Merge "Add default network monitor to DevSdkIgnoreRunner" into main
diff --git a/staticlibs/tests/unit/src/com/android/testutils/DefaultNetworkRestoreMonitorTest.kt b/staticlibs/tests/unit/src/com/android/testutils/DefaultNetworkRestoreMonitorTest.kt
new file mode 100644
index 0000000..7e508fb
--- /dev/null
+++ b/staticlibs/tests/unit/src/com/android/testutils/DefaultNetworkRestoreMonitorTest.kt
@@ -0,0 +1,167 @@
+/*
+ * Copyright (C) 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.testutils
+
+import android.content.Context
+import android.content.pm.PackageManager
+import android.net.ConnectivityManager
+import android.net.ConnectivityManager.NetworkCallback
+import android.net.Network
+import android.net.NetworkCapabilities
+import android.net.NetworkCapabilities.TRANSPORT_CELLULAR
+import android.net.NetworkCapabilities.TRANSPORT_WIFI
+import org.junit.Test
+import org.junit.runner.Description
+import org.junit.runner.notification.RunListener
+import org.junit.runner.notification.RunNotifier
+import org.mockito.ArgumentCaptor
+import org.mockito.ArgumentMatchers.any
+import org.mockito.ArgumentMatchers.anyString
+import org.mockito.ArgumentMatchers.argThat
+import org.mockito.Mockito.doAnswer
+import org.mockito.Mockito.doNothing
+import org.mockito.Mockito.doReturn
+import org.mockito.Mockito.inOrder
+import org.mockito.Mockito.mock
+import org.mockito.Mockito.never
+import org.mockito.Mockito.verify
+
+class DefaultNetworkRestoreMonitorTest {
+    private val restoreDefaultNetworkDesc =
+            Description.createSuiteDescription("RestoreDefaultNetwork")
+    private val testDesc = Description.createTestDescription("testClass", "testMethod")
+    private val wifiCap = NetworkCapabilities.Builder()
+            .addTransportType(TRANSPORT_WIFI)
+            .addCapability(NetworkCapabilities.NET_CAPABILITY_VALIDATED)
+            .build()
+    private val cellCap = NetworkCapabilities.Builder()
+            .addTransportType(TRANSPORT_CELLULAR)
+            .addCapability(NetworkCapabilities.NET_CAPABILITY_VALIDATED)
+            .build()
+    private val cm = mock(ConnectivityManager::class.java)
+    private val pm = mock(PackageManager::class.java).also {
+        doReturn(true).`when`(it).hasSystemFeature(anyString())
+    }
+    private val ctx = mock(Context::class.java).also {
+        doReturn(cm).`when`(it).getSystemService(ConnectivityManager::class.java)
+        doReturn(pm).`when`(it).getPackageManager()
+    }
+    private val notifier = mock(RunNotifier::class.java)
+    private val defaultNetworkMonitor = DefaultNetworkRestoreMonitor(
+        ctx,
+        notifier,
+        timeoutMs = 0
+    )
+
+    private fun getRunListener(): RunListener {
+        val captor = ArgumentCaptor.forClass(RunListener::class.java)
+        verify(notifier).addListener(captor.capture())
+        return captor.value
+    }
+
+    private fun mockDefaultNetworkCapabilities(cap: NetworkCapabilities?) {
+        if (cap == null) {
+            doNothing().`when`(cm).registerDefaultNetworkCallback(any())
+            return
+        }
+        doAnswer {
+            val callback = it.getArgument(0) as NetworkCallback
+            callback.onCapabilitiesChanged(Network(100), cap)
+        }.`when`(cm).registerDefaultNetworkCallback(any())
+    }
+
+    @Test
+    fun testDefaultNetworkRestoreMonitor_defaultNetworkRestored() {
+        mockDefaultNetworkCapabilities(wifiCap)
+        defaultNetworkMonitor.init(mock(ConnectUtil::class.java))
+
+        val listener = getRunListener()
+        listener.testFinished(testDesc)
+
+        defaultNetworkMonitor.reportResultAndCleanUp(restoreDefaultNetworkDesc)
+        val inOrder = inOrder(notifier)
+        inOrder.verify(notifier).fireTestStarted(restoreDefaultNetworkDesc)
+        inOrder.verify(notifier, never()).fireTestFailure(any())
+        inOrder.verify(notifier).fireTestFinished(restoreDefaultNetworkDesc)
+        inOrder.verify(notifier).removeListener(listener)
+    }
+
+    @Test
+    fun testDefaultNetworkRestoreMonitor_testStartWithoutDefaultNetwork() {
+        // There is no default network when the tests start
+        mockDefaultNetworkCapabilities(null)
+        defaultNetworkMonitor.init(mock(ConnectUtil::class.java))
+
+        mockDefaultNetworkCapabilities(wifiCap)
+        val listener = getRunListener()
+        listener.testFinished(testDesc)
+
+        defaultNetworkMonitor.reportResultAndCleanUp(restoreDefaultNetworkDesc)
+        val inOrder = inOrder(notifier)
+        inOrder.verify(notifier).fireTestStarted(restoreDefaultNetworkDesc)
+        // fireTestFailure is called
+        inOrder.verify(notifier).fireTestFailure(any())
+        inOrder.verify(notifier).fireTestFinished(restoreDefaultNetworkDesc)
+        inOrder.verify(notifier).removeListener(listener)
+    }
+
+    @Test
+    fun testDefaultNetworkRestoreMonitor_testEndWithoutDefaultNetwork() {
+        mockDefaultNetworkCapabilities(wifiCap)
+        defaultNetworkMonitor.init(mock(ConnectUtil::class.java))
+
+        // There is no default network after the test
+        mockDefaultNetworkCapabilities(null)
+        val listener = getRunListener()
+        listener.testFinished(testDesc)
+
+        defaultNetworkMonitor.reportResultAndCleanUp(restoreDefaultNetworkDesc)
+        val inOrder = inOrder(notifier)
+        inOrder.verify(notifier).fireTestStarted(restoreDefaultNetworkDesc)
+        // fireTestFailure is called with method name
+        inOrder.verify(
+                notifier
+        ).fireTestFailure(
+                argThat{failure -> failure.exception.message?.contains("testMethod") ?: false}
+        )
+        inOrder.verify(notifier).fireTestFinished(restoreDefaultNetworkDesc)
+        inOrder.verify(notifier).removeListener(listener)
+    }
+
+    @Test
+    fun testDefaultNetworkRestoreMonitor_testChangeDefaultNetwork() {
+        mockDefaultNetworkCapabilities(wifiCap)
+        defaultNetworkMonitor.init(mock(ConnectUtil::class.java))
+
+        // The default network transport types change after the test
+        mockDefaultNetworkCapabilities(cellCap)
+        val listener = getRunListener()
+        listener.testFinished(testDesc)
+
+        defaultNetworkMonitor.reportResultAndCleanUp(restoreDefaultNetworkDesc)
+        val inOrder = inOrder(notifier)
+        inOrder.verify(notifier).fireTestStarted(restoreDefaultNetworkDesc)
+        // fireTestFailure is called with method name
+        inOrder.verify(
+                notifier
+        ).fireTestFailure(
+                argThat{failure -> failure.exception.message?.contains("testMethod") ?: false}
+        )
+        inOrder.verify(notifier).fireTestFinished(restoreDefaultNetworkDesc)
+        inOrder.verify(notifier).removeListener(listener)
+    }
+}
diff --git a/staticlibs/testutils/devicetests/com/android/testutils/DefaultNetworkRestoreMonitor.kt b/staticlibs/testutils/devicetests/com/android/testutils/DefaultNetworkRestoreMonitor.kt
new file mode 100644
index 0000000..1b709b2
--- /dev/null
+++ b/staticlibs/testutils/devicetests/com/android/testutils/DefaultNetworkRestoreMonitor.kt
@@ -0,0 +1,113 @@
+/*
+ * Copyright (C) 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.testutils
+
+import android.content.Context
+import android.content.pm.PackageManager
+import android.net.ConnectivityManager
+import android.net.Network
+import android.net.NetworkCapabilities
+import com.android.internal.annotations.VisibleForTesting
+import com.android.net.module.util.BitUtils
+import java.util.concurrent.CompletableFuture
+import java.util.concurrent.TimeUnit
+import org.junit.runner.Description
+import org.junit.runner.notification.Failure
+import org.junit.runner.notification.RunListener
+import org.junit.runner.notification.RunNotifier
+
+@VisibleForTesting(visibility = VisibleForTesting.Visibility.PRIVATE)
+class DefaultNetworkRestoreMonitor(
+        ctx: Context,
+        private val notifier: RunNotifier,
+        private val timeoutMs: Long = 3000
+) {
+    var firstFailure: Exception? = null
+    var initialTransports = 0L
+    val cm = ctx.getSystemService(ConnectivityManager::class.java)!!
+    val pm = ctx.packageManager
+    val listener = object : RunListener() {
+        override fun testFinished(desc: Description) {
+            // Only the first method that does not restore the default network should be blamed.
+            if (firstFailure != null) {
+                return
+            }
+            val cb = TestableNetworkCallback()
+            cm.registerDefaultNetworkCallback(cb)
+            try {
+                cb.eventuallyExpect<RecorderCallback.CallbackEntry.CapabilitiesChanged>(
+                    timeoutMs = timeoutMs
+                ) {
+                    BitUtils.packBits(it.caps.transportTypes) == initialTransports &&
+                            it.caps.hasCapability(NetworkCapabilities.NET_CAPABILITY_VALIDATED)
+                }
+            } catch (e: AssertionError) {
+                firstFailure = IllegalStateException(desc.methodName +
+                        " does not restore the default network")
+            } finally {
+                cm.unregisterNetworkCallback(cb)
+            }
+        }
+    }
+
+    fun init(connectUtil: ConnectUtil) {
+        // Ensure Wi-Fi and cellular connection before running test to avoid starting test
+        // with unexpected default network.
+        // ConnectivityTestTargetPreparer does the same thing, but it's possible that previous tests
+        // don't enable DefaultNetworkRestoreMonitor and the default network is not restored.
+        // This can be removed if all tests enable DefaultNetworkRestoreMonitor
+        if (pm.hasSystemFeature(PackageManager.FEATURE_WIFI)) {
+            connectUtil.ensureWifiValidated()
+        }
+        if (pm.hasSystemFeature(PackageManager.FEATURE_TELEPHONY)) {
+            connectUtil.ensureCellularValidated()
+        }
+
+        val capFuture = CompletableFuture<NetworkCapabilities>()
+        val cb = object : ConnectivityManager.NetworkCallback() {
+            override fun onCapabilitiesChanged(
+                    network: Network,
+                    cap: NetworkCapabilities
+            ) {
+                capFuture.complete(cap)
+            }
+        }
+        cm.registerDefaultNetworkCallback(cb)
+        try {
+            val cap = capFuture.get(100, TimeUnit.MILLISECONDS)
+            initialTransports = BitUtils.packBits(cap.transportTypes)
+        } catch (e: Exception) {
+            firstFailure = IllegalStateException(
+                    "Failed to get default network status before starting tests", e
+            )
+        } finally {
+            cm.unregisterNetworkCallback(cb)
+        }
+        notifier.addListener(listener)
+    }
+
+    fun reportResultAndCleanUp(desc: Description) {
+        notifier.fireTestStarted(desc)
+        if (firstFailure != null) {
+            notifier.fireTestFailure(
+                    Failure(desc, firstFailure)
+            )
+        }
+        notifier.fireTestFinished(desc)
+        notifier.removeListener(listener)
+    }
+}
diff --git a/staticlibs/testutils/devicetests/com/android/testutils/DevSdkIgnoreRunner.kt b/staticlibs/testutils/devicetests/com/android/testutils/DevSdkIgnoreRunner.kt
index 8687ac7..a014834 100644
--- a/staticlibs/testutils/devicetests/com/android/testutils/DevSdkIgnoreRunner.kt
+++ b/staticlibs/testutils/devicetests/com/android/testutils/DevSdkIgnoreRunner.kt
@@ -16,6 +16,8 @@
 
 package com.android.testutils
 
+import android.content.Context
+import androidx.test.core.app.ApplicationProvider
 import androidx.test.ext.junit.runners.AndroidJUnit4
 import com.android.net.module.util.LinkPropertiesUtils.CompareOrUpdateResult
 import com.android.testutils.DevSdkIgnoreRule.IgnoreAfter
@@ -57,6 +59,10 @@
 class DevSdkIgnoreRunner(private val klass: Class<*>) : Runner(), Filterable, Sortable {
     private val leakMonitorDesc = Description.createTestDescription(klass, "ThreadLeakMonitor")
     private val shouldThreadLeakFailTest = klass.isAnnotationPresent(MonitorThreadLeak::class.java)
+    private val restoreDefaultNetworkDesc =
+            Description.createTestDescription(klass, "RestoreDefaultNetwork")
+    private val restoreDefaultNetwork = klass.isAnnotationPresent(RestoreDefaultNetwork::class.java)
+    val ctx = ApplicationProvider.getApplicationContext<Context>()
 
     // Inference correctly infers Runner & Filterable & Sortable for |baseRunner|, but the
     // Java bytecode doesn't have a way to express this. Give this type a name by wrapping it.
@@ -71,6 +77,10 @@
     // TODO(b/307693729): Remove this annotation and monitor thread leak by default.
     annotation class MonitorThreadLeak
 
+    // Annotation for test classes to indicate the test runner should verify the default network is
+    // restored after each test.
+    annotation class RestoreDefaultNetwork
+
     private val baseRunner: RunnerWrapper<*>? = klass.let {
         val ignoreAfter = it.getAnnotation(IgnoreAfter::class.java)
         val ignoreUpTo = it.getAnnotation(IgnoreUpTo::class.java)
@@ -125,6 +135,14 @@
             )
             return
         }
+
+        val networkRestoreMonitor = if (restoreDefaultNetwork) {
+            DefaultNetworkRestoreMonitor(ctx, notifier).apply{
+                init(ConnectUtil(ctx))
+            }
+        } else {
+            null
+        }
         val threadCountsBeforeTest = if (shouldThreadLeakFailTest) {
             // Dump threads as a baseline to monitor thread leaks.
             getAllThreadNameCounts()
@@ -137,6 +155,7 @@
         if (threadCountsBeforeTest != null) {
             checkThreadLeak(notifier, threadCountsBeforeTest)
         }
+        networkRestoreMonitor?.reportResultAndCleanUp(restoreDefaultNetworkDesc)
         // Clears up internal state of all inline mocks.
         // TODO: Call clearInlineMocks() at the end of each test.
         Mockito.framework().clearInlineMocks()
@@ -163,6 +182,9 @@
             if (shouldThreadLeakFailTest) {
                 it.addChild(leakMonitorDesc)
             }
+            if (restoreDefaultNetwork) {
+                it.addChild(restoreDefaultNetworkDesc)
+            }
         }
     }
 
@@ -173,7 +195,14 @@
         // When ignoring the tests, a skipped placeholder test is reported, so test count is 1.
         if (baseRunner == null) return 1
 
-        return baseRunner.testCount() + if (shouldThreadLeakFailTest) 1 else 0
+        var testCount = baseRunner.testCount()
+        if (shouldThreadLeakFailTest) {
+            testCount += 1
+        }
+        if (restoreDefaultNetwork) {
+            testCount += 1
+        }
+        return testCount
     }
 
     @Throws(NoTestsRemainException::class)
diff --git a/tests/cts/net/src/android/net/cts/MultinetworkApiTest.java b/tests/cts/net/src/android/net/cts/MultinetworkApiTest.java
index 06a827b..2c7d5c6 100644
--- a/tests/cts/net/src/android/net/cts/MultinetworkApiTest.java
+++ b/tests/cts/net/src/android/net/cts/MultinetworkApiTest.java
@@ -40,10 +40,10 @@
 import android.system.OsConstants;
 import android.util.ArraySet;
 
-import androidx.test.ext.junit.runners.AndroidJUnit4;
 import androidx.test.platform.app.InstrumentationRegistry;
 
 import com.android.testutils.AutoReleaseNetworkCallbackRule;
+import com.android.testutils.DevSdkIgnoreRunner;
 import com.android.testutils.DeviceConfigRule;
 
 import org.junit.Before;
@@ -53,7 +53,8 @@
 
 import java.util.Set;
 
-@RunWith(AndroidJUnit4.class)
+@DevSdkIgnoreRunner.RestoreDefaultNetwork
+@RunWith(DevSdkIgnoreRunner.class)
 public class MultinetworkApiTest {
     @Rule(order = 1)
     public final DeviceConfigRule mDeviceConfigRule = new DeviceConfigRule();
diff --git a/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt b/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt
index beb9274..60081d4 100644
--- a/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt
+++ b/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt
@@ -194,15 +194,16 @@
 // TODO : enable this in a Mainline update or in V.
 private const val SHOULD_CREATE_NETWORKS_IMMEDIATELY = false
 
-@RunWith(DevSdkIgnoreRunner::class)
-// NetworkAgent is not updatable in R-, so this test does not need to be compatible with older
-// versions. NetworkAgent was also based on AsyncChannel before S so cannot be tested the same way.
-@IgnoreUpTo(Build.VERSION_CODES.R)
+@AppModeFull(reason = "Instant apps can't use NetworkAgent because it needs NETWORK_FACTORY'.")
 // NetworkAgent is updated as part of the connectivity module, and running NetworkAgent tests in MTS
 // for modules other than Connectivity does not provide much value. Only run them in connectivity
 // module MTS, so the tests only need to cover the case of an updated NetworkAgent.
 @ConnectivityModuleTest
-@AppModeFull(reason = "Instant apps can't use NetworkAgent because it needs NETWORK_FACTORY'.")
+@DevSdkIgnoreRunner.RestoreDefaultNetwork
+// NetworkAgent is not updatable in R-, so this test does not need to be compatible with older
+// versions. NetworkAgent was also based on AsyncChannel before S so cannot be tested the same way.
+@IgnoreUpTo(Build.VERSION_CODES.R)
+@RunWith(DevSdkIgnoreRunner::class)
 class NetworkAgentTest {
     private val LOCAL_IPV4_ADDRESS = InetAddresses.parseNumericAddress("192.0.2.1")
     private val REMOTE_IPV4_ADDRESS = InetAddresses.parseNumericAddress("192.0.2.2")