Screenshot a single display per invocation

Actual per-source display determination left as TODOs, but this puts the
structure in place under a flag.

Bug: 362720389
Test: atest TakeScreenshotExecutorTest
Flag: com.android.systemui.screenshot_multidisplay_focus_change

Change-Id: I4f0b4f21b1f93e02ccb14546157c79ffad685397
diff --git a/packages/SystemUI/aconfig/systemui.aconfig b/packages/SystemUI/aconfig/systemui.aconfig
index f8383d9..5942ff3 100644
--- a/packages/SystemUI/aconfig/systemui.aconfig
+++ b/packages/SystemUI/aconfig/systemui.aconfig
@@ -633,6 +633,13 @@
 }
 
 flag {
+    name: "screenshot_multidisplay_focus_change"
+    namespace: "systemui"
+    description: "Only capture a single display when screenshotting"
+    bug: "362720389"
+}
+
+flag {
    name: "run_fingerprint_detect_on_dismissible_keyguard"
    namespace: "systemui"
    description: "Run fingerprint detect instead of authenticate if the keyguard is dismissible."
diff --git a/packages/SystemUI/src/com/android/systemui/screenshot/TakeScreenshotExecutor.kt b/packages/SystemUI/src/com/android/systemui/screenshot/TakeScreenshotExecutor.kt
index 448f7c4..38608d0 100644
--- a/packages/SystemUI/src/com/android/systemui/screenshot/TakeScreenshotExecutor.kt
+++ b/packages/SystemUI/src/com/android/systemui/screenshot/TakeScreenshotExecutor.kt
@@ -20,9 +20,11 @@
 import android.os.Trace
 import android.util.Log
 import android.view.Display
+import android.view.WindowManager.ScreenshotSource
 import android.view.WindowManager.TAKE_SCREENSHOT_PROVIDED_IMAGE
 import com.android.internal.logging.UiEventLogger
 import com.android.internal.util.ScreenshotRequest
+import com.android.systemui.Flags.screenshotMultidisplayFocusChange
 import com.android.systemui.dagger.SysUISingleton
 import com.android.systemui.dagger.qualifiers.Application
 import com.android.systemui.display.data.repository.DisplayRepository
@@ -40,7 +42,7 @@
     suspend fun executeScreenshots(
         screenshotRequest: ScreenshotRequest,
         onSaved: (Uri?) -> Unit,
-        requestCallback: RequestCallback
+        requestCallback: RequestCallback,
     )
 
     fun onCloseSystemDialogsReceived()
@@ -52,7 +54,7 @@
     fun executeScreenshotsAsync(
         screenshotRequest: ScreenshotRequest,
         onSaved: Consumer<Uri?>,
-        requestCallback: RequestCallback
+        requestCallback: RequestCallback,
     )
 }
 
@@ -60,7 +62,7 @@
     fun handleScreenshot(
         screenshot: ScreenshotData,
         finisher: Consumer<Uri?>,
-        requestCallback: RequestCallback
+        requestCallback: RequestCallback,
     )
 }
 
