blob: f00393d255cdc0051b620ddf55619fe46f8dca43 [file] [log] [blame]
Alice Wang4e082c32023-07-11 07:41:50 +00001// Copyright 2023, The Android Open Source Project
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Supports for the communication between rialto and host.
16
17use crate::error::{Error, Result};
18use log::info;
19use virtio_drivers::{
20 self,
21 device::socket::{
22 SingleConnectionManager, SocketError, VirtIOSocket, VsockAddr, VsockEventType,
23 },
24 transport::Transport,
25 Hal,
26};
27
28const MAX_RECV_BUFFER_SIZE_BYTES: usize = 64;
29
30pub struct DataChannel<H: Hal, T: Transport> {
31 connection_manager: SingleConnectionManager<H, T>,
32}
33
34impl<H: Hal, T: Transport> From<VirtIOSocket<H, T>> for DataChannel<H, T> {
35 fn from(socket_device_driver: VirtIOSocket<H, T>) -> Self {
36 Self { connection_manager: SingleConnectionManager::new(socket_device_driver) }
37 }
38}
39
40impl<H: Hal, T: Transport> DataChannel<H, T> {
41 /// Connects to the given destination.
42 pub fn connect(&mut self, destination: VsockAddr) -> virtio_drivers::Result {
43 // Use the same port on rialto and host for convenience.
44 self.connection_manager.connect(destination, destination.port)?;
45 self.connection_manager.wait_for_connect()?;
46 info!("Connected to the destination {destination:?}");
47 Ok(())
48 }
49
50 /// Processes the received requests and sends back a reply.
51 pub fn handle_incoming_request(&mut self) -> Result<()> {
52 let mut buffer = [0u8; MAX_RECV_BUFFER_SIZE_BYTES];
53
54 // TODO(b/274441673): Handle the scenario when the given buffer is too short.
55 let len = self.wait_for_recv(&mut buffer).map_err(Error::ReceivingDataFailed)?;
56
57 // TODO(b/291732060): Implement the communication protocol.
58 // Just reverse the received message for now.
59 buffer[..len].reverse();
60 self.connection_manager.send(&buffer[..len])?;
61 Ok(())
62 }
63
64 fn wait_for_recv(&mut self, buffer: &mut [u8]) -> virtio_drivers::Result<usize> {
65 loop {
66 match self.connection_manager.wait_for_recv(buffer)?.event_type {
67 VsockEventType::Disconnected { .. } => {
68 return Err(SocketError::ConnectionFailed.into())
69 }
70 VsockEventType::Received { length, .. } => return Ok(length),
71 VsockEventType::Connected
72 | VsockEventType::ConnectionRequest
73 | VsockEventType::CreditRequest
74 | VsockEventType::CreditUpdate => {}
75 }
76 }
77 }
78
79 /// Shuts down the data channel.
80 pub fn force_close(&mut self) -> virtio_drivers::Result {
81 self.connection_manager.force_close()?;
82 info!("Connection shutdown.");
83 Ok(())
84 }
85}