libbinder: dynamically accept clients

Server listens on a single port and add clients.

The server looks like this:

    while True:
        accept client
        read client id
        if new id:
            create new rpc connection
        else:
            attach thread to existing rpc connection

Roadmap:
- having client add connections only when needed (currently they are
  all added at initialization time) - when this change is made, the
  server will also need to enforce the max threads per client.
- allowing RpcConnection to create reverse connections with an
  threadpool to serve calls in the other direction
- replacing connection IDs with something like TLS
- access controls for who can connect to who in pKVM context

Bug: 185167543
Test: binderRpcTest
Change-Id: I510d23a50cf839c39bc8107c1b0dae24dee3bc7b
diff --git a/libs/binder/RpcConnection.cpp b/libs/binder/RpcConnection.cpp
index 95eba87..4b3a53f 100644
--- a/libs/binder/RpcConnection.cpp
+++ b/libs/binder/RpcConnection.cpp
@@ -133,6 +133,21 @@
     return OK;
 }
 
+void RpcConnection::startThread(unique_fd client) {
+    std::lock_guard<std::mutex> _l(mSocketMutex);
+    sp<RpcConnection> holdThis = sp<RpcConnection>::fromExisting(this);
+    int fd = client.release();
+    auto thread = std::thread([=] {
+        holdThis->join(unique_fd(fd));
+        {
+            std::lock_guard<std::mutex> _l(holdThis->mSocketMutex);
+            size_t erased = mThreads.erase(std::this_thread::get_id());
+            LOG_ALWAYS_FATAL_IF(erased != 0, "Could not erase thread.");
+        }
+    });
+    mThreads[thread.get_id()] = std::move(thread);
+}
+
 void RpcConnection::join(unique_fd client) {
     // must be registered to allow arbitrary client code executing commands to
     // be able to do nested calls (we can't only read from it)
@@ -164,7 +179,7 @@
                             mClients.size());
     }
 
-    if (!setupOneSocketClient(addr)) return false;
+    if (!setupOneSocketClient(addr, RPC_CONNECTION_ID_NEW)) return false;
 
     // TODO(b/185167543): we should add additional connections dynamically
     // instead of all at once.
@@ -186,7 +201,7 @@
     for (size_t i = 0; i + 1 < numThreadsAvailable; i++) {
         // TODO(b/185167543): avoid race w/ accept4 not being called on server
         for (size_t tries = 0; tries < 5; tries++) {
-            if (setupOneSocketClient(addr)) break;
+            if (setupOneSocketClient(addr, mId.value())) break;
             usleep(10000);
         }
     }
@@ -194,7 +209,7 @@
     return true;
 }
 
