Merge changes from topic "binder-tls-trigger" am: 53195f9d63 am: dd326395b5 am: 1fbb5b6fc4 am: 5fccda1352
Original change: https://android-review.googlesource.com/c/platform/frameworks/native/+/1825993
Change-Id: I57a6b2267c01f34ce75fc40b9c840dbe2f1b5511
diff --git a/libs/binder/RpcTransportTls.cpp b/libs/binder/RpcTransportTls.cpp
index d40cfc8..63f9339 100644
--- a/libs/binder/RpcTransportTls.cpp
+++ b/libs/binder/RpcTransportTls.cpp
@@ -347,7 +347,7 @@
ALOGE("%s: %s", __PRETTY_FUNCTION__, ret.error().message().c_str());
return ret.error().code() == 0 ? UNKNOWN_ERROR : -ret.error().code();
}
- return OK;
+ return *ret ? -ECANCELED : OK;
}
status_t RpcTransportTls::interruptableWriteFully(FdTrigger* fdTrigger, const void* data,
diff --git a/libs/binder/tests/binderRpcTest.cpp b/libs/binder/tests/binderRpcTest.cpp
index a4e37ad..2fd63a3 100644
--- a/libs/binder/tests/binderRpcTest.cpp
+++ b/libs/binder/tests/binderRpcTest.cpp
@@ -53,6 +53,7 @@
#include "RpcCertificateVerifierSimple.h"
using namespace std::chrono_literals;
+using namespace std::placeholders;
using testing::AssertionFailure;
using testing::AssertionResult;
using testing::AssertionSuccess;
@@ -1444,7 +1445,7 @@
PrintToString(certificateFormat);
}
void TearDown() override {
- for (auto& server : mServers) server->shutdown();
+ for (auto& server : mServers) server->shutdownAndWait();
}
// A server that handles client socket connections.
@@ -1452,7 +1453,7 @@
public:
explicit Server() {}
Server(Server&&) = default;
- ~Server() { shutdown(); }
+ ~Server() { shutdownAndWait(); }
[[nodiscard]] AssertionResult setUp() {
auto [socketType, rpcSecurity, certificateFormat] = GetParam();
auto rpcServer = RpcServer::make(newFactory(rpcSecurity));
@@ -1536,17 +1537,17 @@
ASSERT_TRUE(acceptedFd.ok());
auto serverTransport = mCtx->newTransport(std::move(acceptedFd), mFdTrigger.get());
if (serverTransport == nullptr) return; // handshake failed
- std::string message(kMessage);
- ASSERT_EQ(OK,
- serverTransport->interruptableWriteFully(mFdTrigger.get(), message.data(),
- message.size()));
+ ASSERT_TRUE(mPostConnect(serverTransport.get(), mFdTrigger.get()));
}
- void shutdown() {
- mFdTrigger->trigger();
- if (mThread != nullptr) {
- mThread->join();
- mThread = nullptr;
- }
+ void shutdownAndWait() {
+ shutdown();
+ join();
+ }
+ void shutdown() { mFdTrigger->trigger(); }
+
+ void setPostConnect(
+ std::function<AssertionResult(RpcTransport*, FdTrigger* fdTrigger)> fn) {
+ mPostConnect = std::move(fn);
}
private:
@@ -1558,6 +1559,26 @@
std::shared_ptr<RpcCertificateVerifierSimple> mCertVerifier =
std::make_shared<RpcCertificateVerifierSimple>();
bool mSetup = false;
+ // The function invoked after connection and handshake. By default, it is
+ // |defaultPostConnect| that sends |kMessage| to the client.
+ std::function<AssertionResult(RpcTransport*, FdTrigger* fdTrigger)> mPostConnect =
+ Server::defaultPostConnect;
+
+ void join() {
+ if (mThread != nullptr) {
+ mThread->join();
+ mThread = nullptr;
+ }
+ }
+
+ static AssertionResult defaultPostConnect(RpcTransport* serverTransport,
+ FdTrigger* fdTrigger) {
+ std::string message(kMessage);
+ auto status = serverTransport->interruptableWriteFully(fdTrigger, message.data(),
+ message.size());
+ if (status != OK) return AssertionFailure() << statusToString(status);
+ return AssertionSuccess();
+ }
};
class Client {
@@ -1566,8 +1587,6 @@
Client(Client&&) = default;
[[nodiscard]] AssertionResult setUp() {
auto [socketType, rpcSecurity, certificateFormat] = GetParam();
- mFd = mConnectToServer();
- if (!mFd.ok()) return AssertionFailure() << "Cannot connect to server";
mFdTrigger = FdTrigger::make();
mCtx = newFactory(rpcSecurity, mCertVerifier)->newClientCtx();
if (mCtx == nullptr) return AssertionFailure() << "newClientCtx";
@@ -1577,24 +1596,35 @@
std::shared_ptr<RpcCertificateVerifierSimple> getCertVerifier() const {
return mCertVerifier;
}
+ // connect() and do handshake
+ bool setUpTransport() {
+ mFd = mConnectToServer();
+ if (!mFd.ok()) return AssertionFailure() << "Cannot connect to server";
+ mClientTransport = mCtx->newTransport(std::move(mFd), mFdTrigger.get());
+ return mClientTransport != nullptr;
+ }
+ AssertionResult readMessage(const std::string& expectedMessage = kMessage) {
+ LOG_ALWAYS_FATAL_IF(mClientTransport == nullptr, "setUpTransport not called or failed");
+ std::string readMessage(expectedMessage.size(), '\0');
+ status_t readStatus =
+ mClientTransport->interruptableReadFully(mFdTrigger.get(), readMessage.data(),
+ readMessage.size());
+ if (readStatus != OK) {
+ return AssertionFailure() << statusToString(readStatus);
+ }
+ if (readMessage != expectedMessage) {
+ return AssertionFailure()
+ << "Expected " << expectedMessage << ", actual " << readMessage;
+ }
+ return AssertionSuccess();
+ }
void run(bool handshakeOk = true, bool readOk = true) {
- auto clientTransport = mCtx->newTransport(std::move(mFd), mFdTrigger.get());
- if (clientTransport == nullptr) {
+ if (!setUpTransport()) {
ASSERT_FALSE(handshakeOk) << "newTransport returns nullptr, but it shouldn't";
return;
}
ASSERT_TRUE(handshakeOk) << "newTransport does not return nullptr, but it should";
- std::string expectedMessage(kMessage);
- std::string readMessage(expectedMessage.size(), '\0');
- status_t readStatus =
- clientTransport->interruptableReadFully(mFdTrigger.get(), readMessage.data(),
- readMessage.size());
- if (readOk) {
- ASSERT_EQ(OK, readStatus);
- ASSERT_EQ(readMessage, expectedMessage);
- } else {
- ASSERT_NE(OK, readStatus);
- }
+ ASSERT_EQ(readOk, readMessage());
}
private:
@@ -1604,6 +1634,7 @@
std::unique_ptr<RpcTransportCtx> mCtx;
std::shared_ptr<RpcCertificateVerifierSimple> mCertVerifier =
std::make_shared<RpcCertificateVerifierSimple>();
+ std::unique_ptr<RpcTransport> mClientTransport;
};
// Make A trust B.
@@ -1729,6 +1760,68 @@
maliciousClient.run(true, readOk);
}
+TEST_P(RpcTransportTest, Trigger) {
+ std::string msg2 = ", world!";
+ std::mutex writeMutex;
+ std::condition_variable writeCv;
+ bool shouldContinueWriting = false;
+ auto serverPostConnect = [&](RpcTransport* serverTransport, FdTrigger* fdTrigger) {
+ std::string message(kMessage);
+ auto status =
+ serverTransport->interruptableWriteFully(fdTrigger, message.data(), message.size());
+ if (status != OK) return AssertionFailure() << statusToString(status);
+
+ {
+ std::unique_lock<std::mutex> lock(writeMutex);
+ if (!writeCv.wait_for(lock, 3s, [&] { return shouldContinueWriting; })) {
+ return AssertionFailure() << "write barrier not cleared in time!";
+ }
+ }
+
+ status = serverTransport->interruptableWriteFully(fdTrigger, msg2.data(), msg2.size());
+ if (status != -ECANCELED)
+ return AssertionFailure() << "When FdTrigger is shut down, interruptableWriteFully "
+ "should return -ECANCELLED, but it is "
+ << statusToString(status);
+ return AssertionSuccess();
+ };
+
+ auto server = mServers.emplace_back(std::make_unique<Server>()).get();
+ ASSERT_TRUE(server->setUp());
+
+ // Set up client
+ Client client(server->getConnectToServerFn());
+ ASSERT_TRUE(client.setUp());
+
+ // Exchange keys
+ ASSERT_EQ(OK, trust(&client, server));
+ ASSERT_EQ(OK, trust(server, &client));
+
+ server->setPostConnect(serverPostConnect);
+
+ // Start server
+ server->start();
+ // connect() to server and do handshake
+ ASSERT_TRUE(client.setUpTransport());
+ // read the first message. This confirms that server has finished handshake and start handling
+ // client fd. Server thread should pause at waitForWriteBarrier.
+ ASSERT_TRUE(client.readMessage(kMessage));
+ // Trigger server shutdown after server starts handling client FD. This ensures that the second
+ // write is on an FdTrigger that has been shut down.
+ server->shutdown();
+ // Continues server thread to write the second message.
+ {
+ std::unique_lock<std::mutex> lock(writeMutex);
+ shouldContinueWriting = true;
+ lock.unlock();
+ writeCv.notify_all();
+ }
+ // After this line, server thread unblocks and attempts to write the second message, but
+ // shutdown is triggered, so write should failed with -ECANCELLED. See |serverPostConnect|.
+ // On the client side, second read fails with DEAD_OBJECT
+ ASSERT_FALSE(client.readMessage(msg2));
+}
+
std::vector<RpcCertificateFormat> testRpcCertificateFormats() {
return {
RpcCertificateFormat::PEM,