Merge changes I5b2ee38d,I5270eaf8 into main

* changes:
  Prevent elements from jump-cutting after an interruption
  Move Transition test utils to tests source folder
diff --git a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/AnimateToScene.kt b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/AnimateToScene.kt
index da07f6d..6b289f3 100644
--- a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/AnimateToScene.kt
+++ b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/AnimateToScene.kt
@@ -190,4 +190,4 @@
 
 // TODO(b/290184746): Compute a good default visibility threshold that depends on the layout size
 // and screen density.
-private const val ProgressVisibilityThreshold = 1e-3f
+internal const val ProgressVisibilityThreshold = 1e-3f
diff --git a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/Element.kt b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/Element.kt
index 7d43ca8..4273b4f 100644
--- a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/Element.kt
+++ b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/Element.kt
@@ -24,6 +24,7 @@
 import androidx.compose.ui.ExperimentalComposeUiApi
 import androidx.compose.ui.Modifier
 import androidx.compose.ui.geometry.Offset
+import androidx.compose.ui.geometry.isSpecified
 import androidx.compose.ui.geometry.isUnspecified
 import androidx.compose.ui.geometry.lerp
 import androidx.compose.ui.graphics.CompositingStrategy
@@ -55,9 +56,15 @@
 internal class Element(val key: ElementKey) {
     /** The mapping between a scene and the state this element has in that scene, if any. */
     // TODO(b/316901148): Make this a normal map instead once we can make sure that new transitions
-    // are first seen by composition then layout/drawing code. See 316901148#comment2 for details.
+    // are first seen by composition then layout/drawing code. See b/316901148#comment2 for details.
     val sceneStates = SnapshotStateMap<SceneKey, SceneState>()
 
+    /**
+     * The last transition that was used when computing the state (size, position and alpha) of this
+     * element in any scene, or `null` if it was last laid out when idle.
+     */
+    var lastTransition: TransitionState.Transition? = null
+
     override fun toString(): String {
         return "Element(key=$key)"
     }
@@ -65,9 +72,33 @@
     /** The last and target state of this element in a given scene. */
     @Stable
     class SceneState(val scene: SceneKey) {
+        /**
+         * The *target* state of this element in this scene, i.e. the state of this element when we
+         * are idle on this scene.
+         */
         var targetSize by mutableStateOf(SizeUnspecified)
         var targetOffset by mutableStateOf(Offset.Unspecified)
 
+        /** The last state this element had in this scene. */
+        var lastOffset = Offset.Unspecified
+        var lastScale = Scale.Unspecified
+        var lastAlpha = AlphaUnspecified
+
+        /** The state of this element in this scene right before the last interruption (if any). */
+        var offsetBeforeInterruption = Offset.Unspecified
+        var scaleBeforeInterruption = Scale.Unspecified
+        var alphaBeforeInterruption = AlphaUnspecified
+
+        /**
+         * The delta values to add to this element state to have smoother interruptions. These
+         * should be multiplied by the
+         * [current interruption progress][TransitionState.Transition.interruptionProgress] so that
+         * they nicely animate from their values down to 0.
+         */
+        var offsetInterruptionDelta = Offset.Zero
+        var scaleInterruptionDelta = Scale.Zero
+        var alphaInterruptionDelta = 0f
+
         /**
          * The attached [ElementNode] a Modifier.element() for a given element and scene. During
          * composition, this set could have 0 to 2 elements. After composition and after all
@@ -78,12 +109,15 @@
 
     companion object {
         val SizeUnspecified = IntSize(Int.MAX_VALUE, Int.MAX_VALUE)
+        val AlphaUnspecified = Float.MAX_VALUE
     }
 }
 
 data class Scale(val scaleX: Float, val scaleY: Float, val pivot: Offset = Offset.Unspecified) {
     companion object {
         val Default = Scale(1f, 1f, Offset.Unspecified)
+        val Zero = Scale(0f, 0f, Offset.Zero)
+        val Unspecified = Scale(Float.MAX_VALUE, Float.MAX_VALUE, Offset.Unspecified)
     }
 }
 
@@ -212,6 +246,10 @@
         val isOtherSceneOverscrolling = overscrollScene != null && overscrollScene != scene.key
         val isNotPartOfAnyOngoingTransitions = transitions.isNotEmpty() && transition == null
         if (isNotPartOfAnyOngoingTransitions || isOtherSceneOverscrolling) {
+            sceneState.lastOffset = Offset.Unspecified
+            sceneState.lastScale = Scale.Unspecified
+            sceneState.lastAlpha = Element.AlphaUnspecified
+
             val placeable = measurable.measure(constraints)
             return layout(placeable.width, placeable.height) {}
         }
@@ -233,7 +271,7 @@
 
     override fun ContentDrawScope.draw() {
         val transition = elementTransition(element, layoutImpl.state.currentTransitions)
-        val drawScale = getDrawScale(layoutImpl, scene, element, transition)
+        val drawScale = getDrawScale(layoutImpl, scene, element, transition, sceneState)
         if (drawScale == Scale.Default) {
             drawContent()
         } else {
@@ -276,8 +314,116 @@
     element: Element,
     transitions: List<TransitionState.Transition>,
 ): TransitionState.Transition? {
-    return transitions.fastLastOrNull { transition ->
-        transition.fromScene in element.sceneStates || transition.toScene in element.sceneStates
+    val transition =
+        transitions.fastLastOrNull { transition ->
+            transition.fromScene in element.sceneStates || transition.toScene in element.sceneStates
+        }
+
+    val previousTransition = element.lastTransition
+    element.lastTransition = transition
+
+    if (transition != previousTransition && transition != null && previousTransition != null) {
+        // The previous transition was interrupted by another transition.
+        prepareInterruption(element)
+    }
+
+    if (transition == null && previousTransition != null) {
+        // The transition was just finished.
+        element.sceneStates.values.forEach { sceneState ->
+            sceneState.offsetInterruptionDelta = Offset.Zero
+            sceneState.scaleInterruptionDelta = Scale.Zero
+            sceneState.alphaInterruptionDelta = 0f
+        }
+    }
+
+    return transition
+}
+
+private fun prepareInterruption(element: Element) {
+    // We look for the last unique state of this element so that we animate the delta with its
+    // future state.
+    val sceneStates = element.sceneStates.values
+    var lastUniqueState: Element.SceneState? = null
+    for (sceneState in sceneStates) {
+        val offset = sceneState.lastOffset
+
+        // If the element was placed in this scene...
+        if (offset != Offset.Unspecified) {
+            // ... and it is the first (and potentially the only) scene where the element was
+            // placed, save the state for later.
+            if (lastUniqueState == null) {
+                lastUniqueState = sceneState
+            } else {
+                // The element was placed in multiple scenes: we abort the interruption for this
+                // element.
+                // TODO(b/290930950): Better support cases where a shared element animation is
+                // disabled and the same element is drawn/placed in multiple scenes at the same
+                // time.
+                lastUniqueState = null
+                break
+            }
+        }
+    }
+
+    val lastOffset = lastUniqueState?.lastOffset ?: Offset.Unspecified
+    val lastScale = lastUniqueState?.lastScale ?: Scale.Unspecified
+    val lastAlpha = lastUniqueState?.lastAlpha ?: Element.AlphaUnspecified
+
+    // Store the state of the element before the interruption and reset the deltas.
+    sceneStates.forEach { sceneState ->
+        sceneState.offsetBeforeInterruption = lastOffset
+        sceneState.scaleBeforeInterruption = lastScale
+        sceneState.alphaBeforeInterruption = lastAlpha
+
+        sceneState.offsetInterruptionDelta = Offset.Zero
+        sceneState.scaleInterruptionDelta = Scale.Zero
+        sceneState.alphaInterruptionDelta = 0f
+    }
+}
+
+/**
+ * Compute what [value] should be if we take the
+ * [interruption progress][TransitionState.Transition.interruptionProgress] of [transition] into
+ * account.
+ */
+private inline fun <T> computeInterruptedValue(
+    layoutImpl: SceneTransitionLayoutImpl,
+    transition: TransitionState.Transition?,
+    value: T,
+    unspecifiedValue: T,
+    zeroValue: T,
+    getValueBeforeInterruption: () -> T,
+    setValueBeforeInterruption: (T) -> Unit,
+    getInterruptionDelta: () -> T,
+    setInterruptionDelta: (T) -> Unit,
+    diff: (a: T, b: T) -> T, // a - b
+    add: (a: T, b: T, bProgress: Float) -> T, // a + (b * bProgress)
+): T {
+    val valueBeforeInterruption = getValueBeforeInterruption()
+
+    // If the value before the interruption is specified, it means that this is the first time we
+    // compute [value] right after an interruption.
+    if (valueBeforeInterruption != unspecifiedValue) {
+        // Compute and store the delta between the value before the interruption and the current
+        // value.
+        setInterruptionDelta(diff(valueBeforeInterruption, value))
+
+        // Reset the value before interruption now that we processed it.
+        setValueBeforeInterruption(unspecifiedValue)
+    }
+
+    val delta = getInterruptionDelta()
+    return if (delta == zeroValue || transition == null) {
+        // There was no interruption or there is no transition: just return the value.
+        value
+    } else {
+        // Add `delta * interruptionProgress` to the value so that we animate to value.
+        val interruptionProgress = transition.interruptionProgress(layoutImpl)
+        if (interruptionProgress == 0f) {
+            value
+        } else {
+            add(value, delta, interruptionProgress)
+        }
     }
 }
 
@@ -417,20 +563,47 @@
     scene: Scene,
     element: Element,
     transition: TransitionState.Transition?,
+    sceneState: Element.SceneState,
 ): Float {
-    return computeValue(
-            layoutImpl,
-            scene,
-            element,
-            transition,
-            sceneValue = { 1f },
-            transformation = { it.alpha },
-            idleValue = 1f,
-            currentValue = { 1f },
-            isSpecified = { true },
-            ::lerp,
-        )
-        .fastCoerceIn(0f, 1f)
+    val alpha =
+        computeValue(
+                layoutImpl,
+                scene,
+                element,
+                transition,
+                sceneValue = { 1f },
+                transformation = { it.alpha },
+                idleValue = 1f,
+                currentValue = { 1f },
+                isSpecified = { true },
+                ::lerp,
+            )
+            .fastCoerceIn(0f, 1f)
+
+    val interruptedAlpha = interruptedAlpha(layoutImpl, transition, sceneState, alpha)
+    sceneState.lastAlpha = interruptedAlpha
+    return interruptedAlpha
+}
+
+private fun interruptedAlpha(
+    layoutImpl: SceneTransitionLayoutImpl,
+    transition: TransitionState.Transition?,
+    sceneState: Element.SceneState,
+    alpha: Float,
+): Float {
+    return computeInterruptedValue(
+        layoutImpl,
+        transition,
+        value = alpha,
+        unspecifiedValue = Element.AlphaUnspecified,
+        zeroValue = 0f,
+        getValueBeforeInterruption = { sceneState.alphaBeforeInterruption },
+        setValueBeforeInterruption = { sceneState.alphaBeforeInterruption = it },
+        getInterruptionDelta = { sceneState.alphaInterruptionDelta },
+        setInterruptionDelta = { sceneState.alphaInterruptionDelta = it },
+        diff = { a, b -> a - b },
+        add = { a, b, bProgress -> a + b * bProgress },
+    )
 }
 
 @OptIn(ExperimentalComposeUiApi::class)
@@ -480,24 +653,70 @@
         )
 }
 
