Merge "Refactor `isHdrDataspace` function." into udc-qpr-dev
diff --git a/cmds/installd/otapreopt_chroot.cpp b/cmds/installd/otapreopt_chroot.cpp
index c86993c..c40caf5 100644
--- a/cmds/installd/otapreopt_chroot.cpp
+++ b/cmds/installd/otapreopt_chroot.cpp
@@ -19,9 +19,12 @@
 #include <sys/mount.h>
 #include <sys/stat.h>
 #include <sys/wait.h>
+#include <unistd.h>
 
+#include <algorithm>
 #include <array>
 #include <fstream>
+#include <iostream>
 #include <sstream>
 
 #include <android-base/file.h>
@@ -29,6 +32,7 @@
 #include <android-base/macros.h>
 #include <android-base/scopeguard.h>
 #include <android-base/stringprintf.h>
+#include <android-base/strings.h>
 #include <android-base/unique_fd.h>
 #include <libdm/dm.h>
 #include <selinux/android.h>
@@ -37,7 +41,7 @@
 #include "otapreopt_utils.h"
 
 #ifndef LOG_TAG
-#define LOG_TAG "otapreopt"
+#define LOG_TAG "otapreopt_chroot"
 #endif
 
 using android::base::StringPrintf;
@@ -49,20 +53,22 @@
 // so just try the possibilities one by one.
 static constexpr std::array kTryMountFsTypes = {"ext4", "erofs"};
 
-static void CloseDescriptor(int fd) {
-    if (fd >= 0) {
-        int result = close(fd);
-        UNUSED(result);  // Ignore result. Printing to logcat will open a new descriptor
-                         // that we do *not* want.
-    }
-}
-
 static void CloseDescriptor(const char* descriptor_string) {
     int fd = -1;
     std::istringstream stream(descriptor_string);
     stream >> fd;
     if (!stream.fail()) {
-        CloseDescriptor(fd);
+        if (fd >= 0) {
+            if (close(fd) < 0) {
+                PLOG(ERROR) << "Failed to close " << fd;
+            }
+        }
+    }
+}
+
+static void SetCloseOnExec(int fd) {
+    if (fcntl(fd, F_SETFD, FD_CLOEXEC) < 0) {
+        PLOG(ERROR) << "Failed to set FD_CLOEXEC on " << fd;
     }
 }
 
@@ -129,24 +135,39 @@
 }
 
 // Entry for otapreopt_chroot. Expected parameters are:
