Update ColorFilter API to be backed by mutable native objects

Move the native ColorFilter implementation off of Skia's
SkColorFilter and instead have a mutable intermediate object
that can be inspected.

Bug: 264559422
Test: re-ran CtsUiRenderingTestCases
Change-Id: I9ec056084f00e72632c86bdf88376b1307e8ef74
diff --git a/libs/hwui/ColorFilter.h b/libs/hwui/ColorFilter.h
new file mode 100644
index 0000000..1a5b938
--- /dev/null
+++ b/libs/hwui/ColorFilter.h
@@ -0,0 +1,94 @@
+/*
+ * Copyright (C) 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef COLORFILTER_H_
+#define COLORFILTER_H_
+
+#include <stdint.h>
+
+#include <memory>
+
+#include "GraphicsJNI.h"
+#include "SkColorFilter.h"
+#include "SkiaWrapper.h"
+
+namespace android {
+namespace uirenderer {
+
+class ColorFilter : public SkiaWrapper<SkColorFilter> {
+public:
+    static ColorFilter* fromJava(jlong handle) { return reinterpret_cast<ColorFilter*>(handle); }
+
+protected:
+    ColorFilter() = default;
+};
+
+class BlendModeColorFilter : public ColorFilter {
+public:
+    BlendModeColorFilter(SkColor color, SkBlendMode mode) : mColor(color), mMode(mode) {}
+
+private:
+    sk_sp<SkColorFilter> createInstance() override { return SkColorFilters::Blend(mColor, mMode); }
+
+private:
+    const SkColor mColor;
+    const SkBlendMode mMode;
+};
+
+class LightingFilter : public ColorFilter {
+public:
+    LightingFilter(SkColor mul, SkColor add) : mMul(mul), mAdd(add) {}
+
+    void setMul(SkColor mul) {
+        mMul = mul;
+        discardInstance();
+    }
+
+    void setAdd(SkColor add) {
+        mAdd = add;
+        discardInstance();
+    }
+
+private:
+    sk_sp<SkColorFilter> createInstance() override { return SkColorFilters::Lighting(mMul, mAdd); }
+
+private:
+    SkColor mMul;
+    SkColor mAdd;
+};
+
+class ColorMatrixColorFilter : public ColorFilter {
+public:
+    ColorMatrixColorFilter(std::vector<float>&& matrix) : mMatrix(std::move(matrix)) {}
+
+    void setMatrix(std::vector<float>&& matrix) {
+        mMatrix = std::move(matrix);
+        discardInstance();
+    }
+
+private:
+    sk_sp<SkColorFilter> createInstance() override {
+        return SkColorFilters::Matrix(mMatrix.data());
+    }
+
+private:
+    std::vector<float> mMatrix;
+};
+
+}  // namespace uirenderer
+}  // namespace android
+
+#endif  // COLORFILTER_H_
diff --git a/libs/hwui/SkiaWrapper.h b/libs/hwui/SkiaWrapper.h
new file mode 100644
index 0000000..bd0e35a
--- /dev/null
+++ b/libs/hwui/SkiaWrapper.h
@@ -0,0 +1,56 @@
+/*
+ * Copyright (C) 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef SKIA_WRAPPER_H_
+#define SKIA_WRAPPER_H_
+
+#include <SkRefCnt.h>
+#include <utils/RefBase.h>
+
+namespace android::uirenderer {
+
+template <typename T>
+class SkiaWrapper : public VirtualLightRefBase {
+public:
+    sk_sp<T> getInstance() {
+        if (mInstance != nullptr && shouldDiscardInstance()) {
+            mInstance = nullptr;
+        }
+
+        if (mInstance == nullptr) {
+            mInstance = createInstance();
+            mGenerationId++;
+        }
+        return mInstance;
+    }
+
+    virtual bool shouldDiscardInstance() const { return false; }
+
+    void discardInstance() { mInstance = nullptr; }
+
+    [[nodiscard]] int32_t getGenerationId() const { return mGenerationId; }
+
+protected:
+    virtual sk_sp<T> createInstance() = 0;
+
+private:
+    sk_sp<T> mInstance = nullptr;
+    int32_t mGenerationId = 0;
+};
+
+}  // namespace android::uirenderer
+
+#endif  // SKIA_WRAPPER_H_
diff --git a/libs/hwui/jni/AnimatedImageDrawable.cpp b/libs/hwui/jni/AnimatedImageDrawable.cpp
index a7f5aa83..90b1da8 100644
--- a/libs/hwui/jni/AnimatedImageDrawable.cpp
+++ b/libs/hwui/jni/AnimatedImageDrawable.cpp
@@ -14,10 +14,6 @@
  * limitations under the License.
  */
 
