Allow ConnectivityServiceTest to change the calling UID.

Allow ConnectivityServiceTest to change the UID by replacing
static calls to Binder.getCallingUid() with a method that can
be mocked.

Add registerNetworkCallbackAsUid as an initial way to exercise
this, and add some test coverage to the always-on lockdown test
to confirm that things are working as expected.

Bug: 173331190
Test: new unit tests
Change-Id: Ie0b32460e20e5906a0f479191e11a062f21cc608
diff --git a/services/core/java/com/android/server/ConnectivityService.java b/services/core/java/com/android/server/ConnectivityService.java
index 2c765bd..184f1bf 100644
--- a/services/core/java/com/android/server/ConnectivityService.java
+++ b/services/core/java/com/android/server/ConnectivityService.java
@@ -876,6 +876,10 @@
      */
     @VisibleForTesting
     public static class Dependencies {
+        public int getCallingUid() {
+            return Binder.getCallingUid();
+        }
+
         /**
          * Get system properties to use in ConnectivityService.
          */
@@ -1408,7 +1412,7 @@
     @Override
     public NetworkInfo getActiveNetworkInfo() {
         enforceAccessPermission();
-        final int uid = Binder.getCallingUid();
+        final int uid = mDeps.getCallingUid();
         final NetworkState state = getUnfilteredActiveNetworkState(uid);
         filterNetworkStateForUid(state, uid, false);
         maybeLogBlockedNetworkInfo(state.networkInfo, uid);
@@ -1418,7 +1422,7 @@
     @Override
     public Network getActiveNetwork() {
         enforceAccessPermission();
-        return getActiveNetworkForUidInternal(Binder.getCallingUid(), false);
+        return getActiveNetworkForUidInternal(mDeps.getCallingUid(), false);
     }
 
     @Override
@@ -1458,7 +1462,7 @@
     // Public because it's used by mLockdownTracker.
     public NetworkInfo getActiveNetworkInfoUnfiltered() {
         enforceAccessPermission();
-        final int uid = Binder.getCallingUid();
+        final int uid = mDeps.getCallingUid();
         NetworkState state = getUnfilteredActiveNetworkState(uid);
         return state.networkInfo;
     }
@@ -1474,7 +1478,7 @@
     @Override
     public NetworkInfo getNetworkInfo(int networkType) {
         enforceAccessPermission();
-        final int uid = Binder.getCallingUid();
+        final int uid = mDeps.getCallingUid();
         if (getVpnUnderlyingNetworks(uid) != null) {
             // A VPN is active, so we may need to return one of its underlying networks. This
             // information is not available in LegacyTypeTracker, so we have to get it from
@@ -1519,7 +1523,7 @@
     @Override
     public Network getNetworkForType(int networkType) {
         enforceAccessPermission();
-        final int uid = Binder.getCallingUid();
+        final int uid = mDeps.getCallingUid();
         NetworkState state = getFilteredNetworkState(networkType, uid);
         if (!isNetworkWithLinkPropertiesBlocked(state.linkProperties, uid, false)) {
             return state.network;
@@ -1566,7 +1570,7 @@
             result.put(
                     nai.network,
                     maybeSanitizeLocationInfoForCaller(
-                            nc, Binder.getCallingUid(), callingPackageName));
+                            nc, mDeps.getCallingUid(), callingPackageName));
         }
 
         synchronized (mVpns) {
@@ -1581,7 +1585,7 @@
                                 result.put(
                                         network,
                                         maybeSanitizeLocationInfoForCaller(
-                                                nc, Binder.getCallingUid(), callingPackageName));
+                                                nc, mDeps.getCallingUid(), callingPackageName));
                             }
                         }
                     }
@@ -1611,7 +1615,7 @@
     @Override
     public LinkProperties getActiveLinkProperties() {
         enforceAccessPermission();
-        final int uid = Binder.getCallingUid();
+        final int uid = mDeps.getCallingUid();
         NetworkState state = getUnfilteredActiveNetworkState(uid);
         if (state.linkProperties == null) return null;
         return linkPropertiesRestrictedForCallerPermissions(state.linkProperties,
@@ -1625,7 +1629,7 @@
         final LinkProperties lp = getLinkProperties(nai);
         if (lp == null) return null;
         return linkPropertiesRestrictedForCallerPermissions(
-                lp, Binder.getCallingPid(), Binder.getCallingUid());
+                lp, Binder.getCallingPid(), mDeps.getCallingUid());
     }
 
     // TODO - this should be ALL networks
