Merge "Replace user switching events with UserTracker" into tm-qpr-dev
diff --git a/packages/SystemUI/src/com/android/keyguard/KeyguardUpdateMonitor.java b/packages/SystemUI/src/com/android/keyguard/KeyguardUpdateMonitor.java
index a4c5d0b..84dd368 100644
--- a/packages/SystemUI/src/com/android/keyguard/KeyguardUpdateMonitor.java
+++ b/packages/SystemUI/src/com/android/keyguard/KeyguardUpdateMonitor.java
@@ -73,11 +73,9 @@
 import android.annotation.AnyThread;
 import android.annotation.MainThread;
 import android.annotation.SuppressLint;
-import android.app.ActivityManager;
 import android.app.ActivityTaskManager;
 import android.app.ActivityTaskManager.RootTaskInfo;
 import android.app.AlarmManager;
-import android.app.UserSwitchObserver;
 import android.app.admin.DevicePolicyManager;
 import android.app.trust.TrustManager;
 import android.content.BroadcastReceiver;
@@ -104,7 +102,6 @@
 import android.nfc.NfcAdapter;
 import android.os.CancellationSignal;
 import android.os.Handler;
-import android.os.IRemoteCallback;
 import android.os.Looper;
 import android.os.Message;
 import android.os.PowerManager;
@@ -175,6 +172,7 @@
 import java.util.Optional;
 import java.util.Set;
 import java.util.TimeZone;
+import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.Executor;
 import java.util.stream.Collectors;
 
@@ -2158,7 +2156,7 @@
                         handleDevicePolicyManagerStateChanged(msg.arg1);
                         break;
                     case MSG_USER_SWITCHING:
-                        handleUserSwitching(msg.arg1, (IRemoteCallback) msg.obj);
+                        handleUserSwitching(msg.arg1, (CountDownLatch) msg.obj);
                         break;
                     case MSG_USER_SWITCH_COMPLETE:
                         handleUserSwitchComplete(msg.arg1);
@@ -2283,11 +2281,7 @@
                 mHandler, UserHandle.ALL);
 
         mSubscriptionManager.addOnSubscriptionsChangedListener(mSubscriptionListener);
-        try {
-            ActivityManager.getService().registerUserSwitchObserver(mUserSwitchObserver, TAG);
-        } catch (RemoteException e) {
-            e.rethrowAsRuntimeException();
-        }
+        mUserTracker.addCallback(mUserChangedCallback, mainExecutor);
 
         mTrustManager.registerTrustListener(this);
 
@@ -2423,17 +2417,17 @@
         return mIsFaceEnrolled;
     }
 
-    private final UserSwitchObserver mUserSwitchObserver = new UserSwitchObserver() {
+    private final UserTracker.Callback mUserChangedCallback = new UserTracker.Callback() {
         @Override
-        public void onUserSwitching(int newUserId, IRemoteCallback reply) {
+        public void onUserChanging(int newUser, Context userContext, CountDownLatch latch) {
             mHandler.sendMessage(mHandler.obtainMessage(MSG_USER_SWITCHING,
-                    newUserId, 0, reply));
+                    newUser, 0, latch));
         }
 
         @Override
-        public void onUserSwitchComplete(int newUserId) {
+        public void onUserChanged(int newUser, Context userContext) {
             mHandler.sendMessage(mHandler.obtainMessage(MSG_USER_SWITCH_COMPLETE,
-                    newUserId, 0));
+                    newUser, 0));
         }
     };
 
