Merge "pvmfw: refactor: Expose DICE inputs to main()"
diff --git a/pvmfw/avb/src/verify.rs b/pvmfw/avb/src/verify.rs
index 67658fd..b03506c 100644
--- a/pvmfw/avb/src/verify.rs
+++ b/pvmfw/avb/src/verify.rs
@@ -23,13 +23,15 @@
 
 /// Verified data returned when the payload verification succeeds.
 #[derive(Debug, PartialEq, Eq)]
-pub struct VerifiedBootData {
+pub struct VerifiedBootData<'a> {
     /// DebugLevel of the VM.
     pub debug_level: DebugLevel,
     /// Kernel digest.
     pub kernel_digest: Digest,
     /// Initrd digest if initrd exists.
     pub initrd_digest: Option<Digest>,
+    /// Trusted public key.
+    pub public_key: &'a [u8],
 }
 
 /// This enum corresponds to the `DebugLevel` in `VirtualMachineConfig`.
@@ -94,11 +96,11 @@
 }
 
 /// Verifies the payload (signed kernel + initrd) against the trusted public key.
-pub fn verify_payload(
+pub fn verify_payload<'a>(
     kernel: &[u8],
     initrd: Option<&[u8]>,
-    trusted_public_key: &[u8],
-) -> Result<VerifiedBootData, AvbSlotVerifyError> {
+    trusted_public_key: &'a [u8],
+) -> Result<VerifiedBootData<'a>, AvbSlotVerifyError> {
     let mut payload = Payload::new(kernel, initrd, trusted_public_key);
     let mut ops = Ops::from(&mut payload);
     let kernel_verify_result = ops.verify_partition(PartitionName::Kernel.as_cstr())?;
@@ -119,6 +121,7 @@
             debug_level: DebugLevel::None,
             kernel_digest: kernel_descriptor.digest,
             initrd_digest: None,
+            public_key: trusted_public_key,
         });
     }
 
@@ -142,5 +145,6 @@
         debug_level,
         kernel_digest: kernel_descriptor.digest,
         initrd_digest: Some(initrd_descriptor.digest),
+        public_key: trusted_public_key,
     })
 }
diff --git a/pvmfw/avb/tests/api_test.rs b/pvmfw/avb/tests/api_test.rs
index 1d7369d..78f274a 100644
--- a/pvmfw/avb/tests/api_test.rs
+++ b/pvmfw/avb/tests/api_test.rs
@@ -53,16 +53,21 @@
 
 #[test]
 fn payload_expecting_no_initrd_passes_verification_with_no_initrd() -> Result<()> {
+    let public_key = load_trusted_public_key()?;
     let verified_boot_data = verify_payload(
         &fs::read(TEST_IMG_WITH_ONE_HASHDESC_PATH)?,
         /*initrd=*/ None,
-        &load_trusted_public_key()?,
+        &public_key,
     )
     .map_err(|e| anyhow!("Verification failed. Error: {}", e))?;
 
     let kernel_digest = hash(&[&hex::decode("1111")?, &fs::read(UNSIGNED_TEST_IMG_PATH)?]);
-    let expected_boot_data =
-        VerifiedBootData { debug_level: DebugLevel::None, kernel_digest, initrd_digest: None };
+    let expected_boot_data = VerifiedBootData {
+        debug_level: DebugLevel::None,
+        kernel_digest,
+        initrd_digest: None,
+        public_key: &public_key,
+    };
     assert_eq!(expected_boot_data, verified_boot_data);
 
     Ok(())
diff --git a/pvmfw/avb/tests/utils.rs b/pvmfw/avb/tests/utils.rs
index 9942b98..6713846 100644
--- a/pvmfw/avb/tests/utils.rs
+++ b/pvmfw/avb/tests/utils.rs
@@ -102,16 +102,21 @@
     initrd_salt: &[u8],
     expected_debug_level: DebugLevel,
 ) -> Result<()> {
+    let public_key = load_trusted_public_key()?;
     let kernel = load_latest_signed_kernel()?;
-    let verified_boot_data = verify_payload(&kernel, Some(initrd), &load_trusted_public_key()?)
+    let verified_boot_data = verify_payload(&kernel, Some(initrd), &public_key)
         .map_err(|e| anyhow!("Verification failed. Error: {}", e))?;
 
     let footer = extract_avb_footer(&kernel)?;
     let kernel_digest =
         hash(&[&hash(&[b"bootloader"]), &kernel[..usize::try_from(footer.original_image_size)?]]);
     let initrd_digest = Some(hash(&[&hash(&[initrd_salt]), initrd]));
-    let expected_boot_data =
-        VerifiedBootData { debug_level: expected_debug_level, kernel_digest, initrd_digest };
+    let expected_boot_data = VerifiedBootData {
+        debug_level: expected_debug_level,
+        kernel_digest,
+        initrd_digest,
+        public_key: &public_key,
+    };
     assert_eq!(expected_boot_data, verified_boot_data);
 
     Ok(())
diff --git a/pvmfw/src/dice.rs b/pvmfw/src/dice.rs
index e354666..9c5f59a 100644
--- a/pvmfw/src/dice.rs
+++ b/pvmfw/src/dice.rs
@@ -19,7 +19,7 @@
 use core::ffi::CStr;
 use core::mem::size_of;
 use core::slice;
-use dice::bcc::Handover;
+
 use dice::Config;
 use dice::DiceMode;
 use dice::InputValues;
@@ -42,35 +42,40 @@
     hash(&digests)
 }
 
