Merge "Mark current tiles as auto-added" into main
diff --git a/packages/SystemUI/multivalentTests/src/com/android/systemui/qs/pipeline/domain/interactor/AutoAddInteractorTest.kt b/packages/SystemUI/multivalentTests/src/com/android/systemui/qs/pipeline/domain/interactor/AutoAddInteractorTest.kt
index 2ea12ef..8ae9172 100644
--- a/packages/SystemUI/multivalentTests/src/com/android/systemui/qs/pipeline/domain/interactor/AutoAddInteractorTest.kt
+++ b/packages/SystemUI/multivalentTests/src/com/android/systemui/qs/pipeline/domain/interactor/AutoAddInteractorTest.kt
@@ -26,6 +26,7 @@
 import com.android.systemui.qs.pipeline.domain.autoaddable.FakeAutoAddable
 import com.android.systemui.qs.pipeline.domain.model.AutoAddTracking
 import com.android.systemui.qs.pipeline.domain.model.AutoAddable
+import com.android.systemui.qs.pipeline.domain.model.TileModel
 import com.android.systemui.qs.pipeline.shared.TileSpec
 import com.android.systemui.qs.pipeline.shared.logging.QSPipelineLogger
 import com.android.systemui.util.mockito.any
@@ -65,6 +66,7 @@
         MockitoAnnotations.initMocks(this)
 
         whenever(currentTilesInteractor.userId).thenReturn(MutableStateFlow(USER))
+        whenever(currentTilesInteractor.currentTiles).thenReturn(MutableStateFlow(emptyList()))
     }
 
     @Test
@@ -201,6 +203,45 @@
             assertThat(autoAddedTiles).doesNotContain(SPEC)
         }
 