-//   [cmd] [status-fd] [target-slot] "dexopt" [dexopt-params]
-// The file descriptor denoted by status-fd will be closed. The rest of the parameters will
-// be passed on to otapreopt in the chroot.
+//
+//   [cmd] [status-fd] [target-slot-suffix]
+//
+// The file descriptor denoted by status-fd will be closed. Dexopt commands on
+// the form
+//
+//   "dexopt" [dexopt-params]
+//
+// are then read from stdin until EOF and passed on to /system/bin/otapreopt one
+// by one. After each call a line with the current command count is written to
+// stdout and flushed.
 static int otapreopt_chroot(const int argc, char **arg) {
     // Validate arguments
-    // We need the command, status channel and target slot, at a minimum.
-    if(argc < 3) {
-        PLOG(ERROR) << "Not enough arguments.";
+    if (argc == 2 && std::string_view(arg[1]) == "--version") {
+        // Accept a single --version flag, to allow the script to tell this binary
+        // from the earlier one.
+        std::cout << "2" << std::endl;
+        return 0;
+    }
+    if (argc != 3) {
+        LOG(ERROR) << "Wrong number of arguments: " << argc;
         exit(208);
     }
-    // Close all file descriptors. They are coming from the caller, we do not want to pass them
-    // on across our fork/exec into a different domain.
-    // 1) Default descriptors.
-    CloseDescriptor(STDIN_FILENO);
-    CloseDescriptor(STDOUT_FILENO);
-    CloseDescriptor(STDERR_FILENO);
-    // 2) The status channel.
-    CloseDescriptor(arg[1]);
+    const char* status_fd = arg[1];
+    const char* slot_suffix = arg[2];
+
+    // Set O_CLOEXEC on standard fds. They are coming from the caller, we do not
+    // want to pass them on across our fork/exec into a different domain.
+    SetCloseOnExec(STDIN_FILENO);
+    SetCloseOnExec(STDOUT_FILENO);
+    SetCloseOnExec(STDERR_FILENO);
+    // Close the status channel.
+    CloseDescriptor(status_fd);
 
     // We need to run the otapreopt tool from the postinstall partition. As such, set up a
     // mount namespace and change root.
@@ -185,20 +206,20 @@
     //  2) We're in a mount namespace here, so when we die, this will be cleaned up.
     //  3) Ignore errors. Printing anything at this stage will open a file descriptor
     //     for logging.
-    if (!ValidateTargetSlotSuffix(arg[2])) {
-        LOG(ERROR) << "Target slot suffix not legal: " << arg[2];
+    if (!ValidateTargetSlotSuffix(slot_suffix)) {
+        LOG(ERROR) << "Target slot suffix not legal: " << slot_suffix;
         exit(207);
     }
-    TryExtraMount("vendor", arg[2], "/postinstall/vendor");
+    TryExtraMount("vendor", slot_suffix, "/postinstall/vendor");
 
     // Try to mount the product partition. update_engine doesn't do this for us, but we
     // want it for product APKs. Same notes as vendor above.
-    TryExtraMount("product", arg[2], "/postinstall/product");
+    TryExtraMount("product", slot_suffix, "/postinstall/product");
 
     // Try to mount the system_ext partition. update_engine doesn't do this for
     // us, but we want it for system_ext APKs. Same notes as vendor and product
     // above.
-    TryExtraMount("system_ext", arg[2], "/postinstall/system_ext");
+    TryExtraMount("system_ext", slot_suffix, "/postinstall/system_ext");
 
     constexpr const char* kPostInstallLinkerconfig = "/postinstall/linkerconfig";
     // Try to mount /postinstall/linkerconfig. we will set it up after performing the chroot
@@ -329,30 +350,37 @@
         exit(218);
     }
 
-    // Now go on and run otapreopt.
+    // Now go on and read dexopt lines from stdin and pass them on to otapreopt.
 
-    // Incoming:  cmd + status-fd + target-slot + cmd...      | Incoming | = argc
-    // Outgoing:  cmd             + target-slot + cmd...      | Outgoing | = argc - 1
-    std::vector<std::string> cmd;
-    cmd.reserve(argc);
-    cmd.push_back("/system/bin/otapreopt");
+    int count = 1;
+    for (std::array<char, 1000> linebuf;
+         std::cin.clear(), std::cin.getline(&linebuf[0], linebuf.size()); ++count) {
+        // Subtract one from gcount() since getline() counts the newline.
+        std::string line(&linebuf[0], std::cin.gcount() - 1);
 
-    // The first parameter is the status file descriptor, skip.
-    for (size_t i = 2; i < static_cast<size_t>(argc); ++i) {
-        cmd.push_back(arg[i]);
+        if (std::cin.fail()) {
+            LOG(ERROR) << "Command exceeds max length " << linebuf.size() << " - skipped: " << line;
+            continue;
+        }
+
+        std::vector<std::string> tokenized_line = android::base::Tokenize(line, " ");
+        std::vector<std::string> cmd{"/system/bin/otapreopt", slot_suffix};
+        std::move(tokenized_line.begin(), tokenized_line.end(), std::back_inserter(cmd));
+
+        LOG(INFO) << "Command " << count << ": " << android::base::Join(cmd, " ");
+
+        // Fork and execute otapreopt in its own process.
+        std::string error_msg;
+        bool exec_result = Exec(cmd, &error_msg);
+        if (!exec_result) {
+            LOG(ERROR) << "Running otapreopt failed: " << error_msg;
+        }
+
+        // Print the count to stdout and flush to indicate progress.
+        std::cout << count << std::endl;
     }
 
-    // Fork and execute otapreopt in its own process.
-    std::string error_msg;
-    bool exec_result = Exec(cmd, &error_msg);
-    if (!exec_result) {
-        LOG(ERROR) << "Running otapreopt failed: " << error_msg;
-    }
-
-    if (!exec_result) {
-        exit(213);
-    }
-
+    LOG(INFO) << "No more dexopt commands";
     return 0;
 }
 
diff --git a/cmds/installd/otapreopt_script.sh b/cmds/installd/otapreopt_script.sh
index db5c34e..28bd793 100644
--- a/cmds/installd/otapreopt_script.sh
+++ b/cmds/installd/otapreopt_script.sh
@@ -16,7 +16,9 @@
 # limitations under the License.
 #
 
-# This script will run as a postinstall step to drive otapreopt.
+# This script runs as a postinstall step to drive otapreopt. It comes with the
+# OTA package, but runs /system/bin/otapreopt_chroot in the (old) active system
+# image. See system/extras/postinst/postinst.sh for some docs.
 
 TARGET_SLOT="$1"
 STATUS_FD="$2"
@@ -31,12 +33,11 @@
 
 BOOT_COMPLETE=$(getprop $BOOT_PROPERTY_NAME)
 if [ "$BOOT_COMPLETE" != "1" ] ; then
-  echo "Error: boot-complete not detected."
+  echo "$0: Error: boot-complete not detected."
   # We must return 0 to not block sideload.
   exit 0
 fi
 
-
 # Compute target slot suffix.
 # TODO: Once bootctl is not restricted, we should query from there. Or get this from
 #       update_engine as a parameter.
@@ -45,45 +46,63 @@
 elif [ "$TARGET_SLOT" = "1" ] ; then
   TARGET_SLOT_SUFFIX="_b"
 else
-  echo "Unknown target slot $TARGET_SLOT"
+  echo "$0: Unknown target slot $TARGET_SLOT"
   exit 1
 fi
 
+if [ "$(/system/bin/otapreopt_chroot --version)" != 2 ]; then
+  # We require an updated chroot wrapper that reads dexopt commands from stdin.
+  # Even if we kept compat with the old binary, the OTA preopt wouldn't work due
+  # to missing sepolicy rules, so there's no use spending time trying to dexopt
+  # (b/291974157).
+  echo "$0: Current system image is too old to work with OTA preopt - skipping."
+  exit 0
+fi
 
 PREPARE=$(cmd otadexopt prepare)
 # Note: Ignore preparation failures. Step and done will fail and exit this.
 #       This is necessary to support suspends - the OTA service will keep
 #       the state around for us.
 
-PROGRESS=$(cmd otadexopt progress)
-print -u${STATUS_FD} "global_progress $PROGRESS"
-
-i=0
-while ((i<MAXIMUM_PACKAGES)) ; do
+# Create an array with all dexopt commands in advance, to know how many there are.
+otadexopt_cmds=()
+while (( ${#otadexopt_cmds[@]} < MAXIMUM_PACKAGES )) ; do
   DONE=$(cmd otadexopt done)
   if [ "$DONE" = "OTA complete." ] ; then
     break
   fi
-
-  DEXOPT_PARAMS=$(cmd otadexopt next)
-
-  /system/bin/otapreopt_chroot $STATUS_FD $TARGET_SLOT_SUFFIX $DEXOPT_PARAMS >&- 2>&-
-
-  PROGRESS=$(cmd otadexopt progress)
-  print -u${STATUS_FD} "global_progress $PROGRESS"
-
-  sleep 1
-  i=$((i+1))
+  otadexopt_cmds+=("$(cmd otadexopt next)")
 done
 
 DONE=$(cmd otadexopt done)
+cmd otadexopt cleanup
+
+echo "$0: Using streaming otapreopt_chroot on ${#otadexopt_cmds[@]} packages"
+
+function print_otadexopt_cmds {
+  for cmd in "${otadexopt_cmds[@]}" ; do
+    print "$cmd"
+  done
+}
+
+function report_progress {
+  while read count ; do
+    # mksh can't do floating point arithmetic, so emulate a fixed point calculation.
+    (( permilles = 1000 * count / ${#otadexopt_cmds[@]} ))
+    printf 'global_progress %d.%03d\n' $((permilles / 1000)) $((permilles % 1000)) >&${STATUS_FD}
+  done
+}
+
+print_otadexopt_cmds | \
+  /system/bin/otapreopt_chroot $STATUS_FD $TARGET_SLOT_SUFFIX | \
+  report_progress
+
 if [ "$DONE" = "OTA incomplete." ] ; then
-  echo "Incomplete."
+  echo "$0: Incomplete."
 else
-  echo "Complete or error."
+  echo "$0: Complete or error."
 fi
 
 print -u${STATUS_FD} "global_progress 1.0"
-cmd otadexopt cleanup
 
 exit 0
diff --git a/include/input/MotionPredictorMetricsManager.h b/include/input/MotionPredictorMetricsManager.h
index 6284f07..12e50ba 100644
--- a/include/input/MotionPredictorMetricsManager.h
+++ b/include/input/MotionPredictorMetricsManager.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (C) 2023 The Android Open Source Project
+ * Copyright 2023 The Android Open Source Project
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -14,23 +14,193 @@
  * limitations under the License.
  */
 
-#include <utils/Timers.h>
+#include <cstddef>
+#include <cstdint>
+#include <functional>
+#include <limits>
+#include <optional>
+#include <vector>
+
+#include <input/Input.h> // for MotionEvent
+#include <input/RingBuffer.h>
+#include <utils/Timers.h> // for nsecs_t
+
+#include "Eigen/Core"
 
 namespace android {
 
 /**
  * Class to handle computing and reporting metrics for MotionPredictor.
  *
- * Currently an empty implementation, containing only the API.
+ * The public API provides two methods: `onRecord` and `onPredict`, which expect to receive the
+ * MotionEvents from the corresponding methods in MotionPredictor.
+ *
+ * This class stores AggregatedStrokeMetrics, updating them as new MotionEvents are passed in. When
+ * onRecord receives an UP or CANCEL event, this indicates the end of the stroke, and the final
+ * AtomFields are computed and reported to the stats library.
+ *
+ * If mMockLoggedAtomFields is set, the batch of AtomFields that are reported to the stats library
+ * for one stroke are also stored in mMockLoggedAtomFields at the time they're reported.
  */
 class MotionPredictorMetricsManager {
 public:
     // Note: the MetricsManager assumes that the input interval equals the prediction interval.
-    MotionPredictorMetricsManager(nsecs_t /*predictionInterval*/, size_t /*maxNumPredictions*/) {}
+    MotionPredictorMetricsManager(nsecs_t predictionInterval, size_t maxNumPredictions);
 
-    void onRecord(const MotionEvent& /*inputEvent*/) {}
+    // This method should be called once for each call to MotionPredictor::record, receiving the
+    // forwarded MotionEvent argument.
+    void onRecord(const MotionEvent& inputEvent);
 
-    void onPredict(const MotionEvent& /*predictionEvent*/) {}
+    // This method should be called once for each call to MotionPredictor::predict, receiving the
+    // MotionEvent that will be returned by MotionPredictor::predict.
+    void onPredict(const MotionEvent& predictionEvent);
+
+    // Simple structs to hold relevant touch input information. Public so they can be used in tests.
+
+    struct TouchPoint {
+        Eigen::Vector2f position; // (y, x) in pixels
+        float pressure;
+    };
+
+    struct GroundTruthPoint : TouchPoint {
+        nsecs_t timestamp;
+    };
+
+    struct PredictionPoint : TouchPoint {
+        // The timestamp of the last ground truth point when the prediction was made.
+        nsecs_t originTimestamp;
+
+        nsecs_t targetTimestamp;
+
+        // Order by targetTimestamp when sorting.
+        bool operator<(const PredictionPoint& other) const {
+            return this->targetTimestamp < other.targetTimestamp;
+        }
+    };
+
+    // Metrics aggregated so far for the current stroke. These are not the final fields to be
+    // reported in the atom (see AtomFields below), but rather an intermediate representation of the
+    // data that can be conveniently aggregated and from which the atom fields can be derived later.
+    //
+    // Displacement units are in pixels.
+    //
+    // "Along-trajectory error" is the dot product of the prediction error with the unit vector
+    // pointing towards the ground truth point whose timestamp corresponds to the prediction
+    // target timestamp, originating from the preceding ground truth point.
+    //
+    // "Off-trajectory error" is the component of the prediction error orthogonal to the
+    // "along-trajectory" unit vector described above.
+    //
+    // "High-velocity" errors are errors that are only accumulated when the velocity between the
+    // most recent two input events exceeds a certain threshold.
+    //
+    // "Scale-invariant errors" are the errors produced when the path length of the stroke is
+    // scaled to 1. (In other words, the error distances are normalized by the path length.)
+    struct AggregatedStrokeMetrics {
+        // General errors
+        float alongTrajectoryErrorSum = 0;
+        float alongTrajectorySumSquaredErrors = 0;
+        float offTrajectorySumSquaredErrors = 0;
+        float pressureSumSquaredErrors = 0;
+        size_t generalErrorsCount = 0;
+
+        // High-velocity errors
+        float highVelocityAlongTrajectorySse = 0;
+        float highVelocityOffTrajectorySse = 0;
+        size_t highVelocityErrorsCount = 0;
+
+        // Scale-invariant errors
+        float scaleInvariantAlongTrajectorySse = 0;
+        float scaleInvariantOffTrajectorySse = 0;
+        size_t scaleInvariantErrorsCount = 0;
+    };
+
+    // In order to explicitly indicate "no relevant data" for a metric, we report this
+    // large-magnitude negative sentinel value. (Most metrics are non-negative, so this value is
+    // completely unobtainable. For along-trajectory error mean, which can be negative, the
+    // magnitude makes it unobtainable in practice.)
+    static const int NO_DATA_SENTINEL = std::numeric_limits<int32_t>::min();
+
+    // Final metrics reported in the atom.
+    struct AtomFields {
+        int deltaTimeBucketMilliseconds = 0;
+
+        // General errors
+        int alongTrajectoryErrorMeanMillipixels = NO_DATA_SENTINEL;
+        int alongTrajectoryErrorStdMillipixels = NO_DATA_SENTINEL;
+        int offTrajectoryRmseMillipixels = NO_DATA_SENTINEL;
+        int pressureRmseMilliunits = NO_DATA_SENTINEL;
+
+        // High-velocity errors
+        int highVelocityAlongTrajectoryRmse = NO_DATA_SENTINEL; // millipixels
+        int highVelocityOffTrajectoryRmse = NO_DATA_SENTINEL;   // millipixels
+
+        // Scale-invariant errors
+        int scaleInvariantAlongTrajectoryRmse = NO_DATA_SENTINEL; // millipixels
+        int scaleInvariantOffTrajectoryRmse = NO_DATA_SENTINEL;   // millipixels
+    };
+
+    // Allow tests to pass in a mock AtomFields pointer.
+    //
+    // When metrics are reported to the stats library on stroke end, they will also be written to
+    // mockLoggedAtomFields, overwriting existing data. The size of mockLoggedAtomFields will equal
+    // the number of calls to stats_write for that stroke.
+    void setMockLoggedAtomFields(std::vector<AtomFields>* mockLoggedAtomFields) {
+        mMockLoggedAtomFields = mockLoggedAtomFields;
+    }
+
+private:
+    // The interval between consecutive predictions' target timestamps. We assume that the input
+    // interval also equals this value.
+    const nsecs_t mPredictionInterval;
+
+    // The maximum number of input frames into the future the model can predict.
+    // Used to perform time-bucketing of metrics.
+    const size_t mMaxNumPredictions;
+
+    // History of mMaxNumPredictions + 1 ground truth points, used to compute scale-invariant
+    // error. (Also, the last two points are used to compute the ground truth trajectory.)
+    RingBuffer<GroundTruthPoint> mRecentGroundTruthPoints;
+
+    // Predictions having a targetTimestamp after the most recent ground truth point's timestamp.
+    // Invariant: sorted in ascending order of targetTimestamp.
+    std::vector<PredictionPoint> mRecentPredictions;
+
+    // Containers for the intermediate representation of stroke metrics and the final atom fields.
+    // These are indexed by the number of input frames into the future being predicted minus one,
+    // and always have size mMaxNumPredictions.
+    std::vector<AggregatedStrokeMetrics> mAggregatedMetrics;
+    std::vector<AtomFields> mAtomFields;
+
+    // Non-owning pointer to the location of mock AtomFields. If present, will be filled with the
+    // values reported to stats_write on each batch of reported metrics.
+    //
+    // This pointer must remain valid as long as the MotionPredictorMetricsManager exists.
+    std::vector<AtomFields>* mMockLoggedAtomFields = nullptr;
+
+    // Helper methods for the implementation of onRecord and onPredict.
+
+    // Clears stored ground truth and prediction points, as well as all stored metrics for the
+    // current stroke.
+    void clearStrokeData();
+
+    // Adds the new ground truth point to mRecentGroundTruths, removes outdated predictions from
+    // mRecentPredictions, and updates the aggregated metrics to include the recent predictions that
+    // fuzzily match with the new ground truth point.
+    void incorporateNewGroundTruth(const GroundTruthPoint& groundTruthPoint);
+
+    // Given a new prediction with targetTimestamp matching the latest ground truth point's
+    // timestamp, computes the corresponding metrics and updates mAggregatedMetrics.
+    void updateAggregatedMetrics(const PredictionPoint& predictionPoint);
+
+    // Computes the atom fields to mAtomFields from the values in mAggregatedMetrics.
+    void computeAtomFields();
+
+    // Reports the metrics given by the current data in mAtomFields:
+    //  • If on an Android device, reports the metrics to stats_write.
+    //  • If mMockLoggedAtomFields is present, it will be overwritten with logged metrics, with one
+    //    AtomFields element per call to stats_write.
+    void reportMetrics();
 };
 
 } // namespace android
diff --git a/include/input/VelocityTracker.h b/include/input/VelocityTracker.h
index da97c3e..4257cb5 100644
--- a/include/input/VelocityTracker.h
+++ b/include/input/VelocityTracker.h
@@ -45,6 +45,7 @@
         INT2 = 8,
         LEGACY = 9,
         MAX = LEGACY,
+        ftl_last = LEGACY,
     };
 
     struct Estimator {
@@ -95,8 +96,6 @@
     // TODO(b/32830165): support axis-specific strategies.
     VelocityTracker(const Strategy strategy = Strategy::DEFAULT);
 
-    ~VelocityTracker();
-
     /** Return true if the axis is supported for velocity tracking, false otherwise. */
     static bool isAxisSupported(int32_t axis);
 
diff --git a/libs/binder/IActivityManager.cpp b/libs/binder/IActivityManager.cpp
index f2b4a6e..7d6ae00 100644
--- a/libs/binder/IActivityManager.cpp
+++ b/libs/binder/IActivityManager.cpp
@@ -193,8 +193,7 @@
         status_t err = remote()->transact(LOG_FGS_API_BEGIN_TRANSACTION, data, &reply,
                                           IBinder::FLAG_ONEWAY);
         if (err != NO_ERROR || ((err = reply.readExceptionCode()) != NO_ERROR)) {
-            ALOGD("FGS Logger Transaction failed");
-            ALOGD("%d", err);
+            ALOGD("%s: FGS Logger Transaction failed, %d", __func__, err);
             return err;
         }
         return NO_ERROR;
@@ -209,8 +208,7 @@
         status_t err =
                 remote()->transact(LOG_FGS_API_END_TRANSACTION, data, &reply, IBinder::FLAG_ONEWAY);
         if (err != NO_ERROR || ((err = reply.readExceptionCode()) != NO_ERROR)) {
-            ALOGD("FGS Logger Transaction failed");
-            ALOGD("%d", err);
+            ALOGD("%s: FGS Logger Transaction failed, %d", __func__, err);
             return err;
         }
         return NO_ERROR;
@@ -224,11 +222,10 @@
         data.writeInt32(state);
         data.writeInt32(appUid);
         data.writeInt32(appPid);
-        status_t err = remote()->transact(LOG_FGS_API_BEGIN_TRANSACTION, data, &reply,
+        status_t err = remote()->transact(LOG_FGS_API_STATE_CHANGED_TRANSACTION, data, &reply,
                                           IBinder::FLAG_ONEWAY);
         if (err != NO_ERROR || ((err = reply.readExceptionCode()) != NO_ERROR)) {
-            ALOGD("FGS Logger Transaction failed");
-            ALOGD("%d", err);
+            ALOGD("%s: FGS Logger Transaction failed, %d", __func__, err);
             return err;
         }
         return NO_ERROR;
diff --git a/libs/input/Android.bp b/libs/input/Android.bp
index 8a17d8a..022dfad 100644
--- a/libs/input/Android.bp
+++ b/libs/input/Android.bp
@@ -185,6 +185,7 @@
         "KeyCharacterMap.cpp",
         "KeyLayoutMap.cpp",
         "MotionPredictor.cpp",
+        "MotionPredictorMetricsManager.cpp",
         "PrintTools.cpp",
         "PropertyMap.cpp",
         "TfLiteMotionPredictor.cpp",
@@ -198,9 +199,13 @@
     header_libs: [
         "flatbuffer_headers",
         "jni_headers",
+        "libeigen",
         "tensorflow_headers",
     ],
-    export_header_lib_headers: ["jni_headers"],
+    export_header_lib_headers: [
+        "jni_headers",
+        "libeigen",
+    ],
 
     generated_headers: [
         "cxx-bridge-header",
@@ -260,6 +265,10 @@
             shared_libs: [
                 "libutils",
                 "libbinder",
+                // Stats logging library and its dependencies.
+                "libstatslog_libinput",
+                "libstatsbootstrap",
+                "android.os.statsbootstrap_aidl-cpp",
             ],
 
             static_libs: [
@@ -311,6 +320,43 @@
     },
 }
 
+// Use bootstrap version of stats logging library.
+// libinput is a bootstrap process (starts early in the boot process), and thus can't use the normal
+// `libstatslog` because that requires `libstatssocket`, which is only available later in the boot.
+cc_library {
+    name: "libstatslog_libinput",
+    generated_sources: ["statslog_libinput.cpp"],
+    generated_headers: ["statslog_libinput.h"],
+    export_generated_headers: ["statslog_libinput.h"],
+    shared_libs: [
+        "libbinder",
+        "libstatsbootstrap",
+        "libutils",
+        "android.os.statsbootstrap_aidl-cpp",
+    ],
+}
+
+genrule {
+    name: "statslog_libinput.h",
+    tools: ["stats-log-api-gen"],
+    cmd: "$(location stats-log-api-gen) --header $(genDir)/statslog_libinput.h --module libinput" +
+        " --namespace android,stats,libinput --bootstrap",
+    out: [
+        "statslog_libinput.h",
+    ],
+}
+
+genrule {
+    name: "statslog_libinput.cpp",
+    tools: ["stats-log-api-gen"],
+    cmd: "$(location stats-log-api-gen) --cpp $(genDir)/statslog_libinput.cpp --module libinput" +
+        " --namespace android,stats,libinput --importHeader statslog_libinput.h" +
+        " --bootstrap",
+    out: [
+        "statslog_libinput.cpp",
+    ],
+}
+
 cc_defaults {
     name: "libinput_fuzz_defaults",
     cpp_std: "c++20",
diff --git a/libs/input/MotionPredictor.cpp b/libs/input/MotionPredictor.cpp
index c2ea35c..f7ca5e7 100644
--- a/libs/input/MotionPredictor.cpp
+++ b/libs/input/MotionPredictor.cpp
@@ -137,10 +137,7 @@
 
     // Pass input event to the MetricsManager.
     if (!mMetricsManager) {
-        mMetricsManager =
-                std::make_optional<MotionPredictorMetricsManager>(mModel->config()
-                                                                          .predictionInterval,
-                                                                  mModel->outputLength());
+        mMetricsManager.emplace(mModel->config().predictionInterval, mModel->outputLength());
     }
     mMetricsManager->onRecord(event);
 
diff --git a/libs/input/MotionPredictorMetricsManager.cpp b/libs/input/MotionPredictorMetricsManager.cpp
new file mode 100644
index 0000000..67b1032
--- /dev/null
+++ b/libs/input/MotionPredictorMetricsManager.cpp
@@ -0,0 +1,373 @@
+/*
+ * Copyright 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#define LOG_TAG "MotionPredictorMetricsManager"
+
+#include <input/MotionPredictorMetricsManager.h>
+
+#include <algorithm>
+
+#include <android-base/logging.h>
+
+#include "Eigen/Core"
+#include "Eigen/Geometry"
+
+#ifdef __ANDROID__
+#include <statslog_libinput.h>
+#endif
+
+namespace android {
+namespace {
+
+inline constexpr int NANOS_PER_SECOND = 1'000'000'000; // nanoseconds per second
+inline constexpr int NANOS_PER_MILLIS = 1'000'000;     // nanoseconds per millisecond
+
+// Velocity threshold at which we report "high-velocity" metrics, in pixels per second.
+// This value was selected from manual experimentation, as a threshold that separates "fast"
+// (semi-sloppy) handwriting from more careful medium to slow handwriting.
+inline constexpr float HIGH_VELOCITY_THRESHOLD = 1100.0;
+
+// Small value to add to the path length when computing scale-invariant error to avoid division by
+// zero.
+inline constexpr float PATH_LENGTH_EPSILON = 0.001;
+
+} // namespace
+
+MotionPredictorMetricsManager::MotionPredictorMetricsManager(nsecs_t predictionInterval,
+                                                             size_t maxNumPredictions)
+      : mPredictionInterval(predictionInterval),
+        mMaxNumPredictions(maxNumPredictions),
+        mRecentGroundTruthPoints(maxNumPredictions + 1),
+        mAggregatedMetrics(maxNumPredictions),
+        mAtomFields(maxNumPredictions) {}
+
+void MotionPredictorMetricsManager::onRecord(const MotionEvent& inputEvent) {
+    // Convert MotionEvent to GroundTruthPoint.
+    const PointerCoords* coords = inputEvent.getRawPointerCoords(/*pointerIndex=*/0);
+    LOG_ALWAYS_FATAL_IF(coords == nullptr);
+    const GroundTruthPoint groundTruthPoint{{.position = Eigen::Vector2f{coords->getY(),
+                                                                         coords->getX()},
+                                             .pressure =
+                                                     inputEvent.getPressure(/*pointerIndex=*/0)},
+                                            .timestamp = inputEvent.getEventTime()};
+
+    // Handle event based on action type.
+    switch (inputEvent.getActionMasked()) {
+        case AMOTION_EVENT_ACTION_DOWN: {
+            clearStrokeData();
+            incorporateNewGroundTruth(groundTruthPoint);
+            break;
+        }
+        case AMOTION_EVENT_ACTION_MOVE: {
+            incorporateNewGroundTruth(groundTruthPoint);
+            break;
+        }
+        case AMOTION_EVENT_ACTION_UP:
+        case AMOTION_EVENT_ACTION_CANCEL: {
+            // Only expect meaningful predictions when given at least two input points.
+            if (mRecentGroundTruthPoints.size() >= 2) {
+                computeAtomFields();
+                reportMetrics();
+                break;
+            }
+        }
+    }
+}
+
+// Adds new predictions to mRecentPredictions and maintains the invariant that elements are
+// sorted in ascending order of targetTimestamp.
+void MotionPredictorMetricsManager::onPredict(const MotionEvent& predictionEvent) {
+    for (size_t i = 0; i < predictionEvent.getHistorySize() + 1; ++i) {
+        // Convert MotionEvent to PredictionPoint.
+        const PointerCoords* coords =
+                predictionEvent.getHistoricalRawPointerCoords(/*pointerIndex=*/0, i);
+        LOG_ALWAYS_FATAL_IF(coords == nullptr);
+        const nsecs_t targetTimestamp = predictionEvent.getHistoricalEventTime(i);
+        mRecentPredictions.push_back(
+                PredictionPoint{{.position = Eigen::Vector2f{coords->getY(), coords->getX()},
+                                 .pressure =
+                                         predictionEvent.getHistoricalPressure(/*pointerIndex=*/0,
+                                                                               i)},
+                                .originTimestamp = mRecentGroundTruthPoints.back().timestamp,
+                                .targetTimestamp = targetTimestamp});
+    }
+
+    std::sort(mRecentPredictions.begin(), mRecentPredictions.end());
+}
+
+void MotionPredictorMetricsManager::clearStrokeData() {
+    mRecentGroundTruthPoints.clear();
+    mRecentPredictions.clear();
+    std::fill(mAggregatedMetrics.begin(), mAggregatedMetrics.end(), AggregatedStrokeMetrics{});
+    std::fill(mAtomFields.begin(), mAtomFields.end(), AtomFields{});
+}
+
+void MotionPredictorMetricsManager::incorporateNewGroundTruth(
+        const GroundTruthPoint& groundTruthPoint) {
+    // Note: this removes the oldest point if `mRecentGroundTruthPoints` is already at capacity.
+    mRecentGroundTruthPoints.pushBack(groundTruthPoint);
+
+    // Remove outdated predictions – those that can never be matched with the current or any future
+    // ground truth points. We use fuzzy association for the timestamps here, because ground truth
+    // and prediction timestamps may not be perfectly synchronized.
+    const nsecs_t fuzzy_association_time_delta = mPredictionInterval / 4;
+    const auto firstCurrentIt =
+            std::find_if(mRecentPredictions.begin(), mRecentPredictions.end(),
+                         [&groundTruthPoint,
+                          fuzzy_association_time_delta](const PredictionPoint& prediction) {
+                             return prediction.targetTimestamp >
+                                     groundTruthPoint.timestamp - fuzzy_association_time_delta;
+                         });
+    mRecentPredictions.erase(mRecentPredictions.begin(), firstCurrentIt);
+
+    // Fuzzily match the new ground truth's timestamp to recent predictions' targetTimestamp and
+    // update the corresponding metrics.
+    for (const PredictionPoint& prediction : mRecentPredictions) {
+        if ((prediction.targetTimestamp >
+             groundTruthPoint.timestamp - fuzzy_association_time_delta) &&
+            (prediction.targetTimestamp <
+             groundTruthPoint.timestamp + fuzzy_association_time_delta)) {
+            updateAggregatedMetrics(prediction);
+        }
+    }
+}
+
+void MotionPredictorMetricsManager::updateAggregatedMetrics(
+        const PredictionPoint& predictionPoint) {
+    if (mRecentGroundTruthPoints.size() < 2) {
+        return;
+    }
+
+    const GroundTruthPoint& latestGroundTruthPoint = mRecentGroundTruthPoints.back();
+    const GroundTruthPoint& previousGroundTruthPoint =
+            mRecentGroundTruthPoints[mRecentGroundTruthPoints.size() - 2];
+    // Calculate prediction error vector.
+    const Eigen::Vector2f groundTruthTrajectory =
+            latestGroundTruthPoint.position - previousGroundTruthPoint.position;
+    const Eigen::Vector2f predictionTrajectory =
+            predictionPoint.position - previousGroundTruthPoint.position;
+    const Eigen::Vector2f predictionError = predictionTrajectory - groundTruthTrajectory;
+
+    // By default, prediction error counts fully as both off-trajectory and along-trajectory error.
+    // This serves as the fallback when the two most recent ground truth points are equal.
+    const float predictionErrorNorm = predictionError.norm();
+    float alongTrajectoryError = predictionErrorNorm;
+    float offTrajectoryError = predictionErrorNorm;
+    if (groundTruthTrajectory.squaredNorm() > 0) {
+        // Rotate the prediction error vector by the angle of the ground truth trajectory vector.
+        // This yields a vector whose first component is the along-trajectory error and whose
+        // second component is the off-trajectory error.
+        const float theta = std::atan2(groundTruthTrajectory[1], groundTruthTrajectory[0]);
+        const Eigen::Vector2f rotatedPredictionError = Eigen::Rotation2Df(-theta) * predictionError;
+        alongTrajectoryError = rotatedPredictionError[0];
+        offTrajectoryError = rotatedPredictionError[1];
+    }
+
+    // Compute the multiple of mPredictionInterval nearest to the amount of time into the
+    // future being predicted. This serves as the time bucket index into mAggregatedMetrics.
+    const float timestampDeltaFloat =
+            static_cast<float>(predictionPoint.targetTimestamp - predictionPoint.originTimestamp);
+    const size_t tIndex =
+            static_cast<size_t>(std::round(timestampDeltaFloat / mPredictionInterval - 1));
+
+    // Aggregate values into "general errors".
+    mAggregatedMetrics[tIndex].alongTrajectoryErrorSum += alongTrajectoryError;
+    mAggregatedMetrics[tIndex].alongTrajectorySumSquaredErrors +=
+            alongTrajectoryError * alongTrajectoryError;
+    mAggregatedMetrics[tIndex].offTrajectorySumSquaredErrors +=
+            offTrajectoryError * offTrajectoryError;
+    const float pressureError = predictionPoint.pressure - latestGroundTruthPoint.pressure;
+    mAggregatedMetrics[tIndex].pressureSumSquaredErrors += pressureError * pressureError;
+    ++mAggregatedMetrics[tIndex].generalErrorsCount;
+
+    // Aggregate values into high-velocity metrics, if we are in one of the last two time buckets
+    // and the velocity is above the threshold. Velocity here is measured in pixels per second.
+    const float velocity = groundTruthTrajectory.norm() /
+            (static_cast<float>(latestGroundTruthPoint.timestamp -
+                                previousGroundTruthPoint.timestamp) /
+             NANOS_PER_SECOND);
+    if ((tIndex + 2 >= mMaxNumPredictions) && (velocity > HIGH_VELOCITY_THRESHOLD)) {
+        mAggregatedMetrics[tIndex].highVelocityAlongTrajectorySse +=
+                alongTrajectoryError * alongTrajectoryError;
+        mAggregatedMetrics[tIndex].highVelocityOffTrajectorySse +=
+                offTrajectoryError * offTrajectoryError;
+        ++mAggregatedMetrics[tIndex].highVelocityErrorsCount;
+    }
+
+    // Compute path length for scale-invariant errors.
+    float pathLength = 0;
+    for (size_t i = 1; i < mRecentGroundTruthPoints.size(); ++i) {
+        pathLength +=
+                (mRecentGroundTruthPoints[i].position - mRecentGroundTruthPoints[i - 1].position)
+                        .norm();
+    }
+    // Avoid overweighting errors at the beginning of a stroke: compute the path length as if there
+    // were a full ground truth history by filling in missing segments with the average length.
+    // Note: the "- 1" is needed to translate from number of endpoints to number of segments.
+    pathLength *= static_cast<float>(mRecentGroundTruthPoints.capacity() - 1) /
+            (mRecentGroundTruthPoints.size() - 1);
+    pathLength += PATH_LENGTH_EPSILON; // Ensure path length is nonzero (>= PATH_LENGTH_EPSILON).
+
+    // Compute and aggregate scale-invariant errors.
+    const float scaleInvariantAlongTrajectoryError = alongTrajectoryError / pathLength;
+    const float scaleInvariantOffTrajectoryError = offTrajectoryError / pathLength;
+    mAggregatedMetrics[tIndex].scaleInvariantAlongTrajectorySse +=
+            scaleInvariantAlongTrajectoryError * scaleInvariantAlongTrajectoryError;
+    mAggregatedMetrics[tIndex].scaleInvariantOffTrajectorySse +=
+            scaleInvariantOffTrajectoryError * scaleInvariantOffTrajectoryError;
+    ++mAggregatedMetrics[tIndex].scaleInvariantErrorsCount;
+}
+
+void MotionPredictorMetricsManager::computeAtomFields() {
+    for (size_t i = 0; i < mAggregatedMetrics.size(); ++i) {
+        if (mAggregatedMetrics[i].generalErrorsCount == 0) {
+            // We have not received data corresponding to metrics for this time bucket.
+            continue;
+        }
+
+        mAtomFields[i].deltaTimeBucketMilliseconds =
+                static_cast<int>(mPredictionInterval / NANOS_PER_MILLIS * (i + 1));
+
+        // Note: we need the "* 1000"s below because we report values in integral milli-units.
+
+        { // General errors: reported for every time bucket.
+            const float alongTrajectoryErrorMean = mAggregatedMetrics[i].alongTrajectoryErrorSum /
+                    mAggregatedMetrics[i].generalErrorsCount;
+            mAtomFields[i].alongTrajectoryErrorMeanMillipixels =
+                    static_cast<int>(alongTrajectoryErrorMean * 1000);
+
+            const float alongTrajectoryMse = mAggregatedMetrics[i].alongTrajectorySumSquaredErrors /
+                    mAggregatedMetrics[i].generalErrorsCount;
+            // Take the max with 0 to avoid negative values caused by numerical instability.
+            const float alongTrajectoryErrorVariance =
+                    std::max(0.0f,
+                             alongTrajectoryMse -
+                                     alongTrajectoryErrorMean * alongTrajectoryErrorMean);
+            const float alongTrajectoryErrorStd = std::sqrt(alongTrajectoryErrorVariance);
+            mAtomFields[i].alongTrajectoryErrorStdMillipixels =
+                    static_cast<int>(alongTrajectoryErrorStd * 1000);
+
+            LOG_ALWAYS_FATAL_IF(mAggregatedMetrics[i].offTrajectorySumSquaredErrors < 0,
+                                "mAggregatedMetrics[%zu].offTrajectorySumSquaredErrors = %f should "
+                                "not be negative",
+                                i, mAggregatedMetrics[i].offTrajectorySumSquaredErrors);
+            const float offTrajectoryRmse =
+                    std::sqrt(mAggregatedMetrics[i].offTrajectorySumSquaredErrors /
+                              mAggregatedMetrics[i].generalErrorsCount);
+            mAtomFields[i].offTrajectoryRmseMillipixels =
+                    static_cast<int>(offTrajectoryRmse * 1000);
+
+            LOG_ALWAYS_FATAL_IF(mAggregatedMetrics[i].pressureSumSquaredErrors < 0,
+                                "mAggregatedMetrics[%zu].pressureSumSquaredErrors = %f should not "
+                                "be negative",
+                                i, mAggregatedMetrics[i].pressureSumSquaredErrors);
+            const float pressureRmse = std::sqrt(mAggregatedMetrics[i].pressureSumSquaredErrors /
+                                                 mAggregatedMetrics[i].generalErrorsCount);
+            mAtomFields[i].pressureRmseMilliunits = static_cast<int>(pressureRmse * 1000);
+        }
+
+        // High-velocity errors: reported only for last two time buckets.
+        // Check if we are in one of the last two time buckets, and there is high-velocity data.
+        if ((i + 2 >= mMaxNumPredictions) && (mAggregatedMetrics[i].highVelocityErrorsCount > 0)) {
+            LOG_ALWAYS_FATAL_IF(mAggregatedMetrics[i].highVelocityAlongTrajectorySse < 0,
+                                "mAggregatedMetrics[%zu].highVelocityAlongTrajectorySse = %f "
+                                "should not be negative",
+                                i, mAggregatedMetrics[i].highVelocityAlongTrajectorySse);
+            const float alongTrajectoryRmse =
+                    std::sqrt(mAggregatedMetrics[i].highVelocityAlongTrajectorySse /
+                              mAggregatedMetrics[i].highVelocityErrorsCount);
+            mAtomFields[i].highVelocityAlongTrajectoryRmse =
+                    static_cast<int>(alongTrajectoryRmse * 1000);
+
+            LOG_ALWAYS_FATAL_IF(mAggregatedMetrics[i].highVelocityOffTrajectorySse < 0,
+                                "mAggregatedMetrics[%zu].highVelocityOffTrajectorySse = %f should "
+                                "not be negative",
+                                i, mAggregatedMetrics[i].highVelocityOffTrajectorySse);
+            const float offTrajectoryRmse =
+                    std::sqrt(mAggregatedMetrics[i].highVelocityOffTrajectorySse /
+                              mAggregatedMetrics[i].highVelocityErrorsCount);
+            mAtomFields[i].highVelocityOffTrajectoryRmse =
+                    static_cast<int>(offTrajectoryRmse * 1000);
+        }
+
+        // Scale-invariant errors: reported only for the last time bucket, where the values
+        // represent an average across all time buckets.
+        if (i + 1 == mMaxNumPredictions) {
+            // Compute error averages.
+            float alongTrajectoryRmseSum = 0;
+            float offTrajectoryRmseSum = 0;
+            for (size_t j = 0; j < mAggregatedMetrics.size(); ++j) {
+                // If we have general errors (checked above), we should always also have
+                // scale-invariant errors.
+                LOG_ALWAYS_FATAL_IF(mAggregatedMetrics[j].scaleInvariantErrorsCount == 0,
+                                    "mAggregatedMetrics[%zu].scaleInvariantErrorsCount is 0", j);
+
+                LOG_ALWAYS_FATAL_IF(mAggregatedMetrics[j].scaleInvariantAlongTrajectorySse < 0,
+                                    "mAggregatedMetrics[%zu].scaleInvariantAlongTrajectorySse = %f "
+                                    "should not be negative",
+                                    j, mAggregatedMetrics[j].scaleInvariantAlongTrajectorySse);
+                alongTrajectoryRmseSum +=
+                        std::sqrt(mAggregatedMetrics[j].scaleInvariantAlongTrajectorySse /
+                                  mAggregatedMetrics[j].scaleInvariantErrorsCount);
+
+                LOG_ALWAYS_FATAL_IF(mAggregatedMetrics[j].scaleInvariantOffTrajectorySse < 0,
+                                    "mAggregatedMetrics[%zu].scaleInvariantOffTrajectorySse = %f "
+                                    "should not be negative",
+                                    j, mAggregatedMetrics[j].scaleInvariantOffTrajectorySse);
+                offTrajectoryRmseSum +=
+                        std::sqrt(mAggregatedMetrics[j].scaleInvariantOffTrajectorySse /
+                                  mAggregatedMetrics[j].scaleInvariantErrorsCount);
+            }
+
+            const float averageAlongTrajectoryRmse =
+                    alongTrajectoryRmseSum / mAggregatedMetrics.size();
+            mAtomFields.back().scaleInvariantAlongTrajectoryRmse =
+                    static_cast<int>(averageAlongTrajectoryRmse * 1000);
+
+            const float averageOffTrajectoryRmse = offTrajectoryRmseSum / mAggregatedMetrics.size();
+            mAtomFields.back().scaleInvariantOffTrajectoryRmse =
+                    static_cast<int>(averageOffTrajectoryRmse * 1000);
+        }
+    }
+}
+
+void MotionPredictorMetricsManager::reportMetrics() {
+    // Report one atom for each time bucket.
+    for (size_t i = 0; i < mAtomFields.size(); ++i) {
+        // Call stats_write logging function only on Android targets (not supported on host).
+#ifdef __ANDROID__
+        android::stats::libinput::
+                stats_write(android::stats::libinput::STYLUS_PREDICTION_METRICS_REPORTED,
+                            /*stylus_vendor_id=*/0,
+                            /*stylus_product_id=*/0, mAtomFields[i].deltaTimeBucketMilliseconds,
+                            mAtomFields[i].alongTrajectoryErrorMeanMillipixels,
+                            mAtomFields[i].alongTrajectoryErrorStdMillipixels,
+                            mAtomFields[i].offTrajectoryRmseMillipixels,
+                            mAtomFields[i].pressureRmseMilliunits,
+                            mAtomFields[i].highVelocityAlongTrajectoryRmse,
+                            mAtomFields[i].highVelocityOffTrajectoryRmse,
+                            mAtomFields[i].scaleInvariantAlongTrajectoryRmse,
+                            mAtomFields[i].scaleInvariantOffTrajectoryRmse);
+#endif
+    }
+
+    // Set mock atom fields, if available.
+    if (mMockLoggedAtomFields != nullptr) {
+        *mMockLoggedAtomFields = mAtomFields;
+    }
+}
+
+} // namespace android
diff --git a/libs/input/VelocityTracker.cpp b/libs/input/VelocityTracker.cpp
index 8551e5f..078109a 100644
--- a/libs/input/VelocityTracker.cpp
+++ b/libs/input/VelocityTracker.cpp
@@ -16,7 +16,9 @@
 
 #define LOG_TAG "VelocityTracker"
 
+#include <android-base/logging.h>
 #include <array>
+#include <ftl/enum.h>
 #include <inttypes.h>
 #include <limits.h>
 #include <math.h>
@@ -145,27 +147,19 @@
 VelocityTracker::VelocityTracker(const Strategy strategy)
       : mLastEventTime(0), mCurrentPointerIdBits(0), mOverrideStrategy(strategy) {}
 
-VelocityTracker::~VelocityTracker() {
-}
-
 bool VelocityTracker::isAxisSupported(int32_t axis) {
     return DEFAULT_STRATEGY_BY_AXIS.find(axis) != DEFAULT_STRATEGY_BY_AXIS.end();
 }
 
 void VelocityTracker::configureStrategy(int32_t axis) {
     const bool isDifferentialAxis = DIFFERENTIAL_AXES.find(axis) != DIFFERENTIAL_AXES.end();
-
-    std::unique_ptr<VelocityTrackerStrategy> createdStrategy;
-    if (mOverrideStrategy != VelocityTracker::Strategy::DEFAULT) {
-        createdStrategy = createStrategy(mOverrideStrategy, /*deltaValues=*/isDifferentialAxis);
+    if (isDifferentialAxis || mOverrideStrategy == VelocityTracker::Strategy::DEFAULT) {
+        // Do not allow overrides of strategies for differential axes, for now.
+        mConfiguredStrategies[axis] = createStrategy(DEFAULT_STRATEGY_BY_AXIS.at(axis),
+                                                     /*deltaValues=*/isDifferentialAxis);
     } else {
-        createdStrategy = createStrategy(DEFAULT_STRATEGY_BY_AXIS.at(axis),
-                                         /*deltaValues=*/isDifferentialAxis);
+        mConfiguredStrategies[axis] = createStrategy(mOverrideStrategy, /*deltaValues=*/false);
     }
-
-    LOG_ALWAYS_FATAL_IF(createdStrategy == nullptr,
-                        "Could not create velocity tracker strategy for axis '%" PRId32 "'!", axis);
-    mConfiguredStrategies[axis] = std::move(createdStrategy);
 }
 
 std::unique_ptr<VelocityTrackerStrategy> VelocityTracker::createStrategy(
@@ -213,6 +207,9 @@
         default:
             break;
     }
+    LOG(FATAL) << "Invalid strategy: " << ftl::enum_string(strategy)
+               << ", deltaValues = " << deltaValues;
+
     return nullptr;
 }
 
diff --git a/libs/input/input_verifier.rs b/libs/input/input_verifier.rs
index 2e05a63..dd2ac4c 100644
--- a/libs/input/input_verifier.rs
+++ b/libs/input/input_verifier.rs
@@ -32,6 +32,7 @@
 use log::info;
 
 #[cxx::bridge(namespace = "android::input")]
+#[allow(unsafe_op_in_unsafe_fn)]
 mod ffi {
     #[namespace = "android"]
     unsafe extern "C++" {
diff --git a/libs/input/tests/Android.bp b/libs/input/tests/Android.bp
index 86b996b..e7224ff 100644
--- a/libs/input/tests/Android.bp
+++ b/libs/input/tests/Android.bp
@@ -20,6 +20,7 @@
         "InputPublisherAndConsumer_test.cpp",
         "InputVerifier_test.cpp",
         "MotionPredictor_test.cpp",
+        "MotionPredictorMetricsManager_test.cpp",
         "RingBuffer_test.cpp",
         "TfLiteMotionPredictor_test.cpp",
         "TouchResampling_test.cpp",
@@ -52,13 +53,6 @@
             undefined: true,
         },
     },
-    target: {
-        host: {
-            sanitize: {
-                address: true,
-            },
-        },
-    },
     shared_libs: [
         "libbase",
         "libbinder",
@@ -77,6 +71,21 @@
         unit_test: true,
     },
     test_suites: ["device-tests"],
+    target: {
+        host: {
+            sanitize: {
+                address: true,
+            },
+        },
+        android: {
+            static_libs: [
+                // Stats logging library and its dependencies.
+                "libstatslog_libinput",
+                "libstatsbootstrap",
+                "android.os.statsbootstrap_aidl-cpp",
+            ],
+        },
+    },
 }
 
 // NOTE: This is a compile time test, and does not need to be
diff --git a/libs/input/tests/MotionPredictorMetricsManager_test.cpp b/libs/input/tests/MotionPredictorMetricsManager_test.cpp
new file mode 100644
index 0000000..b420a5a
--- /dev/null
+++ b/libs/input/tests/MotionPredictorMetricsManager_test.cpp
@@ -0,0 +1,972 @@
+/*
+ * Copyright 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <input/MotionPredictor.h>
+
+#include <cmath>
+#include <cstddef>
+#include <cstdint>
+#include <numeric>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include <input/InputEventBuilders.h>
+#include <utils/Timers.h> // for nsecs_t
+
+#include "Eigen/Core"
+#include "Eigen/Geometry"
+
+namespace android {
+namespace {
+
+using ::testing::FloatNear;
+using ::testing::Matches;
+
+using GroundTruthPoint = MotionPredictorMetricsManager::GroundTruthPoint;
+using PredictionPoint = MotionPredictorMetricsManager::PredictionPoint;
+using AtomFields = MotionPredictorMetricsManager::AtomFields;
+
+inline constexpr int NANOS_PER_MILLIS = 1'000'000;
+
+inline constexpr nsecs_t TEST_INITIAL_TIMESTAMP = 1'000'000'000;
+inline constexpr size_t TEST_MAX_NUM_PREDICTIONS = 5;
+inline constexpr nsecs_t TEST_PREDICTION_INTERVAL_NANOS = 12'500'000 / 3; // 1 / (240 hz)
+inline constexpr int NO_DATA_SENTINEL = MotionPredictorMetricsManager::NO_DATA_SENTINEL;
+
+// Parameters:
+//  • arg: Eigen::Vector2f
+//  • target: Eigen::Vector2f
+//  • epsilon: float
+MATCHER_P2(Vector2fNear, target, epsilon, "") {
+    return Matches(FloatNear(target[0], epsilon))(arg[0]) &&
+            Matches(FloatNear(target[1], epsilon))(arg[1]);
+}
+
+// Parameters:
+//  • arg: PredictionPoint
+//  • target: PredictionPoint
+//  • epsilon: float
+MATCHER_P2(PredictionPointNear, target, epsilon, "") {
+    if (!Matches(Vector2fNear(target.position, epsilon))(arg.position)) {
+        *result_listener << "Position mismatch. Actual: (" << arg.position[0] << ", "
+                         << arg.position[1] << "), expected: (" << target.position[0] << ", "
+                         << target.position[1] << ")";
+        return false;
+    }
+    if (!Matches(FloatNear(target.pressure, epsilon))(arg.pressure)) {
+        *result_listener << "Pressure mismatch. Actual: " << arg.pressure
+                         << ", expected: " << target.pressure;
+        return false;
+    }
+    if (arg.originTimestamp != target.originTimestamp) {
+        *result_listener << "Origin timestamp mismatch. Actual: " << arg.originTimestamp
+                         << ", expected: " << target.originTimestamp;
+        return false;
+    }
+    if (arg.targetTimestamp != target.targetTimestamp) {
+        *result_listener << "Target timestamp mismatch. Actual: " << arg.targetTimestamp
+                         << ", expected: " << target.targetTimestamp;
+        return false;
+    }
+    return true;
+}
+
+// --- Mathematical helper functions. ---
+
+template <typename T>
+T average(std::vector<T> values) {
+    return std::accumulate(values.begin(), values.end(), T{}) / static_cast<T>(values.size());
+}
+
+template <typename T>
+T standardDeviation(std::vector<T> values) {
+    T mean = average(values);
+    T accumulator = {};
+    for (const T value : values) {
+        accumulator += value * value - mean * mean;
+    }
+    // Take the max with 0 to avoid negative values caused by numerical instability.
+    return std::sqrt(std::max(T{}, accumulator) / static_cast<T>(values.size()));
+}
+
+template <typename T>
+T rmse(std::vector<T> errors) {
+    T sse = {};
+    for (const T error : errors) {
+        sse += error * error;
+    }
+    return std::sqrt(sse / static_cast<T>(errors.size()));
+}
+
+TEST(MathematicalHelperFunctionTest, Average) {
+    std::vector<float> values{1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
+    EXPECT_EQ(5.5f, average(values));
+}
+
+TEST(MathematicalHelperFunctionTest, StandardDeviation) {
+    // https://www.calculator.net/standard-deviation-calculator.html?numberinputs=10%2C+12%2C+23%2C+23%2C+16%2C+23%2C+21%2C+16
+    std::vector<float> values{10, 12, 23, 23, 16, 23, 21, 16};
+    EXPECT_FLOAT_EQ(4.8989794855664f, standardDeviation(values));
+}
+
+TEST(MathematicalHelperFunctionTest, Rmse) {
+    std::vector<float> errors{1, 5, 7, 7, 8, 20};
+    EXPECT_FLOAT_EQ(9.899494937f, rmse(errors));
+}
+
+// --- MotionEvent-related helper functions. ---
+
+// Creates a MotionEvent corresponding to the given GroundTruthPoint.
+MotionEvent makeMotionEvent(const GroundTruthPoint& groundTruthPoint) {
+    // Build single pointer of type STYLUS, with coordinates from groundTruthPoint.
+    PointerBuilder pointerBuilder =
+            PointerBuilder(/*id=*/0, ToolType::STYLUS)
+                    .x(groundTruthPoint.position[1])
+                    .y(groundTruthPoint.position[0])
+                    .axis(AMOTION_EVENT_AXIS_PRESSURE, groundTruthPoint.pressure);
+    return MotionEventBuilder(/*action=*/AMOTION_EVENT_ACTION_MOVE,
+                              /*source=*/AINPUT_SOURCE_CLASS_POINTER)
+            .eventTime(groundTruthPoint.timestamp)
+            .pointer(pointerBuilder)
+            .build();
+}
+
+// Creates a MotionEvent corresponding to the given sequence of PredictionPoints.
+MotionEvent makeMotionEvent(const std::vector<PredictionPoint>& predictionPoints) {
+    // Build single pointer of type STYLUS, with coordinates from first prediction point.
+    PointerBuilder pointerBuilder =
+            PointerBuilder(/*id=*/0, ToolType::STYLUS)
+                    .x(predictionPoints[0].position[1])
+                    .y(predictionPoints[0].position[0])
+                    .axis(AMOTION_EVENT_AXIS_PRESSURE, predictionPoints[0].pressure);
+    MotionEvent predictionEvent =
+            MotionEventBuilder(
+                    /*action=*/AMOTION_EVENT_ACTION_MOVE, /*source=*/AINPUT_SOURCE_CLASS_POINTER)
+                    .eventTime(predictionPoints[0].targetTimestamp)
+                    .pointer(pointerBuilder)
+                    .build();
+    for (size_t i = 1; i < predictionPoints.size(); ++i) {
+        PointerCoords coords =
+                PointerBuilder(/*id=*/0, ToolType::STYLUS)
+                        .x(predictionPoints[i].position[1])
+                        .y(predictionPoints[i].position[0])
+                        .axis(AMOTION_EVENT_AXIS_PRESSURE, predictionPoints[i].pressure)
+                        .buildCoords();
+        predictionEvent.addSample(predictionPoints[i].targetTimestamp, &coords);
+    }
+    return predictionEvent;
+}
+
+// Creates a MotionEvent corresponding to a stylus lift (UP) ground truth event.
+MotionEvent makeLiftMotionEvent() {
+    return MotionEventBuilder(/*action=*/AMOTION_EVENT_ACTION_UP,
+                              /*source=*/AINPUT_SOURCE_CLASS_POINTER)
+            .pointer(PointerBuilder(/*id=*/0, ToolType::STYLUS))
+            .build();
+}
+
+TEST(MakeMotionEventTest, MakeGroundTruthMotionEvent) {
+    const GroundTruthPoint groundTruthPoint{{.position = Eigen::Vector2f(10.0f, 20.0f),
+                                             .pressure = 0.6f},
+                                            .timestamp = TEST_INITIAL_TIMESTAMP};
+    const MotionEvent groundTruthMotionEvent = makeMotionEvent(groundTruthPoint);
+
+    ASSERT_EQ(1u, groundTruthMotionEvent.getPointerCount());
+    // Note: a MotionEvent's "history size" is one less than its number of samples.
+    ASSERT_EQ(0u, groundTruthMotionEvent.getHistorySize());
+    EXPECT_EQ(groundTruthPoint.position[0], groundTruthMotionEvent.getRawPointerCoords(0)->getY());
+    EXPECT_EQ(groundTruthPoint.position[1], groundTruthMotionEvent.getRawPointerCoords(0)->getX());
+    EXPECT_EQ(groundTruthPoint.pressure,
+              groundTruthMotionEvent.getRawPointerCoords(0)->getAxisValue(
+                      AMOTION_EVENT_AXIS_PRESSURE));
+    EXPECT_EQ(AMOTION_EVENT_ACTION_MOVE, groundTruthMotionEvent.getAction());
+}
+
+TEST(MakeMotionEventTest, MakePredictionMotionEvent) {
+    const nsecs_t originTimestamp = TEST_INITIAL_TIMESTAMP;
+    const std::vector<PredictionPoint>
+            predictionPoints{{{.position = Eigen::Vector2f(10.0f, 20.0f), .pressure = 0.6f},
+                              .originTimestamp = originTimestamp,
+                              .targetTimestamp = originTimestamp + 5 * NANOS_PER_MILLIS},
+                             {{.position = Eigen::Vector2f(11.0f, 22.0f), .pressure = 0.5f},
+                              .originTimestamp = originTimestamp,
+                              .targetTimestamp = originTimestamp + 10 * NANOS_PER_MILLIS},
+                             {{.position = Eigen::Vector2f(12.0f, 24.0f), .pressure = 0.4f},
+                              .originTimestamp = originTimestamp,
+                              .targetTimestamp = originTimestamp + 15 * NANOS_PER_MILLIS}};
+    const MotionEvent predictionMotionEvent = makeMotionEvent(predictionPoints);
+
+    ASSERT_EQ(1u, predictionMotionEvent.getPointerCount());
+    // Note: a MotionEvent's "history size" is one less than its number of samples.
+    ASSERT_EQ(predictionPoints.size(), predictionMotionEvent.getHistorySize() + 1);
+    for (size_t i = 0; i < predictionPoints.size(); ++i) {
+        SCOPED_TRACE(testing::Message() << "i = " << i);
+        const PointerCoords coords = *predictionMotionEvent.getHistoricalRawPointerCoords(
+                /*pointerIndex=*/0, /*historicalIndex=*/i);
+        EXPECT_EQ(predictionPoints[i].position[0], coords.getY());
+        EXPECT_EQ(predictionPoints[i].position[1], coords.getX());
+        EXPECT_EQ(predictionPoints[i].pressure, coords.getAxisValue(AMOTION_EVENT_AXIS_PRESSURE));
+        // Note: originTimestamp is discarded when converting PredictionPoint to MotionEvent.
+        EXPECT_EQ(predictionPoints[i].targetTimestamp,
+                  predictionMotionEvent.getHistoricalEventTime(i));
+        EXPECT_EQ(AMOTION_EVENT_ACTION_MOVE, predictionMotionEvent.getAction());
+    }
+}
+
+TEST(MakeMotionEventTest, MakeLiftMotionEvent) {
+    const MotionEvent liftMotionEvent = makeLiftMotionEvent();
+    ASSERT_EQ(1u, liftMotionEvent.getPointerCount());
+    // Note: a MotionEvent's "history size" is one less than its number of samples.
+    ASSERT_EQ(0u, liftMotionEvent.getHistorySize());
+    EXPECT_EQ(AMOTION_EVENT_ACTION_UP, liftMotionEvent.getAction());
+}
+
+// --- Ground-truth-generation helper functions. ---
+
+std::vector<GroundTruthPoint> generateConstantGroundTruthPoints(
+        const GroundTruthPoint& groundTruthPoint, size_t numPoints) {
+    std::vector<GroundTruthPoint> groundTruthPoints;
+    nsecs_t timestamp = groundTruthPoint.timestamp;
+    for (size_t i = 0; i < numPoints; ++i) {
+        groundTruthPoints.emplace_back(groundTruthPoint);
+        groundTruthPoints.back().timestamp = timestamp;
+        timestamp += TEST_PREDICTION_INTERVAL_NANOS;
+    }
+    return groundTruthPoints;
+}
+
+// This function uses the coordinate system (y, x), with +y pointing downwards and +x pointing
+// rightwards. Angles are measured counterclockwise from down (+y).
+std::vector<GroundTruthPoint> generateCircularArcGroundTruthPoints(Eigen::Vector2f initialPosition,
+                                                                   float initialAngle,
+                                                                   float velocity,
+                                                                   float turningAngle,
+                                                                   size_t numPoints) {
+    std::vector<GroundTruthPoint> groundTruthPoints;
+    // Create first point.
+    if (numPoints > 0) {
+        groundTruthPoints.push_back({{.position = initialPosition, .pressure = 0.0f},
+                                     .timestamp = TEST_INITIAL_TIMESTAMP});
+    }
+    float trajectoryAngle = initialAngle; // measured counterclockwise from +y axis.
+    for (size_t i = 1; i < numPoints; ++i) {
+        const Eigen::Vector2f trajectory =
+                Eigen::Rotation2D(trajectoryAngle) * Eigen::Vector2f(1, 0);
+        groundTruthPoints.push_back(
+                {{.position = groundTruthPoints.back().position + velocity * trajectory,
+                  .pressure = 0.0f},
+                 .timestamp = groundTruthPoints.back().timestamp + TEST_PREDICTION_INTERVAL_NANOS});
+        trajectoryAngle += turningAngle;
+    }
+    return groundTruthPoints;
+}
+
+TEST(GenerateConstantGroundTruthPointsTest, BasicTest) {
+    const GroundTruthPoint groundTruthPoint{{.position = Eigen::Vector2f(10, 20), .pressure = 0.3f},
+                                            .timestamp = TEST_INITIAL_TIMESTAMP};
+    const std::vector<GroundTruthPoint> groundTruthPoints =
+            generateConstantGroundTruthPoints(groundTruthPoint, /*numPoints=*/3);
+
+    ASSERT_EQ(3u, groundTruthPoints.size());
+    // First point.
+    EXPECT_EQ(groundTruthPoints[0].position, groundTruthPoint.position);
+    EXPECT_EQ(groundTruthPoints[0].pressure, groundTruthPoint.pressure);
+    EXPECT_EQ(groundTruthPoints[0].timestamp, groundTruthPoint.timestamp);
+    // Second point.
+    EXPECT_EQ(groundTruthPoints[1].position, groundTruthPoint.position);
+    EXPECT_EQ(groundTruthPoints[1].pressure, groundTruthPoint.pressure);
+    EXPECT_GT(groundTruthPoints[1].timestamp, groundTruthPoints[0].timestamp);
+    // Third point.
+    EXPECT_EQ(groundTruthPoints[2].position, groundTruthPoint.position);
+    EXPECT_EQ(groundTruthPoints[2].pressure, groundTruthPoint.pressure);
+    EXPECT_GT(groundTruthPoints[2].timestamp, groundTruthPoints[1].timestamp);
+}
+
+TEST(GenerateCircularArcGroundTruthTest, StraightLineUpwards) {
+    const std::vector<GroundTruthPoint> groundTruthPoints = generateCircularArcGroundTruthPoints(
+            /*initialPosition=*/Eigen::Vector2f(0, 0),
+            /*initialAngle=*/M_PI,
+            /*velocity=*/1.0f,
+            /*turningAngle=*/0.0f,
+            /*numPoints=*/3);
+
+    ASSERT_EQ(3u, groundTruthPoints.size());
+    EXPECT_THAT(groundTruthPoints[0].position, Vector2fNear(Eigen::Vector2f(0, 0), 1e-6));
+    EXPECT_THAT(groundTruthPoints[1].position, Vector2fNear(Eigen::Vector2f(-1, 0), 1e-6));
+    EXPECT_THAT(groundTruthPoints[2].position, Vector2fNear(Eigen::Vector2f(-2, 0), 1e-6));
+    // Check that timestamps are increasing between consecutive ground truth points.
+    EXPECT_GT(groundTruthPoints[1].timestamp, groundTruthPoints[0].timestamp);
+    EXPECT_GT(groundTruthPoints[2].timestamp, groundTruthPoints[1].timestamp);
+}
+
+TEST(GenerateCircularArcGroundTruthTest, CounterclockwiseSquare) {
+    // Generate points in a counterclockwise unit square starting pointing right.
+    const std::vector<GroundTruthPoint> groundTruthPoints = generateCircularArcGroundTruthPoints(
+            /*initialPosition=*/Eigen::Vector2f(10, 100),
+            /*initialAngle=*/M_PI_2,
+            /*velocity=*/1.0f,
+            /*turningAngle=*/M_PI_2,
+            /*numPoints=*/5);
+
+    ASSERT_EQ(5u, groundTruthPoints.size());
+    EXPECT_THAT(groundTruthPoints[0].position, Vector2fNear(Eigen::Vector2f(10, 100), 1e-6));
+    EXPECT_THAT(groundTruthPoints[1].position, Vector2fNear(Eigen::Vector2f(10, 101), 1e-6));
+    EXPECT_THAT(groundTruthPoints[2].position, Vector2fNear(Eigen::Vector2f(9, 101), 1e-6));
+    EXPECT_THAT(groundTruthPoints[3].position, Vector2fNear(Eigen::Vector2f(9, 100), 1e-6));
+    EXPECT_THAT(groundTruthPoints[4].position, Vector2fNear(Eigen::Vector2f(10, 100), 1e-6));
+}
+
+// --- Prediction-generation helper functions. ---
+
+// Creates a sequence of predictions with values equal to those of the given GroundTruthPoint.
+std::vector<PredictionPoint> generateConstantPredictions(const GroundTruthPoint& groundTruthPoint) {
+    std::vector<PredictionPoint> predictions;
+    nsecs_t predictionTimestamp = groundTruthPoint.timestamp + TEST_PREDICTION_INTERVAL_NANOS;
+    for (size_t j = 0; j < TEST_MAX_NUM_PREDICTIONS; ++j) {
+        predictions.push_back(PredictionPoint{{.position = groundTruthPoint.position,
+                                               .pressure = groundTruthPoint.pressure},
+                                              .originTimestamp = groundTruthPoint.timestamp,
+                                              .targetTimestamp = predictionTimestamp});
+        predictionTimestamp += TEST_PREDICTION_INTERVAL_NANOS;
+    }
+    return predictions;
+}
+
+// Generates TEST_MAX_NUM_PREDICTIONS predictions from the given most recent two ground truth points
+// by linear extrapolation of position and pressure. The interval between consecutive predictions'
+// timestamps is TEST_PREDICTION_INTERVAL_NANOS.
+std::vector<PredictionPoint> generatePredictionsByLinearExtrapolation(
+        const GroundTruthPoint& firstGroundTruth, const GroundTruthPoint& secondGroundTruth) {
+    // Precompute deltas.
+    const Eigen::Vector2f trajectory = secondGroundTruth.position - firstGroundTruth.position;
+    const float deltaPressure = secondGroundTruth.pressure - firstGroundTruth.pressure;
+    // Compute predictions.
+    std::vector<PredictionPoint> predictions;
+    Eigen::Vector2f predictionPosition = secondGroundTruth.position;
+    float predictionPressure = secondGroundTruth.pressure;
+    nsecs_t predictionTargetTimestamp = secondGroundTruth.timestamp;
+    for (size_t i = 0; i < TEST_MAX_NUM_PREDICTIONS; ++i) {
+        predictionPosition += trajectory;
+        predictionPressure += deltaPressure;
+        predictionTargetTimestamp += TEST_PREDICTION_INTERVAL_NANOS;
+        predictions.push_back(
+                PredictionPoint{{.position = predictionPosition, .pressure = predictionPressure},
+                                .originTimestamp = secondGroundTruth.timestamp,
+                                .targetTimestamp = predictionTargetTimestamp});
+    }
+    return predictions;
+}
+
+TEST(GeneratePredictionsTest, GenerateConstantPredictions) {
+    const GroundTruthPoint groundTruthPoint{{.position = Eigen::Vector2f(10, 20), .pressure = 0.3f},
+                                            .timestamp = TEST_INITIAL_TIMESTAMP};
+    const std::vector<PredictionPoint> predictionPoints =
+            generateConstantPredictions(groundTruthPoint);
+
+    ASSERT_EQ(TEST_MAX_NUM_PREDICTIONS, predictionPoints.size());
+    for (size_t i = 0; i < predictionPoints.size(); ++i) {
+        SCOPED_TRACE(testing::Message() << "i = " << i);
+        EXPECT_THAT(predictionPoints[i].position, Vector2fNear(groundTruthPoint.position, 1e-6));
+        EXPECT_THAT(predictionPoints[i].pressure, FloatNear(groundTruthPoint.pressure, 1e-6));
+        EXPECT_EQ(predictionPoints[i].originTimestamp, groundTruthPoint.timestamp);
+        EXPECT_EQ(predictionPoints[i].targetTimestamp,
+                  groundTruthPoint.timestamp +
+                          static_cast<nsecs_t>(i + 1) * TEST_PREDICTION_INTERVAL_NANOS);
+    }
+}
+
+TEST(GeneratePredictionsTest, LinearExtrapolationFromTwoPoints) {
+    const nsecs_t initialTimestamp = TEST_INITIAL_TIMESTAMP;
+    const std::vector<PredictionPoint> predictionPoints = generatePredictionsByLinearExtrapolation(
+            GroundTruthPoint{{.position = Eigen::Vector2f(100, 200), .pressure = 0.9f},
+                             .timestamp = initialTimestamp},
+            GroundTruthPoint{{.position = Eigen::Vector2f(105, 190), .pressure = 0.8f},
+                             .timestamp = initialTimestamp + TEST_PREDICTION_INTERVAL_NANOS});
+
+    ASSERT_EQ(TEST_MAX_NUM_PREDICTIONS, predictionPoints.size());
+    const nsecs_t originTimestamp = initialTimestamp + TEST_PREDICTION_INTERVAL_NANOS;
+    EXPECT_THAT(predictionPoints[0],
+                PredictionPointNear(PredictionPoint{{.position = Eigen::Vector2f(110, 180),
+                                                     .pressure = 0.7f},
+                                                    .originTimestamp = originTimestamp,
+                                                    .targetTimestamp = originTimestamp +
+                                                            TEST_PREDICTION_INTERVAL_NANOS},
+                                    0.001));
+    EXPECT_THAT(predictionPoints[1],
+                PredictionPointNear(PredictionPoint{{.position = Eigen::Vector2f(115, 170),
+                                                     .pressure = 0.6f},
+                                                    .originTimestamp = originTimestamp,
+                                                    .targetTimestamp = originTimestamp +
+                                                            2 * TEST_PREDICTION_INTERVAL_NANOS},
+                                    0.001));
+    EXPECT_THAT(predictionPoints[2],
+                PredictionPointNear(PredictionPoint{{.position = Eigen::Vector2f(120, 160),
+                                                     .pressure = 0.5f},
+                                                    .originTimestamp = originTimestamp,
+                                                    .targetTimestamp = originTimestamp +
+                                                            3 * TEST_PREDICTION_INTERVAL_NANOS},
+                                    0.001));
+    EXPECT_THAT(predictionPoints[3],
+                PredictionPointNear(PredictionPoint{{.position = Eigen::Vector2f(125, 150),
+                                                     .pressure = 0.4f},
+                                                    .originTimestamp = originTimestamp,
+                                                    .targetTimestamp = originTimestamp +
+                                                            4 * TEST_PREDICTION_INTERVAL_NANOS},
+                                    0.001));
+    EXPECT_THAT(predictionPoints[4],
+                PredictionPointNear(PredictionPoint{{.position = Eigen::Vector2f(130, 140),
+                                                     .pressure = 0.3f},
+                                                    .originTimestamp = originTimestamp,
+                                                    .targetTimestamp = originTimestamp +
+                                                            5 * TEST_PREDICTION_INTERVAL_NANOS},
+                                    0.001));
+}
+
+// Generates predictions by linear extrapolation for each consecutive pair of ground truth points
+// (see the comment for the above function for further explanation). Returns a vector of vectors of
+// prediction points, where the first index is the source ground truth index, and the second is the
+// prediction target index.
+//
+// The returned vector has size equal to the input vector, and the first element of the returned
+// vector is always empty.
+std::vector<std::vector<PredictionPoint>> generateAllPredictionsByLinearExtrapolation(
+        const std::vector<GroundTruthPoint>& groundTruthPoints) {
+    std::vector<std::vector<PredictionPoint>> allPredictions;
+    allPredictions.emplace_back();
+    for (size_t i = 1; i < groundTruthPoints.size(); ++i) {
+        allPredictions.push_back(generatePredictionsByLinearExtrapolation(groundTruthPoints[i - 1],
+                                                                          groundTruthPoints[i]));
+    }
+    return allPredictions;
+}
+
+TEST(GeneratePredictionsTest, GenerateAllPredictions) {
+    const nsecs_t initialTimestamp = TEST_INITIAL_TIMESTAMP;
+    std::vector<GroundTruthPoint>
+            groundTruthPoints{GroundTruthPoint{{.position = Eigen::Vector2f(0, 0),
+                                                .pressure = 0.5f},
+                                               .timestamp = initialTimestamp},
+                              GroundTruthPoint{{.position = Eigen::Vector2f(1, -1),
+                                                .pressure = 0.51f},
+                                               .timestamp = initialTimestamp +
+                                                       2 * TEST_PREDICTION_INTERVAL_NANOS},
+                              GroundTruthPoint{{.position = Eigen::Vector2f(2, -2),
+                                                .pressure = 0.52f},
+                                               .timestamp = initialTimestamp +
+                                                       3 * TEST_PREDICTION_INTERVAL_NANOS}};
+
+    const std::vector<std::vector<PredictionPoint>> allPredictions =
+            generateAllPredictionsByLinearExtrapolation(groundTruthPoints);
+
+    // Check format of allPredictions data.
+    ASSERT_EQ(groundTruthPoints.size(), allPredictions.size());
+    EXPECT_TRUE(allPredictions[0].empty());
+    EXPECT_EQ(TEST_MAX_NUM_PREDICTIONS, allPredictions[1].size());
+    EXPECT_EQ(TEST_MAX_NUM_PREDICTIONS, allPredictions[2].size());
+
+    // Check positions of predictions generated from first pair of ground truth points.
+    EXPECT_THAT(allPredictions[1][0].position, Vector2fNear(Eigen::Vector2f(2, -2), 1e-9));
+    EXPECT_THAT(allPredictions[1][1].position, Vector2fNear(Eigen::Vector2f(3, -3), 1e-9));
+    EXPECT_THAT(allPredictions[1][2].position, Vector2fNear(Eigen::Vector2f(4, -4), 1e-9));
+    EXPECT_THAT(allPredictions[1][3].position, Vector2fNear(Eigen::Vector2f(5, -5), 1e-9));
+    EXPECT_THAT(allPredictions[1][4].position, Vector2fNear(Eigen::Vector2f(6, -6), 1e-9));
+
+    // Check pressures of predictions generated from first pair of ground truth points.
+    EXPECT_FLOAT_EQ(0.52f, allPredictions[1][0].pressure);
+    EXPECT_FLOAT_EQ(0.53f, allPredictions[1][1].pressure);
+    EXPECT_FLOAT_EQ(0.54f, allPredictions[1][2].pressure);
+    EXPECT_FLOAT_EQ(0.55f, allPredictions[1][3].pressure);
+    EXPECT_FLOAT_EQ(0.56f, allPredictions[1][4].pressure);
+}
+
+// --- Prediction error helper functions. ---
+
+struct GeneralPositionErrors {
+    float alongTrajectoryErrorMean;
+    float alongTrajectoryErrorStd;
+    float offTrajectoryRmse;
+};
+
+// Inputs:
+//  • Vector of ground truth points
+//  • Vector of vectors of prediction points, where the first index is the source ground truth
+//    index, and the second is the prediction target index.
+//
+// Returns a vector of GeneralPositionErrors, indexed by prediction time delta bucket.
+std::vector<GeneralPositionErrors> computeGeneralPositionErrors(
+        const std::vector<GroundTruthPoint>& groundTruthPoints,
+        const std::vector<std::vector<PredictionPoint>>& predictionPoints) {
+    // Aggregate errors by time bucket (prediction target index).
+    std::vector<GeneralPositionErrors> generalPostitionErrors;
+    for (size_t predictionTargetIndex = 0; predictionTargetIndex < TEST_MAX_NUM_PREDICTIONS;
+         ++predictionTargetIndex) {
+        std::vector<float> alongTrajectoryErrors;
+        std::vector<float> alongTrajectorySquaredErrors;
+        std::vector<float> offTrajectoryErrors;
+        for (size_t sourceGroundTruthIndex = 1; sourceGroundTruthIndex < groundTruthPoints.size();
+             ++sourceGroundTruthIndex) {
+            const size_t targetGroundTruthIndex =
+                    sourceGroundTruthIndex + predictionTargetIndex + 1;
+            // Only include errors for points with a ground truth value.
+            if (targetGroundTruthIndex < groundTruthPoints.size()) {
+                const Eigen::Vector2f trajectory =
+                        (groundTruthPoints[targetGroundTruthIndex].position -
+                         groundTruthPoints[targetGroundTruthIndex - 1].position)
+                                .normalized();
+                const Eigen::Vector2f orthogonalTrajectory =
+                        Eigen::Rotation2Df(M_PI_2) * trajectory;
+                const Eigen::Vector2f positionError =
+                        predictionPoints[sourceGroundTruthIndex][predictionTargetIndex].position -
+                        groundTruthPoints[targetGroundTruthIndex].position;
+                alongTrajectoryErrors.push_back(positionError.dot(trajectory));
+                alongTrajectorySquaredErrors.push_back(alongTrajectoryErrors.back() *
+                                                       alongTrajectoryErrors.back());
+                offTrajectoryErrors.push_back(positionError.dot(orthogonalTrajectory));
+            }
+        }
+        generalPostitionErrors.push_back(
+                {.alongTrajectoryErrorMean = average(alongTrajectoryErrors),
+                 .alongTrajectoryErrorStd = standardDeviation(alongTrajectoryErrors),
+                 .offTrajectoryRmse = rmse(offTrajectoryErrors)});
+    }
+    return generalPostitionErrors;
+}
+
+// Inputs:
+//  • Vector of ground truth points
+//  • Vector of vectors of prediction points, where the first index is the source ground truth
+//    index, and the second is the prediction target index.
+//
+// Returns a vector of pressure RMSEs, indexed by prediction time delta bucket.
+std::vector<float> computePressureRmses(
+        const std::vector<GroundTruthPoint>& groundTruthPoints,
+        const std::vector<std::vector<PredictionPoint>>& predictionPoints) {
+    // Aggregate errors by time bucket (prediction target index).
+    std::vector<float> pressureRmses;
+    for (size_t predictionTargetIndex = 0; predictionTargetIndex < TEST_MAX_NUM_PREDICTIONS;
+         ++predictionTargetIndex) {
+        std::vector<float> pressureErrors;
+        for (size_t sourceGroundTruthIndex = 1; sourceGroundTruthIndex < groundTruthPoints.size();
+             ++sourceGroundTruthIndex) {
+            const size_t targetGroundTruthIndex =
+                    sourceGroundTruthIndex + predictionTargetIndex + 1;
+            // Only include errors for points with a ground truth value.
+            if (targetGroundTruthIndex < groundTruthPoints.size()) {
+                pressureErrors.push_back(
+                        predictionPoints[sourceGroundTruthIndex][predictionTargetIndex].pressure -
+                        groundTruthPoints[targetGroundTruthIndex].pressure);
+            }
+        }
+        pressureRmses.push_back(rmse(pressureErrors));
+    }
+    return pressureRmses;
+}
+
+TEST(ErrorComputationHelperTest, ComputeGeneralPositionErrorsSimpleTest) {
+    std::vector<GroundTruthPoint> groundTruthPoints =
+            generateConstantGroundTruthPoints(GroundTruthPoint{{.position = Eigen::Vector2f(0, 0),
+                                                                .pressure = 0.0f},
+                                                               .timestamp = TEST_INITIAL_TIMESTAMP},
+                                              /*numPoints=*/TEST_MAX_NUM_PREDICTIONS + 2);
+    groundTruthPoints[3].position = Eigen::Vector2f(1, 0);
+    groundTruthPoints[4].position = Eigen::Vector2f(1, 1);
+    groundTruthPoints[5].position = Eigen::Vector2f(1, 3);
+    groundTruthPoints[6].position = Eigen::Vector2f(2, 3);
+
+    std::vector<std::vector<PredictionPoint>> predictionPoints =
+            generateAllPredictionsByLinearExtrapolation(groundTruthPoints);
+
+    // The generated predictions look like:
+    //
+    // |    Source  |         Target Ground Truth Index          |
+    // |     Index  |   2    |   3    |   4    |   5    |   6    |
+    // |------------|--------|--------|--------|--------|--------|
+    // |          1 | (0, 0) | (0, 0) | (0, 0) | (0, 0) | (0, 0) |
+    // |          2 |        | (0, 0) | (0, 0) | (0, 0) | (0, 0) |
+    // |          3 |        |        | (2, 0) | (3, 0) | (4, 0) |
+    // |          4 |        |        |        | (1, 2) | (1, 3) |
+    // |          5 |        |        |        |        | (1, 5) |
+    // |---------------------------------------------------------|
+    // |               Actual Ground Truth Values                |
+    // |  Position  | (0, 0) | (1, 0) | (1, 1) | (1, 3) | (2, 3) |
+    // |  Previous  | (0, 0) | (0, 0) | (1, 0) | (1, 1) | (1, 3) |
+    //
+    // Note: this table organizes prediction targets by target ground truth index. Metrics are
+    // aggregated across points with the same prediction time bucket index, which is different.
+    // Each down-right diagonal from this table gives us points from a unique time bucket.
+
+    // Initialize expected prediction errors from the table above. The first time bucket corresponds
+    // to the long diagonal of the table, and subsequent time buckets step up-right from there.
+    const std::vector<std::vector<float>> expectedAlongTrajectoryErrors{{0, -1, -1, -1, -1},
+                                                                        {-1, -1, -3, -1},
+                                                                        {-1, -3, 2},
+                                                                        {-3, -2},
+                                                                        {-2}};
+    const std::vector<std::vector<float>> expectedOffTrajectoryErrors{{0, 0, 1, 0, 2},
+                                                                      {0, 1, 2, 0},
+                                                                      {1, 1, 3},
+                                                                      {1, 3},
+                                                                      {3}};
+
+    std::vector<GeneralPositionErrors> generalPositionErrors =
+            computeGeneralPositionErrors(groundTruthPoints, predictionPoints);
+
+    ASSERT_EQ(TEST_MAX_NUM_PREDICTIONS, generalPositionErrors.size());
+    for (size_t i = 0; i < generalPositionErrors.size(); ++i) {
+        SCOPED_TRACE(testing::Message() << "i = " << i);
+        EXPECT_FLOAT_EQ(average(expectedAlongTrajectoryErrors[i]),
+                        generalPositionErrors[i].alongTrajectoryErrorMean);
+        EXPECT_FLOAT_EQ(standardDeviation(expectedAlongTrajectoryErrors[i]),
+                        generalPositionErrors[i].alongTrajectoryErrorStd);
+        EXPECT_FLOAT_EQ(rmse(expectedOffTrajectoryErrors[i]),
+                        generalPositionErrors[i].offTrajectoryRmse);
+    }
+}
+
+TEST(ErrorComputationHelperTest, ComputePressureRmsesSimpleTest) {
+    // Generate ground truth points with pressures {0.0, 0.0, 0.0, 0.0, 0.5, 0.5, 0.5}.
+    // (We need TEST_MAX_NUM_PREDICTIONS + 2 to test all prediction time buckets.)
+    std::vector<GroundTruthPoint> groundTruthPoints =
+            generateConstantGroundTruthPoints(GroundTruthPoint{{.position = Eigen::Vector2f(0, 0),
+                                                                .pressure = 0.0f},
+                                                               .timestamp = TEST_INITIAL_TIMESTAMP},
+                                              /*numPoints=*/TEST_MAX_NUM_PREDICTIONS + 2);
+    for (size_t i = 4; i < groundTruthPoints.size(); ++i) {
+        groundTruthPoints[i].pressure = 0.5f;
+    }
+
+    std::vector<std::vector<PredictionPoint>> predictionPoints =
+            generateAllPredictionsByLinearExtrapolation(groundTruthPoints);
+
+    std::vector<float> pressureRmses = computePressureRmses(groundTruthPoints, predictionPoints);
+
+    ASSERT_EQ(TEST_MAX_NUM_PREDICTIONS, pressureRmses.size());
+    EXPECT_FLOAT_EQ(rmse(std::vector<float>{0.0f, 0.0f, -0.5f, 0.5f, 0.0f}), pressureRmses[0]);
+    EXPECT_FLOAT_EQ(rmse(std::vector<float>{0.0f, -0.5f, -0.5f, 1.0f}), pressureRmses[1]);
+    EXPECT_FLOAT_EQ(rmse(std::vector<float>{-0.5f, -0.5f, -0.5f}), pressureRmses[2]);
+    EXPECT_FLOAT_EQ(rmse(std::vector<float>{-0.5f, -0.5f}), pressureRmses[3]);
+    EXPECT_FLOAT_EQ(rmse(std::vector<float>{-0.5f}), pressureRmses[4]);
+}
+
+// --- MotionPredictorMetricsManager tests. ---
+
+// Helper function that instantiates a MetricsManager with the given mock logged AtomFields. Takes
+// vectors of ground truth and prediction points of the same length, and passes these points to the
+// MetricsManager. The format of these vectors is expected to be:
+//  • groundTruthPoints: chronologically-ordered ground truth points, with at least 2 elements.
+//  • predictionPoints: the first index points to a vector of predictions corresponding to the
+//    source ground truth point with the same index.
+//     - The first element should be empty, because there are not expected to be predictions until
+//       we have received 2 ground truth points.
+//     - The last element may be empty, because there will be no future ground truth points to
+//       associate with those predictions (if not empty, it will be ignored).
+//     - To test all prediction buckets, there should be at least TEST_MAX_NUM_PREDICTIONS non-empty
+//       prediction sets (that is, excluding the first and last). Thus, groundTruthPoints and
+//       predictionPoints should have size at least TEST_MAX_NUM_PREDICTIONS + 2.
+//
+// The passed-in outAtomFields will contain the logged AtomFields when the function returns.
+//
+// This function returns void so that it can use test assertions.
+void runMetricsManager(const std::vector<GroundTruthPoint>& groundTruthPoints,
+                       const std::vector<std::vector<PredictionPoint>>& predictionPoints,
+                       std::vector<AtomFields>& outAtomFields) {
+    MotionPredictorMetricsManager metricsManager(TEST_PREDICTION_INTERVAL_NANOS,
+                                                 TEST_MAX_NUM_PREDICTIONS);
+    metricsManager.setMockLoggedAtomFields(&outAtomFields);
+
+    // Validate structure of groundTruthPoints and predictionPoints.
+    ASSERT_EQ(predictionPoints.size(), groundTruthPoints.size());
+    ASSERT_GE(groundTruthPoints.size(), 2u);
+    ASSERT_EQ(predictionPoints[0].size(), 0u);
+    for (size_t i = 1; i + 1 < predictionPoints.size(); ++i) {
+        SCOPED_TRACE(testing::Message() << "i = " << i);
+        ASSERT_EQ(predictionPoints[i].size(), TEST_MAX_NUM_PREDICTIONS);
+    }
+
+    // Pass ground truth points and predictions (for all except first and last ground truth).
+    for (size_t i = 0; i < groundTruthPoints.size(); ++i) {
+        metricsManager.onRecord(makeMotionEvent(groundTruthPoints[i]));
+        if ((i > 0) && (i + 1 < predictionPoints.size())) {
+            metricsManager.onPredict(makeMotionEvent(predictionPoints[i]));
+        }
+    }
+    // Send a stroke-end event to trigger the logging call.
+    metricsManager.onRecord(makeLiftMotionEvent());
+}
+
+// Vacuous test:
+//  • Input: no prediction data.
+//  • Expectation: no metrics should be logged.
+TEST(MotionPredictorMetricsManagerTest, NoPredictions) {
+    std::vector<AtomFields> mockLoggedAtomFields;
+    MotionPredictorMetricsManager metricsManager(TEST_PREDICTION_INTERVAL_NANOS,
+                                                 TEST_MAX_NUM_PREDICTIONS);
+    metricsManager.setMockLoggedAtomFields(&mockLoggedAtomFields);
+
+    metricsManager.onRecord(makeMotionEvent(
+            GroundTruthPoint{{.position = Eigen::Vector2f(0, 0), .pressure = 0}, .timestamp = 0}));
+    metricsManager.onRecord(makeLiftMotionEvent());
+
+    // Check that mockLoggedAtomFields is still empty (as it was initialized empty), ensuring that
+    // no metrics were logged.
+    EXPECT_EQ(0u, mockLoggedAtomFields.size());
+}
+
+// Perfect predictions test:
+//  • Input: constant input events, perfect predictions matching the input events.
+//  • Expectation: all error metrics should be zero, or NO_DATA_SENTINEL for "unreported" metrics.
+//    (For example, scale-invariant errors are only reported for the final time bucket.)
+TEST(MotionPredictorMetricsManagerTest, ConstantGroundTruthPerfectPredictions) {
+    GroundTruthPoint groundTruthPoint{{.position = Eigen::Vector2f(10.0f, 20.0f), .pressure = 0.6f},
+                                      .timestamp = TEST_INITIAL_TIMESTAMP};
+
+    // Generate ground truth and prediction points as described by the runMetricsManager comment.
+    std::vector<GroundTruthPoint> groundTruthPoints;
+    std::vector<std::vector<PredictionPoint>> predictionPoints;
+    for (size_t i = 0; i < TEST_MAX_NUM_PREDICTIONS + 2; ++i) {
+        groundTruthPoints.push_back(groundTruthPoint);
+        predictionPoints.push_back(i > 0 ? generateConstantPredictions(groundTruthPoint)
+                                         : std::vector<PredictionPoint>{});
+        groundTruthPoint.timestamp += TEST_PREDICTION_INTERVAL_NANOS;
+    }
+
+    std::vector<AtomFields> atomFields;
+    runMetricsManager(groundTruthPoints, predictionPoints, atomFields);
+
+    ASSERT_EQ(TEST_MAX_NUM_PREDICTIONS, atomFields.size());
+    // Check that errors are all zero, or NO_DATA_SENTINEL for unreported metrics.
+    for (size_t i = 0; i < atomFields.size(); ++i) {
+        SCOPED_TRACE(testing::Message() << "i = " << i);
+        const AtomFields& atom = atomFields[i];
+        const nsecs_t deltaTimeBucketNanos = TEST_PREDICTION_INTERVAL_NANOS * (i + 1);
+        EXPECT_EQ(deltaTimeBucketNanos / NANOS_PER_MILLIS, atom.deltaTimeBucketMilliseconds);
+        // General errors: reported for every time bucket.
+        EXPECT_EQ(0, atom.alongTrajectoryErrorMeanMillipixels);
+        EXPECT_EQ(0, atom.alongTrajectoryErrorStdMillipixels);
+        EXPECT_EQ(0, atom.offTrajectoryRmseMillipixels);
+        EXPECT_EQ(0, atom.pressureRmseMilliunits);
+        // High-velocity errors: reported only for the last two time buckets.
+        // However, this data has zero velocity, so these metrics should all be NO_DATA_SENTINEL.
+        EXPECT_EQ(NO_DATA_SENTINEL, atom.highVelocityAlongTrajectoryRmse);
+        EXPECT_EQ(NO_DATA_SENTINEL, atom.highVelocityOffTrajectoryRmse);
+        // Scale-invariant errors: reported only for the last time bucket.
+        if (i + 1 == atomFields.size()) {
+            EXPECT_EQ(0, atom.scaleInvariantAlongTrajectoryRmse);
+            EXPECT_EQ(0, atom.scaleInvariantOffTrajectoryRmse);
+        } else {
+            EXPECT_EQ(NO_DATA_SENTINEL, atom.scaleInvariantAlongTrajectoryRmse);
+            EXPECT_EQ(NO_DATA_SENTINEL, atom.scaleInvariantOffTrajectoryRmse);
+        }
+    }
+}
+
+TEST(MotionPredictorMetricsManagerTest, QuadraticPressureLinearPredictions) {
+    // Generate ground truth points.
+    //
+    // Ground truth pressures are a quadratically increasing function from some initial value.
+    const float initialPressure = 0.5f;
+    const float quadraticCoefficient = 0.01f;
+    std::vector<GroundTruthPoint> groundTruthPoints;
+    nsecs_t timestamp = TEST_INITIAL_TIMESTAMP;
+    // As described in the runMetricsManager comment, we should have TEST_MAX_NUM_PREDICTIONS + 2
+    // ground truth points.
+    for (size_t i = 0; i < TEST_MAX_NUM_PREDICTIONS + 2; ++i) {
+        const float pressure = initialPressure + quadraticCoefficient * static_cast<float>(i * i);
+        groundTruthPoints.push_back(
+                GroundTruthPoint{{.position = Eigen::Vector2f(0, 0), .pressure = pressure},
+                                 .timestamp = timestamp});
+        timestamp += TEST_PREDICTION_INTERVAL_NANOS;
+    }
+
+    // Note: the first index is the source ground truth index, and the second is the prediction
+    // target index.
+    std::vector<std::vector<PredictionPoint>> predictionPoints =
+            generateAllPredictionsByLinearExtrapolation(groundTruthPoints);
+
+    const std::vector<float> pressureErrors =
+            computePressureRmses(groundTruthPoints, predictionPoints);
+
+    // Run test.
+    std::vector<AtomFields> atomFields;
+    runMetricsManager(groundTruthPoints, predictionPoints, atomFields);
+
+    // Check logged metrics match expectations.
+    ASSERT_EQ(TEST_MAX_NUM_PREDICTIONS, atomFields.size());
+    for (size_t i = 0; i < atomFields.size(); ++i) {
+        SCOPED_TRACE(testing::Message() << "i = " << i);
+        const AtomFields& atom = atomFields[i];
+        // Check time bucket delta matches expectation based on index and prediction interval.
+        const nsecs_t deltaTimeBucketNanos = TEST_PREDICTION_INTERVAL_NANOS * (i + 1);
+        EXPECT_EQ(deltaTimeBucketNanos / NANOS_PER_MILLIS, atom.deltaTimeBucketMilliseconds);
+        // Check pressure error matches expectation.
+        EXPECT_NEAR(static_cast<int>(1000 * pressureErrors[i]), atom.pressureRmseMilliunits, 1);
+    }
+}
+
+TEST(MotionPredictorMetricsManagerTest, QuadraticPositionLinearPredictionsGeneralErrors) {
+    // Generate ground truth points.
+    //
+    // Each component of the ground truth positions are an independent quadratically increasing
+    // function from some initial value.
+    const Eigen::Vector2f initialPosition(200, 300);
+    const Eigen::Vector2f quadraticCoefficients(-2, 3);
+    std::vector<GroundTruthPoint> groundTruthPoints;
+    nsecs_t timestamp = TEST_INITIAL_TIMESTAMP;
+    // As described in the runMetricsManager comment, we should have TEST_MAX_NUM_PREDICTIONS + 2
+    // ground truth points.
+    for (size_t i = 0; i < TEST_MAX_NUM_PREDICTIONS + 2; ++i) {
+        const Eigen::Vector2f position =
+                initialPosition + quadraticCoefficients * static_cast<float>(i * i);
+        groundTruthPoints.push_back(
+                GroundTruthPoint{{.position = position, .pressure = 0.5}, .timestamp = timestamp});
+        timestamp += TEST_PREDICTION_INTERVAL_NANOS;
+    }
+
+    // Note: the first index is the source ground truth index, and the second is the prediction
+    // target index.
+    std::vector<std::vector<PredictionPoint>> predictionPoints =
+            generateAllPredictionsByLinearExtrapolation(groundTruthPoints);
+
+    std::vector<GeneralPositionErrors> generalPositionErrors =
+            computeGeneralPositionErrors(groundTruthPoints, predictionPoints);
+
+    // Run test.
+    std::vector<AtomFields> atomFields;
+    runMetricsManager(groundTruthPoints, predictionPoints, atomFields);
+
+    // Check logged metrics match expectations.
+    ASSERT_EQ(TEST_MAX_NUM_PREDICTIONS, atomFields.size());
+    for (size_t i = 0; i < atomFields.size(); ++i) {
+        SCOPED_TRACE(testing::Message() << "i = " << i);
+        const AtomFields& atom = atomFields[i];
+        // Check time bucket delta matches expectation based on index and prediction interval.
+        const nsecs_t deltaTimeBucketNanos = TEST_PREDICTION_INTERVAL_NANOS * (i + 1);
+        EXPECT_EQ(deltaTimeBucketNanos / NANOS_PER_MILLIS, atom.deltaTimeBucketMilliseconds);
+        // Check general position errors match expectation.
+        EXPECT_NEAR(static_cast<int>(1000 * generalPositionErrors[i].alongTrajectoryErrorMean),
+                    atom.alongTrajectoryErrorMeanMillipixels, 1);
+        EXPECT_NEAR(static_cast<int>(1000 * generalPositionErrors[i].alongTrajectoryErrorStd),
+                    atom.alongTrajectoryErrorStdMillipixels, 1);
+        EXPECT_NEAR(static_cast<int>(1000 * generalPositionErrors[i].offTrajectoryRmse),
+                    atom.offTrajectoryRmseMillipixels, 1);
+    }
+}
+
+// Counterclockwise regular octagonal section test:
+//  • Input – ground truth: constantly-spaced input events starting at a trajectory pointing exactly
+//    rightwards, and rotating by 45° counterclockwise after each input.
+//  • Input – predictions: simple linear extrapolations of previous two ground truth points.
+//
+// The code below uses the following terminology to distinguish references to ground truth events:
+//  • Source ground truth: the most recent ground truth point received at the time the prediction
+//    was made.
+//  • Target ground truth: the ground truth event that the prediction was attempting to match.
+TEST(MotionPredictorMetricsManagerTest, CounterclockwiseOctagonGroundTruthLinearPredictions) {
+    // Select a stroke velocity that exceeds the high-velocity threshold of 1100 px/sec.
+    // For an input rate of 240 hz, 1100 px/sec * (1/240) sec/input ≈ 4.58 pixels per input.
+    const float strokeVelocity = 10; // pixels per input
+
+    // As described in the runMetricsManager comment, we should have TEST_MAX_NUM_PREDICTIONS + 2
+    // ground truth points.
+    std::vector<GroundTruthPoint> groundTruthPoints = generateCircularArcGroundTruthPoints(
+            /*initialPosition=*/Eigen::Vector2f(100, 100),
+            /*initialAngle=*/M_PI_2,
+            /*velocity=*/strokeVelocity,
+            /*turningAngle=*/-M_PI_4,
+            /*numPoints=*/TEST_MAX_NUM_PREDICTIONS + 2);
+
+    std::vector<std::vector<PredictionPoint>> predictionPoints =
+            generateAllPredictionsByLinearExtrapolation(groundTruthPoints);
+
+    std::vector<GeneralPositionErrors> generalPositionErrors =
+            computeGeneralPositionErrors(groundTruthPoints, predictionPoints);
+
+    // Run test.
+    std::vector<AtomFields> atomFields;
+    runMetricsManager(groundTruthPoints, predictionPoints, atomFields);
+
+    // Check logged metrics match expectations.
+    ASSERT_EQ(TEST_MAX_NUM_PREDICTIONS, atomFields.size());
+    for (size_t i = 0; i < atomFields.size(); ++i) {
+        SCOPED_TRACE(testing::Message() << "i = " << i);
+        const AtomFields& atom = atomFields[i];
+        const nsecs_t deltaTimeBucketNanos = TEST_PREDICTION_INTERVAL_NANOS * (i + 1);
+        EXPECT_EQ(deltaTimeBucketNanos / NANOS_PER_MILLIS, atom.deltaTimeBucketMilliseconds);
+
+        // General errors: reported for every time bucket.
+        EXPECT_NEAR(static_cast<int>(1000 * generalPositionErrors[i].alongTrajectoryErrorMean),
+                    atom.alongTrajectoryErrorMeanMillipixels, 1);
+        // We allow for some floating point error in standard deviation (0.02 pixels).
+        EXPECT_NEAR(1000 * generalPositionErrors[i].alongTrajectoryErrorStd,
+                    atom.alongTrajectoryErrorStdMillipixels, 20);
+        // All position errors are equal, so the standard deviation should be approximately zero.
+        EXPECT_NEAR(0, atom.alongTrajectoryErrorStdMillipixels, 20);
+        // Absolute value for RMSE, since it must be non-negative.
+        EXPECT_NEAR(static_cast<int>(1000 * generalPositionErrors[i].offTrajectoryRmse),
+                    atom.offTrajectoryRmseMillipixels, 1);
+
+        // High-velocity errors: reported only for the last two time buckets.
+        //
+        // Since our input stroke velocity is chosen to be above the high-velocity threshold, all
+        // data contributes to high-velocity errors, and thus high-velocity errors should be equal
+        // to general errors (where reported).
+        //
+        // As above, use absolute value for RMSE, since it must be non-negative.
+        if (i + 2 >= atomFields.size()) {
+            EXPECT_NEAR(static_cast<int>(
+                                1000 * std::abs(generalPositionErrors[i].alongTrajectoryErrorMean)),
+                        atom.highVelocityAlongTrajectoryRmse, 1);
+            EXPECT_NEAR(static_cast<int>(1000 *
+                                         std::abs(generalPositionErrors[i].offTrajectoryRmse)),
+                        atom.highVelocityOffTrajectoryRmse, 1);
+        } else {
+            EXPECT_EQ(NO_DATA_SENTINEL, atom.highVelocityAlongTrajectoryRmse);
+            EXPECT_EQ(NO_DATA_SENTINEL, atom.highVelocityOffTrajectoryRmse);
+        }
+
+        // Scale-invariant errors: reported only for the last time bucket, where the reported value
+        // is the aggregation across all time buckets.
+        //
+        // The MetricsManager stores mMaxNumPredictions recent ground truth segments. Our ground
+        // truth segments here all have a length of strokeVelocity, so we can convert general errors
+        // to scale-invariant errors by dividing by `strokeVelocty * TEST_MAX_NUM_PREDICTIONS`.
+        //
+        // As above, use absolute value for RMSE, since it must be non-negative.
+        if (i + 1 == atomFields.size()) {
+            const float pathLength = strokeVelocity * TEST_MAX_NUM_PREDICTIONS;
+            std::vector<float> alongTrajectoryAbsoluteErrors;
+            std::vector<float> offTrajectoryAbsoluteErrors;
+            for (size_t j = 0; j < TEST_MAX_NUM_PREDICTIONS; ++j) {
+                alongTrajectoryAbsoluteErrors.push_back(
+                        std::abs(generalPositionErrors[j].alongTrajectoryErrorMean));
+                offTrajectoryAbsoluteErrors.push_back(
+                        std::abs(generalPositionErrors[j].offTrajectoryRmse));
+            }
+            EXPECT_NEAR(static_cast<int>(1000 * average(alongTrajectoryAbsoluteErrors) /
+                                         pathLength),
+                        atom.scaleInvariantAlongTrajectoryRmse, 1);
+            EXPECT_NEAR(static_cast<int>(1000 * average(offTrajectoryAbsoluteErrors) / pathLength),
+                        atom.scaleInvariantOffTrajectoryRmse, 1);
+        } else {
+            EXPECT_EQ(NO_DATA_SENTINEL, atom.scaleInvariantAlongTrajectoryRmse);
+            EXPECT_EQ(NO_DATA_SENTINEL, atom.scaleInvariantOffTrajectoryRmse);
+        }
+    }
+}
+
+} // namespace
+} // namespace android
diff --git a/libs/input/tests/VelocityTracker_test.cpp b/libs/input/tests/VelocityTracker_test.cpp
index ae72109..73f25cc 100644
--- a/libs/input/tests/VelocityTracker_test.cpp
+++ b/libs/input/tests/VelocityTracker_test.cpp
@@ -282,6 +282,11 @@
         const std::vector<std::pair<std::chrono::nanoseconds, float>>& motions,
         std::optional<float> targetVelocity) {
     checkVelocity(computeVelocity(strategy, motions, AMOTION_EVENT_AXIS_SCROLL), targetVelocity);
+    // The strategy LSQ2 is not compatible with AXIS_SCROLL. In those situations, we should fall
+    // back to a strategy that supports differential axes.
+    checkVelocity(computeVelocity(VelocityTracker::Strategy::LSQ2, motions,
+                                  AMOTION_EVENT_AXIS_SCROLL),
+                  targetVelocity);
 }
 
 static void computeAndCheckQuadraticEstimate(const std::vector<PlanarMotionEventEntry>& motions,
diff --git a/opengl/libs/EGL/egl_platform_entries.cpp b/opengl/libs/EGL/egl_platform_entries.cpp
index 88001b2..440eb17 100644
--- a/opengl/libs/EGL/egl_platform_entries.cpp
+++ b/opengl/libs/EGL/egl_platform_entries.cpp
@@ -49,6 +49,7 @@
 #include "egl_trace.h"
 
 using namespace android;
+using PixelFormat = aidl::android::hardware::graphics::common::PixelFormat;
 
 // ----------------------------------------------------------------------------
 
@@ -406,7 +407,7 @@
 // ----------------------------------------------------------------------------
 
 // Translates EGL color spaces to Android data spaces.
-static android_dataspace dataSpaceFromEGLColorSpace(EGLint colorspace) {
+static android_dataspace dataSpaceFromEGLColorSpace(EGLint colorspace, PixelFormat pixelFormat) {
     if (colorspace == EGL_GL_COLORSPACE_LINEAR_KHR) {
         return HAL_DATASPACE_UNKNOWN;
     } else if (colorspace == EGL_GL_COLORSPACE_SRGB_KHR) {
@@ -424,7 +425,13 @@
     } else if (colorspace == EGL_GL_COLORSPACE_BT2020_HLG_EXT) {
         return static_cast<android_dataspace>(HAL_DATASPACE_BT2020_HLG);
     } else if (colorspace == EGL_GL_COLORSPACE_BT2020_LINEAR_EXT) {
-        return HAL_DATASPACE_BT2020_LINEAR;
+        if (pixelFormat == PixelFormat::RGBA_FP16) {
+            return static_cast<android_dataspace>(HAL_DATASPACE_STANDARD_BT2020 |
+                                                  HAL_DATASPACE_TRANSFER_LINEAR |
+                                                  HAL_DATASPACE_RANGE_EXTENDED);
+        } else {
+            return HAL_DATASPACE_BT2020_LINEAR;
+        }
     } else if (colorspace == EGL_GL_COLORSPACE_BT2020_PQ_EXT) {
         return HAL_DATASPACE_BT2020_PQ;
     }
@@ -573,8 +580,6 @@
     newList.push_back(EGL_NONE);
 }
 
-using PixelFormat = aidl::android::hardware::graphics::common::PixelFormat;
-
 // Gets the native pixel format corrsponding to the passed EGLConfig.
 void getNativePixelFormat(EGLDisplay dpy, egl_connection_t* cnx, EGLConfig config,
                           PixelFormat* format) {
@@ -714,7 +719,7 @@
             return setError(EGL_BAD_NATIVE_WINDOW, EGL_NO_SURFACE);
         }
 
-        android_dataspace dataSpace = dataSpaceFromEGLColorSpace(colorSpace);
+        android_dataspace dataSpace = dataSpaceFromEGLColorSpace(colorSpace, format);
         // Set dataSpace even if it could be HAL_DATASPACE_UNKNOWN.
         // HAL_DATASPACE_UNKNOWN is the default value, but it may have changed
         // at this point.
diff --git a/services/inputflinger/reader/EventHub.cpp b/services/inputflinger/reader/EventHub.cpp
index 4d0e13e..44e80a7 100644
--- a/services/inputflinger/reader/EventHub.cpp
+++ b/services/inputflinger/reader/EventHub.cpp
@@ -2403,6 +2403,7 @@
 
         // See if this device has any stylus buttons that we would want to fuse with touch data.
         if (!device->classes.any(InputDeviceClass::TOUCH | InputDeviceClass::TOUCH_MT) &&
+            !device->classes.any(InputDeviceClass::ALPHAKEY) &&
             std::any_of(STYLUS_BUTTON_KEYCODES.begin(), STYLUS_BUTTON_KEYCODES.end(),
                         [&](int32_t keycode) { return device->hasKeycodeLocked(keycode); })) {
             device->classes |= InputDeviceClass::EXTERNAL_STYLUS;
diff --git a/services/inputflinger/reader/InputReader.cpp b/services/inputflinger/reader/InputReader.cpp
index 08600b2..7f63355 100644
--- a/services/inputflinger/reader/InputReader.cpp
+++ b/services/inputflinger/reader/InputReader.cpp
@@ -77,7 +77,7 @@
       : mContext(this),
         mEventHub(eventHub),
         mPolicy(policy),
-        mQueuedListener(listener),
+        mNextListener(listener),
         mGlobalMetaState(AMETA_NONE),
         mLedMetaState(AMETA_NONE),
         mGeneration(1),
@@ -140,7 +140,7 @@
         mReaderIsAliveCondition.notify_all();
 
         if (!events.empty()) {
-            notifyArgs += processEventsLocked(events.data(), events.size());
+            mPendingArgs += processEventsLocked(events.data(), events.size());
         }
 
         if (mNextTimeout != LLONG_MAX) {
@@ -150,16 +150,18 @@
                     ALOGD("Timeout expired, latency=%0.3fms", (now - mNextTimeout) * 0.000001f);
                 }
                 mNextTimeout = LLONG_MAX;
-                notifyArgs += timeoutExpiredLocked(now);
+                mPendingArgs += timeoutExpiredLocked(now);
             }
         }
 
         if (oldGeneration != mGeneration) {
             inputDevicesChanged = true;
             inputDevices = getInputDevicesLocked();
-            notifyArgs.emplace_back(
+            mPendingArgs.emplace_back(
                     NotifyInputDevicesChangedArgs{mContext.getNextId(), inputDevices});
         }
+
+        std::swap(notifyArgs, mPendingArgs);
     } // release lock
 
     // Send out a message that the describes the changed input devices.
@@ -175,8 +177,6 @@
         }
     }
 
