Destroy sockets of apps under background firewall chain restriction

Currently, CS destroys sockets of apps under firewall chain restriction
only when the firewall chains get enabled by setFirewallChainEnabled.

This CL updates CS to also destroy sockets of apps when the
setUidFirewallRule and replaceFirewallChain update uid rule and apps
lose network access due to FIREWALL_CHAIN_BACKGROUND.
This socket destruction is delayed until the cell modem is up in the
same way that CS destroys sockets of frozen apps.

Test: CSSocketDestroyTest
Bug: 300681644
Change-Id: I6a6fb5d6d7b71c0949a4195b84d30efef95bff37
diff --git a/service/src/com/android/server/ConnectivityService.java b/service/src/com/android/server/ConnectivityService.java
index 519391f..ca99935 100755
--- a/service/src/com/android/server/ConnectivityService.java
+++ b/service/src/com/android/server/ConnectivityService.java
@@ -36,6 +36,7 @@
 import static android.net.ConnectivityDiagnosticsManager.DataStallReport.KEY_TCP_PACKET_FAIL_RATE;
 import static android.net.ConnectivityManager.ACTION_RESTRICT_BACKGROUND_CHANGED;
 import static android.net.ConnectivityManager.BLOCKED_METERED_REASON_MASK;
+import static android.net.ConnectivityManager.BLOCKED_REASON_APP_BACKGROUND;
 import static android.net.ConnectivityManager.BLOCKED_REASON_LOCKDOWN_VPN;
 import static android.net.ConnectivityManager.BLOCKED_REASON_NONE;
 import static android.net.ConnectivityManager.BLOCKED_REASON_NETWORK_RESTRICTED;
@@ -878,6 +879,18 @@
     private static final int EVENT_UID_FROZEN_STATE_CHANGED = 61;
 
     /**
+     * Event to update firewall socket destroy reasons for uids.
+     * obj = List of Pair(uid, socketDestroyReasons)
+     */
+    private static final int EVENT_UPDATE_FIREWALL_DESTROY_SOCKET_REASONS = 62;
+
+    /**
+     * Event to clear firewall socket destroy reasons for all uids.
+     * arg1 = socketDestroyReason
+     */
+    private static final int EVENT_CLEAR_FIREWALL_DESTROY_SOCKET_REASONS = 63;
+
+    /**
      * Argument for {@link #EVENT_PROVISIONING_NOTIFICATION} to indicate that the notification
      * should be shown.
      */
@@ -1033,6 +1046,7 @@
 
     private static final int DESTROY_SOCKET_REASON_NONE = 0;
     private static final int DESTROY_SOCKET_REASON_FROZEN = 1 << 0;
+    private static final int DESTROY_SOCKET_REASON_FIREWALL_BACKGROUND = 1 << 1;
 
     // Flag to drop packets to VPN addresses ingressing via non-VPN interfaces.
     private final boolean mIngressToVpnAddressFiltering;
@@ -3404,6 +3418,10 @@
         return !mNetworkActivityTracker.isDefaultNetworkActive();
     }
 
+    private boolean shouldTrackFirewallDestroySocketReasons() {
+        return mDeps.isAtLeastV();
+    }
+
     private void updateDestroySocketReasons(final int uid, final int reason,
             final boolean addReason) {
         final int destroyReasons = mDestroySocketPendingUids.get(uid, DESTROY_SOCKET_REASON_NONE);
@@ -3432,6 +3450,43 @@
         }
     }
 
