Prevent elements from jump-cutting after an interruption

This CL adds initial support for interruptions.

Before this CL, whenever a new transition was started, all elements
taking part in this transition (i.e. elements in the transition
fromScene or toScene) would instantly jump to their new state in this
transition.

This CL improves this by tracking the current transition used by an
element. If that transition is interrupted (i.e. the transition for that
element changed), we save the current state of the element. Then, the
next time we compute the state of the element, which is the first time
we compute the state of the element given the new transition, we also
compute and save the diff/delta between the state we saved earlier and
the new state. This delta is then animated to zero and added to the
state computed with the new transition. That way, we nicely animate to
the new state of the new transition while preventing jump cuts caused by
the interruption.

This CL adds support for elements offset, alpha and scale. The size will
be supported in a follow-up CL; it is quite harder to support given that
elements can be measured only once.

The unit test focus on offset only at the moment, because testing scale
and alpha in a unit test can't really be done without changing
production code. I *might* add some screenshot tests for this in another
CL.

See b/290930950#comment5 for details on how this works.

Bug: 290930950
Test: ElementTest
Test: Performed a lot of different transitions in a row in the STL demo
 manually.
Test: This change should not have any impact on usages outside of the
 STL demo given that interruptions were disabled everywhere in
 ag/26621600 and will be enabled later after thorough manual testing.

Flag: N/A

Change-Id: I5b2ee38d71b947cd3aaf587194281c75eca1a0ca
diff --git a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/AnimateToScene.kt b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/AnimateToScene.kt
index da07f6d..6b289f3 100644
--- a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/AnimateToScene.kt
+++ b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/AnimateToScene.kt
@@ -190,4 +190,4 @@
 
 // TODO(b/290184746): Compute a good default visibility threshold that depends on the layout size
 // and screen density.
-private const val ProgressVisibilityThreshold = 1e-3f
+internal const val ProgressVisibilityThreshold = 1e-3f
diff --git a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/Element.kt b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/Element.kt
index 7d43ca8..4273b4f 100644
--- a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/Element.kt
+++ b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/Element.kt
@@ -24,6 +24,7 @@
 import androidx.compose.ui.ExperimentalComposeUiApi
 import androidx.compose.ui.Modifier
 import androidx.compose.ui.geometry.Offset
+import androidx.compose.ui.geometry.isSpecified
 import androidx.compose.ui.geometry.isUnspecified
 import androidx.compose.ui.geometry.lerp
 import androidx.compose.ui.graphics.CompositingStrategy
