Merge "Revert^2 "Enable SVE when available in AVF"" into main
diff --git a/guest/pvmfw/src/instance.rs b/guest/pvmfw/src/instance.rs
index bb07f74..bbc58ed 100644
--- a/guest/pvmfw/src/instance.rs
+++ b/guest/pvmfw/src/instance.rs
@@ -26,7 +26,10 @@
 use diced_open_dice::Hidden;
 use log::trace;
 use uuid::Uuid;
-use virtio_drivers::transport::{pci::bus::PciRoot, DeviceType, Transport};
+use virtio_drivers::transport::{
+    pci::bus::{ConfigurationAccess, PciRoot},
+    DeviceType, Transport,
+};
 use vmbase::util::ceiling_div;
 use vmbase::virtio::pci::{PciTransportIterator, VirtIOBlk};
 use vmbase::virtio::HalImpl;
@@ -99,7 +102,7 @@
 /// pvmfw in the instance.img as well as index corresponding to empty header which can be used to
 /// record instance data with `record_instance_entry`.
 pub(crate) fn get_recorded_entry(
-    pci_root: &mut PciRoot,
+    pci_root: &mut PciRoot<impl ConfigurationAccess>,
     secret: &[u8],
 ) -> Result<(Option<EntryBody>, Partition, usize)> {
     let mut instance_img = find_instance_img(pci_root)?;
@@ -175,8 +178,8 @@
     }
 }
 
