Draw gainmaps in HDR

Test: manual
Bug: 266628247
Change-Id: Iad0234089913637f9cddcb39d8fc32badabf3865
diff --git a/libs/hwui/Android.bp b/libs/hwui/Android.bp
index eeed226..9e3f115 100644
--- a/libs/hwui/Android.bp
+++ b/libs/hwui/Android.bp
@@ -511,6 +511,7 @@
         "canvas/CanvasOpBuffer.cpp",
         "canvas/CanvasOpRasterizer.cpp",
         "effects/StretchEffect.cpp",
+        "effects/GainmapRenderer.cpp",
         "pipeline/skia/HolePunch.cpp",
         "pipeline/skia/SkiaDisplayList.cpp",
         "pipeline/skia/SkiaRecordingCanvas.cpp",
diff --git a/libs/hwui/RecordingCanvas.cpp b/libs/hwui/RecordingCanvas.cpp
index 430e69e..659aec0 100644
--- a/libs/hwui/RecordingCanvas.cpp
+++ b/libs/hwui/RecordingCanvas.cpp
@@ -19,9 +19,9 @@
 #include <GrRecordingContext.h>
 #include <SkMesh.h>
 #include <hwui/Paint.h>
+#include <log/log.h>
 
 #include <experimental/type_traits>
-#include <log/log.h>
 #include <utility>
 
 #include "SkAndroidFrameworkUtils.h"
@@ -45,7 +45,8 @@
 #include "SkVertices.h"
 #include "Tonemapper.h"
 #include "VectorDrawable.h"
-#include "include/gpu/GpuTypes.h" // from Skia
+#include "effects/GainmapRenderer.h"
+#include "include/gpu/GpuTypes.h"  // from Skia
 #include "include/gpu/GrDirectContext.h"
 #include "pipeline/skia/AnimatedDrawables.h"
 #include "pipeline/skia/FunctorDrawable.h"