@@ -55,9 +56,15 @@
 internal class Element(val key: ElementKey) {
     /** The mapping between a scene and the state this element has in that scene, if any. */
     // TODO(b/316901148): Make this a normal map instead once we can make sure that new transitions
-    // are first seen by composition then layout/drawing code. See 316901148#comment2 for details.
+    // are first seen by composition then layout/drawing code. See b/316901148#comment2 for details.
     val sceneStates = SnapshotStateMap<SceneKey, SceneState>()
 
+    /**
+     * The last transition that was used when computing the state (size, position and alpha) of this
+     * element in any scene, or `null` if it was last laid out when idle.
+     */
+    var lastTransition: TransitionState.Transition? = null
+
     override fun toString(): String {
         return "Element(key=$key)"
     }
@@ -65,9 +72,33 @@
     /** The last and target state of this element in a given scene. */
     @Stable
     class SceneState(val scene: SceneKey) {
+        /**
+         * The *target* state of this element in this scene, i.e. the state of this element when we
+         * are idle on this scene.
+         */
         var targetSize by mutableStateOf(SizeUnspecified)
         var targetOffset by mutableStateOf(Offset.Unspecified)
 
+        /** The last state this element had in this scene. */
+        var lastOffset = Offset.Unspecified
+        var lastScale = Scale.Unspecified
+        var lastAlpha = AlphaUnspecified
+
+        /** The state of this element in this scene right before the last interruption (if any). */
+        var offsetBeforeInterruption = Offset.Unspecified
+        var scaleBeforeInterruption = Scale.Unspecified
+        var alphaBeforeInterruption = AlphaUnspecified
+
+        /**
+         * The delta values to add to this element state to have smoother interruptions. These
+         * should be multiplied by the
+         * [current interruption progress][TransitionState.Transition.interruptionProgress] so that
+         * they nicely animate from their values down to 0.
+         */
+        var offsetInterruptionDelta = Offset.Zero
+        var scaleInterruptionDelta = Scale.Zero
+        var alphaInterruptionDelta = 0f
+
         /**
          * The attached [ElementNode] a Modifier.element() for a given element and scene. During
          * composition, this set could have 0 to 2 elements. After composition and after all
@@ -78,12 +109,15 @@
 
     companion object {
         val SizeUnspecified = IntSize(Int.MAX_VALUE, Int.MAX_VALUE)
+        val AlphaUnspecified = Float.MAX_VALUE
     }
 }
 
 data class Scale(val scaleX: Float, val scaleY: Float, val pivot: Offset = Offset.Unspecified) {
     companion object {
         val Default = Scale(1f, 1f, Offset.Unspecified)
+        val Zero = Scale(0f, 0f, Offset.Zero)
+        val Unspecified = Scale(Float.MAX_VALUE, Float.MAX_VALUE, Offset.Unspecified)
     }
 }
 
@@ -212,6 +246,10 @@
         val isOtherSceneOverscrolling = overscrollScene != null && overscrollScene != scene.key
         val isNotPartOfAnyOngoingTransitions = transitions.isNotEmpty() && transition == null
         if (isNotPartOfAnyOngoingTransitions || isOtherSceneOverscrolling) {
+            sceneState.lastOffset = Offset.Unspecified
+            sceneState.lastScale = Scale.Unspecified
+            sceneState.lastAlpha = Element.AlphaUnspecified
+
             val placeable = measurable.measure(constraints)
             return layout(placeable.width, placeable.height) {}
         }
@@ -233,7 +271,7 @@
 
     override fun ContentDrawScope.draw() {
         val transition = elementTransition(element, layoutImpl.state.currentTransitions)
-        val drawScale = getDrawScale(layoutImpl, scene, element, transition)
+        val drawScale = getDrawScale(layoutImpl, scene, element, transition, sceneState)
         if (drawScale == Scale.Default) {
             drawContent()
         } else {
@@ -276,8 +314,116 @@
     element: Element,
     transitions: List<TransitionState.Transition>,
 ): TransitionState.Transition? {
-    return transitions.fastLastOrNull { transition ->
-        transition.fromScene in element.sceneStates || transition.toScene in element.sceneStates
+    val transition =
+        transitions.fastLastOrNull { transition ->
+            transition.fromScene in element.sceneStates || transition.toScene in element.sceneStates
+        }
+
+    val previousTransition = element.lastTransition
+    element.lastTransition = transition
+
+    if (transition != previousTransition && transition != null && previousTransition != null) {
+        // The previous transition was interrupted by another transition.
+        prepareInterruption(element)
+    }
+
+    if (transition == null && previousTransition != null) {
+        // The transition was just finished.
+        element.sceneStates.values.forEach { sceneState ->
+            sceneState.offsetInterruptionDelta = Offset.Zero
+            sceneState.scaleInterruptionDelta = Scale.Zero
+            sceneState.alphaInterruptionDelta = 0f
+        }
+    }
+
+    return transition
+}
+
+private fun prepareInterruption(element: Element) {
+    // We look for the last unique state of this element so that we animate the delta with its
+    // future state.
+    val sceneStates = element.sceneStates.values
+    var lastUniqueState: Element.SceneState? = null
+    for (sceneState in sceneStates) {
+        val offset = sceneState.lastOffset
+
+        // If the element was placed in this scene...
+        if (offset != Offset.Unspecified) {
+            // ... and it is the first (and potentially the only) scene where the element was
+            // placed, save the state for later.
+            if (lastUniqueState == null) {
+                lastUniqueState = sceneState
+            } else {
+                // The element was placed in multiple scenes: we abort the interruption for this
+                // element.
+                // TODO(b/290930950): Better support cases where a shared element animation is
+                // disabled and the same element is drawn/placed in multiple scenes at the same
+                // time.
+                lastUniqueState = null
+                break
+            }
+        }
+    }
+
+    val lastOffset = lastUniqueState?.lastOffset ?: Offset.Unspecified
+    val lastScale = lastUniqueState?.lastScale ?: Scale.Unspecified
+    val lastAlpha = lastUniqueState?.lastAlpha ?: Element.AlphaUnspecified
+
+    // Store the state of the element before the interruption and reset the deltas.
+    sceneStates.forEach { sceneState ->
+        sceneState.offsetBeforeInterruption = lastOffset
+        sceneState.scaleBeforeInterruption = lastScale
+        sceneState.alphaBeforeInterruption = lastAlpha
+
+        sceneState.offsetInterruptionDelta = Offset.Zero
+        sceneState.scaleInterruptionDelta = Scale.Zero
+        sceneState.alphaInterruptionDelta = 0f
+    }
+}
+
+/**
+ * Compute what [value] should be if we take the
+ * [interruption progress][TransitionState.Transition.interruptionProgress] of [transition] into
+ * account.
+ */
+private inline fun <T> computeInterruptedValue(
+    layoutImpl: SceneTransitionLayoutImpl,
+    transition: TransitionState.Transition?,
+    value: T,
+    unspecifiedValue: T,
+    zeroValue: T,
+    getValueBeforeInterruption: () -> T,
+    setValueBeforeInterruption: (T) -> Unit,
+    getInterruptionDelta: () -> T,
+    setInterruptionDelta: (T) -> Unit,
+    diff: (a: T, b: T) -> T, // a - b
+    add: (a: T, b: T, bProgress: Float) -> T, // a + (b * bProgress)
+): T {
+    val valueBeforeInterruption = getValueBeforeInterruption()
+
+    // If the value before the interruption is specified, it means that this is the first time we
+    // compute [value] right after an interruption.
+    if (valueBeforeInterruption != unspecifiedValue) {
+        // Compute and store the delta between the value before the interruption and the current
+        // value.
+        setInterruptionDelta(diff(valueBeforeInterruption, value))
+
+        // Reset the value before interruption now that we processed it.
+        setValueBeforeInterruption(unspecifiedValue)
+    }
+
+    val delta = getInterruptionDelta()
+    return if (delta == zeroValue || transition == null) {
+        // There was no interruption or there is no transition: just return the value.
+        value
+    } else {
+        // Add `delta * interruptionProgress` to the value so that we animate to value.
+        val interruptionProgress = transition.interruptionProgress(layoutImpl)
+        if (interruptionProgress == 0f) {
+            value
+        } else {
+            add(value, delta, interruptionProgress)
+        }
     }
 }
 
@@ -417,20 +563,47 @@
     scene: Scene,
     element: Element,
     transition: TransitionState.Transition?,
+    sceneState: Element.SceneState,
 ): Float {
-    return computeValue(
-            layoutImpl,
-            scene,
-            element,
-            transition,
-            sceneValue = { 1f },
-            transformation = { it.alpha },
-            idleValue = 1f,
-            currentValue = { 1f },
-            isSpecified = { true },
-            ::lerp,
-        )
-        .fastCoerceIn(0f, 1f)
+    val alpha =
+        computeValue(
+                layoutImpl,
+                scene,
+                element,
+                transition,
+                sceneValue = { 1f },
+                transformation = { it.alpha },
+                idleValue = 1f,
+                currentValue = { 1f },
+                isSpecified = { true },
+                ::lerp,
+            )
+            .fastCoerceIn(0f, 1f)
+
+    val interruptedAlpha = interruptedAlpha(layoutImpl, transition, sceneState, alpha)
+    sceneState.lastAlpha = interruptedAlpha
+    return interruptedAlpha
+}
+
+private fun interruptedAlpha(
+    layoutImpl: SceneTransitionLayoutImpl,
+    transition: TransitionState.Transition?,
+    sceneState: Element.SceneState,
+    alpha: Float,
+): Float {
+    return computeInterruptedValue(
+        layoutImpl,
+        transition,
+        value = alpha,
+        unspecifiedValue = Element.AlphaUnspecified,
+        zeroValue = 0f,
+        getValueBeforeInterruption = { sceneState.alphaBeforeInterruption },
+        setValueBeforeInterruption = { sceneState.alphaBeforeInterruption = it },
+        getInterruptionDelta = { sceneState.alphaInterruptionDelta },
+        setInterruptionDelta = { sceneState.alphaInterruptionDelta = it },
+        diff = { a, b -> a - b },
+        add = { a, b, bProgress -> a + b * bProgress },
+    )
 }
 
 @OptIn(ExperimentalComposeUiApi::class)
@@ -480,24 +653,70 @@
         )
 }
 
