rkp_factory_extraction_tool: Add support for IRPC v3

Bug: 235265072
Test: rkp_factory_extraction_tool -self_test
Change-Id: Ie776411a32d446b53cb3dfe73a24f60e1eab6506
diff --git a/provisioner/rkp_factory_extraction_lib.cpp b/provisioner/rkp_factory_extraction_lib.cpp
index 77d032b..d85e85f 100644
--- a/provisioner/rkp_factory_extraction_lib.cpp
+++ b/provisioner/rkp_factory_extraction_lib.cpp
@@ -46,11 +46,14 @@
 using aidl::android::hardware::security::keymint::remote_prov::getProdEekChain;
 using aidl::android::hardware::security::keymint::remote_prov::jsonEncodeCsrWithBuild;
 using aidl::android::hardware::security::keymint::remote_prov::parseAndValidateFactoryDeviceInfo;
+using aidl::android::hardware::security::keymint::remote_prov::verifyFactoryCsr;
 using aidl::android::hardware::security::keymint::remote_prov::verifyFactoryProtectedData;
 
 using namespace cppbor;
 using namespace cppcose;
 
+constexpr size_t kVersionWithoutSuperencryption = 3;
+
 std::string toBase64(const std::vector<uint8_t>& buffer) {
     size_t base64Length;
     int rc = EVP_EncodedLength(&base64Length, buffer.size());
@@ -97,11 +100,11 @@
     return challenge;
 }
 
