Add support for two-stage transitions in STL

Bug: 333415330
Bug: 350705972
Test: ElementTest, PredictiveBackHandlerTest, TransitionDslTest
Flag: com.android.systemui.scene_container
Change-Id: Ibda3016c55ad42ee1ef1de792a2b7853b9a4c8c0
diff --git a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/Element.kt b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/Element.kt
index 377b02b..3ad07d0 100644
--- a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/Element.kt
+++ b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/Element.kt
@@ -1133,10 +1133,98 @@
 
     val transformation =
         transformation(transition.transformationSpec.transformations(element.key, scene))
-            // If there is no transformation explicitly associated to this element value, let's use
-            // the value given by the system (like the current position and size given by the layout
-            // pass).
-            ?: return currentValue()
+
+    val previewTransformation =
+        transition.previewTransformationSpec?.let {
+            transformation(it.transformations(element.key, scene))
+        }
+    if (previewTransformation != null) {
+        val isInPreviewStage = transition.isInPreviewStage
+
+        val idleValue = sceneValue(sceneState)
+        val isEntering = scene == toScene
+        val previewTargetValue =
+            previewTransformation.transform(
+                layoutImpl,
+                scene,
+                element,
+                sceneState,
+                transition,
+                idleValue,
+            )
+
+        val targetValueOrNull =
+            transformation?.transform(
+                layoutImpl,
+                scene,
+                element,
+                sceneState,
+                transition,
+                idleValue,
+            )
+
+        // Make sure we don't read progress if values are the same and we don't need to interpolate,
+        // so we don't invalidate the phase where this is read.
+        when {
+            isInPreviewStage && isEntering && previewTargetValue == targetValueOrNull ->
+                return previewTargetValue
+            isInPreviewStage && !isEntering && idleValue == previewTargetValue -> return idleValue
+            previewTargetValue == targetValueOrNull && idleValue == previewTargetValue ->
+                return idleValue
+            else -> {}
+        }
+
+        val previewProgress = transition.previewProgress
+        // progress is not needed for all cases of the below when block, therefore read it lazily
+        // TODO(b/290184746): Make sure that we don't overflow transformations associated to a range
+        val previewRangeProgress =
+            previewTransformation.range?.progress(previewProgress) ?: previewProgress
+
+        if (isInPreviewStage) {
+            // if we're in the preview stage of the transition, interpolate between start state and
+            // preview target state:
+            return if (isEntering) {
+                // i.e. in the entering case between previewTargetValue and targetValue (or
+                // idleValue if no transformation is defined in the second stage transition)...
+                lerp(previewTargetValue, targetValueOrNull ?: idleValue, previewRangeProgress)
+            } else {
+                // ...and in the exiting case between the idleValue and the previewTargetValue.
+                lerp(idleValue, previewTargetValue, previewRangeProgress)
+            }
+        }
+
+        // if we're in the second stage of the transition, interpolate between the state the
+        // element was left at the end of the preview-phase and the target state:
+        return if (isEntering) {
+            // i.e. in the entering case between preview-end-state and the idleValue...
+            lerp(
+                lerp(previewTargetValue, targetValueOrNull ?: idleValue, previewRangeProgress),
+                idleValue,
+                transformation?.range?.progress(transition.progress) ?: transition.progress
+            )
+        } else {
+            if (targetValueOrNull == null) {
+                // ... and in the exiting case, the element should remain in the preview-end-state
+                // if no further transformation is defined in the second-stage transition...
+                lerp(idleValue, previewTargetValue, previewRangeProgress)
+            } else {
+                // ...and otherwise it should be interpolated between preview-end-state and
+                // targetValue
+                lerp(
+                    lerp(idleValue, previewTargetValue, previewRangeProgress),
+                    targetValueOrNull,
+                    transformation.range?.progress(transition.progress) ?: transition.progress
+                )
+            }
+        }
+    }
+
+    if (transformation == null) {
+        // If there is no transformation explicitly associated to this element value, let's use
+        // the value given by the system (like the current position and size given by the layout
+        // pass).
+        return currentValue()
+    }
 
     val idleValue = sceneValue(sceneState)
     val targetValue =
