Add TFLite model for motion prediction.

This model generates probabilistic motion predictions based on a
sequence of relative input movements. The input movements are converted
into polar coordinates (distance and angle) based on an axis that
follows the current path. This ensures that the orientation of the
device and of the inputs do not affect the predictions. The orientation
of the input device is also transformed to be relative to the path axis.

The test cases verifying model efficacy are consolidated into CTS.

Bug: 167946763
Test: atest libinput_tests
PiperOrigin-RevId: 492068340
Change-Id: Icd8d90bd5a7ce79c699bfdb6367a4cbd8130441a
diff --git a/libs/input/tests/Android.bp b/libs/input/tests/Android.bp
index e2c0860..916a8f2 100644
--- a/libs/input/tests/Android.bp
+++ b/libs/input/tests/Android.bp
@@ -10,6 +10,7 @@
 
 cc_test {
     name: "libinput_tests",
+    cpp_std: "c++20",
     host_supported: true,
     srcs: [
         "IdGenerator_test.cpp",
@@ -18,12 +19,18 @@
         "InputEvent_test.cpp",
         "InputPublisherAndConsumer_test.cpp",
         "MotionPredictor_test.cpp",
+        "TfLiteMotionPredictor_test.cpp",
         "TouchResampling_test.cpp",
         "TouchVideoFrame_test.cpp",
         "VelocityTracker_test.cpp",
         "VerifiedInputEvent_test.cpp",
     ],
+    header_libs: [
+        "flatbuffer_headers",
+        "tensorflow_headers",
+    ],
     static_libs: [
+        "libgmock",
         "libgui_window_info_static",
         "libinput",
         "libui-types",
@@ -32,6 +39,7 @@
         "-Wall",
         "-Wextra",
         "-Werror",
+        "-Wno-unused-parameter",
     ],
     shared_libs: [
         "libbase",
@@ -39,10 +47,14 @@
         "libcutils",
         "liblog",
         "libPlatformProperties",
+        "libtflite",
         "libutils",
         "libvintf",
     ],
-    data: ["data/*"],
+    data: [
+        "data/*",
+        ":motion_predictor_model.fb",
+    ],
     test_options: {
         unit_test: true,
     },
diff --git a/libs/input/tests/MotionPredictor_test.cpp b/libs/input/tests/MotionPredictor_test.cpp
index d2b59a1..ce87c86 100644
--- a/libs/input/tests/MotionPredictor_test.cpp
+++ b/libs/input/tests/MotionPredictor_test.cpp
@@ -14,17 +14,36 @@
  * limitations under the License.
  */
 
+#include <chrono>
+
+#include <gmock/gmock.h>
 #include <gtest/gtest.h>
 #include <gui/constants.h>
 #include <input/Input.h>
 #include <input/MotionPredictor.h>
 
+using namespace std::literals::chrono_literals;
+
 namespace android {
 
+using ::testing::IsEmpty;
+using ::testing::SizeIs;
+using ::testing::UnorderedElementsAre;
+
+const char MODEL_PATH[] =
+#if defined(__ANDROID__)
+        "/system/etc/motion_predictor_model.fb";
+#else
+        "motion_predictor_model.fb";
+#endif
+
 constexpr int32_t DOWN = AMOTION_EVENT_ACTION_DOWN;
 constexpr int32_t MOVE = AMOTION_EVENT_ACTION_MOVE;
+constexpr int32_t UP = AMOTION_EVENT_ACTION_UP;
+constexpr nsecs_t NSEC_PER_MSEC = 1'000'000;
 
-static MotionEvent getMotionEvent(int32_t action, float x, float y, nsecs_t eventTime) {
+static MotionEvent getMotionEvent(int32_t action, float x, float y,
+                                  std::chrono::nanoseconds eventTime, int32_t deviceId = 0) {
     MotionEvent event;
     constexpr size_t pointerCount = 1;
     std::vector<PointerProperties> pointerProperties;
@@ -33,6 +52,7 @@
         PointerProperties properties;
         properties.clear();
         properties.id = i;
+        properties.toolType = AMOTION_EVENT_TOOL_TYPE_STYLUS;
         pointerProperties.push_back(properties);
         PointerCoords coords;
         coords.clear();
@@ -42,73 +62,93 @@
     }
 
     ui::Transform identityTransform;
-    event.initialize(InputEvent::nextId(), /*deviceId=*/0, AINPUT_SOURCE_STYLUS,
-                     ADISPLAY_ID_DEFAULT, {0}, action, /*actionButton=*/0, /*flags=*/0,
-                     AMOTION_EVENT_EDGE_FLAG_NONE, AMETA_NONE, /*buttonState=*/0,
-                     MotionClassification::NONE, identityTransform, /*xPrecision=*/0.1,
+    event.initialize(InputEvent::nextId(), deviceId, AINPUT_SOURCE_STYLUS, ADISPLAY_ID_DEFAULT, {0},
+                     action, /*actionButton=*/0, /*flags=*/0, AMOTION_EVENT_EDGE_FLAG_NONE,
+                     AMETA_NONE, /*buttonState=*/0, MotionClassification::NONE, identityTransform,
+                     /*xPrecision=*/0.1,
                      /*yPrecision=*/0.2, /*xCursorPosition=*/280, /*yCursorPosition=*/540,
-                     identityTransform, /*downTime=*/100, eventTime, pointerCount,
+                     identityTransform, /*downTime=*/100, eventTime.count(), pointerCount,
                      pointerProperties.data(), pointerCoords.data());
     return event;
 }
 
-/**
- * A linear motion should be predicted to be linear in the future
- */
-TEST(MotionPredictorTest, LinearPrediction) {
-    MotionPredictor predictor(/*predictionTimestampOffsetNanos=*/0,
-                              []() { return true /*enable prediction*/; });
-
-    predictor.record(getMotionEvent(DOWN, 0, 1, 0));
-    predictor.record(getMotionEvent(MOVE, 1, 3, 10));
-    predictor.record(getMotionEvent(MOVE, 2, 5, 20));
-    predictor.record(getMotionEvent(MOVE, 3, 7, 30));
-    std::vector<std::unique_ptr<MotionEvent>> predicted = predictor.predict(40);
-    ASSERT_EQ(1u, predicted.size());
-    ASSERT_EQ(predicted[0]->getX(0), 4);
-    ASSERT_EQ(predicted[0]->getY(0), 9);
-}
-
-/**
- * A still motion should be predicted to remain still
- */
-TEST(MotionPredictorTest, StationaryPrediction) {
-    MotionPredictor predictor(/*predictionTimestampOffsetNanos=*/0,
-                              []() { return true /*enable prediction*/; });
-
-    predictor.record(getMotionEvent(DOWN, 0, 1, 0));
-    predictor.record(getMotionEvent(MOVE, 0, 1, 10));
-    predictor.record(getMotionEvent(MOVE, 0, 1, 20));
-    predictor.record(getMotionEvent(MOVE, 0, 1, 30));
-    std::vector<std::unique_ptr<MotionEvent>> predicted = predictor.predict(40);
-    ASSERT_EQ(1u, predicted.size());
-    ASSERT_EQ(predicted[0]->getX(0), 0);
-    ASSERT_EQ(predicted[0]->getY(0), 1);
-}
-
 TEST(MotionPredictorTest, IsPredictionAvailable) {
-    MotionPredictor predictor(/*predictionTimestampOffsetNanos=*/0,
+    MotionPredictor predictor(/*predictionTimestampOffsetNanos=*/0, MODEL_PATH,
                               []() { return true /*enable prediction*/; });
     ASSERT_TRUE(predictor.isPredictionAvailable(/*deviceId=*/1, AINPUT_SOURCE_STYLUS));
     ASSERT_FALSE(predictor.isPredictionAvailable(/*deviceId=*/1, AINPUT_SOURCE_TOUCHSCREEN));
 }
 
 TEST(MotionPredictorTest, Offset) {
-    MotionPredictor predictor(/*predictionTimestampOffsetNanos=*/1,
+    MotionPredictor predictor(/*predictionTimestampOffsetNanos=*/1, MODEL_PATH,
                               []() { return true /*enable prediction*/; });
-    predictor.record(getMotionEvent(DOWN, 0, 1, 30));
-    predictor.record(getMotionEvent(MOVE, 0, 1, 35));
-    std::vector<std::unique_ptr<MotionEvent>> predicted = predictor.predict(40);
+    predictor.record(getMotionEvent(DOWN, 0, 1, 30ms));
+    predictor.record(getMotionEvent(MOVE, 0, 2, 35ms));
+    std::vector<std::unique_ptr<MotionEvent>> predicted = predictor.predict(40 * NSEC_PER_MSEC);
     ASSERT_EQ(1u, predicted.size());
     ASSERT_GE(predicted[0]->getEventTime(), 41);
 }
 
+TEST(MotionPredictorTest, FollowsGesture) {
+    MotionPredictor predictor(/*predictionTimestampOffsetNanos=*/0, MODEL_PATH,
+                              []() { return true /*enable prediction*/; });
+
+    // MOVE without a DOWN is ignored.
+    predictor.record(getMotionEvent(MOVE, 1, 3, 10ms));
+    EXPECT_THAT(predictor.predict(20 * NSEC_PER_MSEC), IsEmpty());
+
+    predictor.record(getMotionEvent(DOWN, 2, 5, 20ms));
+    predictor.record(getMotionEvent(MOVE, 2, 7, 30ms));
+    predictor.record(getMotionEvent(MOVE, 3, 9, 40ms));
+    EXPECT_THAT(predictor.predict(50 * NSEC_PER_MSEC), SizeIs(1));
+
+    predictor.record(getMotionEvent(UP, 4, 11, 50ms));
+    EXPECT_THAT(predictor.predict(20 * NSEC_PER_MSEC), IsEmpty());
+}
+
+TEST(MotionPredictorTest, MultipleDevicesTracked) {
+    MotionPredictor predictor(/*predictionTimestampOffsetNanos=*/0, MODEL_PATH,
+                              []() { return true /*enable prediction*/; });
+
+    predictor.record(getMotionEvent(DOWN, 1, 3, 0ms, /*deviceId=*/0));
+    predictor.record(getMotionEvent(MOVE, 1, 3, 10ms, /*deviceId=*/0));
+    predictor.record(getMotionEvent(MOVE, 2, 5, 20ms, /*deviceId=*/0));
+    predictor.record(getMotionEvent(MOVE, 3, 7, 30ms, /*deviceId=*/0));
+
+    predictor.record(getMotionEvent(DOWN, 100, 300, 0ms, /*deviceId=*/1));
+    predictor.record(getMotionEvent(MOVE, 100, 300, 10ms, /*deviceId=*/1));
+    predictor.record(getMotionEvent(MOVE, 200, 500, 20ms, /*deviceId=*/1));
+    predictor.record(getMotionEvent(MOVE, 300, 700, 30ms, /*deviceId=*/1));
+
+    {
+        std::vector<std::unique_ptr<MotionEvent>> predicted = predictor.predict(40 * NSEC_PER_MSEC);
+        ASSERT_EQ(2u, predicted.size());
+
+        // Order of the returned vector is not guaranteed.
+        std::vector<int32_t> seenDeviceIds;
+        for (const auto& prediction : predicted) {
+            seenDeviceIds.push_back(prediction->getDeviceId());
+        }
+        EXPECT_THAT(seenDeviceIds, UnorderedElementsAre(0, 1));
+    }
+
+    // End the gesture for device 0.
+    predictor.record(getMotionEvent(UP, 4, 9, 40ms, /*deviceId=*/0));
+    predictor.record(getMotionEvent(MOVE, 400, 900, 40ms, /*deviceId=*/1));
+
+    {
+        std::vector<std::unique_ptr<MotionEvent>> predicted = predictor.predict(40 * NSEC_PER_MSEC);
+        ASSERT_EQ(1u, predicted.size());
+        ASSERT_EQ(predicted[0]->getDeviceId(), 1);
+    }
+}
+
 TEST(MotionPredictorTest, FlagDisablesPrediction) {
-    MotionPredictor predictor(/*predictionTimestampOffsetNanos=*/0,
+    MotionPredictor predictor(/*predictionTimestampOffsetNanos=*/0, MODEL_PATH,
                               []() { return false /*disable prediction*/; });
-    predictor.record(getMotionEvent(DOWN, 0, 1, 30));
-    predictor.record(getMotionEvent(MOVE, 0, 1, 35));
-    std::vector<std::unique_ptr<MotionEvent>> predicted = predictor.predict(40);
+    predictor.record(getMotionEvent(DOWN, 0, 1, 30ms));
+    predictor.record(getMotionEvent(MOVE, 0, 1, 35ms));
+    std::vector<std::unique_ptr<MotionEvent>> predicted = predictor.predict(40 * NSEC_PER_MSEC);
     ASSERT_EQ(0u, predicted.size());
     ASSERT_FALSE(predictor.isPredictionAvailable(/*deviceId=*/1, AINPUT_SOURCE_STYLUS));
     ASSERT_FALSE(predictor.isPredictionAvailable(/*deviceId=*/1, AINPUT_SOURCE_TOUCHSCREEN));
diff --git a/libs/input/tests/TfLiteMotionPredictor_test.cpp b/libs/input/tests/TfLiteMotionPredictor_test.cpp
new file mode 100644
index 0000000..454f2aa
--- /dev/null
+++ b/libs/input/tests/TfLiteMotionPredictor_test.cpp
@@ -0,0 +1,179 @@
+/*
+ * Copyright (C) 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <algorithm>
+#include <cmath>
+#include <fstream>
+#include <ios>
+#include <iterator>
+#include <string>
+
+#include <android-base/file.h>
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include <input/TfLiteMotionPredictor.h>
+
+namespace android {
+namespace {
+
+using ::testing::Each;
+using ::testing::ElementsAre;
+using ::testing::FloatNear;
+
+std::string getModelPath() {
+#if defined(__ANDROID__)
+    return "/system/etc/motion_predictor_model.fb";
+#else
+    return base::GetExecutableDirectory() + "/motion_predictor_model.fb";
+#endif
+}
+
+TEST(TfLiteMotionPredictorTest, BuffersReadiness) {
+    TfLiteMotionPredictorBuffers buffers(/*inputLength=*/5);
+    ASSERT_FALSE(buffers.isReady());
+
+    buffers.pushSample(/*timestamp=*/0, {.position = {.x = 100, .y = 100}});
+    ASSERT_FALSE(buffers.isReady());
+
+    buffers.pushSample(/*timestamp=*/1, {.position = {.x = 100, .y = 100}});
+    ASSERT_FALSE(buffers.isReady());
+
+    // Two samples with distinct positions are required.
+    buffers.pushSample(/*timestamp=*/2, {.position = {.x = 100, .y = 110}});
+    ASSERT_TRUE(buffers.isReady());
+
+    buffers.reset();
+    ASSERT_FALSE(buffers.isReady());
+}
+
+TEST(TfLiteMotionPredictorTest, BuffersRecentData) {
+    TfLiteMotionPredictorBuffers buffers(/*inputLength=*/5);
+
+    buffers.pushSample(/*timestamp=*/1, {.position = {.x = 100, .y = 200}});
+    ASSERT_EQ(buffers.lastTimestamp(), 1);
+
+    buffers.pushSample(/*timestamp=*/2, {.position = {.x = 150, .y = 250}});
+    ASSERT_EQ(buffers.lastTimestamp(), 2);
+    ASSERT_TRUE(buffers.isReady());
+    ASSERT_EQ(buffers.axisFrom().position.x, 100);
+    ASSERT_EQ(buffers.axisFrom().position.y, 200);
+    ASSERT_EQ(buffers.axisTo().position.x, 150);
+    ASSERT_EQ(buffers.axisTo().position.y, 250);
+
+    // Position doesn't change, so neither do the axes.
+    buffers.pushSample(/*timestamp=*/3, {.position = {.x = 150, .y = 250}});
+    ASSERT_EQ(buffers.lastTimestamp(), 3);
+    ASSERT_TRUE(buffers.isReady());
+    ASSERT_EQ(buffers.axisFrom().position.x, 100);
+    ASSERT_EQ(buffers.axisFrom().position.y, 200);
+    ASSERT_EQ(buffers.axisTo().position.x, 150);
+    ASSERT_EQ(buffers.axisTo().position.y, 250);
+
+    buffers.pushSample(/*timestamp=*/4, {.position = {.x = 180, .y = 280}});
+    ASSERT_EQ(buffers.lastTimestamp(), 4);
+    ASSERT_TRUE(buffers.isReady());
+    ASSERT_EQ(buffers.axisFrom().position.x, 150);
+    ASSERT_EQ(buffers.axisFrom().position.y, 250);
+    ASSERT_EQ(buffers.axisTo().position.x, 180);
+    ASSERT_EQ(buffers.axisTo().position.y, 280);
+}
+
+TEST(TfLiteMotionPredictorTest, BuffersCopyTo) {
+    std::unique_ptr<TfLiteMotionPredictorModel> model =
+            TfLiteMotionPredictorModel::create(getModelPath().c_str());
+    TfLiteMotionPredictorBuffers buffers(model->inputLength());
+
+    buffers.pushSample(/*timestamp=*/1,
+                       {.position = {.x = 10, .y = 10},
+                        .pressure = 0,
+                        .orientation = 0,
+                        .tilt = 0.2});
+    buffers.pushSample(/*timestamp=*/2,
+                       {.position = {.x = 10, .y = 50},
+                        .pressure = 0.4,
+                        .orientation = M_PI / 4,
+                        .tilt = 0.3});
+    buffers.pushSample(/*timestamp=*/3,
+                       {.position = {.x = 30, .y = 50},
+                        .pressure = 0.5,
+                        .orientation = -M_PI / 4,
+                        .tilt = 0.4});
+    buffers.pushSample(/*timestamp=*/3,
+                       {.position = {.x = 30, .y = 60},
+                        .pressure = 0,
+                        .orientation = 0,
+                        .tilt = 0.5});
+    buffers.copyTo(*model);
+
+    const int zeroPadding = model->inputLength() - 3;
+    ASSERT_GE(zeroPadding, 0);
+
+    EXPECT_THAT(model->inputR().subspan(0, zeroPadding), Each(0));
+    EXPECT_THAT(model->inputPhi().subspan(0, zeroPadding), Each(0));
+    EXPECT_THAT(model->inputPressure().subspan(0, zeroPadding), Each(0));
+    EXPECT_THAT(model->inputTilt().subspan(0, zeroPadding), Each(0));
+    EXPECT_THAT(model->inputOrientation().subspan(0, zeroPadding), Each(0));
+
+    EXPECT_THAT(model->inputR().subspan(zeroPadding), ElementsAre(40, 20, 10));
+    EXPECT_THAT(model->inputPhi().subspan(zeroPadding), ElementsAre(0, -M_PI / 2, M_PI / 2));
+    EXPECT_THAT(model->inputPressure().subspan(zeroPadding), ElementsAre(0.4, 0.5, 0));
+    EXPECT_THAT(model->inputTilt().subspan(zeroPadding), ElementsAre(0.3, 0.4, 0.5));
+    EXPECT_THAT(model->inputOrientation().subspan(zeroPadding),
+                ElementsAre(FloatNear(-M_PI / 4, 1e-5), FloatNear(M_PI / 4, 1e-5),
+                            FloatNear(M_PI / 2, 1e-5)));
+}
+
+TEST(TfLiteMotionPredictorTest, ModelInputOutputLength) {
+    std::unique_ptr<TfLiteMotionPredictorModel> model =
+            TfLiteMotionPredictorModel::create(getModelPath().c_str());
+    ASSERT_GT(model->inputLength(), 0u);
+
+    const int inputLength = model->inputLength();
+    ASSERT_EQ(inputLength, model->inputR().size());
+    ASSERT_EQ(inputLength, model->inputPhi().size());
+    ASSERT_EQ(inputLength, model->inputPressure().size());
+    ASSERT_EQ(inputLength, model->inputOrientation().size());
+    ASSERT_EQ(inputLength, model->inputTilt().size());
+
+    ASSERT_TRUE(model->invoke());
+
+    ASSERT_EQ(model->outputR().size(), model->outputPhi().size());
+    ASSERT_EQ(model->outputR().size(), model->outputPressure().size());
+}
+
+TEST(TfLiteMotionPredictorTest, ModelOutput) {
+    std::unique_ptr<TfLiteMotionPredictorModel> model =
+            TfLiteMotionPredictorModel::create(getModelPath().c_str());
+    TfLiteMotionPredictorBuffers buffers(model->inputLength());
+
+    buffers.pushSample(/*timestamp=*/1, {.position = {.x = 100, .y = 200}, .pressure = 0.2});
+    buffers.pushSample(/*timestamp=*/2, {.position = {.x = 150, .y = 250}, .pressure = 0.4});
+    buffers.pushSample(/*timestamp=*/3, {.position = {.x = 180, .y = 280}, .pressure = 0.6});
+    buffers.copyTo(*model);
+
+    ASSERT_TRUE(model->invoke());
+
+    // The actual model output is implementation-defined, but it should at least be non-zero and
+    // non-NaN.
+    const auto is_valid = [](float value) { return !isnan(value) && value != 0; };
+    ASSERT_TRUE(std::all_of(model->outputR().begin(), model->outputR().end(), is_valid));
+    ASSERT_TRUE(std::all_of(model->outputPhi().begin(), model->outputPhi().end(), is_valid));
+    ASSERT_TRUE(
+            std::all_of(model->outputPressure().begin(), model->outputPressure().end(), is_valid));
+}
+
+} // namespace
+} // namespace android