Expose more functionality from hwtrust

Instead of having parsing code in remote_prov_utils and hwtrust, we
should expose more of the values that are parsed in hwtrust and use
that interface to implement the verification of a CSR.

The aim of this CL is to preserve the functionality that previously
existed in remote_prov_utils and, also, add more tests.

Test: atest libkeymint_remote_prov_support_test
Change-Id: Id5408a425f28ea99052ba954c34441ed9307a5d2
diff --git a/security/keymint/support/include/remote_prov/remote_prov_utils.h b/security/keymint/support/include/remote_prov/remote_prov_utils.h
index 6cb00f2..9035d79 100644
--- a/security/keymint/support/include/remote_prov/remote_prov_utils.h
+++ b/security/keymint/support/include/remote_prov/remote_prov_utils.h
@@ -29,6 +29,11 @@
 using bytevec = std::vector<uint8_t>;
 using namespace cppcose;
 
+constexpr std::string_view kErrorChallengeMismatch = "challenges do not match";
+constexpr std::string_view kErrorUdsCertsAreRequired = "UdsCerts are required";
+constexpr std::string_view kErrorKeysToSignMismatch = "KeysToSign do not match";
+constexpr std::string_view kErrorDiceChainIsDegenerate = "DICE chain is degenerate";
+
 extern bytevec kTestMacKey;
 
 // The Google root key for the Endpoint Encryption Key chain, encoded as COSE_Sign1
diff --git a/security/keymint/support/remote_prov_utils.cpp b/security/keymint/support/remote_prov_utils.cpp
index fdc0f28..ccb2768 100644
--- a/security/keymint/support/remote_prov_utils.cpp
+++ b/security/keymint/support/remote_prov_utils.cpp
@@ -123,37 +123,6 @@
     return std::make_tuple(std::move(pubX), std::move(pubY));
 }
 
-ErrMsgOr<bytevec> getRawPublicKey(const EVP_PKEY_Ptr& pubKey) {
-    if (pubKey.get() == nullptr) {
-        return "pkey is null.";
-    }
-    int keyType = EVP_PKEY_base_id(pubKey.get());
-    switch (keyType) {
-        case EVP_PKEY_EC: {
-            int nid = EVP_PKEY_bits(pubKey.get()) == 384 ? NID_secp384r1 : NID_X9_62_prime256v1;
-            auto ecKey = EC_KEY_Ptr(EVP_PKEY_get1_EC_KEY(pubKey.get()));
-            if (ecKey.get() == nullptr) {
-                return "Failed to get ec key";
-          }
-          return ecKeyGetPublicKey(ecKey.get(), nid);
-        }
-        case EVP_PKEY_ED25519: {
-            bytevec rawPubKey;
-            size_t rawKeySize = 0;
-            if (!EVP_PKEY_get_raw_public_key(pubKey.get(), NULL, &rawKeySize)) {
-                return "Failed to get raw public key.";
-            }
-            rawPubKey.resize(rawKeySize);
-            if (!EVP_PKEY_get_raw_public_key(pubKey.get(), rawPubKey.data(), &rawKeySize)) {
-                return "Failed to get raw public key.";
-            }
-            return rawPubKey;
-        }
-        default:
-            return "Unknown key type.";
-    }
-}
-
 ErrMsgOr<std::tuple<bytevec, bytevec>> generateEc256KeyPair() {
     auto ec_key = EC_KEY_Ptr(EC_KEY_new());
     if (ec_key.get() == nullptr) {
@@ -166,7 +135,7 @@
     }
 
     if (EC_KEY_set_group(ec_key.get(), group.get()) != 1 ||
-        EC_KEY_generate_key(ec_key.get()) != 1 || EC_KEY_check_key(ec_key.get()) < 0) {
+        EC_KEY_generate_key(ec_key.get()) != 1 || EC_KEY_check_key(ec_key.get()) != 1) {
         return "Error generating key";
     }
 
@@ -331,17 +300,22 @@
     return chain.encode();
 }
 
