Suspend callbacks for frozen apps

To avoid filling the binder buffers, and sending unnecessary callbacks
when apps are unfrozen (for example successive LinkProperties changes or
onAvailable+onLost are not useful), queue and aggregate callbacks for
apps that are suspended.

Test: atest
Bug: 327038794
Bug: 279392981
Change-Id: I6de357ac2834f1f7e960dfbaa63adfc3825f2a82
diff --git a/framework/src/android/net/ConnectivityManager.java b/framework/src/android/net/ConnectivityManager.java
index a6a967b..8cf6e04 100644
--- a/framework/src/android/net/ConnectivityManager.java
+++ b/framework/src/android/net/ConnectivityManager.java
@@ -4163,6 +4163,8 @@
          */
         @FilteredCallback(methodId = METHOD_ONAVAILABLE_5ARGS,
                 calledByCallbackId = CALLBACK_AVAILABLE,
+                // If this list is modified, ConnectivityService#addAvailableStateUpdateCallbacks
+                // needs to be updated too.
                 mayCall = { METHOD_ONAVAILABLE_4ARGS,
                         METHOD_ONLOCALNETWORKINFOCHANGED,
                         METHOD_ONBLOCKEDSTATUSCHANGED_INT })
@@ -4193,6 +4195,8 @@
          */
         @FilteredCallback(methodId = METHOD_ONAVAILABLE_4ARGS,
                 calledByCallbackId = CALLBACK_TRANSITIVE_CALLS_ONLY,
+                // If this list is modified, ConnectivityService#addAvailableStateUpdateCallbacks
+                // needs to be updated too.
                 mayCall = { METHOD_ONAVAILABLE_1ARG,
                         METHOD_ONNETWORKSUSPENDED,
                         METHOD_ONCAPABILITIESCHANGED,
diff --git a/service/src/com/android/server/ConnectivityService.java b/service/src/com/android/server/ConnectivityService.java
index 953fd76..6ae7505 100755
--- a/service/src/com/android/server/ConnectivityService.java
+++ b/service/src/com/android/server/ConnectivityService.java
@@ -56,6 +56,7 @@
 import static android.net.ConnectivityManager.FIREWALL_RULE_ALLOW;
 import static android.net.ConnectivityManager.FIREWALL_RULE_DEFAULT;
 import static android.net.ConnectivityManager.FIREWALL_RULE_DENY;
+import static android.net.ConnectivityManager.NETID_UNSET;
 import static android.net.ConnectivityManager.NetworkCallback.DECLARED_METHODS_ALL;
 import static android.net.ConnectivityManager.NetworkCallback.DECLARED_METHODS_NONE;
 import static android.net.ConnectivityManager.TYPE_BLUETOOTH;
@@ -2147,7 +2148,7 @@
     }
 
     @VisibleForTesting
-    void updateMobileDataPreferredUids() {
+    public void updateMobileDataPreferredUids() {
         mHandler.sendEmptyMessage(EVENT_MOBILE_DATA_PREFERRED_UIDS_CHANGED);
     }
 
@@ -3403,7 +3404,7 @@
     }
 
     @VisibleForTesting
-    void handleBlockedReasonsChanged(List<Pair<Integer, Integer>> reasonsList) {
+    public void handleBlockedReasonsChanged(List<Pair<Integer, Integer>> reasonsList) {
         for (Pair<Integer, Integer> reasons: reasonsList) {
             final int uid = reasons.first;
             final int blockedReasons = reasons.second;
@@ -3472,7 +3473,7 @@
     private void handleFrozenUids(int[] uids, int[] frozenStates) {
         ensureRunningOnConnectivityServiceThread();
         handleDestroyFrozenSockets(uids, frozenStates);
-        // TODO: handle freezing NetworkCallbacks
+        handleFreezeNetworkCallbacks(uids, frozenStates);
     }
 
     private void handleDestroyFrozenSockets(int[] uids, int[] frozenStates) {
@@ -3490,6 +3491,73 @@
         }
     }
 
+    private void handleFreezeNetworkCallbacks(int[] uids, int[] frozenStates) {
+        if (!mQueueCallbacksForFrozenApps) {
+            return;
+        }
+        for (int i = 0; i < uids.length; i++) {
+            final int uid = uids[i];
+            // These counters may be modified on different threads, but using them here is fine
+            // because this is only an optimization where wrong behavior would only happen if they
+            // are zero even though there is a request registered. This is not possible as they are
+            // always incremented before posting messages to register, and decremented on the
+            // handler thread when unregistering.
+            if (mSystemNetworkRequestCounter.get(uid) == 0
+                    && mNetworkRequestCounter.get(uid) == 0) {
+                // Avoid iterating requests if there isn't any. The counters only track app requests
+                // and not internal requests (for example always-on requests which do not have a
+                // mMessenger), so it does not completely match the content of mRequests. This is OK
+                // as only app requests need to be frozen.
+                continue;
+            }
+
+            if (frozenStates[i] == UID_FROZEN_STATE_FROZEN) {
+                freezeNetworkCallbacksForUid(uid);
+            } else {
+                unfreezeNetworkCallbacksForUid(uid);
+            }
+        }
+    }
+
+    /**
+     * Suspend callbacks for a UID that was just frozen.
+     *
+     * <p>Note that it is not possible for a process to be frozen during a blocking binder call
+     * (see CachedAppOptimizer.freezeBinder), and IConnectivityManager callback registrations are
+     * blocking binder calls, so no callback can be registered while the UID is frozen. This means
+     * it is not necessary to check frozen state on new callback registrations, and calling this
+     * method when a UID is newly frozen is sufficient.
+     *
+     * <p>If it ever becomes possible for a process to be frozen during a blocking binder call,
+     * ConnectivityService will need to handle freezing callbacks that reach ConnectivityService
+     * after the app was frozen when being registered.
+     */
+    private void freezeNetworkCallbacksForUid(int uid) {
+        if (DDBG) Log.d(TAG, "Freezing callbacks for UID " + uid);
+        for (NetworkRequestInfo nri : mNetworkRequests.values()) {
+            if (nri.mUid != uid) continue;
+            // mNetworkRequests can have duplicate values for multilayer requests, but calling
+            // onFrozen multiple times is fine.
+            // If freezeNetworkCallbacksForUid was called multiple times in a raw for a frozen UID
+            // (which would be incorrect), this would also handle it gracefully.
+            nri.onFrozen();
+        }
+    }
+
+    private void unfreezeNetworkCallbacksForUid(int uid) {
+        // This sends all callbacks for one NetworkRequest at a time, which may not be the
+        // same order they were queued in, but different network requests use different
+        // binder objects, so the relative order of their callbacks is not guaranteed.
+        // If callbacks are not queued, callbacks from different binder objects may be
+        // posted on different threads when the process is unfrozen, so even if they were
+        // called a long time apart while the process was frozen, they may still appear in
+        // different order when unfreezing it.
+        for (NetworkRequestInfo nri : mNetworkRequests.values()) {
+            if (nri.mUid != uid) continue;
+            nri.sendQueuedCallbacks();
+        }
+    }
+
     private void handleUpdateFirewallDestroySocketReasons(
             List<Pair<Integer, Integer>> reasonsList) {
         if (!shouldTrackFirewallDestroySocketReasons()) {
@@ -7544,6 +7612,29 @@
         // single NetworkRequest in mRequests.
         final List<NetworkRequest> mRequests;
 
+        /**
+         * List of callbacks that are queued for sending later when the requesting app is unfrozen.
+         *
+         * <p>There may typically be hundreds of NetworkRequestInfo, so a memory-efficient structure
+         * (just an int[]) is used to keep queued callbacks. This reduces the number of object
+         * references.
+         *
+         * <p>This is intended to be used with {@link CallbackQueue} which defines the internal
+         * format.
+         */
+        @NonNull
+        private int[] mQueuedCallbacks = new int[0];
+
+        private static final int MATCHED_NETID_NOT_FROZEN = -1;
+
+        /**
+         * If this request was already satisfied by a network when the requesting UID was frozen,
+         * the netId that was matched at that time. Otherwise, NETID_UNSET if no network was
+         * satisfying this request when frozen (including if this is a listen and not a request),
+         * and MATCHED_NETID_NOT_FROZEN if not frozen.
+         */
+        private int mMatchedNetIdWhenFrozen = MATCHED_NETID_NOT_FROZEN;
+
         // mSatisfier and mActiveRequest rely on one another therefore set them together.
         void setSatisfier(
                 @Nullable final NetworkAgentInfo satisfier,
@@ -7715,6 +7806,8 @@
                 }
                 setSatisfier(satisfier, activeRequest);
             }
+            mMatchedNetIdWhenFrozen = nri.mMatchedNetIdWhenFrozen;
+            mQueuedCallbacks = nri.mQueuedCallbacks;
             mMessenger = nri.mMessenger;
             mBinder = nri.mBinder;
             mPid = nri.mPid;
@@ -7779,11 +7872,190 @@
             }
         }
 
+        /**
+         * Called when this NRI is being frozen.
+         *
+         * <p>Calling this method multiple times when the NRI is frozen is fine. This may happen
+         * if iterating through the NetworkRequest -> NRI map since there are duplicates in the
+         * NRI values for multilayer requests. It may also happen if an app is frozen, killed,
+         * restarted and refrozen since there is no callback sent when processes are killed, but in
+         * that case the callbacks to the killed app do not matter.
+         */
+        void onFrozen() {
+            if (mMatchedNetIdWhenFrozen != MATCHED_NETID_NOT_FROZEN) {
+                // Already frozen
+                return;
+            }
+            if (mSatisfier != null) {
+                mMatchedNetIdWhenFrozen = mSatisfier.network.netId;
+            } else {
+                mMatchedNetIdWhenFrozen = NETID_UNSET;
+            }
+        }
+
+        boolean maybeQueueCallback(@NonNull NetworkAgentInfo nai, int callbackId) {
+            if (mMatchedNetIdWhenFrozen == MATCHED_NETID_NOT_FROZEN) {
+                return false;
+            }
+
+            boolean ignoreThisCallback = false;
+            final int netId = nai.network.netId;
+            final CallbackQueue queue = new CallbackQueue(mQueuedCallbacks);
+            // Based on the new callback, clear previous callbacks that are no longer necessary.
+            // For example, if the network is lost, there is no need to send intermediate callbacks.
+            switch (callbackId) {
+                // PRECHECK is not an API and not very meaningful, do not deliver it for frozen apps
+                // Networks are likely to already be lost when the app is unfrozen, also skip LOSING
+                case CALLBACK_PRECHECK:
+                case CALLBACK_LOSING:
+                    ignoreThisCallback = true;
+                    break;
+                case CALLBACK_LOST:
+                    // All callbacks for this netId before onLost are unnecessary. And onLost itself
+                    // is also unnecessary if onAvailable was previously queued for this netId: the
+                    // Network just appeared and disappeared while the app was frozen.
+                    ignoreThisCallback = queue.hasCallback(netId, CALLBACK_AVAILABLE);
+                    queue.removeCallbacksForNetId(netId);
+                    break;
+                case CALLBACK_AVAILABLE:
+                    if (mSatisfier != null) {
+                        // For requests that are satisfied by individual networks (not LISTEN), when
+                        // AVAILABLE is received, the request is matching a new Network, so previous
+                        // callbacks (for other Networks) are unnecessary.
+                        queue.clear();
+                    }
+                    break;
+                case CALLBACK_SUSPENDED:
+                case CALLBACK_RESUMED:
+                    if (queue.hasCallback(netId, CALLBACK_AVAILABLE)) {
+                        // AVAILABLE will already send the latest suspended status
+                        ignoreThisCallback = true;
+                        break;
+                    }
+                    // If SUSPENDED was queued, just remove it from the queue instead of sending
+                    // RESUMED; and vice-versa.
+                    final int otherCb = callbackId == CALLBACK_SUSPENDED
+                            ? CALLBACK_RESUMED
+                            : CALLBACK_SUSPENDED;
+                    ignoreThisCallback = queue.removeCallbacks(netId, otherCb);
+                    break;
+                case CALLBACK_CAP_CHANGED:
+                case CALLBACK_IP_CHANGED:
+                case CALLBACK_LOCAL_NETWORK_INFO_CHANGED:
+                case CALLBACK_BLK_CHANGED:
+                    ignoreThisCallback = queue.hasCallback(netId, CALLBACK_AVAILABLE);
+                    break;
+                default:
+                    Log.wtf(TAG, "Unexpected callback type: "
+                            + ConnectivityManager.getCallbackName(callbackId));
+                    return false;
+            }
+
+            if (!ignoreThisCallback) {
+                // For non-listen (matching) callbacks, AVAILABLE can appear in the queue twice in a
+                // row for the same network if the new AVAILABLE suppressed intermediate AVAILABLEs
+                // for other networks. Example:
+                // A is matched, app is frozen, B is matched, A is matched again (removes callbacks
+                // for B), app is unfrozen.
+                // In that case call AVAILABLE sub-callbacks to update state, but not AVAILABLE
+                // itself.
+                if (callbackId == CALLBACK_AVAILABLE && netId == mMatchedNetIdWhenFrozen) {
+                    // The queue should have been cleared here, since this is AVAILABLE on a
+                    // non-listen callback (mMatchedNetIdWhenFrozen is set).
+                    addAvailableSubCallbacks(nai, queue);
+                } else {
+                    // When unfreezing, no need to send a callback multiple times for the same netId
+                    queue.removeCallbacks(netId, callbackId);
+                    // TODO: this code always adds the callback for simplicity. It would save
+                    // some CPU/memory if the code instead only added to the queue callbacks where
+                    // isCallbackOverridden=true, or which need to be in the queue because they
+                    // affect other callbacks that are overridden.
+                    queue.addCallback(netId, callbackId);
+                }
+            }
+            // Instead of shrinking the queue, possibly reallocating, the NRI could keep the array
+            // and length in memory for future adds, but this saves memory by avoiding the cost
+            // of an extra member and of unused array length (there are often hundreds of NRIs).
+            mQueuedCallbacks = queue.getShrinkedBackingArray();
+            return true;
+        }
+
+        /**
+         * Called when this NRI is being unfrozen to stop queueing, and send queued callbacks.
+         *
+         * <p>Calling this method multiple times when the NRI is unfrozen (for example iterating
+         * through the NetworkRequest -> NRI map where there are duplicate values for multilayer
+         * requests) is fine.
+         */
+        void sendQueuedCallbacks() {
+            mMatchedNetIdWhenFrozen = MATCHED_NETID_NOT_FROZEN;
+            if (mQueuedCallbacks.length == 0) {
+                return;
+            }
+            new CallbackQueue(mQueuedCallbacks).forEach((netId, callbackId) -> {
+                // For CALLBACK_LOST only, there will not be a NAI for the netId. Build and send the
+                // callback directly.
+                if (callbackId == CALLBACK_LOST) {
+                    if (isCallbackOverridden(CALLBACK_LOST)) {
+                        final Bundle cbBundle = makeCommonBundleForCallback(this,
+                                new Network(netId));
+                        callCallbackForRequest(this, CALLBACK_LOST, cbBundle, 0 /* arg1 */);
+                    }
+                    return; // Next item in forEach
+                }
+
+                // Other callbacks should always have a NAI, because if a Network disconnects
+                // LOST will be called, unless the request is no longer satisfied by that Network in
+                // which case AVAILABLE will have been called for another Network. In both cases
+                // previous callbacks are cleared.
+                final NetworkAgentInfo nai = getNetworkAgentInfoForNetId(netId);
+                if (nai == null) {
+                    Log.wtf(TAG, "Missing NetworkAgentInfo for net " + netId
+                            + " for callback " + callbackId);
+                    return; // Next item in forEach
+                }
+
+                final int arg1 =
+                        callbackId == CALLBACK_AVAILABLE || callbackId == CALLBACK_BLK_CHANGED
+                                ? getBlockedState(nai, mAsUid)
+                                : 0;
+                callCallbackForRequest(this, nai, callbackId, arg1);
+            });
+            mQueuedCallbacks = new int[0];
+        }
+
         boolean isCallbackOverridden(int callbackId) {
             return !mUseDeclaredMethodsForCallbacksEnabled
                     || (mDeclaredMethodsFlags & (1 << callbackId)) != 0;
         }
 
+        /**
+         * Queue all callbacks that are called by AVAILABLE, except onAvailable.
+         *
+         * <p>AVAILABLE may call SUSPENDED, CAP_CHANGED, IP_CHANGED, LOCAL_NETWORK_INFO_CHANGED,
+         * and BLK_CHANGED, in this order.
+         */
+        private void addAvailableSubCallbacks(
+                @NonNull NetworkAgentInfo nai, @NonNull CallbackQueue queue) {
+            final boolean callSuspended =
+                    !nai.networkCapabilities.hasCapability(NET_CAPABILITY_NOT_SUSPENDED);
+            final boolean callLocalInfoChanged = nai.isLocalNetwork();
+
+            final int cbCount = 3 + (callSuspended ? 1 : 0) + (callLocalInfoChanged ? 1 : 0);
+            // Avoid unnecessary re-allocations by reserving enough space for all callbacks to add.
+            queue.ensureHasCapacity(cbCount);
+            final int netId = nai.network.netId;
+            if (callSuspended) {
+                queue.addCallback(netId, CALLBACK_SUSPENDED);
+            }
+            queue.addCallback(netId, CALLBACK_CAP_CHANGED);
+            queue.addCallback(netId, CALLBACK_IP_CHANGED);
+            if (callLocalInfoChanged) {
+                queue.addCallback(netId, CALLBACK_LOCAL_NETWORK_INFO_CHANGED);
+            }
+            queue.addCallback(netId, CALLBACK_BLK_CHANGED);
+        }
+
         boolean hasHigherOrderThan(@NonNull final NetworkRequestInfo target) {
             // Compare two preference orders.
             return mPreferenceOrder < target.mPreferenceOrder;
@@ -10277,6 +10549,11 @@
             // are Type.LISTEN, but should not have NetworkCallbacks invoked.
             return;
         }
+        // Even if a callback ends up not being sent, it may affect other callbacks in the queue, so
+        // queue callbacks before checking the declared methods flags.
+        if (networkAgent != null && nri.maybeQueueCallback(networkAgent, notificationType)) {
+            return;
+        }
         if (!nri.isCallbackOverridden(notificationType)) {
             // No need to send the notification as the recipient method is not overridden
             return;
diff --git a/staticlibs/framework/com/android/net/module/util/PerUidCounter.java b/staticlibs/framework/com/android/net/module/util/PerUidCounter.java
index 463b0c4..98d91a5 100644
--- a/staticlibs/framework/com/android/net/module/util/PerUidCounter.java
+++ b/staticlibs/framework/com/android/net/module/util/PerUidCounter.java
@@ -87,7 +87,9 @@
         }
     }
 
-    @VisibleForTesting
+    /**
+     * Get the current counter value for the given uid.
+     */
     public synchronized int get(int uid) {
         return mUidToCount.get(uid, 0);
     }
diff --git a/tests/unit/java/com/android/server/connectivityservice/CSQueuedCallbacksTest.kt b/tests/unit/java/com/android/server/connectivityservice/CSQueuedCallbacksTest.kt
new file mode 100644
index 0000000..fc2a06c
--- /dev/null
+++ b/tests/unit/java/com/android/server/connectivityservice/CSQueuedCallbacksTest.kt
@@ -0,0 +1,636 @@
+/*
+ * Copyright (C) 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server.connectivityservice
+
+import android.app.ActivityManager.UidFrozenStateChangedCallback
+import android.app.ActivityManager.UidFrozenStateChangedCallback.UID_FROZEN_STATE_FROZEN
+import android.app.ActivityManager.UidFrozenStateChangedCallback.UID_FROZEN_STATE_UNFROZEN
+import android.net.ConnectivityManager.BLOCKED_METERED_REASON_DATA_SAVER
+import android.net.ConnectivityManager.BLOCKED_REASON_NONE
+import android.net.ConnectivitySettingsManager
+import android.net.INetworkMonitor.NETWORK_VALIDATION_PROBE_DNS
+import android.net.INetworkMonitor.NETWORK_VALIDATION_PROBE_HTTPS
+import android.net.INetworkMonitor.NETWORK_VALIDATION_RESULT_VALID
+import android.net.LinkAddress
+import android.net.LinkProperties
+import android.net.LocalNetworkConfig
+import android.net.NetworkCapabilities
+import android.net.NetworkCapabilities.NET_CAPABILITY_FOREGROUND
+import android.net.NetworkCapabilities.NET_CAPABILITY_INTERNET
+import android.net.NetworkCapabilities.NET_CAPABILITY_LOCAL_NETWORK
+import android.net.NetworkCapabilities.NET_CAPABILITY_NOT_CONGESTED
+import android.net.NetworkCapabilities.NET_CAPABILITY_NOT_ROAMING
+import android.net.NetworkCapabilities.NET_CAPABILITY_NOT_SUSPENDED
+import android.net.NetworkCapabilities.NET_CAPABILITY_TEMPORARILY_NOT_METERED
+import android.net.NetworkCapabilities.NET_CAPABILITY_VALIDATED
+import android.net.NetworkCapabilities.TRANSPORT_CELLULAR
+import android.net.NetworkCapabilities.TRANSPORT_ETHERNET
+import android.net.NetworkCapabilities.TRANSPORT_WIFI
+import android.net.NetworkPolicyManager.NetworkPolicyCallback
+import android.net.NetworkRequest
+import android.os.Build
+import android.os.Process
+import com.android.server.CALLING_UID_UNMOCKED
+import com.android.server.CSAgentWrapper
+import com.android.server.CSTest
+import com.android.server.FromS
+import com.android.server.HANDLER_TIMEOUT_MS
+import com.android.server.connectivity.ConnectivityFlags.QUEUE_CALLBACKS_FOR_FROZEN_APPS
+import com.android.server.defaultLp
+import com.android.server.defaultNc
+import com.android.testutils.DevSdkIgnoreRule
+import com.android.testutils.DevSdkIgnoreRunner
+import com.android.testutils.RecorderCallback.CallbackEntry.BlockedStatus
+import com.android.testutils.RecorderCallback.CallbackEntry.CapabilitiesChanged
+import com.android.testutils.RecorderCallback.CallbackEntry.LinkPropertiesChanged
+import com.android.testutils.RecorderCallback.CallbackEntry.LocalInfoChanged
+import com.android.testutils.RecorderCallback.CallbackEntry.Lost
+import com.android.testutils.RecorderCallback.CallbackEntry.Resumed
+import com.android.testutils.RecorderCallback.CallbackEntry.Suspended
+import com.android.testutils.TestableNetworkCallback
+import com.android.testutils.visibleOnHandlerThread
+import com.android.testutils.waitForIdleSerialExecutor
+import java.util.Collections
+import kotlin.test.fail
+import org.junit.Before
+import org.junit.Rule
+import org.junit.Test
+import org.junit.runner.RunWith
+import org.junit.runners.Parameterized
+import org.mockito.ArgumentCaptor
+import org.mockito.Mockito.any
+import org.mockito.Mockito.verify
+
+private const val TEST_UID = 42
+
+@DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.TIRAMISU)
+@DevSdkIgnoreRunner.MonitorThreadLeak
+@RunWith(DevSdkIgnoreRunner::class)
+class CSQueuedCallbacksTest(freezingBehavior: FreezingBehavior) : CSTest() {
+    companion object {
+        enum class FreezingBehavior {
+            UID_FROZEN,
+            UID_NOT_FROZEN,
+            UID_FROZEN_FEATURE_DISABLED
+        }
+
+        // Use a parameterized test with / without freezing to make it easy to compare and make sure
+        // freezing behavior (which callbacks are sent in which order) stays close to what happens
+        // without freezing.
+        @JvmStatic
+        @Parameterized.Parameters(name = "freezingBehavior={0}")
+        fun freezingBehavior() = listOf(
+            FreezingBehavior.UID_FROZEN,
+            FreezingBehavior.UID_NOT_FROZEN,
+            FreezingBehavior.UID_FROZEN_FEATURE_DISABLED
+        )
+
+        private val TAG = CSQueuedCallbacksTest::class.simpleName
+            ?: fail("Could not get test class name")
+    }
+
+    @get:Rule
+    val ignoreRule = DevSdkIgnoreRule()
+
+    private val mockedBlockedReasonsPerUid = Collections.synchronizedMap(mutableMapOf(
+        Process.myUid() to BLOCKED_REASON_NONE,
+        TEST_UID to BLOCKED_REASON_NONE
+    ))
+
+    private val freezeUids = freezingBehavior != FreezingBehavior.UID_NOT_FROZEN
+    private val expectAllCallbacks = freezingBehavior == FreezingBehavior.UID_NOT_FROZEN ||
+            freezingBehavior == FreezingBehavior.UID_FROZEN_FEATURE_DISABLED
+    init {
+        setFeatureEnabled(
+            QUEUE_CALLBACKS_FOR_FROZEN_APPS,
+            freezingBehavior != FreezingBehavior.UID_FROZEN_FEATURE_DISABLED
+        )
+    }
+
+    @Before
+    fun subclassSetUp() {
+        // Ensure cellular stays up. CS is recreated for each test so no cleanup is necessary.
+//        cm.requestNetwork(
+//            NetworkRequest.Builder().addTransportType(TRANSPORT_CELLULAR).build(),
+//            TestableNetworkCallback()
+//        )
+    }
+
+    @Test
+    fun testFrozenWhileNetworkConnects_UpdatesAreReceived() {
+        val cb = TestableNetworkCallback(logTag = TAG)
+        withCallingUid(TEST_UID) {
+            cm.registerNetworkCallback(NetworkRequest.Builder().build(), cb)
+        }
+        val agent = Agent(TRANSPORT_CELLULAR).apply { connect() }
+        val lpChangeOnConnect = agent.sendLpChange { setLinkAddresses("fe80:db8::123/64") }
+        val ncChangeOnConnect = agent.sendNcChange {
+            addCapability(NET_CAPABILITY_TEMPORARILY_NOT_METERED)
+        }
+
+        maybeSetUidsFrozen(true, TEST_UID)
+
+        val lpChange1WhileFrozen = agent.sendLpChange {
+            setLinkAddresses("fe80:db8::126/64")
+        }
+        val ncChange1WhileFrozen = agent.sendNcChange {
+            removeCapability(NET_CAPABILITY_NOT_ROAMING)
+        }
+        val ncChange2WhileFrozen = agent.sendNcChange {
+            addCapability(NET_CAPABILITY_NOT_ROAMING)
+            addCapability(NET_CAPABILITY_NOT_CONGESTED)
+        }
+        val lpChange2WhileFrozen = agent.sendLpChange {
+            setLinkAddresses("fe80:db8::125/64")
+        }
+        maybeSetUidsFrozen(false, TEST_UID)
+
+        // Verify callbacks that are sent before freezing
+        cb.expectAvailableCallbacks(agent.network, validated = false)
+        cb.expectLpWith(agent, lpChangeOnConnect)
+        cb.expectNcWith(agent, ncChangeOnConnect)
+
+        // Below callbacks should be skipped if the processes were frozen, since a single callback
+        // will be sent with the latest state after unfreezing
+        if (expectAllCallbacks) {
+            cb.expectLpWith(agent, lpChange1WhileFrozen)
+            cb.expectNcWith(agent, ncChange1WhileFrozen)
+        }
+
+        cb.expectNcWith(agent, ncChange2WhileFrozen)
+        cb.expectLpWith(agent, lpChange2WhileFrozen)
+
+        cb.assertNoCallback(timeoutMs = 0L)
+    }
+
+    @Test
+    fun testFrozenWhileNetworkConnects_SuspendedUnsuspendedWhileFrozen() {
+        val cb = TestableNetworkCallback(logTag = TAG)
+        withCallingUid(TEST_UID) {
+            cm.registerNetworkCallback(NetworkRequest.Builder().build(), cb)
+        }
+
+        val agent = Agent(TRANSPORT_CELLULAR).apply { connect() }
+        maybeSetUidsFrozen(true, TEST_UID)
+        val rmCap = agent.sendNcChange { removeCapability(NET_CAPABILITY_NOT_SUSPENDED) }
+        val addCap = agent.sendNcChange { addCapability(NET_CAPABILITY_NOT_SUSPENDED) }
+
+        maybeSetUidsFrozen(false, TEST_UID)
+
+        cb.expectAvailableCallbacks(agent.network, validated = false)
+        if (expectAllCallbacks) {
+            cb.expectNcWith(agent, rmCap)
+            cb.expect<Suspended>(agent)
+            cb.expectNcWith(agent, addCap)
+            cb.expect<Resumed>(agent)
+        } else {
+            // When frozen, a single NetworkCapabilitiesChange will be sent at unfreezing time,
+            // with nc actually identical to the original ones. This is because NetworkCapabilities
+            // callbacks were sent, but CS does not keep initial NetworkCapabilities in memory, so
+            // it cannot detect A->B->A.
+            cb.expect<CapabilitiesChanged>(agent) {
+                it.caps.hasCapability(NET_CAPABILITY_NOT_SUSPENDED)
+            }
+        }
+        cb.assertNoCallback(timeoutMs = 0L)
+    }
+
+    @Test
+    fun testFrozenWhileNetworkConnects_UnsuspendedWhileFrozen_GetResumedCallbackWhenUnfrozen() {
+        val cb = TestableNetworkCallback(logTag = TAG)
+        withCallingUid(TEST_UID) {
+            cm.registerNetworkCallback(NetworkRequest.Builder().build(), cb)
+        }
+
+        val agent = Agent(TRANSPORT_CELLULAR).apply { connect() }
+        val rmCap = agent.sendNcChange { removeCapability(NET_CAPABILITY_NOT_SUSPENDED) }
+        maybeSetUidsFrozen(true, TEST_UID)
+        val addCap = agent.sendNcChange { addCapability(NET_CAPABILITY_NOT_SUSPENDED) }
+        maybeSetUidsFrozen(false, TEST_UID)
+
+        cb.expectAvailableCallbacks(agent.network, validated = false)
+        cb.expectNcWith(agent, rmCap)
+        cb.expect<Suspended>(agent)
+        cb.expectNcWith(agent, addCap)
+        cb.expect<Resumed>(agent)
+        cb.assertNoCallback(timeoutMs = 0L)
+    }
+
+    @Test
+    fun testFrozenWhileNetworkConnects_BlockedUnblockedWhileFrozen_SingleCallbackIfFrozen() {
+        val cb = TestableNetworkCallback(logTag = TAG)
+        withCallingUid(TEST_UID) {
+            cm.registerNetworkCallback(NetworkRequest.Builder().build(), cb)
+        }
+        val agent = Agent(TRANSPORT_CELLULAR).apply { connect() }
+
+        maybeSetUidsFrozen(true, TEST_UID)
+        setUidsBlockedForDataSaver(true, TEST_UID)
+        setUidsBlockedForDataSaver(false, TEST_UID)
+        maybeSetUidsFrozen(false, TEST_UID)
+
+        cb.expectAvailableCallbacks(agent.network, validated = false)
+        if (expectAllCallbacks) {
+            cb.expect<BlockedStatus>(agent) { it.blocked }
+        }
+        // The unblocked callback is sent in any case (with the latest blocked reason), as the
+        // blocked reason may have changed, and ConnectivityService cannot know that it is the same
+        // as the original reason as it does not keep pre-freeze blocked reasons in memory.
+        cb.expect<BlockedStatus>(agent) { !it.blocked }
+        cb.assertNoCallback(timeoutMs = 0L)
+    }
+
+    @Test
+    fun testFrozenWhileNetworkConnects_BlockedWhileFrozen_GetLastBlockedCallbackOnlyIfFrozen() {
+        val cb = TestableNetworkCallback(logTag = TAG)
+        withCallingUid(TEST_UID) {
+            cm.registerNetworkCallback(NetworkRequest.Builder().build(), cb)
+        }
+        val agent = Agent(TRANSPORT_CELLULAR).apply { connect() }
+
+        maybeSetUidsFrozen(true, TEST_UID)
+        setUidsBlockedForDataSaver(true, TEST_UID)
+        setUidsBlockedForDataSaver(false, TEST_UID)
+        setUidsBlockedForDataSaver(true, TEST_UID)
+        maybeSetUidsFrozen(false, TEST_UID)
+
+        cb.expectAvailableCallbacks(agent.network, validated = false)
+        if (expectAllCallbacks) {
+            cb.expect<BlockedStatus>(agent) { it.blocked }
+            cb.expect<BlockedStatus>(agent) { !it.blocked }
+        }
+        cb.expect<BlockedStatus>(agent) { it.blocked }
+        cb.assertNoCallback(timeoutMs = 0L)
+    }
+
+    @Test
+    fun testNetworkCallback_NetworkToggledWhileFrozen_NotSeen() {
+        val cb = TestableNetworkCallback(logTag = TAG)
+        withCallingUid(TEST_UID) {
+            cm.registerNetworkCallback(NetworkRequest.Builder().build(), cb)
+        }
+        val cellAgent = Agent(TRANSPORT_CELLULAR).apply { connect() }
+        maybeSetUidsFrozen(true, TEST_UID)
+        val wifiAgent = Agent(TRANSPORT_WIFI).apply { connect() }
+        wifiAgent.disconnect()
+        maybeSetUidsFrozen(false, TEST_UID)
+
+        cb.expectAvailableCallbacks(cellAgent.network, validated = false)
+        if (expectAllCallbacks) {
+            cb.expectAvailableCallbacks(wifiAgent.network, validated = false)
+            cb.expect<Lost>(wifiAgent)
+        }
+        cb.assertNoCallback(timeoutMs = 0L)
+    }
+
+    @Test
+    fun testNetworkCallback_NetworkAppearedWhileFrozen_ReceiveLatestInfoInOnAvailable() {
+        val cb = TestableNetworkCallback(logTag = TAG)
+        withCallingUid(TEST_UID) {
+            cm.registerNetworkCallback(NetworkRequest.Builder().build(), cb)
+        }
+        maybeSetUidsFrozen(true, TEST_UID)
+        val agent = Agent(TRANSPORT_CELLULAR).apply { connect() }
+        waitForIdle()
+        agent.makeValidationSuccess()
+        val lpChange = agent.sendLpChange {
+            setLinkAddresses("fe80:db8::123/64")
+        }
+        val suspendedChange = agent.sendNcChange {
+            removeCapability(NET_CAPABILITY_NOT_SUSPENDED)
+        }
+        setUidsBlockedForDataSaver(true, TEST_UID)
+
+        maybeSetUidsFrozen(false, TEST_UID)
+
+        val expectLatestStatusInOnAvailable = !expectAllCallbacks
+        cb.expectAvailableCallbacks(
+            agent.network,
+            suspended = expectLatestStatusInOnAvailable,
+            validated = expectLatestStatusInOnAvailable,
+            blocked = expectLatestStatusInOnAvailable
+        )
+        if (expectAllCallbacks) {
+            cb.expectNcWith(agent) { addCapability(NET_CAPABILITY_VALIDATED) }
+            cb.expectLpWith(agent, lpChange)
+            cb.expectNcWith(agent, suspendedChange)
+            cb.expect<Suspended>(agent)
+            cb.expect<BlockedStatus>(agent) { it.blocked }
+        }
+        cb.assertNoCallback(timeoutMs = 0L)
+    }
+
+    @Test
+    @DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.UPSIDE_DOWN_CAKE)
+    fun testNetworkCallback_LocalNetworkAppearedWhileFrozen_ReceiveLatestInfoInOnAvailable() {
+        val cb = TestableNetworkCallback(logTag = TAG)
+        withCallingUid(TEST_UID) {
+            cm.registerNetworkCallback(
+                NetworkRequest.Builder().addCapability(NET_CAPABILITY_LOCAL_NETWORK).build(),
+                cb
+            )
+        }
+        val upstreamAgent = Agent(
+            nc = defaultNc()
+                .addTransportType(TRANSPORT_WIFI)
+                .addCapability(NET_CAPABILITY_INTERNET),
+            lp = defaultLp().apply { interfaceName = "wlan0" }
+        ).apply { connect() }
+        maybeSetUidsFrozen(true, TEST_UID)
+
+        val lnc = LocalNetworkConfig.Builder().build()
+        val localAgent = Agent(
+            nc = defaultNc()
+                .addCapability(NET_CAPABILITY_LOCAL_NETWORK)
+                .removeCapability(NET_CAPABILITY_INTERNET),
+            lp = defaultLp().apply { interfaceName = "local42" },
+            lnc = FromS(lnc)
+        ).apply { connect() }
+        localAgent.sendLocalNetworkConfig(
+            LocalNetworkConfig.Builder()
+                .setUpstreamSelector(
+                    NetworkRequest.Builder()
+                        .addCapability(NET_CAPABILITY_INTERNET)
+                        .build()
+                )
+                .build()
+        )
+
+        maybeSetUidsFrozen(false, TEST_UID)
+
+        cb.expectAvailableCallbacks(
+            localAgent.network,
+            validated = false,
+            upstream = if (expectAllCallbacks) null else upstreamAgent.network
+        )
+        if (expectAllCallbacks) {
+            cb.expect<LocalInfoChanged>(localAgent) {
+                it.info.upstreamNetwork == upstreamAgent.network
+            }
+        }
+        cb.assertNoCallback(timeoutMs = 0L)
+    }
+
+    @Test
+    fun testNetworkRequest_NetworkSwitchesWhileFrozen_ReceiveLastNetworkUpdatesOnly() {
+        val cb = TestableNetworkCallback(logTag = TAG)
+        withCallingUid(TEST_UID) {
+            cm.requestNetwork(NetworkRequest.Builder().build(), cb)
+        }
+        val cellAgent = Agent(TRANSPORT_CELLULAR).apply { connect() }
+        maybeSetUidsFrozen(true, TEST_UID)
+        val wifiAgent = Agent(TRANSPORT_WIFI).apply { connect() }
+        val ethAgent = Agent(TRANSPORT_ETHERNET).apply { connect() }
+        waitForIdle()
+        ethAgent.makeValidationSuccess()
+        maybeSetUidsFrozen(false, TEST_UID)
+
+        cb.expectAvailableCallbacks(cellAgent.network, validated = false)
+        if (expectAllCallbacks) {
+            cb.expectAvailableCallbacks(wifiAgent.network, validated = false)
+            cb.expectAvailableCallbacks(ethAgent.network, validated = false)
+            cb.expectNcWith(ethAgent) { addCapability(NET_CAPABILITY_VALIDATED) }
+        } else {
+            cb.expectAvailableCallbacks(ethAgent.network, validated = true)
+        }
+        cb.assertNoCallback(timeoutMs = 0L)
+    }
+
+    @Test
+    fun testNetworkRequest_NetworkSwitchesBackWhileFrozen_ReceiveNoAvailableCallback() {
+        val cb = TestableNetworkCallback(logTag = TAG)
+        withCallingUid(TEST_UID) {
+            cm.requestNetwork(NetworkRequest.Builder().build(), cb)
+        }
+        val cellAgent = Agent(TRANSPORT_CELLULAR).apply { connect() }
+        maybeSetUidsFrozen(true, TEST_UID)
+        val wifiAgent = Agent(TRANSPORT_WIFI).apply { connect() }
+        waitForIdle()
+
+        // CS switches back to validated cell over non-validated Wi-Fi
+        cellAgent.makeValidationSuccess()
+        val cellLpChange = cellAgent.sendLpChange {
+            setLinkAddresses("fe80:db8::123/64")
+        }
+        setUidsBlockedForDataSaver(true, TEST_UID)
+        maybeSetUidsFrozen(false, TEST_UID)
+
+        cb.expectAvailableCallbacks(cellAgent.network, validated = false)
+        if (expectAllCallbacks) {
+            cb.expectAvailableCallbacks(wifiAgent.network, validated = false)
+            // There is an extra "double validated" CapabilitiesChange callback (b/245893397), so
+            // callbacks are (AVAIL, NC, LP), extra NC, then further updates (LP and BLK here).
+            cb.expectAvailableDoubleValidatedCallbacks(cellAgent.network)
+            cb.expectLpWith(cellAgent, cellLpChange)
+            cb.expect<BlockedStatus>(cellAgent) { it.blocked }
+        } else {
+            cb.expectNcWith(cellAgent) {
+                addCapability(NET_CAPABILITY_VALIDATED)
+            }
+            cb.expectLpWith(cellAgent, cellLpChange)
+            cb.expect<BlockedStatus>(cellAgent) { it.blocked }
+        }
+        cb.assertNoCallback(timeoutMs = 0L)
+    }
+
+    @Test
+    fun testTrackDefaultRequest_AppFrozenWhilePerAppDefaultRequestFiled_ReceiveChangeCallbacks() {
+        val cellAgent = Agent(TRANSPORT_CELLULAR, baseNc = makeInternetNc()).apply { connect() }
+        waitForIdle()
+
+        val cb = TestableNetworkCallback(logTag = TAG)
+        withCallingUid(TEST_UID) {
+            cm.registerDefaultNetworkCallback(cb)
+        }
+        maybeSetUidsFrozen(true, TEST_UID)
+
+        // Change LinkProperties twice before the per-app network request is applied
+        val lpChange1 = cellAgent.sendLpChange {
+            setLinkAddresses("fe80:db8::123/64")
+        }
+        val lpChange2 = cellAgent.sendLpChange {
+            setLinkAddresses("fe80:db8::124/64")
+        }
+        setMobileDataPreferredUids(setOf(TEST_UID))
+
+        // Change NetworkCapabilities after the per-app network request is applied
+        val ncChange = cellAgent.sendNcChange {
+            addCapability(NET_CAPABILITY_TEMPORARILY_NOT_METERED)
+        }
+        maybeSetUidsFrozen(false, TEST_UID)
+
+        // Even if a per-app network request was filed to replace the default network request for
+        // the app, all network change callbacks are received
+        cb.expectAvailableCallbacks(cellAgent.network, validated = false)
+        if (expectAllCallbacks) {
+            cb.expectLpWith(cellAgent, lpChange1)
+        }
+        cb.expectLpWith(cellAgent, lpChange2)
+        cb.expectNcWith(cellAgent, ncChange)
+        cb.assertNoCallback(timeoutMs = 0L)
+    }
+
+    @Test
+    fun testTrackDefaultRequest_AppFrozenWhilePerAppDefaultToggled_GetStatusUpdateCallbacksOnly() {
+        // Add validated Wi-Fi and non-validated cell, expect Wi-Fi is preferred by default
+        val wifiAgent = Agent(TRANSPORT_WIFI, baseNc = makeInternetNc()).apply { connect() }
+        wifiAgent.makeValidationSuccess()
+        val cellAgent = Agent(TRANSPORT_CELLULAR, baseNc = makeInternetNc()).apply { connect() }
+        waitForIdle()
+
+        val cb = TestableNetworkCallback(logTag = TAG)
+        withCallingUid(TEST_UID) {
+            cm.registerDefaultNetworkCallback(cb)
+        }
+        maybeSetUidsFrozen(true, TEST_UID)
+
+        // LP change on the original Wi-Fi network
+        val lpChange = wifiAgent.sendLpChange {
+            setLinkAddresses("fe80:db8::123/64")
+        }
+        // Set per-app default to cell, then unset it
+        setMobileDataPreferredUids(setOf(TEST_UID))
+        setMobileDataPreferredUids(emptySet())
+
+        maybeSetUidsFrozen(false, TEST_UID)
+
+        cb.expectAvailableCallbacks(wifiAgent.network)
+        if (expectAllCallbacks) {
+            cb.expectLpWith(wifiAgent, lpChange)
+            cb.expectAvailableCallbacks(cellAgent.network, validated = false)
+            // Cellular stops being foreground since it is now matched for this app
+            cb.expect<CapabilitiesChanged> { it.caps.hasCapability(NET_CAPABILITY_FOREGROUND) }
+            cb.expectAvailableCallbacks(wifiAgent.network)
+        } else {
+            // After switching to cell and back while frozen, only network attribute update
+            // callbacks (and not AVAILABLE) for the original Wi-Fi network should be sent
+            cb.expect<CapabilitiesChanged>(wifiAgent)
+            cb.expectLpWith(wifiAgent, lpChange)
+            cb.expect<BlockedStatus> { !it.blocked }
+        }
+        cb.assertNoCallback(timeoutMs = 0L)
+    }
+
+    private fun setUidsBlockedForDataSaver(blocked: Boolean, vararg uid: Int) {
+        val reason = if (blocked) {
+            BLOCKED_METERED_REASON_DATA_SAVER
+        } else {
+            BLOCKED_REASON_NONE
+        }
+        if (deps.isAtLeastV) {
+            visibleOnHandlerThread(csHandler) {
+                service.handleBlockedReasonsChanged(uid.map { android.util.Pair(it, reason) })
+            }
+        } else {
+            notifyLegacyBlockedReasonChanged(reason, uid)
+            waitForIdle()
+        }
+    }
+
+    @Suppress("DEPRECATION")
+    private fun notifyLegacyBlockedReasonChanged(reason: Int, uids: IntArray) {
+        val cbCaptor = ArgumentCaptor.forClass(NetworkPolicyCallback::class.java)
+        verify(context.networkPolicyManager).registerNetworkPolicyCallback(
+            any(),
+            cbCaptor.capture()
+        )
+        uids.forEach {
+            cbCaptor.value.onUidBlockedReasonChanged(it, reason)
+        }
+    }
+
+    private fun withCallingUid(uid: Int, action: () -> Unit) {
+        deps.callingUid = uid
+        action()
+        deps.callingUid = CALLING_UID_UNMOCKED
+    }
+
+    private fun getUidFrozenStateChangedCallback(): UidFrozenStateChangedCallback {
+        val captor = ArgumentCaptor.forClass(UidFrozenStateChangedCallback::class.java)
+        verify(activityManager).registerUidFrozenStateChangedCallback(any(), captor.capture())
+        return captor.value
+    }
+
+    private fun maybeSetUidsFrozen(frozen: Boolean, vararg uids: Int) {
+        if (!freezeUids) return
+        val state = if (frozen) UID_FROZEN_STATE_FROZEN else UID_FROZEN_STATE_UNFROZEN
+        getUidFrozenStateChangedCallback()
+            .onUidFrozenStateChanged(uids, IntArray(uids.size) { state })
+        waitForIdle()
+    }
+
+    private fun CSAgentWrapper.makeValidationSuccess() {
+        setValidationResult(
+            NETWORK_VALIDATION_RESULT_VALID,
+            probesCompleted = NETWORK_VALIDATION_PROBE_DNS or NETWORK_VALIDATION_PROBE_HTTPS,
+            probesSucceeded = NETWORK_VALIDATION_PROBE_DNS or NETWORK_VALIDATION_PROBE_HTTPS
+        )
+        cm.reportNetworkConnectivity(network, true)
+        // Ensure validation is scheduled
+        waitForIdle()
+        // Ensure validation completes on mock executor
+        waitForIdleSerialExecutor(CSTestExecutor, HANDLER_TIMEOUT_MS)
+        // Ensure validation results are processed
+        waitForIdle()
+    }
+
+    private fun setMobileDataPreferredUids(uids: Set<Int>) {
+        ConnectivitySettingsManager.setMobileDataPreferredUids(context, uids)
+        service.updateMobileDataPreferredUids()
+        waitForIdle()
+    }
+}
+
+private fun makeInternetNc() = NetworkCapabilities.Builder(defaultNc())
+    .addCapability(NET_CAPABILITY_INTERNET)
+    .build()
+
+private fun CSAgentWrapper.sendLpChange(
+    mutator: LinkProperties.() -> Unit
+): LinkProperties.() -> Unit {
+    lp.mutator()
+    sendLinkProperties(lp)
+    return mutator
+}
+
+private fun CSAgentWrapper.sendNcChange(
+    mutator: NetworkCapabilities.() -> Unit
+): NetworkCapabilities.() -> Unit {
+    nc.mutator()
+    sendNetworkCapabilities(nc)
+    return mutator
+}
+
+private fun TestableNetworkCallback.expectLpWith(
+    agent: CSAgentWrapper,
+    change: LinkProperties.() -> Unit
+) = expect<LinkPropertiesChanged>(agent) {
+    // This test uses changes that are no-op when already applied (idempotent): verify that the
+    // change is already applied.
+    it.lp == LinkProperties(it.lp).apply(change)
+}
+
+private fun TestableNetworkCallback.expectNcWith(
+    agent: CSAgentWrapper,
+    change: NetworkCapabilities.() -> Unit
+) = expect<CapabilitiesChanged>(agent) {
+    it.caps == NetworkCapabilities(it.caps).apply(change)
+}
+
+private fun LinkProperties.setLinkAddresses(vararg addrs: String) {
+    setLinkAddresses(addrs.map { LinkAddress(it) })
+}
diff --git a/tests/unit/java/com/android/server/connectivityservice/base/CSAgentWrapper.kt b/tests/unit/java/com/android/server/connectivityservice/base/CSAgentWrapper.kt
index 13c5cbc..1f5ee32 100644
--- a/tests/unit/java/com/android/server/connectivityservice/base/CSAgentWrapper.kt
+++ b/tests/unit/java/com/android/server/connectivityservice/base/CSAgentWrapper.kt
@@ -192,7 +192,8 @@
         connect()
     }
 
-    fun setProbesStatus(probesCompleted: Int, probesSucceeded: Int) {
+    fun setValidationResult(result: Int, probesCompleted: Int, probesSucceeded: Int) {
+        nmValidationResult = result
         nmProbesCompleted = probesCompleted
         nmProbesSucceeded = probesSucceeded
     }
@@ -204,8 +205,10 @@
         // in the beginning. Because NETWORK_VALIDATION_PROBE_HTTP is the decisive probe for captive
         // portal, considering the NETWORK_VALIDATION_PROBE_HTTPS hasn't probed yet and set only
         // DNS and HTTP probes completed.
-        setProbesStatus(
-            NETWORK_VALIDATION_PROBE_DNS or NETWORK_VALIDATION_PROBE_HTTP /* probesCompleted */,
-            VALIDATION_RESULT_INVALID /* probesSucceeded */)
+        setValidationResult(
+            VALIDATION_RESULT_INVALID,
+            probesCompleted = NETWORK_VALIDATION_PROBE_DNS or NETWORK_VALIDATION_PROBE_HTTP,
+            probesSucceeded = NO_PROBE_RESULT
+        )
     }
 }
diff --git a/tests/unit/java/com/android/server/connectivityservice/base/CSTest.kt b/tests/unit/java/com/android/server/connectivityservice/base/CSTest.kt
index de56ae5..46c25d2 100644
--- a/tests/unit/java/com/android/server/connectivityservice/base/CSTest.kt
+++ b/tests/unit/java/com/android/server/connectivityservice/base/CSTest.kt
@@ -538,8 +538,12 @@
             provider: NetworkProvider? = null
     ) = CSAgentWrapper(context, deps, csHandlerThread, networkStack,
             nac, nc, lp, lnc, score, provider)
-    fun Agent(vararg transports: Int, lp: LinkProperties = defaultLp()): CSAgentWrapper {
-        val nc = NetworkCapabilities.Builder().apply {
+    fun Agent(
+        vararg transports: Int,
+        baseNc: NetworkCapabilities = defaultNc(),
+        lp: LinkProperties = defaultLp()
+    ): CSAgentWrapper {
+        val nc = NetworkCapabilities.Builder(baseNc).apply {
             transports.forEach {
                 addTransportType(it)
             }