Merge changes I09e780af,I21367c66

* changes:
  Implement proper subtype advertising
  Implement proper discovery with subtypes
diff --git a/service/src/com/android/server/BpfNetMaps.java b/service/src/com/android/server/BpfNetMaps.java
index 84e581e..ec168dd 100644
--- a/service/src/com/android/server/BpfNetMaps.java
+++ b/service/src/com/android/server/BpfNetMaps.java
@@ -384,7 +384,6 @@
      * ALLOWLIST means the firewall denies all by default, uids must be explicitly allowed
      * DENYLIST means the firewall allows all by default, uids must be explicitly denyed
      */
-    @VisibleForTesting
     public boolean isFirewallAllowList(final int chain) {
         switch (chain) {
             case FIREWALL_CHAIN_DOZABLE:
@@ -745,6 +744,65 @@
         }
     }
 
+    private Set<Integer> getUidsMatchEnabled(final int childChain) throws ErrnoException {
+        final long match = getMatchByFirewallChain(childChain);
+        Set<Integer> uids = new ArraySet<>();
+        synchronized (sUidOwnerMap) {
+            sUidOwnerMap.forEach((uid, val) -> {
+                if (val == null) {
+                    Log.wtf(TAG, "sUidOwnerMap entry was deleted while holding a lock");
+                } else {
+                    if ((val.rule & match) != 0) {
+                        uids.add(uid.val);
+                    }
+                }
+            });
+        }
+        return uids;
+    }
+
+    /**
+     * Get uids that has FIREWALL_RULE_ALLOW on allowlist chain.
+     * Allowlist means the firewall denies all by default, uids must be explicitly allowed.
+     *
+     * Note that uids that has FIREWALL_RULE_DENY on allowlist chain can not be computed from the
+     * bpf map, since all the uids that does not have explicit FIREWALL_RULE_ALLOW rule in bpf map
+     * are determined to have FIREWALL_RULE_DENY.
+     *
+     * @param childChain target chain
+     * @return Set of uids
+     */
+    public Set<Integer> getUidsWithAllowRuleOnAllowListChain(final int childChain)
+            throws ErrnoException {
+        if (!isFirewallAllowList(childChain)) {
+            throw new IllegalArgumentException("getUidsWithAllowRuleOnAllowListChain is called with"
+                    + " denylist chain:" + childChain);
+        }
+        // Corresponding match is enabled for uids that has FIREWALL_RULE_ALLOW on allowlist chain.
+        return getUidsMatchEnabled(childChain);
+    }
+
+    /**
+     * Get uids that has FIREWALL_RULE_DENY on denylist chain.
+     * Denylist means the firewall allows all by default, uids must be explicitly denyed
+     *
+     * Note that uids that has FIREWALL_RULE_ALLOW on denylist chain can not be computed from the
+     * bpf map, since all the uids that does not have explicit FIREWALL_RULE_DENY rule in bpf map
+     * are determined to have the FIREWALL_RULE_ALLOW.
+     *
+     * @param childChain target chain
+     * @return Set of uids
+     */
+    public Set<Integer> getUidsWithDenyRuleOnDenyListChain(final int childChain)
+            throws ErrnoException {
+        if (isFirewallAllowList(childChain)) {
+            throw new IllegalArgumentException("getUidsWithDenyRuleOnDenyListChain is called with"
+                    + " allowlist chain:" + childChain);
+        }
+        // Corresponding match is enabled for uids that has FIREWALL_RULE_DENY on denylist chain.
+        return getUidsMatchEnabled(childChain);
+    }
+
     /**
      * Add ingress interface filtering rules to a list of UIDs
      *
diff --git a/service/src/com/android/server/ConnectivityService.java b/service/src/com/android/server/ConnectivityService.java
index b5c9b0a..b17af99 100755
--- a/service/src/com/android/server/ConnectivityService.java
+++ b/service/src/com/android/server/ConnectivityService.java
@@ -1509,6 +1509,16 @@
                 throws SocketException, InterruptedIOException, ErrnoException {
             InetDiagMessage.destroyLiveTcpSockets(ranges, exemptUids);
         }
+
+        /**
+         * Call {@link InetDiagMessage#destroyLiveTcpSocketsByOwnerUids(Set)}
+         *
+         * @param ownerUids target uids to close sockets
+         */
+        public void destroyLiveTcpSocketsByOwnerUids(final Set<Integer> ownerUids)
+                throws SocketException, InterruptedIOException, ErrnoException {
+            InetDiagMessage.destroyLiveTcpSocketsByOwnerUids(ownerUids);
+        }
     }
 
     public ConnectivityService(Context context) {
@@ -12048,6 +12058,23 @@
         return rule;
     }
 
