Merge "Make packages/modules/Virtualization tests available to the platform" into main
diff --git a/android/virtmgr/src/aidl.rs b/android/virtmgr/src/aidl.rs
index 0c5480b..f684a2a 100644
--- a/android/virtmgr/src/aidl.rs
+++ b/android/virtmgr/src/aidl.rs
@@ -802,6 +802,25 @@
 
         let detect_hangup = is_app_config && gdb_port.is_none();
 
+        let custom_memory_backing_files = config
+            .customMemoryBackingFiles
+            .iter()
+            .map(|memory_backing_file| {
+                Ok((
+                    clone_file(
+                        memory_backing_file
+                            .file
+                            .as_ref()
+                            .context("missing CustomMemoryBackingFile FD")
+                            .or_binder_exception(ExceptionCode::ILLEGAL_ARGUMENT)?,
+                    )?
+                    .into(),
+                    memory_backing_file.rangeStart as u64,
+                    memory_backing_file.size as u64,
+                ))
+            })
+            .collect::<binder::Result<_>>()?;
+
         // Actually start the VM.
         let crosvm_config = CrosvmConfig {
             cid,
@@ -846,6 +865,7 @@
             dump_dt_fd,
             enable_hypervisor_specific_auth_method: config.enableHypervisorSpecificAuthMethod,
             instance_id,
+            custom_memory_backing_files,
         };
         let instance = Arc::new(
             VmInstance::new(
diff --git a/android/virtmgr/src/crosvm.rs b/android/virtmgr/src/crosvm.rs
index bb7712e..8500421 100644
--- a/android/virtmgr/src/crosvm.rs
+++ b/android/virtmgr/src/crosvm.rs
@@ -140,6 +140,8 @@
     pub dump_dt_fd: Option<File>,
     pub enable_hypervisor_specific_auth_method: bool,
     pub instance_id: [u8; 64],
+    // (memfd, guest address, size)
+    pub custom_memory_backing_files: Vec<(OwnedFd, u64, u64)>,
 }
 
 #[derive(Debug)]
@@ -1042,8 +1044,8 @@
             // When this mode is enabled, two hypervisor specific IDs are expected to be packed
             // into the instance ID. We extract them here and pass along to crosvm so they can be
             // given to the hypervisor driver via an ioctl.
-            let vm_id = u32::from_le_bytes(config.instance_id[60..64].try_into().unwrap());
-            let pas_id = u16::from_le_bytes(config.instance_id[58..60].try_into().unwrap());
+            let pas_id = u32::from_le_bytes(config.instance_id[60..64].try_into().unwrap());
+            let vm_id = u16::from_le_bytes(config.instance_id[58..60].try_into().unwrap());
             command.arg("--hypervisor").arg(
                 format!("gunyah[device=/dev/gunyah,qcom_trusted_vm_id={vm_id},qcom_trusted_vm_pas_id={pas_id}]"),
             );
@@ -1370,6 +1372,13 @@
         }
     }
 
+    for (fd, addr, size) in config.custom_memory_backing_files {
+        command.arg("--file-backed-mapping").arg(format!(
+            "{},addr={addr:#0x},size={size:#0x},rw,ram",
+            add_preserved_fd(&mut preserved_fds, fd)
+        ));
+    }
+
     debug!("Preserving FDs {:?}", preserved_fds);
     command.preserved_fds(preserved_fds);
 
