NNAPI VTS: Add validation for Priority
This CL also cleans up the validation framework code.
Bug: 67828197
Test: mma
Test: VtsHalNeuralnetworksV1_*TargetTest
Change-Id: I84661fb2b8204148788d10425ca0ac986158b15f
Merged-In: I84661fb2b8204148788d10425ca0ac986158b15f
(cherry picked from commit da1a69288054dcbe9e91943d408a65b3c59c4a09)
diff --git a/neuralnetworks/1.1/vts/functional/ValidateModel.cpp b/neuralnetworks/1.1/vts/functional/ValidateModel.cpp
index 0629a1e..3b6f0f8 100644
--- a/neuralnetworks/1.1/vts/functional/ValidateModel.cpp
+++ b/neuralnetworks/1.1/vts/functional/ValidateModel.cpp
@@ -30,6 +30,8 @@
using V1_0::OperandType;
using V1_0::implementation::PreparedModelCallback;
+using PrepareModelMutation = std::function<void(Model*, ExecutionPreference*)>;
+
///////////////////////// UTILITY FUNCTIONS /////////////////////////
static void validateGetSupportedOperations(const sp<IDevice>& device, const std::string& message,
@@ -67,16 +69,19 @@
}
// Primary validation function. This function will take a valid model, apply a
-// mutation to it to invalidate the model, then pass it to interface calls that
-// use the model. Note that the model here is passed by value, and any mutation
-// to the model does not leave this function.
-static void validate(const sp<IDevice>& device, const std::string& message, Model model,
- const std::function<void(Model*)>& mutation,
- ExecutionPreference preference = ExecutionPreference::FAST_SINGLE_ANSWER) {
- mutation(&model);
+// mutation to invalidate either the model or the execution preference, then
+// pass these to supportedOperations and/or prepareModel if that method is
+// called with an invalid argument.
+static void validate(const sp<IDevice>& device, const std::string& message,
+ const Model& originalModel, const PrepareModelMutation& mutate) {
+ Model model = originalModel;
+ ExecutionPreference preference = ExecutionPreference::FAST_SINGLE_ANSWER;
+ mutate(&model, &preference);
+
if (validExecutionPreference(preference)) {
validateGetSupportedOperations(device, message, model);
}
+
validatePrepareModel(device, message, model, preference);
}
@@ -115,9 +120,11 @@
const std::string message = "mutateOperandTypeTest: operand " +
std::to_string(operand) + " set to value " +
std::to_string(invalidOperandType);
- validate(device, message, model, [operand, invalidOperandType](Model* model) {
- model->operands[operand].type = static_cast<OperandType>(invalidOperandType);
- });
+ validate(device, message, model,
+ [operand, invalidOperandType](Model* model, ExecutionPreference*) {
+ model->operands[operand].type =
+ static_cast<OperandType>(invalidOperandType);
+ });
}
}
}
@@ -144,9 +151,10 @@
const uint32_t invalidRank = getInvalidRank(model.operands[operand].type);
const std::string message = "mutateOperandRankTest: operand " + std::to_string(operand) +
" has rank of " + std::to_string(invalidRank);
- validate(device, message, model, [operand, invalidRank](Model* model) {
- model->operands[operand].dimensions = std::vector<uint32_t>(invalidRank, 0);
- });
+ validate(device, message, model,
+ [operand, invalidRank](Model* model, ExecutionPreference*) {
+ model->operands[operand].dimensions = std::vector<uint32_t>(invalidRank, 0);
+ });
}
}
@@ -173,9 +181,10 @@
const float invalidScale = getInvalidScale(model.operands[operand].type);
const std::string message = "mutateOperandScaleTest: operand " + std::to_string(operand) +
" has scale of " + std::to_string(invalidScale);
- validate(device, message, model, [operand, invalidScale](Model* model) {
- model->operands[operand].scale = invalidScale;
- });
+ validate(device, message, model,
+ [operand, invalidScale](Model* model, ExecutionPreference*) {
+ model->operands[operand].scale = invalidScale;
+ });
}
}
@@ -204,9 +213,10 @@
const std::string message = "mutateOperandZeroPointTest: operand " +
std::to_string(operand) + " has zero point of " +
std::to_string(invalidZeroPoint);
- validate(device, message, model, [operand, invalidZeroPoint](Model* model) {
- model->operands[operand].zeroPoint = invalidZeroPoint;
- });
+ validate(device, message, model,
+ [operand, invalidZeroPoint](Model* model, ExecutionPreference*) {
+ model->operands[operand].zeroPoint = invalidZeroPoint;
+ });
}
}
}
@@ -282,9 +292,10 @@
const std::string message = "mutateOperationOperandTypeTest: operand " +
std::to_string(operand) + " set to type " +
toString(invalidOperandType);
- validate(device, message, model, [operand, invalidOperandType](Model* model) {
- mutateOperand(&model->operands[operand], invalidOperandType);
- });
+ validate(device, message, model,
+ [operand, invalidOperandType](Model* model, ExecutionPreference*) {
+ mutateOperand(&model->operands[operand], invalidOperandType);
+ });
}
}
}
@@ -304,10 +315,11 @@
const std::string message = "mutateOperationTypeTest: operation " +
std::to_string(operation) + " set to value " +
std::to_string(invalidOperationType);
- validate(device, message, model, [operation, invalidOperationType](Model* model) {
- model->operations[operation].type =
- static_cast<OperationType>(invalidOperationType);
- });
+ validate(device, message, model,
+ [operation, invalidOperationType](Model* model, ExecutionPreference*) {
+ model->operations[operation].type =
+ static_cast<OperationType>(invalidOperationType);
+ });
}
}
}
@@ -321,9 +333,10 @@
const std::string message = "mutateOperationInputOperandIndexTest: operation " +
std::to_string(operation) + " input " +
std::to_string(input);
- validate(device, message, model, [operation, input, invalidOperand](Model* model) {
- model->operations[operation].inputs[input] = invalidOperand;
- });
+ validate(device, message, model,
+ [operation, input, invalidOperand](Model* model, ExecutionPreference*) {
+ model->operations[operation].inputs[input] = invalidOperand;
+ });
}
}
}
@@ -337,9 +350,10 @@
const std::string message = "mutateOperationOutputOperandIndexTest: operation " +
std::to_string(operation) + " output " +
std::to_string(output);
- validate(device, message, model, [operation, output, invalidOperand](Model* model) {
- model->operations[operation].outputs[output] = invalidOperand;
- });
+ validate(device, message, model,
+ [operation, output, invalidOperand](Model* model, ExecutionPreference*) {
+ model->operations[operation].outputs[output] = invalidOperand;
+ });
}
}
}
@@ -372,7 +386,7 @@
for (size_t operand = 0; operand < model.operands.size(); ++operand) {
const std::string message = "removeOperandTest: operand " + std::to_string(operand);
validate(device, message, model,
- [operand](Model* model) { removeOperand(model, operand); });
+ [operand](Model* model, ExecutionPreference*) { removeOperand(model, operand); });
}
}
@@ -388,8 +402,9 @@
static void removeOperationTest(const sp<IDevice>& device, const Model& model) {
for (size_t operation = 0; operation < model.operations.size(); ++operation) {
const std::string message = "removeOperationTest: operation " + std::to_string(operation);
- validate(device, message, model,
- [operation](Model* model) { removeOperation(model, operation); });
+ validate(device, message, model, [operation](Model* model, ExecutionPreference*) {
+ removeOperation(model, operation);
+ });
}
}
@@ -409,11 +424,12 @@
const std::string message = "removeOperationInputTest: operation " +
std::to_string(operation) + ", input " +
std::to_string(input);
- validate(device, message, model, [operation, input](Model* model) {
- uint32_t operand = model->operations[operation].inputs[input];
- model->operands[operand].numberOfConsumers--;
- hidl_vec_removeAt(&model->operations[operation].inputs, input);
- });
+ validate(device, message, model,
+ [operation, input](Model* model, ExecutionPreference*) {
+ uint32_t operand = model->operations[operation].inputs[input];
+ model->operands[operand].numberOfConsumers--;
+ hidl_vec_removeAt(&model->operations[operation].inputs, input);
+ });
}
}
}
@@ -426,9 +442,10 @@
const std::string message = "removeOperationOutputTest: operation " +
std::to_string(operation) + ", output " +
std::to_string(output);
- validate(device, message, model, [operation, output](Model* model) {
- hidl_vec_removeAt(&model->operations[operation].outputs, output);
- });
+ validate(device, message, model,
+ [operation, output](Model* model, ExecutionPreference*) {
+ hidl_vec_removeAt(&model->operations[operation].outputs, output);
+ });
}
}
}
@@ -444,7 +461,7 @@
static void addOperationInputTest(const sp<IDevice>& device, const Model& model) {
for (size_t operation = 0; operation < model.operations.size(); ++operation) {
const std::string message = "addOperationInputTest: operation " + std::to_string(operation);
- validate(device, message, model, [operation](Model* model) {
+ validate(device, message, model, [operation](Model* model, ExecutionPreference*) {
uint32_t index = addOperand(model, OperandLifeTime::MODEL_INPUT);
hidl_vec_push_back(&model->operations[operation].inputs, index);
hidl_vec_push_back(&model->inputIndexes, index);
@@ -458,7 +475,7 @@
for (size_t operation = 0; operation < model.operations.size(); ++operation) {
const std::string message =
"addOperationOutputTest: operation " + std::to_string(operation);
- validate(device, message, model, [operation](Model* model) {
+ validate(device, message, model, [operation](Model* model, ExecutionPreference*) {
uint32_t index = addOperand(model, OperandLifeTime::MODEL_OUTPUT);
hidl_vec_push_back(&model->operations[operation].outputs, index);
hidl_vec_push_back(&model->outputIndexes, index);
@@ -474,12 +491,13 @@
};
static void mutateExecutionPreferenceTest(const sp<IDevice>& device, const Model& model) {
- for (int32_t preference : invalidExecutionPreferences) {
+ for (int32_t invalidPreference : invalidExecutionPreferences) {
const std::string message =
- "mutateExecutionPreferenceTest: preference " + std::to_string(preference);
- validate(
- device, message, model, [](Model*) {},
- static_cast<ExecutionPreference>(preference));
+ "mutateExecutionPreferenceTest: preference " + std::to_string(invalidPreference);
+ validate(device, message, model,
+ [invalidPreference](Model*, ExecutionPreference* preference) {
+ *preference = static_cast<ExecutionPreference>(invalidPreference);
+ });
}
}
diff --git a/neuralnetworks/1.1/vts/functional/ValidateRequest.cpp b/neuralnetworks/1.1/vts/functional/ValidateRequest.cpp
index 9684eb2..2914335 100644
--- a/neuralnetworks/1.1/vts/functional/ValidateRequest.cpp
+++ b/neuralnetworks/1.1/vts/functional/ValidateRequest.cpp
@@ -28,15 +28,17 @@
using V1_0::Request;
using V1_0::implementation::ExecutionCallback;
+using ExecutionMutation = std::function<void(Request*)>;
+
///////////////////////// UTILITY FUNCTIONS /////////////////////////
// Primary validation function. This function will take a valid request, apply a
// mutation to it to invalidate the request, then pass it to interface calls
-// that use the request. Note that the request here is passed by value, and any
-// mutation to the request does not leave this function.
+// that use the request.
static void validate(const sp<IPreparedModel>& preparedModel, const std::string& message,
- Request request, const std::function<void(Request*)>& mutation) {
- mutation(&request);
+ const Request& originalRequest, const ExecutionMutation& mutate) {
+ Request request = originalRequest;
+ mutate(&request);
SCOPED_TRACE(message + " [execute]");
sp<ExecutionCallback> executionCallback = new ExecutionCallback();