-#include "GraphicsJNI.h"
-#include "ImageDecoder.h"
-#include "Utils.h"
-
 #include <SkAndroidCodec.h>
 #include <SkAnimatedImage.h>
 #include <SkColorFilter.h>
@@ -27,10 +23,15 @@
 #include <SkRect.h>
 #include <SkRefCnt.h>
 #include <hwui/AnimatedImageDrawable.h>
-#include <hwui/ImageDecoder.h>
 #include <hwui/Canvas.h>
+#include <hwui/ImageDecoder.h>
 #include <utils/Looper.h>
 
+#include "ColorFilter.h"
+#include "GraphicsJNI.h"
+#include "ImageDecoder.h"
+#include "Utils.h"
+
 using namespace android;
 
 static jclass gAnimatedImageDrawableClass;
@@ -145,8 +146,9 @@
 static void AnimatedImageDrawable_nSetColorFilter(JNIEnv* env, jobject /*clazz*/, jlong nativePtr,
                                                   jlong nativeFilter) {
     auto* drawable = reinterpret_cast<AnimatedImageDrawable*>(nativePtr);
-    auto* filter = reinterpret_cast<SkColorFilter*>(nativeFilter);
-    drawable->setStagingColorFilter(sk_ref_sp(filter));
+    auto filter = uirenderer::ColorFilter::fromJava(nativeFilter);
+    auto skColorFilter = filter != nullptr ? filter->getInstance() : sk_sp<SkColorFilter>();
+    drawable->setStagingColorFilter(skColorFilter);
 }
 
 static jboolean AnimatedImageDrawable_nIsRunning(JNIEnv* env, jobject /*clazz*/, jlong nativePtr) {
diff --git a/libs/hwui/jni/ColorFilter.cpp b/libs/hwui/jni/ColorFilter.cpp
index 4bd7ef4..0b95148 100644
--- a/libs/hwui/jni/ColorFilter.cpp
+++ b/libs/hwui/jni/ColorFilter.cpp
@@ -15,20 +15,21 @@
 ** limitations under the License.
 */
 
-#include "GraphicsJNI.h"
+#include "ColorFilter.h"
 
+#include "GraphicsJNI.h"
 #include "SkBlendMode.h"
-#include "SkColorFilter.h"
-#include "SkColorMatrixFilter.h"
 
 namespace android {
 
 using namespace uirenderer;
 
-class SkColorFilterGlue {
+class ColorFilterGlue {
 public:
-    static void SafeUnref(SkColorFilter* filter) {
-        SkSafeUnref(filter);
+    static void SafeUnref(ColorFilter* filter) {
+        if (filter) {
+            filter->decStrong(nullptr);
+        }
     }
 
     static jlong GetNativeFinalizer(JNIEnv*, jobject) {
@@ -36,41 +37,75 @@
     }
 
     static jlong CreateBlendModeFilter(JNIEnv* env, jobject, jint srcColor, jint modeHandle) {
-        SkBlendMode mode = static_cast<SkBlendMode>(modeHandle);
-        return reinterpret_cast<jlong>(SkColorFilters::Blend(srcColor, mode).release());
+        auto mode = static_cast<SkBlendMode>(modeHandle);
+        auto* blendModeFilter = new BlendModeColorFilter(srcColor, mode);
+        blendModeFilter->incStrong(nullptr);
+        return static_cast<jlong>(reinterpret_cast<uintptr_t>(blendModeFilter));
     }
 
     static jlong CreateLightingFilter(JNIEnv* env, jobject, jint mul, jint add) {
-        return reinterpret_cast<jlong>(SkColorMatrixFilter::MakeLightingFilter(mul, add).release());
+        auto* lightingFilter = new LightingFilter(mul, add);
+        lightingFilter->incStrong(nullptr);
+        return static_cast<jlong>(reinterpret_cast<uintptr_t>(lightingFilter));
     }
 
-    static jlong CreateColorMatrixFilter(JNIEnv* env, jobject, jfloatArray jarray) {
-        float matrix[20];
-        env->GetFloatArrayRegion(jarray, 0, 20, matrix);
+    static void SetLightingFilterMul(JNIEnv* env, jobject, jlong lightingFilterPtr, jint mul) {
+        auto* filter = reinterpret_cast<LightingFilter*>(lightingFilterPtr);
+        if (filter) {
+            filter->setMul(mul);
+        }
+    }
+
+    static void SetLightingFilterAdd(JNIEnv* env, jobject, jlong lightingFilterPtr, jint add) {
+        auto* filter = reinterpret_cast<LightingFilter*>(lightingFilterPtr);
+        if (filter) {
+            filter->setAdd(add);
+        }
+    }
+
+    static std::vector<float> getMatrixFromJFloatArray(JNIEnv* env, jfloatArray jarray) {
+        std::vector<float> matrix(20);
+        // float matrix[20];
+        env->GetFloatArrayRegion(jarray, 0, 20, matrix.data());
         // java biases the translates by 255, so undo that before calling skia
         matrix[ 4] *= (1.0f/255);
         matrix[ 9] *= (1.0f/255);
         matrix[14] *= (1.0f/255);
         matrix[19] *= (1.0f/255);
-        return reinterpret_cast<jlong>(SkColorFilters::Matrix(matrix).release());
+        return matrix;
+    }
+
+    static jlong CreateColorMatrixFilter(JNIEnv* env, jobject, jfloatArray jarray) {
+        std::vector<float> matrix = getMatrixFromJFloatArray(env, jarray);
+        auto* colorMatrixColorFilter = new ColorMatrixColorFilter(std::move(matrix));
+        colorMatrixColorFilter->incStrong(nullptr);
+        return static_cast<jlong>(reinterpret_cast<uintptr_t>(colorMatrixColorFilter));
+    }
+
+    static void SetColorMatrix(JNIEnv* env, jobject, jlong colorMatrixColorFilterPtr,
+                               jfloatArray jarray) {
+        auto* filter = reinterpret_cast<ColorMatrixColorFilter*>(colorMatrixColorFilterPtr);
+        if (filter) {
+            filter->setMatrix(getMatrixFromJFloatArray(env, jarray));
+        }
     }
 };
 
 static const JNINativeMethod colorfilter_methods[] = {
-    {"nativeGetFinalizer", "()J", (void*) SkColorFilterGlue::GetNativeFinalizer }
-};
+        {"nativeGetFinalizer", "()J", (void*)ColorFilterGlue::GetNativeFinalizer}};
 
 static const JNINativeMethod blendmode_methods[] = {
-    { "native_CreateBlendModeFilter", "(II)J", (void*) SkColorFilterGlue::CreateBlendModeFilter },
+        {"native_CreateBlendModeFilter", "(II)J", (void*)ColorFilterGlue::CreateBlendModeFilter},
 };
 
 static const JNINativeMethod lighting_methods[] = {
-    { "native_CreateLightingFilter", "(II)J", (void*) SkColorFilterGlue::CreateLightingFilter },
-};
+        {"native_CreateLightingFilter", "(II)J", (void*)ColorFilterGlue::CreateLightingFilter},
+        {"native_SetLightingFilterAdd", "(JI)V", (void*)ColorFilterGlue::SetLightingFilterAdd},
+        {"native_SetLightingFilterMul", "(JI)V", (void*)ColorFilterGlue::SetLightingFilterMul}};
 
 static const JNINativeMethod colormatrix_methods[] = {
-    { "nativeColorMatrixFilter", "([F)J", (void*) SkColorFilterGlue::CreateColorMatrixFilter },
-};
+        {"nativeColorMatrixFilter", "([F)J", (void*)ColorFilterGlue::CreateColorMatrixFilter},
+        {"nativeSetColorMatrix", "(J[F)V", (void*)ColorFilterGlue::SetColorMatrix}};
 
 int register_android_graphics_ColorFilter(JNIEnv* env) {
     android::RegisterMethodsOrDie(env, "android/graphics/ColorFilter", colorfilter_methods,
diff --git a/libs/hwui/jni/Paint.cpp b/libs/hwui/jni/Paint.cpp
index 13357fa..ace896d 100644
--- a/libs/hwui/jni/Paint.cpp
+++ b/libs/hwui/jni/Paint.cpp
@@ -18,13 +18,29 @@
 #undef LOG_TAG
 #define LOG_TAG "Paint"
 
-#include <utils/Log.h>
-
-#include "GraphicsJNI.h"
+#include <hwui/BlurDrawLooper.h>
+#include <hwui/MinikinSkia.h>
+#include <hwui/MinikinUtils.h>
+#include <hwui/Paint.h>
+#include <hwui/Typeface.h>
+#include <minikin/GraphemeBreak.h>
+#include <minikin/LocaleList.h>
+#include <minikin/Measurement.h>
+#include <minikin/MinikinPaint.h>
+#include <nativehelper/ScopedPrimitiveArray.h>
 #include <nativehelper/ScopedStringChars.h>
 #include <nativehelper/ScopedUtfChars.h>
-#include <nativehelper/ScopedPrimitiveArray.h>
+#include <unicode/utf16.h>
+#include <utils/Log.h>
 
+#include <cassert>
+#include <cstring>
+#include <memory>
+#include <vector>
+
+#include "ColorFilter.h"
+#include "GraphicsJNI.h"
+#include "SkBlendMode.h"
 #include "SkColorFilter.h"
 #include "SkColorSpace.h"
 #include "SkFont.h"
@@ -35,26 +51,9 @@
 #include "SkPathEffect.h"
 #include "SkPathUtils.h"
 #include "SkShader.h"
-#include "SkBlendMode.h"
 #include "unicode/uloc.h"
 #include "utils/Blur.h"
 
-#include <hwui/BlurDrawLooper.h>
-#include <hwui/MinikinSkia.h>
-#include <hwui/MinikinUtils.h>
-#include <hwui/Paint.h>
-#include <hwui/Typeface.h>
-#include <minikin/GraphemeBreak.h>
-#include <minikin/LocaleList.h>
-#include <minikin/Measurement.h>
-#include <minikin/MinikinPaint.h>
-#include <unicode/utf16.h>
-
-#include <cassert>
-#include <cstring>
-#include <memory>
-#include <vector>
-
 namespace android {
 
 static void getPosTextPath(const SkFont& font, const uint16_t glyphs[], int count,
@@ -821,9 +820,11 @@
 
     static jlong setColorFilter(CRITICAL_JNI_PARAMS_COMMA jlong objHandle, jlong filterHandle) {
         Paint* obj = reinterpret_cast<Paint *>(objHandle);
-        SkColorFilter* filter  = reinterpret_cast<SkColorFilter *>(filterHandle);
-        obj->setColorFilter(sk_ref_sp(filter));
-        return reinterpret_cast<jlong>(obj->getColorFilter());
+        auto colorFilter = uirenderer::ColorFilter::fromJava(filterHandle);
+        auto skColorFilter =
+                colorFilter != nullptr ? colorFilter->getInstance() : sk_sp<SkColorFilter>();
+        obj->setColorFilter(skColorFilter);
+        return filterHandle;
     }
 
     static void setXfermode(CRITICAL_JNI_PARAMS_COMMA jlong paintHandle, jint xfermodeHandle) {
diff --git a/libs/hwui/jni/RenderEffect.cpp b/libs/hwui/jni/RenderEffect.cpp
index f3db170..dcd3fa4 100644
--- a/libs/hwui/jni/RenderEffect.cpp
+++ b/libs/hwui/jni/RenderEffect.cpp
@@ -14,13 +14,13 @@
  * limitations under the License.
  */
 #include "Bitmap.h"
+#include "ColorFilter.h"
 #include "GraphicsJNI.h"
 #include "SkBlendMode.h"
 #include "SkImageFilter.h"
 #include "SkImageFilters.h"
 #include "graphics_jni_helpers.h"
 #include "utils/Blur.h"
-#include <utils/Log.h>
 
 using namespace android::uirenderer;
 
@@ -76,11 +76,13 @@
     jlong colorFilterHandle,
     jlong inputFilterHandle
 ) {
-    auto* colorFilter = reinterpret_cast<const SkColorFilter*>(colorFilterHandle);
+    auto colorFilter = android::uirenderer::ColorFilter::fromJava(colorFilterHandle);
+    auto skColorFilter =
+            colorFilter != nullptr ? colorFilter->getInstance() : sk_sp<SkColorFilter>();
     auto* inputFilter = reinterpret_cast<const SkImageFilter*>(inputFilterHandle);
-    sk_sp<SkImageFilter> colorFilterImageFilter = SkImageFilters::ColorFilter(
-            sk_ref_sp(colorFilter), sk_ref_sp(inputFilter), nullptr);
-   return reinterpret_cast<jlong>(colorFilterImageFilter.release());
+    sk_sp<SkImageFilter> colorFilterImageFilter =
+            SkImageFilters::ColorFilter(skColorFilter, sk_ref_sp(inputFilter), nullptr);
+    return reinterpret_cast<jlong>(colorFilterImageFilter.release());
 }
 
 static jlong createBlendModeEffect(
diff --git a/libs/hwui/jni/android_graphics_drawable_VectorDrawable.cpp b/libs/hwui/jni/android_graphics_drawable_VectorDrawable.cpp
index 9cffceb..ade48f2 100644
--- a/libs/hwui/jni/android_graphics_drawable_VectorDrawable.cpp
+++ b/libs/hwui/jni/android_graphics_drawable_VectorDrawable.cpp
@@ -14,13 +14,13 @@
  * limitations under the License.
  */
 
-#include "GraphicsJNI.h"
+#include <hwui/Paint.h>
 
+#include "ColorFilter.h"
+#include "GraphicsJNI.h"
 #include "PathParser.h"
 #include "VectorDrawable.h"
 
-#include <hwui/Paint.h>
-
 namespace android {
 using namespace uirenderer;
 using namespace uirenderer::VectorDrawable;
@@ -108,8 +108,9 @@
     Canvas* canvas = reinterpret_cast<Canvas*>(canvasPtr);
     SkRect rect;
     GraphicsJNI::jrect_to_rect(env, jrect, &rect);
-    SkColorFilter* colorFilter = reinterpret_cast<SkColorFilter*>(colorFilterPtr);
-    return tree->draw(canvas, colorFilter, rect, needsMirroring, canReuseCache);
+    auto colorFilter = ColorFilter::fromJava(colorFilterPtr);
+    auto skColorFilter = colorFilter != nullptr ? colorFilter->getInstance() : nullptr;
+    return tree->draw(canvas, skColorFilter.get(), rect, needsMirroring, canReuseCache);
 }
 
 /**