Fix desktop immersive's state clean up

This change fixes two bugs related to state cleanup:
1) clearState() was called by #animateResizeChange when the animation
   finished even if the animation was not requested by a transition this
   handler is tracking (e.g. mixed transitions). To fix this, an end
   callback is sent to #animateResizeChange instead and only the
   DesktopImmersiveController#startAnimation call path clears the state
   when the callback is run.
2) Even when animating transitions handled by this handler, it is
   possible that its changes do not contain the expected immersive task,
   in which case there's nothing to animate, but the state should still
   be cleaned up.

Flag: com.android.window.flags.enable_fully_immersive_in_desktop
Fix: 376913079
Test: atest WMShellUnitTests
Change-Id: Ia7519bcb8e565b699d6b1cefd28869dc44da6413
diff --git a/libs/WindowManager/Shell/src/com/android/wm/shell/desktopmode/DesktopImmersiveController.kt b/libs/WindowManager/Shell/src/com/android/wm/shell/desktopmode/DesktopImmersiveController.kt
index 8b76aa3..1acde73 100644
--- a/libs/WindowManager/Shell/src/com/android/wm/shell/desktopmode/DesktopImmersiveController.kt
+++ b/libs/WindowManager/Shell/src/com/android/wm/shell/desktopmode/DesktopImmersiveController.kt
@@ -74,7 +74,8 @@
         { SurfaceControl.Transaction() }
     )
 
-    private var state: TransitionState? = null
+    @VisibleForTesting
+    var state: TransitionState? = null
 
     @VisibleForTesting
     val pendingExternalExitTransitions = mutableListOf<ExternalPendingExit>()
@@ -99,7 +100,10 @@
     /** Starts a transition to enter full immersive state inside the desktop. */
     fun moveTaskToImmersive(taskInfo: RunningTaskInfo) {
         if (inProgress) {
-            logV("Cannot start entry because transition already in progress.")
+            logV(
+                "Cannot start entry because transition(s) already in progress: %s",
+                getRunningTransitions()
+            )
             return
         }
         val wct = WindowContainerTransaction().apply {
@@ -117,7 +121,10 @@
 
     fun moveTaskToNonImmersive(taskInfo: RunningTaskInfo) {
         if (inProgress) {
-            logV("Cannot start exit because transition already in progress.")
+            logV(
+                "Cannot start exit because transition(s) already in progress: %s",
+                getRunningTransitions()
+            )
             return
         }
 
@@ -242,14 +249,19 @@
         finishCallback: Transitions.TransitionFinishCallback
     ): Boolean {
         val state = requireState()
-        if (transition != state.transition) return false
+        check(state.transition == transition) {
+            "Transition $transition did not match expected state=$state"
+        }
         logD("startAnimation transition=%s", transition)
         animateResize(
             targetTaskId = state.taskId,
             info = info,
             startTransaction = startTransaction,
             finishTransaction = finishTransaction,
-            finishCallback = finishCallback
+            finishCallback = {
+                finishCallback.onTransitionFinished(/* wct= */ null)
+                clearState()
+            },
         )
         return true
     }
@@ -259,12 +271,18 @@
         info: TransitionInfo,
         startTransaction: SurfaceControl.Transaction,
         finishTransaction: SurfaceControl.Transaction,
-        finishCallback: Transitions.TransitionFinishCallback
+        finishCallback: Transitions.TransitionFinishCallback,
     ) {
         logD("animateResize for task#%d", targetTaskId)
-        val change = info.changes.first { c ->
+        val change = info.changes.firstOrNull { c ->
             val taskInfo = c.taskInfo
-            return@first taskInfo != null && taskInfo.taskId == targetTaskId
+            return@firstOrNull taskInfo != null && taskInfo.taskId == targetTaskId
+        }
+        if (change == null) {
+            logD("Did not find change for task#%d to animate", targetTaskId)
+            startTransaction.apply()
+            finishCallback.onTransitionFinished(/* wct= */ null)
+            return
         }
         animateResizeChange(change, startTransaction, finishTransaction, finishCallback)
     }
@@ -305,7 +323,6 @@
                         .apply()
                     onTaskResizeAnimationListener?.onAnimationEnd(taskId)
                     finishCallback.onTransitionFinished(null /* wct */)
-                    clearState()
                 }
             )
             addUpdateListener { animation ->
@@ -472,6 +489,17 @@
     private fun requireState(): TransitionState =
         state ?: error("Expected non-null transition state")
 
+    private fun getRunningTransitions(): List<IBinder> {
+        val running = mutableListOf<IBinder>()
+        state?.let {
+            running.add(it.transition)
+        }
+        pendingExternalExitTransitions.forEach {
+            running.add(it.transition)
+        }
+        return running
+    }
+
     private fun TransitionInfo.hasTaskChange(taskId: Int): Boolean =
         changes.any { c -> c.taskInfo?.taskId == taskId }
 
@@ -483,7 +511,8 @@
     }
 
     /** The state of the currently running transition. */
