Merge "InputController: Wait for associated display sync for input devices" into main
diff --git a/services/companion/java/com/android/server/companion/virtual/InputController.java b/services/companion/java/com/android/server/companion/virtual/InputController.java
index 8962bf0..1b49f18e 100644
--- a/services/companion/java/com/android/server/companion/virtual/InputController.java
+++ b/services/companion/java/com/android/server/companion/virtual/InputController.java
@@ -692,32 +692,37 @@
 
         private int mInputDeviceId = IInputConstants.INVALID_INPUT_DEVICE_ID;
 
-        WaitForDevice(String deviceName, int vendorId, int productId) {
+        WaitForDevice(String deviceName, int vendorId, int productId, int associatedDisplayId) {
             mListener = new InputManager.InputDeviceListener() {
                 @Override
                 public void onInputDeviceAdded(int deviceId) {
-                    final InputDevice device = InputManagerGlobal.getInstance().getInputDevice(
-                            deviceId);
-                    Objects.requireNonNull(device, "Newly added input device was null.");
-                    if (!device.getName().equals(deviceName)) {
-                        return;
-                    }
-                    final InputDeviceIdentifier id = device.getIdentifier();
-                    if (id.getVendorId() != vendorId || id.getProductId() != productId) {
-                        return;
-                    }
-                    mInputDeviceId = deviceId;
-                    mDeviceAddedLatch.countDown();
+                    onInputDeviceChanged(deviceId);
                 }
 
                 @Override
                 public void onInputDeviceRemoved(int deviceId) {
-
                 }
 
                 @Override
                 public void onInputDeviceChanged(int deviceId) {
+                    if (isMatchingDevice(deviceId)) {
+                        mInputDeviceId = deviceId;
+                        mDeviceAddedLatch.countDown();
+                    }
+                }
 
+                private boolean isMatchingDevice(int deviceId) {
+                    final InputDevice device = InputManagerGlobal.getInstance().getInputDevice(
+                            deviceId);
+                    Objects.requireNonNull(device, "Newly added input device was null.");
+                    if (!device.getName().equals(deviceName)) {
+                        return false;
+                    }
+                    final InputDeviceIdentifier id = device.getIdentifier();
+                    if (id.getVendorId() != vendorId || id.getProductId() != productId) {
+                        return false;
+                    }
+                    return device.getAssociatedDisplayId() == associatedDisplayId;
                 }
             };
             InputManagerGlobal.getInstance().registerInputDeviceListener(mListener, mHandler);
@@ -799,7 +804,7 @@
         final int inputDeviceId;
 
         setUniqueIdAssociation(displayId, phys);
-        try (WaitForDevice waiter = new WaitForDevice(deviceName, vendorId, productId)) {
+        try (WaitForDevice waiter = new WaitForDevice(deviceName, vendorId, productId, displayId)) {
             ptr = deviceOpener.get();
             // See INVALID_PTR in libs/input/VirtualInputDevice.cpp.
             if (ptr == 0) {
diff --git a/services/tests/servicestests/src/com/android/server/companion/virtual/InputControllerTest.java b/services/tests/servicestests/src/com/android/server/companion/virtual/InputControllerTest.java
index 5943832..07e6ab2 100644
--- a/services/tests/servicestests/src/com/android/server/companion/virtual/InputControllerTest.java
+++ b/services/tests/servicestests/src/com/android/server/companion/virtual/InputControllerTest.java
@@ -19,11 +19,10 @@
 import static com.google.common.truth.Truth.assertWithMessage;
 
 import static org.junit.Assert.assertThrows;
-import static org.mockito.ArgumentMatchers.anyInt;
 import static org.mockito.ArgumentMatchers.anyString;
 import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.ArgumentMatchers.startsWith;
-import static org.mockito.Mockito.doReturn;
+import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.verify;
 
 import android.hardware.display.DisplayManagerInternal;
@@ -86,9 +85,8 @@
         LocalServices.removeServiceForTest(InputManagerInternal.class);
         LocalServices.addService(InputManagerInternal.class, mInputManagerInternalMock);
 
-        final DisplayInfo displayInfo = new DisplayInfo();
-        displayInfo.uniqueId = "uniqueId";
-        doReturn(displayInfo).when(mDisplayManagerInternalMock).getDisplayInfo(anyInt());
+        setUpDisplay(1 /* displayId */);
+        setUpDisplay(2 /* displayId */);
         LocalServices.removeServiceForTest(DisplayManagerInternal.class);
         LocalServices.addService(DisplayManagerInternal.class, mDisplayManagerInternalMock);
 
@@ -100,6 +98,16 @@
                 threadVerifier);
     }
 
+    void setUpDisplay(int displayId) {
+        final String uniqueId = "uniqueId:" + displayId;
+        doAnswer((inv) -> {
+            final DisplayInfo displayInfo = new DisplayInfo();
+            displayInfo.uniqueId = uniqueId;
+            return displayInfo;
+        }).when(mDisplayManagerInternalMock).getDisplayInfo(eq(displayId));
+        mInputManagerMockHelper.addDisplayIdMapping(uniqueId, displayId);
+    }
+
     @After
     public void tearDown() {
         mInputManagerMockHelper.tearDown();
diff --git a/services/tests/servicestests/src/com/android/server/companion/virtual/InputManagerMockHelper.java b/services/tests/servicestests/src/com/android/server/companion/virtual/InputManagerMockHelper.java
index 3722247..74e854e4 100644
--- a/services/tests/servicestests/src/com/android/server/companion/virtual/InputManagerMockHelper.java
+++ b/services/tests/servicestests/src/com/android/server/companion/virtual/InputManagerMockHelper.java
@@ -20,7 +20,6 @@
 import static org.mockito.ArgumentMatchers.anyString;
 import static org.mockito.ArgumentMatchers.notNull;
 import static org.mockito.Mockito.doAnswer;
-import static org.mockito.Mockito.doNothing;
 import static org.mockito.Mockito.when;
 
 import android.hardware.input.IInputDevicesChangedListener;
@@ -28,12 +27,15 @@
 import android.hardware.input.InputManagerGlobal;
 import android.os.RemoteException;
 import android.testing.TestableLooper;
+import android.view.Display;
 import android.view.InputDevice;
 
 import org.mockito.invocation.InvocationOnMock;
 
 import java.util.ArrayList;
+import java.util.HashMap;
 import java.util.List;
+import java.util.Map;
 import java.util.Objects;
 import java.util.stream.IntStream;
 
@@ -49,6 +51,10 @@
     private final InputManagerGlobal.TestSession mInputManagerGlobalSession;
     private final List<InputDevice> mDevices = new ArrayList<>();
     private IInputDevicesChangedListener mDevicesChangedListener;
+    private final Map<String /* uniqueId */, Integer /* displayId */> mDisplayIdMapping =
+            new HashMap<>();
+    private final Map<String /* phys */, String /* uniqueId */> mUniqueIdAssociation =
+            new HashMap<>();
 
     InputManagerMockHelper(TestableLooper testableLooper,
             InputController.NativeWrapper nativeWrapperMock, IInputManager iInputManagerMock)
@@ -73,8 +79,10 @@
         when(mIInputManagerMock.getInputDeviceIds()).thenReturn(new int[0]);
         doAnswer(inv -> mDevices.get(inv.getArgument(0)))
                 .when(mIInputManagerMock).getInputDevice(anyInt());
-        doNothing().when(mIInputManagerMock).addUniqueIdAssociation(anyString(), anyString());
-        doNothing().when(mIInputManagerMock).removeUniqueIdAssociation(anyString());
+        doAnswer(inv -> mUniqueIdAssociation.put(inv.getArgument(0), inv.getArgument(1))).when(
+                mIInputManagerMock).addUniqueIdAssociation(anyString(), anyString());
+        doAnswer(inv -> mUniqueIdAssociation.remove(inv.getArgument(0))).when(
+                mIInputManagerMock).removeUniqueIdAssociation(anyString());
 
         // Set a new instance of InputManager for testing that uses the IInputManager mock as the
         // interface to the server.
@@ -87,17 +95,25 @@
         }
     }
 
+    public void addDisplayIdMapping(String uniqueId, int displayId) {
+        mDisplayIdMapping.put(uniqueId, displayId);
+    }
+
     private long handleNativeOpenInputDevice(InvocationOnMock inv) {
         Objects.requireNonNull(mDevicesChangedListener,
                 "InputController did not register an InputDevicesChangedListener.");
 
+        final String phys = inv.getArgument(3);
         final InputDevice device = new InputDevice.Builder()
                 .setId(mDevices.size())
                 .setName(inv.getArgument(0))
                 .setVendorId(inv.getArgument(1))
                 .setProductId(inv.getArgument(2))
-                .setDescriptor(inv.getArgument(3))
+                .setDescriptor(phys)
                 .setExternal(true)
+                .setAssociatedDisplayId(
+                        mDisplayIdMapping.getOrDefault(mUniqueIdAssociation.get(phys),
+                                Display.INVALID_DISPLAY))
                 .build();
 
         mDevices.add(device);
diff --git a/services/tests/servicestests/src/com/android/server/companion/virtual/VirtualDeviceManagerServiceTest.java b/services/tests/servicestests/src/com/android/server/companion/virtual/VirtualDeviceManagerServiceTest.java
index 5442af8..157e893 100644
--- a/services/tests/servicestests/src/com/android/server/companion/virtual/VirtualDeviceManagerServiceTest.java
+++ b/services/tests/servicestests/src/com/android/server/companion/virtual/VirtualDeviceManagerServiceTest.java
@@ -39,6 +39,7 @@
 import static org.mockito.ArgumentMatchers.nullable;
 import static org.mockito.Mockito.argThat;
 import static org.mockito.Mockito.clearInvocations;
+import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.doNothing;
 import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.mock;
@@ -2015,6 +2016,13 @@
                 eq(virtualDevice), any(), any())).thenReturn(displayId);
         virtualDevice.createVirtualDisplay(VIRTUAL_DISPLAY_CONFIG, mVirtualDisplayCallback,
                 NONBLOCKED_APP_PACKAGE_NAME);
+        final String uniqueId = UNIQUE_ID + displayId;
+        doAnswer(inv -> {
+            final DisplayInfo displayInfo = new DisplayInfo();
+            displayInfo.uniqueId = uniqueId;
+            return displayInfo;
+        }).when(mDisplayManagerInternalMock).getDisplayInfo(eq(displayId));
+        mInputManagerMockHelper.addDisplayIdMapping(uniqueId, displayId);
     }
 
     private ComponentName getPermissionDialogComponent() {