Merge "Keep VirtualDevice-created VirtualDisplays awake"
diff --git a/services/companion/java/com/android/server/companion/virtual/VirtualDeviceImpl.java b/services/companion/java/com/android/server/companion/virtual/VirtualDeviceImpl.java
index 5c8fb2e..95b9e58 100644
--- a/services/companion/java/com/android/server/companion/virtual/VirtualDeviceImpl.java
+++ b/services/companion/java/com/android/server/companion/virtual/VirtualDeviceImpl.java
@@ -45,10 +45,12 @@
 import android.hardware.input.VirtualTouchEvent;
 import android.os.Binder;
 import android.os.IBinder;
+import android.os.PowerManager;
 import android.os.RemoteException;
 import android.os.ResultReceiver;
 import android.os.UserHandle;
 import android.os.UserManager;
+import android.util.ArrayMap;
 import android.util.ArraySet;
 import android.util.Slog;
 import android.util.SparseArray;
@@ -59,6 +61,7 @@
 
 import java.io.FileDescriptor;
 import java.io.PrintWriter;
+import java.util.Map;
 import java.util.Set;
 
 
@@ -78,6 +81,7 @@
     private final OnDeviceCloseListener mListener;
     private final IBinder mAppToken;
     private final VirtualDeviceParams mParams;
