diff --git a/src/com/android/launcher3/model/GridSizeMigrationUtil.java b/src/com/android/launcher3/model/GridSizeMigrationUtil.java
index 30d2cfb..c40484d 100644
--- a/src/com/android/launcher3/model/GridSizeMigrationUtil.java
+++ b/src/com/android/launcher3/model/GridSizeMigrationUtil.java
@@ -40,6 +40,7 @@
 import androidx.annotation.NonNull;
 import androidx.annotation.VisibleForTesting;
 
+import com.android.launcher3.Flags;
 import com.android.launcher3.InvariantDeviceProfile;
 import com.android.launcher3.LauncherPrefs;
 import com.android.launcher3.LauncherSettings;
@@ -122,6 +123,16 @@
         if (!needsToMigrate(srcDeviceState, destDeviceState)) {
             return true;
         }
+
+        if (Flags.gridMigrationFix()
+                && srcDeviceState.getColumns().equals(destDeviceState.getColumns())
+                && srcDeviceState.getRows() < destDeviceState.getRows()) {
+            // Only use this strategy when comparing the previous grid to the new grid and the
+            // columns are the same and the destination has more rows
+            copyTable(source, TABLE_NAME, target.getWritableDatabase(), TABLE_NAME, context);
+            destDeviceState.writeToPrefs(context);
+            return true;
+        }
         copyTable(source, TABLE_NAME, target.getWritableDatabase(), TMP_TABLE, context);
 
         HashSet<String> validPackages = getValidPackages(context);
diff --git a/tests/assets/databases/GridMigrationTest/flagged_result5x5to5x8.db b/tests/assets/databases/GridMigrationTest/flagged_result5x5to5x8.db
new file mode 100644
index 0000000..8bea3ce
--- /dev/null
+++ b/tests/assets/databases/GridMigrationTest/flagged_result5x5to5x8.db
Binary files differ
diff --git a/tests/src/com/android/launcher3/model/GridMigrationTest.kt b/tests/src/com/android/launcher3/model/GridMigrationTest.kt
index eb8604e..da4a208 100644
--- a/tests/src/com/android/launcher3/model/GridMigrationTest.kt
+++ b/tests/src/com/android/launcher3/model/GridMigrationTest.kt
@@ -16,14 +16,18 @@
 
 package com.android.launcher3.model
 
+import android.platform.test.flag.junit.SetFlagsRule
 import androidx.test.ext.junit.runners.AndroidJUnit4
 import androidx.test.filters.SmallTest
 import androidx.test.platform.app.InstrumentationRegistry
+import com.android.launcher3.Flags
 import com.android.launcher3.InvariantDeviceProfile.TYPE_PHONE
 import com.android.launcher3.LauncherSettings.Favorites.TABLE_NAME
 import com.android.launcher3.celllayout.board.CellLayoutBoard
 import com.android.launcher3.pm.UserCache
 import com.android.launcher3.util.rule.TestToPhoneFileCopier
+import com.android.launcher3.util.rule.setFlags
+import org.junit.Before
 import org.junit.Rule
 import org.junit.Test
 import org.junit.runner.RunWith
