Merge "Do not expect validated when connect to Wi-Fi for upstream type None" into main
diff --git a/service/src/com/android/server/ConnectivityService.java b/service/src/com/android/server/ConnectivityService.java
index 2872751..5b415c8 100644
--- a/service/src/com/android/server/ConnectivityService.java
+++ b/service/src/com/android/server/ConnectivityService.java
@@ -9785,10 +9785,10 @@
                 newLp != null ? newLp.getAllInterfaceNames() : null);
 
         for (final String iface : interfaceDiff.added) {
-            addLocalAddressesToBpfMap(iface, MULTICAST_AND_BROADCAST_PREFIXES);
+            addLocalAddressesToBpfMap(iface, MULTICAST_AND_BROADCAST_PREFIXES, newLp);
         }
         for (final String iface : interfaceDiff.removed) {
-            removeLocalAddressesFromBpfMap(iface, MULTICAST_AND_BROADCAST_PREFIXES);
+            removeLocalAddressesFromBpfMap(iface, MULTICAST_AND_BROADCAST_PREFIXES, oldLp);
         }
 
         final CompareResult<LinkAddress> linkAddressDiff = new CompareResult<>(
@@ -9813,7 +9813,7 @@
         // If newLp is not null, adding local network prefixes using interface name of newLp
         if (newLp != null) {
             addLocalAddressesToBpfMap(newLp.getInterfaceName(),
-                    new ArrayList<>(unicastLocalPrefixesToBeAdded));
+                    new ArrayList<>(unicastLocalPrefixesToBeAdded), newLp);
         }
         if (oldLp != null) {
             // excluding removal of ip prefixes that needs to be added for newLp, but also
@@ -9824,7 +9824,7 @@
             }
             // removing ip local network prefixes because of change in link addresses.
             removeLocalAddressesFromBpfMap(oldLp.getInterfaceName(),
-                    new ArrayList<>(unicastLocalPrefixesToBeRemoved));
+                    new ArrayList<>(unicastLocalPrefixesToBeRemoved), oldLp);
         }
 
     }
@@ -9859,10 +9859,15 @@
      * Adds list of prefixes(addresses) to local network access map.
      * @param iface interface name
      * @param prefixes list of prefixes/addresses
+     * @param lp LinkProperties
      */