-    notifyAll(std::move(notifyArgs));
-
     // Flush queued events out to the listener.
     // This must happen outside of the lock because the listener could potentially call
     // back into the InputReader's methods, such as getScanCodeState, or become blocked
@@ -184,7 +184,9 @@
     // resulting in a deadlock.  This situation is actually quite plausible because the
     // listener is actually the input dispatcher, which calls into the window manager,
     // which occasionally calls into the input reader.
-    mQueuedListener.flush();
+    for (const NotifyArgs& args : notifyArgs) {
+        mNextListener.notify(args);
+    }
 }
 
 std::list<NotifyArgs> InputReader::processEventsLocked(const RawEvent* rawEvents, size_t count) {
@@ -236,8 +238,8 @@
     InputDeviceIdentifier identifier = mEventHub->getDeviceIdentifier(eventHubId);
     std::shared_ptr<InputDevice> device = createDeviceLocked(eventHubId, identifier);
 
-    notifyAll(device->configure(when, mConfig, /*changes=*/{}));
-    notifyAll(device->reset(when));
+    mPendingArgs += device->configure(when, mConfig, /*changes=*/{});
+    mPendingArgs += device->reset(when);
 
     if (device->isIgnored()) {
         ALOGI("Device added: id=%d, eventHubId=%d, name='%s', descriptor='%s' "
@@ -310,12 +312,10 @@
         notifyExternalStylusPresenceChangedLocked();
     }
 
-    std::list<NotifyArgs> resetEvents;
     if (device->hasEventHubDevices()) {
-        resetEvents += device->configure(when, mConfig, /*changes=*/{});
+        mPendingArgs += device->configure(when, mConfig, /*changes=*/{});
     }
-    resetEvents += device->reset(when);
-    notifyAll(std::move(resetEvents));
+    mPendingArgs += device->reset(when);
 }
 
 std::shared_ptr<InputDevice> InputReader::createDeviceLocked(
@@ -387,7 +387,7 @@
     updateGlobalMetaStateLocked();
 
     // Enqueue configuration changed.
-    mQueuedListener.notifyConfigurationChanged({mContext.getNextId(), when});
+    mPendingArgs.emplace_back(NotifyConfigurationChangedArgs{mContext.getNextId(), when});
 }
 
 void InputReader::refreshConfigurationLocked(ConfigurationChanges changes) {
@@ -409,7 +409,7 @@
     } else {
         for (auto& devicePair : mDevices) {
             std::shared_ptr<InputDevice>& device = devicePair.second;
-            notifyAll(device->configure(now, mConfig, changes));
+            mPendingArgs += device->configure(now, mConfig, changes);
         }
     }
 
@@ -419,18 +419,13 @@
                   "There was no change in the pointer capture state.");
         } else {
             mCurrentPointerCaptureRequest = mConfig.pointerCaptureRequest;
-            mQueuedListener.notifyPointerCaptureChanged(
-                    {mContext.getNextId(), now, mCurrentPointerCaptureRequest});
+            mPendingArgs.emplace_back(
+                    NotifyPointerCaptureChangedArgs{mContext.getNextId(), now,
+                                                    mCurrentPointerCaptureRequest});
         }
     }
 }
 
