Merge changes I7498f56a,I833d77d5,Ia822d4a2,Ia928cc43,Ie02f4ed0 into main

* changes:
  Prevent size jumps during interruptions
  Revert local change
  Remove calls to invokeOnCompletion
  Enforce that STLState is mutated on the right thread
  Read transitions during composition instead of layout/drawing
diff --git a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/AnimateToScene.kt b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/AnimateToScene.kt
index b5e9313..48a348b 100644
--- a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/AnimateToScene.kt
+++ b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/AnimateToScene.kt
@@ -21,6 +21,7 @@
 import androidx.compose.animation.core.SpringSpec
 import kotlin.math.absoluteValue
 import kotlinx.coroutines.CoroutineScope
+import kotlinx.coroutines.CoroutineStart
 import kotlinx.coroutines.Job
 import kotlinx.coroutines.launch
 
@@ -190,16 +191,17 @@
         }
 
     // Animate the progress to its target value.
+    // Important: We start atomically to make sure that we start the coroutine even if it is
+    // cancelled right after it is launched, so that finishTransition() is correctly called.
+    // Otherwise, this transition will never be stopped and we will never settle to Idle.
     transition.job =
-        launch { animatable.animateTo(targetProgress, animationSpec, initialVelocity) }
-            .apply {
-                invokeOnCompletion {
-                    // Settle the state to Idle(target). Note that this will do nothing if this
-                    // transition was replaced/interrupted by another one, and this also runs if
-                    // this coroutine is cancelled, i.e. if [this] coroutine scope is cancelled.
-                    layoutState.finishTransition(transition, targetScene)
-                }
+        launch(start = CoroutineStart.ATOMIC) {
+            try {
+                animatable.animateTo(targetProgress, animationSpec, initialVelocity)
+            } finally {
+                layoutState.finishTransition(transition, targetScene)
             }
+        }
 
     return transition
 }
diff --git a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/DraggableHandler.kt b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/DraggableHandler.kt
index 6758990..1f81245 100644
--- a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/DraggableHandler.kt
+++ b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/DraggableHandler.kt
@@ -33,6 +33,7 @@
 import com.android.compose.nestedscroll.PriorityNestedScrollConnection
 import kotlin.math.absoluteValue
 import kotlinx.coroutines.CoroutineScope
+import kotlinx.coroutines.CoroutineStart
 import kotlinx.coroutines.Job
 import kotlinx.coroutines.launch
 
@@ -684,7 +685,11 @@
             val isTargetGreater = targetOffset > animatable.value
             val job =
                 coroutineScope
-                    .launch {
+                    // Important: We start atomically to make sure that we start the coroutine even
+                    // if it is cancelled right after it is launched, so that snapToScene() is
+                    // correctly called. Otherwise, this transition will never be stopped and we
+                    // will never settle to Idle.
+                    .launch(start = CoroutineStart.ATOMIC) {
                         // TODO(b/327249191): Refactor the code so that we don't even launch a
                         // coroutine if we don't need to animate.
                         if (skipAnimation) {
@@ -726,18 +731,15 @@
                             }
                         } finally {
                             bouncingScene = null
+                            snapToScene(targetScene)
                         }
                     }
-                    // Make sure that we settle to target scene at the end of the animation or if
-                    // the animation is cancelled.
-                    .apply { invokeOnCompletion { snapToScene(targetScene) } }
 
             OffsetAnimation(animatable, job)
         }
     }
 
     fun snapToScene(scene: SceneKey) {
-        if (layoutState.transitionState != this) return
         cancelOffsetAnimation()
         layoutState.finishTransition(this, idleScene = scene)
     }
diff --git a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/Element.kt b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/Element.kt
index 20742ee..18baee9 100644
--- a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/Element.kt
+++ b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/Element.kt
@@ -49,6 +49,7 @@
 import com.android.compose.animation.scene.transformation.PropertyTransformation
 import com.android.compose.animation.scene.transformation.SharedElementTransformation
 import com.android.compose.ui.util.lerp
+import kotlin.math.roundToInt
 import kotlinx.coroutines.launch
 
 /** An element on screen, that can be composed in one or more scenes. */
@@ -81,11 +82,13 @@
 
         /** The last state this element had in this scene. */
         var lastOffset = Offset.Unspecified
+        var lastSize = SizeUnspecified
         var lastScale = Scale.Unspecified
         var lastAlpha = AlphaUnspecified
 
         /** The state of this element in this scene right before the last interruption (if any). */
         var offsetBeforeInterruption = Offset.Unspecified
+        var sizeBeforeInterruption = SizeUnspecified
         var scaleBeforeInterruption = Scale.Unspecified
         var alphaBeforeInterruption = AlphaUnspecified
 
@@ -96,6 +99,7 @@
          * they nicely animate from their values down to 0.
          */
         var offsetInterruptionDelta = Offset.Zero
