Merge "[rialto] Enable the host and service VM vsock connection" into main
diff --git a/rialto/Android.bp b/rialto/Android.bp
index 1840278..ed9a284 100644
--- a/rialto/Android.bp
+++ b/rialto/Android.bp
@@ -101,6 +101,7 @@
"liblog_rust",
"libnix",
"libvmclient",
+ "libvsock",
],
data: [
":rialto_bin",
diff --git a/rialto/AndroidTest.xml b/rialto/AndroidTest.xml
new file mode 100644
index 0000000..43c4c90
--- /dev/null
+++ b/rialto/AndroidTest.xml
@@ -0,0 +1,35 @@
+<?xml version="1.0" encoding="utf-8"?>
+<!-- Copyright (C) 2023 The Android Open Source Project
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+-->
+
+<configuration description="Config for rialto_test">
+ <!--
+ We need root privilege to bypass selinux because shell cannot create socket.
+ Otherwise, we hit the following errors:
+
+ avc: denied { create } for scontext=u:r:shell:s0 tcontext=u:r:shell:s0
+ tclass=vsock_socket permissive=0
+ -->
+ <target_preparer class="com.android.tradefed.targetprep.RootTargetPreparer"/>
+
+ <target_preparer class="com.android.tradefed.targetprep.PushFilePreparer">
+ <option name="push-file" key="rialto_test" value="/data/local/tmp/rialto_test" />
+ </target_preparer>
+
+ <test class="com.android.tradefed.testtype.rust.RustBinaryTest" >
+ <option name="test-device-path" value="/data/local/tmp" />
+ <option name="module-name" value="rialto_test" />
+ </test>
+</configuration>
\ No newline at end of file
diff --git a/rialto/src/communication.rs b/rialto/src/communication.rs
new file mode 100644
index 0000000..f00393d
--- /dev/null
+++ b/rialto/src/communication.rs
@@ -0,0 +1,85 @@
+// Copyright 2023, The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//! Supports for the communication between rialto and host.
+
+use crate::error::{Error, Result};
+use log::info;
+use virtio_drivers::{
+ self,
+ device::socket::{
+ SingleConnectionManager, SocketError, VirtIOSocket, VsockAddr, VsockEventType,
+ },
+ transport::Transport,
+ Hal,
+};
+
+const MAX_RECV_BUFFER_SIZE_BYTES: usize = 64;
+
+pub struct DataChannel<H: Hal, T: Transport> {
+ connection_manager: SingleConnectionManager<H, T>,
+}
+
+impl<H: Hal, T: Transport> From<VirtIOSocket<H, T>> for DataChannel<H, T> {
+ fn from(socket_device_driver: VirtIOSocket<H, T>) -> Self {
+ Self { connection_manager: SingleConnectionManager::new(socket_device_driver) }
+ }
+}
+
+impl<H: Hal, T: Transport> DataChannel<H, T> {
+ /// Connects to the given destination.
+ pub fn connect(&mut self, destination: VsockAddr) -> virtio_drivers::Result {
+ // Use the same port on rialto and host for convenience.
+ self.connection_manager.connect(destination, destination.port)?;
+ self.connection_manager.wait_for_connect()?;
+ info!("Connected to the destination {destination:?}");
+ Ok(())
+ }
+
+ /// Processes the received requests and sends back a reply.
+ pub fn handle_incoming_request(&mut self) -> Result<()> {
+ let mut buffer = [0u8; MAX_RECV_BUFFER_SIZE_BYTES];
+
+ // TODO(b/274441673): Handle the scenario when the given buffer is too short.
+ let len = self.wait_for_recv(&mut buffer).map_err(Error::ReceivingDataFailed)?;
+
+ // TODO(b/291732060): Implement the communication protocol.
+ // Just reverse the received message for now.
+ buffer[..len].reverse();
+ self.connection_manager.send(&buffer[..len])?;
+ Ok(())
+ }
+
+ fn wait_for_recv(&mut self, buffer: &mut [u8]) -> virtio_drivers::Result<usize> {
+ loop {
+ match self.connection_manager.wait_for_recv(buffer)?.event_type {
+ VsockEventType::Disconnected { .. } => {
+ return Err(SocketError::ConnectionFailed.into())
+ }
+ VsockEventType::Received { length, .. } => return Ok(length),
+ VsockEventType::Connected
+ | VsockEventType::ConnectionRequest
+ | VsockEventType::CreditRequest
+ | VsockEventType::CreditUpdate => {}
+ }
+ }
+ }
+
+ /// Shuts down the data channel.
+ pub fn force_close(&mut self) -> virtio_drivers::Result {
+ self.connection_manager.force_close()?;
+ info!("Connection shutdown.");
+ Ok(())
+ }
+}
diff --git a/rialto/src/error.rs b/rialto/src/error.rs
index 0c1e25d..461870b 100644
--- a/rialto/src/error.rs
+++ b/rialto/src/error.rs
@@ -41,6 +41,10 @@
VirtIOSocketCreationFailed(virtio_drivers::Error),
/// Missing socket device.
MissingVirtIOSocketDevice,
+ /// Failed VirtIO driver operation.
+ VirtIODriverOperationFailed(virtio_drivers::Error),
+ /// Failed to receive data.
+ ReceivingDataFailed(virtio_drivers::Error),
}
impl fmt::Display for Error {
@@ -58,6 +62,10 @@
write!(f, "Failed to create VirtIO Socket device: {e}")
}
Self::MissingVirtIOSocketDevice => write!(f, "Missing VirtIO Socket device."),
+ Self::VirtIODriverOperationFailed(e) => {
+ write!(f, "Failed VirtIO driver operation: {e}")
+ }
+ Self::ReceivingDataFailed(e) => write!(f, "Failed to receive data: {e}"),
}
}
}
@@ -91,3 +99,9 @@
Self::MemoryOperationFailed(e)
}
}
+
+impl From<virtio_drivers::Error> for Error {
+ fn from(e: virtio_drivers::Error) -> Self {
+ Self::VirtIODriverOperationFailed(e)
+ }
+}
diff --git a/rialto/src/main.rs b/rialto/src/main.rs
index bbc9997..5c6649a 100644
--- a/rialto/src/main.rs
+++ b/rialto/src/main.rs
@@ -17,11 +17,13 @@
#![no_main]
#![no_std]
+mod communication;
mod error;
mod exceptions;
extern crate alloc;
+use crate::communication::DataChannel;
use crate::error::{Error, Result};
use core::num::NonZeroUsize;
use core::slice;
@@ -30,6 +32,7 @@
use libfdt::FdtError;
use log::{debug, error, info};
use virtio_drivers::{
+ device::socket::VsockAddr,
transport::{pci::bus::PciRoot, DeviceType, Transport},
Hal,
};
@@ -46,6 +49,20 @@
},
};
+fn host_addr() -> VsockAddr {
+ const PROTECTED_VM_PORT: u32 = 5679;
+ const NON_PROTECTED_VM_PORT: u32 = 5680;
+ const VMADDR_CID_HOST: u64 = 2;
+
+ let port = if is_protected_vm() { PROTECTED_VM_PORT } else { NON_PROTECTED_VM_PORT };
+ VsockAddr { cid: VMADDR_CID_HOST, port }
+}
+
+fn is_protected_vm() -> bool {
+ // Use MMIO support to determine whether the VM is protected.
+ get_mmio_guard().is_some()
+}
+
fn new_page_table() -> Result<PageTable> {
let mut page_table = PageTable::default();
@@ -119,6 +136,12 @@
debug!("PCI root: {pci_root:#x?}");
let socket_device = find_socket_device::<HalImpl>(&mut pci_root)?;
debug!("Found socket device: guest cid = {:?}", socket_device.guest_cid());
+
+ let mut data_channel = DataChannel::from(socket_device);
+ data_channel.connect(host_addr())?;
+ data_channel.handle_incoming_request()?;
+ data_channel.force_close()?;
+
Ok(())
}
diff --git a/rialto/tests/test.rs b/rialto/tests/test.rs
index 2dbd0cb..4ad8eb8 100644
--- a/rialto/tests/test.rs
+++ b/rialto/tests/test.rs
@@ -22,15 +22,21 @@
},
binder::{ParcelFileDescriptor, ProcessState},
};
-use anyhow::{anyhow, Context, Error};
+use anyhow::{anyhow, bail, Context, Error};
use log::info;
use std::fs::File;
-use std::io::{self, BufRead, BufReader};
+use std::io::{self, BufRead, BufReader, Read, Write};
use std::os::unix::io::FromRawFd;
use std::panic;
use std::thread;
use std::time::Duration;
use vmclient::{DeathReason, VmInstance};
+use vsock::{VsockListener, VMADDR_CID_HOST};
+
+// TODO(b/291732060): Move the port numbers to the common library shared between the host
+// and rialto.
+const PROTECTED_VM_PORT: u32 = 5679;
+const NON_PROTECTED_VM_PORT: u32 = 5680;
const SIGNED_RIALTO_PATH: &str = "/data/local/tmp/rialto_test/arm64/rialto.bin";
const UNSIGNED_RIALTO_PATH: &str = "/data/local/tmp/rialto_test/arm64/rialto_unsigned.bin";
@@ -124,6 +130,9 @@
)
.context("Failed to create VM")?;
+ let port = if protected_vm { PROTECTED_VM_PORT } else { NON_PROTECTED_VM_PORT };
+ let check_socket_handle = thread::spawn(move || try_check_socket_connection(port).unwrap());
+
vm.start().context("Failed to start VM")?;
// Wait for VM to finish, and check that it shut down cleanly.
@@ -132,6 +141,15 @@
.ok_or_else(|| anyhow!("Timed out waiting for VM exit"))?;
assert_eq!(death_reason, DeathReason::Shutdown);
+ match check_socket_handle.join() {
+ Ok(_) => {
+ info!(
+ "Received the echoed message. \
+ The socket connection between the host and the service VM works correctly."
+ )
+ }
+ Err(_) => bail!("The socket connection check failed."),
+ }
Ok(())
}
@@ -150,3 +168,28 @@
});
Ok(writer)
}
+
+fn try_check_socket_connection(port: u32) -> Result<(), Error> {
+ info!("Setting up the listening socket on port {port}...");
+ let listener = VsockListener::bind_with_cid_port(VMADDR_CID_HOST, port)?;
+ info!("Listening on port {port}...");
+
+ let Some(Ok(mut vsock_stream)) = listener.incoming().next() else {
+ bail!("Failed to get vsock_stream");
+ };
+ info!("Accepted connection {:?}", vsock_stream);
+
+ let message = "Hello from host";
+ vsock_stream.write_all(message.as_bytes())?;
+ vsock_stream.flush()?;
+ info!("Sent message: {:?}.", message);
+
+ let mut buffer = vec![0u8; 30];
+ vsock_stream.set_read_timeout(Some(Duration::from_millis(1_000)))?;
+ let len = vsock_stream.read(&mut buffer)?;
+
+ assert_eq!(message.len(), len);
+ buffer[..len].reverse();
+ assert_eq!(message.as_bytes(), &buffer[..len]);
+ Ok(())
+}
diff --git a/vmbase/src/memory/shared.rs b/vmbase/src/memory/shared.rs
index dfa29e4..f4c9f72 100644
--- a/vmbase/src/memory/shared.rs
+++ b/vmbase/src/memory/shared.rs
@@ -341,7 +341,11 @@
/// Allocates a memory range of at least the given size and alignment that is shared with the host.
/// Returns a pointer to the buffer.
pub(crate) fn alloc_shared(layout: Layout) -> hyp::Result<NonNull<u8>> {
- assert_ne!(layout.size(), 0);
+ // TODO(b/291586508): We have temporarily removed the non-zero check for layout.size() to
+ // enable the Rialto socket device to connect or shut down, as the socket driver adds empty
+ // buffers in these scenarios.
+ // We will add the check back once this issue is fixed in the driver.
+
let Some(buffer) = try_shared_alloc(layout) else {
handle_alloc_error(layout);
};