Merge changes I8e0076d3,Id39dab35 into main

* changes:
  SF: Add property to skip boot animation
  SF: Remove StartPropertySetThread
diff --git a/cmds/dumpstate/dumpstate.cpp b/cmds/dumpstate/dumpstate.cpp
index 8d37aac..6b9a0a0 100644
--- a/cmds/dumpstate/dumpstate.cpp
+++ b/cmds/dumpstate/dumpstate.cpp
@@ -3280,6 +3280,12 @@
     // duration is logged into MYLOG instead.
     PrintHeader();
 
+    bool system_trace_exists = access(SYSTEM_TRACE_SNAPSHOT, F_OK) == 0;
+    if (options_->use_predumped_ui_data && !system_trace_exists) {
+        MYLOGW("Ignoring 'use predumped data' flag because no predumped data is available");
+        options_->use_predumped_ui_data = false;
+    }
+
     std::future<std::string> snapshot_system_trace;
 
     bool is_dumpstate_restricted =
@@ -4191,14 +4197,14 @@
 }
 
 int read_file_as_long(const char *path, long int *output) {
-    int fd = TEMP_FAILURE_RETRY(open(path, O_RDONLY | O_NONBLOCK | O_CLOEXEC));
-    if (fd < 0) {
+    android::base::unique_fd fd(TEMP_FAILURE_RETRY(open(path, O_RDONLY | O_NONBLOCK | O_CLOEXEC)));
+    if (fd.get() < 0) {
         int err = errno;
         MYLOGE("Error opening file descriptor for %s: %s\n", path, strerror(err));
         return -1;
     }
     char buffer[50];
-    ssize_t bytes_read = TEMP_FAILURE_RETRY(read(fd, buffer, sizeof(buffer)));
+    ssize_t bytes_read = TEMP_FAILURE_RETRY(read(fd.get(), buffer, sizeof(buffer)));
     if (bytes_read == -1) {
         MYLOGE("Error reading file %s: %s\n", path, strerror(errno));
         return -2;
diff --git a/cmds/evemu-record/main.rs b/cmds/evemu-record/main.rs
index c30c00f..db3fd77 100644
--- a/cmds/evemu-record/main.rs
+++ b/cmds/evemu-record/main.rs
@@ -120,7 +120,7 @@
     fn print_in_8_byte_chunks(
         output: &mut impl Write,
         prefix: &str,
-        data: &Vec<u8>,
+        data: &[u8],
     ) -> Result<(), io::Error> {
         for (i, byte) in data.iter().enumerate() {
             if i % 8 == 0 {
diff --git a/cmds/installd/otapreopt_script.sh b/cmds/installd/otapreopt_script.sh
index 28bd793..ae7d8e0 100644
--- a/cmds/installd/otapreopt_script.sh
+++ b/cmds/installd/otapreopt_script.sh
@@ -50,6 +50,12 @@
   exit 1
 fi
 
+if pm art on-ota-staged --slot "$TARGET_SLOT_SUFFIX"; then
+  # Handled by Pre-reboot Dexopt.
+  exit 0
+fi
+echo "Pre-reboot Dexopt not enabled. Fall back to otapreopt."
+
 if [ "$(/system/bin/otapreopt_chroot --version)" != 2 ]; then
   # We require an updated chroot wrapper that reads dexopt commands from stdin.
   # Even if we kept compat with the old binary, the OTA preopt wouldn't work due
diff --git a/data/etc/Android.bp b/data/etc/Android.bp
index 8b2842a..da232a5 100644
--- a/data/etc/Android.bp
+++ b/data/etc/Android.bp
@@ -287,6 +287,12 @@
 }
 
 prebuilt_etc {
+    name: "android.hardware.telephony.data.prebuilt.xml",
+    src: "android.hardware.telephony.data.xml",
+    defaults: ["frameworks_native_data_etc_defaults"],
+}
+
+prebuilt_etc {
     name: "android.hardware.telephony.gsm.prebuilt.xml",
     src: "android.hardware.telephony.gsm.xml",
     defaults: ["frameworks_native_data_etc_defaults"],
@@ -329,12 +335,24 @@
 }
 
 prebuilt_etc {
+    name: "android.hardware.vulkan.level-1.prebuilt.xml",
+    src: "android.hardware.vulkan.level-1.xml",
+    defaults: ["frameworks_native_data_etc_defaults"],
+}
+
+prebuilt_etc {
     name: "android.hardware.vulkan.version-1_0_3.prebuilt.xml",
     src: "android.hardware.vulkan.version-1_0_3.xml",
     defaults: ["frameworks_native_data_etc_defaults"],
 }
 
 prebuilt_etc {
+    name: "android.hardware.vulkan.version-1_3.prebuilt.xml",
+    src: "android.hardware.vulkan.version-1_3.xml",
+    defaults: ["frameworks_native_data_etc_defaults"],
+}
+
+prebuilt_etc {
     name: "android.hardware.wifi.prebuilt.xml",
     src: "android.hardware.wifi.xml",
     defaults: ["frameworks_native_data_etc_defaults"],
@@ -353,6 +371,12 @@
 }
 
 prebuilt_etc {
+    name: "android.software.contextualsearch.prebuilt.xml",
+    src: "android.software.contextualsearch.xml",
+    defaults: ["frameworks_native_data_etc_defaults"],
+}
+
+prebuilt_etc {
     name: "android.software.device_id_attestation.prebuilt.xml",
     src: "android.software.device_id_attestation.xml",
     defaults: ["frameworks_native_data_etc_defaults"],
diff --git a/include/android/surface_control.h b/include/android/surface_control.h
index 321737e..bf1f2e9 100644
--- a/include/android/surface_control.h
+++ b/include/android/surface_control.h
@@ -532,7 +532,7 @@
  * using this API for formats that encode an HDR/SDR ratio as part of generating the buffer.
  *
  * @param surface_control The layer whose extended range brightness is being specified
- * @param currentBufferRatio The current hdr/sdr ratio of the current buffer as represented as
+ * @param currentBufferRatio The current HDR/SDR ratio of the current buffer as represented as
  *                           peakHdrBrightnessInNits / targetSdrWhitePointInNits. For example if the
  *                           buffer was rendered with a target SDR whitepoint of 100nits and a max
  *                           display brightness of 200nits, this should be set to 2.0f.
@@ -546,7 +546,7 @@
  *
  *                           Must be finite && >= 1.0f
  *
- * @param desiredRatio The desired hdr/sdr ratio as represented as peakHdrBrightnessInNits /
+ * @param desiredRatio The desired HDR/SDR ratio as represented as peakHdrBrightnessInNits /
  *                     targetSdrWhitePointInNits. This can be used to communicate the max desired
  *                     brightness range. This is similar to the "max luminance" value in other
  *                     HDR metadata formats, but represented as a ratio of the target SDR whitepoint
@@ -579,13 +579,13 @@
                                             float desiredRatio) __INTRODUCED_IN(__ANDROID_API_U__);
 
 /**
- * Sets the desired hdr headroom for the layer. See: ASurfaceTransaction_setExtendedRangeBrightness,
+ * Sets the desired HDR headroom for the layer. See: ASurfaceTransaction_setExtendedRangeBrightness,
  * prefer using this API for formats that conform to HDR standards like HLG or HDR10, that do not
  * communicate a HDR/SDR ratio as part of generating the buffer.
  *
- * @param surface_control The layer whose desired hdr headroom is being specified
+ * @param surface_control The layer whose desired HDR headroom is being specified
  *
- * @param desiredHeadroom The desired hdr/sdr ratio as represented as peakHdrBrightnessInNits /
+ * @param desiredHeadroom The desired HDR/SDR ratio as represented as peakHdrBrightnessInNits /
  *                        targetSdrWhitePointInNits. This can be used to communicate the max
  *                        desired brightness range of the panel. The system may not be able to, or
  *                        may choose not to, deliver the requested range.
diff --git a/include/input/Input.h b/include/input/Input.h
index 374254f..ddc3768 100644
--- a/include/input/Input.h
+++ b/include/input/Input.h
@@ -662,10 +662,6 @@
 
     inline void setActionButton(int32_t button) { mActionButton = button; }
 
-    inline float getXOffset() const { return mTransform.tx(); }
-
-    inline float getYOffset() const { return mTransform.ty(); }
-
     inline const ui::Transform& getTransform() const { return mTransform; }
 
     std::optional<ui::Rotation> getSurfaceRotation() const;
@@ -880,6 +876,22 @@
 
     void offsetLocation(float xOffset, float yOffset);
 
+    /**
+     * Get the X offset of this motion event relative to the origin of the raw coordinate space.
+     *
+     * In practice, this is the delta that was added to the raw screen coordinates (i.e. in logical
+     * display space) to adjust for the absolute position of the containing windows and views.
+     */
+    float getRawXOffset() const;
+
+    /**
+     * Get the Y offset of this motion event relative to the origin of the raw coordinate space.
+     *
+     * In practice, this is the delta that was added to the raw screen coordinates (i.e. in logical
+     * display space) to adjust for the absolute position of the containing windows and views.
+     */
+    float getRawYOffset() const;
+
     void scale(float globalScaleFactor);
 
     // Set 3x3 perspective matrix transformation.
diff --git a/include/input/InputTransport.h b/include/input/InputTransport.h
index 42dcd3c..aca4b62 100644
--- a/include/input/InputTransport.h
+++ b/include/input/InputTransport.h
@@ -670,13 +670,6 @@
     status_t sendUnchainedFinishedSignal(uint32_t seq, bool handled);
 
     static void rewriteMessage(TouchState& state, InputMessage& msg);
-    static void initializeKeyEvent(KeyEvent* event, const InputMessage* msg);
-    static void initializeMotionEvent(MotionEvent* event, const InputMessage* msg);
-    static void initializeFocusEvent(FocusEvent* event, const InputMessage* msg);
-    static void initializeCaptureEvent(CaptureEvent* event, const InputMessage* msg);
-    static void initializeDragEvent(DragEvent* event, const InputMessage* msg);
-    static void initializeTouchModeEvent(TouchModeEvent* event, const InputMessage* msg);
-    static void addSample(MotionEvent* event, const InputMessage* msg);
     static bool canAddSample(const Batch& batch, const InputMessage* msg);
     static ssize_t findSampleNoLaterThan(const Batch& batch, nsecs_t time);
 
diff --git a/libs/binder/Parcel.cpp b/libs/binder/Parcel.cpp
index 2dd310e..35cea81 100644
--- a/libs/binder/Parcel.cpp
+++ b/libs/binder/Parcel.cpp
@@ -2976,14 +2976,14 @@
         return continueWrite(desired);
     }
 
+    releaseObjects();
+
     uint8_t* data = reallocZeroFree(mData, mDataCapacity, desired, mDeallocZero);
     if (!data && desired > mDataCapacity) {
         mError = NO_MEMORY;
         return NO_MEMORY;
     }
 
-    releaseObjects();
-
     if (data || desired == 0) {
         LOG_ALLOC("Parcel %p: restart from %zu to %zu capacity", this, mDataCapacity, desired);
         if (mDataCapacity > desired) {
diff --git a/libs/binder/tests/parcel_fuzzer/binder.cpp b/libs/binder/tests/parcel_fuzzer/binder.cpp
index 5c280f4..e378b86 100644
--- a/libs/binder/tests/parcel_fuzzer/binder.cpp
+++ b/libs/binder/tests/parcel_fuzzer/binder.cpp
@@ -115,6 +115,14 @@
         p.setDataPosition(pos);
         FUZZ_LOG() << "setDataPosition done";
     },
+    [] (const ::android::Parcel& p, FuzzedDataProvider& provider) {
+        size_t len = provider.ConsumeIntegralInRange<size_t>(0, 1024);
+        std::vector<uint8_t> bytes = provider.ConsumeBytes<uint8_t>(len);
+        FUZZ_LOG() << "about to setData: " <<(bytes.data() ? HexString(bytes.data(), bytes.size()) : "null");
+        // TODO: allow all read and write operations
+        (*const_cast<::android::Parcel*>(&p)).setData(bytes.data(), bytes.size());
+        FUZZ_LOG() << "setData done";
+    },
     PARCEL_READ_NO_STATUS(size_t, allowFds),
     PARCEL_READ_NO_STATUS(size_t, hasFileDescriptors),
     PARCEL_READ_NO_STATUS(std::vector<android::sp<android::IBinder>>, debugReadAllStrongBinders),
diff --git a/libs/gui/Choreographer.cpp b/libs/gui/Choreographer.cpp
index 4518b67..54290cd 100644
--- a/libs/gui/Choreographer.cpp
+++ b/libs/gui/Choreographer.cpp
@@ -143,9 +143,9 @@
 void Choreographer::postFrameCallbackDelayed(AChoreographer_frameCallback cb,
                                              AChoreographer_frameCallback64 cb64,
                                              AChoreographer_vsyncCallback vsyncCallback, void* data,
-                                             nsecs_t delay) {
+                                             nsecs_t delay, CallbackType callbackType) {
     nsecs_t now = systemTime(SYSTEM_TIME_MONOTONIC);
-    FrameCallback callback{cb, cb64, vsyncCallback, data, now + delay};
+    FrameCallback callback{cb, cb64, vsyncCallback, data, now + delay, callbackType};
     {
         std::lock_guard<std::mutex> _l{mLock};
         mFrameCallbacks.push(callback);
@@ -285,18 +285,8 @@
     }
 }
 
-void Choreographer::dispatchVsync(nsecs_t timestamp, PhysicalDisplayId, uint32_t,
-                                  VsyncEventData vsyncEventData) {
-    std::vector<FrameCallback> callbacks{};
-    {
-        std::lock_guard<std::mutex> _l{mLock};
-        nsecs_t now = systemTime(SYSTEM_TIME_MONOTONIC);
-        while (!mFrameCallbacks.empty() && mFrameCallbacks.top().dueTime < now) {
-            callbacks.push_back(mFrameCallbacks.top());
-            mFrameCallbacks.pop();
-        }
-    }
-    mLastVsyncEventData = vsyncEventData;
+void Choreographer::dispatchCallbacks(const std::vector<FrameCallback>& callbacks,
+                                      VsyncEventData vsyncEventData, nsecs_t timestamp) {
     for (const auto& cb : callbacks) {
         if (cb.vsyncCallback != nullptr) {
             ATRACE_FORMAT("AChoreographer_vsyncCallback %" PRId64,
@@ -319,6 +309,34 @@
     }
 }
 
+void Choreographer::dispatchVsync(nsecs_t timestamp, PhysicalDisplayId, uint32_t,
+                                  VsyncEventData vsyncEventData) {
+    std::vector<FrameCallback> animationCallbacks{};
+    std::vector<FrameCallback> inputCallbacks{};
+    {
+        std::lock_guard<std::mutex> _l{mLock};
+        nsecs_t now = systemTime(SYSTEM_TIME_MONOTONIC);
+        while (!mFrameCallbacks.empty() && mFrameCallbacks.top().dueTime < now) {
+            if (mFrameCallbacks.top().callbackType == CALLBACK_INPUT) {
+                inputCallbacks.push_back(mFrameCallbacks.top());
+            } else {
+                animationCallbacks.push_back(mFrameCallbacks.top());
+            }
+            mFrameCallbacks.pop();
+        }
+    }
+    mLastVsyncEventData = vsyncEventData;
+    // Callbacks with type CALLBACK_INPUT should always run first
+    {
+        ATRACE_FORMAT("CALLBACK_INPUT");
+        dispatchCallbacks(inputCallbacks, vsyncEventData, timestamp);
+    }
+    {
+        ATRACE_FORMAT("CALLBACK_ANIMATION");
+        dispatchCallbacks(animationCallbacks, vsyncEventData, timestamp);
+    }
+}
+
 void Choreographer::dispatchHotplug(nsecs_t, PhysicalDisplayId displayId, bool connected) {
     ALOGV("choreographer %p ~ received hotplug event (displayId=%s, connected=%s), ignoring.", this,
           to_string(displayId).c_str(), toString(connected));
diff --git a/libs/gui/include/gui/Choreographer.h b/libs/gui/include/gui/Choreographer.h
index 55a7aa7..fc79b03 100644
--- a/libs/gui/include/gui/Choreographer.h
+++ b/libs/gui/include/gui/Choreographer.h
@@ -28,12 +28,18 @@
 namespace android {
 using gui::VsyncEventData;
 
+enum CallbackType : int8_t {
+    CALLBACK_INPUT,
+    CALLBACK_ANIMATION,
+};
+
 struct FrameCallback {
     AChoreographer_frameCallback callback;
     AChoreographer_frameCallback64 callback64;
     AChoreographer_vsyncCallback vsyncCallback;
     void* data;
     nsecs_t dueTime;
+    CallbackType callbackType;
 
     inline bool operator<(const FrameCallback& rhs) const {
         // Note that this is intentionally flipped because we want callbacks due sooner to be at
@@ -78,7 +84,7 @@
     void postFrameCallbackDelayed(AChoreographer_frameCallback cb,
                                   AChoreographer_frameCallback64 cb64,
                                   AChoreographer_vsyncCallback vsyncCallback, void* data,
-                                  nsecs_t delay);
+                                  nsecs_t delay, CallbackType callbackType);
     void registerRefreshRateCallback(AChoreographer_refreshRateCallback cb, void* data)
             EXCLUDES(gChoreographers.lock);
     void unregisterRefreshRateCallback(AChoreographer_refreshRateCallback cb, void* data);
@@ -109,6 +115,8 @@
 
     void dispatchVsync(nsecs_t timestamp, PhysicalDisplayId displayId, uint32_t count,
                        VsyncEventData vsyncEventData) override;
+    void dispatchCallbacks(const std::vector<FrameCallback>&, VsyncEventData vsyncEventData,
+                           nsecs_t timestamp);
     void dispatchHotplug(nsecs_t timestamp, PhysicalDisplayId displayId, bool connected) override;
     void dispatchHotplugConnectionError(nsecs_t timestamp, int32_t connectionError) override;
     void dispatchModeChanged(nsecs_t timestamp, PhysicalDisplayId displayId, int32_t modeId,
diff --git a/libs/gui/tests/Android.bp b/libs/gui/tests/Android.bp
index e606b99..0f16f71 100644
--- a/libs/gui/tests/Android.bp
+++ b/libs/gui/tests/Android.bp
@@ -30,6 +30,7 @@
         "BLASTBufferQueue_test.cpp",
         "BufferItemConsumer_test.cpp",
         "BufferQueue_test.cpp",
+        "Choreographer_test.cpp",
         "CompositorTiming_test.cpp",
         "CpuConsumer_test.cpp",
         "EndToEndNativeInputTest.cpp",
@@ -61,6 +62,7 @@
         "libSurfaceFlingerProp",
         "libGLESv1_CM",
         "libinput",
+        "libnativedisplay",
     ],
 
     static_libs: [
diff --git a/libs/gui/tests/Choreographer_test.cpp b/libs/gui/tests/Choreographer_test.cpp
new file mode 100644
index 0000000..2ac2550
--- /dev/null
+++ b/libs/gui/tests/Choreographer_test.cpp
@@ -0,0 +1,88 @@
+/*
+ * Copyright (C) 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#define LOG_TAG "Choreographer_test"
+
+#include <android-base/stringprintf.h>
+#include <android/choreographer.h>
+#include <gtest/gtest.h>
+#include <gui/Choreographer.h>
+#include <utils/Looper.h>
+#include <chrono>
+#include <future>
+#include <string>
+
+namespace android {
+class ChoreographerTest : public ::testing::Test {};
+
+struct VsyncCallback {
+    std::atomic<bool> completePromise{false};
+    std::chrono::nanoseconds frameTime{0LL};
+    std::chrono::nanoseconds receivedCallbackTime{0LL};
+
+    void onVsyncCallback(const AChoreographerFrameCallbackData* callbackData) {
+        frameTime = std::chrono::nanoseconds{
+                AChoreographerFrameCallbackData_getFrameTimeNanos(callbackData)};
+        receivedCallbackTime = std::chrono::nanoseconds{systemTime(SYSTEM_TIME_MONOTONIC)};
+        completePromise.store(true);
+    }
+
+    bool callbackReceived() { return completePromise.load(); }
+};
+
+static void vsyncCallback(const AChoreographerFrameCallbackData* callbackData, void* data) {
+    VsyncCallback* cb = static_cast<VsyncCallback*>(data);
+    cb->onVsyncCallback(callbackData);
+}
+
+TEST_F(ChoreographerTest, InputCallbackBeforeAnimation) {
+    sp<Looper> looper = Looper::prepare(0);
+    Choreographer* choreographer = Choreographer::getForThread();
+    VsyncCallback animationCb;
+    VsyncCallback inputCb;
+
+    choreographer->postFrameCallbackDelayed(nullptr, nullptr, vsyncCallback, &animationCb, 0,
+                                            CALLBACK_ANIMATION);
+    choreographer->postFrameCallbackDelayed(nullptr, nullptr, vsyncCallback, &inputCb, 0,
+                                            CALLBACK_INPUT);
+
+    nsecs_t startTime = systemTime(SYSTEM_TIME_MONOTONIC);
+    nsecs_t currTime;
+    int pollResult;
+    do {
+        pollResult = looper->pollOnce(16);
+        currTime = systemTime(SYSTEM_TIME_MONOTONIC);
+    } while (!(inputCb.callbackReceived() && animationCb.callbackReceived()) &&
+             (pollResult != Looper::POLL_TIMEOUT && pollResult != Looper::POLL_ERROR) &&
+             (currTime - startTime < 3000));
+
+    ASSERT_TRUE(inputCb.callbackReceived()) << "did not receive input callback";
+    ASSERT_TRUE(animationCb.callbackReceived()) << "did not receive animation callback";
+
+    ASSERT_EQ(inputCb.frameTime, animationCb.frameTime)
+            << android::base::StringPrintf("input and animation callback frame times don't match. "
+                                           "inputFrameTime=%lld  animationFrameTime=%lld",
+                                           inputCb.frameTime.count(),
+                                           animationCb.frameTime.count());
+
+    ASSERT_LT(inputCb.receivedCallbackTime, animationCb.receivedCallbackTime)
+            << android::base::StringPrintf("input callback was not called first. "
+                                           "inputCallbackTime=%lld  animationCallbackTime=%lld",
+                                           inputCb.frameTime.count(),
+                                           animationCb.frameTime.count());
+}
+
+} // namespace android
\ No newline at end of file
diff --git a/libs/input/Android.bp b/libs/input/Android.bp
index 8b69339..0171d74 100644
--- a/libs/input/Android.bp
+++ b/libs/input/Android.bp
@@ -154,6 +154,10 @@
     ],
     generated_sources: ["libinput_cxx_bridge_code"],
 
+    lto: {
+        never: true,
+    },
+
     shared_libs: [
         "libbase",
     ],
diff --git a/libs/input/Input.cpp b/libs/input/Input.cpp
index d58fb42..ff9d9a9 100644
--- a/libs/input/Input.cpp
+++ b/libs/input/Input.cpp
@@ -751,6 +751,18 @@
     mTransform.set(currXOffset + xOffset, currYOffset + yOffset);
 }
 
+float MotionEvent::getRawXOffset() const {
+    // This is equivalent to the x-coordinate of the point that the origin of the raw coordinate
+    // space maps to.
+    return (mTransform * mRawTransform.inverse()).tx();
+}
+
+float MotionEvent::getRawYOffset() const {
+    // This is equivalent to the y-coordinate of the point that the origin of the raw coordinate
+    // space maps to.
+    return (mTransform * mRawTransform.inverse()).ty();
+}
+
 void MotionEvent::scale(float globalScaleFactor) {
     mTransform.set(mTransform.tx() * globalScaleFactor, mTransform.ty() * globalScaleFactor);
     mRawTransform.set(mRawTransform.tx() * globalScaleFactor,
diff --git a/libs/input/InputTransport.cpp b/libs/input/InputTransport.cpp
index e49f4eb..b3a36eb 100644
--- a/libs/input/InputTransport.cpp
+++ b/libs/input/InputTransport.cpp
@@ -26,10 +26,13 @@
 
 #include <com_android_input_flags.h>
 #include <input/InputTransport.h>
+#include <input/PrintTools.h>
 #include <input/TraceTools.h>
 
 namespace input_flags = com::android::input::flags;
 
+namespace android {
+
 namespace {
 
 /**
@@ -110,21 +113,76 @@
     return newFd;
 }
 
-} // namespace
+void initializeKeyEvent(KeyEvent& event, const InputMessage& msg) {
+    event.initialize(msg.body.key.eventId, msg.body.key.deviceId, msg.body.key.source,
+                     msg.body.key.displayId, msg.body.key.hmac, msg.body.key.action,
+                     msg.body.key.flags, msg.body.key.keyCode, msg.body.key.scanCode,
+                     msg.body.key.metaState, msg.body.key.repeatCount, msg.body.key.downTime,
+                     msg.body.key.eventTime);
+}
 
-using android::base::Result;
-using android::base::StringPrintf;
+void initializeFocusEvent(FocusEvent& event, const InputMessage& msg) {
+    event.initialize(msg.body.focus.eventId, msg.body.focus.hasFocus);
+}
 
-namespace android {
+void initializeCaptureEvent(CaptureEvent& event, const InputMessage& msg) {
+    event.initialize(msg.body.capture.eventId, msg.body.capture.pointerCaptureEnabled);
+}
+
+void initializeDragEvent(DragEvent& event, const InputMessage& msg) {
+    event.initialize(msg.body.drag.eventId, msg.body.drag.x, msg.body.drag.y,
+                     msg.body.drag.isExiting);
+}
+
+void initializeMotionEvent(MotionEvent& event, const InputMessage& msg) {
+    uint32_t pointerCount = msg.body.motion.pointerCount;
+    PointerProperties pointerProperties[pointerCount];
+    PointerCoords pointerCoords[pointerCount];
+    for (uint32_t i = 0; i < pointerCount; i++) {
+        pointerProperties[i] = msg.body.motion.pointers[i].properties;
+        pointerCoords[i] = msg.body.motion.pointers[i].coords;
+    }
+
+    ui::Transform transform;
+    transform.set({msg.body.motion.dsdx, msg.body.motion.dtdx, msg.body.motion.tx,
+                   msg.body.motion.dtdy, msg.body.motion.dsdy, msg.body.motion.ty, 0, 0, 1});
+    ui::Transform displayTransform;
+    displayTransform.set({msg.body.motion.dsdxRaw, msg.body.motion.dtdxRaw, msg.body.motion.txRaw,
+                          msg.body.motion.dtdyRaw, msg.body.motion.dsdyRaw, msg.body.motion.tyRaw,
+                          0, 0, 1});
+    event.initialize(msg.body.motion.eventId, msg.body.motion.deviceId, msg.body.motion.source,
+                     msg.body.motion.displayId, msg.body.motion.hmac, msg.body.motion.action,
+                     msg.body.motion.actionButton, msg.body.motion.flags, msg.body.motion.edgeFlags,
+                     msg.body.motion.metaState, msg.body.motion.buttonState,
+                     msg.body.motion.classification, transform, msg.body.motion.xPrecision,
+                     msg.body.motion.yPrecision, msg.body.motion.xCursorPosition,
+                     msg.body.motion.yCursorPosition, displayTransform, msg.body.motion.downTime,
+                     msg.body.motion.eventTime, pointerCount, pointerProperties, pointerCoords);
+}
+
+void addSample(MotionEvent& event, const InputMessage& msg) {
+    uint32_t pointerCount = msg.body.motion.pointerCount;
+    PointerCoords pointerCoords[pointerCount];
+    for (uint32_t i = 0; i < pointerCount; i++) {
+        pointerCoords[i] = msg.body.motion.pointers[i].coords;
+    }
+
+    event.setMetaState(event.getMetaState() | msg.body.motion.metaState);
+    event.addSample(msg.body.motion.eventTime, pointerCoords);
+}
+
+void initializeTouchModeEvent(TouchModeEvent& event, const InputMessage& msg) {
+    event.initialize(msg.body.touchMode.eventId, msg.body.touchMode.isInTouchMode);
+}
 
 // Socket buffer size.  The default is typically about 128KB, which is much larger than
 // we really need.  So we make it smaller.  It just needs to be big enough to hold
 // a few dozen large multi-finger motion events in the case where an application gets
 // behind processing touches.
-static const size_t SOCKET_BUFFER_SIZE = 32 * 1024;
+static constexpr size_t SOCKET_BUFFER_SIZE = 32 * 1024;
 
 // Nanoseconds per milliseconds.
-static const nsecs_t NANOS_PER_MS = 1000000;
+static constexpr nsecs_t NANOS_PER_MS = 1000000;
 
 // Latency added during resampling.  A few milliseconds doesn't hurt much but
 // reduces the impact of mispredicted touch positions.
@@ -157,32 +215,28 @@
  * Crash if the events that are getting sent to the InputPublisher are inconsistent.
  * Enable this via "adb shell setprop log.tag.InputTransportVerifyEvents DEBUG"
  */
-static bool verifyEvents() {
+bool verifyEvents() {
     return input_flags::enable_outbound_event_verification() ||
             __android_log_is_loggable(ANDROID_LOG_DEBUG, LOG_TAG "VerifyEvents", ANDROID_LOG_INFO);
 }
 
-template<typename T>
-inline static T min(const T& a, const T& b) {
-    return a < b ? a : b;
-}
-
-inline static float lerp(float a, float b, float alpha) {
+inline float lerp(float a, float b, float alpha) {
     return a + alpha * (b - a);
 }
 
-inline static bool isPointerEvent(int32_t source) {
+inline bool isPointerEvent(int32_t source) {
     return (source & AINPUT_SOURCE_CLASS_POINTER) == AINPUT_SOURCE_CLASS_POINTER;
 }
 
-inline static const char* toString(bool value) {
-    return value ? "true" : "false";
-}
-
-static bool shouldResampleTool(ToolType toolType) {
+bool shouldResampleTool(ToolType toolType) {
     return toolType == ToolType::FINGER || toolType == ToolType::UNKNOWN;
 }
 
+} // namespace
+
+using android::base::Result;
+using android::base::StringPrintf;
+
 // --- InputMessage ---
 
 bool InputMessage::isValid(size_t actualSize) const {
@@ -902,7 +956,7 @@
                 KeyEvent* keyEvent = factory->createKeyEvent();
                 if (!keyEvent) return NO_MEMORY;
 
-                initializeKeyEvent(keyEvent, &mMsg);
+                initializeKeyEvent(*keyEvent, mMsg);
                 *outSeq = mMsg.header.seq;
                 *outEvent = keyEvent;
                 ALOGD_IF(DEBUG_TRANSPORT_CONSUMER,
@@ -965,7 +1019,7 @@
                 if (!motionEvent) return NO_MEMORY;
 
                 updateTouchState(mMsg);
-                initializeMotionEvent(motionEvent, &mMsg);
+                initializeMotionEvent(*motionEvent, mMsg);
                 *outSeq = mMsg.header.seq;
                 *outEvent = motionEvent;
 
@@ -987,7 +1041,7 @@
                 FocusEvent* focusEvent = factory->createFocusEvent();
                 if (!focusEvent) return NO_MEMORY;
 
-                initializeFocusEvent(focusEvent, &mMsg);
+                initializeFocusEvent(*focusEvent, mMsg);
                 *outSeq = mMsg.header.seq;
                 *outEvent = focusEvent;
                 break;
@@ -997,7 +1051,7 @@
                 CaptureEvent* captureEvent = factory->createCaptureEvent();
                 if (!captureEvent) return NO_MEMORY;
 
-                initializeCaptureEvent(captureEvent, &mMsg);
+                initializeCaptureEvent(*captureEvent, mMsg);
                 *outSeq = mMsg.header.seq;
                 *outEvent = captureEvent;
                 break;
@@ -1007,7 +1061,7 @@
                 DragEvent* dragEvent = factory->createDragEvent();
                 if (!dragEvent) return NO_MEMORY;
 
-                initializeDragEvent(dragEvent, &mMsg);
+                initializeDragEvent(*dragEvent, mMsg);
                 *outSeq = mMsg.header.seq;
                 *outEvent = dragEvent;
                 break;
@@ -1017,7 +1071,7 @@
                 TouchModeEvent* touchModeEvent = factory->createTouchModeEvent();
                 if (!touchModeEvent) return NO_MEMORY;
 
-                initializeTouchModeEvent(touchModeEvent, &mMsg);
+                initializeTouchModeEvent(*touchModeEvent, mMsg);
                 *outSeq = mMsg.header.seq;
                 *outEvent = touchModeEvent;
                 break;
@@ -1079,9 +1133,9 @@
             seqChain.seq = msg.header.seq;
             seqChain.chain = chain;
             mSeqChains.push_back(seqChain);
-            addSample(motionEvent, &msg);
+            addSample(*motionEvent, msg);
         } else {
-            initializeMotionEvent(motionEvent, &msg);
+            initializeMotionEvent(*motionEvent, msg);
         }
         chain = msg.header.seq;
     }
@@ -1262,7 +1316,7 @@
                      delta);
             return;
         }
-        nsecs_t maxPredict = current->eventTime + min(delta / 2, RESAMPLE_MAX_PREDICTION);
+        nsecs_t maxPredict = current->eventTime + std::min(delta / 2, RESAMPLE_MAX_PREDICTION);
         if (sampleTime > maxPredict) {
             ALOGD_IF(debugResampling(),
                      "Sample time is too far in the future, adjusting prediction "
@@ -1465,69 +1519,6 @@
     return -1;
 }
 
-void InputConsumer::initializeKeyEvent(KeyEvent* event, const InputMessage* msg) {
-    event->initialize(msg->body.key.eventId, msg->body.key.deviceId, msg->body.key.source,
-                      msg->body.key.displayId, msg->body.key.hmac, msg->body.key.action,
-                      msg->body.key.flags, msg->body.key.keyCode, msg->body.key.scanCode,
-                      msg->body.key.metaState, msg->body.key.repeatCount, msg->body.key.downTime,
-                      msg->body.key.eventTime);
-}
-
-void InputConsumer::initializeFocusEvent(FocusEvent* event, const InputMessage* msg) {
-    event->initialize(msg->body.focus.eventId, msg->body.focus.hasFocus);
-}
-
-void InputConsumer::initializeCaptureEvent(CaptureEvent* event, const InputMessage* msg) {
-    event->initialize(msg->body.capture.eventId, msg->body.capture.pointerCaptureEnabled);
-}
-
-void InputConsumer::initializeDragEvent(DragEvent* event, const InputMessage* msg) {
-    event->initialize(msg->body.drag.eventId, msg->body.drag.x, msg->body.drag.y,
-                      msg->body.drag.isExiting);
-}
-
-void InputConsumer::initializeMotionEvent(MotionEvent* event, const InputMessage* msg) {
-    uint32_t pointerCount = msg->body.motion.pointerCount;
-    PointerProperties pointerProperties[pointerCount];
-    PointerCoords pointerCoords[pointerCount];
-    for (uint32_t i = 0; i < pointerCount; i++) {
-        pointerProperties[i] = msg->body.motion.pointers[i].properties;
-        pointerCoords[i] = msg->body.motion.pointers[i].coords;
-    }
-
-    ui::Transform transform;
-    transform.set({msg->body.motion.dsdx, msg->body.motion.dtdx, msg->body.motion.tx,
-                   msg->body.motion.dtdy, msg->body.motion.dsdy, msg->body.motion.ty, 0, 0, 1});
-    ui::Transform displayTransform;
-    displayTransform.set({msg->body.motion.dsdxRaw, msg->body.motion.dtdxRaw,
-                          msg->body.motion.txRaw, msg->body.motion.dtdyRaw,
-                          msg->body.motion.dsdyRaw, msg->body.motion.tyRaw, 0, 0, 1});
-    event->initialize(msg->body.motion.eventId, msg->body.motion.deviceId, msg->body.motion.source,
-                      msg->body.motion.displayId, msg->body.motion.hmac, msg->body.motion.action,
-                      msg->body.motion.actionButton, msg->body.motion.flags,
-                      msg->body.motion.edgeFlags, msg->body.motion.metaState,
-                      msg->body.motion.buttonState, msg->body.motion.classification, transform,
-                      msg->body.motion.xPrecision, msg->body.motion.yPrecision,
-                      msg->body.motion.xCursorPosition, msg->body.motion.yCursorPosition,
-                      displayTransform, msg->body.motion.downTime, msg->body.motion.eventTime,
-                      pointerCount, pointerProperties, pointerCoords);
-}
-
-void InputConsumer::initializeTouchModeEvent(TouchModeEvent* event, const InputMessage* msg) {
-    event->initialize(msg->body.touchMode.eventId, msg->body.touchMode.isInTouchMode);
-}
-
-void InputConsumer::addSample(MotionEvent* event, const InputMessage* msg) {
-    uint32_t pointerCount = msg->body.motion.pointerCount;
-    PointerCoords pointerCoords[pointerCount];
-    for (uint32_t i = 0; i < pointerCount; i++) {
-        pointerCoords[i] = msg->body.motion.pointers[i].coords;
-    }
-
-    event->setMetaState(event->getMetaState() | msg->body.motion.metaState);
-    event->addSample(msg->body.motion.eventTime, pointerCoords);
-}
-
 bool InputConsumer::canAddSample(const Batch& batch, const InputMessage *msg) {
     const InputMessage& head = batch.samples[0];
     uint32_t pointerCount = msg->body.motion.pointerCount;
diff --git a/libs/input/MotionPredictorMetricsManager.cpp b/libs/input/MotionPredictorMetricsManager.cpp
index 0412d08..6872af2 100644
--- a/libs/input/MotionPredictorMetricsManager.cpp
+++ b/libs/input/MotionPredictorMetricsManager.cpp
@@ -113,7 +113,12 @@
 // Adds new predictions to mRecentPredictions and maintains the invariant that elements are
 // sorted in ascending order of targetTimestamp.
 void MotionPredictorMetricsManager::onPredict(const MotionEvent& predictionEvent) {
-    for (size_t i = 0; i < predictionEvent.getHistorySize() + 1; ++i) {
+    const size_t numPredictions = predictionEvent.getHistorySize() + 1;
+    if (numPredictions > mMaxNumPredictions) {
+        LOG(WARNING) << "numPredictions (" << numPredictions << ") > mMaxNumPredictions ("
+                     << mMaxNumPredictions << "). Ignoring extra predictions in metrics.";
+    }
+    for (size_t i = 0; (i < numPredictions) && (i < mMaxNumPredictions); ++i) {
         // Convert MotionEvent to PredictionPoint.
         const PointerCoords* coords =
                 predictionEvent.getHistoricalRawPointerCoords(/*pointerIndex=*/0, i);
@@ -325,42 +330,44 @@
             mAtomFields[i].highVelocityOffTrajectoryRmse =
                     static_cast<int>(offTrajectoryRmse * 1000);
         }
+    }
 
-        // Scale-invariant errors: reported only for the last time bucket, where the values
-        // represent an average across all time buckets.
-        if (i + 1 == mMaxNumPredictions) {
-            // Compute error averages.
-            float alongTrajectoryRmseSum = 0;
-            float offTrajectoryRmseSum = 0;
-            for (size_t j = 0; j < mAggregatedMetrics.size(); ++j) {
-                // If we have general errors (checked above), we should always also have
-                // scale-invariant errors.
-                LOG_ALWAYS_FATAL_IF(mAggregatedMetrics[j].scaleInvariantErrorsCount == 0,
-                                    "mAggregatedMetrics[%zu].scaleInvariantErrorsCount is 0", j);
-
-                LOG_ALWAYS_FATAL_IF(mAggregatedMetrics[j].scaleInvariantAlongTrajectorySse < 0,
-                                    "mAggregatedMetrics[%zu].scaleInvariantAlongTrajectorySse = %f "
-                                    "should not be negative",
-                                    j, mAggregatedMetrics[j].scaleInvariantAlongTrajectorySse);
-                alongTrajectoryRmseSum +=
-                        std::sqrt(mAggregatedMetrics[j].scaleInvariantAlongTrajectorySse /
-                                  mAggregatedMetrics[j].scaleInvariantErrorsCount);
-
-                LOG_ALWAYS_FATAL_IF(mAggregatedMetrics[j].scaleInvariantOffTrajectorySse < 0,
-                                    "mAggregatedMetrics[%zu].scaleInvariantOffTrajectorySse = %f "
-                                    "should not be negative",
-                                    j, mAggregatedMetrics[j].scaleInvariantOffTrajectorySse);
-                offTrajectoryRmseSum +=
-                        std::sqrt(mAggregatedMetrics[j].scaleInvariantOffTrajectorySse /
-                                  mAggregatedMetrics[j].scaleInvariantErrorsCount);
+    // Scale-invariant errors: the average scale-invariant error across all time buckets
+    // is reported in the last time bucket.
+    {
+        // Compute error averages.
+        float alongTrajectoryRmseSum = 0;
+        float offTrajectoryRmseSum = 0;
+        int bucket_count = 0;
+        for (size_t j = 0; j < mAggregatedMetrics.size(); ++j) {
+            if (mAggregatedMetrics[j].scaleInvariantErrorsCount == 0) {
+                continue;
             }
 
-            const float averageAlongTrajectoryRmse =
-                    alongTrajectoryRmseSum / mAggregatedMetrics.size();
+            LOG_ALWAYS_FATAL_IF(mAggregatedMetrics[j].scaleInvariantAlongTrajectorySse < 0,
+                                "mAggregatedMetrics[%zu].scaleInvariantAlongTrajectorySse = %f "
+                                "should not be negative",
+                                j, mAggregatedMetrics[j].scaleInvariantAlongTrajectorySse);
+            alongTrajectoryRmseSum +=
+                    std::sqrt(mAggregatedMetrics[j].scaleInvariantAlongTrajectorySse /
+                              mAggregatedMetrics[j].scaleInvariantErrorsCount);
+
+            LOG_ALWAYS_FATAL_IF(mAggregatedMetrics[j].scaleInvariantOffTrajectorySse < 0,
+                                "mAggregatedMetrics[%zu].scaleInvariantOffTrajectorySse = %f "
+                                "should not be negative",
+                                j, mAggregatedMetrics[j].scaleInvariantOffTrajectorySse);
+            offTrajectoryRmseSum += std::sqrt(mAggregatedMetrics[j].scaleInvariantOffTrajectorySse /
+                                              mAggregatedMetrics[j].scaleInvariantErrorsCount);
+
+            ++bucket_count;
+        }
+
+        if (bucket_count > 0) {
+            const float averageAlongTrajectoryRmse = alongTrajectoryRmseSum / bucket_count;
             mAtomFields.back().scaleInvariantAlongTrajectoryRmse =
                     static_cast<int>(averageAlongTrajectoryRmse * 1000);
 
-            const float averageOffTrajectoryRmse = offTrajectoryRmseSum / mAggregatedMetrics.size();
+            const float averageOffTrajectoryRmse = offTrajectoryRmseSum / bucket_count;
             mAtomFields.back().scaleInvariantOffTrajectoryRmse =
                     static_cast<int>(averageOffTrajectoryRmse * 1000);
         }
diff --git a/libs/input/tests/InputChannel_test.cpp b/libs/input/tests/InputChannel_test.cpp
index 60feb53..02d4c07 100644
--- a/libs/input/tests/InputChannel_test.cpp
+++ b/libs/input/tests/InputChannel_test.cpp
@@ -16,8 +16,6 @@
 
 #include <array>
 
-#include "TestHelpers.h"
-
 #include <unistd.h>
 #include <time.h>
 #include <errno.h>
diff --git a/libs/input/tests/InputEvent_test.cpp b/libs/input/tests/InputEvent_test.cpp
index 540766d..0df06b7 100644
--- a/libs/input/tests/InputEvent_test.cpp
+++ b/libs/input/tests/InputEvent_test.cpp
@@ -371,8 +371,10 @@
     ASSERT_EQ(AMOTION_EVENT_BUTTON_PRIMARY, event->getButtonState());
     ASSERT_EQ(MotionClassification::NONE, event->getClassification());
     EXPECT_EQ(mTransform, event->getTransform());
-    ASSERT_EQ(X_OFFSET, event->getXOffset());
-    ASSERT_EQ(Y_OFFSET, event->getYOffset());
+    ASSERT_NEAR((-RAW_X_OFFSET / RAW_X_SCALE) * X_SCALE + X_OFFSET, event->getRawXOffset(),
+                EPSILON);
+    ASSERT_NEAR((-RAW_Y_OFFSET / RAW_Y_SCALE) * Y_SCALE + Y_OFFSET, event->getRawYOffset(),
+                EPSILON);
     ASSERT_EQ(2.0f, event->getXPrecision());
     ASSERT_EQ(2.1f, event->getYPrecision());
     ASSERT_EQ(ARBITRARY_DOWN_TIME, event->getDownTime());
@@ -709,22 +711,26 @@
 TEST_F(MotionEventTest, OffsetLocation) {
     MotionEvent event;
     initializeEventWithHistory(&event);
+    const float xOffset = event.getRawXOffset();
+    const float yOffset = event.getRawYOffset();
 
     event.offsetLocation(5.0f, -2.0f);
 
-    ASSERT_EQ(X_OFFSET + 5.0f, event.getXOffset());
-    ASSERT_EQ(Y_OFFSET - 2.0f, event.getYOffset());
+    ASSERT_EQ(xOffset + 5.0f, event.getRawXOffset());
+    ASSERT_EQ(yOffset - 2.0f, event.getRawYOffset());
 }
 
 TEST_F(MotionEventTest, Scale) {
     MotionEvent event;
     initializeEventWithHistory(&event);
     const float unscaledOrientation = event.getOrientation(0);
+    const float unscaledXOffset = event.getRawXOffset();
+    const float unscaledYOffset = event.getRawYOffset();
 
     event.scale(2.0f);
 
-    ASSERT_EQ(X_OFFSET * 2, event.getXOffset());
-    ASSERT_EQ(Y_OFFSET * 2, event.getYOffset());
+    ASSERT_EQ(unscaledXOffset * 2, event.getRawXOffset());
+    ASSERT_EQ(unscaledYOffset * 2, event.getRawYOffset());
 
     ASSERT_NEAR((RAW_X_OFFSET + 210 * RAW_X_SCALE) * 2, event.getRawX(0), EPSILON);
     ASSERT_NEAR((RAW_Y_OFFSET + 211 * RAW_Y_SCALE) * 2, event.getRawY(0), EPSILON);
diff --git a/libs/input/tests/InputPublisherAndConsumer_test.cpp b/libs/input/tests/InputPublisherAndConsumer_test.cpp
index 3543020..b5fab49 100644
--- a/libs/input/tests/InputPublisherAndConsumer_test.cpp
+++ b/libs/input/tests/InputPublisherAndConsumer_test.cpp
@@ -14,8 +14,6 @@
  * limitations under the License.
  */
 
-#include "TestHelpers.h"
-
 #include <attestation/HmacKeyManager.h>
 #include <gtest/gtest.h>
 #include <gui/constants.h>
@@ -135,8 +133,10 @@
     EXPECT_EQ(args.buttonState, motionEvent.getButtonState());
     EXPECT_EQ(args.classification, motionEvent.getClassification());
     EXPECT_EQ(args.transform, motionEvent.getTransform());
-    EXPECT_EQ(args.xOffset, motionEvent.getXOffset());
-    EXPECT_EQ(args.yOffset, motionEvent.getYOffset());
+    EXPECT_NEAR((-args.rawXOffset / args.rawXScale) * args.xScale + args.xOffset,
+                motionEvent.getRawXOffset(), EPSILON);
+    EXPECT_NEAR((-args.rawYOffset / args.rawYScale) * args.yScale + args.yOffset,
+                motionEvent.getRawYOffset(), EPSILON);
     EXPECT_EQ(args.xPrecision, motionEvent.getXPrecision());
     EXPECT_EQ(args.yPrecision, motionEvent.getYPrecision());
     EXPECT_NEAR(args.xCursorPosition, motionEvent.getRawXCursorPosition(), EPSILON);
diff --git a/libs/input/tests/MotionPredictorMetricsManager_test.cpp b/libs/input/tests/MotionPredictorMetricsManager_test.cpp
index 31cc145..cc41eeb 100644
--- a/libs/input/tests/MotionPredictorMetricsManager_test.cpp
+++ b/libs/input/tests/MotionPredictorMetricsManager_test.cpp
@@ -238,14 +238,17 @@
 
 // --- Ground-truth-generation helper functions. ---
 
+// Generates numPoints ground truth points with values equal to those of the given
+// GroundTruthPoint, and with consecutive timestamps separated by the given inputInterval.
 std::vector<GroundTruthPoint> generateConstantGroundTruthPoints(
-        const GroundTruthPoint& groundTruthPoint, size_t numPoints) {
+        const GroundTruthPoint& groundTruthPoint, size_t numPoints,
+        nsecs_t inputInterval = TEST_PREDICTION_INTERVAL_NANOS) {
     std::vector<GroundTruthPoint> groundTruthPoints;
     nsecs_t timestamp = groundTruthPoint.timestamp;
     for (size_t i = 0; i < numPoints; ++i) {
         groundTruthPoints.emplace_back(groundTruthPoint);
         groundTruthPoints.back().timestamp = timestamp;
-        timestamp += TEST_PREDICTION_INTERVAL_NANOS;
+        timestamp += inputInterval;
     }
     return groundTruthPoints;
 }
@@ -280,7 +283,8 @@
     const GroundTruthPoint groundTruthPoint{{.position = Eigen::Vector2f(10, 20), .pressure = 0.3f},
                                             .timestamp = TEST_INITIAL_TIMESTAMP};
     const std::vector<GroundTruthPoint> groundTruthPoints =
-            generateConstantGroundTruthPoints(groundTruthPoint, /*numPoints=*/3);
+            generateConstantGroundTruthPoints(groundTruthPoint, /*numPoints=*/3,
+                                              /*inputInterval=*/10);
 
     ASSERT_EQ(3u, groundTruthPoints.size());
     // First point.
@@ -290,11 +294,11 @@
     // Second point.
     EXPECT_EQ(groundTruthPoints[1].position, groundTruthPoint.position);
     EXPECT_EQ(groundTruthPoints[1].pressure, groundTruthPoint.pressure);
-    EXPECT_GT(groundTruthPoints[1].timestamp, groundTruthPoints[0].timestamp);
+    EXPECT_EQ(groundTruthPoints[1].timestamp, groundTruthPoint.timestamp + 10);
     // Third point.
     EXPECT_EQ(groundTruthPoints[2].position, groundTruthPoint.position);
     EXPECT_EQ(groundTruthPoints[2].pressure, groundTruthPoint.pressure);
-    EXPECT_GT(groundTruthPoints[2].timestamp, groundTruthPoints[1].timestamp);
+    EXPECT_EQ(groundTruthPoints[2].timestamp, groundTruthPoint.timestamp + 20);
 }
 
 TEST(GenerateCircularArcGroundTruthTest, StraightLineUpwards) {
@@ -333,16 +337,19 @@
 
 // --- Prediction-generation helper functions. ---
 
-// Creates a sequence of predictions with values equal to those of the given GroundTruthPoint.
-std::vector<PredictionPoint> generateConstantPredictions(const GroundTruthPoint& groundTruthPoint) {
+// Generates TEST_MAX_NUM_PREDICTIONS predictions with values equal to those of the given
+// GroundTruthPoint, and with consecutive timestamps separated by the given predictionInterval.
+std::vector<PredictionPoint> generateConstantPredictions(
+        const GroundTruthPoint& groundTruthPoint,
+        nsecs_t predictionInterval = TEST_PREDICTION_INTERVAL_NANOS) {
     std::vector<PredictionPoint> predictions;
-    nsecs_t predictionTimestamp = groundTruthPoint.timestamp + TEST_PREDICTION_INTERVAL_NANOS;
+    nsecs_t predictionTimestamp = groundTruthPoint.timestamp + predictionInterval;
     for (size_t j = 0; j < TEST_MAX_NUM_PREDICTIONS; ++j) {
         predictions.push_back(PredictionPoint{{.position = groundTruthPoint.position,
                                                .pressure = groundTruthPoint.pressure},
                                               .originTimestamp = groundTruthPoint.timestamp,
                                               .targetTimestamp = predictionTimestamp});
-        predictionTimestamp += TEST_PREDICTION_INTERVAL_NANOS;
+        predictionTimestamp += predictionInterval;
     }
     return predictions;
 }
@@ -375,8 +382,9 @@
 TEST(GeneratePredictionsTest, GenerateConstantPredictions) {
     const GroundTruthPoint groundTruthPoint{{.position = Eigen::Vector2f(10, 20), .pressure = 0.3f},
                                             .timestamp = TEST_INITIAL_TIMESTAMP};
+    const nsecs_t predictionInterval = 10;
     const std::vector<PredictionPoint> predictionPoints =
-            generateConstantPredictions(groundTruthPoint);
+            generateConstantPredictions(groundTruthPoint, predictionInterval);
 
     ASSERT_EQ(TEST_MAX_NUM_PREDICTIONS, predictionPoints.size());
     for (size_t i = 0; i < predictionPoints.size(); ++i) {
@@ -385,8 +393,7 @@
         EXPECT_THAT(predictionPoints[i].pressure, FloatNear(groundTruthPoint.pressure, 1e-6));
         EXPECT_EQ(predictionPoints[i].originTimestamp, groundTruthPoint.timestamp);
         EXPECT_EQ(predictionPoints[i].targetTimestamp,
-                  groundTruthPoint.timestamp +
-                          static_cast<nsecs_t>(i + 1) * TEST_PREDICTION_INTERVAL_NANOS);
+                  TEST_INITIAL_TIMESTAMP + static_cast<nsecs_t>(i + 1) * predictionInterval);
     }
 }
 
@@ -678,12 +685,9 @@
 //  • groundTruthPoints: chronologically-ordered ground truth points, with at least 2 elements.
 //  • predictionPoints: the first index points to a vector of predictions corresponding to the
 //    source ground truth point with the same index.
-//     - The first element should be empty, because there are not expected to be predictions until
-//       we have received 2 ground truth points.
-//     - The last element may be empty, because there will be no future ground truth points to
-//       associate with those predictions (if not empty, it will be ignored).
+//     - For empty prediction vectors, MetricsManager::onPredict will not be called.
 //     - To test all prediction buckets, there should be at least TEST_MAX_NUM_PREDICTIONS non-empty
-//       prediction sets (that is, excluding the first and last). Thus, groundTruthPoints and
+//       prediction vectors (that is, excluding the first and last). Thus, groundTruthPoints and
 //       predictionPoints should have size at least TEST_MAX_NUM_PREDICTIONS + 2.
 //
 // When the function returns, outReportedAtomFields will contain the reported AtomFields.
@@ -697,19 +701,12 @@
                                                  createMockReportAtomFunction(
                                                          outReportedAtomFields));
 
-    // Validate structure of groundTruthPoints and predictionPoints.
-    ASSERT_EQ(predictionPoints.size(), groundTruthPoints.size());
     ASSERT_GE(groundTruthPoints.size(), 2u);
-    ASSERT_EQ(predictionPoints[0].size(), 0u);
-    for (size_t i = 1; i + 1 < predictionPoints.size(); ++i) {
-        SCOPED_TRACE(testing::Message() << "i = " << i);
-        ASSERT_EQ(predictionPoints[i].size(), TEST_MAX_NUM_PREDICTIONS);
-    }
+    ASSERT_EQ(predictionPoints.size(), groundTruthPoints.size());
 
-    // Pass ground truth points and predictions (for all except first and last ground truth).
     for (size_t i = 0; i < groundTruthPoints.size(); ++i) {
         metricsManager.onRecord(makeMotionEvent(groundTruthPoints[i]));
-        if ((i > 0) && (i + 1 < predictionPoints.size())) {
+        if (!predictionPoints[i].empty()) {
             metricsManager.onPredict(makeMotionEvent(predictionPoints[i]));
         }
     }
@@ -738,7 +735,7 @@
 // Perfect predictions test:
 //  • Input: constant input events, perfect predictions matching the input events.
 //  • Expectation: all error metrics should be zero, or NO_DATA_SENTINEL for "unreported" metrics.
-//    (For example, scale-invariant errors are only reported for the final time bucket.)
+//    (For example, scale-invariant errors are only reported for the last time bucket.)
 TEST(MotionPredictorMetricsManagerTest, ConstantGroundTruthPerfectPredictions) {
     GroundTruthPoint groundTruthPoint{{.position = Eigen::Vector2f(10.0f, 20.0f), .pressure = 0.6f},
                                       .timestamp = TEST_INITIAL_TIMESTAMP};
@@ -977,5 +974,35 @@
     }
 }
 
+// Robustness test:
+//  • Input: input events separated by a significantly greater time interval than the interval
+//    between predictions.
+//  • Expectation: the MetricsManager should not crash in this case. (No assertions are made about
+//    the resulting metrics.)
+//
+// In practice, this scenario could arise either if the input and prediction intervals are
+// mismatched, or if input events are missing (dropped or skipped for some reason).
+TEST(MotionPredictorMetricsManagerTest, MismatchedInputAndPredictionInterval) {
+    // Create two ground truth points separated by MAX_NUM_PREDICTIONS * PREDICTION_INTERVAL,
+    // so that the second ground truth point corresponds to the last prediction bucket. This
+    // ensures that the scale-invariant error codepath will be run, giving full code coverage.
+    GroundTruthPoint groundTruthPoint{{.position = Eigen::Vector2f(0.0f, 0.0f), .pressure = 0.5f},
+                                      .timestamp = TEST_INITIAL_TIMESTAMP};
+    const nsecs_t inputInterval = TEST_MAX_NUM_PREDICTIONS * TEST_PREDICTION_INTERVAL_NANOS;
+    const std::vector<GroundTruthPoint> groundTruthPoints =
+            generateConstantGroundTruthPoints(groundTruthPoint, /*numPoints=*/2, inputInterval);
+
+    // Create predictions separated by the prediction interval.
+    std::vector<std::vector<PredictionPoint>> predictionPoints;
+    for (size_t i = 0; i < groundTruthPoints.size(); ++i) {
+        predictionPoints.push_back(
+                generateConstantPredictions(groundTruthPoints[i], TEST_PREDICTION_INTERVAL_NANOS));
+    }
+
+    // Test that we can run the MetricsManager without crashing.
+    std::vector<AtomFields> reportedAtomFields;
+    runMetricsManager(groundTruthPoints, predictionPoints, reportedAtomFields);
+}
+
 } // namespace
 } // namespace android
diff --git a/libs/input/tests/TestHelpers.h b/libs/input/tests/TestHelpers.h
deleted file mode 100644
index 343d81f..0000000
--- a/libs/input/tests/TestHelpers.h
+++ /dev/null
@@ -1,81 +0,0 @@
-/*
- * Copyright (C) 2010 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef TESTHELPERS_H
-#define TESTHELPERS_H
-
-#include <unistd.h>
-
-#include <utils/threads.h>
-
-namespace android {
-
-class Pipe {
-public:
-    int sendFd;
-    int receiveFd;
-
-    Pipe() {
-        int fds[2];
-        ::pipe(fds);
-
-        receiveFd = fds[0];
-        sendFd = fds[1];
-    }
-
-    ~Pipe() {
-        if (sendFd != -1) {
-            ::close(sendFd);
-        }
-
-        if (receiveFd != -1) {
-            ::close(receiveFd);
-        }
-    }
-
-    status_t writeSignal() {
-        ssize_t nWritten = ::write(sendFd, "*", 1);
-        return nWritten == 1 ? 0 : -errno;
-    }
-
-    status_t readSignal() {
-        char buf[1];
-        ssize_t nRead = ::read(receiveFd, buf, 1);
-        return nRead == 1 ? 0 : nRead == 0 ? -EPIPE : -errno;
-    }
-};
-
-class DelayedTask : public Thread {
-    int mDelayMillis;
-
-public:
-    explicit DelayedTask(int delayMillis) : mDelayMillis(delayMillis) { }
-
-protected:
-    virtual ~DelayedTask() { }
-
-    virtual void doTask() = 0;
-
-    virtual bool threadLoop() {
-        usleep(mDelayMillis * 1000);
-        doTask();
-        return false;
-    }
-};
-
-} // namespace android
-
-#endif // TESTHELPERS_H
diff --git a/libs/input/tests/TouchResampling_test.cpp b/libs/input/tests/TouchResampling_test.cpp
index 1cb7f7b..0b0bb63 100644
--- a/libs/input/tests/TouchResampling_test.cpp
+++ b/libs/input/tests/TouchResampling_test.cpp
@@ -14,8 +14,6 @@
  * limitations under the License.
  */
 
-#include "TestHelpers.h"
-
 #include <chrono>
 #include <vector>
 
diff --git a/libs/math/include/math/mat4.h b/libs/math/include/math/mat4.h
index 6119ba7..c630d97 100644
--- a/libs/math/include/math/mat4.h
+++ b/libs/math/include/math/mat4.h
@@ -34,6 +34,14 @@
 #define CONSTEXPR
 #endif
 
+#ifdef _WIN32
+// windows.h contains obsolete defines of 'near' and 'far' for systems using
+// legacy 16 bit pointers. Undefine them to avoid conflicting with the usage of
+// 'near' and 'far' in this file.
+#undef near
+#undef far
+#endif
+
 namespace android {
 // -------------------------------------------------------------------------------------
 namespace details {
diff --git a/libs/nativedisplay/AChoreographer.cpp b/libs/nativedisplay/AChoreographer.cpp
index 8f005a5..bed31e2 100644
--- a/libs/nativedisplay/AChoreographer.cpp
+++ b/libs/nativedisplay/AChoreographer.cpp
@@ -148,29 +148,31 @@
 void AChoreographer_postFrameCallback(AChoreographer* choreographer,
                                       AChoreographer_frameCallback callback, void* data) {
     AChoreographer_to_Choreographer(choreographer)
-            ->postFrameCallbackDelayed(callback, nullptr, nullptr, data, 0);
+            ->postFrameCallbackDelayed(callback, nullptr, nullptr, data, 0, CALLBACK_ANIMATION);
 }
 void AChoreographer_postFrameCallbackDelayed(AChoreographer* choreographer,
                                              AChoreographer_frameCallback callback, void* data,
                                              long delayMillis) {
     AChoreographer_to_Choreographer(choreographer)
-            ->postFrameCallbackDelayed(callback, nullptr, nullptr, data, ms2ns(delayMillis));
+            ->postFrameCallbackDelayed(callback, nullptr, nullptr, data, ms2ns(delayMillis),
+                                       CALLBACK_ANIMATION);
 }
 void AChoreographer_postVsyncCallback(AChoreographer* choreographer,
                                       AChoreographer_vsyncCallback callback, void* data) {
     AChoreographer_to_Choreographer(choreographer)
-            ->postFrameCallbackDelayed(nullptr, nullptr, callback, data, 0);
+            ->postFrameCallbackDelayed(nullptr, nullptr, callback, data, 0, CALLBACK_ANIMATION);
 }
 void AChoreographer_postFrameCallback64(AChoreographer* choreographer,
                                         AChoreographer_frameCallback64 callback, void* data) {
     AChoreographer_to_Choreographer(choreographer)
-            ->postFrameCallbackDelayed(nullptr, callback, nullptr, data, 0);
+            ->postFrameCallbackDelayed(nullptr, callback, nullptr, data, 0, CALLBACK_ANIMATION);
 }
 void AChoreographer_postFrameCallbackDelayed64(AChoreographer* choreographer,
                                                AChoreographer_frameCallback64 callback, void* data,
                                                uint32_t delayMillis) {
     AChoreographer_to_Choreographer(choreographer)
-            ->postFrameCallbackDelayed(nullptr, callback, nullptr, data, ms2ns(delayMillis));
+            ->postFrameCallbackDelayed(nullptr, callback, nullptr, data, ms2ns(delayMillis),
+                                       CALLBACK_ANIMATION);
 }
 void AChoreographer_registerRefreshRateCallback(AChoreographer* choreographer,
                                                 AChoreographer_refreshRateCallback callback,
diff --git a/libs/renderengine/Android.bp b/libs/renderengine/Android.bp
index b501d40..09d7cb5 100644
--- a/libs/renderengine/Android.bp
+++ b/libs/renderengine/Android.bp
@@ -87,6 +87,7 @@
         "skia/SkiaRenderEngine.cpp",
         "skia/SkiaGLRenderEngine.cpp",
         "skia/SkiaVkRenderEngine.cpp",
+        "skia/VulkanInterface.cpp",
         "skia/debug/CaptureTimer.cpp",
         "skia/debug/CommonPool.cpp",
         "skia/debug/SkiaCapture.cpp",
diff --git a/libs/renderengine/skia/SkiaVkRenderEngine.cpp b/libs/renderengine/skia/SkiaVkRenderEngine.cpp
index eb7a9d5..f2f1b5d 100644
--- a/libs/renderengine/skia/SkiaVkRenderEngine.cpp
+++ b/libs/renderengine/skia/SkiaVkRenderEngine.cpp
@@ -25,6 +25,7 @@
 #include <GrContextOptions.h>
 #include <vk/GrVkExtensions.h>
 #include <vk/GrVkTypes.h>
+#include <include/gpu/ganesh/vk/GrVkBackendSemaphore.h>
 #include <include/gpu/ganesh/vk/GrVkDirectContext.h>
 
 #include <android-base/stringprintf.h>
@@ -32,11 +33,8 @@
 #include <sync/sync.h>
 #include <utils/Trace.h>
 
-#include <cstdint>
 #include <memory>
-#include <sstream>
 #include <string>
-#include <vector>
 
 #include <vulkan/vulkan.h>
 #include "log/log_main.h"
@@ -44,619 +42,19 @@
 namespace android {
 namespace renderengine {
 
-struct VulkanFuncs {
-    PFN_vkCreateSemaphore vkCreateSemaphore = nullptr;
-    PFN_vkImportSemaphoreFdKHR vkImportSemaphoreFdKHR = nullptr;
-    PFN_vkGetSemaphoreFdKHR vkGetSemaphoreFdKHR = nullptr;
-    PFN_vkDestroySemaphore vkDestroySemaphore = nullptr;
-
-    PFN_vkDeviceWaitIdle vkDeviceWaitIdle = nullptr;
-    PFN_vkDestroyDevice vkDestroyDevice = nullptr;
-    PFN_vkDestroyInstance vkDestroyInstance = nullptr;
-};
-
-// Ref-Count a semaphore
-struct DestroySemaphoreInfo {
-    VkSemaphore mSemaphore;
-    // We need to make sure we don't delete the VkSemaphore until it is done being used by both Skia
-    // (including by the GPU) and inside SkiaVkRenderEngine. So we always start with two refs, one
-    // owned by Skia and one owned by the SkiaVkRenderEngine. The refs are decremented each time
-    // delete_semaphore* is called with this object. Skia will call destroy_semaphore* once it is
-    // done with the semaphore and the GPU has finished work on the semaphore. SkiaVkRenderEngine
-    // calls delete_semaphore* after sending the semaphore to Skia and exporting it if need be.
-    int mRefs = 2;
-
-    DestroySemaphoreInfo(VkSemaphore semaphore) : mSemaphore(semaphore) {}
-};
-
-namespace {
-void onVkDeviceFault(void* callbackContext, const std::string& description,
-                     const std::vector<VkDeviceFaultAddressInfoEXT>& addressInfos,
-                     const std::vector<VkDeviceFaultVendorInfoEXT>& vendorInfos,
-                     const std::vector<std::byte>& vendorBinaryData);
-} // anonymous namespace
-
-struct VulkanInterface {
-    bool initialized = false;
-    VkInstance instance;
-    VkPhysicalDevice physicalDevice;
-    VkDevice device;
-    VkQueue queue;
-    int queueIndex;
-    uint32_t apiVersion;
-    GrVkExtensions grExtensions;
-    VkPhysicalDeviceFeatures2* physicalDeviceFeatures2 = nullptr;
-    VkPhysicalDeviceSamplerYcbcrConversionFeatures* samplerYcbcrConversionFeatures = nullptr;
-    VkPhysicalDeviceProtectedMemoryFeatures* protectedMemoryFeatures = nullptr;
-    VkPhysicalDeviceFaultFeaturesEXT* deviceFaultFeatures = nullptr;
-    GrVkGetProc grGetProc;
-    bool isProtected;
-    bool isRealtimePriority;
-
-    VulkanFuncs funcs;
-
-    std::vector<std::string> instanceExtensionNames;
-    std::vector<std::string> deviceExtensionNames;
-
-    GrVkBackendContext getBackendContext() {
-        GrVkBackendContext backendContext;
-        backendContext.fInstance = instance;
-        backendContext.fPhysicalDevice = physicalDevice;
-        backendContext.fDevice = device;
-        backendContext.fQueue = queue;
-        backendContext.fGraphicsQueueIndex = queueIndex;
-        backendContext.fMaxAPIVersion = apiVersion;
-        backendContext.fVkExtensions = &grExtensions;
-        backendContext.fDeviceFeatures2 = physicalDeviceFeatures2;
-        backendContext.fGetProc = grGetProc;
-        backendContext.fProtectedContext = isProtected ? GrProtected::kYes : GrProtected::kNo;
-        backendContext.fDeviceLostContext = this; // VulkanInterface is long-lived
-        backendContext.fDeviceLostProc = onVkDeviceFault;
-        return backendContext;
-    };
-
-    VkSemaphore createExportableSemaphore() {
-        VkExportSemaphoreCreateInfo exportInfo;
-        exportInfo.sType = VK_STRUCTURE_TYPE_EXPORT_SEMAPHORE_CREATE_INFO;
-        exportInfo.pNext = nullptr;
-        exportInfo.handleTypes = VK_EXTERNAL_SEMAPHORE_HANDLE_TYPE_SYNC_FD_BIT;
-
-        VkSemaphoreCreateInfo semaphoreInfo;
-        semaphoreInfo.sType = VK_STRUCTURE_TYPE_SEMAPHORE_CREATE_INFO;
-        semaphoreInfo.pNext = &exportInfo;
-        semaphoreInfo.flags = 0;
-
-        VkSemaphore semaphore;
-        VkResult err = funcs.vkCreateSemaphore(device, &semaphoreInfo, nullptr, &semaphore);
-        if (VK_SUCCESS != err) {
-            ALOGE("%s: failed to create semaphore. err %d\n", __func__, err);
-            return VK_NULL_HANDLE;
-        }
-
-        return semaphore;
-    }
-
-    // syncFd cannot be <= 0
-    VkSemaphore importSemaphoreFromSyncFd(int syncFd) {
-        VkSemaphoreCreateInfo semaphoreInfo;
-        semaphoreInfo.sType = VK_STRUCTURE_TYPE_SEMAPHORE_CREATE_INFO;
-        semaphoreInfo.pNext = nullptr;
-        semaphoreInfo.flags = 0;
-
-        VkSemaphore semaphore;
-        VkResult err = funcs.vkCreateSemaphore(device, &semaphoreInfo, nullptr, &semaphore);
-        if (VK_SUCCESS != err) {
-            ALOGE("%s: failed to create import semaphore", __func__);
-            return VK_NULL_HANDLE;
-        }
-
-        VkImportSemaphoreFdInfoKHR importInfo;
-        importInfo.sType = VK_STRUCTURE_TYPE_IMPORT_SEMAPHORE_FD_INFO_KHR;
-        importInfo.pNext = nullptr;
-        importInfo.semaphore = semaphore;
-        importInfo.flags = VK_SEMAPHORE_IMPORT_TEMPORARY_BIT;
-        importInfo.handleType = VK_EXTERNAL_SEMAPHORE_HANDLE_TYPE_SYNC_FD_BIT;
-        importInfo.fd = syncFd;
-
-        err = funcs.vkImportSemaphoreFdKHR(device, &importInfo);
-        if (VK_SUCCESS != err) {
-            funcs.vkDestroySemaphore(device, semaphore, nullptr);
-            ALOGE("%s: failed to import semaphore", __func__);
-            return VK_NULL_HANDLE;
-        }
-
-        return semaphore;
-    }
-
-    int exportSemaphoreSyncFd(VkSemaphore semaphore) {
-        int res;
-
-        VkSemaphoreGetFdInfoKHR getFdInfo;
-        getFdInfo.sType = VK_STRUCTURE_TYPE_SEMAPHORE_GET_FD_INFO_KHR;
-        getFdInfo.pNext = nullptr;
-        getFdInfo.semaphore = semaphore;
-        getFdInfo.handleType = VK_EXTERNAL_SEMAPHORE_HANDLE_TYPE_SYNC_FD_BIT;
-
-        VkResult err = funcs.vkGetSemaphoreFdKHR(device, &getFdInfo, &res);
-        if (VK_SUCCESS != err) {
-            ALOGE("%s: failed to export semaphore, err: %d", __func__, err);
-            return -1;
-        }
-        return res;
-    }
-
-    void destroySemaphore(VkSemaphore semaphore) {
-        funcs.vkDestroySemaphore(device, semaphore, nullptr);
-    }
-};
-
-namespace {
-void onVkDeviceFault(void* callbackContext, const std::string& description,
-                     const std::vector<VkDeviceFaultAddressInfoEXT>& addressInfos,
-                     const std::vector<VkDeviceFaultVendorInfoEXT>& vendorInfos,
-                     const std::vector<std::byte>& vendorBinaryData) {
-    VulkanInterface* interface = static_cast<VulkanInterface*>(callbackContext);
-    const std::string protectedStr = interface->isProtected ? "protected" : "non-protected";
-    // The final crash string should contain as much differentiating info as possible, up to 1024
-    // bytes. As this final message is constructed, the same information is also dumped to the logs
-    // but in a more verbose format. Building the crash string is unsightly, so the clearer logging
-    // statement is always placed first to give context.
-    ALOGE("VK_ERROR_DEVICE_LOST (%s context): %s", protectedStr.c_str(), description.c_str());
-    std::stringstream crashMsg;
-    crashMsg << "VK_ERROR_DEVICE_LOST (" << protectedStr;
-
-    if (!addressInfos.empty()) {
-        ALOGE("%zu VkDeviceFaultAddressInfoEXT:", addressInfos.size());
-        crashMsg << ", " << addressInfos.size() << " address info (";
-        for (VkDeviceFaultAddressInfoEXT addressInfo : addressInfos) {
-            ALOGE(" addressType:       %d", (int)addressInfo.addressType);
-            ALOGE("  reportedAddress:  %" PRIu64, addressInfo.reportedAddress);
-            ALOGE("  addressPrecision: %" PRIu64, addressInfo.addressPrecision);
-            crashMsg << addressInfo.addressType << ":"
-                     << addressInfo.reportedAddress << ":"
-                     << addressInfo.addressPrecision << ", ";
-        }
-        crashMsg.seekp(-2, crashMsg.cur); // Move back to overwrite trailing ", "
-        crashMsg << ")";
-    }
-
-    if (!vendorInfos.empty()) {
-        ALOGE("%zu VkDeviceFaultVendorInfoEXT:", vendorInfos.size());
-        crashMsg << ", " << vendorInfos.size() << " vendor info (";
-        for (VkDeviceFaultVendorInfoEXT vendorInfo : vendorInfos) {
-            ALOGE(" description:      %s", vendorInfo.description);
-            ALOGE("  vendorFaultCode: %" PRIu64, vendorInfo.vendorFaultCode);
-            ALOGE("  vendorFaultData: %" PRIu64, vendorInfo.vendorFaultData);
-            // Omit descriptions for individual vendor info structs in the crash string, as the
-            // fault code and fault data fields should be enough for clustering, and the verbosity
-            // isn't worth it. Additionally, vendors may just set the general description field of
-            // the overall fault to the description of the first element in this list, and that
-            // overall description will be placed at the end of the crash string.
-            crashMsg << vendorInfo.vendorFaultCode << ":"
-                     << vendorInfo.vendorFaultData << ", ";
-        }
-        crashMsg.seekp(-2, crashMsg.cur); // Move back to overwrite trailing ", "
-        crashMsg << ")";
-    }
-
-    if (!vendorBinaryData.empty()) {
-        // TODO: b/322830575 - Log in base64, or dump directly to a file that gets put in bugreports
-        ALOGE("%zu bytes of vendor-specific binary data (please notify Android's Core Graphics"
-              " Stack team if you observe this message).",
-              vendorBinaryData.size());
-        crashMsg << ", " << vendorBinaryData.size() << " bytes binary";
-    }
-
-    crashMsg << "): " << description;
-    LOG_ALWAYS_FATAL("%s", crashMsg.str().c_str());
-};
-} // anonymous namespace
-
-static GrVkGetProc sGetProc = [](const char* proc_name, VkInstance instance, VkDevice device) {
-    if (device != VK_NULL_HANDLE) {
-        return vkGetDeviceProcAddr(device, proc_name);
-    }
-    return vkGetInstanceProcAddr(instance, proc_name);
-};
-
-#define BAIL(fmt, ...)                                          \
-    {                                                           \
-        ALOGE("%s: " fmt ", bailing", __func__, ##__VA_ARGS__); \
-        return interface;                                       \
-    }
-
-#define CHECK_NONNULL(expr)       \
-    if ((expr) == nullptr) {      \
-        BAIL("[%s] null", #expr); \
-    }
-
-#define VK_CHECK(expr)                              \
-    if ((expr) != VK_SUCCESS) {                     \
-        BAIL("[%s] failed. err = %d", #expr, expr); \
-        return interface;                           \
-    }
-
-#define VK_GET_PROC(F)                                                           \
-    PFN_vk##F vk##F = (PFN_vk##F)vkGetInstanceProcAddr(VK_NULL_HANDLE, "vk" #F); \
-    CHECK_NONNULL(vk##F)
-#define VK_GET_INST_PROC(instance, F)                                      \
-    PFN_vk##F vk##F = (PFN_vk##F)vkGetInstanceProcAddr(instance, "vk" #F); \
-    CHECK_NONNULL(vk##F)
-#define VK_GET_DEV_PROC(device, F)                                     \
-    PFN_vk##F vk##F = (PFN_vk##F)vkGetDeviceProcAddr(device, "vk" #F); \
-    CHECK_NONNULL(vk##F)
-
-VulkanInterface initVulkanInterface(bool protectedContent = false) {
-    const nsecs_t timeBefore = systemTime();
-    VulkanInterface interface;
-
-    VK_GET_PROC(EnumerateInstanceVersion);
-    uint32_t instanceVersion;
-    VK_CHECK(vkEnumerateInstanceVersion(&instanceVersion));
-
-    if (instanceVersion < VK_MAKE_VERSION(1, 1, 0)) {
-        return interface;
-    }
-
-    const VkApplicationInfo appInfo = {
-            VK_STRUCTURE_TYPE_APPLICATION_INFO, nullptr, "surfaceflinger", 0, "android platform", 0,
-            VK_MAKE_VERSION(1, 1, 0),
-    };
-
-    VK_GET_PROC(EnumerateInstanceExtensionProperties);
-
-    uint32_t extensionCount = 0;
-    VK_CHECK(vkEnumerateInstanceExtensionProperties(nullptr, &extensionCount, nullptr));
-    std::vector<VkExtensionProperties> instanceExtensions(extensionCount);
-    VK_CHECK(vkEnumerateInstanceExtensionProperties(nullptr, &extensionCount,
-                                                    instanceExtensions.data()));
-    std::vector<const char*> enabledInstanceExtensionNames;
-    enabledInstanceExtensionNames.reserve(instanceExtensions.size());
-    interface.instanceExtensionNames.reserve(instanceExtensions.size());
-    for (const auto& instExt : instanceExtensions) {
-        enabledInstanceExtensionNames.push_back(instExt.extensionName);
-        interface.instanceExtensionNames.push_back(instExt.extensionName);
-    }
-
-    const VkInstanceCreateInfo instanceCreateInfo = {
-            VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO,
-            nullptr,
-            0,
-            &appInfo,
-            0,
-            nullptr,
-            (uint32_t)enabledInstanceExtensionNames.size(),
-            enabledInstanceExtensionNames.data(),
-    };
-
-    VK_GET_PROC(CreateInstance);
-    VkInstance instance;
-    VK_CHECK(vkCreateInstance(&instanceCreateInfo, nullptr, &instance));
-
-    VK_GET_INST_PROC(instance, DestroyInstance);
-    interface.funcs.vkDestroyInstance = vkDestroyInstance;
-    VK_GET_INST_PROC(instance, EnumeratePhysicalDevices);
-    VK_GET_INST_PROC(instance, EnumerateDeviceExtensionProperties);
-    VK_GET_INST_PROC(instance, GetPhysicalDeviceProperties2);
-    VK_GET_INST_PROC(instance, GetPhysicalDeviceExternalSemaphoreProperties);
-    VK_GET_INST_PROC(instance, GetPhysicalDeviceQueueFamilyProperties2);
-    VK_GET_INST_PROC(instance, GetPhysicalDeviceFeatures2);
-    VK_GET_INST_PROC(instance, CreateDevice);
-
-    uint32_t physdevCount;
-    VK_CHECK(vkEnumeratePhysicalDevices(instance, &physdevCount, nullptr));
-    if (physdevCount == 0) {
-        BAIL("Could not find any physical devices");
-    }
-
-    physdevCount = 1;
-    VkPhysicalDevice physicalDevice;
-    VkResult enumeratePhysDevsErr =
-            vkEnumeratePhysicalDevices(instance, &physdevCount, &physicalDevice);
-    if (enumeratePhysDevsErr != VK_SUCCESS && VK_INCOMPLETE != enumeratePhysDevsErr) {
-        BAIL("vkEnumeratePhysicalDevices failed with non-VK_INCOMPLETE error: %d",
-             enumeratePhysDevsErr);
-    }
-
-    VkPhysicalDeviceProperties2 physDevProps = {
-            VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2,
-            0,
-            {},
-    };
-    VkPhysicalDeviceProtectedMemoryProperties protMemProps = {
-            VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROTECTED_MEMORY_PROPERTIES,
-            0,
-            {},
-    };
-
-    if (protectedContent) {
-        physDevProps.pNext = &protMemProps;
-    }
-
-    vkGetPhysicalDeviceProperties2(physicalDevice, &physDevProps);
-    if (physDevProps.properties.apiVersion < VK_MAKE_VERSION(1, 1, 0)) {
-        BAIL("Could not find a Vulkan 1.1+ physical device");
-    }
-
-    if (physDevProps.properties.deviceType == VK_PHYSICAL_DEVICE_TYPE_CPU) {
-        // TODO: b/326633110 - SkiaVK is not working correctly on swiftshader path.
-        BAIL("CPU implementations of Vulkan is not supported");
-    }
-
-    // Check for syncfd support. Bail if we cannot both import and export them.
-    VkPhysicalDeviceExternalSemaphoreInfo semInfo = {
-            VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_EXTERNAL_SEMAPHORE_INFO,
-            nullptr,
-            VK_EXTERNAL_SEMAPHORE_HANDLE_TYPE_SYNC_FD_BIT,
-    };
-    VkExternalSemaphoreProperties semProps = {
-            VK_STRUCTURE_TYPE_EXTERNAL_SEMAPHORE_PROPERTIES, nullptr, 0, 0, 0,
-    };
-    vkGetPhysicalDeviceExternalSemaphoreProperties(physicalDevice, &semInfo, &semProps);
-
-    bool sufficientSemaphoreSyncFdSupport = (semProps.exportFromImportedHandleTypes &
-                                             VK_EXTERNAL_SEMAPHORE_HANDLE_TYPE_SYNC_FD_BIT) &&
-            (semProps.compatibleHandleTypes & VK_EXTERNAL_SEMAPHORE_HANDLE_TYPE_SYNC_FD_BIT) &&
-            (semProps.externalSemaphoreFeatures & VK_EXTERNAL_SEMAPHORE_FEATURE_EXPORTABLE_BIT) &&
-            (semProps.externalSemaphoreFeatures & VK_EXTERNAL_SEMAPHORE_FEATURE_IMPORTABLE_BIT);
-
-    if (!sufficientSemaphoreSyncFdSupport) {
-        BAIL("Vulkan device does not support sufficient external semaphore sync fd features. "
-             "exportFromImportedHandleTypes 0x%x (needed 0x%x) "
-             "compatibleHandleTypes 0x%x (needed 0x%x) "
-             "externalSemaphoreFeatures 0x%x (needed 0x%x) ",
-             semProps.exportFromImportedHandleTypes, VK_EXTERNAL_SEMAPHORE_HANDLE_TYPE_SYNC_FD_BIT,
-             semProps.compatibleHandleTypes, VK_EXTERNAL_SEMAPHORE_HANDLE_TYPE_SYNC_FD_BIT,
-             semProps.externalSemaphoreFeatures,
-             VK_EXTERNAL_SEMAPHORE_FEATURE_EXPORTABLE_BIT |
-                     VK_EXTERNAL_SEMAPHORE_FEATURE_IMPORTABLE_BIT);
-    } else {
-        ALOGD("Vulkan device supports sufficient external semaphore sync fd features. "
-              "exportFromImportedHandleTypes 0x%x (needed 0x%x) "
-              "compatibleHandleTypes 0x%x (needed 0x%x) "
-              "externalSemaphoreFeatures 0x%x (needed 0x%x) ",
-              semProps.exportFromImportedHandleTypes, VK_EXTERNAL_SEMAPHORE_HANDLE_TYPE_SYNC_FD_BIT,
-              semProps.compatibleHandleTypes, VK_EXTERNAL_SEMAPHORE_HANDLE_TYPE_SYNC_FD_BIT,
-              semProps.externalSemaphoreFeatures,
-              VK_EXTERNAL_SEMAPHORE_FEATURE_EXPORTABLE_BIT |
-                      VK_EXTERNAL_SEMAPHORE_FEATURE_IMPORTABLE_BIT);
-    }
-
-    uint32_t queueCount;
-    vkGetPhysicalDeviceQueueFamilyProperties2(physicalDevice, &queueCount, nullptr);
-    if (queueCount == 0) {
-        BAIL("Could not find queues for physical device");
-    }
-
-    std::vector<VkQueueFamilyProperties2> queueProps(queueCount);
-    std::vector<VkQueueFamilyGlobalPriorityPropertiesEXT> queuePriorityProps(queueCount);
-    VkQueueGlobalPriorityKHR queuePriority = VK_QUEUE_GLOBAL_PRIORITY_MEDIUM_KHR;
-    // Even though we don't yet know if the VK_EXT_global_priority extension is available,
-    // we can safely add the request to the pNext chain, and if the extension is not
-    // available, it will be ignored.
-    for (uint32_t i = 0; i < queueCount; ++i) {
-        queuePriorityProps[i].sType = VK_STRUCTURE_TYPE_QUEUE_FAMILY_GLOBAL_PRIORITY_PROPERTIES_EXT;
-        queuePriorityProps[i].pNext = nullptr;
-        queueProps[i].pNext = &queuePriorityProps[i];
-    }
-    vkGetPhysicalDeviceQueueFamilyProperties2(physicalDevice, &queueCount, queueProps.data());
-
-    int graphicsQueueIndex = -1;
-    for (uint32_t i = 0; i < queueCount; ++i) {
-        // Look at potential answers to the VK_EXT_global_priority query.  If answers were
-        // provided, we may adjust the queuePriority.
-        if (queueProps[i].queueFamilyProperties.queueFlags & VK_QUEUE_GRAPHICS_BIT) {
-            for (uint32_t j = 0; j < queuePriorityProps[i].priorityCount; j++) {
-                if (queuePriorityProps[i].priorities[j] > queuePriority) {
-                    queuePriority = queuePriorityProps[i].priorities[j];
-                }
-            }
-            if (queuePriority == VK_QUEUE_GLOBAL_PRIORITY_REALTIME_KHR) {
-                interface.isRealtimePriority = true;
-            }
-            graphicsQueueIndex = i;
-            break;
-        }
-    }
-
-    if (graphicsQueueIndex == -1) {
-        BAIL("Could not find a graphics queue family");
-    }
-
-    uint32_t deviceExtensionCount;
-    VK_CHECK(vkEnumerateDeviceExtensionProperties(physicalDevice, nullptr, &deviceExtensionCount,
-                                                  nullptr));
-    std::vector<VkExtensionProperties> deviceExtensions(deviceExtensionCount);
-    VK_CHECK(vkEnumerateDeviceExtensionProperties(physicalDevice, nullptr, &deviceExtensionCount,
-                                                  deviceExtensions.data()));
-
-    std::vector<const char*> enabledDeviceExtensionNames;
-    enabledDeviceExtensionNames.reserve(deviceExtensions.size());
-    interface.deviceExtensionNames.reserve(deviceExtensions.size());
-    for (const auto& devExt : deviceExtensions) {
-        enabledDeviceExtensionNames.push_back(devExt.extensionName);
-        interface.deviceExtensionNames.push_back(devExt.extensionName);
-    }
-
-    interface.grExtensions.init(sGetProc, instance, physicalDevice,
-                                enabledInstanceExtensionNames.size(),
-                                enabledInstanceExtensionNames.data(),
-                                enabledDeviceExtensionNames.size(),
-                                enabledDeviceExtensionNames.data());
-
-    if (!interface.grExtensions.hasExtension(VK_KHR_EXTERNAL_SEMAPHORE_FD_EXTENSION_NAME, 1)) {
-        BAIL("Vulkan driver doesn't support external semaphore fd");
-    }
-
-    interface.physicalDeviceFeatures2 = new VkPhysicalDeviceFeatures2;
-    interface.physicalDeviceFeatures2->sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2;
-    interface.physicalDeviceFeatures2->pNext = nullptr;
-
-    interface.samplerYcbcrConversionFeatures = new VkPhysicalDeviceSamplerYcbcrConversionFeatures;
-    interface.samplerYcbcrConversionFeatures->sType =
-            VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SAMPLER_YCBCR_CONVERSION_FEATURES;
-    interface.samplerYcbcrConversionFeatures->pNext = nullptr;
-
-    interface.physicalDeviceFeatures2->pNext = interface.samplerYcbcrConversionFeatures;
-    void** tailPnext = &interface.samplerYcbcrConversionFeatures->pNext;
-
-    if (protectedContent) {
-        interface.protectedMemoryFeatures = new VkPhysicalDeviceProtectedMemoryFeatures;
-        interface.protectedMemoryFeatures->sType =
-                VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROTECTED_MEMORY_FEATURES;
-        interface.protectedMemoryFeatures->pNext = nullptr;
-        *tailPnext = interface.protectedMemoryFeatures;
-        tailPnext = &interface.protectedMemoryFeatures->pNext;
-    }
-
-    if (interface.grExtensions.hasExtension(VK_EXT_DEVICE_FAULT_EXTENSION_NAME, 1)) {
-        interface.deviceFaultFeatures = new VkPhysicalDeviceFaultFeaturesEXT;
-        interface.deviceFaultFeatures->sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FAULT_FEATURES_EXT;
-        interface.deviceFaultFeatures->pNext = nullptr;
-        *tailPnext = interface.deviceFaultFeatures;
-        tailPnext = &interface.deviceFaultFeatures->pNext;
-    }
-
-    vkGetPhysicalDeviceFeatures2(physicalDevice, interface.physicalDeviceFeatures2);
-    // Looks like this would slow things down and we can't depend on it on all platforms
-    interface.physicalDeviceFeatures2->features.robustBufferAccess = VK_FALSE;
-
-    if (protectedContent && !interface.protectedMemoryFeatures->protectedMemory) {
-        BAIL("Protected memory not supported");
-    }
-
-    float queuePriorities[1] = {0.0f};
-    void* queueNextPtr = nullptr;
-
-    VkDeviceQueueGlobalPriorityCreateInfoEXT queuePriorityCreateInfo = {
-            VK_STRUCTURE_TYPE_DEVICE_QUEUE_GLOBAL_PRIORITY_CREATE_INFO_EXT,
-            nullptr,
-            // If queue priority is supported, RE should always have realtime priority.
-            queuePriority,
-    };
-
-    if (interface.grExtensions.hasExtension(VK_EXT_GLOBAL_PRIORITY_EXTENSION_NAME, 2)) {
-        queueNextPtr = &queuePriorityCreateInfo;
-    }
-
-    VkDeviceQueueCreateFlags deviceQueueCreateFlags =
-            (VkDeviceQueueCreateFlags)(protectedContent ? VK_DEVICE_QUEUE_CREATE_PROTECTED_BIT : 0);
-
-    const VkDeviceQueueCreateInfo queueInfo = {
-            VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO,
-            queueNextPtr,
-            deviceQueueCreateFlags,
-            (uint32_t)graphicsQueueIndex,
-            1,
-            queuePriorities,
-    };
-
-    const VkDeviceCreateInfo deviceInfo = {
-            VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO,
-            interface.physicalDeviceFeatures2,
-            0,
-            1,
-            &queueInfo,
-            0,
-            nullptr,
-            (uint32_t)enabledDeviceExtensionNames.size(),
-            enabledDeviceExtensionNames.data(),
-            nullptr,
-    };
-
-    ALOGD("Trying to create Vk device with protectedContent=%d", protectedContent);
-    VkDevice device;
-    VK_CHECK(vkCreateDevice(physicalDevice, &deviceInfo, nullptr, &device));
-    ALOGD("Trying to create Vk device with protectedContent=%d (success)", protectedContent);
-
-    VkQueue graphicsQueue;
-    VK_GET_DEV_PROC(device, GetDeviceQueue2);
-    const VkDeviceQueueInfo2 deviceQueueInfo2 = {VK_STRUCTURE_TYPE_DEVICE_QUEUE_INFO_2, nullptr,
-                                                 deviceQueueCreateFlags,
-                                                 (uint32_t)graphicsQueueIndex, 0};
-    vkGetDeviceQueue2(device, &deviceQueueInfo2, &graphicsQueue);
-
-    VK_GET_DEV_PROC(device, DeviceWaitIdle);
-    VK_GET_DEV_PROC(device, DestroyDevice);
-    interface.funcs.vkDeviceWaitIdle = vkDeviceWaitIdle;
-    interface.funcs.vkDestroyDevice = vkDestroyDevice;
-
-    VK_GET_DEV_PROC(device, CreateSemaphore);
-    VK_GET_DEV_PROC(device, ImportSemaphoreFdKHR);
-    VK_GET_DEV_PROC(device, GetSemaphoreFdKHR);
-    VK_GET_DEV_PROC(device, DestroySemaphore);
-    interface.funcs.vkCreateSemaphore = vkCreateSemaphore;
-    interface.funcs.vkImportSemaphoreFdKHR = vkImportSemaphoreFdKHR;
-    interface.funcs.vkGetSemaphoreFdKHR = vkGetSemaphoreFdKHR;
-    interface.funcs.vkDestroySemaphore = vkDestroySemaphore;
-
-    // At this point, everything's succeeded and we can continue
-    interface.initialized = true;
-    interface.instance = instance;
-    interface.physicalDevice = physicalDevice;
-    interface.device = device;
-    interface.queue = graphicsQueue;
-    interface.queueIndex = graphicsQueueIndex;
-    interface.apiVersion = physDevProps.properties.apiVersion;
-    // grExtensions already constructed
-    // feature pointers already constructed
-    interface.grGetProc = sGetProc;
-    interface.isProtected = protectedContent;
-    // funcs already initialized
-
-    const nsecs_t timeAfter = systemTime();
-    const float initTimeMs = static_cast<float>(timeAfter - timeBefore) / 1.0E6;
-    ALOGD("%s: Success init Vulkan interface in %f ms", __func__, initTimeMs);
-    return interface;
-}
-
-void teardownVulkanInterface(VulkanInterface* interface) {
-    interface->initialized = false;
-
-    if (interface->device != VK_NULL_HANDLE) {
-        interface->funcs.vkDeviceWaitIdle(interface->device);
-        interface->funcs.vkDestroyDevice(interface->device, nullptr);
-        interface->device = VK_NULL_HANDLE;
-    }
-    if (interface->instance != VK_NULL_HANDLE) {
-        interface->funcs.vkDestroyInstance(interface->instance, nullptr);
-        interface->instance = VK_NULL_HANDLE;
-    }
-
-    if (interface->protectedMemoryFeatures) {
-        delete interface->protectedMemoryFeatures;
-    }
-
-    if (interface->samplerYcbcrConversionFeatures) {
-        delete interface->samplerYcbcrConversionFeatures;
-    }
-
-    if (interface->physicalDeviceFeatures2) {
-        delete interface->physicalDeviceFeatures2;
-    }
-
-    if (interface->deviceFaultFeatures) {
-        delete interface->deviceFaultFeatures;
-    }
-
-    interface->samplerYcbcrConversionFeatures = nullptr;
-    interface->physicalDeviceFeatures2 = nullptr;
-    interface->protectedMemoryFeatures = nullptr;
-}
-
-static VulkanInterface sVulkanInterface;
-static VulkanInterface sProtectedContentVulkanInterface;
+static skia::VulkanInterface sVulkanInterface;
+static skia::VulkanInterface sProtectedContentVulkanInterface;
 
 static void sSetupVulkanInterface() {
-    if (!sVulkanInterface.initialized) {
-        sVulkanInterface = initVulkanInterface(false /* no protected content */);
+    if (!sVulkanInterface.isInitialized()) {
+        sVulkanInterface.init(false /* no protected content */);
         // We will have to abort if non-protected VkDevice creation fails (then nothing works).
-        LOG_ALWAYS_FATAL_IF(!sVulkanInterface.initialized,
+        LOG_ALWAYS_FATAL_IF(!sVulkanInterface.isInitialized(),
                             "Could not initialize Vulkan RenderEngine!");
     }
-    if (!sProtectedContentVulkanInterface.initialized) {
-        sProtectedContentVulkanInterface = initVulkanInterface(true /* protected content */);
-        if (!sProtectedContentVulkanInterface.initialized) {
+    if (!sProtectedContentVulkanInterface.isInitialized()) {
+        sProtectedContentVulkanInterface.init(true /* protected content */);
+        if (!sProtectedContentVulkanInterface.isInitialized()) {
             ALOGE("Could not initialize protected content Vulkan RenderEngine.");
         }
     }
@@ -667,12 +65,12 @@
         case GraphicsApi::GL:
             return true;
         case GraphicsApi::VK: {
-            if (!sVulkanInterface.initialized) {
-                sVulkanInterface = initVulkanInterface(false /* no protected content */);
+            if (!sVulkanInterface.isInitialized()) {
+                sVulkanInterface.init(false /* no protected content */);
                 ALOGD("%s: initialized == %s.", __func__,
-                      sVulkanInterface.initialized ? "true" : "false");
+                      sVulkanInterface.isInitialized() ? "true" : "false");
             }
-            return sVulkanInterface.initialized;
+            return sVulkanInterface.isInitialized();
         }
     }
 }
@@ -686,7 +84,7 @@
     std::unique_ptr<SkiaVkRenderEngine> engine(new SkiaVkRenderEngine(args));
     engine->ensureGrContextsCreated();
 
-    if (sVulkanInterface.initialized) {
+    if (sVulkanInterface.isInitialized()) {
         ALOGD("SkiaVkRenderEngine::%s: successfully initialized SkiaVkRenderEngine", __func__);
         return engine;
     } else {
@@ -721,29 +119,17 @@
 }
 
 bool SkiaVkRenderEngine::supportsProtectedContentImpl() const {
-    return sProtectedContentVulkanInterface.initialized;
+    return sProtectedContentVulkanInterface.isInitialized();
 }
 
 bool SkiaVkRenderEngine::useProtectedContextImpl(GrProtected) {
     return true;
 }
 
-static void delete_semaphore(void* semaphore) {
-    DestroySemaphoreInfo* info = reinterpret_cast<DestroySemaphoreInfo*>(semaphore);
-    --info->mRefs;
-    if (!info->mRefs) {
-        sVulkanInterface.destroySemaphore(info->mSemaphore);
-        delete info;
-    }
-}
-
-static void delete_semaphore_protected(void* semaphore) {
-    DestroySemaphoreInfo* info = reinterpret_cast<DestroySemaphoreInfo*>(semaphore);
-    --info->mRefs;
-    if (!info->mRefs) {
-        sProtectedContentVulkanInterface.destroySemaphore(info->mSemaphore);
-        delete info;
-    }
+static void unref_semaphore(void* semaphore) {
+    SkiaVkRenderEngine::DestroySemaphoreInfo* info =
+            reinterpret_cast<SkiaVkRenderEngine::DestroySemaphoreInfo*>(semaphore);
+    info->unref();
 }
 
 static VulkanInterface& getVulkanInterface(bool protectedContext) {
@@ -766,8 +152,7 @@
     base::unique_fd fenceDup(dupedFd);
     VkSemaphore waitSemaphore =
             getVulkanInterface(isProtected()).importSemaphoreFromSyncFd(fenceDup.release());
-    GrBackendSemaphore beSemaphore;
-    beSemaphore.initVulkan(waitSemaphore);
+    GrBackendSemaphore beSemaphore = GrBackendSemaphores::MakeVk(waitSemaphore);
     grContext->wait(1, &beSemaphore, true /* delete after wait */);
 }
 
@@ -775,16 +160,15 @@
     VulkanInterface& vi = getVulkanInterface(isProtected());
     VkSemaphore semaphore = vi.createExportableSemaphore();
 
-    GrBackendSemaphore backendSemaphore;
-    backendSemaphore.initVulkan(semaphore);
+    GrBackendSemaphore backendSemaphore = GrBackendSemaphores::MakeVk(semaphore);
 
     GrFlushInfo flushInfo;
     DestroySemaphoreInfo* destroySemaphoreInfo = nullptr;
     if (semaphore != VK_NULL_HANDLE) {
-        destroySemaphoreInfo = new DestroySemaphoreInfo(semaphore);
+        destroySemaphoreInfo = new DestroySemaphoreInfo(vi, semaphore);
         flushInfo.fNumSemaphores = 1;
         flushInfo.fSignalSemaphores = &backendSemaphore;
-        flushInfo.fFinishedProc = isProtected() ? delete_semaphore_protected : delete_semaphore;
+        flushInfo.fFinishedProc = unref_semaphore;
         flushInfo.fFinishedContext = destroySemaphoreInfo;
     }
     GrSemaphoresSubmitted submitted = grContext->flush(flushInfo);
@@ -804,7 +188,7 @@
 int SkiaVkRenderEngine::getContextPriority() {
     // EGL_CONTEXT_PRIORITY_REALTIME_NV
     constexpr int kRealtimePriority = 0x3357;
-    if (getVulkanInterface(isProtected()).isRealtimePriority) {
+    if (getVulkanInterface(isProtected()).isRealtimePriority()) {
         return kRealtimePriority;
     } else {
         return 0;
@@ -813,21 +197,21 @@
 
 void SkiaVkRenderEngine::appendBackendSpecificInfoToDump(std::string& result) {
     StringAppendF(&result, "\n ------------RE Vulkan----------\n");
-    StringAppendF(&result, "\n Vulkan device initialized: %d\n", sVulkanInterface.initialized);
+    StringAppendF(&result, "\n Vulkan device initialized: %d\n", sVulkanInterface.isInitialized());
     StringAppendF(&result, "\n Vulkan protected device initialized: %d\n",
-                  sProtectedContentVulkanInterface.initialized);
+                  sProtectedContentVulkanInterface.isInitialized());
 
-    if (!sVulkanInterface.initialized) {
+    if (!sVulkanInterface.isInitialized()) {
         return;
     }
 
     StringAppendF(&result, "\n Instance extensions:\n");
-    for (const auto& name : sVulkanInterface.instanceExtensionNames) {
+    for (const auto& name : sVulkanInterface.getInstanceExtensionNames()) {
         StringAppendF(&result, "\n %s\n", name.c_str());
     }
 
     StringAppendF(&result, "\n Device extensions:\n");
-    for (const auto& name : sVulkanInterface.deviceExtensionNames) {
+    for (const auto& name : sVulkanInterface.getDeviceExtensionNames()) {
         StringAppendF(&result, "\n %s\n", name.c_str());
     }
 }
diff --git a/libs/renderengine/skia/SkiaVkRenderEngine.h b/libs/renderengine/skia/SkiaVkRenderEngine.h
index 52bc500..ca0dcbf 100644
--- a/libs/renderengine/skia/SkiaVkRenderEngine.h
+++ b/libs/renderengine/skia/SkiaVkRenderEngine.h
@@ -20,6 +20,7 @@
 #include <vk/GrVkBackendContext.h>
 
 #include "SkiaRenderEngine.h"
+#include "VulkanInterface.h"
 
 namespace android {
 namespace renderengine {
@@ -32,6 +33,42 @@
 
     int getContextPriority() override;
 
+    class DestroySemaphoreInfo {
+    public:
+        DestroySemaphoreInfo() = delete;
+        DestroySemaphoreInfo(const DestroySemaphoreInfo&) = delete;
+        DestroySemaphoreInfo& operator=(const DestroySemaphoreInfo&) = delete;
+        DestroySemaphoreInfo& operator=(DestroySemaphoreInfo&&) = delete;
+
+        DestroySemaphoreInfo(VulkanInterface& vulkanInterface, std::vector<VkSemaphore> semaphores)
+              : mVulkanInterface(vulkanInterface), mSemaphores(std::move(semaphores)) {}
+        DestroySemaphoreInfo(VulkanInterface& vulkanInterface, VkSemaphore semaphore)
+              : DestroySemaphoreInfo(vulkanInterface, std::vector<VkSemaphore>(1, semaphore)) {}
+
+        void unref() {
+            --mRefs;
+            if (!mRefs) {
+                for (VkSemaphore semaphore : mSemaphores) {
+                    mVulkanInterface.destroySemaphore(semaphore);
+                }
+                delete this;
+            }
+        }
+
+    private:
+        ~DestroySemaphoreInfo() = default;
+
+        VulkanInterface& mVulkanInterface;
+        std::vector<VkSemaphore> mSemaphores;
+        // We need to make sure we don't delete the VkSemaphore until it is done being used by both
+        // Skia (including by the GPU) and inside SkiaVkRenderEngine. So we always start with two
+        // refs, one owned by Skia and one owned by the SkiaVkRenderEngine. The refs are decremented
+        // each time unref() is called on this object. Skia will call unref() once it is done with
+        // the semaphore and the GPU has finished work on the semaphore. SkiaVkRenderEngine calls
+        // unref() after sending the semaphore to Skia and exporting it if need be.
+        int mRefs = 2;
+    };
+
 protected:
     // Implementations of abstract SkiaRenderEngine functions specific to
     // rendering backend
diff --git a/libs/renderengine/skia/VulkanInterface.cpp b/libs/renderengine/skia/VulkanInterface.cpp
new file mode 100644
index 0000000..453cdc1
--- /dev/null
+++ b/libs/renderengine/skia/VulkanInterface.cpp
@@ -0,0 +1,582 @@
+/*
+ * Copyright 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#undef LOG_TAG
+#define LOG_TAG "RenderEngine"
+
+#include "VulkanInterface.h"
+
+#include <include/gpu/GpuTypes.h>
+#include <log/log_main.h>
+#include <utils/Timers.h>
+
+#include <cinttypes>
+#include <sstream>
+
+namespace android {
+namespace renderengine {
+namespace skia {
+
+GrVkBackendContext VulkanInterface::getBackendContext() {
+    GrVkBackendContext backendContext;
+    backendContext.fInstance = mInstance;
+    backendContext.fPhysicalDevice = mPhysicalDevice;
+    backendContext.fDevice = mDevice;
+    backendContext.fQueue = mQueue;
+    backendContext.fGraphicsQueueIndex = mQueueIndex;
+    backendContext.fMaxAPIVersion = mApiVersion;
+    backendContext.fVkExtensions = &mGrExtensions;
+    backendContext.fDeviceFeatures2 = mPhysicalDeviceFeatures2;
+    backendContext.fGetProc = mGrGetProc;
+    backendContext.fProtectedContext = mIsProtected ? Protected::kYes : Protected::kNo;
+    backendContext.fDeviceLostContext = this; // VulkanInterface is long-lived
+    backendContext.fDeviceLostProc = onVkDeviceFault;
+    return backendContext;
+};
+
+VkSemaphore VulkanInterface::createExportableSemaphore() {
+    VkExportSemaphoreCreateInfo exportInfo;
+    exportInfo.sType = VK_STRUCTURE_TYPE_EXPORT_SEMAPHORE_CREATE_INFO;
+    exportInfo.pNext = nullptr;
+    exportInfo.handleTypes = VK_EXTERNAL_SEMAPHORE_HANDLE_TYPE_SYNC_FD_BIT;
+
+    VkSemaphoreCreateInfo semaphoreInfo;
+    semaphoreInfo.sType = VK_STRUCTURE_TYPE_SEMAPHORE_CREATE_INFO;
+    semaphoreInfo.pNext = &exportInfo;
+    semaphoreInfo.flags = 0;
+
+    VkSemaphore semaphore;
+    VkResult err = mFuncs.vkCreateSemaphore(mDevice, &semaphoreInfo, nullptr, &semaphore);
+    if (VK_SUCCESS != err) {
+        ALOGE("%s: failed to create semaphore. err %d\n", __func__, err);
+        return VK_NULL_HANDLE;
+    }
+
+    return semaphore;
+}
+
+// syncFd cannot be <= 0
+VkSemaphore VulkanInterface::importSemaphoreFromSyncFd(int syncFd) {
+    VkSemaphoreCreateInfo semaphoreInfo;
+    semaphoreInfo.sType = VK_STRUCTURE_TYPE_SEMAPHORE_CREATE_INFO;
+    semaphoreInfo.pNext = nullptr;
+    semaphoreInfo.flags = 0;
+
+    VkSemaphore semaphore;
+    VkResult err = mFuncs.vkCreateSemaphore(mDevice, &semaphoreInfo, nullptr, &semaphore);
+    if (VK_SUCCESS != err) {
+        ALOGE("%s: failed to create import semaphore", __func__);
+        return VK_NULL_HANDLE;
+    }
+
+    VkImportSemaphoreFdInfoKHR importInfo;
+    importInfo.sType = VK_STRUCTURE_TYPE_IMPORT_SEMAPHORE_FD_INFO_KHR;
+    importInfo.pNext = nullptr;
+    importInfo.semaphore = semaphore;
+    importInfo.flags = VK_SEMAPHORE_IMPORT_TEMPORARY_BIT;
+    importInfo.handleType = VK_EXTERNAL_SEMAPHORE_HANDLE_TYPE_SYNC_FD_BIT;
+    importInfo.fd = syncFd;
+
+    err = mFuncs.vkImportSemaphoreFdKHR(mDevice, &importInfo);
+    if (VK_SUCCESS != err) {
+        mFuncs.vkDestroySemaphore(mDevice, semaphore, nullptr);
+        ALOGE("%s: failed to import semaphore", __func__);
+        return VK_NULL_HANDLE;
+    }
+
+    return semaphore;
+}
+
+int VulkanInterface::exportSemaphoreSyncFd(VkSemaphore semaphore) {
+    int res;
+
+    VkSemaphoreGetFdInfoKHR getFdInfo;
+    getFdInfo.sType = VK_STRUCTURE_TYPE_SEMAPHORE_GET_FD_INFO_KHR;
+    getFdInfo.pNext = nullptr;
+    getFdInfo.semaphore = semaphore;
+    getFdInfo.handleType = VK_EXTERNAL_SEMAPHORE_HANDLE_TYPE_SYNC_FD_BIT;
+    VkResult err = mFuncs.vkGetSemaphoreFdKHR(mDevice, &getFdInfo, &res);
+    if (VK_SUCCESS != err) {
+        ALOGE("%s: failed to export semaphore, err: %d", __func__, err);
+        return -1;
+    }
+    return res;
+}
+
+void VulkanInterface::destroySemaphore(VkSemaphore semaphore) {
+    mFuncs.vkDestroySemaphore(mDevice, semaphore, nullptr);
+}
+
+void VulkanInterface::onVkDeviceFault(void* callbackContext, const std::string& description,
+                                      const std::vector<VkDeviceFaultAddressInfoEXT>& addressInfos,
+                                      const std::vector<VkDeviceFaultVendorInfoEXT>& vendorInfos,
+                                      const std::vector<std::byte>& vendorBinaryData) {
+    VulkanInterface* interface = static_cast<VulkanInterface*>(callbackContext);
+    const std::string protectedStr = interface->mIsProtected ? "protected" : "non-protected";
+    // The final crash string should contain as much differentiating info as possible, up to 1024
+    // bytes. As this final message is constructed, the same information is also dumped to the logs
+    // but in a more verbose format. Building the crash string is unsightly, so the clearer logging
+    // statement is always placed first to give context.
+    ALOGE("VK_ERROR_DEVICE_LOST (%s context): %s", protectedStr.c_str(), description.c_str());
+    std::stringstream crashMsg;
+    crashMsg << "VK_ERROR_DEVICE_LOST (" << protectedStr;
+
+    if (!addressInfos.empty()) {
+        ALOGE("%zu VkDeviceFaultAddressInfoEXT:", addressInfos.size());
+        crashMsg << ", " << addressInfos.size() << " address info (";
+        for (VkDeviceFaultAddressInfoEXT addressInfo : addressInfos) {
+            ALOGE(" addressType:       %d", (int)addressInfo.addressType);
+            ALOGE("  reportedAddress:  %" PRIu64, addressInfo.reportedAddress);
+            ALOGE("  addressPrecision: %" PRIu64, addressInfo.addressPrecision);
+            crashMsg << addressInfo.addressType << ":" << addressInfo.reportedAddress << ":"
+                     << addressInfo.addressPrecision << ", ";
+        }
+        crashMsg.seekp(-2, crashMsg.cur); // Move back to overwrite trailing ", "
+        crashMsg << ")";
+    }
+
+    if (!vendorInfos.empty()) {
+        ALOGE("%zu VkDeviceFaultVendorInfoEXT:", vendorInfos.size());
+        crashMsg << ", " << vendorInfos.size() << " vendor info (";
+        for (VkDeviceFaultVendorInfoEXT vendorInfo : vendorInfos) {
+            ALOGE(" description:      %s", vendorInfo.description);
+            ALOGE("  vendorFaultCode: %" PRIu64, vendorInfo.vendorFaultCode);
+            ALOGE("  vendorFaultData: %" PRIu64, vendorInfo.vendorFaultData);
+            // Omit descriptions for individual vendor info structs in the crash string, as the
+            // fault code and fault data fields should be enough for clustering, and the verbosity
+            // isn't worth it. Additionally, vendors may just set the general description field of
+            // the overall fault to the description of the first element in this list, and that
+            // overall description will be placed at the end of the crash string.
+            crashMsg << vendorInfo.vendorFaultCode << ":" << vendorInfo.vendorFaultData << ", ";
+        }
+        crashMsg.seekp(-2, crashMsg.cur); // Move back to overwrite trailing ", "
+        crashMsg << ")";
+    }
+
+    if (!vendorBinaryData.empty()) {
+        // TODO: b/322830575 - Log in base64, or dump directly to a file that gets put in bugreports
+        ALOGE("%zu bytes of vendor-specific binary data (please notify Android's Core Graphics"
+              " Stack team if you observe this message).",
+              vendorBinaryData.size());
+        crashMsg << ", " << vendorBinaryData.size() << " bytes binary";
+    }
+
+    crashMsg << "): " << description;
+    LOG_ALWAYS_FATAL("%s", crashMsg.str().c_str());
+};
+
+static GrVkGetProc sGetProc = [](const char* proc_name, VkInstance instance, VkDevice device) {
+    if (device != VK_NULL_HANDLE) {
+        return vkGetDeviceProcAddr(device, proc_name);
+    }
+    return vkGetInstanceProcAddr(instance, proc_name);
+};
+
+#define BAIL(fmt, ...)                                          \
+    {                                                           \
+        ALOGE("%s: " fmt ", bailing", __func__, ##__VA_ARGS__); \
+        return;                                                 \
+    }
+
+#define CHECK_NONNULL(expr)       \
+    if ((expr) == nullptr) {      \
+        BAIL("[%s] null", #expr); \
+    }
+
+#define VK_CHECK(expr)                              \
+    if ((expr) != VK_SUCCESS) {                     \
+        BAIL("[%s] failed. err = %d", #expr, expr); \
+        return;                                     \
+    }
+
+#define VK_GET_PROC(F)                                                           \
+    PFN_vk##F vk##F = (PFN_vk##F)vkGetInstanceProcAddr(VK_NULL_HANDLE, "vk" #F); \
+    CHECK_NONNULL(vk##F)
+#define VK_GET_INST_PROC(instance, F)                                      \
+    PFN_vk##F vk##F = (PFN_vk##F)vkGetInstanceProcAddr(instance, "vk" #F); \
+    CHECK_NONNULL(vk##F)
+#define VK_GET_DEV_PROC(device, F)                                     \
+    PFN_vk##F vk##F = (PFN_vk##F)vkGetDeviceProcAddr(device, "vk" #F); \
+    CHECK_NONNULL(vk##F)
+
+void VulkanInterface::init(bool protectedContent) {
+    if (isInitialized()) {
+        ALOGW("Called init on already initialized VulkanInterface");
+        return;
+    }
+
+    const nsecs_t timeBefore = systemTime();
+
+    VK_GET_PROC(EnumerateInstanceVersion);
+    uint32_t instanceVersion;
+    VK_CHECK(vkEnumerateInstanceVersion(&instanceVersion));
+
+    if (instanceVersion < VK_MAKE_VERSION(1, 1, 0)) {
+        return;
+    }
+
+    const VkApplicationInfo appInfo = {
+            VK_STRUCTURE_TYPE_APPLICATION_INFO, nullptr, "surfaceflinger", 0, "android platform", 0,
+            VK_MAKE_VERSION(1, 1, 0),
+    };
+
+    VK_GET_PROC(EnumerateInstanceExtensionProperties);
+
+    uint32_t extensionCount = 0;
+    VK_CHECK(vkEnumerateInstanceExtensionProperties(nullptr, &extensionCount, nullptr));
+    std::vector<VkExtensionProperties> instanceExtensions(extensionCount);
+    VK_CHECK(vkEnumerateInstanceExtensionProperties(nullptr, &extensionCount,
+                                                    instanceExtensions.data()));
+    std::vector<const char*> enabledInstanceExtensionNames;
+    enabledInstanceExtensionNames.reserve(instanceExtensions.size());
+    mInstanceExtensionNames.reserve(instanceExtensions.size());
+    for (const auto& instExt : instanceExtensions) {
+        enabledInstanceExtensionNames.push_back(instExt.extensionName);
+        mInstanceExtensionNames.push_back(instExt.extensionName);
+    }
+
+    const VkInstanceCreateInfo instanceCreateInfo = {
+            VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO,
+            nullptr,
+            0,
+            &appInfo,
+            0,
+            nullptr,
+            (uint32_t)enabledInstanceExtensionNames.size(),
+            enabledInstanceExtensionNames.data(),
+    };
+
+    VK_GET_PROC(CreateInstance);
+    VkInstance instance;
+    VK_CHECK(vkCreateInstance(&instanceCreateInfo, nullptr, &instance));
+
+    VK_GET_INST_PROC(instance, DestroyInstance);
+    mFuncs.vkDestroyInstance = vkDestroyInstance;
+    VK_GET_INST_PROC(instance, EnumeratePhysicalDevices);
+    VK_GET_INST_PROC(instance, EnumerateDeviceExtensionProperties);
+    VK_GET_INST_PROC(instance, GetPhysicalDeviceProperties2);
+    VK_GET_INST_PROC(instance, GetPhysicalDeviceExternalSemaphoreProperties);
+    VK_GET_INST_PROC(instance, GetPhysicalDeviceQueueFamilyProperties2);
+    VK_GET_INST_PROC(instance, GetPhysicalDeviceFeatures2);
+    VK_GET_INST_PROC(instance, CreateDevice);
+
+    uint32_t physdevCount;
+    VK_CHECK(vkEnumeratePhysicalDevices(instance, &physdevCount, nullptr));
+    if (physdevCount == 0) {
+        BAIL("Could not find any physical devices");
+    }
+
+    physdevCount = 1;
+    VkPhysicalDevice physicalDevice;
+    VkResult enumeratePhysDevsErr =
+            vkEnumeratePhysicalDevices(instance, &physdevCount, &physicalDevice);
+    if (enumeratePhysDevsErr != VK_SUCCESS && VK_INCOMPLETE != enumeratePhysDevsErr) {
+        BAIL("vkEnumeratePhysicalDevices failed with non-VK_INCOMPLETE error: %d",
+             enumeratePhysDevsErr);
+    }
+
+    VkPhysicalDeviceProperties2 physDevProps = {
+            VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2,
+            0,
+            {},
+    };
+    VkPhysicalDeviceProtectedMemoryProperties protMemProps = {
+            VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROTECTED_MEMORY_PROPERTIES,
+            0,
+            {},
+    };
+
+    if (protectedContent) {
+        physDevProps.pNext = &protMemProps;
+    }
+
+    vkGetPhysicalDeviceProperties2(physicalDevice, &physDevProps);
+    if (physDevProps.properties.apiVersion < VK_MAKE_VERSION(1, 1, 0)) {
+        BAIL("Could not find a Vulkan 1.1+ physical device");
+    }
+
+    if (physDevProps.properties.deviceType == VK_PHYSICAL_DEVICE_TYPE_CPU) {
+        // TODO: b/326633110 - SkiaVK is not working correctly on swiftshader path.
+        BAIL("CPU implementations of Vulkan is not supported");
+    }
+
+    // Check for syncfd support. Bail if we cannot both import and export them.
+    VkPhysicalDeviceExternalSemaphoreInfo semInfo = {
+            VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_EXTERNAL_SEMAPHORE_INFO,
+            nullptr,
+            VK_EXTERNAL_SEMAPHORE_HANDLE_TYPE_SYNC_FD_BIT,
+    };
+    VkExternalSemaphoreProperties semProps = {
+            VK_STRUCTURE_TYPE_EXTERNAL_SEMAPHORE_PROPERTIES, nullptr, 0, 0, 0,
+    };
+    vkGetPhysicalDeviceExternalSemaphoreProperties(physicalDevice, &semInfo, &semProps);
+
+    bool sufficientSemaphoreSyncFdSupport = (semProps.exportFromImportedHandleTypes &
+                                             VK_EXTERNAL_SEMAPHORE_HANDLE_TYPE_SYNC_FD_BIT) &&
+            (semProps.compatibleHandleTypes & VK_EXTERNAL_SEMAPHORE_HANDLE_TYPE_SYNC_FD_BIT) &&
+            (semProps.externalSemaphoreFeatures & VK_EXTERNAL_SEMAPHORE_FEATURE_EXPORTABLE_BIT) &&
+            (semProps.externalSemaphoreFeatures & VK_EXTERNAL_SEMAPHORE_FEATURE_IMPORTABLE_BIT);
+
+    if (!sufficientSemaphoreSyncFdSupport) {
+        BAIL("Vulkan device does not support sufficient external semaphore sync fd features. "
+             "exportFromImportedHandleTypes 0x%x (needed 0x%x) "
+             "compatibleHandleTypes 0x%x (needed 0x%x) "
+             "externalSemaphoreFeatures 0x%x (needed 0x%x) ",
+             semProps.exportFromImportedHandleTypes, VK_EXTERNAL_SEMAPHORE_HANDLE_TYPE_SYNC_FD_BIT,
+             semProps.compatibleHandleTypes, VK_EXTERNAL_SEMAPHORE_HANDLE_TYPE_SYNC_FD_BIT,
+             semProps.externalSemaphoreFeatures,
+             VK_EXTERNAL_SEMAPHORE_FEATURE_EXPORTABLE_BIT |
+                     VK_EXTERNAL_SEMAPHORE_FEATURE_IMPORTABLE_BIT);
+    } else {
+        ALOGD("Vulkan device supports sufficient external semaphore sync fd features. "
+              "exportFromImportedHandleTypes 0x%x (needed 0x%x) "
+              "compatibleHandleTypes 0x%x (needed 0x%x) "
+              "externalSemaphoreFeatures 0x%x (needed 0x%x) ",
+              semProps.exportFromImportedHandleTypes, VK_EXTERNAL_SEMAPHORE_HANDLE_TYPE_SYNC_FD_BIT,
+              semProps.compatibleHandleTypes, VK_EXTERNAL_SEMAPHORE_HANDLE_TYPE_SYNC_FD_BIT,
+              semProps.externalSemaphoreFeatures,
+              VK_EXTERNAL_SEMAPHORE_FEATURE_EXPORTABLE_BIT |
+                      VK_EXTERNAL_SEMAPHORE_FEATURE_IMPORTABLE_BIT);
+    }
+
+    uint32_t queueCount;
+    vkGetPhysicalDeviceQueueFamilyProperties2(physicalDevice, &queueCount, nullptr);
+    if (queueCount == 0) {
+        BAIL("Could not find queues for physical device");
+    }
+
+    std::vector<VkQueueFamilyProperties2> queueProps(queueCount);
+    std::vector<VkQueueFamilyGlobalPriorityPropertiesEXT> queuePriorityProps(queueCount);
+    VkQueueGlobalPriorityKHR queuePriority = VK_QUEUE_GLOBAL_PRIORITY_MEDIUM_KHR;
+    // Even though we don't yet know if the VK_EXT_global_priority extension is available,
+    // we can safely add the request to the pNext chain, and if the extension is not
+    // available, it will be ignored.
+    for (uint32_t i = 0; i < queueCount; ++i) {
+        queuePriorityProps[i].sType = VK_STRUCTURE_TYPE_QUEUE_FAMILY_GLOBAL_PRIORITY_PROPERTIES_EXT;
+        queuePriorityProps[i].pNext = nullptr;
+        queueProps[i].pNext = &queuePriorityProps[i];
+    }
+    vkGetPhysicalDeviceQueueFamilyProperties2(physicalDevice, &queueCount, queueProps.data());
+
+    int graphicsQueueIndex = -1;
+    for (uint32_t i = 0; i < queueCount; ++i) {
+        // Look at potential answers to the VK_EXT_global_priority query.  If answers were
+        // provided, we may adjust the queuePriority.
+        if (queueProps[i].queueFamilyProperties.queueFlags & VK_QUEUE_GRAPHICS_BIT) {
+            for (uint32_t j = 0; j < queuePriorityProps[i].priorityCount; j++) {
+                if (queuePriorityProps[i].priorities[j] > queuePriority) {
+                    queuePriority = queuePriorityProps[i].priorities[j];
+                }
+            }
+            if (queuePriority == VK_QUEUE_GLOBAL_PRIORITY_REALTIME_KHR) {
+                mIsRealtimePriority = true;
+            }
+            graphicsQueueIndex = i;
+            break;
+        }
+    }
+
+    if (graphicsQueueIndex == -1) {
+        BAIL("Could not find a graphics queue family");
+    }
+
+    uint32_t deviceExtensionCount;
+    VK_CHECK(vkEnumerateDeviceExtensionProperties(physicalDevice, nullptr, &deviceExtensionCount,
+                                                  nullptr));
+    std::vector<VkExtensionProperties> deviceExtensions(deviceExtensionCount);
+    VK_CHECK(vkEnumerateDeviceExtensionProperties(physicalDevice, nullptr, &deviceExtensionCount,
+                                                  deviceExtensions.data()));
+
+    std::vector<const char*> enabledDeviceExtensionNames;
+    enabledDeviceExtensionNames.reserve(deviceExtensions.size());
+    mDeviceExtensionNames.reserve(deviceExtensions.size());
+    for (const auto& devExt : deviceExtensions) {
+        enabledDeviceExtensionNames.push_back(devExt.extensionName);
+        mDeviceExtensionNames.push_back(devExt.extensionName);
+    }
+
+    mGrExtensions.init(sGetProc, instance, physicalDevice, enabledInstanceExtensionNames.size(),
+                       enabledInstanceExtensionNames.data(), enabledDeviceExtensionNames.size(),
+                       enabledDeviceExtensionNames.data());
+
+    if (!mGrExtensions.hasExtension(VK_KHR_EXTERNAL_SEMAPHORE_FD_EXTENSION_NAME, 1)) {
+        BAIL("Vulkan driver doesn't support external semaphore fd");
+    }
+
+    mPhysicalDeviceFeatures2 = new VkPhysicalDeviceFeatures2;
+    mPhysicalDeviceFeatures2->sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2;
+    mPhysicalDeviceFeatures2->pNext = nullptr;
+
+    mSamplerYcbcrConversionFeatures = new VkPhysicalDeviceSamplerYcbcrConversionFeatures;
+    mSamplerYcbcrConversionFeatures->sType =
+            VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SAMPLER_YCBCR_CONVERSION_FEATURES;
+    mSamplerYcbcrConversionFeatures->pNext = nullptr;
+
+    mPhysicalDeviceFeatures2->pNext = mSamplerYcbcrConversionFeatures;
+    void** tailPnext = &mSamplerYcbcrConversionFeatures->pNext;
+
+    if (protectedContent) {
+        mProtectedMemoryFeatures = new VkPhysicalDeviceProtectedMemoryFeatures;
+        mProtectedMemoryFeatures->sType =
+                VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROTECTED_MEMORY_FEATURES;
+        mProtectedMemoryFeatures->pNext = nullptr;
+        *tailPnext = mProtectedMemoryFeatures;
+        tailPnext = &mProtectedMemoryFeatures->pNext;
+    }
+
+    if (mGrExtensions.hasExtension(VK_EXT_DEVICE_FAULT_EXTENSION_NAME, 1)) {
+        mDeviceFaultFeatures = new VkPhysicalDeviceFaultFeaturesEXT;
+        mDeviceFaultFeatures->sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FAULT_FEATURES_EXT;
+        mDeviceFaultFeatures->pNext = nullptr;
+        *tailPnext = mDeviceFaultFeatures;
+        tailPnext = &mDeviceFaultFeatures->pNext;
+    }
+
+    vkGetPhysicalDeviceFeatures2(physicalDevice, mPhysicalDeviceFeatures2);
+    // Looks like this would slow things down and we can't depend on it on all platforms
+    mPhysicalDeviceFeatures2->features.robustBufferAccess = VK_FALSE;
+
+    if (protectedContent && !mProtectedMemoryFeatures->protectedMemory) {
+        BAIL("Protected memory not supported");
+    }
+
+    float queuePriorities[1] = {0.0f};
+    void* queueNextPtr = nullptr;
+
+    VkDeviceQueueGlobalPriorityCreateInfoEXT queuePriorityCreateInfo = {
+            VK_STRUCTURE_TYPE_DEVICE_QUEUE_GLOBAL_PRIORITY_CREATE_INFO_EXT,
+            nullptr,
+            // If queue priority is supported, RE should always have realtime priority.
+            queuePriority,
+    };
+
+    if (mGrExtensions.hasExtension(VK_EXT_GLOBAL_PRIORITY_EXTENSION_NAME, 2)) {
+        queueNextPtr = &queuePriorityCreateInfo;
+    }
+
+    VkDeviceQueueCreateFlags deviceQueueCreateFlags =
+            (VkDeviceQueueCreateFlags)(protectedContent ? VK_DEVICE_QUEUE_CREATE_PROTECTED_BIT : 0);
+
+    const VkDeviceQueueCreateInfo queueInfo = {
+            VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO,
+            queueNextPtr,
+            deviceQueueCreateFlags,
+            (uint32_t)graphicsQueueIndex,
+            1,
+            queuePriorities,
+    };
+
+    const VkDeviceCreateInfo deviceInfo = {
+            VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO,
+            mPhysicalDeviceFeatures2,
+            0,
+            1,
+            &queueInfo,
+            0,
+            nullptr,
+            (uint32_t)enabledDeviceExtensionNames.size(),
+            enabledDeviceExtensionNames.data(),
+            nullptr,
+    };
+
+    ALOGD("Trying to create Vk device with protectedContent=%d", protectedContent);
+    VkDevice device;
+    VK_CHECK(vkCreateDevice(physicalDevice, &deviceInfo, nullptr, &device));
+    ALOGD("Trying to create Vk device with protectedContent=%d (success)", protectedContent);
+
+    VkQueue graphicsQueue;
+    VK_GET_DEV_PROC(device, GetDeviceQueue2);
+    const VkDeviceQueueInfo2 deviceQueueInfo2 = {VK_STRUCTURE_TYPE_DEVICE_QUEUE_INFO_2, nullptr,
+                                                 deviceQueueCreateFlags,
+                                                 (uint32_t)graphicsQueueIndex, 0};
+    vkGetDeviceQueue2(device, &deviceQueueInfo2, &graphicsQueue);
+
+    VK_GET_DEV_PROC(device, DeviceWaitIdle);
+    VK_GET_DEV_PROC(device, DestroyDevice);
+    mFuncs.vkDeviceWaitIdle = vkDeviceWaitIdle;
+    mFuncs.vkDestroyDevice = vkDestroyDevice;
+
+    VK_GET_DEV_PROC(device, CreateSemaphore);
+    VK_GET_DEV_PROC(device, ImportSemaphoreFdKHR);
+    VK_GET_DEV_PROC(device, GetSemaphoreFdKHR);
+    VK_GET_DEV_PROC(device, DestroySemaphore);
+    mFuncs.vkCreateSemaphore = vkCreateSemaphore;
+    mFuncs.vkImportSemaphoreFdKHR = vkImportSemaphoreFdKHR;
+    mFuncs.vkGetSemaphoreFdKHR = vkGetSemaphoreFdKHR;
+    mFuncs.vkDestroySemaphore = vkDestroySemaphore;
+
+    // At this point, everything's succeeded and we can continue
+    mInitialized = true;
+    mInstance = instance;
+    mPhysicalDevice = physicalDevice;
+    mDevice = device;
+    mQueue = graphicsQueue;
+    mQueueIndex = graphicsQueueIndex;
+    mApiVersion = physDevProps.properties.apiVersion;
+    // grExtensions already constructed
+    // feature pointers already constructed
+    mGrGetProc = sGetProc;
+    mIsProtected = protectedContent;
+    // mIsRealtimePriority already initialized by constructor
+    // funcs already initialized
+
+    const nsecs_t timeAfter = systemTime();
+    const float initTimeMs = static_cast<float>(timeAfter - timeBefore) / 1.0E6;
+    ALOGD("%s: Success init Vulkan interface in %f ms", __func__, initTimeMs);
+}
+
+// TODO: b/293371537 - Iterate on this.
+// Currently unused, but copied over from its original location for potential future use. This
+// should likely be improved to walk the pNext chain of mPhysicalDeviceFeatures2 and free everything
+// like HWUI's VulkanManager. Also, not all fields are being reset.
+void VulkanInterface::teardown() {
+    mInitialized = false;
+
+    if (mDevice != VK_NULL_HANDLE) {
+        mFuncs.vkDeviceWaitIdle(mDevice);
+        mFuncs.vkDestroyDevice(mDevice, nullptr);
+        mDevice = VK_NULL_HANDLE;
+    }
+    if (mInstance != VK_NULL_HANDLE) {
+        mFuncs.vkDestroyInstance(mInstance, nullptr);
+        mInstance = VK_NULL_HANDLE;
+    }
+
+    if (mProtectedMemoryFeatures) {
+        delete mProtectedMemoryFeatures;
+    }
+
+    if (mSamplerYcbcrConversionFeatures) {
+        delete mSamplerYcbcrConversionFeatures;
+    }
+
+    if (mPhysicalDeviceFeatures2) {
+        delete mPhysicalDeviceFeatures2;
+    }
+
+    if (mDeviceFaultFeatures) {
+        delete mDeviceFaultFeatures;
+    }
+
+    mSamplerYcbcrConversionFeatures = nullptr;
+    mPhysicalDeviceFeatures2 = nullptr;
+    mProtectedMemoryFeatures = nullptr;
+    mDeviceFaultFeatures = nullptr;
+}
+
+} // namespace skia
+} // namespace renderengine
+} // namespace android
diff --git a/libs/renderengine/skia/VulkanInterface.h b/libs/renderengine/skia/VulkanInterface.h
new file mode 100644
index 0000000..c3936d9
--- /dev/null
+++ b/libs/renderengine/skia/VulkanInterface.h
@@ -0,0 +1,95 @@
+/*
+ * Copyright 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include <include/gpu/vk/GrVkBackendContext.h>
+#include <include/gpu/vk/GrVkExtensions.h>
+
+#include <vulkan/vulkan.h>
+
+using namespace skgpu;
+
+namespace android {
+namespace renderengine {
+namespace skia {
+
+class VulkanInterface {
+public:
+    // Create an uninitialized interface. Initialize with `init`.
+    VulkanInterface() = default;
+    ~VulkanInterface() = default;
+    VulkanInterface(const VulkanInterface&) = delete;
+    VulkanInterface& operator=(const VulkanInterface&) = delete;
+    VulkanInterface& operator=(VulkanInterface&&) = delete;
+
+    void init(bool protectedContent = false);
+    void teardown();
+
+    // TODO: b/293371537 - Graphite variant (external/skia/include/gpu/vk/VulkanBackendContext.h)
+    GrVkBackendContext getBackendContext();
+    VkSemaphore createExportableSemaphore();
+    VkSemaphore importSemaphoreFromSyncFd(int syncFd);
+    int exportSemaphoreSyncFd(VkSemaphore semaphore);
+    void destroySemaphore(VkSemaphore semaphore);
+
+    bool isInitialized() const { return mInitialized; }
+    bool isRealtimePriority() const { return mIsRealtimePriority; }
+    const std::vector<std::string>& getInstanceExtensionNames() { return mInstanceExtensionNames; }
+    const std::vector<std::string>& getDeviceExtensionNames() { return mDeviceExtensionNames; }
+
+private:
+    struct VulkanFuncs {
+        PFN_vkCreateSemaphore vkCreateSemaphore = nullptr;
+        PFN_vkImportSemaphoreFdKHR vkImportSemaphoreFdKHR = nullptr;
+        PFN_vkGetSemaphoreFdKHR vkGetSemaphoreFdKHR = nullptr;
+        PFN_vkDestroySemaphore vkDestroySemaphore = nullptr;
+
+        PFN_vkDeviceWaitIdle vkDeviceWaitIdle = nullptr;
+        PFN_vkDestroyDevice vkDestroyDevice = nullptr;
+        PFN_vkDestroyInstance vkDestroyInstance = nullptr;
+    };
+
+    static void onVkDeviceFault(void* callbackContext, const std::string& description,
+                                const std::vector<VkDeviceFaultAddressInfoEXT>& addressInfos,
+                                const std::vector<VkDeviceFaultVendorInfoEXT>& vendorInfos,
+                                const std::vector<std::byte>& vendorBinaryData);
+
+    bool mInitialized = false;
+    VkInstance mInstance = VK_NULL_HANDLE;
+    VkPhysicalDevice mPhysicalDevice = VK_NULL_HANDLE;
+    VkDevice mDevice = VK_NULL_HANDLE;
+    VkQueue mQueue = VK_NULL_HANDLE;
+    int mQueueIndex = 0;
+    uint32_t mApiVersion = 0;
+    GrVkExtensions mGrExtensions;
+    VkPhysicalDeviceFeatures2* mPhysicalDeviceFeatures2 = nullptr;
+    VkPhysicalDeviceSamplerYcbcrConversionFeatures* mSamplerYcbcrConversionFeatures = nullptr;
+    VkPhysicalDeviceProtectedMemoryFeatures* mProtectedMemoryFeatures = nullptr;
+    VkPhysicalDeviceFaultFeaturesEXT* mDeviceFaultFeatures = nullptr;
+    GrVkGetProc mGrGetProc = nullptr;
+    bool mIsProtected = false;
+    bool mIsRealtimePriority = false;
+
+    VulkanFuncs mFuncs;
+
+    std::vector<std::string> mInstanceExtensionNames;
+    std::vector<std::string> mDeviceExtensionNames;
+};
+
+} // namespace skia
+} // namespace renderengine
+} // namespace android
diff --git a/libs/tracing_perfetto/.clang-format b/libs/tracing_perfetto/.clang-format
new file mode 100644
index 0000000..f397454
--- /dev/null
+++ b/libs/tracing_perfetto/.clang-format
@@ -0,0 +1,12 @@
+BasedOnStyle: Google
+AllowShortBlocksOnASingleLine: false
+AllowShortFunctionsOnASingleLine: false
+
+ColumnLimit: 80
+ContinuationIndentWidth: 4
+CommentPragmas: NOLINT:.*
+DerivePointerAlignment: false
+IndentWidth: 2
+PointerAlignment: Left
+UseTab: Never
+PenaltyExcessCharacter: 32
\ No newline at end of file
diff --git a/libs/tracing_perfetto/Android.bp b/libs/tracing_perfetto/Android.bp
new file mode 100644
index 0000000..3a4c869
--- /dev/null
+++ b/libs/tracing_perfetto/Android.bp
@@ -0,0 +1,49 @@
+// Copyright 2024 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package {
+    // See: http://go/android-license-faq
+    // A large-scale-change added 'default_applicable_licenses' to import
+    // all of the 'license_kinds' from "frameworks_native_license"
+    // to get the below license kinds:
+    //   SPDX-license-identifier-Apache-2.0
+    default_applicable_licenses: ["frameworks_native_license"],
+}
+
+cc_library_shared {
+    name: "libtracing_perfetto",
+    export_include_dirs: [
+        "include",
+    ],
+
+    cflags: [
+        "-Wall",
+        "-Werror",
+        "-Wno-enum-compare",
+        "-Wno-unused-function",
+    ],
+
+    srcs: [
+        "tracing_perfetto.cpp",
+        "tracing_perfetto_internal.cpp",
+    ],
+
+    shared_libs: [
+        "libcutils",
+        "libperfetto_c",
+        "android.os.flags-aconfig-cc-host",
+    ],
+
+    host_supported: true,
+}
diff --git a/libs/tracing_perfetto/OWNERS b/libs/tracing_perfetto/OWNERS
new file mode 100644
index 0000000..e2d4b46
--- /dev/null
+++ b/libs/tracing_perfetto/OWNERS
@@ -0,0 +1,2 @@
+zezeozue@google.com
+biswarupp@google.com
\ No newline at end of file
diff --git a/libs/tracing_perfetto/include/trace_categories.h b/libs/tracing_perfetto/include/trace_categories.h
new file mode 100644
index 0000000..6d4168b
--- /dev/null
+++ b/libs/tracing_perfetto/include/trace_categories.h
@@ -0,0 +1,65 @@
+/*
+ * Copyright 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef TRACE_CATEGORIES_H
+#define TRACE_CATEGORIES_H
+
+/**
+ * Keep these in sync with frameworks/base/core/java/android/os/Trace.java.
+ */
+#define TRACE_CATEGORY_ALWAYS (1 << 0)
+#define TRACE_CATEGORY_GRAPHICS (1 << 1)
+#define TRACE_CATEGORY_INPUT (1 << 2)
+#define TRACE_CATEGORY_VIEW (1 << 3)
+#define TRACE_CATEGORY_WEBVIEW (1 << 4)
+#define TRACE_CATEGORY_WINDOW_MANAGER (1 << 5)
+#define TRACE_CATEGORY_ACTIVITY_MANAGER (1 << 6)
+#define TRACE_CATEGORY_SYNC_MANAGER (1 << 7)
+#define TRACE_CATEGORY_AUDIO (1 << 8)
+#define TRACE_CATEGORY_VIDEO (1 << 9)
+#define TRACE_CATEGORY_CAMERA (1 << 10)
+#define TRACE_CATEGORY_HAL (1 << 11)
+#define TRACE_CATEGORY_APP (1 << 12)
+#define TRACE_CATEGORY_RESOURCES (1 << 13)
+#define TRACE_CATEGORY_DALVIK (1 << 14)
+#define TRACE_CATEGORY_RS (1 << 15)
+#define TRACE_CATEGORY_BIONIC (1 << 16)
+#define TRACE_CATEGORY_POWER (1 << 17)
+#define TRACE_CATEGORY_PACKAGE_MANAGER (1 << 18)
+#define TRACE_CATEGORY_SYSTEM_SERVER (1 << 19)
+#define TRACE_CATEGORY_DATABASE (1 << 20)
+#define TRACE_CATEGORY_NETWORK (1 << 21)
+#define TRACE_CATEGORY_ADB (1 << 22)
+#define TRACE_CATEGORY_VIBRATOR (1 << 23)
+#define TRACE_CATEGORY_AIDL (1 << 24)
+#define TRACE_CATEGORY_NNAPI (1 << 25)
+#define TRACE_CATEGORY_RRO (1 << 26)
+#define TRACE_CATEGORY_THERMAL (1 << 27)
+
+// Allow all categories except TRACE_CATEGORY_APP
+#define TRACE_CATEGORIES                                                      \
+  TRACE_CATEGORY_ALWAYS | TRACE_CATEGORY_GRAPHICS | TRACE_CATEGORY_INPUT |    \
+      TRACE_CATEGORY_VIEW | TRACE_CATEGORY_WEBVIEW |                          \
+      TRACE_CATEGORY_WINDOW_MANAGER | TRACE_CATEGORY_ACTIVITY_MANAGER |       \
+      TRACE_CATEGORY_SYNC_MANAGER | TRACE_CATEGORY_AUDIO |                    \
+      TRACE_CATEGORY_VIDEO | TRACE_CATEGORY_CAMERA | TRACE_CATEGORY_HAL |     \
+      TRACE_CATEGORY_RESOURCES | TRACE_CATEGORY_DALVIK | TRACE_CATEGORY_RS |  \
+      TRACE_CATEGORY_BIONIC | TRACE_CATEGORY_POWER |                          \
+      TRACE_CATEGORY_PACKAGE_MANAGER | TRACE_CATEGORY_SYSTEM_SERVER |         \
+      TRACE_CATEGORY_DATABASE | TRACE_CATEGORY_NETWORK | TRACE_CATEGORY_ADB | \
+      TRACE_CATEGORY_VIBRATOR | TRACE_CATEGORY_AIDL | TRACE_CATEGORY_NNAPI |  \
+      TRACE_CATEGORY_RRO | TRACE_CATEGORY_THERMAL
+#endif  // TRACE_CATEGORIES_H
diff --git a/libs/tracing_perfetto/include/trace_result.h b/libs/tracing_perfetto/include/trace_result.h
new file mode 100644
index 0000000..f7581fc
--- /dev/null
+++ b/libs/tracing_perfetto/include/trace_result.h
@@ -0,0 +1,30 @@
+/*
+ * Copyright 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef TRACE_RESULT_H
+#define TRACE_RESULT_H
+
+namespace tracing_perfetto {
+
+enum class Result {
+  SUCCESS,
+  NOT_SUPPORTED,
+  INVALID_INPUT,
+};
+
+}
+
+#endif  // TRACE_RESULT_H
diff --git a/libs/tracing_perfetto/include/tracing_perfetto.h b/libs/tracing_perfetto/include/tracing_perfetto.h
new file mode 100644
index 0000000..4e3c83f
--- /dev/null
+++ b/libs/tracing_perfetto/include/tracing_perfetto.h
@@ -0,0 +1,53 @@
+/*
+ * Copyright 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef TRACING_PERFETTO_H
+#define TRACING_PERFETTO_H
+
+#include <stdint.h>
+
+#include "trace_result.h"
+
+namespace tracing_perfetto {
+
+void registerWithPerfetto(bool test = false);
+
+Result traceBegin(uint64_t category, const char* name);
+
+Result traceEnd(uint64_t category);
+
+Result traceAsyncBegin(uint64_t category, const char* name, int32_t cookie);
+
+Result traceAsyncEnd(uint64_t category, const char* name, int32_t cookie);
+
+Result traceAsyncBeginForTrack(uint64_t category, const char* name,
+                               const char* trackName, int32_t cookie);
+
+Result traceAsyncEndForTrack(uint64_t category, const char* trackName,
+                             int32_t cookie);
+
+Result traceInstant(uint64_t category, const char* name);
+
+Result traceInstantForTrack(uint64_t category, const char* trackName,
+                            const char* name);
+
+Result traceCounter(uint64_t category, const char* name, int64_t value);
+
+uint64_t getEnabledCategories();
+
+}  // namespace tracing_perfetto
+
+#endif  // TRACING_PERFETTO_H
diff --git a/libs/tracing_perfetto/tests/Android.bp b/libs/tracing_perfetto/tests/Android.bp
new file mode 100644
index 0000000..a35b0e0
--- /dev/null
+++ b/libs/tracing_perfetto/tests/Android.bp
@@ -0,0 +1,45 @@
+// Copyright 2024 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package {
+    // See: http://go/android-license-faq
+    // A large-scale-change added 'default_applicable_licenses' to import
+    // all of the 'license_kinds' from "frameworks_native_license"
+    // to get the below license kinds:
+    //   SPDX-license-identifier-Apache-2.0
+    default_applicable_licenses: ["frameworks_native_license"],
+}
+
+cc_test {
+    name: "libtracing_perfetto_tests",
+    static_libs: [
+        "libflagtest",
+        "libgmock",
+    ],
+    cflags: [
+        "-Wall",
+        "-Werror",
+    ],
+    shared_libs: [
+        "android.os.flags-aconfig-cc-host",
+        "libbase",
+        "libperfetto_c",
+        "libtracing_perfetto",
+    ],
+    srcs: [
+        "tracing_perfetto_test.cpp",
+        "utils.cpp",
+    ],
+    test_suites: ["device-tests"],
+}
diff --git a/libs/tracing_perfetto/tests/tracing_perfetto_test.cpp b/libs/tracing_perfetto/tests/tracing_perfetto_test.cpp
new file mode 100644
index 0000000..7716b9a
--- /dev/null
+++ b/libs/tracing_perfetto/tests/tracing_perfetto_test.cpp
@@ -0,0 +1,111 @@
+/*
+ * Copyright 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "tracing_perfetto.h"
+
+#include <thread>
+
+#include <android_os.h>
+#include <flag_macros.h>
+
+#include "gtest/gtest.h"
+#include "perfetto/public/abi/data_source_abi.h"
+#include "perfetto/public/abi/heap_buffer.h"
+#include "perfetto/public/abi/pb_decoder_abi.h"
+#include "perfetto/public/abi/tracing_session_abi.h"
+#include "perfetto/public/abi/track_event_abi.h"
+#include "perfetto/public/data_source.h"
+#include "perfetto/public/pb_decoder.h"
+#include "perfetto/public/producer.h"
+#include "perfetto/public/protos/config/trace_config.pzc.h"
+#include "perfetto/public/protos/trace/interned_data/interned_data.pzc.h"
+#include "perfetto/public/protos/trace/test_event.pzc.h"
+#include "perfetto/public/protos/trace/trace.pzc.h"
+#include "perfetto/public/protos/trace/trace_packet.pzc.h"
+#include "perfetto/public/protos/trace/track_event/debug_annotation.pzc.h"
+#include "perfetto/public/protos/trace/track_event/track_descriptor.pzc.h"
+#include "perfetto/public/protos/trace/track_event/track_event.pzc.h"
+#include "perfetto/public/protos/trace/trigger.pzc.h"
+#include "perfetto/public/te_category_macros.h"
+#include "perfetto/public/te_macros.h"
+#include "perfetto/public/track_event.h"
+#include "trace_categories.h"
+#include "utils.h"
+
+namespace tracing_perfetto {
+
+using ::perfetto::shlib::test_utils::AllFieldsWithId;
+using ::perfetto::shlib::test_utils::FieldView;
+using ::perfetto::shlib::test_utils::IdFieldView;
+using ::perfetto::shlib::test_utils::MsgField;
+using ::perfetto::shlib::test_utils::PbField;
+using ::perfetto::shlib::test_utils::StringField;
+using ::perfetto::shlib::test_utils::TracingSession;
+using ::perfetto::shlib::test_utils::VarIntField;
+using ::testing::_;
+using ::testing::ElementsAre;
+using ::testing::UnorderedElementsAre;
+
+const auto PERFETTO_SDK_TRACING = ACONFIG_FLAG(android::os, perfetto_sdk_tracing);
+
+class TracingPerfettoTest : public testing::Test {
+ protected:
+  void SetUp() override {
+    tracing_perfetto::registerWithPerfetto(true /* test */);
+  }
+};
+
+// TODO(b/303199244): Add tests for all the library functions.
+
+TEST_F_WITH_FLAGS(TracingPerfettoTest, traceInstant,
+                  REQUIRES_FLAGS_ENABLED(PERFETTO_SDK_TRACING)) {
+  TracingSession tracing_session =
+      TracingSession::Builder().set_data_source_name("track_event").Build();
+  tracing_perfetto::traceInstant(TRACE_CATEGORY_INPUT, "");
+
+  tracing_session.StopBlocking();
+  std::vector<uint8_t> data = tracing_session.ReadBlocking();
+  bool found = false;
+  for (struct PerfettoPbDecoderField trace_field : FieldView(data)) {
+    ASSERT_THAT(trace_field, PbField(perfetto_protos_Trace_packet_field_number,
+                                     MsgField(_)));
+    IdFieldView track_event(
+        trace_field, perfetto_protos_TracePacket_track_event_field_number);
+    if (track_event.size() == 0) {
+      continue;
+    }
+    found = true;
+    IdFieldView cat_iid_fields(
+        track_event.front(),
+        perfetto_protos_TrackEvent_category_iids_field_number);
+    ASSERT_THAT(cat_iid_fields, ElementsAre(VarIntField(_)));
+    uint64_t cat_iid = cat_iid_fields.front().value.integer64;
+    EXPECT_THAT(
+        trace_field,
+        AllFieldsWithId(
+            perfetto_protos_TracePacket_interned_data_field_number,
+            ElementsAre(AllFieldsWithId(
+                perfetto_protos_InternedData_event_categories_field_number,
+                ElementsAre(MsgField(UnorderedElementsAre(
+                    PbField(perfetto_protos_EventCategory_iid_field_number,
+                            VarIntField(cat_iid)),
+                    PbField(perfetto_protos_EventCategory_name_field_number,
+                            StringField("input")))))))));
+  }
+  EXPECT_TRUE(found);
+}
+
+}  // namespace tracing_perfetto
\ No newline at end of file
diff --git a/libs/tracing_perfetto/tests/utils.cpp b/libs/tracing_perfetto/tests/utils.cpp
new file mode 100644
index 0000000..9c42028
--- /dev/null
+++ b/libs/tracing_perfetto/tests/utils.cpp
@@ -0,0 +1,219 @@
+/*
+ * Copyright 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Copied from //external/perfetto/src/shared_lib/test/utils.cc
+
+#include "utils.h"
+
+#include "perfetto/public/abi/heap_buffer.h"
+#include "perfetto/public/pb_msg.h"
+#include "perfetto/public/pb_utils.h"
+#include "perfetto/public/protos/config/data_source_config.pzc.h"
+#include "perfetto/public/protos/config/trace_config.pzc.h"
+#include "perfetto/public/protos/config/track_event/track_event_config.pzc.h"
+#include "perfetto/public/tracing_session.h"
+
+namespace perfetto {
+namespace shlib {
+namespace test_utils {
+namespace {
+
+std::string ToHexChars(uint8_t val) {
+  std::string ret;
+  uint8_t high_nibble = (val & 0xF0) >> 4;
+  uint8_t low_nibble = (val & 0xF);
+  static const char hex_chars[] = "0123456789ABCDEF";
+  ret.push_back(hex_chars[high_nibble]);
+  ret.push_back(hex_chars[low_nibble]);
+  return ret;
+}
+
+}  // namespace
+
+TracingSession TracingSession::Builder::Build() {
+  struct PerfettoPbMsgWriter writer;
+  struct PerfettoHeapBuffer* hb = PerfettoHeapBufferCreate(&writer.writer);
+
+  struct perfetto_protos_TraceConfig cfg;
+  PerfettoPbMsgInit(&cfg.msg, &writer);
+
+  {
+    struct perfetto_protos_TraceConfig_BufferConfig buffers;
+    perfetto_protos_TraceConfig_begin_buffers(&cfg, &buffers);
+
+    perfetto_protos_TraceConfig_BufferConfig_set_size_kb(&buffers, 1024);
+
+    perfetto_protos_TraceConfig_end_buffers(&cfg, &buffers);
+  }
+
+  {
+    struct perfetto_protos_TraceConfig_DataSource data_sources;
+    perfetto_protos_TraceConfig_begin_data_sources(&cfg, &data_sources);
+
+    {
+      struct perfetto_protos_DataSourceConfig ds_cfg;
+      perfetto_protos_TraceConfig_DataSource_begin_config(&data_sources,
+                                                          &ds_cfg);
+
+      perfetto_protos_DataSourceConfig_set_cstr_name(&ds_cfg,
+                                                     data_source_name_.c_str());
+      if (!enabled_categories_.empty() && !disabled_categories_.empty()) {
+        perfetto_protos_TrackEventConfig te_cfg;
+        perfetto_protos_DataSourceConfig_begin_track_event_config(&ds_cfg,
+                                                                  &te_cfg);
+        for (const std::string& cat : enabled_categories_) {
+          perfetto_protos_TrackEventConfig_set_enabled_categories(
+              &te_cfg, cat.data(), cat.size());
+        }
+        for (const std::string& cat : disabled_categories_) {
+          perfetto_protos_TrackEventConfig_set_disabled_categories(
+              &te_cfg, cat.data(), cat.size());
+        }
+        perfetto_protos_DataSourceConfig_end_track_event_config(&ds_cfg,
+                                                                &te_cfg);
+      }
+
+      perfetto_protos_TraceConfig_DataSource_end_config(&data_sources, &ds_cfg);
+    }
+
+    perfetto_protos_TraceConfig_end_data_sources(&cfg, &data_sources);
+  }
+  size_t cfg_size = PerfettoStreamWriterGetWrittenSize(&writer.writer);
+  std::unique_ptr<uint8_t[]> ser(new uint8_t[cfg_size]);
+  PerfettoHeapBufferCopyInto(hb, &writer.writer, ser.get(), cfg_size);
+  PerfettoHeapBufferDestroy(hb, &writer.writer);
+
+  struct PerfettoTracingSessionImpl* ts =
+      PerfettoTracingSessionCreate(PERFETTO_BACKEND_IN_PROCESS);
+
+  PerfettoTracingSessionSetup(ts, ser.get(), cfg_size);
+
+  PerfettoTracingSessionStartBlocking(ts);
+
+  return TracingSession::Adopt(ts);
+}
+
+TracingSession TracingSession::Adopt(struct PerfettoTracingSessionImpl* session) {
+  TracingSession ret;
+  ret.session_ = session;
+  ret.stopped_ = std::make_unique<WaitableEvent>();
+  PerfettoTracingSessionSetStopCb(
+      ret.session_,
+      [](struct PerfettoTracingSessionImpl*, void* arg) {
+        static_cast<WaitableEvent*>(arg)->Notify();
+      },
+      ret.stopped_.get());
+  return ret;
+}
+
+TracingSession::TracingSession(TracingSession&& other) noexcept {
+  session_ = other.session_;
+  other.session_ = nullptr;
+  stopped_ = std::move(other.stopped_);
+  other.stopped_ = nullptr;
+}
+
+TracingSession::~TracingSession() {
+  if (!session_) {
+    return;
+  }
+  if (!stopped_->IsNotified()) {
+    PerfettoTracingSessionStopBlocking(session_);
+    stopped_->WaitForNotification();
+  }
+  PerfettoTracingSessionDestroy(session_);
+}
+
+bool TracingSession::FlushBlocking(uint32_t timeout_ms) {
+  WaitableEvent notification;
+  bool result;
+  auto* cb = new std::function<void(bool)>([&](bool success) {
+    result = success;
+    notification.Notify();
+  });
+  PerfettoTracingSessionFlushAsync(
+      session_, timeout_ms,
+      [](PerfettoTracingSessionImpl*, bool success, void* user_arg) {
+        auto* f = reinterpret_cast<std::function<void(bool)>*>(user_arg);
+        (*f)(success);
+        delete f;
+      },
+      cb);
+  notification.WaitForNotification();
+  return result;
+}
+
+void TracingSession::WaitForStopped() {
+  stopped_->WaitForNotification();
+}
+
+void TracingSession::StopBlocking() {
+  PerfettoTracingSessionStopBlocking(session_);
+}
+
+std::vector<uint8_t> TracingSession::ReadBlocking() {
+  std::vector<uint8_t> data;
+  PerfettoTracingSessionReadTraceBlocking(
+      session_,
+      [](struct PerfettoTracingSessionImpl*, const void* trace_data,
+         size_t size, bool, void* user_arg) {
+        auto& dst = *static_cast<std::vector<uint8_t>*>(user_arg);
+        auto* src = static_cast<const uint8_t*>(trace_data);
+        dst.insert(dst.end(), src, src + size);
+      },
+      &data);
+  return data;
+}
+
+}  // namespace test_utils
+}  // namespace shlib
+}  // namespace perfetto
+
+void PrintTo(const PerfettoPbDecoderField& field, std::ostream* pos) {
+  std::ostream& os = *pos;
+  PerfettoPbDecoderStatus status =
+      static_cast<PerfettoPbDecoderStatus>(field.status);
+  switch (status) {
+    case PERFETTO_PB_DECODER_ERROR:
+      os << "MALFORMED PROTOBUF";
+      break;
+    case PERFETTO_PB_DECODER_DONE:
+      os << "DECODER DONE";
+      break;
+    case PERFETTO_PB_DECODER_OK:
+      switch (field.wire_type) {
+        case PERFETTO_PB_WIRE_TYPE_DELIMITED:
+          os << "\"";
+          for (size_t i = 0; i < field.value.delimited.len; i++) {
+            os << perfetto::shlib::test_utils::ToHexChars(
+                      field.value.delimited.start[i])
+               << " ";
+          }
+          os << "\"";
+          break;
+        case PERFETTO_PB_WIRE_TYPE_VARINT:
+          os << "varint: " << field.value.integer64;
+          break;
+        case PERFETTO_PB_WIRE_TYPE_FIXED32:
+          os << "fixed32: " << field.value.integer32;
+          break;
+        case PERFETTO_PB_WIRE_TYPE_FIXED64:
+          os << "fixed64: " << field.value.integer64;
+          break;
+      }
+      break;
+  }
+}
diff --git a/libs/tracing_perfetto/tests/utils.h b/libs/tracing_perfetto/tests/utils.h
new file mode 100644
index 0000000..4353554
--- /dev/null
+++ b/libs/tracing_perfetto/tests/utils.h
@@ -0,0 +1,452 @@
+/*
+ * Copyright 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Copied from //external/perfetto/src/shared_lib/test/utils.h
+
+#ifndef UTILS_H
+#define UTILS_H
+
+#include <cassert>
+#include <condition_variable>
+#include <cstdint>
+#include <functional>
+#include <iterator>
+#include <memory>
+#include <mutex>
+#include <ostream>
+#include <string>
+#include <vector>
+
+#include "gmock/gmock-matchers.h"
+#include "gmock/gmock-more-matchers.h"
+#include "gtest/gtest-matchers.h"
+#include "gtest/gtest.h"
+#include "perfetto/public/abi/pb_decoder_abi.h"
+#include "perfetto/public/pb_utils.h"
+#include "perfetto/public/tracing_session.h"
+
+// Pretty printer for gtest
+void PrintTo(const PerfettoPbDecoderField& field, std::ostream*);
+
+namespace perfetto {
+namespace shlib {
+namespace test_utils {
+
+class WaitableEvent {
+ public:
+  WaitableEvent() = default;
+  void Notify() {
+    std::unique_lock<std::mutex> lock(m_);
+    notified_ = true;
+    cv_.notify_one();
+  }
+  bool WaitForNotification() {
+    std::unique_lock<std::mutex> lock(m_);
+    cv_.wait(lock, [this] { return notified_; });
+    return notified_;
+  }
+  bool IsNotified() {
+    std::unique_lock<std::mutex> lock(m_);
+    return notified_;
+  }
+
+ private:
+  std::mutex m_;
+  std::condition_variable cv_;
+  bool notified_ = false;
+};
+
+class TracingSession {
+ public:
+  class Builder {
+   public:
+    Builder() = default;
+    Builder& set_data_source_name(std::string data_source_name) {
+      data_source_name_ = std::move(data_source_name);
+      return *this;
+    }
+    Builder& add_enabled_category(std::string category) {
+      enabled_categories_.push_back(std::move(category));
+      return *this;
+    }
+    Builder& add_disabled_category(std::string category) {
+      disabled_categories_.push_back(std::move(category));
+      return *this;
+    }
+    TracingSession Build();
+
+   private:
+    std::string data_source_name_;
+    std::vector<std::string> enabled_categories_;
+    std::vector<std::string> disabled_categories_;
+  };
+
+  static TracingSession Adopt(struct PerfettoTracingSessionImpl*);
+
+  TracingSession(TracingSession&&) noexcept;
+
+  ~TracingSession();
+
+  struct PerfettoTracingSessionImpl* session() const {
+    return session_;
+  }
+
+  bool FlushBlocking(uint32_t timeout_ms);
+  void WaitForStopped();
+  void StopBlocking();
+  std::vector<uint8_t> ReadBlocking();
+
+ private:
+  TracingSession() = default;
+  struct PerfettoTracingSessionImpl* session_;
+  std::unique_ptr<WaitableEvent> stopped_;
+};
+
+template <typename FieldSkipper>
+class FieldViewBase {
+ public:
+  class Iterator {
+   public:
+    using iterator_category = std::input_iterator_tag;
+    using value_type = const PerfettoPbDecoderField;
+    using pointer = value_type;
+    using reference = value_type;
+    reference operator*() const {
+      struct PerfettoPbDecoder decoder;
+      decoder.read_ptr = read_ptr_;
+      decoder.end_ptr = end_ptr_;
+      struct PerfettoPbDecoderField field;
+      do {
+        field = PerfettoPbDecoderParseField(&decoder);
+      } while (field.status == PERFETTO_PB_DECODER_OK &&
+               skipper_.ShouldSkip(field));
+      return field;
+    }
+    Iterator& operator++() {
+      struct PerfettoPbDecoder decoder;
+      decoder.read_ptr = read_ptr_;
+      decoder.end_ptr = end_ptr_;
+      PerfettoPbDecoderSkipField(&decoder);
+      read_ptr_ = decoder.read_ptr;
+      AdvanceToFirstInterestingField();
+      return *this;
+    }
+    Iterator operator++(int) {
+      Iterator tmp = *this;
+      ++(*this);
+      return tmp;
+    }
+
+    friend bool operator==(const Iterator& a, const Iterator& b) {
+      return a.read_ptr_ == b.read_ptr_;
+    }
+    friend bool operator!=(const Iterator& a, const Iterator& b) {
+      return a.read_ptr_ != b.read_ptr_;
+    }
+
+   private:
+    Iterator(const uint8_t* read_ptr, const uint8_t* end_ptr,
+             const FieldSkipper& skipper)
+        : read_ptr_(read_ptr), end_ptr_(end_ptr), skipper_(skipper) {
+      AdvanceToFirstInterestingField();
+    }
+    void AdvanceToFirstInterestingField() {
+      struct PerfettoPbDecoder decoder;
+      decoder.read_ptr = read_ptr_;
+      decoder.end_ptr = end_ptr_;
+      struct PerfettoPbDecoderField field;
+      const uint8_t* prev_read_ptr;
+      do {
+        prev_read_ptr = decoder.read_ptr;
+        field = PerfettoPbDecoderParseField(&decoder);
+      } while (field.status == PERFETTO_PB_DECODER_OK &&
+               skipper_.ShouldSkip(field));
+      if (field.status == PERFETTO_PB_DECODER_OK) {
+        read_ptr_ = prev_read_ptr;
+      } else {
+        read_ptr_ = decoder.read_ptr;
+      }
+    }
+    friend class FieldViewBase<FieldSkipper>;
+    const uint8_t* read_ptr_;
+    const uint8_t* end_ptr_;
+    const FieldSkipper& skipper_;
+  };
+  using value_type = const PerfettoPbDecoderField;
+  using const_iterator = Iterator;
+  template <typename... Args>
+  explicit FieldViewBase(const uint8_t* begin, const uint8_t* end, Args... args)
+      : begin_(begin), end_(end), s_(args...) {
+  }
+  template <typename... Args>
+  explicit FieldViewBase(const std::vector<uint8_t>& data, Args... args)
+      : FieldViewBase(data.data(), data.data() + data.size(), args...) {
+  }
+  template <typename... Args>
+  explicit FieldViewBase(const struct PerfettoPbDecoderField& field,
+                         Args... args)
+      : s_(args...) {
+    if (field.wire_type != PERFETTO_PB_WIRE_TYPE_DELIMITED) {
+      abort();
+    }
+    begin_ = field.value.delimited.start;
+    end_ = begin_ + field.value.delimited.len;
+  }
+  Iterator begin() const {
+    return Iterator(begin_, end_, s_);
+  }
+  Iterator end() const {
+    return Iterator(end_, end_, s_);
+  }
+  PerfettoPbDecoderField front() const {
+    return *begin();
+  }
+
+  size_t size() const {
+    size_t count = 0;
+    for (auto field : *this) {
+      (void)field;
+      count++;
+    }
+    return count;
+  }
+
+  bool ok() const {
+    for (auto field : *this) {
+      if (field.status != PERFETTO_PB_DECODER_OK) {
+        return false;
+      }
+    }
+    return true;
+  }
+
+ private:
+  const uint8_t* begin_;
+  const uint8_t* end_;
+  FieldSkipper s_;
+};
+
+// Pretty printer for gtest
+template <typename FieldSkipper>
+void PrintTo(const FieldViewBase<FieldSkipper>& field_view, std::ostream* pos) {
+  std::ostream& os = *pos;
+  os << "{";
+  for (PerfettoPbDecoderField f : field_view) {
+    PrintTo(f, pos);
+    os << ", ";
+  }
+  os << "}";
+}
+
+class IdFieldSkipper {
+ public:
+  explicit IdFieldSkipper(uint32_t id) : id_(id) {
+  }
+  explicit IdFieldSkipper(int32_t id) : id_(static_cast<uint32_t>(id)) {
+  }
+  bool ShouldSkip(const struct PerfettoPbDecoderField& field) const {
+    return field.id != id_;
+  }
+
+ private:
+  uint32_t id_;
+};
+
+class NoFieldSkipper {
+ public:
+  NoFieldSkipper() = default;
+  bool ShouldSkip(const struct PerfettoPbDecoderField&) const {
+    return false;
+  }
+};
+
+// View over all the fields of a contiguous serialized protobuf message.
+//
+// Examples:
+//
+// for (struct PerfettoPbDecoderField field : FieldView(msg_begin, msg_end)) {
+//   //...
+// }
+// FieldView fields2(/*PerfettoPbDecoderField*/ nested_field);
+// FieldView fields3(/*std::vector<uint8_t>*/ data);
+// size_t num = fields1.size(); // The number of fields.
+// bool ok = fields1.ok(); // Checks that the message is not malformed.
+using FieldView = FieldViewBase<NoFieldSkipper>;
+
+// Like `FieldView`, but only considers fields with a specific id.
+//
+// Examples:
+//
+// IdFieldView fields(msg_begin, msg_end, id)
+using IdFieldView = FieldViewBase<IdFieldSkipper>;
+
+// Matches a PerfettoPbDecoderField with the specified id. Accepts another
+// matcher to match the contents of the field.
+//
+// Example:
+// PerfettoPbDecoderField field = ...
+// EXPECT_THAT(field, PbField(900, VarIntField(5)));
+template <typename M>
+auto PbField(int32_t id, M m) {
+  return testing::AllOf(
+      testing::Field(&PerfettoPbDecoderField::status, PERFETTO_PB_DECODER_OK),
+      testing::Field(&PerfettoPbDecoderField::id, id), m);
+}
+
+// Matches a PerfettoPbDecoderField submessage field. Accepts a container
+// matcher for the subfields.
+//
+// Example:
+// PerfettoPbDecoderField field = ...
+// EXPECT_THAT(field, MsgField(ElementsAre(...)));
+template <typename M>
+auto MsgField(M m) {
+  auto f = [](const PerfettoPbDecoderField& field) { return FieldView(field); };
+  return testing::AllOf(
+      testing::Field(&PerfettoPbDecoderField::status, PERFETTO_PB_DECODER_OK),
+      testing::Field(&PerfettoPbDecoderField::wire_type,
+                     PERFETTO_PB_WIRE_TYPE_DELIMITED),
+      testing::ResultOf(f, m));
+}
+
+// Matches a PerfettoPbDecoderField length delimited field. Accepts a string
+// matcher.
+//
+// Example:
+// PerfettoPbDecoderField field = ...
+// EXPECT_THAT(field, StringField("string"));
+template <typename M>
+auto StringField(M m) {
+  auto f = [](const PerfettoPbDecoderField& field) {
+    return std::string(
+        reinterpret_cast<const char*>(field.value.delimited.start),
+        field.value.delimited.len);
+  };
+  return testing::AllOf(
+      testing::Field(&PerfettoPbDecoderField::status, PERFETTO_PB_DECODER_OK),
+      testing::Field(&PerfettoPbDecoderField::wire_type,
+                     PERFETTO_PB_WIRE_TYPE_DELIMITED),
+      testing::ResultOf(f, m));
+}
+
+// Matches a PerfettoPbDecoderField VarInt field. Accepts an integer matcher
+//
+// Example:
+// PerfettoPbDecoderField field = ...
+// EXPECT_THAT(field, VarIntField(1)));
+template <typename M>
+auto VarIntField(M m) {
+  auto f = [](const PerfettoPbDecoderField& field) {
+    return field.value.integer64;
+  };
+  return testing::AllOf(
+      testing::Field(&PerfettoPbDecoderField::status, PERFETTO_PB_DECODER_OK),
+      testing::Field(&PerfettoPbDecoderField::wire_type,
+                     PERFETTO_PB_WIRE_TYPE_VARINT),
+      testing::ResultOf(f, m));
+}
+
+// Matches a PerfettoPbDecoderField fixed64 field. Accepts an integer matcher
+//
+// Example:
+// PerfettoPbDecoderField field = ...
+// EXPECT_THAT(field, Fixed64Field(1)));
+template <typename M>
+auto Fixed64Field(M m) {
+  auto f = [](const PerfettoPbDecoderField& field) {
+    return field.value.integer64;
+  };
+  return testing::AllOf(
+      testing::Field(&PerfettoPbDecoderField::status, PERFETTO_PB_DECODER_OK),
+      testing::Field(&PerfettoPbDecoderField::wire_type,
+                     PERFETTO_PB_WIRE_TYPE_FIXED64),
+      testing::ResultOf(f, m));
+}
+
+// Matches a PerfettoPbDecoderField fixed32 field. Accepts an integer matcher
+//
+// Example:
+// PerfettoPbDecoderField field = ...
+// EXPECT_THAT(field, Fixed32Field(1)));
+template <typename M>
+auto Fixed32Field(M m) {
+  auto f = [](const PerfettoPbDecoderField& field) {
+    return field.value.integer32;
+  };
+  return testing::AllOf(
+      testing::Field(&PerfettoPbDecoderField::status, PERFETTO_PB_DECODER_OK),
+      testing::Field(&PerfettoPbDecoderField::wire_type,
+                     PERFETTO_PB_WIRE_TYPE_FIXED32),
+      testing::ResultOf(f, m));
+}
+
+// Matches a PerfettoPbDecoderField double field. Accepts a double matcher
+//
+// Example:
+// PerfettoPbDecoderField field = ...
+// EXPECT_THAT(field, DoubleField(1.0)));
+template <typename M>
+auto DoubleField(M m) {
+  auto f = [](const PerfettoPbDecoderField& field) {
+    return field.value.double_val;
+  };
+  return testing::AllOf(
+      testing::Field(&PerfettoPbDecoderField::status, PERFETTO_PB_DECODER_OK),
+      testing::Field(&PerfettoPbDecoderField::wire_type,
+                     PERFETTO_PB_WIRE_TYPE_FIXED64),
+      testing::ResultOf(f, m));
+}
+
+// Matches a PerfettoPbDecoderField float field. Accepts a float matcher
+//
+// Example:
+// PerfettoPbDecoderField field = ...
+// EXPECT_THAT(field, FloatField(1.0)));
+template <typename M>
+auto FloatField(M m) {
+  auto f = [](const PerfettoPbDecoderField& field) {
+    return field.value.float_val;
+  };
+  return testing::AllOf(
+      testing::Field(&PerfettoPbDecoderField::status, PERFETTO_PB_DECODER_OK),
+      testing::Field(&PerfettoPbDecoderField::wire_type,
+                     PERFETTO_PB_WIRE_TYPE_FIXED32),
+      testing::ResultOf(f, m));
+}
+
+// Matches a PerfettoPbDecoderField submessage field. Accepts a container
+// matcher for the subfields.
+//
+// Example:
+// PerfettoPbDecoderField field = ...
+// EXPECT_THAT(field, AllFieldsWithId(900, ElementsAre(...)));
+template <typename M>
+auto AllFieldsWithId(int32_t id, M m) {
+  auto f = [id](const PerfettoPbDecoderField& field) {
+    return IdFieldView(field, id);
+  };
+  return testing::AllOf(
+      testing::Field(&PerfettoPbDecoderField::status, PERFETTO_PB_DECODER_OK),
+      testing::Field(&PerfettoPbDecoderField::wire_type,
+                     PERFETTO_PB_WIRE_TYPE_DELIMITED),
+      testing::ResultOf(f, m));
+}
+
+}  // namespace test_utils
+}  // namespace shlib
+}  // namespace perfetto
+
+#endif  // UTILS_H
diff --git a/libs/tracing_perfetto/tracing_perfetto.cpp b/libs/tracing_perfetto/tracing_perfetto.cpp
new file mode 100644
index 0000000..c7fb8bd
--- /dev/null
+++ b/libs/tracing_perfetto/tracing_perfetto.cpp
@@ -0,0 +1,141 @@
+/*
+ * Copyright 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "tracing_perfetto.h"
+
+#include <cutils/trace.h>
+
+#include "perfetto/public/te_category_macros.h"
+#include "trace_categories.h"
+#include "tracing_perfetto_internal.h"
+
+namespace tracing_perfetto {
+
+void registerWithPerfetto(bool test) {
+  internal::registerWithPerfetto(test);
+}
+
+Result traceBegin(uint64_t category, const char* name) {
+  struct PerfettoTeCategory* perfettoTeCategory =
+      internal::toPerfettoCategory(category);
+  if (perfettoTeCategory != nullptr) {
+    return internal::perfettoTraceBegin(*perfettoTeCategory, name);
+  } else {
+    atrace_begin(category, name);
+    return Result::SUCCESS;
+  }
+}
+
+Result traceEnd(uint64_t category) {
+  struct PerfettoTeCategory* perfettoTeCategory =
+      internal::toPerfettoCategory(category);
+  if (perfettoTeCategory != nullptr) {
+    return internal::perfettoTraceEnd(*perfettoTeCategory);
+  } else {
+    atrace_end(category);
+    return Result::SUCCESS;
+  }
+}
+
+Result traceAsyncBegin(uint64_t category, const char* name, int32_t cookie) {
+  struct PerfettoTeCategory* perfettoTeCategory =
+      internal::toPerfettoCategory(category);
+  if (perfettoTeCategory != nullptr) {
+    return internal::perfettoTraceAsyncBegin(*perfettoTeCategory, name, cookie);
+  } else {
+    atrace_async_begin(category, name, cookie);
+    return Result::SUCCESS;
+  }
+}
+
+Result traceAsyncEnd(uint64_t category, const char* name, int32_t cookie) {
+  struct PerfettoTeCategory* perfettoTeCategory =
+      internal::toPerfettoCategory(category);
+  if (perfettoTeCategory != nullptr) {
+    return internal::perfettoTraceAsyncEnd(*perfettoTeCategory, name, cookie);
+  } else {
+    atrace_async_end(category, name, cookie);
+    return Result::SUCCESS;
+  }
+}
+
+Result traceAsyncBeginForTrack(uint64_t category, const char* name,
+                               const char* trackName, int32_t cookie) {
+  struct PerfettoTeCategory* perfettoTeCategory =
+      internal::toPerfettoCategory(category);
+  if (perfettoTeCategory != nullptr) {
+    return internal::perfettoTraceAsyncBeginForTrack(*perfettoTeCategory, name, trackName, cookie);
+  } else {
+    atrace_async_for_track_begin(category, trackName, name, cookie);
+    return Result::SUCCESS;
+  }
+}
+
+Result traceAsyncEndForTrack(uint64_t category, const char* trackName,
+                             int32_t cookie) {
+  struct PerfettoTeCategory* perfettoTeCategory =
+      internal::toPerfettoCategory(category);
+  if (perfettoTeCategory != nullptr) {
+    return internal::perfettoTraceAsyncEndForTrack(*perfettoTeCategory, trackName, cookie);
+  } else {
+    atrace_async_for_track_end(category, trackName, cookie);
+    return Result::SUCCESS;
+  }
+}
+
+Result traceInstant(uint64_t category, const char* name) {
+  struct PerfettoTeCategory* perfettoTeCategory =
+      internal::toPerfettoCategory(category);
+  if (perfettoTeCategory != nullptr) {
+    return internal::perfettoTraceInstant(*perfettoTeCategory, name);
+  } else {
+    atrace_instant(category, name);
+    return Result::SUCCESS;
+  }
+}
+
+Result traceInstantForTrack(uint64_t category, const char* trackName,
+                            const char* name) {
+  struct PerfettoTeCategory* perfettoTeCategory =
+      internal::toPerfettoCategory(category);
+  if (perfettoTeCategory != nullptr) {
+    return internal::perfettoTraceInstantForTrack(*perfettoTeCategory, trackName, name);
+  } else {
+    atrace_instant_for_track(category, trackName, name);
+    return Result::SUCCESS;
+  }
+}
+
+Result traceCounter(uint64_t category, const char* name, int64_t value) {
+  struct PerfettoTeCategory* perfettoTeCategory =
+      internal::toPerfettoCategory(category);
+  if (perfettoTeCategory != nullptr) {
+    return internal::perfettoTraceCounter(*perfettoTeCategory, name, value);
+  } else {
+    atrace_int64(category, name, value);
+    return Result::SUCCESS;
+  }
+}
+
+uint64_t getEnabledCategories() {
+  if (internal::isPerfettoSdkTracingEnabled()) {
+    return internal::getDefaultCategories();
+  } else {
+    return atrace_get_enabled_tags();
+  }
+}
+
+}  // namespace tracing_perfetto
diff --git a/libs/tracing_perfetto/tracing_perfetto_internal.cpp b/libs/tracing_perfetto/tracing_perfetto_internal.cpp
new file mode 100644
index 0000000..58ba428
--- /dev/null
+++ b/libs/tracing_perfetto/tracing_perfetto_internal.cpp
@@ -0,0 +1,229 @@
+/*
+ * Copyright 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#define FRAMEWORK_CATEGORIES(C)                                  \
+  C(always, "always", "Always category")                         \
+  C(graphics, "graphics", "Graphics category")                   \
+  C(input, "input", "Input category")                            \
+  C(view, "view", "View category")                               \
+  C(webview, "webview", "WebView category")                      \
+  C(windowmanager, "wm", "WindowManager category")               \
+  C(activitymanager, "am", "ActivityManager category")           \
+  C(syncmanager, "syncmanager", "SyncManager category")          \
+  C(audio, "audio", "Audio category")                            \
+  C(video, "video", "Video category")                            \
+  C(camera, "camera", "Camera category")                         \
+  C(hal, "hal", "HAL category")                                  \
+  C(app, "app", "App category")                                  \
+  C(resources, "res", "Resources category")                      \
+  C(dalvik, "dalvik", "Dalvik category")                         \
+  C(rs, "rs", "RS category")                                     \
+  C(bionic, "bionic", "Bionic category")                         \
+  C(power, "power", "Power category")                            \
+  C(packagemanager, "packagemanager", "PackageManager category") \
+  C(systemserver, "ss", "System Server category")                \
+  C(database, "database", "Database category")                   \
+  C(network, "network", "Network category")                      \
+  C(adb, "adb", "ADB category")                                  \
+  C(vibrator, "vibrator", "Vibrator category")                   \
+  C(aidl, "aidl", "AIDL category")                               \
+  C(nnapi, "nnapi", "NNAPI category")                            \
+  C(rro, "rro", "RRO category")                                  \
+  C(thermal, "thermal", "Thermal category")
+
+#include "tracing_perfetto_internal.h"
+
+#include <inttypes.h>
+
+#include <mutex>
+
+#include <android_os.h>
+
+#include "perfetto/public/compiler.h"
+#include "perfetto/public/producer.h"
+#include "perfetto/public/te_category_macros.h"
+#include "perfetto/public/te_macros.h"
+#include "perfetto/public/track_event.h"
+#include "trace_categories.h"
+#include "trace_result.h"
+
+namespace tracing_perfetto {
+
+namespace internal {
+
+namespace {
+
+PERFETTO_TE_CATEGORIES_DECLARE(FRAMEWORK_CATEGORIES);
+
+PERFETTO_TE_CATEGORIES_DEFINE(FRAMEWORK_CATEGORIES);
+
+struct PerfettoTeCategory* toCategory(uint64_t inCategory) {
+  switch (inCategory) {
+    case TRACE_CATEGORY_ALWAYS:
+      return &always;
+    case TRACE_CATEGORY_GRAPHICS:
+      return &graphics;
+    case TRACE_CATEGORY_INPUT:
+      return &input;
+    case TRACE_CATEGORY_VIEW:
+      return &view;
+    case TRACE_CATEGORY_WEBVIEW:
+      return &webview;
+    case TRACE_CATEGORY_WINDOW_MANAGER:
+      return &windowmanager;
+    case TRACE_CATEGORY_ACTIVITY_MANAGER:
+      return &activitymanager;
+    case TRACE_CATEGORY_SYNC_MANAGER:
+      return &syncmanager;
+    case TRACE_CATEGORY_AUDIO:
+      return &audio;
+    case TRACE_CATEGORY_VIDEO:
+      return &video;
+    case TRACE_CATEGORY_CAMERA:
+      return &camera;
+    case TRACE_CATEGORY_HAL:
+      return &hal;
+    case TRACE_CATEGORY_APP:
+      return &app;
+    case TRACE_CATEGORY_RESOURCES:
+      return &resources;
+    case TRACE_CATEGORY_DALVIK:
+      return &dalvik;
+    case TRACE_CATEGORY_RS:
+      return &rs;
+    case TRACE_CATEGORY_BIONIC:
+      return &bionic;
+    case TRACE_CATEGORY_POWER:
+      return &power;
+    case TRACE_CATEGORY_PACKAGE_MANAGER:
+      return &packagemanager;
+    case TRACE_CATEGORY_SYSTEM_SERVER:
+      return &systemserver;
+    case TRACE_CATEGORY_DATABASE:
+      return &database;
+    case TRACE_CATEGORY_NETWORK:
+      return &network;
+    case TRACE_CATEGORY_ADB:
+      return &adb;
+    case TRACE_CATEGORY_VIBRATOR:
+      return &vibrator;
+    case TRACE_CATEGORY_AIDL:
+      return &aidl;
+    case TRACE_CATEGORY_NNAPI:
+      return &nnapi;
+    case TRACE_CATEGORY_RRO:
+      return &rro;
+    case TRACE_CATEGORY_THERMAL:
+      return &thermal;
+    default:
+      return nullptr;
+  }
+}
+
+}  // namespace
+
+bool isPerfettoSdkTracingEnabled() {
+  return android::os::perfetto_sdk_tracing();
+}
+
+struct PerfettoTeCategory* toPerfettoCategory(uint64_t category) {
+  if (!isPerfettoSdkTracingEnabled()) {
+    return nullptr;
+  }
+
+  struct PerfettoTeCategory* perfettoCategory = toCategory(category);
+  bool enabled = PERFETTO_UNLIKELY(PERFETTO_ATOMIC_LOAD_EXPLICIT(
+      (*perfettoCategory).enabled, PERFETTO_MEMORY_ORDER_RELAXED));
+  return enabled ? perfettoCategory : nullptr;
+}
+
+void registerWithPerfetto(bool test) {
+  if (!isPerfettoSdkTracingEnabled()) {
+    return;
+  }
+  static std::once_flag registration;
+  std::call_once(registration, [test]() {
+    struct PerfettoProducerInitArgs args = PERFETTO_PRODUCER_INIT_ARGS_INIT();
+    args.backends = test ? PERFETTO_BACKEND_IN_PROCESS : PERFETTO_BACKEND_SYSTEM;
+    PerfettoProducerInit(args);
+    PerfettoTeInit();
+    PERFETTO_TE_REGISTER_CATEGORIES(FRAMEWORK_CATEGORIES);
+  });
+}
+
+Result perfettoTraceBegin(const struct PerfettoTeCategory& category, const char* name) {
+  PERFETTO_TE(category, PERFETTO_TE_SLICE_BEGIN(name));
+  return Result::SUCCESS;
+}
+
+Result perfettoTraceEnd(const struct PerfettoTeCategory& category) {
+  PERFETTO_TE(category, PERFETTO_TE_SLICE_END());
+  return Result::SUCCESS;
+}
+
+Result perfettoTraceAsyncBeginForTrack(const struct PerfettoTeCategory& category, const char* name,
+                                       const char* trackName, uint64_t cookie) {
+  PERFETTO_TE(
+      category, PERFETTO_TE_SLICE_BEGIN(name),
+      PERFETTO_TE_NAMED_TRACK(trackName, cookie, PerfettoTeProcessTrackUuid()));
+  return Result::SUCCESS;
+}
+
+Result perfettoTraceAsyncEndForTrack(const struct PerfettoTeCategory& category,
+                                     const char* trackName, uint64_t cookie) {
+  PERFETTO_TE(
+      category, PERFETTO_TE_SLICE_END(),
+      PERFETTO_TE_NAMED_TRACK(trackName, cookie, PerfettoTeProcessTrackUuid()));
+  return Result::SUCCESS;
+}
+
+Result perfettoTraceAsyncBegin(const struct PerfettoTeCategory& category, const char* name,
+                               uint64_t cookie) {
+  return perfettoTraceAsyncBeginForTrack(category, name, name, cookie);
+}
+
+Result perfettoTraceAsyncEnd(const struct PerfettoTeCategory& category, const char* name,
+                             uint64_t cookie) {
+  return perfettoTraceAsyncEndForTrack(category, name, cookie);
+}
+
+Result perfettoTraceInstant(const struct PerfettoTeCategory& category, const char* name) {
+  PERFETTO_TE(category, PERFETTO_TE_INSTANT(name));
+  return Result::SUCCESS;
+}
+
+Result perfettoTraceInstantForTrack(const struct PerfettoTeCategory& category,
+                                    const char* trackName, const char* name) {
+  PERFETTO_TE(
+      category, PERFETTO_TE_INSTANT(name),
+      PERFETTO_TE_NAMED_TRACK(trackName, 1, PerfettoTeProcessTrackUuid()));
+  return Result::SUCCESS;
+}
+
+Result perfettoTraceCounter(const struct PerfettoTeCategory& category,
+                            [[maybe_unused]] const char* name, int64_t value) {
+  PERFETTO_TE(category, PERFETTO_TE_COUNTER(),
+              PERFETTO_TE_INT_COUNTER(value));
+  return Result::SUCCESS;
+}
+
+uint64_t getDefaultCategories() {
+  return TRACE_CATEGORIES;
+}
+
+}  // namespace internal
+
+}  // namespace tracing_perfetto
diff --git a/libs/tracing_perfetto/tracing_perfetto_internal.h b/libs/tracing_perfetto/tracing_perfetto_internal.h
new file mode 100644
index 0000000..9a579f1
--- /dev/null
+++ b/libs/tracing_perfetto/tracing_perfetto_internal.h
@@ -0,0 +1,65 @@
+/*
+ * Copyright 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef TRACING_PERFETTO_INTERNAL_H
+#define TRACING_PERFETTO_INTERNAL_H
+
+#include <stdint.h>
+
+#include "include/trace_result.h"
+#include "perfetto/public/te_category_macros.h"
+
+namespace tracing_perfetto {
+
+namespace internal {
+
+bool isPerfettoSdkTracingEnabled();
+
+struct PerfettoTeCategory* toPerfettoCategory(uint64_t category);
+
+void registerWithPerfetto(bool test = false);
+
+Result perfettoTraceBegin(const struct PerfettoTeCategory& category, const char* name);
+
+Result perfettoTraceEnd(const struct PerfettoTeCategory& category);
+
+Result perfettoTraceAsyncBegin(const struct PerfettoTeCategory& category, const char* name,
+                               uint64_t cookie);
+
+Result perfettoTraceAsyncEnd(const struct PerfettoTeCategory& category, const char* name,
+                             uint64_t cookie);
+
+Result perfettoTraceAsyncBeginForTrack(const struct PerfettoTeCategory& category, const char* name,
+                                       const char* trackName, uint64_t cookie);
+
+Result perfettoTraceAsyncEndForTrack(const struct PerfettoTeCategory& category,
+                                     const char* trackName, uint64_t cookie);
+
+Result perfettoTraceInstant(const struct PerfettoTeCategory& category, const char* name);
+
+Result perfettoTraceInstantForTrack(const struct PerfettoTeCategory& category,
+                                    const char* trackName, const char* name);
+
+Result perfettoTraceCounter(const struct PerfettoTeCategory& category, const char* name,
+                            int64_t value);
+
+uint64_t getDefaultCategories();
+
+}  // namespace internal
+
+}  // namespace tracing_perfetto
+
+#endif  // TRACING_PERFETTO_INTERNAL_H
diff --git a/services/inputflinger/PointerChoreographer.cpp b/services/inputflinger/PointerChoreographer.cpp
index 3ac4285..9db3574 100644
--- a/services/inputflinger/PointerChoreographer.cpp
+++ b/services/inputflinger/PointerChoreographer.cpp
@@ -26,6 +26,7 @@
 namespace android {
 
 namespace {
+
 bool isFromMouse(const NotifyMotionArgs& args) {
     return isFromSource(args.source, AINPUT_SOURCE_MOUSE) &&
             args.pointerProperties[0].toolType == ToolType::MOUSE;
@@ -44,13 +45,23 @@
 bool isStylusHoverEvent(const NotifyMotionArgs& args) {
     return isStylusEvent(args.source, args.pointerProperties) && isHoverAction(args.action);
 }
+
+inline void notifyPointerDisplayChange(std::optional<std::tuple<int32_t, FloatPoint>> change,
+                                       PointerChoreographerPolicyInterface& policy) {
+    if (!change) {
+        return;
+    }
+    const auto& [displayId, cursorPosition] = *change;
+    policy.notifyPointerDisplayIdChanged(displayId, cursorPosition);
+}
+
 } // namespace
 
 // --- PointerChoreographer ---
 
 PointerChoreographer::PointerChoreographer(InputListenerInterface& listener,
                                            PointerChoreographerPolicyInterface& policy)
-      : mTouchControllerConstructor([this]() REQUIRES(mLock) {
+      : mTouchControllerConstructor([this]() {
             return mPolicy.createPointerController(
                     PointerControllerInterface::ControllerType::TOUCH);
         }),
@@ -62,10 +73,16 @@
         mStylusPointerIconEnabled(false) {}
 
 void PointerChoreographer::notifyInputDevicesChanged(const NotifyInputDevicesChangedArgs& args) {
-    std::scoped_lock _l(mLock);
+    PointerDisplayChange pointerDisplayChange;
 
-    mInputDeviceInfos = args.inputDeviceInfos;
-    updatePointerControllersLocked();
+    { // acquire lock
+        std::scoped_lock _l(mLock);
+
+        mInputDeviceInfos = args.inputDeviceInfos;
+        pointerDisplayChange = updatePointerControllersLocked();
+    } // release lock
+
+    notifyPointerDisplayChange(pointerDisplayChange, mPolicy);
     mNextListener.notify(args);
 }
 
@@ -104,7 +121,7 @@
                    << args.dump();
     }
 
-    auto [displayId, pc] = getDisplayIdAndMouseControllerLocked(args.displayId);
+    auto [displayId, pc] = ensureMouseControllerLocked(args.displayId);
 
     const float deltaX = args.pointerCoords[0].getAxisValue(AMOTION_EVENT_AXIS_RELATIVE_X);
     const float deltaY = args.pointerCoords[0].getAxisValue(AMOTION_EVENT_AXIS_RELATIVE_Y);
@@ -124,7 +141,7 @@
 }
 
 NotifyMotionArgs PointerChoreographer::processTouchpadEventLocked(const NotifyMotionArgs& args) {
-    auto [displayId, pc] = getDisplayIdAndMouseControllerLocked(args.displayId);
+    auto [displayId, pc] = ensureMouseControllerLocked(args.displayId);
 
     NotifyMotionArgs newArgs(args);
     newArgs.displayId = displayId;
@@ -308,17 +325,13 @@
     return associatedDisplayId == ADISPLAY_ID_NONE ? mDefaultMouseDisplayId : associatedDisplayId;
 }
 
-std::pair<int32_t, PointerControllerInterface&>
-PointerChoreographer::getDisplayIdAndMouseControllerLocked(int32_t associatedDisplayId) {
+std::pair<int32_t, PointerControllerInterface&> PointerChoreographer::ensureMouseControllerLocked(
+        int32_t associatedDisplayId) {
     const int32_t displayId = getTargetMouseDisplayLocked(associatedDisplayId);
 
-    // Get the mouse pointer controller for the display, or create one if it doesn't exist.
-    auto [it, emplaced] =
-            mMousePointersByDisplay.try_emplace(displayId,
-                                                getMouseControllerConstructor(displayId));
-    if (emplaced) {
-        notifyPointerDisplayIdChangedLocked();
-    }
+    auto it = mMousePointersByDisplay.find(displayId);
+    LOG_ALWAYS_FATAL_IF(it == mMousePointersByDisplay.end(),
+                        "There is no mouse controller created for display %d", displayId);
 
     return {displayId, *it->second};
 }
@@ -333,7 +346,7 @@
     return mDisplaysWithPointersHidden.find(displayId) == mDisplaysWithPointersHidden.end();
 }
 
-void PointerChoreographer::updatePointerControllersLocked() {
+PointerChoreographer::PointerDisplayChange PointerChoreographer::updatePointerControllersLocked() {
     std::set<int32_t /*displayId*/> mouseDisplaysToKeep;
     std::set<DeviceId> touchDevicesToKeep;
     std::set<DeviceId> stylusDevicesToKeep;
@@ -382,11 +395,12 @@
                 mInputDeviceInfos.end();
     });
 
-    // Notify the policy if there's a change on the pointer display ID.
-    notifyPointerDisplayIdChangedLocked();
+    // Check if we need to notify the policy if there's a change on the pointer display ID.
+    return calculatePointerDisplayChangeToNotify();
 }
 
-void PointerChoreographer::notifyPointerDisplayIdChangedLocked() {
+PointerChoreographer::PointerDisplayChange
+PointerChoreographer::calculatePointerDisplayChangeToNotify() {
     int32_t displayIdToNotify = ADISPLAY_ID_NONE;
     FloatPoint cursorPosition = {0, 0};
     if (const auto it = mMousePointersByDisplay.find(mDefaultMouseDisplayId);
@@ -398,38 +412,49 @@
         displayIdToNotify = pointerController->getDisplayId();
         cursorPosition = pointerController->getPosition();
     }
-
     if (mNotifiedPointerDisplayId == displayIdToNotify) {
-        return;
+        return {};
     }
-    mPolicy.notifyPointerDisplayIdChanged(displayIdToNotify, cursorPosition);
     mNotifiedPointerDisplayId = displayIdToNotify;
+    return {{displayIdToNotify, cursorPosition}};
 }
 
 void PointerChoreographer::setDefaultMouseDisplayId(int32_t displayId) {
-    std::scoped_lock _l(mLock);
+    PointerDisplayChange pointerDisplayChange;
 
-    mDefaultMouseDisplayId = displayId;
-    updatePointerControllersLocked();
+    { // acquire lock
+        std::scoped_lock _l(mLock);
+
+        mDefaultMouseDisplayId = displayId;
+        pointerDisplayChange = updatePointerControllersLocked();
+    } // release lock
+
+    notifyPointerDisplayChange(pointerDisplayChange, mPolicy);
 }
 
 void PointerChoreographer::setDisplayViewports(const std::vector<DisplayViewport>& viewports) {
-    std::scoped_lock _l(mLock);
-    for (const auto& viewport : viewports) {
-        const int32_t displayId = viewport.displayId;
-        if (const auto it = mMousePointersByDisplay.find(displayId);
-            it != mMousePointersByDisplay.end()) {
-            it->second->setDisplayViewport(viewport);
-        }
-        for (const auto& [deviceId, stylusPointerController] : mStylusPointersByDevice) {
-            const InputDeviceInfo* info = findInputDeviceLocked(deviceId);
-            if (info && info->getAssociatedDisplayId() == displayId) {
-                stylusPointerController->setDisplayViewport(viewport);
+    PointerDisplayChange pointerDisplayChange;
+
+    { // acquire lock
+        std::scoped_lock _l(mLock);
+        for (const auto& viewport : viewports) {
+            const int32_t displayId = viewport.displayId;
+            if (const auto it = mMousePointersByDisplay.find(displayId);
+                it != mMousePointersByDisplay.end()) {
+                it->second->setDisplayViewport(viewport);
+            }
+            for (const auto& [deviceId, stylusPointerController] : mStylusPointersByDevice) {
+                const InputDeviceInfo* info = findInputDeviceLocked(deviceId);
+                if (info && info->getAssociatedDisplayId() == displayId) {
+                    stylusPointerController->setDisplayViewport(viewport);
+                }
             }
         }
-    }
-    mViewports = viewports;
-    notifyPointerDisplayIdChangedLocked();
+        mViewports = viewports;
+        pointerDisplayChange = calculatePointerDisplayChangeToNotify();
+    } // release lock
+
+    notifyPointerDisplayChange(pointerDisplayChange, mPolicy);
 }
 
 std::optional<DisplayViewport> PointerChoreographer::getViewportForPointerDevice(
@@ -453,21 +478,33 @@
 }
 
 void PointerChoreographer::setShowTouchesEnabled(bool enabled) {
-    std::scoped_lock _l(mLock);
-    if (mShowTouchesEnabled == enabled) {
-        return;
-    }
-    mShowTouchesEnabled = enabled;
-    updatePointerControllersLocked();
+    PointerDisplayChange pointerDisplayChange;
+
+    { // acquire lock
+        std::scoped_lock _l(mLock);
+        if (mShowTouchesEnabled == enabled) {
+            return;
+        }
+        mShowTouchesEnabled = enabled;
+        pointerDisplayChange = updatePointerControllersLocked();
+    } // release lock
+
+    notifyPointerDisplayChange(pointerDisplayChange, mPolicy);
 }
 
 void PointerChoreographer::setStylusPointerIconEnabled(bool enabled) {
-    std::scoped_lock _l(mLock);
-    if (mStylusPointerIconEnabled == enabled) {
-        return;
-    }
-    mStylusPointerIconEnabled = enabled;
-    updatePointerControllersLocked();
+    PointerDisplayChange pointerDisplayChange;
+
+    { // acquire lock
+        std::scoped_lock _l(mLock);
+        if (mStylusPointerIconEnabled == enabled) {
+            return;
+        }
+        mStylusPointerIconEnabled = enabled;
+        pointerDisplayChange = updatePointerControllersLocked();
+    } // release lock
+
+    notifyPointerDisplayChange(pointerDisplayChange, mPolicy);
 }
 
 bool PointerChoreographer::setPointerIcon(
diff --git a/services/inputflinger/PointerChoreographer.h b/services/inputflinger/PointerChoreographer.h
index 6aab3aa..db1488b 100644
--- a/services/inputflinger/PointerChoreographer.h
+++ b/services/inputflinger/PointerChoreographer.h
@@ -109,11 +109,13 @@
     void dump(std::string& dump) override;
 
 private:
-    void updatePointerControllersLocked() REQUIRES(mLock);
-    void notifyPointerDisplayIdChangedLocked() REQUIRES(mLock);
+    using PointerDisplayChange =
+            std::optional<std::tuple<int32_t /*displayId*/, FloatPoint /*cursorPosition*/>>;
+    [[nodiscard]] PointerDisplayChange updatePointerControllersLocked() REQUIRES(mLock);
+    [[nodiscard]] PointerDisplayChange calculatePointerDisplayChangeToNotify() REQUIRES(mLock);
     const DisplayViewport* findViewportByIdLocked(int32_t displayId) const REQUIRES(mLock);
     int32_t getTargetMouseDisplayLocked(int32_t associatedDisplayId) const REQUIRES(mLock);
-    std::pair<int32_t, PointerControllerInterface&> getDisplayIdAndMouseControllerLocked(
+    std::pair<int32_t /*displayId*/, PointerControllerInterface&> ensureMouseControllerLocked(
             int32_t associatedDisplayId) REQUIRES(mLock);
     InputDeviceInfo* findInputDeviceLocked(DeviceId deviceId) REQUIRES(mLock);
     bool canUnfadeOnDisplay(int32_t displayId) REQUIRES(mLock);
diff --git a/services/inputflinger/dispatcher/CancelationOptions.h b/services/inputflinger/dispatcher/CancelationOptions.h
index 83e6a60..9c73f03 100644
--- a/services/inputflinger/dispatcher/CancelationOptions.h
+++ b/services/inputflinger/dispatcher/CancelationOptions.h
@@ -16,6 +16,8 @@
 
 #pragma once
 
+#include "trace/EventTrackerInterface.h"
+
 #include <input/Input.h>
 #include <bitset>
 #include <optional>
@@ -51,7 +53,13 @@
     // The specific pointers to cancel, or nullopt to cancel all pointer events
     std::optional<std::bitset<MAX_POINTER_ID + 1>> pointerIds = std::nullopt;
 
-    CancelationOptions(Mode mode, const char* reason) : mode(mode), reason(reason) {}
+    const std::unique_ptr<trace::EventTrackerInterface>& traceTracker;
+
+    explicit CancelationOptions(Mode mode, const char* reason,
+                                const std::unique_ptr<trace::EventTrackerInterface>& traceTracker)
+          : mode(mode), reason(reason), traceTracker(traceTracker) {}
+    CancelationOptions(const CancelationOptions&) = delete;
+    CancelationOptions operator=(const CancelationOptions&) = delete;
 };
 
 } // namespace inputdispatcher
diff --git a/services/inputflinger/dispatcher/Entry.h b/services/inputflinger/dispatcher/Entry.h
index 1298b5d..06d5c7d 100644
--- a/services/inputflinger/dispatcher/Entry.h
+++ b/services/inputflinger/dispatcher/Entry.h
@@ -140,6 +140,7 @@
     mutable InterceptKeyResult interceptKeyResult; // set based on the interception result
     mutable nsecs_t interceptKeyWakeupTime;        // used with INTERCEPT_KEY_RESULT_TRY_AGAIN_LATER
     mutable int32_t flags;
+    // TODO(b/328618922): Refactor key repeat generation to make repeatCount non-mutable.
     mutable int32_t repeatCount;
 
     KeyEntry(int32_t id, std::shared_ptr<InjectionState> injectionState, nsecs_t eventTime,
diff --git a/services/inputflinger/dispatcher/InputDispatcher.cpp b/services/inputflinger/dispatcher/InputDispatcher.cpp
index 9c04a1d..f06caa6 100644
--- a/services/inputflinger/dispatcher/InputDispatcher.cpp
+++ b/services/inputflinger/dispatcher/InputDispatcher.cpp
@@ -103,6 +103,26 @@
     }
 }
 
+// Helper to get a trace tracker from a traced key or motion entry.
+const std::unique_ptr<trace::EventTrackerInterface>& getTraceTracker(const EventEntry& entry) {
+    switch (entry.type) {
+        case EventEntry::Type::MOTION: {
+            const auto& motion = static_cast<const MotionEntry&>(entry);
+            ensureEventTraced(motion);
+            return motion.traceTracker;
+        }
+        case EventEntry::Type::KEY: {
+            const auto& key = static_cast<const KeyEntry&>(entry);
+            ensureEventTraced(key);
+            return key.traceTracker;
+        }
+        default: {
+            const static std::unique_ptr<trace::EventTrackerInterface> kNullTracker;
+            return kNullTracker;
+        }
+    }
+}
+
 // Temporarily releases a held mutex for the lifetime of the instance.
 // Named to match std::scoped_lock
 class scoped_unlock {
@@ -379,7 +399,8 @@
                                                    const InputTarget& inputTarget,
                                                    std::shared_ptr<const EventEntry> eventEntry,
                                                    ftl::Flags<InputTarget::Flags> inputTargetFlags,
-                                                   int64_t vsyncId) {
+                                                   int64_t vsyncId,
+                                                   trace::InputTracerInterface* tracer) {
     const bool zeroCoords = inputTargetFlags.test(InputTarget::Flags::ZERO_COORDS);
     const sp<WindowInfoHandle> win = inputTarget.windowHandle;
     const std::optional<int32_t> windowId =
@@ -442,6 +463,10 @@
                                           motionEntry.xCursorPosition, motionEntry.yCursorPosition,
                                           motionEntry.downTime, motionEntry.pointerProperties,
                                           pointerCoords);
+    if (tracer) {
+        combinedMotionEntry->traceTracker =
+                tracer->traceDerivedEvent(*combinedMotionEntry, *motionEntry.traceTracker);
+    }
 
     std::unique_ptr<DispatchEntry> dispatchEntry =
             std::make_unique<DispatchEntry>(std::move(combinedMotionEntry), inputTargetFlags,
@@ -656,13 +681,13 @@
 std::vector<TouchedWindow> getHoveringWindowsLocked(const TouchState* oldState,
                                                     const TouchState& newTouchState,
                                                     const MotionEntry& entry) {
-    std::vector<TouchedWindow> out;
     const int32_t maskedAction = MotionEvent::getActionMasked(entry.action);
 
     if (maskedAction == AMOTION_EVENT_ACTION_SCROLL) {
         // ACTION_SCROLL events should not affect the hovering pointer dispatch
         return {};
     }
+    std::vector<TouchedWindow> out;
 
     // We should consider all hovering pointers here. But for now, just use the first one
     const PointerProperties& pointer = entry.pointerProperties[0];
@@ -837,6 +862,30 @@
     }
 }
 
+class ScopedSyntheticEventTracer {
+public:
+    ScopedSyntheticEventTracer(std::unique_ptr<trace::InputTracerInterface>& tracer)
+          : mTracer(tracer) {
+        if (mTracer) {
+            mEventTracker = mTracer->createTrackerForSyntheticEvent();
+        }
+    }
+
+    ~ScopedSyntheticEventTracer() {
+        if (mTracer) {
+            mTracer->eventProcessingComplete(*mEventTracker);
+        }
+    }
+
+    const std::unique_ptr<trace::EventTrackerInterface>& getTracker() const {
+        return mEventTracker;
+    }
+
+private:
+    std::unique_ptr<trace::InputTracerInterface>& mTracer;
+    std::unique_ptr<trace::EventTrackerInterface> mEventTracker;
+};
+
 } // namespace
 
 // --- InputDispatcher ---
@@ -1147,10 +1196,6 @@
                 dropReason = DropReason::BLOCKED;
             }
             done = dispatchKeyLocked(currentTime, keyEntry, &dropReason, nextWakeupTime);
-            if (done && mTracer) {
-                ensureEventTraced(*keyEntry);
-                mTracer->eventProcessingComplete(*keyEntry->traceTracker);
-            }
             break;
         }
 
@@ -1176,10 +1221,6 @@
                 }
             }
             done = dispatchMotionLocked(currentTime, motionEntry, &dropReason, nextWakeupTime);
-            if (done && mTracer) {
-                ensureEventTraced(*motionEntry);
-                mTracer->eventProcessingComplete(*motionEntry->traceTracker);
-            }
             break;
         }
 
@@ -1205,6 +1246,12 @@
         }
         mLastDropReason = dropReason;
 
+        if (mTracer) {
+            if (auto& traceTracker = getTraceTracker(*mPendingEvent); traceTracker != nullptr) {
+                mTracer->eventProcessingComplete(*traceTracker);
+            }
+        }
+
         releasePendingEventLocked();
         nextWakeupTime = LLONG_MIN; // force next poll to wake up immediately
     }
@@ -1456,8 +1503,9 @@
 
     switch (entry.type) {
         case EventEntry::Type::KEY: {
-            CancelationOptions options(CancelationOptions::Mode::CANCEL_NON_POINTER_EVENTS, reason);
             const KeyEntry& keyEntry = static_cast<const KeyEntry&>(entry);
+            CancelationOptions options(CancelationOptions::Mode::CANCEL_NON_POINTER_EVENTS, reason,
+                                       keyEntry.traceTracker);
             options.displayId = keyEntry.displayId;
             options.deviceId = keyEntry.deviceId;
             synthesizeCancelationEventsForAllConnectionsLocked(options);
@@ -1466,13 +1514,14 @@
         case EventEntry::Type::MOTION: {
             const MotionEntry& motionEntry = static_cast<const MotionEntry&>(entry);
             if (motionEntry.source & AINPUT_SOURCE_CLASS_POINTER) {
-                CancelationOptions options(CancelationOptions::Mode::CANCEL_POINTER_EVENTS, reason);
+                CancelationOptions options(CancelationOptions::Mode::CANCEL_POINTER_EVENTS, reason,
+                                           motionEntry.traceTracker);
                 options.displayId = motionEntry.displayId;
                 options.deviceId = motionEntry.deviceId;
                 synthesizeCancelationEventsForAllConnectionsLocked(options);
             } else {
                 CancelationOptions options(CancelationOptions::Mode::CANCEL_NON_POINTER_EVENTS,
-                                           reason);
+                                           reason, motionEntry.traceTracker);
                 options.displayId = motionEntry.displayId;
                 options.deviceId = motionEntry.deviceId;
                 synthesizeCancelationEventsForAllConnectionsLocked(options);
@@ -1607,7 +1656,9 @@
         resetKeyRepeatLocked();
     }
 
-    CancelationOptions options(CancelationOptions::Mode::CANCEL_ALL_EVENTS, "device was reset");
+    ScopedSyntheticEventTracer traceContext(mTracer);
+    CancelationOptions options(CancelationOptions::Mode::CANCEL_ALL_EVENTS, "device was reset",
+                               traceContext.getTracker());
     options.deviceId = entry.deviceId;
     synthesizeCancelationEventsForAllConnectionsLocked(options);
 
@@ -1996,7 +2047,7 @@
         CancelationOptions::Mode mode(
                 isPointerEvent ? CancelationOptions::Mode::CANCEL_POINTER_EVENTS
                                : CancelationOptions::Mode::CANCEL_NON_POINTER_EVENTS);
-        CancelationOptions options(mode, "input event injection failed");
+        CancelationOptions options(mode, "input event injection failed", entry->traceTracker);
         options.displayId = entry->displayId;
         synthesizeCancelationEventsForMonitorsLocked(options);
         return true;
@@ -2102,8 +2153,9 @@
     if (connection->status != Connection::Status::NORMAL) {
         return;
     }
+    ScopedSyntheticEventTracer traceContext(mTracer);
     CancelationOptions options(CancelationOptions::Mode::CANCEL_ALL_EVENTS,
-                               "application not responding");
+                               "application not responding", traceContext.getTracker());
 
     sp<WindowInfoHandle> windowHandle;
     if (!connection->monitor) {
@@ -2633,19 +2685,14 @@
     {
         std::vector<TouchedWindow> hoveringWindows =
                 getHoveringWindowsLocked(oldState, tempTouchState, entry);
+        // Hardcode to single hovering pointer for now.
+        std::bitset<MAX_POINTER_ID + 1> pointerIds;
+        pointerIds.set(entry.pointerProperties[0].id);
         for (const TouchedWindow& touchedWindow : hoveringWindows) {
-            std::optional<InputTarget> target =
-                    createInputTargetLocked(touchedWindow.windowHandle, touchedWindow.dispatchMode,
-                                            touchedWindow.targetFlags,
-                                            touchedWindow.getDownTimeInTarget(entry.deviceId));
-            if (!target) {
-                continue;
-            }
-            // Hardcode to single hovering pointer for now.
-            std::bitset<MAX_POINTER_ID + 1> pointerIds;
-            pointerIds.set(entry.pointerProperties[0].id);
-            target->addPointers(pointerIds, touchedWindow.windowHandle->getInfo()->transform);
-            targets.push_back(*target);
+            addPointerWindowTargetLocked(touchedWindow.windowHandle, touchedWindow.dispatchMode,
+                                         touchedWindow.targetFlags, pointerIds,
+                                         touchedWindow.getDownTimeInTarget(entry.deviceId),
+                                         targets);
         }
     }
 
@@ -3373,7 +3420,7 @@
     // Enqueue a new dispatch entry onto the outbound queue for this connection.
     std::unique_ptr<DispatchEntry> dispatchEntry =
             createDispatchEntry(mIdGenerator, inputTarget, eventEntry, inputTarget.flags,
-                                mWindowInfosVsyncId);
+                                mWindowInfosVsyncId, mTracer.get());
 
     // Use the eventEntry from dispatchEntry since the entry may have changed and can now be a
     // different EventEntry than what was passed in.
@@ -3456,21 +3503,31 @@
                             usingCoords = pointerInfo->second;
                         }
                     }
-                    // Generate a new MotionEntry with a new eventId using the resolved action and
-                    // flags.
-                    resolvedMotion = std::make_shared<
-                            MotionEntry>(mIdGenerator.nextId(), motionEntry.injectionState,
-                                         motionEntry.eventTime, motionEntry.deviceId,
-                                         motionEntry.source, motionEntry.displayId,
-                                         motionEntry.policyFlags, resolvedAction,
-                                         motionEntry.actionButton, resolvedFlags,
-                                         motionEntry.metaState, motionEntry.buttonState,
-                                         motionEntry.classification, motionEntry.edgeFlags,
-                                         motionEntry.xPrecision, motionEntry.yPrecision,
-                                         motionEntry.xCursorPosition, motionEntry.yCursorPosition,
-                                         motionEntry.downTime,
-                                         usingProperties.value_or(motionEntry.pointerProperties),
-                                         usingCoords.value_or(motionEntry.pointerCoords));
+                    {
+                        // Generate a new MotionEntry with a new eventId using the resolved action
+                        // and flags, and set it as the resolved entry.
+                        auto newEntry = std::make_shared<
+                                MotionEntry>(mIdGenerator.nextId(), motionEntry.injectionState,
+                                             motionEntry.eventTime, motionEntry.deviceId,
+                                             motionEntry.source, motionEntry.displayId,
+                                             motionEntry.policyFlags, resolvedAction,
+                                             motionEntry.actionButton, resolvedFlags,
+                                             motionEntry.metaState, motionEntry.buttonState,
+                                             motionEntry.classification, motionEntry.edgeFlags,
+                                             motionEntry.xPrecision, motionEntry.yPrecision,
+                                             motionEntry.xCursorPosition,
+                                             motionEntry.yCursorPosition, motionEntry.downTime,
+                                             usingProperties.value_or(
+                                                     motionEntry.pointerProperties),
+                                             usingCoords.value_or(motionEntry.pointerCoords));
+                        if (mTracer) {
+                            ensureEventTraced(motionEntry);
+                            newEntry->traceTracker =
+                                    mTracer->traceDerivedEvent(*newEntry,
+                                                               *motionEntry.traceTracker);
+                        }
+                        resolvedMotion = newEntry;
+                    }
                     if (ATRACE_ENABLED()) {
                         std::string message = StringPrintf("Transmute MotionEvent(id=0x%" PRIx32
                                                            ") to MotionEvent(id=0x%" PRIx32 ").",
@@ -3493,9 +3550,14 @@
                 LOG(INFO) << "Canceling pointers for device " << resolvedMotion->deviceId << " in "
                           << connection->getInputChannelName() << " with event "
                           << cancelEvent->getDescription();
+                if (mTracer) {
+                    static_cast<MotionEntry&>(*cancelEvent).traceTracker =
+                            mTracer->traceDerivedEvent(*cancelEvent, *resolvedMotion->traceTracker);
+                }
                 std::unique_ptr<DispatchEntry> cancelDispatchEntry =
                         createDispatchEntry(mIdGenerator, inputTarget, std::move(cancelEvent),
-                                            ftl::Flags<InputTarget::Flags>(), mWindowInfosVsyncId);
+                                            ftl::Flags<InputTarget::Flags>(), mWindowInfosVsyncId,
+                                            mTracer.get());
 
                 // Send these cancel events to the queue before sending the event from the new
                 // device.
@@ -3729,7 +3791,8 @@
                                                   keyEntry.metaState, keyEntry.repeatCount,
                                                   keyEntry.downTime, keyEntry.eventTime);
                 if (mTracer) {
-                    mTracer->traceEventDispatch(*dispatchEntry, keyEntry.traceTracker.get());
+                    ensureEventTraced(keyEntry);
+                    mTracer->traceEventDispatch(*dispatchEntry, *keyEntry.traceTracker);
                 }
                 break;
             }
@@ -3742,7 +3805,8 @@
                 const MotionEntry& motionEntry = static_cast<const MotionEntry&>(eventEntry);
                 status = publishMotionEvent(*connection, *dispatchEntry);
                 if (mTracer) {
-                    mTracer->traceEventDispatch(*dispatchEntry, motionEntry.traceTracker.get());
+                    ensureEventTraced(motionEntry);
+                    mTracer->traceEventDispatch(*dispatchEntry, *motionEntry.traceTracker);
                 }
                 break;
             }
@@ -4121,6 +4185,11 @@
 
         switch (cancelationEventEntry->type) {
             case EventEntry::Type::KEY: {
+                if (mTracer) {
+                    static_cast<KeyEntry&>(*cancelationEventEntry).traceTracker =
+                            mTracer->traceDerivedEvent(*cancelationEventEntry,
+                                                       *options.traceTracker);
+                }
                 const auto& keyEntry = static_cast<const KeyEntry&>(*cancelationEventEntry);
                 if (window) {
                     addWindowTargetLocked(window, InputTarget::DispatchMode::AS_IS,
@@ -4132,6 +4201,11 @@
                 break;
             }
             case EventEntry::Type::MOTION: {
+                if (mTracer) {
+                    static_cast<MotionEntry&>(*cancelationEventEntry).traceTracker =
+                            mTracer->traceDerivedEvent(*cancelationEventEntry,
+                                                       *options.traceTracker);
+                }
                 const auto& motionEntry = static_cast<const MotionEntry&>(*cancelationEventEntry);
                 if (window) {
                     std::bitset<MAX_POINTER_ID + 1> pointerIds;
@@ -4179,6 +4253,9 @@
         }
 
         if (targets.size() != 1) LOG(FATAL) << __func__ << ": InputTarget not created";
+        if (mTracer) {
+            mTracer->dispatchToTargetHint(*options.traceTracker, targets[0]);
+        }
         enqueueDispatchEntryLocked(connection, std::move(cancelationEventEntry), targets[0]);
     }
 
@@ -4190,7 +4267,8 @@
 
 void InputDispatcher::synthesizePointerDownEventsForConnectionLocked(
         const nsecs_t downTime, const std::shared_ptr<Connection>& connection,
-        ftl::Flags<InputTarget::Flags> targetFlags) {
+        ftl::Flags<InputTarget::Flags> targetFlags,
+        const std::unique_ptr<trace::EventTrackerInterface>& traceTracker) {
     if (connection->status != Connection::Status::NORMAL) {
         return;
     }
@@ -4219,6 +4297,10 @@
         std::vector<InputTarget> targets{};
         switch (downEventEntry->type) {
             case EventEntry::Type::MOTION: {
+                if (mTracer) {
+                    static_cast<MotionEntry&>(*downEventEntry).traceTracker =
+                            mTracer->traceDerivedEvent(*downEventEntry, *traceTracker);
+                }
                 const auto& motionEntry = static_cast<const MotionEntry&>(*downEventEntry);
                 if (windowHandle != nullptr) {
                     std::bitset<MAX_POINTER_ID + 1> pointerIds;
@@ -4256,6 +4338,9 @@
         }
 
         if (targets.size() != 1) LOG(FATAL) << __func__ << ": InputTarget not created";
+        if (mTracer) {
+            mTracer->dispatchToTargetHint(*traceTracker, targets[0]);
+        }
         enqueueDispatchEntryLocked(connection, std::move(downEventEntry), targets[0]);
     }
 
@@ -4303,6 +4388,10 @@
                                           originalMotionEntry.xCursorPosition,
                                           originalMotionEntry.yCursorPosition, splitDownTime,
                                           pointerProperties, pointerCoords);
+    if (mTracer) {
+        splitMotionEntry->traceTracker =
+                mTracer->traceDerivedEvent(*splitMotionEntry, *originalMotionEntry.traceTracker);
+    }
 
     return splitMotionEntry;
 }
@@ -4797,6 +4886,10 @@
                                                                         pointerCount));
                 transformMotionEntryForInjectionLocked(*nextInjectedEntry,
                                                        motionEvent.getTransform());
+                if (mTracer) {
+                    nextInjectedEntry->traceTracker =
+                            mTracer->traceInboundEvent(*nextInjectedEntry);
+                }
                 injectedEntries.push(std::move(nextInjectedEntry));
             }
             break;
@@ -5185,6 +5278,7 @@
         }
         LOG(INFO) << "setInputWindows displayId=" << displayId << " " << windowList;
     }
+    ScopedSyntheticEventTracer traceContext(mTracer);
 
     // Check preconditions for new input windows
     for (const sp<WindowInfoHandle>& window : windowInfoHandles) {
@@ -5224,7 +5318,7 @@
     std::optional<FocusResolver::FocusChanges> changes =
             mFocusResolver.setInputWindows(displayId, windowHandles);
     if (changes) {
-        onFocusChangedLocked(*changes, removedFocusedWindowHandle);
+        onFocusChangedLocked(*changes, traceContext.getTracker(), removedFocusedWindowHandle);
     }
 
     std::unordered_map<int32_t, TouchState>::iterator stateIt =
@@ -5237,7 +5331,7 @@
                 LOG(INFO) << "Touched window was removed: " << touchedWindow.windowHandle->getName()
                           << " in display %" << displayId;
                 CancelationOptions options(CancelationOptions::Mode::CANCEL_POINTER_EVENTS,
-                                           "touched window was removed");
+                                           "touched window was removed", traceContext.getTracker());
                 synthesizeCancelationEventsForWindowLocked(touchedWindow.windowHandle, options);
                 // Since we are about to drop the touch, cancel the events for the wallpaper as
                 // well.
@@ -5338,6 +5432,7 @@
     }
     { // acquire lock
         std::scoped_lock _l(mLock);
+        ScopedSyntheticEventTracer traceContext(mTracer);
 
         if (mFocusedDisplayId != displayId) {
             sp<IBinder> oldFocusedWindowToken =
@@ -5350,7 +5445,8 @@
                 }
                 CancelationOptions
                         options(CancelationOptions::Mode::CANCEL_NON_POINTER_EVENTS,
-                                "The display which contains this window no longer has focus.");
+                                "The display which contains this window no longer has focus.",
+                                traceContext.getTracker());
                 options.displayId = ADISPLAY_ID_NONE;
                 synthesizeCancelationEventsForWindowLocked(windowHandle, options);
             }
@@ -5575,19 +5671,22 @@
         }
 
         // Synthesize cancel for old window and down for new window.
+        ScopedSyntheticEventTracer traceContext(mTracer);
         std::shared_ptr<Connection> fromConnection = getConnectionLocked(fromToken);
         std::shared_ptr<Connection> toConnection = getConnectionLocked(toToken);
         if (fromConnection != nullptr && toConnection != nullptr) {
             fromConnection->inputState.mergePointerStateTo(toConnection->inputState);
             CancelationOptions options(CancelationOptions::Mode::CANCEL_POINTER_EVENTS,
-                                       "transferring touch from this window to another window");
+                                       "transferring touch from this window to another window",
+                                       traceContext.getTracker());
             synthesizeCancelationEventsForWindowLocked(fromWindowHandle, options, fromConnection);
             synthesizePointerDownEventsForConnectionLocked(downTimeInTarget, toConnection,
-                                                           newTargetFlags);
+                                                           newTargetFlags,
+                                                           traceContext.getTracker());
 
             // Check if the wallpaper window should deliver the corresponding event.
             transferWallpaperTouch(oldTargetFlags, newTargetFlags, fromWindowHandle, toWindowHandle,
-                                   *state, deviceId, pointers);
+                                   *state, deviceId, pointers, traceContext.getTracker());
         }
     } // release lock
 
@@ -5655,7 +5754,9 @@
         ALOGD("Resetting and dropping all events (%s).", reason);
     }
 
-    CancelationOptions options(CancelationOptions::Mode::CANCEL_ALL_EVENTS, reason);
+    ScopedSyntheticEventTracer traceContext(mTracer);
+    CancelationOptions options(CancelationOptions::Mode::CANCEL_ALL_EVENTS, reason,
+                               traceContext.getTracker());
     synthesizeCancelationEventsForAllConnectionsLocked(options);
 
     resetKeyRepeatLocked();
@@ -6050,12 +6151,13 @@
         return BAD_VALUE;
     }
 
+    ScopedSyntheticEventTracer traceContext(mTracer);
     for (const DeviceId deviceId : deviceIds) {
         TouchState& state = *statePtr;
         TouchedWindow& window = *windowPtr;
         // Send cancel events to all the input channels we're stealing from.
         CancelationOptions options(CancelationOptions::Mode::CANCEL_POINTER_EVENTS,
-                                   "input channel stole pointer stream");
+                                   "input channel stole pointer stream", traceContext.getTracker());
         options.deviceId = deviceId;
         options.displayId = displayId;
         std::vector<PointerProperties> pointers = window.getTouchingPointers(deviceId);
@@ -6479,7 +6581,8 @@
                     CancelationOptions options(CancelationOptions::Mode::CANCEL_FALLBACK_EVENTS,
                                                "application handled the original non-fallback key "
                                                "or is no longer a foreground target, "
-                                               "canceling previously dispatched fallback key");
+                                               "canceling previously dispatched fallback key",
+                                               keyEntry.traceTracker);
                     options.keyCode = *fallbackKeyCode;
                     synthesizeCancelationEventsForWindowLocked(windowHandle, options, connection);
                 }
@@ -6561,7 +6664,8 @@
             const auto windowHandle = getWindowHandleLocked(connection->getToken());
             if (windowHandle != nullptr) {
                 CancelationOptions options(CancelationOptions::Mode::CANCEL_FALLBACK_EVENTS,
-                                           "canceling fallback, policy no longer desires it");
+                                           "canceling fallback, policy no longer desires it",
+                                           keyEntry.traceTracker);
                 options.keyCode = *fallbackKeyCode;
                 synthesizeCancelationEventsForWindowLocked(windowHandle, options, connection);
             }
@@ -6597,6 +6701,10 @@
                                                *fallbackKeyCode, event.getScanCode(),
                                                event.getMetaState(), event.getRepeatCount(),
                                                event.getDownTime());
+            if (mTracer) {
+                newEntry->traceTracker =
+                        mTracer->traceDerivedEvent(*newEntry, *keyEntry.traceTracker);
+            }
             if (DEBUG_OUTBOUND_EVENT_DETAILS) {
                 ALOGD("Unhandled key event: Dispatching fallback key.  "
                       "originalKeyCode=%d, fallbackKeyCode=%d, fallbackMetaState=%08x",
@@ -6695,16 +6803,19 @@
         std::scoped_lock _l(mLock);
         std::optional<FocusResolver::FocusChanges> changes =
                 mFocusResolver.setFocusedWindow(request, getWindowHandlesLocked(request.displayId));
+        ScopedSyntheticEventTracer traceContext(mTracer);
         if (changes) {
-            onFocusChangedLocked(*changes);
+            onFocusChangedLocked(*changes, traceContext.getTracker());
         }
     } // release lock
     // Wake up poll loop since it may need to make new input dispatching choices.
     mLooper->wake();
 }
 
-void InputDispatcher::onFocusChangedLocked(const FocusResolver::FocusChanges& changes,
-                                           const sp<WindowInfoHandle> removedFocusedWindowHandle) {
+void InputDispatcher::onFocusChangedLocked(
+        const FocusResolver::FocusChanges& changes,
+        const std::unique_ptr<trace::EventTrackerInterface>& traceTracker,
+        const sp<WindowInfoHandle> removedFocusedWindowHandle) {
     if (changes.oldFocus) {
         const auto resolvedWindow = removedFocusedWindowHandle != nullptr
                 ? removedFocusedWindowHandle
@@ -6713,7 +6824,7 @@
             LOG(FATAL) << __func__ << ": Previously focused token did not have a window";
         }
         CancelationOptions options(CancelationOptions::Mode::CANCEL_NON_POINTER_EVENTS,
-                                   "focus left window");
+                                   "focus left window", traceTracker);
         synthesizeCancelationEventsForWindowLocked(resolvedWindow, options);
         enqueueFocusEventLocked(changes.oldFocus, /*hasFocus=*/false, changes.reason);
     }
@@ -6866,9 +6977,10 @@
 void InputDispatcher::cancelCurrentTouch() {
     {
         std::scoped_lock _l(mLock);
+        ScopedSyntheticEventTracer traceContext(mTracer);
         ALOGD("Canceling all ongoing pointer gestures on all displays.");
         CancelationOptions options(CancelationOptions::Mode::CANCEL_POINTER_EVENTS,
-                                   "cancel current touch");
+                                   "cancel current touch", traceContext.getTracker());
         synthesizeCancelationEventsForAllConnectionsLocked(options);
 
         mTouchStatesByDisplay.clear();
@@ -6918,12 +7030,12 @@
     }
 }
 
-void InputDispatcher::transferWallpaperTouch(ftl::Flags<InputTarget::Flags> oldTargetFlags,
-                                             ftl::Flags<InputTarget::Flags> newTargetFlags,
-                                             const sp<WindowInfoHandle> fromWindowHandle,
-                                             const sp<WindowInfoHandle> toWindowHandle,
-                                             TouchState& state, int32_t deviceId,
-                                             const std::vector<PointerProperties>& pointers) {
+void InputDispatcher::transferWallpaperTouch(
+        ftl::Flags<InputTarget::Flags> oldTargetFlags,
+        ftl::Flags<InputTarget::Flags> newTargetFlags, const sp<WindowInfoHandle> fromWindowHandle,
+        const sp<WindowInfoHandle> toWindowHandle, TouchState& state, int32_t deviceId,
+        const std::vector<PointerProperties>& pointers,
+        const std::unique_ptr<trace::EventTrackerInterface>& traceTracker) {
     const bool oldHasWallpaper = oldTargetFlags.test(InputTarget::Flags::FOREGROUND) &&
             fromWindowHandle->getInfo()->inputConfig.test(
                     gui::WindowInfo::InputConfig::DUPLICATE_TOUCH_TO_WALLPAPER);
@@ -6941,7 +7053,7 @@
 
     if (oldWallpaper != nullptr) {
         CancelationOptions options(CancelationOptions::Mode::CANCEL_POINTER_EVENTS,
-                                   "transferring touch focus to another window");
+                                   "transferring touch focus to another window", traceTracker);
         state.removeWindowByToken(oldWallpaper->getToken());
         synthesizeCancelationEventsForWindowLocked(oldWallpaper, options);
     }
@@ -6961,7 +7073,7 @@
                     getConnectionLocked(toWindowHandle->getToken());
             toConnection->inputState.mergePointerStateTo(wallpaperConnection->inputState);
             synthesizePointerDownEventsForConnectionLocked(downTimeInTarget, wallpaperConnection,
-                                                           wallpaperFlags);
+                                                           wallpaperFlags, traceTracker);
         }
     }
 }
diff --git a/services/inputflinger/dispatcher/InputDispatcher.h b/services/inputflinger/dispatcher/InputDispatcher.h
index 269bfdd..d6eba64 100644
--- a/services/inputflinger/dispatcher/InputDispatcher.h
+++ b/services/inputflinger/dispatcher/InputDispatcher.h
@@ -628,7 +628,8 @@
 
     void synthesizePointerDownEventsForConnectionLocked(
             const nsecs_t downTime, const std::shared_ptr<Connection>& connection,
-            ftl::Flags<InputTarget::Flags> targetFlags) REQUIRES(mLock);
+            ftl::Flags<InputTarget::Flags> targetFlags,
+            const std::unique_ptr<trace::EventTrackerInterface>& traceTracker) REQUIRES(mLock);
 
     // Splitting motion events across windows. When splitting motion event for a target,
     // splitDownTime refers to the time of first 'down' event on that particular target
@@ -657,6 +658,7 @@
     void doInterceptKeyBeforeDispatchingCommand(const sp<IBinder>& focusedWindowToken,
                                                 const KeyEntry& entry) REQUIRES(mLock);
     void onFocusChangedLocked(const FocusResolver::FocusChanges& changes,
+                              const std::unique_ptr<trace::EventTrackerInterface>& traceTracker,
                               const sp<gui::WindowInfoHandle> removedFocusedWindowHandle = nullptr)
             REQUIRES(mLock);
     void sendFocusChangedCommandLocked(const sp<IBinder>& oldToken, const sp<IBinder>& newToken)
@@ -704,7 +706,9 @@
                                 const sp<android::gui::WindowInfoHandle> fromWindowHandle,
                                 const sp<android::gui::WindowInfoHandle> toWindowHandle,
                                 TouchState& state, int32_t deviceId,
-                                const std::vector<PointerProperties>& pointers) REQUIRES(mLock);
+                                const std::vector<PointerProperties>& pointers,
+                                const std::unique_ptr<trace::EventTrackerInterface>& traceTracker)
+            REQUIRES(mLock);
 
     sp<android::gui::WindowInfoHandle> findWallpaperWindowBelow(
             const sp<android::gui::WindowInfoHandle>& windowHandle) const REQUIRES(mLock);
diff --git a/services/inputflinger/dispatcher/trace/InputTracer.cpp b/services/inputflinger/dispatcher/trace/InputTracer.cpp
index be09013..83ed452 100644
--- a/services/inputflinger/dispatcher/trace/InputTracer.cpp
+++ b/services/inputflinger/dispatcher/trace/InputTracer.cpp
@@ -59,6 +59,16 @@
                           e.downTime,  e.flags,     e.repeatCount};
 }
 
+void writeEventToBackend(const TracedEvent& event, InputTracingBackendInterface& backend) {
+    std::visit(Visitor{[&](const TracedMotionEvent& e) { backend.traceMotionEvent(e); },
+                       [&](const TracedKeyEvent& e) { backend.traceKeyEvent(e); }},
+               event);
+}
+
+inline auto getId(const trace::TracedEvent& v) {
+    return std::visit([](const auto& event) { return event.id; }, v);
+}
+
 } // namespace
 
 // --- InputTracer ---
@@ -67,46 +77,86 @@
       : mBackend(std::move(backend)) {}
 
 std::unique_ptr<EventTrackerInterface> InputTracer::traceInboundEvent(const EventEntry& entry) {
-    TracedEvent traced;
+    // This is a newly traced inbound event. Create a new state to track it and its derived events.
+    auto eventState = std::make_shared<EventState>(*this);
 
     if (entry.type == EventEntry::Type::MOTION) {
         const auto& motion = static_cast<const MotionEntry&>(entry);
-        traced = createTracedEvent(motion);
+        eventState->events.emplace_back(createTracedEvent(motion));
     } else if (entry.type == EventEntry::Type::KEY) {
         const auto& key = static_cast<const KeyEntry&>(entry);
-        traced = createTracedEvent(key);
+        eventState->events.emplace_back(createTracedEvent(key));
     } else {
         LOG(FATAL) << "Cannot trace EventEntry of type: " << ftl::enum_string(entry.type);
     }
 
-    return std::make_unique<EventTrackerImpl>(*this, std::move(traced));
+    return std::make_unique<EventTrackerImpl>(std::move(eventState), /*isDerived=*/false);
+}
+
+std::unique_ptr<EventTrackerInterface> InputTracer::createTrackerForSyntheticEvent() {
+    // Create a new EventState to track events derived from this tracker.
+    return std::make_unique<EventTrackerImpl>(std::make_shared<EventState>(*this),
+                                              /*isDerived=*/false);
 }
 
 void InputTracer::dispatchToTargetHint(const EventTrackerInterface& cookie,
                                        const InputTarget& target) {
-    auto& cookieState = getState(cookie);
-    if (!cookieState) {
-        LOG(FATAL) << "dispatchToTargetHint() should not be called after eventProcessingComplete()";
+    if (isDerivedCookie(cookie)) {
+        LOG(FATAL) << "Event target cannot be updated from a derived cookie.";
+    }
+    auto& eventState = getState(cookie);
+    if (eventState->isEventProcessingComplete) {
+        // TODO(b/210460522): Disallow adding new targets after eventProcessingComplete() is called.
+        return;
     }
     // TODO(b/210460522): Determine if the event is sensitive based on the target.
 }
 
 void InputTracer::eventProcessingComplete(const EventTrackerInterface& cookie) {
-    auto& cookieState = getState(cookie);
-    if (!cookieState) {
+    if (isDerivedCookie(cookie)) {
+        LOG(FATAL) << "Event processing cannot be set from a derived cookie.";
+    }
+    auto& eventState = getState(cookie);
+    if (eventState->isEventProcessingComplete) {
         LOG(FATAL) << "Traced event was already logged. "
                       "eventProcessingComplete() was likely called more than once.";
     }
+    eventState->onEventProcessingComplete();
+}
 
-    std::visit(Visitor{[&](const TracedMotionEvent& e) { mBackend->traceMotionEvent(e); },
-                       [&](const TracedKeyEvent& e) { mBackend->traceKeyEvent(e); }},
-               cookieState->event);
-    cookieState.reset();
+std::unique_ptr<EventTrackerInterface> InputTracer::traceDerivedEvent(
+        const EventEntry& entry, const EventTrackerInterface& originalEventCookie) {
+    // This is an event derived from an already-established event. Use the same state to track
+    // this event too.
+    auto eventState = getState(originalEventCookie);
+
+    if (entry.type == EventEntry::Type::MOTION) {
+        const auto& motion = static_cast<const MotionEntry&>(entry);
+        eventState->events.emplace_back(createTracedEvent(motion));
+    } else if (entry.type == EventEntry::Type::KEY) {
+        const auto& key = static_cast<const KeyEntry&>(entry);
+        eventState->events.emplace_back(createTracedEvent(key));
+    } else {
+        LOG(FATAL) << "Cannot trace EventEntry of type: " << ftl::enum_string(entry.type);
+    }
+
+    if (eventState->isEventProcessingComplete) {
+        // It is possible for a derived event to be dispatched some time after the original event
+        // is dispatched, such as in the case of key fallback events. To account for these cases,
+        // derived events can be traced after the processing is complete for the original event.
+        writeEventToBackend(eventState->events.back(), *mBackend);
+    }
+    return std::make_unique<EventTrackerImpl>(std::move(eventState), /*isDerived=*/true);
 }
 
 void InputTracer::traceEventDispatch(const DispatchEntry& dispatchEntry,
-                                     const EventTrackerInterface* cookie) {
+                                     const EventTrackerInterface& cookie) {
+    auto& eventState = getState(cookie);
     const EventEntry& entry = *dispatchEntry.eventEntry;
+    // TODO(b/328618922): Remove resolved key repeats after making repeatCount non-mutable.
+    // The KeyEntry's repeatCount is mutable and can be modified after an event is initially traced,
+    // so we need to find the repeatCount at the time of dispatching to trace it accurately.
+    int32_t resolvedKeyRepeatCount = 0;
 
     TracedEvent traced;
     if (entry.type == EventEntry::Type::MOTION) {
@@ -114,16 +164,19 @@
         traced = createTracedEvent(motion);
     } else if (entry.type == EventEntry::Type::KEY) {
         const auto& key = static_cast<const KeyEntry&>(entry);
+        resolvedKeyRepeatCount = key.repeatCount;
         traced = createTracedEvent(key);
     } else {
         LOG(FATAL) << "Cannot trace EventEntry of type: " << ftl::enum_string(entry.type);
     }
 
-    if (!cookie) {
-        // This event was not tracked as an inbound event, so trace it now.
-        std::visit(Visitor{[&](const TracedMotionEvent& e) { mBackend->traceMotionEvent(e); },
-                           [&](const TracedKeyEvent& e) { mBackend->traceKeyEvent(e); }},
-                   traced);
+    auto tracedEventIt =
+            std::find_if(eventState->events.begin(), eventState->events.end(),
+                         [&traced](const auto& event) { return getId(traced) == getId(event); });
+    if (tracedEventIt == eventState->events.end()) {
+        LOG(FATAL)
+                << __func__
+                << ": Failed to find a previously traced event that matches the dispatched event";
     }
 
     // The vsyncId only has meaning if the event is targeting a window.
@@ -133,31 +186,38 @@
     mBackend->traceWindowDispatch({std::move(traced), dispatchEntry.deliveryTime,
                                    dispatchEntry.resolvedFlags, dispatchEntry.targetUid, vsyncId,
                                    windowId, dispatchEntry.transform, dispatchEntry.rawTransform,
-                                   /*hmac=*/{}});
+                                   /*hmac=*/{}, resolvedKeyRepeatCount});
 }
 
-std::optional<InputTracer::EventState>& InputTracer::getState(const EventTrackerInterface& cookie) {
+std::shared_ptr<InputTracer::EventState>& InputTracer::getState(
+        const EventTrackerInterface& cookie) {
     return static_cast<const EventTrackerImpl&>(cookie).mState;
 }
 
-// --- InputTracer::EventTrackerImpl ---
+bool InputTracer::isDerivedCookie(const EventTrackerInterface& cookie) {
+    return static_cast<const EventTrackerImpl&>(cookie).mIsDerived;
+}
 
-InputTracer::EventTrackerImpl::EventTrackerImpl(InputTracer& tracer, TracedEvent&& event)
-      : mTracer(tracer), mState(event) {}
+// --- InputTracer::EventState ---
 
-InputTracer::EventTrackerImpl::~EventTrackerImpl() {
-    if (!mState) {
+void InputTracer::EventState::onEventProcessingComplete() {
+    // Write all of the events known so far to the trace.
+    for (const auto& event : events) {
+        writeEventToBackend(event, *tracer.mBackend);
+    }
+    isEventProcessingComplete = true;
+}
+
+InputTracer::EventState::~EventState() {
+    if (isEventProcessingComplete) {
         // This event has already been written to the trace as expected.
         return;
     }
-    // We're still holding on to the state, which means it hasn't yet been written to the trace.
-    // Write it to the trace now.
-    // TODO(b/210460522): Determine why/where the event is being destroyed before
-    //   eventProcessingComplete() is called.
-    std::visit(Visitor{[&](const TracedMotionEvent& e) { mTracer.mBackend->traceMotionEvent(e); },
-                       [&](const TracedKeyEvent& e) { mTracer.mBackend->traceKeyEvent(e); }},
-               mState->event);
-    mState.reset();
+    // The event processing was never marked as complete, so do it now.
+    // We should never end up here in normal operation. However, in tests, it's possible that we
+    // stop and destroy InputDispatcher without waiting for it to finish processing events, at
+    // which point an event (and thus its EventState) may be destroyed before processing finishes.
+    onEventProcessingComplete();
 }
 
 } // namespace android::inputdispatcher::trace::impl
diff --git a/services/inputflinger/dispatcher/trace/InputTracer.h b/services/inputflinger/dispatcher/trace/InputTracer.h
index c8b25c9..529c0fa 100644
--- a/services/inputflinger/dispatcher/trace/InputTracer.h
+++ b/services/inputflinger/dispatcher/trace/InputTracer.h
@@ -42,37 +42,48 @@
     InputTracer& operator=(const InputTracer&) = delete;
 
     std::unique_ptr<EventTrackerInterface> traceInboundEvent(const EventEntry&) override;
+    std::unique_ptr<EventTrackerInterface> createTrackerForSyntheticEvent() override;
     void dispatchToTargetHint(const EventTrackerInterface&, const InputTarget&) override;
     void eventProcessingComplete(const EventTrackerInterface&) override;
-    void traceEventDispatch(const DispatchEntry&, const EventTrackerInterface*) override;
+    std::unique_ptr<EventTrackerInterface> traceDerivedEvent(const EventEntry&,
+                                                             const EventTrackerInterface&) override;
+    void traceEventDispatch(const DispatchEntry&, const EventTrackerInterface&) override;
 
 private:
     std::unique_ptr<InputTracingBackendInterface> mBackend;
 
-    // The state of a tracked event.
+    // The state of a tracked event, shared across all events derived from the original event.
     struct EventState {
-        const TracedEvent event;
+        explicit inline EventState(InputTracer& tracer) : tracer(tracer){};
+        ~EventState();
+
+        void onEventProcessingComplete();
+
+        InputTracer& tracer;
+        std::vector<const TracedEvent> events;
+        bool isEventProcessingComplete{false};
         // TODO(b/210460522): Add additional args for tracking event sensitivity and
         //  dispatch target UIDs.
     };
 
     // Get the event state associated with a tracking cookie.
-    std::optional<EventState>& getState(const EventTrackerInterface&);
+    std::shared_ptr<EventState>& getState(const EventTrackerInterface&);
+    bool isDerivedCookie(const EventTrackerInterface&);
 
     // Implementation of the event tracker cookie. The cookie holds the event state directly for
     // convenience to avoid the overhead of tracking the state separately in InputTracer.
     class EventTrackerImpl : public EventTrackerInterface {
     public:
-        explicit EventTrackerImpl(InputTracer&, TracedEvent&& entry);
-        virtual ~EventTrackerImpl() override;
+        inline EventTrackerImpl(const std::shared_ptr<EventState>& state, bool isDerivedEvent)
+              : mState(state), mIsDerived(isDerivedEvent) {}
+        EventTrackerImpl(const EventTrackerImpl&) = default;
 
     private:
-        InputTracer& mTracer;
-        // This event tracker cookie will only hold the state as long as it has not been written
-        // to the trace. The state is released when the event is written to the trace.
-        mutable std::optional<EventState> mState;
+        mutable std::shared_ptr<EventState> mState;
+        const bool mIsDerived;
 
-        friend std::optional<EventState>& InputTracer::getState(const EventTrackerInterface&);
+        friend std::shared_ptr<EventState>& InputTracer::getState(const EventTrackerInterface&);
+        friend bool InputTracer::isDerivedCookie(const EventTrackerInterface&);
     };
 };
 
diff --git a/services/inputflinger/dispatcher/trace/InputTracerInterface.h b/services/inputflinger/dispatcher/trace/InputTracerInterface.h
index c6cd7de..609d10c 100644
--- a/services/inputflinger/dispatcher/trace/InputTracerInterface.h
+++ b/services/inputflinger/dispatcher/trace/InputTracerInterface.h
@@ -54,6 +54,14 @@
     virtual std::unique_ptr<EventTrackerInterface> traceInboundEvent(const EventEntry&) = 0;
 
     /**
+     * Create a trace tracker for a synthetic event that does not stem from an inbound input event.
+     * This includes things like generating cancellations or down events for various reasons,
+     * such as ANR, pilfering, transfer touch, etc. Any key or motion events generated for this
+     * synthetic event should be traced as a derived event using {@link #traceDerivedEvent}.
+     */
+    virtual std::unique_ptr<EventTrackerInterface> createTrackerForSyntheticEvent() = 0;
+
+    /**
      * Notify the tracer that the traced event will be sent to the given InputTarget.
      * The tracer may change how the event is logged depending on the target. For example,
      * events targeting certain UIDs may be logged as sensitive events.
@@ -76,12 +84,25 @@
     virtual void eventProcessingComplete(const EventTrackerInterface&) = 0;
 
     /**
-     * Trace an input event being successfully dispatched to a window. The dispatched event may
-     * be a previously traced inbound event, or it may be a synthesized event that has not been
-     * previously traced. For inbound events that were previously traced, the EventTracker cookie
-     * must be provided. For events that were not previously traced, the cookie must be null.
+     * Trace an input event that is derived from another event. This is used in cases where an event
+     * is modified from the original, such as when a touch is split across multiple windows, or
+     * when a HOVER_MOVE event is modified to be a HOVER_EXIT, etc. The original event's tracker
+     * must be provided, and a new EventTracker is returned that should be used to track the event's
+     * lifecycle.
+     *
+     * NOTE: The derived tracker cannot be used to change the targets of the original event, meaning
+     * it cannot be used with {@link #dispatchToTargetHint} or {@link eventProcessingComplete}.
      */
-    virtual void traceEventDispatch(const DispatchEntry&, const EventTrackerInterface*) = 0;
+    virtual std::unique_ptr<EventTrackerInterface> traceDerivedEvent(
+            const EventEntry&, const EventTrackerInterface& originalEventTracker) = 0;
+
+    /**
+     * Trace an input event being successfully dispatched to a window. The dispatched event may
+     * be a previously traced inbound event, or it may be a synthesized event. All dispatched events
+     * must have been previously traced, so the trace tracker associated with the event must be
+     * provided.
+     */
+    virtual void traceEventDispatch(const DispatchEntry&, const EventTrackerInterface&) = 0;
 };
 
 } // namespace android::inputdispatcher::trace
diff --git a/services/inputflinger/dispatcher/trace/InputTracingBackendInterface.h b/services/inputflinger/dispatcher/trace/InputTracingBackendInterface.h
index b0eadfe..94a86b9 100644
--- a/services/inputflinger/dispatcher/trace/InputTracingBackendInterface.h
+++ b/services/inputflinger/dispatcher/trace/InputTracingBackendInterface.h
@@ -98,6 +98,7 @@
         ui::Transform transform;
         ui::Transform rawTransform;
         std::array<uint8_t, 32> hmac;
+        int32_t resolvedKeyRepeatCount;
     };
     virtual void traceWindowDispatch(const WindowDispatchArgs&) = 0;
 };
diff --git a/services/inputflinger/include/PointerChoreographerPolicyInterface.h b/services/inputflinger/include/PointerChoreographerPolicyInterface.h
index 8b47b55..462aedc 100644
--- a/services/inputflinger/include/PointerChoreographerPolicyInterface.h
+++ b/services/inputflinger/include/PointerChoreographerPolicyInterface.h
@@ -25,6 +25,9 @@
  *
  * This is the interface that PointerChoreographer uses to talk to Window Manager and other
  * system components.
+ *
+ * NOTE: In general, the PointerChoreographer must not interact with the policy while
+ * holding any locks.
  */
 class PointerChoreographerPolicyInterface {
 public:
@@ -37,6 +40,9 @@
      * for and runnable on the host, the PointerController implementation must be in a separate
      * library, libinputservice, that has the additional dependencies. The PointerController
      * will be mocked when testing PointerChoreographer.
+     *
+     * Since this is a factory method used to work around dependencies, it will not interact with
+     * other input components and may be called with the PointerChoreographer lock held.
      */
     virtual std::shared_ptr<PointerControllerInterface> createPointerController(
             PointerControllerInterface::ControllerType type) = 0;
diff --git a/services/inputflinger/tests/FakeInputTracingBackend.cpp b/services/inputflinger/tests/FakeInputTracingBackend.cpp
index 4655ee8..08738e3 100644
--- a/services/inputflinger/tests/FakeInputTracingBackend.cpp
+++ b/services/inputflinger/tests/FakeInputTracingBackend.cpp
@@ -56,8 +56,8 @@
                       const std::array<uint8_t, 32>& hmac) {
     KeyEvent traced;
     traced.initialize(e.id, e.deviceId, e.source, e.displayId, hmac, e.action,
-                      dispatchArgs.resolvedFlags, e.keyCode, e.scanCode, e.metaState, e.repeatCount,
-                      e.downTime, e.eventTime);
+                      dispatchArgs.resolvedFlags, e.keyCode, e.scanCode, e.metaState,
+                      dispatchArgs.resolvedKeyRepeatCount, e.downTime, e.eventTime);
     return traced;
 }
 
diff --git a/services/inputflinger/tests/FocusResolver_test.cpp b/services/inputflinger/tests/FocusResolver_test.cpp
index 2ff9c3c..cb8c3cb 100644
--- a/services/inputflinger/tests/FocusResolver_test.cpp
+++ b/services/inputflinger/tests/FocusResolver_test.cpp
@@ -20,6 +20,7 @@
 
 #define ASSERT_FOCUS_CHANGE(_changes, _oldFocus, _newFocus) \
     {                                                       \
+        ASSERT_TRUE(_changes.has_value());                  \
         ASSERT_EQ(_oldFocus, _changes->oldFocus);           \
         ASSERT_EQ(_newFocus, _changes->newFocus);           \
     }
@@ -152,6 +153,38 @@
     ASSERT_FOCUS_CHANGE(changes, /*from*/ invisibleWindowToken, /*to*/ nullptr);
 }
 
+TEST(FocusResolverTest, FocusTransferToMirror) {
+    sp<IBinder> focusableWindowToken = sp<BBinder>::make();
+    auto window = sp<FakeWindowHandle>::make("Window", focusableWindowToken,
+                                             /*focusable=*/true, /*visible=*/true);
+    auto mirror = sp<FakeWindowHandle>::make("Mirror", focusableWindowToken,
+                                             /*focusable=*/true, /*visible=*/true);
+
+    FocusRequest request;
+    request.displayId = 42;
+    request.token = focusableWindowToken;
+    FocusResolver focusResolver;
+    std::optional<FocusResolver::FocusChanges> changes =
+            focusResolver.setFocusedWindow(request, {window, mirror});
+    ASSERT_FOCUS_CHANGE(changes, /*from*/ nullptr, /*to*/ focusableWindowToken);
+
+    // The mirror window now comes on top, and the focus does not change
+    changes = focusResolver.setInputWindows(request.displayId, {mirror, window});
+    ASSERT_FALSE(changes.has_value());
+
+    // The window now comes on top while the mirror is removed, and the focus does not change
+    changes = focusResolver.setInputWindows(request.displayId, {window});
+    ASSERT_FALSE(changes.has_value());
+
+    // The window is removed but the mirror is on top, and focus does not change
+    changes = focusResolver.setInputWindows(request.displayId, {mirror});
+    ASSERT_FALSE(changes.has_value());
+
+    // All windows removed
+    changes = focusResolver.setInputWindows(request.displayId, {});
+    ASSERT_FOCUS_CHANGE(changes, /*from*/ focusableWindowToken, /*to*/ nullptr);
+}
+
 TEST(FocusResolverTest, SetInputWindows) {
     sp<IBinder> focusableWindowToken = sp<BBinder>::make();
     std::vector<sp<WindowInfoHandle>> windows;
@@ -169,6 +202,10 @@
             focusResolver.setFocusedWindow(request, windows);
     ASSERT_EQ(focusableWindowToken, changes->newFocus);
 
+    // When there are no changes to the window, focus does not change
+    changes = focusResolver.setInputWindows(request.displayId, windows);
+    ASSERT_FALSE(changes.has_value());
+
     // Window visibility changes and the window loses focus
     window->setVisible(false);
     changes = focusResolver.setInputWindows(request.displayId, windows);
@@ -380,18 +417,13 @@
     ASSERT_FOCUS_CHANGE(changes, /*from*/ nullptr, /*to*/ windowToken);
     ASSERT_EQ(request.displayId, changes->displayId);
 
-    // Start with a focused window
-    window->setFocusable(true);
-    changes = focusResolver.setInputWindows(request.displayId, windows);
-    ASSERT_FOCUS_CHANGE(changes, /*from*/ nullptr, /*to*/ windowToken);
-
     // When a display is removed, all windows are removed from the display
     // and our focused window loses focus
     changes = focusResolver.setInputWindows(request.displayId, {});
     ASSERT_FOCUS_CHANGE(changes, /*from*/ windowToken, /*to*/ nullptr);
     focusResolver.displayRemoved(request.displayId);
 
-    // When a display is readded, the window does not get focus since the request was cleared.
+    // When a display is re-added, the window does not get focus since the request was cleared.
     changes = focusResolver.setInputWindows(request.displayId, windows);
     ASSERT_FALSE(changes);
 }
diff --git a/services/surfaceflinger/Scheduler/RefreshRateSelector.cpp b/services/surfaceflinger/Scheduler/RefreshRateSelector.cpp
index ffd3463..ad59f1a 100644
--- a/services/surfaceflinger/Scheduler/RefreshRateSelector.cpp
+++ b/services/surfaceflinger/Scheduler/RefreshRateSelector.cpp
@@ -834,13 +834,16 @@
     const bool touchBoostForExplicitExact = [&] {
         if (supportsAppFrameRateOverrideByContent()) {
             // Enable touch boost if there are other layers besides exact
-            return explicitExact + noVoteLayers != layers.size();
+            return explicitExact + noVoteLayers + explicitGteLayers != layers.size();
         } else {
             // Enable touch boost if there are no exact layers
             return explicitExact == 0;
         }
     }();
 
+    const bool touchBoostForCategory =
+            explicitCategoryVoteLayers + noVoteLayers + explicitGteLayers != layers.size();
+
     const auto touchRefreshRates = rankFrameRates(anchorGroup, RefreshRateOrder::Descending);
     using fps_approx_ops::operator<;
 
@@ -851,6 +854,7 @@
     const bool hasInteraction = signals.touch || interactiveLayers > 0;
 
     if (hasInteraction && explicitDefaultVoteLayers == 0 && touchBoostForExplicitExact &&
+        touchBoostForCategory &&
         scores.front().frameRateMode.fps < touchRefreshRates.front().frameRateMode.fps) {
         ALOGV("Touch Boost");
         ATRACE_FORMAT_INSTANT("%s (Touch Boost [late])",
@@ -1554,19 +1558,17 @@
         case FrameRateCategory::High:
             return FpsRange{90_Hz, 120_Hz};
         case FrameRateCategory::Normal:
-            return FpsRange{60_Hz, 90_Hz};
+            return FpsRange{60_Hz, 120_Hz};
         case FrameRateCategory::Low:
-            return FpsRange{30_Hz, 30_Hz};
+            return FpsRange{30_Hz, 120_Hz};
         case FrameRateCategory::HighHint:
         case FrameRateCategory::NoPreference:
         case FrameRateCategory::Default:
             LOG_ALWAYS_FATAL("Should not get fps range for frame rate category: %s",
                              ftl::enum_string(category).c_str());
-            return FpsRange{0_Hz, 0_Hz};
         default:
             LOG_ALWAYS_FATAL("Invalid frame rate category for range: %s",
                              ftl::enum_string(category).c_str());
-            return FpsRange{0_Hz, 0_Hz};
     }
 }
 
diff --git a/services/surfaceflinger/Scheduler/Scheduler.cpp b/services/surfaceflinger/Scheduler/Scheduler.cpp
index 3f91682..d92edb8 100644
--- a/services/surfaceflinger/Scheduler/Scheduler.cpp
+++ b/services/surfaceflinger/Scheduler/Scheduler.cpp
@@ -565,7 +565,7 @@
             }));
 }
 
-void Scheduler::setRenderRate(PhysicalDisplayId id, Fps renderFrameRate) {
+void Scheduler::setRenderRate(PhysicalDisplayId id, Fps renderFrameRate, bool applyImmediately) {
     std::scoped_lock lock(mDisplayLock);
     ftl::FakeGuard guard(kMainThreadContext);
 
@@ -586,7 +586,7 @@
     ALOGV("%s %s (%s)", __func__, to_string(mode.fps).c_str(),
           to_string(mode.modePtr->getVsyncRate()).c_str());
 
-    display.schedulePtr->getTracker().setRenderRate(renderFrameRate);
+    display.schedulePtr->getTracker().setRenderRate(renderFrameRate, applyImmediately);
 }
 
 Fps Scheduler::getNextFrameInterval(PhysicalDisplayId id,
diff --git a/services/surfaceflinger/Scheduler/Scheduler.h b/services/surfaceflinger/Scheduler/Scheduler.h
index 09f75fd..494a91b 100644
--- a/services/surfaceflinger/Scheduler/Scheduler.h
+++ b/services/surfaceflinger/Scheduler/Scheduler.h
@@ -188,7 +188,7 @@
     const VsyncConfiguration& getVsyncConfiguration() const { return *mVsyncConfiguration; }
 
     // Sets the render rate for the scheduler to run at.
-    void setRenderRate(PhysicalDisplayId, Fps);
+    void setRenderRate(PhysicalDisplayId, Fps, bool applyImmediately);
 
     void enableHardwareVsync(PhysicalDisplayId) REQUIRES(kMainThreadContext);
     void disableHardwareVsync(PhysicalDisplayId, bool disallow) REQUIRES(kMainThreadContext);
diff --git a/services/surfaceflinger/Scheduler/VSyncDispatchTimerQueue.cpp b/services/surfaceflinger/Scheduler/VSyncDispatchTimerQueue.cpp
index 84ccf8e..6d6b70d 100644
--- a/services/surfaceflinger/Scheduler/VSyncDispatchTimerQueue.cpp
+++ b/services/surfaceflinger/Scheduler/VSyncDispatchTimerQueue.cpp
@@ -20,7 +20,7 @@
 
 #include <android-base/stringprintf.h>
 #include <ftl/concat.h>
-#include <utils/Trace.h>
+#include <gui/TraceUtils.h>
 #include <log/log_main.h>
 
 #include <scheduler/TimeKeeper.h>
@@ -44,6 +44,17 @@
             TimePoint::fromNs(nextVsyncTime)};
 }
 
+void traceEntry(const VSyncDispatchTimerQueueEntry& entry, nsecs_t now) {
+    if (!ATRACE_ENABLED() || !entry.wakeupTime().has_value() || !entry.targetVsync().has_value()) {
+        return;
+    }
+
+    ftl::Concat trace(ftl::truncated<5>(entry.name()), " alarm in ",
+                      ns2us(*entry.wakeupTime() - now), "us; VSYNC in ",
+                      ns2us(*entry.targetVsync() - now), "us");
+    ATRACE_FORMAT_INSTANT(trace.c_str());
+}
+
 } // namespace
 
 VSyncDispatch::~VSyncDispatch() = default;
@@ -87,6 +98,7 @@
 
 ScheduleResult VSyncDispatchTimerQueueEntry::schedule(VSyncDispatch::ScheduleTiming timing,
                                                       VSyncTracker& tracker, nsecs_t now) {
+    ATRACE_NAME("VSyncDispatchTimerQueueEntry::schedule");
     auto nextVsyncTime =
             tracker.nextAnticipatedVSyncTimeFrom(std::max(timing.lastVsync,
                                                           now + timing.workDuration +
@@ -98,6 +110,8 @@
             mArmedInfo && (nextVsyncTime > (mArmedInfo->mActualVsyncTime + mMinVsyncDistance));
     bool const wouldSkipAWakeup =
             mArmedInfo && ((nextWakeupTime > (mArmedInfo->mActualWakeupTime + mMinVsyncDistance)));
+    ATRACE_FORMAT_INSTANT("%s: wouldSkipAVsyncTarget=%d wouldSkipAWakeup=%d", mName.c_str(),
+                          wouldSkipAVsyncTarget, wouldSkipAWakeup);
     if (FlagManager::getInstance().dont_skip_on_early_ro()) {
         if (wouldSkipAVsyncTarget || wouldSkipAWakeup) {
             nextVsyncTime = mArmedInfo->mActualVsyncTime;
@@ -122,7 +136,7 @@
 ScheduleResult VSyncDispatchTimerQueueEntry::addPendingWorkloadUpdate(
         VSyncTracker& tracker, nsecs_t now, VSyncDispatch::ScheduleTiming timing) {
     mWorkloadUpdateInfo = timing;
-    const auto armedInfo = update(tracker, now, timing, mArmedInfo);
+    const auto armedInfo = getArmedInfo(tracker, now, timing, mArmedInfo);
     return {TimePoint::fromNs(armedInfo.mActualWakeupTime),
             TimePoint::fromNs(armedInfo.mActualVsyncTime)};
 }
@@ -140,11 +154,13 @@
     bool const nextVsyncTooClose = mLastDispatchTime &&
             (nextVsyncTime - *mLastDispatchTime + mMinVsyncDistance) <= currentPeriod;
     if (alreadyDispatchedForVsync) {
+        ATRACE_FORMAT_INSTANT("alreadyDispatchedForVsync");
         return tracker.nextAnticipatedVSyncTimeFrom(*mLastDispatchTime + mMinVsyncDistance,
                                                     *mLastDispatchTime);
     }
 
     if (nextVsyncTooClose) {
+        ATRACE_FORMAT_INSTANT("nextVsyncTooClose");
         return tracker.nextAnticipatedVSyncTimeFrom(*mLastDispatchTime + currentPeriod,
                                                     *mLastDispatchTime + currentPeriod);
     }
@@ -152,9 +168,11 @@
     return nextVsyncTime;
 }
 
-auto VSyncDispatchTimerQueueEntry::update(VSyncTracker& tracker, nsecs_t now,
-                                          VSyncDispatch::ScheduleTiming timing,
-                                          std::optional<ArmingInfo> armedInfo) const -> ArmingInfo {
+auto VSyncDispatchTimerQueueEntry::getArmedInfo(VSyncTracker& tracker, nsecs_t now,
+                                                VSyncDispatch::ScheduleTiming timing,
+                                                std::optional<ArmingInfo> armedInfo) const
+        -> ArmingInfo {
+    ATRACE_NAME("VSyncDispatchTimerQueueEntry::getArmedInfo");
     const auto earliestReadyBy = now + timing.workDuration + timing.readyDuration;
     const auto earliestVsync = std::max(earliestReadyBy, timing.lastVsync);
 
@@ -165,29 +183,39 @@
     const auto nextReadyTime = nextVsyncTime - timing.readyDuration;
     const auto nextWakeupTime = nextReadyTime - timing.workDuration;
 
-    bool const wouldSkipAVsyncTarget =
-            armedInfo && (nextVsyncTime > (armedInfo->mActualVsyncTime + mMinVsyncDistance));
-    bool const wouldSkipAWakeup =
-            armedInfo && (nextWakeupTime > (armedInfo->mActualWakeupTime + mMinVsyncDistance));
-    if (FlagManager::getInstance().dont_skip_on_early_ro() &&
-        (wouldSkipAVsyncTarget || wouldSkipAWakeup)) {
-        return *armedInfo;
+    if (FlagManager::getInstance().dont_skip_on_early_ro()) {
+        bool const wouldSkipAVsyncTarget =
+                armedInfo && (nextVsyncTime > (armedInfo->mActualVsyncTime + mMinVsyncDistance));
+        bool const wouldSkipAWakeup =
+                armedInfo && (nextWakeupTime > (armedInfo->mActualWakeupTime + mMinVsyncDistance));
+        ATRACE_FORMAT_INSTANT("%s: wouldSkipAVsyncTarget=%d wouldSkipAWakeup=%d", mName.c_str(),
+                              wouldSkipAVsyncTarget, wouldSkipAWakeup);
+        if (wouldSkipAVsyncTarget || wouldSkipAWakeup) {
+            return *armedInfo;
+        }
     }
 
     return ArmingInfo{nextWakeupTime, nextVsyncTime, nextReadyTime};
 }
 
 void VSyncDispatchTimerQueueEntry::update(VSyncTracker& tracker, nsecs_t now) {
+    ATRACE_NAME("VSyncDispatchTimerQueueEntry::update");
     if (!mArmedInfo && !mWorkloadUpdateInfo) {
         return;
     }
 
     if (mWorkloadUpdateInfo) {
+        const auto workDelta = mWorkloadUpdateInfo->workDuration - mScheduleTiming.workDuration;
+        const auto readyDelta = mWorkloadUpdateInfo->readyDuration - mScheduleTiming.readyDuration;
+        const auto lastVsyncDelta = mWorkloadUpdateInfo->lastVsync - mScheduleTiming.lastVsync;
+        ATRACE_FORMAT_INSTANT("Workload updated workDelta=%" PRId64 " readyDelta=%" PRId64
+                              " lastVsyncDelta=%" PRId64,
+                              workDelta, readyDelta, lastVsyncDelta);
         mScheduleTiming = *mWorkloadUpdateInfo;
         mWorkloadUpdateInfo.reset();
     }
 
-    mArmedInfo = update(tracker, now, mScheduleTiming, mArmedInfo);
+    mArmedInfo = getArmedInfo(tracker, now, mScheduleTiming, mArmedInfo);
 }
 
 void VSyncDispatchTimerQueueEntry::disarm() {
@@ -282,6 +310,7 @@
 
 void VSyncDispatchTimerQueue::rearmTimerSkippingUpdateFor(
         nsecs_t now, CallbackMap::const_iterator skipUpdateIt) {
+    ATRACE_CALL();
     std::optional<nsecs_t> min;
     std::optional<nsecs_t> targetVsync;
     std::optional<std::string_view> nextWakeupName;
@@ -294,7 +323,10 @@
         if (it != skipUpdateIt) {
             callback->update(*mTracker, now);
         }
-        auto const wakeupTime = *callback->wakeupTime();
+
+        traceEntry(*callback, now);
+
+        const auto wakeupTime = *callback->wakeupTime();
         if (!min || *min > wakeupTime) {
             nextWakeupName = callback->name();
             min = wakeupTime;
@@ -303,11 +335,6 @@
     }
 
     if (min && min < mIntendedWakeupTime) {
-        if (ATRACE_ENABLED() && nextWakeupName && targetVsync) {
-            ftl::Concat trace(ftl::truncated<5>(*nextWakeupName), " alarm in ", ns2us(*min - now),
-                              "us; VSYNC in ", ns2us(*targetVsync - now), "us");
-            ATRACE_NAME(trace.c_str());
-        }
         setTimer(*min, now);
     } else {
         ATRACE_NAME("cancel timer");
@@ -316,6 +343,7 @@
 }
 
 void VSyncDispatchTimerQueue::timerCallback() {
+    ATRACE_CALL();
     struct Invocation {
         std::shared_ptr<VSyncDispatchTimerQueueEntry> callback;
         nsecs_t vsyncTimestamp;
@@ -338,8 +366,9 @@
                 continue;
             }
 
-            auto const readyTime = callback->readyTime();
+            traceEntry(*callback, now);
 
+            auto const readyTime = callback->readyTime();
             auto const lagAllowance = std::max(now - mIntendedWakeupTime, static_cast<nsecs_t>(0));
             if (*wakeupTime < mIntendedWakeupTime + mTimerSlack + lagAllowance) {
                 callback->executing();
@@ -353,6 +382,8 @@
     }
 
     for (auto const& invocation : invocations) {
+        ftl::Concat trace(ftl::truncated<5>(invocation.callback->name()));
+        ATRACE_FORMAT("%s: %s", __func__, trace.c_str());
         invocation.callback->callback(invocation.vsyncTimestamp, invocation.wakeupTimestamp,
                                       invocation.deadlineTimestamp);
     }
diff --git a/services/surfaceflinger/Scheduler/VSyncDispatchTimerQueue.h b/services/surfaceflinger/Scheduler/VSyncDispatchTimerQueue.h
index 252c09c..e4ddc03 100644
--- a/services/surfaceflinger/Scheduler/VSyncDispatchTimerQueue.h
+++ b/services/surfaceflinger/Scheduler/VSyncDispatchTimerQueue.h
@@ -91,8 +91,8 @@
     };
 
     nsecs_t adjustVsyncIfNeeded(VSyncTracker& tracker, nsecs_t nextVsyncTime) const;
-    ArmingInfo update(VSyncTracker&, nsecs_t now, VSyncDispatch::ScheduleTiming,
-                      std::optional<ArmingInfo>) const;
+    ArmingInfo getArmedInfo(VSyncTracker&, nsecs_t now, VSyncDispatch::ScheduleTiming,
+                            std::optional<ArmingInfo>) const;
 
     const std::string mName;
     const VSyncDispatch::Callback mCallback;
diff --git a/services/surfaceflinger/Scheduler/VSyncPredictor.cpp b/services/surfaceflinger/Scheduler/VSyncPredictor.cpp
index 8697696..db1930d 100644
--- a/services/surfaceflinger/Scheduler/VSyncPredictor.cpp
+++ b/services/surfaceflinger/Scheduler/VSyncPredictor.cpp
@@ -45,16 +45,28 @@
 
 static auto constexpr kMaxPercent = 100u;
 
+namespace {
+int numVsyncsPerFrame(const ftl::NonNull<DisplayModePtr>& displayModePtr) {
+    const auto idealPeakRefreshPeriod = displayModePtr->getPeakFps().getPeriodNsecs();
+    const auto idealRefreshPeriod = displayModePtr->getVsyncRate().getPeriodNsecs();
+    return static_cast<int>(std::round(static_cast<float>(idealPeakRefreshPeriod) /
+                                       static_cast<float>(idealRefreshPeriod)));
+}
+} // namespace
+
 VSyncPredictor::~VSyncPredictor() = default;
 
-VSyncPredictor::VSyncPredictor(ftl::NonNull<DisplayModePtr> modePtr, size_t historySize,
-                               size_t minimumSamplesForPrediction, uint32_t outlierTolerancePercent)
-      : mId(modePtr->getPhysicalDisplayId()),
+VSyncPredictor::VSyncPredictor(std::unique_ptr<Clock> clock, ftl::NonNull<DisplayModePtr> modePtr,
+                               size_t historySize, size_t minimumSamplesForPrediction,
+                               uint32_t outlierTolerancePercent)
+      : mClock(std::move(clock)),
+        mId(modePtr->getPhysicalDisplayId()),
         mTraceOn(property_get_bool("debug.sf.vsp_trace", false)),
         kHistorySize(historySize),
         kMinimumSamplesForPrediction(minimumSamplesForPrediction),
         kOutlierTolerancePercent(std::min(outlierTolerancePercent, kMaxPercent)),
-        mDisplayModePtr(modePtr) {
+        mDisplayModePtr(modePtr),
+        mNumVsyncsForFrame(numVsyncsPerFrame(mDisplayModePtr)) {
     resetModel();
 }
 
@@ -118,11 +130,8 @@
 }
 
 Period VSyncPredictor::minFramePeriodLocked() const {
-    const auto idealPeakRefreshPeriod = mDisplayModePtr->getPeakFps().getPeriodNsecs();
-    const auto numPeriods = static_cast<int>(std::round(static_cast<float>(idealPeakRefreshPeriod) /
-                                                        static_cast<float>(idealPeriod())));
     const auto slope = mRateMap.find(idealPeriod())->second.slope;
-    return Period::fromNs(slope * numPeriods);
+    return Period::fromNs(slope * mNumVsyncsForFrame);
 }
 
 bool VSyncPredictor::addVsyncTimestamp(nsecs_t timestamp) {
@@ -147,7 +156,7 @@
             mKnownTimestamp = timestamp;
         }
         ATRACE_FORMAT_INSTANT("timestamp rejected. mKnownTimestamp was %.2fms ago",
-            (systemTime() - *mKnownTimestamp) / 1e6f);
+                              (mClock->now() - *mKnownTimestamp) / 1e6f);
         return false;
     }
 
@@ -250,17 +259,6 @@
     return true;
 }
 
-auto VSyncPredictor::getVsyncSequenceLocked(nsecs_t timestamp) const -> VsyncSequence {
-    const auto vsync = snapToVsync(timestamp);
-    if (!mLastVsyncSequence) return {vsync, 0};
-
-    const auto [slope, _] = getVSyncPredictionModelLocked();
-    const auto [lastVsyncTime, lastVsyncSequence] = *mLastVsyncSequence;
-    const auto vsyncSequence = lastVsyncSequence +
-            static_cast<int64_t>(std::round((vsync - lastVsyncTime) / static_cast<float>(slope)));
-    return {vsync, vsyncSequence};
-}
-
 nsecs_t VSyncPredictor::snapToVsync(nsecs_t timePoint) const {
     auto const [slope, intercept] = getVSyncPredictionModelLocked();
 
@@ -298,51 +296,45 @@
 }
 
 nsecs_t VSyncPredictor::nextAnticipatedVSyncTimeFrom(nsecs_t timePoint,
-                                                     std::optional<nsecs_t> lastVsyncOpt) const {
+                                                     std::optional<nsecs_t> lastVsyncOpt) {
     ATRACE_CALL();
     std::lock_guard lock(mMutex);
-    const auto currentPeriod = mRateMap.find(idealPeriod())->second.slope;
-    const auto threshold = currentPeriod / 2;
-    const auto minFramePeriod = minFramePeriodLocked().ns();
-    const auto lastFrameMissed =
-            lastVsyncOpt && std::abs(*lastVsyncOpt - mLastMissedVsync.ns()) < threshold;
-    const nsecs_t baseTime =
-            FlagManager::getInstance().vrr_config() && !lastFrameMissed && lastVsyncOpt
-            ? std::max(timePoint, *lastVsyncOpt + minFramePeriod - threshold)
-            : timePoint;
-    return snapToVsyncAlignedWithRenderRate(baseTime);
-}
 
-nsecs_t VSyncPredictor::snapToVsyncAlignedWithRenderRate(nsecs_t timePoint) const {
-    // update the mLastVsyncSequence for reference point
-    mLastVsyncSequence = getVsyncSequenceLocked(timePoint);
+    const auto now = TimePoint::fromNs(mClock->now());
+    purgeTimelines(now);
 
-    const auto renderRatePhase = [&]() REQUIRES(mMutex) -> int {
-        if (!mRenderRateOpt) return 0;
-        const auto divisor =
-                RefreshRateSelector::getFrameRateDivisor(Fps::fromPeriodNsecs(idealPeriod()),
-                                                         *mRenderRateOpt);
-        if (divisor <= 1) return 0;
-
-        int mod = mLastVsyncSequence->seq % divisor;
-        if (mod == 0) return 0;
-
-        // This is actually a bug fix, but guarded with vrr_config since we found it with this
-        // config
-        if (FlagManager::getInstance().vrr_config()) {
-            if (mod < 0) mod += divisor;
-        }
-
-        return divisor - mod;
-    }();
-
-    if (renderRatePhase == 0) {
-        return mLastVsyncSequence->vsyncTime;
+    if (lastVsyncOpt && *lastVsyncOpt > timePoint) {
+        timePoint = *lastVsyncOpt;
     }
 
-    auto const [slope, intercept] = getVSyncPredictionModelLocked();
-    const auto approximateNextVsync = mLastVsyncSequence->vsyncTime + slope * renderRatePhase;
-    return snapToVsync(approximateNextVsync - slope / 2);
+    const auto model = getVSyncPredictionModelLocked();
+    const auto threshold = model.slope / 2;
+    std::optional<Period> minFramePeriodOpt;
+
+    if (mNumVsyncsForFrame > 1) {
+        minFramePeriodOpt = minFramePeriodLocked();
+    }
+
+    std::optional<TimePoint> vsyncOpt;
+    for (auto& timeline : mTimelines) {
+        vsyncOpt = timeline.nextAnticipatedVSyncTimeFrom(model, minFramePeriodOpt,
+                                                         snapToVsync(timePoint), mMissedVsync,
+                                                         lastVsyncOpt ? snapToVsync(*lastVsyncOpt -
+                                                                                    threshold)
+                                                                      : lastVsyncOpt);
+        if (vsyncOpt) {
+            break;
+        }
+    }
+    LOG_ALWAYS_FATAL_IF(!vsyncOpt);
+
+    if (*vsyncOpt > mLastCommittedVsync) {
+        mLastCommittedVsync = *vsyncOpt;
+        ATRACE_FORMAT_INSTANT("mLastCommittedVsync in %.2fms",
+                              float(mLastCommittedVsync.ns() - mClock->now()) / 1e6f);
+    }
+
+    return vsyncOpt->ns();
 }
 
 /*
@@ -353,39 +345,56 @@
  * isVSyncInPhase(33.3, 30) = false
  * isVSyncInPhase(50.0, 30) = true
  */
-bool VSyncPredictor::isVSyncInPhase(nsecs_t timePoint, Fps frameRate) const {
-    std::lock_guard lock(mMutex);
-    const auto divisor =
-            RefreshRateSelector::getFrameRateDivisor(Fps::fromPeriodNsecs(idealPeriod()),
-                                                     frameRate);
-    return isVSyncInPhaseLocked(timePoint, static_cast<unsigned>(divisor));
-}
-
-bool VSyncPredictor::isVSyncInPhaseLocked(nsecs_t timePoint, unsigned divisor) const {
-    const TimePoint now = TimePoint::now();
-    const auto getTimePointIn = [](TimePoint now, nsecs_t timePoint) -> float {
-        return ticks<std::milli, float>(TimePoint::fromNs(timePoint) - now);
-    };
-    ATRACE_FORMAT("%s timePoint in: %.2f divisor: %zu", __func__, getTimePointIn(now, timePoint),
-                  divisor);
-
-    if (divisor <= 1 || timePoint == 0) {
+bool VSyncPredictor::isVSyncInPhase(nsecs_t timePoint, Fps frameRate) {
+    if (timePoint == 0) {
         return true;
     }
 
-    const nsecs_t period = mRateMap[idealPeriod()].slope;
+    std::lock_guard lock(mMutex);
+    const auto model = getVSyncPredictionModelLocked();
+    const nsecs_t period = model.slope;
     const nsecs_t justBeforeTimePoint = timePoint - period / 2;
-    const auto vsyncSequence = getVsyncSequenceLocked(justBeforeTimePoint);
-    ATRACE_FORMAT_INSTANT("vsync in: %.2f sequence: %" PRId64,
-                          getTimePointIn(now, vsyncSequence.vsyncTime), vsyncSequence.seq);
-    return vsyncSequence.seq % divisor == 0;
+    const auto now = TimePoint::fromNs(mClock->now());
+    const auto vsync = snapToVsync(justBeforeTimePoint);
+
+    purgeTimelines(now);
+
+    for (auto& timeline : mTimelines) {
+        if (timeline.validUntil() && timeline.validUntil()->ns() > vsync) {
+            return timeline.isVSyncInPhase(model, vsync, frameRate);
+        }
+    }
+
+    // The last timeline should always be valid
+    return mTimelines.back().isVSyncInPhase(model, vsync, frameRate);
 }
 
-void VSyncPredictor::setRenderRate(Fps renderRate) {
+void VSyncPredictor::setRenderRate(Fps renderRate, bool applyImmediately) {
     ATRACE_FORMAT("%s %s", __func__, to_string(renderRate).c_str());
     ALOGV("%s %s: RenderRate %s ", __func__, to_string(mId).c_str(), to_string(renderRate).c_str());
     std::lock_guard lock(mMutex);
+    const auto prevRenderRate = mRenderRateOpt;
     mRenderRateOpt = renderRate;
+    const auto renderPeriodDelta =
+            prevRenderRate ? prevRenderRate->getPeriodNsecs() - renderRate.getPeriodNsecs() : 0;
+    const bool newRenderRateIsHigher = renderPeriodDelta > renderRate.getPeriodNsecs() &&
+            mLastCommittedVsync.ns() - mClock->now() > 2 * renderRate.getPeriodNsecs();
+    if (applyImmediately) {
+        while (mTimelines.size() > 1) {
+            mTimelines.pop_front();
+        }
+
+        mTimelines.front().setRenderRate(renderRate);
+    } else if (newRenderRateIsHigher) {
+        mTimelines.clear();
+        mLastCommittedVsync = TimePoint::fromNs(0);
+
+    } else {
+        mTimelines.back().freeze(
+                TimePoint::fromNs(mLastCommittedVsync.ns() + mIdealPeriod.ns() / 2));
+    }
+    mTimelines.emplace_back(mLastCommittedVsync, mIdealPeriod, renderRate);
+    purgeTimelines(TimePoint::fromNs(mClock->now()));
 }
 
 void VSyncPredictor::setDisplayModePtr(ftl::NonNull<DisplayModePtr> modePtr) {
@@ -401,6 +410,7 @@
     std::lock_guard lock(mMutex);
 
     mDisplayModePtr = modePtr;
+    mNumVsyncsForFrame = numVsyncsPerFrame(mDisplayModePtr);
     traceInt64("VSP-setPeriod", modePtr->getVsyncRate().getPeriodNsecs());
 
     static constexpr size_t kSizeLimit = 30;
@@ -415,8 +425,14 @@
     clearTimestamps();
 }
 
-void VSyncPredictor::ensureMinFrameDurationIsKept(TimePoint expectedPresentTime,
-                                                  TimePoint lastConfirmedPresentTime) {
+Duration VSyncPredictor::ensureMinFrameDurationIsKept(TimePoint expectedPresentTime,
+                                                      TimePoint lastConfirmedPresentTime) {
+    ATRACE_CALL();
+
+    if (mNumVsyncsForFrame <= 1) {
+        return 0ns;
+    }
+
     const auto currentPeriod = mRateMap.find(idealPeriod())->second.slope;
     const auto threshold = currentPeriod / 2;
     const auto minFramePeriod = minFramePeriodLocked().ns();
@@ -442,17 +458,20 @@
     if (!mPastExpectedPresentTimes.empty()) {
         const auto phase = Duration(mPastExpectedPresentTimes.back() - expectedPresentTime);
         if (phase > 0ns) {
-            if (mLastVsyncSequence) {
-                mLastVsyncSequence->vsyncTime += phase.ns();
+            for (auto& timeline : mTimelines) {
+                timeline.shiftVsyncSequence(phase);
             }
             mPastExpectedPresentTimes.clear();
+            return phase;
         }
     }
+
+    return 0ns;
 }
 
 void VSyncPredictor::onFrameBegin(TimePoint expectedPresentTime,
                                   TimePoint lastConfirmedPresentTime) {
-    ATRACE_CALL();
+    ATRACE_NAME("VSyncPredictor::onFrameBegin");
     std::lock_guard lock(mMutex);
 
     if (!mDisplayModePtr->getVrrConfig()) return;
@@ -482,11 +501,14 @@
         }
     }
 
-    ensureMinFrameDurationIsKept(expectedPresentTime, lastConfirmedPresentTime);
+    const auto phase = ensureMinFrameDurationIsKept(expectedPresentTime, lastConfirmedPresentTime);
+    if (phase > 0ns) {
+        mMissedVsync = {expectedPresentTime, minFramePeriodLocked()};
+    }
 }
 
 void VSyncPredictor::onFrameMissed(TimePoint expectedPresentTime) {
-    ATRACE_CALL();
+    ATRACE_NAME("VSyncPredictor::onFrameMissed");
 
     std::lock_guard lock(mMutex);
     if (!mDisplayModePtr->getVrrConfig()) return;
@@ -496,14 +518,15 @@
     const auto lastConfirmedPresentTime =
             TimePoint::fromNs(expectedPresentTime.ns() + currentPeriod);
 
-    ensureMinFrameDurationIsKept(expectedPresentTime, lastConfirmedPresentTime);
-    mLastMissedVsync = expectedPresentTime;
+    const auto phase = ensureMinFrameDurationIsKept(expectedPresentTime, lastConfirmedPresentTime);
+    if (phase > 0ns) {
+        mMissedVsync = {expectedPresentTime, Duration::fromNs(0)};
+    }
 }
 
 VSyncPredictor::Model VSyncPredictor::getVSyncPredictionModel() const {
     std::lock_guard lock(mMutex);
-    const auto model = VSyncPredictor::getVSyncPredictionModelLocked();
-    return {model.slope, model.intercept};
+    return VSyncPredictor::getVSyncPredictionModelLocked();
 }
 
 VSyncPredictor::Model VSyncPredictor::getVSyncPredictionModelLocked() const {
@@ -524,6 +547,11 @@
         mTimestamps.clear();
         mLastTimestampIndex = 0;
     }
+
+    mTimelines.clear();
+    mLastCommittedVsync = TimePoint::fromNs(0);
+    mIdealPeriod = Period::fromNs(idealPeriod());
+    mTimelines.emplace_back(mLastCommittedVsync, mIdealPeriod, mRenderRateOpt);
 }
 
 bool VSyncPredictor::needsMoreSamples() const {
@@ -547,6 +575,171 @@
                       period / 1e6f, periodInterceptTuple.slope / 1e6f,
                       periodInterceptTuple.intercept);
     }
+    StringAppendF(&result, "\tmTimelines.size()=%zu\n", mTimelines.size());
+}
+
+void VSyncPredictor::purgeTimelines(android::TimePoint now) {
+    const auto kEnoughFramesToBreakPhase = 5;
+    if (mRenderRateOpt &&
+        mLastCommittedVsync.ns() + mRenderRateOpt->getPeriodNsecs() * kEnoughFramesToBreakPhase <
+                mClock->now()) {
+        mTimelines.clear();
+        mLastCommittedVsync = TimePoint::fromNs(0);
+        mTimelines.emplace_back(mLastCommittedVsync, mIdealPeriod, mRenderRateOpt);
+        return;
+    }
+
+    while (mTimelines.size() > 1) {
+        const auto validUntilOpt = mTimelines.front().validUntil();
+        if (validUntilOpt && *validUntilOpt < now) {
+            mTimelines.pop_front();
+        } else {
+            break;
+        }
+    }
+    LOG_ALWAYS_FATAL_IF(mTimelines.empty());
+    LOG_ALWAYS_FATAL_IF(mTimelines.back().validUntil().has_value());
+}
+
+auto VSyncPredictor::VsyncTimeline::makeVsyncSequence(TimePoint knownVsync)
+        -> std::optional<VsyncSequence> {
+    if (knownVsync.ns() == 0) return std::nullopt;
+    return std::make_optional<VsyncSequence>({knownVsync.ns(), 0});
+}
+
+VSyncPredictor::VsyncTimeline::VsyncTimeline(TimePoint knownVsync, Period idealPeriod,
+                                             std::optional<Fps> renderRateOpt)
+      : mIdealPeriod(idealPeriod),
+        mRenderRateOpt(renderRateOpt),
+        mLastVsyncSequence(makeVsyncSequence(knownVsync)) {}
+
+void VSyncPredictor::VsyncTimeline::freeze(TimePoint lastVsync) {
+    LOG_ALWAYS_FATAL_IF(mValidUntil.has_value());
+    ATRACE_FORMAT_INSTANT("renderRate %s valid for %.2f",
+                          mRenderRateOpt ? to_string(*mRenderRateOpt).c_str() : "NA",
+                          float(lastVsync.ns() - TimePoint::now().ns()) / 1e6f);
+    mValidUntil = lastVsync;
+}
+
+std::optional<TimePoint> VSyncPredictor::VsyncTimeline::nextAnticipatedVSyncTimeFrom(
+        Model model, std::optional<Period> minFramePeriodOpt, nsecs_t vsync,
+        MissedVsync missedVsync, std::optional<nsecs_t> lastVsyncOpt) {
+    ATRACE_FORMAT("renderRate %s", mRenderRateOpt ? to_string(*mRenderRateOpt).c_str() : "NA");
+
+    nsecs_t vsyncTime = snapToVsyncAlignedWithRenderRate(model, vsync);
+    const auto threshold = model.slope / 2;
+    const auto lastFrameMissed =
+            lastVsyncOpt && std::abs(*lastVsyncOpt - missedVsync.vsync.ns()) < threshold;
+    nsecs_t vsyncFixupTime = 0;
+    if (FlagManager::getInstance().vrr_config() && lastFrameMissed) {
+        // If the last frame missed is the last vsync, we already shifted the timeline. Depends on
+        // whether we skipped the frame (onFrameMissed) or not (onFrameBegin) we apply a different
+        // fixup. There is no need to to shift the vsync timeline again.
+        vsyncTime += missedVsync.fixup.ns();
+        ATRACE_FORMAT_INSTANT("lastFrameMissed");
+    } else if (minFramePeriodOpt) {
+        if (FlagManager::getInstance().vrr_config() && lastVsyncOpt) {
+            // lastVsyncOpt is based on the old timeline before we shifted it. we should correct it
+            // first before trying to use it.
+            if (mLastVsyncSequence->seq > 0) {
+                lastVsyncOpt = snapToVsyncAlignedWithRenderRate(model, *lastVsyncOpt);
+            }
+            const auto vsyncDiff = vsyncTime - *lastVsyncOpt;
+            if (vsyncDiff <= minFramePeriodOpt->ns() - threshold) {
+                vsyncFixupTime = *lastVsyncOpt + minFramePeriodOpt->ns() - vsyncTime;
+                ATRACE_FORMAT_INSTANT("minFramePeriod violation. next in %.2f which is %.2f "
+                                      "from "
+                                      "prev. "
+                                      "adjust by %.2f",
+                                      static_cast<float>(vsyncTime - TimePoint::now().ns()) / 1e6f,
+                                      static_cast<float>(vsyncTime - *lastVsyncOpt) / 1e6f,
+                                      static_cast<float>(vsyncFixupTime) / 1e6f);
+            }
+        }
+        vsyncTime += vsyncFixupTime;
+    }
+
+    ATRACE_FORMAT_INSTANT("vsync in %.2fms", float(vsyncTime - TimePoint::now().ns()) / 1e6f);
+    if (mValidUntil && vsyncTime > mValidUntil->ns()) {
+        ATRACE_FORMAT_INSTANT("no longer valid for vsync in %.2f",
+                              static_cast<float>(vsyncTime - TimePoint::now().ns()) / 1e6f);
+        return std::nullopt;
+    }
+
+    // If we needed a fixup, it means that we changed the render rate and the chosen vsync would
+    // cross minFramePeriod. In that case we need to shift the entire vsync timeline.
+    if (vsyncFixupTime > 0) {
+        shiftVsyncSequence(Duration::fromNs(vsyncFixupTime));
+    }
+
+    return TimePoint::fromNs(vsyncTime);
+}
+
+auto VSyncPredictor::VsyncTimeline::getVsyncSequenceLocked(Model model, nsecs_t vsync)
+        -> VsyncSequence {
+    if (!mLastVsyncSequence) return {vsync, 0};
+
+    const auto [lastVsyncTime, lastVsyncSequence] = *mLastVsyncSequence;
+    const auto vsyncSequence = lastVsyncSequence +
+            static_cast<int64_t>(std::round((vsync - lastVsyncTime) /
+                                            static_cast<float>(model.slope)));
+    return {vsync, vsyncSequence};
+}
+
+nsecs_t VSyncPredictor::VsyncTimeline::snapToVsyncAlignedWithRenderRate(Model model,
+                                                                        nsecs_t vsync) {
+    // update the mLastVsyncSequence for reference point
+    mLastVsyncSequence = getVsyncSequenceLocked(model, vsync);
+
+    const auto renderRatePhase = [&]() -> int {
+        if (!mRenderRateOpt) return 0;
+        const auto divisor =
+                RefreshRateSelector::getFrameRateDivisor(Fps::fromPeriodNsecs(mIdealPeriod.ns()),
+                                                         *mRenderRateOpt);
+        if (divisor <= 1) return 0;
+
+        int mod = mLastVsyncSequence->seq % divisor;
+        if (mod == 0) return 0;
+
+        // This is actually a bug fix, but guarded with vrr_config since we found it with this
+        // config
+        if (FlagManager::getInstance().vrr_config()) {
+            if (mod < 0) mod += divisor;
+        }
+
+        return divisor - mod;
+    }();
+
+    if (renderRatePhase == 0) {
+        return mLastVsyncSequence->vsyncTime;
+    }
+
+    return mLastVsyncSequence->vsyncTime + model.slope * renderRatePhase;
+}
+
+bool VSyncPredictor::VsyncTimeline::isVSyncInPhase(Model model, nsecs_t vsync, Fps frameRate) {
+    const auto getVsyncIn = [](TimePoint now, nsecs_t timePoint) -> float {
+        return ticks<std::milli, float>(TimePoint::fromNs(timePoint) - now);
+    };
+
+    Fps displayFps = mRenderRateOpt ? *mRenderRateOpt : Fps::fromPeriodNsecs(mIdealPeriod.ns());
+    const auto divisor = RefreshRateSelector::getFrameRateDivisor(displayFps, frameRate);
+    const auto now = TimePoint::now();
+
+    if (divisor <= 1) {
+        return true;
+    }
+    const auto vsyncSequence = getVsyncSequenceLocked(model, vsync);
+    ATRACE_FORMAT_INSTANT("vsync in: %.2f sequence: %" PRId64 " divisor: %zu",
+                          getVsyncIn(now, vsyncSequence.vsyncTime), vsyncSequence.seq, divisor);
+    return vsyncSequence.seq % divisor == 0;
+}
+
+void VSyncPredictor::VsyncTimeline::shiftVsyncSequence(Duration phase) {
+    if (mLastVsyncSequence) {
+        ATRACE_FORMAT_INSTANT("adjusting vsync by %.2f", static_cast<float>(phase.ns()) / 1e6f);
+        mLastVsyncSequence->vsyncTime += phase.ns();
+    }
 }
 
 } // namespace android::scheduler
diff --git a/services/surfaceflinger/Scheduler/VSyncPredictor.h b/services/surfaceflinger/Scheduler/VSyncPredictor.h
index 8fd7e60..3ed1d41 100644
--- a/services/surfaceflinger/Scheduler/VSyncPredictor.h
+++ b/services/surfaceflinger/Scheduler/VSyncPredictor.h
@@ -22,6 +22,7 @@
 #include <vector>
 
 #include <android-base/thread_annotations.h>
+#include <scheduler/TimeKeeper.h>
 #include <ui/DisplayId.h>
 
 #include "VSyncTracker.h"
@@ -31,6 +32,7 @@
 class VSyncPredictor : public VSyncTracker {
 public:
     /*
+     * \param [in] Clock The clock abstraction. Useful for unit tests.
      * \param [in] PhysicalDisplayid The display this corresponds to.
      * \param [in] modePtr  The initial display mode
      * \param [in] historySize  The internal amount of entries to store in the model.
@@ -38,13 +40,13 @@
      * predicting. \param [in] outlierTolerancePercent a number 0 to 100 that will be used to filter
      * samples that fall outlierTolerancePercent from an anticipated vsync event.
      */
-    VSyncPredictor(ftl::NonNull<DisplayModePtr> modePtr, size_t historySize,
+    VSyncPredictor(std::unique_ptr<Clock>, ftl::NonNull<DisplayModePtr> modePtr, size_t historySize,
                    size_t minimumSamplesForPrediction, uint32_t outlierTolerancePercent);
     ~VSyncPredictor();
 
     bool addVsyncTimestamp(nsecs_t timestamp) final EXCLUDES(mMutex);
     nsecs_t nextAnticipatedVSyncTimeFrom(nsecs_t timePoint,
-                                         std::optional<nsecs_t> lastVsyncOpt = {}) const final
+                                         std::optional<nsecs_t> lastVsyncOpt = {}) final
             EXCLUDES(mMutex);
     nsecs_t currentPeriod() const final EXCLUDES(mMutex);
     Period minFramePeriod() const final EXCLUDES(mMutex);
@@ -62,11 +64,11 @@
 
     VSyncPredictor::Model getVSyncPredictionModel() const EXCLUDES(mMutex);
 
-    bool isVSyncInPhase(nsecs_t timePoint, Fps frameRate) const final EXCLUDES(mMutex);
+    bool isVSyncInPhase(nsecs_t timePoint, Fps frameRate) final EXCLUDES(mMutex);
 
     void setDisplayModePtr(ftl::NonNull<DisplayModePtr>) final EXCLUDES(mMutex);
 
-    void setRenderRate(Fps) final EXCLUDES(mMutex);
+    void setRenderRate(Fps, bool applyImmediately) final EXCLUDES(mMutex);
 
     void onFrameBegin(TimePoint expectedPresentTime, TimePoint lastConfirmedPresentTime) final
             EXCLUDES(mMutex);
@@ -75,10 +77,44 @@
     void dump(std::string& result) const final EXCLUDES(mMutex);
 
 private:
+    struct VsyncSequence {
+        nsecs_t vsyncTime;
+        int64_t seq;
+    };
+
+    struct MissedVsync {
+        TimePoint vsync;
+        Duration fixup = Duration::fromNs(0);
+    };
+
+    class VsyncTimeline {
+    public:
+        VsyncTimeline(TimePoint knownVsync, Period idealPeriod, std::optional<Fps> renderRateOpt);
+        std::optional<TimePoint> nextAnticipatedVSyncTimeFrom(
+                Model model, std::optional<Period> minFramePeriodOpt, nsecs_t vsyncTime,
+                MissedVsync lastMissedVsync, std::optional<nsecs_t> lastVsyncOpt = {});
+        void freeze(TimePoint lastVsync);
+        std::optional<TimePoint> validUntil() const { return mValidUntil; }
+        bool isVSyncInPhase(Model, nsecs_t vsync, Fps frameRate);
+        void shiftVsyncSequence(Duration phase);
+        void setRenderRate(Fps renderRate) { mRenderRateOpt = renderRate; }
+
+    private:
+        nsecs_t snapToVsyncAlignedWithRenderRate(Model model, nsecs_t vsync);
+        VsyncSequence getVsyncSequenceLocked(Model, nsecs_t vsync);
+        std::optional<VsyncSequence> makeVsyncSequence(TimePoint knownVsync);
+
+        const Period mIdealPeriod = Duration::fromNs(0);
+        std::optional<Fps> mRenderRateOpt;
+        std::optional<TimePoint> mValidUntil;
+        std::optional<VsyncSequence> mLastVsyncSequence;
+    };
+
     VSyncPredictor(VSyncPredictor const&) = delete;
     VSyncPredictor& operator=(VSyncPredictor const&) = delete;
     void clearTimestamps() REQUIRES(mMutex);
 
+    const std::unique_ptr<Clock> mClock;
     const PhysicalDisplayId mId;
 
     inline void traceInt64If(const char* name, int64_t value) const;
@@ -88,16 +124,10 @@
     bool validate(nsecs_t timestamp) const REQUIRES(mMutex);
     Model getVSyncPredictionModelLocked() const REQUIRES(mMutex);
     nsecs_t snapToVsync(nsecs_t timePoint) const REQUIRES(mMutex);
-    nsecs_t snapToVsyncAlignedWithRenderRate(nsecs_t timePoint) const REQUIRES(mMutex);
-    bool isVSyncInPhaseLocked(nsecs_t timePoint, unsigned divisor) const REQUIRES(mMutex);
     Period minFramePeriodLocked() const REQUIRES(mMutex);
-    void ensureMinFrameDurationIsKept(TimePoint, TimePoint) REQUIRES(mMutex);
+    Duration ensureMinFrameDurationIsKept(TimePoint, TimePoint) REQUIRES(mMutex);
+    void purgeTimelines(android::TimePoint now) REQUIRES(mMutex);
 
-    struct VsyncSequence {
-        nsecs_t vsyncTime;
-        int64_t seq;
-    };
-    VsyncSequence getVsyncSequenceLocked(nsecs_t timestamp) const REQUIRES(mMutex);
     nsecs_t idealPeriod() const REQUIRES(mMutex);
 
     bool const mTraceOn;
@@ -115,13 +145,16 @@
     std::vector<nsecs_t> mTimestamps GUARDED_BY(mMutex);
 
     ftl::NonNull<DisplayModePtr> mDisplayModePtr GUARDED_BY(mMutex);
-    std::optional<Fps> mRenderRateOpt GUARDED_BY(mMutex);
-
-    mutable std::optional<VsyncSequence> mLastVsyncSequence GUARDED_BY(mMutex);
+    int mNumVsyncsForFrame GUARDED_BY(mMutex);
 
     std::deque<TimePoint> mPastExpectedPresentTimes GUARDED_BY(mMutex);
 
-    TimePoint mLastMissedVsync GUARDED_BY(mMutex);
+    MissedVsync mMissedVsync GUARDED_BY(mMutex);
+
+    std::deque<VsyncTimeline> mTimelines GUARDED_BY(mMutex);
+    TimePoint mLastCommittedVsync GUARDED_BY(mMutex) = TimePoint::fromNs(0);
+    Period mIdealPeriod GUARDED_BY(mMutex) = Duration::fromNs(0);
+    std::optional<Fps> mRenderRateOpt GUARDED_BY(mMutex);
 };
 
 } // namespace android::scheduler
diff --git a/services/surfaceflinger/Scheduler/VSyncTracker.h b/services/surfaceflinger/Scheduler/VSyncTracker.h
index 37bd4b4..8787cdb 100644
--- a/services/surfaceflinger/Scheduler/VSyncTracker.h
+++ b/services/surfaceflinger/Scheduler/VSyncTracker.h
@@ -56,8 +56,8 @@
      *                          and avoid crossing the minimal frame period of a VRR display.
      * \return                  A prediction of the timestamp of a vsync event.
      */
-    virtual nsecs_t nextAnticipatedVSyncTimeFrom(
-            nsecs_t timePoint, std::optional<nsecs_t> lastVsyncOpt = {}) const = 0;
+    virtual nsecs_t nextAnticipatedVSyncTimeFrom(nsecs_t timePoint,
+                                                 std::optional<nsecs_t> lastVsyncOpt = {}) = 0;
 
     /*
      * The current period of the vsync signal.
@@ -82,7 +82,7 @@
      * \param [in] timePoint  A vsync timestamp
      * \param [in] frameRate  The frame rate to check for
      */
-    virtual bool isVSyncInPhase(nsecs_t timePoint, Fps frameRate) const = 0;
+    virtual bool isVSyncInPhase(nsecs_t timePoint, Fps frameRate) = 0;
 
     /*
      * Sets the active mode of the display which includes the vsync period and other VRR attributes.
@@ -102,8 +102,10 @@
      * when a display is running at 120Hz but the render frame rate is 60Hz.
      *
      * \param [in] Fps   The render rate the tracker should operate at.
+     * \param [in] applyImmediately Whether to apply the new render rate immediately regardless of
+     *                              already committed vsyncs.
      */
-    virtual void setRenderRate(Fps) = 0;
+    virtual void setRenderRate(Fps, bool applyImmediately) = 0;
 
     virtual void onFrameBegin(TimePoint expectedPresentTime,
                               TimePoint lastConfirmedPresentTime) = 0;
diff --git a/services/surfaceflinger/Scheduler/VsyncSchedule.cpp b/services/surfaceflinger/Scheduler/VsyncSchedule.cpp
index 001938c..2fa3318 100644
--- a/services/surfaceflinger/Scheduler/VsyncSchedule.cpp
+++ b/services/surfaceflinger/Scheduler/VsyncSchedule.cpp
@@ -120,8 +120,8 @@
     constexpr size_t kMinSamplesForPrediction = 6;
     constexpr uint32_t kDiscardOutlierPercent = 20;
 
-    return std::make_unique<VSyncPredictor>(modePtr, kHistorySize, kMinSamplesForPrediction,
-                                            kDiscardOutlierPercent);
+    return std::make_unique<VSyncPredictor>(std::make_unique<SystemClock>(), modePtr, kHistorySize,
+                                            kMinSamplesForPrediction, kDiscardOutlierPercent);
 }
 
 VsyncSchedule::DispatchPtr VsyncSchedule::createDispatch(TrackerPtr tracker) {
diff --git a/services/surfaceflinger/Scheduler/VsyncSchedule.h b/services/surfaceflinger/Scheduler/VsyncSchedule.h
index 85cd3e7..881d678 100644
--- a/services/surfaceflinger/Scheduler/VsyncSchedule.h
+++ b/services/surfaceflinger/Scheduler/VsyncSchedule.h
@@ -81,7 +81,7 @@
     bool addResyncSample(TimePoint timestamp, ftl::Optional<Period> hwcVsyncPeriod);
 
     // TODO(b/185535769): Hide behind API.
-    const VsyncTracker& getTracker() const { return *mTracker; }
+    VsyncTracker& getTracker() const { return *mTracker; }
     VsyncTracker& getTracker() { return *mTracker; }
     VsyncController& getController() { return *mController; }
 
diff --git a/services/surfaceflinger/SurfaceFlinger.cpp b/services/surfaceflinger/SurfaceFlinger.cpp
index be3c1d8..30b8953 100644
--- a/services/surfaceflinger/SurfaceFlinger.cpp
+++ b/services/surfaceflinger/SurfaceFlinger.cpp
@@ -1245,8 +1245,8 @@
     switch (display->setDesiredMode(std::move(desiredMode))) {
         case DisplayDevice::DesiredModeAction::InitiateDisplayModeSwitch:
             // DisplayDevice::setDesiredMode updated the render rate, so inform Scheduler.
-            mScheduler->setRenderRate(displayId,
-                                      display->refreshRateSelector().getActiveMode().fps);
+            mScheduler->setRenderRate(displayId, display->refreshRateSelector().getActiveMode().fps,
+                                      /*applyImmediately*/ true);
 
             // Schedule a new frame to initiate the display mode switch.
             scheduleComposite(FrameHint::kNone);
@@ -1267,7 +1267,7 @@
             mScheduler->setModeChangePending(true);
             break;
         case DisplayDevice::DesiredModeAction::InitiateRenderRateSwitch:
-            mScheduler->setRenderRate(displayId, mode.fps);
+            mScheduler->setRenderRate(displayId, mode.fps, /*applyImmediately*/ false);
 
             if (displayId == mActiveDisplayId) {
                 mScheduler->updatePhaseConfiguration(mode.fps);
@@ -1388,7 +1388,7 @@
 
     constexpr bool kAllowToEnable = true;
     mScheduler->resyncToHardwareVsync(displayId, kAllowToEnable, std::move(activeModePtr).take());
-    mScheduler->setRenderRate(displayId, renderFps);
+    mScheduler->setRenderRate(displayId, renderFps, /*applyImmediately*/ true);
 
     if (displayId == mActiveDisplayId) {
         mScheduler->updatePhaseConfiguration(renderFps);
@@ -2736,7 +2736,8 @@
     refreshArgs.forceOutputColorMode = mForceColorMode;
 
     refreshArgs.updatingOutputGeometryThisFrame = mVisibleRegionsDirty;
-    refreshArgs.updatingGeometryThisFrame = mGeometryDirty.exchange(false) || mVisibleRegionsDirty;
+    refreshArgs.updatingGeometryThisFrame = mGeometryDirty.exchange(false) ||
+            mVisibleRegionsDirty || mDrawingState.colorMatrixChanged;
     refreshArgs.internalDisplayRotationFlags = getActiveDisplayRotationFlags();
 
     if (CC_UNLIKELY(mDrawingState.colorMatrixChanged)) {
@@ -4382,7 +4383,8 @@
     // The pacesetter must be registered before EventThread creation below.
     mScheduler->registerDisplay(display->getPhysicalId(), display->holdRefreshRateSelector());
     if (FlagManager::getInstance().vrr_config()) {
-        mScheduler->setRenderRate(display->getPhysicalId(), activeMode.fps);
+        mScheduler->setRenderRate(display->getPhysicalId(), activeMode.fps,
+                                  /*applyImmediately*/ true);
     }
 
     const auto configs = mScheduler->getVsyncConfiguration().getCurrentConfigs();
@@ -4706,7 +4708,14 @@
         return TransactionReadiness::NotReady;
     }
 
-    if (!mScheduler->isVsyncValid(expectedPresentTime, transaction.originUid)) {
+    const auto vsyncId = VsyncId{transaction.frameTimelineInfo.vsyncId};
+
+    // Transactions with VsyncId are already throttled by the vsyncId (i.e. Choreographer issued
+    // the vsyncId according to the frame rate override cadence) so we shouldn't throttle again
+    // when applying the transaction. Otherwise we might throttle older transactions
+    // incorrectly as the frame rate of SF changed before it drained the older transactions.
+    if (ftl::to_underlying(vsyncId) == FrameTimelineInfo::INVALID_VSYNC_ID &&
+        !mScheduler->isVsyncValid(expectedPresentTime, transaction.originUid)) {
         ATRACE_FORMAT("!isVsyncValid expectedPresentTime: %" PRId64 " uid: %d", expectedPresentTime,
                       transaction.originUid);
         return TransactionReadiness::NotReady;
@@ -4714,8 +4723,7 @@
 
     // If the client didn't specify desiredPresentTime, use the vsyncId to determine the
     // expected present time of this transaction.
-    if (transaction.isAutoTimestamp &&
-        frameIsEarly(expectedPresentTime, VsyncId{transaction.frameTimelineInfo.vsyncId})) {
+    if (transaction.isAutoTimestamp && frameIsEarly(expectedPresentTime, vsyncId)) {
         ATRACE_FORMAT("frameIsEarly vsyncId: %" PRId64 " expectedPresentTime: %" PRId64,
                       transaction.frameTimelineInfo.vsyncId, expectedPresentTime);
         return TransactionReadiness::NotReady;
diff --git a/services/surfaceflinger/tests/unittests/RefreshRateSelectorTest.cpp b/services/surfaceflinger/tests/unittests/RefreshRateSelectorTest.cpp
index 0a6e305..fe0e3d1 100644
--- a/services/surfaceflinger/tests/unittests/RefreshRateSelectorTest.cpp
+++ b/services/surfaceflinger/tests/unittests/RefreshRateSelectorTest.cpp
@@ -259,6 +259,44 @@
         config.enableFrameRateOverride = GetParam();
         return TestableRefreshRateSelector(modes, activeModeId, config);
     }
+
+    template <class T>
+    void testFrameRateCategoryWithMultipleLayers(const std::initializer_list<T>& testCases,
+                                                 const TestableRefreshRateSelector& selector) {
+        std::vector<LayerRequirement> layers;
+        for (auto testCase : testCases) {
+            ALOGI("**** %s: Testing desiredFrameRate=%s, frameRateCategory=%s", __func__,
+                  to_string(testCase.desiredFrameRate).c_str(),
+                  ftl::enum_string(testCase.frameRateCategory).c_str());
+
+            if (testCase.desiredFrameRate.isValid()) {
+                std::stringstream ss;
+                ss << to_string(testCase.desiredFrameRate)
+                   << ftl::enum_string(testCase.frameRateCategory) << "ExplicitDefault";
+                LayerRequirement layer = {.name = ss.str(),
+                                          .vote = LayerVoteType::ExplicitDefault,
+                                          .desiredRefreshRate = testCase.desiredFrameRate,
+                                          .weight = 1.f};
+                layers.push_back(layer);
+            }
+
+            if (testCase.frameRateCategory != FrameRateCategory::Default) {
+                std::stringstream ss;
+                ss << "ExplicitCategory (" << ftl::enum_string(testCase.frameRateCategory) << ")";
+                LayerRequirement layer = {.name = ss.str(),
+                                          .vote = LayerVoteType::ExplicitCategory,
+                                          .frameRateCategory = testCase.frameRateCategory,
+                                          .weight = 1.f};
+                layers.push_back(layer);
+            }
+
+            EXPECT_EQ(testCase.expectedFrameRate,
+                      selector.getBestFrameRateMode(layers).modePtr->getPeakFps())
+                    << "Did not get expected frame rate for frameRate="
+                    << to_string(testCase.desiredFrameRate)
+                    << " category=" << ftl::enum_string(testCase.frameRateCategory);
+        }
+    }
 };
 
 RefreshRateSelectorTest::RefreshRateSelectorTest() {
@@ -1542,6 +1580,96 @@
     }
 }
 
+TEST_P(RefreshRateSelectorTest,
+       getBestFrameRateMode_withFrameRateCategoryMultiLayers_30_60_90_120) {
+    auto selector = createSelector(makeModes(kMode30, kMode60, kMode90, kMode120), kModeId60);
+
+    struct Case {
+        // Params
+        Fps desiredFrameRate = 0_Hz;
+        FrameRateCategory frameRateCategory = FrameRateCategory::Default;
+
+        // Expected result
+        Fps expectedFrameRate = 0_Hz;
+    };
+
+    testFrameRateCategoryWithMultipleLayers(
+            std::initializer_list<Case>{
+                    {0_Hz, FrameRateCategory::High, 90_Hz},
+                    {0_Hz, FrameRateCategory::NoPreference, 90_Hz},
+                    {0_Hz, FrameRateCategory::Normal, 90_Hz},
+                    {0_Hz, FrameRateCategory::Normal, 90_Hz},
+                    {0_Hz, FrameRateCategory::NoPreference, 90_Hz},
+            },
+            selector);
+
+    testFrameRateCategoryWithMultipleLayers(
+            std::initializer_list<Case>{
+                    {0_Hz, FrameRateCategory::Normal, 60_Hz},
+                    {0_Hz, FrameRateCategory::High, 90_Hz},
+                    {0_Hz, FrameRateCategory::NoPreference, 90_Hz},
+            },
+            selector);
+
+    testFrameRateCategoryWithMultipleLayers(
+            std::initializer_list<Case>{
+                    {30_Hz, FrameRateCategory::High, 90_Hz},
+                    {24_Hz, FrameRateCategory::High, 120_Hz},
+                    {12_Hz, FrameRateCategory::Normal, 120_Hz},
+                    {30_Hz, FrameRateCategory::NoPreference, 120_Hz},
+
+            },
+            selector);
+
+    testFrameRateCategoryWithMultipleLayers(
+            std::initializer_list<Case>{
+                    {24_Hz, FrameRateCategory::Default, 120_Hz},
+                    {30_Hz, FrameRateCategory::Default, 120_Hz},
+                    {120_Hz, FrameRateCategory::Default, 120_Hz},
+            },
+            selector);
+}
+
+TEST_P(RefreshRateSelectorTest, getBestFrameRateMode_withFrameRateCategoryMultiLayers_60_120) {
+    auto selector = createSelector(makeModes(kMode60, kMode120), kModeId60);
+
+    struct Case {
+        // Params
+        Fps desiredFrameRate = 0_Hz;
+        FrameRateCategory frameRateCategory = FrameRateCategory::Default;
+
+        // Expected result
+        Fps expectedFrameRate = 0_Hz;
+    };
+
+    testFrameRateCategoryWithMultipleLayers(std::initializer_list<
+                                                    Case>{{0_Hz, FrameRateCategory::High, 120_Hz},
+                                                          {0_Hz, FrameRateCategory::NoPreference,
+                                                           120_Hz},
+                                                          {0_Hz, FrameRateCategory::Normal, 120_Hz},
+                                                          {0_Hz, FrameRateCategory::Normal, 120_Hz},
+                                                          {0_Hz, FrameRateCategory::NoPreference,
+                                                           120_Hz}},
+                                            selector);
+
+    testFrameRateCategoryWithMultipleLayers(std::initializer_list<
+                                                    Case>{{24_Hz, FrameRateCategory::High, 120_Hz},
+                                                          {30_Hz, FrameRateCategory::High, 120_Hz},
+                                                          {12_Hz, FrameRateCategory::Normal,
+                                                           120_Hz},
+                                                          {30_Hz, FrameRateCategory::NoPreference,
+                                                           120_Hz}},
+                                            selector);
+
+    testFrameRateCategoryWithMultipleLayers(
+            std::initializer_list<Case>{
+                    {24_Hz, FrameRateCategory::Default, 120_Hz},
+                    {30_Hz, FrameRateCategory::Default, 120_Hz},
+                    {120_Hz, FrameRateCategory::Default, 120_Hz},
+            },
+            selector);
+}
+
 TEST_P(RefreshRateSelectorTest, getBestFrameRateMode_withFrameRateCategory_60_120) {
     auto selector = createSelector(makeModes(kMode60, kMode120), kModeId60);
 
@@ -1665,6 +1793,7 @@
     lr1.frameRateCategory = FrameRateCategory::HighHint;
     lr1.name = "ExplicitCategory HighHint";
     lr2.vote = LayerVoteType::ExplicitExactOrMultiple;
+    lr2.frameRateCategory = FrameRateCategory::Default;
     lr2.desiredRefreshRate = 30_Hz;
     lr2.name = "30Hz ExplicitExactOrMultiple";
     actualRankedFrameRates = selector.getRankedFrameRates(layers);
@@ -1691,6 +1820,153 @@
         EXPECT_EQ(kModeId30, actualRankedFrameRates.ranking.front().frameRateMode.modePtr->getId());
         EXPECT_FALSE(actualRankedFrameRates.consideredSignals.touch);
     }
+
+    lr1.vote = LayerVoteType::ExplicitCategory;
+    lr1.frameRateCategory = FrameRateCategory::HighHint;
+    lr1.name = "ExplicitCategory HighHint";
+    lr2.vote = LayerVoteType::Heuristic;
+    lr2.desiredRefreshRate = 30_Hz;
+    lr2.name = "30Hz Heuristic";
+    actualRankedFrameRates = selector.getRankedFrameRates(layers);
+    // Gets touch boost
+    EXPECT_EQ(120_Hz, actualRankedFrameRates.ranking.front().frameRateMode.fps);
+    EXPECT_EQ(kModeId120, actualRankedFrameRates.ranking.front().frameRateMode.modePtr->getId());
+    EXPECT_TRUE(actualRankedFrameRates.consideredSignals.touch);
+
+    lr1.vote = LayerVoteType::ExplicitCategory;
+    lr1.frameRateCategory = FrameRateCategory::HighHint;
+    lr1.name = "ExplicitCategory HighHint";
+    lr2.vote = LayerVoteType::Min;
+    lr2.name = "Min";
+    actualRankedFrameRates = selector.getRankedFrameRates(layers);
+    // Gets touch boost
+    EXPECT_EQ(120_Hz, actualRankedFrameRates.ranking.front().frameRateMode.fps);
+    EXPECT_EQ(kModeId120, actualRankedFrameRates.ranking.front().frameRateMode.modePtr->getId());
+    EXPECT_TRUE(actualRankedFrameRates.consideredSignals.touch);
+
+    lr1.vote = LayerVoteType::ExplicitCategory;
+    lr1.frameRateCategory = FrameRateCategory::HighHint;
+    lr1.name = "ExplicitCategory HighHint";
+    lr2.vote = LayerVoteType::Max;
+    lr2.name = "Max";
+    actualRankedFrameRates = selector.getRankedFrameRates(layers);
+    // Gets touch boost
+    EXPECT_EQ(120_Hz, actualRankedFrameRates.ranking.front().frameRateMode.fps);
+    EXPECT_EQ(kModeId120, actualRankedFrameRates.ranking.front().frameRateMode.modePtr->getId());
+    EXPECT_FALSE(actualRankedFrameRates.consideredSignals.touch);
+}
+
+TEST_P(RefreshRateSelectorTest, getBestFrameRateMode_withFrameRateCategory_TouchBoost) {
+    auto selector = createSelector(makeModes(kMode24, kMode30, kMode60, kMode120), kModeId60);
+
+    std::vector<LayerRequirement> layers = {{.weight = 1.f}, {.weight = 1.f}};
+    auto& lr1 = layers[0];
+    auto& lr2 = layers[1];
+
+    lr1.vote = LayerVoteType::ExplicitCategory;
+    lr1.frameRateCategory = FrameRateCategory::Normal;
+    lr1.name = "ExplicitCategory Normal";
+    lr2.vote = LayerVoteType::NoVote;
+    lr2.name = "NoVote";
+    auto actualRankedFrameRates = selector.getRankedFrameRates(layers, {.touch = true});
+    EXPECT_FRAME_RATE_MODE(kMode60, 60_Hz, actualRankedFrameRates.ranking.front().frameRateMode);
+    EXPECT_FALSE(actualRankedFrameRates.consideredSignals.touch);
+
+    // No touch boost, for example a game that uses setFrameRate(30, default compatibility).
+    lr1.vote = LayerVoteType::ExplicitCategory;
+    lr1.frameRateCategory = FrameRateCategory::Normal;
+    lr1.name = "ExplicitCategory Normal";
+    lr2.vote = LayerVoteType::ExplicitDefault;
+    lr2.desiredRefreshRate = 30_Hz;
+    lr2.name = "30Hz ExplicitDefault";
+    actualRankedFrameRates = selector.getRankedFrameRates(layers, {.touch = true});
+    EXPECT_FRAME_RATE_MODE(kMode60, 60_Hz, actualRankedFrameRates.ranking.front().frameRateMode);
+    EXPECT_FALSE(actualRankedFrameRates.consideredSignals.touch);
+
+    lr1.vote = LayerVoteType::ExplicitCategory;
+    lr1.frameRateCategory = FrameRateCategory::Normal;
+    lr1.name = "ExplicitCategory Normal";
+    lr2.vote = LayerVoteType::ExplicitCategory;
+    lr2.frameRateCategory = FrameRateCategory::HighHint;
+    lr2.name = "ExplicitCategory HighHint";
+    actualRankedFrameRates = selector.getRankedFrameRates(layers, {.touch = true});
+    EXPECT_FRAME_RATE_MODE(kMode120, 120_Hz, actualRankedFrameRates.ranking.front().frameRateMode);
+    EXPECT_TRUE(actualRankedFrameRates.consideredSignals.touch);
+
+    lr1.vote = LayerVoteType::ExplicitCategory;
+    lr1.frameRateCategory = FrameRateCategory::Normal;
+    lr1.name = "ExplicitCategory Normal";
+    lr2.vote = LayerVoteType::ExplicitCategory;
+    lr2.frameRateCategory = FrameRateCategory::Low;
+    lr2.name = "ExplicitCategory Low";
+    actualRankedFrameRates = selector.getRankedFrameRates(layers, {.touch = true});
+    EXPECT_FRAME_RATE_MODE(kMode60, 60_Hz, actualRankedFrameRates.ranking.front().frameRateMode);
+    EXPECT_FALSE(actualRankedFrameRates.consideredSignals.touch);
+
+    lr1.vote = LayerVoteType::ExplicitCategory;
+    lr1.frameRateCategory = FrameRateCategory::Normal;
+    lr1.name = "ExplicitCategory Normal";
+    lr2.vote = LayerVoteType::ExplicitExactOrMultiple;
+    lr2.frameRateCategory = FrameRateCategory::Default;
+    lr2.desiredRefreshRate = 30_Hz;
+    lr2.name = "30Hz ExplicitExactOrMultiple";
+    actualRankedFrameRates = selector.getRankedFrameRates(layers, {.touch = true});
+    EXPECT_FRAME_RATE_MODE(kMode120, 120_Hz, actualRankedFrameRates.ranking.front().frameRateMode);
+    EXPECT_TRUE(actualRankedFrameRates.consideredSignals.touch);
+
+    lr1.vote = LayerVoteType::ExplicitCategory;
+    lr1.frameRateCategory = FrameRateCategory::Normal;
+    lr1.name = "ExplicitCategory Normal";
+    lr2.vote = LayerVoteType::ExplicitExact;
+    lr2.desiredRefreshRate = 30_Hz;
+    lr2.name = "30Hz ExplicitExact";
+    actualRankedFrameRates = selector.getRankedFrameRates(layers, {.touch = true});
+    if (selector.supportsAppFrameRateOverrideByContent()) {
+        EXPECT_FRAME_RATE_MODE(kMode120, 120_Hz,
+                               actualRankedFrameRates.ranking.front().frameRateMode);
+        EXPECT_TRUE(actualRankedFrameRates.consideredSignals.touch);
+    } else {
+        EXPECT_FRAME_RATE_MODE(kMode30, 30_Hz,
+                               actualRankedFrameRates.ranking.front().frameRateMode);
+        EXPECT_FALSE(actualRankedFrameRates.consideredSignals.touch);
+    }
+
+    lr1.vote = LayerVoteType::ExplicitCategory;
+    lr1.frameRateCategory = FrameRateCategory::Normal;
+    lr1.name = "ExplicitCategory Normal";
+    lr2.vote = LayerVoteType::Min;
+    lr2.name = "Min";
+    actualRankedFrameRates = selector.getRankedFrameRates(layers, {.touch = true});
+    EXPECT_FRAME_RATE_MODE(kMode120, 120_Hz, actualRankedFrameRates.ranking.front().frameRateMode);
+    EXPECT_TRUE(actualRankedFrameRates.consideredSignals.touch);
+
+    lr1.vote = LayerVoteType::ExplicitCategory;
+    lr1.frameRateCategory = FrameRateCategory::Normal;
+    lr1.name = "ExplicitCategory Normal";
+    lr2.vote = LayerVoteType::Max;
+    lr2.name = "Max";
+    actualRankedFrameRates = selector.getRankedFrameRates(layers, {.touch = true});
+    EXPECT_FRAME_RATE_MODE(kMode120, 120_Hz, actualRankedFrameRates.ranking.front().frameRateMode);
+    EXPECT_FALSE(actualRankedFrameRates.consideredSignals.touch);
+
+    lr1.vote = LayerVoteType::ExplicitCategory;
+    lr1.frameRateCategory = FrameRateCategory::Normal;
+    lr1.name = "ExplicitCategory Normal";
+    lr2.vote = LayerVoteType::Heuristic;
+    lr2.name = "30Hz Heuristic";
+    actualRankedFrameRates = selector.getRankedFrameRates(layers, {.touch = true});
+    EXPECT_FRAME_RATE_MODE(kMode120, 120_Hz, actualRankedFrameRates.ranking.front().frameRateMode);
+    EXPECT_TRUE(actualRankedFrameRates.consideredSignals.touch);
+
+    lr1.vote = LayerVoteType::ExplicitCategory;
+    lr1.frameRateCategory = FrameRateCategory::Normal;
+    lr1.name = "ExplicitCategory Normal";
+    lr2.vote = LayerVoteType::ExplicitGte;
+    lr2.desiredRefreshRate = 30_Hz;
+    lr2.name = "30Hz ExplicitGte";
+    actualRankedFrameRates = selector.getRankedFrameRates(layers, {.touch = true});
+    EXPECT_FRAME_RATE_MODE(kMode60, 60_Hz, actualRankedFrameRates.ranking.front().frameRateMode);
+    EXPECT_FALSE(actualRankedFrameRates.consideredSignals.touch);
 }
 
 TEST_P(RefreshRateSelectorTest,
@@ -1725,8 +2001,8 @@
             // These layers cannot change mode due to smoothSwitchOnly, and will definitely use
             // active mode (120Hz).
             {FrameRateCategory::NoPreference, true, 120_Hz, kModeId120},
-            {FrameRateCategory::Low, true, 120_Hz, kModeId120},
-            {FrameRateCategory::Normal, true, 40_Hz, kModeId120},
+            {FrameRateCategory::Low, true, 40_Hz, kModeId120},
+            {FrameRateCategory::Normal, true, 120_Hz, kModeId120},
             {FrameRateCategory::High, true, 120_Hz, kModeId120},
     };
 
diff --git a/services/surfaceflinger/tests/unittests/SchedulerTest.cpp b/services/surfaceflinger/tests/unittests/SchedulerTest.cpp
index 10e2220..d4735c7 100644
--- a/services/surfaceflinger/tests/unittests/SchedulerTest.cpp
+++ b/services/surfaceflinger/tests/unittests/SchedulerTest.cpp
@@ -25,6 +25,7 @@
 #include "Scheduler/EventThread.h"
 #include "Scheduler/RefreshRateSelector.h"
 #include "Scheduler/VSyncPredictor.h"
+#include "Scheduler/VSyncReactor.h"
 #include "TestableScheduler.h"
 #include "TestableSurfaceFlinger.h"
 #include "mock/DisplayHardware/MockDisplayMode.h"
@@ -56,6 +57,11 @@
 using LayerHierarchyBuilder = surfaceflinger::frontend::LayerHierarchyBuilder;
 using RequestedLayerState = surfaceflinger::frontend::RequestedLayerState;
 
+class ZeroClock : public Clock {
+public:
+    nsecs_t now() const override { return 0; }
+};
+
 class SchedulerTest : public testing::Test {
 protected:
     class MockEventThreadConnection : public android::EventThreadConnection {
@@ -563,7 +569,8 @@
                                  hal::VrrConfig{.minFrameIntervalNs = static_cast<int32_t>(
                                                         frameRate.getPeriodNsecs())}));
     std::shared_ptr<VSyncPredictor> vrrTracker =
-            std::make_shared<VSyncPredictor>(kMode, kHistorySize, kMinimumSamplesForPrediction,
+            std::make_shared<VSyncPredictor>(std::make_unique<ZeroClock>(), kMode, kHistorySize,
+                                             kMinimumSamplesForPrediction,
                                              kOutlierTolerancePercent);
     std::shared_ptr<RefreshRateSelector> vrrSelectorPtr =
             std::make_shared<RefreshRateSelector>(makeModes(kMode), kMode->getId());
@@ -576,8 +583,10 @@
 
     scheduler.registerDisplay(kMode->getPhysicalDisplayId(), vrrSelectorPtr, vrrTracker);
     vrrSelectorPtr->setActiveMode(kMode->getId(), frameRate);
-    scheduler.setRenderRate(kMode->getPhysicalDisplayId(), frameRate);
+    scheduler.setRenderRate(kMode->getPhysicalDisplayId(), frameRate, /*applyImmediately*/ false);
     vrrTracker->addVsyncTimestamp(0);
+    // Set 1000 as vsync seq #0
+    vrrTracker->nextAnticipatedVSyncTimeFrom(700);
 
     EXPECT_EQ(Fps::fromPeriodNsecs(1000),
               scheduler.getNextFrameInterval(kMode->getPhysicalDisplayId(),
@@ -587,20 +596,21 @@
                                              TimePoint::fromNs(2000)));
 
     // Not crossing the min frame period
-    EXPECT_EQ(Fps::fromPeriodNsecs(1500),
+    vrrTracker->onFrameBegin(TimePoint::fromNs(2000), TimePoint::fromNs(1500));
+    EXPECT_EQ(Fps::fromPeriodNsecs(1000),
               scheduler.getNextFrameInterval(kMode->getPhysicalDisplayId(),
                                              TimePoint::fromNs(2500)));
     // Change render rate
     frameRate = Fps::fromPeriodNsecs(2000);
     vrrSelectorPtr->setActiveMode(kMode->getId(), frameRate);
-    scheduler.setRenderRate(kMode->getPhysicalDisplayId(), frameRate);
+    scheduler.setRenderRate(kMode->getPhysicalDisplayId(), frameRate, /*applyImmediately*/ false);
 
     EXPECT_EQ(Fps::fromPeriodNsecs(2000),
               scheduler.getNextFrameInterval(kMode->getPhysicalDisplayId(),
-                                             TimePoint::fromNs(2000)));
+                                             TimePoint::fromNs(4500)));
     EXPECT_EQ(Fps::fromPeriodNsecs(2000),
               scheduler.getNextFrameInterval(kMode->getPhysicalDisplayId(),
-                                             TimePoint::fromNs(4000)));
+                                             TimePoint::fromNs(6500)));
 }
 
 TEST_F(SchedulerTest, resyncAllToHardwareVsync) FTL_FAKE_GUARD(kMainThreadContext) {
diff --git a/services/surfaceflinger/tests/unittests/VSyncDispatchRealtimeTest.cpp b/services/surfaceflinger/tests/unittests/VSyncDispatchRealtimeTest.cpp
index d891008..d701a97 100644
--- a/services/surfaceflinger/tests/unittests/VSyncDispatchRealtimeTest.cpp
+++ b/services/surfaceflinger/tests/unittests/VSyncDispatchRealtimeTest.cpp
@@ -48,9 +48,9 @@
     Period minFramePeriod() const final { return Period::fromNs(currentPeriod()); }
     void resetModel() final {}
     bool needsMoreSamples() const final { return false; }
-    bool isVSyncInPhase(nsecs_t, Fps) const final { return false; }
+    bool isVSyncInPhase(nsecs_t, Fps) final { return false; }
     void setDisplayModePtr(ftl::NonNull<DisplayModePtr>) final {}
-    void setRenderRate(Fps) final {}
+    void setRenderRate(Fps, bool) final {}
     void onFrameBegin(TimePoint, TimePoint) final {}
     void onFrameMissed(TimePoint) final {}
     void dump(std::string&) const final {}
@@ -64,7 +64,7 @@
 public:
     FixedRateIdealStubTracker() : StubTracker{toNs(3ms)} {}
 
-    nsecs_t nextAnticipatedVSyncTimeFrom(nsecs_t timePoint, std::optional<nsecs_t>) const final {
+    nsecs_t nextAnticipatedVSyncTimeFrom(nsecs_t timePoint, std::optional<nsecs_t>) final {
         auto const floor = timePoint % mPeriod;
         if (floor == 0) {
             return timePoint;
@@ -77,7 +77,7 @@
 public:
     VRRStubTracker(nsecs_t period) : StubTracker(period) {}
 
-    nsecs_t nextAnticipatedVSyncTimeFrom(nsecs_t time_point, std::optional<nsecs_t>) const final {
+    nsecs_t nextAnticipatedVSyncTimeFrom(nsecs_t time_point, std::optional<nsecs_t>) final {
         std::lock_guard lock(mMutex);
         auto const normalized_to_base = time_point - mBase;
         auto const floor = (normalized_to_base) % mPeriod;
diff --git a/services/surfaceflinger/tests/unittests/VSyncPredictorTest.cpp b/services/surfaceflinger/tests/unittests/VSyncPredictorTest.cpp
index b9f3d70..48707cb 100644
--- a/services/surfaceflinger/tests/unittests/VSyncPredictorTest.cpp
+++ b/services/surfaceflinger/tests/unittests/VSyncPredictorTest.cpp
@@ -75,6 +75,28 @@
     return ftl::as_non_null(createDisplayMode(DisplayModeId(0), refreshRate, kGroup, kResolution,
                                               DEFAULT_DISPLAY_ID));
 }
+
+class TestClock : public Clock {
+public:
+    TestClock() = default;
+
+    nsecs_t now() const override { return mNow; }
+    void setNow(nsecs_t now) { mNow = now; }
+
+private:
+    nsecs_t mNow = 0;
+};
+
+class ClockWrapper : public Clock {
+public:
+    ClockWrapper(std::shared_ptr<Clock> const& clock) : mClock(clock) {}
+
+    nsecs_t now() const { return mClock->now(); }
+
+private:
+    std::shared_ptr<Clock> const mClock;
+};
+
 } // namespace
 
 struct VSyncPredictorTest : testing::Test {
@@ -86,8 +108,10 @@
     static constexpr size_t kOutlierTolerancePercent = 25;
     static constexpr nsecs_t mMaxRoundingError = 100;
 
-    VSyncPredictor tracker{mMode, kHistorySize, kMinimumSamplesForPrediction,
-                           kOutlierTolerancePercent};
+    std::shared_ptr<TestClock> mClock{std::make_shared<TestClock>()};
+
+    VSyncPredictor tracker{std::make_unique<ClockWrapper>(mClock), mMode, kHistorySize,
+                           kMinimumSamplesForPrediction, kOutlierTolerancePercent};
 };
 
 TEST_F(VSyncPredictorTest, reportsAnticipatedPeriod) {
@@ -408,7 +432,8 @@
 // See b/151146131
 TEST_F(VSyncPredictorTest, hasEnoughPrecision) {
     const auto mode = displayMode(mPeriod);
-    VSyncPredictor tracker{mode, 20, kMinimumSamplesForPrediction, kOutlierTolerancePercent};
+    VSyncPredictor tracker{std::make_unique<ClockWrapper>(mClock), mode, 20,
+                           kMinimumSamplesForPrediction, kOutlierTolerancePercent};
     std::vector<nsecs_t> const simulatedVsyncs{840873348817, 840890049444, 840906762675,
                                                840923581635, 840940161584, 840956868096,
                                                840973702473, 840990256277, 841007116851,
@@ -595,44 +620,15 @@
         tracker.addVsyncTimestamp(mNow);
     }
 
-    tracker.setRenderRate(Fps::fromPeriodNsecs(3 * mPeriod));
+    tracker.setRenderRate(Fps::fromPeriodNsecs(3 * mPeriod), /*applyImmediately*/ false);
 
-    EXPECT_THAT(tracker.nextAnticipatedVSyncTimeFrom(mNow), Eq(mNow + mPeriod));
-    EXPECT_THAT(tracker.nextAnticipatedVSyncTimeFrom(mNow + 100), Eq(mNow + mPeriod));
-    EXPECT_THAT(tracker.nextAnticipatedVSyncTimeFrom(mNow + 1100), Eq(mNow + 4 * mPeriod));
-    EXPECT_THAT(tracker.nextAnticipatedVSyncTimeFrom(mNow + 2100), Eq(mNow + 4 * mPeriod));
-    EXPECT_THAT(tracker.nextAnticipatedVSyncTimeFrom(mNow + 3100), Eq(mNow + 4 * mPeriod));
-    EXPECT_THAT(tracker.nextAnticipatedVSyncTimeFrom(mNow + 4100), Eq(mNow + 7 * mPeriod));
-    EXPECT_THAT(tracker.nextAnticipatedVSyncTimeFrom(mNow + 5100), Eq(mNow + 7 * mPeriod));
-}
-
-TEST_F(VSyncPredictorTest, setRenderRateOfDivisorIsInPhase) {
-    auto last = mNow;
-    for (auto i = 0u; i < kMinimumSamplesForPrediction; i++) {
-        EXPECT_THAT(tracker.nextAnticipatedVSyncTimeFrom(mNow), Eq(last + mPeriod));
-        mNow += mPeriod;
-        last = mNow;
-        tracker.addVsyncTimestamp(mNow);
-    }
-
-    const auto refreshRate = Fps::fromPeriodNsecs(mPeriod);
-
-    tracker.setRenderRate(refreshRate / 4);
     EXPECT_THAT(tracker.nextAnticipatedVSyncTimeFrom(mNow), Eq(mNow + 3 * mPeriod));
-    EXPECT_THAT(tracker.nextAnticipatedVSyncTimeFrom(mNow + 3 * mPeriod), Eq(mNow + 7 * mPeriod));
-    EXPECT_THAT(tracker.nextAnticipatedVSyncTimeFrom(mNow + 7 * mPeriod), Eq(mNow + 11 * mPeriod));
-
-    tracker.setRenderRate(refreshRate / 2);
-    EXPECT_THAT(tracker.nextAnticipatedVSyncTimeFrom(mNow), Eq(mNow + 1 * mPeriod));
-    EXPECT_THAT(tracker.nextAnticipatedVSyncTimeFrom(mNow + 1 * mPeriod), Eq(mNow + 3 * mPeriod));
-    EXPECT_THAT(tracker.nextAnticipatedVSyncTimeFrom(mNow + 3 * mPeriod), Eq(mNow + 5 * mPeriod));
-    EXPECT_THAT(tracker.nextAnticipatedVSyncTimeFrom(mNow + 5 * mPeriod), Eq(mNow + 7 * mPeriod));
-    EXPECT_THAT(tracker.nextAnticipatedVSyncTimeFrom(mNow + 7 * mPeriod), Eq(mNow + 9 * mPeriod));
-    EXPECT_THAT(tracker.nextAnticipatedVSyncTimeFrom(mNow + 9 * mPeriod), Eq(mNow + 11 * mPeriod));
-
-    tracker.setRenderRate(refreshRate / 6);
-    EXPECT_THAT(tracker.nextAnticipatedVSyncTimeFrom(mNow), Eq(mNow + 1 * mPeriod));
-    EXPECT_THAT(tracker.nextAnticipatedVSyncTimeFrom(mNow + 1 * mPeriod), Eq(mNow + 7 * mPeriod));
+    EXPECT_THAT(tracker.nextAnticipatedVSyncTimeFrom(mNow + 100), Eq(mNow + 3 * mPeriod));
+    EXPECT_THAT(tracker.nextAnticipatedVSyncTimeFrom(mNow + 1100), Eq(mNow + 3 * mPeriod));
+    EXPECT_THAT(tracker.nextAnticipatedVSyncTimeFrom(mNow + 2100), Eq(mNow + 3 * mPeriod));
+    EXPECT_THAT(tracker.nextAnticipatedVSyncTimeFrom(mNow + 3100), Eq(mNow + 6 * mPeriod));
+    EXPECT_THAT(tracker.nextAnticipatedVSyncTimeFrom(mNow + 4100), Eq(mNow + 6 * mPeriod));
+    EXPECT_THAT(tracker.nextAnticipatedVSyncTimeFrom(mNow + 5100), Eq(mNow + 6 * mPeriod));
 }
 
 TEST_F(VSyncPredictorTest, setRenderRateIsIgnoredIfNotDivisor) {
@@ -644,7 +640,7 @@
         tracker.addVsyncTimestamp(mNow);
     }
 
-    tracker.setRenderRate(Fps::fromPeriodNsecs(3.5f * mPeriod));
+    tracker.setRenderRate(Fps::fromPeriodNsecs(3.5f * mPeriod), /*applyImmediately*/ false);
 
     EXPECT_THAT(tracker.nextAnticipatedVSyncTimeFrom(mNow), Eq(mNow + mPeriod));
     EXPECT_THAT(tracker.nextAnticipatedVSyncTimeFrom(mNow + 100), Eq(mNow + mPeriod));
@@ -655,6 +651,178 @@
     EXPECT_THAT(tracker.nextAnticipatedVSyncTimeFrom(mNow + 5100), Eq(mNow + 6 * mPeriod));
 }
 
+TEST_F(VSyncPredictorTest, setRenderRateHighIsAppliedImmediately) {
+    SET_FLAG_FOR_TEST(flags::vrr_config, true);
+
+    const int32_t kGroup = 0;
+    const auto kResolution = ui::Size(1920, 1080);
+    const auto vsyncRate = Fps::fromPeriodNsecs(500);
+    const auto minFrameRate = Fps::fromPeriodNsecs(1000);
+    hal::VrrConfig vrrConfig;
+    vrrConfig.minFrameIntervalNs = minFrameRate.getPeriodNsecs();
+    const ftl::NonNull<DisplayModePtr> kMode =
+            ftl::as_non_null(createDisplayModeBuilder(DisplayModeId(0), vsyncRate, kGroup,
+                                                      kResolution, DEFAULT_DISPLAY_ID)
+                                     .setVrrConfig(std::move(vrrConfig))
+                                     .build());
+
+    VSyncPredictor vrrTracker{std::make_unique<ClockWrapper>(mClock), kMode, kHistorySize,
+                              kMinimumSamplesForPrediction, kOutlierTolerancePercent};
+
+    vrrTracker.setRenderRate(Fps::fromPeriodNsecs(1000), /*applyImmediately*/ false);
+    vrrTracker.addVsyncTimestamp(0);
+    EXPECT_EQ(1000, vrrTracker.nextAnticipatedVSyncTimeFrom(700));
+    EXPECT_EQ(2000, vrrTracker.nextAnticipatedVSyncTimeFrom(1000, 1000));
+
+    // commit to a vsync in the future
+    EXPECT_EQ(6000, vrrTracker.nextAnticipatedVSyncTimeFrom(5000, 5000));
+
+    vrrTracker.setRenderRate(Fps::fromPeriodNsecs(2000), /*applyImmediately*/ false);
+    EXPECT_EQ(5000, vrrTracker.nextAnticipatedVSyncTimeFrom(4000, 4000));
+    EXPECT_EQ(6000, vrrTracker.nextAnticipatedVSyncTimeFrom(5000, 5000));
+    EXPECT_EQ(8000, vrrTracker.nextAnticipatedVSyncTimeFrom(6000, 6000));
+
+    EXPECT_EQ(12000, vrrTracker.nextAnticipatedVSyncTimeFrom(10000, 10000));
+
+    vrrTracker.setRenderRate(Fps::fromPeriodNsecs(3500), /*applyImmediately*/ false);
+    EXPECT_EQ(5000, vrrTracker.nextAnticipatedVSyncTimeFrom(4000, 4000));
+    EXPECT_EQ(6000, vrrTracker.nextAnticipatedVSyncTimeFrom(5000, 5000));
+    EXPECT_EQ(8000, vrrTracker.nextAnticipatedVSyncTimeFrom(6000, 6000));
+    EXPECT_EQ(10000, vrrTracker.nextAnticipatedVSyncTimeFrom(8000, 8000));
+    EXPECT_EQ(12000, vrrTracker.nextAnticipatedVSyncTimeFrom(10000, 10000));
+    EXPECT_EQ(15500, vrrTracker.nextAnticipatedVSyncTimeFrom(12000, 12000));
+    EXPECT_EQ(19000, vrrTracker.nextAnticipatedVSyncTimeFrom(15500, 15500));
+
+    vrrTracker.setRenderRate(Fps::fromPeriodNsecs(2500), /*applyImmediately*/ false);
+    EXPECT_EQ(5000, vrrTracker.nextAnticipatedVSyncTimeFrom(4000, 4000));
+    EXPECT_EQ(6000, vrrTracker.nextAnticipatedVSyncTimeFrom(5000, 5000));
+    EXPECT_EQ(8000, vrrTracker.nextAnticipatedVSyncTimeFrom(6000, 6000));
+    EXPECT_EQ(10000, vrrTracker.nextAnticipatedVSyncTimeFrom(8000, 8000));
+    EXPECT_EQ(12000, vrrTracker.nextAnticipatedVSyncTimeFrom(10000, 10000));
+    EXPECT_EQ(15500, vrrTracker.nextAnticipatedVSyncTimeFrom(12000, 12000));
+    EXPECT_EQ(19000, vrrTracker.nextAnticipatedVSyncTimeFrom(15500, 15500));
+    EXPECT_EQ(21500, vrrTracker.nextAnticipatedVSyncTimeFrom(19000, 19000));
+
+    vrrTracker.setRenderRate(Fps::fromPeriodNsecs(1000), /*applyImmediately*/ false);
+    EXPECT_EQ(5000, vrrTracker.nextAnticipatedVSyncTimeFrom(4000, 4000));
+    EXPECT_EQ(6000, vrrTracker.nextAnticipatedVSyncTimeFrom(5000, 5000));
+    EXPECT_EQ(7000, vrrTracker.nextAnticipatedVSyncTimeFrom(6000, 6000));
+    EXPECT_EQ(9000, vrrTracker.nextAnticipatedVSyncTimeFrom(8000, 8000));
+    EXPECT_EQ(11000, vrrTracker.nextAnticipatedVSyncTimeFrom(10000, 10000));
+    EXPECT_EQ(13000, vrrTracker.nextAnticipatedVSyncTimeFrom(12000, 12000));
+    EXPECT_EQ(17000, vrrTracker.nextAnticipatedVSyncTimeFrom(15500, 15500));
+    EXPECT_EQ(20000, vrrTracker.nextAnticipatedVSyncTimeFrom(19000, 19000));
+}
+
+TEST_F(VSyncPredictorTest, minFramePeriodDoesntApplyWhenSameWithRefreshRate) {
+    SET_FLAG_FOR_TEST(flags::vrr_config, true);
+
+    const int32_t kGroup = 0;
+    const auto kResolution = ui::Size(1920, 1080);
+    const auto vsyncRate = Fps::fromPeriodNsecs(1000);
+    const auto minFrameRate = Fps::fromPeriodNsecs(1000);
+    hal::VrrConfig vrrConfig;
+    vrrConfig.minFrameIntervalNs = minFrameRate.getPeriodNsecs();
+    const ftl::NonNull<DisplayModePtr> kMode =
+            ftl::as_non_null(createDisplayModeBuilder(DisplayModeId(0), vsyncRate, kGroup,
+                                                      kResolution, DEFAULT_DISPLAY_ID)
+                                     .setVrrConfig(std::move(vrrConfig))
+                                     .build());
+
+    VSyncPredictor vrrTracker{std::make_unique<ClockWrapper>(mClock), kMode, kHistorySize,
+                              kMinimumSamplesForPrediction, kOutlierTolerancePercent};
+
+    vrrTracker.setRenderRate(Fps::fromPeriodNsecs(1000), /*applyImmediately*/ false);
+    vrrTracker.addVsyncTimestamp(0);
+    EXPECT_EQ(1000, vrrTracker.nextAnticipatedVSyncTimeFrom(700));
+    EXPECT_EQ(2000, vrrTracker.nextAnticipatedVSyncTimeFrom(1000, 1000));
+
+    // Assume that the last vsync is wrong due to a vsync drift. It shouldn't matter.
+    EXPECT_EQ(2000, vrrTracker.nextAnticipatedVSyncTimeFrom(1000, 1700));
+}
+
+TEST_F(VSyncPredictorTest, setRenderRateExplicitAppliedImmediately) {
+    SET_FLAG_FOR_TEST(flags::vrr_config, true);
+
+    const int32_t kGroup = 0;
+    const auto kResolution = ui::Size(1920, 1080);
+    const auto vsyncRate = Fps::fromPeriodNsecs(500);
+    const auto minFrameRate = Fps::fromPeriodNsecs(1000);
+    hal::VrrConfig vrrConfig;
+    vrrConfig.minFrameIntervalNs = minFrameRate.getPeriodNsecs();
+    const ftl::NonNull<DisplayModePtr> kMode =
+            ftl::as_non_null(createDisplayModeBuilder(DisplayModeId(0), vsyncRate, kGroup,
+                                                      kResolution, DEFAULT_DISPLAY_ID)
+                                     .setVrrConfig(std::move(vrrConfig))
+                                     .build());
+
+    VSyncPredictor vrrTracker{std::make_unique<ClockWrapper>(mClock), kMode, kHistorySize,
+                              kMinimumSamplesForPrediction, kOutlierTolerancePercent};
+
+    vrrTracker.setRenderRate(Fps::fromPeriodNsecs(1000), /*applyImmediately*/ false);
+    vrrTracker.addVsyncTimestamp(0);
+    EXPECT_EQ(1000, vrrTracker.nextAnticipatedVSyncTimeFrom(700));
+    EXPECT_EQ(2000, vrrTracker.nextAnticipatedVSyncTimeFrom(1000, 1000));
+
+    // commit to a vsync in the future
+    EXPECT_EQ(6000, vrrTracker.nextAnticipatedVSyncTimeFrom(5000, 2000));
+
+    vrrTracker.setRenderRate(Fps::fromPeriodNsecs(2000), /*applyImmediately*/ true);
+    EXPECT_EQ(5000, vrrTracker.nextAnticipatedVSyncTimeFrom(4000));
+    EXPECT_EQ(7000, vrrTracker.nextAnticipatedVSyncTimeFrom(5000, 5000));
+    EXPECT_EQ(9000, vrrTracker.nextAnticipatedVSyncTimeFrom(7000, 7000));
+}
+
+TEST_F(VSyncPredictorTest, selectsClosestVsyncAfterInactivity) {
+    SET_FLAG_FOR_TEST(flags::vrr_config, true);
+
+    const int32_t kGroup = 0;
+    const auto kResolution = ui::Size(1920, 1080);
+    const auto vsyncRate = Fps::fromPeriodNsecs(500);
+    const auto minFrameRate = Fps::fromPeriodNsecs(1000);
+    hal::VrrConfig vrrConfig;
+    vrrConfig.minFrameIntervalNs = minFrameRate.getPeriodNsecs();
+    const ftl::NonNull<DisplayModePtr> kMode =
+            ftl::as_non_null(createDisplayModeBuilder(DisplayModeId(0), vsyncRate, kGroup,
+                                                      kResolution, DEFAULT_DISPLAY_ID)
+                                     .setVrrConfig(std::move(vrrConfig))
+                                     .build());
+
+    VSyncPredictor vrrTracker{std::make_unique<ClockWrapper>(mClock), kMode, kHistorySize,
+                              kMinimumSamplesForPrediction, kOutlierTolerancePercent};
+
+    vrrTracker.setRenderRate(Fps::fromPeriodNsecs(5000), /*applyImmediately*/ false);
+    vrrTracker.addVsyncTimestamp(0);
+    EXPECT_EQ(5000, vrrTracker.nextAnticipatedVSyncTimeFrom(4700));
+    EXPECT_EQ(10000, vrrTracker.nextAnticipatedVSyncTimeFrom(5000, 5000));
+
+    mClock->setNow(50000);
+    EXPECT_EQ(50500, vrrTracker.nextAnticipatedVSyncTimeFrom(50000, 10000));
+}
+
+TEST_F(VSyncPredictorTest, returnsCorrectVsyncWhenLastIsNot) {
+    SET_FLAG_FOR_TEST(flags::vrr_config, true);
+
+    const int32_t kGroup = 0;
+    const auto kResolution = ui::Size(1920, 1080);
+    const auto vsyncRate = Fps::fromPeriodNsecs(500);
+    const auto minFrameRate = Fps::fromPeriodNsecs(1000);
+    hal::VrrConfig vrrConfig;
+    vrrConfig.minFrameIntervalNs = minFrameRate.getPeriodNsecs();
+    const ftl::NonNull<DisplayModePtr> kMode =
+            ftl::as_non_null(createDisplayModeBuilder(DisplayModeId(0), vsyncRate, kGroup,
+                                                      kResolution, DEFAULT_DISPLAY_ID)
+                                     .setVrrConfig(std::move(vrrConfig))
+                                     .build());
+
+    VSyncPredictor vrrTracker{std::make_unique<ClockWrapper>(mClock), kMode, kHistorySize,
+                              kMinimumSamplesForPrediction, kOutlierTolerancePercent};
+
+    vrrTracker.setRenderRate(Fps::fromPeriodNsecs(1000), /*applyImmediately*/ false);
+    vrrTracker.addVsyncTimestamp(0);
+    EXPECT_EQ(2000, vrrTracker.nextAnticipatedVSyncTimeFrom(1234, 1234));
+}
+
 TEST_F(VSyncPredictorTest, adjustsVrrTimeline) {
     SET_FLAG_FOR_TEST(flags::vrr_config, true);
 
@@ -670,10 +838,10 @@
                                      .setVrrConfig(std::move(vrrConfig))
                                      .build());
 
-    VSyncPredictor vrrTracker{kMode, kHistorySize, kMinimumSamplesForPrediction,
-                              kOutlierTolerancePercent};
+    VSyncPredictor vrrTracker{std::make_unique<ClockWrapper>(mClock), kMode, kHistorySize,
+                              kMinimumSamplesForPrediction, kOutlierTolerancePercent};
 
-    vrrTracker.setRenderRate(minFrameRate);
+    vrrTracker.setRenderRate(minFrameRate, /*applyImmediately*/ false);
     vrrTracker.addVsyncTimestamp(0);
     EXPECT_EQ(1000, vrrTracker.nextAnticipatedVSyncTimeFrom(700));
     EXPECT_EQ(2000, vrrTracker.nextAnticipatedVSyncTimeFrom(1000));
@@ -687,7 +855,95 @@
     vrrTracker.onFrameMissed(TimePoint::fromNs(4500));
     EXPECT_EQ(5000, vrrTracker.nextAnticipatedVSyncTimeFrom(4500, 4500));
     EXPECT_EQ(6000, vrrTracker.nextAnticipatedVSyncTimeFrom(5000, 5000));
+
+    vrrTracker.onFrameBegin(TimePoint::fromNs(7000), TimePoint::fromNs(6500));
+    EXPECT_EQ(10500, vrrTracker.nextAnticipatedVSyncTimeFrom(9000, 7000));
 }
+
+TEST_F(VSyncPredictorTest, adjustsVrrTimelineTwoClients) {
+    SET_FLAG_FOR_TEST(flags::vrr_config, true);
+
+    const int32_t kGroup = 0;
+    const auto kResolution = ui::Size(1920, 1080);
+    const auto refreshRate = Fps::fromPeriodNsecs(500);
+    const auto minFrameRate = Fps::fromPeriodNsecs(1000);
+    hal::VrrConfig vrrConfig;
+    vrrConfig.minFrameIntervalNs = minFrameRate.getPeriodNsecs();
+    const ftl::NonNull<DisplayModePtr> kMode =
+            ftl::as_non_null(createDisplayModeBuilder(DisplayModeId(0), refreshRate, kGroup,
+                                                      kResolution, DEFAULT_DISPLAY_ID)
+                                     .setVrrConfig(std::move(vrrConfig))
+                                     .build());
+
+    VSyncPredictor vrrTracker{std::make_unique<ClockWrapper>(mClock), kMode, kHistorySize,
+                              kMinimumSamplesForPrediction, kOutlierTolerancePercent};
+
+    vrrTracker.setRenderRate(minFrameRate, /*applyImmediately*/ false);
+    vrrTracker.addVsyncTimestamp(0);
+
+    // App runs ahead
+    EXPECT_EQ(3000, vrrTracker.nextAnticipatedVSyncTimeFrom(2700));
+    EXPECT_EQ(4000, vrrTracker.nextAnticipatedVSyncTimeFrom(3000, 3000));
+    EXPECT_EQ(5000, vrrTracker.nextAnticipatedVSyncTimeFrom(4000, 4000));
+
+    // SF starts to catch up
+    EXPECT_EQ(3000, vrrTracker.nextAnticipatedVSyncTimeFrom(2700));
+    vrrTracker.onFrameBegin(TimePoint::fromNs(3000), TimePoint::fromNs(0));
+
+    // SF misses last frame (3000) and observes that when committing (4000)
+    EXPECT_EQ(6000, vrrTracker.nextAnticipatedVSyncTimeFrom(5000, 5000));
+    EXPECT_EQ(4000, vrrTracker.nextAnticipatedVSyncTimeFrom(3700));
+    vrrTracker.onFrameMissed(TimePoint::fromNs(4000));
+
+    // SF wakes up again instead of the (4000) missed frame
+    EXPECT_EQ(4500, vrrTracker.nextAnticipatedVSyncTimeFrom(4000, 4000));
+    vrrTracker.onFrameBegin(TimePoint::fromNs(4500), TimePoint::fromNs(4500));
+
+    // Timeline shifted. The app needs to get the next frame at (7500) as its last frame (6500) will
+    // be presented at (7500)
+    EXPECT_EQ(7500, vrrTracker.nextAnticipatedVSyncTimeFrom(6000, 6000));
+    EXPECT_EQ(5500, vrrTracker.nextAnticipatedVSyncTimeFrom(4500, 4500));
+    vrrTracker.onFrameBegin(TimePoint::fromNs(5500), TimePoint::fromNs(4500));
+
+    EXPECT_EQ(8500, vrrTracker.nextAnticipatedVSyncTimeFrom(7500, 7500));
+    EXPECT_EQ(6500, vrrTracker.nextAnticipatedVSyncTimeFrom(5500, 5500));
+    vrrTracker.onFrameBegin(TimePoint::fromNs(6500), TimePoint::fromNs(5500));
+}
+
+TEST_F(VSyncPredictorTest, renderRateIsPreservedForCommittedVsyncs) {
+    tracker.addVsyncTimestamp(1000);
+
+    EXPECT_THAT(tracker.nextAnticipatedVSyncTimeFrom(1), Eq(1000));
+    EXPECT_THAT(tracker.nextAnticipatedVSyncTimeFrom(5001), Eq(6000));
+    EXPECT_THAT(tracker.nextAnticipatedVSyncTimeFrom(6001), Eq(7000));
+
+    tracker.setRenderRate(Fps::fromPeriodNsecs(2000), /*applyImmediately*/ false);
+    EXPECT_THAT(tracker.nextAnticipatedVSyncTimeFrom(1), Eq(1000));
+    EXPECT_THAT(tracker.nextAnticipatedVSyncTimeFrom(5001), Eq(6000));
+    EXPECT_THAT(tracker.nextAnticipatedVSyncTimeFrom(6001), Eq(7000));
+    EXPECT_THAT(tracker.nextAnticipatedVSyncTimeFrom(7001), Eq(9000));
+    EXPECT_THAT(tracker.nextAnticipatedVSyncTimeFrom(8001), Eq(9000));
+    EXPECT_THAT(tracker.nextAnticipatedVSyncTimeFrom(9001), Eq(11000));
+    EXPECT_THAT(tracker.nextAnticipatedVSyncTimeFrom(10001), Eq(11000));
+
+    tracker.setRenderRate(Fps::fromPeriodNsecs(3000), /*applyImmediately*/ false);
+    EXPECT_THAT(tracker.nextAnticipatedVSyncTimeFrom(1), Eq(1000));
+    EXPECT_THAT(tracker.nextAnticipatedVSyncTimeFrom(5001), Eq(6000));
+    EXPECT_THAT(tracker.nextAnticipatedVSyncTimeFrom(6001), Eq(7000));
+    EXPECT_THAT(tracker.nextAnticipatedVSyncTimeFrom(7001), Eq(9000));
+    EXPECT_THAT(tracker.nextAnticipatedVSyncTimeFrom(8001), Eq(9000));
+    EXPECT_THAT(tracker.nextAnticipatedVSyncTimeFrom(9001), Eq(11000));
+    EXPECT_THAT(tracker.nextAnticipatedVSyncTimeFrom(10001), Eq(11000));
+    EXPECT_THAT(tracker.nextAnticipatedVSyncTimeFrom(11001), Eq(14000));
+    EXPECT_THAT(tracker.nextAnticipatedVSyncTimeFrom(12001), Eq(14000));
+
+    // Check the purge logic works
+    mClock->setNow(20000);
+    EXPECT_THAT(tracker.nextAnticipatedVSyncTimeFrom(1), Eq(2000));
+    EXPECT_THAT(tracker.nextAnticipatedVSyncTimeFrom(5001), Eq(8000));
+    EXPECT_THAT(tracker.nextAnticipatedVSyncTimeFrom(6001), Eq(8000));
+}
+
 } // namespace android::scheduler
 
 // TODO(b/129481165): remove the #pragma below and fix conversion issues
diff --git a/services/surfaceflinger/tests/unittests/mock/MockVSyncTracker.h b/services/surfaceflinger/tests/unittests/mock/MockVSyncTracker.h
index 3870983..c311901 100644
--- a/services/surfaceflinger/tests/unittests/mock/MockVSyncTracker.h
+++ b/services/surfaceflinger/tests/unittests/mock/MockVSyncTracker.h
@@ -29,14 +29,14 @@
 
     MOCK_METHOD(bool, addVsyncTimestamp, (nsecs_t), (override));
     MOCK_METHOD(nsecs_t, nextAnticipatedVSyncTimeFrom, (nsecs_t, std::optional<nsecs_t>),
-                (const, override));
+                (override));
     MOCK_METHOD(nsecs_t, currentPeriod, (), (const, override));
     MOCK_METHOD(Period, minFramePeriod, (), (const, override));
     MOCK_METHOD(void, resetModel, (), (override));
     MOCK_METHOD(bool, needsMoreSamples, (), (const, override));
-    MOCK_METHOD(bool, isVSyncInPhase, (nsecs_t, Fps), (const, override));
+    MOCK_METHOD(bool, isVSyncInPhase, (nsecs_t, Fps), (override));
     MOCK_METHOD(void, setDisplayModePtr, (ftl::NonNull<DisplayModePtr>), (override));
-    MOCK_METHOD(void, setRenderRate, (Fps), (override));
+    MOCK_METHOD(void, setRenderRate, (Fps, bool), (override));
     MOCK_METHOD(void, onFrameBegin, (TimePoint, TimePoint), (override));
     MOCK_METHOD(void, onFrameMissed, (TimePoint), (override));
     MOCK_METHOD(void, dump, (std::string&), (const, override));