[CC03] Replace expectCallback* with expect

"callback" is redundant – it's a TestCallback object and
it takes a type of callback anyway.

"assertNextIs" is shorter than expectCallback, and more
directly explicit than "expect", but it's also 3 words
and very different from previous usage.
Using expect() is a little bit less in-your-face obvious,
but it's simple, familiar, and short.

Test: CtsNetTestCases
      FrameworksNetIntegrationTests
      FrameworksNetTests
      NetworkStackTests
Bug: 157405399
Change-Id: I6c8c85a8be3895dd8f0ef681faa4a8b2b4f2f493
diff --git a/staticlibs/testutils/devicetests/com/android/testutils/TestableNetworkCallback.kt b/staticlibs/testutils/devicetests/com/android/testutils/TestableNetworkCallback.kt
index 0c16937..a40e57d 100644
--- a/staticlibs/testutils/devicetests/com/android/testutils/TestableNetworkCallback.kt
+++ b/staticlibs/testutils/devicetests/com/android/testutils/TestableNetworkCallback.kt
@@ -36,7 +36,6 @@
 import kotlin.reflect.KClass
 import kotlin.test.assertEquals
 import kotlin.test.assertNotNull
-import kotlin.test.assertTrue
 import kotlin.test.fail
 
 object NULL_NETWORK : Network(-1)
@@ -212,6 +211,162 @@
         errorMsg: String = "Did not receive callback after $timeoutMs"
     ): CallbackEntry = poll(timeoutMs) ?: fail(errorMsg)
 
+    /*****
+     * AssertNextIs family of methods.
+     * These methods fetch the next callback and assert it matches the conditions : type,
+     * passed predicate. If no callback is received within the timeout, these methods fail.
+     */
+    fun <T : CallbackEntry> expect(
+        type: KClass<T>,
+        network: Network = ANY_NETWORK,
+        timeoutMs: Long = defaultTimeoutMs,
+        errorMsg: String? = null,
+        test: (T) -> Boolean = { true }
+    ) = expect<CallbackEntry>(network, timeoutMs, errorMsg) {
+        test(it as? T ?: fail("Expected callback ${type.simpleName}, got $it"))
+    } as T
+
+    fun <T : CallbackEntry> expect(
+        type: KClass<T>,
+        network: HasNetwork,
+        timeoutMs: Long = defaultTimeoutMs,
+        errorMsg: String? = null,
+        test: (T) -> Boolean = { true }
+    ) = expect(type, network.network, timeoutMs, errorMsg, test)
+
+    // Java needs an explicit overload to let it omit arguments in the middle, so define these
+    // here. Note that @JvmOverloads give us the versions without the last arguments too, so
+    // there is no need to explicitly define versions without the test predicate.
+    // Without |network|
+    @JvmOverloads
+    fun <T : CallbackEntry> expect(
+        type: KClass<T>,
+        timeoutMs: Long,
+        errorMsg: String?,
+        test: (T) -> Boolean = { true }
+    ) = expect(type, ANY_NETWORK, timeoutMs, errorMsg, test)
+
+    // Without |timeout|, in Network and HasNetwork versions
+    @JvmOverloads
+    fun <T : CallbackEntry> expect(
+        type: KClass<T>,
+        network: Network,
+        errorMsg: String?,
+        test: (T) -> Boolean = { true }
+    ) = expect(type, network, defaultTimeoutMs, errorMsg, test)
+
+    @JvmOverloads
+    fun <T : CallbackEntry> expect(
+        type: KClass<T>,
+        network: HasNetwork,
+        errorMsg: String?,
+        test: (T) -> Boolean = { true }
+    ) = expect(type, network.network, defaultTimeoutMs, errorMsg, test)
+
+    // Without |errorMsg|, in Network and HasNetwork versions
+    @JvmOverloads
+    fun <T : CallbackEntry> expect(
+        type: KClass<T>,
+        network: Network,
+        timeoutMs: Long,
+        test: (T) -> Boolean = { true }
+    ) = expect(type, network, timeoutMs, null, test)
+
+    @JvmOverloads
+    fun <T : CallbackEntry> expect(
+        type: KClass<T>,
+        network: HasNetwork,
+        timeoutMs: Long,
+        test: (T) -> Boolean = { true }
+    ) = expect(type, network.network, timeoutMs, null, test)
+
+    // Without |network| or |timeout|
+    @JvmOverloads
+    fun <T : CallbackEntry> expect(
+        type: KClass<T>,
+        errorMsg: String?,
+        test: (T) -> Boolean = { true }
+    ) = expect(type, ANY_NETWORK, defaultTimeoutMs, errorMsg, test)
+
+    // Without |network| or |errorMsg|
+    @JvmOverloads
+    fun <T : CallbackEntry> expect(
+        type: KClass<T>,
+        timeoutMs: Long,
+        test: (T) -> Boolean = { true }
+    ) = expect(type, ANY_NETWORK, timeoutMs, null, test)
+
+    // Without |timeout| or |errorMsg|, in Network and HasNetwork versions
+    @JvmOverloads
+    fun <T : CallbackEntry> expect(
+        type: KClass<T>,
+        network: Network,
+        test: (T) -> Boolean = { true }
+    ) = expect(type, network, defaultTimeoutMs, null, test)
+
+    @JvmOverloads
+    fun <T : CallbackEntry> expect(
+        type: KClass<T>,
+        network: HasNetwork,
+        test: (T) -> Boolean = { true }
+    ) = expect(type, network.network, defaultTimeoutMs, null, test)
+
+    // Without |network| or |timeout| or |errorMsg|
+    @JvmOverloads
+    fun <T : CallbackEntry> expect(
+        type: KClass<T>,
+        test: (T) -> Boolean = { true }
+    ) = expect(type, ANY_NETWORK, defaultTimeoutMs, null, test)
+
+    // Kotlin reified versions. Don't call methods above, or the predicate would need to be noinline
+    inline fun <reified T : CallbackEntry> expect(
+        network: Network = ANY_NETWORK,
+        timeoutMs: Long = defaultTimeoutMs,
+        errorMsg: String? = null,
+        test: (T) -> Boolean = { true }
+    ) = pollOrThrow(timeoutMs).also {
+        if (it !is T) fail("Expected callback ${T::class.simpleName}, got $it")
+        if (ANY_NETWORK !== network && it.network != network) {
+            fail("Expected network $network for callback : $it")
+        }
+        if (!test(it)) {
+            fail("${errorMsg ?: "Callback doesn't match predicate"} : $it")
+        }
+    } as T
+
+    inline fun <reified T : CallbackEntry> expect(
+        network: HasNetwork,
+        timeoutMs: Long = defaultTimeoutMs,
+        errorMsg: String? = null,
+        test: (T) -> Boolean = { true }
+    ) = expect(network.network, timeoutMs, errorMsg, test)
+
+    // TODO : remove all expectCallback and expectCallbackThat methods after all callers have been
+    // migrated to expect().
+    inline fun <reified T : CallbackEntry> expectCallback(
+        network: Network = ANY_NETWORK,
+        timeoutMs: Long = defaultTimeoutMs
+    ): T = expect(network, timeoutMs)
+
+    @JvmOverloads
+    open fun <T : CallbackEntry> expectCallback(
+        type: KClass<T>,
+        n: Network?,
+        timeoutMs: Long = defaultTimeoutMs
+    ) = expect(type, n ?: ANY_NETWORK, timeoutMs)
+
+    @JvmOverloads
+    open fun <T : CallbackEntry> expectCallback(
+        type: KClass<T>,
+        n: HasNetwork?,
+        timeoutMs: Long = defaultTimeoutMs
+    ) = expect(type, n?.network ?: ANY_NETWORK, timeoutMs)
+
+    fun expectCallbackThat(
+        timeoutMs: Long = defaultTimeoutMs,
+        valid: (CallbackEntry) -> Boolean
+    ) = expect(timeoutMs = timeoutMs, test = valid)
+
     // Make open for use in ConnectivityServiceTest which is the only one knowing its handlers.
     // TODO : remove the necessity to overload this, remove the open qualifier, and give a
     // default argument to assertNoCallback instead, possibly with @JvmOverloads if necessary.
@@ -231,19 +386,6 @@
         }
     }
 
-    // Expects a callback of the specified type on the specified network within the timeout.
-    // If no callback arrives, or a different callback arrives, fail. Returns the callback.
-    inline fun <reified T : CallbackEntry> expectCallback(
-        network: Network = ANY_NETWORK,
-        timeoutMs: Long = defaultTimeoutMs
-    ): T = pollOrThrow(timeoutMs).let {
-        if (it !is T || (ANY_NETWORK !== network && it.network != network)) {
-            fail("Unexpected callback : $it, expected ${T::class} with Network[$network]")
-        } else {
-            it
-        }
-    }
-
     // Expects a callback of the specified type matching the predicate within the timeout.
     // Any callback that doesn't match the predicate will be skipped. Fails only if
     // no matching callback is received within the timeout.
@@ -270,30 +412,19 @@
         crossinline predicate: (T) -> Boolean = { true }
     ) = history.poll(timeoutMs, from) { it is T && predicate(it) } as T?
 
-    fun expectCallbackThat(
-        timeoutMs: Long = defaultTimeoutMs,
-        valid: (CallbackEntry) -> Boolean
-    ) = pollOrThrow(timeoutMs).also { assertTrue(valid(it), "Unexpected callback : $it") }
-
-    fun expectCapabilitiesThat(
+    inline fun expectCapabilitiesThat(
         net: Network,
         tmt: Long = defaultTimeoutMs,
         valid: (NetworkCapabilities) -> Boolean
-    ): CapabilitiesChanged {
-        return expectCallback<CapabilitiesChanged>(net, tmt).also {
-            assertTrue(valid(it.caps), "Capabilities don't match expectations ${it.caps}")
-        }
-    }
+    ): CapabilitiesChanged =
+            expect(net, tmt, "Capabilities don't match expectations") { valid(it.caps) }
 
-    fun expectLinkPropertiesThat(
+    inline fun expectLinkPropertiesThat(
         net: Network,
         tmt: Long = defaultTimeoutMs,
         valid: (LinkProperties) -> Boolean
-    ): LinkPropertiesChanged {
-        return expectCallback<LinkPropertiesChanged>(net, tmt).also {
-            assertTrue(valid(it.lp), "LinkProperties don't match expectations ${it.lp}")
-        }
-    }
+    ): LinkPropertiesChanged =
+            expect(net, tmt, "LinkProperties don't match expectations") { valid(it.lp) }
 
     // Expects onAvailable and the callbacks that follow it. These are:
     // - onSuspended, iff the network was suspended when the callbacks fire.
@@ -334,16 +465,16 @@
         validated: Boolean?,
         tmt: Long
     ) {
-        expectCallback<Available>(net, tmt)
+        expect<Available>(net, tmt)
         if (suspended) {
-            expectCallback<Suspended>(net, tmt)
+            expect<Suspended>(net, tmt)
         }
         expectCapabilitiesThat(net, tmt) {
             validated == null || validated == it.hasCapability(
                 NET_CAPABILITY_VALIDATED
             )
         }
-        expectCallback<LinkPropertiesChanged>(net, tmt)
+        expect<LinkPropertiesChanged>(net, tmt)
     }
 
     // Backward compatibility for existing Java code. Use named arguments instead and remove all
@@ -354,17 +485,15 @@
         tmt: Long = defaultTimeoutMs
     ) = expectAvailableCallbacks(net, suspended = true, validated = validated, tmt = tmt)
 
-    fun expectBlockedStatusCallback(blocked: Boolean, net: Network, tmt: Long = defaultTimeoutMs) {
-        expectCallback<BlockedStatus>(net, tmt).also {
-            assertEquals(blocked, it.blocked, "Unexpected blocked status ${it.blocked}")
-        }
-    }
+    fun expectBlockedStatusCallback(blocked: Boolean, net: Network, tmt: Long = defaultTimeoutMs) =
+            expect<BlockedStatus>(net, tmt, "Unexpected blocked status") {
+                it.blocked == blocked
+            }
 
-    fun expectBlockedStatusCallback(blocked: Int, net: Network, tmt: Long = defaultTimeoutMs) {
-        expectCallback<BlockedStatusInt>(net, tmt).also {
-            assertEquals(blocked, it.blocked, "Unexpected blocked status ${it.blocked}")
-        }
-    }
+    fun expectBlockedStatusCallback(blocked: Int, net: Network, tmt: Long = defaultTimeoutMs) =
+            expect<BlockedStatusInt>(net, tmt, "Unexpected blocked status") {
+                it.blocked == blocked
+            }
 
     // Expects the available callbacks (where the onCapabilitiesChanged must contain the
     // VALIDATED capability), plus another onCapabilitiesChanged which is identical to the
@@ -374,7 +503,7 @@
         val mark = history.mark
         expectAvailableCallbacks(net, tmt = tmt)
         val firstCaps = history.poll(tmt, mark) { it is CapabilitiesChanged }
-        assertEquals(firstCaps, expectCallback<CapabilitiesChanged>(net, tmt))
+        assertEquals(firstCaps, expect<CapabilitiesChanged>(net, tmt))
     }
 
     // Expects the available callbacks where the onCapabilitiesChanged must not have validated,
@@ -402,26 +531,6 @@
         val network: Network
     }
 
-    @JvmOverloads
-    open fun <T : CallbackEntry> expectCallback(
-        type: KClass<T>,
-        n: Network?,
-        timeoutMs: Long = defaultTimeoutMs
-    ) = pollOrThrow(timeoutMs).also {
-        val network = n ?: NULL_NETWORK
-        // TODO : remove this .java access if the tests ever use kotlin-reflect. At the time of
-        // this writing this would be the only use of this library in the tests.
-        assertTrue(type.java.isInstance(it) && (ANY_NETWORK === n || it.network == network),
-                "Unexpected callback : $it, expected ${type.java} with Network[$network]")
-    } as T
-
-    @JvmOverloads
-    open fun <T : CallbackEntry> expectCallback(
-        type: KClass<T>,
-        n: HasNetwork?,
-        timeoutMs: Long = defaultTimeoutMs
-    ) = expectCallback(type, n?.network, timeoutMs)
-
     fun expectAvailableCallbacks(
         n: HasNetwork,
         suspended: Boolean,