Refactor STL Element maps logic

This CL refactors the logic to create/removes Elements and
Element.TargetValues in the SceneTransitionLayout maps, such that we
don't incorrectly throw an exception when a Modifier.element() is moved
from one node to another in a given scene.

Another major improvement this CL brings is that we don't need
Modifier.element() to run in a composition context anymore, which should
make Modifier.element() much more performant.

Bug: 310241171
Bug: 291566282
Flag: NA
Test: atest ElementTest
Change-Id: I1f513191bca1cbff4917091b9af069c47ea658da
diff --git a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/Element.kt b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/Element.kt
index c01d32e..3b999e30 100644
--- a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/Element.kt
+++ b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/Element.kt
@@ -17,17 +17,14 @@
 package com.android.compose.animation.scene
 
 import androidx.compose.runtime.Composable
-import androidx.compose.runtime.DisposableEffect
 import androidx.compose.runtime.getValue
 import androidx.compose.runtime.movableContentOf
 import androidx.compose.runtime.mutableStateOf
-import androidx.compose.runtime.remember
 import androidx.compose.runtime.setValue
 import androidx.compose.runtime.snapshots.Snapshot
 import androidx.compose.runtime.snapshots.SnapshotStateMap
 import androidx.compose.ui.ExperimentalComposeUiApi
 import androidx.compose.ui.Modifier
-import androidx.compose.ui.composed
 import androidx.compose.ui.draw.drawWithContent
 import androidx.compose.ui.geometry.Offset
 import androidx.compose.ui.geometry.isSpecified
@@ -38,6 +35,7 @@
 import androidx.compose.ui.layout.Measurable
 import androidx.compose.ui.layout.Placeable
 import androidx.compose.ui.layout.intermediateLayout
+import androidx.compose.ui.node.ModifierNodeElement
 import androidx.compose.ui.platform.testTag
 import androidx.compose.ui.unit.Constraints
 import androidx.compose.ui.unit.IntSize
@@ -45,6 +43,7 @@
 import com.android.compose.animation.scene.transformation.PropertyTransformation
 import com.android.compose.animation.scene.transformation.SharedElementTransformation
 import com.android.compose.ui.util.lerp
+import kotlinx.coroutines.launch
 
 /** An element on screen, that can be composed in one or more scenes. */
 internal class Element(val key: ElementKey) {
@@ -91,13 +90,20 @@
     }
 
     /** The target values of this element in a given scene. */
