Use std::function instead of base::function_ref

Bug: 302723053
Test: mma
Change-Id: I2a437363135a26dd8459e643ecf87a274da3d26e
diff --git a/libs/binder/RpcState.cpp b/libs/binder/RpcState.cpp
index aa26b91..09032dd 100644
--- a/libs/binder/RpcState.cpp
+++ b/libs/binder/RpcState.cpp
@@ -359,7 +359,7 @@
 status_t RpcState::rpcSend(
         const sp<RpcSession::RpcConnection>& connection, const sp<RpcSession>& session,
         const char* what, iovec* iovs, int niovs,
-        const std::optional<android::base::function_ref<status_t()>>& altPoll,
+        const std::optional<SmallFunction<status_t()>>& altPoll,
         const std::vector<std::variant<base::unique_fd, base::borrowed_fd>>* ancillaryFds) {
     for (int i = 0; i < niovs; i++) {
         LOG_RPC_DETAIL("Sending %s (part %d of %d) on RpcTransport %p: %s",
@@ -604,25 +604,24 @@
             {const_cast<uint8_t*>(data.data()), data.dataSize()},
             objectTableSpan.toIovec(),
     };
-    if (status_t status = rpcSend(
-                connection, session, "transaction", iovs, arraysize(iovs),
-                [&] {
-                    if (waitUs > kWaitLogUs) {
-                        ALOGE("Cannot send command, trying to process pending refcounts. Waiting "
-                              "%zuus. Too many oneway calls?",
-                              waitUs);
-                    }
+    auto altPoll = [&] {
+        if (waitUs > kWaitLogUs) {
+            ALOGE("Cannot send command, trying to process pending refcounts. Waiting "
+                  "%zuus. Too many oneway calls?",
+                  waitUs);
+        }
 
-                    if (waitUs > 0) {
-                        usleep(waitUs);
-                        waitUs = std::min(kWaitMaxUs, waitUs * 2);
-                    } else {
-                        waitUs = 1;
-                    }
+        if (waitUs > 0) {
+            usleep(waitUs);
+            waitUs = std::min(kWaitMaxUs, waitUs * 2);
+        } else {
+            waitUs = 1;
+        }
 
-                    return drainCommands(connection, session, CommandType::CONTROL_ONLY);
-                },
-                rpcFields->mFds.get());
+        return drainCommands(connection, session, CommandType::CONTROL_ONLY);
+    };
+    if (status_t status = rpcSend(connection, session, "transaction", iovs, arraysize(iovs),
+                                  std::ref(altPoll), rpcFields->mFds.get());
         status != OK) {
         // rpcSend calls shutdownAndWait, so all refcounts should be reset. If we ever tolerate
         // errors here, then we may need to undo the binder-sent counts for the transaction as
diff --git a/libs/binder/RpcState.h b/libs/binder/RpcState.h
index 1fe71a5..2a954e6 100644
--- a/libs/binder/RpcState.h
+++ b/libs/binder/RpcState.h
@@ -16,6 +16,7 @@
 #pragma once
 
 #include <android-base/unique_fd.h>
+#include <binder/Functional.h>
 #include <binder/IBinder.h>
 #include <binder/Parcel.h>
 #include <binder/RpcSession.h>
@@ -190,7 +191,7 @@
     [[nodiscard]] status_t rpcSend(
             const sp<RpcSession::RpcConnection>& connection, const sp<RpcSession>& session,
             const char* what, iovec* iovs, int niovs,
-            const std::optional<android::base::function_ref<status_t()>>& altPoll,
+            const std::optional<binder::impl::SmallFunction<status_t()>>& altPoll,
             const std::vector<std::variant<base::unique_fd, base::borrowed_fd>>* ancillaryFds =
                     nullptr);
     [[nodiscard]] status_t rpcRec(
diff --git a/libs/binder/RpcTransportRaw.cpp b/libs/binder/RpcTransportRaw.cpp
index c089811..ffa3151 100644
--- a/libs/binder/RpcTransportRaw.cpp
+++ b/libs/binder/RpcTransportRaw.cpp
@@ -29,6 +29,8 @@
 
 namespace android {
 
+using namespace android::binder::impl;
+
 // RpcTransport with TLS disabled.
 class RpcTransportRaw : public RpcTransport {
 public:
@@ -54,7 +56,7 @@
 
     status_t interruptableWriteFully(
             FdTrigger* fdTrigger, iovec* iovs, int niovs,
-            const std::optional<android::base::function_ref<status_t()>>& altPoll,
+            const std::optional<SmallFunction<status_t()>>& altPoll,
             const std::vector<std::variant<base::unique_fd, base::borrowed_fd>>* ancillaryFds)
             override {
         bool sentFds = false;
@@ -70,7 +72,7 @@
 
     status_t interruptableReadFully(
             FdTrigger* fdTrigger, iovec* iovs, int niovs,
-            const std::optional<android::base::function_ref<status_t()>>& altPoll,
+            const std::optional<SmallFunction<status_t()>>& altPoll,
             std::vector<std::variant<base::unique_fd, base::borrowed_fd>>* ancillaryFds) override {
         auto recv = [&](iovec* iovs, int niovs) -> ssize_t {
             return binder::os::receiveMessageFromSocket(mSocket, iovs, niovs, ancillaryFds);
diff --git a/libs/binder/RpcTransportTipcAndroid.cpp b/libs/binder/RpcTransportTipcAndroid.cpp
index 0c81d83..cf0360f 100644
--- a/libs/binder/RpcTransportTipcAndroid.cpp
+++ b/libs/binder/RpcTransportTipcAndroid.cpp
@@ -26,6 +26,7 @@
 #include "RpcState.h"
 #include "RpcTransportUtils.h"
 
+using namespace android::binder::impl;
 using android::base::Error;
 using android::base::Result;
 
@@ -75,7 +76,7 @@
 
     status_t interruptableWriteFully(
             FdTrigger* fdTrigger, iovec* iovs, int niovs,
-            const std::optional<android::base::function_ref<status_t()>>& altPoll,
+            const std::optional<SmallFunction<status_t()>>& altPoll,
             const std::vector<std::variant<base::unique_fd, base::borrowed_fd>>* ancillaryFds)
             override {
         auto writeFn = [&](iovec* iovs, size_t niovs) -> ssize_t {
@@ -93,7 +94,7 @@
 
     status_t interruptableReadFully(
             FdTrigger* fdTrigger, iovec* iovs, int niovs,
-            const std::optional<android::base::function_ref<status_t()>>& altPoll,
+            const std::optional<SmallFunction<status_t()>>& altPoll,
             std::vector<std::variant<base::unique_fd, base::borrowed_fd>>* /*ancillaryFds*/)
             override {
         auto readFn = [&](iovec* iovs, size_t niovs) -> ssize_t {
diff --git a/libs/binder/RpcTransportTls.cpp b/libs/binder/RpcTransportTls.cpp
index efb09e9..2cbfd8d 100644
--- a/libs/binder/RpcTransportTls.cpp
+++ b/libs/binder/RpcTransportTls.cpp
@@ -38,6 +38,9 @@
 #endif
 
 namespace android {
+
+using namespace android::binder::impl;
+
 namespace {
 
 // Implement BIO for socket that ignores SIGPIPE.
@@ -181,10 +184,9 @@
     // |sslError| should be from Ssl::getError().
     // If |sslError| is WANT_READ / WANT_WRITE, poll for POLLIN / POLLOUT respectively. Otherwise
     // return error. Also return error if |fdTrigger| is triggered before or during poll().
-    status_t pollForSslError(
-            const android::RpcTransportFd& fd, int sslError, FdTrigger* fdTrigger,
-            const char* fnString, int additionalEvent,
-            const std::optional<android::base::function_ref<status_t()>>& altPoll) {
+    status_t pollForSslError(const android::RpcTransportFd& fd, int sslError, FdTrigger* fdTrigger,
+                             const char* fnString, int additionalEvent,
+                             const std::optional<SmallFunction<status_t()>>& altPoll) {
         switch (sslError) {
             case SSL_ERROR_WANT_READ:
                 return handlePoll(POLLIN | additionalEvent, fd, fdTrigger, fnString, altPoll);
@@ -200,7 +202,7 @@
 
     status_t handlePoll(int event, const android::RpcTransportFd& fd, FdTrigger* fdTrigger,
                         const char* fnString,
-                        const std::optional<android::base::function_ref<status_t()>>& altPoll) {
+                        const std::optional<SmallFunction<status_t()>>& altPoll) {
         status_t ret;
         if (altPoll) {
             ret = (*altPoll)();
@@ -284,12 +286,12 @@
     status_t pollRead(void) override;
     status_t interruptableWriteFully(
             FdTrigger* fdTrigger, iovec* iovs, int niovs,
-            const std::optional<android::base::function_ref<status_t()>>& altPoll,
+            const std::optional<SmallFunction<status_t()>>& altPoll,
             const std::vector<std::variant<base::unique_fd, base::borrowed_fd>>* ancillaryFds)
             override;
     status_t interruptableReadFully(
             FdTrigger* fdTrigger, iovec* iovs, int niovs,
-            const std::optional<android::base::function_ref<status_t()>>& altPoll,
+            const std::optional<SmallFunction<status_t()>>& altPoll,
             std::vector<std::variant<base::unique_fd, base::borrowed_fd>>* ancillaryFds) override;
 
     bool isWaiting() override { return mSocket.isInPollingState(); };
@@ -320,7 +322,7 @@
 
 status_t RpcTransportTls::interruptableWriteFully(
         FdTrigger* fdTrigger, iovec* iovs, int niovs,
-        const std::optional<android::base::function_ref<status_t()>>& altPoll,
+        const std::optional<SmallFunction<status_t()>>& altPoll,
         const std::vector<std::variant<base::unique_fd, base::borrowed_fd>>* ancillaryFds) {
     (void)ancillaryFds;
 
@@ -366,7 +368,7 @@
 
 status_t RpcTransportTls::interruptableReadFully(
         FdTrigger* fdTrigger, iovec* iovs, int niovs,
-        const std::optional<android::base::function_ref<status_t()>>& altPoll,
+        const std::optional<SmallFunction<status_t()>>& altPoll,
         std::vector<std::variant<base::unique_fd, base::borrowed_fd>>* ancillaryFds) {
     (void)ancillaryFds;
 
diff --git a/libs/binder/RpcTransportUtils.h b/libs/binder/RpcTransportUtils.h
index 32f0db8..a0e502e 100644
--- a/libs/binder/RpcTransportUtils.h
+++ b/libs/binder/RpcTransportUtils.h
@@ -27,7 +27,7 @@
 status_t interruptableReadOrWrite(
         const android::RpcTransportFd& socket, FdTrigger* fdTrigger, iovec* iovs, int niovs,
         SendOrReceive sendOrReceiveFun, const char* funName, int16_t event,
-        const std::optional<android::base::function_ref<status_t()>>& altPoll) {
+        const std::optional<binder::impl::SmallFunction<status_t()>>& altPoll) {
     MAYBE_WAIT_IN_FLAKE_MODE;
 
     if (niovs < 0) {
diff --git a/libs/binder/include/binder/Functional.h b/libs/binder/include/binder/Functional.h
index 058f833..08e3b21 100644
--- a/libs/binder/include/binder/Functional.h
+++ b/libs/binder/include/binder/Functional.h
@@ -38,4 +38,13 @@
     return {reinterpret_cast<void*>(true), std::bind(f)};
 }
 
+template <typename T>
+class SmallFunction : public std::function<T> {
+public:
+    template <typename F>
+    SmallFunction(F&& f) : std::function<T>(f) {
+        assert_small_callable<F>();
+    }
+};
+
 } // namespace android::binder::impl
diff --git a/libs/binder/include/binder/RpcTransport.h b/libs/binder/include/binder/RpcTransport.h
index 6db9ad9..115a173 100644
--- a/libs/binder/include/binder/RpcTransport.h
+++ b/libs/binder/include/binder/RpcTransport.h
@@ -25,10 +25,10 @@
 #include <variant>
 #include <vector>
 
-#include <android-base/function_ref.h>
 #include <android-base/unique_fd.h>
 #include <utils/Errors.h>
 
+#include <binder/Functional.h>
 #include <binder/RpcCertificateFormat.h>
 #include <binder/RpcThreads.h>
 
@@ -85,13 +85,13 @@
      *   error - interrupted (failure or trigger)
      */
     [[nodiscard]] virtual status_t interruptableWriteFully(
-            FdTrigger *fdTrigger, iovec *iovs, int niovs,
-            const std::optional<android::base::function_ref<status_t()>> &altPoll,
-            const std::vector<std::variant<base::unique_fd, base::borrowed_fd>> *ancillaryFds) = 0;
+            FdTrigger* fdTrigger, iovec* iovs, int niovs,
+            const std::optional<binder::impl::SmallFunction<status_t()>>& altPoll,
+            const std::vector<std::variant<base::unique_fd, base::borrowed_fd>>* ancillaryFds) = 0;
     [[nodiscard]] virtual status_t interruptableReadFully(
-            FdTrigger *fdTrigger, iovec *iovs, int niovs,
-            const std::optional<android::base::function_ref<status_t()>> &altPoll,
-            std::vector<std::variant<base::unique_fd, base::borrowed_fd>> *ancillaryFds) = 0;
+            FdTrigger* fdTrigger, iovec* iovs, int niovs,
+            const std::optional<binder::impl::SmallFunction<status_t()>>& altPoll,
+            std::vector<std::variant<base::unique_fd, base::borrowed_fd>>* ancillaryFds) = 0;
 
     /**
      *  Check whether any threads are blocked while polling the transport
diff --git a/libs/binder/trusty/RpcTransportTipcTrusty.cpp b/libs/binder/trusty/RpcTransportTipcTrusty.cpp
index 692f82d..6bb45e2 100644
--- a/libs/binder/trusty/RpcTransportTipcTrusty.cpp
+++ b/libs/binder/trusty/RpcTransportTipcTrusty.cpp
@@ -29,6 +29,8 @@
 
 namespace android {
 
+using namespace android::binder::impl;
+
 // RpcTransport for Trusty.
 class RpcTransportTipcTrusty : public RpcTransport {
 public:
@@ -45,7 +47,7 @@
 
     status_t interruptableWriteFully(
             FdTrigger* /*fdTrigger*/, iovec* iovs, int niovs,
-            const std::optional<android::base::function_ref<status_t()>>& /*altPoll*/,
+            const std::optional<SmallFunction<status_t()>>& /*altPoll*/,
             const std::vector<std::variant<base::unique_fd, base::borrowed_fd>>* ancillaryFds)
             override {
         if (niovs < 0) {
@@ -115,7 +117,7 @@
 
     status_t interruptableReadFully(
             FdTrigger* /*fdTrigger*/, iovec* iovs, int niovs,
-            const std::optional<android::base::function_ref<status_t()>>& /*altPoll*/,
+            const std::optional<SmallFunction<status_t()>>& /*altPoll*/,
             std::vector<std::variant<base::unique_fd, base::borrowed_fd>>* ancillaryFds) override {
         if (niovs < 0) {
             return BAD_VALUE;