blob: fb00b264e34dc1b872932d62baf51a3ec1bca1bb [file] [log] [blame]
Michael Butler7a9d6092021-03-10 21:57:13 -08001/*
2 * Copyright (C) 2021 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "Burst.h"
18
19#include "Conversions.h"
20#include "Utils.h"
21
22#include <android-base/logging.h>
23#include <android/binder_auto_utils.h>
24#include <nnapi/IBurst.h>
Xusong Wangb2e80852021-03-23 15:07:10 -070025#include <nnapi/IExecution.h>
Michael Butler7a9d6092021-03-10 21:57:13 -080026#include <nnapi/Result.h>
27#include <nnapi/TypeUtils.h>
28#include <nnapi/Types.h>
Michael Butler7a9d6092021-03-10 21:57:13 -080029
30#include <memory>
31#include <mutex>
32#include <optional>
33#include <utility>
34
35namespace aidl::android::hardware::neuralnetworks::utils {
36namespace {
37
Xusong Wangb2e80852021-03-23 15:07:10 -070038class BurstExecution final : public nn::IExecution,
39 public std::enable_shared_from_this<BurstExecution> {
40 struct PrivateConstructorTag {};
41
42 public:
43 static nn::GeneralResult<std::shared_ptr<const BurstExecution>> create(
44 std::shared_ptr<const Burst> burst, Request request,
45 std::vector<int64_t> memoryIdentifierTokens, bool measure, int64_t loopTimeoutDuration,
46 hal::utils::RequestRelocation relocation,
47 std::vector<Burst::OptionalCacheHold> cacheHolds);
48
49 BurstExecution(PrivateConstructorTag tag, std::shared_ptr<const Burst> burst, Request request,
50 std::vector<int64_t> memoryIdentifierTokens, bool measure,
51 int64_t loopTimeoutDuration, hal::utils::RequestRelocation relocation,
52 std::vector<Burst::OptionalCacheHold> cacheHolds);
53
54 nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> compute(
55 const nn::OptionalTimePoint& deadline) const override;
56
57 nn::GeneralResult<std::pair<nn::SyncFence, nn::ExecuteFencedInfoCallback>> computeFenced(
58 const std::vector<nn::SyncFence>& waitFor, const nn::OptionalTimePoint& deadline,
59 const nn::OptionalDuration& timeoutDurationAfterFence) const override;
60
61 private:
62 const std::shared_ptr<const Burst> kBurst;
63 const Request kRequest;
Xusong Wang5e045952021-05-18 13:54:11 -070064 const std::vector<int64_t> kMemoryIdentifierTokens;
Xusong Wangb2e80852021-03-23 15:07:10 -070065 const bool kMeasure;
66 const int64_t kLoopTimeoutDuration;
67 const hal::utils::RequestRelocation kRelocation;
68 const std::vector<Burst::OptionalCacheHold> kCacheHolds;
69};
70
Michael Butler7a9d6092021-03-10 21:57:13 -080071nn::GeneralResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> convertExecutionResults(
72 const std::vector<OutputShape>& outputShapes, const Timing& timing) {
73 return std::make_pair(NN_TRY(nn::convert(outputShapes)), NN_TRY(nn::convert(timing)));
74}
75
76} // namespace
77
78Burst::MemoryCache::MemoryCache(std::shared_ptr<aidl_hal::IBurst> burst)
79 : kBurst(std::move(burst)) {}
80
81std::pair<int64_t, Burst::MemoryCache::SharedCleanup> Burst::MemoryCache::getOrCacheMemory(
82 const nn::SharedMemory& memory) {
83 std::lock_guard lock(mMutex);
84
85 // Get the cache payload or create it (with default values) if it does not exist.
86 auto& cachedPayload = mCache[memory];
87 {
88 const auto& [identifier, maybeCleaner] = cachedPayload;
89 // If cache payload already exists, reuse it.
90 if (auto cleaner = maybeCleaner.lock()) {
91 return std::make_pair(identifier, std::move(cleaner));
92 }
93 }
94
95 // If the code reaches this point, the cached payload either did not exist or expired prior to
96 // this call.
97
98 // Allocate a new identifier.
99 CHECK_LT(mUnusedIdentifier, std::numeric_limits<int64_t>::max());
100 const int64_t identifier = mUnusedIdentifier++;
101
102 // Create reference-counted self-cleaning cache object.
103 auto self = weak_from_this();
104 Task cleanup = [memory, identifier, maybeMemoryCache = std::move(self)] {
105 if (const auto memoryCache = maybeMemoryCache.lock()) {
106 memoryCache->tryFreeMemory(memory, identifier);
107 }
108 };
109 auto cleaner = std::make_shared<const Cleanup>(std::move(cleanup));
110
111 // Store the result in the cache and return it.
112 auto result = std::make_pair(identifier, std::move(cleaner));
113 cachedPayload = result;
114 return result;
115}
116
117std::optional<std::pair<int64_t, Burst::MemoryCache::SharedCleanup>>
118Burst::MemoryCache::getMemoryIfAvailable(const nn::SharedMemory& memory) {
119 std::lock_guard lock(mMutex);
120
121 // Get the existing cached entry if it exists.
122 const auto iter = mCache.find(memory);
123 if (iter != mCache.end()) {
124 const auto& [identifier, maybeCleaner] = iter->second;
125 if (auto cleaner = maybeCleaner.lock()) {
126 return std::make_pair(identifier, std::move(cleaner));
127 }
128 }
129
130 // If the code reaches this point, the cached payload did not exist or was actively being
131 // deleted.
132 return std::nullopt;
133}
134
135void Burst::MemoryCache::tryFreeMemory(const nn::SharedMemory& memory, int64_t identifier) {
136 {
137 std::lock_guard guard(mMutex);
138 // Remove the cached memory and payload if it is present but expired. Note that it may not
139 // be present or may not be expired because another thread may have removed or cached the
140 // same memory object before the current thread locked mMutex in tryFreeMemory.
141 const auto iter = mCache.find(memory);
142 if (iter != mCache.end()) {
143 if (std::get<WeakCleanup>(iter->second).expired()) {
144 mCache.erase(iter);
145 }
146 }
147 }
148 kBurst->releaseMemoryResource(identifier);
149}
150
151nn::GeneralResult<std::shared_ptr<const Burst>> Burst::create(
152 std::shared_ptr<aidl_hal::IBurst> burst) {
153 if (burst == nullptr) {
154 return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
155 << "aidl_hal::utils::Burst::create must have non-null burst";
156 }
157
158 return std::make_shared<const Burst>(PrivateConstructorTag{}, std::move(burst));
159}
160
161Burst::Burst(PrivateConstructorTag /*tag*/, std::shared_ptr<aidl_hal::IBurst> burst)
162 : kBurst(std::move(burst)), kMemoryCache(std::make_shared<MemoryCache>(kBurst)) {
163 CHECK(kBurst != nullptr);
164}
165
166Burst::OptionalCacheHold Burst::cacheMemory(const nn::SharedMemory& memory) const {
167 auto [identifier, hold] = kMemoryCache->getOrCacheMemory(memory);
168 return hold;
169}
170
171nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> Burst::execute(
172 const nn::Request& request, nn::MeasureTiming measure,
173 const nn::OptionalTimePoint& deadline,
174 const nn::OptionalDuration& loopTimeoutDuration) const {
Michael Butler7a9d6092021-03-10 21:57:13 -0800175 // Ensure that request is ready for IPC.
176 std::optional<nn::Request> maybeRequestInShared;
Xusong Wang5f6bedb2021-03-03 16:20:37 -0800177 hal::utils::RequestRelocation relocation;
Michael Butlerff9a5a52021-10-15 16:23:20 -0700178 const nn::Request& requestInShared = NN_TRY(hal::utils::convertRequestFromPointerToShared(
179 &request, nn::kDefaultRequestMemoryAlignment, nn::kDefaultRequestMemoryPadding,
180 &maybeRequestInShared, &relocation));
Michael Butler7a9d6092021-03-10 21:57:13 -0800181
Michael Butlerff9a5a52021-10-15 16:23:20 -0700182 const auto aidlRequest = NN_TRY(convert(requestInShared));
183 const auto aidlMeasure = NN_TRY(convert(measure));
184 const auto aidlDeadline = NN_TRY(convert(deadline));
185 const auto aidlLoopTimeoutDuration = NN_TRY(convert(loopTimeoutDuration));
Michael Butler7a9d6092021-03-10 21:57:13 -0800186
187 std::vector<int64_t> memoryIdentifierTokens;
188 std::vector<OptionalCacheHold> holds;
Xusong Wangb2e80852021-03-23 15:07:10 -0700189 memoryIdentifierTokens.reserve(requestInShared.pools.size());
190 holds.reserve(requestInShared.pools.size());
191 for (const auto& memoryPool : requestInShared.pools) {
Michael Butler7a9d6092021-03-10 21:57:13 -0800192 if (const auto* memory = std::get_if<nn::SharedMemory>(&memoryPool)) {
193 if (auto cached = kMemoryCache->getMemoryIfAvailable(*memory)) {
194 auto& [identifier, hold] = *cached;
195 memoryIdentifierTokens.push_back(identifier);
196 holds.push_back(std::move(hold));
197 continue;
198 }
199 }
200 memoryIdentifierTokens.push_back(-1);
201 }
Xusong Wangb2e80852021-03-23 15:07:10 -0700202 CHECK_EQ(requestInShared.pools.size(), memoryIdentifierTokens.size());
203
204 return executeInternal(aidlRequest, memoryIdentifierTokens, aidlMeasure, aidlDeadline,
205 aidlLoopTimeoutDuration, relocation);
206}
207
208nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> Burst::executeInternal(
209 const Request& request, const std::vector<int64_t>& memoryIdentifierTokens, bool measure,
210 int64_t deadline, int64_t loopTimeoutDuration,
211 const hal::utils::RequestRelocation& relocation) const {
212 // Ensure that at most one execution is in flight at any given time.
213 const bool alreadyInFlight = mExecutionInFlight.test_and_set();
214 if (alreadyInFlight) {
215 return NN_ERROR() << "IBurst already has an execution in flight";
216 }
217 const auto guard = ::android::base::make_scope_guard([this] { mExecutionInFlight.clear(); });
Michael Butler7a9d6092021-03-10 21:57:13 -0800218
Xusong Wang5f6bedb2021-03-03 16:20:37 -0800219 if (relocation.input) {
220 relocation.input->flush();
221 }
222
Michael Butler7a9d6092021-03-10 21:57:13 -0800223 ExecutionResult executionResult;
Xusong Wangb2e80852021-03-23 15:07:10 -0700224 const auto ret = kBurst->executeSynchronously(request, memoryIdentifierTokens, measure,
225 deadline, loopTimeoutDuration, &executionResult);
Michael Butler7a9d6092021-03-10 21:57:13 -0800226 HANDLE_ASTATUS(ret) << "execute failed";
227 if (!executionResult.outputSufficientSize) {
228 auto canonicalOutputShapes =
229 nn::convert(executionResult.outputShapes).value_or(std::vector<nn::OutputShape>{});
230 return NN_ERROR(nn::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE, std::move(canonicalOutputShapes))
231 << "execution failed with " << nn::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE;
232 }
Michael Butlerff9a5a52021-10-15 16:23:20 -0700233 auto [outputShapes, timing] =
234 NN_TRY(convertExecutionResults(executionResult.outputShapes, executionResult.timing));
Michael Butler7a9d6092021-03-10 21:57:13 -0800235
Xusong Wang5f6bedb2021-03-03 16:20:37 -0800236 if (relocation.output) {
237 relocation.output->flush();
238 }
Michael Butler7a9d6092021-03-10 21:57:13 -0800239 return std::make_pair(std::move(outputShapes), timing);
240}
241
Xusong Wangb2e80852021-03-23 15:07:10 -0700242nn::GeneralResult<nn::SharedExecution> Burst::createReusableExecution(
243 const nn::Request& request, nn::MeasureTiming measure,
244 const nn::OptionalDuration& loopTimeoutDuration) const {
245 // Ensure that request is ready for IPC.
246 std::optional<nn::Request> maybeRequestInShared;
247 hal::utils::RequestRelocation relocation;
248 const nn::Request& requestInShared = NN_TRY(hal::utils::convertRequestFromPointerToShared(
Xusong Wange3d0dad2021-05-07 14:13:22 -0700249 &request, nn::kDefaultRequestMemoryAlignment, nn::kDefaultRequestMemoryPadding,
250 &maybeRequestInShared, &relocation));
Xusong Wangb2e80852021-03-23 15:07:10 -0700251
252 auto aidlRequest = NN_TRY(convert(requestInShared));
253 const auto aidlMeasure = NN_TRY(convert(measure));
254 const auto aidlLoopTimeoutDuration = NN_TRY(convert(loopTimeoutDuration));
255
256 std::vector<int64_t> memoryIdentifierTokens;
257 std::vector<OptionalCacheHold> holds;
258 memoryIdentifierTokens.reserve(requestInShared.pools.size());
259 holds.reserve(requestInShared.pools.size());
260 for (const auto& memoryPool : requestInShared.pools) {
261 if (const auto* memory = std::get_if<nn::SharedMemory>(&memoryPool)) {
262 if (auto cached = kMemoryCache->getMemoryIfAvailable(*memory)) {
263 auto& [identifier, hold] = *cached;
264 memoryIdentifierTokens.push_back(identifier);
265 holds.push_back(std::move(hold));
266 continue;
267 }
268 }
269 memoryIdentifierTokens.push_back(-1);
270 }
271 CHECK_EQ(requestInShared.pools.size(), memoryIdentifierTokens.size());
272
273 return BurstExecution::create(shared_from_this(), std::move(aidlRequest),
274 std::move(memoryIdentifierTokens), aidlMeasure,
275 aidlLoopTimeoutDuration, std::move(relocation), std::move(holds));
276}
277
278nn::GeneralResult<std::shared_ptr<const BurstExecution>> BurstExecution::create(
279 std::shared_ptr<const Burst> burst, Request request,
280 std::vector<int64_t> memoryIdentifierTokens, bool measure, int64_t loopTimeoutDuration,
281 hal::utils::RequestRelocation relocation,
282 std::vector<Burst::OptionalCacheHold> cacheHolds) {
283 if (burst == nullptr) {
284 return NN_ERROR() << "aidl::utils::BurstExecution::create must have non-null burst";
285 }
286
287 return std::make_shared<const BurstExecution>(
288 PrivateConstructorTag{}, std::move(burst), std::move(request),
289 std::move(memoryIdentifierTokens), measure, loopTimeoutDuration, std::move(relocation),
290 std::move(cacheHolds));
291}
292
293BurstExecution::BurstExecution(PrivateConstructorTag /*tag*/, std::shared_ptr<const Burst> burst,
294 Request request, std::vector<int64_t> memoryIdentifierTokens,
295 bool measure, int64_t loopTimeoutDuration,
296 hal::utils::RequestRelocation relocation,
297 std::vector<Burst::OptionalCacheHold> cacheHolds)
298 : kBurst(std::move(burst)),
299 kRequest(std::move(request)),
300 kMemoryIdentifierTokens(std::move(memoryIdentifierTokens)),
301 kMeasure(measure),
302 kLoopTimeoutDuration(loopTimeoutDuration),
303 kRelocation(std::move(relocation)),
304 kCacheHolds(std::move(cacheHolds)) {}
305
306nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> BurstExecution::compute(
307 const nn::OptionalTimePoint& deadline) const {
Michael Butlerff9a5a52021-10-15 16:23:20 -0700308 const auto aidlDeadline = NN_TRY(convert(deadline));
Xusong Wangb2e80852021-03-23 15:07:10 -0700309 return kBurst->executeInternal(kRequest, kMemoryIdentifierTokens, kMeasure, aidlDeadline,
310 kLoopTimeoutDuration, kRelocation);
311}
312
313nn::GeneralResult<std::pair<nn::SyncFence, nn::ExecuteFencedInfoCallback>>
314BurstExecution::computeFenced(const std::vector<nn::SyncFence>& /*waitFor*/,
315 const nn::OptionalTimePoint& /*deadline*/,
316 const nn::OptionalDuration& /*timeoutDurationAfterFence*/) const {
317 return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
318 << "IExecution::computeFenced is not supported on burst object";
319}
320
Michael Butler7a9d6092021-03-10 21:57:13 -0800321} // namespace aidl::android::hardware::neuralnetworks::utils