binder: implement simple TLS verification for testing

Implements RpcCertificateVerifierSimple with an algorithm
that treats all certificates as leaf certificates.

Fix existing tests to set certificates properly. Also
add a test that checks that bad certificates are rejected.

Also adds RpcCertificateUtils that includes function for
(de)serializing certificates. These util functions are useful
for implementing RpcCertificateVerifier.

Test: binderRpcTest
Bug: 195166979
Fixes: 196422181
Fixes: 198833574
Change-Id: I6c1f0f88fe5bc712f3890426d6da26c9ad046d79
diff --git a/libs/binder/Android.bp b/libs/binder/Android.bp
index b0d7478..9bca1f3 100644
--- a/libs/binder/Android.bp
+++ b/libs/binder/Android.bp
@@ -262,6 +262,7 @@
     ],
     srcs: [
         "RpcTransportTls.cpp",
+        "RpcCertificateUtils.cpp",
     ],
 }
 
diff --git a/libs/binder/RpcCertificateUtils.cpp b/libs/binder/RpcCertificateUtils.cpp
new file mode 100644
index 0000000..d0ea3c7
--- /dev/null
+++ b/libs/binder/RpcCertificateUtils.cpp
@@ -0,0 +1,61 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#define LOG_TAG "RpcCertificateUtils"
+#include <log/log.h>
+
+#include <binder/RpcCertificateUtils.h>
+
+#include "Utils.h"
+
+namespace android {
+
+namespace {
+
+bssl::UniquePtr<X509> fromPem(const std::vector<uint8_t>& cert) {
+    if (cert.size() > std::numeric_limits<int>::max()) return nullptr;
+    bssl::UniquePtr<BIO> certBio(BIO_new_mem_buf(cert.data(), static_cast<int>(cert.size())));
+    return bssl::UniquePtr<X509>(PEM_read_bio_X509(certBio.get(), nullptr, nullptr, nullptr));
+}
+
+} // namespace
+
+bssl::UniquePtr<X509> deserializeCertificate(const std::vector<uint8_t>& cert,
+                                             CertificateFormat format) {
+    switch (format) {
+        case CertificateFormat::PEM:
+            return fromPem(cert);
+    }
+    LOG_ALWAYS_FATAL("Unsupported format %d", static_cast<int>(format));
+}
+
+std::vector<uint8_t> serializeCertificate(X509* x509, CertificateFormat format) {
+    bssl::UniquePtr<BIO> certBio(BIO_new(BIO_s_mem()));
+    switch (format) {
+        case CertificateFormat::PEM: {
+            TEST_AND_RETURN({}, PEM_write_bio_X509(certBio.get(), x509));
+        } break;
+        default: {
+            LOG_ALWAYS_FATAL("Unsupported format %d", static_cast<int>(format));
+        }
+    }
+    const uint8_t* data;
+    size_t len;
+    TEST_AND_RETURN({}, BIO_mem_contents(certBio.get(), &data, &len));
+    return std::vector<uint8_t>(data, data + len);
+}
+
+} // namespace android
diff --git a/libs/binder/RpcTransportTls.cpp b/libs/binder/RpcTransportTls.cpp
index c42ea9a..2a1dffd 100644
--- a/libs/binder/RpcTransportTls.cpp
+++ b/libs/binder/RpcTransportTls.cpp
@@ -22,6 +22,7 @@
 #include <openssl/bn.h>
 #include <openssl/ssl.h>
 
+#include <binder/RpcCertificateUtils.h>
 #include <binder/RpcTransportTls.h>
 
 #include "FdTrigger.h"
@@ -459,9 +460,9 @@
     std::shared_ptr<RpcCertificateVerifier> mCertVerifier;
 };
 
