[rialto] Enable the host and service VM vsock connection

This cl enables the communication between the host and the
service VM. The cl also verifies the connection in the integration
tests by sending a message from the host to rialto, which is
then reversed and sent back.

Test: atest rialto_test
Bug: 274441673
Change-Id: I3e1f4f48c2d8b7fb1b211e0830ff07b5291d4410
diff --git a/rialto/src/communication.rs b/rialto/src/communication.rs
new file mode 100644
index 0000000..f00393d
--- /dev/null
+++ b/rialto/src/communication.rs
@@ -0,0 +1,85 @@
+// Copyright 2023, The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//! Supports for the communication between rialto and host.
+
+use crate::error::{Error, Result};
+use log::info;
+use virtio_drivers::{
+    self,
+    device::socket::{
+        SingleConnectionManager, SocketError, VirtIOSocket, VsockAddr, VsockEventType,
+    },
+    transport::Transport,
+    Hal,
+};
+
+const MAX_RECV_BUFFER_SIZE_BYTES: usize = 64;
+
+pub struct DataChannel<H: Hal, T: Transport> {
+    connection_manager: SingleConnectionManager<H, T>,
+}
+
+impl<H: Hal, T: Transport> From<VirtIOSocket<H, T>> for DataChannel<H, T> {
+    fn from(socket_device_driver: VirtIOSocket<H, T>) -> Self {
+        Self { connection_manager: SingleConnectionManager::new(socket_device_driver) }
+    }
+}
+
+impl<H: Hal, T: Transport> DataChannel<H, T> {
+    /// Connects to the given destination.
+    pub fn connect(&mut self, destination: VsockAddr) -> virtio_drivers::Result {
+        // Use the same port on rialto and host for convenience.
+        self.connection_manager.connect(destination, destination.port)?;
+        self.connection_manager.wait_for_connect()?;
+        info!("Connected to the destination {destination:?}");
+        Ok(())
+    }
+
+    /// Processes the received requests and sends back a reply.
+    pub fn handle_incoming_request(&mut self) -> Result<()> {
+        let mut buffer = [0u8; MAX_RECV_BUFFER_SIZE_BYTES];
+
+        // TODO(b/274441673): Handle the scenario when the given buffer is too short.
+        let len = self.wait_for_recv(&mut buffer).map_err(Error::ReceivingDataFailed)?;
+
+        // TODO(b/291732060): Implement the communication protocol.
+        // Just reverse the received message for now.
+        buffer[..len].reverse();
+        self.connection_manager.send(&buffer[..len])?;
+        Ok(())
+    }
+
+    fn wait_for_recv(&mut self, buffer: &mut [u8]) -> virtio_drivers::Result<usize> {
+        loop {
+            match self.connection_manager.wait_for_recv(buffer)?.event_type {
+                VsockEventType::Disconnected { .. } => {
+                    return Err(SocketError::ConnectionFailed.into())
+                }
+                VsockEventType::Received { length, .. } => return Ok(length),
+                VsockEventType::Connected
+                | VsockEventType::ConnectionRequest
+                | VsockEventType::CreditRequest
+                | VsockEventType::CreditUpdate => {}
+            }
+        }
+    }
+
+    /// Shuts down the data channel.
+    pub fn force_close(&mut self) -> virtio_drivers::Result {
+        self.connection_manager.force_close()?;
+        info!("Connection shutdown.");
+        Ok(())
+    }
+}
diff --git a/rialto/src/error.rs b/rialto/src/error.rs
index 0c1e25d..461870b 100644
--- a/rialto/src/error.rs
+++ b/rialto/src/error.rs
@@ -41,6 +41,10 @@
     VirtIOSocketCreationFailed(virtio_drivers::Error),
     /// Missing socket device.
     MissingVirtIOSocketDevice,
+    /// Failed VirtIO driver operation.
+    VirtIODriverOperationFailed(virtio_drivers::Error),
+    /// Failed to receive data.
+    ReceivingDataFailed(virtio_drivers::Error),
 }
 
 impl fmt::Display for Error {
@@ -58,6 +62,10 @@
                 write!(f, "Failed to create VirtIO Socket device: {e}")
             }
             Self::MissingVirtIOSocketDevice => write!(f, "Missing VirtIO Socket device."),
+            Self::VirtIODriverOperationFailed(e) => {
+                write!(f, "Failed VirtIO driver operation: {e}")
+            }
+            Self::ReceivingDataFailed(e) => write!(f, "Failed to receive data: {e}"),
         }
     }
 }
@@ -91,3 +99,9 @@
         Self::MemoryOperationFailed(e)
     }
 }
+
+impl From<virtio_drivers::Error> for Error {
+    fn from(e: virtio_drivers::Error) -> Self {
+        Self::VirtIODriverOperationFailed(e)
+    }
+}
diff --git a/rialto/src/main.rs b/rialto/src/main.rs
index bbc9997..5c6649a 100644
--- a/rialto/src/main.rs
+++ b/rialto/src/main.rs
@@ -17,11 +17,13 @@
 #![no_main]
 #![no_std]
 
+mod communication;
 mod error;
 mod exceptions;
 
 extern crate alloc;
 
+use crate::communication::DataChannel;
 use crate::error::{Error, Result};
 use core::num::NonZeroUsize;
 use core::slice;
@@ -30,6 +32,7 @@
 use libfdt::FdtError;
 use log::{debug, error, info};
 use virtio_drivers::{
+    device::socket::VsockAddr,
     transport::{pci::bus::PciRoot, DeviceType, Transport},
     Hal,
 };
@@ -46,6 +49,20 @@
     },
 };
 
+fn host_addr() -> VsockAddr {
+    const PROTECTED_VM_PORT: u32 = 5679;
+    const NON_PROTECTED_VM_PORT: u32 = 5680;
+    const VMADDR_CID_HOST: u64 = 2;
+
+    let port = if is_protected_vm() { PROTECTED_VM_PORT } else { NON_PROTECTED_VM_PORT };
+    VsockAddr { cid: VMADDR_CID_HOST, port }
+}
+
+fn is_protected_vm() -> bool {
+    // Use MMIO support to determine whether the VM is protected.
+    get_mmio_guard().is_some()
+}
+
 fn new_page_table() -> Result<PageTable> {
     let mut page_table = PageTable::default();
 
@@ -119,6 +136,12 @@
     debug!("PCI root: {pci_root:#x?}");
     let socket_device = find_socket_device::<HalImpl>(&mut pci_root)?;
     debug!("Found socket device: guest cid = {:?}", socket_device.guest_cid());
+
+    let mut data_channel = DataChannel::from(socket_device);
+    data_channel.connect(host_addr())?;
+    data_channel.handle_incoming_request()?;
+    data_channel.force_close()?;
+
     Ok(())
 }