blob: 46bf24350fb5894b22602be24ce44a01344df45e [file] [log] [blame]
Michael Butlercf22a572017-09-22 13:26:12 -07001#include "Callbacks.h"
2#include <android-base/logging.h>
3
4namespace android {
5namespace hardware {
6namespace neuralnetworks {
7namespace V1_0 {
8namespace implementation {
9
10CallbackBase::CallbackBase() : mNotified(false) {}
11
12CallbackBase::~CallbackBase() {
13 // Note that we cannot call CallbackBase::join_thread from here:
14 // CallbackBase is intended to be reference counted, and it is possible that
15 // the reference count drops to zero in the bound thread, causing the
16 // bound thread to call this destructor. If a thread tries to join
17 // itself, it throws an exception, producing a message like the
18 // following:
19 //
20 // terminating with uncaught exception of type std::__1::system_error:
21 // thread::join failed: Resource deadlock would occur
22}
23
24void CallbackBase::wait() {
25 std::unique_lock<std::mutex> lock(mMutex);
26 mCondition.wait(lock, [this]{return mNotified;});
27 join_thread_locked();
28}
29
30bool CallbackBase::on_finish(std::function<bool(void)> post_work) {
31 std::lock_guard<std::mutex> lock(mMutex);
32 if (mPostWork != nullptr) {
33 LOG(ERROR) << "CallbackBase::on_finish -- a post-work function has already been bound to "
34 "this callback object";
35 return false;
36 }
37 if (post_work == nullptr) {
38 LOG(ERROR) << "CallbackBase::on_finish -- the new post-work function is invalid";
39 return false;
40 }
41 mPostWork = std::move(post_work);
42 return true;
43}
44
45bool CallbackBase::bind_thread(std::thread&& asyncThread) {
46 std::lock_guard<std::mutex> lock(mMutex);
47 if (mThread.joinable()) {
48 LOG(ERROR) << "CallbackBase::bind_thread -- a thread has already been bound to this "
49 "callback object";
50 return false;
51 }
52 if (!asyncThread.joinable()) {
53 LOG(ERROR) << "CallbackBase::bind_thread -- the new thread is not joinable";
54 return false;
55 }
56 mThread = std::move(asyncThread);
57 return true;
58}
59
60void CallbackBase::join_thread() {
61 std::lock_guard<std::mutex> lock(mMutex);
62 join_thread_locked();
63}
64
65void CallbackBase::notify() {
66 {
67 std::lock_guard<std::mutex> lock(mMutex);
68 mNotified = true;
69 if (mPostWork != nullptr) {
70 bool success = mPostWork();
71 if (!success) {
72 LOG(ERROR) << "CallbackBase::notify -- post work failed";
73 }
74 }
75 }
76 mCondition.notify_all();
77}
78
79void CallbackBase::join_thread_locked() {
80 if (mThread.joinable()) {
81 mThread.join();
82 }
83}
84
85PreparedModelCallback::PreparedModelCallback() :
86 mErrorStatus(ErrorStatus::GENERAL_FAILURE), mPreparedModel(nullptr) {}
87
88PreparedModelCallback::~PreparedModelCallback() {}
89
90Return<void> PreparedModelCallback::notify(ErrorStatus errorStatus,
91 const sp<IPreparedModel>& preparedModel) {
92 mErrorStatus = errorStatus;
93 mPreparedModel = preparedModel;
94 CallbackBase::notify();
95 return Void();
96}
97
98ErrorStatus PreparedModelCallback::getStatus() {
99 wait();
100 return mErrorStatus;
101}
102
103sp<IPreparedModel> PreparedModelCallback::getPreparedModel() {
104 wait();
105 return mPreparedModel;
106}
107
108ExecutionCallback::ExecutionCallback() : mErrorStatus(ErrorStatus::GENERAL_FAILURE) {}
109
110ExecutionCallback::~ExecutionCallback() {}
111
112Return<void> ExecutionCallback::notify(ErrorStatus errorStatus) {
113 mErrorStatus = errorStatus;
114 CallbackBase::notify();
115 return Void();
116}
117
118ErrorStatus ExecutionCallback::getStatus() {
119 wait();
120 return mErrorStatus;
121}
122
123} // namespace implementation
124} // namespace V1_0
125} // namespace neuralnetworks
126} // namespace hardware
127} // namespace android