blob: 9f4aaa8337db1554a3ceaf9c6dad1efe9d94ac7a [file] [log] [blame]
Philip Quinn8f953ab2022-12-06 15:37:07 -08001/*
2 * Copyright (C) 2023 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#define LOG_TAG "TfLiteMotionPredictor"
18#include <input/TfLiteMotionPredictor.h>
19
Philip Quinncb3229a2023-02-08 22:50:59 -080020#include <fcntl.h>
21#include <sys/mman.h>
22#include <unistd.h>
23
Philip Quinn8f953ab2022-12-06 15:37:07 -080024#include <algorithm>
25#include <cmath>
26#include <cstddef>
27#include <cstdint>
Philip Quinn8f953ab2022-12-06 15:37:07 -080028#include <memory>
29#include <span>
Philip Quinn8f953ab2022-12-06 15:37:07 -080030#include <type_traits>
31#include <utility>
32
Siarhei Vishniakoufd0a68e2023-02-28 13:25:36 -080033#include <android-base/file.h>
Philip Quinncb3229a2023-02-08 22:50:59 -080034#include <android-base/logging.h>
35#include <android-base/mapped_file.h>
Philip Quinn8f953ab2022-12-06 15:37:07 -080036#define ATRACE_TAG ATRACE_TAG_INPUT
37#include <cutils/trace.h>
38#include <log/log.h>
Philip Quinnf84fa492023-06-26 14:15:15 -070039#include <utils/Timers.h>
Philip Quinn8f953ab2022-12-06 15:37:07 -080040
41#include "tensorflow/lite/core/api/error_reporter.h"
Philip Quinnda6a4482023-02-07 10:09:57 -080042#include "tensorflow/lite/core/api/op_resolver.h"
Philip Quinn8f953ab2022-12-06 15:37:07 -080043#include "tensorflow/lite/interpreter.h"
Philip Quinnda6a4482023-02-07 10:09:57 -080044#include "tensorflow/lite/kernels/builtin_op_kernels.h"
Philip Quinn8f953ab2022-12-06 15:37:07 -080045#include "tensorflow/lite/model.h"
Philip Quinnda6a4482023-02-07 10:09:57 -080046#include "tensorflow/lite/mutable_op_resolver.h"
Philip Quinn8f953ab2022-12-06 15:37:07 -080047
Philip Quinnf84fa492023-06-26 14:15:15 -070048#include "tinyxml2.h"
49
Philip Quinn8f953ab2022-12-06 15:37:07 -080050namespace android {
51namespace {
52
53constexpr char SIGNATURE_KEY[] = "serving_default";
54
55// Input tensor names.
56constexpr char INPUT_R[] = "r";
57constexpr char INPUT_PHI[] = "phi";
58constexpr char INPUT_PRESSURE[] = "pressure";
59constexpr char INPUT_TILT[] = "tilt";
60constexpr char INPUT_ORIENTATION[] = "orientation";
61
62// Output tensor names.
63constexpr char OUTPUT_R[] = "r";
64constexpr char OUTPUT_PHI[] = "phi";
65constexpr char OUTPUT_PRESSURE[] = "pressure";
66
Siarhei Vishniakouc065d7b2023-03-02 14:06:29 -080067// Ideally, we would just use std::filesystem::exists here, but it requires libc++fs, which causes
68// build issues in other parts of the system.
69#if defined(__ANDROID__)
70bool fileExists(const char* filename) {
71 struct stat buffer;
72 return stat(filename, &buffer) == 0;
73}
74#endif
75
Siarhei Vishniakoufd0a68e2023-02-28 13:25:36 -080076std::string getModelPath() {
77#if defined(__ANDROID__)
Philip Quinnf84fa492023-06-26 14:15:15 -070078 static const char* oemModel = "/vendor/etc/motion_predictor_model.tflite";
Siarhei Vishniakouc065d7b2023-03-02 14:06:29 -080079 if (fileExists(oemModel)) {
80 return oemModel;
81 }
Philip Quinnf84fa492023-06-26 14:15:15 -070082 return "/system/etc/motion_predictor_model.tflite";
Siarhei Vishniakoufd0a68e2023-02-28 13:25:36 -080083#else
Philip Quinnf84fa492023-06-26 14:15:15 -070084 return base::GetExecutableDirectory() + "/motion_predictor_model.tflite";
Siarhei Vishniakoufd0a68e2023-02-28 13:25:36 -080085#endif
86}
87
Philip Quinnf84fa492023-06-26 14:15:15 -070088std::string getConfigPath() {
89 // The config file should be alongside the model file.
90 return base::Dirname(getModelPath()) + "/motion_predictor_config.xml";
91}
92
93int64_t parseXMLInt64(const tinyxml2::XMLElement& configRoot, const char* elementName) {
94 const tinyxml2::XMLElement* element = configRoot.FirstChildElement(elementName);
95 LOG_ALWAYS_FATAL_IF(!element, "Could not find '%s' element", elementName);
96
97 int64_t value = 0;
98 LOG_ALWAYS_FATAL_IF(element->QueryInt64Text(&value) != tinyxml2::XML_SUCCESS,
99 "Failed to parse %s: %s", elementName, element->GetText());
100 return value;
101}
102
Philip Quinn8f953ab2022-12-06 15:37:07 -0800103// A TFLite ErrorReporter that logs to logcat.
104class LoggingErrorReporter : public tflite::ErrorReporter {
105public:
106 int Report(const char* format, va_list args) override {
107 return LOG_PRI_VA(ANDROID_LOG_ERROR, LOG_TAG, format, args);
108 }
109};
110
111// Searches a runner for an input tensor.
112TfLiteTensor* findInputTensor(const char* name, tflite::SignatureRunner* runner) {
113 TfLiteTensor* tensor = runner->input_tensor(name);
114 LOG_ALWAYS_FATAL_IF(!tensor, "Failed to find input tensor '%s'", name);
115 return tensor;
116}
117
118// Searches a runner for an output tensor.
119const TfLiteTensor* findOutputTensor(const char* name, tflite::SignatureRunner* runner) {
120 const TfLiteTensor* tensor = runner->output_tensor(name);
121 LOG_ALWAYS_FATAL_IF(!tensor, "Failed to find output tensor '%s'", name);
122 return tensor;
123}
124
125// Returns the buffer for a tensor of type T.
126template <typename T>
127std::span<T> getTensorBuffer(typename std::conditional<std::is_const<T>::value, const TfLiteTensor*,
128 TfLiteTensor*>::type tensor) {
129 LOG_ALWAYS_FATAL_IF(!tensor);
130
131 const TfLiteType type = tflite::typeToTfLiteType<typename std::remove_cv<T>::type>();
132 LOG_ALWAYS_FATAL_IF(tensor->type != type, "Unexpected type for '%s' tensor: %s (expected %s)",
133 tensor->name, TfLiteTypeGetName(tensor->type), TfLiteTypeGetName(type));
134
135 LOG_ALWAYS_FATAL_IF(!tensor->data.data);
136 return {reinterpret_cast<T*>(tensor->data.data),
137 static_cast<typename std::span<T>::index_type>(tensor->bytes / sizeof(T))};
138}
139
140// Verifies that a tensor exists and has an underlying buffer of type T.
141template <typename T>
142void checkTensor(const TfLiteTensor* tensor) {
143 LOG_ALWAYS_FATAL_IF(!tensor);
144
145 const auto buffer = getTensorBuffer<const T>(tensor);
146 LOG_ALWAYS_FATAL_IF(buffer.empty(), "No buffer for tensor '%s'", tensor->name);
147}
148
Philip Quinnda6a4482023-02-07 10:09:57 -0800149std::unique_ptr<tflite::OpResolver> createOpResolver() {
150 auto resolver = std::make_unique<tflite::MutableOpResolver>();
151 resolver->AddBuiltin(::tflite::BuiltinOperator_CONCATENATION,
152 ::tflite::ops::builtin::Register_CONCATENATION());
153 resolver->AddBuiltin(::tflite::BuiltinOperator_FULLY_CONNECTED,
154 ::tflite::ops::builtin::Register_FULLY_CONNECTED());
155 return resolver;
156}
157
Philip Quinn8f953ab2022-12-06 15:37:07 -0800158} // namespace
159
Philip Quinn9b8926e2023-01-31 14:50:02 -0800160TfLiteMotionPredictorBuffers::TfLiteMotionPredictorBuffers(size_t inputLength)
161 : mInputR(inputLength, 0),
162 mInputPhi(inputLength, 0),
163 mInputPressure(inputLength, 0),
164 mInputTilt(inputLength, 0),
165 mInputOrientation(inputLength, 0) {
Philip Quinn8f953ab2022-12-06 15:37:07 -0800166 LOG_ALWAYS_FATAL_IF(inputLength == 0, "Buffer input size must be greater than 0");
Philip Quinn8f953ab2022-12-06 15:37:07 -0800167}
168
169void TfLiteMotionPredictorBuffers::reset() {
170 std::fill(mInputR.begin(), mInputR.end(), 0);
171 std::fill(mInputPhi.begin(), mInputPhi.end(), 0);
172 std::fill(mInputPressure.begin(), mInputPressure.end(), 0);
173 std::fill(mInputTilt.begin(), mInputTilt.end(), 0);
174 std::fill(mInputOrientation.begin(), mInputOrientation.end(), 0);
175 mAxisFrom.reset();
176 mAxisTo.reset();
177}
178
179void TfLiteMotionPredictorBuffers::copyTo(TfLiteMotionPredictorModel& model) const {
180 LOG_ALWAYS_FATAL_IF(mInputR.size() != model.inputLength(),
181 "Buffer length %zu doesn't match model input length %zu", mInputR.size(),
182 model.inputLength());
183 LOG_ALWAYS_FATAL_IF(!isReady(), "Buffers are incomplete");
184
185 std::copy(mInputR.begin(), mInputR.end(), model.inputR().begin());
186 std::copy(mInputPhi.begin(), mInputPhi.end(), model.inputPhi().begin());
187 std::copy(mInputPressure.begin(), mInputPressure.end(), model.inputPressure().begin());
188 std::copy(mInputTilt.begin(), mInputTilt.end(), model.inputTilt().begin());
189 std::copy(mInputOrientation.begin(), mInputOrientation.end(), model.inputOrientation().begin());
190}
191
192void TfLiteMotionPredictorBuffers::pushSample(int64_t timestamp,
193 const TfLiteMotionPredictorSample sample) {
194 // Convert the sample (x, y) into polar (r, φ) based on a reference axis
195 // from the preceding two points (mAxisFrom/mAxisTo).
196
197 mTimestamp = timestamp;
198
199 if (!mAxisTo) { // First point.
200 mAxisTo = sample;
201 return;
202 }
203
204 // Vector from the last point to the current sample point.
205 const TfLiteMotionPredictorSample::Point v = sample.position - mAxisTo->position;
206
207 const float r = std::hypot(v.x, v.y);
208 float phi = 0;
209 float orientation = 0;
210
211 // Ignore the sample if there is no movement. These samples can occur when there's change to a
212 // property other than the coordinates and pollute the input to the model.
213 if (r == 0) {
214 return;
215 }
216
217 if (!mAxisFrom) { // Second point.
218 // We can only determine the distance from the first point, and not any
219 // angle. However, if the second point forms an axis, the orientation can
220 // be transformed relative to that axis.
221 const float axisPhi = std::atan2(v.y, v.x);
222 // A MotionEvent's orientation is measured clockwise from the vertical
223 // axis, but axisPhi is measured counter-clockwise from the horizontal
224 // axis.
225 orientation = M_PI_2 - sample.orientation - axisPhi;
226 } else {
227 const TfLiteMotionPredictorSample::Point axis = mAxisTo->position - mAxisFrom->position;
228 const float axisPhi = std::atan2(axis.y, axis.x);
229 phi = std::atan2(v.y, v.x) - axisPhi;
230
231 if (std::hypot(axis.x, axis.y) > 0) {
232 // See note above.
233 orientation = M_PI_2 - sample.orientation - axisPhi;
234 }
235 }
236
237 // Update the axis for the next point.
238 mAxisFrom = mAxisTo;
239 mAxisTo = sample;
240
241 // Push the current sample onto the end of the input buffers.
Philip Quinn9b8926e2023-01-31 14:50:02 -0800242 mInputR.pushBack(r);
243 mInputPhi.pushBack(phi);
244 mInputPressure.pushBack(sample.pressure);
245 mInputTilt.pushBack(sample.tilt);
246 mInputOrientation.pushBack(orientation);
Philip Quinn8f953ab2022-12-06 15:37:07 -0800247}
248
Siarhei Vishniakoufd0a68e2023-02-28 13:25:36 -0800249std::unique_ptr<TfLiteMotionPredictorModel> TfLiteMotionPredictorModel::create() {
250 const std::string modelPath = getModelPath();
Siarhei Vishniakouc065d7b2023-03-02 14:06:29 -0800251 android::base::unique_fd fd(open(modelPath.c_str(), O_RDONLY));
Philip Quinncb3229a2023-02-08 22:50:59 -0800252 if (fd == -1) {
253 PLOG(FATAL) << "Could not read model from " << modelPath;
254 }
Philip Quinn8f953ab2022-12-06 15:37:07 -0800255
Philip Quinncb3229a2023-02-08 22:50:59 -0800256 const off_t fdSize = lseek(fd, 0, SEEK_END);
257 if (fdSize == -1) {
258 PLOG(FATAL) << "Failed to determine file size";
259 }
260
261 std::unique_ptr<android::base::MappedFile> modelBuffer =
262 android::base::MappedFile::FromFd(fd, /*offset=*/0, fdSize, PROT_READ);
263 if (!modelBuffer) {
264 PLOG(FATAL) << "Failed to mmap model";
265 }
Philip Quinn8f953ab2022-12-06 15:37:07 -0800266
Philip Quinnf84fa492023-06-26 14:15:15 -0700267 const std::string configPath = getConfigPath();
268 tinyxml2::XMLDocument configDocument;
269 LOG_ALWAYS_FATAL_IF(configDocument.LoadFile(configPath.c_str()) != tinyxml2::XML_SUCCESS,
270 "Failed to load config file from %s", configPath.c_str());
271
272 // Parse configuration file.
273 const tinyxml2::XMLElement* configRoot = configDocument.FirstChildElement("motion-predictor");
274 LOG_ALWAYS_FATAL_IF(!configRoot);
275 const nsecs_t predictionInterval = parseXMLInt64(*configRoot, "prediction-interval");
276
Philip Quinn8f953ab2022-12-06 15:37:07 -0800277 return std::unique_ptr<TfLiteMotionPredictorModel>(
Philip Quinnf84fa492023-06-26 14:15:15 -0700278 new TfLiteMotionPredictorModel(std::move(modelBuffer), predictionInterval));
Philip Quinn8f953ab2022-12-06 15:37:07 -0800279}
280
Philip Quinncb3229a2023-02-08 22:50:59 -0800281TfLiteMotionPredictorModel::TfLiteMotionPredictorModel(
Philip Quinnf84fa492023-06-26 14:15:15 -0700282 std::unique_ptr<android::base::MappedFile> model, nsecs_t predictionInterval)
283 : mFlatBuffer(std::move(model)), mPredictionInterval(predictionInterval) {
Philip Quinncb3229a2023-02-08 22:50:59 -0800284 CHECK(mFlatBuffer);
Philip Quinn8f953ab2022-12-06 15:37:07 -0800285 mErrorReporter = std::make_unique<LoggingErrorReporter>();
Philip Quinncb3229a2023-02-08 22:50:59 -0800286 mModel = tflite::FlatBufferModel::VerifyAndBuildFromBuffer(mFlatBuffer->data(),
287 mFlatBuffer->size(),
Philip Quinn8f953ab2022-12-06 15:37:07 -0800288 /*extra_verifier=*/nullptr,
289 mErrorReporter.get());
290 LOG_ALWAYS_FATAL_IF(!mModel);
291
Philip Quinnda6a4482023-02-07 10:09:57 -0800292 auto resolver = createOpResolver();
293 tflite::InterpreterBuilder builder(*mModel, *resolver);
Philip Quinn8f953ab2022-12-06 15:37:07 -0800294
295 if (builder(&mInterpreter) != kTfLiteOk || !mInterpreter) {
296 LOG_ALWAYS_FATAL("Failed to build interpreter");
297 }
298
299 mRunner = mInterpreter->GetSignatureRunner(SIGNATURE_KEY);
300 LOG_ALWAYS_FATAL_IF(!mRunner, "Failed to find runner for signature '%s'", SIGNATURE_KEY);
301
302 allocateTensors();
303}
304
Philip Quinnda6a4482023-02-07 10:09:57 -0800305TfLiteMotionPredictorModel::~TfLiteMotionPredictorModel() {}
306
Philip Quinn8f953ab2022-12-06 15:37:07 -0800307void TfLiteMotionPredictorModel::allocateTensors() {
308 if (mRunner->AllocateTensors() != kTfLiteOk) {
309 LOG_ALWAYS_FATAL("Failed to allocate tensors");
310 }
311
312 attachInputTensors();
313 attachOutputTensors();
314
315 checkTensor<float>(mInputR);
316 checkTensor<float>(mInputPhi);
317 checkTensor<float>(mInputPressure);
318 checkTensor<float>(mInputTilt);
319 checkTensor<float>(mInputOrientation);
320 checkTensor<float>(mOutputR);
321 checkTensor<float>(mOutputPhi);
322 checkTensor<float>(mOutputPressure);
323
324 const auto checkInputTensorSize = [this](const TfLiteTensor* tensor) {
325 const size_t size = getTensorBuffer<const float>(tensor).size();
326 LOG_ALWAYS_FATAL_IF(size != inputLength(),
327 "Tensor '%s' length %zu does not match input length %zu", tensor->name,
328 size, inputLength());
329 };
330
331 checkInputTensorSize(mInputR);
332 checkInputTensorSize(mInputPhi);
333 checkInputTensorSize(mInputPressure);
334 checkInputTensorSize(mInputTilt);
335 checkInputTensorSize(mInputOrientation);
336}
337
338void TfLiteMotionPredictorModel::attachInputTensors() {
339 mInputR = findInputTensor(INPUT_R, mRunner);
340 mInputPhi = findInputTensor(INPUT_PHI, mRunner);
341 mInputPressure = findInputTensor(INPUT_PRESSURE, mRunner);
342 mInputTilt = findInputTensor(INPUT_TILT, mRunner);
343 mInputOrientation = findInputTensor(INPUT_ORIENTATION, mRunner);
344}
345
346void TfLiteMotionPredictorModel::attachOutputTensors() {
347 mOutputR = findOutputTensor(OUTPUT_R, mRunner);
348 mOutputPhi = findOutputTensor(OUTPUT_PHI, mRunner);
349 mOutputPressure = findOutputTensor(OUTPUT_PRESSURE, mRunner);
350}
351
352bool TfLiteMotionPredictorModel::invoke() {
353 ATRACE_BEGIN("TfLiteMotionPredictorModel::invoke");
354 TfLiteStatus result = mRunner->Invoke();
355 ATRACE_END();
356
357 if (result != kTfLiteOk) {
358 return false;
359 }
360
361 // Invoke() might reallocate tensors, so they need to be reattached.
362 attachInputTensors();
363 attachOutputTensors();
364
365 if (outputR().size() != outputPhi().size() || outputR().size() != outputPressure().size()) {
366 LOG_ALWAYS_FATAL("Output size mismatch: (r: %zu, phi: %zu, pressure: %zu)",
367 outputR().size(), outputPhi().size(), outputPressure().size());
368 }
369
370 return true;
371}
372
373size_t TfLiteMotionPredictorModel::inputLength() const {
374 return getTensorBuffer<const float>(mInputR).size();
375}
376
Cody Heinerdbd14eb2023-03-30 18:41:45 -0700377size_t TfLiteMotionPredictorModel::outputLength() const {
378 return getTensorBuffer<const float>(mOutputR).size();
379}
380
Philip Quinn8f953ab2022-12-06 15:37:07 -0800381std::span<float> TfLiteMotionPredictorModel::inputR() {
382 return getTensorBuffer<float>(mInputR);
383}
384
385std::span<float> TfLiteMotionPredictorModel::inputPhi() {
386 return getTensorBuffer<float>(mInputPhi);
387}
388
389std::span<float> TfLiteMotionPredictorModel::inputPressure() {
390 return getTensorBuffer<float>(mInputPressure);
391}
392
393std::span<float> TfLiteMotionPredictorModel::inputTilt() {
394 return getTensorBuffer<float>(mInputTilt);
395}
396
397std::span<float> TfLiteMotionPredictorModel::inputOrientation() {
398 return getTensorBuffer<float>(mInputOrientation);
399}
400
401std::span<const float> TfLiteMotionPredictorModel::outputR() const {
402 return getTensorBuffer<const float>(mOutputR);
403}
404
405std::span<const float> TfLiteMotionPredictorModel::outputPhi() const {
406 return getTensorBuffer<const float>(mOutputPhi);
407}
408
409std::span<const float> TfLiteMotionPredictorModel::outputPressure() const {
410 return getTensorBuffer<const float>(mOutputPressure);
411}
412
413} // namespace android