blob: f00393d255cdc0051b620ddf55619fe46f8dca43 [file] [log] [blame]
// Copyright 2023, The Android Open Source Project
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Supports for the communication between rialto and host.
use crate::error::{Error, Result};
use log::info;
use virtio_drivers::{
self,
device::socket::{
SingleConnectionManager, SocketError, VirtIOSocket, VsockAddr, VsockEventType,
},
transport::Transport,
Hal,
};
const MAX_RECV_BUFFER_SIZE_BYTES: usize = 64;
pub struct DataChannel<H: Hal, T: Transport> {
connection_manager: SingleConnectionManager<H, T>,
}
impl<H: Hal, T: Transport> From<VirtIOSocket<H, T>> for DataChannel<H, T> {
fn from(socket_device_driver: VirtIOSocket<H, T>) -> Self {
Self { connection_manager: SingleConnectionManager::new(socket_device_driver) }
}
}
impl<H: Hal, T: Transport> DataChannel<H, T> {
/// Connects to the given destination.
pub fn connect(&mut self, destination: VsockAddr) -> virtio_drivers::Result {
// Use the same port on rialto and host for convenience.
self.connection_manager.connect(destination, destination.port)?;
self.connection_manager.wait_for_connect()?;
info!("Connected to the destination {destination:?}");
Ok(())
}
/// Processes the received requests and sends back a reply.
pub fn handle_incoming_request(&mut self) -> Result<()> {
let mut buffer = [0u8; MAX_RECV_BUFFER_SIZE_BYTES];
// TODO(b/274441673): Handle the scenario when the given buffer is too short.
let len = self.wait_for_recv(&mut buffer).map_err(Error::ReceivingDataFailed)?;
// TODO(b/291732060): Implement the communication protocol.
// Just reverse the received message for now.
buffer[..len].reverse();
self.connection_manager.send(&buffer[..len])?;
Ok(())
}
fn wait_for_recv(&mut self, buffer: &mut [u8]) -> virtio_drivers::Result<usize> {
loop {
match self.connection_manager.wait_for_recv(buffer)?.event_type {
VsockEventType::Disconnected { .. } => {
return Err(SocketError::ConnectionFailed.into())
}
VsockEventType::Received { length, .. } => return Ok(length),
VsockEventType::Connected
| VsockEventType::ConnectionRequest
| VsockEventType::CreditRequest
| VsockEventType::CreditUpdate => {}
}
}
}
/// Shuts down the data channel.
pub fn force_close(&mut self) -> virtio_drivers::Result {
self.connection_manager.force_close()?;
info!("Connection shutdown.");
Ok(())
}
}