Update NN VTS callback objects

The VTS Callback files are a subset of the Callback files in
frameworks/ml/nn/runtime/Callbacks.*. This CL syncs the implementations,
removing the functionality that is not needed in VTS.

Fixes: 132322149
Test: mma
Test: VtsHalNeuralnetworksV1_0TargetTest
Test: VtsHalNeuralnetworksV1_1TargetTest
Test: VtsHalNeuralnetworksV1_2TargetTest
Change-Id: I114ce7f3b6c3d58de0196e9508209614d0a73e11
diff --git a/neuralnetworks/1.2/vts/functional/Callbacks.cpp b/neuralnetworks/1.2/vts/functional/Callbacks.cpp
index cfaf91d..a607a08 100644
--- a/neuralnetworks/1.2/vts/functional/Callbacks.cpp
+++ b/neuralnetworks/1.2/vts/functional/Callbacks.cpp
@@ -14,160 +14,128 @@
  * limitations under the License.
  */
 
+#define LOG_TAG "Callbacks"
+
 #include "1.2/Callbacks.h"
+
 #include <android-base/logging.h>
 
-namespace android {
-namespace hardware {
-namespace neuralnetworks {
-namespace V1_2 {
-namespace implementation {
+#include <limits>
 
-CallbackBase::CallbackBase() : mNotified(false) {}
+namespace android::hardware::neuralnetworks::V1_2::implementation {
 
-CallbackBase::~CallbackBase() {
-    // Note that we cannot call CallbackBase::join_thread from here:
-    // CallbackBase is intended to be reference counted, and it is possible that
-    // the reference count drops to zero in the bound thread, causing the
-    // bound thread to call this destructor. If a thread tries to join
-    // itself, it throws an exception, producing a message like the
-    // following:
-    //
-    //     terminating with uncaught exception of type std::__1::system_error:
-    //     thread::join failed: Resource deadlock would occur
-}
+constexpr Timing kNoTiming = {.timeOnDevice = std::numeric_limits<uint64_t>::max(),
+                              .timeInDriver = std::numeric_limits<uint64_t>::max()};
 
-void CallbackBase::wait() {
-    std::unique_lock<std::mutex> lock(mMutex);
-    mCondition.wait(lock, [this] { return mNotified; });
-    join_thread_locked();
-}
-
-bool CallbackBase::on_finish(std::function<bool(void)> post_work) {
-    std::lock_guard<std::mutex> lock(mMutex);
-    if (mPostWork != nullptr) {
-        LOG(ERROR) << "CallbackBase::on_finish -- a post-work function has already been bound to "
-                      "this callback object";
-        return false;
-    }
-    if (post_work == nullptr) {
-        LOG(ERROR) << "CallbackBase::on_finish -- the new post-work function is invalid";
-        return false;
-    }
-    mPostWork = std::move(post_work);
-    return true;
-}
-
-bool CallbackBase::bind_thread(std::thread&& asyncThread) {
-    std::lock_guard<std::mutex> lock(mMutex);
-    if (mThread.joinable()) {
-        LOG(ERROR) << "CallbackBase::bind_thread -- a thread has already been bound to this "
-                      "callback object";
-        return false;
-    }
-    if (!asyncThread.joinable()) {
-        LOG(ERROR) << "CallbackBase::bind_thread -- the new thread is not joinable";
-        return false;
-    }
-    mThread = std::move(asyncThread);
-    return true;
-}
-
-void CallbackBase::join_thread() {
-    std::lock_guard<std::mutex> lock(mMutex);
-    join_thread_locked();
-}
-
-void CallbackBase::notify() {
-    {
-        std::lock_guard<std::mutex> lock(mMutex);
-        mNotified = true;
-        if (mPostWork != nullptr) {
-            bool success = mPostWork();
-            if (!success) {
-                LOG(ERROR) << "CallbackBase::notify -- post work failed";
-            }
-        }
-    }
-    mCondition.notify_all();
-}
-
-void CallbackBase::join_thread_locked() {
-    if (mThread.joinable()) {
-        mThread.join();
-    }
-}
-
-PreparedModelCallback::PreparedModelCallback()
-    : mErrorStatus(ErrorStatus::GENERAL_FAILURE), mPreparedModel(nullptr) {}
-
-PreparedModelCallback::~PreparedModelCallback() {}
+// PreparedModelCallback methods begin here
 
 Return<void> PreparedModelCallback::notify(ErrorStatus errorStatus,
                                            const sp<V1_0::IPreparedModel>& preparedModel) {
-    mErrorStatus = errorStatus;
-    mPreparedModel = preparedModel;
-    CallbackBase::notify();
+    {
+        std::lock_guard<std::mutex> hold(mMutex);
+
+        // quick-return if object has already been notified
+        if (mNotified) {
+            return Void();
+        }
+
+        // store results and mark as notified
+        mErrorStatus = errorStatus;
+        mPreparedModel = preparedModel;
+        mNotified = true;
+    }
+
+    mCondition.notify_all();
     return Void();
 }
 
 Return<void> PreparedModelCallback::notify_1_2(ErrorStatus errorStatus,
                                                const sp<V1_2::IPreparedModel>& preparedModel) {
-    mErrorStatus = errorStatus;
-    mPreparedModel = preparedModel;
-    CallbackBase::notify();
-    return Void();
+    return notify(errorStatus, preparedModel);
 }
 
-ErrorStatus PreparedModelCallback::getStatus() {
+void PreparedModelCallback::wait() const {
+    std::unique_lock<std::mutex> lock(mMutex);
+    mCondition.wait(lock, [this] { return mNotified; });
+}
+
+ErrorStatus PreparedModelCallback::getStatus() const {
     wait();
     return mErrorStatus;
 }
 
-sp<V1_0::IPreparedModel> PreparedModelCallback::getPreparedModel() {
+sp<V1_0::IPreparedModel> PreparedModelCallback::getPreparedModel() const {
     wait();
     return mPreparedModel;
 }
 
-ExecutionCallback::ExecutionCallback() : mErrorStatus(ErrorStatus::GENERAL_FAILURE) {}
-
-ExecutionCallback::~ExecutionCallback() {}
+// ExecutionCallback methods begin here
 
 Return<void> ExecutionCallback::notify(ErrorStatus errorStatus) {
-    mErrorStatus = errorStatus;
-    mOutputShapes = {};
-    mTiming = {.timeOnDevice = UINT64_MAX, .timeInDriver = UINT64_MAX};
-    CallbackBase::notify();
+    notifyInternal(errorStatus, {}, kNoTiming);
     return Void();
 }
 
 Return<void> ExecutionCallback::notify_1_2(ErrorStatus errorStatus,
                                            const hidl_vec<OutputShape>& outputShapes,
                                            const Timing& timing) {
-    mErrorStatus = errorStatus;
-    mOutputShapes = outputShapes;
-    mTiming = timing;
-    CallbackBase::notify();
+    if (errorStatus == ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) {
+        // outputShapes must not be empty if OUTPUT_INSUFFICIENT_SIZE.
+        if (outputShapes.size() == 0) {
+            LOG(ERROR) << "Notified with empty output shape vector when OUTPUT_INSUFFICIENT_SIZE";
+            notifyInternal(ErrorStatus::GENERAL_FAILURE, {}, kNoTiming);
+            return Void();
+        }
+    } else if (errorStatus != ErrorStatus::NONE) {
+        // outputShapes must be empty if errorStatus is neither NONE nor OUTPUT_INSUFFICIENT_SIZE.
+        if (outputShapes.size() != 0) {
+            LOG(ERROR) << "Notified with non-empty output shape vector when error status is "
+                          "neither NONE nor OUTPUT_INSUFFICIENT_SIZE";
+            notifyInternal(ErrorStatus::GENERAL_FAILURE, {}, kNoTiming);
+            return Void();
+        }
+    }
+    notifyInternal(errorStatus, outputShapes, timing);
     return Void();
 }
 
-ErrorStatus ExecutionCallback::getStatus() {
+void ExecutionCallback::wait() const {
+    std::unique_lock<std::mutex> lock(mMutex);
+    mCondition.wait(lock, [this] { return mNotified; });
+}
+
+ErrorStatus ExecutionCallback::getStatus() const {
     wait();
     return mErrorStatus;
 }
 
-const std::vector<OutputShape>& ExecutionCallback::getOutputShapes() {
+const std::vector<OutputShape>& ExecutionCallback::getOutputShapes() const {
     wait();
     return mOutputShapes;
 }
 
-Timing ExecutionCallback::getTiming() {
+Timing ExecutionCallback::getTiming() const {
     wait();
     return mTiming;
 }
 
-}  // namespace implementation
-}  // namespace V1_2
-}  // namespace neuralnetworks
-}  // namespace hardware
-}  // namespace android
+void ExecutionCallback::notifyInternal(ErrorStatus errorStatus,
+                                       const hidl_vec<OutputShape>& outputShapes,
+                                       const Timing& timing) {
+    {
+        std::lock_guard<std::mutex> hold(mMutex);
+
+        // quick-return if object has already been notified
+        if (mNotified) {
+            return;
+        }
+
+        mErrorStatus = errorStatus;
+        mOutputShapes = outputShapes;
+        mTiming = timing;
+        mNotified = true;
+    }
+    mCondition.notify_all();
+}
+
+}  // namespace android::hardware::neuralnetworks::V1_2::implementation