Add TFLite model for motion prediction.
This model generates probabilistic motion predictions based on a
sequence of relative input movements. The input movements are converted
into polar coordinates (distance and angle) based on an axis that
follows the current path. This ensures that the orientation of the
device and of the inputs do not affect the predictions. The orientation
of the input device is also transformed to be relative to the path axis.
The test cases verifying model efficacy are consolidated into CTS.
Bug: 167946763
Test: atest libinput_tests
PiperOrigin-RevId: 492068340
Change-Id: Icd8d90bd5a7ce79c699bfdb6367a4cbd8130441a
diff --git a/include/input/MotionPredictor.h b/include/input/MotionPredictor.h
index 045e61b..3fae4e6 100644
--- a/include/input/MotionPredictor.h
+++ b/include/input/MotionPredictor.h
@@ -16,9 +16,15 @@
#pragma once
+#include <cstdint>
+#include <memory>
+#include <mutex>
+#include <unordered_map>
+
#include <android-base/thread_annotations.h>
#include <android/sysprop/InputProperties.sysprop.h>
#include <input/Input.h>
+#include <input/TfLiteMotionPredictor.h>
namespace android {
@@ -28,48 +34,51 @@
/**
* Given a set of MotionEvents for the current gesture, predict the motion. The returned MotionEvent
- * contains a set of samples in the future, up to "presentation time + offset".
+ * contains a set of samples in the future.
*
* The typical usage is like this:
*
* MotionPredictor predictor(offset = MY_OFFSET);
- * predictor.setExpectedPresentationTimeNanos(NEXT_PRESENT_TIME);
* predictor.record(DOWN_MOTION_EVENT);
* predictor.record(MOVE_MOTION_EVENT);
- * prediction = predictor.predict();
+ * prediction = predictor.predict(futureTime);
*
- * The presentation time should be set some time before calling .predict(). It could be set before
- * or after the recorded motion events. Must be done on every frame.
- *
- * The resulting motion event will have eventTime <= (NEXT_PRESENT_TIME + MY_OFFSET). It might
- * contain historical data, which are additional samples from the latest recorded MotionEvent's
- * eventTime to the NEXT_PRESENT_TIME + MY_OFFSET.
+ * The resulting motion event will have eventTime <= (futureTime + MY_OFFSET). It might contain
+ * historical data, which are additional samples from the latest recorded MotionEvent's eventTime
+ * to the futureTime + MY_OFFSET.
*
* The offset is used to provide additional flexibility to the caller, in case the default present
* time (typically provided by the choreographer) does not account for some delays, or to simply
- * reduce the aggressiveness of the prediction. Offset can be both positive and negative.
+ * reduce the aggressiveness of the prediction. Offset can be positive or negative.
*/
class MotionPredictor {
public:
/**
* Parameters:
* predictionTimestampOffsetNanos: additional, constant shift to apply to the target
- * presentation time. The prediction will target the time t=(presentationTime +
+ * prediction time. The prediction will target the time t=(prediction time +
* predictionTimestampOffsetNanos).
*
+ * modelPath: filesystem path to a TfLiteMotionPredictorModel flatbuffer, or nullptr to use the
+ * default model path.
+ *
* checkEnableMotionPredition: the function to check whether the prediction should run. Used to
* provide an additional way of turning prediction on and off. Can be toggled at runtime.
*/
- MotionPredictor(nsecs_t predictionTimestampOffsetNanos,
+ MotionPredictor(nsecs_t predictionTimestampOffsetNanos, const char* modelPath = nullptr,
std::function<bool()> checkEnableMotionPrediction = isMotionPredictionEnabled);
void record(const MotionEvent& event);
std::vector<std::unique_ptr<MotionEvent>> predict(nsecs_t timestamp);
bool isPredictionAvailable(int32_t deviceId, int32_t source);
private:
- std::vector<MotionEvent> mEvents;
const nsecs_t mPredictionTimestampOffsetNanos;
const std::function<bool()> mCheckMotionPredictionEnabled;
+
+ std::unique_ptr<TfLiteMotionPredictorModel> mModel;
+ // Buffers/events for each device seen by record().
+ std::unordered_map</*deviceId*/ int32_t, TfLiteMotionPredictorBuffers> mDeviceBuffers;
+ std::unordered_map</*deviceId*/ int32_t, MotionEvent> mLastEvents;
};
} // namespace android
diff --git a/include/input/TfLiteMotionPredictor.h b/include/input/TfLiteMotionPredictor.h
new file mode 100644
index 0000000..ff0f51c
--- /dev/null
+++ b/include/input/TfLiteMotionPredictor.h
@@ -0,0 +1,147 @@
+/*
+ * Copyright (C) 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include <array>
+#include <cstddef>
+#include <cstdint>
+#include <memory>
+#include <optional>
+#include <span>
+#include <string>
+#include <vector>
+
+#include <tensorflow/lite/core/api/error_reporter.h>
+#include <tensorflow/lite/interpreter.h>
+#include <tensorflow/lite/model.h>
+#include <tensorflow/lite/signature_runner.h>
+
+namespace android {
+
+struct TfLiteMotionPredictorSample {
+ // The untransformed AMOTION_EVENT_AXIS_X and AMOTION_EVENT_AXIS_Y of the sample.
+ struct Point {
+ float x;
+ float y;
+ } position;
+ // The AMOTION_EVENT_AXIS_PRESSURE, _TILT, and _ORIENTATION.
+ float pressure;
+ float tilt;
+ float orientation;
+};
+
+inline TfLiteMotionPredictorSample::Point operator-(const TfLiteMotionPredictorSample::Point& lhs,
+ const TfLiteMotionPredictorSample::Point& rhs) {
+ return {.x = lhs.x - rhs.x, .y = lhs.y - rhs.y};
+}
+
+class TfLiteMotionPredictorModel;
+
+// Buffer storage for a TfLiteMotionPredictorModel.
+class TfLiteMotionPredictorBuffers {
+public:
+ // Creates buffer storage for a model with the given input length.
+ TfLiteMotionPredictorBuffers(size_t inputLength);
+
+ // Adds a motion sample to the buffers.
+ void pushSample(int64_t timestamp, TfLiteMotionPredictorSample sample);
+
+ // Returns true if the buffers are complete enough to generate a prediction.
+ bool isReady() const {
+ // Predictions can't be applied unless there are at least two points to determine
+ // the direction to apply them in.
+ return mAxisFrom && mAxisTo;
+ }
+
+ // Resets all buffers to their initial state.
+ void reset();
+
+ // Copies the buffers to those of a model for prediction.
+ void copyTo(TfLiteMotionPredictorModel& model) const;
+
+ // Returns the current axis of the buffer's samples. Only valid if isReady().
+ TfLiteMotionPredictorSample axisFrom() const { return *mAxisFrom; }
+ TfLiteMotionPredictorSample axisTo() const { return *mAxisTo; }
+
+ // Returns the timestamp of the last sample.
+ int64_t lastTimestamp() const { return mTimestamp; }
+
+private:
+ int64_t mTimestamp = 0;
+
+ std::vector<float> mInputR;
+ std::vector<float> mInputPhi;
+ std::vector<float> mInputPressure;
+ std::vector<float> mInputTilt;
+ std::vector<float> mInputOrientation;
+
+ // The samples defining the current polar axis.
+ std::optional<TfLiteMotionPredictorSample> mAxisFrom;
+ std::optional<TfLiteMotionPredictorSample> mAxisTo;
+};
+
+// A TFLite model for generating motion predictions.
+class TfLiteMotionPredictorModel {
+public:
+ // Creates a model from an encoded Flatbuffer model.
+ static std::unique_ptr<TfLiteMotionPredictorModel> create(const char* modelPath);
+
+ // Returns the length of the model's input buffers.
+ size_t inputLength() const;
+
+ // Executes the model.
+ // Returns true if the model successfully executed and the output tensors can be read.
+ bool invoke();
+
+ // Returns mutable buffers to the input tensors of inputLength() elements.
+ std::span<float> inputR();
+ std::span<float> inputPhi();
+ std::span<float> inputPressure();
+ std::span<float> inputOrientation();
+ std::span<float> inputTilt();
+
+ // Returns immutable buffers to the output tensors of identical length. Only valid after a
+ // successful call to invoke().
+ std::span<const float> outputR() const;
+ std::span<const float> outputPhi() const;
+ std::span<const float> outputPressure() const;
+
+private:
+ explicit TfLiteMotionPredictorModel(std::string model);
+
+ void allocateTensors();
+ void attachInputTensors();
+ void attachOutputTensors();
+
+ TfLiteTensor* mInputR = nullptr;
+ TfLiteTensor* mInputPhi = nullptr;
+ TfLiteTensor* mInputPressure = nullptr;
+ TfLiteTensor* mInputTilt = nullptr;
+ TfLiteTensor* mInputOrientation = nullptr;
+
+ const TfLiteTensor* mOutputR = nullptr;
+ const TfLiteTensor* mOutputPhi = nullptr;
+ const TfLiteTensor* mOutputPressure = nullptr;
+
+ std::string mFlatBuffer;
+ std::unique_ptr<tflite::ErrorReporter> mErrorReporter;
+ std::unique_ptr<tflite::FlatBufferModel> mModel;
+ std::unique_ptr<tflite::Interpreter> mInterpreter;
+ tflite::SignatureRunner* mRunner = nullptr;
+};
+
+} // namespace android