binder: Add tests for using pre-signed certificates.

Test: binderRpcTest
Fixes: 199344157
Change-Id: I0f9d8ce3d4fadecd197d87393f689bdeb35dbc56
diff --git a/libs/binder/tests/binderRpcTest.cpp b/libs/binder/tests/binderRpcTest.cpp
index 2fd1a2a..0e7e259 100644
--- a/libs/binder/tests/binderRpcTest.cpp
+++ b/libs/binder/tests/binderRpcTest.cpp
@@ -31,6 +31,7 @@
 #include <binder/ProcessState.h>
 #include <binder/RpcServer.h>
 #include <binder/RpcSession.h>
+#include <binder/RpcTlsUtils.h>
 #include <binder/RpcTransport.h>
 #include <binder/RpcTransportRaw.h>
 #include <binder/RpcTransportTls.h>
@@ -1439,37 +1440,10 @@
 INSTANTIATE_TEST_CASE_P(BinderRpc, BinderRpcSimple, ::testing::ValuesIn(RpcSecurityValues()),
                         BinderRpcSimple::PrintTestParam);
 
-class RpcTransportTest
-      : public ::testing::TestWithParam<
-                std::tuple<SocketType, RpcSecurity, std::optional<RpcCertificateFormat>>> {
+class RpcTransportTestUtils {
 public:
+    using Param = std::tuple<SocketType, RpcSecurity, std::optional<RpcCertificateFormat>>;
     using ConnectToServer = std::function<base::unique_fd()>;
-    static inline std::string PrintParamInfo(const testing::TestParamInfo<ParamType>& info) {
-        auto [socketType, rpcSecurity, certificateFormat] = info.param;
-        auto ret = PrintToString(socketType) + "_" + newFactory(rpcSecurity)->toCString();
-        if (certificateFormat.has_value()) ret += "_" + PrintToString(*certificateFormat);
-        return ret;
-    }
-    static std::vector<ParamType> getRpcTranportTestParams() {
-        std::vector<RpcTransportTest::ParamType> ret;
-        for (auto socketType : testSocketTypes(false /* hasPreconnected */)) {
-            for (auto rpcSecurity : RpcSecurityValues()) {
-                switch (rpcSecurity) {
-                    case RpcSecurity::RAW: {
-                        ret.emplace_back(socketType, rpcSecurity, std::nullopt);
-                    } break;
-                    case RpcSecurity::TLS: {
-                        ret.emplace_back(socketType, rpcSecurity, RpcCertificateFormat::PEM);
-                        ret.emplace_back(socketType, rpcSecurity, RpcCertificateFormat::DER);
-                    } break;
-                }
-            }
-        }
-        return ret;
-    }
-    void TearDown() override {
-        for (auto& server : mServers) server->shutdownAndWait();
-    }
 
     // A server that handles client socket connections.
     class Server {
@@ -1477,8 +1451,10 @@
         explicit Server() {}
         Server(Server&&) = default;
         ~Server() { shutdownAndWait(); }
-        [[nodiscard]] AssertionResult setUp() {
-            auto [socketType, rpcSecurity, certificateFormat] = GetParam();
+        [[nodiscard]] AssertionResult setUp(
+                const Param& param,
+                std::unique_ptr<RpcAuth> auth = std::make_unique<RpcAuthSelfSigned>()) {
+            auto [socketType, rpcSecurity, certificateFormat] = param;
             auto rpcServer = RpcServer::make(newFactory(rpcSecurity));
             rpcServer->iUnderstandThisCodeIsExperimentalAndIWillNotUseItInProduction();
             switch (socketType) {
@@ -1529,7 +1505,7 @@
             }
             mFd = rpcServer->releaseServer();
             if (!mFd.ok()) return AssertionFailure() << "releaseServer returns invalid fd";
-            mCtx = newFactory(rpcSecurity, mCertVerifier)->newServerCtx();
+            mCtx = newFactory(rpcSecurity, mCertVerifier, std::move(auth))->newServerCtx();
             if (mCtx == nullptr) return AssertionFailure() << "newServerCtx";
             mSetup = true;
             return AssertionSuccess();
@@ -1608,8 +1584,8 @@
     public:
         explicit Client(ConnectToServer connectToServer) : mConnectToServer(connectToServer) {}
         Client(Client&&) = default;
-        [[nodiscard]] AssertionResult setUp() {
-            auto [socketType, rpcSecurity, certificateFormat] = GetParam();
+        [[nodiscard]] AssertionResult setUp(const Param& param) {
+            auto [socketType, rpcSecurity, certificateFormat] = param;
             mFdTrigger = FdTrigger::make();
             mCtx = newFactory(rpcSecurity, mCertVerifier)->newClientCtx();
             if (mCtx == nullptr) return AssertionFailure() << "newClientCtx";
@@ -1662,8 +1638,9 @@
 
     // Make A trust B.
     template <typename A, typename B>
-    status_t trust(A* a, B* b) {
-        auto [socketType, rpcSecurity, certificateFormat] = GetParam();
+    static status_t trust(RpcSecurity rpcSecurity,
+                          std::optional<RpcCertificateFormat> certificateFormat, const A& a,
+                          const B& b) {
         if (rpcSecurity != RpcSecurity::TLS) return OK;
         LOG_ALWAYS_FATAL_IF(!certificateFormat.has_value());
         auto bCert = b->getCtx()->getCertificate(*certificateFormat);
@@ -1671,15 +1648,48 @@
     }
 
     static constexpr const char* kMessage = "hello";
-    std::vector<std::unique_ptr<Server>> mServers;
+};
+
+class RpcTransportTest : public testing::TestWithParam<RpcTransportTestUtils::Param> {
+public:
+    using Server = RpcTransportTestUtils::Server;
+    using Client = RpcTransportTestUtils::Client;
+    static inline std::string PrintParamInfo(const testing::TestParamInfo<ParamType>& info) {
+        auto [socketType, rpcSecurity, certificateFormat] = info.param;
+        auto ret = PrintToString(socketType) + "_" + newFactory(rpcSecurity)->toCString();
+        if (certificateFormat.has_value()) ret += "_" + PrintToString(*certificateFormat);
+        return ret;
+    }
+    static std::vector<ParamType> getRpcTranportTestParams() {
+        std::vector<ParamType> ret;
+        for (auto socketType : testSocketTypes(false /* hasPreconnected */)) {
+            for (auto rpcSecurity : RpcSecurityValues()) {
+                switch (rpcSecurity) {
+                    case RpcSecurity::RAW: {
+                        ret.emplace_back(socketType, rpcSecurity, std::nullopt);
+                    } break;
+                    case RpcSecurity::TLS: {
+                        ret.emplace_back(socketType, rpcSecurity, RpcCertificateFormat::PEM);
+                        ret.emplace_back(socketType, rpcSecurity, RpcCertificateFormat::DER);
+                    } break;
+                }
+            }
+        }
+        return ret;
+    }
+    template <typename A, typename B>
+    status_t trust(const A& a, const B& b) {
+        auto [socketType, rpcSecurity, certificateFormat] = GetParam();
+        return RpcTransportTestUtils::trust(rpcSecurity, certificateFormat, a, b);
+    }
 };
 
 TEST_P(RpcTransportTest, GoodCertificate) {
-    auto server = mServers.emplace_back(std::make_unique<Server>()).get();
-    ASSERT_TRUE(server->setUp());
+    auto server = std::make_unique<Server>();
+    ASSERT_TRUE(server->setUp(GetParam()));
 
     Client client(server->getConnectToServerFn());
-    ASSERT_TRUE(client.setUp());
+    ASSERT_TRUE(client.setUp(GetParam()));
 
     ASSERT_EQ(OK, trust(&client, server));
     ASSERT_EQ(OK, trust(server, &client));
@@ -1689,13 +1699,13 @@
 }
 
 TEST_P(RpcTransportTest, MultipleClients) {
-    auto server = mServers.emplace_back(std::make_unique<Server>()).get();
-    ASSERT_TRUE(server->setUp());
+    auto server = std::make_unique<Server>();
+    ASSERT_TRUE(server->setUp(GetParam()));
 
     std::vector<Client> clients;
     for (int i = 0; i < 2; i++) {
         auto& client = clients.emplace_back(server->getConnectToServerFn());
-        ASSERT_TRUE(client.setUp());
+        ASSERT_TRUE(client.setUp(GetParam()));
         ASSERT_EQ(OK, trust(&client, server));
         ASSERT_EQ(OK, trust(server, &client));
     }
@@ -1707,11 +1717,11 @@
 TEST_P(RpcTransportTest, UntrustedServer) {
     auto [socketType, rpcSecurity, certificateFormat] = GetParam();
 
-    auto untrustedServer = mServers.emplace_back(std::make_unique<Server>()).get();
-    ASSERT_TRUE(untrustedServer->setUp());
+    auto untrustedServer = std::make_unique<Server>();
+    ASSERT_TRUE(untrustedServer->setUp(GetParam()));
 
     Client client(untrustedServer->getConnectToServerFn());
-    ASSERT_TRUE(client.setUp());
+    ASSERT_TRUE(client.setUp(GetParam()));
 
     ASSERT_EQ(OK, trust(untrustedServer, &client));
 
@@ -1724,14 +1734,14 @@
 }
 TEST_P(RpcTransportTest, MaliciousServer) {
     auto [socketType, rpcSecurity, certificateFormat] = GetParam();
-    auto validServer = mServers.emplace_back(std::make_unique<Server>()).get();
-    ASSERT_TRUE(validServer->setUp());
+    auto validServer = std::make_unique<Server>();
+    ASSERT_TRUE(validServer->setUp(GetParam()));
 
-    auto maliciousServer = mServers.emplace_back(std::make_unique<Server>()).get();
-    ASSERT_TRUE(maliciousServer->setUp());
+    auto maliciousServer = std::make_unique<Server>();
+    ASSERT_TRUE(maliciousServer->setUp(GetParam()));
 
     Client client(maliciousServer->getConnectToServerFn());
-    ASSERT_TRUE(client.setUp());
+    ASSERT_TRUE(client.setUp(GetParam()));
 
     ASSERT_EQ(OK, trust(&client, validServer));
     ASSERT_EQ(OK, trust(validServer, &client));
@@ -1747,11 +1757,11 @@
 
 TEST_P(RpcTransportTest, UntrustedClient) {
     auto [socketType, rpcSecurity, certificateFormat] = GetParam();
-    auto server = mServers.emplace_back(std::make_unique<Server>()).get();
-    ASSERT_TRUE(server->setUp());
+    auto server = std::make_unique<Server>();
+    ASSERT_TRUE(server->setUp(GetParam()));
 
     Client client(server->getConnectToServerFn());
-    ASSERT_TRUE(client.setUp());
+    ASSERT_TRUE(client.setUp(GetParam()));
 
     ASSERT_EQ(OK, trust(&client, server));
 
@@ -1766,13 +1776,13 @@
 
 TEST_P(RpcTransportTest, MaliciousClient) {
     auto [socketType, rpcSecurity, certificateFormat] = GetParam();
-    auto server = mServers.emplace_back(std::make_unique<Server>()).get();
-    ASSERT_TRUE(server->setUp());
+    auto server = std::make_unique<Server>();
+    ASSERT_TRUE(server->setUp(GetParam()));
 
     Client validClient(server->getConnectToServerFn());
-    ASSERT_TRUE(validClient.setUp());
+    ASSERT_TRUE(validClient.setUp(GetParam()));
     Client maliciousClient(server->getConnectToServerFn());
-    ASSERT_TRUE(maliciousClient.setUp());
+    ASSERT_TRUE(maliciousClient.setUp(GetParam()));
 
     ASSERT_EQ(OK, trust(&validClient, server));
     ASSERT_EQ(OK, trust(&maliciousClient, server));
@@ -1790,7 +1800,7 @@
     std::condition_variable writeCv;
     bool shouldContinueWriting = false;
     auto serverPostConnect = [&](RpcTransport* serverTransport, FdTrigger* fdTrigger) {
-        std::string message(kMessage);
+        std::string message(RpcTransportTestUtils::kMessage);
         auto status =
                 serverTransport->interruptableWriteFully(fdTrigger, message.data(), message.size());
         if (status != OK) return AssertionFailure() << statusToString(status);
@@ -1810,12 +1820,12 @@
         return AssertionSuccess();
     };
 
-    auto server = mServers.emplace_back(std::make_unique<Server>()).get();
-    ASSERT_TRUE(server->setUp());
+    auto server = std::make_unique<Server>();
+    ASSERT_TRUE(server->setUp(GetParam()));
 
     // Set up client
     Client client(server->getConnectToServerFn());
-    ASSERT_TRUE(client.setUp());
+    ASSERT_TRUE(client.setUp(GetParam()));
 
     // Exchange keys
     ASSERT_EQ(OK, trust(&client, server));
@@ -1828,7 +1838,7 @@
     ASSERT_TRUE(client.setUpTransport());
     // read the first message. This ensures that server has finished handshake and start handling
     // client fd. Server thread should pause at writeCv.wait_for().
-    ASSERT_TRUE(client.readMessage(kMessage));
+    ASSERT_TRUE(client.readMessage(RpcTransportTestUtils::kMessage));
     // Trigger server shutdown after server starts handling client FD. This ensures that the second
     // write is on an FdTrigger that has been shut down.
     server->shutdown();
@@ -1848,6 +1858,61 @@
                         ::testing::ValuesIn(RpcTransportTest::getRpcTranportTestParams()),
                         RpcTransportTest::PrintParamInfo);
 
+class RpcTransportTlsKeyTest
+      : public testing::TestWithParam<std::tuple<SocketType, RpcCertificateFormat, RpcKeyFormat>> {
+public:
+    template <typename A, typename B>
+    status_t trust(const A& a, const B& b) {
+        auto [socketType, certificateFormat, keyFormat] = GetParam();
+        return RpcTransportTestUtils::trust(RpcSecurity::TLS, certificateFormat, a, b);
+    }
+    static std::string PrintParamInfo(const testing::TestParamInfo<ParamType>& info) {
+        auto [socketType, certificateFormat, keyFormat] = info.param;
+        auto ret = PrintToString(socketType) + "_certificate_" + PrintToString(certificateFormat) +
+                "_key_" + PrintToString(keyFormat);
+        return ret;
+    };
+};
+
+TEST_P(RpcTransportTlsKeyTest, PreSignedCertificate) {
+    auto [socketType, certificateFormat, keyFormat] = GetParam();
+
+    std::vector<uint8_t> pkeyData, certData;
+    {
+        auto pkey = makeKeyPairForSelfSignedCert();
+        ASSERT_NE(nullptr, pkey);
+        auto cert = makeSelfSignedCert(pkey.get(), kCertValidSeconds);
+        ASSERT_NE(nullptr, cert);
+        pkeyData = serializeUnencryptedPrivatekey(pkey.get(), keyFormat);
+        certData = serializeCertificate(cert.get(), certificateFormat);
+    }
+
+    auto desPkey = deserializeUnencryptedPrivatekey(pkeyData, keyFormat);
+    auto desCert = deserializeCertificate(certData, certificateFormat);
+    auto auth = std::make_unique<RpcAuthPreSigned>(std::move(desPkey), std::move(desCert));
+    auto utilsParam =
+            std::make_tuple(socketType, RpcSecurity::TLS, std::make_optional(certificateFormat));
+
+    auto server = std::make_unique<RpcTransportTestUtils::Server>();
+    ASSERT_TRUE(server->setUp(utilsParam, std::move(auth)));
+
+    RpcTransportTestUtils::Client client(server->getConnectToServerFn());
+    ASSERT_TRUE(client.setUp(utilsParam));
+
+    ASSERT_EQ(OK, trust(&client, server));
+    ASSERT_EQ(OK, trust(server, &client));
+
+    server->start();
+    client.run();
+}
+
+INSTANTIATE_TEST_CASE_P(
+        BinderRpc, RpcTransportTlsKeyTest,
+        testing::Combine(testing::ValuesIn(testSocketTypes(false /* hasPreconnected*/)),
+                         testing::Values(RpcCertificateFormat::PEM, RpcCertificateFormat::DER),
+                         testing::Values(RpcKeyFormat::PEM, RpcKeyFormat::DER)),
+        RpcTransportTlsKeyTest::PrintParamInfo);
+
 } // namespace android
 
 int main(int argc, char** argv) {