NNAPI VTS: Add validation for Priority

This CL also cleans up the validation framework code.

Bug: 67828197
Test: mma
Test: VtsHalNeuralnetworksV1_*TargetTest
Change-Id: I84661fb2b8204148788d10425ca0ac986158b15f
diff --git a/neuralnetworks/1.2/vts/functional/ValidateModel.cpp b/neuralnetworks/1.2/vts/functional/ValidateModel.cpp
index 30530be..7451f09 100644
--- a/neuralnetworks/1.2/vts/functional/ValidateModel.cpp
+++ b/neuralnetworks/1.2/vts/functional/ValidateModel.cpp
@@ -29,6 +29,8 @@
 using V1_1::ExecutionPreference;
 using HidlToken = hidl_array<uint8_t, static_cast<uint32_t>(Constant::BYTE_SIZE_OF_CACHE_TOKEN)>;
 
+using PrepareModelMutation = std::function<void(Model*, ExecutionPreference*)>;
+
 ///////////////////////// UTILITY FUNCTIONS /////////////////////////
 
 static void validateGetSupportedOperations(const sp<IDevice>& device, const std::string& message,
@@ -67,16 +69,19 @@
 }
 
 // Primary validation function. This function will take a valid model, apply a
-// mutation to it to invalidate the model, then pass it to interface calls that
-// use the model. Note that the model here is passed by value, and any mutation
-// to the model does not leave this function.
-static void validate(const sp<IDevice>& device, const std::string& message, Model model,
-                     const std::function<void(Model*)>& mutation,
-                     ExecutionPreference preference = ExecutionPreference::FAST_SINGLE_ANSWER) {
-    mutation(&model);
+// mutation to invalidate either the model or the execution preference, then
+// pass these to supportedOperations and/or prepareModel if that method is
+// called with an invalid argument.
+static void validate(const sp<IDevice>& device, const std::string& message,
+                     const Model& originalModel, const PrepareModelMutation& mutate) {
+    Model model = originalModel;
+    ExecutionPreference preference = ExecutionPreference::FAST_SINGLE_ANSWER;
+    mutate(&model, &preference);
+
     if (validExecutionPreference(preference)) {
         validateGetSupportedOperations(device, message, model);
     }
+
     validatePrepareModel(device, message, model, preference);
 }
 
@@ -115,9 +120,11 @@
             const std::string message = "mutateOperandTypeTest: operand " +
                                         std::to_string(operand) + " set to value " +
                                         std::to_string(invalidOperandType);
-            validate(device, message, model, [operand, invalidOperandType](Model* model) {
-                model->operands[operand].type = static_cast<OperandType>(invalidOperandType);
-            });
+            validate(device, message, model,
+                     [operand, invalidOperandType](Model* model, ExecutionPreference*) {
+                         model->operands[operand].type =
+                                 static_cast<OperandType>(invalidOperandType);
+                     });
         }
     }
 }
@@ -155,9 +162,10 @@
         }
         const std::string message = "mutateOperandRankTest: operand " + std::to_string(operand) +
                                     " has rank of " + std::to_string(invalidRank);
-        validate(device, message, model, [operand, invalidRank](Model* model) {
-            model->operands[operand].dimensions = std::vector<uint32_t>(invalidRank, 0);
-        });
+        validate(device, message, model,
+                 [operand, invalidRank](Model* model, ExecutionPreference*) {
+                     model->operands[operand].dimensions = std::vector<uint32_t>(invalidRank, 0);
+                 });
     }
 }
 
@@ -192,9 +200,10 @@
         const float invalidScale = getInvalidScale(model.operands[operand].type);
         const std::string message = "mutateOperandScaleTest: operand " + std::to_string(operand) +
                                     " has scale of " + std::to_string(invalidScale);
-        validate(device, message, model, [operand, invalidScale](Model* model) {
-            model->operands[operand].scale = invalidScale;
-        });
+        validate(device, message, model,
+                 [operand, invalidScale](Model* model, ExecutionPreference*) {
+                     model->operands[operand].scale = invalidScale;
+                 });
     }
 }
 
@@ -234,9 +243,10 @@
             const std::string message = "mutateOperandZeroPointTest: operand " +
                                         std::to_string(operand) + " has zero point of " +
                                         std::to_string(invalidZeroPoint);
-            validate(device, message, model, [operand, invalidZeroPoint](Model* model) {
-                model->operands[operand].zeroPoint = invalidZeroPoint;
-            });
+            validate(device, message, model,
+                     [operand, invalidZeroPoint](Model* model, ExecutionPreference*) {
+                         model->operands[operand].zeroPoint = invalidZeroPoint;
+                     });
         }
     }
 }
@@ -386,9 +396,10 @@
             const std::string message = "mutateOperationOperandTypeTest: operand " +
                                         std::to_string(operand) + " set to type " +
                                         toString(invalidOperandType);
-            validate(device, message, model, [operand, invalidOperandType](Model* model) {
-                mutateOperand(&model->operands[operand], invalidOperandType);
-            });
+            validate(device, message, model,
+                     [operand, invalidOperandType](Model* model, ExecutionPreference*) {
+                         mutateOperand(&model->operands[operand], invalidOperandType);
+                     });
         }
     }
 }
@@ -407,10 +418,11 @@
             const std::string message = "mutateOperationTypeTest: operation " +
                                         std::to_string(operation) + " set to value " +
                                         std::to_string(invalidOperationType);