+    private void closeSocketsForFirewallChainLocked(final int chain)
+            throws ErrnoException, SocketException, InterruptedIOException {
+        if (mBpfNetMaps.isFirewallAllowList(chain)) {
+            // Allowlist means the firewall denies all by default, uids must be explicitly allowed
+            // So, close all non-system socket owned by uids that are not explicitly allowed
+            Set<Range<Integer>> ranges = new ArraySet<>();
+            ranges.add(new Range<>(Process.FIRST_APPLICATION_UID, Integer.MAX_VALUE));
+            final Set<Integer> exemptUids = mBpfNetMaps.getUidsWithAllowRuleOnAllowListChain(chain);
+            mDeps.destroyLiveTcpSockets(ranges, exemptUids);
+        } else {
+            // Denylist means the firewall allows all by default, uids must be explicitly denied
+            // So, close socket owned by uids that are explicitly denied
+            final Set<Integer> ownerUids = mBpfNetMaps.getUidsWithDenyRuleOnDenyListChain(chain);
+            mDeps.destroyLiveTcpSocketsByOwnerUids(ownerUids);
+        }
+    }
+
     @Override
     public void setFirewallChainEnabled(final int chain, final boolean enable) {
         enforceNetworkStackOrSettingsPermission();
@@ -12057,6 +12084,14 @@
         } catch (ServiceSpecificException e) {
             throw new IllegalStateException(e);
         }
+
+        if (SdkLevel.isAtLeastU() && enable) {
+            try {
+                closeSocketsForFirewallChainLocked(chain);
+            } catch (ErrnoException | SocketException | InterruptedIOException e) {
+                Log.e(TAG, "Failed to close sockets after enabling chain (" + chain + "): " + e);
+            }
+        }
     }
 
     @Override
diff --git a/tests/cts/net/src/android/net/cts/ConnectivityManagerTest.java b/tests/cts/net/src/android/net/cts/ConnectivityManagerTest.java
index 8b059e3..ee2f6bb 100644
--- a/tests/cts/net/src/android/net/cts/ConnectivityManagerTest.java
+++ b/tests/cts/net/src/android/net/cts/ConnectivityManagerTest.java
@@ -225,6 +225,7 @@
 import java.net.InetSocketAddress;
 import java.net.MalformedURLException;
 import java.net.Socket;
+import java.net.SocketException;
 import java.net.URL;
 import java.net.UnknownHostException;
 import java.nio.charset.StandardCharsets;
@@ -278,6 +279,7 @@
     // TODO(b/252972908): reset the original timer when aosp/2188755 is ramped up.
     private static final int LISTEN_ACTIVITY_TIMEOUT_MS = 30_000;
     private static final int NO_CALLBACK_TIMEOUT_MS = 100;
+    private static final int NETWORK_REQUEST_TIMEOUT_MS = 3000;
     private static final int SOCKET_TIMEOUT_MS = 100;
     private static final int NUM_TRIES_MULTIPATH_PREF_CHECK = 20;
     private static final long INTERVAL_MULTIPATH_PREF_CHECK_MS = 500;
@@ -3553,6 +3555,103 @@
         doTestFirewallBlocking(FIREWALL_CHAIN_OEM_DENY_3, DENYLIST);
     }
 
