Separate creation and starting of VMs.

Bug: 199127239
Test: atest VirtualizationTestCases
Change-Id: I2cb436c2acd6b4830aab0a044ed03fb688459fe0
diff --git a/compos/common/compos_client.rs b/compos/common/compos_client.rs
index dd8e54f..22304f1 100644
--- a/compos/common/compos_client.rs
+++ b/compos/common/compos_client.rs
@@ -82,7 +82,7 @@
         )
         .context("Failed to find VirtualizationService")?;
 
-        let vm = service.startVm(&config, Some(&log_fd)).context("Failed to start VM")?;
+        let vm = service.createVm(&config, Some(&log_fd)).context("Failed to create VM")?;
         let vm_state = Arc::new(VmStateMonitor::default());
 
         let vm_state_clone = Arc::clone(&vm_state);
@@ -98,6 +98,8 @@
         );
         vm.registerCallback(&callback)?;
 
+        vm.start()?;
+
         let cid = vm_state.wait_for_start()?;
 
         // TODO: Use onPayloadReady to avoid this
diff --git a/compos/compos_key_cmd/compos_key_cmd.cpp b/compos/compos_key_cmd/compos_key_cmd.cpp
index 874a208..eb11d92 100644
--- a/compos/compos_key_cmd/compos_key_cmd.cpp
+++ b/compos/compos_key_cmd/compos_key_cmd.cpp
@@ -213,7 +213,7 @@
         appConfig.memoryMib = 0; // Use default
 
         LOG(INFO) << "Starting VM";
-        auto status = service->startVm(config, logFd, &mVm);
+        auto status = service->createVm(config, logFd, &mVm);
         if (!status.isOk()) {
             return Error() << status.getDescription();
         }
@@ -224,7 +224,7 @@
             return Error() << status.getDescription();
         }
 
-        LOG(INFO) << "Started VM with cid = " << cid;
+        LOG(INFO) << "Created VM with CID = " << cid;
 
         // We need to use this rather than std::make_shared to make sure the
         // embedded weak_ptr is initialised.
@@ -235,6 +235,12 @@
             return Error() << status.getDescription();
         }
 
+        status = mVm->start();
+        if (!status.isOk()) {
+            return Error() << status.getDescription();
+        }
+        LOG(INFO) << "Started VM";
+
         if (!mCallback->waitForStarted()) {
             return Error() << "VM Payload failed to start";
         }
