pvmfw: Optimize MMIO guard map

Optimize MMIO guard mapping device pages by mapping the pages as invalid
in order to produce translation faults. When a translation fault
happens, re-enable the pages and only then MMIO guard map the pages.

Bug: 245267332
Test: atest MicrodroidTestApp

Change-Id: I81128d7efec1249a9a7da988ec098b29936338ef
diff --git a/pvmfw/src/memory.rs b/pvmfw/src/memory.rs
index 0656321..1a2b4b7 100644
--- a/pvmfw/src/memory.rs
+++ b/pvmfw/src/memory.rs
@@ -16,8 +16,9 @@
 
 #![deny(unsafe_op_in_unsafe_fn)]
 
-use crate::helpers::{self, page_4kb_of, RangeExt, SIZE_4KB, SIZE_4MB};
+use crate::helpers::{self, page_4kb_of, RangeExt, PVMFW_PAGE_SIZE, SIZE_4MB};
 use crate::mmu;
+use aarch64_paging::paging::{Attributes, Descriptor, MemoryRegion as VaRange};
 use alloc::alloc::alloc_zeroed;
 use alloc::alloc::dealloc;
 use alloc::alloc::handle_alloc_error;
@@ -110,12 +111,16 @@
     Overlaps,
     /// Region couldn't be mapped.
     FailedToMap,
+    /// Region couldn't be unmapped.
+    FailedToUnmap,
     /// Error from the interaction with the hypervisor.
     Hypervisor(hyp::Error),
     /// Failure to set `SHARED_MEMORY`.
     SharedMemorySetFailure,
     /// Failure to set `SHARED_POOL`.
     SharedPoolSetFailure,
+    /// Invalid page table entry.
+    InvalidPte,
 }
 
 impl fmt::Display for MemoryTrackerError {
@@ -128,9 +133,11 @@
             Self::OutOfRange => write!(f, "Region is out of the tracked memory address space"),
             Self::Overlaps => write!(f, "New region overlaps with tracked regions"),
             Self::FailedToMap => write!(f, "Failed to map the new region"),
+            Self::FailedToUnmap => write!(f, "Failed to unmap the new region"),
             Self::Hypervisor(e) => e.fmt(f),
             Self::SharedMemorySetFailure => write!(f, "Failed to set SHARED_MEMORY"),
             Self::SharedPoolSetFailure => write!(f, "Failed to set SHARED_POOL"),
+            Self::InvalidPte => write!(f, "Page table entry is not valid"),
         }
     }
 }
@@ -279,15 +286,11 @@
             return Err(MemoryTrackerError::Full);
         }
 
