Merge changes from topic "libbinder-rpc-perf-poll"

* changes:
  libbinder: RPC handle builtup refcounts
  libbinder: RPC avoid poll
  libbinder: RPC simpl transactAddressInternal
diff --git a/libs/binder/RpcServer.cpp b/libs/binder/RpcServer.cpp
index 4c61a59..ba2920e 100644
--- a/libs/binder/RpcServer.cpp
+++ b/libs/binder/RpcServer.cpp
@@ -278,7 +278,7 @@
     RpcConnectionHeader header;
     if (status == OK) {
         status = client->interruptableReadFully(server->mShutdownTrigger.get(), &header,
-                                                sizeof(header));
+                                                sizeof(header), {});
         if (status != OK) {
             ALOGE("Failed to read ID for client connecting to RPC server: %s",
                   statusToString(status).c_str());
@@ -291,7 +291,7 @@
         if (header.sessionIdSize > 0) {
             sessionId.resize(header.sessionIdSize);
             status = client->interruptableReadFully(server->mShutdownTrigger.get(),
-                                                    sessionId.data(), sessionId.size());
+                                                    sessionId.data(), sessionId.size(), {});
             if (status != OK) {
                 ALOGE("Failed to read session ID for client connecting to RPC server: %s",
                       statusToString(status).c_str());
@@ -316,7 +316,7 @@
             };
 
             status = client->interruptableWriteFully(server->mShutdownTrigger.get(), &response,
-                                                     sizeof(response));
+                                                     sizeof(response), {});
             if (status != OK) {
                 ALOGE("Failed to send new session response: %s", statusToString(status).c_str());
                 // still need to cleanup before we can return
diff --git a/libs/binder/RpcSession.cpp b/libs/binder/RpcSession.cpp
index dafb339..65f6bc6 100644
--- a/libs/binder/RpcSession.cpp
+++ b/libs/binder/RpcSession.cpp
@@ -560,7 +560,7 @@
     }
 
     auto sendHeaderStatus =