+    private void assertSocketOpen(final Socket socket) throws Exception {
+        mCtsNetUtils.testHttpRequest(socket);
+    }
+
+    private void assertSocketClosed(final Socket socket) throws Exception {
+        try {
+            mCtsNetUtils.testHttpRequest(socket);
+            fail("Socket is expected to be closed");
+        } catch (SocketException expected) {
+        }
+    }
+
+    private static final boolean EXPECT_OPEN = false;
+    private static final boolean EXPECT_CLOSE = true;
+
+    private void doTestFirewallCloseSocket(final int chain, final int rule, final int targetUid,
+            final boolean expectClose) {
+        runWithShellPermissionIdentity(() -> {
+            // Firewall chain status will be restored after the test.
+            final boolean wasChainEnabled = mCm.getFirewallChainEnabled(chain);
+            final int previousUidFirewallRule = mCm.getUidFirewallRule(chain, targetUid);
+            final Socket socket = new Socket(TEST_HOST, HTTP_PORT);
+            socket.setSoTimeout(NETWORK_REQUEST_TIMEOUT_MS);
+            testAndCleanup(() -> {
+                mCm.setFirewallChainEnabled(chain, false /* enable */);
+                assertSocketOpen(socket);
+
+                try {
+                    mCm.setUidFirewallRule(chain, targetUid, rule);
+                } catch (IllegalStateException ignored) {
+                    // Removing match causes an exception when the rule entry for the uid does
+                    // not exist. But this is fine and can be ignored.
+                }
+                mCm.setFirewallChainEnabled(chain, true /* enable */);
+
+                if (expectClose) {
+                    assertSocketClosed(socket);
+                } else {
+                    assertSocketOpen(socket);
+                }
+            }, /* cleanup */ () -> {
+                    // Restore the global chain status
+                    mCm.setFirewallChainEnabled(chain, wasChainEnabled);
+                }, /* cleanup */ () -> {
+                    // Restore the uid firewall rule status
+                    try {
+                        mCm.setUidFirewallRule(chain, targetUid, previousUidFirewallRule);
+                    } catch (IllegalStateException ignored) {
+                        // Removing match causes an exception when the rule entry for the uid does
+                        // not exist. But this is fine and can be ignored.
+                    }
+                }, /* cleanup */ () -> {
+                    socket.close();
+                });
+        }, NETWORK_SETTINGS);
+    }
+
+    @Test @IgnoreUpTo(Build.VERSION_CODES.TIRAMISU) @ConnectivityModuleTest
+    public void testFirewallCloseSocketAllowlistChainAllow() {
+        doTestFirewallCloseSocket(FIREWALL_CHAIN_DOZABLE, FIREWALL_RULE_ALLOW,
+                Process.myUid(), EXPECT_OPEN);
+    }
+
+    @Test @IgnoreUpTo(Build.VERSION_CODES.TIRAMISU) @ConnectivityModuleTest
+    public void testFirewallCloseSocketAllowlistChainDeny() {
+        doTestFirewallCloseSocket(FIREWALL_CHAIN_DOZABLE, FIREWALL_RULE_DENY,
+                Process.myUid(), EXPECT_CLOSE);
+    }
+
+    @Test @IgnoreUpTo(Build.VERSION_CODES.TIRAMISU) @ConnectivityModuleTest
+    public void testFirewallCloseSocketAllowlistChainOtherUid() {
+        doTestFirewallCloseSocket(FIREWALL_CHAIN_DOZABLE, FIREWALL_RULE_ALLOW,
+                Process.myUid() + 1, EXPECT_CLOSE);
+        doTestFirewallCloseSocket(FIREWALL_CHAIN_DOZABLE, FIREWALL_RULE_DENY,
+                Process.myUid() + 1, EXPECT_CLOSE);
+    }
+
+    @Test @IgnoreUpTo(Build.VERSION_CODES.TIRAMISU) @ConnectivityModuleTest
+    public void testFirewallCloseSocketDenylistChainAllow() {
+        doTestFirewallCloseSocket(FIREWALL_CHAIN_STANDBY, FIREWALL_RULE_ALLOW,
+                Process.myUid(), EXPECT_OPEN);
+    }
+
+    @Test @IgnoreUpTo(Build.VERSION_CODES.TIRAMISU) @ConnectivityModuleTest
+    public void testFirewallCloseSocketDenylistChainDeny() {
+        doTestFirewallCloseSocket(FIREWALL_CHAIN_STANDBY, FIREWALL_RULE_DENY,
+                Process.myUid(), EXPECT_CLOSE);
+    }
+
+    @Test @IgnoreUpTo(Build.VERSION_CODES.TIRAMISU) @ConnectivityModuleTest
+    public void testFirewallCloseSocketDenylistChainOtherUid() {
+        doTestFirewallCloseSocket(FIREWALL_CHAIN_STANDBY, FIREWALL_RULE_ALLOW,
+                Process.myUid() + 1, EXPECT_OPEN);
+        doTestFirewallCloseSocket(FIREWALL_CHAIN_STANDBY, FIREWALL_RULE_DENY,
+                Process.myUid() + 1, EXPECT_OPEN);
+    }
+
     private void assumeTestSApis() {
         // Cannot use @IgnoreUpTo(Build.VERSION_CODES.R) because this test also requires API 31
         // shims, and @IgnoreUpTo does not check that.
diff --git a/tests/cts/net/util/java/android/net/cts/util/CtsNetUtils.java b/tests/cts/net/util/java/android/net/cts/util/CtsNetUtils.java
index 0c4f794..ce789fc 100644
--- a/tests/cts/net/util/java/android/net/cts/util/CtsNetUtils.java
+++ b/tests/cts/net/util/java/android/net/cts/util/CtsNetUtils.java
@@ -426,7 +426,7 @@
                 .build();
     }
 