-    private void addLocalAddressesToBpfMap(final String iface, final List<IpPrefix> prefixes) {
+    private void addLocalAddressesToBpfMap(final String iface, final List<IpPrefix> prefixes,
+                                           @Nullable final LinkProperties lp) {
         if (!BpfNetMaps.isAtLeast25Q2()) return;
+
         for (IpPrefix prefix : prefixes) {
+            // Add local dnses allow rule To BpfMap before adding the block rule for prefix
+            addLocalDnsesToBpfMap(iface, prefix, lp);
             /*
             Prefix length is used by LPM trie map(local_net_access_map) for performing longest
             prefix matching, this length represents the maximum number of bits used for matching.
@@ -9884,18 +9889,79 @@
      * Removes list of prefixes(addresses) from local network access map.
      * @param iface interface name
      * @param prefixes list of prefixes/addresses
+     * @param lp LinkProperties
      */
-    private void removeLocalAddressesFromBpfMap(final String iface, final List<IpPrefix> prefixes) {
+    private void removeLocalAddressesFromBpfMap(final String iface, final List<IpPrefix> prefixes,
+                                                @Nullable final LinkProperties lp) {
         if (!BpfNetMaps.isAtLeast25Q2()) return;
+
         for (IpPrefix prefix : prefixes) {
             // The reasoning for prefix length is explained in addLocalAddressesToBpfMap()
             final int prefixLengthConstant = (prefix.isIPv4() ? (32 + 96) : 32);
             mBpfNetMaps.removeLocalNetAccess(prefixLengthConstant
                     + prefix.getPrefixLength(), iface, prefix.getAddress(), 0, 0);
+
+            // Also remove the allow rule for dnses included in the prefix after removing the block
+            // rule for prefix.
+            removeLocalDnsesFromBpfMap(iface, prefix, lp);
         }
     }
 
     /**
+     * Adds DNS servers to local network access map, if included in the interface prefix
+     * @param iface interface name
+     * @param prefix IpPrefix
+     * @param lp LinkProperties
+     */
+    private void addLocalDnsesToBpfMap(final String iface, IpPrefix prefix,
+            @Nullable final LinkProperties lp) {
+        if (!BpfNetMaps.isAtLeast25Q2() || lp == null) return;
+
+        for (InetAddress dnsServer : lp.getDnsServers()) {
+            // Adds dns allow rule to LocalNetAccessMap for both TCP and UDP protocol at port 53,
+            // if it is a local dns (ie. it falls in the local prefix range).
+            if (prefix.contains(dnsServer)) {
+                mBpfNetMaps.addLocalNetAccess(getIpv4MappedAddressBitLen(), iface, dnsServer,
+                        IPPROTO_UDP, 53, true);
+                mBpfNetMaps.addLocalNetAccess(getIpv4MappedAddressBitLen(), iface, dnsServer,
+                        IPPROTO_TCP, 53, true);
+            }
+        }
+    }
+
+    /**
+     * Removes DNS servers from local network access map, if included in the interface prefix
+     * @param iface interface name
+     * @param prefix IpPrefix
+     * @param lp LinkProperties
+     */
+    private void removeLocalDnsesFromBpfMap(final String iface, IpPrefix prefix,
+            @Nullable final LinkProperties lp) {
+        if (!BpfNetMaps.isAtLeast25Q2() || lp == null) return;
+
+        for (InetAddress dnsServer : lp.getDnsServers()) {
+            // Removes dns allow rule from LocalNetAccessMap for both TCP and UDP protocol
+            // at port 53, if it is a local dns (ie. it falls in the prefix range).
+            if (prefix.contains(dnsServer)) {
+                mBpfNetMaps.removeLocalNetAccess(getIpv4MappedAddressBitLen(), iface, dnsServer,
+                        IPPROTO_UDP, 53);
+                mBpfNetMaps.removeLocalNetAccess(getIpv4MappedAddressBitLen(), iface, dnsServer,
+                        IPPROTO_TCP, 53);
+            }
+        }
+    }
+
+    /**
+     * Returns total bit length of an Ipv4 mapped address.
+     */
+    private int getIpv4MappedAddressBitLen() {
+        final int ifaceLen = 32; // bit length of interface
+        final int inetAddressLen = 32 + 96; // length of ipv4 mapped addresses
+        final int portProtocolLen = 32;  //16 for port + 16 for protocol;
+        return ifaceLen + inetAddressLen + portProtocolLen;
+    }
+
+    /**
      * Have netd update routes from oldLp to newLp.
      * @return true if routes changed between oldLp and newLp
      */
diff --git a/tests/unit/java/com/android/server/connectivityservice/CSBpfNetMapsTest.kt b/tests/unit/java/com/android/server/connectivityservice/CSBpfNetMapsTest.kt
index 8155fd0..06cb7ee 100644
--- a/tests/unit/java/com/android/server/connectivityservice/CSBpfNetMapsTest.kt
+++ b/tests/unit/java/com/android/server/connectivityservice/CSBpfNetMapsTest.kt
@@ -21,7 +21,10 @@
 import android.net.ConnectivityManager.RESTRICT_BACKGROUND_STATUS_DISABLED
 import android.net.ConnectivityManager.RESTRICT_BACKGROUND_STATUS_ENABLED
 import android.net.ConnectivityManager.RESTRICT_BACKGROUND_STATUS_WHITELISTED
+import android.net.InetAddresses
+import android.net.LinkProperties
 import android.os.Build
+import android.os.Build.VERSION_CODES
 import androidx.test.filters.SmallTest
 import com.android.testutils.DevSdkIgnoreRule
 import com.android.testutils.DevSdkIgnoreRule.IgnoreAfter
@@ -33,11 +36,32 @@
 import org.junit.runner.RunWith
 import org.mockito.ArgumentMatchers.anyBoolean
 import org.mockito.ArgumentMatchers.anyInt
+import org.mockito.ArgumentMatchers.eq
+import org.mockito.Mockito.atLeastOnce
 import org.mockito.Mockito.doReturn
 import org.mockito.Mockito.inOrder
 import org.mockito.Mockito.never
 import org.mockito.Mockito.verify
 
+internal val LOCAL_DNS = InetAddresses.parseNumericAddress("224.0.1.2")
+internal val NON_LOCAL_DNS = InetAddresses.parseNumericAddress("76.76.75.75")
+
+private const val IFNAME_1 = "wlan1"
+private const val IFNAME_2 = "wlan2"
+private const val PORT_53 = 53
+private const val PROTOCOL_TCP = 6
+private const val PROTOCOL_UDP = 17
+
+private val lpWithNoLocalDns = LinkProperties().apply {
+    addDnsServer(NON_LOCAL_DNS)
+    interfaceName = IFNAME_1
+}
+
+private val lpWithLocalDns = LinkProperties().apply {
+    addDnsServer(LOCAL_DNS)
+    interfaceName = IFNAME_2
+}
+
 @DevSdkIgnoreRunner.MonitorThreadLeak
 @RunWith(DevSdkIgnoreRunner::class)
 @SmallTest
@@ -69,6 +93,81 @@
         }
     }
 
+    @IgnoreUpTo(Build.VERSION_CODES.VANILLA_ICE_CREAM)
+    @Test
+    fun testLocalPrefixesUpdatedInBpfMap() {
+        // Connect Wi-Fi network with non-local dns.
+        val wifiAgent = Agent(nc = defaultNc(), lp = lpWithNoLocalDns)
+        wifiAgent.connect()
+
+        // Verify that block rule is added to BpfMap for local prefixes.
+        verify(bpfNetMaps, atLeastOnce()).addLocalNetAccess(any(), eq(IFNAME_1),
+            any(), eq(0), eq(0), eq(false))
+
+        wifiAgent.disconnect()
+        val cellAgent = Agent(nc = defaultNc(), lp = lpWithLocalDns)
+        cellAgent.connect()
+
+        // Verify that block rule is removed from BpfMap for local prefixes.
+        verify(bpfNetMaps, atLeastOnce()).removeLocalNetAccess(any(), eq(IFNAME_1),
+            any(), eq(0), eq(0))
+
+        cellAgent.disconnect()
+    }
+
+    @IgnoreUpTo(Build.VERSION_CODES.VANILLA_ICE_CREAM)
+    @Test
+    fun testLocalDnsNotUpdatedInBpfMap() {
+        // Connect Wi-Fi network with non-local dns.
+        val wifiAgent = Agent(nc = defaultNc(), lp = lpWithNoLocalDns)
+        wifiAgent.connect()
+
+        // Verify that No allow rule is added to BpfMap since there is no local dns.
+        verify(bpfNetMaps, never()).addLocalNetAccess(any(), any(), any(), any(), any(),
+            eq(true))
+
+        wifiAgent.disconnect()
+        val cellAgent = Agent(nc = defaultNc(), lp = lpWithLocalDns)
+        cellAgent.connect()
+
+        // Verify that No allow rule from port 53 is removed on network change
+        // because no dns was added
+        verify(bpfNetMaps, never()).removeLocalNetAccess(eq(192), eq(IFNAME_1),
+            eq(NON_LOCAL_DNS), any(), eq(PORT_53))
+
+        cellAgent.disconnect()
+    }
+
+    @IgnoreUpTo(Build.VERSION_CODES.VANILLA_ICE_CREAM)
+    @Test
+    fun testLocalDnsUpdatedInBpfMap() {
+        // Connect Wi-Fi network with one local Dns.
+        val wifiAgent = Agent(nc = defaultNc(), lp = lpWithLocalDns)
+        wifiAgent.connect()
+
+        // Verify that allow rule is added to BpfMap for local dns at port 53,
+        // for TCP(=6) protocol
+        verify(bpfNetMaps, atLeastOnce()).addLocalNetAccess(eq(192), eq(IFNAME_2),
+            eq(LOCAL_DNS), eq(PROTOCOL_TCP), eq(PORT_53), eq(true))
+        // And for UDP(=17) protocol
+        verify(bpfNetMaps, atLeastOnce()).addLocalNetAccess(eq(192), eq(IFNAME_2),
+            eq(LOCAL_DNS), eq(PROTOCOL_UDP), eq(PORT_53), eq(true))
+
+        wifiAgent.disconnect()
+        val cellAgent = Agent(nc = defaultNc(), lp = lpWithNoLocalDns)
+        cellAgent.connect()
+
+        // Verify that allow rule is removed for local dns on network change,
+        // for TCP(=6) protocol
+        verify(bpfNetMaps, atLeastOnce()).removeLocalNetAccess(eq(192), eq(IFNAME_2),
+            eq(LOCAL_DNS), eq(PROTOCOL_TCP), eq(PORT_53))
+        // And for UDP(=17) protocol
+        verify(bpfNetMaps, atLeastOnce()).removeLocalNetAccess(eq(192), eq(IFNAME_2),
+            eq(LOCAL_DNS), eq(PROTOCOL_UDP), eq(PORT_53))
+
+        cellAgent.disconnect()
+    }
+
     private fun mockDataSaverStatus(status: Int) {
         doReturn(status).`when`(context.networkPolicyManager).getRestrictBackgroundStatus(anyInt())
         // While the production code dispatches the intent on the handler thread,