Validate during NN conversions by default -- hal
This change renames all `convert` functions to `unvalidatedConvert`.
This change also introduces new `convert` functions that act only on the
types that appear in the NN HIDL methods directly. These new `convert`
functions perform validation. Specifically, if either the source or
destination value is invalid, then the conversion fails.
Bug: 160667419
Test: mma
Test: NeuralNetworksTest_static
Change-Id: I492956ff60ad1466c67893993d28cdd6f3860708
Merged-In: I492956ff60ad1466c67893993d28cdd6f3860708
(cherry picked from commit 32acc0614402a35eed3407116ec359f4fdb60ecc)
diff --git a/neuralnetworks/1.0/utils/include/nnapi/hal/1.0/Conversions.h b/neuralnetworks/1.0/utils/include/nnapi/hal/1.0/Conversions.h
index fb77cb2..d3d933b 100644
--- a/neuralnetworks/1.0/utils/include/nnapi/hal/1.0/Conversions.h
+++ b/neuralnetworks/1.0/utils/include/nnapi/hal/1.0/Conversions.h
@@ -24,20 +24,28 @@
namespace android::nn {
-GeneralResult<OperandType> convert(const hal::V1_0::OperandType& operandType);
-GeneralResult<OperationType> convert(const hal::V1_0::OperationType& operationType);
-GeneralResult<Operand::LifeTime> convert(const hal::V1_0::OperandLifeTime& lifetime);
-GeneralResult<DeviceStatus> convert(const hal::V1_0::DeviceStatus& deviceStatus);
-GeneralResult<Capabilities::PerformanceInfo> convert(
+GeneralResult<OperandType> unvalidatedConvert(const hal::V1_0::OperandType& operandType);
+GeneralResult<OperationType> unvalidatedConvert(const hal::V1_0::OperationType& operationType);
+GeneralResult<Operand::LifeTime> unvalidatedConvert(const hal::V1_0::OperandLifeTime& lifetime);
+GeneralResult<DeviceStatus> unvalidatedConvert(const hal::V1_0::DeviceStatus& deviceStatus);
+GeneralResult<Capabilities::PerformanceInfo> unvalidatedConvert(
const hal::V1_0::PerformanceInfo& performanceInfo);
+GeneralResult<Capabilities> unvalidatedConvert(const hal::V1_0::Capabilities& capabilities);
+GeneralResult<DataLocation> unvalidatedConvert(const hal::V1_0::DataLocation& location);
+GeneralResult<Operand> unvalidatedConvert(const hal::V1_0::Operand& operand);
+GeneralResult<Operation> unvalidatedConvert(const hal::V1_0::Operation& operation);
+GeneralResult<Model::OperandValues> unvalidatedConvert(
+ const hardware::hidl_vec<uint8_t>& operandValues);
+GeneralResult<Memory> unvalidatedConvert(const hardware::hidl_memory& memory);
+GeneralResult<Model> unvalidatedConvert(const hal::V1_0::Model& model);
+GeneralResult<Request::Argument> unvalidatedConvert(
+ const hal::V1_0::RequestArgument& requestArgument);
+GeneralResult<Request> unvalidatedConvert(const hal::V1_0::Request& request);
+GeneralResult<ErrorStatus> unvalidatedConvert(const hal::V1_0::ErrorStatus& status);
+
+GeneralResult<DeviceStatus> convert(const hal::V1_0::DeviceStatus& deviceStatus);
GeneralResult<Capabilities> convert(const hal::V1_0::Capabilities& capabilities);
-GeneralResult<DataLocation> convert(const hal::V1_0::DataLocation& location);
-GeneralResult<Operand> convert(const hal::V1_0::Operand& operand);
-GeneralResult<Operation> convert(const hal::V1_0::Operation& operation);
-GeneralResult<Model::OperandValues> convert(const hardware::hidl_vec<uint8_t>& operandValues);
-GeneralResult<Memory> convert(const hardware::hidl_memory& memory);
GeneralResult<Model> convert(const hal::V1_0::Model& model);
-GeneralResult<Request::Argument> convert(const hal::V1_0::RequestArgument& requestArgument);
GeneralResult<Request> convert(const hal::V1_0::Request& request);
GeneralResult<ErrorStatus> convert(const hal::V1_0::ErrorStatus& status);
@@ -45,21 +53,28 @@
namespace android::hardware::neuralnetworks::V1_0::utils {
-nn::GeneralResult<OperandType> convert(const nn::OperandType& operandType);
-nn::GeneralResult<OperationType> convert(const nn::OperationType& operationType);
-nn::GeneralResult<OperandLifeTime> convert(const nn::Operand::LifeTime& lifetime);
-nn::GeneralResult<DeviceStatus> convert(const nn::DeviceStatus& deviceStatus);
-nn::GeneralResult<PerformanceInfo> convert(
+nn::GeneralResult<OperandType> unvalidatedConvert(const nn::OperandType& operandType);
+nn::GeneralResult<OperationType> unvalidatedConvert(const nn::OperationType& operationType);
+nn::GeneralResult<OperandLifeTime> unvalidatedConvert(const nn::Operand::LifeTime& lifetime);
+nn::GeneralResult<DeviceStatus> unvalidatedConvert(const nn::DeviceStatus& deviceStatus);
+nn::GeneralResult<PerformanceInfo> unvalidatedConvert(
const nn::Capabilities::PerformanceInfo& performanceInfo);
+nn::GeneralResult<Capabilities> unvalidatedConvert(const nn::Capabilities& capabilities);
+nn::GeneralResult<DataLocation> unvalidatedConvert(const nn::DataLocation& location);
+nn::GeneralResult<Operand> unvalidatedConvert(const nn::Operand& operand);
+nn::GeneralResult<Operation> unvalidatedConvert(const nn::Operation& operation);
+nn::GeneralResult<hidl_vec<uint8_t>> unvalidatedConvert(
+ const nn::Model::OperandValues& operandValues);
+nn::GeneralResult<hidl_memory> unvalidatedConvert(const nn::Memory& memory);
+nn::GeneralResult<Model> unvalidatedConvert(const nn::Model& model);
+nn::GeneralResult<RequestArgument> unvalidatedConvert(const nn::Request::Argument& requestArgument);
+nn::GeneralResult<hidl_memory> unvalidatedConvert(const nn::Request::MemoryPool& memoryPool);
+nn::GeneralResult<Request> unvalidatedConvert(const nn::Request& request);
+nn::GeneralResult<ErrorStatus> unvalidatedConvert(const nn::ErrorStatus& status);
+
+nn::GeneralResult<DeviceStatus> convert(const nn::DeviceStatus& deviceStatus);
nn::GeneralResult<Capabilities> convert(const nn::Capabilities& capabilities);
-nn::GeneralResult<DataLocation> convert(const nn::DataLocation& location);
-nn::GeneralResult<Operand> convert(const nn::Operand& operand);
-nn::GeneralResult<Operation> convert(const nn::Operation& operation);
-nn::GeneralResult<hidl_vec<uint8_t>> convert(const nn::Model::OperandValues& operandValues);
-nn::GeneralResult<hidl_memory> convert(const nn::Memory& memory);
nn::GeneralResult<Model> convert(const nn::Model& model);
-nn::GeneralResult<RequestArgument> convert(const nn::Request::Argument& requestArgument);
-nn::GeneralResult<hidl_memory> convert(const nn::Request::MemoryPool& memoryPool);
nn::GeneralResult<Request> convert(const nn::Request& request);
nn::GeneralResult<ErrorStatus> convert(const nn::ErrorStatus& status);
diff --git a/neuralnetworks/1.0/utils/include/nnapi/hal/1.0/Utils.h b/neuralnetworks/1.0/utils/include/nnapi/hal/1.0/Utils.h
index baa2b95..4cec545 100644
--- a/neuralnetworks/1.0/utils/include/nnapi/hal/1.0/Utils.h
+++ b/neuralnetworks/1.0/utils/include/nnapi/hal/1.0/Utils.h
@@ -22,25 +22,16 @@
#include <android-base/logging.h>
#include <android/hardware/neuralnetworks/1.0/types.h>
#include <nnapi/Result.h>
-#include <nnapi/TypeUtils.h>
#include <nnapi/Types.h>
-#include <nnapi/Validation.h>
namespace android::hardware::neuralnetworks::V1_0::utils {
-constexpr auto kVersion = nn::Version::ANDROID_OC_MR1;
-
template <typename Type>
nn::Result<void> validate(const Type& halObject) {
const auto maybeCanonical = nn::convert(halObject);
if (!maybeCanonical.has_value()) {
return nn::error() << maybeCanonical.error().message;
}
- const auto version = NN_TRY(nn::validate(maybeCanonical.value()));
- if (version > utils::kVersion) {
- return NN_ERROR() << "Insufficient version: " << version << " vs required "
- << utils::kVersion;
- }
return {};
}
@@ -53,21 +44,6 @@
return result.has_value();
}
-template <typename Type>
-decltype(nn::convert(std::declval<Type>())) validatedConvertToCanonical(const Type& halObject) {
- auto canonical = NN_TRY(nn::convert(halObject));
- const auto maybeVersion = nn::validate(canonical);
- if (!maybeVersion.has_value()) {
- return nn::error() << maybeVersion.error();
- }
- const auto version = maybeVersion.value();
- if (version > utils::kVersion) {
- return NN_ERROR() << "Insufficient version: " << version << " vs required "
- << utils::kVersion;
- }
- return canonical;
-}
-
} // namespace android::hardware::neuralnetworks::V1_0::utils
#endif // ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_1_0_UTILS_H
diff --git a/neuralnetworks/1.0/utils/src/Callbacks.cpp b/neuralnetworks/1.0/utils/src/Callbacks.cpp
index f286bcc..b1259c3 100644
--- a/neuralnetworks/1.0/utils/src/Callbacks.cpp
+++ b/neuralnetworks/1.0/utils/src/Callbacks.cpp
@@ -45,8 +45,7 @@
Return<void> PreparedModelCallback::notify(ErrorStatus status,
const sp<IPreparedModel>& preparedModel) {
if (status != ErrorStatus::NONE) {
- const auto canonical =
- validatedConvertToCanonical(status).value_or(nn::ErrorStatus::GENERAL_FAILURE);
+ const auto canonical = nn::convert(status).value_or(nn::ErrorStatus::GENERAL_FAILURE);
notifyInternal(NN_ERROR(canonical) << "preparedModel failed with " << toString(status));
} else if (preparedModel == nullptr) {
notifyInternal(NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
@@ -73,8 +72,7 @@
Return<void> ExecutionCallback::notify(ErrorStatus status) {
if (status != ErrorStatus::NONE) {
- const auto canonical =
- validatedConvertToCanonical(status).value_or(nn::ErrorStatus::GENERAL_FAILURE);
+ const auto canonical = nn::convert(status).value_or(nn::ErrorStatus::GENERAL_FAILURE);
notifyInternal(NN_ERROR(canonical) << "execute failed with " << toString(status));
} else {
notifyInternal({});
diff --git a/neuralnetworks/1.0/utils/src/Conversions.cpp b/neuralnetworks/1.0/utils/src/Conversions.cpp
index 6cf9073..fde7346 100644
--- a/neuralnetworks/1.0/utils/src/Conversions.cpp
+++ b/neuralnetworks/1.0/utils/src/Conversions.cpp
@@ -22,7 +22,9 @@
#include <nnapi/OperationTypes.h>
#include <nnapi/Result.h>
#include <nnapi/SharedMemory.h>
+#include <nnapi/TypeUtils.h>
#include <nnapi/Types.h>
+#include <nnapi/Validation.h>
#include <nnapi/hal/CommonUtils.h>
#include <algorithm>
@@ -40,6 +42,8 @@
return static_cast<std::underlying_type_t<Type>>(value);
}
+constexpr auto kVersion = android::nn::Version::ANDROID_OC_MR1;
+
} // namespace
namespace android::nn {
@@ -49,37 +53,53 @@
using hardware::hidl_vec;
template <typename Input>
-using ConvertOutput = std::decay_t<decltype(convert(std::declval<Input>()).value())>;
+using unvalidatedConvertOutput =
+ std::decay_t<decltype(unvalidatedConvert(std::declval<Input>()).value())>;
template <typename Type>
-GeneralResult<std::vector<ConvertOutput<Type>>> convert(const hidl_vec<Type>& arguments) {
- std::vector<ConvertOutput<Type>> canonical;
+GeneralResult<std::vector<unvalidatedConvertOutput<Type>>> unvalidatedConvert(
+ const hidl_vec<Type>& arguments) {
+ std::vector<unvalidatedConvertOutput<Type>> canonical;
canonical.reserve(arguments.size());
for (const auto& argument : arguments) {
- canonical.push_back(NN_TRY(nn::convert(argument)));
+ canonical.push_back(NN_TRY(nn::unvalidatedConvert(argument)));
+ }
+ return canonical;
+}
+
+template <typename Type>
+decltype(nn::unvalidatedConvert(std::declval<Type>())) validatedConvert(const Type& halObject) {
+ auto canonical = NN_TRY(nn::unvalidatedConvert(halObject));
+ const auto maybeVersion = validate(canonical);
+ if (!maybeVersion.has_value()) {
+ return error() << maybeVersion.error();
+ }
+ const auto version = maybeVersion.value();
+ if (version > kVersion) {
+ return NN_ERROR() << "Insufficient version: " << version << " vs required " << kVersion;
}
return canonical;
}
} // anonymous namespace
-GeneralResult<OperandType> convert(const hal::V1_0::OperandType& operandType) {
+GeneralResult<OperandType> unvalidatedConvert(const hal::V1_0::OperandType& operandType) {
return static_cast<OperandType>(operandType);
}
-GeneralResult<OperationType> convert(const hal::V1_0::OperationType& operationType) {
+GeneralResult<OperationType> unvalidatedConvert(const hal::V1_0::OperationType& operationType) {
return static_cast<OperationType>(operationType);
}
-GeneralResult<Operand::LifeTime> convert(const hal::V1_0::OperandLifeTime& lifetime) {
+GeneralResult<Operand::LifeTime> unvalidatedConvert(const hal::V1_0::OperandLifeTime& lifetime) {
return static_cast<Operand::LifeTime>(lifetime);
}
-GeneralResult<DeviceStatus> convert(const hal::V1_0::DeviceStatus& deviceStatus) {
+GeneralResult<DeviceStatus> unvalidatedConvert(const hal::V1_0::DeviceStatus& deviceStatus) {
return static_cast<DeviceStatus>(deviceStatus);
}
-GeneralResult<Capabilities::PerformanceInfo> convert(
+GeneralResult<Capabilities::PerformanceInfo> unvalidatedConvert(
const hal::V1_0::PerformanceInfo& performanceInfo) {
return Capabilities::PerformanceInfo{
.execTime = performanceInfo.execTime,
@@ -87,9 +107,10 @@
};
}
-GeneralResult<Capabilities> convert(const hal::V1_0::Capabilities& capabilities) {
- const auto quantized8Performance = NN_TRY(convert(capabilities.quantized8Performance));
- const auto float32Performance = NN_TRY(convert(capabilities.float32Performance));
+GeneralResult<Capabilities> unvalidatedConvert(const hal::V1_0::Capabilities& capabilities) {
+ const auto quantized8Performance =
+ NN_TRY(unvalidatedConvert(capabilities.quantized8Performance));
+ const auto float32Performance = NN_TRY(unvalidatedConvert(capabilities.float32Performance));
auto table = hal::utils::makeQuantized8PerformanceConsistentWithP(float32Performance,
quantized8Performance);
@@ -101,7 +122,7 @@
};
}
-GeneralResult<DataLocation> convert(const hal::V1_0::DataLocation& location) {
+GeneralResult<DataLocation> unvalidatedConvert(const hal::V1_0::DataLocation& location) {
return DataLocation{
.poolIndex = location.poolIndex,
.offset = location.offset,
@@ -109,35 +130,35 @@
};
}
-GeneralResult<Operand> convert(const hal::V1_0::Operand& operand) {
+GeneralResult<Operand> unvalidatedConvert(const hal::V1_0::Operand& operand) {
return Operand{
- .type = NN_TRY(convert(operand.type)),
+ .type = NN_TRY(unvalidatedConvert(operand.type)),
.dimensions = operand.dimensions,
.scale = operand.scale,
.zeroPoint = operand.zeroPoint,
- .lifetime = NN_TRY(convert(operand.lifetime)),
- .location = NN_TRY(convert(operand.location)),
+ .lifetime = NN_TRY(unvalidatedConvert(operand.lifetime)),
+ .location = NN_TRY(unvalidatedConvert(operand.location)),
};
}
-GeneralResult<Operation> convert(const hal::V1_0::Operation& operation) {
+GeneralResult<Operation> unvalidatedConvert(const hal::V1_0::Operation& operation) {
return Operation{
- .type = NN_TRY(convert(operation.type)),
+ .type = NN_TRY(unvalidatedConvert(operation.type)),
.inputs = operation.inputs,
.outputs = operation.outputs,
};
}
-GeneralResult<Model::OperandValues> convert(const hidl_vec<uint8_t>& operandValues) {
+GeneralResult<Model::OperandValues> unvalidatedConvert(const hidl_vec<uint8_t>& operandValues) {
return Model::OperandValues(operandValues.data(), operandValues.size());
}
-GeneralResult<Memory> convert(const hidl_memory& memory) {
+GeneralResult<Memory> unvalidatedConvert(const hidl_memory& memory) {
return createSharedMemoryFromHidlMemory(memory);
}
-GeneralResult<Model> convert(const hal::V1_0::Model& model) {
- auto operations = NN_TRY(convert(model.operations));
+GeneralResult<Model> unvalidatedConvert(const hal::V1_0::Model& model) {
+ auto operations = NN_TRY(unvalidatedConvert(model.operations));
// Verify number of consumers.
const auto numberOfConsumers =
@@ -152,7 +173,7 @@
}
auto main = Model::Subgraph{
- .operands = NN_TRY(convert(model.operands)),
+ .operands = NN_TRY(unvalidatedConvert(model.operands)),
.operations = std::move(operations),
.inputIndexes = model.inputIndexes,
.outputIndexes = model.outputIndexes,
@@ -160,35 +181,35 @@
return Model{
.main = std::move(main),
- .operandValues = NN_TRY(convert(model.operandValues)),
- .pools = NN_TRY(convert(model.pools)),
+ .operandValues = NN_TRY(unvalidatedConvert(model.operandValues)),
+ .pools = NN_TRY(unvalidatedConvert(model.pools)),
};
}
-GeneralResult<Request::Argument> convert(const hal::V1_0::RequestArgument& argument) {
+GeneralResult<Request::Argument> unvalidatedConvert(const hal::V1_0::RequestArgument& argument) {
const auto lifetime = argument.hasNoValue ? Request::Argument::LifeTime::NO_VALUE
: Request::Argument::LifeTime::POOL;
return Request::Argument{
.lifetime = lifetime,
- .location = NN_TRY(convert(argument.location)),
+ .location = NN_TRY(unvalidatedConvert(argument.location)),
.dimensions = argument.dimensions,
};
}
-GeneralResult<Request> convert(const hal::V1_0::Request& request) {
- auto memories = NN_TRY(convert(request.pools));
+GeneralResult<Request> unvalidatedConvert(const hal::V1_0::Request& request) {
+ auto memories = NN_TRY(unvalidatedConvert(request.pools));
std::vector<Request::MemoryPool> pools;
pools.reserve(memories.size());
std::move(memories.begin(), memories.end(), std::back_inserter(pools));
return Request{
- .inputs = NN_TRY(convert(request.inputs)),
- .outputs = NN_TRY(convert(request.outputs)),
+ .inputs = NN_TRY(unvalidatedConvert(request.inputs)),
+ .outputs = NN_TRY(unvalidatedConvert(request.outputs)),
.pools = std::move(pools),
};
}
-GeneralResult<ErrorStatus> convert(const hal::V1_0::ErrorStatus& status) {
+GeneralResult<ErrorStatus> unvalidatedConvert(const hal::V1_0::ErrorStatus& status) {
switch (status) {
case hal::V1_0::ErrorStatus::NONE:
case hal::V1_0::ErrorStatus::DEVICE_UNAVAILABLE:
@@ -201,46 +222,81 @@
<< "Invalid ErrorStatus " << underlyingType(status);
}
+GeneralResult<DeviceStatus> convert(const hal::V1_0::DeviceStatus& deviceStatus) {
+ return validatedConvert(deviceStatus);
+}
+
+GeneralResult<Capabilities> convert(const hal::V1_0::Capabilities& capabilities) {
+ return validatedConvert(capabilities);
+}
+
+GeneralResult<Model> convert(const hal::V1_0::Model& model) {
+ return validatedConvert(model);
+}
+
+GeneralResult<Request> convert(const hal::V1_0::Request& request) {
+ return validatedConvert(request);
+}
+
+GeneralResult<ErrorStatus> convert(const hal::V1_0::ErrorStatus& status) {
+ return validatedConvert(status);
+}
+
} // namespace android::nn
namespace android::hardware::neuralnetworks::V1_0::utils {
namespace {
template <typename Input>
-using ConvertOutput = std::decay_t<decltype(convert(std::declval<Input>()).value())>;
+using unvalidatedConvertOutput =
+ std::decay_t<decltype(unvalidatedConvert(std::declval<Input>()).value())>;
template <typename Type>
-nn::GeneralResult<hidl_vec<ConvertOutput<Type>>> convert(const std::vector<Type>& arguments) {
- hidl_vec<ConvertOutput<Type>> halObject(arguments.size());
+nn::GeneralResult<hidl_vec<unvalidatedConvertOutput<Type>>> unvalidatedConvert(
+ const std::vector<Type>& arguments) {
+ hidl_vec<unvalidatedConvertOutput<Type>> halObject(arguments.size());
for (size_t i = 0; i < arguments.size(); ++i) {
- halObject[i] = NN_TRY(utils::convert(arguments[i]));
+ halObject[i] = NN_TRY(utils::unvalidatedConvert(arguments[i]));
}
return halObject;
}
+template <typename Type>
+decltype(utils::unvalidatedConvert(std::declval<Type>())) validatedConvert(const Type& canonical) {
+ const auto maybeVersion = nn::validate(canonical);
+ if (!maybeVersion.has_value()) {
+ return nn::error() << maybeVersion.error();
+ }
+ const auto version = maybeVersion.value();
+ if (version > kVersion) {
+ return NN_ERROR() << "Insufficient version: " << version << " vs required " << kVersion;
+ }
+ return utils::unvalidatedConvert(canonical);
+}
+
} // anonymous namespace
-nn::GeneralResult<OperandType> convert(const nn::OperandType& operandType) {
+nn::GeneralResult<OperandType> unvalidatedConvert(const nn::OperandType& operandType) {
return static_cast<OperandType>(operandType);
}
-nn::GeneralResult<OperationType> convert(const nn::OperationType& operationType) {
+nn::GeneralResult<OperationType> unvalidatedConvert(const nn::OperationType& operationType) {
return static_cast<OperationType>(operationType);
}
-nn::GeneralResult<OperandLifeTime> convert(const nn::Operand::LifeTime& lifetime) {
+nn::GeneralResult<OperandLifeTime> unvalidatedConvert(const nn::Operand::LifeTime& lifetime) {
if (lifetime == nn::Operand::LifeTime::POINTER) {
return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT)
- << "Model cannot be converted because it contains pointer-based memory";
+ << "Model cannot be unvalidatedConverted because it contains pointer-based memory";
}
return static_cast<OperandLifeTime>(lifetime);
}
-nn::GeneralResult<DeviceStatus> convert(const nn::DeviceStatus& deviceStatus) {
+nn::GeneralResult<DeviceStatus> unvalidatedConvert(const nn::DeviceStatus& deviceStatus) {
return static_cast<DeviceStatus>(deviceStatus);
}
-nn::GeneralResult<PerformanceInfo> convert(
+nn::GeneralResult<PerformanceInfo> unvalidatedConvert(
const nn::Capabilities::PerformanceInfo& performanceInfo) {
return PerformanceInfo{
.execTime = performanceInfo.execTime,
@@ -248,16 +304,16 @@
};
}
-nn::GeneralResult<Capabilities> convert(const nn::Capabilities& capabilities) {
+nn::GeneralResult<Capabilities> unvalidatedConvert(const nn::Capabilities& capabilities) {
return Capabilities{
- .float32Performance = NN_TRY(convert(
+ .float32Performance = NN_TRY(unvalidatedConvert(
capabilities.operandPerformance.lookup(nn::OperandType::TENSOR_FLOAT32))),
- .quantized8Performance = NN_TRY(convert(
+ .quantized8Performance = NN_TRY(unvalidatedConvert(
capabilities.operandPerformance.lookup(nn::OperandType::TENSOR_QUANT8_ASYMM))),
};
}
-nn::GeneralResult<DataLocation> convert(const nn::DataLocation& location) {
+nn::GeneralResult<DataLocation> unvalidatedConvert(const nn::DataLocation& location) {
return DataLocation{
.poolIndex = location.poolIndex,
.offset = location.offset,
@@ -265,42 +321,43 @@
};
}
-nn::GeneralResult<Operand> convert(const nn::Operand& operand) {
+nn::GeneralResult<Operand> unvalidatedConvert(const nn::Operand& operand) {
return Operand{
- .type = NN_TRY(convert(operand.type)),
+ .type = NN_TRY(unvalidatedConvert(operand.type)),
.dimensions = operand.dimensions,
.numberOfConsumers = 0,
.scale = operand.scale,
.zeroPoint = operand.zeroPoint,
- .lifetime = NN_TRY(convert(operand.lifetime)),
- .location = NN_TRY(convert(operand.location)),
+ .lifetime = NN_TRY(unvalidatedConvert(operand.lifetime)),
+ .location = NN_TRY(unvalidatedConvert(operand.location)),
};
}
-nn::GeneralResult<Operation> convert(const nn::Operation& operation) {
+nn::GeneralResult<Operation> unvalidatedConvert(const nn::Operation& operation) {
return Operation{
- .type = NN_TRY(convert(operation.type)),
+ .type = NN_TRY(unvalidatedConvert(operation.type)),
.inputs = operation.inputs,
.outputs = operation.outputs,
};
}
-nn::GeneralResult<hidl_vec<uint8_t>> convert(const nn::Model::OperandValues& operandValues) {
+nn::GeneralResult<hidl_vec<uint8_t>> unvalidatedConvert(
+ const nn::Model::OperandValues& operandValues) {
return hidl_vec<uint8_t>(operandValues.data(), operandValues.data() + operandValues.size());
}
-nn::GeneralResult<hidl_memory> convert(const nn::Memory& memory) {
+nn::GeneralResult<hidl_memory> unvalidatedConvert(const nn::Memory& memory) {
return hidl_memory(memory.name, NN_TRY(hal::utils::hidlHandleFromSharedHandle(memory.handle)),
memory.size);
}
-nn::GeneralResult<Model> convert(const nn::Model& model) {
+nn::GeneralResult<Model> unvalidatedConvert(const nn::Model& model) {
if (!hal::utils::hasNoPointerData(model)) {
return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT)
- << "Mdoel cannot be converted because it contains pointer-based memory";
+ << "Mdoel cannot be unvalidatedConverted because it contains pointer-based memory";
}
- auto operands = NN_TRY(convert(model.main.operands));
+ auto operands = NN_TRY(unvalidatedConvert(model.main.operands));
// Update number of consumers.
const auto numberOfConsumers =
@@ -312,45 +369,46 @@
return Model{
.operands = std::move(operands),
- .operations = NN_TRY(convert(model.main.operations)),
+ .operations = NN_TRY(unvalidatedConvert(model.main.operations)),
.inputIndexes = model.main.inputIndexes,
.outputIndexes = model.main.outputIndexes,
- .operandValues = NN_TRY(convert(model.operandValues)),
- .pools = NN_TRY(convert(model.pools)),
+ .operandValues = NN_TRY(unvalidatedConvert(model.operandValues)),
+ .pools = NN_TRY(unvalidatedConvert(model.pools)),
};
}
-nn::GeneralResult<RequestArgument> convert(const nn::Request::Argument& requestArgument) {
+nn::GeneralResult<RequestArgument> unvalidatedConvert(
+ const nn::Request::Argument& requestArgument) {
if (requestArgument.lifetime == nn::Request::Argument::LifeTime::POINTER) {
return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT)
- << "Request cannot be converted because it contains pointer-based memory";
+ << "Request cannot be unvalidatedConverted because it contains pointer-based memory";
}
const bool hasNoValue = requestArgument.lifetime == nn::Request::Argument::LifeTime::NO_VALUE;
return RequestArgument{
.hasNoValue = hasNoValue,
- .location = NN_TRY(convert(requestArgument.location)),
+ .location = NN_TRY(unvalidatedConvert(requestArgument.location)),
.dimensions = requestArgument.dimensions,
};
}
-nn::GeneralResult<hidl_memory> convert(const nn::Request::MemoryPool& memoryPool) {
- return convert(std::get<nn::Memory>(memoryPool));
+nn::GeneralResult<hidl_memory> unvalidatedConvert(const nn::Request::MemoryPool& memoryPool) {
+ return unvalidatedConvert(std::get<nn::Memory>(memoryPool));
}
-nn::GeneralResult<Request> convert(const nn::Request& request) {
+nn::GeneralResult<Request> unvalidatedConvert(const nn::Request& request) {
if (!hal::utils::hasNoPointerData(request)) {
return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT)
- << "Request cannot be converted because it contains pointer-based memory";
+ << "Request cannot be unvalidatedConverted because it contains pointer-based memory";
}
return Request{
- .inputs = NN_TRY(convert(request.inputs)),
- .outputs = NN_TRY(convert(request.outputs)),
- .pools = NN_TRY(convert(request.pools)),
+ .inputs = NN_TRY(unvalidatedConvert(request.inputs)),
+ .outputs = NN_TRY(unvalidatedConvert(request.outputs)),
+ .pools = NN_TRY(unvalidatedConvert(request.pools)),
};
}
-nn::GeneralResult<ErrorStatus> convert(const nn::ErrorStatus& status) {
+nn::GeneralResult<ErrorStatus> unvalidatedConvert(const nn::ErrorStatus& status) {
switch (status) {
case nn::ErrorStatus::NONE:
case nn::ErrorStatus::DEVICE_UNAVAILABLE:
@@ -363,4 +421,24 @@
}
}
+nn::GeneralResult<DeviceStatus> convert(const nn::DeviceStatus& deviceStatus) {
+ return validatedConvert(deviceStatus);
+}
+
+nn::GeneralResult<Capabilities> convert(const nn::Capabilities& capabilities) {
+ return validatedConvert(capabilities);
+}
+
+nn::GeneralResult<Model> convert(const nn::Model& model) {
+ return validatedConvert(model);
+}
+
+nn::GeneralResult<Request> convert(const nn::Request& request) {
+ return validatedConvert(request);
+}
+
+nn::GeneralResult<ErrorStatus> convert(const nn::ErrorStatus& status) {
+ return validatedConvert(status);
+}
+
} // namespace android::hardware::neuralnetworks::V1_0::utils
diff --git a/neuralnetworks/1.0/utils/src/Device.cpp b/neuralnetworks/1.0/utils/src/Device.cpp
index 671416b..ab3f5af 100644
--- a/neuralnetworks/1.0/utils/src/Device.cpp
+++ b/neuralnetworks/1.0/utils/src/Device.cpp
@@ -48,11 +48,10 @@
<< "uninitialized";
const auto cb = [&result](ErrorStatus status, const Capabilities& capabilities) {
if (status != ErrorStatus::NONE) {
- const auto canonical =
- validatedConvertToCanonical(status).value_or(nn::ErrorStatus::GENERAL_FAILURE);
+ const auto canonical = nn::convert(status).value_or(nn::ErrorStatus::GENERAL_FAILURE);
result = NN_ERROR(canonical) << "getCapabilities failed with " << toString(status);
} else {
- result = validatedConvertToCanonical(capabilities);
+ result = nn::convert(capabilities);
}
};
@@ -135,8 +134,7 @@
<< "uninitialized";
auto cb = [&result, &model](ErrorStatus status, const hidl_vec<bool>& supportedOperations) {
if (status != ErrorStatus::NONE) {
- const auto canonical =
- validatedConvertToCanonical(status).value_or(nn::ErrorStatus::GENERAL_FAILURE);
+ const auto canonical = nn::convert(status).value_or(nn::ErrorStatus::GENERAL_FAILURE);
result = NN_ERROR(canonical)
<< "getSupportedOperations failed with " << toString(status);
} else if (supportedOperations.size() != model.main.operations.size()) {
@@ -172,8 +170,7 @@
const auto ret = kDevice->prepareModel(hidlModel, cb);
const auto status = NN_TRY(hal::utils::handleTransportError(ret));
if (status != ErrorStatus::NONE) {
- const auto canonical =
- validatedConvertToCanonical(status).value_or(nn::ErrorStatus::GENERAL_FAILURE);
+ const auto canonical = nn::convert(status).value_or(nn::ErrorStatus::GENERAL_FAILURE);
return NN_ERROR(canonical) << "prepareModel failed with " << toString(status);
}
diff --git a/neuralnetworks/1.0/utils/src/PreparedModel.cpp b/neuralnetworks/1.0/utils/src/PreparedModel.cpp
index 11ccbe3..80f885a 100644
--- a/neuralnetworks/1.0/utils/src/PreparedModel.cpp
+++ b/neuralnetworks/1.0/utils/src/PreparedModel.cpp
@@ -70,8 +70,7 @@
const auto status =
NN_TRY(hal::utils::makeExecutionFailure(hal::utils::handleTransportError(ret)));
if (status != ErrorStatus::NONE) {
- const auto canonical =
- validatedConvertToCanonical(status).value_or(nn::ErrorStatus::GENERAL_FAILURE);
+ const auto canonical = nn::convert(status).value_or(nn::ErrorStatus::GENERAL_FAILURE);
return NN_ERROR(canonical) << "execute failed with " << toString(status);
}