Merge "Migrate ConnectivityManagerTest to callback rule" into main
diff --git a/staticlibs/testutils/devicetests/com/android/testutils/AutoReleaseNetworkCallbackRule.kt b/staticlibs/testutils/devicetests/com/android/testutils/AutoReleaseNetworkCallbackRule.kt
index 28ae609..93422ad 100644
--- a/staticlibs/testutils/devicetests/com/android/testutils/AutoReleaseNetworkCallbackRule.kt
+++ b/staticlibs/testutils/devicetests/com/android/testutils/AutoReleaseNetworkCallbackRule.kt
@@ -21,6 +21,7 @@
 import android.net.Network
 import android.net.NetworkCapabilities
 import android.net.NetworkRequest
+import android.os.Handler
 import androidx.test.platform.app.InstrumentationRegistry
 import com.android.testutils.RecorderCallback.CallbackEntry
 import java.util.Collections
@@ -97,6 +98,15 @@
         cellRequestCb = null
     }
 
+    private fun addCallback(
+        cb: TestableNetworkCallback,
+        registrar: (TestableNetworkCallback) -> Unit
+    ): TestableNetworkCallback {
+        registrar(cb)
+        cbToCleanup.add(cb)
+        return cb
+    }
+
     /**
      * File a request for a Network.
      *
@@ -109,14 +119,27 @@
     @JvmOverloads
     fun requestNetwork(
         request: NetworkRequest,
-        cb: TestableNetworkCallback = TestableNetworkCallback()
-    ): TestableNetworkCallback {
-        cm.requestNetwork(request, cb)
-        cbToCleanup.add(cb)
-        return cb
+        cb: TestableNetworkCallback = TestableNetworkCallback(),
+        handler: Handler? = null
+    ) = addCallback(cb) {
+        if (handler == null) {
+            cm.requestNetwork(request, it)
+        } else {
+            cm.requestNetwork(request, it, handler)
+        }
     }
 
     /**
+     * Overload of [requestNetwork] that allows specifying a timeout.
+     */
+    @JvmOverloads
+    fun requestNetwork(
+        request: NetworkRequest,
+        cb: TestableNetworkCallback = TestableNetworkCallback(),
+        timeoutMs: Int,
+    ) = addCallback(cb) { cm.requestNetwork(request, it, timeoutMs) }
+
+    /**
      * File a callback for a NetworkRequest.
      *
      * This will fail tests (throw) if the cell network cannot be obtained, or if it was already
@@ -129,13 +152,63 @@
     fun registerNetworkCallback(
         request: NetworkRequest,
         cb: TestableNetworkCallback = TestableNetworkCallback()
-    ): TestableNetworkCallback {
-        cm.registerNetworkCallback(request, cb)
-        cbToCleanup.add(cb)
-        return cb
+    ) = addCallback(cb) { cm.registerNetworkCallback(request, it) }
+
+    /**
+     * @see ConnectivityManager.registerDefaultNetworkCallback
+     */
+    @JvmOverloads
+    fun registerDefaultNetworkCallback(
+        cb: TestableNetworkCallback = TestableNetworkCallback(),
+        handler: Handler? = null
+    ) = addCallback(cb) {
+        if (handler == null) {
+            cm.registerDefaultNetworkCallback(it)
+        } else {
+            cm.registerDefaultNetworkCallback(it, handler)
+        }
     }
 
     /**
+     * @see ConnectivityManager.registerSystemDefaultNetworkCallback
+     */
+    @JvmOverloads
+    fun registerSystemDefaultNetworkCallback(
+        cb: TestableNetworkCallback = TestableNetworkCallback(),
+        handler: Handler
+    ) = addCallback(cb) { cm.registerSystemDefaultNetworkCallback(it, handler) }
+
+    /**
+     * @see ConnectivityManager.registerDefaultNetworkCallbackForUid
+     */
+    @JvmOverloads
+    fun registerDefaultNetworkCallbackForUid(
+        uid: Int,
+        cb: TestableNetworkCallback = TestableNetworkCallback(),
+        handler: Handler
+    ) = addCallback(cb) { cm.registerDefaultNetworkCallbackForUid(uid, it, handler) }
+
+    /**
+     * @see ConnectivityManager.registerBestMatchingNetworkCallback
+     */
+    @JvmOverloads
+    fun registerBestMatchingNetworkCallback(
+        request: NetworkRequest,
+        cb: TestableNetworkCallback = TestableNetworkCallback(),
+        handler: Handler
+    ) = addCallback(cb) { cm.registerBestMatchingNetworkCallback(request, it, handler) }
+
+    /**
+     * @see ConnectivityManager.requestBackgroundNetwork
+     */
+    @JvmOverloads
+    fun requestBackgroundNetwork(
+        request: NetworkRequest,
+        cb: TestableNetworkCallback = TestableNetworkCallback(),
+        handler: Handler
+    ) = addCallback(cb) { cm.requestBackgroundNetwork(request, it, handler) }
+
+    /**
      * Unregister a callback filed using registration methods in this class.
      */
     fun unregisterNetworkCallback(cb: NetworkCallback) {
diff --git a/tests/cts/net/src/android/net/cts/ConnectivityManagerTest.java b/tests/cts/net/src/android/net/cts/ConnectivityManagerTest.java
index 4d465ba..c0f1080 100644
--- a/tests/cts/net/src/android/net/cts/ConnectivityManagerTest.java
+++ b/tests/cts/net/src/android/net/cts/ConnectivityManagerTest.java
@@ -71,7 +71,9 @@
 import static android.net.NetworkCapabilities.NET_CAPABILITY_INTERNET;
 import static android.net.NetworkCapabilities.NET_CAPABILITY_NOT_METERED;
 import static android.net.NetworkCapabilities.NET_CAPABILITY_NOT_RESTRICTED;
+import static android.net.NetworkCapabilities.NET_CAPABILITY_NOT_VPN;
 import static android.net.NetworkCapabilities.NET_CAPABILITY_PARTIAL_CONNECTIVITY;
+import static android.net.NetworkCapabilities.NET_CAPABILITY_TRUSTED;
 import static android.net.NetworkCapabilities.NET_CAPABILITY_VALIDATED;
 import static android.net.NetworkCapabilities.TRANSPORT_TEST;
 import static android.net.NetworkCapabilities.TRANSPORT_VPN;
@@ -81,7 +83,6 @@
 import static android.net.cts.util.CtsNetUtils.HTTP_PORT;
 import static android.net.cts.util.CtsNetUtils.NETWORK_CALLBACK_ACTION;
 import static android.net.cts.util.CtsNetUtils.TEST_HOST;
-import static android.net.cts.util.CtsNetUtils.TestNetworkCallback;
 import static android.net.cts.util.CtsTetheringUtils.TestTetheringEventCallback;
 import static android.os.MessageQueue.OnFileDescriptorEventListener.EVENT_INPUT;
 import static android.os.Process.INVALID_UID;
@@ -333,8 +334,6 @@
     private final ArraySet<Integer> mNetworkTypes = new ArraySet<>();
     private UiAutomation mUiAutomation;
     private CtsNetUtils mCtsNetUtils;
-    // The registered callbacks.
-    private List<NetworkCallback> mRegisteredCallbacks = new ArrayList<>();
     // Used for cleanup purposes.
     private final List<Range<Integer>> mVpnRequiredUidRanges = new ArrayList<>();
 
@@ -425,15 +424,12 @@
         // All tests in this class require a working Internet connection as they start. Make
         // sure there is still one as they end that's ready to use for the next test to use.
         mTestValidationConfigRule.runAfterNextCleanup(() -> {
-            final TestNetworkCallback callback = new TestNetworkCallback();
-            registerDefaultNetworkCallback(callback);
-            try {
-                assertNotNull("Couldn't restore Internet connectivity",
-                        callback.waitForAvailable());
-            } finally {
-                // Unregister all registered callbacks.
-                unregisterRegisteredCallbacks();
-            }
+            // mTestValidationConfigRule has higher order than networkCallbackRule, so
+            // networkCallbackRule is the outer rule and will be cleaned up after this method.
+            final TestableNetworkCallback callback =
+                    networkCallbackRule.registerDefaultNetworkCallback();
+            assertNotNull("Couldn't restore Internet connectivity",
+                    callback.eventuallyExpect(CallbackEntry.AVAILABLE));
         });
     }
 
