binder: Use RpcTransport
- after accept() / connect(), call sslAccept() /
sslConnect(), respectively.
- replace ::send() / ::recv() with RpcTransport::
send() / recv() / peek() accordingly.
Also refacator binderRpcTest to prepare for TLS implementation.
Test: TH
Test: binderRpcTest
Bug: 190868302
Change-Id: I809345c59a467cd219ebcec7a9db3a3b7776a601
diff --git a/libs/binder/RpcServer.cpp b/libs/binder/RpcServer.cpp
index 62ea187..6ab25e0 100644
--- a/libs/binder/RpcServer.cpp
+++ b/libs/binder/RpcServer.cpp
@@ -26,6 +26,7 @@
#include <android-base/scopeguard.h>
#include <binder/Parcel.h>
#include <binder/RpcServer.h>
+#include <binder/RpcTransportRaw.h>
#include <log/log.h>
#include "RpcSocketAddress.h"
@@ -37,13 +38,17 @@
using base::ScopeGuard;
using base::unique_fd;
-RpcServer::RpcServer() {}
+RpcServer::RpcServer(std::unique_ptr<RpcTransportCtxFactory> rpcTransportCtxFactory)
+ : mRpcTransportCtxFactory(std::move(rpcTransportCtxFactory)) {}
RpcServer::~RpcServer() {
(void)shutdown();
}
-sp<RpcServer> RpcServer::make() {
- return sp<RpcServer>::make();
+sp<RpcServer> RpcServer::make(std::unique_ptr<RpcTransportCtxFactory> rpcTransportCtxFactory) {
+ // Default is without TLS.
+ if (rpcTransportCtxFactory == nullptr)
+ rpcTransportCtxFactory = RpcTransportCtxFactoryRaw::make();
+ return sp<RpcServer>::make(std::move(rpcTransportCtxFactory));
}
void RpcServer::iUnderstandThisCodeIsExperimentalAndIWillNotUseItInProduction() {
@@ -154,6 +159,10 @@
mJoinThreadRunning = true;
mShutdownTrigger = RpcSession::FdTrigger::make();
LOG_ALWAYS_FATAL_IF(mShutdownTrigger == nullptr, "Cannot create join signaler");
+
+ mCtx = mRpcTransportCtxFactory->newServerCtx();
+ LOG_ALWAYS_FATAL_IF(mCtx == nullptr, "Unable to create RpcTransportCtx with %s sockets",
+ mRpcTransportCtxFactory->toCString());
}
status_t status;
@@ -220,6 +229,7 @@
LOG_RPC_DETAIL("Finished waiting on shutdown.");
mShutdownTrigger = nullptr;
+ mCtx = nullptr;
return true;
}
@@ -245,14 +255,29 @@
// mShutdownTrigger can only be cleared once connection threads have joined.
// It must be set before this thread is started
LOG_ALWAYS_FATAL_IF(server->mShutdownTrigger == nullptr);
+ LOG_ALWAYS_FATAL_IF(server->mCtx == nullptr);
+
+ status_t status = OK;
+
+ int clientFdForLog = clientFd.get();
+ auto client = server->mCtx->newTransport(std::move(clientFd));
+ if (client == nullptr) {
+ ALOGE("Dropping accept4()-ed socket because sslAccept fails");
+ status = DEAD_OBJECT;
+ // still need to cleanup before we can return
+ } else {
+ LOG_RPC_DETAIL("Created RpcTransport %p for client fd %d", client.get(), clientFdForLog);
+ }
RpcConnectionHeader header;
- status_t status = server->mShutdownTrigger->interruptableReadFully(clientFd.get(), &header,
- sizeof(header));
- if (status != OK) {
- ALOGE("Failed to read ID for client connecting to RPC server: %s",
- statusToString(status).c_str());
- // still need to cleanup before we can return
+ if (status == OK) {
+ status = server->mShutdownTrigger->interruptableReadFully(client.get(), &header,
+ sizeof(header));
+ if (status != OK) {
+ ALOGE("Failed to read ID for client connecting to RPC server: %s",
+ statusToString(status).c_str());
+ // still need to cleanup before we can return
+ }
}
bool incoming = false;
@@ -272,7 +297,7 @@
.version = protocolVersion,
};
- status = server->mShutdownTrigger->interruptableWriteFully(clientFd.get(), &response,
+ status = server->mShutdownTrigger->interruptableWriteFully(client.get(), &response,
sizeof(response));
if (status != OK) {
ALOGE("Failed to send new session response: %s", statusToString(status).c_str());
@@ -342,7 +367,7 @@
}
if (incoming) {
- LOG_ALWAYS_FATAL_IF(!session->addOutgoingConnection(std::move(clientFd), true),
+ LOG_ALWAYS_FATAL_IF(!session->addOutgoingConnection(std::move(client), true),
"server state must already be initialized");
return;
}
@@ -351,7 +376,7 @@
session->preJoinThreadOwnership(std::move(thisThread));
}
- auto setupResult = session->preJoinSetup(std::move(clientFd));
+ auto setupResult = session->preJoinSetup(std::move(client));
// avoid strong cycle
server = nullptr;
diff --git a/libs/binder/RpcSession.cpp b/libs/binder/RpcSession.cpp
index 254b99c..6a22913 100644
--- a/libs/binder/RpcSession.cpp
+++ b/libs/binder/RpcSession.cpp
@@ -30,6 +30,7 @@
#include <android_runtime/vm.h>
#include <binder/Parcel.h>
#include <binder/RpcServer.h>
+#include <binder/RpcTransportRaw.h>
#include <binder/Stability.h>
#include <jni.h>
#include <utils/String8.h>
@@ -46,7 +47,8 @@
using base::unique_fd;
-RpcSession::RpcSession() {
+RpcSession::RpcSession(std::unique_ptr<RpcTransportCtxFactory> rpcTransportCtxFactory)
+ : mRpcTransportCtxFactory(std::move(rpcTransportCtxFactory)) {
LOG_RPC_DETAIL("RpcSession created %p", this);
mState = std::make_unique<RpcState>();
@@ -59,8 +61,11 @@
"Should not be able to destroy a session with servers in use.");
}
-sp<RpcSession> RpcSession::make() {
- return sp<RpcSession>::make();
+sp<RpcSession> RpcSession::make(std::unique_ptr<RpcTransportCtxFactory> rpcTransportCtxFactory) {
+ // Default is without TLS.
+ if (rpcTransportCtxFactory == nullptr)
+ rpcTransportCtxFactory = RpcTransportCtxFactoryRaw::make();
+ return sp<RpcSession>::make(std::move(rpcTransportCtxFactory));
}
void RpcSession::setMaxThreads(size_t threads) {
@@ -122,6 +127,7 @@
}
bool RpcSession::addNullDebuggingClient() {
+ // Note: only works on raw sockets.
unique_fd serverFd(TEMP_FAILURE_RETRY(open("/dev/null", O_WRONLY | O_CLOEXEC)));
if (serverFd == -1) {
@@ -129,7 +135,17 @@
return false;
}
- return addOutgoingConnection(std::move(serverFd), false);
+ auto ctx = mRpcTransportCtxFactory->newClientCtx();
+ if (ctx == nullptr) {
+ ALOGE("Unable to create RpcTransportCtx for null debugging client");
+ return false;
+ }
+ auto server = ctx->newTransport(std::move(serverFd));
+ if (server == nullptr) {
+ ALOGE("Unable to set up RpcTransport");
+ return false;
+ }
+ return addOutgoingConnection(std::move(server), false);
}
sp<IBinder> RpcSession::getRootObject() {
@@ -205,6 +221,10 @@
return mWrite == -1;
}
+status_t RpcSession::FdTrigger::triggerablePoll(RpcTransport* rpcTransport, int16_t event) {
+ return triggerablePoll(rpcTransport->pollSocket(), event);
+}
+
status_t RpcSession::FdTrigger::triggerablePoll(base::borrowed_fd fd, int16_t event) {
while (true) {
pollfd pfd[]{{.fd = fd.get(), .events = static_cast<int16_t>(event), .revents = 0},
@@ -223,28 +243,30 @@
}
}
-status_t RpcSession::FdTrigger::interruptableWriteFully(base::borrowed_fd fd, const void* data,
- size_t size) {
+status_t RpcSession::FdTrigger::interruptableWriteFully(RpcTransport* rpcTransport,
+ const void* data, size_t size) {
const uint8_t* buffer = reinterpret_cast<const uint8_t*>(data);
const uint8_t* end = buffer + size;
MAYBE_WAIT_IN_FLAKE_MODE;
status_t status;
- while ((status = triggerablePoll(fd, POLLOUT)) == OK) {
- ssize_t writeSize = TEMP_FAILURE_RETRY(send(fd.get(), buffer, end - buffer, MSG_NOSIGNAL));
- if (writeSize == 0) return DEAD_OBJECT;
-
- if (writeSize < 0) {
- return -errno;
+ while ((status = triggerablePoll(rpcTransport, POLLOUT)) == OK) {
+ auto writeSize = rpcTransport->send(buffer, end - buffer);
+ if (!writeSize.ok()) {
+ LOG_RPC_DETAIL("RpcTransport::send(): %s", writeSize.error().message().c_str());
+ return writeSize.error().code() == 0 ? UNKNOWN_ERROR : -writeSize.error().code();
}
- buffer += writeSize;
+
+ if (*writeSize == 0) return DEAD_OBJECT;
+
+ buffer += *writeSize;
if (buffer == end) return OK;
}
return status;
}
-status_t RpcSession::FdTrigger::interruptableReadFully(base::borrowed_fd fd, void* data,
+status_t RpcSession::FdTrigger::interruptableReadFully(RpcTransport* rpcTransport, void* data,
size_t size) {
uint8_t* buffer = reinterpret_cast<uint8_t*>(data);
uint8_t* end = buffer + size;
@@ -252,14 +274,16 @@
MAYBE_WAIT_IN_FLAKE_MODE;
status_t status;
- while ((status = triggerablePoll(fd, POLLIN)) == OK) {
- ssize_t readSize = TEMP_FAILURE_RETRY(recv(fd.get(), buffer, end - buffer, MSG_NOSIGNAL));
- if (readSize == 0) return DEAD_OBJECT; // EOF
-
- if (readSize < 0) {
- return -errno;
+ while ((status = triggerablePoll(rpcTransport, POLLIN)) == OK) {
+ auto readSize = rpcTransport->recv(buffer, end - buffer);
+ if (!readSize.ok()) {
+ LOG_RPC_DETAIL("RpcTransport::recv(): %s", readSize.error().message().c_str());
+ return readSize.error().code() == 0 ? UNKNOWN_ERROR : -readSize.error().code();
}
- buffer += readSize;
+
+ if (*readSize == 0) return DEAD_OBJECT; // EOF
+
+ buffer += *readSize;
if (buffer == end) return OK;
}
return status;
@@ -312,10 +336,11 @@
}
}
-RpcSession::PreJoinSetupResult RpcSession::preJoinSetup(base::unique_fd fd) {
+RpcSession::PreJoinSetupResult RpcSession::preJoinSetup(
+ std::unique_ptr<RpcTransport> rpcTransport) {
// must be registered to allow arbitrary client code executing commands to
// be able to do nested calls (we can't only read from it)
- sp<RpcConnection> connection = assignIncomingConnectionToThisThread(std::move(fd));
+ sp<RpcConnection> connection = assignIncomingConnectionToThisThread(std::move(rpcTransport));
status_t status;
@@ -520,6 +545,22 @@
strerror(savedErrno));
return false;
}
+ LOG_RPC_DETAIL("Socket at %s client with fd %d", addr.toString().c_str(), serverFd.get());
+
+ auto ctx = mRpcTransportCtxFactory->newClientCtx();
+ if (ctx == nullptr) {
+ ALOGE("Unable to create client RpcTransportCtx with %s sockets",
+ mRpcTransportCtxFactory->toCString());
+ return false;
+ }
+ auto server = ctx->newTransport(std::move(serverFd));
+ if (server == nullptr) {
+ ALOGE("Unable to set up RpcTransport for %s", addr.toString().c_str());
+ return false;
+ }
+
+ LOG_RPC_DETAIL("Socket at %s client with RpcTransport %p", addr.toString().c_str(),
+ server.get());
RpcConnectionHeader header{
.version = mProtocolVersion.value_or(RPC_WIRE_PROTOCOL_VERSION),
@@ -529,19 +570,24 @@
if (incoming) header.options |= RPC_CONNECTION_OPTION_INCOMING;
- if (sizeof(header) != TEMP_FAILURE_RETRY(write(serverFd.get(), &header, sizeof(header)))) {
- int savedErrno = errno;
+ auto sentHeader = server->send(&header, sizeof(header));
+ if (!sentHeader.ok()) {
ALOGE("Could not write connection header to socket at %s: %s", addr.toString().c_str(),
- strerror(savedErrno));
+ sentHeader.error().message().c_str());
+ return false;
+ }
+ if (*sentHeader != sizeof(header)) {
+ ALOGE("Could not write connection header to socket at %s: sent %zd bytes, expected %zd",
+ addr.toString().c_str(), *sentHeader, sizeof(header));
return false;
}
- LOG_RPC_DETAIL("Socket at %s client with fd %d", addr.toString().c_str(), serverFd.get());
+ LOG_RPC_DETAIL("Socket at %s client: header sent", addr.toString().c_str());
if (incoming) {
- return addIncomingConnection(std::move(serverFd));
+ return addIncomingConnection(std::move(server));
} else {
- return addOutgoingConnection(std::move(serverFd), true);
+ return addOutgoingConnection(std::move(server), true);
}
}
@@ -549,7 +595,7 @@
return false;
}
-bool RpcSession::addIncomingConnection(unique_fd fd) {
+bool RpcSession::addIncomingConnection(std::unique_ptr<RpcTransport> rpcTransport) {
std::mutex mutex;
std::condition_variable joinCv;
std::unique_lock<std::mutex> lock(mutex);
@@ -558,13 +604,13 @@
bool ownershipTransferred = false;
thread = std::thread([&]() {
std::unique_lock<std::mutex> threadLock(mutex);
- unique_fd movedFd = std::move(fd);
+ std::unique_ptr<RpcTransport> movedRpcTransport = std::move(rpcTransport);
// NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
sp<RpcSession> session = thiz;
session->preJoinThreadOwnership(std::move(thread));
// only continue once we have a response or the connection fails
- auto setupResult = session->preJoinSetup(std::move(movedFd));
+ auto setupResult = session->preJoinSetup(std::move(movedRpcTransport));
ownershipTransferred = true;
threadLock.unlock();
@@ -578,7 +624,7 @@
return true;
}
-bool RpcSession::addOutgoingConnection(unique_fd fd, bool init) {
+bool RpcSession::addOutgoingConnection(std::unique_ptr<RpcTransport> rpcTransport, bool init) {
sp<RpcConnection> connection = sp<RpcConnection>::make();
{
std::lock_guard<std::mutex> _l(mMutex);
@@ -591,7 +637,7 @@
if (mShutdownTrigger == nullptr) return false;
}
- connection->fd = std::move(fd);
+ connection->rpcTransport = std::move(rpcTransport);
connection->exclusiveTid = gettid();
mOutgoingConnections.push_back(connection);
}
@@ -626,7 +672,8 @@
return true;
}
-sp<RpcSession::RpcConnection> RpcSession::assignIncomingConnectionToThisThread(unique_fd fd) {
+sp<RpcSession::RpcConnection> RpcSession::assignIncomingConnectionToThisThread(
+ std::unique_ptr<RpcTransport> rpcTransport) {
std::lock_guard<std::mutex> _l(mMutex);
// Don't accept any more connections, some have shutdown. Usually this
@@ -638,7 +685,7 @@
}
sp<RpcConnection> session = sp<RpcConnection>::make();
- session->fd = std::move(fd);
+ session->rpcTransport = std::move(rpcTransport);
session->exclusiveTid = gettid();
mIncomingConnections.push_back(session);
diff --git a/libs/binder/RpcState.cpp b/libs/binder/RpcState.cpp
index 36c03c5..6563bc8 100644
--- a/libs/binder/RpcState.cpp
+++ b/libs/binder/RpcState.cpp
@@ -273,7 +273,7 @@
status_t RpcState::rpcSend(const sp<RpcSession::RpcConnection>& connection,
const sp<RpcSession>& session, const char* what, const void* data,
size_t size) {
- LOG_RPC_DETAIL("Sending %s on fd %d: %s", what, connection->fd.get(),
+ LOG_RPC_DETAIL("Sending %s on RpcTransport %p: %s", what, connection->rpcTransport.get(),
android::base::HexString(data, size).c_str());
if (size > std::numeric_limits<ssize_t>::max()) {
@@ -282,11 +282,12 @@
return BAD_VALUE;
}
- if (status_t status = session->mShutdownTrigger->interruptableWriteFully(connection->fd.get(),
- data, size);
+ if (status_t status =
+ session->mShutdownTrigger->interruptableWriteFully(connection->rpcTransport.get(),
+ data, size);
status != OK) {
- LOG_RPC_DETAIL("Failed to write %s (%zu bytes) on fd %d, error: %s", what, size,
- connection->fd.get(), statusToString(status).c_str());
+ LOG_RPC_DETAIL("Failed to write %s (%zu bytes) on RpcTransport %p, error: %s", what, size,
+ connection->rpcTransport.get(), statusToString(status).c_str());
(void)session->shutdownAndWait(false);
return status;
}
@@ -304,14 +305,15 @@
}
if (status_t status =
- session->mShutdownTrigger->interruptableReadFully(connection->fd.get(), data, size);
+ session->mShutdownTrigger->interruptableReadFully(connection->rpcTransport.get(),
+ data, size);
status != OK) {
- LOG_RPC_DETAIL("Failed to read %s (%zu bytes) on fd %d, error: %s", what, size,
- connection->fd.get(), statusToString(status).c_str());
+ LOG_RPC_DETAIL("Failed to read %s (%zu bytes) on RpcTransport %p, error: %s", what, size,
+ connection->rpcTransport.get(), statusToString(status).c_str());
return status;
}
- LOG_RPC_DETAIL("Received %s on fd %d: %s", what, connection->fd.get(),
+ LOG_RPC_DETAIL("Received %s on RpcTransport %p: %s", what, connection->rpcTransport.get(),
android::base::HexString(data, size).c_str());
return OK;
}
@@ -490,7 +492,8 @@
return status;
if (flags & IBinder::FLAG_ONEWAY) {
- LOG_RPC_DETAIL("Oneway command, so no longer waiting on %d", connection->fd.get());
+ LOG_RPC_DETAIL("Oneway command, so no longer waiting on RpcTransport %p",
+ connection->rpcTransport.get());
// Do not wait on result.
// However, too many oneway calls may cause refcounts to build up and fill up the socket,
@@ -585,7 +588,7 @@
status_t RpcState::getAndExecuteCommand(const sp<RpcSession::RpcConnection>& connection,
const sp<RpcSession>& session, CommandType type) {
- LOG_RPC_DETAIL("getAndExecuteCommand on fd %d", connection->fd.get());
+ LOG_RPC_DETAIL("getAndExecuteCommand on RpcTransport %p", connection->rpcTransport.get());
RpcWireHeader command;
if (status_t status = rpcRec(connection, session, "command header", &command, sizeof(command));
@@ -598,8 +601,7 @@
status_t RpcState::drainCommands(const sp<RpcSession::RpcConnection>& connection,
const sp<RpcSession>& session, CommandType type) {
uint8_t buf;
- while (0 < TEMP_FAILURE_RETRY(
- recv(connection->fd.get(), &buf, sizeof(buf), MSG_PEEK | MSG_DONTWAIT))) {
+ while (connection->rpcTransport->peek(&buf, sizeof(buf)).value_or(-1) > 0) {
status_t status = getAndExecuteCommand(connection, session, type);
if (status != OK) return status;
}
diff --git a/libs/binder/include/binder/RpcServer.h b/libs/binder/include/binder/RpcServer.h
index 40ff78c..0fdab9c 100644
--- a/libs/binder/include/binder/RpcServer.h
+++ b/libs/binder/include/binder/RpcServer.h
@@ -19,6 +19,7 @@
#include <binder/IBinder.h>
#include <binder/RpcAddress.h>
#include <binder/RpcSession.h>
+#include <binder/RpcTransport.h>
#include <utils/Errors.h>
#include <utils/RefBase.h>
@@ -47,7 +48,8 @@
*/
class RpcServer final : public virtual RefBase, private RpcSession::EventListener {
public:
- static sp<RpcServer> make();
+ static sp<RpcServer> make(
+ std::unique_ptr<RpcTransportCtxFactory> rpcTransportCtxFactory = nullptr);
/**
* This represents a session for responses, e.g.:
@@ -161,7 +163,7 @@
private:
friend sp<RpcServer>;
- RpcServer();
+ explicit RpcServer(std::unique_ptr<RpcTransportCtxFactory> rpcTransportCtxFactory);
void onSessionAllIncomingThreadsEnded(const sp<RpcSession>& session) override;
void onSessionIncomingThreadEnded() override;
@@ -169,6 +171,7 @@
static void establishConnection(sp<RpcServer>&& server, base::unique_fd clientFd);
bool setupSocketServer(const RpcSocketAddress& address);
+ const std::unique_ptr<RpcTransportCtxFactory> mRpcTransportCtxFactory;
bool mAgreedExperimental = false;
size_t mMaxThreads = 1;
std::optional<uint32_t> mProtocolVersion;
@@ -183,6 +186,7 @@
std::map<RpcAddress, sp<RpcSession>> mSessions;
std::unique_ptr<RpcSession::FdTrigger> mShutdownTrigger;
std::condition_variable mShutdownCv;
+ std::unique_ptr<RpcTransportCtx> mCtx;
};
} // namespace android
diff --git a/libs/binder/include/binder/RpcSession.h b/libs/binder/include/binder/RpcSession.h
index 1f7c029..e3d6bba 100644
--- a/libs/binder/include/binder/RpcSession.h
+++ b/libs/binder/include/binder/RpcSession.h
@@ -18,6 +18,8 @@
#include <android-base/unique_fd.h>
#include <binder/IBinder.h>
#include <binder/RpcAddress.h>
+#include <binder/RpcSession.h>
+#include <binder/RpcTransport.h>
#include <utils/Errors.h>
#include <utils/RefBase.h>
@@ -36,6 +38,7 @@
class RpcServer;
class RpcSocketAddress;
class RpcState;
+class RpcTransport;
constexpr uint32_t RPC_WIRE_PROTOCOL_VERSION_NEXT = 0;
constexpr uint32_t RPC_WIRE_PROTOCOL_VERSION_EXPERIMENTAL = 0xF0000000;
@@ -48,7 +51,8 @@
*/
class RpcSession final : public virtual RefBase {
public:
- static sp<RpcSession> make();
+ static sp<RpcSession> make(
+ std::unique_ptr<RpcTransportCtxFactory> rpcTransportCtxFactory = nullptr);
/**
* Set the maximum number of threads allowed to be made (for things like callbacks).
@@ -142,7 +146,7 @@
friend sp<RpcSession>;
friend RpcServer;
friend RpcState;
- RpcSession();
+ explicit RpcSession(std::unique_ptr<RpcTransportCtxFactory> rpcTransportCtxFactory);
/** This is not a pipe. */
struct FdTrigger {
@@ -178,10 +182,12 @@
* true - succeeded in completely processing 'size'
* false - interrupted (failure or trigger)
*/
- status_t interruptableReadFully(base::borrowed_fd fd, void* data, size_t size);
- status_t interruptableWriteFully(base::borrowed_fd fd, const void* data, size_t size);
+ status_t interruptableReadFully(RpcTransport* rpcTransport, void* data, size_t size);
+ status_t interruptableWriteFully(RpcTransport* rpcTransport, const void* data, size_t size);
private:
+ status_t triggerablePoll(RpcTransport* rpcTransport, int16_t event);
+
base::unique_fd mWrite;
base::unique_fd mRead;
};
@@ -204,7 +210,7 @@
};
struct RpcConnection : public RefBase {
- base::unique_fd fd;
+ std::unique_ptr<RpcTransport> rpcTransport;
// whether this or another thread is currently using this fd to make
// or receive transactions.
@@ -230,19 +236,20 @@
// Status of setup
status_t status;
};
- PreJoinSetupResult preJoinSetup(base::unique_fd fd);
+ PreJoinSetupResult preJoinSetup(std::unique_ptr<RpcTransport> rpcTransport);
// join on thread passed to preJoinThreadOwnership
static void join(sp<RpcSession>&& session, PreJoinSetupResult&& result);
[[nodiscard]] bool setupSocketClient(const RpcSocketAddress& address);
[[nodiscard]] bool setupOneSocketConnection(const RpcSocketAddress& address,
const RpcAddress& sessionId, bool server);
- [[nodiscard]] bool addIncomingConnection(base::unique_fd fd);
- [[nodiscard]] bool addOutgoingConnection(base::unique_fd fd, bool init);
+ [[nodiscard]] bool addIncomingConnection(std::unique_ptr<RpcTransport> rpcTransport);
+ [[nodiscard]] bool addOutgoingConnection(std::unique_ptr<RpcTransport> rpcTransport, bool init);
[[nodiscard]] bool setForServer(const wp<RpcServer>& server,
const wp<RpcSession::EventListener>& eventListener,
const RpcAddress& sessionId);
- sp<RpcConnection> assignIncomingConnectionToThisThread(base::unique_fd fd);
+ sp<RpcConnection> assignIncomingConnectionToThisThread(
+ std::unique_ptr<RpcTransport> rpcTransport);
[[nodiscard]] bool removeIncomingConnection(const sp<RpcConnection>& connection);
enum class ConnectionUse {
@@ -275,6 +282,8 @@
bool mReentrant = false;
};
+ const std::unique_ptr<RpcTransportCtxFactory> mRpcTransportCtxFactory;
+
// On the other side of a session, for each of mOutgoingConnections here, there should
// be one of mIncomingConnections on the other side (and vice versa).
//
diff --git a/libs/binder/tests/binderRpcTest.cpp b/libs/binder/tests/binderRpcTest.cpp
index d5786bc..386183d 100644
--- a/libs/binder/tests/binderRpcTest.cpp
+++ b/libs/binder/tests/binderRpcTest.cpp
@@ -29,6 +29,8 @@
#include <binder/ProcessState.h>
#include <binder/RpcServer.h>
#include <binder/RpcSession.h>
+#include <binder/RpcTransport.h>
+#include <binder/RpcTransportRaw.h>
#include <gtest/gtest.h>
#include <chrono>
@@ -50,6 +52,21 @@
static_assert(RPC_WIRE_PROTOCOL_VERSION + 1 == RPC_WIRE_PROTOCOL_VERSION_NEXT ||
RPC_WIRE_PROTOCOL_VERSION == RPC_WIRE_PROTOCOL_VERSION_EXPERIMENTAL);
+enum class RpcSecurity { RAW };
+
+static inline std::vector<RpcSecurity> RpcSecurityValues() {
+ return {RpcSecurity::RAW};
+}
+
+static inline std::unique_ptr<RpcTransportCtxFactory> newFactory(RpcSecurity rpcSecurity) {
+ switch (rpcSecurity) {
+ case RpcSecurity::RAW:
+ return RpcTransportCtxFactoryRaw::make();
+ default:
+ LOG_ALWAYS_FATAL("Unknown RpcSecurity %d", rpcSecurity);
+ }
+}
+
TEST(BinderRpcParcel, EntireParcelFormatted) {
Parcel p;
p.writeInt32(3);
@@ -57,10 +74,17 @@
EXPECT_DEATH(p.markForBinder(sp<BBinder>::make()), "");
}
-TEST(BinderRpc, SetExternalServer) {
+class BinderRpcSimple : public ::testing::TestWithParam<RpcSecurity> {
+public:
+ static std::string PrintTestParam(const ::testing::TestParamInfo<ParamType>& info) {
+ return newFactory(info.param)->toCString();
+ }
+};
+
+TEST_P(BinderRpcSimple, SetExternalServerTest) {
base::unique_fd sink(TEMP_FAILURE_RETRY(open("/dev/null", O_RDWR)));
int sinkFd = sink.get();
- auto server = RpcServer::make();
+ auto server = RpcServer::make(newFactory(GetParam()));
server->iUnderstandThisCodeIsExperimentalAndIWillNotUseItInProduction();
ASSERT_FALSE(server->hasServer());
ASSERT_TRUE(server->setupExternalServer(std::move(sink)));
@@ -388,8 +412,8 @@
VSOCK,
INET,
};
-static inline std::string PrintSocketType(const testing::TestParamInfo<SocketType>& info) {
- switch (info.param) {
+static inline std::string PrintToString(SocketType socketType) {
+ switch (socketType) {
case SocketType::UNIX:
return "unix_domain_socket";
case SocketType::VSOCK:
@@ -402,7 +426,7 @@
}
}
-class BinderRpc : public ::testing::TestWithParam<SocketType> {
+class BinderRpc : public ::testing::TestWithParam<std::tuple<SocketType, RpcSecurity>> {
public:
struct Options {
size_t numThreads = 1;
@@ -410,13 +434,19 @@
size_t numIncomingConnections = 0;
};
+ static inline std::string PrintParamInfo(const testing::TestParamInfo<ParamType>& info) {
+ auto [type, security] = info.param;
+ return PrintToString(type) + "_" + newFactory(security)->toCString();
+ }
+
// This creates a new process serving an interface on a certain number of
// threads.
ProcessSession createRpcTestSocketServerProcess(
const Options& options, const std::function<void(const sp<RpcServer>&)>& configure) {
CHECK_GE(options.numSessions, 1) << "Must have at least one session to a server";
- SocketType socketType = GetParam();
+ SocketType socketType = std::get<0>(GetParam());
+ RpcSecurity rpcSecurity = std::get<1>(GetParam());
unsigned int vsockPort = allocateVsockPort();
std::string addr = allocateSocketAddress();
@@ -424,7 +454,7 @@
auto ret = ProcessSession{
.host = Process([&](android::base::borrowed_fd writeEnd) {
- sp<RpcServer> server = RpcServer::make();
+ sp<RpcServer> server = RpcServer::make(newFactory(rpcSecurity));
server->iUnderstandThisCodeIsExperimentalAndIWillNotUseItInProduction();
server->setMaxThreads(options.numThreads);
@@ -466,7 +496,7 @@
}
for (size_t i = 0; i < options.numSessions; i++) {
- sp<RpcSession> session = RpcSession::make();
+ sp<RpcSession> session = RpcSession::make(newFactory(rpcSecurity));
session->setMaxThreads(options.numIncomingConnections);
switch (socketType) {
@@ -1130,13 +1160,14 @@
}
static bool testSupportVsockLoopback() {
+ // We don't need to enable TLS to know if vsock is supported.
unsigned int vsockPort = allocateVsockPort();
- sp<RpcServer> server = RpcServer::make();
+ sp<RpcServer> server = RpcServer::make(RpcTransportCtxFactoryRaw::make());
server->iUnderstandThisCodeIsExperimentalAndIWillNotUseItInProduction();
CHECK(server->setupVsockServer(vsockPort));
server->start();
- sp<RpcSession> session = RpcSession::make();
+ sp<RpcSession> session = RpcSession::make(RpcTransportCtxFactoryRaw::make());
bool okay = session->setupVsockClient(VMADDR_CID_LOCAL, vsockPort);
while (!server->shutdown()) usleep(10000);
ALOGE("Detected vsock loopback supported: %d", okay);
@@ -1155,10 +1186,13 @@
return ret;
}
-INSTANTIATE_TEST_CASE_P(PerSocket, BinderRpc, ::testing::ValuesIn(testSocketTypes()),
- PrintSocketType);
+INSTANTIATE_TEST_CASE_P(PerSocket, BinderRpc,
+ ::testing::Combine(::testing::ValuesIn(testSocketTypes()),
+ ::testing::ValuesIn(RpcSecurityValues())),
+ BinderRpc::PrintParamInfo);
-class BinderRpcServerRootObject : public ::testing::TestWithParam<std::tuple<bool, bool>> {};
+class BinderRpcServerRootObject
+ : public ::testing::TestWithParam<std::tuple<bool, bool, RpcSecurity>> {};
TEST_P(BinderRpcServerRootObject, WeakRootObject) {
using SetFn = std::function<void(RpcServer*, sp<IBinder>)>;
@@ -1166,8 +1200,8 @@
return isStrong ? SetFn(&RpcServer::setRootObject) : SetFn(&RpcServer::setRootObjectWeak);
};
- auto server = RpcServer::make();
- auto [isStrong1, isStrong2] = GetParam();
+ auto [isStrong1, isStrong2, rpcSecurity] = GetParam();
+ auto server = RpcServer::make(newFactory(rpcSecurity));
auto binder1 = sp<BBinder>::make();
IBinder* binderRaw1 = binder1.get();
setRootObject(isStrong1)(server.get(), binder1);
@@ -1184,7 +1218,8 @@
}
INSTANTIATE_TEST_CASE_P(BinderRpc, BinderRpcServerRootObject,
- ::testing::Combine(::testing::Bool(), ::testing::Bool()));
+ ::testing::Combine(::testing::Bool(), ::testing::Bool(),
+ ::testing::ValuesIn(RpcSecurityValues())));
class OneOffSignal {
public:
@@ -1207,10 +1242,10 @@
bool mValue = false;
};
-TEST(BinderRpc, Shutdown) {
+TEST_P(BinderRpcSimple, Shutdown) {
auto addr = allocateSocketAddress();
unlink(addr.c_str());
- auto server = RpcServer::make();
+ auto server = RpcServer::make(newFactory(GetParam()));
server->iUnderstandThisCodeIsExperimentalAndIWillNotUseItInProduction();
ASSERT_TRUE(server->setupUnixDomainServer(addr.c_str()));
auto joinEnds = std::make_shared<OneOffSignal>();
@@ -1271,6 +1306,9 @@
ASSERT_EQ(OK, rpcBinder->pingBinder());
}
+INSTANTIATE_TEST_CASE_P(BinderRpc, BinderRpcSimple, ::testing::ValuesIn(RpcSecurityValues()),
+ BinderRpcSimple::PrintTestParam);
+
} // namespace android
int main(int argc, char** argv) {
diff --git a/libs/binder/tests/parcel_fuzzer/random_parcel.cpp b/libs/binder/tests/parcel_fuzzer/random_parcel.cpp
index 92fdc72..a2472f8 100644
--- a/libs/binder/tests/parcel_fuzzer/random_parcel.cpp
+++ b/libs/binder/tests/parcel_fuzzer/random_parcel.cpp
@@ -19,6 +19,7 @@
#include <android-base/logging.h>
#include <binder/IServiceManager.h>
#include <binder/RpcSession.h>
+#include <binder/RpcTransportRaw.h>
#include <fuzzbinder/random_fd.h>
#include <utils/String16.h>
@@ -35,7 +36,7 @@
void fillRandomParcel(Parcel* p, FuzzedDataProvider&& provider) {
if (provider.ConsumeBool()) {
- auto session = sp<RpcSession>::make();
+ auto session = RpcSession::make(RpcTransportCtxFactoryRaw::make());
CHECK(session->addNullDebuggingClient());
p->markForRpc(session);
fillRandomParcelData(p, std::move(provider));