Merge "Add configurable captive portal probes" into pi-dev
diff --git a/core/java/android/net/Network.java b/core/java/android/net/Network.java
index 15a0ee5..d75d439 100644
--- a/core/java/android/net/Network.java
+++ b/core/java/android/net/Network.java
@@ -85,6 +85,21 @@
     private static final long HANDLE_MAGIC = 0xcafed00dL;
     private static final int HANDLE_MAGIC_SIZE = 32;
 
+    // A boolean to control how getAllByName()/getByName() behaves in the face
+    // of Private DNS.
+    //
+    // When true, these calls will request that DNS resolution bypass any
+    // Private DNS that might otherwise apply. Use of this feature is restricted
+    // and permission checks are made by netd (attempts to bypass Private DNS
+    // without appropriate permission are silently turned into vanilla DNS
+    // requests). This only affects DNS queries made using this network object.
+    //
+    // It it not parceled to receivers because (a) it can be set or cleared at
+    // anytime and (b) receivers should be explicit about attempts to bypass
+    // Private DNS so that the intent of the code is easily determined and
+    // code search audits are possible.
+    private boolean mPrivateDnsBypass = false;
+
     /**
      * @hide
      */
@@ -108,7 +123,7 @@
      * @throws UnknownHostException if the address lookup fails.
      */
     public InetAddress[] getAllByName(String host) throws UnknownHostException {
-        return InetAddress.getAllByNameOnNet(host, netId);
+        return InetAddress.getAllByNameOnNet(host, getNetIdForResolv());
     }
 
     /**
@@ -122,7 +137,32 @@
      *             if the address lookup fails.
      */
     public InetAddress getByName(String host) throws UnknownHostException {
-        return InetAddress.getByNameOnNet(host, netId);
+        return InetAddress.getByNameOnNet(host, getNetIdForResolv());
+    }
+
+    /**
+     * Specify whether or not Private DNS should be bypassed when attempting
+     * to use {@link getAllByName()}/{@link getByName()} methods on the given
+     * instance for hostname resolution.
+     *
+     * @hide
+     */
+    public void setPrivateDnsBypass(boolean bypass) {
+        mPrivateDnsBypass = bypass;
+    }
+
+    /**
+     * Returns a netid marked with the Private DNS bypass flag.
+     *
+     * This flag must be kept in sync with the NETID_USE_LOCAL_NAMESERVERS flag
+     * in system/netd/include/NetdClient.h.
+     *
+     * @hide
+     */
+    public int getNetIdForResolv() {
+        return mPrivateDnsBypass
+                ? (int) (0x80000000L | (long) netId)  // Non-portable DNS resolution flag.
+                : netId;
     }
 
     /**
diff --git a/core/java/android/net/NetworkCapabilities.java b/core/java/android/net/NetworkCapabilities.java
index a808c64..e3a1107 100644
--- a/core/java/android/net/NetworkCapabilities.java
+++ b/core/java/android/net/NetworkCapabilities.java
@@ -63,16 +63,7 @@
 
     public NetworkCapabilities(NetworkCapabilities nc) {
         if (nc != null) {
-            mNetworkCapabilities = nc.mNetworkCapabilities;
-            mTransportTypes = nc.mTransportTypes;
-            mLinkUpBandwidthKbps = nc.mLinkUpBandwidthKbps;
-            mLinkDownBandwidthKbps = nc.mLinkDownBandwidthKbps;
-            mNetworkSpecifier = nc.mNetworkSpecifier;
-            mSignalStrength = nc.mSignalStrength;
-            mUids = nc.mUids;
-            mEstablishingVpnAppUid = nc.mEstablishingVpnAppUid;
-            mUnwantedNetworkCapabilities = nc.mUnwantedNetworkCapabilities;
-            mSSID = nc.mSSID;
+            set(nc);
         }
     }
 
@@ -92,6 +83,23 @@
     }
 
     /**
+     * Set all contents of this object to the contents of a NetworkCapabilities.
+     * @hide
+     */
+    public void set(NetworkCapabilities nc) {
+        mNetworkCapabilities = nc.mNetworkCapabilities;
+        mTransportTypes = nc.mTransportTypes;
+        mLinkUpBandwidthKbps = nc.mLinkUpBandwidthKbps;
+        mLinkDownBandwidthKbps = nc.mLinkDownBandwidthKbps;
+        mNetworkSpecifier = nc.mNetworkSpecifier;
+        mSignalStrength = nc.mSignalStrength;
+        setUids(nc.mUids); // Will make the defensive copy
+        mEstablishingVpnAppUid = nc.mEstablishingVpnAppUid;
+        mUnwantedNetworkCapabilities = nc.mUnwantedNetworkCapabilities;
+        mSSID = nc.mSSID;
+    }
+
+    /**
      * Represents the network's capabilities.  If any are specified they will be satisfied
      * by any Network that matches all of them.
      */
