Revert "Revert "Update VPN capabilities when its underlying network set is null.""

This reverts commit c0e7f75e704b284c297a7f3317ce944394b608b0.

Reason for revert: Retargeted for June monthly release

Bug: 119129310

Change-Id: I9d543415c5707859cfa2a14a1a8ce5909aae7d11
Merged-In: Id0abc4d304bb096e92479a118168690ccce634ed
diff --git a/services/core/java/com/android/server/ConnectivityService.java b/services/core/java/com/android/server/ConnectivityService.java
index c9f9ab6..1c66c5a 100644
--- a/services/core/java/com/android/server/ConnectivityService.java
+++ b/services/core/java/com/android/server/ConnectivityService.java
@@ -867,7 +867,8 @@
 
         mPermissionMonitor = new PermissionMonitor(mContext, mNetd);
 
-        //set up the listener for user state for creating user VPNs
+        // Set up the listener for user state for creating user VPNs.
+        // Should run on mHandler to avoid any races.
         IntentFilter intentFilter = new IntentFilter();
         intentFilter.addAction(Intent.ACTION_USER_STARTED);
         intentFilter.addAction(Intent.ACTION_USER_STOPPED);
@@ -875,7 +876,11 @@
         intentFilter.addAction(Intent.ACTION_USER_REMOVED);
         intentFilter.addAction(Intent.ACTION_USER_UNLOCKED);
         mContext.registerReceiverAsUser(
-                mUserIntentReceiver, UserHandle.ALL, intentFilter, null, null);
+                mUserIntentReceiver,
+                UserHandle.ALL,
+                intentFilter,
+                null /* broadcastPermission */,
+                mHandler);
         mContext.registerReceiverAsUser(mUserPresentReceiver, UserHandle.SYSTEM,
                 new IntentFilter(Intent.ACTION_USER_PRESENT), null, null);
 
@@ -3815,17 +3820,27 @@
      * handler thread through their agent, this is asynchronous. When the capabilities objects
      * are computed they will be up-to-date as they are computed synchronously from here and
      * this is running on the ConnectivityService thread.
-     * TODO : Fix this and call updateCapabilities inline to remove out-of-order events.
      */
     private void updateAllVpnsCapabilities() {
+        Network defaultNetwork = getNetwork(getDefaultNetwork());
         synchronized (mVpns) {
             for (int i = 0; i < mVpns.size(); i++) {
                 final Vpn vpn = mVpns.valueAt(i);
-                vpn.updateCapabilities();
+                NetworkCapabilities nc = vpn.updateCapabilities(defaultNetwork);
+                updateVpnCapabilities(vpn, nc);
             }
         }
     }
 
