blob: 4a58f3b93ca2e6ba9e310d301bd04c872cab4428 [file] [log] [blame]
Michael Butlera685c3d2020-02-22 22:37:59 -08001/*
2 * Copyright (C) 2020 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "Conversions.h"
18
19#include <android-base/logging.h>
20#include <android/hardware/neuralnetworks/1.0/types.h>
21#include <nnapi/OperandTypes.h>
22#include <nnapi/OperationTypes.h>
23#include <nnapi/Result.h>
24#include <nnapi/SharedMemory.h>
25#include <nnapi/Types.h>
26#include <nnapi/hal/CommonUtils.h>
27
28#include <algorithm>
29#include <functional>
30#include <iterator>
31#include <memory>
32#include <type_traits>
33#include <utility>
34#include <variant>
35
36namespace {
37
38template <typename Type>
39constexpr std::underlying_type_t<Type> underlyingType(Type value) {
40 return static_cast<std::underlying_type_t<Type>>(value);
41}
42
43} // namespace
44
45namespace android::nn {
46namespace {
47
48using hardware::hidl_memory;
49using hardware::hidl_vec;
50
51template <typename Input>
52using ConvertOutput = std::decay_t<decltype(convert(std::declval<Input>()).value())>;
53
54template <typename Type>
55Result<std::vector<ConvertOutput<Type>>> convert(const hidl_vec<Type>& arguments) {
56 std::vector<ConvertOutput<Type>> canonical;
57 canonical.reserve(arguments.size());
58 for (const auto& argument : arguments) {
59 canonical.push_back(NN_TRY(nn::convert(argument)));
60 }
61 return canonical;
62}
63
64} // anonymous namespace
65
66Result<OperandType> convert(const hal::V1_0::OperandType& operandType) {
67 return static_cast<OperandType>(operandType);
68}
69
70Result<OperationType> convert(const hal::V1_0::OperationType& operationType) {
71 return static_cast<OperationType>(operationType);
72}
73
74Result<Operand::LifeTime> convert(const hal::V1_0::OperandLifeTime& lifetime) {
75 return static_cast<Operand::LifeTime>(lifetime);
76}
77
78Result<DeviceStatus> convert(const hal::V1_0::DeviceStatus& deviceStatus) {
79 return static_cast<DeviceStatus>(deviceStatus);
80}
81
82Result<Capabilities::PerformanceInfo> convert(const hal::V1_0::PerformanceInfo& performanceInfo) {
83 return Capabilities::PerformanceInfo{
84 .execTime = performanceInfo.execTime,
85 .powerUsage = performanceInfo.powerUsage,
86 };
87}
88
89Result<Capabilities> convert(const hal::V1_0::Capabilities& capabilities) {
90 const auto quantized8Performance = NN_TRY(convert(capabilities.quantized8Performance));
91 const auto float32Performance = NN_TRY(convert(capabilities.float32Performance));
92
93 auto table = hal::utils::makeQuantized8PerformanceConsistentWithP(float32Performance,
94 quantized8Performance);
95
96 return Capabilities{
97 .relaxedFloat32toFloat16PerformanceScalar = float32Performance,
98 .relaxedFloat32toFloat16PerformanceTensor = float32Performance,
99 .operandPerformance = std::move(table),
100 };
101}
102
103Result<DataLocation> convert(const hal::V1_0::DataLocation& location) {
104 return DataLocation{
105 .poolIndex = location.poolIndex,
106 .offset = location.offset,
107 .length = location.length,
108 };
109}
110
111Result<Operand> convert(const hal::V1_0::Operand& operand) {
112 return Operand{
113 .type = NN_TRY(convert(operand.type)),
114 .dimensions = operand.dimensions,
115 .scale = operand.scale,
116 .zeroPoint = operand.zeroPoint,
117 .lifetime = NN_TRY(convert(operand.lifetime)),
118 .location = NN_TRY(convert(operand.location)),
119 };
120}
121
122Result<Operation> convert(const hal::V1_0::Operation& operation) {
123 return Operation{
124 .type = NN_TRY(convert(operation.type)),
125 .inputs = operation.inputs,
126 .outputs = operation.outputs,
127 };
128}
129
130Result<Model::OperandValues> convert(const hidl_vec<uint8_t>& operandValues) {
131 return Model::OperandValues(operandValues.data(), operandValues.size());
132}
133
134Result<Memory> convert(const hidl_memory& memory) {
135 return createSharedMemoryFromHidlMemory(memory);
136}
137
138Result<Model> convert(const hal::V1_0::Model& model) {
139 auto operations = NN_TRY(convert(model.operations));
140
141 // Verify number of consumers.
142 const auto numberOfConsumers =
143 hal::utils::countNumberOfConsumers(model.operands.size(), operations);
144 CHECK(model.operands.size() == numberOfConsumers.size());
145 for (size_t i = 0; i < model.operands.size(); ++i) {
146 if (model.operands[i].numberOfConsumers != numberOfConsumers[i]) {
147 return NN_ERROR() << "Invalid numberOfConsumers for operand " << i << ", expected "
148 << numberOfConsumers[i] << " but found "
149 << model.operands[i].numberOfConsumers;
150 }
151 }
152
153 auto main = Model::Subgraph{
154 .operands = NN_TRY(convert(model.operands)),
155 .operations = std::move(operations),
156 .inputIndexes = model.inputIndexes,
157 .outputIndexes = model.outputIndexes,
158 };
159
160 return Model{
161 .main = std::move(main),
162 .operandValues = NN_TRY(convert(model.operandValues)),
163 .pools = NN_TRY(convert(model.pools)),
164 };
165}
166
167Result<Request::Argument> convert(const hal::V1_0::RequestArgument& argument) {
168 const auto lifetime = argument.hasNoValue ? Request::Argument::LifeTime::NO_VALUE
169 : Request::Argument::LifeTime::POOL;
170 return Request::Argument{
171 .lifetime = lifetime,
172 .location = NN_TRY(convert(argument.location)),
173 .dimensions = argument.dimensions,
174 };
175}
176
177Result<Request> convert(const hal::V1_0::Request& request) {
178 auto memories = NN_TRY(convert(request.pools));
179 std::vector<Request::MemoryPool> pools;
180 pools.reserve(memories.size());
181 std::move(memories.begin(), memories.end(), std::back_inserter(pools));
182
183 return Request{
184 .inputs = NN_TRY(convert(request.inputs)),
185 .outputs = NN_TRY(convert(request.outputs)),
186 .pools = std::move(pools),
187 };
188}
189
190Result<ErrorStatus> convert(const hal::V1_0::ErrorStatus& status) {
191 switch (status) {
192 case hal::V1_0::ErrorStatus::NONE:
193 case hal::V1_0::ErrorStatus::DEVICE_UNAVAILABLE:
194 case hal::V1_0::ErrorStatus::GENERAL_FAILURE:
195 case hal::V1_0::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE:
196 case hal::V1_0::ErrorStatus::INVALID_ARGUMENT:
197 return static_cast<ErrorStatus>(status);
198 }
199 return NN_ERROR() << "Invalid ErrorStatus " << underlyingType(status);
200}
201
202} // namespace android::nn
203
204namespace android::hardware::neuralnetworks::V1_0::utils {
205namespace {
206
207template <typename Input>
208using ConvertOutput = std::decay_t<decltype(convert(std::declval<Input>()).value())>;
209
210template <typename Type>
211nn::Result<hidl_vec<ConvertOutput<Type>>> convert(const std::vector<Type>& arguments) {
212 hidl_vec<ConvertOutput<Type>> halObject(arguments.size());
213 for (size_t i = 0; i < arguments.size(); ++i) {
214 halObject[i] = NN_TRY(utils::convert(arguments[i]));
215 }
216 return halObject;
217}
218
219} // anonymous namespace
220
221nn::Result<OperandType> convert(const nn::OperandType& operandType) {
222 return static_cast<OperandType>(operandType);
223}
224
225nn::Result<OperationType> convert(const nn::OperationType& operationType) {
226 return static_cast<OperationType>(operationType);
227}
228
229nn::Result<OperandLifeTime> convert(const nn::Operand::LifeTime& lifetime) {
230 if (lifetime == nn::Operand::LifeTime::POINTER) {
231 return NN_ERROR() << "Model cannot be converted because it contains pointer-based memory";
232 }
233 return static_cast<OperandLifeTime>(lifetime);
234}
235
236nn::Result<DeviceStatus> convert(const nn::DeviceStatus& deviceStatus) {
237 return static_cast<DeviceStatus>(deviceStatus);
238}
239
240nn::Result<PerformanceInfo> convert(const nn::Capabilities::PerformanceInfo& performanceInfo) {
241 return PerformanceInfo{
242 .execTime = performanceInfo.execTime,
243 .powerUsage = performanceInfo.powerUsage,
244 };
245}
246
247nn::Result<Capabilities> convert(const nn::Capabilities& capabilities) {
248 return Capabilities{
249 .float32Performance = NN_TRY(convert(
250 capabilities.operandPerformance.lookup(nn::OperandType::TENSOR_FLOAT32))),
251 .quantized8Performance = NN_TRY(convert(
252 capabilities.operandPerformance.lookup(nn::OperandType::TENSOR_QUANT8_ASYMM))),
253 };
254}
255
256nn::Result<DataLocation> convert(const nn::DataLocation& location) {
257 return DataLocation{
258 .poolIndex = location.poolIndex,
259 .offset = location.offset,
260 .length = location.length,
261 };
262}
263
264nn::Result<Operand> convert(const nn::Operand& operand) {
265 return Operand{
266 .type = NN_TRY(convert(operand.type)),
267 .dimensions = operand.dimensions,
268 .numberOfConsumers = 0,
269 .scale = operand.scale,
270 .zeroPoint = operand.zeroPoint,
271 .lifetime = NN_TRY(convert(operand.lifetime)),
272 .location = NN_TRY(convert(operand.location)),
273 };
274}
275
276nn::Result<Operation> convert(const nn::Operation& operation) {
277 return Operation{
278 .type = NN_TRY(convert(operation.type)),
279 .inputs = operation.inputs,
280 .outputs = operation.outputs,
281 };
282}
283
284nn::Result<hidl_vec<uint8_t>> convert(const nn::Model::OperandValues& operandValues) {
285 return hidl_vec<uint8_t>(operandValues.data(), operandValues.data() + operandValues.size());
286}
287
288nn::Result<hidl_memory> convert(const nn::Memory& memory) {
289 const auto hidlMemory = hidl_memory(memory.name, memory.handle->handle(), memory.size);
290 // Copy memory to force the native_handle_t to be copied.
291 auto copiedMemory = hidlMemory;
292 return copiedMemory;
293}
294
295nn::Result<Model> convert(const nn::Model& model) {
296 if (!hal::utils::hasNoPointerData(model)) {
297 return NN_ERROR() << "Mdoel cannot be converted because it contains pointer-based memory";
298 }
299
300 auto operands = NN_TRY(convert(model.main.operands));
301
302 // Update number of consumers.
303 const auto numberOfConsumers =
304 hal::utils::countNumberOfConsumers(operands.size(), model.main.operations);
305 CHECK(operands.size() == numberOfConsumers.size());
306 for (size_t i = 0; i < operands.size(); ++i) {
307 operands[i].numberOfConsumers = numberOfConsumers[i];
308 }
309
310 return Model{
311 .operands = std::move(operands),
312 .operations = NN_TRY(convert(model.main.operations)),
313 .inputIndexes = model.main.inputIndexes,
314 .outputIndexes = model.main.outputIndexes,
315 .operandValues = NN_TRY(convert(model.operandValues)),
316 .pools = NN_TRY(convert(model.pools)),
317 };
318}
319
320nn::Result<RequestArgument> convert(const nn::Request::Argument& requestArgument) {
321 if (requestArgument.lifetime == nn::Request::Argument::LifeTime::POINTER) {
322 return NN_ERROR() << "Request cannot be converted because it contains pointer-based memory";
323 }
324 const bool hasNoValue = requestArgument.lifetime == nn::Request::Argument::LifeTime::NO_VALUE;
325 return RequestArgument{
326 .hasNoValue = hasNoValue,
327 .location = NN_TRY(convert(requestArgument.location)),
328 .dimensions = requestArgument.dimensions,
329 };
330}
331
332nn::Result<hidl_memory> convert(const nn::Request::MemoryPool& memoryPool) {
333 return convert(std::get<nn::Memory>(memoryPool));
334}
335
336nn::Result<Request> convert(const nn::Request& request) {
337 if (!hal::utils::hasNoPointerData(request)) {
338 return NN_ERROR() << "Request cannot be converted because it contains pointer-based memory";
339 }
340
341 return Request{
342 .inputs = NN_TRY(convert(request.inputs)),
343 .outputs = NN_TRY(convert(request.outputs)),
344 .pools = NN_TRY(convert(request.pools)),
345 };
346}
347
348nn::Result<ErrorStatus> convert(const nn::ErrorStatus& status) {
349 switch (status) {
350 case nn::ErrorStatus::NONE:
351 case nn::ErrorStatus::DEVICE_UNAVAILABLE:
352 case nn::ErrorStatus::GENERAL_FAILURE:
353 case nn::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE:
354 case nn::ErrorStatus::INVALID_ARGUMENT:
355 return static_cast<ErrorStatus>(status);
356 default:
357 return ErrorStatus::GENERAL_FAILURE;
358 }
359}
360
361} // namespace android::hardware::neuralnetworks::V1_0::utils