Merge changes I1069e3a1,I2dc2f578 into main

* changes:
  Finish unfold Shell transition when fold is merged in
  Fix finishing unfold shell transition immediately after folding
diff --git a/libs/WindowManager/Shell/src/com/android/wm/shell/unfold/UnfoldTransitionHandler.java b/libs/WindowManager/Shell/src/com/android/wm/shell/unfold/UnfoldTransitionHandler.java
index f783b45..5a2abe1 100644
--- a/libs/WindowManager/Shell/src/com/android/wm/shell/unfold/UnfoldTransitionHandler.java
+++ b/libs/WindowManager/Shell/src/com/android/wm/shell/unfold/UnfoldTransitionHandler.java
@@ -16,9 +16,9 @@
 
 package com.android.wm.shell.unfold;
 
+import static android.view.Display.DEFAULT_DISPLAY;
 import static android.view.WindowManager.KEYGUARD_VISIBILITY_TRANSIT_FLAGS;
 import static android.view.WindowManager.TRANSIT_CHANGE;
-import static android.view.WindowManager.TRANSIT_FLAG_PHYSICAL_DISPLAY_SWITCH;
 
 import static com.android.wm.shell.protolog.ShellProtoLogGroup.WM_SHELL_TRANSITIONS;
 
@@ -30,6 +30,7 @@
 import android.window.TransitionRequestInfo;
 import android.window.WindowContainerTransaction;
 
+import androidx.annotation.IntDef;
 import androidx.annotation.NonNull;
 import androidx.annotation.Nullable;
 
@@ -45,6 +46,8 @@
 import com.android.wm.shell.unfold.animation.SplitTaskUnfoldAnimator;
 import com.android.wm.shell.unfold.animation.UnfoldTaskAnimator;
 
+import java.lang.annotation.Retention;
+import java.lang.annotation.RetentionPolicy;
 import java.util.ArrayList;
 import java.util.List;
 import java.util.concurrent.Executor;
