binder: Add FD support to RPC Binder

Bug: 185909244
Test: TH
Change-Id: Ic4fc1b1edfe9d69984e785553cd1aaca97a07da3
diff --git a/libs/binder/RpcTransportRaw.cpp b/libs/binder/RpcTransportRaw.cpp
index f9b73fc..d9059e9 100644
--- a/libs/binder/RpcTransportRaw.cpp
+++ b/libs/binder/RpcTransportRaw.cpp
@@ -18,6 +18,7 @@
 #include <log/log.h>
 
 #include <poll.h>
+#include <stddef.h>
 
 #include <binder/RpcTransportRaw.h>
 
@@ -28,6 +29,9 @@
 
 namespace {
 
+// Linux kernel supports up to 253 (from SCM_MAX_FD) for unix sockets.
+constexpr size_t kMaxFdsPerMsg = 253;
+
 // RpcTransport with TLS disabled.
 class RpcTransportRaw : public RpcTransport {
 public:
@@ -85,15 +89,7 @@
 
         bool havePolled = false;
         while (true) {
-            msghdr msg{
-                    .msg_iov = iovs,
-                    // posix uses int, glibc uses size_t.  niovs is a
-                    // non-negative int and can be cast to either.
-                    .msg_iovlen = static_cast<decltype(msg.msg_iovlen)>(niovs),
-            };
-            ssize_t processSize =
-                    TEMP_FAILURE_RETRY(sendOrReceiveFun(mSocket.get(), &msg, MSG_NOSIGNAL));
-
+            ssize_t processSize = sendOrReceiveFun(iovs, niovs);
             if (processSize < 0) {
                 int savedErrno = errno;
 
@@ -145,20 +141,133 @@
 
     status_t interruptableWriteFully(
             FdTrigger* fdTrigger, iovec* iovs, int niovs,
-            const std::optional<android::base::function_ref<status_t()>>& altPoll) override {
-        return interruptableReadOrWrite(fdTrigger, iovs, niovs, sendmsg, "sendmsg", POLLOUT,
-                                        altPoll);
+            const std::optional<android::base::function_ref<status_t()>>& altPoll,
+            const std::vector<std::variant<base::unique_fd, base::borrowed_fd>>* ancillaryFds)
+            override {
+        bool sentFds = false;
+        auto send = [&](iovec* iovs, int niovs) -> ssize_t {
+            if (ancillaryFds != nullptr && !ancillaryFds->empty() && !sentFds) {
+                if (ancillaryFds->size() > kMaxFdsPerMsg) {
+                    // This shouldn't happen because we check the FD count in RpcState.
+                    ALOGE("Saw too many file descriptors in RpcTransportCtxRaw: %zu (max is %zu). "
+                          "Aborting session.",
+                          ancillaryFds->size(), kMaxFdsPerMsg);
+                    errno = EINVAL;
+                    return -1;
+                }
+
+                // CMSG_DATA is not necessarily aligned, so we copy the FDs into a buffer and then
+                // use memcpy.
+                int fds[kMaxFdsPerMsg];
+                for (size_t i = 0; i < ancillaryFds->size(); i++) {
+                    fds[i] = std::visit([](const auto& fd) { return fd.get(); },
+                                        ancillaryFds->at(i));
+                }
+                const size_t fdsByteSize = sizeof(int) * ancillaryFds->size();
+
+                alignas(struct cmsghdr) char msgControlBuf[CMSG_SPACE(sizeof(int) * kMaxFdsPerMsg)];
+
+                msghdr msg{
+                        .msg_iov = iovs,
+                        .msg_iovlen = static_cast<decltype(msg.msg_iovlen)>(niovs),
+                        .msg_control = msgControlBuf,
+                        .msg_controllen = sizeof(msgControlBuf),
+                };
+
+                cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
+                cmsg->cmsg_level = SOL_SOCKET;
+                cmsg->cmsg_type = SCM_RIGHTS;
+                cmsg->cmsg_len = CMSG_LEN(fdsByteSize);
+                memcpy(CMSG_DATA(cmsg), fds, fdsByteSize);
+
+                msg.msg_controllen = CMSG_SPACE(fdsByteSize);
+
+                ssize_t processedSize = TEMP_FAILURE_RETRY(
+                        sendmsg(mSocket.get(), &msg, MSG_NOSIGNAL | MSG_CMSG_CLOEXEC));
+                if (processedSize > 0) {
+                    sentFds = true;
+                }
+                return processedSize;
+            }
+
+            msghdr msg{
+                    .msg_iov = iovs,
+                    // posix uses int, glibc uses size_t.  niovs is a
+                    // non-negative int and can be cast to either.
+                    .msg_iovlen = static_cast<decltype(msg.msg_iovlen)>(niovs),
+            };
+            return TEMP_FAILURE_RETRY(sendmsg(mSocket.get(), &msg, MSG_NOSIGNAL));
+        };
+        return interruptableReadOrWrite(fdTrigger, iovs, niovs, send, "sendmsg", POLLOUT, altPoll);
     }
 
     status_t interruptableReadFully(
             FdTrigger* fdTrigger, iovec* iovs, int niovs,
-            const std::optional<android::base::function_ref<status_t()>>& altPoll) override {
-        return interruptableReadOrWrite(fdTrigger, iovs, niovs, recvmsg, "recvmsg", POLLIN,
-                                        altPoll);
+            const std::optional<android::base::function_ref<status_t()>>& altPoll,
+            bool enableAncillaryFds) override {
+        auto recv = [&](iovec* iovs, int niovs) -> ssize_t {
+            if (enableAncillaryFds) {
+                int fdBuffer[kMaxFdsPerMsg];
+                alignas(struct cmsghdr) char msgControlBuf[CMSG_SPACE(sizeof(fdBuffer))];
+
+                msghdr msg{
+                        .msg_iov = iovs,
+                        .msg_iovlen = static_cast<decltype(msg.msg_iovlen)>(niovs),
+                        .msg_control = msgControlBuf,
+                        .msg_controllen = sizeof(msgControlBuf),
+                };
+                ssize_t processSize =
+                        TEMP_FAILURE_RETRY(recvmsg(mSocket.get(), &msg, MSG_NOSIGNAL));
+                if (processSize < 0) {
+                    return -1;
+                }
+
+                for (cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); cmsg != nullptr;
+                     cmsg = CMSG_NXTHDR(&msg, cmsg)) {
+                    if (cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_RIGHTS) {
+                        // NOTE: It is tempting to reinterpret_cast, but cmsg(3) explicitly asks
+                        // application devs to memcpy the data to ensure memory alignment.
+                        size_t dataLen = cmsg->cmsg_len - CMSG_LEN(0);
+                        memcpy(fdBuffer, CMSG_DATA(cmsg), dataLen);
+                        size_t fdCount = dataLen / sizeof(int);
+                        for (size_t i = 0; i < fdCount; i++) {
+                            mFdsPendingRead.emplace_back(fdBuffer[i]);
+                        }
+                        break;
+                    }
+                }
+
+                if (msg.msg_flags & MSG_CTRUNC) {
+                    ALOGE("msg was truncated. Aborting session.");
+                    errno = EPIPE;
+                    return -1;
+                }
+
+                return processSize;
+            }
+            msghdr msg{
+                    .msg_iov = iovs,
+                    // posix uses int, glibc uses size_t.  niovs is a
+                    // non-negative int and can be cast to either.
+                    .msg_iovlen = static_cast<decltype(msg.msg_iovlen)>(niovs),
+            };
+            return TEMP_FAILURE_RETRY(recvmsg(mSocket.get(), &msg, MSG_NOSIGNAL));
+        };
+        return interruptableReadOrWrite(fdTrigger, iovs, niovs, recv, "recvmsg", POLLIN, altPoll);
+    }
+
+    status_t consumePendingAncillaryData(std::vector<base::unique_fd>* fds) override {
+        fds->reserve(fds->size() + mFdsPendingRead.size());
+        for (auto& fd : mFdsPendingRead) {
+            fds->emplace_back(std::move(fd));
+        }
+        mFdsPendingRead.clear();
+        return OK;
     }
 
 private:
     base::unique_fd mSocket;
+    std::vector<base::unique_fd> mFdsPendingRead;
 };
 
 // RpcTransportCtx with TLS disabled.