Check the calling user instead of the current user.

In concurrent multi-user case, the calling user may be different from
the current user, thus we cannot assume that it is always the current
user to check if the calling user is an admin user. Instead, we use the
calling user id from the binder, before the caller identity is cleared.

Bug: 315382159

Test: atest android.bugreport.cts.BugreportManagerTest

Change-Id: I003037d37b72c47dc4be7c85f634116f9bc4ca63
diff --git a/services/core/java/com/android/server/os/BugreportManagerServiceImpl.java b/services/core/java/com/android/server/os/BugreportManagerServiceImpl.java
index 1660c3e..e546f42 100644
--- a/services/core/java/com/android/server/os/BugreportManagerServiceImpl.java
+++ b/services/core/java/com/android/server/os/BugreportManagerServiceImpl.java
@@ -21,13 +21,11 @@
 import android.Manifest;
 import android.annotation.Nullable;
 import android.annotation.RequiresPermission;
-import android.app.ActivityManager;
 import android.app.AppOpsManager;
 import android.app.admin.DevicePolicyManager;
 import android.app.role.RoleManager;
 import android.content.Context;
 import android.content.pm.PackageManager;
-import android.content.pm.UserInfo;
 import android.os.Binder;
 import android.os.BugreportManager.BugreportCallback;
 import android.os.BugreportParams;
@@ -39,6 +37,7 @@
 import android.os.SystemClock;
 import android.os.SystemProperties;
 import android.os.UserHandle;
+import android.os.UserManager;
 import android.telephony.TelephonyManager;
 import android.text.TextUtils;
 import android.util.ArrayMap;
@@ -96,6 +95,7 @@
     private static final long DEFAULT_BUGREPORT_SERVICE_TIMEOUT_MILLIS = 30 * 1000;
 
     private final Object mLock = new Object();
+    private final Injector mInjector;
     private final Context mContext;
     private final AppOpsManager mAppOps;
     private final TelephonyManager mTelephonyManager;
@@ -346,6 +346,14 @@
         AtomicFile getMappingFile() {
             return mMappingFile;
         }
