Only register VirtualDeviceListener when a11y proxies are registered

Immediately registering VirtualDeviceListener during the
AccessibilityManagerService start can lead to issues as this depends on
the timing of VirtualDeviceManagerService starting.

Instead, only register VirtualDeviceListener when a a11y proxy is
registered.

Bug: 302519290
Test: atest ProxyManagerTest AccessibilityDisplayProxyTest
Change-Id: I80b5213107107de358adcf5d8193fa438ffe55c0
diff --git a/services/accessibility/java/com/android/server/accessibility/AccessibilityManagerService.java b/services/accessibility/java/com/android/server/accessibility/AccessibilityManagerService.java
index 87f9cf1..d575102 100644
--- a/services/accessibility/java/com/android/server/accessibility/AccessibilityManagerService.java
+++ b/services/accessibility/java/com/android/server/accessibility/AccessibilityManagerService.java
@@ -64,7 +64,6 @@
 import android.app.RemoteAction;
 import android.app.admin.DevicePolicyManager;
 import android.appwidget.AppWidgetManagerInternal;
-import android.companion.virtual.VirtualDeviceManager;
 import android.content.ActivityNotFoundException;
 import android.content.BroadcastReceiver;
 import android.content.ComponentName;
@@ -1071,18 +1070,7 @@
         mContext.registerReceiverAsUser(receiver, UserHandle.ALL, filter, null, mMainHandler,
                 Context.RECEIVER_EXPORTED);
 
-        if (android.companion.virtual.flags.Flags.vdmPublicApis()) {
-            VirtualDeviceManager vdm = mContext.getSystemService(VirtualDeviceManager.class);
-            if (vdm != null) {
-                vdm.registerVirtualDeviceListener(mContext.getMainExecutor(),
-                        new VirtualDeviceManager.VirtualDeviceListener() {
-                            @Override
-                            public void onVirtualDeviceClosed(int deviceId) {
-                                mProxyManager.clearConnections(deviceId);
-                            }
-                        });
-            }
-        } else {
+        if (!android.companion.virtual.flags.Flags.vdmPublicApis()) {
             final BroadcastReceiver virtualDeviceReceiver = new BroadcastReceiver() {
                 @Override
                 public void onReceive(Context context, Intent intent) {
diff --git a/services/accessibility/java/com/android/server/accessibility/ProxyManager.java b/services/accessibility/java/com/android/server/accessibility/ProxyManager.java
index ed77476..2032a50 100644
--- a/services/accessibility/java/com/android/server/accessibility/ProxyManager.java
+++ b/services/accessibility/java/com/android/server/accessibility/ProxyManager.java
@@ -101,6 +101,8 @@
     private VirtualDeviceManagerInternal.AppsOnVirtualDeviceListener
             mAppsOnVirtualDeviceListener;
 
+    private VirtualDeviceManager.VirtualDeviceListener mVirtualDeviceListener;
+
     /**
      * Callbacks into AccessibilityManagerService.
      */
@@ -189,6 +191,9 @@
                     }
                 }
             }
+            if (mProxyA11yServiceConnections.size() == 1) {
+                registerVirtualDeviceListener();
+            }
         }
 
         // If the client dies, make sure to remove the connection.
@@ -210,6 +215,31 @@
         connection.initializeServiceInterface(client);
     }
 
+    private void registerVirtualDeviceListener() {
+        VirtualDeviceManager vdm = mContext.getSystemService(VirtualDeviceManager.class);
+        if (vdm == null || !android.companion.virtual.flags.Flags.vdmPublicApis()) {
+            return;
+        }
+        if (mVirtualDeviceListener == null) {
+            mVirtualDeviceListener = new VirtualDeviceManager.VirtualDeviceListener() {
+                @Override
+                public void onVirtualDeviceClosed(int deviceId) {
+                    clearConnections(deviceId);
+                }
+            };
+        }
+
+        vdm.registerVirtualDeviceListener(mContext.getMainExecutor(), mVirtualDeviceListener);
+    }
+
+    private void unregisterVirtualDeviceListener() {
+        VirtualDeviceManager vdm = mContext.getSystemService(VirtualDeviceManager.class);
+        if (vdm == null || !android.companion.virtual.flags.Flags.vdmPublicApis()) {
+            return;
+        }
+        vdm.unregisterVirtualDeviceListener(mVirtualDeviceListener);
+    }
+
     /**
      * Unregister the proxy based on display id.
      */