diff --git a/core/java/android/net/NetworkRequest.java b/core/java/android/net/NetworkRequest.java
index 227a4cb..16c2342 100644
--- a/core/java/android/net/NetworkRequest.java
+++ b/core/java/android/net/NetworkRequest.java
@@ -198,8 +198,7 @@
          * @hide
          */
         public Builder setCapabilities(NetworkCapabilities nc) {
-            mNetworkCapabilities.clearAll();
-            mNetworkCapabilities.combineCapabilities(nc);
+            mNetworkCapabilities.set(nc);
             return this;
         }
 
diff --git a/services/core/java/com/android/server/ConnectivityService.java b/services/core/java/com/android/server/ConnectivityService.java
index 797cb4b..e22a86e 100644
--- a/services/core/java/com/android/server/ConnectivityService.java
+++ b/services/core/java/com/android/server/ConnectivityService.java
@@ -2503,6 +2503,9 @@
                 ensureNetworkTransitionWakelock(nai.name());
             }
             mLegacyTypeTracker.remove(nai, wasDefault);
+            if (!nai.networkCapabilities.hasTransport(TRANSPORT_VPN)) {
+                updateAllVpnsCapabilities();
+            }
             rematchAllNetworksAndRequests(null, 0);
             mLingerMonitor.noteDisconnect(nai);
             if (nai.created) {
@@ -3778,6 +3781,26 @@
         }
     }
 
+    /**
+     * Ask all VPN objects to recompute and update their capabilities.
+     *
+     * When underlying networks change, VPNs may have to update capabilities to reflect things
+     * like the metered bit, their transports, and so on. This asks the VPN objects to update
+     * their capabilities, and as this will cause them to send messages to the ConnectivityService
+     * handler thread through their agent, this is asynchronous. When the capabilities objects
+     * are computed they will be up-to-date as they are computed synchronously from here and
+     * this is running on the ConnectivityService thread.
+     * TODO : Fix this and call updateCapabilities inline to remove out-of-order events.
+     */
+    private void updateAllVpnsCapabilities() {
+        synchronized (mVpns) {
+            for (int i = 0; i < mVpns.size(); i++) {
+                final Vpn vpn = mVpns.valueAt(i);
+                vpn.updateCapabilities();
+            }
+        }
+    }
+
     @Override
     public boolean updateLockdownVpn() {
         if (Binder.getCallingUid() != Process.SYSTEM_UID) {
@@ -4575,7 +4598,7 @@
 
     // Note: if mDefaultRequest is changed, NetworkMonitor needs to be updated.
     private final NetworkRequest mDefaultRequest;
- 
+
     // Request used to optionally keep mobile data active even when higher
     // priority networks like Wi-Fi are active.
     private final NetworkRequest mDefaultMobileDataRequest;
@@ -4651,7 +4674,7 @@
     }
 
     private void updateLinkProperties(NetworkAgentInfo networkAgent, LinkProperties oldLp) {
-        LinkProperties newLp = networkAgent.linkProperties;
+        LinkProperties newLp = new LinkProperties(networkAgent.linkProperties);
         int netId = networkAgent.network.netId;
 
         // The NetworkAgentInfo does not know whether clatd is running on its network or not. Before
@@ -4685,6 +4708,9 @@
         }
         // TODO - move this check to cover the whole function
         if (!Objects.equals(newLp, oldLp)) {
+            synchronized (networkAgent) {
+                networkAgent.linkProperties = newLp;
+            }
             notifyIfacesChangedForNetworkStats();
             notifyNetworkCallbacks(networkAgent, ConnectivityManager.CALLBACK_IP_CHANGED);
         }
@@ -4935,12 +4961,7 @@
         if (!newNc.hasTransport(TRANSPORT_VPN)) {
             // Tell VPNs about updated capabilities, since they may need to
             // bubble those changes through.
-            synchronized (mVpns) {
-                for (int i = 0; i < mVpns.size(); i++) {
-                    final Vpn vpn = mVpns.valueAt(i);
-                    vpn.updateCapabilities();
-                }
-            }
+            updateAllVpnsCapabilities();
         }
     }
 
