Merge "[attestation] Add API to check AVF RKP Hal presence in VM Attestation" into main
diff --git a/libs/vmconfig/src/lib.rs b/libs/vmconfig/src/lib.rs
index 50f3c8e..907e0d3 100644
--- a/libs/vmconfig/src/lib.rs
+++ b/libs/vmconfig/src/lib.rs
@@ -15,6 +15,7 @@
 //! Struct for VM configuration with JSON (de)serialization and AIDL parcelables
 
 use android_system_virtualizationservice::{
+    aidl::android::system::virtualizationservice::CpuTopology::CpuTopology,
     aidl::android::system::virtualizationservice::DiskImage::DiskImage as AidlDiskImage,
     aidl::android::system::virtualizationservice::Partition::Partition as AidlPartition,
     aidl::android::system::virtualizationservice::VirtualMachineRawConfig::VirtualMachineRawConfig,
@@ -54,6 +55,8 @@
     /// The amount of RAM to give the VM, in MiB.
     #[serde(default)]
     pub memory_mib: Option<NonZeroU32>,
+    /// The CPU topology: either "one_cpu"(default) or "match_host"
+    pub cpu_topology: Option<String>,
     /// Version or range of versions of the virtual platform that this config is compatible with.
     /// The format follows SemVer (https://semver.org).
     pub platform_version: VersionReq,
@@ -96,7 +99,12 @@
         } else {
             0
         };
