Add multi-device checks to AccessibilityInputFilterInputTest

Prior to this CL, AccessibilityInputFilter assumed that only a single
input device could be active at a time. However, after the recent
multi-device input feature, that assumption no longer holds.

As a result, InputFilter is being sent multi-device streams from the
InputDispatcher. This is problematic, because InputFilter gets confused
and start processing the events from all devices at the same time.
Components like TouchExplorer don't look at the device id, and quickly
get into an inconsistent state.

This causes some events to be dropped by InputFilter, and some to be
sent back to the InputDispatcher. In some situations, this leads to an
inconsistent input stream being injected back into dispatcher, which
could lead to a crash.

In this CL, we add a multiplexer to AccessibilityInputFilter so that
only one input device can be active at a time. All MotionEvents from
other devices are going to be ignored.

Multi-device behaviour of AccessibilityInputFilter is also being tested
in this CL, for the cases where
none of the a11y features are on, and where all of the a11y features are
on.

Bug: 310014874
Test: atest FrameworksServicesTests:AccessibilityInputFilterInputTest
Change-Id: Iac18fd671787d880aca1f6fbfd17d822e5ae8ee6
diff --git a/services/accessibility/accessibility.aconfig b/services/accessibility/accessibility.aconfig
index 015c35e..7c48550 100644
--- a/services/accessibility/accessibility.aconfig
+++ b/services/accessibility/accessibility.aconfig
@@ -66,6 +66,16 @@
 }
 
 flag {
+    name: "handle_multi_device_input"
+    namespace: "accessibility"
+    description: "Select a single active device when a multi-device stream is received by AccessibilityInputFilter"
+    bug: "310014874"
+    metadata {
+        purpose: PURPOSE_BUGFIX
+    }
+}
+
+flag {
     name: "pinch_zoom_zero_min_span"
     namespace: "accessibility"
     description: "Whether to set min span of ScaleGestureDetector to zero."
diff --git a/services/accessibility/java/com/android/server/accessibility/AccessibilityInputFilter.java b/services/accessibility/java/com/android/server/accessibility/AccessibilityInputFilter.java
index abcd8e2..16119d11 100644
--- a/services/accessibility/java/com/android/server/accessibility/AccessibilityInputFilter.java
+++ b/services/accessibility/java/com/android/server/accessibility/AccessibilityInputFilter.java
@@ -25,6 +25,7 @@
 import android.content.Context;
 import android.graphics.Region;
 import android.os.PowerManager;
+import android.os.SystemClock;
 import android.provider.Settings;
 import android.util.Slog;
 import android.util.SparseArray;
@@ -35,6 +36,8 @@
 import android.view.InputFilter;
 import android.view.KeyEvent;
 import android.view.MotionEvent;
+import android.view.MotionEvent.PointerCoords;
+import android.view.MotionEvent.PointerProperties;
 import android.view.accessibility.AccessibilityEvent;
 
 import com.android.server.LocalServices;
@@ -203,6 +206,62 @@
 
     private EventStreamState mKeyboardStreamState;
 
+    /**
+     * The last MotionEvent emitted from the input device that's currently active. This is used to
+     * keep track of which input device is currently active, and also to generate the cancel event
+     * if a new device becomes active.
+     */
+    private MotionEvent mLastActiveDeviceMotionEvent = null;
+
+    private static MotionEvent cancelMotion(MotionEvent event) {
+        if (event.getActionMasked() == MotionEvent.ACTION_CANCEL
+                || event.getActionMasked() == MotionEvent.ACTION_HOVER_EXIT
+                || event.getActionMasked() == MotionEvent.ACTION_UP) {
+            throw new IllegalArgumentException("Can't cancel " + event);
+        }
+        final int action;
+        if (event.getActionMasked() == MotionEvent.ACTION_HOVER_ENTER
+                || event.getActionMasked() == MotionEvent.ACTION_HOVER_MOVE) {
+            action = MotionEvent.ACTION_HOVER_EXIT;
+        } else {
+            action = MotionEvent.ACTION_CANCEL;
+        }
+
+        final int pointerCount;
+        if (event.getActionMasked() == MotionEvent.ACTION_POINTER_UP) {
+            pointerCount = event.getPointerCount() - 1;
+        } else {
+            pointerCount = event.getPointerCount();
+        }
+        final PointerProperties[] properties = new PointerProperties[pointerCount];
+        final PointerCoords[] coords = new PointerCoords[pointerCount];
+        int newPointerIndex = 0;
+        for (int i = 0; i < event.getPointerCount(); i++) {
+            if (event.getActionMasked() == MotionEvent.ACTION_POINTER_UP) {
+                if (event.getActionIndex() == i) {
+                    // Skip the pointer that's going away
+                    continue;
+                }
+            }
+            final PointerCoords c = new PointerCoords();
+            c.x = event.getX(i);
+            c.y = event.getY(i);
+            coords[newPointerIndex] = c;
+            final PointerProperties p = new PointerProperties();
+            p.id = event.getPointerId(i);
+            p.toolType = event.getToolType(i);
+            properties[newPointerIndex] = p;
+            newPointerIndex++;
+        }
+
+        return MotionEvent.obtain(event.getDownTime(), SystemClock.uptimeMillis(), action,
+                pointerCount, properties, coords,
+                event.getMetaState(), event.getButtonState(),
+                event.getXPrecision(), event.getYPrecision(), event.getDeviceId(),
+                event.getEdgeFlags(), event.getSource(), event.getDisplayId(), event.getFlags(),
+                event.getClassification());
+    }
+
     AccessibilityInputFilter(Context context, AccessibilityManagerService service) {
         this(context, service, new SparseArray<>(0));
     }
@@ -260,6 +319,17 @@
                     AccessibilityTrace.FLAGS_INPUT_FILTER,
                     "event=" + event + ";policyFlags=" + policyFlags);
         }
