Moving InvariantDeviceProfile to Dagger

Bug: 361850561
Test: Presubmit
Flag: EXEMPT dagger
Change-Id: Ia99ee031e96c5e6069a56fe90a262a533dd8665a
diff --git a/src/com/android/launcher3/DeviceProfile.java b/src/com/android/launcher3/DeviceProfile.java
index 813d8f1..3b283c3 100644
--- a/src/com/android/launcher3/DeviceProfile.java
+++ b/src/com/android/launcher3/DeviceProfile.java
@@ -387,7 +387,8 @@
     }
 
     /** TODO: Once we fully migrate to staged split, remove "isMultiWindowMode" */
-    DeviceProfile(Context context, InvariantDeviceProfile inv, Info info, WindowBounds windowBounds,
+    DeviceProfile(Context context, InvariantDeviceProfile inv, Info info,
+            WindowManagerProxy wmProxy, IconShape iconShape, WindowBounds windowBounds,
             SparseArray<DotRenderer> dotRendererCache, boolean isMultiWindowMode,
             boolean transposeLayoutWithOrientation, boolean isMultiDisplay, boolean isGestureMode,
             @NonNull final ViewScaleProvider viewScaleProvider,
@@ -420,7 +421,7 @@
         isPhone = !isTablet;
         isTwoPanels = isTablet && isMultiDisplay;
         isTaskbarPresent = (isTablet || (enableTinyTaskbar() && isGestureMode))
-                && WindowManagerProxy.INSTANCE.get(context).isTaskbarDrawnInProcess();
+                && wmProxy.isTaskbarDrawnInProcess();
 
         // Some more constants.
         context = getContext(context, info, inv.isFixedLandscape
@@ -845,8 +846,8 @@
         dimensionOverrideProvider.accept(this);
 
         // This is done last, after iconSizePx is calculated above.
-        mDotRendererWorkSpace = createDotRenderer(context, iconSizePx, dotRendererCache);
-        mDotRendererAllApps = createDotRenderer(context, allAppsIconSizePx, dotRendererCache);
+        mDotRendererWorkSpace = createDotRenderer(iconShape, iconSizePx, dotRendererCache);
+        mDotRendererAllApps = createDotRenderer(iconShape, allAppsIconSizePx, dotRendererCache);
     }
 
     /**
@@ -868,12 +869,12 @@
     }
 
     private static DotRenderer createDotRenderer(
-            @NonNull Context context, int size, @NonNull SparseArray<DotRenderer> cache) {
+            @NonNull IconShape iconShape, int size, @NonNull SparseArray<DotRenderer> cache) {
         DotRenderer renderer = cache.get(size);
         if (renderer == null) {
             renderer = new DotRenderer(
                     size,
-                    IconShape.INSTANCE.get(context).getShape().getPath(DEFAULT_DOT_SIZE),
+                    iconShape.getShape().getPath(DEFAULT_DOT_SIZE),
                     DEFAULT_DOT_SIZE);
             cache.put(size, renderer);
         }
@@ -1090,7 +1091,7 @@
         dotRendererCache.put(iconSizePx, mDotRendererWorkSpace);
         dotRendererCache.put(allAppsIconSizePx, mDotRendererAllApps);
 
-        return new Builder(context, inv, mInfo)
+        return inv.newDPBuilder(context, mInfo)
                 .setWindowBounds(bounds)
                 .setIsMultiDisplay(isMultiDisplay)
                 .setMultiWindowMode(isMultiWindowMode)
@@ -2473,9 +2474,11 @@
     }
 
     public static class Builder {
-        private Context mContext;
-        private InvariantDeviceProfile mInv;
-        private Info mInfo;
+        private final Context mContext;
+        private final InvariantDeviceProfile mInv;
+        private final Info mInfo;
+        private final WindowManagerProxy mWMProxy;
+        private final IconShape mIconShape;
 
         private WindowBounds mWindowBounds;
         private boolean mIsMultiDisplay;
@@ -2491,10 +2494,13 @@
 
         private boolean mIsTransientTaskbar;
 
-        public Builder(Context context, InvariantDeviceProfile inv, Info info) {
+        public Builder(Context context, InvariantDeviceProfile inv, Info info,
+                WindowManagerProxy wmProxy, IconShape iconShape) {
             mContext = context;
             mInv = inv;
             mInfo = info;
+            mWMProxy = wmProxy;
+            mIconShape = iconShape;
             mIsTransientTaskbar = info.isTransientTaskbar();
         }
 
@@ -2575,7 +2581,8 @@
             if (mOverrideProvider == null) {
                 mOverrideProvider = DEFAULT_DIMENSION_PROVIDER;
             }
-            return new DeviceProfile(mContext, mInv, mInfo, mWindowBounds, mDotRendererCache,
+            return new DeviceProfile(mContext, mInv, mInfo, mWMProxy, mIconShape,
+                    mWindowBounds, mDotRendererCache,
                     mIsMultiWindowMode, mTransposeLayoutWithOrientation, mIsMultiDisplay,
                     mIsGestureMode, mViewScaleProvider, mOverrideProvider, mIsTransientTaskbar);
         }
diff --git a/src/com/android/launcher3/InvariantDeviceProfile.java b/src/com/android/launcher3/InvariantDeviceProfile.java
index 56c2b8e..43876a6 100644
--- a/src/com/android/launcher3/InvariantDeviceProfile.java
+++ b/src/com/android/launcher3/InvariantDeviceProfile.java
@@ -30,7 +30,6 @@
 import static com.android.launcher3.util.DisplayController.CHANGE_TASKBAR_PINNING;
 import static com.android.launcher3.util.Executors.MAIN_EXECUTOR;
 
-import android.annotation.TargetApi;
 import android.content.Context;
 import android.content.Intent;
 import android.content.res.Resources;
@@ -46,7 +45,6 @@
 import android.util.Log;
 import android.util.SparseArray;
 import android.util.Xml;
-import android.view.Display;
 
 import androidx.annotation.DimenRes;
 import androidx.annotation.IntDef;
@@ -56,18 +54,21 @@
 import androidx.core.content.res.ResourcesCompat;
 
 import com.android.launcher3.config.FeatureFlags;
+import com.android.launcher3.dagger.ApplicationContext;
+import com.android.launcher3.dagger.LauncherAppComponent;
+import com.android.launcher3.dagger.LauncherAppSingleton;
+import com.android.launcher3.graphics.IconShape;
 import com.android.launcher3.icons.DotRenderer;
 import com.android.launcher3.logging.FileLog;
 import com.android.launcher3.model.DeviceGridState;
 import com.android.launcher3.provider.RestoreDbTask;
 import com.android.launcher3.testing.shared.ResourceUtils;
+import com.android.launcher3.util.DaggerSingletonObject;
+import com.android.launcher3.util.DaggerSingletonTracker;
 import com.android.launcher3.util.DisplayController;
 import com.android.launcher3.util.DisplayController.Info;
-import com.android.launcher3.util.MainThreadInitializedObject;
 import com.android.launcher3.util.Partner;
 import com.android.launcher3.util.ResourceHelper;
-import com.android.launcher3.util.RunnableList;
-import com.android.launcher3.util.SafeCloseable;
 import com.android.launcher3.util.SimpleBroadcastReceiver;
 import com.android.launcher3.util.WindowBounds;
 import com.android.launcher3.util.window.CachedDisplayInfo;
@@ -86,12 +87,15 @@
 import java.util.Objects;
 import java.util.stream.Collectors;
 
-public class InvariantDeviceProfile implements SafeCloseable {
+import javax.inject.Inject;
+
+@LauncherAppSingleton
+public class InvariantDeviceProfile {
 
     public static final String TAG = "IDP";
     // We do not need any synchronization for this variable as its only written on UI thread.
-    public static final MainThreadInitializedObject<InvariantDeviceProfile> INSTANCE =
-            new MainThreadInitializedObject<>(InvariantDeviceProfile::new);
+    public static final DaggerSingletonObject<InvariantDeviceProfile> INSTANCE =
+            new DaggerSingletonObject<>(LauncherAppComponent::getIDP);
 
     public static final String GRID_NAME_PREFS_KEY = "idp_grid_name";
     public static final String NON_FIXED_LANDSCAPE_GRID_NAME_PREFS_KEY =
@@ -129,7 +133,10 @@
     private static final String RES_GRID_NUM_COLUMNS = "grid_num_columns";
     private static final String RES_GRID_ICON_SIZE_DP = "grid_icon_size_dp";
 
-    private final RunnableList mCloseActions = new RunnableList();
+    private final DisplayController mDisplayController;
+    private final WindowManagerProxy mWMProxy;
+    private final LauncherPrefs mPrefs;
+    private final IconShape mIconShape;
 
     /**
      * Number of icons per row and column in the workspace.
@@ -247,16 +254,22 @@
 
     private final ArrayList<OnIDPChangeListener> mChangeListeners = new ArrayList<>();
 
-    @VisibleForTesting
-    public InvariantDeviceProfile() {
-    }
+    @Inject
+    InvariantDeviceProfile(
+            @ApplicationContext Context context,
+            LauncherPrefs prefs,
+            DisplayController dc,
+            WindowManagerProxy wmProxy,
+            IconShape iconShape,
+            DaggerSingletonTracker lifeCycle) {
+        mDisplayController = dc;
+        mWMProxy = wmProxy;
+        mPrefs = prefs;
+        mIconShape = iconShape;
 
-    @TargetApi(23)
-    private InvariantDeviceProfile(Context context) {
-        String gridName = getCurrentGridName(context);
+        String gridName = prefs.get(GRID_NAME);
         initGrid(context, gridName);
 
-        DisplayController dc = DisplayController.INSTANCE.get(context);
         dc.setPriorityListener(
                 (displayContext, info, flags) -> {
                     if ((flags & (CHANGE_DENSITY | CHANGE_SUPPORTED_BOUNDS
@@ -265,126 +278,46 @@
                         onConfigChanged(displayContext);
                     }
                 });
-        mCloseActions.add(() -> dc.setPriorityListener(null));
+        lifeCycle.addCloseable(() -> dc.setPriorityListener(null));
 
         LauncherPrefChangeListener prefListener = key -> {
             if (FIXED_LANDSCAPE_MODE.getSharedPrefKey().equals(key)
-                    && isFixedLandscape != FIXED_LANDSCAPE_MODE.get(context)) {
+                    && isFixedLandscape != prefs.get(FIXED_LANDSCAPE_MODE)) {
                 Trace.beginSection("InvariantDeviceProfile#setFixedLandscape");
                 if (isFixedLandscape) {
-                    setCurrentGrid(
-                            context, LauncherPrefs.get(context).get(NON_FIXED_LANDSCAPE_GRID_NAME));
+                    setCurrentGrid(context, prefs.get(NON_FIXED_LANDSCAPE_GRID_NAME));
                 } else {
-                    LauncherPrefs.get(context)
-                            .put(NON_FIXED_LANDSCAPE_GRID_NAME, getCurrentGridName(context));
+                    prefs.put(NON_FIXED_LANDSCAPE_GRID_NAME, mPrefs.get(GRID_NAME));
                     onConfigChanged(context);
                 }
                 Trace.endSection();
             } else if (ENABLE_TWOLINE_ALLAPPS_TOGGLE.getSharedPrefKey().equals(key)
-                    && enableTwoLinesInAllApps != ENABLE_TWOLINE_ALLAPPS_TOGGLE.get(context)) {
+                    && enableTwoLinesInAllApps != prefs.get(ENABLE_TWOLINE_ALLAPPS_TOGGLE)) {
                 onConfigChanged(context);
             }
         };
-        LauncherPrefs prefs = LauncherPrefs.INSTANCE.get(context);
         prefs.addListener(prefListener, FIXED_LANDSCAPE_MODE, ENABLE_TWOLINE_ALLAPPS_TOGGLE);
-        mCloseActions.add(() -> prefs.removeListener(prefListener,
+        lifeCycle.addCloseable(() -> prefs.removeListener(prefListener,
                 FIXED_LANDSCAPE_MODE, ENABLE_TWOLINE_ALLAPPS_TOGGLE));
 
         SimpleBroadcastReceiver localeReceiver = new SimpleBroadcastReceiver(
                 MAIN_EXECUTOR, i -> onConfigChanged(context));
         localeReceiver.register(context, Intent.ACTION_LOCALE_CHANGED);
-        mCloseActions.add(() -> localeReceiver.unregisterReceiverSafely(context));
-    }
-
-    /**
-     * This constructor should NOT have any monitors by design.
-     */
-    public InvariantDeviceProfile(Context context, String gridName) {
-        String newName = initGrid(context, gridName);
-        if (newName == null || !newName.equals(gridName)) {
-            throw new IllegalArgumentException("Unknown grid name: " + gridName);
-        }
-    }
-
-    /**
-     * This constructor should NOT have any monitors by design.
-     */
-    public InvariantDeviceProfile(Context context, Display display) {
-        // Ensure that the main device profile is initialized
-        INSTANCE.get(context);
-        String gridName = getCurrentGridName(context);
-
-        // Get the display info based on default display and interpolate it to existing display
-        Info defaultInfo = DisplayController.INSTANCE.get(context).getInfo();
-        @DeviceType int defaultDeviceType = defaultInfo.getDeviceType();
-        DisplayOption defaultDisplayOption = invDistWeightedInterpolate(
-                defaultInfo,
-                getPredefinedDeviceProfiles(
-                        context,
-                        gridName,
-                        defaultInfo,
-                        /*allowDisabledGrid=*/false,
-                        FIXED_LANDSCAPE_MODE.get(context)
-                ),
-                defaultDeviceType);
-
-        Context displayContext = context.createDisplayContext(display);
-        Info myInfo = new Info(displayContext);
-        @DeviceType int deviceType = myInfo.getDeviceType();
-        DisplayOption myDisplayOption = invDistWeightedInterpolate(
-                myInfo,
-                getPredefinedDeviceProfiles(
-                        context,
-                        gridName,
-                        myInfo,
-                        /*allowDisabledGrid=*/false,
-                        FIXED_LANDSCAPE_MODE.get(context)
-                ),
-                deviceType);
-
-        DisplayOption result = new DisplayOption(defaultDisplayOption.grid)
-                .add(myDisplayOption);
-        result.iconSizes[INDEX_DEFAULT] =
-                defaultDisplayOption.iconSizes[INDEX_DEFAULT];
-        for (int i = 1; i < COUNT_SIZES; i++) {
-            result.iconSizes[i] = Math.min(
-                    defaultDisplayOption.iconSizes[i], myDisplayOption.iconSizes[i]);
-        }
-
-        System.arraycopy(defaultDisplayOption.minCellSize, 0, result.minCellSize, 0,
-                COUNT_SIZES);
-        System.arraycopy(defaultDisplayOption.borderSpaces, 0, result.borderSpaces, 0,
-                COUNT_SIZES);
-
-        initGrid(context, myInfo, result);
-    }
-
-    @Override
-    public void close() {
-        mCloseActions.executeAllAndDestroy();
-    }
-
-    public static String getCurrentGridName(Context context) {
-        return LauncherPrefs.get(context).get(GRID_NAME);
+        lifeCycle.addCloseable(() -> localeReceiver.unregisterReceiverSafely(context));
     }
 
     private String initGrid(Context context, String gridName) {
-        FileLog.d(TAG, "Before initGrid:"
-                + "gridName:" + gridName
-                + ", dbFile:" + dbFile
-                + ", LauncherPrefs GRID_NAME:" + LauncherPrefs.get(context).get(GRID_NAME)
-                + ", LauncherPrefs DB_FILE:" + LauncherPrefs.get(context).get(DB_FILE));
-        Info displayInfo = DisplayController.INSTANCE.get(context).getInfo();
+        Info displayInfo = mDisplayController.getInfo();
         List<DisplayOption> allOptions = getPredefinedDeviceProfiles(
                 context,
                 gridName,
                 displayInfo,
-                RestoreDbTask.isPending(context),
-                FIXED_LANDSCAPE_MODE.get(context)
+                RestoreDbTask.isPending(mPrefs),
+                mPrefs.get(FIXED_LANDSCAPE_MODE)
         );
 
         // Filter out options that don't have the same number of columns as the grid
-        DeviceGridState deviceGridState = new DeviceGridState(context);
+        DeviceGridState deviceGridState = new DeviceGridState(mPrefs);
         List<DisplayOption> allOptionsFilteredByColCount =
                 filterByColumnCount(allOptions, deviceGridState.getColumns());
 
@@ -395,15 +328,15 @@
                         displayInfo.getDeviceType());
 
         if (!displayOption.grid.name.equals(gridName)) {
-            LauncherPrefs.get(context).put(GRID_NAME, displayOption.grid.name);
+            mPrefs.put(GRID_NAME, displayOption.grid.name);
         }
 
         initGrid(context, displayInfo, displayOption);
         FileLog.d(TAG, "After initGrid:"
                 + "gridName:" + gridName
                 + ", dbFile:" + dbFile
-                + ", LauncherPrefs GRID_NAME:" + LauncherPrefs.get(context).get(GRID_NAME)
-                + ", LauncherPrefs DB_FILE:" + LauncherPrefs.get(context).get(DB_FILE));
+                + ", LauncherPrefs GRID_NAME:" + mPrefs.get(GRID_NAME)
+                + ", LauncherPrefs DB_FILE:" + mPrefs.get(DB_FILE));
         return displayOption.grid.name;
     }
 
