libbinder: RPC allow RpcSession to be reusable

Previously, there were two ways setup could fail:
- very quickly (e.g. cannot create fd to poll)
- after some delay (e.g. second connectoin messed up)

In either case, 'false' is returned from setup* funtions. However, in
the second case, if setup* is called again, then it would result in an
abort. When connections fail for unrelated reasons, this sometimes
causes aborts in existing tests.

Two alternatives considered:
- switch to factory-type setup, this seems a bit heavy, chiefly because
  typically, only one RpcSession is needed, so it is annoying to have to
  create a factory and an object.
- disallow setup* from being called multiple times - this breaks some
  of our tests, and it adds work to clients.

Bug: 200737956
Test: manual
Change-Id: Ia6a69a7d2ca6c6835844cd9a90c7d24646a83526
diff --git a/libs/binder/RpcSession.cpp b/libs/binder/RpcSession.cpp
index 65f6bc6..4465b8e 100644
--- a/libs/binder/RpcSession.cpp
+++ b/libs/binder/RpcSession.cpp
@@ -28,6 +28,7 @@
 
 #include <android-base/hex.h>
 #include <android-base/macros.h>
+#include <android-base/scopeguard.h>
 #include <android_runtime/vm.h>
 #include <binder/BpBinder.h>
 #include <binder/Parcel.h>
@@ -54,13 +55,13 @@
 RpcSession::RpcSession(std::unique_ptr<RpcTransportCtx> ctx) : mCtx(std::move(ctx)) {
     LOG_RPC_DETAIL("RpcSession created %p", this);
 
-    mState = std::make_unique<RpcState>();
+    mRpcBinderState = std::make_unique<RpcState>();
 }
 RpcSession::~RpcSession() {
     LOG_RPC_DETAIL("RpcSession destroyed %p", this);
 
     std::lock_guard<std::mutex> _l(mMutex);
-    LOG_ALWAYS_FATAL_IF(mIncomingConnections.size() != 0,
+    LOG_ALWAYS_FATAL_IF(mThreadState.mIncomingConnections.size() != 0,
                         "Should not be able to destroy a session with servers in use.");
 }
 
@@ -77,10 +78,12 @@
 
 void RpcSession::setMaxThreads(size_t threads) {
     std::lock_guard<std::mutex> _l(mMutex);
-    LOG_ALWAYS_FATAL_IF(!mOutgoingConnections.empty() || !mIncomingConnections.empty(),
+    LOG_ALWAYS_FATAL_IF(!mThreadState.mOutgoingConnections.empty() ||
+                                !mThreadState.mIncomingConnections.empty(),
                         "Must set max threads before setting up connections, but has %zu client(s) "
                         "and %zu server(s)",
-                        mOutgoingConnections.size(), mIncomingConnections.size());
+                        mThreadState.mOutgoingConnections.size(),
+                        mThreadState.mIncomingConnections.size());
     mMaxThreads = threads;
 }
 
@@ -194,11 +197,11 @@
         LOG_ALWAYS_FATAL_IF(mShutdownListener == nullptr, "Shutdown listener not installed");
         mShutdownListener->waitForShutdown(_l, sp<RpcSession>::fromExisting(this));
 
-        LOG_ALWAYS_FATAL_IF(!mThreads.empty(), "Shutdown failed");
+        LOG_ALWAYS_FATAL_IF(!mThreadState.mThreads.empty(), "Shutdown failed");
     }
 
     _l.unlock();
-    mState->clear();
+    mRpcBinderState->clear();
 
     return true;
 }
@@ -260,11 +263,11 @@
 
 void RpcSession::WaitForShutdownListener::waitForShutdown(std::unique_lock<std::mutex>& lock,
                                                           const sp<RpcSession>& session) {
-    while (session->mIncomingConnections.size() > 0) {
+    while (session->mThreadState.mIncomingConnections.size() > 0) {
         if (std::cv_status::timeout == mCv.wait_for(lock, std::chrono::seconds(1))) {
             ALOGE("Waiting for RpcSession to shut down (1s w/o progress): %zu incoming connections "
                   "still.",
-                  session->mIncomingConnections.size());
+                  session->mThreadState.mIncomingConnections.size());
         }
     }
 }
@@ -274,7 +277,7 @@
 
     {
         std::lock_guard<std::mutex> _l(mMutex);
-        mThreads[thread.get_id()] = std::move(thread);
+        mThreadState.mThreads[thread.get_id()] = std::move(thread);
     }
 }
 