-private fun getDrawScale(
+private fun ContentDrawScope.getDrawScale(
     layoutImpl: SceneTransitionLayoutImpl,
     scene: Scene,
     element: Element,
     transition: TransitionState.Transition?,
+    sceneState: Element.SceneState,
 ): Scale {
-    return computeValue(
-        layoutImpl,
-        scene,
-        element,
-        transition,
-        sceneValue = { Scale.Default },
-        transformation = { it.drawScale },
-        idleValue = Scale.Default,
-        currentValue = { Scale.Default },
-        isSpecified = { true },
-        ::lerp,
-    )
+    val scale =
+        computeValue(
+            layoutImpl,
+            scene,
+            element,
+            transition,
+            sceneValue = { Scale.Default },
+            transformation = { it.drawScale },
+            idleValue = Scale.Default,
+            currentValue = { Scale.Default },
+            isSpecified = { true },
+            ::lerp,
+        )
+
+    fun Offset.specifiedOrCenter(): Offset {
+        return this.takeIf { isSpecified } ?: center
+    }
+
+    val interruptedScale =
+        computeInterruptedValue(
+            layoutImpl,
+            transition,
+            value = scale,
+            unspecifiedValue = Scale.Unspecified,
+            zeroValue = Scale.Zero,
+            getValueBeforeInterruption = { sceneState.scaleBeforeInterruption },
+            setValueBeforeInterruption = { sceneState.scaleBeforeInterruption = it },
+            getInterruptionDelta = { sceneState.scaleInterruptionDelta },
+            setInterruptionDelta = { sceneState.scaleInterruptionDelta = it },
+            diff = { a, b ->
+                Scale(
+                    scaleX = a.scaleX - b.scaleX,
+                    scaleY = a.scaleY - b.scaleY,
+                    pivot =
+                        if (a.pivot.isUnspecified && b.pivot.isUnspecified) {
+                            Offset.Unspecified
+                        } else {
+                            a.pivot.specifiedOrCenter() - b.pivot.specifiedOrCenter()
+                        }
+                )
+            },
+            add = { a, b, bProgress ->
+                Scale(
+                    scaleX = a.scaleX + b.scaleX * bProgress,
+                    scaleY = a.scaleY + b.scaleY * bProgress,
+                    pivot =
+                        if (a.pivot.isUnspecified && b.pivot.isUnspecified) {
+                            Offset.Unspecified
+                        } else {
+                            a.pivot.specifiedOrCenter() + b.pivot.specifiedOrCenter() * bProgress
+                        }
+                )
+            }
+        )
+
+    sceneState.lastScale = interruptedScale
+    return interruptedScale
 }
 
 @OptIn(ExperimentalComposeUiApi::class)
@@ -524,6 +743,8 @@
 
         // No need to place the element in this scene if we don't want to draw it anyways.
         if (!shouldPlaceElement(layoutImpl, scene, element, transition)) {
+            sceneState.lastOffset = Offset.Unspecified
+            sceneState.offsetBeforeInterruption = Offset.Unspecified
             return
         }
 
@@ -542,15 +763,37 @@
                 ::lerp,
             )
 
