Update motion prediction model.
Input events with no movement (r = 0) are now included in the buffer
so that the model can accurately determine when the input device has
become stationary, and a noise floor is added to prevent spurious
predictions when this happens.
Benchmark results:
Old:
timeRecordAndPredict_mean (ns): 17990
timeRecordAndPredict_median (ns): 18024
timeRecordAndPredict_min (ns): 17606
timeRecordAndPredict_standardDeviation: 345
New:
timeRecordAndPredict_mean (ns): 38394
timeRecordAndPredict_median (ns): 38476
timeRecordAndPredict_min (ns): 38083
timeRecordAndPredict_standardDeviation: 187
Bug: 288354672
PiperOrigin-RevId: 549064247
Test: predictions are visible in the motionprediction test app
Test: atest CtsInputTestCases
Test: atest MotionPredictorBenchmark MotionPredictorTest
Test: atest --host libinput_tests
Change-Id: I6c3917591323d7117c4ee2e91abf6c6004178f19
diff --git a/data/etc/input/motion_predictor_config.xml b/data/etc/input/motion_predictor_config.xml
index 03dfd63..39772ae 100644
--- a/data/etc/input/motion_predictor_config.xml
+++ b/data/etc/input/motion_predictor_config.xml
@@ -16,5 +16,20 @@
<motion-predictor>
<!-- The time interval (ns) between the model's predictions. -->
<prediction-interval>4166666</prediction-interval> <!-- 4.167 ms = ~240 Hz -->
+ <!-- The noise floor (px) for predicted distances.
+
+ As the model is trained stochastically, there is some expected minimum
+ variability in its output. This can be a UX issue when the input device
+ is moving slowly and the variability is large relative to the magnitude
+ of the motion. In these cases, it is better to inhibit the prediction,
+ rather than show noisy predictions (and there is little benefit to
+ prediction anyway).
+
+ The value for this parameter should at least be close to the maximum
+ predicted distance when the input device is held stationary (i.e. the
+ expected minimum variability), and perhaps a little larger to capture
+ the UX issue mentioned above.
+ -->
+ <distance-noise-floor>0.2</distance-noise-floor>
</motion-predictor>
diff --git a/data/etc/input/motion_predictor_model.tflite b/data/etc/input/motion_predictor_model.tflite
index 10b3c8b..45fc162 100644
--- a/data/etc/input/motion_predictor_model.tflite
+++ b/data/etc/input/motion_predictor_model.tflite
Binary files differ
diff --git a/include/input/TfLiteMotionPredictor.h b/include/input/TfLiteMotionPredictor.h
index fbd6026..2edc138 100644
--- a/include/input/TfLiteMotionPredictor.h
+++ b/include/input/TfLiteMotionPredictor.h
@@ -99,6 +99,14 @@
// A TFLite model for generating motion predictions.
class TfLiteMotionPredictorModel {
public:
+ struct Config {
+ // The time between predictions.
+ nsecs_t predictionInterval = 0;
+ // The noise floor for predictions.
+ // Distances (r) less than this should be discarded as noise.
+ float distanceNoiseFloor = 0;
+ };
+
// Creates a model from an encoded Flatbuffer model.
static std::unique_ptr<TfLiteMotionPredictorModel> create();
@@ -110,8 +118,7 @@
// Returns the length of the model's output buffers.
size_t outputLength() const;
- // Returns the time interval between predictions.
- nsecs_t predictionInterval() const { return mPredictionInterval; }
+ const Config& config() const { return mConfig; }
// Executes the model.
// Returns true if the model successfully executed and the output tensors can be read.
@@ -132,7 +139,7 @@
private:
explicit TfLiteMotionPredictorModel(std::unique_ptr<android::base::MappedFile> model,
- nsecs_t predictionInterval);
+ Config config);
void allocateTensors();
void attachInputTensors();
@@ -154,7 +161,7 @@
std::unique_ptr<tflite::Interpreter> mInterpreter;
tflite::SignatureRunner* mRunner = nullptr;
- const nsecs_t mPredictionInterval = 0;
+ const Config mConfig = {};
};
} // namespace android
diff --git a/libs/input/MotionPredictor.cpp b/libs/input/MotionPredictor.cpp
index 68e6888..c2ea35c 100644
--- a/libs/input/MotionPredictor.cpp
+++ b/libs/input/MotionPredictor.cpp
@@ -138,7 +138,8 @@
// Pass input event to the MetricsManager.
if (!mMetricsManager) {
mMetricsManager =
- std::make_optional<MotionPredictorMetricsManager>(mModel->predictionInterval(),
+ std::make_optional<MotionPredictorMetricsManager>(mModel->config()
+ .predictionInterval,
mModel->outputLength());
}
mMetricsManager->onRecord(event);
@@ -184,8 +185,18 @@
const int64_t futureTime = timestamp + mPredictionTimestampOffsetNanos;
for (int i = 0; i < predictedR.size() && predictionTime <= futureTime; ++i) {
- // TODO(b/266747654): Stop predictions if confidence and/or predicted pressure are below
- // some thresholds.
+ if (predictedR[i] < mModel->config().distanceNoiseFloor) {
+ // Stop predicting when the predicted output is below the model's noise floor.
+ //
+ // We assume that all subsequent predictions in the batch are unreliable because later
+ // predictions are conditional on earlier predictions, and a state of noise is not a
+ // good basis for prediction.
+ //
+ // The UX trade-off is that this potentially sacrifices some predictions when the input
+ // device starts to speed up, but avoids producing noisy predictions as it slows down.
+ break;
+ }
+ // TODO(b/266747654): Stop predictions if confidence is < some threshold.
const TfLiteMotionPredictorSample::Point predictedPoint =
convertPrediction(axisFrom, axisTo, predictedR[i], predictedPhi[i]);
@@ -197,7 +208,7 @@
coords.setAxisValue(AMOTION_EVENT_AXIS_Y, predictedPoint.y);
coords.setAxisValue(AMOTION_EVENT_AXIS_PRESSURE, predictedPressure[i]);
- predictionTime += mModel->predictionInterval();
+ predictionTime += mModel->config().predictionInterval;
if (i == 0) {
hasPredictions = true;
prediction->initialize(InputEvent::nextId(), event.getDeviceId(), event.getSource(),
diff --git a/libs/input/TfLiteMotionPredictor.cpp b/libs/input/TfLiteMotionPredictor.cpp
index 9f4aaa8..5984b4d3 100644
--- a/libs/input/TfLiteMotionPredictor.cpp
+++ b/libs/input/TfLiteMotionPredictor.cpp
@@ -100,6 +100,16 @@
return value;
}
+float parseXMLFloat(const tinyxml2::XMLElement& configRoot, const char* elementName) {
+ const tinyxml2::XMLElement* element = configRoot.FirstChildElement(elementName);
+ LOG_ALWAYS_FATAL_IF(!element, "Could not find '%s' element", elementName);
+
+ float value = 0;
+ LOG_ALWAYS_FATAL_IF(element->QueryFloatText(&value) != tinyxml2::XML_SUCCESS,
+ "Failed to parse %s: %s", elementName, element->GetText());
+ return value;
+}
+
// A TFLite ErrorReporter that logs to logcat.
class LoggingErrorReporter : public tflite::ErrorReporter {
public:
@@ -152,6 +162,7 @@
::tflite::ops::builtin::Register_CONCATENATION());
resolver->AddBuiltin(::tflite::BuiltinOperator_FULLY_CONNECTED,
::tflite::ops::builtin::Register_FULLY_CONNECTED());
+ resolver->AddBuiltin(::tflite::BuiltinOperator_GELU, ::tflite::ops::builtin::Register_GELU());
return resolver;
}
@@ -208,13 +219,7 @@
float phi = 0;
float orientation = 0;
- // Ignore the sample if there is no movement. These samples can occur when there's change to a
- // property other than the coordinates and pollute the input to the model.
- if (r == 0) {
- return;
- }
-
- if (!mAxisFrom) { // Second point.
+ if (!mAxisFrom && r > 0) { // Second point.
// We can only determine the distance from the first point, and not any
// angle. However, if the second point forms an axis, the orientation can
// be transformed relative to that axis.
@@ -235,8 +240,10 @@
}
// Update the axis for the next point.
- mAxisFrom = mAxisTo;
- mAxisTo = sample;
+ if (r > 0) {
+ mAxisFrom = mAxisTo;
+ mAxisTo = sample;
+ }
// Push the current sample onto the end of the input buffers.
mInputR.pushBack(r);
@@ -272,15 +279,18 @@
// Parse configuration file.
const tinyxml2::XMLElement* configRoot = configDocument.FirstChildElement("motion-predictor");
LOG_ALWAYS_FATAL_IF(!configRoot);
- const nsecs_t predictionInterval = parseXMLInt64(*configRoot, "prediction-interval");
+ Config config{
+ .predictionInterval = parseXMLInt64(*configRoot, "prediction-interval"),
+ .distanceNoiseFloor = parseXMLFloat(*configRoot, "distance-noise-floor"),
+ };
return std::unique_ptr<TfLiteMotionPredictorModel>(
- new TfLiteMotionPredictorModel(std::move(modelBuffer), predictionInterval));
+ new TfLiteMotionPredictorModel(std::move(modelBuffer), std::move(config)));
}
TfLiteMotionPredictorModel::TfLiteMotionPredictorModel(
- std::unique_ptr<android::base::MappedFile> model, nsecs_t predictionInterval)
- : mFlatBuffer(std::move(model)), mPredictionInterval(predictionInterval) {
+ std::unique_ptr<android::base::MappedFile> model, Config config)
+ : mFlatBuffer(std::move(model)), mConfig(std::move(config)) {
CHECK(mFlatBuffer);
mErrorReporter = std::make_unique<LoggingErrorReporter>();
mModel = tflite::FlatBufferModel::VerifyAndBuildFromBuffer(mFlatBuffer->data(),
diff --git a/libs/input/tests/MotionPredictor_test.cpp b/libs/input/tests/MotionPredictor_test.cpp
index 7a62f5e..4ac7ae9 100644
--- a/libs/input/tests/MotionPredictor_test.cpp
+++ b/libs/input/tests/MotionPredictor_test.cpp
@@ -72,11 +72,20 @@
ASSERT_FALSE(predictor.isPredictionAvailable(/*deviceId=*/1, AINPUT_SOURCE_TOUCHSCREEN));
}
+TEST(MotionPredictorTest, StationaryNoiseFloor) {
+ MotionPredictor predictor(/*predictionTimestampOffsetNanos=*/1,
+ []() { return true /*enable prediction*/; });
+ predictor.record(getMotionEvent(DOWN, 0, 1, 30ms));
+ predictor.record(getMotionEvent(MOVE, 0, 1, 35ms)); // No movement.
+ std::unique_ptr<MotionEvent> predicted = predictor.predict(40 * NSEC_PER_MSEC);
+ ASSERT_EQ(nullptr, predicted);
+}
+
TEST(MotionPredictorTest, Offset) {
MotionPredictor predictor(/*predictionTimestampOffsetNanos=*/1,
[]() { return true /*enable prediction*/; });
predictor.record(getMotionEvent(DOWN, 0, 1, 30ms));
- predictor.record(getMotionEvent(MOVE, 0, 2, 35ms));
+ predictor.record(getMotionEvent(MOVE, 0, 5, 35ms)); // Move enough to overcome the noise floor.
std::unique_ptr<MotionEvent> predicted = predictor.predict(40 * NSEC_PER_MSEC);
ASSERT_NE(nullptr, predicted);
ASSERT_GE(predicted->getEventTime(), 41);