Merge "libbinder: Stricter protocol and code for receiving FDs"
diff --git a/libs/binder/Parcel.cpp b/libs/binder/Parcel.cpp
index 8739105..0354382 100644
--- a/libs/binder/Parcel.cpp
+++ b/libs/binder/Parcel.cpp
@@ -2486,11 +2486,11 @@
     scanForFds();
 }
 
-status_t Parcel::rpcSetDataReference(const sp<RpcSession>& session, const uint8_t* data,
-                                     size_t dataSize, const uint32_t* objectTable,
-                                     size_t objectTableSize,
-                                     std::vector<base::unique_fd> ancillaryFds,
-                                     release_func relFunc) {
+status_t Parcel::rpcSetDataReference(
+        const sp<RpcSession>& session, const uint8_t* data, size_t dataSize,
+        const uint32_t* objectTable, size_t objectTableSize,
+        std::vector<std::variant<base::unique_fd, base::borrowed_fd>>&& ancillaryFds,
+        release_func relFunc) {
     // this code uses 'mOwner == nullptr' to understand whether it owns memory
     LOG_ALWAYS_FATAL_IF(relFunc == nullptr, "must provide cleanup function");
 
@@ -2518,10 +2518,7 @@
     }
     if (!ancillaryFds.empty()) {
         rpcFields->mFds = std::make_unique<decltype(rpcFields->mFds)::element_type>();
-        rpcFields->mFds->reserve(ancillaryFds.size());
-        for (auto& fd : ancillaryFds) {
-            rpcFields->mFds->push_back(std::move(fd));
-        }
+        *rpcFields->mFds = std::move(ancillaryFds);
     }
 
     return OK;
diff --git a/libs/binder/RpcServer.cpp b/libs/binder/RpcServer.cpp
index 3bb21ad..f83bb5e 100644
--- a/libs/binder/RpcServer.cpp
+++ b/libs/binder/RpcServer.cpp
@@ -300,7 +300,7 @@
     if (status == OK) {
         iovec iov{&header, sizeof(header)};
         status = client->interruptableReadFully(server->mShutdownTrigger.get(), &iov, 1,
-                                                std::nullopt, /*enableAncillaryFds=*/false);
+                                                std::nullopt, /*ancillaryFds=*/nullptr);
         if (status != OK) {
             ALOGE("Failed to read ID for client connecting to RPC server: %s",
                   statusToString(status).c_str());
@@ -315,7 +315,7 @@
                 sessionId.resize(header.sessionIdSize);
                 iovec iov{sessionId.data(), sessionId.size()};
                 status = client->interruptableReadFully(server->mShutdownTrigger.get(), &iov, 1,
-                                                        std::nullopt, /*enableAncillaryFds=*/false);
+                                                        std::nullopt, /*ancillaryFds=*/nullptr);
                 if (status != OK) {
                     ALOGE("Failed to read session ID for client connecting to RPC server: %s",
                           statusToString(status).c_str());
diff --git a/libs/binder/RpcState.cpp b/libs/binder/RpcState.cpp
index 2b1d0e5..8f104bc 100644
--- a/libs/binder/RpcState.cpp
+++ b/libs/binder/RpcState.cpp
@@ -346,11 +346,14 @@
     return OK;
 }
 
-status_t RpcState::rpcRec(const sp<RpcSession::RpcConnection>& connection,
-                          const sp<RpcSession>& session, const char* what, iovec* iovs, int niovs) {
-    if (status_t status = connection->rpcTransport->interruptableReadFully(
-                session->mShutdownTrigger.get(), iovs, niovs, std::nullopt,
-                enableAncillaryFds(session->getFileDescriptorTransportMode()));
+status_t RpcState::rpcRec(
+        const sp<RpcSession::RpcConnection>& connection, const sp<RpcSession>& session,
+        const char* what, iovec* iovs, int niovs,
+        std::vector<std::variant<base::unique_fd, base::borrowed_fd>>* ancillaryFds) {
+    if (status_t status =
+                connection->rpcTransport->interruptableReadFully(session->mShutdownTrigger.get(),
+                                                                 iovs, niovs, std::nullopt,
+                                                                 ancillaryFds);
         status != OK) {
         LOG_RPC_DETAIL("Failed to read %s (%d iovs) on RpcTransport %p, error: %s", what, niovs,
                        connection->rpcTransport.get(), statusToString(status).c_str());
@@ -370,7 +373,7 @@
                                           const sp<RpcSession>& session, uint32_t* version) {
     RpcNewSessionResponse response;
     iovec iov{&response, sizeof(response)};
-    if (status_t status = rpcRec(connection, session, "new session response", &iov, 1);
+    if (status_t status = rpcRec(connection, session, "new session response", &iov, 1, nullptr);
         status != OK) {
         return status;
     }
@@ -391,7 +394,8 @@
                                       const sp<RpcSession>& session) {
     RpcOutgoingConnectionInit init;
     iovec iov{&init, sizeof(init)};
-    if (status_t status = rpcRec(connection, session, "connection init", &iov, 1); status != OK)
+    if (status_t status = rpcRec(connection, session, "connection init", &iov, 1, nullptr);
+        status != OK)
         return status;
 
     static_assert(sizeof(init.msg) == sizeof(RPC_CONNECTION_INIT_OKAY));
@@ -589,18 +593,26 @@
 
 status_t RpcState::waitForReply(const sp<RpcSession::RpcConnection>& connection,
                                 const sp<RpcSession>& session, Parcel* reply) {
+    std::vector<std::variant<base::unique_fd, base::borrowed_fd>> ancillaryFds;
     RpcWireHeader command;
     while (true) {
         iovec iov{&command, sizeof(command)};
-        if (status_t status = rpcRec(connection, session, "command header (for reply)", &iov, 1);
+        if (status_t status = rpcRec(connection, session, "command header (for reply)", &iov, 1,
+                                     enableAncillaryFds(session->getFileDescriptorTransportMode())
+                                             ? &ancillaryFds
+                                             : nullptr);
             status != OK)
             return status;
 
         if (command.command == RPC_COMMAND_REPLY) break;
 
-        if (status_t status = processCommand(connection, session, command, CommandType::ANY);
+        if (status_t status = processCommand(connection, session, command, CommandType::ANY,
+                                             std::move(ancillaryFds));
             status != OK)
             return status;
+
+        // Reset to avoid spurious use-after-move warning from clang-tidy.
+        ancillaryFds = decltype(ancillaryFds)();
     }
 
     const size_t rpcReplyWireSize = RpcWireReply::wireSize(session->getProtocolVersion().value());
@@ -622,17 +634,10 @@
             {&rpcReply, rpcReplyWireSize},
             {data.data(), data.size()},
     };
-    if (status_t status = rpcRec(connection, session, "reply body", iovs, arraysize(iovs));
+    if (status_t status = rpcRec(connection, session, "reply body", iovs, arraysize(iovs), nullptr);
         status != OK)
         return status;
 
-    // Check if the reply came with any ancillary data.
-    std::vector<base::unique_fd> pendingFds;
-    if (status_t status = connection->rpcTransport->consumePendingAncillaryData(&pendingFds);
-        status != OK) {
-        return status;
-    }
-
     if (rpcReply.status != OK) return rpcReply.status;
 
     Span<const uint8_t> parcelSpan = {data.data(), data.size()};
@@ -655,7 +660,7 @@
     data.release();
     return reply->rpcSetDataReference(session, parcelSpan.data, parcelSpan.size,
                                       objectTableSpan.data, objectTableSpan.size,
-                                      std::move(pendingFds), cleanup_reply_data);
+                                      std::move(ancillaryFds), cleanup_reply_data);
 }
 
 status_t RpcState::sendDecStrongToTarget(const sp<RpcSession::RpcConnection>& connection,
@@ -698,13 +703,17 @@
                                         const sp<RpcSession>& session, CommandType type) {
     LOG_RPC_DETAIL("getAndExecuteCommand on RpcTransport %p", connection->rpcTransport.get());
 
+    std::vector<std::variant<base::unique_fd, base::borrowed_fd>> ancillaryFds;
     RpcWireHeader command;
     iovec iov{&command, sizeof(command)};
-    if (status_t status = rpcRec(connection, session, "command header (for server)", &iov, 1);
+    if (status_t status =
+                rpcRec(connection, session, "command header (for server)", &iov, 1,
+                       enableAncillaryFds(session->getFileDescriptorTransportMode()) ? &ancillaryFds
+                                                                                     : nullptr);
         status != OK)
         return status;
 
-    return processCommand(connection, session, command, type);
+    return processCommand(connection, session, command, type, std::move(ancillaryFds));
 }
 
 status_t RpcState::drainCommands(const sp<RpcSession::RpcConnection>& connection,
@@ -720,9 +729,10 @@
     return OK;
 }
 
-status_t RpcState::processCommand(const sp<RpcSession::RpcConnection>& connection,
-                                  const sp<RpcSession>& session, const RpcWireHeader& command,
-                                  CommandType type) {
+status_t RpcState::processCommand(
+        const sp<RpcSession::RpcConnection>& connection, const sp<RpcSession>& session,
+        const RpcWireHeader& command, CommandType type,
+        std::vector<std::variant<base::unique_fd, base::borrowed_fd>>&& ancillaryFds) {
     IPCThreadState* kernelBinderState = IPCThreadState::selfOrNull();
     IPCThreadState::SpGuard spGuard{
             .address = __builtin_frame_address(0),
@@ -741,7 +751,7 @@
     switch (command.command) {
         case RPC_COMMAND_TRANSACT:
             if (type != CommandType::ANY) return BAD_TYPE;
-            return processTransact(connection, session, command);
+            return processTransact(connection, session, command, std::move(ancillaryFds));
         case RPC_COMMAND_DEC_STRONG:
             return processDecStrong(connection, session, command);
     }
@@ -755,8 +765,10 @@
     (void)session->shutdownAndWait(false);
     return DEAD_OBJECT;
 }
-status_t RpcState::processTransact(const sp<RpcSession::RpcConnection>& connection,
-                                   const sp<RpcSession>& session, const RpcWireHeader& command) {
+status_t RpcState::processTransact(
+        const sp<RpcSession::RpcConnection>& connection, const sp<RpcSession>& session,
+        const RpcWireHeader& command,
+        std::vector<std::variant<base::unique_fd, base::borrowed_fd>>&& ancillaryFds) {
     LOG_ALWAYS_FATAL_IF(command.command != RPC_COMMAND_TRANSACT, "command: %d", command.command);
 
     CommandData transactionData(command.bodySize);
@@ -764,10 +776,12 @@
         return NO_MEMORY;
     }
     iovec iov{transactionData.data(), transactionData.size()};
-    if (status_t status = rpcRec(connection, session, "transaction body", &iov, 1); status != OK)
+    if (status_t status = rpcRec(connection, session, "transaction body", &iov, 1, nullptr);
+        status != OK)
         return status;
 
-    return processTransactInternal(connection, session, std::move(transactionData));
+    return processTransactInternal(connection, session, std::move(transactionData),
+                                   std::move(ancillaryFds));
 }
 
 static void do_nothing_to_transact_data(Parcel* p, const uint8_t* data, size_t dataSize,
@@ -779,9 +793,10 @@
     (void)objectsCount;
 }
 
-status_t RpcState::processTransactInternal(const sp<RpcSession::RpcConnection>& connection,
-                                           const sp<RpcSession>& session,
-                                           CommandData transactionData) {
+status_t RpcState::processTransactInternal(
+        const sp<RpcSession::RpcConnection>& connection, const sp<RpcSession>& session,
+        CommandData transactionData,
+        std::vector<std::variant<base::unique_fd, base::borrowed_fd>>&& ancillaryFds) {
     // for 'recursive' calls to this, we have already read and processed the
     // binder from the transaction data and taken reference counts into account,
     // so it is cached here.
@@ -869,13 +884,6 @@
     reply.markForRpc(session);
 
     if (replyStatus == OK) {
-        // Check if the transaction came with any ancillary data.
-        std::vector<base::unique_fd> pendingFds;
-        if (status_t status = connection->rpcTransport->consumePendingAncillaryData(&pendingFds);
-            status != OK) {
-            return status;
-        }
-
         Span<const uint8_t> parcelSpan = {transaction->data,
                                           transactionData.size() -
                                                   offsetof(RpcWireTransaction, data)};
@@ -901,9 +909,12 @@
         // only holds onto it for the duration of this function call. Parcel will be
         // deleted before the 'transactionData' object.
 
-        replyStatus = data.rpcSetDataReference(session, parcelSpan.data, parcelSpan.size,
-                                               objectTableSpan.data, objectTableSpan.size,
-                                               std::move(pendingFds), do_nothing_to_transact_data);
+        replyStatus =
+                data.rpcSetDataReference(session, parcelSpan.data, parcelSpan.size,
+                                         objectTableSpan.data, objectTableSpan.size,
+                                         std::move(ancillaryFds), do_nothing_to_transact_data);
+        // Reset to avoid spurious use-after-move warning from clang-tidy.
+        ancillaryFds = std::remove_reference<decltype(ancillaryFds)>::type();
 
         if (replyStatus == OK) {
             if (target) {
@@ -1073,7 +1084,8 @@
 
     RpcDecStrong body;
     iovec iov{&body, sizeof(RpcDecStrong)};
-    if (status_t status = rpcRec(connection, session, "dec ref body", &iov, 1); status != OK)
+    if (status_t status = rpcRec(connection, session, "dec ref body", &iov, 1, nullptr);
+        status != OK)
         return status;
 
     uint64_t addr = RpcWireAddress::toRaw(body.address);
diff --git a/libs/binder/RpcState.h b/libs/binder/RpcState.h
index 08a314e..b452a99 100644
--- a/libs/binder/RpcState.h
+++ b/libs/binder/RpcState.h
@@ -184,21 +184,25 @@
             const std::optional<android::base::function_ref<status_t()>>& altPoll,
             const std::vector<std::variant<base::unique_fd, base::borrowed_fd>>* ancillaryFds =
                     nullptr);
-    [[nodiscard]] status_t rpcRec(const sp<RpcSession::RpcConnection>& connection,
-                                  const sp<RpcSession>& session, const char* what, iovec* iovs,
-                                  int niovs);
+    [[nodiscard]] status_t rpcRec(
+            const sp<RpcSession::RpcConnection>& connection, const sp<RpcSession>& session,
+            const char* what, iovec* iovs, int niovs,
+            std::vector<std::variant<base::unique_fd, base::borrowed_fd>>* ancillaryFds = nullptr);
 
     [[nodiscard]] status_t waitForReply(const sp<RpcSession::RpcConnection>& connection,
                                         const sp<RpcSession>& session, Parcel* reply);
-    [[nodiscard]] status_t processCommand(const sp<RpcSession::RpcConnection>& connection,
-                                          const sp<RpcSession>& session,
-                                          const RpcWireHeader& command, CommandType type);
-    [[nodiscard]] status_t processTransact(const sp<RpcSession::RpcConnection>& connection,
-                                           const sp<RpcSession>& session,
-                                           const RpcWireHeader& command);
-    [[nodiscard]] status_t processTransactInternal(const sp<RpcSession::RpcConnection>& connection,
-                                                   const sp<RpcSession>& session,
-                                                   CommandData transactionData);
+    [[nodiscard]] status_t processCommand(
+            const sp<RpcSession::RpcConnection>& connection, const sp<RpcSession>& session,
+            const RpcWireHeader& command, CommandType type,
+            std::vector<std::variant<base::unique_fd, base::borrowed_fd>>&& ancillaryFds);
+    [[nodiscard]] status_t processTransact(
+            const sp<RpcSession::RpcConnection>& connection, const sp<RpcSession>& session,
+            const RpcWireHeader& command,
+            std::vector<std::variant<base::unique_fd, base::borrowed_fd>>&& ancillaryFds);
+    [[nodiscard]] status_t processTransactInternal(
+            const sp<RpcSession::RpcConnection>& connection, const sp<RpcSession>& session,
+            CommandData transactionData,
+            std::vector<std::variant<base::unique_fd, base::borrowed_fd>>&& ancillaryFds);
     [[nodiscard]] status_t processDecStrong(const sp<RpcSession::RpcConnection>& connection,
                                             const sp<RpcSession>& session,
                                             const RpcWireHeader& command);
diff --git a/libs/binder/RpcTransportRaw.cpp b/libs/binder/RpcTransportRaw.cpp
index d9059e9..7cc58cd 100644
--- a/libs/binder/RpcTransportRaw.cpp
+++ b/libs/binder/RpcTransportRaw.cpp
@@ -204,9 +204,9 @@
     status_t interruptableReadFully(
             FdTrigger* fdTrigger, iovec* iovs, int niovs,
             const std::optional<android::base::function_ref<status_t()>>& altPoll,
-            bool enableAncillaryFds) override {
+            std::vector<std::variant<base::unique_fd, base::borrowed_fd>>* ancillaryFds) override {
         auto recv = [&](iovec* iovs, int niovs) -> ssize_t {
-            if (enableAncillaryFds) {
+            if (ancillaryFds != nullptr) {
                 int fdBuffer[kMaxFdsPerMsg];
                 alignas(struct cmsghdr) char msgControlBuf[CMSG_SPACE(sizeof(fdBuffer))];
 
@@ -228,10 +228,12 @@
                         // NOTE: It is tempting to reinterpret_cast, but cmsg(3) explicitly asks
                         // application devs to memcpy the data to ensure memory alignment.
                         size_t dataLen = cmsg->cmsg_len - CMSG_LEN(0);
+                        LOG_ALWAYS_FATAL_IF(dataLen > sizeof(fdBuffer)); // sanity check
                         memcpy(fdBuffer, CMSG_DATA(cmsg), dataLen);
                         size_t fdCount = dataLen / sizeof(int);
+                        ancillaryFds->reserve(ancillaryFds->size() + fdCount);
                         for (size_t i = 0; i < fdCount; i++) {
-                            mFdsPendingRead.emplace_back(fdBuffer[i]);
+                            ancillaryFds->emplace_back(base::unique_fd(fdBuffer[i]));
                         }
                         break;
                     }
@@ -256,18 +258,8 @@
         return interruptableReadOrWrite(fdTrigger, iovs, niovs, recv, "recvmsg", POLLIN, altPoll);
     }
 
-    status_t consumePendingAncillaryData(std::vector<base::unique_fd>* fds) override {
-        fds->reserve(fds->size() + mFdsPendingRead.size());
-        for (auto& fd : mFdsPendingRead) {
-            fds->emplace_back(std::move(fd));
-        }
-        mFdsPendingRead.clear();
-        return OK;
-    }
-
 private:
     base::unique_fd mSocket;
-    std::vector<base::unique_fd> mFdsPendingRead;
 };
 
 // RpcTransportCtx with TLS disabled.
diff --git a/libs/binder/RpcTransportTls.cpp b/libs/binder/RpcTransportTls.cpp
index 7783111..09b5c17 100644
--- a/libs/binder/RpcTransportTls.cpp
+++ b/libs/binder/RpcTransportTls.cpp
@@ -288,12 +288,7 @@
     status_t interruptableReadFully(
             FdTrigger* fdTrigger, iovec* iovs, int niovs,
             const std::optional<android::base::function_ref<status_t()>>& altPoll,
-            bool enableAncillaryFds) override;
-
-    status_t consumePendingAncillaryData(std::vector<base::unique_fd>* fds) override {
-        (void)fds;
-        return OK;
-    }
+            std::vector<std::variant<base::unique_fd, base::borrowed_fd>>* ancillaryFds) override;
 
 private:
     android::base::unique_fd mSocket;
@@ -368,8 +363,8 @@
 status_t RpcTransportTls::interruptableReadFully(
         FdTrigger* fdTrigger, iovec* iovs, int niovs,
         const std::optional<android::base::function_ref<status_t()>>& altPoll,
-        bool enableAncillaryFds) {
-    (void)enableAncillaryFds;
+        std::vector<std::variant<base::unique_fd, base::borrowed_fd>>* ancillaryFds) {
+    (void)ancillaryFds;
 
     MAYBE_WAIT_IN_FLAKE_MODE;
 
diff --git a/libs/binder/RpcWireFormat.h b/libs/binder/RpcWireFormat.h
index 13989e5..ff1b01a 100644
--- a/libs/binder/RpcWireFormat.h
+++ b/libs/binder/RpcWireFormat.h
@@ -109,6 +109,10 @@
 
 // serialization is like:
 // |RpcWireHeader|struct desginated by 'command'| (over and over again)
+//
+// When file descriptors are included in out-of-band data (e.g. in unix domain
+// sockets), they are always paired with the RpcWireHeader bytes of the
+// transaction or reply the file descriptors belong to.
 
 struct RpcWireHeader {
     uint32_t command; // RPC_COMMAND_*
diff --git a/libs/binder/include/binder/Parcel.h b/libs/binder/include/binder/Parcel.h
index 68a4aef..32b0ded 100644
--- a/libs/binder/include/binder/Parcel.h
+++ b/libs/binder/include/binder/Parcel.h
@@ -609,10 +609,11 @@
     void ipcSetDataReference(const uint8_t* data, size_t dataSize, const binder_size_t* objects,
                              size_t objectsCount, release_func relFunc);
     // Takes ownership even when an error is returned.
-    status_t rpcSetDataReference(const sp<RpcSession>& session, const uint8_t* data,
-                                 size_t dataSize, const uint32_t* objectTable,
-                                 size_t objectTableSize, std::vector<base::unique_fd> ancillaryFds,
-                                 release_func relFunc);
+    status_t rpcSetDataReference(
+            const sp<RpcSession>& session, const uint8_t* data, size_t dataSize,
+            const uint32_t* objectTable, size_t objectTableSize,
+            std::vector<std::variant<base::unique_fd, base::borrowed_fd>>&& ancillaryFds,
+            release_func relFunc);
 
     status_t            finishWrite(size_t len);
     void                releaseObjects();
diff --git a/libs/binder/include/binder/RpcTransport.h b/libs/binder/include/binder/RpcTransport.h
index 80f5a32..5197ef9 100644
--- a/libs/binder/include/binder/RpcTransport.h
+++ b/libs/binder/include/binder/RpcTransport.h
@@ -63,12 +63,10 @@
      * to read/write data. If this returns an error, that error is returned from
      * this function.
      *
-     * ancillaryFds - FDs to be sent via UNIX domain dockets or Trusty IPC.
-     *
-     * enableAncillaryFds - Whether to check for FDs in the ancillary data and
-     * queue for them for use in `consumePendingAncillaryData`. If false and FDs
-     * are received, they will be silently dropped (and closed) by the operating
-     * system.
+     * ancillaryFds - FDs to be sent via UNIX domain dockets or Trusty IPC. When
+     * reading, if `ancillaryFds` is null, any received FDs will be silently
+     * dropped and closed (by the OS). Appended values will always be unique_fd,
+     * the variant type is used to avoid extra copies elsewhere.
      *
      * Return:
      *   OK - succeeded in completely processing 'size'
@@ -81,13 +79,7 @@
     [[nodiscard]] virtual status_t interruptableReadFully(
             FdTrigger *fdTrigger, iovec *iovs, int niovs,
             const std::optional<android::base::function_ref<status_t()>> &altPoll,
-            bool enableAncillaryFds) = 0;
-
-    // Consume the ancillary data that was accumulated from previous
-    // `interruptableReadFully` calls.
-    //
-    // Appends to `fds`.
-    virtual status_t consumePendingAncillaryData(std::vector<base::unique_fd> *fds) = 0;
+            std::vector<std::variant<base::unique_fd, base::borrowed_fd>> *ancillaryFds) = 0;
 
 protected:
     RpcTransport() = default;
diff --git a/libs/binder/tests/binderRpcTest.cpp b/libs/binder/tests/binderRpcTest.cpp
index 3b1fc82..c8b724b 100644
--- a/libs/binder/tests/binderRpcTest.cpp
+++ b/libs/binder/tests/binderRpcTest.cpp
@@ -2000,7 +2000,7 @@
             iovec readMessageIov{readMessage.data(), readMessage.size()};
             status_t readStatus =
                     mClientTransport->interruptableReadFully(mFdTrigger.get(), &readMessageIov, 1,
-                                                             std::nullopt, false);
+                                                             std::nullopt, nullptr);
             if (readStatus != OK) {
                 return AssertionFailure() << statusToString(readStatus);
             }