-
+        let cpu_topology = match self.cpu_topology.as_deref() {
+            None => CpuTopology::ONE_CPU,
+            Some("one_cpu") => CpuTopology::ONE_CPU,
+            Some("match_host") => CpuTopology::MATCH_HOST,
+            Some(cpu_topology) => bail!("Invalid cpu topology {}", cpu_topology),
+        };
         Ok(VirtualMachineRawConfig {
             kernel: maybe_open_parcel_file(&self.kernel, false)?,
             initrd: maybe_open_parcel_file(&self.initrd, false)?,
@@ -105,6 +113,7 @@
             disks: self.disks.iter().map(DiskImage::to_parcelable).collect::<Result<_, Error>>()?,
             protectedVm: self.protected,
             memoryMib: memory_mib,
+            cpuTopology: cpu_topology,
             platformVersion: self.platform_version.to_string(),
             devices: self
                 .devices
diff --git a/service_vm/manager/src/lib.rs b/service_vm/manager/src/lib.rs
index 8dedec5..3f2550c 100644
--- a/service_vm/manager/src/lib.rs
+++ b/service_vm/manager/src/lib.rs
@@ -32,7 +32,7 @@
 use std::io::{self, BufRead, BufReader, BufWriter, Write};
 use std::os::unix::io::FromRawFd;
 use std::path::{Path, PathBuf};
-use std::sync::{Condvar, Mutex, MutexGuard};
+use std::sync::{Condvar, Mutex};
 use std::thread;
 use std::time::Duration;
 use vmclient::{DeathReason, VmInstance};
@@ -48,40 +48,78 @@
 const WRITE_TIMEOUT: Duration = Duration::from_secs(10);
 
 lazy_static! {
-    static ref SERVICE_VM_STATE: State = State::default();
+    static ref PENDING_REQUESTS: AtomicCounter = AtomicCounter::default();
+    static ref SERVICE_VM: Mutex<Option<ServiceVm>> = Mutex::new(None);
+    static ref SERVICE_VM_SHUTDOWN: Condvar = Condvar::new();
 }
 
-/// The running state of the Service VM.
+/// Atomic counter with a condition variable that is used to wait for the counter
+/// to become positive within a timeout.
 #[derive(Debug, Default)]
-struct State {
-    is_running: Mutex<bool>,
-    stopped: Condvar,
+struct AtomicCounter {
+    num: Mutex<usize>,
+    num_increased: Condvar,
 }
 
-impl State {
-    fn wait_until_no_service_vm_running(&self) -> Result<MutexGuard<'_, bool>> {
-        // The real timeout can be longer than 10 seconds since the time to acquire
-        // is_running mutex is not counted in the 10 seconds.
-        let (guard, wait_result) = self
-            .stopped
-            .wait_timeout_while(
-                self.is_running.lock().unwrap(),
-                Duration::from_secs(10),
-                |&mut is_running| is_running,
-            )
+impl AtomicCounter {
+    /// Checks if the counter becomes positive within the given timeout.
+    fn is_positive_within_timeout(&self, timeout: Duration) -> bool {
+        let (guard, _wait_result) = self
+            .num_increased
+            .wait_timeout_while(self.num.lock().unwrap(), timeout, |&mut x| x == 0)
             .unwrap();
-        ensure!(
-            !wait_result.timed_out(),
-            "Timed out while waiting for the running service VM to stop."
-        );
-        Ok(guard)
+        *guard > 0
     }
 
-    fn notify_service_vm_shutdown(&self) {
-        let mut is_running_guard = self.is_running.lock().unwrap();
-        *is_running_guard = false;
-        self.stopped.notify_one();
+    fn increment(&self) {
+        let mut num = self.num.lock().unwrap();
+        *num = num.checked_add(1).unwrap();
+        self.num_increased.notify_all();
     }
+
+    fn decrement(&self) {
+        let mut num = self.num.lock().unwrap();
+        *num = num.checked_sub(1).unwrap();
+    }
+}
+
+/// Processes the request in the service VM.
+pub fn process_request(request: Request) -> Result<Response> {
+    PENDING_REQUESTS.increment();
+    let result = process_request_in_service_vm(request);
+    PENDING_REQUESTS.decrement();
+    thread::spawn(stop_service_vm_if_idle);
+    result
+}
+
+fn process_request_in_service_vm(request: Request) -> Result<Response> {
+    let mut service_vm = SERVICE_VM.lock().unwrap();
+    if service_vm.is_none() {
+        *service_vm = Some(ServiceVm::start()?);
+    }
+    service_vm.as_mut().unwrap().process_request(request)
+}
+
+fn stop_service_vm_if_idle() {
+    if PENDING_REQUESTS.is_positive_within_timeout(Duration::from_secs(1)) {
+        info!("Service VM has pending requests, keeping it running.");
+    } else {
+        info!("Service VM is idle, shutting it down.");
+        *SERVICE_VM.lock().unwrap() = None;
+        SERVICE_VM_SHUTDOWN.notify_all();
+    }
+}
+
+/// Waits until the service VM shuts down.
+/// This function is only used in tests.
+pub fn wait_until_service_vm_shuts_down() -> Result<()> {
+    const WAIT_FOR_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(5);
+
+    let (_guard, wait_result) = SERVICE_VM_SHUTDOWN
+        .wait_timeout_while(SERVICE_VM.lock().unwrap(), WAIT_FOR_SHUTDOWN_TIMEOUT, |x| x.is_some())
+        .unwrap();
+    ensure!(!wait_result.timed_out(), "Service VM didn't shut down within the timeout");
+    Ok(())
 }
 
 /// Service VM.
@@ -94,17 +132,12 @@
 impl ServiceVm {
     /// Starts the service VM and returns its instance.
     /// The same instance image is used for different VMs.
-    /// At any given time,  only one service should be running. If a service VM is
-    /// already running, this function will start the service VM once the running one
-    /// shuts down.
+    /// TODO(b/27593612): Remove instance image usage for Service VM.
     pub fn start() -> Result<Self> {
-        let mut is_running_guard = SERVICE_VM_STATE.wait_until_no_service_vm_running()?;
-
         let instance_img_path = Path::new(VIRT_DATA_DIR).join(INSTANCE_IMG_NAME);
         let vm = protected_vm_instance(instance_img_path)?;
 
         let vm = Self::start_vm(vm, VmType::ProtectedVm)?;
-        *is_running_guard = true;
         Ok(vm)
     }
 
@@ -174,7 +207,6 @@
             Ok(reason) => info!("Exit the service VM successfully: {reason:?}"),
             Err(e) => warn!("Service VM shutdown request failed '{e:?}', killing it."),
         }
-        SERVICE_VM_STATE.notify_service_vm_shutdown();
     }
 }
 
diff --git a/virtualizationservice/src/aidl.rs b/virtualizationservice/src/aidl.rs
index 5ddb8c3..05f3cf6 100644
--- a/virtualizationservice/src/aidl.rs
+++ b/virtualizationservice/src/aidl.rs
@@ -269,6 +269,13 @@
             .context("Failed to generate ECDSA P-256 key pair for testing")
             .with_log()
             .or_service_specific_exception(-1)?;
+        // Wait until the service VM shuts down, so that the Service VM will be restarted when
+        // the key generated in the current session will be used for attestation.
+        // This ensures that different Service VM sessions have the same KEK for the key blob.
+        service_vm_manager::wait_until_service_vm_shuts_down()
+            .context("Failed to wait until the service VM shuts down")
+            .with_log()
+            .or_service_specific_exception(-1)?;
         match res {
             Response::GenerateEcdsaP256KeyPair(key_pair) => {
                 FAKE_PROVISIONED_KEY_BLOB_FOR_TESTING
diff --git a/virtualizationservice/src/rkpvm.rs b/virtualizationservice/src/rkpvm.rs
index 67ba740..6898921 100644
--- a/virtualizationservice/src/rkpvm.rs
+++ b/virtualizationservice/src/rkpvm.rs
@@ -21,28 +21,25 @@
 use service_vm_comm::{
     ClientVmAttestationParams, GenerateCertificateRequestParams, Request, Response,
 };
-use service_vm_manager::ServiceVm;
+use service_vm_manager::process_request;
 
 pub(crate) fn request_attestation(
     csr: Vec<u8>,
     remotely_provisioned_key_blob: Vec<u8>,
     remotely_provisioned_cert: Vec<u8>,
 ) -> Result<Vec<u8>> {
-    let mut vm = ServiceVm::start()?;
-
     let params =
         ClientVmAttestationParams { csr, remotely_provisioned_key_blob, remotely_provisioned_cert };
     let request = Request::RequestClientVmAttestation(params);
-    match vm.process_request(request).context("Failed to process request")? {
+    match process_request(request).context("Failed to process request")? {
         Response::RequestClientVmAttestation(cert) => Ok(cert),
         other => bail!("Incorrect response type {other:?}"),
     }
 }
 
 pub(crate) fn generate_ecdsa_p256_key_pair() -> Result<Response> {
-    let mut vm = ServiceVm::start()?;
     let request = Request::GenerateEcdsaP256KeyPair;
-    vm.process_request(request).context("Failed to process request")
+    process_request(request).context("Failed to process request")
 }
 
 pub(crate) fn generate_certificate_request(
@@ -55,6 +52,5 @@
     };
     let request = Request::GenerateCertificateRequest(params);
 
-    let mut vm = ServiceVm::start()?;
-    vm.process_request(request).context("Failed to process request")
+    process_request(request).context("Failed to process request")
 }