Add validation tests for NNAPI Burst serialized format
This CL adds the following two types of validation tests on the NNAPI
Burst serialized format:
(1) it directly modifies the serialized data (invalidating it) to ensure
that vendor driver services properly validates the serialized
request
(2) it ensures that vendor driver services properly fail when the result
channel is not large enough to return the data
This CL additionally includes miscellaneous cleanups:
(1) having a generic "validateEverything" function
(2) moving the "prepareModel" function that's common across
validateRequest and validateBurst to a common area
Fixes: 129779280
Bug: 129157135
Test: mma
Test: VtsHalNeuralnetworksV1_2TargetTest (with sample-all)
Change-Id: Ib90fe7f662824de17db5a254a8c501855e45f6bd
Merged-In: Ib90fe7f662824de17db5a254a8c501855e45f6bd
(cherry picked from commit 20f28a24e908d54f4708ad17943154fb61a4c770)
diff --git a/neuralnetworks/1.0/vts/functional/VtsHalNeuralnetworks.cpp b/neuralnetworks/1.0/vts/functional/VtsHalNeuralnetworks.cpp
index 8883057..31638c4 100644
--- a/neuralnetworks/1.0/vts/functional/VtsHalNeuralnetworks.cpp
+++ b/neuralnetworks/1.0/vts/functional/VtsHalNeuralnetworks.cpp
@@ -68,6 +68,11 @@
::testing::VtsHalHidlTargetTestBase::TearDown();
}
+void ValidationTest::validateEverything(const Model& model, const std::vector<Request>& request) {
+ validateModel(model);
+ validateRequests(model, request);
+}
+
} // namespace functional
} // namespace vts
} // namespace V1_0
diff --git a/neuralnetworks/1.0/vts/functional/VtsHalNeuralnetworks.h b/neuralnetworks/1.0/vts/functional/VtsHalNeuralnetworks.h
index d4c114d..559d678 100644
--- a/neuralnetworks/1.0/vts/functional/VtsHalNeuralnetworks.h
+++ b/neuralnetworks/1.0/vts/functional/VtsHalNeuralnetworks.h
@@ -63,8 +63,11 @@
// Tag for the validation tests
class ValidationTest : public NeuralnetworksHidlTest {
protected:
- void validateModel(const Model& model);
- void validateRequests(const Model& model, const std::vector<Request>& request);
+ void validateEverything(const Model& model, const std::vector<Request>& request);
+
+ private:
+ void validateModel(const Model& model);
+ void validateRequests(const Model& model, const std::vector<Request>& request);
};
// Tag for the generated tests
diff --git a/neuralnetworks/1.1/vts/functional/VtsHalNeuralnetworks.cpp b/neuralnetworks/1.1/vts/functional/VtsHalNeuralnetworks.cpp
index 224a51d..11fa693 100644
--- a/neuralnetworks/1.1/vts/functional/VtsHalNeuralnetworks.cpp
+++ b/neuralnetworks/1.1/vts/functional/VtsHalNeuralnetworks.cpp
@@ -68,6 +68,11 @@
::testing::VtsHalHidlTargetTestBase::TearDown();
}
+void ValidationTest::validateEverything(const Model& model, const std::vector<Request>& request) {
+ validateModel(model);
+ validateRequests(model, request);
+}
+
} // namespace functional
} // namespace vts
} // namespace V1_1
diff --git a/neuralnetworks/1.1/vts/functional/VtsHalNeuralnetworks.h b/neuralnetworks/1.1/vts/functional/VtsHalNeuralnetworks.h
index 1c8c0e1..cea2b54 100644
--- a/neuralnetworks/1.1/vts/functional/VtsHalNeuralnetworks.h
+++ b/neuralnetworks/1.1/vts/functional/VtsHalNeuralnetworks.h
@@ -72,8 +72,11 @@
// Tag for the validation tests
class ValidationTest : public NeuralnetworksHidlTest {
protected:
- void validateModel(const Model& model);
- void validateRequests(const Model& model, const std::vector<Request>& request);
+ void validateEverything(const Model& model, const std::vector<Request>& request);
+
+ private:
+ void validateModel(const Model& model);
+ void validateRequests(const Model& model, const std::vector<Request>& request);
};
// Tag for the generated tests
diff --git a/neuralnetworks/1.2/vts/functional/Android.bp b/neuralnetworks/1.2/vts/functional/Android.bp
index 891b414..6c26820 100644
--- a/neuralnetworks/1.2/vts/functional/Android.bp
+++ b/neuralnetworks/1.2/vts/functional/Android.bp
@@ -20,6 +20,7 @@
defaults: ["VtsHalNeuralNetworksTargetTestDefaults"],
srcs: [
"GeneratedTestsV1_0.cpp",
+ "ValidateBurst.cpp",
],
cflags: [
"-DNN_TEST_DYNAMIC_OUTPUT_SHAPE"
@@ -32,6 +33,7 @@
defaults: ["VtsHalNeuralNetworksTargetTestDefaults"],
srcs: [
"GeneratedTestsV1_1.cpp",
+ "ValidateBurst.cpp",
],
cflags: [
"-DNN_TEST_DYNAMIC_OUTPUT_SHAPE"
@@ -46,6 +48,7 @@
"BasicTests.cpp",
"CompilationCachingTests.cpp",
"GeneratedTests.cpp",
+ "ValidateBurst.cpp",
],
cflags: [
"-DNN_TEST_DYNAMIC_OUTPUT_SHAPE"
@@ -58,6 +61,7 @@
srcs: [
"BasicTests.cpp",
"GeneratedTests.cpp",
+ "ValidateBurst.cpp",
],
cflags: [
"-DNN_TEST_DYNAMIC_OUTPUT_SHAPE",
diff --git a/neuralnetworks/1.2/vts/functional/ValidateBurst.cpp b/neuralnetworks/1.2/vts/functional/ValidateBurst.cpp
new file mode 100644
index 0000000..386c141
--- /dev/null
+++ b/neuralnetworks/1.2/vts/functional/ValidateBurst.cpp
@@ -0,0 +1,333 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#define LOG_TAG "neuralnetworks_hidl_hal_test"
+
+#include "VtsHalNeuralnetworks.h"
+
+#include "Callbacks.h"
+#include "ExecutionBurstController.h"
+#include "ExecutionBurstServer.h"
+#include "TestHarness.h"
+#include "Utils.h"
+
+#include <android-base/logging.h>
+
+namespace android {
+namespace hardware {
+namespace neuralnetworks {
+namespace V1_2 {
+namespace vts {
+namespace functional {
+
+using ::android::nn::ExecutionBurstController;
+using ::android::nn::RequestChannelSender;
+using ::android::nn::ResultChannelReceiver;
+using ExecutionBurstCallback = ::android::nn::ExecutionBurstController::ExecutionBurstCallback;
+
+constexpr size_t kExecutionBurstChannelLength = 1024;
+constexpr size_t kExecutionBurstChannelSmallLength = 8;
+
+///////////////////////// UTILITY FUNCTIONS /////////////////////////
+
+static bool badTiming(Timing timing) {
+ return timing.timeOnDevice == UINT64_MAX && timing.timeInDriver == UINT64_MAX;
+}
+
+static void createBurst(const sp<IPreparedModel>& preparedModel, const sp<IBurstCallback>& callback,
+ std::unique_ptr<RequestChannelSender>* sender,
+ std::unique_ptr<ResultChannelReceiver>* receiver,
+ sp<IBurstContext>* context) {
+ ASSERT_NE(nullptr, preparedModel.get());
+ ASSERT_NE(nullptr, sender);
+ ASSERT_NE(nullptr, receiver);
+ ASSERT_NE(nullptr, context);
+
+ // create FMQ objects
+ auto [fmqRequestChannel, fmqRequestDescriptor] =
+ RequestChannelSender::create(kExecutionBurstChannelLength, /*blocking=*/true);
+ auto [fmqResultChannel, fmqResultDescriptor] =
+ ResultChannelReceiver::create(kExecutionBurstChannelLength, /*blocking=*/true);
+ ASSERT_NE(nullptr, fmqRequestChannel.get());
+ ASSERT_NE(nullptr, fmqResultChannel.get());
+ ASSERT_NE(nullptr, fmqRequestDescriptor);
+ ASSERT_NE(nullptr, fmqResultDescriptor);
+
+ // configure burst
+ ErrorStatus errorStatus;
+ sp<IBurstContext> burstContext;
+ const Return<void> ret = preparedModel->configureExecutionBurst(
+ callback, *fmqRequestDescriptor, *fmqResultDescriptor,
+ [&errorStatus, &burstContext](ErrorStatus status, const sp<IBurstContext>& context) {
+ errorStatus = status;
+ burstContext = context;
+ });
+ ASSERT_TRUE(ret.isOk());
+ ASSERT_EQ(ErrorStatus::NONE, errorStatus);
+ ASSERT_NE(nullptr, burstContext.get());
+
+ // return values
+ *sender = std::move(fmqRequestChannel);
+ *receiver = std::move(fmqResultChannel);
+ *context = burstContext;
+}
+
+static void createBurstWithResultChannelLength(
+ const sp<IPreparedModel>& preparedModel,
+ std::shared_ptr<ExecutionBurstController>* controller, size_t resultChannelLength) {
+ ASSERT_NE(nullptr, preparedModel.get());
+ ASSERT_NE(nullptr, controller);
+
+ // create FMQ objects
+ auto [fmqRequestChannel, fmqRequestDescriptor] =
+ RequestChannelSender::create(kExecutionBurstChannelLength, /*blocking=*/true);
+ auto [fmqResultChannel, fmqResultDescriptor] =
+ ResultChannelReceiver::create(resultChannelLength, /*blocking=*/true);
+ ASSERT_NE(nullptr, fmqRequestChannel.get());
+ ASSERT_NE(nullptr, fmqResultChannel.get());
+ ASSERT_NE(nullptr, fmqRequestDescriptor);
+ ASSERT_NE(nullptr, fmqResultDescriptor);
+
+ // configure burst
+ sp<ExecutionBurstCallback> callback = new ExecutionBurstCallback();
+ ErrorStatus errorStatus;
+ sp<IBurstContext> burstContext;
+ const Return<void> ret = preparedModel->configureExecutionBurst(
+ callback, *fmqRequestDescriptor, *fmqResultDescriptor,
+ [&errorStatus, &burstContext](ErrorStatus status, const sp<IBurstContext>& context) {
+ errorStatus = status;
+ burstContext = context;
+ });
+ ASSERT_TRUE(ret.isOk());
+ ASSERT_EQ(ErrorStatus::NONE, errorStatus);
+ ASSERT_NE(nullptr, burstContext.get());
+
+ // return values
+ *controller = std::make_shared<ExecutionBurstController>(
+ std::move(fmqRequestChannel), std::move(fmqResultChannel), burstContext, callback);
+}
+
+// Primary validation function. This function will take a valid serialized
+// request, apply a mutation to it to invalidate the serialized request, then
+// pass it to interface calls that use the serialized request. Note that the
+// serialized request here is passed by value, and any mutation to the
+// serialized request does not leave this function.
+static void validate(RequestChannelSender* sender, ResultChannelReceiver* receiver,
+ const std::string& message, std::vector<FmqRequestDatum> serialized,
+ const std::function<void(std::vector<FmqRequestDatum>*)>& mutation) {
+ mutation(&serialized);
+
+ // skip if packet is too large to send
+ if (serialized.size() > kExecutionBurstChannelLength) {
+ return;
+ }
+
+ SCOPED_TRACE(message);
+
+ // send invalid packet
+ sender->sendPacket(serialized);
+
+ // receive error
+ auto results = receiver->getBlocking();
+ ASSERT_TRUE(results.has_value());
+ const auto [status, outputShapes, timing] = std::move(*results);
+ EXPECT_NE(ErrorStatus::NONE, status);
+ EXPECT_EQ(0u, outputShapes.size());
+ EXPECT_TRUE(badTiming(timing));
+}
+
+static std::vector<FmqRequestDatum> createUniqueDatum() {
+ const FmqRequestDatum::PacketInformation packetInformation = {
+ /*.packetSize=*/10, /*.numberOfInputOperands=*/10, /*.numberOfOutputOperands=*/10,
+ /*.numberOfPools=*/10};
+ const FmqRequestDatum::OperandInformation operandInformation = {
+ /*.hasNoValue=*/false, /*.location=*/{}, /*.numberOfDimensions=*/10};
+ const int32_t invalidPoolIdentifier = std::numeric_limits<int32_t>::max();
+ std::vector<FmqRequestDatum> unique(7);
+ unique[0].packetInformation(packetInformation);
+ unique[1].inputOperandInformation(operandInformation);
+ unique[2].inputOperandDimensionValue(0);
+ unique[3].outputOperandInformation(operandInformation);
+ unique[4].outputOperandDimensionValue(0);
+ unique[5].poolIdentifier(invalidPoolIdentifier);
+ unique[6].measureTiming(MeasureTiming::YES);
+ return unique;
+}
+
+static const std::vector<FmqRequestDatum>& getUniqueDatum() {
+ static const std::vector<FmqRequestDatum> unique = createUniqueDatum();
+ return unique;
+}
+
+///////////////////////// REMOVE DATUM ////////////////////////////////////
+
+static void removeDatumTest(RequestChannelSender* sender, ResultChannelReceiver* receiver,
+ const std::vector<FmqRequestDatum>& serialized) {
+ for (size_t index = 0; index < serialized.size(); ++index) {
+ const std::string message = "removeDatum: removed datum at index " + std::to_string(index);
+ validate(sender, receiver, message, serialized,
+ [index](std::vector<FmqRequestDatum>* serialized) {
+ serialized->erase(serialized->begin() + index);
+ });
+ }
+}
+
+///////////////////////// ADD DATUM ////////////////////////////////////
+
+static void addDatumTest(RequestChannelSender* sender, ResultChannelReceiver* receiver,
+ const std::vector<FmqRequestDatum>& serialized) {
+ const std::vector<FmqRequestDatum>& extra = getUniqueDatum();
+ for (size_t index = 0; index <= serialized.size(); ++index) {
+ for (size_t type = 0; type < extra.size(); ++type) {
+ const std::string message = "addDatum: added datum type " + std::to_string(type) +
+ " at index " + std::to_string(index);
+ validate(sender, receiver, message, serialized,
+ [index, type, &extra](std::vector<FmqRequestDatum>* serialized) {
+ serialized->insert(serialized->begin() + index, extra[type]);
+ });
+ }
+ }
+}
+
+///////////////////////// MUTATE DATUM ////////////////////////////////////
+
+static bool interestingCase(const FmqRequestDatum& lhs, const FmqRequestDatum& rhs) {
+ using Discriminator = FmqRequestDatum::hidl_discriminator;
+
+ const bool differentValues = (lhs != rhs);
+ const bool sameSumType = (lhs.getDiscriminator() == rhs.getDiscriminator());
+ const auto discriminator = rhs.getDiscriminator();
+ const bool isDimensionValue = (discriminator == Discriminator::inputOperandDimensionValue ||
+ discriminator == Discriminator::outputOperandDimensionValue);
+
+ return differentValues && !(sameSumType && isDimensionValue);
+}
+
+static void mutateDatumTest(RequestChannelSender* sender, ResultChannelReceiver* receiver,
+ const std::vector<FmqRequestDatum>& serialized) {
+ const std::vector<FmqRequestDatum>& change = getUniqueDatum();
+ for (size_t index = 0; index < serialized.size(); ++index) {
+ for (size_t type = 0; type < change.size(); ++type) {
+ if (interestingCase(serialized[index], change[type])) {
+ const std::string message = "mutateDatum: changed datum at index " +
+ std::to_string(index) + " to datum type " +
+ std::to_string(type);
+ validate(sender, receiver, message, serialized,
+ [index, type, &change](std::vector<FmqRequestDatum>* serialized) {
+ (*serialized)[index] = change[type];
+ });
+ }
+ }
+ }
+}
+
+///////////////////////// BURST VALIATION TESTS ////////////////////////////////////
+
+static void validateBurstSerialization(const sp<IPreparedModel>& preparedModel,
+ const std::vector<Request>& requests) {
+ // create burst
+ std::unique_ptr<RequestChannelSender> sender;
+ std::unique_ptr<ResultChannelReceiver> receiver;
+ sp<ExecutionBurstCallback> callback = new ExecutionBurstCallback();
+ sp<IBurstContext> context;
+ ASSERT_NO_FATAL_FAILURE(createBurst(preparedModel, callback, &sender, &receiver, &context));
+ ASSERT_NE(nullptr, sender.get());
+ ASSERT_NE(nullptr, receiver.get());
+ ASSERT_NE(nullptr, context.get());
+
+ // validate each request
+ for (const Request& request : requests) {
+ // load memory into callback slots
+ std::vector<intptr_t> keys(request.pools.size());
+ for (size_t i = 0; i < keys.size(); ++i) {
+ keys[i] = reinterpret_cast<intptr_t>(&request.pools[i]);
+ }
+ const std::vector<int32_t> slots = callback->getSlots(request.pools, keys);
+
+ // ensure slot std::numeric_limits<int32_t>::max() doesn't exist (for
+ // subsequent slot validation testing)
+ const auto maxElement = std::max_element(slots.begin(), slots.end());
+ ASSERT_NE(slots.end(), maxElement);
+ ASSERT_NE(std::numeric_limits<int32_t>::max(), *maxElement);
+
+ // serialize the request
+ const auto serialized = ::android::nn::serialize(request, MeasureTiming::YES, slots);
+
+ // validations
+ removeDatumTest(sender.get(), receiver.get(), serialized);
+ addDatumTest(sender.get(), receiver.get(), serialized);
+ mutateDatumTest(sender.get(), receiver.get(), serialized);
+ }
+}
+
+static void validateBurstFmqLength(const sp<IPreparedModel>& preparedModel,
+ const std::vector<Request>& requests) {
+ // create regular burst
+ std::shared_ptr<ExecutionBurstController> controllerRegular;
+ ASSERT_NO_FATAL_FAILURE(createBurstWithResultChannelLength(preparedModel, &controllerRegular,
+ kExecutionBurstChannelLength));
+ ASSERT_NE(nullptr, controllerRegular.get());
+
+ // create burst with small output channel
+ std::shared_ptr<ExecutionBurstController> controllerSmall;
+ ASSERT_NO_FATAL_FAILURE(createBurstWithResultChannelLength(preparedModel, &controllerSmall,
+ kExecutionBurstChannelSmallLength));
+ ASSERT_NE(nullptr, controllerSmall.get());
+
+ // validate each request
+ for (const Request& request : requests) {
+ // load memory into callback slots
+ std::vector<intptr_t> keys(request.pools.size());
+ for (size_t i = 0; i < keys.size(); ++i) {
+ keys[i] = reinterpret_cast<intptr_t>(&request.pools[i]);
+ }
+
+ // collect serialized result by running regular burst
+ const auto [status1, outputShapes1, timing1] =
+ controllerRegular->compute(request, MeasureTiming::NO, keys);
+
+ // skip test if synchronous output isn't useful
+ const std::vector<FmqResultDatum> serialized =
+ ::android::nn::serialize(status1, outputShapes1, timing1);
+ if (status1 != ErrorStatus::NONE ||
+ serialized.size() <= kExecutionBurstChannelSmallLength) {
+ continue;
+ }
+
+ // by this point, execution should fail because the result channel isn't
+ // large enough to return the serialized result
+ const auto [status2, outputShapes2, timing2] =
+ controllerSmall->compute(request, MeasureTiming::NO, keys);
+ EXPECT_NE(ErrorStatus::NONE, status2);
+ EXPECT_EQ(0u, outputShapes2.size());
+ EXPECT_TRUE(badTiming(timing2));
+ }
+}
+
+///////////////////////////// ENTRY POINT //////////////////////////////////
+
+void ValidationTest::validateBurst(const sp<IPreparedModel>& preparedModel,
+ const std::vector<Request>& requests) {
+ ASSERT_NO_FATAL_FAILURE(validateBurstSerialization(preparedModel, requests));
+ ASSERT_NO_FATAL_FAILURE(validateBurstFmqLength(preparedModel, requests));
+}
+
+} // namespace functional
+} // namespace vts
+} // namespace V1_2
+} // namespace neuralnetworks
+} // namespace hardware
+} // namespace android
diff --git a/neuralnetworks/1.2/vts/functional/ValidateRequest.cpp b/neuralnetworks/1.2/vts/functional/ValidateRequest.cpp
index 870d017..9703c2d 100644
--- a/neuralnetworks/1.2/vts/functional/ValidateRequest.cpp
+++ b/neuralnetworks/1.2/vts/functional/ValidateRequest.cpp
@@ -35,9 +35,7 @@
namespace functional {
using ::android::hardware::neuralnetworks::V1_2::implementation::ExecutionCallback;
-using ::android::hardware::neuralnetworks::V1_2::implementation::PreparedModelCallback;
using ::android::hidl::memory::V1_0::IMemory;
-using HidlToken = hidl_array<uint8_t, static_cast<uint32_t>(Constant::BYTE_SIZE_OF_CACHE_TOKEN)>;
using test_helper::for_all;
using test_helper::MixedTyped;
using test_helper::MixedTypedExample;
@@ -48,55 +46,6 @@
return timing.timeOnDevice == UINT64_MAX && timing.timeInDriver == UINT64_MAX;
}
-static void createPreparedModel(const sp<IDevice>& device, const Model& model,
- sp<IPreparedModel>* preparedModel) {
- ASSERT_NE(nullptr, preparedModel);
-
- // see if service can handle model
- bool fullySupportsModel = false;
- Return<void> supportedOpsLaunchStatus = device->getSupportedOperations_1_2(
- model, [&fullySupportsModel](ErrorStatus status, const hidl_vec<bool>& supported) {
- ASSERT_EQ(ErrorStatus::NONE, status);
- ASSERT_NE(0ul, supported.size());
- fullySupportsModel =
- std::all_of(supported.begin(), supported.end(), [](bool valid) { return valid; });
- });
- ASSERT_TRUE(supportedOpsLaunchStatus.isOk());
-
- // launch prepare model
- sp<PreparedModelCallback> preparedModelCallback = new PreparedModelCallback();
- ASSERT_NE(nullptr, preparedModelCallback.get());
- Return<ErrorStatus> prepareLaunchStatus = device->prepareModel_1_2(
- model, ExecutionPreference::FAST_SINGLE_ANSWER, hidl_vec<hidl_handle>(),
- hidl_vec<hidl_handle>(), HidlToken(), preparedModelCallback);
- ASSERT_TRUE(prepareLaunchStatus.isOk());
- ASSERT_EQ(ErrorStatus::NONE, static_cast<ErrorStatus>(prepareLaunchStatus));
-
- // retrieve prepared model
- preparedModelCallback->wait();
- ErrorStatus prepareReturnStatus = preparedModelCallback->getStatus();
- *preparedModel = getPreparedModel_1_2(preparedModelCallback);
-
- // The getSupportedOperations_1_2 call returns a list of operations that are
- // guaranteed not to fail if prepareModel_1_2 is called, and
- // 'fullySupportsModel' is true i.f.f. the entire model is guaranteed.
- // If a driver has any doubt that it can prepare an operation, it must
- // return false. So here, if a driver isn't sure if it can support an
- // operation, but reports that it successfully prepared the model, the test
- // can continue.
- if (!fullySupportsModel && prepareReturnStatus != ErrorStatus::NONE) {
- ASSERT_EQ(nullptr, preparedModel->get());
- LOG(INFO) << "NN VTS: Unable to test Request validation because vendor service cannot "
- "prepare model that it does not support.";
- std::cout << "[ ] Unable to test Request validation because vendor service "
- "cannot prepare model that it does not support."
- << std::endl;
- return;
- }
- ASSERT_EQ(ErrorStatus::NONE, prepareReturnStatus);
- ASSERT_NE(nullptr, preparedModel->get());
-}
-
// Primary validation function. This function will take a valid request, apply a
// mutation to it to invalidate the request, then pass it to interface calls
// that use the request. Note that the request here is passed by value, and any
@@ -316,14 +265,8 @@
return requests;
}
-void ValidationTest::validateRequests(const Model& model, const std::vector<Request>& requests) {
- // create IPreparedModel
- sp<IPreparedModel> preparedModel;
- ASSERT_NO_FATAL_FAILURE(createPreparedModel(device, model, &preparedModel));
- if (preparedModel == nullptr) {
- return;
- }
-
+void ValidationTest::validateRequests(const sp<IPreparedModel>& preparedModel,
+ const std::vector<Request>& requests) {
// validate each request
for (const Request& request : requests) {
removeInputTest(preparedModel, request);
diff --git a/neuralnetworks/1.2/vts/functional/VtsHalNeuralnetworks.cpp b/neuralnetworks/1.2/vts/functional/VtsHalNeuralnetworks.cpp
index 4728c28..93182f1 100644
--- a/neuralnetworks/1.2/vts/functional/VtsHalNeuralnetworks.cpp
+++ b/neuralnetworks/1.2/vts/functional/VtsHalNeuralnetworks.cpp
@@ -18,6 +18,10 @@
#include "VtsHalNeuralnetworks.h"
+#include <android-base/logging.h>
+
+#include "Callbacks.h"
+
namespace android {
namespace hardware {
namespace neuralnetworks {
@@ -25,6 +29,61 @@
namespace vts {
namespace functional {
+using ::android::hardware::neuralnetworks::V1_2::implementation::ExecutionCallback;
+using ::android::hardware::neuralnetworks::V1_2::implementation::PreparedModelCallback;
+using HidlToken = hidl_array<uint8_t, static_cast<uint32_t>(Constant::BYTE_SIZE_OF_CACHE_TOKEN)>;
+using V1_1::ExecutionPreference;
+
+// internal helper function
+static void createPreparedModel(const sp<IDevice>& device, const Model& model,
+ sp<IPreparedModel>* preparedModel) {
+ ASSERT_NE(nullptr, preparedModel);
+
+ // see if service can handle model
+ bool fullySupportsModel = false;
+ Return<void> supportedOpsLaunchStatus = device->getSupportedOperations_1_2(
+ model, [&fullySupportsModel](ErrorStatus status, const hidl_vec<bool>& supported) {
+ ASSERT_EQ(ErrorStatus::NONE, status);
+ ASSERT_NE(0ul, supported.size());
+ fullySupportsModel = std::all_of(supported.begin(), supported.end(),
+ [](bool valid) { return valid; });
+ });
+ ASSERT_TRUE(supportedOpsLaunchStatus.isOk());
+
+ // launch prepare model
+ sp<PreparedModelCallback> preparedModelCallback = new PreparedModelCallback();
+ ASSERT_NE(nullptr, preparedModelCallback.get());
+ Return<ErrorStatus> prepareLaunchStatus = device->prepareModel_1_2(
+ model, ExecutionPreference::FAST_SINGLE_ANSWER, hidl_vec<hidl_handle>(),
+ hidl_vec<hidl_handle>(), HidlToken(), preparedModelCallback);
+ ASSERT_TRUE(prepareLaunchStatus.isOk());
+ ASSERT_EQ(ErrorStatus::NONE, static_cast<ErrorStatus>(prepareLaunchStatus));
+
+ // retrieve prepared model
+ preparedModelCallback->wait();
+ ErrorStatus prepareReturnStatus = preparedModelCallback->getStatus();
+ *preparedModel = getPreparedModel_1_2(preparedModelCallback);
+
+ // The getSupportedOperations_1_2 call returns a list of operations that are
+ // guaranteed not to fail if prepareModel_1_2 is called, and
+ // 'fullySupportsModel' is true i.f.f. the entire model is guaranteed.
+ // If a driver has any doubt that it can prepare an operation, it must
+ // return false. So here, if a driver isn't sure if it can support an
+ // operation, but reports that it successfully prepared the model, the test
+ // can continue.
+ if (!fullySupportsModel && prepareReturnStatus != ErrorStatus::NONE) {
+ ASSERT_EQ(nullptr, preparedModel->get());
+ LOG(INFO) << "NN VTS: Unable to test Request validation because vendor service cannot "
+ "prepare model that it does not support.";
+ std::cout << "[ ] Unable to test Request validation because vendor service "
+ "cannot prepare model that it does not support."
+ << std::endl;
+ return;
+ }
+ ASSERT_EQ(ErrorStatus::NONE, prepareReturnStatus);
+ ASSERT_NE(nullptr, preparedModel->get());
+}
+
// A class for test environment setup
NeuralnetworksHidlEnvironment::NeuralnetworksHidlEnvironment() {}
@@ -68,6 +127,20 @@
::testing::VtsHalHidlTargetTestBase::TearDown();
}
+void ValidationTest::validateEverything(const Model& model, const std::vector<Request>& request) {
+ validateModel(model);
+
+ // create IPreparedModel
+ sp<IPreparedModel> preparedModel;
+ ASSERT_NO_FATAL_FAILURE(createPreparedModel(device, model, &preparedModel));
+ if (preparedModel == nullptr) {
+ return;
+ }
+
+ validateRequests(preparedModel, request);
+ validateBurst(preparedModel, request);
+}
+
sp<IPreparedModel> getPreparedModel_1_2(
const sp<V1_2::implementation::PreparedModelCallback>& callback) {
sp<V1_0::IPreparedModel> preparedModelV1_0 = callback->getPreparedModel();
diff --git a/neuralnetworks/1.2/vts/functional/VtsHalNeuralnetworks.h b/neuralnetworks/1.2/vts/functional/VtsHalNeuralnetworks.h
index 404eec0..36e73a4 100644
--- a/neuralnetworks/1.2/vts/functional/VtsHalNeuralnetworks.h
+++ b/neuralnetworks/1.2/vts/functional/VtsHalNeuralnetworks.h
@@ -72,8 +72,14 @@
// Tag for the validation tests
class ValidationTest : public NeuralnetworksHidlTest {
protected:
- void validateModel(const Model& model);
- void validateRequests(const Model& model, const std::vector<Request>& request);
+ void validateEverything(const Model& model, const std::vector<Request>& request);
+
+ private:
+ void validateModel(const Model& model);
+ void validateRequests(const sp<IPreparedModel>& preparedModel,
+ const std::vector<Request>& requests);
+ void validateBurst(const sp<IPreparedModel>& preparedModel,
+ const std::vector<Request>& requests);
};
// Tag for the generated tests