Merge changes Ie2944544,Ia7cec977,Id4e5ff33,Iad082ca5 into main

* changes:
  Add support for interruptions in shared values
  Move the implementation of AnimatedState into AnimatedStateImpl
  Make animated values support multiple transitions
  Expose the current scene key in SceneScope
diff --git a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/AnimateSharedAsState.kt b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/AnimateSharedAsState.kt
index 5d1a7c5..7fd3a176 100644
--- a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/AnimateSharedAsState.kt
+++ b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/AnimateSharedAsState.kt
@@ -27,11 +27,12 @@
 import androidx.compose.runtime.snapshotFlow
 import androidx.compose.runtime.snapshots.SnapshotStateMap
 import androidx.compose.ui.graphics.Color
-import androidx.compose.ui.graphics.lerp
+import androidx.compose.ui.graphics.colorspace.ColorSpaces
 import androidx.compose.ui.unit.Dp
-import androidx.compose.ui.unit.lerp
+import androidx.compose.ui.unit.dp
 import androidx.compose.ui.util.fastCoerceIn
-import androidx.compose.ui.util.lerp
+import androidx.compose.ui.util.fastLastOrNull
+import kotlin.math.roundToInt
 
 /**
  * A [State] whose [value] is animated.
@@ -74,7 +75,7 @@
     key: ValueKey,
     canOverflow: Boolean = true,
 ): AnimatedState<Int> {
-    return animateSceneValueAsState(value, key, ::lerp, canOverflow)
+    return animateSceneValueAsState(value, key, SharedIntType, canOverflow)
 }
 
 /**
@@ -88,7 +89,19 @@
     key: ValueKey,
     canOverflow: Boolean = true,
 ): AnimatedState<Int> {
-    return animateElementValueAsState(value, key, ::lerp, canOverflow)
+    return animateElementValueAsState(value, key, SharedIntType, canOverflow)
+}
+
+private object SharedIntType : SharedValueType<Int, Int> {
+    override val unspecifiedValue: Int = Int.MIN_VALUE
+    override val zeroDeltaValue: Int = 0
+
+    override fun lerp(a: Int, b: Int, progress: Float): Int =
+        androidx.compose.ui.util.lerp(a, b, progress)
+
+    override fun diff(a: Int, b: Int): Int = a - b
+
+    override fun addWeighted(a: Int, b: Int, bWeight: Float): Int = (a + b * bWeight).roundToInt()
 }
 
 /**
@@ -102,7 +115,7 @@
     key: ValueKey,
     canOverflow: Boolean = true,
 ): AnimatedState<Float> {
-    return animateSceneValueAsState(value, key, ::lerp, canOverflow)
+    return animateSceneValueAsState(value, key, SharedFloatType, canOverflow)
 }
 
 /**
@@ -116,7 +129,19 @@
     key: ValueKey,
     canOverflow: Boolean = true,
 ): AnimatedState<Float> {
-    return animateElementValueAsState(value, key, ::lerp, canOverflow)
+    return animateElementValueAsState(value, key, SharedFloatType, canOverflow)
+}
+
+private object SharedFloatType : SharedValueType<Float, Float> {
+    override val unspecifiedValue: Float = Float.MIN_VALUE
+    override val zeroDeltaValue: Float = 0f
+
+    override fun lerp(a: Float, b: Float, progress: Float): Float =
+        androidx.compose.ui.util.lerp(a, b, progress)
+
+    override fun diff(a: Float, b: Float): Float = a - b
+
+    override fun addWeighted(a: Float, b: Float, bWeight: Float): Float = a + b * bWeight
 }
 
 /**
@@ -130,7 +155,7 @@
     key: ValueKey,
     canOverflow: Boolean = true,
 ): AnimatedState<Dp> {
-    return animateSceneValueAsState(value, key, ::lerp, canOverflow)
+    return animateSceneValueAsState(value, key, SharedDpType, canOverflow)
 }
 
 /**
@@ -144,7 +169,20 @@
     key: ValueKey,
     canOverflow: Boolean = true,
 ): AnimatedState<Dp> {
-    return animateElementValueAsState(value, key, ::lerp, canOverflow)
+    return animateElementValueAsState(value, key, SharedDpType, canOverflow)
+}
+
+private object SharedDpType : SharedValueType<Dp, Dp> {
+    override val unspecifiedValue: Dp = Dp.Unspecified
+    override val zeroDeltaValue: Dp = 0.dp
+
+    override fun lerp(a: Dp, b: Dp, progress: Float): Dp {
+        return androidx.compose.ui.unit.lerp(a, b, progress)
+    }
+
+    override fun diff(a: Dp, b: Dp): Dp = a - b
+
+    override fun addWeighted(a: Dp, b: Dp, bWeight: Float): Dp = a + b * bWeight
 }
 
 /**
@@ -157,7 +195,7 @@
     value: Color,
     key: ValueKey,
 ): AnimatedState<Color> {
-    return animateSceneValueAsState(value, key, ::lerp, canOverflow = false)
+    return animateSceneValueAsState(value, key, SharedColorType, canOverflow = false)
 }
 
 /**
@@ -170,9 +208,56 @@
     value: Color,
     key: ValueKey,
 ): AnimatedState<Color> {
-    return animateElementValueAsState(value, key, ::lerp, canOverflow = false)
+    return animateElementValueAsState(value, key, SharedColorType, canOverflow = false)
 }
 
+private object SharedColorType : SharedValueType<Color, ColorDelta> {
+    override val unspecifiedValue: Color = Color.Unspecified
+    override val zeroDeltaValue: ColorDelta = ColorDelta(0f, 0f, 0f, 0f)
+
+    override fun lerp(a: Color, b: Color, progress: Float): Color {
+        return androidx.compose.ui.graphics.lerp(a, b, progress)
+    }
+
+    override fun diff(a: Color, b: Color): ColorDelta {
+        // Similar to lerp, we convert colors to the Oklab color space to perform operations on
+        // colors.
+        val aOklab = a.convert(ColorSpaces.Oklab)
+        val bOklab = b.convert(ColorSpaces.Oklab)
+        return ColorDelta(
+            red = aOklab.red - bOklab.red,
+            green = aOklab.green - bOklab.green,
+            blue = aOklab.blue - bOklab.blue,
+            alpha = aOklab.alpha - bOklab.alpha,
+        )
+    }
+
+    override fun addWeighted(a: Color, b: ColorDelta, bWeight: Float): Color {
+        val aOklab = a.convert(ColorSpaces.Oklab)
+        return Color(
+                red = aOklab.red + b.red * bWeight,
+                green = aOklab.green + b.green * bWeight,
+                blue = aOklab.blue + b.blue * bWeight,
+                alpha = aOklab.alpha + b.alpha * bWeight,
+                colorSpace = ColorSpaces.Oklab,
+            )
+            .convert(aOklab.colorSpace)
+    }
+}
+
+/**
+ * Represents the diff between two colors in the same color space.
+ *
+ * Note: This class is necessary because Color() checks the bounds of its values and UncheckedColor
+ * is internal.
+ */
+private class ColorDelta(
+    val red: Float,
+    val green: Float,
+    val blue: Float,
+    val alpha: Float,
+)
+
 @Composable
 internal fun <T> animateSharedValueAsState(
     layoutImpl: SceneTransitionLayoutImpl,
@@ -180,23 +265,22 @@
     element: ElementKey?,
     key: ValueKey,
     value: T,
-    lerp: (T, T, Float) -> T,
+    type: SharedValueType<T, *>,
     canOverflow: Boolean,
 ): AnimatedState<T> {
     DisposableEffect(layoutImpl, scene, element, key) {
         // Create the associated maps that hold the current value for each (element, scene) pair.
         val valueMap = layoutImpl.sharedValues.getOrPut(key) { mutableMapOf() }
-        val sceneToValueMap =
-            valueMap.getOrPut(element) { SnapshotStateMap<SceneKey, Any>() }
-                as SnapshotStateMap<SceneKey, T>
-        sceneToValueMap[scene] = value
+        val sharedValue = valueMap.getOrPut(element) { SharedValue(type) } as SharedValue<T, *>
+        val targetValues = sharedValue.targetValues
+        targetValues[scene] = value
 
         onDispose {
             // Remove the value associated to the current scene, and eventually remove the maps if
             // they are empty.
-            sceneToValueMap.remove(scene)
+            targetValues.remove(scene)
 
-            if (sceneToValueMap.isEmpty() && valueMap[element] === sceneToValueMap) {
+            if (targetValues.isEmpty() && valueMap[element] === sharedValue) {
                 valueMap.remove(element)
 
                 if (valueMap.isEmpty() && layoutImpl.sharedValues[key] === valueMap) {
@@ -208,34 +292,25 @@
 
     // Update the current value. Note that side effects run after disposable effects, so we know
     // that the associated maps were created at this point.
-    SideEffect { sceneToValueMap<T>(layoutImpl, key, element)[scene] = value }
-
-    return remember(layoutImpl, scene, element, lerp, canOverflow) {
-        object : AnimatedState<T> {
-            override val value: T
-                get() = value(layoutImpl, scene, element, key, lerp, canOverflow)
-
-            @Composable
-            override fun unsafeCompositionState(initialValue: T): State<T> {
-                val state = remember { mutableStateOf(initialValue) }
-
-                val animatedState = this
-                LaunchedEffect(animatedState) {
-                    snapshotFlow { animatedState.value }.collect { state.value = it }
-                }
-
-                return state
-            }
+    SideEffect {
+        if (value == type.unspecifiedValue) {
+            error("value is equal to $value, which is the undefined value for this type.")
         }
+
+        sharedValue<T, Any>(layoutImpl, key, element).targetValues[scene] = value
+    }
+
+    return remember(layoutImpl, scene, element, canOverflow) {
+        AnimatedStateImpl<T, Any>(layoutImpl, scene, element, key, canOverflow)
     }
 }
 
-private fun <T> sceneToValueMap(
+private fun <T, Delta> sharedValue(
     layoutImpl: SceneTransitionLayoutImpl,
     key: ValueKey,
     element: ElementKey?
-): MutableMap<SceneKey, T> {
-    return layoutImpl.sharedValues[key]?.get(element)?.let { it as SnapshotStateMap<SceneKey, T> }
+): SharedValue<T, Delta> {
+    return layoutImpl.sharedValues[key]?.get(element)?.let { it as SharedValue<T, Delta> }
         ?: error(valueReadTooEarlyMessage(key))
 }
 
@@ -244,62 +319,155 @@
         "means that you are reading it during composition, which you should not do. See the " +
         "documentation of AnimatedState for more information."
 
-private fun <T> value(
-    layoutImpl: SceneTransitionLayoutImpl,
-    scene: SceneKey,
-    element: ElementKey?,
-    key: ValueKey,
-    lerp: (T, T, Float) -> T,
-    canOverflow: Boolean,
-): T {
-    return valueOrNull(layoutImpl, scene, element, key, lerp, canOverflow)
-        ?: error(valueReadTooEarlyMessage(key))
+internal class SharedValue<T, Delta>(
+    val type: SharedValueType<T, Delta>,
+) {
+    /** The target value of this shared value for each scene. */
+    val targetValues = SnapshotStateMap<SceneKey, T>()
+
+    /** The last value of this shared value. */
+    var lastValue: T = type.unspecifiedValue
+
+    /** The value of this shared value before the last interruption (if any). */
+    var valueBeforeInterruption: T = type.unspecifiedValue
+
+    /** The delta value to add to this shared value to have smoother interruptions. */
+    var valueInterruptionDelta = type.zeroDeltaValue
+
+    /** The last transition that was used when the value of this shared state. */
+    var lastTransition: TransitionState.Transition? = null
 }
 
-private fun <T> valueOrNull(
-    layoutImpl: SceneTransitionLayoutImpl,
-    scene: SceneKey,
-    element: ElementKey?,
-    key: ValueKey,
-    lerp: (T, T, Float) -> T,
-    canOverflow: Boolean,
-): T? {
-    val sceneToValueMap = sceneToValueMap<T>(layoutImpl, key, element)
-    fun sceneValue(scene: SceneKey): T? = sceneToValueMap[scene]
+private class AnimatedStateImpl<T, Delta>(
+    private val layoutImpl: SceneTransitionLayoutImpl,
+    private val scene: SceneKey,
+    private val element: ElementKey?,
+    private val key: ValueKey,
+    private val canOverflow: Boolean,
+) : AnimatedState<T> {
+    override val value: T
+        get() = value()
 
-    return when (val transition = layoutImpl.state.transitionState) {
-        is TransitionState.Idle -> sceneValue(transition.currentScene)
-        is TransitionState.Transition -> {
-            // Note: no need to check for transition ready here given that all target values are
-            // defined during composition, we should already have the correct values to interpolate
-            // between here.
-            val fromValue = sceneValue(transition.fromScene)
-            val toValue = sceneValue(transition.toScene)
-            if (fromValue != null && toValue != null) {
-                if (fromValue == toValue) {
-                    // Optimization: avoid reading progress if the values are the same, so we don't
-                    // relayout/redraw for nothing.
-                    fromValue
-                } else {
-                    // In the case of bouncing, if the value remains constant during the overscroll,
-                    // we should use the value of the scene we are bouncing around.
-                    if (!canOverflow && transition is TransitionState.HasOverscrollProperties) {
-                        val bouncingScene = transition.bouncingScene
-                        if (bouncingScene != null) {
-                            return sceneValue(bouncingScene)
-                        }
+    private fun value(): T {
+        val sharedValue = sharedValue<T, Delta>(layoutImpl, key, element)
+        val transition = transition(sharedValue)
+        val value: T =
+            valueOrNull(sharedValue, transition)
+                // TODO(b/311600838): Remove this. We should not have to fallback to the current
+                // scene value, but we have to because code of removed nodes can still run if they
+                // are placed with a graphics layer.
+                ?: sharedValue[scene]
+                ?: error(valueReadTooEarlyMessage(key))
+        val interruptedValue = computeInterruptedValue(sharedValue, transition, value)
+        sharedValue.lastValue = interruptedValue
+        return interruptedValue
+    }
+
+    private operator fun SharedValue<T, *>.get(scene: SceneKey): T? = targetValues[scene]
+
+    private fun valueOrNull(
+        sharedValue: SharedValue<T, *>,
+        transition: TransitionState.Transition?,
+    ): T? {
+        if (transition == null) {
+            return sharedValue[layoutImpl.state.transitionState.currentScene]
+        }
+
+        val fromValue = sharedValue[transition.fromScene]
+        val toValue = sharedValue[transition.toScene]
+        return if (fromValue != null && toValue != null) {
+            if (fromValue == toValue) {
+                // Optimization: avoid reading progress if the values are the same, so we don't
+                // relayout/redraw for nothing.
+                fromValue
+            } else {
+                // In the case of bouncing, if the value remains constant during the overscroll, we
+                // should use the value of the scene we are bouncing around.
+                if (!canOverflow && transition is TransitionState.HasOverscrollProperties) {
+                    val bouncingScene = transition.bouncingScene
+                    if (bouncingScene != null) {
+                        return sharedValue[bouncingScene]
                     }
-
-                    val progress =
-                        if (canOverflow) transition.progress
-                        else transition.progress.fastCoerceIn(0f, 1f)
-                    lerp(fromValue, toValue, progress)
                 }
-            } else fromValue ?: toValue
+
+                val progress =
+                    if (canOverflow) transition.progress
+                    else transition.progress.fastCoerceIn(0f, 1f)
+                sharedValue.type.lerp(fromValue, toValue, progress)
+            }
+        } else fromValue ?: toValue
+    }
+
+    private fun transition(sharedValue: SharedValue<T, Delta>): TransitionState.Transition? {
+        val targetValues = sharedValue.targetValues
+        val transition =
+            if (element != null) {
+                layoutImpl.elements[element]?.sceneStates?.let { sceneStates ->
+                    layoutImpl.state.currentTransitions.fastLastOrNull { transition ->
+                        transition.fromScene in sceneStates || transition.toScene in sceneStates
+                    }
+                }
+            } else {
+                layoutImpl.state.currentTransitions.fastLastOrNull { transition ->
+                    transition.fromScene in targetValues || transition.toScene in targetValues
+                }
+            }
+
+        val previousTransition = sharedValue.lastTransition
+        sharedValue.lastTransition = transition
+
+        if (transition != previousTransition && transition != null && previousTransition != null) {
+            // The previous transition was interrupted by another transition.
+            sharedValue.valueBeforeInterruption = sharedValue.lastValue
+            sharedValue.valueInterruptionDelta = sharedValue.type.zeroDeltaValue
+        } else if (transition == null && previousTransition != null) {
+            // The transition was just finished.
+            sharedValue.valueBeforeInterruption = sharedValue.type.unspecifiedValue
+            sharedValue.valueInterruptionDelta = sharedValue.type.zeroDeltaValue
+        }
+
+        return transition
+    }
+
+    /**
+     * Compute what [value] should be if we take the
+     * [interruption progress][TransitionState.Transition.interruptionProgress] of [transition] into
+     * account.
+     */
+    private fun computeInterruptedValue(
+        sharedValue: SharedValue<T, Delta>,
+        transition: TransitionState.Transition?,
+        value: T,
+    ): T {
+        val type = sharedValue.type
+        if (sharedValue.valueBeforeInterruption != type.unspecifiedValue) {
+            sharedValue.valueInterruptionDelta =
+                type.diff(sharedValue.valueBeforeInterruption, value)
+            sharedValue.valueBeforeInterruption = type.unspecifiedValue
+        }
+
+        val delta = sharedValue.valueInterruptionDelta
+        return if (delta == type.zeroDeltaValue || transition == null) {
+            value
+        } else {
+            val interruptionProgress = transition.interruptionProgress(layoutImpl)
+            if (interruptionProgress == 0f) {
+                value
+            } else {
+                type.addWeighted(value, delta, interruptionProgress)
+            }
         }
     }
-    // TODO(b/311600838): Remove this. We should not have to fallback to the current scene value,
-    // but we have to because code of removed nodes can still run if they are placed with a graphics
-    // layer.
-    ?: sceneValue(scene)
+
+    @Composable
+    override fun unsafeCompositionState(initialValue: T): State<T> {
+        val state = remember { mutableStateOf(initialValue) }
+
+        val animatedState = this
+        LaunchedEffect(animatedState) {
+            snapshotFlow { animatedState.value }.collect { state.value = it }
+        }
+
+        return state
+    }
 }
diff --git a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/MovableElement.kt b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/MovableElement.kt
index 4b20aca..be005ea 100644
--- a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/MovableElement.kt
+++ b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/MovableElement.kt
@@ -77,7 +77,7 @@
     override fun <T> animateElementValueAsState(
         value: T,
         key: ValueKey,
-        lerp: (start: T, stop: T, fraction: Float) -> T,
+        type: SharedValueType<T, *>,
         canOverflow: Boolean
     ): AnimatedState<T> {
         return animateSharedValueAsState(
@@ -86,7 +86,7 @@
             element,
             key,
             value,
-            lerp,
+            type,
             canOverflow,
         )
     }
@@ -184,8 +184,7 @@
                 fromSceneZIndex = layoutImpl.scenes.getValue(transition.fromScene).zIndex,
                 toSceneZIndex = layoutImpl.scenes.getValue(transition.toScene).zIndex,
             ) != null
-        }
-            ?: return false
+        } ?: return false
 
     // Always compose movable elements in the scene picked by their scene picker.
     return shouldDrawOrComposeSharedElement(
diff --git a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/Scene.kt b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/Scene.kt
index 6fef33c..936f4ba 100644
--- a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/Scene.kt
+++ b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/Scene.kt
@@ -24,7 +24,6 @@
 import androidx.compose.runtime.mutableFloatStateOf
 import androidx.compose.runtime.mutableStateOf
 import androidx.compose.runtime.setValue
-import androidx.compose.ui.ExperimentalComposeUiApi
 import androidx.compose.ui.Modifier
 import androidx.compose.ui.layout.approachLayout
 import androidx.compose.ui.platform.testTag
@@ -69,7 +68,6 @@
     }
 
     @Composable
-    @OptIn(ExperimentalComposeUiApi::class)
     fun Content(modifier: Modifier = Modifier) {
         Box(
             modifier
@@ -96,6 +94,7 @@
     private val layoutImpl: SceneTransitionLayoutImpl,
     private val scene: Scene,
 ) : SceneScope, ElementStateScope by layoutImpl.elementStateScope {
+    override val sceneKey: SceneKey = scene.key
     override val layoutState: SceneTransitionLayoutState = layoutImpl.state
 
     override fun Modifier.element(key: ElementKey): Modifier {
@@ -124,7 +123,7 @@
     override fun <T> animateSceneValueAsState(
         value: T,
         key: ValueKey,
-        lerp: (T, T, Float) -> T,
+        type: SharedValueType<T, *>,
         canOverflow: Boolean
     ): AnimatedState<T> {
         return animateSharedValueAsState(
@@ -133,7 +132,7 @@
             element = null,
             key = key,
             value = value,
-            lerp = lerp,
+            type = type,
             canOverflow = canOverflow,
         )
     }
diff --git a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayout.kt b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayout.kt
index cf8c584..2946b04 100644
--- a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayout.kt
+++ b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayout.kt
@@ -167,6 +167,9 @@
 @Stable
 @ElementDsl
 interface BaseSceneScope : ElementStateScope {
+    /** The key of this scene. */
+    val sceneKey: SceneKey
+
     /** The state of the [SceneTransitionLayout] in which this scene is contained. */
     val layoutState: SceneTransitionLayoutState
 
@@ -285,9 +288,7 @@
      *
      * @param value the value of this shared value in the current scene.
      * @param key the key of this shared value.
-     * @param lerp the *linear* interpolation function that should be used to interpolate between
-     *   two different values. Note that it has to be linear because the [fraction] passed to this
-     *   interpolator is already interpolated.
+     * @param type the [SharedValueType] of this animated value.
      * @param canOverflow whether this value can overflow past the values it is interpolated
      *   between, for instance because the transition is animated using a bouncy spring.
      * @see animateSceneIntAsState
@@ -299,11 +300,39 @@
     fun <T> animateSceneValueAsState(
         value: T,
         key: ValueKey,
-        lerp: (start: T, stop: T, fraction: Float) -> T,
+        type: SharedValueType<T, *>,
         canOverflow: Boolean,
     ): AnimatedState<T>
 }
 
+/**
+ * The type of a shared value animated using [ElementScope.animateElementValueAsState] or
+ * [SceneScope.animateSceneValueAsState].
+ */
+@Stable
+interface SharedValueType<T, Delta> {
+    /** The unspecified value for this type. */
+    val unspecifiedValue: T
+
+    /**
+     * The zero value of this type. It should be equal to what [diff(x, x)] returns for any value of
+     * x.
+     */
+    val zeroDeltaValue: Delta
+
+    /**
+     * Return the linear interpolation of [a] and [b] at the given [progress], i.e. `a + (b - a) *
+     * progress`.
+     */
+    fun lerp(a: T, b: T, progress: Float): T
+
+    /** Return `a - b`. */
+    fun diff(a: T, b: T): Delta
+
+    /** Return `a + b * bWeight`. */
+    fun addWeighted(a: T, b: Delta, bWeight: Float): T
+}
+
 @Stable
 @ElementDsl
 interface ElementScope<ContentScope> {
@@ -312,9 +341,7 @@
      *
      * @param value the value of this shared value in the current scene.
      * @param key the key of this shared value.
-     * @param lerp the *linear* interpolation function that should be used to interpolate between
-     *   two different values. Note that it has to be linear because the [fraction] passed to this
-     *   interpolator is already interpolated.
+     * @param type the [SharedValueType] of this animated value.
      * @param canOverflow whether this value can overflow past the values it is interpolated
      *   between, for instance because the transition is animated using a bouncy spring.
      * @see animateElementIntAsState
@@ -326,7 +353,7 @@
     fun <T> animateElementValueAsState(
         value: T,
         key: ValueKey,
-        lerp: (start: T, stop: T, fraction: Float) -> T,
+        type: SharedValueType<T, *>,
         canOverflow: Boolean,
     ): AnimatedState<T>
 
diff --git a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayoutImpl.kt b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayoutImpl.kt
index c614265..5fa7c87 100644
--- a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayoutImpl.kt
+++ b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayoutImpl.kt
@@ -85,15 +85,14 @@
      * The different values of a shared value keyed by a a [ValueKey] and the different elements and
      * scenes it is associated to.
      */
-    private var _sharedValues:
-        MutableMap<ValueKey, MutableMap<ElementKey?, SnapshotStateMap<SceneKey, *>>>? =
+    private var _sharedValues: MutableMap<ValueKey, MutableMap<ElementKey?, SharedValue<*, *>>>? =
         null
-    internal val sharedValues:
-        MutableMap<ValueKey, MutableMap<ElementKey?, SnapshotStateMap<SceneKey, *>>>
+    internal val sharedValues: MutableMap<ValueKey, MutableMap<ElementKey?, SharedValue<*, *>>>
         get() =
             _sharedValues
-                ?: mutableMapOf<ValueKey, MutableMap<ElementKey?, SnapshotStateMap<SceneKey, *>>>()
-                    .also { _sharedValues = it }
+                ?: mutableMapOf<ValueKey, MutableMap<ElementKey?, SharedValue<*, *>>>().also {
+                    _sharedValues = it
+                }
 
     // TODO(b/317958526): Lazily allocate scene gesture handlers the first time they are needed.
     private val horizontalDraggableHandler: DraggableHandlerImpl
diff --git a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayoutState.kt b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayoutState.kt
index a5b6d24..44affd9 100644
--- a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayoutState.kt
+++ b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayoutState.kt
@@ -457,7 +457,7 @@
      */
     internal fun startTransition(
         transition: TransitionState.Transition,
-        transitionKey: TransitionKey?,
+        transitionKey: TransitionKey? = null,
         chain: Boolean = true,
     ) {
         checkThread()
diff --git a/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/AnimatedSharedAsStateTest.kt b/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/AnimatedSharedAsStateTest.kt
index e8854cf..6e8b208 100644
--- a/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/AnimatedSharedAsStateTest.kt
+++ b/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/AnimatedSharedAsStateTest.kt
@@ -32,7 +32,12 @@
 import androidx.compose.ui.unit.lerp
 import androidx.compose.ui.util.lerp
 import androidx.test.ext.junit.runners.AndroidJUnit4
+import com.android.compose.animation.scene.TestScenes.SceneA
+import com.android.compose.animation.scene.TestScenes.SceneB
+import com.android.compose.animation.scene.TestScenes.SceneC
+import com.android.compose.animation.scene.TestScenes.SceneD
 import com.google.common.truth.Truth.assertThat
+import kotlinx.coroutines.test.runTest
 import org.junit.Assert.assertThrows
 import org.junit.Rule
 import org.junit.Test
@@ -130,8 +135,8 @@
                 // The transition lasts 64ms = 4 frames.
                 spec = tween(durationMillis = 16 * 4, easing = LinearEasing)
             },
-            fromScene = TestScenes.SceneA,
-            toScene = TestScenes.SceneB,
+            fromScene = SceneA,
+            toScene = SceneB,
         ) {
             before {
                 assertThat(lastValueInFrom).isEqualTo(fromValues)
@@ -189,8 +194,8 @@
                 // The transition lasts 64ms = 4 frames.
                 spec = tween(durationMillis = 16 * 4, easing = LinearEasing)
             },
-            fromScene = TestScenes.SceneA,
-            toScene = TestScenes.SceneB,
+            fromScene = SceneA,
+            toScene = SceneB,
         ) {
             before {
                 assertThat(lastValueInFrom).isEqualTo(fromValues)
@@ -243,8 +248,8 @@
                 // The transition lasts 64ms = 4 frames.
                 spec = tween(durationMillis = 16 * 4, easing = LinearEasing)
             },
-            fromScene = TestScenes.SceneA,
-            toScene = TestScenes.SceneB,
+            fromScene = SceneA,
+            toScene = SceneB,
         ) {
             before {
                 assertThat(lastValueInFrom).isEqualTo(fromValues)
@@ -381,4 +386,61 @@
             }
         }
     }
+
+    @Test
+    fun animatedValueIsUsingLastTransition() = runTest {
+        val state =
+            rule.runOnUiThread { MutableSceneTransitionLayoutStateImpl(SceneA, transitions {}) }
+
+        val foo = ValueKey("foo")
+        val bar = ValueKey("bar")
+        val lastValues = mutableMapOf<ValueKey, MutableMap<SceneKey, Float>>()
+
+        @Composable
+        fun SceneScope.animateFloat(value: Float, key: ValueKey) {
+            val animatedValue = animateSceneFloatAsState(value, key)
+            LaunchedEffect(animatedValue) {
+                snapshotFlow { animatedValue.value }
+                    .collect { lastValues.getOrPut(key) { mutableMapOf() }[sceneKey] = it }
+            }
+        }
+
+        rule.setContent {
+            SceneTransitionLayout(state) {
+                // foo goes from 0f to 100f in A => B.
+                scene(SceneA) { animateFloat(0f, foo) }
+                scene(SceneB) { animateFloat(100f, foo) }
+
+                // bar goes from 0f to 10f in C => D.
+                scene(SceneC) { animateFloat(0f, bar) }
+                scene(SceneD) { animateFloat(10f, bar) }
+            }
+        }
+
+        rule.runOnUiThread {
+            // A => B is at 30%.
+            state.startTransition(
+                transition(
+                    from = SceneA,
+                    to = SceneB,
+                    progress = { 0.3f },
+                    onFinish = neverFinish(),
+                )
+            )
+
+            // C => D is at 70%.
+            state.startTransition(transition(from = SceneC, to = SceneD, progress = { 0.7f }))
+        }
+        rule.waitForIdle()
+
+        assertThat(lastValues[foo]?.get(SceneA)).isWithin(0.001f).of(30f)
+        assertThat(lastValues[foo]?.get(SceneB)).isWithin(0.001f).of(30f)
+        assertThat(lastValues[foo]?.get(SceneC)).isNull()
+        assertThat(lastValues[foo]?.get(SceneD)).isNull()
+
+        assertThat(lastValues[bar]?.get(SceneA)).isNull()
+        assertThat(lastValues[bar]?.get(SceneB)).isNull()
+        assertThat(lastValues[bar]?.get(SceneC)).isWithin(0.001f).of(7f)
+        assertThat(lastValues[bar]?.get(SceneD)).isWithin(0.001f).of(7f)
+    }
 }
diff --git a/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/ElementTest.kt b/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/ElementTest.kt
index 9692fae..beb74bc 100644
--- a/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/ElementTest.kt
+++ b/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/ElementTest.kt
@@ -59,6 +59,7 @@
 import androidx.compose.ui.unit.IntSize
 import androidx.compose.ui.unit.dp
 import androidx.compose.ui.unit.lerp
+import androidx.compose.ui.util.lerp
 import androidx.test.ext.junit.runners.AndroidJUnit4
 import com.android.compose.animation.scene.TestScenes.SceneA
 import com.android.compose.animation.scene.TestScenes.SceneB
@@ -348,7 +349,7 @@
                     ),
                 onLayoutImpl = { nullableLayoutImpl = it },
             ) {
-                scene(SceneA) { /* Nothing */}
+                scene(SceneA) { /* Nothing */ }
                 scene(SceneB) { Box(Modifier.element(key)) }
                 scene(SceneC) {
                     when (sceneCState) {
@@ -1083,10 +1084,17 @@
             }
 
         val layoutSize = DpSize(200.dp, 100.dp)
+        val lastValues = mutableMapOf<SceneKey, Float>()
 
         @Composable
-        fun SceneScope.Foo(size: Dp, modifier: Modifier = Modifier) {
-            Box(modifier.element(TestElements.Foo).size(size))
+        fun SceneScope.Foo(size: Dp, value: Float, modifier: Modifier = Modifier) {
+            val sceneKey = this.sceneKey
+            Element(TestElements.Foo, modifier.size(size)) {
+                val animatedValue = animateElementFloatAsState(value, TestValues.Value1)
+                LaunchedEffect(animatedValue) {
+                    snapshotFlow { animatedValue.value }.collect { lastValues[sceneKey] = it }
+                }
+            }
         }
 
         // The size of Foo when idle in A, B or C.
@@ -1094,6 +1102,11 @@
         val sizeInB = 30.dp
         val sizeInC = 50.dp
 
+        // The target value when idle in A, B, or C.
+        val valueInA = 0f
+        val valueInB = 100f
+        val valueInC = 200f
+
         lateinit var layoutImpl: SceneTransitionLayoutImpl
         rule.setContent {
             SceneTransitionLayoutForTesting(
@@ -1103,7 +1116,9 @@
             ) {
                 // In scene A, Foo is aligned at the TopStart.
                 scene(SceneA) {
-                    Box(Modifier.fillMaxSize()) { Foo(sizeInA, Modifier.align(Alignment.TopStart)) }
+                    Box(Modifier.fillMaxSize()) {
+                        Foo(sizeInA, valueInA, Modifier.align(Alignment.TopStart))
+                    }
                 }
 
                 // In scene C, Foo is aligned at the BottomEnd, so it moves vertically when coming
@@ -1111,14 +1126,16 @@
                 // values and deltas are properly cleared once all transitions are done.
                 scene(SceneC) {
                     Box(Modifier.fillMaxSize()) {
-                        Foo(sizeInC, Modifier.align(Alignment.BottomEnd))
+                        Foo(sizeInC, valueInC, Modifier.align(Alignment.BottomEnd))
                     }
                 }
 
                 // In scene B, Foo is aligned at the TopEnd, so it moves horizontally when coming
                 // from A.
                 scene(SceneB) {
-                    Box(Modifier.fillMaxSize()) { Foo(sizeInB, Modifier.align(Alignment.TopEnd)) }
+                    Box(Modifier.fillMaxSize()) {
+                        Foo(sizeInB, valueInB, Modifier.align(Alignment.TopEnd))
+                    }
                 }
             }
         }
@@ -1134,6 +1151,10 @@
             .assertSizeIsEqualTo(sizeInA)
             .assertPositionInRootIsEqualTo(offsetInA.x, offsetInA.y)
 
+        assertThat(lastValues[SceneA]).isWithin(0.001f).of(valueInA)
+        assertThat(lastValues[SceneB]).isNull()
+        assertThat(lastValues[SceneC]).isNull()
+
         // Current transition is A => B at 50%.
         val aToBProgress = 0.5f
         val aToB =
@@ -1145,12 +1166,17 @@
             )
         val offsetInAToB = lerp(offsetInA, offsetInB, aToBProgress)
         val sizeInAToB = lerp(sizeInA, sizeInB, aToBProgress)
+        val valueInAToB = lerp(valueInA, valueInB, aToBProgress)
         rule.runOnUiThread { state.startTransition(aToB, transitionKey = null) }
         rule
             .onNode(isElement(TestElements.Foo, SceneB))
             .assertSizeIsEqualTo(sizeInAToB)
             .assertPositionInRootIsEqualTo(offsetInAToB.x, offsetInAToB.y)
 
+        assertThat(lastValues[SceneA]).isWithin(0.001f).of(valueInAToB)
+        assertThat(lastValues[SceneB]).isWithin(0.001f).of(valueInAToB)
+        assertThat(lastValues[SceneC]).isNull()
+
         // Start B => C at 0%.
         var bToCProgress by mutableFloatStateOf(0f)
         var interruptionProgress by mutableFloatStateOf(1f)
@@ -1167,6 +1193,11 @@
         // to the current transition offset and size.
         val offsetInterruptionDelta = offsetInAToB - offsetInB
         val sizeInterruptionDelta = sizeInAToB - sizeInB
+        val valueInterruptionDelta = valueInAToB - valueInB
+
+        assertThat(offsetInterruptionDelta).isNotEqualTo(DpOffset.Zero)
+        assertThat(sizeInterruptionDelta).isNotEqualTo(0.dp)
+        assertThat(valueInterruptionDelta).isNotEqualTo(0f)
 
         // Interruption progress is at 100% and bToC is at 0%, so Foo should be at the same offset
         // and size as right before the interruption.
@@ -1175,11 +1206,16 @@
             .assertPositionInRootIsEqualTo(offsetInAToB.x, offsetInAToB.y)
             .assertSizeIsEqualTo(sizeInAToB)
 
+        assertThat(lastValues[SceneA]).isWithin(0.001f).of(valueInAToB)
+        assertThat(lastValues[SceneB]).isWithin(0.001f).of(valueInAToB)
+        assertThat(lastValues[SceneC]).isWithin(0.001f).of(valueInAToB)
+
         // Move the transition forward at 30% and set the interruption progress to 50%.
         bToCProgress = 0.3f
         interruptionProgress = 0.5f
         val offsetInBToC = lerp(offsetInB, offsetInC, bToCProgress)
         val sizeInBToC = lerp(sizeInB, sizeInC, bToCProgress)
+        val valueInBToC = lerp(valueInB, valueInC, bToCProgress)
         val offsetInBToCWithInterruption =
             offsetInBToC +
                 DpOffset(
@@ -1187,6 +1223,9 @@
                     offsetInterruptionDelta.y * interruptionProgress,
                 )
         val sizeInBToCWithInterruption = sizeInBToC + sizeInterruptionDelta * interruptionProgress
+        val valueInBToCWithInterruption =
+            valueInBToC + valueInterruptionDelta * interruptionProgress
+
         rule.waitForIdle()
         rule
             .onNode(isElement(TestElements.Foo, SceneB))
@@ -1196,6 +1235,10 @@
             )
             .assertSizeIsEqualTo(sizeInBToCWithInterruption)
 
+        assertThat(lastValues[SceneA]).isWithin(0.001f).of(valueInBToCWithInterruption)
+        assertThat(lastValues[SceneB]).isWithin(0.001f).of(valueInBToCWithInterruption)
+        assertThat(lastValues[SceneC]).isWithin(0.001f).of(valueInBToCWithInterruption)
+
         // Finish the transition and interruption.
         bToCProgress = 1f
         interruptionProgress = 0f
diff --git a/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/SceneTransitionLayoutTest.kt b/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/SceneTransitionLayoutTest.kt
index 3751a22..08532bd 100644
--- a/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/SceneTransitionLayoutTest.kt
+++ b/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/SceneTransitionLayoutTest.kt
@@ -461,4 +461,30 @@
         assertThat(exception).hasMessageThat().contains(Back.toString())
         assertThat(exception).hasMessageThat().contains(SceneA.debugName)
     }
+
+    @Test
+    fun sceneKeyInScope() {
+        val state = rule.runOnUiThread { MutableSceneTransitionLayoutState(SceneA) }
+
+        var keyInA: SceneKey? = null
+        var keyInB: SceneKey? = null
+        var keyInC: SceneKey? = null
+        rule.setContent {
+            SceneTransitionLayout(state) {
+                scene(SceneA) { keyInA = sceneKey }
+                scene(SceneB) { keyInB = sceneKey }
+                scene(SceneC) { keyInC = sceneKey }
+            }
+        }
+
+        // Snap to B then C to compose these scenes at least once.
+        rule.runOnUiThread { state.snapToScene(SceneB) }
+        rule.waitForIdle()
+        rule.runOnUiThread { state.snapToScene(SceneC) }
+        rule.waitForIdle()
+
+        assertThat(keyInA).isEqualTo(SceneA)
+        assertThat(keyInB).isEqualTo(SceneB)
+        assertThat(keyInC).isEqualTo(SceneC)
+    }
 }