Merge "Support to get EC public key from the UdsCertchain." into stage-aosp-udc-ts-dev
diff --git a/security/keymint/support/remote_prov_utils.cpp b/security/keymint/support/remote_prov_utils.cpp
index 3cb783c..c9c3e4d 100644
--- a/security/keymint/support/remote_prov_utils.cpp
+++ b/security/keymint/support/remote_prov_utils.cpp
@@ -115,6 +115,36 @@
     return std::make_tuple(std::move(pubX), std::move(pubY));
 }
 
+ErrMsgOr<bytevec> getRawPublicKey(const EVP_PKEY_Ptr& pubKey) {
+    if (pubKey.get() == nullptr) {
+        return "pkey is null.";
+    }
+    int keyType = EVP_PKEY_base_id(pubKey.get());
+    switch (keyType) {
+        case EVP_PKEY_EC: {
+            auto ecKey = EC_KEY_Ptr(EVP_PKEY_get1_EC_KEY(pubKey.get()));
+            if (ecKey.get() == nullptr) {
+                return "Failed to get ec key";
+            }
+            return ecKeyGetPublicKey(ecKey.get());
+        }
+        case EVP_PKEY_ED25519: {
+            bytevec rawPubKey;
+            size_t rawKeySize = 0;
+            if (!EVP_PKEY_get_raw_public_key(pubKey.get(), NULL, &rawKeySize)) {
+                return "Failed to get raw public key.";
+            }
+            rawPubKey.resize(rawKeySize);
+            if (!EVP_PKEY_get_raw_public_key(pubKey.get(), rawPubKey.data(), &rawKeySize)) {
+                return "Failed to get raw public key.";
+            }
+            return rawPubKey;
+        }
+        default:
+            return "Unknown key type.";
+    }
+}
+
 ErrMsgOr<std::tuple<bytevec, bytevec>> generateEc256KeyPair() {
     auto ec_key = EC_KEY_Ptr(EC_KEY_new());
     if (ec_key.get() == nullptr) {
@@ -706,11 +736,10 @@
 
 // Validates the certificate chain and returns the leaf public key.
 ErrMsgOr<bytevec> validateCertChain(const cppbor::Array& chain) {
-    uint8_t rawPubKey[64];
-    size_t rawPubKeySize = sizeof(rawPubKey);
+    bytevec rawPubKey;
     for (size_t i = 0; i < chain.size(); ++i) {
         // Root must be self-signed.
-        size_t signingCertIndex = (i > 1) ? i - 1 : i;
+        size_t signingCertIndex = (i > 0) ? i - 1 : i;
         auto& keyCertItem = chain[i];
         auto& signingCertItem = chain[signingCertIndex];
         if (!keyCertItem || !keyCertItem->asBstr()) {
@@ -724,7 +753,7 @@
         if (!keyCert) {
             return keyCert.message();
         }
-        auto signingCert = parseX509Cert(keyCertItem->asBstr()->value());
+        auto signingCert = parseX509Cert(signingCertItem->asBstr()->value());
         if (!signingCert) {
             return signingCert.message();
         }
@@ -749,17 +778,16 @@
             return "Certificate " + std::to_string(i) + " has wrong issuer. Signer subject is " +
                    signerSubj + " Issuer subject is " + certIssuer;
         }
-
-        rawPubKeySize = sizeof(rawPubKey);
-        if (!EVP_PKEY_get_raw_public_key(pubKey.get(), rawPubKey, &rawPubKeySize)) {
-            return "Failed to get raw public key.";
+        if (i == chain.size() - 1) {
+            auto key = getRawPublicKey(pubKey);
+            if (!key) key.moveMessage();
+            rawPubKey = key.moveValue();
         }
     }
-
-    return bytevec(rawPubKey, rawPubKey + rawPubKeySize);
+    return rawPubKey;
 }
 
-std::string validateUdsCerts(const cppbor::Map& udsCerts, const bytevec& udsPub) {
+std::string validateUdsCerts(const cppbor::Map& udsCerts, const bytevec& udsCoseKeyBytes) {
     for (const auto& [signerName, udsCertChain] : udsCerts) {
         if (!signerName || !signerName->asTstr()) {
             return "Signer Name must be a Tstr.";
@@ -775,8 +803,31 @@
         if (!leafPubKey) {
             return leafPubKey.message();
         }
+        auto coseKey = CoseKey::parse(udsCoseKeyBytes);
+        if (!coseKey) return coseKey.moveMessage();
+
+        auto curve = coseKey->getIntValue(CoseKey::CURVE);
+        if (!curve) {
+            return "CoseKey must contain curve.";
+        }
+        bytevec udsPub;
+        if (curve == CoseKeyCurve::P256 || curve == CoseKeyCurve::P384) {
+            auto pubKey = coseKey->getEcPublicKey();
+            if (!pubKey) return pubKey.moveMessage();
+            // convert public key to uncompressed form by prepending 0x04 at begin.
+            pubKey->insert(pubKey->begin(), 0x04);
+            udsPub = pubKey.moveValue();
+        } else if (curve == CoseKeyCurve::ED25519) {
+            auto& pubkey = coseKey->getMap().get(cppcose::CoseKey::PUBKEY_X);
+            if (!pubkey || !pubkey->asBstr()) {
+                return "Invalid public key.";
+            }
+            udsPub = pubkey->asBstr()->value();
+        } else {
+            return "Unknown curve.";
+        }
         if (*leafPubKey != udsPub) {
-            return "Leaf public key in UDS certificat chain doesn't match UDS public key.";
+            return "Leaf public key in UDS certificate chain doesn't match UDS public key.";
         }
     }
     return "";