Merge "libbinder: finalize connect/server APIs" am: a98286eb6b

Original change: https://android-review.googlesource.com/c/platform/frameworks/native/+/1683484

Change-Id: I715b1b49cd88cfcb293ff52239abac05265e4dfe
diff --git a/libs/binder/RpcConnection.cpp b/libs/binder/RpcConnection.cpp
index 2502d1b..ee5f508 100644
--- a/libs/binder/RpcConnection.cpp
+++ b/libs/binder/RpcConnection.cpp
@@ -139,8 +139,8 @@
     return setupSocketServer(UnixSocketAddress(path));
 }
 
-bool RpcConnection::addUnixDomainClient(const char* path) {
-    return addSocketClient(UnixSocketAddress(path));
+bool RpcConnection::setupUnixDomainClient(const char* path) {
+    return setupSocketClient(UnixSocketAddress(path));
 }
 
 #ifdef __BIONIC__
@@ -171,8 +171,8 @@
     return setupSocketServer(VsockSocketAddress(kAnyCid, port));
 }
 
-bool RpcConnection::addVsockClient(unsigned int cid, unsigned int port) {
-    return addSocketClient(VsockSocketAddress(cid, port));
+bool RpcConnection::setupVsockClient(unsigned int cid, unsigned int port) {
+    return setupSocketClient(VsockSocketAddress(cid, port));
 }
 
 #endif // __BIONIC__
@@ -240,12 +240,12 @@
     return false;
 }
 
