libbinder: allow externally created connections

We want to allow:
- VM B process S has server using RpcServer normally
- VM A process V calls socket() and connect()
- VM A process V passes connected socket to VM A process C
- VM A process C talks to S over socket directly

Where:
- V = virtmanager process
- S = server process (e.g. compOS)
- C = client process (e.g. installd)

This way, within a VM, only one process needs permissions to connect to
services, and using this service, we can control access to VM services.

Bug: 193801719
Test: binderRpcTest
Change-Id: I5af60fd24354bce21d79676231172b6357b0c1a3
diff --git a/libs/binder/RpcSession.cpp b/libs/binder/RpcSession.cpp
index 6a22913..572595d 100644
--- a/libs/binder/RpcSession.cpp
+++ b/libs/binder/RpcSession.cpp
@@ -126,6 +126,17 @@
     return false;
 }
 
+bool RpcSession::setupPreconnectedClient(unique_fd fd, std::function<unique_fd()>&& request) {
+    return setupClient([&](const RpcAddress& sessionId, bool incoming) {
+        // std::move'd from fd becomes -1 (!ok())
+        if (!fd.ok()) {
+            fd = request();
+            if (!fd.ok()) return false;
+        }
+        return initAndAddConnection(std::move(fd), sessionId, incoming);
+    });
+}
+
 bool RpcSession::addNullDebuggingClient() {
     // Note: only works on raw sockets.
     unique_fd serverFd(TEMP_FAILURE_RETRY(open("/dev/null", O_WRONLY | O_CLOEXEC)));
@@ -464,7 +475,8 @@
     return server;
 }
 