@@ -289,7 +292,8 @@
     if (connection == nullptr) {
         status = DEAD_OBJECT;
     } else {
-        status = mState->readConnectionInit(connection, sp<RpcSession>::fromExisting(this));
+        status =
+                mRpcBinderState->readConnectionInit(connection, sp<RpcSession>::fromExisting(this));
     }
 
     return PreJoinSetupResult{
@@ -376,10 +380,10 @@
     sp<RpcSession::EventListener> listener;
     {
         std::lock_guard<std::mutex> _l(session->mMutex);
-        auto it = session->mThreads.find(std::this_thread::get_id());
-        LOG_ALWAYS_FATAL_IF(it == session->mThreads.end());
+        auto it = session->mThreadState.mThreads.find(std::this_thread::get_id());
+        LOG_ALWAYS_FATAL_IF(it == session->mThreadState.mThreads.end());
         it->second.detach();
-        session->mThreads.erase(it);
+        session->mThreadState.mThreads.erase(it);
 
         listener = session->mEventListener.promote();
     }
@@ -410,12 +414,34 @@
                                                               bool incoming)>& connectAndInit) {
     {
         std::lock_guard<std::mutex> _l(mMutex);
-        LOG_ALWAYS_FATAL_IF(mOutgoingConnections.size() != 0,
+        LOG_ALWAYS_FATAL_IF(mThreadState.mOutgoingConnections.size() != 0,
                             "Must only setup session once, but already has %zu clients",
-                            mOutgoingConnections.size());
+                            mThreadState.mOutgoingConnections.size());
     }
+
     if (auto status = initShutdownTrigger(); status != OK) return status;
 
+    auto oldProtocolVersion = mProtocolVersion;
+    auto cleanup = base::ScopeGuard([&] {
+        // if any threads are started, shut them down
+        (void)shutdownAndWait(true);
+
+        mShutdownListener = nullptr;
+        mEventListener.clear();
+
+        mId.clear();
+
+        mShutdownTrigger = nullptr;
+        mRpcBinderState = std::make_unique<RpcState>();
+
+        // protocol version may have been downgraded - if we reuse this object
+        // to connect to another server, force that server to request a
+        // downgrade again
+        mProtocolVersion = oldProtocolVersion;
+
+        mThreadState = {};
+    });
+
     if (status_t status = connectAndInit({}, false /*incoming*/); status != OK) return status;
 
     {
@@ -464,6 +490,8 @@
         if (status_t status = connectAndInit(mId, true /*incoming*/); status != OK) return status;
     }
 
+    cleanup.Disable();
+
     return OK;
 }
 
@@ -634,12 +662,12 @@
         std::lock_guard<std::mutex> _l(mMutex);
         connection->rpcTransport = std::move(rpcTransport);
         connection->exclusiveTid = gettid();
