contexthub: Make client registration API safer
Modifies the registration API for new endpoint session capabilities to
return a new dedicated IEndpointCommunication interface tied to the
client registration. All endpoint/session APIs have been moved to the
new interface. This prevents security issues resulting from pid re-use.
Adds an unregister() API to IEndpointCommunication for a client to
signal that its associated hub will no longer be active. Subsequent
calls on IEndpointCommunication APIs will fail.
Bug: 384897139
Flag: android.chre.flags.offload_api
Test: Builds
Change-Id: Ibd8a28e96315ca146d62413b5337ece9912d14c4
diff --git a/contexthub/aidl/default/ContextHub.cpp b/contexthub/aidl/default/ContextHub.cpp
index 19d9639..433617e 100644
--- a/contexthub/aidl/default/ContextHub.cpp
+++ b/contexthub/aidl/default/ContextHub.cpp
@@ -150,10 +150,11 @@
ScopedAStatus ContextHub::setTestMode(bool enable) {
if (enable) {
- std::unique_lock<std::mutex> lock(mEndpointMutex);
- mEndpoints.clear();
- mEndpointSessions.clear();
- mEndpointCallback = nullptr;
+ std::lock_guard lock(mHostHubsLock);
+ for (auto& [id, hub] : mIdToHostHub) {
+ hub->mActive = false;
+ }
+ mIdToHostHub.clear();
}
return ScopedAStatus::ok();
}
@@ -227,7 +228,23 @@
return ScopedAStatus::ok();
};
-ScopedAStatus ContextHub::registerEndpoint(const EndpointInfo& in_endpoint) {
+ScopedAStatus ContextHub::registerEndpointHub(
+ const std::shared_ptr<IEndpointCallback>& in_callback, const HubInfo& in_hubInfo,
+ std::shared_ptr<IEndpointCommunication>* _aidl_return) {
+ std::lock_guard lock(mHostHubsLock);
+ if (mIdToHostHub.count(in_hubInfo.hubId)) {
+ return ScopedAStatus::fromExceptionCode(EX_ILLEGAL_STATE);
+ }
+ auto hub = ndk::SharedRefBase::make<HubInterface>(*this, in_callback, in_hubInfo);
+ mIdToHostHub.insert({in_hubInfo.hubId, hub});
+ *_aidl_return = std::move(hub);
+ return ScopedAStatus::ok();
+}
+
+ScopedAStatus ContextHub::HubInterface::registerEndpoint(const EndpointInfo& in_endpoint) {
+ if (!mActive) {
+ return ScopedAStatus::fromExceptionCode(EX_ILLEGAL_STATE);
+ }
std::unique_lock<std::mutex> lock(mEndpointMutex);
for (const EndpointInfo& endpoint : mEndpoints) {
@@ -240,7 +257,10 @@
return ScopedAStatus::ok();
};
-ScopedAStatus ContextHub::unregisterEndpoint(const EndpointInfo& in_endpoint) {
+ScopedAStatus ContextHub::HubInterface::unregisterEndpoint(const EndpointInfo& in_endpoint) {
+ if (!mActive) {
+ return ScopedAStatus::fromExceptionCode(EX_ILLEGAL_STATE);
+ }
std::unique_lock<std::mutex> lock(mEndpointMutex);
for (auto it = mEndpoints.begin(); it != mEndpoints.end(); ++it) {
@@ -252,41 +272,47 @@
return ScopedAStatus::fromExceptionCode(EX_ILLEGAL_ARGUMENT);
};
-ScopedAStatus ContextHub::registerEndpointCallback(
- const std::shared_ptr<IEndpointCallback>& in_callback) {
- std::unique_lock<std::mutex> lock(mEndpointMutex);
-
- mEndpointCallback = in_callback;
- return ScopedAStatus::ok();
-};
-
-ScopedAStatus ContextHub::requestSessionIdRange(int32_t in_size,
- std::array<int32_t, 2>* _aidl_return) {
+ScopedAStatus ContextHub::HubInterface::requestSessionIdRange(
+ int32_t in_size, std::array<int32_t, 2>* _aidl_return) {
+ if (!mActive) {
+ return ScopedAStatus::fromExceptionCode(EX_ILLEGAL_STATE);
+ }
constexpr int32_t kMaxSize = 1024;
if (in_size > kMaxSize || _aidl_return == nullptr) {
return ScopedAStatus::fromExceptionCode(EX_ILLEGAL_ARGUMENT);
}
+ uint16_t base = 0;
{
- std::lock_guard<std::mutex> lock(mEndpointMutex);
- mMaxValidSessionId = in_size;
+ std::lock_guard lock(mHal.mHostHubsLock);
+ if (static_cast<int32_t>(USHRT_MAX) - mHal.mNextSessionIdBase + 1 < in_size) {
+ return ScopedAStatus::fromServiceSpecificError(EX_CONTEXT_HUB_UNSPECIFIED);
+ }
+ base = mHal.mNextSessionIdBase;
+ mHal.mNextSessionIdBase += in_size;
}
- (*_aidl_return)[0] = 0;
- (*_aidl_return)[1] = in_size;
+ {
+ std::lock_guard<std::mutex> lock(mEndpointMutex);
+ (*_aidl_return)[0] = mBaseSessionId = base;
+ (*_aidl_return)[1] = mMaxSessionId = base + (in_size - 1);
+ }
return ScopedAStatus::ok();
};
-ScopedAStatus ContextHub::openEndpointSession(
+ScopedAStatus ContextHub::HubInterface::openEndpointSession(
int32_t in_sessionId, const EndpointId& in_destination, const EndpointId& in_initiator,
const std::optional<std::string>& in_serviceDescriptor) {
+ if (!mActive) {
+ return ScopedAStatus::fromExceptionCode(EX_ILLEGAL_STATE);
+ }
// We are not calling onCloseEndpointSession on failure because the remote endpoints (our
// mock endpoints) always accept the session.
std::weak_ptr<IEndpointCallback> callback;
{
std::unique_lock<std::mutex> lock(mEndpointMutex);
- if (in_sessionId > mMaxValidSessionId) {
+ if (in_sessionId < mBaseSessionId || in_sessionId > mMaxSessionId) {
ALOGE("openEndpointSession: session ID %" PRId32 " is invalid", in_sessionId);
return ScopedAStatus::fromExceptionCode(EX_ILLEGAL_ARGUMENT);
}
@@ -346,7 +372,11 @@
return ScopedAStatus::ok();
};
-ScopedAStatus ContextHub::sendMessageToEndpoint(int32_t in_sessionId, const Message& in_msg) {
+ScopedAStatus ContextHub::HubInterface::sendMessageToEndpoint(int32_t in_sessionId,
+ const Message& in_msg) {
+ if (!mActive) {
+ return ScopedAStatus::fromExceptionCode(EX_ILLEGAL_STATE);
+ }
std::weak_ptr<IEndpointCallback> callback;
{
std::unique_lock<std::mutex> lock(mEndpointMutex);
@@ -393,12 +423,19 @@
return ScopedAStatus::ok();
};
-ScopedAStatus ContextHub::sendMessageDeliveryStatusToEndpoint(
+ScopedAStatus ContextHub::HubInterface::sendMessageDeliveryStatusToEndpoint(
int32_t /* in_sessionId */, const MessageDeliveryStatus& /* in_msgStatus */) {
+ if (!mActive) {
+ return ScopedAStatus::fromExceptionCode(EX_ILLEGAL_STATE);
+ }
return ScopedAStatus::ok();
};
-ScopedAStatus ContextHub::closeEndpointSession(int32_t in_sessionId, Reason /* in_reason */) {
+ScopedAStatus ContextHub::HubInterface::closeEndpointSession(int32_t in_sessionId,
+ Reason /* in_reason */) {
+ if (!mActive) {
+ return ScopedAStatus::fromExceptionCode(EX_ILLEGAL_STATE);
+ }
std::unique_lock<std::mutex> lock(mEndpointMutex);
for (auto it = mEndpointSessions.begin(); it != mEndpointSessions.end(); ++it) {
@@ -411,8 +448,20 @@
return ScopedAStatus::fromExceptionCode(EX_ILLEGAL_ARGUMENT);
};
-ScopedAStatus ContextHub::endpointSessionOpenComplete(int32_t /* in_sessionId */) {
+ScopedAStatus ContextHub::HubInterface::endpointSessionOpenComplete(int32_t /* in_sessionId */) {
+ if (!mActive) {
+ return ScopedAStatus::fromExceptionCode(EX_ILLEGAL_STATE);
+ }
return ScopedAStatus::ok();
};
+ScopedAStatus ContextHub::HubInterface::unregister() {
+ if (!mActive.exchange(false)) {
+ return ScopedAStatus::fromExceptionCode(EX_ILLEGAL_STATE);
+ }
+ std::lock_guard lock(mHal.mHostHubsLock);
+ mHal.mIdToHostHub.erase(kInfo.hubId);
+ return ScopedAStatus::ok();
+}
+
} // namespace aidl::android::hardware::contexthub
diff --git a/contexthub/aidl/default/include/contexthub-impl/ContextHub.h b/contexthub/aidl/default/include/contexthub-impl/ContextHub.h
index 6da8bf2..65e84bb 100644
--- a/contexthub/aidl/default/include/contexthub-impl/ContextHub.h
+++ b/contexthub/aidl/default/include/contexthub-impl/ContextHub.h
@@ -17,7 +17,9 @@
#pragma once
#include <aidl/android/hardware/contexthub/BnContextHub.h>
+#include <aidl/android/hardware/contexthub/BnEndpointCommunication.h>
+#include <atomic>
#include <mutex>
#include <unordered_set>
#include <vector>
@@ -56,54 +58,79 @@
::ndk::ScopedAStatus getHubs(std::vector<HubInfo>* _aidl_return) override;
::ndk::ScopedAStatus getEndpoints(std::vector<EndpointInfo>* _aidl_return) override;
- ::ndk::ScopedAStatus registerEndpoint(const EndpointInfo& in_endpoint) override;
- ::ndk::ScopedAStatus unregisterEndpoint(const EndpointInfo& in_endpoint) override;
- ::ndk::ScopedAStatus registerEndpointCallback(
- const std::shared_ptr<IEndpointCallback>& in_callback) override;
- ::ndk::ScopedAStatus requestSessionIdRange(int32_t in_size,
- std::array<int32_t, 2>* _aidl_return) override;
- ::ndk::ScopedAStatus openEndpointSession(
- int32_t in_sessionId, const EndpointId& in_destination, const EndpointId& in_initiator,
- const std::optional<std::string>& in_serviceDescriptor) override;
- ::ndk::ScopedAStatus sendMessageToEndpoint(int32_t in_sessionId,
- const Message& in_msg) override;
- ::ndk::ScopedAStatus sendMessageDeliveryStatusToEndpoint(
- int32_t in_sessionId, const MessageDeliveryStatus& in_msgStatus) override;
- ::ndk::ScopedAStatus closeEndpointSession(int32_t in_sessionId, Reason in_reason) override;
- ::ndk::ScopedAStatus endpointSessionOpenComplete(int32_t in_sessionId) override;
+ ::ndk::ScopedAStatus registerEndpointHub(
+ const std::shared_ptr<IEndpointCallback>& in_callback, const HubInfo& in_hubInfo,
+ std::shared_ptr<IEndpointCommunication>* _aidl_return) override;
private:
- struct EndpointSession {
- int32_t sessionId;
- EndpointId initiator;
- EndpointId peer;
- std::optional<std::string> serviceDescriptor;
+ class HubInterface : public BnEndpointCommunication {
+ public:
+ HubInterface(ContextHub& hal, const std::shared_ptr<IEndpointCallback>& in_callback,
+ const HubInfo& in_hubInfo)
+ : mHal(hal), mEndpointCallback(in_callback), kInfo(in_hubInfo) {}
+ ~HubInterface() = default;
+
+ ::ndk::ScopedAStatus registerEndpoint(const EndpointInfo& in_endpoint) override;
+ ::ndk::ScopedAStatus unregisterEndpoint(const EndpointInfo& in_endpoint) override;
+ ::ndk::ScopedAStatus requestSessionIdRange(int32_t in_size,
+ std::array<int32_t, 2>* _aidl_return) override;
+ ::ndk::ScopedAStatus openEndpointSession(
+ int32_t in_sessionId, const EndpointId& in_destination,
+ const EndpointId& in_initiator,
+ const std::optional<std::string>& in_serviceDescriptor) override;
+ ::ndk::ScopedAStatus sendMessageToEndpoint(int32_t in_sessionId,
+ const Message& in_msg) override;
+ ::ndk::ScopedAStatus sendMessageDeliveryStatusToEndpoint(
+ int32_t in_sessionId, const MessageDeliveryStatus& in_msgStatus) override;
+ ::ndk::ScopedAStatus closeEndpointSession(int32_t in_sessionId, Reason in_reason) override;
+ ::ndk::ScopedAStatus endpointSessionOpenComplete(int32_t in_sessionId) override;
+ ::ndk::ScopedAStatus unregister() override;
+
+ private:
+ friend class ContextHub;
+
+ struct EndpointSession {
+ int32_t sessionId;
+ EndpointId initiator;
+ EndpointId peer;
+ std::optional<std::string> serviceDescriptor;
+ };
+
+ //! Finds an endpoint in the range defined by the endpoints
+ //! @return whether the endpoint was found
+ template <typename Iter>
+ bool findEndpoint(const EndpointId& target, const Iter& begin, const Iter& end) {
+ for (auto iter = begin; iter != end; ++iter) {
+ if (iter->id.id == target.id && iter->id.hubId == target.hubId) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ //! Endpoint storage and information
+ ContextHub& mHal;
+ std::shared_ptr<IEndpointCallback> mEndpointCallback;
+ const HubInfo kInfo;
+
+ std::atomic<bool> mActive = true;
+
+ std::mutex mEndpointMutex;
+ std::vector<EndpointInfo> mEndpoints;
+ std::vector<EndpointSession> mEndpointSessions;
+ uint16_t mBaseSessionId;
+ uint16_t mMaxSessionId;
};
static constexpr uint32_t kMockHubId = 0;
- //! Finds an endpoint in the range defined by the endpoints
- //! @return whether the endpoint was found
- template <typename Iter>
- bool findEndpoint(const EndpointId& target, const Iter& begin, const Iter& end) {
- for (auto iter = begin; iter != end; ++iter) {
- if (iter->id.id == target.id && iter->id.hubId == target.hubId) {
- return true;
- }
- }
- return false;
- }
-
std::shared_ptr<IContextHubCallback> mCallback;
std::unordered_set<char16_t> mConnectedHostEndpoints;
- //! Endpoint storage and information
- std::mutex mEndpointMutex;
- std::vector<EndpointInfo> mEndpoints;
- std::vector<EndpointSession> mEndpointSessions;
- std::shared_ptr<IEndpointCallback> mEndpointCallback;
- int32_t mMaxValidSessionId = 0;
+ std::mutex mHostHubsLock;
+ std::unordered_map<int64_t, std::shared_ptr<HubInterface>> mIdToHostHub;
+ int32_t mNextSessionIdBase = 0;
};
} // namespace contexthub