Merge "Improve getConnectionOwnerUid tests."
diff --git a/services/core/java/com/android/server/ConnectivityService.java b/services/core/java/com/android/server/ConnectivityService.java
index f2e1920..94e39da 100644
--- a/services/core/java/com/android/server/ConnectivityService.java
+++ b/services/core/java/com/android/server/ConnectivityService.java
@@ -222,6 +222,7 @@
 import java.io.PrintWriter;
 import java.net.Inet4Address;
 import java.net.InetAddress;
+import java.net.InetSocketAddress;
 import java.net.UnknownHostException;
 import java.util.ArrayList;
 import java.util.Arrays;
@@ -990,6 +991,15 @@
         }
 
         /**
+         * Gets the UID that owns a socket connection. Needed because opening SOCK_DIAG sockets
+         * requires CAP_NET_ADMIN, which the unit tests do not have.
+         */
+        public int getConnectionOwnerUid(int protocol, InetSocketAddress local,
+                InetSocketAddress remote) {
+            return InetDiagMessage.getConnectionOwnerUid(protocol, local, remote);
+        }
+
+        /**
          * @see MultinetworkPolicyTracker
          */
         public MultinetworkPolicyTracker makeMultinetworkPolicyTracker(
@@ -8350,7 +8360,7 @@
             throw new IllegalArgumentException("Unsupported protocol " + connectionInfo.protocol);
         }
 
-        final int uid = InetDiagMessage.getConnectionOwnerUid(connectionInfo.protocol,
+        final int uid = mDeps.getConnectionOwnerUid(connectionInfo.protocol,
                 connectionInfo.local, connectionInfo.remote);
 
         /* Filter out Uids not associated with the VPN. */
diff --git a/tests/net/java/com/android/server/ConnectivityServiceTest.java b/tests/net/java/com/android/server/ConnectivityServiceTest.java
index b0cc7f1..88a4377 100644
--- a/tests/net/java/com/android/server/ConnectivityServiceTest.java
+++ b/tests/net/java/com/android/server/ConnectivityServiceTest.java
@@ -8355,13 +8355,14 @@
     private void setupConnectionOwnerUid(int vpnOwnerUid, @VpnManager.VpnType int vpnType)
             throws Exception {
         final Set<UidRange> vpnRange = Collections.singleton(UidRange.createForUser(PRIMARY_USER));
+        mMockVpn.setVpnType(vpnType);
         mMockVpn.establish(new LinkProperties(), vpnOwnerUid, vpnRange);
         assertVpnUidRangesUpdated(true, vpnRange, vpnOwnerUid);
-        mMockVpn.setVpnType(vpnType);
 
         final UnderlyingNetworkInfo underlyingNetworkInfo =
                 new UnderlyingNetworkInfo(vpnOwnerUid, VPN_IFNAME, new ArrayList<String>());
         mMockVpn.setUnderlyingNetworkInfo(underlyingNetworkInfo);
+        when(mDeps.getConnectionOwnerUid(anyInt(), any(), any())).thenReturn(42);
     }
 
     private void setupConnectionOwnerUidAsVpnApp(int vpnOwnerUid, @VpnManager.VpnType int vpnType)
@@ -8410,8 +8411,7 @@
         final int myUid = Process.myUid();
         setupConnectionOwnerUidAsVpnApp(myUid, VpnManager.TYPE_VPN_SERVICE);
 
-        // TODO: Test the returned UID
-        mService.getConnectionOwnerUid(getTestConnectionInfo());
+        assertEquals(42, mService.getConnectionOwnerUid(getTestConnectionInfo()));
     }
 
     @Test
@@ -8421,8 +8421,7 @@
         mServiceContext.setPermission(
                 android.Manifest.permission.NETWORK_STACK, PERMISSION_GRANTED);
 
-        // TODO: Test the returned UID
-        mService.getConnectionOwnerUid(getTestConnectionInfo());
+        assertEquals(42, mService.getConnectionOwnerUid(getTestConnectionInfo()));
     }
 
     @Test
@@ -8433,8 +8432,7 @@
         mServiceContext.setPermission(
                 NetworkStack.PERMISSION_MAINLINE_NETWORK_STACK, PERMISSION_GRANTED);
 
-        // TODO: Test the returned UID
-        mService.getConnectionOwnerUid(getTestConnectionInfo());
+        assertEquals(42, mService.getConnectionOwnerUid(getTestConnectionInfo()));
     }
 
     private static PackageInfo buildPackageInfo(boolean hasSystemPermission, int uid) {