diff --git a/tests/net/java/android/net/NetworkCapabilitiesTest.java b/tests/net/java/android/net/NetworkCapabilitiesTest.java
index da897ae..a112fa6 100644
--- a/tests/net/java/android/net/NetworkCapabilitiesTest.java
+++ b/tests/net/java/android/net/NetworkCapabilitiesTest.java
@@ -56,6 +56,7 @@
 @SmallTest
 public class NetworkCapabilitiesTest {
     private static final String TEST_SSID = "TEST_SSID";
+    private static final String DIFFERENT_TEST_SSID = "DIFFERENT_TEST_SSID";
 
     @Test
     public void testMaybeMarkCapabilitiesRestricted() {
@@ -374,6 +375,12 @@
         assertFalse(nc1.satisfiedByNetworkCapabilities(nc2));
     }
 
+    private ArraySet<UidRange> uidRange(int from, int to) {
+        final ArraySet<UidRange> range = new ArraySet<>(1);
+        range.add(new UidRange(from, to));
+        return range;
+    }
+
     @Test
     public void testCombineCapabilities() {
         NetworkCapabilities nc1 = new NetworkCapabilities();
@@ -400,14 +407,30 @@
         nc2.combineCapabilities(nc1);
         assertTrue(TEST_SSID.equals(nc2.getSSID()));
 
-        // Because they now have the same SSID, the folllowing call should not throw
+        // Because they now have the same SSID, the following call should not throw
         nc2.combineCapabilities(nc1);
 
-        nc1.setSSID("different " + TEST_SSID);
+        nc1.setSSID(DIFFERENT_TEST_SSID);
         try {
             nc2.combineCapabilities(nc1);
             fail("Expected IllegalStateException: can't combine different SSIDs");
         } catch (IllegalStateException expected) {}
+        nc1.setSSID(TEST_SSID);
+
+        nc1.setUids(uidRange(10, 13));
+        assertNotEquals(nc1, nc2);
+        nc2.combineCapabilities(nc1);  // Everything + 10~13 is still everything.
+        assertNotEquals(nc1, nc2);
+        nc1.combineCapabilities(nc2);  // 10~13 + everything is everything.
+        assertEquals(nc1, nc2);
+        nc1.setUids(uidRange(10, 13));
+        nc2.setUids(uidRange(20, 23));
+        assertNotEquals(nc1, nc2);
+        nc1.combineCapabilities(nc2);
+        assertTrue(nc1.appliesToUid(12));
+        assertFalse(nc2.appliesToUid(12));
+        assertTrue(nc1.appliesToUid(22));
+        assertTrue(nc2.appliesToUid(22));
     }
 
     @Test
@@ -446,4 +469,38 @@
         p.setDataPosition(0);
         assertEquals(NetworkCapabilities.CREATOR.createFromParcel(p), netCap);
     }