-bool RpcSession::setupSocketClient(const RpcSocketAddress& addr) {
+bool RpcSession::setupClient(
+        const std::function<bool(const RpcAddress& sessionId, bool incoming)>& connectAndInit) {
     {
         std::lock_guard<std::mutex> _l(mMutex);
         LOG_ALWAYS_FATAL_IF(mOutgoingConnections.size() != 0,
@@ -472,7 +484,7 @@
                             mOutgoingConnections.size());
     }
 
-    if (!setupOneSocketConnection(addr, RpcAddress::zero(), false /*incoming*/)) return false;
+    if (!connectAndInit(RpcAddress::zero(), false /*incoming*/)) return false;
 
     {
         ExclusiveConnection connection;
@@ -491,37 +503,42 @@
     // TODO(b/186470974): first risk of blocking
     size_t numThreadsAvailable;
     if (status_t status = getRemoteMaxThreads(&numThreadsAvailable); status != OK) {
-        ALOGE("Could not get max threads after initial session to %s: %s", addr.toString().c_str(),
+        ALOGE("Could not get max threads after initial session setup: %s",
               statusToString(status).c_str());
         return false;
     }
 
     if (status_t status = readId(); status != OK) {
-        ALOGE("Could not get session id after initial session to %s; %s", addr.toString().c_str(),
+        ALOGE("Could not get session id after initial session setup: %s",
               statusToString(status).c_str());
         return false;
     }
 
-    // we've already setup one client
-    for (size_t i = 0; i + 1 < numThreadsAvailable; i++) {
-        // TODO(b/189955605): shutdown existing connections?
-        if (!setupOneSocketConnection(addr, mId.value(), false /*incoming*/)) return false;
-    }
-
     // TODO(b/189955605): we should add additional sessions dynamically
     // instead of all at once - the other side should be responsible for setting
     // up additional connections. We need to create at least one (unless 0 are
     // requested to be set) in order to allow the other side to reliably make
     // any requests at all.
 
+    // we've already setup one client
+    for (size_t i = 0; i + 1 < numThreadsAvailable; i++) {
+        if (!connectAndInit(mId.value(), false /*incoming*/)) return false;
+    }
+
     for (size_t i = 0; i < mMaxThreads; i++) {
-        if (!setupOneSocketConnection(addr, mId.value(), true /*incoming*/)) return false;
+        if (!connectAndInit(mId.value(), true /*incoming*/)) return false;
     }
 
     return true;
 }
 
-bool RpcSession::setupOneSocketConnection(const RpcSocketAddress& addr, const RpcAddress& id,
+bool RpcSession::setupSocketClient(const RpcSocketAddress& addr) {
+    return setupClient([&](const RpcAddress& sessionId, bool incoming) {
+        return setupOneSocketConnection(addr, sessionId, incoming);
+    });
+}
+
+bool RpcSession::setupOneSocketConnection(const RpcSocketAddress& addr, const RpcAddress& sessionId,
                                           bool incoming) {
     for (size_t tries = 0; tries < 5; tries++) {
         if (tries > 0) usleep(10000);
@@ -547,54 +564,57 @@
         }
         LOG_RPC_DETAIL("Socket at %s client with fd %d", addr.toString().c_str(), serverFd.get());
 
-        auto ctx = mRpcTransportCtxFactory->newClientCtx();
-        if (ctx == nullptr) {
-            ALOGE("Unable to create client RpcTransportCtx with %s sockets",
-                  mRpcTransportCtxFactory->toCString());
-            return false;
-        }
-        auto server = ctx->newTransport(std::move(serverFd));
-        if (server == nullptr) {
-            ALOGE("Unable to set up RpcTransport for %s", addr.toString().c_str());
-            return false;
-        }
-
-        LOG_RPC_DETAIL("Socket at %s client with RpcTransport %p", addr.toString().c_str(),
-                       server.get());
-
-        RpcConnectionHeader header{
-                .version = mProtocolVersion.value_or(RPC_WIRE_PROTOCOL_VERSION),
-                .options = 0,
-        };
-        memcpy(&header.sessionId, &id.viewRawEmbedded(), sizeof(RpcWireAddress));
-
-        if (incoming) header.options |= RPC_CONNECTION_OPTION_INCOMING;
-
-        auto sentHeader = server->send(&header, sizeof(header));
-        if (!sentHeader.ok()) {
-            ALOGE("Could not write connection header to socket at %s: %s", addr.toString().c_str(),
-                  sentHeader.error().message().c_str());
-            return false;
-        }
-        if (*sentHeader != sizeof(header)) {
-            ALOGE("Could not write connection header to socket at %s: sent %zd bytes, expected %zd",
-                  addr.toString().c_str(), *sentHeader, sizeof(header));
-            return false;
-        }
-
-        LOG_RPC_DETAIL("Socket at %s client: header sent", addr.toString().c_str());
-
-        if (incoming) {
-            return addIncomingConnection(std::move(server));
-        } else {
-            return addOutgoingConnection(std::move(server), true);
-        }
+        return initAndAddConnection(std::move(serverFd), sessionId, incoming);
     }
 
     ALOGE("Ran out of retries to connect to %s", addr.toString().c_str());
     return false;
 }
 
+bool RpcSession::initAndAddConnection(unique_fd fd, const RpcAddress& sessionId, bool incoming) {
+    auto ctx = mRpcTransportCtxFactory->newClientCtx();
+    if (ctx == nullptr) {
+        ALOGE("Unable to create client RpcTransportCtx with %s sockets",
+              mRpcTransportCtxFactory->toCString());
+        return false;
+    }
+    auto server = ctx->newTransport(std::move(fd));
+    if (server == nullptr) {
+        ALOGE("Unable to set up RpcTransport in %s context", mRpcTransportCtxFactory->toCString());
+        return false;
+    }
+
+    LOG_RPC_DETAIL("Socket at client with RpcTransport %p", server.get());
+
+    RpcConnectionHeader header{
+            .version = mProtocolVersion.value_or(RPC_WIRE_PROTOCOL_VERSION),
+            .options = 0,
+    };
+    memcpy(&header.sessionId, &sessionId.viewRawEmbedded(), sizeof(RpcWireAddress));
+
+    if (incoming) header.options |= RPC_CONNECTION_OPTION_INCOMING;
+
+    auto sentHeader = server->send(&header, sizeof(header));
+    if (!sentHeader.ok()) {
+        ALOGE("Could not write connection header to socket: %s",
+              sentHeader.error().message().c_str());
+        return false;
+    }
+    if (*sentHeader != sizeof(header)) {
+        ALOGE("Could not write connection header to socket: sent %zd bytes, expected %zd",
+              *sentHeader, sizeof(header));
+        return false;
+    }
+
+    LOG_RPC_DETAIL("Socket at client: header sent");
+
+    if (incoming) {
+        return addIncomingConnection(std::move(server));
+    } else {
+        return addOutgoingConnection(std::move(server), true /*init*/);
+    }
+}
+
 bool RpcSession::addIncomingConnection(std::unique_ptr<RpcTransport> rpcTransport) {
     std::mutex mutex;
     std::condition_variable joinCv;
diff --git a/libs/binder/include/binder/RpcSession.h b/libs/binder/include/binder/RpcSession.h
index e3d6bba..0b46606 100644
--- a/libs/binder/include/binder/RpcSession.h
+++ b/libs/binder/include/binder/RpcSession.h
@@ -91,6 +91,19 @@
     [[nodiscard]] bool setupInetClient(const char* addr, unsigned int port);
 
     /**
+     * Starts talking to an RPC server which has already been connected to. This
+     * is expected to be used when another process has permission to connect to
+     * a binder RPC service, but this process only has permission to talk to
+     * that service.
+     *
+     * For convenience, if 'fd' is -1, 'request' will be called.
+     *
+     * For future compatibility, 'request' should not reference any stack data.
+     */
+    [[nodiscard]] bool setupPreconnectedClient(base::unique_fd fd,
+                                               std::function<base::unique_fd()>&& request);
+
+    /**
      * For debugging!
      *
      * Sets up an empty connection. All queries to this connection which require a
@@ -240,9 +253,13 @@
     // join on thread passed to preJoinThreadOwnership
     static void join(sp<RpcSession>&& session, PreJoinSetupResult&& result);
 
+    [[nodiscard]] bool setupClient(
+            const std::function<bool(const RpcAddress& sessionId, bool incoming)>& connectAndInit);
     [[nodiscard]] bool setupSocketClient(const RpcSocketAddress& address);
     [[nodiscard]] bool setupOneSocketConnection(const RpcSocketAddress& address,
-                                                const RpcAddress& sessionId, bool server);
+                                                const RpcAddress& sessionId, bool incoming);
+    [[nodiscard]] bool initAndAddConnection(base::unique_fd fd, const RpcAddress& sessionId,
+                                            bool incoming);
     [[nodiscard]] bool addIncomingConnection(std::unique_ptr<RpcTransport> rpcTransport);
     [[nodiscard]] bool addOutgoingConnection(std::unique_ptr<RpcTransport> rpcTransport, bool init);
     [[nodiscard]] bool setForServer(const wp<RpcServer>& server,
diff --git a/libs/binder/tests/binderRpcTest.cpp b/libs/binder/tests/binderRpcTest.cpp
index 386183d..6c56a4d 100644
--- a/libs/binder/tests/binderRpcTest.cpp
+++ b/libs/binder/tests/binderRpcTest.cpp
@@ -42,6 +42,7 @@
 #include <sys/prctl.h>
 #include <unistd.h>
 
+#include "../RpcSocketAddress.h" // for testing preconnected clients
 #include "../RpcState.h"   // for debugging
 #include "../vm_sockets.h" // for VMADDR_*
 
@@ -408,12 +409,15 @@
 };
 
 enum class SocketType {
+    PRECONNECTED,
     UNIX,
     VSOCK,
     INET,
 };
 static inline std::string PrintToString(SocketType socketType) {
     switch (socketType) {
+        case SocketType::PRECONNECTED:
+            return "preconnected_uds";
         case SocketType::UNIX:
             return "unix_domain_socket";
         case SocketType::VSOCK:
@@ -426,6 +430,20 @@
     }
 }
 
+static base::unique_fd connectToUds(const char* addrStr) {
+    UnixSocketAddress addr(addrStr);
+    base::unique_fd serverFd(
+            TEMP_FAILURE_RETRY(socket(addr.addr()->sa_family, SOCK_STREAM | SOCK_CLOEXEC, 0)));
+    int savedErrno = errno;
+    CHECK(serverFd.ok()) << "Could not create socket " << addrStr << ": " << strerror(savedErrno);
+
+    if (0 != TEMP_FAILURE_RETRY(connect(serverFd.get(), addr.addr(), addr.addrSize()))) {
+        int savedErrno = errno;
+        LOG(FATAL) << "Could not connect to socket " << addrStr << ": " << strerror(savedErrno);
+    }
+    return serverFd;
+}
+
 class BinderRpc : public ::testing::TestWithParam<std::tuple<SocketType, RpcSecurity>> {
 public:
     struct Options {
@@ -462,6 +480,8 @@
                     unsigned int outPort = 0;
 
                     switch (socketType) {
+                        case SocketType::PRECONNECTED:
+                            [[fallthrough]];
                         case SocketType::UNIX:
                             CHECK(server->setupUnixDomainServer(addr.c_str())) << addr;
                             break;
@@ -500,6 +520,12 @@
             session->setMaxThreads(options.numIncomingConnections);
 
             switch (socketType) {
+                case SocketType::PRECONNECTED:
+                    if (session->setupPreconnectedClient({}, [=]() {
+                            return connectToUds(addr.c_str());
+                        }))
+                        goto success;
+                    break;
                 case SocketType::UNIX:
                     if (session->setupUnixDomainClient(addr.c_str())) goto success;
                     break;
@@ -1175,7 +1201,7 @@
 }
 
 static std::vector<SocketType> testSocketTypes() {
-    std::vector<SocketType> ret = {SocketType::UNIX, SocketType::INET};
+    std::vector<SocketType> ret = {SocketType::PRECONNECTED, SocketType::UNIX, SocketType::INET};
 
     static bool hasVsockLoopback = testSupportVsockLoopback();