Implement full canonical Burst in NN util code

Bug: 180492058
Bug: 177267324
Test: mma
Test: presubmit
Change-Id: I5018f6cf2dbaf705f74f4f46318142c64433e19d
Merged-In: I5018f6cf2dbaf705f74f4f46318142c64433e19d
(cherry picked from commit acff4063b63c04cbb28af87eab61e9a1fa70980a)
diff --git a/neuralnetworks/1.2/utils/src/ExecutionBurstServer.cpp b/neuralnetworks/1.2/utils/src/ExecutionBurstServer.cpp
index 022548d..50af881 100644
--- a/neuralnetworks/1.2/utils/src/ExecutionBurstServer.cpp
+++ b/neuralnetworks/1.2/utils/src/ExecutionBurstServer.cpp
@@ -17,8 +17,19 @@
 #define LOG_TAG "ExecutionBurstServer"
 
 #include "ExecutionBurstServer.h"
+#include "Conversions.h"
+#include "ExecutionBurstUtils.h"
 
 #include <android-base/logging.h>
+#include <nnapi/IBurst.h>
+#include <nnapi/Result.h>
+#include <nnapi/TypeUtils.h>
+#include <nnapi/Types.h>
+#include <nnapi/Validation.h>
+#include <nnapi/hal/1.0/Conversions.h>
+#include <nnapi/hal/HandleError.h>
+#include <nnapi/hal/ProtectCallback.h>
+#include <nnapi/hal/TransferValue.h>
 
 #include <algorithm>
 #include <cstring>
@@ -29,134 +40,146 @@
 #include <utility>
 #include <vector>
 
-#include "ExecutionBurstUtils.h"
-#include "HalInterfaces.h"
 #include "Tracing.h"
 
