Merge "Context Hub default HAL: Use std::thread for endpoint callbacks" into main
diff --git a/contexthub/aidl/default/ContextHub.cpp b/contexthub/aidl/default/ContextHub.cpp
index c1af0a3..7eb51d0 100644
--- a/contexthub/aidl/default/ContextHub.cpp
+++ b/contexthub/aidl/default/ContextHub.cpp
@@ -15,6 +15,7 @@
  */
 
 #include "contexthub-impl/ContextHub.h"
+#include "aidl/android/hardware/contexthub/IContextHubCallback.h"
 
 #ifndef LOG_TAG
 #define LOG_TAG "CHRE"
@@ -22,6 +23,8 @@
 
 #include <inttypes.h>
 #include <log/log.h>
+#include <optional>
+#include <thread>
 
 using ::ndk::ScopedAStatus;
 
@@ -62,6 +65,9 @@
         },
 };
 
+//! Mutex used to ensure callbacks are called after the initial function returns.
+std::mutex gCallbackMutex;
+
 }  // anonymous namespace
 
 ScopedAStatus ContextHub::getContextHubs(std::vector<ContextHubInfo>* out_contextHubInfos) {
@@ -297,8 +303,8 @@
         mMaxValidSessionId = in_size;
     }
 
-    _aidl_return->at(0) = 0;
-    _aidl_return->at(1) = in_size;
+    (*_aidl_return)[0] = 0;
+    (*_aidl_return)[1] = in_size;
     return ScopedAStatus::ok();
 };
 
@@ -308,7 +314,7 @@
     // We are not calling onCloseEndpointSession on failure because the remote endpoints (our
     // mock endpoints) always accept the session.
 
-    std::shared_ptr<IEndpointCallback> callback = nullptr;
+    std::weak_ptr<IEndpointCallback> callback;
     {
         std::unique_lock<std::mutex> lock(mEndpointMutex);
         if (in_sessionId > mMaxValidSessionId) {
@@ -355,23 +361,27 @@
                 .serviceDescriptor = in_serviceDescriptor,
         });
 
-        if (mEndpointCallback != nullptr) {
-            callback = mEndpointCallback;
+        if (mEndpointCallback == nullptr) {
+            return ScopedAStatus::ok();
         }
+        callback = mEndpointCallback;
     }
 
-    if (callback != nullptr) {
-        callback->onEndpointSessionOpenComplete(in_sessionId);
-    }
+    std::unique_lock<std::mutex> lock(gCallbackMutex);
+    std::thread{[callback, in_sessionId]() {
+        std::unique_lock<std::mutex> lock(gCallbackMutex);
+        if (auto cb = callback.lock(); cb != nullptr) {
+            cb->onEndpointSessionOpenComplete(in_sessionId);
+        }
+    }}.detach();
     return ScopedAStatus::ok();
 };
 
 ScopedAStatus ContextHub::sendMessageToEndpoint(int32_t in_sessionId, const Message& in_msg) {
-    bool foundSession = false;
-    std::shared_ptr<IEndpointCallback> callback = nullptr;
+    std::weak_ptr<IEndpointCallback> callback;
     {
         std::unique_lock<std::mutex> lock(mEndpointMutex);
-
+        bool foundSession = false;
         for (const EndpointSession& session : mEndpointSessions) {
             if (session.sessionId == in_sessionId) {
                 foundSession = true;
@@ -379,27 +389,38 @@
             }
         }
 
-        if (mEndpointCallback != nullptr) {
-            callback = mEndpointCallback;
-        }
-    }
-
-    if (!foundSession) {
-        ALOGE("sendMessageToEndpoint: session ID %" PRId32 " is invalid", in_sessionId);
-        return ScopedAStatus::fromExceptionCode(EX_ILLEGAL_ARGUMENT);
-    }
-
-    if (callback != nullptr) {
-        if (in_msg.flags & Message::FLAG_REQUIRES_DELIVERY_STATUS) {
-            MessageDeliveryStatus msgStatus = {};
-            msgStatus.messageSequenceNumber = in_msg.sequenceNumber;
-            msgStatus.errorCode = ErrorCode::OK;
-            callback->onMessageDeliveryStatusReceived(in_sessionId, msgStatus);
+        if (!foundSession) {
+            ALOGE("sendMessageToEndpoint: session ID %" PRId32 " is invalid", in_sessionId);
+            return ScopedAStatus::fromExceptionCode(EX_ILLEGAL_ARGUMENT);
         }
 
-        // Echo the message back
-        callback->onMessageReceived(in_sessionId, in_msg);
+        if (mEndpointCallback == nullptr) {
+            return ScopedAStatus::ok();
+        }
+        callback = mEndpointCallback;
     }
+
+    std::unique_lock<std::mutex> lock(gCallbackMutex);
+    if ((in_msg.flags & Message::FLAG_REQUIRES_DELIVERY_STATUS) != 0) {
+        MessageDeliveryStatus msgStatus = {};
+        msgStatus.messageSequenceNumber = in_msg.sequenceNumber;
+        msgStatus.errorCode = ErrorCode::OK;
+
+        std::thread{[callback, in_sessionId, msgStatus]() {
+            std::unique_lock<std::mutex> lock(gCallbackMutex);
+            if (auto cb = callback.lock(); cb != nullptr) {
+                cb->onMessageDeliveryStatusReceived(in_sessionId, msgStatus);
+            }
+        }}.detach();
+    }
+
+    // Echo the message back
+    std::thread{[callback, in_sessionId, in_msg]() {
+        std::unique_lock<std::mutex> lock(gCallbackMutex);
+        if (auto cb = callback.lock(); cb != nullptr) {
+            cb->onMessageReceived(in_sessionId, in_msg);
+        }
+    }}.detach();
     return ScopedAStatus::ok();
 };
 
diff --git a/contexthub/aidl/vts/VtsAidlHalContextHubTargetTest.cpp b/contexthub/aidl/vts/VtsAidlHalContextHubTargetTest.cpp
index 090d4fe..02f0653 100644
--- a/contexthub/aidl/vts/VtsAidlHalContextHubTargetTest.cpp
+++ b/contexthub/aidl/vts/VtsAidlHalContextHubTargetTest.cpp
@@ -492,7 +492,11 @@
     }
 
     Status onMessageReceived(int32_t /* sessionId */, const Message& message) override {
-        mMessages.push_back(message);
+        {
+            std::unique_lock<std::mutex> lock(mMutex);
+            mMessages.push_back(message);
+        }
+        mCondVar.notify_one();
         return Status::ok();
     }
 