+    private void updateVpnCapabilities(Vpn vpn, @Nullable NetworkCapabilities nc) {
+        ensureRunningOnConnectivityServiceThread();
+        NetworkAgentInfo vpnNai = getNetworkAgentInfoForNetId(vpn.getNetId());
+        if (vpnNai == null || nc == null) {
+            return;
+        }
+        updateCapabilities(vpnNai.getCurrentScore(), vpnNai, nc);
+    }
+
     @Override
     public boolean updateLockdownVpn() {
         if (Binder.getCallingUid() != Process.SYSTEM_UID) {
@@ -4132,21 +4147,27 @@
     }
 
     private void onUserAdded(int userId) {
+        Network defaultNetwork = getNetwork(getDefaultNetwork());
         synchronized (mVpns) {
             final int vpnsSize = mVpns.size();
             for (int i = 0; i < vpnsSize; i++) {
                 Vpn vpn = mVpns.valueAt(i);
                 vpn.onUserAdded(userId);
+                NetworkCapabilities nc = vpn.updateCapabilities(defaultNetwork);
+                updateVpnCapabilities(vpn, nc);
             }
         }
     }
 
     private void onUserRemoved(int userId) {
+        Network defaultNetwork = getNetwork(getDefaultNetwork());
         synchronized (mVpns) {
             final int vpnsSize = mVpns.size();
             for (int i = 0; i < vpnsSize; i++) {
                 Vpn vpn = mVpns.valueAt(i);
                 vpn.onUserRemoved(userId);
+                NetworkCapabilities nc = vpn.updateCapabilities(defaultNetwork);
+                updateVpnCapabilities(vpn, nc);
             }
         }
     }
@@ -4165,6 +4186,7 @@
     private BroadcastReceiver mUserIntentReceiver = new BroadcastReceiver() {
         @Override
         public void onReceive(Context context, Intent intent) {
+            ensureRunningOnConnectivityServiceThread();
             final String action = intent.getAction();
             final int userId = intent.getIntExtra(Intent.EXTRA_USER_HANDLE, UserHandle.USER_NULL);
             if (userId == UserHandle.USER_NULL) return;
@@ -4650,6 +4672,19 @@
         return getNetworkForRequest(mDefaultRequest.requestId);
     }
 
+    @Nullable
+    private Network getNetwork(@Nullable NetworkAgentInfo nai) {
+        return nai != null ? nai.network : null;
+    }
+
+    private void ensureRunningOnConnectivityServiceThread() {
+        if (mHandler.getLooper().getThread() != Thread.currentThread()) {
+            throw new IllegalStateException(
+                    "Not running on ConnectivityService thread: "
+                            + Thread.currentThread().getName());
+        }
+    }
+
     private boolean isDefaultNetwork(NetworkAgentInfo nai) {
         return nai == getDefaultNetwork();
     }
@@ -5197,6 +5232,8 @@
         updateTcpBufferSizes(newNetwork);
         mDnsManager.setDefaultDnsSystemProperties(newNetwork.linkProperties.getDnsServers());
         notifyIfacesChangedForNetworkStats();
+        // Fix up the NetworkCapabilities of any VPNs that don't specify underlying networks.
+        updateAllVpnsCapabilities();
     }
 
     private void processListenRequests(NetworkAgentInfo nai, boolean capabilitiesChanged) {
@@ -5630,6 +5667,10 @@
             // doing.
             updateSignalStrengthThresholds(networkAgent, "CONNECT", null);
 
+            if (networkAgent.isVPN()) {
+                updateAllVpnsCapabilities();
+            }
+
             // Consider network even though it is not yet validated.
             final long now = SystemClock.elapsedRealtime();
             rematchNetworkAndRequests(networkAgent, ReapUnvalidatedNetworks.REAP, now);
@@ -5823,7 +5864,11 @@
             success = mVpns.get(user).setUnderlyingNetworks(networks);
         }
         if (success) {
-            mHandler.post(() -> notifyIfacesChangedForNetworkStats());
+            mHandler.post(() -> {
+                // Update VPN's capabilities based on updated underlying network set.
+                updateAllVpnsCapabilities();
+                notifyIfacesChangedForNetworkStats();
+            });
         }
         return success;
     }
diff --git a/tests/net/java/com/android/server/ConnectivityServiceTest.java b/tests/net/java/com/android/server/ConnectivityServiceTest.java
index c2c627d..0dcb21a 100644
--- a/tests/net/java/com/android/server/ConnectivityServiceTest.java
+++ b/tests/net/java/com/android/server/ConnectivityServiceTest.java
@@ -20,6 +20,7 @@
 import static android.net.ConnectivityManager.PRIVATE_DNS_MODE_OFF;
 import static android.net.ConnectivityManager.PRIVATE_DNS_MODE_OPPORTUNISTIC;
 import static android.net.ConnectivityManager.PRIVATE_DNS_MODE_PROVIDER_HOSTNAME;
+import static android.net.ConnectivityManager.NETID_UNSET;
 import static android.net.ConnectivityManager.TYPE_ETHERNET;
 import static android.net.ConnectivityManager.TYPE_MOBILE;
 import static android.net.ConnectivityManager.TYPE_MOBILE_FOTA;