@@ -65,11 +69,24 @@
 class GridMigrationTest {
     private val DB_FILE = "test_launcher.db"
 
+    @JvmField
+    @Rule
+    val setFlagsRule = SetFlagsRule(SetFlagsRule.DefaultInitValueType.DEVICE_DEFAULT)
+
     // Copying the src db for all tests.
     @JvmField
     @Rule
     val fileCopier =
-        TestToPhoneFileCopier("databases/GridMigrationTest/$DB_FILE", "databases/$DB_FILE", true)
+        TestToPhoneFileCopier(
+            src = "databases/GridMigrationTest/$DB_FILE",
+            dest = "databases/$DB_FILE",
+            removeOnFinish = true
+        )
+
+    @Before
+    fun setup() {
+        setFlagsRule.setFlags(false, Flags.FLAG_GRID_MIGRATION_FIX)
+    }
 
     private fun migrate(src: GridMigrationData, dst: GridMigrationData) {
         GridSizeMigrationUtil.migrateGridIfNeeded(
@@ -86,8 +103,10 @@
      * same space in the db.
      */
     private fun validateDb(data: GridMigrationData) {
-        val cellLayoutBoard = CellLayoutBoard(data.gridState.columns, data.gridState.rows)
+        // The array size is just a big enough number to fit all the number of workspaces
+        val boards = Array(100) { CellLayoutBoard(data.gridState.columns, data.gridState.rows) }
         data.readEntries().forEach {
+            val cellLayoutBoard = boards[it.screenId]
             assert(cellLayoutBoard.isEmpty(it.cellX, it.cellY, it.spanX, it.spanY)) {
                 "Db has overlapping items"
             }
@@ -96,13 +115,13 @@
     }
 
     private fun compare(dst: GridMigrationData, target: GridMigrationData) {
-        val sortX = { it: GridSizeMigrationUtil.DbEntry -> it.cellX }
-        val sortY = { it: GridSizeMigrationUtil.DbEntry -> it.cellX }
+        val sort = compareBy<GridSizeMigrationUtil.DbEntry>({ it.cellX }, { it.cellY })
         val mapF = { it: GridSizeMigrationUtil.DbEntry ->
             EntryData(it.cellX, it.cellY, it.spanX, it.spanY, it.rank)
         }
-        val entriesDst = dst.readEntries().sortedBy(sortX).sortedBy(sortY).map(mapF)
-        val entriesTarget = target.readEntries().sortedBy(sortX).sortedBy(sortY).map(mapF)
+        val entriesDst = dst.readEntries().sortedWith(sort).map(mapF)
+        val entriesTarget = target.readEntries().sortedWith(sort).map(mapF)
+
         assert(entriesDst == entriesTarget) {
             "The elements on the dst database is not the same as in the target"
         }
@@ -128,9 +147,9 @@
     @Rule
     val result5x5to3x3 =
         TestToPhoneFileCopier(
-            "databases/GridMigrationTest/result5x5to3x3.db",
-            "databases/result5x5to3x3.db",
-            true
+            src = "databases/GridMigrationTest/result5x5to3x3.db",
+            dest = "databases/result5x5to3x3.db",
+            removeOnFinish = true
         )
 
     @Test
@@ -151,9 +170,9 @@
     @Rule
     val result5x5to4x7 =
         TestToPhoneFileCopier(
-            "databases/GridMigrationTest/result5x5to4x7.db",
-            "databases/result5x5to4x7.db",
-            true
+            src = "databases/GridMigrationTest/result5x5to4x7.db",
+            dest = "databases/result5x5to4x7.db",
+            removeOnFinish = true
         )
 
     @Test
@@ -174,9 +193,9 @@
     @Rule
     val result5x5to5x8 =
         TestToPhoneFileCopier(
-            "databases/GridMigrationTest/result5x5to5x8.db",
-            "databases/result5x5to5x8.db",
-            true
+            src = "databases/GridMigrationTest/result5x5to5x8.db",
+            dest = "databases/result5x5to5x8.db",
+            removeOnFinish = true
         )
 
     @Test
@@ -192,4 +211,32 @@
             target =
                 GridMigrationData("result5x5to5x8.db", DeviceGridState(5, 8, 5, TYPE_PHONE, ""))
         )
+
+    @JvmField
+    @Rule
+    val flaggedResult5x5to5x8 =
+        TestToPhoneFileCopier(
+            src = "databases/GridMigrationTest/flagged_result5x5to5x8.db",
+            dest = "databases/flagged_result5x5to5x8.db",
+            removeOnFinish = true
+        )
+
+    @Test
+    fun `flagged 5x5 to 5x8`() {
+        setFlagsRule.setFlags(true, Flags.FLAG_GRID_MIGRATION_FIX)
+        runTest(
+            src = GridMigrationData(DB_FILE, DeviceGridState(5, 5, 5, TYPE_PHONE, DB_FILE)),
+            dst =
+                GridMigrationData(
+                    null, // in memory db, to download a new db change null for the filename of the
+                    // db name to store it. Do not use existing names.
+                    DeviceGridState(5, 8, 5, TYPE_PHONE, "")
+                ),
+            target =
+                GridMigrationData(
+                    "flagged_result5x5to5x8.db",
+                    DeviceGridState(5, 8, 5, TYPE_PHONE, "")
+                )
+        )
+    }
 }
diff --git a/tests/src/com/android/launcher3/util/rule/TestToPhoneFileCopier.kt b/tests/src/com/android/launcher3/util/rule/TestToPhoneFileCopier.kt
index 72c4f16..d3516d1 100644
--- a/tests/src/com/android/launcher3/util/rule/TestToPhoneFileCopier.kt
+++ b/tests/src/com/android/launcher3/util/rule/TestToPhoneFileCopier.kt
@@ -49,7 +49,11 @@
         object : Statement() {
             override fun evaluate() {
                 before()
-                base.evaluate()
+                try {
+                    base.evaluate()
+                } finally {
+                    after()
+                }
             }
         }
 }
