Merge "Add support for an existing CompOS cert."
diff --git a/ondevice-signing/CertUtils.cpp b/ondevice-signing/CertUtils.cpp
index 10abfe2..9867f62 100644
--- a/ondevice-signing/CertUtils.cpp
+++ b/ondevice-signing/CertUtils.cpp
@@ -14,6 +14,8 @@
  * limitations under the License.
  */
 
+#include "CertUtils.h"
+
 #include <android-base/logging.h>
 #include <android-base/result.h>
 
@@ -32,6 +34,7 @@
 const char kBasicConstraints[] = "CA:TRUE";
 const char kKeyUsage[] = "critical,keyCertSign,cRLSign,digitalSignature";
 const char kSubjectKeyIdentifier[] = "hash";
+const char kAuthorityKeyIdentifier[] = "keyid:always";
 constexpr int kCertLifetimeSeconds = 10 * 365 * 24 * 60 * 60;
 
 using android::base::Result;
@@ -91,6 +94,21 @@
     return {};
 }
 
+static Result<bssl::UniquePtr<EVP_PKEY>> toRsaPkey(const std::vector<uint8_t>& publicKey) {
+    // "publicKey" corresponds to the raw public key bytes - need to create
+    // a new RSA key with the correct exponent.
+    auto rsaPubkey = getRsa(publicKey);
+    if (!rsaPubkey.ok()) {
+        return rsaPubkey.error();
+    }
+
+    bssl::UniquePtr<EVP_PKEY> public_key(EVP_PKEY_new());
+    if (!EVP_PKEY_assign_RSA(public_key.get(), rsaPubkey->release())) {
+        return Error() << "Failed to assign key";
+    }
+    return public_key;
+}
+
 Result<void> createSelfSignedCertificate(
     const std::vector<uint8_t>& publicKey,
     const std::function<Result<std::string>(const std::string&)>& signFunction,
@@ -105,17 +123,12 @@
     X509_gmtime_adj(X509_get_notBefore(x509.get()), 0);
     X509_gmtime_adj(X509_get_notAfter(x509.get()), kCertLifetimeSeconds);
 
-    // "publicKey" corresponds to the raw public key bytes - need to create
-    // a new RSA key with the correct exponent.
-    auto rsaPubkey = getRsa(publicKey);
-    if (!rsaPubkey.ok()) {
-        return rsaPubkey.error();
+    auto public_key = toRsaPkey(publicKey);
+    if (!public_key.ok()) {
+        return public_key.error();
     }
 
-    bssl::UniquePtr<EVP_PKEY> public_key(EVP_PKEY_new());
-    EVP_PKEY_assign_RSA(public_key.get(), rsaPubkey->release());
-
-    if (!X509_set_pubkey(x509.get(), public_key.get())) {
+    if (!X509_set_pubkey(x509.get(), public_key.value().get())) {
         return Error() << "Unable to set x509 public key";
     }
 
@@ -136,7 +149,7 @@
     add_ext(x509.get(), NID_basic_constraints, kBasicConstraints);
     add_ext(x509.get(), NID_key_usage, kKeyUsage);
     add_ext(x509.get(), NID_subject_key_identifier, kSubjectKeyIdentifier);
-    add_ext(x509.get(), NID_authority_key_identifier, "keyid:always");
+    add_ext(x509.get(), NID_authority_key_identifier, kAuthorityKeyIdentifier);
 
     bssl::UniquePtr<X509_ALGOR> algor(X509_ALGOR_new());
     if (!algor ||
@@ -201,9 +214,9 @@
     return extractPublicKey(public_key.get());
 }
 
-Result<std::vector<uint8_t>> extractPublicKeyFromX509(const std::vector<uint8_t>& keyData) {
-    auto keyDataBytes = keyData.data();
-    bssl::UniquePtr<X509> decoded_cert(d2i_X509(nullptr, &keyDataBytes, keyData.size()));
+Result<std::vector<uint8_t>> extractPublicKeyFromX509(const std::vector<uint8_t>& derCert) {
+    auto derCertBytes = derCert.data();
+    bssl::UniquePtr<X509> decoded_cert(d2i_X509(nullptr, &derCertBytes, derCert.size()));
     if (decoded_cert.get() == nullptr) {
         return Error() << "Failed to decode X509 certificate.";
     }
@@ -212,7 +225,7 @@
     return extractPublicKey(decoded_pkey.get());
 }
 
-Result<std::vector<uint8_t>> extractPublicKeyFromX509(const std::string& path) {
+static Result<bssl::UniquePtr<X509>> loadX509(const std::string& path) {
     X509* rawCert;
     auto f = fopen(path.c_str(), "re");
     if (f == nullptr) {
@@ -225,7 +238,58 @@
     bssl::UniquePtr<X509> cert(rawCert);
 
     fclose(f);
-    return extractPublicKey(X509_get_pubkey(cert.get()));
+    return cert;
+}
+
+Result<std::vector<uint8_t>> extractPublicKeyFromX509(const std::string& path) {
+    auto cert = loadX509(path);
+    if (!cert.ok()) {
+        return cert.error();
+    }
+    return extractPublicKey(X509_get_pubkey(cert.value().get()));
+}
+
+Result<CertInfo> verifyAndExtractCertInfoFromX509(const std::string& path,
+                                                  const std::vector<uint8_t>& publicKey) {
+    auto public_key = toRsaPkey(publicKey);
+    if (!public_key.ok()) {
+        return public_key.error();
+    }
+
+    auto cert = loadX509(path);
+    if (!cert.ok()) {
+        return cert.error();
+    }
+    X509* x509 = cert.value().get();
+
+    // Make sure we signed it.
+    if (X509_verify(x509, public_key.value().get()) != 1) {
+        return Error() << "Failed to verify certificate.";
+    }
+
+    bssl::UniquePtr<EVP_PKEY> pkey(X509_get_pubkey(x509));
+    auto subject_key = extractPublicKey(pkey.get());
+    if (!subject_key.ok()) {
+        return subject_key.error();
+    }
+
+    // The pointers here are all owned by x509, and each function handles an
+    // error return from the previous call correctly.
+    X509_NAME* name = X509_get_subject_name(x509);
+    int index = X509_NAME_get_index_by_NID(name, NID_commonName, -1);
+    X509_NAME_ENTRY* entry = X509_NAME_get_entry(name, index);
+    ASN1_STRING* asn1cn = X509_NAME_ENTRY_get_data(entry);
+    unsigned char* utf8cn;
+    int length = ASN1_STRING_to_UTF8(&utf8cn, asn1cn);
+    if (length < 0) {
+        return Error() << "Failed to read subject CN";
+    }
+
+    bssl::UniquePtr<unsigned char> utf8owner(utf8cn);
+    std::string cn(reinterpret_cast<char*>(utf8cn), static_cast<size_t>(length));
+
+    CertInfo cert_info{std::move(cn), std::move(subject_key.value())};
+    return cert_info;
 }
 
 Result<std::vector<uint8_t>> createPkcs7(const std::vector<uint8_t>& signed_digest) {
diff --git a/ondevice-signing/CertUtils.h b/ondevice-signing/CertUtils.h
index 66dff04..d202fbc 100644
--- a/ondevice-signing/CertUtils.h
+++ b/ondevice-signing/CertUtils.h
@@ -18,6 +18,11 @@
 
 #include <android-base/result.h>
 
+struct CertInfo {
+    std::string subjectCn;
+    std::vector<uint8_t> subjectKey;
+};
+
 android::base::Result<void> createSelfSignedCertificate(
     const std::vector<uint8_t>& publicKey,
     const std::function<android::base::Result<std::string>(const std::string&)>& signFunction,
@@ -30,6 +35,9 @@
 extractPublicKeyFromSubjectPublicKeyInfo(const std::vector<uint8_t>& subjectKeyInfo);
 android::base::Result<std::vector<uint8_t>> extractPublicKeyFromX509(const std::string& path);
 
+android::base::Result<CertInfo>
+verifyAndExtractCertInfoFromX509(const std::string& path, const std::vector<uint8_t>& publicKey);
+
 android::base::Result<void> verifySignature(const std::string& message,
                                             const std::string& signature,
                                             const std::vector<uint8_t>& publicKey);
diff --git a/ondevice-signing/VerityUtils.cpp b/ondevice-signing/VerityUtils.cpp
index 2c4dc6d..25f949c 100644
--- a/ondevice-signing/VerityUtils.cpp
+++ b/ondevice-signing/VerityUtils.cpp
@@ -243,8 +243,8 @@
     return digests;
 }
 
-Result<void> addCertToFsVerityKeyring(const std::string& path) {
-    const char* const argv[] = {kFsVerityInitPath, "--load-extra-key", "fsv_ods"};
+Result<void> addCertToFsVerityKeyring(const std::string& path, const char* keyName) {
+    const char* const argv[] = {kFsVerityInitPath, "--load-extra-key", keyName};
 
     int fd = open(path.c_str(), O_RDONLY | O_CLOEXEC);
     pid_t pid = fork();
diff --git a/ondevice-signing/VerityUtils.h b/ondevice-signing/VerityUtils.h
index 84af319..dca3184 100644
--- a/ondevice-signing/VerityUtils.h
+++ b/ondevice-signing/VerityUtils.h
@@ -20,7 +20,7 @@
 
 #include "SigningKey.h"
 
-android::base::Result<void> addCertToFsVerityKeyring(const std::string& path);
+android::base::Result<void> addCertToFsVerityKeyring(const std::string& path, const char* keyName);
 android::base::Result<std::vector<uint8_t>> createDigest(const std::string& path);
 android::base::Result<std::map<std::string, std::string>>
 verifyAllFilesInVerity(const std::string& path);
diff --git a/ondevice-signing/odsign_main.cpp b/ondevice-signing/odsign_main.cpp
index 0991704..135c4a0 100644
--- a/ondevice-signing/odsign_main.cpp
+++ b/ondevice-signing/odsign_main.cpp
@@ -44,7 +44,6 @@
 
 using OdsignInfo = ::odsign::proto::OdsignInfo;
 
-const std::string kSigningKeyBlob = "/data/misc/odsign/key.blob";
 const std::string kSigningKeyCert = "/data/misc/odsign/key.cert";
 const std::string kOdsignInfo = "/data/misc/odsign/odsign.info";
 const std::string kOdsignInfoSignature = "/data/misc/odsign/odsign.info.signature";
@@ -56,6 +55,10 @@
 static const char* kFsVerityProcPath = "/proc/sys/fs/verity";
 
 static const bool kForceCompilation = false;
+static const bool kUseCompOs = false;  // STOPSHIP if true
+
+static const char* kVirtApexPath = "/apex/com.android.virt";
+const std::string kCompOsCert = "/data/misc/odsign/compos_key.cert";
 
 static const char* kOdsignVerificationDoneProp = "odsign.verification.done";
 static const char* kOdsignKeyDoneProp = "odsign.key.done";
@@ -64,13 +67,17 @@
 static const char* kOdsignVerificationStatusValid = "1";
 static const char* kOdsignVerificationStatusError = "0";
 
-Result<void> verifyExistingCert(const SigningKey& key) {
+bool compOsPresent() {
+    return access(kVirtApexPath, F_OK) == 0;
+}
+
+Result<void> verifyExistingRootCert(const SigningKey& key) {
     if (access(kSigningKeyCert.c_str(), F_OK) < 0) {
         return ErrnoError() << "Key certificate not found: " << kSigningKeyCert;
     }
     auto trustedPublicKey = key.getPublicKey();
     if (!trustedPublicKey.ok()) {
-        return Error() << "Failed to retrieve signing public key.";
+        return Error() << "Failed to retrieve signing public key: " << trustedPublicKey.error();
     }
 
     auto publicKeyFromExistingCert = extractPublicKeyFromX509(kSigningKeyCert);
@@ -82,11 +89,12 @@
                        << " does not match signing public key.";
     }
 
-    // At this point, we know the cert matches
+    // At this point, we know the cert is for our key; it's unimportant whether it's
+    // actually self-signed.
     return {};
 }
 
-Result<void> createX509Cert(const SigningKey& key, const std::string& outPath) {
+Result<void> createX509RootCert(const SigningKey& key, const std::string& outPath) {
     auto publicKey = key.getPublicKey();
 
     if (!publicKey.ok()) {
@@ -98,6 +106,32 @@
     return {};
 }
 
+Result<std::vector<uint8_t>> extractPublicKeyFromLeafCert(const SigningKey& key,
+                                                          const std::string& certPath,
+                                                          const std::string& expectedCn) {
+    if (access(certPath.c_str(), F_OK) < 0) {
+        return ErrnoError() << "Certificate not found: " << kCompOsCert;
+    }
+    auto trustedPublicKey = key.getPublicKey();
+    if (!trustedPublicKey.ok()) {
+        return Error() << "Failed to retrieve signing public key: " << trustedPublicKey.error();
+    }
+
+    auto existingCertInfo = verifyAndExtractCertInfoFromX509(certPath, trustedPublicKey.value());
+    if (!existingCertInfo.ok()) {
+        return Error() << "Failed to verify certificate at " << certPath << ": "
+                       << existingCertInfo.error();
+    }
+
+    auto& actualCn = existingCertInfo.value().subjectCn;
+    if (actualCn != expectedCn) {
+        return Error() << "CN of existing certificate at " << certPath << " is " << actualCn
+                       << ", should be " << expectedCn;
+    }
+
+    return existingCertInfo.value().subjectKey;
+}
+
 art::odrefresh::ExitCode compileArtifacts(bool force) {
     const char* const argv[] = {kOdrefreshPath, force ? "--force-compile" : "--compile"};
     const int exit_code =
@@ -263,7 +297,7 @@
     // by the next boot.
     SetProperty(kOdsignKeyDoneProp, "1");
     if (!signInfo.ok()) {
-        return Error() << signInfo.error().message();
+        return signInfo.error();
     }
     std::map<std::string, std::string> trusted_digests(signInfo->file_hashes().begin(),
                                                        signInfo->file_hashes().end());
@@ -275,7 +309,7 @@
         integrityStatus = verifyIntegrityNoFsVerity(trusted_digests);
     }
     if (!integrityStatus.ok()) {
-        return Error() << integrityStatus.error().message();
+        return integrityStatus.error();
     }
 
     return {};
@@ -310,13 +344,15 @@
         LOG(INFO) << "Device doesn't support fsverity. Falling back to full verification.";
     }
 
+    bool supportsCompOs = kUseCompOs && supportsFsVerity && compOsPresent();
+
     if (supportsFsVerity) {
-        auto existing_cert = verifyExistingCert(*key);
+        auto existing_cert = verifyExistingRootCert(*key);
         if (!existing_cert.ok()) {
             LOG(WARNING) << existing_cert.error().message();
 
             // Try to create a new cert
-            auto new_cert = createX509Cert(*key, kSigningKeyCert);
+            auto new_cert = createX509RootCert(*key, kSigningKeyCert);
             if (!new_cert.ok()) {
                 LOG(ERROR) << "Failed to create X509 certificate: " << new_cert.error().message();
                 // TODO apparently the key become invalid - delete the blob / cert
@@ -325,7 +361,7 @@
         } else {
             LOG(INFO) << "Found and verified existing public key certificate: " << kSigningKeyCert;
         }
-        auto cert_add_result = addCertToFsVerityKeyring(kSigningKeyCert);
+        auto cert_add_result = addCertToFsVerityKeyring(kSigningKeyCert, "fsv_ods");
         if (!cert_add_result.ok()) {
             LOG(ERROR) << "Failed to add certificate to fs-verity keyring: "
                        << cert_add_result.error().message();
@@ -333,6 +369,27 @@
         }
     }
 
+    if (supportsCompOs) {
+        auto compos_key = extractPublicKeyFromLeafCert(*key, kCompOsCert, "CompOS");
+        if (compos_key.ok()) {
+            auto cert_add_result = addCertToFsVerityKeyring(kCompOsCert, "fsv_compos");
+            if (cert_add_result.ok()) {
+                LOG(INFO) << "Added CompOs key to fs-verity keyring";
+            } else {
+                LOG(ERROR) << "Failed to add CompOs certificate to fs-verity keyring: "
+                           << cert_add_result.error().message();
+                // TODO - what do we do now?
+                // return -1;
+            }
+        } else {
+            LOG(ERROR) << "Failed to retrieve key from CompOs certificate: "
+                       << compos_key.error().message();
+            // Best efforts only - nothing we can do if deletion fails.
+            unlink(kCompOsCert.c_str());
+            // TODO - what do we do now?
+        }
+    }
+
     art::odrefresh::ExitCode odrefresh_status = compileArtifacts(kForceCompilation);
     if (odrefresh_status == art::odrefresh::ExitCode::kOkay) {
         LOG(INFO) << "odrefresh said artifacts are VALID";