Merge changes I1f9a8e93,I74ec2bf3 into main

* changes:
  Exit immersive state on rotation
  Restore to pre-immersive bounds on immersive exit
diff --git a/libs/WindowManager/Shell/src/com/android/wm/shell/desktopmode/DesktopFullImmersiveTransitionHandler.kt b/libs/WindowManager/Shell/src/com/android/wm/shell/desktopmode/DesktopFullImmersiveTransitionHandler.kt
index 679179a..320c003 100644
--- a/libs/WindowManager/Shell/src/com/android/wm/shell/desktopmode/DesktopFullImmersiveTransitionHandler.kt
+++ b/libs/WindowManager/Shell/src/com/android/wm/shell/desktopmode/DesktopFullImmersiveTransitionHandler.kt
@@ -23,6 +23,7 @@
 import android.view.SurfaceControl
 import android.view.WindowManager.TRANSIT_CHANGE
 import android.view.animation.DecelerateInterpolator
+import android.window.DesktopModeFlags.ENABLE_WINDOWING_DYNAMIC_INITIAL_BOUNDS
 import android.window.TransitionInfo
 import android.window.TransitionRequestInfo
 import android.window.WindowContainerTransaction
@@ -102,10 +103,8 @@
             return
         }
 
-        val displayLayout = displayController.getDisplayLayout(taskInfo.displayId) ?: return
-        val destinationBounds = calculateMaximizeBounds(displayLayout, taskInfo)
         val wct = WindowContainerTransaction().apply {
-            setBounds(taskInfo.token, destinationBounds)
+            setBounds(taskInfo.token, getExitDestinationBounds(taskInfo))
         }
         logV("Moving task ${taskInfo.taskId} out of immersive mode")
         val transition = transitions.startTransition(TRANSIT_CHANGE, wct, /* handler= */ this)
@@ -145,11 +144,10 @@
         displayId: Int
     ): ((IBinder) -> Unit)? {
         if (!Flags.enableFullyImmersiveInDesktop()) return null
-        val displayLayout = displayController.getDisplayLayout(displayId) ?: return null
         val immersiveTask = desktopRepository.getTaskInFullImmersiveState(displayId) ?: return null
         val taskInfo = shellTaskOrganizer.getRunningTaskInfo(immersiveTask) ?: return null
         logV("Appending immersive exit for task: $immersiveTask in display: $displayId")
-        wct.setBounds(taskInfo.token, calculateMaximizeBounds(displayLayout, taskInfo))
+        wct.setBounds(taskInfo.token, getExitDestinationBounds(taskInfo))
         return { transition -> addPendingImmersiveExit(immersiveTask, displayId, transition) }
     }
 
@@ -168,16 +166,14 @@
         if (desktopRepository.isTaskInFullImmersiveState(taskInfo.taskId)) {
             // A full immersive task is being minimized, make sure the immersive state is broken
             // (i.e. resize back to max bounds).
-            displayController.getDisplayLayout(taskInfo.displayId)?.let { displayLayout ->
-                wct.setBounds(taskInfo.token, calculateMaximizeBounds(displayLayout, taskInfo))
-                logV("Appending immersive exit for task: ${taskInfo.taskId}")
-                return { transition ->
-                    addPendingImmersiveExit(
-                        taskId = taskInfo.taskId,
-                        displayId = taskInfo.displayId,
-                        transition = transition
-                    )
-                }
+            wct.setBounds(taskInfo.token, getExitDestinationBounds(taskInfo))
+            logV("Appending immersive exit for task: ${taskInfo.taskId}")
+            return { transition ->
+                addPendingImmersiveExit(
+                    taskId = taskInfo.taskId,
+                    displayId = taskInfo.displayId,
+                    transition = transition
+                )
             }
         }
         return null
@@ -302,14 +298,19 @@
                         taskId = pendingExit.taskId,
                         immersive = false
                     )
+                    if (Flags.enableRestoreToPreviousSizeFromDesktopImmersive()) {
+                        desktopRepository.removeBoundsBeforeFullImmersive(pendingExit.taskId)
+                    }
                 }
             }
             return
         }
 
         // Check if this is a direct immersive enter/exit transition.