+
+    @Test
+    public void testSet() {
+        NetworkCapabilities nc1 = new NetworkCapabilities();
+        NetworkCapabilities nc2 = new NetworkCapabilities();
+
+        nc1.addUnwantedCapability(NET_CAPABILITY_CAPTIVE_PORTAL);
+        nc1.addCapability(NET_CAPABILITY_NOT_ROAMING);
+        assertNotEquals(nc1, nc2);
+        nc2.set(nc1);
+        assertEquals(nc1, nc2);
+        assertTrue(nc2.hasCapability(NET_CAPABILITY_NOT_ROAMING));
+        assertTrue(nc2.hasUnwantedCapability(NET_CAPABILITY_CAPTIVE_PORTAL));
+
+        // This will effectively move NOT_ROAMING capability from required to unwanted for nc1.
+        nc1.addUnwantedCapability(NET_CAPABILITY_NOT_ROAMING);
+        nc1.setSSID(TEST_SSID);
+        nc2.set(nc1);
+        assertEquals(nc1, nc2);
+        // Contrary to combineCapabilities, set() will have removed the NOT_ROAMING capability
+        // from nc2.
+        assertFalse(nc2.hasCapability(NET_CAPABILITY_NOT_ROAMING));
+        assertTrue(nc2.hasUnwantedCapability(NET_CAPABILITY_NOT_ROAMING));
+        assertTrue(TEST_SSID.equals(nc2.getSSID()));
+
+        nc1.setSSID(DIFFERENT_TEST_SSID);
+        nc2.set(nc1);
+        assertEquals(nc1, nc2);
+        assertTrue(DIFFERENT_TEST_SSID.equals(nc2.getSSID()));
+
+        nc1.setUids(uidRange(10, 13));
+        nc2.set(nc1);  // Overwrites, as opposed to combineCapabilities
+        assertEquals(nc1, nc2);
+    }
 }
diff --git a/tests/net/java/com/android/server/ConnectivityServiceTest.java b/tests/net/java/com/android/server/ConnectivityServiceTest.java
index 7eef2d5..2208580 100644
--- a/tests/net/java/com/android/server/ConnectivityServiceTest.java
+++ b/tests/net/java/com/android/server/ConnectivityServiceTest.java
@@ -112,6 +112,7 @@
 import android.net.RouteInfo;
 import android.net.StringNetworkSpecifier;
 import android.net.UidRange;
+import android.net.VpnService;
 import android.net.captiveportal.CaptivePortalProbeResult;
 import android.net.metrics.IpConnectivityLog;
 import android.net.util.MultinetworkPolicyTracker;
@@ -134,6 +135,7 @@
 import android.util.ArraySet;
 import android.util.Log;
 
+import com.android.internal.net.VpnConfig;
 import com.android.internal.util.ArrayUtils;
 import com.android.internal.util.WakeupMessage;
 import com.android.internal.util.test.BroadcastInterceptingContext;
@@ -197,13 +199,13 @@
     private MockNetworkAgent mWiFiNetworkAgent;
     private MockNetworkAgent mCellNetworkAgent;
     private MockNetworkAgent mEthernetNetworkAgent;
+    private MockVpn mMockVpn;
     private Context mContext;
 
     @Mock IpConnectivityMetrics.Logger mMetricsService;
     @Mock DefaultNetworkMetrics mDefaultNetworkMetrics;
     @Mock INetworkManagementService mNetworkManagementService;
     @Mock INetworkStatsService mStatsService;
-    @Mock Vpn mMockVpn;
 
     private ArgumentCaptor<String[]> mStringArrayCaptor = ArgumentCaptor.forClass(String[].class);
 
@@ -479,6 +481,14 @@
             mNetworkAgent.sendNetworkCapabilities(mNetworkCapabilities);
         }
 
