[refactoring] Add VmType enum to service_vm_comm
to improve the readability of the code.
Bug: 299089107
Test: atest rialto_test
Change-Id: Ic8fec11701e90815e83c01b1914cb363604bff50
diff --git a/libs/service_vm_comm/src/lib.rs b/libs/service_vm_comm/src/lib.rs
index 6f03209..3b53b63 100644
--- a/libs/service_vm_comm/src/lib.rs
+++ b/libs/service_vm_comm/src/lib.rs
@@ -23,4 +23,4 @@
mod vsock;
pub use message::{EcdsaP256KeyPair, GenerateCertificateRequestParams, Request, Response};
-pub use vsock::host_port;
+pub use vsock::VmType;
diff --git a/libs/service_vm_comm/src/vsock.rs b/libs/service_vm_comm/src/vsock.rs
index fd6f088..aa7166d 100644
--- a/libs/service_vm_comm/src/vsock.rs
+++ b/libs/service_vm_comm/src/vsock.rs
@@ -14,14 +14,34 @@
//! Vsock setup shared between the host and the service VM.
-/// Returns the host port number for the given VM protection state.
-pub fn host_port(is_protected_vm: bool) -> u32 {
- const PROTECTED_VM_PORT: u32 = 5679;
- const NON_PROTECTED_VM_PORT: u32 = 5680;
+const PROTECTED_VM_PORT: u32 = 5679;
+const NON_PROTECTED_VM_PORT: u32 = 5680;
- if is_protected_vm {
- PROTECTED_VM_PORT
- } else {
- NON_PROTECTED_VM_PORT
+/// VM Type.
+#[derive(Clone, Copy, Debug)]
+pub enum VmType {
+ /// Protected VM.
+ ProtectedVm,
+
+ /// NonProtectev VM.
+ NonProtectedVm,
+}
+
+impl VmType {
+ /// Returns the port number used for the vsock communication between
+ /// the host and the service VM.
+ pub fn port(&self) -> u32 {
+ match self {
+ Self::ProtectedVm => PROTECTED_VM_PORT,
+ Self::NonProtectedVm => NON_PROTECTED_VM_PORT,
+ }
+ }
+
+ /// Returns whether it is a protected VM.
+ pub fn is_protected(&self) -> bool {
+ match self {
+ Self::ProtectedVm => true,
+ Self::NonProtectedVm => false,
+ }
}
}
diff --git a/rialto/src/main.rs b/rialto/src/main.rs
index 0ecbe9d..d777b2d 100644
--- a/rialto/src/main.rs
+++ b/rialto/src/main.rs
@@ -33,6 +33,7 @@
use hyp::{get_mem_sharer, get_mmio_guard};
use libfdt::FdtError;
use log::{debug, error, info};
+use service_vm_comm::VmType;
use virtio_drivers::{
device::socket::{VsockAddr, VMADDR_CID_HOST},
transport::{pci::bus::PciRoot, DeviceType, Transport},
@@ -52,12 +53,16 @@
};
fn host_addr() -> VsockAddr {
- VsockAddr { cid: VMADDR_CID_HOST, port: service_vm_comm::host_port(is_protected_vm()) }
+ VsockAddr { cid: VMADDR_CID_HOST, port: vm_type().port() }
}
-fn is_protected_vm() -> bool {
+fn vm_type() -> VmType {
// Use MMIO support to determine whether the VM is protected.
- get_mmio_guard().is_some()
+ if get_mmio_guard().is_some() {
+ VmType::ProtectedVm
+ } else {
+ VmType::NonProtectedVm
+ }
}
fn new_page_table() -> Result<PageTable> {
diff --git a/rialto/tests/test.rs b/rialto/tests/test.rs
index e9bdab6..f7df217 100644
--- a/rialto/tests/test.rs
+++ b/rialto/tests/test.rs
@@ -24,7 +24,7 @@
};
use anyhow::{anyhow, bail, Context, Result};
use log::info;
-use service_vm_comm::{host_port, Request, Response};
+use service_vm_comm::{Request, Response, VmType};
use std::fs::File;
use std::io::{self, BufRead, BufReader, BufWriter, Write};
use std::os::unix::io::FromRawFd;
@@ -41,21 +41,15 @@
#[test]
fn boot_rialto_in_protected_vm_successfully() -> Result<()> {
- boot_rialto_successfully(
- SIGNED_RIALTO_PATH,
- true, // protected_vm
- )
+ boot_rialto_successfully(SIGNED_RIALTO_PATH, VmType::ProtectedVm)
}
#[test]
fn boot_rialto_in_unprotected_vm_successfully() -> Result<()> {
- boot_rialto_successfully(
- UNSIGNED_RIALTO_PATH,
- false, // protected_vm
- )
+ boot_rialto_successfully(UNSIGNED_RIALTO_PATH, VmType::NonProtectedVm)
}
-fn boot_rialto_successfully(rialto_path: &str, protected_vm: bool) -> Result<()> {
+fn boot_rialto_successfully(rialto_path: &str, vm_type: VmType) -> Result<()> {
android_logger::init_once(
android_logger::Config::default().with_tag("rialto").with_min_level(log::Level::Debug),
);
@@ -76,30 +70,31 @@
let console = android_log_fd()?;
let log = android_log_fd()?;
- let disks = if protected_vm {
- let instance_img = File::options()
- .create(true)
- .read(true)
- .write(true)
- .truncate(true)
- .open(INSTANCE_IMG_PATH)?;
- let instance_img = ParcelFileDescriptor::new(instance_img);
+ let disks = match vm_type {
+ VmType::ProtectedVm => {
+ let instance_img = File::options()
+ .create(true)
+ .read(true)
+ .write(true)
+ .truncate(true)
+ .open(INSTANCE_IMG_PATH)?;
+ let instance_img = ParcelFileDescriptor::new(instance_img);
- service
- .initializeWritablePartition(
- &instance_img,
- INSTANCE_IMG_SIZE,
- PartitionType::ANDROID_VM_INSTANCE,
- )
- .context("Failed to initialize instange.img")?;
- let writable_partitions = vec![Partition {
- label: "vm-instance".to_owned(),
- image: Some(instance_img),
- writable: true,
- }];
- vec![DiskImage { image: None, partitions: writable_partitions, writable: true }]
- } else {
- vec![]
+ service
+ .initializeWritablePartition(
+ &instance_img,
+ INSTANCE_IMG_SIZE,
+ PartitionType::ANDROID_VM_INSTANCE,
+ )
+ .context("Failed to initialize instange.img")?;
+ let writable_partitions = vec![Partition {
+ label: "vm-instance".to_owned(),
+ image: Some(instance_img),
+ writable: true,
+ }];
+ vec![DiskImage { image: None, partitions: writable_partitions, writable: true }]
+ }
+ VmType::NonProtectedVm => vec![],
};
let config = VirtualMachineConfig::RawConfig(VirtualMachineRawConfig {
@@ -109,7 +104,7 @@
params: None,
bootloader: Some(ParcelFileDescriptor::new(rialto)),
disks,
- protectedVm: protected_vm,
+ protectedVm: vm_type.is_protected(),
memoryMib: 300,
cpuTopology: CpuTopology::ONE_CPU,
platformVersion: "~1.0".to_string(),
@@ -126,8 +121,8 @@
)
.context("Failed to create VM")?;
- let port = host_port(protected_vm);
- let check_socket_handle = thread::spawn(move || try_check_socket_connection(port).unwrap());
+ let check_socket_handle =
+ thread::spawn(move || try_check_socket_connection(vm_type.port()).unwrap());
vm.start().context("Failed to start VM")?;