Merge "Move reentry state into PipBoundsState"
diff --git a/libs/WindowManager/Shell/src/com/android/wm/shell/pip/PipBoundsHandler.java b/libs/WindowManager/Shell/src/com/android/wm/shell/pip/PipBoundsHandler.java
index de3261b..2ab087c 100644
--- a/libs/WindowManager/Shell/src/com/android/wm/shell/pip/PipBoundsHandler.java
+++ b/libs/WindowManager/Shell/src/com/android/wm/shell/pip/PipBoundsHandler.java
@@ -22,9 +22,9 @@
 import static android.view.Surface.ROTATION_0;
 import static android.view.Surface.ROTATION_180;
 
+import android.annotation.NonNull;
 import android.app.ActivityTaskManager;
 import android.app.ActivityTaskManager.RootTaskInfo;
-import android.content.ComponentName;
 import android.content.Context;
 import android.content.res.Resources;
 import android.graphics.Point;
@@ -52,14 +52,11 @@
     private static final String TAG = PipBoundsHandler.class.getSimpleName();
     private static final float INVALID_SNAP_FRACTION = -1f;
 
+    private final @NonNull PipBoundsState mPipBoundsState;
     private final PipSnapAlgorithm mSnapAlgorithm;
     private final DisplayInfo mDisplayInfo = new DisplayInfo();
     private DisplayLayout mDisplayLayout;
 
-    private ComponentName mLastPipComponentName;
-    private float mReentrySnapFraction = INVALID_SNAP_FRACTION;
-    private Size mReentrySize;
-
     private float mDefaultAspectRatio;
     private float mMinAspectRatio;
     private float mMaxAspectRatio;
@@ -75,7 +72,8 @@
     private boolean mIsShelfShowing;
     private int mShelfHeight;
 
-    public PipBoundsHandler(Context context) {
+    public PipBoundsHandler(Context context, @NonNull PipBoundsState pipBoundsState) {
+        mPipBoundsState = pipBoundsState;
         mSnapAlgorithm = new PipSnapAlgorithm(context);
         mDisplayLayout = new DisplayLayout();
         reloadResources(context);
@@ -175,40 +173,6 @@
     }
 
     /**
-     * Responds to IPinnedStackListener on saving reentry snap fraction and size
-     * for a given {@link ComponentName}.
-     */
-    public void onSaveReentryBounds(ComponentName componentName, Rect bounds) {
-        mReentrySnapFraction = getSnapFraction(bounds);
-        mReentrySize = new Size(bounds.width(), bounds.height());
-        mLastPipComponentName = componentName;
-    }
-
-    /**
-     * Responds to IPinnedStackListener on resetting reentry snap fraction and size
-     * for a given {@link ComponentName}.
-     */
-    public void onResetReentryBounds(ComponentName componentName) {
-        if (componentName.equals(mLastPipComponentName)) {
-            onResetReentryBoundsUnchecked();
-        }
-    }
-
-    private void onResetReentryBoundsUnchecked() {
-        mReentrySnapFraction = INVALID_SNAP_FRACTION;
-        mReentrySize = null;
-        mLastPipComponentName = null;
-    }
-
-    /**
-     * Returns ture if there's a valid snap fraction. This is used with {@link EXTRA_IS_FIRST_ENTRY}
-     * to see if this is the first time user has entered PIP for the component.
-     */
-    public boolean hasSaveReentryBounds() {
-        return mReentrySnapFraction != INVALID_SNAP_FRACTION;
-    }
-
-    /**
      * The {@link PipSnapAlgorithm} is couple on display bounds
      * @return {@link PipSnapAlgorithm}.
      */
@@ -250,37 +214,43 @@
     }
 
     /**
-     * See {@link #getDestinationBounds(ComponentName, float, Rect, Size, boolean)}
+     * See {@link #getDestinationBounds(float, Rect, Size, boolean)}
      */
