[rialto] Enable the host and service VM vsock connection

This cl enables the communication between the host and the
service VM. The cl also verifies the connection in the integration
tests by sending a message from the host to rialto, which is
then reversed and sent back.

Test: atest rialto_test
Bug: 274441673
Change-Id: I3e1f4f48c2d8b7fb1b211e0830ff07b5291d4410
diff --git a/rialto/Android.bp b/rialto/Android.bp
index 1840278..ed9a284 100644
--- a/rialto/Android.bp
+++ b/rialto/Android.bp
@@ -101,6 +101,7 @@
         "liblog_rust",
         "libnix",
         "libvmclient",
+        "libvsock",
     ],
     data: [
         ":rialto_bin",
diff --git a/rialto/AndroidTest.xml b/rialto/AndroidTest.xml
new file mode 100644
index 0000000..43c4c90
--- /dev/null
+++ b/rialto/AndroidTest.xml
@@ -0,0 +1,35 @@
+<?xml version="1.0" encoding="utf-8"?>
+<!-- Copyright (C) 2023 The Android Open Source Project
+
+     Licensed under the Apache License, Version 2.0 (the "License");
+     you may not use this file except in compliance with the License.
+     You may obtain a copy of the License at
+
+          http://www.apache.org/licenses/LICENSE-2.0
+
+     Unless required by applicable law or agreed to in writing, software
+     distributed under the License is distributed on an "AS IS" BASIS,
+     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+     See the License for the specific language governing permissions and
+     limitations under the License.
+-->
+
+<configuration description="Config for rialto_test">
+  <!--
+    We need root privilege to bypass selinux because shell cannot create socket.
+    Otherwise, we hit the following errors:
+
+    avc:  denied  { create } for  scontext=u:r:shell:s0 tcontext=u:r:shell:s0
+     tclass=vsock_socket permissive=0
+  -->
+  <target_preparer class="com.android.tradefed.targetprep.RootTargetPreparer"/>
+
+  <target_preparer class="com.android.tradefed.targetprep.PushFilePreparer">
+    <option name="push-file" key="rialto_test" value="/data/local/tmp/rialto_test" />
+  </target_preparer>
+
+  <test class="com.android.tradefed.testtype.rust.RustBinaryTest" >
+    <option name="test-device-path" value="/data/local/tmp" />
+    <option name="module-name" value="rialto_test" />
+  </test>
+</configuration>
\ No newline at end of file
diff --git a/rialto/src/communication.rs b/rialto/src/communication.rs
new file mode 100644
index 0000000..f00393d
--- /dev/null
+++ b/rialto/src/communication.rs
@@ -0,0 +1,85 @@
+// Copyright 2023, The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//! Supports for the communication between rialto and host.
+
+use crate::error::{Error, Result};
+use log::info;
+use virtio_drivers::{
+    self,
+    device::socket::{
+        SingleConnectionManager, SocketError, VirtIOSocket, VsockAddr, VsockEventType,
+    },
+    transport::Transport,
+    Hal,
+};
+
+const MAX_RECV_BUFFER_SIZE_BYTES: usize = 64;
+
+pub struct DataChannel<H: Hal, T: Transport> {
+    connection_manager: SingleConnectionManager<H, T>,
+}
+
+impl<H: Hal, T: Transport> From<VirtIOSocket<H, T>> for DataChannel<H, T> {
+    fn from(socket_device_driver: VirtIOSocket<H, T>) -> Self {
+        Self { connection_manager: SingleConnectionManager::new(socket_device_driver) }
+    }
+}
+
+impl<H: Hal, T: Transport> DataChannel<H, T> {
+    /// Connects to the given destination.
+    pub fn connect(&mut self, destination: VsockAddr) -> virtio_drivers::Result {
+        // Use the same port on rialto and host for convenience.
+        self.connection_manager.connect(destination, destination.port)?;
+        self.connection_manager.wait_for_connect()?;
+        info!("Connected to the destination {destination:?}");
+        Ok(())
+    }
+
+    /// Processes the received requests and sends back a reply.
+    pub fn handle_incoming_request(&mut self) -> Result<()> {
+        let mut buffer = [0u8; MAX_RECV_BUFFER_SIZE_BYTES];
+
+        // TODO(b/274441673): Handle the scenario when the given buffer is too short.
+        let len = self.wait_for_recv(&mut buffer).map_err(Error::ReceivingDataFailed)?;
+
+        // TODO(b/291732060): Implement the communication protocol.
+        // Just reverse the received message for now.
+        buffer[..len].reverse();
+        self.connection_manager.send(&buffer[..len])?;
+        Ok(())
+    }
+
+    fn wait_for_recv(&mut self, buffer: &mut [u8]) -> virtio_drivers::Result<usize> {
+        loop {
+            match self.connection_manager.wait_for_recv(buffer)?.event_type {
+                VsockEventType::Disconnected { .. } => {
+                    return Err(SocketError::ConnectionFailed.into())
+                }
+                VsockEventType::Received { length, .. } => return Ok(length),
+                VsockEventType::Connected
+                | VsockEventType::ConnectionRequest
+                | VsockEventType::CreditRequest
+                | VsockEventType::CreditUpdate => {}
+            }
+        }
+    }
+
+    /// Shuts down the data channel.
+    pub fn force_close(&mut self) -> virtio_drivers::Result {
+        self.connection_manager.force_close()?;
+        info!("Connection shutdown.");
+        Ok(())
+    }
+}
diff --git a/rialto/src/error.rs b/rialto/src/error.rs
index 0c1e25d..461870b 100644
--- a/rialto/src/error.rs
+++ b/rialto/src/error.rs
@@ -41,6 +41,10 @@
     VirtIOSocketCreationFailed(virtio_drivers::Error),
     /// Missing socket device.
     MissingVirtIOSocketDevice,
+    /// Failed VirtIO driver operation.
+    VirtIODriverOperationFailed(virtio_drivers::Error),
+    /// Failed to receive data.
+    ReceivingDataFailed(virtio_drivers::Error),
 }
 
 impl fmt::Display for Error {
@@ -58,6 +62,10 @@
                 write!(f, "Failed to create VirtIO Socket device: {e}")
             }
             Self::MissingVirtIOSocketDevice => write!(f, "Missing VirtIO Socket device."),
