Add TFLite model for motion prediction.

This model generates probabilistic motion predictions based on a
sequence of relative input movements. The input movements are converted
into polar coordinates (distance and angle) based on an axis that
follows the current path. This ensures that the orientation of the
device and of the inputs do not affect the predictions. The orientation
of the input device is also transformed to be relative to the path axis.

The test cases verifying model efficacy are consolidated into CTS.

Bug: 167946763
Test: atest libinput_tests
PiperOrigin-RevId: 492068340
Change-Id: Icd8d90bd5a7ce79c699bfdb6367a4cbd8130441a
diff --git a/libs/input/MotionPredictor.cpp b/libs/input/MotionPredictor.cpp
index 0fa0f12..0f889e8 100644
--- a/libs/input/MotionPredictor.cpp
+++ b/libs/input/MotionPredictor.cpp
@@ -18,118 +18,188 @@
 
 #include <input/MotionPredictor.h>
 
+#include <cinttypes>
+#include <cmath>
+#include <cstddef>
+#include <cstdint>
+#include <string>
+#include <vector>
+
+#include <android-base/strings.h>
+#include <android/input.h>
+#include <log/log.h>
+
+#include <attestation/HmacKeyManager.h>
+#include <input/TfLiteMotionPredictor.h>
+
+namespace android {
+namespace {
+
+const char DEFAULT_MODEL_PATH[] = "/system/etc/motion_predictor_model.fb";
+const int64_t PREDICTION_INTERVAL_NANOS =
+        12500000 / 3; // TODO(b/266747937): Get this from the model.
+
 /**
  * Log debug messages about predictions.
  * Enable this via "adb shell setprop log.tag.MotionPredictor DEBUG"
  */
-static bool isDebug() {
+bool isDebug() {
     return __android_log_is_loggable(ANDROID_LOG_DEBUG, LOG_TAG, ANDROID_LOG_INFO);
 }
 
-namespace android {
+// Converts a prediction of some polar (r, phi) to Cartesian (x, y) when applied to an axis.
+TfLiteMotionPredictorSample::Point convertPrediction(
+        const TfLiteMotionPredictorSample::Point& axisFrom,
+        const TfLiteMotionPredictorSample::Point& axisTo, float r, float phi) {
+    const TfLiteMotionPredictorSample::Point axis = axisTo - axisFrom;
+    const float axis_phi = std::atan2(axis.y, axis.x);
+    const float x_delta = r * std::cos(axis_phi + phi);
+    const float y_delta = r * std::sin(axis_phi + phi);
+    return {.x = axisTo.x + x_delta, .y = axisTo.y + y_delta};
+}
+
+} // namespace
 
 // --- MotionPredictor ---
 
-MotionPredictor::MotionPredictor(nsecs_t predictionTimestampOffsetNanos,
+MotionPredictor::MotionPredictor(nsecs_t predictionTimestampOffsetNanos, const char* modelPath,
                                  std::function<bool()> checkMotionPredictionEnabled)
       : mPredictionTimestampOffsetNanos(predictionTimestampOffsetNanos),
-        mCheckMotionPredictionEnabled(std::move(checkMotionPredictionEnabled)) {}
+        mCheckMotionPredictionEnabled(std::move(checkMotionPredictionEnabled)),
+        mModel(TfLiteMotionPredictorModel::create(modelPath == nullptr ? DEFAULT_MODEL_PATH
+                                                                       : modelPath)) {}
 
 void MotionPredictor::record(const MotionEvent& event) {
-    mEvents.push_back({});
-    mEvents.back().copyFrom(&event, /*keepHistory=*/true);
-    if (mEvents.size() > 2) {
-        // Just need 2 samples in order to extrapolate
-        mEvents.erase(mEvents.begin());
+    if (!isPredictionAvailable(event.getDeviceId(), event.getSource())) {
+        ALOGE("Prediction not supported for device %d's %s source", event.getDeviceId(),
+              inputEventSourceToString(event.getSource()).c_str());
+        return;
     }
+
+    TfLiteMotionPredictorBuffers& buffers =
+            mDeviceBuffers.try_emplace(event.getDeviceId(), mModel->inputLength()).first->second;
+
+    const int32_t action = event.getActionMasked();
+    if (action == AMOTION_EVENT_ACTION_UP) {
+        ALOGD_IF(isDebug(), "End of event stream");
+        buffers.reset();
+        return;
+    } else if (action != AMOTION_EVENT_ACTION_DOWN && action != AMOTION_EVENT_ACTION_MOVE) {
+        ALOGD_IF(isDebug(), "Skipping unsupported %s action",
+                 MotionEvent::actionToString(action).c_str());
+        return;
+    }
+
+    if (event.getPointerCount() != 1) {
+        ALOGD_IF(isDebug(), "Prediction not supported for multiple pointers");
+        return;
+    }
+
+    const int32_t toolType = event.getPointerProperties(0)->toolType;
+    if (toolType != AMOTION_EVENT_TOOL_TYPE_STYLUS) {
+        ALOGD_IF(isDebug(), "Prediction not supported for non-stylus tool: %s",
+                 motionToolTypeToString(toolType));
+        return;
+    }
+
+    for (size_t i = 0; i <= event.getHistorySize(); ++i) {
+        if (event.isResampled(0, i)) {
+            continue;
+        }
+        const PointerCoords* coords = event.getHistoricalRawPointerCoords(0, i);
+        buffers.pushSample(event.getHistoricalEventTime(i),
+                           {
+                                   .position.x = coords->getAxisValue(AMOTION_EVENT_AXIS_X),
+                                   .position.y = coords->getAxisValue(AMOTION_EVENT_AXIS_Y),
+                                   .pressure = event.getHistoricalPressure(0, i),
+                                   .tilt = event.getHistoricalAxisValue(AMOTION_EVENT_AXIS_TILT, 0,
+                                                                        i),
+                                   .orientation = event.getHistoricalOrientation(0, i),
+                           });
+    }
+
+    mLastEvents.try_emplace(event.getDeviceId())
+            .first->second.copyFrom(&event, /*keepHistory=*/false);
 }
 
-/**
- * This is an example implementation that should be replaced with the actual prediction.
- * The returned MotionEvent should be similar to the incoming MotionEvent, except for the
- * fields that are predicted:
- *
- * 1) event.getEventTime
- * 2) event.getPointerCoords
- *
- * The returned event should not contain any of the real, existing data. It should only
- * contain the predicted samples.
- */
 std::vector<std::unique_ptr<MotionEvent>> MotionPredictor::predict(nsecs_t timestamp) {
-    if (mEvents.size() < 2) {
-        return {};
-    }
+    std::vector<std::unique_ptr<MotionEvent>> predictions;
 
-    const MotionEvent& event = mEvents.back();
-    if (!isPredictionAvailable(event.getDeviceId(), event.getSource())) {
-        return {};
-    }
-
-    std::unique_ptr<MotionEvent> prediction = std::make_unique<MotionEvent>();
-    std::vector<PointerCoords> futureCoords;
-    const nsecs_t futureTime = timestamp + mPredictionTimestampOffsetNanos;
-    const nsecs_t currentTime = event.getEventTime();
-    const MotionEvent& previous = mEvents.rbegin()[1];
-    const nsecs_t oldTime = previous.getEventTime();
-    if (currentTime == oldTime) {
-        // This can happen if it's an ACTION_POINTER_DOWN event, for example.
-        return {}; // prevent division by zero.
-    }
-
-    for (size_t i = 0; i < event.getPointerCount(); i++) {
-        const int32_t pointerId = event.getPointerId(i);
-        const PointerCoords* currentPointerCoords = event.getRawPointerCoords(i);
-        const float currentX = currentPointerCoords->getAxisValue(AMOTION_EVENT_AXIS_X);
-        const float currentY = currentPointerCoords->getAxisValue(AMOTION_EVENT_AXIS_Y);
-
-        PointerCoords coords;
-        coords.clear();
-
-        ssize_t index = previous.findPointerIndex(pointerId);
-        if (index >= 0) {
-            // We have old data for this pointer. Compute the prediction.
-            const PointerCoords* oldPointerCoords = previous.getRawPointerCoords(index);
-            const float oldX = oldPointerCoords->getAxisValue(AMOTION_EVENT_AXIS_X);
-            const float oldY = oldPointerCoords->getAxisValue(AMOTION_EVENT_AXIS_Y);
-
-            // Let's do a linear interpolation while waiting for a real model
-            const float scale =
-                    static_cast<float>(futureTime - currentTime) / (currentTime - oldTime);
-            const float futureX = currentX + (currentX - oldX) * scale;
-            const float futureY = currentY + (currentY - oldY) * scale;
-
-            coords.setAxisValue(AMOTION_EVENT_AXIS_X, futureX);
-            coords.setAxisValue(AMOTION_EVENT_AXIS_Y, futureY);
-            ALOGD_IF(isDebug(),
-                     "Prediction by %.1f ms, (%.1f, %.1f), (%.1f, %.1f) --> (%.1f, %.1f)",
-                     (futureTime - event.getEventTime()) * 1E-6, oldX, oldY, currentX, currentY,
-                     futureX, futureY);
+    for (const auto& [deviceId, buffer] : mDeviceBuffers) {
+        if (!buffer.isReady()) {
+            continue;
         }
 
-        futureCoords.push_back(coords);
+        buffer.copyTo(*mModel);
+        LOG_ALWAYS_FATAL_IF(!mModel->invoke());
+
+        // Read out the predictions.
+        const std::span<const float> predictedR = mModel->outputR();
+        const std::span<const float> predictedPhi = mModel->outputPhi();
+        const std::span<const float> predictedPressure = mModel->outputPressure();
+
+        TfLiteMotionPredictorSample::Point axisFrom = buffer.axisFrom().position;
+        TfLiteMotionPredictorSample::Point axisTo = buffer.axisTo().position;
+
+        if (isDebug()) {
+            ALOGD("deviceId: %d", deviceId);
+            ALOGD("axisFrom: %f, %f", axisFrom.x, axisFrom.y);
+            ALOGD("axisTo: %f, %f", axisTo.x, axisTo.y);
+            ALOGD("mInputR: %s", base::Join(mModel->inputR(), ", ").c_str());
+            ALOGD("mInputPhi: %s", base::Join(mModel->inputPhi(), ", ").c_str());
+            ALOGD("mInputPressure: %s", base::Join(mModel->inputPressure(), ", ").c_str());
+            ALOGD("mInputTilt: %s", base::Join(mModel->inputTilt(), ", ").c_str());
+            ALOGD("mInputOrientation: %s", base::Join(mModel->inputOrientation(), ", ").c_str());
+            ALOGD("predictedR: %s", base::Join(predictedR, ", ").c_str());
+            ALOGD("predictedPhi: %s", base::Join(predictedPhi, ", ").c_str());
+            ALOGD("predictedPressure: %s", base::Join(predictedPressure, ", ").c_str());
+        }
+
+        const MotionEvent& event = mLastEvents[deviceId];
+        bool hasPredictions = false;
+        std::unique_ptr<MotionEvent> prediction = std::make_unique<MotionEvent>();
+        int64_t predictionTime = buffer.lastTimestamp();
+        const int64_t futureTime = timestamp + mPredictionTimestampOffsetNanos;
+
+        for (int i = 0; i < predictedR.size() && predictionTime <= futureTime; ++i) {
+            const TfLiteMotionPredictorSample::Point point =
+                    convertPrediction(axisFrom, axisTo, predictedR[i], predictedPhi[i]);
+            // TODO(b/266747654): Stop predictions if confidence is < some threshold.
+
+            ALOGD_IF(isDebug(), "prediction %d: %f, %f", i, point.x, point.y);
+            PointerCoords coords;
+            coords.clear();
+            coords.setAxisValue(AMOTION_EVENT_AXIS_X, point.x);
+            coords.setAxisValue(AMOTION_EVENT_AXIS_Y, point.y);
+            // TODO(b/266747654): Stop predictions if predicted pressure is < some threshold.
+            coords.setAxisValue(AMOTION_EVENT_AXIS_PRESSURE, predictedPressure[i]);
+
+            predictionTime += PREDICTION_INTERVAL_NANOS;
+            if (i == 0) {
+                hasPredictions = true;
+                prediction->initialize(InputEvent::nextId(), event.getDeviceId(), event.getSource(),
+                                       event.getDisplayId(), INVALID_HMAC,
+                                       AMOTION_EVENT_ACTION_MOVE, event.getActionButton(),
+                                       event.getFlags(), event.getEdgeFlags(), event.getMetaState(),
+                                       event.getButtonState(), event.getClassification(),
+                                       event.getTransform(), event.getXPrecision(),
+                                       event.getYPrecision(), event.getRawXCursorPosition(),
+                                       event.getRawYCursorPosition(), event.getRawTransform(),
+                                       event.getDownTime(), predictionTime, event.getPointerCount(),
+                                       event.getPointerProperties(), &coords);
+            } else {
+                prediction->addSample(predictionTime, &coords);
+            }
+
+            axisFrom = axisTo;
+            axisTo = point;
+        }
+        // TODO(b/266747511): Interpolate to futureTime?
+        if (hasPredictions) {
+            predictions.push_back(std::move(prediction));
+        }
     }
-
-    /**
-     * The process of adding samples is different for the first and subsequent samples:
-     * 1. Add the first sample via 'initialize' as below
-     * 2. Add subsequent samples via 'addSample'
-     */
-    prediction->initialize(event.getId(), event.getDeviceId(), event.getSource(),
-                           event.getDisplayId(), event.getHmac(), event.getAction(),
-                           event.getActionButton(), event.getFlags(), event.getEdgeFlags(),
-                           event.getMetaState(), event.getButtonState(), event.getClassification(),
-                           event.getTransform(), event.getXPrecision(), event.getYPrecision(),
-                           event.getRawXCursorPosition(), event.getRawYCursorPosition(),
-                           event.getRawTransform(), event.getDownTime(), futureTime,
-                           event.getPointerCount(), event.getPointerProperties(),
-                           futureCoords.data());
-
-    // To add more predicted samples, use 'addSample':
-    prediction->addSample(futureTime + 1, futureCoords.data());
-
-    std::vector<std::unique_ptr<MotionEvent>> out;
-    out.push_back(std::move(prediction));
-    return out;
+    return predictions;
 }
 
 bool MotionPredictor::isPredictionAvailable(int32_t /*deviceId*/, int32_t source) {