@@ -513,21 +517,30 @@
     }
 
     Status onEndpointSessionOpenComplete(int32_t /* sessionId */) override {
-        mWasOnEndpointSessionOpenCompleteCalled = true;
+        {
+            std::unique_lock<std::mutex> lock(mMutex);
+            mWasOnEndpointSessionOpenCompleteCalled = true;
+        }
+        mCondVar.notify_one();
         return Status::ok();
     }
 
-    std::vector<Message> getMessages() { return mMessages; }
-
     bool wasOnEndpointSessionOpenCompleteCalled() {
         return mWasOnEndpointSessionOpenCompleteCalled;
     }
+
     void resetWasOnEndpointSessionOpenCompleteCalled() {
         mWasOnEndpointSessionOpenCompleteCalled = false;
     }
 
+    std::mutex& getMutex() { return mMutex; }
+    std::condition_variable& getCondVar() { return mCondVar; }
+    std::vector<Message> getMessages() { return mMessages; }
+
   private:
     std::vector<Message> mMessages;
+    std::mutex mMutex;
+    std::condition_variable mCondVar;
     bool mWasOnEndpointSessionOpenCompleteCalled = false;
 };
 
@@ -690,14 +703,12 @@
     EXPECT_GE(range[1] - range[0] + 1, requestedRange);
 
     // Open the session
-    cb->resetWasOnEndpointSessionOpenCompleteCalled();
     int32_t sessionId = range[1] + 10;  // invalid
     EXPECT_FALSE(contextHub
                          ->openEndpointSession(sessionId, destinationEndpoint->id,
                                                initiatorEndpoint.id,
                                                /* in_serviceDescriptor= */ String16("ECHO"))
                          .isOk());
-    EXPECT_FALSE(cb->wasOnEndpointSessionOpenCompleteCalled());
 }
 
 TEST_P(ContextHubAidl, OpenEndpointSessionAndSendMessageEchoesBack) {
@@ -710,6 +721,8 @@
         EXPECT_TRUE(status.isOk());
     }
 
+    std::unique_lock<std::mutex> lock(cb->getMutex());
+
     // Register the endpoint
     EndpointInfo initiatorEndpoint;
     initiatorEndpoint.id.id = 8;
@@ -750,6 +763,7 @@
                                               initiatorEndpoint.id,
                                               /* in_serviceDescriptor= */ String16("ECHO"))
                         .isOk());
+    cb->getCondVar().wait(lock);
     EXPECT_TRUE(cb->wasOnEndpointSessionOpenCompleteCalled());
 
     // Send the message
@@ -760,6 +774,7 @@
     ASSERT_TRUE(contextHub->sendMessageToEndpoint(sessionId, message).isOk());
 
     // Check for echo
+    cb->getCondVar().wait(lock);
     EXPECT_FALSE(cb->getMessages().empty());
     EXPECT_EQ(cb->getMessages().back().content.back(), 42);
 }