Merge "Use the injected device id for events that are a11y trusted" into sc-dev
diff --git a/include/input/Input.h b/include/input/Input.h
index d4defa8..7d936ba 100644
--- a/include/input/Input.h
+++ b/include/input/Input.h
@@ -24,6 +24,7 @@
  */
 
 #include <android/input.h>
+#include <android/os/IInputConstants.h>
 #include <math.h>
 #include <stdint.h>
 #include <ui/Transform.h>
@@ -220,6 +221,11 @@
 
     POLICY_FLAG_RAW_MASK = 0x0000ffff,
 
+    POLICY_FLAG_INPUTFILTER_TRUSTED = android::os::IInputConstants::POLICY_FLAG_INPUTFILTER_TRUSTED,
+
+    POLICY_FLAG_INJECTED_FROM_ACCESSIBILITY =
+            android::os::IInputConstants::POLICY_FLAG_INJECTED_FROM_ACCESSIBILITY,
+
     /* These flags are set by the input dispatcher. */
 
     // Indicates that the input event was injected.
diff --git a/include/input/InputDevice.h b/include/input/InputDevice.h
index 7f0324a..1955104 100644
--- a/include/input/InputDevice.h
+++ b/include/input/InputDevice.h
@@ -318,6 +318,8 @@
         const std::string& name, InputDeviceConfigurationFileType type);
 
 enum ReservedInputDeviceId : int32_t {
+    // Device id assigned to input events generated inside accessibility service
+    ACCESSIBILITY_DEVICE_ID = -2,
     // Device id of a special "virtual" keyboard that is always present.
     VIRTUAL_KEYBOARD_ID = -1,
     // Device id of the "built-in" keyboard if there is one.
diff --git a/libs/input/android/os/IInputConstants.aidl b/libs/input/android/os/IInputConstants.aidl
index 4b90844..3038d9d 100644
--- a/libs/input/android/os/IInputConstants.aidl
+++ b/libs/input/android/os/IInputConstants.aidl
@@ -40,4 +40,18 @@
      * available.
      */
     const int INVALID_INPUT_EVENT_ID = 0;
+
+    /**
+     * The injected event was originally sent from InputDispatcher. Most likely, the journey of the
+     * event looked as follows:
+     * InputDispatcherPolicyInterface::filterInputEvent -> InputFilter.java::onInputEvent ->
+     * InputFilter.java::sendInputEvent -> InputDispatcher::injectInputEvent, without being modified
+     * along the way.
+     */
+    const int POLICY_FLAG_INPUTFILTER_TRUSTED = 0x10000;
+
+    /**
+     * The input event was injected from accessibility
+     */
+    const int POLICY_FLAG_INJECTED_FROM_ACCESSIBILITY = 0x20000;
 }
diff --git a/services/inputflinger/dispatcher/InputDispatcher.cpp b/services/inputflinger/dispatcher/InputDispatcher.cpp
index cf433c0..d2b8739 100644
--- a/services/inputflinger/dispatcher/InputDispatcher.cpp
+++ b/services/inputflinger/dispatcher/InputDispatcher.cpp
@@ -3787,7 +3787,7 @@
         if (shouldSendKeyToInputFilterLocked(args)) {
             mLock.unlock();
 
-            policyFlags |= POLICY_FLAG_FILTERED;
+            policyFlags |= POLICY_FLAG_FILTERED | POLICY_FLAG_INPUTFILTER_TRUSTED;
             if (!mPolicy->filterInputEvent(&event, policyFlags)) {
                 return; // event was consumed by the filter
             }
@@ -4009,6 +4009,19 @@
         policyFlags |= POLICY_FLAG_TRUSTED;
     }
 
