Merge "Moving LauncherIcons to dagger" into main
diff --git a/quickstep/src/com/android/launcher3/model/WellbeingModel.java b/quickstep/src/com/android/launcher3/model/WellbeingModel.java
index 0f3aaa6..09433c5 100644
--- a/quickstep/src/com/android/launcher3/model/WellbeingModel.java
+++ b/quickstep/src/com/android/launcher3/model/WellbeingModel.java
@@ -109,9 +109,9 @@
                 ? Executors.UI_HELPER_EXECUTOR.getLooper()
                 : Executors.getPackageExecutor(mWellbeingProviderPkg).getLooper());
         mWellbeingAppChangeReceiver =
-                new SimpleBroadcastReceiver(mWorkerHandler, t -> restartObserver());
+                new SimpleBroadcastReceiver(context, mWorkerHandler, t -> restartObserver());
         mAppAddRemoveReceiver =
-                new SimpleBroadcastReceiver(mWorkerHandler, this::onAppPackageChanged);
+                new SimpleBroadcastReceiver(context, mWorkerHandler, this::onAppPackageChanged);
 
 
         mContentObserver = new ContentObserver(mWorkerHandler) {
@@ -148,8 +148,8 @@
     public void close() {
         if (!TextUtils.isEmpty(mWellbeingProviderPkg)) {
             mWorkerHandler.post(() -> {
-                mWellbeingAppChangeReceiver.unregisterReceiverSafely(mContext);
-                mAppAddRemoveReceiver.unregisterReceiverSafely(mContext);
+                mWellbeingAppChangeReceiver.unregisterReceiverSafely();
+                mAppAddRemoveReceiver.unregisterReceiverSafely();
                 mContext.getContentResolver().unregisterContentObserver(mContentObserver);
             });
         }
diff --git a/quickstep/src/com/android/launcher3/taskbar/TaskbarManager.java b/quickstep/src/com/android/launcher3/taskbar/TaskbarManager.java
index df3869e..96bcffd 100644
--- a/quickstep/src/com/android/launcher3/taskbar/TaskbarManager.java
+++ b/quickstep/src/com/android/launcher3/taskbar/TaskbarManager.java
@@ -124,8 +124,7 @@
     private final TaskbarNavButtonController mDefaultNavButtonController;
     private final ComponentCallbacks mDefaultComponentCallbacks;
 
-    private final SimpleBroadcastReceiver mShutdownReceiver =
-            new SimpleBroadcastReceiver(UI_HELPER_EXECUTOR, i -> destroyAllTaskbars());
+    private final SimpleBroadcastReceiver mShutdownReceiver;
 
     // The source for this provider is set when Launcher is available
     // We use 'non-destroyable' version here so the original provider won't be destroyed
@@ -183,8 +182,7 @@
 
     private boolean mUserUnlocked = false;
 
-    private final SimpleBroadcastReceiver mTaskbarBroadcastReceiver =
-            new SimpleBroadcastReceiver(UI_HELPER_EXECUTOR, this::showTaskbarFromBroadcast);
+    private final SimpleBroadcastReceiver mTaskbarBroadcastReceiver;
 
     private final AllAppsActionManager mAllAppsActionManager;
     private final RecentsDisplayModel mRecentsDisplayModel;
@@ -266,7 +264,13 @@
                 .register(NAV_BAR_KIDS_MODE, mOnSettingsChangeListener);
         Log.d(TASKBAR_NOT_DESTROYED_TAG, "registering component callbacks from constructor.");
         mPrimaryWindowContext.registerComponentCallbacks(mDefaultComponentCallbacks);
-        mShutdownReceiver.register(mPrimaryWindowContext, Intent.ACTION_SHUTDOWN);
+        mShutdownReceiver =
+                new SimpleBroadcastReceiver(
+                        mPrimaryWindowContext, UI_HELPER_EXECUTOR, i -> destroyAllTaskbars());
+        mTaskbarBroadcastReceiver =
+                new SimpleBroadcastReceiver(mPrimaryWindowContext,
+                        UI_HELPER_EXECUTOR, this::showTaskbarFromBroadcast);
+        mShutdownReceiver.register(Intent.ACTION_SHUTDOWN);
         UI_HELPER_EXECUTOR.execute(() -> {
             mSharedState.taskbarSystemActionPendingIntent = PendingIntent.getBroadcast(
                     mPrimaryWindowContext,
@@ -274,8 +278,7 @@
                     new Intent(ACTION_SHOW_TASKBAR).setPackage(
                             mPrimaryWindowContext.getPackageName()),
                     PendingIntent.FLAG_UPDATE_CURRENT | PendingIntent.FLAG_IMMUTABLE);
-            mTaskbarBroadcastReceiver.register(
-                    mPrimaryWindowContext, RECEIVER_NOT_EXPORTED, ACTION_SHOW_TASKBAR);
+            mTaskbarBroadcastReceiver.register(RECEIVER_NOT_EXPORTED, ACTION_SHOW_TASKBAR);
         });
 
         debugWhyTaskbarNotDestroyed("TaskbarManager created");
@@ -795,7 +798,7 @@
         mRecentsViewContainer = null;
         debugWhyTaskbarNotDestroyed("TaskbarManager#destroy()");
         removeActivityCallbacksAndListeners();
-        mTaskbarBroadcastReceiver.unregisterReceiverSafely(mPrimaryWindowContext);
+        mTaskbarBroadcastReceiver.unregisterReceiverSafely();
 
         if (mUserUnlocked) {
             DisplayController.INSTANCE.get(mPrimaryWindowContext).removeChangeListener(
@@ -807,7 +810,7 @@
                 .unregister(NAV_BAR_KIDS_MODE, mOnSettingsChangeListener);
         Log.d(TASKBAR_NOT_DESTROYED_TAG, "unregistering component callbacks from destroy().");
         mPrimaryWindowContext.unregisterComponentCallbacks(mDefaultComponentCallbacks);
-        mShutdownReceiver.unregisterReceiverSafely(mPrimaryWindowContext);
+        mShutdownReceiver.unregisterReceiverSafely();
         destroyAllTaskbars();
     }
 
diff --git a/quickstep/src/com/android/launcher3/taskbar/TaskbarUIController.java b/quickstep/src/com/android/launcher3/taskbar/TaskbarUIController.java
index e5d642d..89bcb41 100644
--- a/quickstep/src/com/android/launcher3/taskbar/TaskbarUIController.java
+++ b/quickstep/src/com/android/launcher3/taskbar/TaskbarUIController.java
@@ -283,7 +283,7 @@
                                     foundTask,
                                     taskContainer.getIconView().getDrawable(),
                                     taskContainer.getSnapshotView(),
-                                    taskContainer.getSplitAnimationThumbnail(),
+                                    taskContainer.getThumbnail(),
                                     null /* intent */,
                                     null /* user */,
                                     info);
diff --git a/quickstep/src/com/android/launcher3/uioverrides/touchcontrollers/TaskViewDismissTouchController.kt b/quickstep/src/com/android/launcher3/uioverrides/touchcontrollers/TaskViewDismissTouchController.kt
index 77a05c1..98737a5 100644
--- a/quickstep/src/com/android/launcher3/uioverrides/touchcontrollers/TaskViewDismissTouchController.kt
+++ b/quickstep/src/com/android/launcher3/uioverrides/touchcontrollers/TaskViewDismissTouchController.kt
@@ -26,10 +26,12 @@
 import com.android.launcher3.Utilities.isRtl
 import com.android.launcher3.Utilities.mapToRange
 import com.android.launcher3.touch.SingleAxisSwipeDetector
+import com.android.launcher3.util.MSDLPlayerWrapper
 import com.android.launcher3.util.TouchController
 import com.android.quickstep.views.RecentsView
 import com.android.quickstep.views.RecentsViewContainer
 import com.android.quickstep.views.TaskView
+import com.google.android.msdl.data.model.MSDLToken
 import kotlin.math.abs
 import kotlin.math.sign
 
@@ -53,6 +55,7 @@
     private var springAnimation: SpringAnimation? = null
     private var dismissLength: Int = 0
     private var verticalFactor: Int = 0
+    private var hasDismissThresholdHapticRun = false
     private var initialDisplacement: Float = 0f
 
     private fun canInterceptTouch(ev: MotionEvent): Boolean =
@@ -159,9 +162,30 @@
             }
             recentsView.redrawLiveTile()
         }
+        playDismissThresholdHaptic(displacement)
         return true
     }
 
+    /**
+     * Play a haptic to alert the user they have passed the dismiss threshold.
+     *
+     * <p>Check within a range of the threshold value, as the drag event does not necessarily happen
+     * at the exact threshold's displacement.
+     */
+    private fun playDismissThresholdHaptic(displacement: Float) {
+        val dismissThreshold = (DISMISS_THRESHOLD_FRACTION * dismissLength * verticalFactor)
+        val inHapticRange =
+            displacement >= (dismissThreshold - DISMISS_THRESHOLD_HAPTIC_RANGE) &&
+                displacement <= (dismissThreshold + DISMISS_THRESHOLD_HAPTIC_RANGE)
+        if (!inHapticRange) {
+            hasDismissThresholdHapticRun = false
+        } else if (!hasDismissThresholdHapticRun) {
+            MSDLPlayerWrapper.INSTANCE.get(recentsView.context)
+                .playToken(MSDLToken.SWIPE_THRESHOLD_INDICATOR)
+            hasDismissThresholdHapticRun = true
+        }
+    }
+
     override fun onDragEnd(velocity: Float) {
         val taskBeingDragged = taskBeingDragged ?: return
 
@@ -208,5 +232,6 @@
 
     companion object {
         private const val DISMISS_THRESHOLD_FRACTION = 0.5f
+        private const val DISMISS_THRESHOLD_HAPTIC_RANGE = 10f
     }
 }
diff --git a/quickstep/src/com/android/quickstep/OverviewComponentObserver.java b/quickstep/src/com/android/quickstep/OverviewComponentObserver.java
index 1f95c41..bc3de41 100644
--- a/quickstep/src/com/android/quickstep/OverviewComponentObserver.java
+++ b/quickstep/src/com/android/quickstep/OverviewComponentObserver.java
@@ -72,12 +72,9 @@
             new DaggerSingletonObject<>(LauncherAppComponent::getOverviewComponentObserver);
 
     // We register broadcast receivers on main thread to avoid missing updates.
-    private final SimpleBroadcastReceiver mUserPreferenceChangeReceiver =
-            new SimpleBroadcastReceiver(MAIN_EXECUTOR, this::updateOverviewTargets);
-    private final SimpleBroadcastReceiver mOtherHomeAppUpdateReceiver =
-            new SimpleBroadcastReceiver(MAIN_EXECUTOR, this::updateOverviewTargets);
+    private final SimpleBroadcastReceiver mUserPreferenceChangeReceiver;
+    private final SimpleBroadcastReceiver mOtherHomeAppUpdateReceiver;
 
-    private final Context mContext;
     private final RecentsDisplayModel mRecentsDisplayModel;
 
     private final Intent mCurrentHomeIntent;
@@ -101,10 +98,13 @@
             @ApplicationContext Context context,
             RecentsDisplayModel recentsDisplayModel,
             DaggerSingletonTracker lifecycleTracker) {
-        mContext = context;
+        mUserPreferenceChangeReceiver =
+                new SimpleBroadcastReceiver(context, MAIN_EXECUTOR, this::updateOverviewTargets);
+        mOtherHomeAppUpdateReceiver =
+                new SimpleBroadcastReceiver(context, MAIN_EXECUTOR, this::updateOverviewTargets);
         mRecentsDisplayModel = recentsDisplayModel;
         mCurrentHomeIntent = createHomeIntent();
-        mMyHomeIntent = new Intent(mCurrentHomeIntent).setPackage(mContext.getPackageName());
+        mMyHomeIntent = new Intent(mCurrentHomeIntent).setPackage(context.getPackageName());
         ResolveInfo info = context.getPackageManager().resolveActivity(mMyHomeIntent, 0);
         ComponentName myHomeComponent =
                 new ComponentName(context.getPackageName(), info.activityInfo.name);
@@ -112,7 +112,7 @@
         mConfigChangesMap.append(myHomeComponent.hashCode(), info.activityInfo.configChanges);
         mSetupWizardPkg = context.getString(R.string.setup_wizard_pkg);
 
-        ComponentName fallbackComponent = new ComponentName(mContext, RecentsActivity.class);
+        ComponentName fallbackComponent = new ComponentName(context, RecentsActivity.class);
         mFallbackIntent = new Intent(Intent.ACTION_MAIN)
                 .addCategory(Intent.CATEGORY_DEFAULT)
                 .setComponent(fallbackComponent)
@@ -124,7 +124,7 @@
             mConfigChangesMap.append(fallbackComponent.hashCode(), fallbackInfo.configChanges);
         } catch (PackageManager.NameNotFoundException ignored) { /* Impossible */ }
 
-        mUserPreferenceChangeReceiver.register(mContext, ACTION_PREFERRED_ACTIVITY_CHANGED);
+        mUserPreferenceChangeReceiver.register(ACTION_PREFERRED_ACTIVITY_CHANGED);
         updateOverviewTargets();
 
         lifecycleTracker.addCloseable(this::onDestroy);
@@ -224,7 +224,7 @@
 
                 mUpdateRegisteredPackage = defaultHome.getPackageName();
                 mOtherHomeAppUpdateReceiver.registerPkgActions(
-                        mContext, mUpdateRegisteredPackage, ACTION_PACKAGE_ADDED,
+                        mUpdateRegisteredPackage, ACTION_PACKAGE_ADDED,
                         ACTION_PACKAGE_CHANGED, ACTION_PACKAGE_REMOVED);
             }
         }