-        val state = this.state ?: return
-        if (transition == state.transition) {
+        if (transition == state?.transition) {
+            val state = requireState()
+            val startBounds = info.changes.first { c -> c.taskInfo?.taskId == state.taskId }
+                .startAbsBounds
             logV("Direct move for task ${state.taskId} in ${state.direction} direction verified")
             when (state.direction) {
                 Direction.ENTER -> {
@@ -318,6 +319,9 @@
                         taskId = state.taskId,
                         immersive = true
                     )
+                    if (Flags.enableRestoreToPreviousSizeFromDesktopImmersive()) {
+                        desktopRepository.saveBoundsBeforeFullImmersive(state.taskId, startBounds)
+                    }
                 }
                 Direction.EXIT -> {
                     desktopRepository.setTaskInFullImmersiveState(
@@ -325,15 +329,48 @@
                         taskId = state.taskId,
                         immersive = false
                     )
+                    if (Flags.enableRestoreToPreviousSizeFromDesktopImmersive()) {
+                        desktopRepository.removeBoundsBeforeFullImmersive(state.taskId)
+                    }
                 }
             }
+            return
         }
+
+        // Check if this is an untracked exit transition, like display rotation.
+        info.changes
+            .filter { c -> c.taskInfo != null }
+            .filter { c -> desktopRepository.isTaskInFullImmersiveState(c.taskInfo!!.taskId) }
+            .filter { c -> c.startRotation != c.endRotation }
+            .forEach { c ->
+                logV("Detected immersive exit due to rotation for task: ${c.taskInfo!!.taskId}")
+                desktopRepository.setTaskInFullImmersiveState(
+                    displayId = c.taskInfo!!.displayId,
+                    taskId = c.taskInfo!!.taskId,
+                    immersive = false
+                )
+            }
     }
 
     private fun clearState() {
         state = null
     }
 
+    private fun getExitDestinationBounds(taskInfo: RunningTaskInfo): Rect {
+        val displayLayout = displayController.getDisplayLayout(taskInfo.displayId)
+            ?: error("Expected non-null display layout for displayId: ${taskInfo.displayId}")
+        return if (Flags.enableRestoreToPreviousSizeFromDesktopImmersive()) {
+            desktopRepository.removeBoundsBeforeFullImmersive(taskInfo.taskId)
+                ?: if (ENABLE_WINDOWING_DYNAMIC_INITIAL_BOUNDS.isTrue()) {
+                    calculateInitialBounds(displayLayout, taskInfo)
+                } else {
+                    calculateDefaultDesktopTaskBounds(displayLayout)
+                }
+        } else {
+            return calculateMaximizeBounds(displayLayout, taskInfo)
+        }
+    }
+
     private fun requireState(): TransitionState =
         state ?: error("Expected non-null transition state")
 
diff --git a/libs/WindowManager/Shell/src/com/android/wm/shell/desktopmode/DesktopModeUtils.kt b/libs/WindowManager/Shell/src/com/android/wm/shell/desktopmode/DesktopModeUtils.kt
index 6d47922..edcc877 100644
--- a/libs/WindowManager/Shell/src/com/android/wm/shell/desktopmode/DesktopModeUtils.kt
+++ b/libs/WindowManager/Shell/src/com/android/wm/shell/desktopmode/DesktopModeUtils.kt
@@ -37,6 +37,23 @@
     SystemProperties.getInt("persist.wm.debug.desktop_mode_landscape_app_padding", 25)
 
 /**
+ * Calculates the initial bounds to enter desktop, centered on the display.
+ */
+fun calculateDefaultDesktopTaskBounds(displayLayout: DisplayLayout): Rect {
+    // TODO(b/319819547): Account for app constraints so apps do not become letterboxed
+    val desiredWidth = (displayLayout.width() * DESKTOP_MODE_INITIAL_BOUNDS_SCALE).toInt()
+    val desiredHeight = (displayLayout.height() * DESKTOP_MODE_INITIAL_BOUNDS_SCALE).toInt()
+    val heightOffset = (displayLayout.height() - desiredHeight) / 2
+    val widthOffset = (displayLayout.width() - desiredWidth) / 2
+    return Rect(
+        widthOffset,
+        heightOffset,
+        desiredWidth + widthOffset,
+        desiredHeight + heightOffset
+    )
+}
+
+/**
  * Calculates the initial bounds required for an application to fill a scale of the display bounds
  * without any letterboxing. This is done by taking into account the applications fullscreen size,
  * aspect ratio, orientation and resizability to calculate an area this is compatible with the
diff --git a/libs/WindowManager/Shell/src/com/android/wm/shell/desktopmode/DesktopRepository.kt b/libs/WindowManager/Shell/src/com/android/wm/shell/desktopmode/DesktopRepository.kt
index 5ac4ef5..eeb7ac8 100644
--- a/libs/WindowManager/Shell/src/com/android/wm/shell/desktopmode/DesktopRepository.kt
+++ b/libs/WindowManager/Shell/src/com/android/wm/shell/desktopmode/DesktopRepository.kt
@@ -102,6 +102,9 @@
     /* Tracks last bounds of task before toggled to stable bounds. */
     private val boundsBeforeMaximizeByTaskId = SparseArray<Rect>()
 
