libbinder: reverse connections

When connecting to an RPC client server, you can request to serve a
threadpool so that you can receive callbacks from it.

Future considerations:
- starting threads dynamically (likely very, very soon after this CL)
- combining threadpools (as needed)

Bug: 185167543
Test: binderRpcTest
Change-Id: I992959e963ebc1b3da2f89fdb6c1ec625cb51af4
diff --git a/libs/binder/RpcServer.cpp b/libs/binder/RpcServer.cpp
index 77cae83..b146bb0 100644
--- a/libs/binder/RpcServer.cpp
+++ b/libs/binder/RpcServer.cpp
@@ -239,15 +239,16 @@
     // It must be set before this thread is started
     LOG_ALWAYS_FATAL_IF(server->mShutdownTrigger == nullptr);
 
-    int32_t id;
-    status_t status =
-            server->mShutdownTrigger->interruptableReadFully(clientFd.get(), &id, sizeof(id));
+    RpcConnectionHeader header;
+    status_t status = server->mShutdownTrigger->interruptableReadFully(clientFd.get(), &header,
+                                                                       sizeof(header));
     bool idValid = status == OK;
     if (!idValid) {
         ALOGE("Failed to read ID for client connecting to RPC server: %s",
               statusToString(status).c_str());
         // still need to cleanup before we can return
     }
+    bool reverse = header.options & RPC_CONNECTION_OPTION_REVERSE;
 
     std::thread thisThread;
     sp<RpcSession> session;
@@ -269,24 +270,37 @@
             return;
         }
 
-        if (id == RPC_SESSION_ID_NEW) {
+        if (header.sessionId == RPC_SESSION_ID_NEW) {
+            if (reverse) {
+                ALOGE("Cannot create a new session with a reverse connection, would leak");
+                return;
+            }
+
             LOG_ALWAYS_FATAL_IF(server->mSessionIdCounter >= INT32_MAX, "Out of session IDs");
             server->mSessionIdCounter++;
 
             session = RpcSession::make();
-            session->setForServer(wp<RpcServer>(server), server->mSessionIdCounter,
-                                  server->mShutdownTrigger);
+            session->setForServer(server,
+                                  sp<RpcServer::EventListener>::fromExisting(
+                                          static_cast<RpcServer::EventListener*>(server.get())),
+                                  server->mSessionIdCounter, server->mShutdownTrigger);
 
             server->mSessions[server->mSessionIdCounter] = session;
         } else {
-            auto it = server->mSessions.find(id);
+            auto it = server->mSessions.find(header.sessionId);
             if (it == server->mSessions.end()) {
-                ALOGE("Cannot add thread, no record of session with ID %d", id);
+                ALOGE("Cannot add thread, no record of session with ID %d", header.sessionId);
                 return;
             }
             session = it->second;
         }
 
+        if (reverse) {
+            LOG_ALWAYS_FATAL_IF(!session->addClientConnection(std::move(clientFd)),
+                                "server state must already be initialized");
+            return;
+        }
+
         detachGuard.Disable();
         session->preJoin(std::move(thisThread));
     }
@@ -294,7 +308,7 @@
     // avoid strong cycle
     server = nullptr;
 
-    session->join(std::move(clientFd));
+    RpcSession::join(std::move(session), std::move(clientFd));
 }
 
 bool RpcServer::setupSocketServer(const RpcSocketAddress& addr) {
@@ -341,8 +355,7 @@
     (void)mSessions.erase(it);
 }
 
-void RpcServer::onSessionServerThreadEnded(const sp<RpcSession>& session) {
-    (void)session;
+void RpcServer::onSessionServerThreadEnded() {
     mShutdownCv.notify_all();
 }
 
diff --git a/libs/binder/RpcSession.cpp b/libs/binder/RpcSession.cpp
index ccf7f89..a3efa56 100644
--- a/libs/binder/RpcSession.cpp
+++ b/libs/binder/RpcSession.cpp
@@ -59,6 +59,17 @@
     return sp<RpcSession>::make();
 }
 