@@ -993,10 +989,10 @@
         // default network.
         return new NetworkRequest.Builder()
                 .clearCapabilities()
-                .addCapability(NetworkCapabilities.NET_CAPABILITY_NOT_RESTRICTED)
-                .addCapability(NetworkCapabilities.NET_CAPABILITY_TRUSTED)
-                .addCapability(NetworkCapabilities.NET_CAPABILITY_NOT_VPN)
-                .addCapability(NetworkCapabilities.NET_CAPABILITY_INTERNET)
+                .addCapability(NET_CAPABILITY_NOT_RESTRICTED)
+                .addCapability(NET_CAPABILITY_TRUSTED)
+                .addCapability(NET_CAPABILITY_NOT_VPN)
+                .addCapability(NET_CAPABILITY_INTERNET)
                 .build();
     }
 
@@ -1027,10 +1023,10 @@
         final String invalidPrivateDnsServer = "invalidhostname.example.com";
         final String goodPrivateDnsServer = "dns.google";
         mCtsNetUtils.storePrivateDnsSetting();
-        final TestableNetworkCallback cb = new TestableNetworkCallback();
         final NetworkRequest networkRequest = new NetworkRequest.Builder()
                 .addCapability(NET_CAPABILITY_INTERNET).build();
-        registerNetworkCallback(networkRequest, cb);
+        final TestableNetworkCallback cb =
+                networkCallbackRule.registerNetworkCallback(networkRequest);
         final Network networkForPrivateDns = mCm.getActiveNetwork();
         try {
             // Verifying the good private DNS sever
@@ -1068,24 +1064,27 @@
         assumeTrue(mPackageManager.hasSystemFeature(FEATURE_WIFI));
 
         // We will register for a WIFI network being available or lost.
-        final TestNetworkCallback callback = new TestNetworkCallback();
-        registerNetworkCallback(makeWifiNetworkRequest(), callback);
+        final TestableNetworkCallback callback = networkCallbackRule.registerNetworkCallback(
+                makeWifiNetworkRequest());
 
-        final TestNetworkCallback defaultTrackingCallback = new TestNetworkCallback();
-        registerDefaultNetworkCallback(defaultTrackingCallback);
+        final TestableNetworkCallback defaultTrackingCallback =
+                networkCallbackRule.registerDefaultNetworkCallback();
 
-        final TestNetworkCallback systemDefaultCallback = new TestNetworkCallback();
-        final TestNetworkCallback perUidCallback = new TestNetworkCallback();
-        final TestNetworkCallback bestMatchingCallback = new TestNetworkCallback();
+        final TestableNetworkCallback systemDefaultCallback = new TestableNetworkCallback();
+        final TestableNetworkCallback perUidCallback = new TestableNetworkCallback();
+        final TestableNetworkCallback bestMatchingCallback = new TestableNetworkCallback();
         final Handler h = new Handler(Looper.getMainLooper());
         if (TestUtils.shouldTestSApis()) {
             assertThrows(SecurityException.class, () ->
-                    registerSystemDefaultNetworkCallback(systemDefaultCallback, h));
+                    networkCallbackRule.registerSystemDefaultNetworkCallback(
+                            systemDefaultCallback, h));
             runWithShellPermissionIdentity(() -> {
-                registerSystemDefaultNetworkCallback(systemDefaultCallback, h);
-                registerDefaultNetworkCallbackForUid(Process.myUid(), perUidCallback, h);
+                networkCallbackRule.registerSystemDefaultNetworkCallback(systemDefaultCallback, h);
+                networkCallbackRule.registerDefaultNetworkCallbackForUid(Process.myUid(),
+                        perUidCallback, h);
             }, NETWORK_SETTINGS);
-            registerBestMatchingNetworkCallback(makeDefaultRequest(), bestMatchingCallback, h);
+            networkCallbackRule.registerBestMatchingNetworkCallback(
+                    makeDefaultRequest(), bestMatchingCallback, h);
         }
 
         Network wifiNetwork = null;
