Prevent size jumps during interruptions

This CL adds the same kind of interruption support for the size of
elements that was added in ag/26597678. The main difference is that this
support unfortunately does not always work given that the element might
have already been measured when processing the interruption delta,
preventing us from re-measuring it again with the delta taken into
account.

See b/290930950#comment22 for more details.

Bug: 290930950
Test: ElementTest
Flag: com.android.systemui.scene_container
Change-Id: I7498f56ac9598abf931a121d43353d66935fd8fb
diff --git a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/Element.kt b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/Element.kt
index 72da3b8..18baee9 100644
--- a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/Element.kt
+++ b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/Element.kt
@@ -49,6 +49,7 @@
 import com.android.compose.animation.scene.transformation.PropertyTransformation
 import com.android.compose.animation.scene.transformation.SharedElementTransformation
 import com.android.compose.ui.util.lerp
+import kotlin.math.roundToInt
 import kotlinx.coroutines.launch
 
 /** An element on screen, that can be composed in one or more scenes. */
@@ -81,11 +82,13 @@
 
         /** The last state this element had in this scene. */
         var lastOffset = Offset.Unspecified
+        var lastSize = SizeUnspecified
         var lastScale = Scale.Unspecified
         var lastAlpha = AlphaUnspecified
 
         /** The state of this element in this scene right before the last interruption (if any). */
         var offsetBeforeInterruption = Offset.Unspecified
+        var sizeBeforeInterruption = SizeUnspecified
         var scaleBeforeInterruption = Scale.Unspecified
         var alphaBeforeInterruption = AlphaUnspecified
 
@@ -96,6 +99,7 @@
          * they nicely animate from their values down to 0.
          */
         var offsetInterruptionDelta = Offset.Zero
+        var sizeInterruptionDelta = IntSize.Zero
         var scaleInterruptionDelta = Scale.Zero
         var alphaInterruptionDelta = 0f
 
@@ -263,11 +267,13 @@
             sceneState.lastAlpha = Element.AlphaUnspecified
 
             val placeable = measurable.measure(constraints)
+            sceneState.lastSize = placeable.size()
             return layout(placeable.width, placeable.height) {}
         }
 
         val placeable =
             measure(layoutImpl, scene, element, transition, sceneState, measurable, constraints)
