libbinder: reverse connections

When connecting to an RPC client server, you can request to serve a
threadpool so that you can receive callbacks from it.

Future considerations:
- starting threads dynamically (likely very, very soon after this CL)
- combining threadpools (as needed)

Bug: 185167543
Test: binderRpcTest
Change-Id: I992959e963ebc1b3da2f89fdb6c1ec625cb51af4
diff --git a/libs/binder/RpcSession.cpp b/libs/binder/RpcSession.cpp
index ccf7f89..a3efa56 100644
--- a/libs/binder/RpcSession.cpp
+++ b/libs/binder/RpcSession.cpp
@@ -59,6 +59,17 @@
     return sp<RpcSession>::make();
 }
 
+void RpcSession::setMaxReverseConnections(size_t connections) {
+    {
+        std::lock_guard<std::mutex> _l(mMutex);
+        LOG_ALWAYS_FATAL_IF(mClientConnections.size() != 0,
+                            "Must setup reverse connections before setting up client connections, "
+                            "but already has %zu clients",
+                            mClientConnections.size());
+    }
+    mMaxReverseConnections = connections;
+}
+
 bool RpcSession::setupUnixDomainClient(const char* path) {
     return setupSocketClient(UnixSocketAddress(path));
 }
@@ -99,6 +110,20 @@
     return state()->getMaxThreads(connection.fd(), sp<RpcSession>::fromExisting(this), maxThreads);
 }
 
+bool RpcSession::shutdown() {
+    std::unique_lock<std::mutex> _l(mMutex);
+    LOG_ALWAYS_FATAL_IF(mForServer.promote() != nullptr, "Can only shut down client session");
+    LOG_ALWAYS_FATAL_IF(mShutdownTrigger == nullptr, "Shutdown trigger not installed");
+    LOG_ALWAYS_FATAL_IF(mShutdownListener == nullptr, "Shutdown listener not installed");
+
+    mShutdownTrigger->trigger();
+    mShutdownListener->waitForShutdown(_l);
+    mState->terminate();
+
+    LOG_ALWAYS_FATAL_IF(!mThreads.empty(), "Shutdown failed");
+    return true;
+}
+
 status_t RpcSession::transact(const sp<IBinder>& binder, uint32_t code, const Parcel& data,
                               Parcel* reply, uint32_t flags) {
     ExclusiveConnection connection(sp<RpcSession>::fromExisting(this),
@@ -179,6 +204,24 @@
     return OK;
 }
 
+void RpcSession::WaitForShutdownListener::onSessionLockedAllServerThreadsEnded(
+        const sp<RpcSession>& session) {
+    (void)session;
+    mShutdown = true;
+}
+
+void RpcSession::WaitForShutdownListener::onSessionServerThreadEnded() {
+    mCv.notify_all();
+}
+
+void RpcSession::WaitForShutdownListener::waitForShutdown(std::unique_lock<std::mutex>& lock) {
+    while (!mShutdown) {
+        if (std::cv_status::timeout == mCv.wait_for(lock, std::chrono::seconds(1))) {
+            ALOGE("Waiting for RpcSession to shut down (1s w/o progress).");
+        }
+    }
+}
+
 void RpcSession::preJoin(std::thread thread) {
     LOG_ALWAYS_FATAL_IF(thread.get_id() != std::this_thread::get_id(), "Must own this thread");
 
@@ -188,14 +231,13 @@
     }
 }
 
