Refactor SkImageFilter usage to cache results.

If an SkImageFilter is used, create an image
snapshot with the filter applied to avoid
re-computing it on each draw invocation

Bug: 188450217
Test: Re-ran CTS tests

Change-Id: Ib790669e14ada9d4ebbfac958d699e2b5242f2d7
diff --git a/libs/hwui/RenderNode.cpp b/libs/hwui/RenderNode.cpp
index 9a9e6d4..332f7e6 100644
--- a/libs/hwui/RenderNode.cpp
+++ b/libs/hwui/RenderNode.cpp
@@ -306,11 +306,17 @@
         info.damageAccumulator->popTransform();
         syncProperties();
 
-        const StretchEffect& stagingStretch =
-            mProperties.layerProperties().getStretchEffect();
+        auto& layerProperties = mProperties.layerProperties();
+        const StretchEffect& stagingStretch = layerProperties.getStretchEffect();
         if (stagingStretch.isEmpty()) {
             mStretchMask.clear();
         }
+
+        if (layerProperties.getImageFilter() == nullptr) {
+            mSnapshotResult.snapshot = nullptr;
+            mTargetImageFilter = nullptr;
+        }
+
         // We could try to be clever and only re-damage if the matrix changed.
         // However, we don't need to worry about that. The cost of over-damaging
         // here is only going to be a single additional map rect of this node
@@ -321,6 +327,44 @@
     }
 }
 
+std::optional<RenderNode::SnapshotResult> RenderNode::updateSnapshotIfRequired(
+    GrRecordingContext* context,
+    const SkImageFilter* imageFilter,
+    const SkIRect& clipBounds
+) {
+    auto* layerSurface = getLayerSurface();
+    if (layerSurface == nullptr) {
+        return std::nullopt;
+    }
+
+    sk_sp<SkImage> snapshot = layerSurface->makeImageSnapshot();
+    const auto subset = SkIRect::MakeWH(properties().getWidth(),
+                                        properties().getHeight());
+    // If we don't have an ImageFilter just return the snapshot
+    if (imageFilter == nullptr) {
+        mSnapshotResult.snapshot = snapshot;
+        mSnapshotResult.outSubset = subset;
+        mSnapshotResult.outOffset = SkIPoint::Make(0.0f, 0.0f);
+        mImageFilterClipBounds = clipBounds;
+        mTargetImageFilter = nullptr;
+    } else if (mSnapshotResult.snapshot == nullptr ||
+        imageFilter != mTargetImageFilter.get() ||
+        mImageFilterClipBounds != clipBounds) {
+        // Otherwise create a new snapshot with the given filter and snapshot
+        mSnapshotResult.snapshot =
+                snapshot->makeWithFilter(context,
+                                         imageFilter,
+                                         subset,
+                                         clipBounds,
+                                         &mSnapshotResult.outSubset,
+                                         &mSnapshotResult.outOffset);
+        mTargetImageFilter = sk_ref_sp(imageFilter);
+        mImageFilterClipBounds = clipBounds;
+    }
+
+    return mSnapshotResult;
+}
+
 void RenderNode::syncDisplayList(TreeObserver& observer, TreeInfo* info) {
     // Make sure we inc first so that we don't fluctuate between 0 and 1,
     // which would thrash the layer cache
@@ -411,6 +455,8 @@
     if (hasLayer()) {
         this->setLayerSurface(nullptr);
     }
+    mSnapshotResult.snapshot = nullptr;
+    mTargetImageFilter = nullptr;
     if (mDisplayList) {
         mDisplayList.updateChildren([](RenderNode* child) { child->destroyLayers(); });
     }
diff --git a/libs/hwui/RenderNode.h b/libs/hwui/RenderNode.h
index 6a0b1aa..8595b6e 100644
--- a/libs/hwui/RenderNode.h
+++ b/libs/hwui/RenderNode.h
@@ -345,6 +345,16 @@
         return mSkiaLayer.get() ? mSkiaLayer->layerSurface.get() : nullptr;
     }
 
+    struct SnapshotResult {
+        sk_sp<SkImage> snapshot;
+        SkIRect outSubset;
+        SkIPoint outOffset;
+    };
+
+    std::optional<SnapshotResult> updateSnapshotIfRequired(GrRecordingContext* context,
+                                            const SkImageFilter* imageFilter,
+                                            const SkIRect& clipBounds);
+
     skiapipeline::SkiaLayer* getSkiaLayer() const { return mSkiaLayer.get(); }
 
     /**
@@ -375,6 +385,22 @@
      */
     std::unique_ptr<skiapipeline::SkiaLayer> mSkiaLayer;
 
