Re-organize NNAPI Burst utility classes
This change:
* Renames ExecutionBurstController to Burst in 1.2/utils
* Renames ExecutionBurstUtils to BurstUtils in 1.2/utils
* Renames ExecutionBurstServer to Burst in common/adapter
Bug: N/A
Test: mma
Change-Id: Ibd460229887c8c9cd23ebc6ee61da37c7c820288
diff --git a/neuralnetworks/1.2/utils/src/Burst.cpp b/neuralnetworks/1.2/utils/src/Burst.cpp
new file mode 100644
index 0000000..e0a23f1
--- /dev/null
+++ b/neuralnetworks/1.2/utils/src/Burst.cpp
@@ -0,0 +1,466 @@
+/*
+ * Copyright (C) 2019 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Burst.h"
+#include "BurstUtils.h"
+
+#include <android-base/logging.h>
+#include <android-base/thread_annotations.h>
+#include <nnapi/IBurst.h>
+#include <nnapi/IPreparedModel.h>
+#include <nnapi/Result.h>
+#include <nnapi/TypeUtils.h>
+#include <nnapi/Types.h>
+#include <nnapi/Validation.h>
+#include <nnapi/hal/1.0/Conversions.h>
+#include <nnapi/hal/1.0/HandleError.h>
+#include <nnapi/hal/1.0/ProtectCallback.h>
+#include <nnapi/hal/CommonUtils.h>
+#include <nnapi/hal/TransferValue.h>
+
+#include <algorithm>
+#include <cstring>
+#include <limits>
+#include <memory>
+#include <string>
+#include <thread>
+#include <tuple>
+#include <utility>
+#include <vector>
+
+#include "Callbacks.h"
+#include "Conversions.h"
+#include "Tracing.h"
+#include "Utils.h"
+
+namespace android::hardware::neuralnetworks::V1_2::utils {
+namespace {
+
+class BurstExecution final : public nn::IExecution,
+ public std::enable_shared_from_this<BurstExecution> {
+ struct PrivateConstructorTag {};
+
+ public:
+ static nn::GeneralResult<std::shared_ptr<const BurstExecution>> create(
+ std::shared_ptr<const Burst> controller, std::vector<FmqRequestDatum> request,
+ hal::utils::RequestRelocation relocation,
+ std::vector<Burst::OptionalCacheHold> cacheHolds);
+
+ BurstExecution(PrivateConstructorTag tag, std::shared_ptr<const Burst> controller,
+ std::vector<FmqRequestDatum> request, hal::utils::RequestRelocation relocation,
+ std::vector<Burst::OptionalCacheHold> cacheHolds);
+
+ nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> compute(
+ const nn::OptionalTimePoint& deadline) const override;
+
+ nn::GeneralResult<std::pair<nn::SyncFence, nn::ExecuteFencedInfoCallback>> computeFenced(
+ const std::vector<nn::SyncFence>& waitFor, const nn::OptionalTimePoint& deadline,
+ const nn::OptionalDuration& timeoutDurationAfterFence) const override;
+
+ private:
+ const std::shared_ptr<const Burst> kController;
+ const std::vector<FmqRequestDatum> kRequest;
+ const hal::utils::RequestRelocation kRelocation;
+ const std::vector<Burst::OptionalCacheHold> kCacheHolds;
+};
+
+nn::GeneralResult<sp<IBurstContext>> executionBurstResultCallback(
+ V1_0::ErrorStatus status, const sp<IBurstContext>& burstContext) {
+ HANDLE_STATUS_HIDL(status) << "IPreparedModel::configureExecutionBurst failed with status "
+ << toString(status);
+ if (burstContext == nullptr) {
+ return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
+ << "IPreparedModel::configureExecutionBurst returned nullptr for burst";
+ }
+ return burstContext;
+}
+
+nn::GeneralResult<hidl_vec<hidl_memory>> getMemoriesHelper(
+ const hidl_vec<int32_t>& slots, const std::shared_ptr<Burst::MemoryCache>& memoryCache) {
+ hidl_vec<hidl_memory> memories(slots.size());
+ for (size_t i = 0; i < slots.size(); ++i) {
+ const int32_t slot = slots[i];
+ const auto memory = NN_TRY(memoryCache->getMemory(slot));
+ memories[i] = NN_TRY(V1_0::utils::unvalidatedConvert(memory));
+ if (!memories[i].valid()) {
+ return NN_ERROR() << "memory at slot " << slot << " is invalid";
+ }
+ }
+ return memories;
+}
+
+} // namespace
+
+// MemoryCache methods
+
+Burst::MemoryCache::MemoryCache() {
+ constexpr size_t kPreallocatedCount = 1024;
+ std::vector<int32_t> freeSlotsSpace;
+ freeSlotsSpace.reserve(kPreallocatedCount);
+ mFreeSlots = std::stack<int32_t, std::vector<int32_t>>(std::move(freeSlotsSpace));
+ mMemoryCache.reserve(kPreallocatedCount);
+ mCacheCleaner.reserve(kPreallocatedCount);
+}
+
+void Burst::MemoryCache::setBurstContext(sp<IBurstContext> burstContext) {
+ std::lock_guard guard(mMutex);
+ mBurstContext = std::move(burstContext);
+}
+
+std::pair<int32_t, Burst::MemoryCache::SharedCleanup> Burst::MemoryCache::cacheMemory(
+ const nn::SharedMemory& memory) {
+ std::unique_lock lock(mMutex);
+ base::ScopedLockAssertion lockAssert(mMutex);
+
+ // Use existing cache entry if (1) the Memory object is in the cache and (2) the cache entry is
+ // not currently being freed.
+ auto iter = mMemoryIdToSlot.find(memory);
+ while (iter != mMemoryIdToSlot.end()) {
+ const int32_t slot = iter->second;
+ if (auto cleaner = mCacheCleaner.at(slot).lock()) {
+ return std::make_pair(slot, std::move(cleaner));
+ }
+
+ // If the code reaches this point, the Memory object was in the cache, but is currently
+ // being destroyed. This code waits until the cache entry has been freed, then loops to
+ // ensure the cache entry has been freed or has been made present by another thread.
+ mCond.wait(lock);
+ iter = mMemoryIdToSlot.find(memory);
+ }
+
+ // Allocate a new cache entry.
+ const int32_t slot = allocateSlotLocked();
+ mMemoryIdToSlot[memory] = slot;
+ mMemoryCache[slot] = memory;
+
+ // Create reference-counted self-cleaning cache object.
+ auto self = weak_from_this();
+ Task cleanup = [memory, memoryCache = std::move(self)] {
+ if (const auto lock = memoryCache.lock()) {
+ lock->freeMemory(memory);
+ }
+ };
+ auto cleaner = std::make_shared<const Cleanup>(std::move(cleanup));
+ mCacheCleaner[slot] = cleaner;
+
+ return std::make_pair(slot, std::move(cleaner));
+}
+
+nn::GeneralResult<nn::SharedMemory> Burst::MemoryCache::getMemory(int32_t slot) {
+ std::lock_guard guard(mMutex);
+ if (slot < 0 || static_cast<size_t>(slot) >= mMemoryCache.size()) {
+ return NN_ERROR() << "Invalid slot: " << slot << " vs " << mMemoryCache.size();
+ }
+ return mMemoryCache[slot];
+}
+
+void Burst::MemoryCache::freeMemory(const nn::SharedMemory& memory) {
+ {
+ std::lock_guard guard(mMutex);
+ const int32_t slot = mMemoryIdToSlot.at(memory);
+ if (mBurstContext) {
+ const auto ret = mBurstContext->freeMemory(slot);
+ if (!ret.isOk()) {
+ LOG(ERROR) << "IBustContext::freeMemory failed: " << ret.description();
+ }
+ }
+ mMemoryIdToSlot.erase(memory);
+ mMemoryCache[slot] = {};
+ mCacheCleaner[slot].reset();
+ mFreeSlots.push(slot);
+ }
+ mCond.notify_all();
+}
+
+int32_t Burst::MemoryCache::allocateSlotLocked() {
+ constexpr size_t kMaxNumberOfSlots = std::numeric_limits<int32_t>::max();
+
+ // If there is a free slot, use it.
+ if (!mFreeSlots.empty()) {
+ const int32_t slot = mFreeSlots.top();
+ mFreeSlots.pop();
+ return slot;
+ }
+
+ // Use a slot for the first time.
+ CHECK_LT(mMemoryCache.size(), kMaxNumberOfSlots) << "Exceeded maximum number of slots!";
+ const int32_t slot = static_cast<int32_t>(mMemoryCache.size());
+ mMemoryCache.emplace_back();
+ mCacheCleaner.emplace_back();
+
+ return slot;
+}
+
+// ExecutionBurstCallback methods
+
+Burst::ExecutionBurstCallback::ExecutionBurstCallback(
+ const std::shared_ptr<MemoryCache>& memoryCache)
+ : kMemoryCache(memoryCache) {
+ CHECK(memoryCache != nullptr);
+}
+
+Return<void> Burst::ExecutionBurstCallback::getMemories(const hidl_vec<int32_t>& slots,
+ getMemories_cb cb) {
+ const auto memoryCache = kMemoryCache.lock();
+ if (memoryCache == nullptr) {
+ LOG(ERROR) << "Burst::ExecutionBurstCallback::getMemories called after the MemoryCache has "
+ "been freed";
+ cb(V1_0::ErrorStatus::GENERAL_FAILURE, {});
+ return Void();
+ }
+
+ const auto maybeMemories = getMemoriesHelper(slots, memoryCache);
+ if (!maybeMemories.has_value()) {
+ const auto& [message, code] = maybeMemories.error();
+ LOG(ERROR) << "Burst::ExecutionBurstCallback::getMemories failed with " << code << ": "
+ << message;
+ cb(V1_0::ErrorStatus::INVALID_ARGUMENT, {});
+ return Void();
+ }
+
+ cb(V1_0::ErrorStatus::NONE, maybeMemories.value());
+ return Void();
+}
+
+// Burst methods
+
+nn::GeneralResult<std::shared_ptr<const Burst>> Burst::create(
+ nn::SharedPreparedModel preparedModel, const sp<V1_2::IPreparedModel>& hidlPreparedModel,
+ std::chrono::microseconds pollingTimeWindow) {
+ // check inputs
+ if (preparedModel == nullptr || hidlPreparedModel == nullptr) {
+ return NN_ERROR() << "Burst::create passed a nullptr";
+ }
+
+ // create FMQ objects
+ auto [requestChannelSender, requestChannelDescriptor] =
+ NN_TRY(RequestChannelSender::create(kExecutionBurstChannelLength));
+ auto [resultChannelReceiver, resultChannelDescriptor] =
+ NN_TRY(ResultChannelReceiver::create(kExecutionBurstChannelLength, pollingTimeWindow));
+
+ // check FMQ objects
+ CHECK(requestChannelSender != nullptr);
+ CHECK(requestChannelDescriptor != nullptr);
+ CHECK(resultChannelReceiver != nullptr);
+ CHECK(resultChannelDescriptor != nullptr);
+
+ // create memory cache
+ auto memoryCache = std::make_shared<MemoryCache>();
+
+ // create callback object
+ auto burstCallback = sp<ExecutionBurstCallback>::make(memoryCache);
+ auto cb = hal::utils::CallbackValue(executionBurstResultCallback);
+
+ // configure burst
+ const Return<void> ret = hidlPreparedModel->configureExecutionBurst(
+ burstCallback, *requestChannelDescriptor, *resultChannelDescriptor, cb);
+ HANDLE_TRANSPORT_FAILURE(ret);
+
+ auto burstContext = NN_TRY(cb.take());
+ memoryCache->setBurstContext(burstContext);
+
+ // create death handler object
+ auto deathHandler = NN_TRY(neuralnetworks::utils::DeathHandler::create(burstContext));
+ deathHandler.protectCallbackForLifetimeOfDeathHandler(requestChannelSender.get());
+ deathHandler.protectCallbackForLifetimeOfDeathHandler(resultChannelReceiver.get());
+
+ // make and return controller
+ return std::make_shared<const Burst>(
+ PrivateConstructorTag{}, std::move(preparedModel), std::move(requestChannelSender),
+ std::move(resultChannelReceiver), std::move(burstCallback), std::move(burstContext),
+ std::move(memoryCache), std::move(deathHandler));
+}
+
+Burst::Burst(PrivateConstructorTag /*tag*/, nn::SharedPreparedModel preparedModel,
+ std::unique_ptr<RequestChannelSender> requestChannelSender,
+ std::unique_ptr<ResultChannelReceiver> resultChannelReceiver,
+ sp<ExecutionBurstCallback> callback, sp<IBurstContext> burstContext,
+ std::shared_ptr<MemoryCache> memoryCache,
+ neuralnetworks::utils::DeathHandler deathHandler)
+ : kPreparedModel(std::move(preparedModel)),
+ mRequestChannelSender(std::move(requestChannelSender)),
+ mResultChannelReceiver(std::move(resultChannelReceiver)),
+ mBurstCallback(std::move(callback)),
+ mBurstContext(std::move(burstContext)),
+ mMemoryCache(std::move(memoryCache)),
+ kDeathHandler(std::move(deathHandler)) {}
+
+Burst::OptionalCacheHold Burst::cacheMemory(const nn::SharedMemory& memory) const {
+ auto [slot, hold] = mMemoryCache->cacheMemory(memory);
+ return hold;
+}
+
+nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> Burst::execute(
+ const nn::Request& request, nn::MeasureTiming measure,
+ const nn::OptionalTimePoint& deadline,
+ const nn::OptionalDuration& loopTimeoutDuration) const {
+ // This is the first point when we know an execution is occurring, so begin to collect
+ // systraces. Note that the first point we can begin collecting systraces in
+ // ExecutionBurstServer is when the RequestChannelReceiver realizes there is data in the FMQ, so
+ // ExecutionBurstServer collects systraces at different points in the code.
+ NNTRACE_RT(NNTRACE_PHASE_EXECUTION, "Burst::execute");
+
+ // if the request is valid but of a higher version than what's supported in burst execution,
+ // fall back to another execution path
+ if (const auto version = NN_TRY(nn::validate(request)); version > nn::Version::ANDROID_Q) {
+ // fallback to another execution path if the packet could not be sent
+ return kPreparedModel->execute(request, measure, deadline, loopTimeoutDuration);
+ }
+
+ // ensure that request is ready for IPC
+ std::optional<nn::Request> maybeRequestInShared;
+ hal::utils::RequestRelocation relocation;
+ const nn::Request& requestInShared = NN_TRY(hal::utils::convertRequestFromPointerToShared(
+ &request, nn::kDefaultRequestMemoryAlignment, nn::kMinMemoryPadding,
+ &maybeRequestInShared, &relocation));
+
+ // clear pools field of request, as they will be provided via slots
+ const auto requestWithoutPools = nn::Request{
+ .inputs = requestInShared.inputs, .outputs = requestInShared.outputs, .pools = {}};
+ auto hidlRequest = NN_TRY(V1_0::utils::unvalidatedConvert(requestWithoutPools));
+ const auto hidlMeasure = NN_TRY(convert(measure));
+
+ std::vector<int32_t> slots;
+ std::vector<OptionalCacheHold> holds;
+ slots.reserve(requestInShared.pools.size());
+ holds.reserve(requestInShared.pools.size());
+ for (const auto& memoryPool : requestInShared.pools) {
+ auto [slot, hold] = mMemoryCache->cacheMemory(std::get<nn::SharedMemory>(memoryPool));
+ slots.push_back(slot);
+ holds.push_back(std::move(hold));
+ }
+
+ // send request packet
+ const auto requestPacket = serialize(hidlRequest, hidlMeasure, slots);
+ const auto fallback = [this, &request, measure, &deadline, &loopTimeoutDuration] {
+ return kPreparedModel->execute(request, measure, deadline, loopTimeoutDuration);
+ };
+ return executeInternal(requestPacket, relocation, fallback);
+}
+
+// See IBurst::createReusableExecution for information on this method.
+nn::GeneralResult<nn::SharedExecution> Burst::createReusableExecution(
+ const nn::Request& request, nn::MeasureTiming measure,
+ const nn::OptionalDuration& loopTimeoutDuration) const {
+ NNTRACE_RT(NNTRACE_PHASE_EXECUTION, "Burst::createReusableExecution");
+
+ // if the request is valid but of a higher version than what's supported in burst execution,
+ // fall back to another execution path
+ if (const auto version = NN_TRY(nn::validate(request)); version > nn::Version::ANDROID_Q) {
+ // fallback to another execution path if the packet could not be sent
+ return kPreparedModel->createReusableExecution(request, measure, loopTimeoutDuration);
+ }
+
+ // ensure that request is ready for IPC
+ std::optional<nn::Request> maybeRequestInShared;
+ hal::utils::RequestRelocation relocation;
+ const nn::Request& requestInShared = NN_TRY(hal::utils::convertRequestFromPointerToShared(
+ &request, nn::kDefaultRequestMemoryAlignment, nn::kMinMemoryPadding,
+ &maybeRequestInShared, &relocation));
+
+ // clear pools field of request, as they will be provided via slots
+ const auto requestWithoutPools = nn::Request{
+ .inputs = requestInShared.inputs, .outputs = requestInShared.outputs, .pools = {}};
+ auto hidlRequest = NN_TRY(V1_0::utils::unvalidatedConvert(requestWithoutPools));
+ const auto hidlMeasure = NN_TRY(convert(measure));
+
+ std::vector<int32_t> slots;
+ std::vector<OptionalCacheHold> holds;
+ slots.reserve(requestInShared.pools.size());
+ holds.reserve(requestInShared.pools.size());
+ for (const auto& memoryPool : requestInShared.pools) {
+ auto [slot, hold] = mMemoryCache->cacheMemory(std::get<nn::SharedMemory>(memoryPool));
+ slots.push_back(slot);
+ holds.push_back(std::move(hold));
+ }
+
+ const auto requestPacket = serialize(hidlRequest, hidlMeasure, slots);
+ return BurstExecution::create(shared_from_this(), std::move(requestPacket),
+ std::move(relocation), std::move(holds));
+}
+
+nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> Burst::executeInternal(
+ const std::vector<FmqRequestDatum>& requestPacket,
+ const hal::utils::RequestRelocation& relocation, FallbackFunction fallback) const {
+ NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_EXECUTION, "Burst::executeInternal");
+
+ // Ensure that at most one execution is in flight at any given time.
+ const bool alreadyInFlight = mExecutionInFlight.test_and_set();
+ if (alreadyInFlight) {
+ return NN_ERROR() << "IBurst already has an execution in flight";
+ }
+ const auto guard = base::make_scope_guard([this] { mExecutionInFlight.clear(); });
+
+ if (relocation.input) {
+ relocation.input->flush();
+ }
+
+ // send request packet
+ const auto sendStatus = mRequestChannelSender->sendPacket(requestPacket);
+ if (!sendStatus.ok()) {
+ // fallback to another execution path if the packet could not be sent
+ if (fallback) {
+ return fallback();
+ }
+ return NN_ERROR() << "Error sending FMQ packet: " << sendStatus.error();
+ }
+
+ // get result packet
+ const auto [status, outputShapes, timing] = NN_TRY(mResultChannelReceiver->getBlocking());
+
+ if (relocation.output) {
+ relocation.output->flush();
+ }
+ return executionCallback(status, outputShapes, timing);
+}
+
+nn::GeneralResult<std::shared_ptr<const BurstExecution>> BurstExecution::create(
+ std::shared_ptr<const Burst> controller, std::vector<FmqRequestDatum> request,
+ hal::utils::RequestRelocation relocation,
+ std::vector<Burst::OptionalCacheHold> cacheHolds) {
+ if (controller == nullptr) {
+ return NN_ERROR() << "V1_2::utils::BurstExecution::create must have non-null controller";
+ }
+
+ return std::make_shared<const BurstExecution>(PrivateConstructorTag{}, std::move(controller),
+ std::move(request), std::move(relocation),
+ std::move(cacheHolds));
+}
+
+BurstExecution::BurstExecution(PrivateConstructorTag /*tag*/,
+ std::shared_ptr<const Burst> controller,
+ std::vector<FmqRequestDatum> request,
+ hal::utils::RequestRelocation relocation,
+ std::vector<Burst::OptionalCacheHold> cacheHolds)
+ : kController(std::move(controller)),
+ kRequest(std::move(request)),
+ kRelocation(std::move(relocation)),
+ kCacheHolds(std::move(cacheHolds)) {}
+
+nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> BurstExecution::compute(
+ const nn::OptionalTimePoint& /*deadline*/) const {
+ return kController->executeInternal(kRequest, kRelocation, /*fallback=*/nullptr);
+}
+
+nn::GeneralResult<std::pair<nn::SyncFence, nn::ExecuteFencedInfoCallback>>
+BurstExecution::computeFenced(const std::vector<nn::SyncFence>& /*waitFor*/,
+ const nn::OptionalTimePoint& /*deadline*/,
+ const nn::OptionalDuration& /*timeoutDurationAfterFence*/) const {
+ return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
+ << "IExecution::computeFenced is not supported on burst object";
+}
+
+} // namespace android::hardware::neuralnetworks::V1_2::utils