-private fun getDrawScale(
+private fun ContentDrawScope.getDrawScale(
     layoutImpl: SceneTransitionLayoutImpl,
     scene: Scene,
     element: Element,
     transition: TransitionState.Transition?,
+    sceneState: Element.SceneState,
 ): Scale {
-    return computeValue(
-        layoutImpl,
-        scene,
-        element,
-        transition,
-        sceneValue = { Scale.Default },
-        transformation = { it.drawScale },
-        idleValue = Scale.Default,
-        currentValue = { Scale.Default },
-        isSpecified = { true },
-        ::lerp,
-    )
+    val scale =
+        computeValue(
+            layoutImpl,
+            scene,
+            element,
+            transition,
+            sceneValue = { Scale.Default },
+            transformation = { it.drawScale },
+            idleValue = Scale.Default,
+            currentValue = { Scale.Default },
+            isSpecified = { true },
+            ::lerp,
+        )
+
+    fun Offset.specifiedOrCenter(): Offset {
+        return this.takeIf { isSpecified } ?: center
+    }
+
+    val interruptedScale =
+        computeInterruptedValue(
+            layoutImpl,
+            transition,
+            value = scale,
+            unspecifiedValue = Scale.Unspecified,
+            zeroValue = Scale.Zero,
+            getValueBeforeInterruption = { sceneState.scaleBeforeInterruption },
+            setValueBeforeInterruption = { sceneState.scaleBeforeInterruption = it },
+            getInterruptionDelta = { sceneState.scaleInterruptionDelta },
+            setInterruptionDelta = { sceneState.scaleInterruptionDelta = it },
+            diff = { a, b ->
+                Scale(
+                    scaleX = a.scaleX - b.scaleX,
+                    scaleY = a.scaleY - b.scaleY,
+                    pivot =
+                        if (a.pivot.isUnspecified && b.pivot.isUnspecified) {
+                            Offset.Unspecified
+                        } else {
+                            a.pivot.specifiedOrCenter() - b.pivot.specifiedOrCenter()
+                        }
+                )
+            },
+            add = { a, b, bProgress ->
+                Scale(
+                    scaleX = a.scaleX + b.scaleX * bProgress,
+                    scaleY = a.scaleY + b.scaleY * bProgress,
+                    pivot =
+                        if (a.pivot.isUnspecified && b.pivot.isUnspecified) {
+                            Offset.Unspecified
+                        } else {
+                            a.pivot.specifiedOrCenter() + b.pivot.specifiedOrCenter() * bProgress
+                        }
+                )
+            }
+        )
+
+    sceneState.lastScale = interruptedScale
+    return interruptedScale
 }
 
 @OptIn(ExperimentalComposeUiApi::class)
@@ -524,6 +743,8 @@
 
         // No need to place the element in this scene if we don't want to draw it anyways.
         if (!shouldPlaceElement(layoutImpl, scene, element, transition)) {
+            sceneState.lastOffset = Offset.Unspecified
+            sceneState.offsetBeforeInterruption = Offset.Unspecified
             return
         }
 
@@ -542,15 +763,37 @@
                 ::lerp,
             )
 