diff --git a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/PredictiveBackHandler.kt b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/PredictiveBackHandler.kt
index 734241e..081707b 100644
--- a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/PredictiveBackHandler.kt
+++ b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/PredictiveBackHandler.kt
@@ -77,8 +77,17 @@
     private var progressAnimatable by mutableStateOf<Animatable<Float, AnimationVector1D>?>(null)
     var dragProgress: Float by mutableFloatStateOf(0f)
 
+    override val previewProgress: Float
+        get() = dragProgress
+
+    override val previewProgressVelocity: Float
+        get() = 0f // Currently, velocity is not exposed by predictive back API
+
+    override val isInPreviewStage: Boolean
+        get() = progressAnimatable == null && previewTransformationSpec != null
+
     override val progress: Float
-        get() = progressAnimatable?.value ?: dragProgress
+        get() = progressAnimatable?.value ?: previewTransformationSpec?.let { 0f } ?: dragProgress
 
     override val progressVelocity: Float
         get() = progressAnimatable?.velocity ?: 0f
@@ -109,8 +118,8 @@
                 toScene -> 1f
                 else -> error("scene $currentScene should be either $fromScene or $toScene")
             }
-
-        val animatable = Animatable(dragProgress).also { progressAnimatable = it }
+        val startProgress = if (previewTransformationSpec != null) 0f else dragProgress
+        val animatable = Animatable(startProgress).also { progressAnimatable = it }
 
         // Important: We start atomically to make sure that we start the coroutine even if it is
         // cancelled right after it is launched, so that finishTransition() is correctly called.
diff --git a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayoutState.kt b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayoutState.kt
index 56c8752..db2a3d3 100644
--- a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayoutState.kt
+++ b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayoutState.kt
@@ -189,6 +189,19 @@
         /** The current velocity of [progress], in progress units. */
         abstract val progressVelocity: Float
 
+        /**
+         * The progress of the preview transition. This is usually in the `[0; 1]` range, but it can
+         * also be less than `0` or greater than `1` when using transitions with a spring
+         * AnimationSpec or when flinging quickly during a swipe gesture.
+         */
+        open val previewProgress: Float = 0f
+
+        /** The current velocity of [previewProgress], in progress units. */
+        open val previewProgressVelocity: Float = 0f
+
+        /** Whether the transition is currently in the preview stage */
+        open val isInPreviewStage: Boolean = false
+
         /** Whether the transition was triggered by user input rather than being programmatic. */
         abstract val isInitiatedByUserInput: Boolean
 
@@ -203,6 +216,7 @@
          * [started][BaseSceneTransitionLayoutState.startTransition].
          */
         internal var transformationSpec: TransformationSpecImpl = TransformationSpec.Empty
+        internal var previewTransformationSpec: TransformationSpecImpl? = null
         private var fromOverscrollSpec: OverscrollSpecImpl? = null
         private var toOverscrollSpec: OverscrollSpecImpl? = null
 
@@ -431,6 +445,10 @@
             transitions
                 .transitionSpec(fromScene, toScene, key = transition.key)
                 .transformationSpec()
