Merge changes from topic "aosp-nnapi-burst-compat-lib"
* changes:
Implement full canonical Burst in NN util code
Add canonical types adapters for NNAPI AIDL interface
Add missing NNAPI HIDL interface mock tests
diff --git a/neuralnetworks/1.2/utils/Android.bp b/neuralnetworks/1.2/utils/Android.bp
index 2921141..41281ee 100644
--- a/neuralnetworks/1.2/utils/Android.bp
+++ b/neuralnetworks/1.2/utils/Android.bp
@@ -27,7 +27,6 @@
name: "neuralnetworks_utils_hal_1_2",
defaults: ["neuralnetworks_utils_defaults"],
srcs: ["src/*"],
- exclude_srcs: ["src/ExecutionBurst*"],
local_include_dirs: ["include/nnapi/hal/1.2/"],
export_include_dirs: ["include"],
cflags: ["-Wthread-safety"],
@@ -41,10 +40,16 @@
"android.hardware.neuralnetworks@1.0",
"android.hardware.neuralnetworks@1.1",
"android.hardware.neuralnetworks@1.2",
+ "libfmq",
],
export_static_lib_headers: [
"neuralnetworks_utils_hal_common",
],
+ product_variables: {
+ debuggable: { // eng and userdebug builds
+ cflags: ["-DNN_DEBUGGABLE"],
+ },
+ },
}
cc_test {
diff --git a/neuralnetworks/1.2/utils/include/nnapi/hal/1.2/Conversions.h b/neuralnetworks/1.2/utils/include/nnapi/hal/1.2/Conversions.h
index 6fd1337..272cee7 100644
--- a/neuralnetworks/1.2/utils/include/nnapi/hal/1.2/Conversions.h
+++ b/neuralnetworks/1.2/utils/include/nnapi/hal/1.2/Conversions.h
@@ -52,6 +52,7 @@
GeneralResult<Model> convert(const hal::V1_2::Model& model);
GeneralResult<MeasureTiming> convert(const hal::V1_2::MeasureTiming& measureTiming);
GeneralResult<Timing> convert(const hal::V1_2::Timing& timing);
+GeneralResult<SharedMemory> convert(const hardware::hidl_memory& memory);
GeneralResult<std::vector<Extension>> convert(
const hardware::hidl_vec<hal::V1_2::Extension>& extensions);
diff --git a/neuralnetworks/1.2/utils/include/nnapi/hal/1.2/ExecutionBurstController.h b/neuralnetworks/1.2/utils/include/nnapi/hal/1.2/ExecutionBurstController.h
index 5356a91..6b6fc71 100644
--- a/neuralnetworks/1.2/utils/include/nnapi/hal/1.2/ExecutionBurstController.h
+++ b/neuralnetworks/1.2/utils/include/nnapi/hal/1.2/ExecutionBurstController.h
@@ -14,23 +14,28 @@
* limitations under the License.
*/
-#ifndef ANDROID_FRAMEWORKS_ML_NN_COMMON_EXECUTION_BURST_CONTROLLER_H
-#define ANDROID_FRAMEWORKS_ML_NN_COMMON_EXECUTION_BURST_CONTROLLER_H
+#ifndef ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_1_2_UTILS_EXECUTION_BURST_CONTROLLER_H
+#define ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_1_2_UTILS_EXECUTION_BURST_CONTROLLER_H
#include "ExecutionBurstUtils.h"
-#include <android-base/macros.h>
+#include <android-base/thread_annotations.h>
#include <android/hardware/neuralnetworks/1.0/types.h>
-#include <android/hardware/neuralnetworks/1.1/types.h>
#include <android/hardware/neuralnetworks/1.2/IBurstCallback.h>
#include <android/hardware/neuralnetworks/1.2/IBurstContext.h>
#include <android/hardware/neuralnetworks/1.2/IPreparedModel.h>
#include <android/hardware/neuralnetworks/1.2/types.h>
#include <fmq/MessageQueue.h>
#include <hidl/MQDescriptor.h>
+#include <nnapi/IBurst.h>
+#include <nnapi/IPreparedModel.h>
+#include <nnapi/Result.h>
+#include <nnapi/Types.h>
+#include <nnapi/hal/ProtectCallback.h>
#include <atomic>
#include <chrono>
+#include <functional>
#include <map>
#include <memory>
#include <mutex>
@@ -39,147 +44,145 @@
#include <utility>
#include <vector>
-namespace android::nn {
+namespace android::hardware::neuralnetworks::V1_2::utils {
/**
- * The ExecutionBurstController class manages both the serialization and
- * deserialization of data across FMQ, making it appear to the runtime as a
- * regular synchronous inference. Additionally, this class manages the burst's
- * memory cache.
+ * The ExecutionBurstController class manages both the serialization and deserialization of data
+ * across FMQ, making it appear to the runtime as a regular synchronous inference. Additionally,
+ * this class manages the burst's memory cache.
*/
-class ExecutionBurstController {
- DISALLOW_IMPLICIT_CONSTRUCTORS(ExecutionBurstController);
+class ExecutionBurstController final : public nn::IBurst {
+ struct PrivateConstructorTag {};
public:
+ using FallbackFunction =
+ std::function<nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>>(
+ const nn::Request&, nn::MeasureTiming)>;
+
/**
- * NN runtime burst callback object and memory cache.
+ * NN runtime memory cache.
*
- * ExecutionBurstCallback associates a hidl_memory object with a slot number
- * to be passed across FMQ. The ExecutionBurstServer can use this callback
- * to retrieve this hidl_memory corresponding to the slot via HIDL.
+ * MemoryCache associates a Memory object with a slot number to be passed across FMQ. The
+ * ExecutionBurstServer can use this callback to retrieve a hidl_memory corresponding to the
+ * slot via HIDL.
*
- * Whenever a hidl_memory object is copied, it will duplicate the underlying
- * file descriptor. Because the NN runtime currently copies the hidl_memory
- * on each execution, it is difficult to associate hidl_memory objects with
- * previously cached hidl_memory objects. For this reason, callers of this
- * class must pair each hidl_memory object with an associated key. For
- * efficiency, if two hidl_memory objects represent the same underlying
- * buffer, they must use the same key.
+ * Whenever a hidl_memory object is copied, it will duplicate the underlying file descriptor.
+ * Because the NN runtime currently copies the hidl_memory on each execution, it is difficult to
+ * associate hidl_memory objects with previously cached hidl_memory objects. For this reason,
+ * callers of this class must pair each hidl_memory object with an associated key. For
+ * efficiency, if two hidl_memory objects represent the same underlying buffer, they must use
+ * the same key.
+ *
+ * This class is thread-safe.
*/
- class ExecutionBurstCallback : public hardware::neuralnetworks::V1_2::IBurstCallback {
- DISALLOW_COPY_AND_ASSIGN(ExecutionBurstCallback);
+ class MemoryCache : public std::enable_shared_from_this<MemoryCache> {
+ struct PrivateConstructorTag {};
public:
- ExecutionBurstCallback() = default;
+ using Task = std::function<void()>;
+ using Cleanup = base::ScopeGuard<Task>;
+ using SharedCleanup = std::shared_ptr<const Cleanup>;
+ using WeakCleanup = std::weak_ptr<const Cleanup>;
- hardware::Return<void> getMemories(const hardware::hidl_vec<int32_t>& slots,
- getMemories_cb cb) override;
+ // Custom constructor to pre-allocate cache sizes.
+ MemoryCache();
/**
- * This function performs one of two different actions:
- * 1) If a key corresponding to a memory resource is unrecognized by the
- * ExecutionBurstCallback object, the ExecutionBurstCallback object
- * will allocate a slot, bind the memory to the slot, and return the
- * slot identifier.
- * 2) If a key corresponding to a memory resource is recognized by the
- * ExecutionBurstCallback object, the ExecutionBurstCallback object
- * will return the existing slot identifier.
+ * Add a burst context to the MemoryCache object.
*
- * @param memories Memory resources used in an inference.
- * @param keys Unique identifiers where each element corresponds to a
- * memory resource element in "memories".
- * @return Unique slot identifiers where each returned slot element
- * corresponds to a memory resource element in "memories".
+ * If this method is called, it must be called before the MemoryCache::cacheMemory or
+ * MemoryCache::getMemory is used.
+ *
+ * @param burstContext Burst context to be added to the MemoryCache object.
*/
- std::vector<int32_t> getSlots(const hardware::hidl_vec<hardware::hidl_memory>& memories,
- const std::vector<intptr_t>& keys);
+ void setBurstContext(sp<IBurstContext> burstContext);
- /*
- * This function performs two different actions:
- * 1) Removes an entry from the cache (if present), including the local
- * storage of the hidl_memory object. Note that this call does not
- * free any corresponding hidl_memory object in ExecutionBurstServer,
- * which is separately freed via IBurstContext::freeMemory.
- * 2) Return whether a cache entry was removed and which slot was removed if
- * found. If the key did not to correspond to any entry in the cache, a
- * slot number of 0 is returned. The slot number and whether the entry
- * existed is useful so the same slot can be freed in the
- * ExecutionBurstServer's cache via IBurstContext::freeMemory.
+ /**
+ * Cache a memory object in the MemoryCache object.
+ *
+ * @param memory Memory object to be cached while the returned `SharedCleanup` is alive.
+ * @return A pair of (1) a unique identifier for the cache entry and (2) a ref-counted
+ * "hold" object which preserves the cache as long as the hold object is alive.
*/
- std::pair<bool, int32_t> freeMemory(intptr_t key);
+ std::pair<int32_t, SharedCleanup> cacheMemory(const nn::SharedMemory& memory);
+
+ /**
+ * Get the memory object corresponding to a slot identifier.
+ *
+ * @param slot Slot which identifies the memory object to retrieve.
+ * @return The memory object corresponding to slot, otherwise GeneralError.
+ */
+ nn::GeneralResult<nn::SharedMemory> getMemory(int32_t slot);
private:
- int32_t getSlotLocked(const hardware::hidl_memory& memory, intptr_t key);
- int32_t allocateSlotLocked();
+ void freeMemory(const nn::SharedMemory& memory);
+ int32_t allocateSlotLocked() REQUIRES(mMutex);
std::mutex mMutex;
- std::stack<int32_t, std::vector<int32_t>> mFreeSlots;
- std::map<intptr_t, int32_t> mMemoryIdToSlot;
- std::vector<hardware::hidl_memory> mMemoryCache;
+ std::condition_variable mCond;
+ sp<IBurstContext> mBurstContext GUARDED_BY(mMutex);
+ std::stack<int32_t, std::vector<int32_t>> mFreeSlots GUARDED_BY(mMutex);
+ std::map<nn::SharedMemory, int32_t> mMemoryIdToSlot GUARDED_BY(mMutex);
+ std::vector<nn::SharedMemory> mMemoryCache GUARDED_BY(mMutex);
+ std::vector<WeakCleanup> mCacheCleaner GUARDED_BY(mMutex);
+ };
+
+ /**
+ * HIDL Callback class to pass memory objects to the Burst server when given corresponding
+ * slots.
+ */
+ class ExecutionBurstCallback : public IBurstCallback {
+ public:
+ // Precondition: memoryCache must be non-null.
+ explicit ExecutionBurstCallback(const std::shared_ptr<MemoryCache>& memoryCache);
+
+ // See IBurstCallback::getMemories for information on this method.
+ Return<void> getMemories(const hidl_vec<int32_t>& slots, getMemories_cb cb) override;
+
+ private:
+ const std::weak_ptr<MemoryCache> kMemoryCache;
};
/**
* Creates a burst controller on a prepared model.
*
- * Prefer this over ExecutionBurstController's constructor.
- *
* @param preparedModel Model prepared for execution to execute on.
- * @param pollingTimeWindow How much time (in microseconds) the
- * ExecutionBurstController is allowed to poll the FMQ before waiting on
- * the blocking futex. Polling may result in lower latencies at the
- * potential cost of more power usage.
+ * @param pollingTimeWindow How much time (in microseconds) the ExecutionBurstController is
+ * allowed to poll the FMQ before waiting on the blocking futex. Polling may result in lower
+ * latencies at the potential cost of more power usage.
* @return ExecutionBurstController Execution burst controller object.
*/
- static std::unique_ptr<ExecutionBurstController> create(
- const sp<hardware::neuralnetworks::V1_2::IPreparedModel>& preparedModel,
+ static nn::GeneralResult<std::shared_ptr<const ExecutionBurstController>> create(
+ const sp<IPreparedModel>& preparedModel, FallbackFunction fallback,
std::chrono::microseconds pollingTimeWindow);
- // prefer calling ExecutionBurstController::create
- ExecutionBurstController(const std::shared_ptr<RequestChannelSender>& requestChannelSender,
- const std::shared_ptr<ResultChannelReceiver>& resultChannelReceiver,
- const sp<hardware::neuralnetworks::V1_2::IBurstContext>& burstContext,
- const sp<ExecutionBurstCallback>& callback,
- const sp<hardware::hidl_death_recipient>& deathHandler = nullptr);
+ ExecutionBurstController(PrivateConstructorTag tag, FallbackFunction fallback,
+ std::unique_ptr<RequestChannelSender> requestChannelSender,
+ std::unique_ptr<ResultChannelReceiver> resultChannelReceiver,
+ sp<ExecutionBurstCallback> callback, sp<IBurstContext> burstContext,
+ std::shared_ptr<MemoryCache> memoryCache,
+ neuralnetworks::utils::DeathHandler deathHandler);
- // explicit destructor to unregister the death recipient
- ~ExecutionBurstController();
+ // See IBurst::cacheMemory for information on this method.
+ OptionalCacheHold cacheMemory(const nn::SharedMemory& memory) const override;
- /**
- * Execute a request on a model.
- *
- * @param request Arguments to be executed on a model.
- * @param measure Whether to collect timing measurements, either YES or NO
- * @param memoryIds Identifiers corresponding to each memory object in the
- * request's pools.
- * @return A tuple of:
- * - result code of the execution
- * - dynamic output shapes from the execution
- * - any execution time measurements of the execution
- * - whether or not a failed burst execution should be re-run using a
- * different path (e.g., IPreparedModel::executeSynchronously)
- */
- std::tuple<int, std::vector<hardware::neuralnetworks::V1_2::OutputShape>,
- hardware::neuralnetworks::V1_2::Timing, bool>
- compute(const hardware::neuralnetworks::V1_0::Request& request,
- hardware::neuralnetworks::V1_2::MeasureTiming measure,
- const std::vector<intptr_t>& memoryIds);
-
- /**
- * Propagate a user's freeing of memory to the service.
- *
- * @param key Key corresponding to the memory object.
- */
- void freeMemory(intptr_t key);
+ // See IBurst::execute for information on this method.
+ nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> execute(
+ const nn::Request& request, nn::MeasureTiming measure) const override;
private:
- std::mutex mMutex;
- const std::shared_ptr<RequestChannelSender> mRequestChannelSender;
- const std::shared_ptr<ResultChannelReceiver> mResultChannelReceiver;
- const sp<hardware::neuralnetworks::V1_2::IBurstContext> mBurstContext;
- const sp<ExecutionBurstCallback> mMemoryCache;
- const sp<hardware::hidl_death_recipient> mDeathHandler;
+ mutable std::atomic_flag mExecutionInFlight = ATOMIC_FLAG_INIT;
+ const FallbackFunction kFallback;
+ const std::unique_ptr<RequestChannelSender> mRequestChannelSender;
+ const std::unique_ptr<ResultChannelReceiver> mResultChannelReceiver;
+ const sp<ExecutionBurstCallback> mBurstCallback;
+ const sp<IBurstContext> mBurstContext;
+ const std::shared_ptr<MemoryCache> mMemoryCache;
+ // `kDeathHandler` must come after `mRequestChannelSender` and `mResultChannelReceiver` because
+ // it holds references to both objects.
+ const neuralnetworks::utils::DeathHandler kDeathHandler;
};
-} // namespace android::nn
+} // namespace android::hardware::neuralnetworks::V1_2::utils
-#endif // ANDROID_FRAMEWORKS_ML_NN_COMMON_EXECUTION_BURST_CONTROLLER_H
+#endif // ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_1_2_UTILS_EXECUTION_BURST_CONTROLLER_H
diff --git a/neuralnetworks/1.2/utils/include/nnapi/hal/1.2/ExecutionBurstServer.h b/neuralnetworks/1.2/utils/include/nnapi/hal/1.2/ExecutionBurstServer.h
index 2e109b2..f7926f5 100644
--- a/neuralnetworks/1.2/utils/include/nnapi/hal/1.2/ExecutionBurstServer.h
+++ b/neuralnetworks/1.2/utils/include/nnapi/hal/1.2/ExecutionBurstServer.h
@@ -14,19 +14,22 @@
* limitations under the License.
*/
-#ifndef ANDROID_FRAMEWORKS_ML_NN_COMMON_EXECUTION_BURST_SERVER_H
-#define ANDROID_FRAMEWORKS_ML_NN_COMMON_EXECUTION_BURST_SERVER_H
+#ifndef ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_1_2_UTILS_EXECUTION_BURST_SERVER_H
+#define ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_1_2_UTILS_EXECUTION_BURST_SERVER_H
#include "ExecutionBurstUtils.h"
-#include <android-base/macros.h>
+#include <android-base/thread_annotations.h>
#include <android/hardware/neuralnetworks/1.0/types.h>
-#include <android/hardware/neuralnetworks/1.1/types.h>
#include <android/hardware/neuralnetworks/1.2/IBurstCallback.h>
#include <android/hardware/neuralnetworks/1.2/IPreparedModel.h>
#include <android/hardware/neuralnetworks/1.2/types.h>
#include <fmq/MessageQueue.h>
#include <hidl/MQDescriptor.h>
+#include <nnapi/IBurst.h>
+#include <nnapi/Result.h>
+#include <nnapi/Types.h>
+#include <nnapi/hal/ProtectCallback.h>
#include <atomic>
#include <chrono>
@@ -36,84 +39,61 @@
#include <tuple>
#include <vector>
-namespace android::nn {
+namespace android::hardware::neuralnetworks::V1_2::utils {
/**
- * The ExecutionBurstServer class is responsible for waiting for and
- * deserializing a request object from a FMQ, performing the inference, and
- * serializing the result back across another FMQ.
+ * The ExecutionBurstServer class is responsible for waiting for and deserializing a request object
+ * from a FMQ, performing the inference, and serializing the result back across another FMQ.
*/
-class ExecutionBurstServer : public hardware::neuralnetworks::V1_2::IBurstContext {
- DISALLOW_IMPLICIT_CONSTRUCTORS(ExecutionBurstServer);
+class ExecutionBurstServer : public IBurstContext {
+ struct PrivateConstructorTag {};
public:
/**
- * IBurstExecutorWithCache is a callback object passed to
- * ExecutionBurstServer's factory function that is used to perform an
- * execution. Because some memory resources are needed across multiple
- * executions, this object also contains a local cache that can directly be
- * used in the execution.
+ * Class to cache the memory objects for a burst object.
*
- * ExecutionBurstServer will never access its IBurstExecutorWithCache object
- * with concurrent calls.
+ * This class is thread-safe.
*/
- class IBurstExecutorWithCache {
- DISALLOW_COPY_AND_ASSIGN(IBurstExecutorWithCache);
-
+ class MemoryCache {
public:
- IBurstExecutorWithCache() = default;
- virtual ~IBurstExecutorWithCache() = default;
+ // Precondition: burstExecutor != nullptr
+ // Precondition: burstCallback != nullptr
+ MemoryCache(nn::SharedBurst burstExecutor, sp<IBurstCallback> burstCallback);
/**
- * Checks if a cache entry specified by a slot is present in the cache.
+ * Get the cached memory objects corresponding to provided slot identifiers.
*
- * @param slot Identifier of the cache entry.
- * @return 'true' if the cache entry is present in the cache, 'false'
- * otherwise.
+ * If the slot entry is not present in the cache, this class will use IBurstCallback to
+ * retrieve those entries that are not present in the cache, then cache them.
+ *
+ * @param slots Identifiers of memory objects to be retrieved.
+ * @return A vector where each element is the memory object and a ref-counted cache "hold"
+ * object to preserve the cache entry of the IBurst object as long as the "hold" object
+ * is alive, otherwise GeneralError. Each element of the vector corresponds to the
+ * element of slot.
*/
- virtual bool isCacheEntryPresent(int32_t slot) const = 0;
+ nn::GeneralResult<std::vector<std::pair<nn::SharedMemory, nn::IBurst::OptionalCacheHold>>>
+ getCacheEntries(const std::vector<int32_t>& slots);
/**
- * Adds an entry specified by a slot to the cache.
+ * Remove an entry from the cache.
*
- * The caller of this function must ensure that the cache entry that is
- * being added is not already present in the cache. This can be checked
- * via isCacheEntryPresent.
- *
- * @param memory Memory resource to be cached.
- * @param slot Slot identifier corresponding to the memory resource.
+ * @param slot Identifier of the memory object to be removed from the cache.
*/
- virtual void addCacheEntry(const hardware::hidl_memory& memory, int32_t slot) = 0;
+ void removeCacheEntry(int32_t slot);
- /**
- * Removes an entry specified by a slot from the cache.
- *
- * If the cache entry corresponding to the slot number does not exist,
- * the call does nothing.
- *
- * @param slot Slot identifier corresponding to the memory resource.
- */
- virtual void removeCacheEntry(int32_t slot) = 0;
+ private:
+ nn::GeneralResult<void> ensureCacheEntriesArePresentLocked(
+ const std::vector<int32_t>& slots) REQUIRES(mMutex);
+ nn::GeneralResult<std::pair<nn::SharedMemory, nn::IBurst::OptionalCacheHold>>
+ getCacheEntryLocked(int32_t slot) REQUIRES(mMutex);
+ void addCacheEntryLocked(int32_t slot, nn::SharedMemory memory) REQUIRES(mMutex);
- /**
- * Perform an execution.
- *
- * @param request Request object with inputs and outputs specified.
- * Request::pools is empty, and DataLocation::poolIndex instead
- * refers to the 'slots' argument as if it were Request::pools.
- * @param slots Slots corresponding to the cached memory entries to be
- * used.
- * @param measure Whether timing information is requested for the
- * execution.
- * @return Result of the execution, including the status of the
- * execution, dynamic output shapes, and any timing information.
- */
- virtual std::tuple<hardware::neuralnetworks::V1_0::ErrorStatus,
- hardware::hidl_vec<hardware::neuralnetworks::V1_2::OutputShape>,
- hardware::neuralnetworks::V1_2::Timing>
- execute(const hardware::neuralnetworks::V1_0::Request& request,
- const std::vector<int32_t>& slots,
- hardware::neuralnetworks::V1_2::MeasureTiming measure) = 0;
+ std::mutex mMutex;
+ std::map<int32_t, std::pair<nn::SharedMemory, nn::IBurst::OptionalCacheHold>> mCache
+ GUARDED_BY(mMutex);
+ nn::SharedBurst kBurstExecutor;
+ const sp<IBurstCallback> kBurstCallback;
};
/**
@@ -124,85 +104,52 @@
* 2) Execute a model with the given information
* 3) Send the result to the created FMQ
*
- * @param callback Callback used to retrieve memories corresponding to
- * unrecognized slots.
- * @param requestChannel Input FMQ channel through which the client passes the
- * request to the service.
- * @param resultChannel Output FMQ channel from which the client can retrieve
- * the result of the execution.
- * @param executorWithCache Object which maintains a local cache of the
- * memory pools and executes using the cached memory pools.
- * @param pollingTimeWindow How much time (in microseconds) the
- * ExecutionBurstServer is allowed to poll the FMQ before waiting on
- * the blocking futex. Polling may result in lower latencies at the
- * potential cost of more power usage.
- * @result IBurstContext Handle to the burst context.
- */
- static sp<ExecutionBurstServer> create(
- const sp<hardware::neuralnetworks::V1_2::IBurstCallback>& callback,
- const FmqRequestDescriptor& requestChannel, const FmqResultDescriptor& resultChannel,
- std::shared_ptr<IBurstExecutorWithCache> executorWithCache,
- std::chrono::microseconds pollingTimeWindow = std::chrono::microseconds{0});
-
- /**
- * Create automated context to manage FMQ-based executions.
- *
- * This function is intended to be used by a service to automatically:
- * 1) Receive data from a provided FMQ
- * 2) Execute a model with the given information
- * 3) Send the result to the created FMQ
- *
- * @param callback Callback used to retrieve memories corresponding to
- * unrecognized slots.
- * @param requestChannel Input FMQ channel through which the client passes the
- * request to the service.
- * @param resultChannel Output FMQ channel from which the client can retrieve
- * the result of the execution.
- * @param preparedModel PreparedModel that the burst object was created from.
- * IPreparedModel::executeSynchronously will be used to perform the
+ * @param callback Callback used to retrieve memories corresponding to unrecognized slots.
+ * @param requestChannel Input FMQ channel through which the client passes the request to the
+ * service.
+ * @param resultChannel Output FMQ channel from which the client can retrieve the result of the
* execution.
- * @param pollingTimeWindow How much time (in microseconds) the
- * ExecutionBurstServer is allowed to poll the FMQ before waiting on
- * the blocking futex. Polling may result in lower latencies at the
- * potential cost of more power usage.
- * @result IBurstContext Handle to the burst context.
+ * @param burstExecutor Object which maintains a local cache of the memory pools and executes
+ * using the cached memory pools.
+ * @param pollingTimeWindow How much time (in microseconds) the ExecutionBurstServer is allowed
+ * to poll the FMQ before waiting on the blocking futex. Polling may result in lower
+ * latencies at the potential cost of more power usage.
+ * @return IBurstContext Handle to the burst context.
*/
- static sp<ExecutionBurstServer> create(
- const sp<hardware::neuralnetworks::V1_2::IBurstCallback>& callback,
- const FmqRequestDescriptor& requestChannel, const FmqResultDescriptor& resultChannel,
- hardware::neuralnetworks::V1_2::IPreparedModel* preparedModel,
+ static nn::GeneralResult<sp<ExecutionBurstServer>> create(
+ const sp<IBurstCallback>& callback,
+ const MQDescriptorSync<FmqRequestDatum>& requestChannel,
+ const MQDescriptorSync<FmqResultDatum>& resultChannel, nn::SharedBurst burstExecutor,
std::chrono::microseconds pollingTimeWindow = std::chrono::microseconds{0});
- ExecutionBurstServer(const sp<hardware::neuralnetworks::V1_2::IBurstCallback>& callback,
+ ExecutionBurstServer(PrivateConstructorTag tag, const sp<IBurstCallback>& callback,
std::unique_ptr<RequestChannelReceiver> requestChannel,
std::unique_ptr<ResultChannelSender> resultChannel,
- std::shared_ptr<IBurstExecutorWithCache> cachedExecutor);
+ nn::SharedBurst burstExecutor);
~ExecutionBurstServer();
- // Used by the NN runtime to preemptively remove any stored memory.
- hardware::Return<void> freeMemory(int32_t slot) override;
+ // Used by the NN runtime to preemptively remove any stored memory. See
+ // IBurstContext::freeMemory for more information.
+ Return<void> freeMemory(int32_t slot) override;
private:
- // Ensures all cache entries contained in mExecutorWithCache are present in
- // the cache. If they are not present, they are retrieved (via
- // IBurstCallback::getMemories) and added to mExecutorWithCache.
- //
- // This method is locked via mMutex when it is called.
- void ensureCacheEntriesArePresentLocked(const std::vector<int32_t>& slots);
-
- // Work loop that will continue processing execution requests until the
- // ExecutionBurstServer object is freed.
+ // Work loop that will continue processing execution requests until the ExecutionBurstServer
+ // object is freed.
void task();
+ nn::ExecutionResult<std::pair<hidl_vec<OutputShape>, Timing>> execute(
+ const V1_0::Request& requestWithoutPools, const std::vector<int32_t>& slotsOfPools,
+ MeasureTiming measure);
+
std::thread mWorker;
- std::mutex mMutex;
std::atomic<bool> mTeardown{false};
- const sp<hardware::neuralnetworks::V1_2::IBurstCallback> mCallback;
+ const sp<IBurstCallback> mCallback;
const std::unique_ptr<RequestChannelReceiver> mRequestChannelReceiver;
const std::unique_ptr<ResultChannelSender> mResultChannelSender;
- const std::shared_ptr<IBurstExecutorWithCache> mExecutorWithCache;
+ const nn::SharedBurst mBurstExecutor;
+ MemoryCache mMemoryCache;
};
-} // namespace android::nn
+} // namespace android::hardware::neuralnetworks::V1_2::utils
-#endif // ANDROID_FRAMEWORKS_ML_NN_COMMON_EXECUTION_BURST_SERVER_H
+#endif // ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_1_2_UTILS_EXECUTION_BURST_SERVER_H
diff --git a/neuralnetworks/1.2/utils/include/nnapi/hal/1.2/ExecutionBurstUtils.h b/neuralnetworks/1.2/utils/include/nnapi/hal/1.2/ExecutionBurstUtils.h
index 8a41591..c662bc3 100644
--- a/neuralnetworks/1.2/utils/include/nnapi/hal/1.2/ExecutionBurstUtils.h
+++ b/neuralnetworks/1.2/utils/include/nnapi/hal/1.2/ExecutionBurstUtils.h
@@ -18,15 +18,16 @@
#define ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_1_2_UTILS_EXECUTION_BURST_UTILS_H
#include <android/hardware/neuralnetworks/1.0/types.h>
-#include <android/hardware/neuralnetworks/1.1/types.h>
#include <android/hardware/neuralnetworks/1.2/types.h>
#include <fmq/MessageQueue.h>
#include <hidl/MQDescriptor.h>
+#include <nnapi/Result.h>
+#include <nnapi/Types.h>
+#include <nnapi/hal/ProtectCallback.h>
#include <atomic>
#include <chrono>
#include <memory>
-#include <optional>
#include <tuple>
#include <utility>
#include <vector>
@@ -38,159 +39,139 @@
*/
constexpr const size_t kExecutionBurstChannelLength = 1024;
-using FmqRequestDescriptor = MQDescriptorSync<FmqRequestDatum>;
-using FmqResultDescriptor = MQDescriptorSync<FmqResultDatum>;
+/**
+ * Get how long the burst controller should poll while waiting for results to be returned.
+ *
+ * This time can be affected by the property "debug.nn.burst-controller-polling-window".
+ *
+ * @return Polling time in microseconds.
+ */
+std::chrono::microseconds getBurstControllerPollingTimeWindow();
+
+/**
+ * Get how long the burst server should poll while waiting for a request to be received.
+ *
+ * This time can be affected by the property "debug.nn.burst-server-polling-window".
+ *
+ * @return Polling time in microseconds.
+ */
+std::chrono::microseconds getBurstServerPollingTimeWindow();
/**
* Function to serialize a request.
*
- * Prefer calling RequestChannelSender::send.
- *
* @param request Request object without the pool information.
* @param measure Whether to collect timing information for the execution.
- * @param memoryIds Slot identifiers corresponding to memory resources for the
- * request.
+ * @param memoryIds Slot identifiers corresponding to memory resources for the request.
* @return Serialized FMQ request data.
*/
-std::vector<hardware::neuralnetworks::V1_2::FmqRequestDatum> serialize(
- const hardware::neuralnetworks::V1_0::Request& request,
- hardware::neuralnetworks::V1_2::MeasureTiming measure, const std::vector<int32_t>& slots);
+std::vector<FmqRequestDatum> serialize(const V1_0::Request& request, MeasureTiming measure,
+ const std::vector<int32_t>& slots);
/**
* Deserialize the FMQ request data.
*
- * The three resulting fields are the Request object (where Request::pools is
- * empty), slot identifiers (which are stand-ins for Request::pools), and
- * whether timing information must be collected for the run.
+ * The three resulting fields are the Request object (where Request::pools is empty), slot
+ * identifiers (which are stand-ins for Request::pools), and whether timing information must be
+ * collected for the run.
*
* @param data Serialized FMQ request data.
- * @return Request object if successfully deserialized, std::nullopt otherwise.
+ * @return Request object if successfully deserialized, otherwise an error message.
*/
-std::optional<std::tuple<hardware::neuralnetworks::V1_0::Request, std::vector<int32_t>,
- hardware::neuralnetworks::V1_2::MeasureTiming>>
-deserialize(const std::vector<hardware::neuralnetworks::V1_2::FmqRequestDatum>& data);
+nn::Result<std::tuple<V1_0::Request, std::vector<int32_t>, MeasureTiming>> deserialize(
+ const std::vector<FmqRequestDatum>& data);
/**
* Function to serialize results.
*
- * Prefer calling ResultChannelSender::send.
- *
* @param errorStatus Status of the execution.
* @param outputShapes Dynamic shapes of the output tensors.
* @param timing Timing information of the execution.
* @return Serialized FMQ result data.
*/
-std::vector<hardware::neuralnetworks::V1_2::FmqResultDatum> serialize(
- hardware::neuralnetworks::V1_0::ErrorStatus errorStatus,
- const std::vector<hardware::neuralnetworks::V1_2::OutputShape>& outputShapes,
- hardware::neuralnetworks::V1_2::Timing timing);
+std::vector<FmqResultDatum> serialize(V1_0::ErrorStatus errorStatus,
+ const std::vector<OutputShape>& outputShapes, Timing timing);
/**
* Deserialize the FMQ result data.
*
- * The three resulting fields are the status of the execution, the dynamic
- * shapes of the output tensors, and the timing information of the execution.
+ * The three resulting fields are the status of the execution, the dynamic shapes of the output
+ * tensors, and the timing information of the execution.
*
* @param data Serialized FMQ result data.
- * @return Result object if successfully deserialized, std::nullopt otherwise.
+ * @return Result object if successfully deserialized, otherwise an error message.
*/
-std::optional<std::tuple<hardware::neuralnetworks::V1_0::ErrorStatus,
- std::vector<hardware::neuralnetworks::V1_2::OutputShape>,
- hardware::neuralnetworks::V1_2::Timing>>
-deserialize(const std::vector<hardware::neuralnetworks::V1_2::FmqResultDatum>& data);
+nn::Result<std::tuple<V1_0::ErrorStatus, std::vector<OutputShape>, Timing>> deserialize(
+ const std::vector<FmqResultDatum>& data);
/**
- * Convert result code to error status.
- *
- * @param resultCode Result code to be converted.
- * @return ErrorStatus Resultant error status.
+ * RequestChannelSender is responsible for serializing the result packet of information, sending it
+ * on the result channel, and signaling that the data is available.
*/
-hardware::neuralnetworks::V1_0::ErrorStatus legacyConvertResultCodeToErrorStatus(int resultCode);
-
-/**
- * RequestChannelSender is responsible for serializing the result packet of
- * information, sending it on the result channel, and signaling that the data is
- * available.
- */
-class RequestChannelSender {
- using FmqRequestDescriptor =
- hardware::MQDescriptorSync<hardware::neuralnetworks::V1_2::FmqRequestDatum>;
- using FmqRequestChannel =
- hardware::MessageQueue<hardware::neuralnetworks::V1_2::FmqRequestDatum,
- hardware::kSynchronizedReadWrite>;
+class RequestChannelSender final : public neuralnetworks::utils::IProtectedCallback {
+ struct PrivateConstructorTag {};
public:
/**
* Create the sending end of a request channel.
*
- * Prefer this call over the constructor.
- *
* @param channelLength Number of elements in the FMQ.
- * @return A pair of ResultChannelReceiver and the FMQ descriptor on
- * successful creation, both nullptr otherwise.
+ * @return A pair of ResultChannelReceiver and the FMQ descriptor on successful creation,
+ * GeneralError otherwise.
*/
- static std::pair<std::unique_ptr<RequestChannelSender>, const FmqRequestDescriptor*> create(
- size_t channelLength);
+ static nn::GeneralResult<std::pair<std::unique_ptr<RequestChannelSender>,
+ const MQDescriptorSync<FmqRequestDatum>*>>
+ create(size_t channelLength);
/**
* Send the request to the channel.
*
* @param request Request object without the pool information.
* @param measure Whether to collect timing information for the execution.
- * @param memoryIds Slot identifiers corresponding to memory resources for
- * the request.
- * @return 'true' on successful send, 'false' otherwise.
+ * @param slots Slot identifiers corresponding to memory resources for the request.
+ * @return An empty `Result` on successful send, otherwise an error message.
*/
- bool send(const hardware::neuralnetworks::V1_0::Request& request,
- hardware::neuralnetworks::V1_2::MeasureTiming measure,
- const std::vector<int32_t>& slots);
+ nn::Result<void> send(const V1_0::Request& request, MeasureTiming measure,
+ const std::vector<int32_t>& slots);
/**
- * Method to mark the channel as invalid, causing all future calls to
- * RequestChannelSender::send to immediately return false without attempting
- * to send a message across the FMQ.
+ * Method to mark the channel as invalid, causing all future calls to RequestChannelSender::send
+ * to immediately return false without attempting to send a message across the FMQ.
*/
- void invalidate();
+ void notifyAsDeadObject() override;
// prefer calling RequestChannelSender::send
- bool sendPacket(const std::vector<hardware::neuralnetworks::V1_2::FmqRequestDatum>& packet);
+ nn::Result<void> sendPacket(const std::vector<FmqRequestDatum>& packet);
- RequestChannelSender(std::unique_ptr<FmqRequestChannel> fmqRequestChannel);
+ RequestChannelSender(PrivateConstructorTag tag, size_t channelLength);
private:
- const std::unique_ptr<FmqRequestChannel> mFmqRequestChannel;
+ MessageQueue<FmqRequestDatum, kSynchronizedReadWrite> mFmqRequestChannel;
std::atomic<bool> mValid{true};
};
/**
- * RequestChannelReceiver is responsible for waiting on the channel until the
- * packet is available, extracting the packet from the channel, and
- * deserializing the packet.
+ * RequestChannelReceiver is responsible for waiting on the channel until the packet is available,
+ * extracting the packet from the channel, and deserializing the packet.
*
- * Because the receiver can wait on a packet that may never come (e.g., because
- * the sending side of the packet has been closed), this object can be
- * invalidated, unblocking the receiver.
+ * Because the receiver can wait on a packet that may never come (e.g., because the sending side of
+ * the packet has been closed), this object can be invalidated, unblocking the receiver.
*/
-class RequestChannelReceiver {
- using FmqRequestChannel =
- hardware::MessageQueue<hardware::neuralnetworks::V1_2::FmqRequestDatum,
- hardware::kSynchronizedReadWrite>;
+class RequestChannelReceiver final {
+ struct PrivateConstructorTag {};
public:
/**
* Create the receiving end of a request channel.
*
- * Prefer this call over the constructor.
- *
* @param requestChannel Descriptor for the request channel.
- * @param pollingTimeWindow How much time (in microseconds) the
- * RequestChannelReceiver is allowed to poll the FMQ before waiting on
- * the blocking futex. Polling may result in lower latencies at the
- * potential cost of more power usage.
+ * @param pollingTimeWindow How much time (in microseconds) the RequestChannelReceiver is
+ * allowed to poll the FMQ before waiting on the blocking futex. Polling may result in lower
+ * latencies at the potential cost of more power usage.
* @return RequestChannelReceiver on successful creation, nullptr otherwise.
*/
- static std::unique_ptr<RequestChannelReceiver> create(
- const FmqRequestDescriptor& requestChannel,
+ static nn::GeneralResult<std::unique_ptr<RequestChannelReceiver>> create(
+ const MQDescriptorSync<FmqRequestDatum>& requestChannel,
std::chrono::microseconds pollingTimeWindow);
/**
@@ -200,49 +181,45 @@
* 1) The packet has been retrieved, or
* 2) The receiver has been invalidated
*
- * @return Request object if successfully received, std::nullopt if error or
- * if the receiver object was invalidated.
+ * @return Request object if successfully received, an appropriate message if error or if the
+ * receiver object was invalidated.
*/
- std::optional<std::tuple<hardware::neuralnetworks::V1_0::Request, std::vector<int32_t>,
- hardware::neuralnetworks::V1_2::MeasureTiming>>
- getBlocking();
+ nn::Result<std::tuple<V1_0::Request, std::vector<int32_t>, MeasureTiming>> getBlocking();
/**
- * Method to mark the channel as invalid, unblocking any current or future
- * calls to RequestChannelReceiver::getBlocking.
+ * Method to mark the channel as invalid, unblocking any current or future calls to
+ * RequestChannelReceiver::getBlocking.
*/
void invalidate();
- RequestChannelReceiver(std::unique_ptr<FmqRequestChannel> fmqRequestChannel,
+ RequestChannelReceiver(PrivateConstructorTag tag,
+ const MQDescriptorSync<FmqRequestDatum>& requestChannel,
std::chrono::microseconds pollingTimeWindow);
private:
- std::optional<std::vector<hardware::neuralnetworks::V1_2::FmqRequestDatum>> getPacketBlocking();
+ nn::Result<std::vector<FmqRequestDatum>> getPacketBlocking();
- const std::unique_ptr<FmqRequestChannel> mFmqRequestChannel;
+ MessageQueue<FmqRequestDatum, kSynchronizedReadWrite> mFmqRequestChannel;
std::atomic<bool> mTeardown{false};
const std::chrono::microseconds kPollingTimeWindow;
};
/**
- * ResultChannelSender is responsible for serializing the result packet of
- * information, sending it on the result channel, and signaling that the data is
- * available.
+ * ResultChannelSender is responsible for serializing the result packet of information, sending it
+ * on the result channel, and signaling that the data is available.
*/
-class ResultChannelSender {
- using FmqResultChannel = hardware::MessageQueue<hardware::neuralnetworks::V1_2::FmqResultDatum,
- hardware::kSynchronizedReadWrite>;
+class ResultChannelSender final {
+ struct PrivateConstructorTag {};
public:
/**
* Create the sending end of a result channel.
*
- * Prefer this call over the constructor.
- *
* @param resultChannel Descriptor for the result channel.
* @return ResultChannelSender on successful creation, nullptr otherwise.
*/
- static std::unique_ptr<ResultChannelSender> create(const FmqResultDescriptor& resultChannel);
+ static nn::GeneralResult<std::unique_ptr<ResultChannelSender>> create(
+ const MQDescriptorSync<FmqResultDatum>& resultChannel);
/**
* Send the result to the channel.
@@ -250,52 +227,44 @@
* @param errorStatus Status of the execution.
* @param outputShapes Dynamic shapes of the output tensors.
* @param timing Timing information of the execution.
- * @return 'true' on successful send, 'false' otherwise.
*/
- bool send(hardware::neuralnetworks::V1_0::ErrorStatus errorStatus,
- const std::vector<hardware::neuralnetworks::V1_2::OutputShape>& outputShapes,
- hardware::neuralnetworks::V1_2::Timing timing);
+ void send(V1_0::ErrorStatus errorStatus, const std::vector<OutputShape>& outputShapes,
+ Timing timing);
// prefer calling ResultChannelSender::send
- bool sendPacket(const std::vector<hardware::neuralnetworks::V1_2::FmqResultDatum>& packet);
+ void sendPacket(const std::vector<FmqResultDatum>& packet);
- ResultChannelSender(std::unique_ptr<FmqResultChannel> fmqResultChannel);
+ ResultChannelSender(PrivateConstructorTag tag,
+ const MQDescriptorSync<FmqResultDatum>& resultChannel);
private:
- const std::unique_ptr<FmqResultChannel> mFmqResultChannel;
+ MessageQueue<FmqResultDatum, kSynchronizedReadWrite> mFmqResultChannel;
};
/**
- * ResultChannelReceiver is responsible for waiting on the channel until the
- * packet is available, extracting the packet from the channel, and
- * deserializing the packet.
+ * ResultChannelReceiver is responsible for waiting on the channel until the packet is available,
+ * extracting the packet from the channel, and deserializing the packet.
*
- * Because the receiver can wait on a packet that may never come (e.g., because
- * the sending side of the packet has been closed), this object can be
- * invalidated, unblocking the receiver.
+ * Because the receiver can wait on a packet that may never come (e.g., because the sending side of
+ * the packet has been closed), this object can be invalidated, unblocking the receiver.
*/
-class ResultChannelReceiver {
- using FmqResultDescriptor =
- hardware::MQDescriptorSync<hardware::neuralnetworks::V1_2::FmqResultDatum>;
- using FmqResultChannel = hardware::MessageQueue<hardware::neuralnetworks::V1_2::FmqResultDatum,
- hardware::kSynchronizedReadWrite>;
+class ResultChannelReceiver final : public neuralnetworks::utils::IProtectedCallback {
+ struct PrivateConstructorTag {};
public:
/**
* Create the receiving end of a result channel.
*
- * Prefer this call over the constructor.
- *
* @param channelLength Number of elements in the FMQ.
- * @param pollingTimeWindow How much time (in microseconds) the
- * ResultChannelReceiver is allowed to poll the FMQ before waiting on
- * the blocking futex. Polling may result in lower latencies at the
- * potential cost of more power usage.
- * @return A pair of ResultChannelReceiver and the FMQ descriptor on
- * successful creation, both nullptr otherwise.
+ * @param pollingTimeWindow How much time (in microseconds) the ResultChannelReceiver is allowed
+ * to poll the FMQ before waiting on the blocking futex. Polling may result in lower
+ * latencies at the potential cost of more power usage.
+ * @return A pair of ResultChannelReceiver and the FMQ descriptor on successful creation, or
+ * GeneralError otherwise.
*/
- static std::pair<std::unique_ptr<ResultChannelReceiver>, const FmqResultDescriptor*> create(
- size_t channelLength, std::chrono::microseconds pollingTimeWindow);
+ static nn::GeneralResult<std::pair<std::unique_ptr<ResultChannelReceiver>,
+ const MQDescriptorSync<FmqResultDatum>*>>
+ create(size_t channelLength, std::chrono::microseconds pollingTimeWindow);
/**
* Get the result from the channel.
@@ -304,28 +273,25 @@
* 1) The packet has been retrieved, or
* 2) The receiver has been invalidated
*
- * @return Result object if successfully received, std::nullopt if error or
+ * @return Result object if successfully received, otherwise an appropriate message if error or
* if the receiver object was invalidated.
*/
- std::optional<std::tuple<hardware::neuralnetworks::V1_0::ErrorStatus,
- std::vector<hardware::neuralnetworks::V1_2::OutputShape>,
- hardware::neuralnetworks::V1_2::Timing>>
- getBlocking();
+ nn::Result<std::tuple<V1_0::ErrorStatus, std::vector<OutputShape>, Timing>> getBlocking();
/**
- * Method to mark the channel as invalid, unblocking any current or future
- * calls to ResultChannelReceiver::getBlocking.
+ * Method to mark the channel as invalid, unblocking any current or future calls to
+ * ResultChannelReceiver::getBlocking.
*/
- void invalidate();
+ void notifyAsDeadObject() override;
// prefer calling ResultChannelReceiver::getBlocking
- std::optional<std::vector<hardware::neuralnetworks::V1_2::FmqResultDatum>> getPacketBlocking();
+ nn::Result<std::vector<FmqResultDatum>> getPacketBlocking();
- ResultChannelReceiver(std::unique_ptr<FmqResultChannel> fmqResultChannel,
+ ResultChannelReceiver(PrivateConstructorTag tag, size_t channelLength,
std::chrono::microseconds pollingTimeWindow);
private:
- const std::unique_ptr<FmqResultChannel> mFmqResultChannel;
+ MessageQueue<FmqResultDatum, kSynchronizedReadWrite> mFmqResultChannel;
std::atomic<bool> mValid{true};
const std::chrono::microseconds kPollingTimeWindow;
};
diff --git a/neuralnetworks/1.2/utils/src/Conversions.cpp b/neuralnetworks/1.2/utils/src/Conversions.cpp
index 86a417a..2c45583 100644
--- a/neuralnetworks/1.2/utils/src/Conversions.cpp
+++ b/neuralnetworks/1.2/utils/src/Conversions.cpp
@@ -331,6 +331,10 @@
return validatedConvert(timing);
}
+GeneralResult<SharedMemory> convert(const hardware::hidl_memory& memory) {
+ return validatedConvert(memory);
+}
+
GeneralResult<std::vector<Extension>> convert(const hidl_vec<hal::V1_2::Extension>& extensions) {
return validatedConvert(extensions);
}
diff --git a/neuralnetworks/1.2/utils/src/ExecutionBurstController.cpp b/neuralnetworks/1.2/utils/src/ExecutionBurstController.cpp
index 2265861..eedf591 100644
--- a/neuralnetworks/1.2/utils/src/ExecutionBurstController.cpp
+++ b/neuralnetworks/1.2/utils/src/ExecutionBurstController.cpp
@@ -17,283 +17,321 @@
#define LOG_TAG "ExecutionBurstController"
#include "ExecutionBurstController.h"
+#include "ExecutionBurstUtils.h"
#include <android-base/logging.h>
+#include <android-base/thread_annotations.h>
+#include <nnapi/IBurst.h>
+#include <nnapi/IPreparedModel.h>
+#include <nnapi/Result.h>
+#include <nnapi/TypeUtils.h>
+#include <nnapi/Types.h>
+#include <nnapi/Validation.h>
+#include <nnapi/hal/1.0/Conversions.h>
+#include <nnapi/hal/HandleError.h>
+#include <nnapi/hal/ProtectCallback.h>
+#include <nnapi/hal/TransferValue.h>
#include <algorithm>
#include <cstring>
#include <limits>
#include <memory>
#include <string>
+#include <thread>
#include <tuple>
#include <utility>
#include <vector>
-#include "ExecutionBurstUtils.h"
-#include "HalInterfaces.h"
+#include "Callbacks.h"
+#include "Conversions.h"
#include "Tracing.h"
#include "Utils.h"
-namespace android::nn {
+namespace android::hardware::neuralnetworks::V1_2::utils {
namespace {
-class BurstContextDeathHandler : public hardware::hidl_death_recipient {
- public:
- using Callback = std::function<void()>;
-
- BurstContextDeathHandler(const Callback& onDeathCallback) : mOnDeathCallback(onDeathCallback) {
- CHECK(onDeathCallback != nullptr);
+nn::GeneralResult<sp<IBurstContext>> executionBurstResultCallback(
+ V1_0::ErrorStatus status, const sp<IBurstContext>& burstContext) {
+ HANDLE_HAL_STATUS(status) << "IPreparedModel::configureExecutionBurst failed with status "
+ << toString(status);
+ if (burstContext == nullptr) {
+ return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
+ << "IPreparedModel::configureExecutionBurst returned nullptr for burst";
}
-
- void serviceDied(uint64_t /*cookie*/, const wp<hidl::base::V1_0::IBase>& /*who*/) override {
- LOG(ERROR) << "BurstContextDeathHandler::serviceDied -- service unexpectedly died!";
- mOnDeathCallback();
- }
-
- private:
- const Callback mOnDeathCallback;
-};
-
-} // anonymous namespace
-
-hardware::Return<void> ExecutionBurstController::ExecutionBurstCallback::getMemories(
- const hardware::hidl_vec<int32_t>& slots, getMemories_cb cb) {
- std::lock_guard<std::mutex> guard(mMutex);
-
- // get all memories
- hardware::hidl_vec<hardware::hidl_memory> memories(slots.size());
- std::transform(slots.begin(), slots.end(), memories.begin(), [this](int32_t slot) {
- return slot < mMemoryCache.size() ? mMemoryCache[slot] : hardware::hidl_memory{};
- });
-
- // ensure all memories are valid
- if (!std::all_of(memories.begin(), memories.end(),
- [](const hardware::hidl_memory& memory) { return memory.valid(); })) {
- cb(V1_0::ErrorStatus::INVALID_ARGUMENT, {});
- return hardware::Void();
- }
-
- // return successful
- cb(V1_0::ErrorStatus::NONE, std::move(memories));
- return hardware::Void();
+ return burstContext;
}
-std::vector<int32_t> ExecutionBurstController::ExecutionBurstCallback::getSlots(
- const hardware::hidl_vec<hardware::hidl_memory>& memories,
- const std::vector<intptr_t>& keys) {
- std::lock_guard<std::mutex> guard(mMutex);
-
- // retrieve (or bind) all slots corresponding to memories
- std::vector<int32_t> slots;
- slots.reserve(memories.size());
- for (size_t i = 0; i < memories.size(); ++i) {
- slots.push_back(getSlotLocked(memories[i], keys[i]));
+nn::GeneralResult<hidl_vec<hidl_memory>> getMemoriesHelper(
+ const hidl_vec<int32_t>& slots,
+ const std::shared_ptr<ExecutionBurstController::MemoryCache>& memoryCache) {
+ hidl_vec<hidl_memory> memories(slots.size());
+ for (size_t i = 0; i < slots.size(); ++i) {
+ const int32_t slot = slots[i];
+ const auto memory = NN_TRY(memoryCache->getMemory(slot));
+ memories[i] = NN_TRY(V1_0::utils::unvalidatedConvert(memory));
+ if (!memories[i].valid()) {
+ return NN_ERROR() << "memory at slot " << slot << " is invalid";
+ }
}
- return slots;
+ return memories;
}
-std::pair<bool, int32_t> ExecutionBurstController::ExecutionBurstCallback::freeMemory(
- intptr_t key) {
- std::lock_guard<std::mutex> guard(mMutex);
+} // namespace
- auto iter = mMemoryIdToSlot.find(key);
- if (iter == mMemoryIdToSlot.end()) {
- return {false, 0};
- }
- const int32_t slot = iter->second;
- mMemoryIdToSlot.erase(key);
- mMemoryCache[slot] = {};
- mFreeSlots.push(slot);
- return {true, slot};
+// MemoryCache methods
+
+ExecutionBurstController::MemoryCache::MemoryCache() {
+ constexpr size_t kPreallocatedCount = 1024;
+ std::vector<int32_t> freeSlotsSpace;
+ freeSlotsSpace.reserve(kPreallocatedCount);
+ mFreeSlots = std::stack<int32_t, std::vector<int32_t>>(std::move(freeSlotsSpace));
+ mMemoryCache.reserve(kPreallocatedCount);
+ mCacheCleaner.reserve(kPreallocatedCount);
}
-int32_t ExecutionBurstController::ExecutionBurstCallback::getSlotLocked(
- const hardware::hidl_memory& memory, intptr_t key) {
- auto iter = mMemoryIdToSlot.find(key);
- if (iter == mMemoryIdToSlot.end()) {
- const int32_t slot = allocateSlotLocked();
- mMemoryIdToSlot[key] = slot;
- mMemoryCache[slot] = memory;
- return slot;
- } else {
+void ExecutionBurstController::MemoryCache::setBurstContext(sp<IBurstContext> burstContext) {
+ std::lock_guard guard(mMutex);
+ mBurstContext = std::move(burstContext);
+}
+
+std::pair<int32_t, ExecutionBurstController::MemoryCache::SharedCleanup>
+ExecutionBurstController::MemoryCache::cacheMemory(const nn::SharedMemory& memory) {
+ std::unique_lock lock(mMutex);
+ base::ScopedLockAssertion lockAssert(mMutex);
+
+ // Use existing cache entry if (1) the Memory object is in the cache and (2) the cache entry is
+ // not currently being freed.
+ auto iter = mMemoryIdToSlot.find(memory);
+ while (iter != mMemoryIdToSlot.end()) {
const int32_t slot = iter->second;
- return slot;
+ if (auto cleaner = mCacheCleaner.at(slot).lock()) {
+ return std::make_pair(slot, std::move(cleaner));
+ }
+
+ // If the code reaches this point, the Memory object was in the cache, but is currently
+ // being destroyed. This code waits until the cache entry has been freed, then loops to
+ // ensure the cache entry has been freed or has been made present by another thread.
+ mCond.wait(lock);
+ iter = mMemoryIdToSlot.find(memory);
}
+
+ // Allocate a new cache entry.
+ const int32_t slot = allocateSlotLocked();
+ mMemoryIdToSlot[memory] = slot;
+ mMemoryCache[slot] = memory;
+
+ // Create reference-counted self-cleaning cache object.
+ auto self = weak_from_this();
+ Task cleanup = [memory, memoryCache = std::move(self)] {
+ if (const auto lock = memoryCache.lock()) {
+ lock->freeMemory(memory);
+ }
+ };
+ auto cleaner = std::make_shared<const Cleanup>(std::move(cleanup));
+ mCacheCleaner[slot] = cleaner;
+
+ return std::make_pair(slot, std::move(cleaner));
}
-int32_t ExecutionBurstController::ExecutionBurstCallback::allocateSlotLocked() {
+nn::GeneralResult<nn::SharedMemory> ExecutionBurstController::MemoryCache::getMemory(int32_t slot) {
+ std::lock_guard guard(mMutex);
+ if (slot < 0 || static_cast<size_t>(slot) >= mMemoryCache.size()) {
+ return NN_ERROR() << "Invalid slot: " << slot << " vs " << mMemoryCache.size();
+ }
+ return mMemoryCache[slot];
+}
+
+void ExecutionBurstController::MemoryCache::freeMemory(const nn::SharedMemory& memory) {
+ {
+ std::lock_guard guard(mMutex);
+ const int32_t slot = mMemoryIdToSlot.at(memory);
+ if (mBurstContext) {
+ mBurstContext->freeMemory(slot);
+ }
+ mMemoryIdToSlot.erase(memory);
+ mMemoryCache[slot] = {};
+ mCacheCleaner[slot].reset();
+ mFreeSlots.push(slot);
+ }
+ mCond.notify_all();
+}
+
+int32_t ExecutionBurstController::MemoryCache::allocateSlotLocked() {
constexpr size_t kMaxNumberOfSlots = std::numeric_limits<int32_t>::max();
- // if there is a free slot, use it
- if (mFreeSlots.size() > 0) {
+ // If there is a free slot, use it.
+ if (!mFreeSlots.empty()) {
const int32_t slot = mFreeSlots.top();
mFreeSlots.pop();
return slot;
}
- // otherwise use a slot for the first time
- CHECK(mMemoryCache.size() < kMaxNumberOfSlots) << "Exceeded maximum number of slots!";
+ // Use a slot for the first time.
+ CHECK_LT(mMemoryCache.size(), kMaxNumberOfSlots) << "Exceeded maximum number of slots!";
const int32_t slot = static_cast<int32_t>(mMemoryCache.size());
mMemoryCache.emplace_back();
+ mCacheCleaner.emplace_back();
return slot;
}
-std::unique_ptr<ExecutionBurstController> ExecutionBurstController::create(
- const sp<V1_2::IPreparedModel>& preparedModel,
+// ExecutionBurstCallback methods
+
+ExecutionBurstController::ExecutionBurstCallback::ExecutionBurstCallback(
+ const std::shared_ptr<MemoryCache>& memoryCache)
+ : kMemoryCache(memoryCache) {
+ CHECK(memoryCache != nullptr);
+}
+
+Return<void> ExecutionBurstController::ExecutionBurstCallback::getMemories(
+ const hidl_vec<int32_t>& slots, getMemories_cb cb) {
+ const auto memoryCache = kMemoryCache.lock();
+ if (memoryCache == nullptr) {
+ LOG(ERROR) << "ExecutionBurstController::ExecutionBurstCallback::getMemories called after "
+ "the MemoryCache has been freed";
+ cb(V1_0::ErrorStatus::GENERAL_FAILURE, {});
+ return Void();
+ }
+
+ const auto maybeMemories = getMemoriesHelper(slots, memoryCache);
+ if (!maybeMemories.has_value()) {
+ const auto& [message, code] = maybeMemories.error();
+ LOG(ERROR) << "ExecutionBurstController::ExecutionBurstCallback::getMemories failed with "
+ << code << ": " << message;
+ cb(V1_0::ErrorStatus::INVALID_ARGUMENT, {});
+ return Void();
+ }
+
+ cb(V1_0::ErrorStatus::NONE, maybeMemories.value());
+ return Void();
+}
+
+// ExecutionBurstController methods
+
+nn::GeneralResult<std::shared_ptr<const ExecutionBurstController>> ExecutionBurstController::create(
+ const sp<V1_2::IPreparedModel>& preparedModel, FallbackFunction fallback,
std::chrono::microseconds pollingTimeWindow) {
// check inputs
if (preparedModel == nullptr) {
- LOG(ERROR) << "ExecutionBurstController::create passed a nullptr";
- return nullptr;
+ return NN_ERROR() << "ExecutionBurstController::create passed a nullptr";
}
- // create callback object
- sp<ExecutionBurstCallback> callback = new ExecutionBurstCallback();
-
// create FMQ objects
- auto [requestChannelSenderTemp, requestChannelDescriptor] =
- RequestChannelSender::create(kExecutionBurstChannelLength);
- auto [resultChannelReceiverTemp, resultChannelDescriptor] =
- ResultChannelReceiver::create(kExecutionBurstChannelLength, pollingTimeWindow);
- std::shared_ptr<RequestChannelSender> requestChannelSender =
- std::move(requestChannelSenderTemp);
- std::shared_ptr<ResultChannelReceiver> resultChannelReceiver =
- std::move(resultChannelReceiverTemp);
+ auto [requestChannelSender, requestChannelDescriptor] =
+ NN_TRY(RequestChannelSender::create(kExecutionBurstChannelLength));
+ auto [resultChannelReceiver, resultChannelDescriptor] =
+ NN_TRY(ResultChannelReceiver::create(kExecutionBurstChannelLength, pollingTimeWindow));
// check FMQ objects
- if (!requestChannelSender || !resultChannelReceiver || !requestChannelDescriptor ||
- !resultChannelDescriptor) {
- LOG(ERROR) << "ExecutionBurstController::create failed to create FastMessageQueue";
- return nullptr;
- }
+ CHECK(requestChannelSender != nullptr);
+ CHECK(requestChannelDescriptor != nullptr);
+ CHECK(resultChannelReceiver != nullptr);
+ CHECK(resultChannelDescriptor != nullptr);
+
+ // create memory cache
+ auto memoryCache = std::make_shared<MemoryCache>();
+
+ // create callback object
+ auto burstCallback = sp<ExecutionBurstCallback>::make(memoryCache);
+ auto cb = hal::utils::CallbackValue(executionBurstResultCallback);
// configure burst
- V1_0::ErrorStatus errorStatus;
- sp<IBurstContext> burstContext;
- const hardware::Return<void> ret = preparedModel->configureExecutionBurst(
- callback, *requestChannelDescriptor, *resultChannelDescriptor,
- [&errorStatus, &burstContext](V1_0::ErrorStatus status,
- const sp<IBurstContext>& context) {
- errorStatus = status;
- burstContext = context;
- });
+ const Return<void> ret = preparedModel->configureExecutionBurst(
+ burstCallback, *requestChannelDescriptor, *resultChannelDescriptor, cb);
+ HANDLE_TRANSPORT_FAILURE(ret);
- // check burst
- if (!ret.isOk()) {
- LOG(ERROR) << "IPreparedModel::configureExecutionBurst failed with description "
- << ret.description();
- return nullptr;
- }
- if (errorStatus != V1_0::ErrorStatus::NONE) {
- LOG(ERROR) << "IPreparedModel::configureExecutionBurst failed with status "
- << toString(errorStatus);
- return nullptr;
- }
- if (burstContext == nullptr) {
- LOG(ERROR) << "IPreparedModel::configureExecutionBurst returned nullptr for burst";
- return nullptr;
- }
+ auto burstContext = NN_TRY(cb.take());
+ memoryCache->setBurstContext(burstContext);
// create death handler object
- BurstContextDeathHandler::Callback onDeathCallback = [requestChannelSender,
- resultChannelReceiver] {
- requestChannelSender->invalidate();
- resultChannelReceiver->invalidate();
- };
- const sp<BurstContextDeathHandler> deathHandler = new BurstContextDeathHandler(onDeathCallback);
-
- // linkToDeath registers a callback that will be invoked on service death to
- // proactively handle service crashes. If the linkToDeath call fails,
- // asynchronous calls are susceptible to hangs if the service crashes before
- // providing the response.
- const hardware::Return<bool> deathHandlerRet = burstContext->linkToDeath(deathHandler, 0);
- if (!deathHandlerRet.isOk() || deathHandlerRet != true) {
- LOG(ERROR) << "ExecutionBurstController::create -- Failed to register a death recipient "
- "for the IBurstContext object.";
- return nullptr;
- }
+ auto deathHandler = NN_TRY(neuralnetworks::utils::DeathHandler::create(burstContext));
+ deathHandler.protectCallbackForLifetimeOfDeathHandler(requestChannelSender.get());
+ deathHandler.protectCallbackForLifetimeOfDeathHandler(resultChannelReceiver.get());
// make and return controller
- return std::make_unique<ExecutionBurstController>(requestChannelSender, resultChannelReceiver,
- burstContext, callback, deathHandler);
+ return std::make_shared<const ExecutionBurstController>(
+ PrivateConstructorTag{}, std::move(fallback), std::move(requestChannelSender),
+ std::move(resultChannelReceiver), std::move(burstCallback), std::move(burstContext),
+ std::move(memoryCache), std::move(deathHandler));
}
ExecutionBurstController::ExecutionBurstController(
- const std::shared_ptr<RequestChannelSender>& requestChannelSender,
- const std::shared_ptr<ResultChannelReceiver>& resultChannelReceiver,
- const sp<IBurstContext>& burstContext, const sp<ExecutionBurstCallback>& callback,
- const sp<hardware::hidl_death_recipient>& deathHandler)
- : mRequestChannelSender(requestChannelSender),
- mResultChannelReceiver(resultChannelReceiver),
- mBurstContext(burstContext),
- mMemoryCache(callback),
- mDeathHandler(deathHandler) {}
+ PrivateConstructorTag /*tag*/, FallbackFunction fallback,
+ std::unique_ptr<RequestChannelSender> requestChannelSender,
+ std::unique_ptr<ResultChannelReceiver> resultChannelReceiver,
+ sp<ExecutionBurstCallback> callback, sp<IBurstContext> burstContext,
+ std::shared_ptr<MemoryCache> memoryCache, neuralnetworks::utils::DeathHandler deathHandler)
+ : kFallback(std::move(fallback)),
+ mRequestChannelSender(std::move(requestChannelSender)),
+ mResultChannelReceiver(std::move(resultChannelReceiver)),
+ mBurstCallback(std::move(callback)),
+ mBurstContext(std::move(burstContext)),
+ mMemoryCache(std::move(memoryCache)),
+ kDeathHandler(std::move(deathHandler)) {}
-ExecutionBurstController::~ExecutionBurstController() {
- // It is safe to ignore any errors resulting from this unlinkToDeath call
- // because the ExecutionBurstController object is already being destroyed
- // and its underlying IBurstContext object is no longer being used by the NN
- // runtime.
- if (mDeathHandler) {
- mBurstContext->unlinkToDeath(mDeathHandler).isOk();
+ExecutionBurstController::OptionalCacheHold ExecutionBurstController::cacheMemory(
+ const nn::SharedMemory& memory) const {
+ auto [slot, hold] = mMemoryCache->cacheMemory(memory);
+ return hold;
+}
+
+nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>>
+ExecutionBurstController::execute(const nn::Request& request, nn::MeasureTiming measure) const {
+ // This is the first point when we know an execution is occurring, so begin to collect
+ // systraces. Note that the first point we can begin collecting systraces in
+ // ExecutionBurstServer is when the RequestChannelReceiver realizes there is data in the FMQ, so
+ // ExecutionBurstServer collects systraces at different points in the code.
+ NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_EXECUTION, "ExecutionBurstController::execute");
+
+ // if the request is valid but of a higher version than what's supported in burst execution,
+ // fall back to another execution path
+ if (const auto version = NN_TRY(hal::utils::makeExecutionFailure(nn::validate(request)));
+ version > nn::Version::ANDROID_Q) {
+ // fallback to another execution path if the packet could not be sent
+ if (kFallback) {
+ return kFallback(request, measure);
+ }
+ return NN_ERROR() << "Request object has features not supported by IBurst::execute";
}
-}
-static std::tuple<int, std::vector<V1_2::OutputShape>, V1_2::Timing, bool> getExecutionResult(
- V1_0::ErrorStatus status, std::vector<V1_2::OutputShape> outputShapes, V1_2::Timing timing,
- bool fallback) {
- auto [n, checkedOutputShapes, checkedTiming] =
- getExecutionResult(convertToV1_3(status), std::move(outputShapes), timing);
- return {n, convertToV1_2(checkedOutputShapes), convertToV1_2(checkedTiming), fallback};
-}
+ // clear pools field of request, as they will be provided via slots
+ const auto requestWithoutPools =
+ nn::Request{.inputs = request.inputs, .outputs = request.outputs, .pools = {}};
+ auto hidlRequest = NN_TRY(
+ hal::utils::makeExecutionFailure(V1_0::utils::unvalidatedConvert(requestWithoutPools)));
+ const auto hidlMeasure = NN_TRY(hal::utils::makeExecutionFailure(convert(measure)));
-std::tuple<int, std::vector<V1_2::OutputShape>, V1_2::Timing, bool>
-ExecutionBurstController::compute(const V1_0::Request& request, V1_2::MeasureTiming measure,
- const std::vector<intptr_t>& memoryIds) {
- // This is the first point when we know an execution is occurring, so begin
- // to collect systraces. Note that the first point we can begin collecting
- // systraces in ExecutionBurstServer is when the RequestChannelReceiver
- // realizes there is data in the FMQ, so ExecutionBurstServer collects
- // systraces at different points in the code.
- NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_EXECUTION, "ExecutionBurstController::compute");
+ // Ensure that at most one execution is in flight at any given time.
+ const bool alreadyInFlight = mExecutionInFlight.test_and_set();
+ if (alreadyInFlight) {
+ return NN_ERROR() << "IBurst already has an execution in flight";
+ }
+ const auto guard = base::make_scope_guard([this] { mExecutionInFlight.clear(); });
- std::lock_guard<std::mutex> guard(mMutex);
+ std::vector<int32_t> slots;
+ std::vector<OptionalCacheHold> holds;
+ slots.reserve(request.pools.size());
+ holds.reserve(request.pools.size());
+ for (const auto& memoryPool : request.pools) {
+ auto [slot, hold] = mMemoryCache->cacheMemory(std::get<nn::SharedMemory>(memoryPool));
+ slots.push_back(slot);
+ holds.push_back(std::move(hold));
+ }
// send request packet
- const std::vector<int32_t> slots = mMemoryCache->getSlots(request.pools, memoryIds);
- const bool success = mRequestChannelSender->send(request, measure, slots);
- if (!success) {
- LOG(ERROR) << "Error sending FMQ packet";
- // only use fallback execution path if the packet could not be sent
- return getExecutionResult(V1_0::ErrorStatus::GENERAL_FAILURE, {}, kNoTiming12,
- /*fallback=*/true);
+ const auto sendStatus = mRequestChannelSender->send(hidlRequest, hidlMeasure, slots);
+ if (!sendStatus.ok()) {
+ // fallback to another execution path if the packet could not be sent
+ if (kFallback) {
+ return kFallback(request, measure);
+ }
+ return NN_ERROR() << "Error sending FMQ packet: " << sendStatus.error();
}
// get result packet
- const auto result = mResultChannelReceiver->getBlocking();
- if (!result) {
- LOG(ERROR) << "Error retrieving FMQ packet";
- // only use fallback execution path if the packet could not be sent
- return getExecutionResult(V1_0::ErrorStatus::GENERAL_FAILURE, {}, kNoTiming12,
- /*fallback=*/false);
- }
-
- // unpack results and return (only use fallback execution path if the
- // packet could not be sent)
- auto [status, outputShapes, timing] = std::move(*result);
- return getExecutionResult(status, std::move(outputShapes), timing, /*fallback=*/false);
+ const auto [status, outputShapes, timing] =
+ NN_TRY(hal::utils::makeExecutionFailure(mResultChannelReceiver->getBlocking()));
+ return executionCallback(status, outputShapes, timing);
}
-void ExecutionBurstController::freeMemory(intptr_t key) {
- std::lock_guard<std::mutex> guard(mMutex);
-
- bool valid;
- int32_t slot;
- std::tie(valid, slot) = mMemoryCache->freeMemory(key);
- if (valid) {
- mBurstContext->freeMemory(slot).isOk();
- }
-}
-
-} // namespace android::nn
+} // namespace android::hardware::neuralnetworks::V1_2::utils
diff --git a/neuralnetworks/1.2/utils/src/ExecutionBurstServer.cpp b/neuralnetworks/1.2/utils/src/ExecutionBurstServer.cpp
index 022548d..50af881 100644
--- a/neuralnetworks/1.2/utils/src/ExecutionBurstServer.cpp
+++ b/neuralnetworks/1.2/utils/src/ExecutionBurstServer.cpp
@@ -17,8 +17,19 @@
#define LOG_TAG "ExecutionBurstServer"
#include "ExecutionBurstServer.h"
+#include "Conversions.h"
+#include "ExecutionBurstUtils.h"
#include <android-base/logging.h>
+#include <nnapi/IBurst.h>
+#include <nnapi/Result.h>
+#include <nnapi/TypeUtils.h>
+#include <nnapi/Types.h>
+#include <nnapi/Validation.h>
+#include <nnapi/hal/1.0/Conversions.h>
+#include <nnapi/hal/HandleError.h>
+#include <nnapi/hal/ProtectCallback.h>
+#include <nnapi/hal/TransferValue.h>
#include <algorithm>
#include <cstring>
@@ -29,134 +40,146 @@
#include <utility>
#include <vector>
-#include "ExecutionBurstUtils.h"
-#include "HalInterfaces.h"
#include "Tracing.h"
-namespace android::nn {
+namespace android::hardware::neuralnetworks::V1_2::utils {
namespace {
-// DefaultBurstExecutorWithCache adapts an IPreparedModel so that it can be
-// used as an IBurstExecutorWithCache. Specifically, the cache simply stores the
-// hidl_memory object, and the execution forwards calls to the provided
-// IPreparedModel's "executeSynchronously" method. With this class, hidl_memory
-// must be mapped and unmapped for each execution.
-class DefaultBurstExecutorWithCache : public ExecutionBurstServer::IBurstExecutorWithCache {
- public:
- DefaultBurstExecutorWithCache(V1_2::IPreparedModel* preparedModel)
- : mpPreparedModel(preparedModel) {}
+using neuralnetworks::utils::makeExecutionFailure;
- bool isCacheEntryPresent(int32_t slot) const override {
- const auto it = mMemoryCache.find(slot);
- return (it != mMemoryCache.end()) && it->second.valid();
+constexpr V1_2::Timing kNoTiming = {std::numeric_limits<uint64_t>::max(),
+ std::numeric_limits<uint64_t>::max()};
+
+nn::GeneralResult<std::vector<nn::SharedMemory>> getMemoriesCallback(
+ V1_0::ErrorStatus status, const hidl_vec<hidl_memory>& memories) {
+ HANDLE_HAL_STATUS(status) << "getting burst memories failed with " << toString(status);
+ std::vector<nn::SharedMemory> canonicalMemories;
+ canonicalMemories.reserve(memories.size());
+ for (const auto& memory : memories) {
+ canonicalMemories.push_back(NN_TRY(nn::convert(memory)));
}
-
- void addCacheEntry(const hardware::hidl_memory& memory, int32_t slot) override {
- mMemoryCache[slot] = memory;
- }
-
- void removeCacheEntry(int32_t slot) override { mMemoryCache.erase(slot); }
-
- std::tuple<V1_0::ErrorStatus, hardware::hidl_vec<V1_2::OutputShape>, V1_2::Timing> execute(
- const V1_0::Request& request, const std::vector<int32_t>& slots,
- V1_2::MeasureTiming measure) override {
- // convert slots to pools
- hardware::hidl_vec<hardware::hidl_memory> pools(slots.size());
- std::transform(slots.begin(), slots.end(), pools.begin(),
- [this](int32_t slot) { return mMemoryCache[slot]; });
-
- // create full request
- V1_0::Request fullRequest = request;
- fullRequest.pools = std::move(pools);
-
- // setup execution
- V1_0::ErrorStatus returnedStatus = V1_0::ErrorStatus::GENERAL_FAILURE;
- hardware::hidl_vec<V1_2::OutputShape> returnedOutputShapes;
- V1_2::Timing returnedTiming;
- auto cb = [&returnedStatus, &returnedOutputShapes, &returnedTiming](
- V1_0::ErrorStatus status,
- const hardware::hidl_vec<V1_2::OutputShape>& outputShapes,
- const V1_2::Timing& timing) {
- returnedStatus = status;
- returnedOutputShapes = outputShapes;
- returnedTiming = timing;
- };
-
- // execute
- const hardware::Return<void> ret =
- mpPreparedModel->executeSynchronously(fullRequest, measure, cb);
- if (!ret.isOk() || returnedStatus != V1_0::ErrorStatus::NONE) {
- LOG(ERROR) << "IPreparedModelAdapter::execute -- Error executing";
- return {returnedStatus, std::move(returnedOutputShapes), kNoTiming};
- }
-
- return std::make_tuple(returnedStatus, std::move(returnedOutputShapes), returnedTiming);
- }
-
- private:
- V1_2::IPreparedModel* const mpPreparedModel;
- std::map<int32_t, hardware::hidl_memory> mMemoryCache;
-};
+ return canonicalMemories;
+}
} // anonymous namespace
+ExecutionBurstServer::MemoryCache::MemoryCache(nn::SharedBurst burstExecutor,
+ sp<IBurstCallback> burstCallback)
+ : kBurstExecutor(std::move(burstExecutor)), kBurstCallback(std::move(burstCallback)) {
+ CHECK(kBurstExecutor != nullptr);
+ CHECK(kBurstCallback != nullptr);
+}
+
+nn::GeneralResult<std::vector<std::pair<nn::SharedMemory, nn::IBurst::OptionalCacheHold>>>
+ExecutionBurstServer::MemoryCache::getCacheEntries(const std::vector<int32_t>& slots) {
+ std::lock_guard guard(mMutex);
+ NN_TRY(ensureCacheEntriesArePresentLocked(slots));
+
+ std::vector<std::pair<nn::SharedMemory, nn::IBurst::OptionalCacheHold>> results;
+ results.reserve(slots.size());
+ for (int32_t slot : slots) {
+ results.push_back(NN_TRY(getCacheEntryLocked(slot)));
+ }
+
+ return results;
+}
+
+nn::GeneralResult<void> ExecutionBurstServer::MemoryCache::ensureCacheEntriesArePresentLocked(
+ const std::vector<int32_t>& slots) {
+ const auto slotIsKnown = [this](int32_t slot)
+ REQUIRES(mMutex) { return mCache.count(slot) > 0; };
+
+ // find unique unknown slots
+ std::vector<int32_t> unknownSlots = slots;
+ std::sort(unknownSlots.begin(), unknownSlots.end());
+ auto unknownSlotsEnd = std::unique(unknownSlots.begin(), unknownSlots.end());
+ unknownSlotsEnd = std::remove_if(unknownSlots.begin(), unknownSlotsEnd, slotIsKnown);
+ unknownSlots.erase(unknownSlotsEnd, unknownSlots.end());
+
+ // quick-exit if all slots are known
+ if (unknownSlots.empty()) {
+ return {};
+ }
+
+ auto cb = neuralnetworks::utils::CallbackValue(getMemoriesCallback);
+
+ const auto ret = kBurstCallback->getMemories(unknownSlots, cb);
+ HANDLE_TRANSPORT_FAILURE(ret);
+
+ auto returnedMemories = NN_TRY(cb.take());
+
+ if (returnedMemories.size() != unknownSlots.size()) {
+ return NN_ERROR()
+ << "ExecutionBurstServer::MemoryCache::ensureCacheEntriesArePresentLocked: Error "
+ "retrieving memories -- count mismatch between requested memories ("
+ << unknownSlots.size() << ") and returned memories (" << returnedMemories.size()
+ << ")";
+ }
+
+ // add memories to unknown slots
+ for (size_t i = 0; i < unknownSlots.size(); ++i) {
+ addCacheEntryLocked(unknownSlots[i], std::move(returnedMemories[i]));
+ }
+
+ return {};
+}
+
+nn::GeneralResult<std::pair<nn::SharedMemory, nn::IBurst::OptionalCacheHold>>
+ExecutionBurstServer::MemoryCache::getCacheEntryLocked(int32_t slot) {
+ if (const auto iter = mCache.find(slot); iter != mCache.end()) {
+ return iter->second;
+ }
+ return NN_ERROR()
+ << "ExecutionBurstServer::MemoryCache::getCacheEntryLocked failed because slot " << slot
+ << " is not present in the cache";
+}
+
+void ExecutionBurstServer::MemoryCache::addCacheEntryLocked(int32_t slot, nn::SharedMemory memory) {
+ auto hold = kBurstExecutor->cacheMemory(memory);
+ mCache.emplace(slot, std::make_pair(std::move(memory), std::move(hold)));
+}
+
+void ExecutionBurstServer::MemoryCache::removeCacheEntry(int32_t slot) {
+ std::lock_guard guard(mMutex);
+ mCache.erase(slot);
+}
+
// ExecutionBurstServer methods
-sp<ExecutionBurstServer> ExecutionBurstServer::create(
+nn::GeneralResult<sp<ExecutionBurstServer>> ExecutionBurstServer::create(
const sp<IBurstCallback>& callback, const MQDescriptorSync<FmqRequestDatum>& requestChannel,
- const MQDescriptorSync<FmqResultDatum>& resultChannel,
- std::shared_ptr<IBurstExecutorWithCache> executorWithCache,
+ const MQDescriptorSync<FmqResultDatum>& resultChannel, nn::SharedBurst burstExecutor,
std::chrono::microseconds pollingTimeWindow) {
// check inputs
- if (callback == nullptr || executorWithCache == nullptr) {
- LOG(ERROR) << "ExecutionBurstServer::create passed a nullptr";
- return nullptr;
+ if (callback == nullptr || burstExecutor == nullptr) {
+ return NN_ERROR() << "ExecutionBurstServer::create passed a nullptr";
}
// create FMQ objects
- std::unique_ptr<RequestChannelReceiver> requestChannelReceiver =
- RequestChannelReceiver::create(requestChannel, pollingTimeWindow);
- std::unique_ptr<ResultChannelSender> resultChannelSender =
- ResultChannelSender::create(resultChannel);
+ auto requestChannelReceiver =
+ NN_TRY(RequestChannelReceiver::create(requestChannel, pollingTimeWindow));
+ auto resultChannelSender = NN_TRY(ResultChannelSender::create(resultChannel));
// check FMQ objects
- if (!requestChannelReceiver || !resultChannelSender) {
- LOG(ERROR) << "ExecutionBurstServer::create failed to create FastMessageQueue";
- return nullptr;
- }
+ CHECK(requestChannelReceiver != nullptr);
+ CHECK(resultChannelSender != nullptr);
// make and return context
- return new ExecutionBurstServer(callback, std::move(requestChannelReceiver),
- std::move(resultChannelSender), std::move(executorWithCache));
+ return sp<ExecutionBurstServer>::make(PrivateConstructorTag{}, callback,
+ std::move(requestChannelReceiver),
+ std::move(resultChannelSender), std::move(burstExecutor));
}
-sp<ExecutionBurstServer> ExecutionBurstServer::create(
- const sp<IBurstCallback>& callback, const MQDescriptorSync<FmqRequestDatum>& requestChannel,
- const MQDescriptorSync<FmqResultDatum>& resultChannel, V1_2::IPreparedModel* preparedModel,
- std::chrono::microseconds pollingTimeWindow) {
- // check relevant input
- if (preparedModel == nullptr) {
- LOG(ERROR) << "ExecutionBurstServer::create passed a nullptr";
- return nullptr;
- }
-
- // adapt IPreparedModel to have caching
- const std::shared_ptr<DefaultBurstExecutorWithCache> preparedModelAdapter =
- std::make_shared<DefaultBurstExecutorWithCache>(preparedModel);
-
- // make and return context
- return ExecutionBurstServer::create(callback, requestChannel, resultChannel,
- preparedModelAdapter, pollingTimeWindow);
-}
-
-ExecutionBurstServer::ExecutionBurstServer(
- const sp<IBurstCallback>& callback, std::unique_ptr<RequestChannelReceiver> requestChannel,
- std::unique_ptr<ResultChannelSender> resultChannel,
- std::shared_ptr<IBurstExecutorWithCache> executorWithCache)
+ExecutionBurstServer::ExecutionBurstServer(PrivateConstructorTag /*tag*/,
+ const sp<IBurstCallback>& callback,
+ std::unique_ptr<RequestChannelReceiver> requestChannel,
+ std::unique_ptr<ResultChannelSender> resultChannel,
+ nn::SharedBurst burstExecutor)
: mCallback(callback),
mRequestChannelReceiver(std::move(requestChannel)),
mResultChannelSender(std::move(resultChannel)),
- mExecutorWithCache(std::move(executorWithCache)) {
+ mBurstExecutor(std::move(burstExecutor)),
+ mMemoryCache(mBurstExecutor, mCallback) {
// TODO: highly document the threading behavior of this class
mWorker = std::thread([this] { task(); });
}
@@ -170,51 +193,9 @@
mWorker.join();
}
-hardware::Return<void> ExecutionBurstServer::freeMemory(int32_t slot) {
- std::lock_guard<std::mutex> hold(mMutex);
- mExecutorWithCache->removeCacheEntry(slot);
- return hardware::Void();
-}
-
-void ExecutionBurstServer::ensureCacheEntriesArePresentLocked(const std::vector<int32_t>& slots) {
- const auto slotIsKnown = [this](int32_t slot) {
- return mExecutorWithCache->isCacheEntryPresent(slot);
- };
-
- // find unique unknown slots
- std::vector<int32_t> unknownSlots = slots;
- auto unknownSlotsEnd = unknownSlots.end();
- std::sort(unknownSlots.begin(), unknownSlotsEnd);
- unknownSlotsEnd = std::unique(unknownSlots.begin(), unknownSlotsEnd);
- unknownSlotsEnd = std::remove_if(unknownSlots.begin(), unknownSlotsEnd, slotIsKnown);
- unknownSlots.erase(unknownSlotsEnd, unknownSlots.end());
-
- // quick-exit if all slots are known
- if (unknownSlots.empty()) {
- return;
- }
-
- V1_0::ErrorStatus errorStatus = V1_0::ErrorStatus::GENERAL_FAILURE;
- std::vector<hardware::hidl_memory> returnedMemories;
- auto cb = [&errorStatus, &returnedMemories](
- V1_0::ErrorStatus status,
- const hardware::hidl_vec<hardware::hidl_memory>& memories) {
- errorStatus = status;
- returnedMemories = memories;
- };
-
- const hardware::Return<void> ret = mCallback->getMemories(unknownSlots, cb);
-
- if (!ret.isOk() || errorStatus != V1_0::ErrorStatus::NONE ||
- returnedMemories.size() != unknownSlots.size()) {
- LOG(ERROR) << "Error retrieving memories";
- return;
- }
-
- // add memories to unknown slots
- for (size_t i = 0; i < unknownSlots.size(); ++i) {
- mExecutorWithCache->addCacheEntry(returnedMemories[i], unknownSlots[i]);
- }
+Return<void> ExecutionBurstServer::freeMemory(int32_t slot) {
+ mMemoryCache.removeCacheEntry(slot);
+ return Void();
}
void ExecutionBurstServer::task() {
@@ -223,38 +204,65 @@
// receive request
auto arguments = mRequestChannelReceiver->getBlocking();
- // if the request packet was not properly received, return a generic
- // error and skip the execution
+ // if the request packet was not properly received, return a generic error and skip the
+ // execution
//
- // if the burst is being torn down, skip the execution so the "task"
- // function can end
- if (!arguments) {
+ // if the burst is being torn down, skip the execution so the "task" function can end
+ if (!arguments.has_value()) {
if (!mTeardown) {
mResultChannelSender->send(V1_0::ErrorStatus::GENERAL_FAILURE, {}, kNoTiming);
}
continue;
}
- // otherwise begin tracing execution
- NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_EXECUTION,
- "ExecutionBurstServer getting memory, executing, and returning results");
+ // unpack the arguments; types are Request, std::vector<int32_t>, and MeasureTiming,
+ // respectively
+ const auto [requestWithoutPools, slotsOfPools, measure] = std::move(arguments).value();
- // unpack the arguments; types are Request, std::vector<int32_t>, and
- // MeasureTiming, respectively
- const auto [requestWithoutPools, slotsOfPools, measure] = std::move(*arguments);
-
- // ensure executor with cache has required memory
- std::lock_guard<std::mutex> hold(mMutex);
- ensureCacheEntriesArePresentLocked(slotsOfPools);
-
- // perform computation; types are ErrorStatus, hidl_vec<OutputShape>,
- // and Timing, respectively
- const auto [errorStatus, outputShapes, returnedTiming] =
- mExecutorWithCache->execute(requestWithoutPools, slotsOfPools, measure);
+ auto result = execute(requestWithoutPools, slotsOfPools, measure);
// return result
- mResultChannelSender->send(errorStatus, outputShapes, returnedTiming);
+ if (result.has_value()) {
+ const auto& [outputShapes, timing] = result.value();
+ mResultChannelSender->send(V1_0::ErrorStatus::NONE, outputShapes, timing);
+ } else {
+ const auto& [message, code, outputShapes] = result.error();
+ LOG(ERROR) << "IBurst::execute failed with " << code << ": " << message;
+ mResultChannelSender->send(convert(code).value(), convert(outputShapes).value(),
+ kNoTiming);
+ }
}
}
-} // namespace android::nn
+nn::ExecutionResult<std::pair<hidl_vec<OutputShape>, Timing>> ExecutionBurstServer::execute(
+ const V1_0::Request& requestWithoutPools, const std::vector<int32_t>& slotsOfPools,
+ MeasureTiming measure) {
+ NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_EXECUTION,
+ "ExecutionBurstServer getting memory, executing, and returning results");
+
+ // ensure executor with cache has required memory
+ const auto cacheEntries =
+ NN_TRY(makeExecutionFailure(mMemoryCache.getCacheEntries(slotsOfPools)));
+
+ // convert request, populating its pools
+ // This code performs an unvalidated convert because the request object without its pools is
+ // invalid because it is incomplete. Instead, the validation is performed after the memory pools
+ // have been added to the request.
+ auto canonicalRequest =
+ NN_TRY(makeExecutionFailure(nn::unvalidatedConvert(requestWithoutPools)));
+ CHECK(canonicalRequest.pools.empty());
+ std::transform(cacheEntries.begin(), cacheEntries.end(),
+ std::back_inserter(canonicalRequest.pools),
+ [](const auto& cacheEntry) { return cacheEntry.first; });
+ NN_TRY(makeExecutionFailure(validate(canonicalRequest)));
+
+ nn::MeasureTiming canonicalMeasure = NN_TRY(makeExecutionFailure(nn::convert(measure)));
+
+ const auto [outputShapes, timing] =
+ NN_TRY(mBurstExecutor->execute(canonicalRequest, canonicalMeasure));
+
+ return std::make_pair(NN_TRY(makeExecutionFailure(convert(outputShapes))),
+ NN_TRY(makeExecutionFailure(convert(timing))));
+}
+
+} // namespace android::hardware::neuralnetworks::V1_2::utils
diff --git a/neuralnetworks/1.2/utils/src/ExecutionBurstUtils.cpp b/neuralnetworks/1.2/utils/src/ExecutionBurstUtils.cpp
index f0275f9..ca3a52c 100644
--- a/neuralnetworks/1.2/utils/src/ExecutionBurstUtils.cpp
+++ b/neuralnetworks/1.2/utils/src/ExecutionBurstUtils.cpp
@@ -19,11 +19,15 @@
#include "ExecutionBurstUtils.h"
#include <android-base/logging.h>
+#include <android-base/properties.h>
#include <android/hardware/neuralnetworks/1.0/types.h>
#include <android/hardware/neuralnetworks/1.1/types.h>
#include <android/hardware/neuralnetworks/1.2/types.h>
#include <fmq/MessageQueue.h>
#include <hidl/MQDescriptor.h>
+#include <nnapi/Result.h>
+#include <nnapi/Types.h>
+#include <nnapi/hal/ProtectCallback.h>
#include <atomic>
#include <chrono>
@@ -39,84 +43,97 @@
constexpr V1_2::Timing kNoTiming = {std::numeric_limits<uint64_t>::max(),
std::numeric_limits<uint64_t>::max()};
+std::chrono::microseconds getPollingTimeWindow(const std::string& property) {
+ constexpr int32_t kDefaultPollingTimeWindow = 0;
+#ifdef NN_DEBUGGABLE
+ constexpr int32_t kMinPollingTimeWindow = 0;
+ const int32_t selectedPollingTimeWindow =
+ base::GetIntProperty(property, kDefaultPollingTimeWindow, kMinPollingTimeWindow);
+ return std::chrono::microseconds(selectedPollingTimeWindow);
+#else
+ (void)property;
+ return std::chrono::microseconds(kDefaultPollingTimeWindow);
+#endif // NN_DEBUGGABLE
+}
+
+} // namespace
+
+std::chrono::microseconds getBurstControllerPollingTimeWindow() {
+ return getPollingTimeWindow("debug.nn.burst-controller-polling-window");
+}
+
+std::chrono::microseconds getBurstServerPollingTimeWindow() {
+ return getPollingTimeWindow("debug.nn.burst-server-polling-window");
}
// serialize a request into a packet
std::vector<FmqRequestDatum> serialize(const V1_0::Request& request, V1_2::MeasureTiming measure,
const std::vector<int32_t>& slots) {
// count how many elements need to be sent for a request
- size_t count = 2 + request.inputs.size() + request.outputs.size() + request.pools.size();
+ size_t count = 2 + request.inputs.size() + request.outputs.size() + slots.size();
for (const auto& input : request.inputs) {
count += input.dimensions.size();
}
for (const auto& output : request.outputs) {
count += output.dimensions.size();
}
+ CHECK_LE(count, std::numeric_limits<uint32_t>::max());
// create buffer to temporarily store elements
std::vector<FmqRequestDatum> data;
data.reserve(count);
// package packetInfo
- {
- FmqRequestDatum datum;
- datum.packetInformation(
- {/*.packetSize=*/static_cast<uint32_t>(count),
- /*.numberOfInputOperands=*/static_cast<uint32_t>(request.inputs.size()),
- /*.numberOfOutputOperands=*/static_cast<uint32_t>(request.outputs.size()),
- /*.numberOfPools=*/static_cast<uint32_t>(request.pools.size())});
- data.push_back(datum);
- }
+ data.emplace_back();
+ data.back().packetInformation(
+ {.packetSize = static_cast<uint32_t>(count),
+ .numberOfInputOperands = static_cast<uint32_t>(request.inputs.size()),
+ .numberOfOutputOperands = static_cast<uint32_t>(request.outputs.size()),
+ .numberOfPools = static_cast<uint32_t>(slots.size())});
// package input data
for (const auto& input : request.inputs) {
// package operand information
- FmqRequestDatum datum;
- datum.inputOperandInformation(
- {/*.hasNoValue=*/input.hasNoValue,
- /*.location=*/input.location,
- /*.numberOfDimensions=*/static_cast<uint32_t>(input.dimensions.size())});
- data.push_back(datum);
+ data.emplace_back();
+ data.back().inputOperandInformation(
+ {.hasNoValue = input.hasNoValue,
+ .location = input.location,
+ .numberOfDimensions = static_cast<uint32_t>(input.dimensions.size())});
// package operand dimensions
for (uint32_t dimension : input.dimensions) {
- FmqRequestDatum datum;
- datum.inputOperandDimensionValue(dimension);
- data.push_back(datum);
+ data.emplace_back();
+ data.back().inputOperandDimensionValue(dimension);
}
}
// package output data
for (const auto& output : request.outputs) {
// package operand information
- FmqRequestDatum datum;
- datum.outputOperandInformation(
- {/*.hasNoValue=*/output.hasNoValue,
- /*.location=*/output.location,
- /*.numberOfDimensions=*/static_cast<uint32_t>(output.dimensions.size())});
- data.push_back(datum);
+ data.emplace_back();
+ data.back().outputOperandInformation(
+ {.hasNoValue = output.hasNoValue,
+ .location = output.location,
+ .numberOfDimensions = static_cast<uint32_t>(output.dimensions.size())});
// package operand dimensions
for (uint32_t dimension : output.dimensions) {
- FmqRequestDatum datum;
- datum.outputOperandDimensionValue(dimension);
- data.push_back(datum);
+ data.emplace_back();
+ data.back().outputOperandDimensionValue(dimension);
}
}
// package pool identifier
for (int32_t slot : slots) {
- FmqRequestDatum datum;
- datum.poolIdentifier(slot);
- data.push_back(datum);
+ data.emplace_back();
+ data.back().poolIdentifier(slot);
}
// package measureTiming
- {
- FmqRequestDatum datum;
- datum.measureTiming(measure);
- data.push_back(datum);
- }
+ data.emplace_back();
+ data.back().measureTiming(measure);
+
+ CHECK_EQ(data.size(), count);
// return packet
return data;
@@ -137,46 +154,38 @@
data.reserve(count);
// package packetInfo
- {
- FmqResultDatum datum;
- datum.packetInformation({/*.packetSize=*/static_cast<uint32_t>(count),
- /*.errorStatus=*/errorStatus,
- /*.numberOfOperands=*/static_cast<uint32_t>(outputShapes.size())});
- data.push_back(datum);
- }
+ data.emplace_back();
+ data.back().packetInformation({.packetSize = static_cast<uint32_t>(count),
+ .errorStatus = errorStatus,
+ .numberOfOperands = static_cast<uint32_t>(outputShapes.size())});
// package output shape data
for (const auto& operand : outputShapes) {
// package operand information
- FmqResultDatum::OperandInformation info{};
- info.isSufficient = operand.isSufficient;
- info.numberOfDimensions = static_cast<uint32_t>(operand.dimensions.size());
-
- FmqResultDatum datum;
- datum.operandInformation(info);
- data.push_back(datum);
+ data.emplace_back();
+ data.back().operandInformation(
+ {.isSufficient = operand.isSufficient,
+ .numberOfDimensions = static_cast<uint32_t>(operand.dimensions.size())});
// package operand dimensions
for (uint32_t dimension : operand.dimensions) {
- FmqResultDatum datum;
- datum.operandDimensionValue(dimension);
- data.push_back(datum);
+ data.emplace_back();
+ data.back().operandDimensionValue(dimension);
}
}
// package executionTiming
- {
- FmqResultDatum datum;
- datum.executionTiming(timing);
- data.push_back(datum);
- }
+ data.emplace_back();
+ data.back().executionTiming(timing);
+
+ CHECK_EQ(data.size(), count);
// return result
return data;
}
// deserialize request
-std::optional<std::tuple<V1_0::Request, std::vector<int32_t>, V1_2::MeasureTiming>> deserialize(
+nn::Result<std::tuple<V1_0::Request, std::vector<int32_t>, V1_2::MeasureTiming>> deserialize(
const std::vector<FmqRequestDatum>& data) {
using discriminator = FmqRequestDatum::hidl_discriminator;
@@ -184,8 +193,7 @@
// validate packet information
if (data.size() == 0 || data[index].getDiscriminator() != discriminator::packetInformation) {
- LOG(ERROR) << "FMQ Request packet ill-formed";
- return std::nullopt;
+ return NN_ERROR() << "FMQ Request packet ill-formed";
}
// unpackage packet information
@@ -198,8 +206,7 @@
// verify packet size
if (data.size() != packetSize) {
- LOG(ERROR) << "FMQ Request packet ill-formed";
- return std::nullopt;
+ return NN_ERROR() << "FMQ Request packet ill-formed";
}
// unpackage input operands
@@ -208,8 +215,7 @@
for (size_t operand = 0; operand < numberOfInputOperands; ++operand) {
// validate input operand information
if (data[index].getDiscriminator() != discriminator::inputOperandInformation) {
- LOG(ERROR) << "FMQ Request packet ill-formed";
- return std::nullopt;
+ return NN_ERROR() << "FMQ Request packet ill-formed";
}
// unpackage operand information
@@ -226,8 +232,7 @@
for (size_t i = 0; i < numberOfDimensions; ++i) {
// validate dimension
if (data[index].getDiscriminator() != discriminator::inputOperandDimensionValue) {
- LOG(ERROR) << "FMQ Request packet ill-formed";
- return std::nullopt;
+ return NN_ERROR() << "FMQ Request packet ill-formed";
}
// unpackage dimension
@@ -240,7 +245,7 @@
// store result
inputs.push_back(
- {/*.hasNoValue=*/hasNoValue, /*.location=*/location, /*.dimensions=*/dimensions});
+ {.hasNoValue = hasNoValue, .location = location, .dimensions = dimensions});
}
// unpackage output operands
@@ -249,8 +254,7 @@
for (size_t operand = 0; operand < numberOfOutputOperands; ++operand) {
// validate output operand information
if (data[index].getDiscriminator() != discriminator::outputOperandInformation) {
- LOG(ERROR) << "FMQ Request packet ill-formed";
- return std::nullopt;
+ return NN_ERROR() << "FMQ Request packet ill-formed";
}
// unpackage operand information
@@ -267,8 +271,7 @@
for (size_t i = 0; i < numberOfDimensions; ++i) {
// validate dimension
if (data[index].getDiscriminator() != discriminator::outputOperandDimensionValue) {
- LOG(ERROR) << "FMQ Request packet ill-formed";
- return std::nullopt;
+ return NN_ERROR() << "FMQ Request packet ill-formed";
}
// unpackage dimension
@@ -281,7 +284,7 @@
// store result
outputs.push_back(
- {/*.hasNoValue=*/hasNoValue, /*.location=*/location, /*.dimensions=*/dimensions});
+ {.hasNoValue = hasNoValue, .location = location, .dimensions = dimensions});
}
// unpackage pools
@@ -290,8 +293,7 @@
for (size_t pool = 0; pool < numberOfPools; ++pool) {
// validate input operand information
if (data[index].getDiscriminator() != discriminator::poolIdentifier) {
- LOG(ERROR) << "FMQ Request packet ill-formed";
- return std::nullopt;
+ return NN_ERROR() << "FMQ Request packet ill-formed";
}
// unpackage operand information
@@ -304,8 +306,7 @@
// validate measureTiming
if (data[index].getDiscriminator() != discriminator::measureTiming) {
- LOG(ERROR) << "FMQ Request packet ill-formed";
- return std::nullopt;
+ return NN_ERROR() << "FMQ Request packet ill-formed";
}
// unpackage measureTiming
@@ -314,27 +315,23 @@
// validate packet information
if (index != packetSize) {
- LOG(ERROR) << "FMQ Result packet ill-formed";
- return std::nullopt;
+ return NN_ERROR() << "FMQ Result packet ill-formed";
}
// return request
- V1_0::Request request = {/*.inputs=*/inputs, /*.outputs=*/outputs, /*.pools=*/{}};
+ V1_0::Request request = {.inputs = inputs, .outputs = outputs, .pools = {}};
return std::make_tuple(std::move(request), std::move(slots), measure);
}
// deserialize a packet into the result
-std::optional<std::tuple<V1_0::ErrorStatus, std::vector<V1_2::OutputShape>, V1_2::Timing>>
-deserialize(const std::vector<FmqResultDatum>& data) {
+nn::Result<std::tuple<V1_0::ErrorStatus, std::vector<V1_2::OutputShape>, V1_2::Timing>> deserialize(
+ const std::vector<FmqResultDatum>& data) {
using discriminator = FmqResultDatum::hidl_discriminator;
-
- std::vector<V1_2::OutputShape> outputShapes;
size_t index = 0;
// validate packet information
if (data.size() == 0 || data[index].getDiscriminator() != discriminator::packetInformation) {
- LOG(ERROR) << "FMQ Result packet ill-formed";
- return std::nullopt;
+ return NN_ERROR() << "FMQ Result packet ill-formed";
}
// unpackage packet information
@@ -346,16 +343,16 @@
// verify packet size
if (data.size() != packetSize) {
- LOG(ERROR) << "FMQ Result packet ill-formed";
- return std::nullopt;
+ return NN_ERROR() << "FMQ Result packet ill-formed";
}
// unpackage operands
+ std::vector<V1_2::OutputShape> outputShapes;
+ outputShapes.reserve(numberOfOperands);
for (size_t operand = 0; operand < numberOfOperands; ++operand) {
// validate operand information
if (data[index].getDiscriminator() != discriminator::operandInformation) {
- LOG(ERROR) << "FMQ Result packet ill-formed";
- return std::nullopt;
+ return NN_ERROR() << "FMQ Result packet ill-formed";
}
// unpackage operand information
@@ -370,8 +367,7 @@
for (size_t i = 0; i < numberOfDimensions; ++i) {
// validate dimension
if (data[index].getDiscriminator() != discriminator::operandDimensionValue) {
- LOG(ERROR) << "FMQ Result packet ill-formed";
- return std::nullopt;
+ return NN_ERROR() << "FMQ Result packet ill-formed";
}
// unpackage dimension
@@ -383,13 +379,12 @@
}
// store result
- outputShapes.push_back({/*.dimensions=*/dimensions, /*.isSufficient=*/isSufficient});
+ outputShapes.push_back({.dimensions = dimensions, .isSufficient = isSufficient});
}
// validate execution timing
if (data[index].getDiscriminator() != discriminator::executionTiming) {
- LOG(ERROR) << "FMQ Result packet ill-formed";
- return std::nullopt;
+ return NN_ERROR() << "FMQ Result packet ill-formed";
}
// unpackage execution timing
@@ -398,123 +393,113 @@
// validate packet information
if (index != packetSize) {
- LOG(ERROR) << "FMQ Result packet ill-formed";
- return std::nullopt;
+ return NN_ERROR() << "FMQ Result packet ill-formed";
}
// return result
return std::make_tuple(errorStatus, std::move(outputShapes), timing);
}
-V1_0::ErrorStatus legacyConvertResultCodeToErrorStatus(int resultCode) {
- return convertToV1_0(convertResultCodeToErrorStatus(resultCode));
-}
-
// RequestChannelSender methods
-std::pair<std::unique_ptr<RequestChannelSender>, const FmqRequestDescriptor*>
+nn::GeneralResult<
+ std::pair<std::unique_ptr<RequestChannelSender>, const MQDescriptorSync<FmqRequestDatum>*>>
RequestChannelSender::create(size_t channelLength) {
- std::unique_ptr<FmqRequestChannel> fmqRequestChannel =
- std::make_unique<FmqRequestChannel>(channelLength, /*confEventFlag=*/true);
- if (!fmqRequestChannel->isValid()) {
- LOG(ERROR) << "Unable to create RequestChannelSender";
- return {nullptr, nullptr};
+ auto requestChannelSender =
+ std::make_unique<RequestChannelSender>(PrivateConstructorTag{}, channelLength);
+ if (!requestChannelSender->mFmqRequestChannel.isValid()) {
+ return NN_ERROR() << "Unable to create RequestChannelSender";
}
- const FmqRequestDescriptor* descriptor = fmqRequestChannel->getDesc();
- return std::make_pair(std::make_unique<RequestChannelSender>(std::move(fmqRequestChannel)),
- descriptor);
+ const MQDescriptorSync<FmqRequestDatum>* descriptor =
+ requestChannelSender->mFmqRequestChannel.getDesc();
+ return std::make_pair(std::move(requestChannelSender), descriptor);
}
-RequestChannelSender::RequestChannelSender(std::unique_ptr<FmqRequestChannel> fmqRequestChannel)
- : mFmqRequestChannel(std::move(fmqRequestChannel)) {}
+RequestChannelSender::RequestChannelSender(PrivateConstructorTag /*tag*/, size_t channelLength)
+ : mFmqRequestChannel(channelLength, /*configureEventFlagWord=*/true) {}
-bool RequestChannelSender::send(const V1_0::Request& request, V1_2::MeasureTiming measure,
- const std::vector<int32_t>& slots) {
+nn::Result<void> RequestChannelSender::send(const V1_0::Request& request,
+ V1_2::MeasureTiming measure,
+ const std::vector<int32_t>& slots) {
const std::vector<FmqRequestDatum> serialized = serialize(request, measure, slots);
return sendPacket(serialized);
}
-bool RequestChannelSender::sendPacket(const std::vector<FmqRequestDatum>& packet) {
+nn::Result<void> RequestChannelSender::sendPacket(const std::vector<FmqRequestDatum>& packet) {
if (!mValid) {
- return false;
+ return NN_ERROR() << "FMQ object is invalid";
}
- if (packet.size() > mFmqRequestChannel->availableToWrite()) {
- LOG(ERROR)
- << "RequestChannelSender::sendPacket -- packet size exceeds size available in FMQ";
- return false;
+ if (packet.size() > mFmqRequestChannel.availableToWrite()) {
+ return NN_ERROR()
+ << "RequestChannelSender::sendPacket -- packet size exceeds size available in FMQ";
}
- // Always send the packet with "blocking" because this signals the futex and
- // unblocks the consumer if it is waiting on the futex.
- return mFmqRequestChannel->writeBlocking(packet.data(), packet.size());
+ // Always send the packet with "blocking" because this signals the futex and unblocks the
+ // consumer if it is waiting on the futex.
+ const bool success = mFmqRequestChannel.writeBlocking(packet.data(), packet.size());
+ if (!success) {
+ return NN_ERROR()
+ << "RequestChannelSender::sendPacket -- FMQ's writeBlocking returned an error";
+ }
+
+ return {};
}
-void RequestChannelSender::invalidate() {
+void RequestChannelSender::notifyAsDeadObject() {
mValid = false;
}
// RequestChannelReceiver methods
-std::unique_ptr<RequestChannelReceiver> RequestChannelReceiver::create(
- const FmqRequestDescriptor& requestChannel, std::chrono::microseconds pollingTimeWindow) {
- std::unique_ptr<FmqRequestChannel> fmqRequestChannel =
- std::make_unique<FmqRequestChannel>(requestChannel);
+nn::GeneralResult<std::unique_ptr<RequestChannelReceiver>> RequestChannelReceiver::create(
+ const MQDescriptorSync<FmqRequestDatum>& requestChannel,
+ std::chrono::microseconds pollingTimeWindow) {
+ auto requestChannelReceiver = std::make_unique<RequestChannelReceiver>(
+ PrivateConstructorTag{}, requestChannel, pollingTimeWindow);
- if (!fmqRequestChannel->isValid()) {
- LOG(ERROR) << "Unable to create RequestChannelReceiver";
- return nullptr;
+ if (!requestChannelReceiver->mFmqRequestChannel.isValid()) {
+ return NN_ERROR() << "Unable to create RequestChannelReceiver";
}
- if (fmqRequestChannel->getEventFlagWord() == nullptr) {
- LOG(ERROR)
- << "RequestChannelReceiver::create was passed an MQDescriptor without an EventFlag";
- return nullptr;
+ if (requestChannelReceiver->mFmqRequestChannel.getEventFlagWord() == nullptr) {
+ return NN_ERROR()
+ << "RequestChannelReceiver::create was passed an MQDescriptor without an EventFlag";
}
- return std::make_unique<RequestChannelReceiver>(std::move(fmqRequestChannel),
- pollingTimeWindow);
+ return requestChannelReceiver;
}
-RequestChannelReceiver::RequestChannelReceiver(std::unique_ptr<FmqRequestChannel> fmqRequestChannel,
- std::chrono::microseconds pollingTimeWindow)
- : mFmqRequestChannel(std::move(fmqRequestChannel)), kPollingTimeWindow(pollingTimeWindow) {}
+RequestChannelReceiver::RequestChannelReceiver(
+ PrivateConstructorTag /*tag*/, const MQDescriptorSync<FmqRequestDatum>& requestChannel,
+ std::chrono::microseconds pollingTimeWindow)
+ : mFmqRequestChannel(requestChannel), kPollingTimeWindow(pollingTimeWindow) {}
-std::optional<std::tuple<V1_0::Request, std::vector<int32_t>, V1_2::MeasureTiming>>
+nn::Result<std::tuple<V1_0::Request, std::vector<int32_t>, V1_2::MeasureTiming>>
RequestChannelReceiver::getBlocking() {
- const auto packet = getPacketBlocking();
- if (!packet) {
- return std::nullopt;
- }
-
- return deserialize(*packet);
+ const auto packet = NN_TRY(getPacketBlocking());
+ return deserialize(packet);
}
void RequestChannelReceiver::invalidate() {
mTeardown = true;
// force unblock
- // ExecutionBurstServer is by default waiting on a request packet. If the
- // client process destroys its burst object, the server may still be waiting
- // on the futex. This force unblock wakes up any thread waiting on the
- // futex.
- // TODO: look for a different/better way to signal/notify the futex to wake
- // up any thread waiting on it
- FmqRequestDatum datum;
- datum.packetInformation({/*.packetSize=*/0, /*.numberOfInputOperands=*/0,
- /*.numberOfOutputOperands=*/0, /*.numberOfPools=*/0});
- mFmqRequestChannel->writeBlocking(&datum, 1);
+ // ExecutionBurstServer is by default waiting on a request packet. If the client process
+ // destroys its burst object, the server may still be waiting on the futex. This force unblock
+ // wakes up any thread waiting on the futex.
+ const auto data = serialize(V1_0::Request{}, V1_2::MeasureTiming::NO, {});
+ mFmqRequestChannel.writeBlocking(data.data(), data.size());
}
-std::optional<std::vector<FmqRequestDatum>> RequestChannelReceiver::getPacketBlocking() {
+nn::Result<std::vector<FmqRequestDatum>> RequestChannelReceiver::getPacketBlocking() {
if (mTeardown) {
- return std::nullopt;
+ return NN_ERROR() << "FMQ object is being torn down";
}
- // First spend time polling if results are available in FMQ instead of
- // waiting on the futex. Polling is more responsive (yielding lower
- // latencies), but can take up more power, so only poll for a limited period
- // of time.
+ // First spend time polling if results are available in FMQ instead of waiting on the futex.
+ // Polling is more responsive (yielding lower latencies), but can take up more power, so only
+ // poll for a limited period of time.
auto& getCurrentTime = std::chrono::high_resolution_clock::now;
const auto timeToStopPolling = getCurrentTime() + kPollingTimeWindow;
@@ -522,173 +507,144 @@
while (getCurrentTime() < timeToStopPolling) {
// if class is being torn down, immediately return
if (mTeardown.load(std::memory_order_relaxed)) {
- return std::nullopt;
+ return NN_ERROR() << "FMQ object is being torn down";
}
- // Check if data is available. If it is, immediately retrieve it and
- // return.
- const size_t available = mFmqRequestChannel->availableToRead();
+ // Check if data is available. If it is, immediately retrieve it and return.
+ const size_t available = mFmqRequestChannel.availableToRead();
if (available > 0) {
- // This is the first point when we know an execution is occurring,
- // so begin to collect systraces. Note that a similar systrace does
- // not exist at the corresponding point in
- // ResultChannelReceiver::getPacketBlocking because the execution is
- // already in flight.
- NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_EXECUTION,
- "ExecutionBurstServer getting packet");
std::vector<FmqRequestDatum> packet(available);
- const bool success = mFmqRequestChannel->read(packet.data(), available);
+ const bool success = mFmqRequestChannel.readBlocking(packet.data(), available);
if (!success) {
- LOG(ERROR) << "Error receiving packet";
- return std::nullopt;
+ return NN_ERROR() << "Error receiving packet";
}
- return std::make_optional(std::move(packet));
+ return packet;
}
}
- // If we get to this point, we either stopped polling because it was taking
- // too long or polling was not allowed. Instead, perform a blocking call
- // which uses a futex to save power.
+ // If we get to this point, we either stopped polling because it was taking too long or polling
+ // was not allowed. Instead, perform a blocking call which uses a futex to save power.
// wait for request packet and read first element of request packet
FmqRequestDatum datum;
- bool success = mFmqRequestChannel->readBlocking(&datum, 1);
-
- // This is the first point when we know an execution is occurring, so begin
- // to collect systraces. Note that a similar systrace does not exist at the
- // corresponding point in ResultChannelReceiver::getPacketBlocking because
- // the execution is already in flight.
- NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_EXECUTION, "ExecutionBurstServer getting packet");
+ bool success = mFmqRequestChannel.readBlocking(&datum, 1);
// retrieve remaining elements
- // NOTE: all of the data is already available at this point, so there's no
- // need to do a blocking wait to wait for more data. This is known because
- // in FMQ, all writes are published (made available) atomically. Currently,
- // the producer always publishes the entire packet in one function call, so
- // if the first element of the packet is available, the remaining elements
- // are also available.
- const size_t count = mFmqRequestChannel->availableToRead();
+ // NOTE: all of the data is already available at this point, so there's no need to do a blocking
+ // wait to wait for more data. This is known because in FMQ, all writes are published (made
+ // available) atomically. Currently, the producer always publishes the entire packet in one
+ // function call, so if the first element of the packet is available, the remaining elements are
+ // also available.
+ const size_t count = mFmqRequestChannel.availableToRead();
std::vector<FmqRequestDatum> packet(count + 1);
std::memcpy(&packet.front(), &datum, sizeof(datum));
- success &= mFmqRequestChannel->read(packet.data() + 1, count);
+ success &= mFmqRequestChannel.read(packet.data() + 1, count);
// terminate loop
if (mTeardown) {
- return std::nullopt;
+ return NN_ERROR() << "FMQ object is being torn down";
}
// ensure packet was successfully received
if (!success) {
- LOG(ERROR) << "Error receiving packet";
- return std::nullopt;
+ return NN_ERROR() << "Error receiving packet";
}
- return std::make_optional(std::move(packet));
+ return packet;
}
// ResultChannelSender methods
-std::unique_ptr<ResultChannelSender> ResultChannelSender::create(
- const FmqResultDescriptor& resultChannel) {
- std::unique_ptr<FmqResultChannel> fmqResultChannel =
- std::make_unique<FmqResultChannel>(resultChannel);
+nn::GeneralResult<std::unique_ptr<ResultChannelSender>> ResultChannelSender::create(
+ const MQDescriptorSync<FmqResultDatum>& resultChannel) {
+ auto resultChannelSender =
+ std::make_unique<ResultChannelSender>(PrivateConstructorTag{}, resultChannel);
- if (!fmqResultChannel->isValid()) {
- LOG(ERROR) << "Unable to create RequestChannelSender";
- return nullptr;
+ if (!resultChannelSender->mFmqResultChannel.isValid()) {
+ return NN_ERROR() << "Unable to create RequestChannelSender";
}
- if (fmqResultChannel->getEventFlagWord() == nullptr) {
- LOG(ERROR) << "ResultChannelSender::create was passed an MQDescriptor without an EventFlag";
- return nullptr;
+ if (resultChannelSender->mFmqResultChannel.getEventFlagWord() == nullptr) {
+ return NN_ERROR()
+ << "ResultChannelSender::create was passed an MQDescriptor without an EventFlag";
}
- return std::make_unique<ResultChannelSender>(std::move(fmqResultChannel));
+ return resultChannelSender;
}
-ResultChannelSender::ResultChannelSender(std::unique_ptr<FmqResultChannel> fmqResultChannel)
- : mFmqResultChannel(std::move(fmqResultChannel)) {}
+ResultChannelSender::ResultChannelSender(PrivateConstructorTag /*tag*/,
+ const MQDescriptorSync<FmqResultDatum>& resultChannel)
+ : mFmqResultChannel(resultChannel) {}
-bool ResultChannelSender::send(V1_0::ErrorStatus errorStatus,
+void ResultChannelSender::send(V1_0::ErrorStatus errorStatus,
const std::vector<V1_2::OutputShape>& outputShapes,
V1_2::Timing timing) {
const std::vector<FmqResultDatum> serialized = serialize(errorStatus, outputShapes, timing);
- return sendPacket(serialized);
+ sendPacket(serialized);
}
-bool ResultChannelSender::sendPacket(const std::vector<FmqResultDatum>& packet) {
- if (packet.size() > mFmqResultChannel->availableToWrite()) {
+void ResultChannelSender::sendPacket(const std::vector<FmqResultDatum>& packet) {
+ if (packet.size() > mFmqResultChannel.availableToWrite()) {
LOG(ERROR)
<< "ResultChannelSender::sendPacket -- packet size exceeds size available in FMQ";
const std::vector<FmqResultDatum> errorPacket =
serialize(V1_0::ErrorStatus::GENERAL_FAILURE, {}, kNoTiming);
- // Always send the packet with "blocking" because this signals the futex
- // and unblocks the consumer if it is waiting on the futex.
- return mFmqResultChannel->writeBlocking(errorPacket.data(), errorPacket.size());
+ // Always send the packet with "blocking" because this signals the futex and unblocks the
+ // consumer if it is waiting on the futex.
+ mFmqResultChannel.writeBlocking(errorPacket.data(), errorPacket.size());
+ } else {
+ // Always send the packet with "blocking" because this signals the futex and unblocks the
+ // consumer if it is waiting on the futex.
+ mFmqResultChannel.writeBlocking(packet.data(), packet.size());
}
-
- // Always send the packet with "blocking" because this signals the futex and
- // unblocks the consumer if it is waiting on the futex.
- return mFmqResultChannel->writeBlocking(packet.data(), packet.size());
}
// ResultChannelReceiver methods
-std::pair<std::unique_ptr<ResultChannelReceiver>, const FmqResultDescriptor*>
+nn::GeneralResult<
+ std::pair<std::unique_ptr<ResultChannelReceiver>, const MQDescriptorSync<FmqResultDatum>*>>
ResultChannelReceiver::create(size_t channelLength, std::chrono::microseconds pollingTimeWindow) {
- std::unique_ptr<FmqResultChannel> fmqResultChannel =
- std::make_unique<FmqResultChannel>(channelLength, /*confEventFlag=*/true);
- if (!fmqResultChannel->isValid()) {
- LOG(ERROR) << "Unable to create ResultChannelReceiver";
- return {nullptr, nullptr};
+ auto resultChannelReceiver = std::make_unique<ResultChannelReceiver>(
+ PrivateConstructorTag{}, channelLength, pollingTimeWindow);
+ if (!resultChannelReceiver->mFmqResultChannel.isValid()) {
+ return NN_ERROR() << "Unable to create ResultChannelReceiver";
}
- const FmqResultDescriptor* descriptor = fmqResultChannel->getDesc();
- return std::make_pair(
- std::make_unique<ResultChannelReceiver>(std::move(fmqResultChannel), pollingTimeWindow),
- descriptor);
+ const MQDescriptorSync<FmqResultDatum>* descriptor =
+ resultChannelReceiver->mFmqResultChannel.getDesc();
+ return std::make_pair(std::move(resultChannelReceiver), descriptor);
}
-ResultChannelReceiver::ResultChannelReceiver(std::unique_ptr<FmqResultChannel> fmqResultChannel,
+ResultChannelReceiver::ResultChannelReceiver(PrivateConstructorTag /*tag*/, size_t channelLength,
std::chrono::microseconds pollingTimeWindow)
- : mFmqResultChannel(std::move(fmqResultChannel)), kPollingTimeWindow(pollingTimeWindow) {}
+ : mFmqResultChannel(channelLength, /*configureEventFlagWord=*/true),
+ kPollingTimeWindow(pollingTimeWindow) {}
-std::optional<std::tuple<V1_0::ErrorStatus, std::vector<V1_2::OutputShape>, V1_2::Timing>>
+nn::Result<std::tuple<V1_0::ErrorStatus, std::vector<V1_2::OutputShape>, V1_2::Timing>>
ResultChannelReceiver::getBlocking() {
- const auto packet = getPacketBlocking();
- if (!packet) {
- return std::nullopt;
- }
-
- return deserialize(*packet);
+ const auto packet = NN_TRY(getPacketBlocking());
+ return deserialize(packet);
}
-void ResultChannelReceiver::invalidate() {
+void ResultChannelReceiver::notifyAsDeadObject() {
mValid = false;
// force unblock
- // ExecutionBurstController waits on a result packet after sending a
- // request. If the driver containing ExecutionBurstServer crashes, the
- // controller may be waiting on the futex. This force unblock wakes up any
- // thread waiting on the futex.
- // TODO: look for a different/better way to signal/notify the futex to
- // wake up any thread waiting on it
- FmqResultDatum datum;
- datum.packetInformation({/*.packetSize=*/0,
- /*.errorStatus=*/V1_0::ErrorStatus::GENERAL_FAILURE,
- /*.numberOfOperands=*/0});
- mFmqResultChannel->writeBlocking(&datum, 1);
+ // ExecutionBurstController waits on a result packet after sending a request. If the driver
+ // containing ExecutionBurstServer crashes, the controller may be waiting on the futex. This
+ // force unblock wakes up any thread waiting on the futex.
+ const auto data = serialize(V1_0::ErrorStatus::GENERAL_FAILURE, {}, kNoTiming);
+ mFmqResultChannel.writeBlocking(data.data(), data.size());
}
-std::optional<std::vector<FmqResultDatum>> ResultChannelReceiver::getPacketBlocking() {
+nn::Result<std::vector<FmqResultDatum>> ResultChannelReceiver::getPacketBlocking() {
if (!mValid) {
- return std::nullopt;
+ return NN_ERROR() << "FMQ object is invalid";
}
- // First spend time polling if results are available in FMQ instead of
- // waiting on the futex. Polling is more responsive (yielding lower
- // latencies), but can take up more power, so only poll for a limited period
- // of time.
+ // First spend time polling if results are available in FMQ instead of waiting on the futex.
+ // Polling is more responsive (yielding lower latencies), but can take up more power, so only
+ // poll for a limited period of time.
auto& getCurrentTime = std::chrono::high_resolution_clock::now;
const auto timeToStopPolling = getCurrentTime() + kPollingTimeWindow;
@@ -696,54 +652,49 @@
while (getCurrentTime() < timeToStopPolling) {
// if class is being torn down, immediately return
if (!mValid.load(std::memory_order_relaxed)) {
- return std::nullopt;
+ return NN_ERROR() << "FMQ object is invalid";
}
- // Check if data is available. If it is, immediately retrieve it and
- // return.
- const size_t available = mFmqResultChannel->availableToRead();
+ // Check if data is available. If it is, immediately retrieve it and return.
+ const size_t available = mFmqResultChannel.availableToRead();
if (available > 0) {
std::vector<FmqResultDatum> packet(available);
- const bool success = mFmqResultChannel->read(packet.data(), available);
+ const bool success = mFmqResultChannel.readBlocking(packet.data(), available);
if (!success) {
- LOG(ERROR) << "Error receiving packet";
- return std::nullopt;
+ return NN_ERROR() << "Error receiving packet";
}
- return std::make_optional(std::move(packet));
+ return packet;
}
}
- // If we get to this point, we either stopped polling because it was taking
- // too long or polling was not allowed. Instead, perform a blocking call
- // which uses a futex to save power.
+ // If we get to this point, we either stopped polling because it was taking too long or polling
+ // was not allowed. Instead, perform a blocking call which uses a futex to save power.
// wait for result packet and read first element of result packet
FmqResultDatum datum;
- bool success = mFmqResultChannel->readBlocking(&datum, 1);
+ bool success = mFmqResultChannel.readBlocking(&datum, 1);
// retrieve remaining elements
- // NOTE: all of the data is already available at this point, so there's no
- // need to do a blocking wait to wait for more data. This is known because
- // in FMQ, all writes are published (made available) atomically. Currently,
- // the producer always publishes the entire packet in one function call, so
- // if the first element of the packet is available, the remaining elements
- // are also available.
- const size_t count = mFmqResultChannel->availableToRead();
+ // NOTE: all of the data is already available at this point, so there's no need to do a blocking
+ // wait to wait for more data. This is known because in FMQ, all writes are published (made
+ // available) atomically. Currently, the producer always publishes the entire packet in one
+ // function call, so if the first element of the packet is available, the remaining elements are
+ // also available.
+ const size_t count = mFmqResultChannel.availableToRead();
std::vector<FmqResultDatum> packet(count + 1);
std::memcpy(&packet.front(), &datum, sizeof(datum));
- success &= mFmqResultChannel->read(packet.data() + 1, count);
+ success &= mFmqResultChannel.read(packet.data() + 1, count);
if (!mValid) {
- return std::nullopt;
+ return NN_ERROR() << "FMQ object is invalid";
}
// ensure packet was successfully received
if (!success) {
- LOG(ERROR) << "Error receiving packet";
- return std::nullopt;
+ return NN_ERROR() << "Error receiving packet";
}
- return std::make_optional(std::move(packet));
+ return packet;
}
} // namespace android::hardware::neuralnetworks::V1_2::utils
diff --git a/neuralnetworks/1.2/utils/src/PreparedModel.cpp b/neuralnetworks/1.2/utils/src/PreparedModel.cpp
index 6841c5e..71a4ea8 100644
--- a/neuralnetworks/1.2/utils/src/PreparedModel.cpp
+++ b/neuralnetworks/1.2/utils/src/PreparedModel.cpp
@@ -18,6 +18,8 @@
#include "Callbacks.h"
#include "Conversions.h"
+#include "ExecutionBurstController.h"
+#include "ExecutionBurstUtils.h"
#include "Utils.h"
#include <android/hardware/neuralnetworks/1.0/types.h>
@@ -27,12 +29,12 @@
#include <nnapi/IPreparedModel.h>
#include <nnapi/Result.h>
#include <nnapi/Types.h>
-#include <nnapi/hal/1.0/Burst.h>
#include <nnapi/hal/1.0/Conversions.h>
#include <nnapi/hal/CommonUtils.h>
#include <nnapi/hal/HandleError.h>
#include <nnapi/hal/ProtectCallback.h>
+#include <chrono>
#include <memory>
#include <tuple>
#include <utility>
@@ -119,7 +121,14 @@
}
nn::GeneralResult<nn::SharedBurst> PreparedModel::configureExecutionBurst() const {
- return V1_0::utils::Burst::create(shared_from_this());
+ auto self = shared_from_this();
+ auto fallback = [preparedModel = std::move(self)](const nn::Request& request,
+ nn::MeasureTiming measure)
+ -> nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> {
+ return preparedModel->execute(request, measure, {}, {});
+ };
+ const auto pollingTimeWindow = getBurstControllerPollingTimeWindow();
+ return ExecutionBurstController::create(kPreparedModel, std::move(fallback), pollingTimeWindow);
}
std::any PreparedModel::getUnderlyingResource() const {
diff --git a/neuralnetworks/1.2/utils/test/DeviceTest.cpp b/neuralnetworks/1.2/utils/test/DeviceTest.cpp
index 9c8adde..215d44c 100644
--- a/neuralnetworks/1.2/utils/test/DeviceTest.cpp
+++ b/neuralnetworks/1.2/utils/test/DeviceTest.cpp
@@ -772,7 +772,7 @@
EXPECT_NE(result.value(), nullptr);
}
-TEST(DeviceTest, prepareModelFromCacheError) {
+TEST(DeviceTest, prepareModelFromCacheLaunchError) {
// setup call
const auto mockDevice = createMockDevice();
const auto device = Device::create(kName, mockDevice).value();
@@ -790,6 +790,23 @@
EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
}
+TEST(DeviceTest, prepareModelFromCacheReturnError) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ const auto device = Device::create(kName, mockDevice).value();
+ EXPECT_CALL(*mockDevice, prepareModelFromCache(_, _, _, _))
+ .Times(1)
+ .WillOnce(Invoke(makePreparedModelFromCacheReturn(
+ V1_0::ErrorStatus::NONE, V1_0::ErrorStatus::GENERAL_FAILURE, nullptr)));
+
+ // run test
+ const auto result = device->prepareModelFromCache({}, {}, {}, {});
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
TEST(DeviceTest, prepareModelFromCacheNullptrError) {
// setup call
const auto mockDevice = createMockDevice();
diff --git a/neuralnetworks/1.3/utils/Android.bp b/neuralnetworks/1.3/utils/Android.bp
index 2b1dcc4..28c036a 100644
--- a/neuralnetworks/1.3/utils/Android.bp
+++ b/neuralnetworks/1.3/utils/Android.bp
@@ -42,6 +42,7 @@
"android.hardware.neuralnetworks@1.1",
"android.hardware.neuralnetworks@1.2",
"android.hardware.neuralnetworks@1.3",
+ "libfmq",
],
export_static_lib_headers: [
"neuralnetworks_utils_hal_common",
diff --git a/neuralnetworks/1.3/utils/include/nnapi/hal/1.3/Conversions.h b/neuralnetworks/1.3/utils/include/nnapi/hal/1.3/Conversions.h
index 8e1cdb8..b677c62 100644
--- a/neuralnetworks/1.3/utils/include/nnapi/hal/1.3/Conversions.h
+++ b/neuralnetworks/1.3/utils/include/nnapi/hal/1.3/Conversions.h
@@ -59,7 +59,6 @@
GeneralResult<ErrorStatus> convert(const hal::V1_3::ErrorStatus& errorStatus);
GeneralResult<SharedHandle> convert(const hardware::hidl_handle& handle);
-GeneralResult<SharedMemory> convert(const hardware::hidl_memory& memory);
GeneralResult<std::vector<BufferRole>> convert(
const hardware::hidl_vec<hal::V1_3::BufferRole>& bufferRoles);
diff --git a/neuralnetworks/1.3/utils/src/Conversions.cpp b/neuralnetworks/1.3/utils/src/Conversions.cpp
index 320c74c..9788fe1 100644
--- a/neuralnetworks/1.3/utils/src/Conversions.cpp
+++ b/neuralnetworks/1.3/utils/src/Conversions.cpp
@@ -352,10 +352,6 @@
return validatedConvert(handle);
}
-GeneralResult<SharedMemory> convert(const hardware::hidl_memory& memory) {
- return validatedConvert(memory);
-}
-
GeneralResult<std::vector<BufferRole>> convert(
const hardware::hidl_vec<hal::V1_3::BufferRole>& bufferRoles) {
return validatedConvert(bufferRoles);
diff --git a/neuralnetworks/1.3/utils/src/PreparedModel.cpp b/neuralnetworks/1.3/utils/src/PreparedModel.cpp
index 725e4f5..64275a3 100644
--- a/neuralnetworks/1.3/utils/src/PreparedModel.cpp
+++ b/neuralnetworks/1.3/utils/src/PreparedModel.cpp
@@ -29,8 +29,9 @@
#include <nnapi/Result.h>
#include <nnapi/TypeUtils.h>
#include <nnapi/Types.h>
-#include <nnapi/hal/1.0/Burst.h>
#include <nnapi/hal/1.2/Conversions.h>
+#include <nnapi/hal/1.2/ExecutionBurstController.h>
+#include <nnapi/hal/1.2/ExecutionBurstUtils.h>
#include <nnapi/hal/CommonUtils.h>
#include <nnapi/hal/HandleError.h>
#include <nnapi/hal/ProtectCallback.h>
@@ -199,7 +200,15 @@
}
nn::GeneralResult<nn::SharedBurst> PreparedModel::configureExecutionBurst() const {
- return V1_0::utils::Burst::create(shared_from_this());
+ auto self = shared_from_this();
+ auto fallback = [preparedModel = std::move(self)](const nn::Request& request,
+ nn::MeasureTiming measure)
+ -> nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> {
+ return preparedModel->execute(request, measure, {}, {});
+ };
+ const auto pollingTimeWindow = V1_2::utils::getBurstControllerPollingTimeWindow();
+ return V1_2::utils::ExecutionBurstController::create(kPreparedModel, std::move(fallback),
+ pollingTimeWindow);
}
std::any PreparedModel::getUnderlyingResource() const {
diff --git a/neuralnetworks/1.3/utils/test/DeviceTest.cpp b/neuralnetworks/1.3/utils/test/DeviceTest.cpp
index f260990..2d1b2f2 100644
--- a/neuralnetworks/1.3/utils/test/DeviceTest.cpp
+++ b/neuralnetworks/1.3/utils/test/DeviceTest.cpp
@@ -794,7 +794,7 @@
EXPECT_NE(result.value(), nullptr);
}
-TEST(DeviceTest, prepareModelFromCacheError) {
+TEST(DeviceTest, prepareModelFromCacheLaunchError) {
// setup call
const auto mockDevice = createMockDevice();
const auto device = Device::create(kName, mockDevice).value();
@@ -812,6 +812,23 @@
EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
}
+TEST(DeviceTest, prepareModelFromCacheReturnError) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ const auto device = Device::create(kName, mockDevice).value();
+ EXPECT_CALL(*mockDevice, prepareModelFromCache_1_3(_, _, _, _, _))
+ .Times(1)
+ .WillOnce(Invoke(makePreparedModelFromCacheReturn(
+ V1_3::ErrorStatus::NONE, V1_3::ErrorStatus::GENERAL_FAILURE, nullptr)));
+
+ // run test
+ const auto result = device->prepareModelFromCache({}, {}, {}, {});
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
TEST(DeviceTest, prepareModelFromCacheNullptrError) {
// setup call
const auto mockDevice = createMockDevice();
diff --git a/neuralnetworks/TEST_MAPPING b/neuralnetworks/TEST_MAPPING
index 5d168d2..d296828 100644
--- a/neuralnetworks/TEST_MAPPING
+++ b/neuralnetworks/TEST_MAPPING
@@ -16,6 +16,9 @@
"name": "neuralnetworks_utils_hal_1_3_test"
},
{
+ "name": "neuralnetworks_utils_hal_aidl_test"
+ },
+ {
"name": "VtsHalNeuralnetworksV1_0TargetTest",
"options": [
{
diff --git a/neuralnetworks/aidl/utils/Android.bp b/neuralnetworks/aidl/utils/Android.bp
index 2673cae..476dac9 100644
--- a/neuralnetworks/aidl/utils/Android.bp
+++ b/neuralnetworks/aidl/utils/Android.bp
@@ -29,10 +29,12 @@
srcs: ["src/*"],
local_include_dirs: ["include/nnapi/hal/aidl/"],
export_include_dirs: ["include"],
+ cflags: ["-Wthread-safety"],
static_libs: [
"libarect",
"neuralnetworks_types",
"neuralnetworks_utils_hal_common",
+ "neuralnetworks_utils_hal_1_0",
],
shared_libs: [
"android.hardware.neuralnetworks-V1-ndk_platform",
@@ -41,3 +43,38 @@
"libnativewindow",
],
}
+
+cc_test {
+ name: "neuralnetworks_utils_hal_aidl_test",
+ defaults: ["neuralnetworks_utils_defaults"],
+ srcs: [
+ "test/*.cpp",
+ ],
+ static_libs: [
+ "android.hardware.common-V2-ndk_platform",
+ "android.hardware.neuralnetworks-V1-ndk_platform",
+ "libgmock",
+ "libneuralnetworks_common",
+ "neuralnetworks_types",
+ "neuralnetworks_utils_hal_aidl",
+ "neuralnetworks_utils_hal_common",
+ ],
+ shared_libs: [
+ "android.hidl.allocator@1.0",
+ "libbase",
+ "libbinder_ndk",
+ "libcutils",
+ "libhidlbase",
+ "libhidlmemory",
+ "liblog",
+ "libnativewindow",
+ "libutils",
+ ],
+ cflags: [
+ /* GMOCK defines functions for printing all MOCK_DEVICE arguments and
+ * MockDevice contains a string pointer which triggers a warning in the
+ * base logging library. */
+ "-Wno-user-defined-warnings",
+ ],
+ test_suites: ["general-tests"],
+}
diff --git a/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/Buffer.h b/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/Buffer.h
new file mode 100644
index 0000000..46190c4
--- /dev/null
+++ b/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/Buffer.h
@@ -0,0 +1,56 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_AIDL_UTILS_BUFFER_H
+#define ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_AIDL_UTILS_BUFFER_H
+
+#include <aidl/android/hardware/neuralnetworks/IBuffer.h>
+#include <nnapi/IBuffer.h>
+#include <nnapi/Result.h>
+#include <nnapi/Types.h>
+#include <nnapi/hal/CommonUtils.h>
+#include <memory>
+
+// See hardware/interfaces/neuralnetworks/utils/README.md for more information on AIDL interface
+// lifetimes across processes.
+
+namespace aidl::android::hardware::neuralnetworks::utils {
+
+// Class that adapts aidl_hal::IBuffer to nn::IBuffer.
+class Buffer final : public nn::IBuffer {
+ struct PrivateConstructorTag {};
+
+ public:
+ static nn::GeneralResult<std::shared_ptr<const Buffer>> create(
+ std::shared_ptr<aidl_hal::IBuffer> buffer, nn::Request::MemoryDomainToken token);
+
+ Buffer(PrivateConstructorTag tag, std::shared_ptr<aidl_hal::IBuffer> buffer,
+ nn::Request::MemoryDomainToken token);
+
+ nn::Request::MemoryDomainToken getToken() const override;
+
+ nn::GeneralResult<void> copyTo(const nn::SharedMemory& dst) const override;
+ nn::GeneralResult<void> copyFrom(const nn::SharedMemory& src,
+ const nn::Dimensions& dimensions) const override;
+
+ private:
+ const std::shared_ptr<aidl_hal::IBuffer> kBuffer;
+ const nn::Request::MemoryDomainToken kToken;
+};
+
+} // namespace aidl::android::hardware::neuralnetworks::utils
+
+#endif // ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_AIDL_UTILS_BUFFER_H
diff --git a/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/Callbacks.h b/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/Callbacks.h
new file mode 100644
index 0000000..8651912
--- /dev/null
+++ b/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/Callbacks.h
@@ -0,0 +1,53 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_AIDL_UTILS_CALLBACKS_H
+#define ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_AIDL_UTILS_CALLBACKS_H
+
+#include <aidl/android/hardware/neuralnetworks/BnPreparedModelCallback.h>
+#include <aidl/android/hardware/neuralnetworks/IDevice.h>
+#include <nnapi/IPreparedModel.h>
+#include <nnapi/Result.h>
+#include <nnapi/Types.h>
+#include <nnapi/hal/CommonUtils.h>
+#include <nnapi/hal/TransferValue.h>
+#include <nnapi/hal/aidl/ProtectCallback.h>
+
+// See hardware/interfaces/neuralnetworks/utils/README.md for more information on AIDL interface
+// lifetimes across processes and for protecting asynchronous calls across AIDL.
+
+namespace aidl::android::hardware::neuralnetworks::utils {
+
+// An AIDL callback class to receive the results of IDevice::prepareModel* asynchronously.
+class PreparedModelCallback final : public BnPreparedModelCallback,
+ public hal::utils::IProtectedCallback {
+ public:
+ using Data = nn::GeneralResult<nn::SharedPreparedModel>;
+
+ ndk::ScopedAStatus notify(ErrorStatus status,
+ const std::shared_ptr<IPreparedModel>& preparedModel) override;
+
+ void notifyAsDeadObject() override;
+
+ Data get();
+
+ private:
+ hal::utils::TransferValue<Data> mData;
+};
+
+} // namespace aidl::android::hardware::neuralnetworks::utils
+
+#endif // ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_AIDL_UTILS_CALLBACKS_H
diff --git a/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/Conversions.h b/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/Conversions.h
index 1b2f69c..4922a6e 100644
--- a/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/Conversions.h
+++ b/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/Conversions.h
@@ -46,6 +46,7 @@
#include <aidl/android/hardware/neuralnetworks/SymmPerChannelQuantParams.h>
#include <aidl/android/hardware/neuralnetworks/Timing.h>
+#include <android/binder_auto_utils.h>
#include <nnapi/Result.h>
#include <nnapi/Types.h>
#include <nnapi/hal/CommonUtils.h>
@@ -96,7 +97,11 @@
const aidl_hal::ExtensionOperandTypeInformation& operandTypeInformation);
GeneralResult<SharedHandle> unvalidatedConvert(
const ::aidl::android::hardware::common::NativeHandle& handle);
+GeneralResult<SyncFence> unvalidatedConvert(const ndk::ScopedFileDescriptor& syncFence);
+GeneralResult<Capabilities> convert(const aidl_hal::Capabilities& capabilities);
+GeneralResult<DeviceType> convert(const aidl_hal::DeviceType& deviceType);
+GeneralResult<ErrorStatus> convert(const aidl_hal::ErrorStatus& errorStatus);
GeneralResult<ExecutionPreference> convert(
const aidl_hal::ExecutionPreference& executionPreference);
GeneralResult<SharedMemory> convert(const aidl_hal::Memory& memory);
@@ -106,9 +111,14 @@
GeneralResult<Priority> convert(const aidl_hal::Priority& priority);
GeneralResult<Request::MemoryPool> convert(const aidl_hal::RequestMemoryPool& memoryPool);
GeneralResult<Request> convert(const aidl_hal::Request& request);
+GeneralResult<Timing> convert(const aidl_hal::Timing& timing);
+GeneralResult<SyncFence> convert(const ndk::ScopedFileDescriptor& syncFence);
+GeneralResult<std::vector<Extension>> convert(const std::vector<aidl_hal::Extension>& extension);
GeneralResult<std::vector<Operation>> convert(const std::vector<aidl_hal::Operation>& outputShapes);
GeneralResult<std::vector<SharedMemory>> convert(const std::vector<aidl_hal::Memory>& memories);
+GeneralResult<std::vector<OutputShape>> convert(
+ const std::vector<aidl_hal::OutputShape>& outputShapes);
GeneralResult<std::vector<uint32_t>> toUnsigned(const std::vector<int32_t>& vec);
@@ -118,14 +128,62 @@
namespace nn = ::android::nn;
+nn::GeneralResult<std::vector<uint8_t>> unvalidatedConvert(const nn::CacheToken& cacheToken);
+nn::GeneralResult<BufferDesc> unvalidatedConvert(const nn::BufferDesc& bufferDesc);
+nn::GeneralResult<BufferRole> unvalidatedConvert(const nn::BufferRole& bufferRole);
+nn::GeneralResult<bool> unvalidatedConvert(const nn::MeasureTiming& measureTiming);
nn::GeneralResult<Memory> unvalidatedConvert(const nn::SharedMemory& memory);
nn::GeneralResult<OutputShape> unvalidatedConvert(const nn::OutputShape& outputShape);
nn::GeneralResult<ErrorStatus> unvalidatedConvert(const nn::ErrorStatus& errorStatus);
+nn::GeneralResult<ExecutionPreference> unvalidatedConvert(
+ const nn::ExecutionPreference& executionPreference);
+nn::GeneralResult<OperandType> unvalidatedConvert(const nn::OperandType& operandType);
+nn::GeneralResult<OperandLifeTime> unvalidatedConvert(const nn::Operand::LifeTime& operandLifeTime);
+nn::GeneralResult<DataLocation> unvalidatedConvert(const nn::DataLocation& location);
+nn::GeneralResult<std::optional<OperandExtraParams>> unvalidatedConvert(
+ const nn::Operand::ExtraParams& extraParams);
+nn::GeneralResult<Operand> unvalidatedConvert(const nn::Operand& operand);
+nn::GeneralResult<OperationType> unvalidatedConvert(const nn::OperationType& operationType);
+nn::GeneralResult<Operation> unvalidatedConvert(const nn::Operation& operation);
+nn::GeneralResult<Subgraph> unvalidatedConvert(const nn::Model::Subgraph& subgraph);
+nn::GeneralResult<std::vector<uint8_t>> unvalidatedConvert(
+ const nn::Model::OperandValues& operandValues);
+nn::GeneralResult<ExtensionNameAndPrefix> unvalidatedConvert(
+ const nn::Model::ExtensionNameAndPrefix& extensionNameToPrefix);
+nn::GeneralResult<Model> unvalidatedConvert(const nn::Model& model);
+nn::GeneralResult<Priority> unvalidatedConvert(const nn::Priority& priority);
+nn::GeneralResult<Request> unvalidatedConvert(const nn::Request& request);
+nn::GeneralResult<RequestArgument> unvalidatedConvert(const nn::Request::Argument& requestArgument);
+nn::GeneralResult<RequestMemoryPool> unvalidatedConvert(const nn::Request::MemoryPool& memoryPool);
+nn::GeneralResult<Timing> unvalidatedConvert(const nn::Timing& timing);
+nn::GeneralResult<int64_t> unvalidatedConvert(const nn::Duration& duration);
+nn::GeneralResult<int64_t> unvalidatedConvert(const nn::OptionalDuration& optionalDuration);
+nn::GeneralResult<int64_t> unvalidatedConvert(const nn::OptionalTimePoint& optionalTimePoint);
+nn::GeneralResult<ndk::ScopedFileDescriptor> unvalidatedConvert(const nn::SyncFence& syncFence);
+nn::GeneralResult<common::NativeHandle> unvalidatedConvert(const nn::SharedHandle& sharedHandle);
+nn::GeneralResult<ndk::ScopedFileDescriptor> unvalidatedConvertCache(
+ const nn::SharedHandle& handle);
+nn::GeneralResult<std::vector<uint8_t>> convert(const nn::CacheToken& cacheToken);
+nn::GeneralResult<BufferDesc> convert(const nn::BufferDesc& bufferDesc);
+nn::GeneralResult<bool> convert(const nn::MeasureTiming& measureTiming);
nn::GeneralResult<Memory> convert(const nn::SharedMemory& memory);
nn::GeneralResult<ErrorStatus> convert(const nn::ErrorStatus& errorStatus);
+nn::GeneralResult<ExecutionPreference> convert(const nn::ExecutionPreference& executionPreference);
+nn::GeneralResult<Model> convert(const nn::Model& model);
+nn::GeneralResult<Priority> convert(const nn::Priority& priority);
+nn::GeneralResult<Request> convert(const nn::Request& request);
+nn::GeneralResult<Timing> convert(const nn::Timing& timing);
+nn::GeneralResult<int64_t> convert(const nn::OptionalDuration& optionalDuration);
+nn::GeneralResult<int64_t> convert(const nn::OptionalTimePoint& optionalTimePoint);
+
+nn::GeneralResult<std::vector<BufferRole>> convert(const std::vector<nn::BufferRole>& bufferRoles);
nn::GeneralResult<std::vector<OutputShape>> convert(
const std::vector<nn::OutputShape>& outputShapes);
+nn::GeneralResult<std::vector<ndk::ScopedFileDescriptor>> convert(
+ const std::vector<nn::SharedHandle>& handles);
+nn::GeneralResult<std::vector<ndk::ScopedFileDescriptor>> convert(
+ const std::vector<nn::SyncFence>& syncFences);
nn::GeneralResult<std::vector<int32_t>> toSigned(const std::vector<uint32_t>& vec);
diff --git a/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/Device.h b/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/Device.h
new file mode 100644
index 0000000..eb194e3
--- /dev/null
+++ b/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/Device.h
@@ -0,0 +1,98 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_AIDL_UTILS_DEVICE_H
+#define ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_AIDL_UTILS_DEVICE_H
+
+#include <aidl/android/hardware/neuralnetworks/IDevice.h>
+#include <nnapi/IBuffer.h>
+#include <nnapi/IDevice.h>
+#include <nnapi/OperandTypes.h>
+#include <nnapi/Result.h>
+#include <nnapi/Types.h>
+#include <nnapi/hal/CommonUtils.h>
+#include <nnapi/hal/aidl/ProtectCallback.h>
+
+#include <functional>
+#include <memory>
+#include <optional>
+#include <string>
+#include <vector>
+
+// See hardware/interfaces/neuralnetworks/utils/README.md for more information on AIDL interface
+// lifetimes across processes and for protecting asynchronous calls across AIDL.
+
+namespace aidl::android::hardware::neuralnetworks::utils {
+
+// Class that adapts aidl_hal::IDevice to nn::IDevice.
+class Device final : public nn::IDevice {
+ struct PrivateConstructorTag {};
+
+ public:
+ static nn::GeneralResult<std::shared_ptr<const Device>> create(
+ std::string name, std::shared_ptr<aidl_hal::IDevice> device);
+
+ Device(PrivateConstructorTag tag, std::string name, std::string versionString,
+ nn::DeviceType deviceType, std::vector<nn::Extension> extensions,
+ nn::Capabilities capabilities, std::pair<uint32_t, uint32_t> numberOfCacheFilesNeeded,
+ std::shared_ptr<aidl_hal::IDevice> device, DeathHandler deathHandler);
+
+ const std::string& getName() const override;
+ const std::string& getVersionString() const override;
+ nn::Version getFeatureLevel() const override;
+ nn::DeviceType getType() const override;
+ bool isUpdatable() const override;
+ const std::vector<nn::Extension>& getSupportedExtensions() const override;
+ const nn::Capabilities& getCapabilities() const override;
+ std::pair<uint32_t, uint32_t> getNumberOfCacheFilesNeeded() const override;
+
+ nn::GeneralResult<void> wait() const override;
+
+ nn::GeneralResult<std::vector<bool>> getSupportedOperations(
+ const nn::Model& model) const override;
+
+ nn::GeneralResult<nn::SharedPreparedModel> prepareModel(
+ const nn::Model& model, nn::ExecutionPreference preference, nn::Priority priority,
+ nn::OptionalTimePoint deadline, const std::vector<nn::SharedHandle>& modelCache,
+ const std::vector<nn::SharedHandle>& dataCache,
+ const nn::CacheToken& token) const override;
+
+ nn::GeneralResult<nn::SharedPreparedModel> prepareModelFromCache(
+ nn::OptionalTimePoint deadline, const std::vector<nn::SharedHandle>& modelCache,
+ const std::vector<nn::SharedHandle>& dataCache,
+ const nn::CacheToken& token) const override;
+
+ nn::GeneralResult<nn::SharedBuffer> allocate(
+ const nn::BufferDesc& desc, const std::vector<nn::SharedPreparedModel>& preparedModels,
+ const std::vector<nn::BufferRole>& inputRoles,
+ const std::vector<nn::BufferRole>& outputRoles) const override;
+
+ DeathMonitor* getDeathMonitor() const;
+
+ private:
+ const std::string kName;
+ const std::string kVersionString;
+ const nn::DeviceType kDeviceType;
+ const std::vector<nn::Extension> kExtensions;
+ const nn::Capabilities kCapabilities;
+ const std::pair<uint32_t, uint32_t> kNumberOfCacheFilesNeeded;
+ const std::shared_ptr<aidl_hal::IDevice> kDevice;
+ const DeathHandler kDeathHandler;
+};
+
+} // namespace aidl::android::hardware::neuralnetworks::utils
+
+#endif // ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_AIDL_UTILS_DEVICE_H
diff --git a/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/PreparedModel.h b/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/PreparedModel.h
new file mode 100644
index 0000000..9b28588
--- /dev/null
+++ b/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/PreparedModel.h
@@ -0,0 +1,70 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_AIDL_UTILS_PREPARED_MODEL_H
+#define ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_AIDL_UTILS_PREPARED_MODEL_H
+
+#include <aidl/android/hardware/neuralnetworks/IPreparedModel.h>
+#include <nnapi/IPreparedModel.h>
+#include <nnapi/Result.h>
+#include <nnapi/Types.h>
+#include <nnapi/hal/CommonUtils.h>
+#include <nnapi/hal/aidl/ProtectCallback.h>
+
+#include <memory>
+#include <tuple>
+#include <utility>
+#include <vector>
+
+// See hardware/interfaces/neuralnetworks/utils/README.md for more information on AIDL interface
+// lifetimes across processes and for protecting asynchronous calls across AIDL.
+
+namespace aidl::android::hardware::neuralnetworks::utils {
+
+// Class that adapts aidl_hal::IPreparedModel to nn::IPreparedModel.
+class PreparedModel final : public nn::IPreparedModel,
+ public std::enable_shared_from_this<PreparedModel> {
+ struct PrivateConstructorTag {};
+
+ public:
+ static nn::GeneralResult<std::shared_ptr<const PreparedModel>> create(
+ std::shared_ptr<aidl_hal::IPreparedModel> preparedModel);
+
+ PreparedModel(PrivateConstructorTag tag,
+ std::shared_ptr<aidl_hal::IPreparedModel> preparedModel);
+
+ nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> execute(
+ const nn::Request& request, nn::MeasureTiming measure,
+ const nn::OptionalTimePoint& deadline,
+ const nn::OptionalDuration& loopTimeoutDuration) const override;
+
+ nn::GeneralResult<std::pair<nn::SyncFence, nn::ExecuteFencedInfoCallback>> executeFenced(
+ const nn::Request& request, const std::vector<nn::SyncFence>& waitFor,
+ nn::MeasureTiming measure, const nn::OptionalTimePoint& deadline,
+ const nn::OptionalDuration& loopTimeoutDuration,
+ const nn::OptionalDuration& timeoutDurationAfterFence) const override;
+
+ nn::GeneralResult<nn::SharedBurst> configureExecutionBurst() const override;
+
+ std::any getUnderlyingResource() const override;
+
+ private:
+ const std::shared_ptr<aidl_hal::IPreparedModel> kPreparedModel;
+};
+
+} // namespace aidl::android::hardware::neuralnetworks::utils
+
+#endif // ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_AIDL_UTILS_PREPARED_MODEL_H
diff --git a/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/ProtectCallback.h b/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/ProtectCallback.h
new file mode 100644
index 0000000..ab1108c
--- /dev/null
+++ b/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/ProtectCallback.h
@@ -0,0 +1,81 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_AIDL_UTILS_PROTECT_CALLBACK_H
+#define ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_AIDL_UTILS_PROTECT_CALLBACK_H
+
+#include <android-base/scopeguard.h>
+#include <android-base/thread_annotations.h>
+#include <android/binder_interface_utils.h>
+#include <nnapi/Result.h>
+#include <nnapi/Types.h>
+#include <nnapi/hal/CommonUtils.h>
+#include <nnapi/hal/ProtectCallback.h>
+
+#include <functional>
+#include <mutex>
+#include <vector>
+
+// See hardware/interfaces/neuralnetworks/utils/README.md for more information on AIDL interface
+// lifetimes across processes and for protecting asynchronous calls across AIDL.
+
+namespace aidl::android::hardware::neuralnetworks::utils {
+
+// Thread safe class
+class DeathMonitor final {
+ public:
+ static void serviceDied(void* cookie);
+ void serviceDied();
+ // Precondition: `killable` must be non-null.
+ void add(hal::utils::IProtectedCallback* killable) const;
+ // Precondition: `killable` must be non-null.
+ void remove(hal::utils::IProtectedCallback* killable) const;
+
+ private:
+ mutable std::mutex mMutex;
+ mutable std::vector<hal::utils::IProtectedCallback*> mObjects GUARDED_BY(mMutex);
+};
+
+class DeathHandler final {
+ public:
+ static nn::GeneralResult<DeathHandler> create(std::shared_ptr<ndk::ICInterface> object);
+
+ DeathHandler(const DeathHandler&) = delete;
+ DeathHandler(DeathHandler&&) noexcept = default;
+ DeathHandler& operator=(const DeathHandler&) = delete;
+ DeathHandler& operator=(DeathHandler&&) noexcept = delete;
+ ~DeathHandler();
+
+ using Cleanup = std::function<void()>;
+ // Precondition: `killable` must be non-null.
+ [[nodiscard]] ::android::base::ScopeGuard<Cleanup> protectCallback(
+ hal::utils::IProtectedCallback* killable) const;
+
+ std::shared_ptr<DeathMonitor> getDeathMonitor() const { return kDeathMonitor; }
+
+ private:
+ DeathHandler(std::shared_ptr<ndk::ICInterface> object,
+ ndk::ScopedAIBinder_DeathRecipient deathRecipient,
+ std::shared_ptr<DeathMonitor> deathMonitor);
+
+ std::shared_ptr<ndk::ICInterface> kObject;
+ ndk::ScopedAIBinder_DeathRecipient kDeathRecipient;
+ std::shared_ptr<DeathMonitor> kDeathMonitor;
+};
+
+} // namespace aidl::android::hardware::neuralnetworks::utils
+
+#endif // ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_AIDL_UTILS_PROTECT_CALLBACK_H
diff --git a/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/Service.h b/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/Service.h
new file mode 100644
index 0000000..b4587ac
--- /dev/null
+++ b/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/Service.h
@@ -0,0 +1,33 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_AIDL_UTILS_SERVICE_H
+#define ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_AIDL_UTILS_SERVICE_H
+
+#include <nnapi/IDevice.h>
+#include <nnapi/Result.h>
+#include <nnapi/Types.h>
+#include <nnapi/hal/CommonUtils.h>
+
+#include <string>
+
+namespace aidl::android::hardware::neuralnetworks::utils {
+
+nn::GeneralResult<nn::SharedDevice> getDevice(const std::string& name);
+
+} // namespace aidl::android::hardware::neuralnetworks::utils
+
+#endif // ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_AIDL_UTILS_SERVICE_H
diff --git a/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/Utils.h b/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/Utils.h
index 79b511d..58dcfe3 100644
--- a/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/Utils.h
+++ b/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/Utils.h
@@ -23,6 +23,7 @@
#include <nnapi/Result.h>
#include <nnapi/Types.h>
#include <nnapi/Validation.h>
+#include <nnapi/hal/HandleError.h>
namespace aidl::android::hardware::neuralnetworks::utils {
@@ -52,6 +53,12 @@
nn::GeneralResult<RequestMemoryPool> clone(const RequestMemoryPool& requestPool);
nn::GeneralResult<Model> clone(const Model& model);
+nn::GeneralResult<void> handleTransportError(const ndk::ScopedAStatus& ret);
+
+#define HANDLE_ASTATUS(ret) \
+ for (const auto status = handleTransportError(ret); !status.ok();) \
+ return NN_ERROR(status.error().code) << status.error().message << ": "
+
} // namespace aidl::android::hardware::neuralnetworks::utils
#endif // ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_AIDL_UTILS_H
diff --git a/neuralnetworks/aidl/utils/src/Buffer.cpp b/neuralnetworks/aidl/utils/src/Buffer.cpp
new file mode 100644
index 0000000..c729a68
--- /dev/null
+++ b/neuralnetworks/aidl/utils/src/Buffer.cpp
@@ -0,0 +1,78 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Buffer.h"
+
+#include <nnapi/IPreparedModel.h>
+#include <nnapi/Result.h>
+#include <nnapi/Types.h>
+
+#include "Conversions.h"
+#include "Utils.h"
+#include "nnapi/hal/aidl/Conversions.h"
+
+#include <memory>
+#include <utility>
+
+// See hardware/interfaces/neuralnetworks/utils/README.md for more information on AIDL interface
+// lifetimes across processes.
+
+namespace aidl::android::hardware::neuralnetworks::utils {
+
+nn::GeneralResult<std::shared_ptr<const Buffer>> Buffer::create(
+ std::shared_ptr<aidl_hal::IBuffer> buffer, nn::Request::MemoryDomainToken token) {
+ if (buffer == nullptr) {
+ return NN_ERROR() << "aidl_hal::utils::Buffer::create must have non-null buffer";
+ }
+ if (token == static_cast<nn::Request::MemoryDomainToken>(0)) {
+ return NN_ERROR() << "aidl_hal::utils::Buffer::create must have non-zero token";
+ }
+
+ return std::make_shared<const Buffer>(PrivateConstructorTag{}, std::move(buffer), token);
+}
+
+Buffer::Buffer(PrivateConstructorTag /*tag*/, std::shared_ptr<aidl_hal::IBuffer> buffer,
+ nn::Request::MemoryDomainToken token)
+ : kBuffer(std::move(buffer)), kToken(token) {
+ CHECK(kBuffer != nullptr);
+ CHECK(kToken != static_cast<nn::Request::MemoryDomainToken>(0));
+}
+
+nn::Request::MemoryDomainToken Buffer::getToken() const {
+ return kToken;
+}
+
+nn::GeneralResult<void> Buffer::copyTo(const nn::SharedMemory& dst) const {
+ const auto aidlDst = NN_TRY(convert(dst));
+
+ const auto ret = kBuffer->copyTo(aidlDst);
+ HANDLE_ASTATUS(ret) << "IBuffer::copyTo failed";
+
+ return {};
+}
+
+nn::GeneralResult<void> Buffer::copyFrom(const nn::SharedMemory& src,
+ const nn::Dimensions& dimensions) const {
+ const auto aidlSrc = NN_TRY(convert(src));
+ const auto aidlDimensions = NN_TRY(toSigned(dimensions));
+
+ const auto ret = kBuffer->copyFrom(aidlSrc, aidlDimensions);
+ HANDLE_ASTATUS(ret) << "IBuffer::copyFrom failed";
+
+ return {};
+}
+
+} // namespace aidl::android::hardware::neuralnetworks::utils
diff --git a/neuralnetworks/aidl/utils/src/Callbacks.cpp b/neuralnetworks/aidl/utils/src/Callbacks.cpp
new file mode 100644
index 0000000..8055665
--- /dev/null
+++ b/neuralnetworks/aidl/utils/src/Callbacks.cpp
@@ -0,0 +1,61 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Callbacks.h"
+
+#include "Conversions.h"
+#include "PreparedModel.h"
+#include "ProtectCallback.h"
+#include "Utils.h"
+
+#include <nnapi/IPreparedModel.h>
+#include <nnapi/Result.h>
+#include <nnapi/Types.h>
+
+#include <utility>
+
+// See hardware/interfaces/neuralnetworks/utils/README.md for more information on AIDL interface
+// lifetimes across processes and for protecting asynchronous calls across AIDL.
+
+namespace aidl::android::hardware::neuralnetworks::utils {
+namespace {
+
+// Converts the results of IDevice::prepareModel* to the NN canonical format. On success, this
+// function returns with a non-null nn::SharedPreparedModel with a feature level of
+// nn::Version::ANDROID_S. On failure, this function returns with the appropriate nn::GeneralError.
+nn::GeneralResult<nn::SharedPreparedModel> prepareModelCallback(
+ ErrorStatus status, const std::shared_ptr<IPreparedModel>& preparedModel) {
+ HANDLE_HAL_STATUS(status) << "model preparation failed with " << toString(status);
+ return NN_TRY(PreparedModel::create(preparedModel));
+}
+
+} // namespace
+
+ndk::ScopedAStatus PreparedModelCallback::notify(
+ ErrorStatus status, const std::shared_ptr<IPreparedModel>& preparedModel) {
+ mData.put(prepareModelCallback(status, preparedModel));
+ return ndk::ScopedAStatus::ok();
+}
+
+void PreparedModelCallback::notifyAsDeadObject() {
+ mData.put(NN_ERROR(nn::ErrorStatus::DEAD_OBJECT) << "Dead object");
+}
+
+PreparedModelCallback::Data PreparedModelCallback::get() {
+ return mData.take();
+}
+
+} // namespace aidl::android::hardware::neuralnetworks::utils
diff --git a/neuralnetworks/aidl/utils/src/Conversions.cpp b/neuralnetworks/aidl/utils/src/Conversions.cpp
index db3504b..5d9c55b 100644
--- a/neuralnetworks/aidl/utils/src/Conversions.cpp
+++ b/neuralnetworks/aidl/utils/src/Conversions.cpp
@@ -18,6 +18,8 @@
#include <aidl/android/hardware/common/NativeHandle.h>
#include <android-base/logging.h>
+#include <android-base/unique_fd.h>
+#include <android/binder_auto_utils.h>
#include <android/hardware_buffer.h>
#include <cutils/native_handle.h>
#include <nnapi/OperandTypes.h>
@@ -42,14 +44,17 @@
#define VERIFY_NON_NEGATIVE(value) \
while (UNLIKELY(value < 0)) return NN_ERROR()
-namespace {
+#define VERIFY_LE_INT32_MAX(value) \
+ while (UNLIKELY(value > std::numeric_limits<int32_t>::max())) return NN_ERROR()
+namespace {
template <typename Type>
constexpr std::underlying_type_t<Type> underlyingType(Type value) {
return static_cast<std::underlying_type_t<Type>>(value);
}
constexpr auto kVersion = android::nn::Version::ANDROID_S;
+constexpr int64_t kNoTiming = -1;
} // namespace
@@ -134,13 +139,8 @@
std::vector<base::unique_fd> fds;
fds.reserve(aidlNativeHandle.fds.size());
for (const auto& fd : aidlNativeHandle.fds) {
- const int dupFd = dup(fd.get());
- if (dupFd == -1) {
- // TODO(b/120417090): is ANEURALNETWORKS_UNEXPECTED_NULL the correct error to return
- // here?
- return NN_ERROR() << "Failed to dup the fd";
- }
- fds.emplace_back(dupFd);
+ auto duplicatedFd = NN_TRY(dupFd(fd.get()));
+ fds.emplace_back(duplicatedFd.release());
}
return Handle{.fds = std::move(fds), .ints = aidlNativeHandle.ints};
@@ -157,16 +157,12 @@
using UniqueNativeHandle = std::unique_ptr<native_handle_t, NativeHandleDeleter>;
-static nn::GeneralResult<UniqueNativeHandle> nativeHandleFromAidlHandle(
- const NativeHandle& handle) {
+static GeneralResult<UniqueNativeHandle> nativeHandleFromAidlHandle(const NativeHandle& handle) {
std::vector<base::unique_fd> fds;
fds.reserve(handle.fds.size());
for (const auto& fd : handle.fds) {
- const int dupFd = dup(fd.get());
- if (dupFd == -1) {
- return NN_ERROR() << "Failed to dup the fd";
- }
- fds.emplace_back(dupFd);
+ auto duplicatedFd = NN_TRY(dupFd(fd.get()));
+ fds.emplace_back(duplicatedFd.release());
}
constexpr size_t kIntMax = std::numeric_limits<int>::max();
@@ -382,14 +378,14 @@
GeneralResult<SharedMemory> unvalidatedConvert(const aidl_hal::Memory& memory) {
VERIFY_NON_NEGATIVE(memory.size) << "Memory size must not be negative";
- if (memory.size > std::numeric_limits<uint32_t>::max()) {
+ if (memory.size > std::numeric_limits<size_t>::max()) {
return NN_ERROR() << "Memory: size must be <= std::numeric_limits<size_t>::max()";
}
if (memory.name != "hardware_buffer_blob") {
return std::make_shared<const Memory>(Memory{
.handle = NN_TRY(unvalidatedConvertHelper(memory.handle)),
- .size = static_cast<uint32_t>(memory.size),
+ .size = static_cast<size_t>(memory.size),
.name = memory.name,
});
}
@@ -434,11 +430,28 @@
return std::make_shared<const Memory>(Memory{
.handle = HardwareBufferHandle(hardwareBuffer, /*takeOwnership=*/true),
- .size = static_cast<uint32_t>(memory.size),
+ .size = static_cast<size_t>(memory.size),
.name = memory.name,
});
}
+GeneralResult<Timing> unvalidatedConvert(const aidl_hal::Timing& timing) {
+ if (timing.timeInDriver < -1) {
+ return NN_ERROR() << "Timing: timeInDriver must not be less than -1";
+ }
+ if (timing.timeOnDevice < -1) {
+ return NN_ERROR() << "Timing: timeOnDevice must not be less than -1";
+ }
+ constexpr auto convertTiming = [](int64_t halTiming) -> OptionalDuration {
+ if (halTiming == kNoTiming) {
+ return {};
+ }
+ return nn::Duration(static_cast<uint64_t>(halTiming));
+ };
+ return Timing{.timeOnDevice = convertTiming(timing.timeOnDevice),
+ .timeInDriver = convertTiming(timing.timeInDriver)};
+}
+
GeneralResult<Model::OperandValues> unvalidatedConvert(const std::vector<uint8_t>& operandValues) {
return Model::OperandValues(operandValues.data(), operandValues.size());
}
@@ -515,6 +528,23 @@
return std::make_shared<const Handle>(NN_TRY(unvalidatedConvertHelper(aidlNativeHandle)));
}
+GeneralResult<SyncFence> unvalidatedConvert(const ndk::ScopedFileDescriptor& syncFence) {
+ auto duplicatedFd = NN_TRY(dupFd(syncFence.get()));
+ return SyncFence::create(std::move(duplicatedFd));
+}
+
+GeneralResult<Capabilities> convert(const aidl_hal::Capabilities& capabilities) {
+ return validatedConvert(capabilities);
+}
+
+GeneralResult<DeviceType> convert(const aidl_hal::DeviceType& deviceType) {
+ return validatedConvert(deviceType);
+}
+
+GeneralResult<ErrorStatus> convert(const aidl_hal::ErrorStatus& errorStatus) {
+ return validatedConvert(errorStatus);
+}
+
GeneralResult<ExecutionPreference> convert(
const aidl_hal::ExecutionPreference& executionPreference) {
return validatedConvert(executionPreference);
@@ -548,6 +578,18 @@
return validatedConvert(request);
}
+GeneralResult<Timing> convert(const aidl_hal::Timing& timing) {
+ return validatedConvert(timing);
+}
+
+GeneralResult<SyncFence> convert(const ndk::ScopedFileDescriptor& syncFence) {
+ return unvalidatedConvert(syncFence);
+}
+
+GeneralResult<std::vector<Extension>> convert(const std::vector<aidl_hal::Extension>& extension) {
+ return validatedConvert(extension);
+}
+
GeneralResult<std::vector<Operation>> convert(const std::vector<aidl_hal::Operation>& operations) {
return unvalidatedConvert(operations);
}
@@ -556,6 +598,11 @@
return validatedConvert(memories);
}
+GeneralResult<std::vector<OutputShape>> convert(
+ const std::vector<aidl_hal::OutputShape>& outputShapes) {
+ return validatedConvert(outputShapes);
+}
+
GeneralResult<std::vector<uint32_t>> toUnsigned(const std::vector<int32_t>& vec) {
if (!std::all_of(vec.begin(), vec.end(), [](int32_t v) { return v >= 0; })) {
return NN_ERROR() << "Negative value passed to conversion from signed to unsigned";
@@ -575,14 +622,21 @@
template <typename Type>
nn::GeneralResult<std::vector<UnvalidatedConvertOutput<Type>>> unvalidatedConvertVec(
const std::vector<Type>& arguments) {
- std::vector<UnvalidatedConvertOutput<Type>> halObject(arguments.size());
- for (size_t i = 0; i < arguments.size(); ++i) {
- halObject[i] = NN_TRY(unvalidatedConvert(arguments[i]));
+ std::vector<UnvalidatedConvertOutput<Type>> halObject;
+ halObject.reserve(arguments.size());
+ for (const auto& argument : arguments) {
+ halObject.push_back(NN_TRY(unvalidatedConvert(argument)));
}
return halObject;
}
template <typename Type>
+nn::GeneralResult<std::vector<UnvalidatedConvertOutput<Type>>> unvalidatedConvert(
+ const std::vector<Type>& arguments) {
+ return unvalidatedConvertVec(arguments);
+}
+
+template <typename Type>
nn::GeneralResult<UnvalidatedConvertOutput<Type>> validatedConvert(const Type& canonical) {
const auto maybeVersion = nn::validate(canonical);
if (!maybeVersion.has_value()) {
@@ -609,29 +663,29 @@
common::NativeHandle aidlNativeHandle;
aidlNativeHandle.fds.reserve(handle.fds.size());
for (const auto& fd : handle.fds) {
- const int dupFd = dup(fd.get());
- if (dupFd == -1) {
- // TODO(b/120417090): is ANEURALNETWORKS_UNEXPECTED_NULL the correct error to return
- // here?
- return NN_ERROR() << "Failed to dup the fd";
- }
- aidlNativeHandle.fds.emplace_back(dupFd);
+ auto duplicatedFd = NN_TRY(nn::dupFd(fd.get()));
+ aidlNativeHandle.fds.emplace_back(duplicatedFd.release());
}
aidlNativeHandle.ints = handle.ints;
return aidlNativeHandle;
}
+// Helper template for std::visit
+template <class... Ts>
+struct overloaded : Ts... {
+ using Ts::operator()...;
+};
+template <class... Ts>
+overloaded(Ts...)->overloaded<Ts...>;
+
static nn::GeneralResult<common::NativeHandle> aidlHandleFromNativeHandle(
const native_handle_t& handle) {
common::NativeHandle aidlNativeHandle;
aidlNativeHandle.fds.reserve(handle.numFds);
for (int i = 0; i < handle.numFds; ++i) {
- const int dupFd = dup(handle.data[i]);
- if (dupFd == -1) {
- return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE) << "Failed to dup the fd";
- }
- aidlNativeHandle.fds.emplace_back(dupFd);
+ auto duplicatedFd = NN_TRY(nn::dupFd(handle.data[i]));
+ aidlNativeHandle.fds.emplace_back(duplicatedFd.release());
}
aidlNativeHandle.ints = std::vector<int>(&handle.data[handle.numFds],
@@ -642,6 +696,30 @@
} // namespace
+nn::GeneralResult<std::vector<uint8_t>> unvalidatedConvert(const nn::CacheToken& cacheToken) {
+ return std::vector<uint8_t>(cacheToken.begin(), cacheToken.end());
+}
+
+nn::GeneralResult<BufferDesc> unvalidatedConvert(const nn::BufferDesc& bufferDesc) {
+ return BufferDesc{.dimensions = NN_TRY(toSigned(bufferDesc.dimensions))};
+}
+
+nn::GeneralResult<BufferRole> unvalidatedConvert(const nn::BufferRole& bufferRole) {
+ VERIFY_LE_INT32_MAX(bufferRole.modelIndex)
+ << "BufferRole: modelIndex must be <= std::numeric_limits<int32_t>::max()";
+ VERIFY_LE_INT32_MAX(bufferRole.ioIndex)
+ << "BufferRole: ioIndex must be <= std::numeric_limits<int32_t>::max()";
+ return BufferRole{
+ .modelIndex = static_cast<int32_t>(bufferRole.modelIndex),
+ .ioIndex = static_cast<int32_t>(bufferRole.ioIndex),
+ .frequency = bufferRole.frequency,
+ };
+}
+
+nn::GeneralResult<bool> unvalidatedConvert(const nn::MeasureTiming& measureTiming) {
+ return measureTiming == nn::MeasureTiming::YES;
+}
+
nn::GeneralResult<common::NativeHandle> unvalidatedConvert(const nn::SharedHandle& sharedHandle) {
CHECK(sharedHandle != nullptr);
return unvalidatedConvert(*sharedHandle);
@@ -707,6 +785,230 @@
.isSufficient = outputShape.isSufficient};
}
+nn::GeneralResult<ExecutionPreference> unvalidatedConvert(
+ const nn::ExecutionPreference& executionPreference) {
+ return static_cast<ExecutionPreference>(executionPreference);
+}
+
+nn::GeneralResult<OperandType> unvalidatedConvert(const nn::OperandType& operandType) {
+ return static_cast<OperandType>(operandType);
+}
+
+nn::GeneralResult<OperandLifeTime> unvalidatedConvert(
+ const nn::Operand::LifeTime& operandLifeTime) {
+ return static_cast<OperandLifeTime>(operandLifeTime);
+}
+
+nn::GeneralResult<DataLocation> unvalidatedConvert(const nn::DataLocation& location) {
+ VERIFY_LE_INT32_MAX(location.poolIndex)
+ << "DataLocation: pool index must be <= std::numeric_limits<int32_t>::max()";
+ return DataLocation{
+ .poolIndex = static_cast<int32_t>(location.poolIndex),
+ .offset = static_cast<int64_t>(location.offset),
+ .length = static_cast<int64_t>(location.length),
+ };
+}
+
+nn::GeneralResult<std::optional<OperandExtraParams>> unvalidatedConvert(
+ const nn::Operand::ExtraParams& extraParams) {
+ return std::visit(
+ overloaded{
+ [](const nn::Operand::NoParams&)
+ -> nn::GeneralResult<std::optional<OperandExtraParams>> {
+ return std::nullopt;
+ },
+ [](const nn::Operand::SymmPerChannelQuantParams& symmPerChannelQuantParams)
+ -> nn::GeneralResult<std::optional<OperandExtraParams>> {
+ if (symmPerChannelQuantParams.channelDim >
+ std::numeric_limits<int32_t>::max()) {
+ // Using explicit type conversion because std::optional in successful
+ // result confuses the compiler.
+ return (NN_ERROR() << "symmPerChannelQuantParams.channelDim must be <= "
+ "std::numeric_limits<int32_t>::max(), received: "
+ << symmPerChannelQuantParams.channelDim)
+ .
+ operator nn::GeneralResult<std::optional<OperandExtraParams>>();
+ }
+ return OperandExtraParams::make<OperandExtraParams::Tag::channelQuant>(
+ SymmPerChannelQuantParams{
+ .scales = symmPerChannelQuantParams.scales,
+ .channelDim = static_cast<int32_t>(
+ symmPerChannelQuantParams.channelDim),
+ });
+ },
+ [](const nn::Operand::ExtensionParams& extensionParams)
+ -> nn::GeneralResult<std::optional<OperandExtraParams>> {
+ return OperandExtraParams::make<OperandExtraParams::Tag::extension>(
+ extensionParams);
+ },
+ },
+ extraParams);
+}
+
+nn::GeneralResult<Operand> unvalidatedConvert(const nn::Operand& operand) {
+ return Operand{
+ .type = NN_TRY(unvalidatedConvert(operand.type)),
+ .dimensions = NN_TRY(toSigned(operand.dimensions)),
+ .scale = operand.scale,
+ .zeroPoint = operand.zeroPoint,
+ .lifetime = NN_TRY(unvalidatedConvert(operand.lifetime)),
+ .location = NN_TRY(unvalidatedConvert(operand.location)),
+ .extraParams = NN_TRY(unvalidatedConvert(operand.extraParams)),
+ };
+}
+
+nn::GeneralResult<OperationType> unvalidatedConvert(const nn::OperationType& operationType) {
+ return static_cast<OperationType>(operationType);
+}
+
+nn::GeneralResult<Operation> unvalidatedConvert(const nn::Operation& operation) {
+ return Operation{
+ .type = NN_TRY(unvalidatedConvert(operation.type)),
+ .inputs = NN_TRY(toSigned(operation.inputs)),
+ .outputs = NN_TRY(toSigned(operation.outputs)),
+ };
+}
+
+nn::GeneralResult<Subgraph> unvalidatedConvert(const nn::Model::Subgraph& subgraph) {
+ return Subgraph{
+ .operands = NN_TRY(unvalidatedConvert(subgraph.operands)),
+ .operations = NN_TRY(unvalidatedConvert(subgraph.operations)),
+ .inputIndexes = NN_TRY(toSigned(subgraph.inputIndexes)),
+ .outputIndexes = NN_TRY(toSigned(subgraph.outputIndexes)),
+ };
+}
+
+nn::GeneralResult<std::vector<uint8_t>> unvalidatedConvert(
+ const nn::Model::OperandValues& operandValues) {
+ return std::vector<uint8_t>(operandValues.data(), operandValues.data() + operandValues.size());
+}
+
+nn::GeneralResult<ExtensionNameAndPrefix> unvalidatedConvert(
+ const nn::Model::ExtensionNameAndPrefix& extensionNameToPrefix) {
+ return ExtensionNameAndPrefix{
+ .name = extensionNameToPrefix.name,
+ .prefix = extensionNameToPrefix.prefix,
+ };
+}
+
+nn::GeneralResult<Model> unvalidatedConvert(const nn::Model& model) {
+ return Model{
+ .main = NN_TRY(unvalidatedConvert(model.main)),
+ .referenced = NN_TRY(unvalidatedConvert(model.referenced)),
+ .operandValues = NN_TRY(unvalidatedConvert(model.operandValues)),
+ .pools = NN_TRY(unvalidatedConvert(model.pools)),
+ .relaxComputationFloat32toFloat16 = model.relaxComputationFloat32toFloat16,
+ .extensionNameToPrefix = NN_TRY(unvalidatedConvert(model.extensionNameToPrefix)),
+ };
+}
+
+nn::GeneralResult<Priority> unvalidatedConvert(const nn::Priority& priority) {
+ return static_cast<Priority>(priority);
+}
+
+nn::GeneralResult<Request> unvalidatedConvert(const nn::Request& request) {
+ return Request{
+ .inputs = NN_TRY(unvalidatedConvert(request.inputs)),
+ .outputs = NN_TRY(unvalidatedConvert(request.outputs)),
+ .pools = NN_TRY(unvalidatedConvert(request.pools)),
+ };
+}
+
+nn::GeneralResult<RequestArgument> unvalidatedConvert(
+ const nn::Request::Argument& requestArgument) {
+ if (requestArgument.lifetime == nn::Request::Argument::LifeTime::POINTER) {
+ return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT)
+ << "Request cannot be unvalidatedConverted because it contains pointer-based memory";
+ }
+ const bool hasNoValue = requestArgument.lifetime == nn::Request::Argument::LifeTime::NO_VALUE;
+ return RequestArgument{
+ .hasNoValue = hasNoValue,
+ .location = NN_TRY(unvalidatedConvert(requestArgument.location)),
+ .dimensions = NN_TRY(toSigned(requestArgument.dimensions)),
+ };
+}
+
+nn::GeneralResult<RequestMemoryPool> unvalidatedConvert(const nn::Request::MemoryPool& memoryPool) {
+ return std::visit(
+ overloaded{
+ [](const nn::SharedMemory& memory) -> nn::GeneralResult<RequestMemoryPool> {
+ return RequestMemoryPool::make<RequestMemoryPool::Tag::pool>(
+ NN_TRY(unvalidatedConvert(memory)));
+ },
+ [](const nn::Request::MemoryDomainToken& token)
+ -> nn::GeneralResult<RequestMemoryPool> {
+ return RequestMemoryPool::make<RequestMemoryPool::Tag::token>(
+ underlyingType(token));
+ },
+ [](const nn::SharedBuffer& /*buffer*/) {
+ return (NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
+ << "Unable to make memory pool from IBuffer")
+ .
+ operator nn::GeneralResult<RequestMemoryPool>();
+ },
+ },
+ memoryPool);
+}
+
+nn::GeneralResult<Timing> unvalidatedConvert(const nn::Timing& timing) {
+ return Timing{
+ .timeOnDevice = NN_TRY(unvalidatedConvert(timing.timeOnDevice)),
+ .timeInDriver = NN_TRY(unvalidatedConvert(timing.timeInDriver)),
+ };
+}
+
+nn::GeneralResult<int64_t> unvalidatedConvert(const nn::Duration& duration) {
+ const uint64_t nanoseconds = duration.count();
+ if (nanoseconds > std::numeric_limits<int64_t>::max()) {
+ return std::numeric_limits<int64_t>::max();
+ }
+ return static_cast<int64_t>(nanoseconds);
+}
+
+nn::GeneralResult<int64_t> unvalidatedConvert(const nn::OptionalDuration& optionalDuration) {
+ if (!optionalDuration.has_value()) {
+ return kNoTiming;
+ }
+ return unvalidatedConvert(optionalDuration.value());
+}
+
+nn::GeneralResult<int64_t> unvalidatedConvert(const nn::OptionalTimePoint& optionalTimePoint) {
+ if (!optionalTimePoint.has_value()) {
+ return kNoTiming;
+ }
+ return unvalidatedConvert(optionalTimePoint->time_since_epoch());
+}
+
+nn::GeneralResult<ndk::ScopedFileDescriptor> unvalidatedConvert(const nn::SyncFence& syncFence) {
+ auto duplicatedFd = NN_TRY(nn::dupFd(syncFence.getFd()));
+ return ndk::ScopedFileDescriptor(duplicatedFd.release());
+}
+
+nn::GeneralResult<ndk::ScopedFileDescriptor> unvalidatedConvertCache(
+ const nn::SharedHandle& handle) {
+ if (handle->ints.size() != 0) {
+ NN_ERROR() << "Cache handle must not contain ints";
+ }
+ if (handle->fds.size() != 1) {
+ NN_ERROR() << "Cache handle must contain exactly one fd but contains "
+ << handle->fds.size();
+ }
+ auto duplicatedFd = NN_TRY(nn::dupFd(handle->fds.front().get()));
+ return ndk::ScopedFileDescriptor(duplicatedFd.release());
+}
+
+nn::GeneralResult<std::vector<uint8_t>> convert(const nn::CacheToken& cacheToken) {
+ return unvalidatedConvert(cacheToken);
+}
+
+nn::GeneralResult<BufferDesc> convert(const nn::BufferDesc& bufferDesc) {
+ return validatedConvert(bufferDesc);
+}
+
+nn::GeneralResult<bool> convert(const nn::MeasureTiming& measureTiming) {
+ return validatedConvert(measureTiming);
+}
+
nn::GeneralResult<Memory> convert(const nn::SharedMemory& memory) {
return validatedConvert(memory);
}
@@ -715,11 +1017,62 @@
return validatedConvert(errorStatus);
}
+nn::GeneralResult<ExecutionPreference> convert(const nn::ExecutionPreference& executionPreference) {
+ return validatedConvert(executionPreference);
+}
+
+nn::GeneralResult<Model> convert(const nn::Model& model) {
+ return validatedConvert(model);
+}
+
+nn::GeneralResult<Priority> convert(const nn::Priority& priority) {
+ return validatedConvert(priority);
+}
+
+nn::GeneralResult<Request> convert(const nn::Request& request) {
+ return validatedConvert(request);
+}
+
+nn::GeneralResult<Timing> convert(const nn::Timing& timing) {
+ return validatedConvert(timing);
+}
+
+nn::GeneralResult<int64_t> convert(const nn::OptionalDuration& optionalDuration) {
+ return validatedConvert(optionalDuration);
+}
+
+nn::GeneralResult<int64_t> convert(const nn::OptionalTimePoint& outputShapes) {
+ return validatedConvert(outputShapes);
+}
+
+nn::GeneralResult<std::vector<BufferRole>> convert(const std::vector<nn::BufferRole>& bufferRoles) {
+ return validatedConvert(bufferRoles);
+}
+
nn::GeneralResult<std::vector<OutputShape>> convert(
const std::vector<nn::OutputShape>& outputShapes) {
return validatedConvert(outputShapes);
}
+nn::GeneralResult<std::vector<ndk::ScopedFileDescriptor>> convert(
+ const std::vector<nn::SharedHandle>& cacheHandles) {
+ const auto version = NN_TRY(hal::utils::makeGeneralFailure(nn::validate(cacheHandles)));
+ if (version > kVersion) {
+ return NN_ERROR() << "Insufficient version: " << version << " vs required " << kVersion;
+ }
+ std::vector<ndk::ScopedFileDescriptor> cacheFds;
+ cacheFds.reserve(cacheHandles.size());
+ for (const auto& cacheHandle : cacheHandles) {
+ cacheFds.push_back(NN_TRY(unvalidatedConvertCache(cacheHandle)));
+ }
+ return cacheFds;
+}
+
+nn::GeneralResult<std::vector<ndk::ScopedFileDescriptor>> convert(
+ const std::vector<nn::SyncFence>& syncFences) {
+ return unvalidatedConvert(syncFences);
+}
+
nn::GeneralResult<std::vector<int32_t>> toSigned(const std::vector<uint32_t>& vec) {
if (!std::all_of(vec.begin(), vec.end(),
[](uint32_t v) { return v <= std::numeric_limits<int32_t>::max(); })) {
diff --git a/neuralnetworks/aidl/utils/src/Device.cpp b/neuralnetworks/aidl/utils/src/Device.cpp
new file mode 100644
index 0000000..02ca861
--- /dev/null
+++ b/neuralnetworks/aidl/utils/src/Device.cpp
@@ -0,0 +1,294 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Device.h"
+
+#include "Buffer.h"
+#include "Callbacks.h"
+#include "Conversions.h"
+#include "PreparedModel.h"
+#include "ProtectCallback.h"
+#include "Utils.h"
+
+#include <aidl/android/hardware/neuralnetworks/IDevice.h>
+#include <android/binder_auto_utils.h>
+#include <android/binder_interface_utils.h>
+#include <nnapi/IBuffer.h>
+#include <nnapi/IDevice.h>
+#include <nnapi/IPreparedModel.h>
+#include <nnapi/OperandTypes.h>
+#include <nnapi/Result.h>
+#include <nnapi/Types.h>
+#include <nnapi/hal/CommonUtils.h>
+
+#include <any>
+#include <functional>
+#include <memory>
+#include <optional>
+#include <string>
+#include <vector>
+
+// See hardware/interfaces/neuralnetworks/utils/README.md for more information on AIDL interface
+// lifetimes across processes and for protecting asynchronous calls across AIDL.
+
+namespace aidl::android::hardware::neuralnetworks::utils {
+
+namespace {
+
+nn::GeneralResult<std::vector<std::shared_ptr<IPreparedModel>>> convert(
+ const std::vector<nn::SharedPreparedModel>& preparedModels) {
+ std::vector<std::shared_ptr<IPreparedModel>> aidlPreparedModels(preparedModels.size());
+ for (size_t i = 0; i < preparedModels.size(); ++i) {
+ std::any underlyingResource = preparedModels[i]->getUnderlyingResource();
+ if (const auto* aidlPreparedModel =
+ std::any_cast<std::shared_ptr<aidl_hal::IPreparedModel>>(&underlyingResource)) {
+ aidlPreparedModels[i] = *aidlPreparedModel;
+ } else {
+ return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT)
+ << "Unable to convert from nn::IPreparedModel to aidl_hal::IPreparedModel";
+ }
+ }
+ return aidlPreparedModels;
+}
+
+nn::GeneralResult<nn::Capabilities> getCapabilitiesFrom(IDevice* device) {
+ CHECK(device != nullptr);
+ Capabilities capabilities;
+ const auto ret = device->getCapabilities(&capabilities);
+ HANDLE_ASTATUS(ret) << "getCapabilities failed";
+ return nn::convert(capabilities);
+}
+
+nn::GeneralResult<std::string> getVersionStringFrom(aidl_hal::IDevice* device) {
+ CHECK(device != nullptr);
+ std::string version;
+ const auto ret = device->getVersionString(&version);
+ HANDLE_ASTATUS(ret) << "getVersionString failed";
+ return version;
+}
+
+nn::GeneralResult<nn::DeviceType> getDeviceTypeFrom(aidl_hal::IDevice* device) {
+ CHECK(device != nullptr);
+ DeviceType deviceType;
+ const auto ret = device->getType(&deviceType);
+ HANDLE_ASTATUS(ret) << "getDeviceType failed";
+ return nn::convert(deviceType);
+}
+
+nn::GeneralResult<std::vector<nn::Extension>> getSupportedExtensionsFrom(
+ aidl_hal::IDevice* device) {
+ CHECK(device != nullptr);
+ std::vector<Extension> supportedExtensions;
+ const auto ret = device->getSupportedExtensions(&supportedExtensions);
+ HANDLE_ASTATUS(ret) << "getExtensions failed";
+ return nn::convert(supportedExtensions);
+}
+
+nn::GeneralResult<std::pair<uint32_t, uint32_t>> getNumberOfCacheFilesNeededFrom(
+ aidl_hal::IDevice* device) {
+ CHECK(device != nullptr);
+ NumberOfCacheFiles numberOfCacheFiles;
+ const auto ret = device->getNumberOfCacheFilesNeeded(&numberOfCacheFiles);
+ HANDLE_ASTATUS(ret) << "getNumberOfCacheFilesNeeded failed";
+
+ if (numberOfCacheFiles.numDataCache < 0 || numberOfCacheFiles.numModelCache < 0) {
+ return NN_ERROR() << "Driver reported negative numer of cache files needed";
+ }
+ if (static_cast<uint32_t>(numberOfCacheFiles.numModelCache) > nn::kMaxNumberOfCacheFiles) {
+ return NN_ERROR() << "getNumberOfCacheFilesNeeded returned numModelCache files greater "
+ "than allowed max ("
+ << numberOfCacheFiles.numModelCache << " vs "
+ << nn::kMaxNumberOfCacheFiles << ")";
+ }
+ if (static_cast<uint32_t>(numberOfCacheFiles.numDataCache) > nn::kMaxNumberOfCacheFiles) {
+ return NN_ERROR() << "getNumberOfCacheFilesNeeded returned numDataCache files greater "
+ "than allowed max ("
+ << numberOfCacheFiles.numDataCache << " vs " << nn::kMaxNumberOfCacheFiles
+ << ")";
+ }
+ return std::make_pair(numberOfCacheFiles.numDataCache, numberOfCacheFiles.numModelCache);
+}
+
+} // namespace
+
+nn::GeneralResult<std::shared_ptr<const Device>> Device::create(
+ std::string name, std::shared_ptr<aidl_hal::IDevice> device) {
+ if (name.empty()) {
+ return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT)
+ << "aidl_hal::utils::Device::create must have non-empty name";
+ }
+ if (device == nullptr) {
+ return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT)
+ << "aidl_hal::utils::Device::create must have non-null device";
+ }
+
+ auto versionString = NN_TRY(getVersionStringFrom(device.get()));
+ const auto deviceType = NN_TRY(getDeviceTypeFrom(device.get()));
+ auto extensions = NN_TRY(getSupportedExtensionsFrom(device.get()));
+ auto capabilities = NN_TRY(getCapabilitiesFrom(device.get()));
+ const auto numberOfCacheFilesNeeded = NN_TRY(getNumberOfCacheFilesNeededFrom(device.get()));
+
+ auto deathHandler = NN_TRY(DeathHandler::create(device));
+ return std::make_shared<const Device>(
+ PrivateConstructorTag{}, std::move(name), std::move(versionString), deviceType,
+ std::move(extensions), std::move(capabilities), numberOfCacheFilesNeeded,
+ std::move(device), std::move(deathHandler));
+}
+
+Device::Device(PrivateConstructorTag /*tag*/, std::string name, std::string versionString,
+ nn::DeviceType deviceType, std::vector<nn::Extension> extensions,
+ nn::Capabilities capabilities,
+ std::pair<uint32_t, uint32_t> numberOfCacheFilesNeeded,
+ std::shared_ptr<aidl_hal::IDevice> device, DeathHandler deathHandler)
+ : kName(std::move(name)),
+ kVersionString(std::move(versionString)),
+ kDeviceType(deviceType),
+ kExtensions(std::move(extensions)),
+ kCapabilities(std::move(capabilities)),
+ kNumberOfCacheFilesNeeded(numberOfCacheFilesNeeded),
+ kDevice(std::move(device)),
+ kDeathHandler(std::move(deathHandler)) {}
+
+const std::string& Device::getName() const {
+ return kName;
+}
+
+const std::string& Device::getVersionString() const {
+ return kVersionString;
+}
+
+nn::Version Device::getFeatureLevel() const {
+ return nn::Version::ANDROID_S;
+}
+
+nn::DeviceType Device::getType() const {
+ return kDeviceType;
+}
+
+bool Device::isUpdatable() const {
+ return false;
+}
+
+const std::vector<nn::Extension>& Device::getSupportedExtensions() const {
+ return kExtensions;
+}
+
+const nn::Capabilities& Device::getCapabilities() const {
+ return kCapabilities;
+}
+
+std::pair<uint32_t, uint32_t> Device::getNumberOfCacheFilesNeeded() const {
+ return kNumberOfCacheFilesNeeded;
+}
+
+nn::GeneralResult<void> Device::wait() const {
+ const auto ret = ndk::ScopedAStatus::fromStatus(AIBinder_ping(kDevice->asBinder().get()));
+ HANDLE_ASTATUS(ret) << "ping failed";
+ return {};
+}
+
+nn::GeneralResult<std::vector<bool>> Device::getSupportedOperations(const nn::Model& model) const {
+ // Ensure that model is ready for IPC.
+ std::optional<nn::Model> maybeModelInShared;
+ const nn::Model& modelInShared =
+ NN_TRY(hal::utils::flushDataFromPointerToShared(&model, &maybeModelInShared));
+
+ const auto aidlModel = NN_TRY(convert(modelInShared));
+
+ std::vector<bool> supportedOperations;
+ const auto ret = kDevice->getSupportedOperations(aidlModel, &supportedOperations);
+ HANDLE_ASTATUS(ret) << "getSupportedOperations failed";
+
+ return supportedOperations;
+}
+
+nn::GeneralResult<nn::SharedPreparedModel> Device::prepareModel(
+ const nn::Model& model, nn::ExecutionPreference preference, nn::Priority priority,
+ nn::OptionalTimePoint deadline, const std::vector<nn::SharedHandle>& modelCache,
+ const std::vector<nn::SharedHandle>& dataCache, const nn::CacheToken& token) const {
+ // Ensure that model is ready for IPC.
+ std::optional<nn::Model> maybeModelInShared;
+ const nn::Model& modelInShared =
+ NN_TRY(hal::utils::flushDataFromPointerToShared(&model, &maybeModelInShared));
+
+ const auto aidlModel = NN_TRY(convert(modelInShared));
+ const auto aidlPreference = NN_TRY(convert(preference));
+ const auto aidlPriority = NN_TRY(convert(priority));
+ const auto aidlDeadline = NN_TRY(convert(deadline));
+ const auto aidlModelCache = NN_TRY(convert(modelCache));
+ const auto aidlDataCache = NN_TRY(convert(dataCache));
+ const auto aidlToken = NN_TRY(convert(token));
+
+ const auto cb = ndk::SharedRefBase::make<PreparedModelCallback>();
+ const auto scoped = kDeathHandler.protectCallback(cb.get());
+
+ const auto ret = kDevice->prepareModel(aidlModel, aidlPreference, aidlPriority, aidlDeadline,
+ aidlModelCache, aidlDataCache, aidlToken, cb);
+ HANDLE_ASTATUS(ret) << "prepareModel failed";
+
+ return cb->get();
+}
+
+nn::GeneralResult<nn::SharedPreparedModel> Device::prepareModelFromCache(
+ nn::OptionalTimePoint deadline, const std::vector<nn::SharedHandle>& modelCache,
+ const std::vector<nn::SharedHandle>& dataCache, const nn::CacheToken& token) const {
+ const auto aidlDeadline = NN_TRY(convert(deadline));
+ const auto aidlModelCache = NN_TRY(convert(modelCache));
+ const auto aidlDataCache = NN_TRY(convert(dataCache));
+ const auto aidlToken = NN_TRY(convert(token));
+
+ const auto cb = ndk::SharedRefBase::make<PreparedModelCallback>();
+ const auto scoped = kDeathHandler.protectCallback(cb.get());
+
+ const auto ret = kDevice->prepareModelFromCache(aidlDeadline, aidlModelCache, aidlDataCache,
+ aidlToken, cb);
+ HANDLE_ASTATUS(ret) << "prepareModelFromCache failed";
+
+ return cb->get();
+}
+
+nn::GeneralResult<nn::SharedBuffer> Device::allocate(
+ const nn::BufferDesc& desc, const std::vector<nn::SharedPreparedModel>& preparedModels,
+ const std::vector<nn::BufferRole>& inputRoles,
+ const std::vector<nn::BufferRole>& outputRoles) const {
+ const auto aidlDesc = NN_TRY(convert(desc));
+ const auto aidlPreparedModels = NN_TRY(convert(preparedModels));
+ const auto aidlInputRoles = NN_TRY(convert(inputRoles));
+ const auto aidlOutputRoles = NN_TRY(convert(outputRoles));
+
+ std::vector<IPreparedModelParcel> aidlPreparedModelParcels;
+ aidlPreparedModelParcels.reserve(aidlPreparedModels.size());
+ for (const auto& preparedModel : aidlPreparedModels) {
+ aidlPreparedModelParcels.push_back({.preparedModel = preparedModel});
+ }
+
+ DeviceBuffer buffer;
+ const auto ret = kDevice->allocate(aidlDesc, aidlPreparedModelParcels, aidlInputRoles,
+ aidlOutputRoles, &buffer);
+ HANDLE_ASTATUS(ret) << "IDevice::allocate failed";
+
+ if (buffer.token < 0) {
+ return NN_ERROR() << "IDevice::allocate returned negative token";
+ }
+
+ return Buffer::create(buffer.buffer, static_cast<nn::Request::MemoryDomainToken>(buffer.token));
+}
+
+DeathMonitor* Device::getDeathMonitor() const {
+ return kDeathHandler.getDeathMonitor().get();
+}
+
+} // namespace aidl::android::hardware::neuralnetworks::utils
diff --git a/neuralnetworks/aidl/utils/src/PreparedModel.cpp b/neuralnetworks/aidl/utils/src/PreparedModel.cpp
new file mode 100644
index 0000000..aee4d90
--- /dev/null
+++ b/neuralnetworks/aidl/utils/src/PreparedModel.cpp
@@ -0,0 +1,172 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "PreparedModel.h"
+
+#include "Callbacks.h"
+#include "Conversions.h"
+#include "ProtectCallback.h"
+#include "Utils.h"
+
+#include <android/binder_auto_utils.h>
+#include <nnapi/IPreparedModel.h>
+#include <nnapi/Result.h>
+#include <nnapi/TypeUtils.h>
+#include <nnapi/Types.h>
+#include <nnapi/hal/1.0/Burst.h>
+#include <nnapi/hal/CommonUtils.h>
+#include <nnapi/hal/HandleError.h>
+
+#include <memory>
+#include <tuple>
+#include <utility>
+#include <vector>
+
+// See hardware/interfaces/neuralnetworks/utils/README.md for more information on AIDL interface
+// lifetimes across processes and for protecting asynchronous calls across AIDL.
+
+namespace aidl::android::hardware::neuralnetworks::utils {
+namespace {
+
+nn::GeneralResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> convertExecutionResults(
+ const std::vector<OutputShape>& outputShapes, const Timing& timing) {
+ return std::make_pair(NN_TRY(nn::convert(outputShapes)), NN_TRY(nn::convert(timing)));
+}
+
+nn::GeneralResult<std::pair<nn::Timing, nn::Timing>> convertFencedExecutionResults(
+ ErrorStatus status, const aidl_hal::Timing& timingLaunched,
+ const aidl_hal::Timing& timingFenced) {
+ HANDLE_HAL_STATUS(status) << "fenced execution callback info failed with " << toString(status);
+ return std::make_pair(NN_TRY(nn::convert(timingLaunched)), NN_TRY(nn::convert(timingFenced)));
+}
+
+} // namespace
+
+nn::GeneralResult<std::shared_ptr<const PreparedModel>> PreparedModel::create(
+ std::shared_ptr<aidl_hal::IPreparedModel> preparedModel) {
+ if (preparedModel == nullptr) {
+ return NN_ERROR()
+ << "aidl_hal::utils::PreparedModel::create must have non-null preparedModel";
+ }
+
+ return std::make_shared<const PreparedModel>(PrivateConstructorTag{}, std::move(preparedModel));
+}
+
+PreparedModel::PreparedModel(PrivateConstructorTag /*tag*/,
+ std::shared_ptr<aidl_hal::IPreparedModel> preparedModel)
+ : kPreparedModel(std::move(preparedModel)) {}
+
+nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> PreparedModel::execute(
+ const nn::Request& request, nn::MeasureTiming measure,
+ const nn::OptionalTimePoint& deadline,
+ const nn::OptionalDuration& loopTimeoutDuration) const {
+ // Ensure that request is ready for IPC.
+ std::optional<nn::Request> maybeRequestInShared;
+ const nn::Request& requestInShared = NN_TRY(hal::utils::makeExecutionFailure(
+ hal::utils::flushDataFromPointerToShared(&request, &maybeRequestInShared)));
+
+ const auto aidlRequest = NN_TRY(hal::utils::makeExecutionFailure(convert(requestInShared)));
+ const auto aidlMeasure = NN_TRY(hal::utils::makeExecutionFailure(convert(measure)));
+ const auto aidlDeadline = NN_TRY(hal::utils::makeExecutionFailure(convert(deadline)));
+ const auto aidlLoopTimeoutDuration =
+ NN_TRY(hal::utils::makeExecutionFailure(convert(loopTimeoutDuration)));
+
+ ExecutionResult executionResult;
+ const auto ret = kPreparedModel->executeSynchronously(
+ aidlRequest, aidlMeasure, aidlDeadline, aidlLoopTimeoutDuration, &executionResult);
+ HANDLE_ASTATUS(ret) << "executeSynchronously failed";
+ if (!executionResult.outputSufficientSize) {
+ auto canonicalOutputShapes =
+ nn::convert(executionResult.outputShapes).value_or(std::vector<nn::OutputShape>{});
+ return NN_ERROR(nn::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE, std::move(canonicalOutputShapes))
+ << "execution failed with " << nn::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE;
+ }
+ auto [outputShapes, timing] = NN_TRY(hal::utils::makeExecutionFailure(
+ convertExecutionResults(executionResult.outputShapes, executionResult.timing)));
+
+ NN_TRY(hal::utils::makeExecutionFailure(
+ hal::utils::unflushDataFromSharedToPointer(request, maybeRequestInShared)));
+
+ return std::make_pair(std::move(outputShapes), timing);
+}
+
+nn::GeneralResult<std::pair<nn::SyncFence, nn::ExecuteFencedInfoCallback>>
+PreparedModel::executeFenced(const nn::Request& request, const std::vector<nn::SyncFence>& waitFor,
+ nn::MeasureTiming measure, const nn::OptionalTimePoint& deadline,
+ const nn::OptionalDuration& loopTimeoutDuration,
+ const nn::OptionalDuration& timeoutDurationAfterFence) const {
+ // Ensure that request is ready for IPC.
+ std::optional<nn::Request> maybeRequestInShared;
+ const nn::Request& requestInShared =
+ NN_TRY(hal::utils::flushDataFromPointerToShared(&request, &maybeRequestInShared));
+
+ const auto aidlRequest = NN_TRY(convert(requestInShared));
+ const auto aidlWaitFor = NN_TRY(convert(waitFor));
+ const auto aidlMeasure = NN_TRY(convert(measure));
+ const auto aidlDeadline = NN_TRY(convert(deadline));
+ const auto aidlLoopTimeoutDuration = NN_TRY(convert(loopTimeoutDuration));
+ const auto aidlTimeoutDurationAfterFence = NN_TRY(convert(timeoutDurationAfterFence));
+
+ FencedExecutionResult result;
+ const auto ret = kPreparedModel->executeFenced(aidlRequest, aidlWaitFor, aidlMeasure,
+ aidlDeadline, aidlLoopTimeoutDuration,
+ aidlTimeoutDurationAfterFence, &result);
+ HANDLE_ASTATUS(ret) << "executeFenced failed";
+
+ auto resultSyncFence = nn::SyncFence::createAsSignaled();
+ if (result.syncFence.get() != -1) {
+ resultSyncFence = NN_TRY(nn::convert(result.syncFence));
+ }
+
+ auto callback = result.callback;
+ if (callback == nullptr) {
+ return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE) << "callback is null";
+ }
+
+ // If executeFenced required the request memory to be moved into shared memory, block here until
+ // the fenced execution has completed and flush the memory back.
+ if (maybeRequestInShared.has_value()) {
+ const auto state = resultSyncFence.syncWait({});
+ if (state != nn::SyncFence::FenceState::SIGNALED) {
+ return NN_ERROR() << "syncWait failed with " << state;
+ }
+ NN_TRY(hal::utils::unflushDataFromSharedToPointer(request, maybeRequestInShared));
+ }
+
+ // Create callback which can be used to retrieve the execution error status and timings.
+ nn::ExecuteFencedInfoCallback resultCallback =
+ [callback]() -> nn::GeneralResult<std::pair<nn::Timing, nn::Timing>> {
+ ErrorStatus errorStatus;
+ Timing timingLaunched;
+ Timing timingFenced;
+ const auto ret = callback->getExecutionInfo(&timingLaunched, &timingFenced, &errorStatus);
+ HANDLE_ASTATUS(ret) << "fenced execution callback getExecutionInfo failed";
+ return convertFencedExecutionResults(errorStatus, timingLaunched, timingFenced);
+ };
+
+ return std::make_pair(std::move(resultSyncFence), std::move(resultCallback));
+}
+
+nn::GeneralResult<nn::SharedBurst> PreparedModel::configureExecutionBurst() const {
+ return hal::V1_0::utils::Burst::create(shared_from_this());
+}
+
+std::any PreparedModel::getUnderlyingResource() const {
+ std::shared_ptr<aidl_hal::IPreparedModel> resource = kPreparedModel;
+ return resource;
+}
+
+} // namespace aidl::android::hardware::neuralnetworks::utils
diff --git a/neuralnetworks/aidl/utils/src/ProtectCallback.cpp b/neuralnetworks/aidl/utils/src/ProtectCallback.cpp
new file mode 100644
index 0000000..124641c
--- /dev/null
+++ b/neuralnetworks/aidl/utils/src/ProtectCallback.cpp
@@ -0,0 +1,112 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ProtectCallback.h"
+
+#include <android-base/logging.h>
+#include <android-base/scopeguard.h>
+#include <android-base/thread_annotations.h>
+#include <android/binder_auto_utils.h>
+#include <android/binder_interface_utils.h>
+#include <nnapi/Result.h>
+#include <nnapi/hal/ProtectCallback.h>
+
+#include <algorithm>
+#include <functional>
+#include <memory>
+#include <mutex>
+#include <vector>
+
+#include "Utils.h"
+
+namespace aidl::android::hardware::neuralnetworks::utils {
+
+void DeathMonitor::serviceDied() {
+ std::lock_guard guard(mMutex);
+ std::for_each(mObjects.begin(), mObjects.end(),
+ [](hal::utils::IProtectedCallback* killable) { killable->notifyAsDeadObject(); });
+}
+
+void DeathMonitor::serviceDied(void* cookie) {
+ auto deathMonitor = static_cast<DeathMonitor*>(cookie);
+ deathMonitor->serviceDied();
+}
+
+void DeathMonitor::add(hal::utils::IProtectedCallback* killable) const {
+ CHECK(killable != nullptr);
+ std::lock_guard guard(mMutex);
+ mObjects.push_back(killable);
+}
+
+void DeathMonitor::remove(hal::utils::IProtectedCallback* killable) const {
+ CHECK(killable != nullptr);
+ std::lock_guard guard(mMutex);
+ const auto removedIter = std::remove(mObjects.begin(), mObjects.end(), killable);
+ mObjects.erase(removedIter);
+}
+
+nn::GeneralResult<DeathHandler> DeathHandler::create(std::shared_ptr<ndk::ICInterface> object) {
+ if (object == nullptr) {
+ return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT)
+ << "utils::DeathHandler::create must have non-null object";
+ }
+ auto deathMonitor = std::make_shared<DeathMonitor>();
+ auto deathRecipient = ndk::ScopedAIBinder_DeathRecipient(
+ AIBinder_DeathRecipient_new(DeathMonitor::serviceDied));
+
+ // If passed a local binder, AIBinder_linkToDeath will do nothing and return
+ // STATUS_INVALID_OPERATION. We ignore this case because we only use local binders in tests
+ // where this is not an error.
+ if (object->isRemote()) {
+ const auto ret = ndk::ScopedAStatus::fromStatus(AIBinder_linkToDeath(
+ object->asBinder().get(), deathRecipient.get(), deathMonitor.get()));
+ HANDLE_ASTATUS(ret) << "AIBinder_linkToDeath failed";
+ }
+
+ return DeathHandler(std::move(object), std::move(deathRecipient), std::move(deathMonitor));
+}
+
+DeathHandler::DeathHandler(std::shared_ptr<ndk::ICInterface> object,
+ ndk::ScopedAIBinder_DeathRecipient deathRecipient,
+ std::shared_ptr<DeathMonitor> deathMonitor)
+ : kObject(std::move(object)),
+ kDeathRecipient(std::move(deathRecipient)),
+ kDeathMonitor(std::move(deathMonitor)) {
+ CHECK(kObject != nullptr);
+ CHECK(kDeathRecipient.get() != nullptr);
+ CHECK(kDeathMonitor != nullptr);
+}
+
+DeathHandler::~DeathHandler() {
+ if (kObject != nullptr && kDeathRecipient.get() != nullptr && kDeathMonitor != nullptr) {
+ const auto ret = ndk::ScopedAStatus::fromStatus(AIBinder_unlinkToDeath(
+ kObject->asBinder().get(), kDeathRecipient.get(), kDeathMonitor.get()));
+ const auto maybeSuccess = handleTransportError(ret);
+ if (!maybeSuccess.ok()) {
+ LOG(ERROR) << maybeSuccess.error().message;
+ }
+ }
+}
+
+[[nodiscard]] ::android::base::ScopeGuard<DeathHandler::Cleanup> DeathHandler::protectCallback(
+ hal::utils::IProtectedCallback* killable) const {
+ CHECK(killable != nullptr);
+ kDeathMonitor->add(killable);
+ return ::android::base::make_scope_guard(
+ [deathMonitor = kDeathMonitor, killable] { deathMonitor->remove(killable); });
+}
+
+} // namespace aidl::android::hardware::neuralnetworks::utils
diff --git a/neuralnetworks/aidl/utils/src/Service.cpp b/neuralnetworks/aidl/utils/src/Service.cpp
new file mode 100644
index 0000000..5ec6ded
--- /dev/null
+++ b/neuralnetworks/aidl/utils/src/Service.cpp
@@ -0,0 +1,50 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Service.h"
+
+#include <android/binder_auto_utils.h>
+#include <android/binder_manager.h>
+
+#include <nnapi/IDevice.h>
+#include <nnapi/Result.h>
+#include <nnapi/Types.h>
+#include <nnapi/hal/ResilientDevice.h>
+#include <string>
+
+#include "Device.h"
+
+namespace aidl::android::hardware::neuralnetworks::utils {
+
+nn::GeneralResult<nn::SharedDevice> getDevice(const std::string& name) {
+ hal::utils::ResilientDevice::Factory makeDevice =
+ [name](bool blocking) -> nn::GeneralResult<nn::SharedDevice> {
+ auto service = blocking ? IDevice::fromBinder(
+ ndk::SpAIBinder(AServiceManager_getService(name.c_str())))
+ : IDevice::fromBinder(ndk::SpAIBinder(
+ AServiceManager_checkService(name.c_str())));
+ if (service == nullptr) {
+ return NN_ERROR() << (blocking ? "AServiceManager_getService"
+ : "AServiceManager_checkService")
+ << " returned nullptr";
+ }
+ return Device::create(name, std::move(service));
+ };
+
+ return hal::utils::ResilientDevice::create(std::move(makeDevice));
+}
+
+} // namespace aidl::android::hardware::neuralnetworks::utils
diff --git a/neuralnetworks/aidl/utils/src/Utils.cpp b/neuralnetworks/aidl/utils/src/Utils.cpp
index 8d00e59..95516c8 100644
--- a/neuralnetworks/aidl/utils/src/Utils.cpp
+++ b/neuralnetworks/aidl/utils/src/Utils.cpp
@@ -16,13 +16,12 @@
#include "Utils.h"
+#include <android/binder_status.h>
#include <nnapi/Result.h>
namespace aidl::android::hardware::neuralnetworks::utils {
namespace {
-using ::android::nn::GeneralResult;
-
template <typename Type>
nn::GeneralResult<std::vector<Type>> cloneVec(const std::vector<Type>& arguments) {
std::vector<Type> clonedObjects;
@@ -34,13 +33,13 @@
}
template <typename Type>
-GeneralResult<std::vector<Type>> clone(const std::vector<Type>& arguments) {
+nn::GeneralResult<std::vector<Type>> clone(const std::vector<Type>& arguments) {
return cloneVec(arguments);
}
} // namespace
-GeneralResult<Memory> clone(const Memory& memory) {
+nn::GeneralResult<Memory> clone(const Memory& memory) {
common::NativeHandle nativeHandle;
nativeHandle.ints = memory.handle.ints;
nativeHandle.fds.reserve(memory.handle.fds.size());
@@ -58,7 +57,7 @@
};
}
-GeneralResult<RequestMemoryPool> clone(const RequestMemoryPool& requestPool) {
+nn::GeneralResult<RequestMemoryPool> clone(const RequestMemoryPool& requestPool) {
using Tag = RequestMemoryPool::Tag;
switch (requestPool.getTag()) {
case Tag::pool:
@@ -70,10 +69,10 @@
// compiler.
return (NN_ERROR() << "Unrecognized request pool tag: " << requestPool.getTag())
.
- operator GeneralResult<RequestMemoryPool>();
+ operator nn::GeneralResult<RequestMemoryPool>();
}
-GeneralResult<Request> clone(const Request& request) {
+nn::GeneralResult<Request> clone(const Request& request) {
return Request{
.inputs = request.inputs,
.outputs = request.outputs,
@@ -81,7 +80,7 @@
};
}
-GeneralResult<Model> clone(const Model& model) {
+nn::GeneralResult<Model> clone(const Model& model) {
return Model{
.main = model.main,
.referenced = model.referenced,
@@ -92,4 +91,20 @@
};
}
+nn::GeneralResult<void> handleTransportError(const ndk::ScopedAStatus& ret) {
+ if (ret.getStatus() == STATUS_DEAD_OBJECT) {
+ return nn::error(nn::ErrorStatus::DEAD_OBJECT)
+ << "Binder transaction returned STATUS_DEAD_OBJECT: " << ret.getDescription();
+ }
+ if (ret.isOk()) {
+ return {};
+ }
+ if (ret.getExceptionCode() != EX_SERVICE_SPECIFIC) {
+ return nn::error(nn::ErrorStatus::GENERAL_FAILURE)
+ << "Binder transaction returned exception: " << ret.getDescription();
+ }
+ return nn::error(static_cast<nn::ErrorStatus>(ret.getServiceSpecificError()))
+ << ret.getMessage();
+}
+
} // namespace aidl::android::hardware::neuralnetworks::utils
diff --git a/neuralnetworks/aidl/utils/test/BufferTest.cpp b/neuralnetworks/aidl/utils/test/BufferTest.cpp
new file mode 100644
index 0000000..9736160
--- /dev/null
+++ b/neuralnetworks/aidl/utils/test/BufferTest.cpp
@@ -0,0 +1,212 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "MockBuffer.h"
+
+#include <aidl/android/hardware/neuralnetworks/ErrorStatus.h>
+#include <aidl/android/hardware/neuralnetworks/IBuffer.h>
+#include <android/binder_auto_utils.h>
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include <nnapi/IBuffer.h>
+#include <nnapi/SharedMemory.h>
+#include <nnapi/TypeUtils.h>
+#include <nnapi/Types.h>
+#include <nnapi/hal/aidl/Buffer.h>
+
+#include <functional>
+#include <memory>
+
+namespace aidl::android::hardware::neuralnetworks::utils {
+namespace {
+
+using ::testing::_;
+using ::testing::Invoke;
+using ::testing::InvokeWithoutArgs;
+using ::testing::Return;
+
+const auto kMemory = nn::createSharedMemory(4).value();
+const std::shared_ptr<IBuffer> kInvalidBuffer;
+constexpr auto kInvalidToken = nn::Request::MemoryDomainToken{0};
+constexpr auto kToken = nn::Request::MemoryDomainToken{1};
+
+constexpr auto makeStatusOk = [] { return ndk::ScopedAStatus::ok(); };
+
+constexpr auto makeGeneralFailure = [] {
+ return ndk::ScopedAStatus::fromServiceSpecificError(
+ static_cast<int32_t>(ErrorStatus::GENERAL_FAILURE));
+};
+constexpr auto makeGeneralTransportFailure = [] {
+ return ndk::ScopedAStatus::fromStatus(STATUS_NO_MEMORY);
+};
+constexpr auto makeDeadObjectFailure = [] {
+ return ndk::ScopedAStatus::fromStatus(STATUS_DEAD_OBJECT);
+};
+
+} // namespace
+
+TEST(BufferTest, invalidBuffer) {
+ // run test
+ const auto result = Buffer::create(kInvalidBuffer, kToken);
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST(BufferTest, invalidToken) {
+ // setup call
+ const auto mockBuffer = MockBuffer::create();
+
+ // run test
+ const auto result = Buffer::create(mockBuffer, kInvalidToken);
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST(BufferTest, create) {
+ // setup call
+ const auto mockBuffer = MockBuffer::create();
+ const auto buffer = Buffer::create(mockBuffer, kToken).value();
+
+ // run test
+ const auto token = buffer->getToken();
+
+ // verify result
+ EXPECT_EQ(token, kToken);
+}
+
+TEST(BufferTest, copyTo) {
+ // setup call
+ const auto mockBuffer = MockBuffer::create();
+ const auto buffer = Buffer::create(mockBuffer, kToken).value();
+ EXPECT_CALL(*mockBuffer, copyTo(_)).Times(1).WillOnce(InvokeWithoutArgs(makeStatusOk));
+
+ // run test
+ const auto result = buffer->copyTo(kMemory);
+
+ // verify result
+ EXPECT_TRUE(result.has_value()) << result.error().message;
+}
+
+TEST(BufferTest, copyToError) {
+ // setup test
+ const auto mockBuffer = MockBuffer::create();
+ const auto buffer = Buffer::create(mockBuffer, kToken).value();
+ EXPECT_CALL(*mockBuffer, copyTo(_)).Times(1).WillOnce(InvokeWithoutArgs(makeGeneralFailure));
+
+ // run test
+ const auto result = buffer->copyTo(kMemory);
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST(BufferTest, copyToTransportFailure) {
+ // setup test
+ const auto mockBuffer = MockBuffer::create();
+ const auto buffer = Buffer::create(mockBuffer, kToken).value();
+ EXPECT_CALL(*mockBuffer, copyTo(_))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
+
+ // run test
+ const auto result = buffer->copyTo(kMemory);
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST(BufferTest, copyToDeadObject) {
+ // setup test
+ const auto mockBuffer = MockBuffer::create();
+ const auto buffer = Buffer::create(mockBuffer, kToken).value();
+ EXPECT_CALL(*mockBuffer, copyTo(_)).Times(1).WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
+
+ // run test
+ const auto result = buffer->copyTo(kMemory);
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
+}
+
+TEST(BufferTest, copyFrom) {
+ // setup call
+ const auto mockBuffer = MockBuffer::create();
+ const auto buffer = Buffer::create(mockBuffer, kToken).value();
+ EXPECT_CALL(*mockBuffer, copyFrom(_, _)).Times(1).WillOnce(InvokeWithoutArgs(makeStatusOk));
+
+ // run test
+ const auto result = buffer->copyFrom(kMemory, {});
+
+ // verify result
+ EXPECT_TRUE(result.has_value());
+}
+
+TEST(BufferTest, copyFromError) {
+ // setup test
+ const auto mockBuffer = MockBuffer::create();
+ const auto buffer = Buffer::create(mockBuffer, kToken).value();
+ EXPECT_CALL(*mockBuffer, copyFrom(_, _))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(makeGeneralFailure));
+
+ // run test
+ const auto result = buffer->copyFrom(kMemory, {});
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST(BufferTest, copyFromTransportFailure) {
+ // setup test
+ const auto mockBuffer = MockBuffer::create();
+ const auto buffer = Buffer::create(mockBuffer, kToken).value();
+ EXPECT_CALL(*mockBuffer, copyFrom(_, _))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
+
+ // run test
+ const auto result = buffer->copyFrom(kMemory, {});
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST(BufferTest, copyFromDeadObject) {
+ // setup test
+ const auto mockBuffer = MockBuffer::create();
+ const auto buffer = Buffer::create(mockBuffer, kToken).value();
+ EXPECT_CALL(*mockBuffer, copyFrom(_, _))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
+
+ // run test
+ const auto result = buffer->copyFrom(kMemory, {});
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
+}
+
+} // namespace aidl::android::hardware::neuralnetworks::utils
diff --git a/neuralnetworks/aidl/utils/test/DeviceTest.cpp b/neuralnetworks/aidl/utils/test/DeviceTest.cpp
new file mode 100644
index 0000000..e53b0a8
--- /dev/null
+++ b/neuralnetworks/aidl/utils/test/DeviceTest.cpp
@@ -0,0 +1,861 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "MockBuffer.h"
+#include "MockDevice.h"
+#include "MockPreparedModel.h"
+
+#include <aidl/android/hardware/neuralnetworks/BnDevice.h>
+#include <android/binder_auto_utils.h>
+#include <android/binder_status.h>
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include <nnapi/IDevice.h>
+#include <nnapi/TypeUtils.h>
+#include <nnapi/Types.h>
+#include <nnapi/hal/aidl/Device.h>
+
+#include <functional>
+#include <memory>
+#include <string>
+
+namespace aidl::android::hardware::neuralnetworks::utils {
+namespace {
+
+namespace nn = ::android::nn;
+using ::testing::_;
+using ::testing::DoAll;
+using ::testing::Invoke;
+using ::testing::InvokeWithoutArgs;
+using ::testing::SetArgPointee;
+
+const nn::Model kSimpleModel = {
+ .main = {.operands = {{.type = nn::OperandType::TENSOR_FLOAT32,
+ .dimensions = {1},
+ .lifetime = nn::Operand::LifeTime::SUBGRAPH_INPUT},
+ {.type = nn::OperandType::TENSOR_FLOAT32,
+ .dimensions = {1},
+ .lifetime = nn::Operand::LifeTime::SUBGRAPH_OUTPUT}},
+ .operations = {{.type = nn::OperationType::RELU, .inputs = {0}, .outputs = {1}}},
+ .inputIndexes = {0},
+ .outputIndexes = {1}}};
+
+const std::string kName = "Google-MockV1";
+const std::string kInvalidName = "";
+const std::shared_ptr<BnDevice> kInvalidDevice;
+constexpr PerformanceInfo kNoPerformanceInfo = {.execTime = std::numeric_limits<float>::max(),
+ .powerUsage = std::numeric_limits<float>::max()};
+constexpr NumberOfCacheFiles kNumberOfCacheFiles = {.numModelCache = nn::kMaxNumberOfCacheFiles,
+ .numDataCache = nn::kMaxNumberOfCacheFiles};
+
+constexpr auto makeStatusOk = [] { return ndk::ScopedAStatus::ok(); };
+
+std::shared_ptr<MockDevice> createMockDevice() {
+ const auto mockDevice = MockDevice::create();
+
+ // Setup default actions for each relevant call.
+ ON_CALL(*mockDevice, getVersionString(_))
+ .WillByDefault(DoAll(SetArgPointee<0>(kName), InvokeWithoutArgs(makeStatusOk)));
+ ON_CALL(*mockDevice, getType(_))
+ .WillByDefault(
+ DoAll(SetArgPointee<0>(DeviceType::OTHER), InvokeWithoutArgs(makeStatusOk)));
+ ON_CALL(*mockDevice, getSupportedExtensions(_))
+ .WillByDefault(DoAll(SetArgPointee<0>(std::vector<Extension>{}),
+ InvokeWithoutArgs(makeStatusOk)));
+ ON_CALL(*mockDevice, getNumberOfCacheFilesNeeded(_))
+ .WillByDefault(
+ DoAll(SetArgPointee<0>(kNumberOfCacheFiles), InvokeWithoutArgs(makeStatusOk)));
+ ON_CALL(*mockDevice, getCapabilities(_))
+ .WillByDefault(
+ DoAll(SetArgPointee<0>(Capabilities{
+ .relaxedFloat32toFloat16PerformanceScalar = kNoPerformanceInfo,
+ .relaxedFloat32toFloat16PerformanceTensor = kNoPerformanceInfo,
+ .ifPerformance = kNoPerformanceInfo,
+ .whilePerformance = kNoPerformanceInfo,
+ }),
+ InvokeWithoutArgs(makeStatusOk)));
+
+ // These EXPECT_CALL(...).Times(testing::AnyNumber()) calls are to suppress warnings on the
+ // uninteresting methods calls.
+ EXPECT_CALL(*mockDevice, getVersionString(_)).Times(testing::AnyNumber());
+ EXPECT_CALL(*mockDevice, getType(_)).Times(testing::AnyNumber());
+ EXPECT_CALL(*mockDevice, getSupportedExtensions(_)).Times(testing::AnyNumber());
+ EXPECT_CALL(*mockDevice, getNumberOfCacheFilesNeeded(_)).Times(testing::AnyNumber());
+ EXPECT_CALL(*mockDevice, getCapabilities(_)).Times(testing::AnyNumber());
+
+ return mockDevice;
+}
+
+constexpr auto makePreparedModelReturnImpl =
+ [](ErrorStatus launchStatus, ErrorStatus returnStatus,
+ const std::shared_ptr<MockPreparedModel>& preparedModel,
+ const std::shared_ptr<IPreparedModelCallback>& cb) {
+ cb->notify(returnStatus, preparedModel);
+ if (launchStatus == ErrorStatus::NONE) {
+ return ndk::ScopedAStatus::ok();
+ }
+ return ndk::ScopedAStatus::fromServiceSpecificError(static_cast<int32_t>(launchStatus));
+ };
+
+auto makePreparedModelReturn(ErrorStatus launchStatus, ErrorStatus returnStatus,
+ const std::shared_ptr<MockPreparedModel>& preparedModel) {
+ return [launchStatus, returnStatus, preparedModel](
+ const Model& /*model*/, ExecutionPreference /*preference*/,
+ Priority /*priority*/, const int64_t& /*deadline*/,
+ const std::vector<ndk::ScopedFileDescriptor>& /*modelCache*/,
+ const std::vector<ndk::ScopedFileDescriptor>& /*dataCache*/,
+ const std::vector<uint8_t>& /*token*/,
+ const std::shared_ptr<IPreparedModelCallback>& cb) -> ndk::ScopedAStatus {
+ return makePreparedModelReturnImpl(launchStatus, returnStatus, preparedModel, cb);
+ };
+}
+
+auto makePreparedModelFromCacheReturn(ErrorStatus launchStatus, ErrorStatus returnStatus,
+ const std::shared_ptr<MockPreparedModel>& preparedModel) {
+ return [launchStatus, returnStatus, preparedModel](
+ const int64_t& /*deadline*/,
+ const std::vector<ndk::ScopedFileDescriptor>& /*modelCache*/,
+ const std::vector<ndk::ScopedFileDescriptor>& /*dataCache*/,
+ const std::vector<uint8_t>& /*token*/,
+ const std::shared_ptr<IPreparedModelCallback>& cb) {
+ return makePreparedModelReturnImpl(launchStatus, returnStatus, preparedModel, cb);
+ };
+}
+
+constexpr auto makeGeneralFailure = [] {
+ return ndk::ScopedAStatus::fromServiceSpecificError(
+ static_cast<int32_t>(ErrorStatus::GENERAL_FAILURE));
+};
+constexpr auto makeGeneralTransportFailure = [] {
+ return ndk::ScopedAStatus::fromStatus(STATUS_NO_MEMORY);
+};
+constexpr auto makeDeadObjectFailure = [] {
+ return ndk::ScopedAStatus::fromStatus(STATUS_DEAD_OBJECT);
+};
+
+} // namespace
+
+TEST(DeviceTest, invalidName) {
+ // run test
+ const auto device = MockDevice::create();
+ const auto result = Device::create(kInvalidName, device);
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::INVALID_ARGUMENT);
+}
+
+TEST(DeviceTest, invalidDevice) {
+ // run test
+ const auto result = Device::create(kName, kInvalidDevice);
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::INVALID_ARGUMENT);
+}
+
+TEST(DeviceTest, getVersionStringError) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ EXPECT_CALL(*mockDevice, getVersionString(_))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(makeGeneralFailure));
+
+ // run test
+ const auto result = Device::create(kName, mockDevice);
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST(DeviceTest, getVersionStringTransportFailure) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ EXPECT_CALL(*mockDevice, getVersionString(_))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
+
+ // run test
+ const auto result = Device::create(kName, mockDevice);
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST(DeviceTest, getVersionStringDeadObject) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ EXPECT_CALL(*mockDevice, getVersionString(_))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
+
+ // run test
+ const auto result = Device::create(kName, mockDevice);
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
+}
+
+TEST(DeviceTest, getTypeError) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ EXPECT_CALL(*mockDevice, getType(_)).Times(1).WillOnce(InvokeWithoutArgs(makeGeneralFailure));
+
+ // run test
+ const auto result = Device::create(kName, mockDevice);
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST(DeviceTest, getTypeTransportFailure) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ EXPECT_CALL(*mockDevice, getType(_))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
+
+ // run test
+ const auto result = Device::create(kName, mockDevice);
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST(DeviceTest, getTypeDeadObject) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ EXPECT_CALL(*mockDevice, getType(_))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
+
+ // run test
+ const auto result = Device::create(kName, mockDevice);
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
+}
+
+TEST(DeviceTest, getSupportedExtensionsError) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ EXPECT_CALL(*mockDevice, getSupportedExtensions(_))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(makeGeneralFailure));
+
+ // run test
+ const auto result = Device::create(kName, mockDevice);
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST(DeviceTest, getSupportedExtensionsTransportFailure) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ EXPECT_CALL(*mockDevice, getSupportedExtensions(_))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
+
+ // run test
+ const auto result = Device::create(kName, mockDevice);
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST(DeviceTest, getSupportedExtensionsDeadObject) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ EXPECT_CALL(*mockDevice, getSupportedExtensions(_))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
+
+ // run test
+ const auto result = Device::create(kName, mockDevice);
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
+}
+
+TEST(DeviceTest, getNumberOfCacheFilesNeededError) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ EXPECT_CALL(*mockDevice, getNumberOfCacheFilesNeeded(_))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(makeGeneralFailure));
+
+ // run test
+ const auto result = Device::create(kName, mockDevice);
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST(DeviceTest, dataCacheFilesExceedsSpecifiedMax) {
+ // setup test
+ const auto mockDevice = createMockDevice();
+ EXPECT_CALL(*mockDevice, getNumberOfCacheFilesNeeded(_))
+ .Times(1)
+ .WillOnce(DoAll(SetArgPointee<0>(NumberOfCacheFiles{
+ .numModelCache = nn::kMaxNumberOfCacheFiles + 1,
+ .numDataCache = nn::kMaxNumberOfCacheFiles}),
+ InvokeWithoutArgs(makeStatusOk)));
+
+ // run test
+ const auto result = Device::create(kName, mockDevice);
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST(DeviceTest, modelCacheFilesExceedsSpecifiedMax) {
+ // setup test
+ const auto mockDevice = createMockDevice();
+ EXPECT_CALL(*mockDevice, getNumberOfCacheFilesNeeded(_))
+ .Times(1)
+ .WillOnce(DoAll(SetArgPointee<0>(NumberOfCacheFiles{
+ .numModelCache = nn::kMaxNumberOfCacheFiles,
+ .numDataCache = nn::kMaxNumberOfCacheFiles + 1}),
+ InvokeWithoutArgs(makeStatusOk)));
+
+ // run test
+ const auto result = Device::create(kName, mockDevice);
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST(DeviceTest, getNumberOfCacheFilesNeededTransportFailure) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ EXPECT_CALL(*mockDevice, getNumberOfCacheFilesNeeded(_))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
+
+ // run test
+ const auto result = Device::create(kName, mockDevice);
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST(DeviceTest, getNumberOfCacheFilesNeededDeadObject) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ EXPECT_CALL(*mockDevice, getNumberOfCacheFilesNeeded(_))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
+
+ // run test
+ const auto result = Device::create(kName, mockDevice);
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
+}
+
+TEST(DeviceTest, getCapabilitiesError) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ EXPECT_CALL(*mockDevice, getCapabilities(_))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(makeGeneralFailure));
+
+ // run test
+ const auto result = Device::create(kName, mockDevice);
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST(DeviceTest, getCapabilitiesTransportFailure) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ EXPECT_CALL(*mockDevice, getCapabilities(_))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
+
+ // run test
+ const auto result = Device::create(kName, mockDevice);
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST(DeviceTest, getCapabilitiesDeadObject) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ EXPECT_CALL(*mockDevice, getCapabilities(_))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
+
+ // run test
+ const auto result = Device::create(kName, mockDevice);
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
+}
+
+TEST(DeviceTest, getName) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ const auto device = Device::create(kName, mockDevice).value();
+
+ // run test
+ const auto& name = device->getName();
+
+ // verify result
+ EXPECT_EQ(name, kName);
+}
+
+TEST(DeviceTest, getFeatureLevel) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ const auto device = Device::create(kName, mockDevice).value();
+
+ // run test
+ const auto featureLevel = device->getFeatureLevel();
+
+ // verify result
+ EXPECT_EQ(featureLevel, nn::Version::ANDROID_S);
+}
+
+TEST(DeviceTest, getCachedData) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ EXPECT_CALL(*mockDevice, getVersionString(_)).Times(1);
+ EXPECT_CALL(*mockDevice, getType(_)).Times(1);
+ EXPECT_CALL(*mockDevice, getSupportedExtensions(_)).Times(1);
+ EXPECT_CALL(*mockDevice, getNumberOfCacheFilesNeeded(_)).Times(1);
+ EXPECT_CALL(*mockDevice, getCapabilities(_)).Times(1);
+
+ const auto result = Device::create(kName, mockDevice);
+ ASSERT_TRUE(result.has_value())
+ << "Failed with " << result.error().code << ": " << result.error().message;
+ const auto& device = result.value();
+
+ // run test and verify results
+ EXPECT_EQ(device->getVersionString(), device->getVersionString());
+ EXPECT_EQ(device->getType(), device->getType());
+ EXPECT_EQ(device->getSupportedExtensions(), device->getSupportedExtensions());
+ EXPECT_EQ(device->getNumberOfCacheFilesNeeded(), device->getNumberOfCacheFilesNeeded());
+ EXPECT_EQ(device->getCapabilities(), device->getCapabilities());
+}
+
+TEST(DeviceTest, getSupportedOperations) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ const auto device = Device::create(kName, mockDevice).value();
+ EXPECT_CALL(*mockDevice, getSupportedOperations(_, _))
+ .Times(1)
+ .WillOnce(DoAll(
+ SetArgPointee<1>(std::vector<bool>(kSimpleModel.main.operations.size(), true)),
+ InvokeWithoutArgs(makeStatusOk)));
+
+ // run test
+ const auto result = device->getSupportedOperations(kSimpleModel);
+
+ // verify result
+ ASSERT_TRUE(result.has_value())
+ << "Failed with " << result.error().code << ": " << result.error().message;
+ const auto& supportedOperations = result.value();
+ EXPECT_EQ(supportedOperations.size(), kSimpleModel.main.operations.size());
+ EXPECT_THAT(supportedOperations, Each(testing::IsTrue()));
+}
+
+TEST(DeviceTest, getSupportedOperationsError) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ const auto device = Device::create(kName, mockDevice).value();
+ EXPECT_CALL(*mockDevice, getSupportedOperations(_, _))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(makeGeneralFailure));
+
+ // run test
+ const auto result = device->getSupportedOperations(kSimpleModel);
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST(DeviceTest, getSupportedOperationsTransportFailure) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ const auto device = Device::create(kName, mockDevice).value();
+ EXPECT_CALL(*mockDevice, getSupportedOperations(_, _))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
+
+ // run test
+ const auto result = device->getSupportedOperations(kSimpleModel);
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST(DeviceTest, getSupportedOperationsDeadObject) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ const auto device = Device::create(kName, mockDevice).value();
+ EXPECT_CALL(*mockDevice, getSupportedOperations(_, _))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
+
+ // run test
+ const auto result = device->getSupportedOperations(kSimpleModel);
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
+}
+
+TEST(DeviceTest, prepareModel) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ const auto device = Device::create(kName, mockDevice).value();
+ const auto mockPreparedModel = MockPreparedModel::create();
+ EXPECT_CALL(*mockDevice, prepareModel(_, _, _, _, _, _, _, _))
+ .Times(1)
+ .WillOnce(Invoke(makePreparedModelReturn(ErrorStatus::NONE, ErrorStatus::NONE,
+ mockPreparedModel)));
+
+ // run test
+ const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
+ nn::Priority::DEFAULT, {}, {}, {}, {});
+
+ // verify result
+ ASSERT_TRUE(result.has_value())
+ << "Failed with " << result.error().code << ": " << result.error().message;
+ EXPECT_NE(result.value(), nullptr);
+}
+
+TEST(DeviceTest, prepareModelLaunchError) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ const auto device = Device::create(kName, mockDevice).value();
+ EXPECT_CALL(*mockDevice, prepareModel(_, _, _, _, _, _, _, _))
+ .Times(1)
+ .WillOnce(Invoke(makePreparedModelReturn(ErrorStatus::GENERAL_FAILURE,
+ ErrorStatus::GENERAL_FAILURE, nullptr)));
+
+ // run test
+ const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
+ nn::Priority::DEFAULT, {}, {}, {}, {});
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST(DeviceTest, prepareModelReturnError) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ const auto device = Device::create(kName, mockDevice).value();
+ EXPECT_CALL(*mockDevice, prepareModel(_, _, _, _, _, _, _, _))
+ .Times(1)
+ .WillOnce(Invoke(makePreparedModelReturn(ErrorStatus::NONE,
+ ErrorStatus::GENERAL_FAILURE, nullptr)));
+
+ // run test
+ const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
+ nn::Priority::DEFAULT, {}, {}, {}, {});
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST(DeviceTest, prepareModelNullptrError) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ const auto device = Device::create(kName, mockDevice).value();
+ EXPECT_CALL(*mockDevice, prepareModel(_, _, _, _, _, _, _, _))
+ .Times(1)
+ .WillOnce(
+ Invoke(makePreparedModelReturn(ErrorStatus::NONE, ErrorStatus::NONE, nullptr)));
+
+ // run test
+ const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
+ nn::Priority::DEFAULT, {}, {}, {}, {});
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST(DeviceTest, prepareModelTransportFailure) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ const auto device = Device::create(kName, mockDevice).value();
+ EXPECT_CALL(*mockDevice, prepareModel(_, _, _, _, _, _, _, _))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
+
+ // run test
+ const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
+ nn::Priority::DEFAULT, {}, {}, {}, {});
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST(DeviceTest, prepareModelDeadObject) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ const auto device = Device::create(kName, mockDevice).value();
+ EXPECT_CALL(*mockDevice, prepareModel(_, _, _, _, _, _, _, _))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
+
+ // run test
+ const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
+ nn::Priority::DEFAULT, {}, {}, {}, {});
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
+}
+
+TEST(DeviceTest, prepareModelAsyncCrash) {
+ // setup test
+ const auto mockDevice = createMockDevice();
+ const auto device = Device::create(kName, mockDevice).value();
+ const auto ret = [&device]() {
+ DeathMonitor::serviceDied(device->getDeathMonitor());
+ return ndk::ScopedAStatus::ok();
+ };
+ EXPECT_CALL(*mockDevice, prepareModel(_, _, _, _, _, _, _, _))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(ret));
+
+ // run test
+ const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
+ nn::Priority::DEFAULT, {}, {}, {}, {});
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
+}
+
+TEST(DeviceTest, prepareModelFromCache) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ const auto device = Device::create(kName, mockDevice).value();
+ const auto mockPreparedModel = MockPreparedModel::create();
+ EXPECT_CALL(*mockDevice, prepareModelFromCache(_, _, _, _, _))
+ .Times(1)
+ .WillOnce(Invoke(makePreparedModelFromCacheReturn(ErrorStatus::NONE, ErrorStatus::NONE,
+ mockPreparedModel)));
+
+ // run test
+ const auto result = device->prepareModelFromCache({}, {}, {}, {});
+
+ // verify result
+ ASSERT_TRUE(result.has_value())
+ << "Failed with " << result.error().code << ": " << result.error().message;
+ EXPECT_NE(result.value(), nullptr);
+}
+
+TEST(DeviceTest, prepareModelFromCacheLaunchError) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ const auto device = Device::create(kName, mockDevice).value();
+ EXPECT_CALL(*mockDevice, prepareModelFromCache(_, _, _, _, _))
+ .Times(1)
+ .WillOnce(Invoke(makePreparedModelFromCacheReturn(
+ ErrorStatus::GENERAL_FAILURE, ErrorStatus::GENERAL_FAILURE, nullptr)));
+
+ // run test
+ const auto result = device->prepareModelFromCache({}, {}, {}, {});
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST(DeviceTest, prepareModelFromCacheReturnError) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ const auto device = Device::create(kName, mockDevice).value();
+ EXPECT_CALL(*mockDevice, prepareModelFromCache(_, _, _, _, _))
+ .Times(1)
+ .WillOnce(Invoke(makePreparedModelFromCacheReturn(
+ ErrorStatus::NONE, ErrorStatus::GENERAL_FAILURE, nullptr)));
+
+ // run test
+ const auto result = device->prepareModelFromCache({}, {}, {}, {});
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST(DeviceTest, prepareModelFromCacheNullptrError) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ const auto device = Device::create(kName, mockDevice).value();
+ EXPECT_CALL(*mockDevice, prepareModelFromCache(_, _, _, _, _))
+ .Times(1)
+ .WillOnce(Invoke(makePreparedModelFromCacheReturn(ErrorStatus::NONE, ErrorStatus::NONE,
+ nullptr)));
+
+ // run test
+ const auto result = device->prepareModelFromCache({}, {}, {}, {});
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST(DeviceTest, prepareModelFromCacheTransportFailure) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ const auto device = Device::create(kName, mockDevice).value();
+ EXPECT_CALL(*mockDevice, prepareModelFromCache(_, _, _, _, _))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
+
+ // run test
+ const auto result = device->prepareModelFromCache({}, {}, {}, {});
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST(DeviceTest, prepareModelFromCacheDeadObject) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ const auto device = Device::create(kName, mockDevice).value();
+ EXPECT_CALL(*mockDevice, prepareModelFromCache(_, _, _, _, _))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
+
+ // run test
+ const auto result = device->prepareModelFromCache({}, {}, {}, {});
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
+}
+
+TEST(DeviceTest, prepareModelFromCacheAsyncCrash) {
+ // setup test
+ const auto mockDevice = createMockDevice();
+ const auto device = Device::create(kName, mockDevice).value();
+ const auto ret = [&device]() {
+ DeathMonitor::serviceDied(device->getDeathMonitor());
+ return ndk::ScopedAStatus::ok();
+ };
+ EXPECT_CALL(*mockDevice, prepareModelFromCache(_, _, _, _, _))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(ret));
+
+ // run test
+ const auto result = device->prepareModelFromCache({}, {}, {}, {});
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
+}
+
+TEST(DeviceTest, allocate) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ const auto device = Device::create(kName, mockDevice).value();
+ const auto mockBuffer = DeviceBuffer{.buffer = MockBuffer::create(), .token = 1};
+ EXPECT_CALL(*mockDevice, allocate(_, _, _, _, _))
+ .Times(1)
+ .WillOnce(DoAll(SetArgPointee<4>(mockBuffer), InvokeWithoutArgs(makeStatusOk)));
+
+ // run test
+ const auto result = device->allocate({}, {}, {}, {});
+
+ // verify result
+ ASSERT_TRUE(result.has_value())
+ << "Failed with " << result.error().code << ": " << result.error().message;
+ EXPECT_NE(result.value(), nullptr);
+}
+
+TEST(DeviceTest, allocateError) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ const auto device = Device::create(kName, mockDevice).value();
+ EXPECT_CALL(*mockDevice, allocate(_, _, _, _, _))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(makeGeneralFailure));
+
+ // run test
+ const auto result = device->allocate({}, {}, {}, {});
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST(DeviceTest, allocateTransportFailure) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ const auto device = Device::create(kName, mockDevice).value();
+ EXPECT_CALL(*mockDevice, allocate(_, _, _, _, _))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
+
+ // run test
+ const auto result = device->allocate({}, {}, {}, {});
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST(DeviceTest, allocateDeadObject) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ const auto device = Device::create(kName, mockDevice).value();
+ EXPECT_CALL(*mockDevice, allocate(_, _, _, _, _))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
+
+ // run test
+ const auto result = device->allocate({}, {}, {}, {});
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
+}
+
+} // namespace aidl::android::hardware::neuralnetworks::utils
diff --git a/neuralnetworks/aidl/utils/test/MockBuffer.h b/neuralnetworks/aidl/utils/test/MockBuffer.h
new file mode 100644
index 0000000..5746176
--- /dev/null
+++ b/neuralnetworks/aidl/utils/test/MockBuffer.h
@@ -0,0 +1,43 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_AIDL_UTILS_TEST_MOCK_BUFFER
+#define ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_AIDL_UTILS_TEST_MOCK_BUFFER
+
+#include <aidl/android/hardware/neuralnetworks/BnBuffer.h>
+#include <android/binder_interface_utils.h>
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include <hidl/Status.h>
+
+namespace aidl::android::hardware::neuralnetworks::utils {
+
+class MockBuffer final : public BnBuffer {
+ public:
+ static std::shared_ptr<MockBuffer> create();
+
+ MOCK_METHOD(ndk::ScopedAStatus, copyTo, (const Memory& dst), (override));
+ MOCK_METHOD(ndk::ScopedAStatus, copyFrom,
+ (const Memory& src, const std::vector<int32_t>& dimensions), (override));
+};
+
+inline std::shared_ptr<MockBuffer> MockBuffer::create() {
+ return ndk::SharedRefBase::make<MockBuffer>();
+}
+
+} // namespace aidl::android::hardware::neuralnetworks::utils
+
+#endif // ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_AIDL_UTILS_TEST_MOCK_BUFFER
diff --git a/neuralnetworks/aidl/utils/test/MockDevice.h b/neuralnetworks/aidl/utils/test/MockDevice.h
new file mode 100644
index 0000000..9b35bf8
--- /dev/null
+++ b/neuralnetworks/aidl/utils/test/MockDevice.h
@@ -0,0 +1,67 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_AIDL_UTILS_TEST_MOCK_DEVICE
+#define ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_AIDL_UTILS_TEST_MOCK_DEVICE
+
+#include <aidl/android/hardware/neuralnetworks/BnDevice.h>
+#include <android/binder_auto_utils.h>
+#include <android/binder_interface_utils.h>
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+namespace aidl::android::hardware::neuralnetworks::utils {
+
+class MockDevice final : public BnDevice {
+ public:
+ static std::shared_ptr<MockDevice> create();
+
+ MOCK_METHOD(ndk::ScopedAStatus, allocate,
+ (const BufferDesc& desc, const std::vector<IPreparedModelParcel>& preparedModels,
+ const std::vector<BufferRole>& inputRoles,
+ const std::vector<BufferRole>& outputRoles, DeviceBuffer* deviceBuffer),
+ (override));
+ MOCK_METHOD(ndk::ScopedAStatus, getCapabilities, (Capabilities * capabilities), (override));
+ MOCK_METHOD(ndk::ScopedAStatus, getNumberOfCacheFilesNeeded,
+ (NumberOfCacheFiles * numberOfCacheFiles), (override));
+ MOCK_METHOD(ndk::ScopedAStatus, getSupportedExtensions, (std::vector<Extension> * extensions),
+ (override));
+ MOCK_METHOD(ndk::ScopedAStatus, getSupportedOperations,
+ (const Model& model, std::vector<bool>* supportedOperations), (override));
+ MOCK_METHOD(ndk::ScopedAStatus, getType, (DeviceType * deviceType), (override));
+ MOCK_METHOD(ndk::ScopedAStatus, getVersionString, (std::string * version), (override));
+ MOCK_METHOD(ndk::ScopedAStatus, prepareModel,
+ (const Model& model, ExecutionPreference preference, Priority priority,
+ int64_t deadline, const std::vector<ndk::ScopedFileDescriptor>& modelCache,
+ const std::vector<ndk::ScopedFileDescriptor>& dataCache,
+ const std::vector<uint8_t>& token,
+ const std::shared_ptr<IPreparedModelCallback>& callback),
+ (override));
+ MOCK_METHOD(ndk::ScopedAStatus, prepareModelFromCache,
+ (int64_t deadline, const std::vector<ndk::ScopedFileDescriptor>& modelCache,
+ const std::vector<ndk::ScopedFileDescriptor>& dataCache,
+ const std::vector<uint8_t>& token,
+ const std::shared_ptr<IPreparedModelCallback>& callback),
+ (override));
+};
+
+inline std::shared_ptr<MockDevice> MockDevice::create() {
+ return ndk::SharedRefBase::make<MockDevice>();
+}
+
+} // namespace aidl::android::hardware::neuralnetworks::utils
+
+#endif // ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_AIDL_UTILS_TEST_MOCK_DEVICE
diff --git a/neuralnetworks/aidl/utils/test/MockFencedExecutionCallback.h b/neuralnetworks/aidl/utils/test/MockFencedExecutionCallback.h
new file mode 100644
index 0000000..463e1c9
--- /dev/null
+++ b/neuralnetworks/aidl/utils/test/MockFencedExecutionCallback.h
@@ -0,0 +1,45 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_AIDL_UTILS_TEST_MOCK_FENCED_EXECUTION_CALLBACK
+#define ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_AIDL_UTILS_TEST_MOCK_FENCED_EXECUTION_CALLBACK
+
+#include <aidl/android/hardware/neuralnetworks/BnFencedExecutionCallback.h>
+#include <android/binder_auto_utils.h>
+#include <android/binder_interface_utils.h>
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include <hidl/Status.h>
+
+namespace aidl::android::hardware::neuralnetworks::utils {
+
+class MockFencedExecutionCallback final : public BnFencedExecutionCallback {
+ public:
+ static std::shared_ptr<MockFencedExecutionCallback> create();
+
+ // V1_3 methods below.
+ MOCK_METHOD(ndk::ScopedAStatus, getExecutionInfo,
+ (Timing * timingLaunched, Timing* timingFenced, ErrorStatus* errorStatus),
+ (override));
+};
+
+inline std::shared_ptr<MockFencedExecutionCallback> MockFencedExecutionCallback::create() {
+ return ndk::SharedRefBase::make<MockFencedExecutionCallback>();
+}
+
+} // namespace aidl::android::hardware::neuralnetworks::utils
+
+#endif // ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_AIDL_UTILS_TEST_MOCK_FENCED_EXECUTION_CALLBACK
diff --git a/neuralnetworks/aidl/utils/test/MockPreparedModel.h b/neuralnetworks/aidl/utils/test/MockPreparedModel.h
new file mode 100644
index 0000000..545b491
--- /dev/null
+++ b/neuralnetworks/aidl/utils/test/MockPreparedModel.h
@@ -0,0 +1,50 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_AIDL_UTILS_TEST_MOCK_PREPARED_MODEL
+#define ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_AIDL_UTILS_TEST_MOCK_PREPARED_MODEL
+
+#include <aidl/android/hardware/neuralnetworks/BnPreparedModel.h>
+#include <android/binder_interface_utils.h>
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include <hidl/HidlSupport.h>
+#include <hidl/Status.h>
+
+namespace aidl::android::hardware::neuralnetworks::utils {
+
+class MockPreparedModel final : public BnPreparedModel {
+ public:
+ static std::shared_ptr<MockPreparedModel> create();
+
+ MOCK_METHOD(ndk::ScopedAStatus, executeSynchronously,
+ (const Request& request, bool measureTiming, int64_t deadline,
+ int64_t loopTimeoutDuration, ExecutionResult* executionResult),
+ (override));
+ MOCK_METHOD(ndk::ScopedAStatus, executeFenced,
+ (const Request& request, const std::vector<ndk::ScopedFileDescriptor>& waitFor,
+ bool measureTiming, int64_t deadline, int64_t loopTimeoutDuration,
+ int64_t duration, FencedExecutionResult* fencedExecutionResult),
+ (override));
+};
+
+inline std::shared_ptr<MockPreparedModel> MockPreparedModel::create() {
+ return ndk::SharedRefBase::make<MockPreparedModel>();
+}
+
+} // namespace aidl::android::hardware::neuralnetworks::utils
+
+#endif // ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_AIDL_UTILS_TEST_MOCK_PREPARED_MODEL
diff --git a/neuralnetworks/aidl/utils/test/PreparedModelTest.cpp b/neuralnetworks/aidl/utils/test/PreparedModelTest.cpp
new file mode 100644
index 0000000..7e28861
--- /dev/null
+++ b/neuralnetworks/aidl/utils/test/PreparedModelTest.cpp
@@ -0,0 +1,272 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "MockFencedExecutionCallback.h"
+#include "MockPreparedModel.h"
+
+#include <aidl/android/hardware/neuralnetworks/IFencedExecutionCallback.h>
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include <nnapi/IPreparedModel.h>
+#include <nnapi/TypeUtils.h>
+#include <nnapi/Types.h>
+#include <nnapi/hal/aidl/PreparedModel.h>
+
+#include <functional>
+#include <memory>
+
+namespace aidl::android::hardware::neuralnetworks::utils {
+namespace {
+
+using ::testing::_;
+using ::testing::DoAll;
+using ::testing::Invoke;
+using ::testing::InvokeWithoutArgs;
+using ::testing::SetArgPointee;
+
+const std::shared_ptr<IPreparedModel> kInvalidPreparedModel;
+constexpr auto kNoTiming = Timing{.timeOnDevice = -1, .timeInDriver = -1};
+
+constexpr auto makeStatusOk = [] { return ndk::ScopedAStatus::ok(); };
+
+constexpr auto makeGeneralFailure = [] {
+ return ndk::ScopedAStatus::fromServiceSpecificError(
+ static_cast<int32_t>(ErrorStatus::GENERAL_FAILURE));
+};
+constexpr auto makeGeneralTransportFailure = [] {
+ return ndk::ScopedAStatus::fromStatus(STATUS_NO_MEMORY);
+};
+constexpr auto makeDeadObjectFailure = [] {
+ return ndk::ScopedAStatus::fromStatus(STATUS_DEAD_OBJECT);
+};
+
+auto makeFencedExecutionResult(const std::shared_ptr<MockFencedExecutionCallback>& callback) {
+ return [callback](const Request& /*request*/,
+ const std::vector<ndk::ScopedFileDescriptor>& /*waitFor*/,
+ bool /*measureTiming*/, int64_t /*deadline*/, int64_t /*loopTimeoutDuration*/,
+ int64_t /*duration*/, FencedExecutionResult* fencedExecutionResult) {
+ *fencedExecutionResult = FencedExecutionResult{.callback = callback,
+ .syncFence = ndk::ScopedFileDescriptor(-1)};
+ return ndk::ScopedAStatus::ok();
+ };
+}
+
+} // namespace
+
+TEST(PreparedModelTest, invalidPreparedModel) {
+ // run test
+ const auto result = PreparedModel::create(kInvalidPreparedModel);
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST(PreparedModelTest, executeSync) {
+ // setup call
+ const auto mockPreparedModel = MockPreparedModel::create();
+ const auto preparedModel = PreparedModel::create(mockPreparedModel).value();
+ const auto mockExecutionResult = ExecutionResult{
+ .outputSufficientSize = true,
+ .outputShapes = {},
+ .timing = kNoTiming,
+ };
+ EXPECT_CALL(*mockPreparedModel, executeSynchronously(_, _, _, _, _))
+ .Times(1)
+ .WillOnce(
+ DoAll(SetArgPointee<4>(mockExecutionResult), InvokeWithoutArgs(makeStatusOk)));
+
+ // run test
+ const auto result = preparedModel->execute({}, {}, {}, {});
+
+ // verify result
+ EXPECT_TRUE(result.has_value())
+ << "Failed with " << result.error().code << ": " << result.error().message;
+}
+
+TEST(PreparedModelTest, executeSyncError) {
+ // setup test
+ const auto mockPreparedModel = MockPreparedModel::create();
+ const auto preparedModel = PreparedModel::create(mockPreparedModel).value();
+ EXPECT_CALL(*mockPreparedModel, executeSynchronously(_, _, _, _, _))
+ .Times(1)
+ .WillOnce(Invoke(makeGeneralFailure));
+
+ // run test
+ const auto result = preparedModel->execute({}, {}, {}, {});
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST(PreparedModelTest, executeSyncTransportFailure) {
+ // setup test
+ const auto mockPreparedModel = MockPreparedModel::create();
+ const auto preparedModel = PreparedModel::create(mockPreparedModel).value();
+ EXPECT_CALL(*mockPreparedModel, executeSynchronously(_, _, _, _, _))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
+
+ // run test
+ const auto result = preparedModel->execute({}, {}, {}, {});
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST(PreparedModelTest, executeSyncDeadObject) {
+ // setup test
+ const auto mockPreparedModel = MockPreparedModel::create();
+ const auto preparedModel = PreparedModel::create(mockPreparedModel).value();
+ EXPECT_CALL(*mockPreparedModel, executeSynchronously(_, _, _, _, _))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
+
+ // run test
+ const auto result = preparedModel->execute({}, {}, {}, {});
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
+}
+
+TEST(PreparedModelTest, executeFenced) {
+ // setup call
+ const auto mockPreparedModel = MockPreparedModel::create();
+ const auto preparedModel = PreparedModel::create(mockPreparedModel).value();
+ const auto mockCallback = MockFencedExecutionCallback::create();
+ EXPECT_CALL(*mockCallback, getExecutionInfo(_, _, _))
+ .Times(1)
+ .WillOnce(DoAll(SetArgPointee<0>(kNoTiming), SetArgPointee<1>(kNoTiming),
+ SetArgPointee<2>(ErrorStatus::NONE), Invoke(makeStatusOk)));
+ EXPECT_CALL(*mockPreparedModel, executeFenced(_, _, _, _, _, _, _))
+ .Times(1)
+ .WillOnce(Invoke(makeFencedExecutionResult(mockCallback)));
+
+ // run test
+ const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {});
+
+ // verify result
+ ASSERT_TRUE(result.has_value())
+ << "Failed with " << result.error().code << ": " << result.error().message;
+ const auto& [syncFence, callback] = result.value();
+ EXPECT_EQ(syncFence.syncWait({}), nn::SyncFence::FenceState::SIGNALED);
+ ASSERT_NE(callback, nullptr);
+
+ // get results from callback
+ const auto callbackResult = callback();
+ ASSERT_TRUE(callbackResult.has_value()) << "Failed with " << callbackResult.error().code << ": "
+ << callbackResult.error().message;
+}
+
+TEST(PreparedModelTest, executeFencedCallbackError) {
+ // setup call
+ const auto mockPreparedModel = MockPreparedModel::create();
+ const auto preparedModel = PreparedModel::create(mockPreparedModel).value();
+ const auto mockCallback = MockFencedExecutionCallback::create();
+ EXPECT_CALL(*mockCallback, getExecutionInfo(_, _, _))
+ .Times(1)
+ .WillOnce(Invoke(DoAll(SetArgPointee<0>(kNoTiming), SetArgPointee<1>(kNoTiming),
+ SetArgPointee<2>(ErrorStatus::GENERAL_FAILURE),
+ Invoke(makeStatusOk))));
+ EXPECT_CALL(*mockPreparedModel, executeFenced(_, _, _, _, _, _, _))
+ .Times(1)
+ .WillOnce(Invoke(makeFencedExecutionResult(mockCallback)));
+
+ // run test
+ const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {});
+
+ // verify result
+ ASSERT_TRUE(result.has_value())
+ << "Failed with " << result.error().code << ": " << result.error().message;
+ const auto& [syncFence, callback] = result.value();
+ EXPECT_NE(syncFence.syncWait({}), nn::SyncFence::FenceState::ACTIVE);
+ ASSERT_NE(callback, nullptr);
+
+ // verify callback failure
+ const auto callbackResult = callback();
+ ASSERT_FALSE(callbackResult.has_value());
+ EXPECT_EQ(callbackResult.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST(PreparedModelTest, executeFencedError) {
+ // setup test
+ const auto mockPreparedModel = MockPreparedModel::create();
+ const auto preparedModel = PreparedModel::create(mockPreparedModel).value();
+ EXPECT_CALL(*mockPreparedModel, executeFenced(_, _, _, _, _, _, _))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(makeGeneralFailure));
+
+ // run test
+ const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {});
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST(PreparedModelTest, executeFencedTransportFailure) {
+ // setup test
+ const auto mockPreparedModel = MockPreparedModel::create();
+ const auto preparedModel = PreparedModel::create(mockPreparedModel).value();
+ EXPECT_CALL(*mockPreparedModel, executeFenced(_, _, _, _, _, _, _))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
+
+ // run test
+ const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {});
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST(PreparedModelTest, executeFencedDeadObject) {
+ // setup test
+ const auto mockPreparedModel = MockPreparedModel::create();
+ const auto preparedModel = PreparedModel::create(mockPreparedModel).value();
+ EXPECT_CALL(*mockPreparedModel, executeFenced(_, _, _, _, _, _, _))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
+
+ // run test
+ const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {});
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
+}
+
+// TODO: test burst execution if/when it is added to nn::IPreparedModel.
+
+TEST(PreparedModelTest, getUnderlyingResource) {
+ // setup test
+ const auto mockPreparedModel = MockPreparedModel::create();
+ const auto preparedModel = PreparedModel::create(mockPreparedModel).value();
+
+ // run test
+ const auto resource = preparedModel->getUnderlyingResource();
+
+ // verify resource
+ const std::shared_ptr<IPreparedModel>* maybeMock =
+ std::any_cast<std::shared_ptr<IPreparedModel>>(&resource);
+ ASSERT_NE(maybeMock, nullptr);
+ EXPECT_EQ(maybeMock->get(), mockPreparedModel.get());
+}
+
+} // namespace aidl::android::hardware::neuralnetworks::utils
diff --git a/neuralnetworks/utils/README.md b/neuralnetworks/utils/README.md
index 45ca0b4..87b3f9f 100644
--- a/neuralnetworks/utils/README.md
+++ b/neuralnetworks/utils/README.md
@@ -49,7 +49,9 @@
(i.e., not as a nested class) or used in a subsequent version of the NN HAL. Prefer using `convert`
over `unvalidatedConvert`.
-# HIDL Interface Lifetimes across Processes
+# Interface Lifetimes across Processes
+
+## HIDL
Some notes about HIDL interface objects and lifetimes across processes:
@@ -68,7 +70,20 @@
If the process which created the HIDL interface object dies, any call on this object from another
process will result in a HIDL transport error with the code `DEAD_OBJECT`.
-# Protecting Asynchronous Calls across HIDL
+## AIDL
+
+We use NDK backend for AIDL interfaces. Handling of lifetimes is generally the same with the
+following differences:
+* Interfaces inherit from `ndk::ICInterface`, which inherits from `ndk::SharedRefBase`. The latter
+ is an analog of `::android::RefBase` using `std::shared_ptr` for reference counting.
+* AIDL calls return `ndk::ScopedAStatus` which wraps fields of types `binder_status_t` and
+ `binder_exception_t`. In case the call is made on a dead object, the call will return
+ `ndk::ScopedAStatus` with exception `EX_TRANSACTION_FAILED` and binder status
+ `STATUS_DEAD_OBJECT`.
+
+# Protecting Asynchronous Calls
+
+## Across HIDL
Some notes about asynchronous calls across HIDL:
@@ -95,3 +110,17 @@
driver process has died, and `DeathHandler` will unblock any thread waiting on the results of an
`IProtectedCallback` callback object that may otherwise not be signaled. In order for this to work,
the `IProtectedCallback` object must have been registered via `DeathHandler::protectCallback()`.
+
+## Across AIDL
+
+We use NDK backend for AIDL interfaces. Handling of asynchronous calls is generally the same with
+the following differences:
+* AIDL calls return `ndk::ScopedAStatus` which wraps fields of types `binder_status_t` and
+ `binder_exception_t`. In case the call is made on a dead object, the call will return
+ `ndk::ScopedAStatus` with exception `EX_TRANSACTION_FAILED` and binder status
+ `STATUS_DEAD_OBJECT`.
+* AIDL interface doesn't contain asynchronous `IPreparedModel::execute`.
+* Service death is handled using `AIBinder_DeathRecipient` object which is linked to an interface
+ object using `AIBinder_linkToDeath`. nnapi/hal/aidl/ProtectCallback.h provides `DeathHandler`
+ object that is a direct analog of HIDL `DeathHandler`, only using libbinder_ndk objects for
+ implementation.
diff --git a/neuralnetworks/utils/common/Android.bp b/neuralnetworks/utils/common/Android.bp
index 6162fe8..2ed1e40 100644
--- a/neuralnetworks/utils/common/Android.bp
+++ b/neuralnetworks/utils/common/Android.bp
@@ -35,8 +35,10 @@
"neuralnetworks_types",
],
shared_libs: [
+ "android.hardware.neuralnetworks-V1-ndk_platform",
"libhidlbase",
"libnativewindow",
+ "libbinder_ndk",
],
}
diff --git a/neuralnetworks/utils/common/include/nnapi/hal/CommonUtils.h b/neuralnetworks/utils/common/include/nnapi/hal/CommonUtils.h
index 2f6112a..8fe6b90 100644
--- a/neuralnetworks/utils/common/include/nnapi/hal/CommonUtils.h
+++ b/neuralnetworks/utils/common/include/nnapi/hal/CommonUtils.h
@@ -32,6 +32,8 @@
// Shorthands
namespace aidl::android::hardware::neuralnetworks {
namespace aidl_hal = ::aidl::android::hardware::neuralnetworks;
+namespace hal = ::android::hardware::neuralnetworks;
+namespace nn = ::android::nn;
} // namespace aidl::android::hardware::neuralnetworks
// Shorthands
diff --git a/neuralnetworks/utils/common/include/nnapi/hal/ProtectCallback.h b/neuralnetworks/utils/common/include/nnapi/hal/ProtectCallback.h
index c921885..05110bc 100644
--- a/neuralnetworks/utils/common/include/nnapi/hal/ProtectCallback.h
+++ b/neuralnetworks/utils/common/include/nnapi/hal/ProtectCallback.h
@@ -56,7 +56,7 @@
// Thread safe class
class DeathRecipient final : public hidl_death_recipient {
public:
- void serviceDied(uint64_t /*cookie*/, const wp<hidl::base::V1_0::IBase>& /*who*/) override;
+ void serviceDied(uint64_t cookie, const wp<hidl::base::V1_0::IBase>& who) override;
// Precondition: `killable` must be non-null.
void add(IProtectedCallback* killable) const;
// Precondition: `killable` must be non-null.
@@ -64,6 +64,7 @@
private:
mutable std::mutex mMutex;
+ mutable bool mIsDeadObject GUARDED_BY(mMutex) = false;
mutable std::vector<IProtectedCallback*> mObjects GUARDED_BY(mMutex);
};
@@ -78,14 +79,21 @@
~DeathHandler();
using Cleanup = std::function<void()>;
+ using Hold = base::ScopeGuard<Cleanup>;
+
// Precondition: `killable` must be non-null.
- [[nodiscard]] base::ScopeGuard<Cleanup> protectCallback(IProtectedCallback* killable) const;
+ // `killable` must outlive the return value `Hold`.
+ [[nodiscard]] Hold protectCallback(IProtectedCallback* killable) const;
+
+ // Precondition: `killable` must be non-null.
+ // `killable` must outlive the `DeathHandler`.
+ void protectCallbackForLifetimeOfDeathHandler(IProtectedCallback* killable) const;
private:
DeathHandler(sp<hidl::base::V1_0::IBase> object, sp<DeathRecipient> deathRecipient);
- sp<hidl::base::V1_0::IBase> kObject;
- sp<DeathRecipient> kDeathRecipient;
+ sp<hidl::base::V1_0::IBase> mObject;
+ sp<DeathRecipient> mDeathRecipient;
};
} // namespace android::hardware::neuralnetworks::utils
diff --git a/neuralnetworks/utils/common/src/ProtectCallback.cpp b/neuralnetworks/utils/common/src/ProtectCallback.cpp
index abe4cb6..18e1f3b 100644
--- a/neuralnetworks/utils/common/src/ProtectCallback.cpp
+++ b/neuralnetworks/utils/common/src/ProtectCallback.cpp
@@ -35,19 +35,25 @@
std::lock_guard guard(mMutex);
std::for_each(mObjects.begin(), mObjects.end(),
[](IProtectedCallback* killable) { killable->notifyAsDeadObject(); });
+ mObjects.clear();
+ mIsDeadObject = true;
}
void DeathRecipient::add(IProtectedCallback* killable) const {
CHECK(killable != nullptr);
std::lock_guard guard(mMutex);
- mObjects.push_back(killable);
+ if (mIsDeadObject) {
+ killable->notifyAsDeadObject();
+ } else {
+ mObjects.push_back(killable);
+ }
}
void DeathRecipient::remove(IProtectedCallback* killable) const {
CHECK(killable != nullptr);
std::lock_guard guard(mMutex);
- const auto removedIter = std::remove(mObjects.begin(), mObjects.end(), killable);
- mObjects.erase(removedIter);
+ const auto newEnd = std::remove(mObjects.begin(), mObjects.end(), killable);
+ mObjects.erase(newEnd, mObjects.end());
}
nn::GeneralResult<DeathHandler> DeathHandler::create(sp<hidl::base::V1_0::IBase> object) {
@@ -67,19 +73,16 @@
}
DeathHandler::DeathHandler(sp<hidl::base::V1_0::IBase> object, sp<DeathRecipient> deathRecipient)
- : kObject(std::move(object)), kDeathRecipient(std::move(deathRecipient)) {
- CHECK(kObject != nullptr);
- CHECK(kDeathRecipient != nullptr);
+ : mObject(std::move(object)), mDeathRecipient(std::move(deathRecipient)) {
+ CHECK(mObject != nullptr);
+ CHECK(mDeathRecipient != nullptr);
}
DeathHandler::~DeathHandler() {
- if (kObject != nullptr && kDeathRecipient != nullptr) {
- const auto ret = kObject->unlinkToDeath(kDeathRecipient);
- const auto maybeSuccess = handleTransportError(ret);
- if (!maybeSuccess.has_value()) {
- LOG(ERROR) << maybeSuccess.error().message;
- } else if (!maybeSuccess.value()) {
- LOG(ERROR) << "IBase::linkToDeath returned false";
+ if (mObject != nullptr && mDeathRecipient != nullptr) {
+ const auto successful = mObject->unlinkToDeath(mDeathRecipient).isOk();
+ if (!successful) {
+ LOG(ERROR) << "IBase::linkToDeath failed";
}
}
}
@@ -87,9 +90,14 @@
[[nodiscard]] base::ScopeGuard<DeathHandler::Cleanup> DeathHandler::protectCallback(
IProtectedCallback* killable) const {
CHECK(killable != nullptr);
- kDeathRecipient->add(killable);
+ mDeathRecipient->add(killable);
return base::make_scope_guard(
- [deathRecipient = kDeathRecipient, killable] { deathRecipient->remove(killable); });
+ [deathRecipient = mDeathRecipient, killable] { deathRecipient->remove(killable); });
+}
+
+void DeathHandler::protectCallbackForLifetimeOfDeathHandler(IProtectedCallback* killable) const {
+ CHECK(killable != nullptr);
+ mDeathRecipient->add(killable);
}
} // namespace android::hardware::neuralnetworks::utils