+        var sizeInterruptionDelta = IntSize.Zero
         var scaleInterruptionDelta = Scale.Zero
         var alphaInterruptionDelta = 0f
 
@@ -127,7 +131,14 @@
     layoutImpl: SceneTransitionLayoutImpl,
     scene: Scene,
     key: ElementKey,
-): Modifier = this.then(ElementModifier(layoutImpl, scene, key)).testTag(key.testTag)
+): Modifier {
+    // Make sure that we read the current transitions during composition and not during
+    // layout/drawing.
+    // TODO(b/341072461): Revert this and read the current transitions in ElementNode directly once
+    // we can ensure that SceneTransitionLayoutImpl will compose new scenes first.
+    val currentTransitions = layoutImpl.state.currentTransitions
+    return then(ElementModifier(layoutImpl, currentTransitions, scene, key)).testTag(key.testTag)
+}
 
 /**
  * An element associated to [ElementNode]. Note that this element does not support updates as its
@@ -135,18 +146,20 @@
  */
 private data class ElementModifier(
     private val layoutImpl: SceneTransitionLayoutImpl,
+    private val currentTransitions: List<TransitionState.Transition>,
     private val scene: Scene,
     private val key: ElementKey,
 ) : ModifierNodeElement<ElementNode>() {
-    override fun create(): ElementNode = ElementNode(layoutImpl, scene, key)
+    override fun create(): ElementNode = ElementNode(layoutImpl, currentTransitions, scene, key)
 
     override fun update(node: ElementNode) {
-        node.update(layoutImpl, scene, key)
+        node.update(layoutImpl, currentTransitions, scene, key)
     }
 }
 
 internal class ElementNode(
     private var layoutImpl: SceneTransitionLayoutImpl,
+    private var currentTransitions: List<TransitionState.Transition>,
     private var scene: Scene,
     private var key: ElementKey,
 ) : Modifier.Node(), DrawModifierNode, ApproachLayoutModifierNode {
@@ -202,10 +215,13 @@
 
     fun update(
         layoutImpl: SceneTransitionLayoutImpl,
+        currentTransitions: List<TransitionState.Transition>,
         scene: Scene,
         key: ElementKey,
     ) {
         check(layoutImpl == this.layoutImpl && scene == this.scene)
+        this.currentTransitions = currentTransitions
+
         removeNodeFromSceneState()
 
         val prevElement = this.element
@@ -236,7 +252,7 @@
         measurable: Measurable,
         constraints: Constraints,
     ): MeasureResult {
-        val transitions = layoutImpl.state.currentTransitions
+        val transitions = currentTransitions
         val transition = elementTransition(element, transitions)
 
         // If this element is not supposed to be laid out now, either because it is not part of any
@@ -251,11 +267,13 @@
             sceneState.lastAlpha = Element.AlphaUnspecified
 
             val placeable = measurable.measure(constraints)
+            sceneState.lastSize = placeable.size()
             return layout(placeable.width, placeable.height) {}
         }
 
         val placeable =
             measure(layoutImpl, scene, element, transition, sceneState, measurable, constraints)
+        sceneState.lastSize = placeable.size()
         return layout(placeable.width, placeable.height) {
             place(
                 layoutImpl,
@@ -270,7 +288,7 @@
     }
 
     override fun ContentDrawScope.draw() {
-        val transition = elementTransition(element, layoutImpl.state.currentTransitions)
+        val transition = elementTransition(element, currentTransitions)
         val drawScale = getDrawScale(layoutImpl, scene, element, transition, sceneState)
         if (drawScale == Scale.Default) {
             drawContent()
@@ -365,12 +383,14 @@
     }
 
     val lastOffset = lastUniqueState?.lastOffset ?: Offset.Unspecified
+    val lastSize = lastUniqueState?.lastSize ?: Element.SizeUnspecified
     val lastScale = lastUniqueState?.lastScale ?: Scale.Unspecified
     val lastAlpha = lastUniqueState?.lastAlpha ?: Element.AlphaUnspecified
 
     // Store the state of the element before the interruption and reset the deltas.
     sceneStates.forEach { sceneState ->
         sceneState.offsetBeforeInterruption = lastOffset
+        sceneState.sizeBeforeInterruption = lastSize
         sceneState.scaleBeforeInterruption = lastScale
         sceneState.alphaBeforeInterruption = lastAlpha
 
@@ -380,6 +400,7 @@
 
 private fun Element.SceneState.clearInterruptionDeltas() {
     offsetInterruptionDelta = Offset.Zero
+    sizeInterruptionDelta = IntSize.Zero
     scaleInterruptionDelta = Scale.Zero
     alphaInterruptionDelta = 0f
 }
@@ -615,7 +636,6 @@
     )
 }
 
-@OptIn(ExperimentalComposeUiApi::class)
 private fun ApproachMeasureScope.measure(
     layoutImpl: SceneTransitionLayoutImpl,
     scene: Scene,
@@ -637,8 +657,6 @@
     // once.
     var maybePlaceable: Placeable? = null
 
-    fun Placeable.size() = IntSize(width, height)
-
     val targetSize =
         computeValue(
             layoutImpl,
@@ -653,15 +671,44 @@
             ::lerp,
         )
 
-    return maybePlaceable
-        ?: measurable.measure(
-            Constraints.fixed(
-                targetSize.width.coerceAtLeast(0),
-                targetSize.height.coerceAtLeast(0),
-            )
+    // The measurable was already measured, so we can't take interruptions into account here given
+    // that we are not allowed to measure the same measurable twice.
+    maybePlaceable?.let { placeable ->
+        sceneState.sizeBeforeInterruption = Element.SizeUnspecified
+        sceneState.sizeInterruptionDelta = IntSize.Zero
+        return placeable
+    }
+
+    val interruptedSize =
+        computeInterruptedValue(
+            layoutImpl,
+            transition,
+            value = targetSize,
+            unspecifiedValue = Element.SizeUnspecified,
+            zeroValue = IntSize.Zero,
+            getValueBeforeInterruption = { sceneState.sizeBeforeInterruption },
+            setValueBeforeInterruption = { sceneState.sizeBeforeInterruption = it },
+            getInterruptionDelta = { sceneState.sizeInterruptionDelta },
+            setInterruptionDelta = { sceneState.sizeInterruptionDelta = it },
+            diff = { a, b -> IntSize(a.width - b.width, a.height - b.height) },
+            add = { a, b, bProgress ->
+                IntSize(
+                    (a.width + b.width * bProgress).roundToInt(),
+                    (a.height + b.height * bProgress).roundToInt(),
+                )
+            },
         )
+
+    return measurable.measure(
+        Constraints.fixed(
+            interruptedSize.width.coerceAtLeast(0),
+            interruptedSize.height.coerceAtLeast(0),
+        )
+    )
 }
 
+private fun Placeable.size(): IntSize = IntSize(width, height)
+
 private fun ContentDrawScope.getDrawScale(
     layoutImpl: SceneTransitionLayoutImpl,
     scene: Scene,
diff --git a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayoutImpl.kt b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayoutImpl.kt
index d383cec..7856498 100644
--- a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayoutImpl.kt
+++ b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayoutImpl.kt
@@ -126,6 +126,10 @@
                 orientation = Orientation.Vertical,
                 coroutineScope = coroutineScope,
             )
+
+        // Make sure that the state is created on the same thread (most probably the main thread)
+        // than this STLImpl.
+        state.checkThread()
     }
 
     internal fun draggableHandler(orientation: Orientation): DraggableHandlerImpl =
diff --git a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayoutState.kt b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayoutState.kt
index 4e3a032..a5b6d24 100644
--- a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayoutState.kt
+++ b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayoutState.kt
@@ -26,8 +26,10 @@
 import androidx.compose.runtime.LaunchedEffect
 import androidx.compose.runtime.SideEffect
 import androidx.compose.runtime.Stable
+import androidx.compose.runtime.getValue
+import androidx.compose.runtime.mutableStateOf
 import androidx.compose.runtime.remember
-import androidx.compose.runtime.snapshots.SnapshotStateList
+import androidx.compose.runtime.setValue
 import androidx.compose.ui.util.fastAll
 import androidx.compose.ui.util.fastFilter
 import androidx.compose.ui.util.fastForEach
@@ -374,14 +376,17 @@
     // TODO(b/290930950): Remove this flag.
     internal var enableInterruptions: Boolean,
 ) : SceneTransitionLayoutState {
+    private val creationThread: Thread = Thread.currentThread()
+
     /**
      * The current [TransitionState]. This list will either be:
      * 1. A list with a single [TransitionState.Idle] element, when we are idle.
      * 2. A list with one or more [TransitionState.Transition], when we are transitioning.
      */
     @VisibleForTesting
-    internal val transitionStates: MutableList<TransitionState> =
-        SnapshotStateList<TransitionState>().apply { add(TransitionState.Idle(initialScene)) }
+    internal var transitionStates: List<TransitionState> by
+        mutableStateOf(listOf(TransitionState.Idle(initialScene)))
+        private set
 
     override val transitionState: TransitionState
         get() = transitionStates.last()
@@ -417,6 +422,20 @@
      */
     internal abstract fun CoroutineScope.onChangeScene(scene: SceneKey)
 
+    internal fun checkThread() {
+        val current = Thread.currentThread()
+        if (current !== creationThread) {
+            error(
+                """
+                    Only the original thread that created a SceneTransitionLayoutState can mutate it
+                      Expected: ${creationThread.name}
+                      Current: ${current.name}
+                """
+                    .trimIndent()
+            )
+        }
+    }
+
     override fun isTransitioning(from: SceneKey?, to: SceneKey?): Boolean {
         val transition = currentTransition ?: return false
         return transition.isTransitioning(from, to)
@@ -441,6 +460,8 @@
         transitionKey: TransitionKey?,
         chain: Boolean = true,
     ) {
+        checkThread()
+
         // Compute the [TransformationSpec] when the transition starts.
         val fromScene = transition.fromScene
         val toScene = transition.toScene
@@ -465,7 +486,7 @@
         if (!enableInterruptions) {
             // Set the current transition.
             check(transitionStates.size == 1)
-            transitionStates[0] = transition
+            transitionStates = listOf(transition)
             return
         }
 
@@ -473,14 +494,12 @@
             is TransitionState.Idle -> {
                 // Replace [Idle] by [transition].
                 check(transitionStates.size == 1)
-                transitionStates[0] = transition
+                transitionStates = listOf(transition)
             }
             is TransitionState.Transition -> {
-                // Force the current transition to finish to currentScene.
-                currentState.finish().invokeOnCompletion {
-                    // Make sure [finishTransition] is called at the end of the transition.
-                    finishTransition(currentState, currentState.currentScene)
-                }
+                // Force the current transition to finish to currentScene. The transition will call
+                // [finishTransition] once it's finished.
+                currentState.finish()
 
                 val tooManyTransitions = transitionStates.size >= MAX_CONCURRENT_TRANSITIONS
                 val clearCurrentTransitions = !chain || tooManyTransitions
@@ -497,11 +516,11 @@
                     // we end up only with the new transition after appending it.
                     check(transitionStates.size == 1)
                     check(transitionStates[0] is TransitionState.Idle)
-                    transitionStates.clear()
+                    transitionStates = listOf(transition)
+                } else {
+                    // Append the new transition.
+                    transitionStates = transitionStates + transition
                 }
-
-                // Append the new transition.
-                transitionStates.add(transition)
             }
         }
     }