diff --git a/android/virtualizationservice/aidl/android/system/virtualizationservice/CustomMemoryBackingFile.aidl b/android/virtualizationservice/aidl/android/system/virtualizationservice/CustomMemoryBackingFile.aidl
new file mode 100644
index 0000000..721ad26
--- /dev/null
+++ b/android/virtualizationservice/aidl/android/system/virtualizationservice/CustomMemoryBackingFile.aidl
@@ -0,0 +1,36 @@
+/*
+ * Copyright 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package android.system.virtualizationservice;
+
+/** Custom memfd for a subset of guest memory. */
+parcelable CustomMemoryBackingFile {
+    /** The memfd. */
+    ParcelFileDescriptor file;
+
+    /**
+     * Start of range in guest physical address space.
+     *
+     * The value should be interpreted as an unsigned 64 bit integer.
+     */
+    long rangeStart;
+
+    /**
+     * Size of range in bytes.
+     *
+     * The value should be interpreted as an unsigned 64 bit integer.
+     */
+    long size;
+}
diff --git a/android/virtualizationservice/aidl/android/system/virtualizationservice/VirtualMachineRawConfig.aidl b/android/virtualizationservice/aidl/android/system/virtualizationservice/VirtualMachineRawConfig.aidl
index a822423..c5fe982 100644
--- a/android/virtualizationservice/aidl/android/system/virtualizationservice/VirtualMachineRawConfig.aidl
+++ b/android/virtualizationservice/aidl/android/system/virtualizationservice/VirtualMachineRawConfig.aidl
@@ -18,6 +18,7 @@
 import android.system.virtualizationservice.AssignedDevices;
 import android.system.virtualizationservice.AudioConfig;
 import android.system.virtualizationservice.CpuOptions;
+import android.system.virtualizationservice.CustomMemoryBackingFile;
 import android.system.virtualizationservice.DiskImage;
 import android.system.virtualizationservice.DisplayConfig;
 import android.system.virtualizationservice.GpuConfig;
@@ -123,4 +124,7 @@
      * VMs.
      */
     boolean enableHypervisorSpecificAuthMethod;
+
+    /** Custom memfds for a subset of guest memory */
+    CustomMemoryBackingFile[] customMemoryBackingFiles;
 }
diff --git a/android/vm/src/run.rs b/android/vm/src/run.rs
index eaf2522..a362b8e 100644
--- a/android/vm/src/run.rs
+++ b/android/vm/src/run.rs
@@ -341,11 +341,10 @@
     } else {
         None
     };
+    let vm = VmInstance::create(service, config, console_out, console_in, log, dump_dt)
+        .context("Failed to create VM")?;
     let callback = Box::new(Callback {});
-    let vm =
-        VmInstance::create(service, config, console_out, console_in, log, dump_dt, Some(callback))
-            .context("Failed to create VM")?;
-    vm.start().context("Failed to start VM")?;
+    vm.start(Some(callback)).context("Failed to start VM")?;
 
     let debug_level = get_debug_level(config).unwrap_or(DebugLevel::NONE);
 
diff --git a/build/debian/vm_config.json.aarch64 b/build/debian/vm_config.json.aarch64
index d41a29c..96254f8 100644
--- a/build/debian/vm_config.json.aarch64
+++ b/build/debian/vm_config.json.aarch64
@@ -34,5 +34,6 @@
     "debuggable": true,
     "console_out": true,
     "console_input_device": "ttyS0",
-    "network": true
+    "network": true,
+    "auto_memory_balloon": true
 }
diff --git a/guest/pvmfw/src/fdt.rs b/guest/pvmfw/src/fdt.rs
index 2e34ee8..59399b3 100644
--- a/guest/pvmfw/src/fdt.rs
+++ b/guest/pvmfw/src/fdt.rs
@@ -29,7 +29,6 @@
 use core::mem::size_of;
 use core::ops::Range;
 use hypervisor_backends::get_device_assigner;
-use hypervisor_backends::get_mem_sharer;
 use libfdt::AddressRange;
 use libfdt::CellIterator;
 use libfdt::Fdt;
@@ -79,48 +78,99 @@
     }
 }
 