@@ -258,6 +288,9 @@
                 deviceId = mProxyA11yServiceConnections.get(displayId).getDeviceId();
                 mProxyA11yServiceConnections.remove(displayId);
                 removedFromConnections = true;
+                if (mProxyA11yServiceConnections.size() == 0) {
+                    unregisterVirtualDeviceListener();
+                }
             }
         }
 
diff --git a/services/tests/servicestests/src/com/android/server/accessibility/ProxyManagerTest.java b/services/tests/servicestests/src/com/android/server/accessibility/ProxyManagerTest.java
index 3808f30..bfaf4959 100644
--- a/services/tests/servicestests/src/com/android/server/accessibility/ProxyManagerTest.java
+++ b/services/tests/servicestests/src/com/android/server/accessibility/ProxyManagerTest.java
@@ -29,6 +29,8 @@
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.anyInt;
 import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.timeout;
+import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
 
@@ -38,6 +40,7 @@
 import android.accessibilityservice.IAccessibilityServiceClient;
 import android.accessibilityservice.IAccessibilityServiceConnection;
 import android.accessibilityservice.MagnificationConfig;
+import android.companion.virtual.IVirtualDeviceListener;
 import android.companion.virtual.IVirtualDeviceManager;
 import android.companion.virtual.VirtualDeviceManager;
 import android.content.ComponentName;
@@ -50,6 +53,7 @@
 import android.platform.test.annotations.RequiresFlagsEnabled;
 import android.platform.test.flag.junit.CheckFlagsRule;
 import android.platform.test.flag.junit.DeviceFlagsValueProvider;
+import android.platform.test.flag.junit.SetFlagsRule;
 import android.util.ArraySet;
 import android.view.KeyEvent;
 import android.view.MotionEvent;
@@ -74,6 +78,7 @@
 import org.junit.Before;
 import org.junit.Rule;
 import org.junit.Test;
+import org.mockito.ArgumentCaptor;
 import org.mockito.Mock;
 import org.mockito.Mockito;
 import org.mockito.MockitoAnnotations;
@@ -94,6 +99,9 @@
     @Rule
     public final CheckFlagsRule mCheckFlagsRule = DeviceFlagsValueProvider.createCheckFlagsRule();
 
+    @Rule
+    public SetFlagsRule mSetFlagsRule = new SetFlagsRule();
+
     @Mock private Context mMockContext;
     @Mock private AccessibilitySecurityPolicy mMockSecurityPolicy;
     @Mock private AccessibilityWindowManager mMockA11yWindowManager;
@@ -114,6 +122,8 @@
 
     @Before
     public void setup() throws RemoteException {
+        mSetFlagsRule.initAllFlagsToReleaseConfigDefault();
+
         MockitoAnnotations.initMocks(this);
         final Resources resources = InstrumentationRegistry.getContext().getResources();
 
@@ -121,6 +131,8 @@
                 resources.getDimensionPixelSize(R.dimen.accessibility_focus_highlight_stroke_width);
         mFocusColorDefaultValue = resources.getColor(R.color.accessibility_focus_highlight_color);
         when(mMockContext.getResources()).thenReturn(resources);
+        when(mMockContext.getMainExecutor())
+                .thenReturn(InstrumentationRegistry.getTargetContext().getMainExecutor());
 
         when(mMockVirtualDeviceManagerInternal.getDeviceIdsForUid(anyInt())).thenReturn(
                 new ArraySet(Set.of(DEVICE_ID)));
@@ -416,6 +428,101 @@
         assertThat(focusStrokeWidth).isEqualTo(mFocusStrokeWidthDefaultValue);
     }
 
