Merge "Add TFLite model for motion prediction."
diff --git a/data/etc/input/Android.bp b/data/etc/input/Android.bp
new file mode 100644
index 0000000..90f3c6b
--- /dev/null
+++ b/data/etc/input/Android.bp
@@ -0,0 +1,14 @@
+package {
+    default_applicable_licenses: ["frameworks_native_license"],
+}
+
+filegroup {
+    name: "motion_predictor_model.fb",
+    srcs: ["motion_predictor_model.fb"],
+}
+
+prebuilt_etc {
+    name: "motion_predictor_model_prebuilt",
+    filename_from_src: true,
+    src: "motion_predictor_model.fb",
+}
diff --git a/data/etc/input/motion_predictor_model.fb b/data/etc/input/motion_predictor_model.fb
new file mode 100644
index 0000000..10b3c8b
--- /dev/null
+++ b/data/etc/input/motion_predictor_model.fb
Binary files differ
diff --git a/include/input/MotionPredictor.h b/include/input/MotionPredictor.h
index 045e61b..3fae4e6 100644
--- a/include/input/MotionPredictor.h
+++ b/include/input/MotionPredictor.h
@@ -16,9 +16,15 @@
 
 #pragma once
 
+#include <cstdint>
+#include <memory>
+#include <mutex>
+#include <unordered_map>
+
 #include <android-base/thread_annotations.h>
 #include <android/sysprop/InputProperties.sysprop.h>
 #include <input/Input.h>