-    private data class TransitionState(
+    @VisibleForTesting
+    data class TransitionState(
         val transition: IBinder,
         val displayId: Int,
         val taskId: Int,
@@ -516,7 +545,8 @@
         fun asExit(): Exit? = if (this is Exit) this else null
     }
 
-    private enum class Direction {
+    @VisibleForTesting
+    enum class Direction {
         ENTER, EXIT
     }
 
@@ -528,9 +558,10 @@
         ProtoLog.d(WM_SHELL_DESKTOP_MODE, "%s: $msg", TAG, *arguments)
     }
 
-    private companion object {
+    companion object {
         private const val TAG = "DesktopImmersive"
 
-        private const val FULL_IMMERSIVE_ANIM_DURATION_MS = 336L
+        @VisibleForTesting
+        const val FULL_IMMERSIVE_ANIM_DURATION_MS = 336L
     }
 }
diff --git a/libs/WindowManager/Shell/tests/unittest/src/com/android/wm/shell/desktopmode/DesktopImmersiveControllerTest.kt b/libs/WindowManager/Shell/tests/unittest/src/com/android/wm/shell/desktopmode/DesktopImmersiveControllerTest.kt
index d5b1790..a4f4d05 100644
--- a/libs/WindowManager/Shell/tests/unittest/src/com/android/wm/shell/desktopmode/DesktopImmersiveControllerTest.kt
+++ b/libs/WindowManager/Shell/tests/unittest/src/com/android/wm/shell/desktopmode/DesktopImmersiveControllerTest.kt
@@ -15,6 +15,7 @@
  */
 package com.android.wm.shell.desktopmode
 
+import android.animation.AnimatorTestRule
 import android.app.ActivityManager.RunningTaskInfo
 import android.app.WindowConfiguration.WINDOW_CONFIG_BOUNDS
 import android.graphics.Rect
@@ -24,6 +25,7 @@
 import android.platform.test.annotations.EnableFlags
 import android.platform.test.flag.junit.SetFlagsRule
 import android.testing.AndroidTestingRunner
+import android.testing.TestableLooper
 import android.view.Display.DEFAULT_DISPLAY
 import android.view.Surface
 import android.view.SurfaceControl
@@ -43,6 +45,7 @@
 import com.android.wm.shell.desktopmode.DesktopTestHelpers.Companion.createFreeformTask
 import com.android.wm.shell.sysui.ShellInit
 import com.android.wm.shell.transition.Transitions
+import com.android.wm.shell.util.StubTransaction
 import com.google.common.truth.Truth.assertThat
 import org.junit.Before
 import org.junit.Rule
@@ -64,17 +67,19 @@
  * Usage: atest WMShellUnitTests:DesktopImmersiveControllerTest
  */
 @SmallTest
+@TestableLooper.RunWithLooper
 @RunWith(AndroidTestingRunner::class)
 class DesktopImmersiveControllerTest : ShellTestCase() {
 
     @JvmField @Rule val setFlagsRule = SetFlagsRule()
+    @JvmField @Rule val animatorTestRule = AnimatorTestRule(this)
 
     @Mock private lateinit var mockTransitions: Transitions
     private lateinit var desktopRepository: DesktopRepository
     @Mock private lateinit var mockDisplayController: DisplayController
     @Mock private lateinit var mockShellTaskOrganizer: ShellTaskOrganizer
     @Mock private lateinit var mockDisplayLayout: DisplayLayout
-    private val transactionSupplier = { SurfaceControl.Transaction() }
+    private val transactionSupplier = { StubTransaction() }
 
     private lateinit var controller: DesktopImmersiveController
 
@@ -674,6 +679,60 @@
         assertThat(controller.isImmersiveChange(transition, change)).isTrue()
     }
 
+    @Test
+    @EnableFlags(Flags.FLAG_ENABLE_FULLY_IMMERSIVE_IN_DESKTOP)
+    fun externalAnimateResizeChange_doesNotCleanUpPendingTransitionState() {
+        val task = createFreeformTask()
+        val mockBinder = mock(IBinder::class.java)
+        whenever(mockTransitions.startTransition(eq(TRANSIT_CHANGE), any(), eq(controller)))
+            .thenReturn(mockBinder)
+        desktopRepository.setTaskInFullImmersiveState(
+            displayId = task.displayId,
+            taskId = task.taskId,
+            immersive = true
+        )
+
+        controller.moveTaskToNonImmersive(task)
+
+        controller.animateResizeChange(
+            change = TransitionInfo.Change(task.token, SurfaceControl()).apply {
+                taskInfo = task
+            },
+            startTransaction = StubTransaction(),
+            finishTransaction = StubTransaction(),
+            finishCallback = { }
+        )
+        animatorTestRule.advanceTimeBy(DesktopImmersiveController.FULL_IMMERSIVE_ANIM_DURATION_MS)
+
+        assertThat(controller.state).isNotNull()
+    }
+
+    @Test
+    @EnableFlags(Flags.FLAG_ENABLE_FULLY_IMMERSIVE_IN_DESKTOP)
+    fun startAnimation_missingChange_clearsState() {
+        val task = createFreeformTask()
+        val mockBinder = mock(IBinder::class.java)
+        whenever(mockTransitions.startTransition(eq(TRANSIT_CHANGE), any(), eq(controller)))
+            .thenReturn(mockBinder)
+        desktopRepository.setTaskInFullImmersiveState(
+            displayId = task.displayId,
+            taskId = task.taskId,
+            immersive = false
+        )
+
+        controller.moveTaskToImmersive(task)
+
+        controller.startAnimation(
+            transition = mockBinder,
+            info = createTransitionInfo(changes = emptyList()),
+            startTransaction = StubTransaction(),
+            finishTransaction = StubTransaction(),
+            finishCallback = {}
+        )
+
+        assertThat(controller.state).isNull()
+    }
+
     private fun createTransitionInfo(
         @TransitionType type: Int = TRANSIT_CHANGE,
         @TransitionFlags flags: Int = 0,