[rkp_factory_extraction_tool] adding requireUdsCerts flag
Pass in a command-line argument that is a comma-delimited
list of RPC instance names for which the presence
of the UDS cert chain will be enforced in the Csr that
is defined in
hardware/interfaces/security/rkp/aidl/android/hardware/security/keymint/generateCertificateRequestV2.cddl
Bug: 366147625
Test: m rkp_factory_extraction_tool
Test: m librkp_factory_extraction_test
Test: atest system/security/provisioner
Change-Id: Idb3e81dd9f92fa446ebf23c3a08083fda5ed9eac
diff --git a/provisioner/rkp_factory_extraction_lib.cpp b/provisioner/rkp_factory_extraction_lib.cpp
index 37b81fd..d4aed45 100644
--- a/provisioner/rkp_factory_extraction_lib.cpp
+++ b/provisioner/rkp_factory_extraction_lib.cpp
@@ -25,7 +25,6 @@
#include <cstring>
#include <iterator>
#include <keymaster/cppcose/cppcose.h>
-#include <openssl/base64.h>
#include <remote_prov/remote_prov_utils.h>
#include <sys/random.h>
@@ -33,6 +32,7 @@
#include <optional>
#include <string>
#include <string_view>
+#include <unordered_set>
#include <vector>
#include "cppbor_parse.h"
@@ -42,6 +42,7 @@
using aidl::android::hardware::security::keymint::MacedPublicKey;
using aidl::android::hardware::security::keymint::ProtectedData;
using aidl::android::hardware::security::keymint::RpcHardwareInfo;
+using aidl::android::hardware::security::keymint::remote_prov::BccEntryData;
using aidl::android::hardware::security::keymint::remote_prov::EekChain;
using aidl::android::hardware::security::keymint::remote_prov::generateEekChain;
using aidl::android::hardware::security::keymint::remote_prov::getProdEekChain;
@@ -50,35 +51,13 @@
using aidl::android::hardware::security::keymint::remote_prov::verifyFactoryCsr;
using aidl::android::hardware::security::keymint::remote_prov::verifyFactoryProtectedData;
-using namespace cppbor;
-using namespace cppcose;
+using cppbor::Array;
+using cppbor::Map;
+using cppbor::Null;
+template <class T> using ErrMsgOr = cppcose::ErrMsgOr<T>;
constexpr size_t kVersionWithoutSuperencryption = 3;
-std::string toBase64(const std::vector<uint8_t>& buffer) {
- size_t base64Length;
- int rc = EVP_EncodedLength(&base64Length, buffer.size());
- if (!rc) {
- std::cerr << "Error getting base64 length. Size overflow?" << std::endl;
- exit(-1);
- }
-
- std::string base64(base64Length, ' ');
- rc = EVP_EncodeBlock(reinterpret_cast<uint8_t*>(base64.data()), buffer.data(), buffer.size());
- ++rc; // Account for NUL, which BoringSSL does not for some reason.
- if (rc != base64Length) {
- std::cerr << "Error writing base64. Expected " << base64Length
- << " bytes to be written, but " << rc << " bytes were actually written."
- << std::endl;
- exit(-1);
- }
-
- // BoringSSL automatically adds a NUL -- remove it from the string data
- base64.pop_back();
-
- return base64;
-}
-
std::vector<uint8_t> generateChallenge() {
std::vector<uint8_t> challenge(kChallengeSize);
@@ -90,7 +69,8 @@
if (errno == EINTR) {
continue;
} else {
- std::cerr << errno << ": " << strerror(errno) << std::endl;
+ std::cerr << "generateChallenge: getrandom returned an error with errno " << errno
+ << ": " << strerror(errno) << ". Exiting..." << std::endl;
exit(-1);
}
}
@@ -118,7 +98,7 @@
return {nullptr, parsedVerifiedDeviceInfo.moveMessage()};
}
- auto [parsedProtectedData, ignore2, errMsg] = parse(protectedData.protectedData);
+ auto [parsedProtectedData, ignore2, errMsg] = cppbor::parse(protectedData.protectedData);
if (!parsedProtectedData) {
std::cerr << "Error parsing protected data: '" << errMsg << "'" << std::endl;
return {nullptr, errMsg};
@@ -145,7 +125,7 @@
if (!status.isOk()) {
std::cerr << "Failed to get hardware info for '" << componentName
<< "'. Description: " << status.getDescription() << "." << std::endl;
- exit(-1);
+ return {nullptr, status.getDescription()};
}
const std::vector<uint8_t> eek = getProdEekChain(hwInfo.supportedEekCurve);
@@ -156,13 +136,14 @@
if (!status.isOk()) {
std::cerr << "Bundle extraction failed for '" << componentName
<< "'. Description: " << status.getDescription() << "." << std::endl;
- exit(-1);
+ return {nullptr, status.getDescription()};
}
return composeCertificateRequestV1(protectedData, verifiedDeviceInfo, challenge, keysToSignMac,
irpc);
}
-void selfTestGetCsrV1(std::string_view componentName, IRemotelyProvisionedComponent* irpc) {
+std::optional<std::string> selfTestGetCsrV1(std::string_view componentName,
+ IRemotelyProvisionedComponent* irpc) {
std::vector<uint8_t> keysToSignMac;
std::vector<MacedPublicKey> emptyKeys;
DeviceInfo verifiedDeviceInfo;
@@ -172,14 +153,14 @@
if (!status.isOk()) {
std::cerr << "Failed to get hardware info for '" << componentName
<< "'. Description: " << status.getDescription() << "." << std::endl;
- exit(-1);
+ return status.getDescription();
}
const std::vector<uint8_t> eekId = {0, 1, 2, 3, 4, 5, 6, 7};
ErrMsgOr<EekChain> eekChain = generateEekChain(hwInfo.supportedEekCurve, /*length=*/3, eekId);
if (!eekChain) {
std::cerr << "Error generating test EEK certificate chain: " << eekChain.message();
- exit(-1);
+ return eekChain.message();
}
const std::vector<uint8_t> challenge = generateChallenge();
status = irpc->generateCertificateRequest(
@@ -188,7 +169,7 @@
if (!status.isOk()) {
std::cerr << "Error generating test cert chain for '" << componentName
<< "'. Description: " << status.getDescription() << "." << std::endl;
- exit(-1);
+ return status.getDescription();
}
auto result = verifyFactoryProtectedData(
@@ -198,8 +179,9 @@
if (!result) {
std::cerr << "Self test failed for IRemotelyProvisionedComponent '" << componentName
<< "'. Error message: '" << result.message() << "'." << std::endl;
- exit(-1);
+ return result.message();
}
+ return std::nullopt;
}
CborResult<Array> composeCertificateRequestV3(const std::vector<uint8_t>& csr) {
@@ -223,9 +205,8 @@
return {std::unique_ptr<Array>(parsedCsr.release()->asArray()), ""};
}
-CborResult<cppbor::Array> getCsrV3(std::string_view componentName,
- IRemotelyProvisionedComponent* irpc, bool selfTest,
- bool allowDegenerate) {
+CborResult<Array> getCsrV3(std::string_view componentName, IRemotelyProvisionedComponent* irpc,
+ bool selfTest, bool allowDegenerate, bool requireUdsCerts) {
std::vector<uint8_t> csr;
std::vector<MacedPublicKey> emptyKeys;
const std::vector<uint8_t> challenge = generateChallenge();
@@ -234,16 +215,17 @@
if (!status.isOk()) {
std::cerr << "Bundle extraction failed for '" << componentName
<< "'. Description: " << status.getDescription() << "." << std::endl;
- exit(-1);
+ return {nullptr, status.getDescription()};
}
if (selfTest) {
- auto result = verifyFactoryCsr(/*keysToSign=*/cppbor::Array(), csr, irpc,
- std::string(componentName), challenge, allowDegenerate);
+ auto result =
+ verifyFactoryCsr(/*keysToSign=*/cppbor::Array(), csr, irpc, std::string(componentName),
+ challenge, allowDegenerate, requireUdsCerts);
if (!result) {
std::cerr << "Self test failed for IRemotelyProvisionedComponent '" << componentName
<< "'. Error message: '" << result.message() << "'." << std::endl;
- exit(-1);
+ return {nullptr, result.message()};
}
}
@@ -251,35 +233,35 @@
}
CborResult<Array> getCsr(std::string_view componentName, IRemotelyProvisionedComponent* irpc,
- bool selfTest, bool allowDegenerate) {
+ bool selfTest, bool allowDegenerate, bool requireUdsCerts) {
RpcHardwareInfo hwInfo;
auto status = irpc->getHardwareInfo(&hwInfo);
if (!status.isOk()) {
std::cerr << "Failed to get hardware info for '" << componentName
<< "'. Description: " << status.getDescription() << "." << std::endl;
- exit(-1);
+ return {nullptr, status.getDescription()};
}
if (hwInfo.versionNumber < kVersionWithoutSuperencryption) {
if (selfTest) {
- selfTestGetCsrV1(componentName, irpc);
+ auto errMsg = selfTestGetCsrV1(componentName, irpc);
+ if (errMsg) {
+ return {nullptr, *errMsg};
+ }
}
return getCsrV1(componentName, irpc);
} else {
- return getCsrV3(componentName, irpc, selfTest, allowDegenerate);
+ return getCsrV3(componentName, irpc, selfTest, allowDegenerate, requireUdsCerts);
}
}
-bool isRemoteProvisioningSupported(IRemotelyProvisionedComponent* irpc) {
- RpcHardwareInfo hwInfo;
- auto status = irpc->getHardwareInfo(&hwInfo);
- if (status.isOk()) {
- return true;
+std::unordered_set<std::string> parseCommaDelimited(const std::string& input) {
+ std::stringstream ss(input);
+ std::unordered_set<std::string> result;
+ while (ss.good()) {
+ std::string name;
+ std::getline(ss, name, ',');
+ result.insert(name);
}
- if (status.getExceptionCode() == EX_UNSUPPORTED_OPERATION) {
- return false;
- }
- std::cerr << "Unexpected error when getting hardware info. Description: "
- << status.getDescription() << "." << std::endl;
- exit(-1);
-}
+ return result;
+}
\ No newline at end of file
diff --git a/provisioner/rkp_factory_extraction_lib.h b/provisioner/rkp_factory_extraction_lib.h
index 94bd751..2c1e2ff 100644
--- a/provisioner/rkp_factory_extraction_lib.h
+++ b/provisioner/rkp_factory_extraction_lib.h
@@ -23,8 +23,12 @@
#include <memory>
#include <string>
#include <string_view>
+#include <unordered_set>
#include <vector>
+// Parse a comma-delimited string.
+std::unordered_set<std::string> parseCommaDelimited(const std::string& input);
+
// Challenge size must be between 32 and 64 bytes inclusive.
constexpr size_t kChallengeSize = 64;
@@ -35,9 +39,6 @@
std::string errMsg;
};
-// Return `buffer` encoded as a base64 string.
-std::string toBase64(const std::vector<uint8_t>& buffer);
-
// Generate a random challenge containing `kChallengeSize` bytes.
std::vector<uint8_t> generateChallenge();
@@ -47,13 +48,4 @@
CborResult<cppbor::Array>
getCsr(std::string_view componentName,
aidl::android::hardware::security::keymint::IRemotelyProvisionedComponent* irpc,
- bool selfTest, bool allowDegenerate);
-
-// Generates a test certificate chain and validates it, exiting the process on error.
-void selfTestGetCsr(
- std::string_view componentName,
- aidl::android::hardware::security::keymint::IRemotelyProvisionedComponent* irpc);
-
-// Returns true if the given IRemotelyProvisionedComponent supports remote provisioning.
-bool isRemoteProvisioningSupported(
- aidl::android::hardware::security::keymint::IRemotelyProvisionedComponent* irpc);
+ bool selfTest, bool allowDegenerate, bool requireUdsCerts);
\ No newline at end of file
diff --git a/provisioner/rkp_factory_extraction_lib_test.cpp b/provisioner/rkp_factory_extraction_lib_test.cpp
index 247c508..746ce41 100644
--- a/provisioner/rkp_factory_extraction_lib_test.cpp
+++ b/provisioner/rkp_factory_extraction_lib_test.cpp
@@ -25,6 +25,8 @@
#include <android-base/properties.h>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
+#include <openssl/base64.h>
+#include <remote_prov/MockIRemotelyProvisionedComponent.h>
#include <cstdint>
#include <memory>
@@ -60,26 +62,29 @@
} // namespace cppbor
-class MockIRemotelyProvisionedComponent : public IRemotelyProvisionedComponentDefault {
- public:
- MOCK_METHOD(ScopedAStatus, getHardwareInfo, (RpcHardwareInfo * _aidl_return), (override));
- MOCK_METHOD(ScopedAStatus, generateEcdsaP256KeyPair,
- (bool in_testMode, MacedPublicKey* out_macedPublicKey,
- std::vector<uint8_t>* _aidl_return),
- (override));
- MOCK_METHOD(ScopedAStatus, generateCertificateRequest,
- (bool in_testMode, const std::vector<MacedPublicKey>& in_keysToSign,
- const std::vector<uint8_t>& in_endpointEncryptionCertChain,
- const std::vector<uint8_t>& in_challenge, DeviceInfo* out_deviceInfo,
- ProtectedData* out_protectedData, std::vector<uint8_t>* _aidl_return),
- (override));
- MOCK_METHOD(ScopedAStatus, generateCertificateRequestV2,
- (const std::vector<MacedPublicKey>& in_keysToSign,
- const std::vector<uint8_t>& in_challenge, std::vector<uint8_t>* _aidl_return),
- (override));
- MOCK_METHOD(ScopedAStatus, getInterfaceVersion, (int32_t * _aidl_return), (override));
- MOCK_METHOD(ScopedAStatus, getInterfaceHash, (std::string * _aidl_return), (override));
-};
+std::string toBase64(const std::vector<uint8_t>& buffer) {
+ size_t base64Length;
+ int rc = EVP_EncodedLength(&base64Length, buffer.size());
+ if (!rc) {
+ std::cerr << "Error getting base64 length. Size overflow?" << std::endl;
+ exit(-1);
+ }
+
+ std::string base64(base64Length, ' ');
+ rc = EVP_EncodeBlock(reinterpret_cast<uint8_t*>(base64.data()), buffer.data(), buffer.size());
+ ++rc; // Account for NUL, which BoringSSL does not for some reason.
+ if (rc != base64Length) {
+ std::cerr << "Error writing base64. Expected " << base64Length
+ << " bytes to be written, but " << rc << " bytes were actually written."
+ << std::endl;
+ exit(-1);
+ }
+
+ // BoringSSL automatically adds a NUL -- remove it from the string data
+ base64.pop_back();
+
+ return base64;
+}
TEST(LibRkpFactoryExtractionTests, ToBase64) {
std::vector<uint8_t> input(UINT8_MAX + 1);
@@ -87,7 +92,7 @@
input[i] = i;
}
- // Test three lengths so we get all the different paddding options
+ // Test three lengths so we get all the different padding options
EXPECT_EQ("AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8gISIjJCUmJygpKissLS4"
"vMDEyMzQ1Njc4OTo7PD0+P0BBQkNERUZHSElKS0xNTk9QUVJTVFVWV1hZWltcXV"
"5fYGFiY2RlZmdoaWprbG1ub3BxcnN0dXZ3eHl6e3x9fn+AgYKDhIWGh4iJiouMj"
@@ -160,7 +165,7 @@
std::vector<uint8_t> challenge;
// Set up mock, then call getSCsr
- auto mockRpc = SharedRefBase::make<MockIRemotelyProvisionedComponent>();
+ auto mockRpc = SharedRefBase::make<remote_prov::MockIRemotelyProvisionedComponent>();
EXPECT_CALL(*mockRpc, getHardwareInfo(NotNull())).WillRepeatedly([](RpcHardwareInfo* hwInfo) {
hwInfo->versionNumber = 2;
return ScopedAStatus::ok();
@@ -180,8 +185,9 @@
SetArgPointee<6>(kFakeMac), //
Return(ByMove(ScopedAStatus::ok())))); //
- auto [csr, csrErrMsg] = getCsr("mock component name", mockRpc.get(),
- /*selfTest=*/false, /*allowDegenerate=*/true);
+ auto [csr, csrErrMsg] =
+ getCsr("mock component name", mockRpc.get(),
+ /*selfTest=*/false, /*allowDegenerate=*/true, /*requireUdsCerts=*/false);
ASSERT_THAT(csr, NotNull()) << csrErrMsg;
ASSERT_THAT(csr->asArray(), Pointee(Property(&Array::size, Eq(4))));
@@ -230,7 +236,7 @@
TEST(LibRkpFactoryExtractionTests, GetCsrWithV3Hal) {
const std::vector<uint8_t> kCsr = Array()
- .add(3 /* version */)
+ .add(1 /* version */)
.add(Map() /* UdsCerts */)
.add(Array() /* DiceCertChain */)
.add(Array() /* SignedData */)
@@ -238,7 +244,7 @@
std::vector<uint8_t> challenge;
// Set up mock, then call getCsr
- auto mockRpc = SharedRefBase::make<MockIRemotelyProvisionedComponent>();
+ auto mockRpc = SharedRefBase::make<remote_prov::MockIRemotelyProvisionedComponent>();
EXPECT_CALL(*mockRpc, getHardwareInfo(NotNull())).WillRepeatedly([](RpcHardwareInfo* hwInfo) {
hwInfo->versionNumber = 3;
return ScopedAStatus::ok();
@@ -250,12 +256,13 @@
.WillOnce(DoAll(SaveArg<1>(&challenge), SetArgPointee<2>(kCsr),
Return(ByMove(ScopedAStatus::ok()))));
- auto [csr, csrErrMsg] = getCsr("mock component name", mockRpc.get(),
- /*selfTest=*/false, /*allowDegenerate=*/true);
+ auto [csr, csrErrMsg] =
+ getCsr("mock component name", mockRpc.get(),
+ /*selfTest=*/false, /*allowDegenerate=*/true, /*requireUdsCerts=*/false);
ASSERT_THAT(csr, NotNull()) << csrErrMsg;
ASSERT_THAT(csr, Pointee(Property(&Array::size, Eq(5))));
- EXPECT_THAT(csr->get(0 /* version */), Pointee(Eq(Uint(3))));
+ EXPECT_THAT(csr->get(0 /* version */), Pointee(Eq(Uint(1))));
EXPECT_THAT(csr->get(1)->asMap(), NotNull());
EXPECT_THAT(csr->get(2)->asArray(), NotNull());
EXPECT_THAT(csr->get(3)->asArray(), NotNull());
@@ -266,3 +273,72 @@
const Tstr fingerprint(android::base::GetProperty("ro.build.fingerprint", ""));
EXPECT_THAT(*unverifedDeviceInfo->get("fingerprint")->asTstr(), Eq(fingerprint));
}
+
+TEST(LibRkpFactoryExtractionTests, requireUdsCerts) {
+ const std::vector<uint8_t> kCsr = Array()
+ .add(1 /* version */)
+ .add(Map() /* UdsCerts */)
+ .add(Array() /* DiceCertChain */)
+ .add(Array() /* SignedData */)
+ .encode();
+ std::vector<uint8_t> challenge;
+
+ // Set up mock, then call getCsr
+ auto mockRpc = SharedRefBase::make<remote_prov::MockIRemotelyProvisionedComponent>();
+ EXPECT_CALL(*mockRpc, getHardwareInfo(NotNull())).WillRepeatedly([](RpcHardwareInfo* hwInfo) {
+ hwInfo->versionNumber = 3;
+ return ScopedAStatus::ok();
+ });
+ EXPECT_CALL(*mockRpc,
+ generateCertificateRequestV2(IsEmpty(), // keysToSign
+ _, // challenge
+ NotNull())) // _aidl_return
+ .WillOnce(DoAll(SaveArg<1>(&challenge), SetArgPointee<2>(kCsr),
+ Return(ByMove(ScopedAStatus::ok()))));
+
+ auto [csr, csrErrMsg] =
+ getCsr("mock component name", mockRpc.get(),
+ /*selfTest=*/true, /*allowDegenerate=*/false, /*requireUdsCerts=*/true);
+ ASSERT_EQ(csr, nullptr);
+ ASSERT_THAT(csrErrMsg, testing::HasSubstr("UdsCerts must not be empty"));
+}
+
+TEST(LibRkpFactoryExtractionTests, dontRequireUdsCerts) {
+ const std::vector<uint8_t> kCsr = Array()
+ .add(1 /* version */)
+ .add(Map() /* UdsCerts */)
+ .add(Array() /* DiceCertChain */)
+ .add(Array() /* SignedData */)
+ .encode();
+ std::vector<uint8_t> challenge;
+
+ // Set up mock, then call getCsr
+ auto mockRpc = SharedRefBase::make<remote_prov::MockIRemotelyProvisionedComponent>();
+ EXPECT_CALL(*mockRpc, getHardwareInfo(NotNull())).WillRepeatedly([](RpcHardwareInfo* hwInfo) {
+ hwInfo->versionNumber = 3;
+ return ScopedAStatus::ok();
+ });
+ EXPECT_CALL(*mockRpc,
+ generateCertificateRequestV2(IsEmpty(), // keysToSign
+ _, // challenge
+ NotNull())) // _aidl_return
+ .WillOnce(DoAll(SaveArg<1>(&challenge), SetArgPointee<2>(kCsr),
+ Return(ByMove(ScopedAStatus::ok()))));
+
+ auto [csr, csrErrMsg] =
+ getCsr("mock component name", mockRpc.get(),
+ /*selfTest=*/true, /*allowDegenerate=*/false, /*requireUdsCerts=*/false);
+ ASSERT_EQ(csr, nullptr);
+ ASSERT_THAT(csrErrMsg, testing::Not(testing::HasSubstr("UdsCerts must not be empty")));
+}
+
+TEST(LibRkpFactoryExtractionTests, parseCommaDelimitedString) {
+ const auto& rpcNames = "default,avf,,default,Strongbox,strongbox,,";
+ const auto& rpcSet = parseCommaDelimited(rpcNames);
+
+ ASSERT_EQ(rpcSet.size(), 5);
+ ASSERT_TRUE(rpcSet.count("default") == 1);
+ ASSERT_TRUE(rpcSet.count("avf") == 1);
+ ASSERT_TRUE(rpcSet.count("strongbox") == 1);
+ ASSERT_TRUE(rpcSet.count("Strongbox") == 1);
+}
\ No newline at end of file
diff --git a/provisioner/rkp_factory_extraction_tool.cpp b/provisioner/rkp_factory_extraction_tool.cpp
index c0f6beb..eaa0acc 100644
--- a/provisioner/rkp_factory_extraction_tool.cpp
+++ b/provisioner/rkp_factory_extraction_tool.cpp
@@ -26,6 +26,7 @@
#include <future>
#include <string>
+#include <unordered_set>
#include <vector>
#include "DrmRkpAdapter.h"
@@ -33,10 +34,10 @@
using aidl::android::hardware::drm::IDrmFactory;
using aidl::android::hardware::security::keymint::IRemotelyProvisionedComponent;
+using aidl::android::hardware::security::keymint::RpcHardwareInfo;
+using aidl::android::hardware::security::keymint::remote_prov::deviceSuffix;
using aidl::android::hardware::security::keymint::remote_prov::jsonEncodeCsrWithBuild;
-
-using namespace cppbor;
-using namespace cppcose;
+using aidl::android::hardware::security::keymint::remote_prov::RKPVM_INSTANCE_NAME;
DEFINE_string(output_format, "build+csr", "How to format the output. Defaults to 'build+csr'.");
DEFINE_bool(self_test, true,
@@ -47,6 +48,10 @@
"If true, self_test validation will allow degenerate DICE chains in the CSR.");
DEFINE_string(serialno_prop, "ro.serialno",
"The property of getting serial number. Defaults to 'ro.serialno'.");
+DEFINE_string(require_uds_certs, "",
+ "The comma-delimited names of remotely provisioned "
+ "components whose UDS certificate chains are required to be present in the CSR. "
+ "Example: avf,default,strongbox");
namespace {
@@ -59,15 +64,15 @@
return std::string(descriptor) + "/" + name;
}
-void writeOutput(const std::string instance_name, const Array& csr) {
+void writeOutput(const std::string instance_name, const cppbor::Array& csr) {
if (FLAGS_output_format == kBinaryCsrOutput) {
auto bytes = csr.encode();
std::copy(bytes.begin(), bytes.end(), std::ostream_iterator<char>(std::cout));
} else if (FLAGS_output_format == kBuildPlusCsr) {
auto [json, error] = jsonEncodeCsrWithBuild(instance_name, csr, FLAGS_serialno_prop);
if (!error.empty()) {
- std::cerr << "Error JSON encoding the output: " << error;
- exit(1);
+ std::cerr << "Error JSON encoding the output: " << error << std::endl;
+ exit(-1);
}
std::cout << json << std::endl;
} else {
@@ -75,20 +80,28 @@
std::cerr << "Valid formats:" << std::endl;
std::cerr << " " << kBinaryCsrOutput << std::endl;
std::cerr << " " << kBuildPlusCsr << std::endl;
- exit(1);
+ exit(-1);
}
}
-void getCsrForIRpc(const char* descriptor, const char* name, IRemotelyProvisionedComponent* irpc) {
+void getCsrForIRpc(const char* descriptor, const char* name, IRemotelyProvisionedComponent* irpc,
+ bool requireUdsCerts) {
// AVF RKP HAL is not always supported, so we need to check if it is supported before
// generating the CSR.
- if (std::string(name) == "avf" && !isRemoteProvisioningSupported(irpc)) {
- return;
+ if (std::string(name) == deviceSuffix(RKPVM_INSTANCE_NAME)) {
+ RpcHardwareInfo hwInfo;
+ auto status = irpc->getHardwareInfo(&hwInfo);
+ if (!status.isOk()) {
+ return;
+ }
}
- auto [request, errMsg] = getCsr(name, irpc, FLAGS_self_test, FLAGS_allow_degenerate);
- auto fullName = getFullServiceName(descriptor, name);
+
+ auto [request, errMsg] =
+ getCsr(name, irpc, FLAGS_self_test, FLAGS_allow_degenerate, requireUdsCerts);
if (!request) {
- std::cerr << "Unable to build CSR for '" << fullName << ": " << errMsg << std::endl;
+ auto fullName = getFullServiceName(descriptor, name);
+ std::cerr << "Unable to build CSR for '" << fullName << "': " << errMsg << ", exiting."
+ << std::endl;
exit(-1);
}
@@ -97,23 +110,33 @@
// Callback for AServiceManager_forEachDeclaredInstance that writes out a CSR
// for every IRemotelyProvisionedComponent.
-void getCsrForInstance(const char* name, void* /*context*/) {
+void getCsrForInstance(const char* name, void* context) {
auto fullName = getFullServiceName(IRemotelyProvisionedComponent::descriptor, name);
- std::future<AIBinder*> wait_for_service_func =
+ std::future<AIBinder*> waitForServiceFunc =
std::async(std::launch::async, AServiceManager_waitForService, fullName.c_str());
- if (wait_for_service_func.wait_for(std::chrono::seconds(10)) == std::future_status::timeout) {
- std::cerr << "Wait for service timed out after 10 seconds: " << fullName;
+ if (waitForServiceFunc.wait_for(std::chrono::seconds(10)) == std::future_status::timeout) {
+ std::cerr << "Wait for service timed out after 10 seconds: '" << fullName << "', exiting."
+ << std::endl;
exit(-1);
}
- AIBinder* rkpAiBinder = wait_for_service_func.get();
+ AIBinder* rkpAiBinder = waitForServiceFunc.get();
::ndk::SpAIBinder rkp_binder(rkpAiBinder);
- auto rkp_service = IRemotelyProvisionedComponent::fromBinder(rkp_binder);
- if (!rkp_service) {
- std::cerr << "Unable to get binder object for '" << fullName << "', skipping.";
+ auto rkpService = IRemotelyProvisionedComponent::fromBinder(rkp_binder);
+ if (!rkpService) {
+ std::cerr << "Unable to get binder object for '" << fullName << "', exiting." << std::endl;
exit(-1);
}
- getCsrForIRpc(IRemotelyProvisionedComponent::descriptor, name, rkp_service.get());
+ if (context == nullptr) {
+ std::cerr << "Unable to get context for '" << fullName << "', exiting." << std::endl;
+ exit(-1);
+ }
+
+ auto requireUdsCertsRpcNames = static_cast<std::unordered_set<std::string>*>(context);
+ auto requireUdsCerts = requireUdsCertsRpcNames->count(name) != 0;
+ requireUdsCertsRpcNames->erase(name);
+ getCsrForIRpc(IRemotelyProvisionedComponent::descriptor, name, rkpService.get(),
+ requireUdsCerts);
}
} // namespace
@@ -121,12 +144,21 @@
int main(int argc, char** argv) {
gflags::ParseCommandLineFlags(&argc, &argv, /*remove_flags=*/true);
+ auto requireUdsCertsRpcNames = parseCommaDelimited(FLAGS_require_uds_certs);
+
AServiceManager_forEachDeclaredInstance(IRemotelyProvisionedComponent::descriptor,
- /*context=*/nullptr, getCsrForInstance);
+ &requireUdsCertsRpcNames, getCsrForInstance);
// Append drm csr's
- for (auto const& e : android::mediadrm::getDrmRemotelyProvisionedComponents()) {
- getCsrForIRpc(IDrmFactory::descriptor, e.first.c_str(), e.second.get());
+ for (auto const& [name, irpc] : android::mediadrm::getDrmRemotelyProvisionedComponents()) {
+ auto requireUdsCerts = requireUdsCertsRpcNames.count(name) != 0;
+ requireUdsCertsRpcNames.erase(name);
+ getCsrForIRpc(IDrmFactory::descriptor, name.c_str(), irpc.get(), requireUdsCerts);
+ }
+
+ for (auto const& rpcName : requireUdsCertsRpcNames) {
+ std::cerr << "WARNING: You requested to enforce the presence of UDS Certs for '" << rpcName
+ << "', but no Remotely Provisioned Component had that name." << std::endl;
}
return 0;