Philip Quinn | 8f953ab | 2022-12-06 15:37:07 -0800 | [diff] [blame] | 1 | /* |
| 2 | * Copyright (C) 2023 The Android Open Source Project |
| 3 | * |
| 4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | * you may not use this file except in compliance with the License. |
| 6 | * You may obtain a copy of the License at |
| 7 | * |
| 8 | * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | * |
| 10 | * Unless required by applicable law or agreed to in writing, software |
| 11 | * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | * See the License for the specific language governing permissions and |
| 14 | * limitations under the License. |
| 15 | */ |
| 16 | |
| 17 | #define LOG_TAG "TfLiteMotionPredictor" |
| 18 | #include <input/TfLiteMotionPredictor.h> |
| 19 | |
Philip Quinn | cb3229a | 2023-02-08 22:50:59 -0800 | [diff] [blame] | 20 | #include <fcntl.h> |
| 21 | #include <sys/mman.h> |
| 22 | #include <unistd.h> |
| 23 | |
Philip Quinn | 8f953ab | 2022-12-06 15:37:07 -0800 | [diff] [blame] | 24 | #include <algorithm> |
| 25 | #include <cmath> |
| 26 | #include <cstddef> |
| 27 | #include <cstdint> |
Philip Quinn | 8f953ab | 2022-12-06 15:37:07 -0800 | [diff] [blame] | 28 | #include <memory> |
| 29 | #include <span> |
Philip Quinn | 8f953ab | 2022-12-06 15:37:07 -0800 | [diff] [blame] | 30 | #include <type_traits> |
| 31 | #include <utility> |
| 32 | |
Siarhei Vishniakou | fd0a68e | 2023-02-28 13:25:36 -0800 | [diff] [blame] | 33 | #include <android-base/file.h> |
Philip Quinn | cb3229a | 2023-02-08 22:50:59 -0800 | [diff] [blame] | 34 | #include <android-base/logging.h> |
| 35 | #include <android-base/mapped_file.h> |
Philip Quinn | 8f953ab | 2022-12-06 15:37:07 -0800 | [diff] [blame] | 36 | #define ATRACE_TAG ATRACE_TAG_INPUT |
| 37 | #include <cutils/trace.h> |
| 38 | #include <log/log.h> |
Philip Quinn | f84fa49 | 2023-06-26 14:15:15 -0700 | [diff] [blame] | 39 | #include <utils/Timers.h> |
Philip Quinn | 8f953ab | 2022-12-06 15:37:07 -0800 | [diff] [blame] | 40 | |
| 41 | #include "tensorflow/lite/core/api/error_reporter.h" |
Philip Quinn | da6a448 | 2023-02-07 10:09:57 -0800 | [diff] [blame] | 42 | #include "tensorflow/lite/core/api/op_resolver.h" |
Philip Quinn | 8f953ab | 2022-12-06 15:37:07 -0800 | [diff] [blame] | 43 | #include "tensorflow/lite/interpreter.h" |
Philip Quinn | da6a448 | 2023-02-07 10:09:57 -0800 | [diff] [blame] | 44 | #include "tensorflow/lite/kernels/builtin_op_kernels.h" |
Philip Quinn | 8f953ab | 2022-12-06 15:37:07 -0800 | [diff] [blame] | 45 | #include "tensorflow/lite/model.h" |
Philip Quinn | da6a448 | 2023-02-07 10:09:57 -0800 | [diff] [blame] | 46 | #include "tensorflow/lite/mutable_op_resolver.h" |
Philip Quinn | 8f953ab | 2022-12-06 15:37:07 -0800 | [diff] [blame] | 47 | |
Philip Quinn | f84fa49 | 2023-06-26 14:15:15 -0700 | [diff] [blame] | 48 | #include "tinyxml2.h" |
| 49 | |
Philip Quinn | 8f953ab | 2022-12-06 15:37:07 -0800 | [diff] [blame] | 50 | namespace android { |
| 51 | namespace { |
| 52 | |
| 53 | constexpr char SIGNATURE_KEY[] = "serving_default"; |
| 54 | |
| 55 | // Input tensor names. |
| 56 | constexpr char INPUT_R[] = "r"; |
| 57 | constexpr char INPUT_PHI[] = "phi"; |
| 58 | constexpr char INPUT_PRESSURE[] = "pressure"; |
| 59 | constexpr char INPUT_TILT[] = "tilt"; |
| 60 | constexpr char INPUT_ORIENTATION[] = "orientation"; |
| 61 | |
| 62 | // Output tensor names. |
| 63 | constexpr char OUTPUT_R[] = "r"; |
| 64 | constexpr char OUTPUT_PHI[] = "phi"; |
| 65 | constexpr char OUTPUT_PRESSURE[] = "pressure"; |
| 66 | |
Siarhei Vishniakou | c065d7b | 2023-03-02 14:06:29 -0800 | [diff] [blame] | 67 | // Ideally, we would just use std::filesystem::exists here, but it requires libc++fs, which causes |
| 68 | // build issues in other parts of the system. |
| 69 | #if defined(__ANDROID__) |
| 70 | bool fileExists(const char* filename) { |
| 71 | struct stat buffer; |
| 72 | return stat(filename, &buffer) == 0; |
| 73 | } |
| 74 | #endif |
| 75 | |
Siarhei Vishniakou | fd0a68e | 2023-02-28 13:25:36 -0800 | [diff] [blame] | 76 | std::string getModelPath() { |
| 77 | #if defined(__ANDROID__) |
Philip Quinn | f84fa49 | 2023-06-26 14:15:15 -0700 | [diff] [blame] | 78 | static const char* oemModel = "/vendor/etc/motion_predictor_model.tflite"; |
Siarhei Vishniakou | c065d7b | 2023-03-02 14:06:29 -0800 | [diff] [blame] | 79 | if (fileExists(oemModel)) { |
| 80 | return oemModel; |
| 81 | } |
Philip Quinn | f84fa49 | 2023-06-26 14:15:15 -0700 | [diff] [blame] | 82 | return "/system/etc/motion_predictor_model.tflite"; |
Siarhei Vishniakou | fd0a68e | 2023-02-28 13:25:36 -0800 | [diff] [blame] | 83 | #else |
Philip Quinn | f84fa49 | 2023-06-26 14:15:15 -0700 | [diff] [blame] | 84 | return base::GetExecutableDirectory() + "/motion_predictor_model.tflite"; |
Siarhei Vishniakou | fd0a68e | 2023-02-28 13:25:36 -0800 | [diff] [blame] | 85 | #endif |
| 86 | } |
| 87 | |
Philip Quinn | f84fa49 | 2023-06-26 14:15:15 -0700 | [diff] [blame] | 88 | std::string getConfigPath() { |
| 89 | // The config file should be alongside the model file. |
| 90 | return base::Dirname(getModelPath()) + "/motion_predictor_config.xml"; |
| 91 | } |
| 92 | |
| 93 | int64_t parseXMLInt64(const tinyxml2::XMLElement& configRoot, const char* elementName) { |
| 94 | const tinyxml2::XMLElement* element = configRoot.FirstChildElement(elementName); |
| 95 | LOG_ALWAYS_FATAL_IF(!element, "Could not find '%s' element", elementName); |
| 96 | |
| 97 | int64_t value = 0; |
| 98 | LOG_ALWAYS_FATAL_IF(element->QueryInt64Text(&value) != tinyxml2::XML_SUCCESS, |
| 99 | "Failed to parse %s: %s", elementName, element->GetText()); |
| 100 | return value; |
| 101 | } |
| 102 | |
Philip Quinn | 107ce70 | 2023-07-14 13:07:13 -0700 | [diff] [blame] | 103 | float parseXMLFloat(const tinyxml2::XMLElement& configRoot, const char* elementName) { |
| 104 | const tinyxml2::XMLElement* element = configRoot.FirstChildElement(elementName); |
| 105 | LOG_ALWAYS_FATAL_IF(!element, "Could not find '%s' element", elementName); |
| 106 | |
| 107 | float value = 0; |
| 108 | LOG_ALWAYS_FATAL_IF(element->QueryFloatText(&value) != tinyxml2::XML_SUCCESS, |
| 109 | "Failed to parse %s: %s", elementName, element->GetText()); |
| 110 | return value; |
| 111 | } |
| 112 | |
Philip Quinn | 8f953ab | 2022-12-06 15:37:07 -0800 | [diff] [blame] | 113 | // A TFLite ErrorReporter that logs to logcat. |
| 114 | class LoggingErrorReporter : public tflite::ErrorReporter { |
| 115 | public: |
| 116 | int Report(const char* format, va_list args) override { |
| 117 | return LOG_PRI_VA(ANDROID_LOG_ERROR, LOG_TAG, format, args); |
| 118 | } |
| 119 | }; |
| 120 | |
| 121 | // Searches a runner for an input tensor. |
| 122 | TfLiteTensor* findInputTensor(const char* name, tflite::SignatureRunner* runner) { |
| 123 | TfLiteTensor* tensor = runner->input_tensor(name); |
| 124 | LOG_ALWAYS_FATAL_IF(!tensor, "Failed to find input tensor '%s'", name); |
| 125 | return tensor; |
| 126 | } |
| 127 | |
| 128 | // Searches a runner for an output tensor. |
| 129 | const TfLiteTensor* findOutputTensor(const char* name, tflite::SignatureRunner* runner) { |
| 130 | const TfLiteTensor* tensor = runner->output_tensor(name); |
| 131 | LOG_ALWAYS_FATAL_IF(!tensor, "Failed to find output tensor '%s'", name); |
| 132 | return tensor; |
| 133 | } |
| 134 | |
| 135 | // Returns the buffer for a tensor of type T. |
| 136 | template <typename T> |
| 137 | std::span<T> getTensorBuffer(typename std::conditional<std::is_const<T>::value, const TfLiteTensor*, |
| 138 | TfLiteTensor*>::type tensor) { |
| 139 | LOG_ALWAYS_FATAL_IF(!tensor); |
| 140 | |
| 141 | const TfLiteType type = tflite::typeToTfLiteType<typename std::remove_cv<T>::type>(); |
| 142 | LOG_ALWAYS_FATAL_IF(tensor->type != type, "Unexpected type for '%s' tensor: %s (expected %s)", |
| 143 | tensor->name, TfLiteTypeGetName(tensor->type), TfLiteTypeGetName(type)); |
| 144 | |
| 145 | LOG_ALWAYS_FATAL_IF(!tensor->data.data); |
Ryan Prichard | 5a8af50 | 2023-08-31 00:00:47 -0700 | [diff] [blame] | 146 | return std::span<T>(reinterpret_cast<T*>(tensor->data.data), tensor->bytes / sizeof(T)); |
Philip Quinn | 8f953ab | 2022-12-06 15:37:07 -0800 | [diff] [blame] | 147 | } |
| 148 | |
| 149 | // Verifies that a tensor exists and has an underlying buffer of type T. |
| 150 | template <typename T> |
| 151 | void checkTensor(const TfLiteTensor* tensor) { |
| 152 | LOG_ALWAYS_FATAL_IF(!tensor); |
| 153 | |
| 154 | const auto buffer = getTensorBuffer<const T>(tensor); |
| 155 | LOG_ALWAYS_FATAL_IF(buffer.empty(), "No buffer for tensor '%s'", tensor->name); |
| 156 | } |
| 157 | |
Philip Quinn | da6a448 | 2023-02-07 10:09:57 -0800 | [diff] [blame] | 158 | std::unique_ptr<tflite::OpResolver> createOpResolver() { |
| 159 | auto resolver = std::make_unique<tflite::MutableOpResolver>(); |
| 160 | resolver->AddBuiltin(::tflite::BuiltinOperator_CONCATENATION, |
| 161 | ::tflite::ops::builtin::Register_CONCATENATION()); |
| 162 | resolver->AddBuiltin(::tflite::BuiltinOperator_FULLY_CONNECTED, |
| 163 | ::tflite::ops::builtin::Register_FULLY_CONNECTED()); |
Philip Quinn | 107ce70 | 2023-07-14 13:07:13 -0700 | [diff] [blame] | 164 | resolver->AddBuiltin(::tflite::BuiltinOperator_GELU, ::tflite::ops::builtin::Register_GELU()); |
Philip Quinn | da6a448 | 2023-02-07 10:09:57 -0800 | [diff] [blame] | 165 | return resolver; |
| 166 | } |
| 167 | |
Philip Quinn | 8f953ab | 2022-12-06 15:37:07 -0800 | [diff] [blame] | 168 | } // namespace |
| 169 | |
Philip Quinn | 9b8926e | 2023-01-31 14:50:02 -0800 | [diff] [blame] | 170 | TfLiteMotionPredictorBuffers::TfLiteMotionPredictorBuffers(size_t inputLength) |
| 171 | : mInputR(inputLength, 0), |
| 172 | mInputPhi(inputLength, 0), |
| 173 | mInputPressure(inputLength, 0), |
| 174 | mInputTilt(inputLength, 0), |
| 175 | mInputOrientation(inputLength, 0) { |
Philip Quinn | 8f953ab | 2022-12-06 15:37:07 -0800 | [diff] [blame] | 176 | LOG_ALWAYS_FATAL_IF(inputLength == 0, "Buffer input size must be greater than 0"); |
Philip Quinn | 8f953ab | 2022-12-06 15:37:07 -0800 | [diff] [blame] | 177 | } |
| 178 | |
| 179 | void TfLiteMotionPredictorBuffers::reset() { |
| 180 | std::fill(mInputR.begin(), mInputR.end(), 0); |
| 181 | std::fill(mInputPhi.begin(), mInputPhi.end(), 0); |
| 182 | std::fill(mInputPressure.begin(), mInputPressure.end(), 0); |
| 183 | std::fill(mInputTilt.begin(), mInputTilt.end(), 0); |
| 184 | std::fill(mInputOrientation.begin(), mInputOrientation.end(), 0); |
| 185 | mAxisFrom.reset(); |
| 186 | mAxisTo.reset(); |
| 187 | } |
| 188 | |
| 189 | void TfLiteMotionPredictorBuffers::copyTo(TfLiteMotionPredictorModel& model) const { |
| 190 | LOG_ALWAYS_FATAL_IF(mInputR.size() != model.inputLength(), |
| 191 | "Buffer length %zu doesn't match model input length %zu", mInputR.size(), |
| 192 | model.inputLength()); |
| 193 | LOG_ALWAYS_FATAL_IF(!isReady(), "Buffers are incomplete"); |
| 194 | |
| 195 | std::copy(mInputR.begin(), mInputR.end(), model.inputR().begin()); |
| 196 | std::copy(mInputPhi.begin(), mInputPhi.end(), model.inputPhi().begin()); |
| 197 | std::copy(mInputPressure.begin(), mInputPressure.end(), model.inputPressure().begin()); |
| 198 | std::copy(mInputTilt.begin(), mInputTilt.end(), model.inputTilt().begin()); |
| 199 | std::copy(mInputOrientation.begin(), mInputOrientation.end(), model.inputOrientation().begin()); |
| 200 | } |
| 201 | |
| 202 | void TfLiteMotionPredictorBuffers::pushSample(int64_t timestamp, |
| 203 | const TfLiteMotionPredictorSample sample) { |
| 204 | // Convert the sample (x, y) into polar (r, φ) based on a reference axis |
| 205 | // from the preceding two points (mAxisFrom/mAxisTo). |
| 206 | |
| 207 | mTimestamp = timestamp; |
| 208 | |
| 209 | if (!mAxisTo) { // First point. |
| 210 | mAxisTo = sample; |
| 211 | return; |
| 212 | } |
| 213 | |
| 214 | // Vector from the last point to the current sample point. |
| 215 | const TfLiteMotionPredictorSample::Point v = sample.position - mAxisTo->position; |
| 216 | |
| 217 | const float r = std::hypot(v.x, v.y); |
| 218 | float phi = 0; |
| 219 | float orientation = 0; |
| 220 | |
Philip Quinn | 107ce70 | 2023-07-14 13:07:13 -0700 | [diff] [blame] | 221 | if (!mAxisFrom && r > 0) { // Second point. |
Philip Quinn | 8f953ab | 2022-12-06 15:37:07 -0800 | [diff] [blame] | 222 | // We can only determine the distance from the first point, and not any |
| 223 | // angle. However, if the second point forms an axis, the orientation can |
| 224 | // be transformed relative to that axis. |
| 225 | const float axisPhi = std::atan2(v.y, v.x); |
| 226 | // A MotionEvent's orientation is measured clockwise from the vertical |
| 227 | // axis, but axisPhi is measured counter-clockwise from the horizontal |
| 228 | // axis. |
| 229 | orientation = M_PI_2 - sample.orientation - axisPhi; |
| 230 | } else { |
| 231 | const TfLiteMotionPredictorSample::Point axis = mAxisTo->position - mAxisFrom->position; |
| 232 | const float axisPhi = std::atan2(axis.y, axis.x); |
| 233 | phi = std::atan2(v.y, v.x) - axisPhi; |
| 234 | |
| 235 | if (std::hypot(axis.x, axis.y) > 0) { |
| 236 | // See note above. |
| 237 | orientation = M_PI_2 - sample.orientation - axisPhi; |
| 238 | } |
| 239 | } |
| 240 | |
| 241 | // Update the axis for the next point. |
Philip Quinn | 107ce70 | 2023-07-14 13:07:13 -0700 | [diff] [blame] | 242 | if (r > 0) { |
| 243 | mAxisFrom = mAxisTo; |
| 244 | mAxisTo = sample; |
| 245 | } |
Philip Quinn | 8f953ab | 2022-12-06 15:37:07 -0800 | [diff] [blame] | 246 | |
| 247 | // Push the current sample onto the end of the input buffers. |
Philip Quinn | 9b8926e | 2023-01-31 14:50:02 -0800 | [diff] [blame] | 248 | mInputR.pushBack(r); |
| 249 | mInputPhi.pushBack(phi); |
| 250 | mInputPressure.pushBack(sample.pressure); |
| 251 | mInputTilt.pushBack(sample.tilt); |
| 252 | mInputOrientation.pushBack(orientation); |
Philip Quinn | 8f953ab | 2022-12-06 15:37:07 -0800 | [diff] [blame] | 253 | } |
| 254 | |
Siarhei Vishniakou | fd0a68e | 2023-02-28 13:25:36 -0800 | [diff] [blame] | 255 | std::unique_ptr<TfLiteMotionPredictorModel> TfLiteMotionPredictorModel::create() { |
| 256 | const std::string modelPath = getModelPath(); |
Siarhei Vishniakou | c065d7b | 2023-03-02 14:06:29 -0800 | [diff] [blame] | 257 | android::base::unique_fd fd(open(modelPath.c_str(), O_RDONLY)); |
Philip Quinn | cb3229a | 2023-02-08 22:50:59 -0800 | [diff] [blame] | 258 | if (fd == -1) { |
| 259 | PLOG(FATAL) << "Could not read model from " << modelPath; |
| 260 | } |
Philip Quinn | 8f953ab | 2022-12-06 15:37:07 -0800 | [diff] [blame] | 261 | |
Philip Quinn | cb3229a | 2023-02-08 22:50:59 -0800 | [diff] [blame] | 262 | const off_t fdSize = lseek(fd, 0, SEEK_END); |
| 263 | if (fdSize == -1) { |
| 264 | PLOG(FATAL) << "Failed to determine file size"; |
| 265 | } |
| 266 | |
| 267 | std::unique_ptr<android::base::MappedFile> modelBuffer = |
| 268 | android::base::MappedFile::FromFd(fd, /*offset=*/0, fdSize, PROT_READ); |
| 269 | if (!modelBuffer) { |
| 270 | PLOG(FATAL) << "Failed to mmap model"; |
| 271 | } |
Philip Quinn | 8f953ab | 2022-12-06 15:37:07 -0800 | [diff] [blame] | 272 | |
Philip Quinn | f84fa49 | 2023-06-26 14:15:15 -0700 | [diff] [blame] | 273 | const std::string configPath = getConfigPath(); |
| 274 | tinyxml2::XMLDocument configDocument; |
| 275 | LOG_ALWAYS_FATAL_IF(configDocument.LoadFile(configPath.c_str()) != tinyxml2::XML_SUCCESS, |
| 276 | "Failed to load config file from %s", configPath.c_str()); |
| 277 | |
| 278 | // Parse configuration file. |
| 279 | const tinyxml2::XMLElement* configRoot = configDocument.FirstChildElement("motion-predictor"); |
| 280 | LOG_ALWAYS_FATAL_IF(!configRoot); |
Philip Quinn | 107ce70 | 2023-07-14 13:07:13 -0700 | [diff] [blame] | 281 | Config config{ |
| 282 | .predictionInterval = parseXMLInt64(*configRoot, "prediction-interval"), |
| 283 | .distanceNoiseFloor = parseXMLFloat(*configRoot, "distance-noise-floor"), |
Derek Wu | aaa4731 | 2024-03-26 15:53:44 -0700 | [diff] [blame] | 284 | .lowJerk = parseXMLFloat(*configRoot, "low-jerk"), |
| 285 | .highJerk = parseXMLFloat(*configRoot, "high-jerk"), |
Philip Quinn | 107ce70 | 2023-07-14 13:07:13 -0700 | [diff] [blame] | 286 | }; |
Philip Quinn | f84fa49 | 2023-06-26 14:15:15 -0700 | [diff] [blame] | 287 | |
Philip Quinn | 8f953ab | 2022-12-06 15:37:07 -0800 | [diff] [blame] | 288 | return std::unique_ptr<TfLiteMotionPredictorModel>( |
Philip Quinn | 107ce70 | 2023-07-14 13:07:13 -0700 | [diff] [blame] | 289 | new TfLiteMotionPredictorModel(std::move(modelBuffer), std::move(config))); |
Philip Quinn | 8f953ab | 2022-12-06 15:37:07 -0800 | [diff] [blame] | 290 | } |
| 291 | |
Philip Quinn | cb3229a | 2023-02-08 22:50:59 -0800 | [diff] [blame] | 292 | TfLiteMotionPredictorModel::TfLiteMotionPredictorModel( |
Philip Quinn | 107ce70 | 2023-07-14 13:07:13 -0700 | [diff] [blame] | 293 | std::unique_ptr<android::base::MappedFile> model, Config config) |
| 294 | : mFlatBuffer(std::move(model)), mConfig(std::move(config)) { |
Philip Quinn | cb3229a | 2023-02-08 22:50:59 -0800 | [diff] [blame] | 295 | CHECK(mFlatBuffer); |
Philip Quinn | 8f953ab | 2022-12-06 15:37:07 -0800 | [diff] [blame] | 296 | mErrorReporter = std::make_unique<LoggingErrorReporter>(); |
Philip Quinn | cb3229a | 2023-02-08 22:50:59 -0800 | [diff] [blame] | 297 | mModel = tflite::FlatBufferModel::VerifyAndBuildFromBuffer(mFlatBuffer->data(), |
| 298 | mFlatBuffer->size(), |
Philip Quinn | 8f953ab | 2022-12-06 15:37:07 -0800 | [diff] [blame] | 299 | /*extra_verifier=*/nullptr, |
| 300 | mErrorReporter.get()); |
| 301 | LOG_ALWAYS_FATAL_IF(!mModel); |
| 302 | |
Philip Quinn | da6a448 | 2023-02-07 10:09:57 -0800 | [diff] [blame] | 303 | auto resolver = createOpResolver(); |
| 304 | tflite::InterpreterBuilder builder(*mModel, *resolver); |
Philip Quinn | 8f953ab | 2022-12-06 15:37:07 -0800 | [diff] [blame] | 305 | |
| 306 | if (builder(&mInterpreter) != kTfLiteOk || !mInterpreter) { |
| 307 | LOG_ALWAYS_FATAL("Failed to build interpreter"); |
| 308 | } |
| 309 | |
| 310 | mRunner = mInterpreter->GetSignatureRunner(SIGNATURE_KEY); |
| 311 | LOG_ALWAYS_FATAL_IF(!mRunner, "Failed to find runner for signature '%s'", SIGNATURE_KEY); |
| 312 | |
| 313 | allocateTensors(); |
| 314 | } |
| 315 | |
Philip Quinn | da6a448 | 2023-02-07 10:09:57 -0800 | [diff] [blame] | 316 | TfLiteMotionPredictorModel::~TfLiteMotionPredictorModel() {} |
| 317 | |
Philip Quinn | 8f953ab | 2022-12-06 15:37:07 -0800 | [diff] [blame] | 318 | void TfLiteMotionPredictorModel::allocateTensors() { |
| 319 | if (mRunner->AllocateTensors() != kTfLiteOk) { |
| 320 | LOG_ALWAYS_FATAL("Failed to allocate tensors"); |
| 321 | } |
| 322 | |
| 323 | attachInputTensors(); |
| 324 | attachOutputTensors(); |
| 325 | |
| 326 | checkTensor<float>(mInputR); |
| 327 | checkTensor<float>(mInputPhi); |
| 328 | checkTensor<float>(mInputPressure); |
| 329 | checkTensor<float>(mInputTilt); |
| 330 | checkTensor<float>(mInputOrientation); |
| 331 | checkTensor<float>(mOutputR); |
| 332 | checkTensor<float>(mOutputPhi); |
| 333 | checkTensor<float>(mOutputPressure); |
| 334 | |
| 335 | const auto checkInputTensorSize = [this](const TfLiteTensor* tensor) { |
| 336 | const size_t size = getTensorBuffer<const float>(tensor).size(); |
| 337 | LOG_ALWAYS_FATAL_IF(size != inputLength(), |
| 338 | "Tensor '%s' length %zu does not match input length %zu", tensor->name, |
| 339 | size, inputLength()); |
| 340 | }; |
| 341 | |
| 342 | checkInputTensorSize(mInputR); |
| 343 | checkInputTensorSize(mInputPhi); |
| 344 | checkInputTensorSize(mInputPressure); |
| 345 | checkInputTensorSize(mInputTilt); |
| 346 | checkInputTensorSize(mInputOrientation); |
| 347 | } |
| 348 | |
| 349 | void TfLiteMotionPredictorModel::attachInputTensors() { |
| 350 | mInputR = findInputTensor(INPUT_R, mRunner); |
| 351 | mInputPhi = findInputTensor(INPUT_PHI, mRunner); |
| 352 | mInputPressure = findInputTensor(INPUT_PRESSURE, mRunner); |
| 353 | mInputTilt = findInputTensor(INPUT_TILT, mRunner); |
| 354 | mInputOrientation = findInputTensor(INPUT_ORIENTATION, mRunner); |
| 355 | } |
| 356 | |
| 357 | void TfLiteMotionPredictorModel::attachOutputTensors() { |
| 358 | mOutputR = findOutputTensor(OUTPUT_R, mRunner); |
| 359 | mOutputPhi = findOutputTensor(OUTPUT_PHI, mRunner); |
| 360 | mOutputPressure = findOutputTensor(OUTPUT_PRESSURE, mRunner); |
| 361 | } |
| 362 | |
| 363 | bool TfLiteMotionPredictorModel::invoke() { |
| 364 | ATRACE_BEGIN("TfLiteMotionPredictorModel::invoke"); |
| 365 | TfLiteStatus result = mRunner->Invoke(); |
| 366 | ATRACE_END(); |
| 367 | |
| 368 | if (result != kTfLiteOk) { |
| 369 | return false; |
| 370 | } |
| 371 | |
| 372 | // Invoke() might reallocate tensors, so they need to be reattached. |
| 373 | attachInputTensors(); |
| 374 | attachOutputTensors(); |
| 375 | |
| 376 | if (outputR().size() != outputPhi().size() || outputR().size() != outputPressure().size()) { |
| 377 | LOG_ALWAYS_FATAL("Output size mismatch: (r: %zu, phi: %zu, pressure: %zu)", |
| 378 | outputR().size(), outputPhi().size(), outputPressure().size()); |
| 379 | } |
| 380 | |
| 381 | return true; |
| 382 | } |
| 383 | |
| 384 | size_t TfLiteMotionPredictorModel::inputLength() const { |
| 385 | return getTensorBuffer<const float>(mInputR).size(); |
| 386 | } |
| 387 | |
Cody Heiner | dbd14eb | 2023-03-30 18:41:45 -0700 | [diff] [blame] | 388 | size_t TfLiteMotionPredictorModel::outputLength() const { |
| 389 | return getTensorBuffer<const float>(mOutputR).size(); |
| 390 | } |
| 391 | |
Philip Quinn | 8f953ab | 2022-12-06 15:37:07 -0800 | [diff] [blame] | 392 | std::span<float> TfLiteMotionPredictorModel::inputR() { |
| 393 | return getTensorBuffer<float>(mInputR); |
| 394 | } |
| 395 | |
| 396 | std::span<float> TfLiteMotionPredictorModel::inputPhi() { |
| 397 | return getTensorBuffer<float>(mInputPhi); |
| 398 | } |
| 399 | |
| 400 | std::span<float> TfLiteMotionPredictorModel::inputPressure() { |
| 401 | return getTensorBuffer<float>(mInputPressure); |
| 402 | } |
| 403 | |
| 404 | std::span<float> TfLiteMotionPredictorModel::inputTilt() { |
| 405 | return getTensorBuffer<float>(mInputTilt); |
| 406 | } |
| 407 | |
| 408 | std::span<float> TfLiteMotionPredictorModel::inputOrientation() { |
| 409 | return getTensorBuffer<float>(mInputOrientation); |
| 410 | } |
| 411 | |
| 412 | std::span<const float> TfLiteMotionPredictorModel::outputR() const { |
| 413 | return getTensorBuffer<const float>(mOutputR); |
| 414 | } |
| 415 | |
| 416 | std::span<const float> TfLiteMotionPredictorModel::outputPhi() const { |
| 417 | return getTensorBuffer<const float>(mOutputPhi); |
| 418 | } |
| 419 | |
| 420 | std::span<const float> TfLiteMotionPredictorModel::outputPressure() const { |
| 421 | return getTensorBuffer<const float>(mOutputPressure); |
| 422 | } |
| 423 | |
| 424 | } // namespace android |