binder: implement simple TLS verification for testing
Implements RpcCertificateVerifierSimple with an algorithm
that treats all certificates as leaf certificates.
Fix existing tests to set certificates properly. Also
add a test that checks that bad certificates are rejected.
Also adds RpcCertificateUtils that includes function for
(de)serializing certificates. These util functions are useful
for implementing RpcCertificateVerifier.
Test: binderRpcTest
Bug: 195166979
Fixes: 196422181
Fixes: 198833574
Change-Id: I6c1f0f88fe5bc712f3890426d6da26c9ad046d79
diff --git a/libs/binder/Android.bp b/libs/binder/Android.bp
index b0d7478..9bca1f3 100644
--- a/libs/binder/Android.bp
+++ b/libs/binder/Android.bp
@@ -262,6 +262,7 @@
],
srcs: [
"RpcTransportTls.cpp",
+ "RpcCertificateUtils.cpp",
],
}
diff --git a/libs/binder/RpcCertificateUtils.cpp b/libs/binder/RpcCertificateUtils.cpp
new file mode 100644
index 0000000..d0ea3c7
--- /dev/null
+++ b/libs/binder/RpcCertificateUtils.cpp
@@ -0,0 +1,61 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#define LOG_TAG "RpcCertificateUtils"
+#include <log/log.h>
+
+#include <binder/RpcCertificateUtils.h>
+
+#include "Utils.h"
+
+namespace android {
+
+namespace {
+
+bssl::UniquePtr<X509> fromPem(const std::vector<uint8_t>& cert) {
+ if (cert.size() > std::numeric_limits<int>::max()) return nullptr;
+ bssl::UniquePtr<BIO> certBio(BIO_new_mem_buf(cert.data(), static_cast<int>(cert.size())));
+ return bssl::UniquePtr<X509>(PEM_read_bio_X509(certBio.get(), nullptr, nullptr, nullptr));
+}
+
+} // namespace
+
+bssl::UniquePtr<X509> deserializeCertificate(const std::vector<uint8_t>& cert,
+ CertificateFormat format) {
+ switch (format) {
+ case CertificateFormat::PEM:
+ return fromPem(cert);
+ }
+ LOG_ALWAYS_FATAL("Unsupported format %d", static_cast<int>(format));
+}
+
+std::vector<uint8_t> serializeCertificate(X509* x509, CertificateFormat format) {
+ bssl::UniquePtr<BIO> certBio(BIO_new(BIO_s_mem()));
+ switch (format) {
+ case CertificateFormat::PEM: {
+ TEST_AND_RETURN({}, PEM_write_bio_X509(certBio.get(), x509));
+ } break;
+ default: {
+ LOG_ALWAYS_FATAL("Unsupported format %d", static_cast<int>(format));
+ }
+ }
+ const uint8_t* data;
+ size_t len;
+ TEST_AND_RETURN({}, BIO_mem_contents(certBio.get(), &data, &len));
+ return std::vector<uint8_t>(data, data + len);
+}
+
+} // namespace android
diff --git a/libs/binder/RpcTransportTls.cpp b/libs/binder/RpcTransportTls.cpp
index c42ea9a..2a1dffd 100644
--- a/libs/binder/RpcTransportTls.cpp
+++ b/libs/binder/RpcTransportTls.cpp
@@ -22,6 +22,7 @@
#include <openssl/bn.h>
#include <openssl/ssl.h>
+#include <binder/RpcCertificateUtils.h>
#include <binder/RpcTransportTls.h>
#include "FdTrigger.h"
@@ -459,9 +460,9 @@
std::shared_ptr<RpcCertificateVerifier> mCertVerifier;
};
-std::vector<uint8_t> RpcTransportCtxTls::getCertificate(CertificateFormat) const {
- // TODO(b/195166979): return certificate here
- return {};
+std::vector<uint8_t> RpcTransportCtxTls::getCertificate(CertificateFormat format) const {
+ X509* x509 = SSL_CTX_get0_certificate(mCtx.get()); // does not own
+ return serializeCertificate(x509, format);
}
// Verify by comparing the leaf of peer certificate with every certificate in
diff --git a/libs/binder/include/binder/CertificateFormat.h b/libs/binder/include/binder/CertificateFormat.h
index 4f7e71e..d33ee7e 100644
--- a/libs/binder/include/binder/CertificateFormat.h
+++ b/libs/binder/include/binder/CertificateFormat.h
@@ -18,6 +18,8 @@
#pragma once
+#include <string>
+
namespace android {
enum class CertificateFormat {
@@ -25,4 +27,13 @@
// TODO(b/195166979): support other formats, e.g. DER
};
+static inline std::string PrintToString(CertificateFormat format) {
+ switch (format) {
+ case CertificateFormat::PEM:
+ return "PEM";
+ default:
+ return "<unknown>";
+ }
+}
+
} // namespace android
diff --git a/libs/binder/include_tls/binder/RpcCertificateUtils.h b/libs/binder/include_tls/binder/RpcCertificateUtils.h
new file mode 100644
index 0000000..b7c1849
--- /dev/null
+++ b/libs/binder/include_tls/binder/RpcCertificateUtils.h
@@ -0,0 +1,34 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Utilities for serializing and deserializing X509 certificates.
+
+#pragma once
+
+#include <vector>
+
+#include <openssl/ssl.h>
+
+#include <binder/CertificateFormat.h>
+
+namespace android {
+
+bssl::UniquePtr<X509> deserializeCertificate(const std::vector<uint8_t>& cert,
+ CertificateFormat format);
+
+std::vector<uint8_t> serializeCertificate(X509* x509, CertificateFormat format);
+
+} // namespace android
diff --git a/libs/binder/tests/Android.bp b/libs/binder/tests/Android.bp
index a9bc15d..1968058 100644
--- a/libs/binder/tests/Android.bp
+++ b/libs/binder/tests/Android.bp
@@ -120,9 +120,12 @@
host_supported: true,
unstable: true,
srcs: [
+ "BinderRpcTestClientInfo.aidl",
+ "BinderRpcTestServerInfo.aidl",
"IBinderRpcCallback.aidl",
"IBinderRpcSession.aidl",
"IBinderRpcTest.aidl",
+ "ParcelableCertificateData.aidl",
],
backend: {
java: {
diff --git a/libs/binder/tests/BinderRpcTestClientInfo.aidl b/libs/binder/tests/BinderRpcTestClientInfo.aidl
new file mode 100644
index 0000000..b4baebc
--- /dev/null
+++ b/libs/binder/tests/BinderRpcTestClientInfo.aidl
@@ -0,0 +1,21 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import ParcelableCertificateData;
+
+parcelable BinderRpcTestClientInfo {
+ ParcelableCertificateData[] certs;
+}
diff --git a/libs/binder/tests/BinderRpcTestServerInfo.aidl b/libs/binder/tests/BinderRpcTestServerInfo.aidl
new file mode 100644
index 0000000..00dc0bc
--- /dev/null
+++ b/libs/binder/tests/BinderRpcTestServerInfo.aidl
@@ -0,0 +1,22 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import ParcelableCertificateData;
+
+parcelable BinderRpcTestServerInfo {
+ long port;
+ ParcelableCertificateData cert;
+}
diff --git a/libs/binder/tests/ParcelableCertificateData.aidl b/libs/binder/tests/ParcelableCertificateData.aidl
new file mode 100644
index 0000000..38c382e
--- /dev/null
+++ b/libs/binder/tests/ParcelableCertificateData.aidl
@@ -0,0 +1,19 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+parcelable ParcelableCertificateData {
+ byte[] data;
+}
diff --git a/libs/binder/tests/RpcCertificateVerifierSimple.cpp b/libs/binder/tests/RpcCertificateVerifierSimple.cpp
index 68e7c65..0cf0e1c 100644
--- a/libs/binder/tests/RpcCertificateVerifierSimple.cpp
+++ b/libs/binder/tests/RpcCertificateVerifierSimple.cpp
@@ -16,12 +16,32 @@
#define LOG_TAG "RpcCertificateVerifierSimple"
#include <log/log.h>
+#include <binder/RpcCertificateUtils.h>
+
#include "RpcCertificateVerifierSimple.h"
namespace android {
-status_t RpcCertificateVerifierSimple::verify(const X509*, uint8_t*) {
- // TODO(b/195166979): implement this
+status_t RpcCertificateVerifierSimple::verify(const X509* peerCert, uint8_t* outAlert) {
+ std::lock_guard<std::mutex> lock(mMutex);
+ for (const auto& trustedCert : mTrustedPeerCertificates) {
+ if (0 == X509_cmp(trustedCert.get(), peerCert)) {
+ return OK;
+ }
+ }
+ *outAlert = SSL_AD_CERTIFICATE_UNKNOWN;
+ return PERMISSION_DENIED;
+}
+
+status_t RpcCertificateVerifierSimple::addTrustedPeerCertificate(CertificateFormat format,
+ const std::vector<uint8_t>& cert) {
+ bssl::UniquePtr<X509> x509 = deserializeCertificate(cert, format);
+ if (x509 == nullptr) {
+ ALOGE("Certificate is not in the proper format %s", PrintToString(format).c_str());
+ return BAD_VALUE;
+ }
+ std::lock_guard<std::mutex> lock(mMutex);
+ mTrustedPeerCertificates.push_back(std::move(x509));
return OK;
}
diff --git a/libs/binder/tests/RpcCertificateVerifierSimple.h b/libs/binder/tests/RpcCertificateVerifierSimple.h
index aff5c7c..02aa3c6 100644
--- a/libs/binder/tests/RpcCertificateVerifierSimple.h
+++ b/libs/binder/tests/RpcCertificateVerifierSimple.h
@@ -16,14 +16,38 @@
#pragma once
+#include <mutex>
+#include <string_view>
+#include <vector>
+
+#include <openssl/ssl.h>
+
+#include <binder/CertificateFormat.h>
#include <binder/RpcCertificateVerifier.h>
namespace android {
// A simple certificate verifier for testing.
+// Keep a list of leaf certificates as trusted. No certificate chain support.
+//
+// All APIs are thread-safe. However, if verify() and addTrustedPeerCertificate() are called
+// simultaneously in different threads, it is not deterministic whether verify() will use the
+// certificate being added.
class RpcCertificateVerifierSimple : public RpcCertificateVerifier {
public:
status_t verify(const X509*, uint8_t*) override;
+
+ // Add a trusted peer certificate. Peers presenting this certificate are accepted.
+ //
+ // Caller must ensure that RpcTransportCtx::newTransport() are called after all trusted peer
+ // certificates are added. Otherwise, RpcTransport-s created before may not trust peer
+ // certificates added later.
+ [[nodiscard]] status_t addTrustedPeerCertificate(CertificateFormat format,
+ const std::vector<uint8_t>& cert);
+
+private:
+ std::mutex mMutex; // for below
+ std::vector<bssl::UniquePtr<X509>> mTrustedPeerCertificates;
};
} // namespace android
diff --git a/libs/binder/tests/binderRpcTest.cpp b/libs/binder/tests/binderRpcTest.cpp
index dbf8899..8c805bb 100644
--- a/libs/binder/tests/binderRpcTest.cpp
+++ b/libs/binder/tests/binderRpcTest.cpp
@@ -14,6 +14,8 @@
* limitations under the License.
*/
+#include <BinderRpcTestClientInfo.h>
+#include <BinderRpcTestServerInfo.h>
#include <BnBinderRpcCallback.h>
#include <BnBinderRpcSession.h>
#include <BnBinderRpcTest.h>
@@ -40,15 +42,20 @@
#include <thread>
#include <type_traits>
+#include <poll.h>
#include <sys/prctl.h>
#include <unistd.h>
+#include "../FdTrigger.h"
#include "../RpcSocketAddress.h" // for testing preconnected clients
#include "../RpcState.h" // for debugging
#include "../vm_sockets.h" // for VMADDR_*
#include "RpcCertificateVerifierSimple.h"
using namespace std::chrono_literals;
+using testing::AssertionFailure;
+using testing::AssertionResult;
+using testing::AssertionSuccess;
namespace android {
@@ -68,7 +75,6 @@
case RpcSecurity::RAW:
return RpcTransportCtxFactoryRaw::make();
case RpcSecurity::TLS: {
- // TODO(b/198833574): exchange keys and set proper verifier
if (verifier == nullptr) {
verifier = std::make_shared<RpcCertificateVerifierSimple>();
}
@@ -311,14 +317,17 @@
class Process {
public:
Process(Process&&) = default;
- Process(const std::function<void(android::base::borrowed_fd /* writeEnd */)>& f) {
- android::base::unique_fd writeEnd;
- CHECK(android::base::Pipe(&mReadEnd, &writeEnd)) << strerror(errno);
+ Process(const std::function<void(android::base::borrowed_fd /* writeEnd */,
+ android::base::borrowed_fd /* readEnd */)>& f) {
+ android::base::unique_fd childWriteEnd;
+ android::base::unique_fd childReadEnd;
+ CHECK(android::base::Pipe(&mReadEnd, &childWriteEnd)) << strerror(errno);
+ CHECK(android::base::Pipe(&childReadEnd, &mWriteEnd)) << strerror(errno);
if (0 == (mPid = fork())) {
// racey: assume parent doesn't crash before this is set
prctl(PR_SET_PDEATHSIG, SIGHUP);
- f(writeEnd);
+ f(childWriteEnd, childReadEnd);
exit(0);
}
@@ -329,16 +338,20 @@
}
}
android::base::borrowed_fd readEnd() { return mReadEnd; }
+ android::base::borrowed_fd writeEnd() { return mWriteEnd; }
private:
pid_t mPid = 0;
android::base::unique_fd mReadEnd;
+ android::base::unique_fd mWriteEnd;
};
static std::string allocateSocketAddress() {
static size_t id = 0;
std::string temp = getenv("TMPDIR") ?: "/tmp";
- return temp + "/binderRpcTest_" + std::to_string(id++);
+ auto ret = temp + "/binderRpcTest_" + std::to_string(id++);
+ unlink(ret.c_str());
+ return ret;
};
static unsigned int allocateVsockPort() {
@@ -441,16 +454,17 @@
}
}
-static base::unique_fd connectToUds(const char* addrStr) {
- UnixSocketAddress addr(addrStr);
+static base::unique_fd connectTo(const RpcSocketAddress& addr) {
base::unique_fd serverFd(
TEMP_FAILURE_RETRY(socket(addr.addr()->sa_family, SOCK_STREAM | SOCK_CLOEXEC, 0)));
int savedErrno = errno;
- CHECK(serverFd.ok()) << "Could not create socket " << addrStr << ": " << strerror(savedErrno);
+ CHECK(serverFd.ok()) << "Could not create socket " << addr.toString() << ": "
+ << strerror(savedErrno);
if (0 != TEMP_FAILURE_RETRY(connect(serverFd.get(), addr.addr(), addr.addrSize()))) {
int savedErrno = errno;
- LOG(FATAL) << "Could not connect to socket " << addrStr << ": " << strerror(savedErrno);
+ LOG(FATAL) << "Could not connect to socket " << addr.toString() << ": "
+ << strerror(savedErrno);
}
return serverFd;
}
@@ -468,6 +482,37 @@
return PrintToString(type) + "_" + newFactory(security)->toCString();
}
+ static inline void writeString(android::base::borrowed_fd fd, std::string_view str) {
+ uint64_t length = str.length();
+ CHECK(android::base::WriteFully(fd, &length, sizeof(length)));
+ CHECK(android::base::WriteFully(fd, str.data(), str.length()));
+ }
+
+ static inline std::string readString(android::base::borrowed_fd fd) {
+ uint64_t length;
+ CHECK(android::base::ReadFully(fd, &length, sizeof(length)));
+ std::string ret(length, '\0');
+ CHECK(android::base::ReadFully(fd, ret.data(), length));
+ return ret;
+ }
+
+ static inline void writeToFd(android::base::borrowed_fd fd, const Parcelable& parcelable) {
+ Parcel parcel;
+ CHECK_EQ(OK, parcelable.writeToParcel(&parcel));
+ writeString(fd,
+ std::string(reinterpret_cast<const char*>(parcel.data()), parcel.dataSize()));
+ }
+
+ template <typename T>
+ static inline T readFromFd(android::base::borrowed_fd fd) {
+ std::string data = readString(fd);
+ Parcel parcel;
+ CHECK_EQ(OK, parcel.setData(reinterpret_cast<const uint8_t*>(data.data()), data.size()));
+ T object;
+ CHECK_EQ(OK, object.readFromParcel(&parcel));
+ return object;
+ }
+
// This creates a new process serving an interface on a certain number of
// threads.
ProcessSession createRpcTestSocketServerProcess(
@@ -479,11 +524,12 @@
unsigned int vsockPort = allocateVsockPort();
std::string addr = allocateSocketAddress();
- unlink(addr.c_str());
auto ret = ProcessSession{
- .host = Process([&](android::base::borrowed_fd writeEnd) {
- sp<RpcServer> server = RpcServer::make(newFactory(rpcSecurity));
+ .host = Process([&](android::base::borrowed_fd writeEnd,
+ android::base::borrowed_fd readEnd) {
+ auto certVerifier = std::make_shared<RpcCertificateVerifierSimple>();
+ sp<RpcServer> server = RpcServer::make(newFactory(rpcSecurity, certVerifier));
server->iUnderstandThisCodeIsExperimentalAndIWillNotUseItInProduction();
server->setMaxThreads(options.numThreads);
@@ -508,7 +554,19 @@
LOG_ALWAYS_FATAL("Unknown socket type");
}
- CHECK(android::base::WriteFully(writeEnd, &outPort, sizeof(outPort)));
+ BinderRpcTestServerInfo serverInfo;
+ serverInfo.port = static_cast<int64_t>(outPort);
+ serverInfo.cert.data = server->getCertificate(CertificateFormat::PEM);
+ writeToFd(writeEnd, serverInfo);
+ auto clientInfo = readFromFd<BinderRpcTestClientInfo>(readEnd);
+
+ if (rpcSecurity == RpcSecurity::TLS) {
+ for (const auto& clientCert : clientInfo.certs) {
+ CHECK_EQ(OK,
+ certVerifier->addTrustedPeerCertificate(CertificateFormat::PEM,
+ clientCert.data));
+ }
+ }
configure(server);
@@ -519,23 +577,40 @@
}),
};
- // always read socket, so that we have waited for the server to start
- unsigned int outPort = 0;
- CHECK(android::base::ReadFully(ret.host.readEnd(), &outPort, sizeof(outPort)));
+ std::vector<sp<RpcSession>> sessions;
+ auto certVerifier = std::make_shared<RpcCertificateVerifierSimple>();
+ for (size_t i = 0; i < options.numSessions; i++) {
+ sessions.emplace_back(RpcSession::make(newFactory(rpcSecurity, certVerifier)));
+ }
+
+ auto serverInfo = readFromFd<BinderRpcTestServerInfo>(ret.host.readEnd());
+ BinderRpcTestClientInfo clientInfo;
+ for (const auto& session : sessions) {
+ auto& parcelableCert = clientInfo.certs.emplace_back();
+ parcelableCert.data = session->getCertificate(CertificateFormat::PEM);
+ }
+ writeToFd(ret.host.writeEnd(), clientInfo);
+
+ CHECK_LE(serverInfo.port, std::numeric_limits<unsigned int>::max());
if (socketType == SocketType::INET) {
- CHECK_NE(0, outPort);
+ CHECK_NE(0, serverInfo.port);
+ }
+
+ if (rpcSecurity == RpcSecurity::TLS) {
+ const auto& serverCert = serverInfo.cert.data;
+ CHECK_EQ(OK,
+ certVerifier->addTrustedPeerCertificate(CertificateFormat::PEM, serverCert));
}
status_t status;
- for (size_t i = 0; i < options.numSessions; i++) {
- sp<RpcSession> session = RpcSession::make(newFactory(rpcSecurity));
+ for (const auto& session : sessions) {
session->setMaxThreads(options.numIncomingConnections);
switch (socketType) {
case SocketType::PRECONNECTED:
status = session->setupPreconnectedClient({}, [=]() {
- return connectToUds(addr.c_str());
+ return connectTo(UnixSocketAddress(addr.c_str()));
});
if (status == OK) goto success;
break;
@@ -548,7 +623,7 @@
if (status == OK) goto success;
break;
case SocketType::INET:
- status = session->setupInetClient("127.0.0.1", outPort);
+ status = session->setupInetClient("127.0.0.1", serverInfo.port);
if (status == OK) goto success;
break;
default:
@@ -1221,8 +1296,10 @@
return status == OK;
}
-static std::vector<SocketType> testSocketTypes() {
- std::vector<SocketType> ret = {SocketType::PRECONNECTED, SocketType::UNIX, SocketType::INET};
+static std::vector<SocketType> testSocketTypes(bool hasPreconnected = true) {
+ std::vector<SocketType> ret = {SocketType::UNIX, SocketType::INET};
+
+ if (hasPreconnected) ret.push_back(SocketType::PRECONNECTED);
static bool hasVsockLoopback = testSupportVsockLoopback();
@@ -1291,7 +1368,6 @@
TEST_P(BinderRpcSimple, Shutdown) {
auto addr = allocateSocketAddress();
- unlink(addr.c_str());
auto server = RpcServer::make(newFactory(GetParam()));
server->iUnderstandThisCodeIsExperimentalAndIWillNotUseItInProduction();
ASSERT_EQ(OK, server->setupUnixDomainServer(addr.c_str()));
@@ -1356,6 +1432,313 @@
INSTANTIATE_TEST_CASE_P(BinderRpc, BinderRpcSimple, ::testing::ValuesIn(RpcSecurityValues()),
BinderRpcSimple::PrintTestParam);
+class RpcTransportTest
+ : public ::testing::TestWithParam<std::tuple<SocketType, RpcSecurity, CertificateFormat>> {
+public:
+ using ConnectToServer = std::function<base::unique_fd()>;
+ static inline std::string PrintParamInfo(const testing::TestParamInfo<ParamType>& info) {
+ auto [socketType, rpcSecurity, certificateFormat] = info.param;
+ return PrintToString(socketType) + "_" + newFactory(rpcSecurity)->toCString() + "_" +
+ PrintToString(certificateFormat);
+ }
+ void TearDown() override {
+ for (auto& server : mServers) server->shutdown();
+ }
+
+ // A server that handles client socket connections.
+ class Server {
+ public:
+ explicit Server() {}
+ Server(Server&&) = default;
+ ~Server() { shutdown(); }
+ [[nodiscard]] AssertionResult setUp() {
+ auto [socketType, rpcSecurity, certificateFormat] = GetParam();
+ auto rpcServer = RpcServer::make(newFactory(rpcSecurity));
+ rpcServer->iUnderstandThisCodeIsExperimentalAndIWillNotUseItInProduction();
+ switch (socketType) {
+ case SocketType::PRECONNECTED: {
+ return AssertionFailure() << "Not supported by this test";
+ } break;
+ case SocketType::UNIX: {
+ auto addr = allocateSocketAddress();
+ auto status = rpcServer->setupUnixDomainServer(addr.c_str());
+ if (status != OK) {
+ return AssertionFailure()
+ << "setupUnixDomainServer: " << statusToString(status);
+ }
+ mConnectToServer = [addr] {
+ return connectTo(UnixSocketAddress(addr.c_str()));
+ };
+ } break;
+ case SocketType::VSOCK: {
+ auto port = allocateVsockPort();
+ auto status = rpcServer->setupVsockServer(port);
+ if (status != OK) {
+ return AssertionFailure() << "setupVsockServer: " << statusToString(status);
+ }
+ mConnectToServer = [port] {
+ return connectTo(VsockSocketAddress(VMADDR_CID_LOCAL, port));
+ };
+ } break;
+ case SocketType::INET: {
+ unsigned int port;
+ auto status = rpcServer->setupInetServer(kLocalInetAddress, 0, &port);
+ if (status != OK) {
+ return AssertionFailure() << "setupInetServer: " << statusToString(status);
+ }
+ mConnectToServer = [port] {
+ const char* addr = kLocalInetAddress;
+ auto aiStart = InetSocketAddress::getAddrInfo(addr, port);
+ if (aiStart == nullptr) return base::unique_fd{};
+ for (auto ai = aiStart.get(); ai != nullptr; ai = ai->ai_next) {
+ auto fd = connectTo(
+ InetSocketAddress(ai->ai_addr, ai->ai_addrlen, addr, port));
+ if (fd.ok()) return fd;
+ }
+ ALOGE("None of the socket address resolved for %s:%u can be connected",
+ addr, port);
+ return base::unique_fd{};
+ };
+ }
+ }
+ mFd = rpcServer->releaseServer();
+ if (!mFd.ok()) return AssertionFailure() << "releaseServer returns invalid fd";
+ mCtx = newFactory(rpcSecurity, mCertVerifier)->newServerCtx();
+ if (mCtx == nullptr) return AssertionFailure() << "newServerCtx";
+ mSetup = true;
+ return AssertionSuccess();
+ }
+ RpcTransportCtx* getCtx() const { return mCtx.get(); }
+ std::shared_ptr<RpcCertificateVerifierSimple> getCertVerifier() const {
+ return mCertVerifier;
+ }
+ ConnectToServer getConnectToServerFn() { return mConnectToServer; }
+ void start() {
+ LOG_ALWAYS_FATAL_IF(!mSetup, "Call Server::setup first!");
+ mThread = std::make_unique<std::thread>(&Server::run, this);
+ }
+ void run() {
+ LOG_ALWAYS_FATAL_IF(!mSetup, "Call Server::setup first!");
+
+ std::vector<std::thread> threads;
+ while (OK == mFdTrigger->triggerablePoll(mFd, POLLIN)) {
+ base::unique_fd acceptedFd(
+ TEMP_FAILURE_RETRY(accept4(mFd.get(), nullptr, nullptr /*length*/,
+ SOCK_CLOEXEC | SOCK_NONBLOCK)));
+ threads.emplace_back(&Server::handleOne, this, std::move(acceptedFd));
+ }
+
+ for (auto& thread : threads) thread.join();
+ }
+ void handleOne(android::base::unique_fd acceptedFd) {
+ ASSERT_TRUE(acceptedFd.ok());
+ auto serverTransport = mCtx->newTransport(std::move(acceptedFd), mFdTrigger.get());
+ if (serverTransport == nullptr) return; // handshake failed
+ std::string message(kMessage);
+ ASSERT_EQ(OK,
+ serverTransport->interruptableWriteFully(mFdTrigger.get(), message.data(),
+ message.size()));
+ }
+ void shutdown() {
+ mFdTrigger->trigger();
+ if (mThread != nullptr) {
+ mThread->join();
+ mThread = nullptr;
+ }
+ }
+
+ private:
+ std::unique_ptr<std::thread> mThread;
+ ConnectToServer mConnectToServer;
+ std::unique_ptr<FdTrigger> mFdTrigger = FdTrigger::make();
+ base::unique_fd mFd;
+ std::unique_ptr<RpcTransportCtx> mCtx;
+ std::shared_ptr<RpcCertificateVerifierSimple> mCertVerifier =
+ std::make_shared<RpcCertificateVerifierSimple>();
+ bool mSetup = false;
+ };
+
+ class Client {
+ public:
+ explicit Client(ConnectToServer connectToServer) : mConnectToServer(connectToServer) {}
+ Client(Client&&) = default;
+ [[nodiscard]] AssertionResult setUp() {
+ auto [socketType, rpcSecurity, certificateFormat] = GetParam();
+ mFd = mConnectToServer();
+ if (!mFd.ok()) return AssertionFailure() << "Cannot connect to server";
+ mFdTrigger = FdTrigger::make();
+ mCtx = newFactory(rpcSecurity, mCertVerifier)->newClientCtx();
+ if (mCtx == nullptr) return AssertionFailure() << "newClientCtx";
+ return AssertionSuccess();
+ }
+ RpcTransportCtx* getCtx() const { return mCtx.get(); }
+ std::shared_ptr<RpcCertificateVerifierSimple> getCertVerifier() const {
+ return mCertVerifier;
+ }
+ void run(bool handshakeOk = true, bool readOk = true) {
+ auto clientTransport = mCtx->newTransport(std::move(mFd), mFdTrigger.get());
+ if (clientTransport == nullptr) {
+ ASSERT_FALSE(handshakeOk) << "newTransport returns nullptr, but it shouldn't";
+ return;
+ }
+ ASSERT_TRUE(handshakeOk) << "newTransport does not return nullptr, but it should";
+ std::string expectedMessage(kMessage);
+ std::string readMessage(expectedMessage.size(), '\0');
+ status_t readStatus =
+ clientTransport->interruptableReadFully(mFdTrigger.get(), readMessage.data(),
+ readMessage.size());
+ if (readOk) {
+ ASSERT_EQ(OK, readStatus);
+ ASSERT_EQ(readMessage, expectedMessage);
+ } else {
+ ASSERT_NE(OK, readStatus);
+ }
+ }
+
+ private:
+ ConnectToServer mConnectToServer;
+ base::unique_fd mFd;
+ std::unique_ptr<FdTrigger> mFdTrigger = FdTrigger::make();
+ std::unique_ptr<RpcTransportCtx> mCtx;
+ std::shared_ptr<RpcCertificateVerifierSimple> mCertVerifier =
+ std::make_shared<RpcCertificateVerifierSimple>();
+ };
+
+ // Make A trust B.
+ template <typename A, typename B>
+ status_t trust(A* a, B* b) {
+ auto [socketType, rpcSecurity, certificateFormat] = GetParam();
+ if (rpcSecurity != RpcSecurity::TLS) return OK;
+ auto bCert = b->getCtx()->getCertificate(certificateFormat);
+ return a->getCertVerifier()->addTrustedPeerCertificate(certificateFormat, bCert);
+ }
+
+ static constexpr const char* kMessage = "hello";
+ std::vector<std::unique_ptr<Server>> mServers;
+};
+
+TEST_P(RpcTransportTest, GoodCertificate) {
+ auto server = mServers.emplace_back(std::make_unique<Server>()).get();
+ ASSERT_TRUE(server->setUp());
+
+ Client client(server->getConnectToServerFn());
+ ASSERT_TRUE(client.setUp());
+
+ ASSERT_EQ(OK, trust(&client, server));
+ ASSERT_EQ(OK, trust(server, &client));
+
+ server->start();
+ client.run();
+}
+
+TEST_P(RpcTransportTest, MultipleClients) {
+ auto server = mServers.emplace_back(std::make_unique<Server>()).get();
+ ASSERT_TRUE(server->setUp());
+
+ std::vector<Client> clients;
+ for (int i = 0; i < 2; i++) {
+ auto& client = clients.emplace_back(server->getConnectToServerFn());
+ ASSERT_TRUE(client.setUp());
+ ASSERT_EQ(OK, trust(&client, server));
+ ASSERT_EQ(OK, trust(server, &client));
+ }
+
+ server->start();
+ for (auto& client : clients) client.run();
+}
+
+TEST_P(RpcTransportTest, UntrustedServer) {
+ auto [socketType, rpcSecurity, certificateFormat] = GetParam();
+
+ auto untrustedServer = mServers.emplace_back(std::make_unique<Server>()).get();
+ ASSERT_TRUE(untrustedServer->setUp());
+
+ Client client(untrustedServer->getConnectToServerFn());
+ ASSERT_TRUE(client.setUp());
+
+ ASSERT_EQ(OK, trust(untrustedServer, &client));
+
+ untrustedServer->start();
+
+ // For TLS, this should reject the certificate. For RAW sockets, it should pass because
+ // the client can't verify the server's identity.
+ bool handshakeOk = rpcSecurity != RpcSecurity::TLS;
+ client.run(handshakeOk);
+}
+TEST_P(RpcTransportTest, MaliciousServer) {
+ auto [socketType, rpcSecurity, certificateFormat] = GetParam();
+ auto validServer = mServers.emplace_back(std::make_unique<Server>()).get();
+ ASSERT_TRUE(validServer->setUp());
+
+ auto maliciousServer = mServers.emplace_back(std::make_unique<Server>()).get();
+ ASSERT_TRUE(maliciousServer->setUp());
+
+ Client client(maliciousServer->getConnectToServerFn());
+ ASSERT_TRUE(client.setUp());
+
+ ASSERT_EQ(OK, trust(&client, validServer));
+ ASSERT_EQ(OK, trust(validServer, &client));
+ ASSERT_EQ(OK, trust(maliciousServer, &client));
+
+ maliciousServer->start();
+
+ // For TLS, this should reject the certificate. For RAW sockets, it should pass because
+ // the client can't verify the server's identity.
+ bool handshakeOk = rpcSecurity != RpcSecurity::TLS;
+ client.run(handshakeOk);
+}
+
+TEST_P(RpcTransportTest, UntrustedClient) {
+ auto [socketType, rpcSecurity, certificateFormat] = GetParam();
+ auto server = mServers.emplace_back(std::make_unique<Server>()).get();
+ ASSERT_TRUE(server->setUp());
+
+ Client client(server->getConnectToServerFn());
+ ASSERT_TRUE(client.setUp());
+
+ ASSERT_EQ(OK, trust(&client, server));
+
+ server->start();
+
+ // For TLS, Client should be able to verify server's identity, so client should see
+ // do_handshake() successfully executed. However, server shouldn't be able to verify client's
+ // identity and should drop the connection, so client shouldn't be able to read anything.
+ bool readOk = rpcSecurity != RpcSecurity::TLS;
+ client.run(true, readOk);
+}
+
+TEST_P(RpcTransportTest, MaliciousClient) {
+ auto [socketType, rpcSecurity, certificateFormat] = GetParam();
+ auto server = mServers.emplace_back(std::make_unique<Server>()).get();
+ ASSERT_TRUE(server->setUp());
+
+ Client validClient(server->getConnectToServerFn());
+ ASSERT_TRUE(validClient.setUp());
+ Client maliciousClient(server->getConnectToServerFn());
+ ASSERT_TRUE(maliciousClient.setUp());
+
+ ASSERT_EQ(OK, trust(&validClient, server));
+ ASSERT_EQ(OK, trust(&maliciousClient, server));
+
+ server->start();
+
+ // See UntrustedClient.
+ bool readOk = rpcSecurity != RpcSecurity::TLS;
+ maliciousClient.run(true, readOk);
+}
+
+std::vector<CertificateFormat> testCertificateFormats() {
+ return {
+ CertificateFormat::PEM,
+ };
+}
+
+INSTANTIATE_TEST_CASE_P(BinderRpc, RpcTransportTest,
+ ::testing::Combine(::testing::ValuesIn(testSocketTypes(false)),
+ ::testing::ValuesIn(RpcSecurityValues()),
+ ::testing::ValuesIn(testCertificateFormats())),
+ RpcTransportTest::PrintParamInfo);
+
} // namespace android
int main(int argc, char** argv) {