blob: 38ccc62156d029a10b55219516f58186965cf4ed [file] [log] [blame]
Michael Butler95331512020-12-18 20:53:55 -08001/*
2 * Copyright (C) 2020 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "ResilientBurst.h"
18
19#include <android-base/logging.h>
20#include <android-base/thread_annotations.h>
21#include <nnapi/IBurst.h>
22#include <nnapi/Result.h>
23#include <nnapi/TypeUtils.h>
24#include <nnapi/Types.h>
25
26#include <functional>
27#include <memory>
28#include <mutex>
29#include <optional>
30#include <utility>
31
32namespace android::hardware::neuralnetworks::utils {
33namespace {
34
35template <typename FnType>
36auto protect(const ResilientBurst& resilientBurst, const FnType& fn)
37 -> decltype(fn(*resilientBurst.getBurst())) {
38 auto burst = resilientBurst.getBurst();
39 auto result = fn(*burst);
40
41 // Immediately return if burst is not dead.
42 if (result.has_value() || result.error().code != nn::ErrorStatus::DEAD_OBJECT) {
43 return result;
44 }
45
46 // Attempt recovery and return if it fails.
47 auto maybeBurst = resilientBurst.recover(burst.get());
48 if (!maybeBurst.has_value()) {
49 auto [resultErrorMessage, resultErrorCode, resultOutputShapes] = std::move(result).error();
50 const auto& [recoveryErrorMessage, recoveryErrorCode] = maybeBurst.error();
51 return nn::error(resultErrorCode, std::move(resultOutputShapes))
52 << resultErrorMessage << ", and failed to recover dead burst object with error "
53 << recoveryErrorCode << ": " << recoveryErrorMessage;
54 }
55 burst = std::move(maybeBurst).value();
56
57 return fn(*burst);
58}
59
60} // namespace
61
62nn::GeneralResult<std::shared_ptr<const ResilientBurst>> ResilientBurst::create(Factory makeBurst) {
63 if (makeBurst == nullptr) {
64 return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT)
65 << "utils::ResilientBurst::create must have non-empty makeBurst";
66 }
67 auto burst = NN_TRY(makeBurst());
68 CHECK(burst != nullptr);
69 return std::make_shared<ResilientBurst>(PrivateConstructorTag{}, std::move(makeBurst),
70 std::move(burst));
71}
72
73ResilientBurst::ResilientBurst(PrivateConstructorTag /*tag*/, Factory makeBurst,
74 nn::SharedBurst burst)
75 : kMakeBurst(std::move(makeBurst)), mBurst(std::move(burst)) {
76 CHECK(kMakeBurst != nullptr);
77 CHECK(mBurst != nullptr);
78}
79
80nn::SharedBurst ResilientBurst::getBurst() const {
81 std::lock_guard guard(mMutex);
82 return mBurst;
83}
84
85nn::GeneralResult<nn::SharedBurst> ResilientBurst::recover(const nn::IBurst* failingBurst) const {
86 std::lock_guard guard(mMutex);
87
88 // Another caller updated the failing burst.
89 if (mBurst.get() != failingBurst) {
90 return mBurst;
91 }
92
93 mBurst = NN_TRY(kMakeBurst());
94 return mBurst;
95}
96
Michael Butlerfadeb8a2021-02-07 00:11:13 -080097ResilientBurst::OptionalCacheHold ResilientBurst::cacheMemory(
98 const nn::SharedMemory& memory) const {
Michael Butler95331512020-12-18 20:53:55 -080099 return getBurst()->cacheMemory(memory);
100}
101
102nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> ResilientBurst::execute(
Michael Butler8414a6e2021-03-10 18:41:05 -0800103 const nn::Request& request, nn::MeasureTiming measure,
104 const nn::OptionalTimePoint& deadline,
105 const nn::OptionalDuration& loopTimeoutDuration) const {
106 const auto fn = [&request, measure, deadline, loopTimeoutDuration](const nn::IBurst& burst) {
107 return burst.execute(request, measure, deadline, loopTimeoutDuration);
Michael Butler95331512020-12-18 20:53:55 -0800108 };
109 return protect(*this, fn);
110}
111
112} // namespace android::hardware::neuralnetworks::utils