blob: 2edc138f67b126baf7a3c1c07c3f889b06881d17 [file] [log] [blame]
Philip Quinn8f953ab2022-12-06 15:37:07 -08001/*
2 * Copyright (C) 2023 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#pragma once
18
19#include <array>
20#include <cstddef>
21#include <cstdint>
22#include <memory>
23#include <optional>
24#include <span>
Philip Quinn9b8926e2023-01-31 14:50:02 -080025
Philip Quinncb3229a2023-02-08 22:50:59 -080026#include <android-base/mapped_file.h>
Philip Quinn9b8926e2023-01-31 14:50:02 -080027#include <input/RingBuffer.h>
Philip Quinnf84fa492023-06-26 14:15:15 -070028#include <utils/Timers.h>
Philip Quinn8f953ab2022-12-06 15:37:07 -080029
30#include <tensorflow/lite/core/api/error_reporter.h>
31#include <tensorflow/lite/interpreter.h>
32#include <tensorflow/lite/model.h>
33#include <tensorflow/lite/signature_runner.h>
34
35namespace android {
36
37struct TfLiteMotionPredictorSample {
38 // The untransformed AMOTION_EVENT_AXIS_X and AMOTION_EVENT_AXIS_Y of the sample.
39 struct Point {
40 float x;
41 float y;
42 } position;
43 // The AMOTION_EVENT_AXIS_PRESSURE, _TILT, and _ORIENTATION.
44 float pressure;
45 float tilt;
46 float orientation;
47};
48
49inline TfLiteMotionPredictorSample::Point operator-(const TfLiteMotionPredictorSample::Point& lhs,
50 const TfLiteMotionPredictorSample::Point& rhs) {
51 return {.x = lhs.x - rhs.x, .y = lhs.y - rhs.y};
52}
53
54class TfLiteMotionPredictorModel;
55
56// Buffer storage for a TfLiteMotionPredictorModel.
57class TfLiteMotionPredictorBuffers {
58public:
59 // Creates buffer storage for a model with the given input length.
60 TfLiteMotionPredictorBuffers(size_t inputLength);
61
62 // Adds a motion sample to the buffers.
63 void pushSample(int64_t timestamp, TfLiteMotionPredictorSample sample);
64
65 // Returns true if the buffers are complete enough to generate a prediction.
66 bool isReady() const {
67 // Predictions can't be applied unless there are at least two points to determine
68 // the direction to apply them in.
69 return mAxisFrom && mAxisTo;
70 }
71
72 // Resets all buffers to their initial state.
73 void reset();
74
75 // Copies the buffers to those of a model for prediction.
76 void copyTo(TfLiteMotionPredictorModel& model) const;
77
78 // Returns the current axis of the buffer's samples. Only valid if isReady().
79 TfLiteMotionPredictorSample axisFrom() const { return *mAxisFrom; }
80 TfLiteMotionPredictorSample axisTo() const { return *mAxisTo; }
81
82 // Returns the timestamp of the last sample.
83 int64_t lastTimestamp() const { return mTimestamp; }
84
85private:
86 int64_t mTimestamp = 0;
87
Philip Quinn9b8926e2023-01-31 14:50:02 -080088 RingBuffer<float> mInputR;
89 RingBuffer<float> mInputPhi;
90 RingBuffer<float> mInputPressure;
91 RingBuffer<float> mInputTilt;
92 RingBuffer<float> mInputOrientation;
Philip Quinn8f953ab2022-12-06 15:37:07 -080093
94 // The samples defining the current polar axis.
95 std::optional<TfLiteMotionPredictorSample> mAxisFrom;
96 std::optional<TfLiteMotionPredictorSample> mAxisTo;
97};
98
99// A TFLite model for generating motion predictions.
100class TfLiteMotionPredictorModel {
101public:
Philip Quinn107ce702023-07-14 13:07:13 -0700102 struct Config {
103 // The time between predictions.
104 nsecs_t predictionInterval = 0;
105 // The noise floor for predictions.
106 // Distances (r) less than this should be discarded as noise.
107 float distanceNoiseFloor = 0;
108 };
109
Philip Quinn8f953ab2022-12-06 15:37:07 -0800110 // Creates a model from an encoded Flatbuffer model.
Siarhei Vishniakoufd0a68e2023-02-28 13:25:36 -0800111 static std::unique_ptr<TfLiteMotionPredictorModel> create();
Philip Quinn8f953ab2022-12-06 15:37:07 -0800112
Philip Quinnda6a4482023-02-07 10:09:57 -0800113 ~TfLiteMotionPredictorModel();
114
Philip Quinn8f953ab2022-12-06 15:37:07 -0800115 // Returns the length of the model's input buffers.
116 size_t inputLength() const;
117
Cody Heinerdbd14eb2023-03-30 18:41:45 -0700118 // Returns the length of the model's output buffers.
119 size_t outputLength() const;
120
Philip Quinn107ce702023-07-14 13:07:13 -0700121 const Config& config() const { return mConfig; }
Philip Quinnf84fa492023-06-26 14:15:15 -0700122
Philip Quinn8f953ab2022-12-06 15:37:07 -0800123 // Executes the model.
124 // Returns true if the model successfully executed and the output tensors can be read.
125 bool invoke();
126
127 // Returns mutable buffers to the input tensors of inputLength() elements.
128 std::span<float> inputR();
129 std::span<float> inputPhi();
130 std::span<float> inputPressure();
131 std::span<float> inputOrientation();
132 std::span<float> inputTilt();
133
134 // Returns immutable buffers to the output tensors of identical length. Only valid after a
135 // successful call to invoke().
136 std::span<const float> outputR() const;
137 std::span<const float> outputPhi() const;
138 std::span<const float> outputPressure() const;
139
140private:
Philip Quinnf84fa492023-06-26 14:15:15 -0700141 explicit TfLiteMotionPredictorModel(std::unique_ptr<android::base::MappedFile> model,
Philip Quinn107ce702023-07-14 13:07:13 -0700142 Config config);
Philip Quinn8f953ab2022-12-06 15:37:07 -0800143
144 void allocateTensors();
145 void attachInputTensors();
146 void attachOutputTensors();
147
148 TfLiteTensor* mInputR = nullptr;
149 TfLiteTensor* mInputPhi = nullptr;
150 TfLiteTensor* mInputPressure = nullptr;
151 TfLiteTensor* mInputTilt = nullptr;
152 TfLiteTensor* mInputOrientation = nullptr;
153
154 const TfLiteTensor* mOutputR = nullptr;
155 const TfLiteTensor* mOutputPhi = nullptr;
156 const TfLiteTensor* mOutputPressure = nullptr;
157
Philip Quinncb3229a2023-02-08 22:50:59 -0800158 std::unique_ptr<android::base::MappedFile> mFlatBuffer;
Philip Quinn8f953ab2022-12-06 15:37:07 -0800159 std::unique_ptr<tflite::ErrorReporter> mErrorReporter;
160 std::unique_ptr<tflite::FlatBufferModel> mModel;
161 std::unique_ptr<tflite::Interpreter> mInterpreter;
162 tflite::SignatureRunner* mRunner = nullptr;
Philip Quinnf84fa492023-06-26 14:15:15 -0700163
Philip Quinn107ce702023-07-14 13:07:13 -0700164 const Config mConfig = {};
Philip Quinn8f953ab2022-12-06 15:37:07 -0800165};
166
167} // namespace android