-void InputReader::notifyAll(std::list<NotifyArgs>&& argsList) {
-    for (const NotifyArgs& args : argsList) {
-        mQueuedListener.notify(args);
-    }
-}
-
 void InputReader::updateGlobalMetaStateLocked() {
     mGlobalMetaState = 0;
 
@@ -690,7 +685,7 @@
 
     InputDevice* device = findInputDeviceLocked(deviceId);
     if (device) {
-        notifyAll(device->vibrate(sequence, repeat, token));
+        mPendingArgs += device->vibrate(sequence, repeat, token);
     }
 }
 
@@ -699,7 +694,7 @@
 
     InputDevice* device = findInputDeviceLocked(deviceId);
     if (device) {
-        notifyAll(device->cancelVibrate(token));
+        mPendingArgs += device->cancelVibrate(token);
     }
 }
 
diff --git a/services/inputflinger/reader/include/InputReader.h b/services/inputflinger/reader/include/InputReader.h
index 01ec7c1..e21715e 100644
--- a/services/inputflinger/reader/include/InputReader.h
+++ b/services/inputflinger/reader/include/InputReader.h
@@ -174,7 +174,14 @@
     // in parallel to passing it to the InputReader.
     std::shared_ptr<EventHubInterface> mEventHub;
     sp<InputReaderPolicyInterface> mPolicy;