-    class TargetValues {
+    class TargetValues(val scene: SceneKey) {
         val lastValues = Values()
 
         var targetSize by mutableStateOf(SizeUnspecified)
         var targetOffset by mutableStateOf(Offset.Unspecified)
 
         val sharedValues = SnapshotStateMap<ValueKey, SharedValue<*>>()
+
+        /**
+         * The attached [ElementNode] a Modifier.element() for a given element and scene. During
+         * composition, this set could have 0 to 2 elements. After composition and after all
+         * modifier nodes have been attached/detached, this set should contain exactly 1 element.
+         */
+        val nodes = mutableSetOf<ElementNode>()
     }
 
     /** A shared value of this element. */
@@ -124,37 +130,22 @@
     layoutImpl: SceneTransitionLayoutImpl,
     scene: Scene,
     key: ElementKey,
-): Modifier = composed {
-    val sceneValues = remember(scene, key) { Element.TargetValues() }
-    val element =
-        // Get the element associated to [key] if it was already composed in another scene,
-        // otherwise create it and add it to our Map<ElementKey, Element>. This is done inside a
-        // withoutReadObservation() because there is no need to recompose when that map is mutated.
-        Snapshot.withoutReadObservation {
-            val element =
-                layoutImpl.elements[key] ?: Element(key).also { layoutImpl.elements[key] = it }
-            val previousValues = element.sceneValues[scene.key]
-            if (previousValues == null) {
-                element.sceneValues[scene.key] = sceneValues
-            } else if (previousValues != sceneValues) {
-                error("$key was composed multiple times in $scene")
-            }
+): Modifier {
+    val element: Element
+    val sceneValues: Element.TargetValues
 
-            element
-        }
-
-    DisposableEffect(scene, sceneValues, element) {
-        onDispose {
-            element.sceneValues.remove(scene.key)
-
-            // This was the last scene this element was in, so remove it from the map.
-            if (element.sceneValues.isEmpty()) {
-                layoutImpl.elements.remove(element.key)
-            }
-        }
+    // Get the element associated to [key] if it was already composed in another scene,
+    // otherwise create it and add it to our Map<ElementKey, Element>. This is done inside a
+    // withoutReadObservation() because there is no need to recompose when that map is mutated.
+    Snapshot.withoutReadObservation {
+        element = layoutImpl.elements[key] ?: Element(key).also { layoutImpl.elements[key] = it }
+        sceneValues =
+            element.sceneValues[scene.key]
+                ?: Element.TargetValues(scene.key).also { element.sceneValues[scene.key] = it }
     }
 
-    drawWithContent {
+    return this.then(ElementModifier(layoutImpl, element, sceneValues))
+        .drawWithContent {
             if (shouldDrawElement(layoutImpl, scene, element)) {
                 val drawScale = getDrawScale(layoutImpl, element, scene, sceneValues)
                 if (drawScale == Scale.Default) {
@@ -181,6 +172,84 @@
         .testTag(key.testTag)
 }
 
+/**
+ * An element associated to [ElementNode]. Note that this element does not support updates as its
+ * arguments should always be the same.
+ */
+private data class ElementModifier(
+    private val layoutImpl: SceneTransitionLayoutImpl,
+    private val element: Element,
+    private val sceneValues: Element.TargetValues,
+) : ModifierNodeElement<ElementNode>() {
+    override fun create(): ElementNode = ElementNode(layoutImpl, element, sceneValues)
+
+    override fun update(node: ElementNode) {
+        node.update(layoutImpl, element, sceneValues)
+    }
+}
+
+internal class ElementNode(
+    layoutImpl: SceneTransitionLayoutImpl,
+    element: Element,
+    sceneValues: Element.TargetValues,
+) : Modifier.Node() {
+    private var layoutImpl: SceneTransitionLayoutImpl = layoutImpl
+    private var element: Element = element
+    private var sceneValues: Element.TargetValues = sceneValues
+
+    override fun onAttach() {
+        super.onAttach()
+        addNodeToSceneValues()
+    }
+
+    private fun addNodeToSceneValues() {
+        sceneValues.nodes.add(this)
+
+        coroutineScope.launch {
+            // At this point all [CodeLocationNode] have been attached or detached, which means that
+            // [sceneValues.codeLocations] should have exactly 1 element, otherwise this means that
+            // this element was composed multiple times in the same scene.
+            val nCodeLocations = sceneValues.nodes.size
+            if (nCodeLocations != 1 || !sceneValues.nodes.contains(this@ElementNode)) {
+                error("${element.key} was composed $nCodeLocations times in ${sceneValues.scene}")
+            }
+        }
+    }
+
+    override fun onDetach() {
+        super.onDetach()
+        removeNodeFromSceneValues()
+    }
+
+    private fun removeNodeFromSceneValues() {
+        sceneValues.nodes.remove(this)
+
+        // If element is not composed from this scene anymore, remove the scene values. This works
+        // because [onAttach] is called before [onDetach], so if an element is moved from the UI
+        // tree we will first add the new code location then remove the old one.
+        if (sceneValues.nodes.isEmpty()) {
+            element.sceneValues.remove(sceneValues.scene)
+        }
+
+        // If the element is not composed in any scene, remove it from the elements map.
+        if (element.sceneValues.isEmpty()) {
+            layoutImpl.elements.remove(element.key)
+        }
+    }
+
+    fun update(
+        layoutImpl: SceneTransitionLayoutImpl,
+        element: Element,
+        sceneValues: Element.TargetValues,
+    ) {
+        removeNodeFromSceneValues()
+        this.layoutImpl = layoutImpl
+        this.element = element
+        this.sceneValues = sceneValues
+        addNodeToSceneValues()
+    }
+}
+
 private fun shouldDrawElement(
     layoutImpl: SceneTransitionLayoutImpl,
     scene: Scene,
diff --git a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayout.kt b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayout.kt
index 9c31445..0f2c2f6 100644
--- a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayout.kt
+++ b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayout.kt
@@ -57,30 +57,17 @@
     @FloatRange(from = 0.0, to = 0.5) transitionInterceptionThreshold: Float = 0f,
     scenes: SceneTransitionLayoutScope.() -> Unit,
 ) {
-    val density = LocalDensity.current
-    val coroutineScope = rememberCoroutineScope()
-    val layoutImpl = remember {
-        SceneTransitionLayoutImpl(
-            onChangeScene = onChangeScene,
-            builder = scenes,
-            transitions = transitions,
-            state = state,
-            density = density,
-            edgeDetector = edgeDetector,
-            transitionInterceptionThreshold = transitionInterceptionThreshold,
-            coroutineScope = coroutineScope,
-        )
-    }
-
-    layoutImpl.onChangeScene = onChangeScene
-    layoutImpl.transitions = transitions
-    layoutImpl.density = density
-    layoutImpl.edgeDetector = edgeDetector
-    layoutImpl.transitionInterceptionThreshold = transitionInterceptionThreshold
-
-    layoutImpl.setScenes(scenes)
-    layoutImpl.setCurrentScene(currentScene)
-    layoutImpl.Content(modifier)
+    SceneTransitionLayoutForTesting(
+        currentScene,
+        onChangeScene,
+        transitions,
+        state,
+        edgeDetector,
+        transitionInterceptionThreshold,
+        modifier,
+        onLayoutImpl = null,
+        scenes,
+    )
 }
 
 interface SceneTransitionLayoutScope {
@@ -228,3 +215,47 @@
     Left(Orientation.Horizontal),
     Right(Orientation.Horizontal),
 }
+
+/**
+ * An internal version of [SceneTransitionLayout] to be used for tests.
+ *
+ * Important: You should use this only in tests and if you need to access the underlying
+ * [SceneTransitionLayoutImpl]. In other cases, you should use [SceneTransitionLayout].
+ */
+@Composable
+internal fun SceneTransitionLayoutForTesting(
+    currentScene: SceneKey,
+    onChangeScene: (SceneKey) -> Unit,
+    transitions: SceneTransitions,
+    state: SceneTransitionLayoutState,
+    edgeDetector: EdgeDetector,
+    transitionInterceptionThreshold: Float,
+    modifier: Modifier,
+    onLayoutImpl: ((SceneTransitionLayoutImpl) -> Unit)?,
+    scenes: SceneTransitionLayoutScope.() -> Unit,
+) {
+    val density = LocalDensity.current
+    val coroutineScope = rememberCoroutineScope()
+    val layoutImpl = remember {
+        SceneTransitionLayoutImpl(
+                onChangeScene = onChangeScene,
+                builder = scenes,
+                transitions = transitions,
+                state = state,
+                density = density,
+                edgeDetector = edgeDetector,
+                transitionInterceptionThreshold = transitionInterceptionThreshold,
+                coroutineScope = coroutineScope,
+            )
+            .also { onLayoutImpl?.invoke(it) }
+    }
+
+    layoutImpl.onChangeScene = onChangeScene
+    layoutImpl.transitions = transitions
+    layoutImpl.density = density
+    layoutImpl.edgeDetector = edgeDetector
+
+    layoutImpl.setScenes(scenes)
+    layoutImpl.setCurrentScene(currentScene)
+    layoutImpl.Content(modifier)
+}
diff --git a/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/ElementTest.kt b/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/ElementTest.kt
index 6401bb3..cc7a0b8 100644
--- a/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/ElementTest.kt
+++ b/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/ElementTest.kt
@@ -18,9 +18,15 @@
 
 import androidx.compose.animation.core.tween
 import androidx.compose.foundation.layout.Box
+import androidx.compose.foundation.layout.Column
+import androidx.compose.foundation.layout.Row
 import androidx.compose.foundation.layout.offset
 import androidx.compose.foundation.layout.size
 import androidx.compose.runtime.Composable
+import androidx.compose.runtime.getValue
+import androidx.compose.runtime.mutableStateOf
+import androidx.compose.runtime.remember
+import androidx.compose.runtime.setValue
 import androidx.compose.ui.ExperimentalComposeUiApi
 import androidx.compose.ui.Modifier
 import androidx.compose.ui.layout.intermediateLayout
@@ -29,6 +35,7 @@
 import androidx.compose.ui.unit.dp
 import androidx.test.ext.junit.runners.AndroidJUnit4
 import com.google.common.truth.Truth.assertThat
+import org.junit.Assert.assertThrows
 import org.junit.Rule
 import org.junit.Test
 import org.junit.runner.RunWith
@@ -209,4 +216,215 @@
             }
         }
     }
+
+    @Test
+    fun elementIsReusedInSameSceneAndBetweenScenes() {
+        var currentScene by mutableStateOf(TestScenes.SceneA)
+        var sceneCState by mutableStateOf(0)
+        var sceneDState by mutableStateOf(0)
+        val key = TestElements.Foo
+        var nullableLayoutImpl: SceneTransitionLayoutImpl? = null
+
+        rule.setContent {
+            SceneTransitionLayoutForTesting(
+                currentScene = currentScene,
+                onChangeScene = { currentScene = it },
+                transitions = remember { transitions {} },
+                state = remember { SceneTransitionLayoutState(currentScene) },
+                edgeDetector = DefaultEdgeDetector,
+                modifier = Modifier,
+                transitionInterceptionThreshold = 0f,
+                onLayoutImpl = { nullableLayoutImpl = it },
+            ) {
+                scene(TestScenes.SceneA) { /* Nothing */}
+                scene(TestScenes.SceneB) { Box(Modifier.element(key)) }
+                scene(TestScenes.SceneC) {
+                    when (sceneCState) {
+                        0 -> Row(Modifier.element(key)) {}
+                        1 -> Column(Modifier.element(key)) {}
+                        else -> {
+                            /* Nothing */
+                        }
+                    }
+                }
+                scene(TestScenes.SceneD) {
+                    // We should be able to extract the modifier before assigning it to different
+                    // nodes.
+                    val childModifier = Modifier.element(key)
+                    when (sceneDState) {
+                        0 -> Row(childModifier) {}
+                        1 -> Column(childModifier) {}
+                        else -> {
+                            /* Nothing */
+                        }
+                    }
+                }
+            }
+        }
+
+        assertThat(nullableLayoutImpl).isNotNull()
+        val layoutImpl = nullableLayoutImpl!!
+
+        // Scene A: no elements in the elements map.
+        rule.waitForIdle()
+        assertThat(layoutImpl.elements).isEmpty()
+
+        // Scene B: element is in the map.
+        currentScene = TestScenes.SceneB
+        rule.waitForIdle()
+
+        assertThat(layoutImpl.elements.keys).containsExactly(key)
+        val element = layoutImpl.elements.getValue(key)
+        assertThat(element.sceneValues.keys).containsExactly(TestScenes.SceneB)
+
+        // Scene C, state 0: the same element is reused.
+        currentScene = TestScenes.SceneC
+        sceneCState = 0
+        rule.waitForIdle()
+
+        assertThat(layoutImpl.elements.keys).containsExactly(key)
+        assertThat(layoutImpl.elements.getValue(key)).isSameInstanceAs(element)
+        assertThat(element.sceneValues.keys).containsExactly(TestScenes.SceneC)
+
+        // Scene C, state 1: the same element is reused.
+        sceneCState = 1
+        rule.waitForIdle()
+
+        assertThat(layoutImpl.elements.keys).containsExactly(key)
+        assertThat(layoutImpl.elements.getValue(key)).isSameInstanceAs(element)
+        assertThat(element.sceneValues.keys).containsExactly(TestScenes.SceneC)
+
+        // Scene D, state 0: the same element is reused.
+        currentScene = TestScenes.SceneD
+        sceneDState = 0
+        rule.waitForIdle()
+
+        assertThat(layoutImpl.elements.keys).containsExactly(key)
+        assertThat(layoutImpl.elements.getValue(key)).isSameInstanceAs(element)
+        assertThat(element.sceneValues.keys).containsExactly(TestScenes.SceneD)
+
+        // Scene D, state 1: the same element is reused.
+        sceneDState = 1
+        rule.waitForIdle()
+
+        assertThat(layoutImpl.elements.keys).containsExactly(key)
+        assertThat(layoutImpl.elements.getValue(key)).isSameInstanceAs(element)
+        assertThat(element.sceneValues.keys).containsExactly(TestScenes.SceneD)
+
+        // Scene D, state 2: the element is removed from the map.
+        sceneDState = 2
+        rule.waitForIdle()
+
+        assertThat(element.sceneValues).isEmpty()
+        assertThat(layoutImpl.elements).isEmpty()
+    }
+
+    @Test
+    fun throwsExceptionWhenElementIsComposedMultipleTimes() {
+        val key = TestElements.Foo
+
+        assertThrows(IllegalStateException::class.java) {
+            rule.setContent {
+                TestSceneScope {
+                    Column {
+                        Box(Modifier.element(key))
+                        Box(Modifier.element(key))
+                    }
+                }
+            }
+        }
+    }
+
+    @Test
+    fun throwsExceptionWhenElementIsComposedMultipleTimes_childModifier() {
+        val key = TestElements.Foo
+
+        assertThrows(IllegalStateException::class.java) {
+            rule.setContent {
+                TestSceneScope {
+                    Column {
+                        val childModifier = Modifier.element(key)
+                        Box(childModifier)
+                        Box(childModifier)
+                    }
+                }
+            }
+        }
+    }
+
+    @Test
+    fun throwsExceptionWhenElementIsComposedMultipleTimes_childModifier_laterDuplication() {
+        val key = TestElements.Foo
+
+        assertThrows(IllegalStateException::class.java) {
+            var nElements by mutableStateOf(1)
+            rule.setContent {
+                TestSceneScope {
+                    Column {
+                        val childModifier = Modifier.element(key)
+                        repeat(nElements) { Box(childModifier) }
+                    }
+                }
+            }
+
+            nElements = 2
+            rule.waitForIdle()
+        }
+    }
+
+    @Test
+    fun throwsExceptionWhenElementIsComposedMultipleTimes_updatedNode() {
+        assertThrows(IllegalStateException::class.java) {
+            var key by mutableStateOf(TestElements.Foo)
+            rule.setContent {
+                TestSceneScope {
+                    Column {
+                        Box(Modifier.element(key))
+                        Box(Modifier.element(TestElements.Bar))
+                    }
+                }
+            }
+
+            key = TestElements.Bar
+            rule.waitForIdle()
+        }
+    }
+
+    @Test
+    fun elementModifierSupportsUpdates() {
+        var key by mutableStateOf(TestElements.Foo)
+        var nullableLayoutImpl: SceneTransitionLayoutImpl? = null
+
+        rule.setContent {
+            SceneTransitionLayoutForTesting(
+                currentScene = TestScenes.SceneA,
+                onChangeScene = {},
+                transitions = remember { transitions {} },
+                state = remember { SceneTransitionLayoutState(TestScenes.SceneA) },
+                edgeDetector = DefaultEdgeDetector,
+                modifier = Modifier,
+                transitionInterceptionThreshold = 0f,
+                onLayoutImpl = { nullableLayoutImpl = it },
+            ) {
+                scene(TestScenes.SceneA) { Box(Modifier.element(key)) }
+            }
+        }
+
+        assertThat(nullableLayoutImpl).isNotNull()
+        val layoutImpl = nullableLayoutImpl!!
+
+        // There is only Foo in the elements map.
+        assertThat(layoutImpl.elements.keys).containsExactly(TestElements.Foo)
+        val fooElement = layoutImpl.elements.getValue(TestElements.Foo)
+        assertThat(fooElement.sceneValues.keys).containsExactly(TestScenes.SceneA)
+
+        key = TestElements.Bar
+
+        // There is only Bar in the elements map and foo scene values was cleaned up.
+        rule.waitForIdle()
+        assertThat(layoutImpl.elements.keys).containsExactly(TestElements.Bar)
+        val barElement = layoutImpl.elements.getValue(TestElements.Bar)
+        assertThat(barElement.sceneValues.keys).containsExactly(TestScenes.SceneA)
+        assertThat(fooElement.sceneValues).isEmpty()
+    }
 }
diff --git a/packages/SystemUI/compose/scene/tests/utils/src/com/android/compose/animation/scene/TestValues.kt b/packages/SystemUI/compose/scene/tests/utils/src/com/android/compose/animation/scene/TestValues.kt
index b4c393e..b83705a 100644
--- a/packages/SystemUI/compose/scene/tests/utils/src/com/android/compose/animation/scene/TestValues.kt
+++ b/packages/SystemUI/compose/scene/tests/utils/src/com/android/compose/animation/scene/TestValues.kt
@@ -26,6 +26,7 @@
     val SceneA = SceneKey("SceneA")
     val SceneB = SceneKey("SceneB")
     val SceneC = SceneKey("SceneC")
+    val SceneD = SceneKey("SceneD")
 }
 
 /** Element keys that can be reused by tests. */