blob: 87cd0e4afe23c915b16964e94bcba087ef99f668 [file] [log] [blame]
Michael Butler7a9d6092021-03-10 21:57:13 -08001/*
2 * Copyright (C) 2021 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "Burst.h"
18
19#include "Conversions.h"
20#include "Utils.h"
21
22#include <android-base/logging.h>
23#include <android/binder_auto_utils.h>
24#include <nnapi/IBurst.h>
Xusong Wangb2e80852021-03-23 15:07:10 -070025#include <nnapi/IExecution.h>
Michael Butler7a9d6092021-03-10 21:57:13 -080026#include <nnapi/Result.h>
27#include <nnapi/TypeUtils.h>
28#include <nnapi/Types.h>
29#include <nnapi/hal/HandleError.h>
30
31#include <memory>
32#include <mutex>
33#include <optional>
34#include <utility>
35
36namespace aidl::android::hardware::neuralnetworks::utils {
37namespace {
38
Xusong Wangb2e80852021-03-23 15:07:10 -070039class BurstExecution final : public nn::IExecution,
40 public std::enable_shared_from_this<BurstExecution> {
41 struct PrivateConstructorTag {};
42
43 public:
44 static nn::GeneralResult<std::shared_ptr<const BurstExecution>> create(
45 std::shared_ptr<const Burst> burst, Request request,
46 std::vector<int64_t> memoryIdentifierTokens, bool measure, int64_t loopTimeoutDuration,
47 hal::utils::RequestRelocation relocation,
48 std::vector<Burst::OptionalCacheHold> cacheHolds);
49
50 BurstExecution(PrivateConstructorTag tag, std::shared_ptr<const Burst> burst, Request request,
51 std::vector<int64_t> memoryIdentifierTokens, bool measure,
52 int64_t loopTimeoutDuration, hal::utils::RequestRelocation relocation,
53 std::vector<Burst::OptionalCacheHold> cacheHolds);
54
55 nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> compute(
56 const nn::OptionalTimePoint& deadline) const override;
57
58 nn::GeneralResult<std::pair<nn::SyncFence, nn::ExecuteFencedInfoCallback>> computeFenced(
59 const std::vector<nn::SyncFence>& waitFor, const nn::OptionalTimePoint& deadline,
60 const nn::OptionalDuration& timeoutDurationAfterFence) const override;
61
62 private:
63 const std::shared_ptr<const Burst> kBurst;
64 const Request kRequest;
65 const std::vector<int64_t>& kMemoryIdentifierTokens;
66 const bool kMeasure;
67 const int64_t kLoopTimeoutDuration;
68 const hal::utils::RequestRelocation kRelocation;
69 const std::vector<Burst::OptionalCacheHold> kCacheHolds;
70};
71
Michael Butler7a9d6092021-03-10 21:57:13 -080072nn::GeneralResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> convertExecutionResults(
73 const std::vector<OutputShape>& outputShapes, const Timing& timing) {
74 return std::make_pair(NN_TRY(nn::convert(outputShapes)), NN_TRY(nn::convert(timing)));
75}
76
77} // namespace
78
79Burst::MemoryCache::MemoryCache(std::shared_ptr<aidl_hal::IBurst> burst)
80 : kBurst(std::move(burst)) {}
81
82std::pair<int64_t, Burst::MemoryCache::SharedCleanup> Burst::MemoryCache::getOrCacheMemory(
83 const nn::SharedMemory& memory) {
84 std::lock_guard lock(mMutex);
85
86 // Get the cache payload or create it (with default values) if it does not exist.
87 auto& cachedPayload = mCache[memory];
88 {
89 const auto& [identifier, maybeCleaner] = cachedPayload;
90 // If cache payload already exists, reuse it.
91 if (auto cleaner = maybeCleaner.lock()) {
92 return std::make_pair(identifier, std::move(cleaner));
93 }
94 }
95
96 // If the code reaches this point, the cached payload either did not exist or expired prior to
97 // this call.
98
99 // Allocate a new identifier.
100 CHECK_LT(mUnusedIdentifier, std::numeric_limits<int64_t>::max());
101 const int64_t identifier = mUnusedIdentifier++;
102
103 // Create reference-counted self-cleaning cache object.
104 auto self = weak_from_this();
105 Task cleanup = [memory, identifier, maybeMemoryCache = std::move(self)] {
106 if (const auto memoryCache = maybeMemoryCache.lock()) {
107 memoryCache->tryFreeMemory(memory, identifier);
108 }
109 };
110 auto cleaner = std::make_shared<const Cleanup>(std::move(cleanup));
111
112 // Store the result in the cache and return it.
113 auto result = std::make_pair(identifier, std::move(cleaner));
114 cachedPayload = result;
115 return result;
116}
117
118std::optional<std::pair<int64_t, Burst::MemoryCache::SharedCleanup>>
119Burst::MemoryCache::getMemoryIfAvailable(const nn::SharedMemory& memory) {
120 std::lock_guard lock(mMutex);
121
122 // Get the existing cached entry if it exists.
123 const auto iter = mCache.find(memory);
124 if (iter != mCache.end()) {
125 const auto& [identifier, maybeCleaner] = iter->second;
126 if (auto cleaner = maybeCleaner.lock()) {
127 return std::make_pair(identifier, std::move(cleaner));
128 }
129 }
130
131 // If the code reaches this point, the cached payload did not exist or was actively being
132 // deleted.
133 return std::nullopt;
134}
135
136void Burst::MemoryCache::tryFreeMemory(const nn::SharedMemory& memory, int64_t identifier) {
137 {
138 std::lock_guard guard(mMutex);
139 // Remove the cached memory and payload if it is present but expired. Note that it may not
140 // be present or may not be expired because another thread may have removed or cached the
141 // same memory object before the current thread locked mMutex in tryFreeMemory.
142 const auto iter = mCache.find(memory);
143 if (iter != mCache.end()) {
144 if (std::get<WeakCleanup>(iter->second).expired()) {
145 mCache.erase(iter);
146 }
147 }
148 }
149 kBurst->releaseMemoryResource(identifier);
150}
151
152nn::GeneralResult<std::shared_ptr<const Burst>> Burst::create(
153 std::shared_ptr<aidl_hal::IBurst> burst) {
154 if (burst == nullptr) {
155 return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
156 << "aidl_hal::utils::Burst::create must have non-null burst";
157 }
158
159 return std::make_shared<const Burst>(PrivateConstructorTag{}, std::move(burst));
160}
161
162Burst::Burst(PrivateConstructorTag /*tag*/, std::shared_ptr<aidl_hal::IBurst> burst)
163 : kBurst(std::move(burst)), kMemoryCache(std::make_shared<MemoryCache>(kBurst)) {
164 CHECK(kBurst != nullptr);
165}
166
167Burst::OptionalCacheHold Burst::cacheMemory(const nn::SharedMemory& memory) const {
168 auto [identifier, hold] = kMemoryCache->getOrCacheMemory(memory);
169 return hold;
170}
171
172nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> Burst::execute(
173 const nn::Request& request, nn::MeasureTiming measure,
174 const nn::OptionalTimePoint& deadline,
175 const nn::OptionalDuration& loopTimeoutDuration) const {
Michael Butler7a9d6092021-03-10 21:57:13 -0800176 // Ensure that request is ready for IPC.
177 std::optional<nn::Request> maybeRequestInShared;
Xusong Wang5f6bedb2021-03-03 16:20:37 -0800178 hal::utils::RequestRelocation relocation;
179 const nn::Request& requestInShared =
180 NN_TRY(hal::utils::makeExecutionFailure(hal::utils::convertRequestFromPointerToShared(
Xusong Wange3d0dad2021-05-07 14:13:22 -0700181 &request, nn::kDefaultRequestMemoryAlignment, nn::kDefaultRequestMemoryPadding,
182 &maybeRequestInShared, &relocation)));
Michael Butler7a9d6092021-03-10 21:57:13 -0800183
184 const auto aidlRequest = NN_TRY(hal::utils::makeExecutionFailure(convert(requestInShared)));
185 const auto aidlMeasure = NN_TRY(hal::utils::makeExecutionFailure(convert(measure)));
186 const auto aidlDeadline = NN_TRY(hal::utils::makeExecutionFailure(convert(deadline)));
187 const auto aidlLoopTimeoutDuration =
188 NN_TRY(hal::utils::makeExecutionFailure(convert(loopTimeoutDuration)));
189
190 std::vector<int64_t> memoryIdentifierTokens;
191 std::vector<OptionalCacheHold> holds;
Xusong Wangb2e80852021-03-23 15:07:10 -0700192 memoryIdentifierTokens.reserve(requestInShared.pools.size());
193 holds.reserve(requestInShared.pools.size());
194 for (const auto& memoryPool : requestInShared.pools) {
Michael Butler7a9d6092021-03-10 21:57:13 -0800195 if (const auto* memory = std::get_if<nn::SharedMemory>(&memoryPool)) {
196 if (auto cached = kMemoryCache->getMemoryIfAvailable(*memory)) {
197 auto& [identifier, hold] = *cached;
198 memoryIdentifierTokens.push_back(identifier);
199 holds.push_back(std::move(hold));
200 continue;
201 }
202 }
203 memoryIdentifierTokens.push_back(-1);
204 }
Xusong Wangb2e80852021-03-23 15:07:10 -0700205 CHECK_EQ(requestInShared.pools.size(), memoryIdentifierTokens.size());
206
207 return executeInternal(aidlRequest, memoryIdentifierTokens, aidlMeasure, aidlDeadline,
208 aidlLoopTimeoutDuration, relocation);
209}
210
211nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> Burst::executeInternal(
212 const Request& request, const std::vector<int64_t>& memoryIdentifierTokens, bool measure,
213 int64_t deadline, int64_t loopTimeoutDuration,
214 const hal::utils::RequestRelocation& relocation) const {
215 // Ensure that at most one execution is in flight at any given time.
216 const bool alreadyInFlight = mExecutionInFlight.test_and_set();
217 if (alreadyInFlight) {
218 return NN_ERROR() << "IBurst already has an execution in flight";
219 }
220 const auto guard = ::android::base::make_scope_guard([this] { mExecutionInFlight.clear(); });
Michael Butler7a9d6092021-03-10 21:57:13 -0800221
Xusong Wang5f6bedb2021-03-03 16:20:37 -0800222 if (relocation.input) {
223 relocation.input->flush();
224 }
225
Michael Butler7a9d6092021-03-10 21:57:13 -0800226 ExecutionResult executionResult;
Xusong Wangb2e80852021-03-23 15:07:10 -0700227 const auto ret = kBurst->executeSynchronously(request, memoryIdentifierTokens, measure,
228 deadline, loopTimeoutDuration, &executionResult);
Michael Butler7a9d6092021-03-10 21:57:13 -0800229 HANDLE_ASTATUS(ret) << "execute failed";
230 if (!executionResult.outputSufficientSize) {
231 auto canonicalOutputShapes =
232 nn::convert(executionResult.outputShapes).value_or(std::vector<nn::OutputShape>{});
233 return NN_ERROR(nn::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE, std::move(canonicalOutputShapes))
234 << "execution failed with " << nn::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE;
235 }
236 auto [outputShapes, timing] = NN_TRY(hal::utils::makeExecutionFailure(
237 convertExecutionResults(executionResult.outputShapes, executionResult.timing)));
238
Xusong Wang5f6bedb2021-03-03 16:20:37 -0800239 if (relocation.output) {
240 relocation.output->flush();
241 }
Michael Butler7a9d6092021-03-10 21:57:13 -0800242 return std::make_pair(std::move(outputShapes), timing);
243}
244
Xusong Wangb2e80852021-03-23 15:07:10 -0700245nn::GeneralResult<nn::SharedExecution> Burst::createReusableExecution(
246 const nn::Request& request, nn::MeasureTiming measure,
247 const nn::OptionalDuration& loopTimeoutDuration) const {
248 // Ensure that request is ready for IPC.
249 std::optional<nn::Request> maybeRequestInShared;
250 hal::utils::RequestRelocation relocation;
251 const nn::Request& requestInShared = NN_TRY(hal::utils::convertRequestFromPointerToShared(
Xusong Wange3d0dad2021-05-07 14:13:22 -0700252 &request, nn::kDefaultRequestMemoryAlignment, nn::kDefaultRequestMemoryPadding,
253 &maybeRequestInShared, &relocation));
Xusong Wangb2e80852021-03-23 15:07:10 -0700254
255 auto aidlRequest = NN_TRY(convert(requestInShared));
256 const auto aidlMeasure = NN_TRY(convert(measure));
257 const auto aidlLoopTimeoutDuration = NN_TRY(convert(loopTimeoutDuration));
258
259 std::vector<int64_t> memoryIdentifierTokens;
260 std::vector<OptionalCacheHold> holds;
261 memoryIdentifierTokens.reserve(requestInShared.pools.size());
262 holds.reserve(requestInShared.pools.size());
263 for (const auto& memoryPool : requestInShared.pools) {
264 if (const auto* memory = std::get_if<nn::SharedMemory>(&memoryPool)) {
265 if (auto cached = kMemoryCache->getMemoryIfAvailable(*memory)) {
266 auto& [identifier, hold] = *cached;
267 memoryIdentifierTokens.push_back(identifier);
268 holds.push_back(std::move(hold));
269 continue;
270 }
271 }
272 memoryIdentifierTokens.push_back(-1);
273 }
274 CHECK_EQ(requestInShared.pools.size(), memoryIdentifierTokens.size());
275
276 return BurstExecution::create(shared_from_this(), std::move(aidlRequest),
277 std::move(memoryIdentifierTokens), aidlMeasure,
278 aidlLoopTimeoutDuration, std::move(relocation), std::move(holds));
279}
280
281nn::GeneralResult<std::shared_ptr<const BurstExecution>> BurstExecution::create(
282 std::shared_ptr<const Burst> burst, Request request,
283 std::vector<int64_t> memoryIdentifierTokens, bool measure, int64_t loopTimeoutDuration,
284 hal::utils::RequestRelocation relocation,
285 std::vector<Burst::OptionalCacheHold> cacheHolds) {
286 if (burst == nullptr) {
287 return NN_ERROR() << "aidl::utils::BurstExecution::create must have non-null burst";
288 }
289
290 return std::make_shared<const BurstExecution>(
291 PrivateConstructorTag{}, std::move(burst), std::move(request),
292 std::move(memoryIdentifierTokens), measure, loopTimeoutDuration, std::move(relocation),
293 std::move(cacheHolds));
294}
295
296BurstExecution::BurstExecution(PrivateConstructorTag /*tag*/, std::shared_ptr<const Burst> burst,
297 Request request, std::vector<int64_t> memoryIdentifierTokens,
298 bool measure, int64_t loopTimeoutDuration,
299 hal::utils::RequestRelocation relocation,
300 std::vector<Burst::OptionalCacheHold> cacheHolds)
301 : kBurst(std::move(burst)),
302 kRequest(std::move(request)),
303 kMemoryIdentifierTokens(std::move(memoryIdentifierTokens)),
304 kMeasure(measure),
305 kLoopTimeoutDuration(loopTimeoutDuration),
306 kRelocation(std::move(relocation)),
307 kCacheHolds(std::move(cacheHolds)) {}
308
309nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> BurstExecution::compute(
310 const nn::OptionalTimePoint& deadline) const {
311 const auto aidlDeadline = NN_TRY(hal::utils::makeExecutionFailure(convert(deadline)));
312 return kBurst->executeInternal(kRequest, kMemoryIdentifierTokens, kMeasure, aidlDeadline,
313 kLoopTimeoutDuration, kRelocation);
314}
315
316nn::GeneralResult<std::pair<nn::SyncFence, nn::ExecuteFencedInfoCallback>>
317BurstExecution::computeFenced(const std::vector<nn::SyncFence>& /*waitFor*/,
318 const nn::OptionalTimePoint& /*deadline*/,
319 const nn::OptionalDuration& /*timeoutDurationAfterFence*/) const {
320 return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
321 << "IExecution::computeFenced is not supported on burst object";
322}
323
Michael Butler7a9d6092021-03-10 21:57:13 -0800324} // namespace aidl::android::hardware::neuralnetworks::utils