+        if (Flags.handleMultiDeviceInput()) {
+            if (!shouldProcessMultiDeviceEvent(event, policyFlags)) {
+                // We are only allowing a single device to be active at a time.
+                return;
+            }
+        }
+
+        onInputEventInternal(event, policyFlags);
+    }
+
+    private void onInputEventInternal(InputEvent event, int policyFlags) {
         if (mEventHandler.size() == 0) {
             if (DEBUG) Slog.d(TAG, "No mEventHandler for event " + event);
             super.onInputEvent(event, policyFlags);
@@ -353,6 +423,63 @@
         }
     }
 
+    boolean shouldProcessMultiDeviceEvent(InputEvent event, int policyFlags) {
+        if (event instanceof MotionEvent motion) {
+            // Only allow 1 device to be sending motion events at a time
+            // If the event is from an active device, let it through.
+            // If the event is not from an active device, only let it through if it starts a new
+            // gesture like ACTION_DOWN or ACTION_HOVER_ENTER
+            final boolean eventIsFromCurrentDevice = mLastActiveDeviceMotionEvent != null
+                    && mLastActiveDeviceMotionEvent.getDeviceId() == motion.getDeviceId();
+            final int actionMasked = motion.getActionMasked();
+            switch (actionMasked) {
+                case MotionEvent.ACTION_DOWN:
+                case MotionEvent.ACTION_HOVER_ENTER:
+                case MotionEvent.ACTION_HOVER_MOVE: {
+                    if (mLastActiveDeviceMotionEvent != null
+                            && mLastActiveDeviceMotionEvent.getDeviceId() != motion.getDeviceId()) {
+                        // This is a new gesture from a new device. Cancel the existing state
+                        // and let this through
+                        MotionEvent canceled = cancelMotion(mLastActiveDeviceMotionEvent);
+                        onInputEventInternal(canceled, policyFlags);
+                    }
+                    mLastActiveDeviceMotionEvent = MotionEvent.obtain(motion);
+                    return true;
+                }
+                case MotionEvent.ACTION_MOVE:
+                case MotionEvent.ACTION_POINTER_DOWN:
+                case MotionEvent.ACTION_POINTER_UP: {
+                    if (eventIsFromCurrentDevice) {
+                        mLastActiveDeviceMotionEvent = MotionEvent.obtain(motion);
+                        return true;
+                    } else {
+                        return false;
+                    }
+                }
+                case MotionEvent.ACTION_UP:
+                case MotionEvent.ACTION_CANCEL:
+                case MotionEvent.ACTION_HOVER_EXIT: {
+                    if (eventIsFromCurrentDevice) {
+                        // This is the last event of the gesture from this device.
+                        mLastActiveDeviceMotionEvent = null;
+                        return true;
+                    } else {
+                        // Event is from another device
+                        return false;
+                    }
+                }
+                default: {
+                    if (mLastActiveDeviceMotionEvent != null
+                            && event.getDeviceId() != mLastActiveDeviceMotionEvent.getDeviceId()) {
+                        // This is an event from another device, ignore it.
+                        return false;
+                    }
+                }
+            }
+        }
+        return true;
+    }
+
     private void processMotionEvent(EventStreamState state, MotionEvent event, int policyFlags) {
         if (!state.shouldProcessScroll() && event.getActionMasked() == MotionEvent.ACTION_SCROLL) {
             super.onInputEvent(event, policyFlags);
diff --git a/services/accessibility/java/com/android/server/accessibility/gestures/EventDispatcher.java b/services/accessibility/java/com/android/server/accessibility/gestures/EventDispatcher.java
index b6223c7..bf9202f1b 100644
--- a/services/accessibility/java/com/android/server/accessibility/gestures/EventDispatcher.java
+++ b/services/accessibility/java/com/android/server/accessibility/gestures/EventDispatcher.java
@@ -106,11 +106,30 @@
                 return;
             }
         }