+bool maybeOverrideAllowAnyMode(bool allowAnyMode) {
+    // Use ro.build.type instead of ro.debuggable because ro.debuggable=1 for VTS testing
+    std::string build_type = ::android::base::GetProperty("ro.build.type", "");
+    if (!build_type.empty() && build_type != "user") {
+        return true;
+    }
+    return allowAnyMode;
+}
+
 ErrMsgOr<std::vector<BccEntryData>> validateBcc(const cppbor::Array* bcc,
                                                 hwtrust::DiceChain::Kind kind, bool allowAnyMode,
                                                 bool allowDegenerate,
                                                 const std::string& instanceName) {
     auto encodedBcc = bcc->encode();
 
-    // Use ro.build.type instead of ro.debuggable because ro.debuggable=1 for VTS testing
-    std::string build_type = ::android::base::GetProperty("ro.build.type", "");
-    if (!build_type.empty() && build_type != "user") {
-        allowAnyMode = true;
-    }
+    allowAnyMode = maybeOverrideAllowAnyMode(allowAnyMode);
 
     auto chain =
             hwtrust::DiceChain::Verify(encodedBcc, kind, allowAnyMode, deviceSuffix(instanceName));
@@ -779,230 +753,6 @@
                                /*isFactory=*/false, allowAnyMode);
 }
 
