Merge "Snap to current page's relative position after dismissal." into sc-v2-dev
diff --git a/quickstep/src/com/android/quickstep/views/RecentsView.java b/quickstep/src/com/android/quickstep/views/RecentsView.java
index 853a3aa..618d9aa 100644
--- a/quickstep/src/com/android/quickstep/views/RecentsView.java
+++ b/quickstep/src/com/android/quickstep/views/RecentsView.java
@@ -2807,14 +2807,49 @@
                     resetTaskVisuals();
 
                     int pageToSnapTo = mCurrentPage;
-                    if ((dismissedIndex < pageToSnapTo && !showAsGrid)
-                            || pageToSnapTo == taskCount - 1) {
-                        pageToSnapTo -= 1;
-                    }
+                    int taskViewIdToSnapTo = -1;
                     if (showAsGrid) {
+                        // Get the id of the task view we will snap to based on the current
+                        // page's relative position as the order of indices change over time due
+                        // to dismissals.
+                        TaskView snappedTaskView = getTaskViewAtByAbsoluteIndex(mCurrentPage);
+                        if (snappedTaskView != null) {
+                            if (snappedTaskView.getTaskViewId() == mFocusedTaskViewId) {
+                                if (finalNextFocusedTaskView != null) {
+                                    taskViewIdToSnapTo = finalNextFocusedTaskView.getTaskViewId();
+                                } else {
+                                    taskViewIdToSnapTo = mFocusedTaskViewId;
+                                }
+                            } else {
+                                int snappedTaskViewId = snappedTaskView.getTaskViewId();
+                                boolean isSnappedTaskInTopRow = mTopRowIdSet.contains(
+                                        snappedTaskViewId);
+                                IntArray taskViewIdArray =
+                                        isSnappedTaskInTopRow ? getTopRowIdArray()
+                                                : getBottomRowIdArray();
+                                int snappedIndex = taskViewIdArray.indexOf(snappedTaskViewId);
+                                taskViewIdArray.removeValue(dismissedTaskViewId);
+                                if (snappedIndex < taskViewIdArray.size()) {
+                                    taskViewIdToSnapTo = taskViewIdArray.get(snappedIndex);
+                                } else if (snappedIndex == taskViewIdArray.size()) {
+                                    // If the snapped task is the last item from the dismissed row,
+                                    // snap to the same column in the other grid row
+                                    IntArray inverseRowTaskViewIdArray =
+                                            isSnappedTaskInTopRow ? getBottomRowIdArray()
+                                                    : getTopRowIdArray();
+                                    if (snappedIndex < inverseRowTaskViewIdArray.size()) {
+                                        taskViewIdToSnapTo = inverseRowTaskViewIdArray.get(
+                                                snappedIndex);
+                                    }
+                                }
+                            }
+                        }
+
                         int primaryScroll = mOrientationHandler.getPrimaryScroll(RecentsView.this);
                         int currentPageScroll = getScrollForPage(pageToSnapTo);
                         mCurrentPageScrollDiff = primaryScroll - currentPageScroll;
+                    } else if (dismissedIndex < pageToSnapTo || pageToSnapTo == taskCount - 1) {
+                        pageToSnapTo -= 1;
                     }
                     removeViewInLayout(dismissedTaskView);
                     mTopRowIdSet.remove(dismissedTaskViewId);
@@ -2834,40 +2869,50 @@
                         // Update scroll and snap to page.
                         updateScrollSynchronously();
 
