Fix ElementNode update/recycling

This CL fixes a subtle bug in the element update that would happen when
an ElementNode is updated with an Element that is different than the
previous one but with the same key, which can happen when the node is
recycled inside a lazy layout.

Bug: 308961608
Test: ElementTest
Flag: N/A
Change-Id: Iba8601627b23ebbd33cadb8ac2f1c619c403a1f8
diff --git a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/Element.kt b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/Element.kt
index fb8083b..31604a6 100644
--- a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/Element.kt
+++ b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/Element.kt
@@ -213,22 +213,11 @@
     override fun onDetach() {
         super.onDetach()
         removeNodeFromSceneValues()
+        maybePruneMaps(layoutImpl, element, sceneValues)
     }
 
     private fun removeNodeFromSceneValues() {
         sceneValues.nodes.remove(this)
-
-        // If element is not composed from this scene anymore, remove the scene values. This works
-        // because [onAttach] is called before [onDetach], so if an element is moved from the UI
-        // tree we will first add the new code location then remove the old one.
-        if (sceneValues.nodes.isEmpty()) {
-            element.sceneValues.remove(sceneValues.scene)
-        }
-
-        // If the element is not composed in any scene, remove it from the elements map.
-        if (element.sceneValues.isEmpty()) {
-            layoutImpl.elements.remove(element.key)
-        }
     }
 
     fun update(
@@ -237,12 +226,16 @@
         element: Element,
         sceneValues: Element.TargetValues,
     ) {
+        check(layoutImpl == this.layoutImpl && scene == this.scene)
         removeNodeFromSceneValues()
-        this.layoutImpl = layoutImpl
-        this.scene = scene
+
+        val prevElement = this.element
+        val prevSceneValues = this.sceneValues
         this.element = element
         this.sceneValues = sceneValues
+
         addNodeToSceneValues()
+        maybePruneMaps(layoutImpl, prevElement, prevSceneValues)
     }
 
     override fun ContentDrawScope.draw() {
@@ -261,6 +254,28 @@
             }
         }
     }
+
+    companion object {
+        private fun maybePruneMaps(
+            layoutImpl: SceneTransitionLayoutImpl,
+            element: Element,
+            sceneValues: Element.TargetValues,
+        ) {
+            // If element is not composed from this scene anymore, remove the scene values. This
+            // works because [onAttach] is called before [onDetach], so if an element is moved from
+            // the UI tree we will first add the new code location then remove the old one.
+            if (
+                sceneValues.nodes.isEmpty() && element.sceneValues[sceneValues.scene] == sceneValues
+            ) {
+                element.sceneValues.remove(sceneValues.scene)
+
+                // If the element is not composed in any scene, remove it from the elements map.
+                if (element.sceneValues.isEmpty() && layoutImpl.elements[element.key] == element) {
+                    layoutImpl.elements.remove(element.key)
+                }
+            }
+        }
+    }
 }
 
 private fun shouldDrawElement(
diff --git a/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/ElementTest.kt b/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/ElementTest.kt
index ce3e1db..439dc00 100644
--- a/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/ElementTest.kt
+++ b/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/ElementTest.kt
@@ -17,16 +17,21 @@
 package com.android.compose.animation.scene
 
 import androidx.compose.animation.core.tween
+import androidx.compose.foundation.ExperimentalFoundationApi
 import androidx.compose.foundation.layout.Box
 import androidx.compose.foundation.layout.Column
 import androidx.compose.foundation.layout.Row
+import androidx.compose.foundation.layout.fillMaxSize
 import androidx.compose.foundation.layout.offset
 import androidx.compose.foundation.layout.size
+import androidx.compose.foundation.pager.HorizontalPager
+import androidx.compose.foundation.pager.PagerState
 import androidx.compose.runtime.Composable
 import androidx.compose.runtime.SideEffect
 import androidx.compose.runtime.getValue
 import androidx.compose.runtime.mutableStateOf
 import androidx.compose.runtime.remember
+import androidx.compose.runtime.rememberCoroutineScope
 import androidx.compose.runtime.setValue
 import androidx.compose.ui.ExperimentalComposeUiApi
 import androidx.compose.ui.Modifier
@@ -36,6 +41,9 @@
 import androidx.compose.ui.unit.dp
 import androidx.test.ext.junit.runners.AndroidJUnit4
 import com.google.common.truth.Truth.assertThat
+import kotlinx.coroutines.CoroutineScope
+import kotlinx.coroutines.launch
+import kotlinx.coroutines.test.runTest
 import org.junit.Assert.assertThrows
 import org.junit.Rule
 import org.junit.Test
@@ -430,6 +438,97 @@
     }
 
     @Test
