Merge "Fix the test flake on ConnectivityServiceTest"
diff --git a/tests/unit/java/com/android/server/ConnectivityServiceTest.java b/tests/unit/java/com/android/server/ConnectivityServiceTest.java
index 7993a5c..f9f63ed 100755
--- a/tests/unit/java/com/android/server/ConnectivityServiceTest.java
+++ b/tests/unit/java/com/android/server/ConnectivityServiceTest.java
@@ -1068,38 +1068,41 @@
          * @param hasInternet Indicate if network should pretend to have NET_CAPABILITY_INTERNET.
          */
         public void connect(boolean validated, boolean hasInternet, boolean isStrictMode) {
-            ConnectivityManager.NetworkCallback callback = null;
             final ConditionVariable validatedCv = new ConditionVariable();
+            final ConditionVariable capsChangedCv = new ConditionVariable();
+            final NetworkRequest request = new NetworkRequest.Builder()
+                    .addTransportType(getNetworkCapabilities().getTransportTypes()[0])
+                    .clearCapabilities()
+                    .build();
             if (validated) {
                 setNetworkValid(isStrictMode);
-                NetworkRequest request = new NetworkRequest.Builder()
-                        .addTransportType(getNetworkCapabilities().getTransportTypes()[0])
-                        .clearCapabilities()
-                        .build();
-                callback = new ConnectivityManager.NetworkCallback() {
-                    public void onCapabilitiesChanged(Network network,
-                            NetworkCapabilities networkCapabilities) {
-                        if (network.equals(getNetwork()) &&
-                                networkCapabilities.hasCapability(NET_CAPABILITY_VALIDATED)) {
+            }
+            final NetworkCallback callback = new NetworkCallback() {
+                public void onCapabilitiesChanged(Network network,
+                        NetworkCapabilities networkCapabilities) {
+                    if (network.equals(getNetwork())) {
+                        capsChangedCv.open();
+                        if (networkCapabilities.hasCapability(NET_CAPABILITY_VALIDATED)) {
                             validatedCv.open();
                         }
                     }
-                };
-                mCm.registerNetworkCallback(request, callback);
-            }
+                }
+            };
+            mCm.registerNetworkCallback(request, callback);
+
             if (hasInternet) {
                 addCapability(NET_CAPABILITY_INTERNET);
             }
 
             connectWithoutInternet();
+            waitFor(capsChangedCv);
 
             if (validated) {
                 // Wait for network to validate.
                 waitFor(validatedCv);
                 setNetworkInvalid(isStrictMode);
             }
-
-            if (callback != null) mCm.unregisterNetworkCallback(callback);
+            mCm.unregisterNetworkCallback(callback);
         }
 
         public void connectWithCaptivePortal(String redirectUrl, boolean isStrictMode) {