-bool RpcConnection::setupOneSocketClient(const RpcSocketAddress& addr) {
+bool RpcConnection::setupOneSocketClient(const RpcSocketAddress& addr, int32_t id) {
     unique_fd serverFd(
             TEMP_FAILURE_RETRY(socket(addr.addr()->sa_family, SOCK_STREAM | SOCK_CLOEXEC, 0)));
     if (serverFd == -1) {
@@ -209,6 +224,13 @@
         return false;
     }
 
+    if (sizeof(id) != TEMP_FAILURE_RETRY(write(serverFd.get(), &id, sizeof(id)))) {
+        int savedErrno = errno;
+        ALOGE("Could not write id to socket at %s: %s", addr.toString().c_str(),
+              strerror(savedErrno));
+        return false;
+    }
+
     LOG_RPC_DETAIL("Socket at %s client with fd %d", addr.toString().c_str(), serverFd.get());
 
     addClient(std::move(serverFd));
diff --git a/libs/binder/RpcServer.cpp b/libs/binder/RpcServer.cpp
index 5f024ca..4df12ce 100644
--- a/libs/binder/RpcServer.cpp
+++ b/libs/binder/RpcServer.cpp
@@ -126,40 +126,61 @@
     {
         std::lock_guard<std::mutex> _l(mLock);
         LOG_ALWAYS_FATAL_IF(mServer.get() == -1, "RpcServer must be setup to join.");
-        // TODO(b/185167543): support more than one client at once
-        mConnection = RpcConnection::make();
-        mConnection->setForServer(sp<RpcServer>::fromExisting(this), 42 /*placeholder id*/);
-
-        mStarted = true;
-            for (size_t i = 0; i < mMaxThreads; i++) {
-                pool.push_back(std::thread([=] {
-                    // TODO(b/185167543): do this dynamically, instead of from a static number
-                    // of threads
-                    unique_fd clientFd(TEMP_FAILURE_RETRY(
-                            accept4(mServer.get(), nullptr, 0 /*length*/, SOCK_CLOEXEC)));
-                    if (clientFd < 0) {
-                        // If this log becomes confusing, should save more state from
-                        // setupUnixDomainServer in order to output here.
-                        ALOGE("Could not accept4 socket: %s", strerror(errno));
-                        return;
-                    }
-
-                    LOG_RPC_DETAIL("accept4 on fd %d yields fd %d", mServer.get(), clientFd.get());
-
-                    mConnection->join(std::move(clientFd));
-                }));
-            }
     }
 
-    // TODO(b/185167543): don't waste extra thread for join, and combine threads
-    // between clients
-    for (auto& t : pool) t.join();
+    while (true) {
+        unique_fd clientFd(
+                TEMP_FAILURE_RETRY(accept4(mServer.get(), nullptr, 0 /*length*/, SOCK_CLOEXEC)));
+
+        if (clientFd < 0) {
+            ALOGE("Could not accept4 socket: %s", strerror(errno));
+            continue;
+        }
+        LOG_RPC_DETAIL("accept4 on fd %d yields fd %d", mServer.get(), clientFd.get());
+
+        // TODO(b/183988761): cannot trust this simple ID
+        LOG_ALWAYS_FATAL_IF(!mAgreedExperimental, "no!");
+        int32_t id;
+        if (sizeof(id) != read(clientFd.get(), &id, sizeof(id))) {
+            ALOGE("Could not read ID from fd %d", clientFd.get());
+            continue;
+        }
+
+        {
+            std::lock_guard<std::mutex> _l(mLock);
+
+            sp<RpcConnection> connection;
+            if (id == RPC_CONNECTION_ID_NEW) {
+                // new client!
+                LOG_ALWAYS_FATAL_IF(mConnectionIdCounter >= INT32_MAX, "Out of connection IDs");
+                mConnectionIdCounter++;
+
+                connection = RpcConnection::make();
+                connection->setForServer(wp<RpcServer>::fromExisting(this), mConnectionIdCounter);
+
+                mConnections[mConnectionIdCounter] = connection;
+            } else {
+                auto it = mConnections.find(id);
+                if (it == mConnections.end()) {
+                    ALOGE("Cannot add thread, no record of connection with ID %d", id);
+                    continue;
+                }
+                connection = it->second;
+            }
+
+            connection->startThread(std::move(clientFd));
+        }
+    }
 }
 
 std::vector<sp<RpcConnection>> RpcServer::listConnections() {
     std::lock_guard<std::mutex> _l(mLock);
-    if (mConnection == nullptr) return {};
-    return {mConnection};
+    std::vector<sp<RpcConnection>> connections;
+    for (auto& [id, connection] : mConnections) {
+        (void)id;
+        connections.push_back(connection);
+    }
+    return connections;
 }
 
 bool RpcServer::setupSocketServer(const RpcSocketAddress& addr) {
diff --git a/libs/binder/RpcWireFormat.h b/libs/binder/RpcWireFormat.h
index 56af0d3..a7e8a52 100644
--- a/libs/binder/RpcWireFormat.h
+++ b/libs/binder/RpcWireFormat.h
@@ -51,6 +51,8 @@
     RPC_SPECIAL_TRANSACT_GET_CONNECTION_ID = 2,
 };
 
+constexpr int32_t RPC_CONNECTION_ID_NEW = -1;
+
 // serialization is like:
 // |RpcWireHeader|struct desginated by 'command'| (over and over again)
 
diff --git a/libs/binder/include/binder/RpcConnection.h b/libs/binder/include/binder/RpcConnection.h
index 7e31e8a..87984d7 100644
--- a/libs/binder/include/binder/RpcConnection.h
+++ b/libs/binder/include/binder/RpcConnection.h
@@ -21,7 +21,9 @@
 #include <utils/Errors.h>
 #include <utils/RefBase.h>
 
+#include <map>
 #include <optional>
+#include <thread>
 #include <vector>
 
 // WARNING: This is a feature which is still in development, and it is subject
@@ -113,6 +115,7 @@
 
     status_t readId();
 
+    void startThread(base::unique_fd client);
     void join(base::unique_fd client);
 
     struct ConnectionSocket : public RefBase {
@@ -124,7 +127,7 @@
     };
 
     bool setupSocketClient(const RpcSocketAddress& address);
-    bool setupOneSocketClient(const RpcSocketAddress& address);
+    bool setupOneSocketClient(const RpcSocketAddress& address, int32_t connectionId);
     void addClient(base::unique_fd fd);
     void setForServer(const wp<RpcServer>& server, int32_t connectionId);
     sp<ConnectionSocket> assignServerToThisThread(base::unique_fd fd);
@@ -179,11 +182,18 @@
     std::unique_ptr<RpcState> mState;
 
     std::mutex mSocketMutex;           // for all below
+
     std::condition_variable mSocketCv; // for mWaitingThreads
     size_t mWaitingThreads = 0;
     size_t mClientsOffset = 0; // hint index into clients, ++ when sending an async transaction
     std::vector<sp<ConnectionSocket>> mClients;
     std::vector<sp<ConnectionSocket>> mServers;
+
+    // TODO(b/185167543): use for reverse connections (allow client to also
+    // serve calls on a connection).
+    // TODO(b/185167543): allow sharing between different connections in a
+    // process? (or combine with mServers)
+    std::map<std::thread::id, std::thread> mThreads;
 };
 
 } // namespace android
