Context Hub default HAL: Use std::thread for endpoint callbacks

This will ensure the endpoint callbacks are called after
returning from the original initiating function in another thread.

This CL also updates the VtsAidlHalContextHubTargetTest to handle
the async callbacks.

Bug: 380335353
Change-Id: I29d932f8a4d8989c06cfa6007368a424c963c91f
Flag: TEST_ONLY
Test: atest VtsAidlHalContextHubTargetTest
diff --git a/contexthub/aidl/default/ContextHub.cpp b/contexthub/aidl/default/ContextHub.cpp
index c1af0a3..7eb51d0 100644
--- a/contexthub/aidl/default/ContextHub.cpp
+++ b/contexthub/aidl/default/ContextHub.cpp
@@ -15,6 +15,7 @@
  */
 
 #include "contexthub-impl/ContextHub.h"
+#include "aidl/android/hardware/contexthub/IContextHubCallback.h"
 
 #ifndef LOG_TAG
 #define LOG_TAG "CHRE"
@@ -22,6 +23,8 @@
 
 #include <inttypes.h>
 #include <log/log.h>
+#include <optional>
+#include <thread>
 
 using ::ndk::ScopedAStatus;
 
@@ -62,6 +65,9 @@
         },
 };
 
+//! Mutex used to ensure callbacks are called after the initial function returns.
+std::mutex gCallbackMutex;
+
 }  // anonymous namespace
 
 ScopedAStatus ContextHub::getContextHubs(std::vector<ContextHubInfo>* out_contextHubInfos) {
@@ -297,8 +303,8 @@
         mMaxValidSessionId = in_size;
     }
 
-    _aidl_return->at(0) = 0;
-    _aidl_return->at(1) = in_size;
+    (*_aidl_return)[0] = 0;
+    (*_aidl_return)[1] = in_size;
     return ScopedAStatus::ok();
 };
 
@@ -308,7 +314,7 @@
     // We are not calling onCloseEndpointSession on failure because the remote endpoints (our
     // mock endpoints) always accept the session.
 
-    std::shared_ptr<IEndpointCallback> callback = nullptr;
+    std::weak_ptr<IEndpointCallback> callback;
     {
         std::unique_lock<std::mutex> lock(mEndpointMutex);
         if (in_sessionId > mMaxValidSessionId) {
@@ -355,23 +361,27 @@
                 .serviceDescriptor = in_serviceDescriptor,
         });
 
-        if (mEndpointCallback != nullptr) {
-            callback = mEndpointCallback;
+        if (mEndpointCallback == nullptr) {
+            return ScopedAStatus::ok();
         }
+        callback = mEndpointCallback;
     }
 
-    if (callback != nullptr) {
-        callback->onEndpointSessionOpenComplete(in_sessionId);
-    }
+    std::unique_lock<std::mutex> lock(gCallbackMutex);
+    std::thread{[callback, in_sessionId]() {
+        std::unique_lock<std::mutex> lock(gCallbackMutex);
+        if (auto cb = callback.lock(); cb != nullptr) {
+            cb->onEndpointSessionOpenComplete(in_sessionId);
+        }
+    }}.detach();
     return ScopedAStatus::ok();
 };
 
 ScopedAStatus ContextHub::sendMessageToEndpoint(int32_t in_sessionId, const Message& in_msg) {
-    bool foundSession = false;
-    std::shared_ptr<IEndpointCallback> callback = nullptr;
+    std::weak_ptr<IEndpointCallback> callback;
     {
         std::unique_lock<std::mutex> lock(mEndpointMutex);
-
+        bool foundSession = false;
         for (const EndpointSession& session : mEndpointSessions) {
             if (session.sessionId == in_sessionId) {
                 foundSession = true;
@@ -379,27 +389,38 @@
             }
         }
 
-        if (mEndpointCallback != nullptr) {
-            callback = mEndpointCallback;
-        }
-    }
-
-    if (!foundSession) {
-        ALOGE("sendMessageToEndpoint: session ID %" PRId32 " is invalid", in_sessionId);
-        return ScopedAStatus::fromExceptionCode(EX_ILLEGAL_ARGUMENT);
-    }
-
-    if (callback != nullptr) {
-        if (in_msg.flags & Message::FLAG_REQUIRES_DELIVERY_STATUS) {
-            MessageDeliveryStatus msgStatus = {};
-            msgStatus.messageSequenceNumber = in_msg.sequenceNumber;
-            msgStatus.errorCode = ErrorCode::OK;
-            callback->onMessageDeliveryStatusReceived(in_sessionId, msgStatus);
+        if (!foundSession) {
+            ALOGE("sendMessageToEndpoint: session ID %" PRId32 " is invalid", in_sessionId);
+            return ScopedAStatus::fromExceptionCode(EX_ILLEGAL_ARGUMENT);
         }
 
-        // Echo the message back
-        callback->onMessageReceived(in_sessionId, in_msg);
+        if (mEndpointCallback == nullptr) {
+            return ScopedAStatus::ok();
+        }
+        callback = mEndpointCallback;
     }
+
+    std::unique_lock<std::mutex> lock(gCallbackMutex);
+    if ((in_msg.flags & Message::FLAG_REQUIRES_DELIVERY_STATUS) != 0) {
+        MessageDeliveryStatus msgStatus = {};
+        msgStatus.messageSequenceNumber = in_msg.sequenceNumber;
+        msgStatus.errorCode = ErrorCode::OK;
+
+        std::thread{[callback, in_sessionId, msgStatus]() {
+            std::unique_lock<std::mutex> lock(gCallbackMutex);
+            if (auto cb = callback.lock(); cb != nullptr) {
+                cb->onMessageDeliveryStatusReceived(in_sessionId, msgStatus);
+            }
+        }}.detach();
+    }
+
+    // Echo the message back
+    std::thread{[callback, in_sessionId, in_msg]() {
+        std::unique_lock<std::mutex> lock(gCallbackMutex);
+        if (auto cb = callback.lock(); cb != nullptr) {
+            cb->onMessageReceived(in_sessionId, in_msg);
+        }
+    }}.detach();
     return ScopedAStatus::ok();
 };