-        self.page_table.map_device(&range).map_err(|e| {
+        self.page_table.map_device_lazy(&range).map_err(|e| {
             error!("Error during MMIO device mapping: {e}");
             MemoryTrackerError::FailedToMap
         })?;
 
-        for page_base in page_iterator(&range) {
-            get_hypervisor().mmio_guard_map(page_base)?;
-        }
-
         if self.mmio_regions.try_push(range).is_some() {
             return Err(MemoryTrackerError::Full);
         }
@@ -322,13 +325,12 @@
     /// Unmaps all tracked MMIO regions from the MMIO guard.
     ///
     /// Note that they are not unmapped from the page table.
-    pub fn mmio_unmap_all(&self) -> Result<()> {
-        for region in &self.mmio_regions {
-            for page_base in page_iterator(region) {
-                get_hypervisor().mmio_guard_unmap(page_base)?;
-            }
+    pub fn mmio_unmap_all(&mut self) -> Result<()> {
+        for range in &self.mmio_regions {
+            self.page_table
+                .modify_range(range, &mmio_guard_unmap_page)
+                .map_err(|_| MemoryTrackerError::FailedToUnmap)?;
         }
-
         Ok(())
     }
 
@@ -372,6 +374,18 @@
     pub fn unshare_all_memory(&mut self) {
         drop(SHARED_MEMORY.lock().take());
     }
+
+    /// Handles translation fault for blocks flagged for lazy MMIO mapping by enabling the page
+    /// table entry and MMIO guard mapping the block. Breaks apart a block entry if required.
+    pub fn handle_mmio_fault(&mut self, addr: usize) -> Result<()> {
+        let page_range = page_4kb_of(addr)..page_4kb_of(addr) + PVMFW_PAGE_SIZE;
+        self.page_table
+            .modify_range(&page_range, &verify_lazy_mapped_block)
+            .map_err(|_| MemoryTrackerError::InvalidPte)?;
+        get_hypervisor().mmio_guard_map(page_range.start)?;
+        // Maps a single device page, breaking up block mappings if necessary.
+        self.page_table.map_device(&page_range).map_err(|_| MemoryTrackerError::FailedToMap)
+    }
 }
 
 impl Drop for MemoryTracker {
@@ -429,11 +443,6 @@
     Ok(())
 }
 
-/// Returns an iterator which yields the base address of each 4 KiB page within the given range.
-fn page_iterator(range: &MemoryRange) -> impl Iterator<Item = usize> {
-    (page_4kb_of(range.start)..range.end).step_by(SIZE_4KB)
-}
-
 /// Returns the intermediate physical address corresponding to the given virtual address.
 ///
 /// As we use identity mapping for everything, this is just a cast, but it's useful to use it to be
@@ -449,3 +458,64 @@
 pub fn phys_to_virt(paddr: usize) -> NonNull<u8> {
     NonNull::new(paddr as _).unwrap()
 }
+
+/// Checks whether a PTE at given level is a page or block descriptor.
+#[inline]
+fn is_leaf_pte(flags: &Attributes, level: usize) -> bool {
+    const LEAF_PTE_LEVEL: usize = 3;
+    if flags.contains(Attributes::TABLE_OR_PAGE) {
+        level == LEAF_PTE_LEVEL
+    } else {
+        level < LEAF_PTE_LEVEL
+    }
+}
+
+/// Checks whether block flags indicate it should be MMIO guard mapped.
+fn verify_lazy_mapped_block(
+    _range: &VaRange,
+    desc: &mut Descriptor,
+    level: usize,
+) -> result::Result<(), ()> {
+    let flags = desc.flags().expect("Unsupported PTE flags set");
+    if !is_leaf_pte(&flags, level) {
+        return Ok(()); // Skip table PTEs as they aren't tagged with MMIO_LAZY_MAP_FLAG.
+    }
+    if flags.contains(mmu::MMIO_LAZY_MAP_FLAG) && !flags.contains(Attributes::VALID) {
+        Ok(())
+    } else {
+        Err(())
+    }
+}
+
+/// MMIO guard unmaps page
+fn mmio_guard_unmap_page(
+    va_range: &VaRange,
+    desc: &mut Descriptor,
+    level: usize,
+) -> result::Result<(), ()> {
+    let flags = desc.flags().expect("Unsupported PTE flags set");
+    // This function will be called on an address range that corresponds to a device. Only if a
+    // page has been accessed (written to or read from), will it contain the VALID flag and be MMIO
+    // guard mapped. Therefore, we can skip unmapping invalid pages, they were never MMIO guard
+    // mapped anyway.
+    if is_leaf_pte(&flags, level) && flags.contains(Attributes::VALID) {
+        assert!(
+            flags.contains(mmu::MMIO_LAZY_MAP_FLAG),
+            "Attempting MMIO guard unmap for non-device pages"
+        );
+        assert_eq!(
+            va_range.len(),
+            PVMFW_PAGE_SIZE,
+            "Failed to break down block mapping before MMIO guard mapping"
+        );
+        let page_base = va_range.start().0;
+        assert_eq!(page_base % PVMFW_PAGE_SIZE, 0);
+        // Since mmio_guard_map takes IPAs, if pvmfw moves non-ID address mapping, page_base
+        // should be converted to IPA. However, since 0x0 is a valid MMIO address, we don't use
+        // virt_to_phys here, and just pass page_base instead.
+        get_hypervisor().mmio_guard_unmap(page_base).map_err(|e| {
+            error!("Error MMIO guard unmapping: {e}");
+        })?;
+    }
+    Ok(())
+}