Flip set subtraction to produce correct list of tasks to remove.

Fix: 380067701
Flag: com.android.launcher3.enable_refactor_task_thumbnail
Test: TasksRepositoryTest
Change-Id: I26933614241e622a11a440829e89023822c8041b
diff --git a/quickstep/src/com/android/quickstep/recents/data/TasksRepository.kt b/quickstep/src/com/android/quickstep/recents/data/TasksRepository.kt
index a315775..8a1b211 100644
--- a/quickstep/src/com/android/quickstep/recents/data/TasksRepository.kt
+++ b/quickstep/src/com/android/quickstep/recents/data/TasksRepository.kt
@@ -50,14 +50,15 @@
 
     override fun getAllTaskData(forceRefresh: Boolean): Flow<List<Task>> {
         if (forceRefresh) {
-            recentsModel.getTasks { result ->
+            recentsModel.getTasks { newTaskList ->
+                val oldTaskMap = tasks.value
                 val recentTasks =
-                    result
+                    newTaskList
                         .flatMap { groupTask -> groupTask.tasks }
                         .associateBy { it.key.id }
-                        .also { hashMap ->
+                        .also { newTaskMap ->
                             // Clean tasks that are not in the latest group tasks list.
-                            val tasksNoLongerVisible = hashMap.keys.subtract(tasks.value.keys)
+                            val tasksNoLongerVisible = oldTaskMap.keys.subtract(newTaskMap.keys)
                             removeTasks(tasksNoLongerVisible)
 
                             // Use pre-loaded thumbnail data and icon from the previous list.
@@ -66,12 +67,12 @@
                             val cache =
                                 taskRequests.keys
                                     .mapNotNull { key ->
-                                        val task = tasks.value[key] ?: return@mapNotNull null
+                                        val task = oldTaskMap[key] ?: return@mapNotNull null
                                         key to Pair(task.thumbnail, task.icon)
                                     }
                                     .toMap()
 
-                            hashMap.values.forEach { task ->
+                            newTaskMap.values.forEach { task ->
                                 task.thumbnail = task.thumbnail ?: cache[task.key.id]?.first
                                 task.icon = task.icon ?: cache[task.key.id]?.second
                             }
diff --git a/quickstep/tests/multivalentTests/src/com/android/quickstep/recents/data/TasksRepositoryTest.kt b/quickstep/tests/multivalentTests/src/com/android/quickstep/recents/data/TasksRepositoryTest.kt
index 624310b..ee1ec6e 100644
--- a/quickstep/tests/multivalentTests/src/com/android/quickstep/recents/data/TasksRepositoryTest.kt
+++ b/quickstep/tests/multivalentTests/src/com/android/quickstep/recents/data/TasksRepositoryTest.kt
@@ -37,7 +37,9 @@
 import kotlinx.coroutines.test.runTest
 import org.junit.Test
 import org.junit.runner.RunWith
+import org.mockito.Mockito.spy
 import org.mockito.kotlin.mock
+import org.mockito.kotlin.verify
 import org.mockito.kotlin.whenever
 
 @OptIn(ExperimentalCoroutinesApi::class)
@@ -56,7 +58,7 @@
     private val taskVisualsChangeNotifier = FakeTaskVisualsChangeNotifier()
     private val highResLoadingStateNotifier = FakeHighResLoadingStateNotifier()
     private val taskVisualsChangedDelegate =
-        TaskVisualsChangedDelegateImpl(taskVisualsChangeNotifier, highResLoadingStateNotifier)
+        spy(TaskVisualsChangedDelegateImpl(taskVisualsChangeNotifier, highResLoadingStateNotifier))
 
     private val dispatcher = UnconfinedTestDispatcher()
     private val testScope = TestScope(dispatcher)
@@ -131,6 +133,29 @@
         }
 
     @Test
+    fun getAllTaskData_clearsPreviouslyLoadedImagesForRemovedTasks() =
+        testScope.runTest {
+            // Setup data
+            recentsModel.seedTasks(defaultTaskList)
+            systemUnderTest.getAllTaskData(forceRefresh = true)
+            val bitmap1 = taskThumbnailDataSource.taskIdToBitmap[1]
+
+            // Load images for task 1
+            systemUnderTest.setVisibleTasks(setOf(1))
+            assertThat(systemUnderTest.getThumbnailById(1).first()!!.thumbnail).isEqualTo(bitmap1)
+
+            // Remove task 1 from "all data"
+            recentsModel.seedTasks(
+                defaultTaskList.filterNot { groupTask -> groupTask.tasks.any { it.key.id == 1 } }
+            )
+            systemUnderTest.getAllTaskData(forceRefresh = true)
+
+            // Assert task 1 was fully removed
+            assertThat(systemUnderTest.getThumbnailById(1).first()?.thumbnail).isNull()
+            verify(taskVisualsChangedDelegate).unregisterTaskThumbnailChangedCallback(tasks[1].key)
+        }
+
+    @Test
     fun getCurrentThumbnailByIdReturnsThumbnailWithLoadedThumbnails() =
         testScope.runTest {
             recentsModel.seedTasks(defaultTaskList)