blob: 486d01bc5124b9487cf4bb3bfc0f9cd85c7567c6 [file] [log] [blame]
Lev Proleev6b6dfcd2020-11-11 18:28:50 +00001/*
2 * Copyright (C) 2021 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "Conversions.h"
18
19#include <aidl/android/hardware/common/NativeHandle.h>
20#include <android-base/logging.h>
21#include <nnapi/OperandTypes.h>
22#include <nnapi/OperationTypes.h>
23#include <nnapi/Result.h>
24#include <nnapi/SharedMemory.h>
25#include <nnapi/TypeUtils.h>
26#include <nnapi/Types.h>
27#include <nnapi/Validation.h>
28#include <nnapi/hal/CommonUtils.h>
29#include <nnapi/hal/HandleError.h>
30
31#include <algorithm>
32#include <chrono>
33#include <functional>
34#include <iterator>
35#include <limits>
36#include <type_traits>
37#include <utility>
38
39#define VERIFY_NON_NEGATIVE(value) \
40 while (UNLIKELY(value < 0)) return NN_ERROR()
41
42namespace {
43
44template <typename Type>
45constexpr std::underlying_type_t<Type> underlyingType(Type value) {
46 return static_cast<std::underlying_type_t<Type>>(value);
47}
48
49constexpr auto kVersion = android::nn::Version::ANDROID_S;
50
51} // namespace
52
53namespace android::nn {
54namespace {
55
Michael Butlerfadeb8a2021-02-07 00:11:13 -080056using ::aidl::android::hardware::common::NativeHandle;
57
Lev Proleev6b6dfcd2020-11-11 18:28:50 +000058constexpr auto validOperandType(nn::OperandType operandType) {
59 switch (operandType) {
60 case nn::OperandType::FLOAT32:
61 case nn::OperandType::INT32:
62 case nn::OperandType::UINT32:
63 case nn::OperandType::TENSOR_FLOAT32:
64 case nn::OperandType::TENSOR_INT32:
65 case nn::OperandType::TENSOR_QUANT8_ASYMM:
66 case nn::OperandType::BOOL:
67 case nn::OperandType::TENSOR_QUANT16_SYMM:
68 case nn::OperandType::TENSOR_FLOAT16:
69 case nn::OperandType::TENSOR_BOOL8:
70 case nn::OperandType::FLOAT16:
71 case nn::OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
72 case nn::OperandType::TENSOR_QUANT16_ASYMM:
73 case nn::OperandType::TENSOR_QUANT8_SYMM:
74 case nn::OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
75 case nn::OperandType::SUBGRAPH:
76 return true;
77 case nn::OperandType::OEM:
78 case nn::OperandType::TENSOR_OEM_BYTE:
79 return false;
80 }
81 return nn::isExtension(operandType);
82}
83
84template <typename Input>
85using UnvalidatedConvertOutput =
86 std::decay_t<decltype(unvalidatedConvert(std::declval<Input>()).value())>;
87
88template <typename Type>
89GeneralResult<std::vector<UnvalidatedConvertOutput<Type>>> unvalidatedConvertVec(
90 const std::vector<Type>& arguments) {
91 std::vector<UnvalidatedConvertOutput<Type>> canonical;
92 canonical.reserve(arguments.size());
93 for (const auto& argument : arguments) {
94 canonical.push_back(NN_TRY(nn::unvalidatedConvert(argument)));
95 }
96 return canonical;
97}
98
99template <typename Type>
100GeneralResult<std::vector<UnvalidatedConvertOutput<Type>>> unvalidatedConvert(
101 const std::vector<Type>& arguments) {
102 return unvalidatedConvertVec(arguments);
103}
104
105template <typename Type>
106GeneralResult<UnvalidatedConvertOutput<Type>> validatedConvert(const Type& halObject) {
107 auto canonical = NN_TRY(nn::unvalidatedConvert(halObject));
108 const auto maybeVersion = validate(canonical);
109 if (!maybeVersion.has_value()) {
110 return error() << maybeVersion.error();
111 }
112 const auto version = maybeVersion.value();
113 if (version > kVersion) {
114 return NN_ERROR() << "Insufficient version: " << version << " vs required " << kVersion;
115 }
116 return canonical;
117}
118
119template <typename Type>
120GeneralResult<std::vector<UnvalidatedConvertOutput<Type>>> validatedConvert(
121 const std::vector<Type>& arguments) {
122 std::vector<UnvalidatedConvertOutput<Type>> canonical;
123 canonical.reserve(arguments.size());
124 for (const auto& argument : arguments) {
125 canonical.push_back(NN_TRY(validatedConvert(argument)));
126 }
127 return canonical;
128}
129
130} // anonymous namespace
131
132GeneralResult<OperandType> unvalidatedConvert(const aidl_hal::OperandType& operandType) {
133 VERIFY_NON_NEGATIVE(underlyingType(operandType)) << "Negative operand types are not allowed.";
134 return static_cast<OperandType>(operandType);
135}
136
137GeneralResult<OperationType> unvalidatedConvert(const aidl_hal::OperationType& operationType) {
138 VERIFY_NON_NEGATIVE(underlyingType(operationType))
139 << "Negative operation types are not allowed.";
140 return static_cast<OperationType>(operationType);
141}
142
143GeneralResult<DeviceType> unvalidatedConvert(const aidl_hal::DeviceType& deviceType) {
144 return static_cast<DeviceType>(deviceType);
145}
146
147GeneralResult<Priority> unvalidatedConvert(const aidl_hal::Priority& priority) {
148 return static_cast<Priority>(priority);
149}
150
151GeneralResult<Capabilities> unvalidatedConvert(const aidl_hal::Capabilities& capabilities) {
152 const bool validOperandTypes = std::all_of(
153 capabilities.operandPerformance.begin(), capabilities.operandPerformance.end(),
154 [](const aidl_hal::OperandPerformance& operandPerformance) {
155 const auto maybeType = unvalidatedConvert(operandPerformance.type);
156 return !maybeType.has_value() ? false : validOperandType(maybeType.value());
157 });
158 if (!validOperandTypes) {
159 return NN_ERROR() << "Invalid OperandType when unvalidatedConverting OperandPerformance in "
160 "Capabilities";
161 }
162
163 auto operandPerformance = NN_TRY(unvalidatedConvert(capabilities.operandPerformance));
164 auto table = NN_TRY(hal::utils::makeGeneralFailure(
165 Capabilities::OperandPerformanceTable::create(std::move(operandPerformance)),
166 nn::ErrorStatus::GENERAL_FAILURE));
167
168 return Capabilities{
169 .relaxedFloat32toFloat16PerformanceScalar = NN_TRY(
170 unvalidatedConvert(capabilities.relaxedFloat32toFloat16PerformanceScalar)),
171 .relaxedFloat32toFloat16PerformanceTensor = NN_TRY(
172 unvalidatedConvert(capabilities.relaxedFloat32toFloat16PerformanceTensor)),
173 .operandPerformance = std::move(table),
174 .ifPerformance = NN_TRY(unvalidatedConvert(capabilities.ifPerformance)),
175 .whilePerformance = NN_TRY(unvalidatedConvert(capabilities.whilePerformance)),
176 };
177}
178
179GeneralResult<Capabilities::OperandPerformance> unvalidatedConvert(
180 const aidl_hal::OperandPerformance& operandPerformance) {
181 return Capabilities::OperandPerformance{
182 .type = NN_TRY(unvalidatedConvert(operandPerformance.type)),
183 .info = NN_TRY(unvalidatedConvert(operandPerformance.info)),
184 };
185}
186
187GeneralResult<Capabilities::PerformanceInfo> unvalidatedConvert(
188 const aidl_hal::PerformanceInfo& performanceInfo) {
189 return Capabilities::PerformanceInfo{
190 .execTime = performanceInfo.execTime,
191 .powerUsage = performanceInfo.powerUsage,
192 };
193}
194
195GeneralResult<DataLocation> unvalidatedConvert(const aidl_hal::DataLocation& location) {
196 VERIFY_NON_NEGATIVE(location.poolIndex) << "DataLocation: pool index must not be negative";
197 VERIFY_NON_NEGATIVE(location.offset) << "DataLocation: offset must not be negative";
198 VERIFY_NON_NEGATIVE(location.length) << "DataLocation: length must not be negative";
199 if (location.offset > std::numeric_limits<uint32_t>::max()) {
200 return NN_ERROR() << "DataLocation: offset must be <= std::numeric_limits<uint32_t>::max()";
201 }
202 if (location.length > std::numeric_limits<uint32_t>::max()) {
203 return NN_ERROR() << "DataLocation: length must be <= std::numeric_limits<uint32_t>::max()";
204 }
205 return DataLocation{
206 .poolIndex = static_cast<uint32_t>(location.poolIndex),
207 .offset = static_cast<uint32_t>(location.offset),
208 .length = static_cast<uint32_t>(location.length),
209 };
210}
211
212GeneralResult<Operation> unvalidatedConvert(const aidl_hal::Operation& operation) {
213 return Operation{
214 .type = NN_TRY(unvalidatedConvert(operation.type)),
215 .inputs = NN_TRY(toUnsigned(operation.inputs)),
216 .outputs = NN_TRY(toUnsigned(operation.outputs)),
217 };
218}
219
220GeneralResult<Operand::LifeTime> unvalidatedConvert(
221 const aidl_hal::OperandLifeTime& operandLifeTime) {
222 return static_cast<Operand::LifeTime>(operandLifeTime);
223}
224
225GeneralResult<Operand> unvalidatedConvert(const aidl_hal::Operand& operand) {
226 return Operand{
227 .type = NN_TRY(unvalidatedConvert(operand.type)),
228 .dimensions = NN_TRY(toUnsigned(operand.dimensions)),
229 .scale = operand.scale,
230 .zeroPoint = operand.zeroPoint,
231 .lifetime = NN_TRY(unvalidatedConvert(operand.lifetime)),
232 .location = NN_TRY(unvalidatedConvert(operand.location)),
233 .extraParams = NN_TRY(unvalidatedConvert(operand.extraParams)),
234 };
235}
236
237GeneralResult<Operand::ExtraParams> unvalidatedConvert(
238 const std::optional<aidl_hal::OperandExtraParams>& optionalExtraParams) {
239 if (!optionalExtraParams.has_value()) {
240 return Operand::NoParams{};
241 }
242 const auto& extraParams = optionalExtraParams.value();
243 using Tag = aidl_hal::OperandExtraParams::Tag;
244 switch (extraParams.getTag()) {
245 case Tag::channelQuant:
246 return unvalidatedConvert(extraParams.get<Tag::channelQuant>());
247 case Tag::extension:
248 return extraParams.get<Tag::extension>();
249 }
250 return NN_ERROR() << "Unrecognized Operand::ExtraParams tag: "
251 << underlyingType(extraParams.getTag());
252}
253
254GeneralResult<Operand::SymmPerChannelQuantParams> unvalidatedConvert(
255 const aidl_hal::SymmPerChannelQuantParams& symmPerChannelQuantParams) {
256 VERIFY_NON_NEGATIVE(symmPerChannelQuantParams.channelDim)
257 << "Per-channel quantization channel dimension must not be negative.";
258 return Operand::SymmPerChannelQuantParams{
259 .scales = symmPerChannelQuantParams.scales,
260 .channelDim = static_cast<uint32_t>(symmPerChannelQuantParams.channelDim),
261 };
262}
263
264GeneralResult<Model> unvalidatedConvert(const aidl_hal::Model& model) {
265 return Model{
266 .main = NN_TRY(unvalidatedConvert(model.main)),
267 .referenced = NN_TRY(unvalidatedConvert(model.referenced)),
268 .operandValues = NN_TRY(unvalidatedConvert(model.operandValues)),
269 .pools = NN_TRY(unvalidatedConvert(model.pools)),
270 .relaxComputationFloat32toFloat16 = model.relaxComputationFloat32toFloat16,
271 .extensionNameToPrefix = NN_TRY(unvalidatedConvert(model.extensionNameToPrefix)),
272 };
273}
274
275GeneralResult<Model::Subgraph> unvalidatedConvert(const aidl_hal::Subgraph& subgraph) {
276 return Model::Subgraph{
277 .operands = NN_TRY(unvalidatedConvert(subgraph.operands)),
278 .operations = NN_TRY(unvalidatedConvert(subgraph.operations)),
279 .inputIndexes = NN_TRY(toUnsigned(subgraph.inputIndexes)),
280 .outputIndexes = NN_TRY(toUnsigned(subgraph.outputIndexes)),
281 };
282}
283
284GeneralResult<Model::ExtensionNameAndPrefix> unvalidatedConvert(
285 const aidl_hal::ExtensionNameAndPrefix& extensionNameAndPrefix) {
286 return Model::ExtensionNameAndPrefix{
287 .name = extensionNameAndPrefix.name,
288 .prefix = extensionNameAndPrefix.prefix,
289 };
290}
291
292GeneralResult<Extension> unvalidatedConvert(const aidl_hal::Extension& extension) {
293 return Extension{
294 .name = extension.name,
295 .operandTypes = NN_TRY(unvalidatedConvert(extension.operandTypes)),
296 };
297}
298
299GeneralResult<Extension::OperandTypeInformation> unvalidatedConvert(
300 const aidl_hal::ExtensionOperandTypeInformation& operandTypeInformation) {
301 VERIFY_NON_NEGATIVE(operandTypeInformation.byteSize)
302 << "Extension operand type byte size must not be negative";
303 return Extension::OperandTypeInformation{
304 .type = operandTypeInformation.type,
305 .isTensor = operandTypeInformation.isTensor,
306 .byteSize = static_cast<uint32_t>(operandTypeInformation.byteSize),
307 };
308}
309
310GeneralResult<OutputShape> unvalidatedConvert(const aidl_hal::OutputShape& outputShape) {
311 return OutputShape{
312 .dimensions = NN_TRY(toUnsigned(outputShape.dimensions)),
313 .isSufficient = outputShape.isSufficient,
314 };
315}
316
317GeneralResult<MeasureTiming> unvalidatedConvert(bool measureTiming) {
318 return measureTiming ? MeasureTiming::YES : MeasureTiming::NO;
319}
320
Michael Butlerfadeb8a2021-02-07 00:11:13 -0800321GeneralResult<SharedMemory> unvalidatedConvert(const aidl_hal::Memory& memory) {
Lev Proleev6b6dfcd2020-11-11 18:28:50 +0000322 VERIFY_NON_NEGATIVE(memory.size) << "Memory size must not be negative";
Michael Butlerfadeb8a2021-02-07 00:11:13 -0800323 return std::make_shared<const Memory>(Memory{
Lev Proleev6b6dfcd2020-11-11 18:28:50 +0000324 .handle = NN_TRY(unvalidatedConvert(memory.handle)),
325 .size = static_cast<uint32_t>(memory.size),
326 .name = memory.name,
Michael Butlerfadeb8a2021-02-07 00:11:13 -0800327 });
Lev Proleev6b6dfcd2020-11-11 18:28:50 +0000328}
329
330GeneralResult<Model::OperandValues> unvalidatedConvert(const std::vector<uint8_t>& operandValues) {
331 return Model::OperandValues(operandValues.data(), operandValues.size());
332}
333
334GeneralResult<BufferDesc> unvalidatedConvert(const aidl_hal::BufferDesc& bufferDesc) {
335 return BufferDesc{.dimensions = NN_TRY(toUnsigned(bufferDesc.dimensions))};
336}
337
338GeneralResult<BufferRole> unvalidatedConvert(const aidl_hal::BufferRole& bufferRole) {
339 VERIFY_NON_NEGATIVE(bufferRole.modelIndex) << "BufferRole: modelIndex must not be negative";
340 VERIFY_NON_NEGATIVE(bufferRole.ioIndex) << "BufferRole: ioIndex must not be negative";
341 return BufferRole{
342 .modelIndex = static_cast<uint32_t>(bufferRole.modelIndex),
343 .ioIndex = static_cast<uint32_t>(bufferRole.ioIndex),
344 .frequency = bufferRole.frequency,
345 };
346}
347
348GeneralResult<Request> unvalidatedConvert(const aidl_hal::Request& request) {
349 return Request{
350 .inputs = NN_TRY(unvalidatedConvert(request.inputs)),
351 .outputs = NN_TRY(unvalidatedConvert(request.outputs)),
352 .pools = NN_TRY(unvalidatedConvert(request.pools)),
353 };
354}
355
356GeneralResult<Request::Argument> unvalidatedConvert(const aidl_hal::RequestArgument& argument) {
357 const auto lifetime = argument.hasNoValue ? Request::Argument::LifeTime::NO_VALUE
358 : Request::Argument::LifeTime::POOL;
359 return Request::Argument{
360 .lifetime = lifetime,
361 .location = NN_TRY(unvalidatedConvert(argument.location)),
362 .dimensions = NN_TRY(toUnsigned(argument.dimensions)),
363 };
364}
365
366GeneralResult<Request::MemoryPool> unvalidatedConvert(
367 const aidl_hal::RequestMemoryPool& memoryPool) {
368 using Tag = aidl_hal::RequestMemoryPool::Tag;
369 switch (memoryPool.getTag()) {
370 case Tag::pool:
371 return unvalidatedConvert(memoryPool.get<Tag::pool>());
372 case Tag::token: {
373 const auto token = memoryPool.get<Tag::token>();
374 VERIFY_NON_NEGATIVE(token) << "Memory pool token must not be negative";
375 return static_cast<Request::MemoryDomainToken>(token);
376 }
377 }
378 return NN_ERROR() << "Invalid Request::MemoryPool tag " << underlyingType(memoryPool.getTag());
379}
380
381GeneralResult<ErrorStatus> unvalidatedConvert(const aidl_hal::ErrorStatus& status) {
382 switch (status) {
383 case aidl_hal::ErrorStatus::NONE:
384 case aidl_hal::ErrorStatus::DEVICE_UNAVAILABLE:
385 case aidl_hal::ErrorStatus::GENERAL_FAILURE:
386 case aidl_hal::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE:
387 case aidl_hal::ErrorStatus::INVALID_ARGUMENT:
388 case aidl_hal::ErrorStatus::MISSED_DEADLINE_TRANSIENT:
389 case aidl_hal::ErrorStatus::MISSED_DEADLINE_PERSISTENT:
390 case aidl_hal::ErrorStatus::RESOURCE_EXHAUSTED_TRANSIENT:
391 case aidl_hal::ErrorStatus::RESOURCE_EXHAUSTED_PERSISTENT:
392 return static_cast<ErrorStatus>(status);
393 }
394 return NN_ERROR() << "Invalid ErrorStatus " << underlyingType(status);
395}
396
397GeneralResult<ExecutionPreference> unvalidatedConvert(
398 const aidl_hal::ExecutionPreference& executionPreference) {
399 return static_cast<ExecutionPreference>(executionPreference);
400}
401
Michael Butlerfadeb8a2021-02-07 00:11:13 -0800402GeneralResult<SharedHandle> unvalidatedConvert(const NativeHandle& aidlNativeHandle) {
Lev Proleev6b6dfcd2020-11-11 18:28:50 +0000403 std::vector<base::unique_fd> fds;
404 fds.reserve(aidlNativeHandle.fds.size());
405 for (const auto& fd : aidlNativeHandle.fds) {
406 int dupFd = dup(fd.get());
407 if (dupFd == -1) {
408 // TODO(b/120417090): is ANEURALNETWORKS_UNEXPECTED_NULL the correct error to return
409 // here?
410 return NN_ERROR() << "Failed to dup the fd";
411 }
412 fds.emplace_back(dupFd);
413 }
414
415 return std::make_shared<const Handle>(Handle{
416 .fds = std::move(fds),
417 .ints = aidlNativeHandle.ints,
418 });
419}
420
421GeneralResult<ExecutionPreference> convert(
422 const aidl_hal::ExecutionPreference& executionPreference) {
423 return validatedConvert(executionPreference);
424}
425
Michael Butlerfadeb8a2021-02-07 00:11:13 -0800426GeneralResult<SharedMemory> convert(const aidl_hal::Memory& operand) {
Lev Proleev6b6dfcd2020-11-11 18:28:50 +0000427 return validatedConvert(operand);
428}
429
430GeneralResult<Model> convert(const aidl_hal::Model& model) {
431 return validatedConvert(model);
432}
433
434GeneralResult<Operand> convert(const aidl_hal::Operand& operand) {
435 return unvalidatedConvert(operand);
436}
437
438GeneralResult<OperandType> convert(const aidl_hal::OperandType& operandType) {
439 return unvalidatedConvert(operandType);
440}
441
442GeneralResult<Priority> convert(const aidl_hal::Priority& priority) {
443 return validatedConvert(priority);
444}
445
446GeneralResult<Request::MemoryPool> convert(const aidl_hal::RequestMemoryPool& memoryPool) {
447 return unvalidatedConvert(memoryPool);
448}
449
450GeneralResult<Request> convert(const aidl_hal::Request& request) {
451 return validatedConvert(request);
452}
453
454GeneralResult<std::vector<Operation>> convert(const std::vector<aidl_hal::Operation>& operations) {
455 return unvalidatedConvert(operations);
456}
457
Michael Butlerfadeb8a2021-02-07 00:11:13 -0800458GeneralResult<std::vector<SharedMemory>> convert(const std::vector<aidl_hal::Memory>& memories) {
Lev Proleev6b6dfcd2020-11-11 18:28:50 +0000459 return validatedConvert(memories);
460}
461
462GeneralResult<std::vector<uint32_t>> toUnsigned(const std::vector<int32_t>& vec) {
463 if (!std::all_of(vec.begin(), vec.end(), [](int32_t v) { return v >= 0; })) {
464 return NN_ERROR() << "Negative value passed to conversion from signed to unsigned";
465 }
466 return std::vector<uint32_t>(vec.begin(), vec.end());
467}
468
469} // namespace android::nn
470
471namespace aidl::android::hardware::neuralnetworks::utils {
472namespace {
473
474template <typename Input>
475using UnvalidatedConvertOutput =
476 std::decay_t<decltype(unvalidatedConvert(std::declval<Input>()).value())>;
477
478template <typename Type>
479nn::GeneralResult<std::vector<UnvalidatedConvertOutput<Type>>> unvalidatedConvertVec(
480 const std::vector<Type>& arguments) {
481 std::vector<UnvalidatedConvertOutput<Type>> halObject(arguments.size());
482 for (size_t i = 0; i < arguments.size(); ++i) {
483 halObject[i] = NN_TRY(unvalidatedConvert(arguments[i]));
484 }
485 return halObject;
486}
487
488template <typename Type>
489nn::GeneralResult<UnvalidatedConvertOutput<Type>> validatedConvert(const Type& canonical) {
490 const auto maybeVersion = nn::validate(canonical);
491 if (!maybeVersion.has_value()) {
492 return nn::error() << maybeVersion.error();
493 }
494 const auto version = maybeVersion.value();
495 if (version > kVersion) {
496 return NN_ERROR() << "Insufficient version: " << version << " vs required " << kVersion;
497 }
498 return utils::unvalidatedConvert(canonical);
499}
500
501template <typename Type>
502nn::GeneralResult<std::vector<UnvalidatedConvertOutput<Type>>> validatedConvert(
503 const std::vector<Type>& arguments) {
504 std::vector<UnvalidatedConvertOutput<Type>> halObject(arguments.size());
505 for (size_t i = 0; i < arguments.size(); ++i) {
506 halObject[i] = NN_TRY(validatedConvert(arguments[i]));
507 }
508 return halObject;
509}
510
511} // namespace
512
513nn::GeneralResult<common::NativeHandle> unvalidatedConvert(const nn::SharedHandle& sharedHandle) {
514 common::NativeHandle aidlNativeHandle;
515 aidlNativeHandle.fds.reserve(sharedHandle->fds.size());
516 for (const auto& fd : sharedHandle->fds) {
517 int dupFd = dup(fd.get());
518 if (dupFd == -1) {
519 // TODO(b/120417090): is ANEURALNETWORKS_UNEXPECTED_NULL the correct error to return
520 // here?
521 return NN_ERROR() << "Failed to dup the fd";
522 }
523 aidlNativeHandle.fds.emplace_back(dupFd);
524 }
525 aidlNativeHandle.ints = sharedHandle->ints;
526 return aidlNativeHandle;
527}
528
Michael Butlerfadeb8a2021-02-07 00:11:13 -0800529nn::GeneralResult<Memory> unvalidatedConvert(const nn::SharedMemory& memory) {
530 CHECK(memory != nullptr);
531 if (memory->size > std::numeric_limits<int64_t>::max()) {
Lev Proleev6b6dfcd2020-11-11 18:28:50 +0000532 return NN_ERROR() << "Memory size doesn't fit into int64_t.";
533 }
534 return Memory{
Michael Butlerfadeb8a2021-02-07 00:11:13 -0800535 .handle = NN_TRY(unvalidatedConvert(memory->handle)),
536 .size = static_cast<int64_t>(memory->size),
537 .name = memory->name,
Lev Proleev6b6dfcd2020-11-11 18:28:50 +0000538 };
539}
540
541nn::GeneralResult<ErrorStatus> unvalidatedConvert(const nn::ErrorStatus& errorStatus) {
542 switch (errorStatus) {
543 case nn::ErrorStatus::NONE:
544 case nn::ErrorStatus::DEVICE_UNAVAILABLE:
545 case nn::ErrorStatus::GENERAL_FAILURE:
546 case nn::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE:
547 case nn::ErrorStatus::INVALID_ARGUMENT:
548 case nn::ErrorStatus::MISSED_DEADLINE_TRANSIENT:
549 case nn::ErrorStatus::MISSED_DEADLINE_PERSISTENT:
550 case nn::ErrorStatus::RESOURCE_EXHAUSTED_TRANSIENT:
551 case nn::ErrorStatus::RESOURCE_EXHAUSTED_PERSISTENT:
552 return static_cast<ErrorStatus>(errorStatus);
553 default:
554 return ErrorStatus::GENERAL_FAILURE;
555 }
556}
557
558nn::GeneralResult<OutputShape> unvalidatedConvert(const nn::OutputShape& outputShape) {
559 return OutputShape{.dimensions = NN_TRY(toSigned(outputShape.dimensions)),
560 .isSufficient = outputShape.isSufficient};
561}
562
Michael Butlerfadeb8a2021-02-07 00:11:13 -0800563nn::GeneralResult<Memory> convert(const nn::SharedMemory& memory) {
Lev Proleev6b6dfcd2020-11-11 18:28:50 +0000564 return validatedConvert(memory);
565}
566
567nn::GeneralResult<ErrorStatus> convert(const nn::ErrorStatus& errorStatus) {
568 return validatedConvert(errorStatus);
569}
570
571nn::GeneralResult<std::vector<OutputShape>> convert(
572 const std::vector<nn::OutputShape>& outputShapes) {
573 return validatedConvert(outputShapes);
574}
575
576nn::GeneralResult<std::vector<int32_t>> toSigned(const std::vector<uint32_t>& vec) {
577 if (!std::all_of(vec.begin(), vec.end(),
578 [](uint32_t v) { return v <= std::numeric_limits<int32_t>::max(); })) {
579 return NN_ERROR() << "Vector contains a value that doesn't fit into int32_t.";
580 }
581 return std::vector<int32_t>(vec.begin(), vec.end());
582}
583
584} // namespace aidl::android::hardware::neuralnetworks::utils