Merge "Lock down networking when waiting for always-on" into nyc-dev
diff --git a/services/core/java/com/android/server/ConnectivityService.java b/services/core/java/com/android/server/ConnectivityService.java
index b5c2b89..c096fa5 100644
--- a/services/core/java/com/android/server/ConnectivityService.java
+++ b/services/core/java/com/android/server/ConnectivityService.java
@@ -915,6 +915,13 @@
         final boolean networkMetered;
         final int uidRules;
 
+        synchronized (mVpns) {
+            final Vpn vpn = mVpns.get(UserHandle.getUserId(uid));
+            if (vpn != null && vpn.isBlockingUid(uid)) {
+                return true;
+            }
+        }
+
         final String iface = (lp == null ? "" : lp.getInterfaceName());
         synchronized (mRulesLock) {
             networkMetered = mMeteredIfaces.contains(iface);
@@ -3365,23 +3372,42 @@
     }
 
     /**
-     * Sets up or tears down the always-on VPN for user {@param user} as appropriate.
+     * Starts the always-on VPN {@link VpnService} for user {@param userId}, which should perform
+     * some setup and then call {@code establish()} to connect.
      *
-     * @return {@code false} in case of errors; {@code true} otherwise.
+     * @return {@code true} if the service was started, the service was already connected, or there
+     *         was no always-on VPN to start. {@code false} otherwise.
      */
-    private boolean updateAlwaysOnVpn(int user) {
-        final String lockdownPackage = getAlwaysOnVpnPackage(user);
-        if (lockdownPackage == null) {
-            return true;
+    private boolean startAlwaysOnVpn(int userId) {
+        final String alwaysOnPackage;
+        synchronized (mVpns) {
+            Vpn vpn = mVpns.get(userId);
+            if (vpn == null) {
+                // Shouldn't happen as all codepaths that point here should have checked the Vpn
+                // exists already.
+                Slog.wtf(TAG, "User " + userId + " has no Vpn configuration");
+                return false;
+            }
+            alwaysOnPackage = vpn.getAlwaysOnPackage();
+            // Skip if there is no service to start.
+            if (alwaysOnPackage == null) {
+                return true;
+            }
+            // Skip if the service is already established. This isn't bulletproof: it's not bound
+            // until after establish(), so if it's mid-setup onStartCommand will be sent twice,
+            // which may restart the connection.
+            if (vpn.getNetworkInfo().isConnected()) {
+                return true;
+            }
         }
 
-        // Create an intent to start the VPN service declared in the app's manifest.
+        // Start the VPN service declared in the app's manifest.
         Intent serviceIntent = new Intent(VpnConfig.SERVICE_INTERFACE);
-        serviceIntent.setPackage(lockdownPackage);
-
+        serviceIntent.setPackage(alwaysOnPackage);
         try {
-            return mContext.startServiceAsUser(serviceIntent, UserHandle.of(user)) != null;
+            return mContext.startServiceAsUser(serviceIntent, UserHandle.of(userId)) != null;
         } catch (RuntimeException e) {
+            Slog.w(TAG, "VpnService " + serviceIntent + " failed to start", e);
             return false;
         }
     }
@@ -3396,25 +3422,35 @@
             return false;
         }
 
-        // If the current VPN package is the same as the new one, this is a no-op
-        final String oldPackage = getAlwaysOnVpnPackage(userId);
-        if (TextUtils.equals(oldPackage, packageName)) {
-            return true;
-        }
-
         synchronized (mVpns) {
             Vpn vpn = mVpns.get(userId);
             if (vpn == null) {
                 Slog.w(TAG, "User " + userId + " has no Vpn configuration");
                 return false;
             }
-            if (!vpn.setAlwaysOnPackage(packageName)) {
+            // If the current VPN package is the same as the new one, this is a no-op
+            if (TextUtils.equals(packageName, vpn.getAlwaysOnPackage())) {
+                return true;
+            }
+            if (!vpn.setAlwaysOnPackage(packageName, lockdown)) {
                 return false;
             }
-            if (!updateAlwaysOnVpn(userId)) {
-                vpn.setAlwaysOnPackage(null);
+            if (!startAlwaysOnVpn(userId)) {
+                vpn.setAlwaysOnPackage(null, false);
                 return false;
             }
+
+            // Save the configuration
+            final long token = Binder.clearCallingIdentity();
+            try {
+                final ContentResolver cr = mContext.getContentResolver();
+                Settings.Secure.putStringForUser(cr, Settings.Secure.ALWAYS_ON_VPN_APP,
+                        packageName, userId);
+                Settings.Secure.putIntForUser(cr, Settings.Secure.ALWAYS_ON_VPN_LOCKDOWN,
+                        (lockdown ? 1 : 0), userId);
+            } finally {
+                Binder.restoreCallingIdentity(token);
+            }
         }
         return true;
     }
@@ -3685,11 +3721,18 @@
             }
             userVpn = new Vpn(mHandler.getLooper(), mContext, mNetd, userId);
             mVpns.put(userId, userVpn);
+
+            final ContentResolver cr = mContext.getContentResolver();
+            String alwaysOnPackage = Settings.Secure.getStringForUser(cr,
+                    Settings.Secure.ALWAYS_ON_VPN_APP, userId);
+            final boolean alwaysOnLockdown = Settings.Secure.getIntForUser(cr,
+                    Settings.Secure.ALWAYS_ON_VPN_LOCKDOWN, /* default */ 0, userId) != 0;
+            if (alwaysOnPackage != null) {
+                userVpn.setAlwaysOnPackage(alwaysOnPackage, alwaysOnLockdown);
+            }
         }
         if (mUserManager.getUserInfo(userId).isPrimary() && LockdownVpnTracker.isEnabled()) {
             updateLockdownVpn();
-        } else {
-            updateAlwaysOnVpn(userId);
         }
     }
 
