Merge "Restrict VPN Diagnostics callbacks to underlying networks."
diff --git a/services/core/java/com/android/server/ConnectivityService.java b/services/core/java/com/android/server/ConnectivityService.java
index 4dbbd0d..f4b9b10 100644
--- a/services/core/java/com/android/server/ConnectivityService.java
+++ b/services/core/java/com/android/server/ConnectivityService.java
@@ -7944,10 +7944,13 @@
             return false;
         }
 
+        final Network[] underlyingNetworks;
         synchronized (mVpns) {
-            if (getVpnIfOwner(callbackUid) != null) {
-                return true;
-            }
+            final Vpn vpn = getVpnIfOwner(callbackUid);
+            underlyingNetworks = (vpn == null) ? null : vpn.getUnderlyingNetworks();
+        }
+        if (underlyingNetworks != null) {
+            if (Arrays.asList(underlyingNetworks).contains(nai.network)) return true;
         }
 
         // Administrator UIDs also contains the Owner UID
diff --git a/tests/net/java/com/android/server/ConnectivityServiceTest.java b/tests/net/java/com/android/server/ConnectivityServiceTest.java
index 9dc0a62..aebe3b6 100644
--- a/tests/net/java/com/android/server/ConnectivityServiceTest.java
+++ b/tests/net/java/com/android/server/ConnectivityServiceTest.java
@@ -308,6 +308,8 @@
 
     private static final long TIMESTAMP = 1234L;
 
+    private static final int NET_ID = 110;
+
     private static final String CLAT_PREFIX = "v4-";
     private static final String MOBILE_IFNAME = "test_rmnet_data0";
     private static final String WIFI_IFNAME = "test_wlan0";
@@ -1017,6 +1019,7 @@
         private int mVpnType = VpnManager.TYPE_VPN_SERVICE;
 
         private VpnInfo mVpnInfo;
+        private Network[] mUnderlyingNetworks;
 
         public MockVpn(int userId) {
             super(startHandlerThreadAndReturnLooper(), mServiceContext, mNetworkManagementService,
@@ -1106,9 +1109,21 @@
             return super.getVpnInfo();
         }
 
-        private void setVpnInfo(VpnInfo vpnInfo) {
+        private synchronized void setVpnInfo(VpnInfo vpnInfo) {
             mVpnInfo = vpnInfo;
         }
+
+        @Override
+        public synchronized Network[] getUnderlyingNetworks() {
+            if (mUnderlyingNetworks != null) return mUnderlyingNetworks;
+
+            return super.getUnderlyingNetworks();
+        }
+
+        /** Don't override behavior for {@link Vpn#setUnderlyingNetworks}. */
+        private synchronized void overrideUnderlyingNetworks(Network[] underlyingNetworks) {
+            mUnderlyingNetworks = underlyingNetworks;
+        }
     }
 
     private void mockVpn(int uid) {
@@ -6851,9 +6866,10 @@
 
     @Test
     public void testCheckConnectivityDiagnosticsPermissionsActiveVpn() throws Exception {
+        final Network network = new Network(NET_ID);
         final NetworkAgentInfo naiWithoutUid =
                 new NetworkAgentInfo(
-                        null, null, null, null, null, new NetworkCapabilities(), 0,
+                        null, null, network, null, null, new NetworkCapabilities(), 0,
                         mServiceContext, null, null, mService, null, null, null, 0);
 
         setupLocationPermissions(Build.VERSION_CODES.Q, true, AppOpsManager.OPSTR_FINE_LOCATION,
@@ -6866,11 +6882,19 @@
         info.ownerUid = Process.myUid();
         info.vpnIface = "interface";
         mMockVpn.setVpnInfo(info);
+        mMockVpn.overrideUnderlyingNetworks(new Network[] {network});
         assertTrue(
                 "Active VPN permission not applied",
                 mService.checkConnectivityDiagnosticsPermissions(
                         Process.myPid(), Process.myUid(), naiWithoutUid,
                         mContext.getOpPackageName()));
+
+        mMockVpn.overrideUnderlyingNetworks(null);
+        assertFalse(
+                "VPN shouldn't receive callback on non-underlying network",
+                mService.checkConnectivityDiagnosticsPermissions(
+                        Process.myPid(), Process.myUid(), naiWithoutUid,
+                        mContext.getOpPackageName()));
     }
 
     @Test