blob: d8c7f51eeb9d006677ddd643511f053b9b9a8a07 [file] [log] [blame]
Seungjae Yooa55e7732024-09-25 12:59:57 +09001// Copyright 2024 The Android Open Source Project
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// Copied from ChromiumOS with relicensing:
16// src/platform2/vm_tools/chunnel/src/stream.rs
17
Seungjae Yoo82f68b32024-09-25 13:03:31 +090018//! This module provides abstraction of various stream socket type.
19
Seungjae Yooa55e7732024-09-25 12:59:57 +090020use std::fmt;
21use std::io;
22use std::net::TcpStream;
23use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd};
24use std::os::unix::net::UnixStream;
25use std::result;
26
27use libc::{self, c_void, shutdown, EPIPE, SHUT_WR};
28use vsock::VsockAddr;
29use vsock::VsockStream;
30
31/// Parse a vsock SocketAddr from a string. vsock socket addresses are of the form
32/// "vsock:cid:port".
33pub fn parse_vsock_addr(addr: &str) -> result::Result<VsockAddr, io::Error> {
34 let components: Vec<&str> = addr.split(':').collect();
35 if components.len() != 3 || components[0] != "vsock" {
36 return Err(io::Error::from_raw_os_error(libc::EINVAL));
37 }
38
39 Ok(VsockAddr::new(
Seungjae Yoo82f68b32024-09-25 13:03:31 +090040 components[1].parse().map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?,
41 components[2].parse().map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?,
Seungjae Yooa55e7732024-09-25 12:59:57 +090042 ))
43}
44
45/// StreamSocket provides a generic abstraction around any connection-oriented stream socket.
46/// The socket will be closed when StreamSocket is dropped, but writes to the socket can also
47/// be shut down manually.
48pub struct StreamSocket {
49 fd: RawFd,
50 shut_down: bool,
51}
52
53impl StreamSocket {
54 /// Connects to the given socket address. Supported socket types are vsock, unix, and TCP.
55 pub fn connect(sockaddr: &str) -> result::Result<StreamSocket, StreamSocketError> {
56 const UNIX_PREFIX: &str = "unix:";
57 const VSOCK_PREFIX: &str = "vsock:";
58
59 if sockaddr.starts_with(VSOCK_PREFIX) {
60 let addr = parse_vsock_addr(sockaddr)
61 .map_err(|e| StreamSocketError::ConnectVsock(sockaddr.to_string(), e))?;
62 let vsock_stream = VsockStream::connect(&addr)
63 .map_err(|e| StreamSocketError::ConnectVsock(sockaddr.to_string(), e))?;
64 Ok(vsock_stream.into())
65 } else if sockaddr.starts_with(UNIX_PREFIX) {
66 let (_prefix, sock_path) = sockaddr.split_at(UNIX_PREFIX.len());
67 let unix_stream = UnixStream::connect(sock_path)
68 .map_err(|e| StreamSocketError::ConnectUnix(sockaddr.to_string(), e))?;
69 Ok(unix_stream.into())
70 } else {
71 // Assume this is a TCP stream.
72 let tcp_stream = TcpStream::connect(sockaddr)
73 .map_err(|e| StreamSocketError::ConnectTcp(sockaddr.to_string(), e))?;
74 Ok(tcp_stream.into())
75 }
76 }
77
78 /// Shuts down writes to the socket using shutdown(2).
79 pub fn shut_down_write(&mut self) -> io::Result<()> {
Seungjae Yoo82f68b32024-09-25 13:03:31 +090080 // SAFETY:
Seungjae Yooa55e7732024-09-25 12:59:57 +090081 // Safe because no memory is modified and the return value is checked.
82 let ret = unsafe { shutdown(self.fd, SHUT_WR) };
83 if ret < 0 {
84 return Err(io::Error::last_os_error());
85 }
86
87 self.shut_down = true;
88 Ok(())
89 }
90
91 /// Returns true if the socket has been shut down for writes, false otherwise.
92 pub fn is_shut_down(&self) -> bool {
93 self.shut_down
94 }
95}
96
97impl io::Read for StreamSocket {
98 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
Seungjae Yoo82f68b32024-09-25 13:03:31 +090099 // SAFETY:
Seungjae Yooa55e7732024-09-25 12:59:57 +0900100 // Safe because this will only modify the contents of |buf| and we check the return value.
101 let ret = unsafe { libc::read(self.fd, buf.as_mut_ptr() as *mut c_void, buf.len()) };
102 if ret < 0 {
103 return Err(io::Error::last_os_error());
104 }
105
106 Ok(ret as usize)
107 }
108}
109
110impl io::Write for StreamSocket {
111 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
Seungjae Yoo82f68b32024-09-25 13:03:31 +0900112 // SAFETY:
Seungjae Yooa55e7732024-09-25 12:59:57 +0900113 // Safe because this doesn't modify any memory and we check the return value.
114 let ret = unsafe { libc::write(self.fd, buf.as_ptr() as *const c_void, buf.len()) };
115 if ret < 0 {
116 // If a write causes EPIPE then the socket is shut down for writes.
117 let err = io::Error::last_os_error();
118 if let Some(errno) = err.raw_os_error() {
119 if errno == EPIPE {
120 self.shut_down = true
121 }
122 }
123
124 return Err(err);
125 }
126
127 Ok(ret as usize)
128 }
129
130 fn flush(&mut self) -> io::Result<()> {
131 // No buffered data so nothing to do.
132 Ok(())
133 }
134}
135
136impl AsRawFd for StreamSocket {
137 fn as_raw_fd(&self) -> RawFd {
138 self.fd
139 }
140}
141
142impl From<TcpStream> for StreamSocket {
143 fn from(stream: TcpStream) -> Self {
Seungjae Yoo82f68b32024-09-25 13:03:31 +0900144 StreamSocket { fd: stream.into_raw_fd(), shut_down: false }
Seungjae Yooa55e7732024-09-25 12:59:57 +0900145 }
146}
147
148impl From<UnixStream> for StreamSocket {
149 fn from(stream: UnixStream) -> Self {
Seungjae Yoo82f68b32024-09-25 13:03:31 +0900150 StreamSocket { fd: stream.into_raw_fd(), shut_down: false }
Seungjae Yooa55e7732024-09-25 12:59:57 +0900151 }
152}
153
154impl From<VsockStream> for StreamSocket {
155 fn from(stream: VsockStream) -> Self {
Seungjae Yoo82f68b32024-09-25 13:03:31 +0900156 StreamSocket { fd: stream.into_raw_fd(), shut_down: false }
Seungjae Yooa55e7732024-09-25 12:59:57 +0900157 }
158}
159
160impl FromRawFd for StreamSocket {
161 unsafe fn from_raw_fd(fd: RawFd) -> Self {
Seungjae Yoo82f68b32024-09-25 13:03:31 +0900162 StreamSocket { fd, shut_down: false }
Seungjae Yooa55e7732024-09-25 12:59:57 +0900163 }
164}
165
166impl Drop for StreamSocket {
167 fn drop(&mut self) {
Seungjae Yoo82f68b32024-09-25 13:03:31 +0900168 // SAFETY:
Seungjae Yooa55e7732024-09-25 12:59:57 +0900169 // Safe because this doesn't modify any memory and we are the only
170 // owner of the file descriptor.
171 unsafe { libc::close(self.fd) };
172 }
173}
174
175/// Error enums for StreamSocket.
176#[remain::sorted]
177#[derive(Debug)]
178pub enum StreamSocketError {
Seungjae Yoo82f68b32024-09-25 13:03:31 +0900179 /// Error on connecting TCP socket.
Seungjae Yooa55e7732024-09-25 12:59:57 +0900180 ConnectTcp(String, io::Error),
Seungjae Yoo82f68b32024-09-25 13:03:31 +0900181 /// Error on connecting unix socket.
Seungjae Yooa55e7732024-09-25 12:59:57 +0900182 ConnectUnix(String, io::Error),
Seungjae Yoo82f68b32024-09-25 13:03:31 +0900183 /// Error on connecting vsock socket.
Seungjae Yooa55e7732024-09-25 12:59:57 +0900184 ConnectVsock(String, io::Error),
185}
186
187impl fmt::Display for StreamSocketError {
188 #[remain::check]
189 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
190 use self::StreamSocketError::*;
191
192 #[remain::sorted]
193 match self {
194 ConnectTcp(sockaddr, e) => {
195 write!(f, "failed to connect to TCP sockaddr {}: {}", sockaddr, e)
196 }
197 ConnectUnix(sockaddr, e) => {
198 write!(f, "failed to connect to unix sockaddr {}: {}", sockaddr, e)
199 }
200 ConnectVsock(sockaddr, e) => {
201 write!(f, "failed to connect to vsock sockaddr {}: {}", sockaddr, e)
202 }
203 }
204 }
205}
206
207#[cfg(test)]
208mod tests {
209 use super::*;
210 use std::io::{Read, Write};
211 use std::net::TcpListener;
212 use std::os::unix::net::{UnixListener, UnixStream};
213 use tempfile::TempDir;
214
215 #[test]
216 fn sock_connect_tcp() {
217 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
218 let sockaddr = format!("127.0.0.1:{}", listener.local_addr().unwrap().port());
219
220 let _stream = StreamSocket::connect(&sockaddr).unwrap();
221 }
222
223 #[test]
224 fn sock_connect_unix() {
225 let tempdir = TempDir::new().unwrap();
226 let path = tempdir.path().to_owned().join("test.sock");
227 let _listener = UnixListener::bind(&path).unwrap();
228
229 let unix_addr = format!("unix:{}", path.to_str().unwrap());
230 let _stream = StreamSocket::connect(&unix_addr).unwrap();
231 }
232
233 #[test]
234 fn invalid_sockaddr() {
235 assert!(StreamSocket::connect("this is not a valid sockaddr").is_err());
236 }
237
238 #[test]
239 fn shut_down_write() {
240 let (unix_stream, _dummy) = UnixStream::pair().unwrap();
241 let mut stream: StreamSocket = unix_stream.into();
242
243 stream.write_all(b"hello").unwrap();
244
245 stream.shut_down_write().unwrap();
246
247 assert!(stream.is_shut_down());
248 assert!(stream.write(b"goodbye").is_err());
249 }
250
251 #[test]
252 fn read_from_shut_down_sock() {
253 let (unix_stream1, unix_stream2) = UnixStream::pair().unwrap();
254 let mut stream1: StreamSocket = unix_stream1.into();
255 let mut stream2: StreamSocket = unix_stream2.into();
256
257 stream1.shut_down_write().unwrap();
258
259 // Reads from the other end of the socket should now return EOF.
260 let mut buf = Vec::new();
261 assert_eq!(stream2.read_to_end(&mut buf).unwrap(), 0);
262 }
263}