@@ -3700,6 +3743,7 @@
                 loge("Stopped user has no VPN");
                 return;
             }
+            userVpn.onUserStopped();
             mVpns.delete(userId);
         }
     }
@@ -3729,7 +3773,7 @@
         if (mUserManager.getUserInfo(userId).isPrimary() && LockdownVpnTracker.isEnabled()) {
             updateLockdownVpn();
         } else {
-            updateAlwaysOnVpn(userId);
+            startAlwaysOnVpn(userId);
         }
     }
 
diff --git a/services/tests/servicestests/src/com/android/server/connectivity/VpnTest.java b/services/tests/servicestests/src/com/android/server/connectivity/VpnTest.java
index 3295bf5..5d8b843 100644
--- a/services/tests/servicestests/src/com/android/server/connectivity/VpnTest.java
+++ b/services/tests/servicestests/src/com/android/server/connectivity/VpnTest.java
@@ -20,9 +20,11 @@
 import static android.content.pm.UserInfo.FLAG_MANAGED_PROFILE;
 import static android.content.pm.UserInfo.FLAG_PRIMARY;
 import static android.content.pm.UserInfo.FLAG_RESTRICTED;
+import static org.mockito.AdditionalMatchers.*;
 import static org.mockito.Mockito.*;
 
 import android.annotation.UserIdInt;
+import android.app.AppOpsManager;
 import android.content.Context;
 import android.content.pm.PackageManager;
 import android.content.pm.UserInfo;
@@ -65,16 +67,35 @@
         managedProfileA.profileGroupId = primaryUser.id;
     }
 
+    /**
+     * Names and UIDs for some fake packages. Important points:
+     *  - UID is ordered increasing.
+     *  - One pair of packages have consecutive UIDs.
+     */
+    static final String[] PKGS = {"com.example", "org.example", "net.example", "web.vpn"};
+    static final int[] PKG_UIDS = {66, 77, 78, 400};
+
+    // Mock packages
+    static final Map<String, Integer> mPackages = new ArrayMap<>();
+    static {
+        for (int i = 0; i < PKGS.length; i++) {
+            mPackages.put(PKGS[i], PKG_UIDS[i]);
+        }
+    }
+
     @Mock private Context mContext;
     @Mock private UserManager mUserManager;
     @Mock private PackageManager mPackageManager;
     @Mock private INetworkManagementService mNetService;
+    @Mock private AppOpsManager mAppOps;
 
     @Override
     public void setUp() throws Exception {
         MockitoAnnotations.initMocks(this);
         when(mContext.getPackageManager()).thenReturn(mPackageManager);
+        setMockedPackages(mPackages);
         when(mContext.getSystemService(eq(Context.USER_SERVICE))).thenReturn(mUserManager);
+        when(mContext.getSystemService(eq(Context.APP_OPS_SERVICE))).thenReturn(mAppOps);
         doNothing().when(mNetService).registerObserver(any());
     }
 