-    QueuedInputListener mQueuedListener;
+
+    // The next stage that should receive the events generated inside InputReader.
+    InputListenerInterface& mNextListener;
+    // As various events are generated inside InputReader, they are stored inside this list. The
+    // list can only be accessed with the lock, so the events inside it are well-ordered.
+    // Once the reader is done working, these events will be swapped into a temporary storage and
+    // sent to the 'mNextListener' without holding the lock.
+    std::list<NotifyArgs> mPendingArgs GUARDED_BY(mLock);
 
     InputReaderConfiguration mConfig GUARDED_BY(mLock);
 
@@ -242,8 +249,6 @@
     ConfigurationChanges mConfigurationChangesToRefresh GUARDED_BY(mLock);
     void refreshConfigurationLocked(ConfigurationChanges changes) REQUIRES(mLock);
 
-    void notifyAll(std::list<NotifyArgs>&& argsList);
-
     PointerCaptureRequest mCurrentPointerCaptureRequest GUARDED_BY(mLock);
 
     // state queries
diff --git a/services/inputflinger/tests/InputReader_test.cpp b/services/inputflinger/tests/InputReader_test.cpp
index d1c3f7d..9ccd965 100644
--- a/services/inputflinger/tests/InputReader_test.cpp
+++ b/services/inputflinger/tests/InputReader_test.cpp
@@ -1334,19 +1334,8 @@
         mFakePolicy = sp<FakeInputReaderPolicy>::make();
         mFakePointerController = std::make_shared<FakePointerController>();
         mFakePolicy->setPointerController(mFakePointerController);
