Merge "Unblock screen-on after start transaction is completed" into main
diff --git a/core/java/android/window/flags/windowing_frontend.aconfig b/core/java/android/window/flags/windowing_frontend.aconfig
index 4402ac7..dd6b772a 100644
--- a/core/java/android/window/flags/windowing_frontend.aconfig
+++ b/core/java/android/window/flags/windowing_frontend.aconfig
@@ -9,6 +9,16 @@
 }
 
 flag {
+    name: "wait_for_transition_on_display_switch"
+    namespace: "windowing_frontend"
+    description: "Waits for Shell transition to start before unblocking the screen after display switch"
+    bug: "301420598"
+    metadata {
+        purpose: PURPOSE_BUGFIX
+    }
+}
+
+flag {
   name: "edge_to_edge_by_default"
   namespace: "windowing_frontend"
   description: "Make app go edge-to-edge by default when targeting SDK 35 or greater"
diff --git a/services/core/java/com/android/server/display/DisplayManagerService.java b/services/core/java/com/android/server/display/DisplayManagerService.java
index 40f0362..31092f2 100644
--- a/services/core/java/com/android/server/display/DisplayManagerService.java
+++ b/services/core/java/com/android/server/display/DisplayManagerService.java
@@ -757,6 +757,7 @@
             mContext.getSystemService(DeviceStateManager.class).registerCallback(
                     new HandlerExecutor(mHandler), new DeviceStateListener());
 
+            mLogicalDisplayMapper.onWindowManagerReady();
             scheduleTraversalLocked(false);
         }
     }
diff --git a/services/core/java/com/android/server/display/LogicalDisplayMapper.java b/services/core/java/com/android/server/display/LogicalDisplayMapper.java
index 6203a32..bca53cf 100644
--- a/services/core/java/com/android/server/display/LogicalDisplayMapper.java
+++ b/services/core/java/com/android/server/display/LogicalDisplayMapper.java
@@ -41,10 +41,12 @@
 
 import com.android.internal.annotations.VisibleForTesting;
 import com.android.internal.foldables.FoldGracePeriodProvider;
+import com.android.server.LocalServices;
 import com.android.server.display.feature.DisplayManagerFlags;
 import com.android.server.display.layout.DisplayIdProducer;
 import com.android.server.display.layout.Layout;
 import com.android.server.display.utils.DebugUtils;
+import com.android.server.policy.WindowManagerPolicy;
 import com.android.server.utils.FoldSettingProvider;
 
 import java.io.PrintWriter;
@@ -189,6 +191,7 @@
      * #updateLogicalDisplaysLocked} to establish which Virtual Devices own which Virtual Displays.
      */
     private final ArrayMap<String, Integer> mVirtualDeviceDisplayMapping = new ArrayMap<>();
+    private WindowManagerPolicy mWindowManagerPolicy;
 
     private int mNextNonDefaultGroupId = Display.DEFAULT_DISPLAY_GROUP + 1;
     private final DisplayIdProducer mIdProducer = (isDefault) ->
@@ -274,6 +277,10 @@
         mListener.onTraversalRequested();
     }
 
+    public void onWindowManagerReady() {
+        mWindowManagerPolicy = LocalServices.getService(WindowManagerPolicy.class);
+    }
+
     public LogicalDisplay getDisplayLocked(int displayId) {
         return getDisplayLocked(displayId, /* includeDisabled= */ true);
     }
@@ -1114,14 +1121,22 @@
             final int logicalDisplayId = displayLayout.getLogicalDisplayId();
 
             LogicalDisplay newDisplay = getDisplayLocked(logicalDisplayId);
