rkp_factory_extraction_tool: append drm CSRs

Bug: 286556950
Test: rkp_factory_extraction_tool
Change-Id: I9fe2898c53012c6cd640e4504ca4d882481ea2a9
diff --git a/provisioner/rkp_factory_extraction_tool.cpp b/provisioner/rkp_factory_extraction_tool.cpp
index 5ba777e..5765e05 100644
--- a/provisioner/rkp_factory_extraction_tool.cpp
+++ b/provisioner/rkp_factory_extraction_tool.cpp
@@ -14,6 +14,7 @@
  * limitations under the License.
  */
 
+#include <aidl/android/hardware/drm/IDrmFactory.h>
 #include <aidl/android/hardware/security/keymint/IRemotelyProvisionedComponent.h>
 #include <android/binder_manager.h>
 #include <cppbor.h>
@@ -26,8 +27,10 @@
 #include <string>
 #include <vector>
 
+#include "DrmRkpAdapter.h"
 #include "rkp_factory_extraction_lib.h"
 
+using aidl::android::hardware::drm::IDrmFactory;
 using aidl::android::hardware::security::keymint::IRemotelyProvisionedComponent;
 using aidl::android::hardware::security::keymint::remote_prov::jsonEncodeCsrWithBuild;
 
@@ -47,6 +50,10 @@
 constexpr std::string_view kBuildPlusCsr = "build+csr";  // Text-encoded (JSON) build
                                                          // fingerprint plus CSR.
 
+std::string getFullServiceName(const char* descriptor, const char* name) {
+    return  std::string(descriptor) + "/" + name;
+}
+
 void writeOutput(const std::string instance_name, const Array& csr) {
     if (FLAGS_output_format == kBinaryCsrOutput) {
         auto bytes = csr.encode();
@@ -67,12 +74,21 @@
     }
 }
 
+void getCsrForIRpc(const char* descriptor, const char* name, IRemotelyProvisionedComponent* irpc) {
+    auto [request, errMsg] = getCsr(name, irpc, FLAGS_self_test);
+    auto fullName = getFullServiceName(descriptor, name);
+    if (!request) {
+        std::cerr << "Unable to build CSR for '" << fullName << ": " << errMsg << std::endl;
+        exit(-1);
+    }
+
+    writeOutput(std::string(name), *request);
+}
+
 // Callback for AServiceManager_forEachDeclaredInstance that writes out a CSR
 // for every IRemotelyProvisionedComponent.
 void getCsrForInstance(const char* name, void* /*context*/) {
-    const std::vector<uint8_t> challenge = generateChallenge();
-
-    auto fullName = std::string(IRemotelyProvisionedComponent::descriptor) + "/" + name;
+    auto fullName = getFullServiceName(IRemotelyProvisionedComponent::descriptor, name);
     AIBinder* rkpAiBinder = AServiceManager_getService(fullName.c_str());
     ::ndk::SpAIBinder rkp_binder(rkpAiBinder);
     auto rkp_service = IRemotelyProvisionedComponent::fromBinder(rkp_binder);
@@ -81,13 +97,7 @@
         exit(-1);
     }
 
-    auto [request, errMsg] = getCsr(name, rkp_service.get(), FLAGS_self_test);
-    if (!request) {
-        std::cerr << "Unable to build CSR for '" << fullName << ": " << errMsg << std::endl;
-        exit(-1);
-    }
-
-    writeOutput(std::string(name), *request);
+    getCsrForIRpc(IRemotelyProvisionedComponent::descriptor, name, rkp_service.get());
 }
 
 }  // namespace
@@ -98,5 +108,10 @@
     AServiceManager_forEachDeclaredInstance(IRemotelyProvisionedComponent::descriptor,
                                             /*context=*/nullptr, getCsrForInstance);
 
+    // Append drm csr's
+    for (auto const& e : android::mediadrm::getDrmRemotelyProvisionedComponents()) {
+        getCsrForIRpc(IDrmFactory::descriptor, e.first.c_str(), e.second.get());
+    }
+
     return 0;
 }