Extract shared values outside of the Element class

This CL moves the Maps and objects that are used to interpolate shared
values outside of the Element class. That way, animating a shared value
won't require the Element object, which will allow to remove the last
call to Snapshot.withoutReadObservation {} and map mutations during
compisition inside Modifier.element().

Test: AnimateSharedAsState
Bug: 291071158
Flag: N/A
Change-Id: Ib56dd943d233edf1276934c02ad37c08821e318d
diff --git a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/AnimateSharedAsState.kt b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/AnimateSharedAsState.kt
index 97a848a..b26194f 100644
--- a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/AnimateSharedAsState.kt
+++ b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/AnimateSharedAsState.kt
@@ -25,12 +25,12 @@
 import androidx.compose.runtime.mutableStateOf
 import androidx.compose.runtime.remember
 import androidx.compose.runtime.snapshotFlow
+import androidx.compose.runtime.snapshots.SnapshotStateMap
 import androidx.compose.ui.graphics.Color
 import androidx.compose.ui.graphics.lerp
 import androidx.compose.ui.unit.Dp
 import androidx.compose.ui.unit.lerp
 import com.android.compose.ui.util.lerp
-import kotlinx.coroutines.flow.collect
 
 /**
  * A [State] whose [value] is animated.
@@ -175,29 +175,44 @@
 @Composable
 internal fun <T> animateSharedValueAsState(
     layoutImpl: SceneTransitionLayoutImpl,
-    scene: Scene,
-    element: Element?,
+    scene: SceneKey,
+    element: ElementKey?,
     key: ValueKey,
     value: T,
     lerp: (T, T, Float) -> T,
     canOverflow: Boolean,
 ): AnimatedState<T> {
-    // Create the associated SharedValue object that holds the current value.
-    DisposableEffect(scene, element, key) {
-        val sharedValues = sharedValues(scene, element)
-        sharedValues[key] = Element.SharedValue(key, value)
-        onDispose { sharedValues.remove(key) }
+    DisposableEffect(layoutImpl, scene, element, key) {
+        // Create the associated maps that hold the current value for each (element, scene) pair.
+        val valueMap = layoutImpl.sharedValues.getOrPut(key) { mutableMapOf() }
+        val sceneToValueMap =
+            valueMap.getOrPut(element) { SnapshotStateMap<SceneKey, Any>() }
+                as SnapshotStateMap<SceneKey, T>
+        sceneToValueMap[scene] = value
+
+        onDispose {
+            // Remove the value associated to the current scene, and eventually remove the maps if
+            // they are empty.
+            sceneToValueMap.remove(scene)
+
+            if (sceneToValueMap.isEmpty() && valueMap[element] === sceneToValueMap) {
+                valueMap.remove(element)
+
+                if (valueMap.isEmpty() && layoutImpl.sharedValues[key] === valueMap) {
+                    layoutImpl.sharedValues.remove(key)
+                }
+            }
+        }
     }
 
     // Update the current value. Note that side effects run after disposable effects, so we know
-    // that the SharedValue object was created at this point.
-    SideEffect { sharedValue<T>(scene, element, key).value = value }
+    // that the associated maps were created at this point.
+    SideEffect { sceneToValueMap<T>(layoutImpl, key, element)[scene] = value }
 
-    val sceneKey = scene.key
-    return remember(layoutImpl, sceneKey, element, lerp, canOverflow) {
+    return remember(layoutImpl, scene, element, lerp, canOverflow) {
         object : AnimatedState<T> {
             override val value: T
-                get() = value(layoutImpl, sceneKey, element, key, lerp, canOverflow)
+                get() = value(layoutImpl, scene, element, key, lerp, canOverflow)
 
             @Composable
             override fun unsafeCompositionState(initialValue: T): State<T> {
@@ -214,28 +229,13 @@
     }
 }
 
-private fun sharedValues(
-    scene: Scene,
-    element: Element?,
-): MutableMap<ValueKey, Element.SharedValue<*>> {
-    return element?.sceneValues?.getValue(scene.key)?.sharedValues ?: scene.sharedValues
-}
-
-private fun <T> sharedValueOrNull(
-    scene: Scene,
-    element: Element?,
+private fun <T> sceneToValueMap(
+    layoutImpl: SceneTransitionLayoutImpl,
     key: ValueKey,
-): Element.SharedValue<T>? {
-    val sharedValue = sharedValues(scene, element)[key] ?: return null
-    return sharedValue as Element.SharedValue<T>
-}
-
-private fun <T> sharedValue(
-    scene: Scene,
-    element: Element?,
-    key: ValueKey,
-): Element.SharedValue<T> {
-    return sharedValueOrNull(scene, element, key) ?: error(valueReadTooEarlyMessage(key))
+    element: ElementKey?
+): MutableMap<SceneKey, T> {
+    return layoutImpl.sharedValues[key]?.get(element)?.let { it as SnapshotStateMap<SceneKey, T> }
+        ?: error(valueReadTooEarlyMessage(key))
 }
 
 private fun valueReadTooEarlyMessage(key: ValueKey) =
@@ -246,7 +246,7 @@
 private fun <T> value(
     layoutImpl: SceneTransitionLayoutImpl,
     scene: SceneKey,
-    element: Element?,
+    element: ElementKey?,
     key: ValueKey,
     lerp: (T, T, Float) -> T,
     canOverflow: Boolean,
@@ -258,25 +258,16 @@
 private fun <T> valueOrNull(
     layoutImpl: SceneTransitionLayoutImpl,
     scene: SceneKey,
-    element: Element?,
+    element: ElementKey?,
     key: ValueKey,
     lerp: (T, T, Float) -> T,
     canOverflow: Boolean,
 ): T? {
-    fun sceneValue(scene: SceneKey): Element.SharedValue<T>? {
-        val sharedValues =
-            if (element == null) {
-                layoutImpl.scene(scene).sharedValues
-            } else {
-                element.sceneValues[scene]?.sharedValues
-            }
-                ?: return null
-        val value = sharedValues[key] ?: return null
-        return value as Element.SharedValue<T>
-    }
+    val sceneToValueMap = sceneToValueMap<T>(layoutImpl, key, element)
+    fun sceneValue(scene: SceneKey): T? = sceneToValueMap[scene]
 
     return when (val transition = layoutImpl.state.transitionState) {
-        is TransitionState.Idle -> sceneValue(transition.currentScene)?.value
+        is TransitionState.Idle -> sceneValue(transition.currentScene)
         is TransitionState.Transition -> {
             // Note: no need to check for transition ready here given that all target values are
             // defined during composition, we should already have the correct values to interpolate
@@ -284,25 +275,21 @@
             val fromValue = sceneValue(transition.fromScene)
             val toValue = sceneValue(transition.toScene)
             if (fromValue != null && toValue != null) {
-                val from = fromValue.value
-                val to = toValue.value
-                if (from == to) {
+                if (fromValue == toValue) {
                     // Optimization: avoid reading progress if the values are the same, so we don't
                     // relayout/redraw for nothing.
-                    from
+                    fromValue
                 } else {
                     val progress =
                         if (canOverflow) transition.progress
                         else transition.progress.coerceIn(0f, 1f)
-                    lerp(from, to, progress)
+                    lerp(fromValue, toValue, progress)
                 }
-            } else if (fromValue != null) {
-                fromValue.value
-            } else toValue?.value
+            } else fromValue ?: toValue
         }
     }
     // TODO(b/311600838): Remove this. We should not have to fallback to the current scene value,
     // but we have to because code of removed nodes can still run if they are placed with a graphics
     // layer.
-    ?: sceneValue(scene)?.value
+    ?: sceneValue(scene)
 }
diff --git a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/Element.kt b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/Element.kt
index e4cc134..1cac477 100644
--- a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/Element.kt
+++ b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/Element.kt
@@ -100,12 +100,6 @@
         var targetSize by mutableStateOf(SizeUnspecified)
         var targetOffset by mutableStateOf(Offset.Unspecified)
 
-        private var _sharedValues: MutableMap<ValueKey, SharedValue<*>>? = null
-        val sharedValues: MutableMap<ValueKey, SharedValue<*>>
-            get() =
-                _sharedValues
-                    ?: SnapshotStateMap<ValueKey, SharedValue<*>>().also { _sharedValues = it }
-
         /**
          * The attached [ElementNode] a Modifier.element() for a given element and scene. During
          * composition, this set could have 0 to 2 elements. After composition and after all
@@ -114,12 +108,6 @@
         val nodes = mutableSetOf<ElementNode>()
     }
 
-    /** A shared value of this element. */
-    @Stable
-    class SharedValue<T>(val key: ValueKey, initialValue: T) {
-        var value by mutableStateOf(initialValue)
-    }
-
     companion object {
         val SizeUnspecified = IntSize(Int.MAX_VALUE, Int.MAX_VALUE)
         val AlphaUnspecified = Float.MIN_VALUE
diff --git a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/MovableElement.kt b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/MovableElement.kt
index 42d20f7..04b73fb 100644
--- a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/MovableElement.kt
+++ b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/MovableElement.kt
@@ -102,7 +102,15 @@
         lerp: (start: T, stop: T, fraction: Float) -> T,
         canOverflow: Boolean
     ): AnimatedState<T> {
-        return animateSharedValueAsState(layoutImpl, scene, element, key, value, lerp, canOverflow)
+        return animateSharedValueAsState(
+            layoutImpl,
+            scene.key,
+            element.key,
+            key,
+            value,
+            lerp,
+            canOverflow,
+        )
     }
 
     @Composable
