Add Burst tests to NN AIDL HAL VTS

Bug: 180492058
Bug: 177267324
Test: mma
Test: VtsHalNeuralnetworksTargetTest
Change-Id: I1744005cbf750b70b42367b81a2fa6b8f24c1904
Merged-In: I1744005cbf750b70b42367b81a2fa6b8f24c1904
(cherry picked from commit 8b7e8138685678c1e7b1d7de8b06ff0899c61b2d)
diff --git a/neuralnetworks/aidl/vts/functional/ValidateRequest.cpp b/neuralnetworks/aidl/vts/functional/ValidateRequest.cpp
index 3be4c1b..29e2471 100644
--- a/neuralnetworks/aidl/vts/functional/ValidateRequest.cpp
+++ b/neuralnetworks/aidl/vts/functional/ValidateRequest.cpp
@@ -16,7 +16,9 @@
 
 #define LOG_TAG "neuralnetworks_aidl_hal_test"
 
+#include <aidl/android/hardware/neuralnetworks/RequestMemoryPool.h>
 #include <android/binder_auto_utils.h>
+#include <variant>
 
 #include <chrono>
 
@@ -77,6 +79,35 @@
         ASSERT_EQ(static_cast<ErrorStatus>(executeStatus.getServiceSpecificError()),
                   ErrorStatus::INVALID_ARGUMENT);
     }
+
+    // burst
+    {
+        SCOPED_TRACE(message + " [burst]");
+
+        // create burst
+        std::shared_ptr<IBurst> burst;
+        auto ret = preparedModel->configureExecutionBurst(&burst);
+        ASSERT_TRUE(ret.isOk()) << ret.getDescription();
+        ASSERT_NE(nullptr, burst.get());
+
+        // use -1 for all memory identifier tokens
+        const std::vector<int64_t> slots(request.pools.size(), -1);
+
+        ExecutionResult executionResult;
+        const auto executeStatus = burst->executeSynchronously(
+                request, slots, measure, kNoDeadline, kOmittedTimeoutDuration, &executionResult);
+        ASSERT_FALSE(executeStatus.isOk());
+        ASSERT_EQ(executeStatus.getExceptionCode(), EX_SERVICE_SPECIFIC);
+        ASSERT_EQ(static_cast<ErrorStatus>(executeStatus.getServiceSpecificError()),
+                  ErrorStatus::INVALID_ARGUMENT);
+    }
+}
+
+std::shared_ptr<IBurst> createBurst(const std::shared_ptr<IPreparedModel>& preparedModel) {
+    std::shared_ptr<IBurst> burst;
+    const auto ret = preparedModel->configureExecutionBurst(&burst);
+    if (!ret.isOk()) return nullptr;
+    return burst;
 }
 
 ///////////////////////// REMOVE INPUT ////////////////////////////////////
@@ -110,6 +141,65 @@
     removeOutputTest(preparedModel, request);
 }
 
+void validateBurst(const std::shared_ptr<IPreparedModel>& preparedModel, const Request& request) {
+    // create burst
+    std::shared_ptr<IBurst> burst;
+    auto ret = preparedModel->configureExecutionBurst(&burst);
+    ASSERT_TRUE(ret.isOk()) << ret.getDescription();
+    ASSERT_NE(nullptr, burst.get());
+
+    const auto test = [&burst, &request](const std::vector<int64_t>& slots) {
+        ExecutionResult executionResult;
+        const auto executeStatus =
+                burst->executeSynchronously(request, slots, /*measure=*/false, kNoDeadline,
+                                            kOmittedTimeoutDuration, &executionResult);
+        ASSERT_FALSE(executeStatus.isOk());
+        ASSERT_EQ(executeStatus.getExceptionCode(), EX_SERVICE_SPECIFIC);
+        ASSERT_EQ(static_cast<ErrorStatus>(executeStatus.getServiceSpecificError()),
+                  ErrorStatus::INVALID_ARGUMENT);
+    };
+
+    int64_t currentSlot = 0;
+    std::vector<int64_t> slots;
+    slots.reserve(request.pools.size());
+    for (const auto& pool : request.pools) {
+        if (pool.getTag() == RequestMemoryPool::Tag::pool) {
+            slots.push_back(currentSlot++);
+        } else {
+            slots.push_back(-1);
+        }
+    }
+
+    constexpr int64_t invalidSlot = -2;
+
+    // validate failure when invalid memory identifier token value
+    for (size_t i = 0; i < request.pools.size(); ++i) {
+        const int64_t oldSlotValue = slots[i];
+
+        slots[i] = invalidSlot;
+        test(slots);
+
+        slots[i] = oldSlotValue;
+    }
+
+    // validate failure when request.pools.size() != memoryIdentifierTokens.size()
+    if (request.pools.size() > 0) {
+        slots = std::vector<int64_t>(request.pools.size() - 1, -1);
+        test(slots);
+    }
+
+    // validate failure when request.pools.size() != memoryIdentifierTokens.size()
+    slots = std::vector<int64_t>(request.pools.size() + 1, -1);
+    test(slots);
+
+    // validate failure when invalid memory identifier token value
+    const auto freeStatus = burst->releaseMemoryResource(invalidSlot);
+    ASSERT_FALSE(freeStatus.isOk());
+    ASSERT_EQ(freeStatus.getExceptionCode(), EX_SERVICE_SPECIFIC);
+    ASSERT_EQ(static_cast<ErrorStatus>(freeStatus.getServiceSpecificError()),
+              ErrorStatus::INVALID_ARGUMENT);
+}
+
 void validateRequestFailure(const std::shared_ptr<IPreparedModel>& preparedModel,
                             const Request& request) {
     SCOPED_TRACE("Expecting request to fail [executeSynchronously]");