contexthub: Make client registration API safer

Modifies the registration API for new endpoint session capabilities to
return a new dedicated IEndpointCommunication interface tied to the
client registration. All endpoint/session APIs have been moved to the
new interface. This prevents security issues resulting from pid re-use.

Adds an unregister() API to IEndpointCommunication for a client to
signal that its associated hub will no longer be active. Subsequent
calls on IEndpointCommunication APIs will fail.

Bug: 384897139
Flag: android.chre.flags.offload_api
Test: Builds
Change-Id: Ibd8a28e96315ca146d62413b5337ece9912d14c4
diff --git a/contexthub/aidl/default/ContextHub.cpp b/contexthub/aidl/default/ContextHub.cpp
index 19d9639..433617e 100644
--- a/contexthub/aidl/default/ContextHub.cpp
+++ b/contexthub/aidl/default/ContextHub.cpp
@@ -150,10 +150,11 @@
 
 ScopedAStatus ContextHub::setTestMode(bool enable) {
     if (enable) {
-        std::unique_lock<std::mutex> lock(mEndpointMutex);
-        mEndpoints.clear();
-        mEndpointSessions.clear();
-        mEndpointCallback = nullptr;
+        std::lock_guard lock(mHostHubsLock);
+        for (auto& [id, hub] : mIdToHostHub) {
+            hub->mActive = false;
+        }
+        mIdToHostHub.clear();
     }
     return ScopedAStatus::ok();
 }
@@ -227,7 +228,23 @@
     return ScopedAStatus::ok();
 };
 
