Color Picker reset support (1/3)

Add color picker reset support. Refactored WallpaperColorsViewModel as
well to be provided from the WPP2 injector rather than from the
ViewModelProvider. This CL also moves optimistic update from the
repository layer to the interactor layer.

Bug: 267803746
Bug: 269339630
Bug: 269451870
Test: Unit tests, and manual tests including changing wallpaper to verify that
color options are updated, setting system color, setting and resetting
color with a combination of wallpaper and basic colors, resetting after
making changes to both wallpaper and color option

Change-Id: I27f47621e2f187449b1642ed0794c63efef9d37f
diff --git a/src/com/android/customization/model/color/ColorOption.java b/src/com/android/customization/model/color/ColorOption.java
index 26e025d..66a3a3c 100644
--- a/src/com/android/customization/model/color/ColorOption.java
+++ b/src/com/android/customization/model/color/ColorOption.java
@@ -107,6 +107,9 @@
         if (other == null) {
             return false;
         }
+        if (mStyle != other.getStyle()) {
+            return false;
+        }
         if (mIsDefault) {
             return other.isDefault() || TextUtils.isEmpty(other.getSerializedPackages())
                     || EMPTY_JSON.equals(other.getSerializedPackages());
diff --git a/src/com/android/customization/module/ThemePickerInjector.kt b/src/com/android/customization/module/ThemePickerInjector.kt
index 09466e3..1ed9f87 100644
--- a/src/com/android/customization/module/ThemePickerInjector.kt
+++ b/src/com/android/customization/module/ThemePickerInjector.kt
@@ -47,6 +47,7 @@
 import com.android.customization.picker.clock.ui.viewmodel.ClockSettingsViewModel
 import com.android.customization.picker.color.data.repository.ColorPickerRepositoryImpl
 import com.android.customization.picker.color.domain.interactor.ColorPickerInteractor
+import com.android.customization.picker.color.domain.interactor.ColorPickerSnapshotRestorer
 import com.android.customization.picker.color.ui.viewmodel.ColorPickerViewModel
 import com.android.customization.picker.notifications.data.repository.NotificationsRepository
 import com.android.customization.picker.notifications.domain.interactor.NotificationsInteractor
@@ -100,6 +101,7 @@
     private var notificationSectionViewModelFactory: NotificationSectionViewModel.Factory? = null
     private var colorPickerInteractor: ColorPickerInteractor? = null
     private var colorPickerViewModelFactory: ColorPickerViewModel.Factory? = null
+    private var colorPickerSnapshotRestorer: ColorPickerSnapshotRestorer? = null
     private var darkModeSnapshotRestorer: DarkModeSnapshotRestorer? = null
     private var themedIconSnapshotRestorer: ThemedIconSnapshotRestorer? = null
     private var themedIconInteractor: ThemedIconInteractor? = null
@@ -113,8 +115,7 @@
             ?: DefaultCustomizationSections(
                     getColorPickerViewModelFactory(
                         context = activity,
-                        wallpaperColorsViewModel =
-                            ViewModelProvider(activity)[WallpaperColorsViewModel::class.java],
+                        wallpaperColorsViewModel = getWallpaperColorsViewModel(),
                     ),
                     getKeyguardQuickAffordancePickerInteractor(activity),
                     getKeyguardQuickAffordancePickerViewModelFactory(activity),
@@ -190,6 +191,8 @@
             this[KEY_DARK_MODE_SNAPSHOT_RESTORER] = getDarkModeSnapshotRestorer(context)
             this[KEY_THEMED_ICON_SNAPSHOT_RESTORER] = getThemedIconSnapshotRestorer(context)
             this[KEY_APP_GRID_SNAPSHOT_RESTORER] = getGridSnapshotRestorer(context)
+            this[KEY_COLOR_PICKER_SNAPSHOT_RESTORER] =
+                getColorPickerSnapshotRestorer(context, getWallpaperColorsViewModel())
         }
     }
 
@@ -346,7 +349,12 @@
         wallpaperColorsViewModel: WallpaperColorsViewModel,
     ): ColorPickerInteractor {
         return colorPickerInteractor
-            ?: ColorPickerInteractor(ColorPickerRepositoryImpl(context, wallpaperColorsViewModel))
+            ?: ColorPickerInteractor(
+                    repository = ColorPickerRepositoryImpl(context, wallpaperColorsViewModel),
+                    snapshotRestorer = {
+                        getColorPickerSnapshotRestorer(context, wallpaperColorsViewModel)
+                    }
+                )
                 .also { colorPickerInteractor = it }
     }
 
@@ -362,6 +370,17 @@
                 .also { colorPickerViewModelFactory = it }
     }
 
+    private fun getColorPickerSnapshotRestorer(
+        context: Context,
+        wallpaperColorsViewModel: WallpaperColorsViewModel,
+    ): ColorPickerSnapshotRestorer {
+        return colorPickerSnapshotRestorer
+            ?: ColorPickerSnapshotRestorer(
+                    getColorPickerInteractor(context, wallpaperColorsViewModel)
+                )
+                .also { colorPickerSnapshotRestorer = it }
+    }
+
     fun getDarkModeSnapshotRestorer(
         context: Context,
     ): DarkModeSnapshotRestorer {
@@ -460,6 +479,8 @@
         private val KEY_THEMED_ICON_SNAPSHOT_RESTORER = KEY_DARK_MODE_SNAPSHOT_RESTORER + 1
         @JvmStatic
         private val KEY_APP_GRID_SNAPSHOT_RESTORER = KEY_THEMED_ICON_SNAPSHOT_RESTORER + 1
+        @JvmStatic
+        private val KEY_COLOR_PICKER_SNAPSHOT_RESTORER = KEY_APP_GRID_SNAPSHOT_RESTORER + 1
 
         /**
          * When this injector is overridden, this is the minimal value that should be used by
@@ -467,6 +488,6 @@
          *
          * It should always be greater than the biggest restorer key.
          */
-        @JvmStatic protected val MIN_SNAPSHOT_RESTORER_KEY = KEY_APP_GRID_SNAPSHOT_RESTORER + 1
+        @JvmStatic protected val MIN_SNAPSHOT_RESTORER_KEY = KEY_COLOR_PICKER_SNAPSHOT_RESTORER + 1
     }
 }
diff --git a/src/com/android/customization/picker/clock/ui/fragment/ClockSettingsFragment.kt b/src/com/android/customization/picker/clock/ui/fragment/ClockSettingsFragment.kt
index 976907b..2ba03bd 100644
--- a/src/com/android/customization/picker/clock/ui/fragment/ClockSettingsFragment.kt
+++ b/src/com/android/customization/picker/clock/ui/fragment/ClockSettingsFragment.kt
@@ -25,7 +25,6 @@
 import com.android.customization.module.ThemePickerInjector
 import com.android.customization.picker.clock.ui.binder.ClockSettingsBinder
 import com.android.wallpaper.R
