binder: Add FD support to RPC Binder

Bug: 185909244
Test: TH
Change-Id: Ic4fc1b1edfe9d69984e785553cd1aaca97a07da3
diff --git a/libs/binder/RpcState.cpp b/libs/binder/RpcState.cpp
index 419df86..2b1d0e5 100644
--- a/libs/binder/RpcState.cpp
+++ b/libs/binder/RpcState.cpp
@@ -21,6 +21,7 @@
 #include <android-base/hex.h>
 #include <android-base/macros.h>
 #include <android-base/scopeguard.h>
+#include <android-base/stringprintf.h>
 #include <binder/BpBinder.h>
 #include <binder/IPCThreadState.h>
 #include <binder/RpcServer.h>
@@ -36,6 +37,7 @@
 namespace android {
 
 using base::ScopeGuard;
+using base::StringPrintf;
 
 #if RPC_FLAKE_PRONE
 void rpcMaybeWaitToFlake() {
@@ -50,6 +52,15 @@
 }
 #endif
 
+static bool enableAncillaryFds(RpcSession::FileDescriptorTransportMode mode) {
+    switch (mode) {
+        case RpcSession::FileDescriptorTransportMode::NONE:
+            return false;
+        case RpcSession::FileDescriptorTransportMode::UNIX:
+            return true;
+    }
+}
+
 RpcState::RpcState() {}
 RpcState::~RpcState() {}
 
@@ -310,9 +321,11 @@
     mData.reset(new (std::nothrow) uint8_t[size]);
 }
 