-        val offset = (targetOffset - currentOffset).round()
-        if (isElementOpaque(scene, element, transition)) {
+        val interruptedOffset =
+            computeInterruptedValue(
+                layoutImpl,
+                transition,
+                value = targetOffset,
+                unspecifiedValue = Offset.Unspecified,
+                zeroValue = Offset.Zero,
+                getValueBeforeInterruption = { sceneState.offsetBeforeInterruption },
+                setValueBeforeInterruption = { sceneState.offsetBeforeInterruption = it },
+                getInterruptionDelta = { sceneState.offsetInterruptionDelta },
+                setInterruptionDelta = { sceneState.offsetInterruptionDelta = it },
+                diff = { a, b -> a - b },
+                add = { a, b, bProgress -> a + b * bProgress },
+            )
+
+        sceneState.lastOffset = interruptedOffset
+
+        val offset = (interruptedOffset - currentOffset).round()
+        if (
+            isElementOpaque(scene, element, transition) &&
+                interruptedAlpha(layoutImpl, transition, sceneState, alpha = 1f) == 1f
+        ) {
+            sceneState.lastAlpha = 1f
+
             // TODO(b/291071158): Call placeWithLayer() if offset != IntOffset.Zero and size is not
             // animated once b/305195729 is fixed. Test that drawing is not invalidated in that
             // case.
             placeable.place(offset)
         } else {
             placeable.placeWithLayer(offset) {
-                alpha = elementAlpha(layoutImpl, scene, element, transition)
+                alpha = elementAlpha(layoutImpl, scene, element, transition, sceneState)
                 compositingStrategy = CompositingStrategy.ModulateAlpha
             }
         }
diff --git a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayoutImpl.kt b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayoutImpl.kt
index 20dcc20..ad691ba 100644
--- a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayoutImpl.kt
+++ b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayoutImpl.kt
@@ -49,7 +49,7 @@
     internal var swipeSourceDetector: SwipeSourceDetector,
     internal var transitionInterceptionThreshold: Float,
     builder: SceneTransitionLayoutScope.() -> Unit,
-    private val coroutineScope: CoroutineScope,
+    internal val coroutineScope: CoroutineScope,
 ) {
     /**
      * The map of [Scene]s.
diff --git a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayoutState.kt b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayoutState.kt
index f13c016..5fda77a 100644
--- a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayoutState.kt
+++ b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayoutState.kt
@@ -18,6 +18,9 @@
 
 import android.util.Log
 import androidx.annotation.VisibleForTesting
+import androidx.compose.animation.core.Animatable
+import androidx.compose.animation.core.AnimationVector1D
+import androidx.compose.animation.core.spring
 import androidx.compose.foundation.gestures.Orientation
 import androidx.compose.runtime.Composable
 import androidx.compose.runtime.LaunchedEffect
@@ -34,6 +37,7 @@
 import kotlinx.coroutines.CoroutineScope
 import kotlinx.coroutines.Job
 import kotlinx.coroutines.channels.Channel
+import kotlinx.coroutines.launch
 
 /**
  * The state of a [SceneTransitionLayout].
@@ -253,6 +257,12 @@
                 }
             }
 
+        /**
+         * An animatable that animates from 1f to 0f. This will be used to nicely animate the sudden
+         * jump of values when this transitions interrupts another one.
+         */
+        private var interruptionDecay: Animatable<Float, AnimationVector1D>? = null
+
         init {
             check(fromScene != toScene)
         }
@@ -289,6 +299,33 @@
             fromOverscrollSpec = fromSpec
             toOverscrollSpec = toSpec
         }