@@ -235,13 +235,13 @@
      * Clean up any registered receivers.
      */
     private void onDestroy() {
-        mUserPreferenceChangeReceiver.unregisterReceiverSafely(mContext);
+        mUserPreferenceChangeReceiver.unregisterReceiverSafely();
         unregisterOtherHomeAppUpdateReceiver();
     }
 
     private void unregisterOtherHomeAppUpdateReceiver() {
         if (mUpdateRegisteredPackage != null) {
-            mOtherHomeAppUpdateReceiver.unregisterReceiverSafely(mContext);
+            mOtherHomeAppUpdateReceiver.unregisterReceiverSafely();
             mUpdateRegisteredPackage = null;
         }
     }
diff --git a/quickstep/src/com/android/quickstep/TaskShortcutFactory.java b/quickstep/src/com/android/quickstep/TaskShortcutFactory.java
index 7990aae..a69d472 100644
--- a/quickstep/src/com/android/quickstep/TaskShortcutFactory.java
+++ b/quickstep/src/com/android/quickstep/TaskShortcutFactory.java
@@ -236,15 +236,14 @@
                         position[0] + width, position[1] + height);
 
                 // Take the thumbnail of the task without a scrim and apply it back after
-                // TODO(b/348643341) add ability to get override the scrim for this Bitmap retrieval
-                float alpha = 0f;
-                if (!enableRefactorTaskThumbnail()) {
-                    alpha = mTaskContainer.getThumbnailViewDeprecated().getDimAlpha();
+                Bitmap thumbnail;
+                if (enableRefactorTaskThumbnail()) {
+                    thumbnail = mTaskContainer.getThumbnail();
+                } else {
+                    float alpha = mTaskContainer.getThumbnailViewDeprecated().getDimAlpha();
                     mTaskContainer.getThumbnailViewDeprecated().setDimAlpha(0);
-                }
-                Bitmap thumbnail = RecentsTransition.drawViewIntoHardwareBitmap(
-                        taskBounds.width(), taskBounds.height(), snapShotView, 1f, Color.BLACK);
-                if (!enableRefactorTaskThumbnail()) {
+                    thumbnail = RecentsTransition.drawViewIntoHardwareBitmap(
+                            taskBounds.width(), taskBounds.height(), snapShotView, 1f, Color.BLACK);
                     mTaskContainer.getThumbnailViewDeprecated().setDimAlpha(alpha);
                 }
 
diff --git a/quickstep/src/com/android/quickstep/task/thumbnail/TaskThumbnailView.kt b/quickstep/src/com/android/quickstep/task/thumbnail/TaskThumbnailView.kt
index 63e93ba..87aace8 100644
--- a/quickstep/src/com/android/quickstep/task/thumbnail/TaskThumbnailView.kt
+++ b/quickstep/src/com/android/quickstep/task/thumbnail/TaskThumbnailView.kt
@@ -74,6 +74,16 @@
     private var taskThumbnailViewHeader: TaskThumbnailViewHeader? = null
 
     private var uiState: TaskThumbnailUiState = Uninitialized
+
+    /**
+     * Sets the outline bounds of the view. Default to use view's bound as outline when set to null.
+     */
+    var outlineBounds: Rect? = null
+        set(value) {
+            field = value
+            invalidateOutline()
+        }
+
     private val bounds = Rect()
 
     var cornerRadius: Float = 0f
@@ -109,7 +119,7 @@
         outlineProvider =
             object : ViewOutlineProvider() {
                 override fun getOutline(view: View, outline: Outline) {
-                    outline.setRoundRect(bounds, cornerRadius)
+                    outline.setRoundRect(outlineBounds ?: bounds, cornerRadius)
                 }
             }
     }
@@ -124,6 +134,7 @@
 
     override fun onRecycle() {
         uiState = Uninitialized
+        outlineBounds = null
         resetViews()
     }
 
diff --git a/quickstep/src/com/android/quickstep/util/AsyncClockEventDelegate.java b/quickstep/src/com/android/quickstep/util/AsyncClockEventDelegate.java
index 54f6443..207e482 100644
--- a/quickstep/src/com/android/quickstep/util/AsyncClockEventDelegate.java
+++ b/quickstep/src/com/android/quickstep/util/AsyncClockEventDelegate.java
@@ -59,8 +59,7 @@
 
     private final Context mContext;
     private final SettingsCache mSettingsCache;
-    private final SimpleBroadcastReceiver mReceiver =
-            new SimpleBroadcastReceiver(UI_HELPER_EXECUTOR, this::onClockEventReceived);
+    private final SimpleBroadcastReceiver mReceiver;
 
     private final ArrayMap<BroadcastReceiver, Handler> mTimeEventReceivers = new ArrayMap<>();
     private final List<ContentObserver> mFormatObservers = new ArrayList<>();
@@ -76,7 +75,9 @@
         super(context);
         mContext = context;
         mSettingsCache = settingsCache;
-        mReceiver.register(mContext, ACTION_TIME_CHANGED, ACTION_TIMEZONE_CHANGED);
+        mReceiver = new SimpleBroadcastReceiver(
+                context, UI_HELPER_EXECUTOR, this::onClockEventReceived);
+        mReceiver.register(ACTION_TIME_CHANGED, ACTION_TIMEZONE_CHANGED);
         tracker.addCloseable(this);
     }
 
@@ -138,6 +139,6 @@
     public void close() {
         mDestroyed = true;
         mSettingsCache.unregister(mFormatUri, this);
-        mReceiver.unregisterReceiverSafely(mContext);
+        mReceiver.unregisterReceiverSafely();
     }
 }
diff --git a/quickstep/src/com/android/quickstep/util/ContextualSearchStateManager.java b/quickstep/src/com/android/quickstep/util/ContextualSearchStateManager.java
index ed96399..a8d3c6d 100644
--- a/quickstep/src/com/android/quickstep/util/ContextualSearchStateManager.java
+++ b/quickstep/src/com/android/quickstep/util/ContextualSearchStateManager.java
@@ -74,8 +74,7 @@
             Settings.Secure.getUriFor(Settings.Secure.SEARCH_ALL_ENTRYPOINTS_ENABLED);
 
     private final Runnable mSysUiStateChangeListener = this::updateOverridesToSysUi;
-    private final SimpleBroadcastReceiver mContextualSearchPackageReceiver =
-            new SimpleBroadcastReceiver(UI_HELPER_EXECUTOR, (unused) -> requestUpdateProperties());
+    private final SimpleBroadcastReceiver mContextualSearchPackageReceiver;
     protected final EventLogArray mEventLogArray = new EventLogArray(TAG, MAX_DEBUG_EVENT_SIZE);
 
     // Cached value whether the ContextualSearch intent filter matched any enabled components.