+    /* Tracks last bounds of task before toggled to immersive state. */
+    private val boundsBeforeFullImmersiveByTaskId = SparseArray<Rect>()
+
     private var desktopGestureExclusionListener: Consumer<Region>? = null
     private var desktopGestureExclusionExecutor: Executor? = null
 
@@ -414,6 +417,7 @@
         logD("Removes freeform task: taskId=%d, displayId=%d", taskId, displayId)
         desktopTaskDataByDisplayId[displayId]?.freeformTasksInZOrder?.remove(taskId)
         boundsBeforeMaximizeByTaskId.remove(taskId)
+        boundsBeforeFullImmersiveByTaskId.remove(taskId)
         logD("Remaining freeform tasks: %s",
             desktopTaskDataByDisplayId[displayId]?.freeformTasksInZOrder?.toDumpString())
         // Remove task from unminimized task if it is minimized.
@@ -472,6 +476,14 @@
     fun saveBoundsBeforeMaximize(taskId: Int, bounds: Rect) =
         boundsBeforeMaximizeByTaskId.set(taskId, Rect(bounds))
 
+    /** Removes and returns the bounds saved before entering immersive with the given task. */
+    fun removeBoundsBeforeFullImmersive(taskId: Int): Rect? =
+        boundsBeforeFullImmersiveByTaskId.removeReturnOld(taskId)
+
+    /** Saves the bounds of the given task before entering immersive. */
+    fun saveBoundsBeforeFullImmersive(taskId: Int, bounds: Rect) =
+        boundsBeforeFullImmersiveByTaskId.set(taskId, Rect(bounds))
+
     private fun updatePersistentRepository(displayId: Int) {
         // Create a deep copy of the data
         desktopTaskDataByDisplayId[displayId]?.deepCopy()?.let { desktopTaskDataByDisplayIdCopy ->
diff --git a/libs/WindowManager/Shell/src/com/android/wm/shell/desktopmode/DesktopTasksController.kt b/libs/WindowManager/Shell/src/com/android/wm/shell/desktopmode/DesktopTasksController.kt
index 18ba748..29e302a 100644
--- a/libs/WindowManager/Shell/src/com/android/wm/shell/desktopmode/DesktopTasksController.kt
+++ b/libs/WindowManager/Shell/src/com/android/wm/shell/desktopmode/DesktopTasksController.kt
@@ -751,7 +751,7 @@
                 if (ENABLE_WINDOWING_DYNAMIC_INITIAL_BOUNDS.isTrue()) {
                     destinationBounds.set(calculateInitialBounds(displayLayout, taskInfo))
                 } else {
-                    destinationBounds.set(getDefaultDesktopTaskBounds(displayLayout))
+                    destinationBounds.set(calculateDefaultDesktopTaskBounds(displayLayout))
                 }
             }
         } else {
@@ -920,20 +920,6 @@
         }
     }
 