-/// Derive the VM-specific secrets and certificate through DICE.
-pub fn derive_next_bcc(
-    bcc: &Handover,
-    next_bcc: &mut [u8],
-    verified_boot_data: &VerifiedBootData,
-    authority: &[u8],
-) -> dice::Result<usize> {
-    let code_hash = to_dice_hash(verified_boot_data)?;
-    let auth_hash = hash(authority)?;
-    let mode = to_dice_mode(verified_boot_data.debug_level);
-    let component_name = CStr::from_bytes_with_nul(b"vm_entry\0").unwrap();
-    let mut config_descriptor_buffer = [0; 128];
-    let config_descriptor_size = bcc_format_config_descriptor(
-        Some(component_name),
-        None,  // component_version
-        false, // resettable
-        &mut config_descriptor_buffer,
-    )?;
-    let config = &config_descriptor_buffer[..config_descriptor_size];
+pub struct PartialInputs {
+    code_hash: dice::Hash,
+    auth_hash: dice::Hash,
+    mode: DiceMode,
+}
 
-    let input_values = InputValues::new(
-        code_hash,
-        Config::Descriptor(config),
-        auth_hash,
-        mode,
-        [0u8; HIDDEN_SIZE], // TODO(b/249723852): Get salt from instance.img (virtio-blk) and/or TRNG.
-    );
+impl PartialInputs {
+    pub fn new(data: &VerifiedBootData) -> dice::Result<Self> {
+        let code_hash = to_dice_hash(data)?;
+        let auth_hash = hash(data.public_key)?;
+        let mode = to_dice_mode(data.debug_level);
 
-    bcc.main_flow(&input_values, next_bcc)
+        Ok(Self { code_hash, auth_hash, mode })
+    }
+
+    pub fn into_input_values(self, salt: &[u8; HIDDEN_SIZE]) -> dice::Result<InputValues> {
+        let component_name = CStr::from_bytes_with_nul(b"vm_entry\0").unwrap();
+        let mut config_descriptor_buffer = [0; 128];
+        let config_descriptor_size = bcc_format_config_descriptor(
+            Some(component_name),
+            None,  // component_version
+            false, // resettable
+            &mut config_descriptor_buffer,
+        )?;
+        let config = &config_descriptor_buffer[..config_descriptor_size];
+
+        Ok(InputValues::new(
+            self.code_hash,
+            Config::Descriptor(config),
+            self.auth_hash,
+            self.mode,
+            *salt,
+        ))
+    }
 }
 
 /// Flushes data caches over the provided address range.
diff --git a/pvmfw/src/main.rs b/pvmfw/src/main.rs
index be5a16a..f7774e4 100644
--- a/pvmfw/src/main.rs
+++ b/pvmfw/src/main.rs
@@ -38,7 +38,7 @@
 use alloc::boxed::Box;
 
 use crate::{
-    dice::derive_next_bcc,
+    dice::PartialInputs,
     entry::RebootReason,
     fdt::add_dice_node,
     helpers::flush,
@@ -90,13 +90,20 @@
     })?;
     // By leaking the slice, its content will be left behind for the next stage.
     let next_bcc = Box::leak(next_bcc);
-    let next_bcc_size =
-        derive_next_bcc(bcc, next_bcc, &verified_boot_data, PUBLIC_KEY).map_err(|e| {
-            error!("Failed to derive next-stage DICE secrets: {e:?}");
-            RebootReason::SecretDerivationError
-        })?;
-    trace!("Next BCC: {:x?}", bcc::Handover::new(&next_bcc[..next_bcc_size]));
 
+    let dice_inputs = PartialInputs::new(&verified_boot_data).map_err(|e| {
+        error!("Failed to compute partial DICE inputs: {e:?}");
+        RebootReason::InternalError
+    })?;
+    let salt = [0; ::dice::HIDDEN_SIZE]; // TODO(b/249723852): Get from instance.img and/or TRNG.
+    let dice_inputs = dice_inputs.into_input_values(&salt).map_err(|e| {
+        error!("Failed to generate DICE inputs: {e:?}");
+        RebootReason::InternalError
+    })?;
+    let _ = bcc.main_flow(&dice_inputs, next_bcc).map_err(|e| {
+        error!("Failed to derive next-stage DICE secrets: {e:?}");
+        RebootReason::SecretDerivationError
+    })?;
     flush(next_bcc);
 
     add_dice_node(fdt, next_bcc.as_ptr() as usize, NEXT_BCC_SIZE).map_err(|e| {