Allocate BARs for all VirtIO PCI devices.

Bug: 237249346
Test: Ran pVM firmware manually.
Change-Id: I5bf4707b9f9edb020b96fa3db740c24a4c4b789b
diff --git a/pvmfw/Android.bp b/pvmfw/Android.bp
index b6c115b..4218fae 100644
--- a/pvmfw/Android.bp
+++ b/pvmfw/Android.bp
@@ -20,6 +20,7 @@
         "liblog_rust_nostd",
         "libpvmfw_embedded_key",
         "libtinyvec_nostd",
+        "libvirtio_drivers",
         "libvmbase",
     ],
     apex_available: ["com.android.virt"],
diff --git a/pvmfw/src/entry.rs b/pvmfw/src/entry.rs
index d307759..2763e80 100644
--- a/pvmfw/src/entry.rs
+++ b/pvmfw/src/entry.rs
@@ -48,6 +48,8 @@
     InvalidRamdisk,
     /// Failed to verify the payload.
     PayloadVerificationError,
+    /// Error interacting with a VirtIO PCI device.
+    PciError,
 }
 
 main!(start);
diff --git a/pvmfw/src/main.rs b/pvmfw/src/main.rs
index 1767e24..79b6f57 100644
--- a/pvmfw/src/main.rs
+++ b/pvmfw/src/main.rs
@@ -32,7 +32,12 @@
 mod pci;
 mod smccc;
 
-use crate::{avb::PUBLIC_KEY, entry::RebootReason, memory::MemoryTracker, pci::PciInfo};
+use crate::{
+    avb::PUBLIC_KEY,
+    entry::RebootReason,
+    memory::MemoryTracker,
+    pci::{allocate_all_virtio_bars, PciError, PciInfo, PciMemory32Allocator},
+};
 use ::avb::verify_image;
 use dice::bcc;
 use libfdt::Fdt;
@@ -56,9 +61,15 @@
     trace!("BCC: {bcc:x?}");
 
     // Set up PCI bus for VirtIO devices.
-    let pci_info = PciInfo::from_fdt(fdt)?;
-    info!("PCI: {:#x?}", pci_info);
+    let pci_info = PciInfo::from_fdt(fdt).map_err(handle_pci_error)?;
+    debug!("PCI: {:#x?}", pci_info);
     pci_info.map(memory)?;
+    let mut bar_allocator = PciMemory32Allocator::new(&pci_info);
+    debug!("Allocator: {:#x?}", bar_allocator);
+    // Safety: This is the only place where we call make_pci_root, and this main function is only
+    // called once.
+    let mut pci_root = unsafe { pci_info.make_pci_root() };
+    allocate_all_virtio_bars(&mut pci_root, &mut bar_allocator).map_err(handle_pci_error)?;
 
     verify_image(signed_kernel, PUBLIC_KEY).map_err(|e| {
         error!("Failed to verify the payload: {e}");
@@ -67,3 +78,24 @@
     info!("Starting payload...");
     Ok(())
 }
+
+/// Logs the given PCI error and returns the appropriate `RebootReason`.
+fn handle_pci_error(e: PciError) -> RebootReason {
+    error!("{}", e);
+    match e {
+        PciError::FdtErrorPci(_)
+        | PciError::FdtNoPci
+        | PciError::FdtErrorReg(_)
+        | PciError::FdtMissingReg
+        | PciError::FdtRegEmpty
+        | PciError::FdtRegMissingSize
+        | PciError::CamWrongSize(_)
+        | PciError::FdtErrorRanges(_)
+        | PciError::FdtMissingRanges
+        | PciError::RangeAddressMismatch { .. }
+        | PciError::NoSuitableRange => RebootReason::InvalidFdt,
+        PciError::BarInfoFailed(_)
+        | PciError::BarAllocationFailed { .. }
+        | PciError::UnsupportedBarType(_) => RebootReason::PciError,
+    }
+}
diff --git a/pvmfw/src/pci.rs b/pvmfw/src/pci.rs
index 3e6915a..d971c7b 100644
--- a/pvmfw/src/pci.rs
+++ b/pvmfw/src/pci.rs
@@ -18,13 +18,78 @@
     entry::RebootReason,
     memory::{MemoryRange, MemoryTracker},
 };