-ErrMsgOr<X509_Ptr> parseX509Cert(const std::vector<uint8_t>& cert) {
-    CRYPTO_BUFFER_Ptr certBuf(CRYPTO_BUFFER_new(cert.data(), cert.size(), nullptr));
-    if (!certBuf.get()) {
-        return "Failed to create crypto buffer.";
-    }
-    X509_Ptr result(X509_parse_from_buffer(certBuf.get()));
-    if (!result.get()) {
-        return "Failed to parse certificate.";
-    }
-    return result;
-}
-
-std::string getX509IssuerName(const X509_Ptr& cert) {
-    char* name = X509_NAME_oneline(X509_get_issuer_name(cert.get()), nullptr, 0);
-    std::string result(name);
-    OPENSSL_free(name);
-    return result;
-}
-
-std::string getX509SubjectName(const X509_Ptr& cert) {
-    char* name = X509_NAME_oneline(X509_get_subject_name(cert.get()), nullptr, 0);
-    std::string result(name);
-    OPENSSL_free(name);
-    return result;
-}
-
-// Validates the certificate chain and returns the leaf public key.
-ErrMsgOr<bytevec> validateCertChain(const cppbor::Array& chain) {
-    bytevec rawPubKey;
-    for (size_t i = 0; i < chain.size(); ++i) {
-        // Root must be self-signed.
-        size_t signingCertIndex = (i > 0) ? i - 1 : i;
-        auto& keyCertItem = chain[i];
-        auto& signingCertItem = chain[signingCertIndex];
-        if (!keyCertItem || !keyCertItem->asBstr()) {
-            return "Key certificate must be a Bstr.";
-        }
-        if (!signingCertItem || !signingCertItem->asBstr()) {
-            return "Signing certificate must be a Bstr.";
-        }
-
-        auto keyCert = parseX509Cert(keyCertItem->asBstr()->value());
-        if (!keyCert) {
-            return keyCert.message();
-        }
-        auto signingCert = parseX509Cert(signingCertItem->asBstr()->value());
-        if (!signingCert) {
-            return signingCert.message();
-        }
-
-        EVP_PKEY_Ptr pubKey(X509_get_pubkey(keyCert->get()));
-        if (!pubKey.get()) {
-            return "Failed to get public key.";
-        }
-        EVP_PKEY_Ptr signingPubKey(X509_get_pubkey(signingCert->get()));
-        if (!signingPubKey.get()) {
-            return "Failed to get signing public key.";
-        }
-
-        if (!X509_verify(keyCert->get(), signingPubKey.get())) {
-            return "Verification of certificate " + std::to_string(i) +
-                   " faile. OpenSSL error string: " + ERR_error_string(ERR_get_error(), NULL);
-        }
-
-        auto certIssuer = getX509IssuerName(*keyCert);
-        auto signerSubj = getX509SubjectName(*signingCert);
-        if (certIssuer != signerSubj) {
-            return "Certificate " + std::to_string(i) + " has wrong issuer. Signer subject is " +
-                   signerSubj + " Issuer subject is " + certIssuer;
-        }
-        if (i == chain.size() - 1) {
-            auto key = getRawPublicKey(pubKey);
-            if (!key) return key.moveMessage();
-            rawPubKey = key.moveValue();
-        }
-    }
-    return rawPubKey;
-}
-
-std::optional<std::string> validateUdsCerts(const cppbor::Map& udsCerts,
-                                            const bytevec& udsCoseKeyBytes) {
-    for (const auto& [signerName, udsCertChain] : udsCerts) {
-        if (!signerName || !signerName->asTstr()) {
-            return "Signer Name must be a Tstr.";
-        }
-        if (!udsCertChain || !udsCertChain->asArray()) {
-            return "UDS certificate chain must be an Array.";
-        }
-        if (udsCertChain->asArray()->size() < 2) {
-            return "UDS certificate chain must have at least two entries: root and leaf.";
-        }
-
-        auto leafPubKey = validateCertChain(*udsCertChain->asArray());
-        if (!leafPubKey) {
-            return leafPubKey.message();
-        }
-        auto coseKey = CoseKey::parse(udsCoseKeyBytes);
-        if (!coseKey) {
-            return coseKey.moveMessage();
-        }
-        auto curve = coseKey->getIntValue(CoseKey::CURVE);
-        if (!curve) {
-            return "CoseKey must contain curve.";
-        }
-        bytevec udsPub;
-        if (curve == CoseKeyCurve::P256 || curve == CoseKeyCurve::P384) {
-            auto pubKey = coseKey->getEcPublicKey();
-            if (!pubKey) {
-                return pubKey.moveMessage();
-            }
-            // convert public key to uncompressed form by prepending 0x04 at begin.
-            pubKey->insert(pubKey->begin(), 0x04);
-            udsPub = pubKey.moveValue();
-        } else if (curve == CoseKeyCurve::ED25519) {
-            auto& pubkey = coseKey->getMap().get(cppcose::CoseKey::PUBKEY_X);
-            if (!pubkey || !pubkey->asBstr()) {
-                return "Invalid public key.";
-            }
-            udsPub = pubkey->asBstr()->value();
-        } else {
-            return "Unknown curve.";
-        }
-        if (*leafPubKey != udsPub) {
-            return "Leaf public key in UDS certificate chain doesn't match UDS public key.";
-        }
-    }
-    return std::nullopt;
-}
-
-ErrMsgOr<std::unique_ptr<cppbor::Array>> parseAndValidateCsrPayload(
-        const cppbor::Array& keysToSign, const std::vector<uint8_t>& csrPayload,
-        const RpcHardwareInfo& rpcHardwareInfo, bool isFactory) {
-    auto [parsedCsrPayload, _, errMsg] = cppbor::parse(csrPayload);
-    if (!parsedCsrPayload) {
-        return errMsg;
-    }
-
-    std::unique_ptr<cppbor::Array> parsed(parsedCsrPayload.release()->asArray());
-    if (!parsed) {
-        return "CSR payload is not a CBOR array.";
-    }
-
-    if (parsed->size() != 4U) {
-        return "CSR payload must contain version, certificate type, device info, keys. "
-               "However, the parsed CSR payload has " +
-               std::to_string(parsed->size()) + " entries.";
-    }
-
-    auto signedVersion = parsed->get(0)->asUint();
-    auto signedCertificateType = parsed->get(1)->asTstr();
-    auto signedDeviceInfo = parsed->get(2)->asMap();
-    auto signedKeys = parsed->get(3)->asArray();
-
-    if (!signedVersion || signedVersion->value() != 3U) {
-        return "CSR payload version must be an unsigned integer and must be equal to 3.";
-    }
-    if (!signedCertificateType) {
-        // Certificate type is allowed to be extendend by vendor, i.e. we can't
-        // enforce its value.
-        return "Certificate type must be a Tstr.";
-    }
-    if (!signedDeviceInfo) {
-        return "Device info must be an Map.";
-    }
-    if (!signedKeys) {
-        return "Keys must be an Array.";
-    }
-
-    auto result =
-            parseAndValidateDeviceInfo(signedDeviceInfo->encode(), rpcHardwareInfo, isFactory);
-    if (!result) {
-        return result.message();
-    }
-
-    if (signedKeys->encode() != keysToSign.encode()) {
-        return "Signed keys do not match.";
-    }
-
-    return std::move(parsed);
-}
-
-ErrMsgOr<bytevec> parseAndValidateAuthenticatedRequestSignedPayload(
-        const std::vector<uint8_t>& signedPayload, const std::vector<uint8_t>& challenge) {
-    auto [parsedSignedPayload, _, errMsg] = cppbor::parse(signedPayload);
-    if (!parsedSignedPayload) {
-        return errMsg;
-    }
-    if (!parsedSignedPayload->asArray()) {
-        return "SignedData payload is not a CBOR array.";
-    }
-    if (parsedSignedPayload->asArray()->size() != 2U) {
-        return "SignedData payload must contain the challenge and request. However, the parsed "
-               "SignedData payload has " +
-               std::to_string(parsedSignedPayload->asArray()->size()) + " entries.";
-    }
-
-    auto signedChallenge = parsedSignedPayload->asArray()->get(0)->asBstr();
-    auto signedRequest = parsedSignedPayload->asArray()->get(1)->asBstr();
-
-    if (!signedChallenge) {
-        return "Challenge must be a Bstr.";
-    }
-
-    if (challenge.size() > 64) {
-        return "Challenge size must be between 0 and 64 bytes inclusive. "
-               "However, challenge is " +
-               std::to_string(challenge.size()) + " bytes long.";
-    }
-
-    auto challengeBstr = cppbor::Bstr(challenge);
-    if (*signedChallenge != challengeBstr) {
-        return "Signed challenge does not match."
-               "\n  Actual: " +
-               cppbor::prettyPrint(signedChallenge->asBstr(), 64 /* maxBStrSize */) +
-               "\nExpected: " + cppbor::prettyPrint(&challengeBstr, 64 /* maxBStrSize */);
-    }
-
-    if (!signedRequest) {
-        return "Request must be a Bstr.";
-    }
-
-    return signedRequest->value();
-}
-
 ErrMsgOr<hwtrust::DiceChain::Kind> getDiceChainKind() {
     int vendor_api_level = ::android::base::GetIntProperty("ro.vendor.api_level", -1);
     if (vendor_api_level <= __ANDROID_API_T__) {
@@ -1018,87 +768,8 @@
     }
 }
 
