Add RpcServer::shutdown.

The function terminates any existing execution of join().

After this CL, join() is only allowed to be called in one thread.

Test: binderLibTest
Change-Id: I5f1abbb39ee42a8f94b7394a702a152701537e7e
diff --git a/libs/binder/RpcServer.cpp b/libs/binder/RpcServer.cpp
index 59659bd..e31aea0 100644
--- a/libs/binder/RpcServer.cpp
+++ b/libs/binder/RpcServer.cpp
@@ -16,19 +16,21 @@
 
 #define LOG_TAG "RpcServer"
 
+#include <poll.h>
 #include <sys/socket.h>
 #include <sys/un.h>
 
 #include <thread>
 #include <vector>
 
+#include <android-base/macros.h>
 #include <android-base/scopeguard.h>
 #include <binder/Parcel.h>
 #include <binder/RpcServer.h>
 #include <log/log.h>
-#include "RpcState.h"
 
 #include "RpcSocketAddress.h"
+#include "RpcState.h"
 #include "RpcWireFormat.h"
 
 namespace android {
@@ -99,7 +101,7 @@
 
 void RpcServer::setMaxThreads(size_t threads) {
     LOG_ALWAYS_FATAL_IF(threads <= 0, "RpcServer is useless without threads");
-    LOG_ALWAYS_FATAL_IF(mStarted, "must be called before started");
+    LOG_ALWAYS_FATAL_IF(mJoinThreadRunning, "Cannot set max threads while running");
     mMaxThreads = threads;
 }
 
@@ -126,16 +128,61 @@
     return ret;
 }
 
+std::unique_ptr<RpcServer::FdTrigger> RpcServer::FdTrigger::make() {
+    auto ret = std::make_unique<RpcServer::FdTrigger>();
+    if (!android::base::Pipe(&ret->mRead, &ret->mWrite)) return nullptr;
+    return ret;
+}
+
+void RpcServer::FdTrigger::trigger() {
+    mWrite.reset();
+}
+
 void RpcServer::join() {
-    while (true) {
-        (void)acceptOne();
+    LOG_ALWAYS_FATAL_IF(!mAgreedExperimental, "no!");
+
+    {
+        std::lock_guard<std::mutex> _l(mLock);
+        LOG_ALWAYS_FATAL_IF(!mServer.ok(), "RpcServer must be setup to join.");
+        LOG_ALWAYS_FATAL_IF(mShutdownTrigger != nullptr, "Already joined");
+        mJoinThreadRunning = true;
+        mShutdownTrigger = FdTrigger::make();
+        LOG_ALWAYS_FATAL_IF(mShutdownTrigger == nullptr, "Cannot create join signaler");
     }
+
+    while (true) {
+        pollfd pfd[]{{.fd = mServer.get(), .events = POLLIN, .revents = 0},
+                     {.fd = mShutdownTrigger->readFd().get(), .events = POLLHUP, .revents = 0}};
+        int ret = TEMP_FAILURE_RETRY(poll(pfd, arraysize(pfd), -1));
+        if (ret < 0) {
+            ALOGE("Could not poll socket: %s", strerror(errno));
+            continue;
+        }
+        if (ret == 0) {
+            continue;
+        }
+        if (pfd[1].revents & POLLHUP) {
+            LOG_RPC_DETAIL("join() exiting because shutdown requested.");
+            break;
+        }
+
+        (void)acceptOneNoCheck();
+    }
+
+    {
+        std::lock_guard<std::mutex> _l(mLock);
+        mJoinThreadRunning = false;
+    }
+    mShutdownCv.notify_all();
 }
 
 bool RpcServer::acceptOne() {
     LOG_ALWAYS_FATAL_IF(!mAgreedExperimental, "no!");
-    LOG_ALWAYS_FATAL_IF(!hasServer(), "RpcServer must be setup to join.");
+    LOG_ALWAYS_FATAL_IF(!hasServer(), "RpcServer must be setup to acceptOne.");
+    return acceptOneNoCheck();
+}
 
+bool RpcServer::acceptOneNoCheck() {
     unique_fd clientFd(
             TEMP_FAILURE_RETRY(accept4(mServer.get(), nullptr, nullptr /*length*/, SOCK_CLOEXEC)));
 
@@ -156,6 +203,18 @@
     return true;
 }
 
+bool RpcServer::shutdown() {
+    LOG_ALWAYS_FATAL_IF(!mAgreedExperimental, "no!");
+    std::unique_lock<std::mutex> _l(mLock);
+    if (mShutdownTrigger == nullptr) return false;
+
+    mShutdownTrigger->trigger();
+    while (mJoinThreadRunning) mShutdownCv.wait(_l);
+
+    mShutdownTrigger = nullptr;
+    return true;
+}
+
 std::vector<sp<RpcSession>> RpcServer::listSessions() {
     std::lock_guard<std::mutex> _l(mLock);
     std::vector<sp<RpcSession>> sessions;
diff --git a/libs/binder/include/binder/RpcServer.h b/libs/binder/include/binder/RpcServer.h
index 8f0c6fd..4973400 100644
--- a/libs/binder/include/binder/RpcServer.h
+++ b/libs/binder/include/binder/RpcServer.h
@@ -119,11 +119,22 @@
     /**
      * You must have at least one client session before calling this.
      *
-     * TODO(b/185167543): way to shut down?
+     * If a client needs to actively terminate join, call shutdown() in a separate thread.
+     *
+     * At any given point, there can only be one thread calling join().
      */
     void join();
 
     /**
+     * Shut down any existing join(). Return true if successfully shut down, false otherwise
+     * (e.g. no join() is running). Will wait for the server to be fully
+     * shutdown.
+     *
+     * TODO(b/185167543): wait for sessions to shutdown as well
+     */
+    [[nodiscard]] bool shutdown();
+
+    /**
      * Accept one connection on this server. You must have at least one client
      * session before calling this.
      */
@@ -142,14 +153,31 @@
     void onSessionTerminating(const sp<RpcSession>& session);
 
 private:
+    /** This is not a pipe. */
+    struct FdTrigger {
+        static std::unique_ptr<FdTrigger> make();
+        /**
+         * poll() on this fd for POLLHUP to get notification when trigger is called
+         */
+        base::borrowed_fd readFd() const { return mRead; }
+        /**
+         * Close the write end of the pipe so that the read end receives POLLHUP.
+         */
+        void trigger();
+
+    private:
+        base::unique_fd mWrite;
+        base::unique_fd mRead;
+    };
+
     friend sp<RpcServer>;
     RpcServer();
 
     void establishConnection(sp<RpcServer>&& session, base::unique_fd clientFd);
     bool setupSocketServer(const RpcSocketAddress& address);
+    [[nodiscard]] bool acceptOneNoCheck();
 
     bool mAgreedExperimental = false;
-    bool mStarted = false; // TODO(b/185167543): support dynamically added clients
     size_t mMaxThreads = 1;
     base::unique_fd mServer; // socket we are accepting sessions on
 
@@ -159,6 +187,9 @@
     wp<IBinder> mRootObjectWeak;
     std::map<int32_t, sp<RpcSession>> mSessions;
     int32_t mSessionIdCounter = 0;
+    bool mJoinThreadRunning = false;
+    std::unique_ptr<FdTrigger> mShutdownTrigger;
+    std::condition_variable mShutdownCv;
 };
 
 } // namespace android