@@ -75,7 +77,7 @@
 @Inject
 constructor(
     private val interactiveScreenshotHandlerFactory: InteractiveScreenshotHandler.Factory,
-    displayRepository: DisplayRepository,
+    private val displayRepository: DisplayRepository,
     @Application private val mainScope: CoroutineScope,
     private val screenshotRequestProcessor: ScreenshotRequestProcessor,
     private val uiEventLogger: UiEventLogger,
@@ -95,31 +97,44 @@
     override suspend fun executeScreenshots(
         screenshotRequest: ScreenshotRequest,
         onSaved: (Uri?) -> Unit,
-        requestCallback: RequestCallback
+        requestCallback: RequestCallback,
     ) {
-        val displays = getDisplaysToScreenshot(screenshotRequest.type)
-        val resultCallbackWrapper = MultiResultCallbackWrapper(requestCallback)
-        if (displays.isEmpty()) {
-            Log.wtf(TAG, "No displays found for screenshot.")
-        }
-        displays.forEach { display ->
-            val displayId = display.displayId
-            var screenshotHandler: ScreenshotHandler =
-                if (displayId == Display.DEFAULT_DISPLAY) {
-                    getScreenshotController(display)
-                } else {
-                    headlessScreenshotHandler
-                }
-            Log.d(TAG, "Executing screenshot for display $displayId")
+        if (screenshotMultidisplayFocusChange()) {
+            val display = getDisplayToScreenshot(screenshotRequest)
+            val screenshotHandler = getScreenshotController(display)
             dispatchToController(
                 screenshotHandler,
-                rawScreenshotData = ScreenshotData.fromRequest(screenshotRequest, displayId),
-                onSaved =
-                    if (displayId == Display.DEFAULT_DISPLAY) {
-                        onSaved
-                    } else { _ -> },
-                callback = resultCallbackWrapper.createCallbackForId(displayId)
+                ScreenshotData.fromRequest(screenshotRequest, display.displayId),
+                onSaved,
+                requestCallback,
             )
+        } else {
+            val displays = getDisplaysToScreenshot(screenshotRequest.type)
+            val resultCallbackWrapper = MultiResultCallbackWrapper(requestCallback)
+            if (displays.isEmpty()) {
+                Log.e(TAG, "No displays found for screenshot.")
+            }
+
+            displays.forEach { display ->
+                val displayId = display.displayId
+                var screenshotHandler: ScreenshotHandler =
+                    if (displayId == Display.DEFAULT_DISPLAY) {
+                        getScreenshotController(display)
+                    } else {
+                        headlessScreenshotHandler
+                    }
+
+                Log.d(TAG, "Executing screenshot for display $displayId")
+                dispatchToController(
+                    screenshotHandler,
+                    rawScreenshotData = ScreenshotData.fromRequest(screenshotRequest, displayId),
+                    onSaved =
+                        if (displayId == Display.DEFAULT_DISPLAY) {
+                            onSaved
+                        } else { _ -> },
+                    callback = resultCallbackWrapper.createCallbackForId(displayId),
+                )
+            }
         }
     }
 
@@ -128,7 +143,7 @@
         screenshotHandler: ScreenshotHandler,
         rawScreenshotData: ScreenshotData,
         onSaved: (Uri?) -> Unit,
-        callback: RequestCallback
+        callback: RequestCallback,
     ) {
         // Let's wait before logging "screenshot requested", as we should log the processed
         // ScreenshotData.
@@ -160,13 +175,13 @@
         uiEventLogger.log(
             ScreenshotEvent.getScreenshotSource(screenshotData.source),
             0,
-            screenshotData.packageNameString
+            screenshotData.packageNameString,
         )
     }
 
     private fun onFailedScreenshotRequest(
         screenshotData: ScreenshotData,
-        callback: RequestCallback
+        callback: RequestCallback,
     ) {
         uiEventLogger.log(SCREENSHOT_CAPTURE_FAILED, 0, screenshotData.packageNameString)
         getNotificationController(screenshotData.displayId)
@@ -184,6 +199,31 @@
         }
     }
 
