Merge "Fix Race in ScreenshotController" into 24D1-dev
diff --git a/packages/SystemUI/src/com/android/systemui/screenshot/ScreenshotController.java b/packages/SystemUI/src/com/android/systemui/screenshot/ScreenshotController.java
index dfd742c..71363b2 100644
--- a/packages/SystemUI/src/com/android/systemui/screenshot/ScreenshotController.java
+++ b/packages/SystemUI/src/com/android/systemui/screenshot/ScreenshotController.java
@@ -47,7 +47,6 @@
 import android.graphics.Bitmap;
 import android.graphics.Insets;
 import android.graphics.Rect;
-import android.hardware.display.DisplayManager;
 import android.net.Uri;
 import android.os.Process;
 import android.os.UserHandle;
@@ -209,8 +208,7 @@
     @Nullable
     private final ScreenshotSoundController mScreenshotSoundController;
     private final PhoneWindow mWindow;
-    private final DisplayManager mDisplayManager;
-    private final int mDisplayId;
+    private final Display mDisplay;
     private final ScrollCaptureExecutor mScrollCaptureExecutor;
     private final ScreenshotNotificationSmartActionsProvider
             mScreenshotNotificationSmartActionsProvider;
@@ -250,7 +248,6 @@
     @AssistedInject
     ScreenshotController(
             Context context,
-            DisplayManager displayManager,
             WindowManager windowManager,
             FeatureFlags flags,
             ScreenshotViewProxy.Factory viewProxyFactory,
@@ -272,12 +269,14 @@
             AssistContentRequester assistContentRequester,
             MessageContainerController messageContainerController,
             Provider<ScreenshotSoundController> screenshotSoundController,
-            @Assisted int displayId,
+            @Assisted Display display,
             @Assisted boolean showUIOnExternalDisplay
     ) {
         mScreenshotSmartActions = screenshotSmartActions;
+        mWindowManager = windowManager;
         mActionsProviderFactory = actionsProviderFactory;
-        mNotificationsController = screenshotNotificationsControllerFactory.create(displayId);
+        mNotificationsController = screenshotNotificationsControllerFactory.create(
+                display.getDisplayId());
         mUiEventLogger = uiEventLogger;
         mImageExporter = imageExporter;
         mImageCapture = imageCapture;
@@ -290,12 +289,8 @@
 
         mScreenshotHandler = timeoutHandler;
         mScreenshotHandler.setDefaultTimeoutMillis(SCREENSHOT_CORNER_DEFAULT_TIMEOUT_MILLIS);
-
-
-        mDisplayId = displayId;
-        mDisplayManager = displayManager;
-        mWindowManager = windowManager;
-        final Context displayContext = context.createDisplayContext(getDisplay());
+        mDisplay = display;
+        final Context displayContext = context.createDisplayContext(display);
         mContext = (WindowContext) displayContext.createWindowContext(TYPE_SCREENSHOT, null);
         mFlags = flags;
         mActionIntentExecutor = actionIntentExecutor;
@@ -303,7 +298,7 @@
         mMessageContainerController = messageContainerController;
         mAssistContentRequester = assistContentRequester;
 
-        mViewProxy = viewProxyFactory.getProxy(mContext, mDisplayId);
+        mViewProxy = viewProxyFactory.getProxy(mContext, mDisplay.getDisplayId());
 
         mScreenshotHandler.setOnTimeoutRunnable(() -> {
             if (DEBUG_UI) {
@@ -329,7 +324,7 @@
                 });
 
         // Sound is only reproduced from the controller of the default display.
-        if (displayId == Display.DEFAULT_DISPLAY) {
+        if (display.getDisplayId() == Display.DEFAULT_DISPLAY) {
             mScreenshotSoundController = screenshotSoundController.get();
         } else {
             mScreenshotSoundController = null;
@@ -357,7 +352,7 @@
         if (screenshot.getType() == WindowManager.TAKE_SCREENSHOT_FULLSCREEN) {
             Rect bounds = getFullScreenRect();
             screenshot.setBitmap(
-                    mImageCapture.captureDisplay(mDisplayId, bounds));
+                    mImageCapture.captureDisplay(mDisplay.getDisplayId(), bounds));
             screenshot.setScreenBounds(bounds);
         }
 
@@ -460,7 +455,7 @@
     }
 
     private boolean shouldShowUi() {
-        return mDisplayId == Display.DEFAULT_DISPLAY || mShowUIOnExternalDisplay;
+        return mDisplay.getDisplayId() == Display.DEFAULT_DISPLAY || mShowUIOnExternalDisplay;
     }
 
     void prepareViewForNewScreenshot(@NonNull ScreenshotData screenshot, String oldPackageName) {
@@ -623,7 +618,7 @@
 
     private void requestScrollCapture(UserHandle owner) {
         mScrollCaptureExecutor.requestScrollCapture(
-                mDisplayId,
+                mDisplay.getDisplayId(),
                 mWindow.getDecorView().getWindowToken(),
                 (response) -> {
                     mUiEventLogger.log(ScreenshotEvent.SCREENSHOT_LONG_SCREENSHOT_IMPRESSION,
@@ -646,7 +641,8 @@
         }
         mUiEventLogger.log(ScreenshotEvent.SCREENSHOT_LONG_SCREENSHOT_REQUESTED, 0,
                 response.getPackageName());
-        Bitmap newScreenshot = mImageCapture.captureDisplay(mDisplayId, getFullScreenRect());
+        Bitmap newScreenshot = mImageCapture.captureDisplay(mDisplay.getDisplayId(),
+                getFullScreenRect());
         if (newScreenshot == null) {
             Log.e(TAG, "Failed to capture current screenshot for scroll transition!");
             return;
@@ -824,7 +820,8 @@
     private void saveScreenshotInBackground(
             ScreenshotData screenshot, UUID requestId, Consumer<Uri> finisher) {
         ListenableFuture<ImageExporter.Result> future = mImageExporter.export(mBgExecutor,
-                requestId, screenshot.getBitmap(), screenshot.getUserOrDefault(), mDisplayId);
+                requestId, screenshot.getBitmap(), screenshot.getUserOrDefault(),
+                mDisplay.getDisplayId());
         future.addListener(() -> {
             try {
                 ImageExporter.Result result = future.get();
@@ -866,7 +863,7 @@
         data.mActionsReadyListener = actionsReadyListener;
         data.mQuickShareActionsReadyListener = quickShareActionsReadyListener;
         data.owner = owner;
-        data.displayId = mDisplayId;
+        data.displayId = mDisplay.getDisplayId();
 
         if (mSaveInBgTask != null) {
             // just log success/failure for the pre-existing screenshot
@@ -991,13 +988,9 @@
         }
     }
 
-    private Display getDisplay() {
-        return mDisplayManager.getDisplay(mDisplayId);
-    }
-
     private Rect getFullScreenRect() {
         DisplayMetrics displayMetrics = new DisplayMetrics();
-        getDisplay().getRealMetrics(displayMetrics);
+        mDisplay.getRealMetrics(displayMetrics);
         return new Rect(0, 0, displayMetrics.widthPixels, displayMetrics.heightPixels);
     }
 
@@ -1033,10 +1026,10 @@
         /**
          * Creates an instance of the controller for that specific displayId.
          *
-         * @param displayId:               display to capture
-         * @param showUIOnExternalDisplay: Whether the UI should be shown if this is an external
-         *                                 display.
+         * @param display                 Display to capture.
+         * @param showUIOnExternalDisplay Whether the UI should be shown if this is an external
+         *                                display.
          */
-        ScreenshotController create(int displayId, boolean showUIOnExternalDisplay);
+        ScreenshotController create(Display display, boolean showUIOnExternalDisplay);
     }
 }
diff --git a/packages/SystemUI/src/com/android/systemui/screenshot/TakeScreenshotExecutor.kt b/packages/SystemUI/src/com/android/systemui/screenshot/TakeScreenshotExecutor.kt
index a9179bf..5feac80 100644
--- a/packages/SystemUI/src/com/android/systemui/screenshot/TakeScreenshotExecutor.kt
+++ b/packages/SystemUI/src/com/android/systemui/screenshot/TakeScreenshotExecutor.kt
@@ -52,11 +52,13 @@
         onSaved: (Uri?) -> Unit,
         requestCallback: RequestCallback
     ) {
-        val displayIds = getDisplaysToScreenshot(screenshotRequest.type)
+        val displays = getDisplaysToScreenshot(screenshotRequest.type)
         val resultCallbackWrapper = MultiResultCallbackWrapper(requestCallback)
-        displayIds.forEach { displayId: Int ->
+        displays.forEach { display: Display ->
+            val displayId = display.displayId
             Log.d(TAG, "Executing screenshot for display $displayId")
             dispatchToController(
+                display,
                 rawScreenshotData = ScreenshotData.fromRequest(screenshotRequest, displayId),
                 onSaved =
                     if (displayId == Display.DEFAULT_DISPLAY) {
@@ -69,6 +71,7 @@
 
     /** All logging should be triggered only by this method. */
     private suspend fun dispatchToController(
+        display: Display,
         rawScreenshotData: ScreenshotData,
         onSaved: (Uri?) -> Unit,
         callback: RequestCallback
@@ -88,8 +91,7 @@
         logScreenshotRequested(screenshotData)
         Log.d(TAG, "Screenshot request: $screenshotData")
         try {
-            getScreenshotController(screenshotData.displayId)
-                .handleScreenshot(screenshotData, onSaved, callback)
+            getScreenshotController(display).handleScreenshot(screenshotData, onSaved, callback)
         } catch (e: IllegalStateException) {
             Log.e(TAG, "Error while ScreenshotController was handling ScreenshotData!", e)
             onFailedScreenshotRequest(screenshotData, callback)
@@ -119,12 +121,13 @@
         callback.reportError()
     }
 
-    private suspend fun getDisplaysToScreenshot(requestType: Int): List<Int> {
+    private suspend fun getDisplaysToScreenshot(requestType: Int): List<Display> {
+        val allDisplays = displays.first()
         return if (requestType == TAKE_SCREENSHOT_PROVIDED_IMAGE) {
             // If this is a provided image, let's show the UI on the default display only.
-            listOf(Display.DEFAULT_DISPLAY)
+            allDisplays.filter { it.displayId == Display.DEFAULT_DISPLAY }
         } else {
-            displays.first().filter { it.type in ALLOWED_DISPLAY_TYPES }.map { it.displayId }
+            allDisplays.filter { it.type in ALLOWED_DISPLAY_TYPES }
         }
     }
 
@@ -158,9 +161,9 @@
         screenshotControllers.clear()
     }
 
-    private fun getScreenshotController(id: Int): ScreenshotController {
-        return screenshotControllers.computeIfAbsent(id) {
-            screenshotControllerFactory.create(id, /* showUIOnExternalDisplay= */ false)
+    private fun getScreenshotController(display: Display): ScreenshotController {
+        return screenshotControllers.computeIfAbsent(display.displayId) {
+            screenshotControllerFactory.create(display, /* showUIOnExternalDisplay= */ false)
         }
     }
 
diff --git a/packages/SystemUI/src/com/android/systemui/screenshot/TakeScreenshotService.java b/packages/SystemUI/src/com/android/systemui/screenshot/TakeScreenshotService.java
index 9cf347b..c03ba65 100644
--- a/packages/SystemUI/src/com/android/systemui/screenshot/TakeScreenshotService.java
+++ b/packages/SystemUI/src/com/android/systemui/screenshot/TakeScreenshotService.java
@@ -37,6 +37,7 @@
 import android.content.Context;
 import android.content.Intent;
 import android.content.IntentFilter;
+import android.hardware.display.DisplayManager;
 import android.net.Uri;
 import android.os.Handler;
 import android.os.IBinder;
@@ -116,7 +117,8 @@
             UiEventLogger uiEventLogger,
             ScreenshotNotificationsController.Factory notificationsControllerFactory,
             Context context, @Background Executor bgExecutor, FeatureFlags featureFlags,
-            RequestProcessor processor, Provider<TakeScreenshotExecutor> takeScreenshotExecutor) {
+            RequestProcessor processor, Provider<TakeScreenshotExecutor> takeScreenshotExecutor,
+            DisplayManager displayManager) {
         if (DEBUG_SERVICE) {
             Log.d(TAG, "new " + this);
         }
@@ -134,7 +136,8 @@
             mScreenshot = null;
         } else {
             mScreenshot = screenshotControllerFactory.create(
-                    Display.DEFAULT_DISPLAY, /* showUIOnExternalDisplay= */ false);
+                    displayManager.getDisplay(
+                            Display.DEFAULT_DISPLAY), /* showUIOnExternalDisplay= */ false);
         }
     }
 
diff --git a/packages/SystemUI/tests/src/com/android/systemui/screenshot/TakeScreenshotExecutorTest.kt b/packages/SystemUI/tests/src/com/android/systemui/screenshot/TakeScreenshotExecutorTest.kt
index 22312f4..0dda41f 100644
--- a/packages/SystemUI/tests/src/com/android/systemui/screenshot/TakeScreenshotExecutorTest.kt
+++ b/packages/SystemUI/tests/src/com/android/systemui/screenshot/TakeScreenshotExecutorTest.kt
@@ -69,8 +69,9 @@
 
     @Before
     fun setUp() {
-        whenever(controllerFactory.create(eq(0), any())).thenReturn(controller0)
-        whenever(controllerFactory.create(eq(1), any())).thenReturn(controller1)
+        whenever(controllerFactory.create(any(), any())).thenAnswer {
+            if (it.getArgument<Display>(0).displayId == 0) controller0 else controller1
+        }
         whenever(notificationControllerFactory.create(eq(0))).thenReturn(notificationsController0)
         whenever(notificationControllerFactory.create(eq(1))).thenReturn(notificationsController1)
     }
@@ -78,12 +79,14 @@
     @Test
     fun executeScreenshots_severalDisplays_callsControllerForEachOne() =
         testScope.runTest {
-            setDisplays(display(TYPE_INTERNAL, id = 0), display(TYPE_EXTERNAL, id = 1))
+            val internalDisplay = display(TYPE_INTERNAL, id = 0)
+            val externalDisplay = display(TYPE_EXTERNAL, id = 1)
+            setDisplays(internalDisplay, externalDisplay)
             val onSaved = { _: Uri? -> }
             screenshotExecutor.executeScreenshots(createScreenshotRequest(), onSaved, callback)
 
-            verify(controllerFactory).create(eq(0), any())
-            verify(controllerFactory).create(eq(1), any())
+            verify(controllerFactory).create(eq(internalDisplay), any())
+            verify(controllerFactory).create(eq(externalDisplay), any())
 
             val capturer = ArgumentCaptor<ScreenshotData>()
 
@@ -107,7 +110,9 @@
     @Test
     fun executeScreenshots_providedImageType_callsOnlyDefaultDisplayController() =
         testScope.runTest {
-            setDisplays(display(TYPE_INTERNAL, id = 0), display(TYPE_EXTERNAL, id = 1))
+            val internalDisplay = display(TYPE_INTERNAL, id = 0)
+            val externalDisplay = display(TYPE_EXTERNAL, id = 1)
+            setDisplays(internalDisplay, externalDisplay)
             val onSaved = { _: Uri? -> }
             screenshotExecutor.executeScreenshots(
                 createScreenshotRequest(TAKE_SCREENSHOT_PROVIDED_IMAGE),
@@ -115,8 +120,8 @@
                 callback
             )
 
-            verify(controllerFactory).create(eq(0), any())
-            verify(controllerFactory, never()).create(eq(1), any())
+            verify(controllerFactory).create(eq(internalDisplay), any())
+            verify(controllerFactory, never()).create(eq(externalDisplay), any())
 
             val capturer = ArgumentCaptor<ScreenshotData>()
 
@@ -473,6 +478,7 @@
         var processed: ScreenshotData? = null
         var toReturn: ScreenshotData? = null
         var shouldThrowException = false
+
         override suspend fun process(screenshot: ScreenshotData): ScreenshotData {
             if (shouldThrowException) throw RequestProcessorException("")
             processed = screenshot
diff --git a/packages/SystemUI/tests/src/com/android/systemui/screenshot/TakeScreenshotServiceTest.kt b/packages/SystemUI/tests/src/com/android/systemui/screenshot/TakeScreenshotServiceTest.kt
index f3809aa..b6163512 100644
--- a/packages/SystemUI/tests/src/com/android/systemui/screenshot/TakeScreenshotServiceTest.kt
+++ b/packages/SystemUI/tests/src/com/android/systemui/screenshot/TakeScreenshotServiceTest.kt
@@ -20,6 +20,7 @@
 import android.app.admin.DevicePolicyResources.Strings.SystemUi.SCREENSHOT_BLOCKED_BY_ADMIN
 import android.app.admin.DevicePolicyResourcesManager
 import android.content.ComponentName
+import android.hardware.display.DisplayManager
 import android.os.UserHandle
 import android.os.UserManager
 import android.testing.AndroidTestingRunner
@@ -38,6 +39,7 @@
 import com.android.systemui.util.mockito.any
 import com.android.systemui.util.mockito.eq
 import com.android.systemui.util.mockito.mock
+import com.android.systemui.util.mockito.nullable
 import com.android.systemui.util.mockito.whenever
 import java.util.function.Consumer
 import org.junit.Assert.assertEquals
@@ -68,6 +70,7 @@
     private val notificationsControllerFactory = mock<ScreenshotNotificationsController.Factory>()
     private val notificationsController = mock<ScreenshotNotificationsController>()
     private val callback = mock<RequestCallback>()
+    private val displayManager = mock<DisplayManager>()
 
     private val eventLogger = UiEventLoggerFake()
     private val flags = FakeFeatureFlags()
@@ -87,7 +90,7 @@
             )
             .thenReturn(false)
         whenever(userManager.isUserUnlocked).thenReturn(true)
-        whenever(controllerFactory.create(any(), any())).thenReturn(controller)
+        whenever(controllerFactory.create(nullable<Display>(), any())).thenReturn(controller)
         whenever(notificationsControllerFactory.create(any())).thenReturn(notificationsController)
 
         // Stub request processor as a synchronous no-op for tests with the flag enabled
@@ -331,6 +334,7 @@
                 flags,
                 requestProcessor,
                 { takeScreenshotExecutor },
+                displayManager,
             )
         service.attach(
             mContext,
diff --git a/packages/SystemUI/tests/utils/src/com/android/systemui/display/data/repository/FakeDisplayRepository.kt b/packages/SystemUI/tests/utils/src/com/android/systemui/display/data/repository/FakeDisplayRepository.kt
index d8098b7..133719c 100644
--- a/packages/SystemUI/tests/utils/src/com/android/systemui/display/data/repository/FakeDisplayRepository.kt
+++ b/packages/SystemUI/tests/utils/src/com/android/systemui/display/data/repository/FakeDisplayRepository.kt
@@ -22,6 +22,7 @@
 import org.mockito.Mockito.`when` as whenever
 
 /** Creates a mock display. */
+@JvmOverloads
 fun display(
     type: Int,
     flags: Int = 0,