+
+        internal open fun interruptionProgress(
+            layoutImpl: SceneTransitionLayoutImpl,
+        ): Float {
+            if (!layoutImpl.state.enableInterruptions) {
+                return 0f
+            }
+
+            fun create(): Animatable<Float, AnimationVector1D> {
+                val animatable = Animatable(1f, visibilityThreshold = ProgressVisibilityThreshold)
+                layoutImpl.coroutineScope.launch {
+                    val swipeSpec = layoutImpl.state.transitions.defaultSwipeSpec
+                    val progressSpec =
+                        spring(
+                            stiffness = swipeSpec.stiffness,
+                            dampingRatio = swipeSpec.dampingRatio,
+                            visibilityThreshold = ProgressVisibilityThreshold,
+                        )
+                    animatable.animateTo(0f, progressSpec)
+                }
+
+                return animatable
+            }
+
+            val animatable = interruptionDecay ?: create().also { interruptionDecay = it }
+            return animatable.value
+        }
     }
 
     interface HasOverscrollProperties {
diff --git a/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/ElementTest.kt b/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/ElementTest.kt
index b7fc91c..b1d7055 100644
--- a/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/ElementTest.kt
+++ b/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/ElementTest.kt
@@ -37,6 +37,7 @@
 import androidx.compose.runtime.LaunchedEffect
 import androidx.compose.runtime.SideEffect
 import androidx.compose.runtime.getValue
+import androidx.compose.runtime.mutableFloatStateOf
 import androidx.compose.runtime.mutableStateOf
 import androidx.compose.runtime.rememberCoroutineScope
 import androidx.compose.runtime.setValue
@@ -55,7 +56,10 @@
 import androidx.compose.ui.test.onRoot
 import androidx.compose.ui.test.performTouchInput
 import androidx.compose.ui.unit.Dp
+import androidx.compose.ui.unit.DpOffset
+import androidx.compose.ui.unit.DpSize
 import androidx.compose.ui.unit.dp
+import androidx.compose.ui.unit.lerp
 import androidx.test.ext.junit.runners.AndroidJUnit4
 import com.android.compose.animation.scene.TestScenes.SceneA
 import com.android.compose.animation.scene.TestScenes.SceneB
@@ -1019,4 +1023,122 @@
         rule.onNode(isElement(TestElements.Foo)).assertDoesNotExist()
         rule.onNode(isElement(TestElements.Bar)).assertPositionInRootIsEqualTo(100.dp, 100.dp)
     }
