Cleanups to VPN hostside tests.

Use TestableNetworkCallback instead of a hand-rolled class.
Remove unnecessary runWithShellPermissionIdentity around
unregisterNetworkCallback calls.

Bug: 165835257
Test: test-only change
Change-Id: I4557dfc64136f9c0b4bdaa1248c33b13e96ba3ed
diff --git a/tests/cts/hostside/app/src/com/android/cts/net/hostside/VpnTest.java b/tests/cts/hostside/app/src/com/android/cts/net/hostside/VpnTest.java
index a47d304..c0600e7 100755
--- a/tests/cts/hostside/app/src/com/android/cts/net/hostside/VpnTest.java
+++ b/tests/cts/hostside/app/src/com/android/cts/net/hostside/VpnTest.java
@@ -55,7 +55,6 @@
 import android.net.VpnManager;
 import android.net.VpnService;
 import android.net.VpnTransportInfo;
-import android.net.cts.util.CtsNetUtils.TestNetworkCallback;
 import android.net.wifi.WifiManager;
 import android.os.Handler;
 import android.os.Looper;
@@ -78,6 +77,7 @@
 
 import com.android.compatibility.common.util.BlockingBroadcastReceiver;
 import com.android.modules.utils.build.SdkLevel;
+import com.android.testutils.TestableNetworkCallback;
 
 import java.io.Closeable;
 import java.io.FileDescriptor;
@@ -94,6 +94,7 @@
 import java.net.UnknownHostException;
 import java.nio.charset.StandardCharsets;
 import java.util.ArrayList;
+import java.util.List;
 import java.util.Objects;
 import java.util.Random;
 import java.util.concurrent.CompletableFuture;
@@ -700,34 +701,6 @@
         setAndVerifyPrivateDns(initialMode);
     }
 
-    private class NeverChangeNetworkCallback extends NetworkCallback {
-        private CountDownLatch mLatch = new CountDownLatch(1);
-        private volatile Network mFirstNetwork;
-        private volatile Network mOtherNetwork;
-
-        public void onAvailable(Network n) {
-            // Don't assert here, as it crashes the test with a hard to debug message.
-            if (mFirstNetwork == null) {
-                mFirstNetwork = n;
-                mLatch.countDown();
-            } else if (mOtherNetwork == null) {
-                mOtherNetwork = n;
-            }
-        }
-
-        public Network getFirstNetwork() throws Exception {
-            assertTrue(
-                    "System default callback got no network after " + TIMEOUT_MS + "ms. "
-                    + "Please ensure the device has a working Internet connection.",
-                    mLatch.await(TIMEOUT_MS, TimeUnit.MILLISECONDS));
-            return mFirstNetwork;
-        }
-
-        public void assertNeverChanged() {
-            assertNull(mOtherNetwork);
-        }
-    }
-
     public void testDefault() throws Exception {
         if (!supportedHardware()) return;
         // If adb TCP port opened, this test may running by adb over network.
@@ -745,9 +718,9 @@
 
         // Test the behaviour of a variety of types of network callbacks.
         final Network defaultNetwork = mCM.getActiveNetwork();
-        final NeverChangeNetworkCallback systemDefaultCallback = new NeverChangeNetworkCallback();
-        final NeverChangeNetworkCallback otherUidCallback = new NeverChangeNetworkCallback();
-        final TestNetworkCallback myUidCallback = new TestNetworkCallback();
+        final TestableNetworkCallback systemDefaultCallback = new TestableNetworkCallback();
+        final TestableNetworkCallback otherUidCallback = new TestableNetworkCallback();
+        final TestableNetworkCallback myUidCallback = new TestableNetworkCallback();
         if (SdkLevel.isAtLeastS()) {
             final int otherUid = UserHandle.getUid(UserHandle.of(5), Process.FIRST_APPLICATION_UID);
             final Handler h = new Handler(Looper.getMainLooper());
@@ -756,7 +729,11 @@
                 mCM.registerDefaultNetworkCallbackAsUid(otherUid, otherUidCallback, h);
                 mCM.registerDefaultNetworkCallbackAsUid(Process.myUid(), myUidCallback, h);
             }, NETWORK_SETTINGS);
-            assertEquals(defaultNetwork, myUidCallback.waitForAvailable());
+            for (TestableNetworkCallback callback :
+                    List.of(systemDefaultCallback, otherUidCallback, myUidCallback)) {
+                callback.expectAvailableCallbacks(defaultNetwork, false /* suspended */,
+                        true /* validated */, false /* blocked */, TIMEOUT_MS);
+            }
         }
 
         FileDescriptor fd = openSocketFdInOtherApp(TEST_HOST, 80, TIMEOUT_MS);
@@ -776,7 +753,8 @@
 
         checkTrafficOnVpn();
 
-        final Network vpnNetwork = myUidCallback.waitForAvailable();
+        final Network vpnNetwork = mCM.getActiveNetwork();
+        myUidCallback.expectAvailableThenValidatedCallbacks(vpnNetwork, TIMEOUT_MS);
         assertEquals(vpnNetwork, mCM.getActiveNetwork());
         assertNotEqual(defaultNetwork, vpnNetwork);
         maybeExpectVpnTransportInfo(vpnNetwork);
@@ -788,15 +766,11 @@
             // This needs to be done before testing  private DNS because checkStrictModePrivateDns
             // will set the private DNS server to a nonexistent name, which will cause validation to
             // fail and could cause the default network to switch (e.g., from wifi to cellular).
-            assertEquals(defaultNetwork, systemDefaultCallback.getFirstNetwork());
-            systemDefaultCallback.assertNeverChanged();
-            assertEquals(defaultNetwork, otherUidCallback.getFirstNetwork());
-            otherUidCallback.assertNeverChanged();
-            runWithShellPermissionIdentity(() -> {
-                mCM.unregisterNetworkCallback(systemDefaultCallback);
-                mCM.unregisterNetworkCallback(otherUidCallback);
-                mCM.unregisterNetworkCallback(myUidCallback);
-            }, NETWORK_SETTINGS);
+            systemDefaultCallback.assertNoCallback();
+            otherUidCallback.assertNoCallback();
+            mCM.unregisterNetworkCallback(systemDefaultCallback);
+            mCM.unregisterNetworkCallback(otherUidCallback);
+            mCM.unregisterNetworkCallback(myUidCallback);
         }
 
         checkStrictModePrivateDns();