vmbase: Harden MMIO_GUARD with RAII

Introduce a MmioSharer tracking shared MMIO separately from the page
tables and use its Drop implementation to ensure that all MMIO is
unshared when the MemoryTracker owning it gets dropped.

Test: m libpvmfw libvmbase_example librialto
Change-Id: Ib3c9acb83a19a9f1989fe77021796443ced5da73
diff --git a/vmbase/src/memory/error.rs b/vmbase/src/memory/error.rs
index 1af8f8c..901621d 100644
--- a/vmbase/src/memory/error.rs
+++ b/vmbase/src/memory/error.rs
@@ -49,6 +49,8 @@
     FlushRegionFailed,
     /// Failed to set PTE dirty state.
     SetPteDirtyFailed,
+    /// Attempting to MMIO_GUARD_MAP more than once the same region.
+    DuplicateMmioShare(usize),
 }
 
 impl fmt::Display for MemoryTrackerError {
@@ -68,6 +70,9 @@
             Self::InvalidPte => write!(f, "Page table entry is not valid"),
             Self::FlushRegionFailed => write!(f, "Failed to flush memory region"),
             Self::SetPteDirtyFailed => write!(f, "Failed to set PTE dirty state"),
+            Self::DuplicateMmioShare(addr) => {
+                write!(f, "Attempted to share the same MMIO region at {addr:#x} twice")
+            }
         }
     }
 }
diff --git a/vmbase/src/memory/shared.rs b/vmbase/src/memory/shared.rs
index 6e70e6a..d44d58a 100644
--- a/vmbase/src/memory/shared.rs
+++ b/vmbase/src/memory/shared.rs
@@ -17,16 +17,18 @@
 use super::dbm::{flush_dirty_range, mark_dirty_block, set_dbm_enabled};
 use super::error::MemoryTrackerError;
 use super::page_table::{PageTable, MMIO_LAZY_MAP_FLAG};
-use super::util::{page_4kb_of, virt_to_phys};
+use super::util::virt_to_phys;
 use crate::dsb;
 use crate::exceptions::HandleExceptionError;
 use crate::hyp::{self, get_mem_sharer, get_mmio_guard, MMIO_GUARD_GRANULE_SIZE};
+use crate::util::unchecked_align_down;
 use crate::util::RangeExt as _;
 use aarch64_paging::paging::{
-    Attributes, Descriptor, MemoryRegion as VaRange, VirtualAddress, BITS_PER_LEVEL, PAGE_SIZE,
+    Attributes, Descriptor, MemoryRegion as VaRange, VirtualAddress, PAGE_SIZE,
 };
 use alloc::alloc::{alloc_zeroed, dealloc, handle_alloc_error};
 use alloc::boxed::Box;
+use alloc::collections::BTreeSet;
 use alloc::vec::Vec;
 use buddy_system_allocator::{FrameAllocator, LockedFrameAllocator};
 use core::alloc::Layout;
@@ -78,6 +80,7 @@
     mmio_regions: ArrayVec<[MemoryRange; MemoryTracker::MMIO_CAPACITY]>,
     mmio_range: MemoryRange,
     payload_range: Option<MemoryRange>,
+    mmio_sharer: MmioSharer,
 }
 
 impl MemoryTracker {
@@ -114,6 +117,7 @@
             mmio_regions: ArrayVec::new(),
             mmio_range,
             payload_range: payload_range.map(|r| r.start.0..r.end.0),
+            mmio_sharer: MmioSharer::new().unwrap(),
         }
     }
 
@@ -251,13 +255,8 @@
 
     /// Unshares any MMIO region previously shared with the MMIO guard.
     pub fn unshare_all_mmio(&mut self) -> Result<()> {
-        if get_mmio_guard().is_some() {
-            for range in &self.mmio_regions {
-                self.page_table
-                    .walk_range(&get_va_range(range), &mmio_guard_unmap_page)
-                    .map_err(|_| MemoryTrackerError::FailedToUnmap)?;
-            }
-        }
+        self.mmio_sharer.unshare_all();
+
         Ok(())
     }
 
@@ -319,14 +318,8 @@
     /// Handles translation fault for blocks flagged for lazy MMIO mapping by enabling the page
     /// table entry and MMIO guard mapping the block. Breaks apart a block entry if required.
     fn handle_mmio_fault(&mut self, addr: VirtualAddress) -> Result<()> {
-        let page_start = VirtualAddress(page_4kb_of(addr.0));
-        assert_eq!(page_start.0 % MMIO_GUARD_GRANULE_SIZE, 0);
-        const_assert_eq!(MMIO_GUARD_GRANULE_SIZE, PAGE_SIZE); // For good measure.
-        let page_range: VaRange = (page_start..page_start + PAGE_SIZE).into();
-
-        let mmio_guard = get_mmio_guard().unwrap();
-        mmio_guard.map(page_start.0)?;
-        self.map_lazy_mmio_as_valid(&page_range)?;
+        let shared_range = self.mmio_sharer.share(addr)?;
+        self.map_lazy_mmio_as_valid(&shared_range)?;
 
         Ok(())
     }