+
+        UserManager getUserManager() {
+            return mContext.getSystemService(UserManager.class);
+        }
+
+        DevicePolicyManager getDevicePolicyManager() {
+            return mContext.getSystemService(DevicePolicyManager.class);
+        }
     }
 
     BugreportManagerServiceImpl(Context context) {
@@ -357,6 +365,7 @@
 
     @VisibleForTesting(visibility = VisibleForTesting.Visibility.PRIVATE)
     BugreportManagerServiceImpl(Injector injector) {
+        mInjector = injector;
         mContext = injector.getContext();
         mAppOps = mContext.getSystemService(AppOpsManager.class);
         mTelephonyManager = mContext.getSystemService(TelephonyManager.class);
@@ -389,12 +398,7 @@
         int callingUid = Binder.getCallingUid();
         enforcePermission(callingPackage, callingUid, bugreportMode
                 == BugreportParams.BUGREPORT_MODE_TELEPHONY /* checkCarrierPrivileges */);
-        final long identity = Binder.clearCallingIdentity();
-        try {
-            ensureUserCanTakeBugReport(bugreportMode);
-        } finally {
-            Binder.restoreCallingIdentity(identity);
-        }
+        ensureUserCanTakeBugReport(bugreportMode);
 
         Slogf.i(TAG, "Starting bugreport for %s / %d", callingPackage, callingUid);
         synchronized (mLock) {
@@ -433,7 +437,6 @@
     @RequiresPermission(value = Manifest.permission.DUMP, conditional = true)
     public void retrieveBugreport(int callingUidUnused, String callingPackage, int userId,
             FileDescriptor bugreportFd, String bugreportFile,
-
             boolean keepBugreportOnRetrievalUnused, IDumpstateListener listener) {
         int callingUid = Binder.getCallingUid();
         enforcePermission(callingPackage, callingUid, false);
@@ -565,54 +568,48 @@
     }
 
     /**
-     * Validates that the current user is an admin user or, when bugreport is requested remotely
-     * that the current user is an affiliated user.
+     * Validates that the calling user is an admin user or, when bugreport is requested remotely
+     * that the user is an affiliated user.
      *
-     * @throws IllegalArgumentException if the current user is not an admin user
+     * @throws IllegalArgumentException if the calling user is not an admin user
      */
     private void ensureUserCanTakeBugReport(int bugreportMode) {
-        UserInfo currentUser = null;
+        // Get the calling userId before clearing the caller identity.
+        int callingUserId = UserHandle.getUserId(Binder.getCallingUid());
+        boolean isAdminUser = false;
+        final long identity = Binder.clearCallingIdentity();
         try {
-            currentUser = ActivityManager.getService().getCurrentUser();
-        } catch (RemoteException e) {
-            // Impossible to get RemoteException for an in-process call.
+            isAdminUser = mInjector.getUserManager().isUserAdmin(callingUserId);
+        } finally {
+            Binder.restoreCallingIdentity(identity);
         }
-
-        if (currentUser == null) {
-            logAndThrow("There is no current user, so no bugreport can be requested.");
-        }
-
-        if (!currentUser.isAdmin()) {
+        if (!isAdminUser) {
             if (bugreportMode == BugreportParams.BUGREPORT_MODE_REMOTE
-                    && isCurrentUserAffiliated(currentUser.id)) {
+                    && isUserAffiliated(callingUserId)) {
                 return;
             }
-            logAndThrow(TextUtils.formatSimple("Current user %s is not an admin user."
-                    + " Only admin users are allowed to take bugreport.", currentUser.id));
+            logAndThrow(TextUtils.formatSimple("Calling user %s is not an admin user."
+                    + " Only admin users are allowed to take bugreport.", callingUserId));
         }
     }
 
     /**
-     * Returns {@code true} if the device has device owner and the current user is affiliated
+     * Returns {@code true} if the device has device owner and the specified user is affiliated
      * with the device owner.
      */
-    private boolean isCurrentUserAffiliated(int currentUserId) {
-        DevicePolicyManager dpm = mContext.getSystemService(DevicePolicyManager.class);
+    private boolean isUserAffiliated(int userId) {
+        DevicePolicyManager dpm = mInjector.getDevicePolicyManager();
         int deviceOwnerUid = dpm.getDeviceOwnerUserId();
         if (deviceOwnerUid == UserHandle.USER_NULL) {
             return false;
         }
 
-        int callingUserId = UserHandle.getUserId(Binder.getCallingUid());
-
-        Slog.i(TAG, "callingUid: " + callingUserId + " deviceOwnerUid: " + deviceOwnerUid
-                + " currentUserId: " + currentUserId);
-
-        if (callingUserId != deviceOwnerUid) {
-            logAndThrow("Caller is not device owner on provisioned device.");
+        if (DEBUG) {
+            Slog.d(TAG, "callingUid: " + userId + " deviceOwnerUid: " + deviceOwnerUid);
         }
-        if (!dpm.isAffiliatedUser(currentUserId)) {
-            logAndThrow("Current user is not affiliated to the device owner.");
+
+        if (userId != deviceOwnerUid && !dpm.isAffiliatedUser(userId)) {
+            logAndThrow("User " + userId + " is not affiliated to the device owner.");
         }
         return true;
     }
diff --git a/services/tests/servicestests/src/com/android/server/os/BugreportManagerServiceImplTest.java b/services/tests/servicestests/src/com/android/server/os/BugreportManagerServiceImplTest.java
index dc1d2c5..21b8a94 100644
--- a/services/tests/servicestests/src/com/android/server/os/BugreportManagerServiceImplTest.java
+++ b/services/tests/servicestests/src/com/android/server/os/BugreportManagerServiceImplTest.java
@@ -16,23 +16,26 @@
 
 package com.android.server.os;
 
-import android.app.admin.flags.Flags;
-import static android.app.admin.flags.Flags.onboardingBugreportV2Enabled;
-
 import static com.android.compatibility.common.util.SystemUtil.runWithShellPermissionIdentity;
 
 import static com.google.common.truth.Truth.assertThat;
 
 import static org.junit.Assert.assertThrows;
+import static org.mockito.ArgumentMatchers.anyInt;
+import static org.mockito.Mockito.when;
 
+import android.app.admin.DevicePolicyManager;
+import android.app.admin.flags.Flags;
 import android.app.role.RoleManager;
 import android.content.Context;
 import android.os.Binder;
 import android.os.BugreportManager.BugreportCallback;
+import android.os.BugreportParams;
 import android.os.IBinder;
 import android.os.IDumpstateListener;
 import android.os.Process;
 import android.os.RemoteException;
+import android.os.UserManager;
 import android.platform.test.annotations.RequiresFlagsEnabled;
 import android.platform.test.flag.junit.CheckFlagsRule;
 import android.platform.test.flag.junit.DeviceFlagsValueProvider;
@@ -48,6 +51,8 @@
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.runner.RunWith;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
 
 import java.io.FileDescriptor;
 import java.util.concurrent.CompletableFuture;
@@ -66,6 +71,11 @@
     private BugreportManagerServiceImpl mService;
     private BugreportManagerServiceImpl.BugreportFileManager mBugreportFileManager;
 
+    @Mock
+    private UserManager mMockUserManager;
+    @Mock
+    private DevicePolicyManager mMockDevicePolicyManager;
+
     private int mCallingUid = 1234;
     private String mCallingPackage  = "test.package";
     private AtomicFile mMappingFile;
@@ -75,14 +85,17 @@
 
     @Before
     public void setUp() {
+        MockitoAnnotations.initMocks(this);
         mContext = InstrumentationRegistry.getInstrumentation().getContext();
         mMappingFile = new AtomicFile(mContext.getFilesDir(), "bugreport-mapping.xml");
         ArraySet<String> mAllowlistedPackages = new ArraySet<>();
         mAllowlistedPackages.add(mContext.getPackageName());
         mService = new BugreportManagerServiceImpl(
-                new BugreportManagerServiceImpl.Injector(mContext, mAllowlistedPackages,
-                        mMappingFile));
+                new TestInjector(mContext, mAllowlistedPackages, mMappingFile,
+                        mMockUserManager, mMockDevicePolicyManager));
         mBugreportFileManager = new BugreportManagerServiceImpl.BugreportFileManager(mMappingFile);
+        // The calling user is an admin user by default.
+        when(mMockUserManager.isUserAdmin(anyInt())).thenReturn(true);
     }
 
     @After
@@ -165,6 +178,36 @@
     }
 
     @Test
+    public void testStartBugreport_throwsForNonAdminUser() throws Exception {
+        when(mMockUserManager.isUserAdmin(anyInt())).thenReturn(false);
+
+        Exception thrown = assertThrows(Exception.class,
+                () -> mService.startBugreport(mCallingUid, mContext.getPackageName(),
+                        new FileDescriptor(), /* screenshotFd= */ null,
+                        BugreportParams.BUGREPORT_MODE_FULL,
+                        /* flags= */ 0, new Listener(new CountDownLatch(1)),
+                        /* isScreenshotRequested= */ false));
+
+        assertThat(thrown.getMessage()).contains("not an admin user");
+    }
+
+    @Test
+    public void testStartBugreport_throwsForNotAffiliatedUser() throws Exception {
+        when(mMockUserManager.isUserAdmin(anyInt())).thenReturn(false);
+        when(mMockDevicePolicyManager.getDeviceOwnerUserId()).thenReturn(-1);
+        when(mMockDevicePolicyManager.isAffiliatedUser(anyInt())).thenReturn(false);
+
+        Exception thrown = assertThrows(Exception.class,
+                () -> mService.startBugreport(mCallingUid, mContext.getPackageName(),
+                        new FileDescriptor(), /* screenshotFd= */ null,
+                        BugreportParams.BUGREPORT_MODE_REMOTE,
+                        /* flags= */ 0, new Listener(new CountDownLatch(1)),
+                        /* isScreenshotRequested= */ false));
+
+        assertThat(thrown.getMessage()).contains("not affiliated to the device owner");
+    }
+
+    @Test
     public void testRetrieveBugreportWithoutFilesForCaller() throws Exception {
         CountDownLatch latch = new CountDownLatch(1);
         Listener listener = new Listener(latch);
@@ -207,7 +250,8 @@
 
     private void clearAllowlist() {
         mService = new BugreportManagerServiceImpl(
-                new BugreportManagerServiceImpl.Injector(mContext, new ArraySet<>(), mMappingFile));
+                new TestInjector(mContext, new ArraySet<>(), mMappingFile,
+                        mMockUserManager, mMockDevicePolicyManager));
     }
 
     private static class Listener implements IDumpstateListener {
@@ -258,4 +302,27 @@
             complete(successful);
         }
     }
+
+    private static class TestInjector extends BugreportManagerServiceImpl.Injector {
+
+        private final UserManager mUserManager;
+        private final DevicePolicyManager mDevicePolicyManager;
+
+        TestInjector(Context context, ArraySet<String> allowlistedPackages, AtomicFile mappingFile,
+                UserManager um, DevicePolicyManager dpm) {
+            super(context, allowlistedPackages, mappingFile);
+            mUserManager = um;
+            mDevicePolicyManager = dpm;
+        }
+
+        @Override
+        public UserManager getUserManager() {
+            return mUserManager;
+        }
+
+        @Override
+        public DevicePolicyManager getDevicePolicyManager() {
+            return mDevicePolicyManager;
+        }
+    }
 }