@@ -1094,24 +1093,22 @@
         // Now we should expect to get a network callback about availability of the wifi
         // network even if it was already connected as a state-based action when the callback
         // is registered.
-        wifiNetwork = callback.waitForAvailable();
+        wifiNetwork = callback.eventuallyExpect(CallbackEntry.AVAILABLE).getNetwork();
         assertNotNull("Did not receive onAvailable for TRANSPORT_WIFI request",
                 wifiNetwork);
 
-        final Network defaultNetwork = defaultTrackingCallback.waitForAvailable();
+        final Network defaultNetwork = defaultTrackingCallback.eventuallyExpect(
+                CallbackEntry.AVAILABLE).getNetwork();
         assertNotNull("Did not receive onAvailable on default network callback",
                 defaultNetwork);
 
         if (TestUtils.shouldTestSApis()) {
-            assertNotNull("Did not receive onAvailable on system default network callback",
-                    systemDefaultCallback.waitForAvailable());
-            final Network perUidNetwork = perUidCallback.waitForAvailable();
-            assertNotNull("Did not receive onAvailable on per-UID default network callback",
-                    perUidNetwork);
+            systemDefaultCallback.eventuallyExpect(CallbackEntry.AVAILABLE);
+            final Network perUidNetwork = perUidCallback.eventuallyExpect(CallbackEntry.AVAILABLE)
+                    .getNetwork();
             assertEquals(defaultNetwork, perUidNetwork);
-            final Network bestMatchingNetwork = bestMatchingCallback.waitForAvailable();
-            assertNotNull("Did not receive onAvailable on best matching network callback",
-                    bestMatchingNetwork);
+            final Network bestMatchingNetwork = bestMatchingCallback.eventuallyExpect(
+                    CallbackEntry.AVAILABLE).getNetwork();
             assertEquals(defaultNetwork, bestMatchingNetwork);
         }
     }
@@ -1123,8 +1120,8 @@
         final Handler h = new Handler(Looper.getMainLooper());
         // Verify registerSystemDefaultNetworkCallback can be accessed via
         // CONNECTIVITY_USE_RESTRICTED_NETWORKS permission.
-        runWithShellPermissionIdentity(() ->
-                        registerSystemDefaultNetworkCallback(new TestNetworkCallback(), h),
+        runWithShellPermissionIdentity(
+                () -> networkCallbackRule.registerSystemDefaultNetworkCallback(h),
                 CONNECTIVITY_USE_RESTRICTED_NETWORKS);
     }
 
@@ -1294,15 +1291,14 @@
      */
     @AppModeFull(reason = "CHANGE_NETWORK_STATE permission can't be granted to instant apps")
     @Test