-std::vector<uint8_t> RpcTransportCtxTls::getCertificate(CertificateFormat) const {
-    // TODO(b/195166979): return certificate here
-    return {};
+std::vector<uint8_t> RpcTransportCtxTls::getCertificate(CertificateFormat format) const {
+    X509* x509 = SSL_CTX_get0_certificate(mCtx.get()); // does not own
+    return serializeCertificate(x509, format);
 }
 
 // Verify by comparing the leaf of peer certificate with every certificate in
diff --git a/libs/binder/include/binder/CertificateFormat.h b/libs/binder/include/binder/CertificateFormat.h
index 4f7e71e..d33ee7e 100644
--- a/libs/binder/include/binder/CertificateFormat.h
+++ b/libs/binder/include/binder/CertificateFormat.h
@@ -18,6 +18,8 @@
 
 #pragma once
 
+#include <string>
+
 namespace android {
 
 enum class CertificateFormat {
@@ -25,4 +27,13 @@
     // TODO(b/195166979): support other formats, e.g. DER
 };
 
+static inline std::string PrintToString(CertificateFormat format) {
+    switch (format) {
+        case CertificateFormat::PEM:
+            return "PEM";
+        default:
+            return "<unknown>";
+    }
+}
+
 } // namespace android
diff --git a/libs/binder/include_tls/binder/RpcCertificateUtils.h b/libs/binder/include_tls/binder/RpcCertificateUtils.h
new file mode 100644
index 0000000..b7c1849
--- /dev/null
+++ b/libs/binder/include_tls/binder/RpcCertificateUtils.h
@@ -0,0 +1,34 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Utilities for serializing and deserializing X509 certificates.
+
+#pragma once
+
+#include <vector>
+
+#include <openssl/ssl.h>
+
+#include <binder/CertificateFormat.h>
+
+namespace android {
+
+bssl::UniquePtr<X509> deserializeCertificate(const std::vector<uint8_t>& cert,
+                                             CertificateFormat format);
+
+std::vector<uint8_t> serializeCertificate(X509* x509, CertificateFormat format);
+
+} // namespace android
diff --git a/libs/binder/tests/Android.bp b/libs/binder/tests/Android.bp
index a9bc15d..1968058 100644
--- a/libs/binder/tests/Android.bp
+++ b/libs/binder/tests/Android.bp
@@ -120,9 +120,12 @@
     host_supported: true,
     unstable: true,
     srcs: [
+        "BinderRpcTestClientInfo.aidl",
+        "BinderRpcTestServerInfo.aidl",
         "IBinderRpcCallback.aidl",
         "IBinderRpcSession.aidl",
         "IBinderRpcTest.aidl",
+        "ParcelableCertificateData.aidl",
     ],
     backend: {
         java: {
diff --git a/libs/binder/tests/BinderRpcTestClientInfo.aidl b/libs/binder/tests/BinderRpcTestClientInfo.aidl
new file mode 100644
index 0000000..b4baebc
--- /dev/null
+++ b/libs/binder/tests/BinderRpcTestClientInfo.aidl
@@ -0,0 +1,21 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import ParcelableCertificateData;
+
+parcelable BinderRpcTestClientInfo {
+    ParcelableCertificateData[] certs;
+}
diff --git a/libs/binder/tests/BinderRpcTestServerInfo.aidl b/libs/binder/tests/BinderRpcTestServerInfo.aidl
new file mode 100644
index 0000000..00dc0bc
--- /dev/null
+++ b/libs/binder/tests/BinderRpcTestServerInfo.aidl
@@ -0,0 +1,22 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import ParcelableCertificateData;
+
+parcelable BinderRpcTestServerInfo {
+    long port;
+    ParcelableCertificateData cert;
+}
diff --git a/libs/binder/tests/ParcelableCertificateData.aidl b/libs/binder/tests/ParcelableCertificateData.aidl
new file mode 100644
index 0000000..38c382e
--- /dev/null
+++ b/libs/binder/tests/ParcelableCertificateData.aidl
@@ -0,0 +1,19 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+parcelable ParcelableCertificateData {
+    byte[] data;
+}
diff --git a/libs/binder/tests/RpcCertificateVerifierSimple.cpp b/libs/binder/tests/RpcCertificateVerifierSimple.cpp
index 68e7c65..0cf0e1c 100644
--- a/libs/binder/tests/RpcCertificateVerifierSimple.cpp
+++ b/libs/binder/tests/RpcCertificateVerifierSimple.cpp
@@ -16,12 +16,32 @@
 #define LOG_TAG "RpcCertificateVerifierSimple"
 #include <log/log.h>
 
+#include <binder/RpcCertificateUtils.h>
+
 #include "RpcCertificateVerifierSimple.h"
 
 namespace android {
 
-status_t RpcCertificateVerifierSimple::verify(const X509*, uint8_t*) {
-    // TODO(b/195166979): implement this
+status_t RpcCertificateVerifierSimple::verify(const X509* peerCert, uint8_t* outAlert) {
+    std::lock_guard<std::mutex> lock(mMutex);
+    for (const auto& trustedCert : mTrustedPeerCertificates) {
+        if (0 == X509_cmp(trustedCert.get(), peerCert)) {
+            return OK;
+        }
+    }
+    *outAlert = SSL_AD_CERTIFICATE_UNKNOWN;
+    return PERMISSION_DENIED;
+}
+
+status_t RpcCertificateVerifierSimple::addTrustedPeerCertificate(CertificateFormat format,
+                                                                 const std::vector<uint8_t>& cert) {
+    bssl::UniquePtr<X509> x509 = deserializeCertificate(cert, format);
+    if (x509 == nullptr) {
+        ALOGE("Certificate is not in the proper format %s", PrintToString(format).c_str());
+        return BAD_VALUE;
+    }
+    std::lock_guard<std::mutex> lock(mMutex);
+    mTrustedPeerCertificates.push_back(std::move(x509));
     return OK;
 }
 
diff --git a/libs/binder/tests/RpcCertificateVerifierSimple.h b/libs/binder/tests/RpcCertificateVerifierSimple.h
index aff5c7c..02aa3c6 100644
--- a/libs/binder/tests/RpcCertificateVerifierSimple.h
+++ b/libs/binder/tests/RpcCertificateVerifierSimple.h
@@ -16,14 +16,38 @@
 
 #pragma once
 
+#include <mutex>
+#include <string_view>
+#include <vector>
+
+#include <openssl/ssl.h>
+
+#include <binder/CertificateFormat.h>
 #include <binder/RpcCertificateVerifier.h>
 
 namespace android {
 
 // A simple certificate verifier for testing.
+// Keep a list of leaf certificates as trusted. No certificate chain support.
+//
+// All APIs are thread-safe. However, if verify() and addTrustedPeerCertificate() are called
+// simultaneously in different threads, it is not deterministic whether verify() will use the
+// certificate being added.
 class RpcCertificateVerifierSimple : public RpcCertificateVerifier {
 public:
     status_t verify(const X509*, uint8_t*) override;
+
+    // Add a trusted peer certificate. Peers presenting this certificate are accepted.
+    //
+    // Caller must ensure that RpcTransportCtx::newTransport() are called after all trusted peer
+    // certificates are added. Otherwise, RpcTransport-s created before may not trust peer
+    // certificates added later.
+    [[nodiscard]] status_t addTrustedPeerCertificate(CertificateFormat format,
+                                                     const std::vector<uint8_t>& cert);
+
+private:
+    std::mutex mMutex; // for below
+    std::vector<bssl::UniquePtr<X509>> mTrustedPeerCertificates;
 };
 
 } // namespace android
diff --git a/libs/binder/tests/binderRpcTest.cpp b/libs/binder/tests/binderRpcTest.cpp
index dbf8899..8c805bb 100644
--- a/libs/binder/tests/binderRpcTest.cpp
+++ b/libs/binder/tests/binderRpcTest.cpp
@@ -14,6 +14,8 @@
  * limitations under the License.
  */
 
+#include <BinderRpcTestClientInfo.h>
+#include <BinderRpcTestServerInfo.h>
 #include <BnBinderRpcCallback.h>
 #include <BnBinderRpcSession.h>
 #include <BnBinderRpcTest.h>
@@ -40,15 +42,20 @@
 #include <thread>
 #include <type_traits>
 
+#include <poll.h>
 #include <sys/prctl.h>
 #include <unistd.h>
 
+#include "../FdTrigger.h"
 #include "../RpcSocketAddress.h" // for testing preconnected clients
 #include "../RpcState.h"   // for debugging
 #include "../vm_sockets.h" // for VMADDR_*
 #include "RpcCertificateVerifierSimple.h"
 
 using namespace std::chrono_literals;
+using testing::AssertionFailure;
+using testing::AssertionResult;
+using testing::AssertionSuccess;
 
 namespace android {
 
@@ -68,7 +75,6 @@
         case RpcSecurity::RAW:
             return RpcTransportCtxFactoryRaw::make();
         case RpcSecurity::TLS: {
-            // TODO(b/198833574): exchange keys and set proper verifier
             if (verifier == nullptr) {
                 verifier = std::make_shared<RpcCertificateVerifierSimple>();
             }
@@ -311,14 +317,17 @@
 class Process {
 public:
     Process(Process&&) = default;
-    Process(const std::function<void(android::base::borrowed_fd /* writeEnd */)>& f) {
-        android::base::unique_fd writeEnd;
-        CHECK(android::base::Pipe(&mReadEnd, &writeEnd)) << strerror(errno);
+    Process(const std::function<void(android::base::borrowed_fd /* writeEnd */,
+                                     android::base::borrowed_fd /* readEnd */)>& f) {
+        android::base::unique_fd childWriteEnd;
+        android::base::unique_fd childReadEnd;
+        CHECK(android::base::Pipe(&mReadEnd, &childWriteEnd)) << strerror(errno);
+        CHECK(android::base::Pipe(&childReadEnd, &mWriteEnd)) << strerror(errno);
         if (0 == (mPid = fork())) {
             // racey: assume parent doesn't crash before this is set
             prctl(PR_SET_PDEATHSIG, SIGHUP);
 
-            f(writeEnd);
+            f(childWriteEnd, childReadEnd);
 
             exit(0);
         }
@@ -329,16 +338,20 @@
         }
     }
     android::base::borrowed_fd readEnd() { return mReadEnd; }
+    android::base::borrowed_fd writeEnd() { return mWriteEnd; }
 
 private:
     pid_t mPid = 0;
     android::base::unique_fd mReadEnd;
+    android::base::unique_fd mWriteEnd;
 };
 
 static std::string allocateSocketAddress() {
     static size_t id = 0;
     std::string temp = getenv("TMPDIR") ?: "/tmp";
-    return temp + "/binderRpcTest_" + std::to_string(id++);
+    auto ret = temp + "/binderRpcTest_" + std::to_string(id++);
+    unlink(ret.c_str());
+    return ret;
 };
 
 static unsigned int allocateVsockPort() {
@@ -441,16 +454,17 @@
     }
 }
 
-static base::unique_fd connectToUds(const char* addrStr) {
-    UnixSocketAddress addr(addrStr);
+static base::unique_fd connectTo(const RpcSocketAddress& addr) {
     base::unique_fd serverFd(
             TEMP_FAILURE_RETRY(socket(addr.addr()->sa_family, SOCK_STREAM | SOCK_CLOEXEC, 0)));
     int savedErrno = errno;
-    CHECK(serverFd.ok()) << "Could not create socket " << addrStr << ": " << strerror(savedErrno);
+    CHECK(serverFd.ok()) << "Could not create socket " << addr.toString() << ": "
+                         << strerror(savedErrno);
 
     if (0 != TEMP_FAILURE_RETRY(connect(serverFd.get(), addr.addr(), addr.addrSize()))) {
         int savedErrno = errno;
-        LOG(FATAL) << "Could not connect to socket " << addrStr << ": " << strerror(savedErrno);
+        LOG(FATAL) << "Could not connect to socket " << addr.toString() << ": "
+                   << strerror(savedErrno);
     }
     return serverFd;
 }
@@ -468,6 +482,37 @@
         return PrintToString(type) + "_" + newFactory(security)->toCString();
     }
 
+    static inline void writeString(android::base::borrowed_fd fd, std::string_view str) {
+        uint64_t length = str.length();
+        CHECK(android::base::WriteFully(fd, &length, sizeof(length)));
+        CHECK(android::base::WriteFully(fd, str.data(), str.length()));
+    }
+
+    static inline std::string readString(android::base::borrowed_fd fd) {
+        uint64_t length;
+        CHECK(android::base::ReadFully(fd, &length, sizeof(length)));
+        std::string ret(length, '\0');
+        CHECK(android::base::ReadFully(fd, ret.data(), length));
+        return ret;
+    }
+
+    static inline void writeToFd(android::base::borrowed_fd fd, const Parcelable& parcelable) {
+        Parcel parcel;
+        CHECK_EQ(OK, parcelable.writeToParcel(&parcel));
+        writeString(fd,
+                    std::string(reinterpret_cast<const char*>(parcel.data()), parcel.dataSize()));
+    }
+
+    template <typename T>
+    static inline T readFromFd(android::base::borrowed_fd fd) {
+        std::string data = readString(fd);
+        Parcel parcel;
+        CHECK_EQ(OK, parcel.setData(reinterpret_cast<const uint8_t*>(data.data()), data.size()));
+        T object;
+        CHECK_EQ(OK, object.readFromParcel(&parcel));
+        return object;
+    }
+
     // This creates a new process serving an interface on a certain number of
     // threads.
     ProcessSession createRpcTestSocketServerProcess(
@@ -479,11 +524,12 @@
 
         unsigned int vsockPort = allocateVsockPort();
         std::string addr = allocateSocketAddress();
-        unlink(addr.c_str());
 
         auto ret = ProcessSession{
-                .host = Process([&](android::base::borrowed_fd writeEnd) {
-                    sp<RpcServer> server = RpcServer::make(newFactory(rpcSecurity));
+                .host = Process([&](android::base::borrowed_fd writeEnd,
+                                    android::base::borrowed_fd readEnd) {
+                    auto certVerifier = std::make_shared<RpcCertificateVerifierSimple>();
+                    sp<RpcServer> server = RpcServer::make(newFactory(rpcSecurity, certVerifier));
 
                     server->iUnderstandThisCodeIsExperimentalAndIWillNotUseItInProduction();
                     server->setMaxThreads(options.numThreads);
@@ -508,7 +554,19 @@
                             LOG_ALWAYS_FATAL("Unknown socket type");
                     }
 
-                    CHECK(android::base::WriteFully(writeEnd, &outPort, sizeof(outPort)));
+                    BinderRpcTestServerInfo serverInfo;
+                    serverInfo.port = static_cast<int64_t>(outPort);
+                    serverInfo.cert.data = server->getCertificate(CertificateFormat::PEM);
+                    writeToFd(writeEnd, serverInfo);
+                    auto clientInfo = readFromFd<BinderRpcTestClientInfo>(readEnd);
+
+                    if (rpcSecurity == RpcSecurity::TLS) {
+                        for (const auto& clientCert : clientInfo.certs) {
+                            CHECK_EQ(OK,
+                                     certVerifier->addTrustedPeerCertificate(CertificateFormat::PEM,
+                                                                             clientCert.data));
+                        }
+                    }
 
                     configure(server);
 
@@ -519,23 +577,40 @@
                 }),
         };
 
-        // always read socket, so that we have waited for the server to start
-        unsigned int outPort = 0;
-        CHECK(android::base::ReadFully(ret.host.readEnd(), &outPort, sizeof(outPort)));
+        std::vector<sp<RpcSession>> sessions;
+        auto certVerifier = std::make_shared<RpcCertificateVerifierSimple>();
+        for (size_t i = 0; i < options.numSessions; i++) {
+            sessions.emplace_back(RpcSession::make(newFactory(rpcSecurity, certVerifier)));
+        }
+
+        auto serverInfo = readFromFd<BinderRpcTestServerInfo>(ret.host.readEnd());
+        BinderRpcTestClientInfo clientInfo;
+        for (const auto& session : sessions) {
+            auto& parcelableCert = clientInfo.certs.emplace_back();
+            parcelableCert.data = session->getCertificate(CertificateFormat::PEM);
+        }
+        writeToFd(ret.host.writeEnd(), clientInfo);
+
+        CHECK_LE(serverInfo.port, std::numeric_limits<unsigned int>::max());
         if (socketType == SocketType::INET) {
-            CHECK_NE(0, outPort);
+            CHECK_NE(0, serverInfo.port);
+        }
+
+        if (rpcSecurity == RpcSecurity::TLS) {
+            const auto& serverCert = serverInfo.cert.data;
+            CHECK_EQ(OK,
+                     certVerifier->addTrustedPeerCertificate(CertificateFormat::PEM, serverCert));
         }
 
         status_t status;
 
-        for (size_t i = 0; i < options.numSessions; i++) {
-            sp<RpcSession> session = RpcSession::make(newFactory(rpcSecurity));
+        for (const auto& session : sessions) {
             session->setMaxThreads(options.numIncomingConnections);
 
             switch (socketType) {
                 case SocketType::PRECONNECTED:
                     status = session->setupPreconnectedClient({}, [=]() {
-                        return connectToUds(addr.c_str());
+                        return connectTo(UnixSocketAddress(addr.c_str()));
                     });
                     if (status == OK) goto success;
                     break;
@@ -548,7 +623,7 @@
                     if (status == OK) goto success;
                     break;
                 case SocketType::INET:
-                    status = session->setupInetClient("127.0.0.1", outPort);
+                    status = session->setupInetClient("127.0.0.1", serverInfo.port);
                     if (status == OK) goto success;
                     break;
                 default:
@@ -1221,8 +1296,10 @@
     return status == OK;
 }
 
-static std::vector<SocketType> testSocketTypes() {
-    std::vector<SocketType> ret = {SocketType::PRECONNECTED, SocketType::UNIX, SocketType::INET};
+static std::vector<SocketType> testSocketTypes(bool hasPreconnected = true) {
+    std::vector<SocketType> ret = {SocketType::UNIX, SocketType::INET};
+
+    if (hasPreconnected) ret.push_back(SocketType::PRECONNECTED);
 
     static bool hasVsockLoopback = testSupportVsockLoopback();
 
@@ -1291,7 +1368,6 @@
 
 TEST_P(BinderRpcSimple, Shutdown) {
     auto addr = allocateSocketAddress();
-    unlink(addr.c_str());
     auto server = RpcServer::make(newFactory(GetParam()));
     server->iUnderstandThisCodeIsExperimentalAndIWillNotUseItInProduction();
     ASSERT_EQ(OK, server->setupUnixDomainServer(addr.c_str()));
@@ -1356,6 +1432,313 @@
 INSTANTIATE_TEST_CASE_P(BinderRpc, BinderRpcSimple, ::testing::ValuesIn(RpcSecurityValues()),
                         BinderRpcSimple::PrintTestParam);
 
+class RpcTransportTest
+      : public ::testing::TestWithParam<std::tuple<SocketType, RpcSecurity, CertificateFormat>> {
+public:
+    using ConnectToServer = std::function<base::unique_fd()>;
+    static inline std::string PrintParamInfo(const testing::TestParamInfo<ParamType>& info) {
+        auto [socketType, rpcSecurity, certificateFormat] = info.param;
+        return PrintToString(socketType) + "_" + newFactory(rpcSecurity)->toCString() + "_" +
+                PrintToString(certificateFormat);
+    }
+    void TearDown() override {
+        for (auto& server : mServers) server->shutdown();
+    }
+
+    // A server that handles client socket connections.
+    class Server {
+    public:
+        explicit Server() {}
+        Server(Server&&) = default;
+        ~Server() { shutdown(); }
+        [[nodiscard]] AssertionResult setUp() {
+            auto [socketType, rpcSecurity, certificateFormat] = GetParam();
+            auto rpcServer = RpcServer::make(newFactory(rpcSecurity));
+            rpcServer->iUnderstandThisCodeIsExperimentalAndIWillNotUseItInProduction();
+            switch (socketType) {
+                case SocketType::PRECONNECTED: {
+                    return AssertionFailure() << "Not supported by this test";
+                } break;
+                case SocketType::UNIX: {
+                    auto addr = allocateSocketAddress();
+                    auto status = rpcServer->setupUnixDomainServer(addr.c_str());
+                    if (status != OK) {
+                        return AssertionFailure()
+                                << "setupUnixDomainServer: " << statusToString(status);
+                    }
+                    mConnectToServer = [addr] {
+                        return connectTo(UnixSocketAddress(addr.c_str()));
+                    };
+                } break;
+                case SocketType::VSOCK: {
+                    auto port = allocateVsockPort();
+                    auto status = rpcServer->setupVsockServer(port);
+                    if (status != OK) {
+                        return AssertionFailure() << "setupVsockServer: " << statusToString(status);
+                    }
+                    mConnectToServer = [port] {
+                        return connectTo(VsockSocketAddress(VMADDR_CID_LOCAL, port));
+                    };
+                } break;
+                case SocketType::INET: {
+                    unsigned int port;
+                    auto status = rpcServer->setupInetServer(kLocalInetAddress, 0, &port);
+                    if (status != OK) {
+                        return AssertionFailure() << "setupInetServer: " << statusToString(status);
+                    }
+                    mConnectToServer = [port] {
+                        const char* addr = kLocalInetAddress;
+                        auto aiStart = InetSocketAddress::getAddrInfo(addr, port);
+                        if (aiStart == nullptr) return base::unique_fd{};
+                        for (auto ai = aiStart.get(); ai != nullptr; ai = ai->ai_next) {
+                            auto fd = connectTo(
+                                    InetSocketAddress(ai->ai_addr, ai->ai_addrlen, addr, port));
+                            if (fd.ok()) return fd;
+                        }
+                        ALOGE("None of the socket address resolved for %s:%u can be connected",
+                              addr, port);
+                        return base::unique_fd{};
+                    };
+                }
+            }
+            mFd = rpcServer->releaseServer();
+            if (!mFd.ok()) return AssertionFailure() << "releaseServer returns invalid fd";
+            mCtx = newFactory(rpcSecurity, mCertVerifier)->newServerCtx();
+            if (mCtx == nullptr) return AssertionFailure() << "newServerCtx";
+            mSetup = true;
+            return AssertionSuccess();
+        }
+        RpcTransportCtx* getCtx() const { return mCtx.get(); }
+        std::shared_ptr<RpcCertificateVerifierSimple> getCertVerifier() const {
+            return mCertVerifier;
+        }
+        ConnectToServer getConnectToServerFn() { return mConnectToServer; }
+        void start() {
+            LOG_ALWAYS_FATAL_IF(!mSetup, "Call Server::setup first!");
+            mThread = std::make_unique<std::thread>(&Server::run, this);
+        }
+        void run() {
+            LOG_ALWAYS_FATAL_IF(!mSetup, "Call Server::setup first!");
+
+            std::vector<std::thread> threads;
+            while (OK == mFdTrigger->triggerablePoll(mFd, POLLIN)) {
+                base::unique_fd acceptedFd(
+                        TEMP_FAILURE_RETRY(accept4(mFd.get(), nullptr, nullptr /*length*/,
+                                                   SOCK_CLOEXEC | SOCK_NONBLOCK)));
+                threads.emplace_back(&Server::handleOne, this, std::move(acceptedFd));
+            }
+
+            for (auto& thread : threads) thread.join();
+        }
+        void handleOne(android::base::unique_fd acceptedFd) {
+            ASSERT_TRUE(acceptedFd.ok());
+            auto serverTransport = mCtx->newTransport(std::move(acceptedFd), mFdTrigger.get());
+            if (serverTransport == nullptr) return; // handshake failed
+            std::string message(kMessage);
+            ASSERT_EQ(OK,
+                      serverTransport->interruptableWriteFully(mFdTrigger.get(), message.data(),
+                                                               message.size()));
+        }
+        void shutdown() {
+            mFdTrigger->trigger();
+            if (mThread != nullptr) {
+                mThread->join();
+                mThread = nullptr;
+            }
+        }
+
+    private:
+        std::unique_ptr<std::thread> mThread;
+        ConnectToServer mConnectToServer;
+        std::unique_ptr<FdTrigger> mFdTrigger = FdTrigger::make();
+        base::unique_fd mFd;
+        std::unique_ptr<RpcTransportCtx> mCtx;
+        std::shared_ptr<RpcCertificateVerifierSimple> mCertVerifier =
+                std::make_shared<RpcCertificateVerifierSimple>();
+        bool mSetup = false;
+    };
+
+    class Client {
+    public:
+        explicit Client(ConnectToServer connectToServer) : mConnectToServer(connectToServer) {}
+        Client(Client&&) = default;
+        [[nodiscard]] AssertionResult setUp() {
+            auto [socketType, rpcSecurity, certificateFormat] = GetParam();
+            mFd = mConnectToServer();
+            if (!mFd.ok()) return AssertionFailure() << "Cannot connect to server";
+            mFdTrigger = FdTrigger::make();
+            mCtx = newFactory(rpcSecurity, mCertVerifier)->newClientCtx();
+            if (mCtx == nullptr) return AssertionFailure() << "newClientCtx";
+            return AssertionSuccess();
+        }
+        RpcTransportCtx* getCtx() const { return mCtx.get(); }
+        std::shared_ptr<RpcCertificateVerifierSimple> getCertVerifier() const {
+            return mCertVerifier;
+        }
+        void run(bool handshakeOk = true, bool readOk = true) {
+            auto clientTransport = mCtx->newTransport(std::move(mFd), mFdTrigger.get());
+            if (clientTransport == nullptr) {
+                ASSERT_FALSE(handshakeOk) << "newTransport returns nullptr, but it shouldn't";
+                return;
+            }
+            ASSERT_TRUE(handshakeOk) << "newTransport does not return nullptr, but it should";
+            std::string expectedMessage(kMessage);
+            std::string readMessage(expectedMessage.size(), '\0');
+            status_t readStatus =
+                    clientTransport->interruptableReadFully(mFdTrigger.get(), readMessage.data(),
+                                                            readMessage.size());
+            if (readOk) {
+                ASSERT_EQ(OK, readStatus);
+                ASSERT_EQ(readMessage, expectedMessage);
+            } else {
+                ASSERT_NE(OK, readStatus);
+            }
+        }
+
+    private:
+        ConnectToServer mConnectToServer;
+        base::unique_fd mFd;
+        std::unique_ptr<FdTrigger> mFdTrigger = FdTrigger::make();
+        std::unique_ptr<RpcTransportCtx> mCtx;
+        std::shared_ptr<RpcCertificateVerifierSimple> mCertVerifier =
+                std::make_shared<RpcCertificateVerifierSimple>();
+    };
+
+    // Make A trust B.
+    template <typename A, typename B>
+    status_t trust(A* a, B* b) {
+        auto [socketType, rpcSecurity, certificateFormat] = GetParam();
+        if (rpcSecurity != RpcSecurity::TLS) return OK;
+        auto bCert = b->getCtx()->getCertificate(certificateFormat);
+        return a->getCertVerifier()->addTrustedPeerCertificate(certificateFormat, bCert);
+    }
+
+    static constexpr const char* kMessage = "hello";
+    std::vector<std::unique_ptr<Server>> mServers;
+};
+
+TEST_P(RpcTransportTest, GoodCertificate) {
+    auto server = mServers.emplace_back(std::make_unique<Server>()).get();
+    ASSERT_TRUE(server->setUp());
+
+    Client client(server->getConnectToServerFn());
+    ASSERT_TRUE(client.setUp());
+
+    ASSERT_EQ(OK, trust(&client, server));
+    ASSERT_EQ(OK, trust(server, &client));
+
+    server->start();
+    client.run();
+}
+
+TEST_P(RpcTransportTest, MultipleClients) {
+    auto server = mServers.emplace_back(std::make_unique<Server>()).get();
+    ASSERT_TRUE(server->setUp());
+
+    std::vector<Client> clients;
+    for (int i = 0; i < 2; i++) {
+        auto& client = clients.emplace_back(server->getConnectToServerFn());
+        ASSERT_TRUE(client.setUp());
+        ASSERT_EQ(OK, trust(&client, server));
+        ASSERT_EQ(OK, trust(server, &client));
+    }
+
+    server->start();
+    for (auto& client : clients) client.run();
+}
+
+TEST_P(RpcTransportTest, UntrustedServer) {
+    auto [socketType, rpcSecurity, certificateFormat] = GetParam();
+
+    auto untrustedServer = mServers.emplace_back(std::make_unique<Server>()).get();
+    ASSERT_TRUE(untrustedServer->setUp());
+
+    Client client(untrustedServer->getConnectToServerFn());
+    ASSERT_TRUE(client.setUp());
+
+    ASSERT_EQ(OK, trust(untrustedServer, &client));
+
+    untrustedServer->start();
+
+    // For TLS, this should reject the certificate. For RAW sockets, it should pass because
+    // the client can't verify the server's identity.
+    bool handshakeOk = rpcSecurity != RpcSecurity::TLS;
+    client.run(handshakeOk);
+}
+TEST_P(RpcTransportTest, MaliciousServer) {
+    auto [socketType, rpcSecurity, certificateFormat] = GetParam();
+    auto validServer = mServers.emplace_back(std::make_unique<Server>()).get();
+    ASSERT_TRUE(validServer->setUp());
+
+    auto maliciousServer = mServers.emplace_back(std::make_unique<Server>()).get();
+    ASSERT_TRUE(maliciousServer->setUp());
+
+    Client client(maliciousServer->getConnectToServerFn());
+    ASSERT_TRUE(client.setUp());
+
+    ASSERT_EQ(OK, trust(&client, validServer));
+    ASSERT_EQ(OK, trust(validServer, &client));
+    ASSERT_EQ(OK, trust(maliciousServer, &client));
+
+    maliciousServer->start();
+
+    // For TLS, this should reject the certificate. For RAW sockets, it should pass because
+    // the client can't verify the server's identity.
+    bool handshakeOk = rpcSecurity != RpcSecurity::TLS;
+    client.run(handshakeOk);
+}
+
+TEST_P(RpcTransportTest, UntrustedClient) {
+    auto [socketType, rpcSecurity, certificateFormat] = GetParam();
+    auto server = mServers.emplace_back(std::make_unique<Server>()).get();
+    ASSERT_TRUE(server->setUp());
+
+    Client client(server->getConnectToServerFn());
+    ASSERT_TRUE(client.setUp());
+
+    ASSERT_EQ(OK, trust(&client, server));
+
+    server->start();
+
+    // For TLS, Client should be able to verify server's identity, so client should see
+    // do_handshake() successfully executed. However, server shouldn't be able to verify client's
+    // identity and should drop the connection, so client shouldn't be able to read anything.
+    bool readOk = rpcSecurity != RpcSecurity::TLS;
+    client.run(true, readOk);
+}
+
+TEST_P(RpcTransportTest, MaliciousClient) {
+    auto [socketType, rpcSecurity, certificateFormat] = GetParam();
+    auto server = mServers.emplace_back(std::make_unique<Server>()).get();
+    ASSERT_TRUE(server->setUp());
+
+    Client validClient(server->getConnectToServerFn());
+    ASSERT_TRUE(validClient.setUp());
+    Client maliciousClient(server->getConnectToServerFn());
+    ASSERT_TRUE(maliciousClient.setUp());
+
+    ASSERT_EQ(OK, trust(&validClient, server));
+    ASSERT_EQ(OK, trust(&maliciousClient, server));
+
+    server->start();
+
+    // See UntrustedClient.
+    bool readOk = rpcSecurity != RpcSecurity::TLS;
+    maliciousClient.run(true, readOk);
+}
+
+std::vector<CertificateFormat> testCertificateFormats() {
+    return {
+            CertificateFormat::PEM,
+    };
+}
+
+INSTANTIATE_TEST_CASE_P(BinderRpc, RpcTransportTest,
+                        ::testing::Combine(::testing::ValuesIn(testSocketTypes(false)),
+                                           ::testing::ValuesIn(RpcSecurityValues()),
+                                           ::testing::ValuesIn(testCertificateFormats())),
+                        RpcTransportTest::PrintParamInfo);
+
 } // namespace android
 
 int main(int argc, char** argv) {