[Lut HAL backend] implementation 3rd patch.

- interpret the lut and pass them into shader.

Bug: 329472856
Test: builds
Flag: EXEMPT no flag needed
Change-Id: I005600593f4a369130bf8bcaea69300758b5ae03
diff --git a/libs/gui/Android.bp b/libs/gui/Android.bp
index 80e148b..1e33abb 100644
--- a/libs/gui/Android.bp
+++ b/libs/gui/Android.bp
@@ -274,6 +274,7 @@
         "LayerMetadata.cpp",
         "LayerStatePermissions.cpp",
         "LayerState.cpp",
+        "DisplayLuts.cpp",
         "OccupancyTracker.cpp",
         "StreamSplitter.cpp",
         "ScreenCaptureResults.cpp",
diff --git a/libs/gui/DisplayLuts.cpp b/libs/gui/DisplayLuts.cpp
new file mode 100644
index 0000000..8042976
--- /dev/null
+++ b/libs/gui/DisplayLuts.cpp
@@ -0,0 +1,81 @@
+/*
+ * Copyright (C) 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "include/gui/DisplayLuts.h"
+#include <gui/DisplayLuts.h>
+#include <private/gui/ParcelUtils.h>
+
+namespace android::gui {
+
+status_t DisplayLuts::Entry::readFromParcel(const android::Parcel* parcel) {
+    if (parcel == nullptr) {
+        ALOGE("%s: Null parcel", __func__);
+        return BAD_VALUE;
+    }
+
+    SAFE_PARCEL(parcel->readInt32, &dimension);
+    SAFE_PARCEL(parcel->readInt32, &size);
+    SAFE_PARCEL(parcel->readInt32, &samplingKey);
+
+    return OK;
+}
+
+status_t DisplayLuts::Entry::writeToParcel(android::Parcel* parcel) const {
+    if (parcel == nullptr) {
+        ALOGE("%s: Null parcel", __func__);
+        return BAD_VALUE;
+    }
+
+    SAFE_PARCEL(parcel->writeInt32, dimension);
+    SAFE_PARCEL(parcel->writeInt32, size);
+    SAFE_PARCEL(parcel->writeInt32, samplingKey);
+
+    return OK;
+}
+
+status_t DisplayLuts::readFromParcel(const android::Parcel* parcel) {
+    if (parcel == nullptr) {
+        ALOGE("%s: Null parcel", __func__);
+        return BAD_VALUE;
+    }
+
+    SAFE_PARCEL(parcel->readUniqueFileDescriptor, &fd);
+    SAFE_PARCEL(parcel->readInt32Vector, &offsets);
+    int32_t numLutProperties;
+    SAFE_PARCEL(parcel->readInt32, &numLutProperties);
+    lutProperties.reserve(numLutProperties);
+    for (int32_t i = 0; i < numLutProperties; i++) {
+        lutProperties.push_back({});
+        SAFE_PARCEL(lutProperties.back().readFromParcel, parcel);
+    }
+    return OK;
+}
+
+status_t DisplayLuts::writeToParcel(android::Parcel* parcel) const {
+    if (parcel == nullptr) {
+        ALOGE("%s: Null parcel", __func__);
+        return BAD_VALUE;
+    }
+
+    SAFE_PARCEL(parcel->writeUniqueFileDescriptor, fd);
+    SAFE_PARCEL(parcel->writeInt32Vector, offsets);
+    SAFE_PARCEL(parcel->writeInt32, static_cast<int32_t>(lutProperties.size()));
+    for (auto& entry : lutProperties) {
+        SAFE_PARCEL(entry.writeToParcel, parcel);
+    }
+    return OK;
+}
+} // namespace android::gui
\ No newline at end of file
diff --git a/libs/gui/LayerState.cpp b/libs/gui/LayerState.cpp
index 4b53134..139764a 100644
--- a/libs/gui/LayerState.cpp
+++ b/libs/gui/LayerState.cpp
@@ -203,6 +203,12 @@
         SAFE_PARCEL(output.writeParcelable, *bufferReleaseChannel);
     }
 
+    const bool hasLuts = (luts != nullptr);
+    SAFE_PARCEL(output.writeBool, hasLuts);
+    if (hasLuts) {
+        SAFE_PARCEL(output.writeParcelable, *luts);
+    }
+
     return NO_ERROR;
 }
 
@@ -358,6 +364,15 @@
         SAFE_PARCEL(input.readParcelable, bufferReleaseChannel.get());
     }
 
+    bool hasLuts;
+    SAFE_PARCEL(input.readBool, &hasLuts);
+    if (hasLuts) {
+        luts = std::make_shared<gui::DisplayLuts>();
+        SAFE_PARCEL(input.readParcelable, luts.get());
+    } else {
+        luts = nullptr;
+    }
+
     return NO_ERROR;
 }
 
diff --git a/libs/gui/SurfaceComposerClient.cpp b/libs/gui/SurfaceComposerClient.cpp
index 3260c53..807f850 100644
--- a/libs/gui/SurfaceComposerClient.cpp
+++ b/libs/gui/SurfaceComposerClient.cpp
@@ -1971,9 +1971,13 @@
         return *this;
     }
 
-    s->luts = std::make_shared<gui::DisplayLuts>(base::unique_fd(dup(lutFd.get())), offsets,
-                                                 dimensions, sizes, samplingKeys);
     s->what |= layer_state_t::eLutsChanged;
+    if (lutFd.ok()) {
+        s->luts = std::make_shared<gui::DisplayLuts>(base::unique_fd(dup(lutFd.get())), offsets,
+                                                     dimensions, sizes, samplingKeys);
+    } else {
+        s->luts = nullptr;
+    }
 
     registerSurfaceControlForCallback(sc);
     return *this;
diff --git a/libs/gui/include/gui/DisplayLuts.h b/libs/gui/include/gui/DisplayLuts.h
index 16a360d..ab86ac4 100644
--- a/libs/gui/include/gui/DisplayLuts.h
+++ b/libs/gui/include/gui/DisplayLuts.h
@@ -16,16 +16,24 @@
 #pragma once
 
 #include <android-base/unique_fd.h>
+#include <binder/Parcel.h>
+#include <binder/Parcelable.h>
 #include <vector>
 
 namespace android::gui {
 
-struct DisplayLuts {
+struct DisplayLuts : public Parcelable {
 public:
-    struct Entry {
+    struct Entry : public Parcelable {
+        Entry() {};
+        Entry(int32_t lutDimension, int32_t lutSize, int32_t lutSamplingKey)
+              : dimension(lutDimension), size(lutSize), samplingKey(lutSamplingKey) {}
         int32_t dimension;
         int32_t size;
         int32_t samplingKey;
+
+        status_t writeToParcel(android::Parcel* parcel) const override;
+        status_t readFromParcel(const android::Parcel* parcel) override;
     };
 
     DisplayLuts() {}
@@ -42,7 +50,10 @@
         }
     }
 
-    base::unique_fd& getLutFileDescriptor() { return fd; }
+    status_t writeToParcel(android::Parcel* parcel) const override;
+    status_t readFromParcel(const android::Parcel* parcel) override;
+
+    const base::unique_fd& getLutFileDescriptor() const { return fd; }
 
     std::vector<Entry> lutProperties;
     std::vector<int32_t> offsets;
diff --git a/libs/renderengine/Android.bp b/libs/renderengine/Android.bp
index d248ea0..7f207f0 100644
--- a/libs/renderengine/Android.bp
+++ b/libs/renderengine/Android.bp
@@ -105,6 +105,7 @@
         "skia/filters/KawaseBlurDualFilter.cpp",
         "skia/filters/KawaseBlurFilter.cpp",
         "skia/filters/LinearEffect.cpp",
+        "skia/filters/LutShader.cpp",
         "skia/filters/MouriMap.cpp",
         "skia/filters/StretchShaderFactory.cpp",
         "skia/filters/EdgeExtensionShaderFactory.cpp",
diff --git a/libs/renderengine/skia/SkiaRenderEngine.cpp b/libs/renderengine/skia/SkiaRenderEngine.cpp
index ec9d3ef..5c46c91 100644
--- a/libs/renderengine/skia/SkiaRenderEngine.cpp
+++ b/libs/renderengine/skia/SkiaRenderEngine.cpp
@@ -543,6 +543,10 @@
         }
     }
 
+    if (graphicBuffer && parameters.layer.luts) {
+        shader = mLutShader.lutShader(shader, parameters.layer.luts);
+    }
+
     if (parameters.requiresLinearEffect) {
         const auto format = targetBuffer != nullptr
                 ? std::optional<ui::PixelFormat>(
diff --git a/libs/renderengine/skia/SkiaRenderEngine.h b/libs/renderengine/skia/SkiaRenderEngine.h
index b5f8898..7be4c25 100644
--- a/libs/renderengine/skia/SkiaRenderEngine.h
+++ b/libs/renderengine/skia/SkiaRenderEngine.h
@@ -39,6 +39,7 @@
 #include "filters/BlurFilter.h"
 #include "filters/EdgeExtensionShaderFactory.h"
 #include "filters/LinearEffect.h"
+#include "filters/LutShader.h"
 #include "filters/StretchShaderFactory.h"
 
 class SkData;
@@ -184,6 +185,7 @@
 
     StretchShaderFactory mStretchShaderFactory;
     EdgeExtensionShaderFactory mEdgeExtensionShaderFactory;
+    LutShader mLutShader;
 
     sp<Fence> mLastDrawFence;
     BlurFilter* mBlurFilter = nullptr;
diff --git a/libs/renderengine/skia/filters/LutShader.cpp b/libs/renderengine/skia/filters/LutShader.cpp
new file mode 100644
index 0000000..cea46ef
--- /dev/null
+++ b/libs/renderengine/skia/filters/LutShader.cpp
@@ -0,0 +1,242 @@
+/*
+ * Copyright 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "LutShader.h"
+
+#include <SkTileMode.h>
+#include <common/trace.h>
+#include <cutils/ashmem.h>
+#include <math/half.h>
+#include <sys/mman.h>
+
+#include "include/core/SkColorSpace.h"
+#include "src/core/SkColorFilterPriv.h"
+
+using aidl::android::hardware::graphics::composer3::LutProperties;
+
+namespace android {
+namespace renderengine {
+namespace skia {
+
+static const SkString kShader = SkString(R"(
+    uniform shader image;
+    uniform shader lut;
+    uniform int size;
+    uniform int key;
+    uniform int dimension;
+    vec4 main(vec2 xy) {
+        float4 rgba = image.eval(xy);
+        float3 linear = toLinearSrgb(rgba.rgb);
+        if (dimension == 1) {
+            // RGB
+            if (key == 0) {
+                float indexR = linear.r * float(size - 1);
+                float indexG = linear.g * float(size - 1);
+                float indexB = linear.b * float(size - 1);
+                float gainR = lut.eval(vec2(indexR, 0.0) + 0.5).r;
+                float gainG = lut.eval(vec2(indexG, 0.0) + 0.5).r;
+                float gainB = lut.eval(vec2(indexB, 0.0) + 0.5).r;
+                return float4(linear.r * gainR, linear.g * gainG, linear.b * gainB, rgba.a);
+            // MAX_RGB
+            } else if (key == 1) {
+                float4 rgba = image.eval(xy);
+                float3 linear = toLinearSrgb(rgba.rgb);
+                float maxRGB = max(linear.r, max(linear.g, linear.b));
+                float index = maxRGB * float(size - 1);
+                float gain = lut.eval(vec2(index, 0.0) + 0.5).r;
+                return float4(linear * gain, rgba.a);
+            }
+        } else if (dimension == 3) {
+            if (key == 0) {
+                float tx = linear.r * float(size - 1);
+                float ty = linear.g * float(size - 1);
+                float tz = linear.b * float(size - 1);
+
+                // calculate lower and upper bounds for each dimension
+                int x = int(tx);
+                int y = int(ty);
+                int z = int(tz);
+
+                int i000 = x + y * size + z * size * size;
+                int i100 = i000 + 1;
+                int i010 = i000 + size;
+                int i110 = i000 + size + 1;
+                int i001 = i000 + size * size;
+                int i101 = i000 + size * size + 1;
+                int i011 = i000 + size * size + size;
+                int i111 = i000 + size * size + size + 1;
+
+                // get 1d normalized indices
+                float c000 = float(i000) / float(size * size * size);
+                float c100 = float(i100) / float(size * size * size);
+                float c010 = float(i010) / float(size * size * size);
+                float c110 = float(i110) / float(size * size * size);
+                float c001 = float(i001) / float(size * size * size);
+                float c101 = float(i101) / float(size * size * size);
+                float c011 = float(i011) / float(size * size * size);
+                float c111 = float(i111) / float(size * size * size);
+
+                //TODO(b/377984618): support Tetrahedral interpolation
+                // perform trilinear interpolation
+                float3 c00 = mix(lut.eval(vec2(c000, 0.0) + 0.5).rgb,
+                                 lut.eval(vec2(c100, 0.0) + 0.5).rgb, linear.r);
+                float3 c01 = mix(lut.eval(vec2(c001, 0.0) + 0.5).rgb,
+                                 lut.eval(vec2(c101, 0.0) + 0.5).rgb, linear.r);
+                float3 c10 = mix(lut.eval(vec2(c010, 0.0) + 0.5).rgb,
+                                 lut.eval(vec2(c110, 0.0) + 0.5).rgb, linear.r);
+                float3 c11 = mix(lut.eval(vec2(c011, 0.0) + 0.5).rgb,
+                                 lut.eval(vec2(c111, 0.0) + 0.5).rgb, linear.r);
+
+                float3 c0 = mix(c00, c10, linear.g);
+                float3 c1 = mix(c01, c11, linear.g);
+
+                float3 val = mix(c0, c1, linear.b);
+
+                return float4(val, rgba.a);
+            }
+        }
+        return rgba;
+    })");
+
+sk_sp<SkShader> LutShader::generateLutShader(sk_sp<SkShader> input,
+                                             const std::vector<float>& buffers,
+                                             const int32_t offset, const int32_t length,
+                                             const int32_t dimension, const int32_t size,
+                                             const int32_t samplingKey) {
+    SFTRACE_NAME("lut shader");
+    std::vector<half> buffer(length * 4); // 4 is for RGBA
+    auto d = static_cast<LutProperties::Dimension>(dimension);
+    if (d == LutProperties::Dimension::ONE_D) {
+        auto it = buffers.begin() + offset;
+        std::generate(buffer.begin(), buffer.end(), [it, i = 0]() mutable {
+            float val = (i++ % 4 == 0) ? *it++ : 0.0f;
+            return half(val);
+        });
+    } else {
+        for (int i = 0; i < length; i++) {
+            buffer[i * 4] = half(buffers[offset + i]);
+            buffer[i * 4 + 1] = half(buffers[offset + length + i]);
+            buffer[i * 4 + 2] = half(buffers[offset + length * 2 + i]);
+            buffer[i * 4 + 3] = half(0);
+        }
+    }
+    /**
+     * 1D Lut(rgba)
+     * (R0, 0, 0, 0)
+     * (R1, 0, 0, 0)
+     * ...
+     *
+     * 3D Lut
+     * (R0, G0, B0, 0)
+     * (R1, G1, B1, 0)
+     * ...
+     */
+    SkImageInfo info = SkImageInfo::Make(length /* the number of rgba */ * 4, 1,
+                                         kRGBA_F16_SkColorType, kPremul_SkAlphaType);
+    SkBitmap bitmap;
+    bitmap.allocPixels(info);
+    if (!bitmap.installPixels(info, buffer.data(), info.minRowBytes())) {
+        LOG_ALWAYS_FATAL("unable to install pixels");
+    }
+
+    sk_sp<SkImage> lutImage = SkImages::RasterFromBitmap(bitmap);
+    mBuilder->child("image") = input;
+    mBuilder->child("lut") =
+            lutImage->makeRawShader(SkTileMode::kClamp, SkTileMode::kClamp,
+                                    d == LutProperties::Dimension::ONE_D
+                                            ? SkSamplingOptions(SkFilterMode::kLinear)
+                                            : SkSamplingOptions());
+
+    const int uSize = static_cast<int>(size);
+    const int uKey = static_cast<int>(samplingKey);
+    const int uDimension = static_cast<int>(dimension);
+    mBuilder->uniform("size") = uSize;
+    mBuilder->uniform("key") = uKey;
+    mBuilder->uniform("dimension") = uDimension;
+    return mBuilder->makeShader();
+}
+
+sk_sp<SkShader> LutShader::lutShader(sk_sp<SkShader>& input,
+                                     std::shared_ptr<gui::DisplayLuts> displayLuts) {
+    if (mBuilder == nullptr) {
+        const static SkRuntimeEffect::Result instance = SkRuntimeEffect::MakeForShader(kShader);
+        mBuilder = std::make_unique<SkRuntimeShaderBuilder>(instance.effect);
+    }
+
+    auto& fd = displayLuts->getLutFileDescriptor();
+    if (fd.ok()) {
+        // de-gamma the image without changing the primaries
+        SkImage* baseImage = input->isAImage((SkMatrix*)nullptr, (SkTileMode*)nullptr);
+        if (baseImage) {
+            sk_sp<SkColorSpace> baseColorSpace =
+                    baseImage->colorSpace() ? baseImage->refColorSpace() : SkColorSpace::MakeSRGB();
+            sk_sp<SkColorSpace> gainmapMathColorSpace = baseColorSpace->makeLinearGamma();
+            auto colorXformSdrToGainmap =
+                    SkColorFilterPriv::MakeColorSpaceXform(baseColorSpace, gainmapMathColorSpace);
+            input = input->makeWithColorFilter(colorXformSdrToGainmap);
+        }
+
+        auto& offsets = displayLuts->offsets;
+        auto& lutProperties = displayLuts->lutProperties;
+        std::vector<float> buffers;
+        int fullLength = offsets[lutProperties.size() - 1];
+        if (lutProperties[lutProperties.size() - 1].dimension == 1) {
+            fullLength += lutProperties[lutProperties.size() - 1].size;
+        } else {
+            fullLength += (lutProperties[lutProperties.size() - 1].size *
+                           lutProperties[lutProperties.size() - 1].size *
+                           lutProperties[lutProperties.size() - 1].size * 3);
+        }
+        size_t bufferSize = fullLength * sizeof(float);
+
+        // decode the shared memory of luts
+        float* ptr =
+                (float*)mmap(NULL, bufferSize, PROT_READ | PROT_WRITE, MAP_SHARED, fd.get(), 0);
+        if (ptr == MAP_FAILED) {
+            LOG_ALWAYS_FATAL("mmap failed");
+        }
+        buffers = std::vector<float>(ptr, ptr + fullLength);
+        munmap(ptr, bufferSize);
+
+        for (size_t i = 0; i < offsets.size(); i++) {
+            int bufferSizePerLut = (i == offsets.size() - 1) ? buffers.size() - offsets[i]
+                                                             : offsets[i + 1] - offsets[i];
+            // divide by 3 for 3d Lut because of 3 (RGB) channels
+            if (static_cast<LutProperties::Dimension>(lutProperties[i].dimension) ==
+                LutProperties::Dimension::THREE_D) {
+                bufferSizePerLut /= 3;
+            }
+            input = generateLutShader(input, buffers, offsets[i], bufferSizePerLut,
+                                      lutProperties[i].dimension, lutProperties[i].size,
+                                      lutProperties[i].samplingKey);
+        }
+
+        // re-gamma
+        baseImage = input->isAImage((SkMatrix*)nullptr, (SkTileMode*)nullptr);
+        if (baseImage) {
+            sk_sp<SkColorSpace> baseColorSpace =
+                    baseImage->colorSpace() ? baseImage->refColorSpace() : SkColorSpace::MakeSRGB();
+            sk_sp<SkColorSpace> gainmapMathColorSpace = baseColorSpace->makeLinearGamma();
+            auto colorXformGainmapToDst =
+                    SkColorFilterPriv::MakeColorSpaceXform(gainmapMathColorSpace, baseColorSpace);
+            input = input->makeWithColorFilter(colorXformGainmapToDst);
+        }
+    }
+    return input;
+}
+
+} // namespace skia
+} // namespace renderengine
+} // namespace android
\ No newline at end of file
diff --git a/libs/renderengine/skia/filters/LutShader.h b/libs/renderengine/skia/filters/LutShader.h
new file mode 100644
index 0000000..c157904
--- /dev/null
+++ b/libs/renderengine/skia/filters/LutShader.h
@@ -0,0 +1,44 @@
+/*
+ * Copyright 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *-
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#pragma once
+
+#include <SkBitmap.h>
+#include <SkImage.h>
+#include <SkRuntimeEffect.h>
+
+#include <aidl/android/hardware/graphics/composer3/LutProperties.h>
+#include <gui/DisplayLuts.h>
+
+namespace android {
+namespace renderengine {
+namespace skia {
+
+class LutShader {
+public:
+    sk_sp<SkShader> lutShader(sk_sp<SkShader>& input,
+                              std::shared_ptr<gui::DisplayLuts> displayLuts);
+
+private:
+    sk_sp<SkShader> generateLutShader(sk_sp<SkShader> input, const std::vector<float>& buffers,
+                                      const int32_t offset, const int32_t length,
+                                      const int32_t dimension, const int32_t size,
+                                      const int32_t samplingKey);
+    std::unique_ptr<SkRuntimeShaderBuilder> mBuilder;
+};
+
+} // namespace skia
+} // namespace renderengine
+} // namespace android