@@ -82,7 +103,7 @@
     public void testRestrictedProfilesAreAddedToVpn() {
         setMockedUsers(primaryUser, secondaryUser, restrictedProfileA, restrictedProfileB);
 
-        final Vpn vpn = createVpn(primaryUser.id);
+        final Vpn vpn = new MockVpn(primaryUser.id);
         final Set<UidRange> ranges = vpn.createUserAndRestrictedProfilesRanges(primaryUser.id,
                 null, null);
 
@@ -96,7 +117,7 @@
     public void testManagedProfilesAreNotAddedToVpn() {
         setMockedUsers(primaryUser, managedProfileA);
 
-        final Vpn vpn = createVpn(primaryUser.id);
+        final Vpn vpn = new MockVpn(primaryUser.id);
         final Set<UidRange> ranges = vpn.createUserAndRestrictedProfilesRanges(primaryUser.id,
                 null, null);
 
@@ -109,7 +130,7 @@
     public void testAddUserToVpnOnlyAddsOneUser() {
         setMockedUsers(primaryUser, restrictedProfileA, managedProfileA);
 
-        final Vpn vpn = createVpn(primaryUser.id);
+        final Vpn vpn = new MockVpn(primaryUser.id);
         final Set<UidRange> ranges = new ArraySet<>();
         vpn.addUserToRanges(ranges, primaryUser.id, null, null);
 
@@ -120,42 +141,123 @@
 
     @SmallTest
     public void testUidWhiteAndBlacklist() throws Exception {
-        final Map<String, Integer> packages = new ArrayMap<>();
-        packages.put("com.example", 66);
-        packages.put("org.example", 77);
-        packages.put("net.example", 78);
-        setMockedPackages(packages);
-
-        final Vpn vpn = createVpn(primaryUser.id);
+        final Vpn vpn = new MockVpn(primaryUser.id);
         final UidRange user = UidRange.createForUser(primaryUser.id);
+        final String[] packages = {PKGS[0], PKGS[1], PKGS[2]};
 
         // Whitelist
         final Set<UidRange> allow = vpn.createUserAndRestrictedProfilesRanges(primaryUser.id,
-                new ArrayList<String>(packages.keySet()), null);
+                Arrays.asList(packages), null);
         assertEquals(new ArraySet<>(Arrays.asList(new UidRange[] {
-            new UidRange(user.start + 66, user.start + 66),
-            new UidRange(user.start + 77, user.start + 78)
+            new UidRange(user.start + PKG_UIDS[0], user.start + PKG_UIDS[0]),
+            new UidRange(user.start + PKG_UIDS[1], user.start + PKG_UIDS[2])
         })), allow);
 
         // Blacklist
         final Set<UidRange> disallow = vpn.createUserAndRestrictedProfilesRanges(primaryUser.id,
-                null, new ArrayList<String>(packages.keySet()));
+                null, Arrays.asList(packages));
         assertEquals(new ArraySet<>(Arrays.asList(new UidRange[] {
-            new UidRange(user.start, user.start + 65),
-            new UidRange(user.start + 67, user.start + 76),
-            new UidRange(user.start + 79, user.stop)
+            new UidRange(user.start, user.start + PKG_UIDS[0] - 1),
+            new UidRange(user.start + PKG_UIDS[0] + 1, user.start + PKG_UIDS[1] - 1),
+            /* Empty range between UIDS[1] and UIDS[2], should be excluded, */
+            new UidRange(user.start + PKG_UIDS[2] + 1, user.stop)
         })), disallow);
     }
 
+    @SmallTest
+    public void testLockdownChangingPackage() throws Exception {
+        final MockVpn vpn = new MockVpn(primaryUser.id);
+        final UidRange user = UidRange.createForUser(primaryUser.id);
+
+        // Default state.
+        vpn.assertUnblocked(user.start + PKG_UIDS[0], user.start + PKG_UIDS[1], user.start + PKG_UIDS[2], user.start + PKG_UIDS[3]);
+
+        // Set always-on without lockdown.
+        assertTrue(vpn.setAlwaysOnPackage(PKGS[1], false));
+        vpn.assertUnblocked(user.start + PKG_UIDS[0], user.start + PKG_UIDS[1], user.start + PKG_UIDS[2], user.start + PKG_UIDS[3]);
+
+        // Set always-on with lockdown.
+        assertTrue(vpn.setAlwaysOnPackage(PKGS[1], true));
+        verify(mNetService).setAllowOnlyVpnForUids(eq(true), aryEq(new UidRange[] {
+            new UidRange(user.start, user.start + PKG_UIDS[1] - 1),
+            new UidRange(user.start + PKG_UIDS[1] + 1, user.stop)
+        }));
+        vpn.assertBlocked(user.start + PKG_UIDS[0], user.start + PKG_UIDS[2], user.start + PKG_UIDS[3]);
+        vpn.assertUnblocked(user.start + PKG_UIDS[1]);
+
+        // Switch to another app.
+        assertTrue(vpn.setAlwaysOnPackage(PKGS[3], true));
+        verify(mNetService).setAllowOnlyVpnForUids(eq(false), aryEq(new UidRange[] {
+            new UidRange(user.start, user.start + PKG_UIDS[1] - 1),
+            new UidRange(user.start + PKG_UIDS[1] + 1, user.stop)
+        }));
+        verify(mNetService).setAllowOnlyVpnForUids(eq(true), aryEq(new UidRange[] {
+            new UidRange(user.start, user.start + PKG_UIDS[3] - 1),
+            new UidRange(user.start + PKG_UIDS[3] + 1, user.stop)
+        }));
+        vpn.assertBlocked(user.start + PKG_UIDS[0], user.start + PKG_UIDS[1], user.start + PKG_UIDS[2]);
+        vpn.assertUnblocked(user.start + PKG_UIDS[3]);
+    }
+
+    @SmallTest
+    public void testLockdownAddingAProfile() throws Exception {
+        final MockVpn vpn = new MockVpn(primaryUser.id);
+        setMockedUsers(primaryUser);
+
+        // Make a copy of the restricted profile, as we're going to mark it deleted halfway through.
+        final UserInfo tempProfile = new UserInfo(restrictedProfileA.id, restrictedProfileA.name,
+                restrictedProfileA.flags);
+        tempProfile.restrictedProfileParentId = primaryUser.id;
+
+        final UidRange user = UidRange.createForUser(primaryUser.id);
+        final UidRange profile = UidRange.createForUser(tempProfile.id);
+
+        // Set lockdown.
+        assertTrue(vpn.setAlwaysOnPackage(PKGS[3], true));
+        verify(mNetService).setAllowOnlyVpnForUids(eq(true), aryEq(new UidRange[] {
+            new UidRange(user.start, user.start + PKG_UIDS[3] - 1),
+            new UidRange(user.start + PKG_UIDS[3] + 1, user.stop)
+        }));
+
+        // Verify restricted user isn't affected at first.
+        vpn.assertUnblocked(profile.start + PKG_UIDS[0]);
+
+        // Add the restricted user.
+        setMockedUsers(primaryUser, tempProfile);
+        vpn.onUserAdded(tempProfile.id);
+        verify(mNetService).setAllowOnlyVpnForUids(eq(true), aryEq(new UidRange[] {
+            new UidRange(profile.start, profile.start + PKG_UIDS[3] - 1),
+            new UidRange(profile.start + PKG_UIDS[3] + 1, profile.stop)
+        }));
+
+        // Remove the restricted user.
+        tempProfile.partial = true;
+        vpn.onUserRemoved(tempProfile.id);
+        verify(mNetService).setAllowOnlyVpnForUids(eq(false), aryEq(new UidRange[] {
+            new UidRange(profile.start, profile.start + PKG_UIDS[3] - 1),
+            new UidRange(profile.start + PKG_UIDS[3] + 1, profile.stop)
+        }));
+    }
+
     /**
-     * @return A subclass of {@link Vpn} which is reliably:
-     * <ul>
-     *   <li>Associated with a specific user ID</li>
-     *   <li>Not in always-on mode</li>
-     * </ul>
+     * A subclass of {@link Vpn} with some of the fields pre-mocked.
      */
-    private Vpn createVpn(@UserIdInt int userId) {
-        return new Vpn(Looper.myLooper(), mContext, mNetService, userId);
+    private class MockVpn extends Vpn {
+        public MockVpn(@UserIdInt int userId) {
+            super(Looper.myLooper(), mContext, mNetService, userId);
+        }
+
+        public void assertBlocked(int... uids) {
+            for (int uid : uids) {
+                assertTrue("Uid " + uid + " should be blocked", isBlockingUid(uid));
+            }
+        }
+
+        public void assertUnblocked(int... uids) {
+            for (int uid : uids) {
+                assertFalse("Uid " + uid + " should not be blocked", isBlockingUid(uid));
+            }
+        }
     }
 
     /**
@@ -167,9 +269,19 @@
             userMap.put(user.id, user);
         }
 
+        /**
+         * @see UserManagerService#getUsers(boolean)
+         */
         doAnswer(invocation -> {
-            return new ArrayList(userMap.values());
-        }).when(mUserManager).getUsers();
+            final boolean excludeDying = (boolean) invocation.getArguments()[0];
+            final ArrayList<UserInfo> result = new ArrayList<>(users.length);
+            for (UserInfo ui : users) {
+                if (!excludeDying || (ui.isEnabled() && !ui.partial)) {
+                    result.add(ui);
+                }
+            }
+            return result;
+        }).when(mUserManager).getUsers(anyBoolean());
 
         doAnswer(invocation -> {
             final int id = (int) invocation.getArguments()[0];