diff --git a/staticlibs/tests/unit/src/com/android/testutils/TestableNetworkCallbackTest.kt b/staticlibs/tests/unit/src/com/android/testutils/TestableNetworkCallbackTest.kt
index f8f2da0..eed31e0 100644
--- a/staticlibs/tests/unit/src/com/android/testutils/TestableNetworkCallbackTest.kt
+++ b/staticlibs/tests/unit/src/com/android/testutils/TestableNetworkCallbackTest.kt
@@ -22,29 +22,29 @@
 import android.net.Network
 import android.net.NetworkCapabilities
 import com.android.testutils.RecorderCallback.CallbackEntry
+import com.android.testutils.RecorderCallback.CallbackEntry.Available
+import com.android.testutils.RecorderCallback.CallbackEntry.BlockedStatus
+import com.android.testutils.RecorderCallback.CallbackEntry.CapabilitiesChanged
 import com.android.testutils.RecorderCallback.CallbackEntry.Companion.AVAILABLE
 import com.android.testutils.RecorderCallback.CallbackEntry.Companion.BLOCKED_STATUS
 import com.android.testutils.RecorderCallback.CallbackEntry.Companion.LINK_PROPERTIES_CHANGED
 import com.android.testutils.RecorderCallback.CallbackEntry.Companion.LOSING
-import com.android.testutils.RecorderCallback.CallbackEntry.Companion.NETWORK_CAPS_UPDATED
 import com.android.testutils.RecorderCallback.CallbackEntry.Companion.LOST
+import com.android.testutils.RecorderCallback.CallbackEntry.Companion.NETWORK_CAPS_UPDATED
 import com.android.testutils.RecorderCallback.CallbackEntry.Companion.RESUMED
 import com.android.testutils.RecorderCallback.CallbackEntry.Companion.SUSPENDED
 import com.android.testutils.RecorderCallback.CallbackEntry.Companion.UNAVAILABLE
-import com.android.testutils.RecorderCallback.CallbackEntry.Available
-import com.android.testutils.RecorderCallback.CallbackEntry.BlockedStatus
-import com.android.testutils.RecorderCallback.CallbackEntry.CapabilitiesChanged
-import org.junit.Before
-import org.junit.Test
-import org.junit.runner.RunWith
-import org.junit.runners.JUnit4
-import org.junit.Assume.assumeTrue
 import kotlin.reflect.KClass
 import kotlin.test.assertEquals
 import kotlin.test.assertFails
 import kotlin.test.assertNull
 import kotlin.test.assertTrue
 import kotlin.test.fail
+import org.junit.Assume.assumeTrue
+import org.junit.Before
+import org.junit.Test
+import org.junit.runner.RunWith
+import org.junit.runners.JUnit4
 
 const val SHORT_TIMEOUT_MS = 20L
 const val DEFAULT_LINGER_DELAY_MS = 30000
@@ -121,20 +121,16 @@
         mCallback.assertNoCallback(SHORT_TIMEOUT_MS)
         mCallback.onAvailable(Network(100))
         assertFails { mCallback.assertNoCallback(SHORT_TIMEOUT_MS) }
-    }
-
-    @Test
-    fun testAssertNoCallbackThat() {
         val net = Network(101)
-        mCallback.assertNoCallbackThat { it is Available }
+        mCallback.assertNoCallback { it is Available }
         mCallback.onAvailable(net)
         // Expect no blocked status change. Receive other callback does not fail the test.
-        mCallback.assertNoCallbackThat { it is BlockedStatus }
+        mCallback.assertNoCallback { it is BlockedStatus }
         mCallback.onBlockedStatusChanged(net, true)
-        assertFails { mCallback.assertNoCallbackThat { it is BlockedStatus } }
+        assertFails { mCallback.assertNoCallback { it is BlockedStatus } }
         mCallback.onBlockedStatusChanged(net, false)
         mCallback.onCapabilitiesChanged(net, NetworkCapabilities())
-        assertFails { mCallback.assertNoCallbackThat { it is CapabilitiesChanged } }
+        assertFails { mCallback.assertNoCallback { it is CapabilitiesChanged } }
     }
 
     @Test
diff --git a/staticlibs/testutils/devicetests/com/android/testutils/TestableNetworkCallback.kt b/staticlibs/testutils/devicetests/com/android/testutils/TestableNetworkCallback.kt
index 124d134..68d5fa9 100644
--- a/staticlibs/testutils/devicetests/com/android/testutils/TestableNetworkCallback.kt
+++ b/staticlibs/testutils/devicetests/com/android/testutils/TestableNetworkCallback.kt
@@ -167,20 +167,43 @@
 
 private const val DEFAULT_TIMEOUT = 30_000L // ms
 private const val DEFAULT_NO_CALLBACK_TIMEOUT = 200L // ms