+    private void handleUpdateFirewallDestroySocketReasons(
+            List<Pair<Integer, Integer>> reasonsList) {
+        if (!shouldTrackFirewallDestroySocketReasons()) {
+            Log.wtf(TAG, "handleUpdateFirewallDestroySocketReasons is called unexpectedly");
+            return;
+        }
+        ensureRunningOnConnectivityServiceThread();
+
+        for (Pair<Integer, Integer> uidSocketDestroyReasons: reasonsList) {
+            final int uid = uidSocketDestroyReasons.first;
+            final int reasons = uidSocketDestroyReasons.second;
+            final boolean destroyByFirewallBackground =
+                    (reasons & DESTROY_SOCKET_REASON_FIREWALL_BACKGROUND)
+                            != DESTROY_SOCKET_REASON_NONE;
+            updateDestroySocketReasons(uid, DESTROY_SOCKET_REASON_FIREWALL_BACKGROUND,
+                    destroyByFirewallBackground);
+        }
+
+        if (!mDelayDestroySockets || !isCellNetworkIdle()) {
+            destroyPendingSockets();
+        }
+    }
+
+    private void handleClearFirewallDestroySocketReasons(final int reason) {
+        if (!shouldTrackFirewallDestroySocketReasons()) {
+            Log.wtf(TAG, "handleClearFirewallDestroySocketReasons is called uexpectedly");
+            return;
+        }
+        ensureRunningOnConnectivityServiceThread();
+
+        // Unset reason from all pending uids
+        for (int i = mDestroySocketPendingUids.size() - 1; i >= 0; i--) {
+            final int uid = mDestroySocketPendingUids.keyAt(i);
+            updateDestroySocketReasons(uid, reason, false /* addReason */);
+        }
+    }
+
     private void destroyPendingSockets() {
         ensureRunningOnConnectivityServiceThread();
         if (mDestroySocketPendingUids.size() == 0) {
@@ -6617,6 +6672,12 @@
                     UidFrozenStateChangedArgs args = (UidFrozenStateChangedArgs) msg.obj;
                     handleFrozenUids(args.mUids, args.mFrozenStates);
                     break;
+                case EVENT_UPDATE_FIREWALL_DESTROY_SOCKET_REASONS:
+                    handleUpdateFirewallDestroySocketReasons((List) msg.obj);
+                    break;
+                case EVENT_CLEAR_FIREWALL_DESTROY_SOCKET_REASONS:
+                    handleClearFirewallDestroySocketReasons(msg.arg1);
+                    break;
             }
         }
     }
@@ -13734,6 +13795,9 @@
                 mHandler.sendMessage(mHandler.obtainMessage(EVENT_BLOCKED_REASONS_CHANGED,
                         List.of(new Pair<>(uid, mBpfNetMaps.getUidNetworkingBlockedReasons(uid)))));
             }
+            if (shouldTrackFirewallDestroySocketReasons()) {
+                maybePostFirewallDestroySocketReasons(chain, Set.of(uid));
+            }
         }
     }
 
@@ -13778,23 +13842,40 @@
     }
 
     @RequiresApi(Build.VERSION_CODES.TIRAMISU)
+    private Set<Integer> getUidsOnFirewallChain(final int chain) throws ErrnoException {
+        if (BpfNetMapsUtils.isFirewallAllowList(chain)) {
+            return mBpfNetMaps.getUidsWithAllowRuleOnAllowListChain(chain);
+        } else {
+            return mBpfNetMaps.getUidsWithDenyRuleOnDenyListChain(chain);
+        }
+    }
+
+    @RequiresApi(Build.VERSION_CODES.TIRAMISU)
     private void closeSocketsForFirewallChainLocked(final int chain)
             throws ErrnoException, SocketException, InterruptedIOException {
+        final Set<Integer> uidsOnChain = getUidsOnFirewallChain(chain);
         if (BpfNetMapsUtils.isFirewallAllowList(chain)) {
             // Allowlist means the firewall denies all by default, uids must be explicitly allowed
             // So, close all non-system socket owned by uids that are not explicitly allowed
             Set<Range<Integer>> ranges = new ArraySet<>();
             ranges.add(new Range<>(Process.FIRST_APPLICATION_UID, Integer.MAX_VALUE));
-            final Set<Integer> exemptUids = mBpfNetMaps.getUidsWithAllowRuleOnAllowListChain(chain);
-            mDeps.destroyLiveTcpSockets(ranges, exemptUids);
+            mDeps.destroyLiveTcpSockets(ranges, uidsOnChain /* exemptUids */);
         } else {
             // Denylist means the firewall allows all by default, uids must be explicitly denied
             // So, close socket owned by uids that are explicitly denied
-            final Set<Integer> ownerUids = mBpfNetMaps.getUidsWithDenyRuleOnDenyListChain(chain);
-            mDeps.destroyLiveTcpSocketsByOwnerUids(ownerUids);
+            mDeps.destroyLiveTcpSocketsByOwnerUids(uidsOnChain /* ownerUids */);
         }
     }
 
