Merge "libbinder: RPC prevent oneway cmd buildup"
diff --git a/include/android/multinetwork.h b/include/android/multinetwork.h
index 424299d..fa77593 100644
--- a/include/android/multinetwork.h
+++ b/include/android/multinetwork.h
@@ -83,7 +83,7 @@
  *
  * To clear a previous process binding, invoke with NETWORK_UNSPECIFIED.
  *
- * This is the equivalent of: [android.net.ConnectivityManager#setProcessDefaultNetwork()](https://developer.android.com/reference/android/net/ConnectivityManager.html#setProcessDefaultNetwork(android.net.Network))
+ * This is the equivalent of: [android.net.ConnectivityManager#bindProcessToNetwork()](https://developer.android.com/reference/android/net/ConnectivityManager.html#bindProcessToNetwork(android.net.Network))
  *
  * Available since API level 23.
  */
@@ -91,6 +91,19 @@
 
 
 /**
+ * Gets the |network| bound to the current process, as per android_setprocnetwork.
+ *
+ * This is the equivalent of: [android.net.ConnectivityManager#getBoundNetworkForProcess()](https://developer.android.com/reference/android/net/ConnectivityManager.html#getBoundNetworkForProcess(android.net.Network))
+ * Returns 0 on success, or -1 setting errno to EINVAL if a null pointer is
+ * passed in.
+ *
+ *
+ * Available since API level 31.
+ */
+int android_getprocnetwork(net_handle_t *network) __INTRODUCED_IN(31);
+
+
+/**
  * Perform hostname resolution via the DNS servers associated with |network|.
  *
  * All arguments (apart from |network|) are used identically as those passed
diff --git a/libs/binder/RpcServer.cpp b/libs/binder/RpcServer.cpp
index b146bb0..2f378da 100644
--- a/libs/binder/RpcServer.cpp
+++ b/libs/binder/RpcServer.cpp
@@ -153,7 +153,22 @@
 
     status_t status;
     while ((status = mShutdownTrigger->triggerablePollRead(mServer)) == OK) {
-        (void)acceptOne();
+        unique_fd clientFd(TEMP_FAILURE_RETRY(
+                accept4(mServer.get(), nullptr, nullptr /*length*/, SOCK_CLOEXEC)));
+
+        if (clientFd < 0) {
+            ALOGE("Could not accept4 socket: %s", strerror(errno));
+            continue;
+        }
+        LOG_RPC_DETAIL("accept4 on fd %d yields fd %d", mServer.get(), clientFd.get());
+
+        {
+            std::lock_guard<std::mutex> _l(mLock);
+            std::thread thread =
+                    std::thread(&RpcServer::establishConnection, sp<RpcServer>::fromExisting(this),
+                                std::move(clientFd));
+            mConnectingThreads[thread.get_id()] = std::move(thread);
+        }
     }
     LOG_RPC_DETAIL("RpcServer::join exiting with %s", statusToString(status).c_str());
 
@@ -164,26 +179,6 @@
     mShutdownCv.notify_all();
 }
 