+    @Test
+    fun autoAddable_trackIfNotAdded_currentTile_markedAsAdded() =
+        testScope.runTest {
+            val fakeTile = FakeQSTile(USER).apply { tileSpec = SPEC.spec }
+            val fakeCurrentTileModel = TileModel(SPEC, fakeTile)
+            whenever(currentTilesInteractor.currentTiles)
+                .thenReturn(MutableStateFlow(listOf(fakeCurrentTileModel)))
+
+            val autoAddedTiles by collectLastValue(autoAddRepository.autoAddedTiles(USER))
+            val fakeAutoAddable = FakeAutoAddable(SPEC, AutoAddTracking.IfNotAdded(SPEC))
+
+            underTest = createInteractor(setOf(fakeAutoAddable))
+            runCurrent()
+
+            assertThat(autoAddedTiles).contains(SPEC)
+        }
+
+    @Test
+    fun autoAddable_trackIfNotAdded_tileAddedToCurrentTiles_markedAsAdded() =
+        testScope.runTest {
+            val fakeTile = FakeQSTile(USER).apply { tileSpec = SPEC.spec }
+            val fakeCurrentTileModel = TileModel(SPEC, fakeTile)
+            val currentTilesFlow = MutableStateFlow(emptyList<TileModel>())
+
+            whenever(currentTilesInteractor.currentTiles).thenReturn(currentTilesFlow)
+
+            val autoAddedTiles by collectLastValue(autoAddRepository.autoAddedTiles(USER))
+            val fakeAutoAddable = FakeAutoAddable(SPEC, AutoAddTracking.IfNotAdded(SPEC))
+
+            underTest = createInteractor(setOf(fakeAutoAddable))
+            runCurrent()
+
+            assertThat(autoAddedTiles).doesNotContain(SPEC)
+
+            currentTilesFlow.value = listOf(fakeCurrentTileModel)
+
+            assertThat(autoAddedTiles).contains(SPEC)
+        }
+
     private fun createInteractor(autoAddables: Set<AutoAddable>): AutoAddInteractor {
         return AutoAddInteractor(
                 autoAddables,
diff --git a/packages/SystemUI/src/com/android/systemui/qs/pipeline/domain/interactor/AutoAddInteractor.kt b/packages/SystemUI/src/com/android/systemui/qs/pipeline/domain/interactor/AutoAddInteractor.kt
index cde3835..187b444 100644
--- a/packages/SystemUI/src/com/android/systemui/qs/pipeline/domain/interactor/AutoAddInteractor.kt
+++ b/packages/SystemUI/src/com/android/systemui/qs/pipeline/domain/interactor/AutoAddInteractor.kt
@@ -35,6 +35,7 @@
 import kotlinx.coroutines.flow.collectLatest
 import kotlinx.coroutines.flow.emptyFlow
 import kotlinx.coroutines.flow.filterIsInstance
+import kotlinx.coroutines.flow.map
 import kotlinx.coroutines.flow.merge
 import kotlinx.coroutines.flow.stateIn
 import kotlinx.coroutines.flow.take
@@ -55,6 +56,7 @@
 ) : Dumpable {
 
     private val initialized = AtomicBoolean(false)
+    private lateinit var currentTilesInteractor: CurrentTilesInteractor
 
     /** Start collection of signals following the user from [currentTilesInteractor]. */
     fun init(currentTilesInteractor: CurrentTilesInteractor) {
@@ -62,58 +64,74 @@
             return
         }
 
+        this.currentTilesInteractor = currentTilesInteractor
         dumpManager.registerNormalDumpable(TAG, this)
 
         scope.launch {
             currentTilesInteractor.userId.collectLatest { userId ->
                 coroutineScope {
-                    val previouslyAdded = repository.autoAddedTiles(userId).stateIn(this)
-
-                    autoAddables
-                        .map { addable ->
-                            val autoAddSignal = addable.autoAddSignal(userId)
-                            when (val lifecycle = addable.autoAddTracking) {
-                                is AutoAddTracking.Always -> autoAddSignal
-                                is AutoAddTracking.Disabled -> emptyFlow()
-                                is AutoAddTracking.IfNotAdded -> {
-                                    if (lifecycle.spec !in previouslyAdded.value) {
-                                        autoAddSignal.filterIsInstance<AutoAddSignal.Add>().take(1)
-                                    } else {
-                                        emptyFlow()
-                                    }
-                                }
-                            }
-                        }
-                        .merge()
-                        .collect { signal ->
-                            when (signal) {
-                                is AutoAddSignal.Add -> {
-                                    if (signal.spec !in previouslyAdded.value) {
-                                        currentTilesInteractor.addTile(signal.spec, signal.position)
-                                        qsPipelineLogger.logTileAutoAdded(
-                                            userId,
-                                            signal.spec,
-                                            signal.position
-                                        )
-                                        repository.markTileAdded(userId, signal.spec)
-                                    }
-                                }
-                                is AutoAddSignal.Remove -> {
-                                    currentTilesInteractor.removeTiles(setOf(signal.spec))
-                                    qsPipelineLogger.logTileAutoRemoved(userId, signal.spec)
-                                    repository.unmarkTileAdded(userId, signal.spec)
-                                }
-                                is AutoAddSignal.RemoveTracking -> {
-                                    qsPipelineLogger.logTileUnmarked(userId, signal.spec)
-                                    repository.unmarkTileAdded(userId, signal.spec)
-                                }
-                            }
-                        }
+                    launch { collectAutoAddSignalsForUser(userId) }
+                    launch { markTrackIfNotAddedTilesThatAreCurrent(userId) }
                 }
             }
         }
     }
 
+    private suspend fun markTrackIfNotAddedTilesThatAreCurrent(userId: Int) {
+        val trackIfNotAddedSpecs =
+            autoAddables
+                .map { it.autoAddTracking }
+                .filterIsInstance<AutoAddTracking.IfNotAdded>()
+                .map { it.spec }
+        currentTilesInteractor.currentTiles
+            .map { tiles -> tiles.map { it.spec } }
+            .collect {
+                it.filter { it in trackIfNotAddedSpecs }
+                    .forEach { spec -> repository.markTileAdded(userId, spec) }
+            }
+    }
+
+    private suspend fun CoroutineScope.collectAutoAddSignalsForUser(userId: Int) {
+        val previouslyAdded = repository.autoAddedTiles(userId).stateIn(this)
+
+        autoAddables
+            .map { addable ->
+                val autoAddSignal = addable.autoAddSignal(userId)
+                when (val lifecycle = addable.autoAddTracking) {
+                    is AutoAddTracking.Always -> autoAddSignal
+                    is AutoAddTracking.Disabled -> emptyFlow()
+                    is AutoAddTracking.IfNotAdded -> {
+                        if (lifecycle.spec !in previouslyAdded.value) {
+                            autoAddSignal.filterIsInstance<AutoAddSignal.Add>().take(1)
+                        } else {
+                            emptyFlow()
+                        }
+                    }
+                }
+            }
+            .merge()
+            .collect { signal ->
+                when (signal) {
+                    is AutoAddSignal.Add -> {
+                        if (signal.spec !in previouslyAdded.value) {
+                            currentTilesInteractor.addTile(signal.spec, signal.position)
+                            qsPipelineLogger.logTileAutoAdded(userId, signal.spec, signal.position)
+                            repository.markTileAdded(userId, signal.spec)
+                        }
+                    }
+                    is AutoAddSignal.Remove -> {
+                        currentTilesInteractor.removeTiles(setOf(signal.spec))
+                        qsPipelineLogger.logTileAutoRemoved(userId, signal.spec)
+                        repository.unmarkTileAdded(userId, signal.spec)
+                    }
+                    is AutoAddSignal.RemoveTracking -> {
+                        qsPipelineLogger.logTileUnmarked(userId, signal.spec)
+                        repository.unmarkTileAdded(userId, signal.spec)
+                    }
+                }
+            }
+    }
+
     override fun dump(pw: PrintWriter, args: Array<out String>) {
         with(pw.asIndenting()) {
             println("AutoAddables:")