-bool RpcConnection::addInetClient(const char* addr, unsigned int port) {
+bool RpcConnection::setupInetClient(const char* addr, unsigned int port) {
     auto aiStart = GetAddrInfo(addr, port);
     if (aiStart == nullptr) return false;
     for (auto ai = aiStart.get(); ai != nullptr; ai = ai->ai_next) {
         InetSocketAddress socketAddress(ai->ai_addr, ai->ai_addrlen, addr, port);
-        if (addSocketClient(socketAddress)) return true;
+        if (setupSocketClient(socketAddress)) return true;
     }
     ALOGE("None of the socket address resolved for %s:%u can be added as inet client.", addr, port);
     return false;
@@ -268,6 +268,11 @@
     return state()->getRootObject(socket.fd(), sp<RpcConnection>::fromExisting(this));
 }
 
+status_t RpcConnection::getMaxThreads(size_t* maxThreads) {
+    ExclusiveSocket socket(sp<RpcConnection>::fromExisting(this), SocketUse::CLIENT);
+    return state()->getMaxThreads(socket.fd(), sp<RpcConnection>::fromExisting(this), maxThreads);
+}
+
 status_t RpcConnection::transact(const RpcAddress& address, uint32_t code, const Parcel& data,
                                  Parcel* reply, uint32_t flags) {
     ExclusiveSocket socket(sp<RpcConnection>::fromExisting(this),
@@ -348,7 +353,39 @@
     return true;
 }
 
-bool RpcConnection::addSocketClient(const SocketAddress& addr) {
+bool RpcConnection::setupSocketClient(const SocketAddress& addr) {
+    {
+        std::lock_guard<std::mutex> _l(mSocketMutex);
+        LOG_ALWAYS_FATAL_IF(mClients.size() != 0,
+                            "Must only setup connection once, but already has %zu clients",
+                            mClients.size());
+    }
+
+    if (!setupOneSocketClient(addr)) return false;
+
+    // TODO(b/185167543): we should add additional connections dynamically
+    // instead of all at once.
+    // TODO(b/186470974): first risk of blocking
+    size_t numThreadsAvailable;
+    if (status_t status = getMaxThreads(&numThreadsAvailable); status != OK) {
+        ALOGE("Could not get max threads after initial connection to %s: %s",
+              addr.toString().c_str(), statusToString(status).c_str());
+        return false;
+    }
+
+    // we've already setup one client
+    for (size_t i = 0; i + 1 < numThreadsAvailable; i++) {
+        // TODO(b/185167543): avoid race w/ accept4 not being called on server
+        for (size_t tries = 0; tries < 5; tries++) {
+            if (setupOneSocketClient(addr)) break;
+            usleep(10000);
+        }
+    }
+
+    return true;
+}
+
+bool RpcConnection::setupOneSocketClient(const SocketAddress& addr) {
     unique_fd serverFd(
             TEMP_FAILURE_RETRY(socket(addr.addr()->sa_family, SOCK_STREAM | SOCK_CLOEXEC, 0)));
     if (serverFd == -1) {
diff --git a/libs/binder/RpcServer.cpp b/libs/binder/RpcServer.cpp
index 9a0be92..8f2805f 100644
--- a/libs/binder/RpcServer.cpp
+++ b/libs/binder/RpcServer.cpp
@@ -19,6 +19,7 @@
 #include <sys/socket.h>
 #include <sys/un.h>
 
+#include <thread>
 #include <vector>
 
 #include <binder/Parcel.h>
@@ -41,16 +42,19 @@
     mAgreedExperimental = true;
 }
 
-sp<RpcConnection> RpcServer::addClientConnection() {
-    LOG_ALWAYS_FATAL_IF(!mAgreedExperimental, "no!");
-
-    auto connection = RpcConnection::make();
-    connection->setForServer(sp<RpcServer>::fromExisting(this));
+void RpcServer::setMaxThreads(size_t threads) {
+    LOG_ALWAYS_FATAL_IF(threads <= 0, "RpcServer is useless without threads");
     {
+        // this lock should only ever be needed in the error case
         std::lock_guard<std::mutex> _l(mLock);
-        mConnections.push_back(connection);
+        LOG_ALWAYS_FATAL_IF(mConnections.size() > 0,
+                            "Must specify max threads before creating a connection");
     }
-    return connection;
+    mMaxThreads = threads;
+}
+
+size_t RpcServer::getMaxThreads() {
+    return mMaxThreads;
 }
 
 void RpcServer::setRootObject(const sp<IBinder>& binder) {
@@ -63,4 +67,35 @@
     return mRootObject;
 }
 
+sp<RpcConnection> RpcServer::addClientConnection() {
+    LOG_ALWAYS_FATAL_IF(!mAgreedExperimental, "no!");
+
+    auto connection = RpcConnection::make();
+    connection->setForServer(sp<RpcServer>::fromExisting(this));
+    {
+        std::lock_guard<std::mutex> _l(mLock);
+        LOG_ALWAYS_FATAL_IF(mStarted,
+                            "currently only supports adding client connections at creation time");
+        mConnections.push_back(connection);
+    }
+    return connection;
+}
+
+void RpcServer::join() {
+    std::vector<std::thread> pool;
+    {
+        std::lock_guard<std::mutex> _l(mLock);
+        mStarted = true;
+        for (const sp<RpcConnection>& connection : mConnections) {
+            for (size_t i = 0; i < mMaxThreads; i++) {
+                pool.push_back(std::thread([=] { connection->join(); }));
+            }
+        }
+    }
+
+    // TODO(b/185167543): don't waste extra thread for join, and combine threads
+    // between clients
+    for (auto& t : pool) t.join();
+}
+
 } // namespace android
diff --git a/libs/binder/RpcState.cpp b/libs/binder/RpcState.cpp
index d934136..6bfcc42 100644
--- a/libs/binder/RpcState.cpp
+++ b/libs/binder/RpcState.cpp
@@ -248,6 +248,31 @@
     return reply.readStrongBinder();
 }
 
+status_t RpcState::getMaxThreads(const base::unique_fd& fd, const sp<RpcConnection>& connection,
+                                 size_t* maxThreads) {
+    Parcel data;
+    data.markForRpc(connection);
+    Parcel reply;
+
+    status_t status = transact(fd, RpcAddress::zero(), RPC_SPECIAL_TRANSACT_GET_MAX_THREADS, data,
+                               connection, &reply, 0);
+    if (status != OK) {
+        ALOGE("Error getting max threads: %s", statusToString(status).c_str());
+        return status;
+    }
+
+    int32_t threads;
+    status = reply.readInt32(&threads);
+    if (status != OK) return status;
+    if (threads <= 0) {
+        ALOGE("Error invalid max threads: %d", threads);
+        return BAD_VALUE;
+    }
+
+    *maxThreads = threads;
+    return OK;
+}
+
 status_t RpcState::transact(const base::unique_fd& fd, const RpcAddress& address, uint32_t code,
                             const Parcel& data, const sp<RpcConnection>& connection, Parcel* reply,
                             uint32_t flags) {
@@ -516,23 +541,25 @@
             replyStatus = target->transact(transaction->code, data, &reply, transaction->flags);
         } else {
             LOG_RPC_DETAIL("Got special transaction %u", transaction->code);
-            // special case for 'zero' address (special server commands)
-            switch (transaction->code) {
-                case RPC_SPECIAL_TRANSACT_GET_ROOT: {
-                    sp<IBinder> root;
-                    sp<RpcServer> server = connection->server().promote();
-                    if (server) {
-                        root = server->getRootObject();
-                    } else {
-                        ALOGE("Root object requested, but no server attached.");
-                    }
 
-                    replyStatus = reply.writeStrongBinder(root);
-                    break;
+            sp<RpcServer> server = connection->server().promote();
+            if (server) {
+                // special case for 'zero' address (special server commands)
+                switch (transaction->code) {
+                    case RPC_SPECIAL_TRANSACT_GET_ROOT: {
+                        replyStatus = reply.writeStrongBinder(server->getRootObject());
+                        break;
+                    }
+                    case RPC_SPECIAL_TRANSACT_GET_MAX_THREADS: {
+                        replyStatus = reply.writeInt32(server->getMaxThreads());
+                        break;
+                    }
+                    default: {
+                        replyStatus = UNKNOWN_TRANSACTION;
+                    }
                 }
-                default: {
-                    replyStatus = UNKNOWN_TRANSACTION;
-                }
+            } else {
+                ALOGE("Special command sent, but no server object attached.");
             }
         }
     }
diff --git a/libs/binder/RpcState.h b/libs/binder/RpcState.h
index f4f5151..1cfa406 100644
--- a/libs/binder/RpcState.h
+++ b/libs/binder/RpcState.h
@@ -51,6 +51,8 @@
     ~RpcState();
 
     sp<IBinder> getRootObject(const base::unique_fd& fd, const sp<RpcConnection>& connection);
+    status_t getMaxThreads(const base::unique_fd& fd, const sp<RpcConnection>& connection,
+                           size_t* maxThreadsOut);
 
     [[nodiscard]] status_t transact(const base::unique_fd& fd, const RpcAddress& address,
                                     uint32_t code, const Parcel& data,
diff --git a/libs/binder/RpcWireFormat.h b/libs/binder/RpcWireFormat.h
index 60ec6c9..cc7cacb 100644
--- a/libs/binder/RpcWireFormat.h
+++ b/libs/binder/RpcWireFormat.h
@@ -47,6 +47,7 @@
  */
 enum : uint32_t {
     RPC_SPECIAL_TRANSACT_GET_ROOT = 0,
+    RPC_SPECIAL_TRANSACT_GET_MAX_THREADS = 1,
 };
 
 // serialization is like:
diff --git a/libs/binder/include/binder/RpcConnection.h b/libs/binder/include/binder/RpcConnection.h
index 09aed13..3a2d8e5 100644
--- a/libs/binder/include/binder/RpcConnection.h
+++ b/libs/binder/include/binder/RpcConnection.h
@@ -59,7 +59,7 @@
      * This should be called once per thread, matching 'join' in the remote
      * process.
      */
-    [[nodiscard]] bool addUnixDomainClient(const char* path);
+    [[nodiscard]] bool setupUnixDomainClient(const char* path);
 
 #ifdef __BIONIC__
     /**
@@ -70,7 +70,7 @@
     /**
      * Connects to an RPC server at the CVD & port.
      */
-    [[nodiscard]] bool addVsockClient(unsigned int cvd, unsigned int port);
+    [[nodiscard]] bool setupVsockClient(unsigned int cvd, unsigned int port);
 #endif // __BIONIC__
 
     /**
@@ -87,7 +87,7 @@
     /**
      * Connects to an RPC server at the given address and port.
      */
-    [[nodiscard]] bool addInetClient(const char* addr, unsigned int port);
+    [[nodiscard]] bool setupInetClient(const char* addr, unsigned int port);
 
     /**
      * For debugging!
@@ -104,16 +104,16 @@
      */
     sp<IBinder> getRootObject();
 
+    /**
+     * Query the other side of the connection for the maximum number of threads
+     * it supports (maximum number of concurrent non-nested synchronous transactions)
+     */
+    status_t getMaxThreads(size_t* maxThreads);
+
     [[nodiscard]] status_t transact(const RpcAddress& address, uint32_t code, const Parcel& data,
                                     Parcel* reply, uint32_t flags);
     [[nodiscard]] status_t sendDecStrong(const RpcAddress& address);
 
-    /**
-     * Adds a server thread accepting connections. Must be called after
-     * setup*Server.
-     */
-    void join();
-
     ~RpcConnection();
 
     void setForServer(const wp<RpcServer>& server);
@@ -132,8 +132,11 @@
 
 private:
     friend sp<RpcConnection>;
+    friend RpcServer;
     RpcConnection();
 
+    void join();
+
     struct ConnectionSocket : public RefBase {
         base::unique_fd fd;
 
@@ -143,7 +146,8 @@
     };
 
     bool setupSocketServer(const SocketAddress& address);
-    bool addSocketClient(const SocketAddress& address);
+    bool setupSocketClient(const SocketAddress& address);
+    bool setupOneSocketClient(const SocketAddress& address);
     void addClient(base::unique_fd&& fd);
     sp<ConnectionSocket> assignServerToThisThread(base::unique_fd&& fd);
     bool removeServerSocket(const sp<ConnectionSocket>& socket);
diff --git a/libs/binder/include/binder/RpcServer.h b/libs/binder/include/binder/RpcServer.h
index a665fad..9247128 100644
--- a/libs/binder/include/binder/RpcServer.h
+++ b/libs/binder/include/binder/RpcServer.h
@@ -40,6 +40,24 @@
     void iUnderstandThisCodeIsExperimentalAndIWillNotUseItInProduction();
 
     /**
+     * This must be called before adding a client connection.
+     *
+     * If this is not specified, this will be a single-threaded server.
+     *
+     * TODO(b/185167543): these are currently created per client, but these
+     * should be shared.
+     */
+    void setMaxThreads(size_t threads);
+    size_t getMaxThreads();
+
+    /**
+     * The root object can be retrieved by any client, without any
+     * authentication. TODO(b/183988761)
+     */
+    void setRootObject(const sp<IBinder>& binder);
+    sp<IBinder> getRootObject();
+
+    /**
      * Setup a static connection, when the number of clients are known.
      *
      * Each call to this function corresponds to a different client, and clients
@@ -50,15 +68,9 @@
     sp<RpcConnection> addClientConnection();
 
     /**
-     * The root object can be retrieved by any client, without any
-     * authentication. TODO(b/183988761)
+     * You must have at least one client connection before calling this.
      */
-    void setRootObject(const sp<IBinder>& binder);
-
-    /**
-     * Root object set with setRootObject
-     */
-    sp<IBinder> getRootObject();
+    void join();
 
     ~RpcServer();
 
@@ -67,8 +79,10 @@
     RpcServer();
 
     bool mAgreedExperimental = false;
+    bool mStarted = false; // TODO(b/185167543): support dynamically added clients
+    size_t mMaxThreads = 1;
 
-    std::mutex mLock;
+    std::mutex mLock; // for below
     sp<IBinder> mRootObject;
     std::vector<sp<RpcConnection>> mConnections; // per-client
 };
diff --git a/libs/binder/tests/binderRpcBenchmark.cpp b/libs/binder/tests/binderRpcBenchmark.cpp
index 7c82226..b3282ff 100644
--- a/libs/binder/tests/binderRpcBenchmark.cpp
+++ b/libs/binder/tests/binderRpcBenchmark.cpp
@@ -127,12 +127,12 @@
         sp<RpcConnection> connection = server->addClientConnection();
         CHECK(connection->setupUnixDomainServer(addr.c_str()));
 
-        connection->join();
+        server->join();
     }).detach();
 
     for (size_t tries = 0; tries < 5; tries++) {
         usleep(10000);
-        if (gConnection->addUnixDomainClient(addr.c_str())) goto success;
+        if (gConnection->setupUnixDomainClient(addr.c_str())) goto success;
     }
     LOG(FATAL) << "Could not connect.";
 success:
diff --git a/libs/binder/tests/binderRpcTest.cpp b/libs/binder/tests/binderRpcTest.cpp
index ce69ea2..f3ec904 100644
--- a/libs/binder/tests/binderRpcTest.cpp
+++ b/libs/binder/tests/binderRpcTest.cpp
@@ -298,8 +298,6 @@
     ProcessConnection createRpcTestSocketServerProcess(
             size_t numThreads,
             const std::function<void(const sp<RpcServer>&, const sp<RpcConnection>&)>& configure) {
-        CHECK_GT(numThreads, 0);
-
         SocketType socketType = GetParam();
 
         std::string addr = allocateSocketAddress();
@@ -312,6 +310,7 @@
                     sp<RpcServer> server = RpcServer::make();
 
                     server->iUnderstandThisCodeIsExperimentalAndIWillNotUseItInProduction();
+                    server->setMaxThreads(numThreads);
 
                     // server supporting one client on one socket
                     sp<RpcConnection> connection = server->addClientConnection();
@@ -339,13 +338,7 @@
 
                     configure(server, connection);
 
-                    // accept 'numThreads' connections
-                    std::vector<std::thread> pool;
-                    for (size_t i = 0; i + 1 < numThreads; i++) {
-                        pool.push_back(std::thread([=] { connection->join(); }));
-                    }
-                    connection->join();
-                    for (auto& t : pool) t.join();
+                    server->join();
                 }),
                 .connection = RpcConnection::make(),
         };
