Merge "Consolidate DNS label comparison functions in DnsUtils" into main
diff --git a/PREUPLOAD.cfg b/PREUPLOAD.cfg
index 83619d6..39009cb 100644
--- a/PREUPLOAD.cfg
+++ b/PREUPLOAD.cfg
@@ -1,6 +1,15 @@
+[Builtin Hooks]
+bpfmt = true
+clang_format = true
+ktfmt = true
+
+[Builtin Hooks Options]
+clang_format = --commit ${PREUPLOAD_COMMIT} --style file --extensions c,h,cc,cpp,hpp
+ktfmt = --kotlinlang-style
+
 [Hook Scripts]
 checkstyle_hook = ${REPO_ROOT}/prebuilts/checkstyle/checkstyle.py --sha ${PREUPLOAD_COMMIT}
 
-ktlint_hook = ${REPO_ROOT}/prebuilts/ktlint/ktlint.py -f ${PREUPLOAD_FILES}
+ktlint_hook = ${REPO_ROOT}/prebuilts/ktlint/ktlint.py --no-verify-format -f ${PREUPLOAD_FILES}
 
 hidden_api_txt_checksorted_hook = ${REPO_ROOT}/tools/platform-compat/hiddenapi/checksorted_sha.sh ${PREUPLOAD_COMMIT} ${REPO_ROOT}
diff --git a/service/src/com/android/server/ConnectivityService.java b/service/src/com/android/server/ConnectivityService.java
index 5167570..4d4dacf 100755
--- a/service/src/com/android/server/ConnectivityService.java
+++ b/service/src/com/android/server/ConnectivityService.java
@@ -330,9 +330,9 @@
 import com.android.net.module.util.LinkPropertiesUtils.CompareOrUpdateResult;
 import com.android.net.module.util.LinkPropertiesUtils.CompareResult;
 import com.android.net.module.util.LocationPermissionChecker;
-import com.android.net.module.util.RoutingCoordinatorService;
 import com.android.net.module.util.PerUidCounter;
 import com.android.net.module.util.PermissionUtils;
+import com.android.net.module.util.RoutingCoordinatorService;
 import com.android.net.module.util.TcUtils;
 import com.android.net.module.util.netlink.InetDiagMessage;
 import com.android.networkstack.apishim.BroadcastOptionsShimImpl;
@@ -7779,6 +7779,11 @@
             }
         }
 
+        boolean isCallbackOverridden(int callbackId) {
+            return !mUseDeclaredMethodsForCallbacksEnabled
+                    || (mDeclaredMethodsFlags & (1 << callbackId)) != 0;
+        }
+
         boolean hasHigherOrderThan(@NonNull final NetworkRequestInfo target) {
             // Compare two preference orders.
             return mPreferenceOrder < target.mPreferenceOrder;
@@ -10248,6 +10253,18 @@
         return new LocalNetworkInfo.Builder().setUpstreamNetwork(upstream).build();
     }
 
+    private Bundle makeCommonBundleForCallback(@NonNull final NetworkRequestInfo nri,
+            @Nullable Network network) {
+        final Bundle bundle = new Bundle();
+        // TODO b/177608132: make sure callbacks are indexed by NRIs and not NetworkRequest objects.
+        // TODO: check if defensive copies of data is needed.
+        putParcelable(bundle, nri.getNetworkRequestForCallback());
+        if (network != null) {
+            putParcelable(bundle, network);
+        }
+        return bundle;
+    }
+
     // networkAgent is only allowed to be null if notificationType is
     // CALLBACK_UNAVAIL. This is because UNAVAIL is about no network being
     // available, while all other cases are about some particular network.
@@ -10260,22 +10277,17 @@
             // are Type.LISTEN, but should not have NetworkCallbacks invoked.
             return;
         }