diff --git a/javalib/src/android/system/virtualmachine/VirtualMachine.java b/javalib/src/android/system/virtualmachine/VirtualMachine.java
index 1a752e3..2da7ecb 100644
--- a/javalib/src/android/system/virtualmachine/VirtualMachine.java
+++ b/javalib/src/android/system/virtualmachine/VirtualMachine.java
@@ -236,6 +236,8 @@
         try {
             if (mVirtualMachine != null) {
                 switch (mVirtualMachine.getState()) {
+                    case VirtualMachineState.NOT_STARTED:
+                        return Status.STOPPED;
                     case VirtualMachineState.STARTING:
                     case VirtualMachineState.STARTED:
                     case VirtualMachineState.READY:
@@ -308,7 +310,7 @@
             android.system.virtualizationservice.VirtualMachineConfig vmConfigParcel =
                     android.system.virtualizationservice.VirtualMachineConfig.appConfig(appConfig);
 
-            mVirtualMachine = service.startVm(vmConfigParcel, mConsoleWriter);
+            mVirtualMachine = service.createVm(vmConfigParcel, mConsoleWriter);
             mVirtualMachine.registerCallback(
                     new IVirtualMachineCallback.Stub() {
                         @Override
@@ -359,7 +361,7 @@
                                 }
                             },
                             0);
-
+            mVirtualMachine.start();
         } catch (IOException e) {
             throw new VirtualMachineException(e);
         } catch (RemoteException e) {
diff --git a/tests/vsock_test.cc b/tests/vsock_test.cc
index a594e6d..480d05a 100644
--- a/tests/vsock_test.cc
+++ b/tests/vsock_test.cc
@@ -85,14 +85,17 @@
 
     VirtualMachineConfig config(std::move(raw_config));
     sp<IVirtualMachine> vm;
-    status = virtualization_service->startVm(config, std::nullopt, &vm);
-    ASSERT_TRUE(status.isOk()) << "Error starting VM: " << status;
+    status = virtualization_service->createVm(config, std::nullopt, &vm);
+    ASSERT_TRUE(status.isOk()) << "Error creating VM: " << status;
 
     int32_t cid;
     status = vm->getCid(&cid);
     ASSERT_TRUE(status.isOk()) << "Error getting CID: " << status;
     LOG(INFO) << "VM starting with CID " << cid;
 
+    status = vm->start();
+    ASSERT_TRUE(status.isOk()) << "Error starting VM: " << status;
+
     LOG(INFO) << "Accepting connection...";
     struct sockaddr_vm client_sa;
     socklen_t client_sa_len = sizeof(client_sa);
diff --git a/virtualizationservice/aidl/android/system/virtualizationservice/IVirtualMachine.aidl b/virtualizationservice/aidl/android/system/virtualizationservice/IVirtualMachine.aidl
index 3c89cd7..6562159 100644
--- a/virtualizationservice/aidl/android/system/virtualizationservice/IVirtualMachine.aidl
+++ b/virtualizationservice/aidl/android/system/virtualizationservice/IVirtualMachine.aidl
@@ -34,6 +34,9 @@
      */
     void registerCallback(IVirtualMachineCallback callback);
 
+    /** Starts running the VM. */
+    void start();
+
     /** Open a vsock connection to the CID of the VM on the given port. */
     ParcelFileDescriptor connectVsock(int port);
 }
diff --git a/virtualizationservice/aidl/android/system/virtualizationservice/IVirtualizationService.aidl b/virtualizationservice/aidl/android/system/virtualizationservice/IVirtualizationService.aidl
index 7c4b897..8be7331 100644
--- a/virtualizationservice/aidl/android/system/virtualizationservice/IVirtualizationService.aidl
+++ b/virtualizationservice/aidl/android/system/virtualizationservice/IVirtualizationService.aidl
@@ -22,10 +22,10 @@
 
 interface IVirtualizationService {
     /**
-     * Start the VM with the given config file, and return a handle to it. If `logFd` is provided
-     * then console logs from the VM will be sent to it.
+     * Create the VM with the given config file, and return a handle to it ready to start it. If
+     * `logFd` is provided then console logs from the VM will be sent to it.
      */
-    IVirtualMachine startVm(
+    IVirtualMachine createVm(
             in VirtualMachineConfig config, in @nullable ParcelFileDescriptor logFd);
 
     /**
diff --git a/virtualizationservice/aidl/android/system/virtualizationservice/VirtualMachineState.aidl b/virtualizationservice/aidl/android/system/virtualizationservice/VirtualMachineState.aidl
index 621887f..b1aebfd 100644
--- a/virtualizationservice/aidl/android/system/virtualizationservice/VirtualMachineState.aidl
+++ b/virtualizationservice/aidl/android/system/virtualizationservice/VirtualMachineState.aidl
@@ -21,6 +21,10 @@
 @Backing(type="int")
 enum VirtualMachineState {
     /**
+     * The VM has been created but not yet started.
+     */
+    NOT_STARTED = 0,
+    /**
      * The VM is running, but the payload has not yet started.
      */
     STARTING = 1,
diff --git a/virtualizationservice/src/aidl.rs b/virtualizationservice/src/aidl.rs
index ad89ba5..2c1011b 100644
--- a/virtualizationservice/src/aidl.rs
+++ b/virtualizationservice/src/aidl.rs
@@ -15,7 +15,7 @@
 //! Implementation of the AIDL interface of the VirtualizationService.
 
 use crate::composite::make_composite_image;
-use crate::crosvm::{CrosvmConfig, DiskFile, PayloadState, VmInstance};
+use crate::crosvm::{CrosvmConfig, DiskFile, PayloadState, VmInstance, VmState};
 use crate::payload::add_microdroid_images;
 use crate::{Cid, FIRST_GUEST_CID};
 
@@ -85,10 +85,11 @@
 impl Interface for VirtualizationService {}
 
 impl IVirtualizationService for VirtualizationService {
-    /// Create and start a new VM with the given configuration, assigning it the next available CID.
+    /// Creates (but does not start) a new VM with the given configuration, assigning it the next
+    /// available CID.
     ///
     /// Returns a binder `IVirtualMachine` object referring to it, as a handle for the client.
-    fn startVm(
+    fn createVm(
         &self,
         config: &VirtualMachineConfig,
         log_fd: Option<&ParcelFileDescriptor>,
@@ -174,20 +175,22 @@
             log_fd,
             indirect_files,
         };
-        let instance = VmInstance::start(
-            crosvm_config,
-            temporary_directory,
-            requester_uid,
-            requester_sid,
-            requester_debug_pid,
-        )
-        .map_err(|e| {
-            error!("Failed to start VM with config {:?}: {}", config, e);
-            new_binder_exception(
-                ExceptionCode::SERVICE_SPECIFIC,
-                format!("Failed to start VM: {}", e),
+        let instance = Arc::new(
+            VmInstance::new(
+                crosvm_config,
+                temporary_directory,
+                requester_uid,
+                requester_sid,
+                requester_debug_pid,
             )
-        })?;
+            .map_err(|e| {
+                error!("Failed to create VM with config {:?}: {}", config, e);
+                new_binder_exception(
+                    ExceptionCode::SERVICE_SPECIFIC,
+                    format!("Failed to create VM: {}", e),
+                )
+            })?,
+        );
         state.add_vm(Arc::downgrade(&instance));
         Ok(VirtualMachine::create(instance))
     }
@@ -513,8 +516,8 @@
                 }
             }
         } else {
-            error!("Missing SID on startVm");
-            Err(new_binder_exception(ExceptionCode::SECURITY, "Missing SID on startVm"))
+            error!("Missing SID on createVm");
+            Err(new_binder_exception(ExceptionCode::SECURITY, "Missing SID on createVm"))
         }
     })
 }