+        final long downTime;
         if (action == MotionEvent.ACTION_DOWN) {
-            event.setDownTime(event.getEventTime());
+            downTime = event.getEventTime();
         } else {
-            event.setDownTime(mState.getLastInjectedDownEventTime());
+            downTime = mState.getLastInjectedDownEventTime();
         }
+
+        // The only way to change device id of the motion event is by re-creating the whole thing
+        final PointerProperties[] properties = new PointerProperties[event.getPointerCount()];
+        final PointerCoords[] coords = new PointerCoords[event.getPointerCount()];
+        for (int i = 0; i < event.getPointerCount(); i++) {
+            final PointerCoords c = new PointerCoords();
+            event.getPointerCoords(i, c);
+            coords[i] = c;
+            final PointerProperties p = new PointerProperties();
+            event.getPointerProperties(i, p);
+            properties[i] = p;
+        }
+        event = MotionEvent.obtain(downTime, event.getEventTime(), event.getAction(),
+                event.getPointerCount(), properties, coords,
+                event.getMetaState(), event.getButtonState(),
+                event.getXPrecision(), event.getYPrecision(), rawEvent.getDeviceId(),
+                event.getEdgeFlags(), rawEvent.getSource(), event.getDisplayId(), event.getFlags(),
+                event.getClassification());
         // If the user is long pressing but the long pressing pointer
         // was not exactly over the accessibility focused item we need
         // to remap the location of that pointer so the user does not
diff --git a/services/tests/servicestests/src/com/android/server/accessibility/AccessibilityInputFilterInputTest.kt b/services/tests/servicestests/src/com/android/server/accessibility/AccessibilityInputFilterInputTest.kt
index 52c7d8d..5c8c6bb 100644
--- a/services/tests/servicestests/src/com/android/server/accessibility/AccessibilityInputFilterInputTest.kt
+++ b/services/tests/servicestests/src/com/android/server/accessibility/AccessibilityInputFilterInputTest.kt
@@ -17,31 +17,38 @@
 
 import android.hardware.display.DisplayManagerGlobal
 import android.os.SystemClock
+import android.platform.test.annotations.RequiresFlagsEnabled
+import android.platform.test.flag.junit.CheckFlagsRule
+import android.platform.test.flag.junit.DeviceFlagsValueProvider
 import android.view.Display
 import android.view.Display.DEFAULT_DISPLAY
 import android.view.DisplayAdjustments
 import android.view.DisplayInfo
 import android.view.IInputFilterHost
+import android.view.InputDevice.SOURCE_STYLUS
 import android.view.InputDevice.SOURCE_TOUCHSCREEN
 import android.view.InputEvent
 import android.view.MotionEvent
+import android.view.MotionEvent.ACTION_CANCEL
 import android.view.MotionEvent.ACTION_DOWN
-import android.view.MotionEvent.ACTION_MOVE
-import android.view.MotionEvent.ACTION_UP
 import android.view.MotionEvent.ACTION_HOVER_ENTER
 import android.view.MotionEvent.ACTION_HOVER_EXIT
 import android.view.MotionEvent.ACTION_HOVER_MOVE
