blob: fbd60261b24ad231240d91d5b06dcc5889d01ab1 [file] [log] [blame]
Philip Quinn8f953ab2022-12-06 15:37:07 -08001/*
2 * Copyright (C) 2023 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#pragma once
18
19#include <array>
20#include <cstddef>
21#include <cstdint>
22#include <memory>
23#include <optional>
24#include <span>
Philip Quinn9b8926e2023-01-31 14:50:02 -080025
Philip Quinncb3229a2023-02-08 22:50:59 -080026#include <android-base/mapped_file.h>
Philip Quinn9b8926e2023-01-31 14:50:02 -080027#include <input/RingBuffer.h>
Philip Quinnf84fa492023-06-26 14:15:15 -070028#include <utils/Timers.h>
Philip Quinn8f953ab2022-12-06 15:37:07 -080029
30#include <tensorflow/lite/core/api/error_reporter.h>
31#include <tensorflow/lite/interpreter.h>
32#include <tensorflow/lite/model.h>
33#include <tensorflow/lite/signature_runner.h>
34
35namespace android {
36
37struct TfLiteMotionPredictorSample {
38 // The untransformed AMOTION_EVENT_AXIS_X and AMOTION_EVENT_AXIS_Y of the sample.
39 struct Point {
40 float x;
41 float y;
42 } position;
43 // The AMOTION_EVENT_AXIS_PRESSURE, _TILT, and _ORIENTATION.
44 float pressure;
45 float tilt;
46 float orientation;
47};
48
49inline TfLiteMotionPredictorSample::Point operator-(const TfLiteMotionPredictorSample::Point& lhs,
50 const TfLiteMotionPredictorSample::Point& rhs) {
51 return {.x = lhs.x - rhs.x, .y = lhs.y - rhs.y};
52}
53
54class TfLiteMotionPredictorModel;
55
56// Buffer storage for a TfLiteMotionPredictorModel.
57class TfLiteMotionPredictorBuffers {
58public:
59 // Creates buffer storage for a model with the given input length.
60 TfLiteMotionPredictorBuffers(size_t inputLength);
61
62 // Adds a motion sample to the buffers.
63 void pushSample(int64_t timestamp, TfLiteMotionPredictorSample sample);
64
65 // Returns true if the buffers are complete enough to generate a prediction.
66 bool isReady() const {
67 // Predictions can't be applied unless there are at least two points to determine
68 // the direction to apply them in.
69 return mAxisFrom && mAxisTo;
70 }
71
72 // Resets all buffers to their initial state.
73 void reset();
74
75 // Copies the buffers to those of a model for prediction.
76 void copyTo(TfLiteMotionPredictorModel& model) const;
77
78 // Returns the current axis of the buffer's samples. Only valid if isReady().
79 TfLiteMotionPredictorSample axisFrom() const { return *mAxisFrom; }
80 TfLiteMotionPredictorSample axisTo() const { return *mAxisTo; }
81
82 // Returns the timestamp of the last sample.
83 int64_t lastTimestamp() const { return mTimestamp; }
84
85private:
86 int64_t mTimestamp = 0;
87
Philip Quinn9b8926e2023-01-31 14:50:02 -080088 RingBuffer<float> mInputR;
89 RingBuffer<float> mInputPhi;
90 RingBuffer<float> mInputPressure;
91 RingBuffer<float> mInputTilt;
92 RingBuffer<float> mInputOrientation;
Philip Quinn8f953ab2022-12-06 15:37:07 -080093
94 // The samples defining the current polar axis.
95 std::optional<TfLiteMotionPredictorSample> mAxisFrom;
96 std::optional<TfLiteMotionPredictorSample> mAxisTo;
97};
98
99// A TFLite model for generating motion predictions.
100class TfLiteMotionPredictorModel {
101public:
102 // Creates a model from an encoded Flatbuffer model.
Siarhei Vishniakoufd0a68e2023-02-28 13:25:36 -0800103 static std::unique_ptr<TfLiteMotionPredictorModel> create();
Philip Quinn8f953ab2022-12-06 15:37:07 -0800104
Philip Quinnda6a4482023-02-07 10:09:57 -0800105 ~TfLiteMotionPredictorModel();
106
Philip Quinn8f953ab2022-12-06 15:37:07 -0800107 // Returns the length of the model's input buffers.
108 size_t inputLength() const;
109
Cody Heinerdbd14eb2023-03-30 18:41:45 -0700110 // Returns the length of the model's output buffers.
111 size_t outputLength() const;
112
Philip Quinnf84fa492023-06-26 14:15:15 -0700113 // Returns the time interval between predictions.
114 nsecs_t predictionInterval() const { return mPredictionInterval; }
115
Philip Quinn8f953ab2022-12-06 15:37:07 -0800116 // Executes the model.
117 // Returns true if the model successfully executed and the output tensors can be read.
118 bool invoke();
119
120 // Returns mutable buffers to the input tensors of inputLength() elements.
121 std::span<float> inputR();
122 std::span<float> inputPhi();
123 std::span<float> inputPressure();
124 std::span<float> inputOrientation();
125 std::span<float> inputTilt();
126
127 // Returns immutable buffers to the output tensors of identical length. Only valid after a
128 // successful call to invoke().
129 std::span<const float> outputR() const;
130 std::span<const float> outputPhi() const;
131 std::span<const float> outputPressure() const;
132
133private:
Philip Quinnf84fa492023-06-26 14:15:15 -0700134 explicit TfLiteMotionPredictorModel(std::unique_ptr<android::base::MappedFile> model,
135 nsecs_t predictionInterval);
Philip Quinn8f953ab2022-12-06 15:37:07 -0800136
137 void allocateTensors();
138 void attachInputTensors();
139 void attachOutputTensors();
140
141 TfLiteTensor* mInputR = nullptr;
142 TfLiteTensor* mInputPhi = nullptr;
143 TfLiteTensor* mInputPressure = nullptr;
144 TfLiteTensor* mInputTilt = nullptr;
145 TfLiteTensor* mInputOrientation = nullptr;
146
147 const TfLiteTensor* mOutputR = nullptr;
148 const TfLiteTensor* mOutputPhi = nullptr;
149 const TfLiteTensor* mOutputPressure = nullptr;
150
Philip Quinncb3229a2023-02-08 22:50:59 -0800151 std::unique_ptr<android::base::MappedFile> mFlatBuffer;
Philip Quinn8f953ab2022-12-06 15:37:07 -0800152 std::unique_ptr<tflite::ErrorReporter> mErrorReporter;
153 std::unique_ptr<tflite::FlatBufferModel> mModel;
154 std::unique_ptr<tflite::Interpreter> mInterpreter;
155 tflite::SignatureRunner* mRunner = nullptr;
Philip Quinnf84fa492023-06-26 14:15:15 -0700156
157 const nsecs_t mPredictionInterval = 0;
Philip Quinn8f953ab2022-12-06 15:37:07 -0800158};
159
160} // namespace android