blob: 1b949121f5d89b281da837a23c5b9866543e8d45 [file] [log] [blame]
// Copyright 2023, The Android Open Source Project
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Supports for the communication between rialto and host.
use crate::error::Result;
use ciborium_io::{Read, Write};
use core::hint::spin_loop;
use core::mem;
use core::result;
use log::info;
use service_vm_comm::{Response, ServiceVmRequest};
use tinyvec::ArrayVec;
use virtio_drivers::{
self,
device::socket::{
SocketError, VirtIOSocket, VsockAddr, VsockConnectionManager, VsockEventType,
},
transport::Transport,
Hal,
};
const WRITE_BUF_CAPACITY: usize = 512;
pub struct VsockStream<H: Hal, T: Transport> {
connection_manager: VsockConnectionManager<H, T>,
/// Peer address. The same port is used on rialto and peer for convenience.
peer_addr: VsockAddr,
write_buf: ArrayVec<[u8; WRITE_BUF_CAPACITY]>,
}
impl<H: Hal, T: Transport> VsockStream<H, T> {
pub fn new(
socket_device_driver: VirtIOSocket<H, T>,
peer_addr: VsockAddr,
) -> virtio_drivers::Result<Self> {
let mut vsock_stream = Self {
connection_manager: VsockConnectionManager::new(socket_device_driver),
peer_addr,
write_buf: ArrayVec::default(),
};
vsock_stream.connect()?;
Ok(vsock_stream)
}
fn connect(&mut self) -> virtio_drivers::Result {
self.connection_manager.connect(self.peer_addr, self.peer_addr.port)?;
self.wait_for_connect()?;
info!("Connected to the peer {:?}", self.peer_addr);
Ok(())
}
fn wait_for_connect(&mut self) -> virtio_drivers::Result {
loop {
if let Some(event) = self.poll_event_from_peer()? {
match event {
VsockEventType::Connected => return Ok(()),
VsockEventType::Disconnected { .. } => {
return Err(SocketError::ConnectionFailed.into())
}
// We shouldn't receive the following event before the connection is
// established.
VsockEventType::ConnectionRequest | VsockEventType::Received { .. } => {
return Err(SocketError::InvalidOperation.into())
}
// We can receive credit requests and updates at any time.
// This can be ignored as the connection manager handles them in poll().
VsockEventType::CreditRequest | VsockEventType::CreditUpdate => {}
}
} else {
spin_loop();
}
}
}
pub fn read_request(&mut self) -> Result<ServiceVmRequest> {
Ok(ciborium::from_reader(self)?)
}
pub fn write_response(&mut self, response: &Response) -> Result<()> {
Ok(ciborium::into_writer(response, self)?)
}
/// Shuts down the data channel.
pub fn shutdown(&mut self) -> virtio_drivers::Result {
self.connection_manager.force_close(self.peer_addr, self.peer_addr.port)?;
info!("Connection shutdown.");
Ok(())
}
fn recv(&mut self, buffer: &mut [u8]) -> virtio_drivers::Result<usize> {
let bytes_read =
self.connection_manager.recv(self.peer_addr, self.peer_addr.port, buffer)?;
let buffer_available_bytes = self
.connection_manager
.recv_buffer_available_bytes(self.peer_addr, self.peer_addr.port)?;
if buffer_available_bytes == 0 && bytes_read > 0 {
self.connection_manager.update_credit(self.peer_addr, self.peer_addr.port)?;
}
Ok(bytes_read)
}
fn wait_for_send(&mut self, buffer: &[u8]) -> virtio_drivers::Result {
const INSUFFICIENT_BUFFER_SPACE_ERROR: virtio_drivers::Error =
virtio_drivers::Error::SocketDeviceError(SocketError::InsufficientBufferSpaceInPeer);
loop {
match self.connection_manager.send(self.peer_addr, self.peer_addr.port, buffer) {
Ok(_) => return Ok(()),
Err(INSUFFICIENT_BUFFER_SPACE_ERROR) => {
self.poll()?;
}
Err(e) => return Err(e),
}
}
}
fn wait_for_recv(&mut self) -> virtio_drivers::Result {
loop {
match self.poll()? {
Some(VsockEventType::Received { .. }) => return Ok(()),
_ => spin_loop(),
}
}
}
/// Polls the rx queue after the connection is established with the peer, this function
/// rejects some invalid events. The valid events are handled inside the connection
/// manager.
fn poll(&mut self) -> virtio_drivers::Result<Option<VsockEventType>> {
if let Some(event) = self.poll_event_from_peer()? {
match event {
VsockEventType::Disconnected { .. } => Err(SocketError::ConnectionFailed.into()),
VsockEventType::Connected | VsockEventType::ConnectionRequest => {
Err(SocketError::InvalidOperation.into())
}
// When there is a received event, the received data is buffered in the
// connection manager's internal receive buffer, so we don't need to do
// anything here.
// The credit request and updates also handled inside the connection
// manager.
VsockEventType::Received { .. }
| VsockEventType::CreditRequest
| VsockEventType::CreditUpdate => Ok(Some(event)),
}
} else {
Ok(None)
}
}
fn poll_event_from_peer(&mut self) -> virtio_drivers::Result<Option<VsockEventType>> {
Ok(self.connection_manager.poll()?.map(|event| {
assert_eq!(event.source, self.peer_addr);
assert_eq!(event.destination.port, self.peer_addr.port);
event.event_type
}))
}
}
impl<H: Hal, T: Transport> Read for VsockStream<H, T> {
type Error = virtio_drivers::Error;
fn read_exact(&mut self, data: &mut [u8]) -> result::Result<(), Self::Error> {
let mut start = 0;
while start < data.len() {
let len = self.recv(&mut data[start..])?;
let len = if len == 0 {
self.wait_for_recv()?;
self.recv(&mut data[start..])?
} else {
len
};
start += len;
}
Ok(())
}
}
impl<H: Hal, T: Transport> Write for VsockStream<H, T> {
type Error = virtio_drivers::Error;
fn write_all(&mut self, data: &[u8]) -> result::Result<(), Self::Error> {
if data.len() >= self.write_buf.capacity() - self.write_buf.len() {
self.flush()?;
if data.len() >= self.write_buf.capacity() {
self.wait_for_send(data)?;
return Ok(());
}
}
self.write_buf.extend_from_slice(data);
Ok(())
}
fn flush(&mut self) -> result::Result<(), Self::Error> {
if !self.write_buf.is_empty() {
// We need to take the memory from self.write_buf to a temporary
// buffer to avoid borrowing `*self` as mutable and immutable on
// the same time in `self.wait_for_send(&self.write_buf)`.
let buffer = mem::take(&mut self.write_buf);
self.wait_for_send(&buffer)?;
}
Ok(())
}
}