InputDispatcher: Improve pointer transform mapping in InputTarget

Bug: 316355518
Bug: 210460522
Test: atest inputflinger_tests
Change-Id: I14f77d6138ad2458649650fd09cef9253a669363
diff --git a/services/inputflinger/dispatcher/InputTarget.cpp b/services/inputflinger/dispatcher/InputTarget.cpp
index 28e3c6d..35ad858 100644
--- a/services/inputflinger/dispatcher/InputTarget.cpp
+++ b/services/inputflinger/dispatcher/InputTarget.cpp
@@ -26,6 +26,15 @@
 
 namespace android::inputdispatcher {
 
+namespace {
+
+const static ui::Transform kIdentityTransform{};
+
+}
+
+InputTarget::InputTarget(const std::shared_ptr<Connection>& connection, ftl::Flags<Flags> flags)
+      : connection(connection), flags(flags) {}
+
 void InputTarget::addPointers(std::bitset<MAX_POINTER_ID + 1> newPointerIds,
                               const ui::Transform& transform) {
     // The pointerIds can be empty, but still a valid InputTarget. This can happen when there is no
@@ -36,31 +45,46 @@
     }
 
     // Ensure that the new set of pointers doesn't overlap with the current set of pointers.
-    if ((pointerIds & newPointerIds).any()) {
+    if ((getPointerIds() & newPointerIds).any()) {
         LOG(FATAL) << __func__ << " - overlap with incoming pointers "
                    << bitsetToString(newPointerIds) << " in " << *this;
     }
 
-    pointerIds |= newPointerIds;
-    for (size_t i = 0; i < newPointerIds.size(); i++) {
-        if (!newPointerIds.test(i)) {
-            continue;
+    for (auto& [existingTransform, existingPointers] : mPointerTransforms) {
+        if (transform == existingTransform) {
+            existingPointers |= newPointerIds;
+            return;
         }
-        pointerTransforms[i] = transform;
     }
+    mPointerTransforms.emplace_back(transform, newPointerIds);
 }
 
 void InputTarget::setDefaultPointerTransform(const ui::Transform& transform) {
-    pointerIds.reset();
-    pointerTransforms[0] = transform;
+    mPointerTransforms = {{transform, {}}};
 }
 
 bool InputTarget::useDefaultPointerTransform() const {
-    return pointerIds.none();
+    return mPointerTransforms.size() <= 1;
 }
 
 const ui::Transform& InputTarget::getDefaultPointerTransform() const {
-    return pointerTransforms[0];
+    if (!useDefaultPointerTransform()) {
+        LOG(FATAL) << __func__ << ": Not using default pointer transform";
+    }
+    return mPointerTransforms.size() == 1 ? mPointerTransforms[0].first : kIdentityTransform;
+}
+
+const ui::Transform& InputTarget::getTransformForPointer(int32_t pointerId) const {
+    for (const auto& [transform, ids] : mPointerTransforms) {
+        if (ids.test(pointerId)) {
+            return transform;
+        }
+    }
+
+    LOG(FATAL) << __func__
+               << ": Cannot get transform: The following Pointer ID does not exist in target: "
+               << pointerId;
+    return kIdentityTransform;
 }
 
 std::string InputTarget::getPointerInfoString() const {
@@ -71,17 +95,21 @@
         return out;
     }
 
-    for (uint32_t i = 0; i < pointerIds.size(); i++) {
-        if (!pointerIds.test(i)) {
-            continue;
-        }
-
-        const std::string name = "pointerId " + std::to_string(i) + ":";
-        pointerTransforms[i].dump(out, name.c_str(), "        ");
+    for (const auto& [transform, ids] : mPointerTransforms) {
+        const std::string name = "pointerIds " + bitsetToString(ids) + ":";
+        transform.dump(out, name.c_str(), "        ");
     }
     return out;
 }
 
+std::bitset<MAX_POINTER_ID + 1> InputTarget::getPointerIds() const {
+    PointerIds allIds;
+    for (const auto& [_, ids] : mPointerTransforms) {
+        allIds |= ids;
+    }
+    return allIds;
+}
+
 std::ostream& operator<<(std::ostream& out, const InputTarget& target) {
     out << "{connection=";
     if (target.connection != nullptr) {