+    // For all injected events, set device id = VIRTUAL_KEYBOARD_ID. The only exception is events
+    // that have gone through the InputFilter. If the event passed through the InputFilter,
+    // but did not get modified, assign the provided device id. If the InputFilter modifies the
+    // events in any way, it is responsible for removing this flag.
+    // If the injected event originated from accessibility, assign the accessibility device id,
+    // so that it can be distinguished from regular injected events.
+    int32_t resolvedDeviceId = VIRTUAL_KEYBOARD_ID;
+    if (policyFlags & POLICY_FLAG_INPUTFILTER_TRUSTED) {
+        resolvedDeviceId = event->getDeviceId();
+    } else if (policyFlags & POLICY_FLAG_INJECTED_FROM_ACCESSIBILITY) {
+        resolvedDeviceId = ACCESSIBILITY_DEVICE_ID;
+    }
+
     std::queue<std::unique_ptr<EventEntry>> injectedEntries;
     switch (event->getType()) {
         case AINPUT_EVENT_TYPE_KEY: {
@@ -4021,10 +4034,10 @@
             int32_t flags = incomingKey.getFlags();
             int32_t keyCode = incomingKey.getKeyCode();
             int32_t metaState = incomingKey.getMetaState();
-            accelerateMetaShortcuts(VIRTUAL_KEYBOARD_ID, action,
+            accelerateMetaShortcuts(resolvedDeviceId, action,
                                     /*byref*/ keyCode, /*byref*/ metaState);
             KeyEvent keyEvent;
-            keyEvent.initialize(incomingKey.getId(), VIRTUAL_KEYBOARD_ID, incomingKey.getSource(),
+            keyEvent.initialize(incomingKey.getId(), resolvedDeviceId, incomingKey.getSource(),
                                 incomingKey.getDisplayId(), INVALID_HMAC, action, flags, keyCode,
                                 incomingKey.getScanCode(), metaState, incomingKey.getRepeatCount(),
                                 incomingKey.getDownTime(), incomingKey.getEventTime());
@@ -4045,7 +4058,7 @@
             mLock.lock();
             std::unique_ptr<KeyEntry> injectedEntry =
                     std::make_unique<KeyEntry>(incomingKey.getId(), incomingKey.getEventTime(),
-                                               VIRTUAL_KEYBOARD_ID, incomingKey.getSource(),
+                                               resolvedDeviceId, incomingKey.getSource(),
                                                incomingKey.getDisplayId(), policyFlags, action,
                                                flags, keyCode, incomingKey.getScanCode(), metaState,
                                                incomingKey.getRepeatCount(),
@@ -4055,18 +4068,18 @@
         }
 
         case AINPUT_EVENT_TYPE_MOTION: {
-            const MotionEvent* motionEvent = static_cast<const MotionEvent*>(event);
-            int32_t action = motionEvent->getAction();
-            size_t pointerCount = motionEvent->getPointerCount();
-            const PointerProperties* pointerProperties = motionEvent->getPointerProperties();
-            int32_t actionButton = motionEvent->getActionButton();
-            int32_t displayId = motionEvent->getDisplayId();
+            const MotionEvent& motionEvent = static_cast<const MotionEvent&>(*event);
+            int32_t action = motionEvent.getAction();
+            size_t pointerCount = motionEvent.getPointerCount();
+            const PointerProperties* pointerProperties = motionEvent.getPointerProperties();
+            int32_t actionButton = motionEvent.getActionButton();
+            int32_t displayId = motionEvent.getDisplayId();
             if (!validateMotionEvent(action, actionButton, pointerCount, pointerProperties)) {
                 return InputEventInjectionResult::FAILED;
             }
 
             if (!(policyFlags & POLICY_FLAG_FILTERED)) {
-                nsecs_t eventTime = motionEvent->getEventTime();
+                nsecs_t eventTime = motionEvent.getEventTime();
                 android::base::Timer t;
                 mPolicy->interceptMotionBeforeQueueing(displayId, eventTime, /*byref*/ policyFlags);
                 if (t.duration() > SLOW_INTERCEPTION_THRESHOLD) {
@@ -4076,47 +4089,46 @@
             }
 
             mLock.lock();
-            const nsecs_t* sampleEventTimes = motionEvent->getSampleEventTimes();
-            const PointerCoords* samplePointerCoords = motionEvent->getSamplePointerCoords();
+            const nsecs_t* sampleEventTimes = motionEvent.getSampleEventTimes();
+            const PointerCoords* samplePointerCoords = motionEvent.getSamplePointerCoords();
             std::unique_ptr<MotionEntry> injectedEntry =
-                    std::make_unique<MotionEntry>(motionEvent->getId(), *sampleEventTimes,
-                                                  VIRTUAL_KEYBOARD_ID, motionEvent->getSource(),
-                                                  motionEvent->getDisplayId(), policyFlags, action,
-                                                  actionButton, motionEvent->getFlags(),
-                                                  motionEvent->getMetaState(),
-                                                  motionEvent->getButtonState(),
-                                                  motionEvent->getClassification(),
-                                                  motionEvent->getEdgeFlags(),
-                                                  motionEvent->getXPrecision(),
-                                                  motionEvent->getYPrecision(),
-                                                  motionEvent->getRawXCursorPosition(),
-                                                  motionEvent->getRawYCursorPosition(),
-                                                  motionEvent->getDownTime(),
-                                                  uint32_t(pointerCount), pointerProperties,
-                                                  samplePointerCoords, motionEvent->getXOffset(),
-                                                  motionEvent->getYOffset());
+                    std::make_unique<MotionEntry>(motionEvent.getId(), *sampleEventTimes,
+                                                  resolvedDeviceId, motionEvent.getSource(),
+                                                  motionEvent.getDisplayId(), policyFlags, action,
+                                                  actionButton, motionEvent.getFlags(),
+                                                  motionEvent.getMetaState(),
+                                                  motionEvent.getButtonState(),
+                                                  motionEvent.getClassification(),
+                                                  motionEvent.getEdgeFlags(),
+                                                  motionEvent.getXPrecision(),
+                                                  motionEvent.getYPrecision(),
+                                                  motionEvent.getRawXCursorPosition(),
+                                                  motionEvent.getRawYCursorPosition(),
+                                                  motionEvent.getDownTime(), uint32_t(pointerCount),
+                                                  pointerProperties, samplePointerCoords,
+                                                  motionEvent.getXOffset(),
+                                                  motionEvent.getYOffset());
             injectedEntries.push(std::move(injectedEntry));
-            for (size_t i = motionEvent->getHistorySize(); i > 0; i--) {
+            for (size_t i = motionEvent.getHistorySize(); i > 0; i--) {
                 sampleEventTimes += 1;
                 samplePointerCoords += pointerCount;
                 std::unique_ptr<MotionEntry> nextInjectedEntry =
-                        std::make_unique<MotionEntry>(motionEvent->getId(), *sampleEventTimes,
-                                                      VIRTUAL_KEYBOARD_ID, motionEvent->getSource(),
-                                                      motionEvent->getDisplayId(), policyFlags,
-                                                      action, actionButton, motionEvent->getFlags(),
-                                                      motionEvent->getMetaState(),
-                                                      motionEvent->getButtonState(),
-                                                      motionEvent->getClassification(),
-                                                      motionEvent->getEdgeFlags(),
-                                                      motionEvent->getXPrecision(),
-                                                      motionEvent->getYPrecision(),
-                                                      motionEvent->getRawXCursorPosition(),
-                                                      motionEvent->getRawYCursorPosition(),
-                                                      motionEvent->getDownTime(),
+                        std::make_unique<MotionEntry>(motionEvent.getId(), *sampleEventTimes,
+                                                      resolvedDeviceId, motionEvent.getSource(),
+                                                      motionEvent.getDisplayId(), policyFlags,
+                                                      action, actionButton, motionEvent.getFlags(),
+                                                      motionEvent.getMetaState(),
+                                                      motionEvent.getButtonState(),
+                                                      motionEvent.getClassification(),
+                                                      motionEvent.getEdgeFlags(),
+                                                      motionEvent.getXPrecision(),
+                                                      motionEvent.getYPrecision(),
+                                                      motionEvent.getRawXCursorPosition(),
+                                                      motionEvent.getRawYCursorPosition(),
+                                                      motionEvent.getDownTime(),
                                                       uint32_t(pointerCount), pointerProperties,
-                                                      samplePointerCoords,
-                                                      motionEvent->getXOffset(),
-                                                      motionEvent->getYOffset());
+                                                      samplePointerCoords, motionEvent.getXOffset(),
+                                                      motionEvent.getYOffset());
                 injectedEntries.push(std::move(nextInjectedEntry));
             }
             break;
diff --git a/services/inputflinger/tests/InputDispatcher_test.cpp b/services/inputflinger/tests/InputDispatcher_test.cpp
index 93aa6ac..d51acce 100644
--- a/services/inputflinger/tests/InputDispatcher_test.cpp
+++ b/services/inputflinger/tests/InputDispatcher_test.cpp
@@ -473,6 +473,7 @@
                           const sp<InputWindowHandle>& focusedWindow = nullptr) {
         FocusRequest request;
         request.token = window->getToken();
+        request.windowName = window->getName();
         if (focusedWindow) {
             request.focusedToken = focusedWindow->getToken();
         }
@@ -1085,6 +1086,20 @@
         return mInputReceiver->consume();
     }
 
+    MotionEvent* consumeMotion() {
+        InputEvent* event = consume();
+        if (event == nullptr) {
+            ADD_FAILURE() << "Consume failed : no event";
+            return nullptr;
+        }
+        if (event->getType() != AINPUT_EVENT_TYPE_MOTION) {
+            ADD_FAILURE() << "Instead of motion event, got "
+                          << inputEventTypeToString(event->getType());
+            return nullptr;
+        }
+        return static_cast<MotionEvent*>(event);
+    }
+
     void assertNoEvents() {
         if (mInputReceiver == nullptr &&
             mInfo.inputFeatures.test(InputWindowInfo::Feature::NO_INPUT_CHANNEL)) {
@@ -2446,13 +2461,10 @@
                 generateMotionArgs(AMOTION_EVENT_ACTION_MOVE, source, ADISPLAY_ID_DEFAULT);
         mDispatcher->notifyMotion(&motionArgs);
 
-        InputEvent* event = window->consume();
+        MotionEvent* event = window->consumeMotion();
         ASSERT_NE(event, nullptr);
-        ASSERT_EQ(AINPUT_EVENT_TYPE_MOTION, event->getType())
-                << name.c_str() << "expected " << inputEventTypeToString(AINPUT_EVENT_TYPE_MOTION)
-                << " event, got " << inputEventTypeToString(event->getType()) << " event";
 
-        const MotionEvent& motionEvent = static_cast<const MotionEvent&>(*event);
+        const MotionEvent& motionEvent = *event;
         EXPECT_EQ(AMOTION_EVENT_ACTION_MOVE, motionEvent.getAction());
         EXPECT_EQ(motionArgs.pointerCount, motionEvent.getPointerCount());
 
@@ -3118,6 +3130,70 @@
     testNotifyKey(/*expectToBeFiltered*/ false);
 }
 
+class InputFilterInjectionPolicyTest : public InputDispatcherTest {
+protected:
+    virtual void SetUp() override {
+        InputDispatcherTest::SetUp();
+
+        /**
+         * We don't need to enable input filter to test the injected event policy, but we enabled it
+         * here to make the tests more realistic, since this policy only matters when inputfilter is
+         * on.
+         */
+        mDispatcher->setInputFilterEnabled(true);
+
+        std::shared_ptr<InputApplicationHandle> application =
+                std::make_shared<FakeApplicationHandle>();
+        mWindow =
+                new FakeWindowHandle(application, mDispatcher, "Test Window", ADISPLAY_ID_DEFAULT);
+
+        mDispatcher->setFocusedApplication(ADISPLAY_ID_DEFAULT, application);
+        mWindow->setFocusable(true);
+        mDispatcher->setInputWindows({{ADISPLAY_ID_DEFAULT, {mWindow}}});
+        setFocusedWindow(mWindow);
+        mWindow->consumeFocusEvent(true);
+    }
+
+    void testInjectedKey(int32_t policyFlags, int32_t injectedDeviceId, int32_t resolvedDeviceId) {
+        KeyEvent event;
+
+        const nsecs_t eventTime = systemTime(SYSTEM_TIME_MONOTONIC);
+        event.initialize(InputEvent::nextId(), injectedDeviceId, AINPUT_SOURCE_KEYBOARD,
+                         ADISPLAY_ID_NONE, INVALID_HMAC, AKEY_EVENT_ACTION_DOWN, 0, AKEYCODE_A,
+                         KEY_A, AMETA_NONE, 0 /*repeatCount*/, eventTime, eventTime);
+        const int32_t additionalPolicyFlags =
+                POLICY_FLAG_PASS_TO_USER | POLICY_FLAG_DISABLE_KEY_REPEAT;
+        ASSERT_EQ(InputEventInjectionResult::SUCCEEDED,
+                  mDispatcher->injectInputEvent(&event, INJECTOR_PID, INJECTOR_UID,
+                                                InputEventInjectionSync::WAIT_FOR_RESULT, 10ms,
+                                                policyFlags | additionalPolicyFlags));
+
+        InputEvent* received = mWindow->consume();
+        ASSERT_NE(nullptr, received);
+        ASSERT_EQ(resolvedDeviceId, received->getDeviceId());
+    }
+
+private:
+    sp<FakeWindowHandle> mWindow;
+};
+
+TEST_F(InputFilterInjectionPolicyTest, TrustedFilteredEvents_KeepOriginalDeviceId) {
+    // We don't need POLICY_FLAG_FILTERED here, but it will be set in practice, so keep it to make
+    // the test more closely resemble the real usage
+    testInjectedKey(POLICY_FLAG_FILTERED | POLICY_FLAG_INPUTFILTER_TRUSTED, 3 /*injectedDeviceId*/,
+                    3 /*resolvedDeviceId*/);
+}
+
+TEST_F(InputFilterInjectionPolicyTest, EventsInjectedFromAccessibility_HaveAccessibilityDeviceId) {
+    testInjectedKey(POLICY_FLAG_FILTERED | POLICY_FLAG_INJECTED_FROM_ACCESSIBILITY,
+                    3 /*injectedDeviceId*/, ACCESSIBILITY_DEVICE_ID /*resolvedDeviceId*/);
+}
+
+TEST_F(InputFilterInjectionPolicyTest, RegularInjectedEvents_ReceiveVirtualDeviceId) {
+    testInjectedKey(0 /*policyFlags*/, 3 /*injectedDeviceId*/,
+                    VIRTUAL_KEYBOARD_ID /*resolvedDeviceId*/);
+}
+
 class InputDispatcherOnPointerDownOutsideFocus : public InputDispatcherTest {
     virtual void SetUp() override {
         InputDispatcherTest::SetUp();