@@ -56,6 +59,18 @@
  */
 public class UnfoldTransitionHandler implements TransitionHandler, UnfoldListener {
 
+    @Retention(RetentionPolicy.SOURCE)
+    @IntDef({
+            DefaultDisplayChange.DEFAULT_DISPLAY_NO_CHANGE,
+            DefaultDisplayChange.DEFAULT_DISPLAY_UNFOLD,
+            DefaultDisplayChange.DEFAULT_DISPLAY_FOLD,
+    })
+    private @interface DefaultDisplayChange {
+        int DEFAULT_DISPLAY_NO_CHANGE = 0;
+        int DEFAULT_DISPLAY_UNFOLD = 1;
+        int DEFAULT_DISPLAY_FOLD = 2;
+    }
+
     private final ShellUnfoldProgressProvider mUnfoldProgressProvider;
     private final Transitions mTransitions;
     private final Executor mExecutor;
@@ -66,7 +81,10 @@
     @Nullable
     private IBinder mTransition;
 
+    // TODO: b/318803244 - remove when we could guarantee finishing the animation
+    //  after startAnimation callback
     private boolean mAnimationFinished = false;
+    private float mLastAnimationProgress = 0.0f;
     private final List<UnfoldTaskAnimator> mAnimators = new ArrayList<>();
 
     public UnfoldTransitionHandler(ShellInit shellInit,
@@ -107,16 +125,6 @@
             @NonNull SurfaceControl.Transaction startTransaction,
             @NonNull SurfaceControl.Transaction finishTransaction,
             @NonNull TransitionFinishCallback finishCallback) {
-        if (shouldPlayUnfoldAnimation(info) && transition != mTransition) {
-            // Take over transition that has unfold, we might receive it if no other handler
-            // accepted request in handleRequest, e.g. for rotation + unfold or
-            // TRANSIT_NONE + unfold transitions
-            mTransition = transition;
-
-            ProtoLog.v(WM_SHELL_TRANSITIONS, "UnfoldTransitionHandler: "
-                    + "take over startAnimation");
-        }
-
         if (transition != mTransition) return false;
 
         for (int i = 0; i < mAnimators.size(); i++) {
@@ -158,6 +166,8 @@
 
     @Override
     public void onStateChangeProgress(float progress) {
+        mLastAnimationProgress = progress;
+
         if (mTransition == null) return;
 
         SurfaceControl.Transaction transaction = null;
@@ -182,8 +192,14 @@
 
     @Override
     public void onStateChangeFinished() {
-        mAnimationFinished = true;
         finishTransitionIfNeeded();
+
+        // mLastAnimationProgress is guaranteed to be 0f when folding finishes, see
+        // {@link PhysicsBasedUnfoldTransitionProgressProvider#cancelTransition}.
+        // We can use it as an indication that the next animation progress events will be related
+        // to unfolding, so let's reset mAnimationFinished to 'false' in this case.
+        final boolean isFoldingFinished = mLastAnimationProgress == 0f;
+        mAnimationFinished = !isFoldingFinished;
     }
 
     @Override
@@ -211,6 +227,12 @@
         // Apply changes happening during the unfold animation immediately
         t.apply();
         finishCallback.onTransitionFinished(null);
+
+        if (getDefaultDisplayChange(info) == DefaultDisplayChange.DEFAULT_DISPLAY_FOLD) {
+            // Force-finish current unfold animation as we are processing folding now which doesn't
+            // have any animations on the Shell side
+            finishTransitionIfNeeded();
+        }
     }
 
     /** Whether `request` contains an unfold action. */
@@ -219,18 +241,25 @@
         if (!ValueAnimator.areAnimatorsEnabled()) return false;
 
         return (request.getType() == TRANSIT_CHANGE
-                && request.getDisplayChange() != null
-                && isUnfoldDisplayChange(request.getDisplayChange()));
+                && getDefaultDisplayChange(request.getDisplayChange())
+                == DefaultDisplayChange.DEFAULT_DISPLAY_UNFOLD);
     }
 
-    private boolean isUnfoldDisplayChange(
-            @NonNull TransitionRequestInfo.DisplayChange displayChange) {
+    @DefaultDisplayChange
+    private int getDefaultDisplayChange(
+            @Nullable TransitionRequestInfo.DisplayChange displayChange) {
+        if (displayChange == null) return DefaultDisplayChange.DEFAULT_DISPLAY_NO_CHANGE;
+
+        if (displayChange.getDisplayId() != DEFAULT_DISPLAY) {
+            return DefaultDisplayChange.DEFAULT_DISPLAY_NO_CHANGE;
+        }
+
         if (!displayChange.isPhysicalDisplayChanged()) {
-            return false;
+            return DefaultDisplayChange.DEFAULT_DISPLAY_NO_CHANGE;
         }
 
         if (displayChange.getStartAbsBounds() == null || displayChange.getEndAbsBounds() == null) {
-            return false;
+            return DefaultDisplayChange.DEFAULT_DISPLAY_NO_CHANGE;
         }
 
         // Handle only unfolding, currently we don't have an animation when folding
@@ -239,17 +268,11 @@
         final int startArea = displayChange.getStartAbsBounds().width()
                 * displayChange.getStartAbsBounds().height();
 
-        return endArea > startArea;
+        return endArea > startArea ? DefaultDisplayChange.DEFAULT_DISPLAY_UNFOLD
+                : DefaultDisplayChange.DEFAULT_DISPLAY_FOLD;
     }
 
-    /** Whether `transitionInfo` contains an unfold action. */
-    public boolean shouldPlayUnfoldAnimation(@NonNull TransitionInfo transitionInfo) {
-        // Unfold animation won't play when animations are disabled
-        if (!ValueAnimator.areAnimatorsEnabled()) return false;
-        // Only handle transitions that are marked as physical display switch
-        // See PhysicalDisplaySwitchTransitionLauncher for the conditions
-        if ((transitionInfo.getFlags() & TRANSIT_FLAG_PHYSICAL_DISPLAY_SWITCH) == 0) return false;
-
+    private int getDefaultDisplayChange(@NonNull TransitionInfo transitionInfo) {
         for (int i = 0; i < transitionInfo.getChanges().size(); i++) {
             final TransitionInfo.Change change = transitionInfo.getChanges().get(i);
             // We are interested only in display container changes
@@ -268,11 +291,13 @@
                     * change.getStartAbsBounds().height();
 
             if (afterArea > beforeArea) {
-                return true;
+                return DefaultDisplayChange.DEFAULT_DISPLAY_UNFOLD;
+            } else {
+                return DefaultDisplayChange.DEFAULT_DISPLAY_FOLD;
             }
         }
 
-        return false;
+        return DefaultDisplayChange.DEFAULT_DISPLAY_NO_CHANGE;
     }
 
     @Nullable
@@ -293,10 +318,6 @@
     @Override
     public void onFoldStateChanged(boolean isFolded) {
         if (isFolded) {
-            // Reset unfold animation finished flag on folding, so it could be used next time
-            // when we unfold the device as an indication that animation hasn't finished yet
-            mAnimationFinished = false;
-
             // If we are currently animating unfold animation we should finish it because
             // the animation might not start and finish as the device was folded
             finishTransitionIfNeeded();
diff --git a/libs/WindowManager/Shell/tests/unittest/src/com/android/wm/shell/unfold/UnfoldTransitionHandlerTest.java b/libs/WindowManager/Shell/tests/unittest/src/com/android/wm/shell/unfold/UnfoldTransitionHandlerTest.java
index cf2de91..22da66d 100644
--- a/libs/WindowManager/Shell/tests/unittest/src/com/android/wm/shell/unfold/UnfoldTransitionHandlerTest.java
+++ b/libs/WindowManager/Shell/tests/unittest/src/com/android/wm/shell/unfold/UnfoldTransitionHandlerTest.java
@@ -18,13 +18,13 @@
 
 import static android.view.WindowManager.TRANSIT_CHANGE;
 import static android.view.WindowManager.TRANSIT_FLAG_KEYGUARD_GOING_AWAY;
-import static android.view.WindowManager.TRANSIT_FLAG_PHYSICAL_DISPLAY_SWITCH;
 import static android.view.WindowManager.TRANSIT_NONE;
 
 import static com.google.common.truth.Truth.assertThat;
 
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.Mockito.clearInvocations;
+import static org.mockito.Mockito.inOrder;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.never;
 import static org.mockito.Mockito.verify;
@@ -50,6 +50,7 @@
 
 import org.junit.Before;
 import org.junit.Test;
+import org.mockito.InOrder;
 
 import java.util.ArrayList;
 import java.util.List;
@@ -140,6 +141,32 @@
     }
 
     @Test
+    public void handleFoldMergeRequest_finishesTheTransition() {
+        TransitionRequestInfo requestInfo = createUnfoldTransitionRequestInfo();
+        mUnfoldTransitionHandler.handleRequest(mTransition, requestInfo);
+        TransitionFinishCallback finishCallback = mock(TransitionFinishCallback.class);
+        // Starts the animation, the handler should wait for mShellUnfoldProgressProvider to
+        // notify about the end of the animation
+        mUnfoldTransitionHandler.startAnimation(
+                mTransition,
+                mock(TransitionInfo.class),
+                mock(SurfaceControl.Transaction.class),
+                mock(SurfaceControl.Transaction.class),
+                finishCallback
+        );
+
+        // Send fold transition request
+        TransitionFinishCallback mergeFinishCallback = mock(TransitionFinishCallback.class);
+        mUnfoldTransitionHandler.mergeAnimation(new Binder(), createFoldTransitionInfo(),
+                mock(SurfaceControl.Transaction.class), mTransition, mergeFinishCallback);
+
+        // Verify that fold transition is merged into unfold and that unfold is finished
+        final InOrder inOrder = inOrder(mergeFinishCallback, finishCallback);
+        inOrder.verify(mergeFinishCallback).onTransitionFinished(any());
+        inOrder.verify(finishCallback).onTransitionFinished(any());
+    }
+
+    @Test
     public void startAnimation_animationHasNotFinishedYet_doesNotFinishTheTransition() {
         TransitionRequestInfo requestInfo = createUnfoldTransitionRequestInfo();
         mUnfoldTransitionHandler.handleRequest(mTransition, requestInfo);
@@ -174,29 +201,13 @@
     }
 
     @Test
-    public void startAnimation_differentTransitionFromRequestWithUnfold_startsAnimation() {
-        mUnfoldTransitionHandler.handleRequest(new Binder(), createNoneTransitionInfo());
-        TransitionFinishCallback finishCallback = mock(TransitionFinishCallback.class);
-
-        boolean animationStarted = mUnfoldTransitionHandler.startAnimation(
-                mTransition,
-                createUnfoldTransitionInfo(),
-                mock(SurfaceControl.Transaction.class),
-                mock(SurfaceControl.Transaction.class),
-                finishCallback
-        );
-
-        assertThat(animationStarted).isTrue();
-    }
-
-    @Test
     public void startAnimation_differentTransitionFromRequestWithResize_doesNotStartAnimation() {
         mUnfoldTransitionHandler.handleRequest(new Binder(), createNoneTransitionInfo());
         TransitionFinishCallback finishCallback = mock(TransitionFinishCallback.class);
 
         boolean animationStarted = mUnfoldTransitionHandler.startAnimation(
                 mTransition,
-                createDisplayResizeTransitionInfo(),
+                createUnfoldTransitionInfo(),
                 mock(SurfaceControl.Transaction.class),
                 mock(SurfaceControl.Transaction.class),
                 finishCallback
@@ -247,6 +258,7 @@
         TransitionFinishCallback finishCallback = mock(TransitionFinishCallback.class);
 
         mShellUnfoldProgressProvider.onStateChangeStarted();
+        mShellUnfoldProgressProvider.onStateChangeProgress(0.5f);
         mShellUnfoldProgressProvider.onStateChangeFinished();
         mUnfoldTransitionHandler.startAnimation(
                 mTransition,
@@ -279,6 +291,8 @@
         clearInvocations(finishCallback);
 
         // Fold
+        mShellUnfoldProgressProvider.onStateChangeProgress(/* progress= */ 0.0f);
+        mShellUnfoldProgressProvider.onStateChangeFinished();
         mShellUnfoldProgressProvider.onFoldStateChanged(/* isFolded= */ true);
 
         // Second unfold
@@ -370,6 +384,19 @@
                 triggerTaskInfo, /* remoteTransition= */ null, displayChange, 0 /* flags */);
     }
 
+    private TransitionInfo createFoldTransitionInfo() {
+        final TransitionInfo transitionInfo = new TransitionInfo(TRANSIT_CHANGE, /* flags= */ 0);
+
+        final TransitionInfo.Change change = new TransitionInfo.Change(/* container= */ null,
+                /* leash= */ null);
+        change.setFlags(TransitionInfo.FLAG_IS_DISPLAY);
+        change.setStartAbsBounds(new Rect(0, 0, 200, 200));
+        change.setEndAbsBounds(new Rect(0, 0, 100, 100));
+        transitionInfo.addChange(change);
+
+        return transitionInfo;
+    }
+
     private TransitionRequestInfo createNoneTransitionInfo() {
         return new TransitionRequestInfo(TRANSIT_NONE,
                 /* triggerTask= */ null, /* remoteTransition= */ null,
@@ -446,17 +473,6 @@
         change.setEndAbsBounds(new Rect(0, 0, 100, 100));
         change.setFlags(TransitionInfo.FLAG_IS_DISPLAY);
         transitionInfo.addChange(change);
-        transitionInfo.setFlags(TRANSIT_FLAG_PHYSICAL_DISPLAY_SWITCH);
-        return transitionInfo;
-    }
-
-    private TransitionInfo createDisplayResizeTransitionInfo() {
-        TransitionInfo transitionInfo = new TransitionInfo(TRANSIT_CHANGE, /* flags= */ 0);
-        TransitionInfo.Change change = new TransitionInfo.Change(null, mock(SurfaceControl.class));
-        change.setStartAbsBounds(new Rect(0, 0, 10, 10));
-        change.setEndAbsBounds(new Rect(0, 0, 100, 100));
-        change.setFlags(TransitionInfo.FLAG_IS_DISPLAY);
-        transitionInfo.addChange(change);
         return transitionInfo;
     }