libbinder: RPC handle builtup refcounts
Generally, in the binder RPC wire protocol, we don't have both the
clients and the servers writing data into sockets. However, in the case
of async transactions, this happens in an unbounded way because a client
may send many oneway transactions, and the server will be sending back
refcounting information related to these transactions (which we process
lazily).
In order to prevent this from building up, when sending a transaction,
if we're unable to write it, instead of waiting, drain that reference
counting information.
Bug: 182940634
Test: binderRpcTest (no longer deadlocks in OnewayStressTest)
Test: manually check 'drainCommands' happens in both raw and tls cases
during this test (checking we are actually getting coverage)
Change-Id: I82039d6188196261b22316e95d8e180c4c33ae73
diff --git a/libs/binder/RpcServer.cpp b/libs/binder/RpcServer.cpp
index 5733993..cf95fda 100644
--- a/libs/binder/RpcServer.cpp
+++ b/libs/binder/RpcServer.cpp
@@ -275,7 +275,7 @@
RpcConnectionHeader header;
if (status == OK) {
status = client->interruptableReadFully(server->mShutdownTrigger.get(), &header,
- sizeof(header));
+ sizeof(header), {});
if (status != OK) {
ALOGE("Failed to read ID for client connecting to RPC server: %s",
statusToString(status).c_str());
@@ -288,7 +288,7 @@
if (header.sessionIdSize > 0) {
sessionId.resize(header.sessionIdSize);
status = client->interruptableReadFully(server->mShutdownTrigger.get(),
- sessionId.data(), sessionId.size());
+ sessionId.data(), sessionId.size(), {});
if (status != OK) {
ALOGE("Failed to read session ID for client connecting to RPC server: %s",
statusToString(status).c_str());
@@ -313,7 +313,7 @@
};
status = client->interruptableWriteFully(server->mShutdownTrigger.get(), &response,
- sizeof(response));
+ sizeof(response), {});
if (status != OK) {
ALOGE("Failed to send new session response: %s", statusToString(status).c_str());
// still need to cleanup before we can return
diff --git a/libs/binder/RpcSession.cpp b/libs/binder/RpcSession.cpp
index dafb339..65f6bc6 100644
--- a/libs/binder/RpcSession.cpp
+++ b/libs/binder/RpcSession.cpp
@@ -560,7 +560,7 @@
}
auto sendHeaderStatus =
- server->interruptableWriteFully(mShutdownTrigger.get(), &header, sizeof(header));
+ server->interruptableWriteFully(mShutdownTrigger.get(), &header, sizeof(header), {});
if (sendHeaderStatus != OK) {
ALOGE("Could not write connection header to socket: %s",
statusToString(sendHeaderStatus).c_str());
@@ -570,7 +570,7 @@
if (sessionId.size() > 0) {
auto sendSessionIdStatus =
server->interruptableWriteFully(mShutdownTrigger.get(), sessionId.data(),
- sessionId.size());
+ sessionId.size(), {});
if (sendSessionIdStatus != OK) {
ALOGE("Could not write session ID ('%s') to socket: %s",
base::HexString(sessionId.data(), sessionId.size()).c_str(),
diff --git a/libs/binder/RpcState.cpp b/libs/binder/RpcState.cpp
index 5bf6bb7..a3079fe 100644
--- a/libs/binder/RpcState.cpp
+++ b/libs/binder/RpcState.cpp
@@ -307,7 +307,7 @@
status_t RpcState::rpcSend(const sp<RpcSession::RpcConnection>& connection,
const sp<RpcSession>& session, const char* what, const void* data,
- size_t size) {
+ size_t size, const std::function<status_t()>& altPoll) {
LOG_RPC_DETAIL("Sending %s on RpcTransport %p: %s", what, connection->rpcTransport.get(),
android::base::HexString(data, size).c_str());
@@ -319,7 +319,7 @@
if (status_t status =
connection->rpcTransport->interruptableWriteFully(session->mShutdownTrigger.get(),
- data, size);
+ data, size, altPoll);
status != OK) {
LOG_RPC_DETAIL("Failed to write %s (%zu bytes) on RpcTransport %p, error: %s", what, size,
connection->rpcTransport.get(), statusToString(status).c_str());
@@ -341,7 +341,7 @@
if (status_t status =
connection->rpcTransport->interruptableReadFully(session->mShutdownTrigger.get(),
- data, size);
+ data, size, {});
status != OK) {
LOG_RPC_DETAIL("Failed to read %s (%zu bytes) on RpcTransport %p, error: %s", what, size,
connection->rpcTransport.get(), statusToString(status).c_str());
@@ -519,21 +519,44 @@
memcpy(transactionData.data() + sizeof(RpcWireHeader) + sizeof(RpcWireTransaction), data.data(),
data.dataSize());
+ constexpr size_t kWaitMaxUs = 1000000;
+ constexpr size_t kWaitLogUs = 10000;
+ size_t waitUs = 0;
+
+ // Oneway calls have no sync point, so if many are sent before, whether this
+ // is a twoway or oneway transaction, they may have filled up the socket.
+ // So, make sure we drain them before polling.
+ std::function<status_t()> drainRefs = [&] {
+ if (waitUs > kWaitLogUs) {
+ ALOGE("Cannot send command, trying to process pending refcounts. Waiting %zuus. Too "
+ "many oneway calls?",
+ waitUs);
+ }
+
+ if (waitUs > 0) {
+ usleep(waitUs);
+ waitUs = std::min(kWaitMaxUs, waitUs * 2);
+ } else {
+ waitUs = 1;
+ }
+
+ return drainCommands(connection, session, CommandType::CONTROL_ONLY);
+ };
+
if (status_t status = rpcSend(connection, session, "transaction", transactionData.data(),
- transactionData.size());
- status != OK)
+ transactionData.size(), drainRefs);
+ status != OK) {
// TODO(b/167966510): need to undo onBinderLeaving - we know the
// refcount isn't successfully transferred.
return status;
+ }
if (flags & IBinder::FLAG_ONEWAY) {
LOG_RPC_DETAIL("Oneway command, so no longer waiting on RpcTransport %p",
connection->rpcTransport.get());
// Do not wait on result.
- // However, too many oneway calls may cause refcounts to build up and fill up the socket,
- // so process those.
- return drainCommands(connection, session, CommandType::CONTROL_ONLY);
+ return OK;
}
LOG_ALWAYS_FATAL_IF(reply == nullptr, "Reply parcel must be used for synchronous transaction.");
diff --git a/libs/binder/RpcState.h b/libs/binder/RpcState.h
index 42e95e0..50de22b 100644
--- a/libs/binder/RpcState.h
+++ b/libs/binder/RpcState.h
@@ -177,7 +177,8 @@
[[nodiscard]] status_t rpcSend(const sp<RpcSession::RpcConnection>& connection,
const sp<RpcSession>& session, const char* what,
- const void* data, size_t size);
+ const void* data, size_t size,
+ const std::function<status_t()>& altPoll = nullptr);
[[nodiscard]] status_t rpcRec(const sp<RpcSession::RpcConnection>& connection,
const sp<RpcSession>& session, const char* what, void* data,
size_t size);
diff --git a/libs/binder/RpcTransportRaw.cpp b/libs/binder/RpcTransportRaw.cpp
index a22bc6f..7669518 100644
--- a/libs/binder/RpcTransportRaw.cpp
+++ b/libs/binder/RpcTransportRaw.cpp
@@ -46,7 +46,7 @@
template <typename Buffer, typename SendOrReceive>
status_t interruptableReadOrWrite(FdTrigger* fdTrigger, Buffer buffer, size_t size,
SendOrReceive sendOrReceiveFun, const char* funName,
- int16_t event) {
+ int16_t event, const std::function<status_t()>& altPoll) {
const Buffer end = buffer + size;
MAYBE_WAIT_IN_FLAKE_MODE;
@@ -57,9 +57,8 @@
return DEAD_OBJECT;
}
- bool first = true;
- status_t status;
- do {
+ bool havePolled = false;
+ while (true) {
ssize_t processSize = TEMP_FAILURE_RETRY(
sendOrReceiveFun(mSocket.get(), buffer, end - buffer, MSG_NOSIGNAL));
@@ -68,7 +67,8 @@
// Still return the error on later passes, since it would expose
// a problem with polling
- if (!first || (first && savedErrno != EAGAIN && savedErrno != EWOULDBLOCK)) {
+ if (havePolled ||
+ (!havePolled && savedErrno != EAGAIN && savedErrno != EWOULDBLOCK)) {
LOG_RPC_DETAIL("RpcTransport %s(): %s", funName, strerror(savedErrno));
return -savedErrno;
}
@@ -81,19 +81,30 @@
}
}
- if (first) first = false;
- } while ((status = fdTrigger->triggerablePoll(mSocket.get(), event)) == OK);
- return status;
+ if (altPoll) {
+ if (status_t status = altPoll(); status != OK) return status;
+ if (fdTrigger->isTriggered()) {
+ return DEAD_OBJECT;
+ }
+ } else {
+ if (status_t status = fdTrigger->triggerablePoll(mSocket.get(), event);
+ status != OK)
+ return status;
+ if (!havePolled) havePolled = true;
+ }
+ }
}
- status_t interruptableWriteFully(FdTrigger* fdTrigger, const void* data, size_t size) override {
+ status_t interruptableWriteFully(FdTrigger* fdTrigger, const void* data, size_t size,
+ const std::function<status_t()>& altPoll) override {
return interruptableReadOrWrite(fdTrigger, reinterpret_cast<const uint8_t*>(data), size,
- send, "send", POLLOUT);
+ send, "send", POLLOUT, altPoll);
}
- status_t interruptableReadFully(FdTrigger* fdTrigger, void* data, size_t size) override {
+ status_t interruptableReadFully(FdTrigger* fdTrigger, void* data, size_t size,
+ const std::function<status_t()>& altPoll) override {
return interruptableReadOrWrite(fdTrigger, reinterpret_cast<uint8_t*>(data), size, recv,
- "recv", POLLIN);
+ "recv", POLLIN, altPoll);
}
private:
diff --git a/libs/binder/RpcTransportTls.cpp b/libs/binder/RpcTransportTls.cpp
index f8cd71d..7f810b1 100644
--- a/libs/binder/RpcTransportTls.cpp
+++ b/libs/binder/RpcTransportTls.cpp
@@ -169,12 +169,13 @@
// If |sslError| is WANT_READ / WANT_WRITE, poll for POLLIN / POLLOUT respectively. Otherwise
// return error. Also return error if |fdTrigger| is triggered before or during poll().
status_t pollForSslError(android::base::borrowed_fd fd, int sslError, FdTrigger* fdTrigger,
- const char* fnString, int additionalEvent = 0) {
+ const char* fnString, int additionalEvent,
+ const std::function<status_t()>& altPoll) {
switch (sslError) {
case SSL_ERROR_WANT_READ:
- return handlePoll(POLLIN | additionalEvent, fd, fdTrigger, fnString);
+ return handlePoll(POLLIN | additionalEvent, fd, fdTrigger, fnString, altPoll);
case SSL_ERROR_WANT_WRITE:
- return handlePoll(POLLOUT | additionalEvent, fd, fdTrigger, fnString);
+ return handlePoll(POLLOUT | additionalEvent, fd, fdTrigger, fnString, altPoll);
case SSL_ERROR_SYSCALL: {
auto queue = toString();
LOG_TLS_DETAIL("%s(): %s. Treating as DEAD_OBJECT. Error queue: %s", fnString,
@@ -194,11 +195,17 @@
bool mHandled = false;
status_t handlePoll(int event, android::base::borrowed_fd fd, FdTrigger* fdTrigger,
- const char* fnString) {
- status_t ret = fdTrigger->triggerablePoll(fd, event);
+ const char* fnString, const std::function<status_t()>& altPoll) {
+ status_t ret;
+ if (altPoll) {
+ ret = altPoll();
+ if (fdTrigger->isTriggered()) ret = DEAD_OBJECT;
+ } else {
+ ret = fdTrigger->triggerablePoll(fd, event);
+ }
+
if (ret != OK && ret != DEAD_OBJECT) {
- ALOGE("triggerablePoll error while poll()-ing after %s(): %s", fnString,
- statusToString(ret).c_str());
+ ALOGE("poll error while after %s(): %s", fnString, statusToString(ret).c_str());
}
clear();
return ret;
@@ -268,8 +275,10 @@
RpcTransportTls(android::base::unique_fd socket, Ssl ssl)
: mSocket(std::move(socket)), mSsl(std::move(ssl)) {}
Result<size_t> peek(void* buf, size_t size) override;
- status_t interruptableWriteFully(FdTrigger* fdTrigger, const void* data, size_t size) override;
- status_t interruptableReadFully(FdTrigger* fdTrigger, void* data, size_t size) override;
+ status_t interruptableWriteFully(FdTrigger* fdTrigger, const void* data, size_t size,
+ const std::function<status_t()>& altPoll) override;
+ status_t interruptableReadFully(FdTrigger* fdTrigger, void* data, size_t size,
+ const std::function<status_t()>& altPoll) override;
private:
android::base::unique_fd mSocket;
@@ -295,7 +304,8 @@
}
status_t RpcTransportTls::interruptableWriteFully(FdTrigger* fdTrigger, const void* data,
- size_t size) {
+ size_t size,
+ const std::function<status_t()>& altPoll) {
auto buffer = reinterpret_cast<const uint8_t*>(data);
const uint8_t* end = buffer + size;
@@ -317,8 +327,8 @@
int sslError = mSsl.getError(writeSize);
// TODO(b/195788248): BIO should contain the FdTrigger, and send(2) / recv(2) should be
// triggerablePoll()-ed. Then additionalEvent is no longer necessary.
- status_t pollStatus =
- errorQueue.pollForSslError(mSocket.get(), sslError, fdTrigger, "SSL_write", POLLIN);
+ status_t pollStatus = errorQueue.pollForSslError(mSocket.get(), sslError, fdTrigger,
+ "SSL_write", POLLIN, altPoll);
if (pollStatus != OK) return pollStatus;
// Do not advance buffer. Try SSL_write() again.
}
@@ -326,7 +336,8 @@
return OK;
}
-status_t RpcTransportTls::interruptableReadFully(FdTrigger* fdTrigger, void* data, size_t size) {
+status_t RpcTransportTls::interruptableReadFully(FdTrigger* fdTrigger, void* data, size_t size,
+ const std::function<status_t()>& altPoll) {
auto buffer = reinterpret_cast<uint8_t*>(data);
uint8_t* end = buffer + size;
@@ -350,8 +361,8 @@
return DEAD_OBJECT;
}
int sslError = mSsl.getError(readSize);
- status_t pollStatus =
- errorQueue.pollForSslError(mSocket.get(), sslError, fdTrigger, "SSL_read");
+ status_t pollStatus = errorQueue.pollForSslError(mSocket.get(), sslError, fdTrigger,
+ "SSL_read", 0, altPoll);
if (pollStatus != OK) return pollStatus;
// Do not advance buffer. Try SSL_read() again.
}
@@ -382,7 +393,7 @@
}
int sslError = ssl->getError(ret);
status_t pollStatus =
- errorQueue.pollForSslError(fd, sslError, fdTrigger, "SSL_do_handshake");
+ errorQueue.pollForSslError(fd, sslError, fdTrigger, "SSL_do_handshake", 0, {});
if (pollStatus != OK) return false;
}
}
diff --git a/libs/binder/include/binder/RpcTransport.h b/libs/binder/include/binder/RpcTransport.h
index 4fe2324..db8b5e9 100644
--- a/libs/binder/include/binder/RpcTransport.h
+++ b/libs/binder/include/binder/RpcTransport.h
@@ -18,6 +18,7 @@
#pragma once
+#include <functional>
#include <memory>
#include <string>
@@ -43,14 +44,20 @@
/**
* Read (or write), but allow to be interrupted by a trigger.
*
+ * altPoll - function to be called instead of polling, when needing to wait
+ * to read/write data. If this returns an error, that error is returned from
+ * this function.
+ *
* Return:
* OK - succeeded in completely processing 'size'
* error - interrupted (failure or trigger)
*/
- [[nodiscard]] virtual status_t interruptableWriteFully(FdTrigger *fdTrigger, const void *buf,
- size_t size) = 0;
- [[nodiscard]] virtual status_t interruptableReadFully(FdTrigger *fdTrigger, void *buf,
- size_t size) = 0;
+ [[nodiscard]] virtual status_t interruptableWriteFully(
+ FdTrigger *fdTrigger, const void *buf, size_t size,
+ const std::function<status_t()> &altPoll) = 0;
+ [[nodiscard]] virtual status_t interruptableReadFully(
+ FdTrigger *fdTrigger, void *buf, size_t size,
+ const std::function<status_t()> &altPoll) = 0;
protected:
RpcTransport() = default;
diff --git a/libs/binder/tests/binderRpcTest.cpp b/libs/binder/tests/binderRpcTest.cpp
index aad0162..ed477d6 100644
--- a/libs/binder/tests/binderRpcTest.cpp
+++ b/libs/binder/tests/binderRpcTest.cpp
@@ -1574,7 +1574,7 @@
FdTrigger* fdTrigger) {
std::string message(kMessage);
auto status = serverTransport->interruptableWriteFully(fdTrigger, message.data(),
- message.size());
+ message.size(), {});
if (status != OK) return AssertionFailure() << statusToString(status);
return AssertionSuccess();
}
@@ -1607,7 +1607,7 @@
std::string readMessage(expectedMessage.size(), '\0');
status_t readStatus =
mClientTransport->interruptableReadFully(mFdTrigger.get(), readMessage.data(),
- readMessage.size());
+ readMessage.size(), {});
if (readStatus != OK) {
return AssertionFailure() << statusToString(readStatus);
}
@@ -1801,8 +1801,8 @@
bool shouldContinueWriting = false;
auto serverPostConnect = [&](RpcTransport* serverTransport, FdTrigger* fdTrigger) {
std::string message(RpcTransportTestUtils::kMessage);
- auto status =
- serverTransport->interruptableWriteFully(fdTrigger, message.data(), message.size());
+ auto status = serverTransport->interruptableWriteFully(fdTrigger, message.data(),
+ message.size(), {});
if (status != OK) return AssertionFailure() << statusToString(status);
{
@@ -1812,7 +1812,7 @@
}
}
- status = serverTransport->interruptableWriteFully(fdTrigger, msg2.data(), msg2.size());
+ status = serverTransport->interruptableWriteFully(fdTrigger, msg2.data(), msg2.size(), {});
if (status != DEAD_OBJECT)
return AssertionFailure() << "When FdTrigger is shut down, interruptableWriteFully "
"should return DEAD_OBJECT, but it is "