+        transition.previewTransformationSpec =
+            transitions
+                .transitionSpec(fromScene, toScene, key = transition.key)
+                .previewTransformationSpec()
         if (orientation != null) {
             transition.updateOverscrollSpecs(
                 fromSpec = transitions.overscrollSpec(fromScene, orientation),
diff --git a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitions.kt b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitions.kt
index e30dd356..06b093d 100644
--- a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitions.kt
+++ b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitions.kt
@@ -110,7 +110,7 @@
     }
 
     private fun defaultTransition(from: SceneKey, to: SceneKey) =
-        TransitionSpecImpl(key = null, from, to, TransformationSpec.EmptyProvider)
+        TransitionSpecImpl(key = null, from, to, null, null, TransformationSpec.EmptyProvider)
 
     internal fun overscrollSpec(scene: SceneKey, orientation: Orientation): OverscrollSpecImpl? =
         overscrollCache
@@ -177,10 +177,18 @@
     /**
      * The [TransformationSpec] associated to this [TransitionSpec].
      *
-     * Note that this is called once every a transition associated to this [TransitionSpec] is
+     * Note that this is called once whenever a transition associated to this [TransitionSpec] is
      * started.
      */
     fun transformationSpec(): TransformationSpec
+
+    /**
+     * The preview [TransformationSpec] associated to this [TransitionSpec].
+     *
+     * Note that this is called once whenever a transition associated to this [TransitionSpec] is
+     * started.
+     */
+    fun previewTransformationSpec(): TransformationSpec?
 }
 
 interface TransformationSpec {
@@ -225,13 +233,17 @@
     override val key: TransitionKey?,
     override val from: SceneKey?,
     override val to: SceneKey?,
-    private val transformationSpec: () -> TransformationSpecImpl,
+    private val previewTransformationSpec: (() -> TransformationSpecImpl)? = null,
+    private val reversePreviewTransformationSpec: (() -> TransformationSpecImpl)? = null,
+    private val transformationSpec: () -> TransformationSpecImpl
 ) : TransitionSpec {
     override fun reversed(): TransitionSpecImpl {
         return TransitionSpecImpl(
             key = key,
             from = to,
             to = from,
+            previewTransformationSpec = reversePreviewTransformationSpec,
+            reversePreviewTransformationSpec = previewTransformationSpec,
             transformationSpec = {
                 val reverse = transformationSpec.invoke()
                 TransformationSpecImpl(
@@ -245,6 +257,9 @@
     }
 
     override fun transformationSpec(): TransformationSpecImpl = this.transformationSpec.invoke()
+
+    override fun previewTransformationSpec(): TransformationSpecImpl? =
+        previewTransformationSpec?.invoke()
 }
 
 /** The definition of the overscroll behavior of the [scene]. */
diff --git a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/TransitionDsl.kt b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/TransitionDsl.kt
index 89ed8d6..3a87d41 100644
--- a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/TransitionDsl.kt
+++ b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/TransitionDsl.kt
@@ -54,11 +54,17 @@
      * If [key] is not `null`, then this transition will only be used if the same key is specified
      * when triggering the transition.
      *
+     * Optionally, define a [preview] animation which will be played during the first stage of the
+     * transition, e.g. during the predictive back gesture. In case your transition should be
+     * reversible with the reverse animation having a preview as well, define a [reversePreview].
+     *
      * @see from
      */
     fun to(
         to: SceneKey,
         key: TransitionKey? = null,
+        preview: (TransitionBuilder.() -> Unit)? = null,
+        reversePreview: (TransitionBuilder.() -> Unit)? = null,
         builder: TransitionBuilder.() -> Unit = {},
     ): TransitionSpec
 
@@ -74,11 +80,17 @@
      * 2. to == A && from == B, which is then treated in reverse.
      * 3. (from == A && to == null) || (from == null && to == B)
      * 4. (from == B && to == null) || (from == null && to == A), which is then treated in reverse.
+     *
+     * Optionally, define a [preview] animation which will be played during the first stage of the
+     * transition, e.g. during the predictive back gesture. In case your transition should be
+     * reversible with the reverse animation having a preview as well, define a [reversePreview].
      */
     fun from(
         from: SceneKey,
         to: SceneKey? = null,
         key: TransitionKey? = null,
+        preview: (TransitionBuilder.() -> Unit)? = null,
+        reversePreview: (TransitionBuilder.() -> Unit)? = null,
         builder: TransitionBuilder.() -> Unit = {},
     ): TransitionSpec
 
diff --git a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/TransitionDslImpl.kt b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/TransitionDslImpl.kt
index 1e67aa9..02a4362 100644
--- a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/TransitionDslImpl.kt
+++ b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/TransitionDslImpl.kt
@@ -62,18 +62,22 @@
     override fun to(
         to: SceneKey,
         key: TransitionKey?,
+        preview: (TransitionBuilder.() -> Unit)?,
+        reversePreview: (TransitionBuilder.() -> Unit)?,
         builder: TransitionBuilder.() -> Unit
     ): TransitionSpec {
-        return transition(from = null, to = to, key = key, builder)
+        return transition(from = null, to = to, key = key, preview, reversePreview, builder)
     }
 
     override fun from(
         from: SceneKey,
         to: SceneKey?,
         key: TransitionKey?,
+        preview: (TransitionBuilder.() -> Unit)?,
+        reversePreview: (TransitionBuilder.() -> Unit)?,
         builder: TransitionBuilder.() -> Unit
     ): TransitionSpec {
-        return transition(from = from, to = to, key = key, builder)
+        return transition(from = from, to = to, key = key, preview, reversePreview, builder)
     }
 
     override fun overscroll(
@@ -103,9 +107,11 @@
         from: SceneKey?,
         to: SceneKey?,
         key: TransitionKey?,
+        preview: (TransitionBuilder.() -> Unit)?,
+        reversePreview: (TransitionBuilder.() -> Unit)?,
         builder: TransitionBuilder.() -> Unit,
     ): TransitionSpec {
-        fun transformationSpec(): TransformationSpecImpl {
+        fun transformationSpec(builder: TransitionBuilder.() -> Unit): TransformationSpecImpl {
             val impl = TransitionBuilderImpl().apply(builder)
             return TransformationSpecImpl(
                 progressSpec = impl.spec,
@@ -115,7 +121,18 @@
             )
         }
 
-        val spec = TransitionSpecImpl(key, from, to, ::transformationSpec)
+        val previewTransformationSpec = preview?.let { { transformationSpec(it) } }
+        val reversePreviewTransformationSpec = reversePreview?.let { { transformationSpec(it) } }
+        val transformationSpec = { transformationSpec(builder) }
+        val spec =
+            TransitionSpecImpl(
+                key,
+                from,
+                to,
+                previewTransformationSpec,
+                reversePreviewTransformationSpec,
+                transformationSpec
+            )
         transitionSpecs.add(spec)
         return spec
     }
diff --git a/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/ElementTest.kt b/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/ElementTest.kt
index 7988e0e..2baf134 100644
--- a/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/ElementTest.kt
+++ b/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/ElementTest.kt
@@ -2185,4 +2185,165 @@
         rule.onNode(isElement(TestElements.Foo, SceneA)).assertIsNotDisplayed()
         rule.onNode(isElement(TestElements.Foo, SceneB)).assertPositionInRootIsEqualTo(40.dp, 60.dp)
     }
+
+    @Test
+    fun previewInterpolation_previewStage() {
+        val exiting1 = ElementKey("exiting1")
+        val exiting2 = ElementKey("exiting2")
+        val exiting3 = ElementKey("exiting3")
+        val entering1 = ElementKey("entering1")
+        val entering2 = ElementKey("entering2")
+        val entering3 = ElementKey("entering3")
+
+        val layoutImpl =
+            testPreviewTransformation(
+                from = SceneB,
+                to = SceneA,
+                exitingElements = listOf(exiting1, exiting2, exiting3),
+                enteringElements = listOf(entering1, entering2, entering3),
+                preview = {
+                    scaleDraw(exiting1, scaleX = 0.8f, scaleY = 0.8f)
+                    translate(exiting2, x = 20.dp)
+                    scaleDraw(entering1, scaleX = 0f, scaleY = 0f)
+                    translate(entering2, y = 30.dp)
+                },
+                transition = {
+                    translate(exiting2, x = 30.dp)
+                    scaleSize(exiting3, width = 0.8f, height = 0.8f)
+                    scaleDraw(entering1, scaleX = 0.5f, scaleY = 0.5f)
+                    scaleSize(entering3, width = 0.2f, height = 0.2f)
+                },
+                previewProgress = 0.5f,
+                progress = 0f,
+                isInPreviewStage = true
+            )
+
+        // verify that preview transition for exiting elements is halfway played from
+        // current-scene-value -> preview-target-value
+        val exiting1InB = layoutImpl.elements.getValue(exiting1).sceneStates.getValue(SceneB)
+        // e.g. exiting1 is half scaled...
+        assertThat(exiting1InB.lastScale).isEqualTo(Scale(0.9f, 0.9f, Offset.Unspecified))
+        // ...and exiting2 is halfway translated from 0.dp to 20.dp...
+        rule.onNode(isElement(exiting2)).assertPositionInRootIsEqualTo(10.dp, 0.dp)
+        // ...whereas exiting3 remains in its original size because it is only affected by the
+        // second phase of the transition
+        rule.onNode(isElement(exiting3)).assertSizeIsEqualTo(100.dp, 100.dp)
+
+        // verify that preview transition for entering elements is halfway played from
+        // preview-target-value -> transition-target-value (or target-scene-value if no
+        // transition-target-value defined).
+        val entering1InA = layoutImpl.elements.getValue(entering1).sceneStates.getValue(SceneA)
+        // e.g. entering1 is half scaled between 0f and 0.5f -> 0.25f...
+        assertThat(entering1InA.lastScale).isEqualTo(Scale(0.25f, 0.25f, Offset.Unspecified))
+        // ...and entering2 is half way translated between 30.dp and 0.dp
+        rule.onNode(isElement(entering2)).assertPositionInRootIsEqualTo(0.dp, 15.dp)
+        // ...and entering3 is still at its start size of 0.2f * 100.dp, because it is unaffected
+        // by the preview phase
+        rule.onNode(isElement(entering3)).assertSizeIsEqualTo(20.dp, 20.dp)
+    }
+
+    @Test
+    fun previewInterpolation_transitionStage() {
+        val exiting1 = ElementKey("exiting1")
+        val exiting2 = ElementKey("exiting2")
+        val exiting3 = ElementKey("exiting3")
+        val entering1 = ElementKey("entering1")
+        val entering2 = ElementKey("entering2")
+        val entering3 = ElementKey("entering3")
+
+        val layoutImpl =
+            testPreviewTransformation(
+                from = SceneB,
+                to = SceneA,
+                exitingElements = listOf(exiting1, exiting2, exiting3),
+                enteringElements = listOf(entering1, entering2, entering3),
+                preview = {
+                    scaleDraw(exiting1, scaleX = 0.8f, scaleY = 0.8f)
+                    translate(exiting2, x = 20.dp)
+                    scaleDraw(entering1, scaleX = 0f, scaleY = 0f)
+                    translate(entering2, y = 30.dp)
+                },
+                transition = {
+                    translate(exiting2, x = 30.dp)
+                    scaleSize(exiting3, width = 0.8f, height = 0.8f)
+                    scaleDraw(entering1, scaleX = 0.5f, scaleY = 0.5f)
+                    scaleSize(entering3, width = 0.2f, height = 0.2f)
+                },
+                previewProgress = 0.5f,
+                progress = 0.5f,
+                isInPreviewStage = false
+            )
+
+        // verify that exiting elements remain in the preview-end state if no further transition is
+        // defined for them in the second stage
+        val exiting1InB = layoutImpl.elements.getValue(exiting1).sceneStates.getValue(SceneB)
+        // i.e. exiting1 remains half scaled
+        assertThat(exiting1InB.lastScale).isEqualTo(Scale(0.9f, 0.9f, Offset.Unspecified))
+        // in case there is an additional transition defined for the second stage, verify that the
+        // animation is seamlessly taken over from the preview-end-state, e.g. the translation of
+        // exiting2 is at 10.dp after the preview phase. After half of the second phase, it
+        // should be half-way between 10.dp and the target-value of 30.dp -> 20.dp
+        rule.onNode(isElement(exiting2)).assertPositionInRootIsEqualTo(20.dp, 0.dp)
+        // if the element is only modified by the second phase transition, verify it's in the middle
+        // of start-scene-state and target-scene-state, i.e. exiting3 is halfway between 100.dp and
+        // 80.dp
+        rule.onNode(isElement(exiting3)).assertSizeIsEqualTo(90.dp, 90.dp)
+
+        // verify that entering elements animate seamlessly to their target state
+        val entering1InA = layoutImpl.elements.getValue(entering1).sceneStates.getValue(SceneA)
+        // e.g. entering1, which was scaled from 0f to 0.25f during the preview phase, should now be
+        // half way scaled between 0.25f and its target-state of 1f -> 0.625f
+        assertThat(entering1InA.lastScale).isEqualTo(Scale(0.625f, 0.625f, Offset.Unspecified))
+        // entering2, which was translated from y=30.dp to y=15.dp should now be half way
+        // between 15.dp and its target state of 0.dp...
+        rule.onNode(isElement(entering2)).assertPositionInRootIsEqualTo(0.dp, 7.5.dp)
+        // entering3, which isn't affected by the preview transformation should be half scaled
+        // between start size (20.dp) and target size (100.dp) -> 60.dp
+        rule.onNode(isElement(entering3)).assertSizeIsEqualTo(60.dp, 60.dp)
+    }
+
+    private fun testPreviewTransformation(
+        from: SceneKey,
+        to: SceneKey,
+        exitingElements: List<ElementKey> = listOf(),
+        enteringElements: List<ElementKey> = listOf(),
+        preview: (TransitionBuilder.() -> Unit)? = null,
+        transition: TransitionBuilder.() -> Unit,
+        progress: Float = 0f,
+        previewProgress: Float = 0.5f,
+        isInPreviewStage: Boolean = true
+    ): SceneTransitionLayoutImpl {
+        val state =
+            rule.runOnIdle {
+                MutableSceneTransitionLayoutStateImpl(
+                    from,
+                    transitions { from(from, to = to, preview = preview, builder = transition) }
+                )
+            }
+
+        @Composable
+        fun SceneScope.Foo(elementKey: ElementKey) {
+            Box(Modifier.element(elementKey).size(100.dp))
+        }
+
+        lateinit var layoutImpl: SceneTransitionLayoutImpl
+        rule.setContent {
+            SceneTransitionLayoutForTesting(state, onLayoutImpl = { layoutImpl = it }) {
+                scene(from) { Box { exitingElements.forEach { Foo(it) } } }
+                scene(to) { Box { enteringElements.forEach { Foo(it) } } }
+            }
+        }
+
+        val bToA =
+            transition(
+                from = from,
+                to = to,
+                progress = { progress },
+                previewProgress = { previewProgress },
+                isInPreviewStage = { isInPreviewStage }
+            )
+        rule.runOnUiThread { state.startTransition(bToA) }
+        rule.waitForIdle()
+        return layoutImpl
+    }
 }
diff --git a/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/PredictiveBackHandlerTest.kt b/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/PredictiveBackHandlerTest.kt
index 6522eb3..0eaecb0 100644
--- a/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/PredictiveBackHandlerTest.kt
+++ b/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/PredictiveBackHandlerTest.kt
@@ -80,6 +80,56 @@
         assertThat(transition).hasFromScene(SceneA)
         assertThat(transition).hasToScene(SceneB)
         assertThat(transition).hasProgress(0.4f)
+        assertThat(transition).isNotInPreviewStage()
+
+        // Cancel it.
+        rule.runOnUiThread { dispatcher.dispatchOnBackCancelled() }
+        rule.waitForIdle()
+        assertThat(layoutState.transitionState).hasCurrentScene(SceneA)
+        assertThat(layoutState.transitionState).isIdle()
+
+        // Start again and commit it.
+        rule.runOnUiThread {
+            dispatcher.dispatchOnBackStarted(backEvent())
+            dispatcher.dispatchOnBackProgressed(backEvent(progress = 0.4f))
+            dispatcher.onBackPressed()
+        }
+        rule.waitForIdle()
+        assertThat(layoutState.transitionState).hasCurrentScene(SceneB)
+        assertThat(layoutState.transitionState).isIdle()
+    }
+
+    @Test
+    fun testPredictiveBackWithPreview() {
+        val layoutState =
+            rule.runOnUiThread {
+                MutableSceneTransitionLayoutState(
+                    SceneA,
+                    transitions = transitions { from(SceneA, to = SceneB, preview = {}) }
+                )
+            }
+        rule.setContent {
+            SceneTransitionLayout(layoutState) {
+                scene(SceneA, mapOf(Back to SceneB)) { Box(Modifier.fillMaxSize()) }
+                scene(SceneB) { Box(Modifier.fillMaxSize()) }
+            }
+        }
+
+        assertThat(layoutState.transitionState).hasCurrentScene(SceneA)
+
+        // Start back.
+        val dispatcher = rule.activity.onBackPressedDispatcher
+        rule.runOnUiThread {
+            dispatcher.dispatchOnBackStarted(backEvent())
+            dispatcher.dispatchOnBackProgressed(backEvent(progress = 0.4f))
+        }
+
+        val transition = assertThat(layoutState.transitionState).isTransition()
+        assertThat(transition).hasFromScene(SceneA)
+        assertThat(transition).hasToScene(SceneB)
+        assertThat(transition).hasPreviewProgress(0.4f)
+        assertThat(transition).hasProgress(0f)
+        assertThat(transition).isInPreviewStage()
 
         // Cancel it.
         rule.runOnUiThread { dispatcher.dispatchOnBackCancelled() }
diff --git a/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/Transition.kt b/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/Transition.kt
index 65f4f9e..66d4059 100644
--- a/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/Transition.kt
+++ b/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/Transition.kt
@@ -30,6 +30,9 @@
     current: () -> SceneKey = { from },
     progress: () -> Float = { 0f },
     progressVelocity: () -> Float = { 0f },
+    previewProgress: () -> Float = { 0f },
+    previewProgressVelocity: () -> Float = { 0f },
+    isInPreviewStage: () -> Boolean = { false },
     interruptionProgress: () -> Float = { 0f },
     isInitiatedByUserInput: Boolean = false,
     isUserInputOngoing: Boolean = false,
@@ -51,6 +54,15 @@
         override val progressVelocity: Float
             get() = progressVelocity()
 
+        override val previewProgress: Float
+            get() = previewProgress()
+
+        override val previewProgressVelocity: Float
+            get() = previewProgressVelocity()
+
+        override val isInPreviewStage: Boolean
+            get() = isInPreviewStage()
+
         override val isInitiatedByUserInput: Boolean = isInitiatedByUserInput
         override val isUserInputOngoing: Boolean = isUserInputOngoing
         override val isUpOrLeft: Boolean = isUpOrLeft
diff --git a/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/TransitionDslTest.kt b/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/TransitionDslTest.kt
index 825fe13..a3790f8 100644
--- a/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/TransitionDslTest.kt
+++ b/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/TransitionDslTest.kt
@@ -168,7 +168,14 @@
     @Test
     fun defaultReversed() {
         val transitions = transitions {
-            from(TestScenes.SceneA, to = TestScenes.SceneB) {
+            from(
+                TestScenes.SceneA,
+                to = TestScenes.SceneB,
+                preview = { fractionRange(start = 0.1f, end = 0.8f) { fade(TestElements.Foo) } },
+                reversePreview = {
+                    fractionRange(start = 0.5f, end = 0.6f) { fade(TestElements.Foo) }
+                }
+            ) {
                 spec = tween(500)
                 fractionRange(start = 0.1f, end = 0.8f) { fade(TestElements.Foo) }
                 timestampRange(startMillis = 100, endMillis = 300) { fade(TestElements.Foo) }
@@ -177,11 +184,10 @@
 
         // Fetch the transition from B to A, which will automatically reverse the transition from A
         // to B we defined.
-        val transformations =
-            transitions
-                .transitionSpec(from = TestScenes.SceneB, to = TestScenes.SceneA, key = null)
-                .transformationSpec()
-                .transformations
+        val transitionSpec =
+            transitions.transitionSpec(from = TestScenes.SceneB, to = TestScenes.SceneA, key = null)
+
+        val transformations = transitionSpec.transformationSpec().transformations
 
         assertThat(transformations)
             .comparingElementsUsing(TRANSFORMATION_RANGE)
@@ -189,6 +195,14 @@
                 TransformationRange(start = 1f - 0.8f, end = 1f - 0.1f),
                 TransformationRange(start = 1f - 300 / 500f, end = 1f - 100 / 500f),
             )
+
+        val previewTransformations = transitionSpec.previewTransformationSpec()?.transformations
+
+        assertThat(previewTransformations)
+            .comparingElementsUsing(TRANSFORMATION_RANGE)
+            .containsExactly(
+                TransformationRange(start = 0.5f, end = 0.6f),
+            )
     }
 
     @Test
diff --git a/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/subjects/TransitionStateSubject.kt b/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/subjects/TransitionStateSubject.kt
index 3489892..e997a75 100644
--- a/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/subjects/TransitionStateSubject.kt
+++ b/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/subjects/TransitionStateSubject.kt
@@ -95,6 +95,25 @@
             .of(progressVelocity)
     }
 
+    fun hasPreviewProgress(progress: Float, tolerance: Float = 0f) {
+        check("previewProgress").that(actual.previewProgress).isWithin(tolerance).of(progress)
+    }
+
+    fun hasPreviewProgressVelocity(progressVelocity: Float, tolerance: Float = 0f) {
+        check("previewProgressVelocity")
+            .that(actual.previewProgressVelocity)
+            .isWithin(tolerance)
+            .of(progressVelocity)
+    }
+
+    fun isInPreviewStage() {
+        check("isInPreviewStage").that(actual.isInPreviewStage).isTrue()
+    }
+
+    fun isNotInPreviewStage() {
+        check("isInPreviewStage").that(actual.isInPreviewStage).isFalse()
+    }
+
     fun isInitiatedByUserInput() {
         check("isInitiatedByUserInput").that(actual.isInitiatedByUserInput).isTrue()
     }