+/// For non-standardly sized integer properties, not following <#size-cells> or <#address-cells>.
+#[derive(Copy, Clone, Debug, Eq, PartialEq)]
+enum DeviceTreeInteger {
+    SingleCell(u32),
+    DoubleCell(u64),
+}
+
+impl DeviceTreeInteger {
+    fn read_from(node: &FdtNode, name: &CStr) -> libfdt::Result<Option<Self>> {
+        if let Some(bytes) = node.getprop(name)? {
+            Ok(Some(Self::from_bytes(bytes).ok_or(FdtError::BadValue)?))
+        } else {
+            Ok(None)
+        }
+    }
+
+    fn from_bytes(bytes: &[u8]) -> Option<Self> {
+        if let Some(val) = bytes.try_into().ok().map(u32::from_be_bytes) {
+            return Some(Self::SingleCell(val));
+        } else if let Some(val) = bytes.try_into().ok().map(u64::from_be_bytes) {
+            return Some(Self::DoubleCell(val));
+        }
+        None
+    }
+
+    fn write_to(&self, node: &mut FdtNodeMut, name: &CStr) -> libfdt::Result<()> {
+        match self {
+            Self::SingleCell(value) => node.setprop(name, &value.to_be_bytes()),
+            Self::DoubleCell(value) => node.setprop(name, &value.to_be_bytes()),
+        }
+    }
+}
+
+impl From<DeviceTreeInteger> for usize {
+    fn from(i: DeviceTreeInteger) -> Self {
+        match i {
+            DeviceTreeInteger::SingleCell(v) => v.try_into().unwrap(),
+            DeviceTreeInteger::DoubleCell(v) => v.try_into().unwrap(),
+        }
+    }
+}
+
+/// Returns the pair or integers or an error if only one value is present.
+fn read_two_ints(
+    node: &FdtNode,
+    name_a: &CStr,
+    name_b: &CStr,
+) -> libfdt::Result<Option<(DeviceTreeInteger, DeviceTreeInteger)>> {
+    let a = DeviceTreeInteger::read_from(node, name_a)?;
+    let b = DeviceTreeInteger::read_from(node, name_b)?;
+
+    match (a, b) {
+        (Some(a), Some(b)) => Ok(Some((a, b))),
+        (None, None) => Ok(None),
+        _ => Err(FdtError::NotFound),
+    }
+}
+
 /// Extract from /config the address range containing the pre-loaded kernel.
 ///
 /// Absence of /config is not an error. However, an error is returned if only one of the two
 /// properties is present.
 pub fn read_kernel_range_from(fdt: &Fdt) -> libfdt::Result<Option<Range<usize>>> {
-    let addr = c"kernel-address";
-    let size = c"kernel-size";
-
-    if let Some(config) = fdt.node(c"/config")? {
-        match (config.getprop_u32(addr)?, config.getprop_u32(size)?) {
-            (None, None) => {}
-            (Some(addr), Some(size)) => {
-                let addr = addr as usize;
-                let size = size as usize;
-                return Ok(Some(addr..(addr + size)));
-            }
-            _ => return Err(FdtError::NotFound),
+    if let Some(ref config) = fdt.node(c"/config")? {
+        if let Some((addr, size)) = read_two_ints(config, c"kernel-address", c"kernel-size")? {
+            let addr = usize::from(addr);
+            let size = usize::from(size);
+            return Ok(Some(addr..(addr + size)));
         }
     }
-
     Ok(None)
 }
 
+fn read_initrd_range_props(
+    fdt: &Fdt,
+) -> libfdt::Result<Option<(DeviceTreeInteger, DeviceTreeInteger)>> {
+    if let Some(ref chosen) = fdt.chosen()? {
+        read_two_ints(chosen, c"linux,initrd-start", c"linux,initrd-end")
+    } else {
+        Ok(None)
+    }
+}
+
 /// Extract from /chosen the address range containing the pre-loaded ramdisk.
 ///
 /// Absence is not an error as there can be initrd-less VM. However, an error is returned if only
 /// one of the two properties is present.
 pub fn read_initrd_range_from(fdt: &Fdt) -> libfdt::Result<Option<Range<usize>>> {
-    let start = c"linux,initrd-start";
-    let end = c"linux,initrd-end";
-
-    if let Some(chosen) = fdt.chosen()? {
-        match (chosen.getprop_u32(start)?, chosen.getprop_u32(end)?) {
-            (None, None) => {}
-            (Some(start), Some(end)) => {
-                return Ok(Some((start as usize)..(end as usize)));
-            }
-            _ => return Err(FdtError::NotFound),
-        }
+    if let Some((start, end)) = read_initrd_range_props(fdt)? {
+        Ok(Some(usize::from(start)..usize::from(end)))
+    } else {
+        Ok(None)
     }
-
-    Ok(None)
 }
 
 /// Read /avf/untrusted/instance-id, if present.
@@ -141,13 +191,14 @@
     }
 }
 
-fn patch_initrd_range(fdt: &mut Fdt, initrd_range: &Range<usize>) -> libfdt::Result<()> {
-    let start = u32::try_from(initrd_range.start).unwrap();
-    let end = u32::try_from(initrd_range.end).unwrap();
-
+fn patch_initrd_range(
+    fdt: &mut Fdt,
+    start: &DeviceTreeInteger,
+    end: &DeviceTreeInteger,
+) -> libfdt::Result<()> {
     let mut node = fdt.chosen_mut()?.ok_or(FdtError::NotFound)?;
-    node.setprop(c"linux,initrd-start", &start.to_be_bytes())?;
-    node.setprop(c"linux,initrd-end", &end.to_be_bytes())?;
+    start.write_to(&mut node, c"linux,initrd-start")?;
+    end.write_to(&mut node, c"linux,initrd-end")?;
     Ok(())
 }
 
@@ -193,6 +244,11 @@
         );
     }
     let base = range.start;