+    // Return the single display to be screenshot based upon the request.
+    private suspend fun getDisplayToScreenshot(screenshotRequest: ScreenshotRequest): Display {
+        return when (screenshotRequest.source) {
+            // TODO(b/367394043): Overview requests should use a display ID provided in
+            //  ScreenshotRequest.
+            ScreenshotSource.SCREENSHOT_OVERVIEW ->
+                displayRepository.getDisplay(Display.DEFAULT_DISPLAY)
+                    ?: error("Can't find default display")
+
+            // Key chord and vendor gesture occur on the device itself, so screenshot the device's
+            // display
+            ScreenshotSource.SCREENSHOT_KEY_CHORD,
+            ScreenshotSource.SCREENSHOT_VENDOR_GESTURE ->
+                displayRepository.getDisplay(Display.DEFAULT_DISPLAY)
+                    ?: error("Can't find default display")
+
+            // All other invocations use the focused display
+            else -> focusedDisplay()
+        }
+    }
+
+    // TODO(b/367394043): Determine the focused display here.
+    private suspend fun focusedDisplay() =
+        displayRepository.getDisplay(Display.DEFAULT_DISPLAY) ?: error("Can't find default display")
+
     /** Propagates the close system dialog signal to the ScreenshotController. */
     override fun onCloseSystemDialogsReceived() {
         if (screenshotController?.isPendingSharedTransition() == false) {
@@ -214,7 +254,7 @@
     override fun executeScreenshotsAsync(
         screenshotRequest: ScreenshotRequest,
         onSaved: Consumer<Uri?>,
-        requestCallback: RequestCallback
+        requestCallback: RequestCallback,
     ) {
         mainScope.launch {
             executeScreenshots(screenshotRequest, { uri -> onSaved.accept(uri) }, requestCallback)
@@ -235,9 +275,7 @@
      * - If any finished with an error, [reportError] of [originalCallback] is called
      * - Otherwise, [onFinish] is called.
      */
-    private class MultiResultCallbackWrapper(
-        private val originalCallback: RequestCallback,
-    ) {
+    private class MultiResultCallbackWrapper(private val originalCallback: RequestCallback) {
         private val idsPending = mutableSetOf<Int>()
         private val idsWithErrors = mutableSetOf<Int>()
 
@@ -290,7 +328,7 @@
                 Display.TYPE_EXTERNAL,
                 Display.TYPE_INTERNAL,
                 Display.TYPE_OVERLAY,
-                Display.TYPE_WIFI
+                Display.TYPE_WIFI,
             )
     }
 }
diff --git a/packages/SystemUI/tests/src/com/android/systemui/screenshot/TakeScreenshotExecutorTest.kt b/packages/SystemUI/tests/src/com/android/systemui/screenshot/TakeScreenshotExecutorTest.kt
index a295981..15705fb 100644
--- a/packages/SystemUI/tests/src/com/android/systemui/screenshot/TakeScreenshotExecutorTest.kt
+++ b/packages/SystemUI/tests/src/com/android/systemui/screenshot/TakeScreenshotExecutorTest.kt
@@ -3,6 +3,8 @@
 import android.content.ComponentName
 import android.graphics.Bitmap
 import android.net.Uri
+import android.platform.test.annotations.DisableFlags
+import android.platform.test.annotations.EnableFlags
 import android.view.Display
 import android.view.Display.TYPE_EXTERNAL
 import android.view.Display.TYPE_INTERNAL
@@ -15,6 +17,7 @@
 import androidx.test.filters.SmallTest
 import com.android.internal.logging.testing.UiEventLoggerFake
 import com.android.internal.util.ScreenshotRequest
+import com.android.systemui.Flags
 import com.android.systemui.SysuiTestCase
 import com.android.systemui.display.data.repository.FakeDisplayRepository
 import com.android.systemui.display.data.repository.display
@@ -75,6 +78,7 @@
     }
 
     @Test
+    @DisableFlags(Flags.FLAG_SCREENSHOT_MULTIDISPLAY_FOCUS_CHANGE)
     fun executeScreenshots_severalDisplays_callsControllerForEachOne() =
         testScope.runTest {
             val internalDisplay = display(TYPE_INTERNAL, id = 0)
@@ -106,6 +110,7 @@
         }
 
     @Test
+    @DisableFlags(Flags.FLAG_SCREENSHOT_MULTIDISPLAY_FOCUS_CHANGE)
     fun executeScreenshots_providedImageType_callsOnlyDefaultDisplayController() =
         testScope.runTest {
             val internalDisplay = display(TYPE_INTERNAL, id = 0)
@@ -115,7 +120,7 @@
             screenshotExecutor.executeScreenshots(
                 createScreenshotRequest(TAKE_SCREENSHOT_PROVIDED_IMAGE),
                 onSaved,
-                callback
+                callback,
             )
 
             verify(controllerFactory).create(eq(internalDisplay))
@@ -137,6 +142,7 @@
         }
 
     @Test
+    @DisableFlags(Flags.FLAG_SCREENSHOT_MULTIDISPLAY_FOCUS_CHANGE)
     fun executeScreenshots_onlyVirtualDisplays_noInteractionsWithControllers() =
         testScope.runTest {
             setDisplays(display(TYPE_VIRTUAL, id = 0), display(TYPE_VIRTUAL, id = 1))
@@ -149,6 +155,7 @@
         }
 
     @Test
+    @DisableFlags(Flags.FLAG_SCREENSHOT_MULTIDISPLAY_FOCUS_CHANGE)
     fun executeScreenshots_allowedTypes_allCaptured() =
         testScope.runTest {
             whenever(controllerFactory.create(any())).thenReturn(controller)
@@ -157,7 +164,7 @@
                 display(TYPE_INTERNAL, id = 0),
                 display(TYPE_EXTERNAL, id = 1),
                 display(TYPE_OVERLAY, id = 2),
-                display(TYPE_WIFI, id = 3)
+                display(TYPE_WIFI, id = 3),
             )
             val onSaved = { _: Uri? -> }
             screenshotExecutor.executeScreenshots(createScreenshotRequest(), onSaved, callback)
@@ -168,6 +175,7 @@
         }
 
     @Test
+    @DisableFlags(Flags.FLAG_SCREENSHOT_MULTIDISPLAY_FOCUS_CHANGE)
     fun executeScreenshots_reportsOnFinishedOnlyWhenBothFinished() =
         testScope.runTest {
             setDisplays(display(TYPE_INTERNAL, id = 0), display(TYPE_EXTERNAL, id = 1))
@@ -193,6 +201,7 @@
         }
 
     @Test
+    @DisableFlags(Flags.FLAG_SCREENSHOT_MULTIDISPLAY_FOCUS_CHANGE)
     fun executeScreenshots_oneFinishesOtherFails_reportFailsOnlyAtTheEnd() =
         testScope.runTest {
             setDisplays(display(TYPE_INTERNAL, id = 0), display(TYPE_EXTERNAL, id = 1))
@@ -220,6 +229,7 @@
         }
 
     @Test
+    @DisableFlags(Flags.FLAG_SCREENSHOT_MULTIDISPLAY_FOCUS_CHANGE)
     fun executeScreenshots_allDisplaysFail_reportsFail() =
         testScope.runTest {
             setDisplays(display(TYPE_INTERNAL, id = 0), display(TYPE_EXTERNAL, id = 1))
@@ -319,6 +329,7 @@
         }
 
     @Test
+    @DisableFlags(Flags.FLAG_SCREENSHOT_MULTIDISPLAY_FOCUS_CHANGE)
     fun executeScreenshots_errorFromProcessor_logsScreenshotRequested() =
         testScope.runTest {
             setDisplays(display(TYPE_INTERNAL, id = 0), display(TYPE_EXTERNAL, id = 1))
@@ -336,6 +347,7 @@
         }
 
     @Test
+    @DisableFlags(Flags.FLAG_SCREENSHOT_MULTIDISPLAY_FOCUS_CHANGE)
     fun executeScreenshots_errorFromProcessor_logsUiError() =
         testScope.runTest {
             setDisplays(display(TYPE_INTERNAL, id = 0), display(TYPE_EXTERNAL, id = 1))
@@ -379,7 +391,8 @@
         }
 
     @Test
-    fun executeScreenshots_errorFromScreenshotController_reportsRequested() =
+    @DisableFlags(Flags.FLAG_SCREENSHOT_MULTIDISPLAY_FOCUS_CHANGE)
+    fun executeScreenshots_errorFromScreenshotController_multidisplay_reportsRequested() =
         testScope.runTest {
             setDisplays(display(TYPE_INTERNAL, id = 0), display(TYPE_EXTERNAL, id = 1))
             val onSaved = { _: Uri? -> }
@@ -399,7 +412,27 @@
         }
 
     @Test
-    fun executeScreenshots_errorFromScreenshotController_reportsError() =
+    @EnableFlags(Flags.FLAG_SCREENSHOT_MULTIDISPLAY_FOCUS_CHANGE)
+    fun executeScreenshots_errorFromScreenshotController_reportsRequested() =
+        testScope.runTest {
+            setDisplays(display(TYPE_INTERNAL, id = 0), display(TYPE_EXTERNAL, id = 1))
+            val onSaved = { _: Uri? -> }
+            whenever(controller.handleScreenshot(any(), any(), any()))
+                .thenThrow(IllegalStateException::class.java)
+
+            screenshotExecutor.executeScreenshots(createScreenshotRequest(), onSaved, callback)
+
+            val screenshotRequested =
+                eventLogger.logs.filter {
+                    it.eventId == ScreenshotEvent.SCREENSHOT_REQUESTED_KEY_OTHER.id
+                }
+            assertThat(screenshotRequested).hasSize(1)
+            screenshotExecutor.onDestroy()
+        }
+
+    @Test
+    @DisableFlags(Flags.FLAG_SCREENSHOT_MULTIDISPLAY_FOCUS_CHANGE)
+    fun executeScreenshots_errorFromScreenshotController_multidisplay_reportsError() =
         testScope.runTest {
             setDisplays(display(TYPE_INTERNAL, id = 0), display(TYPE_EXTERNAL, id = 1))
             val onSaved = { _: Uri? -> }
@@ -419,7 +452,27 @@
         }
 
     @Test
-    fun executeScreenshots_errorFromScreenshotController_showsErrorNotification() =
+    @EnableFlags(Flags.FLAG_SCREENSHOT_MULTIDISPLAY_FOCUS_CHANGE)
+    fun executeScreenshots_errorFromScreenshotController_reportsError() =
+        testScope.runTest {
+            setDisplays(display(TYPE_INTERNAL, id = 0), display(TYPE_EXTERNAL, id = 1))
+            val onSaved = { _: Uri? -> }
+            whenever(controller.handleScreenshot(any(), any(), any()))
+                .thenThrow(IllegalStateException::class.java)
+
+            screenshotExecutor.executeScreenshots(createScreenshotRequest(), onSaved, callback)
+
+            val screenshotRequested =
+                eventLogger.logs.filter {
+                    it.eventId == ScreenshotEvent.SCREENSHOT_CAPTURE_FAILED.id
+                }
+            assertThat(screenshotRequested).hasSize(1)
+            screenshotExecutor.onDestroy()
+        }
+
+    @Test
+    @DisableFlags(Flags.FLAG_SCREENSHOT_MULTIDISPLAY_FOCUS_CHANGE)
+    fun executeScreenshots_errorFromScreenshotController_multidisplay_showsErrorNotification() =
         testScope.runTest {
             setDisplays(display(TYPE_INTERNAL, id = 0), display(TYPE_EXTERNAL, id = 1))
             val onSaved = { _: Uri? -> }
@@ -436,6 +489,21 @@
         }
 
     @Test
+    @EnableFlags(Flags.FLAG_SCREENSHOT_MULTIDISPLAY_FOCUS_CHANGE)
+    fun executeScreenshots_errorFromScreenshotController_showsErrorNotification() =
+        testScope.runTest {
+            setDisplays(display(TYPE_INTERNAL, id = 0), display(TYPE_EXTERNAL, id = 1))
+            val onSaved = { _: Uri? -> }
+            whenever(controller.handleScreenshot(any(), any(), any()))
+                .thenThrow(IllegalStateException::class.java)
+
+            screenshotExecutor.executeScreenshots(createScreenshotRequest(), onSaved, callback)
+
+            verify(notificationsController0).notifyScreenshotError(any())
+            screenshotExecutor.onDestroy()
+        }
+
+    @Test
     fun executeScreenshots_finisherCalledWithNullUri_succeeds() =
         testScope.runTest {
             setDisplays(display(TYPE_INTERNAL, id = 0))