Merge "Map PCI CAM MMIO region and register pages with the MMIO guard."
diff --git a/pvmfw/src/entry.rs b/pvmfw/src/entry.rs
index efbb179..e8f9bb2 100644
--- a/pvmfw/src/entry.rs
+++ b/pvmfw/src/entry.rs
@@ -32,7 +32,7 @@
 use vmbase::{console, layout, logger, main, power::reboot};
 
 #[derive(Debug, Clone)]
-pub(crate) enum RebootReason {
+pub enum RebootReason {
     /// A malformed BCC was received.
     InvalidBcc,
     /// An invalid configuration was appended to pvmfw.
@@ -243,10 +243,15 @@
     let slices = MemorySlices::new(fdt, payload, payload_size, &mut memory)?;
 
     // This wrapper allows main() to be blissfully ignorant of platform details.
-    crate::main(slices.fdt, slices.kernel, slices.ramdisk, bcc)?;
+    crate::main(slices.fdt, slices.kernel, slices.ramdisk, bcc, &mut memory)?;
 
     // TODO: Overwrite BCC before jumping to payload to avoid leaking our sealing key.
 
+    info!("Expecting a bug making MMIO_GUARD_UNMAP return NOT_SUPPORTED on success");
+    memory.mmio_unmap_all().map_err(|e| {
+        error!("Failed to unshare MMIO ranges: {e}");
+        RebootReason::InternalError
+    })?;
     mmio_guard::unmap(console::BASE_ADDRESS).map_err(|e| {
         error!("Failed to unshare the UART: {e}");
         RebootReason::InternalError
diff --git a/pvmfw/src/main.rs b/pvmfw/src/main.rs
index d453e26..e6a158d 100644
--- a/pvmfw/src/main.rs
+++ b/pvmfw/src/main.rs
@@ -29,18 +29,25 @@
 mod memory;
 mod mmio_guard;
 mod mmu;
+mod pci;
 mod smccc;
 
-use crate::entry::RebootReason;
+use crate::{
+    entry::RebootReason,
+    memory::MemoryTracker,
+    pci::{map_cam, pci_node},
+};
 use avb::PUBLIC_KEY;
 use avb_nostd::verify_image;
+use libfdt::Fdt;
 use log::{debug, error, info};
 
 fn main(
-    fdt: &libfdt::Fdt,
+    fdt: &Fdt,
     signed_kernel: &[u8],
     ramdisk: Option<&[u8]>,
     bcc: &[u8],
+    memory: &mut MemoryTracker,
 ) -> Result<(), RebootReason> {
     info!("pVM firmware");
     debug!("FDT: {:?}", fdt as *const libfdt::Fdt);
@@ -51,6 +58,11 @@
         debug!("Ramdisk: None");
     }
     debug!("BCC: {:?} ({:#x} bytes)", bcc.as_ptr(), bcc.len());
+
+    // Set up PCI bus for VirtIO devices.
+    let pci_node = pci_node(fdt)?;
+    map_cam(&pci_node, memory)?;
+
     verify_image(signed_kernel, PUBLIC_KEY).map_err(|e| {
         error!("Failed to verify the payload: {e}");
         RebootReason::PayloadVerificationError
diff --git a/pvmfw/src/memory.rs b/pvmfw/src/memory.rs
index e88fa5b..ca1024d 100644
--- a/pvmfw/src/memory.rs
+++ b/pvmfw/src/memory.rs
@@ -14,7 +14,8 @@
 
 //! Low-level allocation and tracking of main memory.
 
-use crate::helpers;
+use crate::helpers::{self, page_4kb_of, SIZE_4KB};
+use crate::mmio_guard;
 use crate::mmu;
 use core::cmp::max;
 use core::cmp::min;
@@ -43,8 +44,7 @@
 impl MemoryRegion {
     /// True if the instance overlaps with the passed range.
     pub fn overlaps(&self, range: &MemoryRange) -> bool {
-        let our: &MemoryRange = self.as_ref();
-        max(our.start, range.start) < min(our.end, range.end)
+        overlaps(&self.range, range)
     }
 
     /// True if the instance is fully contained within the passed range.
@@ -60,11 +60,17 @@
     }
 }
 
+/// Returns true if one range overlaps with the other at all.
+fn overlaps<T: Copy + Ord>(a: &Range<T>, b: &Range<T>) -> bool {
+    max(a.start, b.start) < min(a.end, b.end)
+}
+
 /// Tracks non-overlapping slices of main memory.
 pub struct MemoryTracker {
-    regions: ArrayVec<[MemoryRegion; MemoryTracker::CAPACITY]>,
     total: MemoryRange,
     page_table: mmu::PageTable,
+    regions: ArrayVec<[MemoryRegion; MemoryTracker::CAPACITY]>,
+    mmio_regions: ArrayVec<[MemoryRange; MemoryTracker::MMIO_CAPACITY]>,
 }
 
 /// Errors for MemoryTracker operations.
@@ -84,6 +90,8 @@
     Overlaps,
     /// Region couldn't be mapped.
     FailedToMap,
+    /// Error from an MMIO guard call.
+    MmioGuard(mmio_guard::Error),
 }
 
 impl fmt::Display for MemoryTrackerError {
@@ -96,14 +104,22 @@
             Self::OutOfRange => write!(f, "Region is out of the tracked memory address space"),
             Self::Overlaps => write!(f, "New region overlaps with tracked regions"),
             Self::FailedToMap => write!(f, "Failed to map the new region"),
+            Self::MmioGuard(e) => e.fmt(f),
         }
     }
 }
 
+impl From<mmio_guard::Error> for MemoryTrackerError {
+    fn from(e: mmio_guard::Error) -> Self {
+        Self::MmioGuard(e)
+    }
+}
+
 type Result<T> = result::Result<T, MemoryTrackerError>;
 
 impl MemoryTracker {
     const CAPACITY: usize = 5;
+    const MMIO_CAPACITY: usize = 5;
     /// Base of the system's contiguous "main" memory.
     const BASE: usize = 0x8000_0000;
     /// First address that can't be translated by a level 1 TTBR0_EL1.
@@ -111,7 +127,12 @@
 
     /// Create a new instance from an active page table, covering the maximum RAM size.
     pub fn new(page_table: mmu::PageTable) -> Self {
-        Self { total: Self::BASE..Self::MAX_ADDR, page_table, regions: ArrayVec::new() }
+        Self {
+            total: Self::BASE..Self::MAX_ADDR,
+            page_table,
+            regions: ArrayVec::new(),
+            mmio_regions: ArrayVec::new(),
+        }
     }
 
     /// Resize the total RAM size.
@@ -164,6 +185,36 @@
         self.alloc_range_mut(&(base..(base + size.get())))
     }
 
+    /// Checks that the given range of addresses is within the MMIO region, and then maps it
+    /// appropriately.
+    pub fn map_mmio_range(&mut self, range: MemoryRange) -> Result<()> {
+        // MMIO space is below the main memory region.
+        if range.end > self.total.start {
+            return Err(MemoryTrackerError::OutOfRange);
+        }
+        if self.mmio_regions.iter().any(|r| overlaps(r, &range)) {
+            return Err(MemoryTrackerError::Overlaps);
+        }
+        if self.mmio_regions.len() == self.mmio_regions.capacity() {
+            return Err(MemoryTrackerError::Full);
+        }
+
+        self.page_table.map_device(&range).map_err(|e| {
+            error!("Error during MMIO device mapping: {e}");
+            MemoryTrackerError::FailedToMap
+        })?;
+
+        for page_base in page_iterator(&range) {
+            mmio_guard::map(page_base)?;
+        }
+
+        if self.mmio_regions.try_push(range).is_some() {
+            return Err(MemoryTrackerError::Full);
+        }
+
+        Ok(())
+    }
+
     /// Checks that the given region is within the range of the `MemoryTracker` and doesn't overlap
     /// with any other previously allocated regions, and that the regions ArrayVec has capacity to
     /// add it.
@@ -187,11 +238,24 @@
 
         Ok(self.regions.last().unwrap().as_ref().clone())
     }
+
+    /// Unmaps all tracked MMIO regions from the MMIO guard.
+    ///
+    /// Note that they are not unmapped from the page table.
+    pub fn mmio_unmap_all(&self) -> Result<()> {
+        for region in &self.mmio_regions {
+            for page_base in page_iterator(region) {
+                mmio_guard::unmap(page_base)?;
+            }
+        }
+
+        Ok(())
+    }
 }
 
 impl Drop for MemoryTracker {
     fn drop(&mut self) {
-        for region in self.regions.iter() {
+        for region in &self.regions {
             match region.mem_type {
                 MemoryType::ReadWrite => {
                     // TODO: Use page table's dirty bit to only flush pages that were touched.
@@ -202,3 +266,8 @@
         }
     }
 }
+
+/// Returns an iterator which yields the base address of each 4 KiB page within the given range.
+fn page_iterator(range: &MemoryRange) -> impl Iterator<Item = usize> {
+    (page_4kb_of(range.start)..range.end).step_by(SIZE_4KB)
+}
diff --git a/pvmfw/src/mmio_guard.rs b/pvmfw/src/mmio_guard.rs
index eb6c1fa..28f928f 100644
--- a/pvmfw/src/mmio_guard.rs
+++ b/pvmfw/src/mmio_guard.rs
@@ -105,7 +105,6 @@
     args[0] = ipa;
 
     // TODO(b/251426790): pKVM currently returns NOT_SUPPORTED for SUCCESS.
-    info!("Expecting a bug making MMIO_GUARD_UNMAP return NOT_SUPPORTED on success");
     match smccc::checked_hvc64_expect_zero(VENDOR_HYP_KVM_MMIO_GUARD_UNMAP_FUNC_ID, args) {
         Err(smccc::Error::NotSupported) | Ok(_) => Ok(()),
         x => x,
diff --git a/pvmfw/src/pci.rs b/pvmfw/src/pci.rs
new file mode 100644
index 0000000..7baabed
--- /dev/null
+++ b/pvmfw/src/pci.rs
@@ -0,0 +1,74 @@
+// Copyright 2022, The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//! Functions to scan the PCI bus for VirtIO device and allocate BARs.
+
+use crate::{entry::RebootReason, memory::MemoryTracker};
+use core::ffi::CStr;
+use libfdt::{Fdt, FdtNode};
+use log::{debug, error};
+
+/// PCI MMIO configuration region size.
+const PCI_CFG_SIZE: usize = 0x100_0000;
+
+/// Finds an FDT node with compatible=pci-host-cam-generic.
+pub fn pci_node(fdt: &Fdt) -> Result<FdtNode, RebootReason> {
+    fdt.compatible_nodes(CStr::from_bytes_with_nul(b"pci-host-cam-generic\0").unwrap())
+        .map_err(|e| {
+            error!("Failed to find PCI bus in FDT: {}", e);
+            RebootReason::InvalidFdt
+        })?
+        .next()
+        .ok_or(RebootReason::InvalidFdt)
+}
+
+pub fn map_cam(pci_node: &FdtNode, memory: &mut MemoryTracker) -> Result<(), RebootReason> {
+    // Parse reg property to find CAM.
+    let pci_reg = pci_node
+        .reg()
+        .map_err(|e| {
+            error!("Error getting reg property from PCI node: {}", e);
+            RebootReason::InvalidFdt
+        })?
+        .ok_or_else(|| {
+            error!("PCI node missing reg property.");
+            RebootReason::InvalidFdt
+        })?
+        .next()
+        .ok_or_else(|| {
+            error!("Empty reg property on PCI node.");
+            RebootReason::InvalidFdt
+        })?;
+    let cam_addr = pci_reg.addr as usize;
+    let cam_size = pci_reg.size.ok_or_else(|| {
+        error!("PCI reg property missing size.");
+        RebootReason::InvalidFdt
+    })? as usize;
+    debug!("Found PCI CAM at {:#x}-{:#x}", cam_addr, cam_addr + cam_size);
+    // Check that the CAM is the size we expect, so we don't later try accessing it beyond its
+    // bounds. If it is a different size then something is very wrong and we shouldn't continue to
+    // access it; maybe there is some new version of PCI we don't know about.
+    if cam_size != PCI_CFG_SIZE {
+        error!("FDT says PCI CAM is {} bytes but we expected {}.", cam_size, PCI_CFG_SIZE);
+        return Err(RebootReason::InvalidFdt);
+    }
+
+    // Map the CAM as MMIO.
+    memory.map_mmio_range(cam_addr..cam_addr + cam_size).map_err(|e| {
+        error!("Failed to map PCI CAM: {}", e);
+        RebootReason::InternalError
+    })?;
+
+    Ok(())
+}