+        sceneState.lastSize = placeable.size()
         return layout(placeable.width, placeable.height) {
             place(
                 layoutImpl,
@@ -377,12 +383,14 @@
     }
 
     val lastOffset = lastUniqueState?.lastOffset ?: Offset.Unspecified
+    val lastSize = lastUniqueState?.lastSize ?: Element.SizeUnspecified
     val lastScale = lastUniqueState?.lastScale ?: Scale.Unspecified
     val lastAlpha = lastUniqueState?.lastAlpha ?: Element.AlphaUnspecified
 
     // Store the state of the element before the interruption and reset the deltas.
     sceneStates.forEach { sceneState ->
         sceneState.offsetBeforeInterruption = lastOffset
+        sceneState.sizeBeforeInterruption = lastSize
         sceneState.scaleBeforeInterruption = lastScale
         sceneState.alphaBeforeInterruption = lastAlpha
 
@@ -392,6 +400,7 @@
 
 private fun Element.SceneState.clearInterruptionDeltas() {
     offsetInterruptionDelta = Offset.Zero
+    sizeInterruptionDelta = IntSize.Zero
     scaleInterruptionDelta = Scale.Zero
     alphaInterruptionDelta = 0f
 }
@@ -648,8 +657,6 @@
     // once.
     var maybePlaceable: Placeable? = null
 
-    fun Placeable.size() = IntSize(width, height)
-
     val targetSize =
         computeValue(
             layoutImpl,
@@ -664,15 +671,44 @@
             ::lerp,
         )
 
-    return maybePlaceable
-        ?: measurable.measure(
-            Constraints.fixed(
-                targetSize.width.coerceAtLeast(0),
-                targetSize.height.coerceAtLeast(0),
-            )
+    // The measurable was already measured, so we can't take interruptions into account here given
+    // that we are not allowed to measure the same measurable twice.
+    maybePlaceable?.let { placeable ->
+        sceneState.sizeBeforeInterruption = Element.SizeUnspecified
+        sceneState.sizeInterruptionDelta = IntSize.Zero
+        return placeable
+    }
+
+    val interruptedSize =
+        computeInterruptedValue(
+            layoutImpl,
+            transition,
+            value = targetSize,
+            unspecifiedValue = Element.SizeUnspecified,
+            zeroValue = IntSize.Zero,
+            getValueBeforeInterruption = { sceneState.sizeBeforeInterruption },
+            setValueBeforeInterruption = { sceneState.sizeBeforeInterruption = it },
+            getInterruptionDelta = { sceneState.sizeInterruptionDelta },
+            setInterruptionDelta = { sceneState.sizeInterruptionDelta = it },
+            diff = { a, b -> IntSize(a.width - b.width, a.height - b.height) },
+            add = { a, b, bProgress ->
+                IntSize(
+                    (a.width + b.width * bProgress).roundToInt(),
+                    (a.height + b.height * bProgress).roundToInt(),
+                )
+            },
         )
+
+    return measurable.measure(
+        Constraints.fixed(
+            interruptedSize.width.coerceAtLeast(0),
+            interruptedSize.height.coerceAtLeast(0),
+        )
+    )
 }
 
+private fun Placeable.size(): IntSize = IntSize(width, height)
+
 private fun ContentDrawScope.getDrawScale(
     layoutImpl: SceneTransitionLayoutImpl,
     scene: Scene,
diff --git a/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/ElementTest.kt b/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/ElementTest.kt
index 7c449e4..6e114e3 100644
--- a/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/ElementTest.kt
+++ b/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/ElementTest.kt
@@ -56,6 +56,7 @@
 import androidx.compose.ui.unit.Dp
 import androidx.compose.ui.unit.DpOffset
 import androidx.compose.ui.unit.DpSize
+import androidx.compose.ui.unit.IntSize
 import androidx.compose.ui.unit.dp
 import androidx.compose.ui.unit.lerp
 import androidx.test.ext.junit.runners.AndroidJUnit4
@@ -1082,13 +1083,17 @@
             }
 
         val layoutSize = DpSize(200.dp, 100.dp)
-        val fooSize = DpSize(20.dp, 10.dp)
 
         @Composable
-        fun SceneScope.Foo(modifier: Modifier = Modifier) {
-            Box(modifier.element(TestElements.Foo).size(fooSize))
+        fun SceneScope.Foo(size: Dp, modifier: Modifier = Modifier) {
+            Box(modifier.element(TestElements.Foo).size(size))
         }
 
+        // The size of Foo when idle in A, B or C.
+        val sizeInA = 10.dp
+        val sizeInB = 30.dp
+        val sizeInC = 50.dp
+
         lateinit var layoutImpl: SceneTransitionLayoutImpl
         rule.setContent {
             SceneTransitionLayoutForTesting(
@@ -1098,33 +1103,35 @@
             ) {
                 // In scene A, Foo is aligned at the TopStart.
                 scene(SceneA) {
-                    Box(Modifier.fillMaxSize()) { Foo(Modifier.align(Alignment.TopStart)) }
+                    Box(Modifier.fillMaxSize()) { Foo(sizeInA, Modifier.align(Alignment.TopStart)) }
                 }
 
                 // In scene C, Foo is aligned at the BottomEnd, so it moves vertically when coming
                 // from B. We put it before (below) scene B so that we can check that interruptions
                 // values and deltas are properly cleared once all transitions are done.
                 scene(SceneC) {
-                    Box(Modifier.fillMaxSize()) { Foo(Modifier.align(Alignment.BottomEnd)) }
+                    Box(Modifier.fillMaxSize()) {
+                        Foo(sizeInC, Modifier.align(Alignment.BottomEnd))
+                    }
                 }
 
                 // In scene B, Foo is aligned at the TopEnd, so it moves horizontally when coming
                 // from A.
                 scene(SceneB) {
-                    Box(Modifier.fillMaxSize()) { Foo(Modifier.align(Alignment.TopEnd)) }
+                    Box(Modifier.fillMaxSize()) { Foo(sizeInB, Modifier.align(Alignment.TopEnd)) }
                 }
             }
         }
 
         // The offset of Foo when idle in A, B or C.
         val offsetInA = DpOffset.Zero
-        val offsetInB = DpOffset(layoutSize.width - fooSize.width, 0.dp)
-        val offsetInC =
-            DpOffset(layoutSize.width - fooSize.width, layoutSize.height - fooSize.height)
+        val offsetInB = DpOffset(layoutSize.width - sizeInB, 0.dp)
+        val offsetInC = DpOffset(layoutSize.width - sizeInC, layoutSize.height - sizeInC)
 
         // Initial state (idle in A).
         rule
             .onNode(isElement(TestElements.Foo, SceneA))
+            .assertSizeIsEqualTo(sizeInA)
             .assertPositionInRootIsEqualTo(offsetInA.x, offsetInA.y)
 
         // Current transition is A => B at 50%.
@@ -1137,9 +1144,11 @@
                 onFinish = neverFinish(),
             )
         val offsetInAToB = lerp(offsetInA, offsetInB, aToBProgress)
+        val sizeInAToB = lerp(sizeInA, sizeInB, aToBProgress)
         rule.runOnUiThread { state.startTransition(aToB, transitionKey = null) }
         rule
             .onNode(isElement(TestElements.Foo, SceneB))
+            .assertSizeIsEqualTo(sizeInAToB)
             .assertPositionInRootIsEqualTo(offsetInAToB.x, offsetInAToB.y)
 
         // Start B => C at 0%.
@@ -1154,26 +1163,30 @@
             )
         rule.runOnUiThread { state.startTransition(bToC, transitionKey = null) }
 
-        // The offset interruption delta, which will be multiplied by the interruption progress then
-        // added to the current transition offset.
-        val interruptionDelta = offsetInAToB - offsetInB
+        // The interruption deltas, which will be multiplied by the interruption progress then added
+        // to the current transition offset and size.
+        val offsetInterruptionDelta = offsetInAToB - offsetInB
+        val sizeInterruptionDelta = sizeInAToB - sizeInB
 
         // Interruption progress is at 100% and bToC is at 0%, so Foo should be at the same offset
-        // as right before the interruption.
+        // and size as right before the interruption.
         rule
             .onNode(isElement(TestElements.Foo, SceneB))
             .assertPositionInRootIsEqualTo(offsetInAToB.x, offsetInAToB.y)
+            .assertSizeIsEqualTo(sizeInAToB)
 
         // Move the transition forward at 30% and set the interruption progress to 50%.
         bToCProgress = 0.3f
         interruptionProgress = 0.5f
         val offsetInBToC = lerp(offsetInB, offsetInC, bToCProgress)
+        val sizeInBToC = lerp(sizeInB, sizeInC, bToCProgress)
         val offsetInBToCWithInterruption =
             offsetInBToC +
                 DpOffset(
-                    interruptionDelta.x * interruptionProgress,
-                    interruptionDelta.y * interruptionProgress,
+                    offsetInterruptionDelta.x * interruptionProgress,
+                    offsetInterruptionDelta.y * interruptionProgress,
                 )
+        val sizeInBToCWithInterruption = sizeInBToC + sizeInterruptionDelta * interruptionProgress
         rule.waitForIdle()
         rule
             .onNode(isElement(TestElements.Foo, SceneB))
@@ -1181,6 +1194,7 @@
                 offsetInBToCWithInterruption.x,
                 offsetInBToCWithInterruption.y,
             )
