Merge "rpc_binder: Specify CID for vsock RpcServer" am: 80170d5f74

Original change: https://android-review.googlesource.com/c/platform/frameworks/native/+/2309126

Change-Id: I36cf174c23c7956a61f01ffec543209d8aa3264e
Signed-off-by: Automerger Merge Worker <android-build-automerger-merge-worker@system.gserviceaccount.com>
diff --git a/libs/binder/RpcServer.cpp b/libs/binder/RpcServer.cpp
index 0820cd1..fedc1d9 100644
--- a/libs/binder/RpcServer.cpp
+++ b/libs/binder/RpcServer.cpp
@@ -70,11 +70,8 @@
     return setupSocketServer(UnixSocketAddress(path));
 }
 
-status_t RpcServer::setupVsockServer(unsigned int port) {
-    // realizing value w/ this type at compile time to avoid ubsan abort
-    constexpr unsigned int kAnyCid = VMADDR_CID_ANY;
-
-    return setupSocketServer(VsockSocketAddress(kAnyCid, port));
+status_t RpcServer::setupVsockServer(unsigned int bindCid, unsigned int port) {
+    return setupSocketServer(VsockSocketAddress(bindCid, port));
 }
 
 status_t RpcServer::setupInetServer(const char* address, unsigned int port,
@@ -157,6 +154,12 @@
     mRootObjectFactory = std::move(makeObject);
 }
 
+void RpcServer::setConnectionFilter(std::function<bool(const void*, size_t)>&& filter) {
+    RpcMutexLockGuard _l(mLock);
+    LOG_ALWAYS_FATAL_IF(mShutdownTrigger != nullptr, "Already joined");
+    mConnectionFilter = std::move(filter);
+}
+
 sp<IBinder> RpcServer::getRootObject() {
     RpcMutexLockGuard _l(mLock);
     bool hasWeak = mRootObjectWeak.unsafe_get();
@@ -242,13 +245,19 @@
         if (mAcceptFn(*this, &clientSocket) != OK) {
             continue;
         }
+
+        LOG_RPC_DETAIL("accept on fd %d yields fd %d", mServer.fd.get(), clientSocket.fd.get());
+
         if (getpeername(clientSocket.fd.get(), reinterpret_cast<sockaddr*>(addr.data()),
                         &addrLen)) {
             ALOGE("Could not getpeername socket: %s", strerror(errno));
             continue;
         }
 
-        LOG_RPC_DETAIL("accept on fd %d yields fd %d", mServer.fd.get(), clientSocket.fd.get());
+        if (mConnectionFilter != nullptr && !mConnectionFilter(addr.data(), addrLen)) {
+            ALOGE("Dropped client connection fd %d", clientSocket.fd.get());
+            continue;
+        }
 
         {
             RpcMutexLockGuard _l(mLock);
diff --git a/libs/binder/include/binder/RpcServer.h b/libs/binder/include/binder/RpcServer.h
index 4ad0a47..25193a3 100644
--- a/libs/binder/include/binder/RpcServer.h
+++ b/libs/binder/include/binder/RpcServer.h
@@ -81,9 +81,9 @@
     [[nodiscard]] status_t setupRawSocketServer(base::unique_fd socket_fd);
 
     /**
-     * Creates an RPC server at the current port.
+     * Creates an RPC server binding to the given CID at the given port.
      */
-    [[nodiscard]] status_t setupVsockServer(unsigned int port);
+    [[nodiscard]] status_t setupVsockServer(unsigned int bindCid, unsigned int port);
 
     /**
      * Creates an RPC server at the current port using IPv4.
@@ -171,6 +171,16 @@
     sp<IBinder> getRootObject();
 
     /**
+     * Set optional filter of incoming connections based on the peer's address.
+     *
+     * Takes one argument: a callable that is invoked on each accept()-ed
+     * connection and returns false if the connection should be dropped.
+     * See the description of setPerSessionRootObject() for details about
+     * the callable's arguments.
+     */
+    void setConnectionFilter(std::function<bool(const void*, size_t)>&& filter);
+
+    /**
      * See RpcTransportCtx::getCertificate
      */
     std::vector<uint8_t> getCertificate(RpcCertificateFormat);
@@ -253,6 +263,7 @@
     sp<IBinder> mRootObject;
     wp<IBinder> mRootObjectWeak;
     std::function<sp<IBinder>(const void*, size_t)> mRootObjectFactory;
+    std::function<bool(const void*, size_t)> mConnectionFilter;
     std::map<std::vector<uint8_t>, sp<RpcSession>> mSessions;
     std::unique_ptr<FdTrigger> mShutdownTrigger;
     RpcConditionVariable mShutdownCv;
diff --git a/libs/binder/include_rpc_unstable/binder_rpc_unstable.hpp b/libs/binder/include_rpc_unstable/binder_rpc_unstable.hpp
index f08bde8..3ec049e 100644
--- a/libs/binder/include_rpc_unstable/binder_rpc_unstable.hpp
+++ b/libs/binder/include_rpc_unstable/binder_rpc_unstable.hpp
@@ -25,9 +25,13 @@
 struct ARpcServer;
 
 // Starts an RPC server on a given port and a given root IBinder object.
+// The server will only accept connections from the given CID.
+// Set `cid` to VMADDR_CID_ANY to accept connections from any client.
+// Set `cid` to VMADDR_CID_LOCAL to only bind to the local vsock interface.
 // Returns an opaque handle to the running server instance, or null if the server
 // could not be started.
-[[nodiscard]] ARpcServer* ARpcServer_newVsock(AIBinder* service, unsigned int port);
+[[nodiscard]] ARpcServer* ARpcServer_newVsock(AIBinder* service, unsigned int cid,
+                                              unsigned int port);
 
 // Starts a Unix domain RPC server with a given init-managed Unix domain `name`
 // and a given root IBinder object.
diff --git a/libs/binder/libbinder_rpc_unstable.cpp b/libs/binder/libbinder_rpc_unstable.cpp
index f55c779..88f8c94 100644
--- a/libs/binder/libbinder_rpc_unstable.cpp
+++ b/libs/binder/libbinder_rpc_unstable.cpp
@@ -51,21 +51,26 @@
     ref->decStrong(ref);
 }
 
+static unsigned int cidFromStructAddr(const void* addr, size_t addrlen) {
+    LOG_ALWAYS_FATAL_IF(addrlen < sizeof(sockaddr_vm), "sockaddr is truncated");
+    const sockaddr_vm* vaddr = reinterpret_cast<const sockaddr_vm*>(addr);
+    LOG_ALWAYS_FATAL_IF(vaddr->svm_family != AF_VSOCK, "address is not a vsock");
+    return vaddr->svm_cid;
+}
+
 extern "C" {
 
 bool RunVsockRpcServerWithFactory(AIBinder* (*factory)(unsigned int cid, void* context),
                                   void* factoryContext, unsigned int port) {
     auto server = RpcServer::make();
-    if (status_t status = server->setupVsockServer(port); status != OK) {
+    if (status_t status = server->setupVsockServer(VMADDR_CID_ANY, port); status != OK) {
         LOG(ERROR) << "Failed to set up vsock server with port " << port
                    << " error: " << statusToString(status).c_str();
         return false;
     }
     server->setPerSessionRootObject([=](const void* addr, size_t addrlen) {
-        LOG_ALWAYS_FATAL_IF(addrlen < sizeof(sockaddr_vm), "sockaddr is truncated");
-        const sockaddr_vm* vaddr = reinterpret_cast<const sockaddr_vm*>(addr);
-        LOG_ALWAYS_FATAL_IF(vaddr->svm_family != AF_VSOCK, "address is not a vsock");
-        return AIBinder_toPlatformBinder(factory(vaddr->svm_cid, factoryContext));
+        unsigned int cid = cidFromStructAddr(addr, addrlen);
+        return AIBinder_toPlatformBinder(factory(cid, factoryContext));
     });
 
     server->join();
@@ -75,13 +80,30 @@
     return true;
 }
 
-ARpcServer* ARpcServer_newVsock(AIBinder* service, unsigned int port) {
+ARpcServer* ARpcServer_newVsock(AIBinder* service, unsigned int cid, unsigned int port) {
     auto server = RpcServer::make();
-    if (status_t status = server->setupVsockServer(port); status != OK) {
+
+    unsigned int bindCid = VMADDR_CID_ANY; // bind to the remote interface
+    if (cid == VMADDR_CID_LOCAL) {
+        bindCid = VMADDR_CID_LOCAL; // bind to the local interface
+        cid = VMADDR_CID_ANY;       // no need for a connection filter
+    }
+
+    if (status_t status = server->setupVsockServer(bindCid, port); status != OK) {
         LOG(ERROR) << "Failed to set up vsock server with port " << port
                    << " error: " << statusToString(status).c_str();
         return nullptr;
     }
+    if (cid != VMADDR_CID_ANY) {
+        server->setConnectionFilter([=](const void* addr, size_t addrlen) {
+            unsigned int remoteCid = cidFromStructAddr(addr, addrlen);
+            if (cid != remoteCid) {
+                LOG(ERROR) << "Rejected vsock connection from CID " << remoteCid;
+                return false;
+            }
+            return true;
+        });
+    }
     server->setRootObject(AIBinder_toPlatformBinder(service));
     return createRpcServerHandle(server);
 }
diff --git a/libs/binder/rust/rpcbinder/src/server.rs b/libs/binder/rust/rpcbinder/src/server.rs
index 42f5567..d5f1219 100644
--- a/libs/binder/rust/rpcbinder/src/server.rs
+++ b/libs/binder/rust/rpcbinder/src/server.rs
@@ -41,14 +41,19 @@
 
 impl RpcServer {
     /// Creates a binder RPC server, serving the supplied binder service implementation on the given
-    /// vsock port.
-    pub fn new_vsock(mut service: SpIBinder, port: u32) -> Result<RpcServer, Error> {
+    /// vsock port. Only connections from the given CID are accepted.
+    ///
+    // Set `cid` to libc::VMADDR_CID_ANY to accept connections from any client.
+    // Set `cid` to libc::VMADDR_CID_LOCAL to only bind to the local vsock interface.
+    pub fn new_vsock(mut service: SpIBinder, cid: u32, port: u32) -> Result<RpcServer, Error> {
         let service = service.as_native_mut();
 
         // SAFETY: Service ownership is transferring to the server and won't be valid afterward.
         // Plus the binder objects are threadsafe.
         unsafe {
-            Self::checked_from_ptr(binder_rpc_unstable_bindgen::ARpcServer_newVsock(service, port))
+            Self::checked_from_ptr(binder_rpc_unstable_bindgen::ARpcServer_newVsock(
+                service, cid, port,
+            ))
         }
     }
 
diff --git a/libs/binder/tests/binderRpcTest.cpp b/libs/binder/tests/binderRpcTest.cpp
index 02aa45f..739c217 100644
--- a/libs/binder/tests/binderRpcTest.cpp
+++ b/libs/binder/tests/binderRpcTest.cpp
@@ -1342,7 +1342,7 @@
                 } break;
                 case SocketType::VSOCK: {
                     auto port = allocateVsockPort();
-                    auto status = rpcServer->setupVsockServer(port);
+                    auto status = rpcServer->setupVsockServer(VMADDR_CID_LOCAL, port);
                     if (status != OK) {
                         return AssertionFailure() << "setupVsockServer: " << statusToString(status);
                     }
diff --git a/libs/binder/tests/binderRpcTestService.cpp b/libs/binder/tests/binderRpcTestService.cpp
index 995e761..cc9726b 100644
--- a/libs/binder/tests/binderRpcTestService.cpp
+++ b/libs/binder/tests/binderRpcTestService.cpp
@@ -58,7 +58,7 @@
             CHECK_EQ(OK, server->setupRawSocketServer(std::move(socketFd)));
             break;
         case SocketType::VSOCK:
-            CHECK_EQ(OK, server->setupVsockServer(serverConfig.vsockPort));
+            CHECK_EQ(OK, server->setupVsockServer(VMADDR_CID_LOCAL, serverConfig.vsockPort));
             break;
         case SocketType::INET: {
             CHECK_EQ(OK, server->setupInetServer(kLocalInetAddress, 0, &outPort));