Merge "libbinder: shutdown session threads"
diff --git a/libs/binder/RpcServer.cpp b/libs/binder/RpcServer.cpp
index e3bf2a5..bff5543 100644
--- a/libs/binder/RpcServer.cpp
+++ b/libs/binder/RpcServer.cpp
@@ -192,10 +192,10 @@
     }
 
     mShutdownTrigger->trigger();
-    while (mJoinThreadRunning || !mConnectingThreads.empty()) {
+    while (mJoinThreadRunning || !mConnectingThreads.empty() || !mSessions.empty()) {
         ALOGI("Waiting for RpcServer to shut down. Join thread running: %d, Connecting threads: "
-              "%zu",
-              mJoinThreadRunning, mConnectingThreads.size());
+              "%zu, Sessions: %zu",
+              mJoinThreadRunning, mConnectingThreads.size(), mSessions.size());
         mShutdownCv.wait(_l);
     }
 
@@ -278,7 +278,8 @@
             server->mSessionIdCounter++;
 
             session = RpcSession::make();
-            session->setForServer(wp<RpcServer>(server), server->mSessionIdCounter);
+            session->setForServer(wp<RpcServer>(server), server->mSessionIdCounter,
+                                  server->mShutdownTrigger);
 
             server->mSessions[server->mSessionIdCounter] = session;
         } else {
@@ -344,6 +345,11 @@
     (void)mSessions.erase(it);
 }
 
+void RpcServer::onSessionThreadEnding(const sp<RpcSession>& session) {
+    (void)session;
+    mShutdownCv.notify_all();
+}
+
 bool RpcServer::hasServer() {
     LOG_ALWAYS_FATAL_IF(!mAgreedExperimental, "no!");
     std::lock_guard<std::mutex> _l(mLock);
diff --git a/libs/binder/RpcSession.cpp b/libs/binder/RpcSession.cpp
index 9f26a33..7c458c1 100644
--- a/libs/binder/RpcSession.cpp
+++ b/libs/binder/RpcSession.cpp
@@ -207,12 +207,19 @@
     LOG_ALWAYS_FATAL_IF(!removeServerConnection(connection),
                         "bad state: connection object guaranteed to be in list");
 
+    sp<RpcServer> server;
     {
         std::lock_guard<std::mutex> _l(mMutex);
         auto it = mThreads.find(std::this_thread::get_id());
         LOG_ALWAYS_FATAL_IF(it == mThreads.end());
         it->second.detach();
         mThreads.erase(it);
+
+        server = mForServer.promote();
+    }
+
+    if (server != nullptr) {
+        server->onSessionThreadEnding(sp<RpcSession>::fromExisting(this));
     }
 }
 
@@ -314,14 +321,25 @@
 
 void RpcSession::addClientConnection(unique_fd fd) {
     std::lock_guard<std::mutex> _l(mMutex);
+
+    if (mShutdownTrigger == nullptr) {
+        mShutdownTrigger = FdTrigger::make();
+    }
+
     sp<RpcConnection> session = sp<RpcConnection>::make();
     session->fd = std::move(fd);
     mClientConnections.push_back(session);
 }
 