@@ -775,11 +776,14 @@
 
         public void setUids(Set<UidRange> uids) {
             mNetworkCapabilities.setUids(uids);
-            updateCapabilities();
+            updateCapabilities(null /* defaultNetwork */);
         }
 
         @Override
         public int getNetId() {
+            if (mMockNetworkAgent == null) {
+                return NETID_UNSET;
+            }
             return mMockNetworkAgent.getNetwork().netId;
         }
 
@@ -800,12 +804,13 @@
         }
 
         @Override
-        public void updateCapabilities() {
-            if (!mConnected) return;
-            super.updateCapabilities();
-            // Because super.updateCapabilities will update the capabilities of the agent but not
-            // the mock agent, the mock agent needs to know about them.
+        public NetworkCapabilities updateCapabilities(Network defaultNetwork) {
+            if (!mConnected) return null;
+            super.updateCapabilities(defaultNetwork);
+            // Because super.updateCapabilities will update the capabilities of the agent but
+            // not the mock agent, the mock agent needs to know about them.
             copyCapabilitiesToNetworkAgent();
+            return new NetworkCapabilities(mNetworkCapabilities);
         }
 
         private void copyCapabilitiesToNetworkAgent() {
@@ -4218,6 +4223,7 @@
         mMockVpn.setUids(ranges);
         vpnNetworkAgent.connect(false);
         mMockVpn.connect();
+        mMockVpn.setUnderlyingNetworks(new Network[0]);
 
         genericNetworkCallback.expectAvailableCallbacksUnvalidated(vpnNetworkAgent);
         genericNotVpnNetworkCallback.assertNoCallback();
@@ -4250,6 +4256,7 @@
 
         ranges.add(new UidRange(uid, uid));
         mMockVpn.setUids(ranges);
+        vpnNetworkAgent.setUids(ranges);
 
         genericNetworkCallback.expectAvailableCallbacksValidated(vpnNetworkAgent);
         genericNotVpnNetworkCallback.assertNoCallback();
@@ -4283,12 +4290,11 @@
     }
 
     @Test
-    public void testVpnWithAndWithoutInternet() {
+    public void testVpnWithoutInternet() {
         final int uid = Process.myUid();
 
         final TestNetworkCallback defaultCallback = new TestNetworkCallback();
         mCm.registerDefaultNetworkCallback(defaultCallback);
-        defaultCallback.assertNoCallback();
 
         mWiFiNetworkAgent = new MockNetworkAgent(TRANSPORT_WIFI);
         mWiFiNetworkAgent.connect(true);
@@ -4310,11 +4316,30 @@
         vpnNetworkAgent.disconnect();
         defaultCallback.assertNoCallback();
 
-        vpnNetworkAgent = new MockNetworkAgent(TRANSPORT_VPN);
+        mCm.unregisterNetworkCallback(defaultCallback);
+    }
+
+    @Test
+    public void testVpnWithInternet() {
+        final int uid = Process.myUid();
+
+        final TestNetworkCallback defaultCallback = new TestNetworkCallback();
+        mCm.registerDefaultNetworkCallback(defaultCallback);
+
+        mWiFiNetworkAgent = new MockNetworkAgent(TRANSPORT_WIFI);
+        mWiFiNetworkAgent.connect(true);
+
+        defaultCallback.expectAvailableThenValidatedCallbacks(mWiFiNetworkAgent);
+        assertEquals(defaultCallback.getLastAvailableNetwork(), mCm.getActiveNetwork());
+
+        MockNetworkAgent vpnNetworkAgent = new MockNetworkAgent(TRANSPORT_VPN);
+        final ArraySet<UidRange> ranges = new ArraySet<>();
+        ranges.add(new UidRange(uid, uid));
         mMockVpn.setNetworkAgent(vpnNetworkAgent);
         mMockVpn.setUids(ranges);
         vpnNetworkAgent.connect(true /* validated */, true /* hasInternet */);
         mMockVpn.connect();
+
         defaultCallback.expectAvailableThenValidatedCallbacks(vpnNetworkAgent);
         assertEquals(defaultCallback.getLastAvailableNetwork(), mCm.getActiveNetwork());
 
@@ -4322,14 +4347,6 @@
         defaultCallback.expectCallback(CallbackState.LOST, vpnNetworkAgent);
         defaultCallback.expectAvailableCallbacksValidated(mWiFiNetworkAgent);
 
-        vpnNetworkAgent = new MockNetworkAgent(TRANSPORT_VPN);
-        ranges.clear();
-        mMockVpn.setNetworkAgent(vpnNetworkAgent);
-        mMockVpn.setUids(ranges);
-        vpnNetworkAgent.connect(false /* validated */, true /* hasInternet */);
-        mMockVpn.connect();
-        defaultCallback.assertNoCallback();
-
         mCm.unregisterNetworkCallback(defaultCallback);
     }
 
@@ -4430,4 +4447,68 @@
 
         mMockVpn.disconnect();
     }
+
+    @Test
+    public void testNullUnderlyingNetworks() {
+        final int uid = Process.myUid();
+
+        final TestNetworkCallback vpnNetworkCallback = new TestNetworkCallback();
+        final NetworkRequest vpnNetworkRequest = new NetworkRequest.Builder()
+                .removeCapability(NET_CAPABILITY_NOT_VPN)
+                .addTransportType(TRANSPORT_VPN)
+                .build();
+        NetworkCapabilities nc;
+        mCm.registerNetworkCallback(vpnNetworkRequest, vpnNetworkCallback);
+        vpnNetworkCallback.assertNoCallback();
+
+        final MockNetworkAgent vpnNetworkAgent = new MockNetworkAgent(TRANSPORT_VPN);
+        final ArraySet<UidRange> ranges = new ArraySet<>();
+        ranges.add(new UidRange(uid, uid));
+        mMockVpn.setNetworkAgent(vpnNetworkAgent);
+        mMockVpn.connect();
+        mMockVpn.setUids(ranges);
+        vpnNetworkAgent.connect(true /* validated */, false /* hasInternet */);
+
+        vpnNetworkCallback.expectAvailableThenValidatedCallbacks(vpnNetworkAgent);
+        nc = mCm.getNetworkCapabilities(vpnNetworkAgent.getNetwork());
+        assertTrue(nc.hasTransport(TRANSPORT_VPN));
+        assertFalse(nc.hasTransport(TRANSPORT_CELLULAR));
+        assertFalse(nc.hasTransport(TRANSPORT_WIFI));
+        // By default, VPN is set to track default network (i.e. its underlying networks is null).
+        // In case of no default network, VPN is considered metered.
+        assertFalse(nc.hasCapability(NET_CAPABILITY_NOT_METERED));
+
+        // Connect to Cell; Cell is the default network.
+        mCellNetworkAgent = new MockNetworkAgent(TRANSPORT_CELLULAR);
+        mCellNetworkAgent.connect(true);
+
+        vpnNetworkCallback.expectCapabilitiesLike((caps) -> caps.hasTransport(TRANSPORT_VPN)
+                && caps.hasTransport(TRANSPORT_CELLULAR) && !caps.hasTransport(TRANSPORT_WIFI)
+                && !caps.hasCapability(NET_CAPABILITY_NOT_METERED),
+                vpnNetworkAgent);
+
+        // Connect to WiFi; WiFi is the new default.
+        mWiFiNetworkAgent = new MockNetworkAgent(TRANSPORT_WIFI);
+        mWiFiNetworkAgent.addCapability(NET_CAPABILITY_NOT_METERED);
+        mWiFiNetworkAgent.connect(true);
+
+        vpnNetworkCallback.expectCapabilitiesLike((caps) -> caps.hasTransport(TRANSPORT_VPN)
+                && !caps.hasTransport(TRANSPORT_CELLULAR) && caps.hasTransport(TRANSPORT_WIFI)
+                && caps.hasCapability(NET_CAPABILITY_NOT_METERED),
+                vpnNetworkAgent);
+
+        // Disconnect Cell. The default network did not change, so there shouldn't be any changes in
+        // the capabilities.
+        mCellNetworkAgent.disconnect();
+
+        // Disconnect wifi too. Now we have no default network.
+        mWiFiNetworkAgent.disconnect();
+
+        vpnNetworkCallback.expectCapabilitiesLike((caps) -> caps.hasTransport(TRANSPORT_VPN)
+                && !caps.hasTransport(TRANSPORT_CELLULAR) && !caps.hasTransport(TRANSPORT_WIFI)
+                && !caps.hasCapability(NET_CAPABILITY_NOT_METERED),
+                vpnNetworkAgent);
+
+        mMockVpn.disconnect();
+    }
 }
