Validate during NN conversions by default -- hal
This change renames all `convert` functions to `unvalidatedConvert`.
This change also introduces new `convert` functions that act only on the
types that appear in the NN HIDL methods directly. These new `convert`
functions perform validation. Specifically, if either the source or
destination value is invalid, then the conversion fails.
Bug: 160667419
Test: mma
Test: NeuralNetworksTest_static
Change-Id: I492956ff60ad1466c67893993d28cdd6f3860708
Merged-In: I492956ff60ad1466c67893993d28cdd6f3860708
(cherry picked from commit 32acc0614402a35eed3407116ec359f4fdb60ecc)
diff --git a/neuralnetworks/1.3/utils/src/Conversions.cpp b/neuralnetworks/1.3/utils/src/Conversions.cpp
index 0dc0785..949dd0d 100644
--- a/neuralnetworks/1.3/utils/src/Conversions.cpp
+++ b/neuralnetworks/1.3/utils/src/Conversions.cpp
@@ -24,6 +24,7 @@
#include <nnapi/SharedMemory.h>
#include <nnapi/TypeUtils.h>
#include <nnapi/Types.h>
+#include <nnapi/Validation.h>
#include <nnapi/hal/1.0/Conversions.h>
#include <nnapi/hal/1.2/Conversions.h>
#include <nnapi/hal/CommonUtils.h>
@@ -44,6 +45,8 @@
return static_cast<std::underlying_type_t<Type>>(value);
}
+constexpr auto kVersion = android::nn::Version::ANDROID_R;
+
} // namespace
namespace android::nn {
@@ -77,110 +80,140 @@
using hardware::hidl_vec;
template <typename Input>
-using ConvertOutput = std::decay_t<decltype(convert(std::declval<Input>()).value())>;
+using unvalidatedConvertOutput =
+ std::decay_t<decltype(unvalidatedConvert(std::declval<Input>()).value())>;
template <typename Type>
-GeneralResult<std::vector<ConvertOutput<Type>>> convertVec(const hidl_vec<Type>& arguments) {
- std::vector<ConvertOutput<Type>> canonical;
+GeneralResult<std::vector<unvalidatedConvertOutput<Type>>> unvalidatedConvertVec(
+ const hidl_vec<Type>& arguments) {
+ std::vector<unvalidatedConvertOutput<Type>> canonical;
canonical.reserve(arguments.size());
for (const auto& argument : arguments) {
- canonical.push_back(NN_TRY(nn::convert(argument)));
+ canonical.push_back(NN_TRY(nn::unvalidatedConvert(argument)));
}
return canonical;
}
template <typename Type>
-GeneralResult<std::vector<ConvertOutput<Type>>> convert(const hidl_vec<Type>& arguments) {
- return convertVec(arguments);
+GeneralResult<std::vector<unvalidatedConvertOutput<Type>>> unvalidatedConvert(
+ const hidl_vec<Type>& arguments) {
+ return unvalidatedConvertVec(arguments);
+}
+
+template <typename Type>
+decltype(nn::unvalidatedConvert(std::declval<Type>())) validatedConvert(const Type& halObject) {
+ auto canonical = NN_TRY(nn::unvalidatedConvert(halObject));
+ const auto maybeVersion = validate(canonical);
+ if (!maybeVersion.has_value()) {
+ return error() << maybeVersion.error();
+ }
+ const auto version = maybeVersion.value();
+ if (version > kVersion) {
+ return NN_ERROR() << "Insufficient version: " << version << " vs required " << kVersion;
+ }
+ return canonical;
+}
+
+template <typename Type>
+GeneralResult<std::vector<unvalidatedConvertOutput<Type>>> validatedConvert(
+ const hidl_vec<Type>& arguments) {
+ std::vector<unvalidatedConvertOutput<Type>> canonical;
+ canonical.reserve(arguments.size());
+ for (const auto& argument : arguments) {
+ canonical.push_back(NN_TRY(validatedConvert(argument)));
+ }
+ return canonical;
}
} // anonymous namespace
-GeneralResult<OperandType> convert(const hal::V1_3::OperandType& operandType) {
+GeneralResult<OperandType> unvalidatedConvert(const hal::V1_3::OperandType& operandType) {
return static_cast<OperandType>(operandType);
}
-GeneralResult<OperationType> convert(const hal::V1_3::OperationType& operationType) {
+GeneralResult<OperationType> unvalidatedConvert(const hal::V1_3::OperationType& operationType) {
return static_cast<OperationType>(operationType);
}
-GeneralResult<Priority> convert(const hal::V1_3::Priority& priority) {
+GeneralResult<Priority> unvalidatedConvert(const hal::V1_3::Priority& priority) {
return static_cast<Priority>(priority);
}
-GeneralResult<Capabilities> convert(const hal::V1_3::Capabilities& capabilities) {
+GeneralResult<Capabilities> unvalidatedConvert(const hal::V1_3::Capabilities& capabilities) {
const bool validOperandTypes = std::all_of(
capabilities.operandPerformance.begin(), capabilities.operandPerformance.end(),
[](const hal::V1_3::Capabilities::OperandPerformance& operandPerformance) {
- const auto maybeType = convert(operandPerformance.type);
+ const auto maybeType = unvalidatedConvert(operandPerformance.type);
return !maybeType.has_value() ? false : validOperandType(maybeType.value());
});
if (!validOperandTypes) {
return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
- << "Invalid OperandType when converting OperandPerformance in Capabilities";
+ << "Invalid OperandType when unvalidatedConverting OperandPerformance in "
+ "Capabilities";
}
- auto operandPerformance = NN_TRY(convert(capabilities.operandPerformance));
+ auto operandPerformance = NN_TRY(unvalidatedConvert(capabilities.operandPerformance));
auto table = NN_TRY(hal::utils::makeGeneralFailure(
Capabilities::OperandPerformanceTable::create(std::move(operandPerformance)),
nn::ErrorStatus::GENERAL_FAILURE));
return Capabilities{
- .relaxedFloat32toFloat16PerformanceScalar =
- NN_TRY(convert(capabilities.relaxedFloat32toFloat16PerformanceScalar)),
- .relaxedFloat32toFloat16PerformanceTensor =
- NN_TRY(convert(capabilities.relaxedFloat32toFloat16PerformanceTensor)),
+ .relaxedFloat32toFloat16PerformanceScalar = NN_TRY(
+ unvalidatedConvert(capabilities.relaxedFloat32toFloat16PerformanceScalar)),
+ .relaxedFloat32toFloat16PerformanceTensor = NN_TRY(
+ unvalidatedConvert(capabilities.relaxedFloat32toFloat16PerformanceTensor)),
.operandPerformance = std::move(table),
- .ifPerformance = NN_TRY(convert(capabilities.ifPerformance)),
- .whilePerformance = NN_TRY(convert(capabilities.whilePerformance)),
+ .ifPerformance = NN_TRY(unvalidatedConvert(capabilities.ifPerformance)),
+ .whilePerformance = NN_TRY(unvalidatedConvert(capabilities.whilePerformance)),
};
}
-GeneralResult<Capabilities::OperandPerformance> convert(
+GeneralResult<Capabilities::OperandPerformance> unvalidatedConvert(
const hal::V1_3::Capabilities::OperandPerformance& operandPerformance) {
return Capabilities::OperandPerformance{
- .type = NN_TRY(convert(operandPerformance.type)),
- .info = NN_TRY(convert(operandPerformance.info)),
+ .type = NN_TRY(unvalidatedConvert(operandPerformance.type)),
+ .info = NN_TRY(unvalidatedConvert(operandPerformance.info)),
};
}
-GeneralResult<Operation> convert(const hal::V1_3::Operation& operation) {
+GeneralResult<Operation> unvalidatedConvert(const hal::V1_3::Operation& operation) {
return Operation{
- .type = NN_TRY(convert(operation.type)),
+ .type = NN_TRY(unvalidatedConvert(operation.type)),
.inputs = operation.inputs,
.outputs = operation.outputs,
};
}
-GeneralResult<Operand::LifeTime> convert(const hal::V1_3::OperandLifeTime& operandLifeTime) {
+GeneralResult<Operand::LifeTime> unvalidatedConvert(
+ const hal::V1_3::OperandLifeTime& operandLifeTime) {
return static_cast<Operand::LifeTime>(operandLifeTime);
}
-GeneralResult<Operand> convert(const hal::V1_3::Operand& operand) {
+GeneralResult<Operand> unvalidatedConvert(const hal::V1_3::Operand& operand) {
return Operand{
- .type = NN_TRY(convert(operand.type)),
+ .type = NN_TRY(unvalidatedConvert(operand.type)),
.dimensions = operand.dimensions,
.scale = operand.scale,
.zeroPoint = operand.zeroPoint,
- .lifetime = NN_TRY(convert(operand.lifetime)),
- .location = NN_TRY(convert(operand.location)),
- .extraParams = NN_TRY(convert(operand.extraParams)),
+ .lifetime = NN_TRY(unvalidatedConvert(operand.lifetime)),
+ .location = NN_TRY(unvalidatedConvert(operand.location)),
+ .extraParams = NN_TRY(unvalidatedConvert(operand.extraParams)),
};
}
-GeneralResult<Model> convert(const hal::V1_3::Model& model) {
+GeneralResult<Model> unvalidatedConvert(const hal::V1_3::Model& model) {
return Model{
- .main = NN_TRY(convert(model.main)),
- .referenced = NN_TRY(convert(model.referenced)),
- .operandValues = NN_TRY(convert(model.operandValues)),
- .pools = NN_TRY(convert(model.pools)),
+ .main = NN_TRY(unvalidatedConvert(model.main)),
+ .referenced = NN_TRY(unvalidatedConvert(model.referenced)),
+ .operandValues = NN_TRY(unvalidatedConvert(model.operandValues)),
+ .pools = NN_TRY(unvalidatedConvert(model.pools)),
.relaxComputationFloat32toFloat16 = model.relaxComputationFloat32toFloat16,
- .extensionNameToPrefix = NN_TRY(convert(model.extensionNameToPrefix)),
+ .extensionNameToPrefix = NN_TRY(unvalidatedConvert(model.extensionNameToPrefix)),
};
}
-GeneralResult<Model::Subgraph> convert(const hal::V1_3::Subgraph& subgraph) {
- auto operations = NN_TRY(convert(subgraph.operations));
+GeneralResult<Model::Subgraph> unvalidatedConvert(const hal::V1_3::Subgraph& subgraph) {
+ auto operations = NN_TRY(unvalidatedConvert(subgraph.operations));
// Verify number of consumers.
const auto numberOfConsumers =
@@ -196,18 +229,18 @@
}
return Model::Subgraph{
- .operands = NN_TRY(convert(subgraph.operands)),
+ .operands = NN_TRY(unvalidatedConvert(subgraph.operands)),
.operations = std::move(operations),
.inputIndexes = subgraph.inputIndexes,
.outputIndexes = subgraph.outputIndexes,
};
}
-GeneralResult<BufferDesc> convert(const hal::V1_3::BufferDesc& bufferDesc) {
+GeneralResult<BufferDesc> unvalidatedConvert(const hal::V1_3::BufferDesc& bufferDesc) {
return BufferDesc{.dimensions = bufferDesc.dimensions};
}
-GeneralResult<BufferRole> convert(const hal::V1_3::BufferRole& bufferRole) {
+GeneralResult<BufferRole> unvalidatedConvert(const hal::V1_3::BufferRole& bufferRole) {
return BufferRole{
.modelIndex = bufferRole.modelIndex,
.ioIndex = bufferRole.ioIndex,
@@ -215,15 +248,16 @@
};
}
-GeneralResult<Request> convert(const hal::V1_3::Request& request) {
+GeneralResult<Request> unvalidatedConvert(const hal::V1_3::Request& request) {
return Request{
- .inputs = NN_TRY(convert(request.inputs)),
- .outputs = NN_TRY(convert(request.outputs)),
- .pools = NN_TRY(convert(request.pools)),
+ .inputs = NN_TRY(unvalidatedConvert(request.inputs)),
+ .outputs = NN_TRY(unvalidatedConvert(request.outputs)),
+ .pools = NN_TRY(unvalidatedConvert(request.pools)),
};
}
-GeneralResult<Request::MemoryPool> convert(const hal::V1_3::Request::MemoryPool& memoryPool) {
+GeneralResult<Request::MemoryPool> unvalidatedConvert(
+ const hal::V1_3::Request::MemoryPool& memoryPool) {
using Discriminator = hal::V1_3::Request::MemoryPool::hidl_discriminator;
switch (memoryPool.getDiscriminator()) {
case Discriminator::hidlMemory:
@@ -236,12 +270,14 @@
<< underlyingType(memoryPool.getDiscriminator());
}
-GeneralResult<OptionalTimePoint> convert(const hal::V1_3::OptionalTimePoint& optionalTimePoint) {
+GeneralResult<OptionalTimePoint> unvalidatedConvert(
+ const hal::V1_3::OptionalTimePoint& optionalTimePoint) {
constexpr auto kTimePointMaxCount = TimePoint::max().time_since_epoch().count();
const auto makeTimePoint = [](uint64_t count) -> GeneralResult<OptionalTimePoint> {
if (count > kTimePointMaxCount) {
return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
- << "Unable to convert OptionalTimePoint because the count exceeds the max";
+ << "Unable to unvalidatedConvert OptionalTimePoint because the count exceeds "
+ "the max";
}
const auto nanoseconds = std::chrono::nanoseconds{count};
return TimePoint{nanoseconds};
@@ -259,13 +295,14 @@
<< underlyingType(optionalTimePoint.getDiscriminator());
}
-GeneralResult<OptionalTimeoutDuration> convert(
+GeneralResult<OptionalTimeoutDuration> unvalidatedConvert(
const hal::V1_3::OptionalTimeoutDuration& optionalTimeoutDuration) {
constexpr auto kTimeoutDurationMaxCount = TimeoutDuration::max().count();
const auto makeTimeoutDuration = [](uint64_t count) -> GeneralResult<OptionalTimeoutDuration> {
if (count > kTimeoutDurationMaxCount) {
return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
- << "Unable to convert OptionalTimeoutDuration because the count exceeds the max";
+ << "Unable to unvalidatedConvert OptionalTimeoutDuration because the count "
+ "exceeds the max";
}
return TimeoutDuration{count};
};
@@ -282,7 +319,7 @@
<< underlyingType(optionalTimeoutDuration.getDiscriminator());
}
-GeneralResult<ErrorStatus> convert(const hal::V1_3::ErrorStatus& status) {
+GeneralResult<ErrorStatus> unvalidatedConvert(const hal::V1_3::ErrorStatus& status) {
switch (status) {
case hal::V1_3::ErrorStatus::NONE:
case hal::V1_3::ErrorStatus::DEVICE_UNAVAILABLE:
@@ -299,9 +336,50 @@
<< "Invalid ErrorStatus " << underlyingType(status);
}
+GeneralResult<Priority> convert(const hal::V1_3::Priority& priority) {
+ return validatedConvert(priority);
+}
+
+GeneralResult<Capabilities> convert(const hal::V1_3::Capabilities& capabilities) {
+ return validatedConvert(capabilities);
+}
+
+GeneralResult<Model> convert(const hal::V1_3::Model& model) {
+ return validatedConvert(model);
+}
+
+GeneralResult<BufferDesc> convert(const hal::V1_3::BufferDesc& bufferDesc) {
+ return validatedConvert(bufferDesc);
+}
+
+GeneralResult<Request> convert(const hal::V1_3::Request& request) {
+ return validatedConvert(request);
+}
+
+GeneralResult<OptionalTimePoint> convert(const hal::V1_3::OptionalTimePoint& optionalTimePoint) {
+ return validatedConvert(optionalTimePoint);
+}
+
+GeneralResult<OptionalTimeoutDuration> convert(
+ const hal::V1_3::OptionalTimeoutDuration& optionalTimeoutDuration) {
+ return validatedConvert(optionalTimeoutDuration);
+}
+
+GeneralResult<ErrorStatus> convert(const hal::V1_3::ErrorStatus& errorStatus) {
+ return validatedConvert(errorStatus);
+}
+
+GeneralResult<SharedHandle> convert(const hardware::hidl_handle& handle) {
+ return validatedConvert(handle);
+}
+
+GeneralResult<Memory> convert(const hardware::hidl_memory& memory) {
+ return validatedConvert(memory);
+}
+
GeneralResult<std::vector<BufferRole>> convert(
const hardware::hidl_vec<hal::V1_3::BufferRole>& bufferRoles) {
- return convertVec(bufferRoles);
+ return validatedConvert(bufferRoles);
}
} // namespace android::nn
@@ -309,58 +387,67 @@
namespace android::hardware::neuralnetworks::V1_3::utils {
namespace {
-using utils::convert;
+using utils::unvalidatedConvert;
-nn::GeneralResult<V1_0::PerformanceInfo> convert(
+nn::GeneralResult<V1_0::PerformanceInfo> unvalidatedConvert(
const nn::Capabilities::PerformanceInfo& performanceInfo) {
- return V1_0::utils::convert(performanceInfo);
+ return V1_0::utils::unvalidatedConvert(performanceInfo);
}
-nn::GeneralResult<V1_0::DataLocation> convert(const nn::DataLocation& dataLocation) {
- return V1_0::utils::convert(dataLocation);
+nn::GeneralResult<V1_0::DataLocation> unvalidatedConvert(const nn::DataLocation& dataLocation) {
+ return V1_0::utils::unvalidatedConvert(dataLocation);
}
-nn::GeneralResult<hidl_vec<uint8_t>> convert(const nn::Model::OperandValues& operandValues) {
- return V1_0::utils::convert(operandValues);
+nn::GeneralResult<hidl_vec<uint8_t>> unvalidatedConvert(
+ const nn::Model::OperandValues& operandValues) {
+ return V1_0::utils::unvalidatedConvert(operandValues);
}
-nn::GeneralResult<hidl_memory> convert(const nn::Memory& memory) {
- return V1_0::utils::convert(memory);
+nn::GeneralResult<hidl_handle> unvalidatedConvert(const nn::SharedHandle& handle) {
+ return V1_2::utils::unvalidatedConvert(handle);
}
-nn::GeneralResult<V1_0::RequestArgument> convert(const nn::Request::Argument& argument) {
- return V1_0::utils::convert(argument);
+nn::GeneralResult<hidl_memory> unvalidatedConvert(const nn::Memory& memory) {
+ return V1_0::utils::unvalidatedConvert(memory);
}
-nn::GeneralResult<V1_2::Operand::ExtraParams> convert(const nn::Operand::ExtraParams& extraParams) {
- return V1_2::utils::convert(extraParams);
+nn::GeneralResult<V1_0::RequestArgument> unvalidatedConvert(const nn::Request::Argument& argument) {
+ return V1_0::utils::unvalidatedConvert(argument);
}
-nn::GeneralResult<V1_2::Model::ExtensionNameAndPrefix> convert(
+nn::GeneralResult<V1_2::Operand::ExtraParams> unvalidatedConvert(
+ const nn::Operand::ExtraParams& extraParams) {
+ return V1_2::utils::unvalidatedConvert(extraParams);
+}
+
+nn::GeneralResult<V1_2::Model::ExtensionNameAndPrefix> unvalidatedConvert(
const nn::Model::ExtensionNameAndPrefix& extensionNameAndPrefix) {
- return V1_2::utils::convert(extensionNameAndPrefix);
+ return V1_2::utils::unvalidatedConvert(extensionNameAndPrefix);
}
template <typename Input>
-using ConvertOutput = std::decay_t<decltype(convert(std::declval<Input>()).value())>;
+using unvalidatedConvertOutput =
+ std::decay_t<decltype(unvalidatedConvert(std::declval<Input>()).value())>;
template <typename Type>
-nn::GeneralResult<hidl_vec<ConvertOutput<Type>>> convertVec(const std::vector<Type>& arguments) {
- hidl_vec<ConvertOutput<Type>> halObject(arguments.size());
+nn::GeneralResult<hidl_vec<unvalidatedConvertOutput<Type>>> unvalidatedConvertVec(
+ const std::vector<Type>& arguments) {
+ hidl_vec<unvalidatedConvertOutput<Type>> halObject(arguments.size());
for (size_t i = 0; i < arguments.size(); ++i) {
- halObject[i] = NN_TRY(convert(arguments[i]));
+ halObject[i] = NN_TRY(unvalidatedConvert(arguments[i]));
}
return halObject;
}
template <typename Type>
-nn::GeneralResult<hidl_vec<ConvertOutput<Type>>> convert(const std::vector<Type>& arguments) {
- return convertVec(arguments);
+nn::GeneralResult<hidl_vec<unvalidatedConvertOutput<Type>>> unvalidatedConvert(
+ const std::vector<Type>& arguments) {
+ return unvalidatedConvertVec(arguments);
}
nn::GeneralResult<Request::MemoryPool> makeMemoryPool(const nn::Memory& memory) {
Request::MemoryPool ret;
- ret.hidlMemory(NN_TRY(convert(memory)));
+ ret.hidlMemory(NN_TRY(unvalidatedConvert(memory)));
return ret;
}
@@ -374,21 +461,46 @@
return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE) << "Unable to make memory pool from IBuffer";
}
+using utils::unvalidatedConvert;
+
+template <typename Type>
+decltype(unvalidatedConvert(std::declval<Type>())) validatedConvert(const Type& canonical) {
+ const auto maybeVersion = nn::validate(canonical);
+ if (!maybeVersion.has_value()) {
+ return nn::error() << maybeVersion.error();
+ }
+ const auto version = maybeVersion.value();
+ if (version > kVersion) {
+ return NN_ERROR() << "Insufficient version: " << version << " vs required " << kVersion;
+ }
+ return unvalidatedConvert(canonical);
+}
+
+template <typename Type>
+nn::GeneralResult<hidl_vec<unvalidatedConvertOutput<Type>>> validatedConvert(
+ const std::vector<Type>& arguments) {
+ hidl_vec<unvalidatedConvertOutput<Type>> halObject(arguments.size());
+ for (size_t i = 0; i < arguments.size(); ++i) {
+ halObject[i] = NN_TRY(validatedConvert(arguments[i]));
+ }
+ return halObject;
+}
+
} // anonymous namespace
-nn::GeneralResult<OperandType> convert(const nn::OperandType& operandType) {
+nn::GeneralResult<OperandType> unvalidatedConvert(const nn::OperandType& operandType) {
return static_cast<OperandType>(operandType);
}
-nn::GeneralResult<OperationType> convert(const nn::OperationType& operationType) {
+nn::GeneralResult<OperationType> unvalidatedConvert(const nn::OperationType& operationType) {
return static_cast<OperationType>(operationType);
}
-nn::GeneralResult<Priority> convert(const nn::Priority& priority) {
+nn::GeneralResult<Priority> unvalidatedConvert(const nn::Priority& priority) {
return static_cast<Priority>(priority);
}
-nn::GeneralResult<Capabilities> convert(const nn::Capabilities& capabilities) {
+nn::GeneralResult<Capabilities> unvalidatedConvert(const nn::Capabilities& capabilities) {
std::vector<nn::Capabilities::OperandPerformance> operandPerformance;
operandPerformance.reserve(capabilities.operandPerformance.asVector().size());
std::copy_if(capabilities.operandPerformance.asVector().begin(),
@@ -399,71 +511,72 @@
});
return Capabilities{
- .relaxedFloat32toFloat16PerformanceScalar =
- NN_TRY(convert(capabilities.relaxedFloat32toFloat16PerformanceScalar)),
- .relaxedFloat32toFloat16PerformanceTensor =
- NN_TRY(convert(capabilities.relaxedFloat32toFloat16PerformanceTensor)),
- .operandPerformance = NN_TRY(convert(operandPerformance)),
- .ifPerformance = NN_TRY(convert(capabilities.ifPerformance)),
- .whilePerformance = NN_TRY(convert(capabilities.whilePerformance)),
+ .relaxedFloat32toFloat16PerformanceScalar = NN_TRY(
+ unvalidatedConvert(capabilities.relaxedFloat32toFloat16PerformanceScalar)),
+ .relaxedFloat32toFloat16PerformanceTensor = NN_TRY(
+ unvalidatedConvert(capabilities.relaxedFloat32toFloat16PerformanceTensor)),
+ .operandPerformance = NN_TRY(unvalidatedConvert(operandPerformance)),
+ .ifPerformance = NN_TRY(unvalidatedConvert(capabilities.ifPerformance)),
+ .whilePerformance = NN_TRY(unvalidatedConvert(capabilities.whilePerformance)),
};
}
-nn::GeneralResult<Capabilities::OperandPerformance> convert(
+nn::GeneralResult<Capabilities::OperandPerformance> unvalidatedConvert(
const nn::Capabilities::OperandPerformance& operandPerformance) {
return Capabilities::OperandPerformance{
- .type = NN_TRY(convert(operandPerformance.type)),
- .info = NN_TRY(convert(operandPerformance.info)),
+ .type = NN_TRY(unvalidatedConvert(operandPerformance.type)),
+ .info = NN_TRY(unvalidatedConvert(operandPerformance.info)),
};
}
-nn::GeneralResult<Operation> convert(const nn::Operation& operation) {
+nn::GeneralResult<Operation> unvalidatedConvert(const nn::Operation& operation) {
return Operation{
- .type = NN_TRY(convert(operation.type)),
+ .type = NN_TRY(unvalidatedConvert(operation.type)),
.inputs = operation.inputs,
.outputs = operation.outputs,
};
}
-nn::GeneralResult<OperandLifeTime> convert(const nn::Operand::LifeTime& operandLifeTime) {
+nn::GeneralResult<OperandLifeTime> unvalidatedConvert(
+ const nn::Operand::LifeTime& operandLifeTime) {
if (operandLifeTime == nn::Operand::LifeTime::POINTER) {
return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT)
- << "Model cannot be converted because it contains pointer-based memory";
+ << "Model cannot be unvalidatedConverted because it contains pointer-based memory";
}
return static_cast<OperandLifeTime>(operandLifeTime);
}
-nn::GeneralResult<Operand> convert(const nn::Operand& operand) {
+nn::GeneralResult<Operand> unvalidatedConvert(const nn::Operand& operand) {
return Operand{
- .type = NN_TRY(convert(operand.type)),
+ .type = NN_TRY(unvalidatedConvert(operand.type)),
.dimensions = operand.dimensions,
.numberOfConsumers = 0,
.scale = operand.scale,
.zeroPoint = operand.zeroPoint,
- .lifetime = NN_TRY(convert(operand.lifetime)),
- .location = NN_TRY(convert(operand.location)),
- .extraParams = NN_TRY(convert(operand.extraParams)),
+ .lifetime = NN_TRY(unvalidatedConvert(operand.lifetime)),
+ .location = NN_TRY(unvalidatedConvert(operand.location)),
+ .extraParams = NN_TRY(unvalidatedConvert(operand.extraParams)),
};
}
-nn::GeneralResult<Model> convert(const nn::Model& model) {
+nn::GeneralResult<Model> unvalidatedConvert(const nn::Model& model) {
if (!hal::utils::hasNoPointerData(model)) {
return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT)
- << "Model cannot be converted because it contains pointer-based memory";
+ << "Model cannot be unvalidatedConverted because it contains pointer-based memory";
}
return Model{
- .main = NN_TRY(convert(model.main)),
- .referenced = NN_TRY(convert(model.referenced)),
- .operandValues = NN_TRY(convert(model.operandValues)),
- .pools = NN_TRY(convert(model.pools)),
+ .main = NN_TRY(unvalidatedConvert(model.main)),
+ .referenced = NN_TRY(unvalidatedConvert(model.referenced)),
+ .operandValues = NN_TRY(unvalidatedConvert(model.operandValues)),
+ .pools = NN_TRY(unvalidatedConvert(model.pools)),
.relaxComputationFloat32toFloat16 = model.relaxComputationFloat32toFloat16,
- .extensionNameToPrefix = NN_TRY(convert(model.extensionNameToPrefix)),
+ .extensionNameToPrefix = NN_TRY(unvalidatedConvert(model.extensionNameToPrefix)),
};
}
-nn::GeneralResult<Subgraph> convert(const nn::Model::Subgraph& subgraph) {
- auto operands = NN_TRY(convert(subgraph.operands));
+nn::GeneralResult<Subgraph> unvalidatedConvert(const nn::Model::Subgraph& subgraph) {
+ auto operands = NN_TRY(unvalidatedConvert(subgraph.operands));
// Update number of consumers.
const auto numberOfConsumers =
@@ -475,17 +588,17 @@
return Subgraph{
.operands = std::move(operands),
- .operations = NN_TRY(convert(subgraph.operations)),
+ .operations = NN_TRY(unvalidatedConvert(subgraph.operations)),
.inputIndexes = subgraph.inputIndexes,
.outputIndexes = subgraph.outputIndexes,
};
}
-nn::GeneralResult<BufferDesc> convert(const nn::BufferDesc& bufferDesc) {
+nn::GeneralResult<BufferDesc> unvalidatedConvert(const nn::BufferDesc& bufferDesc) {
return BufferDesc{.dimensions = bufferDesc.dimensions};
}
-nn::GeneralResult<BufferRole> convert(const nn::BufferRole& bufferRole) {
+nn::GeneralResult<BufferRole> unvalidatedConvert(const nn::BufferRole& bufferRole) {
return BufferRole{
.modelIndex = bufferRole.modelIndex,
.ioIndex = bufferRole.ioIndex,
@@ -493,30 +606,33 @@
};
}
-nn::GeneralResult<Request> convert(const nn::Request& request) {
+nn::GeneralResult<Request> unvalidatedConvert(const nn::Request& request) {
if (!hal::utils::hasNoPointerData(request)) {
return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT)
- << "Request cannot be converted because it contains pointer-based memory";
+ << "Request cannot be unvalidatedConverted because it contains pointer-based memory";
}
return Request{
- .inputs = NN_TRY(convert(request.inputs)),
- .outputs = NN_TRY(convert(request.outputs)),
- .pools = NN_TRY(convert(request.pools)),
+ .inputs = NN_TRY(unvalidatedConvert(request.inputs)),
+ .outputs = NN_TRY(unvalidatedConvert(request.outputs)),
+ .pools = NN_TRY(unvalidatedConvert(request.pools)),
};
}
-nn::GeneralResult<Request::MemoryPool> convert(const nn::Request::MemoryPool& memoryPool) {
+nn::GeneralResult<Request::MemoryPool> unvalidatedConvert(
+ const nn::Request::MemoryPool& memoryPool) {
return std::visit([](const auto& o) { return makeMemoryPool(o); }, memoryPool);
}
-nn::GeneralResult<OptionalTimePoint> convert(const nn::OptionalTimePoint& optionalTimePoint) {
+nn::GeneralResult<OptionalTimePoint> unvalidatedConvert(
+ const nn::OptionalTimePoint& optionalTimePoint) {
OptionalTimePoint ret;
if (optionalTimePoint.has_value()) {
const auto count = optionalTimePoint.value().time_since_epoch().count();
if (count < 0) {
return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
- << "Unable to convert OptionalTimePoint because time since epoch count is "
+ << "Unable to unvalidatedConvert OptionalTimePoint because time since epoch "
+ "count is "
"negative";
}
ret.nanosecondsSinceEpoch(count);
@@ -524,21 +640,22 @@
return ret;
}
-nn::GeneralResult<OptionalTimeoutDuration> convert(
+nn::GeneralResult<OptionalTimeoutDuration> unvalidatedConvert(
const nn::OptionalTimeoutDuration& optionalTimeoutDuration) {
OptionalTimeoutDuration ret;
if (optionalTimeoutDuration.has_value()) {
const auto count = optionalTimeoutDuration.value().count();
if (count < 0) {
return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
- << "Unable to convert OptionalTimeoutDuration because count is negative";
+ << "Unable to unvalidatedConvert OptionalTimeoutDuration because count is "
+ "negative";
}
ret.nanoseconds(count);
}
return ret;
}
-nn::GeneralResult<ErrorStatus> convert(const nn::ErrorStatus& errorStatus) {
+nn::GeneralResult<ErrorStatus> unvalidatedConvert(const nn::ErrorStatus& errorStatus) {
switch (errorStatus) {
case nn::ErrorStatus::NONE:
case nn::ErrorStatus::DEVICE_UNAVAILABLE:
@@ -555,8 +672,49 @@
}
}
+nn::GeneralResult<Priority> convert(const nn::Priority& priority) {
+ return validatedConvert(priority);
+}
+
+nn::GeneralResult<Capabilities> convert(const nn::Capabilities& capabilities) {
+ return validatedConvert(capabilities);
+}
+
+nn::GeneralResult<Model> convert(const nn::Model& model) {
+ return validatedConvert(model);
+}
+
+nn::GeneralResult<BufferDesc> convert(const nn::BufferDesc& bufferDesc) {
+ return validatedConvert(bufferDesc);
+}
+
+nn::GeneralResult<Request> convert(const nn::Request& request) {
+ return validatedConvert(request);
+}
+
+nn::GeneralResult<OptionalTimePoint> convert(const nn::OptionalTimePoint& optionalTimePoint) {
+ return validatedConvert(optionalTimePoint);
+}
+
+nn::GeneralResult<OptionalTimeoutDuration> convert(
+ const nn::OptionalTimeoutDuration& optionalTimeoutDuration) {
+ return validatedConvert(optionalTimeoutDuration);
+}
+
+nn::GeneralResult<ErrorStatus> convert(const nn::ErrorStatus& errorStatus) {
+ return validatedConvert(errorStatus);
+}
+
+nn::GeneralResult<hidl_handle> convert(const nn::SharedHandle& handle) {
+ return validatedConvert(handle);
+}
+
+nn::GeneralResult<hidl_memory> convert(const nn::Memory& memory) {
+ return validatedConvert(memory);
+}
+
nn::GeneralResult<hidl_vec<BufferRole>> convert(const std::vector<nn::BufferRole>& bufferRoles) {
- return convertVec(bufferRoles);
+ return validatedConvert(bufferRoles);
}
} // namespace android::hardware::neuralnetworks::V1_3::utils