Merge "pvmfw: Refactor get_or_generate_instance_salt" into main
diff --git a/pvmfw/src/instance.rs b/pvmfw/src/instance.rs
index 6daadd9..43c7442 100644
--- a/pvmfw/src/instance.rs
+++ b/pvmfw/src/instance.rs
@@ -27,7 +27,6 @@
 use log::trace;
 use uuid::Uuid;
 use virtio_drivers::transport::{pci::bus::PciRoot, DeviceType, Transport};
-use vmbase::rand;
 use vmbase::util::ceiling_div;
 use vmbase::virtio::pci::{PciTransportIterator, VirtIOBlk};
 use vmbase::virtio::HalImpl;
@@ -38,8 +37,6 @@
 pub enum Error {
     /// Unexpected I/O error while accessing the underlying disk.
     FailedIo(gpt::Error),
-    /// Failed to generate a random salt to be stored.
-    FailedSaltGeneration(rand::Error),
     /// Impossible to create a new instance.img entry.
     InstanceImageFull,
     /// Badly formatted instance.img header block.
@@ -66,7 +63,6 @@
     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
         match self {
             Self::FailedIo(e) => write!(f, "Failed I/O to disk: {e}"),
-            Self::FailedSaltGeneration(e) => write!(f, "Failed to generate salt: {e}"),
             Self::InstanceImageFull => write!(f, "Failed to obtain a free instance.img partition"),
             Self::InvalidInstanceImageHeader => write!(f, "instance.img header is invalid"),
             Self::MissingInstanceImage => write!(f, "Failed to find the instance.img partition"),
@@ -93,27 +89,27 @@
 
 pub type Result<T> = core::result::Result<T, Error>;
 
-pub fn get_or_generate_instance_salt(
+fn aead_ctx_from_secret(secret: &[u8]) -> Result<AeadContext> {
+    let key = hkdf::<32>(secret, /* salt= */ &[], b"vm-instance", Digester::sha512())?;
+    Ok(AeadContext::new(Aead::aes_256_gcm_randnonce(), key.as_slice(), /* tag_len */ None)?)
+}
+
+/// Get the entry from instance.img. This method additionally returns Partition corresponding to
+/// pvmfw in the instance.img as well as index corresponding to empty header which can be used to
+/// record instance data with `record_instance_entry`.
+pub(crate) fn get_recorded_entry(
     pci_root: &mut PciRoot,
-    dice_inputs: &PartialInputs,
     secret: &[u8],
-) -> Result<(bool, Hidden)> {
+) -> Result<(Option<EntryBody>, Partition, usize)> {
     let mut instance_img = find_instance_img(pci_root)?;
 
     let entry = locate_entry(&mut instance_img)?;
     trace!("Found pvmfw instance.img entry: {entry:?}");
 
-    let key = hkdf::<32>(secret, /* salt= */ &[], b"vm-instance", Digester::sha512())?;
-    let tag_len = None;
-    let aead_ctx = AeadContext::new(Aead::aes_256_gcm_randnonce(), key.as_slice(), tag_len)?;
-    let ad = &[];
-    // The nonce is generated internally for `aes_256_gcm_randnonce`, so no additional
-    // nonce is required.
-    let nonce = &[];
-
-    let mut blk = [0; BLK_SIZE];
     match entry {
         PvmfwEntry::Existing { header_index, payload_size } => {
+            let aead_ctx = aead_ctx_from_secret(secret)?;
+            let mut blk = [0; BLK_SIZE];
             if payload_size > blk.len() {
                 // We currently only support single-blk entries.
                 return Err(Error::UnsupportedEntrySize(payload_size));
@@ -123,52 +119,41 @@
 
             let payload = &blk[..payload_size];
             let mut entry = [0; size_of::<EntryBody>()];
-            let decrypted = aead_ctx.open(payload, nonce, ad, &mut entry)?;
-
+            // The nonce is generated internally for `aes_256_gcm_randnonce`, so no additional
+            // nonce is required.
+            let decrypted =
+                aead_ctx.open(payload, /* nonce */ &[], /* ad */ &[], &mut entry)?;
             let body = EntryBody::read_from(decrypted).unwrap();
-            if dice_inputs.rkp_vm_marker {
-                // The RKP VM is allowed to run if it has passed the verified boot check and
-                // contains the expected version in its AVB footer.
-                // The comparison below with the previous boot information is skipped to enable the
-                // simultaneous update of the pvmfw and RKP VM.
-                // For instance, when both the pvmfw and RKP VM are updated, the code hash of the
-                // RKP VM will differ from the one stored in the instance image. In this case, the
-                // RKP VM is still allowed to run.
-                // This ensures that the updated RKP VM will retain the same CDIs in the next stage.
-                return Ok((false, body.salt));
-            }
-            if body.code_hash != dice_inputs.code_hash {
-                Err(Error::RecordedCodeHashMismatch)
-            } else if body.auth_hash != dice_inputs.auth_hash {
-                Err(Error::RecordedAuthHashMismatch)
-            } else if body.mode() != dice_inputs.mode {
-                Err(Error::RecordedDiceModeMismatch)
-            } else {
-                Ok((false, body.salt))
-            }
+            Ok((Some(body), instance_img, header_index))
         }
-        PvmfwEntry::New { header_index } => {
-            let salt = rand::random_array().map_err(Error::FailedSaltGeneration)?;
-            let body = EntryBody::new(dice_inputs, &salt);
-
-            // We currently only support single-blk entries.
-            let plaintext = body.as_bytes();
-            assert!(plaintext.len() + aead_ctx.aead().max_overhead() < blk.len());
-            let encrypted = aead_ctx.seal(plaintext, nonce, ad, &mut blk)?;
-            let payload_size = encrypted.len();
-            let payload_index = header_index + 1;
-            instance_img.write_block(payload_index, &blk).map_err(Error::FailedIo)?;
-
-            let header = EntryHeader::new(PvmfwEntry::UUID, payload_size);
-            header.write_to_prefix(blk.as_mut_slice()).unwrap();
-            blk[header.as_bytes().len()..].fill(0);
-            instance_img.write_block(header_index, &blk).map_err(Error::FailedIo)?;
-
-            Ok((true, salt))
-        }
+        PvmfwEntry::New { header_index } => Ok((None, instance_img, header_index)),
     }
 }
 
+pub(crate) fn record_instance_entry(
+    body: &EntryBody,
+    secret: &[u8],
+    instance_img: &mut Partition,
+    header_index: usize,
+) -> Result<()> {
+    // We currently only support single-blk entries.
+    let mut blk = [0; BLK_SIZE];
+    let plaintext = body.as_bytes();
+    let aead_ctx = aead_ctx_from_secret(secret)?;
+    assert!(plaintext.len() + aead_ctx.aead().max_overhead() < blk.len());
+    let encrypted = aead_ctx.seal(plaintext, /* nonce */ &[], /* ad */ &[], &mut blk)?;
+    let payload_size = encrypted.len();
+    let payload_index = header_index + 1;
+    instance_img.write_block(payload_index, &blk).map_err(Error::FailedIo)?;
+
+    let header = EntryHeader::new(PvmfwEntry::UUID, payload_size);
+    header.write_to_prefix(blk.as_mut_slice()).unwrap();
+    blk[header.as_bytes().len()..].fill(0);
+    instance_img.write_block(header_index, &blk).map_err(Error::FailedIo)?;
+
+    Ok(())
+}
+
 #[derive(FromZeroes, FromBytes)]
 #[repr(C, packed)]
 struct Header {
@@ -276,15 +261,15 @@
 
 #[derive(AsBytes, FromZeroes, FromBytes)]
 #[repr(C)]
-struct EntryBody {
-    code_hash: Hash,
-    auth_hash: Hash,
-    salt: Hidden,
+pub(crate) struct EntryBody {
+    pub code_hash: Hash,
+    pub auth_hash: Hash,
+    pub salt: Hidden,
     mode: u8,
 }
 
 impl EntryBody {
-    fn new(dice_inputs: &PartialInputs, salt: &Hidden) -> Self {
+    pub(crate) fn new(dice_inputs: &PartialInputs, salt: &Hidden) -> Self {
         let mode = match dice_inputs.mode {
             DiceMode::kDiceModeNotInitialized => 0,
             DiceMode::kDiceModeNormal => 1,
@@ -300,7 +285,7 @@
         }
     }
 
-    fn mode(&self) -> DiceMode {
+    pub(crate) fn mode(&self) -> DiceMode {
         match self.mode {
             1 => DiceMode::kDiceModeNormal,
             2 => DiceMode::kDiceModeDebug,
diff --git a/pvmfw/src/main.rs b/pvmfw/src/main.rs
index f80bae1..12d63d5 100644
--- a/pvmfw/src/main.rs
+++ b/pvmfw/src/main.rs
@@ -37,7 +37,9 @@
 use crate::entry::RebootReason;
 use crate::fdt::modify_for_next_stage;
 use crate::helpers::GUEST_PAGE_SIZE;
-use crate::instance::get_or_generate_instance_salt;
+use crate::instance::EntryBody;
+use crate::instance::Error as InstanceError;
+use crate::instance::{get_recorded_entry, record_instance_entry};
 use alloc::borrow::Cow;
 use alloc::boxed::Box;
 use core::ops::Range;
@@ -150,11 +152,43 @@
         error!("Failed to compute partial DICE inputs: {e:?}");
         RebootReason::InternalError
     })?;
-    let (new_instance, salt) = get_or_generate_instance_salt(&mut pci_root, &dice_inputs, cdi_seal)
-        .map_err(|e| {
-            error!("Failed to get instance.img salt: {e}");
+
+    let (recorded_entry, mut instance_img, header_index) =
+        get_recorded_entry(&mut pci_root, cdi_seal).map_err(|e| {
+            error!("Failed to get entry from instance.img: {e}");
             RebootReason::InternalError
         })?;
+    let (new_instance, salt) = if let Some(entry) = recorded_entry {
+        // The RKP VM is allowed to run if it has passed the verified boot check and
+        // contains the expected version in its AVB footer.
+        // The comparison below with the previous boot information is skipped to enable the
+        // simultaneous update of the pvmfw and RKP VM.
+        // For instance, when both the pvmfw and RKP VM are updated, the code hash of the
+        // RKP VM will differ from the one stored in the instance image. In this case, the
+        // RKP VM is still allowed to run.
+        // This ensures that the updated RKP VM will retain the same CDIs in the next stage.
+        if !dice_inputs.rkp_vm_marker {
+            ensure_dice_measurements_match_entry(&dice_inputs, &entry).map_err(|e| {
+                error!(
+                    "Dice measurements do not match recorded entry.
+                This may be because of update: {e}"
+                );
+                RebootReason::InternalError
+            })?;
+        }
+        (false, entry.salt)
+    } else {
+        let salt = rand::random_array().map_err(|e| {
+            error!("Failed to generated instance.img salt: {e}");
+            RebootReason::InternalError
+        })?;
+        let entry = EntryBody::new(&dice_inputs, &salt);
+        record_instance_entry(&entry, cdi_seal, &mut instance_img, header_index).map_err(|e| {
+            error!("Failed to get recorded entry in instance.img: {e}");
+            RebootReason::InternalError
+        })?;
+        (true, salt)
+    };
     trace!("Got salt from instance.img: {salt:x?}");
 
     let new_bcc_handover = if cfg!(dice_changes) {
@@ -207,6 +241,21 @@
     Ok(bcc_range)
 }
 
+fn ensure_dice_measurements_match_entry(
+    dice_inputs: &PartialInputs,
+    entry: &EntryBody,
+) -> Result<(), InstanceError> {
+    if entry.code_hash != dice_inputs.code_hash {
+        Err(InstanceError::RecordedCodeHashMismatch)
+    } else if entry.auth_hash != dice_inputs.auth_hash {
+        Err(InstanceError::RecordedAuthHashMismatch)
+    } else if entry.mode() != dice_inputs.mode {
+        Err(InstanceError::RecordedDiceModeMismatch)
+    } else {
+        Ok(())
+    }
+}
+
 /// Logs the given PCI error and returns the appropriate `RebootReason`.
 fn handle_pci_error(e: PciError) -> RebootReason {
     error!("{}", e);