-                        int highestVisibleTaskIndex = getHighestVisibleTaskIndex();
-                        if (highestVisibleTaskIndex < Integer.MAX_VALUE) {
-                            TaskView taskView = getTaskViewAt(highestVisibleTaskIndex);
+                        if (showAsGrid) {
+                            // Rebalance tasks in the grid
+                            int highestVisibleTaskIndex = getHighestVisibleTaskIndex();
+                            if (highestVisibleTaskIndex < Integer.MAX_VALUE) {
+                                TaskView taskView = getTaskViewAt(highestVisibleTaskIndex);
 
-                            boolean shouldRebalance = false;
-                            int screenStart = mOrientationHandler.getPrimaryScroll(
-                                    RecentsView.this);
-                            int taskStart = mOrientationHandler.getChildStart(taskView)
-                                    + (int) taskView.getOffsetAdjustment(
-                                    /*fullscreenEnabled=*/ false,
-                                    /*gridEnabled=*/ true);
-
-                            // Rebalance only if there is a maximum gap between the task and the
-                            // screen's edge; this ensures that rebalanced tasks are outside the
-                            // visible screen.
-                            if (mIsRtl) {
-                                shouldRebalance = taskStart <= screenStart + mPageSpacing;
-                            } else {
-                                int screenEnd = screenStart + mOrientationHandler.getMeasuredSize(
+                                boolean shouldRebalance = false;
+                                int screenStart = mOrientationHandler.getPrimaryScroll(
                                         RecentsView.this);
-                                int taskSize = (int) (mOrientationHandler.getMeasuredSize(taskView)
-                                        * taskView.getSizeAdjustment(/*fullscreenEnabled=*/ false));
-                                int taskEnd = taskStart + taskSize;
+                                int taskStart = mOrientationHandler.getChildStart(taskView)
+                                        + (int) taskView.getOffsetAdjustment(/*fullscreenEnabled=*/
+                                        false, /*gridEnabled=*/ true);
 
-                                shouldRebalance = taskEnd >= screenEnd - mPageSpacing;
+                                // Rebalance only if there is a maximum gap between the task and the
+                                // screen's edge; this ensures that rebalanced tasks are outside the
+                                // visible screen.
+                                if (mIsRtl) {
+                                    shouldRebalance = taskStart <= screenStart + mPageSpacing;
+                                } else {
+                                    int screenEnd =
+                                            screenStart + mOrientationHandler.getMeasuredSize(
+                                                    RecentsView.this);
+                                    int taskSize = (int) (mOrientationHandler.getMeasuredSize(
+                                            taskView) * taskView
+                                            .getSizeAdjustment(/*fullscreenEnabled=*/false));
+                                    int taskEnd = taskStart + taskSize;
+
+                                    shouldRebalance = taskEnd >= screenEnd - mPageSpacing;
+                                }
+
+                                if (shouldRebalance) {
+                                    updateGridProperties(/*isTaskDismissal=*/ true,
+                                            highestVisibleTaskIndex);
+                                    updateScrollSynchronously();
+                                }
                             }
 
-                            if (shouldRebalance) {
-                                updateGridProperties(/*isTaskDismissal=*/ true,
-                                        highestVisibleTaskIndex);
-                                updateScrollSynchronously();
+                            // If snapping to another page due to indices rearranging, find the new
+                            // index after dismissal & rearrange using the task view id.
+                            if (taskViewIdToSnapTo != -1) {
+                                pageToSnapTo = indexOfChild(
+                                        getTaskViewFromTaskViewId(taskViewIdToSnapTo));
                             }
                         }
-
                         setCurrentPage(pageToSnapTo);
                         dispatchScrollChanged();
                     }
@@ -2880,10 +2925,32 @@
     }
 
     /**
+     * Returns all the tasks in the top row, without the focused task
+     */
+    private IntArray getTopRowIdArray() {
+        if (mTopRowIdSet.isEmpty()) {
+            return new IntArray(0);
+        }
+        IntArray topArray = new IntArray(mTopRowIdSet.size());
+        int taskViewCount = getTaskViewCount();
+        for (int i = 0; i < taskViewCount; i++) {
+            int taskViewId = getTaskViewAt(i).getTaskViewId();
+            if (mTopRowIdSet.contains(taskViewId)) {
+                topArray.add(taskViewId);
+            }
+        }
+        return topArray;
+    }
+
+    /**
      * Returns all the tasks in the bottom row, without the focused task
      */
     private IntArray getBottomRowIdArray() {
-        IntArray bottomArray = new IntArray();
+        int bottomRowIdArraySize = getTaskViewCount() - mTopRowIdSet.size() - 1;
+        if (bottomRowIdArraySize <= 0) {
+            return new IntArray(0);
+        }
+        IntArray bottomArray = new IntArray(bottomRowIdArraySize);
         int taskViewCount = getTaskViewCount();
         for (int i = 0; i < taskViewCount; i++) {
             int taskViewId = getTaskViewAt(i).getTaskViewId();
@@ -2904,7 +2971,7 @@
         if (mTopRowIdSet.isEmpty()) return Integer.MAX_VALUE; // return earlier
 
         int lastVisibleIndex = Integer.MAX_VALUE;
-        IntArray topRowIdArray = mTopRowIdSet.getArray();
+        IntArray topRowIdArray = getTopRowIdArray();
         IntArray bottomRowIdArray = getBottomRowIdArray();
         int balancedColumns = Math.min(bottomRowIdArray.size(), topRowIdArray.size());