[rkp] Verify the MAC of the public keys to sign in RKP HAL

Bug: 299256925
Test: atest rialto_test
Change-Id: I9d7ff281166e5acbe47936fa103cbe6c5fa2c2da
diff --git a/rialto/src/requests/pub_key.rs b/rialto/src/requests/pub_key.rs
index 84373ce..b45c117 100644
--- a/rialto/src/requests/pub_key.rs
+++ b/rialto/src/requests/pub_key.rs
@@ -19,11 +19,27 @@
 use bssl_ffi::EVP_sha256;
 use bssl_ffi::HMAC;
 use core::result;
-use coset::{iana, CborSerializable, CoseKey, CoseMac0Builder, HeaderBuilder};
+use coset::{iana, CborSerializable, CoseKey, CoseMac0, CoseMac0Builder, HeaderBuilder};
 use service_vm_comm::{BoringSSLApiName, RequestProcessingError};
 
 type Result<T> = result::Result<T, RequestProcessingError>;
 
+/// Verifies the MAC of the given public key.
+/// TODO(b/299256925): Return the validated public key.
+pub fn validate_public_key(maced_public_key: &[u8], hmac_key: &[u8]) -> Result<()> {
+    let cose_mac = CoseMac0::from_slice(maced_public_key)?;
+    cose_mac.verify_tag(&[], |tag, data| verify_tag(tag, data, hmac_key))
+}
+
+fn verify_tag(tag: &[u8], data: &[u8], hmac_key: &[u8]) -> Result<()> {
+    let computed_tag = hmac_sha256(hmac_key, data)?;
+    if tag == computed_tag {
+        Ok(())
+    } else {
+        Err(RequestProcessingError::InvalidMac)
+    }
+}
+
 /// Returns the MACed public key.
 pub fn build_maced_public_key(public_key: CoseKey, hmac_key: &[u8]) -> Result<Vec<u8>> {
     const ALGO: iana::Algorithm = iana::Algorithm::HMAC_256_256;
diff --git a/rialto/src/requests/rkp.rs b/rialto/src/requests/rkp.rs
index 58e054f..9b3e569 100644
--- a/rialto/src/requests/rkp.rs
+++ b/rialto/src/requests/rkp.rs
@@ -16,7 +16,7 @@
 //! service VM via the RKP (Remote Key Provisioning) server.
 
 use super::ec_key::EcKey;
-use super::pub_key::build_maced_public_key;
+use super::pub_key::{build_maced_public_key, validate_public_key};
 use alloc::vec::Vec;
 use core::result;
 use diced_open_dice::DiceArtifacts;
@@ -40,9 +40,14 @@
 }
 
 pub(super) fn generate_certificate_request(
-    _params: GenerateCertificateRequestParams,
+    params: GenerateCertificateRequestParams,
     _dice_artifacts: &dyn DiceArtifacts,
 ) -> Result<Vec<u8>> {
+    // TODO(b/300590857): Derive the HMAC key from the DICE sealing CDI.
+    let hmac_key = [];
+    for key_to_sign in params.keys_to_sign {
+        validate_public_key(&key_to_sign, &hmac_key)?;
+    }
     // TODO(b/299256925): Generate the certificate request
     Ok(Vec::new())
 }
diff --git a/rialto/tests/test.rs b/rialto/tests/test.rs
index e975bbf..c9d68ed 100644
--- a/rialto/tests/test.rs
+++ b/rialto/tests/test.rs
@@ -49,8 +49,8 @@
     let mut vm = start_service_vm(vm_type)?;
 
     check_processing_reverse_request(&mut vm)?;
-    check_processing_generating_key_pair_request(&mut vm)?;
-    check_processing_generating_certificate_request(&mut vm)?;
+    let maced_public_key = check_processing_generating_key_pair_request(&mut vm)?;
+    check_processing_generating_certificate_request(&mut vm, maced_public_key)?;
     Ok(())
 }
 
@@ -68,7 +68,7 @@
     Ok(())
 }
 
-fn check_processing_generating_key_pair_request(vm: &mut ServiceVm) -> Result<()> {
+fn check_processing_generating_key_pair_request(vm: &mut ServiceVm) -> Result<Vec<u8>> {
     let request = Request::GenerateEcdsaP256KeyPair;
 
     let response = vm.process_request(request)?;
@@ -77,9 +77,9 @@
     match response {
         Response::GenerateEcdsaP256KeyPair(EcdsaP256KeyPair { maced_public_key, .. }) => {
             assert_array_has_nonzero(&maced_public_key[..]);
-            Ok(())
+            Ok(maced_public_key)
         }
-        _ => bail!("Incorrect response type"),
+        _ => bail!("Incorrect response type: {response:?}"),
     }
 }
 
@@ -87,8 +87,14 @@
     assert!(v.iter().any(|&x| x != 0))
 }
 
-fn check_processing_generating_certificate_request(vm: &mut ServiceVm) -> Result<()> {
-    let params = GenerateCertificateRequestParams { keys_to_sign: vec![], challenge: vec![] };
+fn check_processing_generating_certificate_request(
+    vm: &mut ServiceVm,
+    maced_public_key: Vec<u8>,
+) -> Result<()> {
+    let params = GenerateCertificateRequestParams {
+        keys_to_sign: vec![maced_public_key],
+        challenge: vec![],
+    };
     let request = Request::GenerateCertificateRequest(params);
 
     let response = vm.process_request(request)?;
@@ -96,7 +102,7 @@
 
     match response {
         Response::GenerateCertificateRequest(_) => Ok(()),
-        _ => bail!("Incorrect response type"),
+        _ => bail!("Incorrect response type: {response:?}"),
     }
 }