Add recovery code to NN ResilientPreparedModel and *Buffer

Prior to this CL, ResilientPreparedModel and ResilientBuffer were
passthrough interfaces that just forwarded calls to the underlying
interface object. This CL implements the full recovery mechanism for
these two classes. However, because we do not want to enable this
functionality in the NN runtime yet, ResilientDevice hides the paths
that create ResilientPreparedModel and ResilientBuffer behind an #if
until we want to enable those paths.

Bug: N/A
Test: mma
Change-Id: Idfe8093c63c7ba2f16c995eec872d150696e7a08
diff --git a/neuralnetworks/utils/common/include/nnapi/hal/ResilientBuffer.h b/neuralnetworks/utils/common/include/nnapi/hal/ResilientBuffer.h
index 9d5e3e6..d2c2469 100644
--- a/neuralnetworks/utils/common/include/nnapi/hal/ResilientBuffer.h
+++ b/neuralnetworks/utils/common/include/nnapi/hal/ResilientBuffer.h
@@ -42,7 +42,7 @@
                              nn::SharedBuffer buffer);
 
     nn::SharedBuffer getBuffer() const;
-    nn::SharedBuffer recover(const nn::IBuffer* failingBuffer, bool blocking) const;
+    nn::GeneralResult<nn::SharedBuffer> recover(const nn::IBuffer* failingBuffer) const;
 
     nn::Request::MemoryDomainToken getToken() const override;
 
diff --git a/neuralnetworks/utils/common/include/nnapi/hal/ResilientPreparedModel.h b/neuralnetworks/utils/common/include/nnapi/hal/ResilientPreparedModel.h
index faae673..9b8d924 100644
--- a/neuralnetworks/utils/common/include/nnapi/hal/ResilientPreparedModel.h
+++ b/neuralnetworks/utils/common/include/nnapi/hal/ResilientPreparedModel.h
@@ -43,8 +43,8 @@
                                     nn::SharedPreparedModel preparedModel);
 
     nn::SharedPreparedModel getPreparedModel() const;
-    nn::SharedPreparedModel recover(const nn::IPreparedModel* failingPreparedModel,
-                                    bool blocking) const;
+    nn::GeneralResult<nn::SharedPreparedModel> recover(
+            const nn::IPreparedModel* failingPreparedModel) const;
 
     nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> execute(
             const nn::Request& request, nn::MeasureTiming measure,
diff --git a/neuralnetworks/utils/common/src/ResilientBuffer.cpp b/neuralnetworks/utils/common/src/ResilientBuffer.cpp
index cf5496a..47abbe2 100644
--- a/neuralnetworks/utils/common/src/ResilientBuffer.cpp
+++ b/neuralnetworks/utils/common/src/ResilientBuffer.cpp
@@ -20,6 +20,7 @@
 #include <android-base/thread_annotations.h>
 #include <nnapi/IBuffer.h>
 #include <nnapi/Result.h>
+#include <nnapi/TypeUtils.h>
 #include <nnapi/Types.h>
 
 #include <functional>
@@ -29,6 +30,34 @@
 #include <vector>
 
 namespace android::hardware::neuralnetworks::utils {
+namespace {
+
+template <typename FnType>
+auto protect(const ResilientBuffer& resilientBuffer, const FnType& fn)
+        -> decltype(fn(*resilientBuffer.getBuffer())) {
+    auto buffer = resilientBuffer.getBuffer();
+    auto result = fn(*buffer);
+
+    // Immediately return if device is not dead.
+    if (result.has_value() || result.error().code != nn::ErrorStatus::DEAD_OBJECT) {
+        return result;
+    }
+
+    // Attempt recovery and return if it fails.
+    auto maybeBuffer = resilientBuffer.recover(buffer.get());
+    if (!maybeBuffer.has_value()) {
+        const auto& [resultErrorMessage, resultErrorCode] = result.error();
+        const auto& [recoveryErrorMessage, recoveryErrorCode] = maybeBuffer.error();
+        return nn::error(resultErrorCode)
+               << resultErrorMessage << ", and failed to recover dead buffer with error "
+               << recoveryErrorCode << ": " << recoveryErrorMessage;
+    }
+    buffer = std::move(maybeBuffer).value();
+
+    return fn(*buffer);
+}
+
+}  // namespace
 
 nn::GeneralResult<std::shared_ptr<const ResilientBuffer>> ResilientBuffer::create(
         Factory makeBuffer) {
@@ -53,9 +82,16 @@
     std::lock_guard guard(mMutex);
     return mBuffer;
 }
-nn::SharedBuffer ResilientBuffer::recover(const nn::IBuffer* /*failingBuffer*/,
-                                          bool /*blocking*/) const {
+nn::GeneralResult<nn::SharedBuffer> ResilientBuffer::recover(
+        const nn::IBuffer* failingBuffer) const {
     std::lock_guard guard(mMutex);
+
+    // Another caller updated the failing prepared model.
+    if (mBuffer.get() != failingBuffer) {
+        return mBuffer;
+    }
+
+    mBuffer = NN_TRY(kMakeBuffer());
     return mBuffer;
 }
 
@@ -64,12 +100,16 @@
 }
 
 nn::GeneralResult<void> ResilientBuffer::copyTo(const nn::Memory& dst) const {
-    return getBuffer()->copyTo(dst);
+    const auto fn = [&dst](const nn::IBuffer& buffer) { return buffer.copyTo(dst); };
+    return protect(*this, fn);
 }
 
 nn::GeneralResult<void> ResilientBuffer::copyFrom(const nn::Memory& src,
                                                   const nn::Dimensions& dimensions) const {
-    return getBuffer()->copyFrom(src, dimensions);
+    const auto fn = [&src, &dimensions](const nn::IBuffer& buffer) {
+        return buffer.copyFrom(src, dimensions);
+    };
+    return protect(*this, fn);
 }
 
 }  // namespace android::hardware::neuralnetworks::utils
diff --git a/neuralnetworks/utils/common/src/ResilientDevice.cpp b/neuralnetworks/utils/common/src/ResilientDevice.cpp
index 6ad3fad..2023c9a 100644
--- a/neuralnetworks/utils/common/src/ResilientDevice.cpp
+++ b/neuralnetworks/utils/common/src/ResilientDevice.cpp
@@ -180,6 +180,7 @@
         const nn::Model& model, nn::ExecutionPreference preference, nn::Priority priority,
         nn::OptionalTimePoint deadline, const std::vector<nn::SharedHandle>& modelCache,
         const std::vector<nn::SharedHandle>& dataCache, const nn::CacheToken& token) const {
+#if 0
     auto self = shared_from_this();
     ResilientPreparedModel::Factory makePreparedModel = [device = std::move(self), model,
                                                          preference, priority, deadline, modelCache,
@@ -188,29 +189,41 @@
                                             dataCache, token);
     };
     return ResilientPreparedModel::create(std::move(makePreparedModel));
