Merge "Destroy sockets of apps under background firewall chain restriction" into main
diff --git a/service/src/com/android/server/ConnectivityService.java b/service/src/com/android/server/ConnectivityService.java
index 519391f..ca99935 100755
--- a/service/src/com/android/server/ConnectivityService.java
+++ b/service/src/com/android/server/ConnectivityService.java
@@ -36,6 +36,7 @@
 import static android.net.ConnectivityDiagnosticsManager.DataStallReport.KEY_TCP_PACKET_FAIL_RATE;
 import static android.net.ConnectivityManager.ACTION_RESTRICT_BACKGROUND_CHANGED;
 import static android.net.ConnectivityManager.BLOCKED_METERED_REASON_MASK;
+import static android.net.ConnectivityManager.BLOCKED_REASON_APP_BACKGROUND;
 import static android.net.ConnectivityManager.BLOCKED_REASON_LOCKDOWN_VPN;
 import static android.net.ConnectivityManager.BLOCKED_REASON_NONE;
 import static android.net.ConnectivityManager.BLOCKED_REASON_NETWORK_RESTRICTED;
@@ -878,6 +879,18 @@
     private static final int EVENT_UID_FROZEN_STATE_CHANGED = 61;
 
     /**
+     * Event to update firewall socket destroy reasons for uids.
+     * obj = List of Pair(uid, socketDestroyReasons)
+     */
+    private static final int EVENT_UPDATE_FIREWALL_DESTROY_SOCKET_REASONS = 62;
+
+    /**
+     * Event to clear firewall socket destroy reasons for all uids.
+     * arg1 = socketDestroyReason
+     */
+    private static final int EVENT_CLEAR_FIREWALL_DESTROY_SOCKET_REASONS = 63;
+
+    /**
      * Argument for {@link #EVENT_PROVISIONING_NOTIFICATION} to indicate that the notification
      * should be shown.
      */
@@ -1033,6 +1046,7 @@
 
     private static final int DESTROY_SOCKET_REASON_NONE = 0;
     private static final int DESTROY_SOCKET_REASON_FROZEN = 1 << 0;
+    private static final int DESTROY_SOCKET_REASON_FIREWALL_BACKGROUND = 1 << 1;
 
     // Flag to drop packets to VPN addresses ingressing via non-VPN interfaces.
     private final boolean mIngressToVpnAddressFiltering;
@@ -3404,6 +3418,10 @@
         return !mNetworkActivityTracker.isDefaultNetworkActive();
     }
 
+    private boolean shouldTrackFirewallDestroySocketReasons() {
+        return mDeps.isAtLeastV();
+    }
+
     private void updateDestroySocketReasons(final int uid, final int reason,
             final boolean addReason) {
         final int destroyReasons = mDestroySocketPendingUids.get(uid, DESTROY_SOCKET_REASON_NONE);
@@ -3432,6 +3450,43 @@
         }
     }
 