+            Self::VirtIODriverOperationFailed(e) => {
+                write!(f, "Failed VirtIO driver operation: {e}")
+            }
+            Self::ReceivingDataFailed(e) => write!(f, "Failed to receive data: {e}"),
         }
     }
 }
@@ -91,3 +99,9 @@
         Self::MemoryOperationFailed(e)
     }
 }
+
+impl From<virtio_drivers::Error> for Error {
+    fn from(e: virtio_drivers::Error) -> Self {
+        Self::VirtIODriverOperationFailed(e)
+    }
+}
diff --git a/rialto/src/main.rs b/rialto/src/main.rs
index bbc9997..5c6649a 100644
--- a/rialto/src/main.rs
+++ b/rialto/src/main.rs
@@ -17,11 +17,13 @@
 #![no_main]
 #![no_std]
 
+mod communication;
 mod error;
 mod exceptions;
 
 extern crate alloc;
 
+use crate::communication::DataChannel;
 use crate::error::{Error, Result};
 use core::num::NonZeroUsize;
 use core::slice;
@@ -30,6 +32,7 @@
 use libfdt::FdtError;
 use log::{debug, error, info};
 use virtio_drivers::{
+    device::socket::VsockAddr,
     transport::{pci::bus::PciRoot, DeviceType, Transport},
     Hal,
 };
@@ -46,6 +49,20 @@
     },
 };
 
