[rialto] Add communication protocol library for host and rialto
This cl adds a communication protocol library that includes
Request and Response definitions for communication between the
host and rialto. The Request and Response messages are
(de)serialized in CBOR format and exchanged via vsock, enabling
the handling of partial packets.
In addition, an end-to-end test has been included to cover this
functionality. The test verifies the following steps that
correspond to the intended usage of the library:
- The host sends a Request to rialto serialized in CBOR.
- rialto deserializes the Request, executes it, and prepares a
Response.
- rialto sends the Response to the host serialized in CBOR.
- The host deserializes the Response.
Bug: 291732060
Test: atest rialto_test
Change-Id: I5f6412949e34b2431d060703e6dea1b96c92fde5
diff --git a/libs/service_vm_comm/Android.bp b/libs/service_vm_comm/Android.bp
new file mode 100644
index 0000000..18397c5
--- /dev/null
+++ b/libs/service_vm_comm/Android.bp
@@ -0,0 +1,36 @@
+package {
+ default_applicable_licenses: ["Android-Apache-2.0"],
+}
+
+rust_defaults {
+ name: "libservice_vm_comm_defaults",
+ crate_name: "service_vm_comm",
+ srcs: ["src/lib.rs"],
+ prefer_rlib: true,
+ apex_available: [
+ "com.android.virt",
+ ],
+}
+
+rust_library_rlib {
+ name: "libservice_vm_comm_nostd",
+ defaults: ["libservice_vm_comm_defaults"],
+ no_stdlibs: true,
+ stdlibs: [
+ "libcore.rust_sysroot",
+ ],
+ rustlibs: [
+ "libserde_nostd",
+ ],
+}
+
+rust_library {
+ name: "libservice_vm_comm",
+ defaults: ["libservice_vm_comm_defaults"],
+ rustlibs: [
+ "libserde",
+ ],
+ features: [
+ "std",
+ ],
+}
diff --git a/libs/service_vm_comm/src/lib.rs b/libs/service_vm_comm/src/lib.rs
new file mode 100644
index 0000000..c3d3ed5
--- /dev/null
+++ b/libs/service_vm_comm/src/lib.rs
@@ -0,0 +1,24 @@
+// Copyright 2023, The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//! This library contains the communication protocol used between the host
+//! and the service VM.
+
+#![cfg_attr(not(feature = "std"), no_std)]
+
+extern crate alloc;
+
+mod message;
+
+pub use message::{Request, Response};
diff --git a/libs/service_vm_comm/src/message.rs b/libs/service_vm_comm/src/message.rs
new file mode 100644
index 0000000..ebbefcb
--- /dev/null
+++ b/libs/service_vm_comm/src/message.rs
@@ -0,0 +1,39 @@
+// Copyright 2023, The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//! This module contains the requests and responses definitions exchanged
+//! between the host and the service VM.
+
+use alloc::vec::Vec;
+
+use serde::{Deserialize, Serialize};
+
+/// Represents a request to be sent to the service VM.
+///
+/// Each request has a corresponding response item.
+#[derive(Clone, Debug, Serialize, Deserialize)]
+pub enum Request {
+ /// Reverse the order of the bytes in the provided byte array.
+ /// Currently this is only used for testing.
+ Reverse(Vec<u8>),
+}
+
+/// Represents a response to a request sent to the service VM.
+///
+/// Each response corresponds to a specific request.
+#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
+pub enum Response {
+ /// Reverse the order of the bytes in the provided byte array.
+ Reverse(Vec<u8>),
+}
diff --git a/rialto/Android.bp b/rialto/Android.bp
index 7ac7136..33c24eb 100644
--- a/rialto/Android.bp
+++ b/rialto/Android.bp
@@ -9,10 +9,13 @@
defaults: ["vmbase_ffi_defaults"],
rustlibs: [
"libaarch64_paging",
+ "libciborium_io_nostd",
+ "libciborium_nostd",
"libhyp",
"libfdtpci",
"liblibfdt",
"liblog_rust_nostd",
+ "libservice_vm_comm_nostd",
"libvirtio_drivers",
"libvmbase",
],
@@ -98,9 +101,11 @@
"android.system.virtualizationservice-rust",
"libandroid_logger",
"libanyhow",
+ "libciborium",
"liblibc",
"liblog_rust",
"libnix",
+ "libservice_vm_comm",
"libvmclient",
"libvsock",
],
diff --git a/rialto/src/communication.rs b/rialto/src/communication.rs
index f00393d..858ccfb 100644
--- a/rialto/src/communication.rs
+++ b/rialto/src/communication.rs
@@ -14,72 +14,174 @@
//! Supports for the communication between rialto and host.
-use crate::error::{Error, Result};
+use crate::error::Result;
+use ciborium_io::{Read, Write};
+use core::hint::spin_loop;
+use core::result;
use log::info;
+use service_vm_comm::{Request, Response};
use virtio_drivers::{
self,
device::socket::{
- SingleConnectionManager, SocketError, VirtIOSocket, VsockAddr, VsockEventType,
+ SocketError, VirtIOSocket, VsockAddr, VsockConnectionManager, VsockEventType,
},
transport::Transport,
Hal,
};
-const MAX_RECV_BUFFER_SIZE_BYTES: usize = 64;
-
-pub struct DataChannel<H: Hal, T: Transport> {
- connection_manager: SingleConnectionManager<H, T>,
+pub struct VsockStream<H: Hal, T: Transport> {
+ connection_manager: VsockConnectionManager<H, T>,
+ /// Peer address. The same port is used on rialto and peer for convenience.
+ peer_addr: VsockAddr,
}
-impl<H: Hal, T: Transport> From<VirtIOSocket<H, T>> for DataChannel<H, T> {
- fn from(socket_device_driver: VirtIOSocket<H, T>) -> Self {
- Self { connection_manager: SingleConnectionManager::new(socket_device_driver) }
+impl<H: Hal, T: Transport> VsockStream<H, T> {
+ pub fn new(
+ socket_device_driver: VirtIOSocket<H, T>,
+ peer_addr: VsockAddr,
+ ) -> virtio_drivers::Result<Self> {
+ let mut vsock_stream = Self {
+ connection_manager: VsockConnectionManager::new(socket_device_driver),
+ peer_addr,
+ };
+ vsock_stream.connect()?;
+ Ok(vsock_stream)
}
-}
-impl<H: Hal, T: Transport> DataChannel<H, T> {
- /// Connects to the given destination.
- pub fn connect(&mut self, destination: VsockAddr) -> virtio_drivers::Result {
- // Use the same port on rialto and host for convenience.
- self.connection_manager.connect(destination, destination.port)?;
- self.connection_manager.wait_for_connect()?;
- info!("Connected to the destination {destination:?}");
+ fn connect(&mut self) -> virtio_drivers::Result {
+ self.connection_manager.connect(self.peer_addr, self.peer_addr.port)?;
+ self.wait_for_connect()?;
+ info!("Connected to the peer {:?}", self.peer_addr);
Ok(())
}
- /// Processes the received requests and sends back a reply.
- pub fn handle_incoming_request(&mut self) -> Result<()> {
- let mut buffer = [0u8; MAX_RECV_BUFFER_SIZE_BYTES];
-
- // TODO(b/274441673): Handle the scenario when the given buffer is too short.
- let len = self.wait_for_recv(&mut buffer).map_err(Error::ReceivingDataFailed)?;
-
- // TODO(b/291732060): Implement the communication protocol.
- // Just reverse the received message for now.
- buffer[..len].reverse();
- self.connection_manager.send(&buffer[..len])?;
- Ok(())
- }
-
- fn wait_for_recv(&mut self, buffer: &mut [u8]) -> virtio_drivers::Result<usize> {
+ fn wait_for_connect(&mut self) -> virtio_drivers::Result {
loop {
- match self.connection_manager.wait_for_recv(buffer)?.event_type {
- VsockEventType::Disconnected { .. } => {
- return Err(SocketError::ConnectionFailed.into())
+ if let Some(event) = self.poll_event_from_peer()? {
+ match event {
+ VsockEventType::Connected => return Ok(()),
+ VsockEventType::Disconnected { .. } => {
+ return Err(SocketError::ConnectionFailed.into())
+ }
+ // We shouldn't receive the following event before the connection is
+ // established.
+ VsockEventType::ConnectionRequest | VsockEventType::Received { .. } => {
+ return Err(SocketError::InvalidOperation.into())
+ }
+ // We can receive credit requests and updates at any time.
+ // This can be ignored as the connection manager handles them in poll().
+ VsockEventType::CreditRequest | VsockEventType::CreditUpdate => {}
}
- VsockEventType::Received { length, .. } => return Ok(length),
- VsockEventType::Connected
- | VsockEventType::ConnectionRequest
- | VsockEventType::CreditRequest
- | VsockEventType::CreditUpdate => {}
+ } else {
+ spin_loop();
}
}
}
+ pub fn read_request(&mut self) -> Result<Request> {
+ Ok(ciborium::from_reader(self)?)
+ }
+
+ pub fn write_response(&mut self, response: &Response) -> Result<()> {
+ Ok(ciborium::into_writer(response, self)?)
+ }
+
/// Shuts down the data channel.
- pub fn force_close(&mut self) -> virtio_drivers::Result {
- self.connection_manager.force_close()?;
+ pub fn shutdown(&mut self) -> virtio_drivers::Result {
+ self.connection_manager.force_close(self.peer_addr, self.peer_addr.port)?;
info!("Connection shutdown.");
Ok(())
}
+
+ fn recv(&mut self, buffer: &mut [u8]) -> virtio_drivers::Result<usize> {
+ self.connection_manager.recv(self.peer_addr, self.peer_addr.port, buffer)
+ }
+
+ fn wait_for_send(&mut self, buffer: &[u8]) -> virtio_drivers::Result {
+ const INSUFFICIENT_BUFFER_SPACE_ERROR: virtio_drivers::Error =
+ virtio_drivers::Error::SocketDeviceError(SocketError::InsufficientBufferSpaceInPeer);
+ loop {
+ match self.connection_manager.send(self.peer_addr, self.peer_addr.port, buffer) {
+ Ok(_) => return Ok(()),
+ Err(INSUFFICIENT_BUFFER_SPACE_ERROR) => {
+ self.poll()?;
+ }
+ Err(e) => return Err(e),
+ }
+ }
+ }
+
+ fn wait_for_recv(&mut self) -> virtio_drivers::Result {
+ loop {
+ match self.poll()? {
+ Some(VsockEventType::Received { .. }) => return Ok(()),
+ _ => spin_loop(),
+ }
+ }
+ }
+
+ /// Polls the rx queue after the connection is established with the peer, this function
+ /// rejects some invalid events. The valid events are handled inside the connection
+ /// manager.
+ fn poll(&mut self) -> virtio_drivers::Result<Option<VsockEventType>> {
+ if let Some(event) = self.poll_event_from_peer()? {
+ match event {
+ VsockEventType::Disconnected { .. } => Err(SocketError::ConnectionFailed.into()),
+ VsockEventType::Connected | VsockEventType::ConnectionRequest => {
+ Err(SocketError::InvalidOperation.into())
+ }
+ // When there is a received event, the received data is buffered in the
+ // connection manager's internal receive buffer, so we don't need to do
+ // anything here.
+ // The credit request and updates also handled inside the connection
+ // manager.
+ VsockEventType::Received { .. }
+ | VsockEventType::CreditRequest
+ | VsockEventType::CreditUpdate => Ok(Some(event)),
+ }
+ } else {
+ Ok(None)
+ }
+ }
+
+ fn poll_event_from_peer(&mut self) -> virtio_drivers::Result<Option<VsockEventType>> {
+ Ok(self.connection_manager.poll()?.map(|event| {
+ assert_eq!(event.source, self.peer_addr);
+ assert_eq!(event.destination.port, self.peer_addr.port);
+ event.event_type
+ }))
+ }
+}
+
+impl<H: Hal, T: Transport> Read for VsockStream<H, T> {
+ type Error = virtio_drivers::Error;
+
+ fn read_exact(&mut self, data: &mut [u8]) -> result::Result<(), Self::Error> {
+ let mut start = 0;
+ while start < data.len() {
+ let len = self.recv(&mut data[start..])?;
+ let len = if len == 0 {
+ self.wait_for_recv()?;
+ self.recv(&mut data[start..])?
+ } else {
+ len
+ };
+ start += len;
+ }
+ Ok(())
+ }
+}
+
+impl<H: Hal, T: Transport> Write for VsockStream<H, T> {
+ type Error = virtio_drivers::Error;
+
+ fn write_all(&mut self, data: &[u8]) -> result::Result<(), Self::Error> {
+ self.wait_for_send(data)
+ }
+
+ fn flush(&mut self) -> result::Result<(), Self::Error> {
+ // TODO(b/293411448): Optimize the data sending by saving the data to write
+ // in a local buffer and then flushing only when the buffer is full.
+ Ok(())
+ }
}
diff --git a/rialto/src/error.rs b/rialto/src/error.rs
index 461870b..23667ed 100644
--- a/rialto/src/error.rs
+++ b/rialto/src/error.rs
@@ -23,7 +23,10 @@
pub type Result<T> = result::Result<T, Error>;
-#[derive(Clone, Debug)]
+type CiboriumSerError = ciborium::ser::Error<virtio_drivers::Error>;
+type CiboriumDeError = ciborium::de::Error<virtio_drivers::Error>;
+
+#[derive(Debug)]
pub enum Error {
/// Hypervisor error.
Hypervisor(HypervisorError),
@@ -43,8 +46,10 @@
MissingVirtIOSocketDevice,
/// Failed VirtIO driver operation.
VirtIODriverOperationFailed(virtio_drivers::Error),
- /// Failed to receive data.
- ReceivingDataFailed(virtio_drivers::Error),
+ /// Failed to serialize.
+ SerializationFailed(CiboriumSerError),
+ /// Failed to deserialize.
+ DeserializationFailed(CiboriumDeError),
}
impl fmt::Display for Error {
@@ -65,7 +70,8 @@
Self::VirtIODriverOperationFailed(e) => {
write!(f, "Failed VirtIO driver operation: {e}")
}
- Self::ReceivingDataFailed(e) => write!(f, "Failed to receive data: {e}"),
+ Self::SerializationFailed(e) => write!(f, "Failed to serialize: {e}"),
+ Self::DeserializationFailed(e) => write!(f, "Failed to deserialize: {e}"),
}
}
}
@@ -105,3 +111,15 @@
Self::VirtIODriverOperationFailed(e)
}
}
+
+impl From<CiboriumSerError> for Error {
+ fn from(e: CiboriumSerError) -> Self {
+ Self::SerializationFailed(e)
+ }
+}
+
+impl From<CiboriumDeError> for Error {
+ fn from(e: CiboriumDeError) -> Self {
+ Self::DeserializationFailed(e)
+ }
+}
diff --git a/rialto/src/main.rs b/rialto/src/main.rs
index 5c6649a..1c5090a 100644
--- a/rialto/src/main.rs
+++ b/rialto/src/main.rs
@@ -23,7 +23,7 @@
extern crate alloc;
-use crate::communication::DataChannel;
+use crate::communication::VsockStream;
use crate::error::{Error, Result};
use core::num::NonZeroUsize;
use core::slice;
@@ -31,6 +31,7 @@
use hyp::{get_mem_sharer, get_mmio_guard};
use libfdt::FdtError;
use log::{debug, error, info};
+use service_vm_comm::{Request, Response};
use virtio_drivers::{
device::socket::VsockAddr,
transport::{pci::bus::PciRoot, DeviceType, Transport},
@@ -137,10 +138,12 @@
let socket_device = find_socket_device::<HalImpl>(&mut pci_root)?;
debug!("Found socket device: guest cid = {:?}", socket_device.guest_cid());
- let mut data_channel = DataChannel::from(socket_device);
- data_channel.connect(host_addr())?;
- data_channel.handle_incoming_request()?;
- data_channel.force_close()?;
+ let mut vsock_stream = VsockStream::new(socket_device, host_addr())?;
+ let response = match vsock_stream.read_request()? {
+ Request::Reverse(v) => Response::Reverse(v.into_iter().rev().collect()),
+ };
+ vsock_stream.write_response(&response)?;
+ vsock_stream.shutdown()?;
Ok(())
}
diff --git a/rialto/tests/test.rs b/rialto/tests/test.rs
index 8089016..2bd8968 100644
--- a/rialto/tests/test.rs
+++ b/rialto/tests/test.rs
@@ -22,10 +22,11 @@
},
binder::{ParcelFileDescriptor, ProcessState},
};
-use anyhow::{anyhow, bail, Context, Error};
+use anyhow::{anyhow, bail, Context, Result};
use log::info;
+use service_vm_comm::{Request, Response};
use std::fs::File;
-use std::io::{self, BufRead, BufReader, Read, Write};
+use std::io::{self, BufRead, BufReader, BufWriter, Write};
use std::os::unix::io::FromRawFd;
use std::panic;
use std::thread;
@@ -44,7 +45,7 @@
const INSTANCE_IMG_SIZE: i64 = 1024 * 1024; // 1MB
#[test]
-fn boot_rialto_in_protected_vm_successfully() -> Result<(), Error> {
+fn boot_rialto_in_protected_vm_successfully() -> Result<()> {
boot_rialto_successfully(
SIGNED_RIALTO_PATH,
true, // protected_vm
@@ -52,14 +53,14 @@
}
#[test]
-fn boot_rialto_in_unprotected_vm_successfully() -> Result<(), Error> {
+fn boot_rialto_in_unprotected_vm_successfully() -> Result<()> {
boot_rialto_successfully(
UNSIGNED_RIALTO_PATH,
false, // protected_vm
)
}
-fn boot_rialto_successfully(rialto_path: &str, protected_vm: bool) -> Result<(), Error> {
+fn boot_rialto_successfully(rialto_path: &str, protected_vm: bool) -> Result<()> {
android_logger::init_once(
android_logger::Config::default().with_tag("rialto").with_min_level(log::Level::Debug),
);
@@ -169,27 +170,31 @@
Ok(writer)
}
-fn try_check_socket_connection(port: u32) -> Result<(), Error> {
+fn try_check_socket_connection(port: u32) -> Result<()> {
info!("Setting up the listening socket on port {port}...");
let listener = VsockListener::bind_with_cid_port(VMADDR_CID_HOST, port)?;
info!("Listening on port {port}...");
- let Some(Ok(mut vsock_stream)) = listener.incoming().next() else {
- bail!("Failed to get vsock_stream");
- };
+ let mut vsock_stream =
+ listener.incoming().next().ok_or_else(|| anyhow!("Failed to get vsock_stream"))??;
info!("Accepted connection {:?}", vsock_stream);
-
- let message = "Hello from host";
- vsock_stream.write_all(message.as_bytes())?;
- vsock_stream.flush()?;
- info!("Sent message: {:?}.", message);
-
- let mut buffer = vec![0u8; 30];
vsock_stream.set_read_timeout(Some(Duration::from_millis(1_000)))?;
- let len = vsock_stream.read(&mut buffer)?;
- assert_eq!(message.len(), len);
- buffer[..len].reverse();
- assert_eq!(message.as_bytes(), &buffer[..len]);
+ const WRITE_BUFFER_CAPACITY: usize = 512;
+ let mut buffer = BufWriter::with_capacity(WRITE_BUFFER_CAPACITY, vsock_stream.clone());
+
+ // TODO(b/292080257): Test with message longer than the receiver's buffer capacity
+ // 1024 bytes once the guest virtio-vsock driver fixes the credit update in recv().
+ let message = "abc".repeat(166);
+ let request = Request::Reverse(message.as_bytes().to_vec());
+ ciborium::into_writer(&request, &mut buffer)?;
+ buffer.flush()?;
+ info!("Sent request: {request:?}.");
+
+ let response: Response = ciborium::from_reader(&mut vsock_stream)?;
+ info!("Received response: {response:?}.");
+
+ let expected_response: Vec<u8> = message.as_bytes().iter().rev().cloned().collect();
+ assert_eq!(Response::Reverse(expected_response), response);
Ok(())
}