Store registered network callbacks and unregister all at tearDown

Add helper methods to register network callbacks and unregister
them all at tearDown to prevent missing unregisteration and causing
leak. This will also help to develop the follow up test cases.

Test: atest com.android.cts.net.HostsideVpnTests
Change-Id: Ifefee3f9bda63a9f45108516f86f2f5242c53bb2
diff --git a/tests/cts/hostside/app/src/com/android/cts/net/hostside/VpnTest.java b/tests/cts/hostside/app/src/com/android/cts/net/hostside/VpnTest.java
index a5bf000..10845d7 100755
--- a/tests/cts/hostside/app/src/com/android/cts/net/hostside/VpnTest.java
+++ b/tests/cts/hostside/app/src/com/android/cts/net/hostside/VpnTest.java
@@ -201,13 +201,15 @@
     private Context mTestContext;
     private Context mTargetContext;
     Network mNetwork;
-    NetworkCallback mCallback;
     final Object mLock = new Object();
     final Object mLockShutdown = new Object();
 
     private String mOldPrivateDnsMode;
     private String mOldPrivateDnsSpecifier;
 
+    // The registered callbacks.
+    private List<NetworkCallback> mRegisteredCallbacks = new ArrayList<>();
+
     @Rule
     public final DevSdkIgnoreRule mDevSdkIgnoreRule = new DevSdkIgnoreRule();
 
