NestedDraggable exposes the number of pointers down when starting

Bug: 378470603
Test: atest NestedDraggableTest
Flag: EXEMPT new API not used anywhere yet
Change-Id: Ie5a8a3d3bf3db0ebc14f2fec39a1fea691eeba57
diff --git a/packages/SystemUI/compose/core/src/com/android/compose/gesture/NestedDraggable.kt b/packages/SystemUI/compose/core/src/com/android/compose/gesture/NestedDraggable.kt
index 9fe85b7..029b9cd 100644
--- a/packages/SystemUI/compose/core/src/com/android/compose/gesture/NestedDraggable.kt
+++ b/packages/SystemUI/compose/core/src/com/android/compose/gesture/NestedDraggable.kt
@@ -38,6 +38,8 @@
 import androidx.compose.ui.input.pointer.PointerInputChange
 import androidx.compose.ui.input.pointer.PointerInputScope
 import androidx.compose.ui.input.pointer.SuspendingPointerInputModifierNode
+import androidx.compose.ui.input.pointer.changedToDownIgnoreConsumed
+import androidx.compose.ui.input.pointer.changedToUpIgnoreConsumed
 import androidx.compose.ui.input.pointer.positionChange
 import androidx.compose.ui.input.pointer.util.VelocityTracker
 import androidx.compose.ui.input.pointer.util.addPointerInputChange
@@ -50,6 +52,7 @@
 import androidx.compose.ui.unit.IntSize
 import androidx.compose.ui.unit.Velocity
 import androidx.compose.ui.util.fastAny
+import androidx.compose.ui.util.fastSumBy
 import com.android.compose.modifiers.thenIf
 import kotlin.math.sign
 import kotlinx.coroutines.CoroutineScope