+#include <input/TfLiteMotionPredictor.h>
 
 namespace android {
 
@@ -28,48 +34,51 @@
 
 /**
  * Given a set of MotionEvents for the current gesture, predict the motion. The returned MotionEvent
- * contains a set of samples in the future, up to "presentation time + offset".
+ * contains a set of samples in the future.
  *
  * The typical usage is like this:
  *
  * MotionPredictor predictor(offset = MY_OFFSET);
- * predictor.setExpectedPresentationTimeNanos(NEXT_PRESENT_TIME);
  * predictor.record(DOWN_MOTION_EVENT);
  * predictor.record(MOVE_MOTION_EVENT);
- * prediction = predictor.predict();
+ * prediction = predictor.predict(futureTime);
  *
- * The presentation time should be set some time before calling .predict(). It could be set before
- * or after the recorded motion events. Must be done on every frame.
- *
- * The resulting motion event will have eventTime <= (NEXT_PRESENT_TIME + MY_OFFSET). It might
- * contain historical data, which are additional samples from the latest recorded MotionEvent's
- * eventTime to the NEXT_PRESENT_TIME + MY_OFFSET.
+ * The resulting motion event will have eventTime <= (futureTime + MY_OFFSET). It might contain
+ * historical data, which are additional samples from the latest recorded MotionEvent's eventTime
+ * to the futureTime + MY_OFFSET.
  *
  * The offset is used to provide additional flexibility to the caller, in case the default present
  * time (typically provided by the choreographer) does not account for some delays, or to simply
- * reduce the aggressiveness of the prediction. Offset can be both positive and negative.
+ * reduce the aggressiveness of the prediction. Offset can be positive or negative.
  */
 class MotionPredictor {
 public:
     /**
      * Parameters:
      * predictionTimestampOffsetNanos: additional, constant shift to apply to the target
-     * presentation time. The prediction will target the time t=(presentationTime +
+     * prediction time. The prediction will target the time t=(prediction time +
      * predictionTimestampOffsetNanos).
      *
+     * modelPath: filesystem path to a TfLiteMotionPredictorModel flatbuffer, or nullptr to use the
+     * default model path.
+     *
      * checkEnableMotionPredition: the function to check whether the prediction should run. Used to
      * provide an additional way of turning prediction on and off. Can be toggled at runtime.
      */
-    MotionPredictor(nsecs_t predictionTimestampOffsetNanos,
+    MotionPredictor(nsecs_t predictionTimestampOffsetNanos, const char* modelPath = nullptr,
                     std::function<bool()> checkEnableMotionPrediction = isMotionPredictionEnabled);
     void record(const MotionEvent& event);
     std::vector<std::unique_ptr<MotionEvent>> predict(nsecs_t timestamp);
     bool isPredictionAvailable(int32_t deviceId, int32_t source);
 
 private:
-    std::vector<MotionEvent> mEvents;
     const nsecs_t mPredictionTimestampOffsetNanos;
     const std::function<bool()> mCheckMotionPredictionEnabled;
+
+    std::unique_ptr<TfLiteMotionPredictorModel> mModel;
+    // Buffers/events for each device seen by record().
+    std::unordered_map</*deviceId*/ int32_t, TfLiteMotionPredictorBuffers> mDeviceBuffers;
+    std::unordered_map</*deviceId*/ int32_t, MotionEvent> mLastEvents;
 };
 
 } // namespace android
diff --git a/include/input/TfLiteMotionPredictor.h b/include/input/TfLiteMotionPredictor.h
new file mode 100644
index 0000000..ff0f51c
--- /dev/null
+++ b/include/input/TfLiteMotionPredictor.h
@@ -0,0 +1,147 @@
+/*
+ * Copyright (C) 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include <array>
+#include <cstddef>
+#include <cstdint>
+#include <memory>
+#include <optional>
+#include <span>
+#include <string>
+#include <vector>
+
+#include <tensorflow/lite/core/api/error_reporter.h>
+#include <tensorflow/lite/interpreter.h>
+#include <tensorflow/lite/model.h>
+#include <tensorflow/lite/signature_runner.h>
+
+namespace android {
+
+struct TfLiteMotionPredictorSample {
+    // The untransformed AMOTION_EVENT_AXIS_X and AMOTION_EVENT_AXIS_Y of the sample.
+    struct Point {
+        float x;
+        float y;
+    } position;
+    // The AMOTION_EVENT_AXIS_PRESSURE, _TILT, and _ORIENTATION.
+    float pressure;
+    float tilt;
+    float orientation;
+};
+
+inline TfLiteMotionPredictorSample::Point operator-(const TfLiteMotionPredictorSample::Point& lhs,
+                                                    const TfLiteMotionPredictorSample::Point& rhs) {
+    return {.x = lhs.x - rhs.x, .y = lhs.y - rhs.y};
+}
+
+class TfLiteMotionPredictorModel;
+
+// Buffer storage for a TfLiteMotionPredictorModel.
+class TfLiteMotionPredictorBuffers {
+public:
+    // Creates buffer storage for a model with the given input length.
+    TfLiteMotionPredictorBuffers(size_t inputLength);
+
+    // Adds a motion sample to the buffers.
+    void pushSample(int64_t timestamp, TfLiteMotionPredictorSample sample);
+
+    // Returns true if the buffers are complete enough to generate a prediction.
+    bool isReady() const {
+        // Predictions can't be applied unless there are at least two points to determine
+        // the direction to apply them in.
+        return mAxisFrom && mAxisTo;
+    }
+
+    // Resets all buffers to their initial state.
+    void reset();
+
+    // Copies the buffers to those of a model for prediction.
+    void copyTo(TfLiteMotionPredictorModel& model) const;
+
+    // Returns the current axis of the buffer's samples. Only valid if isReady().
+    TfLiteMotionPredictorSample axisFrom() const { return *mAxisFrom; }
+    TfLiteMotionPredictorSample axisTo() const { return *mAxisTo; }
+
+    // Returns the timestamp of the last sample.
+    int64_t lastTimestamp() const { return mTimestamp; }
+
+private:
+    int64_t mTimestamp = 0;
+
+    std::vector<float> mInputR;
+    std::vector<float> mInputPhi;
+    std::vector<float> mInputPressure;
+    std::vector<float> mInputTilt;
+    std::vector<float> mInputOrientation;
+
+    // The samples defining the current polar axis.
+    std::optional<TfLiteMotionPredictorSample> mAxisFrom;
+    std::optional<TfLiteMotionPredictorSample> mAxisTo;
+};
+
+// A TFLite model for generating motion predictions.
+class TfLiteMotionPredictorModel {
+public:
+    // Creates a model from an encoded Flatbuffer model.
+    static std::unique_ptr<TfLiteMotionPredictorModel> create(const char* modelPath);
+
+    // Returns the length of the model's input buffers.
+    size_t inputLength() const;
+
+    // Executes the model.
+    // Returns true if the model successfully executed and the output tensors can be read.
+    bool invoke();
+
+    // Returns mutable buffers to the input tensors of inputLength() elements.
+    std::span<float> inputR();
+    std::span<float> inputPhi();
+    std::span<float> inputPressure();
+    std::span<float> inputOrientation();
+    std::span<float> inputTilt();
+
+    // Returns immutable buffers to the output tensors of identical length. Only valid after a
+    // successful call to invoke().
+    std::span<const float> outputR() const;
+    std::span<const float> outputPhi() const;
+    std::span<const float> outputPressure() const;
+
+private:
+    explicit TfLiteMotionPredictorModel(std::string model);
+
+    void allocateTensors();
+    void attachInputTensors();
+    void attachOutputTensors();
+
+    TfLiteTensor* mInputR = nullptr;
+    TfLiteTensor* mInputPhi = nullptr;
+    TfLiteTensor* mInputPressure = nullptr;
+    TfLiteTensor* mInputTilt = nullptr;
+    TfLiteTensor* mInputOrientation = nullptr;
+
+    const TfLiteTensor* mOutputR = nullptr;
+    const TfLiteTensor* mOutputPhi = nullptr;
+    const TfLiteTensor* mOutputPressure = nullptr;
+
+    std::string mFlatBuffer;
+    std::unique_ptr<tflite::ErrorReporter> mErrorReporter;
+    std::unique_ptr<tflite::FlatBufferModel> mModel;
+    std::unique_ptr<tflite::Interpreter> mInterpreter;
+    tflite::SignatureRunner* mRunner = nullptr;
+};
+
+} // namespace android
diff --git a/libs/input/Android.bp b/libs/input/Android.bp
index 8f41cc1..83392ec 100644
--- a/libs/input/Android.bp
+++ b/libs/input/Android.bp
@@ -41,6 +41,7 @@
         "-Wall",
         "-Wextra",
         "-Werror",
+        "-Wno-unused-parameter",
     ],
     srcs: [
         "Input.cpp",
@@ -52,13 +53,18 @@
         "MotionPredictor.cpp",
         "PrintTools.cpp",
         "PropertyMap.cpp",
+        "TfLiteMotionPredictor.cpp",
         "TouchVideoFrame.cpp",
         "VelocityControl.cpp",
         "VelocityTracker.cpp",
         "VirtualKeyMap.cpp",
     ],
 
-    header_libs: ["jni_headers"],
+    header_libs: [
+        "flatbuffer_headers",
+        "jni_headers",
+        "tensorflow_headers",
+    ],
     export_header_lib_headers: ["jni_headers"],
 
     shared_libs: [
@@ -67,6 +73,7 @@
         "liblog",
         "libPlatformProperties",
         "libvintf",
+        "libtflite",
     ],
 
     static_libs: [
@@ -103,6 +110,10 @@
             sanitize: {
                 misc_undefined: ["integer"],
             },
+
+            required: [
+                "motion_predictor_model_prebuilt",
+            ],
         },
         host: {
             shared: {
diff --git a/libs/input/MotionPredictor.cpp b/libs/input/MotionPredictor.cpp
index 0fa0f12..0f889e8 100644
--- a/libs/input/MotionPredictor.cpp
+++ b/libs/input/MotionPredictor.cpp
@@ -18,118 +18,188 @@
 
 #include <input/MotionPredictor.h>
 
+#include <cinttypes>
+#include <cmath>
+#include <cstddef>
+#include <cstdint>
+#include <string>
+#include <vector>
+
+#include <android-base/strings.h>
+#include <android/input.h>
+#include <log/log.h>
+
+#include <attestation/HmacKeyManager.h>
+#include <input/TfLiteMotionPredictor.h>
+
+namespace android {
+namespace {
+
+const char DEFAULT_MODEL_PATH[] = "/system/etc/motion_predictor_model.fb";
+const int64_t PREDICTION_INTERVAL_NANOS =
+        12500000 / 3; // TODO(b/266747937): Get this from the model.
+
 /**
  * Log debug messages about predictions.
  * Enable this via "adb shell setprop log.tag.MotionPredictor DEBUG"
  */
-static bool isDebug() {
+bool isDebug() {
     return __android_log_is_loggable(ANDROID_LOG_DEBUG, LOG_TAG, ANDROID_LOG_INFO);
 }
 
-namespace android {
+// Converts a prediction of some polar (r, phi) to Cartesian (x, y) when applied to an axis.
+TfLiteMotionPredictorSample::Point convertPrediction(
+        const TfLiteMotionPredictorSample::Point& axisFrom,
+        const TfLiteMotionPredictorSample::Point& axisTo, float r, float phi) {
+    const TfLiteMotionPredictorSample::Point axis = axisTo - axisFrom;
+    const float axis_phi = std::atan2(axis.y, axis.x);
+    const float x_delta = r * std::cos(axis_phi + phi);
+    const float y_delta = r * std::sin(axis_phi + phi);
+    return {.x = axisTo.x + x_delta, .y = axisTo.y + y_delta};
+}
+
+} // namespace
 
 // --- MotionPredictor ---
 
-MotionPredictor::MotionPredictor(nsecs_t predictionTimestampOffsetNanos,
+MotionPredictor::MotionPredictor(nsecs_t predictionTimestampOffsetNanos, const char* modelPath,
                                  std::function<bool()> checkMotionPredictionEnabled)
       : mPredictionTimestampOffsetNanos(predictionTimestampOffsetNanos),
-        mCheckMotionPredictionEnabled(std::move(checkMotionPredictionEnabled)) {}
+        mCheckMotionPredictionEnabled(std::move(checkMotionPredictionEnabled)),
+        mModel(TfLiteMotionPredictorModel::create(modelPath == nullptr ? DEFAULT_MODEL_PATH
+                                                                       : modelPath)) {}
 
 void MotionPredictor::record(const MotionEvent& event) {
-    mEvents.push_back({});
-    mEvents.back().copyFrom(&event, /*keepHistory=*/true);
-    if (mEvents.size() > 2) {
-        // Just need 2 samples in order to extrapolate
-        mEvents.erase(mEvents.begin());
+    if (!isPredictionAvailable(event.getDeviceId(), event.getSource())) {
+        ALOGE("Prediction not supported for device %d's %s source", event.getDeviceId(),
+              inputEventSourceToString(event.getSource()).c_str());
+        return;
     }
+
+    TfLiteMotionPredictorBuffers& buffers =
+            mDeviceBuffers.try_emplace(event.getDeviceId(), mModel->inputLength()).first->second;
+
+    const int32_t action = event.getActionMasked();
+    if (action == AMOTION_EVENT_ACTION_UP) {
+        ALOGD_IF(isDebug(), "End of event stream");
+        buffers.reset();
+        return;
+    } else if (action != AMOTION_EVENT_ACTION_DOWN && action != AMOTION_EVENT_ACTION_MOVE) {
+        ALOGD_IF(isDebug(), "Skipping unsupported %s action",
+                 MotionEvent::actionToString(action).c_str());
+        return;
+    }
+
+    if (event.getPointerCount() != 1) {
+        ALOGD_IF(isDebug(), "Prediction not supported for multiple pointers");
+        return;
+    }
+
+    const int32_t toolType = event.getPointerProperties(0)->toolType;
+    if (toolType != AMOTION_EVENT_TOOL_TYPE_STYLUS) {
+        ALOGD_IF(isDebug(), "Prediction not supported for non-stylus tool: %s",
+                 motionToolTypeToString(toolType));
+        return;
+    }
+
+    for (size_t i = 0; i <= event.getHistorySize(); ++i) {
+        if (event.isResampled(0, i)) {
+            continue;
+        }
+        const PointerCoords* coords = event.getHistoricalRawPointerCoords(0, i);
+        buffers.pushSample(event.getHistoricalEventTime(i),
+                           {
+                                   .position.x = coords->getAxisValue(AMOTION_EVENT_AXIS_X),
+                                   .position.y = coords->getAxisValue(AMOTION_EVENT_AXIS_Y),
+                                   .pressure = event.getHistoricalPressure(0, i),
+                                   .tilt = event.getHistoricalAxisValue(AMOTION_EVENT_AXIS_TILT, 0,
+                                                                        i),
+                                   .orientation = event.getHistoricalOrientation(0, i),
+                           });
+    }
+
+    mLastEvents.try_emplace(event.getDeviceId())
+            .first->second.copyFrom(&event, /*keepHistory=*/false);
 }
 
-/**
- * This is an example implementation that should be replaced with the actual prediction.
- * The returned MotionEvent should be similar to the incoming MotionEvent, except for the
- * fields that are predicted:
- *
- * 1) event.getEventTime
- * 2) event.getPointerCoords
- *
- * The returned event should not contain any of the real, existing data. It should only
- * contain the predicted samples.
- */
 std::vector<std::unique_ptr<MotionEvent>> MotionPredictor::predict(nsecs_t timestamp) {
-    if (mEvents.size() < 2) {
-        return {};
-    }
+    std::vector<std::unique_ptr<MotionEvent>> predictions;
 
-    const MotionEvent& event = mEvents.back();
-    if (!isPredictionAvailable(event.getDeviceId(), event.getSource())) {
-        return {};
-    }
-
-    std::unique_ptr<MotionEvent> prediction = std::make_unique<MotionEvent>();
-    std::vector<PointerCoords> futureCoords;
-    const nsecs_t futureTime = timestamp + mPredictionTimestampOffsetNanos;
-    const nsecs_t currentTime = event.getEventTime();
-    const MotionEvent& previous = mEvents.rbegin()[1];
-    const nsecs_t oldTime = previous.getEventTime();
-    if (currentTime == oldTime) {
-        // This can happen if it's an ACTION_POINTER_DOWN event, for example.
-        return {}; // prevent division by zero.
-    }
-
-    for (size_t i = 0; i < event.getPointerCount(); i++) {
-        const int32_t pointerId = event.getPointerId(i);
-        const PointerCoords* currentPointerCoords = event.getRawPointerCoords(i);
-        const float currentX = currentPointerCoords->getAxisValue(AMOTION_EVENT_AXIS_X);
-        const float currentY = currentPointerCoords->getAxisValue(AMOTION_EVENT_AXIS_Y);
-
-        PointerCoords coords;
-        coords.clear();
-
-        ssize_t index = previous.findPointerIndex(pointerId);
-        if (index >= 0) {
-            // We have old data for this pointer. Compute the prediction.
-            const PointerCoords* oldPointerCoords = previous.getRawPointerCoords(index);
-            const float oldX = oldPointerCoords->getAxisValue(AMOTION_EVENT_AXIS_X);
-            const float oldY = oldPointerCoords->getAxisValue(AMOTION_EVENT_AXIS_Y);
-
-            // Let's do a linear interpolation while waiting for a real model
-            const float scale =
-                    static_cast<float>(futureTime - currentTime) / (currentTime - oldTime);
-            const float futureX = currentX + (currentX - oldX) * scale;
-            const float futureY = currentY + (currentY - oldY) * scale;
-
-            coords.setAxisValue(AMOTION_EVENT_AXIS_X, futureX);
-            coords.setAxisValue(AMOTION_EVENT_AXIS_Y, futureY);
-            ALOGD_IF(isDebug(),
-                     "Prediction by %.1f ms, (%.1f, %.1f), (%.1f, %.1f) --> (%.1f, %.1f)",
-                     (futureTime - event.getEventTime()) * 1E-6, oldX, oldY, currentX, currentY,
-                     futureX, futureY);
+    for (const auto& [deviceId, buffer] : mDeviceBuffers) {
+        if (!buffer.isReady()) {
+            continue;
         }
 
-        futureCoords.push_back(coords);
+        buffer.copyTo(*mModel);
+        LOG_ALWAYS_FATAL_IF(!mModel->invoke());
+
+        // Read out the predictions.
+        const std::span<const float> predictedR = mModel->outputR();
+        const std::span<const float> predictedPhi = mModel->outputPhi();
+        const std::span<const float> predictedPressure = mModel->outputPressure();
+
+        TfLiteMotionPredictorSample::Point axisFrom = buffer.axisFrom().position;
+        TfLiteMotionPredictorSample::Point axisTo = buffer.axisTo().position;
+
+        if (isDebug()) {
+            ALOGD("deviceId: %d", deviceId);
+            ALOGD("axisFrom: %f, %f", axisFrom.x, axisFrom.y);
+            ALOGD("axisTo: %f, %f", axisTo.x, axisTo.y);
+            ALOGD("mInputR: %s", base::Join(mModel->inputR(), ", ").c_str());
+            ALOGD("mInputPhi: %s", base::Join(mModel->inputPhi(), ", ").c_str());
+            ALOGD("mInputPressure: %s", base::Join(mModel->inputPressure(), ", ").c_str());
+            ALOGD("mInputTilt: %s", base::Join(mModel->inputTilt(), ", ").c_str());
+            ALOGD("mInputOrientation: %s", base::Join(mModel->inputOrientation(), ", ").c_str());
+            ALOGD("predictedR: %s", base::Join(predictedR, ", ").c_str());
+            ALOGD("predictedPhi: %s", base::Join(predictedPhi, ", ").c_str());
+            ALOGD("predictedPressure: %s", base::Join(predictedPressure, ", ").c_str());
+        }
+
+        const MotionEvent& event = mLastEvents[deviceId];
+        bool hasPredictions = false;
+        std::unique_ptr<MotionEvent> prediction = std::make_unique<MotionEvent>();
+        int64_t predictionTime = buffer.lastTimestamp();
+        const int64_t futureTime = timestamp + mPredictionTimestampOffsetNanos;
+
+        for (int i = 0; i < predictedR.size() && predictionTime <= futureTime; ++i) {
+            const TfLiteMotionPredictorSample::Point point =
+                    convertPrediction(axisFrom, axisTo, predictedR[i], predictedPhi[i]);
+            // TODO(b/266747654): Stop predictions if confidence is < some threshold.
+
+            ALOGD_IF(isDebug(), "prediction %d: %f, %f", i, point.x, point.y);
+            PointerCoords coords;
+            coords.clear();
+            coords.setAxisValue(AMOTION_EVENT_AXIS_X, point.x);
+            coords.setAxisValue(AMOTION_EVENT_AXIS_Y, point.y);
+            // TODO(b/266747654): Stop predictions if predicted pressure is < some threshold.
+            coords.setAxisValue(AMOTION_EVENT_AXIS_PRESSURE, predictedPressure[i]);
+
+            predictionTime += PREDICTION_INTERVAL_NANOS;
+            if (i == 0) {
+                hasPredictions = true;
+                prediction->initialize(InputEvent::nextId(), event.getDeviceId(), event.getSource(),
+                                       event.getDisplayId(), INVALID_HMAC,
+                                       AMOTION_EVENT_ACTION_MOVE, event.getActionButton(),
+                                       event.getFlags(), event.getEdgeFlags(), event.getMetaState(),
+                                       event.getButtonState(), event.getClassification(),
+                                       event.getTransform(), event.getXPrecision(),
+                                       event.getYPrecision(), event.getRawXCursorPosition(),
+                                       event.getRawYCursorPosition(), event.getRawTransform(),
+                                       event.getDownTime(), predictionTime, event.getPointerCount(),
+                                       event.getPointerProperties(), &coords);
+            } else {
+                prediction->addSample(predictionTime, &coords);
+            }
+
+            axisFrom = axisTo;
+            axisTo = point;
+        }
+        // TODO(b/266747511): Interpolate to futureTime?
+        if (hasPredictions) {
+            predictions.push_back(std::move(prediction));
+        }
     }
-
-    /**
-     * The process of adding samples is different for the first and subsequent samples:
-     * 1. Add the first sample via 'initialize' as below
-     * 2. Add subsequent samples via 'addSample'
-     */
-    prediction->initialize(event.getId(), event.getDeviceId(), event.getSource(),
-                           event.getDisplayId(), event.getHmac(), event.getAction(),
-                           event.getActionButton(), event.getFlags(), event.getEdgeFlags(),
-                           event.getMetaState(), event.getButtonState(), event.getClassification(),
-                           event.getTransform(), event.getXPrecision(), event.getYPrecision(),
-                           event.getRawXCursorPosition(), event.getRawYCursorPosition(),
-                           event.getRawTransform(), event.getDownTime(), futureTime,
-                           event.getPointerCount(), event.getPointerProperties(),
-                           futureCoords.data());
-
-    // To add more predicted samples, use 'addSample':
-    prediction->addSample(futureTime + 1, futureCoords.data());
-
-    std::vector<std::unique_ptr<MotionEvent>> out;
-    out.push_back(std::move(prediction));
-    return out;
+    return predictions;
 }
 
 bool MotionPredictor::isPredictionAvailable(int32_t /*deviceId*/, int32_t source) {
diff --git a/libs/input/TfLiteMotionPredictor.cpp b/libs/input/TfLiteMotionPredictor.cpp
new file mode 100644
index 0000000..1a337ad
--- /dev/null
+++ b/libs/input/TfLiteMotionPredictor.cpp
@@ -0,0 +1,338 @@
+/*
+ * Copyright (C) 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#define LOG_TAG "TfLiteMotionPredictor"
+#include <input/TfLiteMotionPredictor.h>
+
+#include <algorithm>
+#include <cmath>
+#include <cstddef>
+#include <cstdint>
+#include <fstream>
+#include <ios>
+#include <iterator>
+#include <memory>
+#include <span>
+#include <string>
+#include <type_traits>
+#include <utility>
+
+#define ATRACE_TAG ATRACE_TAG_INPUT
+#include <cutils/trace.h>
+#include <log/log.h>
+
+#include "tensorflow/lite/core/api/error_reporter.h"
+#include "tensorflow/lite/interpreter.h"
+#include "tensorflow/lite/kernels/register.h"
+#include "tensorflow/lite/model.h"
+
+namespace android {
+namespace {
+
+constexpr char SIGNATURE_KEY[] = "serving_default";
+
+// Input tensor names.
+constexpr char INPUT_R[] = "r";
+constexpr char INPUT_PHI[] = "phi";
+constexpr char INPUT_PRESSURE[] = "pressure";
+constexpr char INPUT_TILT[] = "tilt";
+constexpr char INPUT_ORIENTATION[] = "orientation";
+
+// Output tensor names.
+constexpr char OUTPUT_R[] = "r";
+constexpr char OUTPUT_PHI[] = "phi";
+constexpr char OUTPUT_PRESSURE[] = "pressure";
+
+// A TFLite ErrorReporter that logs to logcat.
+class LoggingErrorReporter : public tflite::ErrorReporter {
+public:
+    int Report(const char* format, va_list args) override {
+        return LOG_PRI_VA(ANDROID_LOG_ERROR, LOG_TAG, format, args);
+    }
+};
+
+// Searches a runner for an input tensor.
+TfLiteTensor* findInputTensor(const char* name, tflite::SignatureRunner* runner) {
+    TfLiteTensor* tensor = runner->input_tensor(name);
+    LOG_ALWAYS_FATAL_IF(!tensor, "Failed to find input tensor '%s'", name);
+    return tensor;
+}
+
+// Searches a runner for an output tensor.
+const TfLiteTensor* findOutputTensor(const char* name, tflite::SignatureRunner* runner) {
+    const TfLiteTensor* tensor = runner->output_tensor(name);
+    LOG_ALWAYS_FATAL_IF(!tensor, "Failed to find output tensor '%s'", name);
+    return tensor;
+}
+
+// Returns the buffer for a tensor of type T.
+template <typename T>
+std::span<T> getTensorBuffer(typename std::conditional<std::is_const<T>::value, const TfLiteTensor*,
+                                                       TfLiteTensor*>::type tensor) {
+    LOG_ALWAYS_FATAL_IF(!tensor);
+
+    const TfLiteType type = tflite::typeToTfLiteType<typename std::remove_cv<T>::type>();
+    LOG_ALWAYS_FATAL_IF(tensor->type != type, "Unexpected type for '%s' tensor: %s (expected %s)",
+                        tensor->name, TfLiteTypeGetName(tensor->type), TfLiteTypeGetName(type));
+
+    LOG_ALWAYS_FATAL_IF(!tensor->data.data);
+    return {reinterpret_cast<T*>(tensor->data.data),
+            static_cast<typename std::span<T>::index_type>(tensor->bytes / sizeof(T))};
+}
+
+// Verifies that a tensor exists and has an underlying buffer of type T.
+template <typename T>
+void checkTensor(const TfLiteTensor* tensor) {
+    LOG_ALWAYS_FATAL_IF(!tensor);
+
+    const auto buffer = getTensorBuffer<const T>(tensor);
+    LOG_ALWAYS_FATAL_IF(buffer.empty(), "No buffer for tensor '%s'", tensor->name);
+}
+
+} // namespace
+
+TfLiteMotionPredictorBuffers::TfLiteMotionPredictorBuffers(size_t inputLength) {
+    LOG_ALWAYS_FATAL_IF(inputLength == 0, "Buffer input size must be greater than 0");
+    mInputR.resize(inputLength);
+    mInputPhi.resize(inputLength);
+    mInputPressure.resize(inputLength);
+    mInputTilt.resize(inputLength);
+    mInputOrientation.resize(inputLength);
+}
+
+void TfLiteMotionPredictorBuffers::reset() {
+    std::fill(mInputR.begin(), mInputR.end(), 0);
+    std::fill(mInputPhi.begin(), mInputPhi.end(), 0);
+    std::fill(mInputPressure.begin(), mInputPressure.end(), 0);
+    std::fill(mInputTilt.begin(), mInputTilt.end(), 0);
+    std::fill(mInputOrientation.begin(), mInputOrientation.end(), 0);
+    mAxisFrom.reset();
+    mAxisTo.reset();
+}
+
+void TfLiteMotionPredictorBuffers::copyTo(TfLiteMotionPredictorModel& model) const {
+    LOG_ALWAYS_FATAL_IF(mInputR.size() != model.inputLength(),
+                        "Buffer length %zu doesn't match model input length %zu", mInputR.size(),
+                        model.inputLength());
+    LOG_ALWAYS_FATAL_IF(!isReady(), "Buffers are incomplete");
+
+    std::copy(mInputR.begin(), mInputR.end(), model.inputR().begin());
+    std::copy(mInputPhi.begin(), mInputPhi.end(), model.inputPhi().begin());
+    std::copy(mInputPressure.begin(), mInputPressure.end(), model.inputPressure().begin());
+    std::copy(mInputTilt.begin(), mInputTilt.end(), model.inputTilt().begin());
+    std::copy(mInputOrientation.begin(), mInputOrientation.end(), model.inputOrientation().begin());
+}
+
+void TfLiteMotionPredictorBuffers::pushSample(int64_t timestamp,
+                                              const TfLiteMotionPredictorSample sample) {
+    // Convert the sample (x, y) into polar (r, φ) based on a reference axis
+    // from the preceding two points (mAxisFrom/mAxisTo).
+
+    mTimestamp = timestamp;
+
+    if (!mAxisTo) { // First point.
+        mAxisTo = sample;
+        return;
+    }
+
+    // Vector from the last point to the current sample point.
+    const TfLiteMotionPredictorSample::Point v = sample.position - mAxisTo->position;
+
+    const float r = std::hypot(v.x, v.y);
+    float phi = 0;
+    float orientation = 0;
+
+    // Ignore the sample if there is no movement. These samples can occur when there's change to a
+    // property other than the coordinates and pollute the input to the model.
+    if (r == 0) {
+        return;
+    }
+
+    if (!mAxisFrom) { // Second point.
+        // We can only determine the distance from the first point, and not any
+        // angle. However, if the second point forms an axis, the orientation can
+        // be transformed relative to that axis.
+        const float axisPhi = std::atan2(v.y, v.x);
+        // A MotionEvent's orientation is measured clockwise from the vertical
+        // axis, but axisPhi is measured counter-clockwise from the horizontal
+        // axis.
+        orientation = M_PI_2 - sample.orientation - axisPhi;
+    } else {
+        const TfLiteMotionPredictorSample::Point axis = mAxisTo->position - mAxisFrom->position;
+        const float axisPhi = std::atan2(axis.y, axis.x);
+        phi = std::atan2(v.y, v.x) - axisPhi;
+
+        if (std::hypot(axis.x, axis.y) > 0) {
+            // See note above.
+            orientation = M_PI_2 - sample.orientation - axisPhi;
+        }
+    }
+
+    // Update the axis for the next point.
+    mAxisFrom = mAxisTo;
+    mAxisTo = sample;
+
+    // Push the current sample onto the end of the input buffers.
+    mInputR.erase(mInputR.begin());
+    mInputPhi.erase(mInputPhi.begin());
+    mInputPressure.erase(mInputPressure.begin());
+    mInputTilt.erase(mInputTilt.begin());
+    mInputOrientation.erase(mInputOrientation.begin());
+
+    mInputR.push_back(r);
+    mInputPhi.push_back(phi);
+    mInputPressure.push_back(sample.pressure);
+    mInputTilt.push_back(sample.tilt);
+    mInputOrientation.push_back(orientation);
+}
+
+std::unique_ptr<TfLiteMotionPredictorModel> TfLiteMotionPredictorModel::create(
+        const char* modelPath) {
+    std::ifstream f(modelPath, std::ios::binary);
+    LOG_ALWAYS_FATAL_IF(!f, "Could not read model from %s", modelPath);
+
+    std::string data;
+    data.assign(std::istreambuf_iterator<char>(f), std::istreambuf_iterator<char>());
+
+    return std::unique_ptr<TfLiteMotionPredictorModel>(
+            new TfLiteMotionPredictorModel(std::move(data)));
+}
+
+TfLiteMotionPredictorModel::TfLiteMotionPredictorModel(std::string model)
+      : mFlatBuffer(std::move(model)) {
+    mErrorReporter = std::make_unique<LoggingErrorReporter>();
+    mModel = tflite::FlatBufferModel::VerifyAndBuildFromBuffer(mFlatBuffer.data(),
+                                                               mFlatBuffer.length(),
+                                                               /*extra_verifier=*/nullptr,
+                                                               mErrorReporter.get());
+    LOG_ALWAYS_FATAL_IF(!mModel);
+
+    tflite::ops::builtin::BuiltinOpResolver resolver;
+    tflite::InterpreterBuilder builder(*mModel, resolver);
+
+    if (builder(&mInterpreter) != kTfLiteOk || !mInterpreter) {
+        LOG_ALWAYS_FATAL("Failed to build interpreter");
+    }
+
+    mRunner = mInterpreter->GetSignatureRunner(SIGNATURE_KEY);
+    LOG_ALWAYS_FATAL_IF(!mRunner, "Failed to find runner for signature '%s'", SIGNATURE_KEY);
+
+    allocateTensors();
+}
+
+void TfLiteMotionPredictorModel::allocateTensors() {
+    if (mRunner->AllocateTensors() != kTfLiteOk) {
+        LOG_ALWAYS_FATAL("Failed to allocate tensors");
+    }
+
+    attachInputTensors();
+    attachOutputTensors();
+
+    checkTensor<float>(mInputR);
+    checkTensor<float>(mInputPhi);
+    checkTensor<float>(mInputPressure);
+    checkTensor<float>(mInputTilt);
+    checkTensor<float>(mInputOrientation);
+    checkTensor<float>(mOutputR);
+    checkTensor<float>(mOutputPhi);
+    checkTensor<float>(mOutputPressure);
+
+    const auto checkInputTensorSize = [this](const TfLiteTensor* tensor) {
+        const size_t size = getTensorBuffer<const float>(tensor).size();
+        LOG_ALWAYS_FATAL_IF(size != inputLength(),
+                            "Tensor '%s' length %zu does not match input length %zu", tensor->name,
+                            size, inputLength());
+    };
+
+    checkInputTensorSize(mInputR);
+    checkInputTensorSize(mInputPhi);
+    checkInputTensorSize(mInputPressure);
+    checkInputTensorSize(mInputTilt);
+    checkInputTensorSize(mInputOrientation);
+}
+
+void TfLiteMotionPredictorModel::attachInputTensors() {
+    mInputR = findInputTensor(INPUT_R, mRunner);
+    mInputPhi = findInputTensor(INPUT_PHI, mRunner);
+    mInputPressure = findInputTensor(INPUT_PRESSURE, mRunner);
+    mInputTilt = findInputTensor(INPUT_TILT, mRunner);
+    mInputOrientation = findInputTensor(INPUT_ORIENTATION, mRunner);
+}
+
+void TfLiteMotionPredictorModel::attachOutputTensors() {
+    mOutputR = findOutputTensor(OUTPUT_R, mRunner);
+    mOutputPhi = findOutputTensor(OUTPUT_PHI, mRunner);
+    mOutputPressure = findOutputTensor(OUTPUT_PRESSURE, mRunner);
+}
+
+bool TfLiteMotionPredictorModel::invoke() {
+    ATRACE_BEGIN("TfLiteMotionPredictorModel::invoke");
+    TfLiteStatus result = mRunner->Invoke();
+    ATRACE_END();
+
+    if (result != kTfLiteOk) {
+        return false;
+    }
+
+    // Invoke() might reallocate tensors, so they need to be reattached.
+    attachInputTensors();
+    attachOutputTensors();
+
+    if (outputR().size() != outputPhi().size() || outputR().size() != outputPressure().size()) {
+        LOG_ALWAYS_FATAL("Output size mismatch: (r: %zu, phi: %zu, pressure: %zu)",
+                         outputR().size(), outputPhi().size(), outputPressure().size());
+    }
+
+    return true;
+}
+
+size_t TfLiteMotionPredictorModel::inputLength() const {
+    return getTensorBuffer<const float>(mInputR).size();
+}
+
+std::span<float> TfLiteMotionPredictorModel::inputR() {
+    return getTensorBuffer<float>(mInputR);
+}
+
+std::span<float> TfLiteMotionPredictorModel::inputPhi() {
+    return getTensorBuffer<float>(mInputPhi);
+}
+
+std::span<float> TfLiteMotionPredictorModel::inputPressure() {
+    return getTensorBuffer<float>(mInputPressure);
+}
+
+std::span<float> TfLiteMotionPredictorModel::inputTilt() {
+    return getTensorBuffer<float>(mInputTilt);
+}
+
+std::span<float> TfLiteMotionPredictorModel::inputOrientation() {
+    return getTensorBuffer<float>(mInputOrientation);
+}
+
+std::span<const float> TfLiteMotionPredictorModel::outputR() const {
+    return getTensorBuffer<const float>(mOutputR);
+}
+
+std::span<const float> TfLiteMotionPredictorModel::outputPhi() const {
+    return getTensorBuffer<const float>(mOutputPhi);
+}
+
+std::span<const float> TfLiteMotionPredictorModel::outputPressure() const {
+    return getTensorBuffer<const float>(mOutputPressure);
+}
+
+} // namespace android
diff --git a/libs/input/tests/Android.bp b/libs/input/tests/Android.bp
index e2c0860..916a8f2 100644
--- a/libs/input/tests/Android.bp
+++ b/libs/input/tests/Android.bp
@@ -10,6 +10,7 @@
 
 cc_test {
     name: "libinput_tests",
+    cpp_std: "c++20",
     host_supported: true,
     srcs: [
         "IdGenerator_test.cpp",
@@ -18,12 +19,18 @@
         "InputEvent_test.cpp",
         "InputPublisherAndConsumer_test.cpp",
         "MotionPredictor_test.cpp",
+        "TfLiteMotionPredictor_test.cpp",
         "TouchResampling_test.cpp",
         "TouchVideoFrame_test.cpp",
         "VelocityTracker_test.cpp",
         "VerifiedInputEvent_test.cpp",
     ],
+    header_libs: [
+        "flatbuffer_headers",
+        "tensorflow_headers",
+    ],
     static_libs: [
+        "libgmock",
         "libgui_window_info_static",
         "libinput",
         "libui-types",
@@ -32,6 +39,7 @@
         "-Wall",
         "-Wextra",
         "-Werror",
+        "-Wno-unused-parameter",
     ],
     shared_libs: [
         "libbase",
@@ -39,10 +47,14 @@
         "libcutils",
         "liblog",
         "libPlatformProperties",
+        "libtflite",
         "libutils",
         "libvintf",
     ],
-    data: ["data/*"],
+    data: [
+        "data/*",
+        ":motion_predictor_model.fb",
+    ],
     test_options: {
         unit_test: true,
     },
diff --git a/libs/input/tests/MotionPredictor_test.cpp b/libs/input/tests/MotionPredictor_test.cpp
index d2b59a1..ce87c86 100644
--- a/libs/input/tests/MotionPredictor_test.cpp
+++ b/libs/input/tests/MotionPredictor_test.cpp
@@ -14,17 +14,36 @@
  * limitations under the License.
  */
 
+#include <chrono>
+
+#include <gmock/gmock.h>
 #include <gtest/gtest.h>
 #include <gui/constants.h>
 #include <input/Input.h>
 #include <input/MotionPredictor.h>
 
+using namespace std::literals::chrono_literals;
+
 namespace android {
 
+using ::testing::IsEmpty;
+using ::testing::SizeIs;
+using ::testing::UnorderedElementsAre;
+
+const char MODEL_PATH[] =
+#if defined(__ANDROID__)
+        "/system/etc/motion_predictor_model.fb";
+#else
+        "motion_predictor_model.fb";
+#endif
+
 constexpr int32_t DOWN = AMOTION_EVENT_ACTION_DOWN;
 constexpr int32_t MOVE = AMOTION_EVENT_ACTION_MOVE;
+constexpr int32_t UP = AMOTION_EVENT_ACTION_UP;
+constexpr nsecs_t NSEC_PER_MSEC = 1'000'000;
 
-static MotionEvent getMotionEvent(int32_t action, float x, float y, nsecs_t eventTime) {
+static MotionEvent getMotionEvent(int32_t action, float x, float y,
+                                  std::chrono::nanoseconds eventTime, int32_t deviceId = 0) {
     MotionEvent event;
     constexpr size_t pointerCount = 1;
     std::vector<PointerProperties> pointerProperties;
@@ -33,6 +52,7 @@
         PointerProperties properties;
         properties.clear();
         properties.id = i;
+        properties.toolType = AMOTION_EVENT_TOOL_TYPE_STYLUS;
         pointerProperties.push_back(properties);
         PointerCoords coords;
         coords.clear();
@@ -42,73 +62,93 @@
     }
 
     ui::Transform identityTransform;
-    event.initialize(InputEvent::nextId(), /*deviceId=*/0, AINPUT_SOURCE_STYLUS,
-                     ADISPLAY_ID_DEFAULT, {0}, action, /*actionButton=*/0, /*flags=*/0,
-                     AMOTION_EVENT_EDGE_FLAG_NONE, AMETA_NONE, /*buttonState=*/0,
-                     MotionClassification::NONE, identityTransform, /*xPrecision=*/0.1,
+    event.initialize(InputEvent::nextId(), deviceId, AINPUT_SOURCE_STYLUS, ADISPLAY_ID_DEFAULT, {0},
+                     action, /*actionButton=*/0, /*flags=*/0, AMOTION_EVENT_EDGE_FLAG_NONE,
+                     AMETA_NONE, /*buttonState=*/0, MotionClassification::NONE, identityTransform,
+                     /*xPrecision=*/0.1,
                      /*yPrecision=*/0.2, /*xCursorPosition=*/280, /*yCursorPosition=*/540,
-                     identityTransform, /*downTime=*/100, eventTime, pointerCount,
+                     identityTransform, /*downTime=*/100, eventTime.count(), pointerCount,
                      pointerProperties.data(), pointerCoords.data());
     return event;
 }
 
-/**
- * A linear motion should be predicted to be linear in the future
- */
-TEST(MotionPredictorTest, LinearPrediction) {
-    MotionPredictor predictor(/*predictionTimestampOffsetNanos=*/0,
-                              []() { return true /*enable prediction*/; });
-
-    predictor.record(getMotionEvent(DOWN, 0, 1, 0));
-    predictor.record(getMotionEvent(MOVE, 1, 3, 10));
-    predictor.record(getMotionEvent(MOVE, 2, 5, 20));
-    predictor.record(getMotionEvent(MOVE, 3, 7, 30));
-    std::vector<std::unique_ptr<MotionEvent>> predicted = predictor.predict(40);
-    ASSERT_EQ(1u, predicted.size());
-    ASSERT_EQ(predicted[0]->getX(0), 4);
-    ASSERT_EQ(predicted[0]->getY(0), 9);
-}
-
-/**
- * A still motion should be predicted to remain still
- */
-TEST(MotionPredictorTest, StationaryPrediction) {
-    MotionPredictor predictor(/*predictionTimestampOffsetNanos=*/0,
-                              []() { return true /*enable prediction*/; });
-
-    predictor.record(getMotionEvent(DOWN, 0, 1, 0));
-    predictor.record(getMotionEvent(MOVE, 0, 1, 10));
-    predictor.record(getMotionEvent(MOVE, 0, 1, 20));
-    predictor.record(getMotionEvent(MOVE, 0, 1, 30));
-    std::vector<std::unique_ptr<MotionEvent>> predicted = predictor.predict(40);
-    ASSERT_EQ(1u, predicted.size());
-    ASSERT_EQ(predicted[0]->getX(0), 0);
-    ASSERT_EQ(predicted[0]->getY(0), 1);
-}
-
 TEST(MotionPredictorTest, IsPredictionAvailable) {
-    MotionPredictor predictor(/*predictionTimestampOffsetNanos=*/0,
+    MotionPredictor predictor(/*predictionTimestampOffsetNanos=*/0, MODEL_PATH,
                               []() { return true /*enable prediction*/; });
     ASSERT_TRUE(predictor.isPredictionAvailable(/*deviceId=*/1, AINPUT_SOURCE_STYLUS));
     ASSERT_FALSE(predictor.isPredictionAvailable(/*deviceId=*/1, AINPUT_SOURCE_TOUCHSCREEN));
 }
 
 TEST(MotionPredictorTest, Offset) {
-    MotionPredictor predictor(/*predictionTimestampOffsetNanos=*/1,
+    MotionPredictor predictor(/*predictionTimestampOffsetNanos=*/1, MODEL_PATH,
                               []() { return true /*enable prediction*/; });
-    predictor.record(getMotionEvent(DOWN, 0, 1, 30));
-    predictor.record(getMotionEvent(MOVE, 0, 1, 35));
-    std::vector<std::unique_ptr<MotionEvent>> predicted = predictor.predict(40);
+    predictor.record(getMotionEvent(DOWN, 0, 1, 30ms));
+    predictor.record(getMotionEvent(MOVE, 0, 2, 35ms));
+    std::vector<std::unique_ptr<MotionEvent>> predicted = predictor.predict(40 * NSEC_PER_MSEC);
     ASSERT_EQ(1u, predicted.size());
     ASSERT_GE(predicted[0]->getEventTime(), 41);
 }
 
+TEST(MotionPredictorTest, FollowsGesture) {
+    MotionPredictor predictor(/*predictionTimestampOffsetNanos=*/0, MODEL_PATH,
+                              []() { return true /*enable prediction*/; });
+
+    // MOVE without a DOWN is ignored.
+    predictor.record(getMotionEvent(MOVE, 1, 3, 10ms));
+    EXPECT_THAT(predictor.predict(20 * NSEC_PER_MSEC), IsEmpty());
+
+    predictor.record(getMotionEvent(DOWN, 2, 5, 20ms));
+    predictor.record(getMotionEvent(MOVE, 2, 7, 30ms));
+    predictor.record(getMotionEvent(MOVE, 3, 9, 40ms));
+    EXPECT_THAT(predictor.predict(50 * NSEC_PER_MSEC), SizeIs(1));
+
+    predictor.record(getMotionEvent(UP, 4, 11, 50ms));
+    EXPECT_THAT(predictor.predict(20 * NSEC_PER_MSEC), IsEmpty());
+}
+
+TEST(MotionPredictorTest, MultipleDevicesTracked) {
+    MotionPredictor predictor(/*predictionTimestampOffsetNanos=*/0, MODEL_PATH,
+                              []() { return true /*enable prediction*/; });
+
+    predictor.record(getMotionEvent(DOWN, 1, 3, 0ms, /*deviceId=*/0));
+    predictor.record(getMotionEvent(MOVE, 1, 3, 10ms, /*deviceId=*/0));
+    predictor.record(getMotionEvent(MOVE, 2, 5, 20ms, /*deviceId=*/0));
+    predictor.record(getMotionEvent(MOVE, 3, 7, 30ms, /*deviceId=*/0));
+
+    predictor.record(getMotionEvent(DOWN, 100, 300, 0ms, /*deviceId=*/1));
+    predictor.record(getMotionEvent(MOVE, 100, 300, 10ms, /*deviceId=*/1));
+    predictor.record(getMotionEvent(MOVE, 200, 500, 20ms, /*deviceId=*/1));
+    predictor.record(getMotionEvent(MOVE, 300, 700, 30ms, /*deviceId=*/1));
+
+    {
+        std::vector<std::unique_ptr<MotionEvent>> predicted = predictor.predict(40 * NSEC_PER_MSEC);
+        ASSERT_EQ(2u, predicted.size());
+
+        // Order of the returned vector is not guaranteed.
+        std::vector<int32_t> seenDeviceIds;
+        for (const auto& prediction : predicted) {
+            seenDeviceIds.push_back(prediction->getDeviceId());
+        }
+        EXPECT_THAT(seenDeviceIds, UnorderedElementsAre(0, 1));
+    }
+
+    // End the gesture for device 0.
+    predictor.record(getMotionEvent(UP, 4, 9, 40ms, /*deviceId=*/0));
+    predictor.record(getMotionEvent(MOVE, 400, 900, 40ms, /*deviceId=*/1));
+
+    {
+        std::vector<std::unique_ptr<MotionEvent>> predicted = predictor.predict(40 * NSEC_PER_MSEC);
+        ASSERT_EQ(1u, predicted.size());
+        ASSERT_EQ(predicted[0]->getDeviceId(), 1);
+    }
+}
+
 TEST(MotionPredictorTest, FlagDisablesPrediction) {
-    MotionPredictor predictor(/*predictionTimestampOffsetNanos=*/0,
+    MotionPredictor predictor(/*predictionTimestampOffsetNanos=*/0, MODEL_PATH,
                               []() { return false /*disable prediction*/; });
-    predictor.record(getMotionEvent(DOWN, 0, 1, 30));
-    predictor.record(getMotionEvent(MOVE, 0, 1, 35));
-    std::vector<std::unique_ptr<MotionEvent>> predicted = predictor.predict(40);
+    predictor.record(getMotionEvent(DOWN, 0, 1, 30ms));
+    predictor.record(getMotionEvent(MOVE, 0, 1, 35ms));
+    std::vector<std::unique_ptr<MotionEvent>> predicted = predictor.predict(40 * NSEC_PER_MSEC);
     ASSERT_EQ(0u, predicted.size());
     ASSERT_FALSE(predictor.isPredictionAvailable(/*deviceId=*/1, AINPUT_SOURCE_STYLUS));
     ASSERT_FALSE(predictor.isPredictionAvailable(/*deviceId=*/1, AINPUT_SOURCE_TOUCHSCREEN));
diff --git a/libs/input/tests/TfLiteMotionPredictor_test.cpp b/libs/input/tests/TfLiteMotionPredictor_test.cpp
new file mode 100644
index 0000000..454f2aa
--- /dev/null
+++ b/libs/input/tests/TfLiteMotionPredictor_test.cpp
@@ -0,0 +1,179 @@
+/*
+ * Copyright (C) 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <algorithm>
+#include <cmath>
+#include <fstream>
+#include <ios>
+#include <iterator>
+#include <string>
+
+#include <android-base/file.h>
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include <input/TfLiteMotionPredictor.h>
+
+namespace android {
+namespace {
+
+using ::testing::Each;
+using ::testing::ElementsAre;
+using ::testing::FloatNear;
+
+std::string getModelPath() {
+#if defined(__ANDROID__)
+    return "/system/etc/motion_predictor_model.fb";
+#else
+    return base::GetExecutableDirectory() + "/motion_predictor_model.fb";
+#endif
+}
+
+TEST(TfLiteMotionPredictorTest, BuffersReadiness) {
+    TfLiteMotionPredictorBuffers buffers(/*inputLength=*/5);
+    ASSERT_FALSE(buffers.isReady());
+
+    buffers.pushSample(/*timestamp=*/0, {.position = {.x = 100, .y = 100}});
+    ASSERT_FALSE(buffers.isReady());
+
+    buffers.pushSample(/*timestamp=*/1, {.position = {.x = 100, .y = 100}});
+    ASSERT_FALSE(buffers.isReady());
+
+    // Two samples with distinct positions are required.
+    buffers.pushSample(/*timestamp=*/2, {.position = {.x = 100, .y = 110}});
+    ASSERT_TRUE(buffers.isReady());
+
+    buffers.reset();
+    ASSERT_FALSE(buffers.isReady());
+}
+
+TEST(TfLiteMotionPredictorTest, BuffersRecentData) {
+    TfLiteMotionPredictorBuffers buffers(/*inputLength=*/5);
+
+    buffers.pushSample(/*timestamp=*/1, {.position = {.x = 100, .y = 200}});
+    ASSERT_EQ(buffers.lastTimestamp(), 1);
+
+    buffers.pushSample(/*timestamp=*/2, {.position = {.x = 150, .y = 250}});
+    ASSERT_EQ(buffers.lastTimestamp(), 2);
+    ASSERT_TRUE(buffers.isReady());
+    ASSERT_EQ(buffers.axisFrom().position.x, 100);
+    ASSERT_EQ(buffers.axisFrom().position.y, 200);
+    ASSERT_EQ(buffers.axisTo().position.x, 150);
+    ASSERT_EQ(buffers.axisTo().position.y, 250);
+
+    // Position doesn't change, so neither do the axes.
+    buffers.pushSample(/*timestamp=*/3, {.position = {.x = 150, .y = 250}});
+    ASSERT_EQ(buffers.lastTimestamp(), 3);
+    ASSERT_TRUE(buffers.isReady());
+    ASSERT_EQ(buffers.axisFrom().position.x, 100);
+    ASSERT_EQ(buffers.axisFrom().position.y, 200);
+    ASSERT_EQ(buffers.axisTo().position.x, 150);
+    ASSERT_EQ(buffers.axisTo().position.y, 250);
+
+    buffers.pushSample(/*timestamp=*/4, {.position = {.x = 180, .y = 280}});
+    ASSERT_EQ(buffers.lastTimestamp(), 4);
+    ASSERT_TRUE(buffers.isReady());
+    ASSERT_EQ(buffers.axisFrom().position.x, 150);
+    ASSERT_EQ(buffers.axisFrom().position.y, 250);
+    ASSERT_EQ(buffers.axisTo().position.x, 180);
+    ASSERT_EQ(buffers.axisTo().position.y, 280);
+}
+
+TEST(TfLiteMotionPredictorTest, BuffersCopyTo) {
+    std::unique_ptr<TfLiteMotionPredictorModel> model =
+            TfLiteMotionPredictorModel::create(getModelPath().c_str());
+    TfLiteMotionPredictorBuffers buffers(model->inputLength());
+
+    buffers.pushSample(/*timestamp=*/1,
+                       {.position = {.x = 10, .y = 10},
+                        .pressure = 0,
+                        .orientation = 0,
+                        .tilt = 0.2});
+    buffers.pushSample(/*timestamp=*/2,
+                       {.position = {.x = 10, .y = 50},
+                        .pressure = 0.4,
+                        .orientation = M_PI / 4,
+                        .tilt = 0.3});
+    buffers.pushSample(/*timestamp=*/3,
+                       {.position = {.x = 30, .y = 50},
+                        .pressure = 0.5,
+                        .orientation = -M_PI / 4,
+                        .tilt = 0.4});
+    buffers.pushSample(/*timestamp=*/3,
+                       {.position = {.x = 30, .y = 60},
+                        .pressure = 0,
+                        .orientation = 0,
+                        .tilt = 0.5});
+    buffers.copyTo(*model);
+
+    const int zeroPadding = model->inputLength() - 3;
+    ASSERT_GE(zeroPadding, 0);
+
+    EXPECT_THAT(model->inputR().subspan(0, zeroPadding), Each(0));
+    EXPECT_THAT(model->inputPhi().subspan(0, zeroPadding), Each(0));
+    EXPECT_THAT(model->inputPressure().subspan(0, zeroPadding), Each(0));
+    EXPECT_THAT(model->inputTilt().subspan(0, zeroPadding), Each(0));
+    EXPECT_THAT(model->inputOrientation().subspan(0, zeroPadding), Each(0));
+
+    EXPECT_THAT(model->inputR().subspan(zeroPadding), ElementsAre(40, 20, 10));
+    EXPECT_THAT(model->inputPhi().subspan(zeroPadding), ElementsAre(0, -M_PI / 2, M_PI / 2));
+    EXPECT_THAT(model->inputPressure().subspan(zeroPadding), ElementsAre(0.4, 0.5, 0));
+    EXPECT_THAT(model->inputTilt().subspan(zeroPadding), ElementsAre(0.3, 0.4, 0.5));
+    EXPECT_THAT(model->inputOrientation().subspan(zeroPadding),
+                ElementsAre(FloatNear(-M_PI / 4, 1e-5), FloatNear(M_PI / 4, 1e-5),
+                            FloatNear(M_PI / 2, 1e-5)));
+}
+
+TEST(TfLiteMotionPredictorTest, ModelInputOutputLength) {
+    std::unique_ptr<TfLiteMotionPredictorModel> model =
+            TfLiteMotionPredictorModel::create(getModelPath().c_str());
+    ASSERT_GT(model->inputLength(), 0u);
+
+    const int inputLength = model->inputLength();
+    ASSERT_EQ(inputLength, model->inputR().size());
+    ASSERT_EQ(inputLength, model->inputPhi().size());
+    ASSERT_EQ(inputLength, model->inputPressure().size());
+    ASSERT_EQ(inputLength, model->inputOrientation().size());
+    ASSERT_EQ(inputLength, model->inputTilt().size());
+
+    ASSERT_TRUE(model->invoke());
+
+    ASSERT_EQ(model->outputR().size(), model->outputPhi().size());
+    ASSERT_EQ(model->outputR().size(), model->outputPressure().size());
+}
+
+TEST(TfLiteMotionPredictorTest, ModelOutput) {
+    std::unique_ptr<TfLiteMotionPredictorModel> model =
+            TfLiteMotionPredictorModel::create(getModelPath().c_str());
+    TfLiteMotionPredictorBuffers buffers(model->inputLength());
+
+    buffers.pushSample(/*timestamp=*/1, {.position = {.x = 100, .y = 200}, .pressure = 0.2});
+    buffers.pushSample(/*timestamp=*/2, {.position = {.x = 150, .y = 250}, .pressure = 0.4});
+    buffers.pushSample(/*timestamp=*/3, {.position = {.x = 180, .y = 280}, .pressure = 0.6});
+    buffers.copyTo(*model);
+
+    ASSERT_TRUE(model->invoke());
+
+    // The actual model output is implementation-defined, but it should at least be non-zero and
+    // non-NaN.
+    const auto is_valid = [](float value) { return !isnan(value) && value != 0; };
+    ASSERT_TRUE(std::all_of(model->outputR().begin(), model->outputR().end(), is_valid));
+    ASSERT_TRUE(std::all_of(model->outputPhi().begin(), model->outputPhi().end(), is_valid));
+    ASSERT_TRUE(
+            std::all_of(model->outputPressure().begin(), model->outputPressure().end(), is_valid));
+}
+
+} // namespace
+} // namespace android