@@ -561,6 +580,8 @@
      * nothing if [transition] was interrupted since it was started.
      */
     internal fun finishTransition(transition: TransitionState.Transition, idleScene: SceneKey) {
+        checkThread()
+
         val existingIdleScene = finishedTransitions[transition]
         if (existingIdleScene != null) {
             // This transition was already finished.
@@ -571,6 +592,7 @@
             return
         }
 
+        val transitionStates = this.transitionStates
         if (!transitionStates.contains(transition)) {
             // This transition was already removed from transitionStates.
             return
@@ -589,25 +611,42 @@
         var lastRemovedIdleScene: SceneKey? = null
 
         // Remove all first n finished transitions.
-        while (transitionStates.isNotEmpty()) {
-            val firstTransition = transitionStates[0]
-            if (!finishedTransitions.contains(firstTransition)) {
+        var i = 0
+        val nStates = transitionStates.size
+        while (i < nStates) {
+            val t = transitionStates[i]
+            if (!finishedTransitions.contains(t)) {
                 // Stop here.
                 break
             }
 
-            // Remove the transition from the list and from the set of finished transitions.
-            transitionStates.removeAt(0)
-            lastRemovedIdleScene = finishedTransitions.remove(firstTransition)
+            // Remove the transition from the set of finished transitions.
+            lastRemovedIdleScene = finishedTransitions.remove(t)
+            i++
         }
 
         // If all transitions are finished, we are idle.
-        if (transitionStates.isEmpty()) {
+        if (i == nStates) {
             check(finishedTransitions.isEmpty())
-            transitionStates.add(TransitionState.Idle(checkNotNull(lastRemovedIdleScene)))
+            this.transitionStates = listOf(TransitionState.Idle(checkNotNull(lastRemovedIdleScene)))
+        } else if (i > 0) {
+            this.transitionStates = transitionStates.subList(fromIndex = i, toIndex = nStates)
         }
     }
 
+    fun snapToScene(scene: SceneKey) {
+        checkThread()
+
+        // Force finish all transitions.
+        while (currentTransitions.isNotEmpty()) {
+            val transition = transitionStates[0] as TransitionState.Transition
+            finishTransition(transition, transition.currentScene)
+        }
+
+        check(transitionStates.size == 1)
+        transitionStates = listOf(TransitionState.Idle(scene))
+    }
+
     private fun finishActiveTransitionLinks(idleScene: SceneKey) {
         val previousTransition = this.transitionState as? TransitionState.Transition ?: return
         for ((link, linkedTransition) in activeTransitionLinks) {
@@ -736,6 +775,8 @@
         coroutineScope: CoroutineScope,
         transitionKey: TransitionKey?,
     ): TransitionState.Transition? {
+        checkThread()
+
         return coroutineScope.animateToScene(
             layoutState = this@MutableSceneTransitionLayoutStateImpl,
             target = targetScene,
@@ -748,17 +789,6 @@
     override fun CoroutineScope.onChangeScene(scene: SceneKey) {
         setTargetScene(scene, coroutineScope = this)
     }
-
-    override fun snapToScene(scene: SceneKey) {
-        // Force finish all transitions.
-        while (currentTransitions.isNotEmpty()) {
-            val transition = transitionStates[0] as TransitionState.Transition
-            finishTransition(transition, transition.currentScene)
-        }
-
-        check(transitionStates.size == 1)
-        transitionStates[0] = TransitionState.Idle(scene)
-    }
 }
 
 private const val TAG = "SceneTransitionLayoutState"
diff --git a/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/ElementTest.kt b/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/ElementTest.kt
index e19dc96..6e114e3 100644
--- a/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/ElementTest.kt
+++ b/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/ElementTest.kt
@@ -56,6 +56,7 @@
 import androidx.compose.ui.unit.Dp
 import androidx.compose.ui.unit.DpOffset
 import androidx.compose.ui.unit.DpSize
+import androidx.compose.ui.unit.IntSize
 import androidx.compose.ui.unit.dp
 import androidx.compose.ui.unit.lerp
 import androidx.test.ext.junit.runners.AndroidJUnit4
@@ -63,11 +64,13 @@
 import com.android.compose.animation.scene.TestScenes.SceneB
 import com.android.compose.animation.scene.TestScenes.SceneC
 import com.android.compose.animation.scene.subjects.assertThat
+import com.android.compose.test.assertSizeIsEqualTo
 import com.google.common.truth.Truth.assertThat
 import kotlinx.coroutines.CoroutineScope
 import kotlinx.coroutines.launch
 import kotlinx.coroutines.test.runTest
 import org.junit.Assert.assertThrows
+import org.junit.Ignore
 import org.junit.Rule
 import org.junit.Test
 import org.junit.runner.RunWith
@@ -237,9 +240,9 @@
                 changeScene(SceneC)
             }
 
-            at(2 * frameDuration) { onElement(TestElements.Bar).assertIsNotDisplayed() }
+            at(3 * frameDuration) { onElement(TestElements.Bar).assertIsNotDisplayed() }
 
-            at(3 * frameDuration) { onElement(TestElements.Bar).assertDoesNotExist() }
+            at(4 * frameDuration) { onElement(TestElements.Bar).assertDoesNotExist() }
         }
     }
 
@@ -578,6 +581,7 @@
     }
 
     @Test
+    @Ignore("b/341072461")
     fun existingElementsDontRecomposeWhenTransitionStateChanges() {
         var fooCompositions = 0
 
@@ -603,6 +607,43 @@
         }
     }
 
+    @Test
+    // TODO(b/341072461): Remove this test.
+    fun layoutGetsCurrentTransitionStateFromComposition() {
+        val state =
+            rule.runOnUiThread {
+                MutableSceneTransitionLayoutStateImpl(
+                    SceneA,
+                    transitions {
+                        from(SceneA, to = SceneB) {
+                            scaleSize(TestElements.Foo, width = 2f, height = 2f)
+                        }
+                    }
+                )
+            }
+
+        rule.setContent {
+            SceneTransitionLayout(state) {
+                scene(SceneA) { Box(Modifier.element(TestElements.Foo).size(20.dp)) }
+                scene(SceneB) {}
+            }
+        }
+
+        // Pause the clock to block recompositions.
+        rule.mainClock.autoAdvance = false
+
+        // Change the current transition.
+        rule.runOnUiThread {
+            state.startTransition(
+                transition(from = SceneA, to = SceneB, progress = { 0.5f }),
+                transitionKey = null,
+            )
+        }
+
+        // The size of Foo should still be 20dp given that the new state was not composed yet.
+        rule.onNode(isElement(TestElements.Foo)).assertSizeIsEqualTo(20.dp, 20.dp)
+    }
+
     private fun setupOverscrollScenario(
         layoutWidth: Dp,
         layoutHeight: Dp,
@@ -616,11 +657,13 @@
         var touchSlop = 0f
 
         val state =
-            MutableSceneTransitionLayoutState(
-                initialScene = SceneA,
-                transitions = transitions(sceneTransitions),
-            )
-                as MutableSceneTransitionLayoutStateImpl
+            rule.runOnUiThread {
+                MutableSceneTransitionLayoutState(
+                    initialScene = SceneA,
+                    transitions = transitions(sceneTransitions),
+                )
+                    as MutableSceneTransitionLayoutStateImpl
+            }
 
         rule.setContent {
             touchSlop = LocalViewConfiguration.current.touchSlop
@@ -726,16 +769,18 @@
         val layoutHeight = 400.dp
 
         val state =
-            MutableSceneTransitionLayoutState(
-                initialScene = SceneB,
-                transitions =
-                    transitions {
-                        overscroll(SceneB, Orientation.Vertical) {
-                            translate(TestElements.Foo, y = overscrollTranslateY)
+            rule.runOnUiThread {
+                MutableSceneTransitionLayoutState(
+                    initialScene = SceneB,
+                    transitions =
+                        transitions {
+                            overscroll(SceneB, Orientation.Vertical) {
+                                translate(TestElements.Foo, y = overscrollTranslateY)
+                            }
                         }
-                    }
-            )
-                as MutableSceneTransitionLayoutStateImpl
+                )
+                    as MutableSceneTransitionLayoutStateImpl
+            }
 
         rule.setContent {
             touchSlop = LocalViewConfiguration.current.touchSlop
@@ -902,32 +947,36 @@
         val duration = 4 * 16
 
         val state =
-            MutableSceneTransitionLayoutState(
-                SceneA,
-                transitions {
-                    // Foo is at the top left corner of scene A. We make it disappear during A => B
-                    // to the right edge so it translates to the right.
-                    from(SceneA, to = SceneB) {
-                        spec = tween(duration, easing = LinearEasing)
-                        translate(
-                            TestElements.Foo,
-                            edge = Edge.Right,
-                            startsOutsideLayoutBounds = false,
-                        )
-                    }
+            rule.runOnUiThread {
+                MutableSceneTransitionLayoutState(
+                    SceneA,
+                    transitions {
+                        // Foo is at the top left corner of scene A. We make it disappear during A
+                        // => B
+                        // to the right edge so it translates to the right.
+                        from(SceneA, to = SceneB) {
+                            spec = tween(duration, easing = LinearEasing)
+                            translate(
+                                TestElements.Foo,
+                                edge = Edge.Right,
+                                startsOutsideLayoutBounds = false,
+                            )
+                        }
 
-                    // Bar is at the top right corner of scene C. We make it appear during B => C
-                    // from the left edge so it translates to the right at same time as Foo.
-                    from(SceneB, to = SceneC) {
-                        spec = tween(duration, easing = LinearEasing)
-                        translate(
-                            TestElements.Bar,
-                            edge = Edge.Left,
-                            startsOutsideLayoutBounds = false,
-                        )
+                        // Bar is at the top right corner of scene C. We make it appear during B =>
+                        // C
+                        // from the left edge so it translates to the right at same time as Foo.
+                        from(SceneB, to = SceneC) {
+                            spec = tween(duration, easing = LinearEasing)
+                            translate(
+                                TestElements.Bar,
+                                edge = Edge.Left,
+                                startsOutsideLayoutBounds = false,
+                            )
+                        }
                     }
-                }
-            )
+                )
+            }
 
         val layoutSize = 150.dp
         val elemSize = 50.dp
@@ -1023,23 +1072,28 @@
         val duration = 4 * 16
 
         val state =
-            MutableSceneTransitionLayoutStateImpl(
-                SceneA,
-                transitions {
-                    from(SceneA, to = SceneB) { spec = tween(duration, easing = LinearEasing) }
-                    from(SceneB, to = SceneC) { spec = tween(duration, easing = LinearEasing) }
-                },
-                enableInterruptions = false,
-            )
+            rule.runOnUiThread {
+                MutableSceneTransitionLayoutStateImpl(
+                    SceneA,
+                    transitions {
+                        from(SceneA, to = SceneB) { spec = tween(duration, easing = LinearEasing) }
+                        from(SceneB, to = SceneC) { spec = tween(duration, easing = LinearEasing) }
+                    },
+                )
+            }
 
         val layoutSize = DpSize(200.dp, 100.dp)
-        val fooSize = DpSize(20.dp, 10.dp)
 
         @Composable
-        fun SceneScope.Foo(modifier: Modifier = Modifier) {
-            Box(modifier.element(TestElements.Foo).size(fooSize))
+        fun SceneScope.Foo(size: Dp, modifier: Modifier = Modifier) {
+            Box(modifier.element(TestElements.Foo).size(size))
         }
 
+        // The size of Foo when idle in A, B or C.
+        val sizeInA = 10.dp
+        val sizeInB = 30.dp
+        val sizeInC = 50.dp
+
         lateinit var layoutImpl: SceneTransitionLayoutImpl
         rule.setContent {
             SceneTransitionLayoutForTesting(
@@ -1049,33 +1103,35 @@
             ) {
                 // In scene A, Foo is aligned at the TopStart.
                 scene(SceneA) {
-                    Box(Modifier.fillMaxSize()) { Foo(Modifier.align(Alignment.TopStart)) }
+                    Box(Modifier.fillMaxSize()) { Foo(sizeInA, Modifier.align(Alignment.TopStart)) }
                 }
 
                 // In scene C, Foo is aligned at the BottomEnd, so it moves vertically when coming
                 // from B. We put it before (below) scene B so that we can check that interruptions
                 // values and deltas are properly cleared once all transitions are done.
                 scene(SceneC) {
-                    Box(Modifier.fillMaxSize()) { Foo(Modifier.align(Alignment.BottomEnd)) }
+                    Box(Modifier.fillMaxSize()) {
+                        Foo(sizeInC, Modifier.align(Alignment.BottomEnd))
+                    }
                 }
 
                 // In scene B, Foo is aligned at the TopEnd, so it moves horizontally when coming
                 // from A.
                 scene(SceneB) {
-                    Box(Modifier.fillMaxSize()) { Foo(Modifier.align(Alignment.TopEnd)) }
+                    Box(Modifier.fillMaxSize()) { Foo(sizeInB, Modifier.align(Alignment.TopEnd)) }
                 }
             }
         }
 
         // The offset of Foo when idle in A, B or C.
         val offsetInA = DpOffset.Zero
-        val offsetInB = DpOffset(layoutSize.width - fooSize.width, 0.dp)
-        val offsetInC =
-            DpOffset(layoutSize.width - fooSize.width, layoutSize.height - fooSize.height)
+        val offsetInB = DpOffset(layoutSize.width - sizeInB, 0.dp)
+        val offsetInC = DpOffset(layoutSize.width - sizeInC, layoutSize.height - sizeInC)
 
         // Initial state (idle in A).
         rule
             .onNode(isElement(TestElements.Foo, SceneA))
+            .assertSizeIsEqualTo(sizeInA)
             .assertPositionInRootIsEqualTo(offsetInA.x, offsetInA.y)
 
         // Current transition is A => B at 50%.
@@ -1088,9 +1144,11 @@
                 onFinish = neverFinish(),
             )
         val offsetInAToB = lerp(offsetInA, offsetInB, aToBProgress)
+        val sizeInAToB = lerp(sizeInA, sizeInB, aToBProgress)
         rule.runOnUiThread { state.startTransition(aToB, transitionKey = null) }
         rule
             .onNode(isElement(TestElements.Foo, SceneB))
+            .assertSizeIsEqualTo(sizeInAToB)
             .assertPositionInRootIsEqualTo(offsetInAToB.x, offsetInAToB.y)
 
         // Start B => C at 0%.
@@ -1105,26 +1163,30 @@
             )
         rule.runOnUiThread { state.startTransition(bToC, transitionKey = null) }
 
-        // The offset interruption delta, which will be multiplied by the interruption progress then
-        // added to the current transition offset.
-        val interruptionDelta = offsetInAToB - offsetInB
+        // The interruption deltas, which will be multiplied by the interruption progress then added
+        // to the current transition offset and size.
+        val offsetInterruptionDelta = offsetInAToB - offsetInB
+        val sizeInterruptionDelta = sizeInAToB - sizeInB
 
         // Interruption progress is at 100% and bToC is at 0%, so Foo should be at the same offset
-        // as right before the interruption.
+        // and size as right before the interruption.
         rule
             .onNode(isElement(TestElements.Foo, SceneB))
             .assertPositionInRootIsEqualTo(offsetInAToB.x, offsetInAToB.y)
+            .assertSizeIsEqualTo(sizeInAToB)
 
         // Move the transition forward at 30% and set the interruption progress to 50%.
         bToCProgress = 0.3f
         interruptionProgress = 0.5f
         val offsetInBToC = lerp(offsetInB, offsetInC, bToCProgress)
+        val sizeInBToC = lerp(sizeInB, sizeInC, bToCProgress)
         val offsetInBToCWithInterruption =
             offsetInBToC +
                 DpOffset(
-                    interruptionDelta.x * interruptionProgress,
-                    interruptionDelta.y * interruptionProgress,
+                    offsetInterruptionDelta.x * interruptionProgress,
+                    offsetInterruptionDelta.y * interruptionProgress,
                 )
+        val sizeInBToCWithInterruption = sizeInBToC + sizeInterruptionDelta * interruptionProgress
         rule.waitForIdle()
         rule
             .onNode(isElement(TestElements.Foo, SceneB))
@@ -1132,6 +1194,7 @@
                 offsetInBToCWithInterruption.x,
                 offsetInBToCWithInterruption.y,
             )
+            .assertSizeIsEqualTo(sizeInBToCWithInterruption)
 
         // Finish the transition and interruption.
         bToCProgress = 1f
@@ -1139,10 +1202,13 @@
         rule
             .onNode(isElement(TestElements.Foo, SceneB))
             .assertPositionInRootIsEqualTo(offsetInC.x, offsetInC.y)
+            .assertSizeIsEqualTo(sizeInC)
 
         // Manually finish the transition.
-        state.finishTransition(aToB, SceneB)
-        state.finishTransition(bToC, SceneC)
+        rule.runOnUiThread {
+            state.finishTransition(aToB, SceneB)
+            state.finishTransition(bToC, SceneC)
+        }
         rule.waitForIdle()
         assertThat(state.transitionState).isIdle()
 
@@ -1151,9 +1217,11 @@
         assertThat(foo.sceneStates.keys).containsExactly(SceneC)
         val stateInC = foo.sceneStates.getValue(SceneC)
         assertThat(stateInC.offsetBeforeInterruption).isEqualTo(Offset.Unspecified)
+        assertThat(stateInC.sizeBeforeInterruption).isEqualTo(Element.SizeUnspecified)
         assertThat(stateInC.scaleBeforeInterruption).isEqualTo(Scale.Unspecified)
         assertThat(stateInC.alphaBeforeInterruption).isEqualTo(Element.AlphaUnspecified)
         assertThat(stateInC.offsetInterruptionDelta).isEqualTo(Offset.Zero)
+        assertThat(stateInC.sizeInterruptionDelta).isEqualTo(IntSize.Zero)
         assertThat(stateInC.scaleInterruptionDelta).isEqualTo(Scale.Zero)
         assertThat(stateInC.alphaInterruptionDelta).isEqualTo(0f)
     }
diff --git a/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/SceneTransitionLayoutTest.kt b/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/SceneTransitionLayoutTest.kt
index 692c18b..3751a22 100644
--- a/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/SceneTransitionLayoutTest.kt
+++ b/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/SceneTransitionLayoutTest.kt
@@ -76,12 +76,13 @@
 
     /** The content under test. */
     @Composable
-    private fun TestContent() {
+    private fun TestContent(enableInterruptions: Boolean = true) {
         layoutState =
             updateSceneTransitionLayoutState(
                 currentScene,
                 { currentScene = it },
-                EmptyTestTransitions
+                EmptyTestTransitions,
+                enableInterruptions = enableInterruptions,
             )
 
         SceneTransitionLayout(
@@ -219,7 +220,7 @@
 
     @Test
     fun testSharedElement() {
-        rule.setContent { TestContent() }
+        rule.setContent { TestContent(enableInterruptions = false) }
 
         // In scene A, the shared element SharedFoo() is at the top end of the layout and has a size
         // of 50.dp.
diff --git a/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/SwipeToSceneTest.kt b/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/SwipeToSceneTest.kt
index 1dd9322..3a806a4 100644
--- a/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/SwipeToSceneTest.kt
+++ b/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/SwipeToSceneTest.kt
@@ -70,7 +70,9 @@
     private fun layoutState(
         initialScene: SceneKey = SceneA,
         transitions: SceneTransitions = EmptyTestTransitions,
-    ) = MutableSceneTransitionLayoutState(initialScene, transitions)
+    ): MutableSceneTransitionLayoutState {
+        return rule.runOnUiThread { MutableSceneTransitionLayoutState(initialScene, transitions) }
+    }
 
     /** The content under test. */
     @Composable
@@ -455,7 +457,7 @@
 
     @Test
     fun swipeEnabledLater() {
-        val layoutState = MutableSceneTransitionLayoutState(SceneA)
+        val layoutState = layoutState()
         var swipesEnabled by mutableStateOf(false)
         var touchSlop = 0f
         rule.setContent {
@@ -489,7 +491,7 @@
     fun transitionKey() {
         val transitionkey = TransitionKey(debugName = "foo")
         val state =
-            MutableSceneTransitionLayoutStateImpl(
+            layoutState(
                 SceneA,
                 transitions {
                     from(SceneA, to = SceneB) { fade(TestElements.Foo) }
@@ -553,7 +555,7 @@
             }
 
         val state =
-            MutableSceneTransitionLayoutState(
+            layoutState(
                 SceneA,
                 transitions { from(SceneA, to = SceneB) { distance = swipeDistance } }
             )
diff --git a/packages/SystemUI/compose/scene/tests/src/com/android/compose/test/SizeAssertions.kt b/packages/SystemUI/compose/scene/tests/src/com/android/compose/test/SizeAssertions.kt
index fbd1b51..bca710f 100644
--- a/packages/SystemUI/compose/scene/tests/src/com/android/compose/test/SizeAssertions.kt
+++ b/packages/SystemUI/compose/scene/tests/src/com/android/compose/test/SizeAssertions.kt
@@ -21,7 +21,11 @@
 import androidx.compose.ui.test.assertWidthIsEqualTo
 import androidx.compose.ui.unit.Dp
 
-fun SemanticsNodeInteraction.assertSizeIsEqualTo(expectedWidth: Dp, expectedHeight: Dp) {
+fun SemanticsNodeInteraction.assertSizeIsEqualTo(
+    expectedWidth: Dp,
+    expectedHeight: Dp = expectedWidth,
+): SemanticsNodeInteraction {
     assertWidthIsEqualTo(expectedWidth)
     assertHeightIsEqualTo(expectedHeight)
+    return this
 }