@@ -95,6 +94,9 @@
             TopTaskTracker topTaskTracker,
             DaggerSingletonTracker lifeCycle) {
         mContext = context;
+        mContextualSearchPackageReceiver =
+                new SimpleBroadcastReceiver(context, UI_HELPER_EXECUTOR,
+                        (unused) -> requestUpdateProperties());
         mContextualSearchPackage = mContext.getResources().getString(
                 com.android.internal.R.string.config_defaultContextualSearchPackageName);
         mSystemUiProxy = systemUiProxy;
@@ -112,7 +114,7 @@
         requestUpdateProperties();
         registerSearchScreenSystemAction();
         mContextualSearchPackageReceiver.registerPkgActions(
-                context, mContextualSearchPackage, Intent.ACTION_PACKAGE_ADDED,
+                mContextualSearchPackage, Intent.ACTION_PACKAGE_ADDED,
                 Intent.ACTION_PACKAGE_CHANGED, Intent.ACTION_PACKAGE_REMOVED);
 
         SettingsCache.OnChangeListener settingChangedListener =
@@ -124,7 +126,7 @@
         systemUiProxy.addOnStateChangeListener(mSysUiStateChangeListener);
 
         lifeCycle.addCloseable(() -> {
-            mContextualSearchPackageReceiver.unregisterReceiverSafely(mContext);
+            mContextualSearchPackageReceiver.unregisterReceiverSafely();
             unregisterSearchScreenSystemAction();
             settingsCache.unregister(SEARCH_ALL_ENTRYPOINTS_ENABLED_URI, settingChangedListener);
             systemUiProxy.removeOnStateChangeListener(mSysUiStateChangeListener);
diff --git a/quickstep/src/com/android/quickstep/util/SplitAnimationController.kt b/quickstep/src/com/android/quickstep/util/SplitAnimationController.kt
index 88d14b7..a1963c9 100644
--- a/quickstep/src/com/android/quickstep/util/SplitAnimationController.kt
+++ b/quickstep/src/com/android/quickstep/util/SplitAnimationController.kt
@@ -127,7 +127,7 @@
                     val drawable = getDrawable(container.iconView, splitSelectSource)
                     return SplitAnimInitProps(
                         container.snapshotView,
-                        container.splitAnimationThumbnail,
+                        container.thumbnail,
                         drawable,
                         fadeWithThumbnail = true,
                         isStagedTask = true,
@@ -147,7 +147,7 @@
                 val drawable = getDrawable(it.iconView, splitSelectSource)
                 return SplitAnimInitProps(
                     it.snapshotView,
-                    it.splitAnimationThumbnail,
+                    it.thumbnail,
                     drawable,
                     fadeWithThumbnail = true,
                     isStagedTask = true,
diff --git a/quickstep/src/com/android/quickstep/util/TaskGridNavHelper.java b/quickstep/src/com/android/quickstep/util/TaskGridNavHelper.java
deleted file mode 100644
index 35e90f2..0000000
--- a/quickstep/src/com/android/quickstep/util/TaskGridNavHelper.java
+++ /dev/null
@@ -1,148 +0,0 @@
-/*
- * Copyright (C) 2023 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package com.android.quickstep.util;
-
-import static java.lang.annotation.RetentionPolicy.SOURCE;
-
-import androidx.annotation.IntDef;
-
-import com.android.launcher3.util.IntArray;
-
-import java.lang.annotation.Retention;
-import java.util.List;
-
-/**
- * Helper class for navigating RecentsView grid tasks via arrow keys and tab.
- */
-public class TaskGridNavHelper {
-    public static final int CLEAR_ALL_PLACEHOLDER_ID = -1;
-    public static final int ADD_DESK_PLACEHOLDER_ID = -2;
-
-    public static final int DIRECTION_UP = 0;
-    public static final int DIRECTION_DOWN = 1;
-    public static final int DIRECTION_LEFT = 2;
-    public static final int DIRECTION_RIGHT = 3;
-    public static final int DIRECTION_TAB = 4;
-
-    @Retention(SOURCE)
-    @IntDef({DIRECTION_UP, DIRECTION_DOWN, DIRECTION_LEFT, DIRECTION_RIGHT, DIRECTION_TAB})
-    public @interface TASK_NAV_DIRECTION {}
-
-    private final IntArray mOriginalTopRowIds;
-    private final IntArray mTopRowIds = new IntArray();
-    private final IntArray mBottomRowIds = new IntArray();
-
-    public TaskGridNavHelper(IntArray topIds, IntArray bottomIds,
-            List<Integer> largeTileIds, boolean hasAddDesktopButton) {
-        mOriginalTopRowIds = topIds.clone();
-        generateTaskViewIdGrid(topIds, bottomIds, largeTileIds, hasAddDesktopButton);
-    }
-
-    private void generateTaskViewIdGrid(IntArray topRowIdArray, IntArray bottomRowIdArray,
-            List<Integer> largeTileIds, boolean hasAddDesktopButton) {
-        // Add AddDesktopButton and lage tiles to both rows.
-        if (hasAddDesktopButton) {
-            mTopRowIds.add(ADD_DESK_PLACEHOLDER_ID);
-            mBottomRowIds.add(ADD_DESK_PLACEHOLDER_ID);
-        }
-        for (Integer tileId : largeTileIds) {
-            mTopRowIds.add(tileId);
-            mBottomRowIds.add(tileId);
-        }
-
-        // Add row ids to their respective rows.
-        mTopRowIds.addAll(topRowIdArray);
-        mBottomRowIds.addAll(bottomRowIdArray);
-
-        // Fill in the shorter array with the ids from the longer one.
-        while (mTopRowIds.size() > mBottomRowIds.size()) {
-            mBottomRowIds.add(mTopRowIds.get(mBottomRowIds.size()));
-        }
-        while (mBottomRowIds.size() > mTopRowIds.size()) {
-            mTopRowIds.add(mBottomRowIds.get(mTopRowIds.size()));
-        }
-
-        // Add the clear all button to the end of both arrays.
-        mTopRowIds.add(CLEAR_ALL_PLACEHOLDER_ID);
-        mBottomRowIds.add(CLEAR_ALL_PLACEHOLDER_ID);
-    }
-
-    /**
-     * Returns the id of the next page in the grid or -1 for the clear all button.
-     */
-    public int getNextGridPage(int currentPageTaskViewId, int delta,
-            @TASK_NAV_DIRECTION int direction, boolean cycle) {
-        boolean inTop = mTopRowIds.contains(currentPageTaskViewId);
-        int index = inTop ? mTopRowIds.indexOf(currentPageTaskViewId)
-                : mBottomRowIds.indexOf(currentPageTaskViewId);
-        int maxSize = Math.max(mTopRowIds.size(), mBottomRowIds.size());
-        int nextIndex = index + delta;
-
-        switch (direction) {
-            case DIRECTION_UP:
-            case DIRECTION_DOWN: {
-                return inTop ? mBottomRowIds.get(index) : mTopRowIds.get(index);
-            }
-            case DIRECTION_LEFT: {
-                int boundedIndex = cycle ? nextIndex % maxSize : Math.min(nextIndex, maxSize - 1);
-                return inTop ? mTopRowIds.get(boundedIndex)
-                        : mBottomRowIds.get(boundedIndex);
-            }
-            case DIRECTION_RIGHT: {
-                int boundedIndex =
-                        cycle ? (nextIndex < 0 ? maxSize - 1 : nextIndex)
-                                : Math.max(nextIndex, 0);
-                boolean inOriginalTop = mOriginalTopRowIds.contains(currentPageTaskViewId);
-                return inOriginalTop ? mTopRowIds.get(boundedIndex)
-                        : mBottomRowIds.get(boundedIndex);
-            }
-            case DIRECTION_TAB: {
-                int boundedIndex =
-                        cycle ? (nextIndex < 0 ? maxSize - 1 : nextIndex % maxSize)
-                                : Math.min(nextIndex, maxSize - 1);
-                if (delta >= 0) {
-                    return inTop && mTopRowIds.get(index) != mBottomRowIds.get(index)
-                            ? mBottomRowIds.get(index)
-                            : mTopRowIds.get(boundedIndex);
-                } else {
-                    if (mTopRowIds.contains(currentPageTaskViewId)) {
-                        if (boundedIndex < 0) {
-                            // If no cycling, always return the first task.
-                            return mTopRowIds.get(0);
-                        } else {
-                            return mBottomRowIds.get(boundedIndex);
-                        }
-                    } else {
-                        // Go up to top if there is task above
-                        return mTopRowIds.get(index) != mBottomRowIds.get(index)
-                                ? mTopRowIds.get(index)
-                                : mBottomRowIds.get(boundedIndex);
-                    }
-                }
-            }
-            default:
-                return currentPageTaskViewId;
-        }
-    }
-
-    /**
-     * Returns the column of a task's id in the grid.
-     */
-    public int getColumn(int taskViewId) {
-        return mTopRowIds.contains(taskViewId) ? mTopRowIds.indexOf(taskViewId)
-                : mBottomRowIds.indexOf(taskViewId);
-    }
-}
diff --git a/quickstep/src/com/android/quickstep/util/TaskGridNavHelper.kt b/quickstep/src/com/android/quickstep/util/TaskGridNavHelper.kt
new file mode 100644
index 0000000..0e78801
--- /dev/null
+++ b/quickstep/src/com/android/quickstep/util/TaskGridNavHelper.kt
@@ -0,0 +1,159 @@
+/*
+ * Copyright (C) 2025 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.android.quickstep.util
+
+import com.android.launcher3.util.IntArray
+import kotlin.math.abs
+import kotlin.math.max
+
+/** Helper class for navigating RecentsView grid tasks via arrow keys and tab. */
+class TaskGridNavHelper(
+    private val topIds: IntArray,
+    bottomIds: IntArray,
+    largeTileIds: List<Int>,
+    hasAddDesktopButton: Boolean,
+) {
+    private val topRowIds = mutableListOf<Int>()
+    private val bottomRowIds = mutableListOf<Int>()
+
+    init {
+        // Add AddDesktopButton and lage tiles to both rows.
+        if (hasAddDesktopButton) {
+            topRowIds += ADD_DESK_PLACEHOLDER_ID
+            bottomRowIds += ADD_DESK_PLACEHOLDER_ID
+        }
+        topRowIds += largeTileIds
+        bottomRowIds += largeTileIds
+
+        // Add row ids to their respective rows.
+        topRowIds += topIds
+        bottomRowIds += bottomIds
+
+        // Fill in the shorter array with the ids from the longer one.
+        topRowIds += bottomRowIds.takeLast(max(bottomRowIds.size - topRowIds.size, 0))
+        bottomRowIds += topRowIds.takeLast(max(topRowIds.size - bottomRowIds.size, 0))
+
+        // Add the clear all button to the end of both arrays.
+        topRowIds += CLEAR_ALL_PLACEHOLDER_ID
+        bottomRowIds += CLEAR_ALL_PLACEHOLDER_ID
+    }
+
+    /** Returns the id of the next page in the grid or -1 for the clear all button. */
+    fun getNextGridPage(
+        currentPageTaskViewId: Int,
+        delta: Int,
+        direction: TaskNavDirection,
+        cycle: Boolean,
+    ): Int {
+        val inTop = topRowIds.contains(currentPageTaskViewId)
+        val index =
+            if (inTop) topRowIds.indexOf(currentPageTaskViewId)
+            else bottomRowIds.indexOf(currentPageTaskViewId)
+        val maxSize = max(topRowIds.size, bottomRowIds.size)
+        val nextIndex = index + delta
+
+        return when (direction) {
+            TaskNavDirection.UP,
+            TaskNavDirection.DOWN -> {
+                if (inTop) bottomRowIds[index] else topRowIds[index]
+            }
+            TaskNavDirection.LEFT -> {
+                val boundedIndex =
+                    if (cycle) nextIndex % maxSize else nextIndex.coerceAtMost(maxSize - 1)
+                if (inTop) topRowIds[boundedIndex] else bottomRowIds[boundedIndex]
+            }
+            TaskNavDirection.RIGHT -> {
+                val boundedIndex =
+                    if (cycle) (if (nextIndex < 0) maxSize - 1 else nextIndex)
+                    else nextIndex.coerceAtLeast(0)
+                val inOriginalTop = topIds.contains(currentPageTaskViewId)
+                if (inOriginalTop) topRowIds[boundedIndex] else bottomRowIds[boundedIndex]
+            }
+            TaskNavDirection.TAB -> {
+                val boundedIndex =
+                    if (cycle) (if (nextIndex < 0) maxSize - 1 else nextIndex % maxSize)
+                    else nextIndex.coerceAtMost(maxSize - 1)
+                if (delta >= 0) {
+                    if (inTop && topRowIds[index] != bottomRowIds[index]) bottomRowIds[index]
+                    else topRowIds[boundedIndex]
+                } else {
+                    if (topRowIds.contains(currentPageTaskViewId)) {
+                        if (boundedIndex < 0) {
+                            // If no cycling, always return the first task.
+                            topRowIds[0]
+                        } else {
+                            bottomRowIds[boundedIndex]
+                        }
+                    } else {
+                        // Go up to top if there is task above
+                        if (topRowIds[index] != bottomRowIds[index]) topRowIds[index]
+                        else bottomRowIds[boundedIndex]
+                    }
+                }
+            }
+            else -> currentPageTaskViewId
+        }
+    }
+
+    /**
+     * Returns a sequence of pairs of (TaskView ID, offset) in the grid, ordered according to tab
+     * navigation, starting from the initial TaskView ID, towards the start or end of the grid.
+     *
+     * <p>A positive delta moves forward in the tab order towards the end of the grid, while a
+     * negative value moves backward towards the beginning. The offset is the distance between
+     * columns the tasks are in.
+     */
+    fun gridTaskViewIdOffsetPairInTabOrderSequence(
+        initialTaskViewId: Int,
+        towardsStart: Boolean,
+    ): Sequence<Pair<Int, Int>> = sequence {
+        val draggedTaskViewColumn = getColumn(initialTaskViewId)
+        var nextTaskViewId: Int = initialTaskViewId
+        var previousTaskViewId: Int = Int.MIN_VALUE
+        while (nextTaskViewId != previousTaskViewId && nextTaskViewId >= 0) {
+            previousTaskViewId = nextTaskViewId
+            nextTaskViewId =
+                getNextGridPage(
+                    nextTaskViewId,
+                    if (towardsStart) -1 else 1,
+                    TaskNavDirection.TAB,
+                    cycle = false,
+                )
+            if (nextTaskViewId >= 0 && nextTaskViewId != previousTaskViewId) {
+                val columnOffset = abs(getColumn(nextTaskViewId) - draggedTaskViewColumn)
+                yield(Pair(nextTaskViewId, columnOffset))
+            }
+        }
+    }
+
+    /** Returns the column of a task's id in the grid. */
+    private fun getColumn(taskViewId: Int): Int =
+        if (topRowIds.contains(taskViewId)) topRowIds.indexOf(taskViewId)
+        else bottomRowIds.indexOf(taskViewId)
+
+    enum class TaskNavDirection {
+        UP,
+        DOWN,
+        LEFT,
+        RIGHT,
+        TAB,
+    }
+
+    companion object {
+        const val CLEAR_ALL_PLACEHOLDER_ID: Int = -1
+        const val ADD_DESK_PLACEHOLDER_ID: Int = -2
+    }
+}
diff --git a/quickstep/src/com/android/quickstep/util/TransformParams.java b/quickstep/src/com/android/quickstep/util/TransformParams.java
index 1c1fbd8..716803a 100644
--- a/quickstep/src/com/android/quickstep/util/TransformParams.java
+++ b/quickstep/src/com/android/quickstep/util/TransformParams.java
@@ -222,7 +222,7 @@
         return null;
     }
 
-    // Pubic getters so outside packages can read the values.
+    // Public getters so outside packages can read the values.
 
     public float getProgress() {
         return mProgress;
diff --git a/quickstep/src/com/android/quickstep/views/DesktopTaskView.kt b/quickstep/src/com/android/quickstep/views/DesktopTaskView.kt
index da160f1..fbda3b3 100644
--- a/quickstep/src/com/android/quickstep/views/DesktopTaskView.kt
+++ b/quickstep/src/com/android/quickstep/views/DesktopTaskView.kt
@@ -21,6 +21,7 @@
 import android.graphics.Matrix
 import android.graphics.PointF
 import android.graphics.Rect
+import android.graphics.Rect.intersects
 import android.graphics.RectF
 import android.util.AttributeSet
 import android.util.Log
@@ -30,6 +31,7 @@
 import android.view.ViewStub
 import androidx.core.content.res.ResourcesCompat
 import androidx.core.view.updateLayoutParams
+import com.android.internal.hidden_from_bootclasspath.com.android.window.flags.Flags.enableDesktopRecentsTransitionsCornersBugfix
 import com.android.launcher3.Flags.enableDesktopExplodedView
 import com.android.launcher3.Flags.enableOverviewIconMenu
 import com.android.launcher3.Flags.enableRefactorTaskThumbnail
@@ -57,6 +59,7 @@
 import com.android.quickstep.task.thumbnail.TaskThumbnailView
 import com.android.quickstep.util.DesktopTask
 import com.android.quickstep.util.RecentsOrientedState
+import kotlin.math.roundToInt
 
 /** TaskView that contains all tasks that are part of the desktop. */
 class DesktopTaskView @JvmOverloads constructor(context: Context, attrs: AttributeSet? = null) :
@@ -102,7 +105,7 @@
      * Holds the default (user placed) positions of task windows. This can be moved into the
      * viewModel once RefactorTaskThumbnail has been launched.
      */
-    private var defaultTaskPositions: List<DesktopTaskBoundsData> = emptyList()
+    private var fullscreenTaskPositions: List<DesktopTaskBoundsData> = emptyList()
 
     /**
      * When enableDesktopExplodedView is enabled, this controls the gradual transition from the
@@ -172,42 +175,44 @@
 
         val thumbnailTopMarginPx = container.deviceProfile.overviewTaskThumbnailTopMarginPx
 
-        val containerWidth = layoutParams.width
-        val containerHeight = layoutParams.height - thumbnailTopMarginPx
+        val taskViewWidth = layoutParams.width
+        val taskViewHeight = layoutParams.height - thumbnailTopMarginPx
 
         BaseContainerInterface.getTaskDimension(mContext, container.deviceProfile, tempPointF)
 
-        val windowWidth = tempPointF.x.toInt()
-        val windowHeight = tempPointF.y.toInt()
-        val scaleWidth = containerWidth / windowWidth.toFloat()
-        val scaleHeight = containerHeight / windowHeight.toFloat()
+        val screenWidth = tempPointF.x.toInt()
+        val screenHeight = tempPointF.y.toInt()
+        val screenRect = Rect(0, 0, screenWidth, screenHeight)
+        val scaleWidth = taskViewWidth / screenWidth.toFloat()
+        val scaleHeight = taskViewHeight / screenHeight.toFloat()
 
         taskContainers.forEach {
             val taskId = it.task.key.id
-            val defaultPosition = defaultTaskPositions.firstOrNull { it.taskId == taskId } ?: return
-            val position =
+            val fullscreenTaskPosition =
+                fullscreenTaskPositions.firstOrNull { it.taskId == taskId } ?: return
+            val overviewTaskPosition =
                 if (enableDesktopExplodedView()) {
                     viewModel!!
                         .organizedDesktopTaskPositions
                         .firstOrNull { it.taskId == taskId }
                         ?.let { organizedPosition ->
-                            TEMP_RECT.apply {
+                            TEMP_OVERVIEW_TASK_POSITION.apply {
                                 lerpRect(
-                                    defaultPosition.bounds,
+                                    fullscreenTaskPosition.bounds,
                                     organizedPosition.bounds,
                                     explodeProgress,
                                 )
                             }
-                        } ?: defaultPosition.bounds
+                        } ?: fullscreenTaskPosition.bounds
                 } else {
-                    defaultPosition.bounds
+                    fullscreenTaskPosition.bounds
                 }
 
             if (enableDesktopExplodedView()) {
                 getRemoteTargetHandle(taskId)?.let { remoteTargetHandle ->
                     val fromRect =
-                        TEMP_RECTF1.apply {
-                            set(defaultPosition.bounds)
+                        TEMP_FROM_RECTF.apply {
+                            set(fullscreenTaskPosition.bounds)
                             scale(scaleWidth)
                             offset(
                                 lastComputedTaskSize.left.toFloat(),
@@ -215,8 +220,8 @@
                             )
                         }
                     val toRect =
-                        TEMP_RECTF2.apply {
-                            set(position)
+                        TEMP_TO_RECTF.apply {
+                            set(overviewTaskPosition)
                             scale(scaleWidth)
                             offset(
                                 lastComputedTaskSize.left.toFloat(),
@@ -230,10 +235,10 @@
                 }
             }
 
-            val taskLeft = position.left * scaleWidth
-            val taskTop = position.top * scaleHeight
-            val taskWidth = position.width() * scaleWidth
-            val taskHeight = position.height() * scaleHeight
+            val taskLeft = overviewTaskPosition.left * scaleWidth
+            val taskTop = overviewTaskPosition.top * scaleHeight
+            val taskWidth = overviewTaskPosition.width() * scaleWidth
+            val taskHeight = overviewTaskPosition.height() * scaleHeight
             // TODO(b/394660950): Revisit the choice to update the layout when explodeProgress == 1.
             // To run the explode animation in reverse, it may be simpler to use translation/scale
             // for all cases where the progress is non-zero.
@@ -254,6 +259,24 @@
                     leftMargin = taskLeft.toInt()
                     topMargin = taskTop.toInt()
                 }
+
+                if (
+                    enableDesktopRecentsTransitionsCornersBugfix() && enableRefactorTaskThumbnail()
+                ) {
+                    it.thumbnailView.outlineBounds =
+                        if (intersects(overviewTaskPosition, screenRect))
+                            Rect(overviewTaskPosition).apply {
+                                intersectUnchecked(screenRect)
+                                // Offset to 0,0 to transform into TaskThumbnailView's coordinate
+                                // system.
+                                offset(-overviewTaskPosition.left, -overviewTaskPosition.top)
+                                left = (left * scaleWidth).roundToInt()
+                                top = (top * scaleHeight).roundToInt()
+                                right = (right * scaleWidth).roundToInt()
+                                bottom = (bottom * scaleHeight).roundToInt()
+                            }
+                        else null
+                }
             } else {
                 // During the animation, apply translation and scale such that the view is
                 // transformed to where we want, without triggering layout.
@@ -346,13 +369,13 @@
         val desktopSize = Size(tempPointF.x.toInt(), tempPointF.y.toInt())
         DEFAULT_BOUNDS.set(0, 0, desktopSize.width / 4, desktopSize.height / 4)
 
-        defaultTaskPositions =
+        fullscreenTaskPositions =
             taskContainers.map {
                 DesktopTaskBoundsData(it.task.key.id, it.task.appBounds ?: DEFAULT_BOUNDS)
             }
 
         if (enableDesktopExplodedView()) {
-            viewModel?.organizeDesktopTasks(desktopSize, defaultTaskPositions)
+            viewModel?.organizeDesktopTasks(desktopSize, fullscreenTaskPositions)
         }
         positionTaskWindows()
     }
@@ -449,8 +472,8 @@
         private const val VIEW_POOL_INITIAL_SIZE = 0
         private val DEFAULT_BOUNDS = Rect()
         // Temporaries used for various purposes to avoid allocations.
-        private val TEMP_RECT = Rect()
-        private val TEMP_RECTF1 = RectF()
-        private val TEMP_RECTF2 = RectF()
+        private val TEMP_OVERVIEW_TASK_POSITION = Rect()
+        private val TEMP_FROM_RECTF = RectF()
+        private val TEMP_TO_RECTF = RectF()
     }
 }
diff --git a/quickstep/src/com/android/quickstep/views/RecentsView.java b/quickstep/src/com/android/quickstep/views/RecentsView.java
index 424271a..e4633e4 100644
--- a/quickstep/src/com/android/quickstep/views/RecentsView.java
+++ b/quickstep/src/com/android/quickstep/views/RecentsView.java
@@ -65,11 +65,6 @@
 import static com.android.quickstep.BaseContainerInterface.getTaskDimension;
 import static com.android.quickstep.TaskUtils.checkCurrentOrManagedUserId;
 import static com.android.quickstep.util.LogUtils.splitFailureMessage;
-import static com.android.quickstep.util.TaskGridNavHelper.DIRECTION_DOWN;
-import static com.android.quickstep.util.TaskGridNavHelper.DIRECTION_LEFT;
-import static com.android.quickstep.util.TaskGridNavHelper.DIRECTION_RIGHT;
-import static com.android.quickstep.util.TaskGridNavHelper.DIRECTION_TAB;
-import static com.android.quickstep.util.TaskGridNavHelper.DIRECTION_UP;
 import static com.android.quickstep.views.ClearAllButton.DISMISS_ALPHA;
 import static com.android.quickstep.views.OverviewActionsView.HIDDEN_ACTIONS_IN_MENU;
 import static com.android.quickstep.views.OverviewActionsView.HIDDEN_DESKTOP;
@@ -4540,7 +4535,7 @@
     }
 
     private boolean snapToPageRelative(int delta, boolean cycle,
-            @TaskGridNavHelper.TASK_NAV_DIRECTION int direction) {
+            TaskGridNavHelper.TaskNavDirection direction) {
         // Set next page if scroll animation is still running, otherwise cannot snap to the
         // next page on successive key presses. Setting the current page aborts the scroll.
         if (!mScroller.isFinished()) {
@@ -4559,7 +4554,7 @@
         return true;
     }
 
-    private int getNextPageInternal(int delta, @TaskGridNavHelper.TASK_NAV_DIRECTION int direction,
+    private int getNextPageInternal(int delta, TaskGridNavHelper.TaskNavDirection direction,
             boolean cycle) {
         if (!showAsGrid()) {
             return getNextPage() + delta;
@@ -4647,15 +4642,19 @@
         switch (event.getKeyCode()) {
             case KeyEvent.KEYCODE_TAB:
                 return snapToPageRelative(event.isShiftPressed() ? -1 : 1, true /* cycle */,
-                        DIRECTION_TAB);
+                        TaskGridNavHelper.TaskNavDirection.TAB);
             case KeyEvent.KEYCODE_DPAD_RIGHT:
-                return snapToPageRelative(mIsRtl ? -1 : 1, true /* cycle */, DIRECTION_RIGHT);
+                return snapToPageRelative(mIsRtl ? -1 : 1, true /* cycle */,
+                        TaskGridNavHelper.TaskNavDirection.RIGHT);
             case KeyEvent.KEYCODE_DPAD_LEFT:
-                return snapToPageRelative(mIsRtl ? 1 : -1, true /* cycle */, DIRECTION_LEFT);
+                return snapToPageRelative(mIsRtl ? 1 : -1, true /* cycle */,
+                        TaskGridNavHelper.TaskNavDirection.LEFT);
             case KeyEvent.KEYCODE_DPAD_UP:
-                return snapToPageRelative(1, false /* cycle */, DIRECTION_UP);
+                return snapToPageRelative(1, false /* cycle */,
+                        TaskGridNavHelper.TaskNavDirection.UP);
             case KeyEvent.KEYCODE_DPAD_DOWN:
-                return snapToPageRelative(1, false /* cycle */, DIRECTION_DOWN);
+                return snapToPageRelative(1, false /* cycle */,
+                        TaskGridNavHelper.TaskNavDirection.DOWN);
             case KeyEvent.KEYCODE_DEL:
             case KeyEvent.KEYCODE_FORWARD_DEL:
                 dismissCurrentTask();
diff --git a/quickstep/src/com/android/quickstep/views/RecentsViewUtils.kt b/quickstep/src/com/android/quickstep/views/RecentsViewUtils.kt
index c7fc448..d3549fb 100644
--- a/quickstep/src/com/android/quickstep/views/RecentsViewUtils.kt
+++ b/quickstep/src/com/android/quickstep/views/RecentsViewUtils.kt
@@ -17,6 +17,7 @@
 package com.android.quickstep.views
 
 import android.graphics.Rect
+import android.os.VibrationAttributes
 import android.view.View
 import androidx.core.view.children
 import androidx.dynamicanimation.animation.FloatPropertyCompat
@@ -26,14 +27,18 @@
 import com.android.launcher3.Flags.enableLargeDesktopWindowingTile
 import com.android.launcher3.Flags.enableSeparateExternalDisplayTasks
 import com.android.launcher3.R
+import com.android.launcher3.Utilities.boundToRange
 import com.android.launcher3.touch.SingleAxisSwipeDetector
 import com.android.launcher3.util.DynamicResource
 import com.android.launcher3.util.IntArray
+import com.android.launcher3.util.MSDLPlayerWrapper
 import com.android.quickstep.util.GroupTask
 import com.android.quickstep.util.TaskGridNavHelper
 import com.android.quickstep.util.isExternalDisplay
 import com.android.quickstep.views.RecentsView.RUNNING_TASK_ATTACH_ALPHA
 import com.android.systemui.shared.recents.model.ThumbnailData
+import com.google.android.msdl.data.model.MSDLToken
+import com.google.android.msdl.domain.InteractionProperties
 import java.util.function.BiConsumer
 import kotlin.math.abs
 
@@ -378,7 +383,7 @@
             // negative displacement to positive displacement). We do not check for an exact value
             // to compare to, as the update listener does not necessarily hit every value (e.g. a
             // value of zero). Do not check again once it has started settling, as a spring can
-            // bounce past the origin multiple times depending on the stifness and damping ratio.
+            // bounce past the origin multiple times depending on the stiffness and damping ratio.
             if (startSettling) return@addUpdateListener
             if (lastPosition < 0 && value >= 0) {
                 startSettling = true
@@ -386,6 +391,7 @@
             lastPosition = value
             if (startSettling) {
                 neighborsToSettle.setStartVelocity(velocity).animateToFinalPosition(0f)
+                playDismissSettlingHaptic(velocity)
             }
         }
 
@@ -426,7 +432,23 @@
         towardsStart: Boolean,
     ): Sequence<Pair<TaskView, Int>> {
         if (recentsView.showAsGrid()) {
-            return gridTaskOffsetPairInTabOrderSequence(draggedTaskView, towardsStart)
+            val taskGridNavHelper =
+                TaskGridNavHelper(
+                    recentsView.topRowIdArray,
+                    recentsView.bottomRowIdArray,
+                    getLargeTaskViewIds(),
+                    hasAddDesktopButton = false,
+                )
+            return taskGridNavHelper
+                .gridTaskViewIdOffsetPairInTabOrderSequence(
+                    draggedTaskView.taskViewId,
+                    towardsStart,
+                )
+                .mapNotNull { (taskViewId, columnOffset) ->
+                    recentsView.getTaskViewFromTaskViewId(taskViewId)?.let { taskView ->
+                        Pair(taskView, columnOffset)
+                    }
+                }
         } else {
             val taskViewList = taskViews.toList()
             val draggedTaskViewIndex = taskViewList.indexOf(draggedTaskView)
@@ -446,49 +468,6 @@
         }
     }
 
-    /**
-     * Returns a sequence of pairs of (TaskViews, offsets) in the grid, ordered according to tab
-     * navigation, starting from the dragged TaskView, towards the start or end of the grid.
-     *
-     * <p>A positive delta moves forward in the tab order towards the end of the grid, while a
-     * negative value moves backward towards the beginning. The offset is the distance between
-     * columns the tasks are in.
-     */
-    private fun gridTaskOffsetPairInTabOrderSequence(
-        draggedTaskView: TaskView,
-        towardsStart: Boolean,
-    ): Sequence<Pair<TaskView, Int>> = sequence {
-        val taskGridNavHelper =
-            TaskGridNavHelper(
-                recentsView.topRowIdArray,
-                recentsView.bottomRowIdArray,
-                getLargeTaskViewIds(),
-                /* hasAddDesktopButton= */ false,
-            )
-        val draggedTaskViewColumn = taskGridNavHelper.getColumn(draggedTaskView.taskViewId)
-        var nextTaskView: TaskView? = draggedTaskView
-        var previousTaskView: TaskView? = null
-        while (nextTaskView != previousTaskView && nextTaskView != null) {
-            previousTaskView = nextTaskView
-            nextTaskView =
-                recentsView.getTaskViewFromTaskViewId(
-                    taskGridNavHelper.getNextGridPage(
-                        nextTaskView.taskViewId,
-                        if (towardsStart) -1 else 1,
-                        TaskGridNavHelper.DIRECTION_TAB,
-                        /* cycle = */ false,
-                    )
-                )
-            if (nextTaskView != null && nextTaskView != previousTaskView) {
-                val columnOffset =
-                    abs(
-                        taskGridNavHelper.getColumn(nextTaskView.taskViewId) - draggedTaskViewColumn
-                    )
-                yield(Pair(nextTaskView, columnOffset))
-            }
-        }
-    }
-
     /** Creates a neighboring task view spring, driven by the spring of its neighbor. */
     private fun createNeighboringTaskViewSpringAnimation(
         taskView: TaskView,
@@ -532,6 +511,27 @@
             )
     }
 
+    /**
+     * Plays a haptic as the dragged task view settles back into its rest state.
+     *
+     * <p>Haptic intensity is proportional to velocity.
+     */
+    private fun playDismissSettlingHaptic(velocity: Float) {
+        val maxDismissSettlingVelocity =
+            recentsView.pagedOrientationHandler.getSecondaryDimension(recentsView)
+        MSDLPlayerWrapper.INSTANCE.get(recentsView.context)
+            .playToken(
+                MSDLToken.CANCEL,
+                InteractionProperties.DynamicVibrationScale(
+                    boundToRange(velocity / maxDismissSettlingVelocity, 0f, 1f),
+                    VibrationAttributes.Builder()
+                        .setUsage(VibrationAttributes.USAGE_TOUCH)
+                        .setFlags(VibrationAttributes.FLAG_PIPELINED_EFFECT)
+                        .build(),
+                ),
+            )
+    }
+
     companion object {
         val TEMP_RECT = Rect()
 
diff --git a/quickstep/src/com/android/quickstep/views/TaskContainer.kt b/quickstep/src/com/android/quickstep/views/TaskContainer.kt
index bbe1af4..ef85aa1 100644
--- a/quickstep/src/com/android/quickstep/views/TaskContainer.kt
+++ b/quickstep/src/com/android/quickstep/views/TaskContainer.kt
@@ -75,7 +75,9 @@
     }
 
     internal var thumbnailData: ThumbnailData? = null
-    val splitAnimationThumbnail: Bitmap?
+
+    val thumbnail: Bitmap?
+        /** If possible don't use this. It should be replaced as part of b/331753115. */
         get() =
             if (enableRefactorTaskThumbnail()) thumbnailData?.thumbnail
             else thumbnailViewDeprecated.thumbnail
diff --git a/quickstep/src/com/android/quickstep/views/TaskView.kt b/quickstep/src/com/android/quickstep/views/TaskView.kt
index 56a35bd..39c15b1 100644
--- a/quickstep/src/com/android/quickstep/views/TaskView.kt
+++ b/quickstep/src/com/android/quickstep/views/TaskView.kt
@@ -1397,7 +1397,7 @@
             container.task,
             container.iconView.drawable,
             container.snapshotView,
-            container.splitAnimationThumbnail,
+            container.thumbnail,
             /* intent */ null,
             /* user */ null,
             container.itemInfo,
diff --git a/quickstep/tests/multivalentTests/src/com/android/quickstep/util/SplitAnimationControllerTest.kt b/quickstep/tests/multivalentTests/src/com/android/quickstep/util/SplitAnimationControllerTest.kt
index 5051251..9d000a4 100644
--- a/quickstep/tests/multivalentTests/src/com/android/quickstep/util/SplitAnimationControllerTest.kt
+++ b/quickstep/tests/multivalentTests/src/com/android/quickstep/util/SplitAnimationControllerTest.kt
@@ -89,7 +89,7 @@
     @Before
     fun setup() {
         whenever(mockTaskContainer.snapshotView).thenReturn(mockSnapshotView)
-        whenever(mockTaskContainer.splitAnimationThumbnail).thenReturn(mockBitmap)
+        whenever(mockTaskContainer.thumbnail).thenReturn(mockBitmap)
         whenever(mockTaskContainer.iconView).thenReturn(mockIconView)
         whenever(mockTaskContainer.task).thenReturn(mockTask)
         whenever(mockIconView.drawable).thenReturn(mockTaskViewDrawable)
@@ -117,7 +117,7 @@
         assertEquals(
             "Did not fallback to use splitSource icon drawable",
             mockSplitSourceDrawable,
-            splitAnimInitProps.iconDrawable
+            splitAnimInitProps.iconDrawable,
         )
     }
 
@@ -133,7 +133,7 @@
         assertEquals(
             "Did not use taskView icon drawable",
             mockTaskViewDrawable,
-            splitAnimInitProps.iconDrawable
+            splitAnimInitProps.iconDrawable,
         )
     }
 