+import android.view.MotionEvent.ACTION_MOVE
+import android.view.MotionEvent.ACTION_UP
 import android.view.WindowManagerPolicyConstants.FLAG_PASS_TO_USER
 import androidx.test.ext.junit.runners.AndroidJUnit4
 import androidx.test.platform.app.InstrumentationRegistry
 import com.android.cts.input.inputeventmatchers.withDeviceId
 import com.android.cts.input.inputeventmatchers.withMotionAction
+import com.android.cts.input.inputeventmatchers.withSource
 import com.android.server.LocalServices
 import com.android.server.accessibility.magnification.MagnificationProcessor
 import com.android.server.wm.WindowManagerInternal
 import java.util.concurrent.LinkedBlockingQueue
 import org.hamcrest.Matchers.allOf
 import org.junit.After
+import org.junit.Assert.assertEquals
 import org.junit.Before
 import org.junit.Rule
 import org.junit.Test
@@ -92,12 +99,17 @@
                 or AccessibilityInputFilter.FLAG_FEATURE_TRIGGERED_SCREEN_MAGNIFIER
                 or AccessibilityInputFilter.FLAG_FEATURE_INJECT_MOTION_EVENTS
                 or AccessibilityInputFilter.FLAG_FEATURE_FILTER_KEY_EVENTS)
+        const val STYLUS_SOURCE = SOURCE_STYLUS or SOURCE_TOUCHSCREEN
     }
 
     @Rule
     @JvmField
     val mocks: MockitoRule = MockitoJUnit.rule()
 
+    @Rule
+    @JvmField
+    val mCheckFlagsRule: CheckFlagsRule = DeviceFlagsValueProvider.createCheckFlagsRule()
+
     @Mock
     private lateinit var mockA11yController: WindowManagerInternal.AccessibilityControllerInternal
 
@@ -115,6 +127,9 @@
     private lateinit var ams: AccessibilityManagerService
     private lateinit var a11yInputFilter: AccessibilityInputFilter
     private val touchDeviceId = 1
+    private val fromTouchScreen = allOf(withDeviceId(touchDeviceId), withSource(SOURCE_TOUCHSCREEN))
+    private val stylusDeviceId = 2
+    private val fromStylus = allOf(withDeviceId(stylusDeviceId), withSource(STYLUS_SOURCE))
 
     @Before
     fun setUp() {
@@ -156,23 +171,14 @@
         enableFeatures(0)
 
         val downTime = SystemClock.uptimeMillis()
-        val downEvent = createMotionEvent(
-            ACTION_DOWN, downTime, downTime, SOURCE_TOUCHSCREEN, touchDeviceId)
-        send(downEvent)
-        verifier.assertReceivedMotion(
-            allOf(withMotionAction(ACTION_DOWN), withDeviceId(touchDeviceId)))
+        sendTouchEvent(ACTION_DOWN, downTime, downTime)
+        verifier.assertReceivedMotion(allOf(fromTouchScreen, withMotionAction(ACTION_DOWN)))
 
-        val moveEvent = createMotionEvent(
-            ACTION_MOVE, downTime, SystemClock.uptimeMillis(), SOURCE_TOUCHSCREEN, touchDeviceId)
-        send(moveEvent)
-        verifier.assertReceivedMotion(
-            allOf(withMotionAction(ACTION_MOVE), withDeviceId(touchDeviceId)))
+        sendTouchEvent(ACTION_MOVE, downTime, SystemClock.uptimeMillis())
+        verifier.assertReceivedMotion(allOf(fromTouchScreen, withMotionAction(ACTION_MOVE)))
 
-        val upEvent = createMotionEvent(
-            ACTION_UP, downTime, SystemClock.uptimeMillis(), SOURCE_TOUCHSCREEN, touchDeviceId)
-        send(upEvent)
-        verifier.assertReceivedMotion(
-            allOf(withMotionAction(ACTION_UP), withDeviceId(touchDeviceId)))
+        sendTouchEvent(ACTION_UP, downTime, SystemClock.uptimeMillis())
+        verifier.assertReceivedMotion(allOf(fromTouchScreen, withMotionAction(ACTION_UP)))
 
         verifier.assertNoEvents()
     }
@@ -186,28 +192,91 @@
         enableFeatures(ALL_A11Y_FEATURES)
 
         val downTime = SystemClock.uptimeMillis()
-        val downEvent = createMotionEvent(
-            ACTION_DOWN, downTime, downTime, SOURCE_TOUCHSCREEN, touchDeviceId)
-        send(MotionEvent.obtain(downEvent))
-
+        sendTouchEvent(ACTION_DOWN, downTime, downTime)
         // DOWN event gets transformed to HOVER_ENTER
