libbinder: support server-specific session
When multiple clients connect to a server, we need a way to tell these
clients apart. Having a per-client root object is the easiest way to do
this (the alternative, using getCalling* like is used in binder, isn't
so great because it requires global/thread local place, but given that
many RpcSession objects can be created, and these can also be used in
conjunction with kernel binder, it is complicated figuring out exactly
where to call getCalling*).
Bug: 199259751
Test: binderRpcTest
Change-Id: I5727db618b5ea138bfa19e75ed915f6a6991518e
diff --git a/libs/binder/RpcServer.cpp b/libs/binder/RpcServer.cpp
index 967b8e3..4edc202 100644
--- a/libs/binder/RpcServer.cpp
+++ b/libs/binder/RpcServer.cpp
@@ -127,14 +127,23 @@
void RpcServer::setRootObject(const sp<IBinder>& binder) {
std::lock_guard<std::mutex> _l(mLock);
+ mRootObjectFactory = nullptr;
mRootObjectWeak = mRootObject = binder;
}
void RpcServer::setRootObjectWeak(const wp<IBinder>& binder) {
std::lock_guard<std::mutex> _l(mLock);
mRootObject.clear();
+ mRootObjectFactory = nullptr;
mRootObjectWeak = binder;
}
+void RpcServer::setPerSessionRootObject(
+ std::function<sp<IBinder>(const sockaddr*, socklen_t)>&& makeObject) {
+ std::lock_guard<std::mutex> _l(mLock);
+ mRootObject.clear();
+ mRootObjectWeak.clear();
+ mRootObjectFactory = std::move(makeObject);
+}
sp<IBinder> RpcServer::getRootObject() {
std::lock_guard<std::mutex> _l(mLock);
@@ -174,8 +183,14 @@
status_t status;
while ((status = mShutdownTrigger->triggerablePoll(mServer, POLLIN)) == OK) {
- unique_fd clientFd(TEMP_FAILURE_RETRY(
- accept4(mServer.get(), nullptr, nullptr /*length*/, SOCK_CLOEXEC | SOCK_NONBLOCK)));
+ sockaddr_storage addr;
+ socklen_t addrLen = sizeof(addr);
+
+ unique_fd clientFd(
+ TEMP_FAILURE_RETRY(accept4(mServer.get(), reinterpret_cast<sockaddr*>(&addr),
+ &addrLen, SOCK_CLOEXEC | SOCK_NONBLOCK)));
+
+ LOG_ALWAYS_FATAL_IF(addrLen > static_cast<socklen_t>(sizeof(addr)), "Truncated address");
if (clientFd < 0) {
ALOGE("Could not accept4 socket: %s", strerror(errno));
@@ -187,7 +202,7 @@
std::lock_guard<std::mutex> _l(mLock);
std::thread thread =
std::thread(&RpcServer::establishConnection, sp<RpcServer>::fromExisting(this),
- std::move(clientFd));
+ std::move(clientFd), addr, addrLen);
mConnectingThreads[thread.get_id()] = std::move(thread);
}
}
@@ -257,7 +272,8 @@
return mConnectingThreads.size();
}
-void RpcServer::establishConnection(sp<RpcServer>&& server, base::unique_fd clientFd) {
+void RpcServer::establishConnection(sp<RpcServer>&& server, base::unique_fd clientFd,
+ const sockaddr_storage addr, socklen_t addrLen) {
// TODO(b/183988761): cannot trust this simple ID
LOG_ALWAYS_FATAL_IF(!server->mAgreedExperimental, "no!");
@@ -383,11 +399,23 @@
session = RpcSession::make();
session->setMaxIncomingThreads(server->mMaxThreads);
if (!session->setProtocolVersion(protocolVersion)) return;
+
+ // if null, falls back to server root
+ sp<IBinder> sessionSpecificRoot;
+ if (server->mRootObjectFactory != nullptr) {
+ sessionSpecificRoot =
+ server->mRootObjectFactory(reinterpret_cast<const sockaddr*>(&addr),
+ addrLen);
+ if (sessionSpecificRoot == nullptr) {
+ ALOGE("Warning: server returned null from root object factory");
+ }
+ }
+
if (!session->setForServer(server,
sp<RpcServer::EventListener>::fromExisting(
static_cast<RpcServer::EventListener*>(
server.get())),
- sessionId)) {
+ sessionId, sessionSpecificRoot)) {
ALOGE("Failed to attach server to session");
return;
}
diff --git a/libs/binder/RpcSession.cpp b/libs/binder/RpcSession.cpp
index 9eef3e8..137411b 100644
--- a/libs/binder/RpcSession.cpp
+++ b/libs/binder/RpcSession.cpp
@@ -700,7 +700,8 @@
}
bool RpcSession::setForServer(const wp<RpcServer>& server, const wp<EventListener>& eventListener,
- const std::vector<uint8_t>& sessionId) {
+ const std::vector<uint8_t>& sessionId,
+ const sp<IBinder>& sessionSpecificRoot) {
LOG_ALWAYS_FATAL_IF(mForServer != nullptr);
LOG_ALWAYS_FATAL_IF(server == nullptr);
LOG_ALWAYS_FATAL_IF(mEventListener != nullptr);
@@ -713,6 +714,7 @@
mId = sessionId;
mForServer = server;
mEventListener = eventListener;
+ mSessionSpecificRootObject = sessionSpecificRoot;
return true;
}
diff --git a/libs/binder/RpcState.cpp b/libs/binder/RpcState.cpp
index 9ba64f3..09b3d68 100644
--- a/libs/binder/RpcState.cpp
+++ b/libs/binder/RpcState.cpp
@@ -870,7 +870,9 @@
if (server) {
switch (transaction->code) {
case RPC_SPECIAL_TRANSACT_GET_ROOT: {
- replyStatus = reply.writeStrongBinder(server->getRootObject());
+ sp<IBinder> root = session->mSessionSpecificRootObject
+ ?: server->getRootObject();
+ replyStatus = reply.writeStrongBinder(root);
break;
}
default: {
diff --git a/libs/binder/include/binder/RpcServer.h b/libs/binder/include/binder/RpcServer.h
index fb2cf23..04c6b59 100644
--- a/libs/binder/include/binder/RpcServer.h
+++ b/libs/binder/include/binder/RpcServer.h
@@ -130,6 +130,10 @@
* Holds a weak reference to the root object.
*/
void setRootObjectWeak(const wp<IBinder>& binder);
+ /**
+ * Allows a root object to be created for each session
+ */
+ void setPerSessionRootObject(std::function<sp<IBinder>(const sockaddr*, socklen_t)>&& object);
sp<IBinder> getRootObject();
/**
@@ -179,7 +183,8 @@
void onSessionAllIncomingThreadsEnded(const sp<RpcSession>& session) override;
void onSessionIncomingThreadEnded() override;
- static void establishConnection(sp<RpcServer>&& server, base::unique_fd clientFd);
+ static void establishConnection(sp<RpcServer>&& server, base::unique_fd clientFd,
+ const sockaddr_storage addr, socklen_t addrLen);
status_t setupSocketServer(const RpcSocketAddress& address);
const std::unique_ptr<RpcTransportCtx> mCtx;
@@ -194,6 +199,7 @@
std::map<std::thread::id, std::thread> mConnectingThreads;
sp<IBinder> mRootObject;
wp<IBinder> mRootObjectWeak;
+ std::function<sp<IBinder>(const sockaddr*, socklen_t)> mRootObjectFactory;
std::map<std::vector<uint8_t>, sp<RpcSession>> mSessions;
std::unique_ptr<FdTrigger> mShutdownTrigger;
std::condition_variable mShutdownCv;
diff --git a/libs/binder/include/binder/RpcSession.h b/libs/binder/include/binder/RpcSession.h
index f5505da..f9b733d 100644
--- a/libs/binder/include/binder/RpcSession.h
+++ b/libs/binder/include/binder/RpcSession.h
@@ -256,7 +256,8 @@
bool init);
[[nodiscard]] bool setForServer(const wp<RpcServer>& server,
const wp<RpcSession::EventListener>& eventListener,
- const std::vector<uint8_t>& sessionId);
+ const std::vector<uint8_t>& sessionId,
+ const sp<IBinder>& sessionSpecificRoot);
sp<RpcConnection> assignIncomingConnectionToThisThread(
std::unique_ptr<RpcTransport> rpcTransport);
[[nodiscard]] bool removeIncomingConnection(const sp<RpcConnection>& connection);
@@ -313,6 +314,10 @@
sp<WaitForShutdownListener> mShutdownListener; // used for client sessions
wp<EventListener> mEventListener; // mForServer if server, mShutdownListener if client
+ // session-specific root object (if a different root is used for each
+ // session)
+ sp<IBinder> mSessionSpecificRootObject;
+
std::vector<uint8_t> mId;
std::unique_ptr<FdTrigger> mShutdownTrigger;
diff --git a/libs/binder/tests/IBinderRpcTest.aidl b/libs/binder/tests/IBinderRpcTest.aidl
index 9e10788..fdd02a4 100644
--- a/libs/binder/tests/IBinderRpcTest.aidl
+++ b/libs/binder/tests/IBinderRpcTest.aidl
@@ -18,6 +18,9 @@
oneway void sendString(@utf8InCpp String str);
@utf8InCpp String doubleString(@utf8InCpp String str);
+ // get the port that a client used to connect to this object
+ int getClientPort();
+
// number of known RPC binders to process, RpcState::countBinders by session
int[] countBinders();
diff --git a/libs/binder/tests/binderRpcTest.cpp b/libs/binder/tests/binderRpcTest.cpp
index 8267702..bc23b15 100644
--- a/libs/binder/tests/binderRpcTest.cpp
+++ b/libs/binder/tests/binderRpcTest.cpp
@@ -174,6 +174,7 @@
class MyBinderRpcTest : public BnBinderRpcTest {
public:
wp<RpcServer> server;
+ int port = 0;
Status sendString(const std::string& str) override {
(void)str;
@@ -183,6 +184,10 @@
*strstr = str + str;
return Status::ok();
}
+ Status getClientPort(int* out) override {
+ *out = port;
+ return Status::ok();
+ }
Status countBinders(std::vector<int32_t>* out) override {
sp<RpcServer> spServer = server.promote();
if (spServer == nullptr) {
@@ -643,13 +648,39 @@
BinderRpcTestProcessSession createRpcTestSocketServerProcess(const Options& options) {
BinderRpcTestProcessSession ret{
- .proc = createRpcTestSocketServerProcess(options,
- [&](const sp<RpcServer>& server) {
- sp<MyBinderRpcTest> service =
- new MyBinderRpcTest;
- server->setRootObject(service);
- service->server = server;
- }),
+ .proc = createRpcTestSocketServerProcess(
+ options,
+ [&](const sp<RpcServer>& server) {
+ server->setPerSessionRootObject([&](const sockaddr* addr,
+ socklen_t len) {
+ sp<MyBinderRpcTest> service = sp<MyBinderRpcTest>::make();
+ switch (addr->sa_family) {
+ case AF_UNIX:
+ // nothing to save
+ break;
+ case AF_VSOCK:
+ CHECK_EQ(len, sizeof(sockaddr_vm));
+ service->port = reinterpret_cast<const sockaddr_vm*>(addr)
+ ->svm_port;
+ break;
+ case AF_INET:
+ CHECK_EQ(len, sizeof(sockaddr_in));
+ service->port = reinterpret_cast<const sockaddr_in*>(addr)
+ ->sin_port;
+ break;
+ case AF_INET6:
+ CHECK_EQ(len, sizeof(sockaddr_in));
+ service->port = reinterpret_cast<const sockaddr_in6*>(addr)
+ ->sin6_port;
+ break;
+ default:
+ LOG_ALWAYS_FATAL("Unrecognized address family %d",
+ addr->sa_family);
+ }
+ service->server = server;
+ return service;
+ });
+ }),
};
ret.rootBinder = ret.proc.sessions.at(0).root;
@@ -682,6 +713,27 @@
}
}
+TEST_P(BinderRpc, SeparateRootObject) {
+ SocketType type = std::get<0>(GetParam());
+ if (type == SocketType::PRECONNECTED || type == SocketType::UNIX) {
+ // we can't get port numbers for unix sockets
+ return;
+ }
+
+ auto proc = createRpcTestSocketServerProcess({.numSessions = 2});
+
+ int port1 = 0;
+ EXPECT_OK(proc.rootIface->getClientPort(&port1));
+
+ sp<IBinderRpcTest> rootIface2 = interface_cast<IBinderRpcTest>(proc.proc.sessions.at(1).root);
+ int port2;
+ EXPECT_OK(rootIface2->getClientPort(&port2));
+
+ // we should have a different IBinderRpcTest object created for each
+ // session, because we use setPerSessionRootObject
+ EXPECT_NE(port1, port2);
+}
+
TEST_P(BinderRpc, TransactionsMustBeMarkedRpc) {
auto proc = createRpcTestSocketServerProcess({});
Parcel data;