libbinder: transaction includes refcount to binder

Prevents case where one thread is making a transaction and another
thread clears the ref to this transaction (mainly this is a problem
with oneway transactions). This is something which the binder driver
also does implicitly, but it was missing from the RPC binder
implementation.

Bug: 183140903
Test: binderRpcTest
Change-Id: I4f59ad6094f90e5c95af5febea2780bed29d4c88
diff --git a/libs/binder/BpBinder.cpp b/libs/binder/BpBinder.cpp
index 1dcb94c..5e44a0f 100644
--- a/libs/binder/BpBinder.cpp
+++ b/libs/binder/BpBinder.cpp
@@ -273,7 +273,8 @@
 
         status_t status;
         if (CC_UNLIKELY(isRpcBinder())) {
-            status = rpcSession()->transact(rpcAddress(), code, data, reply, flags);
+            status = rpcSession()->transact(sp<IBinder>::fromExisting(this), code, data, reply,
+                                            flags);
         } else {
             status = IPCThreadState::self()->transact(binderHandle(), code, data, reply, flags);
         }
diff --git a/libs/binder/RpcSession.cpp b/libs/binder/RpcSession.cpp
index 7c458c1..d5ce324 100644
--- a/libs/binder/RpcSession.cpp
+++ b/libs/binder/RpcSession.cpp
@@ -100,12 +100,12 @@
     return state()->getMaxThreads(connection.fd(), sp<RpcSession>::fromExisting(this), maxThreads);
 }
 