@@ -386,6 +379,61 @@
     }
 }
 
+struct MmioSharer {
+    granule: usize,
+    frames: BTreeSet<usize>,
+}
+
+impl MmioSharer {
+    fn new() -> Result<Self> {
+        let granule = MMIO_GUARD_GRANULE_SIZE;
+        const_assert_eq!(MMIO_GUARD_GRANULE_SIZE, PAGE_SIZE); // For good measure.
+        let frames = BTreeSet::new();
+
+        // Allows safely calling util::unchecked_align_down().
+        assert!(granule.is_power_of_two());
+
+        Ok(Self { granule, frames })
+    }
+
+    /// Share the MMIO region aligned to the granule size containing addr (not validated as MMIO).
+    fn share(&mut self, addr: VirtualAddress) -> Result<VaRange> {
+        // This can't use virt_to_phys() since 0x0 is a valid MMIO address and we are ID-mapped.
+        let phys = addr.0;
+        let base = unchecked_align_down(phys, self.granule);
+
+        if self.frames.contains(&base) {
+            return Err(MemoryTrackerError::DuplicateMmioShare(base));
+        }
+
+        if let Some(mmio_guard) = get_mmio_guard() {
+            mmio_guard.map(base)?;
+        }
+
+        let inserted = self.frames.insert(base);
+        assert!(inserted);
+
+        let base_va = VirtualAddress(base);
+        Ok((base_va..base_va + self.granule).into())
+    }
+
+    fn unshare_all(&mut self) {
+        let Some(mmio_guard) = get_mmio_guard() else {
+            return self.frames.clear();
+        };
+
+        while let Some(base) = self.frames.pop_first() {
+            mmio_guard.unmap(base).unwrap();
+        }
+    }
+}
+
+impl Drop for MmioSharer {
+    fn drop(&mut self) {
+        self.unshare_all();
+    }
+}
+
 /// Allocates a memory range of at least the given size and alignment that is shared with the host.
 /// Returns a pointer to the buffer.
 pub(crate) fn alloc_shared(layout: Layout) -> hyp::Result<NonNull<u8>> {
@@ -489,41 +537,6 @@
     }
 }
 
-/// MMIO guard unmaps page
-fn mmio_guard_unmap_page(
-    va_range: &VaRange,
-    desc: &Descriptor,
-    level: usize,
-) -> result::Result<(), ()> {
-    let flags = desc.flags().expect("Unsupported PTE flags set");
-    // This function will be called on an address range that corresponds to a device. Only if a
-    // page has been accessed (written to or read from), will it contain the VALID flag and be MMIO
-    // guard mapped. Therefore, we can skip unmapping invalid pages, they were never MMIO guard
-    // mapped anyway.
-    if flags.contains(Attributes::VALID) {
-        assert!(
-            flags.contains(MMIO_LAZY_MAP_FLAG),
-            "Attempting MMIO guard unmap for non-device pages"
-        );
-        const MMIO_GUARD_GRANULE_SHIFT: u32 = MMIO_GUARD_GRANULE_SIZE.ilog2() - PAGE_SIZE.ilog2();
-        const MMIO_GUARD_GRANULE_LEVEL: usize =
-            3 - (MMIO_GUARD_GRANULE_SHIFT as usize / BITS_PER_LEVEL);
-        assert_eq!(
-            level, MMIO_GUARD_GRANULE_LEVEL,
-            "Failed to break down block mapping before MMIO guard mapping"
-        );
-        let page_base = va_range.start().0;
-        assert_eq!(page_base % MMIO_GUARD_GRANULE_SIZE, 0);
-        // Since mmio_guard_map takes IPAs, if pvmfw moves non-ID address mapping, page_base
-        // should be converted to IPA. However, since 0x0 is a valid MMIO address, we don't use
-        // virt_to_phys here, and just pass page_base instead.
-        get_mmio_guard().unwrap().unmap(page_base).map_err(|e| {
-            error!("Error MMIO guard unmapping: {e}");
-        })?;
-    }
-    Ok(())
-}
-
 /// Handles a translation fault with the given fault address register (FAR).
 #[inline]
 pub fn handle_translation_fault(far: VirtualAddress) -> result::Result<(), HandleExceptionError> {