Move onBeforeUserSwitching call to the beginning of the user switch.

Bug: 331853529
Bug: 360838273
Test: atest FrameworksServicesTests:UserControllerTest
Flag: EXEMPT bugfix
Change-Id: I03e3756194ea0565ea5ce1b5ac383beffae36839
diff --git a/services/core/java/com/android/server/am/UserController.java b/services/core/java/com/android/server/am/UserController.java
index fb55700..262c76e 100644
--- a/services/core/java/com/android/server/am/UserController.java
+++ b/services/core/java/com/android/server/am/UserController.java
@@ -1980,6 +1980,7 @@
                 // it should be moved outside, but for now it's not as there are many calls to
                 // external components here afterwards
                 updateProfileRelatedCaches();
+                dispatchOnBeforeUserSwitching(userId);
                 mInjector.getWindowManager().setCurrentUser(userId);
                 mInjector.reportCurWakefulnessUsageEvent();
                 // Once the internal notion of the active user has switched, we lock the device
@@ -2285,6 +2286,25 @@
         mUserSwitchObservers.finishBroadcast();
     }
 
+    private void dispatchOnBeforeUserSwitching(@UserIdInt int newUserId) {
+        final TimingsTraceAndSlog t = new TimingsTraceAndSlog();
+        t.traceBegin("dispatchOnBeforeUserSwitching-" + newUserId);
+        final int observerCount = mUserSwitchObservers.beginBroadcast();
+        for (int i = 0; i < observerCount; i++) {
+            final String name = "#" + i + " " + mUserSwitchObservers.getBroadcastCookie(i);
+            t.traceBegin("onBeforeUserSwitching-" + name);
+            try {
+                mUserSwitchObservers.getBroadcastItem(i).onBeforeUserSwitching(newUserId);
+            } catch (RemoteException e) {
+                // Ignore
+            } finally {
+                t.traceEnd();
+            }
+        }
+        mUserSwitchObservers.finishBroadcast();
+        t.traceEnd();
+    }
+
     /** Called on handler thread */
     @VisibleForTesting
     void dispatchUserSwitchComplete(@UserIdInt int oldUserId, @UserIdInt int newUserId) {
@@ -2500,17 +2520,6 @@
 
         final int observerCount = mUserSwitchObservers.beginBroadcast();
         if (observerCount > 0) {
-            for (int i = 0; i < observerCount; i++) {
-                final String name = "#" + i + " " + mUserSwitchObservers.getBroadcastCookie(i);
-                t.traceBegin("onBeforeUserSwitching-" + name);
-                try {
-                    mUserSwitchObservers.getBroadcastItem(i).onBeforeUserSwitching(newUserId);
-                } catch (RemoteException e) {
-                    // Ignore
-                } finally {
-                    t.traceEnd();
-                }
-            }
             final ArraySet<String> curWaitingUserSwitchCallbacks = new ArraySet<>();
             synchronized (mLock) {
                 uss.switching = true;
diff --git a/services/tests/servicestests/src/com/android/server/am/UserControllerTest.java b/services/tests/servicestests/src/com/android/server/am/UserControllerTest.java
index d97a0fa..390eb93 100644
--- a/services/tests/servicestests/src/com/android/server/am/UserControllerTest.java
+++ b/services/tests/servicestests/src/com/android/server/am/UserControllerTest.java
@@ -427,6 +427,7 @@
         mUserController.registerUserSwitchObserver(observer, "mock");
         // Start user -- this will update state of mUserController
         mUserController.startUser(TEST_USER_ID, USER_START_MODE_FOREGROUND);
+        verify(observer, times(1)).onBeforeUserSwitching(eq(TEST_USER_ID));
         Message reportMsg = mInjector.mHandler.getMessageForCode(REPORT_USER_SWITCH_MSG);
         assertNotNull(reportMsg);
         UserState userState = (UserState) reportMsg.obj;
@@ -435,7 +436,6 @@
         // Call dispatchUserSwitch and verify that observer was called only once
         mInjector.mHandler.clearAllRecordedMessages();
         mUserController.dispatchUserSwitch(userState, oldUserId, newUserId);
-        verify(observer, times(1)).onBeforeUserSwitching(eq(TEST_USER_ID));
         verify(observer, times(1)).onUserSwitching(eq(TEST_USER_ID), any());
         Set<Integer> expectedCodes = Collections.singleton(CONTINUE_USER_SWITCH_MSG);
         Set<Integer> actualCodes = mInjector.mHandler.getMessageCodes();
@@ -458,6 +458,7 @@
         mUserController.registerUserSwitchObserver(observer, "mock");
         // Start user -- this will update state of mUserController
         mUserController.startUser(TEST_USER_ID, USER_START_MODE_FOREGROUND);
+        verify(observer, times(1)).onBeforeUserSwitching(eq(TEST_USER_ID));
         Message reportMsg = mInjector.mHandler.getMessageForCode(REPORT_USER_SWITCH_MSG);
         assertNotNull(reportMsg);
         UserState userState = (UserState) reportMsg.obj;
@@ -466,7 +467,6 @@
         // Call dispatchUserSwitch and verify that observer was called only once
         mInjector.mHandler.clearAllRecordedMessages();
         mUserController.dispatchUserSwitch(userState, oldUserId, newUserId);
-        verify(observer, times(1)).onBeforeUserSwitching(eq(TEST_USER_ID));
         verify(observer, times(1)).onUserSwitching(eq(TEST_USER_ID), any());
         // Verify that CONTINUE_USER_SWITCH_MSG is not sent (triggers timeout)
         Set<Integer> actualCodes = mInjector.mHandler.getMessageCodes();