Merge "Expose the current Transition in TransitionBuilder" into main
diff --git a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayoutState.kt b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayoutState.kt
index 3c3c612..a9a8668 100644
--- a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayoutState.kt
+++ b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayoutState.kt
@@ -387,11 +387,11 @@
         transition.transformationSpec =
             transitions
                 .transitionSpec(fromContent, toContent, key = transition.key)
-                .transformationSpec()
+                .transformationSpec(transition)
         transition.previewTransformationSpec =
             transitions
                 .transitionSpec(fromContent, toContent, key = transition.key)
-                .previewTransformationSpec()
+                .previewTransformationSpec(transition)
         if (orientation != null) {
             transition.updateOverscrollSpecs(
                 fromSpec = transitions.overscrollSpec(fromContent, orientation),
diff --git a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitions.kt b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitions.kt
index 879dc54..8866fbf 100644
--- a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitions.kt
+++ b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitions.kt
@@ -25,6 +25,7 @@
 import androidx.compose.ui.geometry.Offset
 import androidx.compose.ui.unit.IntSize
 import androidx.compose.ui.util.fastForEach
+import com.android.compose.animation.scene.content.state.TransitionState
 import com.android.compose.animation.scene.transformation.AnchoredSize
 import com.android.compose.animation.scene.transformation.AnchoredTranslate
 import com.android.compose.animation.scene.transformation.DrawScale
@@ -191,20 +192,21 @@
     fun reversed(): TransitionSpec
 
     /**
-     * The [TransformationSpec] associated to this [TransitionSpec].
+     * The [TransformationSpec] associated to this [TransitionSpec] for the given [transition].
      *
      * Note that this is called once whenever a transition associated to this [TransitionSpec] is
      * started.
      */
-    fun transformationSpec(): TransformationSpec
+    fun transformationSpec(transition: TransitionState.Transition): TransformationSpec
 
     /**
-     * The preview [TransformationSpec] associated to this [TransitionSpec].
+     * The preview [TransformationSpec] associated to this [TransitionSpec] for the given
+     * [transition].
      *
      * Note that this is called once whenever a transition associated to this [TransitionSpec] is
      * started.
      */
-    fun previewTransformationSpec(): TransformationSpec?
+    fun previewTransformationSpec(transition: TransitionState.Transition): TransformationSpec?
 }
 
 interface TransformationSpec {
@@ -241,7 +243,7 @@
                 distance = null,
                 transformations = emptyList(),
             )
-        internal val EmptyProvider = { Empty }
+        internal val EmptyProvider = { _: TransitionState.Transition -> Empty }
     }
 }
 
@@ -249,9 +251,13 @@
     override val key: TransitionKey?,
     override val from: ContentKey?,
     override val to: ContentKey?,