+#else
+    return prepareModelInternal(model, preference, priority, deadline, modelCache, dataCache,
+                                token);
+#endif
 }
 
 nn::GeneralResult<nn::SharedPreparedModel> ResilientDevice::prepareModelFromCache(
         nn::OptionalTimePoint deadline, const std::vector<nn::SharedHandle>& modelCache,
         const std::vector<nn::SharedHandle>& dataCache, const nn::CacheToken& token) const {
+#if 0
     auto self = shared_from_this();
     ResilientPreparedModel::Factory makePreparedModel = [device = std::move(self), deadline,
                                                          modelCache, dataCache, token] {
         return device->prepareModelFromCacheInternal(deadline, modelCache, dataCache, token);
     };
     return ResilientPreparedModel::create(std::move(makePreparedModel));
+#else
+    return prepareModelFromCacheInternal(deadline, modelCache, dataCache, token);
+#endif
 }
 
 nn::GeneralResult<nn::SharedBuffer> ResilientDevice::allocate(
         const nn::BufferDesc& desc, const std::vector<nn::SharedPreparedModel>& preparedModels,
         const std::vector<nn::BufferRole>& inputRoles,
         const std::vector<nn::BufferRole>& outputRoles) const {
+#if 0
     auto self = shared_from_this();
     ResilientBuffer::Factory makeBuffer = [device = std::move(self), desc, preparedModels,
                                            inputRoles, outputRoles] {
         return device->allocateInternal(desc, preparedModels, inputRoles, outputRoles);
     };
     return ResilientBuffer::create(std::move(makeBuffer));
+#else
+    return allocateInternal(desc, preparedModels, inputRoles, outputRoles);
+#endif
 }
 
 bool ResilientDevice::isValidInternal() const {
@@ -225,8 +238,8 @@
     if (!isValidInternal()) {
         return std::make_shared<const InvalidPreparedModel>();
     }
-    const auto fn = [&model, preference, priority, deadline, &modelCache, &dataCache,
-                     token](const nn::IDevice& device) {
+    const auto fn = [&model, preference, priority, &deadline, &modelCache, &dataCache,
+                     &token](const nn::IDevice& device) {
         return device.prepareModel(model, preference, priority, deadline, modelCache, dataCache,
                                    token);
     };
@@ -239,7 +252,7 @@
     if (!isValidInternal()) {
         return std::make_shared<const InvalidPreparedModel>();
     }
-    const auto fn = [deadline, &modelCache, &dataCache, token](const nn::IDevice& device) {
+    const auto fn = [&deadline, &modelCache, &dataCache, &token](const nn::IDevice& device) {
         return device.prepareModelFromCache(deadline, modelCache, dataCache, token);
     };
     return protect(*this, fn, /*blocking=*/false);
diff --git a/neuralnetworks/utils/common/src/ResilientPreparedModel.cpp b/neuralnetworks/utils/common/src/ResilientPreparedModel.cpp
index b8acee1..667df2b 100644
--- a/neuralnetworks/utils/common/src/ResilientPreparedModel.cpp
+++ b/neuralnetworks/utils/common/src/ResilientPreparedModel.cpp
@@ -20,15 +20,45 @@
 #include <android-base/thread_annotations.h>
 #include <nnapi/IPreparedModel.h>
 #include <nnapi/Result.h>
+#include <nnapi/TypeUtils.h>
 #include <nnapi/Types.h>
 
 #include <functional>
 #include <memory>
 #include <mutex>
+#include <sstream>
 #include <utility>
 #include <vector>
 
 namespace android::hardware::neuralnetworks::utils {
+namespace {
+
+template <typename FnType>
+auto protect(const ResilientPreparedModel& resilientPreparedModel, const FnType& fn)
+        -> decltype(fn(*resilientPreparedModel.getPreparedModel())) {
+    auto preparedModel = resilientPreparedModel.getPreparedModel();
+    auto result = fn(*preparedModel);
+
+    // Immediately return if prepared model is not dead.
+    if (result.has_value() || result.error().code != nn::ErrorStatus::DEAD_OBJECT) {
+        return result;
+    }
+
+    // Attempt recovery and return if it fails.
+    auto maybePreparedModel = resilientPreparedModel.recover(preparedModel.get());
+    if (!maybePreparedModel.has_value()) {
+        const auto& [message, code] = maybePreparedModel.error();
+        std::ostringstream oss;
+        oss << ", and failed to recover dead prepared model with error " << code << ": " << message;
+        result.error().message += oss.str();
+        return result;
+    }
+    preparedModel = std::move(maybePreparedModel).value();
+
+    return fn(*preparedModel);
+}
+
+}  // namespace
 
 nn::GeneralResult<std::shared_ptr<const ResilientPreparedModel>> ResilientPreparedModel::create(
         Factory makePreparedModel) {
@@ -55,9 +85,16 @@
     return mPreparedModel;
 }
 
-nn::SharedPreparedModel ResilientPreparedModel::recover(
-        const nn::IPreparedModel* /*failingPreparedModel*/, bool /*blocking*/) const {
+nn::GeneralResult<nn::SharedPreparedModel> ResilientPreparedModel::recover(
+        const nn::IPreparedModel* failingPreparedModel) const {
     std::lock_guard guard(mMutex);
+
+    // Another caller updated the failing prepared model.
+    if (mPreparedModel.get() != failingPreparedModel) {
+        return mPreparedModel;
+    }
+
+    mPreparedModel = NN_TRY(kMakePreparedModel());
     return mPreparedModel;
 }
 
@@ -65,7 +102,11 @@
 ResilientPreparedModel::execute(const nn::Request& request, nn::MeasureTiming measure,
                                 const nn::OptionalTimePoint& deadline,
                                 const nn::OptionalDuration& loopTimeoutDuration) const {
-    return getPreparedModel()->execute(request, measure, deadline, loopTimeoutDuration);
+    const auto fn = [&request, measure, &deadline,
+                     &loopTimeoutDuration](const nn::IPreparedModel& preparedModel) {
+        return preparedModel.execute(request, measure, deadline, loopTimeoutDuration);
+    };
+    return protect(*this, fn);
 }
 
 nn::GeneralResult<std::pair<nn::SyncFence, nn::ExecuteFencedInfoCallback>>
@@ -75,8 +116,12 @@
                                       const nn::OptionalTimePoint& deadline,
                                       const nn::OptionalDuration& loopTimeoutDuration,
                                       const nn::OptionalDuration& timeoutDurationAfterFence) const {
-    return getPreparedModel()->executeFenced(request, waitFor, measure, deadline,
-                                             loopTimeoutDuration, timeoutDurationAfterFence);
+    const auto fn = [&request, &waitFor, measure, &deadline, &loopTimeoutDuration,
+                     &timeoutDurationAfterFence](const nn::IPreparedModel& preparedModel) {
+        return preparedModel.executeFenced(request, waitFor, measure, deadline, loopTimeoutDuration,
+                                           timeoutDurationAfterFence);
+    };
+    return protect(*this, fn);
 }
 
 std::any ResilientPreparedModel::getUnderlyingResource() const {