Merge "Make non-main displays take headless screenshots." into main
diff --git a/packages/SystemUI/src/com/android/systemui/screenshot/HeadlessScreenshotHandler.kt b/packages/SystemUI/src/com/android/systemui/screenshot/HeadlessScreenshotHandler.kt
new file mode 100644
index 0000000..6730d2d
--- /dev/null
+++ b/packages/SystemUI/src/com/android/systemui/screenshot/HeadlessScreenshotHandler.kt
@@ -0,0 +1,114 @@
+/*
+ * Copyright (C) 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.systemui.screenshot
+
+import android.net.Uri
+import android.os.UserManager
+import android.util.Log
+import android.view.WindowManager
+import com.android.internal.logging.UiEventLogger
+import com.android.systemui.dagger.qualifiers.Main
+import com.android.systemui.res.R
+import com.google.common.util.concurrent.ListenableFuture
+import java.util.UUID
+import java.util.concurrent.Executor
+import java.util.concurrent.Executors
+import java.util.function.Consumer
+import javax.inject.Inject
+
+/**
+ * A ScreenshotHandler that just saves the screenshot and calls back as appropriate, with no UI.
+ *
+ * Basically, ScreenshotController with all the UI bits ripped out.
+ */
+class HeadlessScreenshotHandler
+@Inject
+constructor(
+    private val imageExporter: ImageExporter,
+    @Main private val mainExecutor: Executor,
+    private val imageCapture: ImageCapture,
+    private val userManager: UserManager,
+    private val uiEventLogger: UiEventLogger,
+    private val notificationsControllerFactory: ScreenshotNotificationsController.Factory,
+) : ScreenshotHandler {
+
+    override fun handleScreenshot(
+        screenshot: ScreenshotData,
+        finisher: Consumer<Uri?>,
+        requestCallback: TakeScreenshotService.RequestCallback
+    ) {
+        if (screenshot.type == WindowManager.TAKE_SCREENSHOT_FULLSCREEN) {
+            screenshot.bitmap = imageCapture.captureDisplay(screenshot.displayId, crop = null)
+        }
+
+        if (screenshot.bitmap == null) {
+            Log.e(TAG, "handleScreenshot: Screenshot bitmap was null")
+            notificationsControllerFactory
+                .create(screenshot.displayId)
+                .notifyScreenshotError(R.string.screenshot_failed_to_capture_text)
+            requestCallback.reportError()
+            return
+        }
+
+        val future: ListenableFuture<ImageExporter.Result> =
+            imageExporter.export(
+                Executors.newSingleThreadExecutor(),
+                UUID.randomUUID(),
+                screenshot.bitmap,
+                screenshot.getUserOrDefault(),
+                screenshot.displayId
+            )
+        future.addListener(
+            {
+                try {
+                    val result = future.get()
+                    Log.d(TAG, "Saved screenshot: $result")
+                    logScreenshotResultStatus(result.uri, screenshot)
+                    finisher.accept(result.uri)
+                    requestCallback.onFinish()
+                } catch (e: Exception) {
+                    Log.d(TAG, "Failed to store screenshot", e)
+                    finisher.accept(null)
+                    requestCallback.reportError()
+                }
+            },
+            mainExecutor
+        )
+    }
+
+    private fun logScreenshotResultStatus(uri: Uri?, screenshot: ScreenshotData) {
+        if (uri == null) {
+            uiEventLogger.log(ScreenshotEvent.SCREENSHOT_NOT_SAVED, 0, screenshot.packageNameString)
+            notificationsControllerFactory
+                .create(screenshot.displayId)
+                .notifyScreenshotError(R.string.screenshot_failed_to_save_text)
+        } else {
+            uiEventLogger.log(ScreenshotEvent.SCREENSHOT_SAVED, 0, screenshot.packageNameString)
+            if (userManager.isManagedProfile(screenshot.getUserOrDefault().identifier)) {
+                uiEventLogger.log(
+                    ScreenshotEvent.SCREENSHOT_SAVED_TO_WORK_PROFILE,
+                    0,
+                    screenshot.packageNameString
+                )
+            }
+        }
+    }
+
+    companion object {
+        const val TAG = "HeadlessScreenshotHandler"
+    }
+}
diff --git a/packages/SystemUI/src/com/android/systemui/screenshot/ScreenshotController.java b/packages/SystemUI/src/com/android/systemui/screenshot/ScreenshotController.java
index e8dfac8..c87b1f5 100644
--- a/packages/SystemUI/src/com/android/systemui/screenshot/ScreenshotController.java
+++ b/packages/SystemUI/src/com/android/systemui/screenshot/ScreenshotController.java
@@ -101,7 +101,7 @@
 /**
  * Controls the state and flow for screenshots.
  */