+            boolean newDisplayCreated = false;
             if (newDisplay == null) {
                 newDisplay = createNewLogicalDisplayLocked(
                         null /*displayDevice*/, logicalDisplayId);
+                newDisplayCreated = true;
             }
 
             // Now swap the underlying display devices between the old display and the new display
             final LogicalDisplay oldDisplay = getDisplayLocked(device);
             if (newDisplay != oldDisplay) {
+                // Display is swapping, notify WindowManager, so it can prepare for
+                // the display switch
+                if (!newDisplayCreated && mWindowManagerPolicy != null) {
+                    mWindowManagerPolicy.onDisplaySwitchStart(newDisplay.getDisplayIdLocked());
+                }
+
                 newDisplay.swapDisplaysLocked(oldDisplay);
             }
             DisplayDeviceConfig config = device.getDisplayDeviceConfig();
diff --git a/services/core/java/com/android/server/policy/PhoneWindowManager.java b/services/core/java/com/android/server/policy/PhoneWindowManager.java
index 76bf8fd..7db83d7 100644
--- a/services/core/java/com/android/server/policy/PhoneWindowManager.java
+++ b/services/core/java/com/android/server/policy/PhoneWindowManager.java
@@ -5664,6 +5664,13 @@
         }
     }
 
+    @Override
+    public void onDisplaySwitchStart(int displayId) {
+        if (displayId == DEFAULT_DISPLAY) {
+            mDefaultDisplayPolicy.onDisplaySwitchStart();
+        }
+    }
+
     private long getKeyguardDrawnTimeout() {
         final boolean bootCompleted =
                 LocalServices.getService(SystemServiceManager.class).isBootCompleted();
diff --git a/services/core/java/com/android/server/policy/WindowManagerPolicy.java b/services/core/java/com/android/server/policy/WindowManagerPolicy.java
index 5956594..2623025 100644
--- a/services/core/java/com/android/server/policy/WindowManagerPolicy.java
+++ b/services/core/java/com/android/server/policy/WindowManagerPolicy.java
@@ -895,6 +895,9 @@
         void onScreenOff();
     }
 
+    /** Called when the physical display starts to switch, e.g. fold/unfold. */
+    void onDisplaySwitchStart(int displayId);
+
     /**
      * Return whether the default display is on and not blocked by a black surface.
      */
diff --git a/services/core/java/com/android/server/wm/DeferredDisplayUpdater.java b/services/core/java/com/android/server/wm/DeferredDisplayUpdater.java
index a29cb60..ca5f26a 100644
--- a/services/core/java/com/android/server/wm/DeferredDisplayUpdater.java
+++ b/services/core/java/com/android/server/wm/DeferredDisplayUpdater.java
@@ -26,6 +26,10 @@
 import android.annotation.NonNull;
 import android.annotation.Nullable;
 import android.graphics.Rect;
+import android.os.Message;
+import android.os.Trace;
+import android.util.Log;
+import android.util.Slog;
 import android.view.DisplayInfo;
 import android.window.DisplayAreaInfo;
 import android.window.TransitionRequestInfo;
@@ -35,6 +39,7 @@
 import com.android.internal.display.BrightnessSynchronizer;
 import com.android.internal.protolog.common.ProtoLog;
 import com.android.server.wm.utils.DisplayInfoOverrides.DisplayInfoFieldsUpdater;
+import com.android.window.flags.Flags;
 
 import java.util.Arrays;
 import java.util.Objects;
@@ -65,6 +70,12 @@
         WM_OVERRIDE_FIELDS.setFields(out, override);
     };
 
+    private static final String TAG = "DeferredDisplayUpdater";
+
+    private static final String TRACE_TAG_WAIT_FOR_TRANSITION =
+            "Screen unblock: wait for transition";
+    private static final int WAIT_FOR_TRANSITION_TIMEOUT = 1000;
+
     private final DisplayContent mDisplayContent;
 
     @NonNull
@@ -88,6 +99,18 @@
     @NonNull
     private final DisplayInfo mOutputDisplayInfo = new DisplayInfo();
 
