Merge "Refactor grid migration" into main
diff --git a/src/com/android/launcher3/graphics/PreviewSurfaceRenderer.java b/src/com/android/launcher3/graphics/PreviewSurfaceRenderer.java
index e3c2d36..1dd7d45 100644
--- a/src/com/android/launcher3/graphics/PreviewSurfaceRenderer.java
+++ b/src/com/android/launcher3/graphics/PreviewSurfaceRenderer.java
@@ -55,7 +55,7 @@
 import com.android.launcher3.model.BaseLauncherBinder;
 import com.android.launcher3.model.BgDataModel;
 import com.android.launcher3.model.BgDataModel.Callbacks;
-import com.android.launcher3.model.GridSizeMigrationUtil;
+import com.android.launcher3.model.GridSizeMigrationDBController;
 import com.android.launcher3.model.LoaderTask;
 import com.android.launcher3.model.ModelDbController;
 import com.android.launcher3.provider.LauncherDbUtils;
@@ -284,7 +284,7 @@
     private void loadModelData() {
         final Context inflationContext = getPreviewContext();
         final InvariantDeviceProfile idp = new InvariantDeviceProfile(inflationContext, mGridName);
-        if (GridSizeMigrationUtil.needsToMigrate(inflationContext, idp)) {
+        if (GridSizeMigrationDBController.needsToMigrate(inflationContext, idp)) {
             // Start the migration
             PreviewContext previewContext = new PreviewContext(inflationContext, idp);
             // Copy existing data to preview DB
diff --git a/src/com/android/launcher3/model/DbEntry.java b/src/com/android/launcher3/model/DbEntry.java
new file mode 100644
index 0000000..c0c51da
--- /dev/null
+++ b/src/com/android/launcher3/model/DbEntry.java
@@ -0,0 +1,155 @@
+/*
+ * Copyright (C) 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.launcher3.model;
+
+import android.content.ContentValues;
+import android.content.Intent;
+import android.util.Log;
+
+import androidx.annotation.NonNull;
+
+import com.android.launcher3.LauncherSettings;
+import com.android.launcher3.model.data.ItemInfo;
+import com.android.launcher3.util.ContentWriter;
+
+import java.net.URISyntaxException;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+public class DbEntry extends ItemInfo implements Comparable<DbEntry> {
+
+    private static final String TAG = "DbEntry";
+
+    String mIntent;
+    String mProvider;
+    Map<String, Set<Integer>> mFolderItems = new HashMap<>();
+
+    /**
+     * Id of the specific widget.
+     */
+    public int appWidgetId = NO_ID;
+
+    /** Comparator according to the reading order */
+    @Override
+    public int compareTo(DbEntry another) {
+        if (screenId != another.screenId) {
+            return Integer.compare(screenId, another.screenId);
+        }
+        if (cellY != another.cellY) {
+            return Integer.compare(cellY, another.cellY);
+        }
+        return Integer.compare(cellX, another.cellX);
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (!(o instanceof DbEntry)) return false;
+        DbEntry entry = (DbEntry) o;
+        return Objects.equals(getEntryMigrationId(), entry.getEntryMigrationId());
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(getEntryMigrationId());
+    }
+
+    /**
+     *  Puts the updated DbEntry values into ContentValues which we then use to insert the
+     *  entry to the DB.
+     */
+    public void updateContentValues(ContentValues values) {
+        values.put(LauncherSettings.Favorites.SCREEN, screenId);
+        values.put(LauncherSettings.Favorites.CELLX, cellX);
+        values.put(LauncherSettings.Favorites.CELLY, cellY);
+        values.put(LauncherSettings.Favorites.SPANX, spanX);
+        values.put(LauncherSettings.Favorites.SPANY, spanY);
+    }
+
+    @Override
+    public void writeToValues(@NonNull ContentWriter writer) {
+        super.writeToValues(writer);
+        writer.put(LauncherSettings.Favorites.APPWIDGET_ID, appWidgetId);
+    }
+
+    @Override
+    public void readFromValues(@NonNull ContentValues values) {
+        super.readFromValues(values);
+        appWidgetId = values.getAsInteger(LauncherSettings.Favorites.APPWIDGET_ID);
+    }
+
+    /**
+     * This id is not used in the DB is only used while doing the migration and it identifies
+     * an entry on each workspace. For example two calculator icons would have the same
+     * migration id even thought they have different database ids.
+     */
+    public String getEntryMigrationId() {
+        switch (itemType) {
+            case LauncherSettings.Favorites.ITEM_TYPE_FOLDER:
+            case LauncherSettings.Favorites.ITEM_TYPE_APP_PAIR:
+                return getFolderMigrationId();
+            case LauncherSettings.Favorites.ITEM_TYPE_APPWIDGET:
+                // mProvider is the app the widget belongs to and appWidgetId it's the unique
+                // is of the widget, we need both because if you remove a widget and then add it
+                // again, then it can change and the WidgetProvider would not know the widget.
+                return mProvider + appWidgetId;
+            case LauncherSettings.Favorites.ITEM_TYPE_APPLICATION:
+                final String intentStr = cleanIntentString(mIntent);
+                try {
+                    Intent i = Intent.parseUri(intentStr, 0);
+                    return Objects.requireNonNull(i.getComponent()).toString();
+                } catch (Exception e) {
+                    return intentStr;
+                }
+            default:
+                return cleanIntentString(mIntent);
+        }
+    }
+
+    /**
+     * This method should return an id that should be the same for two folders containing the
+     * same elements.
+     */
+    @NonNull
+    private String getFolderMigrationId() {
+        return mFolderItems.keySet().stream()
+                .map(intentString -> mFolderItems.get(intentString).size()
+                        + cleanIntentString(intentString))
+                .sorted()
+                .collect(Collectors.joining(","));
+    }
+
+    /**
+     * This is needed because sourceBounds can change and make the id of two equal items
+     * different.
+     */
+    @NonNull
+    private String cleanIntentString(@NonNull String intentStr) {
+        try {
+            Intent i = Intent.parseUri(intentStr, 0);
+            i.setSourceBounds(null);
+            return i.toURI();
+        } catch (URISyntaxException e) {
+            Log.e(TAG, "Unable to parse Intent string", e);
+            return intentStr;
+        }
+
+    }
+}
diff --git a/src/com/android/launcher3/model/GridSizeMigrationUtil.java b/src/com/android/launcher3/model/GridSizeMigrationDBController.java
similarity index 83%
rename from src/com/android/launcher3/model/GridSizeMigrationUtil.java
rename to src/com/android/launcher3/model/GridSizeMigrationDBController.java
index 4c017e9..9531d5b 100644
--- a/src/com/android/launcher3/model/GridSizeMigrationUtil.java
+++ b/src/com/android/launcher3/model/GridSizeMigrationDBController.java
@@ -28,7 +28,6 @@
 import android.content.ComponentName;
 import android.content.ContentValues;
 import android.content.Context;
-import android.content.Intent;
 import android.database.Cursor;
 import android.database.DatabaseUtils;
 import android.database.sqlite.SQLiteDatabase;
@@ -45,15 +44,12 @@
 import com.android.launcher3.LauncherSettings;
 import com.android.launcher3.Utilities;
 import com.android.launcher3.config.FeatureFlags;
-import com.android.launcher3.model.data.ItemInfo;
 import com.android.launcher3.provider.LauncherDbUtils.SQLiteTransaction;
-import com.android.launcher3.util.ContentWriter;
 import com.android.launcher3.util.GridOccupancy;
 import com.android.launcher3.util.IntArray;
 import com.android.launcher3.widget.LauncherAppWidgetProviderInfo;
 import com.android.launcher3.widget.WidgetManagerHelper;
 
-import java.net.URISyntaxException;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.HashMap;
@@ -61,7 +57,6 @@
 import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
-import java.util.Objects;
 import java.util.Set;
 import java.util.stream.Collectors;
 
@@ -69,12 +64,12 @@
  * This class takes care of shrinking the workspace (by maximum of one row and one column), as a
  * result of restoring from a larger device or device density change.
  */
-public class GridSizeMigrationUtil {
+public class GridSizeMigrationDBController {
 
-    private static final String TAG = "GridSizeMigrationUtil";
+    private static final String TAG = "GridSizeMigrationDBController";
     private static final boolean DEBUG = true;
 
-    private GridSizeMigrationUtil() {
+    private GridSizeMigrationDBController() {
         // Util class should not be instantiated
     }
 
@@ -85,7 +80,7 @@
         return needsToMigrate(new DeviceGridState(context), new DeviceGridState(idp));
     }
 
-    private static boolean needsToMigrate(
+    static boolean needsToMigrate(
             DeviceGridState srcDeviceState, DeviceGridState destDeviceState) {
         boolean needsToMigrate = !destDeviceState.isCompatible(srcDeviceState);
         if (needsToMigrate) {
@@ -95,6 +90,9 @@
         return needsToMigrate;
     }
 
+    /**
+     * @return all the workspace and hotseat entries in the db.
+     */
     @VisibleForTesting
     public static List<DbEntry> readAllEntries(SQLiteDatabase db, String tableName,
             Context context) {
@@ -198,7 +196,7 @@
                     Collectors.joining(",\n", "[", "]"))
                     + "\n Removing Items:"
                     + dstWorkspaceItems.stream().filter(entry ->
-                            toBeRemoved.contains(entry.id)).map(DbEntry::toString).collect(
+                    toBeRemoved.contains(entry.id)).map(DbEntry::toString).collect(
                     Collectors.joining(",\n", "[", "]"))
                     + "\n Adding Workspace Items:"
                     + workspaceToBeAdded.stream().map(DbEntry::toString).collect(
@@ -291,7 +289,7 @@
         });
     }
 
-    private static void insertEntryInDb(DatabaseHelper helper, DbEntry entry,
+    static void insertEntryInDb(DatabaseHelper helper, DbEntry entry,
             String srcTableName, String destTableName, List<Integer> idsInUse) {
         int id = copyEntryAndUpdate(helper, entry, srcTableName, destTableName, idsInUse);
         if (entry.itemType == LauncherSettings.Favorites.ITEM_TYPE_FOLDER
@@ -341,7 +339,7 @@
         return newId;
     }
 
-    private static void removeEntryFromDb(SQLiteDatabase db, String tableName, IntArray entryIds) {
+    static void removeEntryFromDb(SQLiteDatabase db, String tableName, IntArray entryIds) {
         db.delete(tableName,
                 Utilities.createDbSelectionQuery(LauncherSettings.Favorites._ID, entryIds), null);
     }
@@ -387,7 +385,7 @@
     private static boolean findPlacementForEntry(@NonNull final DbEntry entry,
             @NonNull final Point next, @NonNull final Point trg,
             @NonNull final GridOccupancy occupied, final int screenId) {
-        for (int y = next.y; y <  trg.y; y++) {
+        for (int y = next.y; y < trg.y; y++) {
             for (int x = next.x; x < trg.x; x++) {
                 boolean fits = occupied.isRegionVacant(x, y, entry.spanX, entry.spanY);
                 boolean minFits = occupied.isRegionVacant(x, y, entry.minSpanX,
@@ -413,7 +411,7 @@
     private static void solveHotseatPlacement(
             @NonNull final DatabaseHelper helper, final int hotseatSize,
             @NonNull final DbReader srcReader, @NonNull final DbReader destReader,
-            @NonNull final  List<DbEntry> placedHotseatItems,
+            @NonNull final List<DbEntry> placedHotseatItems,
             @NonNull final List<DbEntry> itemsToPlace, List<Integer> idsInUse) {
 
         final boolean[] occupied = new boolean[hotseatSize];
@@ -436,15 +434,26 @@
         }
     }
 
-    @VisibleForTesting
+    static void copyCurrentGridToNewGrid(
+            @NonNull Context context,
+            @NonNull DeviceGridState destDeviceState,
+            @NonNull DatabaseHelper target,
+            @NonNull SQLiteDatabase source) {
+        // Only use this strategy when comparing the previous grid to the new grid and the
+        // columns are the same and the destination has more rows
+        copyTable(source, TABLE_NAME, target.getWritableDatabase(), TABLE_NAME, context);
+        destDeviceState.writeToPrefs(context);
+    }
+
+    @VisibleForTesting(otherwise = VisibleForTesting.PACKAGE_PRIVATE)
     public static class DbReader {
 
-        private final SQLiteDatabase mDb;
-        private final String mTableName;
-        private final Context mContext;
-        private int mLastScreenId = -1;
+        final SQLiteDatabase mDb;
+        final String mTableName;
+        final Context mContext;
+        int mLastScreenId = -1;
 
-        private final Map<Integer, ArrayList<DbEntry>> mWorkspaceEntriesByScreenId =
+        final Map<Integer, ArrayList<DbEntry>> mWorkspaceEntriesByScreenId =
                 new ArrayMap<>();
 
         public DbReader(SQLiteDatabase db, String tableName, Context context) {
@@ -529,7 +538,7 @@
                             LauncherSettings.Favorites.INTENT,               // 7
                             LauncherSettings.Favorites.APPWIDGET_PROVIDER,   // 8
                             LauncherSettings.Favorites.APPWIDGET_ID},        // 9
-                        LauncherSettings.Favorites.CONTAINER + " = "
+                    LauncherSettings.Favorites.CONTAINER + " = "
                             + LauncherSettings.Favorites.CONTAINER_DESKTOP);
             final int indexId = c.getColumnIndexOrThrow(LauncherSettings.Favorites._ID);
             final int indexItemType = c.getColumnIndexOrThrow(LauncherSettings.Favorites.ITEM_TYPE);
@@ -648,118 +657,4 @@
             return mDb.query(mTableName, columns, where, null, null, null, null);
         }
     }
-
-    public static class DbEntry extends ItemInfo implements Comparable<DbEntry> {
-
-        private String mIntent;
-        private String mProvider;
-        private Map<String, Set<Integer>> mFolderItems = new HashMap<>();
-
-        /**
-         * Id of the specific widget.
-         */
-        public int appWidgetId = NO_ID;
-
-        /** Comparator according to the reading order */
-        @Override
-        public int compareTo(DbEntry another) {
-            if (screenId != another.screenId) {
-                return Integer.compare(screenId, another.screenId);
-            }
-            if (cellY != another.cellY) {
-                return Integer.compare(cellY, another.cellY);
-            }
-            return Integer.compare(cellX, another.cellX);
-        }
-
-        @Override
-        public boolean equals(Object o) {
-            if (this == o) return true;
-            if (o == null || getClass() != o.getClass()) return false;
-            DbEntry entry = (DbEntry) o;
-            return Objects.equals(getEntryMigrationId(), entry.getEntryMigrationId());
-        }
-
-        @Override
-        public int hashCode() {
-            return Objects.hash(getEntryMigrationId());
-        }
-
-        public void updateContentValues(ContentValues values) {
-            values.put(LauncherSettings.Favorites.SCREEN, screenId);
-            values.put(LauncherSettings.Favorites.CELLX, cellX);
-            values.put(LauncherSettings.Favorites.CELLY, cellY);
-            values.put(LauncherSettings.Favorites.SPANX, spanX);
-            values.put(LauncherSettings.Favorites.SPANY, spanY);
-        }
-
-        @Override
-        public void writeToValues(@NonNull ContentWriter writer) {
-            super.writeToValues(writer);
-            writer.put(LauncherSettings.Favorites.APPWIDGET_ID, appWidgetId);
-        }
-
-        @Override
-        public void readFromValues(@NonNull ContentValues values) {
-            super.readFromValues(values);
-            appWidgetId = values.getAsInteger(LauncherSettings.Favorites.APPWIDGET_ID);
-        }
-
-        /** This id is not used in the DB is only used while doing the migration and it identifies
-         * an entry on each workspace. For example two calculator icons would have the same
-         * migration id even thought they have different database ids.
-         */
-        public String getEntryMigrationId() {
-            switch (itemType) {
-                case LauncherSettings.Favorites.ITEM_TYPE_FOLDER:
-                case LauncherSettings.Favorites.ITEM_TYPE_APP_PAIR:
-                    return getFolderMigrationId();
-                case LauncherSettings.Favorites.ITEM_TYPE_APPWIDGET:
-                    // mProvider is the app the widget belongs to and appWidgetId it's the unique
-                    // is of the widget, we need both because if you remove a widget and then add it
-                    // again, then it can change and the WidgetProvider would not know the widget.
-                    return mProvider + appWidgetId;
-                case LauncherSettings.Favorites.ITEM_TYPE_APPLICATION:
-                    final String intentStr = cleanIntentString(mIntent);
-                    try {
-                        Intent i = Intent.parseUri(intentStr, 0);
-                        return Objects.requireNonNull(i.getComponent()).toString();
-                    } catch (Exception e) {
-                        return intentStr;
-                    }
-                default:
-                    return cleanIntentString(mIntent);
-            }
-        }
-
-        /**
-         * This method should return an id that should be the same for two folders containing the
-         * same elements.
-         */
-        @NonNull
-        private String getFolderMigrationId() {
-            return mFolderItems.keySet().stream()
-                    .map(intentString -> mFolderItems.get(intentString).size()
-                            + cleanIntentString(intentString))
-                    .sorted()
-                    .collect(Collectors.joining(","));
-        }
-
-        /**
-         * This is needed because sourceBounds can change and make the id of two equal items
-         * different.
-         */
-        @NonNull
-        private String cleanIntentString(@NonNull String intentStr) {
-            try {
-                Intent i = Intent.parseUri(intentStr, 0);
-                i.setSourceBounds(null);
-                return i.toURI();
-            } catch (URISyntaxException e) {
-                Log.e(TAG, "Unable to parse Intent string", e);
-                return intentStr;
-            }
-
-        }
-    }
 }
diff --git a/src/com/android/launcher3/model/GridSizeMigrationLogic.java b/src/com/android/launcher3/model/GridSizeMigrationLogic.java
new file mode 100644
index 0000000..12a14b2
--- /dev/null
+++ b/src/com/android/launcher3/model/GridSizeMigrationLogic.java
@@ -0,0 +1,466 @@
+/*
+ * Copyright (C) 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.launcher3.model;
+
+import static com.android.launcher3.Flags.enableSmartspaceRemovalToggle;
+import static com.android.launcher3.LauncherPrefs.IS_FIRST_LOAD_AFTER_RESTORE;
+import static com.android.launcher3.LauncherSettings.Favorites.TABLE_NAME;
+import static com.android.launcher3.LauncherSettings.Favorites.TMP_TABLE;
+import static com.android.launcher3.Utilities.SHOULD_SHOW_FIRST_PAGE_WIDGET;
+import static com.android.launcher3.model.GridSizeMigrationDBController.copyCurrentGridToNewGrid;
+import static com.android.launcher3.model.GridSizeMigrationDBController.insertEntryInDb;
+import static com.android.launcher3.model.GridSizeMigrationDBController.needsToMigrate;
+import static com.android.launcher3.model.GridSizeMigrationDBController.removeEntryFromDb;
+import static com.android.launcher3.model.LoaderTask.SMARTSPACE_ON_HOME_SCREEN;
+import static com.android.launcher3.provider.LauncherDbUtils.copyTable;
+import static com.android.launcher3.provider.LauncherDbUtils.dropTable;
+
+import android.content.Context;
+import android.database.sqlite.SQLiteDatabase;
+import android.graphics.Point;
+import android.util.Log;
+
+import androidx.annotation.NonNull;
+
+import com.android.launcher3.Flags;
+import com.android.launcher3.LauncherPrefs;
+import com.android.launcher3.LauncherSettings;
+import com.android.launcher3.config.FeatureFlags;
+import com.android.launcher3.provider.LauncherDbUtils;
+import com.android.launcher3.util.CellAndSpan;
+import com.android.launcher3.util.GridOccupancy;
+import com.android.launcher3.util.IntArray;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
+
+public class GridSizeMigrationLogic {
+
+    private static final String TAG = "GridSizeMigrationLogic";
+    private static final boolean DEBUG = true;
+
+    /**
+     * Migrates the grid size from srcDeviceState to destDeviceState and make those changes
+     * in the target DB, using the source DB to determine what to add/remove/move/resize
+     * in the destination DB.
+     */
+    public void migrateGrid(
+            @NonNull Context context,
+            @NonNull DeviceGridState srcDeviceState,
+            @NonNull DeviceGridState destDeviceState,
+            @NonNull DatabaseHelper target,
+            @NonNull SQLiteDatabase source) {
+        if (!needsToMigrate(srcDeviceState, destDeviceState)) {
+            return;
+        }
+
+        boolean isFirstLoad = LauncherPrefs.get(context).get(IS_FIRST_LOAD_AFTER_RESTORE);
+        Log.d(TAG, "Begin grid migration. First load: " + isFirstLoad);
+
+        // This is a special case where if the grid is the same amount of columns but a larger
+        // amount of rows we simply copy over the source grid to the destination grid, rather
+        // than undergoing the general grid migration.
+        if (shouldMigrateToStrictlyTallerGrid(isFirstLoad, srcDeviceState, destDeviceState)) {
+            copyCurrentGridToNewGrid(context, destDeviceState, target, source);
+            return;
+        }
+        copyTable(source, TABLE_NAME, target.getWritableDatabase(), TMP_TABLE, context);
+
+        long migrationStartTime = System.currentTimeMillis();
+        try (LauncherDbUtils.SQLiteTransaction t =
+                     new LauncherDbUtils.SQLiteTransaction(target.getWritableDatabase())) {
+            GridSizeMigrationDBController.DbReader srcReader = new GridSizeMigrationDBController
+                    .DbReader(t.getDb(), TMP_TABLE, context);
+            GridSizeMigrationDBController.DbReader destReader =
+                    new GridSizeMigrationDBController.DbReader(
+                            t.getDb(), TABLE_NAME, context);
+
+            Point targetSize = new Point(destDeviceState.getColumns(), destDeviceState.getRows());
+            // Migrate hotseat.
+            migrateHotseat(destDeviceState.getNumHotseat(), srcReader, destReader, target);
+            // Migrate workspace.
+            migrateWorkspace(srcReader, destReader, target, targetSize);
+
+            dropTable(t.getDb(), TMP_TABLE);
+            t.commit();
+        } catch (Exception e) {
+            Log.e(TAG, "Error during grid migration", e);
+        } finally {
+            Log.v(TAG, "Workspace migration completed in "
+                    + (System.currentTimeMillis() - migrationStartTime));
+
+            // Save current configuration, so that the migration does not run again.
+            destDeviceState.writeToPrefs(context);
+        }
+    }
+
+    private void migrateHotseat(int destHotseatSize,
+            GridSizeMigrationDBController.DbReader srcReader,
+            GridSizeMigrationDBController.DbReader destReader, DatabaseHelper helper) {
+        final List<DbEntry> srcHotseatItems =
+                srcReader.loadHotseatEntries();
+        final List<DbEntry> dstHotseatItems =
+                destReader.loadHotseatEntries();
+
+
+        final List<DbEntry> hotseatToBeAdded =
+                getItemsToBeAdded(srcHotseatItems, dstHotseatItems);
+        final IntArray toBeRemoved = new IntArray();
+        toBeRemoved.addAll(getItemsToBeRemoved(srcHotseatItems, dstHotseatItems));
+
+        if (DEBUG) {
+            Log.d(TAG, "Start hotseat migration:"
+                    + "\n Removing Hotseat Items:"
+                    + dstHotseatItems.stream().filter(entry -> toBeRemoved
+                    .contains(entry.id)).map(DbEntry::toString)
+                    .collect(Collectors.joining(",\n", "[", "]"))
+                    + "\n Adding Hotseat Items:"
+                    + hotseatToBeAdded.stream().map(DbEntry::toString)
+                    .collect(Collectors.joining(",\n", "[", "]"))
+            );
+        }
+
+        // Removes the items that we need to remove from the destination DB.
+        if (!toBeRemoved.isEmpty()) {
+            removeEntryFromDb(destReader.mDb, destReader.mTableName, toBeRemoved);
+        }
+
+        placeHotseatItems(
+                hotseatToBeAdded, dstHotseatItems, destHotseatSize, helper, srcReader, destReader);
+    }
+
+    private void placeHotseatItems(List<DbEntry> hotseatToBeAdded,
+            List<DbEntry> dstHotseatItems, int destHotseatSize,
+            DatabaseHelper helper, GridSizeMigrationDBController.DbReader srcReader,
+            GridSizeMigrationDBController.DbReader destReader) {
+        if (hotseatToBeAdded.isEmpty()) {
+            return;
+        }
+
+        List<Integer> idsInUse = dstHotseatItems.stream().map(entry -> entry.id).toList();
+
+        Collections.sort(hotseatToBeAdded);
+
+        List<DbEntry> placementSolutionHotseat =
+                solveHotseatPlacement(destHotseatSize, dstHotseatItems, hotseatToBeAdded);
+        for (DbEntry entryToPlace: placementSolutionHotseat) {
+            insertEntryInDb(helper, entryToPlace, srcReader.mTableName, destReader.mTableName,
+                    idsInUse);
+        }
+    }
+
+    private void migrateWorkspace(GridSizeMigrationDBController.DbReader srcReader,
+            GridSizeMigrationDBController.DbReader destReader, DatabaseHelper helper,
+            Point targetSize) {
+
+
+        final List<DbEntry> srcWorkspaceItems =
+                srcReader.loadAllWorkspaceEntries();
+
+        final List<DbEntry> dstWorkspaceItems =
+                destReader.loadAllWorkspaceEntries();
+
+        final IntArray toBeRemoved = new IntArray();
+
+        List<DbEntry> workspaceToBeAdded =
+                getItemsToBeAdded(srcWorkspaceItems, dstWorkspaceItems);
+        toBeRemoved.addAll(getItemsToBeRemoved(srcWorkspaceItems, dstWorkspaceItems));
+
+        if (DEBUG) {
+            Log.d(TAG, "Start workspace migration:"
+                    + "\n Source Device:"
+                    + srcWorkspaceItems.stream().map(
+                            DbEntry::toString)
+                    .collect(Collectors.joining(",\n", "[", "]"))
+                    + "\n Target Device:"
+                    + dstWorkspaceItems.stream().map(
+                            DbEntry::toString)
+                    .collect(Collectors.joining(",\n", "[", "]"))
+                    + "\n Removing Workspace Items:"
+                    + dstWorkspaceItems.stream().filter(entry -> toBeRemoved
+                            .contains(entry.id)).map(
+                            DbEntry::toString)
+                    .collect(Collectors.joining(",\n", "[", "]"))
+                    + "\n Adding Workspace Items:"
+                    + workspaceToBeAdded.stream().map(
+                            DbEntry::toString)
+                    .collect(Collectors.joining(",\n", "[", "]"))
+            );
+        }
+
+        // Removes the items that we need to remove from the destination DB.
+        if (!toBeRemoved.isEmpty()) {
+            removeEntryFromDb(destReader.mDb, destReader.mTableName, toBeRemoved);
+        }
+
+        placeWorkspaceItems(workspaceToBeAdded, dstWorkspaceItems, targetSize.x, targetSize.y,
+                helper, srcReader, destReader);
+    }
+
+    private void placeWorkspaceItems(
+            List<DbEntry> workspaceToBeAdded,
+            List<DbEntry> dstWorkspaceItems,
+            int trgX, int trgY, DatabaseHelper helper,
+            GridSizeMigrationDBController.DbReader srcReader,
+            GridSizeMigrationDBController.DbReader destReader) {
+        if (workspaceToBeAdded.isEmpty()) {
+            return;
+        }
+
+        List<Integer> idsInUse = dstWorkspaceItems.stream().map(entry -> entry.id).collect(
+                Collectors.toList());
+
+        Collections.sort(workspaceToBeAdded);
+
+
+        // First we create a collection of the screens
+        List<Integer> screens = new ArrayList<>();
+        for (int screenId = 0; screenId <= destReader.mLastScreenId; screenId++) {
+            screens.add(screenId);
+        }
+
+        // Then we place the items on the screens
+        WorkspaceItemsToPlace itemsToPlace =
+                new WorkspaceItemsToPlace(workspaceToBeAdded);
+        for (int screenId : screens) {
+            if (DEBUG) {
+                Log.d(TAG, "Migrating " + screenId);
+            }
+            itemsToPlace = solveGridPlacement(
+                    destReader.mContext, screenId, trgX, trgY, itemsToPlace.mRemainingItemsToPlace,
+                    destReader.mWorkspaceEntriesByScreenId.get(screenId));
+            placeItems(itemsToPlace, helper, srcReader, destReader, idsInUse);
+            while (!itemsToPlace.mPlacementSolution.isEmpty()) {
+                insertEntryInDb(helper, itemsToPlace.mPlacementSolution.remove(0),
+                        srcReader.mTableName, destReader.mTableName, idsInUse);
+            }
+            if (itemsToPlace.mRemainingItemsToPlace.isEmpty()) {
+                break;
+            }
+        }
+
+        // In case the new grid is smaller, there might be some leftover items that don't fit on
+        // any of the screens, in this case we add them to new screens until all of them are placed.
+        int screenId = destReader.mLastScreenId + 1;
+        while (!itemsToPlace.mRemainingItemsToPlace.isEmpty()) {
+            itemsToPlace = solveGridPlacement(destReader.mContext, screenId,
+                    trgX, trgY, itemsToPlace.mRemainingItemsToPlace,
+                    destReader.mWorkspaceEntriesByScreenId.get(screenId));
+            placeItems(itemsToPlace, helper, srcReader, destReader, idsInUse);
+            screenId++;
+        }
+    }
+
+    private void placeItems(WorkspaceItemsToPlace itemsToPlace, DatabaseHelper helper,
+            GridSizeMigrationDBController.DbReader srcReader,
+            GridSizeMigrationDBController.DbReader destReader, List<Integer> idsInUse) {
+        while (!itemsToPlace.mPlacementSolution.isEmpty()) {
+            insertEntryInDb(helper, itemsToPlace.mPlacementSolution.remove(0),
+                    srcReader.mTableName, destReader.mTableName, idsInUse);
+        }
+    }
+
+
+    /**
+     * Only migrate the grid in this manner if the target grid is taller and not wider.
+     */
+    private boolean shouldMigrateToStrictlyTallerGrid(boolean isFirstLoad,
+            @NonNull DeviceGridState srcDeviceState,
+            @NonNull DeviceGridState destDeviceState) {
+        if (isFirstLoad
+                && Flags.enableGridMigrationFix()
+                && srcDeviceState.getColumns().equals(destDeviceState.getColumns())
+                && srcDeviceState.getRows() < destDeviceState.getRows()) {
+            return true;
+        }
+        return false;
+    }
+
+    /**
+     * Finds all the items that are in the old grid which aren't in the new grid, meaning they
+     * need to be added to the new grid.
+     *
+     * @return a list of DbEntry's which we need to add.
+     */
+    private List<DbEntry> getItemsToBeAdded(
+            @NonNull final List<DbEntry> src,
+            @NonNull final List<DbEntry> dest) {
+        Map<DbEntry, Integer> entryCountDiff =
+                calcDiff(src, dest);
+        List<DbEntry> toBeAdded = new ArrayList<>();
+        src.forEach(entry -> {
+            if (entryCountDiff.get(entry) > 0) {
+                toBeAdded.add(entry);
+                entryCountDiff.put(entry, entryCountDiff.get(entry) - 1);
+            }
+        });
+        return toBeAdded;
+    }
+
+    /**
+     * Finds all the items that are in the new grid which aren't in the old grid, meaning they
+     * need to be removed from the new grid.
+     *
+     * @return an IntArray of item id's which we need to remove.
+     */
+    private IntArray getItemsToBeRemoved(
+            @NonNull final List<DbEntry> src,
+            @NonNull final List<DbEntry> dest) {
+        Map<DbEntry, Integer> entryCountDiff =
+                calcDiff(src, dest);
+        IntArray toBeRemoved = new IntArray();
+        dest.forEach(entry -> {
+            if (entryCountDiff.get(entry) < 0) {
+                toBeRemoved.add(entry.id);
+                if (entry.itemType == LauncherSettings.Favorites.ITEM_TYPE_FOLDER) {
+                    entry.mFolderItems.values().forEach(ids -> ids.forEach(toBeRemoved::add));
+                }
+                entryCountDiff.put(entry, entryCountDiff.get(entry) + 1);
+            }
+        });
+        return toBeRemoved;
+    }
+
+    /**
+     * Calculates the difference between the old and new grid items in terms of how many of each
+     * item there are. E.g. if the old grid had 2 Calculator icons but the new grid has 0, then the
+     * difference there would be 2. While if the old grid has 0 Calculator icons and the
+     * new grid has 1, then the difference would be -1.
+     *
+     * @return a Map with each DbEntry as a key and the count of said entry as the value.
+     */
+    private Map<DbEntry, Integer> calcDiff(
+            @NonNull final List<DbEntry> src,
+            @NonNull final List<DbEntry> dest) {
+        Map<DbEntry, Integer> entryCountDiff = new HashMap<>();
+        src.forEach(entry ->
+                entryCountDiff.put(entry, entryCountDiff.getOrDefault(entry, 0) + 1));
+        dest.forEach(entry ->
+                entryCountDiff.put(entry, entryCountDiff.getOrDefault(entry, 0) - 1));
+        return entryCountDiff;
+    }
+
+    private List<DbEntry> solveHotseatPlacement(final int hotseatSize,
+            @NonNull final List<DbEntry> placedHotseatItems,
+            @NonNull final List<DbEntry> itemsToPlace) {
+        List<DbEntry> placementSolution = new ArrayList<>();
+        List<DbEntry> remainingItemsToPlace =
+                new ArrayList<>(itemsToPlace);
+        final boolean[] occupied = new boolean[hotseatSize];
+        for (DbEntry entry : placedHotseatItems) {
+            occupied[entry.screenId] = true;
+        }
+
+        for (int i = 0; i < occupied.length; i++) {
+            if (!occupied[i] && !remainingItemsToPlace.isEmpty()) {
+                DbEntry entry = remainingItemsToPlace.remove(0);
+                entry.screenId = i;
+                // These values does not affect the item position, but we should set them
+                // to something other than -1.
+                entry.cellX = i;
+                entry.cellY = 0;
+
+                placementSolution.add(entry);
+                occupied[entry.screenId] = true;
+            }
+        }
+        return placementSolution;
+    }
+
+    private WorkspaceItemsToPlace solveGridPlacement(
+            Context context,
+            final int screenId, final int trgX, final int trgY,
+            @NonNull final List<DbEntry> sortedItemsToPlace,
+            List<DbEntry> existedEntries) {
+        WorkspaceItemsToPlace itemsToPlace = new WorkspaceItemsToPlace(sortedItemsToPlace);
+        final GridOccupancy occupied = new GridOccupancy(trgX, trgY);
+        final Point trg = new Point(trgX, trgY);
+        final Point next = new Point(0, screenId == 0
+                && (FeatureFlags.QSB_ON_FIRST_SCREEN
+                && (!enableSmartspaceRemovalToggle() || LauncherPrefs.getPrefs(context)
+                .getBoolean(SMARTSPACE_ON_HOME_SCREEN, true))
+                && !SHOULD_SHOW_FIRST_PAGE_WIDGET)
+                ? 1 /* smartspace */ : 0);
+        if (existedEntries != null) {
+            for (DbEntry entry : existedEntries) {
+                occupied.markCells(entry, true);
+            }
+        }
+        Iterator<DbEntry> iterator =
+                itemsToPlace.mRemainingItemsToPlace.iterator();
+        while (iterator.hasNext()) {
+            final DbEntry entry = iterator.next();
+            if (entry.minSpanX > trgX || entry.minSpanY > trgY) {
+                iterator.remove();
+                continue;
+            }
+            CellAndSpan placement = findPlacementForEntry(
+                    entry, next.x, next.y, trg, occupied);
+            if (placement != null) {
+                entry.screenId = screenId;
+                entry.cellX = placement.cellX;
+                entry.cellY = placement.cellY;
+                entry.spanX = placement.spanX;
+                entry.spanY = placement.spanY;
+                occupied.markCells(entry, true);
+                next.set(entry.cellX + entry.spanX, entry.cellY);
+                itemsToPlace.mPlacementSolution.add(entry);
+                iterator.remove();
+            }
+        }
+        return itemsToPlace;
+    }
+
+    /**
+     * Search for the next possible placement of an item. (mNextStartX, mNextStartY) serves as
+     * a memoization of last placement, we can start our search for next placement from there
+     * to speed up the search.
+     *
+     * @return NewEntryPlacement object if we found a valid placement, null if we didn't.
+     */
+    private CellAndSpan findPlacementForEntry(
+            @NonNull final DbEntry entry,
+            int startPosX, int startPosY, @NonNull final Point trg,
+            @NonNull final GridOccupancy occupied) {
+        for (int y = startPosY; y <  trg.y; y++) {
+            for (int x = startPosX; x < trg.x; x++) {
+                boolean minFits = occupied.isRegionVacant(x, y, entry.minSpanX, entry.minSpanY);
+                if (minFits) {
+                    return (new CellAndSpan(x, y, entry.minSpanX, entry.minSpanY));
+                }
+            }
+            startPosX = 0;
+        }
+        return null;
+    }
+
+    private static class WorkspaceItemsToPlace {
+        List<DbEntry> mRemainingItemsToPlace;
+        List<DbEntry> mPlacementSolution;
+
+        WorkspaceItemsToPlace(List<DbEntry> sortedItemsToPlace) {
+            mRemainingItemsToPlace = new ArrayList<>(sortedItemsToPlace);
+            mPlacementSolution = new ArrayList<>();
+        }
+
+    }
+}
diff --git a/src/com/android/launcher3/model/LoaderTask.java b/src/com/android/launcher3/model/LoaderTask.java
index 09d1146..b0108c2 100644
--- a/src/com/android/launcher3/model/LoaderTask.java
+++ b/src/com/android/launcher3/model/LoaderTask.java
@@ -435,7 +435,15 @@
         final WidgetInflater widgetInflater = new WidgetInflater(context);
 
         ModelDbController dbController = mApp.getModel().getModelDbController();
-        dbController.tryMigrateDB(restoreEventLogger);
+        if (Flags.gridMigrationRefactor()) {
+            try {
+                dbController.attemptMigrateDb(restoreEventLogger);
+            } catch (Exception e) {
+                FileLog.e(TAG, "Failed to migrate grid", e);
+            }
+        } else {
+            dbController.tryMigrateDB(restoreEventLogger);
+        }
         Log.d(TAG, "loadWorkspace: loading default favorites");
         dbController.loadDefaultFavoritesIfNecessary();
 
diff --git a/src/com/android/launcher3/model/ModelDbController.java b/src/com/android/launcher3/model/ModelDbController.java
index 787aef4..4f0f162 100644
--- a/src/com/android/launcher3/model/ModelDbController.java
+++ b/src/com/android/launcher3/model/ModelDbController.java
@@ -291,6 +291,100 @@
 
 
     /**
+     * Resets the launcher DB if we should reset it.
+     */
+    public void resetLauncherDb(@Nullable LauncherRestoreEventLogger restoreEventLogger) {
+        if (restoreEventLogger != null) {
+            sendMetricsForFailedMigration(restoreEventLogger, getDb());
+        }
+        FileLog.d(TAG, "Migration failed: resetting launcher database");
+        createEmptyDB();
+        LauncherPrefs.get(mContext).putSync(
+                getEmptyDbCreatedKey(mOpenHelper.getDatabaseName()).to(true));
+
+        // Write the grid state to avoid another migration
+        new DeviceGridState(LauncherAppState.getIDP(mContext)).writeToPrefs(mContext);
+    }
+
+    /**
+     * Determines if we should reset the DB.
+     */
+    private boolean shouldResetDb() {
+        if (isThereExistingDb()) {
+            return true;
+        }
+        if (!isGridMigrationNecessary()) {
+            return false;
+        }
+        if (isCurrentDbSameAsTarget()) {
+            return true;
+        }
+        return false;
+    }
+
+    private boolean isThereExistingDb() {
+        if (LauncherPrefs.get(mContext).get(getEmptyDbCreatedKey())) {
+            // If we already have a new DB, ignore migration
+            Log.d(TAG, "migrateGridIfNeeded: new DB already created, skipping migration");
+            return true;
+        }
+        return false;
+    }
+
+    private boolean isGridMigrationNecessary() {
+        InvariantDeviceProfile idp = LauncherAppState.getIDP(mContext);
+        if (GridSizeMigrationDBController.needsToMigrate(mContext, idp)) {
+            return true;
+        }
+        Log.d(TAG, "migrateGridIfNeeded: no grid migration needed");
+        return false;
+    }
+
+    private boolean isCurrentDbSameAsTarget() {
+        InvariantDeviceProfile idp = LauncherAppState.getIDP(mContext);
+        String targetDbName = new DeviceGridState(idp).getDbFile();
+        if (TextUtils.equals(targetDbName, mOpenHelper.getDatabaseName())) {
+            Log.e(TAG, "migrateGridIfNeeded: target db is same as current: " + targetDbName);
+            return true;
+        }
+        return false;
+    }
+
+    /**
+     * Migrates the DB. If the migration failed, it clears the DB.
+     */
+    public void attemptMigrateDb(LauncherRestoreEventLogger restoreEventLogger) throws Exception {
+        createDbIfNotExists();
+
+        if (shouldResetDb()) {
+            resetLauncherDb(restoreEventLogger);
+            return;
+        }
+
+        InvariantDeviceProfile idp = LauncherAppState.getIDP(mContext);
+        DatabaseHelper oldHelper = mOpenHelper;
+        mOpenHelper = (mContext instanceof SandboxContext) ? oldHelper
+                : createDatabaseHelper(true /* forMigration */);
+        try {
+            // This is the current grid we have, given by the mContext
+            DeviceGridState srcDeviceState = new DeviceGridState(mContext);
+            // This is the state we want to migrate to that is given by the idp
+            DeviceGridState destDeviceState = new DeviceGridState(idp);
+
+            GridSizeMigrationLogic gridSizeMigrationLogic = new GridSizeMigrationLogic();
+            gridSizeMigrationLogic.migrateGrid(mContext, srcDeviceState, destDeviceState,
+                    mOpenHelper, oldHelper.getWritableDatabase());
+        } catch (Exception e) {
+            resetLauncherDb(restoreEventLogger);
+            throw new Exception("Failed to migrate grid", e);
+        } finally {
+            if (mOpenHelper != oldHelper) {
+                oldHelper.close();
+            }
+        }
+    }
+
+    /**
      * Migrates the DB if needed. If the migration failed, it clears the DB.
      */
     public void tryMigrateDB(@Nullable LauncherRestoreEventLogger restoreEventLogger) {
@@ -333,8 +427,8 @@
             return false;
         }
         InvariantDeviceProfile idp = LauncherAppState.getIDP(mContext);
-        if (!GridSizeMigrationUtil.needsToMigrate(mContext, idp)) {
-            FileLog.d(TAG, "migrateGridIfNeeded: no grid migration needed");
+        if (!GridSizeMigrationDBController.needsToMigrate(mContext, idp)) {
+            Log.d(TAG, "migrateGridIfNeeded: no grid migration needed");
             return true;
         }
         String targetDbName = new DeviceGridState(idp).getDbFile();
@@ -350,7 +444,7 @@
             DeviceGridState srcDeviceState = new DeviceGridState(mContext);
             // This is the state we want to migrate to that is given by the idp
             DeviceGridState destDeviceState = new DeviceGridState(idp);
-            return GridSizeMigrationUtil.migrateGridIfNeeded(mContext, srcDeviceState,
+            return GridSizeMigrationDBController.migrateGridIfNeeded(mContext, srcDeviceState,
                     destDeviceState, mOpenHelper, oldHelper.getWritableDatabase());
         } catch (Exception e) {
             FileLog.e(TAG, "Failed to migrate grid", e);
diff --git a/tests/multivalentTests/src/com/android/launcher3/celllayout/FavoriteItemsTransaction.java b/tests/multivalentTests/src/com/android/launcher3/celllayout/FavoriteItemsTransaction.java
index 0c3081f..a9082e2 100644
--- a/tests/multivalentTests/src/com/android/launcher3/celllayout/FavoriteItemsTransaction.java
+++ b/tests/multivalentTests/src/com/android/launcher3/celllayout/FavoriteItemsTransaction.java
@@ -22,6 +22,7 @@
 
 import android.content.Context;
 
+import com.android.launcher3.Flags;
 import com.android.launcher3.LauncherAppState;
 import com.android.launcher3.LauncherModel;
 import com.android.launcher3.LauncherSettings;
@@ -59,7 +60,11 @@
         runOnExecutorSync(MODEL_EXECUTOR, () -> {
             ModelDbController controller = model.getModelDbController();
             // Migrate any previous data so that the DB state is correct
-            controller.tryMigrateDB(null /* restoreEventLogger */);
+            if (Flags.gridMigrationRefactor()) {
+                controller.attemptMigrateDb(null /* restoreEventLogger */);
+            } else {
+                controller.tryMigrateDB(null /* restoreEventLogger */);
+            }
 
             // Create DB again to load fresh data
             controller.createEmptyDB();
diff --git a/tests/multivalentTests/src/com/android/launcher3/model/GridSizeMigrationUtilTest.kt b/tests/multivalentTests/src/com/android/launcher3/model/GridSizeMigrationDBControllerTest.kt
similarity index 83%
rename from tests/multivalentTests/src/com/android/launcher3/model/GridSizeMigrationUtilTest.kt
rename to tests/multivalentTests/src/com/android/launcher3/model/GridSizeMigrationDBControllerTest.kt
index f57e8a1..c6f291d 100644
--- a/tests/multivalentTests/src/com/android/launcher3/model/GridSizeMigrationUtilTest.kt
+++ b/tests/multivalentTests/src/com/android/launcher3/model/GridSizeMigrationDBControllerTest.kt
@@ -22,13 +22,16 @@
 import android.database.sqlite.SQLiteDatabase
 import android.graphics.Point
 import android.os.Process
+import android.platform.test.annotations.DisableFlags
+import android.platform.test.annotations.EnableFlags
 import androidx.test.ext.junit.runners.AndroidJUnit4
 import androidx.test.filters.SmallTest
+import com.android.launcher3.Flags
 import com.android.launcher3.InvariantDeviceProfile
 import com.android.launcher3.LauncherPrefs
 import com.android.launcher3.LauncherPrefs.Companion.WORKSPACE_SIZE
 import com.android.launcher3.LauncherSettings.Favorites.*
-import com.android.launcher3.model.GridSizeMigrationUtil.DbReader
+import com.android.launcher3.model.GridSizeMigrationDBController.DbReader
 import com.android.launcher3.pm.UserCache
 import com.android.launcher3.provider.LauncherDbUtils
 import com.android.launcher3.util.LauncherModelHelper
@@ -82,9 +85,22 @@
         modelHelper.destroy()
     }
 
-    /** Old migration logic, should be modified once is not needed anymore */
     @Test
     @Throws(Exception::class)
+    @EnableFlags(Flags.FLAG_GRID_MIGRATION_REFACTOR)
+    fun testMigrationRefactorFlagOn() {
+        testMigration()
+    }
+
+    @Test
+    @Throws(Exception::class)
+    @DisableFlags(Flags.FLAG_GRID_MIGRATION_REFACTOR)
+    fun testMigrationRefactorFlagOff() {
+        testMigration()
+    }
+
+    /** Old migration logic, should be modified once is not needed anymore */
+    @Throws(Exception::class)
     fun testMigration() {
         // Src Hotseat icons
         addItem(ITEM_TYPE_APPLICATION, 0, CONTAINER_HOTSEAT, 0, 0, testPackage1, 1, TMP_TABLE)
@@ -113,15 +129,26 @@
         idp.numRows = 4
         val srcReader = DbReader(db, TMP_TABLE, context)
         val destReader = DbReader(db, TABLE_NAME, context)
-        GridSizeMigrationUtil.migrate(
-            dbHelper,
-            srcReader,
-            destReader,
-            idp.numDatabaseHotseatIcons,
-            Point(idp.numColumns, idp.numRows),
-            DeviceGridState(context),
-            DeviceGridState(idp),
-        )
+        if (Flags.gridMigrationRefactor()) {
+            val gridSizeMigrationLogic = GridSizeMigrationLogic()
+            gridSizeMigrationLogic.migrateGrid(
+                context,
+                DeviceGridState(context),
+                DeviceGridState(idp),
+                dbHelper,
+                db,
+            )
+        } else {
+            GridSizeMigrationDBController.migrate(
+                dbHelper,
+                srcReader,
+                destReader,
+                idp.numDatabaseHotseatIcons,
+                Point(idp.numColumns, idp.numRows),
+                DeviceGridState(context),
+                DeviceGridState(idp),
+            )
+        }
 
         // Check hotseat items
         var c =
@@ -187,9 +214,22 @@
         assertThat(locMap[testPackage9]).isEqualTo(Point(0, 2))
     }
 
-    /** Old migration logic, should be modified once is not needed anymore */
     @Test
     @Throws(Exception::class)
+    @EnableFlags(Flags.FLAG_GRID_MIGRATION_REFACTOR)
+    fun testMigrationBackAndForthRefactorFlagOn() {
+        testMigrationBackAndForth()
+    }
+
+    @Test
+    @Throws(Exception::class)
+    @DisableFlags(Flags.FLAG_GRID_MIGRATION_REFACTOR)
+    fun testMigrationBackAndForthRefactorFlagOff() {
+        testMigrationBackAndForth()
+    }
+
+    /** Old migration logic, should be modified once is not needed anymore */
+    @Throws(Exception::class)
     fun testMigrationBackAndForth() {
         // Hotseat items in grid A
         // 1 2 _ 3 4
@@ -224,15 +264,26 @@
         val readerGridA = DbReader(db, TMP_TABLE, context)
         val readerGridB = DbReader(db, TABLE_NAME, context)
         // migrate from A -> B
-        GridSizeMigrationUtil.migrate(
-            dbHelper,
-            readerGridA,
-            readerGridB,
-            idp.numDatabaseHotseatIcons,
-            Point(idp.numColumns, idp.numRows),
-            DeviceGridState(context),
-            DeviceGridState(idp),
-        )
+        if (Flags.gridMigrationRefactor()) {
+            var gridSizeMigrationLogic = GridSizeMigrationLogic()
+            gridSizeMigrationLogic.migrateGrid(
+                context,
+                DeviceGridState(context),
+                DeviceGridState(idp),
+                dbHelper,
+                db,
+            )
+        } else {
+            GridSizeMigrationDBController.migrate(
+                dbHelper,
+                readerGridA,
+                readerGridB,
+                idp.numDatabaseHotseatIcons,
+                Point(idp.numColumns, idp.numRows),
+                DeviceGridState(context),
+                DeviceGridState(idp),
+            )
+        }
 
         // Check hotseat items in grid B
         var c =
@@ -280,15 +331,8 @@
         addItem(ITEM_TYPE_APPLICATION, 0, CONTAINER_DESKTOP, 0, 2, testPackage9)
 
         // migrate from B -> A
-        GridSizeMigrationUtil.migrate(
-            dbHelper,
-            readerGridB,
-            readerGridA,
-            5,
-            Point(5, 5),
-            DeviceGridState(idp),
-            DeviceGridState(context),
-        )
+        migrateGrid(dbHelper, readerGridB, readerGridA, 5, 5, 5)
+
         // Check hotseat items in grid A
         c =
             db.query(
@@ -339,14 +383,13 @@
         db.delete(TMP_TABLE, "$_ID=7", null)
 
         // migrate from A -> B
-        GridSizeMigrationUtil.migrate(
+        migrateGrid(
             dbHelper,
             readerGridA,
             readerGridB,
             idp.numDatabaseHotseatIcons,
-            Point(idp.numColumns, idp.numRows),
-            DeviceGridState(context),
-            DeviceGridState(idp),
+            idp.numColumns,
+            idp.numRows,
         )
 
         // Check hotseat items in grid B
@@ -392,6 +435,36 @@
         assertThat(locMap[testPackage9]).isEqualTo(Triple(0, 0, 2))
     }
 
+    private fun migrateGrid(
+        dbHelper: DatabaseHelper,
+        srcReader: DbReader,
+        destReader: DbReader,
+        destHotseatSize: Int,
+        pointX: Int,
+        pointY: Int,
+    ) {
+        if (Flags.gridMigrationRefactor()) {
+            var gridSizeMigrationLogic = GridSizeMigrationLogic()
+            gridSizeMigrationLogic.migrateGrid(
+                context,
+                DeviceGridState(context),
+                DeviceGridState(idp),
+                dbHelper,
+                db,
+            )
+        } else {
+            GridSizeMigrationDBController.migrate(
+                dbHelper,
+                srcReader,
+                destReader,
+                destHotseatSize,
+                Point(pointX, pointY),
+                DeviceGridState(idp),
+                DeviceGridState(context),
+            )
+        }
+    }
+
     private fun verifyHotseat(c: Cursor, idp: InvariantDeviceProfile, expected: List<String?>) {
         assertThat(c.count).isEqualTo(idp.numDatabaseHotseatIcons)
         val screenIndex = c.getColumnIndex(SCREEN)
@@ -421,6 +494,17 @@
     }
 
     @Test
+    @EnableFlags(Flags.FLAG_GRID_MIGRATION_REFACTOR)
+    fun migrateToLargerHotseatRefactorFlagOn() {
+        migrateToLargerHotseat()
+    }
+
+    @Test
+    @DisableFlags(Flags.FLAG_GRID_MIGRATION_REFACTOR)
+    fun migrateToLargerHotseatRefactorFlagOff() {
+        migrateToLargerHotseat()
+    }
+
     fun migrateToLargerHotseat() {
         val srcHotseatItems =
             intArrayOf(
@@ -471,14 +555,13 @@
         idp.numRows = 4
         val srcReader = DbReader(db, TMP_TABLE, context)
         val destReader = DbReader(db, TABLE_NAME, context)
-        GridSizeMigrationUtil.migrate(
+        migrateGrid(
             dbHelper,
             srcReader,
             destReader,
             idp.numDatabaseHotseatIcons,
-            Point(idp.numColumns, idp.numRows),
-            DeviceGridState(context),
-            DeviceGridState(idp),
+            idp.numColumns,
+            idp.numRows,
         )
 
         // Check hotseat items
@@ -516,6 +599,17 @@
     }
 
     @Test
+    @EnableFlags(Flags.FLAG_GRID_MIGRATION_REFACTOR)
+    fun migrateFromLargerHotseatRefactorFlagOn() {
+        migrateFromLargerHotseat()
+    }
+
+    @Test
+    @DisableFlags(Flags.FLAG_GRID_MIGRATION_REFACTOR)
+    fun migrateFromLargerHotseatRefactorFlagOff() {
+        migrateFromLargerHotseat()
+    }
+
     fun migrateFromLargerHotseat() {
         addItem(ITEM_TYPE_APPLICATION, 0, CONTAINER_HOTSEAT, 0, 0, testPackage1, 1, TMP_TABLE)
         addItem(ITEM_TYPE_DEEP_SHORTCUT, 2, CONTAINER_HOTSEAT, 0, 0, testPackage2, 2, TMP_TABLE)
@@ -528,14 +622,13 @@
         idp.numRows = 4
         val srcReader = DbReader(db, TMP_TABLE, context)
         val destReader = DbReader(db, TABLE_NAME, context)
-        GridSizeMigrationUtil.migrate(
+        migrateGrid(
             dbHelper,
             srcReader,
             destReader,
             idp.numDatabaseHotseatIcons,
-            Point(idp.numColumns, idp.numRows),
-            DeviceGridState(context),
-            DeviceGridState(idp),
+            idp.numColumns,
+            idp.numRows,
         )
 
         // Check hotseat items
@@ -573,11 +666,24 @@
         c.close()
     }
 
+    @Test
+    @Throws(Exception::class)
+    @EnableFlags(Flags.FLAG_GRID_MIGRATION_REFACTOR)
+    fun migrateFromSmallerGridBigDifferenceRefactorFlagOn() {
+        migrateFromSmallerGridBigDifference()
+    }
+
+    @Test
+    @Throws(Exception::class)
+    @DisableFlags(Flags.FLAG_GRID_MIGRATION_REFACTOR)
+    fun migrateFromSmallerGridBigDifferenceRefactorFlagOff() {
+        migrateFromSmallerGridBigDifference()
+    }
+
     /**
      * Migrating from a smaller grid to a large one should reflow the pages if the column difference
      * is more than 2
      */
-    @Test
     @Throws(Exception::class)
     fun migrateFromSmallerGridBigDifference() {
         enableNewMigrationLogic("2,2")
@@ -594,14 +700,13 @@
         idp.numRows = 5
         val srcReader = DbReader(db, TMP_TABLE, context)
         val destReader = DbReader(db, TABLE_NAME, context)
-        GridSizeMigrationUtil.migrate(
+        migrateGrid(
             dbHelper,
             srcReader,
             destReader,
             idp.numDatabaseHotseatIcons,
-            Point(idp.numColumns, idp.numRows),
-            DeviceGridState(context),
-            DeviceGridState(idp),
+            idp.numColumns,
+            idp.numRows,
         )
 
         // Get workspace items
@@ -636,9 +741,22 @@
         assertThat(locMap[testPackage5]).isEqualTo(0)
     }
 
-    /** Migrating from a larger grid to a smaller, we reflow from page 0 */
     @Test
     @Throws(Exception::class)
+    @EnableFlags(Flags.FLAG_GRID_MIGRATION_REFACTOR)
+    fun migrateFromLargerGridRefactorFlagOn() {
+        migrateFromLargerGrid()
+    }
+
+    @Test
+    @Throws(Exception::class)
+    @DisableFlags(Flags.FLAG_GRID_MIGRATION_REFACTOR)
+    fun migrateFromLargerGridRefactorFlagOff() {
+        migrateFromLargerGrid()
+    }
+
+    /** Migrating from a larger grid to a smaller, we reflow from page 0 */
+    @Throws(Exception::class)
     fun migrateFromLargerGrid() {
         enableNewMigrationLogic("5,5")
 
@@ -654,14 +772,13 @@
         idp.numRows = 4
         val srcReader = DbReader(db, TMP_TABLE, context)
         val destReader = DbReader(db, TABLE_NAME, context)
-        GridSizeMigrationUtil.migrate(
+        migrateGrid(
             dbHelper,
             srcReader,
             destReader,
             idp.numDatabaseHotseatIcons,
-            Point(idp.numColumns, idp.numRows),
-            DeviceGridState(context),
-            DeviceGridState(idp),
+            idp.numColumns,
+            idp.numRows,
         )
 
         // Get workspace items
diff --git a/tests/multivalentTests/src/com/android/launcher3/util/ModelTestExtensions.kt b/tests/multivalentTests/src/com/android/launcher3/util/ModelTestExtensions.kt
index 6bd182b..8d072d8 100644
--- a/tests/multivalentTests/src/com/android/launcher3/util/ModelTestExtensions.kt
+++ b/tests/multivalentTests/src/com/android/launcher3/util/ModelTestExtensions.kt
@@ -1,6 +1,7 @@
 package com.android.launcher3.util
 
 import android.content.ContentValues
+import com.android.launcher3.Flags
 import com.android.launcher3.LauncherModel
 import com.android.launcher3.LauncherSettings.Favorites
 import com.android.launcher3.LauncherSettings.Favorites.APPWIDGET_ID
@@ -30,7 +31,8 @@
         loadModelSync()
         TestUtil.runOnExecutorSync(Executors.MODEL_EXECUTOR) {
             modelDbController.run {
-                tryMigrateDB(null /* restoreEventLogger */)
+                if (Flags.gridMigrationRefactor()) attemptMigrateDb(null /* restoreEventLogger */)
+                else tryMigrateDB(null /* restoreEventLogger */)
                 createEmptyDB()
                 clearEmptyDbFlag()
             }
@@ -67,12 +69,12 @@
         tableName: String = Favorites.TABLE_NAME,
         appWidgetId: Int = -1,
         appWidgetSource: Int = -1,
-        appWidgetProvider: String? = null
+        appWidgetProvider: String? = null,
     ) {
         loadModelSync()
         TestUtil.runOnExecutorSync(Executors.MODEL_EXECUTOR) {
             val controller: ModelDbController = modelDbController
-            controller.tryMigrateDB(null /* restoreEventLogger */)
+            controller.attemptMigrateDb(null /* restoreEventLogger */)
             modelDbController.newTransaction().use { transaction ->
                 val values =
                     ContentValues().apply {
diff --git a/tests/src/com/android/launcher3/backuprestore/BackupAndRestoreDBSelectionTest.kt b/tests/src/com/android/launcher3/backuprestore/BackupAndRestoreDBSelectionTest.kt
index 35ac0a1..b4ee090 100644
--- a/tests/src/com/android/launcher3/backuprestore/BackupAndRestoreDBSelectionTest.kt
+++ b/tests/src/com/android/launcher3/backuprestore/BackupAndRestoreDBSelectionTest.kt
@@ -16,6 +16,8 @@
 
 package com.android.launcher3.backuprestore
 
+import android.platform.test.annotations.DisableFlags
+import android.platform.test.annotations.EnableFlags
 import android.platform.test.flag.junit.SetFlagsRule
 import androidx.test.ext.junit.runners.AndroidJUnit4
 import androidx.test.filters.MediumTest
@@ -52,10 +54,24 @@
         setFlagsRule.setFlags(true, Flags.FLAG_ENABLE_NARROW_GRID_RESTORE)
     }
 
+    @EnableFlags(Flags.FLAG_GRID_MIGRATION_REFACTOR)
+    fun oldDatabasesNotPresentAfterRestoreRefactorFlagEnabled() {
+        oldDatabasesNotPresentAfterRestore()
+    }
+
+    @DisableFlags(Flags.FLAG_GRID_MIGRATION_REFACTOR)
+    fun oldDatabasesNotPresentAfterRestoreRefactorFlagDisabled() {
+        oldDatabasesNotPresentAfterRestore()
+    }
+
     @Test
     fun oldDatabasesNotPresentAfterRestore() {
         val dbController = ModelDbController(getInstrumentation().targetContext)
-        dbController.tryMigrateDB(null)
+        if (Flags.gridMigrationRefactor()) {
+            dbController.attemptMigrateDb(null)
+        } else {
+            dbController.tryMigrateDB(null)
+        }
         TestUtil.runOnExecutorSync(MODEL_EXECUTOR) {
             assert(backAndRestoreRule.getDatabaseFiles().size == 1) {
                 "There should only be one database after restoring, the last one used. Actual databases ${backAndRestoreRule.getDatabaseFiles()}"
diff --git a/tests/src/com/android/launcher3/model/GridMigrationTest.kt b/tests/src/com/android/launcher3/model/GridMigrationTest.kt
index 15222a4..666ec16 100644
--- a/tests/src/com/android/launcher3/model/GridMigrationTest.kt
+++ b/tests/src/com/android/launcher3/model/GridMigrationTest.kt
@@ -52,11 +52,15 @@
             phoneContext,
             dbFileName,
             { UserCache.INSTANCE.get(phoneContext).getSerialNumberForUser(it) },
-            {}
+            {},
         )
 
-    fun readEntries(): List<GridSizeMigrationUtil.DbEntry> =
-        GridSizeMigrationUtil.readAllEntries(dbHelper.readableDatabase, TABLE_NAME, phoneContext)
+    fun readEntries(): List<DbEntry> =
+        GridSizeMigrationDBController.readAllEntries(
+            dbHelper.readableDatabase,
+            TABLE_NAME,
+            phoneContext,
+        )
 }
 
 /**
@@ -80,7 +84,7 @@
         TestToPhoneFileCopier(
             src = "databases/GridMigrationTest/$DB_FILE",
             dest = "databases/$DB_FILE",
-            removeOnFinish = true
+            removeOnFinish = true,
         )
 
     @Before
@@ -89,13 +93,24 @@
     }
 
     private fun migrate(src: GridMigrationData, dst: GridMigrationData) {
-        GridSizeMigrationUtil.migrateGridIfNeeded(
-            phoneContext,
-            src.gridState,
-            dst.gridState,
-            dst.dbHelper,
-            src.dbHelper.readableDatabase
-        )
+        if (Flags.gridMigrationRefactor()) {
+            val gridSizeMigrationLogic = GridSizeMigrationLogic()
+            gridSizeMigrationLogic.migrateGrid(
+                phoneContext,
+                src.gridState,
+                dst.gridState,
+                dst.dbHelper,
+                src.dbHelper.readableDatabase,
+            )
+        } else {
+            GridSizeMigrationDBController.migrateGridIfNeeded(
+                phoneContext,
+                src.gridState,
+                dst.gridState,
+                dst.dbHelper,
+                src.dbHelper.readableDatabase,
+            )
+        }
     }
 
     /**
@@ -115,10 +130,8 @@
     }
 
     private fun compare(dst: GridMigrationData, target: GridMigrationData) {
-        val sort = compareBy<GridSizeMigrationUtil.DbEntry>({ it.cellX }, { it.cellY })
-        val mapF = { it: GridSizeMigrationUtil.DbEntry ->
-            EntryData(it.cellX, it.cellY, it.spanX, it.spanY, it.rank)
-        }
+        val sort = compareBy<DbEntry>({ it.cellX }, { it.cellY })
+        val mapF = { it: DbEntry -> EntryData(it.cellX, it.cellY, it.spanX, it.spanY, it.rank) }
         val entriesDst = dst.readEntries().sortedWith(sort).map(mapF)
         val entriesTarget = target.readEntries().sortedWith(sort).map(mapF)
 
@@ -149,7 +162,7 @@
         TestToPhoneFileCopier(
             src = "databases/GridMigrationTest/result5x5to3x3.db",
             dest = "databases/result5x5to3x3.db",
-            removeOnFinish = true
+            removeOnFinish = true,
         )
 
     @Test
@@ -160,10 +173,10 @@
                 GridMigrationData(
                     null, // in memory db, to download a new db change null for the filename of the
                     // db name to store it. Do not use existing names.
-                    DeviceGridState(3, 3, 3, TYPE_PHONE, "")
+                    DeviceGridState(3, 3, 3, TYPE_PHONE, ""),
                 ),
             target =
-                GridMigrationData("result5x5to3x3.db", DeviceGridState(3, 3, 3, TYPE_PHONE, ""))
+                GridMigrationData("result5x5to3x3.db", DeviceGridState(3, 3, 3, TYPE_PHONE, "")),
         )
 
     @JvmField
@@ -172,7 +185,7 @@
         TestToPhoneFileCopier(
             src = "databases/GridMigrationTest/result5x5to4x7.db",
             dest = "databases/result5x5to4x7.db",
-            removeOnFinish = true
+            removeOnFinish = true,
         )
 
     @Test
@@ -183,10 +196,10 @@
                 GridMigrationData(
                     null, // in memory db, to download a new db change null for the filename of the
                     // db name to store it. Do not use existing names.
-                    DeviceGridState(4, 7, 4, TYPE_PHONE, "")
+                    DeviceGridState(4, 7, 4, TYPE_PHONE, ""),
                 ),
             target =
-                GridMigrationData("result5x5to4x7.db", DeviceGridState(4, 7, 4, TYPE_PHONE, ""))
+                GridMigrationData("result5x5to4x7.db", DeviceGridState(4, 7, 4, TYPE_PHONE, "")),
         )
 
     @JvmField
@@ -195,7 +208,7 @@
         TestToPhoneFileCopier(
             src = "databases/GridMigrationTest/result5x5to5x8.db",
             dest = "databases/result5x5to5x8.db",
-            removeOnFinish = true
+            removeOnFinish = true,
         )
 
     @Test
@@ -206,10 +219,10 @@
                 GridMigrationData(
                     null, // in memory db, to download a new db change null for the filename of the
                     // db name to store it. Do not use existing names.
-                    DeviceGridState(5, 8, 5, TYPE_PHONE, "")
+                    DeviceGridState(5, 8, 5, TYPE_PHONE, ""),
                 ),
             target =
-                GridMigrationData("result5x5to5x8.db", DeviceGridState(5, 8, 5, TYPE_PHONE, ""))
+                GridMigrationData("result5x5to5x8.db", DeviceGridState(5, 8, 5, TYPE_PHONE, "")),
         )
 
     @JvmField
@@ -218,7 +231,7 @@
         TestToPhoneFileCopier(
             src = "databases/GridMigrationTest/flagged_result5x5to5x8.db",
             dest = "databases/flagged_result5x5to5x8.db",
-            removeOnFinish = true
+            removeOnFinish = true,
         )
 
     @Test
@@ -230,13 +243,13 @@
                 GridMigrationData(
                     null, // in memory db, to download a new db change null for the filename of the
                     // db name to store it. Do not use existing names.
-                    DeviceGridState(5, 8, 5, TYPE_PHONE, "")
+                    DeviceGridState(5, 8, 5, TYPE_PHONE, ""),
                 ),
             target =
                 GridMigrationData(
                     "flagged_result5x5to5x8.db",
-                    DeviceGridState(5, 8, 5, TYPE_PHONE, "")
-                )
+                    DeviceGridState(5, 8, 5, TYPE_PHONE, ""),
+                ),
         )
     }
 }
diff --git a/tests/src/com/android/launcher3/model/gridmigration/ValidGridMigrationUnitTest.kt b/tests/src/com/android/launcher3/model/gridmigration/ValidGridMigrationUnitTest.kt
index 03d0195..c08237c 100644
--- a/tests/src/com/android/launcher3/model/gridmigration/ValidGridMigrationUnitTest.kt
+++ b/tests/src/com/android/launcher3/model/gridmigration/ValidGridMigrationUnitTest.kt
@@ -20,17 +20,21 @@
 import android.database.sqlite.SQLiteDatabase
 import android.graphics.Point
 import android.os.Process
+import android.platform.test.annotations.DisableFlags
+import android.platform.test.annotations.EnableFlags
 import android.util.Log
 import androidx.test.ext.junit.runners.AndroidJUnit4
 import androidx.test.filters.SmallTest
 import androidx.test.platform.app.InstrumentationRegistry
+import com.android.launcher3.Flags
 import com.android.launcher3.InvariantDeviceProfile
 import com.android.launcher3.LauncherSettings.Favorites
 import com.android.launcher3.celllayout.testgenerator.ValidGridMigrationTestCaseGenerator
 import com.android.launcher3.celllayout.testgenerator.generateItemsForTest
 import com.android.launcher3.model.DatabaseHelper
 import com.android.launcher3.model.DeviceGridState
-import com.android.launcher3.model.GridSizeMigrationUtil
+import com.android.launcher3.model.GridSizeMigrationDBController
+import com.android.launcher3.model.GridSizeMigrationLogic
 import com.android.launcher3.pm.UserCache
 import com.android.launcher3.provider.LauncherDbUtils
 import com.android.launcher3.util.rule.TestStabilityRule
@@ -130,22 +134,44 @@
         addItemsToDb(dbHelper.writableDatabase, dstGrid)
 
         LauncherDbUtils.SQLiteTransaction(dbHelper.writableDatabase).use {
-            GridSizeMigrationUtil.migrate(
-                dbHelper,
-                GridSizeMigrationUtil.DbReader(it.db, srcGrid.tableName, context),
-                GridSizeMigrationUtil.DbReader(it.db, dstGrid.tableName, context),
-                dstGrid.size.x,
-                dstGrid.size,
-                srcGrid.toGridState(),
-                dstGrid.toGridState(),
-            )
+            if (Flags.gridMigrationRefactor()) {
+                val gridSizeMigrationLogic = GridSizeMigrationLogic()
+                gridSizeMigrationLogic.migrateGrid(
+                    context,
+                    srcGrid.toGridState(),
+                    dstGrid.toGridState(),
+                    dbHelper,
+                    it.db,
+                )
+            } else {
+                GridSizeMigrationDBController.migrate(
+                    dbHelper,
+                    GridSizeMigrationDBController.DbReader(it.db, srcGrid.tableName, context),
+                    GridSizeMigrationDBController.DbReader(it.db, dstGrid.tableName, context),
+                    dstGrid.size.x,
+                    dstGrid.size,
+                    srcGrid.toGridState(),
+                    dstGrid.toGridState(),
+                )
+            }
             it.commit()
         }
         return readDb(dstGrid.tableName, dbHelper.readableDatabase)
     }
 
     @Test
-    fun runTestCase() {
+    @EnableFlags(Flags.FLAG_GRID_MIGRATION_REFACTOR)
+    fun runTestCaseRefactorFlagEnabled() {
+        runTestCase()
+    }
+
+    @Test
+    @DisableFlags(Flags.FLAG_GRID_MIGRATION_REFACTOR)
+    fun runTestCaseRefactorFlagDisabled() {
+        runTestCase()
+    }
+
+    private fun runTestCase() {
         val caseGenerator = ValidGridMigrationTestCaseGenerator(Random(SEED.toLong()))
         for (i in 0..SMALL_TEST_SIZE) {
             val testCase = caseGenerator.generateTestCase(isDestEmpty = true)
@@ -163,7 +189,18 @@
     }
 
     @Test
-    fun mergeBoards() {
+    @EnableFlags(Flags.FLAG_GRID_MIGRATION_REFACTOR)
+    fun mergeBoardsRefactorFlagEnabled() {
+        mergeBoards()
+    }
+
+    @Test
+    @DisableFlags(Flags.FLAG_GRID_MIGRATION_REFACTOR)
+    fun mergeBoardsRefactorFlagDisabled() {
+        mergeBoards()
+    }
+
+    private fun mergeBoards() {
         val caseGenerator = ValidGridMigrationTestCaseGenerator(Random(SEED.toLong()))
         for (i in 0..SMALL_TEST_SIZE) {
             val testCase = caseGenerator.generateTestCase(isDestEmpty = false)
@@ -187,7 +224,20 @@
     // This test takes about 4 minutes, there is no need to run it in presubmit.
     @Stability(flavors = TestStabilityRule.LOCAL or TestStabilityRule.PLATFORM_POSTSUBMIT)
     @Test
-    fun runExtensiveTestCases() {
+    @EnableFlags(Flags.FLAG_GRID_MIGRATION_REFACTOR)
+    fun runExtensiveTestCasesRefactorFlagEnabled() {
+        runExtensiveTestCases()
+    }
+
+    // This test takes about 4 minutes, there is no need to run it in presubmit.
+    @Stability(flavors = TestStabilityRule.LOCAL or TestStabilityRule.PLATFORM_POSTSUBMIT)
+    @Test
+    @DisableFlags(Flags.FLAG_GRID_MIGRATION_REFACTOR)
+    fun runExtensiveTestCasesRefactorFlagDisabled() {
+        runExtensiveTestCases()
+    }
+
+    private fun runExtensiveTestCases() {
         val caseGenerator = ValidGridMigrationTestCaseGenerator(Random(SEED.toLong()))
         for (i in 0..LARGE_TEST_SIZE) {
             val testCase = caseGenerator.generateTestCase(isDestEmpty = true)