diff --git a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/Scene.kt b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/Scene.kt
index 3c2525e..4785716 100644
--- a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/Scene.kt
+++ b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/Scene.kt
@@ -24,7 +24,6 @@
 import androidx.compose.runtime.mutableFloatStateOf
 import androidx.compose.runtime.mutableStateOf
 import androidx.compose.runtime.setValue
-import androidx.compose.runtime.snapshots.SnapshotStateMap
 import androidx.compose.ui.ExperimentalComposeUiApi
 import androidx.compose.ui.Modifier
 import androidx.compose.ui.graphics.Shape
@@ -50,13 +49,6 @@
     var zIndex by mutableFloatStateOf(zIndex)
     var targetSize by mutableStateOf(IntSize.Zero)
 
-    /** The shared values in this scene that are not tied to a specific element. */
-    private var _sharedValues: MutableMap<ValueKey, Element.SharedValue<*>>? = null
-    val sharedValues: MutableMap<ValueKey, Element.SharedValue<*>>
-        get() =
-            _sharedValues
-                ?: SnapshotStateMap<ValueKey, Element.SharedValue<*>>().also { _sharedValues = it }
-
     @Composable
     @OptIn(ExperimentalComposeUiApi::class)
     fun Content(modifier: Modifier = Modifier) {
@@ -116,7 +108,7 @@
     ): AnimatedState<T> {
         return animateSharedValueAsState(
             layoutImpl = layoutImpl,
-            scene = scene,
+            scene = scene.key,
             element = null,
             key = key,
             value = value,
diff --git a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayoutImpl.kt b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayoutImpl.kt
index 45e1a0f..c56202c 100644
--- a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayoutImpl.kt
+++ b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayoutImpl.kt
@@ -62,6 +62,20 @@
     internal val elements = SnapshotStateMap<ElementKey, Element>()
 
     /**
+     * The different values of a shared value keyed by a a [ValueKey] and the different elements and
+     * scenes it is associated to.
+     */
+    private var _sharedValues:
+        MutableMap<ValueKey, MutableMap<ElementKey?, SnapshotStateMap<SceneKey, *>>>? =
+        null
+    internal val sharedValues:
+        MutableMap<ValueKey, MutableMap<ElementKey?, SnapshotStateMap<SceneKey, *>>>
+        get() =
+            _sharedValues
+                ?: mutableMapOf<ValueKey, MutableMap<ElementKey?, SnapshotStateMap<SceneKey, *>>>()
+                    .also { _sharedValues = it }
+
+    /**
      * The scenes that are "ready", i.e. they were composed and fully laid-out at least once.
      *
      * Note that this map is *read* during composition, so it is a [SnapshotStateMap] to make sure