+    /** Whether {@link #mScreenUnblocker} should wait for transition to be ready. */
+    private boolean mShouldWaitForTransitionWhenScreenOn;
+
+    /** The message to notify PhoneWindowManager#finishWindowsDrawn. */
+    @Nullable
+    private Message mScreenUnblocker;
+
+    private final Runnable mScreenUnblockTimeoutRunnable = () -> {
+        Slog.e(TAG, "Timeout waiting for the display switch transition to start");
+        continueScreenUnblocking();
+    };
+
     public DeferredDisplayUpdater(@NonNull DisplayContent displayContent) {
         mDisplayContent = displayContent;
         mNonOverrideDisplayInfo.copyFrom(mDisplayContent.getDisplayInfo());
@@ -248,6 +271,7 @@
                 getCurrentDisplayChange(fromRotation, startBounds);
         displayChange.setPhysicalDisplayChanged(true);
 
+        transition.addTransactionCompletedListener(this::continueScreenUnblocking);
         mDisplayContent.mTransitionController.requestStartTransition(transition,
                 /* startTask= */ null, /* remoteTransition= */ null, displayChange);
 
@@ -277,6 +301,58 @@
         return !Objects.equals(first.uniqueId, second.uniqueId);
     }
 
+    @Override
+    public void onDisplayContentDisplayPropertiesPostChanged(int previousRotation, int newRotation,
+            DisplayAreaInfo newDisplayAreaInfo) {
+        // Unblock immediately in case there is no transition. This is unlikely to happen.
+        if (mScreenUnblocker != null && !mDisplayContent.mTransitionController.inTransition()) {
+            mScreenUnblocker.sendToTarget();
+            mScreenUnblocker = null;
+        }
+    }
+
+    @Override
+    public void onDisplaySwitching(boolean switching) {
+        mShouldWaitForTransitionWhenScreenOn = switching;
+    }
+
+    @Override
+    public boolean waitForTransition(@NonNull Message screenUnblocker) {
+        if (!Flags.waitForTransitionOnDisplaySwitch()) return false;
+        if (!mShouldWaitForTransitionWhenScreenOn) {
+            return false;
+        }
+        mScreenUnblocker = screenUnblocker;
+        if (Trace.isTagEnabled(Trace.TRACE_TAG_WINDOW_MANAGER)) {
+            Trace.beginAsyncSection(TRACE_TAG_WAIT_FOR_TRANSITION, screenUnblocker.hashCode());
+        }
+
+        mDisplayContent.mWmService.mH.removeCallbacks(mScreenUnblockTimeoutRunnable);
+        mDisplayContent.mWmService.mH.postDelayed(mScreenUnblockTimeoutRunnable,
+                WAIT_FOR_TRANSITION_TIMEOUT);
+        return true;
+    }
+
+    /**
+     * Continues the screen unblocking flow, could be called either on a binder thread as
+     * a result of surface transaction completed listener or from {@link WindowManagerService#mH}
+     * handler in case of timeout
+     */
+    private void continueScreenUnblocking() {
+        synchronized (mDisplayContent.mWmService.mGlobalLock) {
+            mShouldWaitForTransitionWhenScreenOn = false;
+            mDisplayContent.mWmService.mH.removeCallbacks(mScreenUnblockTimeoutRunnable);
+            if (mScreenUnblocker == null) {
+                return;
+            }
+            mScreenUnblocker.sendToTarget();
+            if (Trace.isTagEnabled(Trace.TRACE_TAG_WINDOW_MANAGER)) {
+                Trace.endAsyncSection(TRACE_TAG_WAIT_FOR_TRANSITION, mScreenUnblocker.hashCode());
+            }
+            mScreenUnblocker = null;
+        }
+    }
+
     /**
      * Diff result: fields are the same
      */