-        val offset = (targetOffset - currentOffset).round()
-        if (isElementOpaque(scene, element, transition)) {
+        val interruptedOffset =
+            computeInterruptedValue(
+                layoutImpl,
+                transition,
+                value = targetOffset,
+                unspecifiedValue = Offset.Unspecified,
+                zeroValue = Offset.Zero,
+                getValueBeforeInterruption = { sceneState.offsetBeforeInterruption },
+                setValueBeforeInterruption = { sceneState.offsetBeforeInterruption = it },
+                getInterruptionDelta = { sceneState.offsetInterruptionDelta },
+                setInterruptionDelta = { sceneState.offsetInterruptionDelta = it },
+                diff = { a, b -> a - b },
+                add = { a, b, bProgress -> a + b * bProgress },
+            )
+
+        sceneState.lastOffset = interruptedOffset
+
+        val offset = (interruptedOffset - currentOffset).round()
+        if (
+            isElementOpaque(scene, element, transition) &&
+                interruptedAlpha(layoutImpl, transition, sceneState, alpha = 1f) == 1f
+        ) {
+            sceneState.lastAlpha = 1f
+
             // TODO(b/291071158): Call placeWithLayer() if offset != IntOffset.Zero and size is not
             // animated once b/305195729 is fixed. Test that drawing is not invalidated in that
             // case.
             placeable.place(offset)
         } else {
             placeable.placeWithLayer(offset) {
-                alpha = elementAlpha(layoutImpl, scene, element, transition)
+                alpha = elementAlpha(layoutImpl, scene, element, transition, sceneState)
                 compositingStrategy = CompositingStrategy.ModulateAlpha
             }
         }
diff --git a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayoutImpl.kt b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayoutImpl.kt
index 20dcc20..ad691ba 100644
--- a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayoutImpl.kt
+++ b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayoutImpl.kt
@@ -49,7 +49,7 @@
     internal var swipeSourceDetector: SwipeSourceDetector,
     internal var transitionInterceptionThreshold: Float,
     builder: SceneTransitionLayoutScope.() -> Unit,
-    private val coroutineScope: CoroutineScope,
+    internal val coroutineScope: CoroutineScope,
 ) {
     /**
      * The map of [Scene]s.
diff --git a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayoutState.kt b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayoutState.kt
index f13c016..5fda77a 100644
--- a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayoutState.kt
+++ b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayoutState.kt
@@ -18,6 +18,9 @@
 
 import android.util.Log
 import androidx.annotation.VisibleForTesting
+import androidx.compose.animation.core.Animatable
+import androidx.compose.animation.core.AnimationVector1D
+import androidx.compose.animation.core.spring
 import androidx.compose.foundation.gestures.Orientation
 import androidx.compose.runtime.Composable
 import androidx.compose.runtime.LaunchedEffect
@@ -34,6 +37,7 @@
 import kotlinx.coroutines.CoroutineScope
 import kotlinx.coroutines.Job
 import kotlinx.coroutines.channels.Channel
+import kotlinx.coroutines.launch
 
 /**
  * The state of a [SceneTransitionLayout].
@@ -253,6 +257,12 @@
                 }
             }
 
+        /**
+         * An animatable that animates from 1f to 0f. This will be used to nicely animate the sudden
+         * jump of values when this transitions interrupts another one.
+         */
+        private var interruptionDecay: Animatable<Float, AnimationVector1D>? = null
+
         init {
             check(fromScene != toScene)
         }
@@ -289,6 +299,33 @@
             fromOverscrollSpec = fromSpec
             toOverscrollSpec = toSpec
         }