-CborResult<Array> composeCertificateRequest(const ProtectedData& protectedData,
-                                            const DeviceInfo& verifiedDeviceInfo,
-                                            const std::vector<uint8_t>& challenge,
-                                            const std::vector<uint8_t>& keysToSignMac,
-                                            IRemotelyProvisionedComponent* provisionable) {
+CborResult<Array> composeCertificateRequestV1(const ProtectedData& protectedData,
+                                              const DeviceInfo& verifiedDeviceInfo,
+                                              const std::vector<uint8_t>& challenge,
+                                              const std::vector<uint8_t>& keysToSignMac,
+                                              IRemotelyProvisionedComponent* provisionable) {
     Array macedKeysToSign = Array()
                                 .add(Map().add(1, 5).encode())  // alg: hmac-sha256
                                 .add(Map())                     // empty unprotected headers
@@ -131,7 +134,7 @@
     return {std::move(certificateRequest), ""};
 }
 
-CborResult<Array> getCsr(std::string_view componentName, IRemotelyProvisionedComponent* irpc) {
+CborResult<Array> getCsrV1(std::string_view componentName, IRemotelyProvisionedComponent* irpc) {
     std::vector<uint8_t> keysToSignMac;
     std::vector<MacedPublicKey> emptyKeys;
     DeviceInfo verifiedDeviceInfo;
@@ -154,11 +157,11 @@
                   << "'. Error code: " << status.getServiceSpecificError() << "." << std::endl;
         exit(-1);
     }
-    return composeCertificateRequest(protectedData, verifiedDeviceInfo, challenge, keysToSignMac,
-                                     irpc);
+    return composeCertificateRequestV1(protectedData, verifiedDeviceInfo, challenge, keysToSignMac,
+                                       irpc);
 }
 
-void selfTestGetCsr(std::string_view componentName, IRemotelyProvisionedComponent* irpc) {
+void selfTestGetCsrV1(std::string_view componentName, IRemotelyProvisionedComponent* irpc) {
     std::vector<uint8_t> keysToSignMac;
     std::vector<MacedPublicKey> emptyKeys;
     DeviceInfo verifiedDeviceInfo;
@@ -192,4 +195,86 @@
                                              hwInfo.supportedEekCurve, irpc, challenge);
 
     std::cout << "Self test successful." << std::endl;
-}
\ No newline at end of file
+}
+
+CborResult<Array> composeCertificateRequestV3(const std::vector<uint8_t>& csr) {
+    auto [parsedCsr, _, csrErrMsg] = cppbor::parse(csr);
+    if (!parsedCsr) {
+        return {nullptr, csrErrMsg};
+    }
+    if (!parsedCsr->asArray()) {
+        return {nullptr, "CSR is not a CBOR array."};
+    }
+
+    return {std::unique_ptr<Array>(parsedCsr.release()->asArray()), ""};
+}
+
+CborResult<cppbor::Array> getCsrV3(std::string_view componentName,
+                                   IRemotelyProvisionedComponent* irpc) {
+    std::vector<uint8_t> csr;
+    std::vector<MacedPublicKey> emptyKeys;
+    const std::vector<uint8_t> challenge = generateChallenge();
+
+    auto status = irpc->generateCertificateRequestV2(emptyKeys, challenge, &csr);
+    if (!status.isOk()) {
+        std::cerr << "Bundle extraction failed for '" << componentName
+                  << "'. Error code: " << status.getServiceSpecificError() << "." << std::endl;
+        exit(-1);
+    }
+
+    return composeCertificateRequestV3(csr);
+}
+
+void selfTestGetCsrV3(std::string_view componentName, IRemotelyProvisionedComponent* irpc) {
+    std::vector<uint8_t> csr;
+    std::vector<MacedPublicKey> emptyKeys;
+    const std::vector<uint8_t> challenge = generateChallenge();
+
+    auto status = irpc->generateCertificateRequestV2(emptyKeys, challenge, &csr);
+    if (!status.isOk()) {
+        std::cerr << "Bundle extraction failed for '" << componentName
+                  << "'. Error code: " << status.getServiceSpecificError() << "." << std::endl;
+        exit(-1);
+    }
+
+    auto result = verifyFactoryCsr(/*keysToSign=*/cppbor::Array(), csr, irpc, challenge);
+    if (!result) {
+        std::cerr << "Self test failed for '" << componentName
+                  << "'. Error message: " << result.message() << "." << std::endl;
+        exit(-1);
+    }
+
+    std::cout << "Self test successful." << std::endl;
+}
+
+CborResult<Array> getCsr(std::string_view componentName, IRemotelyProvisionedComponent* irpc) {
+    RpcHardwareInfo hwInfo;
+    auto status = irpc->getHardwareInfo(&hwInfo);
+    if (!status.isOk()) {
+        std::cerr << "Failed to get hardware info for '" << componentName
+                  << "'. Error code: " << status.getServiceSpecificError() << "." << std::endl;
+        exit(-1);
+    }
+
+    if (hwInfo.versionNumber < kVersionWithoutSuperencryption) {
+        return getCsrV1(componentName, irpc);
+    } else {
+        return getCsrV3(componentName, irpc);
+    }
+}
+
+void selfTestGetCsr(std::string_view componentName, IRemotelyProvisionedComponent* irpc) {
+    RpcHardwareInfo hwInfo;
+    auto status = irpc->getHardwareInfo(&hwInfo);
+    if (!status.isOk()) {
+        std::cerr << "Failed to get hardware info for '" << componentName
+                  << "'. Error code: " << status.getServiceSpecificError() << "." << std::endl;
+        exit(-1);
+    }
+
+    if (hwInfo.versionNumber < kVersionWithoutSuperencryption) {
+        selfTestGetCsrV1(componentName, irpc);
+    } else {
+        selfTestGetCsrV3(componentName, irpc);
+    }
+}
diff --git a/provisioner/rkp_factory_extraction_lib.h b/provisioner/rkp_factory_extraction_lib.h
index a803582..a218338 100644
--- a/provisioner/rkp_factory_extraction_lib.h
+++ b/provisioner/rkp_factory_extraction_lib.h
@@ -25,7 +25,8 @@
 #include <string_view>
 #include <vector>
 
-constexpr size_t kChallengeSize = 16;
+// Challenge size must be between 32 and 64 bytes inclusive.
+constexpr size_t kChallengeSize = 64;
 
 // Contains a the result of an operation that should return cborData on success.
 // Returns an an error message and null cborData on error.
