Add jerk thresholded pruning.

This will change MotionPredictor::predict() outputs once
enable_prediction_pruning_via_jerk_thresholding is enabled.

Test: atest libinput_tests --host with local flag override to enable
enable_prediction_pruning_via_jerk_thresholding.
Test: atest CtsInputTestCases
Test: atest MotionPredictorBenchmark MotionPredictorTest
Bug: 266747654

Change-Id: I11eb1972246468a1f3824656f5ac57e01e0359cd
diff --git a/libs/input/MotionPredictor.cpp b/libs/input/MotionPredictor.cpp
index 77292d4..5b61d39 100644
--- a/libs/input/MotionPredictor.cpp
+++ b/libs/input/MotionPredictor.cpp
@@ -18,6 +18,7 @@
 
 #include <input/MotionPredictor.h>
 
+#include <algorithm>
 #include <array>
 #include <cinttypes>
 #include <cmath>
@@ -62,6 +63,11 @@
     return {.x = axisTo.x + x_delta, .y = axisTo.y + y_delta};
 }
 
+float normalizeRange(float x, float min, float max) {
+    const float normalized = (x - min) / (max - min);
+    return std::min(1.0f, std::max(0.0f, normalized));
+}
+
 } // namespace
 
 // --- JerkTracker ---
@@ -255,6 +261,17 @@
     int64_t predictionTime = mBuffers->lastTimestamp();
     const int64_t futureTime = timestamp + mPredictionTimestampOffsetNanos;
 