-ScopedAStatus ContextHub::registerEndpoint(const EndpointInfo& in_endpoint) {
+ScopedAStatus ContextHub::registerEndpointHub(
+        const std::shared_ptr<IEndpointCallback>& in_callback, const HubInfo& in_hubInfo,
+        std::shared_ptr<IEndpointCommunication>* _aidl_return) {
+    std::lock_guard lock(mHostHubsLock);
+    if (mIdToHostHub.count(in_hubInfo.hubId)) {
+        return ScopedAStatus::fromExceptionCode(EX_ILLEGAL_STATE);
+    }
+    auto hub = ndk::SharedRefBase::make<HubInterface>(*this, in_callback, in_hubInfo);
+    mIdToHostHub.insert({in_hubInfo.hubId, hub});
+    *_aidl_return = std::move(hub);
+    return ScopedAStatus::ok();
+}
+
+ScopedAStatus ContextHub::HubInterface::registerEndpoint(const EndpointInfo& in_endpoint) {
+    if (!mActive) {
+        return ScopedAStatus::fromExceptionCode(EX_ILLEGAL_STATE);
+    }
     std::unique_lock<std::mutex> lock(mEndpointMutex);
 
     for (const EndpointInfo& endpoint : mEndpoints) {
@@ -240,7 +257,10 @@
     return ScopedAStatus::ok();
 };
 
-ScopedAStatus ContextHub::unregisterEndpoint(const EndpointInfo& in_endpoint) {
+ScopedAStatus ContextHub::HubInterface::unregisterEndpoint(const EndpointInfo& in_endpoint) {
+    if (!mActive) {
+        return ScopedAStatus::fromExceptionCode(EX_ILLEGAL_STATE);
+    }
     std::unique_lock<std::mutex> lock(mEndpointMutex);
 
     for (auto it = mEndpoints.begin(); it != mEndpoints.end(); ++it) {
@@ -252,41 +272,47 @@
     return ScopedAStatus::fromExceptionCode(EX_ILLEGAL_ARGUMENT);
 };
 
-ScopedAStatus ContextHub::registerEndpointCallback(
-        const std::shared_ptr<IEndpointCallback>& in_callback) {
-    std::unique_lock<std::mutex> lock(mEndpointMutex);
-
-    mEndpointCallback = in_callback;
-    return ScopedAStatus::ok();
-};
-
-ScopedAStatus ContextHub::requestSessionIdRange(int32_t in_size,
-                                                std::array<int32_t, 2>* _aidl_return) {
+ScopedAStatus ContextHub::HubInterface::requestSessionIdRange(
+        int32_t in_size, std::array<int32_t, 2>* _aidl_return) {
+    if (!mActive) {
+        return ScopedAStatus::fromExceptionCode(EX_ILLEGAL_STATE);
+    }
     constexpr int32_t kMaxSize = 1024;
     if (in_size > kMaxSize || _aidl_return == nullptr) {
         return ScopedAStatus::fromExceptionCode(EX_ILLEGAL_ARGUMENT);
     }
 
+    uint16_t base = 0;
     {
-        std::lock_guard<std::mutex> lock(mEndpointMutex);
-        mMaxValidSessionId = in_size;
+        std::lock_guard lock(mHal.mHostHubsLock);
+        if (static_cast<int32_t>(USHRT_MAX) - mHal.mNextSessionIdBase + 1 < in_size) {
+            return ScopedAStatus::fromServiceSpecificError(EX_CONTEXT_HUB_UNSPECIFIED);
+        }
+        base = mHal.mNextSessionIdBase;
+        mHal.mNextSessionIdBase += in_size;
     }
 
-    (*_aidl_return)[0] = 0;
-    (*_aidl_return)[1] = in_size;
+    {
+        std::lock_guard<std::mutex> lock(mEndpointMutex);
+        (*_aidl_return)[0] = mBaseSessionId = base;
+        (*_aidl_return)[1] = mMaxSessionId = base + (in_size - 1);
+    }
     return ScopedAStatus::ok();
 };
 
-ScopedAStatus ContextHub::openEndpointSession(
+ScopedAStatus ContextHub::HubInterface::openEndpointSession(
         int32_t in_sessionId, const EndpointId& in_destination, const EndpointId& in_initiator,
         const std::optional<std::string>& in_serviceDescriptor) {
+    if (!mActive) {
+        return ScopedAStatus::fromExceptionCode(EX_ILLEGAL_STATE);
+    }
     // We are not calling onCloseEndpointSession on failure because the remote endpoints (our
     // mock endpoints) always accept the session.
 
     std::weak_ptr<IEndpointCallback> callback;
     {
         std::unique_lock<std::mutex> lock(mEndpointMutex);
-        if (in_sessionId > mMaxValidSessionId) {
+        if (in_sessionId < mBaseSessionId || in_sessionId > mMaxSessionId) {
             ALOGE("openEndpointSession: session ID %" PRId32 " is invalid", in_sessionId);
             return ScopedAStatus::fromExceptionCode(EX_ILLEGAL_ARGUMENT);
         }
@@ -346,7 +372,11 @@
     return ScopedAStatus::ok();
 };
 
-ScopedAStatus ContextHub::sendMessageToEndpoint(int32_t in_sessionId, const Message& in_msg) {
+ScopedAStatus ContextHub::HubInterface::sendMessageToEndpoint(int32_t in_sessionId,
+                                                              const Message& in_msg) {
+    if (!mActive) {
+        return ScopedAStatus::fromExceptionCode(EX_ILLEGAL_STATE);
+    }
     std::weak_ptr<IEndpointCallback> callback;
     {
         std::unique_lock<std::mutex> lock(mEndpointMutex);
@@ -393,12 +423,19 @@
     return ScopedAStatus::ok();
 };
 
-ScopedAStatus ContextHub::sendMessageDeliveryStatusToEndpoint(
+ScopedAStatus ContextHub::HubInterface::sendMessageDeliveryStatusToEndpoint(
         int32_t /* in_sessionId */, const MessageDeliveryStatus& /* in_msgStatus */) {
+    if (!mActive) {
+        return ScopedAStatus::fromExceptionCode(EX_ILLEGAL_STATE);
+    }
     return ScopedAStatus::ok();
 };
 
-ScopedAStatus ContextHub::closeEndpointSession(int32_t in_sessionId, Reason /* in_reason */) {
+ScopedAStatus ContextHub::HubInterface::closeEndpointSession(int32_t in_sessionId,
+                                                             Reason /* in_reason */) {
+    if (!mActive) {
+        return ScopedAStatus::fromExceptionCode(EX_ILLEGAL_STATE);
+    }
     std::unique_lock<std::mutex> lock(mEndpointMutex);
 
     for (auto it = mEndpointSessions.begin(); it != mEndpointSessions.end(); ++it) {
@@ -411,8 +448,20 @@
     return ScopedAStatus::fromExceptionCode(EX_ILLEGAL_ARGUMENT);
 };
 
-ScopedAStatus ContextHub::endpointSessionOpenComplete(int32_t /* in_sessionId */) {
+ScopedAStatus ContextHub::HubInterface::endpointSessionOpenComplete(int32_t /* in_sessionId */) {
+    if (!mActive) {
+        return ScopedAStatus::fromExceptionCode(EX_ILLEGAL_STATE);
+    }
     return ScopedAStatus::ok();
 };
 
+ScopedAStatus ContextHub::HubInterface::unregister() {
+    if (!mActive.exchange(false)) {
+        return ScopedAStatus::fromExceptionCode(EX_ILLEGAL_STATE);
+    }
+    std::lock_guard lock(mHal.mHostHubsLock);
+    mHal.mIdToHostHub.erase(kInfo.hubId);
+    return ScopedAStatus::ok();
+}
+
 }  // namespace aidl::android::hardware::contexthub