Merge "pvmfw: Improve memory sharing with the host"
diff --git a/pvmfw/src/entry.rs b/pvmfw/src/entry.rs
index 999baee..8e3e47b 100644
--- a/pvmfw/src/entry.rs
+++ b/pvmfw/src/entry.rs
@@ -112,13 +112,18 @@
             RebootReason::InvalidFdt
         })?;
 
-        if !get_hypervisor().has_cap(HypervisorCap::DYNAMIC_MEM_SHARE) {
+        if get_hypervisor().has_cap(HypervisorCap::DYNAMIC_MEM_SHARE) {
+            memory.init_dynamic_shared_pool().map_err(|e| {
+                error!("Failed to initialize dynamically shared pool: {e}");
+                RebootReason::InternalError
+            })?;
+        } else {
             let range = info.swiotlb_info.fixed_range().ok_or_else(|| {
                 error!("Pre-shared pool range not specified in swiotlb node");
                 RebootReason::InvalidFdt
             })?;
 
-            memory.init_shared_pool(range).map_err(|e| {
+            memory.init_static_shared_pool(range).map_err(|e| {
                 error!("Failed to initialize pre-shared pool {e}");
                 RebootReason::InvalidFdt
             })?;
@@ -261,6 +266,8 @@
         error!("Failed to unshare MMIO ranges: {e}");
         RebootReason::InternalError
     })?;
+    // Call unshare_all_memory here (instead of relying on the dtor) while UART is still mapped.
+    MEMORY.lock().as_mut().unwrap().unshare_all_memory();
     get_hypervisor().mmio_guard_unmap(console::BASE_ADDRESS).map_err(|e| {
         error!("Failed to unshare the UART: {e}");
         RebootReason::InternalError
diff --git a/pvmfw/src/helpers.rs b/pvmfw/src/helpers.rs
index a6f0dd5..c230784 100644
--- a/pvmfw/src/helpers.rs
+++ b/pvmfw/src/helpers.rs
@@ -98,6 +98,7 @@
 /// Aligns the given address to the given alignment, if it is a power of two.
 ///
 /// Returns `None` if the alignment isn't a power of two.
+#[allow(dead_code)] // Currently unused but might be needed again.
 pub const fn align_down(addr: usize, alignment: usize) -> Option<usize> {
     if !alignment.is_power_of_two() {
         None
diff --git a/pvmfw/src/memory.rs b/pvmfw/src/memory.rs
index 7e8423a..cfd6b0a 100644
--- a/pvmfw/src/memory.rs
+++ b/pvmfw/src/memory.rs
@@ -16,14 +16,15 @@
 
 #![deny(unsafe_op_in_unsafe_fn)]
 
-use crate::helpers::{self, align_down, page_4kb_of, RangeExt, SIZE_4KB, SIZE_4MB};
+use crate::helpers::{self, page_4kb_of, RangeExt, SIZE_4KB, SIZE_4MB};
 use crate::mmu;
 use alloc::alloc::alloc_zeroed;
 use alloc::alloc::dealloc;
 use alloc::alloc::handle_alloc_error;
 use alloc::boxed::Box;
+use alloc::vec::Vec;
+use buddy_system_allocator::Heap;
 use buddy_system_allocator::LockedHeap;
-use core::alloc::GlobalAlloc as _;
 use core::alloc::Layout;
 use core::cmp::max;
 use core::cmp::min;
@@ -112,6 +113,8 @@
     FailedToMap,
     /// Error from the interaction with the hypervisor.
     Hypervisor(hyp::Error),
+    /// Failure to set `SHARED_MEMORY`.
+    SharedMemorySetFailure,
     /// Failure to set `SHARED_POOL`.
     SharedPoolSetFailure,
 }
@@ -127,6 +130,7 @@
             Self::Overlaps => write!(f, "New region overlaps with tracked regions"),
             Self::FailedToMap => write!(f, "Failed to map the new region"),
             Self::Hypervisor(e) => e.fmt(f),
+            Self::SharedMemorySetFailure => write!(f, "Failed to set SHARED_MEMORY"),
             Self::SharedPoolSetFailure => write!(f, "Failed to set SHARED_POOL"),
         }
     }
@@ -141,6 +145,62 @@
 type Result<T> = result::Result<T, MemoryTrackerError>;
 
 static SHARED_POOL: OnceBox<LockedHeap<32>> = OnceBox::new();
+static SHARED_MEMORY: SpinMutex<Option<MemorySharer>> = SpinMutex::new(None);
+
+/// Allocates memory on the heap and shares it with the host.
+///
+/// Unshares all pages when dropped.
+pub struct MemorySharer {
+    granule: usize,
+    shared_regions: Vec<(usize, Layout)>,
+}
+
+impl MemorySharer {
+    const INIT_CAP: usize = 10;
+
+    pub fn new(granule: usize) -> Self {
+        assert!(granule.is_power_of_two());
+        Self { granule, shared_regions: Vec::with_capacity(Self::INIT_CAP) }
+    }
+
+    /// Get from the global allocator a granule-aligned region that suits `hint` and share it.
+    pub fn refill(&mut self, pool: &mut Heap<32>, hint: Layout) {
+        let layout = hint.align_to(self.granule).unwrap().pad_to_align();
+        assert_ne!(layout.size(), 0);
+        // SAFETY - layout has non-zero size.
+        let Some(shared) = NonNull::new(unsafe { alloc_zeroed(layout) }) else {
+            handle_alloc_error(layout);
+        };
+
+        let base = shared.as_ptr() as usize;
+        let end = base.checked_add(layout.size()).unwrap();
+        trace!("Sharing memory region {:#x?}", base..end);
+        for vaddr in (base..end).step_by(self.granule) {
+            let vaddr = NonNull::new(vaddr as *mut _).unwrap();
+            get_hypervisor().mem_share(virt_to_phys(vaddr).try_into().unwrap()).unwrap();
+        }
+        self.shared_regions.push((base, layout));
+
+        // SAFETY - The underlying memory range is owned by self and reserved for this pool.
+        unsafe { pool.add_to_heap(base, end) };
+    }
+}
+
+impl Drop for MemorySharer {
+    fn drop(&mut self) {
+        while let Some((base, layout)) = self.shared_regions.pop() {
+            let end = base.checked_add(layout.size()).unwrap();
+            trace!("Unsharing memory region {:#x?}", base..end);
+            for vaddr in (base..end).step_by(self.granule) {
+                let vaddr = NonNull::new(vaddr as *mut _).unwrap();
+                get_hypervisor().mem_unshare(virt_to_phys(vaddr).try_into().unwrap()).unwrap();
+            }
+
+            // SAFETY - The region was obtained from alloc_zeroed() with the recorded layout.
+            unsafe { dealloc(base as *mut _, layout) };
+        }
+    }
+}
 
 impl MemoryTracker {
     const CAPACITY: usize = 5;
@@ -274,14 +334,29 @@
         Ok(())
     }
 
-    /// Initialize a separate heap for shared memory allocations.
+    /// Initialize the shared heap to dynamically share memory from the global allocator.
+    pub fn init_dynamic_shared_pool(&mut self) -> Result<()> {
+        let granule = get_hypervisor().memory_protection_granule()?;
+        let previous = SHARED_MEMORY.lock().replace(MemorySharer::new(granule));
+        if previous.is_some() {
+            return Err(MemoryTrackerError::SharedMemorySetFailure);
+        }
+
+        SHARED_POOL
+            .set(Box::new(LockedHeap::empty()))
+            .map_err(|_| MemoryTrackerError::SharedPoolSetFailure)?;
+
+        Ok(())
+    }
+
+    /// Initialize the shared heap from a static region of memory.
     ///
     /// Some hypervisors such as Gunyah do not support a MemShare API for guest
     /// to share its memory with host. Instead they allow host to designate part
     /// of guest memory as "shared" ahead of guest starting its execution. The
     /// shared memory region is indicated in swiotlb node. On such platforms use
     /// a separate heap to allocate buffers that can be shared with host.
-    pub fn init_shared_pool(&mut self, range: Range<usize>) -> Result<()> {
+    pub fn init_static_shared_pool(&mut self, range: Range<usize>) -> Result<()> {
         let size = NonZeroUsize::new(range.len()).unwrap();
         let range = self.alloc_mut(range.start, size)?;
         let shared_pool = LockedHeap::<32>::new();
@@ -298,6 +373,11 @@
 
         Ok(())
     }
+
+    /// Unshares any memory that may have been shared.
+    pub fn unshare_all_memory(&mut self) {
+        drop(SHARED_MEMORY.lock().take());
+    }
 }
 
 impl Drop for MemoryTracker {
@@ -311,73 +391,37 @@
                 MemoryType::ReadOnly => {}
             }
         }
+        self.unshare_all_memory()
     }
 }
 