-        mTestListener = std::make_unique<TestInputListener>(/*eventHappenedTimeout=*/2000ms,
-                                                            /*eventDidNotHappenTimeout=*/30ms);
 
-        mReader = std::make_unique<InputReader>(std::make_shared<EventHub>(), mFakePolicy,
-                                                *mTestListener);
-        ASSERT_EQ(mReader->start(), OK);
-
-        // Since this test is run on a real device, all the input devices connected
-        // to the test device will show up in mReader. We wait for those input devices to
-        // show up before beginning the tests.
-        ASSERT_NO_FATAL_FAILURE(mFakePolicy->assertInputDevicesChanged());
-        ASSERT_NO_FATAL_FAILURE(mTestListener->assertNotifyInputDevicesChangedWasCalled());
-        ASSERT_NO_FATAL_FAILURE(mTestListener->assertNotifyConfigurationChangedWasCalled());
+        setupInputReader();
     }
 
     void TearDown() override {
@@ -1367,6 +1356,22 @@
                                       });
         return it != inputDevices.end() ? std::make_optional(*it) : std::nullopt;
     }
+
+    void setupInputReader() {
+        mTestListener = std::make_unique<TestInputListener>(/*eventHappenedTimeout=*/2000ms,
+                                                            /*eventDidNotHappenTimeout=*/30ms);
+
+        mReader = std::make_unique<InputReader>(std::make_shared<EventHub>(), mFakePolicy,
+                                                *mTestListener);
+        ASSERT_EQ(mReader->start(), OK);
+
+        // Since this test is run on a real device, all the input devices connected
+        // to the test device will show up in mReader. We wait for those input devices to
+        // show up before beginning the tests.
+        ASSERT_NO_FATAL_FAILURE(mFakePolicy->assertInputDevicesChanged());
+        ASSERT_NO_FATAL_FAILURE(mTestListener->assertNotifyInputDevicesChangedWasCalled());
+        ASSERT_NO_FATAL_FAILURE(mTestListener->assertNotifyConfigurationChangedWasCalled());
+    }
 };
 
 TEST_F(InputReaderIntegrationTest, TestInvalidDevice) {
@@ -1476,6 +1481,46 @@
             AllOf(UP, WithKeyCode(AKEYCODE_STYLUS_BUTTON_TERTIARY))));
 }
 
+TEST_F(InputReaderIntegrationTest, KeyboardWithStylusButtons) {
+    std::unique_ptr<UinputKeyboard> keyboard =
+            createUinputDevice<UinputKeyboard>("KeyboardWithStylusButtons", /*productId=*/99,
+                                               std::initializer_list<int>{KEY_Q, KEY_W, KEY_E,
+                                                                          KEY_R, KEY_T, KEY_Y,
+                                                                          BTN_STYLUS, BTN_STYLUS2,
+                                                                          BTN_STYLUS3});
+    ASSERT_NO_FATAL_FAILURE(mFakePolicy->assertInputDevicesChanged());
+
+    const auto device = findDeviceByName(keyboard->getName());
+    ASSERT_TRUE(device.has_value());
+
+    // An alphabetical keyboard that reports stylus buttons should not be recognized as a stylus.
+    ASSERT_EQ(AINPUT_SOURCE_KEYBOARD, device->getSources())
+            << "Unexpected source " << inputEventSourceToString(device->getSources()).c_str();
+    ASSERT_EQ(AINPUT_KEYBOARD_TYPE_ALPHABETIC, device->getKeyboardType());
+}
+
+TEST_F(InputReaderIntegrationTest, HidUsageKeyboardIsNotAStylus) {
+    // Create a Uinput keyboard that simulates a keyboard that can report HID usage codes. The
+    // hid-input driver reports HID usage codes using the value for EV_MSC MSC_SCAN event.
+    std::unique_ptr<UinputKeyboardWithHidUsage> keyboard =
+            createUinputDevice<UinputKeyboardWithHidUsage>(
+                    std::initializer_list<int>{KEY_VOLUMEUP, KEY_VOLUMEDOWN});
+    ASSERT_NO_FATAL_FAILURE(mFakePolicy->assertInputDevicesChanged());
+
+    const auto device = findDeviceByName(keyboard->getName());
+    ASSERT_TRUE(device.has_value());
+
+    ASSERT_EQ(AINPUT_SOURCE_KEYBOARD, device->getSources())
+            << "Unexpected source " << inputEventSourceToString(device->getSources()).c_str();
+
+    // If a device supports reporting HID usage codes, it shouldn't automatically support
+    // stylus keys.
+    const std::vector<int> keycodes{AKEYCODE_STYLUS_BUTTON_PRIMARY};
+    uint8_t outFlags[] = {0};
+    ASSERT_TRUE(mReader->hasKeys(device->getId(), AINPUT_SOURCE_KEYBOARD, keycodes, outFlags));
+    ASSERT_EQ(0, outFlags[0]) << "Keyboard should not have stylus button";
+}
+
 /**
  * The Steam controller sends BTN_GEAR_DOWN and BTN_GEAR_UP for the two "paddle" buttons
  * on the back. In this test, we make sure that BTN_GEAR_DOWN / BTN_WHEEL and BTN_GEAR_UP
@@ -1500,7 +1545,7 @@
 
 // --- TouchIntegrationTest ---
 
-class TouchIntegrationTest : public InputReaderIntegrationTest {
+class BaseTouchIntegrationTest : public InputReaderIntegrationTest {
 protected:
     const std::string UNIQUE_ID = "local:0";
 
@@ -1545,7 +1590,55 @@
     InputDeviceInfo mDeviceInfo;
 };
 
-TEST_F(TouchIntegrationTest, MultiTouchDeviceSource) {
+enum class TouchIntegrationTestDisplays { DISPLAY_INTERNAL, DISPLAY_INPUT_PORT, DISPLAY_UNIQUE_ID };
+
+class TouchIntegrationTest : public BaseTouchIntegrationTest,
+                             public testing::WithParamInterface<TouchIntegrationTestDisplays> {
+protected:
+    static constexpr std::optional<uint8_t> DISPLAY_PORT = 0;
+    const std::string INPUT_PORT = "uinput_touch/input0";
+
+    void SetUp() override {
+#if !defined(__ANDROID__)
+        GTEST_SKIP();
+#endif
+        if (GetParam() == TouchIntegrationTestDisplays::DISPLAY_INTERNAL) {
+            BaseTouchIntegrationTest::SetUp();
+            return;
+        }
+
+        // setup policy with a input-port or UniqueId association to the display
+        bool isInputPortAssociation =
+                GetParam() == TouchIntegrationTestDisplays::DISPLAY_INPUT_PORT;
+
+        mFakePolicy = sp<FakeInputReaderPolicy>::make();
+        if (isInputPortAssociation) {
+            mFakePolicy->addInputPortAssociation(INPUT_PORT, DISPLAY_PORT.value());
+        } else {
+            mFakePolicy->addInputUniqueIdAssociation(INPUT_PORT, UNIQUE_ID);
+        }
+        mFakePointerController = std::make_shared<FakePointerController>();
+        mFakePolicy->setPointerController(mFakePointerController);
+
+        InputReaderIntegrationTest::setupInputReader();
+
+        mDevice = createUinputDevice<UinputTouchScreen>(Rect(0, 0, DISPLAY_WIDTH, DISPLAY_HEIGHT),
+                                                        INPUT_PORT);
+        ASSERT_NO_FATAL_FAILURE(mFakePolicy->assertInputDevicesChanged());
+
+        // Add a display linked to a physical port or UniqueId.
+        setDisplayInfoAndReconfigure(DISPLAY_ID, DISPLAY_WIDTH, DISPLAY_HEIGHT, ui::ROTATION_0,
+                                     UNIQUE_ID, isInputPortAssociation ? DISPLAY_PORT : NO_PORT,
+                                     ViewportType::INTERNAL);
+        ASSERT_NO_FATAL_FAILURE(mFakePolicy->assertInputDevicesChanged());
+        ASSERT_NO_FATAL_FAILURE(mTestListener->assertNotifyConfigurationChangedWasCalled());
+        const auto info = findDeviceByName(mDevice->getName());
+        ASSERT_TRUE(info);
+        mDeviceInfo = *info;
+    }
+};
+
+TEST_P(TouchIntegrationTest, MultiTouchDeviceSource) {
     // The UinputTouchScreen is an MT device that supports MT_TOOL_TYPE and also supports stylus
     // buttons. It should show up as a touchscreen, stylus, and keyboard (for reporting button
     // presses).
@@ -1553,7 +1646,7 @@
               mDeviceInfo.getSources());
 }
 
-TEST_F(TouchIntegrationTest, InputEvent_ProcessSingleTouch) {
+TEST_P(TouchIntegrationTest, InputEvent_ProcessSingleTouch) {
     NotifyMotionArgs args;
     const Point centerPoint = mDevice->getCenterPoint();
 
@@ -1577,7 +1670,7 @@
     ASSERT_EQ(AMOTION_EVENT_ACTION_UP, args.action);
 }
 
-TEST_F(TouchIntegrationTest, InputEvent_ProcessMultiTouch) {
+TEST_P(TouchIntegrationTest, InputEvent_ProcessMultiTouch) {
     NotifyMotionArgs args;
     const Point centerPoint = mDevice->getCenterPoint();
 
@@ -1633,7 +1726,7 @@
  * palms, and wants to cancel Pointer 1, then it is safe to simply drop POINTER_1_UP event without
  * losing information about non-palm pointers.
  */
