Merge "Mark service-compos as system_ext_specific" into main
diff --git a/build/debian/build.sh b/build/debian/build.sh
index 3d3820a..2177b17 100755
--- a/build/debian/build.sh
+++ b/build/debian/build.sh
@@ -37,6 +37,7 @@
apt update
DEBIAN_FRONTEND=noninteractive \
apt install --no-install-recommends --assume-yes \
+ binfmt-support \
ca-certificates \
debsums \
dosfstools \
@@ -49,10 +50,11 @@
python3-marshmallow \
python3-pytest \
python3-yaml \
+ qemu-system-arm \
+ qemu-user-static \
qemu-utils \
udev \
- qemu-system-arm \
- qemu-user-static
+
sed -i s/losetup\ -f/losetup\ -P\ -f/g /usr/sbin/fai-diskimage
sed -i 's/wget \$/wget -t 0 \$/g' /usr/share/debootstrap/functions
diff --git a/build/debian/fai_config/hooks/extrbase.BASE b/build/debian/fai_config/hooks/extrbase.BASE
deleted file mode 100755
index 05d1e96..0000000
--- a/build/debian/fai_config/hooks/extrbase.BASE
+++ /dev/null
@@ -1,6 +0,0 @@
-#!/bin/sh
-set -euE
-
-touch "${LOGDIR}/skip.extrbase"
-
-debootstrap --verbose --variant minbase --arch "$DEBOOTSTRAP_ARCH" "$SUITE" "$FAI_ROOT" "$DEBOOTSTRAP_MIRROR"
diff --git a/build/debian/fai_config/hooks/partition.ARM64 b/build/debian/fai_config/hooks/partition.ARM64
deleted file mode 100755
index b3b603b..0000000
--- a/build/debian/fai_config/hooks/partition.ARM64
+++ /dev/null
@@ -1,53 +0,0 @@
-#!/bin/sh
-set -eu
-touch $LOGDIR/skip.partition
-
-set -- $disklist
-device=/dev/$1
-
-wait_for_device() {
- for s in $(seq 10); do
- if [ -e "$1" ]; then
- break
- fi
- sleep 1
- done
-}
-
-sfdisk "$device" << EOF
-label: gpt
-unit: sectors
-
-# EFI system
-p15 : start=2048, size=260096, type="EFI System", uuid=${PARTUUID_ESP}
-# Linux
-p1 : start=262144, type="Linux root (ARM-64)", uuid=${PARTUUID_ROOT}
-EOF
-
-file=$(losetup -O BACK-FILE ${device} | tail -1)
-
-root_offset=$(parted -m ${device} unit B print | awk -F '[B:]' '/1:/{ print $2 }')
-root_size=$( parted -m ${device} unit B print | awk -F '[B:]' '/1:/{ print $6 }')
-efi_offset=$( parted -m ${device} unit B print | awk -F '[B:]' '/15:/{ print $2 }')
-efi_size=$( parted -m ${device} unit B print | awk -F '[B:]' '/15:/{ print $6 }')
-device_root=$(losetup -o ${root_offset} --sizelimit ${root_size} --show -f ${file})
-device_efi=$(losetup -o ${efi_offset} --sizelimit ${efi_size} --show -f ${file})
-rm -f ${device}p1
-rm -f ${device}p15
-ln -sf ${device_root} ${device}p1
-ln -sf ${device_efi} ${device}p15
-
-ls -al /dev/loop*
-losetup -a -l
-parted ${device} unit B print
-
-partprobe "$device"
-
-wait_for_device "$device_root"
-mkfs.ext4 -U "$FSUUID_ROOT" "$device_root"
-tune2fs -c 0 -i 0 "$device_root"
-
-wait_for_device "$device_efi"
-mkfs.vfat "$device_efi"
-
-parted ${device} unit B print
diff --git a/build/debian/forwarder_guest/Cargo.toml b/build/debian/forwarder_guest/Cargo.toml
new file mode 100644
index 0000000..e70dcd4
--- /dev/null
+++ b/build/debian/forwarder_guest/Cargo.toml
@@ -0,0 +1,11 @@
+[package]
+name = "forwarder_guest"
+version = "0.1.0"
+edition = "2021"
+
+[dependencies]
+clap = { version = "4.5.19", features = ["derive"] }
+forwarder = { path = "../../../libs/libforwarder" }
+poll_token_derive = "0.1.0"
+remain = "0.2.14"
+vmm-sys-util = "0.12.1"
diff --git a/build/debian/forwarder_guest/src/main.rs b/build/debian/forwarder_guest/src/main.rs
new file mode 100644
index 0000000..6ebd4ef
--- /dev/null
+++ b/build/debian/forwarder_guest/src/main.rs
@@ -0,0 +1,123 @@
+// Copyright 2024 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Copied from ChromiumOS with relicensing:
+// src/platform2/vm_tools/chunnel/src/bin/chunnel.rs
+
+//! Guest-side stream socket forwarder
+
+use std::fmt;
+use std::result;
+
+use clap::Parser;
+use forwarder::forwarder::{ForwarderError, ForwarderSession};
+use forwarder::stream::{StreamSocket, StreamSocketError};
+use poll_token_derive::PollToken;
+use vmm_sys_util::poll::{PollContext, PollToken};
+
+#[remain::sorted]
+#[derive(Debug)]
+enum Error {
+ ConnectSocket(StreamSocketError),
+ Forward(ForwarderError),
+ PollContextAdd(vmm_sys_util::errno::Error),
+ PollContextDelete(vmm_sys_util::errno::Error),
+ PollContextNew(vmm_sys_util::errno::Error),
+ PollWait(vmm_sys_util::errno::Error),
+}
+
+type Result<T> = result::Result<T, Error>;
+
+impl fmt::Display for Error {
+ #[remain::check]
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ use self::Error::*;
+
+ #[remain::sorted]
+ match self {
+ ConnectSocket(e) => write!(f, "failed to connect socket: {}", e),
+ Forward(e) => write!(f, "failed to forward traffic: {}", e),
+ PollContextAdd(e) => write!(f, "failed to add fd to poll context: {}", e),
+ PollContextDelete(e) => write!(f, "failed to delete fd from poll context: {}", e),
+ PollContextNew(e) => write!(f, "failed to create poll context: {}", e),
+ PollWait(e) => write!(f, "failed to wait for poll: {}", e),
+ }
+ }
+}
+
+fn run_forwarder(local_stream: StreamSocket, remote_stream: StreamSocket) -> Result<()> {
+ #[derive(PollToken)]
+ enum Token {
+ LocalStreamReadable,
+ RemoteStreamReadable,
+ }
+ let poll_ctx: PollContext<Token> = PollContext::new().map_err(Error::PollContextNew)?;
+ poll_ctx.add(&local_stream, Token::LocalStreamReadable).map_err(Error::PollContextAdd)?;
+ poll_ctx.add(&remote_stream, Token::RemoteStreamReadable).map_err(Error::PollContextAdd)?;
+
+ let mut forwarder = ForwarderSession::new(local_stream, remote_stream);
+
+ loop {
+ let events = poll_ctx.wait().map_err(Error::PollWait)?;
+
+ for event in events.iter_readable() {
+ match event.token() {
+ Token::LocalStreamReadable => {
+ let shutdown = forwarder.forward_from_local().map_err(Error::Forward)?;
+ if shutdown {
+ poll_ctx
+ .delete(forwarder.local_stream())
+ .map_err(Error::PollContextDelete)?;
+ }
+ }
+ Token::RemoteStreamReadable => {
+ let shutdown = forwarder.forward_from_remote().map_err(Error::Forward)?;
+ if shutdown {
+ poll_ctx
+ .delete(forwarder.remote_stream())
+ .map_err(Error::PollContextDelete)?;
+ }
+ }
+ }
+ }
+ if forwarder.is_shut_down() {
+ return Ok(());
+ }
+ }
+}
+
+#[derive(Parser)]
+/// Flags for running command
+pub struct Args {
+ /// Local socket address
+ #[arg(long)]
+ #[arg(alias = "local")]
+ local_sockaddr: String,
+
+ /// Remote socket address
+ #[arg(long)]
+ #[arg(alias = "remote")]
+ remote_sockaddr: String,
+}
+
+// TODO(b/370897694): Support forwarding for datagram socket
+fn main() -> Result<()> {
+ let args = Args::parse();
+
+ let local_stream = StreamSocket::connect(&args.local_sockaddr).map_err(Error::ConnectSocket)?;
+ let remote_stream =
+ StreamSocket::connect(&args.remote_sockaddr).map_err(Error::ConnectSocket)?;
+
+ run_forwarder(local_stream, remote_stream)
+}
diff --git a/build/debian/kokoro/gcp_ubuntu_docker/build.sh b/build/debian/kokoro/gcp_ubuntu_docker/build.sh
index fb2a1a3..4598d1c 100644
--- a/build/debian/kokoro/gcp_ubuntu_docker/build.sh
+++ b/build/debian/kokoro/gcp_ubuntu_docker/build.sh
@@ -4,4 +4,6 @@
cd "${KOKORO_ARTIFACTS_DIR}/git/avf/build/debian/"
sudo losetup -D
+grep vmx /proc/cpuinfo || true
sudo ./build.sh
+cp image.raw ${KOKORO_ARTIFACTS_DIR}
diff --git a/build/debian/kokoro/gcp_ubuntu_docker/continuous.cfg b/build/debian/kokoro/gcp_ubuntu_docker/continuous.cfg
index d92031e..111096d 100644
--- a/build/debian/kokoro/gcp_ubuntu_docker/continuous.cfg
+++ b/build/debian/kokoro/gcp_ubuntu_docker/continuous.cfg
@@ -5,3 +5,9 @@
# Location of the bash script. Should have value <git_on_borg_scm.name>/<path_from_repository_root>.
# git_on_borg_scm.name is specified in the job configuration (next section).
build_file: "avf/build/debian/kokoro/gcp_ubuntu_docker/build.sh"
+
+action {
+ define_artifacts {
+ regex: "image.raw"
+ }
+}
diff --git a/build/debian/kokoro/gcp_ubuntu_docker/presubmit.cfg b/build/debian/kokoro/gcp_ubuntu_docker/hourly.cfg
similarity index 85%
rename from build/debian/kokoro/gcp_ubuntu_docker/presubmit.cfg
rename to build/debian/kokoro/gcp_ubuntu_docker/hourly.cfg
index d92031e..111096d 100644
--- a/build/debian/kokoro/gcp_ubuntu_docker/presubmit.cfg
+++ b/build/debian/kokoro/gcp_ubuntu_docker/hourly.cfg
@@ -5,3 +5,9 @@
# Location of the bash script. Should have value <git_on_borg_scm.name>/<path_from_repository_root>.
# git_on_borg_scm.name is specified in the job configuration (next section).
build_file: "avf/build/debian/kokoro/gcp_ubuntu_docker/build.sh"
+
+action {
+ define_artifacts {
+ regex: "image.raw"
+ }
+}
diff --git a/guest/pvmfw/avb/Android.bp b/guest/pvmfw/avb/Android.bp
index 558152d..f97a713 100644
--- a/guest/pvmfw/avb/Android.bp
+++ b/guest/pvmfw/avb/Android.bp
@@ -43,6 +43,7 @@
":test_image_with_duplicated_capability",
":test_image_with_rollback_index_5",
":test_image_with_multiple_capabilities",
+ ":test_image_with_all_capabilities",
":unsigned_test_image",
],
prefer_rlib: true,
@@ -218,3 +219,17 @@
},
],
}
+
+avb_add_hash_footer {
+ name: "test_image_with_all_capabilities",
+ src: ":unsigned_test_image",
+ partition_name: "boot",
+ private_key: ":pvmfw_sign_key",
+ salt: "4231",
+ props: [
+ {
+ name: "com.android.virt.cap",
+ value: "remote_attest|secretkeeper_protection|supports_uefi_boot",
+ },
+ ],
+}
diff --git a/guest/pvmfw/avb/src/verify.rs b/guest/pvmfw/avb/src/verify.rs
index 038b1d6..bd700ce 100644
--- a/guest/pvmfw/avb/src/verify.rs
+++ b/guest/pvmfw/avb/src/verify.rs
@@ -70,6 +70,11 @@
RemoteAttest,
/// Secretkeeper protected secrets.
SecretkeeperProtection,
+ /// UEFI support for booting guest kernel.
+ SupportsUefiBoot,
+ /// (internal)
+ #[allow(non_camel_case_types)] // TODO: Use mem::variant_count once stable.
+ _VARIANT_COUNT,
}
impl Capability {
@@ -77,6 +82,9 @@
const REMOTE_ATTEST: &'static [u8] = b"remote_attest";
const SECRETKEEPER_PROTECTION: &'static [u8] = b"secretkeeper_protection";
const SEPARATOR: u8 = b'|';
+ const SUPPORTS_UEFI_BOOT: &'static [u8] = b"supports_uefi_boot";
+ /// Number of supported capabilites.
+ pub const COUNT: usize = Self::_VARIANT_COUNT as usize;
/// Returns the capabilities indicated in `descriptor`, or error if the descriptor has
/// unexpected contents.
@@ -91,6 +99,7 @@
let cap = match v {
Self::REMOTE_ATTEST => Self::RemoteAttest,
Self::SECRETKEEPER_PROTECTION => Self::SecretkeeperProtection,
+ Self::SUPPORTS_UEFI_BOOT => Self::SupportsUefiBoot,
_ => return Err(PvmfwVerifyError::UnknownVbmetaProperty),
};
if res.contains(&cap) {
diff --git a/guest/pvmfw/avb/tests/api_test.rs b/guest/pvmfw/avb/tests/api_test.rs
index 8683e69..01c13d4 100644
--- a/guest/pvmfw/avb/tests/api_test.rs
+++ b/guest/pvmfw/avb/tests/api_test.rs
@@ -38,6 +38,7 @@
const TEST_IMG_WITH_INITRD_AND_NON_INITRD_DESC_PATH: &str =
"test_image_with_initrd_and_non_initrd_desc.img";
const TEST_IMG_WITH_MULTIPLE_CAPABILITIES: &str = "test_image_with_multiple_capabilities.img";
+const TEST_IMG_WITH_ALL_CAPABILITIES: &str = "test_image_with_all_capabilities.img";
const UNSIGNED_TEST_IMG_PATH: &str = "unsigned_test.img";
const RANDOM_FOOTER_POS: usize = 30;
@@ -418,3 +419,22 @@
assert!(verified_boot_data.has_capability(Capability::SecretkeeperProtection));
Ok(())
}
+
+#[test]
+fn payload_with_all_capabilities() -> Result<()> {
+ let public_key = load_trusted_public_key()?;
+ let verified_boot_data = verify_payload(
+ &fs::read(TEST_IMG_WITH_ALL_CAPABILITIES)?,
+ /* initrd= */ None,
+ &public_key,
+ )
+ .map_err(|e| anyhow!("Verification failed. Error: {}", e))?;
+
+ assert!(verified_boot_data.has_capability(Capability::RemoteAttest));
+ assert!(verified_boot_data.has_capability(Capability::SecretkeeperProtection));
+ assert!(verified_boot_data.has_capability(Capability::SupportsUefiBoot));
+ // Fail if this test doesn't actually cover all supported capabilities.
+ assert_eq!(Capability::COUNT, 3);
+
+ Ok(())
+}
diff --git a/libs/libforwarder/Android.bp b/libs/libforwarder/Android.bp
new file mode 100644
index 0000000..48307e7
--- /dev/null
+++ b/libs/libforwarder/Android.bp
@@ -0,0 +1,15 @@
+package {
+ default_applicable_licenses: ["Android-Apache-2.0"],
+}
+
+rust_library {
+ name: "libforwarder",
+ crate_name: "forwarder",
+ edition: "2021",
+ srcs: ["src/lib.rs"],
+ rustlibs: [
+ "liblibc",
+ "libvsock",
+ ],
+ proc_macros: ["libremain"],
+}
diff --git a/libs/libforwarder/Cargo.toml b/libs/libforwarder/Cargo.toml
new file mode 100644
index 0000000..9f3f341
--- /dev/null
+++ b/libs/libforwarder/Cargo.toml
@@ -0,0 +1,9 @@
+[package]
+name = "forwarder"
+version = "0.1.0"
+edition = "2021"
+
+[dependencies]
+libc = "0.2.159"
+remain = "0.2.14"
+vsock = "0.5.1"
diff --git a/libs/libforwarder/src/forwarder.rs b/libs/libforwarder/src/forwarder.rs
new file mode 100644
index 0000000..3600ab2
--- /dev/null
+++ b/libs/libforwarder/src/forwarder.rs
@@ -0,0 +1,170 @@
+// Copyright 2024 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Copied from ChromiumOS with relicensing:
+// src/platform2/vm_tools/chunnel/src/forwarder.rs
+
+//! This module contains forwarding mechanism between stream sockets.
+
+use std::fmt;
+use std::io::{self, Read, Write};
+use std::result;
+
+use crate::stream::StreamSocket;
+
+// This was picked arbitrarily. crosvm doesn't yet use VIRTIO_NET_F_MTU, so there's no reason to
+// opt for massive 65535 byte frames.
+const MAX_FRAME_SIZE: usize = 8192;
+
+/// Errors that can be encountered by a ForwarderSession.
+#[remain::sorted]
+#[derive(Debug)]
+pub enum ForwarderError {
+ /// An io::Error was encountered while reading from a stream.
+ ReadFromStream(io::Error),
+ /// An io::Error was encountered while shutting down writes on a stream.
+ ShutDownStream(io::Error),
+ /// An io::Error was encountered while writing to a stream.
+ WriteToStream(io::Error),
+}
+
+type Result<T> = result::Result<T, ForwarderError>;
+
+impl fmt::Display for ForwarderError {
+ #[remain::check]
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ use self::ForwarderError::*;
+
+ #[remain::sorted]
+ match self {
+ ReadFromStream(e) => write!(f, "failed to read from stream: {}", e),
+ ShutDownStream(e) => write!(f, "failed to shut down stream: {}", e),
+ WriteToStream(e) => write!(f, "failed to write to stream: {}", e),
+ }
+ }
+}
+
+/// A ForwarderSession owns stream sockets that it forwards traffic between.
+pub struct ForwarderSession {
+ local: StreamSocket,
+ remote: StreamSocket,
+}
+
+fn forward(from_stream: &mut StreamSocket, to_stream: &mut StreamSocket) -> Result<bool> {
+ let mut buf = [0u8; MAX_FRAME_SIZE];
+
+ let count = from_stream.read(&mut buf).map_err(ForwarderError::ReadFromStream)?;
+ if count == 0 {
+ to_stream.shut_down_write().map_err(ForwarderError::ShutDownStream)?;
+ return Ok(true);
+ }
+
+ to_stream.write_all(&buf[..count]).map_err(ForwarderError::WriteToStream)?;
+ Ok(false)
+}
+
+impl ForwarderSession {
+ /// Creates a forwarder session from a local and remote stream socket.
+ pub fn new(local: StreamSocket, remote: StreamSocket) -> Self {
+ ForwarderSession { local, remote }
+ }
+
+ /// Forwards traffic from the local socket to the remote socket.
+ /// Returns true if the local socket has reached EOF and the
+ /// remote socket has been shut down for further writes.
+ pub fn forward_from_local(&mut self) -> Result<bool> {
+ forward(&mut self.local, &mut self.remote)
+ }
+
+ /// Forwards traffic from the remote socket to the local socket.
+ /// Returns true if the remote socket has reached EOF and the
+ /// local socket has been shut down for further writes.
+ pub fn forward_from_remote(&mut self) -> Result<bool> {
+ forward(&mut self.remote, &mut self.local)
+ }
+
+ /// Returns a reference to the local stream socket.
+ pub fn local_stream(&self) -> &StreamSocket {
+ &self.local
+ }
+
+ /// Returns a reference to the remote stream socket.
+ pub fn remote_stream(&self) -> &StreamSocket {
+ &self.remote
+ }
+
+ /// Returns true if both sockets are completely shut down and the session can be dropped.
+ pub fn is_shut_down(&self) -> bool {
+ self.local.is_shut_down() && self.remote.is_shut_down()
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use std::io::{Read, Write};
+ use std::net::Shutdown;
+ use std::os::unix::net::UnixStream;
+
+ #[test]
+ fn forward_unix() {
+ // Local streams.
+ let (mut london, folkestone) = UnixStream::pair().unwrap();
+ // Remote streams.
+ let (coquelles, mut paris) = UnixStream::pair().unwrap();
+
+ // Connect the local and remote sockets via the chunnel.
+ let mut forwarder = ForwarderSession::new(folkestone.into(), coquelles.into());
+
+ // Put some traffic in from London.
+ let greeting = b"hello";
+ london.write_all(greeting).unwrap();
+
+ // Expect forwarding from the local end not to have reached EOF.
+ assert!(!forwarder.forward_from_local().unwrap());
+ let mut salutation = [0u8; 8];
+ let count = paris.read(&mut salutation).unwrap();
+ assert_eq!(greeting.len(), count);
+ assert_eq!(greeting, &salutation[..count]);
+
+ // Shut the local socket down. The forwarder should detect this and perform a shutdown,
+ // which will manifest as an EOF when reading.
+ london.shutdown(Shutdown::Write).unwrap();
+ assert!(forwarder.forward_from_local().unwrap());
+ assert_eq!(paris.read(&mut salutation).unwrap(), 0);
+
+ // Don't consider the forwarder shut down until both ends are.
+ assert!(!forwarder.is_shut_down());
+
+ // Forward traffic from the remote end.
+ let salutation = b"bonjour";
+ paris.write_all(salutation).unwrap();
+
+ // Expect forwarding from the remote end not to have reached EOF.
+ assert!(!forwarder.forward_from_remote().unwrap());
+ let mut greeting = [0u8; 8];
+ let count = london.read(&mut greeting).unwrap();
+ assert_eq!(salutation.len(), count);
+ assert_eq!(salutation, &greeting[..count]);
+
+ // Shut the remote socket down. The forwarder should detect this and perform a shutdown,
+ // which will manifest as an EOF when reading.
+ paris.shutdown(Shutdown::Write).unwrap();
+ assert!(forwarder.forward_from_remote().unwrap());
+ assert_eq!(london.read(&mut greeting).unwrap(), 0);
+
+ // The forwarder should now be considered shut down.
+ assert!(forwarder.is_shut_down());
+ }
+}
diff --git a/libs/libforwarder/src/lib.rs b/libs/libforwarder/src/lib.rs
new file mode 100644
index 0000000..bcce689
--- /dev/null
+++ b/libs/libforwarder/src/lib.rs
@@ -0,0 +1,21 @@
+// Copyright 2024 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Copied from ChromiumOS with relicensing:
+// src/platform2/vm_tools/chunnel/src/lib.rs
+
+//! Library for stream socket forwarding.
+
+pub mod forwarder;
+pub mod stream;
diff --git a/libs/libforwarder/src/stream.rs b/libs/libforwarder/src/stream.rs
new file mode 100644
index 0000000..d8c7f51
--- /dev/null
+++ b/libs/libforwarder/src/stream.rs
@@ -0,0 +1,263 @@
+// Copyright 2024 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Copied from ChromiumOS with relicensing:
+// src/platform2/vm_tools/chunnel/src/stream.rs
+
+//! This module provides abstraction of various stream socket type.
+
+use std::fmt;
+use std::io;
+use std::net::TcpStream;
+use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd};
+use std::os::unix::net::UnixStream;
+use std::result;
+
+use libc::{self, c_void, shutdown, EPIPE, SHUT_WR};
+use vsock::VsockAddr;
+use vsock::VsockStream;
+
+/// Parse a vsock SocketAddr from a string. vsock socket addresses are of the form
+/// "vsock:cid:port".
+pub fn parse_vsock_addr(addr: &str) -> result::Result<VsockAddr, io::Error> {
+ let components: Vec<&str> = addr.split(':').collect();
+ if components.len() != 3 || components[0] != "vsock" {
+ return Err(io::Error::from_raw_os_error(libc::EINVAL));
+ }
+
+ Ok(VsockAddr::new(
+ components[1].parse().map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?,
+ components[2].parse().map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?,
+ ))
+}
+
+/// StreamSocket provides a generic abstraction around any connection-oriented stream socket.
+/// The socket will be closed when StreamSocket is dropped, but writes to the socket can also
+/// be shut down manually.
+pub struct StreamSocket {
+ fd: RawFd,
+ shut_down: bool,
+}
+
+impl StreamSocket {
+ /// Connects to the given socket address. Supported socket types are vsock, unix, and TCP.
+ pub fn connect(sockaddr: &str) -> result::Result<StreamSocket, StreamSocketError> {
+ const UNIX_PREFIX: &str = "unix:";
+ const VSOCK_PREFIX: &str = "vsock:";
+
+ if sockaddr.starts_with(VSOCK_PREFIX) {
+ let addr = parse_vsock_addr(sockaddr)
+ .map_err(|e| StreamSocketError::ConnectVsock(sockaddr.to_string(), e))?;
+ let vsock_stream = VsockStream::connect(&addr)
+ .map_err(|e| StreamSocketError::ConnectVsock(sockaddr.to_string(), e))?;
+ Ok(vsock_stream.into())
+ } else if sockaddr.starts_with(UNIX_PREFIX) {
+ let (_prefix, sock_path) = sockaddr.split_at(UNIX_PREFIX.len());
+ let unix_stream = UnixStream::connect(sock_path)
+ .map_err(|e| StreamSocketError::ConnectUnix(sockaddr.to_string(), e))?;
+ Ok(unix_stream.into())
+ } else {
+ // Assume this is a TCP stream.
+ let tcp_stream = TcpStream::connect(sockaddr)
+ .map_err(|e| StreamSocketError::ConnectTcp(sockaddr.to_string(), e))?;
+ Ok(tcp_stream.into())
+ }
+ }
+
+ /// Shuts down writes to the socket using shutdown(2).
+ pub fn shut_down_write(&mut self) -> io::Result<()> {
+ // SAFETY:
+ // Safe because no memory is modified and the return value is checked.
+ let ret = unsafe { shutdown(self.fd, SHUT_WR) };
+ if ret < 0 {
+ return Err(io::Error::last_os_error());
+ }
+
+ self.shut_down = true;
+ Ok(())
+ }
+
+ /// Returns true if the socket has been shut down for writes, false otherwise.
+ pub fn is_shut_down(&self) -> bool {
+ self.shut_down
+ }
+}
+
+impl io::Read for StreamSocket {
+ fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
+ // SAFETY:
+ // Safe because this will only modify the contents of |buf| and we check the return value.
+ let ret = unsafe { libc::read(self.fd, buf.as_mut_ptr() as *mut c_void, buf.len()) };
+ if ret < 0 {
+ return Err(io::Error::last_os_error());
+ }
+
+ Ok(ret as usize)
+ }
+}
+
+impl io::Write for StreamSocket {
+ fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
+ // SAFETY:
+ // Safe because this doesn't modify any memory and we check the return value.
+ let ret = unsafe { libc::write(self.fd, buf.as_ptr() as *const c_void, buf.len()) };
+ if ret < 0 {
+ // If a write causes EPIPE then the socket is shut down for writes.
+ let err = io::Error::last_os_error();
+ if let Some(errno) = err.raw_os_error() {
+ if errno == EPIPE {
+ self.shut_down = true
+ }
+ }
+
+ return Err(err);
+ }
+
+ Ok(ret as usize)
+ }
+
+ fn flush(&mut self) -> io::Result<()> {
+ // No buffered data so nothing to do.
+ Ok(())
+ }
+}
+
+impl AsRawFd for StreamSocket {
+ fn as_raw_fd(&self) -> RawFd {
+ self.fd
+ }
+}
+
+impl From<TcpStream> for StreamSocket {
+ fn from(stream: TcpStream) -> Self {
+ StreamSocket { fd: stream.into_raw_fd(), shut_down: false }
+ }
+}
+
+impl From<UnixStream> for StreamSocket {
+ fn from(stream: UnixStream) -> Self {
+ StreamSocket { fd: stream.into_raw_fd(), shut_down: false }
+ }
+}
+
+impl From<VsockStream> for StreamSocket {
+ fn from(stream: VsockStream) -> Self {
+ StreamSocket { fd: stream.into_raw_fd(), shut_down: false }
+ }
+}
+
+impl FromRawFd for StreamSocket {
+ unsafe fn from_raw_fd(fd: RawFd) -> Self {
+ StreamSocket { fd, shut_down: false }
+ }
+}
+
+impl Drop for StreamSocket {
+ fn drop(&mut self) {
+ // SAFETY:
+ // Safe because this doesn't modify any memory and we are the only
+ // owner of the file descriptor.
+ unsafe { libc::close(self.fd) };
+ }
+}
+
+/// Error enums for StreamSocket.
+#[remain::sorted]
+#[derive(Debug)]
+pub enum StreamSocketError {
+ /// Error on connecting TCP socket.
+ ConnectTcp(String, io::Error),
+ /// Error on connecting unix socket.
+ ConnectUnix(String, io::Error),
+ /// Error on connecting vsock socket.
+ ConnectVsock(String, io::Error),
+}
+
+impl fmt::Display for StreamSocketError {
+ #[remain::check]
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ use self::StreamSocketError::*;
+
+ #[remain::sorted]
+ match self {
+ ConnectTcp(sockaddr, e) => {
+ write!(f, "failed to connect to TCP sockaddr {}: {}", sockaddr, e)
+ }
+ ConnectUnix(sockaddr, e) => {
+ write!(f, "failed to connect to unix sockaddr {}: {}", sockaddr, e)
+ }
+ ConnectVsock(sockaddr, e) => {
+ write!(f, "failed to connect to vsock sockaddr {}: {}", sockaddr, e)
+ }
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use std::io::{Read, Write};
+ use std::net::TcpListener;
+ use std::os::unix::net::{UnixListener, UnixStream};
+ use tempfile::TempDir;
+
+ #[test]
+ fn sock_connect_tcp() {
+ let listener = TcpListener::bind("127.0.0.1:0").unwrap();
+ let sockaddr = format!("127.0.0.1:{}", listener.local_addr().unwrap().port());
+
+ let _stream = StreamSocket::connect(&sockaddr).unwrap();
+ }
+
+ #[test]
+ fn sock_connect_unix() {
+ let tempdir = TempDir::new().unwrap();
+ let path = tempdir.path().to_owned().join("test.sock");
+ let _listener = UnixListener::bind(&path).unwrap();
+
+ let unix_addr = format!("unix:{}", path.to_str().unwrap());
+ let _stream = StreamSocket::connect(&unix_addr).unwrap();
+ }
+
+ #[test]
+ fn invalid_sockaddr() {
+ assert!(StreamSocket::connect("this is not a valid sockaddr").is_err());
+ }
+
+ #[test]
+ fn shut_down_write() {
+ let (unix_stream, _dummy) = UnixStream::pair().unwrap();
+ let mut stream: StreamSocket = unix_stream.into();
+
+ stream.write_all(b"hello").unwrap();
+
+ stream.shut_down_write().unwrap();
+
+ assert!(stream.is_shut_down());
+ assert!(stream.write(b"goodbye").is_err());
+ }
+
+ #[test]
+ fn read_from_shut_down_sock() {
+ let (unix_stream1, unix_stream2) = UnixStream::pair().unwrap();
+ let mut stream1: StreamSocket = unix_stream1.into();
+ let mut stream2: StreamSocket = unix_stream2.into();
+
+ stream1.shut_down_write().unwrap();
+
+ // Reads from the other end of the socket should now return EOF.
+ let mut buf = Vec::new();
+ assert_eq!(stream2.read_to_end(&mut buf).unwrap(), 0);
+ }
+}