Merge "Drop the flag dice_changes in pvmfw and RKP VM" into main
diff --git a/guest/pvmfw/src/fdt.rs b/guest/pvmfw/src/fdt.rs
index 29212f9..818d342 100644
--- a/guest/pvmfw/src/fdt.rs
+++ b/guest/pvmfw/src/fdt.rs
@@ -112,6 +112,24 @@
     Ok(None)
 }
 
+/// Read /avf/untrusted/instance-id, if present.
+pub fn read_instance_id(fdt: &Fdt) -> libfdt::Result<Option<&[u8]>> {
+    read_avf_untrusted_prop(fdt, c"instance-id")
+}
+
+/// Read /avf/untrusted/defer-rollback-protection, if present.
+pub fn read_defer_rollback_protection(fdt: &Fdt) -> libfdt::Result<Option<&[u8]>> {
+    read_avf_untrusted_prop(fdt, c"defer-rollback-protection")
+}
+
+fn read_avf_untrusted_prop<'a>(fdt: &'a Fdt, prop: &CStr) -> libfdt::Result<Option<&'a [u8]>> {
+    if let Some(node) = fdt.node(c"/avf/untrusted")? {
+        node.getprop(prop)
+    } else {
+        Ok(None)
+    }
+}
+
 fn patch_initrd_range(fdt: &mut Fdt, initrd_range: &Range<usize>) -> libfdt::Result<()> {
     let start = u32::try_from(initrd_range.start).unwrap();
     let end = u32::try_from(initrd_range.end).unwrap();
diff --git a/guest/pvmfw/src/main.rs b/guest/pvmfw/src/main.rs
index 0a3dca6..afa64e0 100644
--- a/guest/pvmfw/src/main.rs
+++ b/guest/pvmfw/src/main.rs
@@ -35,22 +35,20 @@
 use crate::bcc::Bcc;
 use crate::dice::PartialInputs;
 use crate::entry::RebootReason;
-use crate::fdt::{modify_for_next_stage, sanitize_device_tree};
+use crate::fdt::{modify_for_next_stage, read_instance_id, sanitize_device_tree};
 use crate::rollback::perform_rollback_protection;
 use alloc::borrow::Cow;
 use alloc::boxed::Box;
 use bssl_avf::Digester;
 use diced_open_dice::{bcc_handover_parse, DiceArtifacts, DiceContext, Hidden, VM_KEY_ALGORITHM};
-use libfdt::{Fdt, FdtNode};
+use libfdt::Fdt;
 use log::{debug, error, info, trace, warn};
 use pvmfw_avb::verify_payload;
 use pvmfw_avb::DebugLevel;
 use pvmfw_embedded_key::PUBLIC_KEY;
-use vmbase::fdt::pci::{PciError, PciInfo};
 use vmbase::heap;
-use vmbase::memory::{flush, init_shared_pool, SIZE_4KB};
+use vmbase::memory::{flush, SIZE_4KB};
 use vmbase::rand;
-use vmbase::virtio::pci;
 
 fn main<'a>(
     untrusted_fdt: &mut Fdt,
@@ -77,8 +75,6 @@
     })?;
     trace!("BCC: {bcc_handover:x?}");
 