-    private val previewTransformationSpec: (() -> TransformationSpecImpl)? = null,
-    private val reversePreviewTransformationSpec: (() -> TransformationSpecImpl)? = null,
-    private val transformationSpec: () -> TransformationSpecImpl,
+    private val previewTransformationSpec:
+        ((TransitionState.Transition) -> TransformationSpecImpl)? =
+        null,
+    private val reversePreviewTransformationSpec:
+        ((TransitionState.Transition) -> TransformationSpecImpl)? =
+        null,
+    private val transformationSpec: (TransitionState.Transition) -> TransformationSpecImpl,
 ) : TransitionSpec {
     override fun reversed(): TransitionSpecImpl {
         return TransitionSpecImpl(
@@ -260,8 +266,8 @@
             to = from,
             previewTransformationSpec = reversePreviewTransformationSpec,
             reversePreviewTransformationSpec = previewTransformationSpec,
-            transformationSpec = {
-                val reverse = transformationSpec.invoke()
+            transformationSpec = { transition ->
+                val reverse = transformationSpec.invoke(transition)
                 TransformationSpecImpl(
                     progressSpec = reverse.progressSpec,
                     swipeSpec = reverse.swipeSpec,
@@ -272,10 +278,13 @@
         )
     }
 
-    override fun transformationSpec(): TransformationSpecImpl = this.transformationSpec.invoke()
+    override fun transformationSpec(
+        transition: TransitionState.Transition
+    ): TransformationSpecImpl = transformationSpec.invoke(transition)
 
-    override fun previewTransformationSpec(): TransformationSpecImpl? =
-        previewTransformationSpec?.invoke()
+    override fun previewTransformationSpec(
+        transition: TransitionState.Transition
+    ): TransformationSpecImpl? = previewTransformationSpec?.invoke(transition)
 }
 
 /** The definition of the overscroll behavior of the [content]. */
diff --git a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/TransitionDsl.kt b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/TransitionDsl.kt
index 763dc6b..e825c6e 100644
--- a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/TransitionDsl.kt
+++ b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/TransitionDsl.kt
@@ -156,6 +156,9 @@
 
 @TransitionDsl
 interface TransitionBuilder : BaseTransitionBuilder {
+    /** The [TransitionState.Transition] for which we currently compute the transformations. */
+    val transition: TransitionState.Transition
+
     /**
      * The [AnimationSpec] used to animate the associated transition progress from `0` to `1` when
      * the transition is triggered (i.e. it is not gesture-based).
diff --git a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/TransitionDslImpl.kt b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/TransitionDslImpl.kt
index 7ec5e4f..a5ad999 100644
--- a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/TransitionDslImpl.kt
+++ b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/TransitionDslImpl.kt
@@ -27,6 +27,7 @@
 import androidx.compose.foundation.gestures.Orientation
 import androidx.compose.ui.geometry.Offset
 import androidx.compose.ui.unit.Dp
+import com.android.compose.animation.scene.content.state.TransitionState
 import com.android.compose.animation.scene.transformation.AnchoredSize
 import com.android.compose.animation.scene.transformation.AnchoredTranslate
 import com.android.compose.animation.scene.transformation.DrawScale
@@ -128,8 +129,11 @@
         reversePreview: (TransitionBuilder.() -> Unit)?,
         builder: TransitionBuilder.() -> Unit,
     ): TransitionSpec {
-        fun transformationSpec(builder: TransitionBuilder.() -> Unit): TransformationSpecImpl {
-            val impl = TransitionBuilderImpl().apply(builder)
+        fun transformationSpec(
+            transition: TransitionState.Transition,
+            builder: TransitionBuilder.() -> Unit,
+        ): TransformationSpecImpl {
+            val impl = TransitionBuilderImpl(transition).apply(builder)
             return TransformationSpecImpl(
                 progressSpec = impl.spec,
                 swipeSpec = impl.swipeSpec,
@@ -138,17 +142,15 @@
             )
         }
 
-        val previewTransformationSpec = preview?.let { { transformationSpec(it) } }
-        val reversePreviewTransformationSpec = reversePreview?.let { { transformationSpec(it) } }
-        val transformationSpec = { transformationSpec(builder) }
         val spec =
             TransitionSpecImpl(
                 key,
                 from,
                 to,
-                previewTransformationSpec,
-                reversePreviewTransformationSpec,
-                transformationSpec,
+                previewTransformationSpec = preview?.let { { t -> transformationSpec(t, it) } },
+                reversePreviewTransformationSpec =
+                    reversePreview?.let { { t -> transformationSpec(t, it) } },
+                transformationSpec = { t -> transformationSpec(t, builder) },
             )
         transitionSpecs.add(spec)
         return spec
@@ -227,7 +229,8 @@
     }
 }
 
-internal class TransitionBuilderImpl : BaseTransitionBuilderImpl(), TransitionBuilder {
+internal class TransitionBuilderImpl(override val transition: TransitionState.Transition) :
+    BaseTransitionBuilderImpl(), TransitionBuilder {
     override var spec: AnimationSpec<Float> = spring(stiffness = Spring.StiffnessLow)
     override var swipeSpec: SpringSpec<Float>? = null
     override var distance: UserActionDistance? = null
diff --git a/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/TransitionDslTest.kt b/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/TransitionDslTest.kt
index 223af80..d66d6b3 100644
--- a/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/TransitionDslTest.kt
+++ b/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/TransitionDslTest.kt
@@ -23,11 +23,17 @@
 import androidx.compose.animation.core.tween
 import androidx.compose.foundation.gestures.Orientation
 import androidx.test.ext.junit.runners.AndroidJUnit4
+import com.android.compose.animation.scene.TestScenes.SceneA
+import com.android.compose.animation.scene.TestScenes.SceneB
+import com.android.compose.animation.scene.TestScenes.SceneC
+import com.android.compose.animation.scene.content.state.TransitionState
 import com.android.compose.animation.scene.transformation.OverscrollTranslate
 import com.android.compose.animation.scene.transformation.Transformation
 import com.android.compose.animation.scene.transformation.TransformationRange
+import com.android.compose.test.transition
 import com.google.common.truth.Correspondence
 import com.google.common.truth.Truth.assertThat
+import kotlinx.coroutines.test.runTest
 import org.junit.Assert.assertThrows
 import org.junit.Test
 import org.junit.runner.RunWith
@@ -43,9 +49,9 @@
     @Test
     fun manyTransitions() {
         val transitions = transitions {
-            from(TestScenes.SceneA, to = TestScenes.SceneB)
-            from(TestScenes.SceneB, to = TestScenes.SceneC)
-            from(TestScenes.SceneC, to = TestScenes.SceneA)
+            from(SceneA, to = SceneB)
+            from(SceneB, to = SceneC)
+            from(SceneC, to = SceneA)
         }
         assertThat(transitions.transitionSpecs).hasSize(3)
     }
@@ -53,9 +59,9 @@
     @Test
     fun toFromBuilders() {
         val transitions = transitions {
-            from(TestScenes.SceneA, to = TestScenes.SceneB)
-            from(TestScenes.SceneB)
-            to(TestScenes.SceneC)
+            from(SceneA, to = SceneB)
+            from(SceneB)
+            to(SceneC)
         }
 
         assertThat(transitions.transitionSpecs)
@@ -65,38 +71,34 @@
                     "has (from, to) equal to",
                 )
             )
-            .containsExactly(
-                TestScenes.SceneA to TestScenes.SceneB,
-                TestScenes.SceneB to null,
-                null to TestScenes.SceneC,
-            )
+            .containsExactly(SceneA to SceneB, SceneB to null, null to SceneC)
     }
 
+    private fun aToB() = transition(SceneA, SceneB)
+
     @Test
     fun defaultTransitionSpec() {
-        val transitions = transitions { from(TestScenes.SceneA, to = TestScenes.SceneB) }
-        val transformationSpec = transitions.transitionSpecs.single().transformationSpec()
+        val transitions = transitions { from(SceneA, to = SceneB) }
+        val transformationSpec = transitions.transitionSpecs.single().transformationSpec(aToB())
         assertThat(transformationSpec.progressSpec).isInstanceOf(SpringSpec::class.java)
     }
 
     @Test
     fun customTransitionSpec() {
         val transitions = transitions {
-            from(TestScenes.SceneA, to = TestScenes.SceneB) { spec = tween(durationMillis = 42) }
+            from(SceneA, to = SceneB) { spec = tween(durationMillis = 42) }
         }
-        val transformationSpec = transitions.transitionSpecs.single().transformationSpec()
+        val transformationSpec = transitions.transitionSpecs.single().transformationSpec(aToB())
         assertThat(transformationSpec.progressSpec).isInstanceOf(TweenSpec::class.java)
         assertThat((transformationSpec.progressSpec as TweenSpec).durationMillis).isEqualTo(42)
     }
 
     @Test
     fun defaultRange() {
-        val transitions = transitions {
-            from(TestScenes.SceneA, to = TestScenes.SceneB) { fade(TestElements.Foo) }
-        }
+        val transitions = transitions { from(SceneA, to = SceneB) { fade(TestElements.Foo) } }
 
         val transformations =
-            transitions.transitionSpecs.single().transformationSpec().transformations
+            transitions.transitionSpecs.single().transformationSpec(aToB()).transformations
         assertThat(transformations.size).isEqualTo(1)
         assertThat(transformations.single().range).isEqualTo(null)
     }
@@ -104,7 +106,7 @@
     @Test
     fun fractionRange() {
         val transitions = transitions {
-            from(TestScenes.SceneA, to = TestScenes.SceneB) {
+            from(SceneA, to = SceneB) {
                 fractionRange(start = 0.1f, end = 0.8f) { fade(TestElements.Foo) }
                 fractionRange(start = 0.2f) { fade(TestElements.Foo) }
                 fractionRange(end = 0.9f) { fade(TestElements.Foo) }
@@ -119,7 +121,7 @@
         }
 
         val transformations =
-            transitions.transitionSpecs.single().transformationSpec().transformations
+            transitions.transitionSpecs.single().transformationSpec(aToB()).transformations
         assertThat(transformations)
             .comparingElementsUsing(TRANSFORMATION_RANGE)
             .containsExactly(
@@ -133,7 +135,7 @@
     @Test
     fun timestampRange() {
         val transitions = transitions {
-            from(TestScenes.SceneA, to = TestScenes.SceneB) {
+            from(SceneA, to = SceneB) {
                 spec = tween(500)
 
                 timestampRange(startMillis = 100, endMillis = 300) { fade(TestElements.Foo) }
@@ -150,7 +152,7 @@
         }
 
         val transformations =
-            transitions.transitionSpecs.single().transformationSpec().transformations
+            transitions.transitionSpecs.single().transformationSpec(aToB()).transformations
         assertThat(transformations)
             .comparingElementsUsing(TRANSFORMATION_RANGE)
             .containsExactly(
@@ -168,7 +170,7 @@
     @Test
     fun reversed() {
         val transitions = transitions {
-            from(TestScenes.SceneA, to = TestScenes.SceneB) {
+            from(SceneA, to = SceneB) {
                 spec = tween(500)
                 reversed {
                     fractionRange(start = 0.1f, end = 0.8f) { fade(TestElements.Foo) }
@@ -178,7 +180,7 @@
         }
 
         val transformations =
-            transitions.transitionSpecs.single().transformationSpec().transformations
+            transitions.transitionSpecs.single().transformationSpec(aToB()).transformations
         assertThat(transformations)
             .comparingElementsUsing(TRANSFORMATION_RANGE)
             .containsExactly(
@@ -191,8 +193,8 @@
     fun defaultReversed() {
         val transitions = transitions {
             from(
-                TestScenes.SceneA,
-                to = TestScenes.SceneB,
+                SceneA,
+                to = SceneB,
                 preview = { fractionRange(start = 0.1f, end = 0.8f) { fade(TestElements.Foo) } },
                 reversePreview = {
                     fractionRange(start = 0.5f, end = 0.6f) { fade(TestElements.Foo) }
@@ -206,10 +208,9 @@
 
         // Fetch the transition from B to A, which will automatically reverse the transition from A
         // to B we defined.
-        val transitionSpec =
-            transitions.transitionSpec(from = TestScenes.SceneB, to = TestScenes.SceneA, key = null)
+        val transitionSpec = transitions.transitionSpec(from = SceneB, to = SceneA, key = null)
 
-        val transformations = transitionSpec.transformationSpec().transformations
+        val transformations = transitionSpec.transformationSpec(aToB()).transformations
 
         assertThat(transformations)
             .comparingElementsUsing(TRANSFORMATION_RANGE)
@@ -218,7 +219,8 @@
                 TransformationRange(start = 1f - 300 / 500f, end = 1f - 100 / 500f),
             )
 
-        val previewTransformations = transitionSpec.previewTransformationSpec()?.transformations
+        val previewTransformations =
+            transitionSpec.previewTransformationSpec(aToB())?.transformations
 
         assertThat(previewTransformations)
             .comparingElementsUsing(TRANSFORMATION_RANGE)
@@ -229,8 +231,8 @@
     fun defaultPredictiveBack() {
         val transitions = transitions {
             from(
-                TestScenes.SceneA,
-                to = TestScenes.SceneB,
+                SceneA,
+                to = SceneB,
                 preview = { fractionRange(start = 0.1f, end = 0.8f) { fade(TestElements.Foo) } },
             ) {
                 spec = tween(500)
@@ -243,12 +245,12 @@
         // transition despite it not having the PredictiveBack key set.
         val transitionSpec =
             transitions.transitionSpec(
-                from = TestScenes.SceneA,
-                to = TestScenes.SceneB,
+                from = SceneA,
+                to = SceneB,
                 key = TransitionKey.PredictiveBack,
             )
 
-        val transformations = transitionSpec.transformationSpec().transformations
+        val transformations = transitionSpec.transformationSpec(aToB()).transformations
 
         assertThat(transformations)
             .comparingElementsUsing(TRANSFORMATION_RANGE)
@@ -257,7 +259,8 @@
                 TransformationRange(start = 100 / 500f, end = 300 / 500f),
             )
 
-        val previewTransformations = transitionSpec.previewTransformationSpec()?.transformations
+        val previewTransformations =
+            transitionSpec.previewTransformationSpec(aToB())?.transformations
 
         assertThat(previewTransformations)
             .comparingElementsUsing(TRANSFORMATION_RANGE)
@@ -271,10 +274,10 @@
         val transitions = transitions {
             defaultSwipeSpec = defaultSpec
 
-            from(TestScenes.SceneA, to = TestScenes.SceneB) {
+            from(SceneA, to = SceneB) {
                 // Default swipe spec.
             }
-            from(TestScenes.SceneA, to = TestScenes.SceneC) { swipeSpec = specFromAToC }
+            from(SceneA, to = SceneC) { swipeSpec = specFromAToC }
         }
 
         assertThat(transitions.defaultSwipeSpec).isSameInstanceAs(defaultSpec)
@@ -282,8 +285,8 @@
         // A => B does not have a custom spec.
         assertThat(
                 transitions
-                    .transitionSpec(from = TestScenes.SceneA, to = TestScenes.SceneB, key = null)
-                    .transformationSpec()
+                    .transitionSpec(from = SceneA, to = SceneB, key = null)
+                    .transformationSpec(aToB())
                     .swipeSpec
             )
             .isNull()
@@ -291,8 +294,8 @@
         // A => C has a custom swipe spec.
         assertThat(
                 transitions
-                    .transitionSpec(from = TestScenes.SceneA, to = TestScenes.SceneC, key = null)
-                    .transformationSpec()
+                    .transitionSpec(from = SceneA, to = SceneC, key = null)
+                    .transformationSpec(transition(from = SceneA, to = SceneC))
                     .swipeSpec
             )
             .isSameInstanceAs(specFromAToC)
@@ -301,7 +304,7 @@
     @Test
     fun overscrollSpec() {
         val transitions = transitions {
-            overscroll(TestScenes.SceneA, Orientation.Vertical) {
+            overscroll(SceneA, Orientation.Vertical) {
                 translate(TestElements.Bar, x = { 1f }, y = { 2f })
             }
         }
@@ -313,9 +316,7 @@
 
     @Test
     fun overscrollSpec_for_overscrollDisabled() {
-        val transitions = transitions {
-            overscrollDisabled(TestScenes.SceneA, Orientation.Vertical)
-        }
+        val transitions = transitions { overscrollDisabled(SceneA, Orientation.Vertical) }
         val overscrollSpec = transitions.overscrollSpecs.single()
         assertThat(overscrollSpec.transformationSpec.transformations).isEmpty()
     }
@@ -323,10 +324,24 @@
     @Test
     fun overscrollSpec_throwIfTransformationsIsEmpty() {
         assertThrows(IllegalStateException::class.java) {
-            transitions { overscroll(TestScenes.SceneA, Orientation.Vertical) {} }
+            transitions { overscroll(SceneA, Orientation.Vertical) {} }
         }
     }
 
+    @Test
+    fun transitionIsPassedToBuilder() = runTest {
+        var transitionPassedToBuilder: TransitionState.Transition? = null
+        val state =
+            MutableSceneTransitionLayoutState(
+                SceneA,
+                transitions { from(SceneA, to = SceneB) { transitionPassedToBuilder = transition } },
+            )
+
+        val transition = aToB()
+        state.startTransitionImmediately(animationScope = backgroundScope, transition)
+        assertThat(transitionPassedToBuilder).isSameInstanceAs(transition)
+    }
+
     companion object {
         private val TRANSFORMATION_RANGE =
             Correspondence.transforming<Transformation, TransformationRange?>(