Make non-main displays take headless screenshots.

Instead of expecting ScreenshotController to function headlessly,
separate the small bit of code that does the actual saving into its own
class and use that instead of ScreenshotController when headless
operation is needed for secondary displays.

Existing "no-UI" functionality in ScreenshotController is untouched for
now to minimize the scope of the change.

Bug: 339424226
Test: atest TakeScreenshotExecutorTest
Test: manual testing with secondary display.
Flag: EXEMPT bugfix
Change-Id: I3abc3e72f43649e114358ab7aa6d14e807fc3bf2
diff --git a/packages/SystemUI/src/com/android/systemui/screenshot/HeadlessScreenshotHandler.kt b/packages/SystemUI/src/com/android/systemui/screenshot/HeadlessScreenshotHandler.kt
new file mode 100644
index 0000000..6730d2d
--- /dev/null
+++ b/packages/SystemUI/src/com/android/systemui/screenshot/HeadlessScreenshotHandler.kt
@@ -0,0 +1,114 @@
+/*
+ * Copyright (C) 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.systemui.screenshot
+
+import android.net.Uri
+import android.os.UserManager
+import android.util.Log
+import android.view.WindowManager
+import com.android.internal.logging.UiEventLogger
+import com.android.systemui.dagger.qualifiers.Main
+import com.android.systemui.res.R
+import com.google.common.util.concurrent.ListenableFuture
+import java.util.UUID
+import java.util.concurrent.Executor
+import java.util.concurrent.Executors
+import java.util.function.Consumer
+import javax.inject.Inject
+
+/**
+ * A ScreenshotHandler that just saves the screenshot and calls back as appropriate, with no UI.
+ *
+ * Basically, ScreenshotController with all the UI bits ripped out.
+ */
+class HeadlessScreenshotHandler
+@Inject
+constructor(
+    private val imageExporter: ImageExporter,
+    @Main private val mainExecutor: Executor,
+    private val imageCapture: ImageCapture,
+    private val userManager: UserManager,
+    private val uiEventLogger: UiEventLogger,
+    private val notificationsControllerFactory: ScreenshotNotificationsController.Factory,
+) : ScreenshotHandler {
+
+    override fun handleScreenshot(
+        screenshot: ScreenshotData,
+        finisher: Consumer<Uri?>,
+        requestCallback: TakeScreenshotService.RequestCallback
+    ) {
+        if (screenshot.type == WindowManager.TAKE_SCREENSHOT_FULLSCREEN) {
+            screenshot.bitmap = imageCapture.captureDisplay(screenshot.displayId, crop = null)
+        }
+
+        if (screenshot.bitmap == null) {
+            Log.e(TAG, "handleScreenshot: Screenshot bitmap was null")
+            notificationsControllerFactory
+                .create(screenshot.displayId)
+                .notifyScreenshotError(R.string.screenshot_failed_to_capture_text)
+            requestCallback.reportError()
+            return
+        }
+
+        val future: ListenableFuture<ImageExporter.Result> =
+            imageExporter.export(
+                Executors.newSingleThreadExecutor(),
+                UUID.randomUUID(),
+                screenshot.bitmap,
+                screenshot.getUserOrDefault(),
+                screenshot.displayId
+            )
+        future.addListener(
+            {
+                try {
+                    val result = future.get()
+                    Log.d(TAG, "Saved screenshot: $result")
+                    logScreenshotResultStatus(result.uri, screenshot)
+                    finisher.accept(result.uri)
+                    requestCallback.onFinish()
+                } catch (e: Exception) {
+                    Log.d(TAG, "Failed to store screenshot", e)
+                    finisher.accept(null)
+                    requestCallback.reportError()
+                }
+            },
+            mainExecutor
+        )
+    }
+
+    private fun logScreenshotResultStatus(uri: Uri?, screenshot: ScreenshotData) {
+        if (uri == null) {
+            uiEventLogger.log(ScreenshotEvent.SCREENSHOT_NOT_SAVED, 0, screenshot.packageNameString)
+            notificationsControllerFactory
+                .create(screenshot.displayId)
+                .notifyScreenshotError(R.string.screenshot_failed_to_save_text)
+        } else {
+            uiEventLogger.log(ScreenshotEvent.SCREENSHOT_SAVED, 0, screenshot.packageNameString)
+            if (userManager.isManagedProfile(screenshot.getUserOrDefault().identifier)) {
+                uiEventLogger.log(
+                    ScreenshotEvent.SCREENSHOT_SAVED_TO_WORK_PROFILE,
+                    0,
+                    screenshot.packageNameString
+                )
+            }
+        }
+    }
+
+    companion object {
+        const val TAG = "HeadlessScreenshotHandler"
+    }
+}
diff --git a/packages/SystemUI/src/com/android/systemui/screenshot/ScreenshotController.java b/packages/SystemUI/src/com/android/systemui/screenshot/ScreenshotController.java
index e8dfac8..c87b1f5 100644
--- a/packages/SystemUI/src/com/android/systemui/screenshot/ScreenshotController.java
+++ b/packages/SystemUI/src/com/android/systemui/screenshot/ScreenshotController.java
@@ -101,7 +101,7 @@
 /**
  * Controls the state and flow for screenshots.
  */