-    let cdi_seal = bcc_handover.cdi_seal();
-
     let bcc = Bcc::new(bcc_handover.bcc()).map_err(|e| {
         error!("{e}");
         RebootReason::InvalidBcc
@@ -102,19 +98,8 @@
     }
 
     let guest_page_size = verified_boot_data.page_size.unwrap_or(SIZE_4KB);
-    let fdt_info = sanitize_device_tree(untrusted_fdt, vm_dtbo, vm_ref_dt, guest_page_size)?;
+    let _ = sanitize_device_tree(untrusted_fdt, vm_dtbo, vm_ref_dt, guest_page_size)?;
     let fdt = untrusted_fdt; // DT has now been sanitized.
-    let pci_info = PciInfo::from_fdt(fdt).map_err(handle_pci_error)?;
-    debug!("PCI: {:#x?}", pci_info);
-    // Set up PCI bus for VirtIO devices.
-    let mut pci_root = pci::initialize(pci_info).map_err(|e| {
-        error!("Failed to initialize PCI: {e}");
-        RebootReason::InternalError
-    })?;
-    init_shared_pool(fdt_info.swiotlb_info.fixed_range()).map_err(|e| {
-        error!("Failed to initialize shared pool: {e}");
-        RebootReason::InternalError
-    })?;
 
     let next_bcc_size = guest_page_size;
     let next_bcc = heap::aligned_boxed_slice(next_bcc_size, guest_page_size).ok_or_else(|| {
@@ -129,13 +114,12 @@
         RebootReason::InternalError
     })?;
 
-    let instance_hash = Some(salt_from_instance_id(fdt)?);
+    let instance_hash = salt_from_instance_id(fdt)?;
     let (new_instance, salt, defer_rollback_protection) = perform_rollback_protection(
         fdt,
         &verified_boot_data,
         &dice_inputs,
-        &mut pci_root,
-        cdi_seal,
+        bcc_handover.cdi_seal(),
         instance_hash,
     )?;
     trace!("Got salt for instance: {salt:x?}");
@@ -204,8 +188,14 @@
 
 // Get the "salt" which is one of the input for DICE derivation.
 // This provides differentiation of secrets for different VM instances with same payloads.
-fn salt_from_instance_id(fdt: &Fdt) -> Result<Hidden, RebootReason> {
-    let id = instance_id(fdt)?;
+fn salt_from_instance_id(fdt: &Fdt) -> Result<Option<Hidden>, RebootReason> {
+    let Some(id) = read_instance_id(fdt).map_err(|e| {
+        error!("Failed to get instance-id in DT: {e}");
+        RebootReason::InvalidFdt
+    })?
+    else {
+        return Ok(None);
+    };
     let salt = Digester::sha512()
         .digest(&[&b"InstanceId:"[..], id].concat())
         .map_err(|e| {
@@ -214,46 +204,5 @@
         })?
         .try_into()
         .map_err(|_| RebootReason::InternalError)?;
-    Ok(salt)
-}
-
-fn instance_id(fdt: &Fdt) -> Result<&[u8], RebootReason> {
-    let node = avf_untrusted_node(fdt)?;
-    let id = node.getprop(c"instance-id").map_err(|e| {
-        error!("Failed to get instance-id in DT: {e}");
-        RebootReason::InvalidFdt
-    })?;
-    id.ok_or_else(|| {
-        error!("Missing instance-id");
-        RebootReason::InvalidFdt
-    })
-}
-
-fn avf_untrusted_node(fdt: &Fdt) -> Result<FdtNode, RebootReason> {
-    let node = fdt.node(c"/avf/untrusted").map_err(|e| {
-        error!("Failed to get /avf/untrusted node: {e}");
-        RebootReason::InvalidFdt
-    })?;
-    node.ok_or_else(|| {
-        error!("/avf/untrusted node is missing in DT");
-        RebootReason::InvalidFdt
-    })
-}
-
-/// Logs the given PCI error and returns the appropriate `RebootReason`.
-fn handle_pci_error(e: PciError) -> RebootReason {
-    error!("{}", e);
-    match e {
-        PciError::FdtErrorPci(_)
-        | PciError::FdtNoPci
-        | PciError::FdtErrorReg(_)
-        | PciError::FdtMissingReg
-        | PciError::FdtRegEmpty
-        | PciError::FdtRegMissingSize
-        | PciError::CamWrongSize(_)
-        | PciError::FdtErrorRanges(_)
-        | PciError::FdtMissingRanges
-        | PciError::RangeAddressMismatch { .. }
-        | PciError::NoSuitableRange => RebootReason::InvalidFdt,
-    }
+    Ok(Some(salt))
 }
diff --git a/guest/pvmfw/src/rollback.rs b/guest/pvmfw/src/rollback.rs
index f7723d7..74b2cd8 100644
--- a/guest/pvmfw/src/rollback.rs
+++ b/guest/pvmfw/src/rollback.rs
@@ -16,16 +16,20 @@
 
 use crate::dice::PartialInputs;
 use crate::entry::RebootReason;
+use crate::fdt::read_defer_rollback_protection;
 use crate::instance::EntryBody;
 use crate::instance::Error as InstanceError;
 use crate::instance::{get_recorded_entry, record_instance_entry};
 use diced_open_dice::Hidden;
-use libfdt::{Fdt, FdtNode};
+use libfdt::Fdt;
 use log::{error, info};
 use pvmfw_avb::Capability;
 use pvmfw_avb::VerifiedBootData;
 use virtio_drivers::transport::pci::bus::PciRoot;
+use vmbase::fdt::{pci::PciInfo, SwiotlbInfo};
+use vmbase::memory::init_shared_pool;
 use vmbase::rand;
+use vmbase::virtio::pci;
 
 /// Performs RBP based on the input payload, current DICE chain, and host-controlled platform.
 ///
@@ -37,7 +41,6 @@
     fdt: &Fdt,
     verified_boot_data: &VerifiedBootData,
     dice_inputs: &PartialInputs,
-    pci_root: &mut PciRoot,
     cdi_seal: &[u8],
     instance_hash: Option<Hidden>,
 ) -> Result<(bool, Hidden, bool), RebootReason> {
@@ -53,7 +56,7 @@
         skip_rollback_protection()?;
         Ok((false, instance_hash.unwrap(), false))
     } else {
-        perform_legacy_rollback_protection(dice_inputs, pci_root, cdi_seal, instance_hash)
+        perform_legacy_rollback_protection(fdt, dice_inputs, cdi_seal, instance_hash)
     }
 }
 