@@ -228,7 +230,6 @@
     @Before
     public void setUp() throws Exception {
         mNetwork = null;
-        mCallback = null;
         mTestContext = getInstrumentation().getContext();
         mTargetContext = getInstrumentation().getTargetContext();
         storePrivateDnsSetting();
@@ -248,15 +249,40 @@
     public void tearDown() throws Exception {
         restorePrivateDnsSetting();
         mRemoteSocketFactoryClient.unbind();
-        if (mCallback != null) {
-            mCM.unregisterNetworkCallback(mCallback);
-        }
         mCtsNetUtils.tearDown();
         Log.i(TAG, "Stopping VPN");
         stopVpn();
+        unregisterRegisteredCallbacks();
         mActivity.finish();
     }
 
+    private void registerNetworkCallback(NetworkRequest request, NetworkCallback callback) {
+        mCM.registerNetworkCallback(request, callback);
+        mRegisteredCallbacks.add(callback);
+    }
+
+    private void registerDefaultNetworkCallback(NetworkCallback callback) {
+        mCM.registerDefaultNetworkCallback(callback);
+        mRegisteredCallbacks.add(callback);
+    }
+
+    private void registerSystemDefaultNetworkCallback(NetworkCallback callback, Handler h) {
+        mCM.registerSystemDefaultNetworkCallback(callback, h);
+        mRegisteredCallbacks.add(callback);
+    }
+
+    private void registerDefaultNetworkCallbackForUid(int uid, NetworkCallback callback,
+            Handler h) {
+        mCM.registerDefaultNetworkCallbackForUid(uid, callback, h);
+        mRegisteredCallbacks.add(callback);
+    }
+
+    private void unregisterRegisteredCallbacks() {
+        for (NetworkCallback callback: mRegisteredCallbacks) {
+            mCM.unregisterNetworkCallback(callback);
+        }
+    }
+
     private void prepareVpn() throws Exception {
         final int REQUEST_ID = 42;
 
@@ -372,7 +398,7 @@
                 .removeCapability(NetworkCapabilities.NET_CAPABILITY_NOT_VPN)
                 .removeCapability(NetworkCapabilities.NET_CAPABILITY_INTERNET)
                 .build();
-        mCallback = new NetworkCallback() {
+        final NetworkCallback callback = new NetworkCallback() {
             public void onAvailable(Network network) {
                 synchronized (mLock) {
                     Log.i(TAG, "Got available callback for network=" + network);
@@ -381,7 +407,7 @@
                 }
             }
         };
-        mCM.registerNetworkCallback(request, mCallback);  // Unregistered in tearDown.
+        registerNetworkCallback(request, callback);
 
         // Start the service and wait up for TIMEOUT_MS ms for the VPN to come up.
         establishVpn(addresses, routes, excludedRoutes, allowedApplications, disallowedApplications,
@@ -406,7 +432,7 @@
                 .removeCapability(NetworkCapabilities.NET_CAPABILITY_NOT_VPN)
                 .removeCapability(NetworkCapabilities.NET_CAPABILITY_INTERNET)
                 .build();
-        mCallback = new NetworkCallback() {
+        final NetworkCallback callback = new NetworkCallback() {
             public void onLost(Network network) {
                 synchronized (mLockShutdown) {
                     Log.i(TAG, "Got lost callback for network=" + network
@@ -417,7 +443,7 @@
                 }
             }
        };
-        mCM.registerNetworkCallback(request, mCallback);  // Unregistered in tearDown.
+        registerNetworkCallback(request, callback);
         // Simply calling mActivity.stopService() won't stop the service, because the system binds
         // to the service for the purpose of sending it a revoke command if another VPN comes up,
         // and stopping a bound service has no effect. Instead, "start" the service again with an
@@ -778,14 +804,10 @@
             }
         };
 
-        mCM.registerNetworkCallback(request, callback);
+        registerNetworkCallback(request, callback);
 
-        try {
-            assertTrue("Private DNS hostname was not " + hostname + " after " + TIMEOUT_MS + "ms",
-                    latch.await(TIMEOUT_MS, TimeUnit.MILLISECONDS));
-        } finally {
-            mCM.unregisterNetworkCallback(callback);
-        }
+        assertTrue("Private DNS hostname was not " + hostname + " after " + TIMEOUT_MS + "ms",
+                latch.await(TIMEOUT_MS, TimeUnit.MILLISECONDS));
     }
 
     private void setAndVerifyPrivateDns(boolean strictMode) throws Exception {
@@ -872,7 +894,7 @@
                     false /* isAlwaysMetered */);
             // Acquire the NETWORK_SETTINGS permission for getting the underlying networks.
             runWithShellPermissionIdentity(() -> {
-                mCM.registerNetworkCallback(makeVpnNetworkRequest(), callback);
+                registerNetworkCallback(makeVpnNetworkRequest(), callback);
                 // Check that this VPN doesn't have any underlying networks.
                 expectUnderlyingNetworks(callback, new ArrayList<Network>());
 
@@ -905,8 +927,6 @@
                 } else {
                     mCtsNetUtils.ensureWifiDisconnected(null);
                 }
-            }, () -> {
-                mCM.unregisterNetworkCallback(callback);
             });
     }
 
@@ -940,9 +960,9 @@
                     UserHandle.of(5 /* userId */).getUid(Process.FIRST_APPLICATION_UID);
             final Handler h = new Handler(Looper.getMainLooper());
             runWithShellPermissionIdentity(() -> {
-                mCM.registerSystemDefaultNetworkCallback(systemDefaultCallback, h);
-                mCM.registerDefaultNetworkCallbackForUid(otherUid, otherUidCallback, h);
-                mCM.registerDefaultNetworkCallbackForUid(Process.myUid(), myUidCallback, h);
+                registerSystemDefaultNetworkCallback(systemDefaultCallback, h);
+                registerDefaultNetworkCallbackForUid(otherUid, otherUidCallback, h);
+                registerDefaultNetworkCallbackForUid(Process.myUid(), myUidCallback, h);
             }, NETWORK_SETTINGS);
             for (TestableNetworkCallback callback :
                     List.of(systemDefaultCallback, otherUidCallback, myUidCallback)) {
@@ -993,9 +1013,6 @@
             // fail and could cause the default network to switch (e.g., from wifi to cellular).
             systemDefaultCallback.assertNoCallback();
             otherUidCallback.assertNoCallback();
-            mCM.unregisterNetworkCallback(systemDefaultCallback);
-            mCM.unregisterNetworkCallback(otherUidCallback);
-            mCM.unregisterNetworkCallback(myUidCallback);
         }
 
         checkStrictModePrivateDns();
@@ -1623,7 +1640,7 @@
 
         testAndCleanup(() -> {
             runWithShellPermissionIdentity(() -> {
-                mCM.registerDefaultNetworkCallbackForUid(remoteUid, remoteUidCallback,
+                registerDefaultNetworkCallbackForUid(remoteUid, remoteUidCallback,
                         new Handler(Looper.getMainLooper()));
             }, NETWORK_SETTINGS);
             remoteUidCallback.expectAvailableCallbacksWithBlockedReasonNone(network);
@@ -1659,8 +1676,6 @@
 
             checkBlockIncomingPacket(tunFd, remoteUdpFd, EXPECT_BLOCK);
         }, /* cleanup */ () -> {
-                mCM.unregisterNetworkCallback(remoteUidCallback);
-            }, /* cleanup */ () -> {
                 Os.close(tunFd);
             }, /* cleanup */ () -> {
                 Os.close(remoteUdpFd);
@@ -1684,7 +1699,7 @@
         final int myUid = Process.myUid();
 
         testAndCleanup(() -> {
-            mCM.registerDefaultNetworkCallback(defaultNetworkCallback);
+            registerDefaultNetworkCallback(defaultNetworkCallback);
             defaultNetworkCallback.expectAvailableCallbacks(defaultNetwork);
 
             final Range<Integer> myUidRange = new Range<>(myUid, myUid);
@@ -1716,8 +1731,6 @@
                 defaultNetworkCallback.eventuallyExpect(CallbackEntry.AVAILABLE,
                         NETWORK_CALLBACK_TIMEOUT_MS,
                         entry -> defaultNetwork.equals(entry.getNetwork()));
-            }, /* cleanup */ () -> {
-                mCM.unregisterNetworkCallback(defaultNetworkCallback);
             });
     }