-TEST_F(TouchIntegrationTest, MultiTouch_PointerMoveAndSecondPointerUp) {
+TEST_P(TouchIntegrationTest, MultiTouch_PointerMoveAndSecondPointerUp) {
     NotifyMotionArgs args;
     const Point centerPoint = mDevice->getCenterPoint();
 
@@ -1676,7 +1769,7 @@
  * In this scenario, the movement of the second pointer just prior to liftoff is ignored, and never
  * gets sent to the listener.
  */
-TEST_F(TouchIntegrationTest, MultiTouch_PointerMoveAndSecondPointerMoveAndUp) {
+TEST_P(TouchIntegrationTest, MultiTouch_PointerMoveAndSecondPointerMoveAndUp) {
     NotifyMotionArgs args;
     const Point centerPoint = mDevice->getCenterPoint();
 
@@ -1716,7 +1809,7 @@
     assertReceivedMotion(AMOTION_EVENT_ACTION_MOVE, {centerPoint + Point(5, 5)});
 }
 
-TEST_F(TouchIntegrationTest, InputEvent_ProcessPalm) {
+TEST_P(TouchIntegrationTest, InputEvent_ProcessPalm) {
     NotifyMotionArgs args;
     const Point centerPoint = mDevice->getCenterPoint();
 
@@ -1767,7 +1860,7 @@
     ASSERT_EQ(AMOTION_EVENT_ACTION_UP, args.action);
 }
 
-TEST_F(TouchIntegrationTest, NotifiesPolicyWhenStylusGestureStarted) {
+TEST_P(TouchIntegrationTest, NotifiesPolicyWhenStylusGestureStarted) {
     const Point centerPoint = mDevice->getCenterPoint();
 
     // Send down with the pen tool selected. The policy should be notified of the stylus presence.
@@ -1819,19 +1912,24 @@
     ASSERT_NO_FATAL_FAILURE(mFakePolicy->assertStylusGestureNotified(mDeviceInfo.getId()));
 }
 
+INSTANTIATE_TEST_SUITE_P(TouchIntegrationTestDisplayVariants, TouchIntegrationTest,
+                         testing::Values(TouchIntegrationTestDisplays::DISPLAY_INTERNAL,
+                                         TouchIntegrationTestDisplays::DISPLAY_INPUT_PORT,
+                                         TouchIntegrationTestDisplays::DISPLAY_UNIQUE_ID));
+
 // --- StylusButtonIntegrationTest ---
 
 // Verify the behavior of button presses reported by various kinds of styluses, including buttons
 // reported by the touchscreen's device, by a fused external stylus, and by an un-fused external
 // stylus.
 template <typename UinputStylusDevice>
-class StylusButtonIntegrationTest : public TouchIntegrationTest {
+class StylusButtonIntegrationTest : public BaseTouchIntegrationTest {
 protected:
     void SetUp() override {
 #if !defined(__ANDROID__)
         GTEST_SKIP();
 #endif
-        TouchIntegrationTest::SetUp();
+        BaseTouchIntegrationTest::SetUp();
         mTouchscreen = mDevice.get();
         mTouchscreenInfo = mDeviceInfo;
 
@@ -1869,8 +1967,8 @@
     std::unique_ptr<UinputStylusDevice> mStylusDeviceLifecycleTracker{};
 
     // Hide the base class's device to expose it with a different name for readability.
-    using TouchIntegrationTest::mDevice;
-    using TouchIntegrationTest::mDeviceInfo;
+    using BaseTouchIntegrationTest::mDevice;
+    using BaseTouchIntegrationTest::mDeviceInfo;
 };
 
 using StylusButtonIntegrationTestTypes =
@@ -2122,7 +2220,7 @@
 // Verify the behavior of an external stylus. An external stylus can report pressure or button
 // data independently of the touchscreen, which is then sent as a MotionEvent as part of an
 // ongoing stylus gesture that is being emitted by the touchscreen.
-using ExternalStylusIntegrationTest = TouchIntegrationTest;
+using ExternalStylusIntegrationTest = BaseTouchIntegrationTest;
 
 TEST_F(ExternalStylusIntegrationTest, DISABLED_FusedExternalStylusPressureReported) {
     const Point centerPoint = mDevice->getCenterPoint();
diff --git a/services/inputflinger/tests/UinputDevice.cpp b/services/inputflinger/tests/UinputDevice.cpp
index 97a2614..e8aaa18 100644
--- a/services/inputflinger/tests/UinputDevice.cpp
+++ b/services/inputflinger/tests/UinputDevice.cpp
@@ -157,12 +157,25 @@
     injectEvent(EV_SYN, SYN_REPORT, 0);
 }
 
+// --- UinputKeyboardWithHidUsage ---
+
+UinputKeyboardWithHidUsage::UinputKeyboardWithHidUsage(std::initializer_list<int> keys)
+      : UinputKeyboard(DEVICE_NAME, PRODUCT_ID, keys) {}
+
+void UinputKeyboardWithHidUsage::configureDevice(int fd, uinput_user_dev* device) {
+    UinputKeyboard::configureDevice(fd, device);
+
+    ioctl(fd, UI_SET_EVBIT, EV_MSC);
+    ioctl(fd, UI_SET_MSCBIT, MSC_SCAN);
+}
+
 // --- UinputTouchScreen ---
 
-UinputTouchScreen::UinputTouchScreen(const Rect& size)
+UinputTouchScreen::UinputTouchScreen(const Rect& size, const std::string& physicalPort)
       : UinputKeyboard(DEVICE_NAME, PRODUCT_ID,
                        {BTN_TOUCH, BTN_TOOL_PEN, BTN_STYLUS, BTN_STYLUS2, BTN_STYLUS3}),
-        mSize(size) {}
+        mSize(size),
+        mPhysicalPort(physicalPort) {}
 
 void UinputTouchScreen::configureDevice(int fd, uinput_user_dev* device) {
     UinputKeyboard::configureDevice(fd, device);
@@ -177,6 +190,9 @@
     ioctl(fd, UI_SET_ABSBIT, ABS_MT_TRACKING_ID);
     ioctl(fd, UI_SET_ABSBIT, ABS_MT_TOOL_TYPE);
     ioctl(fd, UI_SET_PROPBIT, INPUT_PROP_DIRECT);
+    if (!mPhysicalPort.empty()) {
+        ioctl(fd, UI_SET_PHYS, mPhysicalPort.c_str());
+    }
 
     device->absmin[ABS_MT_SLOT] = RAW_SLOT_MIN;
     device->absmax[ABS_MT_SLOT] = RAW_SLOT_MAX;
diff --git a/services/inputflinger/tests/UinputDevice.h b/services/inputflinger/tests/UinputDevice.h
index 51e331d..f5507ec 100644
--- a/services/inputflinger/tests/UinputDevice.h
+++ b/services/inputflinger/tests/UinputDevice.h
@@ -165,13 +165,30 @@
     explicit UinputExternalStylusWithPressure();
 };
 
+// --- UinputKeyboardWithUsage ---
+// A keyboard that supports EV_MSC MSC_SCAN through which it can report HID usage codes.
+
+class UinputKeyboardWithHidUsage : public UinputKeyboard {
+public:
+    static constexpr const char* DEVICE_NAME = "Test Uinput Keyboard With Usage";
+    static constexpr int16_t PRODUCT_ID = 47;
+
+    template <class D, class... Ts>
+    friend std::unique_ptr<D> createUinputDevice(Ts... args);
+
+protected:
+    explicit UinputKeyboardWithHidUsage(std::initializer_list<int> keys);
+
+    void configureDevice(int fd, uinput_user_dev* device) override;
+};
+
 // --- UinputTouchScreen ---
 
 // A multi-touch touchscreen device with specific size that also supports styluses.
 class UinputTouchScreen : public UinputKeyboard {
 public:
     static constexpr const char* DEVICE_NAME = "Test Uinput Touch Screen";
-    static constexpr int16_t PRODUCT_ID = 47;
+    static constexpr int16_t PRODUCT_ID = 48;
 
     static const int32_t RAW_TOUCH_MIN = 0;
     static const int32_t RAW_TOUCH_MAX = 31;
@@ -197,11 +214,12 @@
     const Point getCenterPoint();
 
 protected:
-    explicit UinputTouchScreen(const Rect& size);
+    explicit UinputTouchScreen(const Rect& size, const std::string& physicalPort = "");
 
 private:
     void configureDevice(int fd, uinput_user_dev* device) override;
     const Rect mSize;
+    const std::string mPhysicalPort;
 };
 
 } // namespace android
diff --git a/services/surfaceflinger/CompositionEngine/include/compositionengine/impl/planner/TexturePool.h b/services/surfaceflinger/CompositionEngine/include/compositionengine/impl/planner/TexturePool.h
index 9f6141a..d607c75 100644
--- a/services/surfaceflinger/CompositionEngine/include/compositionengine/impl/planner/TexturePool.h
+++ b/services/surfaceflinger/CompositionEngine/include/compositionengine/impl/planner/TexturePool.h
@@ -66,7 +66,7 @@
     TexturePool(renderengine::RenderEngine& renderEngine)
           : mRenderEngine(renderEngine), mEnabled(false) {}
 
-    virtual ~TexturePool();
+    virtual ~TexturePool() = default;
 
     // Sets the display size for the texture pool.
     // This will trigger a reallocation for all remaining textures in the pool.
@@ -83,10 +83,11 @@
     // be held by the pool. This is useful when the active display changes.
     void setEnabled(bool enable);
 
-    void dump(std::string& out) const EXCLUDES(mMutex);
+    void dump(std::string& out) const;
 
 protected:
     // Proteted visibility so that they can be used for testing
+    const static constexpr size_t kMinPoolSize = 3;
     const static constexpr size_t kMaxPoolSize = 4;
 
     struct Entry {
@@ -95,20 +96,16 @@
     };
 
     std::deque<Entry> mPool;
-    std::future<std::shared_ptr<renderengine::ExternalTexture>> mGenTextureFuture;
 
 private:
-    std::shared_ptr<renderengine::ExternalTexture> genTexture(ui::Size size);
+    std::shared_ptr<renderengine::ExternalTexture> genTexture();
     // Returns a previously borrowed texture to the pool.
     void returnTexture(std::shared_ptr<renderengine::ExternalTexture>&& texture,
                        const sp<Fence>& fence);
-    void genTextureAsyncIfNeeded() REQUIRES(mMutex);
-    void resetPool() REQUIRES(mMutex);
-    renderengine::RenderEngine& mRenderEngine GUARDED_BY(mRenderEngineMutex);
-    ui::Size mSize GUARDED_BY(mMutex);
+    void allocatePool();
+    renderengine::RenderEngine& mRenderEngine;
+    ui::Size mSize;
     bool mEnabled;
-    mutable std::mutex mMutex;
-    mutable std::mutex mRenderEngineMutex;
 };
 
 } // namespace android::compositionengine::impl::planner
diff --git a/services/surfaceflinger/CompositionEngine/src/planner/TexturePool.cpp b/services/surfaceflinger/CompositionEngine/src/planner/TexturePool.cpp
index 10f58ce..54ecb56 100644
--- a/services/surfaceflinger/CompositionEngine/src/planner/TexturePool.cpp
+++ b/services/surfaceflinger/CompositionEngine/src/planner/TexturePool.cpp
@@ -25,61 +25,31 @@
 
 namespace android::compositionengine::impl::planner {
 
-TexturePool::~TexturePool() {
-    if (mGenTextureFuture.valid()) {
-        mGenTextureFuture.get();
-    }
-}
-
-void TexturePool::resetPool() {
-    if (mGenTextureFuture.valid()) {
-        mGenTextureFuture.get();
-    }
+void TexturePool::allocatePool() {
     mPool.clear();
-    genTextureAsyncIfNeeded();
-}
-
-// Generate a new texture asynchronously so it will not require allocation on the main
-// thread.
-void TexturePool::genTextureAsyncIfNeeded() {
-    if (mEnabled && mSize.isValid() && !mGenTextureFuture.valid()) {
-        mGenTextureFuture = std::async(
-                std::launch::async, [&](ui::Size size) { return genTexture(size); }, mSize);
+    if (mEnabled && mSize.isValid()) {
+        mPool.resize(kMinPoolSize);
+        std::generate_n(mPool.begin(), kMinPoolSize, [&]() {
+            return Entry{genTexture(), nullptr};
+        });
     }
 }
 
 void TexturePool::setDisplaySize(ui::Size size) {
-    std::lock_guard lock(mMutex);
     if (mSize == size) {
         return;
     }
     mSize = size;
-    resetPool();
+    allocatePool();
 }
 
 std::shared_ptr<TexturePool::AutoTexture> TexturePool::borrowTexture() {
     if (mPool.empty()) {
-        std::lock_guard lock(mMutex);
-        std::shared_ptr<TexturePool::AutoTexture> tex;
-        if (mGenTextureFuture.valid()) {
-            tex = std::make_shared<AutoTexture>(*this, mGenTextureFuture.get(), nullptr);
-        } else {
-            tex = std::make_shared<AutoTexture>(*this, genTexture(mSize), nullptr);
-        }
-        // Speculatively generate a new texture, so that the next call does not need
-        // to wait for allocation.
-        genTextureAsyncIfNeeded();
-        return tex;
+        return std::make_shared<AutoTexture>(*this, genTexture(), nullptr);
     }
 
     const auto entry = mPool.front();
     mPool.pop_front();
-    if (mPool.empty()) {
-        std::lock_guard lock(mMutex);
-        // Similiarly generate a new texture when lending out the last entry, so that
-        // the next call does not need to wait for allocation.
-        genTextureAsyncIfNeeded();
-    }
     return std::make_shared<AutoTexture>(*this, entry.texture, entry.fence);
 }
 
@@ -90,8 +60,6 @@
         return;
     }
 
-    std::lock_guard lock(mMutex);
-
     // Or the texture on the floor if the pool is no longer tracking textures of the same size.
     if (static_cast<int32_t>(texture->getBuffer()->getWidth()) != mSize.getWidth() ||
         static_cast<int32_t>(texture->getBuffer()->getHeight()) != mSize.getHeight()) {
@@ -112,14 +80,13 @@
     mPool.push_back({std::move(texture), fence});
 }
 
-std::shared_ptr<renderengine::ExternalTexture> TexturePool::genTexture(ui::Size size) {
-    std::lock_guard lock(mRenderEngineMutex);
-    LOG_ALWAYS_FATAL_IF(!size.isValid(), "Attempted to generate texture with invalid size");
+std::shared_ptr<renderengine::ExternalTexture> TexturePool::genTexture() {
+    LOG_ALWAYS_FATAL_IF(!mSize.isValid(), "Attempted to generate texture with invalid size");
     return std::make_shared<
             renderengine::impl::
                     ExternalTexture>(sp<GraphicBuffer>::
-                                             make(static_cast<uint32_t>(size.getWidth()),
-                                                  static_cast<uint32_t>(size.getHeight()),
+                                             make(static_cast<uint32_t>(mSize.getWidth()),
+                                                  static_cast<uint32_t>(mSize.getHeight()),
                                                   HAL_PIXEL_FORMAT_RGBA_8888, 1U,
                                                   static_cast<uint64_t>(
                                                           GraphicBuffer::USAGE_HW_RENDER |
@@ -133,16 +100,13 @@
 
 void TexturePool::setEnabled(bool enabled) {
     mEnabled = enabled;
-
-    std::lock_guard lock(mMutex);
-    resetPool();
+    allocatePool();
 }
 
 void TexturePool::dump(std::string& out) const {
-    std::lock_guard lock(mMutex);
     base::StringAppendF(&out,
                         "TexturePool (%s) has %zu buffers of size [%" PRId32 ", %" PRId32 "]\n",
                         mEnabled ? "enabled" : "disabled", mPool.size(), mSize.width, mSize.height);
 }
 
-} // namespace android::compositionengine::impl::planner
+} // namespace android::compositionengine::impl::planner
\ No newline at end of file
diff --git a/services/surfaceflinger/CompositionEngine/tests/planner/TexturePoolTest.cpp b/services/surfaceflinger/CompositionEngine/tests/planner/TexturePoolTest.cpp
index 494a9f4..6fc90fe 100644
--- a/services/surfaceflinger/CompositionEngine/tests/planner/TexturePoolTest.cpp
+++ b/services/surfaceflinger/CompositionEngine/tests/planner/TexturePoolTest.cpp
@@ -32,9 +32,9 @@
 public:
     TestableTexturePool(renderengine::RenderEngine& renderEngine) : TexturePool(renderEngine) {}
 
+    size_t getMinPoolSize() const { return kMinPoolSize; }
     size_t getMaxPoolSize() const { return kMaxPoolSize; }
     size_t getPoolSize() const { return mPool.size(); }
-    size_t isGenTextureFutureValid() const { return mGenTextureFuture.valid(); }
 };
 
 struct TexturePoolTest : public testing::Test {
@@ -56,8 +56,16 @@
     TestableTexturePool mTexturePool = TestableTexturePool(mRenderEngine);
 };
 
-TEST_F(TexturePoolTest, preallocatesZeroSizePool) {
-    EXPECT_EQ(mTexturePool.getPoolSize(), 0u);
+TEST_F(TexturePoolTest, preallocatesMinPool) {
+    EXPECT_EQ(mTexturePool.getMinPoolSize(), mTexturePool.getPoolSize());
+}
+
+TEST_F(TexturePoolTest, doesNotAllocateBeyondMinPool) {
+    for (size_t i = 0; i < mTexturePool.getMinPoolSize() + 1; i++) {
+        auto texture = mTexturePool.borrowTexture();
+    }
+
+    EXPECT_EQ(mTexturePool.getMinPoolSize(), mTexturePool.getPoolSize());
 }
 
 TEST_F(TexturePoolTest, cyclesUpToMaxPoolSize) {
@@ -111,10 +119,10 @@
               static_cast<int32_t>(texture->get()->getBuffer()->getHeight()));
     mTexturePool.setDisplaySize(kDisplaySizeTwo);
 
-    EXPECT_EQ(mTexturePool.getPoolSize(), 0u);
+    EXPECT_EQ(mTexturePool.getMinPoolSize(), mTexturePool.getPoolSize());
     texture.reset();
     // When the texture is returned to the pool, the pool now destroys it.
-    EXPECT_EQ(mTexturePool.getPoolSize(), 0u);
+    EXPECT_EQ(mTexturePool.getMinPoolSize(), mTexturePool.getPoolSize());
 
     texture = mTexturePool.borrowTexture();
     EXPECT_EQ(kDisplaySizeTwo.getWidth(),
@@ -124,11 +132,14 @@
 }
 
 TEST_F(TexturePoolTest, freesBuffersWhenDisabled) {
+    EXPECT_EQ(mTexturePool.getPoolSize(), mTexturePool.getMinPoolSize());
+
     std::deque<std::shared_ptr<TexturePool::AutoTexture>> textures;
-    for (size_t i = 0; i < 2; i++) {
+    for (size_t i = 0; i < mTexturePool.getMinPoolSize() - 1; i++) {
         textures.emplace_back(mTexturePool.borrowTexture());
     }
 
+    EXPECT_EQ(mTexturePool.getPoolSize(), 1u);
     mTexturePool.setEnabled(false);
     EXPECT_EQ(mTexturePool.getPoolSize(), 0u);
 
@@ -137,11 +148,12 @@
 }
 
 TEST_F(TexturePoolTest, doesNotHoldBuffersWhenDisabled) {
+    EXPECT_EQ(mTexturePool.getPoolSize(), mTexturePool.getMinPoolSize());
     mTexturePool.setEnabled(false);
     EXPECT_EQ(mTexturePool.getPoolSize(), 0u);
 
     std::deque<std::shared_ptr<TexturePool::AutoTexture>> textures;
-    for (size_t i = 0; i < 2; i++) {
+    for (size_t i = 0; i < mTexturePool.getMinPoolSize() - 1; i++) {
         textures.emplace_back(mTexturePool.borrowTexture());
     }
 
@@ -150,13 +162,12 @@
     EXPECT_EQ(mTexturePool.getPoolSize(), 0u);
 }
 
-TEST_F(TexturePoolTest, genFutureWhenReEnabled) {
+TEST_F(TexturePoolTest, reallocatesWhenReEnabled) {
+    EXPECT_EQ(mTexturePool.getPoolSize(), mTexturePool.getMinPoolSize());
     mTexturePool.setEnabled(false);
     EXPECT_EQ(mTexturePool.getPoolSize(), 0u);
-    EXPECT_FALSE(mTexturePool.isGenTextureFutureValid());
     mTexturePool.setEnabled(true);
-    EXPECT_EQ(mTexturePool.getPoolSize(), 0u);
-    EXPECT_TRUE(mTexturePool.isGenTextureFutureValid());
+    EXPECT_EQ(mTexturePool.getPoolSize(), mTexturePool.getMinPoolSize());
 }
 
 } // namespace
diff --git a/services/surfaceflinger/Scheduler/RefreshRateSelector.cpp b/services/surfaceflinger/Scheduler/RefreshRateSelector.cpp
index c44e22e..6b7d7df 100644
--- a/services/surfaceflinger/Scheduler/RefreshRateSelector.cpp
+++ b/services/surfaceflinger/Scheduler/RefreshRateSelector.cpp
@@ -148,8 +148,8 @@
 } // namespace
 
 auto RefreshRateSelector::createFrameRateModes(
-        std::function<bool(const DisplayMode&)>&& filterModes, const FpsRange& renderRange) const
-        -> std::vector<FrameRateMode> {
+        const Policy& policy, std::function<bool(const DisplayMode&)>&& filterModes,
+        const FpsRange& renderRange) const -> std::vector<FrameRateMode> {
     struct Key {
         Fps fps;
         int32_t group;
@@ -202,11 +202,25 @@
                 ALOGV("%s: including %s (%s)", __func__, to_string(fps).c_str(),
                       to_string(mode->getFps()).c_str());
             } else {
-                // We might need to update the map as we found a lower refresh rate
-                if (isStrictlyLess(mode->getFps(), existingIter->second->second->getFps())) {
+                // If the primary physical range is a single rate, prefer to stay in that rate
+                // even if there is a lower physical refresh rate available. This would cause more
+                // cases to stay within the primary physical range
+                const Fps existingModeFps = existingIter->second->second->getFps();
+                const bool existingModeIsPrimaryRange = policy.primaryRangeIsSingleRate() &&
+                        policy.primaryRanges.physical.includes(existingModeFps);
+                const bool newModeIsPrimaryRange = policy.primaryRangeIsSingleRate() &&
+                        policy.primaryRanges.physical.includes(mode->getFps());
+                if (newModeIsPrimaryRange == existingModeIsPrimaryRange) {
+                    // We might need to update the map as we found a lower refresh rate
+                    if (isStrictlyLess(mode->getFps(), existingModeFps)) {
+                        existingIter->second = it;
+                        ALOGV("%s: changing %s (%s) as we found a lower physical rate", __func__,
+                              to_string(fps).c_str(), to_string(mode->getFps()).c_str());
+                    }
+                } else if (newModeIsPrimaryRange) {
                     existingIter->second = it;
-                    ALOGV("%s: changing %s (%s)", __func__, to_string(fps).c_str(),
-                          to_string(mode->getFps()).c_str());
+                    ALOGV("%s: changing %s (%s) to stay in the primary range", __func__,
+                          to_string(fps).c_str(), to_string(mode->getFps()).c_str());
                 }
             }
         }
@@ -500,10 +514,8 @@
     // If the primary range consists of a single refresh rate then we can only
     // move out the of range if layers explicitly request a different refresh
     // rate.
-    const bool primaryRangeIsSingleRate =
-            isApproxEqual(policy->primaryRanges.physical.min, policy->primaryRanges.physical.max);
-
-    if (!signals.touch && signals.idle && !(primaryRangeIsSingleRate && hasExplicitVoteLayers)) {
+    if (!signals.touch && signals.idle &&
+        !(policy->primaryRangeIsSingleRate() && hasExplicitVoteLayers)) {
         ALOGV("Idle");
         const auto ranking = rankFrameRates(activeMode.getGroup(), RefreshRateOrder::Ascending);
         ATRACE_FORMAT_INSTANT("%s (Idle)", to_string(ranking.front().frameRateMode.fps).c_str());
@@ -577,8 +589,11 @@
                 continue;
             }
 
-            const bool inPrimaryRange = policy->primaryRanges.render.includes(fps);
-            if ((primaryRangeIsSingleRate || !inPrimaryRange) &&
+            const bool inPrimaryPhysicalRange =
+                    policy->primaryRanges.physical.includes(modePtr->getFps());
+            const bool inPrimaryRenderRange = policy->primaryRanges.render.includes(fps);
+            if (((policy->primaryRangeIsSingleRate() && !inPrimaryPhysicalRange) ||
+                 !inPrimaryRenderRange) &&
                 !(layer.focused &&
                   (layer.vote == LayerVoteType::ExplicitDefault ||
                    layer.vote == LayerVoteType::ExplicitExact))) {
@@ -689,7 +704,7 @@
         return score.overallScore == 0;
     });
 
-    if (primaryRangeIsSingleRate) {
+    if (policy->primaryRangeIsSingleRate()) {
         // If we never scored any layers, then choose the rate from the primary
         // range instead of picking a random score from the app range.
         if (noLayerScore) {
@@ -1234,14 +1249,14 @@
                     (supportsFrameRateOverride() || ranges.render.includes(mode.getFps()));
         };
 
-        auto frameRateModes = createFrameRateModes(filterModes, ranges.render);
+        auto frameRateModes = createFrameRateModes(*policy, filterModes, ranges.render);
         if (frameRateModes.empty()) {
             ALOGW("No matching frame rate modes for %s range. policy: %s", rangeName,
                   policy->toString().c_str());
             // TODO(b/292105422): Ideally DisplayManager should not send render ranges smaller than
             // the min supported. See b/292047939.
             //  For not we just ignore the render ranges.
-            frameRateModes = createFrameRateModes(filterModes, {});
+            frameRateModes = createFrameRateModes(*policy, filterModes, {});
         }
         LOG_ALWAYS_FATAL_IF(frameRateModes.empty(),
                             "No matching frame rate modes for %s range even after ignoring the "
diff --git a/services/surfaceflinger/Scheduler/RefreshRateSelector.h b/services/surfaceflinger/Scheduler/RefreshRateSelector.h
index 7af8d03..b25919e 100644
--- a/services/surfaceflinger/Scheduler/RefreshRateSelector.h
+++ b/services/surfaceflinger/Scheduler/RefreshRateSelector.h
@@ -101,6 +101,11 @@
         }
 
         bool operator!=(const Policy& other) const { return !(*this == other); }
+
+        bool primaryRangeIsSingleRate() const {
+            return isApproxEqual(primaryRanges.physical.min, primaryRanges.physical.max);
+        }
+
         std::string toString() const;
     };
 
@@ -468,8 +473,8 @@
     }
 
     std::vector<FrameRateMode> createFrameRateModes(
-            std::function<bool(const DisplayMode&)>&& filterModes, const FpsRange&) const
-            REQUIRES(mLock);
+            const Policy&, std::function<bool(const DisplayMode&)>&& filterModes,
+            const FpsRange&) const REQUIRES(mLock);
 
     // The display modes of the active display. The DisplayModeIterators below are pointers into
     // this container, so must be invalidated whenever the DisplayModes change. The Policy below
diff --git a/services/surfaceflinger/ScreenCaptureOutput.cpp b/services/surfaceflinger/ScreenCaptureOutput.cpp
index 0103843..ef9b457 100644
--- a/services/surfaceflinger/ScreenCaptureOutput.cpp
+++ b/services/surfaceflinger/ScreenCaptureOutput.cpp
@@ -45,10 +45,9 @@
 std::shared_ptr<ScreenCaptureOutput> createScreenCaptureOutput(ScreenCaptureOutputArgs args) {
     std::shared_ptr<ScreenCaptureOutput> output = compositionengine::impl::createOutputTemplated<
             ScreenCaptureOutput, compositionengine::CompositionEngine, const RenderArea&,
-            const compositionengine::Output::ColorProfile&, bool>(args.compositionEngine,
-                                                                  args.renderArea,
-                                                                  args.colorProfile,
-                                                                  args.regionSampling);
+            const compositionengine::Output::ColorProfile&,
+            bool>(args.compositionEngine, args.renderArea, args.colorProfile, args.regionSampling,
+                  args.dimInGammaSpaceForEnhancedScreenshots);
     output->editState().isSecure = args.renderArea.isSecure();
     output->setCompositionEnabled(true);
     output->setLayerFilter({args.layerStack});
@@ -81,8 +80,11 @@
 
 ScreenCaptureOutput::ScreenCaptureOutput(
         const RenderArea& renderArea, const compositionengine::Output::ColorProfile& colorProfile,
-        bool regionSampling)
-      : mRenderArea(renderArea), mColorProfile(colorProfile), mRegionSampling(regionSampling) {}
+        bool regionSampling, bool dimInGammaSpaceForEnhancedScreenshots)
+      : mRenderArea(renderArea),
+        mColorProfile(colorProfile),
+        mRegionSampling(regionSampling),
+        mDimInGammaSpaceForEnhancedScreenshots(dimInGammaSpaceForEnhancedScreenshots) {}
 
 void ScreenCaptureOutput::updateColorProfile(const compositionengine::CompositionRefreshArgs&) {
     auto& outputState = editState();
@@ -95,6 +97,14 @@
     auto clientCompositionDisplay =
             compositionengine::impl::Output::generateClientCompositionDisplaySettings();
     clientCompositionDisplay.clip = mRenderArea.getSourceCrop();
+
+    auto renderIntent = static_cast<ui::RenderIntent>(clientCompositionDisplay.renderIntent);
+    if (mDimInGammaSpaceForEnhancedScreenshots && renderIntent != ui::RenderIntent::COLORIMETRIC &&
+        renderIntent != ui::RenderIntent::TONE_MAP_COLORIMETRIC) {
+        clientCompositionDisplay.dimmingStage =
+                aidl::android::hardware::graphics::composer3::DimmingStage::GAMMA_OETF;
+    }
+
     return clientCompositionDisplay;
 }
 
diff --git a/services/surfaceflinger/ScreenCaptureOutput.h b/services/surfaceflinger/ScreenCaptureOutput.h
index 159c2bf..fc095de 100644
--- a/services/surfaceflinger/ScreenCaptureOutput.h
+++ b/services/surfaceflinger/ScreenCaptureOutput.h
@@ -37,6 +37,7 @@
     float targetBrightness;
     bool regionSampling;
     bool treat170mAsSrgb;
+    bool dimInGammaSpaceForEnhancedScreenshots;
 };
 
 // ScreenCaptureOutput is used to compose a set of layers into a preallocated buffer.
@@ -47,7 +48,7 @@
 public:
     ScreenCaptureOutput(const RenderArea& renderArea,
                         const compositionengine::Output::ColorProfile& colorProfile,
-                        bool regionSampling);
+                        bool regionSampling, bool dimInGammaSpaceForEnhancedScreenshots);
 
     void updateColorProfile(const compositionengine::CompositionRefreshArgs&) override;
 
@@ -63,6 +64,7 @@
     const RenderArea& mRenderArea;
     const compositionengine::Output::ColorProfile& mColorProfile;
     const bool mRegionSampling;
+    const bool mDimInGammaSpaceForEnhancedScreenshots;
 };
 
 std::shared_ptr<ScreenCaptureOutput> createScreenCaptureOutput(ScreenCaptureOutputArgs);
diff --git a/services/surfaceflinger/SurfaceFlinger.cpp b/services/surfaceflinger/SurfaceFlinger.cpp
index 8c0a329..286222e 100644
--- a/services/surfaceflinger/SurfaceFlinger.cpp
+++ b/services/surfaceflinger/SurfaceFlinger.cpp
@@ -451,6 +451,9 @@
     property_get("debug.sf.treat_170m_as_sRGB", value, "0");
     mTreat170mAsSrgb = atoi(value);
 
+    property_get("debug.sf.dim_in_gamma_in_enhanced_screenshots", value, 0);
+    mDimInGammaSpaceForEnhancedScreenshots = atoi(value);
+
     mIgnoreHwcPhysicalDisplayOrientation =
             base::GetBoolProperty("debug.sf.ignore_hwc_physical_display_orientation"s, false);
 
@@ -7481,7 +7484,9 @@
                                         .displayBrightnessNits = displayBrightnessNits,
                                         .targetBrightness = targetBrightness,
                                         .regionSampling = regionSampling,
-                                        .treat170mAsSrgb = mTreat170mAsSrgb});
+                                        .treat170mAsSrgb = mTreat170mAsSrgb,
+                                        .dimInGammaSpaceForEnhancedScreenshots =
+                                                mDimInGammaSpaceForEnhancedScreenshots});
 
         const float colorSaturation = grayscale ? 0 : 1;
         compositionengine::CompositionRefreshArgs refreshArgs{
diff --git a/services/surfaceflinger/SurfaceFlinger.h b/services/surfaceflinger/SurfaceFlinger.h
index e3e72ed..f1989db 100644
--- a/services/surfaceflinger/SurfaceFlinger.h
+++ b/services/surfaceflinger/SurfaceFlinger.h
@@ -323,6 +323,11 @@
     // on this behavior to increase contrast for some media sources.
     bool mTreat170mAsSrgb = false;
 
+    // If true, then screenshots with an enhanced render intent will dim in gamma space.
+    // The purpose is to ensure that screenshots appear correct during system animations for devices
+    // that require that dimming must occur in gamma space.
+    bool mDimInGammaSpaceForEnhancedScreenshots = false;
+
     // Allows to ignore physical orientation provided through hwc API in favour of
     // 'ro.surface_flinger.primary_display_orientation'.
     // TODO(b/246793311): Clean up a temporary property
diff --git a/services/surfaceflinger/tests/unittests/DisplayDevice_InitiateModeChange.cpp b/services/surfaceflinger/tests/unittests/DisplayDevice_InitiateModeChange.cpp
index 60ad7a3..2d87ddd 100644
--- a/services/surfaceflinger/tests/unittests/DisplayDevice_InitiateModeChange.cpp
+++ b/services/surfaceflinger/tests/unittests/DisplayDevice_InitiateModeChange.cpp
@@ -78,7 +78,7 @@
               mDisplay->setDesiredActiveMode(
                       {scheduler::FrameRateMode{90_Hz, kMode90}, Event::None}));
     ASSERT_NE(std::nullopt, mDisplay->getDesiredActiveMode());
-    EXPECT_FRAME_RATE_MODE(kMode90, 90_Hz, mDisplay->getDesiredActiveMode()->modeOpt);
+    EXPECT_FRAME_RATE_MODE(kMode90, 90_Hz, *mDisplay->getDesiredActiveMode()->modeOpt);
     EXPECT_EQ(Event::None, mDisplay->getDesiredActiveMode()->event);
 
     // Setting another mode should be cached but return None
@@ -86,7 +86,7 @@
               mDisplay->setDesiredActiveMode(
                       {scheduler::FrameRateMode{120_Hz, kMode120}, Event::None}));
     ASSERT_NE(std::nullopt, mDisplay->getDesiredActiveMode());
-    EXPECT_FRAME_RATE_MODE(kMode120, 120_Hz, mDisplay->getDesiredActiveMode()->modeOpt);
+    EXPECT_FRAME_RATE_MODE(kMode120, 120_Hz, *mDisplay->getDesiredActiveMode()->modeOpt);
     EXPECT_EQ(Event::None, mDisplay->getDesiredActiveMode()->event);
 }
 
@@ -105,7 +105,7 @@
               mDisplay->setDesiredActiveMode(
                       {scheduler::FrameRateMode{90_Hz, kMode90}, Event::None}));
     ASSERT_NE(std::nullopt, mDisplay->getDesiredActiveMode());
-    EXPECT_FRAME_RATE_MODE(kMode90, 90_Hz, mDisplay->getDesiredActiveMode()->modeOpt);
+    EXPECT_FRAME_RATE_MODE(kMode90, 90_Hz, *mDisplay->getDesiredActiveMode()->modeOpt);
     EXPECT_EQ(Event::None, mDisplay->getDesiredActiveMode()->event);
 
     hal::VsyncPeriodChangeConstraints constraints{
@@ -136,7 +136,7 @@
               mDisplay->setDesiredActiveMode(
                       {scheduler::FrameRateMode{90_Hz, kMode90}, Event::None}));
     ASSERT_NE(std::nullopt, mDisplay->getDesiredActiveMode());
-    EXPECT_FRAME_RATE_MODE(kMode90, 90_Hz, mDisplay->getDesiredActiveMode()->modeOpt);
+    EXPECT_FRAME_RATE_MODE(kMode90, 90_Hz, *mDisplay->getDesiredActiveMode()->modeOpt);
     EXPECT_EQ(Event::None, mDisplay->getDesiredActiveMode()->event);
 
     hal::VsyncPeriodChangeConstraints constraints{
@@ -154,7 +154,7 @@
               mDisplay->setDesiredActiveMode(
                       {scheduler::FrameRateMode{120_Hz, kMode120}, Event::None}));
     ASSERT_NE(std::nullopt, mDisplay->getDesiredActiveMode());
-    EXPECT_FRAME_RATE_MODE(kMode120, 120_Hz, mDisplay->getDesiredActiveMode()->modeOpt);
+    EXPECT_FRAME_RATE_MODE(kMode120, 120_Hz, *mDisplay->getDesiredActiveMode()->modeOpt);
     EXPECT_EQ(Event::None, mDisplay->getDesiredActiveMode()->event);
 
     EXPECT_FRAME_RATE_MODE(kMode90, 90_Hz, *mDisplay->getUpcomingActiveMode().modeOpt);
diff --git a/services/surfaceflinger/tests/unittests/RefreshRateSelectorTest.cpp b/services/surfaceflinger/tests/unittests/RefreshRateSelectorTest.cpp
index aaf55fb..0397b99 100644
--- a/services/surfaceflinger/tests/unittests/RefreshRateSelectorTest.cpp
+++ b/services/surfaceflinger/tests/unittests/RefreshRateSelectorTest.cpp
@@ -3125,5 +3125,69 @@
                       {DisplayModeId(kModeId60), kLowerThanMin, kLowerThanMin}));
 }
 
