binder: RpcSession limit outgoing threads

Similar to the number of incoming threads,
the number of outgoing threads can be limited via
RpcSession::setMaxOutgoingThreads(). If set, only
min(maxOutgoingThreads, remoteMaxThreads) outgoing threads
are instantiated.

Test: binderRpcTest
Bug: 194225767

Change-Id: I15686bae4317d0ced5af999f3a3d21f9a03037e1
diff --git a/libs/binder/RpcSession.cpp b/libs/binder/RpcSession.cpp
index 486b67b..9eef3e8 100644
--- a/libs/binder/RpcSession.cpp
+++ b/libs/binder/RpcSession.cpp
@@ -90,6 +90,20 @@
     return mMaxIncomingThreads;
 }
 
+void RpcSession::setMaxOutgoingThreads(size_t threads) {
+    std::lock_guard<std::mutex> _l(mMutex);
+    LOG_ALWAYS_FATAL_IF(!mConnections.mOutgoing.empty() || !mConnections.mIncoming.empty(),
+                        "Must set max outgoing threads before setting up connections, but has %zu "
+                        "client(s) and %zu server(s)",
+                        mConnections.mOutgoing.size(), mConnections.mIncoming.size());
+    mMaxOutgoingThreads = threads;
+}
+
+size_t RpcSession::getMaxOutgoingThreads() {
+    std::lock_guard<std::mutex> _l(mMutex);
+    return mMaxOutgoingThreads;
+}
+
 bool RpcSession::setProtocolVersion(uint32_t version) {
     if (version >= RPC_WIRE_PROTOCOL_VERSION_NEXT &&
         version != RPC_WIRE_PROTOCOL_VERSION_EXPERIMENTAL) {
@@ -473,6 +487,12 @@
         return status;
     }
 
+    size_t outgoingThreads = std::min(numThreadsAvailable, mMaxOutgoingThreads);
+    ALOGI_IF(outgoingThreads != numThreadsAvailable,
+             "Server hints client to start %zu outgoing threads, but client will only start %zu "
+             "because it is preconfigured to start at most %zu outgoing threads.",
+             numThreadsAvailable, outgoingThreads, mMaxOutgoingThreads);
+
     // TODO(b/189955605): we should add additional sessions dynamically
     // instead of all at once - the other side should be responsible for setting
     // up additional connections. We need to create at least one (unless 0 are
@@ -480,7 +500,10 @@
     // any requests at all.
 
     // we've already setup one client
-    for (size_t i = 0; i + 1 < numThreadsAvailable; i++) {
+    LOG_RPC_DETAIL("RpcSession::setupClient() instantiating %zu outgoing (server max: %zu) and %zu "
+                   "incoming threads",
+                   outgoingThreads, numThreadsAvailable, mMaxIncomingThreads);
+    for (size_t i = 0; i + 1 < outgoingThreads; i++) {
         if (status_t status = connectAndInit(mId, false /*incoming*/); status != OK) return status;
     }
 
diff --git a/libs/binder/include/binder/RpcSession.h b/libs/binder/include/binder/RpcSession.h
index e64572b..f5505da 100644
--- a/libs/binder/include/binder/RpcSession.h
+++ b/libs/binder/include/binder/RpcSession.h
@@ -50,6 +50,8 @@
  */
 class RpcSession final : public virtual RefBase {
 public:
+    static constexpr size_t kDefaultMaxOutgoingThreads = 10;
+
     // Create an RpcSession with default configuration (raw sockets).
     static sp<RpcSession> make();
 
@@ -72,6 +74,18 @@
     size_t getMaxIncomingThreads();
 
     /**
+     * Set the maximum number of outgoing threads allowed to be made.
+     * By default, this is |kDefaultMaxOutgoingThreads|. This must be called before setting up this
+     * connection as a client.
+     *
+     * This limits the number of outgoing threads on top of the remote peer setting. This RpcSession
+     * will only instantiate |min(maxOutgoingThreads, remoteMaxThreads)| outgoing threads, where
+     * |remoteMaxThreads| can be retrieved from the remote peer via |getRemoteMaxThreads()|.
+     */
+    void setMaxOutgoingThreads(size_t threads);
+    size_t getMaxOutgoingThreads();
+
+    /**
      * By default, the minimum of the supported versions of the client and the
      * server will be used. Usually, this API should only be used for debugging.
      */
@@ -308,6 +322,7 @@
     std::mutex mMutex; // for all below
 
     size_t mMaxIncomingThreads = 0;
+    size_t mMaxOutgoingThreads = kDefaultMaxOutgoingThreads;
     std::optional<uint32_t> mProtocolVersion;
 
     std::condition_variable mAvailableConnectionCv; // for mWaitingThreads
diff --git a/libs/binder/tests/binderRpcTest.cpp b/libs/binder/tests/binderRpcTest.cpp
index 2011801..d68c6ff 100644
--- a/libs/binder/tests/binderRpcTest.cpp
+++ b/libs/binder/tests/binderRpcTest.cpp
@@ -481,6 +481,7 @@
         size_t numThreads = 1;
         size_t numSessions = 1;
         size_t numIncomingConnections = 0;
+        size_t numOutgoingConnections = SIZE_MAX;
     };
 
     static inline std::string PrintParamInfo(const testing::TestParamInfo<ParamType>& info) {
@@ -614,6 +615,7 @@
 
         for (const auto& session : sessions) {
             session->setMaxIncomingThreads(options.numIncomingConnections);
+            session->setMaxOutgoingThreads(options.numOutgoingConnections);
 
             switch (socketType) {
                 case SocketType::PRECONNECTED:
@@ -655,6 +657,9 @@
 
         return ret;
     }
+
+    void testThreadPoolOverSaturated(sp<IBinderRpcTest> iface, size_t numCalls,
+                                     size_t sleepMs = 500);
 };
 
 TEST_P(BinderRpc, Ping) {
@@ -996,28 +1001,39 @@
     for (auto& t : ts) t.join();
 }
 
-TEST_P(BinderRpc, ThreadPoolOverSaturated) {
-    constexpr size_t kNumThreads = 10;
-    constexpr size_t kNumCalls = kNumThreads + 3;
-    constexpr size_t kSleepMs = 500;
-
-    auto proc = createRpcTestSocketServerProcess({.numThreads = kNumThreads});
-
+void BinderRpc::testThreadPoolOverSaturated(sp<IBinderRpcTest> iface, size_t numCalls,
+                                            size_t sleepMs) {
     size_t epochMsBefore = epochMillis();
 
     std::vector<std::thread> ts;
-    for (size_t i = 0; i < kNumCalls; i++) {
-        ts.push_back(std::thread([&] { proc.rootIface->sleepMs(kSleepMs); }));
+    for (size_t i = 0; i < numCalls; i++) {
+        ts.push_back(std::thread([&] { iface->sleepMs(sleepMs); }));
     }
 
     for (auto& t : ts) t.join();
 
     size_t epochMsAfter = epochMillis();
 
-    EXPECT_GE(epochMsAfter, epochMsBefore + 2 * kSleepMs);
+    EXPECT_GE(epochMsAfter, epochMsBefore + 2 * sleepMs);
 
     // Potential flake, but make sure calls are handled in parallel.
-    EXPECT_LE(epochMsAfter, epochMsBefore + 3 * kSleepMs);
+    EXPECT_LE(epochMsAfter, epochMsBefore + 3 * sleepMs);
+}
+
+TEST_P(BinderRpc, ThreadPoolOverSaturated) {
+    constexpr size_t kNumThreads = 10;
+    constexpr size_t kNumCalls = kNumThreads + 3;
+    auto proc = createRpcTestSocketServerProcess({.numThreads = kNumThreads});
+    testThreadPoolOverSaturated(proc.rootIface, kNumCalls);
+}
+
+TEST_P(BinderRpc, ThreadPoolLimitOutgoing) {
+    constexpr size_t kNumThreads = 20;
+    constexpr size_t kNumOutgoingConnections = 10;
+    constexpr size_t kNumCalls = kNumOutgoingConnections + 3;
+    auto proc = createRpcTestSocketServerProcess(
+            {.numThreads = kNumThreads, .numOutgoingConnections = kNumOutgoingConnections});
+    testThreadPoolOverSaturated(proc.rootIface, kNumCalls);
 }
 
 TEST_P(BinderRpc, ThreadingStressTest) {