-    public Rect getDestinationBounds(ComponentName componentName, float aspectRatio, Rect bounds,
-            Size minimalSize) {
-        return getDestinationBounds(componentName, aspectRatio, bounds, minimalSize,
+    public Rect getDestinationBounds(float aspectRatio, Rect bounds, Size minimalSize) {
+        return getDestinationBounds(aspectRatio, bounds, minimalSize,
                 false /* useCurrentMinEdgeSize */);
     }
 
     /**
      * @return {@link Rect} of the destination PiP window bounds.
      */
-    public Rect getDestinationBounds(ComponentName componentName, float aspectRatio, Rect bounds,
+    public Rect getDestinationBounds(float aspectRatio, Rect bounds,
             Size minimalSize, boolean useCurrentMinEdgeSize) {
-        if (!componentName.equals(mLastPipComponentName)) {
-            onResetReentryBoundsUnchecked();
-            mLastPipComponentName = componentName;
-        }
+        boolean isReentryBounds = false;
         final Rect destinationBounds;
         if (bounds == null) {
-            final Rect defaultBounds = getDefaultBounds(mReentrySnapFraction, mReentrySize);
-            destinationBounds = new Rect(defaultBounds);
-            if (mReentrySnapFraction == INVALID_SNAP_FRACTION && mReentrySize == null) {
+            // Calculating initial entry bounds
+            final PipBoundsState.PipReentryState state = mPipBoundsState.getReentryState();
+
+            final Rect defaultBounds;
+            if (state != null) {
+                // Restore to reentry bounds.
+                defaultBounds = getDefaultBounds(state.getSnapFraction(), state.getSize());
+                isReentryBounds = true;
+            } else {
+                // Get actual default bounds.
+                defaultBounds = getDefaultBounds(INVALID_SNAP_FRACTION, null /* size */);
                 mOverrideMinimalSize = minimalSize;
             }
+
+            destinationBounds = new Rect(defaultBounds);
         } else {
+            // Just adjusting bounds (e.g. on aspect ratio changed).
             destinationBounds = new Rect(bounds);
         }
         if (isValidPictureInPictureAspectRatio(aspectRatio)) {
-            boolean useCurrentSize = bounds == null && mReentrySize != null;
             transformBoundsToAspectRatio(destinationBounds, aspectRatio, useCurrentMinEdgeSize,
-                    useCurrentSize);
+                    isReentryBounds);
         }
         mAspectRatio = aspectRatio;
         return destinationBounds;
@@ -533,9 +503,6 @@
     public void dump(PrintWriter pw, String prefix) {
         final String innerPrefix = prefix + "  ";
         pw.println(prefix + TAG);
-        pw.println(innerPrefix + "mLastPipComponentName=" + mLastPipComponentName);
-        pw.println(innerPrefix + "mReentrySnapFraction=" + mReentrySnapFraction);
-        pw.println(innerPrefix + "mReentrySize=" + mReentrySize);
         pw.println(innerPrefix + "mDisplayInfo=" + mDisplayInfo);
         pw.println(innerPrefix + "mDefaultAspectRatio=" + mDefaultAspectRatio);
         pw.println(innerPrefix + "mMinAspectRatio=" + mMinAspectRatio);
diff --git a/libs/WindowManager/Shell/src/com/android/wm/shell/pip/PipBoundsState.java b/libs/WindowManager/Shell/src/com/android/wm/shell/pip/PipBoundsState.java
index 10e5c3d..2625f16 100644
--- a/libs/WindowManager/Shell/src/com/android/wm/shell/pip/PipBoundsState.java
+++ b/libs/WindowManager/Shell/src/com/android/wm/shell/pip/PipBoundsState.java
@@ -17,9 +17,15 @@
 package com.android.wm.shell.pip;
 
 import android.annotation.NonNull;
+import android.annotation.Nullable;
+import android.content.ComponentName;
 import android.graphics.Rect;
+import android.util.Size;
+
+import com.android.internal.annotations.VisibleForTesting;
 
 import java.io.PrintWriter;
+import java.util.Objects;
 
 /**
  * Singleton source of truth for the current state of PIP bounds.
@@ -28,6 +34,8 @@
     private static final String TAG = PipBoundsState.class.getSimpleName();
 
     private final @NonNull Rect mBounds = new Rect();
+    private PipReentryState mPipReentryState;
+    private ComponentName mLastPipComponentName;
 
     void setBounds(@NonNull Rect bounds) {
         mBounds.set(bounds);
@@ -39,11 +47,83 @@
     }
 
     /**
+     * Save the reentry state to restore to when re-entering PIP mode.
+     *
+     * TODO(b/169373982): consider refactoring this so that this class alone can use mBounds and
+     * calculate the snap fraction to save for re-entry.
+     */
+    public void saveReentryState(@NonNull Rect bounds, float fraction) {
+        mPipReentryState = new PipReentryState(new Size(bounds.width(), bounds.height()), fraction);
+    }
+
+    /**
+     * Returns the saved reentry state.
+     */
+    @Nullable
+    public PipReentryState getReentryState() {
+        return mPipReentryState;
+    }
+
+    /**
+     * Set the last {@link ComponentName} to enter PIP mode.
+     */
+    public void setLastPipComponentName(ComponentName lastPipComponentName) {
+        final boolean changed = !Objects.equals(mLastPipComponentName, lastPipComponentName);
+        mLastPipComponentName = lastPipComponentName;
+        if (changed) {
+            clearReentryState();
+        }
+    }
+
+    public ComponentName getLastPipComponentName() {
+        return mLastPipComponentName;
+    }
+
+    @VisibleForTesting
+    void clearReentryState() {
+        mPipReentryState = null;
+    }
+
+    static final class PipReentryState {
+        private static final String TAG = PipReentryState.class.getSimpleName();
+
+        private final @NonNull Size mSize;
+        private final float mSnapFraction;
+
+        PipReentryState(@NonNull Size size, float snapFraction) {
+            mSize = size;
+            mSnapFraction = snapFraction;
+        }
+
+        @NonNull
+        Size getSize() {
+            return mSize;
+        }
+
+        float getSnapFraction() {
+            return mSnapFraction;
+        }
+
+        void dump(PrintWriter pw, String prefix) {
+            final String innerPrefix = prefix + "  ";
+            pw.println(prefix + TAG);
+            pw.println(innerPrefix + "mSize=" + mSize);
+            pw.println(innerPrefix + "mSnapFraction=" + mSnapFraction);
+        }
+    }
+
+    /**
      * Dumps internal state.
      */
     public void dump(PrintWriter pw, String prefix) {
         final String innerPrefix = prefix + "  ";
         pw.println(prefix + TAG);
         pw.println(innerPrefix + "mBounds=" + mBounds);
+        pw.println(innerPrefix + "mLastPipComponentName=" + mLastPipComponentName);
+        if (mPipReentryState == null) {
+            pw.println(innerPrefix + "mPipReentryState=null");
+        } else {
+            mPipReentryState.dump(pw, innerPrefix);
+        }
     }
 }
diff --git a/libs/WindowManager/Shell/src/com/android/wm/shell/pip/PipTaskOrganizer.java b/libs/WindowManager/Shell/src/com/android/wm/shell/pip/PipTaskOrganizer.java
index 15fd424..22dc084 100644
--- a/libs/WindowManager/Shell/src/com/android/wm/shell/pip/PipTaskOrganizer.java
+++ b/libs/WindowManager/Shell/src/com/android/wm/shell/pip/PipTaskOrganizer.java
@@ -329,7 +329,8 @@
             PictureInPictureParams pictureInPictureParams) {
         mShouldIgnoreEnteringPipTransition = true;
         mState = State.ENTERING_PIP;
-        return mPipBoundsHandler.getDestinationBounds(componentName,
+        mPipBoundsState.setLastPipComponentName(componentName);
+        return mPipBoundsHandler.getDestinationBounds(
                 getAspectRatioOrDefault(pictureInPictureParams),
                 null /* bounds */, getMinimalSize(activityInfo));
     }
@@ -465,6 +466,7 @@
         mLeash = leash;
         mInitialState.put(mToken.asBinder(), new Configuration(mTaskInfo.configuration));
         mPictureInPictureParams = mTaskInfo.pictureInPictureParams;
+        mPipBoundsState.setLastPipComponentName(mTaskInfo.topActivity);
 
         mPipUiEventLoggerLogger.setTaskInfo(mTaskInfo);
         mPipUiEventLoggerLogger.log(PipUiEventLogger.PipUiEventEnum.PICTURE_IN_PICTURE_ENTER);
@@ -491,7 +493,7 @@
         }
 
         final Rect destinationBounds = mPipBoundsHandler.getDestinationBounds(
-                mTaskInfo.topActivity, getAspectRatioOrDefault(mPictureInPictureParams),
+                getAspectRatioOrDefault(mPictureInPictureParams),
                 null /* bounds */, getMinimalSize(mTaskInfo.topActivityInfo));
         Objects.requireNonNull(destinationBounds, "Missing destination bounds");
         final Rect currentBounds = mTaskInfo.configuration.windowConfiguration.getBounds();
@@ -686,13 +688,14 @@
     @Override
     public void onTaskInfoChanged(ActivityManager.RunningTaskInfo info) {
         Objects.requireNonNull(mToken, "onTaskInfoChanged requires valid existing mToken");
+        mPipBoundsState.setLastPipComponentName(info.topActivity);
         final PictureInPictureParams newParams = info.pictureInPictureParams;
         if (newParams == null || !applyPictureInPictureParams(newParams)) {
             Log.d(TAG, "Ignored onTaskInfoChanged with PiP param: " + newParams);
             return;
         }
         final Rect destinationBounds = mPipBoundsHandler.getDestinationBounds(
-                info.topActivity, getAspectRatioOrDefault(newParams),
+                getAspectRatioOrDefault(newParams),
                 mPipBoundsState.getBounds(), getMinimalSize(info.topActivityInfo),
                 true /* userCurrentMinEdgeSize */);
         Objects.requireNonNull(destinationBounds, "Missing destination bounds");
@@ -709,7 +712,7 @@
     public void onFixedRotationFinished(int displayId) {
         if (mShouldDeferEnteringPip && mState.isInPip()) {
             final Rect destinationBounds = mPipBoundsHandler.getDestinationBounds(
-                    mTaskInfo.topActivity, getAspectRatioOrDefault(mPictureInPictureParams),
+                    getAspectRatioOrDefault(mPictureInPictureParams),
                     null /* bounds */, getMinimalSize(mTaskInfo.topActivityInfo));
             // schedule a regular animation to ensure all the callbacks are still being sent
             enterPipWithAlphaAnimation(destinationBounds, 0 /* durationMs */);
@@ -783,7 +786,7 @@
         }
 
         final Rect newDestinationBounds = mPipBoundsHandler.getDestinationBounds(
-                mTaskInfo.topActivity, getAspectRatioOrDefault(mPictureInPictureParams),
+                getAspectRatioOrDefault(mPictureInPictureParams),
                 null /* bounds */, getMinimalSize(mTaskInfo.topActivityInfo));
         if (newDestinationBounds.equals(currentDestinationBounds)) return;
         if (animator.getAnimationType() == ANIM_TYPE_BOUNDS) {
diff --git a/libs/WindowManager/Shell/src/com/android/wm/shell/pip/phone/PipController.java b/libs/WindowManager/Shell/src/com/android/wm/shell/pip/phone/PipController.java
index 13f5ac3..a3d21f2 100644
--- a/libs/WindowManager/Shell/src/com/android/wm/shell/pip/phone/PipController.java
+++ b/libs/WindowManager/Shell/src/com/android/wm/shell/pip/phone/PipController.java
@@ -173,7 +173,13 @@
 
         @Override
         public void onActivityHidden(ComponentName componentName) {
-            mHandler.post(() -> mPipBoundsHandler.onResetReentryBounds(componentName));
+            mHandler.post(() -> {
+                if (componentName.equals(mPipBoundsState.getLastPipComponentName())) {
+                    // The activity was removed, we don't want to restore to the reentry state
+                    // saved for this component anymore.
+                    mPipBoundsState.setLastPipComponentName(null);
+                }
+            });
         }
 
         @Override
@@ -384,7 +390,8 @@
         if (isOutPipDirection(direction)) {
             // Exiting PIP, save the reentry bounds to restore to when re-entering.
             updateReentryBounds(pipBounds);
-            mPipBoundsHandler.onSaveReentryBounds(activity, mReentryBounds);
+            final float snapFraction = mPipBoundsHandler.getSnapFraction(mReentryBounds);
+            mPipBoundsState.saveReentryState(mReentryBounds, snapFraction);
         }
         // Disable touches while the animation is running
         mTouchHandler.setTouchEnabled(false);
diff --git a/libs/WindowManager/Shell/tests/unittest/src/com/android/wm/shell/pip/PipBoundsHandlerTest.java b/libs/WindowManager/Shell/tests/unittest/src/com/android/wm/shell/pip/PipBoundsHandlerTest.java
index d9e3148..e0ac8e2 100644
--- a/libs/WindowManager/Shell/tests/unittest/src/com/android/wm/shell/pip/PipBoundsHandlerTest.java
+++ b/libs/WindowManager/Shell/tests/unittest/src/com/android/wm/shell/pip/PipBoundsHandlerTest.java
@@ -20,7 +20,6 @@
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertTrue;
 
-import android.content.ComponentName;
 import android.graphics.Rect;
 import android.testing.AndroidTestingRunner;
 import android.testing.TestableLooper;
@@ -58,15 +57,13 @@
 
     private PipBoundsHandler mPipBoundsHandler;
     private DisplayInfo mDefaultDisplayInfo;
-    private ComponentName mTestComponentName1;
-    private ComponentName mTestComponentName2;
+    private PipBoundsState mPipBoundsState;
 
     @Before
     public void setUp() throws Exception {
         initializeMockResources();
-        mPipBoundsHandler = new PipBoundsHandler(mContext);
-        mTestComponentName1 = new ComponentName(mContext, "component1");
-        mTestComponentName2 = new ComponentName(mContext, "component2");
+        mPipBoundsState = new PipBoundsState();
+        mPipBoundsHandler = new PipBoundsHandler(mContext, mPipBoundsState);
 
         mPipBoundsHandler.onDisplayInfoChanged(mDefaultDisplayInfo);
     }
@@ -126,8 +123,8 @@
                 (MAX_ASPECT_RATIO + DEFAULT_ASPECT_RATIO) / 2
         };
         for (float aspectRatio : aspectRatios) {
-            final Rect destinationBounds = mPipBoundsHandler.getDestinationBounds(
-                    mTestComponentName1, aspectRatio, EMPTY_CURRENT_BOUNDS, EMPTY_MINIMAL_SIZE);
+            final Rect destinationBounds = mPipBoundsHandler.getDestinationBounds(aspectRatio,
+                    EMPTY_CURRENT_BOUNDS, EMPTY_MINIMAL_SIZE);
             final float actualAspectRatio =
                     destinationBounds.width() / (destinationBounds.height() * 1f);
             assertEquals("Destination bounds matches the given aspect ratio",
@@ -142,8 +139,8 @@
                 MAX_ASPECT_RATIO * 2
         };
         for (float aspectRatio : invalidAspectRatios) {
-            final Rect destinationBounds = mPipBoundsHandler.getDestinationBounds(
-                    mTestComponentName1, aspectRatio, EMPTY_CURRENT_BOUNDS, EMPTY_MINIMAL_SIZE);
+            final Rect destinationBounds = mPipBoundsHandler.getDestinationBounds(aspectRatio,
+                    EMPTY_CURRENT_BOUNDS, EMPTY_MINIMAL_SIZE);
             final float actualAspectRatio =
                     destinationBounds.width() / (destinationBounds.height() * 1f);
             assertEquals("Destination bounds fallbacks to default aspect ratio",
@@ -158,8 +155,8 @@
         final Rect currentBounds = new Rect(0, 0, 0, 100);
         currentBounds.right = (int) (currentBounds.height() * aspectRatio) + currentBounds.left;
 
-        final Rect destinationBounds = mPipBoundsHandler.getDestinationBounds(
-                mTestComponentName1, aspectRatio, currentBounds, EMPTY_MINIMAL_SIZE);
+        final Rect destinationBounds = mPipBoundsHandler.getDestinationBounds(aspectRatio,
+                currentBounds, EMPTY_MINIMAL_SIZE);
 
         final float actualAspectRatio =
                 destinationBounds.width() / (destinationBounds.height() * 1f);
@@ -182,8 +179,8 @@
         for (int i = 0; i < aspectRatios.length; i++) {
             final float aspectRatio = aspectRatios[i];
             final Size minimalSize = minimalSizes[i];
-            final Rect destinationBounds = mPipBoundsHandler.getDestinationBounds(
-                    mTestComponentName1, aspectRatio, EMPTY_CURRENT_BOUNDS, minimalSize);
+            final Rect destinationBounds = mPipBoundsHandler.getDestinationBounds(aspectRatio,
+                    EMPTY_CURRENT_BOUNDS, minimalSize);
             assertTrue("Destination bounds is no smaller than minimal requirement",
                     (destinationBounds.width() == minimalSize.getWidth()
                             && destinationBounds.height() >= minimalSize.getHeight())
@@ -203,8 +200,8 @@
         currentBounds.right = (int) (currentBounds.height() * aspectRatio) + currentBounds.left;
         final Size minSize = new Size(currentBounds.width() / 2, currentBounds.height() / 2);
 
-        final Rect destinationBounds = mPipBoundsHandler.getDestinationBounds(
-                mTestComponentName1, aspectRatio, currentBounds, minSize);
+        final Rect destinationBounds = mPipBoundsHandler.getDestinationBounds(aspectRatio,
+                currentBounds, minSize);
 
         assertTrue("Destination bounds ignores minimal size",
                 destinationBounds.width() > minSize.getWidth()
@@ -212,28 +209,44 @@
     }
 
     @Test
-    public void getDestinationBounds_withDifferentComponentName_ignoreLastPosition() {
-        final Rect oldPosition = mPipBoundsHandler.getDestinationBounds(mTestComponentName1,
-                DEFAULT_ASPECT_RATIO, EMPTY_CURRENT_BOUNDS, EMPTY_MINIMAL_SIZE);
+    public void getDestinationBounds_reentryStateExists_restoreLastSize() {
+        final Rect reentryBounds = mPipBoundsHandler.getDestinationBounds(DEFAULT_ASPECT_RATIO,
+                EMPTY_CURRENT_BOUNDS, EMPTY_MINIMAL_SIZE);
+        reentryBounds.scale(1.25f);
+        final float reentrySnapFraction = mPipBoundsHandler.getSnapFraction(reentryBounds);
 
-        oldPosition.offset(0, -100);
-        mPipBoundsHandler.onSaveReentryBounds(mTestComponentName1, oldPosition);
+        mPipBoundsState.saveReentryState(reentryBounds, reentrySnapFraction);
+        final Rect destinationBounds = mPipBoundsHandler.getDestinationBounds(DEFAULT_ASPECT_RATIO,
+                EMPTY_CURRENT_BOUNDS, EMPTY_MINIMAL_SIZE);
 
-        final Rect newPosition = mPipBoundsHandler.getDestinationBounds(mTestComponentName2,
-                DEFAULT_ASPECT_RATIO, EMPTY_CURRENT_BOUNDS, EMPTY_MINIMAL_SIZE);
+        assertEquals(reentryBounds.width(), destinationBounds.width());
+        assertEquals(reentryBounds.height(), destinationBounds.height());
+    }
 
-        assertNonBoundsInclusionWithMargin("ignore saved bounds", oldPosition, newPosition);
+    @Test
+    public void getDestinationBounds_reentryStateExists_restoreLastPosition() {
+        final Rect reentryBounds = mPipBoundsHandler.getDestinationBounds(DEFAULT_ASPECT_RATIO,
+                EMPTY_CURRENT_BOUNDS, EMPTY_MINIMAL_SIZE);
+        reentryBounds.offset(0, -100);
+        final float reentrySnapFraction = mPipBoundsHandler.getSnapFraction(reentryBounds);
+
+        mPipBoundsState.saveReentryState(reentryBounds, reentrySnapFraction);
+
+        final Rect destinationBounds = mPipBoundsHandler.getDestinationBounds(DEFAULT_ASPECT_RATIO,
+                EMPTY_CURRENT_BOUNDS, EMPTY_MINIMAL_SIZE);
+
+        assertBoundsInclusionWithMargin("restoreLastPosition", reentryBounds, destinationBounds);
     }
 
     @Test
     public void setShelfHeight_offsetBounds() {
         final int shelfHeight = 100;
-        final Rect oldPosition = mPipBoundsHandler.getDestinationBounds(mTestComponentName1,
-                DEFAULT_ASPECT_RATIO, EMPTY_CURRENT_BOUNDS, EMPTY_MINIMAL_SIZE);
+        final Rect oldPosition = mPipBoundsHandler.getDestinationBounds(DEFAULT_ASPECT_RATIO,
+                EMPTY_CURRENT_BOUNDS, EMPTY_MINIMAL_SIZE);
 
         mPipBoundsHandler.setShelfHeight(true, shelfHeight);
-        final Rect newPosition = mPipBoundsHandler.getDestinationBounds(mTestComponentName1,
-                DEFAULT_ASPECT_RATIO, EMPTY_CURRENT_BOUNDS, EMPTY_MINIMAL_SIZE);
+        final Rect newPosition = mPipBoundsHandler.getDestinationBounds(DEFAULT_ASPECT_RATIO,
+                EMPTY_CURRENT_BOUNDS, EMPTY_MINIMAL_SIZE);
 
         oldPosition.offset(0, -shelfHeight);
         assertBoundsInclusionWithMargin("offsetBounds by shelf", oldPosition, newPosition);
@@ -242,92 +255,30 @@
     @Test
     public void onImeVisibilityChanged_offsetBounds() {
         final int imeHeight = 100;
-        final Rect oldPosition = mPipBoundsHandler.getDestinationBounds(mTestComponentName1,
-                DEFAULT_ASPECT_RATIO, EMPTY_CURRENT_BOUNDS, EMPTY_MINIMAL_SIZE);
+        final Rect oldPosition = mPipBoundsHandler.getDestinationBounds(DEFAULT_ASPECT_RATIO,
+                EMPTY_CURRENT_BOUNDS, EMPTY_MINIMAL_SIZE);
 
         mPipBoundsHandler.onImeVisibilityChanged(true, imeHeight);
-        final Rect newPosition = mPipBoundsHandler.getDestinationBounds(mTestComponentName1,
-                DEFAULT_ASPECT_RATIO, EMPTY_CURRENT_BOUNDS, EMPTY_MINIMAL_SIZE);
+        final Rect newPosition = mPipBoundsHandler.getDestinationBounds(DEFAULT_ASPECT_RATIO,
+                EMPTY_CURRENT_BOUNDS, EMPTY_MINIMAL_SIZE);
 
         oldPosition.offset(0, -imeHeight);
         assertBoundsInclusionWithMargin("offsetBounds by IME", oldPosition, newPosition);
     }
 
     @Test
-    public void onSaveReentryBounds_restoreLastPosition() {
-        final Rect oldPosition = mPipBoundsHandler.getDestinationBounds(mTestComponentName1,
-                DEFAULT_ASPECT_RATIO, EMPTY_CURRENT_BOUNDS, EMPTY_MINIMAL_SIZE);
+    public void getDestinationBounds_noReentryState_useDefaultBounds() {
+        final Rect defaultBounds = mPipBoundsHandler.getDestinationBounds(DEFAULT_ASPECT_RATIO,
+                EMPTY_CURRENT_BOUNDS, EMPTY_MINIMAL_SIZE);
 
-        oldPosition.offset(0, -100);
-        mPipBoundsHandler.onSaveReentryBounds(mTestComponentName1, oldPosition);
+        mPipBoundsState.clearReentryState();
 
-        final Rect newPosition = mPipBoundsHandler.getDestinationBounds(mTestComponentName1,
-                DEFAULT_ASPECT_RATIO, EMPTY_CURRENT_BOUNDS, EMPTY_MINIMAL_SIZE);
-
-        assertBoundsInclusionWithMargin("restoreLastPosition", oldPosition, newPosition);
-    }
-
-    @Test
-    public void onSaveReentryBounds_restoreLastSize() {
-        final Rect oldSize = mPipBoundsHandler.getDestinationBounds(mTestComponentName1,
-                DEFAULT_ASPECT_RATIO, EMPTY_CURRENT_BOUNDS, EMPTY_MINIMAL_SIZE);
-
-        oldSize.scale(1.25f);
-        mPipBoundsHandler.onSaveReentryBounds(mTestComponentName1, oldSize);
-
-        final Rect newSize = mPipBoundsHandler.getDestinationBounds(mTestComponentName1,
-                DEFAULT_ASPECT_RATIO, EMPTY_CURRENT_BOUNDS, EMPTY_MINIMAL_SIZE);
-
-        assertEquals(oldSize.width(), newSize.width());
-        assertEquals(oldSize.height(), newSize.height());
-    }
-
-    @Test
-    public void onResetReentryBounds_useDefaultBounds() {
-        final Rect defaultBounds = mPipBoundsHandler.getDestinationBounds(mTestComponentName1,
-                DEFAULT_ASPECT_RATIO, EMPTY_CURRENT_BOUNDS, EMPTY_MINIMAL_SIZE);
-        final Rect newBounds = new Rect(defaultBounds);
-        newBounds.offset(0, -100);
-        mPipBoundsHandler.onSaveReentryBounds(mTestComponentName1, newBounds);
-
-        mPipBoundsHandler.onResetReentryBounds(mTestComponentName1);
-        final Rect actualBounds = mPipBoundsHandler.getDestinationBounds(mTestComponentName1,
-                DEFAULT_ASPECT_RATIO, EMPTY_CURRENT_BOUNDS, EMPTY_MINIMAL_SIZE);
+        final Rect actualBounds = mPipBoundsHandler.getDestinationBounds(DEFAULT_ASPECT_RATIO,
+                EMPTY_CURRENT_BOUNDS, EMPTY_MINIMAL_SIZE);
 
         assertBoundsInclusionWithMargin("useDefaultBounds", defaultBounds, actualBounds);
     }
 
-    @Test
-    public void onResetReentryBounds_componentMismatch_restoreLastPosition() {
-        final Rect defaultBounds = mPipBoundsHandler.getDestinationBounds(mTestComponentName1,
-                DEFAULT_ASPECT_RATIO, EMPTY_CURRENT_BOUNDS, EMPTY_MINIMAL_SIZE);
-        final Rect newBounds = new Rect(defaultBounds);
-        newBounds.offset(0, -100);
-        mPipBoundsHandler.onSaveReentryBounds(mTestComponentName1, newBounds);
-
-        mPipBoundsHandler.onResetReentryBounds(mTestComponentName2);
-        final Rect actualBounds = mPipBoundsHandler.getDestinationBounds(mTestComponentName1,
-                DEFAULT_ASPECT_RATIO, EMPTY_CURRENT_BOUNDS, EMPTY_MINIMAL_SIZE);
-
-        assertBoundsInclusionWithMargin("restoreLastPosition", newBounds, actualBounds);
-    }
-
-    @Test
-    public void onSaveReentryBounds_componentMismatch_restoreLastSize() {
-        final Rect oldSize = mPipBoundsHandler.getDestinationBounds(mTestComponentName1,
-                DEFAULT_ASPECT_RATIO, EMPTY_CURRENT_BOUNDS, EMPTY_MINIMAL_SIZE);
-
-        oldSize.scale(1.25f);
-        mPipBoundsHandler.onSaveReentryBounds(mTestComponentName1, oldSize);
-
-        mPipBoundsHandler.onResetReentryBounds(mTestComponentName2);
-        final Rect newSize = mPipBoundsHandler.getDestinationBounds(mTestComponentName1,
-                DEFAULT_ASPECT_RATIO, EMPTY_CURRENT_BOUNDS, EMPTY_MINIMAL_SIZE);
-
-        assertEquals(oldSize.width(), newSize.width());
-        assertEquals(oldSize.height(), newSize.height());
-    }
-
     private void assertBoundsInclusionWithMargin(String from, Rect expected, Rect actual) {
         final Rect expectedWithMargin = new Rect(expected);
         expectedWithMargin.inset(-ROUNDING_ERROR_MARGIN, -ROUNDING_ERROR_MARGIN);
diff --git a/libs/WindowManager/Shell/tests/unittest/src/com/android/wm/shell/pip/PipBoundsStateTest.java b/libs/WindowManager/Shell/tests/unittest/src/com/android/wm/shell/pip/PipBoundsStateTest.java
new file mode 100644
index 0000000..dc9399e
--- /dev/null
+++ b/libs/WindowManager/Shell/tests/unittest/src/com/android/wm/shell/pip/PipBoundsStateTest.java
@@ -0,0 +1,110 @@
+/*
+ * Copyright (C) 2020 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.wm.shell.pip;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertNull;
+
+import android.content.ComponentName;
+import android.graphics.Rect;
+import android.testing.AndroidTestingRunner;
+import android.testing.TestableLooper;
+import android.util.Size;
+
+import androidx.test.filters.SmallTest;
+
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+
+/**
+ * Tests for {@link PipBoundsState}.
+ */
+@RunWith(AndroidTestingRunner.class)
+@TestableLooper.RunWithLooper
+@SmallTest
+public class PipBoundsStateTest extends PipTestCase {
+
+    private static final Rect DEFAULT_BOUNDS = new Rect(0, 0, 10, 10);
+    private static final float DEFAULT_SNAP_FRACTION = 1.0f;
+
+    private PipBoundsState mPipBoundsState;
+    private ComponentName mTestComponentName1;
+    private ComponentName mTestComponentName2;
+
+    @Before
+    public void setUp() {
+        mPipBoundsState = new PipBoundsState();
+        mTestComponentName1 = new ComponentName(mContext, "component1");
+        mTestComponentName2 = new ComponentName(mContext, "component2");
+    }
+
+    @Test
+    public void testSetBounds() {
+        final Rect bounds = new Rect(0, 0, 100, 100);
+        mPipBoundsState.setBounds(bounds);
+
+        assertEquals(bounds, mPipBoundsState.getBounds());
+    }
+
+    @Test
+    public void testSetReentryState() {
+        final Rect bounds = new Rect(0, 0, 100, 100);
+        final float snapFraction = 0.5f;
+
+        mPipBoundsState.saveReentryState(bounds, snapFraction);
+
+        final PipBoundsState.PipReentryState state = mPipBoundsState.getReentryState();
+        assertEquals(new Size(100, 100), state.getSize());
+        assertEquals(snapFraction, state.getSnapFraction(), 0.01);
+    }
+
+    @Test
+    public void testClearReentryState() {
+        final Rect bounds = new Rect(0, 0, 100, 100);
+        final float snapFraction = 0.5f;
+
+        mPipBoundsState.saveReentryState(bounds, snapFraction);
+        mPipBoundsState.clearReentryState();
+
+        assertNull(mPipBoundsState.getReentryState());
+    }
+
+    @Test
+    public void testSetLastPipComponentName_notChanged_doesNotClearReentryState() {
+        mPipBoundsState.setLastPipComponentName(mTestComponentName1);
+        mPipBoundsState.saveReentryState(DEFAULT_BOUNDS, DEFAULT_SNAP_FRACTION);
+
+        mPipBoundsState.setLastPipComponentName(mTestComponentName1);
+
+        final PipBoundsState.PipReentryState state = mPipBoundsState.getReentryState();
+        assertNotNull(state);
+        assertEquals(new Size(DEFAULT_BOUNDS.width(), DEFAULT_BOUNDS.height()), state.getSize());
+        assertEquals(DEFAULT_SNAP_FRACTION, state.getSnapFraction(), 0.01);
+    }
+
+    @Test
+    public void testSetLastPipComponentName_changed_clearReentryState() {
+        mPipBoundsState.setLastPipComponentName(mTestComponentName1);
+        mPipBoundsState.saveReentryState(DEFAULT_BOUNDS, DEFAULT_SNAP_FRACTION);
+
+        mPipBoundsState.setLastPipComponentName(mTestComponentName2);
+
+        assertNull(mPipBoundsState.getReentryState());
+    }
+}
diff --git a/libs/WindowManager/Shell/tests/unittest/src/com/android/wm/shell/pip/phone/PipTaskOrganizerTest.java b/libs/WindowManager/Shell/tests/unittest/src/com/android/wm/shell/pip/PipTaskOrganizerTest.java
similarity index 88%
rename from libs/WindowManager/Shell/tests/unittest/src/com/android/wm/shell/pip/phone/PipTaskOrganizerTest.java
rename to libs/WindowManager/Shell/tests/unittest/src/com/android/wm/shell/pip/PipTaskOrganizerTest.java
index 46ebbf3..2f7faaf 100644
--- a/libs/WindowManager/Shell/tests/unittest/src/com/android/wm/shell/pip/phone/PipTaskOrganizerTest.java
+++ b/libs/WindowManager/Shell/tests/unittest/src/com/android/wm/shell/pip/PipTaskOrganizerTest.java
@@ -14,7 +14,7 @@
  * limitations under the License.
  */
 
-package com.android.wm.shell.pip.phone;
+package com.android.wm.shell.pip;
 
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.anyInt;
@@ -27,12 +27,6 @@
 
 import com.android.wm.shell.ShellTaskOrganizer;
 import com.android.wm.shell.common.DisplayController;
-import com.android.wm.shell.pip.PipBoundsHandler;
-import com.android.wm.shell.pip.PipBoundsState;
-import com.android.wm.shell.pip.PipSurfaceTransactionHelper;
-import com.android.wm.shell.pip.PipTaskOrganizer;
-import com.android.wm.shell.pip.PipTestCase;
-import com.android.wm.shell.pip.PipUiEventLogger;
 import com.android.wm.shell.splitscreen.SplitScreen;
 
 import org.junit.Before;
diff --git a/libs/WindowManager/Shell/tests/unittest/src/com/android/wm/shell/pip/phone/PipTouchHandlerTest.java b/libs/WindowManager/Shell/tests/unittest/src/com/android/wm/shell/pip/phone/PipTouchHandlerTest.java
index 4713142..3f60cc0 100644
--- a/libs/WindowManager/Shell/tests/unittest/src/com/android/wm/shell/pip/phone/PipTouchHandlerTest.java
+++ b/libs/WindowManager/Shell/tests/unittest/src/com/android/wm/shell/pip/phone/PipTouchHandlerTest.java
@@ -93,7 +93,7 @@
     public void setUp() throws Exception {
         MockitoAnnotations.initMocks(this);
         mPipBoundsState = new PipBoundsState();
-        mPipBoundsHandler = new PipBoundsHandler(mContext);
+        mPipBoundsHandler = new PipBoundsHandler(mContext, mPipBoundsState);
         mPipSnapAlgorithm = mPipBoundsHandler.getSnapAlgorithm();
         mPipSnapAlgorithm = new PipSnapAlgorithm(mContext);
         mPipTouchHandler = new PipTouchHandler(mContext, mPipMenuActivityController,
diff --git a/packages/SystemUI/src/com/android/systemui/wmshell/TvPipModule.java b/packages/SystemUI/src/com/android/systemui/wmshell/TvPipModule.java
index e14af23..5310b3f 100644
--- a/packages/SystemUI/src/com/android/systemui/wmshell/TvPipModule.java
+++ b/packages/SystemUI/src/com/android/systemui/wmshell/TvPipModule.java
@@ -85,8 +85,9 @@
 
     @WMSingleton
     @Provides
-    static PipBoundsHandler providePipBoundsHandler(Context context) {
-        return new PipBoundsHandler(context);
+    static PipBoundsHandler providePipBoundsHandler(Context context,
+            PipBoundsState pipBoundsState) {
+        return new PipBoundsHandler(context, pipBoundsState);
     }
 
     @WMSingleton
diff --git a/packages/SystemUI/src/com/android/systemui/wmshell/WMShellModule.java b/packages/SystemUI/src/com/android/systemui/wmshell/WMShellModule.java
index 70632b2..a6fe728 100644
--- a/packages/SystemUI/src/com/android/systemui/wmshell/WMShellModule.java
+++ b/packages/SystemUI/src/com/android/systemui/wmshell/WMShellModule.java
@@ -97,8 +97,9 @@
 
     @WMSingleton
     @Provides
-    static PipBoundsHandler providePipBoundsHandler(Context context) {
-        return new PipBoundsHandler(context);
+    static PipBoundsHandler providesPipBoundsHandler(Context context,
+            PipBoundsState pipBoundsState) {
+        return new PipBoundsHandler(context, pipBoundsState);
     }
 
     @WMSingleton