diff --git a/tests/net/java/com/android/server/connectivity/VpnTest.java b/tests/net/java/com/android/server/connectivity/VpnTest.java
index e377a47..a0a4ad1 100644
--- a/tests/net/java/com/android/server/connectivity/VpnTest.java
+++ b/tests/net/java/com/android/server/connectivity/VpnTest.java
@@ -457,7 +457,8 @@
 
         final NetworkCapabilities caps = new NetworkCapabilities();
 
-        Vpn.updateCapabilities(mConnectivityManager, new Network[] { }, caps);
+        Vpn.applyUnderlyingCapabilities(
+                mConnectivityManager, new Network[] {}, caps);
         assertTrue(caps.hasTransport(TRANSPORT_VPN));
         assertFalse(caps.hasTransport(TRANSPORT_CELLULAR));
         assertFalse(caps.hasTransport(TRANSPORT_WIFI));
@@ -467,7 +468,8 @@
         assertTrue(caps.hasCapability(NET_CAPABILITY_NOT_ROAMING));
         assertTrue(caps.hasCapability(NET_CAPABILITY_NOT_CONGESTED));
 
-        Vpn.updateCapabilities(mConnectivityManager, new Network[] { mobile }, caps);
+        Vpn.applyUnderlyingCapabilities(
+                mConnectivityManager, new Network[] {mobile}, caps);
         assertTrue(caps.hasTransport(TRANSPORT_VPN));
         assertTrue(caps.hasTransport(TRANSPORT_CELLULAR));
         assertFalse(caps.hasTransport(TRANSPORT_WIFI));