-            validate(device, message, model, [operation, invalidOperationType](Model* model) {
-                model->operations[operation].type =
-                        static_cast<OperationType>(invalidOperationType);
-            });
+            validate(device, message, model,
+                     [operation, invalidOperationType](Model* model, ExecutionPreference*) {
+                         model->operations[operation].type =
+                                 static_cast<OperationType>(invalidOperationType);
+                     });
         }
     }
 }
@@ -424,9 +436,10 @@
             const std::string message = "mutateOperationInputOperandIndexTest: operation " +
                                         std::to_string(operation) + " input " +
                                         std::to_string(input);
-            validate(device, message, model, [operation, input, invalidOperand](Model* model) {
-                model->operations[operation].inputs[input] = invalidOperand;
-            });
+            validate(device, message, model,
+                     [operation, input, invalidOperand](Model* model, ExecutionPreference*) {
+                         model->operations[operation].inputs[input] = invalidOperand;
+                     });
         }
     }
 }
@@ -440,9 +453,10 @@
             const std::string message = "mutateOperationOutputOperandIndexTest: operation " +
                                         std::to_string(operation) + " output " +
                                         std::to_string(output);
-            validate(device, message, model, [operation, output, invalidOperand](Model* model) {
-                model->operations[operation].outputs[output] = invalidOperand;
-            });
+            validate(device, message, model,
+                     [operation, output, invalidOperand](Model* model, ExecutionPreference*) {
+                         model->operations[operation].outputs[output] = invalidOperand;
+                     });
         }
     }
 }
@@ -503,7 +517,7 @@
         }
         const std::string message = "removeOperandTest: operand " + std::to_string(operand);
         validate(device, message, model,
-                 [operand](Model* model) { removeOperand(model, operand); });
+                 [operand](Model* model, ExecutionPreference*) { removeOperand(model, operand); });
     }
 }
 
@@ -519,8 +533,9 @@
 static void removeOperationTest(const sp<IDevice>& device, const Model& model) {
     for (size_t operation = 0; operation < model.operations.size(); ++operation) {
         const std::string message = "removeOperationTest: operation " + std::to_string(operation);
-        validate(device, message, model,
-                 [operation](Model* model) { removeOperation(model, operation); });
+        validate(device, message, model, [operation](Model* model, ExecutionPreference*) {
+            removeOperation(model, operation);
+        });
     }
 }
 
@@ -601,11 +616,12 @@
             const std::string message = "removeOperationInputTest: operation " +
                                         std::to_string(operation) + ", input " +
                                         std::to_string(input);
-            validate(device, message, model, [operation, input](Model* model) {
-                uint32_t operand = model->operations[operation].inputs[input];
-                model->operands[operand].numberOfConsumers--;
-                hidl_vec_removeAt(&model->operations[operation].inputs, input);
-            });
+            validate(device, message, model,
+                     [operation, input](Model* model, ExecutionPreference*) {
+                         uint32_t operand = model->operations[operation].inputs[input];
+                         model->operands[operand].numberOfConsumers--;
+                         hidl_vec_removeAt(&model->operations[operation].inputs, input);
+                     });
         }
     }
 }
@@ -618,9 +634,10 @@
             const std::string message = "removeOperationOutputTest: operation " +
                                         std::to_string(operation) + ", output " +
                                         std::to_string(output);
-            validate(device, message, model, [operation, output](Model* model) {
-                hidl_vec_removeAt(&model->operations[operation].outputs, output);
-            });
+            validate(device, message, model,
+                     [operation, output](Model* model, ExecutionPreference*) {
+                         hidl_vec_removeAt(&model->operations[operation].outputs, output);
+                     });
         }
     }
 }
@@ -651,7 +668,7 @@
             continue;
         }
         const std::string message = "addOperationInputTest: operation " + std::to_string(operation);
-        validate(device, message, model, [operation](Model* model) {
+        validate(device, message, model, [operation](Model* model, ExecutionPreference*) {
             uint32_t index = addOperand(model, OperandLifeTime::MODEL_INPUT);
             hidl_vec_push_back(&model->operations[operation].inputs, index);
             hidl_vec_push_back(&model->inputIndexes, index);
@@ -665,7 +682,7 @@
     for (size_t operation = 0; operation < model.operations.size(); ++operation) {
         const std::string message =
                 "addOperationOutputTest: operation " + std::to_string(operation);
-        validate(device, message, model, [operation](Model* model) {
+        validate(device, message, model, [operation](Model* model, ExecutionPreference*) {
             uint32_t index = addOperand(model, OperandLifeTime::MODEL_OUTPUT);
             hidl_vec_push_back(&model->operations[operation].outputs, index);
             hidl_vec_push_back(&model->outputIndexes, index);
@@ -681,12 +698,13 @@
 };
 
 static void mutateExecutionPreferenceTest(const sp<IDevice>& device, const Model& model) {
-    for (int32_t preference : invalidExecutionPreferences) {
+    for (int32_t invalidPreference : invalidExecutionPreferences) {
         const std::string message =
-                "mutateExecutionPreferenceTest: preference " + std::to_string(preference);
-        validate(
-                device, message, model, [](Model*) {},
-                static_cast<ExecutionPreference>(preference));
+                "mutateExecutionPreferenceTest: preference " + std::to_string(invalidPreference);
+        validate(device, message, model,
+                 [invalidPreference](Model*, ExecutionPreference* preference) {
+                     *preference = static_cast<ExecutionPreference>(invalidPreference);
+                 });
     }
 }