-        verifier.assertReceivedMotion(
-            allOf(withMotionAction(ACTION_HOVER_ENTER), withDeviceId(touchDeviceId)))
+        verifier.assertReceivedMotion(allOf(fromTouchScreen, withMotionAction(ACTION_HOVER_ENTER)))
 
         // MOVE becomes HOVER_MOVE
-        val moveEvent = createMotionEvent(
-            ACTION_MOVE, downTime, SystemClock.uptimeMillis(), SOURCE_TOUCHSCREEN, touchDeviceId)
-        send(moveEvent)
-        verifier.assertReceivedMotion(
-            allOf(withMotionAction(ACTION_HOVER_MOVE), withDeviceId(touchDeviceId)))
+        sendTouchEvent(ACTION_MOVE, downTime, SystemClock.uptimeMillis())
+        verifier.assertReceivedMotion(allOf(fromTouchScreen, withMotionAction(ACTION_HOVER_MOVE)))
 
         // UP becomes HOVER_EXIT
-        val upEvent = createMotionEvent(
-            ACTION_UP, downTime, SystemClock.uptimeMillis(), SOURCE_TOUCHSCREEN, touchDeviceId)
-        send(upEvent)
+        sendTouchEvent(ACTION_UP, downTime, SystemClock.uptimeMillis())
+        verifier.assertReceivedMotion(allOf(fromTouchScreen, withMotionAction(ACTION_HOVER_EXIT)))
 
-        verifier.assertReceivedMotion(
-            allOf(withMotionAction(ACTION_HOVER_EXIT), withDeviceId(touchDeviceId)))
+        verifier.assertNoEvents()
+    }
+
+    /**
+     * Enable all a11y features and send a touchscreen stream of DOWN -> CANCEL -> DOWN events.
+     * These get converted into HOVER_ENTER -> HOVER_EXIT -> HOVER_ENTER events by the input filter.
+     */
+    @Test
+    fun testTouchDownCancelDownWithAllA11yFeatures() {
+        enableFeatures(ALL_A11Y_FEATURES)
+
+        val downTime = SystemClock.uptimeMillis()
+        sendTouchEvent(ACTION_DOWN, downTime, downTime)
+        // DOWN event gets transformed to HOVER_ENTER
+        verifier.assertReceivedMotion(allOf(fromTouchScreen, withMotionAction(ACTION_HOVER_ENTER)))
+
+        // CANCEL becomes HOVER_EXIT
+        sendTouchEvent(ACTION_CANCEL, downTime, SystemClock.uptimeMillis())
+        verifier.assertReceivedMotion(allOf(fromTouchScreen, withMotionAction(ACTION_HOVER_EXIT)))
+
+        // DOWN again! New hover is expected
+        val newDownTime = SystemClock.uptimeMillis()
+        sendTouchEvent(ACTION_DOWN, newDownTime, newDownTime)
+        verifier.assertReceivedMotion(allOf(fromTouchScreen, withMotionAction(ACTION_HOVER_ENTER)))
+
+        verifier.assertNoEvents()
+    }
+
+    /**
+     * Enable all a11y features and send a stylus stream of DOWN -> CANCEL -> DOWN events.
+     * These get converted into HOVER_ENTER -> HOVER_EXIT -> HOVER_ENTER events by the input filter.
+     * This test is the same as above, but for stylus events.
+     */
+    @Test
+    fun testStylusDownCancelDownWithAllA11yFeatures() {
+        enableFeatures(ALL_A11Y_FEATURES)
+
+        val downTime = SystemClock.uptimeMillis()
+        sendStylusEvent(ACTION_DOWN, downTime, downTime)
+        // DOWN event gets transformed to HOVER_ENTER
+        verifier.assertReceivedMotion(allOf(fromStylus, withMotionAction(ACTION_HOVER_ENTER)))
+
+        // CANCEL becomes HOVER_EXIT
+        sendStylusEvent(ACTION_CANCEL, downTime, SystemClock.uptimeMillis())
+        verifier.assertReceivedMotion(allOf(fromStylus, withMotionAction(ACTION_HOVER_EXIT)))
+
+        // DOWN again! New hover is expected
+        val newDownTime = SystemClock.uptimeMillis()
+        sendStylusEvent(ACTION_DOWN, newDownTime, newDownTime)
+        verifier.assertReceivedMotion(allOf(fromStylus, withMotionAction(ACTION_HOVER_ENTER)))
+
+        verifier.assertNoEvents()
+    }
+
+    /**
+     * Enable all a11y features and send a stylus stream and then a touch stream.
+     */
+    @Test
+    fun testStylusThenTouch() {
+        enableFeatures(ALL_A11Y_FEATURES)
+
+        val downTime = SystemClock.uptimeMillis()
+        sendStylusEvent(ACTION_DOWN, downTime, downTime)
+        // DOWN event gets transformed to HOVER_ENTER
+        verifier.assertReceivedMotion(allOf(fromStylus, withMotionAction(ACTION_HOVER_ENTER)))
+
+        // CANCEL becomes HOVER_EXIT
+        sendStylusEvent(ACTION_CANCEL, downTime, SystemClock.uptimeMillis())
+        verifier.assertReceivedMotion(allOf(fromStylus, withMotionAction(ACTION_HOVER_EXIT)))
+
+        val newDownTime = SystemClock.uptimeMillis()
+        sendTouchEvent(ACTION_DOWN, newDownTime, newDownTime)
+        verifier.assertReceivedMotion(allOf(fromTouchScreen, withMotionAction(ACTION_HOVER_ENTER)))
 
         verifier.assertNoEvents()
     }