@@ -358,29 +351,26 @@
         }
 
         // create remainder of connections
-        for (size_t i = 0; i < numThreads; i++) {
-            for (size_t tries = 0; tries < 5; tries++) {
-                usleep(10000);
-                switch (socketType) {
-                    case SocketType::UNIX:
-                        if (ret.connection->addUnixDomainClient(addr.c_str())) goto success;
-                        break;
+        for (size_t tries = 0; tries < 10; tries++) {
+            usleep(10000);
+            switch (socketType) {
+                case SocketType::UNIX:
+                    if (ret.connection->setupUnixDomainClient(addr.c_str())) goto success;
+                    break;
 #ifdef __BIONIC__
-                    case SocketType::VSOCK:
-                        if (ret.connection->addVsockClient(VMADDR_CID_LOCAL, vsockPort))
-                            goto success;
-                        break;
+                case SocketType::VSOCK:
+                    if (ret.connection->setupVsockClient(VMADDR_CID_LOCAL, vsockPort)) goto success;
+                    break;
 #endif // __BIONIC__
-                    case SocketType::INET:
-                        if (ret.connection->addInetClient("127.0.0.1", inetPort)) goto success;
-                        break;
-                    default:
-                        LOG_ALWAYS_FATAL("Unknown socket type");
-                }
+                case SocketType::INET:
+                    if (ret.connection->setupInetClient("127.0.0.1", inetPort)) goto success;
+                    break;
+                default:
+                    LOG_ALWAYS_FATAL("Unknown socket type");
             }
-            LOG_ALWAYS_FATAL("Could not connect");
-        success:;
         }
+        LOG_ALWAYS_FATAL("Could not connect");
+    success:
 
         ret.rootBinder = ret.connection->getRootObject();
         return ret;