blob: 2265861b4148f9fbcb0dc0f96c9768e70d64325a [file] [log] [blame]
Michael Butlerf6b2d1a2020-12-19 14:44:35 -08001/*
2 * Copyright (C) 2019 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#define LOG_TAG "ExecutionBurstController"
18
19#include "ExecutionBurstController.h"
20
21#include <android-base/logging.h>
22
23#include <algorithm>
24#include <cstring>
25#include <limits>
26#include <memory>
27#include <string>
28#include <tuple>
29#include <utility>
30#include <vector>
31
Michael Butler8fc48962021-01-08 17:21:27 -080032#include "ExecutionBurstUtils.h"
Michael Butlerf6b2d1a2020-12-19 14:44:35 -080033#include "HalInterfaces.h"
34#include "Tracing.h"
35#include "Utils.h"
36
37namespace android::nn {
38namespace {
39
Michael Butlerf6b2d1a2020-12-19 14:44:35 -080040class BurstContextDeathHandler : public hardware::hidl_death_recipient {
41 public:
42 using Callback = std::function<void()>;
43
44 BurstContextDeathHandler(const Callback& onDeathCallback) : mOnDeathCallback(onDeathCallback) {
45 CHECK(onDeathCallback != nullptr);
46 }
47
48 void serviceDied(uint64_t /*cookie*/, const wp<hidl::base::V1_0::IBase>& /*who*/) override {
49 LOG(ERROR) << "BurstContextDeathHandler::serviceDied -- service unexpectedly died!";
50 mOnDeathCallback();
51 }
52
53 private:
54 const Callback mOnDeathCallback;
55};
56
57} // anonymous namespace
58
Michael Butlerf6b2d1a2020-12-19 14:44:35 -080059hardware::Return<void> ExecutionBurstController::ExecutionBurstCallback::getMemories(
60 const hardware::hidl_vec<int32_t>& slots, getMemories_cb cb) {
61 std::lock_guard<std::mutex> guard(mMutex);
62
63 // get all memories
64 hardware::hidl_vec<hardware::hidl_memory> memories(slots.size());
65 std::transform(slots.begin(), slots.end(), memories.begin(), [this](int32_t slot) {
66 return slot < mMemoryCache.size() ? mMemoryCache[slot] : hardware::hidl_memory{};
67 });
68
69 // ensure all memories are valid
70 if (!std::all_of(memories.begin(), memories.end(),
71 [](const hardware::hidl_memory& memory) { return memory.valid(); })) {
72 cb(V1_0::ErrorStatus::INVALID_ARGUMENT, {});
73 return hardware::Void();
74 }
75
76 // return successful
77 cb(V1_0::ErrorStatus::NONE, std::move(memories));
78 return hardware::Void();
79}
80
81std::vector<int32_t> ExecutionBurstController::ExecutionBurstCallback::getSlots(
82 const hardware::hidl_vec<hardware::hidl_memory>& memories,
83 const std::vector<intptr_t>& keys) {
84 std::lock_guard<std::mutex> guard(mMutex);
85
86 // retrieve (or bind) all slots corresponding to memories
87 std::vector<int32_t> slots;
88 slots.reserve(memories.size());
89 for (size_t i = 0; i < memories.size(); ++i) {
90 slots.push_back(getSlotLocked(memories[i], keys[i]));
91 }
92 return slots;
93}
94
95std::pair<bool, int32_t> ExecutionBurstController::ExecutionBurstCallback::freeMemory(
96 intptr_t key) {
97 std::lock_guard<std::mutex> guard(mMutex);
98
99 auto iter = mMemoryIdToSlot.find(key);
100 if (iter == mMemoryIdToSlot.end()) {
101 return {false, 0};
102 }
103 const int32_t slot = iter->second;
104 mMemoryIdToSlot.erase(key);
105 mMemoryCache[slot] = {};
106 mFreeSlots.push(slot);
107 return {true, slot};
108}
109
110int32_t ExecutionBurstController::ExecutionBurstCallback::getSlotLocked(
111 const hardware::hidl_memory& memory, intptr_t key) {
112 auto iter = mMemoryIdToSlot.find(key);
113 if (iter == mMemoryIdToSlot.end()) {
114 const int32_t slot = allocateSlotLocked();
115 mMemoryIdToSlot[key] = slot;
116 mMemoryCache[slot] = memory;
117 return slot;
118 } else {
119 const int32_t slot = iter->second;
120 return slot;
121 }
122}
123
124int32_t ExecutionBurstController::ExecutionBurstCallback::allocateSlotLocked() {
125 constexpr size_t kMaxNumberOfSlots = std::numeric_limits<int32_t>::max();
126
127 // if there is a free slot, use it
128 if (mFreeSlots.size() > 0) {
129 const int32_t slot = mFreeSlots.top();
130 mFreeSlots.pop();
131 return slot;
132 }
133
134 // otherwise use a slot for the first time
135 CHECK(mMemoryCache.size() < kMaxNumberOfSlots) << "Exceeded maximum number of slots!";
136 const int32_t slot = static_cast<int32_t>(mMemoryCache.size());
137 mMemoryCache.emplace_back();
138
139 return slot;
140}
141
142std::unique_ptr<ExecutionBurstController> ExecutionBurstController::create(
143 const sp<V1_2::IPreparedModel>& preparedModel,
144 std::chrono::microseconds pollingTimeWindow) {
145 // check inputs
146 if (preparedModel == nullptr) {
147 LOG(ERROR) << "ExecutionBurstController::create passed a nullptr";
148 return nullptr;
149 }
150
151 // create callback object
152 sp<ExecutionBurstCallback> callback = new ExecutionBurstCallback();
153
154 // create FMQ objects
155 auto [requestChannelSenderTemp, requestChannelDescriptor] =
156 RequestChannelSender::create(kExecutionBurstChannelLength);
157 auto [resultChannelReceiverTemp, resultChannelDescriptor] =
158 ResultChannelReceiver::create(kExecutionBurstChannelLength, pollingTimeWindow);
159 std::shared_ptr<RequestChannelSender> requestChannelSender =
160 std::move(requestChannelSenderTemp);
161 std::shared_ptr<ResultChannelReceiver> resultChannelReceiver =
162 std::move(resultChannelReceiverTemp);
163
164 // check FMQ objects
165 if (!requestChannelSender || !resultChannelReceiver || !requestChannelDescriptor ||
166 !resultChannelDescriptor) {
167 LOG(ERROR) << "ExecutionBurstController::create failed to create FastMessageQueue";
168 return nullptr;
169 }
170
171 // configure burst
172 V1_0::ErrorStatus errorStatus;
173 sp<IBurstContext> burstContext;
174 const hardware::Return<void> ret = preparedModel->configureExecutionBurst(
175 callback, *requestChannelDescriptor, *resultChannelDescriptor,
176 [&errorStatus, &burstContext](V1_0::ErrorStatus status,
177 const sp<IBurstContext>& context) {
178 errorStatus = status;
179 burstContext = context;
180 });
181
182 // check burst
183 if (!ret.isOk()) {
184 LOG(ERROR) << "IPreparedModel::configureExecutionBurst failed with description "
185 << ret.description();
186 return nullptr;
187 }
188 if (errorStatus != V1_0::ErrorStatus::NONE) {
189 LOG(ERROR) << "IPreparedModel::configureExecutionBurst failed with status "
190 << toString(errorStatus);
191 return nullptr;
192 }
193 if (burstContext == nullptr) {
194 LOG(ERROR) << "IPreparedModel::configureExecutionBurst returned nullptr for burst";
195 return nullptr;
196 }
197
198 // create death handler object
199 BurstContextDeathHandler::Callback onDeathCallback = [requestChannelSender,
200 resultChannelReceiver] {
201 requestChannelSender->invalidate();
202 resultChannelReceiver->invalidate();
203 };
204 const sp<BurstContextDeathHandler> deathHandler = new BurstContextDeathHandler(onDeathCallback);
205
206 // linkToDeath registers a callback that will be invoked on service death to
207 // proactively handle service crashes. If the linkToDeath call fails,
208 // asynchronous calls are susceptible to hangs if the service crashes before
209 // providing the response.
210 const hardware::Return<bool> deathHandlerRet = burstContext->linkToDeath(deathHandler, 0);
211 if (!deathHandlerRet.isOk() || deathHandlerRet != true) {
212 LOG(ERROR) << "ExecutionBurstController::create -- Failed to register a death recipient "
213 "for the IBurstContext object.";
214 return nullptr;
215 }
216
217 // make and return controller
218 return std::make_unique<ExecutionBurstController>(requestChannelSender, resultChannelReceiver,
219 burstContext, callback, deathHandler);
220}
221
222ExecutionBurstController::ExecutionBurstController(
223 const std::shared_ptr<RequestChannelSender>& requestChannelSender,
224 const std::shared_ptr<ResultChannelReceiver>& resultChannelReceiver,
225 const sp<IBurstContext>& burstContext, const sp<ExecutionBurstCallback>& callback,
226 const sp<hardware::hidl_death_recipient>& deathHandler)
227 : mRequestChannelSender(requestChannelSender),
228 mResultChannelReceiver(resultChannelReceiver),
229 mBurstContext(burstContext),
230 mMemoryCache(callback),
231 mDeathHandler(deathHandler) {}
232
233ExecutionBurstController::~ExecutionBurstController() {
234 // It is safe to ignore any errors resulting from this unlinkToDeath call
235 // because the ExecutionBurstController object is already being destroyed
236 // and its underlying IBurstContext object is no longer being used by the NN
237 // runtime.
238 if (mDeathHandler) {
239 mBurstContext->unlinkToDeath(mDeathHandler).isOk();
240 }
241}
242
243static std::tuple<int, std::vector<V1_2::OutputShape>, V1_2::Timing, bool> getExecutionResult(
244 V1_0::ErrorStatus status, std::vector<V1_2::OutputShape> outputShapes, V1_2::Timing timing,
245 bool fallback) {
246 auto [n, checkedOutputShapes, checkedTiming] =
247 getExecutionResult(convertToV1_3(status), std::move(outputShapes), timing);
248 return {n, convertToV1_2(checkedOutputShapes), convertToV1_2(checkedTiming), fallback};
249}
250
251std::tuple<int, std::vector<V1_2::OutputShape>, V1_2::Timing, bool>
252ExecutionBurstController::compute(const V1_0::Request& request, V1_2::MeasureTiming measure,
253 const std::vector<intptr_t>& memoryIds) {
254 // This is the first point when we know an execution is occurring, so begin
255 // to collect systraces. Note that the first point we can begin collecting
256 // systraces in ExecutionBurstServer is when the RequestChannelReceiver
257 // realizes there is data in the FMQ, so ExecutionBurstServer collects
258 // systraces at different points in the code.
259 NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_EXECUTION, "ExecutionBurstController::compute");
260
261 std::lock_guard<std::mutex> guard(mMutex);
262
263 // send request packet
264 const std::vector<int32_t> slots = mMemoryCache->getSlots(request.pools, memoryIds);
265 const bool success = mRequestChannelSender->send(request, measure, slots);
266 if (!success) {
267 LOG(ERROR) << "Error sending FMQ packet";
268 // only use fallback execution path if the packet could not be sent
269 return getExecutionResult(V1_0::ErrorStatus::GENERAL_FAILURE, {}, kNoTiming12,
270 /*fallback=*/true);
271 }
272
273 // get result packet
274 const auto result = mResultChannelReceiver->getBlocking();
275 if (!result) {
276 LOG(ERROR) << "Error retrieving FMQ packet";
277 // only use fallback execution path if the packet could not be sent
278 return getExecutionResult(V1_0::ErrorStatus::GENERAL_FAILURE, {}, kNoTiming12,
279 /*fallback=*/false);
280 }
281
282 // unpack results and return (only use fallback execution path if the
283 // packet could not be sent)
284 auto [status, outputShapes, timing] = std::move(*result);
285 return getExecutionResult(status, std::move(outputShapes), timing, /*fallback=*/false);
286}
287
288void ExecutionBurstController::freeMemory(intptr_t key) {
289 std::lock_guard<std::mutex> guard(mMutex);
290
291 bool valid;
292 int32_t slot;
293 std::tie(valid, slot) = mMemoryCache->freeMemory(key);
294 if (valid) {
295 mBurstContext->freeMemory(slot).isOk();
296 }
297}
298
299} // namespace android::nn