+    private void handleUpdateFirewallDestroySocketReasons(
+            List<Pair<Integer, Integer>> reasonsList) {
+        if (!shouldTrackFirewallDestroySocketReasons()) {
+            Log.wtf(TAG, "handleUpdateFirewallDestroySocketReasons is called unexpectedly");
+            return;
+        }
+        ensureRunningOnConnectivityServiceThread();
+
+        for (Pair<Integer, Integer> uidSocketDestroyReasons: reasonsList) {
+            final int uid = uidSocketDestroyReasons.first;
+            final int reasons = uidSocketDestroyReasons.second;
+            final boolean destroyByFirewallBackground =
+                    (reasons & DESTROY_SOCKET_REASON_FIREWALL_BACKGROUND)
+                            != DESTROY_SOCKET_REASON_NONE;
+            updateDestroySocketReasons(uid, DESTROY_SOCKET_REASON_FIREWALL_BACKGROUND,
+                    destroyByFirewallBackground);
+        }
+
+        if (!mDelayDestroySockets || !isCellNetworkIdle()) {
+            destroyPendingSockets();
+        }
+    }
+
+    private void handleClearFirewallDestroySocketReasons(final int reason) {
+        if (!shouldTrackFirewallDestroySocketReasons()) {
+            Log.wtf(TAG, "handleClearFirewallDestroySocketReasons is called uexpectedly");
+            return;
+        }
+        ensureRunningOnConnectivityServiceThread();
+
+        // Unset reason from all pending uids
+        for (int i = mDestroySocketPendingUids.size() - 1; i >= 0; i--) {
+            final int uid = mDestroySocketPendingUids.keyAt(i);
+            updateDestroySocketReasons(uid, reason, false /* addReason */);
+        }
+    }
+
     private void destroyPendingSockets() {
         ensureRunningOnConnectivityServiceThread();
         if (mDestroySocketPendingUids.size() == 0) {
@@ -6617,6 +6672,12 @@
                     UidFrozenStateChangedArgs args = (UidFrozenStateChangedArgs) msg.obj;
                     handleFrozenUids(args.mUids, args.mFrozenStates);
                     break;
+                case EVENT_UPDATE_FIREWALL_DESTROY_SOCKET_REASONS:
+                    handleUpdateFirewallDestroySocketReasons((List) msg.obj);
+                    break;
+                case EVENT_CLEAR_FIREWALL_DESTROY_SOCKET_REASONS:
+                    handleClearFirewallDestroySocketReasons(msg.arg1);
+                    break;
             }
         }
     }
@@ -13734,6 +13795,9 @@
                 mHandler.sendMessage(mHandler.obtainMessage(EVENT_BLOCKED_REASONS_CHANGED,
                         List.of(new Pair<>(uid, mBpfNetMaps.getUidNetworkingBlockedReasons(uid)))));
             }
+            if (shouldTrackFirewallDestroySocketReasons()) {
+                maybePostFirewallDestroySocketReasons(chain, Set.of(uid));
+            }
         }
     }
 
@@ -13778,23 +13842,40 @@
     }
 
     @RequiresApi(Build.VERSION_CODES.TIRAMISU)
+    private Set<Integer> getUidsOnFirewallChain(final int chain) throws ErrnoException {
+        if (BpfNetMapsUtils.isFirewallAllowList(chain)) {
+            return mBpfNetMaps.getUidsWithAllowRuleOnAllowListChain(chain);
+        } else {
+            return mBpfNetMaps.getUidsWithDenyRuleOnDenyListChain(chain);
+        }
+    }
+
+    @RequiresApi(Build.VERSION_CODES.TIRAMISU)
     private void closeSocketsForFirewallChainLocked(final int chain)
             throws ErrnoException, SocketException, InterruptedIOException {
+        final Set<Integer> uidsOnChain = getUidsOnFirewallChain(chain);
         if (BpfNetMapsUtils.isFirewallAllowList(chain)) {
             // Allowlist means the firewall denies all by default, uids must be explicitly allowed
             // So, close all non-system socket owned by uids that are not explicitly allowed
             Set<Range<Integer>> ranges = new ArraySet<>();
             ranges.add(new Range<>(Process.FIRST_APPLICATION_UID, Integer.MAX_VALUE));
-            final Set<Integer> exemptUids = mBpfNetMaps.getUidsWithAllowRuleOnAllowListChain(chain);
-            mDeps.destroyLiveTcpSockets(ranges, exemptUids);
+            mDeps.destroyLiveTcpSockets(ranges, uidsOnChain /* exemptUids */);
         } else {
             // Denylist means the firewall allows all by default, uids must be explicitly denied
             // So, close socket owned by uids that are explicitly denied
-            final Set<Integer> ownerUids = mBpfNetMaps.getUidsWithDenyRuleOnDenyListChain(chain);
-            mDeps.destroyLiveTcpSocketsByOwnerUids(ownerUids);
+            mDeps.destroyLiveTcpSocketsByOwnerUids(uidsOnChain /* ownerUids */);
         }
     }
 
