Merge "[rkpvm] Decide whether VM is protected using FDT property" into main
diff --git a/guest/rialto/src/fdt.rs b/guest/rialto/src/fdt.rs
index b220f41..e97a262 100644
--- a/guest/rialto/src/fdt.rs
+++ b/guest/rialto/src/fdt.rs
@@ -29,3 +29,10 @@
let node = fdt.node(cstr!("/avf"))?.ok_or(FdtError::NotFound)?;
node.getprop(cstr!("vendor_hashtree_descriptor_root_digest"))
}
+
+pub(crate) fn read_is_strict_boot(fdt: &Fdt) -> libfdt::Result<bool> {
+ match fdt.chosen()? {
+ Some(node) => Ok(node.getprop(cstr!("avf,strict-boot"))?.is_some()),
+ None => Ok(false),
+ }
+}
diff --git a/guest/rialto/src/main.rs b/guest/rialto/src/main.rs
index 701a287..930f4e8 100644
--- a/guest/rialto/src/main.rs
+++ b/guest/rialto/src/main.rs
@@ -26,7 +26,7 @@
use crate::communication::VsockStream;
use crate::error::{Error, Result};
-use crate::fdt::{read_dice_range_from, read_vendor_hashtree_root_digest};
+use crate::fdt::{read_dice_range_from, read_is_strict_boot, read_vendor_hashtree_root_digest};
use alloc::boxed::Box;
use bssl_sys::CRYPTO_library_init;
use ciborium_io::Write;
@@ -58,16 +58,15 @@
},
};
-fn host_addr() -> VsockAddr {
- VsockAddr { cid: VMADDR_CID_HOST, port: vm_type().port() }
+fn host_addr(fdt: &libfdt::Fdt) -> Result<VsockAddr> {
+ Ok(VsockAddr { cid: VMADDR_CID_HOST, port: vm_type(fdt)?.port() })
}
-fn vm_type() -> VmType {
- // Use MMIO support to determine whether the VM is protected.
- if get_mmio_guard().is_some() {
- VmType::ProtectedVm
+fn vm_type(fdt: &libfdt::Fdt) -> Result<VmType> {
+ if read_is_strict_boot(fdt)? {
+ Ok(VmType::ProtectedVm)
} else {
- VmType::NonProtectedVm
+ Ok(VmType::NonProtectedVm)
}
}
@@ -143,7 +142,7 @@
unsafe {
CRYPTO_library_init();
}
- let bcc_handover: Box<dyn DiceArtifacts> = match vm_type() {
+ let bcc_handover: Box<dyn DiceArtifacts> = match vm_type(fdt)? {
VmType::ProtectedVm => {
let dice_range = read_dice_range_from(fdt)?;
info!("DICE range: {dice_range:#x?}");
@@ -178,7 +177,7 @@
let request_context =
RequestContext { dice_artifacts: bcc_handover.as_ref(), vendor_hashtree_root_digest };
- let mut vsock_stream = VsockStream::new(socket_device, host_addr())?;
+ let mut vsock_stream = VsockStream::new(socket_device, host_addr(fdt)?)?;
while let ServiceVmRequest::Process(req) = vsock_stream.read_request()? {
info!("Received request: {}", req.name());
let response = process_request(req, &request_context);