+    if base % alignment != 0 {
+        error!("Memory base address {:#x} is not aligned to {:#x}", base, alignment);
+        return Err(RebootReason::InvalidFdt);
+    }
+    // For simplicity, force a hardcoded memory base, for now.
     if base != MEM_START {
         error!("Memory base address {:#x} is not {:#x}", base, MEM_START);
         return Err(RebootReason::InvalidFdt);
@@ -902,6 +958,10 @@
             error!("Invalid swiotlb range: addr:{addr:#x} size:{size:#x}");
             return Err(RebootReason::InvalidFdt);
         }
+        if (addr % alignment) != 0 {
+            error!("Swiotlb address {:#x} not aligned to {:#x}", addr, alignment);
+            return Err(RebootReason::InvalidFdt);
+        }
     }
     if let Some(range) = swiotlb_info.fixed_range() {
         if !range.is_within(memory) {
@@ -1016,7 +1076,7 @@
 
 #[derive(Debug)]
 pub struct DeviceTreeInfo {
-    pub initrd_range: Option<Range<usize>>,
+    initrd_range: Option<(DeviceTreeInteger, DeviceTreeInteger)>,
     pub memory_range: Range<usize>,
     bootargs: Option<CString>,
     cpus: ArrayVec<[CpuInfo; DeviceTreeInfo::MAX_CPUS]>,
@@ -1045,6 +1105,7 @@
     vm_dtbo: Option<&mut [u8]>,
     vm_ref_dt: Option<&[u8]>,
     guest_page_size: usize,
+    hyp_page_size: Option<usize>,
 ) -> Result<DeviceTreeInfo, RebootReason> {
     let vm_dtbo = match vm_dtbo {
         Some(vm_dtbo) => Some(VmDtbo::from_mut_slice(vm_dtbo).map_err(|e| {
@@ -1054,7 +1115,7 @@
         None => None,
     };
 
-    let info = parse_device_tree(fdt, vm_dtbo.as_deref(), guest_page_size)?;
+    let info = parse_device_tree(fdt, vm_dtbo.as_deref(), guest_page_size, hyp_page_size)?;
 
     fdt.clone_from(FDT_TEMPLATE).map_err(|e| {
         error!("Failed to instantiate FDT from the template DT: {e}");
@@ -1111,13 +1172,15 @@
     fdt: &Fdt,
     vm_dtbo: Option<&VmDtbo>,
     guest_page_size: usize,
+    hyp_page_size: Option<usize>,
 ) -> Result<DeviceTreeInfo, RebootReason> {
-    let initrd_range = read_initrd_range_from(fdt).map_err(|e| {
+    let initrd_range = read_initrd_range_props(fdt).map_err(|e| {
         error!("Failed to read initrd range from DT: {e}");
         RebootReason::InvalidFdt
     })?;
 
-    let memory_alignment = guest_page_size;
+    // Ensure that MMIO_GUARD can't be used to inadvertently map some memory as MMIO.
+    let memory_alignment = max(hyp_page_size, Some(guest_page_size)).unwrap();
     let memory_range = read_and_validate_memory_range(fdt, memory_alignment)?;
 
     let bootargs = read_bootargs_from(fdt).map_err(|e| {
@@ -1171,22 +1234,17 @@
             error!("Swiotlb info missing from DT");
             RebootReason::InvalidFdt
         })?;
-    let swiotlb_alignment = guest_page_size;
+    // Ensure that MEM_SHARE won't inadvertently map beyond the shared region.
+    let swiotlb_alignment = max(hyp_page_size, Some(guest_page_size)).unwrap();
     validate_swiotlb_info(&swiotlb_info, &memory_range, swiotlb_alignment)?;
 
     let device_assignment = if let Some(vm_dtbo) = vm_dtbo {
         if let Some(hypervisor) = get_device_assigner() {
-            // TODO(ptosi): Cache the (single?) granule once, in vmbase.
-            let granule = get_mem_sharer()
-                .ok_or_else(|| {
-                    error!("No MEM_SHARE found during device assignment validation");
-                    RebootReason::InternalError
-                })?
-                .granule()
-                .map_err(|e| {
-                    error!("Failed to get granule for device assignment validation: {e}");
-                    RebootReason::InternalError
-                })?;
+            let granule = hyp_page_size.ok_or_else(|| {
+                error!("No granule found during device assignment validation");
+                RebootReason::InternalError
+            })?;
+
             DeviceAssignmentInfo::parse(fdt, vm_dtbo, hypervisor, granule).map_err(|e| {
                 error!("Failed to parse device assignment from DT and VM DTBO: {e}");
                 RebootReason::InvalidFdt
@@ -1230,8 +1288,8 @@
 }
 
 fn patch_device_tree(fdt: &mut Fdt, info: &DeviceTreeInfo) -> Result<(), RebootReason> {
-    if let Some(initrd_range) = &info.initrd_range {
-        patch_initrd_range(fdt, initrd_range).map_err(|e| {
+    if let Some((start, end)) = &info.initrd_range {
+        patch_initrd_range(fdt, start, end).map_err(|e| {
             error!("Failed to patch initrd range to DT: {e}");
             RebootReason::InvalidFdt
         })?;
diff --git a/guest/pvmfw/src/main.rs b/guest/pvmfw/src/main.rs
index 9c67be8..9afbcc3 100644
--- a/guest/pvmfw/src/main.rs
+++ b/guest/pvmfw/src/main.rs
@@ -41,6 +41,7 @@
 use alloc::boxed::Box;
 use bssl_avf::Digester;
 use diced_open_dice::{bcc_handover_parse, DiceArtifacts, DiceContext, Hidden, VM_KEY_ALGORITHM};
+use hypervisor_backends::get_mem_sharer;
 use libfdt::Fdt;
 use log::{debug, error, info, trace, warn};
 use pvmfw_avb::verify_payload;
@@ -98,7 +99,17 @@
     }
 
     let guest_page_size = verified_boot_data.page_size.unwrap_or(SIZE_4KB);
-    let _ = sanitize_device_tree(untrusted_fdt, vm_dtbo, vm_ref_dt, guest_page_size)?;
+    // TODO(ptosi): Cache the (single?) granule once, in vmbase.
+    let hyp_page_size = if let Some(mem_sharer) = get_mem_sharer() {
+        Some(mem_sharer.granule().map_err(|e| {
+            error!("Failed to get granule size: {e}");
+            RebootReason::InternalError
+        })?)
+    } else {
+        None
+    };
+    let _ =
+        sanitize_device_tree(untrusted_fdt, vm_dtbo, vm_ref_dt, guest_page_size, hyp_page_size)?;
     let fdt = untrusted_fdt; // DT has now been sanitized.
 
     let next_bcc_size = guest_page_size;
diff --git a/guest/rialto/tests/test.rs b/guest/rialto/tests/test.rs
index d68c568..c650046 100644
--- a/guest/rialto/tests/test.rs
+++ b/guest/rialto/tests/test.rs
@@ -338,7 +338,6 @@
         /* consoleIn */ None,
         log,
         /* dump_dt */ None,
-        None,
     )
     .context("Failed to create VM")
 }
diff --git a/guest/trusty/security_vm/launcher/src/main.rs b/guest/trusty/security_vm/launcher/src/main.rs
index 3c8d599..5273d33 100644
--- a/guest/trusty/security_vm/launcher/src/main.rs
+++ b/guest/trusty/security_vm/launcher/src/main.rs
@@ -102,10 +102,9 @@
         None, // console_out
         None, // log
         None, // dump_dt
-        None, // callback
     )
     .context("Failed to create VM")?;
-    vm.start().context("Failed to start VM")?;
+    vm.start(None /* callback */).context("Failed to start VM")?;
 
     println!("started trusty_security_vm_launcher VM");
     let death_reason = vm.wait_for_death();
diff --git a/libs/libavf/include/android/virtualization.h b/libs/libavf/include/android/virtualization.h
index 8d96fac..4bfe47a 100644
--- a/libs/libavf/include/android/virtualization.h
+++ b/libs/libavf/include/android/virtualization.h
@@ -189,7 +189,7 @@
  * physical memory.
  *
  * \param config a virtual machine config object.
- * \param fd a memfd
+ * \param fd a memfd. Ownership is transferred, even if the function is not successful.
  * \param rangeStart range start of guest memory addresses
  * \param rangeEnd range end of guest memory addresses
  *
diff --git a/libs/libavf/src/lib.rs b/libs/libavf/src/lib.rs
index 3fa1b75..33bd2d7 100644
--- a/libs/libavf/src/lib.rs
+++ b/libs/libavf/src/lib.rs
@@ -24,8 +24,9 @@
 use android_system_virtualizationservice::{
     aidl::android::system::virtualizationservice::{
         AssignedDevices::AssignedDevices, CpuOptions::CpuOptions,
-        CpuOptions::CpuTopology::CpuTopology, DiskImage::DiskImage,
-        IVirtualizationService::IVirtualizationService, VirtualMachineConfig::VirtualMachineConfig,
+        CpuOptions::CpuTopology::CpuTopology, CustomMemoryBackingFile::CustomMemoryBackingFile,
+        DiskImage::DiskImage, IVirtualizationService::IVirtualizationService,
+        VirtualMachineConfig::VirtualMachineConfig,
         VirtualMachineRawConfig::VirtualMachineRawConfig,
     },
     binder::{ParcelFileDescriptor, Strong},
@@ -254,18 +255,35 @@
     0
 }
 
-/// NOT IMPLEMENTED.
+/// Use the specified fd as the backing memfd for a range of the guest physical memory.
 ///
-/// # Returns
-/// It always returns `-ENOTSUP`.
+/// # Safety
+/// `config` must be a pointer returned by `AVirtualMachineRawConfig_create`.
 #[no_mangle]
-pub extern "C" fn AVirtualMachineRawConfig_addCustomMemoryBackingFile(
-    _config: *mut VirtualMachineRawConfig,
-    _fd: c_int,
-    _range_start: u64,
-    _range_end: u64,
+pub unsafe extern "C" fn AVirtualMachineRawConfig_addCustomMemoryBackingFile(
+    config: *mut VirtualMachineRawConfig,
+    fd: c_int,
+    range_start: u64,
+    range_end: u64,
 ) -> c_int {
-    -libc::ENOTSUP
+    // SAFETY: `config` is assumed to be a valid, non-null pointer returned by
+    // AVirtualMachineRawConfig_create. It's the only reference to the object.
+    let config = unsafe { &mut *config };
+
+    let Some(file) = get_file_from_fd(fd) else {
+        return -libc::EINVAL;
+    };
+    let Some(size) = range_end.checked_sub(range_start) else {
+        return -libc::EINVAL;
+    };
+    config.customMemoryBackingFiles.push(CustomMemoryBackingFile {
+        file: Some(ParcelFileDescriptor::new(file)),
+        // AIDL doesn't support unsigned ints, so we've got to reinterpret the bytes into a signed
+        // int.
+        rangeStart: range_start as i64,
+        size: size as i64,
+    });
+    0
 }
 
 /// Add device tree overlay blob
@@ -373,7 +391,7 @@
     let console_in = get_file_from_fd(console_in_fd);
     let log = get_file_from_fd(log_fd);
 
-    match VmInstance::create(service.as_ref(), &config, console_out, console_in, log, None, None) {
+    match VmInstance::create(service.as_ref(), &config, console_out, console_in, log, None) {
         Ok(vm) => {
             // SAFETY: `vm_ptr` is assumed to be a valid, non-null pointer to a mutable raw pointer.
             // `vm` is the only reference here and `vm_ptr` takes ownership.
@@ -398,7 +416,7 @@
     // SAFETY: `vm` is assumed to be a valid, non-null pointer returned by
     // `AVirtualMachine_createRaw`. It's the only reference to the object.
     let vm = unsafe { &*vm };
-    match vm.start() {
+    match vm.start(None) {
         Ok(_) => 0,
         Err(e) => {
             error!("AVirtualMachine_start failed: {e:?}");
diff --git a/libs/libcompos_common/compos_client.rs b/libs/libcompos_common/compos_client.rs
index a52104d..c48c9f6 100644
--- a/libs/libcompos_common/compos_client.rs
+++ b/libs/libcompos_common/compos_client.rs
@@ -150,19 +150,14 @@
 
         // Let logs go to logcat.
         let (console_fd, log_fd) = (None, None);
-        let callback = Box::new(Callback {});
         let instance = VmInstance::create(
-            service,
-            &config,
-            console_fd,
-            /* console_in_fd */ None,
-            log_fd,
+            service, &config, console_fd, /* console_in_fd */ None, log_fd,
             /* dump_dt */ None,
-            Some(callback),
         )
         .context("Failed to create VM")?;
 
-        instance.start()?;
+        let callback = Box::new(Callback {});
+        instance.start(Some(callback))?;
 
         let ready = instance.wait_until_ready(TIMEOUTS.vm_max_time_to_ready);
         if ready == Err(VmWaitError::Finished) && debug_level != DebugLevel::NONE {
diff --git a/libs/libservice_vm_manager/src/lib.rs b/libs/libservice_vm_manager/src/lib.rs
index 667731f..0681a6f 100644
--- a/libs/libservice_vm_manager/src/lib.rs
+++ b/libs/libservice_vm_manager/src/lib.rs
@@ -152,7 +152,7 @@
         let vsock_listener = VsockListener::bind_with_cid_port(VMADDR_CID_HOST, vm_type.port())?;
 
         // Starts the service VM.
-        vm.start().context("Failed to start service VM")?;
+        vm.start(None).context("Failed to start service VM")?;
         info!("Service VM started");
 
         // Accepts the connection from the service VM.
@@ -246,8 +246,7 @@
     let console_in = None;
     let log = Some(android_log_fd()?);
     let dump_dt = None;
-    let callback = None;
-    VmInstance::create(service.as_ref(), &config, console_out, console_in, log, dump_dt, callback)
+    VmInstance::create(service.as_ref(), &config, console_out, console_in, log, dump_dt)
         .context("Failed to create service VM")
 }
 
diff --git a/libs/libvmclient/src/lib.rs b/libs/libvmclient/src/lib.rs
index 8dd3cd3..2c6abb5 100644
--- a/libs/libvmclient/src/lib.rs
+++ b/libs/libvmclient/src/lib.rs
@@ -209,7 +209,6 @@
         console_in: Option<File>,
         log: Option<File>,
         dump_dt: Option<File>,
-        callback: Option<Box<dyn VmCallback + Send + Sync>>,
     ) -> BinderResult<Self> {
         let console_out = console_out.map(ParcelFileDescriptor::new);
         let console_in = console_in.map(ParcelFileDescriptor::new);
@@ -226,20 +225,19 @@
 
         let cid = vm.getCid()?;
 
-        // Register callback before starting VM, in case it dies immediately.
         let state = Arc::new(Monitor::new(VmState::default()));
-        let callback = BnVirtualMachineCallback::new_binder(
-            VirtualMachineCallback { state: state.clone(), client_callback: callback },
-            BinderFeatures::default(),
-        );
-        vm.registerCallback(&callback)?;
         let death_recipient = wait_for_binder_death(&mut vm.as_binder(), state.clone())?;
 
         Ok(Self { vm, cid, state, _death_recipient: death_recipient })
     }
 
     /// Starts the VM.
-    pub fn start(&self) -> BinderResult<()> {
+    pub fn start(&self, callback: Option<Box<dyn VmCallback + Send + Sync>>) -> BinderResult<()> {
+        let callback = BnVirtualMachineCallback::new_binder(
+            VirtualMachineCallback { state: self.state.clone(), client_callback: callback },
+            BinderFeatures::default(),
+        );
+        self.vm.registerCallback(&callback)?;
         self.vm.start()
     }
 
diff --git a/microfuchsia/microfuchsiad/src/instance_starter.rs b/microfuchsia/microfuchsiad/src/instance_starter.rs
index 55c946e..f58a379 100644
--- a/microfuchsia/microfuchsiad/src/instance_starter.rs
+++ b/microfuchsia/microfuchsiad/src/instance_starter.rs
@@ -97,7 +97,6 @@
             console_in,
             /* log= */ None,
             /* dump_dt= */ None,
-            None,
         )
         .context("Failed to create VM")?;
         if let Some(pty) = &pty {
@@ -106,7 +105,7 @@
                 .setHostConsoleName(&pty.follower_name)
                 .context("Setting host console name")?;
         }
-        vm_instance.start().context("Starting VM")?;
+        vm_instance.start(None).context("Starting VM")?;
 
         Ok(MicrofuchsiaInstance {
             _vm_instance: vm_instance,
diff --git a/tests/backcompat_test/src/main.rs b/tests/backcompat_test/src/main.rs
index b0cd042..4d09a89 100644
--- a/tests/backcompat_test/src/main.rs
+++ b/tests/backcompat_test/src/main.rs
@@ -118,10 +118,9 @@
         /* consoleIn */ None,
         None,
         Some(dump_dt),
-        None,
     )
     .context("Failed to create VM")?;
-    vm.start().context("Failed to start VM")?;
+    vm.start(None).context("Failed to start VM")?;
     info!("Started example VM.");
 
     // Wait for VM to finish
diff --git a/tests/early_vm_test/src/main.rs b/tests/early_vm_test/src/main.rs
index a3c80ca..7d630f8 100644
--- a/tests/early_vm_test/src/main.rs
+++ b/tests/early_vm_test/src/main.rs
@@ -96,7 +96,6 @@
         None, // console_out
         None, // log
         None, // dump_dt
-        None, // callback
     )
     .context("Failed to create VM")?;
 
diff --git a/tests/vm_accessor/accessor/src/run.rs b/tests/vm_accessor/accessor/src/run.rs
index 6dcc507..5bdb8f1 100644
--- a/tests/vm_accessor/accessor/src/run.rs
+++ b/tests/vm_accessor/accessor/src/run.rs
@@ -129,10 +129,9 @@
         None,                    /* console_in */
         Some(android_log_fd()?), /* log */
         None,                    /* dump_dt */
-        Some(Box::new(Callback {})),
     )
     .context("Failed to create VM")?;
-    vm.start().context("Failed to start VM")?;
+    vm.start(Some(Box::new(Callback {}))).context("Failed to start VM")?;
 
     info!("started IAccessor VM with CID {}", vm.cid());
 
diff --git a/tests/vmbase_example/src/main.rs b/tests/vmbase_example/src/main.rs
index 81812cd..0eda90e 100644
--- a/tests/vmbase_example/src/main.rs
+++ b/tests/vmbase_example/src/main.rs
@@ -116,10 +116,9 @@
         /* consoleIn */ None,
         Some(log_writer),
         /* dump_dt */ None,
-        None,
     )
     .context("Failed to create VM")?;
-    vm.start().context("Failed to start VM")?;
+    vm.start(None).context("Failed to start VM")?;
     info!("Started example VM.");
 
     // Wait for VM to finish, and check that it shut down cleanly.