-    private void testHttpRequest(Socket s) throws IOException {
+    public void testHttpRequest(Socket s) throws IOException {
         OutputStream out = s.getOutputStream();
         InputStream in = s.getInputStream();
 
@@ -434,7 +434,9 @@
         byte[] responseBytes = new byte[4096];
         out.write(requestBytes);
         in.read(responseBytes);
-        assertTrue(new String(responseBytes, "UTF-8").startsWith("HTTP/1.0 204 No Content\r\n"));
+        final String response = new String(responseBytes, "UTF-8");
+        assertTrue("Received unexpected response: " + response,
+                response.startsWith("HTTP/1.0 204 No Content\r\n"));
     }
 
     private Socket getBoundSocket(Network network, String host, int port) throws IOException {
diff --git a/tests/unit/java/com/android/server/BpfNetMapsTest.java b/tests/unit/java/com/android/server/BpfNetMapsTest.java
index d189848..19fa41d 100644
--- a/tests/unit/java/com/android/server/BpfNetMapsTest.java
+++ b/tests/unit/java/com/android/server/BpfNetMapsTest.java
@@ -66,6 +66,7 @@
 import android.os.Build;
 import android.os.ServiceSpecificException;
 import android.system.ErrnoException;
+import android.util.ArraySet;
 import android.util.IndentingPrintWriter;
 
 import androidx.test.filters.SmallTest;
@@ -1151,4 +1152,33 @@
         mCookieTagMap.updateEntry(new CookieTagMapKey(123), new CookieTagMapValue(456, 0x789));
         assertDumpContains(getDump(), "cookie=123 tag=0x789 uid=456");
     }
+
+    @Test
+    public void testGetUids() throws ErrnoException {
+        final int uid0 = TEST_UIDS[0];
+        final int uid1 = TEST_UIDS[1];
+        final long match0 = DOZABLE_MATCH | POWERSAVE_MATCH;
+        final long match1 = DOZABLE_MATCH | STANDBY_MATCH;
+        mUidOwnerMap.updateEntry(new S32(uid0), new UidOwnerValue(NULL_IIF, match0));
+        mUidOwnerMap.updateEntry(new S32(uid1), new UidOwnerValue(NULL_IIF, match1));
+
+        assertEquals(new ArraySet<>(List.of(uid0, uid1)),
+                mBpfNetMaps.getUidsWithAllowRuleOnAllowListChain(FIREWALL_CHAIN_DOZABLE));
+        assertEquals(new ArraySet<>(List.of(uid0)),
+                mBpfNetMaps.getUidsWithAllowRuleOnAllowListChain(FIREWALL_CHAIN_POWERSAVE));
+
+        assertEquals(new ArraySet<>(List.of(uid1)),
+                mBpfNetMaps.getUidsWithDenyRuleOnDenyListChain(FIREWALL_CHAIN_STANDBY));
+        assertEquals(new ArraySet<>(),
+                mBpfNetMaps.getUidsWithDenyRuleOnDenyListChain(FIREWALL_CHAIN_OEM_DENY_1));
+    }
+
+    @Test
+    public void testGetUidsIllegalArgument() {
+        final Class<IllegalArgumentException> expected = IllegalArgumentException.class;
+        assertThrows(expected,
+                () -> mBpfNetMaps.getUidsWithDenyRuleOnDenyListChain(FIREWALL_CHAIN_DOZABLE));
+        assertThrows(expected,
+                () -> mBpfNetMaps.getUidsWithAllowRuleOnAllowListChain(FIREWALL_CHAIN_OEM_DENY_1));
+    }
 }
