binder: Add utils for (de)serializing key pairs.

Test: RpcTlsUtilsTest
Bug: 199344157
Change-Id: I2ce1dee8c057f99a92a03bceff62648b02d20d0c
diff --git a/libs/binder/RpcTlsUtils.cpp b/libs/binder/RpcTlsUtils.cpp
index 483cc7c..f3ca02a 100644
--- a/libs/binder/RpcTlsUtils.cpp
+++ b/libs/binder/RpcTlsUtils.cpp
@@ -25,54 +25,87 @@
 
 namespace {
 
-bssl::UniquePtr<X509> fromPem(const std::vector<uint8_t>& cert) {
-    if (cert.size() > std::numeric_limits<int>::max()) return nullptr;
-    bssl::UniquePtr<BIO> certBio(BIO_new_mem_buf(cert.data(), static_cast<int>(cert.size())));
-    return bssl::UniquePtr<X509>(PEM_read_bio_X509(certBio.get(), nullptr, nullptr, nullptr));
+static_assert(sizeof(unsigned char) == sizeof(uint8_t));
+
+template <typename PemReadBioFn,
+          typename T = std::remove_pointer_t<std::invoke_result_t<
+                  PemReadBioFn, BIO*, std::nullptr_t, std::nullptr_t, std::nullptr_t>>>
+bssl::UniquePtr<T> fromPem(const std::vector<uint8_t>& data, PemReadBioFn fn) {
+    if (data.size() > std::numeric_limits<int>::max()) return nullptr;
+    bssl::UniquePtr<BIO> bio(BIO_new_mem_buf(data.data(), static_cast<int>(data.size())));
+    return bssl::UniquePtr<T>(fn(bio.get(), nullptr, nullptr, nullptr));
 }
 
-bssl::UniquePtr<X509> fromDer(const std::vector<uint8_t>& cert) {
-    if (cert.size() > std::numeric_limits<long>::max()) return nullptr;
-    const unsigned char* data = cert.data();
-    auto expectedEnd = data + cert.size();
-    bssl::UniquePtr<X509> ret(d2i_X509(nullptr, &data, static_cast<long>(cert.size())));
-    if (data != expectedEnd) {
-        ALOGE("%s: %td bytes remaining!", __PRETTY_FUNCTION__, expectedEnd - data);
+template <typename D2iFn,
+          typename T = std::remove_pointer_t<
+                  std::invoke_result_t<D2iFn, std::nullptr_t, const unsigned char**, long>>>
+bssl::UniquePtr<T> fromDer(const std::vector<uint8_t>& data, D2iFn fn) {
+    if (data.size() > std::numeric_limits<long>::max()) return nullptr;
+    const unsigned char* dataPtr = data.data();
+    auto expectedEnd = dataPtr + data.size();
+    bssl::UniquePtr<T> ret(fn(nullptr, &dataPtr, static_cast<long>(data.size())));
+    if (dataPtr != expectedEnd) {
+        ALOGE("%s: %td bytes remaining!", __PRETTY_FUNCTION__, expectedEnd - dataPtr);
         return nullptr;
     }
     return ret;
 }
 
+template <typename T, typename WriteBioFn = int (*)(BIO*, T*)>
+std::vector<uint8_t> serialize(T* object, WriteBioFn writeBio) {
+    bssl::UniquePtr<BIO> bio(BIO_new(BIO_s_mem()));
+    TEST_AND_RETURN({}, writeBio(bio.get(), object));
+    const uint8_t* data;
+    size_t len;
+    TEST_AND_RETURN({}, BIO_mem_contents(bio.get(), &data, &len));
+    return std::vector<uint8_t>(data, data + len);
+}
+
 } // namespace
 
-bssl::UniquePtr<X509> deserializeCertificate(const std::vector<uint8_t>& cert,
+bssl::UniquePtr<X509> deserializeCertificate(const std::vector<uint8_t>& data,
                                              RpcCertificateFormat format) {
     switch (format) {
         case RpcCertificateFormat::PEM:
-            return fromPem(cert);
+            return fromPem(data, PEM_read_bio_X509);
         case RpcCertificateFormat::DER:
-            return fromDer(cert);
+            return fromDer(data, d2i_X509);
     }
     LOG_ALWAYS_FATAL("Unsupported format %d", static_cast<int>(format));
 }
 
 std::vector<uint8_t> serializeCertificate(X509* x509, RpcCertificateFormat format) {
-    bssl::UniquePtr<BIO> certBio(BIO_new(BIO_s_mem()));
     switch (format) {
-        case RpcCertificateFormat::PEM: {
-            TEST_AND_RETURN({}, PEM_write_bio_X509(certBio.get(), x509));
-        } break;
-        case RpcCertificateFormat::DER: {
-            TEST_AND_RETURN({}, i2d_X509_bio(certBio.get(), x509));
-        } break;
-        default: {
-            LOG_ALWAYS_FATAL("Unsupported format %d", static_cast<int>(format));
-        }
+        case RpcCertificateFormat::PEM:
+            return serialize(x509, PEM_write_bio_X509);
+        case RpcCertificateFormat::DER:
+            return serialize(x509, i2d_X509_bio);
     }
-    const uint8_t* data;
-    size_t len;
-    TEST_AND_RETURN({}, BIO_mem_contents(certBio.get(), &data, &len));
-    return std::vector<uint8_t>(data, data + len);
+    LOG_ALWAYS_FATAL("Unsupported format %d", static_cast<int>(format));
+}
+
+bssl::UniquePtr<EVP_PKEY> deserializeUnencryptedPrivatekey(const std::vector<uint8_t>& data,
+                                                           RpcKeyFormat format) {
+    switch (format) {
+        case RpcKeyFormat::PEM:
+            return fromPem(data, PEM_read_bio_PrivateKey);
+        case RpcKeyFormat::DER:
+            return fromDer(data, d2i_AutoPrivateKey);
+    }
+    LOG_ALWAYS_FATAL("Unsupported format %d", static_cast<int>(format));
+}
+
+std::vector<uint8_t> serializeUnencryptedPrivatekey(EVP_PKEY* pkey, RpcKeyFormat format) {
+    switch (format) {
+        case RpcKeyFormat::PEM:
+            return serialize(pkey, [](BIO* bio, EVP_PKEY* pkey) {
+                return PEM_write_bio_PrivateKey(bio, pkey, nullptr /* enc */, nullptr /* kstr */,
+                                                0 /* klen */, nullptr, nullptr);
+            });
+        case RpcKeyFormat::DER:
+            return serialize(pkey, i2d_PrivateKey_bio);
+    }
+    LOG_ALWAYS_FATAL("Unsupported format %d", static_cast<int>(format));
 }
 
 } // namespace android
diff --git a/libs/binder/include/binder/RpcKeyFormat.h b/libs/binder/include/binder/RpcKeyFormat.h
new file mode 100644
index 0000000..5099c2e
--- /dev/null
+++ b/libs/binder/include/binder/RpcKeyFormat.h
@@ -0,0 +1,41 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Formats for serializing TLS private keys.
+
+#pragma once
+
+#include <string>
+
+namespace android {
+
+enum class RpcKeyFormat {
+    PEM,
+    DER,
+};
+
+static inline std::string PrintToString(RpcKeyFormat format) {
+    switch (format) {
+        case RpcKeyFormat::PEM:
+            return "PEM";
+        case RpcKeyFormat::DER:
+            return "DER";
+        default:
+            return "<unknown>";
+    }
+}
+
+} // namespace android
diff --git a/libs/binder/include_tls/binder/RpcTlsUtils.h b/libs/binder/include_tls/binder/RpcTlsUtils.h
index 8d07835..591926b 100644
--- a/libs/binder/include_tls/binder/RpcTlsUtils.h
+++ b/libs/binder/include_tls/binder/RpcTlsUtils.h
@@ -23,12 +23,20 @@
 #include <openssl/ssl.h>
 
 #include <binder/RpcCertificateFormat.h>
+#include <binder/RpcKeyFormat.h>
 
 namespace android {
 
-bssl::UniquePtr<X509> deserializeCertificate(const std::vector<uint8_t>& cert,
+bssl::UniquePtr<X509> deserializeCertificate(const std::vector<uint8_t>& data,
                                              RpcCertificateFormat format);
 
 std::vector<uint8_t> serializeCertificate(X509* x509, RpcCertificateFormat format);
 
+// Deserialize an un-encrypted private key.
+bssl::UniquePtr<EVP_PKEY> deserializeUnencryptedPrivatekey(const std::vector<uint8_t>& data,
+                                                           RpcKeyFormat format);
+
+// Serialize a private key in un-encrypted form.
+std::vector<uint8_t> serializeUnencryptedPrivatekey(EVP_PKEY* pkey, RpcKeyFormat format);
+
 } // namespace android
diff --git a/libs/binder/tests/Android.bp b/libs/binder/tests/Android.bp
index 6f3c6e2..23c1b14 100644
--- a/libs/binder/tests/Android.bp
+++ b/libs/binder/tests/Android.bp
@@ -173,6 +173,37 @@
     require_root: true,
 }
 
+cc_test {
+    name: "RpcTlsUtilsTest",
+    host_supported: true,
+    target: {
+        darwin: {
+            enabled: false,
+        },
+        android: {
+            test_suites: ["vts"],
+        },
+    },
+    defaults: [
+        "binder_test_defaults",
+        "libbinder_tls_shared_deps",
+    ],
+    srcs: [
+        "RpcAuthTesting.cpp",
+        "RpcTlsUtilsTest.cpp",
+    ],
+    shared_libs: [
+        "libbinder",
+        "libbase",
+        "libutils",
+        "liblog",
+    ],
+    static_libs: [
+        "libbinder_tls_static",
+    ],
+    test_suites: ["general-tests", "device-tests"],
+}
+
 cc_benchmark {
     name: "binderRpcBenchmark",
     defaults: ["binder_test_defaults"],
diff --git a/libs/binder/tests/RpcTlsUtilsTest.cpp b/libs/binder/tests/RpcTlsUtilsTest.cpp
new file mode 100644
index 0000000..9b3078d
--- /dev/null
+++ b/libs/binder/tests/RpcTlsUtilsTest.cpp
@@ -0,0 +1,115 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <binder/RpcTlsUtils.h>
+#include <gtest/gtest.h>
+
+#include "RpcAuthTesting.h"
+
+namespace android {
+
+std::string toDebugString(EVP_PKEY* pkey) {
+    bssl::UniquePtr<BIO> bio(BIO_new(BIO_s_mem()));
+    int res = EVP_PKEY_print_public(bio.get(), pkey, 2, nullptr);
+    std::string buf = "\nEVP_PKEY_print_public -> " + std::to_string(res) + "\n";
+    if (BIO_write(bio.get(), buf.data(), buf.length()) <= 0) return {};
+    res = EVP_PKEY_print_private(bio.get(), pkey, 2, nullptr);
+    buf = "\nEVP_PKEY_print_private -> " + std::to_string(res);
+    if (BIO_write(bio.get(), buf.data(), buf.length()) <= 0) return {};
+    const uint8_t* data;
+    size_t len;
+    if (!BIO_mem_contents(bio.get(), &data, &len)) return {};
+    return std::string(reinterpret_cast<const char*>(data), len);
+}
+
+class RpcTlsUtilsKeyTest : public testing::TestWithParam<RpcKeyFormat> {
+public:
+    static inline std::string PrintParamInfo(const testing::TestParamInfo<ParamType>& info) {
+        return PrintToString(info.param);
+    }
+};
+
+TEST_P(RpcTlsUtilsKeyTest, Test) {
+    auto pkey = makeKeyPairForSelfSignedCert();
+    ASSERT_NE(nullptr, pkey);
+    auto pkeyData = serializeUnencryptedPrivatekey(pkey.get(), GetParam());
+    auto deserializedPkey = deserializeUnencryptedPrivatekey(pkeyData, GetParam());
+    ASSERT_NE(nullptr, deserializedPkey);
+    EXPECT_EQ(1, EVP_PKEY_cmp(pkey.get(), deserializedPkey.get()))
+            << "expected: " << toDebugString(pkey.get())
+            << "\nactual: " << toDebugString(deserializedPkey.get());
+}
+
+INSTANTIATE_TEST_CASE_P(RpcTlsUtilsTest, RpcTlsUtilsKeyTest,
+                        testing::Values(RpcKeyFormat::PEM, RpcKeyFormat::DER),
+                        RpcTlsUtilsKeyTest::PrintParamInfo);
+
+class RpcTlsUtilsCertTest : public testing::TestWithParam<RpcCertificateFormat> {
+public:
+    static inline std::string PrintParamInfo(const testing::TestParamInfo<ParamType>& info) {
+        return PrintToString(info.param);
+    }
+};
+
+TEST_P(RpcTlsUtilsCertTest, Test) {
+    auto pkey = makeKeyPairForSelfSignedCert();
+    ASSERT_NE(nullptr, pkey);
+    // Make certificate from the original key in memory
+    auto cert = makeSelfSignedCert(pkey.get(), kCertValidSeconds);
+    ASSERT_NE(nullptr, cert);
+    auto certData = serializeCertificate(cert.get(), GetParam());
+    auto deserializedCert = deserializeCertificate(certData, GetParam());
+    ASSERT_NE(nullptr, deserializedCert);
+    EXPECT_EQ(0, X509_cmp(cert.get(), deserializedCert.get()));
+}
+
+INSTANTIATE_TEST_CASE_P(RpcTlsUtilsTest, RpcTlsUtilsCertTest,
+                        testing::Values(RpcCertificateFormat::PEM, RpcCertificateFormat::DER),
+                        RpcTlsUtilsCertTest::PrintParamInfo);
+
+class RpcTlsUtilsKeyAndCertTest
+      : public testing::TestWithParam<std::tuple<RpcKeyFormat, RpcCertificateFormat>> {
+public:
+    static inline std::string PrintParamInfo(const testing::TestParamInfo<ParamType>& info) {
+        auto [keyFormat, certificateFormat] = info.param;
+        return "key_" + PrintToString(keyFormat) + "_cert_" + PrintToString(certificateFormat);
+    }
+};
+
+TEST_P(RpcTlsUtilsKeyAndCertTest, TestCertFromDeserializedKey) {
+    auto [keyFormat, certificateFormat] = GetParam();
+    auto pkey = makeKeyPairForSelfSignedCert();
+    ASSERT_NE(nullptr, pkey);
+    auto pkeyData = serializeUnencryptedPrivatekey(pkey.get(), keyFormat);
+    auto deserializedPkey = deserializeUnencryptedPrivatekey(pkeyData, keyFormat);
+    ASSERT_NE(nullptr, deserializedPkey);
+
+    // Make certificate from deserialized key loaded from bytes
+    auto cert = makeSelfSignedCert(deserializedPkey.get(), kCertValidSeconds);
+    ASSERT_NE(nullptr, cert);
+    auto certData = serializeCertificate(cert.get(), certificateFormat);
+    auto deserializedCert = deserializeCertificate(certData, certificateFormat);
+    ASSERT_NE(nullptr, deserializedCert);
+    EXPECT_EQ(0, X509_cmp(cert.get(), deserializedCert.get()));
+}
+
+INSTANTIATE_TEST_CASE_P(RpcTlsUtilsTest, RpcTlsUtilsKeyAndCertTest,
+                        testing::Combine(testing::Values(RpcKeyFormat::PEM, RpcKeyFormat::DER),
+                                         testing::Values(RpcCertificateFormat::PEM,
+                                                         RpcCertificateFormat::DER)),
+                        RpcTlsUtilsKeyAndCertTest::PrintParamInfo);
+
+} // namespace android