-void RpcSession::setForServer(const wp<RpcServer>& server, int32_t sessionId) {
+void RpcSession::setForServer(const wp<RpcServer>& server, int32_t sessionId,
+                              const std::shared_ptr<FdTrigger>& shutdownTrigger) {
+    LOG_ALWAYS_FATAL_IF(mForServer.unsafe_get() != nullptr);
+    LOG_ALWAYS_FATAL_IF(mShutdownTrigger != nullptr);
+    LOG_ALWAYS_FATAL_IF(shutdownTrigger == nullptr);
+
     mId = sessionId;
     mForServer = server;
+    mShutdownTrigger = shutdownTrigger;
 }
 
 sp<RpcSession::RpcConnection> RpcSession::assignServerToThisThread(unique_fd fd) {
diff --git a/libs/binder/RpcState.cpp b/libs/binder/RpcState.cpp
index 230de6f..6483486 100644
--- a/libs/binder/RpcState.cpp
+++ b/libs/binder/RpcState.cpp
@@ -229,30 +229,22 @@
     return true;
 }
 
-bool RpcState::rpcRec(const base::unique_fd& fd, const char* what, void* data, size_t size) {
+bool RpcState::rpcRec(const base::unique_fd& fd, const sp<RpcSession>& session, const char* what,
+                      void* data, size_t size) {
     if (size > std::numeric_limits<ssize_t>::max()) {
         ALOGE("Cannot rec %s at size %zu (too big)", what, size);
         terminate();
         return false;
     }
 
-    ssize_t recd = TEMP_FAILURE_RETRY(recv(fd.get(), data, size, MSG_WAITALL | MSG_NOSIGNAL));
-
-    if (recd < 0 || recd != static_cast<ssize_t>(size)) {
-        terminate();
-
-        if (recd == 0 && errno == 0) {
-            LOG_RPC_DETAIL("No more data when trying to read %s on fd %d", what, fd.get());
-            return false;
-        }
-
-        ALOGE("Failed to read %s (received %zd of %zu bytes) on fd %d, error: %s", what, recd, size,
-              fd.get(), strerror(errno));
+    if (status_t status = session->mShutdownTrigger->interruptableReadFully(fd.get(), data, size);
+        status != OK) {
+        ALOGE("Failed to read %s (%zu bytes) on fd %d, error: %s", what, size, fd.get(),
+              statusToString(status).c_str());
         return false;
-    } else {
-        LOG_RPC_DETAIL("Received %s on fd %d: %s", what, fd.get(), hexString(data, size).c_str());
     }
 
+    LOG_RPC_DETAIL("Received %s on fd %d: %s", what, fd.get(), hexString(data, size).c_str());
     return true;
 }
 
@@ -398,7 +390,7 @@
                                 Parcel* reply) {
     RpcWireHeader command;
     while (true) {
-        if (!rpcRec(fd, "command header", &command, sizeof(command))) {
+        if (!rpcRec(fd, session, "command header", &command, sizeof(command))) {
             return DEAD_OBJECT;
         }
 
@@ -413,7 +405,7 @@
         return NO_MEMORY;
     }
 
-    if (!rpcRec(fd, "reply body", data.data(), command.bodySize)) {
+    if (!rpcRec(fd, session, "reply body", data.data(), command.bodySize)) {
         return DEAD_OBJECT;
     }
 
@@ -465,7 +457,7 @@
     LOG_RPC_DETAIL("getAndExecuteCommand on fd %d", fd.get());
 
     RpcWireHeader command;
-    if (!rpcRec(fd, "command header", &command, sizeof(command))) {
+    if (!rpcRec(fd, session, "command header", &command, sizeof(command))) {
         return DEAD_OBJECT;
     }
 
@@ -493,7 +485,7 @@
         case RPC_COMMAND_TRANSACT:
             return processTransact(fd, session, command);
         case RPC_COMMAND_DEC_STRONG:
-            return processDecStrong(fd, command);
+            return processDecStrong(fd, session, command);
     }
 
     // We should always know the version of the opposing side, and since the
@@ -513,7 +505,7 @@
     if (!transactionData.valid()) {
         return NO_MEMORY;
     }
-    if (!rpcRec(fd, "transaction body", transactionData.data(), transactionData.size())) {
+    if (!rpcRec(fd, session, "transaction body", transactionData.data(), transactionData.size())) {
         return DEAD_OBJECT;
     }
 
@@ -626,7 +618,7 @@
                         //
                         // sessions associated with servers must have an ID
                         // (hence abort)
-                        int32_t id = session->getPrivateAccessorForId().get().value();
+                        int32_t id = session->mId.value();
                         replyStatus = reply.writeInt32(id);
                         break;
                     }
@@ -721,14 +713,15 @@
     return OK;
 }
 
-status_t RpcState::processDecStrong(const base::unique_fd& fd, const RpcWireHeader& command) {
+status_t RpcState::processDecStrong(const base::unique_fd& fd, const sp<RpcSession>& session,
+                                    const RpcWireHeader& command) {
     LOG_ALWAYS_FATAL_IF(command.command != RPC_COMMAND_DEC_STRONG, "command: %d", command.command);
 
     CommandData commandData(command.bodySize);
     if (!commandData.valid()) {
         return NO_MEMORY;
     }
-    if (!rpcRec(fd, "dec ref body", commandData.data(), commandData.size())) {
+    if (!rpcRec(fd, session, "dec ref body", commandData.data(), commandData.size())) {
         return DEAD_OBJECT;
     }
 
diff --git a/libs/binder/RpcState.h b/libs/binder/RpcState.h
index 31f8a22..f913925 100644
--- a/libs/binder/RpcState.h
+++ b/libs/binder/RpcState.h
@@ -117,7 +117,8 @@
 
     [[nodiscard]] bool rpcSend(const base::unique_fd& fd, const char* what, const void* data,
                                size_t size);
-    [[nodiscard]] bool rpcRec(const base::unique_fd& fd, const char* what, void* data, size_t size);
+    [[nodiscard]] bool rpcRec(const base::unique_fd& fd, const sp<RpcSession>& session,
+                              const char* what, void* data, size_t size);
 
     [[nodiscard]] status_t waitForReply(const base::unique_fd& fd, const sp<RpcSession>& session,
                                         Parcel* reply);
@@ -130,6 +131,7 @@
                                                    const sp<RpcSession>& session,
                                                    CommandData transactionData);
     [[nodiscard]] status_t processDecStrong(const base::unique_fd& fd,
+                                            const sp<RpcSession>& session,
                                             const RpcWireHeader& command);
 
     struct BinderNode {
diff --git a/libs/binder/include/binder/RpcServer.h b/libs/binder/include/binder/RpcServer.h
index 50770f1..178459d 100644
--- a/libs/binder/include/binder/RpcServer.h
+++ b/libs/binder/include/binder/RpcServer.h
@@ -150,6 +150,7 @@
     // internal use only
 
     void onSessionTerminating(const sp<RpcSession>& session);
+    void onSessionThreadEnding(const sp<RpcSession>& session);
 
 private:
     friend sp<RpcServer>;
@@ -171,7 +172,7 @@
     wp<IBinder> mRootObjectWeak;
     std::map<int32_t, sp<RpcSession>> mSessions;
     int32_t mSessionIdCounter = 0;
-    std::unique_ptr<RpcSession::FdTrigger> mShutdownTrigger;
+    std::shared_ptr<RpcSession::FdTrigger> mShutdownTrigger;
     std::condition_variable mShutdownCv;
 };
 
diff --git a/libs/binder/include/binder/RpcSession.h b/libs/binder/include/binder/RpcSession.h
index d6b796f..d46f275 100644
--- a/libs/binder/include/binder/RpcSession.h
+++ b/libs/binder/include/binder/RpcSession.h
@@ -94,27 +94,16 @@
     // internal only
     const std::unique_ptr<RpcState>& state() { return mState; }
 
-    class PrivateAccessorForId {
-    private:
-        friend class RpcSession;
-        friend class RpcState;
-        explicit PrivateAccessorForId(const RpcSession* session) : mSession(session) {}
-
-        const std::optional<int32_t> get() { return mSession->mId; }
-
-        const RpcSession* mSession;
-    };
-    PrivateAccessorForId getPrivateAccessorForId() const { return PrivateAccessorForId(this); }
-
 private:
-    friend PrivateAccessorForId;
     friend sp<RpcSession>;
     friend RpcServer;
+    friend RpcState;
     RpcSession();
 
     /** This is not a pipe. */
     struct FdTrigger {
         static std::unique_ptr<FdTrigger> make();
+
         /**
          * poll() on this fd for POLLHUP to get notification when trigger is called
          */
@@ -167,7 +156,8 @@
     bool setupSocketClient(const RpcSocketAddress& address);
     bool setupOneSocketClient(const RpcSocketAddress& address, int32_t sessionId);
     void addClientConnection(base::unique_fd fd);
-    void setForServer(const wp<RpcServer>& server, int32_t sessionId);
+    void setForServer(const wp<RpcServer>& server, int32_t sessionId,
+                      const std::shared_ptr<FdTrigger>& shutdownTrigger);
     sp<RpcConnection> assignServerToThisThread(base::unique_fd fd);
     bool removeServerConnection(const sp<RpcConnection>& connection);
 
@@ -218,6 +208,8 @@
     // TODO(b/183988761): this shouldn't be guessable
     std::optional<int32_t> mId;
 
+    std::shared_ptr<FdTrigger> mShutdownTrigger;
+
     std::unique_ptr<RpcState> mState;
 
     std::mutex mMutex; // for all below