-status_t RpcSession::transact(const RpcAddress& address, uint32_t code, const Parcel& data,
+status_t RpcSession::transact(const sp<IBinder>& binder, uint32_t code, const Parcel& data,
                               Parcel* reply, uint32_t flags) {
     ExclusiveConnection connection(sp<RpcSession>::fromExisting(this),
                                    (flags & IBinder::FLAG_ONEWAY) ? ConnectionUse::CLIENT_ASYNC
                                                                   : ConnectionUse::CLIENT);
-    return state()->transact(connection.fd(), address, code, data,
+    return state()->transact(connection.fd(), binder, code, data,
                              sp<RpcSession>::fromExisting(this), reply, flags);
 }
 
diff --git a/libs/binder/RpcState.cpp b/libs/binder/RpcState.cpp
index 6483486..d40fef6 100644
--- a/libs/binder/RpcState.cpp
+++ b/libs/binder/RpcState.cpp
@@ -253,8 +253,8 @@
     data.markForRpc(session);
     Parcel reply;
 
-    status_t status = transact(fd, RpcAddress::zero(), RPC_SPECIAL_TRANSACT_GET_ROOT, data, session,
-                               &reply, 0);
+    status_t status = transactAddress(fd, RpcAddress::zero(), RPC_SPECIAL_TRANSACT_GET_ROOT, data,
+                                      session, &reply, 0);
     if (status != OK) {
         ALOGE("Error getting root object: %s", statusToString(status).c_str());
         return nullptr;
@@ -269,8 +269,8 @@
     data.markForRpc(session);
     Parcel reply;
 
-    status_t status = transact(fd, RpcAddress::zero(), RPC_SPECIAL_TRANSACT_GET_MAX_THREADS, data,
-                               session, &reply, 0);
+    status_t status = transactAddress(fd, RpcAddress::zero(), RPC_SPECIAL_TRANSACT_GET_MAX_THREADS,
+                                      data, session, &reply, 0);
     if (status != OK) {
         ALOGE("Error getting max threads: %s", statusToString(status).c_str());
         return status;
@@ -294,8 +294,8 @@
     data.markForRpc(session);
     Parcel reply;
 
-    status_t status = transact(fd, RpcAddress::zero(), RPC_SPECIAL_TRANSACT_GET_SESSION_ID, data,
-                               session, &reply, 0);
+    status_t status = transactAddress(fd, RpcAddress::zero(), RPC_SPECIAL_TRANSACT_GET_SESSION_ID,
+                                      data, session, &reply, 0);
     if (status != OK) {
         ALOGE("Error getting session ID: %s", statusToString(status).c_str());
         return status;
@@ -309,9 +309,31 @@
     return OK;
 }
 
-status_t RpcState::transact(const base::unique_fd& fd, const RpcAddress& address, uint32_t code,
+status_t RpcState::transact(const base::unique_fd& fd, const sp<IBinder>& binder, uint32_t code,
                             const Parcel& data, const sp<RpcSession>& session, Parcel* reply,
                             uint32_t flags) {
+    if (!data.isForRpc()) {
+        ALOGE("Refusing to send RPC with parcel not crafted for RPC");
+        return BAD_TYPE;
+    }
+
+    if (data.objectsCount() != 0) {
+        ALOGE("Parcel at %p has attached objects but is being used in an RPC call", &data);
+        return BAD_TYPE;
+    }
+
+    RpcAddress address = RpcAddress::zero();
+    if (status_t status = onBinderLeaving(session, binder, &address); status != OK) return status;
+
+    return transactAddress(fd, address, code, data, session, reply, flags);
+}
+
+status_t RpcState::transactAddress(const base::unique_fd& fd, const RpcAddress& address,
+                                   uint32_t code, const Parcel& data, const sp<RpcSession>& session,
+                                   Parcel* reply, uint32_t flags) {
+    LOG_ALWAYS_FATAL_IF(!data.isForRpc());
+    LOG_ALWAYS_FATAL_IF(data.objectsCount() != 0);
+
     uint64_t asyncNumber = 0;
 
     if (!address.isZero()) {
@@ -326,16 +348,6 @@
         }
     }
 
-    if (!data.isForRpc()) {
-        ALOGE("Refusing to send RPC with parcel not crafted for RPC");
-        return BAD_TYPE;
-    }
-
-    if (data.objectsCount() != 0) {
-        ALOGE("Parcel at %p has attached objects but is being used in an RPC call", &data);
-        return BAD_TYPE;
-    }
-
     RpcWireTransaction transaction{
             .address = address.viewRawEmbedded(),
             .code = code,
@@ -509,7 +521,7 @@
         return DEAD_OBJECT;
     }
 
-    return processTransactInternal(fd, session, std::move(transactionData));
+    return processTransactInternal(fd, session, std::move(transactionData), nullptr /*targetRef*/);
 }
 
 static void do_nothing_to_transact_data(Parcel* p, const uint8_t* data, size_t dataSize,
@@ -522,7 +534,7 @@
 }
 
 status_t RpcState::processTransactInternal(const base::unique_fd& fd, const sp<RpcSession>& session,
-                                           CommandData transactionData) {
+                                           CommandData transactionData, sp<IBinder>&& targetRef) {
     if (transactionData.size() < sizeof(RpcWireTransaction)) {
         ALOGE("Expecting %zu but got %zu bytes for RpcWireTransaction. Terminating!",
               sizeof(RpcWireTransaction), transactionData.size());
@@ -538,45 +550,49 @@
     status_t replyStatus = OK;
     sp<IBinder> target;
     if (!addr.isZero()) {
-        std::lock_guard<std::mutex> _l(mNodeMutex);
-
-        auto it = mNodeForAddress.find(addr);
-        if (it == mNodeForAddress.end()) {
-            ALOGE("Unknown binder address %s.", addr.toString().c_str());
-            replyStatus = BAD_VALUE;
+        if (!targetRef) {
+            target = onBinderEntering(session, addr);
         } else {
-            target = it->second.binder.promote();
-            if (target == nullptr) {
-                // This can happen if the binder is remote in this process, and
-                // another thread has called the last decStrong on this binder.
-                // However, for local binders, it indicates a misbehaving client
-                // (any binder which is being transacted on should be holding a
-                // strong ref count), so in either case, terminating the
-                // session.
-                ALOGE("While transacting, binder has been deleted at address %s. Terminating!",
+            target = targetRef;
+        }
+
+        if (target == nullptr) {
+            // This can happen if the binder is remote in this process, and
+            // another thread has called the last decStrong on this binder.
+            // However, for local binders, it indicates a misbehaving client
+            // (any binder which is being transacted on should be holding a
+            // strong ref count), so in either case, terminating the
+            // session.
+            ALOGE("While transacting, binder has been deleted at address %s. Terminating!",
+                  addr.toString().c_str());
+            terminate();
+            replyStatus = BAD_VALUE;
+        } else if (target->localBinder() == nullptr) {
+            ALOGE("Unknown binder address or non-local binder, not address %s. Terminating!",
+                  addr.toString().c_str());
+            terminate();
+            replyStatus = BAD_VALUE;
+        } else if (transaction->flags & IBinder::FLAG_ONEWAY) {
+            std::lock_guard<std::mutex> _l(mNodeMutex);
+            auto it = mNodeForAddress.find(addr);
+            if (it->second.binder.promote() != target) {
+                ALOGE("Binder became invalid during transaction. Bad client? %s",
                       addr.toString().c_str());
-                terminate();
                 replyStatus = BAD_VALUE;
-            } else if (target->localBinder() == nullptr) {
-                ALOGE("Transactions can only go to local binders, not address %s. Terminating!",
-                      addr.toString().c_str());
-                terminate();
-                replyStatus = BAD_VALUE;
-            } else if (transaction->flags & IBinder::FLAG_ONEWAY) {
-                if (transaction->asyncNumber != it->second.asyncNumber) {
-                    // we need to process some other asynchronous transaction
-                    // first
-                    // TODO(b/183140903): limit enqueues/detect overfill for bad client
-                    // TODO(b/183140903): detect when an object is deleted when it still has
-                    //        pending async transactions
-                    it->second.asyncTodo.push(BinderNode::AsyncTodo{
-                            .data = std::move(transactionData),
-                            .asyncNumber = transaction->asyncNumber,
-                    });
-                    LOG_RPC_DETAIL("Enqueuing %" PRId64 " on %s", transaction->asyncNumber,
-                                   addr.toString().c_str());
-                    return OK;
-                }
+            } else if (transaction->asyncNumber != it->second.asyncNumber) {
+                // we need to process some other asynchronous transaction
+                // first
+                // TODO(b/183140903): limit enqueues/detect overfill for bad client
+                // TODO(b/183140903): detect when an object is deleted when it still has
+                //        pending async transactions
+                it->second.asyncTodo.push(BinderNode::AsyncTodo{
+                        .ref = target,
+                        .data = std::move(transactionData),
+                        .asyncNumber = transaction->asyncNumber,
+                });
+                LOG_RPC_DETAIL("Enqueuing %" PRId64 " on %s", transaction->asyncNumber,
+                               addr.toString().c_str());
+                return OK;
             }
         }
     }
@@ -670,13 +686,17 @@
                                it->second.asyncNumber, addr.toString().c_str());
 
                 // justification for const_cast (consider avoiding priority_queue):
-                // - AsyncTodo operator< doesn't depend on 'data' object
+                // - AsyncTodo operator< doesn't depend on 'data' or 'ref' objects
                 // - gotta go fast
-                CommandData data = std::move(
-                        const_cast<BinderNode::AsyncTodo&>(it->second.asyncTodo.top()).data);
+                auto& todo = const_cast<BinderNode::AsyncTodo&>(it->second.asyncTodo.top());
+
+                CommandData nextData = std::move(todo.data);
+                sp<IBinder> nextRef = std::move(todo.ref);
+
                 it->second.asyncTodo.pop();
                 _l.unlock();
-                return processTransactInternal(fd, session, std::move(data));
+                return processTransactInternal(fd, session, std::move(nextData),
+                                               std::move(nextRef));
             }
         }
         return OK;
diff --git a/libs/binder/RpcState.h b/libs/binder/RpcState.h
index f913925..78e8997 100644
--- a/libs/binder/RpcState.h
+++ b/libs/binder/RpcState.h
@@ -58,9 +58,13 @@
     status_t getSessionId(const base::unique_fd& fd, const sp<RpcSession>& session,
                           int32_t* sessionIdOut);
 
-    [[nodiscard]] status_t transact(const base::unique_fd& fd, const RpcAddress& address,
+    [[nodiscard]] status_t transact(const base::unique_fd& fd, const sp<IBinder>& address,
                                     uint32_t code, const Parcel& data,
                                     const sp<RpcSession>& session, Parcel* reply, uint32_t flags);
+    [[nodiscard]] status_t transactAddress(const base::unique_fd& fd, const RpcAddress& address,
+                                           uint32_t code, const Parcel& data,
+                                           const sp<RpcSession>& session, Parcel* reply,
+                                           uint32_t flags);
     [[nodiscard]] status_t sendDecStrong(const base::unique_fd& fd, const RpcAddress& address);
     [[nodiscard]] status_t getAndExecuteCommand(const base::unique_fd& fd,
                                                 const sp<RpcSession>& session);
@@ -129,7 +133,8 @@
                                            const RpcWireHeader& command);
     [[nodiscard]] status_t processTransactInternal(const base::unique_fd& fd,
                                                    const sp<RpcSession>& session,
-                                                   CommandData transactionData);
+                                                   CommandData transactionData,
+                                                   sp<IBinder>&& targetRef);
     [[nodiscard]] status_t processDecStrong(const base::unique_fd& fd,
                                             const sp<RpcSession>& session,
                                             const RpcWireHeader& command);
@@ -165,6 +170,7 @@
 
         // async transaction queue, _only_ for local binder
         struct AsyncTodo {
+            sp<IBinder> ref;
             CommandData data;
             uint64_t asyncNumber = 0;
 
diff --git a/libs/binder/include/binder/RpcSession.h b/libs/binder/include/binder/RpcSession.h
index d46f275..7c7feaa 100644
--- a/libs/binder/include/binder/RpcSession.h
+++ b/libs/binder/include/binder/RpcSession.h
@@ -83,7 +83,7 @@
      */
     status_t getRemoteMaxThreads(size_t* maxThreads);
 
-    [[nodiscard]] status_t transact(const RpcAddress& address, uint32_t code, const Parcel& data,
+    [[nodiscard]] status_t transact(const sp<IBinder>& binder, uint32_t code, const Parcel& data,
                                     Parcel* reply, uint32_t flags);
     [[nodiscard]] status_t sendDecStrong(const RpcAddress& address);
 
diff --git a/libs/binder/tests/binderRpcTest.cpp b/libs/binder/tests/binderRpcTest.cpp
index e10fe2f..9b2d88d 100644
--- a/libs/binder/tests/binderRpcTest.cpp
+++ b/libs/binder/tests/binderRpcTest.cpp
@@ -843,8 +843,7 @@
     constexpr size_t kReallyLongTimeMs = 100;
     constexpr size_t kSleepMs = kReallyLongTimeMs * 5;
 
-    // more than one thread, just so this doesn't deadlock
-    auto proc = createRpcTestSocketServerProcess(2);
+    auto proc = createRpcTestSocketServerProcess(1);
 
     size_t epochMsBefore = epochMillis();
 
@@ -876,6 +875,14 @@
     size_t epochMsAfter = epochMillis();
 
     EXPECT_GT(epochMsAfter, epochMsBefore + kSleepMs * kNumSleeps);
+
+    // pending oneway transactions hold ref, make sure we read data on all
+    // sockets
+    std::vector<std::thread> threads;
+    for (size_t i = 0; i < 1 + kNumExtraServerThreads; i++) {
+        threads.push_back(std::thread([&] { EXPECT_OK(proc.rootIface->sleepMs(250)); }));
+    }
+    for (auto& t : threads) t.join();
 }
 
 TEST_P(BinderRpc, Die) {