-ErrMsgOr<bytevec> parseAndValidateAuthenticatedRequest(const std::vector<uint8_t>& request,
-                                                       const std::vector<uint8_t>& challenge,
-                                                       const std::string& instanceName,
-                                                       bool allowAnyMode = false,
-                                                       bool allowDegenerate = true,
-                                                       bool requireUdsCerts = false) {
-    auto [parsedRequest, _, csrErrMsg] = cppbor::parse(request);
-    if (!parsedRequest) {
-        return csrErrMsg;
-    }
-    if (!parsedRequest->asArray()) {
-        return "AuthenticatedRequest is not a CBOR array.";
-    }
-    if (parsedRequest->asArray()->size() != 4U) {
-        return "AuthenticatedRequest must contain version, UDS certificates, DICE chain, and "
-               "signed data. However, the parsed AuthenticatedRequest has " +
-               std::to_string(parsedRequest->asArray()->size()) + " entries.";
-    }
-
-    auto version = parsedRequest->asArray()->get(0)->asUint();
-    auto udsCerts = parsedRequest->asArray()->get(1)->asMap();
-    auto diceCertChain = parsedRequest->asArray()->get(2)->asArray();
-    auto signedData = parsedRequest->asArray()->get(3)->asArray();
-
-    if (!version || version->value() != 1U) {
-        return "AuthenticatedRequest version must be an unsigned integer and must be equal to 1.";
-    }
-
-    if (!udsCerts) {
-        return "AuthenticatedRequest UdsCerts must be a Map.";
-    }
-    if (requireUdsCerts && udsCerts->size() == 0) {
-        return "AuthenticatedRequest UdsCerts must not be empty.";
-    }
-    if (!diceCertChain) {
-        return "AuthenticatedRequest DiceCertChain must be an Array.";
-    }
-    if (!signedData) {
-        return "AuthenticatedRequest SignedData must be an Array.";
-    }
-
-    // DICE chain is [ pubkey, + DiceChainEntry ].
-    auto diceChainKind = getDiceChainKind();
-    if (!diceChainKind) {
-        return diceChainKind.message();
-    }
-
-    auto diceContents =
-            validateBcc(diceCertChain, *diceChainKind, allowAnyMode, allowDegenerate, instanceName);
-    if (!diceContents) {
-        return diceContents.message() + "\n" + prettyPrint(diceCertChain);
-    }
-
-    if (!diceCertChain->get(0)->asMap()) {
-        return "AuthenticatedRequest The first entry in DiceCertChain must be a Map.";
-    }
-    auto udsPub = diceCertChain->get(0)->asMap()->encode();
-    auto error = validateUdsCerts(*udsCerts, udsPub);
-    if (error) {
-        return *error;
-    }
-
-    if (diceContents->empty()) {
-        return "AuthenticatedRequest DiceContents must not be empty.";
-    }
-    auto& kmDiceKey = diceContents->back().pubKey;
-    auto signedPayload = verifyAndParseCoseSign1(signedData, kmDiceKey, /*aad=*/{});
-    if (!signedPayload) {
-        return signedPayload.message();
-    }
-
-    auto payload = parseAndValidateAuthenticatedRequestSignedPayload(*signedPayload, challenge);
-    if (!payload) {
-        return payload.message();
-    }
-
-    return payload;
-}
-
 ErrMsgOr<std::unique_ptr<cppbor::Array>> verifyCsr(
-        const cppbor::Array& keysToSign, const std::vector<uint8_t>& csr,
+        const cppbor::Array& keysToSign, const std::vector<uint8_t>& encodedCsr,
         const RpcHardwareInfo& rpcHardwareInfo, const std::string& instanceName,
         const std::vector<uint8_t>& challenge, bool isFactory, bool allowAnyMode = false,
         bool allowDegenerate = true, bool requireUdsCerts = false) {
@@ -1108,14 +779,68 @@
                ") does not match expected version (3).";
     }
 
