Merge "pvmfw: Return MemorySlices from main_wrapper" into main
diff --git a/guest/pvmfw/src/entry.rs b/guest/pvmfw/src/entry.rs
index 7c46515..862fb1d 100644
--- a/guest/pvmfw/src/entry.rs
+++ b/guest/pvmfw/src/entry.rs
@@ -15,10 +15,9 @@
 //! Low-level entry and exit points of pvmfw.
 
 use crate::config;
-use crate::memory;
+use crate::memory::MemorySlices;
 use core::arch::asm;
 use core::mem::size_of;
-use core::ops::Range;
 use core::slice;
 use log::error;
 use log::warn;
@@ -88,14 +87,14 @@
 
     let reboot_reason = match main_wrapper(fdt_address, payload_start, payload_size) {
         Err(r) => r,
-        Ok((next_stage, bcc)) => match next_stage {
-            NextStage::LinuxBootWithUart(ep) => jump_to_payload(fdt_address, ep, bcc),
+        Ok((next_stage, slices)) => match next_stage {
+            NextStage::LinuxBootWithUart(ep) => jump_to_payload(ep, &slices),
             NextStage::LinuxBoot(ep) => {
                 if let Err(e) = unshare_uart() {
                     error!("Failed to unmap UART: {e}");
                     RebootReason::InternalError
                 } else {
-                    jump_to_payload(fdt_address, ep, bcc)
+                    jump_to_payload(ep, &slices)
                 }
             }
         },
@@ -112,11 +111,11 @@
 ///
 /// Provide the abstractions necessary for start() to abort the pVM boot and for main() to run with
 /// the assumption that its environment has been properly configured.
-fn main_wrapper(
+fn main_wrapper<'a>(
     fdt: usize,
     payload: usize,
     payload_size: usize,
-) -> Result<(NextStage, Range<usize>), RebootReason> {
+) -> Result<(NextStage, MemorySlices<'a>), RebootReason> {
     // Limitations in this function:
     // - only access MMIO once (and while) it has been mapped and configured
     // - only perform logging once the logger has been initialized
@@ -136,7 +135,7 @@
 
     let config_entries = appended.get_entries();
 
-    let slices = memory::MemorySlices::new(fdt, payload, payload_size)?;
+    let mut slices = MemorySlices::new(fdt, payload, payload_size)?;
 
     // This wrapper allows main() to be blissfully ignorant of platform details.
     let (next_bcc, debuggable_payload) = crate::main(
@@ -148,6 +147,7 @@
         config_entries.vm_dtbo,
         config_entries.vm_ref_dt,
     )?;
+    slices.add_dice_chain(next_bcc);
     // Keep UART MMIO_GUARD-ed for debuggable payloads, to enable earlycon.
     let keep_uart = cfg!(debuggable_vms_improvements) && debuggable_payload;
 
@@ -162,7 +162,7 @@
 
     let next_stage = select_next_stage(slices.kernel, keep_uart);
 
-    Ok((next_stage, next_bcc))
+    Ok((next_stage, slices))
 }
 
 fn select_next_stage(kernel: &[u8], keep_uart: bool) -> NextStage {
@@ -173,7 +173,16 @@
     }
 }
 
-fn jump_to_payload(fdt_address: usize, payload_start: usize, bcc: Range<usize>) -> ! {
+fn jump_to_payload(entrypoint: usize, slices: &MemorySlices) -> ! {
+    let fdt_address = slices.fdt.as_ptr() as usize;
+    let bcc = slices
+        .dice_chain
+        .map(|slice| {
+            let r = slice.as_ptr_range();
+            (r.start as usize)..(r.end as usize)
+        })
+        .expect("Missing DICE chain");
+
     deactivate_dynamic_page_tables();
 
     const ASM_STP_ALIGN: usize = size_of::<u64>() * 2;
@@ -313,7 +322,7 @@
             eh_stack_end = in(reg) u64::try_from(eh_stack.end.0).unwrap(),
             dcache_line_size = in(reg) u64::try_from(min_dcache_line_size()).unwrap(),
             in("x0") u64::try_from(fdt_address).unwrap(),
-            in("x30") u64::try_from(payload_start).unwrap(),
+            in("x30") u64::try_from(entrypoint).unwrap(),
             options(noreturn),
         );
     };
diff --git a/guest/pvmfw/src/main.rs b/guest/pvmfw/src/main.rs
index d04db06..a28a039 100644
--- a/guest/pvmfw/src/main.rs
+++ b/guest/pvmfw/src/main.rs
@@ -40,7 +40,6 @@
 use alloc::borrow::Cow;
 use alloc::boxed::Box;
 use bssl_avf::Digester;
-use core::ops::Range;
 use cstr::cstr;
 use diced_open_dice::{bcc_handover_parse, DiceArtifacts, DiceContext, Hidden, VM_KEY_ALGORITHM};
 use libfdt::{Fdt, FdtNode};
@@ -54,7 +53,7 @@
 use vmbase::rand;
 use vmbase::virtio::pci;
 
-fn main(
+fn main<'a>(
     untrusted_fdt: &mut Fdt,
     signed_kernel: &[u8],
     ramdisk: Option<&[u8]>,
@@ -62,7 +61,7 @@
     mut debug_policy: Option<&[u8]>,
     vm_dtbo: Option<&mut [u8]>,
     vm_ref_dt: Option<&[u8]>,
-) -> Result<(Range<usize>, bool), RebootReason> {
+) -> Result<(&'a [u8], bool), RebootReason> {
     info!("pVM firmware");
     debug!("FDT: {:?}", untrusted_fdt.as_ptr());
     debug!("Signed kernel: {:?} ({:#x} bytes)", signed_kernel.as_ptr(), signed_kernel.len());
@@ -201,13 +200,7 @@
     })?;
 
     info!("Starting payload...");
-
-    let bcc_range = {
-        let r = next_bcc.as_ptr_range();
-        (r.start as usize)..(r.end as usize)
-    };
-
-    Ok((bcc_range, debuggable))
+    Ok((next_bcc, debuggable))
 }
 
 // Get the "salt" which is one of the input for DICE derivation.
diff --git a/guest/pvmfw/src/memory.rs b/guest/pvmfw/src/memory.rs
index d2f63b5..a663008 100644
--- a/guest/pvmfw/src/memory.rs
+++ b/guest/pvmfw/src/memory.rs
@@ -31,6 +31,7 @@
     pub fdt: &'a mut libfdt::Fdt,
     pub kernel: &'a [u8],
     pub ramdisk: Option<&'a [u8]>,
+    pub dice_chain: Option<&'a [u8]>,
 }
 
 impl<'a> MemorySlices<'a> {
@@ -111,6 +112,12 @@
             None
         };
 
-        Ok(Self { fdt: untrusted_fdt, kernel, ramdisk })
+        let dice_chain = None;
+
+        Ok(Self { fdt: untrusted_fdt, kernel, ramdisk, dice_chain })
+    }
+
+    pub fn add_dice_chain(&mut self, dice_chain: &'a [u8]) {
+        self.dice_chain = Some(dice_chain)
     }
 }