diff --git a/tests/unit/java/com/android/server/ConnectivityServiceTest.java b/tests/unit/java/com/android/server/ConnectivityServiceTest.java
index 9d7b21f..31f3124 100755
--- a/tests/unit/java/com/android/server/ConnectivityServiceTest.java
+++ b/tests/unit/java/com/android/server/ConnectivityServiceTest.java
@@ -2173,6 +2173,11 @@
                 final Set<Integer> exemptUids) {
             // This function is empty since the invocation of this method is verified by mocks
         }
+
+        @Override
+        public void destroyLiveTcpSocketsByOwnerUids(final Set<Integer> ownerUids) {
+            // This function is empty since the invocation of this method is verified by mocks
+        }
     }
 
     private class AutomaticOnOffKeepaliveTrackerDependencies
@@ -10269,6 +10274,50 @@
         }
     }
 
+    private void doTestSetFirewallChainEnabledCloseSocket(final int chain,
+            final boolean isAllowList) throws Exception {
+        reset(mDeps);
+
+        mCm.setFirewallChainEnabled(chain, true /* enabled */);
+        final Set<Integer> uids =
+                new ArraySet<>(List.of(TEST_PACKAGE_UID, TEST_PACKAGE_UID2));
+        if (isAllowList) {
+            final Set<Range<Integer>> range = new ArraySet<>(
+                    List.of(new Range<>(Process.FIRST_APPLICATION_UID, Integer.MAX_VALUE)));
+            verify(mDeps).destroyLiveTcpSockets(range, uids);
+        } else {
+            verify(mDeps).destroyLiveTcpSocketsByOwnerUids(uids);
+        }
+
+        mCm.setFirewallChainEnabled(chain, false /* enabled */);
+        verifyNoMoreInteractions(mDeps);
+    }
+
+    @Test @IgnoreUpTo(Build.VERSION_CODES.TIRAMISU)
+    public void testSetFirewallChainEnabledCloseSocket() throws Exception {
+        doReturn(new ArraySet<>(Arrays.asList(TEST_PACKAGE_UID, TEST_PACKAGE_UID2)))
+                .when(mBpfNetMaps)
+                .getUidsWithDenyRuleOnDenyListChain(anyInt());
+        doReturn(new ArraySet<>(Arrays.asList(TEST_PACKAGE_UID, TEST_PACKAGE_UID2)))
+                .when(mBpfNetMaps)
+                .getUidsWithAllowRuleOnAllowListChain(anyInt());
+
+        final boolean allowlist = true;
+        final boolean denylist = false;
+
+        doReturn(true).when(mBpfNetMaps).isFirewallAllowList(anyInt());
+        doTestSetFirewallChainEnabledCloseSocket(FIREWALL_CHAIN_DOZABLE, allowlist);
+        doTestSetFirewallChainEnabledCloseSocket(FIREWALL_CHAIN_POWERSAVE, allowlist);
+        doTestSetFirewallChainEnabledCloseSocket(FIREWALL_CHAIN_RESTRICTED, allowlist);
+        doTestSetFirewallChainEnabledCloseSocket(FIREWALL_CHAIN_LOW_POWER_STANDBY, allowlist);
+
+        doReturn(false).when(mBpfNetMaps).isFirewallAllowList(anyInt());
+        doTestSetFirewallChainEnabledCloseSocket(FIREWALL_CHAIN_STANDBY, denylist);
+        doTestSetFirewallChainEnabledCloseSocket(FIREWALL_CHAIN_OEM_DENY_1, denylist);
+        doTestSetFirewallChainEnabledCloseSocket(FIREWALL_CHAIN_OEM_DENY_2, denylist);
+        doTestSetFirewallChainEnabledCloseSocket(FIREWALL_CHAIN_OEM_DENY_3, denylist);
+    }
+
     private void doTestReplaceFirewallChain(final int chain) {
         final int[] uids = new int[] {1001, 1002};
         mCm.replaceFirewallChain(chain, uids);