-    auto csrPayload = parseAndValidateAuthenticatedRequest(
-            csr, challenge, instanceName, allowAnyMode, allowDegenerate, requireUdsCerts);
-
-    if (!csrPayload) {
-        return csrPayload.message();
+    auto diceChainKind = getDiceChainKind();
+    if (!diceChainKind) {
+        return diceChainKind.message();
     }
 
-    return parseAndValidateCsrPayload(keysToSign, *csrPayload, rpcHardwareInfo, isFactory);
+    allowAnyMode = maybeOverrideAllowAnyMode(allowAnyMode);
+
+    auto csr = hwtrust::Csr::validate(encodedCsr, *diceChainKind, isFactory, allowAnyMode,
+                                      deviceSuffix(instanceName));
+
+    if (!csr.ok()) {
+        return csr.error().message();
+    }
+
+    if (!allowDegenerate) {
+        auto diceChain = csr->getDiceChain();
+        if (!diceChain.ok()) {
+            return diceChain.error().message();
+        }
+
+        if (!diceChain->IsProper()) {
+            return kErrorDiceChainIsDegenerate;
+        }
+    }
+
+    if (requireUdsCerts && !csr->hasUdsCerts()) {
+        return kErrorUdsCertsAreRequired;
+    }
+
+    auto equalChallenges = csr->compareChallenge(challenge);
+    if (!equalChallenges.ok()) {
+        return equalChallenges.error().message();
+    }
+
+    if (!*equalChallenges) {
+        return kErrorChallengeMismatch;
+    }
+
+    auto equalKeysToSign = csr->compareKeysToSign(keysToSign.encode());
+    if (!equalKeysToSign.ok()) {
+        return equalKeysToSign.error().message();
+    }
+
+    if (!*equalKeysToSign) {
+        return kErrorKeysToSignMismatch;
+    }
+
+    auto csrPayload = csr->getCsrPayload();
+    if (!csrPayload) {
+        return csrPayload.error().message();
+    }
+
+    auto [csrPayloadDecoded, _, errMsg] = cppbor::parse(*csrPayload);
+    if (!csrPayloadDecoded) {
+        return errMsg;
+    }
+
+    if (!csrPayloadDecoded->asArray()) {
+        return "CSR payload is not an array.";
+    }
+
+    return std::unique_ptr<cppbor::Array>(csrPayloadDecoded.release()->asArray());
 }
 
 ErrMsgOr<std::unique_ptr<cppbor::Array>> verifyFactoryCsr(
@@ -1143,8 +868,8 @@
         return diceChainKind.message();
     }
 
-    auto csr = hwtrust::Csr::validate(encodedCsr, *diceChainKind, false /*allowAnyMode*/,
-                                      deviceSuffix(instanceName));
+    auto csr = hwtrust::Csr::validate(encodedCsr, *diceChainKind, false /*isFactory*/,
+                                      false /*allowAnyMode*/, deviceSuffix(instanceName));
     if (!csr.ok()) {
         return csr.error().message();
     }
