libbinder: shutdown session threads

The last piece to completely shutting down servers (this is in
preparation for adding threadpools to server callbacks, which actually
need to be shut down during normal usage).

Bug: 185167543
Test: binderRpcTest

Change-Id: I20d6ac16c58fe6801545fa7be178518201fe075d
diff --git a/libs/binder/RpcServer.cpp b/libs/binder/RpcServer.cpp
index e3bf2a5..bff5543 100644
--- a/libs/binder/RpcServer.cpp
+++ b/libs/binder/RpcServer.cpp
@@ -192,10 +192,10 @@
     }
 
     mShutdownTrigger->trigger();
-    while (mJoinThreadRunning || !mConnectingThreads.empty()) {
+    while (mJoinThreadRunning || !mConnectingThreads.empty() || !mSessions.empty()) {
         ALOGI("Waiting for RpcServer to shut down. Join thread running: %d, Connecting threads: "
-              "%zu",
-              mJoinThreadRunning, mConnectingThreads.size());
+              "%zu, Sessions: %zu",
+              mJoinThreadRunning, mConnectingThreads.size(), mSessions.size());
         mShutdownCv.wait(_l);
     }
 
@@ -278,7 +278,8 @@
             server->mSessionIdCounter++;
 
             session = RpcSession::make();
-            session->setForServer(wp<RpcServer>(server), server->mSessionIdCounter);
+            session->setForServer(wp<RpcServer>(server), server->mSessionIdCounter,
+                                  server->mShutdownTrigger);
 
             server->mSessions[server->mSessionIdCounter] = session;
         } else {
@@ -344,6 +345,11 @@
     (void)mSessions.erase(it);
 }
 
+void RpcServer::onSessionThreadEnding(const sp<RpcSession>& session) {
+    (void)session;
+    mShutdownCv.notify_all();
+}
+
 bool RpcServer::hasServer() {
     LOG_ALWAYS_FATAL_IF(!mAgreedExperimental, "no!");
     std::lock_guard<std::mutex> _l(mLock);
diff --git a/libs/binder/RpcSession.cpp b/libs/binder/RpcSession.cpp
index 9f26a33..7c458c1 100644
--- a/libs/binder/RpcSession.cpp
+++ b/libs/binder/RpcSession.cpp
@@ -207,12 +207,19 @@
     LOG_ALWAYS_FATAL_IF(!removeServerConnection(connection),
                         "bad state: connection object guaranteed to be in list");
 
+    sp<RpcServer> server;
     {
         std::lock_guard<std::mutex> _l(mMutex);
         auto it = mThreads.find(std::this_thread::get_id());
         LOG_ALWAYS_FATAL_IF(it == mThreads.end());
         it->second.detach();
         mThreads.erase(it);
+
+        server = mForServer.promote();
+    }
+
+    if (server != nullptr) {
+        server->onSessionThreadEnding(sp<RpcSession>::fromExisting(this));
     }
 }
 
@@ -314,14 +321,25 @@
 
 void RpcSession::addClientConnection(unique_fd fd) {
     std::lock_guard<std::mutex> _l(mMutex);
+
+    if (mShutdownTrigger == nullptr) {
+        mShutdownTrigger = FdTrigger::make();
+    }
+
     sp<RpcConnection> session = sp<RpcConnection>::make();
     session->fd = std::move(fd);
     mClientConnections.push_back(session);
 }
 