@@ -589,12 +592,16 @@
         Ok(())
     }
 
+    fn start(&self) -> binder::Result<()> {
+        self.instance.start().map_err(|e| {
+            error!("Error starting VM with CID {}: {:?}", self.instance.cid, e);
+            new_binder_exception(ExceptionCode::SERVICE_SPECIFIC, e.to_string())
+        })
+    }
+
     fn connectVsock(&self, port: i32) -> binder::Result<ParcelFileDescriptor> {
-        if !self.instance.running() {
-            return Err(new_binder_exception(
-                ExceptionCode::SERVICE_SPECIFIC,
-                "VM is no longer running",
-            ));
+        if !matches!(&*self.instance.vm_state.lock().unwrap(), VmState::Running { .. }) {
+            return Err(new_binder_exception(ExceptionCode::SERVICE_SPECIFIC, "VM is not running"));
         }
         let stream =
             VsockStream::connect_with_cid_port(self.instance.cid, port as u32).map_err(|e| {
@@ -734,15 +741,16 @@
 
 /// Gets the `VirtualMachineState` of the given `VmInstance`.
 fn get_state(instance: &VmInstance) -> VirtualMachineState {
-    if instance.running() {
-        match instance.payload_state() {
+    match &*instance.vm_state.lock().unwrap() {
+        VmState::NotStarted { .. } => VirtualMachineState::NOT_STARTED,
+        VmState::Running { .. } => match instance.payload_state() {
             PayloadState::Starting => VirtualMachineState::STARTING,
             PayloadState::Started => VirtualMachineState::STARTED,
             PayloadState::Ready => VirtualMachineState::READY,
             PayloadState::Finished => VirtualMachineState::FINISHED,
-        }
-    } else {
-        VirtualMachineState::DEAD
+        },
+        VmState::Dead => VirtualMachineState::DEAD,
+        VmState::Failed => VirtualMachineState::DEAD,
     }
 }
 
diff --git a/virtualizationservice/src/crosvm.rs b/virtualizationservice/src/crosvm.rs
index 4b6a351..38e5bf3 100644
--- a/virtualizationservice/src/crosvm.rs
+++ b/virtualizationservice/src/crosvm.rs
@@ -21,11 +21,11 @@
 use log::{debug, error, info};
 use shared_child::SharedChild;
 use std::fs::{remove_dir_all, File};
+use std::mem;
 use std::num::NonZeroU32;
 use std::os::unix::io::{AsRawFd, RawFd};
 use std::path::PathBuf;
 use std::process::Command;
-use std::sync::atomic::{AtomicBool, Ordering};
 use std::sync::{Arc, Mutex};
 use std::thread;
 use vsock::VsockStream;
@@ -66,11 +66,55 @@
     Finished,
 }
 
-/// Information about a particular instance of a VM which is running.
+/// The current state of the VM itself.
+#[derive(Debug)]
+pub enum VmState {
+    /// The VM has not yet tried to start.
+    NotStarted {
+        ///The configuration needed to start the VM, if it has not yet been started.
+        config: CrosvmConfig,
+    },
+    /// The VM has been started.
+    Running {
+        /// The crosvm child process.
+        child: Arc<SharedChild>,
+    },
+    /// The VM died or was killed.
+    Dead,
+    /// The VM failed to start.
+    Failed,
+}
+
+impl VmState {
+    /// Tries to start the VM, if it is in the `NotStarted` state.
+    ///
+    /// Returns an error if the VM is in the wrong state, or fails to start.
+    fn start(&mut self, instance: Arc<VmInstance>) -> Result<(), Error> {
+        let state = mem::replace(self, VmState::Failed);
+        if let VmState::NotStarted { config } = state {
+            // If this fails and returns an error, `self` will be left in the `Failed` state.
+            let child = Arc::new(run_vm(config)?);
+
+            let child_clone = child.clone();
+            thread::spawn(move || {
+                instance.monitor(child_clone);
+            });
+
+            // If it started correctly, update the state.
+            *self = VmState::Running { child };
+            Ok(())
+        } else {
+            *self = state;
+            bail!("VM already started or failed")
+        }
+    }
+}
+
+/// Information about a particular instance of a VM which may be running.
 #[derive(Debug)]
 pub struct VmInstance {
-    /// The crosvm child process.
-    child: SharedChild,
+    /// The current state of the VM.
+    pub vm_state: Mutex<VmState>,
     /// The CID assigned to the VM for vsock communication.
     pub cid: Cid,
     /// Whether the VM is a protected VM.
@@ -84,8 +128,6 @@
     /// The PID of the process which requested the VM. Note that this process may no longer exist
     /// and the PID may have been reused for a different process, so this should not be trusted.
     pub requester_debug_pid: i32,
-    /// Whether the VM is still running.
-    running: AtomicBool,
     /// Callbacks to clients of the VM.
     pub callbacks: VirtualMachineCallbacks,
     /// Input/output stream of the payload run in the VM.
@@ -95,69 +137,53 @@
 }
 
 impl VmInstance {
-    /// Create a new `VmInstance` for the given process.
-    fn new(
-        child: SharedChild,
-        cid: Cid,
-        protected: bool,
-        temporary_directory: PathBuf,
-        requester_uid: u32,
-        requester_sid: String,
-        requester_debug_pid: i32,
-    ) -> VmInstance {
-        VmInstance {
-            child,
-            cid,
-            protected,
-            temporary_directory,
-            requester_uid,
-            requester_sid,
-            requester_debug_pid,
-            running: AtomicBool::new(true),
-            callbacks: Default::default(),
-            stream: Mutex::new(None),
-            payload_state: Mutex::new(PayloadState::Starting),
-        }
-    }
-
-    /// Start an instance of `crosvm` to manage a new VM. The `crosvm` instance will be killed when
-    /// the `VmInstance` is dropped.
-    pub fn start(
+    /// Validates the given config and creates a new `VmInstance` but doesn't start running it.
+    pub fn new(
         config: CrosvmConfig,
         temporary_directory: PathBuf,
         requester_uid: u32,
         requester_sid: String,
         requester_debug_pid: i32,
-    ) -> Result<Arc<VmInstance>, Error> {
+    ) -> Result<VmInstance, Error> {
+        validate_config(&config)?;
         let cid = config.cid;
         let protected = config.protected;
-        let child = run_vm(config)?;
-        let instance = Arc::new(VmInstance::new(
-            child,
+        Ok(VmInstance {
+            vm_state: Mutex::new(VmState::NotStarted { config }),
             cid,
             protected,
             temporary_directory,
             requester_uid,
             requester_sid,
             requester_debug_pid,
-        ));
-
-        let instance_clone = instance.clone();
-        thread::spawn(move || {
-            instance_clone.monitor();
-        });
-
-        Ok(instance)
+            callbacks: Default::default(),
+            stream: Mutex::new(None),
+            payload_state: Mutex::new(PayloadState::Starting),
+        })
     }
 
-    /// Wait for the crosvm child process to finish, then mark the VM as no longer running and call
-    /// any callbacks.
-    fn monitor(&self) {
-        match self.child.wait() {
+    /// Starts an instance of `crosvm` to manage the VM. The `crosvm` instance will be killed when
+    /// the `VmInstance` is dropped.
+    pub fn start(self: &Arc<Self>) -> Result<(), Error> {
+        self.vm_state.lock().unwrap().start(self.clone())
+    }
+
+    /// Waits for the crosvm child process to finish, then marks the VM as no longer running and
+    /// calls any callbacks.
+    ///
+    /// This takes a separate reference to the `SharedChild` rather than using the one in
+    /// `self.vm_state` to avoid holding the lock on `vm_state` while it is running.
+    fn monitor(&self, child: Arc<SharedChild>) {
+        match child.wait() {
             Err(e) => error!("Error waiting for crosvm instance to die: {}", e),
             Ok(status) => info!("crosvm exited with status {}", status),
         }
-        self.running.store(false, Ordering::Release);
+
+        let mut vm_state = self.vm_state.lock().unwrap();
+        *vm_state = VmState::Dead;
+        // Ensure that the mutex is released before calling the callbacks.
+        drop(vm_state);
+
         self.callbacks.callback_on_died(self.cid);
 
         // Delete temporary files.
@@ -166,11 +192,6 @@
         }
     }
 
-    /// Return whether `crosvm` is still running the VM.
-    pub fn running(&self) -> bool {
-        self.running.load(Ordering::Acquire)
-    }
-
     /// Returns the last reported state of the VM payload.
     pub fn payload_state(&self) -> PayloadState {
         *self.payload_state.lock().unwrap()
@@ -189,11 +210,14 @@
         }
     }
 
-    /// Kill the crosvm instance.
+    /// Kills the crosvm instance, if it is running.
     pub fn kill(&self) {
-        // TODO: Talk to crosvm to shutdown cleanly.
-        if let Err(e) = self.child.kill() {
-            error!("Error killing crosvm instance: {}", e);
+        let vm_state = &*self.vm_state.lock().unwrap();
+        if let VmState::Running { child } = vm_state {
+            // TODO: Talk to crosvm to shutdown cleanly.
+            if let Err(e) = child.kill() {
+                error!("Error killing crosvm instance: {}", e);
+            }
         }
     }
 }
diff --git a/vm/src/run.rs b/vm/src/run.rs
index ccb4085..0d34a97 100644
--- a/vm/src/run.rs
+++ b/vm/src/run.rs
@@ -16,15 +16,16 @@
 
 use crate::create_partition::command_create_partition;
 use crate::sync::AtomicFlag;
-use android_system_virtualizationservice::aidl::android::system::virtualizationservice::IVirtualizationService::IVirtualizationService;
-use android_system_virtualizationservice::aidl::android::system::virtualizationservice::IVirtualMachine::IVirtualMachine;
 use android_system_virtualizationservice::aidl::android::system::virtualizationservice::IVirtualMachineCallback::{
     BnVirtualMachineCallback, IVirtualMachineCallback,
 };
 use android_system_virtualizationservice::aidl::android::system::virtualizationservice::{
+    IVirtualMachine::IVirtualMachine,
+    IVirtualizationService::IVirtualizationService,
     PartitionType::PartitionType,
     VirtualMachineAppConfig::VirtualMachineAppConfig,
     VirtualMachineConfig::VirtualMachineConfig,
+    VirtualMachineState::VirtualMachineState,
 };
 use android_system_virtualizationservice::binder::{
     BinderFeatures, DeathRecipient, IBinder, ParcelFileDescriptor, Strong,
@@ -100,6 +101,18 @@
     )
 }
 
+fn state_to_str(vm_state: VirtualMachineState) -> &'static str {
+    match vm_state {
+        VirtualMachineState::NOT_STARTED => "NOT_STARTED",
+        VirtualMachineState::STARTING => "STARTING",
+        VirtualMachineState::STARTED => "STARTED",
+        VirtualMachineState::READY => "READY",
+        VirtualMachineState::FINISHED => "FINISHED",
+        VirtualMachineState::DEAD => "DEAD",
+        _ => "(invalid state)",
+    }
+}
+
 fn run(
     service: Strong<dyn IVirtualizationService>,
     config: &VirtualMachineConfig,
@@ -117,10 +130,17 @@
     } else {
         Some(ParcelFileDescriptor::new(duplicate_stdout()?))
     };
-    let vm = service.startVm(config, stdout.as_ref()).context("Failed to start VM")?;
+    let vm = service.createVm(config, stdout.as_ref()).context("Failed to create VM")?;
 
     let cid = vm.getCid().context("Failed to get CID")?;
-    println!("Started VM from {} with CID {}.", config_path, cid);
+    println!(
+        "Created VM from {} with CID {}, state is {}.",
+        config_path,
+        cid,
+        state_to_str(vm.getState()?)
+    );
+    vm.start()?;
+    println!("Started VM, state now {}.", state_to_str(vm.getState()?));
 
     if daemonize {
         // Pass the VM reference back to VirtualizationService and have it hold it in the