diff --git a/services/core/java/com/android/server/wm/DisplayContent.java b/services/core/java/com/android/server/wm/DisplayContent.java
index 54abbc3..cde3e68 100644
--- a/services/core/java/com/android/server/wm/DisplayContent.java
+++ b/services/core/java/com/android/server/wm/DisplayContent.java
@@ -470,7 +470,7 @@
     private final DisplayRotation mDisplayRotation;
     @Nullable final DisplayRotationCompatPolicy mDisplayRotationCompatPolicy;
     DisplayFrames mDisplayFrames;
-    private final DisplayUpdater mDisplayUpdater;
+    final DisplayUpdater mDisplayUpdater;
 
     private boolean mInTouchMode;
 
diff --git a/services/core/java/com/android/server/wm/DisplayPolicy.java b/services/core/java/com/android/server/wm/DisplayPolicy.java
index 16f7373..a5037ea 100644
--- a/services/core/java/com/android/server/wm/DisplayPolicy.java
+++ b/services/core/java/com/android/server/wm/DisplayPolicy.java
@@ -779,6 +779,11 @@
         return mLidState;
     }
 
+    private void onDisplaySwitchFinished() {
+        mDisplayContent.mWallpaperController.onDisplaySwitchFinished();
+        mDisplayContent.mDisplayUpdater.onDisplaySwitching(false);
+    }
+
     public void setAwake(boolean awake) {
         synchronized (mLock) {
             if (awake == mAwake) {
@@ -797,7 +802,7 @@
             mService.mAtmService.mKeyguardController.updateDeferTransitionForAod(
                     mAwake /* waiting */);
             if (!awake) {
-                mDisplayContent.mWallpaperController.onDisplaySwitchFinished();
+                onDisplaySwitchFinished();
             }
         }
     }
@@ -866,7 +871,7 @@
 
     /** It is called after {@link #finishScreenTurningOn}. This runs on PowerManager's thread. */
     public void screenTurnedOn() {
-        mDisplayContent.mWallpaperController.onDisplaySwitchFinished();
+        onDisplaySwitchFinished();
     }
 
     public void screenTurnedOff() {
@@ -2187,6 +2192,11 @@
                 mDisplayContent.mTransitionController.getCollectingTransitionId();
     }
 
+    /** If this is called, expect that there will be an onDisplayChanged about unique id. */
+    public void onDisplaySwitchStart() {
+        mDisplayContent.mDisplayUpdater.onDisplaySwitching(true);
+    }
+
     @NavigationBarPosition
     int navigationBarPosition(int displayRotation) {
         if (mNavigationBar != null) {
diff --git a/services/core/java/com/android/server/wm/DisplayUpdater.java b/services/core/java/com/android/server/wm/DisplayUpdater.java
index e611177..918b180 100644
--- a/services/core/java/com/android/server/wm/DisplayUpdater.java
+++ b/services/core/java/com/android/server/wm/DisplayUpdater.java
@@ -17,6 +17,7 @@
 package com.android.server.wm;
 
 import android.annotation.NonNull;
+import android.os.Message;
 import android.view.Surface;
 import android.window.DisplayAreaInfo;
 
@@ -49,4 +50,16 @@
             @Surface.Rotation int previousRotation, @Surface.Rotation int newRotation,
             @NonNull DisplayAreaInfo newDisplayAreaInfo) {
     }
+
+    /**
+     * Called with {@code true} when physical display is going to switch. And {@code false} when
+     * the display is turned on or the device goes to sleep.
+     */
+    default void onDisplaySwitching(boolean switching) {
+    }
+
+    /** Returns {@code true} if the transition will control when to turn on the screen. */
+    default boolean waitForTransition(@NonNull Message screenUnBlocker) {
+        return false;
+    }
 }
diff --git a/services/core/java/com/android/server/wm/Transition.java b/services/core/java/com/android/server/wm/Transition.java
index 3779d9e..1b380aa 100644
--- a/services/core/java/com/android/server/wm/Transition.java
+++ b/services/core/java/com/android/server/wm/Transition.java
@@ -112,6 +112,7 @@
 import java.util.ArrayList;
 import java.util.List;
 import java.util.Objects;
+import java.util.concurrent.Executor;
 import java.util.function.Predicate;
 
 /**
@@ -233,6 +234,9 @@
      */
     private ArrayList<Task> mTransientHideTasks;
 
+    @VisibleForTesting
+    ArrayList<Runnable> mTransactionCompletedListeners = null;
+
     /** Custom activity-level animation options and callbacks. */
     private TransitionInfo.AnimationOptions mOverrideOptions;
     private IRemoteCallback mClientAnimationStartCallback = null;
@@ -1640,6 +1644,14 @@
         commitVisibleActivities(transaction);
         commitVisibleWallpapers();
 
+        if (mTransactionCompletedListeners != null) {
+            for (int i = 0; i < mTransactionCompletedListeners.size(); i++) {
+                final Runnable listener = mTransactionCompletedListeners.get(i);
+                transaction.addTransactionCompletedListener(Runnable::run,
+                        (stats) -> listener.run());
+            }
+        }
+
         // Fall-back to the default display if there isn't one participating.
         final DisplayContent primaryDisplay = !mTargetDisplays.isEmpty() ? mTargetDisplays.get(0)
                 : mController.mAtm.mRootWindowContainer.getDefaultDisplay();
@@ -1862,6 +1874,17 @@
     }
 
     /**
+     * Adds a listener that will be executed after the start transaction of this transition
+     * is presented on the screen, the listener will be executed on a binder thread
+     */
+    void addTransactionCompletedListener(Runnable listener) {
+        if (mTransactionCompletedListeners == null) {
+            mTransactionCompletedListeners = new ArrayList<>();
+        }
+        mTransactionCompletedListeners.add(listener);
+    }
+
+    /**
      * Checks if the transition contains order changes.
      *
      * This is a shallow check that doesn't account for collection in parallel, unlike
diff --git a/services/core/java/com/android/server/wm/WindowManagerService.java b/services/core/java/com/android/server/wm/WindowManagerService.java
index 04ca0ae..2e72121 100644
--- a/services/core/java/com/android/server/wm/WindowManagerService.java
+++ b/services/core/java/com/android/server/wm/WindowManagerService.java
@@ -8077,6 +8077,10 @@
             }
             boolean allWindowsDrawn = false;
             synchronized (mGlobalLock) {
+                if (mRoot.getDefaultDisplay().mDisplayUpdater.waitForTransition(message)) {
+                    // Use the ready-to-play of transition as the signal.
+                    return;
+                }
                 container.waitForAllWindowsDrawn();
                 mWindowPlacerLocked.requestTraversal();
                 mH.removeMessages(H.WAITING_FOR_DRAWN_TIMEOUT, container);
diff --git a/services/tests/displayservicetests/src/com/android/server/display/LogicalDisplayMapperTest.java b/services/tests/displayservicetests/src/com/android/server/display/LogicalDisplayMapperTest.java
index 5a50510..1a03e78 100644
--- a/services/tests/displayservicetests/src/com/android/server/display/LogicalDisplayMapperTest.java
+++ b/services/tests/displayservicetests/src/com/android/server/display/LogicalDisplayMapperTest.java
@@ -79,12 +79,16 @@
 import androidx.test.filters.SmallTest;
 
 import com.android.internal.foldables.FoldGracePeriodProvider;
+import com.android.internal.util.test.LocalServiceKeeperRule;
+import com.android.server.LocalServices;
 import com.android.server.display.feature.DisplayManagerFlags;
 import com.android.server.display.layout.DisplayIdProducer;
 import com.android.server.display.layout.Layout;
+import com.android.server.policy.WindowManagerPolicy;
 import com.android.server.utils.FoldSettingProvider;
 
 import org.junit.Before;
+import org.junit.Rule;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.mockito.ArgumentCaptor;
@@ -124,6 +128,9 @@
 
     private DeviceStateToLayoutMap mDeviceStateToLayoutMapSpy;
 
+    @Rule
+    public LocalServiceKeeperRule mLocalServiceKeeperRule = new LocalServiceKeeperRule();
+
     @Mock LogicalDisplayMapper.Listener mListenerMock;
     @Mock Context mContextMock;
     @Mock FoldSettingProvider mFoldSettingProviderMock;
@@ -133,6 +140,7 @@
     @Mock IThermalService mIThermalServiceMock;
     @Mock DisplayManagerFlags mFlagsMock;
     @Mock DisplayAdapter mDisplayAdapterMock;
+    @Mock WindowManagerPolicy mWindowManagerPolicy;
 
     @Captor ArgumentCaptor<LogicalDisplay> mDisplayCaptor;
     @Captor ArgumentCaptor<Integer> mDisplayEventCaptor;
@@ -143,6 +151,9 @@
         System.setProperty("dexmaker.share_classloader", "true");
         MockitoAnnotations.initMocks(this);
 
+        mLocalServiceKeeperRule.overrideLocalService(WindowManagerPolicy.class,
+                mWindowManagerPolicy);
+
         mDeviceStateToLayoutMapSpy =
                 spy(new DeviceStateToLayoutMap(mIdProducer, mFlagsMock, NON_EXISTING_FILE));
         mDisplayDeviceRepo = new DisplayDeviceRepository(
@@ -194,6 +205,7 @@
                 mDisplayDeviceRepo,
                 mListenerMock, new DisplayManagerService.SyncRoot(), mHandler,
                 mDeviceStateToLayoutMapSpy, mFlagsMock);
+        mLogicalDisplayMapper.onWindowManagerReady();
     }
 
 
@@ -757,6 +769,44 @@
     }
 
     @Test
+    public void testDisplaySwappedAfterDeviceStateChange_windowManagerIsNotified() {
+        FoldableDisplayDevices foldableDisplayDevices = createFoldableDeviceStateToLayoutMap();
+        mLogicalDisplayMapper.setDeviceStateLocked(DEVICE_STATE_OPEN);
+        mLogicalDisplayMapper.onEarlyInteractivityChange(true);
+        mLogicalDisplayMapper.onBootCompleted();
+        advanceTime(1000);
+        clearInvocations(mWindowManagerPolicy);
+
+        // Switch from 'inner' to 'outer' display (fold a foldable device)
+        mLogicalDisplayMapper.setDeviceStateLocked(DEVICE_STATE_CLOSED);
+        // Continue folding device state transition by turning off the inner display
+        foldableDisplayDevices.mInner.setState(STATE_OFF);
+        notifyDisplayChanges(foldableDisplayDevices.mOuter);
+        advanceTime(TIMEOUT_STATE_TRANSITION_MILLIS);
+
+        verify(mWindowManagerPolicy).onDisplaySwitchStart(DEFAULT_DISPLAY);
+    }
+
+    @Test
+    public void testCreateNewLogicalDisplay_windowManagerIsNotNotifiedAboutSwitch() {
+        DisplayDevice device1 = createDisplayDevice(TYPE_EXTERNAL, 600, 800,
+                FLAG_ALLOWED_TO_BE_DEFAULT_DISPLAY);
+        when(mDeviceStateToLayoutMapSpy.size()).thenReturn(1);
+        LogicalDisplay display1 = add(device1);
+
+        assertTrue(display1.isEnabledLocked());
+
+        DisplayDevice device2 = createDisplayDevice(TYPE_INTERNAL, 600, 800,
+                FLAG_ALLOWED_TO_BE_DEFAULT_DISPLAY);
+        when(mDeviceStateToLayoutMapSpy.size()).thenReturn(2);
+        add(device2);
+
+        // As it is not a display switch but adding a new display, we should not notify
+        // about display switch start to window manager
+        verify(mWindowManagerPolicy, never()).onDisplaySwitchStart(anyInt());
+    }
+
+    @Test
     public void testDoNotWaitForSleepWhenFoldSettingStayAwake() {
         // Test device should be marked ready for transition immediately when 'Continue using app
         // on fold' set to 'Always'
diff --git a/services/tests/wmtests/src/com/android/server/wm/DisplayContentDeferredUpdateTests.java b/services/tests/wmtests/src/com/android/server/wm/DisplayContentDeferredUpdateTests.java
index b11f9b2..073b551 100644
--- a/services/tests/wmtests/src/com/android/server/wm/DisplayContentDeferredUpdateTests.java
+++ b/services/tests/wmtests/src/com/android/server/wm/DisplayContentDeferredUpdateTests.java
@@ -31,6 +31,7 @@
 import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.when;
 
+import android.os.Message;
 import android.platform.test.annotations.Presubmit;
 import android.view.DisplayInfo;
 
@@ -60,6 +61,8 @@
     private int mColorMode;
     private int mLogicalDensityDpi;
 
+    private final Message mScreenUnblocker = mock(Message.class);
+
     @Override
     protected void onBeforeSystemServicesCreated() {
         // Set other flags to their default values
@@ -73,12 +76,11 @@
         doReturn(true).when(mDisplayContent).getLastHasContent();
         mockTransitionsController(/* enabled= */ true);
         mockRemoteDisplayChangeController();