@@ -223,26 +292,18 @@
         enableFeatures(ALL_A11Y_FEATURES)
 
         val downTime = SystemClock.uptimeMillis()
-        val downEvent = createMotionEvent(
-            ACTION_DOWN, downTime, downTime, SOURCE_TOUCHSCREEN, touchDeviceId)
-        send(MotionEvent.obtain(downEvent))
+        sendTouchEvent(ACTION_DOWN, downTime, downTime)
 
         // DOWN event gets transformed to HOVER_ENTER
-        verifier.assertReceivedMotion(
-            allOf(withMotionAction(ACTION_HOVER_ENTER), withDeviceId(touchDeviceId)))
+        verifier.assertReceivedMotion(allOf(fromTouchScreen, withMotionAction(ACTION_HOVER_ENTER)))
         verifier.assertNoEvents()
 
         enableFeatures(0)
-        verifier.assertReceivedMotion(
-            allOf(withMotionAction(ACTION_HOVER_EXIT), withDeviceId(touchDeviceId)))
+        verifier.assertReceivedMotion(allOf(fromTouchScreen, withMotionAction(ACTION_HOVER_EXIT)))
         verifier.assertNoEvents()
 
-        val moveEvent = createMotionEvent(
-            ACTION_MOVE, downTime, SystemClock.uptimeMillis(), SOURCE_TOUCHSCREEN, touchDeviceId)
-        send(moveEvent)
-        val upEvent = createMotionEvent(
-            ACTION_UP, downTime, SystemClock.uptimeMillis(), SOURCE_TOUCHSCREEN, touchDeviceId)
-        send(upEvent)
+        sendTouchEvent(ACTION_MOVE, downTime, SystemClock.uptimeMillis())
+        sendTouchEvent(ACTION_UP, downTime, SystemClock.uptimeMillis())
         // As the original gesture continues, no additional events should be getting sent by the
         // filter because the HOVER_EXIT above already effectively finished the current gesture and
         // the DOWN event was never sent to the host.
@@ -250,10 +311,148 @@
         // Bug: the down event was swallowed, so the remainder of the gesture should be swallowed
         // too. However, the MOVE and UP events are currently passed back to the dispatcher.
         // TODO(b/310014874) - ensure a11y sends consistent input streams to the dispatcher
