blob: 728a8e1e39c4be4c3d9026bda5dfabccc4f7fe48 [file] [log] [blame]
Philip Quinn8f953ab2022-12-06 15:37:07 -08001/*
2 * Copyright (C) 2023 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#pragma once
18
19#include <array>
20#include <cstddef>
21#include <cstdint>
22#include <memory>
23#include <optional>
24#include <span>
Philip Quinn9b8926e2023-01-31 14:50:02 -080025
Philip Quinncb3229a2023-02-08 22:50:59 -080026#include <android-base/mapped_file.h>
Philip Quinn9b8926e2023-01-31 14:50:02 -080027#include <input/RingBuffer.h>
Philip Quinnf84fa492023-06-26 14:15:15 -070028#include <utils/Timers.h>
Philip Quinn8f953ab2022-12-06 15:37:07 -080029
30#include <tensorflow/lite/core/api/error_reporter.h>
31#include <tensorflow/lite/interpreter.h>
32#include <tensorflow/lite/model.h>
33#include <tensorflow/lite/signature_runner.h>
34
35namespace android {
36
37struct TfLiteMotionPredictorSample {
38 // The untransformed AMOTION_EVENT_AXIS_X and AMOTION_EVENT_AXIS_Y of the sample.
39 struct Point {
40 float x;
41 float y;
42 } position;
43 // The AMOTION_EVENT_AXIS_PRESSURE, _TILT, and _ORIENTATION.
44 float pressure;
45 float tilt;
46 float orientation;
47};
48
49inline TfLiteMotionPredictorSample::Point operator-(const TfLiteMotionPredictorSample::Point& lhs,
50 const TfLiteMotionPredictorSample::Point& rhs) {
51 return {.x = lhs.x - rhs.x, .y = lhs.y - rhs.y};
52}
53
54class TfLiteMotionPredictorModel;
55
56// Buffer storage for a TfLiteMotionPredictorModel.
57class TfLiteMotionPredictorBuffers {
58public:
59 // Creates buffer storage for a model with the given input length.
60 TfLiteMotionPredictorBuffers(size_t inputLength);
61
62 // Adds a motion sample to the buffers.
63 void pushSample(int64_t timestamp, TfLiteMotionPredictorSample sample);
64
65 // Returns true if the buffers are complete enough to generate a prediction.
66 bool isReady() const {
67 // Predictions can't be applied unless there are at least two points to determine
68 // the direction to apply them in.
69 return mAxisFrom && mAxisTo;
70 }
71
72 // Resets all buffers to their initial state.
73 void reset();
74
75 // Copies the buffers to those of a model for prediction.
76 void copyTo(TfLiteMotionPredictorModel& model) const;
77
78 // Returns the current axis of the buffer's samples. Only valid if isReady().
79 TfLiteMotionPredictorSample axisFrom() const { return *mAxisFrom; }
80 TfLiteMotionPredictorSample axisTo() const { return *mAxisTo; }
81
82 // Returns the timestamp of the last sample.
83 int64_t lastTimestamp() const { return mTimestamp; }
84
85private:
86 int64_t mTimestamp = 0;
87
Philip Quinn9b8926e2023-01-31 14:50:02 -080088 RingBuffer<float> mInputR;
89 RingBuffer<float> mInputPhi;
90 RingBuffer<float> mInputPressure;
91 RingBuffer<float> mInputTilt;
92 RingBuffer<float> mInputOrientation;
Philip Quinn8f953ab2022-12-06 15:37:07 -080093
94 // The samples defining the current polar axis.
95 std::optional<TfLiteMotionPredictorSample> mAxisFrom;
96 std::optional<TfLiteMotionPredictorSample> mAxisTo;
97};
98
99// A TFLite model for generating motion predictions.
100class TfLiteMotionPredictorModel {
101public:
Philip Quinn107ce702023-07-14 13:07:13 -0700102 struct Config {
103 // The time between predictions.
104 nsecs_t predictionInterval = 0;
105 // The noise floor for predictions.
106 // Distances (r) less than this should be discarded as noise.
107 float distanceNoiseFloor = 0;
Derek Wuaaa47312024-03-26 15:53:44 -0700108
109 // Low and high jerk thresholds (with normalized dt = 1) for predictions.
110 // High jerk means more predictions will be pruned, vice versa for low.
111 float lowJerk = 0;
112 float highJerk = 0;
Philip Quinn107ce702023-07-14 13:07:13 -0700113 };
114
Philip Quinn8f953ab2022-12-06 15:37:07 -0800115 // Creates a model from an encoded Flatbuffer model.
Siarhei Vishniakoufd0a68e2023-02-28 13:25:36 -0800116 static std::unique_ptr<TfLiteMotionPredictorModel> create();
Philip Quinn8f953ab2022-12-06 15:37:07 -0800117
Philip Quinnda6a4482023-02-07 10:09:57 -0800118 ~TfLiteMotionPredictorModel();
119
Philip Quinn8f953ab2022-12-06 15:37:07 -0800120 // Returns the length of the model's input buffers.
121 size_t inputLength() const;
122
Cody Heinerdbd14eb2023-03-30 18:41:45 -0700123 // Returns the length of the model's output buffers.
124 size_t outputLength() const;
125
Philip Quinn107ce702023-07-14 13:07:13 -0700126 const Config& config() const { return mConfig; }
Philip Quinnf84fa492023-06-26 14:15:15 -0700127
Philip Quinn8f953ab2022-12-06 15:37:07 -0800128 // Executes the model.
129 // Returns true if the model successfully executed and the output tensors can be read.
130 bool invoke();
131
132 // Returns mutable buffers to the input tensors of inputLength() elements.
133 std::span<float> inputR();
134 std::span<float> inputPhi();
135 std::span<float> inputPressure();
136 std::span<float> inputOrientation();
137 std::span<float> inputTilt();
138
139 // Returns immutable buffers to the output tensors of identical length. Only valid after a
140 // successful call to invoke().
141 std::span<const float> outputR() const;
142 std::span<const float> outputPhi() const;
143 std::span<const float> outputPressure() const;
144
145private:
Philip Quinnf84fa492023-06-26 14:15:15 -0700146 explicit TfLiteMotionPredictorModel(std::unique_ptr<android::base::MappedFile> model,
Philip Quinn107ce702023-07-14 13:07:13 -0700147 Config config);
Philip Quinn8f953ab2022-12-06 15:37:07 -0800148
149 void allocateTensors();
150 void attachInputTensors();
151 void attachOutputTensors();
152
153 TfLiteTensor* mInputR = nullptr;
154 TfLiteTensor* mInputPhi = nullptr;
155 TfLiteTensor* mInputPressure = nullptr;
156 TfLiteTensor* mInputTilt = nullptr;
157 TfLiteTensor* mInputOrientation = nullptr;
158
159 const TfLiteTensor* mOutputR = nullptr;
160 const TfLiteTensor* mOutputPhi = nullptr;
161 const TfLiteTensor* mOutputPressure = nullptr;
162
Philip Quinncb3229a2023-02-08 22:50:59 -0800163 std::unique_ptr<android::base::MappedFile> mFlatBuffer;
Philip Quinn8f953ab2022-12-06 15:37:07 -0800164 std::unique_ptr<tflite::ErrorReporter> mErrorReporter;
165 std::unique_ptr<tflite::FlatBufferModel> mModel;
166 std::unique_ptr<tflite::Interpreter> mInterpreter;
167 tflite::SignatureRunner* mRunner = nullptr;
Philip Quinnf84fa492023-06-26 14:15:15 -0700168
Philip Quinn107ce702023-07-14 13:07:13 -0700169 const Config mConfig = {};
Philip Quinn8f953ab2022-12-06 15:37:07 -0800170};
171
172} // namespace android