+        public void setNetworkCapabilities(NetworkCapabilities nc,
+                boolean sendToConnectivityService) {
+            mNetworkCapabilities.set(nc);
+            if (sendToConnectivityService) {
+                mNetworkAgent.sendNetworkCapabilities(mNetworkCapabilities);
+            }
+        }
+
         public void connectWithoutInternet() {
             mNetworkInfo.setDetailedState(DetailedState.CONNECTED, null, null);
             mNetworkAgent.sendNetworkInfo(mNetworkInfo);
@@ -594,6 +604,10 @@
             return mRedirectUrl;
         }
 
+        public NetworkAgent getNetworkAgent() {
+            return mNetworkAgent;
+        }
+
         public NetworkCapabilities getNetworkCapabilities() {
             return mNetworkCapabilities;
         }
@@ -726,6 +740,87 @@
         }
     }
 
+    private static Looper startHandlerThreadAndReturnLooper() {
+        final HandlerThread handlerThread = new HandlerThread("MockVpnThread");
+        handlerThread.start();
+        return handlerThread.getLooper();
+    }
+
+    private class MockVpn extends Vpn {
+        // TODO : the interactions between this mock and the mock network agent are too
+        // hard to get right at this moment, because it's unclear in which case which
+        // target needs to get a method call or both, and in what order. It's because
+        // MockNetworkAgent wants to manage its own NetworkCapabilities, but the Vpn
+        // parent class of MockVpn agent wants that responsibility.
+        // That being said inside the test it should be possible to make the interactions
+        // harder to get wrong with precise speccing, judicious comments, helper methods
+        // and a few sprinkled assertions.
+
+        private boolean mConnected = false;
+        // Careful ! This is different from mNetworkAgent, because MockNetworkAgent does
+        // not inherit from NetworkAgent.
+        private MockNetworkAgent mMockNetworkAgent;
+
+        public MockVpn(int userId) {
+            super(startHandlerThreadAndReturnLooper(), mServiceContext, mNetworkManagementService,
+                    userId);
+        }
+
+        public void setNetworkAgent(MockNetworkAgent agent) {
+            waitForIdle(agent, TIMEOUT_MS);
+            mMockNetworkAgent = agent;
+            mNetworkAgent = agent.getNetworkAgent();
+            mNetworkCapabilities.set(agent.getNetworkCapabilities());
+        }
+
+        public void setUids(Set<UidRange> uids) {
+            mNetworkCapabilities.setUids(uids);
+            updateCapabilities();
+        }
+
+        @Override
+        public int getNetId() {
+            return mMockNetworkAgent.getNetwork().netId;
+        }
+
+        @Override
+        public boolean appliesToUid(int uid) {
+            return mConnected;  // Trickery to simplify testing.
+        }
+
+        @Override
+        protected boolean isCallerEstablishedOwnerLocked() {
+            return mConnected;  // Similar trickery
+        }
+
+        public void connect() {
+            mNetworkCapabilities.set(mMockNetworkAgent.getNetworkCapabilities());
+            mConnected = true;
+            mConfig = new VpnConfig();
+        }
+
+        @Override
+        public void updateCapabilities() {
+            if (!mConnected) return;
+            super.updateCapabilities();
+            // Because super.updateCapabilities will update the capabilities of the agent but not
+            // the mock agent, the mock agent needs to know about them.
+            copyCapabilitiesToNetworkAgent();
+        }
+
+        private void copyCapabilitiesToNetworkAgent() {
+            if (null != mMockNetworkAgent) {
+                mMockNetworkAgent.setNetworkCapabilities(mNetworkCapabilities,
+                        false /* sendToConnectivityService */);
+            }
+        }
+
+        public void disconnect() {
+            mConnected = false;
+            mConfig = null;
+        }
+    }
+
     private class FakeWakeupMessage extends WakeupMessage {
         private static final int UNREASONABLY_LONG_WAIT = 1000;
 
@@ -894,10 +989,12 @@
 
         public void mockVpn(int uid) {
             synchronized (mVpns) {
+                int userId = UserHandle.getUserId(uid);
+                mMockVpn = new MockVpn(userId);
                 // This has no effect unless the VPN is actually connected, because things like
                 // getActiveNetworkForUidInternal call getNetworkAgentInfoForNetId on the VPN
                 // netId, and check if that network is actually connected.
-                mVpns.put(UserHandle.getUserId(Process.myUid()), mMockVpn);
+                mVpns.put(userId, mMockVpn);
             }
         }
 
@@ -927,7 +1024,6 @@
 
         MockitoAnnotations.initMocks(this);
         when(mMetricsService.defaultNetworkMetrics()).thenReturn(mDefaultNetworkMetrics);
-        when(mMockVpn.appliesToUid(Process.myUid())).thenReturn(true);
 
         // InstrumentationTestRunner prepares a looper, but AndroidJUnitRunner does not.
         // http://b/25897652 .
@@ -1549,7 +1645,8 @@
 
         void expectCapabilitiesLike(Predicate<NetworkCapabilities> fn, MockNetworkAgent agent) {
             CallbackInfo cbi = expectCallback(CallbackState.NETWORK_CAPABILITIES, agent);
-            assertTrue(fn.test((NetworkCapabilities) cbi.arg));
+            assertTrue("Received capabilities don't match expectations : " + cbi.arg,
+                    fn.test((NetworkCapabilities) cbi.arg));
         }
 
         void assertNoCallback() {
@@ -2577,9 +2674,10 @@
         final MockNetworkAgent vpnNetworkAgent = new MockNetworkAgent(TRANSPORT_VPN);
         final ArraySet<UidRange> ranges = new ArraySet<>();
         ranges.add(new UidRange(uid, uid));
-        when(mMockVpn.getNetId()).thenReturn(vpnNetworkAgent.getNetwork().netId);
-        vpnNetworkAgent.setUids(ranges);
+        mMockVpn.setNetworkAgent(vpnNetworkAgent);
+        mMockVpn.setUids(ranges);
         vpnNetworkAgent.connect(true);
+        mMockVpn.connect();
         defaultNetworkCallback.expectAvailableThenValidatedCallbacks(vpnNetworkAgent);
         assertEquals(defaultNetworkCallback.getLastAvailableNetwork(), mCm.getActiveNetwork());
 
@@ -4108,9 +4206,10 @@
         final MockNetworkAgent vpnNetworkAgent = new MockNetworkAgent(TRANSPORT_VPN);
         final ArraySet<UidRange> ranges = new ArraySet<>();
         ranges.add(new UidRange(uid, uid));
-        when(mMockVpn.getNetId()).thenReturn(vpnNetworkAgent.getNetwork().netId);
-        vpnNetworkAgent.setUids(ranges);
+        mMockVpn.setNetworkAgent(vpnNetworkAgent);
+        mMockVpn.setUids(ranges);
         vpnNetworkAgent.connect(false);
+        mMockVpn.connect();
 
         genericNetworkCallback.expectAvailableCallbacksUnvalidated(vpnNetworkAgent);
         genericNotVpnNetworkCallback.assertNoCallback();
@@ -4142,7 +4241,7 @@
         defaultCallback.expectCallback(CallbackState.NETWORK_CAPABILITIES, vpnNetworkAgent);
 
         ranges.add(new UidRange(uid, uid));
-        vpnNetworkAgent.setUids(ranges);
+        mMockVpn.setUids(ranges);
 
         genericNetworkCallback.expectAvailableCallbacksValidated(vpnNetworkAgent);
         genericNotVpnNetworkCallback.assertNoCallback();
@@ -4192,9 +4291,10 @@
         MockNetworkAgent vpnNetworkAgent = new MockNetworkAgent(TRANSPORT_VPN);
         final ArraySet<UidRange> ranges = new ArraySet<>();
         ranges.add(new UidRange(uid, uid));
-        when(mMockVpn.getNetId()).thenReturn(vpnNetworkAgent.getNetwork().netId);
-        vpnNetworkAgent.setUids(ranges);
+        mMockVpn.setNetworkAgent(vpnNetworkAgent);
+        mMockVpn.setUids(ranges);
         vpnNetworkAgent.connect(true /* validated */, false /* hasInternet */);
+        mMockVpn.connect();
 
         defaultCallback.assertNoCallback();
         assertEquals(defaultCallback.getLastAvailableNetwork(), mCm.getActiveNetwork());
@@ -4203,9 +4303,10 @@
         defaultCallback.assertNoCallback();
 
         vpnNetworkAgent = new MockNetworkAgent(TRANSPORT_VPN);
-        when(mMockVpn.getNetId()).thenReturn(vpnNetworkAgent.getNetwork().netId);
-        vpnNetworkAgent.setUids(ranges);
+        mMockVpn.setNetworkAgent(vpnNetworkAgent);
+        mMockVpn.setUids(ranges);
         vpnNetworkAgent.connect(true /* validated */, true /* hasInternet */);
+        mMockVpn.connect();
         defaultCallback.expectAvailableThenValidatedCallbacks(vpnNetworkAgent);
         assertEquals(defaultCallback.getLastAvailableNetwork(), mCm.getActiveNetwork());
 
@@ -4214,13 +4315,111 @@
         defaultCallback.expectAvailableCallbacksValidated(mWiFiNetworkAgent);
 
         vpnNetworkAgent = new MockNetworkAgent(TRANSPORT_VPN);
-        when(mMockVpn.getNetId()).thenReturn(vpnNetworkAgent.getNetwork().netId);
         ranges.clear();
-        vpnNetworkAgent.setUids(ranges);
-
+        mMockVpn.setNetworkAgent(vpnNetworkAgent);
+        mMockVpn.setUids(ranges);
         vpnNetworkAgent.connect(false /* validated */, true /* hasInternet */);
+        mMockVpn.connect();
         defaultCallback.assertNoCallback();
 
         mCm.unregisterNetworkCallback(defaultCallback);
     }
+
+    @Test
+    public void testVpnSetUnderlyingNetworks() {
+        final int uid = Process.myUid();
+
+        final TestNetworkCallback vpnNetworkCallback = new TestNetworkCallback();
+        final NetworkRequest vpnNetworkRequest = new NetworkRequest.Builder()
+                .removeCapability(NET_CAPABILITY_NOT_VPN)
+                .addTransportType(TRANSPORT_VPN)
+                .build();
+        NetworkCapabilities nc;
+        mCm.registerNetworkCallback(vpnNetworkRequest, vpnNetworkCallback);
+        vpnNetworkCallback.assertNoCallback();
+
+        final MockNetworkAgent vpnNetworkAgent = new MockNetworkAgent(TRANSPORT_VPN);
+        final ArraySet<UidRange> ranges = new ArraySet<>();
+        ranges.add(new UidRange(uid, uid));
+        mMockVpn.setNetworkAgent(vpnNetworkAgent);
+        mMockVpn.connect();
+        mMockVpn.setUids(ranges);
+        vpnNetworkAgent.connect(true /* validated */, false /* hasInternet */);
+
+        vpnNetworkCallback.expectAvailableThenValidatedCallbacks(vpnNetworkAgent);
+        nc = mCm.getNetworkCapabilities(vpnNetworkAgent.getNetwork());
+        assertTrue(nc.hasTransport(TRANSPORT_VPN));
+        assertFalse(nc.hasTransport(TRANSPORT_CELLULAR));
+        assertFalse(nc.hasTransport(TRANSPORT_WIFI));
+        // For safety reasons a VPN without underlying networks is considered metered.
+        assertFalse(nc.hasCapability(NET_CAPABILITY_NOT_METERED));
+
+        // Connect cell and use it as an underlying network.
+        mCellNetworkAgent = new MockNetworkAgent(TRANSPORT_CELLULAR);
+        mCellNetworkAgent.connect(true);
+
+        mService.setUnderlyingNetworksForVpn(
+                new Network[] { mCellNetworkAgent.getNetwork() });
+
+        vpnNetworkCallback.expectCapabilitiesLike((caps) -> caps.hasTransport(TRANSPORT_VPN)
+                && caps.hasTransport(TRANSPORT_CELLULAR) && !caps.hasTransport(TRANSPORT_WIFI)
+                && !caps.hasCapability(NET_CAPABILITY_NOT_METERED),
+                vpnNetworkAgent);
+
+        mWiFiNetworkAgent = new MockNetworkAgent(TRANSPORT_WIFI);
+        mWiFiNetworkAgent.addCapability(NET_CAPABILITY_NOT_METERED);
+        mWiFiNetworkAgent.connect(true);
+
+        mService.setUnderlyingNetworksForVpn(
+                new Network[] { mCellNetworkAgent.getNetwork(), mWiFiNetworkAgent.getNetwork() });
+
+        vpnNetworkCallback.expectCapabilitiesLike((caps) -> caps.hasTransport(TRANSPORT_VPN)
+                && caps.hasTransport(TRANSPORT_CELLULAR) && caps.hasTransport(TRANSPORT_WIFI)
+                && !caps.hasCapability(NET_CAPABILITY_NOT_METERED),
+                vpnNetworkAgent);
+
+        // Don't disconnect, but note the VPN is not using wifi any more.
+        mService.setUnderlyingNetworksForVpn(
+                new Network[] { mCellNetworkAgent.getNetwork() });
+
+        vpnNetworkCallback.expectCapabilitiesLike((caps) -> caps.hasTransport(TRANSPORT_VPN)
+                && caps.hasTransport(TRANSPORT_CELLULAR) && !caps.hasTransport(TRANSPORT_WIFI)
+                && !caps.hasCapability(NET_CAPABILITY_NOT_METERED),
+                vpnNetworkAgent);
+
+        // Use Wifi but not cell. Note the VPN is now unmetered.
+        mService.setUnderlyingNetworksForVpn(
+                new Network[] { mWiFiNetworkAgent.getNetwork() });
+
+        vpnNetworkCallback.expectCapabilitiesLike((caps) -> caps.hasTransport(TRANSPORT_VPN)
+                && !caps.hasTransport(TRANSPORT_CELLULAR) && caps.hasTransport(TRANSPORT_WIFI)
+                && caps.hasCapability(NET_CAPABILITY_NOT_METERED),
+                vpnNetworkAgent);
+
+        // Use both again.
+        mService.setUnderlyingNetworksForVpn(
+                new Network[] { mCellNetworkAgent.getNetwork(), mWiFiNetworkAgent.getNetwork() });
+
+        vpnNetworkCallback.expectCapabilitiesLike((caps) -> caps.hasTransport(TRANSPORT_VPN)
+                && caps.hasTransport(TRANSPORT_CELLULAR) && caps.hasTransport(TRANSPORT_WIFI)
+                && !caps.hasCapability(NET_CAPABILITY_NOT_METERED),
+                vpnNetworkAgent);
+
+        // Disconnect cell. Receive update without even removing the dead network from the
+        // underlying networks – it's dead anyway. Not metered any more.
+        mCellNetworkAgent.disconnect();
+        vpnNetworkCallback.expectCapabilitiesLike((caps) -> caps.hasTransport(TRANSPORT_VPN)
+                && !caps.hasTransport(TRANSPORT_CELLULAR) && caps.hasTransport(TRANSPORT_WIFI)
+                && caps.hasCapability(NET_CAPABILITY_NOT_METERED),
+                vpnNetworkAgent);
+
+        // Disconnect wifi too. No underlying networks means this is now metered.
+        mWiFiNetworkAgent.disconnect();
+        vpnNetworkCallback.expectCapabilitiesLike((caps) -> caps.hasTransport(TRANSPORT_VPN)
+                && !caps.hasTransport(TRANSPORT_CELLULAR) && !caps.hasTransport(TRANSPORT_WIFI)
+                && !caps.hasCapability(NET_CAPABILITY_NOT_METERED),
+                vpnNetworkAgent);
+
+        mMockVpn.disconnect();
+    }
 }