blob: 3b061d1cf12459befa1e4d7f8a6c53ecbc2d49f1 [file] [log] [blame]
Philip Quinn8f953ab2022-12-06 15:37:07 -08001/*
2 * Copyright (C) 2023 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#define LOG_TAG "TfLiteMotionPredictor"
18#include <input/TfLiteMotionPredictor.h>
19
Philip Quinncb3229a2023-02-08 22:50:59 -080020#include <fcntl.h>
21#include <sys/mman.h>
22#include <unistd.h>
23
Philip Quinn8f953ab2022-12-06 15:37:07 -080024#include <algorithm>
25#include <cmath>
26#include <cstddef>
27#include <cstdint>
Philip Quinn8f953ab2022-12-06 15:37:07 -080028#include <memory>
29#include <span>
Philip Quinn8f953ab2022-12-06 15:37:07 -080030#include <type_traits>
31#include <utility>
32
Siarhei Vishniakoufd0a68e2023-02-28 13:25:36 -080033#include <android-base/file.h>
Philip Quinncb3229a2023-02-08 22:50:59 -080034#include <android-base/logging.h>
35#include <android-base/mapped_file.h>
Philip Quinn8f953ab2022-12-06 15:37:07 -080036#define ATRACE_TAG ATRACE_TAG_INPUT
37#include <cutils/trace.h>
38#include <log/log.h>
39
40#include "tensorflow/lite/core/api/error_reporter.h"
Philip Quinnda6a4482023-02-07 10:09:57 -080041#include "tensorflow/lite/core/api/op_resolver.h"
Philip Quinn8f953ab2022-12-06 15:37:07 -080042#include "tensorflow/lite/interpreter.h"
Philip Quinnda6a4482023-02-07 10:09:57 -080043#include "tensorflow/lite/kernels/builtin_op_kernels.h"
Philip Quinn8f953ab2022-12-06 15:37:07 -080044#include "tensorflow/lite/model.h"
Philip Quinnda6a4482023-02-07 10:09:57 -080045#include "tensorflow/lite/mutable_op_resolver.h"
Philip Quinn8f953ab2022-12-06 15:37:07 -080046
47namespace android {
48namespace {
49
50constexpr char SIGNATURE_KEY[] = "serving_default";
51
52// Input tensor names.
53constexpr char INPUT_R[] = "r";
54constexpr char INPUT_PHI[] = "phi";
55constexpr char INPUT_PRESSURE[] = "pressure";
56constexpr char INPUT_TILT[] = "tilt";
57constexpr char INPUT_ORIENTATION[] = "orientation";
58
59// Output tensor names.
60constexpr char OUTPUT_R[] = "r";
61constexpr char OUTPUT_PHI[] = "phi";
62constexpr char OUTPUT_PRESSURE[] = "pressure";
63
Siarhei Vishniakouc065d7b2023-03-02 14:06:29 -080064// Ideally, we would just use std::filesystem::exists here, but it requires libc++fs, which causes
65// build issues in other parts of the system.
66#if defined(__ANDROID__)
67bool fileExists(const char* filename) {
68 struct stat buffer;
69 return stat(filename, &buffer) == 0;
70}
71#endif
72
Siarhei Vishniakoufd0a68e2023-02-28 13:25:36 -080073std::string getModelPath() {
74#if defined(__ANDROID__)
Siarhei Vishniakouc065d7b2023-03-02 14:06:29 -080075 static const char* oemModel = "/vendor/etc/motion_predictor_model.fb";
76 if (fileExists(oemModel)) {
77 return oemModel;
78 }
Siarhei Vishniakoufd0a68e2023-02-28 13:25:36 -080079 return "/system/etc/motion_predictor_model.fb";
80#else
81 return base::GetExecutableDirectory() + "/motion_predictor_model.fb";
82#endif
83}
84
Philip Quinn8f953ab2022-12-06 15:37:07 -080085// A TFLite ErrorReporter that logs to logcat.
86class LoggingErrorReporter : public tflite::ErrorReporter {
87public:
88 int Report(const char* format, va_list args) override {
89 return LOG_PRI_VA(ANDROID_LOG_ERROR, LOG_TAG, format, args);
90 }
91};
92
93// Searches a runner for an input tensor.
94TfLiteTensor* findInputTensor(const char* name, tflite::SignatureRunner* runner) {
95 TfLiteTensor* tensor = runner->input_tensor(name);
96 LOG_ALWAYS_FATAL_IF(!tensor, "Failed to find input tensor '%s'", name);
97 return tensor;
98}
99
100// Searches a runner for an output tensor.
101const TfLiteTensor* findOutputTensor(const char* name, tflite::SignatureRunner* runner) {
102 const TfLiteTensor* tensor = runner->output_tensor(name);
103 LOG_ALWAYS_FATAL_IF(!tensor, "Failed to find output tensor '%s'", name);
104 return tensor;
105}
106
107// Returns the buffer for a tensor of type T.
108template <typename T>
109std::span<T> getTensorBuffer(typename std::conditional<std::is_const<T>::value, const TfLiteTensor*,
110 TfLiteTensor*>::type tensor) {
111 LOG_ALWAYS_FATAL_IF(!tensor);
112
113 const TfLiteType type = tflite::typeToTfLiteType<typename std::remove_cv<T>::type>();
114 LOG_ALWAYS_FATAL_IF(tensor->type != type, "Unexpected type for '%s' tensor: %s (expected %s)",
115 tensor->name, TfLiteTypeGetName(tensor->type), TfLiteTypeGetName(type));
116
117 LOG_ALWAYS_FATAL_IF(!tensor->data.data);
118 return {reinterpret_cast<T*>(tensor->data.data),
119 static_cast<typename std::span<T>::index_type>(tensor->bytes / sizeof(T))};
120}
121
122// Verifies that a tensor exists and has an underlying buffer of type T.
123template <typename T>
124void checkTensor(const TfLiteTensor* tensor) {
125 LOG_ALWAYS_FATAL_IF(!tensor);
126
127 const auto buffer = getTensorBuffer<const T>(tensor);
128 LOG_ALWAYS_FATAL_IF(buffer.empty(), "No buffer for tensor '%s'", tensor->name);
129}
130
Philip Quinnda6a4482023-02-07 10:09:57 -0800131std::unique_ptr<tflite::OpResolver> createOpResolver() {
132 auto resolver = std::make_unique<tflite::MutableOpResolver>();
133 resolver->AddBuiltin(::tflite::BuiltinOperator_CONCATENATION,
134 ::tflite::ops::builtin::Register_CONCATENATION());
135 resolver->AddBuiltin(::tflite::BuiltinOperator_FULLY_CONNECTED,
136 ::tflite::ops::builtin::Register_FULLY_CONNECTED());
137 return resolver;
138}
139
Philip Quinn8f953ab2022-12-06 15:37:07 -0800140} // namespace
141
Philip Quinn9b8926e2023-01-31 14:50:02 -0800142TfLiteMotionPredictorBuffers::TfLiteMotionPredictorBuffers(size_t inputLength)
143 : mInputR(inputLength, 0),
144 mInputPhi(inputLength, 0),
145 mInputPressure(inputLength, 0),
146 mInputTilt(inputLength, 0),
147 mInputOrientation(inputLength, 0) {
Philip Quinn8f953ab2022-12-06 15:37:07 -0800148 LOG_ALWAYS_FATAL_IF(inputLength == 0, "Buffer input size must be greater than 0");
Philip Quinn8f953ab2022-12-06 15:37:07 -0800149}
150
151void TfLiteMotionPredictorBuffers::reset() {
152 std::fill(mInputR.begin(), mInputR.end(), 0);
153 std::fill(mInputPhi.begin(), mInputPhi.end(), 0);
154 std::fill(mInputPressure.begin(), mInputPressure.end(), 0);
155 std::fill(mInputTilt.begin(), mInputTilt.end(), 0);
156 std::fill(mInputOrientation.begin(), mInputOrientation.end(), 0);
157 mAxisFrom.reset();
158 mAxisTo.reset();
159}
160
161void TfLiteMotionPredictorBuffers::copyTo(TfLiteMotionPredictorModel& model) const {
162 LOG_ALWAYS_FATAL_IF(mInputR.size() != model.inputLength(),
163 "Buffer length %zu doesn't match model input length %zu", mInputR.size(),
164 model.inputLength());
165 LOG_ALWAYS_FATAL_IF(!isReady(), "Buffers are incomplete");
166
167 std::copy(mInputR.begin(), mInputR.end(), model.inputR().begin());
168 std::copy(mInputPhi.begin(), mInputPhi.end(), model.inputPhi().begin());
169 std::copy(mInputPressure.begin(), mInputPressure.end(), model.inputPressure().begin());
170 std::copy(mInputTilt.begin(), mInputTilt.end(), model.inputTilt().begin());
171 std::copy(mInputOrientation.begin(), mInputOrientation.end(), model.inputOrientation().begin());
172}
173
174void TfLiteMotionPredictorBuffers::pushSample(int64_t timestamp,
175 const TfLiteMotionPredictorSample sample) {
176 // Convert the sample (x, y) into polar (r, φ) based on a reference axis
177 // from the preceding two points (mAxisFrom/mAxisTo).
178
179 mTimestamp = timestamp;
180
181 if (!mAxisTo) { // First point.
182 mAxisTo = sample;
183 return;
184 }
185
186 // Vector from the last point to the current sample point.
187 const TfLiteMotionPredictorSample::Point v = sample.position - mAxisTo->position;
188
189 const float r = std::hypot(v.x, v.y);
190 float phi = 0;
191 float orientation = 0;
192
193 // Ignore the sample if there is no movement. These samples can occur when there's change to a
194 // property other than the coordinates and pollute the input to the model.
195 if (r == 0) {
196 return;
197 }
198
199 if (!mAxisFrom) { // Second point.
200 // We can only determine the distance from the first point, and not any
201 // angle. However, if the second point forms an axis, the orientation can
202 // be transformed relative to that axis.
203 const float axisPhi = std::atan2(v.y, v.x);
204 // A MotionEvent's orientation is measured clockwise from the vertical
205 // axis, but axisPhi is measured counter-clockwise from the horizontal
206 // axis.
207 orientation = M_PI_2 - sample.orientation - axisPhi;
208 } else {
209 const TfLiteMotionPredictorSample::Point axis = mAxisTo->position - mAxisFrom->position;
210 const float axisPhi = std::atan2(axis.y, axis.x);
211 phi = std::atan2(v.y, v.x) - axisPhi;
212
213 if (std::hypot(axis.x, axis.y) > 0) {
214 // See note above.
215 orientation = M_PI_2 - sample.orientation - axisPhi;
216 }
217 }
218
219 // Update the axis for the next point.
220 mAxisFrom = mAxisTo;
221 mAxisTo = sample;
222
223 // Push the current sample onto the end of the input buffers.
Philip Quinn9b8926e2023-01-31 14:50:02 -0800224 mInputR.pushBack(r);
225 mInputPhi.pushBack(phi);
226 mInputPressure.pushBack(sample.pressure);
227 mInputTilt.pushBack(sample.tilt);
228 mInputOrientation.pushBack(orientation);
Philip Quinn8f953ab2022-12-06 15:37:07 -0800229}
230
Siarhei Vishniakoufd0a68e2023-02-28 13:25:36 -0800231std::unique_ptr<TfLiteMotionPredictorModel> TfLiteMotionPredictorModel::create() {
232 const std::string modelPath = getModelPath();
Siarhei Vishniakouc065d7b2023-03-02 14:06:29 -0800233 android::base::unique_fd fd(open(modelPath.c_str(), O_RDONLY));
Philip Quinncb3229a2023-02-08 22:50:59 -0800234 if (fd == -1) {
235 PLOG(FATAL) << "Could not read model from " << modelPath;
236 }
Philip Quinn8f953ab2022-12-06 15:37:07 -0800237
Philip Quinncb3229a2023-02-08 22:50:59 -0800238 const off_t fdSize = lseek(fd, 0, SEEK_END);
239 if (fdSize == -1) {
240 PLOG(FATAL) << "Failed to determine file size";
241 }
242
243 std::unique_ptr<android::base::MappedFile> modelBuffer =
244 android::base::MappedFile::FromFd(fd, /*offset=*/0, fdSize, PROT_READ);
245 if (!modelBuffer) {
246 PLOG(FATAL) << "Failed to mmap model";
247 }
Philip Quinn8f953ab2022-12-06 15:37:07 -0800248
249 return std::unique_ptr<TfLiteMotionPredictorModel>(
Philip Quinncb3229a2023-02-08 22:50:59 -0800250 new TfLiteMotionPredictorModel(std::move(modelBuffer)));
Philip Quinn8f953ab2022-12-06 15:37:07 -0800251}
252
Philip Quinncb3229a2023-02-08 22:50:59 -0800253TfLiteMotionPredictorModel::TfLiteMotionPredictorModel(
254 std::unique_ptr<android::base::MappedFile> model)
Philip Quinn8f953ab2022-12-06 15:37:07 -0800255 : mFlatBuffer(std::move(model)) {
Philip Quinncb3229a2023-02-08 22:50:59 -0800256 CHECK(mFlatBuffer);
Philip Quinn8f953ab2022-12-06 15:37:07 -0800257 mErrorReporter = std::make_unique<LoggingErrorReporter>();
Philip Quinncb3229a2023-02-08 22:50:59 -0800258 mModel = tflite::FlatBufferModel::VerifyAndBuildFromBuffer(mFlatBuffer->data(),
259 mFlatBuffer->size(),
Philip Quinn8f953ab2022-12-06 15:37:07 -0800260 /*extra_verifier=*/nullptr,
261 mErrorReporter.get());
262 LOG_ALWAYS_FATAL_IF(!mModel);
263
Philip Quinnda6a4482023-02-07 10:09:57 -0800264 auto resolver = createOpResolver();
265 tflite::InterpreterBuilder builder(*mModel, *resolver);
Philip Quinn8f953ab2022-12-06 15:37:07 -0800266
267 if (builder(&mInterpreter) != kTfLiteOk || !mInterpreter) {
268 LOG_ALWAYS_FATAL("Failed to build interpreter");
269 }
270
271 mRunner = mInterpreter->GetSignatureRunner(SIGNATURE_KEY);
272 LOG_ALWAYS_FATAL_IF(!mRunner, "Failed to find runner for signature '%s'", SIGNATURE_KEY);
273
274 allocateTensors();
275}
276
Philip Quinnda6a4482023-02-07 10:09:57 -0800277TfLiteMotionPredictorModel::~TfLiteMotionPredictorModel() {}
278
Philip Quinn8f953ab2022-12-06 15:37:07 -0800279void TfLiteMotionPredictorModel::allocateTensors() {
280 if (mRunner->AllocateTensors() != kTfLiteOk) {
281 LOG_ALWAYS_FATAL("Failed to allocate tensors");
282 }
283
284 attachInputTensors();
285 attachOutputTensors();
286
287 checkTensor<float>(mInputR);
288 checkTensor<float>(mInputPhi);
289 checkTensor<float>(mInputPressure);
290 checkTensor<float>(mInputTilt);
291 checkTensor<float>(mInputOrientation);
292 checkTensor<float>(mOutputR);
293 checkTensor<float>(mOutputPhi);
294 checkTensor<float>(mOutputPressure);
295
296 const auto checkInputTensorSize = [this](const TfLiteTensor* tensor) {
297 const size_t size = getTensorBuffer<const float>(tensor).size();
298 LOG_ALWAYS_FATAL_IF(size != inputLength(),
299 "Tensor '%s' length %zu does not match input length %zu", tensor->name,
300 size, inputLength());
301 };
302
303 checkInputTensorSize(mInputR);
304 checkInputTensorSize(mInputPhi);
305 checkInputTensorSize(mInputPressure);
306 checkInputTensorSize(mInputTilt);
307 checkInputTensorSize(mInputOrientation);
308}
309
310void TfLiteMotionPredictorModel::attachInputTensors() {
311 mInputR = findInputTensor(INPUT_R, mRunner);
312 mInputPhi = findInputTensor(INPUT_PHI, mRunner);
313 mInputPressure = findInputTensor(INPUT_PRESSURE, mRunner);
314 mInputTilt = findInputTensor(INPUT_TILT, mRunner);
315 mInputOrientation = findInputTensor(INPUT_ORIENTATION, mRunner);
316}
317
318void TfLiteMotionPredictorModel::attachOutputTensors() {
319 mOutputR = findOutputTensor(OUTPUT_R, mRunner);
320 mOutputPhi = findOutputTensor(OUTPUT_PHI, mRunner);
321 mOutputPressure = findOutputTensor(OUTPUT_PRESSURE, mRunner);
322}
323
324bool TfLiteMotionPredictorModel::invoke() {
325 ATRACE_BEGIN("TfLiteMotionPredictorModel::invoke");
326 TfLiteStatus result = mRunner->Invoke();
327 ATRACE_END();
328
329 if (result != kTfLiteOk) {
330 return false;
331 }
332
333 // Invoke() might reallocate tensors, so they need to be reattached.
334 attachInputTensors();
335 attachOutputTensors();
336
337 if (outputR().size() != outputPhi().size() || outputR().size() != outputPressure().size()) {
338 LOG_ALWAYS_FATAL("Output size mismatch: (r: %zu, phi: %zu, pressure: %zu)",
339 outputR().size(), outputPhi().size(), outputPressure().size());
340 }
341
342 return true;
343}
344
345size_t TfLiteMotionPredictorModel::inputLength() const {
346 return getTensorBuffer<const float>(mInputR).size();
347}
348
349std::span<float> TfLiteMotionPredictorModel::inputR() {
350 return getTensorBuffer<float>(mInputR);
351}
352
353std::span<float> TfLiteMotionPredictorModel::inputPhi() {
354 return getTensorBuffer<float>(mInputPhi);
355}
356
357std::span<float> TfLiteMotionPredictorModel::inputPressure() {
358 return getTensorBuffer<float>(mInputPressure);
359}
360
361std::span<float> TfLiteMotionPredictorModel::inputTilt() {
362 return getTensorBuffer<float>(mInputTilt);
363}
364
365std::span<float> TfLiteMotionPredictorModel::inputOrientation() {
366 return getTensorBuffer<float>(mInputOrientation);
367}
368
369std::span<const float> TfLiteMotionPredictorModel::outputR() const {
370 return getTensorBuffer<const float>(mOutputR);
371}
372
373std::span<const float> TfLiteMotionPredictorModel::outputPhi() const {
374 return getTensorBuffer<const float>(mOutputPhi);
375}
376
377std::span<const float> TfLiteMotionPredictorModel::outputPressure() const {
378 return getTensorBuffer<const float>(mOutputPressure);
379}
380
381} // namespace android