[rialto] Find virtio socket device through PCI bus

Bug: 284462758
Test: atest rialto_test
Change-Id: I62836d42b6afba2beea37ef57e8745f03e6d0c3d
diff --git a/rialto/src/error.rs b/rialto/src/error.rs
index c326566..0c1e25d 100644
--- a/rialto/src/error.rs
+++ b/rialto/src/error.rs
@@ -37,6 +37,10 @@
     MemoryOperationFailed(MemoryTrackerError),
     /// Failed to initialize PCI.
     PciInitializationFailed(pci::PciError),
+    /// Failed to create VirtIO Socket device.
+    VirtIOSocketCreationFailed(virtio_drivers::Error),
+    /// Missing socket device.
+    MissingVirtIOSocketDevice,
 }
 
 impl fmt::Display for Error {
@@ -50,6 +54,10 @@
             Self::InvalidPci(e) => write!(f, "Invalid PCI: {e}"),
             Self::MemoryOperationFailed(e) => write!(f, "Failed memory operation: {e}"),
             Self::PciInitializationFailed(e) => write!(f, "Failed to initialize PCI: {e}"),
+            Self::VirtIOSocketCreationFailed(e) => {
+                write!(f, "Failed to create VirtIO Socket device: {e}")
+            }
+            Self::MissingVirtIOSocketDevice => write!(f, "Missing VirtIO Socket device."),
         }
     }
 }
diff --git a/rialto/src/main.rs b/rialto/src/main.rs
index 3e0485d..bbc9997 100644
--- a/rialto/src/main.rs
+++ b/rialto/src/main.rs
@@ -29,14 +29,21 @@
 use hyp::{get_mem_sharer, get_mmio_guard};
 use libfdt::FdtError;
 use log::{debug, error, info};
+use virtio_drivers::{
+    transport::{pci::bus::PciRoot, DeviceType, Transport},
+    Hal,
+};
 use vmbase::{
     configure_heap,
     fdt::SwiotlbInfo,
     layout::{self, crosvm},
     main,
-    memory::{MemoryTracker, PageTable, MEMORY, PAGE_SIZE, SIZE_64KB},
+    memory::{MemoryTracker, PageTable, MEMORY, PAGE_SIZE, SIZE_128KB},
     power::reboot,
-    virtio::pci,
+    virtio::{
+        pci::{self, PciTransportIterator, VirtIOSocket},
+        HalImpl,
+    },
 };
 
 fn new_page_table() -> Result<PageTable> {
@@ -107,12 +114,23 @@
 
     let pci_info = PciInfo::from_fdt(fdt)?;
     debug!("PCI: {pci_info:#x?}");
-    let pci_root = pci::initialize(pci_info, MEMORY.lock().as_mut().unwrap())
+    let mut pci_root = pci::initialize(pci_info, MEMORY.lock().as_mut().unwrap())
         .map_err(Error::PciInitializationFailed)?;
     debug!("PCI root: {pci_root:#x?}");
+    let socket_device = find_socket_device::<HalImpl>(&mut pci_root)?;
+    debug!("Found socket device: guest cid = {:?}", socket_device.guest_cid());
     Ok(())
 }
 
+fn find_socket_device<T: Hal>(pci_root: &mut PciRoot) -> Result<VirtIOSocket<T>> {
+    PciTransportIterator::<T>::new(pci_root)
+        .find(|t| DeviceType::Socket == t.device_type())
+        .map(VirtIOSocket::<T>::new)
+        .transpose()
+        .map_err(Error::VirtIOSocketCreationFailed)?
+        .ok_or(Error::MissingVirtIOSocketDevice)
+}
+
 fn try_unshare_all_memory() -> Result<()> {
     info!("Starting unsharing memory...");
 
@@ -147,4 +165,4 @@
 }
 
 main!(main);
-configure_heap!(SIZE_64KB);
+configure_heap!(SIZE_128KB);