[rialto] Find virtio socket device through PCI bus
Bug: 284462758
Test: atest rialto_test
Change-Id: I62836d42b6afba2beea37ef57e8745f03e6d0c3d
diff --git a/rialto/src/error.rs b/rialto/src/error.rs
index c326566..0c1e25d 100644
--- a/rialto/src/error.rs
+++ b/rialto/src/error.rs
@@ -37,6 +37,10 @@
MemoryOperationFailed(MemoryTrackerError),
/// Failed to initialize PCI.
PciInitializationFailed(pci::PciError),
+ /// Failed to create VirtIO Socket device.
+ VirtIOSocketCreationFailed(virtio_drivers::Error),
+ /// Missing socket device.
+ MissingVirtIOSocketDevice,
}
impl fmt::Display for Error {
@@ -50,6 +54,10 @@
Self::InvalidPci(e) => write!(f, "Invalid PCI: {e}"),
Self::MemoryOperationFailed(e) => write!(f, "Failed memory operation: {e}"),
Self::PciInitializationFailed(e) => write!(f, "Failed to initialize PCI: {e}"),
+ Self::VirtIOSocketCreationFailed(e) => {
+ write!(f, "Failed to create VirtIO Socket device: {e}")
+ }
+ Self::MissingVirtIOSocketDevice => write!(f, "Missing VirtIO Socket device."),
}
}
}
diff --git a/rialto/src/main.rs b/rialto/src/main.rs
index 3e0485d..bbc9997 100644
--- a/rialto/src/main.rs
+++ b/rialto/src/main.rs
@@ -29,14 +29,21 @@
use hyp::{get_mem_sharer, get_mmio_guard};
use libfdt::FdtError;
use log::{debug, error, info};
+use virtio_drivers::{
+ transport::{pci::bus::PciRoot, DeviceType, Transport},
+ Hal,
+};
use vmbase::{
configure_heap,
fdt::SwiotlbInfo,
layout::{self, crosvm},
main,
- memory::{MemoryTracker, PageTable, MEMORY, PAGE_SIZE, SIZE_64KB},
+ memory::{MemoryTracker, PageTable, MEMORY, PAGE_SIZE, SIZE_128KB},
power::reboot,
- virtio::pci,
+ virtio::{
+ pci::{self, PciTransportIterator, VirtIOSocket},
+ HalImpl,
+ },
};
fn new_page_table() -> Result<PageTable> {
@@ -107,12 +114,23 @@
let pci_info = PciInfo::from_fdt(fdt)?;
debug!("PCI: {pci_info:#x?}");
- let pci_root = pci::initialize(pci_info, MEMORY.lock().as_mut().unwrap())
+ let mut pci_root = pci::initialize(pci_info, MEMORY.lock().as_mut().unwrap())
.map_err(Error::PciInitializationFailed)?;
debug!("PCI root: {pci_root:#x?}");
+ let socket_device = find_socket_device::<HalImpl>(&mut pci_root)?;
+ debug!("Found socket device: guest cid = {:?}", socket_device.guest_cid());
Ok(())
}
+fn find_socket_device<T: Hal>(pci_root: &mut PciRoot) -> Result<VirtIOSocket<T>> {
+ PciTransportIterator::<T>::new(pci_root)
+ .find(|t| DeviceType::Socket == t.device_type())
+ .map(VirtIOSocket::<T>::new)
+ .transpose()
+ .map_err(Error::VirtIOSocketCreationFailed)?
+ .ok_or(Error::MissingVirtIOSocketDevice)
+}
+
fn try_unshare_all_memory() -> Result<()> {
info!("Starting unsharing memory...");
@@ -147,4 +165,4 @@
}
main!(main);
-configure_heap!(SIZE_64KB);
+configure_heap!(SIZE_128KB);