Merge "Only register VirtualDeviceListener when a11y proxies are registered" into main
diff --git a/services/accessibility/java/com/android/server/accessibility/AccessibilityManagerService.java b/services/accessibility/java/com/android/server/accessibility/AccessibilityManagerService.java
index 87f9cf1..d575102 100644
--- a/services/accessibility/java/com/android/server/accessibility/AccessibilityManagerService.java
+++ b/services/accessibility/java/com/android/server/accessibility/AccessibilityManagerService.java
@@ -64,7 +64,6 @@
 import android.app.RemoteAction;
 import android.app.admin.DevicePolicyManager;
 import android.appwidget.AppWidgetManagerInternal;
-import android.companion.virtual.VirtualDeviceManager;
 import android.content.ActivityNotFoundException;
 import android.content.BroadcastReceiver;
 import android.content.ComponentName;
@@ -1071,18 +1070,7 @@
         mContext.registerReceiverAsUser(receiver, UserHandle.ALL, filter, null, mMainHandler,
                 Context.RECEIVER_EXPORTED);
 
-        if (android.companion.virtual.flags.Flags.vdmPublicApis()) {
-            VirtualDeviceManager vdm = mContext.getSystemService(VirtualDeviceManager.class);
-            if (vdm != null) {
-                vdm.registerVirtualDeviceListener(mContext.getMainExecutor(),
-                        new VirtualDeviceManager.VirtualDeviceListener() {
-                            @Override
-                            public void onVirtualDeviceClosed(int deviceId) {
-                                mProxyManager.clearConnections(deviceId);
-                            }
-                        });
-            }
-        } else {
+        if (!android.companion.virtual.flags.Flags.vdmPublicApis()) {
             final BroadcastReceiver virtualDeviceReceiver = new BroadcastReceiver() {
                 @Override
                 public void onReceive(Context context, Intent intent) {
diff --git a/services/accessibility/java/com/android/server/accessibility/ProxyManager.java b/services/accessibility/java/com/android/server/accessibility/ProxyManager.java
index ed77476..2032a50 100644
--- a/services/accessibility/java/com/android/server/accessibility/ProxyManager.java
+++ b/services/accessibility/java/com/android/server/accessibility/ProxyManager.java
@@ -101,6 +101,8 @@
     private VirtualDeviceManagerInternal.AppsOnVirtualDeviceListener
             mAppsOnVirtualDeviceListener;
 
+    private VirtualDeviceManager.VirtualDeviceListener mVirtualDeviceListener;
+
     /**
      * Callbacks into AccessibilityManagerService.
      */
@@ -189,6 +191,9 @@
                     }
                 }
             }
+            if (mProxyA11yServiceConnections.size() == 1) {
+                registerVirtualDeviceListener();
+            }
         }
 
         // If the client dies, make sure to remove the connection.
@@ -210,6 +215,31 @@
         connection.initializeServiceInterface(client);
     }
 
+    private void registerVirtualDeviceListener() {
+        VirtualDeviceManager vdm = mContext.getSystemService(VirtualDeviceManager.class);
+        if (vdm == null || !android.companion.virtual.flags.Flags.vdmPublicApis()) {
+            return;
+        }
+        if (mVirtualDeviceListener == null) {
+            mVirtualDeviceListener = new VirtualDeviceManager.VirtualDeviceListener() {
+                @Override
+                public void onVirtualDeviceClosed(int deviceId) {
+                    clearConnections(deviceId);
+                }
+            };
+        }
+
+        vdm.registerVirtualDeviceListener(mContext.getMainExecutor(), mVirtualDeviceListener);
+    }
+
+    private void unregisterVirtualDeviceListener() {
+        VirtualDeviceManager vdm = mContext.getSystemService(VirtualDeviceManager.class);
+        if (vdm == null || !android.companion.virtual.flags.Flags.vdmPublicApis()) {
+            return;
+        }
+        vdm.unregisterVirtualDeviceListener(mVirtualDeviceListener);
+    }
+
     /**
      * Unregister the proxy based on display id.
      */
@@ -258,6 +288,9 @@
                 deviceId = mProxyA11yServiceConnections.get(displayId).getDeviceId();
                 mProxyA11yServiceConnections.remove(displayId);
                 removedFromConnections = true;
+                if (mProxyA11yServiceConnections.size() == 0) {
+                    unregisterVirtualDeviceListener();
+                }
             }
         }
 
