fix incorrect casting for child color filters

Flag: com.android.graphics.hwui.flags.runtime_color_filters_blenders

Test: atest
CtsUiRenderingTestCases:RuntimeShaderTests
CtsUiRenderingTestCases:RuntimeColorFilterTests
CtsUiRenderingTestCases:RuntimeXfermodeTests

Bug: b/358126864 b/379193391

Change-Id: Ic0a92c18075e1fd9f07080c2b9a7a795e541cdee
diff --git a/graphics/java/android/graphics/RuntimeColorFilter.java b/graphics/java/android/graphics/RuntimeColorFilter.java
index d112f71..a64acfe 100644
--- a/graphics/java/android/graphics/RuntimeColorFilter.java
+++ b/graphics/java/android/graphics/RuntimeColorFilter.java
@@ -280,7 +280,8 @@
         if (colorFilter == null) {
             throw new NullPointerException("The colorFilter parameter must not be null");
         }
-        nativeUpdateChild(getNativeInstance(), filterName, colorFilter.getNativeInstance());
+        nativeUpdateInputColorFilter(getNativeInstance(), filterName,
+                colorFilter.getNativeInstance());
     }
 
     /**
@@ -318,5 +319,6 @@
             long colorFilter, String uniformName, int value1, int value2, int value3,
             int value4, int count);
     private static native void nativeUpdateChild(long colorFilter, String childName, long child);
-
+    private static native void nativeUpdateInputColorFilter(long colorFilter, String childName,
+            long inputFilter);
 }
diff --git a/graphics/java/android/graphics/RuntimeShader.java b/graphics/java/android/graphics/RuntimeShader.java
index 6316c1f..3543e99 100644
--- a/graphics/java/android/graphics/RuntimeShader.java
+++ b/graphics/java/android/graphics/RuntimeShader.java
@@ -264,6 +264,9 @@
      * enable better heap tracking & tooling support
      */
     private ArrayMap<String, Shader> mShaderUniforms = new ArrayMap<>();
+    private ArrayMap<String, ColorFilter> mColorFilterUniforms = new ArrayMap<>();
+    private ArrayMap<String, RuntimeXfermode> mXfermodeUniforms = new ArrayMap<>();
+
 
     /**
      * Creates a new RuntimeShader.
@@ -544,8 +547,10 @@
         if (colorFilter == null) {
             throw new NullPointerException("The colorFilter parameter must not be null");
         }
-        nativeUpdateChild(mNativeInstanceRuntimeShaderBuilder, filterName,
+        mColorFilterUniforms.put(filterName, colorFilter);
+        nativeUpdateColorFilter(mNativeInstanceRuntimeShaderBuilder, filterName,
                 colorFilter.getNativeInstance());
+        discardNativeInstance();
     }
 
     /**
@@ -563,8 +568,10 @@
         if (xfermode == null) {
             throw new NullPointerException("The xfermode parameter must not be null");
         }
+        mXfermodeUniforms.put(xfermodeName, xfermode);
         nativeUpdateChild(mNativeInstanceRuntimeShaderBuilder, xfermodeName,
                 xfermode.createNativeInstance());
+        discardNativeInstance();
     }
 
 
@@ -594,6 +601,8 @@
             int value4, int count);
     private static native void nativeUpdateShader(
             long shaderBuilder, String shaderName, long shader);
+    private static native void nativeUpdateColorFilter(
+            long shaderBuilder, String colorFilterName, long colorFilter);
     private static native void nativeUpdateChild(
             long shaderBuilder, String childName, long child);
 }
diff --git a/graphics/java/android/graphics/RuntimeXfermode.java b/graphics/java/android/graphics/RuntimeXfermode.java
index 51d97a4..c8a0b1a 100644
--- a/graphics/java/android/graphics/RuntimeXfermode.java
+++ b/graphics/java/android/graphics/RuntimeXfermode.java
@@ -285,7 +285,8 @@
         if (colorFilter == null) {
             throw new NullPointerException("The colorFilter parameter must not be null");
         }
-        nativeUpdateChild(mBuilderNativeInstance, filterName, colorFilter.getNativeInstance());
+        nativeUpdateColorFilter(mBuilderNativeInstance, filterName,
+                colorFilter.getNativeInstance());
     }
 
     /**
@@ -325,5 +326,6 @@
             long builder, String uniformName, int value1, int value2, int value3,
             int value4, int count);
     private static native void nativeUpdateChild(long builder, String childName, long child);
+    private static native void nativeUpdateColorFilter(long builder, String childName, long filter);
 
 }
diff --git a/libs/hwui/jni/ColorFilter.cpp b/libs/hwui/jni/ColorFilter.cpp
index 20301d2..1c6d886 100644
--- a/libs/hwui/jni/ColorFilter.cpp
+++ b/libs/hwui/jni/ColorFilter.cpp
@@ -163,6 +163,20 @@
             filter->updateChild(env, name.c_str(), child);
         }
     }
+
+    static void RuntimeColorFilter_updateInputColorFilter(JNIEnv* env, jobject,
+                                                          jlong colorFilterPtr, jstring childName,
+                                                          jlong childFilterPtr) {
+        auto* filter = reinterpret_cast<RuntimeColorFilter*>(colorFilterPtr);
+        ScopedUtfChars name(env, childName);
+        auto* child = reinterpret_cast<ColorFilter*>(childFilterPtr);
+        if (filter && child) {
+            auto childInput = child->getInstance();
+            if (childInput) {
+                filter->updateChild(env, name.c_str(), childInput.release());
+            }
+        }
+    }
 };
 
 static const JNINativeMethod colorfilter_methods[] = {
@@ -193,7 +207,9 @@
         {"nativeUpdateUniforms", "(JLjava/lang/String;IIIII)V",
          (void*)ColorFilterGlue::RuntimeColorFilter_updateUniformsInts},
         {"nativeUpdateChild", "(JLjava/lang/String;J)V",
-         (void*)ColorFilterGlue::RuntimeColorFilter_updateChild}};
+         (void*)ColorFilterGlue::RuntimeColorFilter_updateChild},
+        {"nativeUpdateInputColorFilter", "(JLjava/lang/String;J)V",
+         (void*)ColorFilterGlue::RuntimeColorFilter_updateInputColorFilter}};
 
 int register_android_graphics_ColorFilter(JNIEnv* env) {
     android::RegisterMethodsOrDie(env, "android/graphics/ColorFilter", colorfilter_methods,
diff --git a/libs/hwui/jni/RuntimeXfermode.cpp b/libs/hwui/jni/RuntimeXfermode.cpp
index c1c8964..17bee8f 100644
--- a/libs/hwui/jni/RuntimeXfermode.cpp
+++ b/libs/hwui/jni/RuntimeXfermode.cpp
@@ -14,6 +14,7 @@
  * limitations under the License.
  */
 