@@ -92,17 +95,18 @@
 
 /// Performs RBP using instance.img where updates require clearing old entries, causing new CDIs.
 fn perform_legacy_rollback_protection(
+    fdt: &Fdt,
     dice_inputs: &PartialInputs,
-    pci_root: &mut PciRoot,
     cdi_seal: &[u8],
     instance_hash: Option<Hidden>,
 ) -> Result<(bool, Hidden, bool), RebootReason> {
     info!("Fallback to instance.img based rollback checks");
-    let (recorded_entry, mut instance_img, header_index) = get_recorded_entry(pci_root, cdi_seal)
-        .map_err(|e| {
-        error!("Failed to get entry from instance.img: {e}");
-        RebootReason::InternalError
-    })?;
+    let mut pci_root = initialize_instance_img_device(fdt)?;
+    let (recorded_entry, mut instance_img, header_index) =
+        get_recorded_entry(&mut pci_root, cdi_seal).map_err(|e| {
+            error!("Failed to get entry from instance.img: {e}");
+            RebootReason::InternalError
+        })?;
     let (new_instance, salt) = if let Some(entry) = recorded_entry {
         check_dice_measurements_match_entry(dice_inputs, &entry)?;
         let salt = instance_hash.unwrap_or(entry.salt);
@@ -155,24 +159,34 @@
 }
 
 fn should_defer_rollback_protection(fdt: &Fdt) -> Result<bool, RebootReason> {
-    let node = avf_untrusted_node(fdt)?;
-    let defer_rbp = node
-        .getprop(c"defer-rollback-protection")
-        .map_err(|e| {
-            error!("Failed to get defer-rollback-protection property in DT: {e}");
-            RebootReason::InvalidFdt
-        })?
-        .is_some();
-    Ok(defer_rbp)
-}
-
-fn avf_untrusted_node(fdt: &Fdt) -> Result<FdtNode, RebootReason> {
-    let node = fdt.node(c"/avf/untrusted").map_err(|e| {
-        error!("Failed to get /avf/untrusted node: {e}");
+    let defer_rbp = read_defer_rollback_protection(fdt).map_err(|e| {
+        error!("Failed to get defer-rollback-protection property in DT: {e}");
         RebootReason::InvalidFdt
     })?;
-    node.ok_or_else(|| {
-        error!("/avf/untrusted node is missing in DT");
+    Ok(defer_rbp.is_some())
+}
+
+/// Set up PCI bus and VirtIO-blk device containing the instance.img partition.
+fn initialize_instance_img_device(fdt: &Fdt) -> Result<PciRoot, RebootReason> {
+    let pci_info = PciInfo::from_fdt(fdt).map_err(|e| {
+        error!("Failed to detect PCI from DT: {e}");
         RebootReason::InvalidFdt
-    })
+    })?;
+    let swiotlb_range = SwiotlbInfo::new_from_fdt(fdt)
+        .map_err(|e| {
+            error!("Failed to detect swiotlb from DT: {e}");
+            RebootReason::InvalidFdt
+        })?
+        .and_then(|info| info.fixed_range());
+
+    let pci_root = pci::initialize(pci_info).map_err(|e| {
+        error!("Failed to initialize PCI: {e}");
+        RebootReason::InternalError
+    })?;
+    init_shared_pool(swiotlb_range).map_err(|e| {
+        error!("Failed to initialize shared pool: {e}");
+        RebootReason::InternalError
+    })?;
+
+    Ok(pci_root)
 }