-public class ScreenshotController {
+public class ScreenshotController implements ScreenshotHandler {
     private static final String TAG = logTag(ScreenshotController.class);
 
     /**
@@ -351,7 +351,8 @@
         mShowUIOnExternalDisplay = showUIOnExternalDisplay;
     }
 
-    void handleScreenshot(ScreenshotData screenshot, Consumer<Uri> finisher,
+    @Override
+    public void handleScreenshot(ScreenshotData screenshot, Consumer<Uri> finisher,
             RequestCallback requestCallback) {
         Assert.isMainThread();
 
diff --git a/packages/SystemUI/src/com/android/systemui/screenshot/TakeScreenshotExecutor.kt b/packages/SystemUI/src/com/android/systemui/screenshot/TakeScreenshotExecutor.kt
index 3c3797b..2699657 100644
--- a/packages/SystemUI/src/com/android/systemui/screenshot/TakeScreenshotExecutor.kt
+++ b/packages/SystemUI/src/com/android/systemui/screenshot/TakeScreenshotExecutor.kt
@@ -1,3 +1,19 @@
+/*
+ * Copyright (C) 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
 package com.android.systemui.screenshot
 
 import android.net.Uri
@@ -7,12 +23,12 @@
 import android.view.WindowManager.TAKE_SCREENSHOT_PROVIDED_IMAGE
 import com.android.internal.logging.UiEventLogger
 import com.android.internal.util.ScreenshotRequest
-import com.android.systemui.Flags.screenshotShelfUi2
 import com.android.systemui.dagger.SysUISingleton
 import com.android.systemui.dagger.qualifiers.Application
 import com.android.systemui.display.data.repository.DisplayRepository
 import com.android.systemui.res.R
 import com.android.systemui.screenshot.ScreenshotEvent.SCREENSHOT_CAPTURE_FAILED
+import com.android.systemui.screenshot.ScreenshotEvent.SCREENSHOT_DISMISSED_OTHER
 import com.android.systemui.screenshot.TakeScreenshotService.RequestCallback
 import java.util.function.Consumer
 import javax.inject.Inject
@@ -26,9 +42,13 @@
         onSaved: (Uri?) -> Unit,
         requestCallback: RequestCallback
     )
+
     fun onCloseSystemDialogsReceived()
+
     fun removeWindows()
+
     fun onDestroy()
+
     fun executeScreenshotsAsync(
         screenshotRequest: ScreenshotRequest,
         onSaved: Consumer<Uri?>,
@@ -36,6 +56,14 @@
     )
 }
 
+interface ScreenshotHandler {
+    fun handleScreenshot(
+        screenshot: ScreenshotData,
+        finisher: Consumer<Uri?>,
+        requestCallback: RequestCallback
+    )
+}
+
 /**
  * Receives the signal to take a screenshot from [TakeScreenshotService], and calls back with the
  * result.
@@ -52,10 +80,10 @@
     private val screenshotRequestProcessor: ScreenshotRequestProcessor,
     private val uiEventLogger: UiEventLogger,
     private val screenshotNotificationControllerFactory: ScreenshotNotificationsController.Factory,
+    private val headlessScreenshotHandler: HeadlessScreenshotHandler,
 ) : TakeScreenshotExecutor {
-
     private val displays = displayRepository.displays
-    private val screenshotControllers = mutableMapOf<Int, ScreenshotController>()
+    private var screenshotController: ScreenshotController? = null
     private val notificationControllers = mutableMapOf<Int, ScreenshotNotificationsController>()
 
     /**
@@ -73,9 +101,15 @@
         val resultCallbackWrapper = MultiResultCallbackWrapper(requestCallback)
         displays.forEach { display ->
             val displayId = display.displayId
+            var screenshotHandler: ScreenshotHandler =
+                if (displayId == Display.DEFAULT_DISPLAY) {
+                    getScreenshotController(display)
+                } else {
+                    headlessScreenshotHandler
+                }
             Log.d(TAG, "Executing screenshot for display $displayId")
             dispatchToController(
-                display = display,
+                screenshotHandler,
                 rawScreenshotData = ScreenshotData.fromRequest(screenshotRequest, displayId),
                 onSaved =
                     if (displayId == Display.DEFAULT_DISPLAY) {
@@ -88,7 +122,7 @@
 
     /** All logging should be triggered only by this method. */
     private suspend fun dispatchToController(
-        display: Display,
+        screenshotHandler: ScreenshotHandler,
         rawScreenshotData: ScreenshotData,
         onSaved: (Uri?) -> Unit,
         callback: RequestCallback
@@ -102,13 +136,12 @@
                     logScreenshotRequested(rawScreenshotData)
                     onFailedScreenshotRequest(rawScreenshotData, callback)
                 }
-                .getOrNull()
-                ?: return
+                .getOrNull() ?: return
 
         logScreenshotRequested(screenshotData)
         Log.d(TAG, "Screenshot request: $screenshotData")
         try {
-            getScreenshotController(display).handleScreenshot(screenshotData, onSaved, callback)
+            screenshotHandler.handleScreenshot(screenshotData, onSaved, callback)
         } catch (e: IllegalStateException) {
             Log.e(TAG, "Error while ScreenshotController was handling ScreenshotData!", e)
             onFailedScreenshotRequest(screenshotData, callback)
@@ -140,44 +173,32 @@
 
     private suspend fun getDisplaysToScreenshot(requestType: Int): List<Display> {
         val allDisplays = displays.first()
-        return if (requestType == TAKE_SCREENSHOT_PROVIDED_IMAGE || screenshotShelfUi2()) {
-            // If this is a provided image or using the shelf UI, just screenshot th default display
+        return if (requestType == TAKE_SCREENSHOT_PROVIDED_IMAGE) {
+            // If this is a provided image just screenshot th default display
             allDisplays.filter { it.displayId == Display.DEFAULT_DISPLAY }
         } else {
             allDisplays.filter { it.type in ALLOWED_DISPLAY_TYPES }
         }
     }
 
-    /** Propagates the close system dialog signal to all controllers. */
+    /** Propagates the close system dialog signal to the ScreenshotController. */
     override fun onCloseSystemDialogsReceived() {
-        screenshotControllers.forEach { (_, screenshotController) ->
-            if (!screenshotController.isPendingSharedTransition) {
-                screenshotController.requestDismissal(ScreenshotEvent.SCREENSHOT_DISMISSED_OTHER)
-            }
+        if (screenshotController?.isPendingSharedTransition == false) {
+            screenshotController?.requestDismissal(SCREENSHOT_DISMISSED_OTHER)
         }
     }
 
     /** Removes all screenshot related windows. */
     override fun removeWindows() {
-        screenshotControllers.forEach { (_, screenshotController) ->
-            screenshotController.removeWindow()
-        }
+        screenshotController?.removeWindow()
     }
 
     /**
      * Destroys the executor. Afterwards, this class is not expected to work as intended anymore.
      */
     override fun onDestroy() {
-        screenshotControllers.forEach { (_, screenshotController) ->
-            screenshotController.onDestroy()
-        }
-        screenshotControllers.clear()
-    }
-
-    private fun getScreenshotController(display: Display): ScreenshotController {
-        return screenshotControllers.computeIfAbsent(display.displayId) {
-            screenshotControllerFactory.create(display, /* showUIOnExternalDisplay= */ false)
-        }
+        screenshotController?.onDestroy()
+        screenshotController = null
     }
 
     private fun getNotificationController(id: Int): ScreenshotNotificationsController {
@@ -197,6 +218,12 @@
         }
     }
 
+    private fun getScreenshotController(display: Display): ScreenshotController {
+        val controller = screenshotController ?: screenshotControllerFactory.create(display, false)
+        screenshotController = controller
+        return controller
+    }
+
     /**
      * Returns a [RequestCallback] that wraps [originalCallback].
      *
diff --git a/packages/SystemUI/tests/src/com/android/systemui/screenshot/TakeScreenshotExecutorTest.kt b/packages/SystemUI/tests/src/com/android/systemui/screenshot/TakeScreenshotExecutorTest.kt
index ec5589e..0b81b5e 100644
--- a/packages/SystemUI/tests/src/com/android/systemui/screenshot/TakeScreenshotExecutorTest.kt
+++ b/packages/SystemUI/tests/src/com/android/systemui/screenshot/TakeScreenshotExecutorTest.kt
@@ -3,9 +3,6 @@
 import android.content.ComponentName
 import android.graphics.Bitmap
 import android.net.Uri
-import android.platform.test.annotations.DisableFlags
-import android.platform.test.annotations.EnableFlags
-import android.testing.AndroidTestingRunner
 import android.view.Display
 import android.view.Display.TYPE_EXTERNAL
 import android.view.Display.TYPE_INTERNAL
@@ -18,7 +15,6 @@
 import androidx.test.filters.SmallTest
 import com.android.internal.logging.testing.UiEventLoggerFake
 import com.android.internal.util.ScreenshotRequest
-import com.android.systemui.Flags
 import com.android.systemui.SysuiTestCase
 import com.android.systemui.display.data.repository.FakeDisplayRepository
 import com.android.systemui.display.data.repository.display
@@ -26,7 +22,6 @@
 import com.android.systemui.util.mockito.eq
 import com.android.systemui.util.mockito.kotlinArgumentCaptor as ArgumentCaptor
 import com.android.systemui.util.mockito.mock
-import com.android.systemui.util.mockito.nullable
 import com.android.systemui.util.mockito.whenever
 import com.google.common.truth.Truth.assertThat
 import java.lang.IllegalStateException
@@ -47,8 +42,7 @@
 @SmallTest
 class TakeScreenshotExecutorTest : SysuiTestCase() {
 
-    private val controller0 = mock<ScreenshotController>()
-    private val controller1 = mock<ScreenshotController>()
+    private val controller = mock<ScreenshotController>()
     private val notificationsController0 = mock<ScreenshotNotificationsController>()
     private val notificationsController1 = mock<ScreenshotNotificationsController>()
     private val controllerFactory = mock<ScreenshotController.Factory>()
@@ -60,6 +54,7 @@
     private val topComponent = ComponentName(mContext, TakeScreenshotExecutorTest::class.java)
     private val testScope = TestScope(UnconfinedTestDispatcher())
     private val eventLogger = UiEventLoggerFake()
+    private val headlessHandler = mock<HeadlessScreenshotHandler>()
 
     private val screenshotExecutor =
         TakeScreenshotExecutorImpl(
@@ -68,20 +63,18 @@
             testScope,
             requestProcessor,
             eventLogger,
-            notificationControllerFactory
+            notificationControllerFactory,
+            headlessHandler,
         )
 
     @Before
     fun setUp() {
-        whenever(controllerFactory.create(any(), any())).thenAnswer {
-            if (it.getArgument<Display>(0).displayId == 0) controller0 else controller1
-        }
+        whenever(controllerFactory.create(any(), any())).thenReturn(controller)
         whenever(notificationControllerFactory.create(eq(0))).thenReturn(notificationsController0)
         whenever(notificationControllerFactory.create(eq(1))).thenReturn(notificationsController1)
     }
 
     @Test
-    @DisableFlags(Flags.FLAG_SCREENSHOT_SHELF_UI2)
     fun executeScreenshots_severalDisplays_callsControllerForEachOne() =
         testScope.runTest {
             val internalDisplay = display(TYPE_INTERNAL, id = 0)
@@ -91,14 +84,14 @@
             screenshotExecutor.executeScreenshots(createScreenshotRequest(), onSaved, callback)
 
             verify(controllerFactory).create(eq(internalDisplay), any())
-            verify(controllerFactory).create(eq(externalDisplay), any())
+            verify(controllerFactory, never()).create(eq(externalDisplay), any())
 
             val capturer = ArgumentCaptor<ScreenshotData>()
 
-            verify(controller0).handleScreenshot(capturer.capture(), any(), any())
+            verify(controller).handleScreenshot(capturer.capture(), any(), any())
             assertThat(capturer.value.displayId).isEqualTo(0)
             // OnSaved callback should be different.
-            verify(controller1).handleScreenshot(capturer.capture(), any(), any())
+            verify(headlessHandler).handleScreenshot(capturer.capture(), any(), any())
             assertThat(capturer.value.displayId).isEqualTo(1)
 
             assertThat(eventLogger.numLogs()).isEqualTo(2)
@@ -113,32 +106,6 @@
         }
 
     @Test
-    @EnableFlags(Flags.FLAG_SCREENSHOT_SHELF_UI2)
-    fun executeScreenshots_severalDisplaysShelfUi_justCallsOne() =
-        testScope.runTest {
-            val internalDisplay = display(TYPE_INTERNAL, id = 0)
-            val externalDisplay = display(TYPE_EXTERNAL, id = 1)
-            setDisplays(internalDisplay, externalDisplay)
-            val onSaved = { _: Uri? -> }
-            screenshotExecutor.executeScreenshots(createScreenshotRequest(), onSaved, callback)
-
-            verify(controllerFactory).create(eq(internalDisplay), any())
-
-            val capturer = ArgumentCaptor<ScreenshotData>()
-
-            verify(controller0).handleScreenshot(capturer.capture(), any(), any())
-            assertThat(capturer.value.displayId).isEqualTo(0)
-
-            assertThat(eventLogger.numLogs()).isEqualTo(1)
-            assertThat(eventLogger.get(0).eventId)
-                .isEqualTo(ScreenshotEvent.SCREENSHOT_REQUESTED_KEY_OTHER.id)
-            assertThat(eventLogger.get(0).packageName).isEqualTo(topComponent.packageName)
-
-            screenshotExecutor.onDestroy()
-        }
-
-    @Test
-    @DisableFlags(Flags.FLAG_SCREENSHOT_SHELF_UI2)
     fun executeScreenshots_providedImageType_callsOnlyDefaultDisplayController() =
         testScope.runTest {
             val internalDisplay = display(TYPE_INTERNAL, id = 0)
@@ -156,10 +123,10 @@
 
             val capturer = ArgumentCaptor<ScreenshotData>()
 
-            verify(controller0).handleScreenshot(capturer.capture(), any(), any())
+            verify(controller).handleScreenshot(capturer.capture(), any(), any())
             assertThat(capturer.value.displayId).isEqualTo(0)
             // OnSaved callback should be different.
-            verify(controller1, never()).handleScreenshot(any(), any(), any())
+            verify(headlessHandler, never()).handleScreenshot(any(), any(), any())
 
             assertThat(eventLogger.numLogs()).isEqualTo(1)
             assertThat(eventLogger.get(0).eventId)
@@ -170,7 +137,6 @@
         }
 
     @Test
-    @DisableFlags(Flags.FLAG_SCREENSHOT_SHELF_UI2)
     fun executeScreenshots_onlyVirtualDisplays_noInteractionsWithControllers() =
         testScope.runTest {
             setDisplays(display(TYPE_VIRTUAL, id = 0), display(TYPE_VIRTUAL, id = 1))
@@ -178,14 +144,14 @@
             screenshotExecutor.executeScreenshots(createScreenshotRequest(), onSaved, callback)
 
             verifyNoMoreInteractions(controllerFactory)
+            verify(headlessHandler, never()).handleScreenshot(any(), any(), any())
             screenshotExecutor.onDestroy()
         }
 
     @Test
-    @DisableFlags(Flags.FLAG_SCREENSHOT_SHELF_UI2)
     fun executeScreenshots_allowedTypes_allCaptured() =
         testScope.runTest {
-            whenever(controllerFactory.create(any(), any())).thenReturn(controller0)
+            whenever(controllerFactory.create(any(), any())).thenReturn(controller)
 
             setDisplays(
                 display(TYPE_INTERNAL, id = 0),
@@ -196,12 +162,12 @@
             val onSaved = { _: Uri? -> }
             screenshotExecutor.executeScreenshots(createScreenshotRequest(), onSaved, callback)
 
-            verify(controller0, times(4)).handleScreenshot(any(), any(), any())
+            verify(controller, times(1)).handleScreenshot(any(), any(), any())
+            verify(headlessHandler, times(3)).handleScreenshot(any(), any(), any())
             screenshotExecutor.onDestroy()
         }
 
     @Test
-    @DisableFlags(Flags.FLAG_SCREENSHOT_SHELF_UI2)
     fun executeScreenshots_reportsOnFinishedOnlyWhenBothFinished() =
         testScope.runTest {
             setDisplays(display(TYPE_INTERNAL, id = 0), display(TYPE_EXTERNAL, id = 1))
@@ -211,8 +177,8 @@
             val capturer0 = ArgumentCaptor<TakeScreenshotService.RequestCallback>()
             val capturer1 = ArgumentCaptor<TakeScreenshotService.RequestCallback>()
 
-            verify(controller0).handleScreenshot(any(), any(), capturer0.capture())
-            verify(controller1).handleScreenshot(any(), any(), capturer1.capture())
+            verify(controller).handleScreenshot(any(), any(), capturer0.capture())
+            verify(headlessHandler).handleScreenshot(any(), any(), capturer1.capture())
 
             verify(callback, never()).onFinish()
 
@@ -227,7 +193,6 @@
         }
 
     @Test
-    @DisableFlags(Flags.FLAG_SCREENSHOT_SHELF_UI2)
     fun executeScreenshots_oneFinishesOtherFails_reportFailsOnlyAtTheEnd() =
         testScope.runTest {
             setDisplays(display(TYPE_INTERNAL, id = 0), display(TYPE_EXTERNAL, id = 1))
@@ -237,8 +202,8 @@
             val capturer0 = ArgumentCaptor<TakeScreenshotService.RequestCallback>()
             val capturer1 = ArgumentCaptor<TakeScreenshotService.RequestCallback>()
 
-            verify(controller0).handleScreenshot(any(), any(), capturer0.capture())
-            verify(controller1).handleScreenshot(any(), nullable(), capturer1.capture())
+            verify(controller).handleScreenshot(any(), any(), capturer0.capture())
+            verify(headlessHandler).handleScreenshot(any(), any(), capturer1.capture())
 
             verify(callback, never()).onFinish()
 
@@ -255,7 +220,6 @@
         }
 
     @Test
-    @DisableFlags(Flags.FLAG_SCREENSHOT_SHELF_UI2)
     fun executeScreenshots_allDisplaysFail_reportsFail() =
         testScope.runTest {
             setDisplays(display(TYPE_INTERNAL, id = 0), display(TYPE_EXTERNAL, id = 1))
@@ -265,8 +229,8 @@
             val capturer0 = ArgumentCaptor<TakeScreenshotService.RequestCallback>()
             val capturer1 = ArgumentCaptor<TakeScreenshotService.RequestCallback>()
 
-            verify(controller0).handleScreenshot(any(), any(), capturer0.capture())
-            verify(controller1).handleScreenshot(any(), any(), capturer1.capture())
+            verify(controller).handleScreenshot(any(), any(), capturer0.capture())
+            verify(headlessHandler).handleScreenshot(any(), any(), capturer1.capture())
 
             verify(callback, never()).onFinish()
 
@@ -283,7 +247,6 @@
         }
 
     @Test
-    @DisableFlags(Flags.FLAG_SCREENSHOT_SHELF_UI2)
     fun onDestroy_propagatedToControllers() =
         testScope.runTest {
             setDisplays(display(TYPE_INTERNAL, id = 0), display(TYPE_EXTERNAL, id = 1))
@@ -291,59 +254,50 @@
             screenshotExecutor.executeScreenshots(createScreenshotRequest(), onSaved, callback)
 
             screenshotExecutor.onDestroy()
-            verify(controller0).onDestroy()
-            verify(controller1).onDestroy()
+            verify(controller).onDestroy()
         }
 
     @Test
-    @DisableFlags(Flags.FLAG_SCREENSHOT_SHELF_UI2)
-    fun removeWindows_propagatedToControllers() =
+    fun removeWindows_propagatedToController() =
         testScope.runTest {
             setDisplays(display(TYPE_INTERNAL, id = 0), display(TYPE_EXTERNAL, id = 1))
             val onSaved = { _: Uri? -> }
             screenshotExecutor.executeScreenshots(createScreenshotRequest(), onSaved, callback)
 
             screenshotExecutor.removeWindows()
-            verify(controller0).removeWindow()
-            verify(controller1).removeWindow()
+            verify(controller).removeWindow()
 
             screenshotExecutor.onDestroy()
         }
 
     @Test
-    @DisableFlags(Flags.FLAG_SCREENSHOT_SHELF_UI2)
-    fun onCloseSystemDialogsReceived_propagatedToControllers() =
+    fun onCloseSystemDialogsReceived_propagatedToController() =
         testScope.runTest {
             setDisplays(display(TYPE_INTERNAL, id = 0), display(TYPE_EXTERNAL, id = 1))
             val onSaved = { _: Uri? -> }
             screenshotExecutor.executeScreenshots(createScreenshotRequest(), onSaved, callback)
 
             screenshotExecutor.onCloseSystemDialogsReceived()
-            verify(controller0).requestDismissal(any())
-            verify(controller1).requestDismissal(any())
+            verify(controller).requestDismissal(any())
 
             screenshotExecutor.onDestroy()
         }
 
     @Test
-    @DisableFlags(Flags.FLAG_SCREENSHOT_SHELF_UI2)
-    fun onCloseSystemDialogsReceived_someControllerHavePendingTransitions() =
+    fun onCloseSystemDialogsReceived_controllerHasPendingTransitions() =
         testScope.runTest {
             setDisplays(display(TYPE_INTERNAL, id = 0), display(TYPE_EXTERNAL, id = 1))
-            whenever(controller0.isPendingSharedTransition).thenReturn(true)
-            whenever(controller1.isPendingSharedTransition).thenReturn(false)
+            whenever(controller.isPendingSharedTransition).thenReturn(true)
             val onSaved = { _: Uri? -> }
             screenshotExecutor.executeScreenshots(createScreenshotRequest(), onSaved, callback)
 
             screenshotExecutor.onCloseSystemDialogsReceived()
-            verify(controller0, never()).requestDismissal(any())
-            verify(controller1).requestDismissal(any())
+            verify(controller, never()).requestDismissal(any())
 
             screenshotExecutor.onDestroy()
         }
 
     @Test
-    @DisableFlags(Flags.FLAG_SCREENSHOT_SHELF_UI2)
     fun executeScreenshots_controllerCalledWithRequestProcessorReturnValue() =
         testScope.runTest {
             setDisplays(display(TYPE_INTERNAL, id = 0))
@@ -358,14 +312,13 @@
                 .isEqualTo(ScreenshotData.fromRequest(screenshotRequest))
 
             val capturer = ArgumentCaptor<ScreenshotData>()
-            verify(controller0).handleScreenshot(capturer.capture(), any(), any())
+            verify(controller).handleScreenshot(capturer.capture(), any(), any())
             assertThat(capturer.value).isEqualTo(toBeReturnedByProcessor)
 
             screenshotExecutor.onDestroy()
         }
 
     @Test
-    @DisableFlags(Flags.FLAG_SCREENSHOT_SHELF_UI2)
     fun executeScreenshots_errorFromProcessor_logsScreenshotRequested() =
         testScope.runTest {
             setDisplays(display(TYPE_INTERNAL, id = 0), display(TYPE_EXTERNAL, id = 1))
@@ -383,7 +336,6 @@
         }
 
     @Test
-    @DisableFlags(Flags.FLAG_SCREENSHOT_SHELF_UI2)
     fun executeScreenshots_errorFromProcessor_logsUiError() =
         testScope.runTest {
             setDisplays(display(TYPE_INTERNAL, id = 0), display(TYPE_EXTERNAL, id = 1))
@@ -401,7 +353,6 @@
         }
 
     @Test
-    @DisableFlags(Flags.FLAG_SCREENSHOT_SHELF_UI2)
     fun executeScreenshots_errorFromProcessorOnDefaultDisplay_showsErrorNotification() =
         testScope.runTest {
             setDisplays(display(TYPE_INTERNAL, id = 0), display(TYPE_EXTERNAL, id = 1))
@@ -428,14 +379,13 @@
         }
 
     @Test
-    @DisableFlags(Flags.FLAG_SCREENSHOT_SHELF_UI2)
     fun executeScreenshots_errorFromScreenshotController_reportsRequested() =
         testScope.runTest {
             setDisplays(display(TYPE_INTERNAL, id = 0), display(TYPE_EXTERNAL, id = 1))
             val onSaved = { _: Uri? -> }
-            whenever(controller0.handleScreenshot(any(), any(), any()))
+            whenever(controller.handleScreenshot(any(), any(), any()))
                 .thenThrow(IllegalStateException::class.java)
-            whenever(controller1.handleScreenshot(any(), any(), any()))
+            whenever(headlessHandler.handleScreenshot(any(), any(), any()))
                 .thenThrow(IllegalStateException::class.java)
 
             screenshotExecutor.executeScreenshots(createScreenshotRequest(), onSaved, callback)
@@ -449,14 +399,13 @@
         }
 
     @Test
-    @DisableFlags(Flags.FLAG_SCREENSHOT_SHELF_UI2)
     fun executeScreenshots_errorFromScreenshotController_reportsError() =
         testScope.runTest {
             setDisplays(display(TYPE_INTERNAL, id = 0), display(TYPE_EXTERNAL, id = 1))
             val onSaved = { _: Uri? -> }
-            whenever(controller0.handleScreenshot(any(), any(), any()))
+            whenever(controller.handleScreenshot(any(), any(), any()))
                 .thenThrow(IllegalStateException::class.java)
-            whenever(controller1.handleScreenshot(any(), any(), any()))
+            whenever(headlessHandler.handleScreenshot(any(), any(), any()))
                 .thenThrow(IllegalStateException::class.java)
 
             screenshotExecutor.executeScreenshots(createScreenshotRequest(), onSaved, callback)
@@ -470,14 +419,13 @@
         }
 
     @Test
-    @DisableFlags(Flags.FLAG_SCREENSHOT_SHELF_UI2)
     fun executeScreenshots_errorFromScreenshotController_showsErrorNotification() =
         testScope.runTest {
             setDisplays(display(TYPE_INTERNAL, id = 0), display(TYPE_EXTERNAL, id = 1))
             val onSaved = { _: Uri? -> }
-            whenever(controller0.handleScreenshot(any(), any(), any()))
+            whenever(controller.handleScreenshot(any(), any(), any()))
                 .thenThrow(IllegalStateException::class.java)
-            whenever(controller1.handleScreenshot(any(), any(), any()))
+            whenever(headlessHandler.handleScreenshot(any(), any(), any()))
                 .thenThrow(IllegalStateException::class.java)
 
             screenshotExecutor.executeScreenshots(createScreenshotRequest(), onSaved, callback)
@@ -496,7 +444,7 @@
                 assertThat(it).isNull()
                 onSavedCallCount += 1
             }
-            whenever(controller0.handleScreenshot(any(), any(), any())).thenAnswer {
+            whenever(controller.handleScreenshot(any(), any(), any())).thenAnswer {
                 (it.getArgument(1) as Consumer<Uri?>).accept(null)
             }
 
@@ -525,6 +473,7 @@
         var processed: ScreenshotData? = null
         var toReturn: ScreenshotData? = null
         var shouldThrowException = false
+
         override suspend fun process(screenshot: ScreenshotData): ScreenshotData {
             if (shouldThrowException) throw RequestProcessorException("")
             processed = screenshot