Add a jerk tracker to MotionPredictor
This is a no-op
Test: atest libinput_tests
Bug: 266747654
Change-Id: I0f67a1d8ea59a6851afa6f523b6e00b2281745d0
diff --git a/include/input/MotionPredictor.h b/include/input/MotionPredictor.h
index 3b6e401..f715039 100644
--- a/include/input/MotionPredictor.h
+++ b/include/input/MotionPredictor.h
@@ -16,6 +16,7 @@
#pragma once
+#include <array>
#include <cstdint>
#include <memory>
#include <mutex>
@@ -28,6 +29,7 @@
#include <android/sysprop/InputProperties.sysprop.h>
#include <input/Input.h>
#include <input/MotionPredictorMetricsManager.h>
+#include <input/RingBuffer.h>
#include <input/TfLiteMotionPredictor.h>
#include <utils/Timers.h> // for nsecs_t
@@ -37,6 +39,31 @@
return sysprop::InputProperties::enable_motion_prediction().value_or(true);
}
+// Tracker to calculate jerk from motion position samples.
+class JerkTracker {
+public:
+ // Initialize the tracker. If normalizedDt is true, assume that each sample pushed has dt=1.
+ JerkTracker(bool normalizedDt);
+
+ // Add a position to the tracker and update derivative estimates.
+ void pushSample(int64_t timestamp, float xPos, float yPos);
+
+ // Reset JerkTracker for a new motion input.
+ void reset();
+
+ // Return last jerk calculation, if enough samples have been collected.
+ // Jerk is defined as the 3rd derivative of position (change in
+ // acceleration) and has the units of d^3p/dt^3.
+ std::optional<float> jerkMagnitude() const;
+
+private:
+ const bool mNormalizedDt;
+
+ RingBuffer<int64_t> mTimestamps{4};
+ std::array<float, 4> mXDerivatives{}; // [x, x', x'', x''']
+ std::array<float, 4> mYDerivatives{}; // [y, y', y'', y''']
+};
+
/**
* Given a set of MotionEvents for the current gesture, predict the motion. The returned MotionEvent
* contains a set of samples in the future.
@@ -97,6 +124,11 @@
std::unique_ptr<TfLiteMotionPredictorBuffers> mBuffers;
std::optional<MotionEvent> mLastEvent;
+ // mJerkTracker assumes normalized dt = 1 between recorded samples because
+ // the underlying mModel input also assumes fixed-interval samples.
+ // Normalized dt as 1 is also used to correspond with the similar Jank
+ // implementation from the JetPack MotionPredictor implementation.
+ JerkTracker mJerkTracker{true};
std::optional<MotionPredictorMetricsManager> mMetricsManager;
diff --git a/libs/input/MotionPredictor.cpp b/libs/input/MotionPredictor.cpp
index 1df88dd..77292d4 100644
--- a/libs/input/MotionPredictor.cpp
+++ b/libs/input/MotionPredictor.cpp
@@ -18,12 +18,15 @@
#include <input/MotionPredictor.h>
+#include <array>
#include <cinttypes>
#include <cmath>
#include <cstddef>
#include <cstdint>
#include <limits>
+#include <optional>
#include <string>
+#include <utility>
#include <vector>
#include <android-base/logging.h>
@@ -61,6 +64,66 @@
} // namespace
+// --- JerkTracker ---
+
+JerkTracker::JerkTracker(bool normalizedDt) : mNormalizedDt(normalizedDt) {}
+
+void JerkTracker::pushSample(int64_t timestamp, float xPos, float yPos) {
+ mTimestamps.pushBack(timestamp);
+ const int numSamples = mTimestamps.size();
+
+ std::array<float, 4> newXDerivatives;
+ std::array<float, 4> newYDerivatives;
+
+ /**
+ * Diagram showing the calculation of higher order derivatives of sample x3
+ * collected at time=t3.
+ * Terms in parentheses are not stored (and not needed for calculations)
+ * t0 ----- t1 ----- t2 ----- t3
+ * (x0)-----(x1) ----- x2 ----- x3
+ * (x'0) --- x'1 --- x'2
+ * x''0 - x''1
+ * x'''0
+ *
+ * In this example:
+ * x'2 = (x3 - x2) / (t3 - t2)
+ * x''1 = (x'2 - x'1) / (t2 - t1)
+ * x'''0 = (x''1 - x''0) / (t1 - t0)
+ * Therefore, timestamp history is needed to calculate higher order derivatives,
+ * compared to just the last calculated derivative sample.
+ *
+ * If mNormalizedDt = true, then dt = 1 and the division is moot.
+ */
+ for (int i = 0; i < numSamples; ++i) {
+ if (i == 0) {
+ newXDerivatives[i] = xPos;
+ newYDerivatives[i] = yPos;
+ } else {
+ newXDerivatives[i] = newXDerivatives[i - 1] - mXDerivatives[i - 1];
+ newYDerivatives[i] = newYDerivatives[i - 1] - mYDerivatives[i - 1];
+ if (!mNormalizedDt) {
+ const float dt = mTimestamps[numSamples - i] - mTimestamps[numSamples - i - 1];
+ newXDerivatives[i] = newXDerivatives[i] / dt;
+ newYDerivatives[i] = newYDerivatives[i] / dt;
+ }
+ }
+ }
+
+ std::swap(newXDerivatives, mXDerivatives);
+ std::swap(newYDerivatives, mYDerivatives);
+}
+
+void JerkTracker::reset() {
+ mTimestamps.clear();
+}
+
+std::optional<float> JerkTracker::jerkMagnitude() const {
+ if (mTimestamps.size() == mTimestamps.capacity()) {
+ return std::hypot(mXDerivatives[3], mYDerivatives[3]);
+ }
+ return std::nullopt;
+}
+
// --- MotionPredictor ---
MotionPredictor::MotionPredictor(nsecs_t predictionTimestampOffsetNanos,
@@ -107,6 +170,7 @@
if (action == AMOTION_EVENT_ACTION_UP || action == AMOTION_EVENT_ACTION_CANCEL) {
ALOGD_IF(isDebug(), "End of event stream");
mBuffers->reset();
+ mJerkTracker.reset();
mLastEvent.reset();
return {};
} else if (action != AMOTION_EVENT_ACTION_DOWN && action != AMOTION_EVENT_ACTION_MOVE) {
@@ -141,6 +205,9 @@
0, i),
.orientation = event.getHistoricalOrientation(0, i),
});
+ mJerkTracker.pushSample(event.getHistoricalEventTime(i),
+ coords->getAxisValue(AMOTION_EVENT_AXIS_X),
+ coords->getAxisValue(AMOTION_EVENT_AXIS_Y));
}
if (!mLastEvent) {
diff --git a/libs/input/tests/MotionPredictor_test.cpp b/libs/input/tests/MotionPredictor_test.cpp
index 3343114..f74874c 100644
--- a/libs/input/tests/MotionPredictor_test.cpp
+++ b/libs/input/tests/MotionPredictor_test.cpp
@@ -15,6 +15,7 @@
*/
#include <chrono>
+#include <cmath>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
@@ -65,6 +66,108 @@
return event;
}
+TEST(JerkTrackerTest, JerkReadiness) {
+ JerkTracker jerkTracker(true);
+ EXPECT_FALSE(jerkTracker.jerkMagnitude());
+ jerkTracker.pushSample(/*timestamp=*/0, 20, 50);
+ EXPECT_FALSE(jerkTracker.jerkMagnitude());
+ jerkTracker.pushSample(/*timestamp=*/1, 25, 53);
+ EXPECT_FALSE(jerkTracker.jerkMagnitude());
+ jerkTracker.pushSample(/*timestamp=*/2, 30, 60);
+ EXPECT_FALSE(jerkTracker.jerkMagnitude());
+ jerkTracker.pushSample(/*timestamp=*/3, 35, 70);
+ EXPECT_TRUE(jerkTracker.jerkMagnitude());
+ jerkTracker.reset();
+ EXPECT_FALSE(jerkTracker.jerkMagnitude());
+ jerkTracker.pushSample(/*timestamp=*/4, 30, 60);
+ EXPECT_FALSE(jerkTracker.jerkMagnitude());
+}
+
+TEST(JerkTrackerTest, JerkCalculationNormalizedDtTrue) {
+ JerkTracker jerkTracker(true);
+ jerkTracker.pushSample(/*timestamp=*/0, 20, 50);
+ jerkTracker.pushSample(/*timestamp=*/1, 25, 53);
+ jerkTracker.pushSample(/*timestamp=*/2, 30, 60);
+ jerkTracker.pushSample(/*timestamp=*/3, 45, 70);
+ /**
+ * Jerk derivative table
+ * x: 20 25 30 45
+ * x': 5 5 15
+ * x'': 0 10
+ * x''': 10
+ *
+ * y: 50 53 60 70
+ * y': 3 7 10
+ * y'': 4 3
+ * y''': -1
+ */
+ EXPECT_FLOAT_EQ(jerkTracker.jerkMagnitude().value(), std::hypot(10, -1));
+ jerkTracker.pushSample(/*timestamp=*/4, 20, 65);
+ /**
+ * (continuing from above table)
+ * x: 45 -> 20
+ * x': 15 -> -25
+ * x'': 10 -> -40
+ * x''': -50
+ *
+ * y: 70 -> 65
+ * y': 10 -> -5
+ * y'': 3 -> -15
+ * y''': -18
+ */
+ EXPECT_FLOAT_EQ(jerkTracker.jerkMagnitude().value(), std::hypot(-50, -18));
+}
+
+TEST(JerkTrackerTest, JerkCalculationNormalizedDtFalse) {
+ JerkTracker jerkTracker(false);
+ jerkTracker.pushSample(/*timestamp=*/0, 20, 50);
+ jerkTracker.pushSample(/*timestamp=*/10, 25, 53);
+ jerkTracker.pushSample(/*timestamp=*/20, 30, 60);
+ jerkTracker.pushSample(/*timestamp=*/30, 45, 70);
+ /**
+ * Jerk derivative table
+ * x: 20 25 30 45
+ * x': .5 .5 1.5
+ * x'': 0 .1
+ * x''': .01
+ *
+ * y: 50 53 60 70
+ * y': .3 .7 1
+ * y'': .04 .03
+ * y''': -.001
+ */
+ EXPECT_FLOAT_EQ(jerkTracker.jerkMagnitude().value(), std::hypot(.01, -.001));
+ jerkTracker.pushSample(/*timestamp=*/50, 20, 65);
+ /**
+ * (continuing from above table)
+ * x: 45 -> 20
+ * x': 1.5 -> -1.25 (delta above, divide by 20)
+ * x'': .1 -> -.275 (delta above, divide by 10)
+ * x''': -.0375 (delta above, divide by 10)
+ *
+ * y: 70 -> 65
+ * y': 1 -> -.25 (delta above, divide by 20)
+ * y'': .03 -> -.125 (delta above, divide by 10)
+ * y''': -.0155 (delta above, divide by 10)
+ */
+ EXPECT_FLOAT_EQ(jerkTracker.jerkMagnitude().value(), std::hypot(-.0375, -.0155));
+}
+
+TEST(JerkTrackerTest, JerkCalculationAfterReset) {
+ JerkTracker jerkTracker(true);
+ jerkTracker.pushSample(/*timestamp=*/0, 20, 50);
+ jerkTracker.pushSample(/*timestamp=*/1, 25, 53);
+ jerkTracker.pushSample(/*timestamp=*/2, 30, 60);
+ jerkTracker.pushSample(/*timestamp=*/3, 45, 70);
+ jerkTracker.pushSample(/*timestamp=*/4, 20, 65);
+ jerkTracker.reset();
+ jerkTracker.pushSample(/*timestamp=*/5, 20, 50);
+ jerkTracker.pushSample(/*timestamp=*/6, 25, 53);
+ jerkTracker.pushSample(/*timestamp=*/7, 30, 60);
+ jerkTracker.pushSample(/*timestamp=*/8, 45, 70);
+ EXPECT_FLOAT_EQ(jerkTracker.jerkMagnitude().value(), std::hypot(10, -1));
+}
+
TEST(MotionPredictorTest, IsPredictionAvailable) {
MotionPredictor predictor(/*predictionTimestampOffsetNanos=*/0,
[]() { return true /*enable prediction*/; });