+    private void maybePostClearFirewallDestroySocketReasons(int chain) {
+        if (chain != FIREWALL_CHAIN_BACKGROUND) {
+            // TODO (b/300681644): Support other firewall chains
+            return;
+        }
+        mHandler.sendMessage(mHandler.obtainMessage(EVENT_CLEAR_FIREWALL_DESTROY_SOCKET_REASONS,
+                DESTROY_SOCKET_REASON_FIREWALL_BACKGROUND, 0 /* arg2 */));
+    }
+
     @Override
     public void setFirewallChainEnabled(final int chain, final boolean enable) {
         enforceNetworkStackOrSettingsPermission();
@@ -13820,6 +13901,11 @@
             if (shouldTrackUidsForBlockedStatusCallbacks()) {
                 updateTrackingUidsBlockedReasons();
             }
+            if (shouldTrackFirewallDestroySocketReasons() && !enable) {
+                // Clear destroy socket reasons so that CS does not destroy sockets of apps that
+                // have network access.
+                maybePostClearFirewallDestroySocketReasons(chain);
+            }
         }
 
         if (mDeps.isAtLeastU() && enable) {
@@ -13847,6 +13933,31 @@
                 uidBlockedReasonsList));
     }
 
+    private int getFirewallDestroySocketReasons(final int blockedReasons) {
+        int destroySocketReasons = DESTROY_SOCKET_REASON_NONE;
+        if ((blockedReasons & BLOCKED_REASON_APP_BACKGROUND) != BLOCKED_REASON_NONE) {
+            destroySocketReasons |= DESTROY_SOCKET_REASON_FIREWALL_BACKGROUND;
+        }
+        return destroySocketReasons;
+    }
+
+    @RequiresApi(Build.VERSION_CODES.TIRAMISU)
+    @GuardedBy("mBlockedStatusTrackingUids")
+    private void maybePostFirewallDestroySocketReasons(int chain, Set<Integer> uids) {
+        if (chain != FIREWALL_CHAIN_BACKGROUND) {
+            // TODO (b/300681644): Support other firewall chains
+            return;
+        }
+        final ArrayList<Pair<Integer, Integer>> reasonsList = new ArrayList<>();
+        for (int uid: uids) {
+            final int blockedReasons = mBpfNetMaps.getUidNetworkingBlockedReasons(uid);
+            final int destroySocketReaons = getFirewallDestroySocketReasons(blockedReasons);
+            reasonsList.add(new Pair<>(uid, destroySocketReaons));
+        }
+        mHandler.sendMessage(mHandler.obtainMessage(EVENT_UPDATE_FIREWALL_DESTROY_SOCKET_REASONS,
+                reasonsList));
+    }
+
     @Override
     public boolean getFirewallChainEnabled(final int chain) {
         enforceNetworkStackOrSettingsPermission();
@@ -13872,11 +13983,29 @@
         }
 
         synchronized (mBlockedStatusTrackingUids) {
-            mBpfNetMaps.replaceUidChain(chain, uids);
+            // replaceFirewallChain removes uids that are currently on the chain and put |uids| on
+            // the chain.
+            // So this method could change blocked reasons of uids that are currently on chain +
+            // |uids|.
+            final Set<Integer> affectedUids = new ArraySet<>();
+            if (shouldTrackFirewallDestroySocketReasons()) {
+                try {
+                    affectedUids.addAll(getUidsOnFirewallChain(chain));
+                } catch (ErrnoException e) {
+                    Log.e(TAG, "Failed to get uids on chain(" + chain + "): " + e);
+                }
+                for (final int uid: uids) {
+                    affectedUids.add(uid);
+                }
+            }
 
+            mBpfNetMaps.replaceUidChain(chain, uids);
             if (shouldTrackUidsForBlockedStatusCallbacks()) {
                 updateTrackingUidsBlockedReasons();
             }
+            if (shouldTrackFirewallDestroySocketReasons()) {
+                maybePostFirewallDestroySocketReasons(chain, affectedUids);
+            }
         }
     }
 
