blob: 1a337adf1303ab3c98cbd98de670c9dc54a0fb30 [file] [log] [blame]
Philip Quinn8f953ab2022-12-06 15:37:07 -08001/*
2 * Copyright (C) 2023 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#define LOG_TAG "TfLiteMotionPredictor"
18#include <input/TfLiteMotionPredictor.h>
19
20#include <algorithm>
21#include <cmath>
22#include <cstddef>
23#include <cstdint>
24#include <fstream>
25#include <ios>
26#include <iterator>
27#include <memory>
28#include <span>
29#include <string>
30#include <type_traits>
31#include <utility>
32
33#define ATRACE_TAG ATRACE_TAG_INPUT
34#include <cutils/trace.h>
35#include <log/log.h>
36
37#include "tensorflow/lite/core/api/error_reporter.h"
38#include "tensorflow/lite/interpreter.h"
39#include "tensorflow/lite/kernels/register.h"
40#include "tensorflow/lite/model.h"
41
42namespace android {
43namespace {
44
45constexpr char SIGNATURE_KEY[] = "serving_default";
46
47// Input tensor names.
48constexpr char INPUT_R[] = "r";
49constexpr char INPUT_PHI[] = "phi";
50constexpr char INPUT_PRESSURE[] = "pressure";
51constexpr char INPUT_TILT[] = "tilt";
52constexpr char INPUT_ORIENTATION[] = "orientation";
53
54// Output tensor names.
55constexpr char OUTPUT_R[] = "r";
56constexpr char OUTPUT_PHI[] = "phi";
57constexpr char OUTPUT_PRESSURE[] = "pressure";
58
59// A TFLite ErrorReporter that logs to logcat.
60class LoggingErrorReporter : public tflite::ErrorReporter {
61public:
62 int Report(const char* format, va_list args) override {
63 return LOG_PRI_VA(ANDROID_LOG_ERROR, LOG_TAG, format, args);
64 }
65};
66
67// Searches a runner for an input tensor.
68TfLiteTensor* findInputTensor(const char* name, tflite::SignatureRunner* runner) {
69 TfLiteTensor* tensor = runner->input_tensor(name);
70 LOG_ALWAYS_FATAL_IF(!tensor, "Failed to find input tensor '%s'", name);
71 return tensor;
72}
73
74// Searches a runner for an output tensor.
75const TfLiteTensor* findOutputTensor(const char* name, tflite::SignatureRunner* runner) {
76 const TfLiteTensor* tensor = runner->output_tensor(name);
77 LOG_ALWAYS_FATAL_IF(!tensor, "Failed to find output tensor '%s'", name);
78 return tensor;
79}
80
81// Returns the buffer for a tensor of type T.
82template <typename T>
83std::span<T> getTensorBuffer(typename std::conditional<std::is_const<T>::value, const TfLiteTensor*,
84 TfLiteTensor*>::type tensor) {
85 LOG_ALWAYS_FATAL_IF(!tensor);
86
87 const TfLiteType type = tflite::typeToTfLiteType<typename std::remove_cv<T>::type>();
88 LOG_ALWAYS_FATAL_IF(tensor->type != type, "Unexpected type for '%s' tensor: %s (expected %s)",
89 tensor->name, TfLiteTypeGetName(tensor->type), TfLiteTypeGetName(type));
90
91 LOG_ALWAYS_FATAL_IF(!tensor->data.data);
92 return {reinterpret_cast<T*>(tensor->data.data),
93 static_cast<typename std::span<T>::index_type>(tensor->bytes / sizeof(T))};
94}
95
96// Verifies that a tensor exists and has an underlying buffer of type T.
97template <typename T>
98void checkTensor(const TfLiteTensor* tensor) {
99 LOG_ALWAYS_FATAL_IF(!tensor);
100
101 const auto buffer = getTensorBuffer<const T>(tensor);
102 LOG_ALWAYS_FATAL_IF(buffer.empty(), "No buffer for tensor '%s'", tensor->name);
103}
104
105} // namespace
106
107TfLiteMotionPredictorBuffers::TfLiteMotionPredictorBuffers(size_t inputLength) {
108 LOG_ALWAYS_FATAL_IF(inputLength == 0, "Buffer input size must be greater than 0");
109 mInputR.resize(inputLength);
110 mInputPhi.resize(inputLength);
111 mInputPressure.resize(inputLength);
112 mInputTilt.resize(inputLength);
113 mInputOrientation.resize(inputLength);
114}
115
116void TfLiteMotionPredictorBuffers::reset() {
117 std::fill(mInputR.begin(), mInputR.end(), 0);
118 std::fill(mInputPhi.begin(), mInputPhi.end(), 0);
119 std::fill(mInputPressure.begin(), mInputPressure.end(), 0);
120 std::fill(mInputTilt.begin(), mInputTilt.end(), 0);
121 std::fill(mInputOrientation.begin(), mInputOrientation.end(), 0);
122 mAxisFrom.reset();
123 mAxisTo.reset();
124}
125
126void TfLiteMotionPredictorBuffers::copyTo(TfLiteMotionPredictorModel& model) const {
127 LOG_ALWAYS_FATAL_IF(mInputR.size() != model.inputLength(),
128 "Buffer length %zu doesn't match model input length %zu", mInputR.size(),
129 model.inputLength());
130 LOG_ALWAYS_FATAL_IF(!isReady(), "Buffers are incomplete");
131
132 std::copy(mInputR.begin(), mInputR.end(), model.inputR().begin());
133 std::copy(mInputPhi.begin(), mInputPhi.end(), model.inputPhi().begin());
134 std::copy(mInputPressure.begin(), mInputPressure.end(), model.inputPressure().begin());
135 std::copy(mInputTilt.begin(), mInputTilt.end(), model.inputTilt().begin());
136 std::copy(mInputOrientation.begin(), mInputOrientation.end(), model.inputOrientation().begin());
137}
138
139void TfLiteMotionPredictorBuffers::pushSample(int64_t timestamp,
140 const TfLiteMotionPredictorSample sample) {
141 // Convert the sample (x, y) into polar (r, φ) based on a reference axis
142 // from the preceding two points (mAxisFrom/mAxisTo).
143
144 mTimestamp = timestamp;
145
146 if (!mAxisTo) { // First point.
147 mAxisTo = sample;
148 return;
149 }
150
151 // Vector from the last point to the current sample point.
152 const TfLiteMotionPredictorSample::Point v = sample.position - mAxisTo->position;
153
154 const float r = std::hypot(v.x, v.y);
155 float phi = 0;
156 float orientation = 0;
157
158 // Ignore the sample if there is no movement. These samples can occur when there's change to a
159 // property other than the coordinates and pollute the input to the model.
160 if (r == 0) {
161 return;
162 }
163
164 if (!mAxisFrom) { // Second point.
165 // We can only determine the distance from the first point, and not any
166 // angle. However, if the second point forms an axis, the orientation can
167 // be transformed relative to that axis.
168 const float axisPhi = std::atan2(v.y, v.x);
169 // A MotionEvent's orientation is measured clockwise from the vertical
170 // axis, but axisPhi is measured counter-clockwise from the horizontal
171 // axis.
172 orientation = M_PI_2 - sample.orientation - axisPhi;
173 } else {
174 const TfLiteMotionPredictorSample::Point axis = mAxisTo->position - mAxisFrom->position;
175 const float axisPhi = std::atan2(axis.y, axis.x);
176 phi = std::atan2(v.y, v.x) - axisPhi;
177
178 if (std::hypot(axis.x, axis.y) > 0) {
179 // See note above.
180 orientation = M_PI_2 - sample.orientation - axisPhi;
181 }
182 }
183
184 // Update the axis for the next point.
185 mAxisFrom = mAxisTo;
186 mAxisTo = sample;
187
188 // Push the current sample onto the end of the input buffers.
189 mInputR.erase(mInputR.begin());
190 mInputPhi.erase(mInputPhi.begin());
191 mInputPressure.erase(mInputPressure.begin());
192 mInputTilt.erase(mInputTilt.begin());
193 mInputOrientation.erase(mInputOrientation.begin());
194
195 mInputR.push_back(r);
196 mInputPhi.push_back(phi);
197 mInputPressure.push_back(sample.pressure);
198 mInputTilt.push_back(sample.tilt);
199 mInputOrientation.push_back(orientation);
200}
201
202std::unique_ptr<TfLiteMotionPredictorModel> TfLiteMotionPredictorModel::create(
203 const char* modelPath) {
204 std::ifstream f(modelPath, std::ios::binary);
205 LOG_ALWAYS_FATAL_IF(!f, "Could not read model from %s", modelPath);
206
207 std::string data;
208 data.assign(std::istreambuf_iterator<char>(f), std::istreambuf_iterator<char>());
209
210 return std::unique_ptr<TfLiteMotionPredictorModel>(
211 new TfLiteMotionPredictorModel(std::move(data)));
212}
213
214TfLiteMotionPredictorModel::TfLiteMotionPredictorModel(std::string model)
215 : mFlatBuffer(std::move(model)) {
216 mErrorReporter = std::make_unique<LoggingErrorReporter>();
217 mModel = tflite::FlatBufferModel::VerifyAndBuildFromBuffer(mFlatBuffer.data(),
218 mFlatBuffer.length(),
219 /*extra_verifier=*/nullptr,
220 mErrorReporter.get());
221 LOG_ALWAYS_FATAL_IF(!mModel);
222
223 tflite::ops::builtin::BuiltinOpResolver resolver;
224 tflite::InterpreterBuilder builder(*mModel, resolver);
225
226 if (builder(&mInterpreter) != kTfLiteOk || !mInterpreter) {
227 LOG_ALWAYS_FATAL("Failed to build interpreter");
228 }
229
230 mRunner = mInterpreter->GetSignatureRunner(SIGNATURE_KEY);
231 LOG_ALWAYS_FATAL_IF(!mRunner, "Failed to find runner for signature '%s'", SIGNATURE_KEY);
232
233 allocateTensors();
234}
235
236void TfLiteMotionPredictorModel::allocateTensors() {
237 if (mRunner->AllocateTensors() != kTfLiteOk) {
238 LOG_ALWAYS_FATAL("Failed to allocate tensors");
239 }
240
241 attachInputTensors();
242 attachOutputTensors();
243
244 checkTensor<float>(mInputR);
245 checkTensor<float>(mInputPhi);
246 checkTensor<float>(mInputPressure);
247 checkTensor<float>(mInputTilt);
248 checkTensor<float>(mInputOrientation);
249 checkTensor<float>(mOutputR);
250 checkTensor<float>(mOutputPhi);
251 checkTensor<float>(mOutputPressure);
252
253 const auto checkInputTensorSize = [this](const TfLiteTensor* tensor) {
254 const size_t size = getTensorBuffer<const float>(tensor).size();
255 LOG_ALWAYS_FATAL_IF(size != inputLength(),
256 "Tensor '%s' length %zu does not match input length %zu", tensor->name,
257 size, inputLength());
258 };
259
260 checkInputTensorSize(mInputR);
261 checkInputTensorSize(mInputPhi);
262 checkInputTensorSize(mInputPressure);
263 checkInputTensorSize(mInputTilt);
264 checkInputTensorSize(mInputOrientation);
265}
266
267void TfLiteMotionPredictorModel::attachInputTensors() {
268 mInputR = findInputTensor(INPUT_R, mRunner);
269 mInputPhi = findInputTensor(INPUT_PHI, mRunner);
270 mInputPressure = findInputTensor(INPUT_PRESSURE, mRunner);
271 mInputTilt = findInputTensor(INPUT_TILT, mRunner);
272 mInputOrientation = findInputTensor(INPUT_ORIENTATION, mRunner);
273}
274
275void TfLiteMotionPredictorModel::attachOutputTensors() {
276 mOutputR = findOutputTensor(OUTPUT_R, mRunner);
277 mOutputPhi = findOutputTensor(OUTPUT_PHI, mRunner);
278 mOutputPressure = findOutputTensor(OUTPUT_PRESSURE, mRunner);
279}
280
281bool TfLiteMotionPredictorModel::invoke() {
282 ATRACE_BEGIN("TfLiteMotionPredictorModel::invoke");
283 TfLiteStatus result = mRunner->Invoke();
284 ATRACE_END();
285
286 if (result != kTfLiteOk) {
287 return false;
288 }
289
290 // Invoke() might reallocate tensors, so they need to be reattached.
291 attachInputTensors();
292 attachOutputTensors();
293
294 if (outputR().size() != outputPhi().size() || outputR().size() != outputPressure().size()) {
295 LOG_ALWAYS_FATAL("Output size mismatch: (r: %zu, phi: %zu, pressure: %zu)",
296 outputR().size(), outputPhi().size(), outputPressure().size());
297 }
298
299 return true;
300}
301
302size_t TfLiteMotionPredictorModel::inputLength() const {
303 return getTensorBuffer<const float>(mInputR).size();
304}
305
306std::span<float> TfLiteMotionPredictorModel::inputR() {
307 return getTensorBuffer<float>(mInputR);
308}
309
310std::span<float> TfLiteMotionPredictorModel::inputPhi() {
311 return getTensorBuffer<float>(mInputPhi);
312}
313
314std::span<float> TfLiteMotionPredictorModel::inputPressure() {
315 return getTensorBuffer<float>(mInputPressure);
316}
317
318std::span<float> TfLiteMotionPredictorModel::inputTilt() {
319 return getTensorBuffer<float>(mInputTilt);
320}
321
322std::span<float> TfLiteMotionPredictorModel::inputOrientation() {
323 return getTensorBuffer<float>(mInputOrientation);
324}
325
326std::span<const float> TfLiteMotionPredictorModel::outputR() const {
327 return getTensorBuffer<const float>(mOutputR);
328}
329
330std::span<const float> TfLiteMotionPredictorModel::outputPhi() const {
331 return getTensorBuffer<const float>(mOutputPhi);
332}
333
334std::span<const float> TfLiteMotionPredictorModel::outputPressure() const {
335 return getTensorBuffer<const float>(mOutputPressure);
336}
337
338} // namespace android