Merge "Add recovery code to NN ResilientPreparedModel and *Buffer"
diff --git a/neuralnetworks/utils/common/include/nnapi/hal/ResilientBuffer.h b/neuralnetworks/utils/common/include/nnapi/hal/ResilientBuffer.h
index 9d5e3e6..d2c2469 100644
--- a/neuralnetworks/utils/common/include/nnapi/hal/ResilientBuffer.h
+++ b/neuralnetworks/utils/common/include/nnapi/hal/ResilientBuffer.h
@@ -42,7 +42,7 @@
                              nn::SharedBuffer buffer);
 
     nn::SharedBuffer getBuffer() const;
-    nn::SharedBuffer recover(const nn::IBuffer* failingBuffer, bool blocking) const;
+    nn::GeneralResult<nn::SharedBuffer> recover(const nn::IBuffer* failingBuffer) const;
 
     nn::Request::MemoryDomainToken getToken() const override;
 
diff --git a/neuralnetworks/utils/common/include/nnapi/hal/ResilientPreparedModel.h b/neuralnetworks/utils/common/include/nnapi/hal/ResilientPreparedModel.h
index faae673..9b8d924 100644
--- a/neuralnetworks/utils/common/include/nnapi/hal/ResilientPreparedModel.h
+++ b/neuralnetworks/utils/common/include/nnapi/hal/ResilientPreparedModel.h
@@ -43,8 +43,8 @@
                                     nn::SharedPreparedModel preparedModel);
 
     nn::SharedPreparedModel getPreparedModel() const;
-    nn::SharedPreparedModel recover(const nn::IPreparedModel* failingPreparedModel,
-                                    bool blocking) const;
+    nn::GeneralResult<nn::SharedPreparedModel> recover(
+            const nn::IPreparedModel* failingPreparedModel) const;
 
     nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> execute(
             const nn::Request& request, nn::MeasureTiming measure,
diff --git a/neuralnetworks/utils/common/src/ResilientBuffer.cpp b/neuralnetworks/utils/common/src/ResilientBuffer.cpp
index cf5496a..47abbe2 100644
--- a/neuralnetworks/utils/common/src/ResilientBuffer.cpp
+++ b/neuralnetworks/utils/common/src/ResilientBuffer.cpp
@@ -20,6 +20,7 @@
 #include <android-base/thread_annotations.h>
 #include <nnapi/IBuffer.h>
 #include <nnapi/Result.h>
+#include <nnapi/TypeUtils.h>
 #include <nnapi/Types.h>
 
 #include <functional>
@@ -29,6 +30,34 @@
 #include <vector>
 
 namespace android::hardware::neuralnetworks::utils {
+namespace {
+
+template <typename FnType>
+auto protect(const ResilientBuffer& resilientBuffer, const FnType& fn)
+        -> decltype(fn(*resilientBuffer.getBuffer())) {
+    auto buffer = resilientBuffer.getBuffer();
+    auto result = fn(*buffer);
+
+    // Immediately return if device is not dead.
+    if (result.has_value() || result.error().code != nn::ErrorStatus::DEAD_OBJECT) {
+        return result;
+    }
+
+    // Attempt recovery and return if it fails.
+    auto maybeBuffer = resilientBuffer.recover(buffer.get());
+    if (!maybeBuffer.has_value()) {
+        const auto& [resultErrorMessage, resultErrorCode] = result.error();
+        const auto& [recoveryErrorMessage, recoveryErrorCode] = maybeBuffer.error();
+        return nn::error(resultErrorCode)
+               << resultErrorMessage << ", and failed to recover dead buffer with error "
+               << recoveryErrorCode << ": " << recoveryErrorMessage;
+    }
+    buffer = std::move(maybeBuffer).value();
+
+    return fn(*buffer);
+}
+
+}  // namespace
 
 nn::GeneralResult<std::shared_ptr<const ResilientBuffer>> ResilientBuffer::create(
         Factory makeBuffer) {
@@ -53,9 +82,16 @@
     std::lock_guard guard(mMutex);
     return mBuffer;
 }
-nn::SharedBuffer ResilientBuffer::recover(const nn::IBuffer* /*failingBuffer*/,
-                                          bool /*blocking*/) const {
+nn::GeneralResult<nn::SharedBuffer> ResilientBuffer::recover(
+        const nn::IBuffer* failingBuffer) const {
     std::lock_guard guard(mMutex);
+
+    // Another caller updated the failing prepared model.
+    if (mBuffer.get() != failingBuffer) {
+        return mBuffer;
+    }
+
+    mBuffer = NN_TRY(kMakeBuffer());
     return mBuffer;
 }
 
@@ -64,12 +100,16 @@
 }
 
 nn::GeneralResult<void> ResilientBuffer::copyTo(const nn::Memory& dst) const {
-    return getBuffer()->copyTo(dst);
+    const auto fn = [&dst](const nn::IBuffer& buffer) { return buffer.copyTo(dst); };
+    return protect(*this, fn);
 }
 
 nn::GeneralResult<void> ResilientBuffer::copyFrom(const nn::Memory& src,
                                                   const nn::Dimensions& dimensions) const {
-    return getBuffer()->copyFrom(src, dimensions);
+    const auto fn = [&src, &dimensions](const nn::IBuffer& buffer) {
+        return buffer.copyFrom(src, dimensions);
+    };
+    return protect(*this, fn);
 }
 
 }  // namespace android::hardware::neuralnetworks::utils
diff --git a/neuralnetworks/utils/common/src/ResilientDevice.cpp b/neuralnetworks/utils/common/src/ResilientDevice.cpp
index 6ad3fad..2023c9a 100644
--- a/neuralnetworks/utils/common/src/ResilientDevice.cpp
+++ b/neuralnetworks/utils/common/src/ResilientDevice.cpp
@@ -180,6 +180,7 @@
         const nn::Model& model, nn::ExecutionPreference preference, nn::Priority priority,
         nn::OptionalTimePoint deadline, const std::vector<nn::SharedHandle>& modelCache,
         const std::vector<nn::SharedHandle>& dataCache, const nn::CacheToken& token) const {
+#if 0
     auto self = shared_from_this();
     ResilientPreparedModel::Factory makePreparedModel = [device = std::move(self), model,
                                                          preference, priority, deadline, modelCache,
@@ -188,29 +189,41 @@
                                             dataCache, token);
     };
     return ResilientPreparedModel::create(std::move(makePreparedModel));