diff --git a/services/tests/servicestests/src/com/android/server/accessibility/ProxyManagerTest.java b/services/tests/servicestests/src/com/android/server/accessibility/ProxyManagerTest.java
index 3808f30..bfaf4959 100644
--- a/services/tests/servicestests/src/com/android/server/accessibility/ProxyManagerTest.java
+++ b/services/tests/servicestests/src/com/android/server/accessibility/ProxyManagerTest.java
@@ -29,6 +29,8 @@
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.anyInt;
 import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.timeout;
+import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
 
@@ -38,6 +40,7 @@
 import android.accessibilityservice.IAccessibilityServiceClient;
 import android.accessibilityservice.IAccessibilityServiceConnection;
 import android.accessibilityservice.MagnificationConfig;
+import android.companion.virtual.IVirtualDeviceListener;
 import android.companion.virtual.IVirtualDeviceManager;
 import android.companion.virtual.VirtualDeviceManager;
 import android.content.ComponentName;
@@ -50,6 +53,7 @@
 import android.platform.test.annotations.RequiresFlagsEnabled;
 import android.platform.test.flag.junit.CheckFlagsRule;
 import android.platform.test.flag.junit.DeviceFlagsValueProvider;
+import android.platform.test.flag.junit.SetFlagsRule;
 import android.util.ArraySet;
 import android.view.KeyEvent;
 import android.view.MotionEvent;
@@ -74,6 +78,7 @@
 import org.junit.Before;
 import org.junit.Rule;
 import org.junit.Test;
+import org.mockito.ArgumentCaptor;
 import org.mockito.Mock;
 import org.mockito.Mockito;
 import org.mockito.MockitoAnnotations;
@@ -94,6 +99,9 @@
     @Rule
     public final CheckFlagsRule mCheckFlagsRule = DeviceFlagsValueProvider.createCheckFlagsRule();
 
+    @Rule
+    public SetFlagsRule mSetFlagsRule = new SetFlagsRule();
+
     @Mock private Context mMockContext;
     @Mock private AccessibilitySecurityPolicy mMockSecurityPolicy;
     @Mock private AccessibilityWindowManager mMockA11yWindowManager;
@@ -114,6 +122,8 @@
 
     @Before
     public void setup() throws RemoteException {
+        mSetFlagsRule.initAllFlagsToReleaseConfigDefault();
+
         MockitoAnnotations.initMocks(this);
         final Resources resources = InstrumentationRegistry.getContext().getResources();
 
@@ -121,6 +131,8 @@
                 resources.getDimensionPixelSize(R.dimen.accessibility_focus_highlight_stroke_width);
         mFocusColorDefaultValue = resources.getColor(R.color.accessibility_focus_highlight_color);
         when(mMockContext.getResources()).thenReturn(resources);
+        when(mMockContext.getMainExecutor())
+                .thenReturn(InstrumentationRegistry.getTargetContext().getMainExecutor());
 
         when(mMockVirtualDeviceManagerInternal.getDeviceIdsForUid(anyInt())).thenReturn(
                 new ArraySet(Set.of(DEVICE_ID)));
@@ -416,6 +428,101 @@
         assertThat(focusStrokeWidth).isEqualTo(mFocusStrokeWidthDefaultValue);
     }
 