@@ -1635,7 +1639,7 @@
         final LinkProperties lp = getLinkProperties(getNetworkAgentInfoForNetwork(network));
         if (lp == null) return null;
         return linkPropertiesRestrictedForCallerPermissions(
-                lp, Binder.getCallingPid(), Binder.getCallingUid());
+                lp, Binder.getCallingPid(), mDeps.getCallingUid());
     }
 
     @Nullable
@@ -1657,17 +1661,17 @@
         synchronized (nai) {
             if (nai.networkCapabilities == null) return null;
             return networkCapabilitiesRestrictedForCallerPermissions(
-                    nai.networkCapabilities, Binder.getCallingPid(), Binder.getCallingUid());
+                    nai.networkCapabilities, Binder.getCallingPid(), mDeps.getCallingUid());
         }
     }
 
     @Override
     public NetworkCapabilities getNetworkCapabilities(Network network, String callingPackageName) {
-        mAppOpsManager.checkPackage(Binder.getCallingUid(), callingPackageName);
+        mAppOpsManager.checkPackage(mDeps.getCallingUid(), callingPackageName);
         enforceAccessPermission();
         return maybeSanitizeLocationInfoForCaller(
                 getNetworkCapabilitiesInternal(network),
-                Binder.getCallingUid(), callingPackageName);
+                mDeps.getCallingUid(), callingPackageName);
     }
 
     @VisibleForTesting