diff --git a/security/keymint/support/remote_prov_utils_test.cpp b/security/keymint/support/remote_prov_utils_test.cpp
index a01762e..4ea8bd8 100644
--- a/security/keymint/support/remote_prov_utils_test.cpp
+++ b/security/keymint/support/remote_prov_utils_test.cpp
@@ -81,6 +81,23 @@
         0x50, 0x12, 0x82, 0x37, 0xfe, 0xa4, 0x07, 0xc3, 0xd5, 0xc3, 0x78, 0xcc, 0xf9, 0xef, 0xe1,
         0x95, 0x38, 0x9f, 0xb0, 0x79, 0x16, 0x4c, 0x4a, 0x23, 0xc4, 0xdc, 0x35, 0x4e, 0x0f};
 
+inline const std::vector<uint8_t> kKeysToSignForCsrWithDegenerateDiceChain{
+        0x82, 0xa6, 0x01, 0x02, 0x03, 0x26, 0x20, 0x01, 0x21, 0x58, 0x20, 0x1d, 0x94, 0xf2, 0x27,
+        0xc3, 0x70, 0x01, 0xde, 0x3c, 0xaf, 0x6f, 0xfd, 0x78, 0x08, 0x37, 0x39, 0x21, 0xdd, 0x46,
+        0x6f, 0x08, 0x4f, 0x77, 0xf7, 0x80, 0x34, 0x30, 0x74, 0x78, 0x69, 0xeb, 0xb1, 0x22, 0x58,
+        0x20, 0x6b, 0x71, 0xd7, 0x7f, 0x0e, 0x51, 0xb2, 0xc9, 0x3d, 0x1a, 0xa0, 0xe8, 0x7a, 0x0d,
+        0x57, 0xfc, 0x91, 0xd0, 0x68, 0xf9, 0x33, 0x5f, 0x80, 0x29, 0x00, 0x80, 0x98, 0x78, 0x63,
+        0x5b, 0x30, 0x24, 0x23, 0x58, 0x20, 0x09, 0x83, 0xa6, 0x5a, 0xbb, 0x3a, 0xf8, 0x90, 0x88,
+        0x87, 0x16, 0x37, 0xb4, 0xe7, 0x11, 0x9b, 0xcc, 0xbb, 0x15, 0x82, 0xa9, 0x97, 0xa5, 0xad,
+        0xa9, 0x85, 0x39, 0x30, 0x55, 0x46, 0x99, 0xc6, 0xa6, 0x01, 0x02, 0x03, 0x26, 0x20, 0x01,
+        0x21, 0x58, 0x20, 0xa8, 0xaa, 0x4b, 0x63, 0x86, 0xf6, 0x5c, 0xe4, 0x28, 0xda, 0x26, 0x3f,
+        0x9a, 0x42, 0x6e, 0xb9, 0x2b, 0x4d, 0x5a, 0x49, 0x4c, 0x5f, 0x1a, 0xa2, 0x5f, 0xd4, 0x8f,
+        0x84, 0xd7, 0x25, 0xe4, 0x6c, 0x22, 0x58, 0x20, 0x6b, 0xef, 0xde, 0xd6, 0x04, 0x58, 0x12,
+        0xdb, 0xf8, 0x90, 0x2c, 0x9c, 0xe0, 0x5e, 0x43, 0xbc, 0xcf, 0x22, 0x01, 0x4d, 0x5c, 0x0c,
+        0x86, 0x7b, 0x66, 0xd2, 0xa1, 0xfc, 0x69, 0x8a, 0x91, 0xfc, 0x23, 0x58, 0x20, 0x31, 0xaf,
+        0x30, 0x85, 0x1f, 0x2a, 0x82, 0xe1, 0x9c, 0xda, 0xe5, 0x68, 0xed, 0x79, 0xc1, 0x35, 0x1a,
+        0x02, 0xb4, 0x8a, 0xd2, 0x4c, 0xc4, 0x70, 0x6b, 0x88, 0x98, 0x23, 0x9e, 0xb3, 0x52, 0xb1};
+
 inline const std::vector<uint8_t> kCsrWithDegenerateDiceChain{
         0x85, 0x01, 0xa0, 0x82, 0xa5, 0x01, 0x02, 0x03, 0x26, 0x20, 0x01, 0x21, 0x58, 0x20, 0xf2,
         0xc6, 0x50, 0xd2, 0x42, 0x59, 0xe0, 0x4e, 0x7b, 0xc0, 0x75, 0x41, 0xa2, 0xe9, 0xd0, 0xe8,
@@ -172,7 +189,7 @@
         0xc9, 0x0a};
 
 inline const std::vector<uint8_t> kCsrWithUdsCerts{
-        0x84, 0x01, 0xa1, 0x70, 0x74, 0x65, 0x73, 0x74, 0x2d, 0x73, 0x69, 0x67, 0x6e, 0x65, 0x72,
+        0x85, 0x01, 0xa1, 0x70, 0x74, 0x65, 0x73, 0x74, 0x2d, 0x73, 0x69, 0x67, 0x6e, 0x65, 0x72,
         0x2d, 0x6e, 0x61, 0x6d, 0x65, 0x82, 0x59, 0x01, 0x6c, 0x30, 0x82, 0x01, 0x68, 0x30, 0x82,
         0x01, 0x1a, 0xa0, 0x03, 0x02, 0x01, 0x02, 0x02, 0x01, 0x7b, 0x30, 0x05, 0x06, 0x03, 0x2b,
         0x65, 0x70, 0x30, 0x2b, 0x31, 0x15, 0x30, 0x13, 0x06, 0x03, 0x55, 0x04, 0x0a, 0x13, 0x0c,
@@ -634,14 +651,87 @@
     ASSERT_FALSE(*result) << "DICE Chain is proper";
 }
 