-            server->interruptableWriteFully(mShutdownTrigger.get(), &header, sizeof(header));
+            server->interruptableWriteFully(mShutdownTrigger.get(), &header, sizeof(header), {});
     if (sendHeaderStatus != OK) {
         ALOGE("Could not write connection header to socket: %s",
               statusToString(sendHeaderStatus).c_str());
@@ -570,7 +570,7 @@
     if (sessionId.size() > 0) {
         auto sendSessionIdStatus =
                 server->interruptableWriteFully(mShutdownTrigger.get(), sessionId.data(),
-                                                sessionId.size());
+                                                sessionId.size(), {});
         if (sendSessionIdStatus != OK) {
             ALOGE("Could not write session ID ('%s') to socket: %s",
                   base::HexString(sessionId.data(), sessionId.size()).c_str(),
diff --git a/libs/binder/RpcState.cpp b/libs/binder/RpcState.cpp
index 3ff13bc..86cc91c 100644
--- a/libs/binder/RpcState.cpp
+++ b/libs/binder/RpcState.cpp
@@ -307,7 +307,7 @@
 
 status_t RpcState::rpcSend(const sp<RpcSession::RpcConnection>& connection,
                            const sp<RpcSession>& session, const char* what, const void* data,
-                           size_t size) {
+                           size_t size, const std::function<status_t()>& altPoll) {
     LOG_RPC_DETAIL("Sending %s on RpcTransport %p: %s", what, connection->rpcTransport.get(),
                    android::base::HexString(data, size).c_str());
 
@@ -319,7 +319,7 @@
 
     if (status_t status =
                 connection->rpcTransport->interruptableWriteFully(session->mShutdownTrigger.get(),
-                                                                  data, size);
+                                                                  data, size, altPoll);
         status != OK) {
         LOG_RPC_DETAIL("Failed to write %s (%zu bytes) on RpcTransport %p, error: %s", what, size,
                        connection->rpcTransport.get(), statusToString(status).c_str());
@@ -341,7 +341,7 @@
 
     if (status_t status =
                 connection->rpcTransport->interruptableReadFully(session->mShutdownTrigger.get(),
-                                                                 data, size);
+                                                                 data, size, {});
         status != OK) {
         LOG_RPC_DETAIL("Failed to read %s (%zu bytes) on RpcTransport %p, error: %s", what, size,
                        connection->rpcTransport.get(), statusToString(status).c_str());
@@ -523,21 +523,44 @@
     memcpy(transactionData.data() + sizeof(RpcWireHeader) + sizeof(RpcWireTransaction), data.data(),
            data.dataSize());
 
+    constexpr size_t kWaitMaxUs = 1000000;
+    constexpr size_t kWaitLogUs = 10000;
+    size_t waitUs = 0;
+
+    // Oneway calls have no sync point, so if many are sent before, whether this
+    // is a twoway or oneway transaction, they may have filled up the socket.
+    // So, make sure we drain them before polling.
+    std::function<status_t()> drainRefs = [&] {
+        if (waitUs > kWaitLogUs) {
+            ALOGE("Cannot send command, trying to process pending refcounts. Waiting %zuus. Too "
+                  "many oneway calls?",
+                  waitUs);
+        }
+
+        if (waitUs > 0) {
+            usleep(waitUs);
+            waitUs = std::min(kWaitMaxUs, waitUs * 2);
+        } else {
+            waitUs = 1;
+        }
+
+        return drainCommands(connection, session, CommandType::CONTROL_ONLY);
+    };
+
     if (status_t status = rpcSend(connection, session, "transaction", transactionData.data(),
-                                  transactionData.size());
-        status != OK)
+                                  transactionData.size(), drainRefs);
+        status != OK) {
         // TODO(b/167966510): need to undo onBinderLeaving - we know the
         // refcount isn't successfully transferred.
         return status;
+    }
 
     if (flags & IBinder::FLAG_ONEWAY) {
         LOG_RPC_DETAIL("Oneway command, so no longer waiting on RpcTransport %p",
                        connection->rpcTransport.get());
 
         // Do not wait on result.
-        // However, too many oneway calls may cause refcounts to build up and fill up the socket,
-        // so process those.
-        return drainCommands(connection, session, CommandType::CONTROL_ONLY);
+        return OK;
     }
 
     LOG_ALWAYS_FATAL_IF(reply == nullptr, "Reply parcel must be used for synchronous transaction.");
@@ -723,7 +746,7 @@
     // for 'recursive' calls to this, we have already read and processed the
     // binder from the transaction data and taken reference counts into account,
     // so it is cached here.
-    sp<IBinder> targetRef;
+    sp<IBinder> target;
 processTransactInternalTailCall:
 
     if (transactionData.size() < sizeof(RpcWireTransaction)) {
@@ -738,12 +761,9 @@
     bool oneway = transaction->flags & IBinder::FLAG_ONEWAY;
 
     status_t replyStatus = OK;
-    sp<IBinder> target;
     if (addr != 0) {
-        if (!targetRef) {
+        if (!target) {
             replyStatus = onBinderEntering(session, addr, &target);
-        } else {
-            target = targetRef;
         }
 
         if (replyStatus != OK) {
@@ -910,7 +930,8 @@
 
                 // reset up arguments
                 transactionData = std::move(todo.data);
-                targetRef = std::move(todo.ref);
+                LOG_ALWAYS_FATAL_IF(target != todo.ref,
+                                    "async list should be associated with a binder");
 
                 it->second.asyncTodo.pop();
                 goto processTransactInternalTailCall;
diff --git a/libs/binder/RpcState.h b/libs/binder/RpcState.h
index 42e95e0..50de22b 100644
--- a/libs/binder/RpcState.h
+++ b/libs/binder/RpcState.h
@@ -177,7 +177,8 @@
 
     [[nodiscard]] status_t rpcSend(const sp<RpcSession::RpcConnection>& connection,
                                    const sp<RpcSession>& session, const char* what,
-                                   const void* data, size_t size);
+                                   const void* data, size_t size,
+                                   const std::function<status_t()>& altPoll = nullptr);
     [[nodiscard]] status_t rpcRec(const sp<RpcSession::RpcConnection>& connection,
                                   const sp<RpcSession>& session, const char* what, void* data,
                                   size_t size);
diff --git a/libs/binder/RpcTransportRaw.cpp b/libs/binder/RpcTransportRaw.cpp
index 41f4a9f..7669518 100644
--- a/libs/binder/RpcTransportRaw.cpp
+++ b/libs/binder/RpcTransportRaw.cpp
@@ -43,56 +43,72 @@
         return ret;
     }
 
-    status_t interruptableWriteFully(FdTrigger* fdTrigger, const void* data, size_t size) override {
-        const uint8_t* buffer = reinterpret_cast<const uint8_t*>(data);
-        const uint8_t* end = buffer + size;
+    template <typename Buffer, typename SendOrReceive>
+    status_t interruptableReadOrWrite(FdTrigger* fdTrigger, Buffer buffer, size_t size,
+                                      SendOrReceive sendOrReceiveFun, const char* funName,
+                                      int16_t event, const std::function<status_t()>& altPoll) {
+        const Buffer end = buffer + size;
 
         MAYBE_WAIT_IN_FLAKE_MODE;
 
-        status_t status;
-        while ((status = fdTrigger->triggerablePoll(mSocket.get(), POLLOUT)) == OK) {
-            ssize_t writeSize =
-                    TEMP_FAILURE_RETRY(::send(mSocket.get(), buffer, end - buffer, MSG_NOSIGNAL));
-            if (writeSize < 0) {
+        // Since we didn't poll, we need to manually check to see if it was triggered. Otherwise, we
+        // may never know we should be shutting down.
+        if (fdTrigger->isTriggered()) {
+            return DEAD_OBJECT;
+        }
+
+        bool havePolled = false;
+        while (true) {
+            ssize_t processSize = TEMP_FAILURE_RETRY(
+                    sendOrReceiveFun(mSocket.get(), buffer, end - buffer, MSG_NOSIGNAL));
+
+            if (processSize < 0) {
                 int savedErrno = errno;
-                LOG_RPC_DETAIL("RpcTransport send(): %s", strerror(savedErrno));
-                return -savedErrno;
+
+                // Still return the error on later passes, since it would expose
+                // a problem with polling
+                if (havePolled ||
+                    (!havePolled && savedErrno != EAGAIN && savedErrno != EWOULDBLOCK)) {
+                    LOG_RPC_DETAIL("RpcTransport %s(): %s", funName, strerror(savedErrno));
+                    return -savedErrno;
+                }
+            } else if (processSize == 0) {
+                return DEAD_OBJECT;
+            } else {
+                buffer += processSize;
+                if (buffer == end) {
+                    return OK;
+                }
             }
 
-            if (writeSize == 0) return DEAD_OBJECT;
-
-            buffer += writeSize;
-            if (buffer == end) return OK;
+            if (altPoll) {
+                if (status_t status = altPoll(); status != OK) return status;
+                if (fdTrigger->isTriggered()) {
+                    return DEAD_OBJECT;
+                }
+            } else {
+                if (status_t status = fdTrigger->triggerablePoll(mSocket.get(), event);
+                    status != OK)
+                    return status;
+                if (!havePolled) havePolled = true;
+            }
         }
-        return status;
     }
 
-    status_t interruptableReadFully(FdTrigger* fdTrigger, void* data, size_t size) override {
-        uint8_t* buffer = reinterpret_cast<uint8_t*>(data);
-        uint8_t* end = buffer + size;
+    status_t interruptableWriteFully(FdTrigger* fdTrigger, const void* data, size_t size,
+                                     const std::function<status_t()>& altPoll) override {
+        return interruptableReadOrWrite(fdTrigger, reinterpret_cast<const uint8_t*>(data), size,
+                                        send, "send", POLLOUT, altPoll);
+    }
 
-        MAYBE_WAIT_IN_FLAKE_MODE;
-
-        status_t status;
-        while ((status = fdTrigger->triggerablePoll(mSocket.get(), POLLIN)) == OK) {
-            ssize_t readSize =
-                    TEMP_FAILURE_RETRY(::recv(mSocket.get(), buffer, end - buffer, MSG_NOSIGNAL));
-            if (readSize < 0) {
-                int savedErrno = errno;
-                LOG_RPC_DETAIL("RpcTransport recv(): %s", strerror(savedErrno));
-                return -savedErrno;
-            }
-
-            if (readSize == 0) return DEAD_OBJECT; // EOF
-
-            buffer += readSize;
-            if (buffer == end) return OK;
-        }
-        return status;
+    status_t interruptableReadFully(FdTrigger* fdTrigger, void* data, size_t size,
+                                    const std::function<status_t()>& altPoll) override {
+        return interruptableReadOrWrite(fdTrigger, reinterpret_cast<uint8_t*>(data), size, recv,
+                                        "recv", POLLIN, altPoll);
     }
 
 private:
-    android::base::unique_fd mSocket;
+    base::unique_fd mSocket;
 };
 
 // RpcTransportCtx with TLS disabled.
diff --git a/libs/binder/RpcTransportTls.cpp b/libs/binder/RpcTransportTls.cpp
index f8cd71d..7f810b1 100644
--- a/libs/binder/RpcTransportTls.cpp
+++ b/libs/binder/RpcTransportTls.cpp
@@ -169,12 +169,13 @@
     // If |sslError| is WANT_READ / WANT_WRITE, poll for POLLIN / POLLOUT respectively. Otherwise
     // return error. Also return error if |fdTrigger| is triggered before or during poll().
     status_t pollForSslError(android::base::borrowed_fd fd, int sslError, FdTrigger* fdTrigger,
-                             const char* fnString, int additionalEvent = 0) {
+                             const char* fnString, int additionalEvent,
+                             const std::function<status_t()>& altPoll) {
         switch (sslError) {
             case SSL_ERROR_WANT_READ:
-                return handlePoll(POLLIN | additionalEvent, fd, fdTrigger, fnString);
+                return handlePoll(POLLIN | additionalEvent, fd, fdTrigger, fnString, altPoll);
             case SSL_ERROR_WANT_WRITE:
-                return handlePoll(POLLOUT | additionalEvent, fd, fdTrigger, fnString);
+                return handlePoll(POLLOUT | additionalEvent, fd, fdTrigger, fnString, altPoll);
             case SSL_ERROR_SYSCALL: {
                 auto queue = toString();
                 LOG_TLS_DETAIL("%s(): %s. Treating as DEAD_OBJECT. Error queue: %s", fnString,
@@ -194,11 +195,17 @@
     bool mHandled = false;
 
     status_t handlePoll(int event, android::base::borrowed_fd fd, FdTrigger* fdTrigger,
-                        const char* fnString) {
-        status_t ret = fdTrigger->triggerablePoll(fd, event);
+                        const char* fnString, const std::function<status_t()>& altPoll) {
+        status_t ret;
+        if (altPoll) {
+            ret = altPoll();
+            if (fdTrigger->isTriggered()) ret = DEAD_OBJECT;
+        } else {
+            ret = fdTrigger->triggerablePoll(fd, event);
+        }
+
         if (ret != OK && ret != DEAD_OBJECT) {
-            ALOGE("triggerablePoll error while poll()-ing after %s(): %s", fnString,
-                  statusToString(ret).c_str());
+            ALOGE("poll error while after %s(): %s", fnString, statusToString(ret).c_str());
         }
         clear();
         return ret;
@@ -268,8 +275,10 @@
     RpcTransportTls(android::base::unique_fd socket, Ssl ssl)
           : mSocket(std::move(socket)), mSsl(std::move(ssl)) {}
     Result<size_t> peek(void* buf, size_t size) override;
-    status_t interruptableWriteFully(FdTrigger* fdTrigger, const void* data, size_t size) override;
-    status_t interruptableReadFully(FdTrigger* fdTrigger, void* data, size_t size) override;
+    status_t interruptableWriteFully(FdTrigger* fdTrigger, const void* data, size_t size,
+                                     const std::function<status_t()>& altPoll) override;
+    status_t interruptableReadFully(FdTrigger* fdTrigger, void* data, size_t size,
+                                    const std::function<status_t()>& altPoll) override;
 
 private:
     android::base::unique_fd mSocket;
@@ -295,7 +304,8 @@
 }
 
 status_t RpcTransportTls::interruptableWriteFully(FdTrigger* fdTrigger, const void* data,
-                                                  size_t size) {
+                                                  size_t size,
+                                                  const std::function<status_t()>& altPoll) {
     auto buffer = reinterpret_cast<const uint8_t*>(data);
     const uint8_t* end = buffer + size;
 
@@ -317,8 +327,8 @@
         int sslError = mSsl.getError(writeSize);
         // TODO(b/195788248): BIO should contain the FdTrigger, and send(2) / recv(2) should be
         //   triggerablePoll()-ed. Then additionalEvent is no longer necessary.
-        status_t pollStatus =
-                errorQueue.pollForSslError(mSocket.get(), sslError, fdTrigger, "SSL_write", POLLIN);
+        status_t pollStatus = errorQueue.pollForSslError(mSocket.get(), sslError, fdTrigger,
+                                                         "SSL_write", POLLIN, altPoll);
         if (pollStatus != OK) return pollStatus;
         // Do not advance buffer. Try SSL_write() again.
     }
@@ -326,7 +336,8 @@
     return OK;
 }
 
-status_t RpcTransportTls::interruptableReadFully(FdTrigger* fdTrigger, void* data, size_t size) {
+status_t RpcTransportTls::interruptableReadFully(FdTrigger* fdTrigger, void* data, size_t size,
+                                                 const std::function<status_t()>& altPoll) {
     auto buffer = reinterpret_cast<uint8_t*>(data);
     uint8_t* end = buffer + size;
 
@@ -350,8 +361,8 @@
             return DEAD_OBJECT;
         }
         int sslError = mSsl.getError(readSize);
-        status_t pollStatus =
-                errorQueue.pollForSslError(mSocket.get(), sslError, fdTrigger, "SSL_read");
+        status_t pollStatus = errorQueue.pollForSslError(mSocket.get(), sslError, fdTrigger,
+                                                         "SSL_read", 0, altPoll);
         if (pollStatus != OK) return pollStatus;
         // Do not advance buffer. Try SSL_read() again.
     }
@@ -382,7 +393,7 @@
         }
         int sslError = ssl->getError(ret);
         status_t pollStatus =
-                errorQueue.pollForSslError(fd, sslError, fdTrigger, "SSL_do_handshake");
+                errorQueue.pollForSslError(fd, sslError, fdTrigger, "SSL_do_handshake", 0, {});
         if (pollStatus != OK) return false;
     }
 }
diff --git a/libs/binder/include/binder/RpcTransport.h b/libs/binder/include/binder/RpcTransport.h
index 4fe2324..db8b5e9 100644
--- a/libs/binder/include/binder/RpcTransport.h
+++ b/libs/binder/include/binder/RpcTransport.h
@@ -18,6 +18,7 @@
 
 #pragma once
 
+#include <functional>
 #include <memory>
 #include <string>
 
@@ -43,14 +44,20 @@
     /**
      * Read (or write), but allow to be interrupted by a trigger.
      *
+     * altPoll - function to be called instead of polling, when needing to wait
+     * to read/write data. If this returns an error, that error is returned from
+     * this function.
+     *
      * Return:
      *   OK - succeeded in completely processing 'size'
      *   error - interrupted (failure or trigger)
      */
-    [[nodiscard]] virtual status_t interruptableWriteFully(FdTrigger *fdTrigger, const void *buf,
-                                                           size_t size) = 0;
-    [[nodiscard]] virtual status_t interruptableReadFully(FdTrigger *fdTrigger, void *buf,
-                                                          size_t size) = 0;
+    [[nodiscard]] virtual status_t interruptableWriteFully(
+            FdTrigger *fdTrigger, const void *buf, size_t size,
+            const std::function<status_t()> &altPoll) = 0;
+    [[nodiscard]] virtual status_t interruptableReadFully(
+            FdTrigger *fdTrigger, void *buf, size_t size,
+            const std::function<status_t()> &altPoll) = 0;
 
 protected:
     RpcTransport() = default;
diff --git a/libs/binder/tests/binderRpcTest.cpp b/libs/binder/tests/binderRpcTest.cpp
index a2558f5..a1058bc 100644
--- a/libs/binder/tests/binderRpcTest.cpp
+++ b/libs/binder/tests/binderRpcTest.cpp
@@ -1573,7 +1573,7 @@
                                                   FdTrigger* fdTrigger) {
             std::string message(kMessage);
             auto status = serverTransport->interruptableWriteFully(fdTrigger, message.data(),
-                                                                   message.size());
+                                                                   message.size(), {});
             if (status != OK) return AssertionFailure() << statusToString(status);
             return AssertionSuccess();
         }
@@ -1606,7 +1606,7 @@
             std::string readMessage(expectedMessage.size(), '\0');
             status_t readStatus =
                     mClientTransport->interruptableReadFully(mFdTrigger.get(), readMessage.data(),
-                                                             readMessage.size());
+                                                             readMessage.size(), {});
             if (readStatus != OK) {
                 return AssertionFailure() << statusToString(readStatus);
             }
@@ -1800,8 +1800,8 @@
     bool shouldContinueWriting = false;
     auto serverPostConnect = [&](RpcTransport* serverTransport, FdTrigger* fdTrigger) {
         std::string message(RpcTransportTestUtils::kMessage);
-        auto status =
-                serverTransport->interruptableWriteFully(fdTrigger, message.data(), message.size());
+        auto status = serverTransport->interruptableWriteFully(fdTrigger, message.data(),
+                                                               message.size(), {});
         if (status != OK) return AssertionFailure() << statusToString(status);
 
         {
@@ -1811,7 +1811,7 @@
             }
         }
 
-        status = serverTransport->interruptableWriteFully(fdTrigger, msg2.data(), msg2.size());
+        status = serverTransport->interruptableWriteFully(fdTrigger, msg2.data(), msg2.size(), {});
         if (status != DEAD_OBJECT)
             return AssertionFailure() << "When FdTrigger is shut down, interruptableWriteFully "
                                          "should return DEAD_OBJECT, but it is "