+
+    @Test
+    fun interruption() = runTest {
+        // 4 frames of animation.
+        val duration = 4 * 16
+
+        val state =
+            MutableSceneTransitionLayoutStateImpl(
+                SceneA,
+                transitions {
+                    from(SceneA, to = SceneB) { spec = tween(duration, easing = LinearEasing) }
+                    from(SceneB, to = SceneC) { spec = tween(duration, easing = LinearEasing) }
+                },
+                enableInterruptions = false,
+            )
+
+        val layoutSize = DpSize(200.dp, 100.dp)
+        val fooSize = DpSize(20.dp, 10.dp)
+
+        @Composable
+        fun SceneScope.Foo(modifier: Modifier = Modifier) {
+            Box(modifier.element(TestElements.Foo).size(fooSize))
+        }
+
+        rule.setContent {
+            SceneTransitionLayout(state, Modifier.size(layoutSize)) {
+                // In scene A, Foo is aligned at the TopStart.
+                scene(SceneA) {
+                    Box(Modifier.fillMaxSize()) { Foo(Modifier.align(Alignment.TopStart)) }
+                }
+
+                // In scene B, Foo is aligned at the TopEnd, so it moves horizontally when coming
+                // from A.
+                scene(SceneB) {
+                    Box(Modifier.fillMaxSize()) { Foo(Modifier.align(Alignment.TopEnd)) }
+                }
+
+                // In scene C, Foo is aligned at the BottomEnd, so it moves vertically when coming
+                // from B.
+                scene(SceneC) {
+                    Box(Modifier.fillMaxSize()) { Foo(Modifier.align(Alignment.BottomEnd)) }
+                }
+            }
+        }
+
+        // The offset of Foo when idle in A, B or C.
+        val offsetInA = DpOffset.Zero
+        val offsetInB = DpOffset(layoutSize.width - fooSize.width, 0.dp)
+        val offsetInC =
+            DpOffset(layoutSize.width - fooSize.width, layoutSize.height - fooSize.height)
+
+        // Initial state (idle in A).
+        rule
+            .onNode(isElement(TestElements.Foo, SceneA))
+            .assertPositionInRootIsEqualTo(offsetInA.x, offsetInA.y)
+
+        // Current transition is A => B at 50%.
+        val aToBProgress = 0.5f
+        val aToB =
+            transition(
+                from = SceneA,
+                to = SceneB,
+                progress = { aToBProgress },
+                onFinish = neverFinish(),
+            )
+        val offsetInAToB = lerp(offsetInA, offsetInB, aToBProgress)
+        rule.runOnUiThread { state.startTransition(aToB, transitionKey = null) }
+        rule
+            .onNode(isElement(TestElements.Foo, SceneB))
+            .assertPositionInRootIsEqualTo(offsetInAToB.x, offsetInAToB.y)
+
+        // Start B => C at 0%.
+        var bToCProgress by mutableFloatStateOf(0f)
+        var interruptionProgress by mutableFloatStateOf(1f)
+        val bToC =
+            transition(
+                from = SceneB,
+                to = SceneC,
+                progress = { bToCProgress },
+                interruptionProgress = { interruptionProgress },
+            )
+        rule.runOnUiThread { state.startTransition(bToC, transitionKey = null) }
+
+        // The offset interruption delta, which will be multiplied by the interruption progress then
+        // added to the current transition offset.
+        val interruptionDelta = offsetInAToB - offsetInB
+
+        // Interruption progress is at 100% and bToC is at 0%, so Foo should be at the same offset
+        // as right before the interruption.
+        rule
+            .onNode(isElement(TestElements.Foo, SceneC))
+            .assertPositionInRootIsEqualTo(offsetInAToB.x, offsetInAToB.y)
+
+        // Move the transition forward at 30% and set the interruption progress to 50%.
+        bToCProgress = 0.3f
+        interruptionProgress = 0.5f
+        val offsetInBToC = lerp(offsetInB, offsetInC, bToCProgress)
+        val offsetInBToCWithInterruption =
+            offsetInBToC +
+                DpOffset(
+                    interruptionDelta.x * interruptionProgress,
+                    interruptionDelta.y * interruptionProgress,
+                )
+        rule.waitForIdle()
+        rule
+            .onNode(isElement(TestElements.Foo, SceneC))
+            .assertPositionInRootIsEqualTo(
+                offsetInBToCWithInterruption.x,
+                offsetInBToCWithInterruption.y,
+            )
+
+        // Finish the transition and interruption.
+        bToCProgress = 1f
+        interruptionProgress = 0f
+        rule
+            .onNode(isElement(TestElements.Foo, SceneC))
+            .assertPositionInRootIsEqualTo(offsetInC.x, offsetInC.y)
+    }
 }