-namespace android::nn {
+namespace android::hardware::neuralnetworks::V1_2::utils {
 namespace {
 
-// DefaultBurstExecutorWithCache adapts an IPreparedModel so that it can be
-// used as an IBurstExecutorWithCache. Specifically, the cache simply stores the
-// hidl_memory object, and the execution forwards calls to the provided
-// IPreparedModel's "executeSynchronously" method. With this class, hidl_memory
-// must be mapped and unmapped for each execution.
-class DefaultBurstExecutorWithCache : public ExecutionBurstServer::IBurstExecutorWithCache {
-  public:
-    DefaultBurstExecutorWithCache(V1_2::IPreparedModel* preparedModel)
-        : mpPreparedModel(preparedModel) {}
+using neuralnetworks::utils::makeExecutionFailure;
 
-    bool isCacheEntryPresent(int32_t slot) const override {
-        const auto it = mMemoryCache.find(slot);
-        return (it != mMemoryCache.end()) && it->second.valid();
+constexpr V1_2::Timing kNoTiming = {std::numeric_limits<uint64_t>::max(),
+                                    std::numeric_limits<uint64_t>::max()};
+
+nn::GeneralResult<std::vector<nn::SharedMemory>> getMemoriesCallback(
+        V1_0::ErrorStatus status, const hidl_vec<hidl_memory>& memories) {
+    HANDLE_HAL_STATUS(status) << "getting burst memories failed with " << toString(status);
+    std::vector<nn::SharedMemory> canonicalMemories;
+    canonicalMemories.reserve(memories.size());
+    for (const auto& memory : memories) {
+        canonicalMemories.push_back(NN_TRY(nn::convert(memory)));
     }
-
-    void addCacheEntry(const hardware::hidl_memory& memory, int32_t slot) override {
-        mMemoryCache[slot] = memory;
-    }
-
-    void removeCacheEntry(int32_t slot) override { mMemoryCache.erase(slot); }
-
-    std::tuple<V1_0::ErrorStatus, hardware::hidl_vec<V1_2::OutputShape>, V1_2::Timing> execute(
-            const V1_0::Request& request, const std::vector<int32_t>& slots,
-            V1_2::MeasureTiming measure) override {
-        // convert slots to pools
-        hardware::hidl_vec<hardware::hidl_memory> pools(slots.size());
-        std::transform(slots.begin(), slots.end(), pools.begin(),
-                       [this](int32_t slot) { return mMemoryCache[slot]; });
-
-        // create full request
-        V1_0::Request fullRequest = request;
-        fullRequest.pools = std::move(pools);
-
-        // setup execution
-        V1_0::ErrorStatus returnedStatus = V1_0::ErrorStatus::GENERAL_FAILURE;
-        hardware::hidl_vec<V1_2::OutputShape> returnedOutputShapes;
-        V1_2::Timing returnedTiming;
-        auto cb = [&returnedStatus, &returnedOutputShapes, &returnedTiming](
-                          V1_0::ErrorStatus status,
-                          const hardware::hidl_vec<V1_2::OutputShape>& outputShapes,
-                          const V1_2::Timing& timing) {
-            returnedStatus = status;
-            returnedOutputShapes = outputShapes;
-            returnedTiming = timing;
-        };
-
-        // execute
-        const hardware::Return<void> ret =
-                mpPreparedModel->executeSynchronously(fullRequest, measure, cb);
-        if (!ret.isOk() || returnedStatus != V1_0::ErrorStatus::NONE) {
-            LOG(ERROR) << "IPreparedModelAdapter::execute -- Error executing";
-            return {returnedStatus, std::move(returnedOutputShapes), kNoTiming};
-        }
-
-        return std::make_tuple(returnedStatus, std::move(returnedOutputShapes), returnedTiming);
-    }
-
-  private:
-    V1_2::IPreparedModel* const mpPreparedModel;
-    std::map<int32_t, hardware::hidl_memory> mMemoryCache;
-};
+    return canonicalMemories;
+}
 
 }  // anonymous namespace
 
+ExecutionBurstServer::MemoryCache::MemoryCache(nn::SharedBurst burstExecutor,
+                                               sp<IBurstCallback> burstCallback)
+    : kBurstExecutor(std::move(burstExecutor)), kBurstCallback(std::move(burstCallback)) {
+    CHECK(kBurstExecutor != nullptr);
+    CHECK(kBurstCallback != nullptr);
+}
+
+nn::GeneralResult<std::vector<std::pair<nn::SharedMemory, nn::IBurst::OptionalCacheHold>>>
+ExecutionBurstServer::MemoryCache::getCacheEntries(const std::vector<int32_t>& slots) {
+    std::lock_guard guard(mMutex);
+    NN_TRY(ensureCacheEntriesArePresentLocked(slots));
+
+    std::vector<std::pair<nn::SharedMemory, nn::IBurst::OptionalCacheHold>> results;
+    results.reserve(slots.size());
+    for (int32_t slot : slots) {
+        results.push_back(NN_TRY(getCacheEntryLocked(slot)));
+    }
+
+    return results;
+}
+
+nn::GeneralResult<void> ExecutionBurstServer::MemoryCache::ensureCacheEntriesArePresentLocked(
+        const std::vector<int32_t>& slots) {
+    const auto slotIsKnown = [this](int32_t slot)
+                                     REQUIRES(mMutex) { return mCache.count(slot) > 0; };
+
+    // find unique unknown slots
+    std::vector<int32_t> unknownSlots = slots;
+    std::sort(unknownSlots.begin(), unknownSlots.end());
+    auto unknownSlotsEnd = std::unique(unknownSlots.begin(), unknownSlots.end());
+    unknownSlotsEnd = std::remove_if(unknownSlots.begin(), unknownSlotsEnd, slotIsKnown);
+    unknownSlots.erase(unknownSlotsEnd, unknownSlots.end());
+
+    // quick-exit if all slots are known
+    if (unknownSlots.empty()) {
+        return {};
+    }
+
+    auto cb = neuralnetworks::utils::CallbackValue(getMemoriesCallback);
+
+    const auto ret = kBurstCallback->getMemories(unknownSlots, cb);
+    HANDLE_TRANSPORT_FAILURE(ret);
+
+    auto returnedMemories = NN_TRY(cb.take());
+
+    if (returnedMemories.size() != unknownSlots.size()) {
+        return NN_ERROR()
+               << "ExecutionBurstServer::MemoryCache::ensureCacheEntriesArePresentLocked: Error "
+                  "retrieving memories -- count mismatch between requested memories ("
+               << unknownSlots.size() << ") and returned memories (" << returnedMemories.size()
+               << ")";
+    }
+
+    // add memories to unknown slots
+    for (size_t i = 0; i < unknownSlots.size(); ++i) {
+        addCacheEntryLocked(unknownSlots[i], std::move(returnedMemories[i]));
+    }
+
+    return {};
+}
+
+nn::GeneralResult<std::pair<nn::SharedMemory, nn::IBurst::OptionalCacheHold>>
+ExecutionBurstServer::MemoryCache::getCacheEntryLocked(int32_t slot) {
+    if (const auto iter = mCache.find(slot); iter != mCache.end()) {
+        return iter->second;
+    }
+    return NN_ERROR()
+           << "ExecutionBurstServer::MemoryCache::getCacheEntryLocked failed because slot " << slot
+           << " is not present in the cache";
+}
+
+void ExecutionBurstServer::MemoryCache::addCacheEntryLocked(int32_t slot, nn::SharedMemory memory) {
+    auto hold = kBurstExecutor->cacheMemory(memory);
+    mCache.emplace(slot, std::make_pair(std::move(memory), std::move(hold)));
+}
+
+void ExecutionBurstServer::MemoryCache::removeCacheEntry(int32_t slot) {
+    std::lock_guard guard(mMutex);
+    mCache.erase(slot);
+}
+
 // ExecutionBurstServer methods
 
-sp<ExecutionBurstServer> ExecutionBurstServer::create(
+nn::GeneralResult<sp<ExecutionBurstServer>> ExecutionBurstServer::create(
         const sp<IBurstCallback>& callback, const MQDescriptorSync<FmqRequestDatum>& requestChannel,
-        const MQDescriptorSync<FmqResultDatum>& resultChannel,
-        std::shared_ptr<IBurstExecutorWithCache> executorWithCache,
+        const MQDescriptorSync<FmqResultDatum>& resultChannel, nn::SharedBurst burstExecutor,
         std::chrono::microseconds pollingTimeWindow) {
     // check inputs
-    if (callback == nullptr || executorWithCache == nullptr) {
-        LOG(ERROR) << "ExecutionBurstServer::create passed a nullptr";
-        return nullptr;
+    if (callback == nullptr || burstExecutor == nullptr) {
+        return NN_ERROR() << "ExecutionBurstServer::create passed a nullptr";
     }
 
     // create FMQ objects
-    std::unique_ptr<RequestChannelReceiver> requestChannelReceiver =
-            RequestChannelReceiver::create(requestChannel, pollingTimeWindow);
-    std::unique_ptr<ResultChannelSender> resultChannelSender =
-            ResultChannelSender::create(resultChannel);
+    auto requestChannelReceiver =
+            NN_TRY(RequestChannelReceiver::create(requestChannel, pollingTimeWindow));
+    auto resultChannelSender = NN_TRY(ResultChannelSender::create(resultChannel));
 
     // check FMQ objects
-    if (!requestChannelReceiver || !resultChannelSender) {
-        LOG(ERROR) << "ExecutionBurstServer::create failed to create FastMessageQueue";
-        return nullptr;
-    }
+    CHECK(requestChannelReceiver != nullptr);
+    CHECK(resultChannelSender != nullptr);
 
     // make and return context
-    return new ExecutionBurstServer(callback, std::move(requestChannelReceiver),
-                                    std::move(resultChannelSender), std::move(executorWithCache));
+    return sp<ExecutionBurstServer>::make(PrivateConstructorTag{}, callback,
+                                          std::move(requestChannelReceiver),
+                                          std::move(resultChannelSender), std::move(burstExecutor));
 }
 
-sp<ExecutionBurstServer> ExecutionBurstServer::create(
-        const sp<IBurstCallback>& callback, const MQDescriptorSync<FmqRequestDatum>& requestChannel,
-        const MQDescriptorSync<FmqResultDatum>& resultChannel, V1_2::IPreparedModel* preparedModel,
-        std::chrono::microseconds pollingTimeWindow) {
-    // check relevant input
-    if (preparedModel == nullptr) {
-        LOG(ERROR) << "ExecutionBurstServer::create passed a nullptr";
-        return nullptr;
-    }
-
-    // adapt IPreparedModel to have caching
-    const std::shared_ptr<DefaultBurstExecutorWithCache> preparedModelAdapter =
-            std::make_shared<DefaultBurstExecutorWithCache>(preparedModel);
-
-    // make and return context
-    return ExecutionBurstServer::create(callback, requestChannel, resultChannel,
-                                        preparedModelAdapter, pollingTimeWindow);
-}
-
-ExecutionBurstServer::ExecutionBurstServer(
-        const sp<IBurstCallback>& callback, std::unique_ptr<RequestChannelReceiver> requestChannel,
-        std::unique_ptr<ResultChannelSender> resultChannel,
-        std::shared_ptr<IBurstExecutorWithCache> executorWithCache)
+ExecutionBurstServer::ExecutionBurstServer(PrivateConstructorTag /*tag*/,
+                                           const sp<IBurstCallback>& callback,
+                                           std::unique_ptr<RequestChannelReceiver> requestChannel,
+                                           std::unique_ptr<ResultChannelSender> resultChannel,
+                                           nn::SharedBurst burstExecutor)
     : mCallback(callback),
       mRequestChannelReceiver(std::move(requestChannel)),
       mResultChannelSender(std::move(resultChannel)),
-      mExecutorWithCache(std::move(executorWithCache)) {
+      mBurstExecutor(std::move(burstExecutor)),
+      mMemoryCache(mBurstExecutor, mCallback) {
     // TODO: highly document the threading behavior of this class
     mWorker = std::thread([this] { task(); });
 }
@@ -170,51 +193,9 @@
     mWorker.join();
 }
 
-hardware::Return<void> ExecutionBurstServer::freeMemory(int32_t slot) {
-    std::lock_guard<std::mutex> hold(mMutex);
-    mExecutorWithCache->removeCacheEntry(slot);
-    return hardware::Void();
-}
-
-void ExecutionBurstServer::ensureCacheEntriesArePresentLocked(const std::vector<int32_t>& slots) {
-    const auto slotIsKnown = [this](int32_t slot) {
-        return mExecutorWithCache->isCacheEntryPresent(slot);
-    };
-
-    // find unique unknown slots
-    std::vector<int32_t> unknownSlots = slots;
-    auto unknownSlotsEnd = unknownSlots.end();
-    std::sort(unknownSlots.begin(), unknownSlotsEnd);
-    unknownSlotsEnd = std::unique(unknownSlots.begin(), unknownSlotsEnd);
-    unknownSlotsEnd = std::remove_if(unknownSlots.begin(), unknownSlotsEnd, slotIsKnown);
-    unknownSlots.erase(unknownSlotsEnd, unknownSlots.end());
-
-    // quick-exit if all slots are known
-    if (unknownSlots.empty()) {
-        return;
-    }
-
-    V1_0::ErrorStatus errorStatus = V1_0::ErrorStatus::GENERAL_FAILURE;
-    std::vector<hardware::hidl_memory> returnedMemories;
-    auto cb = [&errorStatus, &returnedMemories](
-                      V1_0::ErrorStatus status,
-                      const hardware::hidl_vec<hardware::hidl_memory>& memories) {
-        errorStatus = status;
-        returnedMemories = memories;
-    };
-
-    const hardware::Return<void> ret = mCallback->getMemories(unknownSlots, cb);
-
-    if (!ret.isOk() || errorStatus != V1_0::ErrorStatus::NONE ||
-        returnedMemories.size() != unknownSlots.size()) {
-        LOG(ERROR) << "Error retrieving memories";
-        return;
-    }
-
-    // add memories to unknown slots
-    for (size_t i = 0; i < unknownSlots.size(); ++i) {
-        mExecutorWithCache->addCacheEntry(returnedMemories[i], unknownSlots[i]);
-    }
+Return<void> ExecutionBurstServer::freeMemory(int32_t slot) {
+    mMemoryCache.removeCacheEntry(slot);
+    return Void();
 }
 
 void ExecutionBurstServer::task() {
@@ -223,38 +204,65 @@
         // receive request
         auto arguments = mRequestChannelReceiver->getBlocking();
 
-        // if the request packet was not properly received, return a generic
-        // error and skip the execution
+        // if the request packet was not properly received, return a generic error and skip the
+        // execution
         //
-        // if the burst is being torn down, skip the execution so the "task"
-        // function can end
-        if (!arguments) {
+        // if the burst is being torn down, skip the execution so the "task" function can end
+        if (!arguments.has_value()) {
             if (!mTeardown) {
                 mResultChannelSender->send(V1_0::ErrorStatus::GENERAL_FAILURE, {}, kNoTiming);
             }
             continue;
         }
 
-        // otherwise begin tracing execution
-        NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_EXECUTION,
-                     "ExecutionBurstServer getting memory, executing, and returning results");
+        // unpack the arguments; types are Request, std::vector<int32_t>, and MeasureTiming,
+        // respectively
+        const auto [requestWithoutPools, slotsOfPools, measure] = std::move(arguments).value();
 
-        // unpack the arguments; types are Request, std::vector<int32_t>, and
-        // MeasureTiming, respectively
-        const auto [requestWithoutPools, slotsOfPools, measure] = std::move(*arguments);
-
-        // ensure executor with cache has required memory
-        std::lock_guard<std::mutex> hold(mMutex);
-        ensureCacheEntriesArePresentLocked(slotsOfPools);
-
-        // perform computation; types are ErrorStatus, hidl_vec<OutputShape>,
-        // and Timing, respectively
-        const auto [errorStatus, outputShapes, returnedTiming] =
-                mExecutorWithCache->execute(requestWithoutPools, slotsOfPools, measure);
+        auto result = execute(requestWithoutPools, slotsOfPools, measure);
 
         // return result
-        mResultChannelSender->send(errorStatus, outputShapes, returnedTiming);
+        if (result.has_value()) {
+            const auto& [outputShapes, timing] = result.value();
+            mResultChannelSender->send(V1_0::ErrorStatus::NONE, outputShapes, timing);
+        } else {
+            const auto& [message, code, outputShapes] = result.error();
+            LOG(ERROR) << "IBurst::execute failed with " << code << ": " << message;
+            mResultChannelSender->send(convert(code).value(), convert(outputShapes).value(),
+                                       kNoTiming);
+        }
     }
 }
 
-}  // namespace android::nn
+nn::ExecutionResult<std::pair<hidl_vec<OutputShape>, Timing>> ExecutionBurstServer::execute(
+        const V1_0::Request& requestWithoutPools, const std::vector<int32_t>& slotsOfPools,
+        MeasureTiming measure) {
+    NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_EXECUTION,
+                 "ExecutionBurstServer getting memory, executing, and returning results");
+
+    // ensure executor with cache has required memory
+    const auto cacheEntries =
+            NN_TRY(makeExecutionFailure(mMemoryCache.getCacheEntries(slotsOfPools)));
+
+    // convert request, populating its pools
+    // This code performs an unvalidated convert because the request object without its pools is
+    // invalid because it is incomplete. Instead, the validation is performed after the memory pools
+    // have been added to the request.
+    auto canonicalRequest =
+            NN_TRY(makeExecutionFailure(nn::unvalidatedConvert(requestWithoutPools)));
+    CHECK(canonicalRequest.pools.empty());
+    std::transform(cacheEntries.begin(), cacheEntries.end(),
+                   std::back_inserter(canonicalRequest.pools),
+                   [](const auto& cacheEntry) { return cacheEntry.first; });
+    NN_TRY(makeExecutionFailure(validate(canonicalRequest)));
+
+    nn::MeasureTiming canonicalMeasure = NN_TRY(makeExecutionFailure(nn::convert(measure)));
+
+    const auto [outputShapes, timing] =
+            NN_TRY(mBurstExecutor->execute(canonicalRequest, canonicalMeasure));
+
+    return std::make_pair(NN_TRY(makeExecutionFailure(convert(outputShapes))),
+                          NN_TRY(makeExecutionFailure(convert(timing))));
+}
+
+}  // namespace android::hardware::neuralnetworks::V1_2::utils