@@ -332,9 +333,15 @@
 
 struct DrawImage final : Op {
     static const auto kType = Type::DrawImage;
-    DrawImage(sk_sp<const SkImage>&& image, SkScalar x, SkScalar y,
-              const SkSamplingOptions& sampling, const SkPaint* paint, BitmapPalette palette)
-            : image(std::move(image)), x(x), y(y), sampling(sampling), palette(palette) {
+    DrawImage(DrawImagePayload&& payload, SkScalar x, SkScalar y, const SkSamplingOptions& sampling,
+              const SkPaint* paint)
+            : image(std::move(payload.image))
+            , x(x)
+            , y(y)
+            , sampling(sampling)
+            , palette(payload.palette)
+            , gainmap(std::move(payload.gainmapImage))
+            , gainmapInfo(payload.gainmapInfo) {
         if (paint) {
             this->paint = *paint;
         }
@@ -344,19 +351,34 @@
     SkSamplingOptions sampling;
     SkPaint paint;
     BitmapPalette palette;
+    sk_sp<const SkImage> gainmap;
+    SkGainmapInfo gainmapInfo;
+
     void draw(SkCanvas* c, const SkMatrix&) const {
-        SkPaint newPaint = paint;
-        tonemapPaint(image->imageInfo(), c->imageInfo(), -1, newPaint);
-        c->drawImage(image.get(), x, y, sampling, &newPaint);
+        if (gainmap) {
+            SkRect src = SkRect::MakeWH(image->width(), image->height());
+            SkRect dst = SkRect::MakeXYWH(x, y, src.width(), src.height());
+            DrawGainmapBitmap(c, image, src, dst, sampling, &paint,
+                              SkCanvas::kFast_SrcRectConstraint, gainmap, gainmapInfo);
+        } else {
+            SkPaint newPaint = paint;
+            tonemapPaint(image->imageInfo(), c->imageInfo(), -1, newPaint);
+            c->drawImage(image.get(), x, y, sampling, &newPaint);
+        }
     }
 };
 struct DrawImageRect final : Op {
     static const auto kType = Type::DrawImageRect;
-    DrawImageRect(sk_sp<const SkImage>&& image, const SkRect* src, const SkRect& dst,
+    DrawImageRect(DrawImagePayload&& payload, const SkRect* src, const SkRect& dst,
                   const SkSamplingOptions& sampling, const SkPaint* paint,
-                  SkCanvas::SrcRectConstraint constraint, BitmapPalette palette)
-            : image(std::move(image)), dst(dst), sampling(sampling), constraint(constraint)
-            , palette(palette) {
+                  SkCanvas::SrcRectConstraint constraint)
+            : image(std::move(payload.image))
+            , dst(dst)
+            , sampling(sampling)
+            , constraint(constraint)
+            , palette(payload.palette)
+            , gainmap(std::move(payload.gainmapImage))
+            , gainmapInfo(payload.gainmapInfo) {
         this->src = src ? *src : SkRect::MakeIWH(this->image->width(), this->image->height());
         if (paint) {
             this->paint = *paint;
@@ -368,25 +390,32 @@
     SkPaint paint;
     SkCanvas::SrcRectConstraint constraint;
     BitmapPalette palette;
+    sk_sp<const SkImage> gainmap;
+    SkGainmapInfo gainmapInfo;
+
     void draw(SkCanvas* c, const SkMatrix&) const {
-        SkPaint newPaint = paint;
-        tonemapPaint(image->imageInfo(), c->imageInfo(), -1, newPaint);
-        c->drawImageRect(image.get(), src, dst, sampling, &newPaint, constraint);
+        if (gainmap) {
+            DrawGainmapBitmap(c, image, src, dst, sampling, &paint, constraint, gainmap,
+                              gainmapInfo);
+        } else {
+            SkPaint newPaint = paint;
+            tonemapPaint(image->imageInfo(), c->imageInfo(), -1, newPaint);
+            c->drawImageRect(image.get(), src, dst, sampling, &newPaint, constraint);
+        }
     }
 };
 struct DrawImageLattice final : Op {
     static const auto kType = Type::DrawImageLattice;
-    DrawImageLattice(sk_sp<const SkImage>&& image, int xs, int ys, int fs, const SkIRect& src,
-                     const SkRect& dst, SkFilterMode filter, const SkPaint* paint,
-                     BitmapPalette palette)
-            : image(std::move(image))
+    DrawImageLattice(DrawImagePayload&& payload, int xs, int ys, int fs, const SkIRect& src,
+                     const SkRect& dst, SkFilterMode filter, const SkPaint* paint)
+            : image(std::move(payload.image))
             , xs(xs)
             , ys(ys)
             , fs(fs)
             , src(src)
             , dst(dst)
             , filter(filter)
-            , palette(palette) {
+            , palette(payload.palette) {
         if (paint) {
             this->paint = *paint;
         }
@@ -399,6 +428,8 @@
     SkPaint paint;
     BitmapPalette palette;
     void draw(SkCanvas* c, const SkMatrix&) const {
+        // TODO: Support drawing a gainmap 9-patch?
+
         auto xdivs = pod<int>(this, 0), ydivs = pod<int>(this, xs * sizeof(int));
         auto colors = (0 == fs) ? nullptr : pod<SkColor>(this, (xs + ys) * sizeof(int));
         auto flags =
@@ -781,27 +812,25 @@
                                   const SkPaint* paint) {
     this->push<DrawPicture>(0, picture, matrix, paint);
 }
-void DisplayListData::drawImage(sk_sp<const SkImage> image, SkScalar x, SkScalar y,
-                                const SkSamplingOptions& sampling, const SkPaint* paint,
-                                BitmapPalette palette) {
-    this->push<DrawImage>(0, std::move(image), x, y, sampling, paint, palette);
+void DisplayListData::drawImage(DrawImagePayload&& payload, SkScalar x, SkScalar y,
+                                const SkSamplingOptions& sampling, const SkPaint* paint) {
+    this->push<DrawImage>(0, std::move(payload), x, y, sampling, paint);
 }
-void DisplayListData::drawImageRect(sk_sp<const SkImage> image, const SkRect* src,
+void DisplayListData::drawImageRect(DrawImagePayload&& payload, const SkRect* src,
                                     const SkRect& dst, const SkSamplingOptions& sampling,
-                                    const SkPaint* paint, SkCanvas::SrcRectConstraint constraint,
-                                    BitmapPalette palette) {
-    this->push<DrawImageRect>(0, std::move(image), src, dst, sampling, paint, constraint, palette);
+                                    const SkPaint* paint, SkCanvas::SrcRectConstraint constraint) {
+    this->push<DrawImageRect>(0, std::move(payload), src, dst, sampling, paint, constraint);
 }
-void DisplayListData::drawImageLattice(sk_sp<const SkImage> image, const SkCanvas::Lattice& lattice,
-                                       const SkRect& dst, SkFilterMode filter, const SkPaint* paint,
-                                       BitmapPalette palette) {
+void DisplayListData::drawImageLattice(DrawImagePayload&& payload, const SkCanvas::Lattice& lattice,
+                                       const SkRect& dst, SkFilterMode filter,
+                                       const SkPaint* paint) {
     int xs = lattice.fXCount, ys = lattice.fYCount;
     int fs = lattice.fRectTypes ? (xs + 1) * (ys + 1) : 0;
     size_t bytes = (xs + ys) * sizeof(int) + fs * sizeof(SkCanvas::Lattice::RectType) +
                    fs * sizeof(SkColor);
     LOG_FATAL_IF(!lattice.fBounds);
-    void* pod = this->push<DrawImageLattice>(bytes, std::move(image), xs, ys, fs, *lattice.fBounds,
-                                             dst, filter, paint, palette);
+    void* pod = this->push<DrawImageLattice>(bytes, std::move(payload), xs, ys, fs,
+                                             *lattice.fBounds, dst, filter, paint);
     copy_v(pod, lattice.fXDivs, xs, lattice.fYDivs, ys, lattice.fColors, fs, lattice.fRectTypes,
            fs);
 }
@@ -1108,57 +1137,55 @@
     fDL->drawRippleDrawable(params);
 }
 
-void RecordingCanvas::drawImage(const sk_sp<SkImage>& image, SkScalar x, SkScalar y,
-                                const SkSamplingOptions& sampling, const SkPaint* paint,
-                                BitmapPalette palette) {
-    fDL->drawImage(image, x, y, sampling, paint, palette);
+void RecordingCanvas::drawImage(DrawImagePayload&& payload, SkScalar x, SkScalar y,
+                                const SkSamplingOptions& sampling, const SkPaint* paint) {
+    fDL->drawImage(std::move(payload), x, y, sampling, paint);
 }
 
-void RecordingCanvas::drawImageRect(const sk_sp<SkImage>& image, const SkRect& src,
+void RecordingCanvas::drawImageRect(DrawImagePayload&& payload, const SkRect& src,
                                     const SkRect& dst, const SkSamplingOptions& sampling,
-                                    const SkPaint* paint, SrcRectConstraint constraint,
-                                    BitmapPalette palette) {
-    fDL->drawImageRect(image, &src, dst, sampling, paint, constraint, palette);
+                                    const SkPaint* paint, SrcRectConstraint constraint) {
+    fDL->drawImageRect(std::move(payload), &src, dst, sampling, paint, constraint);
 }
 
-void RecordingCanvas::drawImageLattice(const sk_sp<SkImage>& image, const Lattice& lattice,
-                                       const SkRect& dst, SkFilterMode filter, const SkPaint* paint,
-                                       BitmapPalette palette) {
-    if (!image || dst.isEmpty()) {
+void RecordingCanvas::drawImageLattice(DrawImagePayload&& payload, const Lattice& lattice,
+                                       const SkRect& dst, SkFilterMode filter,
+                                       const SkPaint* paint) {
+    if (!payload.image || dst.isEmpty()) {
         return;
     }
 
     SkIRect bounds;
     Lattice latticePlusBounds = lattice;
     if (!latticePlusBounds.fBounds) {
-        bounds = SkIRect::MakeWH(image->width(), image->height());
+        bounds = SkIRect::MakeWH(payload.image->width(), payload.image->height());
         latticePlusBounds.fBounds = &bounds;
     }
 
-    if (SkLatticeIter::Valid(image->width(), image->height(), latticePlusBounds)) {
-        fDL->drawImageLattice(image, latticePlusBounds, dst, filter, paint, palette);
+    if (SkLatticeIter::Valid(payload.image->width(), payload.image->height(), latticePlusBounds)) {
+        fDL->drawImageLattice(std::move(payload), latticePlusBounds, dst, filter, paint);
     } else {
         SkSamplingOptions sampling(filter, SkMipmapMode::kNone);
-        fDL->drawImageRect(image, nullptr, dst, sampling, paint, kFast_SrcRectConstraint, palette);
+        fDL->drawImageRect(std::move(payload), nullptr, dst, sampling, paint,
+                           kFast_SrcRectConstraint);
     }
 }
 
 void RecordingCanvas::onDrawImage2(const SkImage* img, SkScalar x, SkScalar y,
                                    const SkSamplingOptions& sampling, const SkPaint* paint) {
-    fDL->drawImage(sk_ref_sp(img), x, y, sampling, paint, BitmapPalette::Unknown);
+    fDL->drawImage(DrawImagePayload(img), x, y, sampling, paint);
 }
 
 void RecordingCanvas::onDrawImageRect2(const SkImage* img, const SkRect& src, const SkRect& dst,
                                        const SkSamplingOptions& sampling, const SkPaint* paint,
                                        SrcRectConstraint constraint) {
-    fDL->drawImageRect(sk_ref_sp(img), &src, dst, sampling, paint, constraint,
-                       BitmapPalette::Unknown);
+    fDL->drawImageRect(DrawImagePayload(img), &src, dst, sampling, paint, constraint);
 }
 
 void RecordingCanvas::onDrawImageLattice2(const SkImage* img, const SkCanvas::Lattice& lattice,
                                           const SkRect& dst, SkFilterMode filter,
                                           const SkPaint* paint) {
-    fDL->drawImageLattice(sk_ref_sp(img), lattice, dst, filter, paint, BitmapPalette::Unknown);
+    fDL->drawImageLattice(DrawImagePayload(img), lattice, dst, filter, paint);
 }
 
 void RecordingCanvas::onDrawPatch(const SkPoint cubics[12], const SkColor colors[4],
diff --git a/libs/hwui/RecordingCanvas.h b/libs/hwui/RecordingCanvas.h
index b7d4dc9..8409e13 100644
--- a/libs/hwui/RecordingCanvas.h
+++ b/libs/hwui/RecordingCanvas.h
@@ -16,6 +16,14 @@
 
 #pragma once
 
+#include <SkCanvas.h>
+#include <SkCanvasVirtualEnforcer.h>
+#include <SkDrawable.h>
+#include <SkGainmapInfo.h>
+#include <SkNoDrawCanvas.h>
+#include <SkPaint.h>
+#include <SkPath.h>
+#include <SkRect.h>
 #include <SkRuntimeEffect.h>
 #include <log/log.h>
 
@@ -23,13 +31,7 @@
 #include <vector>
 
 #include "CanvasTransform.h"
-#include "SkCanvas.h"
-#include "SkCanvasVirtualEnforcer.h"
-#include "SkDrawable.h"
-#include "SkNoDrawCanvas.h"
-#include "SkPaint.h"
-#include "SkPath.h"
-#include "SkRect.h"
+#include "Gainmap.h"
 #include "hwui/Bitmap.h"
 #include "pipeline/skia/AnimatedDrawables.h"
 #include "utils/AutoMalloc.h"
@@ -64,6 +66,32 @@
 
 static_assert(sizeof(DisplayListOp) == 4);
 
+struct DrawImagePayload {
+    explicit DrawImagePayload(Bitmap& bitmap)
+            : image(bitmap.makeImage()), palette(bitmap.palette()) {
+        if (bitmap.hasGainmap()) {
+            auto gainmap = bitmap.gainmap();
+            gainmapInfo = gainmap->info;
+            gainmapImage = gainmap->bitmap->makeImage();
+        }
+    }
+
+    explicit DrawImagePayload(const SkImage* image)
+            : image(sk_ref_sp(image)), palette(BitmapPalette::Unknown) {}
+
+    DrawImagePayload(const DrawImagePayload&) = default;
+    DrawImagePayload(DrawImagePayload&&) = default;
+    DrawImagePayload& operator=(const DrawImagePayload&) = default;
+    DrawImagePayload& operator=(DrawImagePayload&&) = default;
+    ~DrawImagePayload() = default;
+
+    sk_sp<SkImage> image;
+    BitmapPalette palette;
+
+    sk_sp<SkImage> gainmapImage;
+    SkGainmapInfo gainmapInfo;
+};
+
 class RecordingCanvas;
 
 class DisplayListData final {
@@ -122,13 +150,12 @@
 
     void drawTextBlob(const SkTextBlob*, SkScalar, SkScalar, const SkPaint&);
 
-    void drawImage(sk_sp<const SkImage>, SkScalar, SkScalar, const SkSamplingOptions&,
-                   const SkPaint*, BitmapPalette palette);
-    void drawImageNine(sk_sp<const SkImage>, const SkIRect&, const SkRect&, const SkPaint*);
-    void drawImageRect(sk_sp<const SkImage>, const SkRect*, const SkRect&, const SkSamplingOptions&,
-                       const SkPaint*, SkCanvas::SrcRectConstraint, BitmapPalette palette);
-    void drawImageLattice(sk_sp<const SkImage>, const SkCanvas::Lattice&, const SkRect&,
-                          SkFilterMode, const SkPaint*, BitmapPalette);
+    void drawImage(DrawImagePayload&&, SkScalar, SkScalar, const SkSamplingOptions&,
+                   const SkPaint*);
+    void drawImageRect(DrawImagePayload&&, const SkRect*, const SkRect&, const SkSamplingOptions&,
+                       const SkPaint*, SkCanvas::SrcRectConstraint);
+    void drawImageLattice(DrawImagePayload&&, const SkCanvas::Lattice&, const SkRect&, SkFilterMode,
+                          const SkPaint*);
 
     void drawPatch(const SkPoint[12], const SkColor[4], const SkPoint[4], SkBlendMode,
                    const SkPaint&);
@@ -195,14 +222,14 @@
 
     void onDrawTextBlob(const SkTextBlob*, SkScalar, SkScalar, const SkPaint&) override;
 
-    void drawImage(const sk_sp<SkImage>&, SkScalar left, SkScalar top, const SkSamplingOptions&,
-                   const SkPaint* paint, BitmapPalette pallete);
     void drawRippleDrawable(const skiapipeline::RippleDrawableParams& params);
 
-    void drawImageRect(const sk_sp<SkImage>& image, const SkRect& src, const SkRect& dst,
-                       const SkSamplingOptions&, const SkPaint*, SrcRectConstraint, BitmapPalette);
-    void drawImageLattice(const sk_sp<SkImage>& image, const Lattice& lattice, const SkRect& dst,
-                          SkFilterMode, const SkPaint* paint, BitmapPalette palette);
+    void drawImage(DrawImagePayload&&, SkScalar, SkScalar, const SkSamplingOptions&,
+                   const SkPaint*);
+    void drawImageRect(DrawImagePayload&&, const SkRect&, const SkRect&, const SkSamplingOptions&,
+                       const SkPaint*, SrcRectConstraint);
+    void drawImageLattice(DrawImagePayload&&, const Lattice& lattice, const SkRect&, SkFilterMode,
+                          const SkPaint*);
 
     void onDrawImage2(const SkImage*, SkScalar, SkScalar, const SkSamplingOptions&,
                       const SkPaint*) override;
diff --git a/libs/hwui/effects/GainmapRenderer.cpp b/libs/hwui/effects/GainmapRenderer.cpp
new file mode 100644
index 0000000..a544ae8
--- /dev/null
+++ b/libs/hwui/effects/GainmapRenderer.cpp
@@ -0,0 +1,64 @@
+/*
+ * Copyright (C) 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "GainmapRenderer.h"
+
+#include <SkGainmapShader.h>
+
+#include "Gainmap.h"
+#include "Rect.h"
+#include "utils/Trace.h"
+
+#ifdef __ANDROID__
+#include "renderthread/CanvasContext.h"
+#endif
+
+namespace android::uirenderer {
+
+using namespace renderthread;
+
+void DrawGainmapBitmap(SkCanvas* c, const sk_sp<const SkImage>& image, const SkRect& src,
+                       const SkRect& dst, const SkSamplingOptions& sampling, const SkPaint* paint,
+                       SkCanvas::SrcRectConstraint constraint,
+                       const sk_sp<const SkImage>& gainmapImage, const SkGainmapInfo& gainmapInfo) {
+    ATRACE_CALL();
+#ifdef __ANDROID__
+    CanvasContext* context = CanvasContext::getActiveContext();
+    float targetSdrHdrRatio = context ? context->targetSdrHdrRatio() : 1.f;
+    if (targetSdrHdrRatio > 1.f && gainmapImage) {
+        SkPaint gainmapPaint = *paint;
+        float sX = gainmapImage->width() / (float)image->width();
+        float sY = gainmapImage->height() / (float)image->height();
+        SkRect gainmapSrc = src;
+        // TODO: Tweak rounding?
+        gainmapSrc.fLeft *= sX;
+        gainmapSrc.fRight *= sX;
+        gainmapSrc.fTop *= sY;
+        gainmapSrc.fBottom *= sY;
+        // TODO: Temporary workaround for SkGainmapShader::Make not having a const variant
+        sk_sp<SkImage> mutImage = sk_ref_sp(const_cast<SkImage*>(image.get()));
+        sk_sp<SkImage> mutGainmap = sk_ref_sp(const_cast<SkImage*>(gainmapImage.get()));
+        auto shader = SkGainmapShader::Make(mutImage, src, sampling, mutGainmap, gainmapSrc,
+                                            sampling, gainmapInfo, dst, targetSdrHdrRatio,
+                                            c->imageInfo().refColorSpace());
+        gainmapPaint.setShader(shader);
+        c->drawRect(dst, gainmapPaint);
+    } else
+#endif
+        c->drawImageRect(image.get(), src, dst, sampling, paint, constraint);
+}
+
+}  // namespace android::uirenderer
\ No newline at end of file
diff --git a/libs/hwui/effects/GainmapRenderer.h b/libs/hwui/effects/GainmapRenderer.h
new file mode 100644
index 0000000..7c56d94
--- /dev/null
+++ b/libs/hwui/effects/GainmapRenderer.h
@@ -0,0 +1,33 @@
+/*
+ * Copyright (C) 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include <SkCanvas.h>
+#include <SkGainmapInfo.h>
+#include <SkImage.h>
+#include <SkPaint.h>
+
+#include "hwui/Bitmap.h"
+
+namespace android::uirenderer {
+
+void DrawGainmapBitmap(SkCanvas* c, const sk_sp<const SkImage>& image, const SkRect& src,
+                       const SkRect& dst, const SkSamplingOptions& sampling, const SkPaint* paint,
+                       SkCanvas::SrcRectConstraint constraint,
+                       const sk_sp<const SkImage>& gainmapImage, const SkGainmapInfo& gainmapInfo);
+
+}  // namespace android::uirenderer
diff --git a/libs/hwui/pipeline/skia/SkiaRecordingCanvas.cpp b/libs/hwui/pipeline/skia/SkiaRecordingCanvas.cpp
index c9d79ab..e1c8877 100644
--- a/libs/hwui/pipeline/skia/SkiaRecordingCanvas.cpp
+++ b/libs/hwui/pipeline/skia/SkiaRecordingCanvas.cpp
@@ -208,40 +208,52 @@
     }
 }
 
+void SkiaRecordingCanvas::handleMutableImages(Bitmap& bitmap, DrawImagePayload& payload) {
+    // if image->unique() is true, then mRecorder.drawImage failed for some reason. It also means
+    // it is not safe to store a raw SkImage pointer, because the image object will be destroyed
+    // when this function ends.
+    if (!bitmap.isImmutable() && payload.image.get() && !payload.image->unique()) {
+        mDisplayList->mMutableImages.push_back(payload.image.get());
+    }
+
+    if (bitmap.hasGainmap()) {
+        auto gainmapBitmap = bitmap.gainmap()->bitmap;
+        // Not all DrawImagePayload receivers will store the gainmap (such as DrawImageLattice),
+        // so only store it in the mutable list if it was actually recorded
+        if (!gainmapBitmap->isImmutable() && payload.gainmapImage.get() &&
+            !payload.gainmapImage->unique()) {
+            mDisplayList->mMutableImages.push_back(payload.gainmapImage.get());
+        }
+    }
+}
+
 void SkiaRecordingCanvas::drawBitmap(Bitmap& bitmap, float left, float top, const Paint* paint) {
-    sk_sp<SkImage> image = bitmap.makeImage();
+    auto payload = DrawImagePayload(bitmap);
 
     applyLooper(
             paint,
             [&](const Paint& p) {
-                mRecorder.drawImage(image, left, top, p.sampling(), &p, bitmap.palette());
+                mRecorder.drawImage(DrawImagePayload(payload), left, top, p.sampling(), &p);
             },
             FilterForImage);
 
-    // if image->unique() is true, then mRecorder.drawImage failed for some reason. It also means
-    // it is not safe to store a raw SkImage pointer, because the image object will be destroyed
-    // when this function ends.
-    if (!bitmap.isImmutable() && image.get() && !image->unique()) {
-        mDisplayList->mMutableImages.push_back(image.get());
-    }
+    handleMutableImages(bitmap, payload);
 }
 
 void SkiaRecordingCanvas::drawBitmap(Bitmap& bitmap, const SkMatrix& matrix, const Paint* paint) {
     SkAutoCanvasRestore acr(&mRecorder, true);
     concat(matrix);
 
-    sk_sp<SkImage> image = bitmap.makeImage();
+    auto payload = DrawImagePayload(bitmap);
 
     applyLooper(
             paint,
             [&](const Paint& p) {
-                mRecorder.drawImage(image, 0, 0, p.sampling(), &p, bitmap.palette());
+                mRecorder.drawImage(DrawImagePayload(payload), 0, 0, p.sampling(), &p);
             },
             FilterForImage);
 
-    if (!bitmap.isImmutable() && image.get() && !image->unique()) {
-        mDisplayList->mMutableImages.push_back(image.get());
-    }
+    handleMutableImages(bitmap, payload);
 }
 
 void SkiaRecordingCanvas::drawBitmap(Bitmap& bitmap, float srcLeft, float srcTop, float srcRight,
@@ -250,20 +262,17 @@
     SkRect srcRect = SkRect::MakeLTRB(srcLeft, srcTop, srcRight, srcBottom);
     SkRect dstRect = SkRect::MakeLTRB(dstLeft, dstTop, dstRight, dstBottom);
 
-    sk_sp<SkImage> image = bitmap.makeImage();
+    auto payload = DrawImagePayload(bitmap);
 
     applyLooper(
             paint,
             [&](const Paint& p) {
-                mRecorder.drawImageRect(image, srcRect, dstRect, p.sampling(), &p,
-                                        SkCanvas::kFast_SrcRectConstraint, bitmap.palette());
+                mRecorder.drawImageRect(DrawImagePayload(payload), srcRect, dstRect, p.sampling(),
+                                        &p, SkCanvas::kFast_SrcRectConstraint);
             },
             FilterForImage);
 
-    if (!bitmap.isImmutable() && image.get() && !image->unique() && !srcRect.isEmpty() &&
-        !dstRect.isEmpty()) {
-        mDisplayList->mMutableImages.push_back(image.get());
-    }
+    handleMutableImages(bitmap, payload);
 }
 
 void SkiaRecordingCanvas::drawNinePatch(Bitmap& bitmap, const Res_png_9patch& chunk, float dstLeft,
@@ -291,7 +300,7 @@
 
     lattice.fBounds = nullptr;
     SkRect dst = SkRect::MakeLTRB(dstLeft, dstTop, dstRight, dstBottom);
-    sk_sp<SkImage> image = bitmap.makeImage();
+    auto payload = DrawImagePayload(bitmap);
 
     // HWUI always draws 9-patches with linear filtering, regardless of the Paint.
     const SkFilterMode filter = SkFilterMode::kLinear;
@@ -299,13 +308,11 @@
     applyLooper(
             paint,
             [&](const SkPaint& p) {
-                mRecorder.drawImageLattice(image, lattice, dst, filter, &p, bitmap.palette());
+                mRecorder.drawImageLattice(DrawImagePayload(payload), lattice, dst, filter, &p);
             },
             FilterForImage);
 
-    if (!bitmap.isImmutable() && image.get() && !image->unique() && !dst.isEmpty()) {
-        mDisplayList->mMutableImages.push_back(image.get());
-    }
+    handleMutableImages(bitmap, payload);
 }
 
 double SkiaRecordingCanvas::drawAnimatedImage(AnimatedImageDrawable* animatedImage) {
diff --git a/libs/hwui/pipeline/skia/SkiaRecordingCanvas.h b/libs/hwui/pipeline/skia/SkiaRecordingCanvas.h
index 7844e2c..3fd8fa3 100644
--- a/libs/hwui/pipeline/skia/SkiaRecordingCanvas.h
+++ b/libs/hwui/pipeline/skia/SkiaRecordingCanvas.h
@@ -102,6 +102,8 @@
      */
     void initDisplayList(uirenderer::RenderNode* renderNode, int width, int height);
 
+    void handleMutableImages(Bitmap& bitmap, DrawImagePayload& payload);
+
     using INHERITED = SkiaCanvas;
 };