-        if (mUseDeclaredMethodsForCallbacksEnabled
-                && (nri.mDeclaredMethodsFlags & (1 << notificationType)) == 0) {
+        if (!nri.isCallbackOverridden(notificationType)) {
             // No need to send the notification as the recipient method is not overridden
             return;
         }
-        final Bundle bundle = new Bundle();
-        // TODO b/177608132: make sure callbacks are indexed by NRIs and not NetworkRequest objects.
-        // TODO: check if defensive copies of data is needed.
-        final NetworkRequest nrForCallback = nri.getNetworkRequestForCallback();
-        putParcelable(bundle, nrForCallback);
-        Message msg = Message.obtain();
-        if (notificationType != CALLBACK_UNAVAIL) {
-            putParcelable(bundle, networkAgent.network);
-        }
+        final Network bundleNetwork = notificationType == CALLBACK_UNAVAIL
+                ? null
+                : networkAgent.network;
+        final Bundle bundle = makeCommonBundleForCallback(nri, bundleNetwork);
         final boolean includeLocationSensitiveInfo =
                 (nri.mCallbackFlags & NetworkCallback.FLAG_INCLUDE_LOCATION_INFO) != 0;
+        final NetworkRequest nrForCallback = nri.getNetworkRequestForCallback();
         switch (notificationType) {
             case CALLBACK_AVAILABLE: {
                 final NetworkCapabilities nc =
@@ -10292,12 +10304,6 @@
                 // method here.
                 bundle.putParcelable(LocalNetworkInfo.class.getSimpleName(),
                         localNetworkInfoForNai(networkAgent));
-                // For this notification, arg1 contains the blocked status.
-                msg.arg1 = arg1;
-                break;
-            }
-            case CALLBACK_LOSING: {
-                msg.arg1 = arg1;
                 break;
             }
             case CALLBACK_CAP_CHANGED: {
@@ -10320,7 +10326,6 @@
             }
             case CALLBACK_BLK_CHANGED: {
                 maybeLogBlockedStatusChanged(nri, networkAgent.network, arg1);
-                msg.arg1 = arg1;
                 break;
             }
             case CALLBACK_LOCAL_NETWORK_INFO_CHANGED: {
@@ -10332,17 +10337,26 @@
                 break;
             }
         }
+        callCallbackForRequest(nri, notificationType, bundle, arg1);
+    }
+
+    private void callCallbackForRequest(@NonNull final NetworkRequestInfo nri, int notificationType,
+            Bundle bundle, int arg1) {
+        Message msg = Message.obtain();
+        msg.arg1 = arg1;
         msg.what = notificationType;
         msg.setData(bundle);
         try {
             if (VDBG) {
                 String notification = ConnectivityManager.getCallbackName(notificationType);
-                log("sending notification " + notification + " for " + nrForCallback);
+                log("sending notification " + notification + " for "
+                        + nri.getNetworkRequestForCallback());
             }
             nri.mMessenger.send(msg);
         } catch (RemoteException e) {
             // may occur naturally in the race of binder death.
-            loge("RemoteException caught trying to send a callback msg for " + nrForCallback);
+            loge("RemoteException caught trying to send a callback msg for "
+                    + nri.getNetworkRequestForCallback());
         }
     }
 
@@ -11431,11 +11445,7 @@
             return;
         }
 
-        final int blockedReasons = mUidBlockedReasons.get(nri.mAsUid, BLOCKED_REASON_NONE);
-        final boolean metered = nai.networkCapabilities.isMetered();
-        final boolean vpnBlocked = isUidBlockedByVpn(nri.mAsUid, mVpnBlockedUidRanges);
-        callCallbackForRequest(nri, nai, CALLBACK_AVAILABLE,
-                getBlockedState(nri.mAsUid, blockedReasons, metered, vpnBlocked));
+        callCallbackForRequest(nri, nai, CALLBACK_AVAILABLE, getBlockedState(nai, nri.mAsUid));
     }
 
     // Notify the requests on this NAI that the network is now lingered.
@@ -11465,6 +11475,13 @@
                 : reasons & ~BLOCKED_REASON_LOCKDOWN_VPN;
     }
 
