Merge changes Ia242c505,I5f641294 into main
* changes:
[rialto] Move request processing to a separate module
[rialto] Add communication protocol library for host and rialto
diff --git a/libs/service_vm_comm/Android.bp b/libs/service_vm_comm/Android.bp
new file mode 100644
index 0000000..18397c5
--- /dev/null
+++ b/libs/service_vm_comm/Android.bp
@@ -0,0 +1,36 @@
+package {
+ default_applicable_licenses: ["Android-Apache-2.0"],
+}
+
+rust_defaults {
+ name: "libservice_vm_comm_defaults",
+ crate_name: "service_vm_comm",
+ srcs: ["src/lib.rs"],
+ prefer_rlib: true,
+ apex_available: [
+ "com.android.virt",
+ ],
+}
+
+rust_library_rlib {
+ name: "libservice_vm_comm_nostd",
+ defaults: ["libservice_vm_comm_defaults"],
+ no_stdlibs: true,
+ stdlibs: [
+ "libcore.rust_sysroot",
+ ],
+ rustlibs: [
+ "libserde_nostd",
+ ],
+}
+
+rust_library {
+ name: "libservice_vm_comm",
+ defaults: ["libservice_vm_comm_defaults"],
+ rustlibs: [
+ "libserde",
+ ],
+ features: [
+ "std",
+ ],
+}
diff --git a/libs/service_vm_comm/src/lib.rs b/libs/service_vm_comm/src/lib.rs
new file mode 100644
index 0000000..c3d3ed5
--- /dev/null
+++ b/libs/service_vm_comm/src/lib.rs
@@ -0,0 +1,24 @@
+// Copyright 2023, The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//! This library contains the communication protocol used between the host
+//! and the service VM.
+
+#![cfg_attr(not(feature = "std"), no_std)]
+
+extern crate alloc;
+
+mod message;
+
+pub use message::{Request, Response};
diff --git a/libs/service_vm_comm/src/message.rs b/libs/service_vm_comm/src/message.rs
new file mode 100644
index 0000000..ebbefcb
--- /dev/null
+++ b/libs/service_vm_comm/src/message.rs
@@ -0,0 +1,39 @@
+// Copyright 2023, The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//! This module contains the requests and responses definitions exchanged
+//! between the host and the service VM.
+
+use alloc::vec::Vec;
+
+use serde::{Deserialize, Serialize};
+
+/// Represents a request to be sent to the service VM.
+///
+/// Each request has a corresponding response item.
+#[derive(Clone, Debug, Serialize, Deserialize)]
+pub enum Request {
+ /// Reverse the order of the bytes in the provided byte array.
+ /// Currently this is only used for testing.
+ Reverse(Vec<u8>),
+}
+
+/// Represents a response to a request sent to the service VM.
+///
+/// Each response corresponds to a specific request.
+#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
+pub enum Response {
+ /// Reverse the order of the bytes in the provided byte array.
+ Reverse(Vec<u8>),
+}
diff --git a/rialto/Android.bp b/rialto/Android.bp
index 7ac7136..33c24eb 100644
--- a/rialto/Android.bp
+++ b/rialto/Android.bp
@@ -9,10 +9,13 @@
defaults: ["vmbase_ffi_defaults"],
rustlibs: [
"libaarch64_paging",
+ "libciborium_io_nostd",
+ "libciborium_nostd",
"libhyp",
"libfdtpci",
"liblibfdt",
"liblog_rust_nostd",
+ "libservice_vm_comm_nostd",
"libvirtio_drivers",
"libvmbase",
],
@@ -98,9 +101,11 @@
"android.system.virtualizationservice-rust",
"libandroid_logger",
"libanyhow",
+ "libciborium",
"liblibc",
"liblog_rust",
"libnix",
+ "libservice_vm_comm",
"libvmclient",
"libvsock",
],
diff --git a/rialto/src/communication.rs b/rialto/src/communication.rs
index f00393d..858ccfb 100644
--- a/rialto/src/communication.rs
+++ b/rialto/src/communication.rs
@@ -14,72 +14,174 @@
//! Supports for the communication between rialto and host.
-use crate::error::{Error, Result};
+use crate::error::Result;
+use ciborium_io::{Read, Write};
+use core::hint::spin_loop;
+use core::result;
use log::info;
+use service_vm_comm::{Request, Response};
use virtio_drivers::{
self,
device::socket::{
- SingleConnectionManager, SocketError, VirtIOSocket, VsockAddr, VsockEventType,
+ SocketError, VirtIOSocket, VsockAddr, VsockConnectionManager, VsockEventType,
},
transport::Transport,
Hal,
};
-const MAX_RECV_BUFFER_SIZE_BYTES: usize = 64;
-
-pub struct DataChannel<H: Hal, T: Transport> {
- connection_manager: SingleConnectionManager<H, T>,
+pub struct VsockStream<H: Hal, T: Transport> {
+ connection_manager: VsockConnectionManager<H, T>,
+ /// Peer address. The same port is used on rialto and peer for convenience.
+ peer_addr: VsockAddr,
}
-impl<H: Hal, T: Transport> From<VirtIOSocket<H, T>> for DataChannel<H, T> {
- fn from(socket_device_driver: VirtIOSocket<H, T>) -> Self {
- Self { connection_manager: SingleConnectionManager::new(socket_device_driver) }
+impl<H: Hal, T: Transport> VsockStream<H, T> {
+ pub fn new(
+ socket_device_driver: VirtIOSocket<H, T>,
+ peer_addr: VsockAddr,
+ ) -> virtio_drivers::Result<Self> {
+ let mut vsock_stream = Self {
+ connection_manager: VsockConnectionManager::new(socket_device_driver),
+ peer_addr,
+ };
+ vsock_stream.connect()?;
+ Ok(vsock_stream)
}
-}
-impl<H: Hal, T: Transport> DataChannel<H, T> {
- /// Connects to the given destination.
- pub fn connect(&mut self, destination: VsockAddr) -> virtio_drivers::Result {
- // Use the same port on rialto and host for convenience.
- self.connection_manager.connect(destination, destination.port)?;
- self.connection_manager.wait_for_connect()?;
- info!("Connected to the destination {destination:?}");
+ fn connect(&mut self) -> virtio_drivers::Result {
+ self.connection_manager.connect(self.peer_addr, self.peer_addr.port)?;
+ self.wait_for_connect()?;
+ info!("Connected to the peer {:?}", self.peer_addr);
Ok(())
}
- /// Processes the received requests and sends back a reply.
- pub fn handle_incoming_request(&mut self) -> Result<()> {
- let mut buffer = [0u8; MAX_RECV_BUFFER_SIZE_BYTES];
-
- // TODO(b/274441673): Handle the scenario when the given buffer is too short.
- let len = self.wait_for_recv(&mut buffer).map_err(Error::ReceivingDataFailed)?;
-
- // TODO(b/291732060): Implement the communication protocol.
- // Just reverse the received message for now.
- buffer[..len].reverse();
- self.connection_manager.send(&buffer[..len])?;
- Ok(())
- }
-
- fn wait_for_recv(&mut self, buffer: &mut [u8]) -> virtio_drivers::Result<usize> {
+ fn wait_for_connect(&mut self) -> virtio_drivers::Result {
loop {
- match self.connection_manager.wait_for_recv(buffer)?.event_type {
- VsockEventType::Disconnected { .. } => {
- return Err(SocketError::ConnectionFailed.into())
+ if let Some(event) = self.poll_event_from_peer()? {
+ match event {
+ VsockEventType::Connected => return Ok(()),
+ VsockEventType::Disconnected { .. } => {
+ return Err(SocketError::ConnectionFailed.into())
+ }
+ // We shouldn't receive the following event before the connection is
+ // established.
+ VsockEventType::ConnectionRequest | VsockEventType::Received { .. } => {
+ return Err(SocketError::InvalidOperation.into())
+ }
+ // We can receive credit requests and updates at any time.
+ // This can be ignored as the connection manager handles them in poll().
+ VsockEventType::CreditRequest | VsockEventType::CreditUpdate => {}
}
- VsockEventType::Received { length, .. } => return Ok(length),
- VsockEventType::Connected
- | VsockEventType::ConnectionRequest
- | VsockEventType::CreditRequest
- | VsockEventType::CreditUpdate => {}
+ } else {
+ spin_loop();
}
}
}
+ pub fn read_request(&mut self) -> Result<Request> {
+ Ok(ciborium::from_reader(self)?)
+ }
+
+ pub fn write_response(&mut self, response: &Response) -> Result<()> {
+ Ok(ciborium::into_writer(response, self)?)
+ }
+
/// Shuts down the data channel.
- pub fn force_close(&mut self) -> virtio_drivers::Result {
- self.connection_manager.force_close()?;
+ pub fn shutdown(&mut self) -> virtio_drivers::Result {
+ self.connection_manager.force_close(self.peer_addr, self.peer_addr.port)?;
info!("Connection shutdown.");
Ok(())
}
+
+ fn recv(&mut self, buffer: &mut [u8]) -> virtio_drivers::Result<usize> {
+ self.connection_manager.recv(self.peer_addr, self.peer_addr.port, buffer)
+ }
+
+ fn wait_for_send(&mut self, buffer: &[u8]) -> virtio_drivers::Result {
+ const INSUFFICIENT_BUFFER_SPACE_ERROR: virtio_drivers::Error =
+ virtio_drivers::Error::SocketDeviceError(SocketError::InsufficientBufferSpaceInPeer);
+ loop {
+ match self.connection_manager.send(self.peer_addr, self.peer_addr.port, buffer) {
+ Ok(_) => return Ok(()),
+ Err(INSUFFICIENT_BUFFER_SPACE_ERROR) => {
+ self.poll()?;
+ }
+ Err(e) => return Err(e),
+ }
+ }
+ }
+
+ fn wait_for_recv(&mut self) -> virtio_drivers::Result {
+ loop {
+ match self.poll()? {
+ Some(VsockEventType::Received { .. }) => return Ok(()),
+ _ => spin_loop(),
+ }
+ }
+ }
+
+ /// Polls the rx queue after the connection is established with the peer, this function
+ /// rejects some invalid events. The valid events are handled inside the connection
+ /// manager.
+ fn poll(&mut self) -> virtio_drivers::Result<Option<VsockEventType>> {
+ if let Some(event) = self.poll_event_from_peer()? {
+ match event {
+ VsockEventType::Disconnected { .. } => Err(SocketError::ConnectionFailed.into()),
+ VsockEventType::Connected | VsockEventType::ConnectionRequest => {
+ Err(SocketError::InvalidOperation.into())
+ }
+ // When there is a received event, the received data is buffered in the
+ // connection manager's internal receive buffer, so we don't need to do
+ // anything here.
+ // The credit request and updates also handled inside the connection
+ // manager.
+ VsockEventType::Received { .. }
+ | VsockEventType::CreditRequest
+ | VsockEventType::CreditUpdate => Ok(Some(event)),
+ }
+ } else {
+ Ok(None)
+ }
+ }
+
+ fn poll_event_from_peer(&mut self) -> virtio_drivers::Result<Option<VsockEventType>> {
+ Ok(self.connection_manager.poll()?.map(|event| {
+ assert_eq!(event.source, self.peer_addr);
+ assert_eq!(event.destination.port, self.peer_addr.port);
+ event.event_type
+ }))
+ }
+}
+
+impl<H: Hal, T: Transport> Read for VsockStream<H, T> {
+ type Error = virtio_drivers::Error;
+
+ fn read_exact(&mut self, data: &mut [u8]) -> result::Result<(), Self::Error> {
+ let mut start = 0;
+ while start < data.len() {
+ let len = self.recv(&mut data[start..])?;
+ let len = if len == 0 {
+ self.wait_for_recv()?;
+ self.recv(&mut data[start..])?
+ } else {
+ len
+ };
+ start += len;
+ }
+ Ok(())
+ }
+}
+
+impl<H: Hal, T: Transport> Write for VsockStream<H, T> {
+ type Error = virtio_drivers::Error;
+
+ fn write_all(&mut self, data: &[u8]) -> result::Result<(), Self::Error> {
+ self.wait_for_send(data)
+ }
+
+ fn flush(&mut self) -> result::Result<(), Self::Error> {
+ // TODO(b/293411448): Optimize the data sending by saving the data to write
+ // in a local buffer and then flushing only when the buffer is full.
+ Ok(())
+ }
}
diff --git a/rialto/src/error.rs b/rialto/src/error.rs
index 461870b..23667ed 100644
--- a/rialto/src/error.rs
+++ b/rialto/src/error.rs
@@ -23,7 +23,10 @@
pub type Result<T> = result::Result<T, Error>;
-#[derive(Clone, Debug)]
+type CiboriumSerError = ciborium::ser::Error<virtio_drivers::Error>;
+type CiboriumDeError = ciborium::de::Error<virtio_drivers::Error>;
+
+#[derive(Debug)]
pub enum Error {
/// Hypervisor error.
Hypervisor(HypervisorError),
@@ -43,8 +46,10 @@
MissingVirtIOSocketDevice,
/// Failed VirtIO driver operation.
VirtIODriverOperationFailed(virtio_drivers::Error),
- /// Failed to receive data.
- ReceivingDataFailed(virtio_drivers::Error),
+ /// Failed to serialize.
+ SerializationFailed(CiboriumSerError),
+ /// Failed to deserialize.
+ DeserializationFailed(CiboriumDeError),
}
impl fmt::Display for Error {
@@ -65,7 +70,8 @@
Self::VirtIODriverOperationFailed(e) => {
write!(f, "Failed VirtIO driver operation: {e}")
}
- Self::ReceivingDataFailed(e) => write!(f, "Failed to receive data: {e}"),
+ Self::SerializationFailed(e) => write!(f, "Failed to serialize: {e}"),
+ Self::DeserializationFailed(e) => write!(f, "Failed to deserialize: {e}"),
}
}
}
@@ -105,3 +111,15 @@
Self::VirtIODriverOperationFailed(e)
}
}
+
+impl From<CiboriumSerError> for Error {
+ fn from(e: CiboriumSerError) -> Self {
+ Self::SerializationFailed(e)
+ }
+}
+
+impl From<CiboriumDeError> for Error {
+ fn from(e: CiboriumDeError) -> Self {
+ Self::DeserializationFailed(e)
+ }
+}
diff --git a/rialto/src/main.rs b/rialto/src/main.rs
index 5c6649a..a8338ca 100644
--- a/rialto/src/main.rs
+++ b/rialto/src/main.rs
@@ -20,10 +20,11 @@
mod communication;
mod error;
mod exceptions;
+mod requests;
extern crate alloc;
-use crate::communication::DataChannel;
+use crate::communication::VsockStream;
use crate::error::{Error, Result};
use core::num::NonZeroUsize;
use core::slice;
@@ -137,10 +138,10 @@
let socket_device = find_socket_device::<HalImpl>(&mut pci_root)?;
debug!("Found socket device: guest cid = {:?}", socket_device.guest_cid());
- let mut data_channel = DataChannel::from(socket_device);
- data_channel.connect(host_addr())?;
- data_channel.handle_incoming_request()?;
- data_channel.force_close()?;
+ let mut vsock_stream = VsockStream::new(socket_device, host_addr())?;
+ let response = requests::process_request(vsock_stream.read_request()?);
+ vsock_stream.write_response(&response)?;
+ vsock_stream.shutdown()?;
Ok(())
}
diff --git a/rialto/src/requests/api.rs b/rialto/src/requests/api.rs
new file mode 100644
index 0000000..11fdde4
--- /dev/null
+++ b/rialto/src/requests/api.rs
@@ -0,0 +1,31 @@
+// Copyright 2023, The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//! This module contains the main API for the request processing module.
+
+use alloc::vec::Vec;
+use service_vm_comm::{Request, Response};
+
+/// Processes a request and returns the corresponding response.
+/// This function serves as the entry point for the request processing
+/// module.
+pub fn process_request(request: Request) -> Response {
+ match request {
+ Request::Reverse(v) => Response::Reverse(reverse(v)),
+ }
+}
+
+fn reverse(payload: Vec<u8>) -> Vec<u8> {
+ payload.into_iter().rev().collect()
+}
diff --git a/rialto/src/requests/mod.rs b/rialto/src/requests/mod.rs
new file mode 100644
index 0000000..ca22777
--- /dev/null
+++ b/rialto/src/requests/mod.rs
@@ -0,0 +1,19 @@
+// Copyright 2023, The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//! This module contains functions for the request processing.
+
+mod api;
+
+pub use api::process_request;
diff --git a/rialto/tests/test.rs b/rialto/tests/test.rs
index 8089016..2bd8968 100644
--- a/rialto/tests/test.rs
+++ b/rialto/tests/test.rs
@@ -22,10 +22,11 @@
},
binder::{ParcelFileDescriptor, ProcessState},
};
-use anyhow::{anyhow, bail, Context, Error};
+use anyhow::{anyhow, bail, Context, Result};
use log::info;
+use service_vm_comm::{Request, Response};
use std::fs::File;
-use std::io::{self, BufRead, BufReader, Read, Write};
+use std::io::{self, BufRead, BufReader, BufWriter, Write};
use std::os::unix::io::FromRawFd;
use std::panic;
use std::thread;
@@ -44,7 +45,7 @@
const INSTANCE_IMG_SIZE: i64 = 1024 * 1024; // 1MB
#[test]
-fn boot_rialto_in_protected_vm_successfully() -> Result<(), Error> {
+fn boot_rialto_in_protected_vm_successfully() -> Result<()> {
boot_rialto_successfully(
SIGNED_RIALTO_PATH,
true, // protected_vm
@@ -52,14 +53,14 @@
}
#[test]
-fn boot_rialto_in_unprotected_vm_successfully() -> Result<(), Error> {
+fn boot_rialto_in_unprotected_vm_successfully() -> Result<()> {
boot_rialto_successfully(
UNSIGNED_RIALTO_PATH,
false, // protected_vm
)
}
-fn boot_rialto_successfully(rialto_path: &str, protected_vm: bool) -> Result<(), Error> {
+fn boot_rialto_successfully(rialto_path: &str, protected_vm: bool) -> Result<()> {
android_logger::init_once(
android_logger::Config::default().with_tag("rialto").with_min_level(log::Level::Debug),
);
@@ -169,27 +170,31 @@
Ok(writer)
}
-fn try_check_socket_connection(port: u32) -> Result<(), Error> {
+fn try_check_socket_connection(port: u32) -> Result<()> {
info!("Setting up the listening socket on port {port}...");
let listener = VsockListener::bind_with_cid_port(VMADDR_CID_HOST, port)?;
info!("Listening on port {port}...");
- let Some(Ok(mut vsock_stream)) = listener.incoming().next() else {
- bail!("Failed to get vsock_stream");
- };
+ let mut vsock_stream =
+ listener.incoming().next().ok_or_else(|| anyhow!("Failed to get vsock_stream"))??;
info!("Accepted connection {:?}", vsock_stream);
-
- let message = "Hello from host";
- vsock_stream.write_all(message.as_bytes())?;
- vsock_stream.flush()?;
- info!("Sent message: {:?}.", message);
-
- let mut buffer = vec![0u8; 30];
vsock_stream.set_read_timeout(Some(Duration::from_millis(1_000)))?;
- let len = vsock_stream.read(&mut buffer)?;
- assert_eq!(message.len(), len);
- buffer[..len].reverse();
- assert_eq!(message.as_bytes(), &buffer[..len]);
+ const WRITE_BUFFER_CAPACITY: usize = 512;
+ let mut buffer = BufWriter::with_capacity(WRITE_BUFFER_CAPACITY, vsock_stream.clone());
+
+ // TODO(b/292080257): Test with message longer than the receiver's buffer capacity
+ // 1024 bytes once the guest virtio-vsock driver fixes the credit update in recv().
+ let message = "abc".repeat(166);
+ let request = Request::Reverse(message.as_bytes().to_vec());
+ ciborium::into_writer(&request, &mut buffer)?;
+ buffer.flush()?;
+ info!("Sent request: {request:?}.");
+
+ let response: Response = ciborium::from_reader(&mut vsock_stream)?;
+ info!("Received response: {response:?}.");
+
+ let expected_response: Vec<u8> = message.as_bytes().iter().rev().cloned().collect();
+ assert_eq!(Response::Reverse(expected_response), response);
Ok(())
}