Merge "virtmgr: add CallingPartition enum" into main
diff --git a/android/virtmgr/src/aidl.rs b/android/virtmgr/src/aidl.rs
index 4e17daa..91a05e3 100644
--- a/android/virtmgr/src/aidl.rs
+++ b/android/virtmgr/src/aidl.rs
@@ -420,40 +420,83 @@
     }
 }
 
-fn find_partition(path: Option<&Path>) -> binder::Result<String> {
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
+enum CallingPartition {
+    Odm,
+    Product,
+    System,
+    SystemExt,
+    Vendor,
+    Unknown,
+}
+
+impl CallingPartition {
+    fn as_str(&self) -> &'static str {
+        match self {
+            CallingPartition::Odm => "odm",
+            CallingPartition::Product => "product",
+            CallingPartition::System => "system",
+            CallingPartition::SystemExt => "system_ext",
+            CallingPartition::Vendor => "vendor",
+            CallingPartition::Unknown => "[unknown]",
+        }
+    }
+}
+
+impl std::fmt::Display for CallingPartition {
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        write!(f, "{}", self.as_str())
+    }
+}
+
+fn find_partition(path: Option<&Path>) -> binder::Result<CallingPartition> {
     let Some(path) = path else {
-        return Ok("system".to_owned());
+        return Ok(CallingPartition::System);
     };
     if path.starts_with("/system/system_ext/") {
-        return Ok("system_ext".to_owned());
-    } else if path.starts_with("/system/product/") {
-        return Ok("product".to_owned());
+        return Ok(CallingPartition::SystemExt);
     }
-    let mut components = path.components();
-    let Some(std::path::Component::Normal(partition)) = components.nth(1) else {
-        return Err(anyhow!("Can't find partition in '{}'", path.display()))
-            .or_service_specific_exception(-1);
-    };
-
-    // If path is under /apex, find a partition of the preinstalled .apex path
-    if partition == "apex" {
-        let Some(std::path::Component::Normal(apex_name)) = components.next() else {
-            return Err(anyhow!("Can't find apex name for '{}'", path.display()))
+    if path.starts_with("/system/product/") {
+        return Ok(CallingPartition::Product);
+    }
+    let partition = {
+        let mut components = path.components();
+        let Some(std::path::Component::Normal(partition)) = components.nth(1) else {
+            return Err(anyhow!("Can't find partition in '{}'", path.display()))
                 .or_service_specific_exception(-1);
         };
-        let apex_info_list = ApexInfoList::load()
-            .context("Failed to get apex info list")
-            .or_service_specific_exception(-1)?;
-        return apex_info_list
-            .list
-            .iter()
-            .find(|apex_info| apex_info.name.as_str() == apex_name)
-            .map(|apex_info| apex_info.partition.to_lowercase())
-            .ok_or(anyhow!("Can't find apex info for {apex_name:?}"))
-            .or_service_specific_exception(-1);
-    }
 
-    Ok(partition.to_string_lossy().into_owned())
+        // If path is under /apex, find a partition of the preinstalled .apex path
+        if partition == "apex" {
+            let Some(std::path::Component::Normal(apex_name)) = components.next() else {
+                return Err(anyhow!("Can't find apex name for '{}'", path.display()))
+                    .or_service_specific_exception(-1);
+            };
+            let apex_info_list = ApexInfoList::load()
+                .context("Failed to get apex info list")
+                .or_service_specific_exception(-1)?;
+            apex_info_list
+                .list
+                .iter()
+                .find(|apex_info| apex_info.name.as_str() == apex_name)
+                .map(|apex_info| apex_info.partition.to_lowercase())
+                .ok_or(anyhow!("Can't find apex info for {apex_name:?}"))
+                .or_service_specific_exception(-1)?
+        } else {
+            partition.to_string_lossy().into_owned()
+        }
+    };
+    Ok(match partition.as_str() {
+        "odm" => CallingPartition::Odm,
+        "product" => CallingPartition::Product,
+        "system" => CallingPartition::System,
+        "system_ext" => CallingPartition::SystemExt,
+        "vendor" => CallingPartition::Vendor,
+        _ => {
+            warn!("unknown partition for '{}'", path.display());
+            CallingPartition::Unknown
+        }
+    })
 }
 
 impl VirtualizationService {
@@ -471,7 +514,7 @@
             VirtualMachineConfig::AppConfig(config) => &config.name,
         };
         let calling_partition = find_partition(calling_exe_path)?;
-        let early_vm = find_early_vm_for_partition(&calling_partition, name)
+        let early_vm = find_early_vm_for_partition(calling_partition, name)
             .or_service_specific_exception(-1)?;
         let calling_exe_path = match calling_exe_path {
             Some(path) => path,
@@ -686,7 +729,7 @@
         // Check if files for payloads and bases are NOT coming from /vendor and /odm, as they may
         // have unstable interfaces.
         // TODO(b/316431494): remove once Treble interfaces are stabilized.
-        check_partitions_for_files(config, &find_partition(CALLING_EXE_PATH.as_deref())?)
+        check_partitions_for_files(config, find_partition(CALLING_EXE_PATH.as_deref())?)
             .or_service_specific_exception(-1)?;
 
         let zero_filler_path = temporary_directory.join("zero.img");
@@ -1328,7 +1371,10 @@
     Ok(vm_config)
 }
 
-fn check_partition_for_file(fd: &ParcelFileDescriptor, calling_partition: &str) -> Result<()> {
+fn check_partition_for_file(
+    fd: &ParcelFileDescriptor,
+    calling_partition: CallingPartition,
+) -> Result<()> {
     let path = format!("/proc/self/fd/{}", fd.as_raw_fd());
     let link = fs::read_link(&path).context(format!("can't read_link {path}"))?;
 
@@ -1339,7 +1385,8 @@
     }
 
     let is_fd_vendor = link.starts_with("/vendor") || link.starts_with("/odm");
-    let is_caller_vendor = calling_partition == "vendor" || calling_partition == "odm";
+    let is_caller_vendor =
+        calling_partition == CallingPartition::Vendor || calling_partition == CallingPartition::Odm;
 
     if is_fd_vendor != is_caller_vendor {
         bail!("{} can't be used for VM client in {calling_partition}", link.display());
@@ -1350,7 +1397,7 @@
 
 fn check_partitions_for_files(
     config: &VirtualMachineRawConfig,
-    calling_partition: &str,
+    calling_partition: CallingPartition,
 ) -> Result<()> {
     config
         .disks
@@ -2251,15 +2298,15 @@
     early_vm: Vec<EarlyVm>,
 }
 
-static EARLY_VMS_CACHE: LazyLock<Mutex<HashMap<String, Vec<EarlyVm>>>> =
+static EARLY_VMS_CACHE: LazyLock<Mutex<HashMap<CallingPartition, Vec<EarlyVm>>>> =
     LazyLock::new(|| Mutex::new(HashMap::new()));
 
-fn range_for_partition(partition: &str) -> Result<Range<Cid>> {
+fn range_for_partition(partition: CallingPartition) -> Range<Cid> {
     match partition {
-        "system" => Ok(100..200),
-        "system_ext" | "product" => Ok(200..300),
-        "vendor" | "odm" => Ok(300..400),
-        _ => Err(anyhow!("Early VMs are not supported for {partition}")),
+        CallingPartition::System => 100..200,
+        CallingPartition::SystemExt | CallingPartition::Product => 200..300,
+        CallingPartition::Vendor | CallingPartition::Odm => 300..400,
+        CallingPartition::Unknown => 0..0,
     }
 }
 
@@ -2294,10 +2341,10 @@
     Ok(())
 }
 
-fn get_early_vms_in_partition(partition: &str) -> Result<Vec<EarlyVm>> {
+fn get_early_vms_in_partition(partition: CallingPartition) -> Result<Vec<EarlyVm>> {
     let mut cache = EARLY_VMS_CACHE.lock().unwrap();
 
-    if let Some(result) = cache.get(partition) {
+    if let Some(result) = cache.get(&partition) {
         return Ok(result.clone());
     }
 
@@ -2310,10 +2357,10 @@
         }
     }
 
-    validate_cid_range(&early_vms, &range_for_partition(partition)?)
+    validate_cid_range(&early_vms, &range_for_partition(partition))
         .with_context(|| format!("CID validation for {partition} failed"))?;
 
-    cache.insert(partition.to_owned(), early_vms.clone());
+    cache.insert(partition, early_vms.clone());
 
     Ok(early_vms)
 }
@@ -2336,7 +2383,7 @@
     found_vm.ok_or_else(|| anyhow!("Can't find a VM named '{name}'"))
 }
 
-fn find_early_vm_for_partition(partition: &str, name: &str) -> Result<EarlyVm> {
+fn find_early_vm_for_partition(partition: CallingPartition, name: &str) -> Result<EarlyVm> {
     let early_vms = get_early_vms_in_partition(partition)
         .with_context(|| format!("Failed to get early VMs from {partition}"))?;
 
@@ -2659,7 +2706,7 @@
     fn test_symlink_to_system_ext_supported() -> Result<()> {
         let link_path = Path::new("/system/system_ext/file");
         let partition = find_partition(Some(link_path)).unwrap();
-        assert_eq!("system_ext", partition);
+        assert_eq!(CallingPartition::SystemExt, partition);
         Ok(())
     }
 
@@ -2667,7 +2714,7 @@
     fn test_symlink_to_product_supported() -> Result<()> {
         let link_path = Path::new("/system/product/file");
         let partition = find_partition(Some(link_path)).unwrap();
-        assert_eq!("product", partition);
+        assert_eq!(CallingPartition::Product, partition);
         Ok(())
     }