+    private void maybePostClearFirewallDestroySocketReasons(int chain) {
+        if (chain != FIREWALL_CHAIN_BACKGROUND) {
+            // TODO (b/300681644): Support other firewall chains
+            return;
+        }
+        mHandler.sendMessage(mHandler.obtainMessage(EVENT_CLEAR_FIREWALL_DESTROY_SOCKET_REASONS,
+                DESTROY_SOCKET_REASON_FIREWALL_BACKGROUND, 0 /* arg2 */));
+    }
+
     @Override
     public void setFirewallChainEnabled(final int chain, final boolean enable) {
         enforceNetworkStackOrSettingsPermission();
@@ -13820,6 +13901,11 @@
             if (shouldTrackUidsForBlockedStatusCallbacks()) {
                 updateTrackingUidsBlockedReasons();
             }
+            if (shouldTrackFirewallDestroySocketReasons() && !enable) {
+                // Clear destroy socket reasons so that CS does not destroy sockets of apps that
+                // have network access.
+                maybePostClearFirewallDestroySocketReasons(chain);
+            }
         }
 
         if (mDeps.isAtLeastU() && enable) {
@@ -13847,6 +13933,31 @@
                 uidBlockedReasonsList));
     }
 
+    private int getFirewallDestroySocketReasons(final int blockedReasons) {
+        int destroySocketReasons = DESTROY_SOCKET_REASON_NONE;
+        if ((blockedReasons & BLOCKED_REASON_APP_BACKGROUND) != BLOCKED_REASON_NONE) {
+            destroySocketReasons |= DESTROY_SOCKET_REASON_FIREWALL_BACKGROUND;
+        }
+        return destroySocketReasons;
+    }
+
+    @RequiresApi(Build.VERSION_CODES.TIRAMISU)
+    @GuardedBy("mBlockedStatusTrackingUids")
+    private void maybePostFirewallDestroySocketReasons(int chain, Set<Integer> uids) {
+        if (chain != FIREWALL_CHAIN_BACKGROUND) {
+            // TODO (b/300681644): Support other firewall chains
+            return;
+        }
+        final ArrayList<Pair<Integer, Integer>> reasonsList = new ArrayList<>();
+        for (int uid: uids) {
+            final int blockedReasons = mBpfNetMaps.getUidNetworkingBlockedReasons(uid);
+            final int destroySocketReaons = getFirewallDestroySocketReasons(blockedReasons);
+            reasonsList.add(new Pair<>(uid, destroySocketReaons));
+        }
+        mHandler.sendMessage(mHandler.obtainMessage(EVENT_UPDATE_FIREWALL_DESTROY_SOCKET_REASONS,
+                reasonsList));
+    }
+
     @Override
     public boolean getFirewallChainEnabled(final int chain) {
         enforceNetworkStackOrSettingsPermission();
@@ -13872,11 +13983,29 @@
         }
 
         synchronized (mBlockedStatusTrackingUids) {
-            mBpfNetMaps.replaceUidChain(chain, uids);
+            // replaceFirewallChain removes uids that are currently on the chain and put |uids| on
+            // the chain.
+            // So this method could change blocked reasons of uids that are currently on chain +
+            // |uids|.
+            final Set<Integer> affectedUids = new ArraySet<>();
+            if (shouldTrackFirewallDestroySocketReasons()) {
+                try {
+                    affectedUids.addAll(getUidsOnFirewallChain(chain));
+                } catch (ErrnoException e) {
+                    Log.e(TAG, "Failed to get uids on chain(" + chain + "): " + e);
+                }
+                for (final int uid: uids) {
+                    affectedUids.add(uid);
+                }
+            }
 
+            mBpfNetMaps.replaceUidChain(chain, uids);
             if (shouldTrackUidsForBlockedStatusCallbacks()) {
                 updateTrackingUidsBlockedReasons();
             }
+            if (shouldTrackFirewallDestroySocketReasons()) {
+                maybePostFirewallDestroySocketReasons(chain, affectedUids);
+            }
         }
     }
 