-void RpcSession::join(unique_fd client) {
+void RpcSession::join(sp<RpcSession>&& session, unique_fd client) {
     // must be registered to allow arbitrary client code executing commands to
     // be able to do nested calls (we can't only read from it)
-    sp<RpcConnection> connection = assignServerToThisThread(std::move(client));
+    sp<RpcConnection> connection = session->assignServerToThisThread(std::move(client));
 
     while (true) {
-        status_t error =
-                state()->getAndExecuteCommand(connection->fd, sp<RpcSession>::fromExisting(this));
+        status_t error = session->state()->getAndExecuteCommand(connection->fd, session);
 
         if (error != OK) {
             LOG_RPC_DETAIL("Binder connection thread closing w/ status %s",
@@ -204,22 +246,24 @@
         }
     }
 
-    LOG_ALWAYS_FATAL_IF(!removeServerConnection(connection),
+    LOG_ALWAYS_FATAL_IF(!session->removeServerConnection(connection),
                         "bad state: connection object guaranteed to be in list");
 
-    sp<RpcServer> server;
+    sp<RpcSession::EventListener> listener;
     {
-        std::lock_guard<std::mutex> _l(mMutex);
-        auto it = mThreads.find(std::this_thread::get_id());
-        LOG_ALWAYS_FATAL_IF(it == mThreads.end());
+        std::lock_guard<std::mutex> _l(session->mMutex);
+        auto it = session->mThreads.find(std::this_thread::get_id());
+        LOG_ALWAYS_FATAL_IF(it == session->mThreads.end());
         it->second.detach();
-        mThreads.erase(it);
+        session->mThreads.erase(it);
 
-        server = mForServer.promote();
+        listener = session->mEventListener.promote();
     }
 
-    if (server != nullptr) {
-        server->onSessionServerThreadEnded(sp<RpcSession>::fromExisting(this));
+    session = nullptr;
+
+    if (listener != nullptr) {
+        listener->onSessionServerThreadEnded();
     }
 }
 
@@ -235,7 +279,7 @@
                             mClientConnections.size());
     }
 
-    if (!setupOneSocketClient(addr, RPC_SESSION_ID_NEW)) return false;
+    if (!setupOneSocketConnection(addr, RPC_SESSION_ID_NEW, false /*reverse*/)) return false;
 
     // TODO(b/185167543): we should add additional sessions dynamically
     // instead of all at once.
@@ -256,13 +300,23 @@
     // we've already setup one client
     for (size_t i = 0; i + 1 < numThreadsAvailable; i++) {
         // TODO(b/185167543): shutdown existing connections?
-        if (!setupOneSocketClient(addr, mId.value())) return false;
+        if (!setupOneSocketConnection(addr, mId.value(), false /*reverse*/)) return false;
+    }
+
+    // TODO(b/185167543): we should add additional sessions dynamically
+    // instead of all at once - the other side should be responsible for setting
+    // up additional connections. We need to create at least one (unless 0 are
+    // requested to be set) in order to allow the other side to reliably make
+    // any requests at all.
+
+    for (size_t i = 0; i < mMaxReverseConnections; i++) {
+        if (!setupOneSocketConnection(addr, mId.value(), true /*reverse*/)) return false;
     }
 
     return true;
 }
 
-bool RpcSession::setupOneSocketClient(const RpcSocketAddress& addr, int32_t id) {
+bool RpcSession::setupOneSocketConnection(const RpcSocketAddress& addr, int32_t id, bool reverse) {
     for (size_t tries = 0; tries < 5; tries++) {
         if (tries > 0) usleep(10000);
 
@@ -286,16 +340,47 @@
             return false;
         }
 
-        if (sizeof(id) != TEMP_FAILURE_RETRY(write(serverFd.get(), &id, sizeof(id)))) {
+        RpcConnectionHeader header{
+                .sessionId = id,
+        };
+        if (reverse) header.options |= RPC_CONNECTION_OPTION_REVERSE;
+
+        if (sizeof(header) != TEMP_FAILURE_RETRY(write(serverFd.get(), &header, sizeof(header)))) {
             int savedErrno = errno;
-            ALOGE("Could not write id to socket at %s: %s", addr.toString().c_str(),
+            ALOGE("Could not write connection header to socket at %s: %s", addr.toString().c_str(),
                   strerror(savedErrno));
             return false;
         }
 
         LOG_RPC_DETAIL("Socket at %s client with fd %d", addr.toString().c_str(), serverFd.get());
 
-        return addClientConnection(std::move(serverFd));
+        if (reverse) {
+            std::mutex mutex;
+            std::condition_variable joinCv;
+            std::unique_lock<std::mutex> lock(mutex);
+            std::thread thread;
+            sp<RpcSession> thiz = sp<RpcSession>::fromExisting(this);
+            bool ownershipTransferred = false;
+            thread = std::thread([&]() {
+                std::unique_lock<std::mutex> threadLock(mutex);
+                unique_fd fd = std::move(serverFd);
+                // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
+                sp<RpcSession> session = thiz;
+                session->preJoin(std::move(thread));
+                ownershipTransferred = true;
+                joinCv.notify_one();
+
+                threadLock.unlock();
+                // do not use & vars below
+
+                RpcSession::join(std::move(session), std::move(fd));
+            });
+            joinCv.wait(lock, [&] { return ownershipTransferred; });
+            LOG_ALWAYS_FATAL_IF(!ownershipTransferred);
+            return true;
+        } else {
+            return addClientConnection(std::move(serverFd));
+        }
     }
 
     ALOGE("Ran out of retries to connect to %s", addr.toString().c_str());
@@ -305,8 +390,11 @@
 bool RpcSession::addClientConnection(unique_fd fd) {
     std::lock_guard<std::mutex> _l(mMutex);
 
+    // first client connection added, but setForServer not called, so
+    // initializaing for a client.
     if (mShutdownTrigger == nullptr) {
         mShutdownTrigger = FdTrigger::make();
+        mEventListener = mShutdownListener = sp<WaitForShutdownListener>::make();
         if (mShutdownTrigger == nullptr) return false;
     }
 
@@ -316,14 +404,19 @@
     return true;
 }
 
-void RpcSession::setForServer(const wp<RpcServer>& server, int32_t sessionId,
+void RpcSession::setForServer(const wp<RpcServer>& server, const wp<EventListener>& eventListener,
+                              int32_t sessionId,
                               const std::shared_ptr<FdTrigger>& shutdownTrigger) {
-    LOG_ALWAYS_FATAL_IF(mForServer.unsafe_get() != nullptr);
+    LOG_ALWAYS_FATAL_IF(mForServer != nullptr);
+    LOG_ALWAYS_FATAL_IF(server == nullptr);
+    LOG_ALWAYS_FATAL_IF(mEventListener != nullptr);
+    LOG_ALWAYS_FATAL_IF(eventListener == nullptr);
     LOG_ALWAYS_FATAL_IF(mShutdownTrigger != nullptr);
     LOG_ALWAYS_FATAL_IF(shutdownTrigger == nullptr);
 
     mId = sessionId;
     mForServer = server;
+    mEventListener = eventListener;
     mShutdownTrigger = shutdownTrigger;
 }
 
@@ -343,9 +436,9 @@
         it != mServerConnections.end()) {
         mServerConnections.erase(it);
         if (mServerConnections.size() == 0) {
-            sp<RpcServer> server = mForServer.promote();
-            if (server) {
-                server->onSessionLockedAllServerThreadsEnded(sp<RpcSession>::fromExisting(this));
+            sp<EventListener> listener = mEventListener.promote();
+            if (listener) {
+                listener->onSessionLockedAllServerThreadsEnded(sp<RpcSession>::fromExisting(this));
             }
         }
         return true;
@@ -405,6 +498,8 @@
             break;
         }
 
+        // TODO(b/185167543): this should return an error, rather than crash a
+        // server
         // in regular binder, this would usually be a deadlock :)
         LOG_ALWAYS_FATAL_IF(mSession->mClientConnections.size() == 0,
                             "Session has no client connections. This is required for an RPC server "