+// b/296079213
+TEST_P(RefreshRateSelectorTest, frameRateOverrideInBlockingZone60_120) {
+    auto selector = createSelector(kModes_60_120, kModeId120);
+
+    const FpsRange only120 = {120_Hz, 120_Hz};
+    const FpsRange allRange = {0_Hz, 120_Hz};
+    EXPECT_EQ(SetPolicyResult::Changed,
+              selector.setDisplayManagerPolicy(
+                      {kModeId120, {only120, allRange}, {allRange, allRange}}));
+
+    std::vector<LayerRequirement> layers = {{.weight = 1.f}};
+    layers[0].name = "30Hz ExplicitExactOrMultiple";
+    layers[0].desiredRefreshRate = 30_Hz;
+    layers[0].vote = LayerVoteType::ExplicitExactOrMultiple;
+
+    if (GetParam() != Config::FrameRateOverride::Enabled) {
+        EXPECT_FRAME_RATE_MODE(kMode120, 120_Hz,
+                               selector.getBestScoredFrameRate(layers).frameRateMode);
+    } else {
+        EXPECT_FRAME_RATE_MODE(kMode120, 30_Hz,
+                               selector.getBestScoredFrameRate(layers).frameRateMode);
+    }
+}
+
+TEST_P(RefreshRateSelectorTest, frameRateOverrideInBlockingZone60_90) {
+    auto selector = createSelector(kModes_60_90, kModeId90);
+
+    const FpsRange only90 = {90_Hz, 90_Hz};
+    const FpsRange allRange = {0_Hz, 90_Hz};
+    EXPECT_EQ(SetPolicyResult::Changed,
+              selector.setDisplayManagerPolicy(
+                      {kModeId90, {only90, allRange}, {allRange, allRange}}));
+
+    std::vector<LayerRequirement> layers = {{.weight = 1.f}};
+    layers[0].name = "30Hz ExplicitExactOrMultiple";
+    layers[0].desiredRefreshRate = 30_Hz;
+    layers[0].vote = LayerVoteType::ExplicitExactOrMultiple;
+
+    if (GetParam() != Config::FrameRateOverride::Enabled) {
+        EXPECT_FRAME_RATE_MODE(kMode90, 90_Hz,
+                               selector.getBestScoredFrameRate(layers).frameRateMode);
+    } else {
+        EXPECT_FRAME_RATE_MODE(kMode90, 30_Hz,
+                               selector.getBestScoredFrameRate(layers).frameRateMode);
+    }
+}
+
+TEST_P(RefreshRateSelectorTest, frameRateOverrideInBlockingZone60_90_NonDivisor) {
+    auto selector = createSelector(kModes_60_90, kModeId90);
+
+    const FpsRange only90 = {90_Hz, 90_Hz};
+    const FpsRange allRange = {0_Hz, 90_Hz};
+    EXPECT_EQ(SetPolicyResult::Changed,
+              selector.setDisplayManagerPolicy(
+                      {kModeId90, {only90, allRange}, {allRange, allRange}}));
+
+    std::vector<LayerRequirement> layers = {{.weight = 1.f}};
+    layers[0].name = "60Hz ExplicitExactOrMultiple";
+    layers[0].desiredRefreshRate = 60_Hz;
+    layers[0].vote = LayerVoteType::ExplicitExactOrMultiple;
+
+    EXPECT_FRAME_RATE_MODE(kMode90, 90_Hz, selector.getBestScoredFrameRate(layers).frameRateMode);
+}
+
 } // namespace
 } // namespace android::scheduler
diff --git a/services/surfaceflinger/tests/unittests/mock/MockFrameRateMode.h b/services/surfaceflinger/tests/unittests/mock/MockFrameRateMode.h
index ef9cd9b..4cfdd58 100644
--- a/services/surfaceflinger/tests/unittests/mock/MockFrameRateMode.h
+++ b/services/surfaceflinger/tests/unittests/mock/MockFrameRateMode.h
@@ -19,5 +19,7 @@
 #include <scheduler/FrameRateMode.h>
 
 // Use a C style macro to keep the line numbers printed in gtest
-#define EXPECT_FRAME_RATE_MODE(modePtr, fps, mode) \
-    EXPECT_EQ((scheduler::FrameRateMode{(fps), (modePtr)}), (mode))
+#define EXPECT_FRAME_RATE_MODE(_modePtr, _fps, _mode)                                \
+    EXPECT_EQ((scheduler::FrameRateMode{(_fps), (_modePtr)}), (_mode))               \
+            << "Expected " << (_fps) << " (" << (_modePtr)->getFps() << ") but was " \
+            << (_mode).fps << " (" << (_mode).modePtr->getFps() << ")"
diff --git a/vulkan/libvulkan/swapchain.cpp b/vulkan/libvulkan/swapchain.cpp
index c28390f..bfb421d 100644
--- a/vulkan/libvulkan/swapchain.cpp
+++ b/vulkan/libvulkan/swapchain.cpp
@@ -532,7 +532,8 @@
     return native_format;
 }
 
-android_dataspace GetNativeDataspace(VkColorSpaceKHR colorspace) {
+android_dataspace GetNativeDataspace(VkColorSpaceKHR colorspace,
+                                     android::PixelFormat pixelFormat) {
     switch (colorspace) {
         case VK_COLOR_SPACE_SRGB_NONLINEAR_KHR:
             return HAL_DATASPACE_V0_SRGB;
@@ -551,7 +552,14 @@
         case VK_COLOR_SPACE_BT709_NONLINEAR_EXT:
             return HAL_DATASPACE_V0_SRGB;
         case VK_COLOR_SPACE_BT2020_LINEAR_EXT:
-            return HAL_DATASPACE_BT2020_LINEAR;
+            if (pixelFormat == HAL_PIXEL_FORMAT_RGBA_FP16) {
+                return static_cast<android_dataspace>(
+                    HAL_DATASPACE_STANDARD_BT2020 |
+                    HAL_DATASPACE_TRANSFER_LINEAR |
+                    HAL_DATASPACE_RANGE_EXTENDED);
+            } else {
+                return HAL_DATASPACE_BT2020_LINEAR;
+            }
         case VK_COLOR_SPACE_HDR10_ST2084_EXT:
             return static_cast<android_dataspace>(
                 HAL_DATASPACE_STANDARD_BT2020 | HAL_DATASPACE_TRANSFER_ST2084 |
@@ -561,9 +569,7 @@
                 HAL_DATASPACE_STANDARD_BT2020 | HAL_DATASPACE_TRANSFER_ST2084 |
                 HAL_DATASPACE_RANGE_FULL);
         case VK_COLOR_SPACE_HDR10_HLG_EXT:
-            return static_cast<android_dataspace>(
-                HAL_DATASPACE_STANDARD_BT2020 | HAL_DATASPACE_TRANSFER_HLG |
-                HAL_DATASPACE_RANGE_FULL);
+            return static_cast<android_dataspace>(HAL_DATASPACE_BT2020_HLG);
         case VK_COLOR_SPACE_ADOBERGB_LINEAR_EXT:
             return static_cast<android_dataspace>(
                 HAL_DATASPACE_STANDARD_ADOBE_RGB |
@@ -1364,7 +1370,7 @@
     android::PixelFormat native_pixel_format =
         GetNativePixelFormat(create_info->imageFormat);
     android_dataspace native_dataspace =
-        GetNativeDataspace(create_info->imageColorSpace);
+        GetNativeDataspace(create_info->imageColorSpace, native_pixel_format);
     if (native_dataspace == HAL_DATASPACE_UNKNOWN) {
         ALOGE(
             "CreateSwapchainKHR(VkSwapchainCreateInfoKHR.imageColorSpace = %d) "