Merge "Fix SnapshotRecord being overwritten when requesting a new one." into main
diff --git a/libs/WindowManager/Shell/src/com/android/wm/shell/startingsurface/StartingSurfaceDrawer.java b/libs/WindowManager/Shell/src/com/android/wm/shell/startingsurface/StartingSurfaceDrawer.java
index e2be153..1ce87ef 100644
--- a/libs/WindowManager/Shell/src/com/android/wm/shell/startingsurface/StartingSurfaceDrawer.java
+++ b/libs/WindowManager/Shell/src/com/android/wm/shell/startingsurface/StartingSurfaceDrawer.java
@@ -22,6 +22,7 @@
 import static android.window.StartingWindowRemovalInfo.DEFER_MODE_ROTATION;
 
 import android.annotation.CallSuper;
+import android.annotation.NonNull;
 import android.app.TaskInfo;
 import android.app.WindowConfiguration;
 import android.content.Context;
@@ -306,7 +307,7 @@
         @CallSuper
         protected void removeImmediately() {
             mRemoveExecutor.removeCallbacks(mScheduledRunnable);
-            mRecordManager.onRecordRemoved(mTaskId);
+            mRecordManager.onRecordRemoved(this, mTaskId);
         }
     }
 
@@ -327,6 +328,11 @@
         }
 
         void addRecord(int taskId, StartingWindowRecord record) {
+            final StartingWindowRecord original = mStartingWindowRecords.get(taskId);
+            if (original != null) {
+                mTmpRemovalInfo.taskId = taskId;
+                original.removeIfPossible(mTmpRemovalInfo, true /* immediately */);
+            }
             mStartingWindowRecords.put(taskId, record);
         }
 
@@ -346,8 +352,11 @@
             removeWindow(mTmpRemovalInfo, true/* immediately */);
         }
 
-        void onRecordRemoved(int taskId) {
-            mStartingWindowRecords.remove(taskId);
+        void onRecordRemoved(@NonNull StartingWindowRecord record, int taskId) {
+            final StartingWindowRecord currentRecord = mStartingWindowRecords.get(taskId);
+            if (currentRecord == record) {
+                mStartingWindowRecords.remove(taskId);
+            }
         }
 
         StartingWindowRecord getRecord(int taskId) {