Implement partial canonical Burst in NN util code
This CL adds a simple implementation of IBurst that dispatches calls to
an IPreparedModel object and changes
IPreparedModel::configureExecutionBurst to return this new object
(instead of returning an error).
This CL additionally defines an InvalidBurst class that returns errors
whenever it is used and a ResilientBurst class to recover an IBurst
object when it has died.
Bug: 177267324
Test: mma
Change-Id: I4c7e7ff4e6559aeb5e62c4fa02f2e751fef9d87d
Merged-In: I4c7e7ff4e6559aeb5e62c4fa02f2e751fef9d87d
(cherry picked from commit 44f324fb0d89ed896c9b0566ea632bddcfe69439)
diff --git a/neuralnetworks/utils/common/src/ResilientBurst.cpp b/neuralnetworks/utils/common/src/ResilientBurst.cpp
new file mode 100644
index 0000000..0d3cb33
--- /dev/null
+++ b/neuralnetworks/utils/common/src/ResilientBurst.cpp
@@ -0,0 +1,109 @@
+/*
+ * Copyright (C) 2020 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ResilientBurst.h"
+
+#include <android-base/logging.h>
+#include <android-base/thread_annotations.h>
+#include <nnapi/IBurst.h>
+#include <nnapi/Result.h>
+#include <nnapi/TypeUtils.h>
+#include <nnapi/Types.h>
+
+#include <functional>
+#include <memory>
+#include <mutex>
+#include <optional>
+#include <utility>
+
+namespace android::hardware::neuralnetworks::utils {
+namespace {
+
+template <typename FnType>
+auto protect(const ResilientBurst& resilientBurst, const FnType& fn)
+ -> decltype(fn(*resilientBurst.getBurst())) {
+ auto burst = resilientBurst.getBurst();
+ auto result = fn(*burst);
+
+ // Immediately return if burst is not dead.
+ if (result.has_value() || result.error().code != nn::ErrorStatus::DEAD_OBJECT) {
+ return result;
+ }
+
+ // Attempt recovery and return if it fails.
+ auto maybeBurst = resilientBurst.recover(burst.get());
+ if (!maybeBurst.has_value()) {
+ auto [resultErrorMessage, resultErrorCode, resultOutputShapes] = std::move(result).error();
+ const auto& [recoveryErrorMessage, recoveryErrorCode] = maybeBurst.error();
+ return nn::error(resultErrorCode, std::move(resultOutputShapes))
+ << resultErrorMessage << ", and failed to recover dead burst object with error "
+ << recoveryErrorCode << ": " << recoveryErrorMessage;
+ }
+ burst = std::move(maybeBurst).value();
+
+ return fn(*burst);
+}
+
+} // namespace
+
+nn::GeneralResult<std::shared_ptr<const ResilientBurst>> ResilientBurst::create(Factory makeBurst) {
+ if (makeBurst == nullptr) {
+ return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT)
+ << "utils::ResilientBurst::create must have non-empty makeBurst";
+ }
+ auto burst = NN_TRY(makeBurst());
+ CHECK(burst != nullptr);
+ return std::make_shared<ResilientBurst>(PrivateConstructorTag{}, std::move(makeBurst),
+ std::move(burst));
+}
+
+ResilientBurst::ResilientBurst(PrivateConstructorTag /*tag*/, Factory makeBurst,
+ nn::SharedBurst burst)
+ : kMakeBurst(std::move(makeBurst)), mBurst(std::move(burst)) {
+ CHECK(kMakeBurst != nullptr);
+ CHECK(mBurst != nullptr);
+}
+
+nn::SharedBurst ResilientBurst::getBurst() const {
+ std::lock_guard guard(mMutex);
+ return mBurst;
+}
+
+nn::GeneralResult<nn::SharedBurst> ResilientBurst::recover(const nn::IBurst* failingBurst) const {
+ std::lock_guard guard(mMutex);
+
+ // Another caller updated the failing burst.
+ if (mBurst.get() != failingBurst) {
+ return mBurst;
+ }
+
+ mBurst = NN_TRY(kMakeBurst());
+ return mBurst;
+}
+
+ResilientBurst::OptionalCacheHold ResilientBurst::cacheMemory(const nn::Memory& memory) const {
+ return getBurst()->cacheMemory(memory);
+}
+
+nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> ResilientBurst::execute(
+ const nn::Request& request, nn::MeasureTiming measure) const {
+ const auto fn = [&request, measure](const nn::IBurst& burst) {
+ return burst.execute(request, measure);
+ };
+ return protect(*this, fn);
+}
+
+} // namespace android::hardware::neuralnetworks::utils