@@ -65,9 +68,10 @@
 interface NestedDraggable {
     /**
      * Called when a drag is started in the given [position] (*before* dragging the touch slop) and
-     * in the direction given by [sign].
+     * in the direction given by [sign], with the given number of [pointersDown] when the touch slop
+     * was detected.
      */
-    fun onDragStarted(position: Offset, sign: Float): Controller
+    fun onDragStarted(position: Offset, sign: Float, pointersDown: Int): Controller
 
     /**
      * Whether this draggable should consume any scroll amount with the given [sign] coming from a
@@ -170,6 +174,9 @@
      */
     private var lastFirstDown: Offset? = null
 
+    /** The number of pointers down. */
+    private var pointersDownCount = 0
+
     init {
         delegate(nestedScrollModifierNode(this, nestedScrollDispatcher))
     }
@@ -234,6 +241,11 @@
 
         awaitEachGesture {
             val down = awaitFirstDown(requireUnconsumed = false)
+            check(down.position == lastFirstDown) {
+                "Position from detectDrags() is not the same as position in trackDownPosition()"
+            }
+            check(pointersDownCount == 1) { "pointersDownCount is equal to $pointersDownCount" }
+
             var overSlop = 0f
             val onTouchSlopReached = { change: PointerInputChange, over: Float ->
                 change.consume()
@@ -276,10 +288,13 @@
 
             if (drag != null) {
                 velocityTracker.resetTracking()
-
                 val sign = (drag.position - down.position).toFloat().sign
+                check(pointersDownCount > 0) { "pointersDownCount is equal to $pointersDownCount" }
                 val wrappedController =
-                    WrappedController(coroutineScope, draggable.onDragStarted(down.position, sign))
+                    WrappedController(
+                        coroutineScope,
+                        draggable.onDragStarted(down.position, sign, pointersDownCount),
+                    )
                 if (overSlop != 0f) {
                     onDrag(wrappedController, drag, overSlop, velocityTracker)
                 }
@@ -424,7 +439,22 @@
      */
 
     private suspend fun PointerInputScope.trackDownPosition() {
-        awaitEachGesture { lastFirstDown = awaitFirstDown(requireUnconsumed = false).position }
+        awaitEachGesture {
+            val down = awaitFirstDown(requireUnconsumed = false)
+            lastFirstDown = down.position
+            pointersDownCount = 1
+
+            do {
+                pointersDownCount +=
+                    awaitPointerEvent().changes.fastSumBy { change ->
+                        when {
+                            change.changedToDownIgnoreConsumed() -> 1
+                            change.changedToUpIgnoreConsumed() -> -1
+                            else -> 0
+                        }
+                    }
+            } while (pointersDownCount > 0)
+        }
     }
 
     override fun onPreScroll(available: Offset, source: NestedScrollSource): Offset {
@@ -451,8 +481,14 @@
         val sign = offset.sign
         if (nestedScrollController == null && draggable.shouldConsumeNestedScroll(sign)) {
             val startedPosition = checkNotNull(lastFirstDown) { "lastFirstDown is not set" }
+
+            // TODO(b/382665591): Replace this by check(pointersDownCount > 0).
+            val pointersDown = pointersDownCount.coerceAtLeast(1)
             nestedScrollController =
-                WrappedController(coroutineScope, draggable.onDragStarted(startedPosition, sign))
+                WrappedController(
+                    coroutineScope,
+                    draggable.onDragStarted(startedPosition, sign, pointersDown),
+                )
         }
 
         val controller = nestedScrollController ?: return Offset.Zero
diff --git a/packages/SystemUI/compose/core/tests/src/com/android/compose/gesture/NestedDraggableTest.kt b/packages/SystemUI/compose/core/tests/src/com/android/compose/gesture/NestedDraggableTest.kt
index fd3902f..735ab68 100644
--- a/packages/SystemUI/compose/core/tests/src/com/android/compose/gesture/NestedDraggableTest.kt
+++ b/packages/SystemUI/compose/core/tests/src/com/android/compose/gesture/NestedDraggableTest.kt
@@ -41,6 +41,7 @@
 import androidx.compose.ui.test.swipeLeft
 import androidx.compose.ui.unit.Velocity
 import com.google.common.truth.Truth.assertThat
+import kotlin.math.ceil
 import kotlinx.coroutines.awaitCancellation
 import org.junit.Ignore
 import org.junit.Rule
@@ -383,6 +384,79 @@
         assertThat(draggable.onDragStoppedCalled).isTrue()
     }
 
+    @Test
+    fun pointersDown() {
+        val draggable = TestDraggable()
+        val touchSlop =
+            rule.setContentWithTouchSlop {
+                Box(Modifier.fillMaxSize().nestedDraggable(draggable, orientation))
+            }
+
+        (1..5).forEach { nDown ->
+            rule.onRoot().performTouchInput {
+                repeat(nDown) { pointerId -> down(pointerId, center) }
+
+                moveBy(pointerId = 0, touchSlop.toOffset())
+            }
+
+            assertThat(draggable.onDragStartedPointersDown).isEqualTo(nDown)
+
+            rule.onRoot().performTouchInput {
+                repeat(nDown) { pointerId -> up(pointerId = pointerId) }
+            }
+        }
+    }
+
+    @Test
+    fun pointersDown_nestedScroll() {
+        val draggable = TestDraggable()
+        val touchSlop =
+            rule.setContentWithTouchSlop {
+                Box(
+                    Modifier.fillMaxSize()
+                        .nestedDraggable(draggable, orientation)
+                        .nestedScrollable(rememberScrollState())
+                )
+            }
+
+        (1..5).forEach { nDown ->
+            rule.onRoot().performTouchInput {
+                repeat(nDown) { pointerId -> down(pointerId, center) }
+
+                moveBy(pointerId = 0, (touchSlop + 1f).toOffset())
+            }
+
+            assertThat(draggable.onDragStartedPointersDown).isEqualTo(nDown)
+
+            rule.onRoot().performTouchInput {
+                repeat(nDown) { pointerId -> up(pointerId = pointerId) }
+            }
+        }
+    }
+
+    @Test
+    fun pointersDown_downThenUpThenDown() {
+        val draggable = TestDraggable()
+        val touchSlop =
+            rule.setContentWithTouchSlop {
+                Box(Modifier.fillMaxSize().nestedDraggable(draggable, orientation))
+            }
+
+        val slopThird = ceil(touchSlop / 3f).toOffset()
+        rule.onRoot().performTouchInput {
+            repeat(5) { down(pointerId = it, center) } // + 5
+            moveBy(pointerId = 0, slopThird)
+
+            listOf(2, 3).forEach { up(pointerId = it) } // - 2
+            moveBy(pointerId = 0, slopThird)
+
+            listOf(5, 6, 7).forEach { down(pointerId = it, center) } // + 3
+            moveBy(pointerId = 0, slopThird)
+        }
+
+        assertThat(draggable.onDragStartedPointersDown).isEqualTo(6)
+    }
+
     private fun ComposeContentTestRule.setContentWithTouchSlop(
         content: @Composable () -> Unit
     ): Float {
@@ -413,12 +487,18 @@
 
         var onDragStartedPosition = Offset.Zero
         var onDragStartedSign = 0f
+        var onDragStartedPointersDown = 0
         var onDragDelta = 0f
 
-        override fun onDragStarted(position: Offset, sign: Float): NestedDraggable.Controller {
+        override fun onDragStarted(
+            position: Offset,
+            sign: Float,
+            pointersDown: Int,
+        ): NestedDraggable.Controller {
             onDragStartedCalled = true
             onDragStartedPosition = position
             onDragStartedSign = sign
+            onDragStartedPointersDown = pointersDown
             onDragDelta = 0f
 
             onDragStarted.invoke(position, sign)