HAL interface for compilation and execution hints
The following AIDL types are added:
- TokenValuePair
- PrepareModelConfig
- ExecutionConfig
The following AIDL methods are added:
- IDevice::prepareModelWithConfig
- IPreparedModel::executeSynchronouslyWithConfig
- IPreparedModel::executeFencedWithConfig
- IBurst::executeSynchronouslyWithConfig
The compilation and execution hints are being stored as a list of
token-value pairs as part of the PrepareModelConfig / ExecutionConfig.
And the PrepareModelConfig / ExecutionConfig parcelables are created in
order to make future extensions to the execution related interfaces
easier.
It is the drivers responsibility to verify the hints, and it is allowed
for the driver to ignore them.
Bug: 203248587
Test: neuralnetworks_utils_hal_aidl_test
Change-Id: I98240fd75089fc85cdfcaa0be28aab8a6f0dfca5
Merged-In: I98240fd75089fc85cdfcaa0be28aab8a6f0dfca5
(cherry picked from commit 0e671f3edb9d2c78658a4ef4169e3211e3f9bb00)
diff --git a/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/Burst.h b/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/Burst.h
index 0cc78d4..f2e6e75 100644
--- a/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/Burst.h
+++ b/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/Burst.h
@@ -86,10 +86,12 @@
GUARDED_BY(mMutex);
};
+ // featureLevel is for testing purposes.
static nn::GeneralResult<std::shared_ptr<const Burst>> create(
- std::shared_ptr<aidl_hal::IBurst> burst);
+ std::shared_ptr<aidl_hal::IBurst> burst, nn::Version featureLevel);
- Burst(PrivateConstructorTag tag, std::shared_ptr<aidl_hal::IBurst> burst);
+ Burst(PrivateConstructorTag tag, std::shared_ptr<aidl_hal::IBurst> burst,
+ nn::Version featureLevel);
// See IBurst::cacheMemory for information.
OptionalCacheHold cacheMemory(const nn::SharedMemory& memory) const override;
@@ -97,23 +99,29 @@
// See IBurst::execute for information.
nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> execute(
const nn::Request& request, nn::MeasureTiming measure,
- const nn::OptionalTimePoint& deadline,
- const nn::OptionalDuration& loopTimeoutDuration) const override;
+ const nn::OptionalTimePoint& deadline, const nn::OptionalDuration& loopTimeoutDuration,
+ const std::vector<nn::TokenValuePair>& hints,
+ const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
// See IBurst::createReusableExecution for information.
nn::GeneralResult<nn::SharedExecution> createReusableExecution(
const nn::Request& request, nn::MeasureTiming measure,
- const nn::OptionalDuration& loopTimeoutDuration) const override;
+ const nn::OptionalDuration& loopTimeoutDuration,
+ const std::vector<nn::TokenValuePair>& hints,
+ const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> executeInternal(
const aidl_hal::Request& request, const std::vector<int64_t>& memoryIdentifierTokens,
bool measure, int64_t deadline, int64_t loopTimeoutDuration,
+ const std::vector<nn::TokenValuePair>& hints,
+ const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix,
const hal::utils::RequestRelocation& relocation) const;
private:
mutable std::atomic_flag mExecutionInFlight = ATOMIC_FLAG_INIT;
const std::shared_ptr<aidl_hal::IBurst> kBurst;
const std::shared_ptr<MemoryCache> kMemoryCache;
+ const nn::Version kFeatureLevel;
};
} // namespace aidl::android::hardware::neuralnetworks::utils
diff --git a/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/Conversions.h b/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/Conversions.h
index 477b311..af58715 100644
--- a/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/Conversions.h
+++ b/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/Conversions.h
@@ -46,6 +46,10 @@
#include <aidl/android/hardware/neuralnetworks/SymmPerChannelQuantParams.h>
#include <aidl/android/hardware/neuralnetworks/Timing.h>
+#ifdef NN_AIDL_V4_OR_ABOVE
+#include <aidl/android/hardware/neuralnetworks/TokenValuePair.h>
+#endif // NN_AIDL_V4_OR_ABOVE
+
#include <android/binder_auto_utils.h>
#include <nnapi/Result.h>
#include <nnapi/Types.h>
@@ -74,7 +78,7 @@
const aidl_hal::SymmPerChannelQuantParams& symmPerChannelQuantParams);
GeneralResult<Operation> unvalidatedConvert(const aidl_hal::Operation& operation);
GeneralResult<Model> unvalidatedConvert(const aidl_hal::Model& model);
-GeneralResult<Model::ExtensionNameAndPrefix> unvalidatedConvert(
+GeneralResult<ExtensionNameAndPrefix> unvalidatedConvert(
const aidl_hal::ExtensionNameAndPrefix& extensionNameAndPrefix);
GeneralResult<Model::OperandValues> unvalidatedConvert(const std::vector<uint8_t>& operandValues);
GeneralResult<Model::Subgraph> unvalidatedConvert(const aidl_hal::Subgraph& subgraph);
@@ -97,6 +101,10 @@
const aidl_hal::ExtensionOperandTypeInformation& operandTypeInformation);
GeneralResult<SharedHandle> unvalidatedConvert(const ndk::ScopedFileDescriptor& handle);
+#ifdef NN_AIDL_V4_OR_ABOVE
+GeneralResult<TokenValuePair> unvalidatedConvert(const aidl_hal::TokenValuePair& tokenValuePair);
+#endif // NN_AIDL_V4_OR_ABOVE
+
GeneralResult<std::vector<Operation>> unvalidatedConvert(
const std::vector<aidl_hal::Operation>& operations);
@@ -116,6 +124,14 @@
GeneralResult<std::vector<Extension>> convert(const std::vector<aidl_hal::Extension>& extension);
GeneralResult<std::vector<SharedMemory>> convert(const std::vector<aidl_hal::Memory>& memories);
+GeneralResult<std::vector<ExtensionNameAndPrefix>> convert(
+ const std::vector<aidl_hal::ExtensionNameAndPrefix>& extensionNameAndPrefix);
+
+#ifdef NN_AIDL_V4_OR_ABOVE
+GeneralResult<std::vector<TokenValuePair>> convert(
+ const std::vector<aidl_hal::TokenValuePair>& metaData);
+#endif // NN_AIDL_V4_OR_ABOVE
+
GeneralResult<std::vector<OutputShape>> convert(
const std::vector<aidl_hal::OutputShape>& outputShapes);
GeneralResult<std::vector<SharedHandle>> convert(
@@ -152,7 +168,7 @@
nn::GeneralResult<std::vector<uint8_t>> unvalidatedConvert(
const nn::Model::OperandValues& operandValues);
nn::GeneralResult<ExtensionNameAndPrefix> unvalidatedConvert(
- const nn::Model::ExtensionNameAndPrefix& extensionNameToPrefix);
+ const nn::ExtensionNameAndPrefix& extensionNameToPrefix);
nn::GeneralResult<Model> unvalidatedConvert(const nn::Model& model);
nn::GeneralResult<Priority> unvalidatedConvert(const nn::Priority& priority);
nn::GeneralResult<Request> unvalidatedConvert(const nn::Request& request);
@@ -166,6 +182,10 @@
nn::GeneralResult<Capabilities> unvalidatedConvert(const nn::Capabilities& capabilities);
nn::GeneralResult<Extension> unvalidatedConvert(const nn::Extension& extension);
+#ifdef NN_AIDL_V4_OR_ABOVE
+nn::GeneralResult<TokenValuePair> unvalidatedConvert(const nn::TokenValuePair& tokenValuePair);
+#endif // NN_AIDL_V4_OR_ABOVE
+
nn::GeneralResult<std::vector<uint8_t>> convert(const nn::CacheToken& cacheToken);
nn::GeneralResult<BufferDesc> convert(const nn::BufferDesc& bufferDesc);
nn::GeneralResult<DeviceType> convert(const nn::DeviceType& deviceType);
@@ -190,6 +210,13 @@
nn::GeneralResult<std::vector<ndk::ScopedFileDescriptor>> convert(
const std::vector<nn::SyncFence>& syncFences);
nn::GeneralResult<std::vector<Extension>> convert(const std::vector<nn::Extension>& extensions);
+nn::GeneralResult<std::vector<ExtensionNameAndPrefix>> convert(
+ const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix);
+
+#ifdef NN_AIDL_V4_OR_ABOVE
+nn::GeneralResult<std::vector<TokenValuePair>> convert(
+ const std::vector<nn::TokenValuePair>& metaData);
+#endif // NN_AIDL_V4_OR_ABOVE
nn::GeneralResult<std::vector<int32_t>> toSigned(const std::vector<uint32_t>& vec);
diff --git a/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/Device.h b/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/Device.h
index d558f66..615c6de 100644
--- a/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/Device.h
+++ b/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/Device.h
@@ -42,6 +42,7 @@
struct PrivateConstructorTag {};
public:
+ // featureLevel is for testing purposes.
static nn::GeneralResult<std::shared_ptr<const Device>> create(
std::string name, std::shared_ptr<aidl_hal::IDevice> device, nn::Version featureLevel);
@@ -67,8 +68,9 @@
nn::GeneralResult<nn::SharedPreparedModel> prepareModel(
const nn::Model& model, nn::ExecutionPreference preference, nn::Priority priority,
nn::OptionalTimePoint deadline, const std::vector<nn::SharedHandle>& modelCache,
- const std::vector<nn::SharedHandle>& dataCache,
- const nn::CacheToken& token) const override;
+ const std::vector<nn::SharedHandle>& dataCache, const nn::CacheToken& token,
+ const std::vector<nn::TokenValuePair>& hints,
+ const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
nn::GeneralResult<nn::SharedPreparedModel> prepareModelFromCache(
nn::OptionalTimePoint deadline, const std::vector<nn::SharedHandle>& modelCache,
diff --git a/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/HalInterfaces.h b/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/HalInterfaces.h
index 205d428..cacdc26 100644
--- a/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/HalInterfaces.h
+++ b/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/HalInterfaces.h
@@ -63,7 +63,9 @@
#ifdef NN_AIDL_V4_OR_ABOVE
#include <aidl/android/hardware/neuralnetworks/BnExecution.h>
+#include <aidl/android/hardware/neuralnetworks/ExecutionConfig.h>
#include <aidl/android/hardware/neuralnetworks/IExecution.h>
+#include <aidl/android/hardware/neuralnetworks/PrepareModelConfig.h>
#endif // NN_AIDL_V4_OR_ABOVE
namespace android::nn {
diff --git a/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/InvalidDevice.h b/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/InvalidDevice.h
index e66507a..9375c1d 100644
--- a/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/InvalidDevice.h
+++ b/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/InvalidDevice.h
@@ -53,6 +53,9 @@
const std::vector<ndk::ScopedFileDescriptor>& dataCache,
const std::vector<uint8_t>& token,
const std::shared_ptr<IPreparedModelCallback>& callback) override;
+ ndk::ScopedAStatus prepareModelWithConfig(
+ const Model& model, const PrepareModelConfig& config,
+ const std::shared_ptr<IPreparedModelCallback>& callback) override;
ndk::ScopedAStatus prepareModelFromCache(
int64_t deadline, const std::vector<ndk::ScopedFileDescriptor>& modelCache,
const std::vector<ndk::ScopedFileDescriptor>& dataCache,
diff --git a/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/PreparedModel.h b/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/PreparedModel.h
index 24cd681..cb6a85b 100644
--- a/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/PreparedModel.h
+++ b/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/PreparedModel.h
@@ -40,6 +40,7 @@
struct PrivateConstructorTag {};
public:
+ // featureLevel is for testing purposes.
static nn::GeneralResult<std::shared_ptr<const PreparedModel>> create(
std::shared_ptr<aidl_hal::IPreparedModel> preparedModel, nn::Version featureLevel);
@@ -49,18 +50,23 @@
nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> execute(
const nn::Request& request, nn::MeasureTiming measure,
- const nn::OptionalTimePoint& deadline,
- const nn::OptionalDuration& loopTimeoutDuration) const override;
+ const nn::OptionalTimePoint& deadline, const nn::OptionalDuration& loopTimeoutDuration,
+ const std::vector<nn::TokenValuePair>& hints,
+ const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
nn::GeneralResult<std::pair<nn::SyncFence, nn::ExecuteFencedInfoCallback>> executeFenced(
const nn::Request& request, const std::vector<nn::SyncFence>& waitFor,
nn::MeasureTiming measure, const nn::OptionalTimePoint& deadline,
const nn::OptionalDuration& loopTimeoutDuration,
- const nn::OptionalDuration& timeoutDurationAfterFence) const override;
+ const nn::OptionalDuration& timeoutDurationAfterFence,
+ const std::vector<nn::TokenValuePair>& hints,
+ const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
nn::GeneralResult<nn::SharedExecution> createReusableExecution(
const nn::Request& request, nn::MeasureTiming measure,
- const nn::OptionalDuration& loopTimeoutDuration) const override;
+ const nn::OptionalDuration& loopTimeoutDuration,
+ const std::vector<nn::TokenValuePair>& hints,
+ const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
nn::GeneralResult<nn::SharedBurst> configureExecutionBurst() const override;
@@ -68,6 +74,8 @@
nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> executeInternal(
const Request& request, bool measure, int64_t deadline, int64_t loopTimeoutDuration,
+ const std::vector<nn::TokenValuePair>& hints,
+ const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix,
const hal::utils::RequestRelocation& relocation) const;
nn::GeneralResult<std::pair<nn::SyncFence, nn::ExecuteFencedInfoCallback>>
@@ -75,6 +83,8 @@
const std::vector<ndk::ScopedFileDescriptor>& waitFor, bool measure,
int64_t deadline, int64_t loopTimeoutDuration,
int64_t timeoutDurationAfterFence,
+ const std::vector<nn::TokenValuePair>& hints,
+ const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix,
const hal::utils::RequestRelocation& relocation) const;
private:
diff --git a/neuralnetworks/aidl/utils/src/Burst.cpp b/neuralnetworks/aidl/utils/src/Burst.cpp
index fb00b26..6c7aa88 100644
--- a/neuralnetworks/aidl/utils/src/Burst.cpp
+++ b/neuralnetworks/aidl/utils/src/Burst.cpp
@@ -43,12 +43,16 @@
static nn::GeneralResult<std::shared_ptr<const BurstExecution>> create(
std::shared_ptr<const Burst> burst, Request request,
std::vector<int64_t> memoryIdentifierTokens, bool measure, int64_t loopTimeoutDuration,
+ const std::vector<nn::TokenValuePair>& hints,
+ const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix,
hal::utils::RequestRelocation relocation,
std::vector<Burst::OptionalCacheHold> cacheHolds);
BurstExecution(PrivateConstructorTag tag, std::shared_ptr<const Burst> burst, Request request,
std::vector<int64_t> memoryIdentifierTokens, bool measure,
- int64_t loopTimeoutDuration, hal::utils::RequestRelocation relocation,
+ int64_t loopTimeoutDuration, const std::vector<nn::TokenValuePair>& hints,
+ const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix,
+ hal::utils::RequestRelocation relocation,
std::vector<Burst::OptionalCacheHold> cacheHolds);
nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> compute(
@@ -64,6 +68,8 @@
const std::vector<int64_t> kMemoryIdentifierTokens;
const bool kMeasure;
const int64_t kLoopTimeoutDuration;
+ const std::vector<nn::TokenValuePair> kHints;
+ const std::vector<nn::ExtensionNameAndPrefix> kExtensionNameToPrefix;
const hal::utils::RequestRelocation kRelocation;
const std::vector<Burst::OptionalCacheHold> kCacheHolds;
};
@@ -149,17 +155,20 @@
}
nn::GeneralResult<std::shared_ptr<const Burst>> Burst::create(
- std::shared_ptr<aidl_hal::IBurst> burst) {
+ std::shared_ptr<aidl_hal::IBurst> burst, nn::Version featureLevel) {
if (burst == nullptr) {
return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
<< "aidl_hal::utils::Burst::create must have non-null burst";
}
- return std::make_shared<const Burst>(PrivateConstructorTag{}, std::move(burst));
+ return std::make_shared<const Burst>(PrivateConstructorTag{}, std::move(burst), featureLevel);
}
-Burst::Burst(PrivateConstructorTag /*tag*/, std::shared_ptr<aidl_hal::IBurst> burst)
- : kBurst(std::move(burst)), kMemoryCache(std::make_shared<MemoryCache>(kBurst)) {
+Burst::Burst(PrivateConstructorTag /*tag*/, std::shared_ptr<aidl_hal::IBurst> burst,
+ nn::Version featureLevel)
+ : kBurst(std::move(burst)),
+ kMemoryCache(std::make_shared<MemoryCache>(kBurst)),
+ kFeatureLevel(featureLevel) {
CHECK(kBurst != nullptr);
}
@@ -170,8 +179,9 @@
nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> Burst::execute(
const nn::Request& request, nn::MeasureTiming measure,
- const nn::OptionalTimePoint& deadline,
- const nn::OptionalDuration& loopTimeoutDuration) const {
+ const nn::OptionalTimePoint& deadline, const nn::OptionalDuration& loopTimeoutDuration,
+ const std::vector<nn::TokenValuePair>& hints,
+ const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const {
// Ensure that request is ready for IPC.
std::optional<nn::Request> maybeRequestInShared;
hal::utils::RequestRelocation relocation;
@@ -200,14 +210,14 @@
memoryIdentifierTokens.push_back(-1);
}
CHECK_EQ(requestInShared.pools.size(), memoryIdentifierTokens.size());
-
return executeInternal(aidlRequest, memoryIdentifierTokens, aidlMeasure, aidlDeadline,
- aidlLoopTimeoutDuration, relocation);
+ aidlLoopTimeoutDuration, hints, extensionNameToPrefix, relocation);
}
nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> Burst::executeInternal(
const Request& request, const std::vector<int64_t>& memoryIdentifierTokens, bool measure,
- int64_t deadline, int64_t loopTimeoutDuration,
+ int64_t deadline, int64_t loopTimeoutDuration, const std::vector<nn::TokenValuePair>& hints,
+ const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix,
const hal::utils::RequestRelocation& relocation) const {
// Ensure that at most one execution is in flight at any given time.
const bool alreadyInFlight = mExecutionInFlight.test_and_set();
@@ -221,9 +231,21 @@
}
ExecutionResult executionResult;
- const auto ret = kBurst->executeSynchronously(request, memoryIdentifierTokens, measure,
- deadline, loopTimeoutDuration, &executionResult);
- HANDLE_ASTATUS(ret) << "execute failed";
+ if (kFeatureLevel.level >= nn::Version::Level::FEATURE_LEVEL_8) {
+ auto aidlHints = NN_TRY(convert(hints));
+ auto aidlExtensionPrefix = NN_TRY(convert(extensionNameToPrefix));
+ const auto ret = kBurst->executeSynchronouslyWithConfig(
+ request, memoryIdentifierTokens,
+ {measure, loopTimeoutDuration, std::move(aidlHints),
+ std::move(aidlExtensionPrefix)},
+ deadline, &executionResult);
+ HANDLE_ASTATUS(ret) << "execute failed";
+ } else {
+ const auto ret =
+ kBurst->executeSynchronously(request, memoryIdentifierTokens, measure, deadline,
+ loopTimeoutDuration, &executionResult);
+ HANDLE_ASTATUS(ret) << "execute failed";
+ }
if (!executionResult.outputSufficientSize) {
auto canonicalOutputShapes =
nn::convert(executionResult.outputShapes).value_or(std::vector<nn::OutputShape>{});
@@ -241,7 +263,9 @@
nn::GeneralResult<nn::SharedExecution> Burst::createReusableExecution(
const nn::Request& request, nn::MeasureTiming measure,
- const nn::OptionalDuration& loopTimeoutDuration) const {
+ const nn::OptionalDuration& loopTimeoutDuration,
+ const std::vector<nn::TokenValuePair>& hints,
+ const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const {
// Ensure that request is ready for IPC.
std::optional<nn::Request> maybeRequestInShared;
hal::utils::RequestRelocation relocation;
@@ -272,12 +296,15 @@
return BurstExecution::create(shared_from_this(), std::move(aidlRequest),
std::move(memoryIdentifierTokens), aidlMeasure,
- aidlLoopTimeoutDuration, std::move(relocation), std::move(holds));
+ aidlLoopTimeoutDuration, hints, extensionNameToPrefix,
+ std::move(relocation), std::move(holds));
}
nn::GeneralResult<std::shared_ptr<const BurstExecution>> BurstExecution::create(
std::shared_ptr<const Burst> burst, Request request,
std::vector<int64_t> memoryIdentifierTokens, bool measure, int64_t loopTimeoutDuration,
+ const std::vector<nn::TokenValuePair>& hints,
+ const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix,
hal::utils::RequestRelocation relocation,
std::vector<Burst::OptionalCacheHold> cacheHolds) {
if (burst == nullptr) {
@@ -286,13 +313,15 @@
return std::make_shared<const BurstExecution>(
PrivateConstructorTag{}, std::move(burst), std::move(request),
- std::move(memoryIdentifierTokens), measure, loopTimeoutDuration, std::move(relocation),
- std::move(cacheHolds));
+ std::move(memoryIdentifierTokens), measure, loopTimeoutDuration, hints,
+ extensionNameToPrefix, std::move(relocation), std::move(cacheHolds));
}
BurstExecution::BurstExecution(PrivateConstructorTag /*tag*/, std::shared_ptr<const Burst> burst,
Request request, std::vector<int64_t> memoryIdentifierTokens,
bool measure, int64_t loopTimeoutDuration,
+ const std::vector<nn::TokenValuePair>& hints,
+ const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix,
hal::utils::RequestRelocation relocation,
std::vector<Burst::OptionalCacheHold> cacheHolds)
: kBurst(std::move(burst)),
@@ -300,6 +329,8 @@
kMemoryIdentifierTokens(std::move(memoryIdentifierTokens)),
kMeasure(measure),
kLoopTimeoutDuration(loopTimeoutDuration),
+ kHints(hints),
+ kExtensionNameToPrefix(extensionNameToPrefix),
kRelocation(std::move(relocation)),
kCacheHolds(std::move(cacheHolds)) {}
@@ -307,7 +338,8 @@
const nn::OptionalTimePoint& deadline) const {
const auto aidlDeadline = NN_TRY(convert(deadline));
return kBurst->executeInternal(kRequest, kMemoryIdentifierTokens, kMeasure, aidlDeadline,
- kLoopTimeoutDuration, kRelocation);
+ kLoopTimeoutDuration, kHints, kExtensionNameToPrefix,
+ kRelocation);
}
nn::GeneralResult<std::pair<nn::SyncFence, nn::ExecuteFencedInfoCallback>>
diff --git a/neuralnetworks/aidl/utils/src/Conversions.cpp b/neuralnetworks/aidl/utils/src/Conversions.cpp
index 113d2da..eb28db7 100644
--- a/neuralnetworks/aidl/utils/src/Conversions.cpp
+++ b/neuralnetworks/aidl/utils/src/Conversions.cpp
@@ -302,9 +302,9 @@
};
}
-GeneralResult<Model::ExtensionNameAndPrefix> unvalidatedConvert(
+GeneralResult<ExtensionNameAndPrefix> unvalidatedConvert(
const aidl_hal::ExtensionNameAndPrefix& extensionNameAndPrefix) {
- return Model::ExtensionNameAndPrefix{
+ return ExtensionNameAndPrefix{
.name = extensionNameAndPrefix.name,
.prefix = extensionNameAndPrefix.prefix,
};
@@ -506,6 +506,12 @@
return std::make_shared<const Handle>(std::move(duplicatedFd));
}
+#ifdef NN_AIDL_V4_OR_ABOVE
+GeneralResult<TokenValuePair> unvalidatedConvert(const aidl_hal::TokenValuePair& tokenValuePair) {
+ return TokenValuePair{.token = tokenValuePair.token, .value = tokenValuePair.value};
+}
+#endif // NN_AIDL_V4_OR_ABOVE
+
GeneralResult<Capabilities> convert(const aidl_hal::Capabilities& capabilities) {
return validatedConvert(capabilities);
}
@@ -562,6 +568,17 @@
GeneralResult<std::vector<SharedMemory>> convert(const std::vector<aidl_hal::Memory>& memories) {
return validatedConvert(memories);
}
+GeneralResult<std::vector<ExtensionNameAndPrefix>> convert(
+ const std::vector<aidl_hal::ExtensionNameAndPrefix>& extensionNameAndPrefix) {
+ return unvalidatedConvert(extensionNameAndPrefix);
+}
+
+#ifdef NN_AIDL_V4_OR_ABOVE
+GeneralResult<std::vector<TokenValuePair>> convert(
+ const std::vector<aidl_hal::TokenValuePair>& metaData) {
+ return validatedConvert(metaData);
+}
+#endif // NN_AIDL_V4_OR_ABOVE
GeneralResult<std::vector<OutputShape>> convert(
const std::vector<aidl_hal::OutputShape>& outputShapes) {
@@ -942,7 +959,7 @@
}
nn::GeneralResult<ExtensionNameAndPrefix> unvalidatedConvert(
- const nn::Model::ExtensionNameAndPrefix& extensionNameToPrefix) {
+ const nn::ExtensionNameAndPrefix& extensionNameToPrefix) {
return ExtensionNameAndPrefix{
.name = extensionNameToPrefix.name,
.prefix = extensionNameToPrefix.prefix,
@@ -1055,6 +1072,11 @@
return Extension{.name = extension.name,
.operandTypes = NN_TRY(unvalidatedConvert(extension.operandTypes))};
}
+#ifdef NN_AIDL_V4_OR_ABOVE
+nn::GeneralResult<TokenValuePair> unvalidatedConvert(const nn::TokenValuePair& tokenValuePair) {
+ return TokenValuePair{.token = tokenValuePair.token, .value = tokenValuePair.value};
+}
+#endif // NN_AIDL_V4_OR_ABOVE
nn::GeneralResult<std::vector<uint8_t>> convert(const nn::CacheToken& cacheToken) {
return validatedConvert(cacheToken);
@@ -1134,6 +1156,17 @@
const std::vector<nn::SyncFence>& syncFences) {
return validatedConvert(syncFences);
}
+nn::GeneralResult<std::vector<ExtensionNameAndPrefix>> convert(
+ const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) {
+ return unvalidatedConvert(extensionNameToPrefix);
+}
+
+#ifdef NN_AIDL_V4_OR_ABOVE
+nn::GeneralResult<std::vector<TokenValuePair>> convert(
+ const std::vector<nn::TokenValuePair>& metaData) {
+ return validatedConvert(metaData);
+}
+#endif // NN_AIDL_V4_OR_ABOVE
nn::GeneralResult<std::vector<Extension>> convert(const std::vector<nn::Extension>& extensions) {
return validatedConvert(extensions);
diff --git a/neuralnetworks/aidl/utils/src/Device.cpp b/neuralnetworks/aidl/utils/src/Device.cpp
index bad10ed..f3f4fdb 100644
--- a/neuralnetworks/aidl/utils/src/Device.cpp
+++ b/neuralnetworks/aidl/utils/src/Device.cpp
@@ -215,7 +215,9 @@
nn::GeneralResult<nn::SharedPreparedModel> Device::prepareModel(
const nn::Model& model, nn::ExecutionPreference preference, nn::Priority priority,
nn::OptionalTimePoint deadline, const std::vector<nn::SharedHandle>& modelCache,
- const std::vector<nn::SharedHandle>& dataCache, const nn::CacheToken& token) const {
+ const std::vector<nn::SharedHandle>& dataCache, const nn::CacheToken& token,
+ const std::vector<nn::TokenValuePair>& hints,
+ const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const {
// Ensure that model is ready for IPC.
std::optional<nn::Model> maybeModelInShared;
const nn::Model& modelInShared =
@@ -225,17 +227,28 @@
const auto aidlPreference = NN_TRY(convert(preference));
const auto aidlPriority = NN_TRY(convert(priority));
const auto aidlDeadline = NN_TRY(convert(deadline));
- const auto aidlModelCache = NN_TRY(convert(modelCache));
- const auto aidlDataCache = NN_TRY(convert(dataCache));
+ auto aidlModelCache = NN_TRY(convert(modelCache));
+ auto aidlDataCache = NN_TRY(convert(dataCache));
const auto aidlToken = NN_TRY(convert(token));
const auto cb = ndk::SharedRefBase::make<PreparedModelCallback>(kFeatureLevel);
const auto scoped = kDeathHandler.protectCallback(cb.get());
+ if (kFeatureLevel.level >= nn::Version::Level::FEATURE_LEVEL_8) {
+ auto aidlHints = NN_TRY(convert(hints));
+ auto aidlExtensionPrefix = NN_TRY(convert(extensionNameToPrefix));
+ const auto ret = kDevice->prepareModelWithConfig(
+ aidlModel,
+ {aidlPreference, aidlPriority, aidlDeadline, std::move(aidlModelCache),
+ std::move(aidlDataCache), aidlToken, std::move(aidlHints),
+ std::move(aidlExtensionPrefix)},
+ cb);
+ HANDLE_ASTATUS(ret) << "prepareModel failed";
+ return cb->get();
+ }
const auto ret = kDevice->prepareModel(aidlModel, aidlPreference, aidlPriority, aidlDeadline,
aidlModelCache, aidlDataCache, aidlToken, cb);
HANDLE_ASTATUS(ret) << "prepareModel failed";
-
return cb->get();
}
diff --git a/neuralnetworks/aidl/utils/src/Execution.cpp b/neuralnetworks/aidl/utils/src/Execution.cpp
index c4add63..2fd88af 100644
--- a/neuralnetworks/aidl/utils/src/Execution.cpp
+++ b/neuralnetworks/aidl/utils/src/Execution.cpp
@@ -63,7 +63,7 @@
ExecutionWithCachedRequest::compute(const nn::OptionalTimePoint& deadline) const {
const auto aidlDeadline = NN_TRY(convert(deadline));
return kPreparedModel->executeInternal(kRequest, kMeasure, aidlDeadline, kLoopTimeoutDuration,
- kRelocation);
+ {}, {}, kRelocation);
}
nn::GeneralResult<std::pair<nn::SyncFence, nn::ExecuteFencedInfoCallback>>
@@ -73,9 +73,9 @@
const auto aidlWaitFor = NN_TRY(convert(waitFor));
const auto aidlDeadline = NN_TRY(convert(deadline));
const auto aidlTimeoutDurationAfterFence = NN_TRY(convert(timeoutDurationAfterFence));
- return kPreparedModel->executeFencedInternal(kRequest, aidlWaitFor, kMeasure, aidlDeadline,
- kLoopTimeoutDuration,
- aidlTimeoutDurationAfterFence, kRelocation);
+ return kPreparedModel->executeFencedInternal(
+ kRequest, aidlWaitFor, kMeasure, aidlDeadline, kLoopTimeoutDuration,
+ aidlTimeoutDurationAfterFence, {}, {}, kRelocation);
}
nn::GeneralResult<std::shared_ptr<const Execution>> Execution::create(
diff --git a/neuralnetworks/aidl/utils/src/InvalidDevice.cpp b/neuralnetworks/aidl/utils/src/InvalidDevice.cpp
index c9d9955..33270ff 100644
--- a/neuralnetworks/aidl/utils/src/InvalidDevice.cpp
+++ b/neuralnetworks/aidl/utils/src/InvalidDevice.cpp
@@ -167,6 +167,31 @@
return ndk::ScopedAStatus::ok();
}
+ndk::ScopedAStatus InvalidDevice::prepareModelWithConfig(
+ const Model& model, const PrepareModelConfig& config,
+ const std::shared_ptr<IPreparedModelCallback>& callback) {
+ if (!utils::valid(config.extensionNameToPrefix)) {
+ callback->notify(ErrorStatus::INVALID_ARGUMENT, nullptr);
+ return toAStatus(ErrorStatus::INVALID_ARGUMENT, "Invalid extensionNameToPrefix");
+ }
+ for (const auto& hint : config.compilationHints) {
+ auto result = std::find_if(config.extensionNameToPrefix.begin(),
+ config.extensionNameToPrefix.end(),
+ [&hint](const ExtensionNameAndPrefix& extension) {
+ uint16_t prefix = static_cast<uint32_t>(hint.token) >>
+ IDevice::EXTENSION_TYPE_LOW_BITS_TYPE;
+ return prefix == extension.prefix;
+ });
+ if (result == config.extensionNameToPrefix.end()) {
+ callback->notify(ErrorStatus::INVALID_ARGUMENT, nullptr);
+ return toAStatus(ErrorStatus::INVALID_ARGUMENT,
+ "Invalid token for compilation hints: " + std::to_string(hint.token));
+ }
+ }
+ return prepareModel(model, config.preference, config.priority, config.deadlineNs,
+ config.modelCache, config.dataCache, config.cacheToken, callback);
+}
+
ndk::ScopedAStatus InvalidDevice::prepareModelFromCache(
int64_t /*deadline*/, const std::vector<ndk::ScopedFileDescriptor>& /*modelCache*/,
const std::vector<ndk::ScopedFileDescriptor>& /*dataCache*/,
diff --git a/neuralnetworks/aidl/utils/src/PreparedModel.cpp b/neuralnetworks/aidl/utils/src/PreparedModel.cpp
index 6d1de56..7e3a31c 100644
--- a/neuralnetworks/aidl/utils/src/PreparedModel.cpp
+++ b/neuralnetworks/aidl/utils/src/PreparedModel.cpp
@@ -128,8 +128,9 @@
nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> PreparedModel::execute(
const nn::Request& request, nn::MeasureTiming measure,
- const nn::OptionalTimePoint& deadline,
- const nn::OptionalDuration& loopTimeoutDuration) const {
+ const nn::OptionalTimePoint& deadline, const nn::OptionalDuration& loopTimeoutDuration,
+ const std::vector<nn::TokenValuePair>& hints,
+ const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const {
// Ensure that request is ready for IPC.
std::optional<nn::Request> maybeRequestInShared;
hal::utils::RequestRelocation relocation;
@@ -141,30 +142,46 @@
const auto aidlMeasure = NN_TRY(convert(measure));
const auto aidlDeadline = NN_TRY(convert(deadline));
const auto aidlLoopTimeoutDuration = NN_TRY(convert(loopTimeoutDuration));
- return executeInternal(aidlRequest, aidlMeasure, aidlDeadline, aidlLoopTimeoutDuration,
- relocation);
+ return executeInternal(aidlRequest, aidlMeasure, aidlDeadline, aidlLoopTimeoutDuration, hints,
+ extensionNameToPrefix, relocation);
}
nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>>
PreparedModel::executeInternal(const Request& request, bool measure, int64_t deadline,
int64_t loopTimeoutDuration,
+ const std::vector<nn::TokenValuePair>& hints,
+ const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix,
const hal::utils::RequestRelocation& relocation) const {
if (relocation.input) {
relocation.input->flush();
}
ExecutionResult executionResult;
- const auto ret = kPreparedModel->executeSynchronously(request, measure, deadline,
- loopTimeoutDuration, &executionResult);
- HANDLE_ASTATUS(ret) << "executeSynchronously failed";
+ if (kFeatureLevel.level >= nn::Version::Level::FEATURE_LEVEL_8) {
+ auto aidlHints = NN_TRY(convert(hints));
+ auto aidlExtensionPrefix = NN_TRY(convert(extensionNameToPrefix));
+ const auto ret = kPreparedModel->executeSynchronouslyWithConfig(
+ request,
+ {measure, loopTimeoutDuration, std::move(aidlHints),
+ std::move(aidlExtensionPrefix)},
+ deadline, &executionResult);
+ HANDLE_ASTATUS(ret) << "executeSynchronouslyWithConfig failed";
+ } else {
+ const auto ret = kPreparedModel->executeSynchronously(
+ request, measure, deadline, loopTimeoutDuration, &executionResult);
+ HANDLE_ASTATUS(ret) << "executeSynchronously failed";
+ }
return handleExecutionResult(executionResult, relocation);
}
nn::GeneralResult<std::pair<nn::SyncFence, nn::ExecuteFencedInfoCallback>>
-PreparedModel::executeFenced(const nn::Request& request, const std::vector<nn::SyncFence>& waitFor,
- nn::MeasureTiming measure, const nn::OptionalTimePoint& deadline,
- const nn::OptionalDuration& loopTimeoutDuration,
- const nn::OptionalDuration& timeoutDurationAfterFence) const {
+PreparedModel::executeFenced(
+ const nn::Request& request, const std::vector<nn::SyncFence>& waitFor,
+ nn::MeasureTiming measure, const nn::OptionalTimePoint& deadline,
+ const nn::OptionalDuration& loopTimeoutDuration,
+ const nn::OptionalDuration& timeoutDurationAfterFence,
+ const std::vector<nn::TokenValuePair>& hints,
+ const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const {
// Ensure that request is ready for IPC.
std::optional<nn::Request> maybeRequestInShared;
hal::utils::RequestRelocation relocation;
@@ -179,31 +196,45 @@
const auto aidlLoopTimeoutDuration = NN_TRY(convert(loopTimeoutDuration));
const auto aidlTimeoutDurationAfterFence = NN_TRY(convert(timeoutDurationAfterFence));
return executeFencedInternal(aidlRequest, aidlWaitFor, aidlMeasure, aidlDeadline,
- aidlLoopTimeoutDuration, aidlTimeoutDurationAfterFence,
- relocation);
+ aidlLoopTimeoutDuration, aidlTimeoutDurationAfterFence, hints,
+ extensionNameToPrefix, relocation);
}
nn::GeneralResult<std::pair<nn::SyncFence, nn::ExecuteFencedInfoCallback>>
-PreparedModel::executeFencedInternal(const Request& request,
- const std::vector<ndk::ScopedFileDescriptor>& waitFor,
- bool measure, int64_t deadline, int64_t loopTimeoutDuration,
- int64_t timeoutDurationAfterFence,
- const hal::utils::RequestRelocation& relocation) const {
+PreparedModel::executeFencedInternal(
+ const Request& request, const std::vector<ndk::ScopedFileDescriptor>& waitFor, bool measure,
+ int64_t deadline, int64_t loopTimeoutDuration, int64_t timeoutDurationAfterFence,
+ const std::vector<nn::TokenValuePair>& hints,
+ const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix,
+ const hal::utils::RequestRelocation& relocation) const {
if (relocation.input) {
relocation.input->flush();
}
FencedExecutionResult result;
- const auto ret =
- kPreparedModel->executeFenced(request, waitFor, measure, deadline, loopTimeoutDuration,
- timeoutDurationAfterFence, &result);
- HANDLE_ASTATUS(ret) << "executeFenced failed";
+ if (kFeatureLevel.level >= nn::Version::Level::FEATURE_LEVEL_8) {
+ auto aidlHints = NN_TRY(convert(hints));
+ auto aidlExtensionPrefix = NN_TRY(convert(extensionNameToPrefix));
+ const auto ret = kPreparedModel->executeFencedWithConfig(
+ request, waitFor,
+ {measure, loopTimeoutDuration, std::move(aidlHints),
+ std::move(aidlExtensionPrefix)},
+ deadline, timeoutDurationAfterFence, &result);
+ HANDLE_ASTATUS(ret) << "executeFencedWithConfig failed";
+ } else {
+ const auto ret = kPreparedModel->executeFenced(request, waitFor, measure, deadline,
+ loopTimeoutDuration,
+ timeoutDurationAfterFence, &result);
+ HANDLE_ASTATUS(ret) << "executeFenced failed";
+ }
return handleFencedExecutionResult(result, relocation);
}
nn::GeneralResult<nn::SharedExecution> PreparedModel::createReusableExecution(
const nn::Request& request, nn::MeasureTiming measure,
- const nn::OptionalDuration& loopTimeoutDuration) const {
+ const nn::OptionalDuration& loopTimeoutDuration,
+ const std::vector<nn::TokenValuePair>& hints,
+ const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const {
// Ensure that request is ready for IPC.
std::optional<nn::Request> maybeRequestInShared;
hal::utils::RequestRelocation relocation;
@@ -217,8 +248,14 @@
if (kFeatureLevel.level >= nn::Version::Level::FEATURE_LEVEL_8) {
std::shared_ptr<IExecution> execution;
+ auto aidlHints = NN_TRY(convert(hints));
+ auto aidlExtensionPrefix = NN_TRY(convert(extensionNameToPrefix));
+
const auto ret = kPreparedModel->createReusableExecution(
- aidlRequest, aidlMeasure, aidlLoopTimeoutDuration, &execution);
+ aidlRequest,
+ {aidlMeasure, aidlLoopTimeoutDuration, std::move(aidlHints),
+ std::move(aidlExtensionPrefix)},
+ &execution);
HANDLE_ASTATUS(ret) << "createReusableExecution failed";
return Execution::create(std::move(execution), std::move(relocation));
}
@@ -232,7 +269,7 @@
std::shared_ptr<IBurst> burst;
const auto ret = kPreparedModel->configureExecutionBurst(&burst);
HANDLE_ASTATUS(ret) << "configureExecutionBurst failed";
- return Burst::create(std::move(burst));
+ return Burst::create(std::move(burst), kFeatureLevel);
}
std::any PreparedModel::getUnderlyingResource() const {
diff --git a/neuralnetworks/aidl/utils/test/DeviceTest.cpp b/neuralnetworks/aidl/utils/test/DeviceTest.cpp
index fb13af8..73727b3 100644
--- a/neuralnetworks/aidl/utils/test/DeviceTest.cpp
+++ b/neuralnetworks/aidl/utils/test/DeviceTest.cpp
@@ -61,7 +61,6 @@
.powerUsage = std::numeric_limits<float>::max()};
constexpr NumberOfCacheFiles kNumberOfCacheFiles = {.numModelCache = nn::kMaxNumberOfCacheFiles - 1,
.numDataCache = nn::kMaxNumberOfCacheFiles};
-
constexpr auto makeStatusOk = [] { return ndk::ScopedAStatus::ok(); };
std::shared_ptr<MockDevice> createMockDevice() {
@@ -124,6 +123,18 @@
};
}
+const std::vector<nn::TokenValuePair> kHints = {nn::TokenValuePair{.token = 0, .value = {1}}};
+const std::vector<nn::ExtensionNameAndPrefix> kExtensionNameToPrefix = {
+ nn::ExtensionNameAndPrefix{.name = "com.android.nn_test", .prefix = 1}};
+auto makePreparedModelWithConfigReturn(ErrorStatus launchStatus, ErrorStatus returnStatus,
+ const std::shared_ptr<MockPreparedModel>& preparedModel) {
+ return [launchStatus, returnStatus, preparedModel](
+ const Model& /*model*/, const PrepareModelConfig& /*config*/,
+ const std::shared_ptr<IPreparedModelCallback>& cb) -> ndk::ScopedAStatus {
+ return makePreparedModelReturnImpl(launchStatus, returnStatus, preparedModel, cb);
+ };
+}
+
auto makePreparedModelFromCacheReturn(ErrorStatus launchStatus, ErrorStatus returnStatus,
const std::shared_ptr<MockPreparedModel>& preparedModel) {
return [launchStatus, returnStatus, preparedModel](
@@ -560,6 +571,8 @@
}
TEST_P(DeviceTest, prepareModel) {
+ if (kVersion.level > nn::Version::Level::FEATURE_LEVEL_7) return;
+
// setup call
const auto mockDevice = createMockDevice();
const auto device = Device::create(kName, mockDevice, kVersion).value();
@@ -571,7 +584,7 @@
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
- nn::Priority::DEFAULT, {}, {}, {}, {});
+ nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_TRUE(result.has_value())
@@ -580,6 +593,8 @@
}
TEST_P(DeviceTest, prepareModelLaunchError) {
+ if (kVersion.level > nn::Version::Level::FEATURE_LEVEL_7) return;
+
// setup call
const auto mockDevice = createMockDevice();
const auto device = Device::create(kName, mockDevice, kVersion).value();
@@ -590,7 +605,7 @@
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
- nn::Priority::DEFAULT, {}, {}, {}, {});
+ nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -598,6 +613,8 @@
}
TEST_P(DeviceTest, prepareModelReturnError) {
+ if (kVersion.level > nn::Version::Level::FEATURE_LEVEL_7) return;
+
// setup call
const auto mockDevice = createMockDevice();
const auto device = Device::create(kName, mockDevice, kVersion).value();
@@ -608,7 +625,7 @@
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
- nn::Priority::DEFAULT, {}, {}, {}, {});
+ nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -616,6 +633,8 @@
}
TEST_P(DeviceTest, prepareModelNullptrError) {
+ if (kVersion.level > nn::Version::Level::FEATURE_LEVEL_7) return;
+
// setup call
const auto mockDevice = createMockDevice();
const auto device = Device::create(kName, mockDevice, kVersion).value();
@@ -626,7 +645,7 @@
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
- nn::Priority::DEFAULT, {}, {}, {}, {});
+ nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -634,6 +653,8 @@
}
TEST_P(DeviceTest, prepareModelTransportFailure) {
+ if (kVersion.level > nn::Version::Level::FEATURE_LEVEL_7) return;
+
// setup call
const auto mockDevice = createMockDevice();
const auto device = Device::create(kName, mockDevice, kVersion).value();
@@ -643,7 +664,7 @@
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
- nn::Priority::DEFAULT, {}, {}, {}, {});
+ nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -651,6 +672,8 @@
}
TEST_P(DeviceTest, prepareModelDeadObject) {
+ if (kVersion.level > nn::Version::Level::FEATURE_LEVEL_7) return;
+
// setup call
const auto mockDevice = createMockDevice();
const auto device = Device::create(kName, mockDevice, kVersion).value();
@@ -660,7 +683,7 @@
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
- nn::Priority::DEFAULT, {}, {}, {}, {});
+ nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -668,6 +691,8 @@
}
TEST_P(DeviceTest, prepareModelAsyncCrash) {
+ if (kVersion.level > nn::Version::Level::FEATURE_LEVEL_7) return;
+
// setup test
const auto mockDevice = createMockDevice();
const auto device = Device::create(kName, mockDevice, kVersion).value();
@@ -681,7 +706,157 @@
// run test
const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
- nn::Priority::DEFAULT, {}, {}, {}, {});
+ nn::Priority::DEFAULT, {}, {}, {}, {}, {}, {});
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
+}
+
+TEST_P(DeviceTest, prepareModelWithConfig) {
+ if (kVersion.level < nn::Version::Level::FEATURE_LEVEL_8) return;
+
+ // setup call
+ const auto mockDevice = createMockDevice();
+ const auto device = Device::create(kName, mockDevice, kVersion).value();
+ const auto mockPreparedModel = MockPreparedModel::create();
+ EXPECT_CALL(*mockDevice, prepareModelWithConfig(_, _, _))
+ .Times(1)
+ .WillOnce(Invoke(makePreparedModelWithConfigReturn(ErrorStatus::NONE, ErrorStatus::NONE,
+ mockPreparedModel)));
+
+ // run test
+ const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
+ nn::Priority::DEFAULT, {}, {}, {}, {}, kHints,
+ kExtensionNameToPrefix);
+
+ // verify result
+ ASSERT_TRUE(result.has_value())
+ << "Failed with " << result.error().code << ": " << result.error().message;
+ EXPECT_NE(result.value(), nullptr);
+}
+
+TEST_P(DeviceTest, prepareModelWithConfigLaunchError) {
+ if (kVersion.level < nn::Version::Level::FEATURE_LEVEL_8) return;
+
+ // setup call
+ const auto mockDevice = createMockDevice();
+ const auto device = Device::create(kName, mockDevice, kVersion).value();
+ EXPECT_CALL(*mockDevice, prepareModelWithConfig(_, _, _))
+ .Times(1)
+ .WillOnce(Invoke(makePreparedModelWithConfigReturn(
+ ErrorStatus::GENERAL_FAILURE, ErrorStatus::GENERAL_FAILURE, nullptr)));
+
+ // run test
+ const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
+ nn::Priority::DEFAULT, {}, {}, {}, {}, kHints,
+ kExtensionNameToPrefix);
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST_P(DeviceTest, prepareModelWithConfigReturnError) {
+ if (kVersion.level < nn::Version::Level::FEATURE_LEVEL_8) return;
+
+ // setup call
+ const auto mockDevice = createMockDevice();
+ const auto device = Device::create(kName, mockDevice, kVersion).value();
+ EXPECT_CALL(*mockDevice, prepareModelWithConfig(_, _, _))
+ .Times(1)
+ .WillOnce(Invoke(makePreparedModelWithConfigReturn(
+ ErrorStatus::NONE, ErrorStatus::GENERAL_FAILURE, nullptr)));
+
+ // run test
+ const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
+ nn::Priority::DEFAULT, {}, {}, {}, {}, kHints,
+ kExtensionNameToPrefix);
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST_P(DeviceTest, prepareModelWithConfigNullptrError) {
+ if (kVersion.level < nn::Version::Level::FEATURE_LEVEL_8) return;
+
+ // setup call
+ const auto mockDevice = createMockDevice();
+ const auto device = Device::create(kName, mockDevice, kVersion).value();
+ EXPECT_CALL(*mockDevice, prepareModelWithConfig(_, _, _))
+ .Times(1)
+ .WillOnce(Invoke(makePreparedModelWithConfigReturn(ErrorStatus::NONE, ErrorStatus::NONE,
+ nullptr)));
+
+ // run test
+ const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
+ nn::Priority::DEFAULT, {}, {}, {}, {}, kHints,
+ kExtensionNameToPrefix);
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST_P(DeviceTest, prepareModelWithConfigTransportFailure) {
+ if (kVersion.level < nn::Version::Level::FEATURE_LEVEL_8) return;
+
+ // setup call
+ const auto mockDevice = createMockDevice();
+ const auto device = Device::create(kName, mockDevice, kVersion).value();
+ EXPECT_CALL(*mockDevice, prepareModelWithConfig(_, _, _))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
+
+ // run test
+ const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
+ nn::Priority::DEFAULT, {}, {}, {}, {}, kHints,
+ kExtensionNameToPrefix);
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST_P(DeviceTest, prepareModelWithConfigDeadObject) {
+ if (kVersion.level < nn::Version::Level::FEATURE_LEVEL_8) return;
+
+ // setup call
+ const auto mockDevice = createMockDevice();
+ const auto device = Device::create(kName, mockDevice, kVersion).value();
+ EXPECT_CALL(*mockDevice, prepareModelWithConfig(_, _, _))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
+
+ // run test
+ const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
+ nn::Priority::DEFAULT, {}, {}, {}, {}, kHints,
+ kExtensionNameToPrefix);
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
+}
+
+TEST_P(DeviceTest, prepareModelWithConfigAsyncCrash) {
+ if (kVersion.level < nn::Version::Level::FEATURE_LEVEL_8) return;
+
+ // setup test
+ const auto mockDevice = createMockDevice();
+ const auto device = Device::create(kName, mockDevice, kVersion).value();
+ const auto ret = [&device]() {
+ DeathMonitor::serviceDied(device->getDeathMonitor());
+ return ndk::ScopedAStatus::ok();
+ };
+ EXPECT_CALL(*mockDevice, prepareModelWithConfig(_, _, _))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(ret));
+
+ // run test
+ const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
+ nn::Priority::DEFAULT, {}, {}, {}, {}, kHints,
+ kExtensionNameToPrefix);
// verify result
ASSERT_FALSE(result.has_value());
diff --git a/neuralnetworks/aidl/utils/test/MockBurst.h b/neuralnetworks/aidl/utils/test/MockBurst.h
index 5083bbd..4cf60b6 100644
--- a/neuralnetworks/aidl/utils/test/MockBurst.h
+++ b/neuralnetworks/aidl/utils/test/MockBurst.h
@@ -32,6 +32,10 @@
bool measureTiming, int64_t deadline, int64_t loopTimeoutDuration,
ExecutionResult* executionResult),
(override));
+ MOCK_METHOD(ndk::ScopedAStatus, executeSynchronouslyWithConfig,
+ (const Request& request, const std::vector<int64_t>& memoryIdentifierTokens,
+ const ExecutionConfig& config, int64_t deadline, ExecutionResult* executionResult),
+ (override));
MOCK_METHOD(ndk::ScopedAStatus, releaseMemoryResource, (int64_t memoryIdentifierToken),
(override));
};
diff --git a/neuralnetworks/aidl/utils/test/MockDevice.h b/neuralnetworks/aidl/utils/test/MockDevice.h
index 3a28d55..47b8346 100644
--- a/neuralnetworks/aidl/utils/test/MockDevice.h
+++ b/neuralnetworks/aidl/utils/test/MockDevice.h
@@ -50,6 +50,10 @@
const std::vector<uint8_t>& token,
const std::shared_ptr<IPreparedModelCallback>& callback),
(override));
+ MOCK_METHOD(ndk::ScopedAStatus, prepareModelWithConfig,
+ (const Model& model, const PrepareModelConfig& config,
+ const std::shared_ptr<IPreparedModelCallback>& callback),
+ (override));
MOCK_METHOD(ndk::ScopedAStatus, prepareModelFromCache,
(int64_t deadline, const std::vector<ndk::ScopedFileDescriptor>& modelCache,
const std::vector<ndk::ScopedFileDescriptor>& dataCache,
diff --git a/neuralnetworks/aidl/utils/test/MockPreparedModel.h b/neuralnetworks/aidl/utils/test/MockPreparedModel.h
index 0ed9af9..318acc2 100644
--- a/neuralnetworks/aidl/utils/test/MockPreparedModel.h
+++ b/neuralnetworks/aidl/utils/test/MockPreparedModel.h
@@ -40,10 +40,19 @@
bool measureTiming, int64_t deadline, int64_t loopTimeoutDuration,
int64_t duration, FencedExecutionResult* fencedExecutionResult),
(override));
+ MOCK_METHOD(ndk::ScopedAStatus, executeSynchronouslyWithConfig,
+ (const Request& request, const ExecutionConfig& config, int64_t deadline,
+ ExecutionResult* executionResult),
+ (override));
+ MOCK_METHOD(ndk::ScopedAStatus, executeFencedWithConfig,
+ (const Request& request, const std::vector<ndk::ScopedFileDescriptor>& waitFor,
+ const ExecutionConfig& config, int64_t deadline, int64_t duration,
+ FencedExecutionResult* fencedExecutionResult),
+ (override));
MOCK_METHOD(ndk::ScopedAStatus, configureExecutionBurst, (std::shared_ptr<IBurst> * burst),
(override));
MOCK_METHOD(ndk::ScopedAStatus, createReusableExecution,
- (const Request& request, bool measureTiming, int64_t loopTimeoutDuration,
+ (const Request& request, const ExecutionConfig& config,
std::shared_ptr<IExecution>* execution),
(override));
};
diff --git a/neuralnetworks/aidl/utils/test/PreparedModelTest.cpp b/neuralnetworks/aidl/utils/test/PreparedModelTest.cpp
index 8cfb7c1..bf6136d 100644
--- a/neuralnetworks/aidl/utils/test/PreparedModelTest.cpp
+++ b/neuralnetworks/aidl/utils/test/PreparedModelTest.cpp
@@ -70,6 +70,21 @@
class PreparedModelTest : public VersionedAidlUtilsTestBase {};
+const std::vector<nn::TokenValuePair> kHints = {nn::TokenValuePair{.token = 0, .value = {1}}};
+const std::vector<nn::ExtensionNameAndPrefix> kExtensionNameToPrefix = {
+ nn::ExtensionNameAndPrefix{.name = "com.android.nn_test", .prefix = 1}};
+auto makeFencedExecutionWithConfigResult(
+ const std::shared_ptr<MockFencedExecutionCallback>& callback) {
+ return [callback](const Request& /*request*/,
+ const std::vector<ndk::ScopedFileDescriptor>& /*waitFor*/,
+ const ExecutionConfig& /*config*/, int64_t /*deadline*/, int64_t /*duration*/,
+ FencedExecutionResult* fencedExecutionResult) {
+ *fencedExecutionResult = FencedExecutionResult{.callback = callback,
+ .syncFence = ndk::ScopedFileDescriptor(-1)};
+ return ndk::ScopedAStatus::ok();
+ };
+}
+
} // namespace
TEST_P(PreparedModelTest, invalidPreparedModel) {
@@ -82,6 +97,8 @@
}
TEST_P(PreparedModelTest, executeSync) {
+ if (kVersion.level >= nn::Version::Level::FEATURE_LEVEL_8) return;
+
// setup call
const auto mockPreparedModel = MockPreparedModel::create();
const auto preparedModel = PreparedModel::create(mockPreparedModel, kVersion).value();
@@ -96,7 +113,7 @@
DoAll(SetArgPointee<4>(mockExecutionResult), InvokeWithoutArgs(makeStatusOk)));
// run test
- const auto result = preparedModel->execute({}, {}, {}, {});
+ const auto result = preparedModel->execute({}, {}, {}, {}, {}, {});
// verify result
EXPECT_TRUE(result.has_value())
@@ -104,6 +121,8 @@
}
TEST_P(PreparedModelTest, executeSyncError) {
+ if (kVersion.level >= nn::Version::Level::FEATURE_LEVEL_8) return;
+
// setup test
const auto mockPreparedModel = MockPreparedModel::create();
const auto preparedModel = PreparedModel::create(mockPreparedModel, kVersion).value();
@@ -112,7 +131,7 @@
.WillOnce(Invoke(makeGeneralFailure));
// run test
- const auto result = preparedModel->execute({}, {}, {}, {});
+ const auto result = preparedModel->execute({}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -120,6 +139,8 @@
}
TEST_P(PreparedModelTest, executeSyncTransportFailure) {
+ if (kVersion.level >= nn::Version::Level::FEATURE_LEVEL_8) return;
+
// setup test
const auto mockPreparedModel = MockPreparedModel::create();
const auto preparedModel = PreparedModel::create(mockPreparedModel, kVersion).value();
@@ -128,7 +149,7 @@
.WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
// run test
- const auto result = preparedModel->execute({}, {}, {}, {});
+ const auto result = preparedModel->execute({}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -136,6 +157,8 @@
}
TEST_P(PreparedModelTest, executeSyncDeadObject) {
+ if (kVersion.level >= nn::Version::Level::FEATURE_LEVEL_8) return;
+
// setup test
const auto mockPreparedModel = MockPreparedModel::create();
const auto preparedModel = PreparedModel::create(mockPreparedModel, kVersion).value();
@@ -144,7 +167,7 @@
.WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
// run test
- const auto result = preparedModel->execute({}, {}, {}, {});
+ const auto result = preparedModel->execute({}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -152,6 +175,8 @@
}
TEST_P(PreparedModelTest, executeFenced) {
+ if (kVersion.level >= nn::Version::Level::FEATURE_LEVEL_8) return;
+
// setup call
const auto mockPreparedModel = MockPreparedModel::create();
const auto preparedModel = PreparedModel::create(mockPreparedModel, kVersion).value();
@@ -165,7 +190,7 @@
.WillOnce(Invoke(makeFencedExecutionResult(mockCallback)));
// run test
- const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {});
+ const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_TRUE(result.has_value())
@@ -181,6 +206,8 @@
}
TEST_P(PreparedModelTest, executeFencedCallbackError) {
+ if (kVersion.level >= nn::Version::Level::FEATURE_LEVEL_8) return;
+
// setup call
const auto mockPreparedModel = MockPreparedModel::create();
const auto preparedModel = PreparedModel::create(mockPreparedModel, kVersion).value();
@@ -195,7 +222,7 @@
.WillOnce(Invoke(makeFencedExecutionResult(mockCallback)));
// run test
- const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {});
+ const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_TRUE(result.has_value())
@@ -211,6 +238,8 @@
}
TEST_P(PreparedModelTest, executeFencedError) {
+ if (kVersion.level >= nn::Version::Level::FEATURE_LEVEL_8) return;
+
// setup test
const auto mockPreparedModel = MockPreparedModel::create();
const auto preparedModel = PreparedModel::create(mockPreparedModel, kVersion).value();
@@ -219,7 +248,7 @@
.WillOnce(InvokeWithoutArgs(makeGeneralFailure));
// run test
- const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {});
+ const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -227,6 +256,8 @@
}
TEST_P(PreparedModelTest, executeFencedTransportFailure) {
+ if (kVersion.level >= nn::Version::Level::FEATURE_LEVEL_8) return;
+
// setup test
const auto mockPreparedModel = MockPreparedModel::create();
const auto preparedModel = PreparedModel::create(mockPreparedModel, kVersion).value();
@@ -235,7 +266,7 @@
.WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
// run test
- const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {});
+ const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -243,6 +274,8 @@
}
TEST_P(PreparedModelTest, executeFencedDeadObject) {
+ if (kVersion.level >= nn::Version::Level::FEATURE_LEVEL_8) return;
+
// setup test
const auto mockPreparedModel = MockPreparedModel::create();
const auto preparedModel = PreparedModel::create(mockPreparedModel, kVersion).value();
@@ -251,7 +284,7 @@
.WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
// run test
- const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {});
+ const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -276,7 +309,7 @@
DoAll(SetArgPointee<4>(mockExecutionResult), InvokeWithoutArgs(makeStatusOk)));
// create execution
- const auto createResult = preparedModel->createReusableExecution({}, {}, {});
+ const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -300,7 +333,7 @@
.WillOnce(Invoke(makeGeneralFailure));
// create execution
- const auto createResult = preparedModel->createReusableExecution({}, {}, {});
+ const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -322,7 +355,7 @@
.WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
// create execution
- const auto createResult = preparedModel->createReusableExecution({}, {}, {});
+ const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -344,7 +377,7 @@
.WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
// create execution
- const auto createResult = preparedModel->createReusableExecution({}, {}, {});
+ const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -372,7 +405,7 @@
.WillRepeatedly(Invoke(makeFencedExecutionResult(mockCallback)));
// create execution
- const auto createResult = preparedModel->createReusableExecution({}, {}, {});
+ const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -410,7 +443,7 @@
.WillOnce(Invoke(makeFencedExecutionResult(mockCallback)));
// create execution
- const auto createResult = preparedModel->createReusableExecution({}, {}, {});
+ const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -440,7 +473,7 @@
.WillOnce(InvokeWithoutArgs(makeGeneralFailure));
// create execution
- const auto createResult = preparedModel->createReusableExecution({}, {}, {});
+ const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -462,7 +495,7 @@
.WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
// create execution
- const auto createResult = preparedModel->createReusableExecution({}, {}, {});
+ const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -484,7 +517,7 @@
.WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
// create execution
- const auto createResult = preparedModel->createReusableExecution({}, {}, {});
+ const auto createResult = preparedModel->createReusableExecution({}, {}, {}, {}, {});
ASSERT_TRUE(createResult.has_value())
<< "Failed with " << createResult.error().code << ": " << createResult.error().message;
ASSERT_NE(createResult.value(), nullptr);
@@ -495,6 +528,206 @@
EXPECT_EQ(computeResult.error().code, nn::ErrorStatus::DEAD_OBJECT);
}
+TEST_P(PreparedModelTest, executeSyncWithConfig) {
+ if (kVersion.level < nn::Version::Level::FEATURE_LEVEL_8) return;
+
+ // setup call
+ const auto mockPreparedModel = MockPreparedModel::create();
+ const auto preparedModel = PreparedModel::create(mockPreparedModel, kVersion).value();
+ const auto mockExecutionResult = ExecutionResult{
+ .outputSufficientSize = true,
+ .outputShapes = {},
+ .timing = kNoTiming,
+ };
+ EXPECT_CALL(*mockPreparedModel, executeSynchronouslyWithConfig(_, _, _, _))
+ .Times(1)
+ .WillOnce(
+ DoAll(SetArgPointee<3>(mockExecutionResult), InvokeWithoutArgs(makeStatusOk)));
+
+ // run test
+ const auto result = preparedModel->execute({}, {}, {}, {}, kHints, kExtensionNameToPrefix);
+
+ // verify result
+ EXPECT_TRUE(result.has_value())
+ << "Failed with " << result.error().code << ": " << result.error().message;
+}
+
+TEST_P(PreparedModelTest, executeSyncWithConfigError) {
+ if (kVersion.level < nn::Version::Level::FEATURE_LEVEL_8) return;
+
+ // setup test
+ const auto mockPreparedModel = MockPreparedModel::create();
+ const auto preparedModel = PreparedModel::create(mockPreparedModel, kVersion).value();
+ EXPECT_CALL(*mockPreparedModel, executeSynchronouslyWithConfig(_, _, _, _))
+ .Times(1)
+ .WillOnce(Invoke(makeGeneralFailure));
+
+ // run test
+ const auto result = preparedModel->execute({}, {}, {}, {}, kHints, kExtensionNameToPrefix);
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST_P(PreparedModelTest, executeSyncWithConfigTransportFailure) {
+ if (kVersion.level < nn::Version::Level::FEATURE_LEVEL_8) return;
+
+ // setup test
+ const auto mockPreparedModel = MockPreparedModel::create();
+ const auto preparedModel = PreparedModel::create(mockPreparedModel, kVersion).value();
+ EXPECT_CALL(*mockPreparedModel, executeSynchronouslyWithConfig(_, _, _, _))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
+
+ // run test
+ const auto result = preparedModel->execute({}, {}, {}, {}, kHints, kExtensionNameToPrefix);
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST_P(PreparedModelTest, executeSyncWithConfigDeadObject) {
+ if (kVersion.level < nn::Version::Level::FEATURE_LEVEL_8) return;
+
+ // setup test
+ const auto mockPreparedModel = MockPreparedModel::create();
+ const auto preparedModel = PreparedModel::create(mockPreparedModel, kVersion).value();
+ EXPECT_CALL(*mockPreparedModel, executeSynchronouslyWithConfig(_, _, _, _))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
+
+ // run test
+ const auto result = preparedModel->execute({}, {}, {}, {}, kHints, kExtensionNameToPrefix);
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
+}
+
+TEST_P(PreparedModelTest, executeFencedWithConfig) {
+ if (kVersion.level < nn::Version::Level::FEATURE_LEVEL_8) return;
+
+ // setup call
+ const auto mockPreparedModel = MockPreparedModel::create();
+ const auto preparedModel = PreparedModel::create(mockPreparedModel, kVersion).value();
+ const auto mockCallback = MockFencedExecutionCallback::create();
+ EXPECT_CALL(*mockCallback, getExecutionInfo(_, _, _))
+ .Times(1)
+ .WillOnce(DoAll(SetArgPointee<0>(kNoTiming), SetArgPointee<1>(kNoTiming),
+ SetArgPointee<2>(ErrorStatus::NONE), Invoke(makeStatusOk)));
+ EXPECT_CALL(*mockPreparedModel, executeFencedWithConfig(_, _, _, _, _, _))
+ .Times(1)
+ .WillOnce(Invoke(makeFencedExecutionWithConfigResult(mockCallback)));
+
+ // run test
+ const auto result =
+ preparedModel->executeFenced({}, {}, {}, {}, {}, {}, kHints, kExtensionNameToPrefix);
+
+ // verify result
+ ASSERT_TRUE(result.has_value())
+ << "Failed with " << result.error().code << ": " << result.error().message;
+ const auto& [syncFence, callback] = result.value();
+ EXPECT_EQ(syncFence.syncWait({}), nn::SyncFence::FenceState::SIGNALED);
+ ASSERT_NE(callback, nullptr);
+
+ // get results from callback
+ const auto callbackResult = callback();
+ ASSERT_TRUE(callbackResult.has_value()) << "Failed with " << callbackResult.error().code << ": "
+ << callbackResult.error().message;
+}
+
+TEST_P(PreparedModelTest, executeFencedWithConfigCallbackError) {
+ if (kVersion.level < nn::Version::Level::FEATURE_LEVEL_8) return;
+
+ // setup call
+ const auto mockPreparedModel = MockPreparedModel::create();
+ const auto preparedModel = PreparedModel::create(mockPreparedModel, kVersion).value();
+ const auto mockCallback = MockFencedExecutionCallback::create();
+ EXPECT_CALL(*mockCallback, getExecutionInfo(_, _, _))
+ .Times(1)
+ .WillOnce(Invoke(DoAll(SetArgPointee<0>(kNoTiming), SetArgPointee<1>(kNoTiming),
+ SetArgPointee<2>(ErrorStatus::GENERAL_FAILURE),
+ Invoke(makeStatusOk))));
+ EXPECT_CALL(*mockPreparedModel, executeFencedWithConfig(_, _, _, _, _, _))
+ .Times(1)
+ .WillOnce(Invoke(makeFencedExecutionWithConfigResult(mockCallback)));
+
+ // run test
+ const auto result =
+ preparedModel->executeFenced({}, {}, {}, {}, {}, {}, kHints, kExtensionNameToPrefix);
+
+ // verify result
+ ASSERT_TRUE(result.has_value())
+ << "Failed with " << result.error().code << ": " << result.error().message;
+ const auto& [syncFence, callback] = result.value();
+ EXPECT_NE(syncFence.syncWait({}), nn::SyncFence::FenceState::ACTIVE);
+ ASSERT_NE(callback, nullptr);
+
+ // verify callback failure
+ const auto callbackResult = callback();
+ ASSERT_FALSE(callbackResult.has_value());
+ EXPECT_EQ(callbackResult.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST_P(PreparedModelTest, executeFencedWithConfigError) {
+ if (kVersion.level < nn::Version::Level::FEATURE_LEVEL_8) return;
+
+ // setup test
+ const auto mockPreparedModel = MockPreparedModel::create();
+ const auto preparedModel = PreparedModel::create(mockPreparedModel, kVersion).value();
+ EXPECT_CALL(*mockPreparedModel, executeFencedWithConfig(_, _, _, _, _, _))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(makeGeneralFailure));
+
+ // run test
+ const auto result =
+ preparedModel->executeFenced({}, {}, {}, {}, {}, {}, kHints, kExtensionNameToPrefix);
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST_P(PreparedModelTest, executeFencedWithConfigTransportFailure) {
+ if (kVersion.level < nn::Version::Level::FEATURE_LEVEL_8) return;
+
+ // setup test
+ const auto mockPreparedModel = MockPreparedModel::create();
+ const auto preparedModel = PreparedModel::create(mockPreparedModel, kVersion).value();
+ EXPECT_CALL(*mockPreparedModel, executeFencedWithConfig(_, _, _, _, _, _))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
+
+ // run test
+ const auto result =
+ preparedModel->executeFenced({}, {}, {}, {}, {}, {}, kHints, kExtensionNameToPrefix);
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST_P(PreparedModelTest, executeFencedWithConfigDeadObject) {
+ if (kVersion.level < nn::Version::Level::FEATURE_LEVEL_8) return;
+
+ // setup test
+ const auto mockPreparedModel = MockPreparedModel::create();
+ const auto preparedModel = PreparedModel::create(mockPreparedModel, kVersion).value();
+ EXPECT_CALL(*mockPreparedModel, executeFencedWithConfig(_, _, _, _, _, _))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
+
+ // run test
+ const auto result =
+ preparedModel->executeFenced({}, {}, {}, {}, {}, {}, kHints, kExtensionNameToPrefix);
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
+}
+
TEST_P(PreparedModelTest, configureExecutionBurst) {
// setup test
const auto mockPreparedModel = MockPreparedModel::create();
@@ -567,13 +800,13 @@
// setup test
const auto mockPreparedModel = MockPreparedModel::create();
const auto mockExecution = ndk::SharedRefBase::make<MockExecution>();
- EXPECT_CALL(*mockPreparedModel, createReusableExecution(_, _, _, _))
+ EXPECT_CALL(*mockPreparedModel, createReusableExecution(_, _, _))
.Times(1)
- .WillOnce(DoAll(SetArgPointee<3>(mockExecution), Invoke(makeStatusOk)));
+ .WillOnce(DoAll(SetArgPointee<2>(mockExecution), Invoke(makeStatusOk)));
const auto preparedModel = PreparedModel::create(mockPreparedModel, kVersion).value();
// run test
- const auto result = preparedModel->createReusableExecution({}, {}, {});
+ const auto result = preparedModel->createReusableExecution({}, {}, {}, {}, {});
// verify result
ASSERT_TRUE(result.has_value())
@@ -586,13 +819,13 @@
// setup test
const auto mockPreparedModel = MockPreparedModel::create();
- EXPECT_CALL(*mockPreparedModel, createReusableExecution(_, _, _, _))
+ EXPECT_CALL(*mockPreparedModel, createReusableExecution(_, _, _))
.Times(1)
.WillOnce(InvokeWithoutArgs(makeGeneralFailure));
const auto preparedModel = PreparedModel::create(mockPreparedModel, kVersion).value();
// run test
- const auto result = preparedModel->createReusableExecution({}, {}, {});
+ const auto result = preparedModel->createReusableExecution({}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -604,13 +837,13 @@
// setup test
const auto mockPreparedModel = MockPreparedModel::create();
- EXPECT_CALL(*mockPreparedModel, createReusableExecution(_, _, _, _))
+ EXPECT_CALL(*mockPreparedModel, createReusableExecution(_, _, _))
.Times(1)
.WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
const auto preparedModel = PreparedModel::create(mockPreparedModel, kVersion).value();
// run test
- const auto result = preparedModel->createReusableExecution({}, {}, {});
+ const auto result = preparedModel->createReusableExecution({}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());
@@ -622,13 +855,13 @@
// setup test
const auto mockPreparedModel = MockPreparedModel::create();
- EXPECT_CALL(*mockPreparedModel, createReusableExecution(_, _, _, _))
+ EXPECT_CALL(*mockPreparedModel, createReusableExecution(_, _, _))
.Times(1)
.WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
const auto preparedModel = PreparedModel::create(mockPreparedModel, kVersion).value();
// run test
- const auto result = preparedModel->createReusableExecution({}, {}, {});
+ const auto result = preparedModel->createReusableExecution({}, {}, {}, {}, {});
// verify result
ASSERT_FALSE(result.has_value());