@@ -152,7 +152,7 @@
         assertEquals(
             "Did not use taskView icon drawable",
             mockTaskViewDrawable,
-            splitAnimInitProps.iconDrawable
+            splitAnimInitProps.iconDrawable,
         )
     }
 
@@ -168,7 +168,7 @@
         assertEquals(
             "Did not use splitSource icon drawable",
             mockSplitSourceDrawable,
-            splitAnimInitProps.iconDrawable
+            splitAnimInitProps.iconDrawable,
         )
     }
 
@@ -190,13 +190,13 @@
         val splitAnimInitProps: SplitAnimationController.Companion.SplitAnimInitProps =
             splitAnimationController.getFirstAnimInitViews(
                 { mockGroupedTaskView },
-                { splitSelectSource }
+                { splitSelectSource },
             )
 
         assertEquals(
             "Did not use splitSource icon drawable",
             mockSplitSourceDrawable,
-            splitAnimInitProps.iconDrawable
+            splitAnimInitProps.iconDrawable,
         )
     }
 
@@ -214,7 +214,7 @@
                 any(),
                 any(),
                 any(),
-                any()
+                any(),
             )
 
         spySplitAnimationController.playSplitLaunchAnimation(
@@ -230,7 +230,7 @@
             null /* info */,
             null /* t */,
             {} /* finishCallback */,
-            1f /* cornerRadius */
+            1f, /* cornerRadius */
         )
 
         verify(spySplitAnimationController)