@@ -50,4 +51,4 @@
 // Generates a test certificate chain and validates it, exiting the process on error.
 void selfTestGetCsr(
     std::string_view componentName,
-    aidl::android::hardware::security::keymint::IRemotelyProvisionedComponent* irpc);
\ No newline at end of file
+    aidl::android::hardware::security::keymint::IRemotelyProvisionedComponent* irpc);
diff --git a/provisioner/rkp_factory_extraction_lib_test.cpp b/provisioner/rkp_factory_extraction_lib_test.cpp
index b27b717..05509b3 100644
--- a/provisioner/rkp_factory_extraction_lib_test.cpp
+++ b/provisioner/rkp_factory_extraction_lib_test.cpp
@@ -72,6 +72,10 @@
                  const std::vector<uint8_t>& in_challenge, DeviceInfo* out_deviceInfo,
                  ProtectedData* out_protectedData, std::vector<uint8_t>* _aidl_return),
                 (override));
+    MOCK_METHOD(ScopedAStatus, generateCertificateRequestV2,
+                (const std::vector<MacedPublicKey>& in_keysToSign,
+                 const std::vector<uint8_t>& in_challenge, std::vector<uint8_t>* _aidl_return),
+                (override));
     MOCK_METHOD(ScopedAStatus, getInterfaceVersion, (int32_t * _aidl_return), (override));
     MOCK_METHOD(ScopedAStatus, getInterfaceHash, (std::string * _aidl_return), (override));
 };
@@ -221,3 +225,35 @@
     EXPECT_THAT(actualMacedKeys->get(2)->asNull(), NotNull());
     EXPECT_THAT(actualMacedKeys->get(3)->asBstr(), Pointee(Eq(Bstr(kFakeMac))));
 }
+
+TEST(LibRkpFactoryExtractionTests, GetCsrWithV3Hal) {
+    const std::vector<uint8_t> kCsr = Array()
+                                          .add(3 /* version */)
+                                          .add(Map() /* UdsCerts */)
+                                          .add(Array() /* DiceCertChain */)
+                                          .add(Array() /* SignedData */)
+                                          .encode();
+    std::vector<uint8_t> challenge;
+
+    // Set up mock, then call getCsr
+    auto mockRpc = SharedRefBase::make<MockIRemotelyProvisionedComponent>();
+    EXPECT_CALL(*mockRpc, getHardwareInfo(NotNull())).WillRepeatedly([](RpcHardwareInfo* hwInfo) {
+        hwInfo->versionNumber = 3;
+        return ScopedAStatus::ok();
+    });
+    EXPECT_CALL(*mockRpc,
+                generateCertificateRequestV2(IsEmpty(),   // keysToSign
+                                             _,           // challenge
+                                             NotNull()))  // _aidl_return
+        .WillOnce(DoAll(SaveArg<1>(&challenge), SetArgPointee<2>(kCsr),
+                        Return(ByMove(ScopedAStatus::ok()))));
+
+    auto [csr, csrErrMsg] = getCsr("mock component name", mockRpc.get());
+    ASSERT_THAT(csr, NotNull()) << csrErrMsg;
+    ASSERT_THAT(csr, Pointee(Property(&Array::size, Eq(4))));
+
+    EXPECT_THAT(csr->get(0 /* version */), Pointee(Eq(Uint(3))));
+    EXPECT_THAT(csr->get(1)->asMap(), NotNull());
+    EXPECT_THAT(csr->get(2)->asArray(), NotNull());
+    EXPECT_THAT(csr->get(3)->asArray(), NotNull());
+}
diff --git a/provisioner/rkp_factory_extraction_tool.cpp b/provisioner/rkp_factory_extraction_tool.cpp
index 0fe7d74..2aeabe0 100644
--- a/provisioner/rkp_factory_extraction_tool.cpp
+++ b/provisioner/rkp_factory_extraction_tool.cpp
@@ -47,8 +47,6 @@
 constexpr std::string_view kBuildPlusCsr = "build+csr";  // Text-encoded (JSON) build
                                                          // fingerprint plus CSR.
 
-constexpr size_t kChallengeSize = 16;
-
 void writeOutput(const std::string instance_name, const Array& csr) {
     if (FLAGS_output_format == kBinaryCsrOutput) {
         auto bytes = csr.encode();