contexthub: Make client registration API safer

Modifies the registration API for new endpoint session capabilities to
return a new dedicated IEndpointCommunication interface tied to the
client registration. All endpoint/session APIs have been moved to the
new interface. This prevents security issues resulting from pid re-use.

Adds an unregister() API to IEndpointCommunication for a client to
signal that its associated hub will no longer be active. Subsequent
calls on IEndpointCommunication APIs will fail.

Bug: 384897139
Flag: android.chre.flags.offload_api
Test: Builds
Change-Id: Ibd8a28e96315ca146d62413b5337ece9912d14c4
diff --git a/contexthub/aidl/default/ContextHub.cpp b/contexthub/aidl/default/ContextHub.cpp
index 19d9639..433617e 100644
--- a/contexthub/aidl/default/ContextHub.cpp
+++ b/contexthub/aidl/default/ContextHub.cpp
@@ -150,10 +150,11 @@
 
 ScopedAStatus ContextHub::setTestMode(bool enable) {
     if (enable) {
-        std::unique_lock<std::mutex> lock(mEndpointMutex);
-        mEndpoints.clear();
-        mEndpointSessions.clear();
-        mEndpointCallback = nullptr;
+        std::lock_guard lock(mHostHubsLock);
+        for (auto& [id, hub] : mIdToHostHub) {
+            hub->mActive = false;
+        }
+        mIdToHostHub.clear();
     }
     return ScopedAStatus::ok();
 }
@@ -227,7 +228,23 @@
     return ScopedAStatus::ok();
 };
 
