Merge "[rkpvm] Decide whether VM is protected using FDT property" into main
diff --git a/guest/rialto/src/fdt.rs b/guest/rialto/src/fdt.rs
index b220f41..e97a262 100644
--- a/guest/rialto/src/fdt.rs
+++ b/guest/rialto/src/fdt.rs
@@ -29,3 +29,10 @@
     let node = fdt.node(cstr!("/avf"))?.ok_or(FdtError::NotFound)?;
     node.getprop(cstr!("vendor_hashtree_descriptor_root_digest"))
 }
+
+pub(crate) fn read_is_strict_boot(fdt: &Fdt) -> libfdt::Result<bool> {
+    match fdt.chosen()? {
+        Some(node) => Ok(node.getprop(cstr!("avf,strict-boot"))?.is_some()),
+        None => Ok(false),
+    }
+}
diff --git a/guest/rialto/src/main.rs b/guest/rialto/src/main.rs
index 701a287..930f4e8 100644
--- a/guest/rialto/src/main.rs
+++ b/guest/rialto/src/main.rs
@@ -26,7 +26,7 @@
 
 use crate::communication::VsockStream;
 use crate::error::{Error, Result};
-use crate::fdt::{read_dice_range_from, read_vendor_hashtree_root_digest};
+use crate::fdt::{read_dice_range_from, read_is_strict_boot, read_vendor_hashtree_root_digest};
 use alloc::boxed::Box;
 use bssl_sys::CRYPTO_library_init;
 use ciborium_io::Write;
@@ -58,16 +58,15 @@
     },
 };
 
-fn host_addr() -> VsockAddr {
-    VsockAddr { cid: VMADDR_CID_HOST, port: vm_type().port() }
+fn host_addr(fdt: &libfdt::Fdt) -> Result<VsockAddr> {
+    Ok(VsockAddr { cid: VMADDR_CID_HOST, port: vm_type(fdt)?.port() })
 }
 
-fn vm_type() -> VmType {
-    // Use MMIO support to determine whether the VM is protected.
-    if get_mmio_guard().is_some() {
-        VmType::ProtectedVm
+fn vm_type(fdt: &libfdt::Fdt) -> Result<VmType> {
+    if read_is_strict_boot(fdt)? {
+        Ok(VmType::ProtectedVm)
     } else {
-        VmType::NonProtectedVm
+        Ok(VmType::NonProtectedVm)
     }
 }
 
@@ -143,7 +142,7 @@
     unsafe {
         CRYPTO_library_init();
     }
-    let bcc_handover: Box<dyn DiceArtifacts> = match vm_type() {
+    let bcc_handover: Box<dyn DiceArtifacts> = match vm_type(fdt)? {
         VmType::ProtectedVm => {
             let dice_range = read_dice_range_from(fdt)?;
             info!("DICE range: {dice_range:#x?}");
@@ -178,7 +177,7 @@
     let request_context =
         RequestContext { dice_artifacts: bcc_handover.as_ref(), vendor_hashtree_root_digest };
 
-    let mut vsock_stream = VsockStream::new(socket_device, host_addr())?;
+    let mut vsock_stream = VsockStream::new(socket_device, host_addr(fdt)?)?;
     while let ServiceVmRequest::Process(req) = vsock_stream.read_request()? {
         info!("Received request: {}", req.name());
         let response = process_request(req, &request_context);