Merge "pvmfw: Introduce NextStage" into main
diff --git a/guest/pvmfw/src/entry.rs b/guest/pvmfw/src/entry.rs
index e987dfd..7c46515 100644
--- a/guest/pvmfw/src/entry.rs
+++ b/guest/pvmfw/src/entry.rs
@@ -74,22 +74,36 @@
 configure_heap!(SIZE_128KB);
 limit_stack_size!(SIZE_4KB * 12);
 
+#[derive(Debug)]
+enum NextStage {
+    LinuxBoot(usize),
+    LinuxBootWithUart(usize),
+}
+
 /// Entry point for pVM firmware.
 pub fn start(fdt_address: u64, payload_start: u64, payload_size: u64, _arg3: u64) {
-    // Limitations in this function:
-    // - can't access non-pvmfw memory (only statically-mapped memory)
-    // - can't access MMIO (except the console, already configured by vmbase)
+    let fdt_address = fdt_address.try_into().unwrap();
+    let payload_start = payload_start.try_into().unwrap();
+    let payload_size = payload_size.try_into().unwrap();
 
-    match main_wrapper(fdt_address as usize, payload_start as usize, payload_size as usize) {
-        Ok((entry, bcc, keep_uart)) => {
-            jump_to_payload(fdt_address, entry.try_into().unwrap(), bcc, keep_uart)
-        }
-        Err(e) => {
-            const REBOOT_REASON_CONSOLE: usize = 1;
-            console_writeln!(REBOOT_REASON_CONSOLE, "{}", e.as_avf_reboot_string());
-            reboot()
-        }
-    }
+    let reboot_reason = match main_wrapper(fdt_address, payload_start, payload_size) {
+        Err(r) => r,
+        Ok((next_stage, bcc)) => match next_stage {
+            NextStage::LinuxBootWithUart(ep) => jump_to_payload(fdt_address, ep, bcc),
+            NextStage::LinuxBoot(ep) => {
+                if let Err(e) = unshare_uart() {
+                    error!("Failed to unmap UART: {e}");
+                    RebootReason::InternalError
+                } else {
+                    jump_to_payload(fdt_address, ep, bcc)
+                }
+            }
+        },
+    };
+
+    const REBOOT_REASON_CONSOLE: usize = 1;
+    console_writeln!(REBOOT_REASON_CONSOLE, "{}", reboot_reason.as_avf_reboot_string());
+    reboot()
 
     // if we reach this point and return, vmbase::entry::rust_entry() will call power::shutdown().
 }
@@ -102,7 +116,7 @@
     fdt: usize,
     payload: usize,
     payload_size: usize,
-) -> Result<(usize, Range<usize>, bool), RebootReason> {
+) -> Result<(NextStage, Range<usize>), RebootReason> {
     // Limitations in this function:
     // - only access MMIO once (and while) it has been mapped and configured
     // - only perform logging once the logger has been initialized
@@ -146,14 +160,20 @@
     })?;
     unshare_all_memory();
 
-    Ok((slices.kernel.as_ptr() as usize, next_bcc, keep_uart))
+    let next_stage = select_next_stage(slices.kernel, keep_uart);
+
+    Ok((next_stage, next_bcc))
 }
 
-fn jump_to_payload(fdt_address: u64, payload_start: u64, bcc: Range<usize>, keep_uart: bool) -> ! {
-    if !keep_uart {
-        unshare_uart().unwrap();
+fn select_next_stage(kernel: &[u8], keep_uart: bool) -> NextStage {
+    if keep_uart {
+        NextStage::LinuxBootWithUart(kernel.as_ptr() as _)
+    } else {
+        NextStage::LinuxBoot(kernel.as_ptr() as _)
     }
+}
 
+fn jump_to_payload(fdt_address: usize, payload_start: usize, bcc: Range<usize>) -> ! {
     deactivate_dynamic_page_tables();
 
     const ASM_STP_ALIGN: usize = size_of::<u64>() * 2;
@@ -292,8 +312,8 @@
             eh_stack = in(reg) u64::try_from(eh_stack.start.0).unwrap(),
             eh_stack_end = in(reg) u64::try_from(eh_stack.end.0).unwrap(),
             dcache_line_size = in(reg) u64::try_from(min_dcache_line_size()).unwrap(),
-            in("x0") fdt_address,
-            in("x30") payload_start,
+            in("x0") u64::try_from(fdt_address).unwrap(),
+            in("x30") u64::try_from(payload_start).unwrap(),
             options(noreturn),
         );
     };