diff --git a/libs/binder/include/binder/RpcServer.h b/libs/binder/include/binder/RpcServer.h
index 5535d8a..81ea3a7 100644
--- a/libs/binder/include/binder/RpcServer.h
+++ b/libs/binder/include/binder/RpcServer.h
@@ -97,6 +97,8 @@
 
     /**
      * You must have at least one client connection before calling this.
+     *
+     * TODO(b/185167543): way to shut down?
      */
     void join();
 
@@ -120,7 +122,8 @@
 
     std::mutex mLock; // for below
     sp<IBinder> mRootObject;
-    sp<RpcConnection> mConnection;
+    std::map<int32_t, sp<RpcConnection>> mConnections;
+    int32_t mConnectionIdCounter = 0;
 };
 
 } // namespace android
diff --git a/libs/binder/tests/IBinderRpcTest.aidl b/libs/binder/tests/IBinderRpcTest.aidl
index 2bdb264..814e094 100644
--- a/libs/binder/tests/IBinderRpcTest.aidl
+++ b/libs/binder/tests/IBinderRpcTest.aidl
@@ -18,8 +18,8 @@
     oneway void sendString(@utf8InCpp String str);
     @utf8InCpp String doubleString(@utf8InCpp String str);
 
-    // number of known RPC binders to process, RpcState::countBinders
-    int countBinders();
+    // number of known RPC binders to process, RpcState::countBinders by connection
+    int[] countBinders();
 
     // Caller sends server, callee pings caller's server and returns error code.
     int pingMe(IBinder binder);
diff --git a/libs/binder/tests/binderRpcTest.cpp b/libs/binder/tests/binderRpcTest.cpp
index d23df8e..50bff91 100644
--- a/libs/binder/tests/binderRpcTest.cpp
+++ b/libs/binder/tests/binderRpcTest.cpp
@@ -88,24 +88,21 @@
         *strstr = str + str;
         return Status::ok();
     }