+            .assertSizeIsEqualTo(sizeInBToCWithInterruption)
 
         // Finish the transition and interruption.
         bToCProgress = 1f
@@ -1188,6 +1202,7 @@
         rule
             .onNode(isElement(TestElements.Foo, SceneB))
             .assertPositionInRootIsEqualTo(offsetInC.x, offsetInC.y)
+            .assertSizeIsEqualTo(sizeInC)
 
         // Manually finish the transition.
         rule.runOnUiThread {
@@ -1202,9 +1217,11 @@
         assertThat(foo.sceneStates.keys).containsExactly(SceneC)
         val stateInC = foo.sceneStates.getValue(SceneC)
         assertThat(stateInC.offsetBeforeInterruption).isEqualTo(Offset.Unspecified)
+        assertThat(stateInC.sizeBeforeInterruption).isEqualTo(Element.SizeUnspecified)
         assertThat(stateInC.scaleBeforeInterruption).isEqualTo(Scale.Unspecified)
         assertThat(stateInC.alphaBeforeInterruption).isEqualTo(Element.AlphaUnspecified)
         assertThat(stateInC.offsetInterruptionDelta).isEqualTo(Offset.Zero)
+        assertThat(stateInC.sizeInterruptionDelta).isEqualTo(IntSize.Zero)
         assertThat(stateInC.scaleInterruptionDelta).isEqualTo(Scale.Zero)
         assertThat(stateInC.alphaInterruptionDelta).isEqualTo(0f)
     }
diff --git a/packages/SystemUI/compose/scene/tests/src/com/android/compose/test/SizeAssertions.kt b/packages/SystemUI/compose/scene/tests/src/com/android/compose/test/SizeAssertions.kt
index fbd1b51..bca710f 100644
--- a/packages/SystemUI/compose/scene/tests/src/com/android/compose/test/SizeAssertions.kt
+++ b/packages/SystemUI/compose/scene/tests/src/com/android/compose/test/SizeAssertions.kt
@@ -21,7 +21,11 @@
 import androidx.compose.ui.test.assertWidthIsEqualTo
 import androidx.compose.ui.unit.Dp
 
-fun SemanticsNodeInteraction.assertSizeIsEqualTo(expectedWidth: Dp, expectedHeight: Dp) {
+fun SemanticsNodeInteraction.assertSizeIsEqualTo(
+    expectedWidth: Dp,
+    expectedHeight: Dp = expectedWidth,
+): SemanticsNodeInteraction {
     assertWidthIsEqualTo(expectedWidth)
     assertHeightIsEqualTo(expectedHeight)
+    return this
 }