blob: b20f6ae8e1ee1c27ffa64211afcf1c749239ae97 [file] [log] [blame]
Michael Butler7a9d6092021-03-10 21:57:13 -08001/*
2 * Copyright (C) 2021 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "Burst.h"
18
19#include "Conversions.h"
20#include "Utils.h"
21
22#include <android-base/logging.h>
23#include <android/binder_auto_utils.h>
24#include <nnapi/IBurst.h>
25#include <nnapi/Result.h>
26#include <nnapi/TypeUtils.h>
27#include <nnapi/Types.h>
28#include <nnapi/hal/HandleError.h>
29
30#include <memory>
31#include <mutex>
32#include <optional>
33#include <utility>
34
35namespace aidl::android::hardware::neuralnetworks::utils {
36namespace {
37
38nn::GeneralResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> convertExecutionResults(
39 const std::vector<OutputShape>& outputShapes, const Timing& timing) {
40 return std::make_pair(NN_TRY(nn::convert(outputShapes)), NN_TRY(nn::convert(timing)));
41}
42
43} // namespace
44
45Burst::MemoryCache::MemoryCache(std::shared_ptr<aidl_hal::IBurst> burst)
46 : kBurst(std::move(burst)) {}
47
48std::pair<int64_t, Burst::MemoryCache::SharedCleanup> Burst::MemoryCache::getOrCacheMemory(
49 const nn::SharedMemory& memory) {
50 std::lock_guard lock(mMutex);
51
52 // Get the cache payload or create it (with default values) if it does not exist.
53 auto& cachedPayload = mCache[memory];
54 {
55 const auto& [identifier, maybeCleaner] = cachedPayload;
56 // If cache payload already exists, reuse it.
57 if (auto cleaner = maybeCleaner.lock()) {
58 return std::make_pair(identifier, std::move(cleaner));
59 }
60 }
61
62 // If the code reaches this point, the cached payload either did not exist or expired prior to
63 // this call.
64
65 // Allocate a new identifier.
66 CHECK_LT(mUnusedIdentifier, std::numeric_limits<int64_t>::max());
67 const int64_t identifier = mUnusedIdentifier++;
68
69 // Create reference-counted self-cleaning cache object.
70 auto self = weak_from_this();
71 Task cleanup = [memory, identifier, maybeMemoryCache = std::move(self)] {
72 if (const auto memoryCache = maybeMemoryCache.lock()) {
73 memoryCache->tryFreeMemory(memory, identifier);
74 }
75 };
76 auto cleaner = std::make_shared<const Cleanup>(std::move(cleanup));
77
78 // Store the result in the cache and return it.
79 auto result = std::make_pair(identifier, std::move(cleaner));
80 cachedPayload = result;
81 return result;
82}
83
84std::optional<std::pair<int64_t, Burst::MemoryCache::SharedCleanup>>
85Burst::MemoryCache::getMemoryIfAvailable(const nn::SharedMemory& memory) {
86 std::lock_guard lock(mMutex);
87
88 // Get the existing cached entry if it exists.
89 const auto iter = mCache.find(memory);
90 if (iter != mCache.end()) {
91 const auto& [identifier, maybeCleaner] = iter->second;
92 if (auto cleaner = maybeCleaner.lock()) {
93 return std::make_pair(identifier, std::move(cleaner));
94 }
95 }
96
97 // If the code reaches this point, the cached payload did not exist or was actively being
98 // deleted.
99 return std::nullopt;
100}
101
102void Burst::MemoryCache::tryFreeMemory(const nn::SharedMemory& memory, int64_t identifier) {
103 {
104 std::lock_guard guard(mMutex);
105 // Remove the cached memory and payload if it is present but expired. Note that it may not
106 // be present or may not be expired because another thread may have removed or cached the
107 // same memory object before the current thread locked mMutex in tryFreeMemory.
108 const auto iter = mCache.find(memory);
109 if (iter != mCache.end()) {
110 if (std::get<WeakCleanup>(iter->second).expired()) {
111 mCache.erase(iter);
112 }
113 }
114 }
115 kBurst->releaseMemoryResource(identifier);
116}
117
118nn::GeneralResult<std::shared_ptr<const Burst>> Burst::create(
119 std::shared_ptr<aidl_hal::IBurst> burst) {
120 if (burst == nullptr) {
121 return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
122 << "aidl_hal::utils::Burst::create must have non-null burst";
123 }
124
125 return std::make_shared<const Burst>(PrivateConstructorTag{}, std::move(burst));
126}
127
128Burst::Burst(PrivateConstructorTag /*tag*/, std::shared_ptr<aidl_hal::IBurst> burst)
129 : kBurst(std::move(burst)), kMemoryCache(std::make_shared<MemoryCache>(kBurst)) {
130 CHECK(kBurst != nullptr);
131}
132
133Burst::OptionalCacheHold Burst::cacheMemory(const nn::SharedMemory& memory) const {
134 auto [identifier, hold] = kMemoryCache->getOrCacheMemory(memory);
135 return hold;
136}
137
138nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> Burst::execute(
139 const nn::Request& request, nn::MeasureTiming measure,
140 const nn::OptionalTimePoint& deadline,
141 const nn::OptionalDuration& loopTimeoutDuration) const {
142 // Ensure that at most one execution is in flight at any given time.
143 const bool alreadyInFlight = mExecutionInFlight.test_and_set();
144 if (alreadyInFlight) {
145 return NN_ERROR() << "IBurst already has an execution in flight";
146 }
147 const auto guard = ::android::base::make_scope_guard([this] { mExecutionInFlight.clear(); });
148
149 // Ensure that request is ready for IPC.
150 std::optional<nn::Request> maybeRequestInShared;
Xusong Wang5f6bedb2021-03-03 16:20:37 -0800151 hal::utils::RequestRelocation relocation;
152 const nn::Request& requestInShared =
153 NN_TRY(hal::utils::makeExecutionFailure(hal::utils::convertRequestFromPointerToShared(
154 &request, &maybeRequestInShared, &relocation)));
Michael Butler7a9d6092021-03-10 21:57:13 -0800155
156 const auto aidlRequest = NN_TRY(hal::utils::makeExecutionFailure(convert(requestInShared)));
157 const auto aidlMeasure = NN_TRY(hal::utils::makeExecutionFailure(convert(measure)));
158 const auto aidlDeadline = NN_TRY(hal::utils::makeExecutionFailure(convert(deadline)));
159 const auto aidlLoopTimeoutDuration =
160 NN_TRY(hal::utils::makeExecutionFailure(convert(loopTimeoutDuration)));
161
162 std::vector<int64_t> memoryIdentifierTokens;
163 std::vector<OptionalCacheHold> holds;
164 memoryIdentifierTokens.reserve(request.pools.size());
165 holds.reserve(request.pools.size());
166 for (const auto& memoryPool : request.pools) {
167 if (const auto* memory = std::get_if<nn::SharedMemory>(&memoryPool)) {
168 if (auto cached = kMemoryCache->getMemoryIfAvailable(*memory)) {
169 auto& [identifier, hold] = *cached;
170 memoryIdentifierTokens.push_back(identifier);
171 holds.push_back(std::move(hold));
172 continue;
173 }
174 }
175 memoryIdentifierTokens.push_back(-1);
176 }
177 CHECK_EQ(request.pools.size(), memoryIdentifierTokens.size());
178
Xusong Wang5f6bedb2021-03-03 16:20:37 -0800179 if (relocation.input) {
180 relocation.input->flush();
181 }
182
Michael Butler7a9d6092021-03-10 21:57:13 -0800183 ExecutionResult executionResult;
184 const auto ret =
185 kBurst->executeSynchronously(aidlRequest, memoryIdentifierTokens, aidlMeasure,
186 aidlDeadline, aidlLoopTimeoutDuration, &executionResult);
187 HANDLE_ASTATUS(ret) << "execute failed";
188 if (!executionResult.outputSufficientSize) {
189 auto canonicalOutputShapes =
190 nn::convert(executionResult.outputShapes).value_or(std::vector<nn::OutputShape>{});
191 return NN_ERROR(nn::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE, std::move(canonicalOutputShapes))
192 << "execution failed with " << nn::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE;
193 }
194 auto [outputShapes, timing] = NN_TRY(hal::utils::makeExecutionFailure(
195 convertExecutionResults(executionResult.outputShapes, executionResult.timing)));
196
Xusong Wang5f6bedb2021-03-03 16:20:37 -0800197 if (relocation.output) {
198 relocation.output->flush();
199 }
Michael Butler7a9d6092021-03-10 21:57:13 -0800200 return std::make_pair(std::move(outputShapes), timing);
201}
202
203} // namespace aidl::android::hardware::neuralnetworks::utils