+
+        internal open fun interruptionProgress(
+            layoutImpl: SceneTransitionLayoutImpl,
+        ): Float {
+            if (!layoutImpl.state.enableInterruptions) {
+                return 0f
+            }
+
+            fun create(): Animatable<Float, AnimationVector1D> {
+                val animatable = Animatable(1f, visibilityThreshold = ProgressVisibilityThreshold)
+                layoutImpl.coroutineScope.launch {
+                    val swipeSpec = layoutImpl.state.transitions.defaultSwipeSpec
+                    val progressSpec =
+                        spring(
+                            stiffness = swipeSpec.stiffness,
+                            dampingRatio = swipeSpec.dampingRatio,
+                            visibilityThreshold = ProgressVisibilityThreshold,
+                        )
+                    animatable.animateTo(0f, progressSpec)
+                }
+
+                return animatable
+            }
+
+            val animatable = interruptionDecay ?: create().also { interruptionDecay = it }
+            return animatable.value
+        }
     }
 
     interface HasOverscrollProperties {
diff --git a/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/ElementTest.kt b/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/ElementTest.kt
index b7fc91c..b1d7055 100644
--- a/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/ElementTest.kt
+++ b/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/ElementTest.kt
@@ -37,6 +37,7 @@
 import androidx.compose.runtime.LaunchedEffect
 import androidx.compose.runtime.SideEffect
 import androidx.compose.runtime.getValue
+import androidx.compose.runtime.mutableFloatStateOf
 import androidx.compose.runtime.mutableStateOf
 import androidx.compose.runtime.rememberCoroutineScope
 import androidx.compose.runtime.setValue
@@ -55,7 +56,10 @@
 import androidx.compose.ui.test.onRoot
 import androidx.compose.ui.test.performTouchInput
 import androidx.compose.ui.unit.Dp
+import androidx.compose.ui.unit.DpOffset
+import androidx.compose.ui.unit.DpSize
 import androidx.compose.ui.unit.dp
+import androidx.compose.ui.unit.lerp
 import androidx.test.ext.junit.runners.AndroidJUnit4
 import com.android.compose.animation.scene.TestScenes.SceneA
 import com.android.compose.animation.scene.TestScenes.SceneB
@@ -1019,4 +1023,122 @@
         rule.onNode(isElement(TestElements.Foo)).assertDoesNotExist()
         rule.onNode(isElement(TestElements.Bar)).assertPositionInRootIsEqualTo(100.dp, 100.dp)
     }