@@ -3152,7 +3146,7 @@
      * Handle {@link #MSG_USER_SWITCHING}
      */
     @VisibleForTesting
-    void handleUserSwitching(int userId, IRemoteCallback reply) {
+    void handleUserSwitching(int userId, CountDownLatch latch) {
         Assert.isMainThread();
         clearBiometricRecognized();
         mUserTrustIsUsuallyManaged.put(userId, mTrustManager.isTrustUsuallyManaged(userId));
@@ -3162,11 +3156,7 @@
                 cb.onUserSwitching(userId);
             }
         }
-        try {
-            reply.sendResult(null);
-        } catch (RemoteException e) {
-            mLogger.logException(e, "Ignored exception while userSwitching");
-        }
+        latch.countDown();
     }
 
     /**
@@ -3936,13 +3926,7 @@
             mContext.getContentResolver().unregisterContentObserver(mTimeFormatChangeObserver);
         }
 
-        try {
-            ActivityManager.getService().unregisterUserSwitchObserver(mUserSwitchObserver);
-        } catch (RemoteException e) {
-            mLogger.logException(
-                    e,
-                    "RemoteException onDestroy. cannot unregister userSwitchObserver");
-        }
+        mUserTracker.removeCallback(mUserChangedCallback);
 
         TaskStackChangeListeners.getInstance().unregisterTaskStackListener(mTaskStackListener);
 
diff --git a/packages/SystemUI/src/com/android/systemui/settings/UserTracker.kt b/packages/SystemUI/src/com/android/systemui/settings/UserTracker.kt
index 287e810..33a3125 100644
--- a/packages/SystemUI/src/com/android/systemui/settings/UserTracker.kt
+++ b/packages/SystemUI/src/com/android/systemui/settings/UserTracker.kt
@@ -19,6 +19,7 @@
 import android.content.Context
 import android.content.pm.UserInfo
 import android.os.UserHandle
+import java.util.concurrent.CountDownLatch
 import java.util.concurrent.Executor
 
 /**
@@ -67,14 +68,25 @@
     interface Callback {
 
         /**
+         * Same as {@link onUserChanging(Int, Context, CountDownLatch)} but the latch will be
+         * auto-decremented after the completion of this method.
+         */
+        @JvmDefault
+        fun onUserChanging(newUser: Int, userContext: Context) {}
+
+        /**
          * Notifies that the current user is being changed.
          * Override this method to run things while the screen is frozen for the user switch.
          * Please use {@link #onUserChanged} if the task doesn't need to push the unfreezing of the
          * screen further. Please be aware that code executed in this callback will lengthen the
-         * user switch duration.
+         * user switch duration. When overriding this method, countDown() MUST be called on the
+         * latch once execution is complete.
          */
         @JvmDefault
