Merge "Correctly test for an object being of the correct type"
diff --git a/staticlibs/tests/unit/src/com/android/testutils/TestableNetworkCallbackTest.kt b/staticlibs/tests/unit/src/com/android/testutils/TestableNetworkCallbackTest.kt
index 4ed881a..e838bdc 100644
--- a/staticlibs/tests/unit/src/com/android/testutils/TestableNetworkCallbackTest.kt
+++ b/staticlibs/tests/unit/src/com/android/testutils/TestableNetworkCallbackTest.kt
@@ -354,6 +354,13 @@
     }
 
     @Test
+    fun testExpectClass() {
+        val net = Network(1)
+        mCallback.onAvailable(net)
+        assertFails { mCallback.expect(LOST, net) }
+    }
+
+    @Test
     fun testPoll() {
         assertNull(mCallback.poll(SHORT_TIMEOUT_MS))
         TNCInterpreter.interpretTestSpec(initial = mCallback, lineShift = 1,
diff --git a/staticlibs/testutils/devicetests/com/android/testutils/TestableNetworkCallback.kt b/staticlibs/testutils/devicetests/com/android/testutils/TestableNetworkCallback.kt
index 0e73112..4b6aea2 100644
--- a/staticlibs/testutils/devicetests/com/android/testutils/TestableNetworkCallback.kt
+++ b/staticlibs/testutils/devicetests/com/android/testutils/TestableNetworkCallback.kt
@@ -236,7 +236,11 @@
         errorMsg: String? = null,
         test: (T) -> Boolean = { true }
     ) = expect<CallbackEntry>(network, timeoutMs, errorMsg) {
-        test(it as? T ?: fail("Expected callback ${type.simpleName}, got $it"))
+        if (type.isInstance(it)) {
+            test(it as T) // Cast can't fail since type.isInstance(it) and type: KClass<T>
+        } else {
+            fail("Expected callback ${type.simpleName}, got $it")
+        }
     } as T
 
     @JvmOverloads