InputDispatcher_test: Verify all traced events match exactly

Bug: 210460522
Test: atest inputflinger_tests
Change-Id: I2ab660ed0a6888c23bc711fb8494385c22b3c404
diff --git a/include/input/Input.h b/include/input/Input.h
index 7b253a5..a84dcfc 100644
--- a/include/input/Input.h
+++ b/include/input/Input.h
@@ -180,7 +180,8 @@
  * Declare a concrete type for the NDK's input event forward declaration.
  */
 struct AInputEvent {
-    virtual ~AInputEvent() { }
+    virtual ~AInputEvent() {}
+    bool operator==(const AInputEvent&) const = default;
 };
 
 /*
@@ -545,6 +546,8 @@
 
     static int32_t nextId();
 
+    bool operator==(const InputEvent&) const = default;
+
 protected:
     void initialize(int32_t id, DeviceId deviceId, uint32_t source, int32_t displayId,
                     std::array<uint8_t, 32> hmac);
@@ -598,6 +601,8 @@
 
     static const char* actionToString(int32_t action);
 
+    bool operator==(const KeyEvent&) const = default;
+
 protected:
     int32_t mAction;
     int32_t mFlags;
@@ -917,6 +922,9 @@
     // The rounding precision for transformed motion events.
     static constexpr float ROUNDING_PRECISION = 0.001f;
 
+    bool operator==(const MotionEvent&) const;
+    inline bool operator!=(const MotionEvent& o) const { return !(*this == o); };
+
 protected:
     int32_t mAction;
     int32_t mActionButton;
diff --git a/libs/input/Input.cpp b/libs/input/Input.cpp
index 0d29b4d..e831d24 100644
--- a/libs/input/Input.cpp
+++ b/libs/input/Input.cpp
@@ -1007,6 +1007,33 @@
     return out;
 }
 
+bool MotionEvent::operator==(const android::MotionEvent& o) const {
+    // We use NaN values to represent invalid cursor positions. Since NaN values are not equal
+    // to themselves according to IEEE 754, we cannot use the default equality operator to compare
+    // MotionEvents. Therefore we define a custom equality operator with special handling for NaNs.
+    // clang-format off
+    return InputEvent::operator==(static_cast<const InputEvent&>(o)) &&
+            mAction == o.mAction &&
+            mActionButton == o.mActionButton &&
+            mFlags == o.mFlags &&
+            mEdgeFlags == o.mEdgeFlags &&
+            mMetaState == o.mMetaState &&
+            mButtonState == o.mButtonState &&
+            mClassification == o.mClassification &&
+            mTransform == o.mTransform &&
+            mXPrecision == o.mXPrecision &&
+            mYPrecision == o.mYPrecision &&
+            ((std::isnan(mRawXCursorPosition) && std::isnan(o.mRawXCursorPosition)) ||
+                mRawXCursorPosition == o.mRawXCursorPosition) &&
+            ((std::isnan(mRawYCursorPosition) && std::isnan(o.mRawYCursorPosition)) ||
+                mRawYCursorPosition == o.mRawYCursorPosition) &&
+            mRawTransform == o.mRawTransform && mDownTime == o.mDownTime &&
+            mPointerProperties == o.mPointerProperties &&
+            mSampleEventTimes == o.mSampleEventTimes &&
+            mSamplePointerCoords == o.mSamplePointerCoords;
+    // clang-format on
+}
+
 std::ostream& operator<<(std::ostream& out, const MotionEvent& event) {
     out << "MotionEvent { action=" << MotionEvent::actionToString(event.getAction());
     if (event.getActionButton() != 0) {
diff --git a/services/inputflinger/tests/FakeInputTracingBackend.cpp b/services/inputflinger/tests/FakeInputTracingBackend.cpp
index f4a06f7..b7af356 100644
--- a/services/inputflinger/tests/FakeInputTracingBackend.cpp
+++ b/services/inputflinger/tests/FakeInputTracingBackend.cpp
@@ -37,6 +37,30 @@
     return std::visit([](const auto& event) { return event.id; }, v);
 }
 
+MotionEvent toInputEvent(
+        const trace::TracedMotionEvent& e,
+        const trace::InputTracingBackendInterface::WindowDispatchArgs& dispatchArgs,
+        const std::array<uint8_t, 32>& hmac) {
+    MotionEvent traced;
+    traced.initialize(e.id, e.deviceId, e.source, e.displayId, hmac, e.action, e.actionButton,
+                      dispatchArgs.resolvedFlags, e.edgeFlags, e.metaState, e.buttonState,
+                      e.classification, dispatchArgs.transform, e.xPrecision, e.yPrecision,
+                      e.xCursorPosition, e.yCursorPosition, dispatchArgs.rawTransform, e.downTime,
+                      e.eventTime, e.pointerProperties.size(), e.pointerProperties.data(),
+                      e.pointerCoords.data());
+    return traced;
+}
+
+KeyEvent toInputEvent(const trace::TracedKeyEvent& e,
+                      const trace::InputTracingBackendInterface::WindowDispatchArgs& dispatchArgs,
+                      const std::array<uint8_t, 32>& hmac) {
+    KeyEvent traced;
+    traced.initialize(e.id, e.deviceId, e.source, e.displayId, hmac, e.action,
+                      dispatchArgs.resolvedFlags, e.keyCode, e.scanCode, e.metaState, e.repeatCount,
+                      e.downTime, e.eventTime);
+    return traced;
+}
+
 } // namespace
 
 // --- VerifyingTrace ---
@@ -55,6 +79,7 @@
     std::unique_lock lock(mLock);
     base::ScopedLockAssertion assumeLocked(mLock);
 
+    // Poll for all expected events to be traced, and keep track of the latest poll result.
     base::Result<void> result;
     mEventTracedCondition.wait_for(lock, TRACE_TIMEOUT, [&]() REQUIRES(mLock) {
         for (const auto& [expectedEvent, windowId] : mExpectedEvents) {
@@ -101,12 +126,39 @@
                          });
     if (tracedDispatchesIt == mTracedWindowDispatches.end()) {
         msg << "Expected dispatch of event with ID 0x" << std::hex << expectedEvent.getId()
-            << " to window with ID 0x" << expectedWindowId << " to be traced, but it was not."
-            << "\nExpected event: " << expectedEvent;
+            << " to window with ID 0x" << expectedWindowId << " to be traced, but it was not.\n"
+            << "Expected event: " << expectedEvent;
         return error(msg);
     }
 
-    return {};
+    // Verify that the traced event matches the expected event exactly.
+    return std::visit(
+            [&](const auto& traced) -> base::Result<void> {
+                Event tracedEvent;
+                using T = std::decay_t<decltype(traced)>;
+                if constexpr (std::is_same_v<Event, MotionEvent> &&
+                              std::is_same_v<T, trace::TracedMotionEvent>) {
+                    tracedEvent =
+                            toInputEvent(traced, *tracedDispatchesIt, expectedEvent.getHmac());
+                } else if constexpr (std::is_same_v<Event, KeyEvent> &&
+                                     std::is_same_v<T, trace::TracedKeyEvent>) {
+                    tracedEvent =
+                            toInputEvent(traced, *tracedDispatchesIt, expectedEvent.getHmac());
+                } else {
+                    msg << "Received the wrong event type!\n"
+                        << "Expected event: " << expectedEvent;
+                    return error(msg);
+                }
+
+                const auto result = testing::internal::CmpHelperEQ("expectedEvent", "tracedEvent",
+                                                                   expectedEvent, tracedEvent);
+                if (!result) {
+                    msg << result.failure_message();
+                    return error(msg);
+                }
+                return {};
+            },
+            tracedEventsIt->second);
 }
 
 // --- FakeInputTracingBackend ---
@@ -114,7 +166,7 @@
 void FakeInputTracingBackend::traceKeyEvent(const trace::TracedKeyEvent& event) const {
     {
         std::scoped_lock lock(mTrace->mLock);
-        mTrace->mTracedEvents.emplace(event.id);
+        mTrace->mTracedEvents.emplace(event.id, event);
     }
     mTrace->mEventTracedCondition.notify_all();
 }
@@ -122,7 +174,7 @@
 void FakeInputTracingBackend::traceMotionEvent(const trace::TracedMotionEvent& event) const {
     {
         std::scoped_lock lock(mTrace->mLock);
-        mTrace->mTracedEvents.emplace(event.id);
+        mTrace->mTracedEvents.emplace(event.id, event);
     }
     mTrace->mEventTracedCondition.notify_all();
 }
diff --git a/services/inputflinger/tests/FakeInputTracingBackend.h b/services/inputflinger/tests/FakeInputTracingBackend.h
index 40ca3a2..108419a 100644
--- a/services/inputflinger/tests/FakeInputTracingBackend.h
+++ b/services/inputflinger/tests/FakeInputTracingBackend.h
@@ -25,7 +25,7 @@
 #include <condition_variable>
 #include <memory>
 #include <mutex>
-#include <unordered_set>
+#include <unordered_map>
 #include <vector>
 
 namespace android::inputdispatcher {
@@ -58,7 +58,7 @@
 private:
     std::mutex mLock;
     std::condition_variable mEventTracedCondition;
-    std::unordered_set<uint32_t /*eventId*/> mTracedEvents GUARDED_BY(mLock);
+    std::unordered_map<uint32_t /*eventId*/, trace::TracedEvent> mTracedEvents GUARDED_BY(mLock);
     using WindowDispatchArgs = trace::InputTracingBackendInterface::WindowDispatchArgs;
     std::vector<WindowDispatchArgs> mTracedWindowDispatches GUARDED_BY(mLock);
     std::vector<std::pair<std::variant<KeyEvent, MotionEvent>, int32_t /*windowId*/>>