+
+    @Test
+    fun interruption() = runTest {
+        // 4 frames of animation.
+        val duration = 4 * 16
+
+        val state =
+            MutableSceneTransitionLayoutStateImpl(
+                SceneA,
+                transitions {
+                    from(SceneA, to = SceneB) { spec = tween(duration, easing = LinearEasing) }
+                    from(SceneB, to = SceneC) { spec = tween(duration, easing = LinearEasing) }
+                },
+                enableInterruptions = false,
+            )
+
+        val layoutSize = DpSize(200.dp, 100.dp)
+        val fooSize = DpSize(20.dp, 10.dp)
+
+        @Composable
+        fun SceneScope.Foo(modifier: Modifier = Modifier) {
+            Box(modifier.element(TestElements.Foo).size(fooSize))
+        }
+
+        rule.setContent {
+            SceneTransitionLayout(state, Modifier.size(layoutSize)) {
+                // In scene A, Foo is aligned at the TopStart.
+                scene(SceneA) {
+                    Box(Modifier.fillMaxSize()) { Foo(Modifier.align(Alignment.TopStart)) }
+                }
+
+                // In scene B, Foo is aligned at the TopEnd, so it moves horizontally when coming
+                // from A.
+                scene(SceneB) {
+                    Box(Modifier.fillMaxSize()) { Foo(Modifier.align(Alignment.TopEnd)) }
+                }
+
+                // In scene C, Foo is aligned at the BottomEnd, so it moves vertically when coming
+                // from B.
+                scene(SceneC) {
+                    Box(Modifier.fillMaxSize()) { Foo(Modifier.align(Alignment.BottomEnd)) }
+                }
+            }
+        }
+
+        // The offset of Foo when idle in A, B or C.
+        val offsetInA = DpOffset.Zero
+        val offsetInB = DpOffset(layoutSize.width - fooSize.width, 0.dp)
+        val offsetInC =
+            DpOffset(layoutSize.width - fooSize.width, layoutSize.height - fooSize.height)
+
+        // Initial state (idle in A).
+        rule
+            .onNode(isElement(TestElements.Foo, SceneA))
+            .assertPositionInRootIsEqualTo(offsetInA.x, offsetInA.y)
+
+        // Current transition is A => B at 50%.
+        val aToBProgress = 0.5f
+        val aToB =
+            transition(
+                from = SceneA,
+                to = SceneB,
+                progress = { aToBProgress },
+                onFinish = neverFinish(),
+            )
+        val offsetInAToB = lerp(offsetInA, offsetInB, aToBProgress)
+        rule.runOnUiThread { state.startTransition(aToB, transitionKey = null) }
+        rule
+            .onNode(isElement(TestElements.Foo, SceneB))
+            .assertPositionInRootIsEqualTo(offsetInAToB.x, offsetInAToB.y)
+
+        // Start B => C at 0%.
+        var bToCProgress by mutableFloatStateOf(0f)
+        var interruptionProgress by mutableFloatStateOf(1f)
+        val bToC =
+            transition(
+                from = SceneB,
+                to = SceneC,
+                progress = { bToCProgress },
+                interruptionProgress = { interruptionProgress },
+            )
+        rule.runOnUiThread { state.startTransition(bToC, transitionKey = null) }
+
+        // The offset interruption delta, which will be multiplied by the interruption progress then
+        // added to the current transition offset.
+        val interruptionDelta = offsetInAToB - offsetInB
+
+        // Interruption progress is at 100% and bToC is at 0%, so Foo should be at the same offset
+        // as right before the interruption.
+        rule
+            .onNode(isElement(TestElements.Foo, SceneC))
+            .assertPositionInRootIsEqualTo(offsetInAToB.x, offsetInAToB.y)
+
+        // Move the transition forward at 30% and set the interruption progress to 50%.
+        bToCProgress = 0.3f
+        interruptionProgress = 0.5f
+        val offsetInBToC = lerp(offsetInB, offsetInC, bToCProgress)
+        val offsetInBToCWithInterruption =
+            offsetInBToC +
+                DpOffset(
+                    interruptionDelta.x * interruptionProgress,
+                    interruptionDelta.y * interruptionProgress,
+                )
+        rule.waitForIdle()
+        rule
+            .onNode(isElement(TestElements.Foo, SceneC))
+            .assertPositionInRootIsEqualTo(
+                offsetInBToCWithInterruption.x,
+                offsetInBToCWithInterruption.y,
+            )
+
+        // Finish the transition and interruption.
+        bToCProgress = 1f
+        interruptionProgress = 0f
+        rule
+            .onNode(isElement(TestElements.Foo, SceneC))
+            .assertPositionInRootIsEqualTo(offsetInC.x, offsetInC.y)
+    }
 }