-public class ScreenshotController {
+public class ScreenshotController implements ScreenshotHandler {
     private static final String TAG = logTag(ScreenshotController.class);
 
     /**
@@ -351,7 +351,8 @@
         mShowUIOnExternalDisplay = showUIOnExternalDisplay;
     }
 
-    void handleScreenshot(ScreenshotData screenshot, Consumer<Uri> finisher,
+    @Override
+    public void handleScreenshot(ScreenshotData screenshot, Consumer<Uri> finisher,
             RequestCallback requestCallback) {
         Assert.isMainThread();
 
diff --git a/packages/SystemUI/src/com/android/systemui/screenshot/TakeScreenshotExecutor.kt b/packages/SystemUI/src/com/android/systemui/screenshot/TakeScreenshotExecutor.kt
index 3c3797b..2699657 100644
--- a/packages/SystemUI/src/com/android/systemui/screenshot/TakeScreenshotExecutor.kt
+++ b/packages/SystemUI/src/com/android/systemui/screenshot/TakeScreenshotExecutor.kt
@@ -1,3 +1,19 @@
+/*
+ * Copyright (C) 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
 package com.android.systemui.screenshot
 
 import android.net.Uri
@@ -7,12 +23,12 @@
 import android.view.WindowManager.TAKE_SCREENSHOT_PROVIDED_IMAGE
 import com.android.internal.logging.UiEventLogger
 import com.android.internal.util.ScreenshotRequest
-import com.android.systemui.Flags.screenshotShelfUi2
 import com.android.systemui.dagger.SysUISingleton
 import com.android.systemui.dagger.qualifiers.Application
 import com.android.systemui.display.data.repository.DisplayRepository
 import com.android.systemui.res.R
 import com.android.systemui.screenshot.ScreenshotEvent.SCREENSHOT_CAPTURE_FAILED
+import com.android.systemui.screenshot.ScreenshotEvent.SCREENSHOT_DISMISSED_OTHER
 import com.android.systemui.screenshot.TakeScreenshotService.RequestCallback
 import java.util.function.Consumer
 import javax.inject.Inject
@@ -26,9 +42,13 @@
         onSaved: (Uri?) -> Unit,
         requestCallback: RequestCallback
     )
+
     fun onCloseSystemDialogsReceived()
+
     fun removeWindows()
+
     fun onDestroy()
+
     fun executeScreenshotsAsync(
         screenshotRequest: ScreenshotRequest,
         onSaved: Consumer<Uri?>,
@@ -36,6 +56,14 @@
     )
 }
 
+interface ScreenshotHandler {
+    fun handleScreenshot(
+        screenshot: ScreenshotData,
+        finisher: Consumer<Uri?>,
+        requestCallback: RequestCallback
+    )
+}
+
 /**
  * Receives the signal to take a screenshot from [TakeScreenshotService], and calls back with the
  * result.
@@ -52,10 +80,10 @@
     private val screenshotRequestProcessor: ScreenshotRequestProcessor,
     private val uiEventLogger: UiEventLogger,
     private val screenshotNotificationControllerFactory: ScreenshotNotificationsController.Factory,
+    private val headlessScreenshotHandler: HeadlessScreenshotHandler,
 ) : TakeScreenshotExecutor {
-
     private val displays = displayRepository.displays
-    private val screenshotControllers = mutableMapOf<Int, ScreenshotController>()
+    private var screenshotController: ScreenshotController? = null
     private val notificationControllers = mutableMapOf<Int, ScreenshotNotificationsController>()
 
     /**
@@ -73,9 +101,15 @@
         val resultCallbackWrapper = MultiResultCallbackWrapper(requestCallback)
         displays.forEach { display ->
             val displayId = display.displayId
+            var screenshotHandler: ScreenshotHandler =
+                if (displayId == Display.DEFAULT_DISPLAY) {
+                    getScreenshotController(display)
+                } else {
+                    headlessScreenshotHandler
+                }
             Log.d(TAG, "Executing screenshot for display $displayId")
             dispatchToController(
-                display = display,
+                screenshotHandler,
                 rawScreenshotData = ScreenshotData.fromRequest(screenshotRequest, displayId),
                 onSaved =
                     if (displayId == Display.DEFAULT_DISPLAY) {
@@ -88,7 +122,7 @@
 
     /** All logging should be triggered only by this method. */
     private suspend fun dispatchToController(
-        display: Display,
+        screenshotHandler: ScreenshotHandler,
         rawScreenshotData: ScreenshotData,
         onSaved: (Uri?) -> Unit,
         callback: RequestCallback
@@ -102,13 +136,12 @@
                     logScreenshotRequested(rawScreenshotData)
                     onFailedScreenshotRequest(rawScreenshotData, callback)
                 }
-                .getOrNull()
-                ?: return
+                .getOrNull() ?: return
 
         logScreenshotRequested(screenshotData)
         Log.d(TAG, "Screenshot request: $screenshotData")
         try {
-            getScreenshotController(display).handleScreenshot(screenshotData, onSaved, callback)
+            screenshotHandler.handleScreenshot(screenshotData, onSaved, callback)
         } catch (e: IllegalStateException) {
             Log.e(TAG, "Error while ScreenshotController was handling ScreenshotData!", e)
             onFailedScreenshotRequest(screenshotData, callback)
@@ -140,44 +173,32 @@
 
     private suspend fun getDisplaysToScreenshot(requestType: Int): List<Display> {
         val allDisplays = displays.first()
-        return if (requestType == TAKE_SCREENSHOT_PROVIDED_IMAGE || screenshotShelfUi2()) {
-            // If this is a provided image or using the shelf UI, just screenshot th default display
+        return if (requestType == TAKE_SCREENSHOT_PROVIDED_IMAGE) {
+            // If this is a provided image just screenshot th default display
             allDisplays.filter { it.displayId == Display.DEFAULT_DISPLAY }
         } else {
             allDisplays.filter { it.type in ALLOWED_DISPLAY_TYPES }
         }
     }
 
-    /** Propagates the close system dialog signal to all controllers. */
+    /** Propagates the close system dialog signal to the ScreenshotController. */
     override fun onCloseSystemDialogsReceived() {
-        screenshotControllers.forEach { (_, screenshotController) ->
-            if (!screenshotController.isPendingSharedTransition) {
-                screenshotController.requestDismissal(ScreenshotEvent.SCREENSHOT_DISMISSED_OTHER)
-            }
+        if (screenshotController?.isPendingSharedTransition == false) {
+            screenshotController?.requestDismissal(SCREENSHOT_DISMISSED_OTHER)
         }
     }
 
     /** Removes all screenshot related windows. */
     override fun removeWindows() {
-        screenshotControllers.forEach { (_, screenshotController) ->
-            screenshotController.removeWindow()
-        }
+        screenshotController?.removeWindow()
     }
 
     /**
      * Destroys the executor. Afterwards, this class is not expected to work as intended anymore.
      */
     override fun onDestroy() {
-        screenshotControllers.forEach { (_, screenshotController) ->
-            screenshotController.onDestroy()
-        }
-        screenshotControllers.clear()
-    }
-
-    private fun getScreenshotController(display: Display): ScreenshotController {
-        return screenshotControllers.computeIfAbsent(display.displayId) {
-            screenshotControllerFactory.create(display, /* showUIOnExternalDisplay= */ false)
-        }
+        screenshotController?.onDestroy()
+        screenshotController = null
     }
 
     private fun getNotificationController(id: Int): ScreenshotNotificationsController {
@@ -197,6 +218,12 @@
         }
     }
 
+    private fun getScreenshotController(display: Display): ScreenshotController {
+        val controller = screenshotController ?: screenshotControllerFactory.create(display, false)
+        screenshotController = controller
+        return controller
+    }
+
     /**
      * Returns a [RequestCallback] that wraps [originalCallback].
      *
diff --git a/packages/SystemUI/tests/src/com/android/systemui/screenshot/TakeScreenshotExecutorTest.kt b/packages/SystemUI/tests/src/com/android/systemui/screenshot/TakeScreenshotExecutorTest.kt
index ec5589e..0b81b5e 100644
--- a/packages/SystemUI/tests/src/com/android/systemui/screenshot/TakeScreenshotExecutorTest.kt
+++ b/packages/SystemUI/tests/src/com/android/systemui/screenshot/TakeScreenshotExecutorTest.kt
@@ -3,9 +3,6 @@
 import android.content.ComponentName
 import android.graphics.Bitmap
 import android.net.Uri
-import android.platform.test.annotations.DisableFlags
-import android.platform.test.annotations.EnableFlags
-import android.testing.AndroidTestingRunner
 import android.view.Display
 import android.view.Display.TYPE_EXTERNAL
 import android.view.Display.TYPE_INTERNAL
@@ -18,7 +15,6 @@
 import androidx.test.filters.SmallTest
 import com.android.internal.logging.testing.UiEventLoggerFake
 import com.android.internal.util.ScreenshotRequest
-import com.android.systemui.Flags
 import com.android.systemui.SysuiTestCase
 import com.android.systemui.display.data.repository.FakeDisplayRepository
 import com.android.systemui.display.data.repository.display
@@ -26,7 +22,6 @@
 import com.android.systemui.util.mockito.eq
 import com.android.systemui.util.mockito.kotlinArgumentCaptor as ArgumentCaptor
 import com.android.systemui.util.mockito.mock
-import com.android.systemui.util.mockito.nullable
 import com.android.systemui.util.mockito.whenever
 import com.google.common.truth.Truth.assertThat
 import java.lang.IllegalStateException
@@ -47,8 +42,7 @@
 @SmallTest
 class TakeScreenshotExecutorTest : SysuiTestCase() {
 
-    private val controller0 = mock<ScreenshotController>()
-    private val controller1 = mock<ScreenshotController>()
+    private val controller = mock<ScreenshotController>()
     private val notificationsController0 = mock<ScreenshotNotificationsController>()
     private val notificationsController1 = mock<ScreenshotNotificationsController>()
     private val controllerFactory = mock<ScreenshotController.Factory>()
@@ -60,6 +54,7 @@
     private val topComponent = ComponentName(mContext, TakeScreenshotExecutorTest::class.java)
     private val testScope = TestScope(UnconfinedTestDispatcher())
     private val eventLogger = UiEventLoggerFake()
+    private val headlessHandler = mock<HeadlessScreenshotHandler>()
 
     private val screenshotExecutor =
         TakeScreenshotExecutorImpl(
@@ -68,20 +63,18 @@
             testScope,
             requestProcessor,
             eventLogger,
-            notificationControllerFactory
+            notificationControllerFactory,
+            headlessHandler,
         )
 
     @Before
     fun setUp() {
-        whenever(controllerFactory.create(any(), any())).thenAnswer {
-            if (it.getArgument<Display>(0).displayId == 0) controller0 else controller1
-        }
+        whenever(controllerFactory.create(any(), any())).thenReturn(controller)
         whenever(notificationControllerFactory.create(eq(0))).thenReturn(notificationsController0)
         whenever(notificationControllerFactory.create(eq(1))).thenReturn(notificationsController1)
     }
 
     @Test
-    @DisableFlags(Flags.FLAG_SCREENSHOT_SHELF_UI2)
     fun executeScreenshots_severalDisplays_callsControllerForEachOne() =
         testScope.runTest {
             val internalDisplay = display(TYPE_INTERNAL, id = 0)
@@ -91,14 +84,14 @@
             screenshotExecutor.executeScreenshots(createScreenshotRequest(), onSaved, callback)
 
             verify(controllerFactory).create(eq(internalDisplay), any())
-            verify(controllerFactory).create(eq(externalDisplay), any())
+            verify(controllerFactory, never()).create(eq(externalDisplay), any())
 
             val capturer = ArgumentCaptor<ScreenshotData>()
 
-            verify(controller0).handleScreenshot(capturer.capture(), any(), any())
+            verify(controller).handleScreenshot(capturer.capture(), any(), any())
             assertThat(capturer.value.displayId).isEqualTo(0)
             // OnSaved callback should be different.
-            verify(controller1).handleScreenshot(capturer.capture(), any(), any())
+            verify(headlessHandler).handleScreenshot(capturer.capture(), any(), any())
             assertThat(capturer.value.displayId).isEqualTo(1)
 
             assertThat(eventLogger.numLogs()).isEqualTo(2)
@@ -113,32 +106,6 @@
         }
 
     @Test
-    @EnableFlags(Flags.FLAG_SCREENSHOT_SHELF_UI2)
-    fun executeScreenshots_severalDisplaysShelfUi_justCallsOne() =
-        testScope.runTest {
-            val internalDisplay = display(TYPE_INTERNAL, id = 0)
-            val externalDisplay = display(TYPE_EXTERNAL, id = 1)
-            setDisplays(internalDisplay, externalDisplay)
-            val onSaved = { _: Uri? -> }
-            screenshotExecutor.executeScreenshots(createScreenshotRequest(), onSaved, callback)
-
-            verify(controllerFactory).create(eq(internalDisplay), any())
-
-            val capturer = ArgumentCaptor<ScreenshotData>()
-
-            verify(controller0).handleScreenshot(capturer.capture(), any(), any())
-            assertThat(capturer.value.displayId).isEqualTo(0)
-
-            assertThat(eventLogger.numLogs()).isEqualTo(1)
-            assertThat(eventLogger.get(0).eventId)
-                .isEqualTo(ScreenshotEvent.SCREENSHOT_REQUESTED_KEY_OTHER.id)
-            assertThat(eventLogger.get(0).packageName).isEqualTo(topComponent.packageName)
-
-            screenshotExecutor.onDestroy()
-        }
-
-    @Test
-    @DisableFlags(Flags.FLAG_SCREENSHOT_SHELF_UI2)
     fun executeScreenshots_providedImageType_callsOnlyDefaultDisplayController() =
         testScope.runTest {
             val internalDisplay = display(TYPE_INTERNAL, id = 0)
@@ -156,10 +123,10 @@
 
             val capturer = ArgumentCaptor<ScreenshotData>()
 
-            verify(controller0).handleScreenshot(capturer.capture(), any(), any())
+            verify(controller).handleScreenshot(capturer.capture(), any(), any())
             assertThat(capturer.value.displayId).isEqualTo(0)
             // OnSaved callback should be different.
-            verify(controller1, never()).handleScreenshot(any(), any(), any())
+            verify(headlessHandler, never()).handleScreenshot(any(), any(), any())
 
             assertThat(eventLogger.numLogs()).isEqualTo(1)
             assertThat(eventLogger.get(0).eventId)
@@ -170,7 +137,6 @@
         }
 
     @Test
-    @DisableFlags(Flags.FLAG_SCREENSHOT_SHELF_UI2)
     fun executeScreenshots_onlyVirtualDisplays_noInteractionsWithControllers() =
         testScope.runTest {
             setDisplays(display(TYPE_VIRTUAL, id = 0), display(TYPE_VIRTUAL, id = 1))
@@ -178,14 +144,14 @@
             screenshotExecutor.executeScreenshots(createScreenshotRequest(), onSaved, callback)
 
             verifyNoMoreInteractions(controllerFactory)
+            verify(headlessHandler, never()).handleScreenshot(any(), any(), any())
             screenshotExecutor.onDestroy()
         }
 
     @Test
-    @DisableFlags(Flags.FLAG_SCREENSHOT_SHELF_UI2)
     fun executeScreenshots_allowedTypes_allCaptured() =
         testScope.runTest {
-            whenever(controllerFactory.create(any(), any())).thenReturn(controller0)
+            whenever(controllerFactory.create(any(), any())).thenReturn(controller)
 
             setDisplays(
                 display(TYPE_INTERNAL, id = 0),
@@ -196,12 +162,12 @@
             val onSaved = { _: Uri? -> }
             screenshotExecutor.executeScreenshots(createScreenshotRequest(), onSaved, callback)
 
-            verify(controller0, times(4)).handleScreenshot(any(), any(), any())
+            verify(controller, times(1)).handleScreenshot(any(), any(), any())
+            verify(headlessHandler, times(3)).handleScreenshot(any(), any(), any())
             screenshotExecutor.onDestroy()
         }
 
     @Test
-    @DisableFlags(Flags.FLAG_SCREENSHOT_SHELF_UI2)
     fun executeScreenshots_reportsOnFinishedOnlyWhenBothFinished() =
         testScope.runTest {
             setDisplays(display(TYPE_INTERNAL, id = 0), display(TYPE_EXTERNAL, id = 1))
@@ -211,8 +177,8 @@
             val capturer0 = ArgumentCaptor<TakeScreenshotService.RequestCallback>()
             val capturer1 = ArgumentCaptor<TakeScreenshotService.RequestCallback>()
 
-            verify(controller0).handleScreenshot(any(), any(), capturer0.capture())
-            verify(controller1).handleScreenshot(any(), any(), capturer1.capture())
+            verify(controller).handleScreenshot(any(), any(), capturer0.capture())
+            verify(headlessHandler).handleScreenshot(any(), any(), capturer1.capture())
 
             verify(callback, never()).onFinish()
 
@@ -227,7 +193,6 @@
         }
 
     @Test
-    @DisableFlags(Flags.FLAG_SCREENSHOT_SHELF_UI2)
     fun executeScreenshots_oneFinishesOtherFails_reportFailsOnlyAtTheEnd() =
         testScope.runTest {
             setDisplays(display(TYPE_INTERNAL, id = 0), display(TYPE_EXTERNAL, id = 1))
@@ -237,8 +202,8 @@
             val capturer0 = ArgumentCaptor<TakeScreenshotService.RequestCallback>()
             val capturer1 = ArgumentCaptor<TakeScreenshotService.RequestCallback>()
 
-            verify(controller0).handleScreenshot(any(), any(), capturer0.capture())
-            verify(controller1).handleScreenshot(any(), nullable(), capturer1.capture())
+            verify(controller).handleScreenshot(any(), any(), capturer0.capture())
+            verify(headlessHandler).handleScreenshot(any(), any(), capturer1.capture())
 
             verify(callback, never()).onFinish()
 
@@ -255,7 +220,6 @@
         }
 
     @Test
-    @DisableFlags(Flags.FLAG_SCREENSHOT_SHELF_UI2)
     fun executeScreenshots_allDisplaysFail_reportsFail() =
         testScope.runTest {
             setDisplays(display(TYPE_INTERNAL, id = 0), display(TYPE_EXTERNAL, id = 1))
@@ -265,8 +229,8 @@
             val capturer0 = ArgumentCaptor<TakeScreenshotService.RequestCallback>()
             val capturer1 = ArgumentCaptor<TakeScreenshotService.RequestCallback>()
 
-            verify(controller0).handleScreenshot(any(), any(), capturer0.capture())
-            verify(controller1).handleScreenshot(any(), any(), capturer1.capture())
+            verify(controller).handleScreenshot(any(), any(), capturer0.capture())
+            verify(headlessHandler).handleScreenshot(any(), any(), capturer1.capture())
 
             verify(callback, never()).onFinish()
 
@@ -283,7 +247,6 @@
         }
 
     @Test
-    @DisableFlags(Flags.FLAG_SCREENSHOT_SHELF_UI2)
     fun onDestroy_propagatedToControllers() =
         testScope.runTest {
             setDisplays(display(TYPE_INTERNAL, id = 0), display(TYPE_EXTERNAL, id = 1))
@@ -291,59 +254,50 @@
             screenshotExecutor.executeScreenshots(createScreenshotRequest(), onSaved, callback)
 
             screenshotExecutor.onDestroy()
-            verify(controller0).onDestroy()
-            verify(controller1).onDestroy()
+            verify(controller).onDestroy()
         }
 
     @Test
-    @DisableFlags(Flags.FLAG_SCREENSHOT_SHELF_UI2)
-    fun removeWindows_propagatedToControllers() =
+    fun removeWindows_propagatedToController() =
         testScope.runTest {
             setDisplays(display(TYPE_INTERNAL, id = 0), display(TYPE_EXTERNAL, id = 1))
             val onSaved = { _: Uri? -> }
             screenshotExecutor.executeScreenshots(createScreenshotRequest(), onSaved, callback)
 
             screenshotExecutor.removeWindows()
-            verify(controller0).removeWindow()
-            verify(controller1).removeWindow()
+            verify(controller).removeWindow()
 
             screenshotExecutor.onDestroy()
         }
 
     @Test
-    @DisableFlags(Flags.FLAG_SCREENSHOT_SHELF_UI2)
-    fun onCloseSystemDialogsReceived_propagatedToControllers() =
+    fun onCloseSystemDialogsReceived_propagatedToController() =
         testScope.runTest {
             setDisplays(display(TYPE_INTERNAL, id = 0), display(TYPE_EXTERNAL, id = 1))
             val onSaved = { _: Uri? -> }
             screenshotExecutor.executeScreenshots(createScreenshotRequest(), onSaved, callback)
 
             screenshotExecutor.onCloseSystemDialogsReceived()
-            verify(controller0).requestDismissal(any())
-            verify(controller1).requestDismissal(any())
+            verify(controller).requestDismissal(any())
 
             screenshotExecutor.onDestroy()
         }
 
     @Test
-    @DisableFlags(Flags.FLAG_SCREENSHOT_SHELF_UI2)
-    fun onCloseSystemDialogsReceived_someControllerHavePendingTransitions() =
+    fun onCloseSystemDialogsReceived_controllerHasPendingTransitions() =
         testScope.runTest {
             setDisplays(display(TYPE_INTERNAL, id = 0), display(TYPE_EXTERNAL, id = 1))
-            whenever(controller0.isPendingSharedTransition).thenReturn(true)
-            whenever(controller1.isPendingSharedTransition).thenReturn(false)
+            whenever(controller.isPendingSharedTransition).thenReturn(true)
             val onSaved = { _: Uri? -> }
             screenshotExecutor.executeScreenshots(createScreenshotRequest(), onSaved, callback)
 
             screenshotExecutor.onCloseSystemDialogsReceived()
-            verify(controller0, never()).requestDismissal(any())
-            verify(controller1).requestDismissal(any())
+            verify(controller, never()).requestDismissal(any())
 
             screenshotExecutor.onDestroy()
         }
 
     @Test
-    @DisableFlags(Flags.FLAG_SCREENSHOT_SHELF_UI2)
     fun executeScreenshots_controllerCalledWithRequestProcessorReturnValue() =
         testScope.runTest {
             setDisplays(display(TYPE_INTERNAL, id = 0))
@@ -358,14 +312,13 @@
                 .isEqualTo(ScreenshotData.fromRequest(screenshotRequest))
 
             val capturer = ArgumentCaptor<ScreenshotData>()
-            verify(controller0).handleScreenshot(capturer.capture(), any(), any())
+            verify(controller).handleScreenshot(capturer.capture(), any(), any())
             assertThat(capturer.value).isEqualTo(toBeReturnedByProcessor)
 
             screenshotExecutor.onDestroy()
         }
 
     @Test
-    @DisableFlags(Flags.FLAG_SCREENSHOT_SHELF_UI2)
     fun executeScreenshots_errorFromProcessor_logsScreenshotRequested() =
         testScope.runTest {
             setDisplays(display(TYPE_INTERNAL, id = 0), display(TYPE_EXTERNAL, id = 1))
@@ -383,7 +336,6 @@
         }
 
     @Test
-    @DisableFlags(Flags.FLAG_SCREENSHOT_SHELF_UI2)
     fun executeScreenshots_errorFromProcessor_logsUiError() =
         testScope.runTest {
             setDisplays(display(TYPE_INTERNAL, id = 0), display(TYPE_EXTERNAL, id = 1))
@@ -401,7 +353,6 @@
         }
 
     @Test
-    @DisableFlags(Flags.FLAG_SCREENSHOT_SHELF_UI2)
     fun executeScreenshots_errorFromProcessorOnDefaultDisplay_showsErrorNotification() =
         testScope.runTest {
             setDisplays(display(TYPE_INTERNAL, id = 0), display(TYPE_EXTERNAL, id = 1))
@@ -428,14 +379,13 @@
         }
 
     @Test
-    @DisableFlags(Flags.FLAG_SCREENSHOT_SHELF_UI2)
     fun executeScreenshots_errorFromScreenshotController_reportsRequested() =
         testScope.runTest {
             setDisplays(display(TYPE_INTERNAL, id = 0), display(TYPE_EXTERNAL, id = 1))
             val onSaved = { _: Uri? -> }
-            whenever(controller0.handleScreenshot(any(), any(), any()))
+            whenever(controller.handleScreenshot(any(), any(), any()))
                 .thenThrow(IllegalStateException::class.java)
-            whenever(controller1.handleScreenshot(any(), any(), any()))
+            whenever(headlessHandler.handleScreenshot(any(), any(), any()))
                 .thenThrow(IllegalStateException::class.java)
 
             screenshotExecutor.executeScreenshots(createScreenshotRequest(), onSaved, callback)
@@ -449,14 +399,13 @@
         }
 
     @Test
-    @DisableFlags(Flags.FLAG_SCREENSHOT_SHELF_UI2)
     fun executeScreenshots_errorFromScreenshotController_reportsError() =
         testScope.runTest {
             setDisplays(display(TYPE_INTERNAL, id = 0), display(TYPE_EXTERNAL, id = 1))
             val onSaved = { _: Uri? -> }
-            whenever(controller0.handleScreenshot(any(), any(), any()))
+            whenever(controller.handleScreenshot(any(), any(), any()))
                 .thenThrow(IllegalStateException::class.java)
-            whenever(controller1.handleScreenshot(any(), any(), any()))
+            whenever(headlessHandler.handleScreenshot(any(), any(), any()))
                 .thenThrow(IllegalStateException::class.java)
 
             screenshotExecutor.executeScreenshots(createScreenshotRequest(), onSaved, callback)
@@ -470,14 +419,13 @@
         }
 
     @Test
-    @DisableFlags(Flags.FLAG_SCREENSHOT_SHELF_UI2)
     fun executeScreenshots_errorFromScreenshotController_showsErrorNotification() =
         testScope.runTest {
             setDisplays(display(TYPE_INTERNAL, id = 0), display(TYPE_EXTERNAL, id = 1))
             val onSaved = { _: Uri? -> }
-            whenever(controller0.handleScreenshot(any(), any(), any()))
+            whenever(controller.handleScreenshot(any(), any(), any()))
                 .thenThrow(IllegalStateException::class.java)
-            whenever(controller1.handleScreenshot(any(), any(), any()))
+            whenever(headlessHandler.handleScreenshot(any(), any(), any()))
                 .thenThrow(IllegalStateException::class.java)
 
             screenshotExecutor.executeScreenshots(createScreenshotRequest(), onSaved, callback)
@@ -496,7 +444,7 @@
                 assertThat(it).isNull()
                 onSavedCallCount += 1
             }
-            whenever(controller0.handleScreenshot(any(), any(), any())).thenAnswer {
+            whenever(controller.handleScreenshot(any(), any(), any())).thenAnswer {
                 (it.getArgument(1) as Consumer<Uri?>).accept(null)
             }
 
@@ -525,6 +473,7 @@
         var processed: ScreenshotData? = null
         var toReturn: ScreenshotData? = null
         var shouldThrowException = false
+
         override suspend fun process(screenshot: ScreenshotData): ScreenshotData {
             if (shouldThrowException) throw RequestProcessorException("")
             processed = screenshot