Add missing validation for NN canonical types
Bug: 177669661
Test: mma
Test: NeuralNetworksTest_static
Change-Id: Ic05c177f61a906a69bf82ff9c4d5bb8b0556d5ca
Merged-In: Ic05c177f61a906a69bf82ff9c4d5bb8b0556d5ca
(cherry picked from commit 08ee3f9287811e9087a5263c3176ce1439f70c2c)
diff --git a/neuralnetworks/1.2/utils/include/nnapi/hal/1.2/Utils.h b/neuralnetworks/1.2/utils/include/nnapi/hal/1.2/Utils.h
index 3233114..09691b6 100644
--- a/neuralnetworks/1.2/utils/include/nnapi/hal/1.2/Utils.h
+++ b/neuralnetworks/1.2/utils/include/nnapi/hal/1.2/Utils.h
@@ -22,19 +22,25 @@
#include <android-base/logging.h>
#include <android/hardware/neuralnetworks/1.2/types.h>
#include <nnapi/Result.h>
+#include <nnapi/TypeUtils.h>
#include <nnapi/Types.h>
+#include <nnapi/Validation.h>
#include <nnapi/hal/1.0/Conversions.h>
#include <nnapi/hal/1.1/Conversions.h>
+#include <nnapi/hal/1.1/Utils.h>
+#include <nnapi/hal/HandleError.h>
#include <limits>
namespace android::hardware::neuralnetworks::V1_2::utils {
using CacheToken = hidl_array<uint8_t, static_cast<size_t>(Constant::BYTE_SIZE_OF_CACHE_TOKEN)>;
+using V1_1::utils::kDefaultExecutionPreference;
constexpr auto kDefaultMesaureTiming = MeasureTiming::NO;
constexpr auto kNoTiming = Timing{.timeOnDevice = std::numeric_limits<uint64_t>::max(),
.timeInDriver = std::numeric_limits<uint64_t>::max()};
+constexpr auto kVersion = nn::Version::ANDROID_Q;
template <typename Type>
nn::Result<void> validate(const Type& halObject) {
@@ -55,6 +61,15 @@
}
template <typename Type>
+nn::GeneralResult<void> compliantVersion(const Type& canonical) {
+ const auto version = NN_TRY(hal::utils::makeGeneralFailure(nn::validate(canonical)));
+ if (version > kVersion) {
+ return NN_ERROR() << "Insufficient version: " << version << " vs required " << kVersion;
+ }
+ return {};
+}
+
+template <typename Type>
auto convertFromNonCanonical(const Type& nonCanonicalObject)
-> decltype(convert(nn::convert(nonCanonicalObject).value())) {
return convert(NN_TRY(nn::convert(nonCanonicalObject)));
diff --git a/neuralnetworks/1.2/utils/src/Conversions.cpp b/neuralnetworks/1.2/utils/src/Conversions.cpp
index 2c45583..29945b7 100644
--- a/neuralnetworks/1.2/utils/src/Conversions.cpp
+++ b/neuralnetworks/1.2/utils/src/Conversions.cpp
@@ -37,6 +37,8 @@
#include <type_traits>
#include <utility>
+#include "Utils.h"
+
namespace {
template <typename Type>
@@ -45,50 +47,23 @@
}
using HalDuration = std::chrono::duration<uint64_t, std::micro>;
-constexpr auto kVersion = android::nn::Version::ANDROID_Q;
-constexpr uint64_t kNoTiming = std::numeric_limits<uint64_t>::max();
} // namespace
namespace android::nn {
namespace {
-constexpr bool validOperandType(OperandType operandType) {
- switch (operandType) {
- case OperandType::FLOAT32:
- case OperandType::INT32:
- case OperandType::UINT32:
- case OperandType::TENSOR_FLOAT32:
- case OperandType::TENSOR_INT32:
- case OperandType::TENSOR_QUANT8_ASYMM:
- case OperandType::BOOL:
- case OperandType::TENSOR_QUANT16_SYMM:
- case OperandType::TENSOR_FLOAT16:
- case OperandType::TENSOR_BOOL8:
- case OperandType::FLOAT16:
- case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
- case OperandType::TENSOR_QUANT16_ASYMM:
- case OperandType::TENSOR_QUANT8_SYMM:
- case OperandType::OEM:
- case OperandType::TENSOR_OEM_BYTE:
- return true;
- default:
- break;
- }
- return isExtension(operandType);
-}
-
using hardware::hidl_handle;
using hardware::hidl_vec;
template <typename Input>
-using unvalidatedConvertOutput =
+using UnvalidatedConvertOutput =
std::decay_t<decltype(unvalidatedConvert(std::declval<Input>()).value())>;
template <typename Type>
-GeneralResult<std::vector<unvalidatedConvertOutput<Type>>> unvalidatedConvertVec(
+GeneralResult<std::vector<UnvalidatedConvertOutput<Type>>> unvalidatedConvert(
const hidl_vec<Type>& arguments) {
- std::vector<unvalidatedConvertOutput<Type>> canonical;
+ std::vector<UnvalidatedConvertOutput<Type>> canonical;
canonical.reserve(arguments.size());
for (const auto& argument : arguments) {
canonical.push_back(NN_TRY(nn::unvalidatedConvert(argument)));
@@ -97,29 +72,16 @@
}
template <typename Type>
-GeneralResult<std::vector<unvalidatedConvertOutput<Type>>> unvalidatedConvert(
- const hidl_vec<Type>& arguments) {
- return unvalidatedConvertVec(arguments);
-}
-
-template <typename Type>
-decltype(nn::unvalidatedConvert(std::declval<Type>())) validatedConvert(const Type& halObject) {
+GeneralResult<UnvalidatedConvertOutput<Type>> validatedConvert(const Type& halObject) {
auto canonical = NN_TRY(nn::unvalidatedConvert(halObject));
- const auto maybeVersion = validate(canonical);
- if (!maybeVersion.has_value()) {
- return error() << maybeVersion.error();
- }
- const auto version = maybeVersion.value();
- if (version > kVersion) {
- return NN_ERROR() << "Insufficient version: " << version << " vs required " << kVersion;
- }
+ NN_TRY(hal::V1_2::utils::compliantVersion(canonical));
return canonical;
}
template <typename Type>
-GeneralResult<std::vector<unvalidatedConvertOutput<Type>>> validatedConvert(
+GeneralResult<std::vector<UnvalidatedConvertOutput<Type>>> validatedConvert(
const hidl_vec<Type>& arguments) {
- std::vector<unvalidatedConvertOutput<Type>> canonical;
+ std::vector<UnvalidatedConvertOutput<Type>> canonical;
canonical.reserve(arguments.size());
for (const auto& argument : arguments) {
canonical.push_back(NN_TRY(validatedConvert(argument)));
@@ -145,8 +107,7 @@
const bool validOperandTypes = std::all_of(
capabilities.operandPerformance.begin(), capabilities.operandPerformance.end(),
[](const hal::V1_2::Capabilities::OperandPerformance& operandPerformance) {
- const auto maybeType = unvalidatedConvert(operandPerformance.type);
- return !maybeType.has_value() ? false : validOperandType(maybeType.value());
+ return validatedConvert(operandPerformance.type).has_value();
});
if (!validOperandTypes) {
return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
@@ -275,6 +236,7 @@
GeneralResult<Timing> unvalidatedConvert(const hal::V1_2::Timing& timing) {
constexpr uint64_t kMaxTiming = std::chrono::floor<HalDuration>(Duration::max()).count();
constexpr auto convertTiming = [](uint64_t halTiming) -> OptionalDuration {
+ constexpr uint64_t kNoTiming = std::numeric_limits<uint64_t>::max();
if (halTiming == kNoTiming) {
return {};
}
@@ -378,25 +340,19 @@
}
template <typename Input>
-using unvalidatedConvertOutput =
+using UnvalidatedConvertOutput =
std::decay_t<decltype(unvalidatedConvert(std::declval<Input>()).value())>;
template <typename Type>
-nn::GeneralResult<hidl_vec<unvalidatedConvertOutput<Type>>> unvalidatedConvertVec(
+nn::GeneralResult<hidl_vec<UnvalidatedConvertOutput<Type>>> unvalidatedConvert(
const std::vector<Type>& arguments) {
- hidl_vec<unvalidatedConvertOutput<Type>> halObject(arguments.size());
+ hidl_vec<UnvalidatedConvertOutput<Type>> halObject(arguments.size());
for (size_t i = 0; i < arguments.size(); ++i) {
halObject[i] = NN_TRY(unvalidatedConvert(arguments[i]));
}
return halObject;
}
-template <typename Type>
-nn::GeneralResult<hidl_vec<unvalidatedConvertOutput<Type>>> unvalidatedConvert(
- const std::vector<Type>& arguments) {
- return unvalidatedConvertVec(arguments);
-}
-
nn::GeneralResult<Operand::ExtraParams> makeExtraParams(nn::Operand::NoParams /*noParams*/) {
return Operand::ExtraParams{};
}
@@ -416,22 +372,15 @@
}
template <typename Type>
-decltype(utils::unvalidatedConvert(std::declval<Type>())) validatedConvert(const Type& canonical) {
- const auto maybeVersion = nn::validate(canonical);
- if (!maybeVersion.has_value()) {
- return nn::error() << maybeVersion.error();
- }
- const auto version = maybeVersion.value();
- if (version > kVersion) {
- return NN_ERROR() << "Insufficient version: " << version << " vs required " << kVersion;
- }
- return utils::unvalidatedConvert(canonical);
+nn::GeneralResult<UnvalidatedConvertOutput<Type>> validatedConvert(const Type& canonical) {
+ NN_TRY(compliantVersion(canonical));
+ return unvalidatedConvert(canonical);
}
template <typename Type>
-nn::GeneralResult<hidl_vec<unvalidatedConvertOutput<Type>>> validatedConvert(
+nn::GeneralResult<hidl_vec<UnvalidatedConvertOutput<Type>>> validatedConvert(
const std::vector<Type>& arguments) {
- hidl_vec<unvalidatedConvertOutput<Type>> halObject(arguments.size());
+ hidl_vec<UnvalidatedConvertOutput<Type>> halObject(arguments.size());
for (size_t i = 0; i < arguments.size(); ++i) {
halObject[i] = NN_TRY(validatedConvert(arguments[i]));
}
@@ -469,7 +418,7 @@
capabilities.operandPerformance.asVector().end(),
std::back_inserter(operandPerformance),
[](const nn::Capabilities::OperandPerformance& operandPerformance) {
- return nn::validOperandType(operandPerformance.type);
+ return compliantVersion(operandPerformance.type).has_value();
});
return Capabilities{
@@ -570,6 +519,7 @@
nn::GeneralResult<Timing> unvalidatedConvert(const nn::Timing& timing) {
constexpr auto convertTiming = [](nn::OptionalDuration canonicalTiming) -> uint64_t {
+ constexpr uint64_t kNoTiming = std::numeric_limits<uint64_t>::max();
if (!canonicalTiming.has_value()) {
return kNoTiming;
}