+    @Test
+    public void testRegisterProxy_registersVirtualDeviceListener() throws RemoteException {
+        mSetFlagsRule.enableFlags(android.companion.virtual.flags.Flags.FLAG_VDM_PUBLIC_APIS);
+        registerProxy(DISPLAY_ID);
+
+        verify(mMockIVirtualDeviceManager, times(1)).registerVirtualDeviceListener(any());
+    }
+
+    @Test
+    public void testRegisterMultipleProxies_registersOneVirtualDeviceListener()
+            throws RemoteException {
+        mSetFlagsRule.enableFlags(android.companion.virtual.flags.Flags.FLAG_VDM_PUBLIC_APIS);
+        registerProxy(DISPLAY_ID);
+        registerProxy(DISPLAY_2_ID);
+
+        verify(mMockIVirtualDeviceManager, times(1)).registerVirtualDeviceListener(any());
+    }
+
+    @Test
+    public void testUnregisterProxy_unregistersVirtualDeviceListener() throws RemoteException {
+        mSetFlagsRule.enableFlags(android.companion.virtual.flags.Flags.FLAG_VDM_PUBLIC_APIS);
+        registerProxy(DISPLAY_ID);
+
+        mProxyManager.unregisterProxy(DISPLAY_ID);
+
+        verify(mMockIVirtualDeviceManager, times(1)).unregisterVirtualDeviceListener(any());
+    }
+
+    @Test
+    public void testUnregisterProxy_onlyUnregistersVirtualDeviceListenerOnLastProxyRemoval()
+            throws RemoteException {
+        mSetFlagsRule.enableFlags(android.companion.virtual.flags.Flags.FLAG_VDM_PUBLIC_APIS);
+        registerProxy(DISPLAY_ID);
+        registerProxy(DISPLAY_2_ID);
+
+        mProxyManager.unregisterProxy(DISPLAY_ID);
+        verify(mMockIVirtualDeviceManager, never()).unregisterVirtualDeviceListener(any());
+
+        mProxyManager.unregisterProxy(DISPLAY_2_ID);
+        verify(mMockIVirtualDeviceManager, times(1)).unregisterVirtualDeviceListener(any());
+    }
+
+    @Test
+    public void testRegisteredProxy_virtualDeviceClosed_proxyClosed()
+            throws RemoteException {
+        mSetFlagsRule.enableFlags(android.companion.virtual.flags.Flags.FLAG_VDM_PUBLIC_APIS);
+        registerProxy(DISPLAY_ID);
+
+        assertThat(mProxyManager.isProxyedDeviceId(DEVICE_ID)).isTrue();
+        assertThat(mProxyManager.isProxyedDisplay(DISPLAY_ID)).isTrue();
+
+        ArgumentCaptor<IVirtualDeviceListener> listenerArgumentCaptor =
+                ArgumentCaptor.forClass(IVirtualDeviceListener.class);
+        verify(mMockIVirtualDeviceManager, times(1))
+                .registerVirtualDeviceListener(listenerArgumentCaptor.capture());
+
+        listenerArgumentCaptor.getValue().onVirtualDeviceClosed(DEVICE_ID);
+
+        verify(mMockProxySystemSupport, timeout(5_000)).removeDeviceIdLocked(DEVICE_ID);
+
+        assertThat(mProxyManager.isProxyedDeviceId(DEVICE_ID)).isFalse();
+        assertThat(mProxyManager.isProxyedDisplay(DISPLAY_ID)).isFalse();
+    }
+
+    @Test
+    public void testRegisteredProxy_unrelatedVirtualDeviceClosed_proxyNotClosed()
+            throws RemoteException {
+        mSetFlagsRule.enableFlags(android.companion.virtual.flags.Flags.FLAG_VDM_PUBLIC_APIS);
+        registerProxy(DISPLAY_ID);
+
+        assertThat(mProxyManager.isProxyedDeviceId(DEVICE_ID)).isTrue();
+        assertThat(mProxyManager.isProxyedDisplay(DISPLAY_ID)).isTrue();
+
+        ArgumentCaptor<IVirtualDeviceListener> listenerArgumentCaptor =
+                ArgumentCaptor.forClass(IVirtualDeviceListener.class);
+        verify(mMockIVirtualDeviceManager, times(1))
+                .registerVirtualDeviceListener(listenerArgumentCaptor.capture());
+
+        listenerArgumentCaptor.getValue().onVirtualDeviceClosed(DEVICE_ID + 1);
+
+        assertThat(mProxyManager.isProxyedDeviceId(DEVICE_ID)).isTrue();
+        assertThat(mProxyManager.isProxyedDisplay(DISPLAY_ID)).isTrue();
+    }
+
+    @Test
+    public void testRegisterProxy_doesNotRegisterVirtualDeviceListener_flagDisabled()
+            throws RemoteException {
+        mSetFlagsRule.disableFlags(android.companion.virtual.flags.Flags.FLAG_VDM_PUBLIC_APIS);
+        registerProxy(DISPLAY_ID);
+        mProxyManager.unregisterProxy(DISPLAY_ID);
+
+        verify(mMockIVirtualDeviceManager, never()).registerVirtualDeviceListener(any());
+        verify(mMockIVirtualDeviceManager, never()).unregisterVirtualDeviceListener(any());
+    }
+
     private void registerProxy(int displayId) {
         try {
             mProxyManager.registerProxy(mMockAccessibilityServiceClient, displayId, anyInt(),