Add jerk thresholded pruning.

This will change MotionPredictor::predict() outputs once
enable_prediction_pruning_via_jerk_thresholding is enabled.

Test: atest libinput_tests --host with local flag override to enable
enable_prediction_pruning_via_jerk_thresholding.
Test: atest CtsInputTestCases
Test: atest MotionPredictorBenchmark MotionPredictorTest
Bug: 266747654

Change-Id: I11eb1972246468a1f3824656f5ac57e01e0359cd
diff --git a/libs/input/tests/Android.bp b/libs/input/tests/Android.bp
index e67a65a..ee140b7 100644
--- a/libs/input/tests/Android.bp
+++ b/libs/input/tests/Android.bp
@@ -36,6 +36,7 @@
         "tensorflow_headers",
     ],
     static_libs: [
+        "libflagtest",
         "libgmock",
         "libgui_window_info_static",
         "libinput",
diff --git a/libs/input/tests/MotionPredictor_test.cpp b/libs/input/tests/MotionPredictor_test.cpp
index f74874c..dc38fef 100644
--- a/libs/input/tests/MotionPredictor_test.cpp
+++ b/libs/input/tests/MotionPredictor_test.cpp
@@ -14,9 +14,12 @@
  * limitations under the License.
  */
 
+// TODO(b/331815574): Decouple this test from assumed config values.
 #include <chrono>
 #include <cmath>
 
+#include <com_android_input_flags.h>
+#include <flag_macros.h>
 #include <gmock/gmock.h>
 #include <gtest/gtest.h>
 #include <gui/constants.h>
@@ -197,18 +200,14 @@
 TEST(MotionPredictorTest, FollowsGesture) {
     MotionPredictor predictor(/*predictionTimestampOffsetNanos=*/0,
                               []() { return true /*enable prediction*/; });
+    predictor.record(getMotionEvent(DOWN, 3.75, 3, 20ms));
+    predictor.record(getMotionEvent(MOVE, 4.8, 3, 30ms));
+    predictor.record(getMotionEvent(MOVE, 6.2, 3, 40ms));
+    predictor.record(getMotionEvent(MOVE, 8, 3, 50ms));
+    EXPECT_NE(nullptr, predictor.predict(90 * NSEC_PER_MSEC));
 
-    // MOVE without a DOWN is ignored.
-    predictor.record(getMotionEvent(MOVE, 1, 3, 10ms));
-    EXPECT_EQ(nullptr, predictor.predict(20 * NSEC_PER_MSEC));
-
-    predictor.record(getMotionEvent(DOWN, 2, 5, 20ms));
-    predictor.record(getMotionEvent(MOVE, 2, 7, 30ms));
-    predictor.record(getMotionEvent(MOVE, 3, 9, 40ms));
-    EXPECT_NE(nullptr, predictor.predict(50 * NSEC_PER_MSEC));
-
-    predictor.record(getMotionEvent(UP, 4, 11, 50ms));
-    EXPECT_EQ(nullptr, predictor.predict(20 * NSEC_PER_MSEC));
+    predictor.record(getMotionEvent(UP, 10.25, 3, 60ms));
+    EXPECT_EQ(nullptr, predictor.predict(100 * NSEC_PER_MSEC));
 }
 
 TEST(MotionPredictorTest, MultipleDevicesNotSupported) {
@@ -250,6 +249,63 @@
     ASSERT_FALSE(predictor.isPredictionAvailable(/*deviceId=*/1, AINPUT_SOURCE_TOUCHSCREEN));
 }
 
+TEST_WITH_FLAGS(
+        MotionPredictorTest, LowJerkNoPruning,
+        REQUIRES_FLAGS_ENABLED(ACONFIG_FLAG(com::android::input::flags,
+                                            enable_prediction_pruning_via_jerk_thresholding))) {
+    MotionPredictor predictor(/*predictionTimestampOffsetNanos=*/0,
+                              []() { return true /*enable prediction*/; });
+
+    // Jerk is low (0.05 normalized).
+    predictor.record(getMotionEvent(DOWN, 2, 7, 20ms));
+    predictor.record(getMotionEvent(MOVE, 2.75, 7, 30ms));
+    predictor.record(getMotionEvent(MOVE, 3.8, 7, 40ms));
+    predictor.record(getMotionEvent(MOVE, 5.2, 7, 50ms));
+    predictor.record(getMotionEvent(MOVE, 7, 7, 60ms));
+    std::unique_ptr<MotionEvent> predicted = predictor.predict(90 * NSEC_PER_MSEC);
+    EXPECT_NE(nullptr, predicted);
+    EXPECT_EQ(static_cast<size_t>(5), predicted->getHistorySize() + 1);
+}
+
+TEST_WITH_FLAGS(
+        MotionPredictorTest, HighJerkPredictionsPruned,
+        REQUIRES_FLAGS_ENABLED(ACONFIG_FLAG(com::android::input::flags,
+                                            enable_prediction_pruning_via_jerk_thresholding))) {
+    MotionPredictor predictor(/*predictionTimestampOffsetNanos=*/0,
+                              []() { return true /*enable prediction*/; });
+
+    // Jerk is incredibly high.
+    predictor.record(getMotionEvent(DOWN, 0, 5, 20ms));
+    predictor.record(getMotionEvent(MOVE, 0, 70, 30ms));
+    predictor.record(getMotionEvent(MOVE, 0, 139, 40ms));
+    predictor.record(getMotionEvent(MOVE, 0, 1421, 50ms));
+    predictor.record(getMotionEvent(MOVE, 0, 41233, 60ms));
+    std::unique_ptr<MotionEvent> predicted = predictor.predict(90 * NSEC_PER_MSEC);
+    EXPECT_EQ(nullptr, predicted);
+}
+
+TEST_WITH_FLAGS(
+        MotionPredictorTest, MediumJerkPredictionsSomePruned,
+        REQUIRES_FLAGS_ENABLED(ACONFIG_FLAG(com::android::input::flags,
+                                            enable_prediction_pruning_via_jerk_thresholding))) {
+    MotionPredictor predictor(/*predictionTimestampOffsetNanos=*/0,
+                              []() { return true /*enable prediction*/; });
+
+    // Jerk is medium (1.5 normalized, which is halfway between LOW_JANK and HIGH_JANK)
+    predictor.record(getMotionEvent(DOWN, 0, 4, 20ms));
+    predictor.record(getMotionEvent(MOVE, 0, 6.25, 30ms));
+    predictor.record(getMotionEvent(MOVE, 0, 9.4, 40ms));
+    predictor.record(getMotionEvent(MOVE, 0, 13.6, 50ms));
+    predictor.record(getMotionEvent(MOVE, 0, 19, 60ms));
+    std::unique_ptr<MotionEvent> predicted = predictor.predict(82 * NSEC_PER_MSEC);
+    EXPECT_NE(nullptr, predicted);
+    // Halfway between LOW_JANK and HIGH_JANK means that half of the predictions
+    // will be pruned. If model prediction window is close enough to predict()
+    // call time window, then half of the model predictions (5/2 -> 2) will be
+    // ouputted.
+    EXPECT_EQ(static_cast<size_t>(3), predicted->getHistorySize() + 1);
+}
+
 using AtomFields = MotionPredictorMetricsManager::AtomFields;
 using ReportAtomFunction = MotionPredictorMetricsManager::ReportAtomFunction;