Add missing validation for NN canonical types
Bug: 177669661
Test: mma
Test: NeuralNetworksTest_static
Change-Id: Ic05c177f61a906a69bf82ff9c4d5bb8b0556d5ca
Merged-In: Ic05c177f61a906a69bf82ff9c4d5bb8b0556d5ca
(cherry picked from commit 08ee3f9287811e9087a5263c3176ce1439f70c2c)
diff --git a/neuralnetworks/1.2/utils/src/Conversions.cpp b/neuralnetworks/1.2/utils/src/Conversions.cpp
index 2c45583..29945b7 100644
--- a/neuralnetworks/1.2/utils/src/Conversions.cpp
+++ b/neuralnetworks/1.2/utils/src/Conversions.cpp
@@ -37,6 +37,8 @@
#include <type_traits>
#include <utility>
+#include "Utils.h"
+
namespace {
template <typename Type>
@@ -45,50 +47,23 @@
}
using HalDuration = std::chrono::duration<uint64_t, std::micro>;
-constexpr auto kVersion = android::nn::Version::ANDROID_Q;
-constexpr uint64_t kNoTiming = std::numeric_limits<uint64_t>::max();
} // namespace
namespace android::nn {
namespace {
-constexpr bool validOperandType(OperandType operandType) {
- switch (operandType) {
- case OperandType::FLOAT32:
- case OperandType::INT32:
- case OperandType::UINT32:
- case OperandType::TENSOR_FLOAT32:
- case OperandType::TENSOR_INT32:
- case OperandType::TENSOR_QUANT8_ASYMM:
- case OperandType::BOOL:
- case OperandType::TENSOR_QUANT16_SYMM:
- case OperandType::TENSOR_FLOAT16:
- case OperandType::TENSOR_BOOL8:
- case OperandType::FLOAT16:
- case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
- case OperandType::TENSOR_QUANT16_ASYMM:
- case OperandType::TENSOR_QUANT8_SYMM:
- case OperandType::OEM:
- case OperandType::TENSOR_OEM_BYTE:
- return true;
- default:
- break;
- }
- return isExtension(operandType);
-}
-
using hardware::hidl_handle;
using hardware::hidl_vec;
template <typename Input>
-using unvalidatedConvertOutput =
+using UnvalidatedConvertOutput =
std::decay_t<decltype(unvalidatedConvert(std::declval<Input>()).value())>;
template <typename Type>
-GeneralResult<std::vector<unvalidatedConvertOutput<Type>>> unvalidatedConvertVec(
+GeneralResult<std::vector<UnvalidatedConvertOutput<Type>>> unvalidatedConvert(
const hidl_vec<Type>& arguments) {
- std::vector<unvalidatedConvertOutput<Type>> canonical;
+ std::vector<UnvalidatedConvertOutput<Type>> canonical;
canonical.reserve(arguments.size());
for (const auto& argument : arguments) {
canonical.push_back(NN_TRY(nn::unvalidatedConvert(argument)));
@@ -97,29 +72,16 @@
}
template <typename Type>
-GeneralResult<std::vector<unvalidatedConvertOutput<Type>>> unvalidatedConvert(
- const hidl_vec<Type>& arguments) {
- return unvalidatedConvertVec(arguments);
-}
-
-template <typename Type>
-decltype(nn::unvalidatedConvert(std::declval<Type>())) validatedConvert(const Type& halObject) {
+GeneralResult<UnvalidatedConvertOutput<Type>> validatedConvert(const Type& halObject) {
auto canonical = NN_TRY(nn::unvalidatedConvert(halObject));
- const auto maybeVersion = validate(canonical);
- if (!maybeVersion.has_value()) {
- return error() << maybeVersion.error();
- }
- const auto version = maybeVersion.value();
- if (version > kVersion) {
- return NN_ERROR() << "Insufficient version: " << version << " vs required " << kVersion;
- }
+ NN_TRY(hal::V1_2::utils::compliantVersion(canonical));
return canonical;
}
template <typename Type>
-GeneralResult<std::vector<unvalidatedConvertOutput<Type>>> validatedConvert(
+GeneralResult<std::vector<UnvalidatedConvertOutput<Type>>> validatedConvert(
const hidl_vec<Type>& arguments) {
- std::vector<unvalidatedConvertOutput<Type>> canonical;
+ std::vector<UnvalidatedConvertOutput<Type>> canonical;
canonical.reserve(arguments.size());
for (const auto& argument : arguments) {
canonical.push_back(NN_TRY(validatedConvert(argument)));
@@ -145,8 +107,7 @@
const bool validOperandTypes = std::all_of(
capabilities.operandPerformance.begin(), capabilities.operandPerformance.end(),
[](const hal::V1_2::Capabilities::OperandPerformance& operandPerformance) {
- const auto maybeType = unvalidatedConvert(operandPerformance.type);
- return !maybeType.has_value() ? false : validOperandType(maybeType.value());
+ return validatedConvert(operandPerformance.type).has_value();
});
if (!validOperandTypes) {
return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
@@ -275,6 +236,7 @@
GeneralResult<Timing> unvalidatedConvert(const hal::V1_2::Timing& timing) {
constexpr uint64_t kMaxTiming = std::chrono::floor<HalDuration>(Duration::max()).count();
constexpr auto convertTiming = [](uint64_t halTiming) -> OptionalDuration {
+ constexpr uint64_t kNoTiming = std::numeric_limits<uint64_t>::max();
if (halTiming == kNoTiming) {
return {};
}
@@ -378,25 +340,19 @@
}
template <typename Input>
-using unvalidatedConvertOutput =
+using UnvalidatedConvertOutput =
std::decay_t<decltype(unvalidatedConvert(std::declval<Input>()).value())>;
template <typename Type>
-nn::GeneralResult<hidl_vec<unvalidatedConvertOutput<Type>>> unvalidatedConvertVec(
+nn::GeneralResult<hidl_vec<UnvalidatedConvertOutput<Type>>> unvalidatedConvert(
const std::vector<Type>& arguments) {
- hidl_vec<unvalidatedConvertOutput<Type>> halObject(arguments.size());
+ hidl_vec<UnvalidatedConvertOutput<Type>> halObject(arguments.size());
for (size_t i = 0; i < arguments.size(); ++i) {
halObject[i] = NN_TRY(unvalidatedConvert(arguments[i]));
}
return halObject;
}
-template <typename Type>
-nn::GeneralResult<hidl_vec<unvalidatedConvertOutput<Type>>> unvalidatedConvert(
- const std::vector<Type>& arguments) {
- return unvalidatedConvertVec(arguments);
-}
-
nn::GeneralResult<Operand::ExtraParams> makeExtraParams(nn::Operand::NoParams /*noParams*/) {
return Operand::ExtraParams{};
}
@@ -416,22 +372,15 @@
}
template <typename Type>
-decltype(utils::unvalidatedConvert(std::declval<Type>())) validatedConvert(const Type& canonical) {
- const auto maybeVersion = nn::validate(canonical);
- if (!maybeVersion.has_value()) {
- return nn::error() << maybeVersion.error();
- }
- const auto version = maybeVersion.value();
- if (version > kVersion) {
- return NN_ERROR() << "Insufficient version: " << version << " vs required " << kVersion;
- }
- return utils::unvalidatedConvert(canonical);
+nn::GeneralResult<UnvalidatedConvertOutput<Type>> validatedConvert(const Type& canonical) {
+ NN_TRY(compliantVersion(canonical));
+ return unvalidatedConvert(canonical);
}
template <typename Type>
-nn::GeneralResult<hidl_vec<unvalidatedConvertOutput<Type>>> validatedConvert(
+nn::GeneralResult<hidl_vec<UnvalidatedConvertOutput<Type>>> validatedConvert(
const std::vector<Type>& arguments) {
- hidl_vec<unvalidatedConvertOutput<Type>> halObject(arguments.size());
+ hidl_vec<UnvalidatedConvertOutput<Type>> halObject(arguments.size());
for (size_t i = 0; i < arguments.size(); ++i) {
halObject[i] = NN_TRY(validatedConvert(arguments[i]));
}
@@ -469,7 +418,7 @@
capabilities.operandPerformance.asVector().end(),
std::back_inserter(operandPerformance),
[](const nn::Capabilities::OperandPerformance& operandPerformance) {
- return nn::validOperandType(operandPerformance.type);
+ return compliantVersion(operandPerformance.type).has_value();
});
return Capabilities{
@@ -570,6 +519,7 @@
nn::GeneralResult<Timing> unvalidatedConvert(const nn::Timing& timing) {
constexpr auto convertTiming = [](nn::OptionalDuration canonicalTiming) -> uint64_t {
+ constexpr uint64_t kNoTiming = std::numeric_limits<uint64_t>::max();
if (!canonicalTiming.has_value()) {
return kNoTiming;
}