-    public void testRequestNetworkCallback() throws Exception {
-        final TestNetworkCallback callback = new TestNetworkCallback();
-        requestNetwork(new NetworkRequest.Builder()
-                .addCapability(NetworkCapabilities.NET_CAPABILITY_INTERNET)
-                .build(), callback);
+    public void testRequestNetworkCallback() {
+        final TestableNetworkCallback callback = networkCallbackRule.requestNetwork(
+                new NetworkRequest.Builder().addCapability(
+                                NET_CAPABILITY_INTERNET)
+                        .build());
 
         // Wait to get callback for availability of internet
-        Network internetNetwork = callback.waitForAvailable();
-        assertNotNull("Did not receive NetworkCallback#onAvailable for INTERNET", internetNetwork);
+        callback.eventuallyExpect(CallbackEntry.AVAILABLE).getNetwork();
     }
 
     /**
@@ -1320,16 +1316,13 @@
             }
         }
 
-        final TestNetworkCallback callback = new TestNetworkCallback();
-        requestNetwork(new NetworkRequest.Builder().addTransportType(TRANSPORT_WIFI).build(),
-                callback, 100);
 
+        final TestableNetworkCallback callback = networkCallbackRule.requestNetwork(
+                new NetworkRequest.Builder().addTransportType(TRANSPORT_WIFI).build(),
+                100 /* timeoutMs */);
         try {
             // Wait to get callback for unavailability of requested network
-            assertTrue("Did not receive NetworkCallback#onUnavailable",
-                    callback.waitForUnavailable());
-        } catch (InterruptedException e) {
-            fail("NetworkCallback wait was interrupted.");
+            callback.eventuallyExpect(CallbackEntry.UNAVAILABLE, 2_000 /* timeoutMs */);
         } finally {
             if (previousWifiEnabledState) {
                 mCtsNetUtils.connectToWifi();
@@ -1416,40 +1409,48 @@
             final boolean useSystemDefault)
             throws Exception {
         final CompletableFuture<Network> networkFuture = new CompletableFuture<>();
-        final NetworkCallback networkCallback = new NetworkCallback() {
-            @Override
-            public void onCapabilitiesChanged(Network network, NetworkCapabilities nc) {
-                if (!nc.hasTransport(targetTransportType)) return;
 
-                final boolean metered = !nc.hasCapability(NET_CAPABILITY_NOT_METERED);
-                final boolean validated = nc.hasCapability(NET_CAPABILITY_VALIDATED);
-                if (metered == requestedMeteredness && (!waitForValidation || validated)) {
-                    networkFuture.complete(network);
+        // Registering a callback here guarantees onCapabilitiesChanged is called immediately
+        // with the current setting. Therefore, if the setting has already been changed,
+        // this method will return right away, and if not, it'll wait for the setting to change.
+        final TestableNetworkCallback networkCallback;
+        if (useSystemDefault) {
+            networkCallback = runWithShellPermissionIdentity(() -> {
+                if (isAtLeastS()) {
+                    return networkCallbackRule.registerSystemDefaultNetworkCallback(
+                            new Handler(Looper.getMainLooper()));
+                } else {
+                    // registerSystemDefaultNetworkCallback is only supported on S+.
+                    return networkCallbackRule.requestNetwork(
+                            new NetworkRequest.Builder()
+                                    .clearCapabilities()
+                                    .addCapability(NET_CAPABILITY_NOT_RESTRICTED)
+                                    .addCapability(NET_CAPABILITY_TRUSTED)
+                                    .addCapability(NET_CAPABILITY_NOT_VPN)
+                                    .addCapability(NET_CAPABILITY_INTERNET)
+                                    .build(),
+                            new TestableNetworkCallback(),
+                            new Handler(Looper.getMainLooper()));
                 }
-            }
-        };
-
-        try {
-            // Registering a callback here guarantees onCapabilitiesChanged is called immediately
-            // with the current setting. Therefore, if the setting has already been changed,
-            // this method will return right away, and if not, it'll wait for the setting to change.
-            if (useSystemDefault) {
-                runWithShellPermissionIdentity(() ->
-                                registerSystemDefaultNetworkCallback(networkCallback,
-                                        new Handler(Looper.getMainLooper())),
-                        NETWORK_SETTINGS);
-            } else {
-                registerDefaultNetworkCallback(networkCallback);
-            }
-
-            // Changing meteredness on wifi involves reconnecting, which can take several seconds
-            // (involves re-associating, DHCP...).
-            return networkFuture.get(NETWORK_CALLBACK_TIMEOUT_MS, TimeUnit.MILLISECONDS);
-        } catch (TimeoutException e) {
-            throw new AssertionError("Timed out waiting for active network metered status to "
-                    + "change to " + requestedMeteredness + " ; network = "
-                    + mCm.getActiveNetwork(), e);
+            },
+            NETWORK_SETTINGS);
+        } else {
+            networkCallback = networkCallbackRule.registerDefaultNetworkCallback();
         }
+
+        return networkCallback.eventuallyExpect(
+                CallbackEntry.NETWORK_CAPS_UPDATED,
+                // Changing meteredness on wifi involves reconnecting, which can take several
+                // seconds (involves re-associating, DHCP...).
+                NETWORK_CALLBACK_TIMEOUT_MS,
+                cb -> {
+                    final NetworkCapabilities nc = cb.getCaps();
+                    if (!nc.hasTransport(targetTransportType)) return false;
+
+                    final boolean metered = !nc.hasCapability(NET_CAPABILITY_NOT_METERED);
+                    final boolean validated = nc.hasCapability(NET_CAPABILITY_VALIDATED);
+                    return metered == requestedMeteredness && (!waitForValidation || validated);
+                }).getNetwork();
     }
 
     private Network setWifiMeteredStatusAndWait(String ssid, boolean isMetered,
@@ -2091,15 +2092,15 @@
     }
 
     private void verifyBindSocketToRestrictedNetworkDisallowed() throws Exception {
-        final TestableNetworkCallback testNetworkCb = new TestableNetworkCallback();
         final NetworkRequest testRequest = new NetworkRequest.Builder()
                 .addTransportType(NetworkCapabilities.TRANSPORT_TEST)
-                .removeCapability(NetworkCapabilities.NET_CAPABILITY_TRUSTED)
-                .removeCapability(NetworkCapabilities.NET_CAPABILITY_NOT_RESTRICTED)
+                .removeCapability(NET_CAPABILITY_TRUSTED)
+                .removeCapability(NET_CAPABILITY_NOT_RESTRICTED)
                 .setNetworkSpecifier(CompatUtil.makeTestNetworkSpecifier(
                         TEST_RESTRICTED_NW_IFACE_NAME))
                 .build();
-        runWithShellPermissionIdentity(() -> requestNetwork(testRequest, testNetworkCb),
+        final TestableNetworkCallback testNetworkCb = runWithShellPermissionIdentity(
+                () -> networkCallbackRule.requestNetwork(testRequest),
                 CONNECTIVITY_USE_RESTRICTED_NETWORKS,
                 // CONNECTIVITY_INTERNAL is for requesting restricted network because shell does not
                 // have CONNECTIVITY_USE_RESTRICTED_NETWORKS on R.
@@ -2115,7 +2116,7 @@
                     NETWORK_CALLBACK_TIMEOUT_MS,
                     entry -> network.equals(entry.getNetwork())
                             && (!((CallbackEntry.CapabilitiesChanged) entry).getCaps()
-                            .hasCapability(NetworkCapabilities.NET_CAPABILITY_NOT_RESTRICTED)));
+                            .hasCapability(NET_CAPABILITY_NOT_RESTRICTED)));
             // CtsNetTestCases package doesn't hold CONNECTIVITY_USE_RESTRICTED_NETWORKS, so it
             // does not allow to bind socket to restricted network.
             assertThrows(IOException.class, () -> network.bindSocket(socket));
@@ -2238,7 +2239,7 @@
 
     private void registerCallbackAndWaitForAvailable(@NonNull final NetworkRequest request,
             @NonNull final TestableNetworkCallback cb) {
-        registerNetworkCallback(request, cb);
+        networkCallbackRule.registerNetworkCallback(request, cb);
         waitForAvailable(cb);
     }
 
@@ -2346,21 +2347,13 @@
 
     private void verifySsidFromCallbackNetworkCapabilities(@NonNull String ssid, boolean hasSsid)
             throws Exception {
-        final CompletableFuture<NetworkCapabilities> foundNc = new CompletableFuture();
-        final NetworkCallback callback = new NetworkCallback() {
-            @Override
-            public void onCapabilitiesChanged(Network network, NetworkCapabilities nc) {
-                foundNc.complete(nc);
-            }
-        };
-
-        registerNetworkCallback(makeWifiNetworkRequest(), callback);
+        final TestableNetworkCallback callback =
+                networkCallbackRule.registerNetworkCallback(makeWifiNetworkRequest());
         // Registering a callback here guarantees onCapabilitiesChanged is called immediately
         // because WiFi network should be connected.
-        final NetworkCapabilities nc =
-                foundNc.get(NETWORK_CALLBACK_TIMEOUT_MS, TimeUnit.MILLISECONDS);
+        final NetworkCapabilities nc = callback.eventuallyExpect(
+                CallbackEntry.NETWORK_CAPS_UPDATED, NETWORK_CALLBACK_TIMEOUT_MS).getCaps();
         // Verify if ssid is contained in the NetworkCapabilities received from callback.
-        assertNotNull("NetworkCapabilities of the network is null", nc);
         assertEquals(hasSsid, Pattern.compile(ssid).matcher(nc.toString()).find());
     }
 
@@ -2389,8 +2382,8 @@
         final NetworkRequest testRequest = new NetworkRequest.Builder()
                 .addTransportType(TRANSPORT_TEST)
                 // Test networks do not have NOT_VPN or TRUSTED capabilities by default
-                .removeCapability(NetworkCapabilities.NET_CAPABILITY_NOT_VPN)
-                .removeCapability(NetworkCapabilities.NET_CAPABILITY_TRUSTED)
+                .removeCapability(NET_CAPABILITY_NOT_VPN)
+                .removeCapability(NET_CAPABILITY_TRUSTED)
                 .setNetworkSpecifier(CompatUtil.makeTestNetworkSpecifier(
                         testNetworkInterface.getInterfaceName()))
                 .build();
@@ -2399,14 +2392,15 @@
         final TestableNetworkCallback callback = new TestableNetworkCallback();
         final Handler handler = new Handler(Looper.getMainLooper());
         assertThrows(SecurityException.class,
-                () -> requestBackgroundNetwork(testRequest, callback, handler));
+                () -> networkCallbackRule.requestBackgroundNetwork(testRequest, callback, handler));
 
         Network testNetwork = null;
         try {
             // Request background test network via Shell identity which has NETWORK_SETTINGS
             // permission granted.
             runWithShellPermissionIdentity(
-                    () -> requestBackgroundNetwork(testRequest, callback, handler),
+                    () -> networkCallbackRule.requestBackgroundNetwork(
+                            testRequest, callback, handler),
                     new String[] { android.Manifest.permission.NETWORK_SETTINGS });
 
             // Register the test network agent which has no foreground request associated to it.
@@ -2499,9 +2493,10 @@
         final int otherUid = UserHandle.getUid(5, Process.FIRST_APPLICATION_UID);
         final Handler handler = new Handler(Looper.getMainLooper());
 
-        registerDefaultNetworkCallback(myUidCallback, handler);
-        runWithShellPermissionIdentity(() -> registerDefaultNetworkCallbackForUid(
-                otherUid, otherUidCallback, handler), NETWORK_SETTINGS);
+        networkCallbackRule.registerDefaultNetworkCallback(myUidCallback, handler);
+        runWithShellPermissionIdentity(
+                () -> networkCallbackRule.registerDefaultNetworkCallbackForUid(
+                        otherUid, otherUidCallback, handler), NETWORK_SETTINGS);
 
         final Network defaultNetwork = myUidCallback.expect(CallbackEntry.AVAILABLE).getNetwork();
         final List<DetailedBlockedStatusCallback> allCallbacks =
@@ -2557,14 +2552,14 @@
         assertNotNull(info);
         assertEquals(DetailedState.CONNECTED, info.getDetailedState());
 
-        final TestableNetworkCallback callback = new TestableNetworkCallback();
+        final TestableNetworkCallback callback;
         try {
             mCmShim.setLegacyLockdownVpnEnabled(true);
 
             // setLegacyLockdownVpnEnabled is asynchronous and only takes effect when the
             // ConnectivityService handler thread processes it. Ensure it has taken effect by doing
             // something that blocks until the handler thread is idle.
-            registerDefaultNetworkCallback(callback);
+            callback = networkCallbackRule.registerDefaultNetworkCallback();
             waitForAvailable(callback);
 
             // Test one of the effects of setLegacyLockdownVpnEnabled: the fact that any NetworkInfo
@@ -2831,9 +2826,9 @@
     private void registerTestOemNetworkPreferenceCallbacks(
             @NonNull final TestableNetworkCallback defaultCallback,
             @NonNull final TestableNetworkCallback systemDefaultCallback) {
-        registerDefaultNetworkCallback(defaultCallback);
+        networkCallbackRule.registerDefaultNetworkCallback(defaultCallback);
         runWithShellPermissionIdentity(() ->
-                registerSystemDefaultNetworkCallback(systemDefaultCallback,
+                networkCallbackRule.registerSystemDefaultNetworkCallback(systemDefaultCallback,
                         new Handler(Looper.getMainLooper())), NETWORK_SETTINGS);
     }
 
@@ -2949,18 +2944,18 @@
                         + " unless device supports WiFi",
                 mPackageManager.hasSystemFeature(FEATURE_WIFI));
 
-        final TestNetworkCallback cb = new TestNetworkCallback();
         try {
             // Wait for partial connectivity to be detected on the network
             final Network network = preparePartialConnectivity();
 
-            requestNetwork(makeWifiNetworkRequest(), cb);
+            final TestableNetworkCallback cb = networkCallbackRule.requestNetwork(
+                    makeWifiNetworkRequest());
             runAsShell(NETWORK_SETTINGS, () -> {
                 // The always bit is verified in NetworkAgentTest
                 mCm.setAcceptPartialConnectivity(network, false /* accept */, false /* always */);
             });
             // Reject partial connectivity network should cause the network being torn down
-            assertEquals(network, cb.waitForLost());
+            assertEquals(network, cb.eventuallyExpect(CallbackEntry.LOST).getNetwork());
         } finally {
             mHttpServer.stop();
             // Wifi will not automatically reconnect to the network. ensureWifiDisconnected cannot
@@ -2988,7 +2983,6 @@
         assumeTrue("testAcceptPartialConnectivity_validatedNetwork cannot execute"
                         + " unless device supports WiFi and telephony", canRunTest);
 
-        final TestableNetworkCallback wifiCb = new TestableNetworkCallback();
         try {
             // Ensure at least one default network candidate connected.
             networkCallbackRule.requestCell();
@@ -2998,7 +2992,8 @@
             // guarantee that it won't become the default in the future.
             assertNotEquals(wifiNetwork, mCm.getActiveNetwork());
 
-            registerNetworkCallback(makeWifiNetworkRequest(), wifiCb);
+            final TestableNetworkCallback wifiCb = networkCallbackRule.registerNetworkCallback(
+                    makeWifiNetworkRequest());
             runAsShell(NETWORK_SETTINGS, () -> {
                 mCm.setAcceptUnvalidated(wifiNetwork, false /* accept */, false /* always */);
             });
@@ -3025,8 +3020,6 @@
         assumeTrue("testSetAvoidUnvalidated cannot execute"
                 + " unless device supports WiFi and telephony", canRunTest);
 
-        final TestableNetworkCallback wifiCb = new TestableNetworkCallback();
-        final TestableNetworkCallback defaultCb = new TestableNetworkCallback();
         final int previousAvoidBadWifi =
                 ConnectivitySettingsManager.getNetworkAvoidBadWifi(mContext);
 
@@ -3036,8 +3029,10 @@
             final Network cellNetwork = networkCallbackRule.requestCell();
             final Network wifiNetwork = prepareValidatedNetwork();
 
-            registerDefaultNetworkCallback(defaultCb);
-            registerNetworkCallback(makeWifiNetworkRequest(), wifiCb);
+            final TestableNetworkCallback defaultCb =
+                    networkCallbackRule.registerDefaultNetworkCallback();
+            final TestableNetworkCallback wifiCb = networkCallbackRule.registerNetworkCallback(
+                    makeWifiNetworkRequest());
 
             // Verify wifi is the default network.
             defaultCb.eventuallyExpect(CallbackEntry.AVAILABLE, NETWORK_CALLBACK_TIMEOUT_MS,
@@ -3100,20 +3095,11 @@
         });
     }
 
-    private Network expectNetworkHasCapability(Network network, int expectedNetCap, long timeout)
-            throws Exception {
-        final CompletableFuture<Network> future = new CompletableFuture();
-        final NetworkCallback cb = new NetworkCallback() {
-            @Override
-            public void onCapabilitiesChanged(Network n, NetworkCapabilities nc) {
-                if (n.equals(network) && nc.hasCapability(expectedNetCap)) {
-                    future.complete(network);
-                }
-            }
-        };
-
-        registerNetworkCallback(new NetworkRequest.Builder().build(), cb);
-        return future.get(timeout, TimeUnit.MILLISECONDS);
+    private Network expectNetworkHasCapability(Network network, int expectedNetCap, long timeout) {
+        return networkCallbackRule.registerNetworkCallback(new NetworkRequest.Builder().build())
+                .eventuallyExpect(CallbackEntry.NETWORK_CAPS_UPDATED, timeout,
+                        cb -> cb.getNetwork().equals(network)
+                                && cb.getCaps().hasCapability(expectedNetCap)).getNetwork();
     }
 
     private void prepareHttpServer() throws Exception {
@@ -3265,12 +3251,13 @@
         // For testing mobile data preferred uids feature, it needs both wifi and cell network.
         final Network wifiNetwork = mCtsNetUtils.ensureWifiConnected();
         final Network cellNetwork = networkCallbackRule.requestCell();
-        final TestableNetworkCallback defaultTrackingCb = new TestableNetworkCallback();
-        final TestableNetworkCallback systemDefaultCb = new TestableNetworkCallback();
         final Handler h = new Handler(Looper.getMainLooper());
-        runWithShellPermissionIdentity(() -> registerSystemDefaultNetworkCallback(
-                systemDefaultCb, h), NETWORK_SETTINGS);
-        registerDefaultNetworkCallback(defaultTrackingCb);
+        final TestableNetworkCallback systemDefaultCb = runWithShellPermissionIdentity(
+                () -> networkCallbackRule.registerSystemDefaultNetworkCallback(h),
+                NETWORK_SETTINGS);
+
+        final TestableNetworkCallback defaultTrackingCb =
+                networkCallbackRule.registerDefaultNetworkCallback();
 
         try {
             // CtsNetTestCases uid is not listed in MOBILE_DATA_PREFERRED_UIDS setting, so the
@@ -3339,7 +3326,7 @@
         // Create test network agent with restricted network.
         final NetworkCapabilities nc = new NetworkCapabilities.Builder()
                 .addTransportType(NetworkCapabilities.TRANSPORT_TEST)
-                .removeCapability(NetworkCapabilities.NET_CAPABILITY_NOT_RESTRICTED)
+                .removeCapability(NET_CAPABILITY_NOT_RESTRICTED)
                 .setNetworkSpecifier(CompatUtil.makeTestNetworkSpecifier(
                         TEST_RESTRICTED_NW_IFACE_NAME))
                 .build();
@@ -3373,23 +3360,23 @@
                 mContext, originalUidsAllowedOnRestrictedNetworks), NETWORK_SETTINGS);
 
         // File a restricted network request with permission first to hold the connection.
-        final TestableNetworkCallback testNetworkCb = new TestableNetworkCallback();
         final NetworkRequest testRequest = new NetworkRequest.Builder()
                 .addTransportType(NetworkCapabilities.TRANSPORT_TEST)
-                .removeCapability(NetworkCapabilities.NET_CAPABILITY_TRUSTED)
-                .removeCapability(NetworkCapabilities.NET_CAPABILITY_NOT_RESTRICTED)
+                .removeCapability(NET_CAPABILITY_TRUSTED)
+                .removeCapability(NET_CAPABILITY_NOT_RESTRICTED)
                 .setNetworkSpecifier(CompatUtil.makeTestNetworkSpecifier(
                         TEST_RESTRICTED_NW_IFACE_NAME))
                 .build();
-        runWithShellPermissionIdentity(() -> requestNetwork(testRequest, testNetworkCb),
+        final TestableNetworkCallback testNetworkCb = runWithShellPermissionIdentity(
+                () -> networkCallbackRule.requestNetwork(testRequest),
                 CONNECTIVITY_USE_RESTRICTED_NETWORKS);
 
         // File another restricted network request without permission.
         final TestableNetworkCallback restrictedNetworkCb = new TestableNetworkCallback();
         final NetworkRequest restrictedRequest = new NetworkRequest.Builder()
                 .addTransportType(NetworkCapabilities.TRANSPORT_TEST)
-                .removeCapability(NetworkCapabilities.NET_CAPABILITY_TRUSTED)
-                .removeCapability(NetworkCapabilities.NET_CAPABILITY_NOT_RESTRICTED)
+                .removeCapability(NET_CAPABILITY_TRUSTED)
+                .removeCapability(NET_CAPABILITY_NOT_RESTRICTED)
                 .setNetworkSpecifier(CompatUtil.makeTestNetworkSpecifier(
                         TEST_RESTRICTED_NW_IFACE_NAME))
                 .build();
@@ -3406,7 +3393,7 @@
                     NETWORK_CALLBACK_TIMEOUT_MS,
                     entry -> network.equals(entry.getNetwork())
                             && (!((CallbackEntry.CapabilitiesChanged) entry).getCaps()
-                            .hasCapability(NetworkCapabilities.NET_CAPABILITY_NOT_RESTRICTED)));
+                            .hasCapability(NET_CAPABILITY_NOT_RESTRICTED)));
             // CtsNetTestCases package doesn't hold CONNECTIVITY_USE_RESTRICTED_NETWORKS, so it
             // does not allow to bind socket to restricted network.
             assertThrows(IOException.class, () -> network.bindSocket(socket));
@@ -3424,13 +3411,13 @@
 
             if (TestUtils.shouldTestTApis()) {
                 // Uid is in allowed list. Try file network request again.
-                requestNetwork(restrictedRequest, restrictedNetworkCb);
+                networkCallbackRule.requestNetwork(restrictedRequest, restrictedNetworkCb);
                 // Verify that the network is restricted.
                 restrictedNetworkCb.eventuallyExpect(CallbackEntry.NETWORK_CAPS_UPDATED,
                         NETWORK_CALLBACK_TIMEOUT_MS,
                         entry -> network.equals(entry.getNetwork())
                                 && (!((CallbackEntry.CapabilitiesChanged) entry).getCaps()
-                                .hasCapability(NetworkCapabilities.NET_CAPABILITY_NOT_RESTRICTED)));
+                                .hasCapability(NET_CAPABILITY_NOT_RESTRICTED)));
             }
         } finally {
             agent.unregister();
@@ -3728,58 +3715,4 @@
         // shims, and @IgnoreUpTo does not check that.
         assumeTrue(TestUtils.shouldTestSApis());
     }