-import com.android.wallpaper.model.WallpaperColorsViewModel
 import com.android.wallpaper.module.InjectorProvider
 import com.android.wallpaper.picker.AppbarFragment
 import com.android.wallpaper.picker.customization.ui.binder.ScreenPreviewBinder
@@ -63,7 +62,7 @@
         val injector = InjectorProvider.getInjector() as ThemePickerInjector
 
         val lockScreenView: CardView = view.requireViewById(R.id.lock_preview)
-        val colorViewModel = ViewModelProvider(activity)[WallpaperColorsViewModel::class.java]
+        val colorViewModel = injector.getWallpaperColorsViewModel()
         val displayUtils = injector.getDisplayUtils(context)
         ScreenPreviewBinder.bind(
                 activity = activity,
diff --git a/src/com/android/customization/picker/color/data/repository/ColorPickerRepository.kt b/src/com/android/customization/picker/color/data/repository/ColorPickerRepository.kt
index 0e65577..1a0f5a9 100644
--- a/src/com/android/customization/picker/color/data/repository/ColorPickerRepository.kt
+++ b/src/com/android/customization/picker/color/data/repository/ColorPickerRepository.kt
@@ -25,15 +25,13 @@
  * system color.
  */
 interface ColorPickerRepository {
-    /**
-     * The newly selected color option for overwriting the current active option during an
-     * optimistic update, the value is null when no overwriting is needed
-     */
-    val activeColorOption: Flow<ColorOptionModel?>
 
     /** List of wallpaper and preset color options on the device, categorized by Color Type */
     val colorOptions: Flow<Map<ColorType, List<ColorOptionModel>>>
 
     /** Selects a color option with optimistic update */
-    fun select(colorOptionModel: ColorOptionModel)
+    suspend fun select(colorOptionModel: ColorOptionModel)
+
+    /** Returns the current selected color option based on system settings */
+    fun getCurrentColorOption(): ColorOptionModel
 }
diff --git a/src/com/android/customization/picker/color/data/repository/ColorPickerRepositoryImpl.kt b/src/com/android/customization/picker/color/data/repository/ColorPickerRepositoryImpl.kt
index d6d5060..70382c7 100644
--- a/src/com/android/customization/picker/color/data/repository/ColorPickerRepositoryImpl.kt
+++ b/src/com/android/customization/picker/color/data/repository/ColorPickerRepositoryImpl.kt
@@ -20,6 +20,8 @@
 import android.content.Context
 import android.util.Log
 import com.android.customization.model.CustomizationManager
+import com.android.customization.model.ResourceConstants.OVERLAY_CATEGORY_COLOR
+import com.android.customization.model.ResourceConstants.OVERLAY_CATEGORY_SYSTEM_PALETTE
 import com.android.customization.model.color.ColorBundle
 import com.android.customization.model.color.ColorCustomizationManager
 import com.android.customization.model.color.ColorOption
@@ -27,11 +29,10 @@
 import com.android.customization.model.theme.OverlayManagerCompat
 import com.android.customization.picker.color.shared.model.ColorOptionModel
 import com.android.customization.picker.color.shared.model.ColorType
+import com.android.systemui.monet.Style
 import com.android.wallpaper.model.WallpaperColorsViewModel
 import kotlinx.coroutines.flow.Flow
-import kotlinx.coroutines.flow.MutableStateFlow
 import kotlinx.coroutines.flow.StateFlow
-import kotlinx.coroutines.flow.asStateFlow
 import kotlinx.coroutines.flow.combine
 import kotlinx.coroutines.flow.map
 import kotlinx.coroutines.suspendCancellableCoroutine
@@ -50,17 +51,11 @@
     private val colorManager: ColorCustomizationManager =
         ColorCustomizationManager.getInstance(context, OverlayManagerCompat(context))
 
-    private val _activeColorOption = MutableStateFlow<ColorOptionModel?>(null)
-    override val activeColorOption: StateFlow<ColorOptionModel?> = _activeColorOption.asStateFlow()
-
     override val colorOptions: Flow<Map<ColorType, List<ColorOptionModel>>> =
-        combine(activeColorOption, homeWallpaperColors, lockWallpaperColors) {
-                activeOption,
-                homeColors,
-                lockColors ->
-                Triple(activeOption, homeColors, lockColors)
+        combine(homeWallpaperColors, lockWallpaperColors) { homeColors, lockColors ->
+                homeColors to lockColors
             }
-            .map { (activeOption, homeColors, lockColors) ->
+            .map { (homeColors, lockColors) ->
                 suspendCancellableCoroutine { continuation ->
                     colorManager.setWallpaperColors(homeColors, lockColors)
                     colorManager.fetchOptions(
@@ -73,9 +68,8 @@
                                 options?.forEach { option ->
                                     when (option) {
                                         is ColorSeedOption ->
-                                            wallpaperColorOptions.add(option.toModel(activeOption))
-                                        is ColorBundle ->
-                                            presetColorOptions.add(option.toModel(activeOption))
+                                            wallpaperColorOptions.add(option.toModel())
+                                        is ColorBundle -> presetColorOptions.add(option.toModel())
                                     }
                                 }
                                 continuation.resumeWith(
@@ -102,33 +96,48 @@
                 }
             }
 
-    override fun select(colorOptionModel: ColorOptionModel) {
-        _activeColorOption.value = colorOptionModel
-        val colorOption: ColorOption = colorOptionModel.colorOption
-        colorManager.apply(
-            colorOption,
-            object : CustomizationManager.Callback {
-                override fun onSuccess() {
-                    _activeColorOption.value = null
-                }
+    override suspend fun select(colorOptionModel: ColorOptionModel) =
+        suspendCancellableCoroutine { continuation ->
+            colorManager.apply(
+                colorOptionModel.colorOption,
+                object : CustomizationManager.Callback {
+                    override fun onSuccess() {
+                        continuation.resumeWith(Result.success(Unit))
+                    }
 
-                override fun onError(throwable: Throwable?) {
-                    _activeColorOption.value = null
-                    Log.w(TAG, "Apply theme with error", throwable)
+                    override fun onError(throwable: Throwable?) {
+                        Log.w(TAG, "Apply theme with error", throwable)
+                        continuation.resumeWith(
+                            Result.failure(throwable ?: Throwable("Error loading theme bundles"))
+                        )
+                    }
                 }
-            }
+            )
+        }
+
+    override fun getCurrentColorOption(): ColorOptionModel {
+        val overlays = colorManager.currentOverlays
+        return ColorOptionModel(
+            colorOption =
+                // Does not matter whether ColorSeedOption or ColorBundle builder is used here
+                // because to apply the color, one just needs a generic ColorOption
+                ColorSeedOption.Builder()
+                    .addOverlayPackage(
+                        OVERLAY_CATEGORY_SYSTEM_PALETTE,
+                        overlays[OVERLAY_CATEGORY_SYSTEM_PALETTE]
+                    )
+                    .addOverlayPackage(OVERLAY_CATEGORY_COLOR, overlays[OVERLAY_CATEGORY_COLOR])
+                    .setSource(colorManager.currentColorSource)
+                    .setStyle(Style.valueOf(colorManager.currentStyle))
+                    .build(),
+            isSelected = false,
         )
     }
 
-    private fun ColorOption.toModel(activeColorOption: ColorOptionModel?): ColorOptionModel {
+    private fun ColorOption.toModel(): ColorOptionModel {
         return ColorOptionModel(
             colorOption = this,
-            isSelected =
-                if (activeColorOption != null) {
-                    isEquivalent(activeColorOption.colorOption)
-                } else {
-                    isActive(colorManager)
-                },
+            isSelected = isActive(colorManager),
         )
     }
 
diff --git a/src/com/android/customization/picker/color/data/repository/FakeColorPickerRepository.kt b/src/com/android/customization/picker/color/data/repository/FakeColorPickerRepository.kt
index d2a25bc..7dab2d8 100644
--- a/src/com/android/customization/picker/color/data/repository/FakeColorPickerRepository.kt
+++ b/src/com/android/customization/picker/color/data/repository/FakeColorPickerRepository.kt
@@ -26,118 +26,104 @@
 import kotlinx.coroutines.flow.StateFlow
 import kotlinx.coroutines.flow.asStateFlow
 
-class FakeColorPickerRepository(context: Context) : ColorPickerRepository {
-    override val activeColorOption: StateFlow<ColorOptionModel?> =
-        MutableStateFlow<ColorOptionModel?>(null)
+class FakeColorPickerRepository(private val context: Context) : ColorPickerRepository {
 
-    private val colorSeedOption0: ColorSeedOption =
-        ColorSeedOption.Builder()
-            .setLightColors(
-                intArrayOf(
-                    Color.TRANSPARENT,
-                    Color.TRANSPARENT,
-                    Color.TRANSPARENT,
-                    Color.TRANSPARENT
-                )
-            )
-            .setDarkColors(
-                intArrayOf(
-                    Color.TRANSPARENT,
-                    Color.TRANSPARENT,
-                    Color.TRANSPARENT,
-                    Color.TRANSPARENT
-                )
-            )
-            .setIndex(0)
-            .build()
-    private val colorSeedOption1: ColorSeedOption =
-        ColorSeedOption.Builder()
-            .setLightColors(
-                intArrayOf(
-                    Color.TRANSPARENT,
-                    Color.TRANSPARENT,
-                    Color.TRANSPARENT,
-                    Color.TRANSPARENT
-                )
-            )
-            .setDarkColors(
-                intArrayOf(
-                    Color.TRANSPARENT,
-                    Color.TRANSPARENT,
-                    Color.TRANSPARENT,
-                    Color.TRANSPARENT
-                )
-            )
-            .setIndex(1)
-            .build()
-    private val colorSeedOption2: ColorSeedOption =
-        ColorSeedOption.Builder()
-            .setLightColors(
-                intArrayOf(
-                    Color.TRANSPARENT,
-                    Color.TRANSPARENT,
-                    Color.TRANSPARENT,
-                    Color.TRANSPARENT
-                )
-            )
-            .setDarkColors(
-                intArrayOf(
-                    Color.TRANSPARENT,
-                    Color.TRANSPARENT,
-                    Color.TRANSPARENT,
-                    Color.TRANSPARENT
-                )
-            )
-            .setIndex(2)
-            .build()
-    private val colorSeedOption3: ColorSeedOption =
-        ColorSeedOption.Builder()
-            .setLightColors(
-                intArrayOf(
-                    Color.TRANSPARENT,
-                    Color.TRANSPARENT,
-                    Color.TRANSPARENT,
-                    Color.TRANSPARENT
-                )
-            )
-            .setDarkColors(
-                intArrayOf(
-                    Color.TRANSPARENT,
-                    Color.TRANSPARENT,
-                    Color.TRANSPARENT,
-                    Color.TRANSPARENT
-                )
-            )
-            .setIndex(3)
-            .build()
-    private val colorBundle0: ColorBundle = ColorBundle.Builder().setIndex(0).build(context)
-    private val colorBundle1: ColorBundle = ColorBundle.Builder().setIndex(1).build(context)
-    private val colorBundle2: ColorBundle = ColorBundle.Builder().setIndex(2).build(context)
-    private val colorBundle3: ColorBundle = ColorBundle.Builder().setIndex(3).build(context)
+    private lateinit var selectedColorOption: ColorOptionModel
 
     private val _colorOptions =
         MutableStateFlow(
-            mapOf(
-                ColorType.WALLPAPER_COLOR to
-                    listOf(
-                        ColorOptionModel(colorOption = colorSeedOption0, isSelected = true),
-                        ColorOptionModel(colorOption = colorSeedOption1, isSelected = false),
-                        ColorOptionModel(colorOption = colorSeedOption2, isSelected = false),
-                        ColorOptionModel(colorOption = colorSeedOption3, isSelected = false)
-                    ),
-                ColorType.BASIC_COLOR to
-                    listOf(
-                        ColorOptionModel(colorOption = colorBundle0, isSelected = false),
-                        ColorOptionModel(colorOption = colorBundle1, isSelected = false),
-                        ColorOptionModel(colorOption = colorBundle2, isSelected = false),
-                        ColorOptionModel(colorOption = colorBundle3, isSelected = false)
-                    )
+            mapOf<ColorType, List<ColorOptionModel>>(
+                ColorType.WALLPAPER_COLOR to listOf(),
+                ColorType.BASIC_COLOR to listOf()
             )
         )
     override val colorOptions: StateFlow<Map<ColorType, List<ColorOptionModel>>> =
         _colorOptions.asStateFlow()
 
-    override fun select(colorOptionModel: ColorOptionModel) {
+    init {
+        setOptions(4, 4, ColorType.WALLPAPER_COLOR, 0)
+    }
+
+    fun setOptions(
+        numWallpaperOptions: Int,
+        numPresetOptions: Int,
+        selectedColorOptionType: ColorType,
+        selectedColorOptionIndex: Int
+    ) {
+        _colorOptions.value =
+            mapOf(
+                ColorType.WALLPAPER_COLOR to
+                    buildList {
+                        repeat(times = numWallpaperOptions) { index ->
+                            val isSelected =
+                                selectedColorOptionType == ColorType.WALLPAPER_COLOR &&
+                                    selectedColorOptionIndex == index
+                            val colorOption =
+                                ColorOptionModel(
+                                    colorOption = buildWallpaperOption(index),
+                                    isSelected = isSelected,
+                                )
+                            if (isSelected) {
+                                selectedColorOption = colorOption
+                            }
+                            add(colorOption)
+                        }
+                    },
+                ColorType.BASIC_COLOR to
+                    buildList {
+                        repeat(times = numPresetOptions) { index ->
+                            val isSelected =
+                                selectedColorOptionType == ColorType.BASIC_COLOR &&
+                                    selectedColorOptionIndex == index
+                            val colorOption =
+                                ColorOptionModel(
+                                    colorOption = buildPresetOption(index),
+                                    isSelected =
+                                        selectedColorOptionType == ColorType.BASIC_COLOR &&
+                                            selectedColorOptionIndex == index,
+                                )
+                            if (isSelected) {
+                                selectedColorOption = colorOption
+                            }
+                            add(colorOption)
+                        }
+                    }
+            )
+    }
+
+    private fun buildPresetOption(index: Int): ColorBundle {
+        return ColorBundle.Builder()
+            .addOverlayPackage("TEST_PACKAGE_TYPE", "preset_color")
+            .addOverlayPackage("TEST_PACKAGE_INDEX", "$index")
+            .setIndex(index)
+            .build(context)
+    }
+
+    private fun buildWallpaperOption(index: Int): ColorSeedOption {
+        return ColorSeedOption.Builder()
+            .setLightColors(
+                intArrayOf(
+                    Color.TRANSPARENT,
+                    Color.TRANSPARENT,
+                    Color.TRANSPARENT,
+                    Color.TRANSPARENT
+                )
+            )
+            .setDarkColors(
+                intArrayOf(
+                    Color.TRANSPARENT,
+                    Color.TRANSPARENT,
+                    Color.TRANSPARENT,
+                    Color.TRANSPARENT
+                )
+            )
+            .addOverlayPackage("TEST_PACKAGE_TYPE", "wallpaper_color")
+            .addOverlayPackage("TEST_PACKAGE_INDEX", "$index")
+            .setIndex(index)
+            .build()
+    }
+
+    override suspend fun select(colorOptionModel: ColorOptionModel) {
         val colorOptions = _colorOptions.value
         val wallpaperColorOptions = colorOptions[ColorType.WALLPAPER_COLOR]!!
         val newWallpaperColorOptions = buildList {
@@ -168,6 +154,8 @@
             )
     }
 
+    override fun getCurrentColorOption(): ColorOptionModel = selectedColorOption
+
     private fun ColorOptionModel.testEquals(other: Any?): Boolean {
         if (other == null) {
             return false
diff --git a/src/com/android/customization/picker/color/domain/interactor/ColorPickerInteractor.kt b/src/com/android/customization/picker/color/domain/interactor/ColorPickerInteractor.kt
index ce453c3..a932067 100644
--- a/src/com/android/customization/picker/color/domain/interactor/ColorPickerInteractor.kt
+++ b/src/com/android/customization/picker/color/domain/interactor/ColorPickerInteractor.kt
@@ -16,17 +16,57 @@
  */
 package com.android.customization.picker.color.domain.interactor
 
+import androidx.annotation.VisibleForTesting
 import com.android.customization.picker.color.data.repository.ColorPickerRepository
 import com.android.customization.picker.color.shared.model.ColorOptionModel
+import javax.inject.Provider
+import kotlinx.coroutines.flow.MutableStateFlow
+import kotlinx.coroutines.flow.combine
 
 /** Single entry-point for all application state and business logic related to system color. */
 class ColorPickerInteractor(
     private val repository: ColorPickerRepository,
+    private val snapshotRestorer: Provider<ColorPickerSnapshotRestorer>,
 ) {
-    /** List of wallpaper and preset color options on the device, categorized by Color Type */
-    val colorOptions = repository.colorOptions
+    /**
+     * The newly selected color option for overwriting the current active option during an
+     * optimistic update, the value is set to null when update fails
+     */
+    @VisibleForTesting private val activeColorOption = MutableStateFlow<ColorOptionModel?>(null)
 
-    fun select(colorOptionModel: ColorOptionModel) {
-        repository.select(colorOptionModel)
+    /** List of wallpaper and preset color options on the device, categorized by Color Type */
+    val colorOptions =
+        combine(repository.colorOptions, activeColorOption) { colorOptions, activeOption ->
+            colorOptions
+                .map { colorTypeEntry ->
+                    colorTypeEntry.key to
+                        colorTypeEntry.value.map { colorOptionModel ->
+                            val isSelected =
+                                if (activeOption != null) {
+                                    colorOptionModel.colorOption.isEquivalent(
+                                        activeOption.colorOption
+                                    )
+                                } else {
+                                    colorOptionModel.isSelected
+                                }
+                            ColorOptionModel(
+                                colorOption = colorOptionModel.colorOption,
+                                isSelected = isSelected
+                            )
+                        }
+                }
+                .toMap()
+        }
+
+    suspend fun select(colorOptionModel: ColorOptionModel) {
+        activeColorOption.value = colorOptionModel
+        try {
+            repository.select(colorOptionModel)
+            snapshotRestorer.get().storeSnapshot(colorOptionModel)
+        } catch (e: Exception) {
+            activeColorOption.value = null
+        }
     }
+
+    fun getCurrentColorOption(): ColorOptionModel = repository.getCurrentColorOption()
 }
diff --git a/src/com/android/customization/picker/color/domain/interactor/ColorPickerSnapshotRestorer.kt b/src/com/android/customization/picker/color/domain/interactor/ColorPickerSnapshotRestorer.kt
new file mode 100644
index 0000000..1635e01
--- /dev/null
+++ b/src/com/android/customization/picker/color/domain/interactor/ColorPickerSnapshotRestorer.kt
@@ -0,0 +1,80 @@
+/*
+ * Copyright (C) 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+package com.android.customization.picker.color.domain.interactor
+
+import android.util.Log
+import com.android.customization.picker.color.shared.model.ColorOptionModel
+import com.android.wallpaper.picker.undo.domain.interactor.SnapshotRestorer
+import com.android.wallpaper.picker.undo.domain.interactor.SnapshotStore
+import com.android.wallpaper.picker.undo.shared.model.RestorableSnapshot
+
+/** Handles state restoration for the color picker system. */
+class ColorPickerSnapshotRestorer(
+    private val interactor: ColorPickerInteractor,
+) : SnapshotRestorer {
+
+    private lateinit var snapshotStore: SnapshotStore
+    private var originalOption: ColorOptionModel? = null
+
+    fun storeSnapshot(colorOptionModel: ColorOptionModel) {
+        snapshotStore.store(snapshot(colorOptionModel))
+    }
+
+    override suspend fun setUpSnapshotRestorer(
+        store: SnapshotStore,
+    ): RestorableSnapshot {
+        snapshotStore = store
+        originalOption = interactor.getCurrentColorOption()
+        return snapshot(originalOption)
+    }
+
+    override suspend fun restoreToSnapshot(snapshot: RestorableSnapshot) {
+        val optionPackagesFromSnapshot: String? = snapshot.args[KEY_COLOR_OPTION_PACKAGES]
+        originalOption?.let { optionToRestore ->
+            if (
+                optionToRestore.colorOption.serializedPackages != optionPackagesFromSnapshot ||
+                    optionToRestore.colorOption.style.toString() !=
+                        snapshot.args[KEY_COLOR_OPTION_STYLE]
+            ) {
+                Log.wtf(
+                    TAG,
+                    """ Original packages does not match snapshot packages to restore to. The 
+                        | current implementation doesn't support undo, only a reset back to the 
+                        | original color option.""".trimMargin(),
+                )
+            }
+
+            interactor.select(optionToRestore)
+        }
+    }
+
+    private fun snapshot(colorOptionModel: ColorOptionModel? = null): RestorableSnapshot {
+        val snapshotMap = mutableMapOf<String, String>()
+        colorOptionModel?.let {
+            snapshotMap[KEY_COLOR_OPTION_PACKAGES] = colorOptionModel.colorOption.serializedPackages
+            snapshotMap[KEY_COLOR_OPTION_STYLE] = colorOptionModel.colorOption.style.toString()
+        }
+        return RestorableSnapshot(snapshotMap)
+    }
+
+    companion object {
+        private const val TAG = "ColorPickerSnapshotRestorer"
+        private const val KEY_COLOR_OPTION_PACKAGES = "color_packages"
+        private const val KEY_COLOR_OPTION_STYLE = "color_style"
+    }
+}
diff --git a/src/com/android/customization/picker/color/ui/fragment/ColorPickerFragment.kt b/src/com/android/customization/picker/color/ui/fragment/ColorPickerFragment.kt
index 416faa6..fa7a344 100644
--- a/src/com/android/customization/picker/color/ui/fragment/ColorPickerFragment.kt
+++ b/src/com/android/customization/picker/color/ui/fragment/ColorPickerFragment.kt
@@ -27,7 +27,6 @@
 import com.android.customization.module.ThemePickerInjector
 import com.android.customization.picker.color.ui.binder.ColorPickerBinder
 import com.android.wallpaper.R
-import com.android.wallpaper.model.WallpaperColorsViewModel
 import com.android.wallpaper.module.InjectorProvider
 import com.android.wallpaper.picker.AppbarFragment
 import com.android.wallpaper.picker.customization.ui.binder.ScreenPreviewBinder
@@ -63,7 +62,7 @@
         val homeScreenView: CardView = view.requireViewById(R.id.home_preview)
         val wallpaperInfoFactory = injector.getCurrentWallpaperInfoFactory(requireContext())
         val displayUtils: DisplayUtils = injector.getDisplayUtils(requireContext())
-        val wcViewModel = ViewModelProvider(requireActivity())[WallpaperColorsViewModel::class.java]
+        val wcViewModel = injector.getWallpaperColorsViewModel()
         ColorPickerBinder.bind(
             view = view,
             viewModel =
diff --git a/src/com/android/customization/picker/color/ui/viewmodel/ColorPickerViewModel.kt b/src/com/android/customization/picker/color/ui/viewmodel/ColorPickerViewModel.kt
index 7eb5488..5e1e542 100644
--- a/src/com/android/customization/picker/color/ui/viewmodel/ColorPickerViewModel.kt
+++ b/src/com/android/customization/picker/color/ui/viewmodel/ColorPickerViewModel.kt
@@ -19,6 +19,7 @@
 import android.content.Context
 import androidx.lifecycle.ViewModel
 import androidx.lifecycle.ViewModelProvider
+import androidx.lifecycle.viewModelScope
 import com.android.customization.model.color.ColorBundle
 import com.android.customization.model.color.ColorSeedOption
 import com.android.customization.picker.color.domain.interactor.ColorPickerInteractor
@@ -30,6 +31,7 @@
 import kotlinx.coroutines.flow.MutableStateFlow
 import kotlinx.coroutines.flow.combine
 import kotlinx.coroutines.flow.map
+import kotlinx.coroutines.launch
 
 /** Models UI state for a color picker experience. */
 class ColorPickerViewModel
@@ -90,7 +92,7 @@
                         if (colorOptionModel.isSelected) {
                             null
                         } else {
-                            { interactor.select(colorOptionModel) }
+                            { viewModelScope.launch { interactor.select(colorOptionModel) } }
                         }
                 )
             }
@@ -115,7 +117,7 @@
                         if (colorOptionModel.isSelected) {
                             null
                         } else {
-                            { interactor.select(colorOptionModel) }
+                            { viewModelScope.launch { interactor.select(colorOptionModel) } }
                         },
                 )
             }
diff --git a/tests/src/com/android/customization/model/picker/color/domain/interactor/ColorPickerInteractorTest.kt b/tests/src/com/android/customization/model/picker/color/domain/interactor/ColorPickerInteractorTest.kt
index 81ef55f..885d5a9 100644
--- a/tests/src/com/android/customization/model/picker/color/domain/interactor/ColorPickerInteractorTest.kt
+++ b/tests/src/com/android/customization/model/picker/color/domain/interactor/ColorPickerInteractorTest.kt
@@ -21,10 +21,13 @@
 import androidx.test.platform.app.InstrumentationRegistry
 import com.android.customization.picker.color.data.repository.FakeColorPickerRepository
 import com.android.customization.picker.color.domain.interactor.ColorPickerInteractor
+import com.android.customization.picker.color.domain.interactor.ColorPickerSnapshotRestorer
 import com.android.customization.picker.color.shared.model.ColorType
+import com.android.wallpaper.testing.FakeSnapshotStore
 import com.android.wallpaper.testing.collectLastValue
 import com.google.common.truth.Truth.assertThat
 import kotlinx.coroutines.ExperimentalCoroutinesApi
+import kotlinx.coroutines.runBlocking
 import kotlinx.coroutines.test.runTest
 import org.junit.Before
 import org.junit.Test
@@ -36,16 +39,26 @@
 @RunWith(JUnit4::class)
 class ColorPickerInteractorTest {
     private lateinit var underTest: ColorPickerInteractor
+    private lateinit var repository: FakeColorPickerRepository
+    private lateinit var store: FakeSnapshotStore
 
     private lateinit var context: Context
 
     @Before
     fun setUp() {
         context = InstrumentationRegistry.getInstrumentation().targetContext
+        repository = FakeColorPickerRepository(context = context)
+        store = FakeSnapshotStore()
         underTest =
             ColorPickerInteractor(
-                repository = FakeColorPickerRepository(context = context),
+                repository = repository,
+                snapshotRestorer = {
+                    ColorPickerSnapshotRestorer(interactor = underTest).apply {
+                        runBlocking { setUpSnapshotRestorer(store = store) }
+                    }
+                },
             )
+        repository.setOptions(4, 4, ColorType.WALLPAPER_COLOR, 0)
     }
 
     @Test
@@ -66,4 +79,40 @@
         val presetColorOptionModelAfter = colorOptions()?.get(ColorType.BASIC_COLOR)?.get(1)
         assertThat(presetColorOptionModelAfter?.isSelected).isTrue()
     }
+
+    @Test
+    fun snapshotRestorer_updatesSnapshot() = runTest {
+        val colorOptions = collectLastValue(underTest.colorOptions)
+        val wallpaperColorOptionModel0 = colorOptions()?.get(ColorType.WALLPAPER_COLOR)?.get(0)
+        val wallpaperColorOptionModel1 = colorOptions()?.get(ColorType.WALLPAPER_COLOR)?.get(1)
+        assertThat(wallpaperColorOptionModel0?.isSelected).isTrue()
+        assertThat(wallpaperColorOptionModel1?.isSelected).isFalse()
+
+        val storedSnapshot = store.retrieve()
+        wallpaperColorOptionModel1?.let { underTest.select(it) }
+        val wallpaperColorOptionModel0After = colorOptions()?.get(ColorType.WALLPAPER_COLOR)?.get(0)
+        val wallpaperColorOptionModel1After = colorOptions()?.get(ColorType.WALLPAPER_COLOR)?.get(1)
+        assertThat(wallpaperColorOptionModel0After?.isSelected).isFalse()
+        assertThat(wallpaperColorOptionModel1After?.isSelected).isTrue()
+
+        assertThat(store.retrieve()).isNotEqualTo(storedSnapshot)
+    }
+
+    @Test
+    fun snapshotRestorer_doesNotUpdateSnapshotOnExternalUpdates() = runTest {
+        val colorOptions = collectLastValue(underTest.colorOptions)
+        val wallpaperColorOptionModel0 = colorOptions()?.get(ColorType.WALLPAPER_COLOR)?.get(0)
+        val wallpaperColorOptionModel1 = colorOptions()?.get(ColorType.WALLPAPER_COLOR)?.get(1)
+        assertThat(wallpaperColorOptionModel0?.isSelected).isTrue()
+        assertThat(wallpaperColorOptionModel1?.isSelected).isFalse()
+
+        val storedSnapshot = store.retrieve()
+        repository.setOptions(4, 4, ColorType.WALLPAPER_COLOR, 1)
+        val wallpaperColorOptionModel0After = colorOptions()?.get(ColorType.WALLPAPER_COLOR)?.get(0)
+        val wallpaperColorOptionModel1After = colorOptions()?.get(ColorType.WALLPAPER_COLOR)?.get(1)
+        assertThat(wallpaperColorOptionModel0After?.isSelected).isFalse()
+        assertThat(wallpaperColorOptionModel1After?.isSelected).isTrue()
+
+        assertThat(store.retrieve()).isEqualTo(storedSnapshot)
+    }
 }
diff --git a/tests/src/com/android/customization/model/picker/color/domain/interactor/ColorPickerSnapshotRestorerTest.kt b/tests/src/com/android/customization/model/picker/color/domain/interactor/ColorPickerSnapshotRestorerTest.kt
new file mode 100644
index 0000000..27b8550
--- /dev/null
+++ b/tests/src/com/android/customization/model/picker/color/domain/interactor/ColorPickerSnapshotRestorerTest.kt
@@ -0,0 +1,138 @@
+/*
+ * Copyright (C) 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+package com.android.customization.model.picker.color.domain.interactor
+
+import android.content.Context
+import androidx.test.filters.SmallTest
+import androidx.test.platform.app.InstrumentationRegistry
+import com.android.customization.picker.color.data.repository.FakeColorPickerRepository
+import com.android.customization.picker.color.domain.interactor.ColorPickerInteractor
+import com.android.customization.picker.color.domain.interactor.ColorPickerSnapshotRestorer
+import com.android.customization.picker.color.shared.model.ColorOptionModel
+import com.android.customization.picker.color.shared.model.ColorType
+import com.android.wallpaper.testing.FakeSnapshotStore
+import com.android.wallpaper.testing.collectLastValue
+import com.google.common.truth.Truth
+import com.google.common.truth.Truth.assertThat
+import kotlinx.coroutines.ExperimentalCoroutinesApi
+import kotlinx.coroutines.test.runTest
+import org.junit.Before
+import org.junit.Test
+import org.junit.runner.RunWith
+import org.junit.runners.JUnit4
+
+@OptIn(ExperimentalCoroutinesApi::class)
+@SmallTest
+@RunWith(JUnit4::class)
+class ColorPickerSnapshotRestorerTest {
+
+    private lateinit var underTest: ColorPickerSnapshotRestorer
+    private lateinit var repository: FakeColorPickerRepository
+    private lateinit var store: FakeSnapshotStore
+
+    private lateinit var context: Context
+
+    @Before
+    fun setUp() {
+        context = InstrumentationRegistry.getInstrumentation().targetContext
+        repository = FakeColorPickerRepository(context = context)
+        underTest =
+            ColorPickerSnapshotRestorer(
+                interactor =
+                    ColorPickerInteractor(
+                        repository = repository,
+                        snapshotRestorer = { underTest },
+                    )
+            )
+        store = FakeSnapshotStore()
+    }
+
+    @Test
+    fun restoreToSnapshot_noCallsToStore_restoresToInitialSnapshot() = runTest {
+        val colorOptions = collectLastValue(repository.colorOptions)
+
+        repository.setOptions(4, 4, ColorType.WALLPAPER_COLOR, 2)
+        val initialSnapshot = underTest.setUpSnapshotRestorer(store = store)
+        assertThat(initialSnapshot.args).isNotEmpty()
+
+        val colorOptionToSelect = colorOptions()?.get(ColorType.BASIC_COLOR)?.get(3)
+        colorOptionToSelect?.let { repository.select(it) }
+        assertState(colorOptions(), ColorType.BASIC_COLOR, 3)
+
+        underTest.restoreToSnapshot(initialSnapshot)
+        assertState(colorOptions(), ColorType.WALLPAPER_COLOR, 2)
+    }
+
+    @Test
+    fun restoreToSnapshot_withCallToStore_restoresToInitialSnapshot() = runTest {
+        val colorOptions = collectLastValue(repository.colorOptions)
+
+        repository.setOptions(4, 4, ColorType.WALLPAPER_COLOR, 2)
+        val initialSnapshot = underTest.setUpSnapshotRestorer(store = store)
+        assertThat(initialSnapshot.args).isNotEmpty()
+
+        val colorOptionToSelect = colorOptions()?.get(ColorType.BASIC_COLOR)?.get(3)
+        colorOptionToSelect?.let { repository.select(it) }
+        assertState(colorOptions(), ColorType.BASIC_COLOR, 3)
+
+        val colorOptionToStore = colorOptions()?.get(ColorType.BASIC_COLOR)?.get(1)
+        colorOptionToStore?.let { underTest.storeSnapshot(colorOptionToStore) }
+
+        underTest.restoreToSnapshot(initialSnapshot)
+        assertState(colorOptions(), ColorType.WALLPAPER_COLOR, 2)
+    }
+
+    private fun assertState(
+        colorOptions: Map<ColorType, List<ColorOptionModel>>?,
+        selectedColorType: ColorType,
+        selectedColorIndex: Int
+    ) {
+        var foundSelectedColorOption = false
+        assertThat(colorOptions).isNotNull()
+        val optionsOfSelectedColorType = colorOptions?.get(selectedColorType)
+        assertThat(optionsOfSelectedColorType).isNotNull()
+        if (optionsOfSelectedColorType != null) {
+            for (i in optionsOfSelectedColorType.indices) {
+                val colorOptionHasSelectedIndex = i == selectedColorIndex
+                Truth.assertWithMessage(
+                        "Expected color option with index \"${i}\" to have" +
+                            " isSelected=$colorOptionHasSelectedIndex but it was" +
+                            " ${optionsOfSelectedColorType[i].isSelected}, num options: ${colorOptions.size}"
+                    )
+                    .that(optionsOfSelectedColorType[i].isSelected)
+                    .isEqualTo(colorOptionHasSelectedIndex)
+                foundSelectedColorOption = foundSelectedColorOption || colorOptionHasSelectedIndex
+            }
+            if (selectedColorIndex == -1) {
+                Truth.assertWithMessage(
+                        "Expected no color options to be selected, but a color option is" +
+                            " selected"
+                    )
+                    .that(foundSelectedColorOption)
+                    .isFalse()
+            } else {
+                Truth.assertWithMessage(
+                        "Expected a color option to be selected, but no color option is" +
+                            " selected"
+                    )
+                    .that(foundSelectedColorOption)
+                    .isTrue()
+            }
+        }
+    }
+}
diff --git a/tests/src/com/android/customization/model/picker/color/ui/viewmodel/ColorPickerViewModelTest.kt b/tests/src/com/android/customization/model/picker/color/ui/viewmodel/ColorPickerViewModelTest.kt
index 6e5f776..b7567ed 100644
--- a/tests/src/com/android/customization/model/picker/color/ui/viewmodel/ColorPickerViewModelTest.kt
+++ b/tests/src/com/android/customization/model/picker/color/ui/viewmodel/ColorPickerViewModelTest.kt
@@ -21,15 +21,24 @@
 import androidx.test.platform.app.InstrumentationRegistry
 import com.android.customization.picker.color.data.repository.FakeColorPickerRepository
 import com.android.customization.picker.color.domain.interactor.ColorPickerInteractor
+import com.android.customization.picker.color.domain.interactor.ColorPickerSnapshotRestorer
 import com.android.customization.picker.color.shared.model.ColorType
 import com.android.customization.picker.color.ui.viewmodel.ColorOptionViewModel
 import com.android.customization.picker.color.ui.viewmodel.ColorPickerViewModel
 import com.android.customization.picker.color.ui.viewmodel.ColorTypeViewModel
+import com.android.wallpaper.testing.FakeSnapshotStore
 import com.android.wallpaper.testing.collectLastValue
 import com.google.common.truth.Truth.assertThat
 import com.google.common.truth.Truth.assertWithMessage
+import kotlinx.coroutines.Dispatchers
 import kotlinx.coroutines.ExperimentalCoroutinesApi
+import kotlinx.coroutines.runBlocking
+import kotlinx.coroutines.test.StandardTestDispatcher
+import kotlinx.coroutines.test.TestScope
+import kotlinx.coroutines.test.resetMain
 import kotlinx.coroutines.test.runTest
+import kotlinx.coroutines.test.setMain
+import org.junit.After
 import org.junit.Before
 import org.junit.Test
 import org.junit.runner.RunWith
@@ -40,80 +49,111 @@
 @RunWith(JUnit4::class)
 class ColorPickerViewModelTest {
     private lateinit var underTest: ColorPickerViewModel
+    private lateinit var repository: FakeColorPickerRepository
+    private lateinit var interactor: ColorPickerInteractor
+    private lateinit var store: FakeSnapshotStore
 
     private lateinit var context: Context
+    private lateinit var testScope: TestScope
 
     @Before
     fun setUp() {
         context = InstrumentationRegistry.getInstrumentation().targetContext
+        val testDispatcher = StandardTestDispatcher()
+        testScope = TestScope(testDispatcher)
+        Dispatchers.setMain(testDispatcher)
+        repository = FakeColorPickerRepository(context = context)
+        store = FakeSnapshotStore()
+
+        interactor =
+            ColorPickerInteractor(
+                repository = repository,
+                snapshotRestorer = {
+                    ColorPickerSnapshotRestorer(interactor = interactor).apply {
+                        runBlocking { setUpSnapshotRestorer(store = store) }
+                    }
+                },
+            )
 
         underTest =
-            ColorPickerViewModel.Factory(
-                    context = context,
-                    interactor =
-                        ColorPickerInteractor(
-                            repository = FakeColorPickerRepository(context = context),
-                        ),
-                )
+            ColorPickerViewModel.Factory(context = context, interactor = interactor)
                 .create(ColorPickerViewModel::class.java)
+
+        repository.setOptions(4, 4, ColorType.WALLPAPER_COLOR, 0)
+    }
+
+    @After
+    fun tearDown() {
+        Dispatchers.resetMain()
     }
 
     @Test
-    fun `Select a color section color`() = runTest {
-        val colorSectionOptions = collectLastValue(underTest.colorSectionOptions)
+    fun `Select a color section color`() =
+        testScope.runTest {
+            val colorSectionOptions = collectLastValue(underTest.colorSectionOptions)
 
-        assertColorOptionUiState(colorOptions = colorSectionOptions(), selectedColorOptionIndex = 0)
+            assertColorOptionUiState(
+                colorOptions = colorSectionOptions(),
+                selectedColorOptionIndex = 0
+            )
 
-        colorSectionOptions()?.get(2)?.onClick?.invoke()
-        assertColorOptionUiState(colorOptions = colorSectionOptions(), selectedColorOptionIndex = 2)
+            colorSectionOptions()?.get(2)?.onClick?.invoke()
+            assertColorOptionUiState(
+                colorOptions = colorSectionOptions(),
+                selectedColorOptionIndex = 2
+            )
 
-        colorSectionOptions()?.get(4)?.onClick?.invoke()
-        assertColorOptionUiState(colorOptions = colorSectionOptions(), selectedColorOptionIndex = 4)
-    }
+            colorSectionOptions()?.get(4)?.onClick?.invoke()
+            assertColorOptionUiState(
+                colorOptions = colorSectionOptions(),
+                selectedColorOptionIndex = 4
+            )
+        }
 
     @Test
-    fun `Select a preset color`() = runTest {
-        val colorTypes = collectLastValue(underTest.colorTypes)
-        val colorOptions = collectLastValue(underTest.colorOptions)
+    fun `Select a preset color`() =
+        testScope.runTest {
+            val colorTypes = collectLastValue(underTest.colorTypes)
+            val colorOptions = collectLastValue(underTest.colorOptions)
 
-        // Initially, the wallpaper color tab should be selected
-        assertPickerUiState(
-            colorTypes = colorTypes(),
-            colorOptions = colorOptions(),
-            selectedColorTypeText = "Wallpaper colors",
-            selectedColorOptionIndex = 0
-        )
+            // Initially, the wallpaper color tab should be selected
+            assertPickerUiState(
+                colorTypes = colorTypes(),
+                colorOptions = colorOptions(),
+                selectedColorTypeText = "Wallpaper colors",
+                selectedColorOptionIndex = 0
+            )
 
-        // Select "Basic colors" tab
-        colorTypes()?.get(ColorType.BASIC_COLOR)?.onClick?.invoke()
-        assertPickerUiState(
-            colorTypes = colorTypes(),
-            colorOptions = colorOptions(),
-            selectedColorTypeText = "Basic colors",
-            selectedColorOptionIndex = -1
-        )
+            // Select "Basic colors" tab
+            colorTypes()?.get(ColorType.BASIC_COLOR)?.onClick?.invoke()
+            assertPickerUiState(
+                colorTypes = colorTypes(),
+                colorOptions = colorOptions(),
+                selectedColorTypeText = "Basic colors",
+                selectedColorOptionIndex = -1
+            )
 
-        // Select a color option
-        colorOptions()?.get(2)?.onClick?.invoke()
+            // Select a color option
+            colorOptions()?.get(2)?.onClick?.invoke()
 
-        // Check original option is no longer selected
-        colorTypes()?.get(ColorType.WALLPAPER_COLOR)?.onClick?.invoke()
-        assertPickerUiState(
-            colorTypes = colorTypes(),
-            colorOptions = colorOptions(),
-            selectedColorTypeText = "Wallpaper colors",
-            selectedColorOptionIndex = -1
-        )
+            // Check original option is no longer selected
+            colorTypes()?.get(ColorType.WALLPAPER_COLOR)?.onClick?.invoke()
+            assertPickerUiState(
+                colorTypes = colorTypes(),
+                colorOptions = colorOptions(),
+                selectedColorTypeText = "Wallpaper colors",
+                selectedColorOptionIndex = -1
+            )
 
-        // Check new option is selected
-        colorTypes()?.get(ColorType.BASIC_COLOR)?.onClick?.invoke()
-        assertPickerUiState(
-            colorTypes = colorTypes(),
-            colorOptions = colorOptions(),
-            selectedColorTypeText = "Basic colors",
-            selectedColorOptionIndex = 2
-        )
-    }
+            // Check new option is selected
+            colorTypes()?.get(ColorType.BASIC_COLOR)?.onClick?.invoke()
+            assertPickerUiState(
+                colorTypes = colorTypes(),
+                colorOptions = colorOptions(),
+                selectedColorTypeText = "Basic colors",
+                selectedColorOptionIndex = 2
+            )
+        }
 
     /**
      * Asserts the entire picker UI state is what is expected. This includes the color type tabs and
diff --git a/tests/src/com/android/customization/testing/TestCustomizationInjector.kt b/tests/src/com/android/customization/testing/TestCustomizationInjector.kt
index 3ab7c84..2a2ab5e 100644
--- a/tests/src/com/android/customization/testing/TestCustomizationInjector.kt
+++ b/tests/src/com/android/customization/testing/TestCustomizationInjector.kt
@@ -18,6 +18,7 @@
 import com.android.customization.picker.clock.ui.viewmodel.ClockSettingsViewModel
 import com.android.customization.picker.color.data.repository.ColorPickerRepositoryImpl
 import com.android.customization.picker.color.domain.interactor.ColorPickerInteractor
+import com.android.customization.picker.color.domain.interactor.ColorPickerSnapshotRestorer
 import com.android.customization.picker.color.ui.viewmodel.ColorPickerViewModel
 import com.android.customization.picker.quickaffordance.data.repository.KeyguardQuickAffordancePickerRepository
 import com.android.customization.picker.quickaffordance.domain.interactor.KeyguardQuickAffordancePickerInteractor
@@ -54,6 +55,7 @@
     private var clockViewFactory: ClockViewFactory? = null
     private var colorPickerInteractor: ColorPickerInteractor? = null
     private var colorPickerViewModelFactory: ColorPickerViewModel.Factory? = null
+    private var colorPickerSnapshotRestorer: ColorPickerSnapshotRestorer? = null
     private var clockCarouselViewModel: ClockCarouselViewModel? = null
     private var clockSettingsViewModelFactory: ClockSettingsViewModel.Factory? = null
 
@@ -118,6 +120,8 @@
         val restorers: MutableMap<Int, SnapshotRestorer> = HashMap()
         restorers[KEY_QUICK_AFFORDANCE_SNAPSHOT_RESTORER] =
             getKeyguardQuickAffordanceSnapshotRestorer(context)
+        restorers[KEY_COLOR_PICKER_SNAPSHOT_RESTORER] =
+            getColorPickerSnapshotRestorer(context, getWallpaperColorsViewModel())
         return restorers
     }
 
@@ -168,7 +172,12 @@
         wallpaperColorsViewModel: WallpaperColorsViewModel,
     ): ColorPickerInteractor {
         return colorPickerInteractor
-            ?: ColorPickerInteractor(ColorPickerRepositoryImpl(context, wallpaperColorsViewModel))
+            ?: ColorPickerInteractor(
+                    repository = ColorPickerRepositoryImpl(context, wallpaperColorsViewModel),
+                    snapshotRestorer = {
+                        getColorPickerSnapshotRestorer(context, wallpaperColorsViewModel)
+                    },
+                )
                 .also { colorPickerInteractor = it }
     }
 
@@ -184,6 +193,17 @@
                 .also { colorPickerViewModelFactory = it }
     }
 
+    private fun getColorPickerSnapshotRestorer(
+        context: Context,
+        wallpaperColorsViewModel: WallpaperColorsViewModel
+    ): ColorPickerSnapshotRestorer {
+        return colorPickerSnapshotRestorer
+            ?: ColorPickerSnapshotRestorer(
+                    getColorPickerInteractor(context, wallpaperColorsViewModel)
+                )
+                .also { colorPickerSnapshotRestorer = it }
+    }
+
     override fun getClockCarouselViewModel(context: Context): ClockCarouselViewModel {
         return clockCarouselViewModel
             ?: ClockCarouselViewModel(getClockPickerInteractor(context)).also {
@@ -209,5 +229,7 @@
 
     companion object {
         private const val KEY_QUICK_AFFORDANCE_SNAPSHOT_RESTORER = 1
+        private const val KEY_COLOR_PICKER_SNAPSHOT_RESTORER =
+            KEY_QUICK_AFFORDANCE_SNAPSHOT_RESTORER + 1
     }
 }