Update NN VTS callback objects
The VTS Callback files are a subset of the Callback files in
frameworks/ml/nn/runtime/Callbacks.*. This CL syncs the implementations,
removing the functionality that is not needed in VTS.
Fixes: 132322149
Test: mma
Test: VtsHalNeuralnetworksV1_0TargetTest
Test: VtsHalNeuralnetworksV1_1TargetTest
Test: VtsHalNeuralnetworksV1_2TargetTest
Change-Id: I114ce7f3b6c3d58de0196e9508209614d0a73e11
diff --git a/neuralnetworks/1.2/vts/functional/Callbacks.cpp b/neuralnetworks/1.2/vts/functional/Callbacks.cpp
index cfaf91d..a607a08 100644
--- a/neuralnetworks/1.2/vts/functional/Callbacks.cpp
+++ b/neuralnetworks/1.2/vts/functional/Callbacks.cpp
@@ -14,160 +14,128 @@
* limitations under the License.
*/
+#define LOG_TAG "Callbacks"
+
#include "1.2/Callbacks.h"
+
#include <android-base/logging.h>
-namespace android {
-namespace hardware {
-namespace neuralnetworks {
-namespace V1_2 {
-namespace implementation {
+#include <limits>
-CallbackBase::CallbackBase() : mNotified(false) {}
+namespace android::hardware::neuralnetworks::V1_2::implementation {
-CallbackBase::~CallbackBase() {
- // Note that we cannot call CallbackBase::join_thread from here:
- // CallbackBase is intended to be reference counted, and it is possible that
- // the reference count drops to zero in the bound thread, causing the
- // bound thread to call this destructor. If a thread tries to join
- // itself, it throws an exception, producing a message like the
- // following:
- //
- // terminating with uncaught exception of type std::__1::system_error:
- // thread::join failed: Resource deadlock would occur
-}
+constexpr Timing kNoTiming = {.timeOnDevice = std::numeric_limits<uint64_t>::max(),
+ .timeInDriver = std::numeric_limits<uint64_t>::max()};
-void CallbackBase::wait() {
- std::unique_lock<std::mutex> lock(mMutex);
- mCondition.wait(lock, [this] { return mNotified; });
- join_thread_locked();
-}
-
-bool CallbackBase::on_finish(std::function<bool(void)> post_work) {
- std::lock_guard<std::mutex> lock(mMutex);
- if (mPostWork != nullptr) {
- LOG(ERROR) << "CallbackBase::on_finish -- a post-work function has already been bound to "
- "this callback object";
- return false;
- }
- if (post_work == nullptr) {
- LOG(ERROR) << "CallbackBase::on_finish -- the new post-work function is invalid";
- return false;
- }
- mPostWork = std::move(post_work);
- return true;
-}
-
-bool CallbackBase::bind_thread(std::thread&& asyncThread) {
- std::lock_guard<std::mutex> lock(mMutex);
- if (mThread.joinable()) {
- LOG(ERROR) << "CallbackBase::bind_thread -- a thread has already been bound to this "
- "callback object";
- return false;
- }
- if (!asyncThread.joinable()) {
- LOG(ERROR) << "CallbackBase::bind_thread -- the new thread is not joinable";
- return false;
- }
- mThread = std::move(asyncThread);
- return true;
-}
-
-void CallbackBase::join_thread() {
- std::lock_guard<std::mutex> lock(mMutex);
- join_thread_locked();
-}
-
-void CallbackBase::notify() {
- {
- std::lock_guard<std::mutex> lock(mMutex);
- mNotified = true;
- if (mPostWork != nullptr) {
- bool success = mPostWork();
- if (!success) {
- LOG(ERROR) << "CallbackBase::notify -- post work failed";
- }
- }
- }
- mCondition.notify_all();
-}
-
-void CallbackBase::join_thread_locked() {
- if (mThread.joinable()) {
- mThread.join();
- }
-}
-
-PreparedModelCallback::PreparedModelCallback()
- : mErrorStatus(ErrorStatus::GENERAL_FAILURE), mPreparedModel(nullptr) {}
-
-PreparedModelCallback::~PreparedModelCallback() {}
+// PreparedModelCallback methods begin here
Return<void> PreparedModelCallback::notify(ErrorStatus errorStatus,
const sp<V1_0::IPreparedModel>& preparedModel) {
- mErrorStatus = errorStatus;
- mPreparedModel = preparedModel;
- CallbackBase::notify();
+ {
+ std::lock_guard<std::mutex> hold(mMutex);
+
+ // quick-return if object has already been notified
+ if (mNotified) {
+ return Void();
+ }
+
+ // store results and mark as notified
+ mErrorStatus = errorStatus;
+ mPreparedModel = preparedModel;
+ mNotified = true;
+ }
+
+ mCondition.notify_all();
return Void();
}
Return<void> PreparedModelCallback::notify_1_2(ErrorStatus errorStatus,
const sp<V1_2::IPreparedModel>& preparedModel) {
- mErrorStatus = errorStatus;
- mPreparedModel = preparedModel;
- CallbackBase::notify();
- return Void();
+ return notify(errorStatus, preparedModel);
}
-ErrorStatus PreparedModelCallback::getStatus() {
+void PreparedModelCallback::wait() const {
+ std::unique_lock<std::mutex> lock(mMutex);
+ mCondition.wait(lock, [this] { return mNotified; });
+}
+
+ErrorStatus PreparedModelCallback::getStatus() const {
wait();
return mErrorStatus;
}
-sp<V1_0::IPreparedModel> PreparedModelCallback::getPreparedModel() {
+sp<V1_0::IPreparedModel> PreparedModelCallback::getPreparedModel() const {
wait();
return mPreparedModel;
}
-ExecutionCallback::ExecutionCallback() : mErrorStatus(ErrorStatus::GENERAL_FAILURE) {}
-
-ExecutionCallback::~ExecutionCallback() {}
+// ExecutionCallback methods begin here
Return<void> ExecutionCallback::notify(ErrorStatus errorStatus) {
- mErrorStatus = errorStatus;
- mOutputShapes = {};
- mTiming = {.timeOnDevice = UINT64_MAX, .timeInDriver = UINT64_MAX};
- CallbackBase::notify();
+ notifyInternal(errorStatus, {}, kNoTiming);
return Void();
}
Return<void> ExecutionCallback::notify_1_2(ErrorStatus errorStatus,
const hidl_vec<OutputShape>& outputShapes,
const Timing& timing) {
- mErrorStatus = errorStatus;
- mOutputShapes = outputShapes;
- mTiming = timing;
- CallbackBase::notify();
+ if (errorStatus == ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) {
+ // outputShapes must not be empty if OUTPUT_INSUFFICIENT_SIZE.
+ if (outputShapes.size() == 0) {
+ LOG(ERROR) << "Notified with empty output shape vector when OUTPUT_INSUFFICIENT_SIZE";
+ notifyInternal(ErrorStatus::GENERAL_FAILURE, {}, kNoTiming);
+ return Void();
+ }
+ } else if (errorStatus != ErrorStatus::NONE) {
+ // outputShapes must be empty if errorStatus is neither NONE nor OUTPUT_INSUFFICIENT_SIZE.
+ if (outputShapes.size() != 0) {
+ LOG(ERROR) << "Notified with non-empty output shape vector when error status is "
+ "neither NONE nor OUTPUT_INSUFFICIENT_SIZE";
+ notifyInternal(ErrorStatus::GENERAL_FAILURE, {}, kNoTiming);
+ return Void();
+ }
+ }
+ notifyInternal(errorStatus, outputShapes, timing);
return Void();
}
-ErrorStatus ExecutionCallback::getStatus() {
+void ExecutionCallback::wait() const {
+ std::unique_lock<std::mutex> lock(mMutex);
+ mCondition.wait(lock, [this] { return mNotified; });
+}
+
+ErrorStatus ExecutionCallback::getStatus() const {
wait();
return mErrorStatus;
}
-const std::vector<OutputShape>& ExecutionCallback::getOutputShapes() {
+const std::vector<OutputShape>& ExecutionCallback::getOutputShapes() const {
wait();
return mOutputShapes;
}
-Timing ExecutionCallback::getTiming() {
+Timing ExecutionCallback::getTiming() const {
wait();
return mTiming;
}
-} // namespace implementation
-} // namespace V1_2
-} // namespace neuralnetworks
-} // namespace hardware
-} // namespace android
+void ExecutionCallback::notifyInternal(ErrorStatus errorStatus,
+ const hidl_vec<OutputShape>& outputShapes,
+ const Timing& timing) {
+ {
+ std::lock_guard<std::mutex> hold(mMutex);
+
+ // quick-return if object has already been notified
+ if (mNotified) {
+ return;
+ }
+
+ mErrorStatus = errorStatus;
+ mOutputShapes = outputShapes;
+ mTiming = timing;
+ mNotified = true;
+ }
+ mCondition.notify_all();
+}
+
+} // namespace android::hardware::neuralnetworks::V1_2::implementation