@@ -243,7 +243,7 @@
                 any(),
                 any(),
                 any(),
-                any()
+                any(),
             )
     }
 
@@ -267,7 +267,7 @@
             transitionInfo,
             transaction,
             {} /* finishCallback */,
-            1f /* cornerRadius */
+            1f, /* cornerRadius */
         )
 
         verify(spySplitAnimationController)
@@ -296,7 +296,7 @@
             transitionInfo,
             transaction,
             {} /* finishCallback */,
-            1f /* cornerRadius */
+            1f, /* cornerRadius */
         )
 
         verify(spySplitAnimationController)
@@ -325,7 +325,7 @@
             transitionInfo,
             transaction,
             {} /* finishCallback */,
-            1f /* cornerRadius */
+            1f, /* cornerRadius */
         )
 
         verify(spySplitAnimationController)
@@ -353,7 +353,7 @@
             transitionInfo,
             transaction,
             {} /* finishCallback */,
-            1f /* cornerRadius */
+            1f, /* cornerRadius */
         )
 
         verify(spySplitAnimationController)
@@ -381,7 +381,7 @@
             transitionInfo,
             transaction,
             {} /* finishCallback */,
-            1f /* cornerRadius */
+            1f, /* cornerRadius */
         )
 
         verify(spySplitAnimationController)
@@ -408,7 +408,7 @@
             transitionInfo,
             transaction,
             {} /* finishCallback */,
-            1f /* cornerRadius */
+            1f, /* cornerRadius */
         )
 
         verify(spySplitAnimationController)
diff --git a/quickstep/tests/multivalentTests/src/com/android/quickstep/util/TaskGridNavHelperTest.kt b/quickstep/tests/multivalentTests/src/com/android/quickstep/util/TaskGridNavHelperTest.kt
index f2fa0c5..cb088fd 100644
--- a/quickstep/tests/multivalentTests/src/com/android/quickstep/util/TaskGridNavHelperTest.kt
+++ b/quickstep/tests/multivalentTests/src/com/android/quickstep/util/TaskGridNavHelperTest.kt
@@ -16,13 +16,13 @@
 package com.android.quickstep.util
 
 import com.android.launcher3.util.IntArray
-import com.android.quickstep.util.TaskGridNavHelper.ADD_DESK_PLACEHOLDER_ID
-import com.android.quickstep.util.TaskGridNavHelper.CLEAR_ALL_PLACEHOLDER_ID
-import com.android.quickstep.util.TaskGridNavHelper.DIRECTION_DOWN
-import com.android.quickstep.util.TaskGridNavHelper.DIRECTION_LEFT
-import com.android.quickstep.util.TaskGridNavHelper.DIRECTION_RIGHT
-import com.android.quickstep.util.TaskGridNavHelper.DIRECTION_TAB
-import com.android.quickstep.util.TaskGridNavHelper.DIRECTION_UP
+import com.android.quickstep.util.TaskGridNavHelper.Companion.ADD_DESK_PLACEHOLDER_ID
+import com.android.quickstep.util.TaskGridNavHelper.Companion.CLEAR_ALL_PLACEHOLDER_ID
+import com.android.quickstep.util.TaskGridNavHelper.TaskNavDirection.DOWN
+import com.android.quickstep.util.TaskGridNavHelper.TaskNavDirection.LEFT
+import com.android.quickstep.util.TaskGridNavHelper.TaskNavDirection.RIGHT
+import com.android.quickstep.util.TaskGridNavHelper.TaskNavDirection.TAB
+import com.android.quickstep.util.TaskGridNavHelper.TaskNavDirection.UP
 import com.google.common.truth.Truth.assertThat
 import org.junit.Test
 
@@ -35,8 +35,7 @@
     */
     @Test
     fun equalLengthRows_noFocused_onTop_pressDown_goesToBottom() {
-        assertThat(getNextGridPage(currentPageTaskViewId = 1, DIRECTION_DOWN, delta = 1))
-            .isEqualTo(2)
+        assertThat(getNextGridPage(currentPageTaskViewId = 1, DOWN, delta = 1)).isEqualTo(2)
     }
 
     /*                      ↑----→
@@ -46,7 +45,7 @@
     */
     @Test
     fun equalLengthRows_noFocused_onTop_pressUp_goesToBottom() {
-        assertThat(getNextGridPage(currentPageTaskViewId = 1, DIRECTION_UP, delta = 1)).isEqualTo(2)
+        assertThat(getNextGridPage(currentPageTaskViewId = 1, UP, delta = 1)).isEqualTo(2)
     }
 
     /*                      ↓----↑
@@ -57,8 +56,7 @@
     */
     @Test
     fun equalLengthRows_noFocused_onBottom_pressDown_goesToTop() {
-        assertThat(getNextGridPage(currentPageTaskViewId = 2, DIRECTION_DOWN, delta = 1))
-            .isEqualTo(1)
+        assertThat(getNextGridPage(currentPageTaskViewId = 2, DOWN, delta = 1)).isEqualTo(1)
     }
 
     /*
@@ -68,7 +66,7 @@
     */
     @Test
     fun equalLengthRows_noFocused_onBottom_pressUp_goesToTop() {
-        assertThat(getNextGridPage(currentPageTaskViewId = 2, DIRECTION_UP, delta = 1)).isEqualTo(1)
+        assertThat(getNextGridPage(currentPageTaskViewId = 2, UP, delta = 1)).isEqualTo(1)
     }
 
     /*
@@ -78,8 +76,7 @@
     */
     @Test
     fun equalLengthRows_noFocused_onTop_pressLeft_goesLeft() {
-        assertThat(getNextGridPage(currentPageTaskViewId = 1, DIRECTION_LEFT, delta = 1))
-            .isEqualTo(3)
+        assertThat(getNextGridPage(currentPageTaskViewId = 1, LEFT, delta = 1)).isEqualTo(3)
     }
 
     /*
@@ -89,8 +86,7 @@
     */
     @Test
     fun equalLengthRows_noFocused_onBottom_pressLeft_goesLeft() {
-        assertThat(getNextGridPage(currentPageTaskViewId = 2, DIRECTION_LEFT, delta = 1))
-            .isEqualTo(4)
+        assertThat(getNextGridPage(currentPageTaskViewId = 2, LEFT, delta = 1)).isEqualTo(4)
     }
 
     /*
@@ -100,8 +96,7 @@
     */
     @Test
     fun equalLengthRows_noFocused_onTop_secondItem_pressRight_goesRight() {
-        assertThat(getNextGridPage(currentPageTaskViewId = 3, DIRECTION_RIGHT, delta = -1))
-            .isEqualTo(1)
+        assertThat(getNextGridPage(currentPageTaskViewId = 3, RIGHT, delta = -1)).isEqualTo(1)
     }
 
     /*
@@ -111,8 +106,7 @@
     */
     @Test
     fun equalLengthRows_noFocused_onBottom_secondItem_pressRight_goesRight() {
-        assertThat(getNextGridPage(currentPageTaskViewId = 4, DIRECTION_RIGHT, delta = -1))
-            .isEqualTo(2)
+        assertThat(getNextGridPage(currentPageTaskViewId = 4, RIGHT, delta = -1)).isEqualTo(2)
     }
 
     /*
@@ -124,7 +118,7 @@
     */
     @Test
     fun equalLengthRows_noFocused_onTop_pressRight_cycleToClearAll() {
-        assertThat(getNextGridPage(currentPageTaskViewId = 1, DIRECTION_RIGHT, delta = -1))
+        assertThat(getNextGridPage(currentPageTaskViewId = 1, RIGHT, delta = -1))
             .isEqualTo(CLEAR_ALL_PLACEHOLDER_ID)
     }
 
@@ -137,7 +131,7 @@
     */
     @Test
     fun equalLengthRows_noFocused_onBottom_pressRight_cycleToClearAll() {
-        assertThat(getNextGridPage(currentPageTaskViewId = 2, DIRECTION_RIGHT, delta = -1))
+        assertThat(getNextGridPage(currentPageTaskViewId = 2, RIGHT, delta = -1))
             .isEqualTo(CLEAR_ALL_PLACEHOLDER_ID)
     }
 
@@ -149,7 +143,7 @@
     */
     @Test
     fun equalLengthRows_noFocused_onTop_lastItem_pressLeft_toClearAll() {
-        assertThat(getNextGridPage(currentPageTaskViewId = 5, DIRECTION_LEFT, delta = 1))
+        assertThat(getNextGridPage(currentPageTaskViewId = 5, LEFT, delta = 1))
             .isEqualTo(CLEAR_ALL_PLACEHOLDER_ID)
     }
 
@@ -161,7 +155,7 @@
     */
     @Test
     fun equalLengthRows_noFocused_onBottom_lastItem_pressLeft_toClearAll() {
-        assertThat(getNextGridPage(currentPageTaskViewId = 6, DIRECTION_LEFT, delta = 1))
+        assertThat(getNextGridPage(currentPageTaskViewId = 6, LEFT, delta = 1))
             .isEqualTo(CLEAR_ALL_PLACEHOLDER_ID)
     }
 
