Use android::base::function_ref isntead of std::function

This removes an allocation from the binder RPC calls.

Test: atest binderAllocationLimits
Bug: 230625474
Change-Id: I70ebb4e320323149c3c66809f1077cbf332c07ef
diff --git a/libs/binder/RpcServer.cpp b/libs/binder/RpcServer.cpp
index d63c3f1..c67b70a 100644
--- a/libs/binder/RpcServer.cpp
+++ b/libs/binder/RpcServer.cpp
@@ -289,7 +289,8 @@
     RpcConnectionHeader header;
     if (status == OK) {
         iovec iov{&header, sizeof(header)};
-        status = client->interruptableReadFully(server->mShutdownTrigger.get(), &iov, 1, {});
+        status = client->interruptableReadFully(server->mShutdownTrigger.get(), &iov, 1,
+                                                std::nullopt);
         if (status != OK) {
             ALOGE("Failed to read ID for client connecting to RPC server: %s",
                   statusToString(status).c_str());
@@ -303,8 +304,8 @@
             if (header.sessionIdSize == kSessionIdBytes) {
                 sessionId.resize(header.sessionIdSize);
                 iovec iov{sessionId.data(), sessionId.size()};
-                status =
-                        client->interruptableReadFully(server->mShutdownTrigger.get(), &iov, 1, {});
+                status = client->interruptableReadFully(server->mShutdownTrigger.get(), &iov, 1,
+                                                        std::nullopt);
                 if (status != OK) {
                     ALOGE("Failed to read session ID for client connecting to RPC server: %s",
                           statusToString(status).c_str());
@@ -334,7 +335,8 @@
             };
 
             iovec iov{&response, sizeof(response)};
-            status = client->interruptableWriteFully(server->mShutdownTrigger.get(), &iov, 1, {});
+            status = client->interruptableWriteFully(server->mShutdownTrigger.get(), &iov, 1,
+                                                     std::nullopt);
             if (status != OK) {
                 ALOGE("Failed to send new session response: %s", statusToString(status).c_str());
                 // still need to cleanup before we can return
diff --git a/libs/binder/RpcSession.cpp b/libs/binder/RpcSession.cpp
index 6ae5357..5c35dd0 100644
--- a/libs/binder/RpcSession.cpp
+++ b/libs/binder/RpcSession.cpp
@@ -615,7 +615,7 @@
 
     iovec headerIov{&header, sizeof(header)};
     auto sendHeaderStatus =
-            server->interruptableWriteFully(mShutdownTrigger.get(), &headerIov, 1, {});
+            server->interruptableWriteFully(mShutdownTrigger.get(), &headerIov, 1, std::nullopt);
     if (sendHeaderStatus != OK) {
         ALOGE("Could not write connection header to socket: %s",
               statusToString(sendHeaderStatus).c_str());
@@ -625,8 +625,8 @@
     if (sessionId.size() > 0) {
         iovec sessionIov{const_cast<void*>(static_cast<const void*>(sessionId.data())),
                          sessionId.size()};
-        auto sendSessionIdStatus =
-                server->interruptableWriteFully(mShutdownTrigger.get(), &sessionIov, 1, {});
+        auto sendSessionIdStatus = server->interruptableWriteFully(mShutdownTrigger.get(),
+                                                                   &sessionIov, 1, std::nullopt);
         if (sendSessionIdStatus != OK) {
             ALOGE("Could not write session ID ('%s') to socket: %s",
                   base::HexString(sessionId.data(), sessionId.size()).c_str(),
diff --git a/libs/binder/RpcState.cpp b/libs/binder/RpcState.cpp
index 4ef9cd8..c4d28c2 100644
--- a/libs/binder/RpcState.cpp
+++ b/libs/binder/RpcState.cpp
@@ -311,7 +311,7 @@
 
 status_t RpcState::rpcSend(const sp<RpcSession::RpcConnection>& connection,
                            const sp<RpcSession>& session, const char* what, iovec* iovs, int niovs,
-                           const std::function<status_t()>& altPoll) {
+                           const std::optional<android::base::function_ref<status_t()>>& altPoll) {
     for (int i = 0; i < niovs; i++) {
         LOG_RPC_DETAIL("Sending %s (part %d of %d) on RpcTransport %p: %s",
                        what, i + 1, niovs, connection->rpcTransport.get(),
@@ -335,7 +335,7 @@
                           const sp<RpcSession>& session, const char* what, iovec* iovs, int niovs) {
     if (status_t status =
                 connection->rpcTransport->interruptableReadFully(session->mShutdownTrigger.get(),
-                                                                 iovs, niovs, {});
+                                                                 iovs, niovs, std::nullopt);
         status != OK) {
         LOG_RPC_DETAIL("Failed to read %s (%d iovs) on RpcTransport %p, error: %s", what, niovs,
                        connection->rpcTransport.get(), statusToString(status).c_str());
@@ -369,7 +369,7 @@
             .msg = RPC_CONNECTION_INIT_OKAY,
     };
     iovec iov{&init, sizeof(init)};
-    return rpcSend(connection, session, "connection init", &iov, 1);
+    return rpcSend(connection, session, "connection init", &iov, 1, std::nullopt);
 }
 
 status_t RpcState::readConnectionInit(const sp<RpcSession::RpcConnection>& connection,
@@ -516,31 +516,32 @@
 
     // Oneway calls have no sync point, so if many are sent before, whether this
     // is a twoway or oneway transaction, they may have filled up the socket.
-    // So, make sure we drain them before polling.
-    std::function<status_t()> drainRefs = [&] {
-        if (waitUs > kWaitLogUs) {
-            ALOGE("Cannot send command, trying to process pending refcounts. Waiting %zuus. Too "
-                  "many oneway calls?",
-                  waitUs);
-        }
-
-        if (waitUs > 0) {
-            usleep(waitUs);
-            waitUs = std::min(kWaitMaxUs, waitUs * 2);
-        } else {
-            waitUs = 1;
-        }
-
-        return drainCommands(connection, session, CommandType::CONTROL_ONLY);
-    };
+    // So, make sure we drain them before polling
 
     iovec iovs[]{
             {&command, sizeof(RpcWireHeader)},
             {&transaction, sizeof(RpcWireTransaction)},
             {const_cast<uint8_t*>(data.data()), data.dataSize()},
     };
-    if (status_t status =
-                rpcSend(connection, session, "transaction", iovs, arraysize(iovs), drainRefs);
+    if (status_t status = rpcSend(connection, session, "transaction", iovs, arraysize(iovs),
+                                  [&] {
+                                      if (waitUs > kWaitLogUs) {
+                                          ALOGE("Cannot send command, trying to process pending "
+                                                "refcounts. Waiting %zuus. Too "
+                                                "many oneway calls?",
+                                                waitUs);
+                                      }
+
+                                      if (waitUs > 0) {
+                                          usleep(waitUs);
+                                          waitUs = std::min(kWaitMaxUs, waitUs * 2);
+                                      } else {
+                                          waitUs = 1;
+                                      }
+
+                                      return drainCommands(connection, session,
+                                                           CommandType::CONTROL_ONLY);
+                                  });
         status != OK) {
         // TODO(b/167966510): need to undo onBinderLeaving - we know the
         // refcount isn't successfully transferred.
@@ -643,7 +644,7 @@
             .bodySize = sizeof(RpcDecStrong),
     };
     iovec iovs[]{{&cmd, sizeof(cmd)}, {&body, sizeof(body)}};
-    return rpcSend(connection, session, "dec ref", iovs, arraysize(iovs));
+    return rpcSend(connection, session, "dec ref", iovs, arraysize(iovs), std::nullopt);
 }
 
 status_t RpcState::getAndExecuteCommand(const sp<RpcSession::RpcConnection>& connection,
@@ -958,7 +959,7 @@
             {&rpcReply, sizeof(RpcWireReply)},
             {const_cast<uint8_t*>(reply.data()), reply.dataSize()},
     };
-    return rpcSend(connection, session, "reply", iovs, arraysize(iovs));
+    return rpcSend(connection, session, "reply", iovs, arraysize(iovs), std::nullopt);
 }
 
 status_t RpcState::processDecStrong(const sp<RpcSession::RpcConnection>& connection,
diff --git a/libs/binder/RpcState.h b/libs/binder/RpcState.h
index f4a0894..9cbe187 100644
--- a/libs/binder/RpcState.h
+++ b/libs/binder/RpcState.h
@@ -178,9 +178,10 @@
         size_t mSize;
     };
 
-    [[nodiscard]] status_t rpcSend(const sp<RpcSession::RpcConnection>& connection,
-                                   const sp<RpcSession>& session, const char* what, iovec* iovs,
-                                   int niovs, const std::function<status_t()>& altPoll = nullptr);
+    [[nodiscard]] status_t rpcSend(
+            const sp<RpcSession::RpcConnection>& connection, const sp<RpcSession>& session,
+            const char* what, iovec* iovs, int niovs,
+            const std::optional<android::base::function_ref<status_t()>>& altPoll);
     [[nodiscard]] status_t rpcRec(const sp<RpcSession::RpcConnection>& connection,
                                   const sp<RpcSession>& session, const char* what, iovec* iovs,
                                   int niovs);
diff --git a/libs/binder/RpcTransportRaw.cpp b/libs/binder/RpcTransportRaw.cpp
index f5cc413..f9b73fc 100644
--- a/libs/binder/RpcTransportRaw.cpp
+++ b/libs/binder/RpcTransportRaw.cpp
@@ -52,9 +52,10 @@
     }
 
     template <typename SendOrReceive>
-    status_t interruptableReadOrWrite(FdTrigger* fdTrigger, iovec* iovs, int niovs,
-                                      SendOrReceive sendOrReceiveFun, const char* funName,
-                                      int16_t event, const std::function<status_t()>& altPoll) {
+    status_t interruptableReadOrWrite(
+            FdTrigger* fdTrigger, iovec* iovs, int niovs, SendOrReceive sendOrReceiveFun,
+            const char* funName, int16_t event,
+            const std::optional<android::base::function_ref<status_t()>>& altPoll) {
         MAYBE_WAIT_IN_FLAKE_MODE;
 
         if (niovs < 0) {
@@ -129,7 +130,7 @@
             }
 
             if (altPoll) {
-                if (status_t status = altPoll(); status != OK) return status;
+                if (status_t status = (*altPoll)(); status != OK) return status;
                 if (fdTrigger->isTriggered()) {
                     return DEAD_OBJECT;
                 }
@@ -142,14 +143,16 @@
         }
     }
 
-    status_t interruptableWriteFully(FdTrigger* fdTrigger, iovec* iovs, int niovs,
-                                     const std::function<status_t()>& altPoll) override {
+    status_t interruptableWriteFully(
+            FdTrigger* fdTrigger, iovec* iovs, int niovs,
+            const std::optional<android::base::function_ref<status_t()>>& altPoll) override {
         return interruptableReadOrWrite(fdTrigger, iovs, niovs, sendmsg, "sendmsg", POLLOUT,
                                         altPoll);
     }
 
-    status_t interruptableReadFully(FdTrigger* fdTrigger, iovec* iovs, int niovs,
-                                    const std::function<status_t()>& altPoll) override {
+    status_t interruptableReadFully(
+            FdTrigger* fdTrigger, iovec* iovs, int niovs,
+            const std::optional<android::base::function_ref<status_t()>>& altPoll) override {
         return interruptableReadOrWrite(fdTrigger, iovs, niovs, recvmsg, "recvmsg", POLLIN,
                                         altPoll);
     }
diff --git a/libs/binder/RpcTransportTls.cpp b/libs/binder/RpcTransportTls.cpp
index 85c7655..ad5cb0f 100644
--- a/libs/binder/RpcTransportTls.cpp
+++ b/libs/binder/RpcTransportTls.cpp
@@ -181,9 +181,10 @@
     // |sslError| should be from Ssl::getError().
     // If |sslError| is WANT_READ / WANT_WRITE, poll for POLLIN / POLLOUT respectively. Otherwise
     // return error. Also return error if |fdTrigger| is triggered before or during poll().
-    status_t pollForSslError(android::base::borrowed_fd fd, int sslError, FdTrigger* fdTrigger,
-                             const char* fnString, int additionalEvent,
-                             const std::function<status_t()>& altPoll) {
+    status_t pollForSslError(
+            android::base::borrowed_fd fd, int sslError, FdTrigger* fdTrigger, const char* fnString,
+            int additionalEvent,
+            const std::optional<android::base::function_ref<status_t()>>& altPoll) {
         switch (sslError) {
             case SSL_ERROR_WANT_READ:
                 return handlePoll(POLLIN | additionalEvent, fd, fdTrigger, fnString, altPoll);
@@ -198,10 +199,11 @@
     bool mHandled = false;
 
     status_t handlePoll(int event, android::base::borrowed_fd fd, FdTrigger* fdTrigger,
-                        const char* fnString, const std::function<status_t()>& altPoll) {
+                        const char* fnString,
+                        const std::optional<android::base::function_ref<status_t()>>& altPoll) {
         status_t ret;
         if (altPoll) {
-            ret = altPoll();
+            ret = (*altPoll)();
             if (fdTrigger->isTriggered()) ret = DEAD_OBJECT;
         } else {
             ret = fdTrigger->triggerablePoll(fd, event);
@@ -278,10 +280,12 @@
     RpcTransportTls(android::base::unique_fd socket, Ssl ssl)
           : mSocket(std::move(socket)), mSsl(std::move(ssl)) {}
     status_t pollRead(void) override;
-    status_t interruptableWriteFully(FdTrigger* fdTrigger, iovec* iovs, int niovs,
-                                     const std::function<status_t()>& altPoll) override;
-    status_t interruptableReadFully(FdTrigger* fdTrigger, iovec* iovs, int niovs,
-                                    const std::function<status_t()>& altPoll) override;
+    status_t interruptableWriteFully(
+            FdTrigger* fdTrigger, iovec* iovs, int niovs,
+            const std::optional<android::base::function_ref<status_t()>>& altPoll) override;
+    status_t interruptableReadFully(
+            FdTrigger* fdTrigger, iovec* iovs, int niovs,
+            const std::optional<android::base::function_ref<status_t()>>& altPoll) override;
 
 private:
     android::base::unique_fd mSocket;
@@ -307,8 +311,9 @@
     return OK;
 }
 
-status_t RpcTransportTls::interruptableWriteFully(FdTrigger* fdTrigger, iovec* iovs, int niovs,
-                                                  const std::function<status_t()>& altPoll) {
+status_t RpcTransportTls::interruptableWriteFully(
+        FdTrigger* fdTrigger, iovec* iovs, int niovs,
+        const std::optional<android::base::function_ref<status_t()>>& altPoll) {
     MAYBE_WAIT_IN_FLAKE_MODE;
 
     if (niovs < 0) return BAD_VALUE;
@@ -349,8 +354,9 @@
     return OK;
 }
 
-status_t RpcTransportTls::interruptableReadFully(FdTrigger* fdTrigger, iovec* iovs, int niovs,
-                                                 const std::function<status_t()>& altPoll) {
+status_t RpcTransportTls::interruptableReadFully(
+        FdTrigger* fdTrigger, iovec* iovs, int niovs,
+        const std::optional<android::base::function_ref<status_t()>>& altPoll) {
     MAYBE_WAIT_IN_FLAKE_MODE;
 
     if (niovs < 0) return BAD_VALUE;
@@ -415,8 +421,8 @@
             return false;
         }
         int sslError = ssl->getError(ret);
-        status_t pollStatus =
-                errorQueue.pollForSslError(fd, sslError, fdTrigger, "SSL_do_handshake", 0, {});
+        status_t pollStatus = errorQueue.pollForSslError(fd, sslError, fdTrigger,
+                                                         "SSL_do_handshake", 0, std::nullopt);
         if (pollStatus != OK) return false;
     }
 }
diff --git a/libs/binder/include/binder/RpcTransport.h b/libs/binder/include/binder/RpcTransport.h
index 2c864f8..ee4b548 100644
--- a/libs/binder/include/binder/RpcTransport.h
+++ b/libs/binder/include/binder/RpcTransport.h
@@ -20,8 +20,10 @@
 
 #include <functional>
 #include <memory>
+#include <optional>
 #include <string>
 
+#include <android-base/function_ref.h>
 #include <android-base/unique_fd.h>
 #include <utils/Errors.h>
 
@@ -65,10 +67,10 @@
      */
     [[nodiscard]] virtual status_t interruptableWriteFully(
             FdTrigger *fdTrigger, iovec *iovs, int niovs,
-            const std::function<status_t()> &altPoll) = 0;
+            const std::optional<android::base::function_ref<status_t()>> &altPoll) = 0;
     [[nodiscard]] virtual status_t interruptableReadFully(
             FdTrigger *fdTrigger, iovec *iovs, int niovs,
-            const std::function<status_t()> &altPoll) = 0;
+            const std::optional<android::base::function_ref<status_t()>> &altPoll) = 0;
 
 protected:
     RpcTransport() = default;
diff --git a/libs/binder/tests/binderAllocationLimits.cpp b/libs/binder/tests/binderAllocationLimits.cpp
index b14ec55..dd1a8c3 100644
--- a/libs/binder/tests/binderAllocationLimits.cpp
+++ b/libs/binder/tests/binderAllocationLimits.cpp
@@ -208,8 +208,8 @@
         totalBytes += bytes;
     });
     CHECK_EQ(OK, remoteBinder->pingBinder());
-    EXPECT_EQ(mallocs, 4);
-    EXPECT_EQ(totalBytes, 108);
+    EXPECT_EQ(mallocs, 3);
+    EXPECT_EQ(totalBytes, 60);
 }
 
 int main(int argc, char** argv) {
diff --git a/libs/binder/tests/binderRpcTest.cpp b/libs/binder/tests/binderRpcTest.cpp
index f85756f..4161a7a 100644
--- a/libs/binder/tests/binderRpcTest.cpp
+++ b/libs/binder/tests/binderRpcTest.cpp
@@ -1681,7 +1681,8 @@
                                                   FdTrigger* fdTrigger) {
             std::string message(kMessage);
             iovec messageIov{message.data(), message.size()};
-            auto status = serverTransport->interruptableWriteFully(fdTrigger, &messageIov, 1, {});
+            auto status = serverTransport->interruptableWriteFully(fdTrigger, &messageIov, 1,
+                                                                   std::nullopt);
             if (status != OK) return AssertionFailure() << statusToString(status);
             return AssertionSuccess();
         }
@@ -1713,8 +1714,9 @@
             LOG_ALWAYS_FATAL_IF(mClientTransport == nullptr, "setUpTransport not called or failed");
             std::string readMessage(expectedMessage.size(), '\0');
             iovec readMessageIov{readMessage.data(), readMessage.size()};
-            status_t readStatus = mClientTransport->interruptableReadFully(mFdTrigger.get(),
-                                                                           &readMessageIov, 1, {});
+            status_t readStatus =
+                    mClientTransport->interruptableReadFully(mFdTrigger.get(), &readMessageIov, 1,
+                                                             std::nullopt);
             if (readStatus != OK) {
                 return AssertionFailure() << statusToString(readStatus);
             }
@@ -1909,7 +1911,8 @@
     auto serverPostConnect = [&](RpcTransport* serverTransport, FdTrigger* fdTrigger) {
         std::string message(RpcTransportTestUtils::kMessage);
         iovec messageIov{message.data(), message.size()};
-        auto status = serverTransport->interruptableWriteFully(fdTrigger, &messageIov, 1, {});
+        auto status =
+                serverTransport->interruptableWriteFully(fdTrigger, &messageIov, 1, std::nullopt);
         if (status != OK) return AssertionFailure() << statusToString(status);
 
         {
@@ -1920,7 +1923,7 @@
         }
 
         iovec msg2Iov{msg2.data(), msg2.size()};
-        status = serverTransport->interruptableWriteFully(fdTrigger, &msg2Iov, 1, {});
+        status = serverTransport->interruptableWriteFully(fdTrigger, &msg2Iov, 1, std::nullopt);
         if (status != DEAD_OBJECT)
             return AssertionFailure() << "When FdTrigger is shut down, interruptableWriteFully "
                                          "should return DEAD_OBJECT, but it is "