@@ -420,18 +353,13 @@
      */
     @Deprecated
     public void reset(Context context) {
-        initGrid(context, getCurrentGridName(context));
-    }
-
-    @VisibleForTesting
-    public static String getDefaultGridName(Context context) {
-        return new InvariantDeviceProfile().initGrid(context, null);
+        initGrid(context, mPrefs.get(GRID_NAME));
     }
 
     private void initGrid(Context context, Info displayInfo, DisplayOption displayOption) {
         enableTwoLinesInAllApps = Flags.enableTwolineToggle()
                 && Utilities.isEnglishLanguage(context)
-                && ENABLE_TWOLINE_ALLAPPS_TOGGLE.get(context);
+                && mPrefs.get(ENABLE_TWOLINE_ALLAPPS_TOGGLE);
         mLocale = context.getResources().getConfiguration().locale.toString();
 
         DisplayMetrics metrics = context.getResources().getDisplayMetrics();
@@ -522,7 +450,7 @@
         defaultWallpaperSize = new Point(displayInfo.currentSize);
         SparseArray<DotRenderer> dotRendererCache = new SparseArray<>();
         for (WindowBounds bounds : displayInfo.supportedBounds) {
-            localSupportedProfiles.add(new DeviceProfile.Builder(context, this, displayInfo)
+            localSupportedProfiles.add(newDPBuilder(context, displayInfo)
                     .setIsMultiDisplay(deviceType == TYPE_MULTI_DISPLAY)
                     .setWindowBounds(bounds)
                     .setDotRendererCache(dotRendererCache)
@@ -561,6 +489,10 @@
                 });
     }
 
+    DeviceProfile.Builder newDPBuilder(Context context, Info info) {
+        return new DeviceProfile.Builder(context, this, info, mWMProxy, mIconShape);
+    }
+
     public void addOnChangeListener(OnIDPChangeListener listener) {
         mChangeListeners.add(listener);
     }
@@ -574,7 +506,7 @@
      * migration.
      */
     public void setCurrentGrid(Context context, String newGridName) {
-        LauncherPrefs.get(context).put(GRID_NAME, newGridName);
+        mPrefs.put(GRID_NAME, newGridName);
         MAIN_EXECUTOR.execute(() -> {
             Trace.beginSection("InvariantDeviceProfile#setCurrentGrid");
             onConfigChanged(context.getApplicationContext());
@@ -594,8 +526,7 @@
         Object[] oldState = toModelState();
 
         // Re-init grid
-        String gridName = getCurrentGridName(context);
-        initGrid(context, gridName);
+        initGrid(context, mPrefs.get(GRID_NAME));
 
         boolean modelPropsChanged = !Arrays.equals(oldState, toModelState());
         for (OnIDPChangeListener listener : mChangeListeners) {
@@ -912,11 +843,20 @@
         return out;
     }
 
-    public DeviceProfile getDeviceProfile(Context context) {
-        WindowManagerProxy windowManagerProxy = WindowManagerProxy.INSTANCE.get(context);
-        Rect bounds = windowManagerProxy.getCurrentBounds(context);
-        int rotation = windowManagerProxy.getRotation(context);
+    public DeviceProfile createDeviceProfileForSecondaryDisplay(Context displayContext) {
+        // Disable transpose layout and use multi-window mode so that the icons are scaled properly
+        return newDPBuilder(displayContext, new Info(displayContext))
+                .setIsMultiDisplay(false)
+                .setMultiWindowMode(true)
+                .setWindowBounds(mWMProxy.getRealBounds(
+                        displayContext, mWMProxy.getDisplayInfo(displayContext)))
+                .setTransposeLayoutWithOrientation(false)
+                .build();
+    }
 
+    public DeviceProfile getDeviceProfile(Context context) {
+        Rect bounds = mWMProxy.getCurrentBounds(context);
+        int rotation = mWMProxy.getRotation(context);
         return getBestMatch(bounds.width(), bounds.height(), rotation);
     }
 
diff --git a/src/com/android/launcher3/dagger/LauncherBaseAppComponent.java b/src/com/android/launcher3/dagger/LauncherBaseAppComponent.java
index 7bd7c3e..87b5459 100644
--- a/src/com/android/launcher3/dagger/LauncherBaseAppComponent.java
+++ b/src/com/android/launcher3/dagger/LauncherBaseAppComponent.java
@@ -18,6 +18,7 @@
 
 import android.content.Context;
 
+import com.android.launcher3.InvariantDeviceProfile;
 import com.android.launcher3.LauncherPrefs;
 import com.android.launcher3.graphics.IconShape;
 import com.android.launcher3.graphics.ThemeManager;
@@ -72,6 +73,7 @@
     DisplayController getDisplayController();
     WallpaperColorHints getWallpaperColorHints();
     LockedUserState getLockedUserState();
+    InvariantDeviceProfile getIDP();
 
     /** Builder for LauncherBaseAppComponent. */
     interface Builder {
diff --git a/src/com/android/launcher3/dagger/LauncherComponentProvider.kt b/src/com/android/launcher3/dagger/LauncherComponentProvider.kt
index 5015e54..71e3354 100644
--- a/src/com/android/launcher3/dagger/LauncherComponentProvider.kt
+++ b/src/com/android/launcher3/dagger/LauncherComponentProvider.kt
@@ -47,6 +47,10 @@
             .component
     }
 
+    /** Extension method easily access LauncherAppComponent */
+    val Context.appComponent: LauncherAppComponent
+        get() = get(this)
+
     private data class Holder(
         val component: LauncherAppComponent,
         private val filter: LayoutInflater.Filter?,
diff --git a/src/com/android/launcher3/graphics/LauncherPreviewRenderer.java b/src/com/android/launcher3/graphics/LauncherPreviewRenderer.java
index 911064c..03f0582 100644
--- a/src/com/android/launcher3/graphics/LauncherPreviewRenderer.java
+++ b/src/com/android/launcher3/graphics/LauncherPreviewRenderer.java
@@ -24,6 +24,7 @@
 import static com.android.launcher3.BubbleTextView.DISPLAY_WORKSPACE;
 import static com.android.launcher3.DeviceProfile.DEFAULT_SCALE;
 import static com.android.launcher3.Hotseat.ALPHA_CHANNEL_PREVIEW_RENDERER;
+import static com.android.launcher3.LauncherPrefs.GRID_NAME;
 import static com.android.launcher3.LauncherSettings.Favorites.CONTAINER_HOTSEAT;
 import static com.android.launcher3.LauncherSettings.Favorites.CONTAINER_HOTSEAT_PREDICTION;
 import static com.android.launcher3.Utilities.SHOULD_SHOW_FIRST_PAGE_WIDGET;
@@ -137,14 +138,14 @@
 
         private final String mPrefName;
 
-        public PreviewContext(Context base, InvariantDeviceProfile idp) {
+        public PreviewContext(Context base, String gridName) {
             super(base);
             mPrefName = "preview-" + UUID.randomUUID().toString();
-            initDaggerComponent(DaggerLauncherPreviewRenderer_PreviewAppComponent.builder()
-                    .bindPrefs(new ProxyPrefs(
-                            this, getSharedPreferences(mPrefName, MODE_PRIVATE))));
-
-            putObject(InvariantDeviceProfile.INSTANCE, idp);
+            LauncherPrefs prefs =
+                    new ProxyPrefs(this, getSharedPreferences(mPrefName, MODE_PRIVATE));
+            prefs.put(GRID_NAME, gridName);
+            initDaggerComponent(
+                    DaggerLauncherPreviewRenderer_PreviewAppComponent.builder().bindPrefs(prefs));
             putObject(LauncherAppState.INSTANCE,
                     new LauncherAppState(this, null /* iconCacheFileName */));
         }
@@ -192,8 +193,8 @@
                 this::getAppWidgetScale).build();
         if (context instanceof PreviewContext) {
             Context tempContext = ((PreviewContext) context).getBaseContext();
-            mDpOrig = new InvariantDeviceProfile(tempContext, InvariantDeviceProfile
-                    .getCurrentGridName(tempContext)).getDeviceProfile(tempContext)
+            mDpOrig = InvariantDeviceProfile.INSTANCE.get(tempContext)
+                    .getDeviceProfile(tempContext)
                     .copy(tempContext);
         } else {
             mDpOrig = mDp;
diff --git a/src/com/android/launcher3/graphics/PreviewSurfaceRenderer.java b/src/com/android/launcher3/graphics/PreviewSurfaceRenderer.java
index 7a60814..686024d 100644
--- a/src/com/android/launcher3/graphics/PreviewSurfaceRenderer.java
+++ b/src/com/android/launcher3/graphics/PreviewSurfaceRenderer.java
@@ -20,6 +20,7 @@
 import static android.content.res.Configuration.UI_MODE_NIGHT_YES;
 import static android.view.Display.DEFAULT_DISPLAY;
 
+import static com.android.launcher3.LauncherPrefs.GRID_NAME;
 import static com.android.launcher3.LauncherSettings.Favorites.TABLE_NAME;
 import static com.android.launcher3.graphics.ThemeManager.PREF_ICON_SHAPE;
 import static com.android.launcher3.util.Executors.MAIN_EXECUTOR;
@@ -60,7 +61,6 @@
 import com.android.launcher3.model.BaseLauncherBinder;
 import com.android.launcher3.model.BgDataModel;
 import com.android.launcher3.model.BgDataModel.Callbacks;
-import com.android.launcher3.model.GridSizeMigrationDBController;
 import com.android.launcher3.model.LoaderTask;
 import com.android.launcher3.model.ModelDbController;
 import com.android.launcher3.provider.LauncherDbUtils;
@@ -117,7 +117,7 @@
         mGridName = bundle.getString("name");
         bundle.remove("name");
         if (mGridName == null) {
-            mGridName = InvariantDeviceProfile.getCurrentGridName(context);
+            mGridName = LauncherPrefs.get(context).get(GRID_NAME);
         }
         mWallpaperColors = bundle.getParcelable(KEY_COLORS);
         if (Flags.newCustomizationPickerUi()) {
@@ -316,11 +316,10 @@
     @WorkerThread
     private void loadModelData() {
         final Context inflationContext = getPreviewContext();
-        final InvariantDeviceProfile idp = new InvariantDeviceProfile(inflationContext, mGridName);
-        if (GridSizeMigrationDBController.needsToMigrate(inflationContext, idp)
+        if (!mGridName.equals(LauncherPrefs.INSTANCE.get(mContext).get(GRID_NAME))
                 || mShapeKey != null) {
             // Start the migration
-            PreviewContext previewContext = new PreviewContext(inflationContext, idp);
+            PreviewContext previewContext = new PreviewContext(inflationContext, mGridName);
             if (mShapeKey != null) {
                 LauncherPrefs.INSTANCE.get(previewContext).put(PREF_ICON_SHAPE, mShapeKey);
             }
@@ -348,6 +347,7 @@
 
                 @Override
                 public void run() {
+                    InvariantDeviceProfile idp = LauncherAppState.getIDP(previewContext);
                     DeviceProfile deviceProfile = idp.getDeviceProfile(previewContext);
                     String query =
                             LauncherSettings.Favorites.SCREEN + " = " + Workspace.FIRST_SCREEN_ID
@@ -371,7 +371,7 @@
             LauncherAppState.getInstance(inflationContext).getModel().loadAsync(dataModel -> {
                 if (dataModel != null) {
                     MAIN_EXECUTOR.execute(() -> renderView(inflationContext, dataModel, null,
-                            null, idp));
+                            null, LauncherAppState.getIDP(inflationContext)));
                 } else {
                     Log.e(TAG, "Model loading failed");
                 }
diff --git a/src/com/android/launcher3/model/DeviceGridState.java b/src/com/android/launcher3/model/DeviceGridState.java
index 90af215..6f12c97 100644
--- a/src/com/android/launcher3/model/DeviceGridState.java
+++ b/src/com/android/launcher3/model/DeviceGridState.java
@@ -68,7 +68,10 @@
     }
 
     public DeviceGridState(Context context) {
-        LauncherPrefs lp = LauncherPrefs.get(context);
+        this(LauncherPrefs.get(context));
+    }
+
+    public DeviceGridState(LauncherPrefs lp) {
         mGridSizeString = lp.get(WORKSPACE_SIZE);
         mNumHotseat = lp.get(HOTSEAT_COUNT);
         mDeviceType = lp.get(DEVICE_TYPE);
diff --git a/src/com/android/launcher3/provider/RestoreDbTask.java b/src/com/android/launcher3/provider/RestoreDbTask.java
index dc42920..34c9117 100644
--- a/src/com/android/launcher3/provider/RestoreDbTask.java
+++ b/src/com/android/launcher3/provider/RestoreDbTask.java
@@ -415,7 +415,11 @@
     }
 
     public static boolean isPending(Context context) {
-        return LauncherPrefs.get(context).has(RESTORE_DEVICE);
+        return isPending(LauncherPrefs.get(context));
+    }
+
+    public static boolean isPending(LauncherPrefs prefs) {
+        return prefs.has(RESTORE_DEVICE);
     }
 
     /**
diff --git a/src/com/android/launcher3/secondarydisplay/SecondaryDisplayLauncher.java b/src/com/android/launcher3/secondarydisplay/SecondaryDisplayLauncher.java
index df27b54..c20d655 100644
--- a/src/com/android/launcher3/secondarydisplay/SecondaryDisplayLauncher.java
+++ b/src/com/android/launcher3/secondarydisplay/SecondaryDisplayLauncher.java
@@ -115,15 +115,9 @@
         if (mDragLayer != null) {
             return;
         }
-        InvariantDeviceProfile currentDisplayIdp = new InvariantDeviceProfile(
-                this, getWindow().getDecorView().getDisplay());
 
-        // Disable transpose layout and use multi-window mode so that the icons are scaled properly
-        mDeviceProfile = currentDisplayIdp.getDeviceProfile(this)
-                .toBuilder(this)
-                .setMultiWindowMode(true)
-                .setTransposeLayoutWithOrientation(false)
-                .build();
+        mDeviceProfile = InvariantDeviceProfile.INSTANCE.get(this)
+                .createDeviceProfileForSecondaryDisplay(this);
         mDeviceProfile.autoResizeAllAppsCells();
 
         setContentView(R.layout.secondary_launcher);
diff --git a/tests/multivalentTests/src/com/android/launcher3/AbstractDeviceProfileTest.kt b/tests/multivalentTests/src/com/android/launcher3/AbstractDeviceProfileTest.kt
index 9c64ec9..9d42e1b 100644
--- a/tests/multivalentTests/src/com/android/launcher3/AbstractDeviceProfileTest.kt
+++ b/tests/multivalentTests/src/com/android/launcher3/AbstractDeviceProfileTest.kt
@@ -28,11 +28,13 @@
 import android.view.Surface
 import androidx.test.core.app.ApplicationProvider.getApplicationContext
 import androidx.test.platform.app.InstrumentationRegistry
+import com.android.launcher3.LauncherPrefs.Companion.GRID_NAME
 import com.android.launcher3.dagger.LauncherAppComponent
 import com.android.launcher3.dagger.LauncherAppSingleton
 import com.android.launcher3.testing.shared.ResourceUtils
 import com.android.launcher3.util.AllModulesMinusWMProxy
 import com.android.launcher3.util.DisplayController
+import com.android.launcher3.util.FakePrefsModule
 import com.android.launcher3.util.MainThreadInitializedObject.SandboxContext
 import com.android.launcher3.util.NavigationMode
 import com.android.launcher3.util.WindowBounds
@@ -70,7 +72,7 @@
     protected open val runningContext: Context = getApplicationContext()
     private val displayController: DisplayController = mock()
     private val windowManagerProxy: WindowManagerProxy = mock()
-    private val launcherPrefs: LauncherPrefs = mock()
+    private lateinit var launcherPrefs: LauncherPrefs
 
     @get:Rule val setFlagsRule = SetFlagsRule(SetFlagsRule.DefaultInitValueType.DEVICE_DEFAULT)
 
@@ -132,6 +134,7 @@
         isGestureMode: Boolean = true,
         isVerticalBar: Boolean = false,
         isFixedLandscape: Boolean = false,
+        gridName: String? = GRID_NAME.defaultValue,
     ) {
         val (naturalX, naturalY) = deviceSpec.naturalSize
         val windowsBounds = phoneWindowsBounds(deviceSpec, isGestureMode, naturalX, naturalY)
@@ -145,6 +148,7 @@
             isGestureMode,
             densityDpi = deviceSpec.densityDpi,
             isFixedLandscape = isFixedLandscape,
+            gridName = gridName,
         )
     }
 
@@ -152,6 +156,7 @@
         deviceSpec: DeviceSpec,
         isLandscape: Boolean = false,
         isGestureMode: Boolean = true,
+        gridName: String? = GRID_NAME.defaultValue,
     ) {
         val (naturalX, naturalY) = deviceSpec.naturalSize
         val windowsBounds = tabletWindowsBounds(deviceSpec, naturalX, naturalY)
@@ -164,6 +169,7 @@
             rotation = if (isLandscape) Surface.ROTATION_0 else Surface.ROTATION_90,
             isGestureMode,
             densityDpi = deviceSpec.densityDpi,
+            gridName = gridName,
         )
     }
 
@@ -173,6 +179,7 @@
         isLandscape: Boolean = false,
         isGestureMode: Boolean = true,
         isFolded: Boolean = false,
+        gridName: String? = GRID_NAME.defaultValue,
     ) {
         val (unfoldedNaturalX, unfoldedNaturalY) = deviceSpecUnfolded.naturalSize
         val unfoldedWindowsBounds =
@@ -199,6 +206,7 @@
                 rotation = if (isLandscape) Surface.ROTATION_90 else Surface.ROTATION_0,
                 isGestureMode = isGestureMode,
                 densityDpi = deviceSpecFolded.densityDpi,
+                gridName = gridName,
             )
         } else {
             initializeCommonVars(
@@ -207,6 +215,7 @@
                 rotation = if (isLandscape) Surface.ROTATION_0 else Surface.ROTATION_90,
                 isGestureMode = isGestureMode,
                 densityDpi = deviceSpecUnfolded.densityDpi,
+                gridName = gridName,
             )
         }
     }
@@ -282,6 +291,7 @@
         isGestureMode: Boolean = true,
         densityDpi: Int,
         isFixedLandscape: Boolean = false,
+        gridName: String? = GRID_NAME.defaultValue,
     ) {
         setFlagsRule.setFlags(true, Flags.FLAG_ENABLE_TWOLINE_TOGGLE)
         val windowsBounds = perDisplayBoundsCache[displayInfo]!!
@@ -311,18 +321,23 @@
         context.initDaggerComponent(
             DaggerAbsDPTestSandboxComponent.builder()
                 .bindWMProxy(windowManagerProxy)
-                .bindLauncherPrefs(launcherPrefs)
                 .bindDisplayController(displayController)
         )
+        launcherPrefs = context.appComponent.launcherPrefs
+        launcherPrefs.put(
+            LauncherPrefs.TASKBAR_PINNING.to(false),
+            LauncherPrefs.TASKBAR_PINNING_IN_DESKTOP_MODE.to(true),
+            LauncherPrefs.FIXED_LANDSCAPE_MODE.to(isFixedLandscape),
+            LauncherPrefs.HOTSEAT_COUNT.to(-1),
+            LauncherPrefs.DEVICE_TYPE.to(-1),
+            LauncherPrefs.WORKSPACE_SIZE.to(""),
+            LauncherPrefs.DB_FILE.to(""),
+            LauncherPrefs.ENABLE_TWOLINE_ALLAPPS_TOGGLE.to(true),
+        )
+        if (gridName != null) {
+            launcherPrefs.put(GRID_NAME, gridName)
+        }
 
-        whenever(launcherPrefs.get(LauncherPrefs.TASKBAR_PINNING)).thenReturn(false)
-        whenever(launcherPrefs.get(LauncherPrefs.TASKBAR_PINNING_IN_DESKTOP_MODE)).thenReturn(true)
-        whenever(launcherPrefs.get(LauncherPrefs.FIXED_LANDSCAPE_MODE)).thenReturn(isFixedLandscape)
-        whenever(launcherPrefs.get(LauncherPrefs.HOTSEAT_COUNT)).thenReturn(-1)
-        whenever(launcherPrefs.get(LauncherPrefs.DEVICE_TYPE)).thenReturn(-1)
-        whenever(launcherPrefs.get(LauncherPrefs.WORKSPACE_SIZE)).thenReturn("")
-        whenever(launcherPrefs.get(LauncherPrefs.DB_FILE)).thenReturn("")
-        whenever(launcherPrefs.get(LauncherPrefs.ENABLE_TWOLINE_ALLAPPS_TOGGLE)).thenReturn(true)
         val info = spy(DisplayController.Info(context, windowManagerProxy, perDisplayBoundsCache))
         whenever(displayController.info).thenReturn(info)
         whenever(info.isTransientTaskbar).thenReturn(isGestureMode)
@@ -365,15 +380,13 @@
 }
 
 @LauncherAppSingleton
-@Component(modules = [AllModulesMinusWMProxy::class])
+@Component(modules = [AllModulesMinusWMProxy::class, FakePrefsModule::class])
 interface AbsDPTestSandboxComponent : LauncherAppComponent {
 
     @Component.Builder
     interface Builder : LauncherAppComponent.Builder {
         @BindsInstance fun bindWMProxy(proxy: WindowManagerProxy): Builder
 
-        @BindsInstance fun bindLauncherPrefs(prefs: LauncherPrefs): Builder
-
         @BindsInstance fun bindDisplayController(displayController: DisplayController): Builder
 
         override fun build(): AbsDPTestSandboxComponent
diff --git a/tests/multivalentTests/src/com/android/launcher3/FakeInvariantDeviceProfileTest.kt b/tests/multivalentTests/src/com/android/launcher3/FakeInvariantDeviceProfileTest.kt
index bfbdb18..060c28c 100644
--- a/tests/multivalentTests/src/com/android/launcher3/FakeInvariantDeviceProfileTest.kt
+++ b/tests/multivalentTests/src/com/android/launcher3/FakeInvariantDeviceProfileTest.kt
@@ -15,7 +15,6 @@
  */
 package com.android.launcher3
 
-import android.content.Context
 import android.graphics.PointF
 import android.graphics.Rect
 import android.platform.test.rule.AllowedDevices
@@ -23,10 +22,11 @@
 import android.platform.test.rule.IgnoreLimit
 import android.platform.test.rule.LimitDevicesRule
 import android.util.SparseArray
-import androidx.test.core.app.ApplicationProvider
 import com.android.launcher3.DeviceProfile.DEFAULT_DIMENSION_PROVIDER
 import com.android.launcher3.DeviceProfile.DEFAULT_PROVIDER
+import com.android.launcher3.testing.shared.ResourceUtils.INVALID_RESOURCE_HANDLE
 import com.android.launcher3.util.DisplayController.Info
+import com.android.launcher3.util.SandboxApplication
 import com.android.launcher3.util.WindowBounds
 import java.io.PrintWriter
 import java.io.StringWriter
@@ -46,7 +46,8 @@
 @IgnoreLimit(ignoreLimit = BuildConfig.IS_STUDIO_BUILD)
 abstract class FakeInvariantDeviceProfileTest {
 
-    protected lateinit var context: Context
+    @get:Rule val context = SandboxApplication()
+
     protected lateinit var inv: InvariantDeviceProfile
     protected val info = mock<Info>()
     protected lateinit var windowBounds: WindowBounds
@@ -59,7 +60,6 @@
 
     @Before
     open fun setUp() {
-        context = ApplicationProvider.getApplicationContext()
         // make sure to reset values
         useTwoPanels = false
         isGestureMode = true
@@ -70,6 +70,8 @@
             context,
             inv,
             info,
+            context.appComponent.wmProxy,
+            context.appComponent.iconShape,
             windowBounds,
             SparseArray(),
             /*isMultiWindowMode=*/ false,
@@ -107,7 +109,7 @@
         transposeLayoutWithOrientation = true
 
         inv =
-            InvariantDeviceProfile().apply {
+            context.appComponent.idp.apply {
                 numRows = 5
                 numColumns = 4
                 numSearchContainerColumns = 4
@@ -169,6 +171,14 @@
                 inlineQsb = BooleanArray(4) { false }
 
                 devicePaddingId = R.xml.paddings_handhelds
+
+                isFixedLandscape = false
+                workspaceSpecsId = INVALID_RESOURCE_HANDLE
+                allAppsSpecsId = INVALID_RESOURCE_HANDLE
+                folderSpecsId = INVALID_RESOURCE_HANDLE
+                hotseatSpecsId = INVALID_RESOURCE_HANDLE
+                workspaceCellSpecsId = INVALID_RESOURCE_HANDLE
+                allAppsCellSpecsId = INVALID_RESOURCE_HANDLE
             }
     }
 
@@ -189,7 +199,7 @@
         useTwoPanels = false
 
         inv =
-            InvariantDeviceProfile().apply {
+            context.appComponent.idp.apply {
                 numRows = 5
                 numColumns = 6
                 numSearchContainerColumns = 3
@@ -252,6 +262,14 @@
                 inlineQsb = booleanArrayOf(false, true, false, false)
 
                 devicePaddingId = R.xml.paddings_handhelds
+
+                isFixedLandscape = false
+                workspaceSpecsId = INVALID_RESOURCE_HANDLE
+                allAppsSpecsId = INVALID_RESOURCE_HANDLE
+                folderSpecsId = INVALID_RESOURCE_HANDLE
+                hotseatSpecsId = INVALID_RESOURCE_HANDLE
+                workspaceCellSpecsId = INVALID_RESOURCE_HANDLE
+                allAppsCellSpecsId = INVALID_RESOURCE_HANDLE
             }
     }
 
@@ -274,7 +292,7 @@
         useTwoPanels = true
 
         inv =
-            InvariantDeviceProfile().apply {
+            context.appComponent.idp.apply {
                 numRows = rows
                 numColumns = cols
                 numSearchContainerColumns = cols
@@ -332,6 +350,14 @@
                 inlineQsb = booleanArrayOf(false, false, false, false)
 
                 devicePaddingId = R.xml.paddings_handhelds
+
+                isFixedLandscape = false
+                workspaceSpecsId = INVALID_RESOURCE_HANDLE
+                allAppsSpecsId = INVALID_RESOURCE_HANDLE
+                folderSpecsId = INVALID_RESOURCE_HANDLE
+                hotseatSpecsId = INVALID_RESOURCE_HANDLE
+                workspaceCellSpecsId = INVALID_RESOURCE_HANDLE
+                allAppsCellSpecsId = INVALID_RESOURCE_HANDLE
             }
     }
 
diff --git a/tests/multivalentTests/src/com/android/launcher3/FakeLauncherPrefs.kt b/tests/multivalentTests/src/com/android/launcher3/FakeLauncherPrefs.kt
index 5e1e548..4d01d4d 100644
--- a/tests/multivalentTests/src/com/android/launcher3/FakeLauncherPrefs.kt
+++ b/tests/multivalentTests/src/com/android/launcher3/FakeLauncherPrefs.kt
@@ -21,19 +21,24 @@
 import android.content.SharedPreferences
 import com.android.launcher3.dagger.ApplicationContext
 import com.android.launcher3.dagger.LauncherAppSingleton
-import java.io.File
+import com.android.launcher3.util.DaggerSingletonTracker
+import java.util.UUID
 import javax.inject.Inject
 
 /** Emulates Launcher preferences for a test environment. */
 @LauncherAppSingleton
-class FakeLauncherPrefs @Inject constructor(@ApplicationContext context: Context) :
+class FakeLauncherPrefs
+@Inject
+constructor(@ApplicationContext context: Context, lifeCycle: DaggerSingletonTracker) :
     LauncherPrefs(context) {
 
-    private val backingPrefs =
-        context.getSharedPreferences(
-            File.createTempFile("fake-pref", ".xml", context.filesDir),
-            MODE_PRIVATE,
-        )
+    private val prefName = "fake-pref-" + UUID.randomUUID().toString()
+
+    private val backingPrefs = context.getSharedPreferences(prefName, MODE_PRIVATE)
+
+    init {
+        lifeCycle.addCloseable { context.deleteSharedPreferences(prefName) }
+    }
 
     override fun getSharedPrefs(item: Item): SharedPreferences = backingPrefs
 }
diff --git a/tests/multivalentTests/src/com/android/launcher3/FakeLauncherPrefsTest.kt b/tests/multivalentTests/src/com/android/launcher3/FakeLauncherPrefsTest.kt
index c57c86f..0941c79 100644
--- a/tests/multivalentTests/src/com/android/launcher3/FakeLauncherPrefsTest.kt
+++ b/tests/multivalentTests/src/com/android/launcher3/FakeLauncherPrefsTest.kt
@@ -18,10 +18,14 @@
 
 import androidx.test.core.app.ApplicationProvider.getApplicationContext
 import androidx.test.platform.app.InstrumentationRegistry.getInstrumentation
+import com.android.launcher3.util.DaggerSingletonTracker
 import com.android.launcher3.util.LauncherMultivalentJUnit
 import com.google.common.truth.Truth.assertThat
+import org.junit.Before
 import org.junit.Test
 import org.junit.runner.RunWith
+import org.mockito.Mock
+import org.mockito.MockitoAnnotations
 
 private val TEST_CONSTANT_ITEM = LauncherPrefs.nonRestorableItem("TEST_BOOLEAN_ITEM", false)
 
@@ -36,7 +40,15 @@
 
 @RunWith(LauncherMultivalentJUnit::class)
 class FakeLauncherPrefsTest {
-    private val launcherPrefs = FakeLauncherPrefs(getApplicationContext())
+
+    @Mock lateinit var lifeCycle: DaggerSingletonTracker
+    private lateinit var launcherPrefs: FakeLauncherPrefs
+
+    @Before
+    fun setup() {
+        MockitoAnnotations.initMocks(this)
+        launcherPrefs = FakeLauncherPrefs(getApplicationContext(), lifeCycle)
+    }
 
     @Test
     fun testGet_constantItemNotInPrefs_returnsDefaultValue() {
diff --git a/tests/multivalentTests/src/com/android/launcher3/widget/GeneratedPreviewTest.kt b/tests/multivalentTests/src/com/android/launcher3/widget/GeneratedPreviewTest.kt
index b92582c..6af0950 100644
--- a/tests/multivalentTests/src/com/android/launcher3/widget/GeneratedPreviewTest.kt
+++ b/tests/multivalentTests/src/com/android/launcher3/widget/GeneratedPreviewTest.kt
@@ -17,7 +17,6 @@
 import androidx.test.filters.SmallTest
 import androidx.test.platform.app.InstrumentationRegistry.getInstrumentation
 import com.android.launcher3.Flags.FLAG_ENABLE_GENERATED_PREVIEWS
-import com.android.launcher3.InvariantDeviceProfile
 import com.android.launcher3.R
 import com.android.launcher3.icons.IconCache
 import com.android.launcher3.model.WidgetItem
@@ -100,7 +99,7 @@
 
     private fun createWidgetItem() {
         Executors.MODEL_EXECUTOR.submit {
-                val idp = InvariantDeviceProfile()
+                val idp = context.appComponent.idp
                 widgetItem = WidgetItem(appWidgetProviderInfo, idp, iconCache, context)
             }
             .get()
diff --git a/tests/multivalentTests/src/com/android/launcher3/widget/LauncherAppWidgetProviderInfoTest.java b/tests/multivalentTests/src/com/android/launcher3/widget/LauncherAppWidgetProviderInfoTest.java
index 48cf3df..b3fd0f7 100644
--- a/tests/multivalentTests/src/com/android/launcher3/widget/LauncherAppWidgetProviderInfoTest.java
+++ b/tests/multivalentTests/src/com/android/launcher3/widget/LauncherAppWidgetProviderInfoTest.java
@@ -15,15 +15,12 @@
  */
 package com.android.launcher3.widget;
 
-import static androidx.test.core.app.ApplicationProvider.getApplicationContext;
-
 import static com.google.common.truth.Truth.assertThat;
 
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.Mockito.doAnswer;
 
 import android.appwidget.AppWidgetHostView;
-import android.content.Context;
 import android.graphics.Point;
 import android.graphics.Rect;
 
@@ -32,16 +29,14 @@
 
 import com.android.launcher3.DeviceProfile;
 import com.android.launcher3.InvariantDeviceProfile;
-import com.android.launcher3.LauncherAppState;
+import com.android.launcher3.util.SandboxApplication;
 
-import org.junit.Before;
+import org.junit.Rule;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.mockito.Mockito;
 
-import java.util.ArrayList;
 import java.util.Collections;
-import java.util.List;
 
 @SmallTest
 @RunWith(AndroidJUnit4.class)
@@ -52,12 +47,7 @@
     private static final int NUM_OF_COLS = 4;
     private static final int NUM_OF_ROWS = 5;
 
-    private Context mContext;
-
-    @Before
-    public void setUp() {
-        mContext = getApplicationContext();
-    }
+    @Rule public SandboxApplication mContext = new SandboxApplication();
 
     @Test
     public void initSpans_minWidthSmallerThanCellWidth_shouldInitializeSpansToOne() {
@@ -256,8 +246,9 @@
     }
 
     private InvariantDeviceProfile createIDP() {
-        DeviceProfile dp = LauncherAppState.getIDP(mContext)
-                .getDeviceProfile(mContext).copy(mContext);
+        InvariantDeviceProfile idp = InvariantDeviceProfile.INSTANCE.get(mContext);
+
+        DeviceProfile dp = idp.getDeviceProfile(mContext).copy(mContext);
         DeviceProfile profile = Mockito.spy(dp);
         doAnswer(i -> {
             ((Point) i.getArgument(0)).set(CELL_SIZE, CELL_SIZE);
@@ -267,10 +258,7 @@
         profile.cellLayoutBorderSpacePx = new Point(SPACE_SIZE, SPACE_SIZE);
         profile.widgetPadding.setEmpty();
 
-        InvariantDeviceProfile idp = new InvariantDeviceProfile();
-        List<DeviceProfile> supportedProfiles = new ArrayList<>(idp.supportedProfiles);
-        supportedProfiles.add(profile);
-        idp.supportedProfiles = Collections.unmodifiableList(supportedProfiles);
+        idp.supportedProfiles = Collections.singletonList(profile);
         idp.numColumns = NUM_OF_COLS;
         idp.numRows = NUM_OF_ROWS;
         return idp;
diff --git a/tests/multivalentTests/src/com/android/launcher3/widget/picker/WidgetRecommendationCategoryProviderTest.java b/tests/multivalentTests/src/com/android/launcher3/widget/picker/WidgetRecommendationCategoryProviderTest.java
index ac67d2b..2fbeaf1 100644
--- a/tests/multivalentTests/src/com/android/launcher3/widget/picker/WidgetRecommendationCategoryProviderTest.java
+++ b/tests/multivalentTests/src/com/android/launcher3/widget/picker/WidgetRecommendationCategoryProviderTest.java
@@ -25,8 +25,6 @@
 import static android.content.pm.ApplicationInfo.CATEGORY_VIDEO;
 import static android.content.pm.ApplicationInfo.FLAG_INSTALLED;
 
-import static androidx.test.platform.app.InstrumentationRegistry.getInstrumentation;
-
 import static com.google.common.truth.Truth.assertThat;
 
 import static org.mockito.ArgumentMatchers.any;
@@ -35,11 +33,9 @@
 import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.spy;
-import static org.mockito.Mockito.when;
 
 import android.appwidget.AppWidgetProviderInfo;
 import android.content.ComponentName;
-import android.content.Context;
 import android.content.pm.ApplicationInfo;
 import android.content.pm.LauncherApps;
 import android.os.Process;
@@ -53,12 +49,14 @@
 import com.android.launcher3.icons.IconCache;
 import com.android.launcher3.model.WidgetItem;
 import com.android.launcher3.util.Executors;
+import com.android.launcher3.util.SandboxApplication;
 import com.android.launcher3.util.WidgetUtils;
 import com.android.launcher3.widget.LauncherAppWidgetProviderInfo;
 
 import com.google.common.collect.ImmutableMap;
 
 import org.junit.Before;
+import org.junit.Rule;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.mockito.Mock;
@@ -92,22 +90,22 @@
 
     private final ApplicationInfo mTestAppInfo = ApplicationInfoBuilder.newBuilder().setPackageName(
             TEST_PACKAGE).setName(TEST_APP_NAME).build();
-    private Context mContext;
+
+    @Rule public SandboxApplication mContext = new SandboxApplication();
     @Mock
     private IconCache mIconCache;
 
     private WidgetItem mTestWidgetItem;
-    @Mock
+
     private LauncherApps mLauncherApps;
     private InvariantDeviceProfile mTestProfile;
 
     @Before
     public void setUp() {
         MockitoAnnotations.initMocks(this);
-        mContext = spy(getInstrumentation().getTargetContext());
-        doReturn(mLauncherApps).when(mContext).getSystemService(LauncherApps.class);
+        mLauncherApps = mContext.spyService(LauncherApps.class);
         mTestAppInfo.flags = FLAG_INSTALLED;
-        mTestProfile = new InvariantDeviceProfile();
+        mTestProfile = InvariantDeviceProfile.INSTANCE.get(mContext);
         mTestProfile.numRows = 5;
         mTestProfile.numColumns = 5;
         createTestWidgetItem();
@@ -128,10 +126,10 @@
                 testCategories.entrySet()) {
 
             mTestAppInfo.category = testCategory.getKey();
-            when(mLauncherApps.getApplicationInfo(/*packageName=*/ eq(TEST_PACKAGE),
+            doReturn(mTestAppInfo).when(mLauncherApps).getApplicationInfo(
+                    /*packageName=*/ eq(TEST_PACKAGE),
                     /*flags=*/ anyInt(),
-                    /*user=*/ eq(Process.myUserHandle())))
-                    .thenReturn(mTestAppInfo);
+                    /*user=*/ eq(Process.myUserHandle()));
 
             WidgetRecommendationCategory category = Executors.MODEL_EXECUTOR.submit(() ->
                     new WidgetRecommendationCategoryProvider().getWidgetRecommendationCategory(
diff --git a/tests/multivalentTests/src/com/android/launcher3/widget/picker/WidgetsListHeaderViewHolderBinderTest.java b/tests/multivalentTests/src/com/android/launcher3/widget/picker/WidgetsListHeaderViewHolderBinderTest.java
index c9b6d4f..767ab63 100644
--- a/tests/multivalentTests/src/com/android/launcher3/widget/picker/WidgetsListHeaderViewHolderBinderTest.java
+++ b/tests/multivalentTests/src/com/android/launcher3/widget/picker/WidgetsListHeaderViewHolderBinderTest.java
@@ -15,8 +15,6 @@
  */
 package com.android.launcher3.widget.picker;
 
-import static androidx.test.core.app.ApplicationProvider.getApplicationContext;
-
 import static com.google.common.truth.Truth.assertThat;
 
 import static org.mockito.ArgumentMatchers.any;
@@ -48,11 +46,13 @@
 import com.android.launcher3.model.data.PackageItemInfo;
 import com.android.launcher3.util.ActivityContextWrapper;
 import com.android.launcher3.util.PackageUserKey;
+import com.android.launcher3.util.SandboxApplication;
 import com.android.launcher3.util.WidgetUtils;
 import com.android.launcher3.widget.LauncherAppWidgetProviderInfo;
 import com.android.launcher3.widget.model.WidgetsListHeaderEntry;
 
 import org.junit.Before;
+import org.junit.Rule;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.mockito.Mock;
@@ -67,6 +67,7 @@
     private static final String TEST_PACKAGE = "com.google.test";
     private static final String APP_NAME = "Test app";
 
+    @Rule public SandboxApplication app = new SandboxApplication();
     private Context mContext;
     private WidgetsListHeaderViewHolderBinder mViewHolderBinder;
     private InvariantDeviceProfile mTestProfile;
@@ -80,9 +81,9 @@
     public void setUp() {
         MockitoAnnotations.initMocks(this);
 
-        mContext = new ActivityContextWrapper(new ContextThemeWrapper(getApplicationContext(),
-                R.style.WidgetContainerTheme));
-        mTestProfile = new InvariantDeviceProfile();
+        mContext = new ActivityContextWrapper(new ContextThemeWrapper(
+                app, R.style.WidgetContainerTheme));
+        mTestProfile = InvariantDeviceProfile.INSTANCE.get(app);
         mTestProfile.numRows = 5;
         mTestProfile.numColumns = 5;
 
diff --git a/tests/multivalentTests/src/com/android/launcher3/widget/picker/WidgetsListTableViewHolderBinderTest.java b/tests/multivalentTests/src/com/android/launcher3/widget/picker/WidgetsListTableViewHolderBinderTest.java
index 86bbcc1..e6f13a6 100644
--- a/tests/multivalentTests/src/com/android/launcher3/widget/picker/WidgetsListTableViewHolderBinderTest.java
+++ b/tests/multivalentTests/src/com/android/launcher3/widget/picker/WidgetsListTableViewHolderBinderTest.java
@@ -15,7 +15,6 @@
  */
 package com.android.launcher3.widget.picker;
 
-import static androidx.test.core.app.ApplicationProvider.getApplicationContext;
 import static androidx.test.platform.app.InstrumentationRegistry.getInstrumentation;
 
 import static com.google.common.truth.Truth.assertThat;
@@ -50,6 +49,7 @@
 import com.android.launcher3.model.WidgetItem;
 import com.android.launcher3.model.data.PackageItemInfo;
 import com.android.launcher3.util.ActivityContextWrapper;
+import com.android.launcher3.util.SandboxApplication;
 import com.android.launcher3.util.WidgetUtils;
 import com.android.launcher3.widget.DatabaseWidgetPreviewLoader;
 import com.android.launcher3.widget.LauncherAppWidgetProviderInfo;
@@ -57,6 +57,7 @@
 import com.android.launcher3.widget.model.WidgetsListContentEntry;
 
 import org.junit.Before;
+import org.junit.Rule;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.mockito.Mock;
@@ -71,6 +72,7 @@
     private static final String TEST_PACKAGE = "com.google.test";
     private static final String APP_NAME = "Test app";
 
+    @Rule public SandboxApplication app = new SandboxApplication();
     private Context mContext;
     private WidgetsListTableViewHolderBinder mViewHolderBinder;
     private InvariantDeviceProfile mTestProfile;
@@ -85,8 +87,8 @@
     @Before
     public void setUp() {
         MockitoAnnotations.initMocks(this);
-        mContext = new ActivityContextWrapper(getApplicationContext());
-        mTestProfile = new InvariantDeviceProfile();
+        mContext = new ActivityContextWrapper(app);
+        mTestProfile = InvariantDeviceProfile.INSTANCE.get(app);
         mTestProfile.numRows = 5;
         mTestProfile.numColumns = 5;
 
diff --git a/tests/multivalentTests/src/com/android/launcher3/widget/picker/model/WidgetsListContentEntryTest.java b/tests/multivalentTests/src/com/android/launcher3/widget/picker/model/WidgetsListContentEntryTest.java
index 6088c8e..bd34de6 100644
--- a/tests/multivalentTests/src/com/android/launcher3/widget/picker/model/WidgetsListContentEntryTest.java
+++ b/tests/multivalentTests/src/com/android/launcher3/widget/picker/model/WidgetsListContentEntryTest.java
@@ -37,10 +37,12 @@
 import com.android.launcher3.icons.cache.CachedObject;
 import com.android.launcher3.model.WidgetItem;
 import com.android.launcher3.model.data.PackageItemInfo;
+import com.android.launcher3.util.SandboxApplication;
 import com.android.launcher3.widget.LauncherAppWidgetProviderInfo;
 import com.android.launcher3.widget.model.WidgetsListContentEntry;
 
 import org.junit.Before;
+import org.junit.Rule;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.mockito.Mock;
@@ -64,6 +66,7 @@
     private final ComponentName mWidget3 = ComponentName.createRelative(PACKAGE_NAME, ".mWidget3");
     private final Map<ComponentName, String> mWidgetsToLabels = new HashMap();
 
+    @Rule public SandboxApplication app = new SandboxApplication();
     @Mock private IconCache mIconCache;
 
     private InvariantDeviceProfile mTestProfile;
@@ -76,7 +79,7 @@
         mWidgetsToLabels.put(mWidget2, "Dog");
         mWidgetsToLabels.put(mWidget3, "Bird");
 
-        mTestProfile = new InvariantDeviceProfile();
+        mTestProfile = InvariantDeviceProfile.INSTANCE.get(app);
         mTestProfile.numRows = 5;
         mTestProfile.numColumns = 5;
 
diff --git a/tests/multivalentTests/src/com/android/launcher3/widget/picker/search/SimpleWidgetsSearchAlgorithmTest.java b/tests/multivalentTests/src/com/android/launcher3/widget/picker/search/SimpleWidgetsSearchAlgorithmTest.java
index 59f352b..0cdda3a 100644
--- a/tests/multivalentTests/src/com/android/launcher3/widget/picker/search/SimpleWidgetsSearchAlgorithmTest.java
+++ b/tests/multivalentTests/src/com/android/launcher3/widget/picker/search/SimpleWidgetsSearchAlgorithmTest.java
@@ -46,6 +46,7 @@
 import com.android.launcher3.model.WidgetItem;
 import com.android.launcher3.model.data.PackageItemInfo;
 import com.android.launcher3.search.SearchCallback;
+import com.android.launcher3.util.SandboxApplication;
 import com.android.launcher3.widget.LauncherAppWidgetProviderInfo;
 import com.android.launcher3.widget.model.WidgetsListBaseEntry;
 import com.android.launcher3.widget.model.WidgetsListContentEntry;
@@ -53,6 +54,7 @@
 import com.android.launcher3.widget.picker.search.WidgetsSearchBar.WidgetsSearchDataProvider;
 
 import org.junit.Before;
+import org.junit.Rule;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.mockito.Mock;
@@ -66,6 +68,7 @@
 @RunWith(AndroidJUnit4.class)
 public class SimpleWidgetsSearchAlgorithmTest {
 
+    @Rule public SandboxApplication app = new SandboxApplication();
     @Mock private IconCache mIconCache;
 
     private InvariantDeviceProfile mTestProfile;
@@ -90,7 +93,7 @@
             CachedObject componentWithLabel = invocation.getArgument(0);
             return componentWithLabel.getComponent().getShortClassName();
         }).when(mIconCache).getTitleNoCache(any());
-        mTestProfile = new InvariantDeviceProfile();
+        mTestProfile = InvariantDeviceProfile.INSTANCE.get(app);
         mTestProfile.numRows = 5;
         mTestProfile.numColumns = 5;
         mContext = getApplicationContext();
diff --git a/tests/multivalentTests/src/com/android/launcher3/widget/picker/util/WidgetPreviewContainerSizesTest.kt b/tests/multivalentTests/src/com/android/launcher3/widget/picker/util/WidgetPreviewContainerSizesTest.kt
index 7a858e4..2452a88 100644
--- a/tests/multivalentTests/src/com/android/launcher3/widget/picker/util/WidgetPreviewContainerSizesTest.kt
+++ b/tests/multivalentTests/src/com/android/launcher3/widget/picker/util/WidgetPreviewContainerSizesTest.kt
@@ -25,6 +25,7 @@
 import com.android.launcher3.DeviceProfile
 import com.android.launcher3.InvariantDeviceProfile
 import com.android.launcher3.LauncherAppState
+import com.android.launcher3.dagger.LauncherComponentProvider.appComponent
 import com.android.launcher3.icons.IconCache
 import com.android.launcher3.model.WidgetItem
 import com.android.launcher3.util.ActivityContextWrapper
@@ -53,7 +54,7 @@
         context = ActivityContextWrapper(ApplicationProvider.getApplicationContext())
         testInvariantProfile = LauncherAppState.getIDP(context)
         widgetItemInvariantProfile =
-            InvariantDeviceProfile().apply {
+            context.appComponent.idp.apply {
                 numRows = TEST_GRID_SIZE
                 numColumns = TEST_GRID_SIZE
             }
@@ -143,13 +144,13 @@
             widgetSize: Point,
             context: Context,
             invariantDeviceProfile: InvariantDeviceProfile,
-            iconCache: IconCache
+            iconCache: IconCache,
         ): WidgetItem {
             val providerInfo =
                 createAppWidgetProviderInfo(
                     ComponentName.createRelative(
                         TEST_PACKAGE,
-                        /*cls=*/ ".WidgetProvider_" + widgetSize.x + "x" + widgetSize.y
+                        /*cls=*/ ".WidgetProvider_" + widgetSize.x + "x" + widgetSize.y,
                     )
                 )
             val widgetInfo =
diff --git a/tests/multivalentTests/src/com/android/launcher3/widget/picker/util/WidgetsTableUtilsTest.java b/tests/multivalentTests/src/com/android/launcher3/widget/picker/util/WidgetsTableUtilsTest.java
index 2f5fcfe..a17e472 100644
--- a/tests/multivalentTests/src/com/android/launcher3/widget/picker/util/WidgetsTableUtilsTest.java
+++ b/tests/multivalentTests/src/com/android/launcher3/widget/picker/util/WidgetsTableUtilsTest.java
@@ -15,8 +15,6 @@
  */
 package com.android.launcher3.widget.picker.util;
 
-import static androidx.test.core.app.ApplicationProvider.getApplicationContext;
-
 import static com.android.launcher3.util.WidgetUtils.createAppWidgetProviderInfo;
 
 import static com.google.common.truth.Truth.assertThat;
@@ -44,10 +42,12 @@
 import com.android.launcher3.model.WidgetItem;
 import com.android.launcher3.pm.ShortcutConfigActivityInfo;
 import com.android.launcher3.util.ActivityContextWrapper;
+import com.android.launcher3.util.SandboxApplication;
 import com.android.launcher3.widget.LauncherAppWidgetProviderInfo;
 import com.android.launcher3.widget.util.WidgetsTableUtils;
 
 import org.junit.Before;
+import org.junit.Rule;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.mockito.Mock;
@@ -68,6 +68,8 @@
     private static final int NUM_OF_COLS = 5;
     private static final int NUM_OF_ROWS = 5;
 
+    @Rule public SandboxApplication app = new SandboxApplication();
+
     @Mock
     private IconCache mIconCache;
 
@@ -89,9 +91,8 @@
     public void setUp() {
         MockitoAnnotations.initMocks(this);
 
-        mContext = new ActivityContextWrapper(getApplicationContext());
-
-        mTestInvariantProfile = new InvariantDeviceProfile();
+        mContext = new ActivityContextWrapper(app);
+        mTestInvariantProfile = InvariantDeviceProfile.INSTANCE.get(app);
         mTestInvariantProfile.numColumns = NUM_OF_COLS;
         mTestInvariantProfile.numRows = NUM_OF_ROWS;
 
diff --git a/tests/src/com/android/launcher3/model/LoaderTaskTest.kt b/tests/src/com/android/launcher3/model/LoaderTaskTest.kt
index cdb45fc..8f64e84 100644
--- a/tests/src/com/android/launcher3/model/LoaderTaskTest.kt
+++ b/tests/src/com/android/launcher3/model/LoaderTaskTest.kt
@@ -15,7 +15,6 @@
 import androidx.test.filters.SmallTest
 import com.android.dx.mockito.inline.extended.ExtendedMockito
 import com.android.launcher3.Flags
-import com.android.launcher3.InvariantDeviceProfile
 import com.android.launcher3.LauncherAppState
 import com.android.launcher3.LauncherModel
 import com.android.launcher3.LauncherModel.LoaderTransaction
@@ -120,12 +119,11 @@
                 .mockStatic(FirstScreenBroadcastHelper::class.java)
                 .startMocking()
         val idp =
-            InvariantDeviceProfile().apply {
+            context.appComponent.idp.apply {
                 numRows = 5
                 numColumns = 6
                 numDatabaseHotseatIcons = 5
             }
-        context.putObject(InvariantDeviceProfile.INSTANCE, idp)
         context.putObject(LauncherAppState.INSTANCE, app)
 
         doReturn(TestViewHelpers.findWidgetProvider(false))
diff --git a/tests/src/com/android/launcher3/nonquickstep/DeviceProfileDumpTest.kt b/tests/src/com/android/launcher3/nonquickstep/DeviceProfileDumpTest.kt
index 2e2b6cd..05cf926 100644
--- a/tests/src/com/android/launcher3/nonquickstep/DeviceProfileDumpTest.kt
+++ b/tests/src/com/android/launcher3/nonquickstep/DeviceProfileDumpTest.kt
@@ -19,7 +19,6 @@
 import com.android.launcher3.AbstractDeviceProfileTest
 import com.android.launcher3.DeviceProfile
 import com.android.launcher3.Flags
-import com.android.launcher3.InvariantDeviceProfile
 import com.android.launcher3.util.rule.setFlags
 import org.junit.Before
 import org.junit.Test
@@ -46,7 +45,7 @@
     @Test
     fun dumpPortraitGesture() {
         initializeDevice(instance.deviceName, isGestureMode = true, isLandscape = false)
-        val dp = getDeviceProfileForGrid(instance.gridName)
+        val dp = context.appComponent.idp.getDeviceProfile(context)
         dp.isTaskbarPresentInApps = instance.isTaskbarPresentInApps
 
         assertDump(dp, instance.filename("Portrait"))
@@ -55,7 +54,7 @@
     @Test
     fun dumpPortrait3Button() {
         initializeDevice(instance.deviceName, isGestureMode = false, isLandscape = false)
-        val dp = getDeviceProfileForGrid(instance.gridName)
+        val dp = context.appComponent.idp.getDeviceProfile(context)
         dp.isTaskbarPresentInApps = instance.isTaskbarPresentInApps
 
         assertDump(dp, instance.filename("Portrait3Button"))
@@ -64,7 +63,7 @@
     @Test
     fun dumpLandscapeGesture() {
         initializeDevice(instance.deviceName, isGestureMode = true, isLandscape = true)
-        val dp = getDeviceProfileForGrid(instance.gridName)
+        val dp = context.appComponent.idp.getDeviceProfile(context)
         dp.isTaskbarPresentInApps = instance.isTaskbarPresentInApps
 
         val testName =
@@ -79,7 +78,7 @@
     @Test
     fun dumpLandscape3Button() {
         initializeDevice(instance.deviceName, isGestureMode = false, isLandscape = true)
-        val dp = getDeviceProfileForGrid(instance.gridName)
+        val dp = context.appComponent.idp.getDeviceProfile(context)
         dp.isTaskbarPresentInApps = instance.isTaskbarPresentInApps
 
         val testName =
@@ -101,26 +100,25 @@
                     deviceSpecFolded = deviceSpecs["twopanel-phone"]!!,
                     isLandscape = isLandscape,
                     isGestureMode = isGestureMode,
+                    gridName = instance.gridName,
                 )
             "tablet" ->
                 initializeVarsForTablet(
                     deviceSpec = deviceSpec,
                     isLandscape = isLandscape,
                     isGestureMode = isGestureMode,
+                    gridName = instance.gridName,
                 )
             else ->
                 initializeVarsForPhone(
                     deviceSpec = deviceSpec,
                     isVerticalBar = isLandscape,
                     isGestureMode = isGestureMode,
+                    gridName = instance.gridName,
                 )
         }
     }
 
-    private fun getDeviceProfileForGrid(gridName: String): DeviceProfile {
-        return InvariantDeviceProfile(context, gridName).getDeviceProfile(context)
-    }
-
     private fun assertDump(dp: DeviceProfile, filename: String) {
         assertDump(dp, folderName, filename)
     }