NNAPI: Add execution preference to prepareModel (HAL)
A model can be prepared in different ways to optimize for different
use-cases. This CL propagates the execution preference across the HAL so
that the NN service can better fit the users needs.
Bug: 77864669
Test: mma
Test: NeuralNetworksTest_static
Test: VtsHalNeuralnetworksV1_1TargetTest
Merged-In: Ib928d510d462f36b6a87d5e81010513db7829fa8
Change-Id: Ib928d510d462f36b6a87d5e81010513db7829fa8
(cherry picked from commit 2504c2fe4f3dc6975165d35a8ea1610d8bc4c477)
diff --git a/neuralnetworks/1.1/vts/functional/ValidateModel.cpp b/neuralnetworks/1.1/vts/functional/ValidateModel.cpp
index 7a20e26..3aa55f8 100644
--- a/neuralnetworks/1.1/vts/functional/ValidateModel.cpp
+++ b/neuralnetworks/1.1/vts/functional/ValidateModel.cpp
@@ -50,13 +50,13 @@
}
static void validatePrepareModel(const sp<IDevice>& device, const std::string& message,
- const V1_1::Model& model) {
+ const V1_1::Model& model, ExecutionPreference preference) {
SCOPED_TRACE(message + " [prepareModel_1_1]");
sp<PreparedModelCallback> preparedModelCallback = new PreparedModelCallback();
ASSERT_NE(nullptr, preparedModelCallback.get());
Return<ErrorStatus> prepareLaunchStatus =
- device->prepareModel_1_1(model, preparedModelCallback);
+ device->prepareModel_1_1(model, preference, preparedModelCallback);
ASSERT_TRUE(prepareLaunchStatus.isOk());
ASSERT_EQ(ErrorStatus::INVALID_ARGUMENT, static_cast<ErrorStatus>(prepareLaunchStatus));
@@ -67,15 +67,24 @@
ASSERT_EQ(nullptr, preparedModel.get());
}
+static bool validExecutionPreference(ExecutionPreference preference) {
+ return preference == ExecutionPreference::LOW_POWER ||
+ preference == ExecutionPreference::FAST_SINGLE_ANSWER ||
+ preference == ExecutionPreference::SUSTAINED_SPEED;
+}
+
// Primary validation function. This function will take a valid model, apply a
// mutation to it to invalidate the model, then pass it to interface calls that
// use the model. Note that the model here is passed by value, and any mutation
// to the model does not leave this function.
static void validate(const sp<IDevice>& device, const std::string& message, V1_1::Model model,
- const std::function<void(Model*)>& mutation) {
+ const std::function<void(Model*)>& mutation,
+ ExecutionPreference preference = ExecutionPreference::FAST_SINGLE_ANSWER) {
mutation(&model);
- validateGetSupportedOperations(device, message, model);
- validatePrepareModel(device, message, model);
+ if (validExecutionPreference(preference)) {
+ validateGetSupportedOperations(device, message, model);
+ }
+ validatePrepareModel(device, message, model, preference);
}
// Delete element from hidl_vec. hidl_vec doesn't support a "remove" operation,
@@ -486,6 +495,22 @@
}
}
+///////////////////////// VALIDATE EXECUTION PREFERENCE /////////////////////////
+
+static const int32_t invalidExecutionPreferences[] = {
+ static_cast<int32_t>(ExecutionPreference::LOW_POWER) - 1, // lower bound
+ static_cast<int32_t>(ExecutionPreference::SUSTAINED_SPEED) + 1, // upper bound
+};
+
+static void mutateExecutionPreferenceTest(const sp<IDevice>& device, const V1_1::Model& model) {
+ for (int32_t preference : invalidExecutionPreferences) {
+ const std::string message =
+ "mutateExecutionPreferenceTest: preference " + std::to_string(preference);
+ validate(device, message, model, [](Model*) {},
+ static_cast<ExecutionPreference>(preference));
+ }
+}
+
////////////////////////// ENTRY POINT //////////////////////////////
void ValidationTest::validateModel(const V1_1::Model& model) {
@@ -503,6 +528,7 @@
removeOperationOutputTest(device, model);
addOperationInputTest(device, model);
addOperationOutputTest(device, model);
+ mutateExecutionPreferenceTest(device, model);
}
} // namespace functional