Michael Butler | f6b2d1a | 2020-12-19 14:44:35 -0800 | [diff] [blame] | 1 | /* |
| 2 | * Copyright (C) 2019 The Android Open Source Project |
| 3 | * |
| 4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | * you may not use this file except in compliance with the License. |
| 6 | * You may obtain a copy of the License at |
| 7 | * |
| 8 | * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | * |
| 10 | * Unless required by applicable law or agreed to in writing, software |
| 11 | * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | * See the License for the specific language governing permissions and |
| 14 | * limitations under the License. |
| 15 | */ |
| 16 | |
| 17 | #define LOG_TAG "ExecutionBurstController" |
| 18 | |
| 19 | #include "ExecutionBurstController.h" |
| 20 | |
| 21 | #include <android-base/logging.h> |
| 22 | |
| 23 | #include <algorithm> |
| 24 | #include <cstring> |
| 25 | #include <limits> |
| 26 | #include <memory> |
| 27 | #include <string> |
| 28 | #include <tuple> |
| 29 | #include <utility> |
| 30 | #include <vector> |
| 31 | |
Michael Butler | 8fc4896 | 2021-01-08 17:21:27 -0800 | [diff] [blame] | 32 | #include "ExecutionBurstUtils.h" |
Michael Butler | f6b2d1a | 2020-12-19 14:44:35 -0800 | [diff] [blame] | 33 | #include "HalInterfaces.h" |
| 34 | #include "Tracing.h" |
| 35 | #include "Utils.h" |
| 36 | |
| 37 | namespace android::nn { |
| 38 | namespace { |
| 39 | |
Michael Butler | f6b2d1a | 2020-12-19 14:44:35 -0800 | [diff] [blame] | 40 | class BurstContextDeathHandler : public hardware::hidl_death_recipient { |
| 41 | public: |
| 42 | using Callback = std::function<void()>; |
| 43 | |
| 44 | BurstContextDeathHandler(const Callback& onDeathCallback) : mOnDeathCallback(onDeathCallback) { |
| 45 | CHECK(onDeathCallback != nullptr); |
| 46 | } |
| 47 | |
| 48 | void serviceDied(uint64_t /*cookie*/, const wp<hidl::base::V1_0::IBase>& /*who*/) override { |
| 49 | LOG(ERROR) << "BurstContextDeathHandler::serviceDied -- service unexpectedly died!"; |
| 50 | mOnDeathCallback(); |
| 51 | } |
| 52 | |
| 53 | private: |
| 54 | const Callback mOnDeathCallback; |
| 55 | }; |
| 56 | |
| 57 | } // anonymous namespace |
| 58 | |
Michael Butler | f6b2d1a | 2020-12-19 14:44:35 -0800 | [diff] [blame] | 59 | hardware::Return<void> ExecutionBurstController::ExecutionBurstCallback::getMemories( |
| 60 | const hardware::hidl_vec<int32_t>& slots, getMemories_cb cb) { |
| 61 | std::lock_guard<std::mutex> guard(mMutex); |
| 62 | |
| 63 | // get all memories |
| 64 | hardware::hidl_vec<hardware::hidl_memory> memories(slots.size()); |
| 65 | std::transform(slots.begin(), slots.end(), memories.begin(), [this](int32_t slot) { |
| 66 | return slot < mMemoryCache.size() ? mMemoryCache[slot] : hardware::hidl_memory{}; |
| 67 | }); |
| 68 | |
| 69 | // ensure all memories are valid |
| 70 | if (!std::all_of(memories.begin(), memories.end(), |
| 71 | [](const hardware::hidl_memory& memory) { return memory.valid(); })) { |
| 72 | cb(V1_0::ErrorStatus::INVALID_ARGUMENT, {}); |
| 73 | return hardware::Void(); |
| 74 | } |
| 75 | |
| 76 | // return successful |
| 77 | cb(V1_0::ErrorStatus::NONE, std::move(memories)); |
| 78 | return hardware::Void(); |
| 79 | } |
| 80 | |
| 81 | std::vector<int32_t> ExecutionBurstController::ExecutionBurstCallback::getSlots( |
| 82 | const hardware::hidl_vec<hardware::hidl_memory>& memories, |
| 83 | const std::vector<intptr_t>& keys) { |
| 84 | std::lock_guard<std::mutex> guard(mMutex); |
| 85 | |
| 86 | // retrieve (or bind) all slots corresponding to memories |
| 87 | std::vector<int32_t> slots; |
| 88 | slots.reserve(memories.size()); |
| 89 | for (size_t i = 0; i < memories.size(); ++i) { |
| 90 | slots.push_back(getSlotLocked(memories[i], keys[i])); |
| 91 | } |
| 92 | return slots; |
| 93 | } |
| 94 | |
| 95 | std::pair<bool, int32_t> ExecutionBurstController::ExecutionBurstCallback::freeMemory( |
| 96 | intptr_t key) { |
| 97 | std::lock_guard<std::mutex> guard(mMutex); |
| 98 | |
| 99 | auto iter = mMemoryIdToSlot.find(key); |
| 100 | if (iter == mMemoryIdToSlot.end()) { |
| 101 | return {false, 0}; |
| 102 | } |
| 103 | const int32_t slot = iter->second; |
| 104 | mMemoryIdToSlot.erase(key); |
| 105 | mMemoryCache[slot] = {}; |
| 106 | mFreeSlots.push(slot); |
| 107 | return {true, slot}; |
| 108 | } |
| 109 | |
| 110 | int32_t ExecutionBurstController::ExecutionBurstCallback::getSlotLocked( |
| 111 | const hardware::hidl_memory& memory, intptr_t key) { |
| 112 | auto iter = mMemoryIdToSlot.find(key); |
| 113 | if (iter == mMemoryIdToSlot.end()) { |
| 114 | const int32_t slot = allocateSlotLocked(); |
| 115 | mMemoryIdToSlot[key] = slot; |
| 116 | mMemoryCache[slot] = memory; |
| 117 | return slot; |
| 118 | } else { |
| 119 | const int32_t slot = iter->second; |
| 120 | return slot; |
| 121 | } |
| 122 | } |
| 123 | |
| 124 | int32_t ExecutionBurstController::ExecutionBurstCallback::allocateSlotLocked() { |
| 125 | constexpr size_t kMaxNumberOfSlots = std::numeric_limits<int32_t>::max(); |
| 126 | |
| 127 | // if there is a free slot, use it |
| 128 | if (mFreeSlots.size() > 0) { |
| 129 | const int32_t slot = mFreeSlots.top(); |
| 130 | mFreeSlots.pop(); |
| 131 | return slot; |
| 132 | } |
| 133 | |
| 134 | // otherwise use a slot for the first time |
| 135 | CHECK(mMemoryCache.size() < kMaxNumberOfSlots) << "Exceeded maximum number of slots!"; |
| 136 | const int32_t slot = static_cast<int32_t>(mMemoryCache.size()); |
| 137 | mMemoryCache.emplace_back(); |
| 138 | |
| 139 | return slot; |
| 140 | } |
| 141 | |
| 142 | std::unique_ptr<ExecutionBurstController> ExecutionBurstController::create( |
| 143 | const sp<V1_2::IPreparedModel>& preparedModel, |
| 144 | std::chrono::microseconds pollingTimeWindow) { |
| 145 | // check inputs |
| 146 | if (preparedModel == nullptr) { |
| 147 | LOG(ERROR) << "ExecutionBurstController::create passed a nullptr"; |
| 148 | return nullptr; |
| 149 | } |
| 150 | |
| 151 | // create callback object |
| 152 | sp<ExecutionBurstCallback> callback = new ExecutionBurstCallback(); |
| 153 | |
| 154 | // create FMQ objects |
| 155 | auto [requestChannelSenderTemp, requestChannelDescriptor] = |
| 156 | RequestChannelSender::create(kExecutionBurstChannelLength); |
| 157 | auto [resultChannelReceiverTemp, resultChannelDescriptor] = |
| 158 | ResultChannelReceiver::create(kExecutionBurstChannelLength, pollingTimeWindow); |
| 159 | std::shared_ptr<RequestChannelSender> requestChannelSender = |
| 160 | std::move(requestChannelSenderTemp); |
| 161 | std::shared_ptr<ResultChannelReceiver> resultChannelReceiver = |
| 162 | std::move(resultChannelReceiverTemp); |
| 163 | |
| 164 | // check FMQ objects |
| 165 | if (!requestChannelSender || !resultChannelReceiver || !requestChannelDescriptor || |
| 166 | !resultChannelDescriptor) { |
| 167 | LOG(ERROR) << "ExecutionBurstController::create failed to create FastMessageQueue"; |
| 168 | return nullptr; |
| 169 | } |
| 170 | |
| 171 | // configure burst |
| 172 | V1_0::ErrorStatus errorStatus; |
| 173 | sp<IBurstContext> burstContext; |
| 174 | const hardware::Return<void> ret = preparedModel->configureExecutionBurst( |
| 175 | callback, *requestChannelDescriptor, *resultChannelDescriptor, |
| 176 | [&errorStatus, &burstContext](V1_0::ErrorStatus status, |
| 177 | const sp<IBurstContext>& context) { |
| 178 | errorStatus = status; |
| 179 | burstContext = context; |
| 180 | }); |
| 181 | |
| 182 | // check burst |
| 183 | if (!ret.isOk()) { |
| 184 | LOG(ERROR) << "IPreparedModel::configureExecutionBurst failed with description " |
| 185 | << ret.description(); |
| 186 | return nullptr; |
| 187 | } |
| 188 | if (errorStatus != V1_0::ErrorStatus::NONE) { |
| 189 | LOG(ERROR) << "IPreparedModel::configureExecutionBurst failed with status " |
| 190 | << toString(errorStatus); |
| 191 | return nullptr; |
| 192 | } |
| 193 | if (burstContext == nullptr) { |
| 194 | LOG(ERROR) << "IPreparedModel::configureExecutionBurst returned nullptr for burst"; |
| 195 | return nullptr; |
| 196 | } |
| 197 | |
| 198 | // create death handler object |
| 199 | BurstContextDeathHandler::Callback onDeathCallback = [requestChannelSender, |
| 200 | resultChannelReceiver] { |
| 201 | requestChannelSender->invalidate(); |
| 202 | resultChannelReceiver->invalidate(); |
| 203 | }; |
| 204 | const sp<BurstContextDeathHandler> deathHandler = new BurstContextDeathHandler(onDeathCallback); |
| 205 | |
| 206 | // linkToDeath registers a callback that will be invoked on service death to |
| 207 | // proactively handle service crashes. If the linkToDeath call fails, |
| 208 | // asynchronous calls are susceptible to hangs if the service crashes before |
| 209 | // providing the response. |
| 210 | const hardware::Return<bool> deathHandlerRet = burstContext->linkToDeath(deathHandler, 0); |
| 211 | if (!deathHandlerRet.isOk() || deathHandlerRet != true) { |
| 212 | LOG(ERROR) << "ExecutionBurstController::create -- Failed to register a death recipient " |
| 213 | "for the IBurstContext object."; |
| 214 | return nullptr; |
| 215 | } |
| 216 | |
| 217 | // make and return controller |
| 218 | return std::make_unique<ExecutionBurstController>(requestChannelSender, resultChannelReceiver, |
| 219 | burstContext, callback, deathHandler); |
| 220 | } |
| 221 | |
| 222 | ExecutionBurstController::ExecutionBurstController( |
| 223 | const std::shared_ptr<RequestChannelSender>& requestChannelSender, |
| 224 | const std::shared_ptr<ResultChannelReceiver>& resultChannelReceiver, |
| 225 | const sp<IBurstContext>& burstContext, const sp<ExecutionBurstCallback>& callback, |
| 226 | const sp<hardware::hidl_death_recipient>& deathHandler) |
| 227 | : mRequestChannelSender(requestChannelSender), |
| 228 | mResultChannelReceiver(resultChannelReceiver), |
| 229 | mBurstContext(burstContext), |
| 230 | mMemoryCache(callback), |
| 231 | mDeathHandler(deathHandler) {} |
| 232 | |
| 233 | ExecutionBurstController::~ExecutionBurstController() { |
| 234 | // It is safe to ignore any errors resulting from this unlinkToDeath call |
| 235 | // because the ExecutionBurstController object is already being destroyed |
| 236 | // and its underlying IBurstContext object is no longer being used by the NN |
| 237 | // runtime. |
| 238 | if (mDeathHandler) { |
| 239 | mBurstContext->unlinkToDeath(mDeathHandler).isOk(); |
| 240 | } |
| 241 | } |
| 242 | |
| 243 | static std::tuple<int, std::vector<V1_2::OutputShape>, V1_2::Timing, bool> getExecutionResult( |
| 244 | V1_0::ErrorStatus status, std::vector<V1_2::OutputShape> outputShapes, V1_2::Timing timing, |
| 245 | bool fallback) { |
| 246 | auto [n, checkedOutputShapes, checkedTiming] = |
| 247 | getExecutionResult(convertToV1_3(status), std::move(outputShapes), timing); |
| 248 | return {n, convertToV1_2(checkedOutputShapes), convertToV1_2(checkedTiming), fallback}; |
| 249 | } |
| 250 | |
| 251 | std::tuple<int, std::vector<V1_2::OutputShape>, V1_2::Timing, bool> |
| 252 | ExecutionBurstController::compute(const V1_0::Request& request, V1_2::MeasureTiming measure, |
| 253 | const std::vector<intptr_t>& memoryIds) { |
| 254 | // This is the first point when we know an execution is occurring, so begin |
| 255 | // to collect systraces. Note that the first point we can begin collecting |
| 256 | // systraces in ExecutionBurstServer is when the RequestChannelReceiver |
| 257 | // realizes there is data in the FMQ, so ExecutionBurstServer collects |
| 258 | // systraces at different points in the code. |
| 259 | NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_EXECUTION, "ExecutionBurstController::compute"); |
| 260 | |
| 261 | std::lock_guard<std::mutex> guard(mMutex); |
| 262 | |
| 263 | // send request packet |
| 264 | const std::vector<int32_t> slots = mMemoryCache->getSlots(request.pools, memoryIds); |
| 265 | const bool success = mRequestChannelSender->send(request, measure, slots); |
| 266 | if (!success) { |
| 267 | LOG(ERROR) << "Error sending FMQ packet"; |
| 268 | // only use fallback execution path if the packet could not be sent |
| 269 | return getExecutionResult(V1_0::ErrorStatus::GENERAL_FAILURE, {}, kNoTiming12, |
| 270 | /*fallback=*/true); |
| 271 | } |
| 272 | |
| 273 | // get result packet |
| 274 | const auto result = mResultChannelReceiver->getBlocking(); |
| 275 | if (!result) { |
| 276 | LOG(ERROR) << "Error retrieving FMQ packet"; |
| 277 | // only use fallback execution path if the packet could not be sent |
| 278 | return getExecutionResult(V1_0::ErrorStatus::GENERAL_FAILURE, {}, kNoTiming12, |
| 279 | /*fallback=*/false); |
| 280 | } |
| 281 | |
| 282 | // unpack results and return (only use fallback execution path if the |
| 283 | // packet could not be sent) |
| 284 | auto [status, outputShapes, timing] = std::move(*result); |
| 285 | return getExecutionResult(status, std::move(outputShapes), timing, /*fallback=*/false); |
| 286 | } |
| 287 | |
| 288 | void ExecutionBurstController::freeMemory(intptr_t key) { |
| 289 | std::lock_guard<std::mutex> guard(mMutex); |
| 290 | |
| 291 | bool valid; |
| 292 | int32_t slot; |
| 293 | std::tie(valid, slot) = mMemoryCache->freeMemory(key); |
| 294 | if (valid) { |
| 295 | mBurstContext->freeMemory(slot).isOk(); |
| 296 | } |
| 297 | } |
| 298 | |
| 299 | } // namespace android::nn |