-use core::{ffi::CStr, ops::Range};
-use libfdt::{AddressRange, Fdt, FdtNode};
+use core::{
+    ffi::CStr,
+    fmt::{self, Display, Formatter},
+    ops::Range,
+};
+use libfdt::{AddressRange, Fdt, FdtError, FdtNode};
 use log::{debug, error};
+use virtio_drivers::pci::{
+    bus::{self, BarInfo, Cam, Command, DeviceFunction, MemoryBarType, PciRoot},
+    virtio_device_type,
+};
 
 /// PCI MMIO configuration region size.
 const PCI_CFG_SIZE: usize = 0x100_0000;
 
+#[derive(Clone, Debug, Eq, PartialEq)]
+pub enum PciError {
+    FdtErrorPci(FdtError),
+    FdtNoPci,
+    FdtErrorReg(FdtError),
+    FdtMissingReg,
+    FdtRegEmpty,
+    FdtRegMissingSize,
+    CamWrongSize(usize),
+    FdtErrorRanges(FdtError),
+    FdtMissingRanges,
+    RangeAddressMismatch { bus_address: u64, cpu_physical: u64 },
+    NoSuitableRange,
+    BarInfoFailed(bus::PciError),
+    BarAllocationFailed { size: u32, device_function: DeviceFunction },
+    UnsupportedBarType(MemoryBarType),
+}
+
+impl Display for PciError {
+    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
+        match self {
+            Self::FdtErrorPci(e) => write!(f, "Error getting PCI node from FDT: {}", e),
+            Self::FdtNoPci => write!(f, "Failed to find PCI bus in FDT."),
+            Self::FdtErrorReg(e) => write!(f, "Error getting reg property from PCI node: {}", e),
+            Self::FdtMissingReg => write!(f, "PCI node missing reg property."),
+            Self::FdtRegEmpty => write!(f, "Empty reg property on PCI node."),
+            Self::FdtRegMissingSize => write!(f, "PCI reg property missing size."),
+            Self::CamWrongSize(cam_size) => write!(
+                f,
+                "FDT says PCI CAM is {} bytes but we expected {}.",
+                cam_size, PCI_CFG_SIZE
+            ),
+            Self::FdtErrorRanges(e) => {
+                write!(f, "Error getting ranges property from PCI node: {}", e)
+            }
+            Self::FdtMissingRanges => write!(f, "PCI node missing ranges property."),
+            Self::RangeAddressMismatch { bus_address, cpu_physical } => {
+                write!(
+                    f,
+                    "bus address {:#018x} != CPU physical address {:#018x}",
+                    bus_address, cpu_physical
+                )
+            }
+            Self::NoSuitableRange => write!(f, "No suitable PCI memory range found."),
+            Self::BarInfoFailed(e) => write!(f, "Error getting PCI BAR information: {}", e),
+            Self::BarAllocationFailed { size, device_function } => write!(
+                f,
+                "Failed to allocate memory BAR of size {} for PCI device {}.",
+                size, device_function
+            ),
+            Self::UnsupportedBarType(address_type) => {
+                write!(f, "Memory BAR address type {:?} not supported.", address_type)
+            }
+        }
+    }
+}
+
 /// Information about the PCI bus parsed from the device tree.
 #[derive(Debug)]
 pub struct PciInfo {
@@ -36,7 +101,7 @@
 
 impl PciInfo {
     /// Finds the PCI node in the FDT, parses its properties and validates it.
-    pub fn from_fdt(fdt: &Fdt) -> Result<Self, RebootReason> {
+    pub fn from_fdt(fdt: &Fdt) -> Result<Self, PciError> {
         let pci_node = pci_node(fdt)?;
 
         let cam_range = parse_cam_range(&pci_node)?;
@@ -61,48 +126,44 @@
 
         Ok(())
     }
+
+    /// Returns the `PciRoot` for the memory-mapped CAM found in the FDT. The CAM should be mapped
+    /// before this is called, by calling [`PciInfo::map`].
+    ///
+    /// # Safety
+    ///
+    /// To prevent concurrent access, only one `PciRoot` should exist in the program. Thus this
+    /// method must only be called once, and there must be no other `PciRoot` constructed using the
+    /// same CAM.
+    pub unsafe fn make_pci_root(&self) -> PciRoot {
+        PciRoot::new(self.cam_range.start as *mut u8, Cam::MmioCam)
+    }
 }
 
 /// Finds an FDT node with compatible=pci-host-cam-generic.
-fn pci_node(fdt: &Fdt) -> Result<FdtNode, RebootReason> {
+fn pci_node(fdt: &Fdt) -> Result<FdtNode, PciError> {
     fdt.compatible_nodes(CStr::from_bytes_with_nul(b"pci-host-cam-generic\0").unwrap())
-        .map_err(|e| {
-            error!("Failed to find PCI bus in FDT: {}", e);
-            RebootReason::InvalidFdt
-        })?
+        .map_err(PciError::FdtErrorPci)?
         .next()
-        .ok_or(RebootReason::InvalidFdt)
+        .ok_or(PciError::FdtNoPci)
 }
 
 /// Parses the "reg" property of the given PCI FDT node to find the MMIO CAM range.
-fn parse_cam_range(pci_node: &FdtNode) -> Result<MemoryRange, RebootReason> {
+fn parse_cam_range(pci_node: &FdtNode) -> Result<MemoryRange, PciError> {
     let pci_reg = pci_node
         .reg()
-        .map_err(|e| {
-            error!("Error getting reg property from PCI node: {}", e);
-            RebootReason::InvalidFdt
-        })?
-        .ok_or_else(|| {
-            error!("PCI node missing reg property.");
-            RebootReason::InvalidFdt
-        })?
+        .map_err(PciError::FdtErrorReg)?
+        .ok_or(PciError::FdtMissingReg)?
         .next()
-        .ok_or_else(|| {
-            error!("Empty reg property on PCI node.");
-            RebootReason::InvalidFdt
-        })?;
+        .ok_or(PciError::FdtRegEmpty)?;
     let cam_addr = pci_reg.addr as usize;
-    let cam_size = pci_reg.size.ok_or_else(|| {
-        error!("PCI reg property missing size.");
-        RebootReason::InvalidFdt
-    })? as usize;
+    let cam_size = pci_reg.size.ok_or(PciError::FdtRegMissingSize)? as usize;
     debug!("Found PCI CAM at {:#x}-{:#x}", cam_addr, cam_addr + cam_size);
     // Check that the CAM is the size we expect, so we don't later try accessing it beyond its
     // bounds. If it is a different size then something is very wrong and we shouldn't continue to
     // access it; maybe there is some new version of PCI we don't know about.
     if cam_size != PCI_CFG_SIZE {
-        error!("FDT says PCI CAM is {} bytes but we expected {}.", cam_size, PCI_CFG_SIZE);
-        return Err(RebootReason::InvalidFdt);
+        return Err(PciError::CamWrongSize(cam_size));
     }
 
     Ok(cam_addr..cam_addr + cam_size)
@@ -110,20 +171,14 @@
 
 /// Parses the "ranges" property of the given PCI FDT node, and returns the largest suitable range
 /// to use for non-prefetchable 32-bit memory BARs.
-fn parse_ranges(pci_node: &FdtNode) -> Result<Range<u32>, RebootReason> {
+fn parse_ranges(pci_node: &FdtNode) -> Result<Range<u32>, PciError> {
     let mut memory_address = 0;
     let mut memory_size = 0;
 
     for AddressRange { addr: (flags, bus_address), parent_addr: cpu_physical, size } in pci_node
         .ranges::<(u32, u64), u64, u64>()
-        .map_err(|e| {
-            error!("Error getting ranges property from PCI node: {}", e);
-            RebootReason::InvalidFdt
-        })?
-        .ok_or_else(|| {
-            error!("PCI node missing ranges property.");
-            RebootReason::InvalidFdt
-        })?
+        .map_err(PciError::FdtErrorRanges)?
+        .ok_or(PciError::FdtMissingRanges)?
     {
         let flags = PciMemoryFlags(flags);
         let prefetchable = flags.prefetchable();
@@ -145,11 +200,7 @@
             && bus_address + size < u32::MAX.into()
         {
             if bus_address != cpu_physical {
-                error!(
-                    "bus address {:#018x} != CPU physical address {:#018x}",
-                    bus_address, cpu_physical
-                );
-                return Err(RebootReason::InvalidFdt);
+                return Err(PciError::RangeAddressMismatch { bus_address, cpu_physical });
             }
             memory_address = u32::try_from(cpu_physical).unwrap();
             memory_size = u32::try_from(size).unwrap();
@@ -157,8 +208,7 @@
     }
 
     if memory_size == 0 {
-        error!("No suitable PCI memory range found.");
-        return Err(RebootReason::InvalidFdt);
+        return Err(PciError::NoSuitableRange);
     }
 
     Ok(memory_address..memory_address + memory_size)
@@ -196,3 +246,101 @@
         }
     }
 }
+
+/// Allocates BARs for all VirtIO PCI devices.
+pub fn allocate_all_virtio_bars(
+    pci_root: &mut PciRoot,
+    allocator: &mut PciMemory32Allocator,
+) -> Result<(), PciError> {
+    for (device_function, info) in pci_root.enumerate_bus(0) {
+        let (status, command) = pci_root.get_status_command(device_function);
+        debug!(
+            "Found PCI device {} at {}, status {:?} command {:?}",
+            info, device_function, status, command
+        );
+        if let Some(virtio_type) = virtio_device_type(&info) {
+            debug!("  VirtIO {:?}", virtio_type);
+            allocate_bars(pci_root, device_function, allocator)?;
+        }
+    }
+
+    Ok(())
+}
+
+/// Allocates 32-bit memory addresses for PCI BARs.
+#[derive(Debug)]
+pub struct PciMemory32Allocator {
+    /// The start of the available (not yet allocated) address space for PCI BARs.
+    start: u32,
+    /// The end of the available address space.
+    end: u32,
+}
+
+impl PciMemory32Allocator {
+    pub fn new(pci_info: &PciInfo) -> Self {
+        Self { start: pci_info.bar_range.start, end: pci_info.bar_range.end }
+    }
+
+    /// Allocates a 32-bit memory address region for a PCI BAR of the given power-of-2 size.
+    ///
+    /// It will have alignment matching the size. The size must be a power of 2.
+    pub fn allocate_memory_32(&mut self, size: u32) -> Option<u32> {
+        assert!(size.is_power_of_two());
+        let allocated_address = align_up(self.start, size);
+        if allocated_address + size <= self.end {
+            self.start = allocated_address + size;
+            Some(allocated_address)
+        } else {
+            None
+        }
+    }
+}
+
+/// Allocates appropriately-sized memory regions and assigns them to the device's BARs.
+fn allocate_bars(
+    root: &mut PciRoot,
+    device_function: DeviceFunction,
+    allocator: &mut PciMemory32Allocator,
+) -> Result<(), PciError> {
+    let mut bar_index = 0;
+    while bar_index < 6 {
+        let info = root.bar_info(device_function, bar_index).map_err(PciError::BarInfoFailed)?;
+        debug!("BAR {}: {}", bar_index, info);
+        // Ignore I/O bars, as they aren't required for the VirtIO driver.
+        if let BarInfo::Memory { address_type, size, .. } = info {
+            match address_type {
+                _ if size == 0 => {}
+                MemoryBarType::Width32 => {
+                    let address = allocator
+                        .allocate_memory_32(size)
+                        .ok_or(PciError::BarAllocationFailed { size, device_function })?;
+                    debug!("Allocated address {:#010x}", address);
+                    root.set_bar_32(device_function, bar_index, address);
+                }
+                _ => {
+                    return Err(PciError::UnsupportedBarType(address_type));
+                }
+            }
+        }
+
+        bar_index += 1;
+        if info.takes_two_entries() {
+            bar_index += 1;
+        }
+    }
+
+    // Enable the device to use its BARs.
+    root.set_command(
+        device_function,
+        Command::IO_SPACE | Command::MEMORY_SPACE | Command::BUS_MASTER,
+    );
+    let (status, command) = root.get_status_command(device_function);
+    debug!("Allocated BARs and enabled device, status {:?} command {:?}", status, command);
+
+    Ok(())
+}
+
+// TODO: Make the alignment functions in the helpers module generic once const_trait_impl is stable.
+const fn align_up(value: u32, alignment: u32) -> u32 {
+    ((value - 1) | (alignment - 1)) + 1
+}