-    Status countBinders(int32_t* out) override {
+    Status countBinders(std::vector<int32_t>* out) override {
         sp<RpcServer> spServer = server.promote();
         if (spServer == nullptr) {
             return Status::fromExceptionCode(Status::EX_NULL_POINTER);
         }
-        size_t count = 0;
+        out->clear();
         for (auto connection : spServer->listConnections()) {
-            count += connection->state()->countBinders();
-        }
-        // help debugging if we don't have one binder (this call is always made
-        // in this test when exactly one binder is held, which is held only to
-        // call this method - all other binders should be cleaned up)
-        if (count != 1) {
-            for (auto connection : spServer->listConnections()) {
+            size_t count = connection->state()->countBinders();
+            if (count != 1) {
+                // this is called when there is only one binder held remaining,
+                // so to aid debugging
                 connection->state()->dump();
             }
+            out->push_back(count);
         }
-        *out = count;
         return Status::ok();
     }
     Status pingMe(const sp<IBinder>& binder, int32_t* out) override {
@@ -232,25 +229,33 @@
     // reference to process hosting a socket server
     Process host;
 
-    // client connection object associated with other process
-    sp<RpcConnection> connection;
+    struct ConnectionInfo {
+        sp<RpcConnection> connection;
+        sp<IBinder> root;
+    };
 
-    // pre-fetched root object
-    sp<IBinder> rootBinder;
-
-    // whether connection should be invalidated by end of run
-    bool expectInvalid = false;
+    // client connection objects associated with other process
+    // each one represents a separate connection
+    std::vector<ConnectionInfo> connections;
 
     ProcessConnection(ProcessConnection&&) = default;
     ~ProcessConnection() {
-        rootBinder = nullptr;
-        EXPECT_NE(nullptr, connection);
-        EXPECT_NE(nullptr, connection->state());
-        EXPECT_EQ(0, connection->state()->countBinders()) << (connection->state()->dump(), "dump:");
+        for (auto& connection : connections) {
+            connection.root = nullptr;
+        }
 
-        wp<RpcConnection> weakConnection = connection;
-        connection = nullptr;
-        EXPECT_EQ(nullptr, weakConnection.promote()) << "Leaked connection";
+        for (auto& info : connections) {
+            sp<RpcConnection>& connection = info.connection;
+
+            EXPECT_NE(nullptr, connection);
+            EXPECT_NE(nullptr, connection->state());
+            EXPECT_EQ(0, connection->state()->countBinders())
+                    << (connection->state()->dump(), "dump:");
+
+            wp<RpcConnection> weakConnection = connection;
+            connection = nullptr;
+            EXPECT_EQ(nullptr, weakConnection.promote()) << "Leaked connection";
+        }
     }
 };
 
@@ -259,19 +264,25 @@
 struct BinderRpcTestProcessConnection {
     ProcessConnection proc;
 
-    // pre-fetched root object
+    // pre-fetched root object (for first connection)
     sp<IBinder> rootBinder;
 
-    // pre-casted root object
+    // pre-casted root object (for first connection)
     sp<IBinderRpcTest> rootIface;
 
+    // whether connection should be invalidated by end of run
+    bool expectInvalid = false;
+
     BinderRpcTestProcessConnection(BinderRpcTestProcessConnection&&) = default;
     ~BinderRpcTestProcessConnection() {
-        if (!proc.expectInvalid) {
-            int32_t remoteBinders = 0;
-            EXPECT_OK(rootIface->countBinders(&remoteBinders));
-            // should only be the root binder object, iface
-            EXPECT_EQ(remoteBinders, 1);
+        if (!expectInvalid) {
+            std::vector<int32_t> remoteCounts;
+            // calling over any connections counts across all connections
+            EXPECT_OK(rootIface->countBinders(&remoteCounts));
+            EXPECT_EQ(remoteCounts.size(), proc.connections.size());
+            for (auto remoteCount : remoteCounts) {
+                EXPECT_EQ(remoteCount, 1);
+            }
         }
 
         rootIface = nullptr;
@@ -306,7 +317,10 @@
     // This creates a new process serving an interface on a certain number of
     // threads.
     ProcessConnection createRpcTestSocketServerProcess(
-            size_t numThreads, const std::function<void(const sp<RpcServer>&)>& configure) {
+            size_t numThreads, size_t numConnections,
+            const std::function<void(const sp<RpcServer>&)>& configure) {
+        CHECK_GE(numConnections, 1) << "Must have at least one connection to a server";
+
         SocketType socketType = GetParam();
 
         std::string addr = allocateSocketAddress();
@@ -346,7 +360,6 @@
 
                     server->join();
                 }),
-                .connection = RpcConnection::make(),
         };
 
         unsigned int inetPort = 0;
@@ -356,35 +369,37 @@
             CHECK_NE(0, inetPort);
         }
 
-        // create remainder of connections
-        for (size_t tries = 0; tries < 10; tries++) {
-            usleep(10000);
-            switch (socketType) {
-                case SocketType::UNIX:
-                    if (ret.connection->setupUnixDomainClient(addr.c_str())) goto success;
-                    break;
+        for (size_t i = 0; i < numConnections; i++) {
+            sp<RpcConnection> connection = RpcConnection::make();
+            for (size_t tries = 0; tries < 10; tries++) {
+                usleep(10000);
+                switch (socketType) {
+                    case SocketType::UNIX:
+                        if (connection->setupUnixDomainClient(addr.c_str())) goto success;
+                        break;
 #ifdef __BIONIC__
-                case SocketType::VSOCK:
-                    if (ret.connection->setupVsockClient(VMADDR_CID_LOCAL, vsockPort)) goto success;
-                    break;
+                    case SocketType::VSOCK:
+                        if (connection->setupVsockClient(VMADDR_CID_LOCAL, vsockPort)) goto success;
+                        break;
 #endif // __BIONIC__
-                case SocketType::INET:
-                    if (ret.connection->setupInetClient("127.0.0.1", inetPort)) goto success;
-                    break;
-                default:
-                    LOG_ALWAYS_FATAL("Unknown socket type");
+                    case SocketType::INET:
+                        if (connection->setupInetClient("127.0.0.1", inetPort)) goto success;
+                        break;
+                    default:
+                        LOG_ALWAYS_FATAL("Unknown socket type");
+                }
             }
+            LOG_ALWAYS_FATAL("Could not connect");
+        success:
+            ret.connections.push_back({connection, connection->getRootObject()});
         }
-        LOG_ALWAYS_FATAL("Could not connect");
-    success:
-
-        ret.rootBinder = ret.connection->getRootObject();
         return ret;
     }
 
-    BinderRpcTestProcessConnection createRpcTestSocketServerProcess(size_t numThreads) {
+    BinderRpcTestProcessConnection createRpcTestSocketServerProcess(size_t numThreads,
+                                                                    size_t numConnections = 1) {
         BinderRpcTestProcessConnection ret{
-                .proc = createRpcTestSocketServerProcess(numThreads,
+                .proc = createRpcTestSocketServerProcess(numThreads, numConnections,
                                                          [&](const sp<RpcServer>& server) {
                                                              sp<MyBinderRpcTest> service =
                                                                      new MyBinderRpcTest;
@@ -393,7 +408,7 @@
                                                          }),
         };
 
-        ret.rootBinder = ret.proc.rootBinder;
+        ret.rootBinder = ret.proc.connections.at(0).root;
         ret.rootIface = interface_cast<IBinderRpcTest>(ret.rootBinder);
 
         return ret;
@@ -401,16 +416,12 @@
 };
 
 TEST_P(BinderRpc, RootObjectIsNull) {
-    auto proc = createRpcTestSocketServerProcess(1, [](const sp<RpcServer>& server) {
+    auto proc = createRpcTestSocketServerProcess(1, 1, [](const sp<RpcServer>& server) {
         // this is the default, but to be explicit
         server->setRootObject(nullptr);
     });
 
-    // retrieved by getRootObject when process is created above
-    EXPECT_EQ(nullptr, proc.rootBinder);
-
-    // make sure we can retrieve it again (process doesn't crash)
-    EXPECT_EQ(nullptr, proc.connection->getRootObject());
+    EXPECT_EQ(nullptr, proc.connections.at(0).root);
 }
 
 TEST_P(BinderRpc, Ping) {
@@ -425,6 +436,14 @@
     EXPECT_EQ(IBinderRpcTest::descriptor, proc.rootBinder->getInterfaceDescriptor());
 }
 
+TEST_P(BinderRpc, MultipleConnections) {
+    auto proc = createRpcTestSocketServerProcess(1 /*threads*/, 5 /*connections*/);
+    for (auto connection : proc.proc.connections) {
+        ASSERT_NE(nullptr, connection.root);
+        EXPECT_EQ(OK, connection.root->pingBinder());
+    }
+}
+
 TEST_P(BinderRpc, TransactionsMustBeMarkedRpc) {
     auto proc = createRpcTestSocketServerProcess(1);
     Parcel data;
@@ -572,6 +591,15 @@
               proc1.rootIface->repeatBinder(proc2.rootBinder, &outBinder).transactionError());
 }
 
+TEST_P(BinderRpc, CannotMixBindersBetweenTwoConnectionsToTheSameServer) {
+    auto proc = createRpcTestSocketServerProcess(1 /*threads*/, 2 /*connections*/);
+
+    sp<IBinder> outBinder;
+    EXPECT_EQ(INVALID_OPERATION,
+              proc.rootIface->repeatBinder(proc.proc.connections.at(1).root, &outBinder)
+                      .transactionError());
+}
+
 TEST_P(BinderRpc, CannotSendRegularBinderOverSocketBinder) {
     auto proc = createRpcTestSocketServerProcess(1);
 
@@ -856,7 +884,7 @@
         EXPECT_EQ(DEAD_OBJECT, proc.rootIface->die(doDeathCleanup).transactionError())
                 << "Do death cleanup: " << doDeathCleanup;
 
-        proc.proc.expectInvalid = true;
+        proc.expectInvalid = true;
     }
 }