diff --git a/packages/SystemUI/compose/scene/tests/utils/src/com/android/compose/animation/scene/Transition.kt b/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/Transition.kt
similarity index 73%
rename from packages/SystemUI/compose/scene/tests/utils/src/com/android/compose/animation/scene/Transition.kt
rename to packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/Transition.kt
index 767057b..c1218ae 100644
--- a/packages/SystemUI/compose/scene/tests/utils/src/com/android/compose/animation/scene/Transition.kt
+++ b/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/Transition.kt
@@ -18,12 +18,17 @@
 
 import androidx.compose.foundation.gestures.Orientation
 import kotlinx.coroutines.Job
+import kotlinx.coroutines.launch
+import kotlinx.coroutines.sync.Mutex
+import kotlinx.coroutines.sync.withLock
+import kotlinx.coroutines.test.TestScope
 
 /** A utility to easily create a [TransitionState.Transition] in tests. */
 fun transition(
     from: SceneKey,
     to: SceneKey,
     progress: () -> Float = { 0f },
+    interruptionProgress: () -> Float = { 100f },
     isInitiatedByUserInput: Boolean = false,
     isUserInputOngoing: Boolean = false,
     isUpOrLeft: Boolean = false,
@@ -55,5 +60,22 @@
 
             return onFinish(this)
         }
+
+        override fun interruptionProgress(layoutImpl: SceneTransitionLayoutImpl): Float {
+            return interruptionProgress()
+        }
+    }
+}
+
+/**
+ * Return a onFinish lambda that can be used with [transition] so that the transition never
+ * finishes. This allows to keep the transition in the current transitions list.
+ */
+fun TestScope.neverFinish(): (TransitionState.Transition) -> Job {
+    return {
+        backgroundScope.launch {
+            // Try to acquire a locked mutex so that this code never completes.
+            Mutex(locked = true).withLock {}
+        }
     }
 }