blob: 10510d675c87264d3c470c6535485e56551d4bb4 [file] [log] [blame]
Philip Quinn8f953ab2022-12-06 15:37:07 -08001/*
2 * Copyright (C) 2023 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#define LOG_TAG "TfLiteMotionPredictor"
18#include <input/TfLiteMotionPredictor.h>
19
Philip Quinncb3229a2023-02-08 22:50:59 -080020#include <fcntl.h>
21#include <sys/mman.h>
22#include <unistd.h>
23
Philip Quinn8f953ab2022-12-06 15:37:07 -080024#include <algorithm>
25#include <cmath>
26#include <cstddef>
27#include <cstdint>
Philip Quinn8f953ab2022-12-06 15:37:07 -080028#include <memory>
29#include <span>
Philip Quinn8f953ab2022-12-06 15:37:07 -080030#include <type_traits>
31#include <utility>
32
Philip Quinncb3229a2023-02-08 22:50:59 -080033#include <android-base/logging.h>
34#include <android-base/mapped_file.h>
Philip Quinn8f953ab2022-12-06 15:37:07 -080035#define ATRACE_TAG ATRACE_TAG_INPUT
36#include <cutils/trace.h>
37#include <log/log.h>
38
39#include "tensorflow/lite/core/api/error_reporter.h"
Philip Quinnda6a4482023-02-07 10:09:57 -080040#include "tensorflow/lite/core/api/op_resolver.h"
Philip Quinn8f953ab2022-12-06 15:37:07 -080041#include "tensorflow/lite/interpreter.h"
Philip Quinnda6a4482023-02-07 10:09:57 -080042#include "tensorflow/lite/kernels/builtin_op_kernels.h"
Philip Quinn8f953ab2022-12-06 15:37:07 -080043#include "tensorflow/lite/model.h"
Philip Quinnda6a4482023-02-07 10:09:57 -080044#include "tensorflow/lite/mutable_op_resolver.h"
Philip Quinn8f953ab2022-12-06 15:37:07 -080045
46namespace android {
47namespace {
48
49constexpr char SIGNATURE_KEY[] = "serving_default";
50
51// Input tensor names.
52constexpr char INPUT_R[] = "r";
53constexpr char INPUT_PHI[] = "phi";
54constexpr char INPUT_PRESSURE[] = "pressure";
55constexpr char INPUT_TILT[] = "tilt";
56constexpr char INPUT_ORIENTATION[] = "orientation";
57
58// Output tensor names.
59constexpr char OUTPUT_R[] = "r";
60constexpr char OUTPUT_PHI[] = "phi";
61constexpr char OUTPUT_PRESSURE[] = "pressure";
62
63// A TFLite ErrorReporter that logs to logcat.
64class LoggingErrorReporter : public tflite::ErrorReporter {
65public:
66 int Report(const char* format, va_list args) override {
67 return LOG_PRI_VA(ANDROID_LOG_ERROR, LOG_TAG, format, args);
68 }
69};
70
71// Searches a runner for an input tensor.
72TfLiteTensor* findInputTensor(const char* name, tflite::SignatureRunner* runner) {
73 TfLiteTensor* tensor = runner->input_tensor(name);
74 LOG_ALWAYS_FATAL_IF(!tensor, "Failed to find input tensor '%s'", name);
75 return tensor;
76}
77
78// Searches a runner for an output tensor.
79const TfLiteTensor* findOutputTensor(const char* name, tflite::SignatureRunner* runner) {
80 const TfLiteTensor* tensor = runner->output_tensor(name);
81 LOG_ALWAYS_FATAL_IF(!tensor, "Failed to find output tensor '%s'", name);
82 return tensor;
83}
84
85// Returns the buffer for a tensor of type T.
86template <typename T>
87std::span<T> getTensorBuffer(typename std::conditional<std::is_const<T>::value, const TfLiteTensor*,
88 TfLiteTensor*>::type tensor) {
89 LOG_ALWAYS_FATAL_IF(!tensor);
90
91 const TfLiteType type = tflite::typeToTfLiteType<typename std::remove_cv<T>::type>();
92 LOG_ALWAYS_FATAL_IF(tensor->type != type, "Unexpected type for '%s' tensor: %s (expected %s)",
93 tensor->name, TfLiteTypeGetName(tensor->type), TfLiteTypeGetName(type));
94
95 LOG_ALWAYS_FATAL_IF(!tensor->data.data);
96 return {reinterpret_cast<T*>(tensor->data.data),
97 static_cast<typename std::span<T>::index_type>(tensor->bytes / sizeof(T))};
98}
99
100// Verifies that a tensor exists and has an underlying buffer of type T.
101template <typename T>
102void checkTensor(const TfLiteTensor* tensor) {
103 LOG_ALWAYS_FATAL_IF(!tensor);
104
105 const auto buffer = getTensorBuffer<const T>(tensor);
106 LOG_ALWAYS_FATAL_IF(buffer.empty(), "No buffer for tensor '%s'", tensor->name);
107}
108
Philip Quinnda6a4482023-02-07 10:09:57 -0800109std::unique_ptr<tflite::OpResolver> createOpResolver() {
110 auto resolver = std::make_unique<tflite::MutableOpResolver>();
111 resolver->AddBuiltin(::tflite::BuiltinOperator_CONCATENATION,
112 ::tflite::ops::builtin::Register_CONCATENATION());
113 resolver->AddBuiltin(::tflite::BuiltinOperator_FULLY_CONNECTED,
114 ::tflite::ops::builtin::Register_FULLY_CONNECTED());
115 return resolver;
116}
117
Philip Quinn8f953ab2022-12-06 15:37:07 -0800118} // namespace
119
Philip Quinn9b8926e2023-01-31 14:50:02 -0800120TfLiteMotionPredictorBuffers::TfLiteMotionPredictorBuffers(size_t inputLength)
121 : mInputR(inputLength, 0),
122 mInputPhi(inputLength, 0),
123 mInputPressure(inputLength, 0),
124 mInputTilt(inputLength, 0),
125 mInputOrientation(inputLength, 0) {
Philip Quinn8f953ab2022-12-06 15:37:07 -0800126 LOG_ALWAYS_FATAL_IF(inputLength == 0, "Buffer input size must be greater than 0");
Philip Quinn8f953ab2022-12-06 15:37:07 -0800127}
128
129void TfLiteMotionPredictorBuffers::reset() {
130 std::fill(mInputR.begin(), mInputR.end(), 0);
131 std::fill(mInputPhi.begin(), mInputPhi.end(), 0);
132 std::fill(mInputPressure.begin(), mInputPressure.end(), 0);
133 std::fill(mInputTilt.begin(), mInputTilt.end(), 0);
134 std::fill(mInputOrientation.begin(), mInputOrientation.end(), 0);
135 mAxisFrom.reset();
136 mAxisTo.reset();
137}
138
139void TfLiteMotionPredictorBuffers::copyTo(TfLiteMotionPredictorModel& model) const {
140 LOG_ALWAYS_FATAL_IF(mInputR.size() != model.inputLength(),
141 "Buffer length %zu doesn't match model input length %zu", mInputR.size(),
142 model.inputLength());
143 LOG_ALWAYS_FATAL_IF(!isReady(), "Buffers are incomplete");
144
145 std::copy(mInputR.begin(), mInputR.end(), model.inputR().begin());
146 std::copy(mInputPhi.begin(), mInputPhi.end(), model.inputPhi().begin());
147 std::copy(mInputPressure.begin(), mInputPressure.end(), model.inputPressure().begin());
148 std::copy(mInputTilt.begin(), mInputTilt.end(), model.inputTilt().begin());
149 std::copy(mInputOrientation.begin(), mInputOrientation.end(), model.inputOrientation().begin());
150}
151
152void TfLiteMotionPredictorBuffers::pushSample(int64_t timestamp,
153 const TfLiteMotionPredictorSample sample) {
154 // Convert the sample (x, y) into polar (r, φ) based on a reference axis
155 // from the preceding two points (mAxisFrom/mAxisTo).
156
157 mTimestamp = timestamp;
158
159 if (!mAxisTo) { // First point.
160 mAxisTo = sample;
161 return;
162 }
163
164 // Vector from the last point to the current sample point.
165 const TfLiteMotionPredictorSample::Point v = sample.position - mAxisTo->position;
166
167 const float r = std::hypot(v.x, v.y);
168 float phi = 0;
169 float orientation = 0;
170
171 // Ignore the sample if there is no movement. These samples can occur when there's change to a
172 // property other than the coordinates and pollute the input to the model.
173 if (r == 0) {
174 return;
175 }
176
177 if (!mAxisFrom) { // Second point.
178 // We can only determine the distance from the first point, and not any
179 // angle. However, if the second point forms an axis, the orientation can
180 // be transformed relative to that axis.
181 const float axisPhi = std::atan2(v.y, v.x);
182 // A MotionEvent's orientation is measured clockwise from the vertical
183 // axis, but axisPhi is measured counter-clockwise from the horizontal
184 // axis.
185 orientation = M_PI_2 - sample.orientation - axisPhi;
186 } else {
187 const TfLiteMotionPredictorSample::Point axis = mAxisTo->position - mAxisFrom->position;
188 const float axisPhi = std::atan2(axis.y, axis.x);
189 phi = std::atan2(v.y, v.x) - axisPhi;
190
191 if (std::hypot(axis.x, axis.y) > 0) {
192 // See note above.
193 orientation = M_PI_2 - sample.orientation - axisPhi;
194 }
195 }
196
197 // Update the axis for the next point.
198 mAxisFrom = mAxisTo;
199 mAxisTo = sample;
200
201 // Push the current sample onto the end of the input buffers.
Philip Quinn9b8926e2023-01-31 14:50:02 -0800202 mInputR.pushBack(r);
203 mInputPhi.pushBack(phi);
204 mInputPressure.pushBack(sample.pressure);
205 mInputTilt.pushBack(sample.tilt);
206 mInputOrientation.pushBack(orientation);
Philip Quinn8f953ab2022-12-06 15:37:07 -0800207}
208
209std::unique_ptr<TfLiteMotionPredictorModel> TfLiteMotionPredictorModel::create(
210 const char* modelPath) {
Philip Quinncb3229a2023-02-08 22:50:59 -0800211 const int fd = open(modelPath, O_RDONLY);
212 if (fd == -1) {
213 PLOG(FATAL) << "Could not read model from " << modelPath;
214 }
Philip Quinn8f953ab2022-12-06 15:37:07 -0800215
Philip Quinncb3229a2023-02-08 22:50:59 -0800216 const off_t fdSize = lseek(fd, 0, SEEK_END);
217 if (fdSize == -1) {
218 PLOG(FATAL) << "Failed to determine file size";
219 }
220
221 std::unique_ptr<android::base::MappedFile> modelBuffer =
222 android::base::MappedFile::FromFd(fd, /*offset=*/0, fdSize, PROT_READ);
223 if (!modelBuffer) {
224 PLOG(FATAL) << "Failed to mmap model";
225 }
226 if (close(fd) == -1) {
227 PLOG(FATAL) << "Failed to close model fd";
228 }
Philip Quinn8f953ab2022-12-06 15:37:07 -0800229
230 return std::unique_ptr<TfLiteMotionPredictorModel>(
Philip Quinncb3229a2023-02-08 22:50:59 -0800231 new TfLiteMotionPredictorModel(std::move(modelBuffer)));
Philip Quinn8f953ab2022-12-06 15:37:07 -0800232}
233
Philip Quinncb3229a2023-02-08 22:50:59 -0800234TfLiteMotionPredictorModel::TfLiteMotionPredictorModel(
235 std::unique_ptr<android::base::MappedFile> model)
Philip Quinn8f953ab2022-12-06 15:37:07 -0800236 : mFlatBuffer(std::move(model)) {
Philip Quinncb3229a2023-02-08 22:50:59 -0800237 CHECK(mFlatBuffer);
Philip Quinn8f953ab2022-12-06 15:37:07 -0800238 mErrorReporter = std::make_unique<LoggingErrorReporter>();
Philip Quinncb3229a2023-02-08 22:50:59 -0800239 mModel = tflite::FlatBufferModel::VerifyAndBuildFromBuffer(mFlatBuffer->data(),
240 mFlatBuffer->size(),
Philip Quinn8f953ab2022-12-06 15:37:07 -0800241 /*extra_verifier=*/nullptr,
242 mErrorReporter.get());
243 LOG_ALWAYS_FATAL_IF(!mModel);
244
Philip Quinnda6a4482023-02-07 10:09:57 -0800245 auto resolver = createOpResolver();
246 tflite::InterpreterBuilder builder(*mModel, *resolver);
Philip Quinn8f953ab2022-12-06 15:37:07 -0800247
248 if (builder(&mInterpreter) != kTfLiteOk || !mInterpreter) {
249 LOG_ALWAYS_FATAL("Failed to build interpreter");
250 }
251
252 mRunner = mInterpreter->GetSignatureRunner(SIGNATURE_KEY);
253 LOG_ALWAYS_FATAL_IF(!mRunner, "Failed to find runner for signature '%s'", SIGNATURE_KEY);
254
255 allocateTensors();
256}
257
Philip Quinnda6a4482023-02-07 10:09:57 -0800258TfLiteMotionPredictorModel::~TfLiteMotionPredictorModel() {}
259
Philip Quinn8f953ab2022-12-06 15:37:07 -0800260void TfLiteMotionPredictorModel::allocateTensors() {
261 if (mRunner->AllocateTensors() != kTfLiteOk) {
262 LOG_ALWAYS_FATAL("Failed to allocate tensors");
263 }
264
265 attachInputTensors();
266 attachOutputTensors();
267
268 checkTensor<float>(mInputR);
269 checkTensor<float>(mInputPhi);
270 checkTensor<float>(mInputPressure);
271 checkTensor<float>(mInputTilt);
272 checkTensor<float>(mInputOrientation);
273 checkTensor<float>(mOutputR);
274 checkTensor<float>(mOutputPhi);
275 checkTensor<float>(mOutputPressure);
276
277 const auto checkInputTensorSize = [this](const TfLiteTensor* tensor) {
278 const size_t size = getTensorBuffer<const float>(tensor).size();
279 LOG_ALWAYS_FATAL_IF(size != inputLength(),
280 "Tensor '%s' length %zu does not match input length %zu", tensor->name,
281 size, inputLength());
282 };
283
284 checkInputTensorSize(mInputR);
285 checkInputTensorSize(mInputPhi);
286 checkInputTensorSize(mInputPressure);
287 checkInputTensorSize(mInputTilt);
288 checkInputTensorSize(mInputOrientation);
289}
290
291void TfLiteMotionPredictorModel::attachInputTensors() {
292 mInputR = findInputTensor(INPUT_R, mRunner);
293 mInputPhi = findInputTensor(INPUT_PHI, mRunner);
294 mInputPressure = findInputTensor(INPUT_PRESSURE, mRunner);
295 mInputTilt = findInputTensor(INPUT_TILT, mRunner);
296 mInputOrientation = findInputTensor(INPUT_ORIENTATION, mRunner);
297}
298
299void TfLiteMotionPredictorModel::attachOutputTensors() {
300 mOutputR = findOutputTensor(OUTPUT_R, mRunner);
301 mOutputPhi = findOutputTensor(OUTPUT_PHI, mRunner);
302 mOutputPressure = findOutputTensor(OUTPUT_PRESSURE, mRunner);
303}
304
305bool TfLiteMotionPredictorModel::invoke() {
306 ATRACE_BEGIN("TfLiteMotionPredictorModel::invoke");
307 TfLiteStatus result = mRunner->Invoke();
308 ATRACE_END();
309
310 if (result != kTfLiteOk) {
311 return false;
312 }
313
314 // Invoke() might reallocate tensors, so they need to be reattached.
315 attachInputTensors();
316 attachOutputTensors();
317
318 if (outputR().size() != outputPhi().size() || outputR().size() != outputPressure().size()) {
319 LOG_ALWAYS_FATAL("Output size mismatch: (r: %zu, phi: %zu, pressure: %zu)",
320 outputR().size(), outputPhi().size(), outputPressure().size());
321 }
322
323 return true;
324}
325
326size_t TfLiteMotionPredictorModel::inputLength() const {
327 return getTensorBuffer<const float>(mInputR).size();
328}
329
330std::span<float> TfLiteMotionPredictorModel::inputR() {
331 return getTensorBuffer<float>(mInputR);
332}
333
334std::span<float> TfLiteMotionPredictorModel::inputPhi() {
335 return getTensorBuffer<float>(mInputPhi);
336}
337
338std::span<float> TfLiteMotionPredictorModel::inputPressure() {
339 return getTensorBuffer<float>(mInputPressure);
340}
341
342std::span<float> TfLiteMotionPredictorModel::inputTilt() {
343 return getTensorBuffer<float>(mInputTilt);
344}
345
346std::span<float> TfLiteMotionPredictorModel::inputOrientation() {
347 return getTensorBuffer<float>(mInputOrientation);
348}
349
350std::span<const float> TfLiteMotionPredictorModel::outputR() const {
351 return getTensorBuffer<const float>(mOutputR);
352}
353
354std::span<const float> TfLiteMotionPredictorModel::outputPhi() const {
355 return getTensorBuffer<const float>(mOutputPhi);
356}
357
358std::span<const float> TfLiteMotionPredictorModel::outputPressure() const {
359 return getTensorBuffer<const float>(mOutputPressure);
360}
361
362} // namespace android