+#else
+    return prepareModelInternal(model, preference, priority, deadline, modelCache, dataCache,
+                                token);
+#endif
 }
 
 nn::GeneralResult<nn::SharedPreparedModel> ResilientDevice::prepareModelFromCache(
         nn::OptionalTimePoint deadline, const std::vector<nn::SharedHandle>& modelCache,
         const std::vector<nn::SharedHandle>& dataCache, const nn::CacheToken& token) const {
+#if 0
     auto self = shared_from_this();
     ResilientPreparedModel::Factory makePreparedModel = [device = std::move(self), deadline,
                                                          modelCache, dataCache, token] {
         return device->prepareModelFromCacheInternal(deadline, modelCache, dataCache, token);
     };
     return ResilientPreparedModel::create(std::move(makePreparedModel));
+#else
+    return prepareModelFromCacheInternal(deadline, modelCache, dataCache, token);
+#endif
 }
 
 nn::GeneralResult<nn::SharedBuffer> ResilientDevice::allocate(
         const nn::BufferDesc& desc, const std::vector<nn::SharedPreparedModel>& preparedModels,
         const std::vector<nn::BufferRole>& inputRoles,
         const std::vector<nn::BufferRole>& outputRoles) const {
+#if 0
     auto self = shared_from_this();
     ResilientBuffer::Factory makeBuffer = [device = std::move(self), desc, preparedModels,
                                            inputRoles, outputRoles] {
         return device->allocateInternal(desc, preparedModels, inputRoles, outputRoles);
     };
     return ResilientBuffer::create(std::move(makeBuffer));
+#else
+    return allocateInternal(desc, preparedModels, inputRoles, outputRoles);
+#endif
 }
 
 bool ResilientDevice::isValidInternal() const {
@@ -225,8 +238,8 @@
     if (!isValidInternal()) {
         return std::make_shared<const InvalidPreparedModel>();
     }
-    const auto fn = [&model, preference, priority, deadline, &modelCache, &dataCache,
-                     token](const nn::IDevice& device) {
+    const auto fn = [&model, preference, priority, &deadline, &modelCache, &dataCache,
+                     &token](const nn::IDevice& device) {
         return device.prepareModel(model, preference, priority, deadline, modelCache, dataCache,
                                    token);
     };
@@ -239,7 +252,7 @@
     if (!isValidInternal()) {
         return std::make_shared<const InvalidPreparedModel>();
     }
-    const auto fn = [deadline, &modelCache, &dataCache, token](const nn::IDevice& device) {
+    const auto fn = [&deadline, &modelCache, &dataCache, &token](const nn::IDevice& device) {
         return device.prepareModelFromCache(deadline, modelCache, dataCache, token);
     };
     return protect(*this, fn, /*blocking=*/false);
diff --git a/neuralnetworks/utils/common/src/ResilientPreparedModel.cpp b/neuralnetworks/utils/common/src/ResilientPreparedModel.cpp
index b8acee1..667df2b 100644
--- a/neuralnetworks/utils/common/src/ResilientPreparedModel.cpp
+++ b/neuralnetworks/utils/common/src/ResilientPreparedModel.cpp
@@ -20,15 +20,45 @@
 #include <android-base/thread_annotations.h>
 #include <nnapi/IPreparedModel.h>
 #include <nnapi/Result.h>
+#include <nnapi/TypeUtils.h>
 #include <nnapi/Types.h>
 
 #include <functional>
 #include <memory>
 #include <mutex>
+#include <sstream>
 #include <utility>
 #include <vector>
 
 namespace android::hardware::neuralnetworks::utils {
+namespace {
+
+template <typename FnType>
+auto protect(const ResilientPreparedModel& resilientPreparedModel, const FnType& fn)
+        -> decltype(fn(*resilientPreparedModel.getPreparedModel())) {
+    auto preparedModel = resilientPreparedModel.getPreparedModel();
+    auto result = fn(*preparedModel);
+
+    // Immediately return if prepared model is not dead.
+    if (result.has_value() || result.error().code != nn::ErrorStatus::DEAD_OBJECT) {
+        return result;
+    }
+
+    // Attempt recovery and return if it fails.
+    auto maybePreparedModel = resilientPreparedModel.recover(preparedModel.get());
+    if (!maybePreparedModel.has_value()) {
+        const auto& [message, code] = maybePreparedModel.error();
+        std::ostringstream oss;
+        oss << ", and failed to recover dead prepared model with error " << code << ": " << message;
+        result.error().message += oss.str();
+        return result;
+    }
+    preparedModel = std::move(maybePreparedModel).value();
+
+    return fn(*preparedModel);
+}
+
+}  // namespace
 
 nn::GeneralResult<std::shared_ptr<const ResilientPreparedModel>> ResilientPreparedModel::create(
         Factory makePreparedModel) {
@@ -55,9 +85,16 @@
     return mPreparedModel;
 }
 
-nn::SharedPreparedModel ResilientPreparedModel::recover(
-        const nn::IPreparedModel* /*failingPreparedModel*/, bool /*blocking*/) const {
+nn::GeneralResult<nn::SharedPreparedModel> ResilientPreparedModel::recover(
+        const nn::IPreparedModel* failingPreparedModel) const {
     std::lock_guard guard(mMutex);
+
+    // Another caller updated the failing prepared model.
+    if (mPreparedModel.get() != failingPreparedModel) {
+        return mPreparedModel;
+    }
+
+    mPreparedModel = NN_TRY(kMakePreparedModel());
     return mPreparedModel;
 }
 
@@ -65,7 +102,11 @@
 ResilientPreparedModel::execute(const nn::Request& request, nn::MeasureTiming measure,
                                 const nn::OptionalTimePoint& deadline,
                                 const nn::OptionalDuration& loopTimeoutDuration) const {
-    return getPreparedModel()->execute(request, measure, deadline, loopTimeoutDuration);
+    const auto fn = [&request, measure, &deadline,
+                     &loopTimeoutDuration](const nn::IPreparedModel& preparedModel) {
+        return preparedModel.execute(request, measure, deadline, loopTimeoutDuration);
+    };
+    return protect(*this, fn);
 }
 
 nn::GeneralResult<std::pair<nn::SyncFence, nn::ExecuteFencedInfoCallback>>
@@ -75,8 +116,12 @@
                                       const nn::OptionalTimePoint& deadline,
                                       const nn::OptionalDuration& loopTimeoutDuration,
                                       const nn::OptionalDuration& timeoutDurationAfterFence) const {
-    return getPreparedModel()->executeFenced(request, waitFor, measure, deadline,
-                                             loopTimeoutDuration, timeoutDurationAfterFence);
+    const auto fn = [&request, &waitFor, measure, &deadline, &loopTimeoutDuration,
+                     &timeoutDurationAfterFence](const nn::IPreparedModel& preparedModel) {
+        return preparedModel.executeFenced(request, waitFor, measure, deadline, loopTimeoutDuration,
+                                           timeoutDurationAfterFence);
+    };
+    return protect(*this, fn);
 }
 
 std::any ResilientPreparedModel::getUnderlyingResource() const {