-void RpcSession::setForServer(const wp<RpcServer>& server, int32_t sessionId) {
+void RpcSession::setForServer(const wp<RpcServer>& server, int32_t sessionId,
+                              const std::shared_ptr<FdTrigger>& shutdownTrigger) {
+    LOG_ALWAYS_FATAL_IF(mForServer.unsafe_get() != nullptr);
+    LOG_ALWAYS_FATAL_IF(mShutdownTrigger != nullptr);
+    LOG_ALWAYS_FATAL_IF(shutdownTrigger == nullptr);
+
     mId = sessionId;
     mForServer = server;
+    mShutdownTrigger = shutdownTrigger;
 }
 
 sp<RpcSession::RpcConnection> RpcSession::assignServerToThisThread(unique_fd fd) {
diff --git a/libs/binder/RpcState.cpp b/libs/binder/RpcState.cpp
index 230de6f..6483486 100644
--- a/libs/binder/RpcState.cpp
+++ b/libs/binder/RpcState.cpp
@@ -229,30 +229,22 @@
     return true;
 }
 
-bool RpcState::rpcRec(const base::unique_fd& fd, const char* what, void* data, size_t size) {
+bool RpcState::rpcRec(const base::unique_fd& fd, const sp<RpcSession>& session, const char* what,
+                      void* data, size_t size) {
     if (size > std::numeric_limits<ssize_t>::max()) {
         ALOGE("Cannot rec %s at size %zu (too big)", what, size);
         terminate();
         return false;
     }
 
-    ssize_t recd = TEMP_FAILURE_RETRY(recv(fd.get(), data, size, MSG_WAITALL | MSG_NOSIGNAL));
-
-    if (recd < 0 || recd != static_cast<ssize_t>(size)) {
-        terminate();
-
-        if (recd == 0 && errno == 0) {
-            LOG_RPC_DETAIL("No more data when trying to read %s on fd %d", what, fd.get());
-            return false;
-        }
-
-        ALOGE("Failed to read %s (received %zd of %zu bytes) on fd %d, error: %s", what, recd, size,
-              fd.get(), strerror(errno));
+    if (status_t status = session->mShutdownTrigger->interruptableReadFully(fd.get(), data, size);
+        status != OK) {
+        ALOGE("Failed to read %s (%zu bytes) on fd %d, error: %s", what, size, fd.get(),
+              statusToString(status).c_str());
         return false;
-    } else {
-        LOG_RPC_DETAIL("Received %s on fd %d: %s", what, fd.get(), hexString(data, size).c_str());
     }
 
+    LOG_RPC_DETAIL("Received %s on fd %d: %s", what, fd.get(), hexString(data, size).c_str());
     return true;
 }
 
@@ -398,7 +390,7 @@
                                 Parcel* reply) {
     RpcWireHeader command;
     while (true) {
-        if (!rpcRec(fd, "command header", &command, sizeof(command))) {
+        if (!rpcRec(fd, session, "command header", &command, sizeof(command))) {
             return DEAD_OBJECT;
         }
 
@@ -413,7 +405,7 @@
         return NO_MEMORY;
     }
 
-    if (!rpcRec(fd, "reply body", data.data(), command.bodySize)) {
+    if (!rpcRec(fd, session, "reply body", data.data(), command.bodySize)) {
         return DEAD_OBJECT;
     }
 
@@ -465,7 +457,7 @@
     LOG_RPC_DETAIL("getAndExecuteCommand on fd %d", fd.get());
 
     RpcWireHeader command;
-    if (!rpcRec(fd, "command header", &command, sizeof(command))) {
+    if (!rpcRec(fd, session, "command header", &command, sizeof(command))) {
         return DEAD_OBJECT;
     }
 
@@ -493,7 +485,7 @@
         case RPC_COMMAND_TRANSACT:
             return processTransact(fd, session, command);
         case RPC_COMMAND_DEC_STRONG:
-            return processDecStrong(fd, command);
+            return processDecStrong(fd, session, command);
     }
 
     // We should always know the version of the opposing side, and since the
@@ -513,7 +505,7 @@
     if (!transactionData.valid()) {
         return NO_MEMORY;
     }
-    if (!rpcRec(fd, "transaction body", transactionData.data(), transactionData.size())) {
+    if (!rpcRec(fd, session, "transaction body", transactionData.data(), transactionData.size())) {
         return DEAD_OBJECT;
     }
 
@@ -626,7 +618,7 @@
                         //
                         // sessions associated with servers must have an ID
                         // (hence abort)
-                        int32_t id = session->getPrivateAccessorForId().get().value();
+                        int32_t id = session->mId.value();
                         replyStatus = reply.writeInt32(id);
                         break;
                     }
@@ -721,14 +713,15 @@
     return OK;
 }
 
-status_t RpcState::processDecStrong(const base::unique_fd& fd, const RpcWireHeader& command) {
+status_t RpcState::processDecStrong(const base::unique_fd& fd, const sp<RpcSession>& session,
+                                    const RpcWireHeader& command) {
     LOG_ALWAYS_FATAL_IF(command.command != RPC_COMMAND_DEC_STRONG, "command: %d", command.command);
 
     CommandData commandData(command.bodySize);
     if (!commandData.valid()) {
         return NO_MEMORY;
     }
-    if (!rpcRec(fd, "dec ref body", commandData.data(), commandData.size())) {
+    if (!rpcRec(fd, session, "dec ref body", commandData.data(), commandData.size())) {
         return DEAD_OBJECT;
     }
 
diff --git a/libs/binder/RpcState.h b/libs/binder/RpcState.h
index 31f8a22..f913925 100644
--- a/libs/binder/RpcState.h
+++ b/libs/binder/RpcState.h
@@ -117,7 +117,8 @@
 
     [[nodiscard]] bool rpcSend(const base::unique_fd& fd, const char* what, const void* data,
                                size_t size);
-    [[nodiscard]] bool rpcRec(const base::unique_fd& fd, const char* what, void* data, size_t size);
+    [[nodiscard]] bool rpcRec(const base::unique_fd& fd, const sp<RpcSession>& session,
+                              const char* what, void* data, size_t size);
 
     [[nodiscard]] status_t waitForReply(const base::unique_fd& fd, const sp<RpcSession>& session,
                                         Parcel* reply);
@@ -130,6 +131,7 @@
                                                    const sp<RpcSession>& session,
                                                    CommandData transactionData);
     [[nodiscard]] status_t processDecStrong(const base::unique_fd& fd,
+                                            const sp<RpcSession>& session,
                                             const RpcWireHeader& command);
 
     struct BinderNode {
diff --git a/libs/binder/include/binder/RpcServer.h b/libs/binder/include/binder/RpcServer.h
index 50770f1..178459d 100644
--- a/libs/binder/include/binder/RpcServer.h
+++ b/libs/binder/include/binder/RpcServer.h
@@ -150,6 +150,7 @@
     // internal use only
 
     void onSessionTerminating(const sp<RpcSession>& session);
+    void onSessionThreadEnding(const sp<RpcSession>& session);
 
 private:
     friend sp<RpcServer>;
@@ -171,7 +172,7 @@
     wp<IBinder> mRootObjectWeak;
     std::map<int32_t, sp<RpcSession>> mSessions;
     int32_t mSessionIdCounter = 0;
-    std::unique_ptr<RpcSession::FdTrigger> mShutdownTrigger;
+    std::shared_ptr<RpcSession::FdTrigger> mShutdownTrigger;
     std::condition_variable mShutdownCv;
 };
 
diff --git a/libs/binder/include/binder/RpcSession.h b/libs/binder/include/binder/RpcSession.h
index d6b796f..d46f275 100644
--- a/libs/binder/include/binder/RpcSession.h
+++ b/libs/binder/include/binder/RpcSession.h
@@ -94,27 +94,16 @@
     // internal only
     const std::unique_ptr<RpcState>& state() { return mState; }
 
-    class PrivateAccessorForId {
-    private:
-        friend class RpcSession;
-        friend class RpcState;
-        explicit PrivateAccessorForId(const RpcSession* session) : mSession(session) {}
-
-        const std::optional<int32_t> get() { return mSession->mId; }
-
-        const RpcSession* mSession;
-    };
-    PrivateAccessorForId getPrivateAccessorForId() const { return PrivateAccessorForId(this); }
-
 private:
-    friend PrivateAccessorForId;
     friend sp<RpcSession>;
     friend RpcServer;
+    friend RpcState;
     RpcSession();
 
     /** This is not a pipe. */
     struct FdTrigger {
         static std::unique_ptr<FdTrigger> make();
+
         /**
          * poll() on this fd for POLLHUP to get notification when trigger is called
          */
@@ -167,7 +156,8 @@
     bool setupSocketClient(const RpcSocketAddress& address);
     bool setupOneSocketClient(const RpcSocketAddress& address, int32_t sessionId);
     void addClientConnection(base::unique_fd fd);
-    void setForServer(const wp<RpcServer>& server, int32_t sessionId);
+    void setForServer(const wp<RpcServer>& server, int32_t sessionId,
+                      const std::shared_ptr<FdTrigger>& shutdownTrigger);
     sp<RpcConnection> assignServerToThisThread(base::unique_fd fd);
     bool removeServerConnection(const sp<RpcConnection>& connection);
 
@@ -218,6 +208,8 @@
     // TODO(b/183988761): this shouldn't be guessable
     std::optional<int32_t> mId;
 
+    std::shared_ptr<FdTrigger> mShutdownTrigger;
+
     std::unique_ptr<RpcState> mState;
 
     std::mutex mMutex; // for all below