diff --git a/staticlibs/tests/unit/src/com/android/testutils/TestableNetworkCallbackTest.kt b/staticlibs/tests/unit/src/com/android/testutils/TestableNetworkCallbackTest.kt
index eed31e0..5acdb34 100644
--- a/staticlibs/tests/unit/src/com/android/testutils/TestableNetworkCallbackTest.kt
+++ b/staticlibs/tests/unit/src/com/android/testutils/TestableNetworkCallbackTest.kt
@@ -179,33 +179,31 @@
     }
 
     @Test
-    fun testCapabilitiesThat() {
+    fun testExpectCaps() {
         val net = Network(101)
         val netCaps = NetworkCapabilities().addCapability(NOT_METERED).addTransportType(WIFI)
         // Check that expecting capabilitiesThat anything fails when no callback has been received.
-        assertFails { mCallback.expectCapabilitiesThat(net, SHORT_TIMEOUT_MS) { true } }
+        assertFails { mCallback.expectCaps(net, SHORT_TIMEOUT_MS) { true } }
 
         // Basic test for true and false
         mCallback.onCapabilitiesChanged(net, netCaps)
-        mCallback.expectCapabilitiesThat(net) { true }
+        mCallback.expectCaps(net) { true }
         mCallback.onCapabilitiesChanged(net, netCaps)
-        assertFails { mCallback.expectCapabilitiesThat(net, SHORT_TIMEOUT_MS) { false } }
+        assertFails { mCallback.expectCaps(net, SHORT_TIMEOUT_MS) { false } }
 
         // Try a positive and a negative case
         mCallback.onCapabilitiesChanged(net, netCaps)
-        mCallback.expectCapabilitiesThat(net) { caps ->
-            caps.hasCapability(NOT_METERED) &&
-                    caps.hasTransport(WIFI) &&
-                    !caps.hasTransport(CELLULAR)
+        mCallback.expectCaps(net) {
+            it.hasCapability(NOT_METERED) && it.hasTransport(WIFI) && !it.hasTransport(CELLULAR)
         }
         mCallback.onCapabilitiesChanged(net, netCaps)
-        assertFails { mCallback.expectCapabilitiesThat(net, SHORT_TIMEOUT_MS) { caps ->
-            caps.hasTransport(CELLULAR)
-        } }
+        assertFails { mCallback.expectCaps(net, SHORT_TIMEOUT_MS) { it.hasTransport(CELLULAR) } }
 
         // Try a matching callback on the wrong network
         mCallback.onCapabilitiesChanged(net, netCaps)
-        assertFails { mCallback.expectCapabilitiesThat(Network(100), SHORT_TIMEOUT_MS) { true } }
+        assertFails {
+            mCallback.expectCaps(Network(100), SHORT_TIMEOUT_MS) { true }
+        }
     }
 
     @Test
diff --git a/staticlibs/testutils/devicetests/com/android/testutils/TestableNetworkCallback.kt b/staticlibs/testutils/devicetests/com/android/testutils/TestableNetworkCallback.kt
index 68d5fa9..d58d582 100644
--- a/staticlibs/testutils/devicetests/com/android/testutils/TestableNetworkCallback.kt
+++ b/staticlibs/testutils/devicetests/com/android/testutils/TestableNetworkCallback.kt
@@ -414,13 +414,6 @@
         crossinline predicate: (T) -> Boolean = { true }
     ) = history.poll(timeoutMs, from) { it is T && predicate(it) } as T?
 
-    inline fun expectCapabilitiesThat(
-        net: Network,
-        tmt: Long = defaultTimeoutMs,
-        valid: (NetworkCapabilities) -> Boolean
-    ): CapabilitiesChanged =
-            expect(net, tmt, "Capabilities don't match expectations") { valid(it.caps) }
-
     inline fun expectLinkPropertiesThat(
         net: Network,
         tmt: Long = defaultTimeoutMs,
@@ -472,10 +465,8 @@
         if (suspended) {
             expect<Suspended>(net, tmt)
         }
-        expectCapabilitiesThat(net, tmt) {
-            validated == null || validated == it.hasCapability(
-                NET_CAPABILITY_VALIDATED
-            )
+        expect<CapabilitiesChanged>(net, tmt) {
+            validated == null || validated == it.caps.hasCapability(NET_CAPABILITY_VALIDATED)
         }
         expect<LinkPropertiesChanged>(net, tmt)
     }
@@ -514,7 +505,7 @@
     // when a network connects and satisfies a callback, and then immediately validates.
     fun expectAvailableThenValidatedCallbacks(net: Network, tmt: Long = defaultTimeoutMs) {
         expectAvailableCallbacks(net, validated = false, tmt = tmt)
-        expectCapabilitiesThat(net, tmt) { it.hasCapability(NET_CAPABILITY_VALIDATED) }
+        expectCaps(net, tmt) { it.hasCapability(NET_CAPABILITY_VALIDATED) }
     }
 
     fun expectAvailableThenValidatedCallbacks(
@@ -524,7 +515,7 @@
     ) {
         expectAvailableCallbacks(net, validated = false, suspended = false,
                 blockedStatus = blockedStatus, tmt = tmt)
-        expectCapabilitiesThat(net, tmt) { it.hasCapability(NET_CAPABILITY_VALIDATED) }
+        expectCaps(net, tmt) { it.hasCapability(NET_CAPABILITY_VALIDATED) }
     }
 
     // Temporary Java compat measure : have MockNetworkAgent implement this so that all existing
@@ -578,29 +569,42 @@
     ) = expectLinkPropertiesThat(n.network, tmt, valid)
 
     @JvmOverloads
-    fun expectCapabilitiesThat(
+    fun expectCaps(
         n: HasNetwork,
         tmt: Long = defaultTimeoutMs,
+        valid: (NetworkCapabilities) -> Boolean = { true }
+    ) = expect<CapabilitiesChanged>(n.network, tmt) { valid(it.caps) }.caps
+
+    @JvmOverloads
+    fun expectCaps(
+        n: Network,
+        tmt: Long = defaultTimeoutMs,
         valid: (NetworkCapabilities) -> Boolean
-    ) = expectCapabilitiesThat(n.network, tmt, valid)
+    ) = expect<CapabilitiesChanged>(n, tmt) { valid(it.caps) }.caps
+
+    fun expectCaps(
+        n: HasNetwork,
+        valid: (NetworkCapabilities) -> Boolean
+    ) = expect<CapabilitiesChanged>(n.network) { valid(it.caps) }.caps
+
+    fun expectCaps(
+        tmt: Long,
+        valid: (NetworkCapabilities) -> Boolean
+    ) = expect<CapabilitiesChanged>(ANY_NETWORK, tmt) { valid(it.caps) }.caps
 
     @JvmOverloads
     fun expectCapabilitiesWith(
         capability: Int,
         n: HasNetwork,
         timeoutMs: Long = defaultTimeoutMs
-    ): NetworkCapabilities {
-        return expectCapabilitiesThat(n.network, timeoutMs) { it.hasCapability(capability) }.caps
-    }
+    ) = expectCaps(n, timeoutMs) { it.hasCapability(capability) }
 
     @JvmOverloads
     fun expectCapabilitiesWithout(
         capability: Int,
         n: HasNetwork,
         timeoutMs: Long = defaultTimeoutMs
-    ): NetworkCapabilities {
-        return expectCapabilitiesThat(n.network, timeoutMs) { !it.hasCapability(capability) }.caps
-    }
+    ) = expectCaps(n, timeoutMs) { !it.hasCapability(capability) }
 
     fun expectBlockedStatusCallback(expectBlocked: Boolean, n: HasNetwork) {
         expectBlockedStatusCallback(expectBlocked, n.network, defaultTimeoutMs)