diff --git a/libs/binder/tests/binderRpcTest.cpp b/libs/binder/tests/binderRpcTest.cpp
index a96deb5..fb0ffdb 100644
--- a/libs/binder/tests/binderRpcTest.cpp
+++ b/libs/binder/tests/binderRpcTest.cpp
@@ -40,6 +40,8 @@
 #include "../RpcState.h"   // for debugging
 #include "../vm_sockets.h" // for VMADDR_*
 
+using namespace std::chrono_literals;
+
 namespace android {
 
 TEST(BinderRpcParcel, EntireParcelFormatted) {
@@ -970,6 +972,54 @@
 INSTANTIATE_TEST_CASE_P(BinderRpc, BinderRpcServerRootObject,
                         ::testing::Combine(::testing::Bool(), ::testing::Bool()));
 
+class OneOffSignal {
+public:
+    // If notify() was previously called, or is called within |duration|, return true; else false.
+    template <typename R, typename P>
+    bool wait(std::chrono::duration<R, P> duration) {
+        std::unique_lock<std::mutex> lock(mMutex);
+        return mCv.wait_for(lock, duration, [this] { return mValue; });
+    }
+    void notify() {
+        std::unique_lock<std::mutex> lock(mMutex);
+        mValue = true;
+        lock.unlock();
+        mCv.notify_all();
+    }
+
+private:
+    std::mutex mMutex;
+    std::condition_variable mCv;
+    bool mValue = false;
+};
+
+TEST(BinderRpc, Shutdown) {
+    auto addr = allocateSocketAddress();
+    unlink(addr.c_str());
+    auto server = RpcServer::make();
+    server->iUnderstandThisCodeIsExperimentalAndIWillNotUseItInProduction();
+    ASSERT_TRUE(server->setupUnixDomainServer(addr.c_str()));
+    auto joinEnds = std::make_shared<OneOffSignal>();
+
+    // If things are broken and the thread never stops, don't block other tests. Because the thread
+    // may run after the test finishes, it must not access the stack memory of the test. Hence,
+    // shared pointers are passed.
+    std::thread([server, joinEnds] {
+        server->join();
+        joinEnds->notify();
+    }).detach();
+
+    bool shutdown = false;
+    for (int i = 0; i < 10 && !shutdown; i++) {
+        usleep(300 * 1000); // 300ms; total 3s
+        if (server->shutdown()) shutdown = true;
+    }
+    ASSERT_TRUE(shutdown) << "server->shutdown() never returns true";
+
+    ASSERT_TRUE(joinEnds->wait(2s))
+            << "After server->shutdown() returns true, join() did not stop after 2s";
+}
+
 } // namespace android
 
 int main(int argc, char** argv) {