Use a CountDownLatch instead of sleep() in NetworkFactory tests.

This makes testNetworkFactoryRequests 2-3 times faster.

Bug: 22606153
Change-Id: I9657b6929e77f23ec811d0ab57b2ba974f0b6a69
diff --git a/services/tests/servicestests/src/com/android/server/ConnectivityServiceTest.java b/services/tests/servicestests/src/com/android/server/ConnectivityServiceTest.java
index 6281ad2..3160f66 100644
--- a/services/tests/servicestests/src/com/android/server/ConnectivityServiceTest.java
+++ b/services/tests/servicestests/src/com/android/server/ConnectivityServiceTest.java
@@ -61,7 +61,8 @@
 import com.android.server.connectivity.NetworkMonitor;
 
 import java.net.InetAddress;
-import java.util.concurrent.Future;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicBoolean;
 
 /**
@@ -73,6 +74,8 @@
 public class ConnectivityServiceTest extends AndroidTestCase {
     private static final String TAG = "ConnectivityServiceTest";
 
+    private static final int TIMEOUT_MS = 500;
+
     private BroadcastInterceptingContext mServiceContext;
     private WrappedConnectivityService mService;
     private ConnectivityManager mCm;
@@ -311,13 +314,28 @@
         }
     }
 
+    /**
+     * A NetworkFactory that allows tests to wait until any in-flight NetworkRequest add or remove
+     * operations have been processed. Before ConnectivityService can add or remove any requests,
+     * the factory must be told to expect those operations by calling expectAddRequests or
+     * expectRemoveRequests.
+     */
     private static class MockNetworkFactory extends NetworkFactory {
         private final ConditionVariable mNetworkStartedCV = new ConditionVariable();
         private final ConditionVariable mNetworkStoppedCV = new ConditionVariable();
-        private final ConditionVariable mNetworkRequestedCV = new ConditionVariable();
-        private final ConditionVariable mNetworkReleasedCV = new ConditionVariable();
         private final AtomicBoolean mNetworkStarted = new AtomicBoolean(false);
 
+        // Used to expect that requests be removed or added on a separate thread, without sleeping.
+        // Callers can call either expectAddRequests() or expectRemoveRequests() exactly once, then
+        // cause some other thread to add or remove requests, then call waitForRequests(). We can
+        // either expect requests to be added or removed, but not both, because CountDownLatch can
+        // only count in one direction.
+        private CountDownLatch mExpectations;
+
+        // Whether we are currently expecting requests to be added or removed. Valid only if
+        // mExpectations is non-null.
+        private boolean mExpectingAdditions;
+
         public MockNetworkFactory(Looper looper, Context context, String logTag,
                 NetworkCapabilities filter) {
             super(looper, context, logTag, filter);
@@ -351,28 +369,75 @@
             return mNetworkStoppedCV;
         }
 
-        protected void needNetworkFor(NetworkRequest networkRequest, int score) {
-            super.needNetworkFor(networkRequest, score);
-            mNetworkRequestedCV.open();
+        @Override
+        protected void handleAddRequest(NetworkRequest request, int score) {
+            // If we're expecting anything, we must be expecting additions.
+            if (mExpectations != null && !mExpectingAdditions) {
+                fail("Can't add requests while expecting requests to be removed");
+            }
+
+            // Add the request.
+            super.handleAddRequest(request, score);
+
+            // Reduce the number of request additions we're waiting for.
+            if (mExpectingAdditions) {
+                assertTrue("Added more requests than expected", mExpectations.getCount() > 0);
+                mExpectations.countDown();
+            }
         }
 
-        protected void releaseNetworkFor(NetworkRequest networkRequest) {
-            super.releaseNetworkFor(networkRequest);
-            mNetworkReleasedCV.open();
+        @Override
+        protected void handleRemoveRequest(NetworkRequest request) {
+            // If we're expecting anything, we must be expecting removals.
+            if (mExpectations != null && mExpectingAdditions) {
+                fail("Can't remove requests while expecting requests to be added");
+            }
+
+            // Remove the request.
+            super.handleRemoveRequest(request);
+
+            // Reduce the number of request removals we're waiting for.
+            if (!mExpectingAdditions) {
+                assertTrue("Removed more requests than expected", mExpectations.getCount() > 0);
+                mExpectations.countDown();
+            }
         }
 
-        public ConditionVariable getNetworkRequestedCV() {
-            mNetworkRequestedCV.close();
-            return mNetworkRequestedCV;
+        private void assertNoExpectations() {
+            if (mExpectations != null) {
+                fail("Can't add expectation, " + mExpectations.getCount() + " already pending");
+            }
         }
 
-        public ConditionVariable getNetworkReleasedCV() {
-            mNetworkReleasedCV.close();
-            return mNetworkReleasedCV;
+        // Expects that count requests will be added.
+        public void expectAddRequests(final int count) {
+            assertNoExpectations();
+            mExpectingAdditions = true;
+            mExpectations = new CountDownLatch(count);
         }
 
-        public void waitForNetworkRequests(final int count) {
-            waitFor(new Criteria() { public boolean get() { return count == getRequestCount(); } });
+        // Expects that count requests will be removed.
+        public void expectRemoveRequests(final int count) {
+            assertNoExpectations();
+            mExpectingAdditions = false;
+            mExpectations = new CountDownLatch(count);
+        }
+
+        // Waits for the expected request additions or removals to happen within a timeout.
+        public void waitForRequests() throws InterruptedException {
+            assertNotNull("Nothing to wait for", mExpectations);
+            mExpectations.await(TIMEOUT_MS, TimeUnit.MILLISECONDS);
+            final long count = mExpectations.getCount();
+            final String msg = count + " requests still not " +
+                    (mExpectingAdditions ? "added" : "removed") +
+                    " after " + TIMEOUT_MS + " ms";
+            assertEquals(msg, 0, count);
+            mExpectations = null;
+        }
+
+        public void waitForNetworkRequests(final int count) throws InterruptedException {
+            waitForRequests();
+            assertEquals(count, getMyRequestCount());
         }
     }
 
@@ -450,7 +515,7 @@
         }
 
         public void waitForIdle() {
-            waitForIdle(500);
+            waitForIdle(TIMEOUT_MS);
         }
 
     }
@@ -475,11 +540,11 @@
     }
 
     /**
-     * Wait up to 500ms for {@code conditionVariable} to open.
-     * Fails if 500ms goes by before {@code conditionVariable} opens.
+     * Wait up to TIMEOUT_MS for {@code conditionVariable} to open.
+     * Fails if TIMEOUT_MS goes by before {@code conditionVariable} opens.
      */
     static private void waitFor(ConditionVariable conditionVariable) {
-        assertTrue(conditionVariable.block(500));
+        assertTrue(conditionVariable.block(TIMEOUT_MS));
     }
 
     @Override
@@ -963,18 +1028,21 @@
                 mServiceContext, "testFactory", filter);
         testFactory.setScoreFilter(40);
         ConditionVariable cv = testFactory.getNetworkStartedCV();
+        testFactory.expectAddRequests(1);
         testFactory.register();
+        testFactory.waitForNetworkRequests(1);
         int expectedRequestCount = 1;
         NetworkCallback networkCallback = null;
         // For non-INTERNET capabilities we cannot rely on the default request being present, so
         // add one.
         if (capability != NET_CAPABILITY_INTERNET) {
-            testFactory.waitForNetworkRequests(1);
             assertFalse(testFactory.getMyStartRequested());
             NetworkRequest request = new NetworkRequest.Builder().addCapability(capability).build();
             networkCallback = new NetworkCallback();
+            testFactory.expectAddRequests(1);
             mCm.requestNetwork(request, networkCallback);
             expectedRequestCount++;
+            testFactory.waitForNetworkRequests(expectedRequestCount);
         }
         waitFor(cv);
         assertEquals(expectedRequestCount, testFactory.getMyRequestCount());
@@ -987,13 +1055,20 @@
         // unvalidated penalty.
         testAgent.adjustScore(40);
         cv = testFactory.getNetworkStoppedCV();
+
+        // When testAgent connects, ConnectivityService will re-send us all current requests with
+        // the new score. There are expectedRequestCount such requests, and we must wait for all of
+        // them.
+        testFactory.expectAddRequests(expectedRequestCount);
         testAgent.connect(false);
         testAgent.addCapability(capability);
         waitFor(cv);
-        assertEquals(expectedRequestCount, testFactory.getMyRequestCount());
+        testFactory.waitForNetworkRequests(expectedRequestCount);
         assertFalse(testFactory.getMyStartRequested());
 
         // Bring in a bunch of requests.
+        testFactory.expectAddRequests(10);
+        assertEquals(expectedRequestCount, testFactory.getMyRequestCount());
         ConnectivityManager.NetworkCallback[] networkCallbacks =
                 new ConnectivityManager.NetworkCallback[10];
         for (int i = 0; i< networkCallbacks.length; i++) {
@@ -1006,6 +1081,7 @@
         assertFalse(testFactory.getMyStartRequested());
 
         // Remove the requests.
+        testFactory.expectRemoveRequests(10);
         for (int i = 0; i < networkCallbacks.length; i++) {
             mCm.unregisterNetworkCallback(networkCallbacks[i]);
         }