-bool RpcServer::acceptOne() {
-    unique_fd clientFd(
-            TEMP_FAILURE_RETRY(accept4(mServer.get(), nullptr, nullptr /*length*/, SOCK_CLOEXEC)));
-
-    if (clientFd < 0) {
-        ALOGE("Could not accept4 socket: %s", strerror(errno));
-        return false;
-    }
-    LOG_RPC_DETAIL("accept4 on fd %d yields fd %d", mServer.get(), clientFd.get());
-
-    {
-        std::lock_guard<std::mutex> _l(mLock);
-        std::thread thread = std::thread(&RpcServer::establishConnection,
-                                         sp<RpcServer>::fromExisting(this), std::move(clientFd));
-        mConnectingThreads[thread.get_id()] = std::move(thread);
-    }
-
-    return true;
-}
-
 bool RpcServer::shutdown() {
     std::unique_lock<std::mutex> _l(mLock);
     if (mShutdownTrigger == nullptr) {
@@ -280,6 +275,7 @@
             server->mSessionIdCounter++;
 
             session = RpcSession::make();
+            session->setMaxThreads(server->mMaxThreads);
             session->setForServer(server,
                                   sp<RpcServer::EventListener>::fromExisting(
                                           static_cast<RpcServer::EventListener*>(server.get())),
diff --git a/libs/binder/RpcSession.cpp b/libs/binder/RpcSession.cpp
index a2fe3b9..c563377 100644
--- a/libs/binder/RpcSession.cpp
+++ b/libs/binder/RpcSession.cpp
@@ -59,15 +59,18 @@
     return sp<RpcSession>::make();
 }
 
-void RpcSession::setMaxReverseConnections(size_t connections) {
-    {
-        std::lock_guard<std::mutex> _l(mMutex);
-        LOG_ALWAYS_FATAL_IF(mClientConnections.size() != 0,
-                            "Must setup reverse connections before setting up client connections, "
-                            "but already has %zu clients",
-                            mClientConnections.size());
-    }
-    mMaxReverseConnections = connections;
+void RpcSession::setMaxThreads(size_t threads) {
+    std::lock_guard<std::mutex> _l(mMutex);
+    LOG_ALWAYS_FATAL_IF(!mClientConnections.empty() || !mServerConnections.empty(),
+                        "Must set max threads before setting up connections, but has %zu client(s) "
+                        "and %zu server(s)",
+                        mClientConnections.size(), mServerConnections.size());
+    mMaxThreads = threads;
+}
+
+size_t RpcSession::getMaxThreads() {
+    std::lock_guard<std::mutex> _l(mMutex);
+    return mMaxThreads;
 }
 
 bool RpcSession::setupUnixDomainClient(const char* path) {
@@ -310,7 +313,7 @@
     // requested to be set) in order to allow the other side to reliably make
     // any requests at all.
 
-    for (size_t i = 0; i < mMaxReverseConnections; i++) {
+    for (size_t i = 0; i < mMaxThreads; i++) {
         if (!setupOneSocketConnection(addr, mId.value(), true /*reverse*/)) return false;
     }
 
diff --git a/libs/binder/RpcState.cpp b/libs/binder/RpcState.cpp
index f1cbe0d..1f97293 100644
--- a/libs/binder/RpcState.cpp
+++ b/libs/binder/RpcState.cpp
@@ -352,35 +352,35 @@
         }
     }
 
+    LOG_ALWAYS_FATAL_IF(std::numeric_limits<int32_t>::max() - sizeof(RpcWireHeader) -
+                                        sizeof(RpcWireTransaction) <
+                                data.dataSize(),
+                        "Too much data %zu", data.dataSize());
+
+    RpcWireHeader command{
+            .command = RPC_COMMAND_TRANSACT,
+            .bodySize = static_cast<uint32_t>(sizeof(RpcWireTransaction) + data.dataSize()),
+    };
     RpcWireTransaction transaction{
             .address = address.viewRawEmbedded(),
             .code = code,
             .flags = flags,
             .asyncNumber = asyncNumber,
     };
-
-    CommandData transactionData(sizeof(RpcWireTransaction) + data.dataSize());
+    CommandData transactionData(sizeof(RpcWireHeader) + sizeof(RpcWireTransaction) +
+                                data.dataSize());
     if (!transactionData.valid()) {
         return NO_MEMORY;
     }
 
-    memcpy(transactionData.data() + 0, &transaction, sizeof(RpcWireTransaction));
-    memcpy(transactionData.data() + sizeof(RpcWireTransaction), data.data(), data.dataSize());
+    memcpy(transactionData.data() + 0, &command, sizeof(RpcWireHeader));
+    memcpy(transactionData.data() + sizeof(RpcWireHeader), &transaction,
+           sizeof(RpcWireTransaction));
+    memcpy(transactionData.data() + sizeof(RpcWireHeader) + sizeof(RpcWireTransaction), data.data(),
+           data.dataSize());
 
-    if (transactionData.size() > std::numeric_limits<uint32_t>::max()) {
-        ALOGE("Transaction size too big %zu", transactionData.size());
-        return BAD_VALUE;
-    }
-
-    RpcWireHeader command{
-            .command = RPC_COMMAND_TRANSACT,
-            .bodySize = static_cast<uint32_t>(transactionData.size()),
-    };
-
-    if (status_t status = rpcSend(fd, "transact header", &command, sizeof(command)); status != OK)
-        return status;
     if (status_t status =
-                rpcSend(fd, "command body", transactionData.data(), transactionData.size());
+                rpcSend(fd, "transaction", transactionData.data(), transactionData.size());
         status != OK)
         return status;
 
@@ -642,34 +642,34 @@
         } else {
             LOG_RPC_DETAIL("Got special transaction %u", transaction->code);
 
-            sp<RpcServer> server = session->server().promote();
-            if (server) {
-                // special case for 'zero' address (special server commands)
-                switch (transaction->code) {
-                    case RPC_SPECIAL_TRANSACT_GET_ROOT: {
-                        replyStatus = reply.writeStrongBinder(server->getRootObject());
-                        break;
-                    }
-                    case RPC_SPECIAL_TRANSACT_GET_MAX_THREADS: {
-                        replyStatus = reply.writeInt32(server->getMaxThreads());
-                        break;
-                    }
-                    case RPC_SPECIAL_TRANSACT_GET_SESSION_ID: {
-                        // only sessions w/ services can be the source of a
-                        // session ID (so still guarded by non-null server)
-                        //
-                        // sessions associated with servers must have an ID
-                        // (hence abort)
-                        int32_t id = session->mId.value();
-                        replyStatus = reply.writeInt32(id);
-                        break;
-                    }
-                    default: {
-                        replyStatus = UNKNOWN_TRANSACTION;
+            switch (transaction->code) {
+                case RPC_SPECIAL_TRANSACT_GET_MAX_THREADS: {
+                    replyStatus = reply.writeInt32(session->getMaxThreads());
+                    break;
+                }
+                case RPC_SPECIAL_TRANSACT_GET_SESSION_ID: {
+                    // for client connections, this should always report the value
+                    // originally returned from the server
+                    int32_t id = session->mId.value();
+                    replyStatus = reply.writeInt32(id);
+                    break;
+                }
+                default: {
+                    sp<RpcServer> server = session->server().promote();
+                    if (server) {
+                        switch (transaction->code) {
+                            case RPC_SPECIAL_TRANSACT_GET_ROOT: {
+                                replyStatus = reply.writeStrongBinder(server->getRootObject());
+                                break;
+                            }
+                            default: {
+                                replyStatus = UNKNOWN_TRANSACTION;
+                            }
+                        }
+                    } else {
+                        ALOGE("Special command sent, but no server object attached.");
                     }
                 }
-            } else {
-                ALOGE("Special command sent, but no server object attached.");
             }
         }
     }
@@ -728,35 +728,29 @@
         return OK;
     }
 
+    LOG_ALWAYS_FATAL_IF(std::numeric_limits<int32_t>::max() - sizeof(RpcWireHeader) -
+                                        sizeof(RpcWireReply) <
+                                reply.dataSize(),
+                        "Too much data for reply %zu", reply.dataSize());
+
+    RpcWireHeader cmdReply{
+            .command = RPC_COMMAND_REPLY,
+            .bodySize = static_cast<uint32_t>(sizeof(RpcWireReply) + reply.dataSize()),
+    };
     RpcWireReply rpcReply{
             .status = replyStatus,
     };
 
-    CommandData replyData(sizeof(RpcWireReply) + reply.dataSize());
+    CommandData replyData(sizeof(RpcWireHeader) + sizeof(RpcWireReply) + reply.dataSize());
     if (!replyData.valid()) {
         return NO_MEMORY;
     }
-    memcpy(replyData.data() + 0, &rpcReply, sizeof(RpcWireReply));
-    memcpy(replyData.data() + sizeof(RpcWireReply), reply.data(), reply.dataSize());
+    memcpy(replyData.data() + 0, &cmdReply, sizeof(RpcWireHeader));
+    memcpy(replyData.data() + sizeof(RpcWireHeader), &rpcReply, sizeof(RpcWireReply));
+    memcpy(replyData.data() + sizeof(RpcWireHeader) + sizeof(RpcWireReply), reply.data(),
+           reply.dataSize());
 
-    if (replyData.size() > std::numeric_limits<uint32_t>::max()) {
-        ALOGE("Reply size too big %zu", transactionData.size());
-        terminate();
-        return BAD_VALUE;
-    }
-
-    RpcWireHeader cmdReply{
-            .command = RPC_COMMAND_REPLY,
-            .bodySize = static_cast<uint32_t>(replyData.size()),
-    };
-
-    if (status_t status = rpcSend(fd, "reply header", &cmdReply, sizeof(RpcWireHeader));
-        status != OK)
-        return status;
-    if (status_t status = rpcSend(fd, "reply body", replyData.data(), replyData.size());
-        status != OK)
-        return status;
-    return OK;
+    return rpcSend(fd, "reply", replyData.data(), replyData.size());
 }
 
 status_t RpcState::processDecStrong(const base::unique_fd& fd, const sp<RpcSession>& session,
diff --git a/libs/binder/include/binder/RpcServer.h b/libs/binder/include/binder/RpcServer.h
index 0082ec3..98db221 100644
--- a/libs/binder/include/binder/RpcServer.h
+++ b/libs/binder/include/binder/RpcServer.h
@@ -160,7 +160,6 @@
 
     static void establishConnection(sp<RpcServer>&& server, base::unique_fd clientFd);
     bool setupSocketServer(const RpcSocketAddress& address);
-    [[nodiscard]] bool acceptOne();
 
     bool mAgreedExperimental = false;
     size_t mMaxThreads = 1;
diff --git a/libs/binder/include/binder/RpcSession.h b/libs/binder/include/binder/RpcSession.h
index 9d314e4..a6bc1a9 100644
--- a/libs/binder/include/binder/RpcSession.h
+++ b/libs/binder/include/binder/RpcSession.h
@@ -47,16 +47,17 @@
     static sp<RpcSession> make();
 
     /**
-     * Set the maximum number of reverse connections allowed to be made (for
-     * things like callbacks). By default, this is 0. This must be called before
-     * setting up this connection as a client.
+     * Set the maximum number of threads allowed to be made (for things like callbacks).
+     * By default, this is 0. This must be called before setting up this connection as a client.
+     * Server sessions will inherits this value from RpcServer.
      *
      * If this is called, 'shutdown' on this session must also be called.
      * Otherwise, a threadpool will leak.
      *
      * TODO(b/185167543): start these dynamically
      */
-    void setMaxReverseConnections(size_t connections);
+    void setMaxThreads(size_t threads);
+    size_t getMaxThreads();
 
     /**
      * This should be called once per thread, matching 'join' in the remote
@@ -257,7 +258,7 @@
 
     std::mutex mMutex; // for all below
 
-    size_t mMaxReverseConnections = 0;
+    size_t mMaxThreads = 0;
 
     std::condition_variable mAvailableConnectionCv; // for mWaitingThreads
     size_t mWaitingThreads = 0;
diff --git a/libs/binder/ndk/tests/Android.bp b/libs/binder/ndk/tests/Android.bp
index bb51bf0..ede4873 100644
--- a/libs/binder/ndk/tests/Android.bp
+++ b/libs/binder/ndk/tests/Android.bp
@@ -95,7 +95,7 @@
         "libbinder_ndk",
         "libutils",
     ],
-    test_suites: ["general-tests", "vts"],
+    test_suites: ["general-tests"],
     require_root: true,
 }
 
diff --git a/libs/binder/tests/binderRpcTest.cpp b/libs/binder/tests/binderRpcTest.cpp
index 601ac6a..0a970fb 100644
--- a/libs/binder/tests/binderRpcTest.cpp
+++ b/libs/binder/tests/binderRpcTest.cpp
@@ -446,7 +446,7 @@
 
         for (size_t i = 0; i < numSessions; i++) {
             sp<RpcSession> session = RpcSession::make();
-            session->setMaxReverseConnections(numReverseConnections);
+            session->setMaxThreads(numReverseConnections);
 
             switch (socketType) {
                 case SocketType::UNIX: