blob: 691e87c3669a95c5564fba10c21270faad9c40e4 [file] [log] [blame]
Philip Quinn8f953ab2022-12-06 15:37:07 -08001/*
2 * Copyright (C) 2023 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#define LOG_TAG "TfLiteMotionPredictor"
18#include <input/TfLiteMotionPredictor.h>
19
Philip Quinncb3229a2023-02-08 22:50:59 -080020#include <fcntl.h>
21#include <sys/mman.h>
22#include <unistd.h>
23
Philip Quinn8f953ab2022-12-06 15:37:07 -080024#include <algorithm>
25#include <cmath>
26#include <cstddef>
27#include <cstdint>
Philip Quinn8f953ab2022-12-06 15:37:07 -080028#include <memory>
29#include <span>
Philip Quinn8f953ab2022-12-06 15:37:07 -080030#include <type_traits>
31#include <utility>
32
Siarhei Vishniakoufd0a68e2023-02-28 13:25:36 -080033#include <android-base/file.h>
Philip Quinncb3229a2023-02-08 22:50:59 -080034#include <android-base/logging.h>
35#include <android-base/mapped_file.h>
Philip Quinn8f953ab2022-12-06 15:37:07 -080036#define ATRACE_TAG ATRACE_TAG_INPUT
37#include <cutils/trace.h>
38#include <log/log.h>
39
40#include "tensorflow/lite/core/api/error_reporter.h"
Philip Quinnda6a4482023-02-07 10:09:57 -080041#include "tensorflow/lite/core/api/op_resolver.h"
Philip Quinn8f953ab2022-12-06 15:37:07 -080042#include "tensorflow/lite/interpreter.h"
Philip Quinnda6a4482023-02-07 10:09:57 -080043#include "tensorflow/lite/kernels/builtin_op_kernels.h"
Philip Quinn8f953ab2022-12-06 15:37:07 -080044#include "tensorflow/lite/model.h"
Philip Quinnda6a4482023-02-07 10:09:57 -080045#include "tensorflow/lite/mutable_op_resolver.h"
Philip Quinn8f953ab2022-12-06 15:37:07 -080046
47namespace android {
48namespace {
49
50constexpr char SIGNATURE_KEY[] = "serving_default";
51
52// Input tensor names.
53constexpr char INPUT_R[] = "r";
54constexpr char INPUT_PHI[] = "phi";
55constexpr char INPUT_PRESSURE[] = "pressure";
56constexpr char INPUT_TILT[] = "tilt";
57constexpr char INPUT_ORIENTATION[] = "orientation";
58
59// Output tensor names.
60constexpr char OUTPUT_R[] = "r";
61constexpr char OUTPUT_PHI[] = "phi";
62constexpr char OUTPUT_PRESSURE[] = "pressure";
63
Siarhei Vishniakoufd0a68e2023-02-28 13:25:36 -080064std::string getModelPath() {
65#if defined(__ANDROID__)
66 return "/system/etc/motion_predictor_model.fb";
67#else
68 return base::GetExecutableDirectory() + "/motion_predictor_model.fb";
69#endif
70}
71
Philip Quinn8f953ab2022-12-06 15:37:07 -080072// A TFLite ErrorReporter that logs to logcat.
73class LoggingErrorReporter : public tflite::ErrorReporter {
74public:
75 int Report(const char* format, va_list args) override {
76 return LOG_PRI_VA(ANDROID_LOG_ERROR, LOG_TAG, format, args);
77 }
78};
79
80// Searches a runner for an input tensor.
81TfLiteTensor* findInputTensor(const char* name, tflite::SignatureRunner* runner) {
82 TfLiteTensor* tensor = runner->input_tensor(name);
83 LOG_ALWAYS_FATAL_IF(!tensor, "Failed to find input tensor '%s'", name);
84 return tensor;
85}
86
87// Searches a runner for an output tensor.
88const TfLiteTensor* findOutputTensor(const char* name, tflite::SignatureRunner* runner) {
89 const TfLiteTensor* tensor = runner->output_tensor(name);
90 LOG_ALWAYS_FATAL_IF(!tensor, "Failed to find output tensor '%s'", name);
91 return tensor;
92}
93
94// Returns the buffer for a tensor of type T.
95template <typename T>
96std::span<T> getTensorBuffer(typename std::conditional<std::is_const<T>::value, const TfLiteTensor*,
97 TfLiteTensor*>::type tensor) {
98 LOG_ALWAYS_FATAL_IF(!tensor);
99
100 const TfLiteType type = tflite::typeToTfLiteType<typename std::remove_cv<T>::type>();
101 LOG_ALWAYS_FATAL_IF(tensor->type != type, "Unexpected type for '%s' tensor: %s (expected %s)",
102 tensor->name, TfLiteTypeGetName(tensor->type), TfLiteTypeGetName(type));
103
104 LOG_ALWAYS_FATAL_IF(!tensor->data.data);
105 return {reinterpret_cast<T*>(tensor->data.data),
106 static_cast<typename std::span<T>::index_type>(tensor->bytes / sizeof(T))};
107}
108
109// Verifies that a tensor exists and has an underlying buffer of type T.
110template <typename T>
111void checkTensor(const TfLiteTensor* tensor) {
112 LOG_ALWAYS_FATAL_IF(!tensor);
113
114 const auto buffer = getTensorBuffer<const T>(tensor);
115 LOG_ALWAYS_FATAL_IF(buffer.empty(), "No buffer for tensor '%s'", tensor->name);
116}
117
Philip Quinnda6a4482023-02-07 10:09:57 -0800118std::unique_ptr<tflite::OpResolver> createOpResolver() {
119 auto resolver = std::make_unique<tflite::MutableOpResolver>();
120 resolver->AddBuiltin(::tflite::BuiltinOperator_CONCATENATION,
121 ::tflite::ops::builtin::Register_CONCATENATION());
122 resolver->AddBuiltin(::tflite::BuiltinOperator_FULLY_CONNECTED,
123 ::tflite::ops::builtin::Register_FULLY_CONNECTED());
124 return resolver;
125}
126
Philip Quinn8f953ab2022-12-06 15:37:07 -0800127} // namespace
128
Philip Quinn9b8926e2023-01-31 14:50:02 -0800129TfLiteMotionPredictorBuffers::TfLiteMotionPredictorBuffers(size_t inputLength)
130 : mInputR(inputLength, 0),
131 mInputPhi(inputLength, 0),
132 mInputPressure(inputLength, 0),
133 mInputTilt(inputLength, 0),
134 mInputOrientation(inputLength, 0) {
Philip Quinn8f953ab2022-12-06 15:37:07 -0800135 LOG_ALWAYS_FATAL_IF(inputLength == 0, "Buffer input size must be greater than 0");
Philip Quinn8f953ab2022-12-06 15:37:07 -0800136}
137
138void TfLiteMotionPredictorBuffers::reset() {
139 std::fill(mInputR.begin(), mInputR.end(), 0);
140 std::fill(mInputPhi.begin(), mInputPhi.end(), 0);
141 std::fill(mInputPressure.begin(), mInputPressure.end(), 0);
142 std::fill(mInputTilt.begin(), mInputTilt.end(), 0);
143 std::fill(mInputOrientation.begin(), mInputOrientation.end(), 0);
144 mAxisFrom.reset();
145 mAxisTo.reset();
146}
147
148void TfLiteMotionPredictorBuffers::copyTo(TfLiteMotionPredictorModel& model) const {
149 LOG_ALWAYS_FATAL_IF(mInputR.size() != model.inputLength(),
150 "Buffer length %zu doesn't match model input length %zu", mInputR.size(),
151 model.inputLength());
152 LOG_ALWAYS_FATAL_IF(!isReady(), "Buffers are incomplete");
153
154 std::copy(mInputR.begin(), mInputR.end(), model.inputR().begin());
155 std::copy(mInputPhi.begin(), mInputPhi.end(), model.inputPhi().begin());
156 std::copy(mInputPressure.begin(), mInputPressure.end(), model.inputPressure().begin());
157 std::copy(mInputTilt.begin(), mInputTilt.end(), model.inputTilt().begin());
158 std::copy(mInputOrientation.begin(), mInputOrientation.end(), model.inputOrientation().begin());
159}
160
161void TfLiteMotionPredictorBuffers::pushSample(int64_t timestamp,
162 const TfLiteMotionPredictorSample sample) {
163 // Convert the sample (x, y) into polar (r, φ) based on a reference axis
164 // from the preceding two points (mAxisFrom/mAxisTo).
165
166 mTimestamp = timestamp;
167
168 if (!mAxisTo) { // First point.
169 mAxisTo = sample;
170 return;
171 }
172
173 // Vector from the last point to the current sample point.
174 const TfLiteMotionPredictorSample::Point v = sample.position - mAxisTo->position;
175
176 const float r = std::hypot(v.x, v.y);
177 float phi = 0;
178 float orientation = 0;
179
180 // Ignore the sample if there is no movement. These samples can occur when there's change to a
181 // property other than the coordinates and pollute the input to the model.
182 if (r == 0) {
183 return;
184 }
185
186 if (!mAxisFrom) { // Second point.
187 // We can only determine the distance from the first point, and not any
188 // angle. However, if the second point forms an axis, the orientation can
189 // be transformed relative to that axis.
190 const float axisPhi = std::atan2(v.y, v.x);
191 // A MotionEvent's orientation is measured clockwise from the vertical
192 // axis, but axisPhi is measured counter-clockwise from the horizontal
193 // axis.
194 orientation = M_PI_2 - sample.orientation - axisPhi;
195 } else {
196 const TfLiteMotionPredictorSample::Point axis = mAxisTo->position - mAxisFrom->position;
197 const float axisPhi = std::atan2(axis.y, axis.x);
198 phi = std::atan2(v.y, v.x) - axisPhi;
199
200 if (std::hypot(axis.x, axis.y) > 0) {
201 // See note above.
202 orientation = M_PI_2 - sample.orientation - axisPhi;
203 }
204 }
205
206 // Update the axis for the next point.
207 mAxisFrom = mAxisTo;
208 mAxisTo = sample;
209
210 // Push the current sample onto the end of the input buffers.
Philip Quinn9b8926e2023-01-31 14:50:02 -0800211 mInputR.pushBack(r);
212 mInputPhi.pushBack(phi);
213 mInputPressure.pushBack(sample.pressure);
214 mInputTilt.pushBack(sample.tilt);
215 mInputOrientation.pushBack(orientation);
Philip Quinn8f953ab2022-12-06 15:37:07 -0800216}
217
Siarhei Vishniakoufd0a68e2023-02-28 13:25:36 -0800218std::unique_ptr<TfLiteMotionPredictorModel> TfLiteMotionPredictorModel::create() {
219 const std::string modelPath = getModelPath();
220 const int fd = open(modelPath.c_str(), O_RDONLY);
Philip Quinncb3229a2023-02-08 22:50:59 -0800221 if (fd == -1) {
222 PLOG(FATAL) << "Could not read model from " << modelPath;
223 }
Philip Quinn8f953ab2022-12-06 15:37:07 -0800224
Philip Quinncb3229a2023-02-08 22:50:59 -0800225 const off_t fdSize = lseek(fd, 0, SEEK_END);
226 if (fdSize == -1) {
227 PLOG(FATAL) << "Failed to determine file size";
228 }
229
230 std::unique_ptr<android::base::MappedFile> modelBuffer =
231 android::base::MappedFile::FromFd(fd, /*offset=*/0, fdSize, PROT_READ);
232 if (!modelBuffer) {
233 PLOG(FATAL) << "Failed to mmap model";
234 }
235 if (close(fd) == -1) {
236 PLOG(FATAL) << "Failed to close model fd";
237 }
Philip Quinn8f953ab2022-12-06 15:37:07 -0800238
239 return std::unique_ptr<TfLiteMotionPredictorModel>(
Philip Quinncb3229a2023-02-08 22:50:59 -0800240 new TfLiteMotionPredictorModel(std::move(modelBuffer)));
Philip Quinn8f953ab2022-12-06 15:37:07 -0800241}
242
Philip Quinncb3229a2023-02-08 22:50:59 -0800243TfLiteMotionPredictorModel::TfLiteMotionPredictorModel(
244 std::unique_ptr<android::base::MappedFile> model)
Philip Quinn8f953ab2022-12-06 15:37:07 -0800245 : mFlatBuffer(std::move(model)) {
Philip Quinncb3229a2023-02-08 22:50:59 -0800246 CHECK(mFlatBuffer);
Philip Quinn8f953ab2022-12-06 15:37:07 -0800247 mErrorReporter = std::make_unique<LoggingErrorReporter>();
Philip Quinncb3229a2023-02-08 22:50:59 -0800248 mModel = tflite::FlatBufferModel::VerifyAndBuildFromBuffer(mFlatBuffer->data(),
249 mFlatBuffer->size(),
Philip Quinn8f953ab2022-12-06 15:37:07 -0800250 /*extra_verifier=*/nullptr,
251 mErrorReporter.get());
252 LOG_ALWAYS_FATAL_IF(!mModel);
253
Philip Quinnda6a4482023-02-07 10:09:57 -0800254 auto resolver = createOpResolver();
255 tflite::InterpreterBuilder builder(*mModel, *resolver);
Philip Quinn8f953ab2022-12-06 15:37:07 -0800256
257 if (builder(&mInterpreter) != kTfLiteOk || !mInterpreter) {
258 LOG_ALWAYS_FATAL("Failed to build interpreter");
259 }
260
261 mRunner = mInterpreter->GetSignatureRunner(SIGNATURE_KEY);
262 LOG_ALWAYS_FATAL_IF(!mRunner, "Failed to find runner for signature '%s'", SIGNATURE_KEY);
263
264 allocateTensors();
265}
266
Philip Quinnda6a4482023-02-07 10:09:57 -0800267TfLiteMotionPredictorModel::~TfLiteMotionPredictorModel() {}
268
Philip Quinn8f953ab2022-12-06 15:37:07 -0800269void TfLiteMotionPredictorModel::allocateTensors() {
270 if (mRunner->AllocateTensors() != kTfLiteOk) {
271 LOG_ALWAYS_FATAL("Failed to allocate tensors");
272 }
273
274 attachInputTensors();
275 attachOutputTensors();
276
277 checkTensor<float>(mInputR);
278 checkTensor<float>(mInputPhi);
279 checkTensor<float>(mInputPressure);
280 checkTensor<float>(mInputTilt);
281 checkTensor<float>(mInputOrientation);
282 checkTensor<float>(mOutputR);
283 checkTensor<float>(mOutputPhi);
284 checkTensor<float>(mOutputPressure);
285
286 const auto checkInputTensorSize = [this](const TfLiteTensor* tensor) {
287 const size_t size = getTensorBuffer<const float>(tensor).size();
288 LOG_ALWAYS_FATAL_IF(size != inputLength(),
289 "Tensor '%s' length %zu does not match input length %zu", tensor->name,
290 size, inputLength());
291 };
292
293 checkInputTensorSize(mInputR);
294 checkInputTensorSize(mInputPhi);
295 checkInputTensorSize(mInputPressure);
296 checkInputTensorSize(mInputTilt);
297 checkInputTensorSize(mInputOrientation);
298}
299
300void TfLiteMotionPredictorModel::attachInputTensors() {
301 mInputR = findInputTensor(INPUT_R, mRunner);
302 mInputPhi = findInputTensor(INPUT_PHI, mRunner);
303 mInputPressure = findInputTensor(INPUT_PRESSURE, mRunner);
304 mInputTilt = findInputTensor(INPUT_TILT, mRunner);
305 mInputOrientation = findInputTensor(INPUT_ORIENTATION, mRunner);
306}
307
308void TfLiteMotionPredictorModel::attachOutputTensors() {
309 mOutputR = findOutputTensor(OUTPUT_R, mRunner);
310 mOutputPhi = findOutputTensor(OUTPUT_PHI, mRunner);
311 mOutputPressure = findOutputTensor(OUTPUT_PRESSURE, mRunner);
312}
313
314bool TfLiteMotionPredictorModel::invoke() {
315 ATRACE_BEGIN("TfLiteMotionPredictorModel::invoke");
316 TfLiteStatus result = mRunner->Invoke();
317 ATRACE_END();
318
319 if (result != kTfLiteOk) {
320 return false;
321 }
322
323 // Invoke() might reallocate tensors, so they need to be reattached.
324 attachInputTensors();
325 attachOutputTensors();
326
327 if (outputR().size() != outputPhi().size() || outputR().size() != outputPressure().size()) {
328 LOG_ALWAYS_FATAL("Output size mismatch: (r: %zu, phi: %zu, pressure: %zu)",
329 outputR().size(), outputPhi().size(), outputPressure().size());
330 }
331
332 return true;
333}
334
335size_t TfLiteMotionPredictorModel::inputLength() const {
336 return getTensorBuffer<const float>(mInputR).size();
337}
338
339std::span<float> TfLiteMotionPredictorModel::inputR() {
340 return getTensorBuffer<float>(mInputR);
341}
342
343std::span<float> TfLiteMotionPredictorModel::inputPhi() {
344 return getTensorBuffer<float>(mInputPhi);
345}
346
347std::span<float> TfLiteMotionPredictorModel::inputPressure() {
348 return getTensorBuffer<float>(mInputPressure);
349}
350
351std::span<float> TfLiteMotionPredictorModel::inputTilt() {
352 return getTensorBuffer<float>(mInputTilt);
353}
354
355std::span<float> TfLiteMotionPredictorModel::inputOrientation() {
356 return getTensorBuffer<float>(mInputOrientation);
357}
358
359std::span<const float> TfLiteMotionPredictorModel::outputR() const {
360 return getTensorBuffer<const float>(mOutputR);
361}
362
363std::span<const float> TfLiteMotionPredictorModel::outputPhi() const {
364 return getTensorBuffer<const float>(mOutputPhi);
365}
366
367std::span<const float> TfLiteMotionPredictorModel::outputPressure() const {
368 return getTensorBuffer<const float>(mOutputPressure);
369}
370
371} // namespace android