Update motion prediction model.
Input events with no movement (r = 0) are now included in the buffer
so that the model can accurately determine when the input device has
become stationary, and a noise floor is added to prevent spurious
predictions when this happens.
Benchmark results:
Old:
timeRecordAndPredict_mean (ns): 17990
timeRecordAndPredict_median (ns): 18024
timeRecordAndPredict_min (ns): 17606
timeRecordAndPredict_standardDeviation: 345
New:
timeRecordAndPredict_mean (ns): 38394
timeRecordAndPredict_median (ns): 38476
timeRecordAndPredict_min (ns): 38083
timeRecordAndPredict_standardDeviation: 187
Bug: 288354672
PiperOrigin-RevId: 549064247
Test: predictions are visible in the motionprediction test app
Test: atest CtsInputTestCases
Test: atest MotionPredictorBenchmark MotionPredictorTest
Test: atest --host libinput_tests
Change-Id: I6c3917591323d7117c4ee2e91abf6c6004178f19
diff --git a/libs/input/TfLiteMotionPredictor.cpp b/libs/input/TfLiteMotionPredictor.cpp
index 9f4aaa8..5984b4d3 100644
--- a/libs/input/TfLiteMotionPredictor.cpp
+++ b/libs/input/TfLiteMotionPredictor.cpp
@@ -100,6 +100,16 @@
return value;
}
+float parseXMLFloat(const tinyxml2::XMLElement& configRoot, const char* elementName) {
+ const tinyxml2::XMLElement* element = configRoot.FirstChildElement(elementName);
+ LOG_ALWAYS_FATAL_IF(!element, "Could not find '%s' element", elementName);
+
+ float value = 0;
+ LOG_ALWAYS_FATAL_IF(element->QueryFloatText(&value) != tinyxml2::XML_SUCCESS,
+ "Failed to parse %s: %s", elementName, element->GetText());
+ return value;
+}
+
// A TFLite ErrorReporter that logs to logcat.
class LoggingErrorReporter : public tflite::ErrorReporter {
public:
@@ -152,6 +162,7 @@
::tflite::ops::builtin::Register_CONCATENATION());
resolver->AddBuiltin(::tflite::BuiltinOperator_FULLY_CONNECTED,
::tflite::ops::builtin::Register_FULLY_CONNECTED());
+ resolver->AddBuiltin(::tflite::BuiltinOperator_GELU, ::tflite::ops::builtin::Register_GELU());
return resolver;
}
@@ -208,13 +219,7 @@
float phi = 0;
float orientation = 0;
- // Ignore the sample if there is no movement. These samples can occur when there's change to a
- // property other than the coordinates and pollute the input to the model.
- if (r == 0) {
- return;
- }
-
- if (!mAxisFrom) { // Second point.
+ if (!mAxisFrom && r > 0) { // Second point.
// We can only determine the distance from the first point, and not any
// angle. However, if the second point forms an axis, the orientation can
// be transformed relative to that axis.
@@ -235,8 +240,10 @@
}
// Update the axis for the next point.
- mAxisFrom = mAxisTo;
- mAxisTo = sample;
+ if (r > 0) {
+ mAxisFrom = mAxisTo;
+ mAxisTo = sample;
+ }
// Push the current sample onto the end of the input buffers.
mInputR.pushBack(r);
@@ -272,15 +279,18 @@
// Parse configuration file.
const tinyxml2::XMLElement* configRoot = configDocument.FirstChildElement("motion-predictor");
LOG_ALWAYS_FATAL_IF(!configRoot);
- const nsecs_t predictionInterval = parseXMLInt64(*configRoot, "prediction-interval");
+ Config config{
+ .predictionInterval = parseXMLInt64(*configRoot, "prediction-interval"),
+ .distanceNoiseFloor = parseXMLFloat(*configRoot, "distance-noise-floor"),
+ };
return std::unique_ptr<TfLiteMotionPredictorModel>(
- new TfLiteMotionPredictorModel(std::move(modelBuffer), predictionInterval));
+ new TfLiteMotionPredictorModel(std::move(modelBuffer), std::move(config)));
}
TfLiteMotionPredictorModel::TfLiteMotionPredictorModel(
- std::unique_ptr<android::base::MappedFile> model, nsecs_t predictionInterval)
- : mFlatBuffer(std::move(model)), mPredictionInterval(predictionInterval) {
+ std::unique_ptr<android::base::MappedFile> model, Config config)
+ : mFlatBuffer(std::move(model)), mConfig(std::move(config)) {
CHECK(mFlatBuffer);
mErrorReporter = std::make_unique<LoggingErrorReporter>();
mModel = tflite::FlatBufferModel::VerifyAndBuildFromBuffer(mFlatBuffer->data(),