+    @Test
+    public void testRegisterProxy_registersVirtualDeviceListener() throws RemoteException {
+        mSetFlagsRule.enableFlags(android.companion.virtual.flags.Flags.FLAG_VDM_PUBLIC_APIS);
+        registerProxy(DISPLAY_ID);
+
+        verify(mMockIVirtualDeviceManager, times(1)).registerVirtualDeviceListener(any());
+    }
+
+    @Test
+    public void testRegisterMultipleProxies_registersOneVirtualDeviceListener()
+            throws RemoteException {
+        mSetFlagsRule.enableFlags(android.companion.virtual.flags.Flags.FLAG_VDM_PUBLIC_APIS);
+        registerProxy(DISPLAY_ID);
+        registerProxy(DISPLAY_2_ID);
+
+        verify(mMockIVirtualDeviceManager, times(1)).registerVirtualDeviceListener(any());
+    }
+
+    @Test
+    public void testUnregisterProxy_unregistersVirtualDeviceListener() throws RemoteException {
+        mSetFlagsRule.enableFlags(android.companion.virtual.flags.Flags.FLAG_VDM_PUBLIC_APIS);
+        registerProxy(DISPLAY_ID);
+
+        mProxyManager.unregisterProxy(DISPLAY_ID);
+
+        verify(mMockIVirtualDeviceManager, times(1)).unregisterVirtualDeviceListener(any());
+    }
+
+    @Test
+    public void testUnregisterProxy_onlyUnregistersVirtualDeviceListenerOnLastProxyRemoval()
+            throws RemoteException {
+        mSetFlagsRule.enableFlags(android.companion.virtual.flags.Flags.FLAG_VDM_PUBLIC_APIS);
+        registerProxy(DISPLAY_ID);
+        registerProxy(DISPLAY_2_ID);
+
+        mProxyManager.unregisterProxy(DISPLAY_ID);
+        verify(mMockIVirtualDeviceManager, never()).unregisterVirtualDeviceListener(any());
+
+        mProxyManager.unregisterProxy(DISPLAY_2_ID);
+        verify(mMockIVirtualDeviceManager, times(1)).unregisterVirtualDeviceListener(any());
+    }
+
+    @Test
+    public void testRegisteredProxy_virtualDeviceClosed_proxyClosed()
+            throws RemoteException {
+        mSetFlagsRule.enableFlags(android.companion.virtual.flags.Flags.FLAG_VDM_PUBLIC_APIS);
+        registerProxy(DISPLAY_ID);
+
+        assertThat(mProxyManager.isProxyedDeviceId(DEVICE_ID)).isTrue();
+        assertThat(mProxyManager.isProxyedDisplay(DISPLAY_ID)).isTrue();
+
+        ArgumentCaptor<IVirtualDeviceListener> listenerArgumentCaptor =
+                ArgumentCaptor.forClass(IVirtualDeviceListener.class);
+        verify(mMockIVirtualDeviceManager, times(1))
+                .registerVirtualDeviceListener(listenerArgumentCaptor.capture());
+
+        listenerArgumentCaptor.getValue().onVirtualDeviceClosed(DEVICE_ID);
+
+        verify(mMockProxySystemSupport, timeout(5_000)).removeDeviceIdLocked(DEVICE_ID);
+
+        assertThat(mProxyManager.isProxyedDeviceId(DEVICE_ID)).isFalse();
+        assertThat(mProxyManager.isProxyedDisplay(DISPLAY_ID)).isFalse();
+    }
+
+    @Test
+    public void testRegisteredProxy_unrelatedVirtualDeviceClosed_proxyNotClosed()
+            throws RemoteException {
+        mSetFlagsRule.enableFlags(android.companion.virtual.flags.Flags.FLAG_VDM_PUBLIC_APIS);
+        registerProxy(DISPLAY_ID);
+
+        assertThat(mProxyManager.isProxyedDeviceId(DEVICE_ID)).isTrue();
+        assertThat(mProxyManager.isProxyedDisplay(DISPLAY_ID)).isTrue();
+
+        ArgumentCaptor<IVirtualDeviceListener> listenerArgumentCaptor =
+                ArgumentCaptor.forClass(IVirtualDeviceListener.class);
+        verify(mMockIVirtualDeviceManager, times(1))
+                .registerVirtualDeviceListener(listenerArgumentCaptor.capture());
+
+        listenerArgumentCaptor.getValue().onVirtualDeviceClosed(DEVICE_ID + 1);
+
+        assertThat(mProxyManager.isProxyedDeviceId(DEVICE_ID)).isTrue();
+        assertThat(mProxyManager.isProxyedDisplay(DISPLAY_ID)).isTrue();
+    }
+
+    @Test
+    public void testRegisterProxy_doesNotRegisterVirtualDeviceListener_flagDisabled()
+            throws RemoteException {
+        mSetFlagsRule.disableFlags(android.companion.virtual.flags.Flags.FLAG_VDM_PUBLIC_APIS);
+        registerProxy(DISPLAY_ID);
+        mProxyManager.unregisterProxy(DISPLAY_ID);
+
+        verify(mMockIVirtualDeviceManager, never()).registerVirtualDeviceListener(any());
+        verify(mMockIVirtualDeviceManager, never()).unregisterVirtualDeviceListener(any());
+    }
+
     private void registerProxy(int displayId) {
         try {
             mProxyManager.registerProxy(mMockAccessibilityServiceClient, displayId, anyInt(),