binder: Add FD support to RPC Binder

Bug: 185909244
Test: TH
Change-Id: Ic4fc1b1edfe9d69984e785553cd1aaca97a07da3
diff --git a/libs/binder/Parcel.cpp b/libs/binder/Parcel.cpp
index e67dd7b..8739105 100644
--- a/libs/binder/Parcel.cpp
+++ b/libs/binder/Parcel.cpp
@@ -40,6 +40,7 @@
 #include <binder/Status.h>
 #include <binder/TextOutput.h>
 
+#include <android-base/scopeguard.h>
 #include <cutils/ashmem.h>
 #include <cutils/compiler.h>
 #include <utils/Flattenable.h>
@@ -152,6 +153,10 @@
     ALOGE("Invalid object type 0x%08x", obj.hdr.type);
 }
 
+static int toRawFd(const std::variant<base::unique_fd, base::borrowed_fd>& v) {
+    return std::visit([](const auto& fd) { return fd.get(); }, v);
+}
+
 Parcel::RpcFields::RpcFields(const sp<RpcSession>& session) : mSession(session) {
     LOG_ALWAYS_FATAL_IF(mSession == nullptr);
 }
@@ -530,6 +535,63 @@
                 }
             }
         }
+    } else {
+        auto* rpcFields = maybeRpcFields();
+        LOG_ALWAYS_FATAL_IF(rpcFields == nullptr);
+        auto* otherRpcFields = parcel->maybeRpcFields();
+        if (otherRpcFields == nullptr) {
+            return BAD_TYPE;
+        }
+        if (rpcFields->mSession != otherRpcFields->mSession) {
+            return BAD_TYPE;
+        }
+
+        const size_t savedDataPos = mDataPos;
+        base::ScopeGuard scopeGuard = [&]() { mDataPos = savedDataPos; };
+
+        rpcFields->mObjectPositions.reserve(otherRpcFields->mObjectPositions.size());
+        if (otherRpcFields->mFds != nullptr) {
+            if (rpcFields->mFds == nullptr) {
+                rpcFields->mFds = std::make_unique<decltype(rpcFields->mFds)::element_type>();
+            }
+            rpcFields->mFds->reserve(otherRpcFields->mFds->size());
+        }
+        for (size_t i = 0; i < otherRpcFields->mObjectPositions.size(); i++) {
+            const binder_size_t objPos = otherRpcFields->mObjectPositions[i];
+            if (offset <= objPos && objPos < offset + len) {
+                size_t newDataPos = objPos - offset + startPos;
+                rpcFields->mObjectPositions.push_back(newDataPos);
+
+                mDataPos = newDataPos;
+                int32_t objectType;
+                if (status_t status = readInt32(&objectType); status != OK) {
+                    return status;
+                }
+                if (objectType != RpcFields::TYPE_NATIVE_FILE_DESCRIPTOR) {
+                    continue;
+                }
+
+                if (!mAllowFds) {
+                    return FDS_NOT_ALLOWED;
+                }
+
+                // Read FD, duplicate, and add to list.
+                int32_t fdIndex;
+                if (status_t status = readInt32(&fdIndex); status != OK) {
+                    return status;
+                }
+                const auto& oldFd = otherRpcFields->mFds->at(fdIndex);
+                // To match kernel binder behavior, we always dup, even if the
+                // FD was unowned in the source parcel.
+                rpcFields->mFds->emplace_back(
+                        base::unique_fd(fcntl(toRawFd(oldFd), F_DUPFD_CLOEXEC, 0)));
+                // Fixup the index in the data.
+                mDataPos = newDataPos + 4;
+                if (status_t status = writeInt32(rpcFields->mFds->size() - 1); status != OK) {
+                    return status;
+                }
+            }
+        }
     }
 
     return err;