-        fun onUserChanging(newUser: Int, userContext: Context) {}
+        fun onUserChanging(newUser: Int, userContext: Context, latch: CountDownLatch) {
+            onUserChanging(newUser, userContext)
+            latch.countDown()
+        }
 
         /**
          * Notifies that the current user has changed.
diff --git a/packages/SystemUI/src/com/android/systemui/settings/UserTrackerImpl.kt b/packages/SystemUI/src/com/android/systemui/settings/UserTrackerImpl.kt
index 9f551c6..8674036 100644
--- a/packages/SystemUI/src/com/android/systemui/settings/UserTrackerImpl.kt
+++ b/packages/SystemUI/src/com/android/systemui/settings/UserTrackerImpl.kt
@@ -183,9 +183,22 @@
         Log.i(TAG, "Switching to user $newUserId")
 
         setUserIdInternal(newUserId)
-        notifySubscribers {
-            onUserChanging(newUserId, userContext)
-        }.await()
+
+        val list = synchronized(callbacks) {
+            callbacks.toList()
+        }
+        val latch = CountDownLatch(list.size)
+        list.forEach {
+            val callback = it.callback.get()
+            if (callback != null) {
+                it.executor.execute {
+                    callback.onUserChanging(userId, userContext, latch)
+                }
+            } else {
+                latch.countDown()
+            }
+        }
+        latch.await()
     }
 
     @WorkerThread
@@ -225,25 +238,18 @@
         }
     }
 
-    private inline fun notifySubscribers(
-            crossinline action: UserTracker.Callback.() -> Unit
-    ): CountDownLatch {
+    private inline fun notifySubscribers(crossinline action: UserTracker.Callback.() -> Unit) {
         val list = synchronized(callbacks) {
             callbacks.toList()
         }
-        val latch = CountDownLatch(list.size)
 
         list.forEach {
             if (it.callback.get() != null) {
                 it.executor.execute {
                     it.callback.get()?.action()
-                    latch.countDown()
                 }
-            } else {
-                latch.countDown()
             }
         }
-        return latch
     }
 
     override fun dump(pw: PrintWriter, args: Array<out String>) {
diff --git a/packages/SystemUI/src/com/android/systemui/statusbar/notification/InstantAppNotifier.java b/packages/SystemUI/src/com/android/systemui/statusbar/notification/InstantAppNotifier.java
index 0a5e986..11582d7 100644
--- a/packages/SystemUI/src/com/android/systemui/statusbar/notification/InstantAppNotifier.java
+++ b/packages/SystemUI/src/com/android/systemui/statusbar/notification/InstantAppNotifier.java
@@ -29,7 +29,6 @@
 import android.app.Notification;
 import android.app.NotificationManager;
 import android.app.PendingIntent;
-import android.app.SynchronousUserSwitchObserver;
 import android.content.ComponentName;
 import android.content.Context;
 import android.content.Intent;
@@ -52,7 +51,9 @@
 import com.android.systemui.CoreStartable;
 import com.android.systemui.R;
 import com.android.systemui.dagger.SysUISingleton;
+import com.android.systemui.dagger.qualifiers.Main;
 import com.android.systemui.dagger.qualifiers.UiBackground;
+import com.android.systemui.settings.UserTracker;
 import com.android.systemui.statusbar.CommandQueue;
 import com.android.systemui.statusbar.policy.KeyguardStateController;
 import com.android.systemui.util.NotificationChannels;
@@ -73,6 +74,8 @@
 
     private final Context mContext;
     private final Handler mHandler = new Handler();
+    private final UserTracker mUserTracker;
+    private final Executor mMainExecutor;
     private final Executor mUiBgExecutor;
     private final ArraySet<Pair<String, Integer>> mCurrentNotifs = new ArraySet<>();
     private final CommandQueue mCommandQueue;
@@ -82,10 +85,14 @@
     public InstantAppNotifier(
             Context context,
             CommandQueue commandQueue,
+            UserTracker userTracker,
+            @Main Executor mainExecutor,
             @UiBackground Executor uiBgExecutor,
             KeyguardStateController keyguardStateController) {
         mContext = context;
         mCommandQueue = commandQueue;
+        mUserTracker = userTracker;
+        mMainExecutor = mainExecutor;
         mUiBgExecutor = uiBgExecutor;
         mKeyguardStateController = keyguardStateController;
     }
@@ -93,11 +100,7 @@
     @Override
     public void start() {
         // listen for user / profile change.
-        try {
-            ActivityManager.getService().registerUserSwitchObserver(mUserSwitchListener, TAG);
-        } catch (RemoteException e) {
-            // Ignore
-        }
+        mUserTracker.addCallback(mUserSwitchListener, mMainExecutor);
 
         mCommandQueue.addCallback(this);
         mKeyguardStateController.addCallback(this);
@@ -129,13 +132,10 @@
         updateForegroundInstantApps();
     }
 
-    private final SynchronousUserSwitchObserver mUserSwitchListener =
-            new SynchronousUserSwitchObserver() {
+    private final UserTracker.Callback mUserSwitchListener =
+            new UserTracker.Callback() {
                 @Override
-                public void onUserSwitching(int newUserId) throws RemoteException {}
-
-                @Override
-                public void onUserSwitchComplete(int newUserId) throws RemoteException {
+                public void onUserChanged(int newUser, Context userContext) {
                     mHandler.post(
                             () -> {
                                 updateForegroundInstantApps();
diff --git a/packages/SystemUI/src/com/android/systemui/statusbar/phone/PhoneStatusBarPolicy.java b/packages/SystemUI/src/com/android/systemui/statusbar/phone/PhoneStatusBarPolicy.java
index 6c532a5..e6b76ad 100644
--- a/packages/SystemUI/src/com/android/systemui/statusbar/phone/PhoneStatusBarPolicy.java
+++ b/packages/SystemUI/src/com/android/systemui/statusbar/phone/PhoneStatusBarPolicy.java
@@ -22,8 +22,6 @@
 import android.app.ActivityTaskManager;
 import android.app.AlarmManager;
 import android.app.AlarmManager.AlarmClockInfo;
-import android.app.IActivityManager;
-import android.app.SynchronousUserSwitchObserver;
 import android.app.admin.DevicePolicyManager;
 import android.content.BroadcastReceiver;
 import android.content.Context;
@@ -134,7 +132,6 @@
     private final NextAlarmController mNextAlarmController;
     private final AlarmManager mAlarmManager;
     private final UserInfoController mUserInfoController;
-    private final IActivityManager mIActivityManager;
     private final UserManager mUserManager;
     private final UserTracker mUserTracker;
     private final DevicePolicyManager mDevicePolicyManager;
@@ -149,6 +146,7 @@
     private final KeyguardStateController mKeyguardStateController;
     private final LocationController mLocationController;
     private final PrivacyItemController mPrivacyItemController;
+    private final Executor mMainExecutor;
     private final Executor mUiBgExecutor;
     private final SensorPrivacyController mSensorPrivacyController;
     private final RecordingController mRecordingController;
@@ -168,16 +166,17 @@
     @Inject
     public PhoneStatusBarPolicy(StatusBarIconController iconController,
             CommandQueue commandQueue, BroadcastDispatcher broadcastDispatcher,
-            @UiBackground Executor uiBgExecutor, @Main Looper looper, @Main Resources resources,
-            CastController castController, HotspotController hotspotController,
-            BluetoothController bluetoothController, NextAlarmController nextAlarmController,
-            UserInfoController userInfoController, RotationLockController rotationLockController,
-            DataSaverController dataSaverController, ZenModeController zenModeController,
+            @Main Executor mainExecutor, @UiBackground Executor uiBgExecutor, @Main Looper looper,
+            @Main Resources resources, CastController castController,
+            HotspotController hotspotController, BluetoothController bluetoothController,
+            NextAlarmController nextAlarmController, UserInfoController userInfoController,
+            RotationLockController rotationLockController, DataSaverController dataSaverController,
+            ZenModeController zenModeController,
             DeviceProvisionedController deviceProvisionedController,
             KeyguardStateController keyguardStateController,
             LocationController locationController,
-            SensorPrivacyController sensorPrivacyController, IActivityManager iActivityManager,
-            AlarmManager alarmManager, UserManager userManager, UserTracker userTracker,
+            SensorPrivacyController sensorPrivacyController, AlarmManager alarmManager,
+            UserManager userManager, UserTracker userTracker,
             DevicePolicyManager devicePolicyManager, RecordingController recordingController,
             @Nullable TelecomManager telecomManager, @DisplayId int displayId,
             @Main SharedPreferences sharedPreferences, DateFormatUtil dateFormatUtil,
@@ -195,7 +194,6 @@
         mNextAlarmController = nextAlarmController;
         mAlarmManager = alarmManager;
         mUserInfoController = userInfoController;
-        mIActivityManager = iActivityManager;
         mUserManager = userManager;
         mUserTracker = userTracker;
         mDevicePolicyManager = devicePolicyManager;
@@ -208,6 +206,7 @@
         mPrivacyItemController = privacyItemController;
         mSensorPrivacyController = sensorPrivacyController;
         mRecordingController = recordingController;
+        mMainExecutor = mainExecutor;
         mUiBgExecutor = uiBgExecutor;
         mTelecomManager = telecomManager;
         mRingerModeTracker = ringerModeTracker;
@@ -256,11 +255,7 @@
         mRingerModeTracker.getRingerModeInternal().observeForever(observer);
 
         // listen for user / profile change.
-        try {
-            mIActivityManager.registerUserSwitchObserver(mUserSwitchListener, TAG);
-        } catch (RemoteException e) {
-            // Ignore
-        }
+        mUserTracker.addCallback(mUserSwitchListener, mMainExecutor);
 
         // TTY status
         updateTTY();
@@ -555,15 +550,15 @@
         });
     }
 
-    private final SynchronousUserSwitchObserver mUserSwitchListener =
-            new SynchronousUserSwitchObserver() {
+    private final UserTracker.Callback mUserSwitchListener =
+            new UserTracker.Callback() {
                 @Override
-                public void onUserSwitching(int newUserId) throws RemoteException {
+                public void onUserChanging(int newUser, Context userContext) {
                     mHandler.post(() -> mUserInfoController.reloadUserInfo());
                 }
 
                 @Override
-                public void onUserSwitchComplete(int newUserId) throws RemoteException {
+                public void onUserChanged(int newUser, Context userContext) {
                     mHandler.post(() -> {
                         updateAlarm();
                         updateManagedProfile();
diff --git a/packages/SystemUI/src/com/android/systemui/user/data/repository/UserRepository.kt b/packages/SystemUI/src/com/android/systemui/user/data/repository/UserRepository.kt
index ad1e5fe..b2b7c0b 100644
--- a/packages/SystemUI/src/com/android/systemui/user/data/repository/UserRepository.kt
+++ b/packages/SystemUI/src/com/android/systemui/user/data/repository/UserRepository.kt
@@ -17,11 +17,8 @@
 
 package com.android.systemui.user.data.repository
 
-import android.app.IActivityManager
-import android.app.UserSwitchObserver
 import android.content.Context
 import android.content.pm.UserInfo
-import android.os.IRemoteCallback
 import android.os.UserHandle
 import android.os.UserManager
 import android.provider.Settings
@@ -118,7 +115,6 @@
     @Background private val backgroundDispatcher: CoroutineDispatcher,
     private val globalSettings: GlobalSettings,
     private val tracker: UserTracker,
-    private val activityManager: IActivityManager,
     featureFlags: FeatureFlags,
 ) : UserRepository {
 
@@ -203,18 +199,18 @@
     private fun observeUserSwitching() {
         conflatedCallbackFlow {
                 val callback =
-                    object : UserSwitchObserver() {
-                        override fun onUserSwitching(newUserId: Int, reply: IRemoteCallback) {
+                    object : UserTracker.Callback {
+                        override fun onUserChanging(newUser: Int, userContext: Context) {
                             trySendWithFailureLogging(true, TAG, "userSwitching started")
                         }
 
-                        override fun onUserSwitchComplete(newUserId: Int) {
+                        override fun onUserChanged(newUserId: Int, userContext: Context) {
                             trySendWithFailureLogging(false, TAG, "userSwitching completed")
                         }
                     }
-                activityManager.registerUserSwitchObserver(callback, TAG)
+                tracker.addCallback(callback, mainDispatcher.asExecutor())
                 trySendWithFailureLogging(false, TAG, "initial value defaulting to false")
-                awaitClose { activityManager.unregisterUserSwitchObserver(callback) }
+                awaitClose { tracker.removeCallback(callback) }
             }
             .onEach { _isUserSwitchingInProgress.value = it }
             // TODO (b/262838215), Make this stateIn and initialize directly in field declaration
diff --git a/packages/SystemUI/tests/src/com/android/keyguard/KeyguardUpdateMonitorTest.java b/packages/SystemUI/tests/src/com/android/keyguard/KeyguardUpdateMonitorTest.java
index c363026..250f221 100644
--- a/packages/SystemUI/tests/src/com/android/keyguard/KeyguardUpdateMonitorTest.java
+++ b/packages/SystemUI/tests/src/com/android/keyguard/KeyguardUpdateMonitorTest.java
@@ -57,8 +57,6 @@
 import static org.mockito.Mockito.when;
 
 import android.app.Activity;
-import android.app.ActivityManager;
-import android.app.IActivityManager;
 import android.app.admin.DevicePolicyManager;
 import android.app.trust.IStrongAuthTracker;
 import android.app.trust.TrustManager;
@@ -91,7 +89,6 @@
 import android.os.Bundle;
 import android.os.CancellationSignal;
 import android.os.Handler;
-import android.os.IRemoteCallback;
 import android.os.PowerManager;
 import android.os.RemoteException;
 import android.os.UserHandle;
@@ -149,6 +146,7 @@
 import java.util.Arrays;
 import java.util.List;
 import java.util.Optional;
+import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.Executor;
 import java.util.concurrent.atomic.AtomicBoolean;
 
@@ -232,8 +230,6 @@
     @Mock
     private KeyguardUpdateMonitorLogger mKeyguardUpdateMonitorLogger;
     @Mock
-    private IActivityManager mActivityService;
-    @Mock
     private SessionTracker mSessionTracker;
     @Mock
     private UiEventLogger mUiEventLogger;
@@ -270,8 +266,6 @@
     @Before
     public void setup() throws RemoteException {
         MockitoAnnotations.initMocks(this);
-        when(mActivityService.getCurrentUser()).thenReturn(mCurrentUserInfo);
-        when(mActivityService.getCurrentUserId()).thenReturn(mCurrentUserId);
         when(mFaceManager.isHardwareDetected()).thenReturn(true);
         when(mFaceManager.hasEnrolledTemplates()).thenReturn(true);
         when(mFaceManager.hasEnrolledTemplates(anyInt())).thenReturn(true);
@@ -311,13 +305,11 @@
 
         mMockitoSession = ExtendedMockito.mockitoSession()
                 .spyStatic(SubscriptionManager.class)
-                .spyStatic(ActivityManager.class)
                 .startMocking();
         ExtendedMockito.doReturn(SubscriptionManager.INVALID_SUBSCRIPTION_ID)
                 .when(SubscriptionManager::getDefaultSubscriptionId);
         KeyguardUpdateMonitor.setCurrentUser(mCurrentUserId);
         when(mUserTracker.getUserId()).thenReturn(mCurrentUserId);
-        ExtendedMockito.doReturn(mActivityService).when(ActivityManager::getService);
 
         mContext.getOrCreateTestableResources().addOverride(
                 com.android.systemui.R.integer.config_face_auth_supported_posture,
@@ -1091,11 +1083,6 @@
 
     @Test
     public void testBiometricsCleared_whenUserSwitches() throws Exception {
-        final IRemoteCallback reply = new IRemoteCallback.Stub() {
-            @Override
-            public void sendResult(Bundle data) {
-            } // do nothing
-        };
         final BiometricAuthenticated dummyAuthentication =
                 new BiometricAuthenticated(true /* authenticated */, true /* strong */);
         mKeyguardUpdateMonitor.mUserFaceAuthenticated.put(0 /* user */, dummyAuthentication);