+fn host_addr() -> VsockAddr {
+    const PROTECTED_VM_PORT: u32 = 5679;
+    const NON_PROTECTED_VM_PORT: u32 = 5680;
+    const VMADDR_CID_HOST: u64 = 2;
+
+    let port = if is_protected_vm() { PROTECTED_VM_PORT } else { NON_PROTECTED_VM_PORT };
+    VsockAddr { cid: VMADDR_CID_HOST, port }
+}
+
+fn is_protected_vm() -> bool {
+    // Use MMIO support to determine whether the VM is protected.
+    get_mmio_guard().is_some()
+}
+
 fn new_page_table() -> Result<PageTable> {
     let mut page_table = PageTable::default();
 
@@ -119,6 +136,12 @@
     debug!("PCI root: {pci_root:#x?}");
     let socket_device = find_socket_device::<HalImpl>(&mut pci_root)?;
     debug!("Found socket device: guest cid = {:?}", socket_device.guest_cid());
+
+    let mut data_channel = DataChannel::from(socket_device);
+    data_channel.connect(host_addr())?;
+    data_channel.handle_incoming_request()?;
+    data_channel.force_close()?;
+
     Ok(())
 }
 
diff --git a/rialto/tests/test.rs b/rialto/tests/test.rs
index 98c935d..31a0c55 100644
--- a/rialto/tests/test.rs
+++ b/rialto/tests/test.rs
@@ -22,15 +22,21 @@
     },
     binder::{ParcelFileDescriptor, ProcessState},
 };
-use anyhow::{anyhow, Context, Error};
+use anyhow::{anyhow, bail, Context, Error};
 use log::info;
 use std::fs::File;
-use std::io::{self, BufRead, BufReader};
+use std::io::{self, BufRead, BufReader, Read, Write};
 use std::os::unix::io::FromRawFd;
 use std::panic;
 use std::thread;
 use std::time::Duration;
 use vmclient::{DeathReason, VmInstance};
+use vsock::{VsockListener, VMADDR_CID_HOST};
+
+// TODO(b/291732060): Move the port numbers to the common library shared between the host
+// and rialto.
+const PROTECTED_VM_PORT: u32 = 5679;
+const NON_PROTECTED_VM_PORT: u32 = 5680;
 
 const SIGNED_RIALTO_PATH: &str = "/data/local/tmp/rialto_test/arm64/rialto.bin";
 const UNSIGNED_RIALTO_PATH: &str = "/data/local/tmp/rialto_test/arm64/rialto_unsigned.bin";
@@ -124,6 +130,9 @@
     )
     .context("Failed to create VM")?;
 
+    let port = if protected_vm { PROTECTED_VM_PORT } else { NON_PROTECTED_VM_PORT };
+    let check_socket_handle = thread::spawn(move || try_check_socket_connection(port).unwrap());
+
     vm.start().context("Failed to start VM")?;
 
     // Wait for VM to finish, and check that it shut down cleanly.
@@ -132,6 +141,15 @@
         .ok_or_else(|| anyhow!("Timed out waiting for VM exit"))?;
     assert_eq!(death_reason, DeathReason::Shutdown);
 
+    match check_socket_handle.join() {
+        Ok(_) => {
+            info!(
+                "Received the echoed message. \
+                   The socket connection between the host and the service VM works correctly."
+            )
+        }
+        Err(_) => bail!("The socket connection check failed."),
+    }
     Ok(())
 }
 
@@ -149,3 +167,28 @@
     });
     Ok(writer)
 }
+
+fn try_check_socket_connection(port: u32) -> Result<(), Error> {
+    info!("Setting up the listening socket on port {port}...");
+    let listener = VsockListener::bind_with_cid_port(VMADDR_CID_HOST, port)?;
+    info!("Listening on port {port}...");
+
+    let Some(Ok(mut vsock_stream)) = listener.incoming().next() else {
+        bail!("Failed to get vsock_stream");
+    };
+    info!("Accepted connection {:?}", vsock_stream);
+
+    let message = "Hello from host";
+    vsock_stream.write_all(message.as_bytes())?;
+    vsock_stream.flush()?;
+    info!("Sent message: {:?}.", message);
+
+    let mut buffer = vec![0u8; 30];
+    vsock_stream.set_read_timeout(Some(Duration::from_millis(1_000)))?;
+    let len = vsock_stream.read(&mut buffer)?;
+
+    assert_eq!(message.len(), len);
+    buffer[..len].reverse();
+    assert_eq!(message.as_bytes(), &buffer[..len]);
+    Ok(())
+}