libbinder: Stricter protocol and code for receiving FDs

This is a slight change the to wire protocol. Now out-of-band FDs must
be sent along with the command header bytes.

The code changes exploit that by only using the more complex `recvmsg`
call when reading the command header. Additionally, we explicitly pass
around the list of FDs so that there is no risk of accumulating them.

The same (somewhat ugly) vector type is used everywhere now so that
there is only one allocation to capture the FDs and pass them to the
`Parcel` object.

Test: binderRpcTest
Bug: 185909244
Change-Id: I1f55995ca82338ab9716fb2246c954ac8b16cfe5
diff --git a/libs/binder/Parcel.cpp b/libs/binder/Parcel.cpp
index 8739105..0354382 100644
--- a/libs/binder/Parcel.cpp
+++ b/libs/binder/Parcel.cpp
@@ -2486,11 +2486,11 @@
     scanForFds();
 }
 
-status_t Parcel::rpcSetDataReference(const sp<RpcSession>& session, const uint8_t* data,
-                                     size_t dataSize, const uint32_t* objectTable,
-                                     size_t objectTableSize,
-                                     std::vector<base::unique_fd> ancillaryFds,
-                                     release_func relFunc) {
+status_t Parcel::rpcSetDataReference(
+        const sp<RpcSession>& session, const uint8_t* data, size_t dataSize,
+        const uint32_t* objectTable, size_t objectTableSize,
+        std::vector<std::variant<base::unique_fd, base::borrowed_fd>>&& ancillaryFds,
+        release_func relFunc) {
     // this code uses 'mOwner == nullptr' to understand whether it owns memory
     LOG_ALWAYS_FATAL_IF(relFunc == nullptr, "must provide cleanup function");
 
@@ -2518,10 +2518,7 @@
     }
     if (!ancillaryFds.empty()) {
         rpcFields->mFds = std::make_unique<decltype(rpcFields->mFds)::element_type>();
-        rpcFields->mFds->reserve(ancillaryFds.size());
-        for (auto& fd : ancillaryFds) {
-            rpcFields->mFds->push_back(std::move(fd));
-        }
+        *rpcFields->mFds = std::move(ancillaryFds);
     }
 
     return OK;
diff --git a/libs/binder/RpcServer.cpp b/libs/binder/RpcServer.cpp
index 3bb21ad..f83bb5e 100644
--- a/libs/binder/RpcServer.cpp
+++ b/libs/binder/RpcServer.cpp
@@ -300,7 +300,7 @@
     if (status == OK) {
         iovec iov{&header, sizeof(header)};
         status = client->interruptableReadFully(server->mShutdownTrigger.get(), &iov, 1,
-                                                std::nullopt, /*enableAncillaryFds=*/false);
+                                                std::nullopt, /*ancillaryFds=*/nullptr);
         if (status != OK) {
             ALOGE("Failed to read ID for client connecting to RPC server: %s",
                   statusToString(status).c_str());
@@ -315,7 +315,7 @@
                 sessionId.resize(header.sessionIdSize);
                 iovec iov{sessionId.data(), sessionId.size()};
                 status = client->interruptableReadFully(server->mShutdownTrigger.get(), &iov, 1,
-                                                        std::nullopt, /*enableAncillaryFds=*/false);
+                                                        std::nullopt, /*ancillaryFds=*/nullptr);
                 if (status != OK) {
                     ALOGE("Failed to read session ID for client connecting to RPC server: %s",
                           statusToString(status).c_str());
diff --git a/libs/binder/RpcState.cpp b/libs/binder/RpcState.cpp
index 2b1d0e5..8f104bc 100644
--- a/libs/binder/RpcState.cpp
+++ b/libs/binder/RpcState.cpp
@@ -346,11 +346,14 @@
     return OK;
 }
 
-status_t RpcState::rpcRec(const sp<RpcSession::RpcConnection>& connection,
-                          const sp<RpcSession>& session, const char* what, iovec* iovs, int niovs) {
-    if (status_t status = connection->rpcTransport->interruptableReadFully(
-                session->mShutdownTrigger.get(), iovs, niovs, std::nullopt,
-                enableAncillaryFds(session->getFileDescriptorTransportMode()));
+status_t RpcState::rpcRec(
+        const sp<RpcSession::RpcConnection>& connection, const sp<RpcSession>& session,
+        const char* what, iovec* iovs, int niovs,
+        std::vector<std::variant<base::unique_fd, base::borrowed_fd>>* ancillaryFds) {
+    if (status_t status =
+                connection->rpcTransport->interruptableReadFully(session->mShutdownTrigger.get(),
+                                                                 iovs, niovs, std::nullopt,
+                                                                 ancillaryFds);
         status != OK) {
         LOG_RPC_DETAIL("Failed to read %s (%d iovs) on RpcTransport %p, error: %s", what, niovs,
                        connection->rpcTransport.get(), statusToString(status).c_str());
@@ -370,7 +373,7 @@
                                           const sp<RpcSession>& session, uint32_t* version) {
     RpcNewSessionResponse response;
     iovec iov{&response, sizeof(response)};
-    if (status_t status = rpcRec(connection, session, "new session response", &iov, 1);
+    if (status_t status = rpcRec(connection, session, "new session response", &iov, 1, nullptr);
         status != OK) {
         return status;
     }
@@ -391,7 +394,8 @@
                                       const sp<RpcSession>& session) {
     RpcOutgoingConnectionInit init;
     iovec iov{&init, sizeof(init)};
-    if (status_t status = rpcRec(connection, session, "connection init", &iov, 1); status != OK)
+    if (status_t status = rpcRec(connection, session, "connection init", &iov, 1, nullptr);
+        status != OK)
         return status;
 
     static_assert(sizeof(init.msg) == sizeof(RPC_CONNECTION_INIT_OKAY));
@@ -589,18 +593,26 @@
 
 status_t RpcState::waitForReply(const sp<RpcSession::RpcConnection>& connection,
                                 const sp<RpcSession>& session, Parcel* reply) {
+    std::vector<std::variant<base::unique_fd, base::borrowed_fd>> ancillaryFds;
     RpcWireHeader command;
     while (true) {
         iovec iov{&command, sizeof(command)};
-        if (status_t status = rpcRec(connection, session, "command header (for reply)", &iov, 1);
+        if (status_t status = rpcRec(connection, session, "command header (for reply)", &iov, 1,
+                                     enableAncillaryFds(session->getFileDescriptorTransportMode())
+                                             ? &ancillaryFds
+                                             : nullptr);
             status != OK)
             return status;
 
         if (command.command == RPC_COMMAND_REPLY) break;
 
-        if (status_t status = processCommand(connection, session, command, CommandType::ANY);
+        if (status_t status = processCommand(connection, session, command, CommandType::ANY,
+                                             std::move(ancillaryFds));
             status != OK)
             return status;
+
+        // Reset to avoid spurious use-after-move warning from clang-tidy.
+        ancillaryFds = decltype(ancillaryFds)();
     }
 
     const size_t rpcReplyWireSize = RpcWireReply::wireSize(session->getProtocolVersion().value());
@@ -622,17 +634,10 @@
             {&rpcReply, rpcReplyWireSize},
             {data.data(), data.size()},
     };
-    if (status_t status = rpcRec(connection, session, "reply body", iovs, arraysize(iovs));
+    if (status_t status = rpcRec(connection, session, "reply body", iovs, arraysize(iovs), nullptr);
         status != OK)
         return status;
 
-    // Check if the reply came with any ancillary data.
-    std::vector<base::unique_fd> pendingFds;
-    if (status_t status = connection->rpcTransport->consumePendingAncillaryData(&pendingFds);
-        status != OK) {
-        return status;
-    }
-
     if (rpcReply.status != OK) return rpcReply.status;
 
     Span<const uint8_t> parcelSpan = {data.data(), data.size()};
@@ -655,7 +660,7 @@
     data.release();
     return reply->rpcSetDataReference(session, parcelSpan.data, parcelSpan.size,
                                       objectTableSpan.data, objectTableSpan.size,
-                                      std::move(pendingFds), cleanup_reply_data);
+                                      std::move(ancillaryFds), cleanup_reply_data);
 }
 
 status_t RpcState::sendDecStrongToTarget(const sp<RpcSession::RpcConnection>& connection,
@@ -698,13 +703,17 @@
                                         const sp<RpcSession>& session, CommandType type) {
     LOG_RPC_DETAIL("getAndExecuteCommand on RpcTransport %p", connection->rpcTransport.get());
 
+    std::vector<std::variant<base::unique_fd, base::borrowed_fd>> ancillaryFds;
     RpcWireHeader command;
     iovec iov{&command, sizeof(command)};
-    if (status_t status = rpcRec(connection, session, "command header (for server)", &iov, 1);
+    if (status_t status =
+                rpcRec(connection, session, "command header (for server)", &iov, 1,
+                       enableAncillaryFds(session->getFileDescriptorTransportMode()) ? &ancillaryFds
+                                                                                     : nullptr);
         status != OK)
         return status;
 
-    return processCommand(connection, session, command, type);
+    return processCommand(connection, session, command, type, std::move(ancillaryFds));
 }
 
 status_t RpcState::drainCommands(const sp<RpcSession::RpcConnection>& connection,
@@ -720,9 +729,10 @@
     return OK;
 }
 
-status_t RpcState::processCommand(const sp<RpcSession::RpcConnection>& connection,
-                                  const sp<RpcSession>& session, const RpcWireHeader& command,
-                                  CommandType type) {
+status_t RpcState::processCommand(
+        const sp<RpcSession::RpcConnection>& connection, const sp<RpcSession>& session,
+        const RpcWireHeader& command, CommandType type,
+        std::vector<std::variant<base::unique_fd, base::borrowed_fd>>&& ancillaryFds) {
     IPCThreadState* kernelBinderState = IPCThreadState::selfOrNull();
     IPCThreadState::SpGuard spGuard{
             .address = __builtin_frame_address(0),
@@ -741,7 +751,7 @@
     switch (command.command) {
         case RPC_COMMAND_TRANSACT:
             if (type != CommandType::ANY) return BAD_TYPE;
-            return processTransact(connection, session, command);
+            return processTransact(connection, session, command, std::move(ancillaryFds));
         case RPC_COMMAND_DEC_STRONG:
             return processDecStrong(connection, session, command);
     }
@@ -755,8 +765,10 @@
     (void)session->shutdownAndWait(false);
     return DEAD_OBJECT;
 }
-status_t RpcState::processTransact(const sp<RpcSession::RpcConnection>& connection,
-                                   const sp<RpcSession>& session, const RpcWireHeader& command) {
+status_t RpcState::processTransact(
+        const sp<RpcSession::RpcConnection>& connection, const sp<RpcSession>& session,
+        const RpcWireHeader& command,
+        std::vector<std::variant<base::unique_fd, base::borrowed_fd>>&& ancillaryFds) {
     LOG_ALWAYS_FATAL_IF(command.command != RPC_COMMAND_TRANSACT, "command: %d", command.command);
 
     CommandData transactionData(command.bodySize);
@@ -764,10 +776,12 @@
         return NO_MEMORY;
     }
     iovec iov{transactionData.data(), transactionData.size()};
-    if (status_t status = rpcRec(connection, session, "transaction body", &iov, 1); status != OK)
+    if (status_t status = rpcRec(connection, session, "transaction body", &iov, 1, nullptr);
+        status != OK)
         return status;
 
-    return processTransactInternal(connection, session, std::move(transactionData));
+    return processTransactInternal(connection, session, std::move(transactionData),
+                                   std::move(ancillaryFds));
 }
 
 static void do_nothing_to_transact_data(Parcel* p, const uint8_t* data, size_t dataSize,
@@ -779,9 +793,10 @@
     (void)objectsCount;
 }
 
-status_t RpcState::processTransactInternal(const sp<RpcSession::RpcConnection>& connection,
-                                           const sp<RpcSession>& session,
-                                           CommandData transactionData) {
+status_t RpcState::processTransactInternal(
+        const sp<RpcSession::RpcConnection>& connection, const sp<RpcSession>& session,
+        CommandData transactionData,
+        std::vector<std::variant<base::unique_fd, base::borrowed_fd>>&& ancillaryFds) {
     // for 'recursive' calls to this, we have already read and processed the
     // binder from the transaction data and taken reference counts into account,
     // so it is cached here.
@@ -869,13 +884,6 @@
     reply.markForRpc(session);
 
     if (replyStatus == OK) {
-        // Check if the transaction came with any ancillary data.
-        std::vector<base::unique_fd> pendingFds;
-        if (status_t status = connection->rpcTransport->consumePendingAncillaryData(&pendingFds);
-            status != OK) {
-            return status;
-        }
-
         Span<const uint8_t> parcelSpan = {transaction->data,
                                           transactionData.size() -
                                                   offsetof(RpcWireTransaction, data)};
@@ -901,9 +909,12 @@
         // only holds onto it for the duration of this function call. Parcel will be
         // deleted before the 'transactionData' object.
 
-        replyStatus = data.rpcSetDataReference(session, parcelSpan.data, parcelSpan.size,
-                                               objectTableSpan.data, objectTableSpan.size,
-                                               std::move(pendingFds), do_nothing_to_transact_data);
+        replyStatus =
+                data.rpcSetDataReference(session, parcelSpan.data, parcelSpan.size,
+                                         objectTableSpan.data, objectTableSpan.size,
+                                         std::move(ancillaryFds), do_nothing_to_transact_data);
+        // Reset to avoid spurious use-after-move warning from clang-tidy.
+        ancillaryFds = std::remove_reference<decltype(ancillaryFds)>::type();
 
         if (replyStatus == OK) {
             if (target) {
@@ -1073,7 +1084,8 @@
 
     RpcDecStrong body;
     iovec iov{&body, sizeof(RpcDecStrong)};
-    if (status_t status = rpcRec(connection, session, "dec ref body", &iov, 1); status != OK)
+    if (status_t status = rpcRec(connection, session, "dec ref body", &iov, 1, nullptr);
+        status != OK)
         return status;
 
     uint64_t addr = RpcWireAddress::toRaw(body.address);
diff --git a/libs/binder/RpcState.h b/libs/binder/RpcState.h
index 08a314e..b452a99 100644
--- a/libs/binder/RpcState.h
+++ b/libs/binder/RpcState.h
@@ -184,21 +184,25 @@
             const std::optional<android::base::function_ref<status_t()>>& altPoll,
             const std::vector<std::variant<base::unique_fd, base::borrowed_fd>>* ancillaryFds =
                     nullptr);
-    [[nodiscard]] status_t rpcRec(const sp<RpcSession::RpcConnection>& connection,
-                                  const sp<RpcSession>& session, const char* what, iovec* iovs,
-                                  int niovs);
+    [[nodiscard]] status_t rpcRec(
+            const sp<RpcSession::RpcConnection>& connection, const sp<RpcSession>& session,
+            const char* what, iovec* iovs, int niovs,
+            std::vector<std::variant<base::unique_fd, base::borrowed_fd>>* ancillaryFds = nullptr);
 
     [[nodiscard]] status_t waitForReply(const sp<RpcSession::RpcConnection>& connection,
                                         const sp<RpcSession>& session, Parcel* reply);
-    [[nodiscard]] status_t processCommand(const sp<RpcSession::RpcConnection>& connection,
-                                          const sp<RpcSession>& session,
-                                          const RpcWireHeader& command, CommandType type);
-    [[nodiscard]] status_t processTransact(const sp<RpcSession::RpcConnection>& connection,
-                                           const sp<RpcSession>& session,
-                                           const RpcWireHeader& command);
-    [[nodiscard]] status_t processTransactInternal(const sp<RpcSession::RpcConnection>& connection,
-                                                   const sp<RpcSession>& session,
-                                                   CommandData transactionData);
+    [[nodiscard]] status_t processCommand(
+            const sp<RpcSession::RpcConnection>& connection, const sp<RpcSession>& session,
+            const RpcWireHeader& command, CommandType type,
+            std::vector<std::variant<base::unique_fd, base::borrowed_fd>>&& ancillaryFds);
+    [[nodiscard]] status_t processTransact(
+            const sp<RpcSession::RpcConnection>& connection, const sp<RpcSession>& session,
+            const RpcWireHeader& command,
+            std::vector<std::variant<base::unique_fd, base::borrowed_fd>>&& ancillaryFds);
+    [[nodiscard]] status_t processTransactInternal(
+            const sp<RpcSession::RpcConnection>& connection, const sp<RpcSession>& session,
+            CommandData transactionData,
+            std::vector<std::variant<base::unique_fd, base::borrowed_fd>>&& ancillaryFds);
     [[nodiscard]] status_t processDecStrong(const sp<RpcSession::RpcConnection>& connection,
                                             const sp<RpcSession>& session,
                                             const RpcWireHeader& command);
diff --git a/libs/binder/RpcTransportRaw.cpp b/libs/binder/RpcTransportRaw.cpp
index d9059e9..7cc58cd 100644
--- a/libs/binder/RpcTransportRaw.cpp
+++ b/libs/binder/RpcTransportRaw.cpp
@@ -204,9 +204,9 @@
     status_t interruptableReadFully(
             FdTrigger* fdTrigger, iovec* iovs, int niovs,
             const std::optional<android::base::function_ref<status_t()>>& altPoll,
-            bool enableAncillaryFds) override {
+            std::vector<std::variant<base::unique_fd, base::borrowed_fd>>* ancillaryFds) override {
         auto recv = [&](iovec* iovs, int niovs) -> ssize_t {
-            if (enableAncillaryFds) {
+            if (ancillaryFds != nullptr) {
                 int fdBuffer[kMaxFdsPerMsg];
                 alignas(struct cmsghdr) char msgControlBuf[CMSG_SPACE(sizeof(fdBuffer))];
 
@@ -228,10 +228,12 @@
                         // NOTE: It is tempting to reinterpret_cast, but cmsg(3) explicitly asks
                         // application devs to memcpy the data to ensure memory alignment.
                         size_t dataLen = cmsg->cmsg_len - CMSG_LEN(0);
+                        LOG_ALWAYS_FATAL_IF(dataLen > sizeof(fdBuffer)); // sanity check
                         memcpy(fdBuffer, CMSG_DATA(cmsg), dataLen);
                         size_t fdCount = dataLen / sizeof(int);
+                        ancillaryFds->reserve(ancillaryFds->size() + fdCount);
                         for (size_t i = 0; i < fdCount; i++) {
-                            mFdsPendingRead.emplace_back(fdBuffer[i]);
+                            ancillaryFds->emplace_back(base::unique_fd(fdBuffer[i]));
                         }
                         break;
                     }
@@ -256,18 +258,8 @@
         return interruptableReadOrWrite(fdTrigger, iovs, niovs, recv, "recvmsg", POLLIN, altPoll);
     }
 
-    status_t consumePendingAncillaryData(std::vector<base::unique_fd>* fds) override {
-        fds->reserve(fds->size() + mFdsPendingRead.size());
-        for (auto& fd : mFdsPendingRead) {
-            fds->emplace_back(std::move(fd));
-        }
-        mFdsPendingRead.clear();
-        return OK;
-    }
-
 private:
     base::unique_fd mSocket;
-    std::vector<base::unique_fd> mFdsPendingRead;
 };
 
 // RpcTransportCtx with TLS disabled.
diff --git a/libs/binder/RpcTransportTls.cpp b/libs/binder/RpcTransportTls.cpp
index 7783111..09b5c17 100644
--- a/libs/binder/RpcTransportTls.cpp
+++ b/libs/binder/RpcTransportTls.cpp
@@ -288,12 +288,7 @@
     status_t interruptableReadFully(
             FdTrigger* fdTrigger, iovec* iovs, int niovs,
             const std::optional<android::base::function_ref<status_t()>>& altPoll,
-            bool enableAncillaryFds) override;
-
-    status_t consumePendingAncillaryData(std::vector<base::unique_fd>* fds) override {
-        (void)fds;
-        return OK;
-    }
+            std::vector<std::variant<base::unique_fd, base::borrowed_fd>>* ancillaryFds) override;
 
 private:
     android::base::unique_fd mSocket;
@@ -368,8 +363,8 @@
 status_t RpcTransportTls::interruptableReadFully(
         FdTrigger* fdTrigger, iovec* iovs, int niovs,
         const std::optional<android::base::function_ref<status_t()>>& altPoll,
-        bool enableAncillaryFds) {
-    (void)enableAncillaryFds;
+        std::vector<std::variant<base::unique_fd, base::borrowed_fd>>* ancillaryFds) {
+    (void)ancillaryFds;
 
     MAYBE_WAIT_IN_FLAKE_MODE;
 
diff --git a/libs/binder/RpcWireFormat.h b/libs/binder/RpcWireFormat.h
index 13989e5..ff1b01a 100644
--- a/libs/binder/RpcWireFormat.h
+++ b/libs/binder/RpcWireFormat.h
@@ -109,6 +109,10 @@
 
 // serialization is like:
 // |RpcWireHeader|struct desginated by 'command'| (over and over again)
+//
+// When file descriptors are included in out-of-band data (e.g. in unix domain
+// sockets), they are always paired with the RpcWireHeader bytes of the
+// transaction or reply the file descriptors belong to.
 
 struct RpcWireHeader {
     uint32_t command; // RPC_COMMAND_*
diff --git a/libs/binder/include/binder/Parcel.h b/libs/binder/include/binder/Parcel.h
index 68a4aef..32b0ded 100644
--- a/libs/binder/include/binder/Parcel.h
+++ b/libs/binder/include/binder/Parcel.h
@@ -609,10 +609,11 @@
     void ipcSetDataReference(const uint8_t* data, size_t dataSize, const binder_size_t* objects,
                              size_t objectsCount, release_func relFunc);
     // Takes ownership even when an error is returned.
-    status_t rpcSetDataReference(const sp<RpcSession>& session, const uint8_t* data,
-                                 size_t dataSize, const uint32_t* objectTable,
-                                 size_t objectTableSize, std::vector<base::unique_fd> ancillaryFds,
-                                 release_func relFunc);
+    status_t rpcSetDataReference(
+            const sp<RpcSession>& session, const uint8_t* data, size_t dataSize,
+            const uint32_t* objectTable, size_t objectTableSize,
+            std::vector<std::variant<base::unique_fd, base::borrowed_fd>>&& ancillaryFds,
+            release_func relFunc);
 
     status_t            finishWrite(size_t len);
     void                releaseObjects();
diff --git a/libs/binder/include/binder/RpcTransport.h b/libs/binder/include/binder/RpcTransport.h
index 80f5a32..5197ef9 100644
--- a/libs/binder/include/binder/RpcTransport.h
+++ b/libs/binder/include/binder/RpcTransport.h
@@ -63,12 +63,10 @@
      * to read/write data. If this returns an error, that error is returned from
      * this function.
      *
-     * ancillaryFds - FDs to be sent via UNIX domain dockets or Trusty IPC.
-     *
-     * enableAncillaryFds - Whether to check for FDs in the ancillary data and
-     * queue for them for use in `consumePendingAncillaryData`. If false and FDs
-     * are received, they will be silently dropped (and closed) by the operating
-     * system.
+     * ancillaryFds - FDs to be sent via UNIX domain dockets or Trusty IPC. When
+     * reading, if `ancillaryFds` is null, any received FDs will be silently
+     * dropped and closed (by the OS). Appended values will always be unique_fd,
+     * the variant type is used to avoid extra copies elsewhere.
      *
      * Return:
      *   OK - succeeded in completely processing 'size'
@@ -81,13 +79,7 @@
     [[nodiscard]] virtual status_t interruptableReadFully(
             FdTrigger *fdTrigger, iovec *iovs, int niovs,
             const std::optional<android::base::function_ref<status_t()>> &altPoll,
-            bool enableAncillaryFds) = 0;
-
-    // Consume the ancillary data that was accumulated from previous
-    // `interruptableReadFully` calls.
-    //
-    // Appends to `fds`.
-    virtual status_t consumePendingAncillaryData(std::vector<base::unique_fd> *fds) = 0;
+            std::vector<std::variant<base::unique_fd, base::borrowed_fd>> *ancillaryFds) = 0;
 
 protected:
     RpcTransport() = default;
diff --git a/libs/binder/tests/binderRpcTest.cpp b/libs/binder/tests/binderRpcTest.cpp
index 3b1fc82..c8b724b 100644
--- a/libs/binder/tests/binderRpcTest.cpp
+++ b/libs/binder/tests/binderRpcTest.cpp
@@ -2000,7 +2000,7 @@
             iovec readMessageIov{readMessage.data(), readMessage.size()};
             status_t readStatus =
                     mClientTransport->interruptableReadFully(mFdTrigger.get(), &readMessageIov, 1,
-                                                             std::nullopt, false);
+                                                             std::nullopt, nullptr);
             if (readStatus != OK) {
                 return AssertionFailure() << statusToString(readStatus);
             }