+#include "ColorFilter.h"
 #include "GraphicsJNI.h"
 #include "RuntimeEffectUtils.h"
 #include "SkBlender.h"
@@ -93,6 +94,19 @@
     }
 }
 
+static void RuntimeXfermode_updateColorFilter(JNIEnv* env, jobject, jlong builderPtr,
+                                              jstring childName, jlong colorFilterPtr) {
+    auto* builder = reinterpret_cast<SkRuntimeEffectBuilder*>(builderPtr);
+    ScopedUtfChars name(env, childName);
+    auto* child = reinterpret_cast<ColorFilter*>(colorFilterPtr);
+    if (child) {
+        auto childInput = child->getInstance();
+        if (childInput) {
+            UpdateChild(env, builder, name.c_str(), childInput.release());
+        }
+    }
+}
+
 static const JNINativeMethod gRuntimeXfermodeMethods[] = {
         {"nativeGetFinalizer", "()J", (void*)RuntimeXfermode_getNativeFinalizer},
         {"nativeCreateBlenderBuilder", "(Ljava/lang/String;)J",
@@ -107,6 +121,8 @@
         {"nativeUpdateUniforms", "(JLjava/lang/String;IIIII)V",
          (void*)RuntimeXfermode_updateIntUniforms},
         {"nativeUpdateChild", "(JLjava/lang/String;J)V", (void*)RuntimeXfermode_updateChild},
+        {"nativeUpdateColorFilter", "(JLjava/lang/String;J)V",
+         (void*)RuntimeXfermode_updateColorFilter},
 };
 
 int register_android_graphics_RuntimeXfermode(JNIEnv* env) {
diff --git a/libs/hwui/jni/Shader.cpp b/libs/hwui/jni/Shader.cpp
index 018c2b13..eadb9de 100644
--- a/libs/hwui/jni/Shader.cpp
+++ b/libs/hwui/jni/Shader.cpp
@@ -1,5 +1,6 @@
 #include <vector>
 
+#include "ColorFilter.h"
 #include "Gainmap.h"
 #include "GraphicsJNI.h"
 #include "RuntimeEffectUtils.h"
@@ -331,6 +332,15 @@
     builder->child(name.c_str()) = sk_ref_sp(shader);
 }
 
+static void RuntimeShader_updateColorFilter(JNIEnv* env, jobject, jlong shaderBuilder,
+                                            jstring jUniformName, jlong colorFilterHandle) {
+    SkRuntimeShaderBuilder* builder = reinterpret_cast<SkRuntimeShaderBuilder*>(shaderBuilder);
+    ScopedUtfChars name(env, jUniformName);
+    auto* childEffect = reinterpret_cast<ColorFilter*>(colorFilterHandle);
+
+    UpdateChild(env, builder, name.c_str(), childEffect->getInstance().release());
+}
+
 static void RuntimeShader_updateChild(JNIEnv* env, jobject, jlong shaderBuilder,
                                       jstring jUniformName, jlong childHandle) {
     SkRuntimeShaderBuilder* builder = reinterpret_cast<SkRuntimeShaderBuilder*>(shaderBuilder);
@@ -380,6 +390,8 @@
         {"nativeUpdateUniforms", "(JLjava/lang/String;IIIII)V",
          (void*)RuntimeShader_updateIntUniforms},
         {"nativeUpdateShader", "(JLjava/lang/String;J)V", (void*)RuntimeShader_updateShader},
+        {"nativeUpdateColorFilter", "(JLjava/lang/String;J)V",
+         (void*)RuntimeShader_updateColorFilter},
         {"nativeUpdateChild", "(JLjava/lang/String;J)V", (void*)RuntimeShader_updateChild},
 };