diff --git a/tests/unit/java/com/android/server/connectivityservice/CSDestroySocketTest.kt b/tests/unit/java/com/android/server/connectivityservice/CSDestroySocketTest.kt
new file mode 100644
index 0000000..bc5be78
--- /dev/null
+++ b/tests/unit/java/com/android/server/connectivityservice/CSDestroySocketTest.kt
@@ -0,0 +1,338 @@
+/*
+ * Copyright (C) 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server
+
+import android.app.ActivityManager.UidFrozenStateChangedCallback
+import android.app.ActivityManager.UidFrozenStateChangedCallback.UID_FROZEN_STATE_FROZEN
+import android.app.ActivityManager.UidFrozenStateChangedCallback.UID_FROZEN_STATE_UNFROZEN
+import android.net.ConnectivityManager.BLOCKED_REASON_APP_BACKGROUND
+import android.net.ConnectivityManager.BLOCKED_REASON_NONE
+import android.net.ConnectivityManager.FIREWALL_CHAIN_BACKGROUND
+import android.net.ConnectivityManager.FIREWALL_RULE_ALLOW
+import android.net.ConnectivityManager.FIREWALL_RULE_DENY
+import android.net.LinkProperties
+import android.net.NetworkCapabilities
+import android.os.Build
+import com.android.net.module.util.BaseNetdUnsolicitedEventListener
+import com.android.server.connectivity.ConnectivityFlags.DELAY_DESTROY_SOCKETS
+import com.android.testutils.DevSdkIgnoreRule
+import com.android.testutils.DevSdkIgnoreRunner
+import org.junit.Test
+import org.junit.runner.RunWith
+import org.mockito.ArgumentCaptor
+import org.mockito.Mockito.any
+import org.mockito.Mockito.doReturn
+import org.mockito.Mockito.inOrder
+import org.mockito.Mockito.never
+import org.mockito.Mockito.verify
+
+private const val TIMESTAMP = 1234L
+private const val TEST_UID = 1234
+private const val TEST_UID2 = 5678
+private const val TEST_CELL_IFACE = "test_rmnet"
+
+private fun cellNc() = NetworkCapabilities.Builder()
+        .addTransportType(NetworkCapabilities.TRANSPORT_CELLULAR)
+        .addCapability(NetworkCapabilities.NET_CAPABILITY_INTERNET)
+        .addCapability(NetworkCapabilities.NET_CAPABILITY_NOT_SUSPENDED)
+        .addCapability(NetworkCapabilities.NET_CAPABILITY_NOT_VCN_MANAGED)
+        .build()
+
+private fun cellLp() = LinkProperties().also{
+    it.interfaceName = TEST_CELL_IFACE
+}
+
+@RunWith(DevSdkIgnoreRunner::class)
+@DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.UPSIDE_DOWN_CAKE)
+class CSDestroySocketTest : CSTest() {
+    private fun getRegisteredNetdUnsolicitedEventListener(): BaseNetdUnsolicitedEventListener {
+        val captor = ArgumentCaptor.forClass(BaseNetdUnsolicitedEventListener::class.java)
+        verify(netd).registerUnsolicitedEventListener(captor.capture())
+        return captor.value
+    }
+
+    private fun getUidFrozenStateChangedCallback(): UidFrozenStateChangedCallback {
+        val captor = ArgumentCaptor.forClass(UidFrozenStateChangedCallback::class.java)
+        verify(activityManager).registerUidFrozenStateChangedCallback(any(), captor.capture())
+        return captor.value
+    }
+
+    private fun doTestBackgroundRestrictionDestroySockets(
+            restrictionWithIdleNetwork: Boolean,
+            expectDelay: Boolean
+    ) {
+        val netdEventListener = getRegisteredNetdUnsolicitedEventListener()
+        val inOrder = inOrder(destroySocketsWrapper)
+
+        val cellAgent = Agent(nc = cellNc(), lp = cellLp())
+        cellAgent.connect()
+        if (restrictionWithIdleNetwork) {
+            // Make cell default network idle
+            netdEventListener.onInterfaceClassActivityChanged(
+                    false, // isActive
+                    cellAgent.network.netId,
+                    TIMESTAMP,
+                    TEST_UID
+            )
+        }
+
+        // Set deny rule on background chain for TEST_UID
+        doReturn(BLOCKED_REASON_APP_BACKGROUND)
+                .`when`(bpfNetMaps).getUidNetworkingBlockedReasons(TEST_UID)
+        cm.setUidFirewallRule(
+                FIREWALL_CHAIN_BACKGROUND,
+                TEST_UID,
+                FIREWALL_RULE_DENY
+        )
+        waitForIdle()
+        if (expectDelay) {
+            inOrder.verify(destroySocketsWrapper, never())
+                    .destroyLiveTcpSocketsByOwnerUids(setOf(TEST_UID))
+        } else {
+            inOrder.verify(destroySocketsWrapper)
+                    .destroyLiveTcpSocketsByOwnerUids(setOf(TEST_UID))
+        }
+
+        netdEventListener.onInterfaceClassActivityChanged(
+                true, // isActive
+                cellAgent.network.netId,
+                TIMESTAMP,
+                TEST_UID
+        )
+        waitForIdle()
+        if (expectDelay) {
+            inOrder.verify(destroySocketsWrapper)
+                    .destroyLiveTcpSocketsByOwnerUids(setOf(TEST_UID))
+        } else {
+            inOrder.verify(destroySocketsWrapper, never())
+                    .destroyLiveTcpSocketsByOwnerUids(setOf(TEST_UID))
+        }
+
+        cellAgent.disconnect()
+    }
+
+    @Test
+    @FeatureFlags(flags = [Flag(DELAY_DESTROY_SOCKETS, true)])
+    fun testBackgroundAppDestroySockets() {
+        doTestBackgroundRestrictionDestroySockets(
+                restrictionWithIdleNetwork = true,
+                expectDelay = true
+        )
+    }
+
+    @Test
+    @FeatureFlags(flags = [Flag(DELAY_DESTROY_SOCKETS, true)])
+    fun testBackgroundAppDestroySockets_activeNetwork() {
+        doTestBackgroundRestrictionDestroySockets(
+                restrictionWithIdleNetwork = false,
+                expectDelay = false
+        )
+    }
+
+    @Test
+    @FeatureFlags(flags = [Flag(DELAY_DESTROY_SOCKETS, false)])
+    fun testBackgroundAppDestroySockets_featureIsDisabled() {
+        doTestBackgroundRestrictionDestroySockets(
+                restrictionWithIdleNetwork = true,
+                expectDelay = false
+        )
+    }
+
+    @Test
+    fun testReplaceFirewallChain() {
+        val netdEventListener = getRegisteredNetdUnsolicitedEventListener()
+        val inOrder = inOrder(destroySocketsWrapper)
+
+        val cellAgent = Agent(nc = cellNc(), lp = cellLp())
+        cellAgent.connect()
+        // Make cell default network idle
+        netdEventListener.onInterfaceClassActivityChanged(
+                false, // isActive
+                cellAgent.network.netId,
+                TIMESTAMP,
+                TEST_UID
+        )
+
+        // Set allow rule on background chain for TEST_UID
+        doReturn(BLOCKED_REASON_NONE)
+                .`when`(bpfNetMaps).getUidNetworkingBlockedReasons(TEST_UID)
+        cm.setUidFirewallRule(
+                FIREWALL_CHAIN_BACKGROUND,
+                TEST_UID,
+                FIREWALL_RULE_ALLOW
+        )
+        // Set deny rule on background chain for TEST_UID
+        doReturn(BLOCKED_REASON_APP_BACKGROUND)
+                .`when`(bpfNetMaps).getUidNetworkingBlockedReasons(TEST_UID2)
+        cm.setUidFirewallRule(
+                FIREWALL_CHAIN_BACKGROUND,
+                TEST_UID2,
+                FIREWALL_RULE_DENY
+        )
+
+        // Put only TEST_UID2 on background chain (deny TEST_UID and allow TEST_UID2)
+        doReturn(setOf(TEST_UID))
+                .`when`(bpfNetMaps).getUidsWithAllowRuleOnAllowListChain(FIREWALL_CHAIN_BACKGROUND)
+        doReturn(BLOCKED_REASON_APP_BACKGROUND)
+                .`when`(bpfNetMaps).getUidNetworkingBlockedReasons(TEST_UID)
+        doReturn(BLOCKED_REASON_NONE)
+                .`when`(bpfNetMaps).getUidNetworkingBlockedReasons(TEST_UID2)
+        cm.replaceFirewallChain(FIREWALL_CHAIN_BACKGROUND, intArrayOf(TEST_UID2))
+        waitForIdle()
+        inOrder.verify(destroySocketsWrapper, never())
+                .destroyLiveTcpSocketsByOwnerUids(setOf(TEST_UID))
+
+        netdEventListener.onInterfaceClassActivityChanged(
+                true, // isActive
+                cellAgent.network.netId,
+                TIMESTAMP,
+                TEST_UID
+        )
+        waitForIdle()
+        inOrder.verify(destroySocketsWrapper)
+                .destroyLiveTcpSocketsByOwnerUids(setOf(TEST_UID))
+
+        cellAgent.disconnect()
+    }
+
+    private fun doTestDestroySockets(
+            isFrozen: Boolean,
+            denyOnBackgroundChain: Boolean,
+            enableBackgroundChain: Boolean,
+            expectDestroySockets: Boolean
+    ) {
+        val netdEventListener = getRegisteredNetdUnsolicitedEventListener()
+        val frozenStateCallback = getUidFrozenStateChangedCallback()
+
+        // Make cell default network idle
+        val cellAgent = Agent(nc = cellNc(), lp = cellLp())
+        cellAgent.connect()
+        netdEventListener.onInterfaceClassActivityChanged(
+                false, // isActive
+                cellAgent.network.netId,
+                TIMESTAMP,
+                TEST_UID
+        )
+
+        // Set deny rule on background chain for TEST_UID
+        doReturn(BLOCKED_REASON_APP_BACKGROUND)
+                .`when`(bpfNetMaps).getUidNetworkingBlockedReasons(TEST_UID)
+        cm.setUidFirewallRule(
+                FIREWALL_CHAIN_BACKGROUND,
+                TEST_UID,
+                FIREWALL_RULE_DENY
+        )
+
+        // Freeze TEST_UID
+        frozenStateCallback.onUidFrozenStateChanged(
+                intArrayOf(TEST_UID),
+                intArrayOf(UID_FROZEN_STATE_FROZEN)
+        )
+
+        if (!isFrozen) {
+            // Unfreeze TEST_UID
+            frozenStateCallback.onUidFrozenStateChanged(
+                    intArrayOf(TEST_UID),
+                    intArrayOf(UID_FROZEN_STATE_UNFROZEN)
+            )
+        }
+        if (!enableBackgroundChain) {
+            // Disable background chain
+            cm.setFirewallChainEnabled(FIREWALL_CHAIN_BACKGROUND, false)
+        }
+        if (!denyOnBackgroundChain) {
+            // Set allow rule on background chain for TEST_UID
+            doReturn(BLOCKED_REASON_NONE)
+                    .`when`(bpfNetMaps).getUidNetworkingBlockedReasons(TEST_UID)
+            cm.setUidFirewallRule(
+                    FIREWALL_CHAIN_BACKGROUND,
+                    TEST_UID,
+                    FIREWALL_RULE_ALLOW
+            )
+        }
+        verify(destroySocketsWrapper, never()).destroyLiveTcpSocketsByOwnerUids(setOf(TEST_UID))
+
+        // Make cell network active
+        netdEventListener.onInterfaceClassActivityChanged(
+                true, // isActive
+                cellAgent.network.netId,
+                TIMESTAMP,
+                TEST_UID
+        )
+        waitForIdle()
+
+        if (expectDestroySockets) {
+            verify(destroySocketsWrapper).destroyLiveTcpSocketsByOwnerUids(setOf(TEST_UID))
+        } else {
+            verify(destroySocketsWrapper, never()).destroyLiveTcpSocketsByOwnerUids(setOf(TEST_UID))
+        }
+    }
+
+    @Test
+    fun testDestroySockets_backgroundDeny_frozen() {
+        doTestDestroySockets(
+                isFrozen = true,
+                denyOnBackgroundChain = true,
+                enableBackgroundChain = true,
+                expectDestroySockets = true
+        )
+    }
+
+    @Test
+    fun testDestroySockets_backgroundDeny_nonFrozen() {
+        doTestDestroySockets(
+                isFrozen = false,
+                denyOnBackgroundChain = true,
+                enableBackgroundChain = true,
+                expectDestroySockets = true
+        )
+    }
+
+    @Test
+    fun testDestroySockets_backgroundAllow_frozen() {
+        doTestDestroySockets(
+                isFrozen = true,
+                denyOnBackgroundChain = false,
+                enableBackgroundChain = true,
+                expectDestroySockets = true
+        )
+    }
+
+    @Test
+    fun testDestroySockets_backgroundAllow_nonFrozen() {
+        // If the app is neither frozen nor under background restriction, sockets are not
+        // destroyed
+        doTestDestroySockets(
+                isFrozen = false,
+                denyOnBackgroundChain = false,
+                enableBackgroundChain = true,
+                expectDestroySockets = false
+        )
+    }
+
+    @Test
+    fun testDestroySockets_backgroundChainDisabled_nonFrozen() {
+        // If the app is neither frozen nor under background restriction, sockets are not
+        // destroyed
+        doTestDestroySockets(
+                isFrozen = false,
+                denyOnBackgroundChain = true,
+                enableBackgroundChain = false,
+                expectDestroySockets = false
+        )
+    }
+}
diff --git a/tests/unit/java/com/android/server/connectivityservice/base/CSTest.kt b/tests/unit/java/com/android/server/connectivityservice/base/CSTest.kt
index 99a8a3d..47a6763 100644
--- a/tests/unit/java/com/android/server/connectivityservice/base/CSTest.kt
+++ b/tests/unit/java/com/android/server/connectivityservice/base/CSTest.kt
@@ -210,6 +210,7 @@
 
     val multicastRoutingCoordinatorService = mock<MulticastRoutingCoordinatorService>()
     val satelliteAccessController = mock<SatelliteAccessController>()
+    val destroySocketsWrapper = mock<DestroySocketsWrapper>()
 
     val deps = CSDeps()
 
@@ -263,6 +264,11 @@
         alarmHandlerThread.join()
     }
 
+    // Class to be mocked and used to verify destroy sockets methods call
+    open inner class DestroySocketsWrapper {
+        open fun destroyLiveTcpSocketsByOwnerUids(ownerUids: Set<Int>) {}
+    }
+
     inner class CSDeps : ConnectivityService.Dependencies() {
         override fun getResources(ctx: Context) = connResources
         override fun getBpfNetMaps(context: Context, netd: INetd) = this@CSTest.bpfNetMaps
@@ -368,6 +374,11 @@
 
         override fun getCallingUid() =
                 if (callingUid == CALLING_UID_UNMOCKED) super.getCallingUid() else callingUid
+
+        override fun destroyLiveTcpSocketsByOwnerUids(ownerUids: Set<Int>) {
+            // Call mocked destroyLiveTcpSocketsByOwnerUids so that test can verify this method call
+            destroySocketsWrapper.destroyLiveTcpSocketsByOwnerUids(ownerUids)
+        }
     }
 
     inner class CSContext(base: Context) : BroadcastInterceptingContext(base) {