@@ -584,7 +646,7 @@
 bool Parcel::hasFileDescriptors() const
 {
     if (const auto* rpcFields = maybeRpcFields()) {
-        return false;
+        return rpcFields->mFds != nullptr && !rpcFields->mFds->empty();
     }
     auto* kernelFields = maybeKernelFields();
     if (!kernelFields->mFdsKnown) {
@@ -621,34 +683,31 @@
 std::vector<int> Parcel::debugReadAllFileDescriptors() const {
     std::vector<int> ret;
 
-    const auto* kernelFields = maybeKernelFields();
-    if (kernelFields == nullptr) {
-        return ret;
+    if (const auto* kernelFields = maybeKernelFields()) {
+        size_t initPosition = dataPosition();
+        for (size_t i = 0; i < kernelFields->mObjectsSize; i++) {
+            binder_size_t offset = kernelFields->mObjects[i];
+            const flat_binder_object* flat =
+                    reinterpret_cast<const flat_binder_object*>(mData + offset);
+            if (flat->hdr.type != BINDER_TYPE_FD) continue;
+
+            setDataPosition(offset);
+
+            int fd = readFileDescriptor();
+            LOG_ALWAYS_FATAL_IF(fd == -1);
+            ret.push_back(fd);
+        }
+        setDataPosition(initPosition);
+    } else if (const auto* rpcFields = maybeRpcFields(); rpcFields && rpcFields->mFds) {
+        for (const auto& fd : *rpcFields->mFds) {
+            ret.push_back(toRawFd(fd));
+        }
     }
 
-    size_t initPosition = dataPosition();
-    for (size_t i = 0; i < kernelFields->mObjectsSize; i++) {
-        binder_size_t offset = kernelFields->mObjects[i];
-        const flat_binder_object* flat =
-                reinterpret_cast<const flat_binder_object*>(mData + offset);
-        if (flat->hdr.type != BINDER_TYPE_FD) continue;
-
-        setDataPosition(offset);
-
-        int fd = readFileDescriptor();
-        LOG_ALWAYS_FATAL_IF(fd == -1);
-        ret.push_back(fd);
-    }
-
-    setDataPosition(initPosition);
     return ret;
 }
 
 status_t Parcel::hasFileDescriptorsInRange(size_t offset, size_t len, bool* result) const {
-    const auto* kernelFields = maybeKernelFields();
-    if (kernelFields == nullptr) {
-        return BAD_TYPE;
-    }
     if (len > INT32_MAX || offset > INT32_MAX) {
         // Don't accept size_t values which may have come from an inadvertent conversion from a
         // negative int.
@@ -659,20 +718,33 @@
         return BAD_VALUE;
     }
     *result = false;
-    for (size_t i = 0; i < kernelFields->mObjectsSize; i++) {
-        size_t pos = kernelFields->mObjects[i];
-        if (pos < offset) continue;
-        if (pos + sizeof(flat_binder_object) > offset + len) {
-            if (kernelFields->mObjectsSorted) {
+    if (const auto* kernelFields = maybeKernelFields()) {
+        for (size_t i = 0; i < kernelFields->mObjectsSize; i++) {
+            size_t pos = kernelFields->mObjects[i];
+            if (pos < offset) continue;
+            if (pos + sizeof(flat_binder_object) > offset + len) {
+                if (kernelFields->mObjectsSorted) {
+                    break;
+                } else {
+                    continue;
+                }
+            }
+            const flat_binder_object* flat =
+                    reinterpret_cast<const flat_binder_object*>(mData + pos);
+            if (flat->hdr.type == BINDER_TYPE_FD) {
+                *result = true;
                 break;
-            } else {
-                continue;
             }
         }
-        const flat_binder_object* flat = reinterpret_cast<const flat_binder_object*>(mData + pos);
-        if (flat->hdr.type == BINDER_TYPE_FD) {
-            *result = true;
-            break;
+    } else if (const auto* rpcFields = maybeRpcFields()) {
+        for (uint32_t pos : rpcFields->mObjectPositions) {
+            if (offset <= pos && pos < limit) {
+                const auto* type = reinterpret_cast<const RpcFields::ObjectType*>(mData + pos);
+                if (*type == RpcFields::TYPE_NATIVE_FILE_DESCRIPTOR) {
+                    *result = true;
+                    break;
+                }
+            }
         }
     }
     return NO_ERROR;
@@ -1293,11 +1365,40 @@
     return err;
 }
 
-status_t Parcel::writeFileDescriptor(int fd, bool takeOwnership)
-{
-    if (isForRpc()) {
-        ALOGE("Cannot write file descriptor to remote binder.");
-        return BAD_TYPE;
+status_t Parcel::writeFileDescriptor(int fd, bool takeOwnership) {
+    if (auto* rpcFields = maybeRpcFields()) {
+        std::variant<base::unique_fd, base::borrowed_fd> fdVariant;
+        if (takeOwnership) {
+            fdVariant = base::unique_fd(fd);
+        } else {
+            fdVariant = base::borrowed_fd(fd);
+        }
+        if (!mAllowFds) {
+            return FDS_NOT_ALLOWED;
+        }
+        switch (rpcFields->mSession->getFileDescriptorTransportMode()) {
+            case RpcSession::FileDescriptorTransportMode::NONE: {
+                return FDS_NOT_ALLOWED;
+            }
+            case RpcSession::FileDescriptorTransportMode::UNIX: {
+                if (rpcFields->mFds == nullptr) {
+                    rpcFields->mFds = std::make_unique<decltype(rpcFields->mFds)::element_type>();
+                }
+                size_t dataPos = mDataPos;
+                if (dataPos > UINT32_MAX) {
+                    return NO_MEMORY;
+                }
+                if (status_t err = writeInt32(RpcFields::TYPE_NATIVE_FILE_DESCRIPTOR); err != OK) {
+                    return err;
+                }
+                if (status_t err = writeInt32(rpcFields->mFds->size()); err != OK) {
+                    return err;
+                }
+                rpcFields->mObjectPositions.push_back(dataPos);
+                rpcFields->mFds->push_back(std::move(fdVariant));
+                return OK;
+            }
+        }
     }
 
     flat_binder_object obj;
@@ -2038,8 +2139,31 @@
     return h;
 }
 
-int Parcel::readFileDescriptor() const
-{
+int Parcel::readFileDescriptor() const {
+    if (const auto* rpcFields = maybeRpcFields()) {
+        if (!std::binary_search(rpcFields->mObjectPositions.begin(),
+                                rpcFields->mObjectPositions.end(), mDataPos)) {
+            ALOGW("Attempt to read file descriptor from Parcel %p at offset %zu that is not in the "
+                  "object list",
+                  this, mDataPos);
+            return BAD_TYPE;
+        }
+
+        int32_t objectType = readInt32();
+        if (objectType != RpcFields::TYPE_NATIVE_FILE_DESCRIPTOR) {
+            return BAD_TYPE;
+        }
+
+        int32_t fdIndex = readInt32();
+        if (rpcFields->mFds == nullptr || fdIndex < 0 ||
+            static_cast<size_t>(fdIndex) >= rpcFields->mFds->size()) {
+            ALOGE("RPC Parcel contains invalid file descriptor index. index=%d fd_count=%zu",
+                  fdIndex, rpcFields->mFds ? rpcFields->mFds->size() : 0);
+            return BAD_VALUE;
+        }
+        return toRawFd(rpcFields->mFds->at(fdIndex));
+    }
+
     const flat_binder_object* flat = readObject(true);
 
     if (flat && flat->hdr.type == BINDER_TYPE_FD) {
@@ -2049,8 +2173,7 @@
     return BAD_TYPE;
 }
 
-int Parcel::readParcelFileDescriptor() const
-{
+int Parcel::readParcelFileDescriptor() const {
     int32_t hasComm = readInt32();
     int fd = readFileDescriptor();
     if (hasComm != 0) {
@@ -2270,22 +2393,22 @@
 }
 
 void Parcel::closeFileDescriptors() {
-    auto* kernelFields = maybeKernelFields();
-    if (kernelFields == nullptr) {
-        return;
-    }
-    size_t i = kernelFields->mObjectsSize;
-    if (i > 0) {
-        //ALOGI("Closing file descriptors for %zu objects...", i);
-    }
-    while (i > 0) {
-        i--;
-        const flat_binder_object* flat =
-                reinterpret_cast<flat_binder_object*>(mData + kernelFields->mObjects[i]);
-        if (flat->hdr.type == BINDER_TYPE_FD) {
-            //ALOGI("Closing fd: %ld", flat->handle);
-            close(flat->handle);
+    if (auto* kernelFields = maybeKernelFields()) {
+        size_t i = kernelFields->mObjectsSize;
+        if (i > 0) {
+            // ALOGI("Closing file descriptors for %zu objects...", i);
         }
+        while (i > 0) {
+            i--;
+            const flat_binder_object* flat =
+                    reinterpret_cast<flat_binder_object*>(mData + kernelFields->mObjects[i]);
+            if (flat->hdr.type == BINDER_TYPE_FD) {
+                // ALOGI("Closing fd: %ld", flat->handle);
+                close(flat->handle);
+            }
+        }
+    } else if (auto* rpcFields = maybeRpcFields()) {
+        rpcFields->mFds.reset();
     }
 }
 
@@ -2363,8 +2486,11 @@
     scanForFds();
 }
 
-void Parcel::rpcSetDataReference(const sp<RpcSession>& session, const uint8_t* data,
-                                 size_t dataSize, release_func relFunc) {
+status_t Parcel::rpcSetDataReference(const sp<RpcSession>& session, const uint8_t* data,
+                                     size_t dataSize, const uint32_t* objectTable,
+                                     size_t objectTableSize,
+                                     std::vector<base::unique_fd> ancillaryFds,
+                                     release_func relFunc) {
     // this code uses 'mOwner == nullptr' to understand whether it owns memory
     LOG_ALWAYS_FATAL_IF(relFunc == nullptr, "must provide cleanup function");
 
@@ -2373,9 +2499,32 @@
     freeData();
     markForRpc(session);
 
+    auto* rpcFields = maybeRpcFields();
+    LOG_ALWAYS_FATAL_IF(rpcFields == nullptr); // guaranteed by markForRpc.
+
     mData = const_cast<uint8_t*>(data);
     mDataSize = mDataCapacity = dataSize;
     mOwner = relFunc;
+
+    if (objectTableSize != ancillaryFds.size()) {
+        ALOGE("objectTableSize=%zu ancillaryFds.size=%zu", objectTableSize, ancillaryFds.size());
+        freeData(); // don't leak mData
+        return BAD_VALUE;
+    }
+
+    rpcFields->mObjectPositions.reserve(objectTableSize);
+    for (size_t i = 0; i < objectTableSize; i++) {
+        rpcFields->mObjectPositions.push_back(objectTable[i]);
+    }
+    if (!ancillaryFds.empty()) {
+        rpcFields->mFds = std::make_unique<decltype(rpcFields->mFds)::element_type>();
+        rpcFields->mFds->reserve(ancillaryFds.size());
+        for (auto& fd : ancillaryFds) {
+            rpcFields->mFds->push_back(std::move(fd));
+        }
+    }
+
+    return OK;
 }
 
 void Parcel::print(TextOutput& to, uint32_t /*flags*/) const
@@ -2558,6 +2707,9 @@
         kernelFields->mObjectsSorted = false;
         kernelFields->mHasFds = false;
         kernelFields->mFdsKnown = true;
+    } else if (auto* rpcFields = maybeRpcFields()) {
+        rpcFields->mObjectPositions.clear();
+        rpcFields->mFds.reset();
     }
     mAllowFds = true;
 
@@ -2573,17 +2725,26 @@
     }
 
     auto* kernelFields = maybeKernelFields();
+    auto* rpcFields = maybeRpcFields();
 
     // If shrinking, first adjust for any objects that appear
     // after the new data size.
-    size_t objectsSize = kernelFields ? kernelFields->mObjectsSize : 0;
-    if (kernelFields && desired < mDataSize) {
+    size_t objectsSize =
+            kernelFields ? kernelFields->mObjectsSize : rpcFields->mObjectPositions.size();
+    if (desired < mDataSize) {
         if (desired == 0) {
             objectsSize = 0;
         } else {
-            while (objectsSize > 0) {
-                if (kernelFields->mObjects[objectsSize - 1] < desired) break;
-                objectsSize--;
+            if (kernelFields) {
+                while (objectsSize > 0) {
+                    if (kernelFields->mObjects[objectsSize - 1] < desired) break;
+                    objectsSize--;
+                }
+            } else {
+                while (objectsSize > 0) {
+                    if (rpcFields->mObjectPositions[objectsSize - 1] < desired) break;
+                    objectsSize--;
+                }
             }
         }
     }
@@ -2604,7 +2765,7 @@
         }
         binder_size_t* objects = nullptr;
 
-        if (objectsSize) {
+        if (kernelFields && objectsSize) {
             objects = (binder_size_t*)calloc(objectsSize, sizeof(binder_size_t));
             if (!objects) {
                 free(data);
@@ -2620,6 +2781,12 @@
             acquireObjects();
             kernelFields->mObjectsSize = oldObjectsSize;
         }
+        if (rpcFields) {
+            if (status_t status = truncateRpcObjects(objectsSize); status != OK) {
+                free(data);
+                return status;
+            }
+        }
 
         if (mData) {
             memcpy(data, mData, mDataSize < desired ? mDataSize : desired);
@@ -2678,6 +2845,11 @@
             kernelFields->mNextObjectHint = 0;
             kernelFields->mObjectsSorted = false;
         }
+        if (rpcFields) {
+            if (status_t status = truncateRpcObjects(objectsSize); status != OK) {
+                return status;
+            }
+        }
 
         // We own the data, so we can just do a realloc().
         if (desired > mDataCapacity) {
@@ -2734,6 +2906,35 @@
     return NO_ERROR;
 }
 
+status_t Parcel::truncateRpcObjects(size_t newObjectsSize) {
+    auto* rpcFields = maybeRpcFields();
+    if (newObjectsSize == 0) {
+        rpcFields->mObjectPositions.clear();
+        if (rpcFields->mFds) {
+            rpcFields->mFds->clear();
+        }
+        return OK;
+    }
+    while (rpcFields->mObjectPositions.size() > newObjectsSize) {
+        uint32_t pos = rpcFields->mObjectPositions.back();
+        rpcFields->mObjectPositions.pop_back();
+        const auto type = *reinterpret_cast<const RpcFields::ObjectType*>(mData + pos);
+        if (type == RpcFields::TYPE_NATIVE_FILE_DESCRIPTOR) {
+            const auto fdIndex =
+                    *reinterpret_cast<const int32_t*>(mData + pos + sizeof(RpcFields::ObjectType));
+            if (rpcFields->mFds == nullptr || fdIndex < 0 ||
+                static_cast<size_t>(fdIndex) >= rpcFields->mFds->size()) {
+                ALOGE("RPC Parcel contains invalid file descriptor index. index=%d fd_count=%zu",
+                      fdIndex, rpcFields->mFds ? rpcFields->mFds->size() : 0);
+                return BAD_VALUE;
+            }
+            // In practice, this always removes the last element.
+            rpcFields->mFds->erase(rpcFields->mFds->begin() + fdIndex);
+        }
+    }
+    return OK;
+}
+
 void Parcel::initState()
 {
     LOG_ALLOC("Parcel %p: initState", this);