-    private fun getDefaultDesktopTaskBounds(displayLayout: DisplayLayout): Rect {
-        // TODO(b/319819547): Account for app constraints so apps do not become letterboxed
-        val desiredWidth = (displayLayout.width() * DESKTOP_MODE_INITIAL_BOUNDS_SCALE).toInt()
-        val desiredHeight = (displayLayout.height() * DESKTOP_MODE_INITIAL_BOUNDS_SCALE).toInt()
-        val heightOffset = (displayLayout.height() - desiredHeight) / 2
-        val widthOffset = (displayLayout.width() - desiredWidth) / 2
-        return Rect(
-            widthOffset,
-            heightOffset,
-            desiredWidth + widthOffset,
-            desiredHeight + heightOffset
-        )
-    }
-
     private fun getSnapBounds(taskInfo: RunningTaskInfo, position: SnapPosition): Rect {
         val displayLayout = displayController.getDisplayLayout(taskInfo.displayId) ?: return Rect()
 
@@ -1487,7 +1473,7 @@
         val bounds = if (ENABLE_WINDOWING_DYNAMIC_INITIAL_BOUNDS.isTrue) {
             calculateInitialBounds(displayLayout, taskInfo)
         } else {
-            getDefaultDesktopTaskBounds(displayLayout)
+            calculateDefaultDesktopTaskBounds(displayLayout)
         }
 
         if (DesktopModeFlags.ENABLE_CASCADING_WINDOWS.isTrue) {
@@ -1883,7 +1869,7 @@
         when (indicatorType) {
             IndicatorType.TO_DESKTOP_INDICATOR -> {
                 // Use default bounds, but with the top-center at the drop point.
-                newWindowBounds.set(getDefaultDesktopTaskBounds(displayLayout))
+                newWindowBounds.set(calculateDefaultDesktopTaskBounds(displayLayout))
                 newWindowBounds.offsetTo(
                     dragEvent.x.toInt() - (newWindowBounds.width() / 2),
                     dragEvent.y.toInt()
diff --git a/libs/WindowManager/Shell/tests/unittest/src/com/android/wm/shell/desktopmode/DesktopFullImmersiveTransitionHandlerTest.kt b/libs/WindowManager/Shell/tests/unittest/src/com/android/wm/shell/desktopmode/DesktopFullImmersiveTransitionHandlerTest.kt
index 2e9effb4..b137468 100644
--- a/libs/WindowManager/Shell/tests/unittest/src/com/android/wm/shell/desktopmode/DesktopFullImmersiveTransitionHandlerTest.kt
+++ b/libs/WindowManager/Shell/tests/unittest/src/com/android/wm/shell/desktopmode/DesktopFullImmersiveTransitionHandlerTest.kt
@@ -16,12 +16,15 @@
 package com.android.wm.shell.desktopmode
 
 import android.app.WindowConfiguration.WINDOW_CONFIG_BOUNDS
+import android.graphics.Rect
 import android.os.Binder
 import android.os.IBinder
+import android.platform.test.annotations.DisableFlags
 import android.platform.test.annotations.EnableFlags
 import android.platform.test.flag.junit.SetFlagsRule
 import android.testing.AndroidTestingRunner
 import android.view.Display.DEFAULT_DISPLAY
+import android.view.Surface
 import android.view.SurfaceControl
 import android.view.WindowManager.TRANSIT_CHANGE
 import android.view.WindowManager.TransitionFlags
@@ -68,6 +71,7 @@
     private lateinit var desktopRepository: DesktopRepository
     @Mock private lateinit var mockDisplayController: DisplayController
     @Mock private lateinit var mockShellTaskOrganizer: ShellTaskOrganizer
+    @Mock private lateinit var mockDisplayLayout: DisplayLayout
     private val transactionSupplier = { SurfaceControl.Transaction() }
 
     private lateinit var immersiveHandler: DesktopFullImmersiveTransitionHandler
@@ -78,7 +82,10 @@
             context, ShellInit(TestShellExecutor()), mock(), mock()
         )
         whenever(mockDisplayController.getDisplayLayout(DEFAULT_DISPLAY))
-            .thenReturn(DisplayLayout())
+            .thenReturn(mockDisplayLayout)
+        whenever(mockDisplayLayout.getStableBounds(any())).thenAnswer { invocation ->
+            (invocation.getArgument(0) as Rect).set(STABLE_BOUNDS)
+        }
         immersiveHandler = DesktopFullImmersiveTransitionHandler(
             transitions = mockTransitions,
             desktopRepository = desktopRepository,
@@ -101,12 +108,50 @@
         )
 
         immersiveHandler.moveTaskToImmersive(task)
-        immersiveHandler.onTransitionReady(mockBinder, createTransitionInfo())
+        immersiveHandler.onTransitionReady(
+            transition = mockBinder,
+            info = createTransitionInfo(
+                changes = listOf(
+                    TransitionInfo.Change(task.token, SurfaceControl()).apply {
+                        taskInfo = task
+                    }
+                )
+            )
+        )
 
         assertThat(desktopRepository.isTaskInFullImmersiveState(task.taskId)).isTrue()
     }
 
     @Test
+    @EnableFlags(Flags.FLAG_ENABLE_RESTORE_TO_PREVIOUS_SIZE_FROM_DESKTOP_IMMERSIVE)
+    fun enterImmersive_savesPreImmersiveBounds() {
+        val task = createFreeformTask()
+        val mockBinder = mock(IBinder::class.java)
+        whenever(mockTransitions.startTransition(eq(TRANSIT_CHANGE), any(), eq(immersiveHandler)))
+            .thenReturn(mockBinder)
+        desktopRepository.setTaskInFullImmersiveState(
+            displayId = task.displayId,
+            taskId = task.taskId,
+            immersive = false
+        )
+        assertThat(desktopRepository.removeBoundsBeforeFullImmersive(task.taskId)).isNull()
+
+        immersiveHandler.moveTaskToImmersive(task)
+        immersiveHandler.onTransitionReady(
+            transition = mockBinder,
+            info = createTransitionInfo(
+                changes = listOf(
+                    TransitionInfo.Change(task.token, SurfaceControl()).apply {
+                        taskInfo = task
+                    }
+                )
+            )
+        )
+
+        assertThat(desktopRepository.removeBoundsBeforeFullImmersive(task.taskId)).isNotNull()
+    }
+
+    @Test
     fun exitImmersive_transitionReady_updatesRepository() {
         val task = createFreeformTask()
         val mockBinder = mock(IBinder::class.java)
@@ -119,7 +164,69 @@
         )
 
         immersiveHandler.moveTaskToNonImmersive(task)
-        immersiveHandler.onTransitionReady(mockBinder, createTransitionInfo())
+        immersiveHandler.onTransitionReady(
+            transition = mockBinder,
+            info = createTransitionInfo(
+                changes = listOf(
+                    TransitionInfo.Change(task.token, SurfaceControl()).apply {
+                        taskInfo = task
+                    }
+                )
+            )
+        )
+
+        assertThat(desktopRepository.isTaskInFullImmersiveState(task.taskId)).isFalse()
+    }
+
+    @Test
+    @EnableFlags(Flags.FLAG_ENABLE_RESTORE_TO_PREVIOUS_SIZE_FROM_DESKTOP_IMMERSIVE)
+    fun exitImmersive_onTransitionReady_removesBoundsBeforeImmersive() {
+        val task = createFreeformTask()
+        val mockBinder = mock(IBinder::class.java)
+        whenever(mockTransitions.startTransition(eq(TRANSIT_CHANGE), any(), eq(immersiveHandler)))
+            .thenReturn(mockBinder)
+        desktopRepository.setTaskInFullImmersiveState(
+            displayId = task.displayId,
+            taskId = task.taskId,
+            immersive = true
+        )
+        desktopRepository.saveBoundsBeforeFullImmersive(task.taskId, Rect(100, 100, 600, 600))
+
+        immersiveHandler.moveTaskToNonImmersive(task)
+        immersiveHandler.onTransitionReady(
+            transition = mockBinder,
+            info = createTransitionInfo(
+                changes = listOf(
+                    TransitionInfo.Change(task.token, SurfaceControl()).apply {
+                        taskInfo = task
+                    }
+                )
+            )
+        )
+
+        assertThat(desktopRepository.removeBoundsBeforeMaximize(task.taskId)).isNull()
+    }
+
+    @Test
+    fun onTransitionReady_displayRotation_exitsImmersive() {
+        val task = createFreeformTask()
+        desktopRepository.setTaskInFullImmersiveState(
+            displayId = task.displayId,
+            taskId = task.taskId,
+            immersive = true
+        )
+
+        immersiveHandler.onTransitionReady(
+            transition = mock(IBinder::class.java),
+            info = createTransitionInfo(
+                changes = listOf(
+                    TransitionInfo.Change(task.token, SurfaceControl()).apply {
+                        taskInfo = task
+                        setRotation(/* start= */ Surface.ROTATION_0, /* end= */ Surface.ROTATION_90)
+                    }
+                )
+            )
+        )
 
         assertThat(desktopRepository.isTaskInFullImmersiveState(task.taskId)).isFalse()
     }
@@ -361,6 +468,103 @@
         assertThat(desktopRepository.isTaskInFullImmersiveState(task.taskId)).isFalse()
     }
 
+    @Test
+    @EnableFlags(
+        Flags.FLAG_ENABLE_FULLY_IMMERSIVE_IN_DESKTOP,
+        Flags.FLAG_ENABLE_RESTORE_TO_PREVIOUS_SIZE_FROM_DESKTOP_IMMERSIVE
+    )
+    fun onTransitionReady_pendingExit_removesBoundsBeforeImmersive() {
+        val task = createFreeformTask()
+        whenever(mockShellTaskOrganizer.getRunningTaskInfo(task.taskId)).thenReturn(task)
+        val wct = WindowContainerTransaction()
+        val transition = Binder()
+        desktopRepository.setTaskInFullImmersiveState(
+            displayId = DEFAULT_DISPLAY,
+            taskId = task.taskId,
+            immersive = true
+        )
+        desktopRepository.saveBoundsBeforeFullImmersive(task.taskId, Rect(100, 100, 600, 600))
+        immersiveHandler.exitImmersiveIfApplicable(transition, wct, DEFAULT_DISPLAY)
+
+        immersiveHandler.onTransitionReady(
+            transition = transition,
+            info = createTransitionInfo(
+                changes = listOf(
+                    TransitionInfo.Change(task.token, SurfaceControl()).apply { taskInfo = task }
+                )
+            )
+        )
+
+        assertThat(desktopRepository.removeBoundsBeforeMaximize(task.taskId)).isNull()
+    }
+
+    @Test
+    @EnableFlags(Flags.FLAG_ENABLE_FULLY_IMMERSIVE_IN_DESKTOP)
+    @DisableFlags(Flags.FLAG_ENABLE_RESTORE_TO_PREVIOUS_SIZE_FROM_DESKTOP_IMMERSIVE)
+    fun exitImmersiveIfApplicable_changesBoundsToMaximize() {
+        val task = createFreeformTask()
+        whenever(mockShellTaskOrganizer.getRunningTaskInfo(task.taskId)).thenReturn(task)
+        val wct = WindowContainerTransaction()
+        desktopRepository.setTaskInFullImmersiveState(
+            displayId = DEFAULT_DISPLAY,
+            taskId = task.taskId,
+            immersive = true
+        )
+
+        immersiveHandler.exitImmersiveIfApplicable(wct = wct, taskInfo = task)
+
+        assertThat(
+            wct.hasBoundsChange(task.token, calculateMaximizeBounds(mockDisplayLayout, task))
+        ).isTrue()
+    }
+
+    @Test
+    @EnableFlags(
+        Flags.FLAG_ENABLE_FULLY_IMMERSIVE_IN_DESKTOP,
+        Flags.FLAG_ENABLE_RESTORE_TO_PREVIOUS_SIZE_FROM_DESKTOP_IMMERSIVE
+    )
+    fun exitImmersiveIfApplicable_preImmersiveBoundsSaved_changesBoundsToPreImmersiveBounds() {
+        val task = createFreeformTask()
+        whenever(mockShellTaskOrganizer.getRunningTaskInfo(task.taskId)).thenReturn(task)
+        val wct = WindowContainerTransaction()
+        desktopRepository.setTaskInFullImmersiveState(
+            displayId = DEFAULT_DISPLAY,
+            taskId = task.taskId,
+            immersive = true
+        )
+        val preImmersiveBounds = Rect(100, 100, 500, 500)
+        desktopRepository.saveBoundsBeforeFullImmersive(task.taskId, preImmersiveBounds)
+
+        immersiveHandler.exitImmersiveIfApplicable(wct = wct, taskInfo = task)
+
+        assertThat(
+            wct.hasBoundsChange(task.token, preImmersiveBounds)
+        ).isTrue()
+    }
+
+    @Test
+    @EnableFlags(
+        Flags.FLAG_ENABLE_FULLY_IMMERSIVE_IN_DESKTOP,
+        Flags.FLAG_ENABLE_RESTORE_TO_PREVIOUS_SIZE_FROM_DESKTOP_IMMERSIVE,
+        Flags.FLAG_ENABLE_WINDOWING_DYNAMIC_INITIAL_BOUNDS
+    )
+    fun exitImmersiveIfApplicable_preImmersiveBoundsNotSaved_changesBoundsToInitialBounds() {
+        val task = createFreeformTask()
+        whenever(mockShellTaskOrganizer.getRunningTaskInfo(task.taskId)).thenReturn(task)
+        val wct = WindowContainerTransaction()
+        desktopRepository.setTaskInFullImmersiveState(
+            displayId = DEFAULT_DISPLAY,
+            taskId = task.taskId,
+            immersive = true
+        )
+
+        immersiveHandler.exitImmersiveIfApplicable(wct = wct, taskInfo = task)
+
+        assertThat(
+            wct.hasBoundsChange(task.token, calculateInitialBounds(mockDisplayLayout, task))
+        ).isTrue()
+    }
+
     private fun createTransitionInfo(
         @TransitionType type: Int = TRANSIT_CHANGE,
         @TransitionFlags flags: Int = 0,
@@ -374,4 +578,17 @@
             change.key == token.asBinder()
                     && (change.value.windowSetMask and WINDOW_CONFIG_BOUNDS) != 0
         }
+
+    private fun WindowContainerTransaction.hasBoundsChange(
+        token: WindowContainerToken,
+        bounds: Rect,
+    ): Boolean = this.changes.any { change ->
+        change.key == token.asBinder()
+                && (change.value.windowSetMask and WINDOW_CONFIG_BOUNDS) != 0
+                && change.value.configuration.windowConfiguration.bounds == bounds
+    }
+
+    companion object {
+        private val STABLE_BOUNDS = Rect(0, 100, 2000, 1900)
+    }
 }