+TEST(RemoteProvUtils, csrHasUdsCerts) {
+    auto csr = hwtrust::Csr::validate(kCsrWithUdsCerts, hwtrust::DiceChain::Kind::kVsr16,
+                                      false /*isFactory*/, false /*allowAnyMode*/,
+                                      deviceSuffix(DEFAULT_INSTANCE_NAME));
+    ASSERT_TRUE(csr.ok()) << csr.error().message();
+    ASSERT_TRUE(csr->hasUdsCerts());
+}
+
+TEST(RemoteProvUtils, csrDoesntHaveUdsCerts) {
+    auto csr = hwtrust::Csr::validate(kCsrWithoutUdsCerts, hwtrust::DiceChain::Kind::kVsr16,
+                                      false /*isFactory*/, false /*allowAnyMode*/,
+                                      deviceSuffix(DEFAULT_INSTANCE_NAME));
+    ASSERT_TRUE(csr.ok()) << csr.error().message();
+    ASSERT_FALSE(csr->hasUdsCerts());
+}
+
+TEST(RemoteProvUtils, csrHasCorrectChallenge) {
+    auto csr = hwtrust::Csr::validate(kCsrWithoutUdsCerts, hwtrust::DiceChain::Kind::kVsr16,
+                                      false /*isFactory*/, false /*allowAnyMode*/,
+                                      deviceSuffix(DEFAULT_INSTANCE_NAME));
+    ASSERT_TRUE(csr.ok()) << csr.error().message();
+
+    auto equal = csr->compareChallenge(kChallenge);
+    ASSERT_TRUE(equal.ok()) << equal.error().message();
+
+    ASSERT_TRUE(*equal) << kErrorChallengeMismatch;
+
+    auto zeroes = std::vector<uint8_t>(32, 0);
+    auto notEqual = csr->compareChallenge(zeroes);
+    ASSERT_TRUE(notEqual.ok()) << notEqual.error().message();
+
+    ASSERT_FALSE(*notEqual) << "ERROR: challenges are not different";
+}
+
+TEST(RemoteProvUtils, csrHasCorrectKeysToSign) {
+    auto csr = hwtrust::Csr::validate(kCsrWithoutUdsCerts, hwtrust::DiceChain::Kind::kVsr16,
+                                      false /*isFactory*/, false /*allowAnyMode*/,
+                                      deviceSuffix(DEFAULT_INSTANCE_NAME));
+    ASSERT_TRUE(csr.ok()) << csr.error().message();
+
+    auto equal = csr->compareKeysToSign(kKeysToSignForCsrWithoutUdsCerts);
+    ASSERT_TRUE(equal.ok()) << equal.error().message();
+    ASSERT_TRUE(*equal) << kErrorKeysToSignMismatch;
+
+    auto zeroes = std::vector<uint8_t>(kKeysToSignForCsrWithoutUdsCerts.size(), 0);
+    auto notEqual = csr->compareKeysToSign(zeroes);
+    ASSERT_TRUE(notEqual.ok()) << notEqual.error().message();
+    ASSERT_FALSE(*notEqual) << kErrorKeysToSignMismatch;
+}
+
+TEST(RemoteProvUtilsTest, allowDegenerateDiceChainWhenDegenerate) {
+    auto [keysToSignPtr, _, errMsg] = cppbor::parse(kKeysToSignForCsrWithDegenerateDiceChain);
+    ASSERT_TRUE(keysToSignPtr) << "Error: " << errMsg;
+
+    const auto keysToSign = keysToSignPtr->asArray();
+    auto csr = verifyFactoryCsr(*keysToSign, kCsrWithDegenerateDiceChain, kRpcHardwareInfo,
+                                DEFAULT_INSTANCE_NAME, kChallenge,
+                                /*allowDegenerate=*/true, /*requireUdsCerts=*/false);
+    ASSERT_TRUE(csr) << csr.message();
+}
+
+TEST(RemoteProvUtilsTest, disallowDegenerateDiceChainWhenDegenerate) {
+    auto [keysToSignPtr, _, errMsg] = cppbor::parse(kKeysToSignForCsrWithDegenerateDiceChain);
+    ASSERT_TRUE(keysToSignPtr) << "Error: " << errMsg;
+
+    const auto keysToSign = keysToSignPtr->asArray();
+    auto csr = verifyFactoryCsr(*keysToSign, kCsrWithDegenerateDiceChain, kRpcHardwareInfo,
+                                DEFAULT_INSTANCE_NAME, kChallenge,
+                                /*allowDegenerate=*/false, /*requireUdsCerts=*/false);
+    ASSERT_FALSE(csr);
+    ASSERT_THAT(csr.message(), testing::HasSubstr(kErrorDiceChainIsDegenerate));
+}
+
 TEST(RemoteProvUtilsTest, requireUdsCertsWhenPresent) {
     auto [keysToSignPtr, _, errMsg] = cppbor::parse(kKeysToSignForCsrWithUdsCerts);
     ASSERT_TRUE(keysToSignPtr) << "Error: " << errMsg;
 
     const auto keysToSign = keysToSignPtr->asArray();
-    auto csr =
-            verifyFactoryCsr(*keysToSign, kCsrWithUdsCerts, kRpcHardwareInfo, "default", kChallenge,
-                             /*allowDegenerate=*/false, /*requireUdsCerts=*/true);
+    auto csr = verifyFactoryCsr(*keysToSign, kCsrWithUdsCerts, kRpcHardwareInfo,
+                                DEFAULT_INSTANCE_NAME, kChallenge,
+                                /*allowDegenerate=*/false, /*requireUdsCerts=*/true);
     ASSERT_TRUE(csr) << csr.message();
 }
 