-status_t RpcState::rpcSend(const sp<RpcSession::RpcConnection>& connection,
-                           const sp<RpcSession>& session, const char* what, iovec* iovs, int niovs,
-                           const std::optional<android::base::function_ref<status_t()>>& altPoll) {
+status_t RpcState::rpcSend(
+        const sp<RpcSession::RpcConnection>& connection, const sp<RpcSession>& session,
+        const char* what, iovec* iovs, int niovs,
+        const std::optional<android::base::function_ref<status_t()>>& altPoll,
+        const std::vector<std::variant<base::unique_fd, base::borrowed_fd>>* ancillaryFds) {
     for (int i = 0; i < niovs; i++) {
         LOG_RPC_DETAIL("Sending %s (part %d of %d) on RpcTransport %p: %s",
                        what, i + 1, niovs, connection->rpcTransport.get(),
@@ -321,7 +334,8 @@
 
     if (status_t status =
                 connection->rpcTransport->interruptableWriteFully(session->mShutdownTrigger.get(),
-                                                                  iovs, niovs, altPoll);
+                                                                  iovs, niovs, altPoll,
+                                                                  ancillaryFds);
         status != OK) {
         LOG_RPC_DETAIL("Failed to write %s (%d iovs) on RpcTransport %p, error: %s", what, niovs,
                        connection->rpcTransport.get(), statusToString(status).c_str());
@@ -334,9 +348,9 @@
 
 status_t RpcState::rpcRec(const sp<RpcSession::RpcConnection>& connection,
                           const sp<RpcSession>& session, const char* what, iovec* iovs, int niovs) {
-    if (status_t status =
-                connection->rpcTransport->interruptableReadFully(session->mShutdownTrigger.get(),
-                                                                 iovs, niovs, std::nullopt);
+    if (status_t status = connection->rpcTransport->interruptableReadFully(
+                session->mShutdownTrigger.get(), iovs, niovs, std::nullopt,
+                enableAncillaryFds(session->getFileDescriptorTransportMode()));
         status != OK) {
         LOG_RPC_DETAIL("Failed to read %s (%d iovs) on RpcTransport %p, error: %s", what, niovs,
                        connection->rpcTransport.get(), statusToString(status).c_str());
@@ -449,20 +463,12 @@
 status_t RpcState::transact(const sp<RpcSession::RpcConnection>& connection,
                             const sp<IBinder>& binder, uint32_t code, const Parcel& data,
                             const sp<RpcSession>& session, Parcel* reply, uint32_t flags) {
-    if (!data.isForRpc()) {
-        ALOGE("Refusing to send RPC with parcel not crafted for RPC call on binder %p code "
-              "%" PRIu32,
-              binder.get(), code);
-        return BAD_TYPE;
+    std::string errorMsg;
+    if (status_t status = validateParcel(session, data, &errorMsg); status != OK) {
+        ALOGE("Refusing to send RPC on binder %p code %" PRIu32 ": Parcel %p failed validation: %s",
+              binder.get(), code, &data, errorMsg.c_str());
+        return status;
     }
-
-    if (data.objectsCount() != 0) {
-        ALOGE("Parcel at %p has attached objects but is being used in an RPC call on binder %p "
-              "code %" PRIu32,
-              &data, binder.get(), code);
-        return BAD_TYPE;
-    }
-
     uint64_t address;
     if (status_t status = onBinderLeaving(session, binder, &address); status != OK) return status;
 
@@ -494,9 +500,11 @@
         }
     }
 
-    // objectTable always empty for now. Will be populated from `data` soon.
-    std::vector<uint32_t> objectTable;
-    Span<const uint32_t> objectTableSpan = {objectTable.data(), objectTable.size()};
+    auto* rpcFields = data.maybeRpcFields();
+    LOG_ALWAYS_FATAL_IF(rpcFields == nullptr);
+
+    Span<const uint32_t> objectTableSpan = Span<const uint32_t>{rpcFields->mObjectPositions.data(),
+                                                                rpcFields->mObjectPositions.size()};
 
     uint32_t bodySize;
     LOG_ALWAYS_FATAL_IF(__builtin_add_overflow(sizeof(RpcWireTransaction), data.dataSize(),
@@ -532,25 +540,25 @@
             {const_cast<uint8_t*>(data.data()), data.dataSize()},
             objectTableSpan.toIovec(),
     };
-    if (status_t status = rpcSend(connection, session, "transaction", iovs, arraysize(iovs),
-                                  [&] {
-                                      if (waitUs > kWaitLogUs) {
-                                          ALOGE("Cannot send command, trying to process pending "
-                                                "refcounts. Waiting %zuus. Too "
-                                                "many oneway calls?",
-                                                waitUs);
-                                      }
+    if (status_t status = rpcSend(
+                connection, session, "transaction", iovs, arraysize(iovs),
+                [&] {
+                    if (waitUs > kWaitLogUs) {
+                        ALOGE("Cannot send command, trying to process pending refcounts. Waiting "
+                              "%zuus. Too many oneway calls?",
+                              waitUs);
+                    }
 
-                                      if (waitUs > 0) {
-                                          usleep(waitUs);
-                                          waitUs = std::min(kWaitMaxUs, waitUs * 2);
-                                      } else {
-                                          waitUs = 1;
-                                      }
+                    if (waitUs > 0) {
+                        usleep(waitUs);
+                        waitUs = std::min(kWaitMaxUs, waitUs * 2);
+                    } else {
+                        waitUs = 1;
+                    }
 
-                                      return drainCommands(connection, session,
-                                                           CommandType::CONTROL_ONLY);
-                                  });
+                    return drainCommands(connection, session, CommandType::CONTROL_ONLY);
+                },
+                rpcFields->mFds.get());
         status != OK) {
         // TODO(b/167966510): need to undo onBinderLeaving - we know the
         // refcount isn't successfully transferred.
@@ -617,18 +625,37 @@
     if (status_t status = rpcRec(connection, session, "reply body", iovs, arraysize(iovs));
         status != OK)
         return status;
+
+    // Check if the reply came with any ancillary data.
+    std::vector<base::unique_fd> pendingFds;
+    if (status_t status = connection->rpcTransport->consumePendingAncillaryData(&pendingFds);
+        status != OK) {
+        return status;
+    }
+
     if (rpcReply.status != OK) return rpcReply.status;
 
     Span<const uint8_t> parcelSpan = {data.data(), data.size()};
+    Span<const uint32_t> objectTableSpan;
     if (session->getProtocolVersion().value() >=
         RPC_WIRE_PROTOCOL_VERSION_RPC_HEADER_FEATURE_EXPLICIT_PARCEL_SIZE) {
         Span<const uint8_t> objectTableBytes = parcelSpan.splitOff(rpcReply.parcelDataSize);
-        LOG_ALWAYS_FATAL_IF(objectTableBytes.size > 0, "Non-empty object table not supported yet.");
+        std::optional<Span<const uint32_t>> maybeSpan =
+                objectTableBytes.reinterpret<const uint32_t>();
+        if (!maybeSpan.has_value()) {
+            ALOGE("Bad object table size inferred from RpcWireReply. Saw bodySize=%" PRId32
+                  " sizeofHeader=%zu parcelSize=%" PRId32 " objectTableBytesSize=%zu. Terminating!",
+                  command.bodySize, rpcReplyWireSize, rpcReply.parcelDataSize,
+                  objectTableBytes.size);
+            return BAD_VALUE;
+        }
+        objectTableSpan = *maybeSpan;
     }
 
     data.release();
-    reply->rpcSetDataReference(session, parcelSpan.data, parcelSpan.size, cleanup_reply_data);
-    return OK;
+    return reply->rpcSetDataReference(session, parcelSpan.data, parcelSpan.size,
+                                      objectTableSpan.data, objectTableSpan.size,
+                                      std::move(pendingFds), cleanup_reply_data);
 }
 
 status_t RpcState::sendDecStrongToTarget(const sp<RpcSession::RpcConnection>& connection,
@@ -842,14 +869,31 @@
     reply.markForRpc(session);
 
     if (replyStatus == OK) {
+        // Check if the transaction came with any ancillary data.
+        std::vector<base::unique_fd> pendingFds;
+        if (status_t status = connection->rpcTransport->consumePendingAncillaryData(&pendingFds);
+            status != OK) {
+            return status;
+        }
+
         Span<const uint8_t> parcelSpan = {transaction->data,
                                           transactionData.size() -
                                                   offsetof(RpcWireTransaction, data)};
-        if (session->getProtocolVersion().value() >=
+        Span<const uint32_t> objectTableSpan;
+        if (session->getProtocolVersion().value() >
             RPC_WIRE_PROTOCOL_VERSION_RPC_HEADER_FEATURE_EXPLICIT_PARCEL_SIZE) {
             Span<const uint8_t> objectTableBytes = parcelSpan.splitOff(transaction->parcelDataSize);
-            LOG_ALWAYS_FATAL_IF(objectTableBytes.size > 0,
-                                "Non-empty object table not supported yet.");
+            std::optional<Span<const uint32_t>> maybeSpan =
+                    objectTableBytes.reinterpret<const uint32_t>();
+            if (!maybeSpan.has_value()) {
+                ALOGE("Bad object table size inferred from RpcWireTransaction. Saw bodySize=%zu "
+                      "sizeofHeader=%zu parcelSize=%" PRId32
+                      " objectTableBytesSize=%zu. Terminating!",
+                      transactionData.size(), sizeof(RpcWireTransaction),
+                      transaction->parcelDataSize, objectTableBytes.size);
+                return BAD_VALUE;
+            }
+            objectTableSpan = *maybeSpan;
         }
 
         Parcel data;
@@ -857,47 +901,50 @@
         // only holds onto it for the duration of this function call. Parcel will be
         // deleted before the 'transactionData' object.
 
-        data.rpcSetDataReference(session, parcelSpan.data, parcelSpan.size,
-                                 do_nothing_to_transact_data);
+        replyStatus = data.rpcSetDataReference(session, parcelSpan.data, parcelSpan.size,
+                                               objectTableSpan.data, objectTableSpan.size,
+                                               std::move(pendingFds), do_nothing_to_transact_data);
 
-        if (target) {
-            bool origAllowNested = connection->allowNested;
-            connection->allowNested = !oneway;
+        if (replyStatus == OK) {
+            if (target) {
+                bool origAllowNested = connection->allowNested;
+                connection->allowNested = !oneway;
 
-            replyStatus = target->transact(transaction->code, data, &reply, transaction->flags);
+                replyStatus = target->transact(transaction->code, data, &reply, transaction->flags);
 
-            connection->allowNested = origAllowNested;
-        } else {
-            LOG_RPC_DETAIL("Got special transaction %u", transaction->code);
+                connection->allowNested = origAllowNested;
+            } else {
+                LOG_RPC_DETAIL("Got special transaction %u", transaction->code);
 
-            switch (transaction->code) {
-                case RPC_SPECIAL_TRANSACT_GET_MAX_THREADS: {
-                    replyStatus = reply.writeInt32(session->getMaxIncomingThreads());
-                    break;
-                }
-                case RPC_SPECIAL_TRANSACT_GET_SESSION_ID: {
-                    // for client connections, this should always report the value
-                    // originally returned from the server, so this is asserting
-                    // that it exists
-                    replyStatus = reply.writeByteVector(session->mId);
-                    break;
-                }
-                default: {
-                    sp<RpcServer> server = session->server();
-                    if (server) {
-                        switch (transaction->code) {
-                            case RPC_SPECIAL_TRANSACT_GET_ROOT: {
-                                sp<IBinder> root = session->mSessionSpecificRootObject
-                                        ?: server->getRootObject();
-                                replyStatus = reply.writeStrongBinder(root);
-                                break;
+                switch (transaction->code) {
+                    case RPC_SPECIAL_TRANSACT_GET_MAX_THREADS: {
+                        replyStatus = reply.writeInt32(session->getMaxIncomingThreads());
+                        break;
+                    }
+                    case RPC_SPECIAL_TRANSACT_GET_SESSION_ID: {
+                        // for client connections, this should always report the value
+                        // originally returned from the server, so this is asserting
+                        // that it exists
+                        replyStatus = reply.writeByteVector(session->mId);
+                        break;
+                    }
+                    default: {
+                        sp<RpcServer> server = session->server();
+                        if (server) {
+                            switch (transaction->code) {
+                                case RPC_SPECIAL_TRANSACT_GET_ROOT: {
+                                    sp<IBinder> root = session->mSessionSpecificRootObject
+                                            ?: server->getRootObject();
+                                    replyStatus = reply.writeStrongBinder(root);
+                                    break;
+                                }
+                                default: {
+                                    replyStatus = UNKNOWN_TRANSACTION;
+                                }
                             }
-                            default: {
-                                replyStatus = UNKNOWN_TRANSACTION;
-                            }
+                        } else {
+                            ALOGE("Special command sent, but no server object attached.");
                         }
-                    } else {
-                        ALOGE("Special command sent, but no server object attached.");
                     }
                 }
             }
@@ -969,11 +1016,22 @@
         replyStatus = flushExcessBinderRefs(session, addr, target);
     }
 
+    std::string errorMsg;
+    if (status_t status = validateParcel(session, reply, &errorMsg); status != OK) {
+        ALOGE("Reply Parcel failed validation: %s", errorMsg.c_str());
+        // Forward the error to the client of the transaction.
+        reply.freeData();
+        reply.markForRpc(session);
+        replyStatus = status;
+    }
+
+    auto* rpcFields = reply.maybeRpcFields();
+    LOG_ALWAYS_FATAL_IF(rpcFields == nullptr);
+
     const size_t rpcReplyWireSize = RpcWireReply::wireSize(session->getProtocolVersion().value());
 
-    // objectTable always empty for now. Will be populated from `reply` soon.
-    std::vector<uint32_t> objectTable;
-    Span<const uint32_t> objectTableSpan = {objectTable.data(), objectTable.size()};
+    Span<const uint32_t> objectTableSpan = Span<const uint32_t>{rpcFields->mObjectPositions.data(),
+                                                                rpcFields->mObjectPositions.size()};
 
     uint32_t bodySize;
     LOG_ALWAYS_FATAL_IF(__builtin_add_overflow(rpcReplyWireSize, reply.dataSize(), &bodySize) ||
@@ -998,7 +1056,8 @@
             {const_cast<uint8_t*>(reply.data()), reply.dataSize()},
             objectTableSpan.toIovec(),
     };
-    return rpcSend(connection, session, "reply", iovs, arraysize(iovs), std::nullopt);
+    return rpcSend(connection, session, "reply", iovs, arraysize(iovs), std::nullopt,
+                   rpcFields->mFds.get());
 }
 
 status_t RpcState::processDecStrong(const sp<RpcSession::RpcConnection>& connection,
@@ -1055,6 +1114,50 @@
     return OK;
 }
 
+status_t RpcState::validateParcel(const sp<RpcSession>& session, const Parcel& parcel,
+                                  std::string* errorMsg) {
+    auto* rpcFields = parcel.maybeRpcFields();
+    if (rpcFields == nullptr) {
+        *errorMsg = "Parcel not crafted for RPC call";
+        return BAD_TYPE;
+    }
+
+    if (rpcFields->mSession != session) {
+        *errorMsg = "Parcel's session doesn't match";
+        return BAD_TYPE;
+    }
+
+    uint32_t protocolVersion = session->getProtocolVersion().value();
+    if (protocolVersion < RPC_WIRE_PROTOCOL_VERSION_RPC_HEADER_FEATURE_EXPLICIT_PARCEL_SIZE &&
+        !rpcFields->mObjectPositions.empty()) {
+        *errorMsg = StringPrintf("Parcel has attached objects but the session's protocol version "
+                                 "(%" PRIu32 ") is too old, must be at least %" PRIu32,
+                                 protocolVersion,
+                                 RPC_WIRE_PROTOCOL_VERSION_RPC_HEADER_FEATURE_EXPLICIT_PARCEL_SIZE);
+        return BAD_VALUE;
+    }
+
+    if (rpcFields->mFds && !rpcFields->mFds->empty()) {
+        switch (session->getFileDescriptorTransportMode()) {
+            case RpcSession::FileDescriptorTransportMode::NONE:
+                *errorMsg =
+                        "Parcel has file descriptors, but no file descriptor transport is enabled";
+                return FDS_NOT_ALLOWED;
+            case RpcSession::FileDescriptorTransportMode::UNIX: {
+                constexpr size_t kMaxFdsPerMsg = 253;
+                if (rpcFields->mFds->size() > kMaxFdsPerMsg) {
+                    *errorMsg = StringPrintf("Too many file descriptors in Parcel for unix "
+                                             "domain socket: %zu (max is %zu)",
+                                             rpcFields->mFds->size(), kMaxFdsPerMsg);
+                    return BAD_VALUE;
+                }
+            }
+        }
+    }
+
+    return OK;
+}
+
 sp<IBinder> RpcState::tryEraseNode(std::map<uint64_t, BinderNode>::iterator& it) {
     sp<IBinder> ref;