@@ -1755,7 +1759,7 @@
     }
 
     private void restrictBackgroundRequestForCaller(NetworkCapabilities nc) {
-        if (!mPermissionMonitor.hasUseBackgroundNetworksPermission(Binder.getCallingUid())) {
+        if (!mPermissionMonitor.hasUseBackgroundNetworksPermission(mDeps.getCallingUid())) {
             nc.addCapability(NET_CAPABILITY_FOREGROUND);
         }
     }
@@ -1808,7 +1812,7 @@
         // requestRouteToHost. In Q, GnssLocationProvider is changed to not call requestRouteToHost
         // for devices launched with Q and above. However, existing devices upgrading to Q and
         // above must continued to be supported for few more releases.
-        if (isSystem(Binder.getCallingUid()) && SystemProperties.getInt(
+        if (isSystem(mDeps.getCallingUid()) && SystemProperties.getInt(
                 "ro.product.first_api_level", 0) > Build.VERSION_CODES.P) {
             log("This method exists only for app backwards compatibility"
                     + " and must not be called by system services.");
@@ -1874,7 +1878,7 @@
             return false;
         }
 
-        final int uid = Binder.getCallingUid();
+        final int uid = mDeps.getCallingUid();
         final long token = Binder.clearCallingIdentity();
         try {
             LinkProperties lp;
@@ -2294,7 +2298,7 @@
      */
     @Override
     public void systemReady() {
-        if (Binder.getCallingUid() != Process.SYSTEM_UID) {
+        if (mDeps.getCallingUid() != Process.SYSTEM_UID) {
             throw new SecurityException("Calling Uid is not system uid.");
         }
         systemReadyInternal();
@@ -2520,7 +2524,7 @@
         if (context.checkCallingOrSelfPermission(android.Manifest.permission.DUMP)
                 != PackageManager.PERMISSION_GRANTED) {
             pw.println("Permission Denial: can't dump " + tag + " from from pid="
-                    + Binder.getCallingPid() + ", uid=" + Binder.getCallingUid()
+                    + Binder.getCallingPid() + ", uid=" + mDeps.getCallingUid()
                     + " due to missing android.permission.DUMP permission");
             return false;
         } else {
@@ -3900,7 +3904,7 @@
 
             if (request == CaptivePortal.APP_REQUEST_REEVALUATION_REQUIRED) {
                 checkNetworkStackPermission();
-                nm.forceReevaluation(Binder.getCallingUid());
+                nm.forceReevaluation(mDeps.getCallingUid());
             }
         }
 
@@ -4367,7 +4371,7 @@
     public void reportNetworkConnectivity(Network network, boolean hasConnectivity) {
         enforceAccessPermission();
         enforceInternetPermission();
-        final int uid = Binder.getCallingUid();
+        final int uid = mDeps.getCallingUid();
         final int connectivityInfo = encodeBool(hasConnectivity);
 
         // Handle ConnectivityDiagnostics event before attempting to revalidate the network. This
@@ -4437,13 +4441,13 @@
         if (globalProxy != null) return globalProxy;
         if (network == null) {
             // Get the network associated with the calling UID.
-            final Network activeNetwork = getActiveNetworkForUidInternal(Binder.getCallingUid(),
+            final Network activeNetwork = getActiveNetworkForUidInternal(mDeps.getCallingUid(),
                     true);
             if (activeNetwork == null) {
                 return null;
             }
             return getLinkPropertiesProxyInfo(activeNetwork);
-        } else if (mDeps.queryUserAccess(Binder.getCallingUid(), network.getNetId())) {
+        } else if (mDeps.queryUserAccess(mDeps.getCallingUid(), network.getNetId())) {
             // Don't call getLinkProperties() as it requires ACCESS_NETWORK_STATE permission, which
             // caller may not have.
             return getLinkPropertiesProxyInfo(network);
@@ -4612,7 +4616,7 @@
      */
     @Override
     public ParcelFileDescriptor establishVpn(VpnConfig config) {
-        int user = UserHandle.getUserId(Binder.getCallingUid());
+        int user = UserHandle.getUserId(mDeps.getCallingUid());
         synchronized (mVpns) {
             throwIfLockdownEnabled();
             return mVpns.get(user).establish(config);
@@ -4633,7 +4637,7 @@
      */
     @Override
     public boolean provisionVpnProfile(@NonNull VpnProfile profile, @NonNull String packageName) {
-        final int user = UserHandle.getUserId(Binder.getCallingUid());
+        final int user = UserHandle.getUserId(mDeps.getCallingUid());
         synchronized (mVpns) {
             return mVpns.get(user).provisionVpnProfile(packageName, profile, mKeyStore);
         }
@@ -4651,7 +4655,7 @@
      */
     @Override
     public void deleteVpnProfile(@NonNull String packageName) {
-        final int user = UserHandle.getUserId(Binder.getCallingUid());
+        final int user = UserHandle.getUserId(mDeps.getCallingUid());
         synchronized (mVpns) {
             mVpns.get(user).deleteVpnProfile(packageName, mKeyStore);
         }
@@ -4668,7 +4672,7 @@
      */
     @Override
     public void startVpnProfile(@NonNull String packageName) {
-        final int user = UserHandle.getUserId(Binder.getCallingUid());
+        final int user = UserHandle.getUserId(mDeps.getCallingUid());
         synchronized (mVpns) {
             throwIfLockdownEnabled();
             mVpns.get(user).startVpnProfile(packageName, mKeyStore);
@@ -4685,7 +4689,7 @@
      */
     @Override
     public void stopVpnProfile(@NonNull String packageName) {
-        final int user = UserHandle.getUserId(Binder.getCallingUid());
+        final int user = UserHandle.getUserId(mDeps.getCallingUid());
         synchronized (mVpns) {
             mVpns.get(user).stopVpnProfile(packageName);
         }
@@ -4697,7 +4701,7 @@
      */
     @Override
     public void startLegacyVpn(VpnProfile profile) {
-        int user = UserHandle.getUserId(Binder.getCallingUid());
+        int user = UserHandle.getUserId(mDeps.getCallingUid());
         final LinkProperties egress = getActiveLinkProperties();
         if (egress == null) {
             throw new IllegalStateException("Missing active network connection");
@@ -4846,7 +4850,7 @@
 
     @Override
     public boolean updateLockdownVpn() {
-        if (Binder.getCallingUid() != Process.SYSTEM_UID) {
+        if (mDeps.getCallingUid() != Process.SYSTEM_UID) {
             logw("Lockdown VPN only available to AID_SYSTEM");
             return false;
         }
@@ -4868,7 +4872,7 @@
                     setLockdownTracker(null);
                     return true;
                 }
-                int user = UserHandle.getUserId(Binder.getCallingUid());
+                int user = UserHandle.getUserId(mDeps.getCallingUid());
                 Vpn vpn = mVpns.get(user);
                 if (vpn == null) {
                     logw("VPN for user " + user + " not ready yet. Skipping lockdown");
@@ -5433,7 +5437,7 @@
             messenger = null;
             mBinder = null;
             mPid = getCallingPid();
-            mUid = getCallingUid();
+            mUid = mDeps.getCallingUid();
             enforceRequestCountLimit();
         }
 
@@ -5445,7 +5449,7 @@
             ensureAllNetworkRequestsHaveType(mRequests);
             mBinder = binder;
             mPid = getCallingPid();
-            mUid = getCallingUid();
+            mUid = mDeps.getCallingUid();
             mPendingIntent = null;
             enforceRequestCountLimit();
 
@@ -5588,7 +5592,7 @@
     }
 
     private boolean checkUnsupportedStartingFrom(int version, String callingPackageName) {
-        final UserHandle user = UserHandle.getUserHandleForUid(Binder.getCallingUid());
+        final UserHandle user = UserHandle.getUserHandleForUid(mDeps.getCallingUid());
         final PackageManager pm =
                 mContext.createContextAsUser(user, 0 /* flags */).getPackageManager();
         try {
@@ -5608,7 +5612,7 @@
                 throw new SecurityException("Insufficient permissions to specify legacy type");
             }
         }
-        final int callingUid = Binder.getCallingUid();
+        final int callingUid = mDeps.getCallingUid();
         final NetworkRequest.Type type = (networkCapabilities == null)
                 ? NetworkRequest.Type.TRACK_DEFAULT
                 : NetworkRequest.Type.REQUEST;
@@ -5678,7 +5682,7 @@
         if (nai != null) {
             nai.asyncChannel.sendMessage(android.net.NetworkAgent.CMD_REQUEST_BANDWIDTH_UPDATE);
             synchronized (mBandwidthRequests) {
-                final int uid = Binder.getCallingUid();
+                final int uid = mDeps.getCallingUid();
                 Integer uidReqs = mBandwidthRequests.get(uid);
                 if (uidReqs == null) {
                     uidReqs = 0;
@@ -5695,7 +5699,7 @@
     }
 
     private void enforceMeteredApnPolicy(NetworkCapabilities networkCapabilities) {
-        final int uid = Binder.getCallingUid();
+        final int uid = mDeps.getCallingUid();
         if (isSystem(uid)) {
             // Exemption for system uid.
             return;
@@ -5715,7 +5719,7 @@
             PendingIntent operation, @NonNull String callingPackageName,
             @Nullable String callingAttributionTag) {
         Objects.requireNonNull(operation, "PendingIntent cannot be null.");
-        final int callingUid = Binder.getCallingUid();
+        final int callingUid = mDeps.getCallingUid();
         networkCapabilities = new NetworkCapabilities(networkCapabilities);
         enforceNetworkRequestPermissions(networkCapabilities, callingPackageName,
                 callingAttributionTag);
@@ -5774,7 +5778,7 @@
     @Override
     public NetworkRequest listenForNetwork(NetworkCapabilities networkCapabilities,
             Messenger messenger, IBinder binder, @NonNull String callingPackageName) {
-        final int callingUid = Binder.getCallingUid();
+        final int callingUid = mDeps.getCallingUid();
         if (!hasWifiNetworkListenPermission(networkCapabilities)) {
             enforceAccessPermission();
         }
@@ -5804,7 +5808,7 @@
     public void pendingListenForNetwork(NetworkCapabilities networkCapabilities,
             PendingIntent operation, @NonNull String callingPackageName) {
         Objects.requireNonNull(operation, "PendingIntent cannot be null.");
-        final int callingUid = Binder.getCallingUid();
+        final int callingUid = mDeps.getCallingUid();
         if (!hasWifiNetworkListenPermission(networkCapabilities)) {
             enforceAccessPermission();
         }
@@ -5905,7 +5909,7 @@
         } else {
             enforceNetworkFactoryPermission();
         }
-        mHandler.post(() -> handleReleaseNetworkRequest(request, Binder.getCallingUid(), true));
+        mHandler.post(() -> handleReleaseNetworkRequest(request, mDeps.getCallingUid(), true));
     }
 
     // NOTE: Accessed on multiple threads, must be synchronized on itself.
@@ -5999,7 +6003,7 @@
             enforceNetworkFactoryPermission();
         }
 
-        final int uid = Binder.getCallingUid();
+        final int uid = mDeps.getCallingUid();
         final long token = Binder.clearCallingIdentity();
         try {
             return registerNetworkAgentInternal(messenger, networkInfo, linkProperties,
@@ -7653,7 +7657,7 @@
 
     @Override
     public boolean addVpnAddress(String address, int prefixLength) {
-        int user = UserHandle.getUserId(Binder.getCallingUid());
+        int user = UserHandle.getUserId(mDeps.getCallingUid());
         synchronized (mVpns) {
             throwIfLockdownEnabled();
             return mVpns.get(user).addAddress(address, prefixLength);
@@ -7662,7 +7666,7 @@
 
     @Override
     public boolean removeVpnAddress(String address, int prefixLength) {
-        int user = UserHandle.getUserId(Binder.getCallingUid());
+        int user = UserHandle.getUserId(mDeps.getCallingUid());
         synchronized (mVpns) {
             throwIfLockdownEnabled();
             return mVpns.get(user).removeAddress(address, prefixLength);
@@ -7671,7 +7675,7 @@
 
     @Override
     public boolean setUnderlyingNetworksForVpn(Network[] networks) {
-        int user = UserHandle.getUserId(Binder.getCallingUid());
+        int user = UserHandle.getUserId(mDeps.getCallingUid());
         final boolean success;
         synchronized (mVpns) {
             throwIfLockdownEnabled();
@@ -7898,7 +7902,7 @@
 
     @GuardedBy("mVpns")
     private Vpn getVpnIfOwner() {
-        return getVpnIfOwner(Binder.getCallingUid());
+        return getVpnIfOwner(mDeps.getCallingUid());
     }
 
     @GuardedBy("mVpns")
@@ -8376,7 +8380,7 @@
             throw new IllegalArgumentException("ConnectivityManager.TYPE_* are deprecated."
                     + " Please use NetworkCapabilities instead.");
         }
-        final int callingUid = Binder.getCallingUid();
+        final int callingUid = mDeps.getCallingUid();
         mAppOpsManager.checkPackage(callingUid, callingPackageName);
 
         // This NetworkCapabilities is only used for matching to Networks. Clear out its owner uid
@@ -8411,7 +8415,7 @@
                 mConnectivityDiagnosticsHandler.obtainMessage(
                         ConnectivityDiagnosticsHandler
                                 .EVENT_UNREGISTER_CONNECTIVITY_DIAGNOSTICS_CALLBACK,
-                        Binder.getCallingUid(),
+                        mDeps.getCallingUid(),
                         0,
                         callback));
     }
@@ -8427,7 +8431,7 @@
         }
 
         final NetworkAgentInfo nai = getNetworkAgentInfoForNetwork(network);
-        if (nai == null || nai.creatorUid != Binder.getCallingUid()) {
+        if (nai == null || nai.creatorUid != mDeps.getCallingUid()) {
             throw new SecurityException("Data Stall simulation is only possible for network "
                 + "creators");
         }
diff --git a/tests/net/java/com/android/server/ConnectivityServiceTest.java b/tests/net/java/com/android/server/ConnectivityServiceTest.java
index af1f75e..21dbbc6 100644
--- a/tests/net/java/com/android/server/ConnectivityServiceTest.java
+++ b/tests/net/java/com/android/server/ConnectivityServiceTest.java
@@ -345,6 +345,7 @@
 
     private MockContext mServiceContext;
     private HandlerThread mCsHandlerThread;
+    private ConnectivityService.Dependencies mDeps;
     private ConnectivityService mService;
     private WrappedConnectivityManager mCm;
     private TestNetworkAgentWrapper mWiFiNetworkAgent;
@@ -1267,6 +1268,17 @@
         fail("ConditionVariable was blocked for more than " + TIMEOUT_MS + "ms");
     }
 
+    private void registerNetworkCallbackAsUid(NetworkRequest request, NetworkCallback callback,
+            int uid) {
+        when(mDeps.getCallingUid()).thenReturn(uid);
+        try {
+            mCm.registerNetworkCallback(request, callback);
+            waitForIdle();
+        } finally {
+            returnRealCallingUid();
+        }
+    }
+
     private static final int VPN_USER = 0;
     private static final int APP1_UID = UserHandle.getUid(VPN_USER, 10100);
     private static final int APP2_UID = UserHandle.getUid(VPN_USER, 10101);
@@ -1309,7 +1321,8 @@
         initAlarmManager(mAlarmManager, mAlarmManagerThread.getThreadHandler());
 
         mCsHandlerThread = new HandlerThread("TestConnectivityService");
-        final ConnectivityService.Dependencies deps = makeDependencies();
+        mDeps = makeDependencies();
+        returnRealCallingUid();
         mService = new ConnectivityService(mServiceContext,
                 mNetworkManagementService,
                 mStatsService,
@@ -1317,9 +1330,9 @@
                 mMockDnsResolver,
                 mock(IpConnectivityLog.class),
                 mMockNetd,
-                deps);
+                mDeps);
         mService.mLingerDelayMs = TEST_LINGER_DELAY_MS;
-        verify(deps).makeMultinetworkPolicyTracker(any(), any(), any());
+        verify(mDeps).makeMultinetworkPolicyTracker(any(), any(), any());
 
         final ArgumentCaptor<INetworkPolicyListener> policyListenerCaptor =
                 ArgumentCaptor.forClass(INetworkPolicyListener.class);
@@ -1339,6 +1352,10 @@
         setPrivateDnsSettings(PRIVATE_DNS_MODE_OFF, "ignored.example.com");
     }
 
+    private void returnRealCallingUid() {
+        doAnswer((invocationOnMock) -> Binder.getCallingUid()).when(mDeps).getCallingUid();
+    }
+
     private ConnectivityService.Dependencies makeDependencies() {
         doReturn(TEST_TCP_INIT_RWND).when(mSystemProperties)
                 .getInt("net.tcp.default_init_rwnd", 0);
@@ -6362,6 +6379,7 @@
         // Despite VPN using WiFi (which is unmetered), VPN itself is marked as always metered.
         assertTrue(mCm.isActiveNetworkMetered());
 
+
         // VPN explicitly declares WiFi as its underlying network.
         mService.setUnderlyingNetworksForVpn(
                 new Network[] { mWiFiNetworkAgent.getNetwork() });
@@ -6511,6 +6529,10 @@
         final TestNetworkCallback defaultCallback = new TestNetworkCallback();
         mCm.registerDefaultNetworkCallback(defaultCallback);
 
+        final TestNetworkCallback vpnUidCallback = new TestNetworkCallback();
+        final NetworkRequest vpnUidRequest = new NetworkRequest.Builder().build();
+        registerNetworkCallbackAsUid(vpnUidRequest, vpnUidCallback, VPN_UID);
+
         final int uid = Process.myUid();
         final int userId = UserHandle.getUserId(uid);
         final ArrayList<String> allowList = new ArrayList<>();
@@ -6526,6 +6548,7 @@
         mWiFiNetworkAgent.connect(false /* validated */);
         callback.expectAvailableCallbacksUnvalidatedAndBlocked(mWiFiNetworkAgent);
         defaultCallback.expectAvailableCallbacksUnvalidatedAndBlocked(mWiFiNetworkAgent);
+        vpnUidCallback.expectAvailableCallbacksUnvalidated(mWiFiNetworkAgent);
         assertEquals(mWiFiNetworkAgent.getNetwork(), mCm.getActiveNetworkForUid(VPN_UID));
         assertNull(mCm.getActiveNetwork());
         assertActiveNetworkInfo(TYPE_WIFI, DetailedState.BLOCKED);
@@ -6537,6 +6560,7 @@
         // There are no callbacks because they are not implemented yet.
         mService.setAlwaysOnVpnPackage(userId, null, false /* lockdown */, allowList);
         expectNetworkRejectNonSecureVpn(inOrder, false, firstHalf, secondHalf);
+        vpnUidCallback.assertNoCallback();
         assertEquals(mWiFiNetworkAgent.getNetwork(), mCm.getActiveNetworkForUid(VPN_UID));
         assertEquals(mWiFiNetworkAgent.getNetwork(), mCm.getActiveNetwork());
         assertActiveNetworkInfo(TYPE_WIFI, DetailedState.CONNECTED);
@@ -6546,7 +6570,9 @@
         // Add our UID to the allowlist and re-enable lockdown, expect network is not blocked.
         allowList.add(TEST_PACKAGE_NAME);
         mService.setAlwaysOnVpnPackage(userId, ALWAYS_ON_PACKAGE, true /* lockdown */, allowList);
+        callback.assertNoCallback();
         defaultCallback.assertNoCallback();
+        vpnUidCallback.assertNoCallback();
 
         // The following requires that the UID of this test package is greater than VPN_UID. This
         // is always true in practice because a plain AOSP build with no apps installed has almost
@@ -6566,6 +6592,7 @@
         mCellNetworkAgent.connect(false /* validated */);
         callback.expectAvailableCallbacksUnvalidated(mCellNetworkAgent);
         defaultCallback.assertNoCallback();
+        vpnUidCallback.expectAvailableCallbacksUnvalidated(mCellNetworkAgent);
         assertEquals(mWiFiNetworkAgent.getNetwork(), mCm.getActiveNetworkForUid(VPN_UID));
         assertEquals(mWiFiNetworkAgent.getNetwork(), mCm.getActiveNetwork());
         assertActiveNetworkInfo(TYPE_WIFI, DetailedState.CONNECTED);
@@ -6580,6 +6607,7 @@
         allowList.clear();
         mService.setAlwaysOnVpnPackage(userId, ALWAYS_ON_PACKAGE, true /* lockdown */, allowList);
         expectNetworkRejectNonSecureVpn(inOrder, true, firstHalf, secondHalf);
+        vpnUidCallback.assertNoCallback();
         assertEquals(mWiFiNetworkAgent.getNetwork(), mCm.getActiveNetworkForUid(VPN_UID));
         assertNull(mCm.getActiveNetwork());
         assertActiveNetworkInfo(TYPE_WIFI, DetailedState.BLOCKED);
@@ -6588,6 +6616,7 @@
 
         // Disable lockdown. Everything is unblocked.
         mService.setAlwaysOnVpnPackage(userId, null, false /* lockdown */, allowList);
+        vpnUidCallback.assertNoCallback();
         assertEquals(mWiFiNetworkAgent.getNetwork(), mCm.getActiveNetworkForUid(VPN_UID));
         assertEquals(mWiFiNetworkAgent.getNetwork(), mCm.getActiveNetwork());
         assertActiveNetworkInfo(TYPE_WIFI, DetailedState.CONNECTED);
@@ -6600,6 +6629,7 @@
         inOrder.verify(mMockNetd, never()).networkRejectNonSecureVpn(anyBoolean(), any());
         callback.assertNoCallback();
         defaultCallback.assertNoCallback();
+        vpnUidCallback.assertNoCallback();
         assertEquals(mWiFiNetworkAgent.getNetwork(), mCm.getActiveNetworkForUid(VPN_UID));
         assertEquals(mWiFiNetworkAgent.getNetwork(), mCm.getActiveNetwork());
         assertActiveNetworkInfo(TYPE_WIFI, DetailedState.CONNECTED);
@@ -6610,6 +6640,7 @@
         inOrder.verify(mMockNetd, never()).networkRejectNonSecureVpn(anyBoolean(), any());
         callback.assertNoCallback();
         defaultCallback.assertNoCallback();
+        vpnUidCallback.assertNoCallback();
         assertEquals(mWiFiNetworkAgent.getNetwork(), mCm.getActiveNetworkForUid(VPN_UID));
         assertEquals(mWiFiNetworkAgent.getNetwork(), mCm.getActiveNetwork());
         assertActiveNetworkInfo(TYPE_WIFI, DetailedState.CONNECTED);
@@ -6618,6 +6649,7 @@
 
         // Enable lockdown and connect a VPN. The VPN is not blocked.
         mService.setAlwaysOnVpnPackage(userId, ALWAYS_ON_PACKAGE, true /* lockdown */, allowList);
+        vpnUidCallback.assertNoCallback();
         assertEquals(mWiFiNetworkAgent.getNetwork(), mCm.getActiveNetworkForUid(VPN_UID));
         assertNull(mCm.getActiveNetwork());
         assertActiveNetworkInfo(TYPE_WIFI, DetailedState.BLOCKED);
@@ -6626,6 +6658,7 @@
 
         mMockVpn.establishForMyUid();
         defaultCallback.expectAvailableThenValidatedCallbacks(mMockVpn);
+        vpnUidCallback.assertNoCallback();  // vpnUidCallback has NOT_VPN capability.
         assertEquals(mMockVpn.getNetwork(), mCm.getActiveNetwork());
         assertEquals(null, mCm.getActiveNetworkForUid(VPN_UID));  // BUG?
         assertActiveNetworkInfo(TYPE_WIFI, DetailedState.CONNECTED);
@@ -6640,6 +6673,7 @@
 
         mCm.unregisterNetworkCallback(callback);
         mCm.unregisterNetworkCallback(defaultCallback);
+        mCm.unregisterNetworkCallback(vpnUidCallback);
     }
 
     @Test