Merge "Context Hub default HAL: Use std::thread for endpoint callbacks" into main
diff --git a/contexthub/aidl/default/ContextHub.cpp b/contexthub/aidl/default/ContextHub.cpp
index c1af0a3..7eb51d0 100644
--- a/contexthub/aidl/default/ContextHub.cpp
+++ b/contexthub/aidl/default/ContextHub.cpp
@@ -15,6 +15,7 @@
*/
#include "contexthub-impl/ContextHub.h"
+#include "aidl/android/hardware/contexthub/IContextHubCallback.h"
#ifndef LOG_TAG
#define LOG_TAG "CHRE"
@@ -22,6 +23,8 @@
#include <inttypes.h>
#include <log/log.h>
+#include <optional>
+#include <thread>
using ::ndk::ScopedAStatus;
@@ -62,6 +65,9 @@
},
};
+//! Mutex used to ensure callbacks are called after the initial function returns.
+std::mutex gCallbackMutex;
+
} // anonymous namespace
ScopedAStatus ContextHub::getContextHubs(std::vector<ContextHubInfo>* out_contextHubInfos) {
@@ -297,8 +303,8 @@
mMaxValidSessionId = in_size;
}
- _aidl_return->at(0) = 0;
- _aidl_return->at(1) = in_size;
+ (*_aidl_return)[0] = 0;
+ (*_aidl_return)[1] = in_size;
return ScopedAStatus::ok();
};
@@ -308,7 +314,7 @@
// We are not calling onCloseEndpointSession on failure because the remote endpoints (our
// mock endpoints) always accept the session.
- std::shared_ptr<IEndpointCallback> callback = nullptr;
+ std::weak_ptr<IEndpointCallback> callback;
{
std::unique_lock<std::mutex> lock(mEndpointMutex);
if (in_sessionId > mMaxValidSessionId) {
@@ -355,23 +361,27 @@
.serviceDescriptor = in_serviceDescriptor,
});
- if (mEndpointCallback != nullptr) {
- callback = mEndpointCallback;
+ if (mEndpointCallback == nullptr) {
+ return ScopedAStatus::ok();
}
+ callback = mEndpointCallback;
}
- if (callback != nullptr) {
- callback->onEndpointSessionOpenComplete(in_sessionId);
- }
+ std::unique_lock<std::mutex> lock(gCallbackMutex);
+ std::thread{[callback, in_sessionId]() {
+ std::unique_lock<std::mutex> lock(gCallbackMutex);
+ if (auto cb = callback.lock(); cb != nullptr) {
+ cb->onEndpointSessionOpenComplete(in_sessionId);
+ }
+ }}.detach();
return ScopedAStatus::ok();
};
ScopedAStatus ContextHub::sendMessageToEndpoint(int32_t in_sessionId, const Message& in_msg) {
- bool foundSession = false;
- std::shared_ptr<IEndpointCallback> callback = nullptr;
+ std::weak_ptr<IEndpointCallback> callback;
{
std::unique_lock<std::mutex> lock(mEndpointMutex);
-
+ bool foundSession = false;
for (const EndpointSession& session : mEndpointSessions) {
if (session.sessionId == in_sessionId) {
foundSession = true;
@@ -379,27 +389,38 @@
}
}
- if (mEndpointCallback != nullptr) {
- callback = mEndpointCallback;
- }
- }
-
- if (!foundSession) {
- ALOGE("sendMessageToEndpoint: session ID %" PRId32 " is invalid", in_sessionId);
- return ScopedAStatus::fromExceptionCode(EX_ILLEGAL_ARGUMENT);
- }
-
- if (callback != nullptr) {
- if (in_msg.flags & Message::FLAG_REQUIRES_DELIVERY_STATUS) {
- MessageDeliveryStatus msgStatus = {};
- msgStatus.messageSequenceNumber = in_msg.sequenceNumber;
- msgStatus.errorCode = ErrorCode::OK;
- callback->onMessageDeliveryStatusReceived(in_sessionId, msgStatus);
+ if (!foundSession) {
+ ALOGE("sendMessageToEndpoint: session ID %" PRId32 " is invalid", in_sessionId);
+ return ScopedAStatus::fromExceptionCode(EX_ILLEGAL_ARGUMENT);
}
- // Echo the message back
- callback->onMessageReceived(in_sessionId, in_msg);
+ if (mEndpointCallback == nullptr) {
+ return ScopedAStatus::ok();
+ }
+ callback = mEndpointCallback;
}
+
+ std::unique_lock<std::mutex> lock(gCallbackMutex);
+ if ((in_msg.flags & Message::FLAG_REQUIRES_DELIVERY_STATUS) != 0) {
+ MessageDeliveryStatus msgStatus = {};
+ msgStatus.messageSequenceNumber = in_msg.sequenceNumber;
+ msgStatus.errorCode = ErrorCode::OK;
+
+ std::thread{[callback, in_sessionId, msgStatus]() {
+ std::unique_lock<std::mutex> lock(gCallbackMutex);
+ if (auto cb = callback.lock(); cb != nullptr) {
+ cb->onMessageDeliveryStatusReceived(in_sessionId, msgStatus);
+ }
+ }}.detach();
+ }
+
+ // Echo the message back
+ std::thread{[callback, in_sessionId, in_msg]() {
+ std::unique_lock<std::mutex> lock(gCallbackMutex);
+ if (auto cb = callback.lock(); cb != nullptr) {
+ cb->onMessageReceived(in_sessionId, in_msg);
+ }
+ }}.detach();
return ScopedAStatus::ok();
};
diff --git a/contexthub/aidl/vts/VtsAidlHalContextHubTargetTest.cpp b/contexthub/aidl/vts/VtsAidlHalContextHubTargetTest.cpp
index 090d4fe..02f0653 100644
--- a/contexthub/aidl/vts/VtsAidlHalContextHubTargetTest.cpp
+++ b/contexthub/aidl/vts/VtsAidlHalContextHubTargetTest.cpp
@@ -492,7 +492,11 @@
}
Status onMessageReceived(int32_t /* sessionId */, const Message& message) override {
- mMessages.push_back(message);
+ {
+ std::unique_lock<std::mutex> lock(mMutex);
+ mMessages.push_back(message);
+ }
+ mCondVar.notify_one();
return Status::ok();
}
@@ -513,21 +517,30 @@
}
Status onEndpointSessionOpenComplete(int32_t /* sessionId */) override {
- mWasOnEndpointSessionOpenCompleteCalled = true;
+ {
+ std::unique_lock<std::mutex> lock(mMutex);
+ mWasOnEndpointSessionOpenCompleteCalled = true;
+ }
+ mCondVar.notify_one();
return Status::ok();
}
- std::vector<Message> getMessages() { return mMessages; }
-
bool wasOnEndpointSessionOpenCompleteCalled() {
return mWasOnEndpointSessionOpenCompleteCalled;
}
+
void resetWasOnEndpointSessionOpenCompleteCalled() {
mWasOnEndpointSessionOpenCompleteCalled = false;
}
+ std::mutex& getMutex() { return mMutex; }
+ std::condition_variable& getCondVar() { return mCondVar; }
+ std::vector<Message> getMessages() { return mMessages; }
+
private:
std::vector<Message> mMessages;
+ std::mutex mMutex;
+ std::condition_variable mCondVar;
bool mWasOnEndpointSessionOpenCompleteCalled = false;
};
@@ -690,14 +703,12 @@
EXPECT_GE(range[1] - range[0] + 1, requestedRange);
// Open the session
- cb->resetWasOnEndpointSessionOpenCompleteCalled();
int32_t sessionId = range[1] + 10; // invalid
EXPECT_FALSE(contextHub
->openEndpointSession(sessionId, destinationEndpoint->id,
initiatorEndpoint.id,
/* in_serviceDescriptor= */ String16("ECHO"))
.isOk());
- EXPECT_FALSE(cb->wasOnEndpointSessionOpenCompleteCalled());
}
TEST_P(ContextHubAidl, OpenEndpointSessionAndSendMessageEchoesBack) {
@@ -710,6 +721,8 @@
EXPECT_TRUE(status.isOk());
}
+ std::unique_lock<std::mutex> lock(cb->getMutex());
+
// Register the endpoint
EndpointInfo initiatorEndpoint;
initiatorEndpoint.id.id = 8;
@@ -750,6 +763,7 @@
initiatorEndpoint.id,
/* in_serviceDescriptor= */ String16("ECHO"))
.isOk());
+ cb->getCondVar().wait(lock);
EXPECT_TRUE(cb->wasOnEndpointSessionOpenCompleteCalled());
// Send the message
@@ -760,6 +774,7 @@
ASSERT_TRUE(contextHub->sendMessageToEndpoint(sessionId, message).isOk());
// Check for echo
+ cb->getCondVar().wait(lock);
EXPECT_FALSE(cb->getMessages().empty());
EXPECT_EQ(cb->getMessages().back().content.back(), 42);
}