@@ -1103,18 +1090,13 @@
         assertThat(mKeyguardUpdateMonitor.mUserFingerprintAuthenticated.size()).isEqualTo(1);
         assertThat(mKeyguardUpdateMonitor.mUserFaceAuthenticated.size()).isEqualTo(1);
 
-        mKeyguardUpdateMonitor.handleUserSwitching(10 /* user */, reply);
+        mKeyguardUpdateMonitor.handleUserSwitching(10 /* user */, new CountDownLatch(0));
         assertThat(mKeyguardUpdateMonitor.mUserFingerprintAuthenticated.size()).isEqualTo(0);
         assertThat(mKeyguardUpdateMonitor.mUserFaceAuthenticated.size()).isEqualTo(0);
     }
 
     @Test
     public void testMultiUserJankMonitor_whenUserSwitches() throws Exception {
-        final IRemoteCallback reply = new IRemoteCallback.Stub() {
-            @Override
-            public void sendResult(Bundle data) {
-            } // do nothing
-        };
         mKeyguardUpdateMonitor.handleUserSwitchComplete(10 /* user */);
         verify(mInteractionJankMonitor).end(InteractionJankMonitor.CUJ_USER_SWITCH);
         verify(mLatencyTracker).onActionEnd(LatencyTracker.ACTION_USER_SWITCH);
diff --git a/packages/SystemUI/tests/src/com/android/systemui/statusbar/phone/PhoneStatusBarPolicyTest.kt b/packages/SystemUI/tests/src/com/android/systemui/statusbar/phone/PhoneStatusBarPolicyTest.kt
index 305b9fe..6b18169 100644
--- a/packages/SystemUI/tests/src/com/android/systemui/statusbar/phone/PhoneStatusBarPolicyTest.kt
+++ b/packages/SystemUI/tests/src/com/android/systemui/statusbar/phone/PhoneStatusBarPolicyTest.kt
@@ -17,7 +17,6 @@
 package com.android.systemui.statusbar.phone
 
 import android.app.AlarmManager
-import android.app.IActivityManager
 import android.app.admin.DevicePolicyManager
 import android.content.SharedPreferences
 import android.os.UserManager
@@ -87,7 +86,6 @@
     @Mock private lateinit var keyguardStateController: KeyguardStateController
     @Mock private lateinit var locationController: LocationController
     @Mock private lateinit var sensorPrivacyController: SensorPrivacyController
-    @Mock private lateinit var iActivityManager: IActivityManager
     @Mock private lateinit var alarmManager: AlarmManager
     @Mock private lateinit var userManager: UserManager
     @Mock private lateinit var userTracker: UserTracker
@@ -176,6 +174,7 @@
             commandQueue,
             broadcastDispatcher,
             executor,
+            executor,
             testableLooper.looper,
             context.resources,
             castController,
@@ -190,7 +189,6 @@
             keyguardStateController,
             locationController,
             sensorPrivacyController,
-            iActivityManager,
             alarmManager,
             userManager,
             userTracker,
diff --git a/packages/SystemUI/tests/src/com/android/systemui/user/data/repository/UserRepositoryImplTest.kt b/packages/SystemUI/tests/src/com/android/systemui/user/data/repository/UserRepositoryImplTest.kt
index ccf378a..ddd880b 100644
--- a/packages/SystemUI/tests/src/com/android/systemui/user/data/repository/UserRepositoryImplTest.kt
+++ b/packages/SystemUI/tests/src/com/android/systemui/user/data/repository/UserRepositoryImplTest.kt
@@ -17,10 +17,7 @@
 
 package com.android.systemui.user.data.repository
 
-import android.app.IActivityManager
-import android.app.UserSwitchObserver
 import android.content.pm.UserInfo
-import android.os.IRemoteCallback
 import android.os.UserHandle
 import android.os.UserManager
 import android.provider.Settings
@@ -44,14 +41,8 @@
 import org.junit.Test
 import org.junit.runner.RunWith
 import org.junit.runners.JUnit4
-import org.mockito.ArgumentCaptor
-import org.mockito.Captor
 import org.mockito.Mock
-import org.mockito.Mockito.any
-import org.mockito.Mockito.anyString
 import org.mockito.Mockito.mock
-import org.mockito.Mockito.times
-import org.mockito.Mockito.verify
 import org.mockito.Mockito.`when` as whenever
 import org.mockito.MockitoAnnotations
 
@@ -60,8 +51,6 @@
 class UserRepositoryImplTest : SysuiTestCase() {
 
     @Mock private lateinit var manager: UserManager
-    @Mock private lateinit var activityManager: IActivityManager
-    @Captor private lateinit var userSwitchObserver: ArgumentCaptor<UserSwitchObserver>
 
     private lateinit var underTest: UserRepositoryImpl
 
@@ -229,30 +218,31 @@
     }
 
     @Test
-    fun userSwitchingInProgress_registersOnlyOneUserSwitchObserver() = runSelfCancelingTest {
+    fun userSwitchingInProgress_registersUserTrackerCallback() = runSelfCancelingTest {
         underTest = create(this)
 
         underTest.userSwitchingInProgress.launchIn(this)
         underTest.userSwitchingInProgress.launchIn(this)
         underTest.userSwitchingInProgress.launchIn(this)
 
-        verify(activityManager, times(1)).registerUserSwitchObserver(any(), anyString())
+        // Two callbacks registered - one for observing user switching and one for observing the
+        // selected user
+        assertThat(tracker.callbacks.size).isEqualTo(2)
     }
 
     @Test
-    fun userSwitchingInProgress_propagatesStateFromActivityManager() = runSelfCancelingTest {
+    fun userSwitchingInProgress_propagatesStateFromUserTracker() = runSelfCancelingTest {
         underTest = create(this)
-        verify(activityManager)
-            .registerUserSwitchObserver(userSwitchObserver.capture(), anyString())
+        assertThat(tracker.callbacks.size).isEqualTo(2)
 
-        userSwitchObserver.value.onUserSwitching(0, mock(IRemoteCallback::class.java))
+        tracker.onUserChanging(0)
 
         var mostRecentSwitchingValue = false
         underTest.userSwitchingInProgress.onEach { mostRecentSwitchingValue = it }.launchIn(this)
 
         assertThat(mostRecentSwitchingValue).isTrue()
 
-        userSwitchObserver.value.onUserSwitchComplete(0)
+        tracker.onUserChanged(0)
         assertThat(mostRecentSwitchingValue).isFalse()
     }
 
@@ -332,7 +322,6 @@
             backgroundDispatcher = IMMEDIATE,
             globalSettings = globalSettings,
             tracker = tracker,
-            activityManager = activityManager,
             featureFlags = featureFlags,
         )
     }
diff --git a/packages/SystemUI/tests/utils/src/com/android/systemui/settings/FakeUserTracker.kt b/packages/SystemUI/tests/utils/src/com/android/systemui/settings/FakeUserTracker.kt
index 251014f..4242c16 100644
--- a/packages/SystemUI/tests/utils/src/com/android/systemui/settings/FakeUserTracker.kt
+++ b/packages/SystemUI/tests/utils/src/com/android/systemui/settings/FakeUserTracker.kt
@@ -22,6 +22,7 @@
 import android.os.UserHandle
 import android.test.mock.MockContentResolver
 import com.android.systemui.util.mockito.mock
+import java.util.concurrent.CountDownLatch
 import java.util.concurrent.Executor
 
 /** A fake [UserTracker] to be used in tests. */
@@ -66,11 +67,19 @@
         _userId = _userInfo.id
         _userHandle = UserHandle.of(_userId)
 
+        onUserChanging()
+        onUserChanged()
+    }
+
+    fun onUserChanging(userId: Int = _userId) {
         val copy = callbacks.toList()
-        copy.forEach {
-            it.onUserChanging(_userId, userContext)
-            it.onUserChanged(_userId, userContext)
-        }
+        val latch = CountDownLatch(copy.size)
+        copy.forEach { it.onUserChanging(userId, userContext, latch) }
+    }
+
+    fun onUserChanged(userId: Int = _userId) {
+        val copy = callbacks.toList()
+        copy.forEach { it.onUserChanged(userId, userContext) }
     }
 
     fun onProfileChanged() {