diff --git a/tests/unit/java/com/android/server/connectivityservice/CSDestroySocketTest.kt b/tests/unit/java/com/android/server/connectivityservice/CSDestroySocketTest.kt
new file mode 100644
index 0000000..bc5be78
--- /dev/null
+++ b/tests/unit/java/com/android/server/connectivityservice/CSDestroySocketTest.kt
@@ -0,0 +1,338 @@
+/*
+ * Copyright (C) 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server
+
+import android.app.ActivityManager.UidFrozenStateChangedCallback
+import android.app.ActivityManager.UidFrozenStateChangedCallback.UID_FROZEN_STATE_FROZEN
+import android.app.ActivityManager.UidFrozenStateChangedCallback.UID_FROZEN_STATE_UNFROZEN
+import android.net.ConnectivityManager.BLOCKED_REASON_APP_BACKGROUND
+import android.net.ConnectivityManager.BLOCKED_REASON_NONE
+import android.net.ConnectivityManager.FIREWALL_CHAIN_BACKGROUND
+import android.net.ConnectivityManager.FIREWALL_RULE_ALLOW
+import android.net.ConnectivityManager.FIREWALL_RULE_DENY
+import android.net.LinkProperties
+import android.net.NetworkCapabilities
+import android.os.Build
+import com.android.net.module.util.BaseNetdUnsolicitedEventListener
+import com.android.server.connectivity.ConnectivityFlags.DELAY_DESTROY_SOCKETS
+import com.android.testutils.DevSdkIgnoreRule
+import com.android.testutils.DevSdkIgnoreRunner
+import org.junit.Test
+import org.junit.runner.RunWith
+import org.mockito.ArgumentCaptor
+import org.mockito.Mockito.any
+import org.mockito.Mockito.doReturn
+import org.mockito.Mockito.inOrder
+import org.mockito.Mockito.never
+import org.mockito.Mockito.verify
+
+private const val TIMESTAMP = 1234L
+private const val TEST_UID = 1234
+private const val TEST_UID2 = 5678
+private const val TEST_CELL_IFACE = "test_rmnet"
+
+private fun cellNc() = NetworkCapabilities.Builder()
+        .addTransportType(NetworkCapabilities.TRANSPORT_CELLULAR)
+        .addCapability(NetworkCapabilities.NET_CAPABILITY_INTERNET)
+        .addCapability(NetworkCapabilities.NET_CAPABILITY_NOT_SUSPENDED)
+        .addCapability(NetworkCapabilities.NET_CAPABILITY_NOT_VCN_MANAGED)
+        .build()
+
+private fun cellLp() = LinkProperties().also{
+    it.interfaceName = TEST_CELL_IFACE
+}
+
+@RunWith(DevSdkIgnoreRunner::class)
+@DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.UPSIDE_DOWN_CAKE)
+class CSDestroySocketTest : CSTest() {
+    private fun getRegisteredNetdUnsolicitedEventListener(): BaseNetdUnsolicitedEventListener {
+        val captor = ArgumentCaptor.forClass(BaseNetdUnsolicitedEventListener::class.java)
+        verify(netd).registerUnsolicitedEventListener(captor.capture())
+        return captor.value
+    }
+
+    private fun getUidFrozenStateChangedCallback(): UidFrozenStateChangedCallback {
+        val captor = ArgumentCaptor.forClass(UidFrozenStateChangedCallback::class.java)
+        verify(activityManager).registerUidFrozenStateChangedCallback(any(), captor.capture())
+        return captor.value
+    }
+
+    private fun doTestBackgroundRestrictionDestroySockets(
+            restrictionWithIdleNetwork: Boolean,
+            expectDelay: Boolean
+    ) {
+        val netdEventListener = getRegisteredNetdUnsolicitedEventListener()
+        val inOrder = inOrder(destroySocketsWrapper)
+
+        val cellAgent = Agent(nc = cellNc(), lp = cellLp())
+        cellAgent.connect()
+        if (restrictionWithIdleNetwork) {
+            // Make cell default network idle
+            netdEventListener.onInterfaceClassActivityChanged(
+                    false, // isActive
+                    cellAgent.network.netId,
+                    TIMESTAMP,
+                    TEST_UID
+            )
+        }
+
+        // Set deny rule on background chain for TEST_UID
+        doReturn(BLOCKED_REASON_APP_BACKGROUND)
+                .`when`(bpfNetMaps).getUidNetworkingBlockedReasons(TEST_UID)
+        cm.setUidFirewallRule(
+                FIREWALL_CHAIN_BACKGROUND,
+                TEST_UID,
+                FIREWALL_RULE_DENY
+        )
+        waitForIdle()
+        if (expectDelay) {
+            inOrder.verify(destroySocketsWrapper, never())
+                    .destroyLiveTcpSocketsByOwnerUids(setOf(TEST_UID))
+        } else {
+            inOrder.verify(destroySocketsWrapper)
+                    .destroyLiveTcpSocketsByOwnerUids(setOf(TEST_UID))
+        }
+
+        netdEventListener.onInterfaceClassActivityChanged(
+                true, // isActive
+                cellAgent.network.netId,
+                TIMESTAMP,
+                TEST_UID
+        )
+        waitForIdle()
+        if (expectDelay) {
+            inOrder.verify(destroySocketsWrapper)
+                    .destroyLiveTcpSocketsByOwnerUids(setOf(TEST_UID))
+        } else {
+            inOrder.verify(destroySocketsWrapper, never())
+                    .destroyLiveTcpSocketsByOwnerUids(setOf(TEST_UID))
+        }
+
+        cellAgent.disconnect()
+    }
+
+    @Test
+    @FeatureFlags(flags = [Flag(DELAY_DESTROY_SOCKETS, true)])
+    fun testBackgroundAppDestroySockets() {
+        doTestBackgroundRestrictionDestroySockets(
+                restrictionWithIdleNetwork = true,
+                expectDelay = true
+        )
+    }
+
+    @Test
+    @FeatureFlags(flags = [Flag(DELAY_DESTROY_SOCKETS, true)])
+    fun testBackgroundAppDestroySockets_activeNetwork() {
+        doTestBackgroundRestrictionDestroySockets(
+                restrictionWithIdleNetwork = false,
+                expectDelay = false
+        )
+    }
+
+    @Test
+    @FeatureFlags(flags = [Flag(DELAY_DESTROY_SOCKETS, false)])
+    fun testBackgroundAppDestroySockets_featureIsDisabled() {
+        doTestBackgroundRestrictionDestroySockets(
+                restrictionWithIdleNetwork = true,
+                expectDelay = false
+        )
+    }
+
+    @Test
+    fun testReplaceFirewallChain() {
+        val netdEventListener = getRegisteredNetdUnsolicitedEventListener()
+        val inOrder = inOrder(destroySocketsWrapper)
+
+        val cellAgent = Agent(nc = cellNc(), lp = cellLp())
+        cellAgent.connect()
+        // Make cell default network idle
+        netdEventListener.onInterfaceClassActivityChanged(
+                false, // isActive
+                cellAgent.network.netId,
+                TIMESTAMP,
+                TEST_UID
+        )
+
+        // Set allow rule on background chain for TEST_UID
+        doReturn(BLOCKED_REASON_NONE)
+                .`when`(bpfNetMaps).getUidNetworkingBlockedReasons(TEST_UID)
+        cm.setUidFirewallRule(
+                FIREWALL_CHAIN_BACKGROUND,
+                TEST_UID,
+                FIREWALL_RULE_ALLOW
+        )
+        // Set deny rule on background chain for TEST_UID
+        doReturn(BLOCKED_REASON_APP_BACKGROUND)
+                .`when`(bpfNetMaps).getUidNetworkingBlockedReasons(TEST_UID2)
+        cm.setUidFirewallRule(
+                FIREWALL_CHAIN_BACKGROUND,
+                TEST_UID2,
+                FIREWALL_RULE_DENY
+        )
+
+        // Put only TEST_UID2 on background chain (deny TEST_UID and allow TEST_UID2)
+        doReturn(setOf(TEST_UID))
+                .`when`(bpfNetMaps).getUidsWithAllowRuleOnAllowListChain(FIREWALL_CHAIN_BACKGROUND)
+        doReturn(BLOCKED_REASON_APP_BACKGROUND)
+                .`when`(bpfNetMaps).getUidNetworkingBlockedReasons(TEST_UID)
+        doReturn(BLOCKED_REASON_NONE)
+                .`when`(bpfNetMaps).getUidNetworkingBlockedReasons(TEST_UID2)
+        cm.replaceFirewallChain(FIREWALL_CHAIN_BACKGROUND, intArrayOf(TEST_UID2))
+        waitForIdle()
+        inOrder.verify(destroySocketsWrapper, never())
+                .destroyLiveTcpSocketsByOwnerUids(setOf(TEST_UID))
+
+        netdEventListener.onInterfaceClassActivityChanged(
+                true, // isActive
+                cellAgent.network.netId,
+                TIMESTAMP,
+                TEST_UID
+        )
+        waitForIdle()
+        inOrder.verify(destroySocketsWrapper)
+                .destroyLiveTcpSocketsByOwnerUids(setOf(TEST_UID))
+
+        cellAgent.disconnect()
+    }
+
+    private fun doTestDestroySockets(
+            isFrozen: Boolean,
+            denyOnBackgroundChain: Boolean,
+            enableBackgroundChain: Boolean,
+            expectDestroySockets: Boolean
+    ) {
+        val netdEventListener = getRegisteredNetdUnsolicitedEventListener()
+        val frozenStateCallback = getUidFrozenStateChangedCallback()
+
+        // Make cell default network idle
+        val cellAgent = Agent(nc = cellNc(), lp = cellLp())
+        cellAgent.connect()
+        netdEventListener.onInterfaceClassActivityChanged(
+                false, // isActive
+                cellAgent.network.netId,
+                TIMESTAMP,
+                TEST_UID
+        )
+
+        // Set deny rule on background chain for TEST_UID
+        doReturn(BLOCKED_REASON_APP_BACKGROUND)
+                .`when`(bpfNetMaps).getUidNetworkingBlockedReasons(TEST_UID)
+        cm.setUidFirewallRule(
+                FIREWALL_CHAIN_BACKGROUND,
+                TEST_UID,
+                FIREWALL_RULE_DENY
+        )
+
+        // Freeze TEST_UID
+        frozenStateCallback.onUidFrozenStateChanged(
+                intArrayOf(TEST_UID),
+                intArrayOf(UID_FROZEN_STATE_FROZEN)
+        )
+
+        if (!isFrozen) {
+            // Unfreeze TEST_UID
+            frozenStateCallback.onUidFrozenStateChanged(
+                    intArrayOf(TEST_UID),
+                    intArrayOf(UID_FROZEN_STATE_UNFROZEN)
+            )
+        }
+        if (!enableBackgroundChain) {
+            // Disable background chain
+            cm.setFirewallChainEnabled(FIREWALL_CHAIN_BACKGROUND, false)
+        }
+        if (!denyOnBackgroundChain) {
+            // Set allow rule on background chain for TEST_UID
+            doReturn(BLOCKED_REASON_NONE)
+                    .`when`(bpfNetMaps).getUidNetworkingBlockedReasons(TEST_UID)
+            cm.setUidFirewallRule(
+                    FIREWALL_CHAIN_BACKGROUND,
+                    TEST_UID,
+                    FIREWALL_RULE_ALLOW
+            )
+        }
+        verify(destroySocketsWrapper, never()).destroyLiveTcpSocketsByOwnerUids(setOf(TEST_UID))
+
+        // Make cell network active
+        netdEventListener.onInterfaceClassActivityChanged(
+                true, // isActive
+                cellAgent.network.netId,
+                TIMESTAMP,
+                TEST_UID
+        )
+        waitForIdle()
+
+        if (expectDestroySockets) {
+            verify(destroySocketsWrapper).destroyLiveTcpSocketsByOwnerUids(setOf(TEST_UID))
+        } else {
+            verify(destroySocketsWrapper, never()).destroyLiveTcpSocketsByOwnerUids(setOf(TEST_UID))
+        }
+    }
+
+    @Test
+    fun testDestroySockets_backgroundDeny_frozen() {
+        doTestDestroySockets(
+                isFrozen = true,
+                denyOnBackgroundChain = true,
+                enableBackgroundChain = true,
+                expectDestroySockets = true
+        )
+    }
+
+    @Test
+    fun testDestroySockets_backgroundDeny_nonFrozen() {
+        doTestDestroySockets(
+                isFrozen = false,
+                denyOnBackgroundChain = true,
+                enableBackgroundChain = true,
+                expectDestroySockets = true
+        )
+    }
+
+    @Test
+    fun testDestroySockets_backgroundAllow_frozen() {
+        doTestDestroySockets(
+                isFrozen = true,
+                denyOnBackgroundChain = false,
+                enableBackgroundChain = true,
+                expectDestroySockets = true
+        )
+    }
+
+    @Test
+    fun testDestroySockets_backgroundAllow_nonFrozen() {
+        // If the app is neither frozen nor under background restriction, sockets are not
+        // destroyed
+        doTestDestroySockets(
+                isFrozen = false,
+                denyOnBackgroundChain = false,
+                enableBackgroundChain = true,
+                expectDestroySockets = false
+        )
+    }
+
+    @Test
+    fun testDestroySockets_backgroundChainDisabled_nonFrozen() {
+        // If the app is neither frozen nor under background restriction, sockets are not
+        // destroyed
+        doTestDestroySockets(
+                isFrozen = false,
+                denyOnBackgroundChain = true,
+                enableBackgroundChain = false,
+                expectDestroySockets = false
+        )
+    }
+}
diff --git a/tests/unit/java/com/android/server/connectivityservice/base/CSTest.kt b/tests/unit/java/com/android/server/connectivityservice/base/CSTest.kt
index 99a8a3d..47a6763 100644
--- a/tests/unit/java/com/android/server/connectivityservice/base/CSTest.kt
+++ b/tests/unit/java/com/android/server/connectivityservice/base/CSTest.kt
@@ -210,6 +210,7 @@
 
     val multicastRoutingCoordinatorService = mock<MulticastRoutingCoordinatorService>()
     val satelliteAccessController = mock<SatelliteAccessController>()
+    val destroySocketsWrapper = mock<DestroySocketsWrapper>()
 
     val deps = CSDeps()
 
@@ -263,6 +264,11 @@
         alarmHandlerThread.join()
     }
 
+    // Class to be mocked and used to verify destroy sockets methods call
+    open inner class DestroySocketsWrapper {
+        open fun destroyLiveTcpSocketsByOwnerUids(ownerUids: Set<Int>) {}
+    }
+
     inner class CSDeps : ConnectivityService.Dependencies() {
         override fun getResources(ctx: Context) = connResources
         override fun getBpfNetMaps(context: Context, netd: INetd) = this@CSTest.bpfNetMaps
@@ -368,6 +374,11 @@
 
         override fun getCallingUid() =
                 if (callingUid == CALLING_UID_UNMOCKED) super.getCallingUid() else callingUid
+
+        override fun destroyLiveTcpSocketsByOwnerUids(ownerUids: Set<Int>) {
+            // Call mocked destroyLiveTcpSocketsByOwnerUids so that test can verify this method call
+            destroySocketsWrapper.destroyLiveTcpSocketsByOwnerUids(ownerUids)
+        }
     }
 
     inner class CSContext(base: Context) : BroadcastInterceptingContext(base) {