-        mOutgoingConnections.push_back(connection);
+        mThreadState.mOutgoingConnections.push_back(connection);
     }
 
     status_t status = OK;
     if (init) {
-        mState->sendConnectionInit(connection, sp<RpcSession>::fromExisting(this));
+        mRpcBinderState->sendConnectionInit(connection, sp<RpcSession>::fromExisting(this));
     }
 
     {
@@ -671,9 +699,9 @@
         std::unique_ptr<RpcTransport> rpcTransport) {
     std::lock_guard<std::mutex> _l(mMutex);
 
-    if (mIncomingConnections.size() >= mMaxThreads) {
+    if (mThreadState.mIncomingConnections.size() >= mMaxThreads) {
         ALOGE("Cannot add thread to session with %zu threads (max is set to %zu)",
-              mIncomingConnections.size(), mMaxThreads);
+              mThreadState.mIncomingConnections.size(), mMaxThreads);
         return nullptr;
     }
 
@@ -681,7 +709,7 @@
     // happens when new connections are still being established as part of a
     // very short-lived session which shuts down after it already started
     // accepting new connections.
-    if (mIncomingConnections.size() < mMaxIncomingConnections) {
+    if (mThreadState.mIncomingConnections.size() < mThreadState.mMaxIncomingConnections) {
         return nullptr;
     }
 
@@ -689,18 +717,19 @@
     session->rpcTransport = std::move(rpcTransport);
     session->exclusiveTid = gettid();
 
-    mIncomingConnections.push_back(session);
-    mMaxIncomingConnections = mIncomingConnections.size();
+    mThreadState.mIncomingConnections.push_back(session);
+    mThreadState.mMaxIncomingConnections = mThreadState.mIncomingConnections.size();
 
     return session;
 }
 
 bool RpcSession::removeIncomingConnection(const sp<RpcConnection>& connection) {
     std::unique_lock<std::mutex> _l(mMutex);
-    if (auto it = std::find(mIncomingConnections.begin(), mIncomingConnections.end(), connection);
-        it != mIncomingConnections.end()) {
-        mIncomingConnections.erase(it);
-        if (mIncomingConnections.size() == 0) {
+    if (auto it = std::find(mThreadState.mIncomingConnections.begin(),
+                            mThreadState.mIncomingConnections.end(), connection);
+        it != mThreadState.mIncomingConnections.end()) {
+        mThreadState.mIncomingConnections.erase(it);
+        if (mThreadState.mIncomingConnections.size() == 0) {
             sp<EventListener> listener = mEventListener.promote();
             if (listener) {
                 _l.unlock();
@@ -725,7 +754,7 @@
     pid_t tid = gettid();
     std::unique_lock<std::mutex> _l(session->mMutex);
 
-    session->mWaitingThreads++;
+    session->mThreadState.mWaitingThreads++;
     while (true) {
         sp<RpcConnection> exclusive;
         sp<RpcConnection> available;
@@ -733,8 +762,8 @@
         // CHECK FOR DEDICATED CLIENT SOCKET
         //
         // A server/looper should always use a dedicated connection if available
-        findConnection(tid, &exclusive, &available, session->mOutgoingConnections,
-                       session->mOutgoingConnectionsOffset);
+        findConnection(tid, &exclusive, &available, session->mThreadState.mOutgoingConnections,
+                       session->mThreadState.mOutgoingConnectionsOffset);
 
         // WARNING: this assumes a server cannot request its client to send
         // a transaction, as mIncomingConnections is excluded below.
@@ -747,8 +776,9 @@
         // command. So, we move to considering the second available thread
         // for subsequent calls.
         if (use == ConnectionUse::CLIENT_ASYNC && (exclusive != nullptr || available != nullptr)) {
-            session->mOutgoingConnectionsOffset = (session->mOutgoingConnectionsOffset + 1) %
-                    session->mOutgoingConnections.size();
+            session->mThreadState.mOutgoingConnectionsOffset =
+                    (session->mThreadState.mOutgoingConnectionsOffset + 1) %
+                    session->mThreadState.mOutgoingConnections.size();
         }
 
         // USE SERVING SOCKET (e.g. nested transaction)
@@ -756,7 +786,7 @@
             sp<RpcConnection> exclusiveIncoming;
             // server connections are always assigned to a thread
             findConnection(tid, &exclusiveIncoming, nullptr /*available*/,
-                           session->mIncomingConnections, 0 /* index hint */);
+                           session->mThreadState.mIncomingConnections, 0 /* index hint */);
 
             // asynchronous calls cannot be nested, we currently allow ref count
             // calls to be nested (so that you can use this without having extra
@@ -785,19 +815,20 @@
             break;
         }
 
-        if (session->mOutgoingConnections.size() == 0) {
+        if (session->mThreadState.mOutgoingConnections.size() == 0) {
             ALOGE("Session has no client connections. This is required for an RPC server to make "
                   "any non-nested (e.g. oneway or on another thread) calls. Use: %d. Server "
                   "connections: %zu",
-                  static_cast<int>(use), session->mIncomingConnections.size());
+                  static_cast<int>(use), session->mThreadState.mIncomingConnections.size());
             return WOULD_BLOCK;
         }
 
         LOG_RPC_DETAIL("No available connections (have %zu clients and %zu servers). Waiting...",
-                       session->mOutgoingConnections.size(), session->mIncomingConnections.size());
+                       session->mThreadState.mOutgoingConnections.size(),
+                       session->mThreadState.mIncomingConnections.size());
         session->mAvailableConnectionCv.wait(_l);
     }
-    session->mWaitingThreads--;
+    session->mThreadState.mWaitingThreads--;
 
     return OK;
 }
@@ -836,7 +867,7 @@
     if (!mReentrant && mConnection != nullptr) {
         std::unique_lock<std::mutex> _l(mSession->mMutex);
         mConnection->exclusiveTid = std::nullopt;
-        if (mSession->mWaitingThreads > 0) {
+        if (mSession->mThreadState.mWaitingThreads > 0) {
             _l.unlock();
             mSession->mAvailableConnectionCv.notify_one();
         }