-
-    private void unregisterRegisteredCallbacks() {
-        for (NetworkCallback callback: mRegisteredCallbacks) {
-            mCm.unregisterNetworkCallback(callback);
-        }
-    }
-
-    private void registerDefaultNetworkCallback(NetworkCallback callback) {
-        mCm.registerDefaultNetworkCallback(callback);
-        mRegisteredCallbacks.add(callback);
-    }
-
-    private void registerDefaultNetworkCallback(NetworkCallback callback, Handler handler) {
-        mCm.registerDefaultNetworkCallback(callback, handler);
-        mRegisteredCallbacks.add(callback);
-    }
-
-    private void registerNetworkCallback(NetworkRequest request, NetworkCallback callback) {
-        mCm.registerNetworkCallback(request, callback);
-        mRegisteredCallbacks.add(callback);
-    }
-
-    private void registerSystemDefaultNetworkCallback(NetworkCallback callback, Handler handler) {
-        mCmShim.registerSystemDefaultNetworkCallback(callback, handler);
-        mRegisteredCallbacks.add(callback);
-    }
-
-    private void registerDefaultNetworkCallbackForUid(int uid, NetworkCallback callback,
-            Handler handler) throws Exception {
-        mCmShim.registerDefaultNetworkCallbackForUid(uid, callback, handler);
-        mRegisteredCallbacks.add(callback);
-    }
-
-    private void requestNetwork(NetworkRequest request, NetworkCallback callback) {
-        mCm.requestNetwork(request, callback);
-        mRegisteredCallbacks.add(callback);
-    }
-
-    private void requestNetwork(NetworkRequest request, NetworkCallback callback, int timeoutSec) {
-        mCm.requestNetwork(request, callback, timeoutSec);
-        mRegisteredCallbacks.add(callback);
-    }
-
-    private void registerBestMatchingNetworkCallback(NetworkRequest request,
-            NetworkCallback callback, Handler handler) {
-        mCm.registerBestMatchingNetworkCallback(request, callback, handler);
-        mRegisteredCallbacks.add(callback);
-    }
-
-    private void requestBackgroundNetwork(NetworkRequest request, NetworkCallback callback,
-            Handler handler) throws Exception {
-        mCmShim.requestBackgroundNetwork(request, callback, handler);
-        mRegisteredCallbacks.add(callback);
-    }
 }