[refactoring] Add VmType enum to service_vm_comm

to improve the readability of the code.

Bug: 299089107
Test: atest rialto_test
Change-Id: Ic8fec11701e90815e83c01b1914cb363604bff50
diff --git a/libs/service_vm_comm/src/lib.rs b/libs/service_vm_comm/src/lib.rs
index 6f03209..3b53b63 100644
--- a/libs/service_vm_comm/src/lib.rs
+++ b/libs/service_vm_comm/src/lib.rs
@@ -23,4 +23,4 @@
 mod vsock;
 
 pub use message::{EcdsaP256KeyPair, GenerateCertificateRequestParams, Request, Response};
-pub use vsock::host_port;
+pub use vsock::VmType;
diff --git a/libs/service_vm_comm/src/vsock.rs b/libs/service_vm_comm/src/vsock.rs
index fd6f088..aa7166d 100644
--- a/libs/service_vm_comm/src/vsock.rs
+++ b/libs/service_vm_comm/src/vsock.rs
@@ -14,14 +14,34 @@
 
 //! Vsock setup shared between the host and the service VM.
 
-/// Returns the host port number for the given VM protection state.
-pub fn host_port(is_protected_vm: bool) -> u32 {
-    const PROTECTED_VM_PORT: u32 = 5679;
-    const NON_PROTECTED_VM_PORT: u32 = 5680;
+const PROTECTED_VM_PORT: u32 = 5679;
+const NON_PROTECTED_VM_PORT: u32 = 5680;
 
-    if is_protected_vm {
-        PROTECTED_VM_PORT
-    } else {
-        NON_PROTECTED_VM_PORT
+/// VM Type.
+#[derive(Clone, Copy, Debug)]
+pub enum VmType {
+    /// Protected VM.
+    ProtectedVm,
+
+    /// NonProtectev VM.
+    NonProtectedVm,
+}
+
+impl VmType {
+    /// Returns the port number used for the vsock communication between
+    /// the host and the service VM.
+    pub fn port(&self) -> u32 {
+        match self {
+            Self::ProtectedVm => PROTECTED_VM_PORT,
+            Self::NonProtectedVm => NON_PROTECTED_VM_PORT,
+        }
+    }
+
+    /// Returns whether it is a protected VM.
+    pub fn is_protected(&self) -> bool {
+        match self {
+            Self::ProtectedVm => true,
+            Self::NonProtectedVm => false,
+        }
     }
 }
diff --git a/rialto/src/main.rs b/rialto/src/main.rs
index 0ecbe9d..d777b2d 100644
--- a/rialto/src/main.rs
+++ b/rialto/src/main.rs
@@ -33,6 +33,7 @@
 use hyp::{get_mem_sharer, get_mmio_guard};
 use libfdt::FdtError;
 use log::{debug, error, info};
+use service_vm_comm::VmType;
 use virtio_drivers::{
     device::socket::{VsockAddr, VMADDR_CID_HOST},
     transport::{pci::bus::PciRoot, DeviceType, Transport},
@@ -52,12 +53,16 @@
 };
 
 fn host_addr() -> VsockAddr {
-    VsockAddr { cid: VMADDR_CID_HOST, port: service_vm_comm::host_port(is_protected_vm()) }
+    VsockAddr { cid: VMADDR_CID_HOST, port: vm_type().port() }
 }
 
-fn is_protected_vm() -> bool {
+fn vm_type() -> VmType {
     // Use MMIO support to determine whether the VM is protected.
-    get_mmio_guard().is_some()
+    if get_mmio_guard().is_some() {
+        VmType::ProtectedVm
+    } else {
+        VmType::NonProtectedVm
+    }
 }
 
 fn new_page_table() -> Result<PageTable> {
diff --git a/rialto/tests/test.rs b/rialto/tests/test.rs
index e9bdab6..f7df217 100644
--- a/rialto/tests/test.rs
+++ b/rialto/tests/test.rs
@@ -24,7 +24,7 @@
 };
 use anyhow::{anyhow, bail, Context, Result};
 use log::info;
-use service_vm_comm::{host_port, Request, Response};
+use service_vm_comm::{Request, Response, VmType};
 use std::fs::File;
 use std::io::{self, BufRead, BufReader, BufWriter, Write};
 use std::os::unix::io::FromRawFd;
@@ -41,21 +41,15 @@
 
 #[test]
 fn boot_rialto_in_protected_vm_successfully() -> Result<()> {
-    boot_rialto_successfully(
-        SIGNED_RIALTO_PATH,
-        true, // protected_vm
-    )
+    boot_rialto_successfully(SIGNED_RIALTO_PATH, VmType::ProtectedVm)
 }
 
 #[test]
 fn boot_rialto_in_unprotected_vm_successfully() -> Result<()> {
-    boot_rialto_successfully(
-        UNSIGNED_RIALTO_PATH,
-        false, // protected_vm
-    )
+    boot_rialto_successfully(UNSIGNED_RIALTO_PATH, VmType::NonProtectedVm)
 }
 
-fn boot_rialto_successfully(rialto_path: &str, protected_vm: bool) -> Result<()> {
+fn boot_rialto_successfully(rialto_path: &str, vm_type: VmType) -> Result<()> {
     android_logger::init_once(
         android_logger::Config::default().with_tag("rialto").with_min_level(log::Level::Debug),
     );
@@ -76,30 +70,31 @@
     let console = android_log_fd()?;
     let log = android_log_fd()?;
 
-    let disks = if protected_vm {
-        let instance_img = File::options()
-            .create(true)
-            .read(true)
-            .write(true)
-            .truncate(true)
-            .open(INSTANCE_IMG_PATH)?;
-        let instance_img = ParcelFileDescriptor::new(instance_img);
+    let disks = match vm_type {
+        VmType::ProtectedVm => {
+            let instance_img = File::options()
+                .create(true)
+                .read(true)
+                .write(true)
+                .truncate(true)
+                .open(INSTANCE_IMG_PATH)?;
+            let instance_img = ParcelFileDescriptor::new(instance_img);
 
-        service
-            .initializeWritablePartition(
-                &instance_img,
-                INSTANCE_IMG_SIZE,
-                PartitionType::ANDROID_VM_INSTANCE,
-            )
-            .context("Failed to initialize instange.img")?;
-        let writable_partitions = vec![Partition {
-            label: "vm-instance".to_owned(),
-            image: Some(instance_img),
-            writable: true,
-        }];
-        vec![DiskImage { image: None, partitions: writable_partitions, writable: true }]
-    } else {
-        vec![]
+            service
+                .initializeWritablePartition(
+                    &instance_img,
+                    INSTANCE_IMG_SIZE,
+                    PartitionType::ANDROID_VM_INSTANCE,
+                )
+                .context("Failed to initialize instange.img")?;
+            let writable_partitions = vec![Partition {
+                label: "vm-instance".to_owned(),
+                image: Some(instance_img),
+                writable: true,
+            }];
+            vec![DiskImage { image: None, partitions: writable_partitions, writable: true }]
+        }
+        VmType::NonProtectedVm => vec![],
     };
 
     let config = VirtualMachineConfig::RawConfig(VirtualMachineRawConfig {
@@ -109,7 +104,7 @@
         params: None,
         bootloader: Some(ParcelFileDescriptor::new(rialto)),
         disks,
-        protectedVm: protected_vm,
+        protectedVm: vm_type.is_protected(),
         memoryMib: 300,
         cpuTopology: CpuTopology::ONE_CPU,
         platformVersion: "~1.0".to_string(),
@@ -126,8 +121,8 @@
     )
     .context("Failed to create VM")?;
 
-    let port = host_port(protected_vm);
-    let check_socket_handle = thread::spawn(move || try_check_socket_connection(port).unwrap());
+    let check_socket_handle =
+        thread::spawn(move || try_check_socket_connection(vm_type.port()).unwrap());
 
     vm.start().context("Failed to start VM")?;