+        performInitialDisplayUpdate();
     }
 
     @Test
     public void testUpdate_deferrableFieldChangedTransitionStarted_deferrableFieldUpdated() {
-        performInitialDisplayUpdate();
-
         mUniqueId = "old";
         Runnable onUpdated = mock(Runnable.class);
         mDisplayContent.requestDisplayUpdate(onUpdated);
@@ -107,8 +109,6 @@
 
     @Test
     public void testUpdate_nonDeferrableUpdateAndTransitionDeferred_nonDeferrableFieldUpdated() {
-        performInitialDisplayUpdate();
-
         // Update only color mode (non-deferrable field) and keep the same unique id
         mUniqueId = "initial_unique_id";
         mColorMode = 123;
@@ -121,8 +121,6 @@
 
     @Test
     public void testUpdate_nonDeferrableUpdateTwiceAndTransitionDeferred_fieldHasLatestValue() {
-        performInitialDisplayUpdate();
-
         // Update only color mode (non-deferrable field) and keep the same unique id
         mUniqueId = "initial_unique_id";
         mColorMode = 123;
@@ -163,7 +161,6 @@
 
     @Test
     public void testUpdate_deferrableFieldUpdatedTransitionPending_fieldNotUpdated() {
-        performInitialDisplayUpdate();
         mUniqueId = "old";
         Runnable onUpdated = mock(Runnable.class);
         mDisplayContent.requestDisplayUpdate(onUpdated);
@@ -181,7 +178,6 @@
 
     @Test
     public void testTwoDisplayUpdates_transitionStarted_displayUpdated() {
-        performInitialDisplayUpdate();
         mUniqueId = "old";
         Runnable onUpdated = mock(Runnable.class);
         mDisplayContent.requestDisplayUpdate(onUpdated);
@@ -212,6 +208,51 @@
         assertThat(mDisplayContent.getDisplayInfo().uniqueId).isEqualTo("new2");
     }
 
+    @Test
+    public void testWaitForTransition_displaySwitching_waitsForTransitionToBeStarted() {
+        mSetFlagsRule.enableFlags(Flags.FLAG_WAIT_FOR_TRANSITION_ON_DISPLAY_SWITCH);
+        mDisplayContent.mDisplayUpdater.onDisplaySwitching(/* switching= */ true);
+        boolean willWait = mDisplayContent.mDisplayUpdater.waitForTransition(mScreenUnblocker);
+        assertThat(willWait).isTrue();
+        mUniqueId = "new";
+        mDisplayContent.requestDisplayUpdate(mock(Runnable.class));
+        when(mDisplayContent.mTransitionController.inTransition()).thenReturn(true);
+        captureStartTransitionCollection().getValue().onCollectStarted(/* deferred= */ true);
+
+        // Verify that screen is not unblocked yet as the start transaction hasn't been presented
+        verify(mScreenUnblocker, never()).sendToTarget();
+
+        when(mDisplayContent.mTransitionController.inTransition()).thenReturn(false);
+        final Transition transition = captureRequestedTransition().getValue();
+        makeTransitionTransactionCompleted(transition);
+
+        // Verify that screen is unblocked as start transaction of the transition
+        // has been completed
+        verify(mScreenUnblocker).sendToTarget();
+    }
+
+    @Test
+    public void testWaitForTransition_displayNotSwitching_doesNotWait() {
+        mSetFlagsRule.enableFlags(Flags.FLAG_WAIT_FOR_TRANSITION_ON_DISPLAY_SWITCH);
+        mDisplayContent.mDisplayUpdater.onDisplaySwitching(/* switching= */ false);
+
+        boolean willWait = mDisplayContent.mDisplayUpdater.waitForTransition(mScreenUnblocker);
+
+        assertThat(willWait).isFalse();
+        verify(mScreenUnblocker, never()).sendToTarget();
+    }
+
+    @Test
+    public void testWaitForTransition_displayIsSwitchingButFlagDisabled_doesNotWait() {
+        mSetFlagsRule.disableFlags(Flags.FLAG_WAIT_FOR_TRANSITION_ON_DISPLAY_SWITCH);
+        mDisplayContent.mDisplayUpdater.onDisplaySwitching(/* switching= */ true);
+
+        boolean willWait = mDisplayContent.mDisplayUpdater.waitForTransition(mScreenUnblocker);
+
+        assertThat(willWait).isFalse();
+        verify(mScreenUnblocker, never()).sendToTarget();
+    }
+
     private void mockTransitionsController(boolean enabled) {
         spyOn(mDisplayContent.mTransitionController);
         when(mDisplayContent.mTransitionController.isShellTransitionsEnabled()).thenReturn(enabled);
@@ -233,6 +274,23 @@
         return callbackCaptor;
     }
 
+    private ArgumentCaptor<Transition> captureRequestedTransition() {
+        ArgumentCaptor<Transition> callbackCaptor =
+                ArgumentCaptor.forClass(Transition.class);
+        verify(mDisplayContent.mTransitionController, atLeast(1))
+                .requestStartTransition(callbackCaptor.capture(), any(), any(), any());
+        return callbackCaptor;
+    }
+
+    private void makeTransitionTransactionCompleted(Transition transition) {
+        if (transition.mTransactionCompletedListeners != null) {
+            for (int i = 0; i < transition.mTransactionCompletedListeners.size(); i++) {
+                final Runnable listener = transition.mTransactionCompletedListeners.get(i);
+                listener.run();
+            }
+        }
+    }
+
     private void performInitialDisplayUpdate() {
         mUniqueId = "initial_unique_id";
         mColorMode = 0;
diff --git a/services/tests/wmtests/src/com/android/server/wm/TestWindowManagerPolicy.java b/services/tests/wmtests/src/com/android/server/wm/TestWindowManagerPolicy.java
index 1233686..00a8842 100644
--- a/services/tests/wmtests/src/com/android/server/wm/TestWindowManagerPolicy.java
+++ b/services/tests/wmtests/src/com/android/server/wm/TestWindowManagerPolicy.java
@@ -167,6 +167,10 @@
     }
 
     @Override
+    public void onDisplaySwitchStart(int displayId) {
+    }
+
+    @Override
     public boolean okToAnimate(boolean ignoreScreenOn) {
         return mOkToAnimate;
     }