+void RpcSession::setMaxReverseConnections(size_t connections) {
+    {
+        std::lock_guard<std::mutex> _l(mMutex);
+        LOG_ALWAYS_FATAL_IF(mClientConnections.size() != 0,
+                            "Must setup reverse connections before setting up client connections, "
+                            "but already has %zu clients",
+                            mClientConnections.size());
+    }
+    mMaxReverseConnections = connections;
+}
+
 bool RpcSession::setupUnixDomainClient(const char* path) {
     return setupSocketClient(UnixSocketAddress(path));
 }
@@ -99,6 +110,20 @@
     return state()->getMaxThreads(connection.fd(), sp<RpcSession>::fromExisting(this), maxThreads);
 }
 
+bool RpcSession::shutdown() {
+    std::unique_lock<std::mutex> _l(mMutex);
+    LOG_ALWAYS_FATAL_IF(mForServer.promote() != nullptr, "Can only shut down client session");
+    LOG_ALWAYS_FATAL_IF(mShutdownTrigger == nullptr, "Shutdown trigger not installed");
+    LOG_ALWAYS_FATAL_IF(mShutdownListener == nullptr, "Shutdown listener not installed");
+
+    mShutdownTrigger->trigger();
+    mShutdownListener->waitForShutdown(_l);
+    mState->terminate();
+
+    LOG_ALWAYS_FATAL_IF(!mThreads.empty(), "Shutdown failed");
+    return true;
+}
+
 status_t RpcSession::transact(const sp<IBinder>& binder, uint32_t code, const Parcel& data,
                               Parcel* reply, uint32_t flags) {
     ExclusiveConnection connection(sp<RpcSession>::fromExisting(this),
@@ -179,6 +204,24 @@
     return OK;
 }
 
+void RpcSession::WaitForShutdownListener::onSessionLockedAllServerThreadsEnded(
+        const sp<RpcSession>& session) {
+    (void)session;
+    mShutdown = true;
+}
+
+void RpcSession::WaitForShutdownListener::onSessionServerThreadEnded() {
+    mCv.notify_all();
+}
+
+void RpcSession::WaitForShutdownListener::waitForShutdown(std::unique_lock<std::mutex>& lock) {
+    while (!mShutdown) {
+        if (std::cv_status::timeout == mCv.wait_for(lock, std::chrono::seconds(1))) {
+            ALOGE("Waiting for RpcSession to shut down (1s w/o progress).");
+        }
+    }
+}
+
 void RpcSession::preJoin(std::thread thread) {
     LOG_ALWAYS_FATAL_IF(thread.get_id() != std::this_thread::get_id(), "Must own this thread");
 
@@ -188,14 +231,13 @@
     }
 }
 
