blob: 26025a5026c6acefab2a7de43de4502cb18fe382 [file] [log] [blame]
Michael Butler4b276a72020-08-06 23:22:35 -07001/*
2 * Copyright (C) 2020 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "ResilientDevice.h"
18
19#include "ResilientBuffer.h"
20#include "ResilientPreparedModel.h"
21
22#include <android-base/logging.h>
23#include <nnapi/IBuffer.h>
24#include <nnapi/IDevice.h>
25#include <nnapi/IPreparedModel.h>
26#include <nnapi/Result.h>
27#include <nnapi/TypeUtils.h>
28#include <nnapi/Types.h>
29
30#include <algorithm>
31#include <memory>
32#include <string>
33#include <vector>
34
35namespace android::hardware::neuralnetworks::utils {
36namespace {
37
38template <typename FnType>
39auto protect(const ResilientDevice& resilientDevice, const FnType& fn, bool blocking)
40 -> decltype(fn(*resilientDevice.getDevice())) {
41 auto device = resilientDevice.getDevice();
42 auto result = fn(*device);
43
44 // Immediately return if device is not dead.
45 if (result.has_value() || result.error().code != nn::ErrorStatus::DEAD_OBJECT) {
46 return result;
47 }
48
49 device = resilientDevice.recover(device.get(), blocking);
50 return fn(*device);
51}
52
53} // namespace
54
55nn::GeneralResult<std::shared_ptr<const ResilientDevice>> ResilientDevice::create(
56 Factory makeDevice) {
57 if (makeDevice == nullptr) {
58 return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT)
59 << "utils::ResilientDevice::create must have non-empty makeDevice";
60 }
61 auto device = NN_TRY(makeDevice(/*blocking=*/true));
62 CHECK(device != nullptr);
63
64 auto name = device->getName();
65 auto versionString = device->getVersionString();
66 auto extensions = device->getSupportedExtensions();
67 auto capabilities = device->getCapabilities();
68
69 return std::make_shared<ResilientDevice>(PrivateConstructorTag{}, std::move(makeDevice),
70 std::move(name), std::move(versionString),
71 std::move(extensions), std::move(capabilities),
72 std::move(device));
73}
74
75ResilientDevice::ResilientDevice(PrivateConstructorTag /*tag*/, Factory makeDevice,
76 std::string name, std::string versionString,
77 std::vector<nn::Extension> extensions,
78 nn::Capabilities capabilities, nn::SharedDevice device)
79 : kMakeDevice(std::move(makeDevice)),
80 kName(std::move(name)),
81 kVersionString(std::move(versionString)),
82 kExtensions(std::move(extensions)),
83 kCapabilities(std::move(capabilities)),
84 mDevice(std::move(device)) {
85 CHECK(kMakeDevice != nullptr);
86 CHECK(mDevice != nullptr);
87}
88
89nn::SharedDevice ResilientDevice::getDevice() const {
90 std::lock_guard guard(mMutex);
91 return mDevice;
92}
93
94nn::SharedDevice ResilientDevice::recover(const nn::IDevice* failingDevice, bool blocking) const {
95 std::lock_guard guard(mMutex);
96
97 // Another caller updated the failing device.
98 if (mDevice.get() != failingDevice) {
99 return mDevice;
100 }
101
102 auto maybeDevice = kMakeDevice(blocking);
103 if (!maybeDevice.has_value()) {
104 const auto& [message, code] = maybeDevice.error();
105 LOG(ERROR) << "Failed to recover dead device with error " << code << ": " << message;
106 return mDevice;
107 }
108 auto device = std::move(maybeDevice).value();
109
110 // TODO(b/173081926): Instead of CHECKing to ensure the cache has not been changed, return an
111 // invalid/"null" IDevice object that always fails.
112 CHECK_EQ(kName, device->getName());
113 CHECK_EQ(kVersionString, device->getVersionString());
114 CHECK(kExtensions == device->getSupportedExtensions());
115 CHECK_EQ(kCapabilities, device->getCapabilities());
116
117 mDevice = std::move(device);
118 return mDevice;
119}
120
121const std::string& ResilientDevice::getName() const {
122 return kName;
123}
124
125const std::string& ResilientDevice::getVersionString() const {
126 return kVersionString;
127}
128
129nn::Version ResilientDevice::getFeatureLevel() const {
130 return getDevice()->getFeatureLevel();
131}
132
133nn::DeviceType ResilientDevice::getType() const {
134 return getDevice()->getType();
135}
136
137const std::vector<nn::Extension>& ResilientDevice::getSupportedExtensions() const {
138 return kExtensions;
139}
140
141const nn::Capabilities& ResilientDevice::getCapabilities() const {
142 return kCapabilities;
143}
144
145std::pair<uint32_t, uint32_t> ResilientDevice::getNumberOfCacheFilesNeeded() const {
146 return getDevice()->getNumberOfCacheFilesNeeded();
147}
148
149nn::GeneralResult<void> ResilientDevice::wait() const {
150 const auto fn = [](const nn::IDevice& device) { return device.wait(); };
151 return protect(*this, fn, /*blocking=*/true);
152}
153
154nn::GeneralResult<std::vector<bool>> ResilientDevice::getSupportedOperations(
155 const nn::Model& model) const {
156 const auto fn = [&model](const nn::IDevice& device) {
157 return device.getSupportedOperations(model);
158 };
159 return protect(*this, fn, /*blocking=*/false);
160}
161
162nn::GeneralResult<nn::SharedPreparedModel> ResilientDevice::prepareModel(
163 const nn::Model& model, nn::ExecutionPreference preference, nn::Priority priority,
Slava Shklyaev49817a02020-10-27 18:44:01 +0000164 nn::OptionalTimePoint deadline, const std::vector<nn::SharedHandle>& modelCache,
165 const std::vector<nn::SharedHandle>& dataCache, const nn::CacheToken& token) const {
Michael Butler4b276a72020-08-06 23:22:35 -0700166 auto self = shared_from_this();
167 ResilientPreparedModel::Factory makePreparedModel =
168 [device = std::move(self), model, preference, priority, deadline, modelCache, dataCache,
169 token](bool blocking) -> nn::GeneralResult<nn::SharedPreparedModel> {
170 return device->prepareModelInternal(blocking, model, preference, priority, deadline,
171 modelCache, dataCache, token);
172 };
173 return ResilientPreparedModel::create(std::move(makePreparedModel));
174}
175
176nn::GeneralResult<nn::SharedPreparedModel> ResilientDevice::prepareModelFromCache(
Slava Shklyaev49817a02020-10-27 18:44:01 +0000177 nn::OptionalTimePoint deadline, const std::vector<nn::SharedHandle>& modelCache,
178 const std::vector<nn::SharedHandle>& dataCache, const nn::CacheToken& token) const {
Michael Butler4b276a72020-08-06 23:22:35 -0700179 auto self = shared_from_this();
180 ResilientPreparedModel::Factory makePreparedModel =
181 [device = std::move(self), deadline, modelCache, dataCache,
182 token](bool blocking) -> nn::GeneralResult<nn::SharedPreparedModel> {
183 return device->prepareModelFromCacheInternal(blocking, deadline, modelCache, dataCache,
184 token);
185 };
186 return ResilientPreparedModel::create(std::move(makePreparedModel));
187}
188
189nn::GeneralResult<nn::SharedBuffer> ResilientDevice::allocate(
190 const nn::BufferDesc& desc, const std::vector<nn::SharedPreparedModel>& preparedModels,
191 const std::vector<nn::BufferRole>& inputRoles,
192 const std::vector<nn::BufferRole>& outputRoles) const {
193 auto self = shared_from_this();
194 ResilientBuffer::Factory makeBuffer =
195 [device = std::move(self), desc, preparedModels, inputRoles,
196 outputRoles](bool blocking) -> nn::GeneralResult<nn::SharedBuffer> {
197 return device->allocateInternal(blocking, desc, preparedModels, inputRoles, outputRoles);
198 };
199 return ResilientBuffer::create(std::move(makeBuffer));
200}
201
202nn::GeneralResult<nn::SharedPreparedModel> ResilientDevice::prepareModelInternal(
203 bool blocking, const nn::Model& model, nn::ExecutionPreference preference,
204 nn::Priority priority, nn::OptionalTimePoint deadline,
Slava Shklyaev49817a02020-10-27 18:44:01 +0000205 const std::vector<nn::SharedHandle>& modelCache,
206 const std::vector<nn::SharedHandle>& dataCache, const nn::CacheToken& token) const {
Michael Butler4b276a72020-08-06 23:22:35 -0700207 const auto fn = [&model, preference, priority, deadline, &modelCache, &dataCache,
208 token](const nn::IDevice& device) {
209 return device.prepareModel(model, preference, priority, deadline, modelCache, dataCache,
210 token);
211 };
212 return protect(*this, fn, blocking);
213}
214
215nn::GeneralResult<nn::SharedPreparedModel> ResilientDevice::prepareModelFromCacheInternal(
216 bool blocking, nn::OptionalTimePoint deadline,
Slava Shklyaev49817a02020-10-27 18:44:01 +0000217 const std::vector<nn::SharedHandle>& modelCache,
218 const std::vector<nn::SharedHandle>& dataCache, const nn::CacheToken& token) const {
Michael Butler4b276a72020-08-06 23:22:35 -0700219 const auto fn = [deadline, &modelCache, &dataCache, token](const nn::IDevice& device) {
220 return device.prepareModelFromCache(deadline, modelCache, dataCache, token);
221 };
222 return protect(*this, fn, blocking);
223}
224
225nn::GeneralResult<nn::SharedBuffer> ResilientDevice::allocateInternal(
226 bool blocking, const nn::BufferDesc& desc,
227 const std::vector<nn::SharedPreparedModel>& preparedModels,
228 const std::vector<nn::BufferRole>& inputRoles,
229 const std::vector<nn::BufferRole>& outputRoles) const {
230 const auto fn = [&desc, &preparedModels, &inputRoles, &outputRoles](const nn::IDevice& device) {
231 return device.allocate(desc, preparedModels, inputRoles, outputRoles);
232 };
233 return protect(*this, fn, blocking);
234}
235
236} // namespace android::hardware::neuralnetworks::utils