diff --git a/libs/WindowManager/Shell/tests/unittest/src/com/android/wm/shell/desktopmode/DesktopRepositoryTest.kt b/libs/WindowManager/Shell/tests/unittest/src/com/android/wm/shell/desktopmode/DesktopRepositoryTest.kt
index e20f0ec..3e22803 100644
--- a/libs/WindowManager/Shell/tests/unittest/src/com/android/wm/shell/desktopmode/DesktopRepositoryTest.kt
+++ b/libs/WindowManager/Shell/tests/unittest/src/com/android/wm/shell/desktopmode/DesktopRepositoryTest.kt
@@ -746,6 +746,18 @@
     }
 
     @Test
+    fun removeFreeformTask_removesTaskBoundsBeforeImmersive() {
+        val taskId = 1
+        repo.addActiveTask(THIRD_DISPLAY, taskId)
+        repo.addOrMoveFreeformTaskToTop(THIRD_DISPLAY, taskId)
+        repo.saveBoundsBeforeFullImmersive(taskId, Rect(0, 0, 200, 200))
+
+        repo.removeFreeformTask(THIRD_DISPLAY, taskId)
+
+        assertThat(repo.removeBoundsBeforeFullImmersive(taskId)).isNull()
+    }
+
+    @Test
     fun removeFreeformTask_removesActiveTask() {
         val taskId = 1
         val listener = TestListener()
@@ -805,6 +817,28 @@
     }
 
     @Test
+    fun saveBoundsBeforeImmersive_boundsSavedByTaskId() {
+        val taskId = 1
+        val bounds = Rect(0, 0, 200, 200)
+
+        repo.saveBoundsBeforeFullImmersive(taskId, bounds)
+
+        assertThat(repo.removeBoundsBeforeFullImmersive(taskId)).isEqualTo(bounds)
+    }
+
+    @Test
+    fun removeBoundsBeforeImmersive_returnsNullAfterBoundsRemoved() {
+        val taskId = 1
+        val bounds = Rect(0, 0, 200, 200)
+        repo.saveBoundsBeforeFullImmersive(taskId, bounds)
+        repo.removeBoundsBeforeFullImmersive(taskId)
+
+        val boundsBeforeImmersive = repo.removeBoundsBeforeFullImmersive(taskId)
+
+        assertThat(boundsBeforeImmersive).isNull()
+    }
+
+    @Test
     fun isMinimizedTask_minimizeTaskNotCalled_noTasksMinimized() {
         assertThat(repo.isMinimizedTask(taskId = 0)).isFalse()
         assertThat(repo.isMinimizedTask(taskId = 1)).isFalse()