blob: 45bc005e9f6c8e6efb484bc386ece744c5ba5cc9 [file] [log] [blame]
Lev Proleev6b6dfcd2020-11-11 18:28:50 +00001/*
2 * Copyright (C) 2021 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "Conversions.h"
18
19#include <aidl/android/hardware/common/NativeHandle.h>
20#include <android-base/logging.h>
Lev Proleev900c28a2021-01-26 19:40:20 +000021#include <android-base/unique_fd.h>
22#include <android/binder_auto_utils.h>
Michael Butlerab2f4822021-02-08 00:05:07 -080023#include <android/hardware_buffer.h>
24#include <cutils/native_handle.h>
Lev Proleev6b6dfcd2020-11-11 18:28:50 +000025#include <nnapi/OperandTypes.h>
26#include <nnapi/OperationTypes.h>
27#include <nnapi/Result.h>
28#include <nnapi/SharedMemory.h>
29#include <nnapi/TypeUtils.h>
30#include <nnapi/Types.h>
31#include <nnapi/Validation.h>
32#include <nnapi/hal/CommonUtils.h>
33#include <nnapi/hal/HandleError.h>
Michael Butlerab2f4822021-02-08 00:05:07 -080034#include <vndk/hardware_buffer.h>
Lev Proleev6b6dfcd2020-11-11 18:28:50 +000035
36#include <algorithm>
37#include <chrono>
38#include <functional>
39#include <iterator>
40#include <limits>
41#include <type_traits>
42#include <utility>
43
44#define VERIFY_NON_NEGATIVE(value) \
45 while (UNLIKELY(value < 0)) return NN_ERROR()
46
Lev Proleev900c28a2021-01-26 19:40:20 +000047#define VERIFY_LE_INT32_MAX(value) \
48 while (UNLIKELY(value > std::numeric_limits<int32_t>::max())) return NN_ERROR()
Lev Proleev6b6dfcd2020-11-11 18:28:50 +000049
Lev Proleev900c28a2021-01-26 19:40:20 +000050namespace {
Lev Proleev6b6dfcd2020-11-11 18:28:50 +000051template <typename Type>
52constexpr std::underlying_type_t<Type> underlyingType(Type value) {
53 return static_cast<std::underlying_type_t<Type>>(value);
54}
55
56constexpr auto kVersion = android::nn::Version::ANDROID_S;
Lev Proleev900c28a2021-01-26 19:40:20 +000057constexpr int64_t kNoTiming = -1;
Lev Proleev6b6dfcd2020-11-11 18:28:50 +000058
59} // namespace
60
61namespace android::nn {
62namespace {
63
Michael Butlerfadeb8a2021-02-07 00:11:13 -080064using ::aidl::android::hardware::common::NativeHandle;
65
Lev Proleev6b6dfcd2020-11-11 18:28:50 +000066constexpr auto validOperandType(nn::OperandType operandType) {
67 switch (operandType) {
68 case nn::OperandType::FLOAT32:
69 case nn::OperandType::INT32:
70 case nn::OperandType::UINT32:
71 case nn::OperandType::TENSOR_FLOAT32:
72 case nn::OperandType::TENSOR_INT32:
73 case nn::OperandType::TENSOR_QUANT8_ASYMM:
74 case nn::OperandType::BOOL:
75 case nn::OperandType::TENSOR_QUANT16_SYMM:
76 case nn::OperandType::TENSOR_FLOAT16:
77 case nn::OperandType::TENSOR_BOOL8:
78 case nn::OperandType::FLOAT16:
79 case nn::OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
80 case nn::OperandType::TENSOR_QUANT16_ASYMM:
81 case nn::OperandType::TENSOR_QUANT8_SYMM:
82 case nn::OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
83 case nn::OperandType::SUBGRAPH:
84 return true;
85 case nn::OperandType::OEM:
86 case nn::OperandType::TENSOR_OEM_BYTE:
87 return false;
88 }
89 return nn::isExtension(operandType);
90}
91
92template <typename Input>
93using UnvalidatedConvertOutput =
94 std::decay_t<decltype(unvalidatedConvert(std::declval<Input>()).value())>;
95
96template <typename Type>
97GeneralResult<std::vector<UnvalidatedConvertOutput<Type>>> unvalidatedConvertVec(
98 const std::vector<Type>& arguments) {
99 std::vector<UnvalidatedConvertOutput<Type>> canonical;
100 canonical.reserve(arguments.size());
101 for (const auto& argument : arguments) {
102 canonical.push_back(NN_TRY(nn::unvalidatedConvert(argument)));
103 }
104 return canonical;
105}
106
107template <typename Type>
108GeneralResult<std::vector<UnvalidatedConvertOutput<Type>>> unvalidatedConvert(
109 const std::vector<Type>& arguments) {
110 return unvalidatedConvertVec(arguments);
111}
112
113template <typename Type>
114GeneralResult<UnvalidatedConvertOutput<Type>> validatedConvert(const Type& halObject) {
115 auto canonical = NN_TRY(nn::unvalidatedConvert(halObject));
116 const auto maybeVersion = validate(canonical);
117 if (!maybeVersion.has_value()) {
118 return error() << maybeVersion.error();
119 }
120 const auto version = maybeVersion.value();
121 if (version > kVersion) {
122 return NN_ERROR() << "Insufficient version: " << version << " vs required " << kVersion;
123 }
124 return canonical;
125}
126
127template <typename Type>
128GeneralResult<std::vector<UnvalidatedConvertOutput<Type>>> validatedConvert(
129 const std::vector<Type>& arguments) {
130 std::vector<UnvalidatedConvertOutput<Type>> canonical;
131 canonical.reserve(arguments.size());
132 for (const auto& argument : arguments) {
133 canonical.push_back(NN_TRY(validatedConvert(argument)));
134 }
135 return canonical;
136}
137
Michael Butlerab2f4822021-02-08 00:05:07 -0800138GeneralResult<Handle> unvalidatedConvertHelper(const NativeHandle& aidlNativeHandle) {
139 std::vector<base::unique_fd> fds;
140 fds.reserve(aidlNativeHandle.fds.size());
141 for (const auto& fd : aidlNativeHandle.fds) {
Lev Proleev900c28a2021-01-26 19:40:20 +0000142 auto duplicatedFd = NN_TRY(dupFd(fd.get()));
143 fds.emplace_back(duplicatedFd.release());
Michael Butlerab2f4822021-02-08 00:05:07 -0800144 }
145
146 return Handle{.fds = std::move(fds), .ints = aidlNativeHandle.ints};
147}
148
149struct NativeHandleDeleter {
150 void operator()(native_handle_t* handle) const {
151 if (handle) {
152 native_handle_close(handle);
153 native_handle_delete(handle);
154 }
155 }
156};
157
158using UniqueNativeHandle = std::unique_ptr<native_handle_t, NativeHandleDeleter>;
159
Lev Proleev900c28a2021-01-26 19:40:20 +0000160static GeneralResult<UniqueNativeHandle> nativeHandleFromAidlHandle(const NativeHandle& handle) {
Michael Butlerab2f4822021-02-08 00:05:07 -0800161 std::vector<base::unique_fd> fds;
162 fds.reserve(handle.fds.size());
163 for (const auto& fd : handle.fds) {
Lev Proleev900c28a2021-01-26 19:40:20 +0000164 auto duplicatedFd = NN_TRY(dupFd(fd.get()));
165 fds.emplace_back(duplicatedFd.release());
Michael Butlerab2f4822021-02-08 00:05:07 -0800166 }
167
168 constexpr size_t kIntMax = std::numeric_limits<int>::max();
169 CHECK_LE(handle.fds.size(), kIntMax);
170 CHECK_LE(handle.ints.size(), kIntMax);
171 native_handle_t* nativeHandle = native_handle_create(static_cast<int>(handle.fds.size()),
172 static_cast<int>(handle.ints.size()));
173 if (nativeHandle == nullptr) {
174 return NN_ERROR() << "Failed to create native_handle";
175 }
176 for (size_t i = 0; i < fds.size(); ++i) {
177 nativeHandle->data[i] = fds[i].release();
178 }
179 std::copy(handle.ints.begin(), handle.ints.end(), &nativeHandle->data[nativeHandle->numFds]);
180
181 return UniqueNativeHandle(nativeHandle);
182}
183
Lev Proleev6b6dfcd2020-11-11 18:28:50 +0000184} // anonymous namespace
185
186GeneralResult<OperandType> unvalidatedConvert(const aidl_hal::OperandType& operandType) {
187 VERIFY_NON_NEGATIVE(underlyingType(operandType)) << "Negative operand types are not allowed.";
188 return static_cast<OperandType>(operandType);
189}
190
191GeneralResult<OperationType> unvalidatedConvert(const aidl_hal::OperationType& operationType) {
192 VERIFY_NON_NEGATIVE(underlyingType(operationType))
193 << "Negative operation types are not allowed.";
194 return static_cast<OperationType>(operationType);
195}
196
197GeneralResult<DeviceType> unvalidatedConvert(const aidl_hal::DeviceType& deviceType) {
198 return static_cast<DeviceType>(deviceType);
199}
200
201GeneralResult<Priority> unvalidatedConvert(const aidl_hal::Priority& priority) {
202 return static_cast<Priority>(priority);
203}
204
205GeneralResult<Capabilities> unvalidatedConvert(const aidl_hal::Capabilities& capabilities) {
206 const bool validOperandTypes = std::all_of(
207 capabilities.operandPerformance.begin(), capabilities.operandPerformance.end(),
208 [](const aidl_hal::OperandPerformance& operandPerformance) {
209 const auto maybeType = unvalidatedConvert(operandPerformance.type);
210 return !maybeType.has_value() ? false : validOperandType(maybeType.value());
211 });
212 if (!validOperandTypes) {
213 return NN_ERROR() << "Invalid OperandType when unvalidatedConverting OperandPerformance in "
214 "Capabilities";
215 }
216
217 auto operandPerformance = NN_TRY(unvalidatedConvert(capabilities.operandPerformance));
218 auto table = NN_TRY(hal::utils::makeGeneralFailure(
219 Capabilities::OperandPerformanceTable::create(std::move(operandPerformance)),
220 nn::ErrorStatus::GENERAL_FAILURE));
221
222 return Capabilities{
223 .relaxedFloat32toFloat16PerformanceScalar = NN_TRY(
224 unvalidatedConvert(capabilities.relaxedFloat32toFloat16PerformanceScalar)),
225 .relaxedFloat32toFloat16PerformanceTensor = NN_TRY(
226 unvalidatedConvert(capabilities.relaxedFloat32toFloat16PerformanceTensor)),
227 .operandPerformance = std::move(table),
228 .ifPerformance = NN_TRY(unvalidatedConvert(capabilities.ifPerformance)),
229 .whilePerformance = NN_TRY(unvalidatedConvert(capabilities.whilePerformance)),
230 };
231}
232
233GeneralResult<Capabilities::OperandPerformance> unvalidatedConvert(
234 const aidl_hal::OperandPerformance& operandPerformance) {
235 return Capabilities::OperandPerformance{
236 .type = NN_TRY(unvalidatedConvert(operandPerformance.type)),
237 .info = NN_TRY(unvalidatedConvert(operandPerformance.info)),
238 };
239}
240
241GeneralResult<Capabilities::PerformanceInfo> unvalidatedConvert(
242 const aidl_hal::PerformanceInfo& performanceInfo) {
243 return Capabilities::PerformanceInfo{
244 .execTime = performanceInfo.execTime,
245 .powerUsage = performanceInfo.powerUsage,
246 };
247}
248
249GeneralResult<DataLocation> unvalidatedConvert(const aidl_hal::DataLocation& location) {
250 VERIFY_NON_NEGATIVE(location.poolIndex) << "DataLocation: pool index must not be negative";
251 VERIFY_NON_NEGATIVE(location.offset) << "DataLocation: offset must not be negative";
252 VERIFY_NON_NEGATIVE(location.length) << "DataLocation: length must not be negative";
Xusong Wang5e36ca02021-02-16 10:40:32 -0800253 VERIFY_NON_NEGATIVE(location.padding) << "DataLocation: padding must not be negative";
Lev Proleev6b6dfcd2020-11-11 18:28:50 +0000254 if (location.offset > std::numeric_limits<uint32_t>::max()) {
255 return NN_ERROR() << "DataLocation: offset must be <= std::numeric_limits<uint32_t>::max()";
256 }
257 if (location.length > std::numeric_limits<uint32_t>::max()) {
258 return NN_ERROR() << "DataLocation: length must be <= std::numeric_limits<uint32_t>::max()";
259 }
Xusong Wang5e36ca02021-02-16 10:40:32 -0800260 if (location.padding > std::numeric_limits<uint32_t>::max()) {
261 return NN_ERROR()
262 << "DataLocation: padding must be <= std::numeric_limits<uint32_t>::max()";
263 }
Lev Proleev6b6dfcd2020-11-11 18:28:50 +0000264 return DataLocation{
265 .poolIndex = static_cast<uint32_t>(location.poolIndex),
266 .offset = static_cast<uint32_t>(location.offset),
267 .length = static_cast<uint32_t>(location.length),
Xusong Wang5e36ca02021-02-16 10:40:32 -0800268 .padding = static_cast<uint32_t>(location.padding),
Lev Proleev6b6dfcd2020-11-11 18:28:50 +0000269 };
270}
271
272GeneralResult<Operation> unvalidatedConvert(const aidl_hal::Operation& operation) {
273 return Operation{
274 .type = NN_TRY(unvalidatedConvert(operation.type)),
275 .inputs = NN_TRY(toUnsigned(operation.inputs)),
276 .outputs = NN_TRY(toUnsigned(operation.outputs)),
277 };
278}
279
280GeneralResult<Operand::LifeTime> unvalidatedConvert(
281 const aidl_hal::OperandLifeTime& operandLifeTime) {
282 return static_cast<Operand::LifeTime>(operandLifeTime);
283}
284
285GeneralResult<Operand> unvalidatedConvert(const aidl_hal::Operand& operand) {
286 return Operand{
287 .type = NN_TRY(unvalidatedConvert(operand.type)),
288 .dimensions = NN_TRY(toUnsigned(operand.dimensions)),
289 .scale = operand.scale,
290 .zeroPoint = operand.zeroPoint,
291 .lifetime = NN_TRY(unvalidatedConvert(operand.lifetime)),
292 .location = NN_TRY(unvalidatedConvert(operand.location)),
293 .extraParams = NN_TRY(unvalidatedConvert(operand.extraParams)),
294 };
295}
296
297GeneralResult<Operand::ExtraParams> unvalidatedConvert(
298 const std::optional<aidl_hal::OperandExtraParams>& optionalExtraParams) {
299 if (!optionalExtraParams.has_value()) {
300 return Operand::NoParams{};
301 }
302 const auto& extraParams = optionalExtraParams.value();
303 using Tag = aidl_hal::OperandExtraParams::Tag;
304 switch (extraParams.getTag()) {
305 case Tag::channelQuant:
306 return unvalidatedConvert(extraParams.get<Tag::channelQuant>());
307 case Tag::extension:
308 return extraParams.get<Tag::extension>();
309 }
310 return NN_ERROR() << "Unrecognized Operand::ExtraParams tag: "
311 << underlyingType(extraParams.getTag());
312}
313
314GeneralResult<Operand::SymmPerChannelQuantParams> unvalidatedConvert(
315 const aidl_hal::SymmPerChannelQuantParams& symmPerChannelQuantParams) {
316 VERIFY_NON_NEGATIVE(symmPerChannelQuantParams.channelDim)
317 << "Per-channel quantization channel dimension must not be negative.";
318 return Operand::SymmPerChannelQuantParams{
319 .scales = symmPerChannelQuantParams.scales,
320 .channelDim = static_cast<uint32_t>(symmPerChannelQuantParams.channelDim),
321 };
322}
323
324GeneralResult<Model> unvalidatedConvert(const aidl_hal::Model& model) {
325 return Model{
326 .main = NN_TRY(unvalidatedConvert(model.main)),
327 .referenced = NN_TRY(unvalidatedConvert(model.referenced)),
328 .operandValues = NN_TRY(unvalidatedConvert(model.operandValues)),
329 .pools = NN_TRY(unvalidatedConvert(model.pools)),
330 .relaxComputationFloat32toFloat16 = model.relaxComputationFloat32toFloat16,
331 .extensionNameToPrefix = NN_TRY(unvalidatedConvert(model.extensionNameToPrefix)),
332 };
333}
334
335GeneralResult<Model::Subgraph> unvalidatedConvert(const aidl_hal::Subgraph& subgraph) {
336 return Model::Subgraph{
337 .operands = NN_TRY(unvalidatedConvert(subgraph.operands)),
338 .operations = NN_TRY(unvalidatedConvert(subgraph.operations)),
339 .inputIndexes = NN_TRY(toUnsigned(subgraph.inputIndexes)),
340 .outputIndexes = NN_TRY(toUnsigned(subgraph.outputIndexes)),
341 };
342}
343
344GeneralResult<Model::ExtensionNameAndPrefix> unvalidatedConvert(
345 const aidl_hal::ExtensionNameAndPrefix& extensionNameAndPrefix) {
346 return Model::ExtensionNameAndPrefix{
347 .name = extensionNameAndPrefix.name,
348 .prefix = extensionNameAndPrefix.prefix,
349 };
350}
351
352GeneralResult<Extension> unvalidatedConvert(const aidl_hal::Extension& extension) {
353 return Extension{
354 .name = extension.name,
355 .operandTypes = NN_TRY(unvalidatedConvert(extension.operandTypes)),
356 };
357}
358
359GeneralResult<Extension::OperandTypeInformation> unvalidatedConvert(
360 const aidl_hal::ExtensionOperandTypeInformation& operandTypeInformation) {
361 VERIFY_NON_NEGATIVE(operandTypeInformation.byteSize)
362 << "Extension operand type byte size must not be negative";
363 return Extension::OperandTypeInformation{
364 .type = operandTypeInformation.type,
365 .isTensor = operandTypeInformation.isTensor,
366 .byteSize = static_cast<uint32_t>(operandTypeInformation.byteSize),
367 };
368}
369
370GeneralResult<OutputShape> unvalidatedConvert(const aidl_hal::OutputShape& outputShape) {
371 return OutputShape{
372 .dimensions = NN_TRY(toUnsigned(outputShape.dimensions)),
373 .isSufficient = outputShape.isSufficient,
374 };
375}
376
377GeneralResult<MeasureTiming> unvalidatedConvert(bool measureTiming) {
378 return measureTiming ? MeasureTiming::YES : MeasureTiming::NO;
379}
380
Michael Butlerab2f4822021-02-08 00:05:07 -0800381static uint32_t roundUpToMultiple(uint32_t value, uint32_t multiple) {
382 return (value + multiple - 1) / multiple * multiple;
383}
384
Michael Butlerfadeb8a2021-02-07 00:11:13 -0800385GeneralResult<SharedMemory> unvalidatedConvert(const aidl_hal::Memory& memory) {
Lev Proleev6b6dfcd2020-11-11 18:28:50 +0000386 VERIFY_NON_NEGATIVE(memory.size) << "Memory size must not be negative";
Lev Proleev900c28a2021-01-26 19:40:20 +0000387 if (memory.size > std::numeric_limits<size_t>::max()) {
Michael Butlerab2f4822021-02-08 00:05:07 -0800388 return NN_ERROR() << "Memory: size must be <= std::numeric_limits<size_t>::max()";
389 }
390
391 if (memory.name != "hardware_buffer_blob") {
392 return std::make_shared<const Memory>(Memory{
393 .handle = NN_TRY(unvalidatedConvertHelper(memory.handle)),
Lev Proleev900c28a2021-01-26 19:40:20 +0000394 .size = static_cast<size_t>(memory.size),
Michael Butlerab2f4822021-02-08 00:05:07 -0800395 .name = memory.name,
396 });
397 }
398
399 const auto size = static_cast<uint32_t>(memory.size);
400 const auto format = AHARDWAREBUFFER_FORMAT_BLOB;
401 const auto usage = AHARDWAREBUFFER_USAGE_CPU_READ_OFTEN | AHARDWAREBUFFER_USAGE_CPU_WRITE_OFTEN;
402 const uint32_t width = size;
403 const uint32_t height = 1; // height is always 1 for BLOB mode AHardwareBuffer.
404 const uint32_t layers = 1; // layers is always 1 for BLOB mode AHardwareBuffer.
405
406 const UniqueNativeHandle handle = NN_TRY(nativeHandleFromAidlHandle(memory.handle));
407 const native_handle_t* nativeHandle = handle.get();
408
409 // AHardwareBuffer_createFromHandle() might fail because an allocator
410 // expects a specific stride value. In that case, we try to guess it by
411 // aligning the width to small powers of 2.
412 // TODO(b/174120849): Avoid stride assumptions.
413 AHardwareBuffer* hardwareBuffer = nullptr;
414 status_t status = UNKNOWN_ERROR;
415 for (uint32_t alignment : {1, 4, 32, 64, 128, 2, 8, 16}) {
416 const uint32_t stride = roundUpToMultiple(width, alignment);
417 AHardwareBuffer_Desc desc{
418 .width = width,
419 .height = height,
420 .layers = layers,
421 .format = format,
422 .usage = usage,
423 .stride = stride,
424 };
425 status = AHardwareBuffer_createFromHandle(&desc, nativeHandle,
426 AHARDWAREBUFFER_CREATE_FROM_HANDLE_METHOD_CLONE,
427 &hardwareBuffer);
428 if (status == NO_ERROR) {
429 break;
430 }
431 }
432 if (status != NO_ERROR) {
433 return NN_ERROR(ErrorStatus::GENERAL_FAILURE)
434 << "Can't create AHardwareBuffer from handle. Error: " << status;
435 }
436
Michael Butlerfadeb8a2021-02-07 00:11:13 -0800437 return std::make_shared<const Memory>(Memory{
Michael Butlerab2f4822021-02-08 00:05:07 -0800438 .handle = HardwareBufferHandle(hardwareBuffer, /*takeOwnership=*/true),
Lev Proleev900c28a2021-01-26 19:40:20 +0000439 .size = static_cast<size_t>(memory.size),
Lev Proleev6b6dfcd2020-11-11 18:28:50 +0000440 .name = memory.name,
Michael Butlerfadeb8a2021-02-07 00:11:13 -0800441 });
Lev Proleev6b6dfcd2020-11-11 18:28:50 +0000442}
443
Lev Proleev900c28a2021-01-26 19:40:20 +0000444GeneralResult<Timing> unvalidatedConvert(const aidl_hal::Timing& timing) {
445 if (timing.timeInDriver < -1) {
446 return NN_ERROR() << "Timing: timeInDriver must not be less than -1";
447 }
448 if (timing.timeOnDevice < -1) {
449 return NN_ERROR() << "Timing: timeOnDevice must not be less than -1";
450 }
451 constexpr auto convertTiming = [](int64_t halTiming) -> OptionalDuration {
452 if (halTiming == kNoTiming) {
453 return {};
454 }
455 return nn::Duration(static_cast<uint64_t>(halTiming));
456 };
457 return Timing{.timeOnDevice = convertTiming(timing.timeOnDevice),
458 .timeInDriver = convertTiming(timing.timeInDriver)};
459}
460
Lev Proleev6b6dfcd2020-11-11 18:28:50 +0000461GeneralResult<Model::OperandValues> unvalidatedConvert(const std::vector<uint8_t>& operandValues) {
462 return Model::OperandValues(operandValues.data(), operandValues.size());
463}
464
465GeneralResult<BufferDesc> unvalidatedConvert(const aidl_hal::BufferDesc& bufferDesc) {
466 return BufferDesc{.dimensions = NN_TRY(toUnsigned(bufferDesc.dimensions))};
467}
468
469GeneralResult<BufferRole> unvalidatedConvert(const aidl_hal::BufferRole& bufferRole) {
470 VERIFY_NON_NEGATIVE(bufferRole.modelIndex) << "BufferRole: modelIndex must not be negative";
471 VERIFY_NON_NEGATIVE(bufferRole.ioIndex) << "BufferRole: ioIndex must not be negative";
472 return BufferRole{
473 .modelIndex = static_cast<uint32_t>(bufferRole.modelIndex),
474 .ioIndex = static_cast<uint32_t>(bufferRole.ioIndex),
Xusong Wang3633d072021-03-19 13:58:24 -0700475 .probability = bufferRole.probability,
Lev Proleev6b6dfcd2020-11-11 18:28:50 +0000476 };
477}
478
479GeneralResult<Request> unvalidatedConvert(const aidl_hal::Request& request) {
480 return Request{
481 .inputs = NN_TRY(unvalidatedConvert(request.inputs)),
482 .outputs = NN_TRY(unvalidatedConvert(request.outputs)),
483 .pools = NN_TRY(unvalidatedConvert(request.pools)),
484 };
485}
486
487GeneralResult<Request::Argument> unvalidatedConvert(const aidl_hal::RequestArgument& argument) {
488 const auto lifetime = argument.hasNoValue ? Request::Argument::LifeTime::NO_VALUE
489 : Request::Argument::LifeTime::POOL;
490 return Request::Argument{
491 .lifetime = lifetime,
492 .location = NN_TRY(unvalidatedConvert(argument.location)),
493 .dimensions = NN_TRY(toUnsigned(argument.dimensions)),
494 };
495}
496
497GeneralResult<Request::MemoryPool> unvalidatedConvert(
498 const aidl_hal::RequestMemoryPool& memoryPool) {
499 using Tag = aidl_hal::RequestMemoryPool::Tag;
500 switch (memoryPool.getTag()) {
501 case Tag::pool:
502 return unvalidatedConvert(memoryPool.get<Tag::pool>());
503 case Tag::token: {
504 const auto token = memoryPool.get<Tag::token>();
505 VERIFY_NON_NEGATIVE(token) << "Memory pool token must not be negative";
506 return static_cast<Request::MemoryDomainToken>(token);
507 }
508 }
509 return NN_ERROR() << "Invalid Request::MemoryPool tag " << underlyingType(memoryPool.getTag());
510}
511
512GeneralResult<ErrorStatus> unvalidatedConvert(const aidl_hal::ErrorStatus& status) {
513 switch (status) {
514 case aidl_hal::ErrorStatus::NONE:
515 case aidl_hal::ErrorStatus::DEVICE_UNAVAILABLE:
516 case aidl_hal::ErrorStatus::GENERAL_FAILURE:
517 case aidl_hal::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE:
518 case aidl_hal::ErrorStatus::INVALID_ARGUMENT:
519 case aidl_hal::ErrorStatus::MISSED_DEADLINE_TRANSIENT:
520 case aidl_hal::ErrorStatus::MISSED_DEADLINE_PERSISTENT:
521 case aidl_hal::ErrorStatus::RESOURCE_EXHAUSTED_TRANSIENT:
522 case aidl_hal::ErrorStatus::RESOURCE_EXHAUSTED_PERSISTENT:
523 return static_cast<ErrorStatus>(status);
524 }
525 return NN_ERROR() << "Invalid ErrorStatus " << underlyingType(status);
526}
527
528GeneralResult<ExecutionPreference> unvalidatedConvert(
529 const aidl_hal::ExecutionPreference& executionPreference) {
530 return static_cast<ExecutionPreference>(executionPreference);
531}
532
Michael Butlerfadeb8a2021-02-07 00:11:13 -0800533GeneralResult<SharedHandle> unvalidatedConvert(const NativeHandle& aidlNativeHandle) {
Michael Butlerab2f4822021-02-08 00:05:07 -0800534 return std::make_shared<const Handle>(NN_TRY(unvalidatedConvertHelper(aidlNativeHandle)));
Lev Proleev6b6dfcd2020-11-11 18:28:50 +0000535}
536
Lev Proleev900c28a2021-01-26 19:40:20 +0000537GeneralResult<SyncFence> unvalidatedConvert(const ndk::ScopedFileDescriptor& syncFence) {
538 auto duplicatedFd = NN_TRY(dupFd(syncFence.get()));
539 return SyncFence::create(std::move(duplicatedFd));
540}
541
542GeneralResult<Capabilities> convert(const aidl_hal::Capabilities& capabilities) {
543 return validatedConvert(capabilities);
544}
545
546GeneralResult<DeviceType> convert(const aidl_hal::DeviceType& deviceType) {
547 return validatedConvert(deviceType);
548}
549
550GeneralResult<ErrorStatus> convert(const aidl_hal::ErrorStatus& errorStatus) {
551 return validatedConvert(errorStatus);
552}
553
Lev Proleev6b6dfcd2020-11-11 18:28:50 +0000554GeneralResult<ExecutionPreference> convert(
555 const aidl_hal::ExecutionPreference& executionPreference) {
556 return validatedConvert(executionPreference);
557}
558
Michael Butlerfadeb8a2021-02-07 00:11:13 -0800559GeneralResult<SharedMemory> convert(const aidl_hal::Memory& operand) {
Lev Proleev6b6dfcd2020-11-11 18:28:50 +0000560 return validatedConvert(operand);
561}
562
563GeneralResult<Model> convert(const aidl_hal::Model& model) {
564 return validatedConvert(model);
565}
566
567GeneralResult<Operand> convert(const aidl_hal::Operand& operand) {
568 return unvalidatedConvert(operand);
569}
570
571GeneralResult<OperandType> convert(const aidl_hal::OperandType& operandType) {
572 return unvalidatedConvert(operandType);
573}
574
575GeneralResult<Priority> convert(const aidl_hal::Priority& priority) {
576 return validatedConvert(priority);
577}
578
579GeneralResult<Request::MemoryPool> convert(const aidl_hal::RequestMemoryPool& memoryPool) {
580 return unvalidatedConvert(memoryPool);
581}
582
583GeneralResult<Request> convert(const aidl_hal::Request& request) {
584 return validatedConvert(request);
585}
586
Lev Proleev900c28a2021-01-26 19:40:20 +0000587GeneralResult<Timing> convert(const aidl_hal::Timing& timing) {
588 return validatedConvert(timing);
589}
590
591GeneralResult<SyncFence> convert(const ndk::ScopedFileDescriptor& syncFence) {
592 return unvalidatedConvert(syncFence);
593}
594
595GeneralResult<std::vector<Extension>> convert(const std::vector<aidl_hal::Extension>& extension) {
596 return validatedConvert(extension);
597}
598
Lev Proleev6b6dfcd2020-11-11 18:28:50 +0000599GeneralResult<std::vector<Operation>> convert(const std::vector<aidl_hal::Operation>& operations) {
600 return unvalidatedConvert(operations);
601}
602
Michael Butlerfadeb8a2021-02-07 00:11:13 -0800603GeneralResult<std::vector<SharedMemory>> convert(const std::vector<aidl_hal::Memory>& memories) {
Lev Proleev6b6dfcd2020-11-11 18:28:50 +0000604 return validatedConvert(memories);
605}
606
Lev Proleev900c28a2021-01-26 19:40:20 +0000607GeneralResult<std::vector<OutputShape>> convert(
608 const std::vector<aidl_hal::OutputShape>& outputShapes) {
609 return validatedConvert(outputShapes);
610}
611
Lev Proleev6b6dfcd2020-11-11 18:28:50 +0000612GeneralResult<std::vector<uint32_t>> toUnsigned(const std::vector<int32_t>& vec) {
613 if (!std::all_of(vec.begin(), vec.end(), [](int32_t v) { return v >= 0; })) {
614 return NN_ERROR() << "Negative value passed to conversion from signed to unsigned";
615 }
616 return std::vector<uint32_t>(vec.begin(), vec.end());
617}
618
619} // namespace android::nn
620
621namespace aidl::android::hardware::neuralnetworks::utils {
622namespace {
623
624template <typename Input>
625using UnvalidatedConvertOutput =
626 std::decay_t<decltype(unvalidatedConvert(std::declval<Input>()).value())>;
627
628template <typename Type>
629nn::GeneralResult<std::vector<UnvalidatedConvertOutput<Type>>> unvalidatedConvertVec(
630 const std::vector<Type>& arguments) {
Lev Proleev900c28a2021-01-26 19:40:20 +0000631 std::vector<UnvalidatedConvertOutput<Type>> halObject;
632 halObject.reserve(arguments.size());
633 for (const auto& argument : arguments) {
634 halObject.push_back(NN_TRY(unvalidatedConvert(argument)));
Lev Proleev6b6dfcd2020-11-11 18:28:50 +0000635 }
636 return halObject;
637}
638
639template <typename Type>
Lev Proleev900c28a2021-01-26 19:40:20 +0000640nn::GeneralResult<std::vector<UnvalidatedConvertOutput<Type>>> unvalidatedConvert(
641 const std::vector<Type>& arguments) {
642 return unvalidatedConvertVec(arguments);
643}
644
645template <typename Type>
Lev Proleev6b6dfcd2020-11-11 18:28:50 +0000646nn::GeneralResult<UnvalidatedConvertOutput<Type>> validatedConvert(const Type& canonical) {
647 const auto maybeVersion = nn::validate(canonical);
648 if (!maybeVersion.has_value()) {
649 return nn::error() << maybeVersion.error();
650 }
651 const auto version = maybeVersion.value();
652 if (version > kVersion) {
653 return NN_ERROR() << "Insufficient version: " << version << " vs required " << kVersion;
654 }
655 return utils::unvalidatedConvert(canonical);
656}
657
658template <typename Type>
659nn::GeneralResult<std::vector<UnvalidatedConvertOutput<Type>>> validatedConvert(
660 const std::vector<Type>& arguments) {
661 std::vector<UnvalidatedConvertOutput<Type>> halObject(arguments.size());
662 for (size_t i = 0; i < arguments.size(); ++i) {
663 halObject[i] = NN_TRY(validatedConvert(arguments[i]));
664 }
665 return halObject;
666}
667
Michael Butlerab2f4822021-02-08 00:05:07 -0800668nn::GeneralResult<common::NativeHandle> unvalidatedConvert(const nn::Handle& handle) {
Lev Proleev6b6dfcd2020-11-11 18:28:50 +0000669 common::NativeHandle aidlNativeHandle;
Michael Butlerab2f4822021-02-08 00:05:07 -0800670 aidlNativeHandle.fds.reserve(handle.fds.size());
671 for (const auto& fd : handle.fds) {
Lev Proleev900c28a2021-01-26 19:40:20 +0000672 auto duplicatedFd = NN_TRY(nn::dupFd(fd.get()));
673 aidlNativeHandle.fds.emplace_back(duplicatedFd.release());
Lev Proleev6b6dfcd2020-11-11 18:28:50 +0000674 }
Michael Butlerab2f4822021-02-08 00:05:07 -0800675 aidlNativeHandle.ints = handle.ints;
Lev Proleev6b6dfcd2020-11-11 18:28:50 +0000676 return aidlNativeHandle;
677}
678
Lev Proleev900c28a2021-01-26 19:40:20 +0000679// Helper template for std::visit
680template <class... Ts>
681struct overloaded : Ts... {
682 using Ts::operator()...;
683};
684template <class... Ts>
685overloaded(Ts...)->overloaded<Ts...>;
686
Michael Butlerab2f4822021-02-08 00:05:07 -0800687static nn::GeneralResult<common::NativeHandle> aidlHandleFromNativeHandle(
688 const native_handle_t& handle) {
689 common::NativeHandle aidlNativeHandle;
690
691 aidlNativeHandle.fds.reserve(handle.numFds);
692 for (int i = 0; i < handle.numFds; ++i) {
Lev Proleev900c28a2021-01-26 19:40:20 +0000693 auto duplicatedFd = NN_TRY(nn::dupFd(handle.data[i]));
694 aidlNativeHandle.fds.emplace_back(duplicatedFd.release());
Michael Butlerab2f4822021-02-08 00:05:07 -0800695 }
696
697 aidlNativeHandle.ints = std::vector<int>(&handle.data[handle.numFds],
698 &handle.data[handle.numFds + handle.numInts]);
699
700 return aidlNativeHandle;
701}
702
703} // namespace
704
Lev Proleev900c28a2021-01-26 19:40:20 +0000705nn::GeneralResult<std::vector<uint8_t>> unvalidatedConvert(const nn::CacheToken& cacheToken) {
706 return std::vector<uint8_t>(cacheToken.begin(), cacheToken.end());
707}
708
709nn::GeneralResult<BufferDesc> unvalidatedConvert(const nn::BufferDesc& bufferDesc) {
710 return BufferDesc{.dimensions = NN_TRY(toSigned(bufferDesc.dimensions))};
711}
712
713nn::GeneralResult<BufferRole> unvalidatedConvert(const nn::BufferRole& bufferRole) {
714 VERIFY_LE_INT32_MAX(bufferRole.modelIndex)
715 << "BufferRole: modelIndex must be <= std::numeric_limits<int32_t>::max()";
716 VERIFY_LE_INT32_MAX(bufferRole.ioIndex)
717 << "BufferRole: ioIndex must be <= std::numeric_limits<int32_t>::max()";
718 return BufferRole{
719 .modelIndex = static_cast<int32_t>(bufferRole.modelIndex),
720 .ioIndex = static_cast<int32_t>(bufferRole.ioIndex),
Xusong Wang3633d072021-03-19 13:58:24 -0700721 .probability = bufferRole.probability,
Lev Proleev900c28a2021-01-26 19:40:20 +0000722 };
723}
724
725nn::GeneralResult<bool> unvalidatedConvert(const nn::MeasureTiming& measureTiming) {
726 return measureTiming == nn::MeasureTiming::YES;
727}
728
Michael Butlerab2f4822021-02-08 00:05:07 -0800729nn::GeneralResult<common::NativeHandle> unvalidatedConvert(const nn::SharedHandle& sharedHandle) {
730 CHECK(sharedHandle != nullptr);
731 return unvalidatedConvert(*sharedHandle);
732}
733
Michael Butlerfadeb8a2021-02-07 00:11:13 -0800734nn::GeneralResult<Memory> unvalidatedConvert(const nn::SharedMemory& memory) {
735 CHECK(memory != nullptr);
736 if (memory->size > std::numeric_limits<int64_t>::max()) {
Lev Proleev6b6dfcd2020-11-11 18:28:50 +0000737 return NN_ERROR() << "Memory size doesn't fit into int64_t.";
738 }
Michael Butlerab2f4822021-02-08 00:05:07 -0800739 if (const auto* handle = std::get_if<nn::Handle>(&memory->handle)) {
740 return Memory{
741 .handle = NN_TRY(unvalidatedConvert(*handle)),
742 .size = static_cast<int64_t>(memory->size),
743 .name = memory->name,
744 };
745 }
746
747 const auto* ahwb = std::get<nn::HardwareBufferHandle>(memory->handle).get();
748 AHardwareBuffer_Desc bufferDesc;
749 AHardwareBuffer_describe(ahwb, &bufferDesc);
750
751 if (bufferDesc.format == AHARDWAREBUFFER_FORMAT_BLOB) {
752 CHECK_EQ(memory->size, bufferDesc.width);
753 CHECK_EQ(memory->name, "hardware_buffer_blob");
754 } else {
755 CHECK_EQ(memory->size, 0u);
756 CHECK_EQ(memory->name, "hardware_buffer");
757 }
758
759 const native_handle_t* nativeHandle = AHardwareBuffer_getNativeHandle(ahwb);
760 if (nativeHandle == nullptr) {
761 return NN_ERROR() << "unvalidatedConvert failed because AHardwareBuffer_getNativeHandle "
762 "returned nullptr";
763 }
764
Lev Proleev6b6dfcd2020-11-11 18:28:50 +0000765 return Memory{
Michael Butlerab2f4822021-02-08 00:05:07 -0800766 .handle = NN_TRY(aidlHandleFromNativeHandle(*nativeHandle)),
Michael Butlerfadeb8a2021-02-07 00:11:13 -0800767 .size = static_cast<int64_t>(memory->size),
768 .name = memory->name,
Lev Proleev6b6dfcd2020-11-11 18:28:50 +0000769 };
770}
771
772nn::GeneralResult<ErrorStatus> unvalidatedConvert(const nn::ErrorStatus& errorStatus) {
773 switch (errorStatus) {
774 case nn::ErrorStatus::NONE:
775 case nn::ErrorStatus::DEVICE_UNAVAILABLE:
776 case nn::ErrorStatus::GENERAL_FAILURE:
777 case nn::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE:
778 case nn::ErrorStatus::INVALID_ARGUMENT:
779 case nn::ErrorStatus::MISSED_DEADLINE_TRANSIENT:
780 case nn::ErrorStatus::MISSED_DEADLINE_PERSISTENT:
781 case nn::ErrorStatus::RESOURCE_EXHAUSTED_TRANSIENT:
782 case nn::ErrorStatus::RESOURCE_EXHAUSTED_PERSISTENT:
783 return static_cast<ErrorStatus>(errorStatus);
784 default:
785 return ErrorStatus::GENERAL_FAILURE;
786 }
787}
788
789nn::GeneralResult<OutputShape> unvalidatedConvert(const nn::OutputShape& outputShape) {
790 return OutputShape{.dimensions = NN_TRY(toSigned(outputShape.dimensions)),
791 .isSufficient = outputShape.isSufficient};
792}
793
Lev Proleev900c28a2021-01-26 19:40:20 +0000794nn::GeneralResult<ExecutionPreference> unvalidatedConvert(
795 const nn::ExecutionPreference& executionPreference) {
796 return static_cast<ExecutionPreference>(executionPreference);
797}
798
799nn::GeneralResult<OperandType> unvalidatedConvert(const nn::OperandType& operandType) {
800 return static_cast<OperandType>(operandType);
801}
802
803nn::GeneralResult<OperandLifeTime> unvalidatedConvert(
804 const nn::Operand::LifeTime& operandLifeTime) {
805 return static_cast<OperandLifeTime>(operandLifeTime);
806}
807
808nn::GeneralResult<DataLocation> unvalidatedConvert(const nn::DataLocation& location) {
809 VERIFY_LE_INT32_MAX(location.poolIndex)
810 << "DataLocation: pool index must be <= std::numeric_limits<int32_t>::max()";
811 return DataLocation{
812 .poolIndex = static_cast<int32_t>(location.poolIndex),
813 .offset = static_cast<int64_t>(location.offset),
814 .length = static_cast<int64_t>(location.length),
815 };
816}
817
818nn::GeneralResult<std::optional<OperandExtraParams>> unvalidatedConvert(
819 const nn::Operand::ExtraParams& extraParams) {
820 return std::visit(
821 overloaded{
822 [](const nn::Operand::NoParams&)
823 -> nn::GeneralResult<std::optional<OperandExtraParams>> {
824 return std::nullopt;
825 },
826 [](const nn::Operand::SymmPerChannelQuantParams& symmPerChannelQuantParams)
827 -> nn::GeneralResult<std::optional<OperandExtraParams>> {
828 if (symmPerChannelQuantParams.channelDim >
829 std::numeric_limits<int32_t>::max()) {
830 // Using explicit type conversion because std::optional in successful
831 // result confuses the compiler.
832 return (NN_ERROR() << "symmPerChannelQuantParams.channelDim must be <= "
833 "std::numeric_limits<int32_t>::max(), received: "
834 << symmPerChannelQuantParams.channelDim)
835 .
836 operator nn::GeneralResult<std::optional<OperandExtraParams>>();
837 }
838 return OperandExtraParams::make<OperandExtraParams::Tag::channelQuant>(
839 SymmPerChannelQuantParams{
840 .scales = symmPerChannelQuantParams.scales,
841 .channelDim = static_cast<int32_t>(
842 symmPerChannelQuantParams.channelDim),
843 });
844 },
845 [](const nn::Operand::ExtensionParams& extensionParams)
846 -> nn::GeneralResult<std::optional<OperandExtraParams>> {
847 return OperandExtraParams::make<OperandExtraParams::Tag::extension>(
848 extensionParams);
849 },
850 },
851 extraParams);
852}
853
854nn::GeneralResult<Operand> unvalidatedConvert(const nn::Operand& operand) {
855 return Operand{
856 .type = NN_TRY(unvalidatedConvert(operand.type)),
857 .dimensions = NN_TRY(toSigned(operand.dimensions)),
858 .scale = operand.scale,
859 .zeroPoint = operand.zeroPoint,
860 .lifetime = NN_TRY(unvalidatedConvert(operand.lifetime)),
861 .location = NN_TRY(unvalidatedConvert(operand.location)),
862 .extraParams = NN_TRY(unvalidatedConvert(operand.extraParams)),
863 };
864}
865
866nn::GeneralResult<OperationType> unvalidatedConvert(const nn::OperationType& operationType) {
867 return static_cast<OperationType>(operationType);
868}
869
870nn::GeneralResult<Operation> unvalidatedConvert(const nn::Operation& operation) {
871 return Operation{
872 .type = NN_TRY(unvalidatedConvert(operation.type)),
873 .inputs = NN_TRY(toSigned(operation.inputs)),
874 .outputs = NN_TRY(toSigned(operation.outputs)),
875 };
876}
877
878nn::GeneralResult<Subgraph> unvalidatedConvert(const nn::Model::Subgraph& subgraph) {
879 return Subgraph{
880 .operands = NN_TRY(unvalidatedConvert(subgraph.operands)),
881 .operations = NN_TRY(unvalidatedConvert(subgraph.operations)),
882 .inputIndexes = NN_TRY(toSigned(subgraph.inputIndexes)),
883 .outputIndexes = NN_TRY(toSigned(subgraph.outputIndexes)),
884 };
885}
886
887nn::GeneralResult<std::vector<uint8_t>> unvalidatedConvert(
888 const nn::Model::OperandValues& operandValues) {
889 return std::vector<uint8_t>(operandValues.data(), operandValues.data() + operandValues.size());
890}
891
892nn::GeneralResult<ExtensionNameAndPrefix> unvalidatedConvert(
893 const nn::Model::ExtensionNameAndPrefix& extensionNameToPrefix) {
894 return ExtensionNameAndPrefix{
895 .name = extensionNameToPrefix.name,
896 .prefix = extensionNameToPrefix.prefix,
897 };
898}
899
900nn::GeneralResult<Model> unvalidatedConvert(const nn::Model& model) {
901 return Model{
902 .main = NN_TRY(unvalidatedConvert(model.main)),
903 .referenced = NN_TRY(unvalidatedConvert(model.referenced)),
904 .operandValues = NN_TRY(unvalidatedConvert(model.operandValues)),
905 .pools = NN_TRY(unvalidatedConvert(model.pools)),
906 .relaxComputationFloat32toFloat16 = model.relaxComputationFloat32toFloat16,
907 .extensionNameToPrefix = NN_TRY(unvalidatedConvert(model.extensionNameToPrefix)),
908 };
909}
910
911nn::GeneralResult<Priority> unvalidatedConvert(const nn::Priority& priority) {
912 return static_cast<Priority>(priority);
913}
914
915nn::GeneralResult<Request> unvalidatedConvert(const nn::Request& request) {
916 return Request{
917 .inputs = NN_TRY(unvalidatedConvert(request.inputs)),
918 .outputs = NN_TRY(unvalidatedConvert(request.outputs)),
919 .pools = NN_TRY(unvalidatedConvert(request.pools)),
920 };
921}
922
923nn::GeneralResult<RequestArgument> unvalidatedConvert(
924 const nn::Request::Argument& requestArgument) {
925 if (requestArgument.lifetime == nn::Request::Argument::LifeTime::POINTER) {
926 return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT)
927 << "Request cannot be unvalidatedConverted because it contains pointer-based memory";
928 }
929 const bool hasNoValue = requestArgument.lifetime == nn::Request::Argument::LifeTime::NO_VALUE;
930 return RequestArgument{
931 .hasNoValue = hasNoValue,
932 .location = NN_TRY(unvalidatedConvert(requestArgument.location)),
933 .dimensions = NN_TRY(toSigned(requestArgument.dimensions)),
934 };
935}
936
937nn::GeneralResult<RequestMemoryPool> unvalidatedConvert(const nn::Request::MemoryPool& memoryPool) {
938 return std::visit(
939 overloaded{
940 [](const nn::SharedMemory& memory) -> nn::GeneralResult<RequestMemoryPool> {
941 return RequestMemoryPool::make<RequestMemoryPool::Tag::pool>(
942 NN_TRY(unvalidatedConvert(memory)));
943 },
944 [](const nn::Request::MemoryDomainToken& token)
945 -> nn::GeneralResult<RequestMemoryPool> {
946 return RequestMemoryPool::make<RequestMemoryPool::Tag::token>(
947 underlyingType(token));
948 },
949 [](const nn::SharedBuffer& /*buffer*/) {
950 return (NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
951 << "Unable to make memory pool from IBuffer")
952 .
953 operator nn::GeneralResult<RequestMemoryPool>();
954 },
955 },
956 memoryPool);
957}
958
959nn::GeneralResult<Timing> unvalidatedConvert(const nn::Timing& timing) {
960 return Timing{
961 .timeOnDevice = NN_TRY(unvalidatedConvert(timing.timeOnDevice)),
962 .timeInDriver = NN_TRY(unvalidatedConvert(timing.timeInDriver)),
963 };
964}
965
966nn::GeneralResult<int64_t> unvalidatedConvert(const nn::Duration& duration) {
967 const uint64_t nanoseconds = duration.count();
968 if (nanoseconds > std::numeric_limits<int64_t>::max()) {
969 return std::numeric_limits<int64_t>::max();
970 }
971 return static_cast<int64_t>(nanoseconds);
972}
973
974nn::GeneralResult<int64_t> unvalidatedConvert(const nn::OptionalDuration& optionalDuration) {
975 if (!optionalDuration.has_value()) {
976 return kNoTiming;
977 }
978 return unvalidatedConvert(optionalDuration.value());
979}
980
981nn::GeneralResult<int64_t> unvalidatedConvert(const nn::OptionalTimePoint& optionalTimePoint) {
982 if (!optionalTimePoint.has_value()) {
983 return kNoTiming;
984 }
985 return unvalidatedConvert(optionalTimePoint->time_since_epoch());
986}
987
988nn::GeneralResult<ndk::ScopedFileDescriptor> unvalidatedConvert(const nn::SyncFence& syncFence) {
989 auto duplicatedFd = NN_TRY(nn::dupFd(syncFence.getFd()));
990 return ndk::ScopedFileDescriptor(duplicatedFd.release());
991}
992
993nn::GeneralResult<ndk::ScopedFileDescriptor> unvalidatedConvertCache(
994 const nn::SharedHandle& handle) {
995 if (handle->ints.size() != 0) {
996 NN_ERROR() << "Cache handle must not contain ints";
997 }
998 if (handle->fds.size() != 1) {
999 NN_ERROR() << "Cache handle must contain exactly one fd but contains "
1000 << handle->fds.size();
1001 }
1002 auto duplicatedFd = NN_TRY(nn::dupFd(handle->fds.front().get()));
1003 return ndk::ScopedFileDescriptor(duplicatedFd.release());
1004}
1005
1006nn::GeneralResult<std::vector<uint8_t>> convert(const nn::CacheToken& cacheToken) {
1007 return unvalidatedConvert(cacheToken);
1008}
1009
1010nn::GeneralResult<BufferDesc> convert(const nn::BufferDesc& bufferDesc) {
1011 return validatedConvert(bufferDesc);
1012}
1013
1014nn::GeneralResult<bool> convert(const nn::MeasureTiming& measureTiming) {
1015 return validatedConvert(measureTiming);
1016}
1017
Michael Butlerfadeb8a2021-02-07 00:11:13 -08001018nn::GeneralResult<Memory> convert(const nn::SharedMemory& memory) {
Lev Proleev6b6dfcd2020-11-11 18:28:50 +00001019 return validatedConvert(memory);
1020}
1021
1022nn::GeneralResult<ErrorStatus> convert(const nn::ErrorStatus& errorStatus) {
1023 return validatedConvert(errorStatus);
1024}
1025
Lev Proleev900c28a2021-01-26 19:40:20 +00001026nn::GeneralResult<ExecutionPreference> convert(const nn::ExecutionPreference& executionPreference) {
1027 return validatedConvert(executionPreference);
1028}
1029
1030nn::GeneralResult<Model> convert(const nn::Model& model) {
1031 return validatedConvert(model);
1032}
1033
1034nn::GeneralResult<Priority> convert(const nn::Priority& priority) {
1035 return validatedConvert(priority);
1036}
1037
1038nn::GeneralResult<Request> convert(const nn::Request& request) {
1039 return validatedConvert(request);
1040}
1041
1042nn::GeneralResult<Timing> convert(const nn::Timing& timing) {
1043 return validatedConvert(timing);
1044}
1045
1046nn::GeneralResult<int64_t> convert(const nn::OptionalDuration& optionalDuration) {
1047 return validatedConvert(optionalDuration);
1048}
1049
1050nn::GeneralResult<int64_t> convert(const nn::OptionalTimePoint& outputShapes) {
1051 return validatedConvert(outputShapes);
1052}
1053
1054nn::GeneralResult<std::vector<BufferRole>> convert(const std::vector<nn::BufferRole>& bufferRoles) {
1055 return validatedConvert(bufferRoles);
1056}
1057
Lev Proleev6b6dfcd2020-11-11 18:28:50 +00001058nn::GeneralResult<std::vector<OutputShape>> convert(
1059 const std::vector<nn::OutputShape>& outputShapes) {
1060 return validatedConvert(outputShapes);
1061}
1062
Lev Proleev900c28a2021-01-26 19:40:20 +00001063nn::GeneralResult<std::vector<ndk::ScopedFileDescriptor>> convert(
1064 const std::vector<nn::SharedHandle>& cacheHandles) {
1065 const auto version = NN_TRY(hal::utils::makeGeneralFailure(nn::validate(cacheHandles)));
1066 if (version > kVersion) {
1067 return NN_ERROR() << "Insufficient version: " << version << " vs required " << kVersion;
1068 }
1069 std::vector<ndk::ScopedFileDescriptor> cacheFds;
1070 cacheFds.reserve(cacheHandles.size());
1071 for (const auto& cacheHandle : cacheHandles) {
1072 cacheFds.push_back(NN_TRY(unvalidatedConvertCache(cacheHandle)));
1073 }
1074 return cacheFds;
1075}
1076
1077nn::GeneralResult<std::vector<ndk::ScopedFileDescriptor>> convert(
1078 const std::vector<nn::SyncFence>& syncFences) {
1079 return unvalidatedConvert(syncFences);
1080}
1081
Lev Proleev6b6dfcd2020-11-11 18:28:50 +00001082nn::GeneralResult<std::vector<int32_t>> toSigned(const std::vector<uint32_t>& vec) {
1083 if (!std::all_of(vec.begin(), vec.end(),
1084 [](uint32_t v) { return v <= std::numeric_limits<int32_t>::max(); })) {
1085 return NN_ERROR() << "Vector contains a value that doesn't fit into int32_t.";
1086 }
1087 return std::vector<int32_t>(vec.begin(), vec.end());
1088}
1089
1090} // namespace aidl::android::hardware::neuralnetworks::utils