blob: fbb71065408c803a8d4d6906670a7972682eed8d [file] [log] [blame]
Philip Quinn8f953ab2022-12-06 15:37:07 -08001/*
2 * Copyright (C) 2023 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#define LOG_TAG "TfLiteMotionPredictor"
18#include <input/TfLiteMotionPredictor.h>
19
20#include <algorithm>
21#include <cmath>
22#include <cstddef>
23#include <cstdint>
24#include <fstream>
25#include <ios>
26#include <iterator>
27#include <memory>
28#include <span>
29#include <string>
30#include <type_traits>
31#include <utility>
32
33#define ATRACE_TAG ATRACE_TAG_INPUT
34#include <cutils/trace.h>
35#include <log/log.h>
36
37#include "tensorflow/lite/core/api/error_reporter.h"
Philip Quinnda6a4482023-02-07 10:09:57 -080038#include "tensorflow/lite/core/api/op_resolver.h"
Philip Quinn8f953ab2022-12-06 15:37:07 -080039#include "tensorflow/lite/interpreter.h"
Philip Quinnda6a4482023-02-07 10:09:57 -080040#include "tensorflow/lite/kernels/builtin_op_kernels.h"
Philip Quinn8f953ab2022-12-06 15:37:07 -080041#include "tensorflow/lite/model.h"
Philip Quinnda6a4482023-02-07 10:09:57 -080042#include "tensorflow/lite/mutable_op_resolver.h"
Philip Quinn8f953ab2022-12-06 15:37:07 -080043
44namespace android {
45namespace {
46
47constexpr char SIGNATURE_KEY[] = "serving_default";
48
49// Input tensor names.
50constexpr char INPUT_R[] = "r";
51constexpr char INPUT_PHI[] = "phi";
52constexpr char INPUT_PRESSURE[] = "pressure";
53constexpr char INPUT_TILT[] = "tilt";
54constexpr char INPUT_ORIENTATION[] = "orientation";
55
56// Output tensor names.
57constexpr char OUTPUT_R[] = "r";
58constexpr char OUTPUT_PHI[] = "phi";
59constexpr char OUTPUT_PRESSURE[] = "pressure";
60
61// A TFLite ErrorReporter that logs to logcat.
62class LoggingErrorReporter : public tflite::ErrorReporter {
63public:
64 int Report(const char* format, va_list args) override {
65 return LOG_PRI_VA(ANDROID_LOG_ERROR, LOG_TAG, format, args);
66 }
67};
68
69// Searches a runner for an input tensor.
70TfLiteTensor* findInputTensor(const char* name, tflite::SignatureRunner* runner) {
71 TfLiteTensor* tensor = runner->input_tensor(name);
72 LOG_ALWAYS_FATAL_IF(!tensor, "Failed to find input tensor '%s'", name);
73 return tensor;
74}
75
76// Searches a runner for an output tensor.
77const TfLiteTensor* findOutputTensor(const char* name, tflite::SignatureRunner* runner) {
78 const TfLiteTensor* tensor = runner->output_tensor(name);
79 LOG_ALWAYS_FATAL_IF(!tensor, "Failed to find output tensor '%s'", name);
80 return tensor;
81}
82
83// Returns the buffer for a tensor of type T.
84template <typename T>
85std::span<T> getTensorBuffer(typename std::conditional<std::is_const<T>::value, const TfLiteTensor*,
86 TfLiteTensor*>::type tensor) {
87 LOG_ALWAYS_FATAL_IF(!tensor);
88
89 const TfLiteType type = tflite::typeToTfLiteType<typename std::remove_cv<T>::type>();
90 LOG_ALWAYS_FATAL_IF(tensor->type != type, "Unexpected type for '%s' tensor: %s (expected %s)",
91 tensor->name, TfLiteTypeGetName(tensor->type), TfLiteTypeGetName(type));
92
93 LOG_ALWAYS_FATAL_IF(!tensor->data.data);
94 return {reinterpret_cast<T*>(tensor->data.data),
95 static_cast<typename std::span<T>::index_type>(tensor->bytes / sizeof(T))};
96}
97
98// Verifies that a tensor exists and has an underlying buffer of type T.
99template <typename T>
100void checkTensor(const TfLiteTensor* tensor) {
101 LOG_ALWAYS_FATAL_IF(!tensor);
102
103 const auto buffer = getTensorBuffer<const T>(tensor);
104 LOG_ALWAYS_FATAL_IF(buffer.empty(), "No buffer for tensor '%s'", tensor->name);
105}
106
Philip Quinnda6a4482023-02-07 10:09:57 -0800107std::unique_ptr<tflite::OpResolver> createOpResolver() {
108 auto resolver = std::make_unique<tflite::MutableOpResolver>();
109 resolver->AddBuiltin(::tflite::BuiltinOperator_CONCATENATION,
110 ::tflite::ops::builtin::Register_CONCATENATION());
111 resolver->AddBuiltin(::tflite::BuiltinOperator_FULLY_CONNECTED,
112 ::tflite::ops::builtin::Register_FULLY_CONNECTED());
113 return resolver;
114}
115
Philip Quinn8f953ab2022-12-06 15:37:07 -0800116} // namespace
117
Philip Quinn9b8926e2023-01-31 14:50:02 -0800118TfLiteMotionPredictorBuffers::TfLiteMotionPredictorBuffers(size_t inputLength)
119 : mInputR(inputLength, 0),
120 mInputPhi(inputLength, 0),
121 mInputPressure(inputLength, 0),
122 mInputTilt(inputLength, 0),
123 mInputOrientation(inputLength, 0) {
Philip Quinn8f953ab2022-12-06 15:37:07 -0800124 LOG_ALWAYS_FATAL_IF(inputLength == 0, "Buffer input size must be greater than 0");
Philip Quinn8f953ab2022-12-06 15:37:07 -0800125}
126
127void TfLiteMotionPredictorBuffers::reset() {
128 std::fill(mInputR.begin(), mInputR.end(), 0);
129 std::fill(mInputPhi.begin(), mInputPhi.end(), 0);
130 std::fill(mInputPressure.begin(), mInputPressure.end(), 0);
131 std::fill(mInputTilt.begin(), mInputTilt.end(), 0);
132 std::fill(mInputOrientation.begin(), mInputOrientation.end(), 0);
133 mAxisFrom.reset();
134 mAxisTo.reset();
135}
136
137void TfLiteMotionPredictorBuffers::copyTo(TfLiteMotionPredictorModel& model) const {
138 LOG_ALWAYS_FATAL_IF(mInputR.size() != model.inputLength(),
139 "Buffer length %zu doesn't match model input length %zu", mInputR.size(),
140 model.inputLength());
141 LOG_ALWAYS_FATAL_IF(!isReady(), "Buffers are incomplete");
142
143 std::copy(mInputR.begin(), mInputR.end(), model.inputR().begin());
144 std::copy(mInputPhi.begin(), mInputPhi.end(), model.inputPhi().begin());
145 std::copy(mInputPressure.begin(), mInputPressure.end(), model.inputPressure().begin());
146 std::copy(mInputTilt.begin(), mInputTilt.end(), model.inputTilt().begin());
147 std::copy(mInputOrientation.begin(), mInputOrientation.end(), model.inputOrientation().begin());
148}
149
150void TfLiteMotionPredictorBuffers::pushSample(int64_t timestamp,
151 const TfLiteMotionPredictorSample sample) {
152 // Convert the sample (x, y) into polar (r, φ) based on a reference axis
153 // from the preceding two points (mAxisFrom/mAxisTo).
154
155 mTimestamp = timestamp;
156
157 if (!mAxisTo) { // First point.
158 mAxisTo = sample;
159 return;
160 }
161
162 // Vector from the last point to the current sample point.
163 const TfLiteMotionPredictorSample::Point v = sample.position - mAxisTo->position;
164
165 const float r = std::hypot(v.x, v.y);
166 float phi = 0;
167 float orientation = 0;
168
169 // Ignore the sample if there is no movement. These samples can occur when there's change to a
170 // property other than the coordinates and pollute the input to the model.
171 if (r == 0) {
172 return;
173 }
174
175 if (!mAxisFrom) { // Second point.
176 // We can only determine the distance from the first point, and not any
177 // angle. However, if the second point forms an axis, the orientation can
178 // be transformed relative to that axis.
179 const float axisPhi = std::atan2(v.y, v.x);
180 // A MotionEvent's orientation is measured clockwise from the vertical
181 // axis, but axisPhi is measured counter-clockwise from the horizontal
182 // axis.
183 orientation = M_PI_2 - sample.orientation - axisPhi;
184 } else {
185 const TfLiteMotionPredictorSample::Point axis = mAxisTo->position - mAxisFrom->position;
186 const float axisPhi = std::atan2(axis.y, axis.x);
187 phi = std::atan2(v.y, v.x) - axisPhi;
188
189 if (std::hypot(axis.x, axis.y) > 0) {
190 // See note above.
191 orientation = M_PI_2 - sample.orientation - axisPhi;
192 }
193 }
194
195 // Update the axis for the next point.
196 mAxisFrom = mAxisTo;
197 mAxisTo = sample;
198
199 // Push the current sample onto the end of the input buffers.
Philip Quinn9b8926e2023-01-31 14:50:02 -0800200 mInputR.pushBack(r);
201 mInputPhi.pushBack(phi);
202 mInputPressure.pushBack(sample.pressure);
203 mInputTilt.pushBack(sample.tilt);
204 mInputOrientation.pushBack(orientation);
Philip Quinn8f953ab2022-12-06 15:37:07 -0800205}
206
207std::unique_ptr<TfLiteMotionPredictorModel> TfLiteMotionPredictorModel::create(
208 const char* modelPath) {
209 std::ifstream f(modelPath, std::ios::binary);
210 LOG_ALWAYS_FATAL_IF(!f, "Could not read model from %s", modelPath);
211
212 std::string data;
213 data.assign(std::istreambuf_iterator<char>(f), std::istreambuf_iterator<char>());
214
215 return std::unique_ptr<TfLiteMotionPredictorModel>(
216 new TfLiteMotionPredictorModel(std::move(data)));
217}
218
219TfLiteMotionPredictorModel::TfLiteMotionPredictorModel(std::string model)
220 : mFlatBuffer(std::move(model)) {
221 mErrorReporter = std::make_unique<LoggingErrorReporter>();
222 mModel = tflite::FlatBufferModel::VerifyAndBuildFromBuffer(mFlatBuffer.data(),
223 mFlatBuffer.length(),
224 /*extra_verifier=*/nullptr,
225 mErrorReporter.get());
226 LOG_ALWAYS_FATAL_IF(!mModel);
227
Philip Quinnda6a4482023-02-07 10:09:57 -0800228 auto resolver = createOpResolver();
229 tflite::InterpreterBuilder builder(*mModel, *resolver);
Philip Quinn8f953ab2022-12-06 15:37:07 -0800230
231 if (builder(&mInterpreter) != kTfLiteOk || !mInterpreter) {
232 LOG_ALWAYS_FATAL("Failed to build interpreter");
233 }
234
235 mRunner = mInterpreter->GetSignatureRunner(SIGNATURE_KEY);
236 LOG_ALWAYS_FATAL_IF(!mRunner, "Failed to find runner for signature '%s'", SIGNATURE_KEY);
237
238 allocateTensors();
239}
240
Philip Quinnda6a4482023-02-07 10:09:57 -0800241TfLiteMotionPredictorModel::~TfLiteMotionPredictorModel() {}
242
Philip Quinn8f953ab2022-12-06 15:37:07 -0800243void TfLiteMotionPredictorModel::allocateTensors() {
244 if (mRunner->AllocateTensors() != kTfLiteOk) {
245 LOG_ALWAYS_FATAL("Failed to allocate tensors");
246 }
247
248 attachInputTensors();
249 attachOutputTensors();
250
251 checkTensor<float>(mInputR);
252 checkTensor<float>(mInputPhi);
253 checkTensor<float>(mInputPressure);
254 checkTensor<float>(mInputTilt);
255 checkTensor<float>(mInputOrientation);
256 checkTensor<float>(mOutputR);
257 checkTensor<float>(mOutputPhi);
258 checkTensor<float>(mOutputPressure);
259
260 const auto checkInputTensorSize = [this](const TfLiteTensor* tensor) {
261 const size_t size = getTensorBuffer<const float>(tensor).size();
262 LOG_ALWAYS_FATAL_IF(size != inputLength(),
263 "Tensor '%s' length %zu does not match input length %zu", tensor->name,
264 size, inputLength());
265 };
266
267 checkInputTensorSize(mInputR);
268 checkInputTensorSize(mInputPhi);
269 checkInputTensorSize(mInputPressure);
270 checkInputTensorSize(mInputTilt);
271 checkInputTensorSize(mInputOrientation);
272}
273
274void TfLiteMotionPredictorModel::attachInputTensors() {
275 mInputR = findInputTensor(INPUT_R, mRunner);
276 mInputPhi = findInputTensor(INPUT_PHI, mRunner);
277 mInputPressure = findInputTensor(INPUT_PRESSURE, mRunner);
278 mInputTilt = findInputTensor(INPUT_TILT, mRunner);
279 mInputOrientation = findInputTensor(INPUT_ORIENTATION, mRunner);
280}
281
282void TfLiteMotionPredictorModel::attachOutputTensors() {
283 mOutputR = findOutputTensor(OUTPUT_R, mRunner);
284 mOutputPhi = findOutputTensor(OUTPUT_PHI, mRunner);
285 mOutputPressure = findOutputTensor(OUTPUT_PRESSURE, mRunner);
286}
287
288bool TfLiteMotionPredictorModel::invoke() {
289 ATRACE_BEGIN("TfLiteMotionPredictorModel::invoke");
290 TfLiteStatus result = mRunner->Invoke();
291 ATRACE_END();
292
293 if (result != kTfLiteOk) {
294 return false;
295 }
296
297 // Invoke() might reallocate tensors, so they need to be reattached.
298 attachInputTensors();
299 attachOutputTensors();
300
301 if (outputR().size() != outputPhi().size() || outputR().size() != outputPressure().size()) {
302 LOG_ALWAYS_FATAL("Output size mismatch: (r: %zu, phi: %zu, pressure: %zu)",
303 outputR().size(), outputPhi().size(), outputPressure().size());
304 }
305
306 return true;
307}
308
309size_t TfLiteMotionPredictorModel::inputLength() const {
310 return getTensorBuffer<const float>(mInputR).size();
311}
312
313std::span<float> TfLiteMotionPredictorModel::inputR() {
314 return getTensorBuffer<float>(mInputR);
315}
316
317std::span<float> TfLiteMotionPredictorModel::inputPhi() {
318 return getTensorBuffer<float>(mInputPhi);
319}
320
321std::span<float> TfLiteMotionPredictorModel::inputPressure() {
322 return getTensorBuffer<float>(mInputPressure);
323}
324
325std::span<float> TfLiteMotionPredictorModel::inputTilt() {
326 return getTensorBuffer<float>(mInputTilt);
327}
328
329std::span<float> TfLiteMotionPredictorModel::inputOrientation() {
330 return getTensorBuffer<float>(mInputOrientation);
331}
332
333std::span<const float> TfLiteMotionPredictorModel::outputR() const {
334 return getTensorBuffer<const float>(mOutputR);
335}
336
337std::span<const float> TfLiteMotionPredictorModel::outputPhi() const {
338 return getTensorBuffer<const float>(mOutputPhi);
339}
340
341std::span<const float> TfLiteMotionPredictorModel::outputPressure() const {
342 return getTensorBuffer<const float>(mOutputPressure);
343}
344
345} // namespace android