@@ -661,13 +751,11 @@
                                 DEFAULT_INSTANCE_NAME, kChallenge, /*allowDegenerate=*/false,
                                 /*requireUdsCerts=*/true);
     ASSERT_FALSE(csr);
-    ASSERT_THAT(csr.message(), testing::HasSubstr("UdsCerts must not be empty"));
+    ASSERT_THAT(csr.message(), testing::HasSubstr(kErrorUdsCertsAreRequired));
 }
 
 TEST(RemoteProvUtilsTest, dontRequireUdsCertsWhenNotPresent) {
-    auto [keysToSignPtr, _, errMsg] = cppbor::parse(
-            kKeysToSignForCsrWithoutUdsCerts.data(),
-            kKeysToSignForCsrWithoutUdsCerts.data() + kKeysToSignForCsrWithoutUdsCerts.size());
+    auto [keysToSignPtr, _, errMsg] = cppbor::parse(kKeysToSignForCsrWithoutUdsCerts);
     ASSERT_TRUE(keysToSignPtr) << "Error: " << errMsg;
 
     const auto* keysToSign = keysToSignPtr->asArray();
diff --git a/security/rkp/aidl/vts/functional/VtsRemotelyProvisionedComponentTests.cpp b/security/rkp/aidl/vts/functional/VtsRemotelyProvisionedComponentTests.cpp
index b9c742a..f0745d7 100644
--- a/security/rkp/aidl/vts/functional/VtsRemotelyProvisionedComponentTests.cpp
+++ b/security/rkp/aidl/vts/functional/VtsRemotelyProvisionedComponentTests.cpp
@@ -999,6 +999,7 @@
 
     std::unique_ptr<cppbor::Array> csrPayload = std::move(*result);
     ASSERT_TRUE(csrPayload);
+    ASSERT_TRUE(csrPayload->size() > 2);
 
     auto deviceInfo = csrPayload->get(2)->asMap();
     ASSERT_TRUE(deviceInfo);