+private val NOOP = Runnable {}
 
+/**
+ * See comments on the public constructor below for a description of the arguments.
+ */
 open class TestableNetworkCallback private constructor(
     src: TestableNetworkCallback?,
     val defaultTimeoutMs: Long = DEFAULT_TIMEOUT,
-    val defaultNoCallbackTimeoutMs: Long = DEFAULT_NO_CALLBACK_TIMEOUT
+    val defaultNoCallbackTimeoutMs: Long = DEFAULT_NO_CALLBACK_TIMEOUT,
+    val waiterFunc: Runnable = NOOP // "() -> Unit" would forbid calling with a void func from Java
 ) : RecorderCallback(src) {
+    /**
+     * Construct a testable network callback.
+     * @param timeoutMs the default timeout for expecting a callback. Default 30 seconds. This
+     *                  should be long in most cases, because the success case doesn't incur
+     *                  the wait.
+     * @param noCallbackTimeoutMs the timeout for expecting that no callback is received. Default
+     *                            200ms. Because the success case does incur the timeout, this
+     *                            should be short in most cases, but not so short as to frequently
+     *                            time out before an incorrect callback is received.
+     * @param waiterFunc a function to use before asserting no callback. For some specific tests,
+     *                   it is useful to run test-specific code before asserting no callback to
+     *                   increase the likelihood that a spurious callback is correctly detected.
+     *                   As an example, a unit test using mock loopers may want to use this to
+     *                   make sure the loopers are drained before asserting no callback, since
+     *                   one of them may cause a callback to be called. @see ConnectivityServiceTest
+     *                   for such an example.
+     */
     @JvmOverloads
     constructor(
         timeoutMs: Long = DEFAULT_TIMEOUT,
-        noCallbackTimeoutMs: Long = DEFAULT_NO_CALLBACK_TIMEOUT
-    ) : this(null, timeoutMs, noCallbackTimeoutMs)
+        noCallbackTimeoutMs: Long = DEFAULT_NO_CALLBACK_TIMEOUT,
+        waiterFunc: Runnable = NOOP
+    ) : this(null, timeoutMs, noCallbackTimeoutMs, waiterFunc)
 
     fun createLinkedCopy() = TestableNetworkCallback(
-            this, defaultTimeoutMs, defaultNoCallbackTimeoutMs)
+            this, defaultTimeoutMs, defaultNoCallbackTimeoutMs, waiterFunc)
 
     // The last available network, or null if any network was lost since the last call to
     // onAvailable. TODO : fix this by fixing the tests that rely on this behavior
@@ -343,25 +366,18 @@
         test: (T) -> Boolean = { true }
     ) = expect(network.network, timeoutMs, errorMsg, test)
 
-    // Make open for use in ConnectivityServiceTest which is the only one knowing its handlers.
-    // TODO : remove the necessity to overload this, remove the open qualifier, and give a
-    // default argument to assertNoCallback instead, possibly with @JvmOverloads if necessary.
-    open fun assertNoCallback() = assertNoCallback(defaultNoCallbackTimeoutMs)
-
-    fun assertNoCallback(timeoutMs: Long) {
-        val cb = history.poll(timeoutMs)
-        if (null != cb) fail("Expected no callback but got $cb")
-    }
-
-    fun assertNoCallbackThat(
+    @JvmOverloads
+    fun assertNoCallback(
         timeoutMs: Long = defaultNoCallbackTimeoutMs,
-        valid: (CallbackEntry) -> Boolean
+        valid: (CallbackEntry) -> Boolean = { true }
     ) {
-        val cb = history.poll(timeoutMs) { valid(it) }.let {
-            if (null != it) fail("Expected no callback but got $it")
-        }
+        waiterFunc.run()
+        history.poll(timeoutMs) { valid(it) }?.let { fail("Expected no callback but got $it") }
     }
 
+    fun assertNoCallback(valid: (CallbackEntry) -> Boolean) =
+            assertNoCallback(defaultNoCallbackTimeoutMs, valid)
+
     // Expects a callback of the specified type matching the predicate within the timeout.
     // Any callback that doesn't match the predicate will be skipped. Fails only if
     // no matching callback is received within the timeout.