+    /**
+     * SkImageFilter used to create the mSnapshotResult
+     */
+    sk_sp<SkImageFilter> mTargetImageFilter;
+
+    /**
+     * Clip bounds used to create the mSnapshotResult
+     */
+    SkIRect mImageFilterClipBounds;
+
+    /**
+     * Result of the most recent snapshot with additional metadata used to
+     * determine how to draw the contents
+     */
+    SnapshotResult mSnapshotResult;
+
     struct ClippedOutlineCache {
         // keys
         uint32_t outlineID = 0;
diff --git a/libs/hwui/effects/StretchEffect.cpp b/libs/hwui/effects/StretchEffect.cpp
index 807fb75..43f805d 100644
--- a/libs/hwui/effects/StretchEffect.cpp
+++ b/libs/hwui/effects/StretchEffect.cpp
@@ -188,7 +188,8 @@
 static const float INTERPOLATION_STRENGTH_VALUE = 0.7f;
 
 sk_sp<SkShader> StretchEffect::getShader(float width, float height,
-                                         const sk_sp<SkImage>& snapshotImage) const {
+                                         const sk_sp<SkImage>& snapshotImage,
+                                         const SkMatrix* matrix) const {
     if (isEmpty()) {
         return nullptr;
     }
@@ -206,8 +207,9 @@
         mBuilder = std::make_unique<SkRuntimeShaderBuilder>(getStretchEffect());
     }
 
-    mBuilder->child("uContentTexture") = snapshotImage->makeShader(
-            SkTileMode::kClamp, SkTileMode::kClamp, SkSamplingOptions(SkFilterMode::kLinear));
+    mBuilder->child("uContentTexture") =
+            snapshotImage->makeShader(SkTileMode::kClamp, SkTileMode::kClamp,
+                                      SkSamplingOptions(SkFilterMode::kLinear), matrix);
     mBuilder->uniform("uInterpolationStrength").set(&INTERPOLATION_STRENGTH_VALUE, 1);
     mBuilder->uniform("uStretchAffectedDistX").set(&width, 1);
     mBuilder->uniform("uStretchAffectedDistY").set(&height, 1);
diff --git a/libs/hwui/effects/StretchEffect.h b/libs/hwui/effects/StretchEffect.h
index 64fb2bf..25777c2 100644
--- a/libs/hwui/effects/StretchEffect.h
+++ b/libs/hwui/effects/StretchEffect.h
@@ -93,8 +93,8 @@
      */
     float computeStretchedPositionY(float normalizedY) const;
 
-    sk_sp<SkShader> getShader(float width, float height,
-                              const sk_sp<SkImage>& snapshotImage) const;
+    sk_sp<SkShader> getShader(float width, float height, const sk_sp<SkImage>& snapshotImage,
+                              const SkMatrix* matrix) const;
 
     float maxStretchAmountX = 0;
     float maxStretchAmountY = 0;
diff --git a/libs/hwui/pipeline/skia/RenderNodeDrawable.cpp b/libs/hwui/pipeline/skia/RenderNodeDrawable.cpp
index d7546d8..7556af9 100644
--- a/libs/hwui/pipeline/skia/RenderNodeDrawable.cpp
+++ b/libs/hwui/pipeline/skia/RenderNodeDrawable.cpp
@@ -171,17 +171,14 @@
     displayList->mProjectedOutline = nullptr;
 }
 
