Merge "Refactor obtaining display in SF main thread" into main
diff --git a/services/surfaceflinger/SurfaceFlinger.cpp b/services/surfaceflinger/SurfaceFlinger.cpp
index 5f81cd4..d9f7804 100644
--- a/services/surfaceflinger/SurfaceFlinger.cpp
+++ b/services/surfaceflinger/SurfaceFlinger.cpp
@@ -8220,36 +8220,66 @@
     futureFence.get();
 }
 
-ftl::SharedFuture<FenceResult> SurfaceFlinger::captureScreenshot(
-        RenderAreaBuilderVariant renderAreaBuilder, GetLayerSnapshotsFunction getLayerSnapshots,
-        const std::shared_ptr<renderengine::ExternalTexture>& buffer, bool regionSampling,
-        bool grayscale, bool isProtected, const sp<IScreenCaptureListener>& captureListener) {
-    ATRACE_CALL();
-
-    auto takeScreenshotFn = [=, this, renderAreaBuilder = std::move(renderAreaBuilder)]() REQUIRES(
-                                    kMainThreadContext) mutable -> ftl::SharedFuture<FenceResult> {
-        // LayerSnapshots must be obtained from the main thread.
-        auto layers = getLayerSnapshots();
-
+const sp<const DisplayDevice> SurfaceFlinger::getRenderAreaDisplay(
+        RenderAreaBuilderVariant& renderAreaBuilder, OutputCompositionState& state) {
+    sp<const DisplayDevice> display = nullptr;
+    {
+        Mutex::Autolock lock(mStateLock);
         if (auto* layerRenderAreaBuilder =
                     std::get_if<LayerRenderAreaBuilder>(&renderAreaBuilder)) {
             // LayerSnapshotBuilder should only be accessed from the main thread.
-            frontend::LayerSnapshot* snapshot =
+            const frontend::LayerSnapshot* snapshot =
                     mLayerSnapshotBuilder.getSnapshot(layerRenderAreaBuilder->layer->getSequence());
             if (!snapshot) {
                 ALOGW("Couldn't find layer snapshot for %d",
                       layerRenderAreaBuilder->layer->getSequence());
             } else {
                 layerRenderAreaBuilder->setLayerSnapshot(*snapshot);
+                display = findDisplay(
+                        [layerStack = snapshot->outputFilter.layerStack](const auto& display) {
+                            return display.getLayerStack() == layerStack;
+                        });
             }
+        } else if (auto* displayRenderAreaBuilder =
+                           std::get_if<DisplayRenderAreaBuilder>(&renderAreaBuilder)) {
+            display = displayRenderAreaBuilder->displayWeak.promote();
         }
 
-        if (FlagManager::getInstance().ce_fence_promise()) {
-            for (auto& [layer, layerFE] : layers) {
-                attachReleaseFenceFutureToLayer(layer, layerFE.get(), ui::INVALID_LAYER_STACK);
-            }
+        if (display == nullptr) {
+            display = getDefaultDisplayDeviceLocked();
         }
 
+        if (display != nullptr) {
+            state = display->getCompositionDisplay()->getState();
+        }
+    }
+    return display;
+}
+
+std::vector<std::pair<Layer*, sp<android::LayerFE>>>
+SurfaceFlinger::getLayerSnapshotsFromMainThread(GetLayerSnapshotsFunction getLayerSnapshotsFn) {
+    auto layers = getLayerSnapshotsFn();
+    if (FlagManager::getInstance().ce_fence_promise()) {
+        for (auto& [layer, layerFE] : layers) {
+            attachReleaseFenceFutureToLayer(layer, layerFE.get(), ui::INVALID_LAYER_STACK);
+        }
+    }
+    return layers;
+}
+
+ftl::SharedFuture<FenceResult> SurfaceFlinger::captureScreenshot(
+        RenderAreaBuilderVariant renderAreaBuilder, GetLayerSnapshotsFunction getLayerSnapshotsFn,
+        const std::shared_ptr<renderengine::ExternalTexture>& buffer, bool regionSampling,
+        bool grayscale, bool isProtected, const sp<IScreenCaptureListener>& captureListener) {
+    ATRACE_CALL();
+
+    auto takeScreenshotFn = [=, this, renderAreaBuilder = std::move(renderAreaBuilder)]() REQUIRES(
+                                    kMainThreadContext) mutable -> ftl::SharedFuture<FenceResult> {
+        auto layers = getLayerSnapshotsFromMainThread(getLayerSnapshotsFn);
+
+        OutputCompositionState state;
+        const auto display = getRenderAreaDisplay(renderAreaBuilder, state);
+
         ScreenCaptureResults captureResults;
         std::unique_ptr<const RenderArea> renderArea =
                 std::visit([](auto&& arg) -> std::unique_ptr<RenderArea> { return arg.build(); },
@@ -8266,7 +8296,7 @@
 
         ftl::SharedFuture<FenceResult> renderFuture =
                 renderScreenImpl(std::move(renderArea), buffer, regionSampling, grayscale,
-                                 isProtected, captureResults, layers);
+                                 isProtected, captureResults, display, state, layers);
 
         if (captureListener) {
             // Defer blocking on renderFuture back to the Binder thread.
@@ -8296,18 +8326,10 @@
 }
 
 ftl::SharedFuture<FenceResult> SurfaceFlinger::renderScreenImpl(
-        std::unique_ptr<const RenderArea> renderArea, GetLayerSnapshotsFunction getLayerSnapshots,
-        const std::shared_ptr<renderengine::ExternalTexture>& buffer, bool regionSampling,
-        bool grayscale, bool isProtected, ScreenCaptureResults& captureResults) {
-    auto layers = getLayerSnapshots();
-    return renderScreenImpl(std::move(renderArea), buffer, regionSampling, grayscale, isProtected,
-                            captureResults, layers);
-}
-
-ftl::SharedFuture<FenceResult> SurfaceFlinger::renderScreenImpl(
         std::unique_ptr<const RenderArea> renderArea,
         const std::shared_ptr<renderengine::ExternalTexture>& buffer, bool regionSampling,
         bool grayscale, bool isProtected, ScreenCaptureResults& captureResults,
+        const sp<const DisplayDevice> display, const OutputCompositionState& state,
         std::vector<std::pair<Layer*, sp<android::LayerFE>>>& layers) {
     ATRACE_CALL();
 
@@ -8334,60 +8356,36 @@
     const bool enableLocalTonemapping = FlagManager::getInstance().local_tonemap_screenshots() &&
             !renderArea->getHintForSeamlessTransition();
 
-    {
-        Mutex::Autolock lock(mStateLock);
-        const DisplayDevice* display = nullptr;
-        if (parent) {
-            const frontend::LayerSnapshot* snapshot =
-                    mLayerSnapshotBuilder.getSnapshot(parent->sequence);
-            if (snapshot) {
-                display = findDisplay([layerStack = snapshot->outputFilter.layerStack](
-                                              const auto& display) {
-                              return display.getLayerStack() == layerStack;
-                          }).get();
-            }
-        }
+    if (display != nullptr) {
+        captureResults.capturedDataspace =
+                pickBestDataspace(requestedDataspace, display.get(),
+                                  captureResults.capturedHdrLayers,
+                                  renderArea->getHintForSeamlessTransition());
+        sdrWhitePointNits = state.sdrWhitePointNits;
 
-        if (display == nullptr) {
-            display = renderArea->getDisplayDevice().get();
-        }
-
-        if (display == nullptr) {
-            display = getDefaultDisplayDeviceLocked().get();
-        }
-
-        if (display != nullptr) {
-            const auto& state = display->getCompositionDisplay()->getState();
-            captureResults.capturedDataspace =
-                    pickBestDataspace(requestedDataspace, display, captureResults.capturedHdrLayers,
-                                      renderArea->getHintForSeamlessTransition());
-            sdrWhitePointNits = state.sdrWhitePointNits;
-
-            if (!captureResults.capturedHdrLayers) {
-                displayBrightnessNits = sdrWhitePointNits;
-            } else {
-                displayBrightnessNits = state.displayBrightnessNits;
-
-                if (!enableLocalTonemapping) {
-                    // Only clamp the display brightness if this is not a seamless transition.
-                    // Otherwise for seamless transitions it's important to match the current
-                    // display state as the buffer will be shown under these same conditions, and we
-                    // want to avoid any flickers
-                    if (sdrWhitePointNits > 1.0f && !renderArea->getHintForSeamlessTransition()) {
-                        // Restrict the amount of HDR "headroom" in the screenshot to avoid
-                        // over-dimming the SDR portion. 2.0 chosen by experimentation
-                        constexpr float kMaxScreenshotHeadroom = 2.0f;
-                        displayBrightnessNits = std::min(sdrWhitePointNits * kMaxScreenshotHeadroom,
-                                                         displayBrightnessNits);
-                    }
+        if (!captureResults.capturedHdrLayers) {
+            displayBrightnessNits = sdrWhitePointNits;
+        } else {
+            displayBrightnessNits = state.displayBrightnessNits;
+            if (!enableLocalTonemapping) {
+                // Only clamp the display brightness if this is not a seamless transition.
+                // Otherwise for seamless transitions it's important to match the current
+                // display state as the buffer will be shown under these same conditions, and we
+                // want to avoid any flickers
+                if (sdrWhitePointNits > 1.0f && !renderArea->getHintForSeamlessTransition()) {
+                    // Restrict the amount of HDR "headroom" in the screenshot to avoid
+                    // over-dimming the SDR portion. 2.0 chosen by experimentation
+                    constexpr float kMaxScreenshotHeadroom = 2.0f;
+                    displayBrightnessNits = std::min(sdrWhitePointNits * kMaxScreenshotHeadroom,
+                                                     displayBrightnessNits);
                 }
             }
+        }
 
-            // Screenshots leaving the device should be colorimetric
-            if (requestedDataspace == ui::Dataspace::UNKNOWN &&
-                renderArea->getHintForSeamlessTransition()) {
-                renderIntent = state.renderIntent;
-            }
+        // Screenshots leaving the device should be colorimetric
+        if (requestedDataspace == ui::Dataspace::UNKNOWN &&
+            renderArea->getHintForSeamlessTransition()) {
+            renderIntent = state.renderIntent;
         }
     }
 
diff --git a/services/surfaceflinger/SurfaceFlinger.h b/services/surfaceflinger/SurfaceFlinger.h
index 8a39016..209d9bc 100644
--- a/services/surfaceflinger/SurfaceFlinger.h
+++ b/services/surfaceflinger/SurfaceFlinger.h
@@ -898,25 +898,26 @@
                              ui::Size bufferSize, ui::PixelFormat, bool allowProtected,
                              bool grayscale, const sp<IScreenCaptureListener>&);
 
+    using OutputCompositionState = compositionengine::impl::OutputCompositionState;
+
+    const sp<const DisplayDevice> getRenderAreaDisplay(RenderAreaBuilderVariant& renderAreaBuilder,
+                                                       OutputCompositionState& state)
+            REQUIRES(kMainThreadContext);
+
+    std::vector<std::pair<Layer*, sp<android::LayerFE>>> getLayerSnapshotsFromMainThread(
+            GetLayerSnapshotsFunction getLayerSnapshotsFn) REQUIRES(kMainThreadContext);
+
     ftl::SharedFuture<FenceResult> captureScreenshot(
             RenderAreaBuilderVariant, GetLayerSnapshotsFunction,
             const std::shared_ptr<renderengine::ExternalTexture>&, bool regionSampling,
             bool grayscale, bool isProtected, const sp<IScreenCaptureListener>&);
 
-    // Overloaded version of renderScreenImpl that is used when layer snapshots have
-    // not yet been captured, and thus cannot yet be passed in as a parameter.
-    // Needed for TestableSurfaceFlinger.
-    ftl::SharedFuture<FenceResult> renderScreenImpl(
-            std::unique_ptr<const RenderArea>, GetLayerSnapshotsFunction,
-            const std::shared_ptr<renderengine::ExternalTexture>&, bool regionSampling,
-            bool grayscale, bool isProtected, ScreenCaptureResults&) EXCLUDES(mStateLock)
-            REQUIRES(kMainThreadContext);
-
     ftl::SharedFuture<FenceResult> renderScreenImpl(
             std::unique_ptr<const RenderArea>,
             const std::shared_ptr<renderengine::ExternalTexture>&, bool regionSampling,
             bool grayscale, bool isProtected, ScreenCaptureResults&,
-            std::vector<std::pair<Layer*, sp<android::LayerFE>>>& layers) EXCLUDES(mStateLock)
+            const sp<const DisplayDevice> display, const OutputCompositionState& state,
+            std::vector<std::pair<Layer*, sp<android::LayerFE>>>& layers)
             REQUIRES(kMainThreadContext);
 
     // If the uid provided is not UNSET_UID, the traverse will skip any layers that don't have a
diff --git a/services/surfaceflinger/tests/unittests/CompositionTest.cpp b/services/surfaceflinger/tests/unittests/CompositionTest.cpp
index 0ddddbd..98f1687 100644
--- a/services/surfaceflinger/tests/unittests/CompositionTest.cpp
+++ b/services/surfaceflinger/tests/unittests/CompositionTest.cpp
@@ -215,7 +215,7 @@
                                                                       HAL_PIXEL_FORMAT_RGBA_8888, 1,
                                                                       usage);
 
-    auto future = mFlinger.renderScreenImpl(std::move(renderArea), getLayerSnapshots,
+    auto future = mFlinger.renderScreenImpl(mDisplay, std::move(renderArea), getLayerSnapshots,
                                             mCaptureScreenBuffer, regionSampling);
     ASSERT_TRUE(future.valid());
     const auto fenceResult = future.get();
diff --git a/services/surfaceflinger/tests/unittests/TestableSurfaceFlinger.h b/services/surfaceflinger/tests/unittests/TestableSurfaceFlinger.h
index 265f804..1783e17 100644
--- a/services/surfaceflinger/tests/unittests/TestableSurfaceFlinger.h
+++ b/services/surfaceflinger/tests/unittests/TestableSurfaceFlinger.h
@@ -485,16 +485,21 @@
         return mFlinger->setPowerModeInternal(display, mode);
     }
 
-    auto renderScreenImpl(std::unique_ptr<const RenderArea> renderArea,
+    auto renderScreenImpl(const sp<DisplayDevice> display,
+                          std::unique_ptr<const RenderArea> renderArea,
                           SurfaceFlinger::GetLayerSnapshotsFunction traverseLayers,
                           const std::shared_ptr<renderengine::ExternalTexture>& buffer,
                           bool regionSampling) {
+        Mutex::Autolock lock(mFlinger->mStateLock);
+        ftl::FakeGuard guard(kMainThreadContext);
+
         ScreenCaptureResults captureResults;
-        return FTL_FAKE_GUARD(kMainThreadContext,
-                              mFlinger->renderScreenImpl(std::move(renderArea), traverseLayers,
-                                                         buffer, regionSampling,
-                                                         false /* grayscale */,
-                                                         false /* isProtected */, captureResults));
+        SurfaceFlinger::OutputCompositionState state = display->getCompositionDisplay()->getState();
+        auto layers = mFlinger->getLayerSnapshotsFromMainThread(traverseLayers);
+
+        return mFlinger->renderScreenImpl(std::move(renderArea), buffer, regionSampling,
+                                          false /* grayscale */, false /* isProtected */,
+                                          captureResults, display, state, layers);
     }
 
     auto traverseLayersInLayerStack(ui::LayerStack layerStack, int32_t uid,