+    private final Map<Integer, PowerManager.WakeLock> mPerDisplayWakelocks = new ArrayMap<>();
     private final IVirtualDeviceActivityListener mActivityListener;
 
     private ActivityListener createListenerAdapter(int displayId) {
@@ -206,6 +210,16 @@
 
     @Override // Binder call
     public void close() {
+        synchronized (mVirtualDeviceLock) {
+            if (!mPerDisplayWakelocks.isEmpty()) {
+                mPerDisplayWakelocks.forEach((displayId, wakeLock) -> {
+                    Slog.w(TAG, "VirtualDisplay " + displayId + " owned by UID " + mOwnerUid
+                            + " was not properly released");
+                    wakeLock.release();
+                });
+                mPerDisplayWakelocks.clear();
+            }
+        }
         mListener.onClose(mAssociationInfo.getId());
         mAppToken.unlinkToDeath(this, 0);
         mInputController.close();
@@ -383,22 +397,48 @@
     }
 
     DisplayWindowPolicyController onVirtualDisplayCreatedLocked(int displayId) {
-        if (mVirtualDisplayIds.contains(displayId)) {
-            throw new IllegalStateException(
-                    "Virtual device already have a virtual display with ID " + displayId);
+        synchronized (mVirtualDeviceLock) {
+            if (mVirtualDisplayIds.contains(displayId)) {
+                throw new IllegalStateException(
+                        "Virtual device already have a virtual display with ID " + displayId);
+            }
+            mVirtualDisplayIds.add(displayId);
+
+            // Since we're being called in the middle of the display being created, we post a
+            // task to grab the wakelock instead of doing it synchronously here, to avoid
+            // reentrancy  problems.
+            mContext.getMainThreadHandler().post(() -> addWakeLockForDisplay(displayId));
+
+            LocalServices.getService(
+                    InputManagerInternal.class).setDisplayEligibilityForPointerCapture(displayId,
+                    false);
+            final GenericWindowPolicyController dwpc =
+                    new GenericWindowPolicyController(FLAG_SECURE,
+                            SYSTEM_FLAG_HIDE_NON_SYSTEM_OVERLAY_WINDOWS,
+                            getAllowedUserHandles(),
+                            mParams.getAllowedActivities(),
+                            mParams.getBlockedActivities(),
+                            createListenerAdapter(displayId));
+            mWindowPolicyControllers.put(displayId, dwpc);
+            return dwpc;
         }
-        mVirtualDisplayIds.add(displayId);
-        LocalServices.getService(InputManagerInternal.class).setDisplayEligibilityForPointerCapture(
-                displayId, false);
-        final GenericWindowPolicyController dwpc =
-                new GenericWindowPolicyController(FLAG_SECURE,
-                        SYSTEM_FLAG_HIDE_NON_SYSTEM_OVERLAY_WINDOWS,
-                        getAllowedUserHandles(),
-                        mParams.getAllowedActivities(),
-                        mParams.getBlockedActivities(),
-                        createListenerAdapter(displayId));
-        mWindowPolicyControllers.put(displayId, dwpc);
-        return dwpc;
+    }
+
+    void addWakeLockForDisplay(int displayId) {
+        synchronized (mVirtualDeviceLock) {
+            if (!mVirtualDisplayIds.contains(displayId)
+                    || mPerDisplayWakelocks.containsKey(displayId)) {
+                Slog.e(TAG, "Not creating wakelock for displayId " + displayId);
+                return;
+            }
+            PowerManager powerManager = mContext.getSystemService(PowerManager.class);
+            PowerManager.WakeLock wakeLock = powerManager.newWakeLock(
+                    PowerManager.SCREEN_BRIGHT_WAKE_LOCK
+                            | PowerManager.ACQUIRE_CAUSES_WAKEUP,
+                    TAG + ":" + displayId, displayId);
+            wakeLock.acquire();
+            mPerDisplayWakelocks.put(displayId, wakeLock);
+        }
     }
 
     private ArraySet<UserHandle> getAllowedUserHandles() {
@@ -420,14 +460,22 @@
     }
 
     void onVirtualDisplayRemovedLocked(int displayId) {
-        if (!mVirtualDisplayIds.contains(displayId)) {
-            throw new IllegalStateException(
-                    "Virtual device doesn't have a virtual display with ID " + displayId);
+        synchronized (mVirtualDeviceLock) {
+            if (!mVirtualDisplayIds.contains(displayId)) {
+                throw new IllegalStateException(
+                        "Virtual device doesn't have a virtual display with ID " + displayId);
+            }
+            PowerManager.WakeLock wakeLock = mPerDisplayWakelocks.get(displayId);
+            if (wakeLock != null) {
+                wakeLock.release();
+                mPerDisplayWakelocks.remove(displayId);
+            }
+            mVirtualDisplayIds.remove(displayId);
+            LocalServices.getService(
+                    InputManagerInternal.class).setDisplayEligibilityForPointerCapture(
+                    displayId, true);
+            mWindowPolicyControllers.remove(displayId);
         }
-        mVirtualDisplayIds.remove(displayId);
-        LocalServices.getService(InputManagerInternal.class).setDisplayEligibilityForPointerCapture(
-                displayId, true);
-        mWindowPolicyControllers.remove(displayId);
     }
 
     int getOwnerUid() {
diff --git a/services/tests/servicestests/src/com/android/server/companion/virtual/VirtualDeviceManagerServiceTest.java b/services/tests/servicestests/src/com/android/server/companion/virtual/VirtualDeviceManagerServiceTest.java
index 72100e44..e36263e 100644
--- a/services/tests/servicestests/src/com/android/server/companion/virtual/VirtualDeviceManagerServiceTest.java
+++ b/services/tests/servicestests/src/com/android/server/companion/virtual/VirtualDeviceManagerServiceTest.java
@@ -18,17 +18,21 @@
 
 import static com.google.common.truth.Truth.assertWithMessage;
 
+import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.anyInt;
 import static org.mockito.ArgumentMatchers.anyString;
 import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.ArgumentMatchers.nullable;
 import static org.mockito.Mockito.doCallRealMethod;
 import static org.mockito.Mockito.doNothing;
+import static org.mockito.Mockito.never;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
 import static org.testng.Assert.assertThrows;
 
 import android.Manifest;
 import android.app.admin.DevicePolicyManager;
+import android.companion.AssociationInfo;
 import android.companion.virtual.IVirtualDeviceActivityListener;
 import android.companion.virtual.VirtualDeviceParams;
 import android.content.Context;
@@ -41,24 +45,35 @@
 import android.hardware.input.VirtualMouseRelativeEvent;
 import android.hardware.input.VirtualMouseScrollEvent;
 import android.hardware.input.VirtualTouchEvent;
+import android.net.MacAddress;
 import android.os.Binder;
+import android.os.Handler;
+import android.os.IBinder;
+import android.os.IPowerManager;
+import android.os.IThermalService;
+import android.os.PowerManager;
+import android.os.RemoteException;
+import android.os.WorkSource;
 import android.platform.test.annotations.Presubmit;
+import android.testing.AndroidTestingRunner;
+import android.testing.TestableLooper;
 import android.view.KeyEvent;
 
 import androidx.test.InstrumentationRegistry;
-import androidx.test.runner.AndroidJUnit4;
 
 import com.android.server.LocalServices;
 
 import org.junit.Before;
 import org.junit.Test;
 import org.junit.runner.RunWith;
+import org.mockito.ArgumentCaptor;
 import org.mockito.Mock;
 import org.mockito.Mockito;
 import org.mockito.MockitoAnnotations;
 
 @Presubmit
-@RunWith(AndroidJUnit4.class)
+@RunWith(AndroidTestingRunner.class)
+@TestableLooper.RunWithLooper(setAsMainLooper = true)
 public class VirtualDeviceManagerServiceTest {
 
     private static final String DEVICE_NAME = "device name";
@@ -84,6 +99,11 @@
     private InputManagerInternal mInputManagerInternalMock;
     @Mock
     private IVirtualDeviceActivityListener mActivityListener;
+    @Mock
+    IPowerManager mIPowerManagerMock;
+    @Mock
+    IThermalService mIThermalServiceMock;
+    private PowerManager mPowerManager;
 
     @Before
     public void setUp() {
@@ -102,10 +122,17 @@
         when(mContext.getSystemService(Context.DEVICE_POLICY_SERVICE)).thenReturn(
                 mDevicePolicyManagerMock);
 
+        mPowerManager = new PowerManager(mContext, mIPowerManagerMock, mIThermalServiceMock,
+                new Handler(TestableLooper.get(this).getLooper()));
+        when(mContext.getSystemService(Context.POWER_SERVICE)).thenReturn(mPowerManager);
+
         mInputController = new InputController(new Object(), mNativeWrapperMock);
+        AssociationInfo associationInfo = new AssociationInfo(1, 0, null,
+                MacAddress.BROADCAST_ADDRESS, "", null, true, false, 0, 0);
         mDeviceImpl = new VirtualDeviceImpl(mContext,
-                /* association info */ null, new Binder(), /* uid */ 0, mInputController,
-                (int associationId) -> {}, mPendingTrampolineCallback, mActivityListener,
+                associationInfo, new Binder(), /* uid */ 0, mInputController,
+                (int associationId) -> {
+                }, mPendingTrampolineCallback, mActivityListener,
                 new VirtualDeviceParams.Builder().build());
     }
 
@@ -118,6 +145,72 @@
     }
 
     @Test
+    public void onVirtualDisplayCreatedLocked_wakeLockIsAcquired() throws RemoteException {
+        final int displayId = 2;
+        mDeviceImpl.onVirtualDisplayCreatedLocked(displayId);
+        verify(mIPowerManagerMock, never()).acquireWakeLock(any(Binder.class), anyInt(),
+                nullable(String.class), nullable(String.class), nullable(WorkSource.class),
+                nullable(String.class), anyInt());
+        TestableLooper.get(this).processAllMessages();
+        verify(mIPowerManagerMock, Mockito.times(1)).acquireWakeLock(any(Binder.class), anyInt(),
+                nullable(String.class), nullable(String.class), nullable(WorkSource.class),
+                nullable(String.class), eq(displayId));
+    }
+
+    @Test
+    public void onVirtualDisplayCreatedLocked_duplicateCalls_onlyOneWakeLockIsAcquired()
+            throws RemoteException {
+        final int displayId = 2;
+        mDeviceImpl.onVirtualDisplayCreatedLocked(displayId);
+        assertThrows(IllegalStateException.class,
+                () -> mDeviceImpl.onVirtualDisplayCreatedLocked(displayId));
+        TestableLooper.get(this).processAllMessages();
+        verify(mIPowerManagerMock, Mockito.times(1)).acquireWakeLock(any(Binder.class), anyInt(),
+                nullable(String.class), nullable(String.class), nullable(WorkSource.class),
+                nullable(String.class), eq(displayId));
+    }
+
+    @Test
+    public void onVirtualDisplayRemovedLocked_unknownDisplayId_throwsException() {
+        final int unknownDisplayId = 999;
+        assertThrows(IllegalStateException.class,
+                () -> mDeviceImpl.onVirtualDisplayRemovedLocked(unknownDisplayId));
+    }
+
+    @Test
+    public void onVirtualDisplayRemovedLocked_wakeLockIsReleased() throws RemoteException {
+        final int displayId = 2;
+        mDeviceImpl.onVirtualDisplayCreatedLocked(displayId);
+        ArgumentCaptor<IBinder> wakeLockCaptor = ArgumentCaptor.forClass(IBinder.class);
+        TestableLooper.get(this).processAllMessages();
+        verify(mIPowerManagerMock, Mockito.times(1)).acquireWakeLock(wakeLockCaptor.capture(),
+                anyInt(),
+                nullable(String.class), nullable(String.class), nullable(WorkSource.class),
+                nullable(String.class), eq(displayId));
+
+        IBinder wakeLock = wakeLockCaptor.getValue();
+        mDeviceImpl.onVirtualDisplayRemovedLocked(displayId);
+        verify(mIPowerManagerMock, Mockito.times(1)).releaseWakeLock(eq(wakeLock), anyInt());
+    }
+
+    @Test
+    public void addVirtualDisplay_displayNotReleased_wakeLockIsReleased() throws RemoteException {
+        final int displayId = 2;
+        mDeviceImpl.onVirtualDisplayCreatedLocked(displayId);
+        ArgumentCaptor<IBinder> wakeLockCaptor = ArgumentCaptor.forClass(IBinder.class);
+        TestableLooper.get(this).processAllMessages();
+        verify(mIPowerManagerMock, Mockito.times(1)).acquireWakeLock(wakeLockCaptor.capture(),
+                anyInt(),
+                nullable(String.class), nullable(String.class), nullable(WorkSource.class),
+                nullable(String.class), eq(displayId));
+        IBinder wakeLock = wakeLockCaptor.getValue();
+
+        // Close the VirtualDevice without first notifying it of the VirtualDisplay removal.
+        mDeviceImpl.close();
+        verify(mIPowerManagerMock, Mockito.times(1)).releaseWakeLock(eq(wakeLock), anyInt());
+    }
+
+    @Test
     public void createVirtualKeyboard_noDisplay_failsSecurityException() {
         assertThrows(
                 SecurityException.class,