-/// Gives the KVM host read, write and execute permissions on the given memory range. If the range
-/// is not aligned with the memory protection granule then it will be extended on either end to
-/// align.
-fn share_range(range: &MemoryRange, granule: usize) -> hyp::Result<()> {
-    trace!("Sharing memory region {range:#x?}");
-    for base in (align_down(range.start, granule)
-        .expect("Memory protection granule was not a power of two")..range.end)
-        .step_by(granule)
-    {
-        get_hypervisor().mem_share(base as u64)?;
-    }
-    Ok(())
-}
-
-/// Removes permission from the KVM host to access the given memory range which was previously
-/// shared. If the range is not aligned with the memory protection granule then it will be extended
-/// on either end to align.
-fn unshare_range(range: &MemoryRange, granule: usize) -> hyp::Result<()> {
-    trace!("Unsharing memory region {range:#x?}");
-    for base in (align_down(range.start, granule)
-        .expect("Memory protection granule was not a power of two")..range.end)
-        .step_by(granule)
-    {
-        get_hypervisor().mem_unshare(base as u64)?;
-    }
-    Ok(())
-}
-
 /// Allocates a memory range of at least the given size that is shared with
 /// host. Returns a pointer to the buffer.
 ///
 /// It will be aligned to the memory sharing granule size supported by the hypervisor.
 pub fn alloc_shared(layout: Layout) -> hyp::Result<NonNull<u8>> {
     assert_ne!(layout.size(), 0);
-    let granule = get_hypervisor().memory_protection_granule()?;
-    let layout = layout.align_to(granule).unwrap().pad_to_align();
-    if let Some(shared_pool) = SHARED_POOL.get() {
-        // SAFETY - layout has a non-zero size.
-        let buffer = unsafe { shared_pool.alloc_zeroed(layout) };
-
-        let Some(buffer) = NonNull::new(buffer) else {
-            handle_alloc_error(layout);
-        };
-
-        trace!("Allocated shared buffer at {buffer:?} with {layout:?}");
-        return Ok(buffer);
-    }
-
-    // SAFETY - layout has a non-zero size.
-    let buffer = unsafe { alloc_zeroed(layout) };
-
-    let Some(buffer) = NonNull::new(buffer) else {
+    let Some(buffer) = try_shared_alloc(layout) else {
         handle_alloc_error(layout);
     };
 
-    let paddr = virt_to_phys(buffer);
-    // If share_range fails then we will leak the allocation, but that seems better than having it
-    // be reused while maybe still partially shared with the host.
-    share_range(&(paddr..paddr + layout.size()), granule)?;
-
-    trace!("Allocated shared memory at {buffer:?} with {layout:?}");
+    trace!("Allocated shared buffer at {buffer:?} with {layout:?}");
     Ok(buffer)
 }
 
+fn try_shared_alloc(layout: Layout) -> Option<NonNull<u8>> {
+    let mut shared_pool = SHARED_POOL.get().unwrap().lock();
+
+    if let Ok(buffer) = shared_pool.alloc(layout) {
+        Some(buffer)
+    } else if let Some(shared_memory) = SHARED_MEMORY.lock().as_mut() {
+        shared_memory.refill(&mut shared_pool, layout);
+        shared_pool.alloc(layout).ok()
+    } else {
+        None
+    }
+}
+
 /// Unshares and deallocates a memory range which was previously allocated by `alloc_shared`.
 ///
 /// The size passed in must be the size passed to the original `alloc_shared` call.
@@ -387,24 +431,9 @@
 /// The memory must have been allocated by `alloc_shared` with the same size, and not yet
 /// deallocated.
 pub unsafe fn dealloc_shared(vaddr: NonNull<u8>, layout: Layout) -> hyp::Result<()> {
-    let granule = get_hypervisor().memory_protection_granule()?;
-    let layout = layout.align_to(granule).unwrap().pad_to_align();
-    if let Some(shared_pool) = SHARED_POOL.get() {
-        // Safe because the memory was allocated by `alloc_shared` above using
-        // the same allocator, and the layout is the same as was used then.
-        unsafe { shared_pool.dealloc(vaddr.as_ptr(), layout) };
+    SHARED_POOL.get().unwrap().lock().dealloc(vaddr, layout);
 
-        trace!("Deallocated shared buffer at {vaddr:?} with {layout:?}");
-        return Ok(());
-    }
-
-    let paddr = virt_to_phys(vaddr);
-    unshare_range(&(paddr..paddr + layout.size()), granule)?;
-    // Safe because the memory was allocated by `alloc_shared` above using the same allocator, and
-    // the layout is the same as was used then.
-    unsafe { dealloc(vaddr.as_ptr(), layout) };
-
-    trace!("Deallocated shared memory at {vaddr:?} with {layout:?}");
+    trace!("Deallocated shared buffer at {vaddr:?} with {layout:?}");
     Ok(())
 }
 
diff --git a/pvmfw/src/virtio/hal.rs b/pvmfw/src/virtio/hal.rs
index ec0b9d8..51567cd 100644
--- a/pvmfw/src/virtio/hal.rs
+++ b/pvmfw/src/virtio/hal.rs
@@ -43,6 +43,9 @@
     fn dma_alloc(pages: usize, _direction: BufferDirection) -> (PhysAddr, NonNull<u8>) {
         let vaddr = alloc_shared(dma_layout(pages))
             .expect("Failed to allocate and share VirtIO DMA range with host");
+        // TODO(ptosi): Move this zeroing to virtio_drivers, if it silently wants a zeroed region.
+        // SAFETY - vaddr points to a region allocated for the caller so is safe to access.
+        unsafe { core::ptr::write_bytes(vaddr.as_ptr(), 0, dma_layout(pages).size()) };
         let paddr = virt_to_phys(vaddr);
         (paddr, vaddr)
     }
@@ -83,8 +86,6 @@
     unsafe fn share(buffer: NonNull<[u8]>, direction: BufferDirection) -> PhysAddr {
         let size = buffer.len();
 
-        // TODO: Copy to a pre-shared region rather than allocating and sharing each time.
-        // Allocate a range of pages, copy the buffer if necessary, and share the new range instead.
         let bounce = alloc_shared(bb_layout(size))
             .expect("Failed to allocate and share VirtIO bounce buffer with host");
         let paddr = virt_to_phys(bounce);