+    @OptIn(ExperimentalFoundationApi::class)
+    fun elementModifierNodeIsRecycledInLazyLayouts() = runTest {
+        val nPages = 2
+        val pagerState = PagerState(currentPage = 0) { nPages }
+        var nullableLayoutImpl: SceneTransitionLayoutImpl? = null
+
+        // This is how we scroll a pager inside a test, as explained in b/315457147#comment2.
+        lateinit var scrollScope: CoroutineScope
+        fun scrollToPage(page: Int) {
+            var animationFinished by mutableStateOf(false)
+            rule.runOnIdle {
+                scrollScope.launch {
+                    pagerState.scrollToPage(page)
+                    animationFinished = true
+                }
+            }
+            rule.waitUntil(timeoutMillis = 10_000) { animationFinished }
+        }
+
+        rule.setContent {
+            scrollScope = rememberCoroutineScope()
+
+            SceneTransitionLayoutForTesting(
+                currentScene = TestScenes.SceneA,
+                onChangeScene = {},
+                transitions = remember { transitions {} },
+                state = remember { SceneTransitionLayoutState(TestScenes.SceneA) },
+                edgeDetector = DefaultEdgeDetector,
+                modifier = Modifier,
+                transitionInterceptionThreshold = 0f,
+                onLayoutImpl = { nullableLayoutImpl = it },
+            ) {
+                scene(TestScenes.SceneA) {
+                    // The pages are full-size and beyondBoundsPageCount is 0, so at rest only one
+                    // page should be composed.
+                    HorizontalPager(
+                        pagerState,
+                        beyondBoundsPageCount = 0,
+                    ) { page ->
+                        when (page) {
+                            0 -> Box(Modifier.element(TestElements.Foo).fillMaxSize())
+                            1 -> Box(Modifier.fillMaxSize())
+                            else -> error("page $page < nPages $nPages")
+                        }
+                    }
+                }
+            }
+        }
+
+        assertThat(nullableLayoutImpl).isNotNull()
+        val layoutImpl = nullableLayoutImpl!!
+
+        // There is only Foo in the elements map.
+        assertThat(layoutImpl.elements.keys).containsExactly(TestElements.Foo)
+        val element = layoutImpl.elements.getValue(TestElements.Foo)
+        val sceneValues = element.sceneValues
+        assertThat(sceneValues.keys).containsExactly(TestScenes.SceneA)
+
+        // Get the ElementModifier node that should be reused later on when coming back to this
+        // page.
+        val nodes = sceneValues.getValue(TestScenes.SceneA).nodes
+        assertThat(nodes).hasSize(1)
+        val node = nodes.single()
+
+        // Go to the second page.
+        scrollToPage(1)
+        rule.waitForIdle()
+
+        assertThat(nodes).isEmpty()
+        assertThat(sceneValues).isEmpty()
+        assertThat(layoutImpl.elements).isEmpty()
+
+        // Go back to the first page.
+        scrollToPage(0)
+        rule.waitForIdle()
+
+        assertThat(layoutImpl.elements.keys).containsExactly(TestElements.Foo)
+        val newElement = layoutImpl.elements.getValue(TestElements.Foo)
+        val newSceneValues = newElement.sceneValues
+        assertThat(newElement).isNotEqualTo(element)
+        assertThat(newSceneValues).isNotEqualTo(sceneValues)
+        assertThat(newSceneValues.keys).containsExactly(TestScenes.SceneA)
+
+        // The ElementModifier node should be the same as before.
+        val newNodes = newSceneValues.getValue(TestScenes.SceneA).nodes
+        assertThat(newNodes).hasSize(1)
+        val newNode = newNodes.single()
+        assertThat(newNode).isSameInstanceAs(node)
+    }
+
+    @Test
     fun existingElementsDontRecomposeWhenTransitionStateChanges() {
         var fooCompositions = 0