Refactor JerkTracker and MotionPredictor for better testing.
Changes include renaming forgetFactor to alpha.
Test: atest libinput_tests
Bug: 266747654
Bug: 353161308
Flag: com.android.input.flags.enable_prediction_pruning_via_jerk_thresholding
Change-Id: Icd056d36a3d7894c6c9b4b957233002ad961a9a1
diff --git a/data/etc/input/motion_predictor_config.xml b/data/etc/input/motion_predictor_config.xml
index 14540ec..f593eda 100644
--- a/data/etc/input/motion_predictor_config.xml
+++ b/data/etc/input/motion_predictor_config.xml
@@ -38,7 +38,8 @@
<low-jerk>1.5</low-jerk>
<high-jerk>2.0</high-jerk>
- <!-- The forget factor in the first-order IIR filter for jerk smoothing -->
- <jerk-forget-factor>0.25</jerk-forget-factor>
+ <!-- The alpha in the first-order IIR filter for jerk smoothing. An alpha
+ of 1 results in no smoothing.-->
+ <jerk-alpha>0.25</jerk-alpha>
</motion-predictor>
diff --git a/include/input/MotionPredictor.h b/include/input/MotionPredictor.h
index 2f1ef86..200c301 100644
--- a/include/input/MotionPredictor.h
+++ b/include/input/MotionPredictor.h
@@ -43,7 +43,9 @@
class JerkTracker {
public:
// Initialize the tracker. If normalizedDt is true, assume that each sample pushed has dt=1.
- JerkTracker(bool normalizedDt);
+ // alpha is the coefficient of the first-order IIR filter for jerk. A factor of 1 results
+ // in no smoothing.
+ JerkTracker(bool normalizedDt, float alpha);
// Add a position to the tracker and update derivative estimates.
void pushSample(int64_t timestamp, float xPos, float yPos);
@@ -56,15 +58,10 @@
// acceleration) and has the units of d^3p/dt^3.
std::optional<float> jerkMagnitude() const;
- // forgetFactor is the coefficient of the first-order IIR filter for jerk. A factor of 1 results
- // in no smoothing.
- void setForgetFactor(float forgetFactor);
- float getForgetFactor() const;
-
private:
const bool mNormalizedDt;
// Coefficient of first-order IIR filter to smooth jerk calculation.
- float mForgetFactor = 1;
+ const float mAlpha;
RingBuffer<int64_t> mTimestamps{4};
std::array<float, 4> mXDerivatives{}; // [x, x', x'', x''']
@@ -124,11 +121,6 @@
bool isPredictionAvailable(int32_t deviceId, int32_t source);
- /**
- * Currently used to expose config constants in testing.
- */
- const TfLiteMotionPredictorModel::Config& getModelConfig();
-
private:
const nsecs_t mPredictionTimestampOffsetNanos;
const std::function<bool()> mCheckMotionPredictionEnabled;
@@ -137,15 +129,17 @@
std::unique_ptr<TfLiteMotionPredictorBuffers> mBuffers;
std::optional<MotionEvent> mLastEvent;
- // mJerkTracker assumes normalized dt = 1 between recorded samples because
- // the underlying mModel input also assumes fixed-interval samples.
- // Normalized dt as 1 is also used to correspond with the similar Jank
- // implementation from the JetPack MotionPredictor implementation.
- JerkTracker mJerkTracker{true};
- std::optional<MotionPredictorMetricsManager> mMetricsManager;
+ std::unique_ptr<JerkTracker> mJerkTracker;
+
+ std::unique_ptr<MotionPredictorMetricsManager> mMetricsManager;
const ReportAtomFunction mReportAtomFunction;
+
+ // Initialize prediction model and associated objects.
+ // Called during lazy initialization.
+ // TODO: b/210158587 Consider removing lazy initialization.
+ void initializeObjects();
};
} // namespace android
diff --git a/include/input/TfLiteMotionPredictor.h b/include/input/TfLiteMotionPredictor.h
index 08a4330..49e909e 100644
--- a/include/input/TfLiteMotionPredictor.h
+++ b/include/input/TfLiteMotionPredictor.h
@@ -112,7 +112,7 @@
float highJerk = 0;
// Coefficient for the first-order IIR filter for jerk calculation.
- float jerkForgetFactor = 1;
+ float jerkAlpha = 1;
};
// Creates a model from an encoded Flatbuffer model.
diff --git a/libs/input/MotionPredictor.cpp b/libs/input/MotionPredictor.cpp
index 9c70535..c61d394 100644
--- a/libs/input/MotionPredictor.cpp
+++ b/libs/input/MotionPredictor.cpp
@@ -72,7 +72,8 @@
// --- JerkTracker ---
-JerkTracker::JerkTracker(bool normalizedDt) : mNormalizedDt(normalizedDt) {}
+JerkTracker::JerkTracker(bool normalizedDt, float alpha)
+ : mNormalizedDt(normalizedDt), mAlpha(alpha) {}
void JerkTracker::pushSample(int64_t timestamp, float xPos, float yPos) {
// If we previously had full samples, we have a previous jerk calculation
@@ -122,7 +123,7 @@
float newJerkMagnitude = std::hypot(newXDerivatives[3], newYDerivatives[3]);
ALOGD_IF(isDebug(), "raw jerk: %f", newJerkMagnitude);
if (applySmoothing) {
- mJerkMagnitude = mJerkMagnitude + (mForgetFactor * (newJerkMagnitude - mJerkMagnitude));
+ mJerkMagnitude = mJerkMagnitude + (mAlpha * (newJerkMagnitude - mJerkMagnitude));
} else {
mJerkMagnitude = newJerkMagnitude;
}
@@ -143,14 +144,6 @@
return std::nullopt;
}
-void JerkTracker::setForgetFactor(float forgetFactor) {
- mForgetFactor = forgetFactor;
-}
-
-float JerkTracker::getForgetFactor() const {
- return mForgetFactor;
-}
-
// --- MotionPredictor ---
MotionPredictor::MotionPredictor(nsecs_t predictionTimestampOffsetNanos,
@@ -160,6 +153,24 @@
mCheckMotionPredictionEnabled(std::move(checkMotionPredictionEnabled)),
mReportAtomFunction(reportAtomFunction) {}
+void MotionPredictor::initializeObjects() {
+ mModel = TfLiteMotionPredictorModel::create();
+ LOG_ALWAYS_FATAL_IF(!mModel);
+
+ // mJerkTracker assumes normalized dt = 1 between recorded samples because
+ // the underlying mModel input also assumes fixed-interval samples.
+ // Normalized dt as 1 is also used to correspond with the similar Jank
+ // implementation from the JetPack MotionPredictor implementation.
+ mJerkTracker = std::make_unique<JerkTracker>(/*normalizedDt=*/true, mModel->config().jerkAlpha);
+
+ mBuffers = std::make_unique<TfLiteMotionPredictorBuffers>(mModel->inputLength());
+
+ mMetricsManager =
+ std::make_unique<MotionPredictorMetricsManager>(mModel->config().predictionInterval,
+ mModel->outputLength(),
+ mReportAtomFunction);
+}
+
android::base::Result<void> MotionPredictor::record(const MotionEvent& event) {
if (mLastEvent && mLastEvent->getDeviceId() != event.getDeviceId()) {
// We still have an active gesture for another device. The provided MotionEvent is not
@@ -176,29 +187,18 @@
return {};
}
- // Initialise the model now that it's likely to be used.
if (!mModel) {
- mModel = TfLiteMotionPredictorModel::create();
- LOG_ALWAYS_FATAL_IF(!mModel);
- mJerkTracker.setForgetFactor(mModel->config().jerkForgetFactor);
- }
-
- if (!mBuffers) {
- mBuffers = std::make_unique<TfLiteMotionPredictorBuffers>(mModel->inputLength());
+ initializeObjects();
}
// Pass input event to the MetricsManager.
- if (!mMetricsManager) {
- mMetricsManager.emplace(mModel->config().predictionInterval, mModel->outputLength(),
- mReportAtomFunction);
- }
mMetricsManager->onRecord(event);
const int32_t action = event.getActionMasked();
if (action == AMOTION_EVENT_ACTION_UP || action == AMOTION_EVENT_ACTION_CANCEL) {
ALOGD_IF(isDebug(), "End of event stream");
mBuffers->reset();
- mJerkTracker.reset();
+ mJerkTracker->reset();
mLastEvent.reset();
return {};
} else if (action != AMOTION_EVENT_ACTION_DOWN && action != AMOTION_EVENT_ACTION_MOVE) {
@@ -233,9 +233,9 @@
0, i),
.orientation = event.getHistoricalOrientation(0, i),
});
- mJerkTracker.pushSample(event.getHistoricalEventTime(i),
- coords->getAxisValue(AMOTION_EVENT_AXIS_X),
- coords->getAxisValue(AMOTION_EVENT_AXIS_Y));
+ mJerkTracker->pushSample(event.getHistoricalEventTime(i),
+ coords->getAxisValue(AMOTION_EVENT_AXIS_X),
+ coords->getAxisValue(AMOTION_EVENT_AXIS_Y));
}
if (!mLastEvent) {
@@ -283,7 +283,7 @@
int64_t predictionTime = mBuffers->lastTimestamp();
const int64_t futureTime = timestamp + mPredictionTimestampOffsetNanos;
- const float jerkMagnitude = mJerkTracker.jerkMagnitude().value_or(0);
+ const float jerkMagnitude = mJerkTracker->jerkMagnitude().value_or(0);
const float fractionKept =
1 - normalizeRange(jerkMagnitude, mModel->config().lowJerk, mModel->config().highJerk);
// float to ensure proper division below.
@@ -379,12 +379,4 @@
return true;
}
-const TfLiteMotionPredictorModel::Config& MotionPredictor::getModelConfig() {
- if (!mModel) {
- mModel = TfLiteMotionPredictorModel::create();
- LOG_ALWAYS_FATAL_IF(!mModel);
- }
- return mModel->config();
-}
-
} // namespace android
diff --git a/libs/input/TfLiteMotionPredictor.cpp b/libs/input/TfLiteMotionPredictor.cpp
index b401c98..5250a9d 100644
--- a/libs/input/TfLiteMotionPredictor.cpp
+++ b/libs/input/TfLiteMotionPredictor.cpp
@@ -283,7 +283,7 @@
.distanceNoiseFloor = parseXMLFloat(*configRoot, "distance-noise-floor"),
.lowJerk = parseXMLFloat(*configRoot, "low-jerk"),
.highJerk = parseXMLFloat(*configRoot, "high-jerk"),
- .jerkForgetFactor = parseXMLFloat(*configRoot, "jerk-forget-factor"),
+ .jerkAlpha = parseXMLFloat(*configRoot, "jerk-alpha"),
};
return std::unique_ptr<TfLiteMotionPredictorModel>(
diff --git a/libs/input/tests/MotionPredictor_test.cpp b/libs/input/tests/MotionPredictor_test.cpp
index 5bd5794..106e686 100644
--- a/libs/input/tests/MotionPredictor_test.cpp
+++ b/libs/input/tests/MotionPredictor_test.cpp
@@ -70,7 +70,7 @@
}
TEST(JerkTrackerTest, JerkReadiness) {
- JerkTracker jerkTracker(true);
+ JerkTracker jerkTracker(/*normalizedDt=*/true, /*alpha=*/1);
EXPECT_FALSE(jerkTracker.jerkMagnitude());
jerkTracker.pushSample(/*timestamp=*/0, 20, 50);
EXPECT_FALSE(jerkTracker.jerkMagnitude());
@@ -87,8 +87,8 @@
}
TEST(JerkTrackerTest, JerkCalculationNormalizedDtTrue) {
- JerkTracker jerkTracker(true);
- jerkTracker.setForgetFactor(.5);
+ const float alpha = .5;
+ JerkTracker jerkTracker(/*normalizedDt=*/true, alpha);
jerkTracker.pushSample(/*timestamp=*/0, 20, 50);
jerkTracker.pushSample(/*timestamp=*/1, 25, 53);
jerkTracker.pushSample(/*timestamp=*/2, 30, 60);
@@ -119,14 +119,13 @@
* y'': 3 -> -15
* y''': -18
*/
- const float newJerk = (1 - jerkTracker.getForgetFactor()) * std::hypot(10, -1) +
- jerkTracker.getForgetFactor() * std::hypot(-50, -18);
+ const float newJerk = (1 - alpha) * std::hypot(10, -1) + alpha * std::hypot(-50, -18);
EXPECT_FLOAT_EQ(jerkTracker.jerkMagnitude().value(), newJerk);
}
TEST(JerkTrackerTest, JerkCalculationNormalizedDtFalse) {
- JerkTracker jerkTracker(false);
- jerkTracker.setForgetFactor(.5);
+ const float alpha = .5;
+ JerkTracker jerkTracker(/*normalizedDt=*/false, alpha);
jerkTracker.pushSample(/*timestamp=*/0, 20, 50);
jerkTracker.pushSample(/*timestamp=*/10, 25, 53);
jerkTracker.pushSample(/*timestamp=*/20, 30, 60);
@@ -157,13 +156,12 @@
* y'': .03 -> -.125 (delta above, divide by 10)
* y''': -.0155 (delta above, divide by 10)
*/
- const float newJerk = (1 - jerkTracker.getForgetFactor()) * std::hypot(.01, -.001) +
- jerkTracker.getForgetFactor() * std::hypot(-.0375, -.0155);
+ const float newJerk = (1 - alpha) * std::hypot(.01, -.001) + alpha * std::hypot(-.0375, -.0155);
EXPECT_FLOAT_EQ(jerkTracker.jerkMagnitude().value(), newJerk);
}
TEST(JerkTrackerTest, JerkCalculationAfterReset) {
- JerkTracker jerkTracker(true);
+ JerkTracker jerkTracker(/*normalizedDt=*/true, /*alpha=*/1);
jerkTracker.pushSample(/*timestamp=*/0, 20, 50);
jerkTracker.pushSample(/*timestamp=*/1, 25, 53);
jerkTracker.pushSample(/*timestamp=*/2, 30, 60);
@@ -297,8 +295,11 @@
MotionPredictor predictor(/*predictionTimestampOffsetNanos=*/0,
[]() { return true /*enable prediction*/; });
+ // Create another instance of TfLiteMotionPredictorModel to read config details.
+ std::unique_ptr<TfLiteMotionPredictorModel> testTfLiteModel =
+ TfLiteMotionPredictorModel::create();
const float mediumJerk =
- (predictor.getModelConfig().lowJerk + predictor.getModelConfig().highJerk) / 2;
+ (testTfLiteModel->config().lowJerk + testTfLiteModel->config().highJerk) / 2;
const float a = 3; // initial acceleration
const float b = 4; // initial velocity
const float c = 5; // initial position