blob: 8d10ff56b0b9041afc540084ec7fdab7dbf1b068 [file] [log] [blame]
Philip Quinn8f953ab2022-12-06 15:37:07 -08001/*
2 * Copyright (C) 2023 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#define LOG_TAG "TfLiteMotionPredictor"
18#include <input/TfLiteMotionPredictor.h>
19
Philip Quinncb3229a2023-02-08 22:50:59 -080020#include <fcntl.h>
21#include <sys/mman.h>
22#include <unistd.h>
23
Philip Quinn8f953ab2022-12-06 15:37:07 -080024#include <algorithm>
25#include <cmath>
26#include <cstddef>
27#include <cstdint>
Philip Quinn8f953ab2022-12-06 15:37:07 -080028#include <memory>
29#include <span>
Philip Quinn8f953ab2022-12-06 15:37:07 -080030#include <type_traits>
31#include <utility>
32
Siarhei Vishniakoufd0a68e2023-02-28 13:25:36 -080033#include <android-base/file.h>
Philip Quinncb3229a2023-02-08 22:50:59 -080034#include <android-base/logging.h>
35#include <android-base/mapped_file.h>
Philip Quinn8f953ab2022-12-06 15:37:07 -080036#define ATRACE_TAG ATRACE_TAG_INPUT
37#include <cutils/trace.h>
38#include <log/log.h>
39
40#include "tensorflow/lite/core/api/error_reporter.h"
Philip Quinnda6a4482023-02-07 10:09:57 -080041#include "tensorflow/lite/core/api/op_resolver.h"
Philip Quinn8f953ab2022-12-06 15:37:07 -080042#include "tensorflow/lite/interpreter.h"
Philip Quinnda6a4482023-02-07 10:09:57 -080043#include "tensorflow/lite/kernels/builtin_op_kernels.h"
Philip Quinn8f953ab2022-12-06 15:37:07 -080044#include "tensorflow/lite/model.h"
Philip Quinnda6a4482023-02-07 10:09:57 -080045#include "tensorflow/lite/mutable_op_resolver.h"
Philip Quinn8f953ab2022-12-06 15:37:07 -080046
47namespace android {
48namespace {
49
50constexpr char SIGNATURE_KEY[] = "serving_default";
51
52// Input tensor names.
53constexpr char INPUT_R[] = "r";
54constexpr char INPUT_PHI[] = "phi";
55constexpr char INPUT_PRESSURE[] = "pressure";
56constexpr char INPUT_TILT[] = "tilt";
57constexpr char INPUT_ORIENTATION[] = "orientation";
58
59// Output tensor names.
60constexpr char OUTPUT_R[] = "r";
61constexpr char OUTPUT_PHI[] = "phi";
62constexpr char OUTPUT_PRESSURE[] = "pressure";
63
Siarhei Vishniakouc065d7b2023-03-02 14:06:29 -080064// Ideally, we would just use std::filesystem::exists here, but it requires libc++fs, which causes
65// build issues in other parts of the system.
66#if defined(__ANDROID__)
67bool fileExists(const char* filename) {
68 struct stat buffer;
69 return stat(filename, &buffer) == 0;
70}
71#endif
72
Siarhei Vishniakoufd0a68e2023-02-28 13:25:36 -080073std::string getModelPath() {
74#if defined(__ANDROID__)
Siarhei Vishniakouc065d7b2023-03-02 14:06:29 -080075 static const char* oemModel = "/vendor/etc/motion_predictor_model.fb";
76 if (fileExists(oemModel)) {
77 return oemModel;
78 }
Siarhei Vishniakoufd0a68e2023-02-28 13:25:36 -080079 return "/system/etc/motion_predictor_model.fb";
80#else
81 return base::GetExecutableDirectory() + "/motion_predictor_model.fb";
82#endif
83}
84
Philip Quinn8f953ab2022-12-06 15:37:07 -080085// A TFLite ErrorReporter that logs to logcat.
86class LoggingErrorReporter : public tflite::ErrorReporter {
87public:
88 int Report(const char* format, va_list args) override {
89 return LOG_PRI_VA(ANDROID_LOG_ERROR, LOG_TAG, format, args);
90 }
91};
92
93// Searches a runner for an input tensor.
94TfLiteTensor* findInputTensor(const char* name, tflite::SignatureRunner* runner) {
95 TfLiteTensor* tensor = runner->input_tensor(name);
96 LOG_ALWAYS_FATAL_IF(!tensor, "Failed to find input tensor '%s'", name);
97 return tensor;
98}
99
100// Searches a runner for an output tensor.
101const TfLiteTensor* findOutputTensor(const char* name, tflite::SignatureRunner* runner) {
102 const TfLiteTensor* tensor = runner->output_tensor(name);
103 LOG_ALWAYS_FATAL_IF(!tensor, "Failed to find output tensor '%s'", name);
104 return tensor;
105}
106
107// Returns the buffer for a tensor of type T.
108template <typename T>
109std::span<T> getTensorBuffer(typename std::conditional<std::is_const<T>::value, const TfLiteTensor*,
110 TfLiteTensor*>::type tensor) {
111 LOG_ALWAYS_FATAL_IF(!tensor);
112
113 const TfLiteType type = tflite::typeToTfLiteType<typename std::remove_cv<T>::type>();
114 LOG_ALWAYS_FATAL_IF(tensor->type != type, "Unexpected type for '%s' tensor: %s (expected %s)",
115 tensor->name, TfLiteTypeGetName(tensor->type), TfLiteTypeGetName(type));
116
117 LOG_ALWAYS_FATAL_IF(!tensor->data.data);
Ryan Prichard841b07c2023-10-05 14:52:00 -0700118 return std::span<T>(reinterpret_cast<T*>(tensor->data.data), tensor->bytes / sizeof(T));
Philip Quinn8f953ab2022-12-06 15:37:07 -0800119}
120
121// Verifies that a tensor exists and has an underlying buffer of type T.
122template <typename T>
123void checkTensor(const TfLiteTensor* tensor) {
124 LOG_ALWAYS_FATAL_IF(!tensor);
125
126 const auto buffer = getTensorBuffer<const T>(tensor);
127 LOG_ALWAYS_FATAL_IF(buffer.empty(), "No buffer for tensor '%s'", tensor->name);
128}
129
Philip Quinnda6a4482023-02-07 10:09:57 -0800130std::unique_ptr<tflite::OpResolver> createOpResolver() {
131 auto resolver = std::make_unique<tflite::MutableOpResolver>();
132 resolver->AddBuiltin(::tflite::BuiltinOperator_CONCATENATION,
133 ::tflite::ops::builtin::Register_CONCATENATION());
134 resolver->AddBuiltin(::tflite::BuiltinOperator_FULLY_CONNECTED,
135 ::tflite::ops::builtin::Register_FULLY_CONNECTED());
136 return resolver;
137}
138
Philip Quinn8f953ab2022-12-06 15:37:07 -0800139} // namespace
140
Philip Quinn9b8926e2023-01-31 14:50:02 -0800141TfLiteMotionPredictorBuffers::TfLiteMotionPredictorBuffers(size_t inputLength)
142 : mInputR(inputLength, 0),
143 mInputPhi(inputLength, 0),
144 mInputPressure(inputLength, 0),
145 mInputTilt(inputLength, 0),
146 mInputOrientation(inputLength, 0) {
Philip Quinn8f953ab2022-12-06 15:37:07 -0800147 LOG_ALWAYS_FATAL_IF(inputLength == 0, "Buffer input size must be greater than 0");
Philip Quinn8f953ab2022-12-06 15:37:07 -0800148}
149
150void TfLiteMotionPredictorBuffers::reset() {
151 std::fill(mInputR.begin(), mInputR.end(), 0);
152 std::fill(mInputPhi.begin(), mInputPhi.end(), 0);
153 std::fill(mInputPressure.begin(), mInputPressure.end(), 0);
154 std::fill(mInputTilt.begin(), mInputTilt.end(), 0);
155 std::fill(mInputOrientation.begin(), mInputOrientation.end(), 0);
156 mAxisFrom.reset();
157 mAxisTo.reset();
158}
159
160void TfLiteMotionPredictorBuffers::copyTo(TfLiteMotionPredictorModel& model) const {
161 LOG_ALWAYS_FATAL_IF(mInputR.size() != model.inputLength(),
162 "Buffer length %zu doesn't match model input length %zu", mInputR.size(),
163 model.inputLength());
164 LOG_ALWAYS_FATAL_IF(!isReady(), "Buffers are incomplete");
165
166 std::copy(mInputR.begin(), mInputR.end(), model.inputR().begin());
167 std::copy(mInputPhi.begin(), mInputPhi.end(), model.inputPhi().begin());
168 std::copy(mInputPressure.begin(), mInputPressure.end(), model.inputPressure().begin());
169 std::copy(mInputTilt.begin(), mInputTilt.end(), model.inputTilt().begin());
170 std::copy(mInputOrientation.begin(), mInputOrientation.end(), model.inputOrientation().begin());
171}
172
173void TfLiteMotionPredictorBuffers::pushSample(int64_t timestamp,
174 const TfLiteMotionPredictorSample sample) {
175 // Convert the sample (x, y) into polar (r, φ) based on a reference axis
176 // from the preceding two points (mAxisFrom/mAxisTo).
177
178 mTimestamp = timestamp;
179
180 if (!mAxisTo) { // First point.
181 mAxisTo = sample;
182 return;
183 }
184
185 // Vector from the last point to the current sample point.
186 const TfLiteMotionPredictorSample::Point v = sample.position - mAxisTo->position;
187
188 const float r = std::hypot(v.x, v.y);
189 float phi = 0;
190 float orientation = 0;
191
192 // Ignore the sample if there is no movement. These samples can occur when there's change to a
193 // property other than the coordinates and pollute the input to the model.
194 if (r == 0) {
195 return;
196 }
197
198 if (!mAxisFrom) { // Second point.
199 // We can only determine the distance from the first point, and not any
200 // angle. However, if the second point forms an axis, the orientation can
201 // be transformed relative to that axis.
202 const float axisPhi = std::atan2(v.y, v.x);
203 // A MotionEvent's orientation is measured clockwise from the vertical
204 // axis, but axisPhi is measured counter-clockwise from the horizontal
205 // axis.
206 orientation = M_PI_2 - sample.orientation - axisPhi;
207 } else {
208 const TfLiteMotionPredictorSample::Point axis = mAxisTo->position - mAxisFrom->position;
209 const float axisPhi = std::atan2(axis.y, axis.x);
210 phi = std::atan2(v.y, v.x) - axisPhi;
211
212 if (std::hypot(axis.x, axis.y) > 0) {
213 // See note above.
214 orientation = M_PI_2 - sample.orientation - axisPhi;
215 }
216 }
217
218 // Update the axis for the next point.
219 mAxisFrom = mAxisTo;
220 mAxisTo = sample;
221
222 // Push the current sample onto the end of the input buffers.
Philip Quinn9b8926e2023-01-31 14:50:02 -0800223 mInputR.pushBack(r);
224 mInputPhi.pushBack(phi);
225 mInputPressure.pushBack(sample.pressure);
226 mInputTilt.pushBack(sample.tilt);
227 mInputOrientation.pushBack(orientation);
Philip Quinn8f953ab2022-12-06 15:37:07 -0800228}
229
Siarhei Vishniakoufd0a68e2023-02-28 13:25:36 -0800230std::unique_ptr<TfLiteMotionPredictorModel> TfLiteMotionPredictorModel::create() {
231 const std::string modelPath = getModelPath();
Siarhei Vishniakouc065d7b2023-03-02 14:06:29 -0800232 android::base::unique_fd fd(open(modelPath.c_str(), O_RDONLY));
Philip Quinncb3229a2023-02-08 22:50:59 -0800233 if (fd == -1) {
234 PLOG(FATAL) << "Could not read model from " << modelPath;
235 }
Philip Quinn8f953ab2022-12-06 15:37:07 -0800236
Philip Quinncb3229a2023-02-08 22:50:59 -0800237 const off_t fdSize = lseek(fd, 0, SEEK_END);
238 if (fdSize == -1) {
239 PLOG(FATAL) << "Failed to determine file size";
240 }
241
242 std::unique_ptr<android::base::MappedFile> modelBuffer =
243 android::base::MappedFile::FromFd(fd, /*offset=*/0, fdSize, PROT_READ);
244 if (!modelBuffer) {
245 PLOG(FATAL) << "Failed to mmap model";
246 }
Philip Quinn8f953ab2022-12-06 15:37:07 -0800247
248 return std::unique_ptr<TfLiteMotionPredictorModel>(
Philip Quinncb3229a2023-02-08 22:50:59 -0800249 new TfLiteMotionPredictorModel(std::move(modelBuffer)));
Philip Quinn8f953ab2022-12-06 15:37:07 -0800250}
251
Philip Quinncb3229a2023-02-08 22:50:59 -0800252TfLiteMotionPredictorModel::TfLiteMotionPredictorModel(
253 std::unique_ptr<android::base::MappedFile> model)
Philip Quinn8f953ab2022-12-06 15:37:07 -0800254 : mFlatBuffer(std::move(model)) {
Philip Quinncb3229a2023-02-08 22:50:59 -0800255 CHECK(mFlatBuffer);
Philip Quinn8f953ab2022-12-06 15:37:07 -0800256 mErrorReporter = std::make_unique<LoggingErrorReporter>();
Philip Quinncb3229a2023-02-08 22:50:59 -0800257 mModel = tflite::FlatBufferModel::VerifyAndBuildFromBuffer(mFlatBuffer->data(),
258 mFlatBuffer->size(),
Philip Quinn8f953ab2022-12-06 15:37:07 -0800259 /*extra_verifier=*/nullptr,
260 mErrorReporter.get());
261 LOG_ALWAYS_FATAL_IF(!mModel);
262
Philip Quinnda6a4482023-02-07 10:09:57 -0800263 auto resolver = createOpResolver();
264 tflite::InterpreterBuilder builder(*mModel, *resolver);
Philip Quinn8f953ab2022-12-06 15:37:07 -0800265
266 if (builder(&mInterpreter) != kTfLiteOk || !mInterpreter) {
267 LOG_ALWAYS_FATAL("Failed to build interpreter");
268 }
269
270 mRunner = mInterpreter->GetSignatureRunner(SIGNATURE_KEY);
271 LOG_ALWAYS_FATAL_IF(!mRunner, "Failed to find runner for signature '%s'", SIGNATURE_KEY);
272
273 allocateTensors();
274}
275
Philip Quinnda6a4482023-02-07 10:09:57 -0800276TfLiteMotionPredictorModel::~TfLiteMotionPredictorModel() {}
277
Philip Quinn8f953ab2022-12-06 15:37:07 -0800278void TfLiteMotionPredictorModel::allocateTensors() {
279 if (mRunner->AllocateTensors() != kTfLiteOk) {
280 LOG_ALWAYS_FATAL("Failed to allocate tensors");
281 }
282
283 attachInputTensors();
284 attachOutputTensors();
285
286 checkTensor<float>(mInputR);
287 checkTensor<float>(mInputPhi);
288 checkTensor<float>(mInputPressure);
289 checkTensor<float>(mInputTilt);
290 checkTensor<float>(mInputOrientation);
291 checkTensor<float>(mOutputR);
292 checkTensor<float>(mOutputPhi);
293 checkTensor<float>(mOutputPressure);
294
295 const auto checkInputTensorSize = [this](const TfLiteTensor* tensor) {
296 const size_t size = getTensorBuffer<const float>(tensor).size();
297 LOG_ALWAYS_FATAL_IF(size != inputLength(),
298 "Tensor '%s' length %zu does not match input length %zu", tensor->name,
299 size, inputLength());
300 };
301
302 checkInputTensorSize(mInputR);
303 checkInputTensorSize(mInputPhi);
304 checkInputTensorSize(mInputPressure);
305 checkInputTensorSize(mInputTilt);
306 checkInputTensorSize(mInputOrientation);
307}
308
309void TfLiteMotionPredictorModel::attachInputTensors() {
310 mInputR = findInputTensor(INPUT_R, mRunner);
311 mInputPhi = findInputTensor(INPUT_PHI, mRunner);
312 mInputPressure = findInputTensor(INPUT_PRESSURE, mRunner);
313 mInputTilt = findInputTensor(INPUT_TILT, mRunner);
314 mInputOrientation = findInputTensor(INPUT_ORIENTATION, mRunner);
315}
316
317void TfLiteMotionPredictorModel::attachOutputTensors() {
318 mOutputR = findOutputTensor(OUTPUT_R, mRunner);
319 mOutputPhi = findOutputTensor(OUTPUT_PHI, mRunner);
320 mOutputPressure = findOutputTensor(OUTPUT_PRESSURE, mRunner);
321}
322
323bool TfLiteMotionPredictorModel::invoke() {
324 ATRACE_BEGIN("TfLiteMotionPredictorModel::invoke");
325 TfLiteStatus result = mRunner->Invoke();
326 ATRACE_END();
327
328 if (result != kTfLiteOk) {
329 return false;
330 }
331
332 // Invoke() might reallocate tensors, so they need to be reattached.
333 attachInputTensors();
334 attachOutputTensors();
335
336 if (outputR().size() != outputPhi().size() || outputR().size() != outputPressure().size()) {
337 LOG_ALWAYS_FATAL("Output size mismatch: (r: %zu, phi: %zu, pressure: %zu)",
338 outputR().size(), outputPhi().size(), outputPressure().size());
339 }
340
341 return true;
342}
343
344size_t TfLiteMotionPredictorModel::inputLength() const {
345 return getTensorBuffer<const float>(mInputR).size();
346}
347
Cody Heinerdbd14eb2023-03-30 18:41:45 -0700348size_t TfLiteMotionPredictorModel::outputLength() const {
349 return getTensorBuffer<const float>(mOutputR).size();
350}
351
Philip Quinn8f953ab2022-12-06 15:37:07 -0800352std::span<float> TfLiteMotionPredictorModel::inputR() {
353 return getTensorBuffer<float>(mInputR);
354}
355
356std::span<float> TfLiteMotionPredictorModel::inputPhi() {
357 return getTensorBuffer<float>(mInputPhi);
358}
359
360std::span<float> TfLiteMotionPredictorModel::inputPressure() {
361 return getTensorBuffer<float>(mInputPressure);
362}
363
364std::span<float> TfLiteMotionPredictorModel::inputTilt() {
365 return getTensorBuffer<float>(mInputTilt);
366}
367
368std::span<float> TfLiteMotionPredictorModel::inputOrientation() {
369 return getTensorBuffer<float>(mInputOrientation);
370}
371
372std::span<const float> TfLiteMotionPredictorModel::outputR() const {
373 return getTensorBuffer<const float>(mOutputR);
374}
375
376std::span<const float> TfLiteMotionPredictorModel::outputPhi() const {
377 return getTensorBuffer<const float>(mOutputPhi);
378}
379
380std::span<const float> TfLiteMotionPredictorModel::outputPressure() const {
381 return getTensorBuffer<const float>(mOutputPressure);
382}
383
384} // namespace android