-fn find_instance_img(pci_root: &mut PciRoot) -> Result<Partition> {
-    for transport in PciTransportIterator::<HalImpl>::new(pci_root)
+fn find_instance_img(pci_root: &mut PciRoot<impl ConfigurationAccess>) -> Result<Partition> {
+    for transport in PciTransportIterator::<HalImpl, _>::new(pci_root)
         .filter(|t| DeviceType::Block == t.device_type())
     {
         let device =
diff --git a/guest/pvmfw/src/rollback.rs b/guest/pvmfw/src/rollback.rs
index 74b2cd8..1d84c5b 100644
--- a/guest/pvmfw/src/rollback.rs
+++ b/guest/pvmfw/src/rollback.rs
@@ -25,7 +25,7 @@
 use log::{error, info};
 use pvmfw_avb::Capability;
 use pvmfw_avb::VerifiedBootData;
-use virtio_drivers::transport::pci::bus::PciRoot;
+use virtio_drivers::transport::pci::bus::{ConfigurationAccess, PciRoot};
 use vmbase::fdt::{pci::PciInfo, SwiotlbInfo};
 use vmbase::memory::init_shared_pool;
 use vmbase::rand;
@@ -167,7 +167,9 @@
 }
 
 /// Set up PCI bus and VirtIO-blk device containing the instance.img partition.
-fn initialize_instance_img_device(fdt: &Fdt) -> Result<PciRoot, RebootReason> {
+fn initialize_instance_img_device(
+    fdt: &Fdt,
+) -> Result<PciRoot<impl ConfigurationAccess>, RebootReason> {
     let pci_info = PciInfo::from_fdt(fdt).map_err(|e| {
         error!("Failed to detect PCI from DT: {e}");
         RebootReason::InvalidFdt
diff --git a/guest/rialto/src/communication.rs b/guest/rialto/src/communication.rs
index 1b94912..6f5a59e 100644
--- a/guest/rialto/src/communication.rs
+++ b/guest/rialto/src/communication.rs
@@ -67,7 +67,7 @@
                 match event {
                     VsockEventType::Connected => return Ok(()),
                     VsockEventType::Disconnected { .. } => {
-                        return Err(SocketError::ConnectionFailed.into())
+                        return Err(SocketError::NotConnected.into())
                     }
                     // We shouldn't receive the following event before the connection is
                     // established.
@@ -141,7 +141,7 @@
     fn poll(&mut self) -> virtio_drivers::Result<Option<VsockEventType>> {
         if let Some(event) = self.poll_event_from_peer()? {
             match event {
-                VsockEventType::Disconnected { .. } => Err(SocketError::ConnectionFailed.into()),
+                VsockEventType::Disconnected { .. } => Err(SocketError::NotConnected.into()),
                 VsockEventType::Connected | VsockEventType::ConnectionRequest => {
                     Err(SocketError::InvalidOperation.into())
                 }
diff --git a/guest/rialto/src/main.rs b/guest/rialto/src/main.rs
index 04d18be..c3d3604 100644
--- a/guest/rialto/src/main.rs
+++ b/guest/rialto/src/main.rs
@@ -38,7 +38,10 @@
 use service_vm_requests::{process_request, RequestContext};
 use virtio_drivers::{
     device::socket::{VsockAddr, VMADDR_CID_HOST},
-    transport::{pci::bus::PciRoot, DeviceType, Transport},
+    transport::{
+        pci::bus::{ConfigurationAccess, PciRoot},
+        DeviceType, Transport,
+    },
     Hal,
 };
 use vmbase::{
@@ -123,7 +126,6 @@
     let pci_info = PciInfo::from_fdt(fdt)?;
     debug!("PCI: {pci_info:#x?}");
     let mut pci_root = pci::initialize(pci_info).map_err(Error::PciInitializationFailed)?;
-    debug!("PCI root: {pci_root:#x?}");
     let socket_device = find_socket_device::<HalImpl>(&mut pci_root)?;
     debug!("Found socket device: guest cid = {:?}", socket_device.guest_cid());
     let vendor_hashtree_root_digest = read_vendor_hashtree_root_digest(fdt)?;
@@ -143,8 +145,10 @@
     Ok(())
 }
 
-fn find_socket_device<T: Hal>(pci_root: &mut PciRoot) -> Result<VirtIOSocket<T>> {
-    PciTransportIterator::<T>::new(pci_root)
+fn find_socket_device<T: Hal>(
+    pci_root: &mut PciRoot<impl ConfigurationAccess>,
+) -> Result<VirtIOSocket<T>> {
+    PciTransportIterator::<T, _>::new(pci_root)
         .find(|t| DeviceType::Socket == t.device_type())
         .map(VirtIOSocket::<T>::new)
         .transpose()
diff --git a/guest/vmbase_example/src/pci.rs b/guest/vmbase_example/src/pci.rs
index 32ab9f6..1e87682 100644
--- a/guest/vmbase_example/src/pci.rs
+++ b/guest/vmbase_example/src/pci.rs
@@ -20,7 +20,10 @@
 use virtio_drivers::{
     device::console::VirtIOConsole,
     transport::{
-        pci::{bus::PciRoot, PciTransport},
+        pci::{
+            bus::{ConfigurationAccess, PciRoot},
+            PciTransport,
+        },
         DeviceType, Transport,
     },
     BufferDirection, Error, Hal, PhysAddr, PAGE_SIZE,
@@ -33,11 +36,11 @@
 /// The size in sectors of the test block device we expect.
 const EXPECTED_SECTOR_COUNT: usize = 4;
 
-pub fn check_pci(pci_root: &mut PciRoot) {
+pub fn check_pci(pci_root: &mut PciRoot<impl ConfigurationAccess>) {
     let mut checked_virtio_device_count = 0;
     let mut block_device_count = 0;
     let mut socket_device_count = 0;
-    for mut transport in PciTransportIterator::<HalImpl>::new(pci_root) {
+    for mut transport in PciTransportIterator::<HalImpl, _>::new(pci_root) {
         info!(
             "Detected virtio PCI device with device type {:?}, features {:#018x}",
             transport.device_type(),
@@ -104,7 +107,10 @@
 fn check_virtio_console_device(transport: PciTransport) {
     let mut console = VirtIOConsole::<HalImpl, PciTransport>::new(transport)
         .expect("Failed to create VirtIO console driver");
-    info!("Found console device: {:?}", console.info());
+    info!(
+        "Found console device with size {:?}",
+        console.size().expect("Failed to get size of VirtIO console device")
+    );
     for &c in b"Hello VirtIO console\n" {
         console.send(c).expect("Failed to send character to VirtIO console device");
     }
diff --git a/libs/libvmbase/src/fdt/pci.rs b/libs/libvmbase/src/fdt/pci.rs
index 44ad455..b526f3d 100644
--- a/libs/libvmbase/src/fdt/pci.rs
+++ b/libs/libvmbase/src/fdt/pci.rs
@@ -18,7 +18,7 @@
 use libfdt::{AddressRange, Fdt, FdtError, FdtNode};
 use log::debug;
 use thiserror::Error;
-use virtio_drivers::transport::pci::bus::{Cam, PciRoot};
+use virtio_drivers::transport::pci::bus::{Cam, ConfigurationAccess, MmioCam, PciRoot};
 
 /// PCI MMIO configuration region size.
 const PCI_CFG_SIZE: usize = 0x100_0000;
@@ -94,10 +94,10 @@
     /// To prevent concurrent access, only one `PciRoot` should exist in the program. Thus this
     /// method must only be called once, and there must be no other `PciRoot` constructed using the
     /// same CAM.
-    pub unsafe fn make_pci_root(&self) -> PciRoot {
+    pub unsafe fn make_pci_root(&self) -> PciRoot<impl ConfigurationAccess> {
         // SAFETY: We trust that the FDT gave us a valid MMIO base address for the CAM. The caller
         // guarantees to only call us once, so there are no other references to it.
-        unsafe { PciRoot::new(self.cam_range.start as *mut u8, Cam::MmioCam) }
+        PciRoot::new(unsafe { MmioCam::new(self.cam_range.start as *mut u8, Cam::MmioCam) })
     }
 }
 
diff --git a/libs/libvmbase/src/virtio/pci.rs b/libs/libvmbase/src/virtio/pci.rs
index ec89b6b..591ae54 100644
--- a/libs/libvmbase/src/virtio/pci.rs
+++ b/libs/libvmbase/src/virtio/pci.rs
@@ -26,7 +26,7 @@
 use virtio_drivers::{
     device::{blk, socket},
     transport::pci::{
-        bus::{BusDeviceIterator, PciRoot},
+        bus::{BusDeviceIterator, ConfigurationAccess, PciRoot},
         virtio_device_type, PciTransport,
     },
     Hal,
@@ -66,7 +66,7 @@
 /// 3. Creates and returns a `PciRoot`.
 ///
 /// This must only be called once and after having switched to the dynamic page tables.
-pub fn initialize(pci_info: PciInfo) -> Result<PciRoot, PciError> {
+pub fn initialize(pci_info: PciInfo) -> Result<PciRoot<impl ConfigurationAccess>, PciError> {
     PCI_INFO.set(Box::new(pci_info.clone())).map_err(|_| PciError::DuplicateInitialization)?;
 
     let cam_start = pci_info.cam_range.start;
@@ -90,21 +90,21 @@
 pub type VirtIOSocket<T> = socket::VirtIOSocket<T, PciTransport>;
 
 /// An iterator that iterates over the PCI transport for each device.
-pub struct PciTransportIterator<'a, T: Hal> {
-    pci_root: &'a mut PciRoot,
-    bus: BusDeviceIterator,
+pub struct PciTransportIterator<'a, T: Hal, C: ConfigurationAccess> {
+    pci_root: &'a mut PciRoot<C>,
+    bus: BusDeviceIterator<C>,
     _hal: PhantomData<T>,
 }
 
-impl<'a, T: Hal> PciTransportIterator<'a, T> {
+impl<'a, T: Hal, C: ConfigurationAccess> PciTransportIterator<'a, T, C> {
     /// Creates a new iterator.
-    pub fn new(pci_root: &'a mut PciRoot) -> Self {
+    pub fn new(pci_root: &'a mut PciRoot<C>) -> Self {
         let bus = pci_root.enumerate_bus(0);
         Self { pci_root, bus, _hal: PhantomData }
     }
 }
 
-impl<'a, T: Hal> Iterator for PciTransportIterator<'a, T> {
+impl<'a, T: Hal, C: ConfigurationAccess> Iterator for PciTransportIterator<'a, T, C> {
     type Item = PciTransport;
 
     fn next(&mut self) -> Option<Self::Item> {
@@ -121,7 +121,7 @@
             };
             debug!("  VirtIO {:?}", virtio_type);
 
-            return PciTransport::new::<T>(self.pci_root, device_function).ok();
+            return PciTransport::new::<T, C>(self.pci_root, device_function).ok();
         }
     }
 }