+    const float jerkMagnitude = mJerkTracker.jerkMagnitude().value_or(0);
+    const float fractionKept =
+            1 - normalizeRange(jerkMagnitude, mModel->config().lowJerk, mModel->config().highJerk);
+    // float to ensure proper division below.
+    const float predictionTimeWindow = futureTime - predictionTime;
+    const int maxNumPredictions = static_cast<int>(
+            std::ceil(predictionTimeWindow / mModel->config().predictionInterval * fractionKept));
+    ALOGD_IF(isDebug(),
+             "jerk (d^3p/normalizedDt^3): %f, fraction of prediction window pruned: %f, max number "
+             "of predictions: %d",
+             jerkMagnitude, 1 - fractionKept, maxNumPredictions);
     for (size_t i = 0; i < static_cast<size_t>(predictedR.size()) && predictionTime <= futureTime;
          ++i) {
         if (predictedR[i] < mModel->config().distanceNoiseFloor) {
@@ -269,13 +286,12 @@
             break;
         }
         if (input_flags::enable_prediction_pruning_via_jerk_thresholding()) {
-            // TODO(b/266747654): Stop predictions if confidence is < some threshold
-            // Arbitrarily high pruning index, will correct once jerk thresholding is implemented.
-            const size_t upperBoundPredictionIndex = std::numeric_limits<size_t>::max();
-            if (i > upperBoundPredictionIndex) {
+            if (i >= static_cast<size_t>(maxNumPredictions)) {
                 break;
             }
         }
+        // TODO(b/266747654): Stop predictions if confidence is < some
+        // threshold. Currently predictions are pruned via jerk thresholding.
 
         const TfLiteMotionPredictorSample::Point predictedPoint =
                 convertPrediction(axisFrom, axisTo, predictedR[i], predictedPhi[i]);
diff --git a/libs/input/TfLiteMotionPredictor.cpp b/libs/input/TfLiteMotionPredictor.cpp
index d17476e..b843a4b 100644
--- a/libs/input/TfLiteMotionPredictor.cpp
+++ b/libs/input/TfLiteMotionPredictor.cpp
@@ -281,6 +281,8 @@
     Config config{
             .predictionInterval = parseXMLInt64(*configRoot, "prediction-interval"),
             .distanceNoiseFloor = parseXMLFloat(*configRoot, "distance-noise-floor"),
+            .lowJerk = parseXMLFloat(*configRoot, "low-jerk"),
+            .highJerk = parseXMLFloat(*configRoot, "high-jerk"),
     };
 
     return std::unique_ptr<TfLiteMotionPredictorModel>(
diff --git a/libs/input/tests/Android.bp b/libs/input/tests/Android.bp
index e67a65a..ee140b7 100644
--- a/libs/input/tests/Android.bp
+++ b/libs/input/tests/Android.bp
@@ -36,6 +36,7 @@
         "tensorflow_headers",
     ],
     static_libs: [
+        "libflagtest",
         "libgmock",
         "libgui_window_info_static",
         "libinput",
diff --git a/libs/input/tests/MotionPredictor_test.cpp b/libs/input/tests/MotionPredictor_test.cpp
index f74874c..dc38fef 100644
--- a/libs/input/tests/MotionPredictor_test.cpp
+++ b/libs/input/tests/MotionPredictor_test.cpp
@@ -14,9 +14,12 @@
  * limitations under the License.
  */
 
+// TODO(b/331815574): Decouple this test from assumed config values.
 #include <chrono>
 #include <cmath>
 
+#include <com_android_input_flags.h>
+#include <flag_macros.h>
 #include <gmock/gmock.h>
 #include <gtest/gtest.h>
 #include <gui/constants.h>
@@ -197,18 +200,14 @@
 TEST(MotionPredictorTest, FollowsGesture) {
     MotionPredictor predictor(/*predictionTimestampOffsetNanos=*/0,
                               []() { return true /*enable prediction*/; });
+    predictor.record(getMotionEvent(DOWN, 3.75, 3, 20ms));
+    predictor.record(getMotionEvent(MOVE, 4.8, 3, 30ms));
+    predictor.record(getMotionEvent(MOVE, 6.2, 3, 40ms));
+    predictor.record(getMotionEvent(MOVE, 8, 3, 50ms));
+    EXPECT_NE(nullptr, predictor.predict(90 * NSEC_PER_MSEC));
 
-    // MOVE without a DOWN is ignored.
-    predictor.record(getMotionEvent(MOVE, 1, 3, 10ms));
-    EXPECT_EQ(nullptr, predictor.predict(20 * NSEC_PER_MSEC));
-
-    predictor.record(getMotionEvent(DOWN, 2, 5, 20ms));
-    predictor.record(getMotionEvent(MOVE, 2, 7, 30ms));
-    predictor.record(getMotionEvent(MOVE, 3, 9, 40ms));
-    EXPECT_NE(nullptr, predictor.predict(50 * NSEC_PER_MSEC));
-
-    predictor.record(getMotionEvent(UP, 4, 11, 50ms));
-    EXPECT_EQ(nullptr, predictor.predict(20 * NSEC_PER_MSEC));
+    predictor.record(getMotionEvent(UP, 10.25, 3, 60ms));
+    EXPECT_EQ(nullptr, predictor.predict(100 * NSEC_PER_MSEC));
 }
 
 TEST(MotionPredictorTest, MultipleDevicesNotSupported) {
@@ -250,6 +249,63 @@
     ASSERT_FALSE(predictor.isPredictionAvailable(/*deviceId=*/1, AINPUT_SOURCE_TOUCHSCREEN));
 }
 
+TEST_WITH_FLAGS(
+        MotionPredictorTest, LowJerkNoPruning,
+        REQUIRES_FLAGS_ENABLED(ACONFIG_FLAG(com::android::input::flags,
+                                            enable_prediction_pruning_via_jerk_thresholding))) {
+    MotionPredictor predictor(/*predictionTimestampOffsetNanos=*/0,
+                              []() { return true /*enable prediction*/; });
+
+    // Jerk is low (0.05 normalized).
+    predictor.record(getMotionEvent(DOWN, 2, 7, 20ms));
+    predictor.record(getMotionEvent(MOVE, 2.75, 7, 30ms));
+    predictor.record(getMotionEvent(MOVE, 3.8, 7, 40ms));
+    predictor.record(getMotionEvent(MOVE, 5.2, 7, 50ms));
+    predictor.record(getMotionEvent(MOVE, 7, 7, 60ms));
+    std::unique_ptr<MotionEvent> predicted = predictor.predict(90 * NSEC_PER_MSEC);
+    EXPECT_NE(nullptr, predicted);
+    EXPECT_EQ(static_cast<size_t>(5), predicted->getHistorySize() + 1);
+}
+
+TEST_WITH_FLAGS(
+        MotionPredictorTest, HighJerkPredictionsPruned,
+        REQUIRES_FLAGS_ENABLED(ACONFIG_FLAG(com::android::input::flags,
+                                            enable_prediction_pruning_via_jerk_thresholding))) {
+    MotionPredictor predictor(/*predictionTimestampOffsetNanos=*/0,
+                              []() { return true /*enable prediction*/; });
+
+    // Jerk is incredibly high.
+    predictor.record(getMotionEvent(DOWN, 0, 5, 20ms));
+    predictor.record(getMotionEvent(MOVE, 0, 70, 30ms));
+    predictor.record(getMotionEvent(MOVE, 0, 139, 40ms));
+    predictor.record(getMotionEvent(MOVE, 0, 1421, 50ms));
+    predictor.record(getMotionEvent(MOVE, 0, 41233, 60ms));
+    std::unique_ptr<MotionEvent> predicted = predictor.predict(90 * NSEC_PER_MSEC);
+    EXPECT_EQ(nullptr, predicted);
+}
+
+TEST_WITH_FLAGS(
+        MotionPredictorTest, MediumJerkPredictionsSomePruned,
+        REQUIRES_FLAGS_ENABLED(ACONFIG_FLAG(com::android::input::flags,
+                                            enable_prediction_pruning_via_jerk_thresholding))) {
+    MotionPredictor predictor(/*predictionTimestampOffsetNanos=*/0,
+                              []() { return true /*enable prediction*/; });
+
+    // Jerk is medium (1.5 normalized, which is halfway between LOW_JANK and HIGH_JANK)
+    predictor.record(getMotionEvent(DOWN, 0, 4, 20ms));
+    predictor.record(getMotionEvent(MOVE, 0, 6.25, 30ms));
+    predictor.record(getMotionEvent(MOVE, 0, 9.4, 40ms));
+    predictor.record(getMotionEvent(MOVE, 0, 13.6, 50ms));
+    predictor.record(getMotionEvent(MOVE, 0, 19, 60ms));
+    std::unique_ptr<MotionEvent> predicted = predictor.predict(82 * NSEC_PER_MSEC);
+    EXPECT_NE(nullptr, predicted);
+    // Halfway between LOW_JANK and HIGH_JANK means that half of the predictions
+    // will be pruned. If model prediction window is close enough to predict()
+    // call time window, then half of the model predictions (5/2 -> 2) will be
+    // ouputted.
+    EXPECT_EQ(static_cast<size_t>(3), predicted->getHistorySize() + 1);
+}
+
 using AtomFields = MotionPredictorMetricsManager::AtomFields;
 using ReportAtomFunction = MotionPredictorMetricsManager::ReportAtomFunction;