Seungjae Yoo | a55e773 | 2024-09-25 12:59:57 +0900 | [diff] [blame] | 1 | // Copyright 2024 The Android Open Source Project |
| 2 | // |
| 3 | // Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | // you may not use this file except in compliance with the License. |
| 5 | // You may obtain a copy of the License at |
| 6 | // |
| 7 | // http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | // |
| 9 | // Unless required by applicable law or agreed to in writing, software |
| 10 | // distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | // See the License for the specific language governing permissions and |
| 13 | // limitations under the License. |
| 14 | |
| 15 | // Copied from ChromiumOS with relicensing: |
| 16 | // src/platform2/vm_tools/chunnel/src/stream.rs |
| 17 | |
Seungjae Yoo | 82f68b3 | 2024-09-25 13:03:31 +0900 | [diff] [blame] | 18 | //! This module provides abstraction of various stream socket type. |
| 19 | |
Seungjae Yoo | a55e773 | 2024-09-25 12:59:57 +0900 | [diff] [blame] | 20 | use std::fmt; |
| 21 | use std::io; |
| 22 | use std::net::TcpStream; |
| 23 | use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd}; |
| 24 | use std::os::unix::net::UnixStream; |
| 25 | use std::result; |
| 26 | |
| 27 | use libc::{self, c_void, shutdown, EPIPE, SHUT_WR}; |
| 28 | use vsock::VsockAddr; |
| 29 | use vsock::VsockStream; |
| 30 | |
| 31 | /// Parse a vsock SocketAddr from a string. vsock socket addresses are of the form |
| 32 | /// "vsock:cid:port". |
| 33 | pub fn parse_vsock_addr(addr: &str) -> result::Result<VsockAddr, io::Error> { |
| 34 | let components: Vec<&str> = addr.split(':').collect(); |
| 35 | if components.len() != 3 || components[0] != "vsock" { |
| 36 | return Err(io::Error::from_raw_os_error(libc::EINVAL)); |
| 37 | } |
| 38 | |
| 39 | Ok(VsockAddr::new( |
Seungjae Yoo | 82f68b3 | 2024-09-25 13:03:31 +0900 | [diff] [blame] | 40 | components[1].parse().map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?, |
| 41 | components[2].parse().map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?, |
Seungjae Yoo | a55e773 | 2024-09-25 12:59:57 +0900 | [diff] [blame] | 42 | )) |
| 43 | } |
| 44 | |
| 45 | /// StreamSocket provides a generic abstraction around any connection-oriented stream socket. |
| 46 | /// The socket will be closed when StreamSocket is dropped, but writes to the socket can also |
| 47 | /// be shut down manually. |
| 48 | pub struct StreamSocket { |
| 49 | fd: RawFd, |
| 50 | shut_down: bool, |
| 51 | } |
| 52 | |
| 53 | impl StreamSocket { |
| 54 | /// Connects to the given socket address. Supported socket types are vsock, unix, and TCP. |
| 55 | pub fn connect(sockaddr: &str) -> result::Result<StreamSocket, StreamSocketError> { |
| 56 | const UNIX_PREFIX: &str = "unix:"; |
| 57 | const VSOCK_PREFIX: &str = "vsock:"; |
| 58 | |
| 59 | if sockaddr.starts_with(VSOCK_PREFIX) { |
| 60 | let addr = parse_vsock_addr(sockaddr) |
| 61 | .map_err(|e| StreamSocketError::ConnectVsock(sockaddr.to_string(), e))?; |
| 62 | let vsock_stream = VsockStream::connect(&addr) |
| 63 | .map_err(|e| StreamSocketError::ConnectVsock(sockaddr.to_string(), e))?; |
| 64 | Ok(vsock_stream.into()) |
| 65 | } else if sockaddr.starts_with(UNIX_PREFIX) { |
| 66 | let (_prefix, sock_path) = sockaddr.split_at(UNIX_PREFIX.len()); |
| 67 | let unix_stream = UnixStream::connect(sock_path) |
| 68 | .map_err(|e| StreamSocketError::ConnectUnix(sockaddr.to_string(), e))?; |
| 69 | Ok(unix_stream.into()) |
| 70 | } else { |
| 71 | // Assume this is a TCP stream. |
| 72 | let tcp_stream = TcpStream::connect(sockaddr) |
| 73 | .map_err(|e| StreamSocketError::ConnectTcp(sockaddr.to_string(), e))?; |
| 74 | Ok(tcp_stream.into()) |
| 75 | } |
| 76 | } |
| 77 | |
| 78 | /// Shuts down writes to the socket using shutdown(2). |
| 79 | pub fn shut_down_write(&mut self) -> io::Result<()> { |
Seungjae Yoo | 82f68b3 | 2024-09-25 13:03:31 +0900 | [diff] [blame] | 80 | // SAFETY: |
Seungjae Yoo | a55e773 | 2024-09-25 12:59:57 +0900 | [diff] [blame] | 81 | // Safe because no memory is modified and the return value is checked. |
| 82 | let ret = unsafe { shutdown(self.fd, SHUT_WR) }; |
| 83 | if ret < 0 { |
| 84 | return Err(io::Error::last_os_error()); |
| 85 | } |
| 86 | |
| 87 | self.shut_down = true; |
| 88 | Ok(()) |
| 89 | } |
| 90 | |
| 91 | /// Returns true if the socket has been shut down for writes, false otherwise. |
| 92 | pub fn is_shut_down(&self) -> bool { |
| 93 | self.shut_down |
| 94 | } |
| 95 | } |
| 96 | |
| 97 | impl io::Read for StreamSocket { |
| 98 | fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { |
Seungjae Yoo | 82f68b3 | 2024-09-25 13:03:31 +0900 | [diff] [blame] | 99 | // SAFETY: |
Seungjae Yoo | a55e773 | 2024-09-25 12:59:57 +0900 | [diff] [blame] | 100 | // Safe because this will only modify the contents of |buf| and we check the return value. |
| 101 | let ret = unsafe { libc::read(self.fd, buf.as_mut_ptr() as *mut c_void, buf.len()) }; |
| 102 | if ret < 0 { |
| 103 | return Err(io::Error::last_os_error()); |
| 104 | } |
| 105 | |
| 106 | Ok(ret as usize) |
| 107 | } |
| 108 | } |
| 109 | |
| 110 | impl io::Write for StreamSocket { |
| 111 | fn write(&mut self, buf: &[u8]) -> io::Result<usize> { |
Seungjae Yoo | 82f68b3 | 2024-09-25 13:03:31 +0900 | [diff] [blame] | 112 | // SAFETY: |
Seungjae Yoo | a55e773 | 2024-09-25 12:59:57 +0900 | [diff] [blame] | 113 | // Safe because this doesn't modify any memory and we check the return value. |
| 114 | let ret = unsafe { libc::write(self.fd, buf.as_ptr() as *const c_void, buf.len()) }; |
| 115 | if ret < 0 { |
| 116 | // If a write causes EPIPE then the socket is shut down for writes. |
| 117 | let err = io::Error::last_os_error(); |
| 118 | if let Some(errno) = err.raw_os_error() { |
| 119 | if errno == EPIPE { |
| 120 | self.shut_down = true |
| 121 | } |
| 122 | } |
| 123 | |
| 124 | return Err(err); |
| 125 | } |
| 126 | |
| 127 | Ok(ret as usize) |
| 128 | } |
| 129 | |
| 130 | fn flush(&mut self) -> io::Result<()> { |
| 131 | // No buffered data so nothing to do. |
| 132 | Ok(()) |
| 133 | } |
| 134 | } |
| 135 | |
| 136 | impl AsRawFd for StreamSocket { |
| 137 | fn as_raw_fd(&self) -> RawFd { |
| 138 | self.fd |
| 139 | } |
| 140 | } |
| 141 | |
| 142 | impl From<TcpStream> for StreamSocket { |
| 143 | fn from(stream: TcpStream) -> Self { |
Seungjae Yoo | 82f68b3 | 2024-09-25 13:03:31 +0900 | [diff] [blame] | 144 | StreamSocket { fd: stream.into_raw_fd(), shut_down: false } |
Seungjae Yoo | a55e773 | 2024-09-25 12:59:57 +0900 | [diff] [blame] | 145 | } |
| 146 | } |
| 147 | |
| 148 | impl From<UnixStream> for StreamSocket { |
| 149 | fn from(stream: UnixStream) -> Self { |
Seungjae Yoo | 82f68b3 | 2024-09-25 13:03:31 +0900 | [diff] [blame] | 150 | StreamSocket { fd: stream.into_raw_fd(), shut_down: false } |
Seungjae Yoo | a55e773 | 2024-09-25 12:59:57 +0900 | [diff] [blame] | 151 | } |
| 152 | } |
| 153 | |
| 154 | impl From<VsockStream> for StreamSocket { |
| 155 | fn from(stream: VsockStream) -> Self { |
Seungjae Yoo | 82f68b3 | 2024-09-25 13:03:31 +0900 | [diff] [blame] | 156 | StreamSocket { fd: stream.into_raw_fd(), shut_down: false } |
Seungjae Yoo | a55e773 | 2024-09-25 12:59:57 +0900 | [diff] [blame] | 157 | } |
| 158 | } |
| 159 | |
| 160 | impl FromRawFd for StreamSocket { |
| 161 | unsafe fn from_raw_fd(fd: RawFd) -> Self { |
Seungjae Yoo | 82f68b3 | 2024-09-25 13:03:31 +0900 | [diff] [blame] | 162 | StreamSocket { fd, shut_down: false } |
Seungjae Yoo | a55e773 | 2024-09-25 12:59:57 +0900 | [diff] [blame] | 163 | } |
| 164 | } |
| 165 | |
| 166 | impl Drop for StreamSocket { |
| 167 | fn drop(&mut self) { |
Seungjae Yoo | 82f68b3 | 2024-09-25 13:03:31 +0900 | [diff] [blame] | 168 | // SAFETY: |
Seungjae Yoo | a55e773 | 2024-09-25 12:59:57 +0900 | [diff] [blame] | 169 | // Safe because this doesn't modify any memory and we are the only |
| 170 | // owner of the file descriptor. |
| 171 | unsafe { libc::close(self.fd) }; |
| 172 | } |
| 173 | } |
| 174 | |
| 175 | /// Error enums for StreamSocket. |
| 176 | #[remain::sorted] |
| 177 | #[derive(Debug)] |
| 178 | pub enum StreamSocketError { |
Seungjae Yoo | 82f68b3 | 2024-09-25 13:03:31 +0900 | [diff] [blame] | 179 | /// Error on connecting TCP socket. |
Seungjae Yoo | a55e773 | 2024-09-25 12:59:57 +0900 | [diff] [blame] | 180 | ConnectTcp(String, io::Error), |
Seungjae Yoo | 82f68b3 | 2024-09-25 13:03:31 +0900 | [diff] [blame] | 181 | /// Error on connecting unix socket. |
Seungjae Yoo | a55e773 | 2024-09-25 12:59:57 +0900 | [diff] [blame] | 182 | ConnectUnix(String, io::Error), |
Seungjae Yoo | 82f68b3 | 2024-09-25 13:03:31 +0900 | [diff] [blame] | 183 | /// Error on connecting vsock socket. |
Seungjae Yoo | a55e773 | 2024-09-25 12:59:57 +0900 | [diff] [blame] | 184 | ConnectVsock(String, io::Error), |
| 185 | } |
| 186 | |
| 187 | impl fmt::Display for StreamSocketError { |
| 188 | #[remain::check] |
| 189 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { |
| 190 | use self::StreamSocketError::*; |
| 191 | |
| 192 | #[remain::sorted] |
| 193 | match self { |
| 194 | ConnectTcp(sockaddr, e) => { |
| 195 | write!(f, "failed to connect to TCP sockaddr {}: {}", sockaddr, e) |
| 196 | } |
| 197 | ConnectUnix(sockaddr, e) => { |
| 198 | write!(f, "failed to connect to unix sockaddr {}: {}", sockaddr, e) |
| 199 | } |
| 200 | ConnectVsock(sockaddr, e) => { |
| 201 | write!(f, "failed to connect to vsock sockaddr {}: {}", sockaddr, e) |
| 202 | } |
| 203 | } |
| 204 | } |
| 205 | } |
| 206 | |
| 207 | #[cfg(test)] |
| 208 | mod tests { |
| 209 | use super::*; |
| 210 | use std::io::{Read, Write}; |
| 211 | use std::net::TcpListener; |
| 212 | use std::os::unix::net::{UnixListener, UnixStream}; |
| 213 | use tempfile::TempDir; |
| 214 | |
| 215 | #[test] |
| 216 | fn sock_connect_tcp() { |
| 217 | let listener = TcpListener::bind("127.0.0.1:0").unwrap(); |
| 218 | let sockaddr = format!("127.0.0.1:{}", listener.local_addr().unwrap().port()); |
| 219 | |
| 220 | let _stream = StreamSocket::connect(&sockaddr).unwrap(); |
| 221 | } |
| 222 | |
| 223 | #[test] |
| 224 | fn sock_connect_unix() { |
| 225 | let tempdir = TempDir::new().unwrap(); |
| 226 | let path = tempdir.path().to_owned().join("test.sock"); |
| 227 | let _listener = UnixListener::bind(&path).unwrap(); |
| 228 | |
| 229 | let unix_addr = format!("unix:{}", path.to_str().unwrap()); |
| 230 | let _stream = StreamSocket::connect(&unix_addr).unwrap(); |
| 231 | } |
| 232 | |
| 233 | #[test] |
| 234 | fn invalid_sockaddr() { |
| 235 | assert!(StreamSocket::connect("this is not a valid sockaddr").is_err()); |
| 236 | } |
| 237 | |
| 238 | #[test] |
| 239 | fn shut_down_write() { |
| 240 | let (unix_stream, _dummy) = UnixStream::pair().unwrap(); |
| 241 | let mut stream: StreamSocket = unix_stream.into(); |
| 242 | |
| 243 | stream.write_all(b"hello").unwrap(); |
| 244 | |
| 245 | stream.shut_down_write().unwrap(); |
| 246 | |
| 247 | assert!(stream.is_shut_down()); |
| 248 | assert!(stream.write(b"goodbye").is_err()); |
| 249 | } |
| 250 | |
| 251 | #[test] |
| 252 | fn read_from_shut_down_sock() { |
| 253 | let (unix_stream1, unix_stream2) = UnixStream::pair().unwrap(); |
| 254 | let mut stream1: StreamSocket = unix_stream1.into(); |
| 255 | let mut stream2: StreamSocket = unix_stream2.into(); |
| 256 | |
| 257 | stream1.shut_down_write().unwrap(); |
| 258 | |
| 259 | // Reads from the other end of the socket should now return EOF. |
| 260 | let mut buf = Vec::new(); |
| 261 | assert_eq!(stream2.read_to_end(&mut buf).unwrap(), 0); |
| 262 | } |
| 263 | } |