@@ -176,11 +170,7 @@
     @Test
     fun equalLengthRows_noFocused_onClearAll_pressLeft_cycleToFirst() {
         assertThat(
-                getNextGridPage(
-                    currentPageTaskViewId = CLEAR_ALL_PLACEHOLDER_ID,
-                    DIRECTION_LEFT,
-                    delta = 1,
-                )
+                getNextGridPage(currentPageTaskViewId = CLEAR_ALL_PLACEHOLDER_ID, LEFT, delta = 1)
             )
             .isEqualTo(1)
     }
@@ -194,11 +184,7 @@
     @Test
     fun equalLengthRows_noFocused_onClearAll_pressRight_toLastInBottom() {
         assertThat(
-                getNextGridPage(
-                    currentPageTaskViewId = CLEAR_ALL_PLACEHOLDER_ID,
-                    DIRECTION_RIGHT,
-                    delta = -1,
-                )
+                getNextGridPage(currentPageTaskViewId = CLEAR_ALL_PLACEHOLDER_ID, RIGHT, delta = -1)
             )
             .isEqualTo(6)
     }
@@ -214,7 +200,7 @@
         assertThat(
                 getNextGridPage(
                     currentPageTaskViewId = FOCUSED_TASK_ID,
-                    DIRECTION_LEFT,
+                    LEFT,
                     delta = 1,
                     largeTileIds = listOf(FOCUSED_TASK_ID),
                 )
@@ -233,7 +219,7 @@
         assertThat(
                 getNextGridPage(
                     currentPageTaskViewId = FOCUSED_TASK_ID,
-                    DIRECTION_UP,
+                    UP,
                     delta = 1,
                     largeTileIds = listOf(FOCUSED_TASK_ID),
                 )
@@ -254,7 +240,7 @@
         assertThat(
                 getNextGridPage(
                     currentPageTaskViewId = FOCUSED_TASK_ID,
-                    DIRECTION_DOWN,
+                    DOWN,
                     delta = 1,
                     largeTileIds = listOf(FOCUSED_TASK_ID),
                 )
@@ -275,7 +261,7 @@
         assertThat(
                 getNextGridPage(
                     currentPageTaskViewId = FOCUSED_TASK_ID,
-                    DIRECTION_RIGHT,
+                    RIGHT,
                     delta = -1,
                     largeTileIds = listOf(FOCUSED_TASK_ID),
                 )
@@ -297,7 +283,7 @@
         assertThat(
                 getNextGridPage(
                     currentPageTaskViewId = CLEAR_ALL_PLACEHOLDER_ID,
-                    DIRECTION_LEFT,
+                    LEFT,
                     delta = 1,
                     largeTileIds = listOf(FOCUSED_TASK_ID),
                 )
@@ -316,7 +302,7 @@
         assertThat(
                 getNextGridPage(
                     currentPageTaskViewId = 7,
-                    DIRECTION_DOWN,
+                    DOWN,
                     delta = 1,
                     topIds = IntArray.wrap(1, 3, 5, 7),
                 )
@@ -335,7 +321,7 @@
         assertThat(
                 getNextGridPage(
                     /* topIds = */ currentPageTaskViewId = 7,
-                    DIRECTION_UP,
+                    UP,
                     delta = 1,
                     topIds = IntArray.wrap(1, 3, 5, 7),
                 )
@@ -353,7 +339,7 @@
         assertThat(
                 getNextGridPage(
                     /* topIds = */ currentPageTaskViewId = 6,
-                    DIRECTION_LEFT,
+                    LEFT,
                     delta = 1,
                     topIds = IntArray.wrap(1, 3, 5, 7),
                 )
@@ -372,7 +358,7 @@
         assertThat(
                 getNextGridPage(
                     /* topIds = */ currentPageTaskViewId = CLEAR_ALL_PLACEHOLDER_ID,
-                    DIRECTION_RIGHT,
+                    RIGHT,
                     delta = -1,
                     topIds = IntArray.wrap(1, 3, 5, 7),
                 )
@@ -391,7 +377,7 @@
         assertThat(
                 getNextGridPage(
                     currentPageTaskViewId = CLEAR_ALL_PLACEHOLDER_ID,
-                    DIRECTION_RIGHT,
+                    RIGHT,
                     delta = -1,
                     bottomIds = IntArray.wrap(2, 4, 6, 7),
                 )
@@ -406,8 +392,7 @@
     */
     @Test
     fun equalLengthRows_noFocused_onTop_pressTab_goesToBottom() {
-        assertThat(getNextGridPage(currentPageTaskViewId = 1, DIRECTION_TAB, delta = 1))
-            .isEqualTo(2)
+        assertThat(getNextGridPage(currentPageTaskViewId = 1, TAB, delta = 1)).isEqualTo(2)
     }
 
     /*
@@ -418,8 +403,7 @@
     */
     @Test
     fun equalLengthRows_noFocused_onBottom_pressTab_goesToNextTop() {
-        assertThat(getNextGridPage(currentPageTaskViewId = 2, DIRECTION_TAB, delta = 1))
-            .isEqualTo(3)
+        assertThat(getNextGridPage(currentPageTaskViewId = 2, TAB, delta = 1)).isEqualTo(3)
     }
 
     /*
@@ -431,8 +415,7 @@
     */
     @Test
     fun equalLengthRows_noFocused_onTop_pressTabWithShift_goesToPreviousBottom() {
-        assertThat(getNextGridPage(currentPageTaskViewId = 3, DIRECTION_TAB, delta = -1))
-            .isEqualTo(2)
+        assertThat(getNextGridPage(currentPageTaskViewId = 3, TAB, delta = -1)).isEqualTo(2)
     }
 
     /*
@@ -442,8 +425,7 @@
     */
     @Test
     fun equalLengthRows_noFocused_onBottom_pressTabWithShift_goesToTop() {
-        assertThat(getNextGridPage(currentPageTaskViewId = 2, DIRECTION_TAB, delta = -1))
-            .isEqualTo(1)
+        assertThat(getNextGridPage(currentPageTaskViewId = 2, TAB, delta = -1)).isEqualTo(1)
     }
 
     /*
@@ -453,9 +435,7 @@
     */
     @Test
     fun equalLengthRows_noFocused_onTop_pressTabWithShift_noCycle_staysOnTop() {
-        assertThat(
-                getNextGridPage(currentPageTaskViewId = 1, DIRECTION_TAB, delta = -1, cycle = false)
-            )
+        assertThat(getNextGridPage(currentPageTaskViewId = 1, TAB, delta = -1, cycle = false))
             .isEqualTo(1)
     }
 
@@ -469,7 +449,7 @@
         assertThat(
                 getNextGridPage(
                     currentPageTaskViewId = CLEAR_ALL_PLACEHOLDER_ID,
-                    DIRECTION_TAB,
+                    TAB,
                     delta = 1,
                     cycle = false,
                 )
@@ -487,7 +467,7 @@
         assertThat(
                 getNextGridPage(
                     currentPageTaskViewId = DESKTOP_TASK_ID,
-                    DIRECTION_LEFT,
+                    LEFT,
                     delta = 1,
                     largeTileIds = listOf(DESKTOP_TASK_ID, FOCUSED_TASK_ID),
                 )
@@ -505,7 +485,7 @@
         assertThat(
                 getNextGridPage(
                     currentPageTaskViewId = FOCUSED_TASK_ID,
-                    DIRECTION_RIGHT,
+                    RIGHT,
                     delta = -1,
                     largeTileIds = listOf(DESKTOP_TASK_ID, FOCUSED_TASK_ID),
                 )
@@ -525,7 +505,7 @@
         assertThat(
                 getNextGridPage(
                     currentPageTaskViewId = DESKTOP_TASK_ID,
-                    DIRECTION_RIGHT,
+                    RIGHT,
                     delta = -1,
                     largeTileIds = listOf(DESKTOP_TASK_ID, FOCUSED_TASK_ID),
                 )
@@ -546,7 +526,7 @@
         assertThat(
                 getNextGridPage(
                     currentPageTaskViewId = CLEAR_ALL_PLACEHOLDER_ID,
-                    DIRECTION_LEFT,
+                    LEFT,
                     delta = 1,
                     largeTileIds = listOf(DESKTOP_TASK_ID, FOCUSED_TASK_ID),
                 )
@@ -565,7 +545,7 @@
         assertThat(
                 getNextGridPage(
                     currentPageTaskViewId = 2,
-                    DIRECTION_RIGHT,
+                    RIGHT,
                     delta = -1,
                     largeTileIds = listOf(DESKTOP_TASK_ID, FOCUSED_TASK_ID),
                 )
@@ -584,7 +564,7 @@
         assertThat(
                 getNextGridPage(
                     currentPageTaskViewId = 1,
-                    DIRECTION_RIGHT,
+                    RIGHT,
                     delta = -1,
                     largeTileIds = listOf(DESKTOP_TASK_ID, FOCUSED_TASK_ID),
                 )
@@ -603,7 +583,7 @@
         assertThat(
                 getNextGridPage(
                     currentPageTaskViewId = DESKTOP_TASK_ID,
-                    DIRECTION_TAB,
+                    TAB,
                     delta = 1,
                     largeTileIds = listOf(DESKTOP_TASK_ID, FOCUSED_TASK_ID),
                 )
@@ -621,7 +601,7 @@
         assertThat(
                 getNextGridPage(
                     currentPageTaskViewId = FOCUSED_TASK_ID,
-                    DIRECTION_LEFT,
+                    LEFT,
                     delta = 1,
                     topIds = IntArray(),
                     bottomIds = IntArray.wrap(2),
@@ -643,7 +623,7 @@
         assertThat(
                 getNextGridPage(
                     currentPageTaskViewId = DESKTOP_TASK_ID,
-                    DIRECTION_TAB,
+                    TAB,
                     delta = -1,
                     largeTileIds = listOf(DESKTOP_TASK_ID, FOCUSED_TASK_ID),
                 )
@@ -662,7 +642,7 @@
         assertThat(
                 getNextGridPage(
                     currentPageTaskViewId = 1,
-                    DIRECTION_RIGHT,
+                    RIGHT,
                     delta = -1,
                     hasAddDesktopButton = true,
                 )
@@ -681,7 +661,7 @@
         assertThat(
                 getNextGridPage(
                     currentPageTaskViewId = 2,
-                    DIRECTION_RIGHT,
+                    RIGHT,
                     delta = -1,
                     hasAddDesktopButton = true,
                 )
@@ -701,7 +681,7 @@
         assertThat(
                 getNextGridPage(
                     currentPageTaskViewId = ADD_DESK_PLACEHOLDER_ID,
-                    DIRECTION_RIGHT,
+                    RIGHT,
                     delta = -1,
                     hasAddDesktopButton = true,
                 )
@@ -722,7 +702,7 @@
         assertThat(
                 getNextGridPage(
                     currentPageTaskViewId = CLEAR_ALL_PLACEHOLDER_ID,
-                    DIRECTION_LEFT,
+                    LEFT,
                     delta = 1,
                     hasAddDesktopButton = true,
                 )
@@ -741,7 +721,7 @@
         assertThat(
                 getNextGridPage(
                     currentPageTaskViewId = ADD_DESK_PLACEHOLDER_ID,
-                    DIRECTION_UP,
+                    UP,
                     delta = 1,
                     hasAddDesktopButton = true,
                 )
@@ -760,7 +740,7 @@
         assertThat(
                 getNextGridPage(
                     currentPageTaskViewId = ADD_DESK_PLACEHOLDER_ID,
-                    DIRECTION_DOWN,
+                    DOWN,
                     delta = 1,
                     hasAddDesktopButton = true,
                 )
@@ -778,7 +758,7 @@
         assertThat(
                 getNextGridPage(
                     currentPageTaskViewId = ADD_DESK_PLACEHOLDER_ID,
-                    DIRECTION_LEFT,
+                    LEFT,
                     delta = 1,
                     largeTileIds = listOf(DESKTOP_TASK_ID),
                     hasAddDesktopButton = true,
@@ -797,7 +777,7 @@
         assertThat(
                 getNextGridPage(
                     currentPageTaskViewId = DESKTOP_TASK_ID,
-                    DIRECTION_RIGHT,
+                    RIGHT,
                     delta = -1,
                     largeTileIds = listOf(DESKTOP_TASK_ID),
                     hasAddDesktopButton = true,
@@ -806,9 +786,43 @@
             .isEqualTo(ADD_DESK_PLACEHOLDER_ID)
     }
 
+    // Col offset:  0   1   2
+    //             -----------
+    // ID grid:     4   2   0  start
+    //         end [5]  3   1
+    @Test
+    fun gridTaskViewIdOffsetPairInTabOrderSequence_towardsStart() {
+        val expected = listOf(Pair(4, 0), Pair(3, 1), Pair(2, 1), Pair(1, 2), Pair(0, 2))
+        assertThat(
+                gridTaskViewIdOffsetPairInTabOrderSequence(
+                        initialTaskViewId = 5,
+                        towardsStart = true,
+                    )
+                    .toList()
+            )
+            .isEqualTo(expected)
+    }
+
+    // Col offset:  2   1   0
+    //             -----------
+    // ID grid:     4   2  [0] start
+    //          end 5   3   1
+    @Test
+    fun gridTaskViewIdOffsetPairInTabOrderSequence_towardsEnd() {
+        val expected = listOf(Pair(1, 0), Pair(2, 1), Pair(3, 1), Pair(4, 2), Pair(5, 2))
+        assertThat(
+                gridTaskViewIdOffsetPairInTabOrderSequence(
+                        initialTaskViewId = 0,
+                        towardsStart = false,
+                    )
+                    .toList()
+            )
+            .isEqualTo(expected)
+    }
+
     private fun getNextGridPage(
         currentPageTaskViewId: Int,
-        direction: Int,
+        direction: TaskGridNavHelper.TaskNavDirection,
         delta: Int,
         topIds: IntArray = IntArray.wrap(1, 3, 5),
         bottomIds: IntArray = IntArray.wrap(2, 4, 6),
@@ -821,6 +835,22 @@
         return taskGridNavHelper.getNextGridPage(currentPageTaskViewId, delta, direction, cycle)
     }
 
+    private fun gridTaskViewIdOffsetPairInTabOrderSequence(
+        initialTaskViewId: Int,
+        towardsStart: Boolean,
+        topIds: IntArray = IntArray.wrap(0, 2, 4),
+        bottomIds: IntArray = IntArray.wrap(1, 3, 5),
+        largeTileIds: List<Int> = emptyList(),
+        hasAddDesktopButton: Boolean = false,
+    ): Sequence<Pair<Int, Int>> {
+        val taskGridNavHelper =
+            TaskGridNavHelper(topIds, bottomIds, largeTileIds, hasAddDesktopButton)
+        return taskGridNavHelper.gridTaskViewIdOffsetPairInTabOrderSequence(
+            initialTaskViewId,
+            towardsStart,
+        )
+    }
+
     private companion object {
         const val FOCUSED_TASK_ID = 99
         const val DESKTOP_TASK_ID = 100
diff --git a/src/com/android/launcher3/InvariantDeviceProfile.java b/src/com/android/launcher3/InvariantDeviceProfile.java
index 43876a6..4107c8f 100644
--- a/src/com/android/launcher3/InvariantDeviceProfile.java
+++ b/src/com/android/launcher3/InvariantDeviceProfile.java
@@ -300,10 +300,10 @@
         lifeCycle.addCloseable(() -> prefs.removeListener(prefListener,
                 FIXED_LANDSCAPE_MODE, ENABLE_TWOLINE_ALLAPPS_TOGGLE));
 
-        SimpleBroadcastReceiver localeReceiver = new SimpleBroadcastReceiver(
+        SimpleBroadcastReceiver localeReceiver = new SimpleBroadcastReceiver(context,
                 MAIN_EXECUTOR, i -> onConfigChanged(context));
-        localeReceiver.register(context, Intent.ACTION_LOCALE_CHANGED);
-        lifeCycle.addCloseable(() -> localeReceiver.unregisterReceiverSafely(context));
+        localeReceiver.register(Intent.ACTION_LOCALE_CHANGED);
+        lifeCycle.addCloseable(() -> localeReceiver.unregisterReceiverSafely());
     }
 
     private String initGrid(Context context, String gridName) {
diff --git a/src/com/android/launcher3/LauncherAppState.java b/src/com/android/launcher3/LauncherAppState.java
index e560a14..71013c3 100644
--- a/src/com/android/launcher3/LauncherAppState.java
+++ b/src/com/android/launcher3/LauncherAppState.java
@@ -119,14 +119,13 @@
         }
 
         SimpleBroadcastReceiver modelChangeReceiver =
-                new SimpleBroadcastReceiver(UI_HELPER_EXECUTOR, mModel::onBroadcastIntent);
+                new SimpleBroadcastReceiver(context, UI_HELPER_EXECUTOR, mModel::onBroadcastIntent);
         modelChangeReceiver.register(
-                mContext,
                 ACTION_DEVICE_POLICY_RESOURCE_UPDATED);
         if (BuildConfig.IS_STUDIO_BUILD) {
-            modelChangeReceiver.register(mContext, RECEIVER_EXPORTED, ACTION_FORCE_ROLOAD);
+            modelChangeReceiver.register(RECEIVER_EXPORTED, ACTION_FORCE_ROLOAD);
         }
-        mOnTerminateCallback.add(() -> modelChangeReceiver.unregisterReceiverSafely(mContext));
+        mOnTerminateCallback.add(() -> modelChangeReceiver.unregisterReceiverSafely());
 
         SafeCloseable userChangeListener = UserCache.INSTANCE.get(mContext)
                 .addUserEventListener(mModel::onUserEvent);
diff --git a/src/com/android/launcher3/graphics/ThemeManager.kt b/src/com/android/launcher3/graphics/ThemeManager.kt
index f977999..e767290 100644
--- a/src/com/android/launcher3/graphics/ThemeManager.kt
+++ b/src/com/android/launcher3/graphics/ThemeManager.kt
@@ -62,8 +62,8 @@
     private val listeners = CopyOnWriteArrayList<ThemeChangeListener>()
 
     init {
-        val receiver = SimpleBroadcastReceiver(MAIN_EXECUTOR) { verifyIconState() }
-        receiver.registerPkgActions(context, "android", ACTION_OVERLAY_CHANGED)
+        val receiver = SimpleBroadcastReceiver(context, MAIN_EXECUTOR) { verifyIconState() }
+        receiver.registerPkgActions("android", ACTION_OVERLAY_CHANGED)
 
         val prefListener = LauncherPrefChangeListener { key ->
             when (key) {
@@ -74,7 +74,7 @@
         prefs.addListener(prefListener, THEMED_ICONS, PREF_ICON_SHAPE)
 
         lifecycle.addCloseable {
-            receiver.unregisterReceiverSafely(context)
+            receiver.unregisterReceiverSafely()
             prefs.removeListener(prefListener)
         }
     }
diff --git a/src/com/android/launcher3/pm/UserCache.java b/src/com/android/launcher3/pm/UserCache.java
index 0b18a87..20c0ecc 100644
--- a/src/com/android/launcher3/pm/UserCache.java
+++ b/src/com/android/launcher3/pm/UserCache.java
@@ -81,10 +81,7 @@
     }
 
     private final List<BiConsumer<UserHandle, String>> mUserEventListeners = new ArrayList<>();
-    private final SimpleBroadcastReceiver mUserChangeReceiver =
-            new SimpleBroadcastReceiver(MODEL_EXECUTOR, this::onUsersChanged);
-
-    private final Context mContext;
+    private final SimpleBroadcastReceiver mUserChangeReceiver;
     private final ApiWrapper mApiWrapper;
 
     @NonNull
@@ -99,16 +96,17 @@
             DaggerSingletonTracker tracker,
             ApiWrapper apiWrapper
     ) {
-        mContext = context;
         mApiWrapper = apiWrapper;
+        mUserChangeReceiver = new SimpleBroadcastReceiver(context,
+                MODEL_EXECUTOR, this::onUsersChanged);
         mUserToSerialMap = Collections.emptyMap();
         MODEL_EXECUTOR.execute(this::initAsync);
-        tracker.addCloseable(() -> mUserChangeReceiver.unregisterReceiverSafely(mContext));
+        tracker.addCloseable(() -> mUserChangeReceiver.unregisterReceiverSafely());
     }
 
     @WorkerThread
     private void initAsync() {
-        mUserChangeReceiver.register(mContext,
+        mUserChangeReceiver.register(
                 Intent.ACTION_MANAGED_PROFILE_AVAILABLE,
                 Intent.ACTION_MANAGED_PROFILE_UNAVAILABLE,
                 Intent.ACTION_MANAGED_PROFILE_REMOVED,
diff --git a/src/com/android/launcher3/util/DisplayController.java b/src/com/android/launcher3/util/DisplayController.java
index 638bd60..376a61e 100644
--- a/src/com/android/launcher3/util/DisplayController.java
+++ b/src/com/android/launcher3/util/DisplayController.java
@@ -107,7 +107,6 @@
     private static final String ACTION_OVERLAY_CHANGED = "android.intent.action.OVERLAY_CHANGED";
     private static final String TARGET_OVERLAY_PACKAGE = "android";
 
-    private final Context mContext;
     private final WindowManagerProxy mWMProxy;
 
     // Null for SDK < S
@@ -120,8 +119,7 @@
 
     // We will register broadcast receiver on main thread to ensure not missing changes on
     // TARGET_OVERLAY_PACKAGE and ACTION_OVERLAY_CHANGED.
-    private final SimpleBroadcastReceiver mReceiver =
-            new SimpleBroadcastReceiver(MAIN_EXECUTOR, this::onIntent);
+    private final SimpleBroadcastReceiver mReceiver;
 
     private Info mInfo;
     private boolean mDestroyed = false;
@@ -131,7 +129,6 @@
             WindowManagerProxy wmProxy,
             LauncherPrefs prefs,
             DaggerSingletonTracker lifecycle) {
-        mContext = context;
         mWMProxy = wmProxy;
 
         if (enableTaskbarPinning()) {
@@ -155,11 +152,12 @@
 
         Display display = context.getSystemService(DisplayManager.class)
                 .getDisplay(DEFAULT_DISPLAY);
-        mWindowContext = mContext.createWindowContext(display, TYPE_APPLICATION, null);
+        mWindowContext = context.createWindowContext(display, TYPE_APPLICATION, null);
         mWindowContext.registerComponentCallbacks(this);
 
         // Initialize navigation mode change listener
-        mReceiver.registerPkgActions(mContext, TARGET_OVERLAY_PACKAGE, ACTION_OVERLAY_CHANGED);
+        mReceiver = new SimpleBroadcastReceiver(context, MAIN_EXECUTOR, this::onIntent);
+        mReceiver.registerPkgActions(TARGET_OVERLAY_PACKAGE, ACTION_OVERLAY_CHANGED);
 
         mInfo = new Info(mWindowContext, wmProxy,
                 wmProxy.estimateInternalDisplayBounds(mWindowContext));
@@ -169,7 +167,7 @@
         lifecycle.addCloseable(() -> {
             mDestroyed = true;
             mWindowContext.unregisterComponentCallbacks(this);
-            mReceiver.unregisterReceiverSafely(mContext);
+            mReceiver.unregisterReceiverSafely();
             wmProxy.unregisterDesktopVisibilityListener(this);
         });
     }
diff --git a/src/com/android/launcher3/util/LockedUserState.kt b/src/com/android/launcher3/util/LockedUserState.kt
index a6a6ceb..742a327 100644
--- a/src/com/android/launcher3/util/LockedUserState.kt
+++ b/src/com/android/launcher3/util/LockedUserState.kt
@@ -44,7 +44,7 @@
 
     @VisibleForTesting
     val userUnlockedReceiver =
-        SimpleBroadcastReceiver(UI_HELPER_EXECUTOR) {
+        SimpleBroadcastReceiver(context, UI_HELPER_EXECUTOR) {
             if (Intent.ACTION_USER_UNLOCKED == it.action) {
                 isUserUnlocked = true
             }
@@ -61,7 +61,6 @@
         isUserUnlockedAtLauncherStartup = isUserUnlocked
         if (!isUserUnlocked) {
             userUnlockedReceiver.register(
-                context,
                 {
                     // If user is unlocked while registering broadcast receiver, we should update
                     // [isUserUnlocked], which will call [notifyUserUnlocked] in setter
@@ -72,7 +71,7 @@
                 Intent.ACTION_USER_UNLOCKED,
             )
         }
-        lifeCycle.addCloseable { userUnlockedReceiver.unregisterReceiverSafely(context) }
+        lifeCycle.addCloseable { userUnlockedReceiver.unregisterReceiverSafely() }
     }
 
     private fun checkIsUserUnlocked() =
@@ -80,7 +79,7 @@
 
     private fun notifyUserUnlocked() {
         mUserUnlockedActions.executeAllAndDestroy()
-        userUnlockedReceiver.unregisterReceiverSafely(context)
+        userUnlockedReceiver.unregisterReceiverSafely()
     }
 
     /**
diff --git a/src/com/android/launcher3/util/ScreenOnTracker.java b/src/com/android/launcher3/util/ScreenOnTracker.java
index 50be98b..8ffe9ea 100644
--- a/src/com/android/launcher3/util/ScreenOnTracker.java
+++ b/src/com/android/launcher3/util/ScreenOnTracker.java
@@ -46,34 +46,31 @@
     private final SimpleBroadcastReceiver mReceiver;
     private final CopyOnWriteArrayList<ScreenOnListener> mListeners = new CopyOnWriteArrayList<>();
 
-    private final Context mContext;
     private boolean mIsScreenOn;
 
     @Inject
     ScreenOnTracker(@ApplicationContext Context context, DaggerSingletonTracker tracker) {
         // Assume that the screen is on to begin with
-        mContext = context;
-        mReceiver = new SimpleBroadcastReceiver(UI_HELPER_EXECUTOR, this::onReceive);
+        mReceiver = new SimpleBroadcastReceiver(context, UI_HELPER_EXECUTOR, this::onReceive);
         init(tracker);
     }
 
     @VisibleForTesting
     ScreenOnTracker(@ApplicationContext Context context, SimpleBroadcastReceiver receiver,
             DaggerSingletonTracker tracker) {
-        mContext = context;
         mReceiver = receiver;
         init(tracker);
     }
 
     private void init(DaggerSingletonTracker tracker) {
         mIsScreenOn = true;
-        mReceiver.register(mContext, ACTION_SCREEN_ON, ACTION_SCREEN_OFF, ACTION_USER_PRESENT);
+        mReceiver.register(ACTION_SCREEN_ON, ACTION_SCREEN_OFF, ACTION_USER_PRESENT);
         tracker.addCloseable(this);
     }
 
     @Override
     public void close() {
-        mReceiver.unregisterReceiverSafely(mContext);
+        mReceiver.unregisterReceiverSafely();
     }
 
     @VisibleForTesting
diff --git a/src/com/android/launcher3/util/SimpleBroadcastReceiver.java b/src/com/android/launcher3/util/SimpleBroadcastReceiver.java
index 539a7cb..7a40abe 100644
--- a/src/com/android/launcher3/util/SimpleBroadcastReceiver.java
+++ b/src/com/android/launcher3/util/SimpleBroadcastReceiver.java
@@ -25,22 +25,29 @@
 import android.text.TextUtils;
 
 import androidx.annotation.AnyThread;
+import androidx.annotation.NonNull;
 import androidx.annotation.Nullable;
 
 import java.util.function.Consumer;
 
 public class SimpleBroadcastReceiver extends BroadcastReceiver {
+    public static final String TAG = "SimpleBroadcastReceiver";
+    // Keeps a strong reference to the context.
+    private final Context mContext;
 
     private final Consumer<Intent> mIntentConsumer;
 
     // Handler to register/unregister broadcast receiver
     private final Handler mHandler;
 
-    public SimpleBroadcastReceiver(LooperExecutor looperExecutor, Consumer<Intent> intentConsumer) {
-        this(looperExecutor.getHandler(), intentConsumer);
+    public SimpleBroadcastReceiver(@NonNull Context context, LooperExecutor looperExecutor,
+            Consumer<Intent> intentConsumer) {
+        this(context, looperExecutor.getHandler(), intentConsumer);
     }
 
-    public SimpleBroadcastReceiver(Handler handler, Consumer<Intent> intentConsumer) {
+    public SimpleBroadcastReceiver(@NonNull Context context, Handler handler,
+            Consumer<Intent> intentConsumer) {
+        mContext = context;
         mIntentConsumer = intentConsumer;
         mHandler = handler;
     }
@@ -50,18 +57,18 @@
         mIntentConsumer.accept(intent);
     }
 
-    /** Calls {@link #register(Context, Runnable, String...)} with null completionCallback. */
+    /** Calls {@link #register(Runnable, String...)} with null completionCallback. */
     @AnyThread
-    public void register(Context context, String... actions) {
-        register(context, null, actions);
+    public void register(String... actions) {
+        register(null, actions);
     }
 
     /**
-     * Calls {@link #register(Context, Runnable, int, String...)} with null completionCallback.
+     * Calls {@link #register(Runnable, int, String...)} with null completionCallback.
      */
     @AnyThread
-    public void register(Context context, int flags, String... actions) {
-        register(context, null, flags, actions);
+    public void register(int flags, String... actions) {
+        register(null, flags, actions);
     }
 
     /**
@@ -74,19 +81,18 @@
      *                           while registerReceiver() is executed on a binder call.
      */
     @AnyThread
-    public void register(
-            Context context, @Nullable Runnable completionCallback, String... actions) {
+    public void register(@Nullable Runnable completionCallback, String... actions) {
         if (Looper.myLooper() == mHandler.getLooper()) {
-            registerInternal(context, completionCallback, actions);
+            registerInternal(mContext, completionCallback, actions);
         } else {
-            mHandler.post(() -> registerInternal(context, completionCallback, actions));
+            mHandler.post(() -> registerInternal(mContext, completionCallback, actions));
         }
     }
 
     /** Register broadcast receiver and run completion callback if passed. */
     @AnyThread
     private void registerInternal(
-            Context context, @Nullable Runnable completionCallback, String... actions) {
+            @NonNull Context context, @Nullable Runnable completionCallback, String... actions) {
         context.registerReceiver(this, getFilter(actions));
         if (completionCallback != null) {
             completionCallback.run();
@@ -94,37 +100,37 @@
     }
 
     /**
-     * Same as {@link #register(Context, Runnable, String...)} above but with additional flags
-     * params.
+     * Same as {@link #register(Runnable, String...)} above but with additional flags
+     * params utilizine the original {@link Context}.
      */
     @AnyThread
-    public void register(
-            Context context, @Nullable Runnable completionCallback, int flags, String... actions) {
+    public void register(@Nullable Runnable completionCallback, int flags, String... actions) {
         if (Looper.myLooper() == mHandler.getLooper()) {
-            registerInternal(context, completionCallback, flags, actions);
+            registerInternal(mContext, completionCallback, flags, actions);
         } else {
-            mHandler.post(() -> registerInternal(context, completionCallback, flags, actions));
+            mHandler.post(() -> registerInternal(mContext, completionCallback, flags, actions));
         }
     }
 
     /** Register broadcast receiver and run completion callback if passed. */
     @AnyThread
     private void registerInternal(
-            Context context, @Nullable Runnable completionCallback, int flags, String... actions) {
+            @NonNull Context context, @Nullable Runnable completionCallback, int flags,
+            String... actions) {
         context.registerReceiver(this, getFilter(actions), flags);
         if (completionCallback != null) {
             completionCallback.run();
         }
     }
 
-    /** Same as {@link #register(Context, Runnable, String...)} above but with pkg name. */
+    /** Same as {@link #register(Runnable, String...)} above but with pkg name. */
     @AnyThread
-    public void registerPkgActions(Context context, @Nullable String pkg, String... actions) {
+    public void registerPkgActions(@Nullable String pkg, String... actions) {
         if (Looper.myLooper() == mHandler.getLooper()) {
-            context.registerReceiver(this, getPackageFilter(pkg, actions));
+            mContext.registerReceiver(this, getPackageFilter(pkg, actions));
         } else {
             mHandler.post(() -> {
-                context.registerReceiver(this, getPackageFilter(pkg, actions));
+                mContext.registerReceiver(this, getPackageFilter(pkg, actions));
             });
         }
     }
@@ -135,19 +141,19 @@
      * unregister happens on {@link #mHandler}'s looper.
      */
     @AnyThread
-    public void unregisterReceiverSafely(Context context) {
+    public void unregisterReceiverSafely() {
         if (Looper.myLooper() == mHandler.getLooper()) {
-            unregisterReceiverSafelyInternal(context);
+            unregisterReceiverSafelyInternal(mContext);
         } else {
             mHandler.post(() -> {
-                unregisterReceiverSafelyInternal(context);
+                unregisterReceiverSafelyInternal(mContext);
             });
         }
     }
 
     /** Unregister broadcast receiver ignoring any errors. */
     @AnyThread
-    private void unregisterReceiverSafelyInternal(Context context) {
+    private void unregisterReceiverSafelyInternal(@NonNull Context context) {
         try {
             context.unregisterReceiver(this);
         } catch (IllegalArgumentException e) {
diff --git a/src/com/android/launcher3/util/WallpaperOffsetInterpolator.java b/src/com/android/launcher3/util/WallpaperOffsetInterpolator.java
index f8cbe0d..26a04a5 100644
--- a/src/com/android/launcher3/util/WallpaperOffsetInterpolator.java
+++ b/src/com/android/launcher3/util/WallpaperOffsetInterpolator.java
@@ -31,8 +31,7 @@
     // Don't use all the wallpaper for parallax until you have at least this many pages
     private static final int MIN_PARALLAX_PAGE_SPAN = 4;
 
-    private final SimpleBroadcastReceiver mWallpaperChangeReceiver =
-            new SimpleBroadcastReceiver(UI_HELPER_EXECUTOR, i -> onWallpaperChanged());
+    private final SimpleBroadcastReceiver mWallpaperChangeReceiver;
     private final Workspace<?> mWorkspace;
     private final boolean mIsRtl;
     private final Handler mHandler;
@@ -46,6 +45,8 @@
 
     public WallpaperOffsetInterpolator(Workspace<?> workspace) {
         mWorkspace = workspace;
+        mWallpaperChangeReceiver = new SimpleBroadcastReceiver(
+                workspace.getContext(), UI_HELPER_EXECUTOR, i -> onWallpaperChanged());
         mIsRtl = Utilities.isRtl(workspace.getResources());
         mHandler = new OffsetHandler(workspace.getContext());
     }
@@ -198,11 +199,10 @@
     public void setWindowToken(IBinder token) {
         mWindowToken = token;
         if (mWindowToken == null && mRegistered) {
-            mWallpaperChangeReceiver.unregisterReceiverSafely(mWorkspace.getContext());
+            mWallpaperChangeReceiver.unregisterReceiverSafely();
             mRegistered = false;
         } else if (mWindowToken != null && !mRegistered) {
-            mWallpaperChangeReceiver.register(
-                    mWorkspace.getContext(), ACTION_WALLPAPER_CHANGED);
+            mWallpaperChangeReceiver.register(ACTION_WALLPAPER_CHANGED);
             onWallpaperChanged();
             mRegistered = true;
         }
diff --git a/tests/multivalentTests/src/com/android/launcher3/util/ScreenOnTrackerTest.kt b/tests/multivalentTests/src/com/android/launcher3/util/ScreenOnTrackerTest.kt
index 45cc19c..9c3f223 100644
--- a/tests/multivalentTests/src/com/android/launcher3/util/ScreenOnTrackerTest.kt
+++ b/tests/multivalentTests/src/com/android/launcher3/util/ScreenOnTrackerTest.kt
@@ -51,7 +51,7 @@
 
     @Test
     fun test_default_state() {
-        verify(receiver).register(context, ACTION_SCREEN_ON, ACTION_SCREEN_OFF, ACTION_USER_PRESENT)
+        verify(receiver).register(ACTION_SCREEN_ON, ACTION_SCREEN_OFF, ACTION_USER_PRESENT)
         assertThat(underTest.isScreenOn).isTrue()
     }
 
@@ -59,7 +59,7 @@
     fun close_unregister_receiver() {
         underTest.close()
 
-        verify(receiver).unregisterReceiverSafely(context)
+        verify(receiver).unregisterReceiverSafely()
     }
 
     @Test
diff --git a/tests/multivalentTests/src/com/android/launcher3/util/SimpleBroadcastReceiverTest.kt b/tests/multivalentTests/src/com/android/launcher3/util/SimpleBroadcastReceiverTest.kt
index d3e27b6..17933f2 100644
--- a/tests/multivalentTests/src/com/android/launcher3/util/SimpleBroadcastReceiverTest.kt
+++ b/tests/multivalentTests/src/com/android/launcher3/util/SimpleBroadcastReceiverTest.kt
@@ -52,7 +52,7 @@
     @Before
     fun setUp() {
         MockitoAnnotations.initMocks(this)
-        underTest = SimpleBroadcastReceiver(UI_HELPER_EXECUTOR, intentConsumer)
+        underTest = SimpleBroadcastReceiver(context, UI_HELPER_EXECUTOR, intentConsumer)
         if (Looper.getMainLooper() == null) {
             Looper.prepareMainLooper()
         }
@@ -60,7 +60,7 @@
 
     @Test
     fun async_register() {
-        underTest.register(context, "test_action_1", "test_action_2")
+        underTest.register("test_action_1", "test_action_2")
         awaitTasksCompleted()
 
         verify(context).registerReceiver(same(underTest), intentFilterCaptor.capture())
@@ -72,7 +72,7 @@
 
     @Test
     fun async_register_withCompletionRunnable() {
-        underTest.register(context, completionRunnable, "test_action_1", "test_action_2")
+        underTest.register(completionRunnable, "test_action_1", "test_action_2")
         awaitTasksCompleted()
 
         verify(context).registerReceiver(same(underTest), intentFilterCaptor.capture())
@@ -85,7 +85,7 @@
 
     @Test
     fun async_register_withCompletionRunnable_and_flag() {
-        underTest.register(context, completionRunnable, 1, "test_action_1", "test_action_2")
+        underTest.register(completionRunnable, 1, "test_action_1", "test_action_2")
         awaitTasksCompleted()
 
         verify(context).registerReceiver(same(underTest), intentFilterCaptor.capture(), eq(1))
@@ -98,7 +98,7 @@
 
     @Test
     fun async_register_with_package() {
-        underTest.registerPkgActions(context, "pkg", "test_action_1", "test_action_2")
+        underTest.registerPkgActions("pkg", "test_action_1", "test_action_2")
 
         awaitTasksCompleted()
         verify(context).registerReceiver(same(underTest), intentFilterCaptor.capture())
@@ -112,9 +112,10 @@
 
     @Test
     fun sync_register_withCompletionRunnable_and_flag() {
-        underTest = SimpleBroadcastReceiver(Handler(Looper.getMainLooper()), intentConsumer)
+        underTest =
+            SimpleBroadcastReceiver(context, Handler(Looper.getMainLooper()), intentConsumer)
 
-        underTest.register(context, completionRunnable, 1, "test_action_1", "test_action_2")
+        underTest.register(completionRunnable, 1, "test_action_1", "test_action_2")
         getInstrumentation().waitForIdleSync()
 
         verify(context).registerReceiver(same(underTest), intentFilterCaptor.capture(), eq(1))
@@ -127,7 +128,7 @@
 
     @Test
     fun async_unregister() {
-        underTest.unregisterReceiverSafely(context)
+        underTest.unregisterReceiverSafely()
 
         awaitTasksCompleted()
         verify(context).unregisterReceiver(same(underTest))
@@ -135,9 +136,10 @@
 
     @Test
     fun sync_unregister() {
-        underTest = SimpleBroadcastReceiver(Handler(Looper.getMainLooper()), intentConsumer)
+        underTest =
+            SimpleBroadcastReceiver(context, Handler(Looper.getMainLooper()), intentConsumer)
 
-        underTest.unregisterReceiverSafely(context)
+        underTest.unregisterReceiverSafely()
         getInstrumentation().waitForIdleSync()
 
         verify(context).unregisterReceiver(same(underTest))