diff --git a/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/Transition.kt b/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/Transition.kt
index 767057b..c1218ae 100644
--- a/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/Transition.kt
+++ b/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/Transition.kt
@@ -18,12 +18,17 @@
 
 import androidx.compose.foundation.gestures.Orientation
 import kotlinx.coroutines.Job
+import kotlinx.coroutines.launch
+import kotlinx.coroutines.sync.Mutex
+import kotlinx.coroutines.sync.withLock
+import kotlinx.coroutines.test.TestScope
 
 /** A utility to easily create a [TransitionState.Transition] in tests. */
 fun transition(
     from: SceneKey,
     to: SceneKey,
     progress: () -> Float = { 0f },
+    interruptionProgress: () -> Float = { 100f },
     isInitiatedByUserInput: Boolean = false,
     isUserInputOngoing: Boolean = false,
     isUpOrLeft: Boolean = false,
@@ -55,5 +60,22 @@
 
             return onFinish(this)
         }
+
+        override fun interruptionProgress(layoutImpl: SceneTransitionLayoutImpl): Float {
+            return interruptionProgress()
+        }
+    }
+}
+
+/**
+ * Return a onFinish lambda that can be used with [transition] so that the transition never
+ * finishes. This allows to keep the transition in the current transitions list.
+ */
+fun TestScope.neverFinish(): (TransitionState.Transition) -> Job {
+    return {
+        backgroundScope.launch {
+            // Try to acquire a locked mutex so that this code never completes.
+            Mutex(locked = true).withLock {}
+        }
     }
 }