-static bool layerNeedsPaint(const sk_sp<SkImage>& snapshotImage, const LayerProperties& properties,
-                            float alphaMultiplier, SkPaint* paint) {
+static bool layerNeedsPaint(const LayerProperties& properties, float alphaMultiplier,
+                            SkPaint* paint) {
     if (alphaMultiplier < 1.0f || properties.alpha() < 255 ||
         properties.xferMode() != SkBlendMode::kSrcOver || properties.getColorFilter() != nullptr ||
-        properties.getImageFilter() != nullptr || properties.getStretchEffect().requiresLayer()) {
+        properties.getStretchEffect().requiresLayer()) {
         paint->setAlpha(properties.alpha() * alphaMultiplier);
         paint->setBlendMode(properties.xferMode());
         paint->setColorFilter(sk_ref_sp(properties.getColorFilter()));
-
-        sk_sp<SkImageFilter> imageFilter = sk_ref_sp(properties.getImageFilter());
-        paint->setImageFilter(std::move(imageFilter));
         return true;
     }
     return false;
@@ -223,6 +220,9 @@
     // TODO should we let the bound of the drawable do this for us?
     const SkRect bounds = SkRect::MakeWH(properties.getWidth(), properties.getHeight());
     bool quickRejected = properties.getClipToBounds() && canvas->quickReject(bounds);
+    auto clipBounds = canvas->getLocalClipBounds();
+    SkIRect srcBounds = SkIRect::MakeWH(bounds.width(), bounds.height());
+    SkIPoint offset = SkIPoint::Make(0.0f, 0.0f);
     if (!quickRejected) {
         SkiaDisplayList* displayList = renderNode->getDisplayList().asSkiaDl();
         const LayerProperties& layerProperties = properties.layerProperties();
@@ -230,8 +230,19 @@
         if (renderNode->getLayerSurface() && mComposeLayer) {
             SkASSERT(properties.effectiveLayerType() == LayerType::RenderLayer);
             SkPaint paint;
-            sk_sp<SkImage> snapshotImage = renderNode->getLayerSurface()->makeImageSnapshot();
-            layerNeedsPaint(snapshotImage, layerProperties, alphaMultiplier, &paint);
+            layerNeedsPaint(layerProperties, alphaMultiplier, &paint);
+            const auto snapshotResult = renderNode->updateSnapshotIfRequired(
+                canvas->recordingContext(),
+                layerProperties.getImageFilter(),
+                clipBounds.roundOut()
+            );
+            sk_sp<SkImage> snapshotImage = snapshotResult->snapshot;
+            srcBounds = snapshotResult->outSubset;
+            offset = snapshotResult->outOffset;
+            const auto dstBounds = SkIRect::MakeXYWH(offset.x(),
+                                                     offset.y(),
+                                                     srcBounds.width(),
+                                                     srcBounds.height());
             SkSamplingOptions sampling(SkFilterMode::kLinear);
 
             // surfaces for layers are created on LAYER_SIZE boundaries (which are >= layer size) so
@@ -257,7 +268,8 @@
                     TransformCanvas transformCanvas(canvas, SkBlendMode::kClear);
                     displayList->draw(&transformCanvas);
                 }
-                canvas->drawImageRect(snapshotImage, bounds, bounds, sampling, &paint,
+                canvas->drawImageRect(snapshotImage, SkRect::Make(srcBounds),
+                                      SkRect::Make(dstBounds), sampling, &paint,
                                       SkCanvas::kStrict_SrcRectConstraint);
             } else {
                 // If we do have stretch effects and have hole punches,
@@ -265,6 +277,16 @@
                 // get the corresponding hole punches.
                 // Then apply the stretch to the mask and draw the mask to
                 // the destination
+                // Also if the stretchy container has an ImageFilter applied
+                // to it (i.e. blur) we need to take into account the offset
+                // that will be generated with this result. Ex blurs will "grow"
+                // the source image by the blur radius so we need to translate
+                // the shader by the same amount to render in the same location
+                SkMatrix matrix;
+                matrix.setTranslate(
+                    offset.x() - srcBounds.left(),
+                    offset.y() - srcBounds.top()
+                );
                 if (renderNode->hasHolePunches()) {
                     GrRecordingContext* context = canvas->recordingContext();
                     StretchMask& stretchMask = renderNode->getStretchMask();
@@ -275,11 +297,10 @@
                                      canvas);
                 }
 
-                sk_sp<SkShader> stretchShader = stretch.getShader(bounds.width(),
-                                                                  bounds.height(),
-                                                                  snapshotImage);
+                sk_sp<SkShader> stretchShader =
+                        stretch.getShader(bounds.width(), bounds.height(), snapshotImage, &matrix);
                 paint.setShader(stretchShader);
-                canvas->drawRect(bounds, paint);
+                canvas->drawRect(SkRect::Make(dstBounds), paint);
             }
 
             if (!renderNode->getSkiaLayer()->hasRenderedSinceRepaint) {
diff --git a/libs/hwui/pipeline/skia/StretchMask.cpp b/libs/hwui/pipeline/skia/StretchMask.cpp
index 1c58c6a..2dbeb3a 100644
--- a/libs/hwui/pipeline/skia/StretchMask.cpp
+++ b/libs/hwui/pipeline/skia/StretchMask.cpp
@@ -59,8 +59,7 @@
     }
 
     sk_sp<SkImage> maskImage = mMaskSurface->makeImageSnapshot();
-    sk_sp<SkShader> maskStretchShader = stretch.getShader(
-        width, height, maskImage);
+    sk_sp<SkShader> maskStretchShader = stretch.getShader(width, height, maskImage, nullptr);
 
     SkPaint maskPaint;
     maskPaint.setShader(maskStretchShader);