@@ -477,7 +479,8 @@
         assertFalse(caps.hasCapability(NET_CAPABILITY_NOT_ROAMING));
         assertTrue(caps.hasCapability(NET_CAPABILITY_NOT_CONGESTED));
 
-        Vpn.updateCapabilities(mConnectivityManager, new Network[] { wifi }, caps);
+        Vpn.applyUnderlyingCapabilities(
+                mConnectivityManager, new Network[] {wifi}, caps);
         assertTrue(caps.hasTransport(TRANSPORT_VPN));
         assertFalse(caps.hasTransport(TRANSPORT_CELLULAR));
         assertTrue(caps.hasTransport(TRANSPORT_WIFI));
@@ -487,7 +490,8 @@
         assertTrue(caps.hasCapability(NET_CAPABILITY_NOT_ROAMING));
         assertTrue(caps.hasCapability(NET_CAPABILITY_NOT_CONGESTED));
 
-        Vpn.updateCapabilities(mConnectivityManager, new Network[] { mobile, wifi }, caps);
+        Vpn.applyUnderlyingCapabilities(
+                mConnectivityManager, new Network[] {mobile, wifi}, caps);
         assertTrue(caps.hasTransport(TRANSPORT_VPN));
         assertTrue(caps.hasTransport(TRANSPORT_CELLULAR));
         assertTrue(caps.hasTransport(TRANSPORT_WIFI));