Merge "Add test for CM#setAcceptUnvalidated"
diff --git a/Tethering/src/com/android/networkstack/tethering/BpfCoordinator.java b/Tethering/src/com/android/networkstack/tethering/BpfCoordinator.java
index 36a2f10..4a05c9f 100644
--- a/Tethering/src/com/android/networkstack/tethering/BpfCoordinator.java
+++ b/Tethering/src/com/android/networkstack/tethering/BpfCoordinator.java
@@ -786,6 +786,7 @@
         dumpIpv6ForwardingRules(pw);
         dumpIpv4ForwardingRules(pw);
         pw.decreaseIndent();
+        pw.println();
 
         pw.println("Device map:");
         pw.increaseIndent();
@@ -872,24 +873,33 @@
         }
     }
 
-    private String ipv4RuleToString(long now, Tether4Key key, Tether4Value value) {
-        final String private4, public4, dst4;
+    private String ipv4RuleToString(long now, boolean downstream,
+            Tether4Key key, Tether4Value value) {
+        final String src4, public4, dst4;
+        final int publicPort;
         try {
-            private4 = InetAddress.getByAddress(key.src4).getHostAddress();
-            dst4 = InetAddress.getByAddress(key.dst4).getHostAddress();
-            public4 = InetAddress.getByAddress(value.src46).getHostAddress();
+            src4 = InetAddress.getByAddress(key.src4).getHostAddress();
+            if (downstream) {
+                public4 = InetAddress.getByAddress(key.dst4).getHostAddress();
+                publicPort = key.dstPort;
+            } else {
+                public4 = InetAddress.getByAddress(value.src46).getHostAddress();
+                publicPort = value.srcPort;
+            }
+            dst4 = InetAddress.getByAddress(value.dst46).getHostAddress();
         } catch (UnknownHostException impossible) {
-            throw new AssertionError("4-byte array not valid IPv4 address!");
+            throw new AssertionError("IP address array not valid IPv4 address!");
         }
-        long ageMs = (now - value.lastUsed) / 1_000_000;
-        return String.format("[%s] %d(%s) %s:%d -> %d(%s) %s:%d -> %s:%d %dms",
-                key.dstMac, key.iif, getIfName(key.iif), private4, key.srcPort,
+
+        final long ageMs = (now - value.lastUsed) / 1_000_000;
+        return String.format("[%s] %d(%s) %s:%d -> %d(%s) %s:%d -> %s:%d [%s] %dms",
+                key.dstMac, key.iif, getIfName(key.iif), src4, key.srcPort,
                 value.oif, getIfName(value.oif),
-                public4, value.srcPort, dst4, key.dstPort, ageMs);
+                public4, publicPort, dst4, value.dstPort, value.ethDstMac, ageMs);
     }
 
-    private void dumpIpv4ForwardingRuleMap(long now, BpfMap<Tether4Key, Tether4Value> map,
-            IndentingPrintWriter pw) throws ErrnoException {
+    private void dumpIpv4ForwardingRuleMap(long now, boolean downstream,
+            BpfMap<Tether4Key, Tether4Value> map, IndentingPrintWriter pw) throws ErrnoException {
         if (map == null) {
             pw.println("No IPv4 support");
             return;
@@ -898,7 +908,7 @@
             pw.println("No rules");
             return;
         }
-        map.forEach((k, v) -> pw.println(ipv4RuleToString(now, k, v)));
+        map.forEach((k, v) -> pw.println(ipv4RuleToString(now, downstream, k, v)));
     }
 
     private void dumpIpv4ForwardingRules(IndentingPrintWriter pw) {
@@ -906,14 +916,14 @@
 
         try (BpfMap<Tether4Key, Tether4Value> upstreamMap = mDeps.getBpfUpstream4Map();
                 BpfMap<Tether4Key, Tether4Value> downstreamMap = mDeps.getBpfDownstream4Map()) {
-            pw.println("IPv4 Upstream: [inDstMac] iif(iface) src -> nat -> dst");
+            pw.println("IPv4 Upstream: [inDstMac] iif(iface) src -> nat -> dst [outDstMac] age");
             pw.increaseIndent();
-            dumpIpv4ForwardingRuleMap(now, upstreamMap, pw);
+            dumpIpv4ForwardingRuleMap(now, UPSTREAM, upstreamMap, pw);
             pw.decreaseIndent();
 
-            pw.println("IPv4 Downstream: [inDstMac] iif(iface) src -> nat -> dst");
+            pw.println("IPv4 Downstream: [inDstMac] iif(iface) src -> nat -> dst [outDstMac] age");
             pw.increaseIndent();
-            dumpIpv4ForwardingRuleMap(now, downstreamMap, pw);
+            dumpIpv4ForwardingRuleMap(now, DOWNSTREAM, downstreamMap, pw);
             pw.decreaseIndent();
         } catch (ErrnoException e) {
             pw.println("Error dumping IPv4 map: " + e);
diff --git a/service/src/com/android/server/ConnectivityService.java b/service/src/com/android/server/ConnectivityService.java
index 9d318cd..fd8397f 100644
--- a/service/src/com/android/server/ConnectivityService.java
+++ b/service/src/com/android/server/ConnectivityService.java
@@ -1275,6 +1275,20 @@
         public boolean getCellular464XlatEnabled() {
             return NetworkProperties.isCellular464XlatEnabled().orElse(true);
         }
+
+        /**
+         * @see PendingIntent#intentFilterEquals
+         */
+        public boolean intentFilterEquals(PendingIntent a, PendingIntent b) {
+            return a.intentFilterEquals(b);
+        }
+
+        /**
+         * @see LocationPermissionChecker
+         */
+        public LocationPermissionChecker makeLocationPermissionChecker(Context context) {
+            return new LocationPermissionChecker(context);
+        }
     }
 
     public ConnectivityService(Context context) {
@@ -1342,7 +1356,7 @@
         mNetd = netd;
         mTelephonyManager = (TelephonyManager) mContext.getSystemService(Context.TELEPHONY_SERVICE);
         mAppOpsManager = (AppOpsManager) mContext.getSystemService(Context.APP_OPS_SERVICE);
-        mLocationPermissionChecker = new LocationPermissionChecker(mContext);
+        mLocationPermissionChecker = mDeps.makeLocationPermissionChecker(mContext);
 
         // To ensure uid state is synchronized with Network Policy, register for
         // NetworkPolicyManagerService events must happen prior to NetworkPolicyManagerService
@@ -3926,7 +3940,7 @@
         for (Map.Entry<NetworkRequest, NetworkRequestInfo> entry : mNetworkRequests.entrySet()) {
             PendingIntent existingPendingIntent = entry.getValue().mPendingIntent;
             if (existingPendingIntent != null &&
-                    existingPendingIntent.intentFilterEquals(pendingIntent)) {
+                    mDeps.intentFilterEquals(existingPendingIntent, pendingIntent)) {
                 return entry.getValue();
             }
         }
diff --git a/tests/common/java/android/net/NetworkProviderTest.kt b/tests/common/java/android/net/NetworkProviderTest.kt
index 7424157..97d3c5a 100644
--- a/tests/common/java/android/net/NetworkProviderTest.kt
+++ b/tests/common/java/android/net/NetworkProviderTest.kt
@@ -18,6 +18,7 @@
 
 import android.app.Instrumentation
 import android.content.Context
+import android.net.NetworkCapabilities.NET_CAPABILITY_TRUSTED
 import android.net.NetworkCapabilities.TRANSPORT_TEST
 import android.net.NetworkProviderTest.TestNetworkCallback.CallbackEntry.OnUnavailable
 import android.net.NetworkProviderTest.TestNetworkProvider.CallbackEntry.OnNetworkRequestWithdrawn
@@ -25,14 +26,18 @@
 import android.os.Build
 import android.os.HandlerThread
 import android.os.Looper
+import android.util.Log
 import androidx.test.InstrumentationRegistry
 import com.android.net.module.util.ArrayTrackRecord
 import com.android.testutils.CompatUtil
+import com.android.testutils.DevSdkIgnoreRule
+import com.android.testutils.DevSdkIgnoreRule.IgnoreAfter
 import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo
 import com.android.testutils.DevSdkIgnoreRunner
 import com.android.testutils.isDevSdkInRange
 import org.junit.After
 import org.junit.Before
+import org.junit.Rule
 import org.junit.Test
 import org.junit.runner.RunWith
 import org.mockito.Mockito.doReturn
@@ -41,6 +46,7 @@
 import java.util.UUID
 import kotlin.test.assertEquals
 import kotlin.test.assertNotEquals
+import kotlin.test.fail
 
 private const val DEFAULT_TIMEOUT_MS = 5000L
 private val instrumentation: Instrumentation
@@ -51,6 +57,8 @@
 @RunWith(DevSdkIgnoreRunner::class)
 @IgnoreUpTo(Build.VERSION_CODES.Q)
 class NetworkProviderTest {
+    @Rule @JvmField
+    val mIgnoreRule = DevSdkIgnoreRule()
     private val mCm = context.getSystemService(ConnectivityManager::class.java)
     private val mHandlerThread = HandlerThread("${javaClass.simpleName} handler thread")
 
@@ -68,6 +76,7 @@
 
     private class TestNetworkProvider(context: Context, looper: Looper) :
             NetworkProvider(context, looper, PROVIDER_NAME) {
+        private val TAG = this::class.simpleName
         private val seenEvents = ArrayTrackRecord<CallbackEntry>().newReadHead()
 
         sealed class CallbackEntry {
@@ -80,22 +89,30 @@
         }
 
         override fun onNetworkRequested(request: NetworkRequest, score: Int, id: Int) {
+            Log.d(TAG, "onNetworkRequested $request, $score, $id")
             seenEvents.add(OnNetworkRequested(request, score, id))
         }
 
         override fun onNetworkRequestWithdrawn(request: NetworkRequest) {
+            Log.d(TAG, "onNetworkRequestWithdrawn $request")
             seenEvents.add(OnNetworkRequestWithdrawn(request))
         }
 
-        inline fun <reified T : CallbackEntry> expectCallback(
+        inline fun <reified T : CallbackEntry> eventuallyExpectCallbackThat(
             crossinline predicate: (T) -> Boolean
         ) = seenEvents.poll(DEFAULT_TIMEOUT_MS) { it is T && predicate(it) }
+                ?: fail("Did not receive callback after ${DEFAULT_TIMEOUT_MS}ms")
     }
 
     private fun createNetworkProvider(ctx: Context = context): TestNetworkProvider {
         return TestNetworkProvider(ctx, mHandlerThread.looper)
     }
 
+    // In S+ framework, do not run this test, since the provider will no longer receive
+    // onNetworkRequested for every request. Instead, provider needs to
+    // call {@code registerNetworkOffer} with the description of networks they
+    // might have ability to setup, and expects {@link NetworkOfferCallback#onNetworkNeeded}.
+    @IgnoreAfter(Build.VERSION_CODES.R)
     @Test
     fun testOnNetworkRequested() {
         val provider = createNetworkProvider()
@@ -105,13 +122,15 @@
 
         val specifier = CompatUtil.makeTestNetworkSpecifier(
                 UUID.randomUUID().toString())
+        // Test network is not allowed to be trusted.
         val nr: NetworkRequest = NetworkRequest.Builder()
                 .addTransportType(TRANSPORT_TEST)
+                .removeCapability(NET_CAPABILITY_TRUSTED)
                 .setNetworkSpecifier(specifier)
                 .build()
         val cb = ConnectivityManager.NetworkCallback()
         mCm.requestNetwork(nr, cb)
-        provider.expectCallback<OnNetworkRequested>() { callback ->
+        provider.eventuallyExpectCallbackThat<OnNetworkRequested>() { callback ->
             callback.request.getNetworkSpecifier() == specifier &&
             callback.request.hasTransport(TRANSPORT_TEST)
         }
@@ -131,22 +150,24 @@
         val config = NetworkAgentConfig.Builder().build()
         val agent = object : NetworkAgent(context, mHandlerThread.looper, "TestAgent", nc, lp,
                 initialScore, config, provider) {}
+        agent.register()
+        agent.markConnected()
 
-        provider.expectCallback<OnNetworkRequested>() { callback ->
+        provider.eventuallyExpectCallbackThat<OnNetworkRequested>() { callback ->
             callback.request.getNetworkSpecifier() == specifier &&
             callback.score == initialScore &&
             callback.id == agent.providerId
         }
 
         agent.sendNetworkScore(updatedScore)
-        provider.expectCallback<OnNetworkRequested>() { callback ->
+        provider.eventuallyExpectCallbackThat<OnNetworkRequested>() { callback ->
             callback.request.getNetworkSpecifier() == specifier &&
             callback.score == updatedScore &&
             callback.id == agent.providerId
         }
 
         mCm.unregisterNetworkCallback(cb)
-        provider.expectCallback<OnNetworkRequestWithdrawn>() { callback ->
+        provider.eventuallyExpectCallbackThat<OnNetworkRequestWithdrawn>() { callback ->
             callback.request.getNetworkSpecifier() == specifier &&
             callback.request.hasTransport(TRANSPORT_TEST)
         }
diff --git a/tests/cts/net/src/android/net/cts/ConnectivityManagerTest.java b/tests/cts/net/src/android/net/cts/ConnectivityManagerTest.java
index fa7cd3d..b249abf 100644
--- a/tests/cts/net/src/android/net/cts/ConnectivityManagerTest.java
+++ b/tests/cts/net/src/android/net/cts/ConnectivityManagerTest.java
@@ -669,6 +669,18 @@
         mCm.getBackgroundDataSetting();
     }
 
+    private NetworkRequest makeDefaultRequest() {
+        // Make a request that is similar to the way framework tracks the system
+        // default network.
+        return new NetworkRequest.Builder()
+                .clearCapabilities()
+                .addCapability(NetworkCapabilities.NET_CAPABILITY_NOT_RESTRICTED)
+                .addCapability(NetworkCapabilities.NET_CAPABILITY_TRUSTED)
+                .addCapability(NetworkCapabilities.NET_CAPABILITY_NOT_VPN)
+                .addCapability(NetworkCapabilities.NET_CAPABILITY_INTERNET)
+                .build();
+    }
+
     private NetworkRequest makeWifiNetworkRequest() {
         return new NetworkRequest.Builder()
                 .addTransportType(NetworkCapabilities.TRANSPORT_WIFI)
@@ -731,12 +743,14 @@
 
         final TestNetworkCallback systemDefaultCallback = new TestNetworkCallback();
         final TestNetworkCallback perUidCallback = new TestNetworkCallback();
+        final TestNetworkCallback bestMatchingCallback = new TestNetworkCallback();
         final Handler h = new Handler(Looper.getMainLooper());
         if (TestUtils.shouldTestSApis()) {
             runWithShellPermissionIdentity(() -> {
                 mCmShim.registerSystemDefaultNetworkCallback(systemDefaultCallback, h);
                 mCmShim.registerDefaultNetworkCallbackForUid(Process.myUid(), perUidCallback, h);
             }, NETWORK_SETTINGS);
+            mCm.registerBestMatchingNetworkCallback(makeDefaultRequest(), bestMatchingCallback, h);
         }
 
         Network wifiNetwork = null;
@@ -762,6 +776,10 @@
                 assertNotNull("Did not receive onAvailable on per-UID default network callback",
                         perUidNetwork);
                 assertEquals(defaultNetwork, perUidNetwork);
+                final Network bestMatchingNetwork = bestMatchingCallback.waitForAvailable();
+                assertNotNull("Did not receive onAvailable on best matching network callback",
+                        bestMatchingNetwork);
+                assertEquals(defaultNetwork, bestMatchingNetwork);
             }
 
         } catch (InterruptedException e) {
@@ -772,6 +790,7 @@
             if (TestUtils.shouldTestSApis()) {
                 mCm.unregisterNetworkCallback(systemDefaultCallback);
                 mCm.unregisterNetworkCallback(perUidCallback);
+                mCm.unregisterNetworkCallback(bestMatchingCallback);
             }
         }
     }
diff --git a/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt b/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt
index 9017f1b..c505cef 100644
--- a/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt
+++ b/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt
@@ -66,6 +66,7 @@
 import android.net.cts.NetworkAgentTest.TestableNetworkAgent.CallbackEntry.OnStopSocketKeepalive
 import android.net.cts.NetworkAgentTest.TestableNetworkAgent.CallbackEntry.OnValidationStatus
 import android.os.Build
+import android.os.Handler
 import android.os.HandlerThread
 import android.os.Looper
 import android.os.Message
@@ -269,10 +270,9 @@
             history.add(OnSignalStrengthThresholdsUpdated(thresholds))
         }
 
-        fun expectEmptySignalStrengths() {
+        fun expectSignalStrengths(thresholds: IntArray? = intArrayOf()) {
             expectCallback<OnSignalStrengthThresholdsUpdated>().let {
-                // intArrayOf() without arguments makes an empty array
-                assertArrayEquals(intArrayOf(), it.thresholds)
+                assertArrayEquals(thresholds, it.thresholds)
             }
         }
 
@@ -292,7 +292,7 @@
         // a NetworkAgent whose network does not require validation (which test networks do
         // not, since they lack the INTERNET capability). It always contains the default argument
         // for the URI.
-        fun expectNoInternetValidationStatus() = expectCallback<OnValidationStatus>().let {
+        fun expectValidationBypassedStatus() = expectCallback<OnValidationStatus>().let {
             assertEquals(it.status, VALID_NETWORK)
             // The returned Uri is parsed from the empty string, which means it's an
             // instance of the (private) Uri.StringUri. There are no real good ways
@@ -332,9 +332,30 @@
         callbacksToCleanUp.add(callback)
     }
 
+    private fun registerBestMatchingNetworkCallback(
+        request: NetworkRequest,
+        callback: TestableNetworkCallback,
+        handler: Handler
+    ) {
+        mCM!!.registerBestMatchingNetworkCallback(request, callback, handler)
+        callbacksToCleanUp.add(callback)
+    }
+
+    private fun makeTestNetworkRequest(specifier: String? = null): NetworkRequest {
+        return NetworkRequest.Builder()
+                .clearCapabilities()
+                .addTransportType(TRANSPORT_TEST)
+                .also {
+                    if (specifier != null) {
+                        it.setNetworkSpecifier(CompatUtil.makeEthernetNetworkSpecifier(specifier))
+                    }
+                }
+                .build()
+    }
+
     private fun createNetworkAgent(
         context: Context = realContext,
-        name: String? = null,
+        specifier: String? = null,
         initialNc: NetworkCapabilities? = null,
         initialLp: LinkProperties? = null,
         initialConfig: NetworkAgentConfig? = null
@@ -349,8 +370,8 @@
             if (SdkLevel.isAtLeastS()) {
                 addCapability(NET_CAPABILITY_NOT_VCN_MANAGED)
             }
-            if (null != name) {
-                setNetworkSpecifier(CompatUtil.makeEthernetNetworkSpecifier(name))
+            if (null != specifier) {
+                setNetworkSpecifier(CompatUtil.makeEthernetNetworkSpecifier(specifier))
             }
         }
         val lp = initialLp ?: LinkProperties().apply {
@@ -365,21 +386,22 @@
 
     private fun createConnectedNetworkAgent(
         context: Context = realContext,
-        name: String? = null,
-        initialConfig: NetworkAgentConfig? = null
+        specifier: String? = UUID.randomUUID().toString(),
+        initialConfig: NetworkAgentConfig? = null,
+        expectedInitSignalStrengthThresholds: IntArray? = intArrayOf()
     ): Pair<TestableNetworkAgent, TestableNetworkCallback> {
-        val request: NetworkRequest = NetworkRequest.Builder()
-                .clearCapabilities()
-                .addTransportType(TRANSPORT_TEST)
-                .build()
         val callback = TestableNetworkCallback(timeoutMs = DEFAULT_TIMEOUT_MS)
-        requestNetwork(request, callback)
-        val config = initialConfig ?: NetworkAgentConfig.Builder().build()
-        val agent = createNetworkAgent(context, name, initialConfig = config)
+        // Ensure this NetworkAgent is never unneeded by filing a request with its specifier.
+        requestNetwork(makeTestNetworkRequest(specifier = specifier), callback)
+        val agent = createNetworkAgent(context, specifier, initialConfig = initialConfig)
         agent.setTeardownDelayMillis(0)
+        // Connect the agent and verify initial status callbacks.
         agent.register()
         agent.markConnected()
         agent.expectCallback<OnNetworkCreated>()
+        agent.expectSignalStrengths(expectedInitSignalStrengthThresholds)
+        agent.expectValidationBypassedStatus()
+        callback.expectAvailableThenValidatedCallbacks(agent.network!!)
         return agent to callback
     }
 
@@ -413,7 +435,6 @@
                 .setLegacySubType(subtypeLTE)
                 .setLegacySubTypeName(subtypeNameLTE).build()
         val (agent, callback) = createConnectedNetworkAgent(initialConfig = config)
-            callback.expectAvailableThenValidatedCallbacks(agent.network)
             agent.setLegacySubtype(subtypeUMTS, subtypeNameUMTS)
 
             // There is no callback when networkInfo changes,
@@ -433,12 +454,8 @@
     @Test
     fun testConnectAndUnregister() {
         val (agent, callback) = createConnectedNetworkAgent()
-        callback.expectAvailableThenValidatedCallbacks(agent.network)
-        agent.expectEmptySignalStrengths()
-        agent.expectNoInternetValidationStatus()
-
         unregister(agent)
-        callback.expectCallback<Lost>(agent.network)
+        callback.expectCallback<Lost>(agent.network!!)
         assertFailsWith<IllegalStateException>("Must not be able to register an agent twice") {
             agent.register()
         }
@@ -446,11 +463,8 @@
 
     @Test
     fun testOnBandwidthUpdateRequested() {
-        val (agent, callback) = createConnectedNetworkAgent()
-        callback.expectAvailableThenValidatedCallbacks(agent.network)
-        agent.expectEmptySignalStrengths()
-        agent.expectNoInternetValidationStatus()
-        mCM.requestBandwidthUpdate(agent.network)
+        val (agent, _) = createConnectedNetworkAgent()
+        mCM.requestBandwidthUpdate(agent.network!!)
         agent.expectCallback<OnBandwidthUpdateRequested>()
         unregister(agent)
     }
@@ -468,13 +482,8 @@
                 registerNetworkCallback(request, it)
             }
         }
-        createConnectedNetworkAgent().let { (agent, callback) ->
-            callback.expectAvailableThenValidatedCallbacks(agent.network)
-            agent.expectCallback<OnSignalStrengthThresholdsUpdated>().let {
-                assertArrayEquals(it.thresholds, thresholds)
-            }
-            agent.expectNoInternetValidationStatus()
-
+        createConnectedNetworkAgent(expectedInitSignalStrengthThresholds = thresholds).let {
+            (agent, callback) ->
             // Send signal strength and check that the callbacks are called appropriately.
             val nc = NetworkCapabilities(agent.nc)
             nc.setSignalStrength(20)
@@ -483,21 +492,21 @@
 
             nc.setSignalStrength(40)
             agent.sendNetworkCapabilities(nc)
-            callbacks[0].expectAvailableCallbacks(agent.network)
+            callbacks[0].expectAvailableCallbacks(agent.network!!)
             callbacks[1].assertNoCallback(NO_CALLBACK_TIMEOUT)
             callbacks[2].assertNoCallback(NO_CALLBACK_TIMEOUT)
 
             nc.setSignalStrength(80)
             agent.sendNetworkCapabilities(nc)
-            callbacks[0].expectCapabilitiesThat(agent.network) { it.signalStrength == 80 }
-            callbacks[1].expectAvailableCallbacks(agent.network)
-            callbacks[2].expectAvailableCallbacks(agent.network)
+            callbacks[0].expectCapabilitiesThat(agent.network!!) { it.signalStrength == 80 }
+            callbacks[1].expectAvailableCallbacks(agent.network!!)
+            callbacks[2].expectAvailableCallbacks(agent.network!!)
 
             nc.setSignalStrength(55)
             agent.sendNetworkCapabilities(nc)
-            callbacks[0].expectCapabilitiesThat(agent.network) { it.signalStrength == 55 }
-            callbacks[1].expectCapabilitiesThat(agent.network) { it.signalStrength == 55 }
-            callbacks[2].expectCallback<Lost>(agent.network)
+            callbacks[0].expectCapabilitiesThat(agent.network!!) { it.signalStrength == 55 }
+            callbacks[1].expectCapabilitiesThat(agent.network!!) { it.signalStrength == 55 }
+            callbacks[2].expectCallback<Lost>(agent.network!!)
         }
         callbacks.forEach {
             mCM.unregisterNetworkCallback(it)
@@ -546,20 +555,17 @@
 
     @Test
     fun testSendUpdates(): Unit = createConnectedNetworkAgent().let { (agent, callback) ->
-        callback.expectAvailableThenValidatedCallbacks(agent.network)
-        agent.expectEmptySignalStrengths()
-        agent.expectNoInternetValidationStatus()
         val ifaceName = "adhocIface"
         val lp = LinkProperties(agent.lp)
         lp.setInterfaceName(ifaceName)
         agent.sendLinkProperties(lp)
-        callback.expectLinkPropertiesThat(agent.network) {
+        callback.expectLinkPropertiesThat(agent.network!!) {
             it.getInterfaceName() == ifaceName
         }
         val nc = NetworkCapabilities(agent.nc)
         nc.addCapability(NET_CAPABILITY_NOT_METERED)
         agent.sendNetworkCapabilities(nc)
-        callback.expectCapabilitiesThat(agent.network) {
+        callback.expectCapabilitiesThat(agent.network!!) {
             it.hasCapability(NET_CAPABILITY_NOT_METERED)
         }
     }
@@ -568,56 +574,32 @@
     fun testSendScore() {
         // This test will create two networks and check that the one with the stronger
         // score wins out for a request that matches them both.
-        // First create requests to make sure both networks are kept up, using the
-        // specifier so they are specific to each network
-        val name1 = UUID.randomUUID().toString()
-        val name2 = UUID.randomUUID().toString()
-        val request1 = NetworkRequest.Builder()
-                .clearCapabilities()
-                .addTransportType(TRANSPORT_TEST)
-                .setNetworkSpecifier(CompatUtil.makeEthernetNetworkSpecifier(name1))
-                .build()
-        val request2 = NetworkRequest.Builder()
-                .clearCapabilities()
-                .addTransportType(TRANSPORT_TEST)
-                .setNetworkSpecifier(CompatUtil.makeEthernetNetworkSpecifier(name2))
-                .build()
-        val callback1 = TestableNetworkCallback(timeoutMs = DEFAULT_TIMEOUT_MS)
-        val callback2 = TestableNetworkCallback(timeoutMs = DEFAULT_TIMEOUT_MS)
-        requestNetwork(request1, callback1)
-        requestNetwork(request2, callback2)
 
-        // Then file the interesting request
-        val request = NetworkRequest.Builder()
-                .clearCapabilities()
-                .addTransportType(TRANSPORT_TEST)
-                .build()
+        // File the interesting request
         val callback = TestableNetworkCallback(timeoutMs = DEFAULT_TIMEOUT_MS)
-        requestNetwork(request, callback)
+        requestNetwork(makeTestNetworkRequest(), callback)
 
-        // Connect the first Network
-        createConnectedNetworkAgent(name = name1).let { (agent1, _) ->
-            callback.expectAvailableThenValidatedCallbacks(agent1.network)
-            // If using the int ranking, agent1 must be upgraded to a better score so that there is
-            // no ambiguity when agent2 connects that agent1 is still better. If using policy
-            // ranking, this is not necessary.
-            agent1.sendNetworkScore(NetworkScore.Builder().setLegacyInt(BETTER_NETWORK_SCORE)
-                    .build())
-            // Connect the second agent
-            createConnectedNetworkAgent(name = name2).let { (agent2, _) ->
-                agent2.markConnected()
-                // The callback should not see anything yet. With int ranking, agent1 was upgraded
-                // to a stronger score beforehand. With policy ranking, agent1 is preferred by
-                // virtue of already satisfying the request.
-                callback.assertNoCallback(NO_CALLBACK_TIMEOUT)
-                // Now downgrade the score and expect the callback now prefers agent2
-                agent1.sendNetworkScore(NetworkScore.Builder()
-                        .setLegacyInt(WORSE_NETWORK_SCORE)
-                        .setExiting(true)
-                        .build())
-                callback.expectCallback<Available>(agent2.network)
-            }
-        }
+        // Connect the first Network, with an unused callback that kept the network up.
+        val (agent1, _) = createConnectedNetworkAgent()
+        callback.expectAvailableThenValidatedCallbacks(agent1.network!!)
+        // If using the int ranking, agent1 must be upgraded to a better score so that there is
+        // no ambiguity when agent2 connects that agent1 is still better. If using policy
+        // ranking, this is not necessary.
+        agent1.sendNetworkScore(NetworkScore.Builder().setLegacyInt(BETTER_NETWORK_SCORE)
+                .build())
+
+        // Connect the second agent.
+        val (agent2, _) = createConnectedNetworkAgent()
+        // The callback should not see anything yet. With int ranking, agent1 was upgraded
+        // to a stronger score beforehand. With policy ranking, agent1 is preferred by
+        // virtue of already satisfying the request.
+        callback.assertNoCallback(NO_CALLBACK_TIMEOUT)
+        // Now downgrade the score and expect the callback now prefers agent2
+        agent1.sendNetworkScore(NetworkScore.Builder()
+                .setLegacyInt(WORSE_NETWORK_SCORE)
+                .setExiting(true)
+                .build())
+        callback.expectCallback<Available>(agent2.network!!)
 
         // tearDown() will unregister the requests and agents
     }
@@ -658,7 +640,7 @@
         callback.expectAvailableThenValidatedCallbacks(agent.network!!)
 
         // Check that the default network's transport is propagated to the VPN.
-        var vpnNc = mCM.getNetworkCapabilities(agent.network)
+        var vpnNc = mCM.getNetworkCapabilities(agent.network!!)
         assertNotNull(vpnNc)
         assertEquals(VpnManager.TYPE_VPN_SERVICE,
                 (vpnNc.transportInfo as VpnTransportInfo).type)
@@ -690,7 +672,7 @@
         // This is not very accurate because the test does not control the capabilities of the
         // underlying networks, and because not congested, not roaming, and not suspended are the
         // default anyway. It's still useful as an extra check though.
-        vpnNc = mCM.getNetworkCapabilities(agent.network)
+        vpnNc = mCM.getNetworkCapabilities(agent.network!!)
         for (cap in listOf(NET_CAPABILITY_NOT_CONGESTED,
                 NET_CAPABILITY_NOT_ROAMING,
                 NET_CAPABILITY_NOT_SUSPENDED)) {
@@ -701,7 +683,7 @@
         }
 
         unregister(agent)
-        callback.expectCallback<Lost>(agent.network)
+        callback.expectCallback<Lost>(agent.network!!)
     }
 
     private fun unregister(agent: TestableNetworkAgent) {
@@ -789,43 +771,24 @@
     fun testTemporarilyUnmeteredCapability() {
         // This test will create a networks with/without NET_CAPABILITY_TEMPORARILY_NOT_METERED
         // and check that the callback reflects the capability changes.
-        // First create a request to make sure the network is kept up
-        val request1 = NetworkRequest.Builder()
-                .clearCapabilities()
-                .addTransportType(TRANSPORT_TEST)
-                .build()
-        val callback1 = TestableNetworkCallback(timeoutMs = DEFAULT_TIMEOUT_MS).also {
-            registerNetworkCallback(request1, it)
-        }
-        requestNetwork(request1, callback1)
-
-        // Then file the interesting request
-        val request = NetworkRequest.Builder()
-                .clearCapabilities()
-                .addTransportType(TRANSPORT_TEST)
-                .build()
-        val callback = TestableNetworkCallback(timeoutMs = DEFAULT_TIMEOUT_MS)
-        requestNetwork(request, callback)
 
         // Connect the network
-        createConnectedNetworkAgent().let { (agent, _) ->
-            callback.expectAvailableThenValidatedCallbacks(agent.network)
+        val (agent, callback) = createConnectedNetworkAgent()
 
-            // Send TEMP_NOT_METERED and check that the callback is called appropriately.
-            val nc1 = NetworkCapabilities(agent.nc)
-                    .addCapability(NET_CAPABILITY_TEMPORARILY_NOT_METERED)
-            agent.sendNetworkCapabilities(nc1)
-            callback.expectCapabilitiesThat(agent.network) {
-                it.hasCapability(NET_CAPABILITY_TEMPORARILY_NOT_METERED)
-            }
+        // Send TEMP_NOT_METERED and check that the callback is called appropriately.
+        val nc1 = NetworkCapabilities(agent.nc)
+                .addCapability(NET_CAPABILITY_TEMPORARILY_NOT_METERED)
+        agent.sendNetworkCapabilities(nc1)
+        callback.expectCapabilitiesThat(agent.network!!) {
+            it.hasCapability(NET_CAPABILITY_TEMPORARILY_NOT_METERED)
+        }
 
-            // Remove TEMP_NOT_METERED and check that the callback is called appropriately.
-            val nc2 = NetworkCapabilities(agent.nc)
-                    .removeCapability(NET_CAPABILITY_TEMPORARILY_NOT_METERED)
-            agent.sendNetworkCapabilities(nc2)
-            callback.expectCapabilitiesThat(agent.network) {
-                !it.hasCapability(NET_CAPABILITY_TEMPORARILY_NOT_METERED)
-            }
+        // Remove TEMP_NOT_METERED and check that the callback is called appropriately.
+        val nc2 = NetworkCapabilities(agent.nc)
+                .removeCapability(NET_CAPABILITY_TEMPORARILY_NOT_METERED)
+        agent.sendNetworkCapabilities(nc2)
+        callback.expectCapabilitiesThat(agent.network!!) {
+            !it.hasCapability(NET_CAPABILITY_TEMPORARILY_NOT_METERED)
         }
 
         // tearDown() will unregister the requests and agents
@@ -838,88 +801,137 @@
         // score wins out for a request that matches them both. And the weaker agent will
         // be disconnected after customized linger duration.
 
-        // Connect the first Network
-        val name1 = UUID.randomUUID().toString()
-        val name2 = UUID.randomUUID().toString()
-        val (agent1, callback) = createConnectedNetworkAgent(name = name1)
-        callback.expectAvailableThenValidatedCallbacks(agent1.network!!)
-        // Downgrade agent1 to a worse score so that there is no ambiguity when
-        // agent2 connects.
-        agent1.sendNetworkScore(NetworkScore.Builder().setLegacyInt(WORSE_NETWORK_SCORE)
+        // Request the first Network, with a request that could moved to agentStronger in order to
+        // make agentWeaker linger later.
+        val specifierWeaker = UUID.randomUUID().toString()
+        val specifierStronger = UUID.randomUUID().toString()
+        val commonCallback = TestableNetworkCallback(timeoutMs = DEFAULT_TIMEOUT_MS)
+        requestNetwork(makeTestNetworkRequest(), commonCallback)
+        val agentWeaker = createNetworkAgent(specifier = specifierWeaker)
+        agentWeaker.register()
+        agentWeaker.markConnected()
+        commonCallback.expectAvailableThenValidatedCallbacks(agentWeaker.network!!)
+        // Downgrade agentWeaker to a worse score so that there is no ambiguity when
+        // agentStronger connects.
+        agentWeaker.sendNetworkScore(NetworkScore.Builder().setLegacyInt(WORSE_NETWORK_SCORE)
                 .setExiting(true).build())
 
         // Verify invalid linger duration cannot be set.
         assertFailsWith<IllegalArgumentException> {
-            agent1.setLingerDuration(Duration.ofMillis(-1))
+            agentWeaker.setLingerDuration(Duration.ofMillis(-1))
         }
-        assertFailsWith<IllegalArgumentException> { agent1.setLingerDuration(Duration.ZERO) }
+        assertFailsWith<IllegalArgumentException> { agentWeaker.setLingerDuration(Duration.ZERO) }
         assertFailsWith<IllegalArgumentException> {
-            agent1.setLingerDuration(Duration.ofMillis(Integer.MIN_VALUE.toLong()))
+            agentWeaker.setLingerDuration(Duration.ofMillis(Integer.MIN_VALUE.toLong()))
         }
         assertFailsWith<IllegalArgumentException> {
-            agent1.setLingerDuration(Duration.ofMillis(Integer.MAX_VALUE.toLong() + 1))
+            agentWeaker.setLingerDuration(Duration.ofMillis(Integer.MAX_VALUE.toLong() + 1))
         }
         assertFailsWith<IllegalArgumentException> {
-            agent1.setLingerDuration(Duration.ofMillis(
+            agentWeaker.setLingerDuration(Duration.ofMillis(
                     NetworkAgent.MIN_LINGER_TIMER_MS.toLong() - 1))
         }
         // Verify valid linger timer can be set, but it should not take effect since the network
         // is still needed.
-        agent1.setLingerDuration(Duration.ofMillis(Integer.MAX_VALUE.toLong()))
-        callback.assertNoCallback(NO_CALLBACK_TIMEOUT)
+        agentWeaker.setLingerDuration(Duration.ofMillis(Integer.MAX_VALUE.toLong()))
+        commonCallback.assertNoCallback(NO_CALLBACK_TIMEOUT)
         // Set to the value we want to verify the functionality.
-        agent1.setLingerDuration(Duration.ofMillis(NetworkAgent.MIN_LINGER_TIMER_MS.toLong()))
-        // Make a listener which can observe agent1 lost later.
+        agentWeaker.setLingerDuration(Duration.ofMillis(NetworkAgent.MIN_LINGER_TIMER_MS.toLong()))
+        // Make a listener which can observe agentWeaker lost later.
         val callbackWeaker = TestableNetworkCallback(timeoutMs = DEFAULT_TIMEOUT_MS)
         registerNetworkCallback(NetworkRequest.Builder()
                 .clearCapabilities()
                 .addTransportType(TRANSPORT_TEST)
-                .setNetworkSpecifier(CompatUtil.makeEthernetNetworkSpecifier(name1))
+                .setNetworkSpecifier(CompatUtil.makeEthernetNetworkSpecifier(specifierWeaker))
                 .build(), callbackWeaker)
-        callbackWeaker.expectAvailableCallbacks(agent1.network!!)
+        callbackWeaker.expectAvailableCallbacks(agentWeaker.network!!)
 
-        // Connect the second agent with a score better than agent1. Verify the callback for
-        // agent1 sees the linger expiry while the callback for both sees the winner.
+        // Connect the agentStronger with a score better than agentWeaker. Verify the callback for
+        // agentWeaker sees the linger expiry while the callback for both sees the winner.
         // Record linger start timestamp prior to send score to prevent possible race, the actual
         // timestamp should be slightly late than this since the service handles update
         // network score asynchronously.
         val lingerStart = SystemClock.elapsedRealtime()
-        val agent2 = createNetworkAgent(name = name2)
-        agent2.register()
-        agent2.markConnected()
-        callback.expectAvailableCallbacks(agent2.network!!)
-        callbackWeaker.expectCallback<Losing>(agent1.network!!)
+        val agentStronger = createNetworkAgent(specifier = specifierStronger)
+        agentStronger.register()
+        agentStronger.markConnected()
+        commonCallback.expectAvailableCallbacks(agentStronger.network!!)
+        callbackWeaker.expectCallback<Losing>(agentWeaker.network!!)
         val expectedRemainingLingerDuration = lingerStart +
                 NetworkAgent.MIN_LINGER_TIMER_MS.toLong() - SystemClock.elapsedRealtime()
         // If the available callback is too late. The remaining duration will be reduced.
         assertTrue(expectedRemainingLingerDuration > 0,
                 "expected remaining linger duration is $expectedRemainingLingerDuration")
         callbackWeaker.assertNoCallback(expectedRemainingLingerDuration)
-        callbackWeaker.expectCallback<Lost>(agent1.network!!)
+        callbackWeaker.expectCallback<Lost>(agentWeaker.network!!)
     }
 
     @Test
     @IgnoreUpTo(Build.VERSION_CODES.R)
     fun testSetSubscriberId() {
-        val name = "TEST-AGENT"
         val imsi = UUID.randomUUID().toString()
         val config = NetworkAgentConfig.Builder().setSubscriberId(imsi).build()
 
-        val request: NetworkRequest = NetworkRequest.Builder()
-                .clearCapabilities()
-                .addTransportType(TRANSPORT_TEST)
-                .setNetworkSpecifier(CompatUtil.makeEthernetNetworkSpecifier(name))
-                .build()
-        val callback = TestableNetworkCallback(timeoutMs = DEFAULT_TIMEOUT_MS)
-        requestNetwork(request, callback)
-
-        val agent = createNetworkAgent(name = name, initialConfig = config)
-        agent.register()
-        agent.markConnected()
-        callback.expectAvailableThenValidatedCallbacks(agent.network!!)
+        val (agent, _) = createConnectedNetworkAgent(initialConfig = config)
         val snapshots = runWithShellPermissionIdentity(ThrowingSupplier {
                 mCM!!.allNetworkStateSnapshots }, NETWORK_SETTINGS)
         val testNetworkSnapshot = snapshots.findLast { it.network == agent.network }
         assertEquals(imsi, testNetworkSnapshot!!.subscriberId)
     }
+
+    @Test
+    @IgnoreUpTo(Build.VERSION_CODES.R)
+    // TODO: Refactor helper functions to util class and move this test case to
+    //  {@link android.net.cts.ConnectivityManagerTest}.
+    fun testRegisterBestMatchingNetworkCallback() {
+        // Register best matching network callback with additional condition that will be
+        // exercised later. This assumes the test network agent has NOT_VCN_MANAGED in it and
+        // does not have NET_CAPABILITY_TEMPORARILY_NOT_METERED.
+        val bestMatchingCb = TestableNetworkCallback(timeoutMs = DEFAULT_TIMEOUT_MS)
+        registerBestMatchingNetworkCallback(NetworkRequest.Builder()
+                .clearCapabilities()
+                .addTransportType(TRANSPORT_TEST)
+                .addCapability(NET_CAPABILITY_NOT_VCN_MANAGED)
+                .build(), bestMatchingCb, mHandlerThread.threadHandler)
+
+        val (agent1, _) = createConnectedNetworkAgent(specifier = "AGENT-1")
+        bestMatchingCb.expectAvailableThenValidatedCallbacks(agent1.network!!)
+        // Make agent1 worse so when agent2 shows up, the callback will see that.
+        agent1.sendNetworkScore(NetworkScore.Builder().setExiting(true).build())
+        bestMatchingCb.assertNoCallback(NO_CALLBACK_TIMEOUT)
+
+        val (agent2, _) = createConnectedNetworkAgent(specifier = "AGENT-2")
+        bestMatchingCb.expectAvailableDoubleValidatedCallbacks(agent2.network!!)
+
+        // Change something on agent1 to trigger capabilities changed, since the callback
+        // only cares about the best network, verify it received nothing from agent1.
+        val ncAgent1 = agent1.nc
+        ncAgent1.addCapability(NET_CAPABILITY_TEMPORARILY_NOT_METERED)
+        agent1.sendNetworkCapabilities(ncAgent1)
+        bestMatchingCb.assertNoCallback(NO_CALLBACK_TIMEOUT)
+
+        // Make agent1 the best network again, verify the callback now tracks agent1.
+        agent1.sendNetworkScore(NetworkScore.Builder()
+                .setExiting(false).setTransportPrimary(true).build())
+        bestMatchingCb.expectAvailableCallbacks(agent1.network!!)
+
+        // Make agent1 temporary vcn managed, which will not satisfying the request.
+        // Verify the callback switch from/to the other network accordingly.
+        ncAgent1.removeCapability(NET_CAPABILITY_NOT_VCN_MANAGED)
+        agent1.sendNetworkCapabilities(ncAgent1)
+        bestMatchingCb.expectAvailableCallbacks(agent2.network!!)
+        ncAgent1.addCapability(NET_CAPABILITY_NOT_VCN_MANAGED)
+        agent1.sendNetworkCapabilities(ncAgent1)
+        bestMatchingCb.expectAvailableDoubleValidatedCallbacks(agent1.network!!)
+
+        // Verify the callback doesn't care about agent2 disconnect.
+        agent2.unregister()
+        agentsToCleanUp.remove(agent2)
+        bestMatchingCb.assertNoCallback()
+        agent1.unregister()
+        agentsToCleanUp.remove(agent1)
+        bestMatchingCb.expectCallback<Lost>(agent1.network!!)
+
+        // tearDown() will unregister the requests and agents
+    }
 }
diff --git a/tests/unit/Android.bp b/tests/unit/Android.bp
index 6c4bb90..5eb43f3 100644
--- a/tests/unit/Android.bp
+++ b/tests/unit/Android.bp
@@ -56,7 +56,6 @@
         "java/**/*.kt",
     ],
     test_suites: ["device-tests"],
-    certificate: "platform",
     jarjar_rules: "jarjar-rules.txt",
     static_libs: [
         "androidx.test.rules",
diff --git a/tests/unit/java/android/net/ConnectivityManagerTest.java b/tests/unit/java/android/net/ConnectivityManagerTest.java
index 591e0cc..07f22a2 100644
--- a/tests/unit/java/android/net/ConnectivityManagerTest.java
+++ b/tests/unit/java/android/net/ConnectivityManagerTest.java
@@ -44,6 +44,7 @@
 import static org.mockito.ArgumentMatchers.anyBoolean;
 import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.ArgumentMatchers.nullable;
+import static org.mockito.Mockito.CALLS_REAL_METHODS;
 import static org.mockito.Mockito.any;
 import static org.mockito.Mockito.anyInt;
 import static org.mockito.Mockito.mock;
@@ -215,7 +216,8 @@
     public void testCallbackRelease() throws Exception {
         ConnectivityManager manager = new ConnectivityManager(mCtx, mService);
         NetworkRequest request = makeRequest(1);
-        NetworkCallback callback = mock(ConnectivityManager.NetworkCallback.class);
+        NetworkCallback callback = mock(ConnectivityManager.NetworkCallback.class,
+                CALLS_REAL_METHODS);
         Handler handler = new Handler(Looper.getMainLooper());
         ArgumentCaptor<Messenger> captor = ArgumentCaptor.forClass(Messenger.class);
 
@@ -243,7 +245,8 @@
         ConnectivityManager manager = new ConnectivityManager(mCtx, mService);
         NetworkRequest req1 = makeRequest(1);
         NetworkRequest req2 = makeRequest(2);
-        NetworkCallback callback = mock(ConnectivityManager.NetworkCallback.class);
+        NetworkCallback callback = mock(ConnectivityManager.NetworkCallback.class,
+                CALLS_REAL_METHODS);
         Handler handler = new Handler(Looper.getMainLooper());
         ArgumentCaptor<Messenger> captor = ArgumentCaptor.forClass(Messenger.class);
 
diff --git a/tests/unit/java/android/net/nsd/NsdManagerTest.java b/tests/unit/java/android/net/nsd/NsdManagerTest.java
index b0a9b8a..370179c 100644
--- a/tests/unit/java/android/net/nsd/NsdManagerTest.java
+++ b/tests/unit/java/android/net/nsd/NsdManagerTest.java
@@ -20,12 +20,12 @@
 import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.fail;
 import static org.mockito.Mockito.any;
+import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.reset;
 import static org.mockito.Mockito.spy;
 import static org.mockito.Mockito.timeout;
 import static org.mockito.Mockito.verify;
-import static org.mockito.Mockito.when;
 
 import android.content.Context;
 import android.os.Handler;
@@ -66,7 +66,7 @@
         MockitoAnnotations.initMocks(this);
 
         mServiceHandler = spy(MockServiceHandler.create(mContext));
-        when(mService.getMessenger()).thenReturn(new Messenger(mServiceHandler));
+        doReturn(new Messenger(mServiceHandler)).when(mService).getMessenger();
 
         mManager = makeManager();
     }
diff --git a/tests/unit/java/android/net/util/MultinetworkPolicyTrackerTest.kt b/tests/unit/java/android/net/util/MultinetworkPolicyTrackerTest.kt
index 25aa626..78c8fa4 100644
--- a/tests/unit/java/android/net/util/MultinetworkPolicyTrackerTest.kt
+++ b/tests/unit/java/android/net/util/MultinetworkPolicyTrackerTest.kt
@@ -45,6 +45,7 @@
 import org.mockito.ArgumentMatchers.argThat
 import org.mockito.ArgumentMatchers.eq
 import org.mockito.Mockito.any
+import org.mockito.Mockito.doCallRealMethod
 import org.mockito.Mockito.doReturn
 import org.mockito.Mockito.mock
 import org.mockito.Mockito.times
@@ -74,6 +75,10 @@
         doReturn(Context.TELEPHONY_SERVICE).`when`(it)
                 .getSystemServiceName(TelephonyManager::class.java)
         doReturn(telephonyManager).`when`(it).getSystemService(Context.TELEPHONY_SERVICE)
+        if (it.getSystemService(TelephonyManager::class.java) == null) {
+            // Test is using mockito extended
+            doCallRealMethod().`when`(it).getSystemService(TelephonyManager::class.java)
+        }
         doReturn(subscriptionManager).`when`(it)
                 .getSystemService(Context.TELEPHONY_SUBSCRIPTION_SERVICE)
         doReturn(resolver).`when`(it).contentResolver
diff --git a/tests/unit/java/com/android/server/ConnectivityServiceTest.java b/tests/unit/java/com/android/server/ConnectivityServiceTest.java
index e8f249e..3b030d6 100644
--- a/tests/unit/java/com/android/server/ConnectivityServiceTest.java
+++ b/tests/unit/java/com/android/server/ConnectivityServiceTest.java
@@ -18,10 +18,15 @@
 
 import static android.Manifest.permission.CHANGE_NETWORK_STATE;
 import static android.Manifest.permission.CONNECTIVITY_USE_RESTRICTED_NETWORKS;
+import static android.Manifest.permission.CONTROL_OEM_PAID_NETWORK_PREFERENCE;
+import static android.Manifest.permission.CREATE_USERS;
 import static android.Manifest.permission.DUMP;
+import static android.Manifest.permission.GET_INTENT_SENDER_INTENT;
 import static android.Manifest.permission.LOCAL_MAC_ADDRESS;
 import static android.Manifest.permission.NETWORK_FACTORY;
 import static android.Manifest.permission.NETWORK_SETTINGS;
+import static android.Manifest.permission.NETWORK_STACK;
+import static android.Manifest.permission.PACKET_KEEPALIVE_OFFLOAD;
 import static android.app.PendingIntent.FLAG_IMMUTABLE;
 import static android.content.Intent.ACTION_PACKAGE_ADDED;
 import static android.content.Intent.ACTION_PACKAGE_REMOVED;
@@ -134,6 +139,7 @@
 import static com.android.testutils.MiscAsserts.assertRunsInAtMost;
 import static com.android.testutils.MiscAsserts.assertSameElements;
 import static com.android.testutils.MiscAsserts.assertThrows;
+import static com.android.testutils.TestPermissionUtil.runAsShell;
 
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
@@ -259,6 +265,7 @@
 import android.net.shared.PrivateDnsConfig;
 import android.net.util.MultinetworkPolicyTracker;
 import android.os.BadParcelableException;
+import android.os.BatteryStatsManager;
 import android.os.Binder;
 import android.os.Build;
 import android.os.Bundle;
@@ -297,6 +304,7 @@
 import androidx.test.runner.AndroidJUnit4;
 
 import com.android.connectivity.resources.R;
+import com.android.internal.app.IBatteryStats;
 import com.android.internal.net.VpnConfig;
 import com.android.internal.net.VpnProfile;
 import com.android.internal.util.ArrayUtils;
@@ -305,6 +313,7 @@
 import com.android.internal.util.test.FakeSettingsProvider;
 import com.android.net.module.util.ArrayTrackRecord;
 import com.android.net.module.util.CollectionUtils;
+import com.android.net.module.util.LocationPermissionChecker;
 import com.android.server.ConnectivityService.ConnectivityDiagnosticsCallbackInfo;
 import com.android.server.ConnectivityService.NetworkRequestInfo;
 import com.android.server.connectivity.MockableSystemProperties;
@@ -384,7 +393,7 @@
 public class ConnectivityServiceTest {
     private static final String TAG = "ConnectivityServiceTest";
 
-    private static final int TIMEOUT_MS = 500;
+    private static final int TIMEOUT_MS = 2_000;
     // Broadcasts can take a long time to be delivered. The test will not wait for that long unless
     // there is a failure, so use a long timeout.
     private static final int BROADCAST_TIMEOUT_MS = 30_000;
@@ -488,6 +497,11 @@
     @Mock Resources mResources;
     @Mock ProxyTracker mProxyTracker;
 
+    // BatteryStatsManager is final and cannot be mocked with regular mockito, so just mock the
+    // underlying binder calls.
+    final BatteryStatsManager mBatteryStatsManager =
+            new BatteryStatsManager(mock(IBatteryStats.class));
+
     private ArgumentCaptor<ResolverParamsParcel> mResolverParamsParcelCaptor =
             ArgumentCaptor.forClass(ResolverParamsParcel.class);
 
@@ -579,6 +593,7 @@
             if (Context.NETWORK_POLICY_SERVICE.equals(name)) return mNetworkPolicyManager;
             if (Context.SYSTEM_CONFIG_SERVICE.equals(name)) return mSystemConfigManager;
             if (Context.NETWORK_STATS_SERVICE.equals(name)) return mStatsManager;
+            if (Context.BATTERY_STATS_SERVICE.equals(name)) return mBatteryStatsManager;
             return super.getSystemService(name);
         }
 
@@ -659,6 +674,15 @@
         public void setPermission(String permission, Integer granted) {
             mMockedPermissions.put(permission, granted);
         }
+
+        @Override
+        public Intent registerReceiverForAllUsers(@Nullable BroadcastReceiver receiver,
+                @NonNull IntentFilter filter, @Nullable String broadcastPermission,
+                @Nullable Handler scheduler) {
+            // TODO: ensure MultinetworkPolicyTracker's BroadcastReceiver is tested; just returning
+            // null should not pass the test
+            return null;
+        }
     }
 
     private void waitForIdle() {
@@ -1208,7 +1232,24 @@
                             return mDeviceIdleInternal;
                         }
                     },
-                    mNetworkManagementService, mMockNetd, userId, mVpnProfileStore);
+                    mNetworkManagementService, mMockNetd, userId, mVpnProfileStore,
+                    new SystemServices(mServiceContext) {
+                        @Override
+                        public String settingsSecureGetStringForUser(String key, int userId) {
+                            switch (key) {
+                                // Settings keys not marked as @Readable are not readable from
+                                // non-privileged apps, unless marked as testOnly=true
+                                // (atest refuses to install testOnly=true apps), even if mocked
+                                // in the content provider, because
+                                // Settings.Secure.NameValueCache#getStringForUser checks the key
+                                // before querying the mock settings provider.
+                                case Settings.Secure.ALWAYS_ON_VPN_APP:
+                                    return null;
+                                default:
+                                    return super.settingsSecureGetStringForUser(key, userId);
+                            }
+                        }
+                    }, new Ikev2SessionCreator());
         }
 
         public void setUids(Set<UidRange> uids) {
@@ -1592,6 +1633,11 @@
         mServiceContext = new MockContext(InstrumentationRegistry.getContext(),
                 new FakeSettingsProvider());
         mServiceContext.setUseRegisteredHandlers(true);
+        mServiceContext.setPermission(NETWORK_FACTORY, PERMISSION_GRANTED);
+        mServiceContext.setPermission(NETWORK_STACK, PERMISSION_GRANTED);
+        mServiceContext.setPermission(CONTROL_OEM_PAID_NETWORK_PREFERENCE, PERMISSION_GRANTED);
+        mServiceContext.setPermission(PACKET_KEEPALIVE_OFFLOAD, PERMISSION_GRANTED);
+        mServiceContext.setPermission(CONNECTIVITY_USE_RESTRICTED_NETWORKS, PERMISSION_GRANTED);
 
         mAlarmManagerThread = new HandlerThread("TestAlarmManager");
         mAlarmManagerThread.start();
@@ -1651,6 +1697,13 @@
             return mPolicyTracker;
         }).when(deps).makeMultinetworkPolicyTracker(any(), any(), any());
         doReturn(true).when(deps).getCellular464XlatEnabled();
+        doAnswer(inv ->
+            new LocationPermissionChecker(inv.getArgument(0)) {
+                @Override
+                protected int getCurrentUser() {
+                    return runAsShell(CREATE_USERS, () -> super.getCurrentUser());
+                }
+            }).when(deps).makeLocationPermissionChecker(any());
 
         doReturn(60000).when(mResources).getInteger(R.integer.config_networkTransitionTimeout);
         doReturn("").when(mResources).getString(R.string.config_networkCaptivePortalServerUrl);
@@ -1680,6 +1733,12 @@
         doReturn(mResources).when(mockResContext).getResources();
         ConnectivityResources.setResourcesContextForTest(mockResContext);
 
+        doAnswer(inv -> {
+            final PendingIntent a = inv.getArgument(0);
+            final PendingIntent b = inv.getArgument(1);
+            return runAsShell(GET_INTENT_SENDER_INTENT, () -> a.intentFilterEquals(b));
+        }).when(deps).intentFilterEquals(any(), any());
+
         return deps;
     }
 
@@ -9239,8 +9298,7 @@
         mServiceContext.setPermission(NetworkStack.PERMISSION_MAINLINE_NETWORK_STACK,
                 PERMISSION_DENIED);
         mServiceContext.setPermission(NETWORK_SETTINGS, PERMISSION_DENIED);
-        mServiceContext.setPermission(Manifest.permission.NETWORK_STACK,
-                PERMISSION_DENIED);
+        mServiceContext.setPermission(NETWORK_STACK, PERMISSION_DENIED);
         mServiceContext.setPermission(Manifest.permission.NETWORK_SETUP_WIZARD,
                 PERMISSION_DENIED);
     }
@@ -9681,7 +9739,7 @@
         setupConnectionOwnerUid(vpnOwnerUid, vpnType);
 
         // Test as VPN app
-        mServiceContext.setPermission(android.Manifest.permission.NETWORK_STACK, PERMISSION_DENIED);
+        mServiceContext.setPermission(NETWORK_STACK, PERMISSION_DENIED);
         mServiceContext.setPermission(
                 NetworkStack.PERMISSION_MAINLINE_NETWORK_STACK, PERMISSION_DENIED);
     }
@@ -9721,8 +9779,7 @@
     public void testGetConnectionOwnerUidVpnServiceNetworkStackDoesNotThrow() throws Exception {
         final int myUid = Process.myUid();
         setupConnectionOwnerUid(myUid, VpnManager.TYPE_VPN_SERVICE);
-        mServiceContext.setPermission(
-                android.Manifest.permission.NETWORK_STACK, PERMISSION_GRANTED);
+        mServiceContext.setPermission(NETWORK_STACK, PERMISSION_GRANTED);
 
         assertEquals(42, mService.getConnectionOwnerUid(getTestConnectionInfo()));
     }
@@ -9890,8 +9947,7 @@
     public void testCheckConnectivityDiagnosticsPermissionsNetworkStack() throws Exception {
         final NetworkAgentInfo naiWithoutUid = fakeMobileNai(new NetworkCapabilities());
 
-        mServiceContext.setPermission(
-                android.Manifest.permission.NETWORK_STACK, PERMISSION_GRANTED);
+        mServiceContext.setPermission(NETWORK_STACK, PERMISSION_GRANTED);
         assertTrue(
                 "NetworkStack permission not applied",
                 mService.checkConnectivityDiagnosticsPermissions(
@@ -9907,7 +9963,7 @@
         nc.setAdministratorUids(new int[] {wrongUid});
         final NetworkAgentInfo naiWithUid = fakeWifiNai(nc);
 
-        mServiceContext.setPermission(android.Manifest.permission.NETWORK_STACK, PERMISSION_DENIED);
+        mServiceContext.setPermission(NETWORK_STACK, PERMISSION_DENIED);
 
         assertFalse(
                 "Mismatched uid/package name should not pass the location permission check",
@@ -9917,7 +9973,7 @@
 
     private void verifyConnectivityDiagnosticsPermissionsWithNetworkAgentInfo(
             NetworkAgentInfo info, boolean expectPermission) {
-        mServiceContext.setPermission(android.Manifest.permission.NETWORK_STACK, PERMISSION_DENIED);
+        mServiceContext.setPermission(NETWORK_STACK, PERMISSION_DENIED);
 
         assertEquals(
                 "Unexpected ConnDiags permission",
@@ -9985,7 +10041,7 @@
 
         setupLocationPermissions(Build.VERSION_CODES.Q, true, AppOpsManager.OPSTR_FINE_LOCATION,
                 Manifest.permission.ACCESS_FINE_LOCATION);
-        mServiceContext.setPermission(android.Manifest.permission.NETWORK_STACK, PERMISSION_DENIED);
+        mServiceContext.setPermission(NETWORK_STACK, PERMISSION_DENIED);
 
         assertTrue(
                 "NetworkCapabilities administrator uid permission not applied",
@@ -10002,7 +10058,7 @@
 
         setupLocationPermissions(Build.VERSION_CODES.Q, true, AppOpsManager.OPSTR_FINE_LOCATION,
                 Manifest.permission.ACCESS_FINE_LOCATION);
-        mServiceContext.setPermission(android.Manifest.permission.NETWORK_STACK, PERMISSION_DENIED);
+        mServiceContext.setPermission(NETWORK_STACK, PERMISSION_DENIED);
 
         // Use wrong pid and uid
         assertFalse(
@@ -10028,8 +10084,7 @@
         final NetworkRequest request = new NetworkRequest.Builder().build();
         when(mConnectivityDiagnosticsCallback.asBinder()).thenReturn(mIBinder);
 
-        mServiceContext.setPermission(
-                android.Manifest.permission.NETWORK_STACK, PERMISSION_GRANTED);
+        mServiceContext.setPermission(NETWORK_STACK, PERMISSION_GRANTED);
 
         mService.registerConnectivityDiagnosticsCallback(
                 mConnectivityDiagnosticsCallback, request, mContext.getPackageName());
@@ -10048,8 +10103,7 @@
         final NetworkRequest request = new NetworkRequest.Builder().build();
         when(mConnectivityDiagnosticsCallback.asBinder()).thenReturn(mIBinder);
 
-        mServiceContext.setPermission(
-                android.Manifest.permission.NETWORK_STACK, PERMISSION_GRANTED);
+        mServiceContext.setPermission(NETWORK_STACK, PERMISSION_GRANTED);
 
         mService.registerConnectivityDiagnosticsCallback(
                 mConnectivityDiagnosticsCallback, request, mContext.getPackageName());
diff --git a/tests/unit/java/com/android/server/connectivity/MultipathPolicyTrackerTest.java b/tests/unit/java/com/android/server/connectivity/MultipathPolicyTrackerTest.java
index 38f6d7f..4c80f6a 100644
--- a/tests/unit/java/com/android/server/connectivity/MultipathPolicyTrackerTest.java
+++ b/tests/unit/java/com/android/server/connectivity/MultipathPolicyTrackerTest.java
@@ -34,6 +34,7 @@
 import static org.mockito.ArgumentMatchers.anyInt;
 import static org.mockito.ArgumentMatchers.argThat;
 import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.doCallRealMethod;
 import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
@@ -114,8 +115,12 @@
     private boolean mRecurrenceRuleClockMocked;
 
     private <T> void mockService(String serviceName, Class<T> serviceClass, T service) {
-        when(mContext.getSystemServiceName(serviceClass)).thenReturn(serviceName);
-        when(mContext.getSystemService(serviceName)).thenReturn(service);
+        doReturn(serviceName).when(mContext).getSystemServiceName(serviceClass);
+        doReturn(service).when(mContext).getSystemService(serviceName);
+        if (mContext.getSystemService(serviceClass) == null) {
+            // Test is using mockito-extended
+            doCallRealMethod().when(mContext).getSystemService(serviceClass);
+        }
     }
 
     @Before
diff --git a/tests/unit/java/com/android/server/connectivity/PermissionMonitorTest.java b/tests/unit/java/com/android/server/connectivity/PermissionMonitorTest.java
index e98f5db..8b45755 100644
--- a/tests/unit/java/com/android/server/connectivity/PermissionMonitorTest.java
+++ b/tests/unit/java/com/android/server/connectivity/PermissionMonitorTest.java
@@ -51,6 +51,7 @@
 import static org.mockito.ArgumentMatchers.argThat;
 import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.doCallRealMethod;
 import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.reset;
@@ -139,6 +140,10 @@
                 .thenReturn(Context.SYSTEM_CONFIG_SERVICE);
         when(mContext.getSystemService(Context.SYSTEM_CONFIG_SERVICE))
                 .thenReturn(mSystemConfigManager);
+        if (mContext.getSystemService(SystemConfigManager.class) == null) {
+            // Test is using mockito-extended
+            doCallRealMethod().when(mContext).getSystemService(SystemConfigManager.class);
+        }
         when(mSystemConfigManager.getSystemPermissionUids(anyString())).thenReturn(new int[0]);
         final Context asUserCtx = mock(Context.class, AdditionalAnswers.delegatesTo(mContext));
         doReturn(UserHandle.ALL).when(asUserCtx).getUser();
diff --git a/tests/unit/java/com/android/server/connectivity/VpnTest.java b/tests/unit/java/com/android/server/connectivity/VpnTest.java
index b725b82..6ff47ae 100644
--- a/tests/unit/java/com/android/server/connectivity/VpnTest.java
+++ b/tests/unit/java/com/android/server/connectivity/VpnTest.java
@@ -39,6 +39,7 @@
 import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.Mockito.atLeastOnce;
 import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.doCallRealMethod;
 import static org.mockito.Mockito.doNothing;
 import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.inOrder;
@@ -219,19 +220,11 @@
 
         when(mContext.getPackageName()).thenReturn(TEST_VPN_PKG);
         when(mContext.getOpPackageName()).thenReturn(TEST_VPN_PKG);
-        when(mContext.getSystemServiceName(UserManager.class))
-                .thenReturn(Context.USER_SERVICE);
-        when(mContext.getSystemService(eq(Context.USER_SERVICE))).thenReturn(mUserManager);
-        when(mContext.getSystemService(eq(Context.APP_OPS_SERVICE))).thenReturn(mAppOps);
-        when(mContext.getSystemServiceName(NotificationManager.class))
-                .thenReturn(Context.NOTIFICATION_SERVICE);
-        when(mContext.getSystemService(eq(Context.NOTIFICATION_SERVICE)))
-                .thenReturn(mNotificationManager);
-        when(mContext.getSystemService(eq(Context.CONNECTIVITY_SERVICE)))
-                .thenReturn(mConnectivityManager);
-        when(mContext.getSystemServiceName(eq(ConnectivityManager.class)))
-                .thenReturn(Context.CONNECTIVITY_SERVICE);
-        when(mContext.getSystemService(eq(Context.IPSEC_SERVICE))).thenReturn(mIpSecManager);
+        mockService(UserManager.class, Context.USER_SERVICE, mUserManager);
+        mockService(AppOpsManager.class, Context.APP_OPS_SERVICE, mAppOps);
+        mockService(NotificationManager.class, Context.NOTIFICATION_SERVICE, mNotificationManager);
+        mockService(ConnectivityManager.class, Context.CONNECTIVITY_SERVICE, mConnectivityManager);
+        mockService(IpSecManager.class, Context.IPSEC_SERVICE, mIpSecManager);
         when(mContext.getString(R.string.config_customVpnAlwaysOnDisconnectedDialogComponent))
                 .thenReturn(Resources.getSystem().getString(
                         R.string.config_customVpnAlwaysOnDisconnectedDialogComponent));
@@ -259,6 +252,16 @@
                 .thenReturn(tunnelResp);
     }
 
+    private <T> void mockService(Class<T> clazz, String name, T service) {
+        doReturn(service).when(mContext).getSystemService(name);
+        doReturn(name).when(mContext).getSystemServiceName(clazz);
+        if (mContext.getSystemService(clazz).getClass().equals(Object.class)) {
+            // Test is using mockito-extended (mContext uses Answers.RETURNS_DEEP_STUBS and returned
+            // a mock object on a final method)
+            doCallRealMethod().when(mContext).getSystemService(clazz);
+        }
+    }
+
     private Set<Range<Integer>> rangeSet(Range<Integer> ... ranges) {
         final Set<Range<Integer>> range = new ArraySet<>();
         for (Range<Integer> r : ranges) range.add(r);
diff --git a/tests/unit/java/com/android/server/net/NetworkStatsServiceTest.java b/tests/unit/java/com/android/server/net/NetworkStatsServiceTest.java
index 0ba5f7d..3dd6598 100644
--- a/tests/unit/java/com/android/server/net/NetworkStatsServiceTest.java
+++ b/tests/unit/java/com/android/server/net/NetworkStatsServiceTest.java
@@ -16,12 +16,17 @@
 
 package com.android.server.net;
 
+import static android.Manifest.permission.READ_NETWORK_USAGE_HISTORY;
+import static android.Manifest.permission.UPDATE_DEVICE_STATS;
 import static android.content.Intent.ACTION_UID_REMOVED;
 import static android.content.Intent.EXTRA_UID;
+import static android.content.pm.PackageManager.PERMISSION_DENIED;
+import static android.content.pm.PackageManager.PERMISSION_GRANTED;
 import static android.net.ConnectivityManager.TYPE_MOBILE;
 import static android.net.ConnectivityManager.TYPE_WIFI;
 import static android.net.NetworkIdentity.OEM_PAID;
 import static android.net.NetworkIdentity.OEM_PRIVATE;
+import static android.net.NetworkStack.PERMISSION_MAINLINE_NETWORK_STACK;
 import static android.net.NetworkStats.DEFAULT_NETWORK_ALL;
 import static android.net.NetworkStats.DEFAULT_NETWORK_NO;
 import static android.net.NetworkStats.DEFAULT_NETWORK_YES;
@@ -106,6 +111,7 @@
 import android.provider.Settings;
 import android.telephony.TelephonyManager;
 
+import androidx.annotation.Nullable;
 import androidx.test.InstrumentationRegistry;
 import androidx.test.filters.SmallTest;
 import androidx.test.runner.AndroidJUnit4;
@@ -199,6 +205,26 @@
             if (Context.TELEPHONY_SERVICE.equals(name)) return mTelephonyManager;
             return mBaseContext.getSystemService(name);
         }
+
+        @Override
+        public void enforceCallingOrSelfPermission(String permission, @Nullable String message) {
+            if (checkCallingOrSelfPermission(permission) != PERMISSION_GRANTED) {
+                throw new SecurityException("Test does not have mocked permission " + permission);
+            }
+        }
+
+        @Override
+        public int checkCallingOrSelfPermission(String permission) {
+            switch (permission) {
+                case PERMISSION_MAINLINE_NETWORK_STACK:
+                case READ_NETWORK_USAGE_HISTORY:
+                case UPDATE_DEVICE_STATS:
+                    return PERMISSION_GRANTED;
+                default:
+                    return PERMISSION_DENIED;
+            }
+
+        }
     }
 
     private final Clock mClock = new SimpleClock(ZoneOffset.UTC) {