-void RpcSession::join(unique_fd client) {
+void RpcSession::join(sp<RpcSession>&& session, unique_fd client) {
     // must be registered to allow arbitrary client code executing commands to
     // be able to do nested calls (we can't only read from it)
-    sp<RpcConnection> connection = assignServerToThisThread(std::move(client));
+    sp<RpcConnection> connection = session->assignServerToThisThread(std::move(client));
 
     while (true) {
-        status_t error =
-                state()->getAndExecuteCommand(connection->fd, sp<RpcSession>::fromExisting(this));
+        status_t error = session->state()->getAndExecuteCommand(connection->fd, session);
 
         if (error != OK) {
             LOG_RPC_DETAIL("Binder connection thread closing w/ status %s",
@@ -204,22 +246,24 @@
         }
     }
 
-    LOG_ALWAYS_FATAL_IF(!removeServerConnection(connection),
+    LOG_ALWAYS_FATAL_IF(!session->removeServerConnection(connection),
                         "bad state: connection object guaranteed to be in list");
 
-    sp<RpcServer> server;
+    sp<RpcSession::EventListener> listener;
     {
-        std::lock_guard<std::mutex> _l(mMutex);
-        auto it = mThreads.find(std::this_thread::get_id());
-        LOG_ALWAYS_FATAL_IF(it == mThreads.end());
+        std::lock_guard<std::mutex> _l(session->mMutex);
+        auto it = session->mThreads.find(std::this_thread::get_id());
+        LOG_ALWAYS_FATAL_IF(it == session->mThreads.end());
         it->second.detach();
-        mThreads.erase(it);
+        session->mThreads.erase(it);
 
-        server = mForServer.promote();
+        listener = session->mEventListener.promote();
     }
 
-    if (server != nullptr) {
-        server->onSessionServerThreadEnded(sp<RpcSession>::fromExisting(this));
+    session = nullptr;
+
+    if (listener != nullptr) {
+        listener->onSessionServerThreadEnded();
     }
 }
 
@@ -235,7 +279,7 @@
                             mClientConnections.size());
     }
 
-    if (!setupOneSocketClient(addr, RPC_SESSION_ID_NEW)) return false;
+    if (!setupOneSocketConnection(addr, RPC_SESSION_ID_NEW, false /*reverse*/)) return false;
 
     // TODO(b/185167543): we should add additional sessions dynamically
     // instead of all at once.
@@ -256,13 +300,23 @@
     // we've already setup one client
     for (size_t i = 0; i + 1 < numThreadsAvailable; i++) {
         // TODO(b/185167543): shutdown existing connections?
-        if (!setupOneSocketClient(addr, mId.value())) return false;
+        if (!setupOneSocketConnection(addr, mId.value(), false /*reverse*/)) return false;
+    }
+
+    // TODO(b/185167543): we should add additional sessions dynamically
+    // instead of all at once - the other side should be responsible for setting
+    // up additional connections. We need to create at least one (unless 0 are
+    // requested to be set) in order to allow the other side to reliably make
+    // any requests at all.
+
+    for (size_t i = 0; i < mMaxReverseConnections; i++) {
+        if (!setupOneSocketConnection(addr, mId.value(), true /*reverse*/)) return false;
     }
 
     return true;
 }
 
-bool RpcSession::setupOneSocketClient(const RpcSocketAddress& addr, int32_t id) {
+bool RpcSession::setupOneSocketConnection(const RpcSocketAddress& addr, int32_t id, bool reverse) {
     for (size_t tries = 0; tries < 5; tries++) {
         if (tries > 0) usleep(10000);
 
@@ -286,16 +340,47 @@
             return false;
         }
 
-        if (sizeof(id) != TEMP_FAILURE_RETRY(write(serverFd.get(), &id, sizeof(id)))) {
+        RpcConnectionHeader header{
+                .sessionId = id,
+        };
+        if (reverse) header.options |= RPC_CONNECTION_OPTION_REVERSE;
+
+        if (sizeof(header) != TEMP_FAILURE_RETRY(write(serverFd.get(), &header, sizeof(header)))) {
             int savedErrno = errno;
-            ALOGE("Could not write id to socket at %s: %s", addr.toString().c_str(),
+            ALOGE("Could not write connection header to socket at %s: %s", addr.toString().c_str(),
                   strerror(savedErrno));
             return false;
         }
 
         LOG_RPC_DETAIL("Socket at %s client with fd %d", addr.toString().c_str(), serverFd.get());
 
-        return addClientConnection(std::move(serverFd));
+        if (reverse) {
+            std::mutex mutex;
+            std::condition_variable joinCv;
+            std::unique_lock<std::mutex> lock(mutex);
+            std::thread thread;
+            sp<RpcSession> thiz = sp<RpcSession>::fromExisting(this);
+            bool ownershipTransferred = false;
+            thread = std::thread([&]() {
+                std::unique_lock<std::mutex> threadLock(mutex);
+                unique_fd fd = std::move(serverFd);
+                // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
+                sp<RpcSession> session = thiz;
+                session->preJoin(std::move(thread));
+                ownershipTransferred = true;
+                joinCv.notify_one();
+
+                threadLock.unlock();
+                // do not use & vars below
+
+                RpcSession::join(std::move(session), std::move(fd));
+            });
+            joinCv.wait(lock, [&] { return ownershipTransferred; });
+            LOG_ALWAYS_FATAL_IF(!ownershipTransferred);
+            return true;
+        } else {
+            return addClientConnection(std::move(serverFd));
+        }
     }
 
     ALOGE("Ran out of retries to connect to %s", addr.toString().c_str());
@@ -305,8 +390,11 @@
 bool RpcSession::addClientConnection(unique_fd fd) {
     std::lock_guard<std::mutex> _l(mMutex);
 
+    // first client connection added, but setForServer not called, so
+    // initializaing for a client.
     if (mShutdownTrigger == nullptr) {
         mShutdownTrigger = FdTrigger::make();
+        mEventListener = mShutdownListener = sp<WaitForShutdownListener>::make();
         if (mShutdownTrigger == nullptr) return false;
     }
 
@@ -316,14 +404,19 @@
     return true;
 }
 
-void RpcSession::setForServer(const wp<RpcServer>& server, int32_t sessionId,
+void RpcSession::setForServer(const wp<RpcServer>& server, const wp<EventListener>& eventListener,
+                              int32_t sessionId,
                               const std::shared_ptr<FdTrigger>& shutdownTrigger) {
-    LOG_ALWAYS_FATAL_IF(mForServer.unsafe_get() != nullptr);
+    LOG_ALWAYS_FATAL_IF(mForServer != nullptr);
+    LOG_ALWAYS_FATAL_IF(server == nullptr);
+    LOG_ALWAYS_FATAL_IF(mEventListener != nullptr);
+    LOG_ALWAYS_FATAL_IF(eventListener == nullptr);
     LOG_ALWAYS_FATAL_IF(mShutdownTrigger != nullptr);
     LOG_ALWAYS_FATAL_IF(shutdownTrigger == nullptr);
 
     mId = sessionId;
     mForServer = server;
+    mEventListener = eventListener;
     mShutdownTrigger = shutdownTrigger;
 }
 
@@ -343,9 +436,9 @@
         it != mServerConnections.end()) {
         mServerConnections.erase(it);
         if (mServerConnections.size() == 0) {
-            sp<RpcServer> server = mForServer.promote();
-            if (server) {
-                server->onSessionLockedAllServerThreadsEnded(sp<RpcSession>::fromExisting(this));
+            sp<EventListener> listener = mEventListener.promote();
+            if (listener) {
+                listener->onSessionLockedAllServerThreadsEnded(sp<RpcSession>::fromExisting(this));
             }
         }
         return true;
@@ -405,6 +498,8 @@
             break;
         }
 
+        // TODO(b/185167543): this should return an error, rather than crash a
+        // server
         // in regular binder, this would usually be a deadlock :)
         LOG_ALWAYS_FATAL_IF(mSession->mClientConnections.size() == 0,
                             "Session has no client connections. This is required for an RPC server "
diff --git a/libs/binder/RpcState.cpp b/libs/binder/RpcState.cpp
index 2cad2ae..2f6b1b3 100644
--- a/libs/binder/RpcState.cpp
+++ b/libs/binder/RpcState.cpp
@@ -383,6 +383,7 @@
         return status;
 
     if (flags & IBinder::FLAG_ONEWAY) {
+        LOG_RPC_DETAIL("Oneway command, so no longer waiting on %d", fd.get());
         return OK; // do not wait for result
     }
 
diff --git a/libs/binder/RpcState.h b/libs/binder/RpcState.h
index 8a0610e..aacb530 100644
--- a/libs/binder/RpcState.h
+++ b/libs/binder/RpcState.h
@@ -86,7 +86,6 @@
     size_t countBinders();
     void dump();
 
-private:
     /**
      * Called when reading or writing data to a session fails to clean up
      * data associated with the session in order to cleanup binders.
@@ -105,6 +104,7 @@
      */
     void terminate();
 
+private:
     // Alternative to std::vector<uint8_t> that doesn't abort on allocation failure and caps
     // large allocations to avoid being requested from allocating too much data.
     struct CommandData {
diff --git a/libs/binder/RpcWireFormat.h b/libs/binder/RpcWireFormat.h
index c5fa008..649c1ee 100644
--- a/libs/binder/RpcWireFormat.h
+++ b/libs/binder/RpcWireFormat.h
@@ -20,6 +20,18 @@
 #pragma clang diagnostic push
 #pragma clang diagnostic error "-Wpadded"
 
+constexpr int32_t RPC_SESSION_ID_NEW = -1;
+
+enum : uint8_t {
+    RPC_CONNECTION_OPTION_REVERSE = 0x1,
+};
+
+struct RpcConnectionHeader {
+    int32_t sessionId;
+    uint8_t options;
+    uint8_t reserved[3];
+};
+
 enum : uint32_t {
     /**
      * follows is RpcWireTransaction, if flags != oneway, reply w/ RPC_COMMAND_REPLY expected
@@ -51,8 +63,6 @@
     RPC_SPECIAL_TRANSACT_GET_SESSION_ID = 2,
 };
 
-constexpr int32_t RPC_SESSION_ID_NEW = -1;
-
 // serialization is like:
 // |RpcWireHeader|struct desginated by 'command'| (over and over again)
 
diff --git a/libs/binder/include/binder/RpcServer.h b/libs/binder/include/binder/RpcServer.h
index 8ad5821..0082ec3 100644
--- a/libs/binder/include/binder/RpcServer.h
+++ b/libs/binder/include/binder/RpcServer.h
@@ -44,7 +44,7 @@
  *     }
  *     server->join();
  */
-class RpcServer final : public virtual RefBase {
+class RpcServer final : public virtual RefBase, private RpcSession::EventListener {
 public:
     static sp<RpcServer> make();
 
@@ -151,15 +151,13 @@
 
     ~RpcServer();
 
-    // internal use only
-
-    void onSessionLockedAllServerThreadsEnded(const sp<RpcSession>& session);
-    void onSessionServerThreadEnded(const sp<RpcSession>& session);
-
 private:
     friend sp<RpcServer>;
     RpcServer();
 
+    void onSessionLockedAllServerThreadsEnded(const sp<RpcSession>& session) override;
+    void onSessionServerThreadEnded() override;
+
     static void establishConnection(sp<RpcServer>&& server, base::unique_fd clientFd);
     bool setupSocketServer(const RpcSocketAddress& address);
     [[nodiscard]] bool acceptOne();
diff --git a/libs/binder/include/binder/RpcSession.h b/libs/binder/include/binder/RpcSession.h
index eadf0f8..9d314e4 100644
--- a/libs/binder/include/binder/RpcSession.h
+++ b/libs/binder/include/binder/RpcSession.h
@@ -47,6 +47,18 @@
     static sp<RpcSession> make();
 
     /**
+     * Set the maximum number of reverse connections allowed to be made (for
+     * things like callbacks). By default, this is 0. This must be called before
+     * setting up this connection as a client.
+     *
+     * If this is called, 'shutdown' on this session must also be called.
+     * Otherwise, a threadpool will leak.
+     *
+     * TODO(b/185167543): start these dynamically
+     */
+    void setMaxReverseConnections(size_t connections);
+
+    /**
      * This should be called once per thread, matching 'join' in the remote
      * process.
      */
@@ -83,6 +95,16 @@
      */
     status_t getRemoteMaxThreads(size_t* maxThreads);
 
+    /**
+     * Shuts down the service. Only works for client sessions (server-side
+     * sessions currently only support shutting down the entire server).
+     *
+     * Warning: this is currently not active/nice (the server isn't told we're
+     * shutting down). Being nicer to the server could potentially make it
+     * reclaim resources faster.
+     */
+    [[nodiscard]] bool shutdown();
+
     [[nodiscard]] status_t transact(const sp<IBinder>& binder, uint32_t code, const Parcel& data,
                                     Parcel* reply, uint32_t flags);
     [[nodiscard]] status_t sendDecStrong(const RpcAddress& address);
@@ -138,12 +160,29 @@
         base::unique_fd mRead;
     };
 
+    class EventListener : public virtual RefBase {
+    public:
+        virtual void onSessionLockedAllServerThreadsEnded(const sp<RpcSession>& session) = 0;
+        virtual void onSessionServerThreadEnded() = 0;
+    };
+
+    class WaitForShutdownListener : public EventListener {
+    public:
+        void onSessionLockedAllServerThreadsEnded(const sp<RpcSession>& session) override;
+        void onSessionServerThreadEnded() override;
+        void waitForShutdown(std::unique_lock<std::mutex>& lock);
+
+    private:
+        std::condition_variable mCv;
+        bool mShutdown = false;
+    };
+
     status_t readId();
 
     // transfer ownership of thread
     void preJoin(std::thread thread);
     // join on thread passed to preJoin
-    void join(base::unique_fd client);
+    static void join(sp<RpcSession>&& session, base::unique_fd client);
 
     struct RpcConnection : public RefBase {
         base::unique_fd fd;
@@ -153,13 +192,15 @@
         std::optional<pid_t> exclusiveTid;
     };
 
-    bool setupSocketClient(const RpcSocketAddress& address);
-    bool setupOneSocketClient(const RpcSocketAddress& address, int32_t sessionId);
-    bool addClientConnection(base::unique_fd fd);
-    void setForServer(const wp<RpcServer>& server, int32_t sessionId,
+    [[nodiscard]] bool setupSocketClient(const RpcSocketAddress& address);
+    [[nodiscard]] bool setupOneSocketConnection(const RpcSocketAddress& address, int32_t sessionId,
+                                                bool server);
+    [[nodiscard]] bool addClientConnection(base::unique_fd fd);
+    void setForServer(const wp<RpcServer>& server,
+                      const wp<RpcSession::EventListener>& eventListener, int32_t sessionId,
                       const std::shared_ptr<FdTrigger>& shutdownTrigger);
     sp<RpcConnection> assignServerToThisThread(base::unique_fd fd);
-    bool removeServerConnection(const sp<RpcConnection>& connection);
+    [[nodiscard]] bool removeServerConnection(const sp<RpcConnection>& connection);
 
     enum class ConnectionUse {
         CLIENT,
@@ -204,6 +245,8 @@
     // serve calls to the server at all times (e.g. if it hosts a callback)
 
     wp<RpcServer> mForServer; // maybe null, for client sessions
+    sp<WaitForShutdownListener> mShutdownListener; // used for client sessions
+    wp<EventListener> mEventListener; // mForServer if server, mShutdownListener if client
 
     // TODO(b/183988761): this shouldn't be guessable
     std::optional<int32_t> mId;
@@ -214,6 +257,8 @@
 
     std::mutex mMutex; // for all below
 
+    size_t mMaxReverseConnections = 0;
+
     std::condition_variable mAvailableConnectionCv; // for mWaitingThreads
     size_t mWaitingThreads = 0;
     // hint index into clients, ++ when sending an async transaction
@@ -221,8 +266,6 @@
     std::vector<sp<RpcConnection>> mClientConnections;
     std::vector<sp<RpcConnection>> mServerConnections;
 
-    // TODO(b/185167543): use for reverse sessions (allow client to also
-    // serve calls on a session).
     // TODO(b/185167543): allow sharing between different sessions in a
     // process? (or combine with mServerConnections)
     std::map<std::thread::id, std::thread> mThreads;
diff --git a/libs/binder/tests/Android.bp b/libs/binder/tests/Android.bp
index 9cf433d..c7c899f 100644
--- a/libs/binder/tests/Android.bp
+++ b/libs/binder/tests/Android.bp
@@ -118,6 +118,7 @@
     host_supported: true,
     unstable: true,
     srcs: [
+        "IBinderRpcCallback.aidl",
         "IBinderRpcSession.aidl",
         "IBinderRpcTest.aidl",
     ],
diff --git a/libs/binder/tests/IBinderRpcCallback.aidl b/libs/binder/tests/IBinderRpcCallback.aidl
new file mode 100644
index 0000000..0336961
--- /dev/null
+++ b/libs/binder/tests/IBinderRpcCallback.aidl
@@ -0,0 +1,20 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+interface IBinderRpcCallback {
+    void sendCallback(@utf8InCpp String str);
+    oneway void sendOnewayCallback(@utf8InCpp String str);
+}
diff --git a/libs/binder/tests/IBinderRpcTest.aidl b/libs/binder/tests/IBinderRpcTest.aidl
index 646bcc6..b0c8b2d 100644
--- a/libs/binder/tests/IBinderRpcTest.aidl
+++ b/libs/binder/tests/IBinderRpcTest.aidl
@@ -54,6 +54,8 @@
     void sleepMs(int ms);
     oneway void sleepMsAsync(int ms);
 
+    void doCallback(IBinderRpcCallback callback, boolean isOneway, boolean delayed, @utf8InCpp String value);
+
     void die(boolean cleanup);
     void scheduleShutdown();
 
diff --git a/libs/binder/tests/binderRpcTest.cpp b/libs/binder/tests/binderRpcTest.cpp
index efc70e6..80708df 100644
--- a/libs/binder/tests/binderRpcTest.cpp
+++ b/libs/binder/tests/binderRpcTest.cpp
@@ -14,6 +14,7 @@
  * limitations under the License.
  */
 
+#include <BnBinderRpcCallback.h>
 #include <BnBinderRpcSession.h>
 #include <BnBinderRpcTest.h>
 #include <aidl/IBinderRpcTest.h>
@@ -34,6 +35,7 @@
 #include <cstdlib>
 #include <iostream>
 #include <thread>
+#include <type_traits>
 
 #include <sys/prctl.h>
 #include <unistd.h>
@@ -89,6 +91,22 @@
 };
 std::atomic<int32_t> MyBinderRpcSession::gNum;
 
+class MyBinderRpcCallback : public BnBinderRpcCallback {
+    Status sendCallback(const std::string& value) {
+        std::unique_lock _l(mMutex);
+        mValues.push_back(value);
+        _l.unlock();
+        mCv.notify_one();
+        return Status::ok();
+    }
+    Status sendOnewayCallback(const std::string& value) { return sendCallback(value); }
+
+public:
+    std::mutex mMutex;
+    std::condition_variable mCv;
+    std::vector<std::string> mValues;
+};
+
 class MyBinderRpcTest : public BnBinderRpcTest {
 public:
     wp<RpcServer> server;
@@ -187,6 +205,27 @@
         return sleepMs(ms);
     }
 
+    Status doCallback(const sp<IBinderRpcCallback>& callback, bool oneway, bool delayed,
+                      const std::string& value) override {
+        if (callback == nullptr) {
+            return Status::fromExceptionCode(Status::EX_NULL_POINTER);
+        }
+
+        if (delayed) {
+            std::thread([=]() {
+                ALOGE("Executing delayed callback: '%s'", value.c_str());
+                (void)doCallback(callback, oneway, false, value);
+            }).detach();
+            return Status::ok();
+        }
+
+        if (oneway) {
+            return callback->sendOnewayCallback(value);
+        }
+
+        return callback->sendCallback(value);
+    }
+
     Status die(bool cleanup) override {
         if (cleanup) {
             exit(1);
@@ -308,6 +347,9 @@
 
     BinderRpcTestProcessSession(BinderRpcTestProcessSession&&) = default;
     ~BinderRpcTestProcessSession() {
+        EXPECT_NE(nullptr, rootIface);
+        if (rootIface == nullptr) return;
+
         if (!expectAlreadyShutdown) {
             std::vector<int32_t> remoteCounts;
             // calling over any sessions counts across all sessions
@@ -348,7 +390,7 @@
     // This creates a new process serving an interface on a certain number of
     // threads.
     ProcessSession createRpcTestSocketServerProcess(
-            size_t numThreads, size_t numSessions,
+            size_t numThreads, size_t numSessions, size_t numReverseConnections,
             const std::function<void(const sp<RpcServer>&)>& configure) {
         CHECK_GE(numSessions, 1) << "Must have at least one session to a server";
 
@@ -404,6 +446,8 @@
 
         for (size_t i = 0; i < numSessions; i++) {
             sp<RpcSession> session = RpcSession::make();
+            session->setMaxReverseConnections(numReverseConnections);
+
             switch (socketType) {
                 case SocketType::UNIX:
                     if (session->setupUnixDomainClient(addr.c_str())) goto success;
@@ -425,9 +469,11 @@
     }
 
     BinderRpcTestProcessSession createRpcTestSocketServerProcess(size_t numThreads,
-                                                                 size_t numSessions = 1) {
+                                                                 size_t numSessions = 1,
+                                                                 size_t numReverseConnections = 0) {
         BinderRpcTestProcessSession ret{
                 .proc = createRpcTestSocketServerProcess(numThreads, numSessions,
+                                                         numReverseConnections,
                                                          [&](const sp<RpcServer>& server) {
                                                              sp<MyBinderRpcTest> service =
                                                                      new MyBinderRpcTest;
@@ -895,6 +941,38 @@
     for (auto& t : threads) t.join();
 }
 
+TEST_P(BinderRpc, Callbacks) {
+    const static std::string kTestString = "good afternoon!";
+
+    for (bool oneway : {true, false}) {
+        for (bool delayed : {true, false}) {
+            auto proc = createRpcTestSocketServerProcess(1, 1, 1);
+            auto cb = sp<MyBinderRpcCallback>::make();
+
+            EXPECT_OK(proc.rootIface->doCallback(cb, oneway, delayed, kTestString));
+
+            using std::literals::chrono_literals::operator""s;
+            std::unique_lock<std::mutex> _l(cb->mMutex);
+            cb->mCv.wait_for(_l, 1s, [&] { return !cb->mValues.empty(); });
+
+            EXPECT_EQ(cb->mValues.size(), 1) << "oneway: " << oneway << "delayed: " << delayed;
+            if (cb->mValues.empty()) continue;
+            EXPECT_EQ(cb->mValues.at(0), kTestString)
+                    << "oneway: " << oneway << "delayed: " << delayed;
+
+            // since we are severing the connection, we need to go ahead and
+            // tell the server to shutdown and exit so that waitpid won't hang
+            EXPECT_OK(proc.rootIface->scheduleShutdown());
+
+            // since this session has a reverse connection w/ a threadpool, we
+            // need to manually shut it down
+            EXPECT_TRUE(proc.proc.sessions.at(0).session->shutdown());
+
+            proc.expectAlreadyShutdown = true;
+        }
+    }
+}
+
 TEST_P(BinderRpc, Die) {
     for (bool doDeathCleanup : {true, false}) {
         auto proc = createRpcTestSocketServerProcess(1);