-        verifier.assertReceivedMotion(
-            allOf(withMotionAction(ACTION_MOVE), withDeviceId(touchDeviceId)))
-        verifier.assertReceivedMotion(
-            allOf(withMotionAction(ACTION_UP), withDeviceId(touchDeviceId)))
+        verifier.assertReceivedMotion(allOf(fromTouchScreen, withMotionAction(ACTION_MOVE)))
+        verifier.assertReceivedMotion(allOf(fromTouchScreen, withMotionAction(ACTION_UP)))
+
+        verifier.assertNoEvents()
+    }
+
+    /**
+     * Check multi-device behaviour when all a11y features are disabled. The events should pass
+     * through unmodified, but only from the active (first) device.
+     * The events from the inactive device should be dropped.
+     * In this test, we are injecting a touchscreen event stream and a stylus event stream,
+     * interleaved.
+     */
+    @Test
+    @RequiresFlagsEnabled(Flags.FLAG_HANDLE_MULTI_DEVICE_INPUT)
+    fun testMultiDeviceEventsWithoutA11yFeatures() {
+        enableFeatures(0)
+
+        val touchDownTime = SystemClock.uptimeMillis()
+
+        // Touch device - ACTION_DOWN
+        sendTouchEvent(ACTION_DOWN, touchDownTime, touchDownTime)
+        verifier.assertReceivedMotion(allOf(fromTouchScreen, withMotionAction(ACTION_DOWN)))
+
+        // Stylus device - ACTION_DOWN
+        val stylusDownTime = SystemClock.uptimeMillis()
+        sendStylusEvent(ACTION_DOWN, stylusDownTime, stylusDownTime)
+        verifier.assertReceivedMotion(allOf(fromTouchScreen, withMotionAction(ACTION_CANCEL)))
+        verifier.assertReceivedMotion(allOf(fromStylus, withMotionAction(ACTION_DOWN)))
+
+        // Touch device - ACTION_MOVE
+        sendTouchEvent(ACTION_MOVE, touchDownTime, SystemClock.uptimeMillis())
+        // Touch event is dropped
+        verifier.assertNoEvents()
+
+        // Stylus device - ACTION_MOVE
+        sendStylusEvent(ACTION_MOVE, stylusDownTime, SystemClock.uptimeMillis())
+        verifier.assertReceivedMotion(allOf(fromStylus, withMotionAction(ACTION_MOVE)))
+
+        // Touch device - ACTION_UP
+        sendTouchEvent(ACTION_UP, touchDownTime, SystemClock.uptimeMillis())
+        // Touch event is dropped
+        verifier.assertNoEvents()
+
+        // Stylus device - ACTION_UP
+        sendStylusEvent(ACTION_UP, stylusDownTime, SystemClock.uptimeMillis())
+        verifier.assertReceivedMotion(allOf(fromStylus, withMotionAction(ACTION_UP)))
+
+        verifier.assertNoEvents()
+    }
+
+    /**
+     * Check multi-device behaviour when all a11y features are enabled. The events should be
+     * modified accordingly, like DOWN events getting converted to hovers.
+     * Only a single device should be active (the latest device to start a new gesture).
+     * In this test, we are injecting a touchscreen event stream and a stylus event stream,
+     * interleaved.
+     */
+    @Test
+    @RequiresFlagsEnabled(Flags.FLAG_HANDLE_MULTI_DEVICE_INPUT)
+    fun testMultiDeviceEventsWithAllA11yFeatures() {
+        enableFeatures(ALL_A11Y_FEATURES)
+
+        // Touch device - ACTION_DOWN
+        val touchDownTime = SystemClock.uptimeMillis()
+        sendTouchEvent(ACTION_DOWN, touchDownTime, touchDownTime)
+        verifier.assertReceivedMotion(allOf(fromTouchScreen, withMotionAction(ACTION_HOVER_ENTER)))
+
+        // Stylus device - ACTION_DOWN
+        val stylusDownTime = SystemClock.uptimeMillis()
+        sendStylusEvent(ACTION_DOWN, stylusDownTime, stylusDownTime)
+        // Touch is canceled and stylus is started
+        verifier.assertReceivedMotion(allOf(fromTouchScreen, withMotionAction(ACTION_HOVER_EXIT)))
+        verifier.assertReceivedMotion(allOf(fromStylus, withMotionAction(ACTION_HOVER_ENTER)))
+
+        // Touch device - ACTION_MOVE
+        sendTouchEvent(ACTION_MOVE, touchDownTime, SystemClock.uptimeMillis())
+        // Stylus is active now; touch is ignored
+        verifier.assertNoEvents()
+
+        // Stylus device - ACTION_MOVE
+        sendStylusEvent(ACTION_MOVE, stylusDownTime, SystemClock.uptimeMillis())
+        verifier.assertReceivedMotion(allOf(fromStylus, withMotionAction(ACTION_HOVER_MOVE)))
+
+        // Touch device - ACTION_UP
+        sendTouchEvent(ACTION_UP, touchDownTime, SystemClock.uptimeMillis())
+        // Stylus is still active; touch is ignored
+        verifier.assertNoEvents()
+
+        sendStylusEvent(ACTION_UP, stylusDownTime, SystemClock.uptimeMillis())
+        verifier.assertReceivedMotion(allOf(fromStylus, withMotionAction(ACTION_HOVER_EXIT)))
+
+        // Now stylus is done, and a new touch gesture will work!
+        val newTouchDownTime = SystemClock.uptimeMillis()
+        sendTouchEvent(ACTION_DOWN, newTouchDownTime, newTouchDownTime)
+        verifier.assertReceivedMotion(allOf(fromTouchScreen, withMotionAction(ACTION_HOVER_ENTER)))
+
+        verifier.assertNoEvents()
+    }
+
+    /**
+     * Check multi-device behaviour when all a11y features are enabled. The events should be
+     * modified accordingly, like DOWN events getting converted to hovers.
+     * Only a single device should be active at a given time. The touch events start and end
+     * while stylus is active. Check that the latest device is always given preference.
+     */
+    @Test
+    @RequiresFlagsEnabled(Flags.FLAG_HANDLE_MULTI_DEVICE_INPUT)
+    fun testStylusWithTouchInTheMiddle() {
+        enableFeatures(ALL_A11Y_FEATURES)
+
+        // Stylus device - ACTION_DOWN
+        val stylusDownTime = SystemClock.uptimeMillis()
+        sendStylusEvent(ACTION_DOWN, stylusDownTime, stylusDownTime)
+        verifier.assertReceivedMotion(allOf(fromStylus, withMotionAction(ACTION_HOVER_ENTER)))
+
+        // Touch device - ACTION_DOWN
+        val touchDownTime = SystemClock.uptimeMillis()
+        sendTouchEvent(ACTION_DOWN, touchDownTime, touchDownTime)
+        // Touch DOWN causes stylus to get canceled
+        verifier.assertReceivedMotion(allOf(fromStylus, withMotionAction(ACTION_HOVER_EXIT)))
+        verifier.assertReceivedMotion(allOf(fromTouchScreen, withMotionAction(ACTION_HOVER_ENTER)))
+
+        // Touch device - ACTION_MOVE
+        sendTouchEvent(ACTION_MOVE, touchDownTime, SystemClock.uptimeMillis())
+        verifier.assertReceivedMotion(allOf(fromTouchScreen, withMotionAction(ACTION_HOVER_MOVE)))
+
+        sendStylusEvent(ACTION_MOVE, stylusDownTime, SystemClock.uptimeMillis())
+        // Stylus is ignored because touch is active now
+        verifier.assertNoEvents()
+
+        sendTouchEvent(ACTION_UP, touchDownTime, SystemClock.uptimeMillis())
+        verifier.assertReceivedMotion(allOf(fromTouchScreen, withMotionAction(ACTION_HOVER_EXIT)))
+
+        sendStylusEvent(ACTION_UP, stylusDownTime, SystemClock.uptimeMillis())
+        // The UP stylus event is also ignored
+        verifier.assertNoEvents()
+
+        // Now stylus works again, because touch gesture is finished
+        val newStylusDownTime = SystemClock.uptimeMillis()
+        sendStylusEvent(ACTION_DOWN, newStylusDownTime, newStylusDownTime)
+        verifier.assertReceivedMotion(allOf(fromStylus, withMotionAction(ACTION_HOVER_ENTER)))
 
         verifier.assertNoEvents()
     }
@@ -264,6 +463,20 @@
         return display
     }
 
+    private fun sendTouchEvent(action: Int, downTime: Long, eventTime: Long) {
+        if (action == ACTION_DOWN) {
+            assertEquals(downTime, eventTime)
+        }
+        send(createMotionEvent(action, downTime, eventTime, SOURCE_TOUCHSCREEN, touchDeviceId))
+    }
+
+    private fun sendStylusEvent(action: Int, downTime: Long, eventTime: Long) {
+        if (action == ACTION_DOWN) {
+            assertEquals(downTime, eventTime)
+        }
+        send(createMotionEvent(action, downTime, eventTime, STYLUS_SOURCE, stylusDeviceId))
+    }
+
     private fun send(event: InputEvent) {
         // We need to make a copy of the event before sending it to the filter, because the filter
         // will recycle it, but the caller of this function might want to still be able to use