+    private int getBlockedState(@NonNull NetworkAgentInfo nai, int uid) {
+        final boolean metered = nai.networkCapabilities.isMetered();
+        final boolean vpnBlocked = isUidBlockedByVpn(uid, mVpnBlockedUidRanges);
+        final int blockedReasons = mUidBlockedReasons.get(uid, BLOCKED_REASON_NONE);
+        return getBlockedState(uid, blockedReasons, metered, vpnBlocked);
+    }
+
     private void setUidBlockedReasons(int uid, @BlockedReason int blockedReasons) {
         if (blockedReasons == BLOCKED_REASON_NONE) {
             mUidBlockedReasons.delete(uid);
diff --git a/staticlibs/tests/unit/src/com/android/testutils/DefaultNetworkRestoreMonitorTest.kt b/staticlibs/tests/unit/src/com/android/testutils/DefaultNetworkRestoreMonitorTest.kt
new file mode 100644
index 0000000..7e508fb
--- /dev/null
+++ b/staticlibs/tests/unit/src/com/android/testutils/DefaultNetworkRestoreMonitorTest.kt
@@ -0,0 +1,167 @@
+/*
+ * Copyright (C) 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.testutils
+
+import android.content.Context
+import android.content.pm.PackageManager
+import android.net.ConnectivityManager
+import android.net.ConnectivityManager.NetworkCallback
+import android.net.Network
+import android.net.NetworkCapabilities
+import android.net.NetworkCapabilities.TRANSPORT_CELLULAR
+import android.net.NetworkCapabilities.TRANSPORT_WIFI
+import org.junit.Test
+import org.junit.runner.Description
+import org.junit.runner.notification.RunListener
+import org.junit.runner.notification.RunNotifier
+import org.mockito.ArgumentCaptor
+import org.mockito.ArgumentMatchers.any
+import org.mockito.ArgumentMatchers.anyString
+import org.mockito.ArgumentMatchers.argThat
+import org.mockito.Mockito.doAnswer
+import org.mockito.Mockito.doNothing
+import org.mockito.Mockito.doReturn
+import org.mockito.Mockito.inOrder
+import org.mockito.Mockito.mock
+import org.mockito.Mockito.never
+import org.mockito.Mockito.verify
+
+class DefaultNetworkRestoreMonitorTest {
+    private val restoreDefaultNetworkDesc =
+            Description.createSuiteDescription("RestoreDefaultNetwork")
+    private val testDesc = Description.createTestDescription("testClass", "testMethod")
+    private val wifiCap = NetworkCapabilities.Builder()
+            .addTransportType(TRANSPORT_WIFI)
+            .addCapability(NetworkCapabilities.NET_CAPABILITY_VALIDATED)
+            .build()
+    private val cellCap = NetworkCapabilities.Builder()
+            .addTransportType(TRANSPORT_CELLULAR)
+            .addCapability(NetworkCapabilities.NET_CAPABILITY_VALIDATED)
+            .build()
+    private val cm = mock(ConnectivityManager::class.java)
+    private val pm = mock(PackageManager::class.java).also {
+        doReturn(true).`when`(it).hasSystemFeature(anyString())
+    }
+    private val ctx = mock(Context::class.java).also {
+        doReturn(cm).`when`(it).getSystemService(ConnectivityManager::class.java)
+        doReturn(pm).`when`(it).getPackageManager()
+    }
+    private val notifier = mock(RunNotifier::class.java)
+    private val defaultNetworkMonitor = DefaultNetworkRestoreMonitor(
+        ctx,
+        notifier,
+        timeoutMs = 0
+    )
+
+    private fun getRunListener(): RunListener {
+        val captor = ArgumentCaptor.forClass(RunListener::class.java)
+        verify(notifier).addListener(captor.capture())
+        return captor.value
+    }
+
+    private fun mockDefaultNetworkCapabilities(cap: NetworkCapabilities?) {
+        if (cap == null) {
+            doNothing().`when`(cm).registerDefaultNetworkCallback(any())
+            return
+        }
+        doAnswer {
+            val callback = it.getArgument(0) as NetworkCallback
+            callback.onCapabilitiesChanged(Network(100), cap)
+        }.`when`(cm).registerDefaultNetworkCallback(any())
+    }
+
+    @Test
+    fun testDefaultNetworkRestoreMonitor_defaultNetworkRestored() {
+        mockDefaultNetworkCapabilities(wifiCap)
+        defaultNetworkMonitor.init(mock(ConnectUtil::class.java))
+
+        val listener = getRunListener()
+        listener.testFinished(testDesc)
+
+        defaultNetworkMonitor.reportResultAndCleanUp(restoreDefaultNetworkDesc)
+        val inOrder = inOrder(notifier)
+        inOrder.verify(notifier).fireTestStarted(restoreDefaultNetworkDesc)
+        inOrder.verify(notifier, never()).fireTestFailure(any())
+        inOrder.verify(notifier).fireTestFinished(restoreDefaultNetworkDesc)
+        inOrder.verify(notifier).removeListener(listener)
+    }
+
+    @Test
+    fun testDefaultNetworkRestoreMonitor_testStartWithoutDefaultNetwork() {
+        // There is no default network when the tests start
+        mockDefaultNetworkCapabilities(null)
+        defaultNetworkMonitor.init(mock(ConnectUtil::class.java))
+
+        mockDefaultNetworkCapabilities(wifiCap)
+        val listener = getRunListener()
+        listener.testFinished(testDesc)
+
+        defaultNetworkMonitor.reportResultAndCleanUp(restoreDefaultNetworkDesc)
+        val inOrder = inOrder(notifier)
+        inOrder.verify(notifier).fireTestStarted(restoreDefaultNetworkDesc)
+        // fireTestFailure is called
+        inOrder.verify(notifier).fireTestFailure(any())
+        inOrder.verify(notifier).fireTestFinished(restoreDefaultNetworkDesc)
+        inOrder.verify(notifier).removeListener(listener)
+    }
+
+    @Test
+    fun testDefaultNetworkRestoreMonitor_testEndWithoutDefaultNetwork() {
+        mockDefaultNetworkCapabilities(wifiCap)
+        defaultNetworkMonitor.init(mock(ConnectUtil::class.java))
+
+        // There is no default network after the test
+        mockDefaultNetworkCapabilities(null)
+        val listener = getRunListener()
+        listener.testFinished(testDesc)
+
+        defaultNetworkMonitor.reportResultAndCleanUp(restoreDefaultNetworkDesc)
+        val inOrder = inOrder(notifier)
+        inOrder.verify(notifier).fireTestStarted(restoreDefaultNetworkDesc)
+        // fireTestFailure is called with method name
+        inOrder.verify(
+                notifier
+        ).fireTestFailure(
+                argThat{failure -> failure.exception.message?.contains("testMethod") ?: false}
+        )
+        inOrder.verify(notifier).fireTestFinished(restoreDefaultNetworkDesc)
+        inOrder.verify(notifier).removeListener(listener)
+    }
+
+    @Test
+    fun testDefaultNetworkRestoreMonitor_testChangeDefaultNetwork() {
+        mockDefaultNetworkCapabilities(wifiCap)
+        defaultNetworkMonitor.init(mock(ConnectUtil::class.java))
+
+        // The default network transport types change after the test
+        mockDefaultNetworkCapabilities(cellCap)
+        val listener = getRunListener()
+        listener.testFinished(testDesc)
+
+        defaultNetworkMonitor.reportResultAndCleanUp(restoreDefaultNetworkDesc)
+        val inOrder = inOrder(notifier)
+        inOrder.verify(notifier).fireTestStarted(restoreDefaultNetworkDesc)
+        // fireTestFailure is called with method name
+        inOrder.verify(
+                notifier
+        ).fireTestFailure(
+                argThat{failure -> failure.exception.message?.contains("testMethod") ?: false}
+        )
+        inOrder.verify(notifier).fireTestFinished(restoreDefaultNetworkDesc)
+        inOrder.verify(notifier).removeListener(listener)
+    }
+}
diff --git a/staticlibs/testutils/devicetests/com/android/testutils/DefaultNetworkRestoreMonitor.kt b/staticlibs/testutils/devicetests/com/android/testutils/DefaultNetworkRestoreMonitor.kt
new file mode 100644
index 0000000..1b709b2
--- /dev/null
+++ b/staticlibs/testutils/devicetests/com/android/testutils/DefaultNetworkRestoreMonitor.kt
@@ -0,0 +1,113 @@
+/*
+ * Copyright (C) 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.testutils
+
+import android.content.Context
+import android.content.pm.PackageManager
+import android.net.ConnectivityManager
+import android.net.Network
+import android.net.NetworkCapabilities
+import com.android.internal.annotations.VisibleForTesting
+import com.android.net.module.util.BitUtils
+import java.util.concurrent.CompletableFuture
+import java.util.concurrent.TimeUnit
+import org.junit.runner.Description
+import org.junit.runner.notification.Failure
+import org.junit.runner.notification.RunListener
+import org.junit.runner.notification.RunNotifier
+
+@VisibleForTesting(visibility = VisibleForTesting.Visibility.PRIVATE)
+class DefaultNetworkRestoreMonitor(
+        ctx: Context,
+        private val notifier: RunNotifier,
+        private val timeoutMs: Long = 3000
+) {
+    var firstFailure: Exception? = null
+    var initialTransports = 0L
+    val cm = ctx.getSystemService(ConnectivityManager::class.java)!!
+    val pm = ctx.packageManager
+    val listener = object : RunListener() {
+        override fun testFinished(desc: Description) {
+            // Only the first method that does not restore the default network should be blamed.
+            if (firstFailure != null) {
+                return
+            }
+            val cb = TestableNetworkCallback()
+            cm.registerDefaultNetworkCallback(cb)
+            try {
+                cb.eventuallyExpect<RecorderCallback.CallbackEntry.CapabilitiesChanged>(
+                    timeoutMs = timeoutMs
+                ) {
+                    BitUtils.packBits(it.caps.transportTypes) == initialTransports &&
+                            it.caps.hasCapability(NetworkCapabilities.NET_CAPABILITY_VALIDATED)
+                }
+            } catch (e: AssertionError) {
+                firstFailure = IllegalStateException(desc.methodName +
+                        " does not restore the default network")
+            } finally {
+                cm.unregisterNetworkCallback(cb)
+            }
+        }
+    }
+
+    fun init(connectUtil: ConnectUtil) {
+        // Ensure Wi-Fi and cellular connection before running test to avoid starting test
+        // with unexpected default network.
+        // ConnectivityTestTargetPreparer does the same thing, but it's possible that previous tests
+        // don't enable DefaultNetworkRestoreMonitor and the default network is not restored.
+        // This can be removed if all tests enable DefaultNetworkRestoreMonitor
+        if (pm.hasSystemFeature(PackageManager.FEATURE_WIFI)) {
+            connectUtil.ensureWifiValidated()
+        }
+        if (pm.hasSystemFeature(PackageManager.FEATURE_TELEPHONY)) {
+            connectUtil.ensureCellularValidated()
+        }
+
+        val capFuture = CompletableFuture<NetworkCapabilities>()
+        val cb = object : ConnectivityManager.NetworkCallback() {
+            override fun onCapabilitiesChanged(
+                    network: Network,
+                    cap: NetworkCapabilities
+            ) {
+                capFuture.complete(cap)
+            }
+        }
+        cm.registerDefaultNetworkCallback(cb)
+        try {
+            val cap = capFuture.get(100, TimeUnit.MILLISECONDS)
+            initialTransports = BitUtils.packBits(cap.transportTypes)
+        } catch (e: Exception) {
+            firstFailure = IllegalStateException(
+                    "Failed to get default network status before starting tests", e
+            )
+        } finally {
+            cm.unregisterNetworkCallback(cb)
+        }
+        notifier.addListener(listener)
+    }
+
+    fun reportResultAndCleanUp(desc: Description) {
+        notifier.fireTestStarted(desc)
+        if (firstFailure != null) {
+            notifier.fireTestFailure(
+                    Failure(desc, firstFailure)
+            )
+        }
+        notifier.fireTestFinished(desc)
+        notifier.removeListener(listener)
+    }
+}
diff --git a/staticlibs/testutils/devicetests/com/android/testutils/DevSdkIgnoreRunner.kt b/staticlibs/testutils/devicetests/com/android/testutils/DevSdkIgnoreRunner.kt
index 8687ac7..a014834 100644
--- a/staticlibs/testutils/devicetests/com/android/testutils/DevSdkIgnoreRunner.kt
+++ b/staticlibs/testutils/devicetests/com/android/testutils/DevSdkIgnoreRunner.kt
@@ -16,6 +16,8 @@
 
 package com.android.testutils
 
+import android.content.Context
+import androidx.test.core.app.ApplicationProvider
 import androidx.test.ext.junit.runners.AndroidJUnit4
 import com.android.net.module.util.LinkPropertiesUtils.CompareOrUpdateResult
 import com.android.testutils.DevSdkIgnoreRule.IgnoreAfter
@@ -57,6 +59,10 @@
 class DevSdkIgnoreRunner(private val klass: Class<*>) : Runner(), Filterable, Sortable {
     private val leakMonitorDesc = Description.createTestDescription(klass, "ThreadLeakMonitor")
     private val shouldThreadLeakFailTest = klass.isAnnotationPresent(MonitorThreadLeak::class.java)
+    private val restoreDefaultNetworkDesc =
+            Description.createTestDescription(klass, "RestoreDefaultNetwork")
+    private val restoreDefaultNetwork = klass.isAnnotationPresent(RestoreDefaultNetwork::class.java)
+    val ctx = ApplicationProvider.getApplicationContext<Context>()
 
     // Inference correctly infers Runner & Filterable & Sortable for |baseRunner|, but the
     // Java bytecode doesn't have a way to express this. Give this type a name by wrapping it.
@@ -71,6 +77,10 @@
     // TODO(b/307693729): Remove this annotation and monitor thread leak by default.
     annotation class MonitorThreadLeak
 
+    // Annotation for test classes to indicate the test runner should verify the default network is
+    // restored after each test.
+    annotation class RestoreDefaultNetwork
+
     private val baseRunner: RunnerWrapper<*>? = klass.let {
         val ignoreAfter = it.getAnnotation(IgnoreAfter::class.java)
         val ignoreUpTo = it.getAnnotation(IgnoreUpTo::class.java)
@@ -125,6 +135,14 @@
             )
             return
         }
+
+        val networkRestoreMonitor = if (restoreDefaultNetwork) {
+            DefaultNetworkRestoreMonitor(ctx, notifier).apply{
+                init(ConnectUtil(ctx))
+            }
+        } else {
+            null
+        }
         val threadCountsBeforeTest = if (shouldThreadLeakFailTest) {
             // Dump threads as a baseline to monitor thread leaks.
             getAllThreadNameCounts()
@@ -137,6 +155,7 @@
         if (threadCountsBeforeTest != null) {
             checkThreadLeak(notifier, threadCountsBeforeTest)
         }
+        networkRestoreMonitor?.reportResultAndCleanUp(restoreDefaultNetworkDesc)
         // Clears up internal state of all inline mocks.
         // TODO: Call clearInlineMocks() at the end of each test.
         Mockito.framework().clearInlineMocks()
@@ -163,6 +182,9 @@
             if (shouldThreadLeakFailTest) {
                 it.addChild(leakMonitorDesc)
             }
+            if (restoreDefaultNetwork) {
+                it.addChild(restoreDefaultNetworkDesc)
+            }
         }
     }
 
@@ -173,7 +195,14 @@
         // When ignoring the tests, a skipped placeholder test is reported, so test count is 1.
         if (baseRunner == null) return 1
 
-        return baseRunner.testCount() + if (shouldThreadLeakFailTest) 1 else 0
+        var testCount = baseRunner.testCount()
+        if (shouldThreadLeakFailTest) {
+            testCount += 1
+        }
+        if (restoreDefaultNetwork) {
+            testCount += 1
+        }
+        return testCount
     }
 
     @Throws(NoTestsRemainException::class)
diff --git a/staticlibs/testutils/host/python/wifip2p_utils.py b/staticlibs/testutils/host/python/wifip2p_utils.py
new file mode 100644
index 0000000..8b4ffa5
--- /dev/null
+++ b/staticlibs/testutils/host/python/wifip2p_utils.py
@@ -0,0 +1,50 @@
+#  Copyright (C) 2024 The Android Open Source Project
+#
+#  Licensed under the Apache License, Version 2.0 (the "License");
+#  you may not use this file except in compliance with the License.
+#  You may obtain a copy of the License at
+#
+#       http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+#  limitations under the License.
+
+from mobly import asserts
+from mobly.controllers import android_device
+
+
+def assume_wifi_p2p_test_preconditions(
+    server_device: android_device, client_device: android_device
+) -> None:
+  server = server_device.connectivity_multi_devices_snippet
+  client = client_device.connectivity_multi_devices_snippet
+
+  # Assert pre-conditions
+  asserts.skip_if(not server.hasWifiFeature(), "Server requires Wifi feature")
+  asserts.skip_if(not client.hasWifiFeature(), "Client requires Wifi feature")
+  asserts.skip_if(
+      not server.isP2pSupported(), "Server requires Wi-fi P2P feature"
+  )
+  asserts.skip_if(
+      not client.isP2pSupported(), "Client requires Wi-fi P2P feature"
+  )
+
+
+def setup_wifi_p2p_server_and_client(
+    server_device: android_device, client_device: android_device
+) -> None:
+  """Set up the Wi-Fi P2P server and client."""
+  # Start Wi-Fi P2P on both server and client.
+  server_device.connectivity_multi_devices_snippet.startWifiP2p()
+  client_device.connectivity_multi_devices_snippet.startWifiP2p()
+
+
+def cleanup_wifi_p2p(
+    server_device: android_device, client_device: android_device
+) -> None:
+  # Stop Wi-Fi P2P
+  server_device.connectivity_multi_devices_snippet.stopWifiP2p()
+  client_device.connectivity_multi_devices_snippet.stopWifiP2p()
diff --git a/tests/cts/multidevices/snippet/Android.bp b/tests/cts/multidevices/snippet/Android.bp
index b0b32c2..c94087e 100644
--- a/tests/cts/multidevices/snippet/Android.bp
+++ b/tests/cts/multidevices/snippet/Android.bp
@@ -26,6 +26,7 @@
     srcs: [
         "ConnectivityMultiDevicesSnippet.kt",
         "MdnsMultiDevicesSnippet.kt",
+        "Wifip2pMultiDevicesSnippet.kt",
     ],
     manifest: "AndroidManifest.xml",
     static_libs: [
diff --git a/tests/cts/multidevices/snippet/AndroidManifest.xml b/tests/cts/multidevices/snippet/AndroidManifest.xml
index 967e581..4637497 100644
--- a/tests/cts/multidevices/snippet/AndroidManifest.xml
+++ b/tests/cts/multidevices/snippet/AndroidManifest.xml
@@ -21,6 +21,8 @@
   <uses-permission android:name="android.permission.CHANGE_NETWORK_STATE" />
   <uses-permission android:name="android.permission.CHANGE_WIFI_STATE" />
   <uses-permission android:name="android.permission.INTERNET" />
+  <uses-permission android:name="android.permission.NEARBY_WIFI_DEVICES"
+                   android:usesPermissionFlags="neverForLocation" />
   <application>
     <!-- Add any classes that implement the Snippet interface as meta-data, whose
          value is a comma-separated string, each section being the package path
@@ -28,7 +30,8 @@
     <meta-data
         android:name="mobly-snippets"
         android:value="com.google.snippet.connectivity.ConnectivityMultiDevicesSnippet,
-                       com.google.snippet.connectivity.MdnsMultiDevicesSnippet" />
+                       com.google.snippet.connectivity.MdnsMultiDevicesSnippet,
+                       com.google.snippet.connectivity.Wifip2pMultiDevicesSnippet" />
   </application>
   <!-- Add an instrumentation tag so that the app can be launched through an
        instrument command. The runner `com.google.android.mobly.snippet.SnippetRunner`
diff --git a/tests/cts/multidevices/snippet/Wifip2pMultiDevicesSnippet.kt b/tests/cts/multidevices/snippet/Wifip2pMultiDevicesSnippet.kt
new file mode 100644
index 0000000..e0929bb
--- /dev/null
+++ b/tests/cts/multidevices/snippet/Wifip2pMultiDevicesSnippet.kt
@@ -0,0 +1,67 @@
+/*
+ * Copyright (C) 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.google.snippet.connectivity
+
+import android.net.wifi.WifiManager
+import android.net.wifi.p2p.WifiP2pManager
+import androidx.test.platform.app.InstrumentationRegistry
+import com.google.android.mobly.snippet.Snippet
+import com.google.android.mobly.snippet.rpc.Rpc
+import java.util.concurrent.CompletableFuture
+import java.util.concurrent.TimeUnit
+import kotlin.test.fail
+
+private const val TIMEOUT_MS = 60000L
+
+class Wifip2pMultiDevicesSnippet : Snippet {
+    private val context by lazy { InstrumentationRegistry.getInstrumentation().getTargetContext() }
+    private val wifiManager by lazy {
+        context.getSystemService(WifiManager::class.java)
+                ?: fail("Could not get WifiManager service")
+    }
+    private val wifip2pManager by lazy {
+        context.getSystemService(WifiP2pManager::class.java)
+                ?: fail("Could not get WifiP2pManager service")
+    }
+    private lateinit var wifip2pChannel: WifiP2pManager.Channel
+
+    @Rpc(description = "Check whether the device supports Wi-Fi P2P.")
+    fun isP2pSupported() = wifiManager.isP2pSupported()
+
+    @Rpc(description = "Start Wi-Fi P2P")
+    fun startWifiP2p() {
+        // Initialize Wi-Fi P2P
+        wifip2pChannel = wifip2pManager.initialize(context, context.mainLooper, null)
+
+        // Ensure the Wi-Fi P2P channel is available
+        val p2pStateEnabledFuture = CompletableFuture<Boolean>()
+        wifip2pManager.requestP2pState(wifip2pChannel) { state ->
+            if (state == WifiP2pManager.WIFI_P2P_STATE_ENABLED) {
+                p2pStateEnabledFuture.complete(true)
+            }
+        }
+        p2pStateEnabledFuture.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)
+    }
+
+    @Rpc(description = "Stop Wi-Fi P2P")
+    fun stopWifiP2p() {
+        if (this::wifip2pChannel.isInitialized) {
+            wifip2pManager.cancelConnect(wifip2pChannel, null)
+            wifip2pManager.removeGroup(wifip2pChannel, null)
+        }
+    }
+}
diff --git a/tests/cts/net/src/android/net/cts/DnsResolverTest.java b/tests/cts/net/src/android/net/cts/DnsResolverTest.java
index 752891f..fa44ae9 100644
--- a/tests/cts/net/src/android/net/cts/DnsResolverTest.java
+++ b/tests/cts/net/src/android/net/cts/DnsResolverTest.java
@@ -852,13 +852,14 @@
     }
 
     public void doTestContinuousQueries(Executor executor) throws InterruptedException {
-        final String msg = "Test continuous " + QUERY_TIMES + " queries " + TEST_DOMAIN;
         for (Network network : getTestableNetworks()) {
             for (int i = 0; i < QUERY_TIMES ; ++i) {
-                final VerifyCancelInetAddressCallback callback =
-                        new VerifyCancelInetAddressCallback(msg, null);
                 // query v6/v4 in turn
                 boolean queryV6 = (i % 2 == 0);
+                final String msg = "Test continuous " + QUERY_TIMES + " queries " + TEST_DOMAIN
+                        + " on " + network + ", queryV6=" + queryV6;
+                final VerifyCancelInetAddressCallback callback =
+                        new VerifyCancelInetAddressCallback(msg, null);
                 mDns.query(network, TEST_DOMAIN, queryV6 ? TYPE_AAAA : TYPE_A,
                         FLAG_NO_CACHE_LOOKUP, executor, null, callback);
 
diff --git a/tests/cts/net/src/android/net/cts/MultinetworkApiTest.java b/tests/cts/net/src/android/net/cts/MultinetworkApiTest.java
index 06a827b..2c7d5c6 100644
--- a/tests/cts/net/src/android/net/cts/MultinetworkApiTest.java
+++ b/tests/cts/net/src/android/net/cts/MultinetworkApiTest.java
@@ -40,10 +40,10 @@
 import android.system.OsConstants;
 import android.util.ArraySet;
 
-import androidx.test.ext.junit.runners.AndroidJUnit4;
 import androidx.test.platform.app.InstrumentationRegistry;
 
 import com.android.testutils.AutoReleaseNetworkCallbackRule;
+import com.android.testutils.DevSdkIgnoreRunner;
 import com.android.testutils.DeviceConfigRule;
 
 import org.junit.Before;
@@ -53,7 +53,8 @@
 
 import java.util.Set;
 
-@RunWith(AndroidJUnit4.class)
+@DevSdkIgnoreRunner.RestoreDefaultNetwork
+@RunWith(DevSdkIgnoreRunner.class)
 public class MultinetworkApiTest {
     @Rule(order = 1)
     public final DeviceConfigRule mDeviceConfigRule = new DeviceConfigRule();
diff --git a/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt b/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt
index beb9274..60081d4 100644
--- a/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt
+++ b/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt
@@ -194,15 +194,16 @@
 // TODO : enable this in a Mainline update or in V.
 private const val SHOULD_CREATE_NETWORKS_IMMEDIATELY = false
 
-@RunWith(DevSdkIgnoreRunner::class)
-// NetworkAgent is not updatable in R-, so this test does not need to be compatible with older
-// versions. NetworkAgent was also based on AsyncChannel before S so cannot be tested the same way.
-@IgnoreUpTo(Build.VERSION_CODES.R)
+@AppModeFull(reason = "Instant apps can't use NetworkAgent because it needs NETWORK_FACTORY'.")
 // NetworkAgent is updated as part of the connectivity module, and running NetworkAgent tests in MTS
 // for modules other than Connectivity does not provide much value. Only run them in connectivity
 // module MTS, so the tests only need to cover the case of an updated NetworkAgent.
 @ConnectivityModuleTest
-@AppModeFull(reason = "Instant apps can't use NetworkAgent because it needs NETWORK_FACTORY'.")
+@DevSdkIgnoreRunner.RestoreDefaultNetwork
+// NetworkAgent is not updatable in R-, so this test does not need to be compatible with older
+// versions. NetworkAgent was also based on AsyncChannel before S so cannot be tested the same way.
+@IgnoreUpTo(Build.VERSION_CODES.R)
+@RunWith(DevSdkIgnoreRunner::class)
 class NetworkAgentTest {
     private val LOCAL_IPV4_ADDRESS = InetAddresses.parseNumericAddress("192.0.2.1")
     private val REMOTE_IPV4_ADDRESS = InetAddresses.parseNumericAddress("192.0.2.2")