-ScopedAStatus ContextHub::registerEndpoint(const EndpointInfo& in_endpoint) {
+ScopedAStatus ContextHub::registerEndpointHub(
+        const std::shared_ptr<IEndpointCallback>& in_callback, const HubInfo& in_hubInfo,
+        std::shared_ptr<IEndpointCommunication>* _aidl_return) {
+    std::lock_guard lock(mHostHubsLock);
+    if (mIdToHostHub.count(in_hubInfo.hubId)) {
+        return ScopedAStatus::fromExceptionCode(EX_ILLEGAL_STATE);
+    }
+    auto hub = ndk::SharedRefBase::make<HubInterface>(*this, in_callback, in_hubInfo);
+    mIdToHostHub.insert({in_hubInfo.hubId, hub});
+    *_aidl_return = std::move(hub);
+    return ScopedAStatus::ok();
+}
+
+ScopedAStatus ContextHub::HubInterface::registerEndpoint(const EndpointInfo& in_endpoint) {
+    if (!mActive) {
+        return ScopedAStatus::fromExceptionCode(EX_ILLEGAL_STATE);
+    }
     std::unique_lock<std::mutex> lock(mEndpointMutex);
 
     for (const EndpointInfo& endpoint : mEndpoints) {
@@ -240,7 +257,10 @@
     return ScopedAStatus::ok();
 };
 
-ScopedAStatus ContextHub::unregisterEndpoint(const EndpointInfo& in_endpoint) {
+ScopedAStatus ContextHub::HubInterface::unregisterEndpoint(const EndpointInfo& in_endpoint) {
+    if (!mActive) {
+        return ScopedAStatus::fromExceptionCode(EX_ILLEGAL_STATE);
+    }
     std::unique_lock<std::mutex> lock(mEndpointMutex);
 
     for (auto it = mEndpoints.begin(); it != mEndpoints.end(); ++it) {
@@ -252,41 +272,47 @@
     return ScopedAStatus::fromExceptionCode(EX_ILLEGAL_ARGUMENT);
 };
 
-ScopedAStatus ContextHub::registerEndpointCallback(
-        const std::shared_ptr<IEndpointCallback>& in_callback) {
-    std::unique_lock<std::mutex> lock(mEndpointMutex);
-
-    mEndpointCallback = in_callback;
-    return ScopedAStatus::ok();
-};
-
-ScopedAStatus ContextHub::requestSessionIdRange(int32_t in_size,
-                                                std::array<int32_t, 2>* _aidl_return) {
+ScopedAStatus ContextHub::HubInterface::requestSessionIdRange(
+        int32_t in_size, std::array<int32_t, 2>* _aidl_return) {
+    if (!mActive) {
+        return ScopedAStatus::fromExceptionCode(EX_ILLEGAL_STATE);
+    }
     constexpr int32_t kMaxSize = 1024;
     if (in_size > kMaxSize || _aidl_return == nullptr) {
         return ScopedAStatus::fromExceptionCode(EX_ILLEGAL_ARGUMENT);
     }
 
+    uint16_t base = 0;
     {
-        std::lock_guard<std::mutex> lock(mEndpointMutex);
-        mMaxValidSessionId = in_size;
+        std::lock_guard lock(mHal.mHostHubsLock);
+        if (static_cast<int32_t>(USHRT_MAX) - mHal.mNextSessionIdBase + 1 < in_size) {
+            return ScopedAStatus::fromServiceSpecificError(EX_CONTEXT_HUB_UNSPECIFIED);
+        }
+        base = mHal.mNextSessionIdBase;
+        mHal.mNextSessionIdBase += in_size;
     }
 
-    (*_aidl_return)[0] = 0;
-    (*_aidl_return)[1] = in_size;
+    {
+        std::lock_guard<std::mutex> lock(mEndpointMutex);
+        (*_aidl_return)[0] = mBaseSessionId = base;
+        (*_aidl_return)[1] = mMaxSessionId = base + (in_size - 1);
+    }
     return ScopedAStatus::ok();
 };
 
-ScopedAStatus ContextHub::openEndpointSession(
+ScopedAStatus ContextHub::HubInterface::openEndpointSession(
         int32_t in_sessionId, const EndpointId& in_destination, const EndpointId& in_initiator,
         const std::optional<std::string>& in_serviceDescriptor) {
+    if (!mActive) {
+        return ScopedAStatus::fromExceptionCode(EX_ILLEGAL_STATE);
+    }
     // We are not calling onCloseEndpointSession on failure because the remote endpoints (our
     // mock endpoints) always accept the session.
 
     std::weak_ptr<IEndpointCallback> callback;
     {
         std::unique_lock<std::mutex> lock(mEndpointMutex);
-        if (in_sessionId > mMaxValidSessionId) {
+        if (in_sessionId < mBaseSessionId || in_sessionId > mMaxSessionId) {
             ALOGE("openEndpointSession: session ID %" PRId32 " is invalid", in_sessionId);
             return ScopedAStatus::fromExceptionCode(EX_ILLEGAL_ARGUMENT);
         }
@@ -346,7 +372,11 @@
     return ScopedAStatus::ok();
 };
 
-ScopedAStatus ContextHub::sendMessageToEndpoint(int32_t in_sessionId, const Message& in_msg) {
+ScopedAStatus ContextHub::HubInterface::sendMessageToEndpoint(int32_t in_sessionId,
+                                                              const Message& in_msg) {
+    if (!mActive) {
+        return ScopedAStatus::fromExceptionCode(EX_ILLEGAL_STATE);
+    }
     std::weak_ptr<IEndpointCallback> callback;
     {
         std::unique_lock<std::mutex> lock(mEndpointMutex);
@@ -393,12 +423,19 @@
     return ScopedAStatus::ok();
 };
 
-ScopedAStatus ContextHub::sendMessageDeliveryStatusToEndpoint(
+ScopedAStatus ContextHub::HubInterface::sendMessageDeliveryStatusToEndpoint(
         int32_t /* in_sessionId */, const MessageDeliveryStatus& /* in_msgStatus */) {
+    if (!mActive) {
+        return ScopedAStatus::fromExceptionCode(EX_ILLEGAL_STATE);
+    }
     return ScopedAStatus::ok();
 };
 
-ScopedAStatus ContextHub::closeEndpointSession(int32_t in_sessionId, Reason /* in_reason */) {
+ScopedAStatus ContextHub::HubInterface::closeEndpointSession(int32_t in_sessionId,
+                                                             Reason /* in_reason */) {
+    if (!mActive) {
+        return ScopedAStatus::fromExceptionCode(EX_ILLEGAL_STATE);
+    }
     std::unique_lock<std::mutex> lock(mEndpointMutex);
 
     for (auto it = mEndpointSessions.begin(); it != mEndpointSessions.end(); ++it) {
@@ -411,8 +448,20 @@
     return ScopedAStatus::fromExceptionCode(EX_ILLEGAL_ARGUMENT);
 };
 
-ScopedAStatus ContextHub::endpointSessionOpenComplete(int32_t /* in_sessionId */) {
+ScopedAStatus ContextHub::HubInterface::endpointSessionOpenComplete(int32_t /* in_sessionId */) {
+    if (!mActive) {
+        return ScopedAStatus::fromExceptionCode(EX_ILLEGAL_STATE);
+    }
     return ScopedAStatus::ok();
 };
 
+ScopedAStatus ContextHub::HubInterface::unregister() {
+    if (!mActive.exchange(false)) {
+        return ScopedAStatus::fromExceptionCode(EX_ILLEGAL_STATE);
+    }
+    std::lock_guard lock(mHal.mHostHubsLock);
+    mHal.mIdToHostHub.erase(kInfo.hubId);
+    return ScopedAStatus::ok();
+}
+
 }  // namespace aidl::android::hardware::contexthub
diff --git a/contexthub/aidl/default/include/contexthub-impl/ContextHub.h b/contexthub/aidl/default/include/contexthub-impl/ContextHub.h
index 6da8bf2..65e84bb 100644
--- a/contexthub/aidl/default/include/contexthub-impl/ContextHub.h
+++ b/contexthub/aidl/default/include/contexthub-impl/ContextHub.h
@@ -17,7 +17,9 @@
 #pragma once
 
 #include <aidl/android/hardware/contexthub/BnContextHub.h>
+#include <aidl/android/hardware/contexthub/BnEndpointCommunication.h>
 
+#include <atomic>
 #include <mutex>
 #include <unordered_set>
 #include <vector>
@@ -56,54 +58,79 @@
 
     ::ndk::ScopedAStatus getHubs(std::vector<HubInfo>* _aidl_return) override;
     ::ndk::ScopedAStatus getEndpoints(std::vector<EndpointInfo>* _aidl_return) override;
-    ::ndk::ScopedAStatus registerEndpoint(const EndpointInfo& in_endpoint) override;
-    ::ndk::ScopedAStatus unregisterEndpoint(const EndpointInfo& in_endpoint) override;
-    ::ndk::ScopedAStatus registerEndpointCallback(
-            const std::shared_ptr<IEndpointCallback>& in_callback) override;
-    ::ndk::ScopedAStatus requestSessionIdRange(int32_t in_size,
-                                               std::array<int32_t, 2>* _aidl_return) override;
-    ::ndk::ScopedAStatus openEndpointSession(
-            int32_t in_sessionId, const EndpointId& in_destination, const EndpointId& in_initiator,
-            const std::optional<std::string>& in_serviceDescriptor) override;
-    ::ndk::ScopedAStatus sendMessageToEndpoint(int32_t in_sessionId,
-                                               const Message& in_msg) override;
-    ::ndk::ScopedAStatus sendMessageDeliveryStatusToEndpoint(
-            int32_t in_sessionId, const MessageDeliveryStatus& in_msgStatus) override;
-    ::ndk::ScopedAStatus closeEndpointSession(int32_t in_sessionId, Reason in_reason) override;
-    ::ndk::ScopedAStatus endpointSessionOpenComplete(int32_t in_sessionId) override;
+    ::ndk::ScopedAStatus registerEndpointHub(
+            const std::shared_ptr<IEndpointCallback>& in_callback, const HubInfo& in_hubInfo,
+            std::shared_ptr<IEndpointCommunication>* _aidl_return) override;
 
   private:
-    struct EndpointSession {
-        int32_t sessionId;
-        EndpointId initiator;
-        EndpointId peer;
-        std::optional<std::string> serviceDescriptor;
+    class HubInterface : public BnEndpointCommunication {
+      public:
+        HubInterface(ContextHub& hal, const std::shared_ptr<IEndpointCallback>& in_callback,
+                     const HubInfo& in_hubInfo)
+            : mHal(hal), mEndpointCallback(in_callback), kInfo(in_hubInfo) {}
+        ~HubInterface() = default;
+
+        ::ndk::ScopedAStatus registerEndpoint(const EndpointInfo& in_endpoint) override;
+        ::ndk::ScopedAStatus unregisterEndpoint(const EndpointInfo& in_endpoint) override;
+        ::ndk::ScopedAStatus requestSessionIdRange(int32_t in_size,
+                                                   std::array<int32_t, 2>* _aidl_return) override;
+        ::ndk::ScopedAStatus openEndpointSession(
+                int32_t in_sessionId, const EndpointId& in_destination,
+                const EndpointId& in_initiator,
+                const std::optional<std::string>& in_serviceDescriptor) override;
+        ::ndk::ScopedAStatus sendMessageToEndpoint(int32_t in_sessionId,
+                                                   const Message& in_msg) override;
+        ::ndk::ScopedAStatus sendMessageDeliveryStatusToEndpoint(
+                int32_t in_sessionId, const MessageDeliveryStatus& in_msgStatus) override;
+        ::ndk::ScopedAStatus closeEndpointSession(int32_t in_sessionId, Reason in_reason) override;
+        ::ndk::ScopedAStatus endpointSessionOpenComplete(int32_t in_sessionId) override;
+        ::ndk::ScopedAStatus unregister() override;
+
+      private:
+        friend class ContextHub;
+
+        struct EndpointSession {
+            int32_t sessionId;
+            EndpointId initiator;
+            EndpointId peer;
+            std::optional<std::string> serviceDescriptor;
+        };
+
+        //! Finds an endpoint in the range defined by the endpoints
+        //! @return whether the endpoint was found
+        template <typename Iter>
+        bool findEndpoint(const EndpointId& target, const Iter& begin, const Iter& end) {
+            for (auto iter = begin; iter != end; ++iter) {
+                if (iter->id.id == target.id && iter->id.hubId == target.hubId) {
+                    return true;
+                }
+            }
+            return false;
+        }
+
+        //! Endpoint storage and information
+        ContextHub& mHal;
+        std::shared_ptr<IEndpointCallback> mEndpointCallback;
+        const HubInfo kInfo;
+
+        std::atomic<bool> mActive = true;
+
+        std::mutex mEndpointMutex;
+        std::vector<EndpointInfo> mEndpoints;
+        std::vector<EndpointSession> mEndpointSessions;
+        uint16_t mBaseSessionId;
+        uint16_t mMaxSessionId;
     };
 
     static constexpr uint32_t kMockHubId = 0;
 
-    //! Finds an endpoint in the range defined by the endpoints
-    //! @return whether the endpoint was found
-    template <typename Iter>
-    bool findEndpoint(const EndpointId& target, const Iter& begin, const Iter& end) {
-        for (auto iter = begin; iter != end; ++iter) {
-            if (iter->id.id == target.id && iter->id.hubId == target.hubId) {
-                return true;
-            }
-        }
-        return false;
-    }
-
     std::shared_ptr<IContextHubCallback> mCallback;
 
     std::unordered_set<char16_t> mConnectedHostEndpoints;
 
-    //! Endpoint storage and information
-    std::mutex mEndpointMutex;
-    std::vector<EndpointInfo> mEndpoints;
-    std::vector<EndpointSession> mEndpointSessions;
-    std::shared_ptr<IEndpointCallback> mEndpointCallback;
-    int32_t mMaxValidSessionId = 0;
+    std::mutex mHostHubsLock;
+    std::unordered_map<int64_t, std::shared_ptr<HubInterface>> mIdToHostHub;
+    int32_t mNextSessionIdBase = 0;
 };
 
 }  // namespace contexthub