Make TransitionLinks match more flexible

TransitionLinks can be initialized with `null` to indicate matching any
Scene in this slot.

Test: SceneTransitionLayoutStateTest
Bug: b/320257219
Flag: NONE
Change-Id: I3b3cf18cab4b8cc177c314f8401f8384b655876b
diff --git a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayoutState.kt b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayoutState.kt
index 2661301..a8da551 100644
--- a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayoutState.kt
+++ b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayoutState.kt
@@ -251,12 +251,15 @@
     private fun setupTransitionLinks(transitionState: TransitionState) {
         if (transitionState !is TransitionState.Transition) return
         stateLinks.fastForEach { stateLink ->
-            val matchingLink =
-                stateLink.transitionLinks.firstOrNull() { it.isMatchingLink(transitionState) } ?: return@fastForEach
+            val matchingLinks =
+                stateLink.transitionLinks.fastFilter { it.isMatchingLink(transitionState) }
+            if (matchingLinks.isEmpty()) return@fastForEach
+            if (matchingLinks.size > 1) error("More than one link matched.")
 
             val targetCurrentScene = stateLink.target.transitionState.currentScene
+            val matchingLink = matchingLinks[0]
 
-            if (targetCurrentScene != matchingLink.targetFrom) return@fastForEach
+            if (!matchingLink.targetIsInValidState(targetCurrentScene)) return@fastForEach
 
             val linkedTransition =
                 LinkedTransition(
diff --git a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/transition/link/StateLink.kt b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/transition/link/StateLink.kt
index 9b51e44..6c29946 100644
--- a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/transition/link/StateLink.kt
+++ b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/transition/link/StateLink.kt
@@ -35,9 +35,9 @@
      * target to `SceneA` from any current scene.
      */
     class TransitionLink(
-        val sourceFrom: SceneKey,
-        val sourceTo: SceneKey,
-        val targetFrom: SceneKey,
+        val sourceFrom: SceneKey?,
+        val sourceTo: SceneKey?,
+        val targetFrom: SceneKey?,
         val targetTo: SceneKey,
         val targetTransitionKey: TransitionKey? = null,
     ) {
@@ -50,12 +50,12 @@
         }
 
         internal fun isMatchingLink(transition: TransitionState.Transition): Boolean {
-            return (sourceFrom == transition.fromScene) &&
-                (sourceTo == transition.toScene)
+            return (sourceFrom == null || sourceFrom == transition.fromScene) &&
+                (sourceTo == null || sourceTo == transition.toScene)
         }
 
         internal fun targetIsInValidState(targetCurrentScene: SceneKey): Boolean {
-            return (targetFrom == targetCurrentScene) &&
+            return (targetFrom == null || targetFrom == targetCurrentScene) &&
                 targetTo != targetCurrentScene
         }
     }
diff --git a/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/SceneTransitionLayoutStateTest.kt b/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/SceneTransitionLayoutStateTest.kt
index 2a5a355..f81a7f2 100644
--- a/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/SceneTransitionLayoutStateTest.kt
+++ b/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/SceneTransitionLayoutStateTest.kt
@@ -130,17 +130,23 @@
         assertThat(state.transitionState).isEqualTo(TransitionState.Idle(SceneB))
     }
 
-    private fun setupLinkedStates():
-            Pair<BaseSceneTransitionLayoutState, BaseSceneTransitionLayoutState> {
-        val parentState = MutableSceneTransitionLayoutState(SceneC)
+    private fun setupLinkedStates(
+        parentInitialScene: SceneKey = SceneC,
+        childInitialScene: SceneKey = SceneA,
+        sourceFrom: SceneKey? = SceneA,
+        sourceTo: SceneKey? = SceneB,
+        targetFrom: SceneKey? = SceneC,
+        targetTo: SceneKey = SceneD
+    ): Pair<BaseSceneTransitionLayoutState, BaseSceneTransitionLayoutState> {
+        val parentState = MutableSceneTransitionLayoutState(parentInitialScene)
         val link =
             listOf(
                 StateLink(
                     parentState,
-                    listOf(StateLink.TransitionLink(SceneA, SceneB, SceneC, SceneD))
+                    listOf(StateLink.TransitionLink(sourceFrom, sourceTo, targetFrom, targetTo))
                 )
             )
-        val childState = MutableSceneTransitionLayoutState(SceneA, stateLinks = link)
+        val childState = MutableSceneTransitionLayoutState(childInitialScene, stateLinks = link)
         return Pair(
             parentState as BaseSceneTransitionLayoutState,
             childState as BaseSceneTransitionLayoutState
@@ -342,4 +348,45 @@
         assertThat(state.isTransitioning()).isFalse()
         assertThat(state.transitionState).isEqualTo(TransitionState.Idle(TestScenes.SceneB))
     }
+
+    @Test
+    fun linkedTransition_fuzzyLinksAreMatchedAndStarted() {
+        val (parentState, childState) = setupLinkedStates(SceneC, SceneA, null, null, null, SceneD)
+        val childTransition = TestableTransition(SceneA, SceneB)
+
+        childState.startTransition(childTransition, null)
+        assertThat(childState.isTransitioning(SceneA, SceneB)).isTrue()
+        assertThat(parentState.isTransitioning(SceneC, SceneD)).isTrue()
+
+        childState.finishTransition(childTransition, SceneB)
+        assertThat(childState.transitionState).isEqualTo(TransitionState.Idle(SceneB))
+        assertThat(parentState.transitionState).isEqualTo(TransitionState.Idle(SceneD))
+    }
+
+    @Test
+    fun linkedTransition_fuzzyLinksAreMatchedAndResetToProperPreviousScene() {
+        val (parentState, childState) =
+            setupLinkedStates(SceneC, SceneA, SceneA, null, null, SceneD)
+
+        val childTransition = TestableTransition(SceneA, SceneB)
+
+        childState.startTransition(childTransition, null)
+        assertThat(childState.isTransitioning(SceneA, SceneB)).isTrue()
+        assertThat(parentState.isTransitioning(SceneC, SceneD)).isTrue()
+
+        childState.finishTransition(childTransition, SceneA)
+        assertThat(childState.transitionState).isEqualTo(TransitionState.Idle(SceneA))
+        assertThat(parentState.transitionState).isEqualTo(TransitionState.Idle(SceneC))
+    }
+
+    @Test
+    fun linkedTransition_fuzzyLinksAreNotMatched() {
+        val (parentState, childState) =
+            setupLinkedStates(SceneC, SceneA, SceneB, null, SceneC, SceneD)
+        val childTransition = TestableTransition(SceneA, SceneB)
+
+        childState.startTransition(childTransition, null)
+        assertThat(childState.isTransitioning(SceneA, SceneB)).isTrue()
+        assertThat(parentState.isTransitioning(SceneC, SceneD)).isFalse()
+    }
 }