Merge "Implement createTapInterface in vmnic" into main
diff --git a/apex/Android.bp b/apex/Android.bp
index 99b2dee..43819dc 100644
--- a/apex/Android.bp
+++ b/apex/Android.bp
@@ -47,7 +47,7 @@
         "release_avf_enable_device_assignment",
         "release_avf_enable_llpvm_changes",
         "release_avf_enable_network",
-        "release_avf_enable_remote_attestation",
+        "avf_remote_attestation_enabled",
         "release_avf_enable_vendor_modules",
         "release_avf_enable_virt_cpufreq",
         "release_avf_support_custom_vm_with_paravirtualized_devices",
@@ -204,7 +204,7 @@
                 },
             },
         },
-        release_avf_enable_remote_attestation: {
+        avf_remote_attestation_enabled: {
             vintf_fragments: [
                 "virtualizationservice.xml",
             ],
@@ -235,7 +235,7 @@
     config_namespace: "ANDROID",
     bool_variables: [
         "release_avf_enable_llpvm_changes",
-        "release_avf_enable_remote_attestation",
+        "avf_remote_attestation_enabled",
     ],
     properties: ["srcs"],
 }
@@ -247,7 +247,7 @@
         release_avf_enable_llpvm_changes: {
             srcs: ["virtualizationservice.rc.llpvm"],
         },
-        release_avf_enable_remote_attestation: {
+        avf_remote_attestation_enabled: {
             srcs: ["virtualizationservice.rc.ra"],
         },
     },
diff --git a/compos/apk/assets/vm_config.json b/compos/apk/assets/vm_config.json
index 1f5cdba..28e0f07 100644
--- a/compos/apk/assets/vm_config.json
+++ b/compos/apk/assets/vm_config.json
@@ -27,5 +27,6 @@
     }
   ],
   "export_tombstones": true,
-  "enable_authfs": true
+  "enable_authfs": true,
+  "hugepages": true
 }
diff --git a/compos/apk/assets/vm_config_staged.json b/compos/apk/assets/vm_config_staged.json
index 37b1d7a..afc3767 100644
--- a/compos/apk/assets/vm_config_staged.json
+++ b/compos/apk/assets/vm_config_staged.json
@@ -28,5 +28,6 @@
     }
   ],
   "export_tombstones": true,
-  "enable_authfs": true
+  "enable_authfs": true,
+  "hugepages": true
 }
diff --git a/compos/apk/assets/vm_config_system_ext.json b/compos/apk/assets/vm_config_system_ext.json
index 1ef43f0..730f592 100644
--- a/compos/apk/assets/vm_config_system_ext.json
+++ b/compos/apk/assets/vm_config_system_ext.json
@@ -30,5 +30,6 @@
     }
   ],
   "export_tombstones": true,
-  "enable_authfs": true
+  "enable_authfs": true,
+  "hugepages": true
 }
diff --git a/compos/apk/assets/vm_config_system_ext_staged.json b/compos/apk/assets/vm_config_system_ext_staged.json
index 9103a9e..6d91aa2 100644
--- a/compos/apk/assets/vm_config_system_ext_staged.json
+++ b/compos/apk/assets/vm_config_system_ext_staged.json
@@ -31,5 +31,6 @@
     }
   ],
   "export_tombstones": true,
-  "enable_authfs": true
+  "enable_authfs": true,
+  "hugepages": true
 }
diff --git a/docs/vm_remote_attestation.md b/docs/vm_remote_attestation.md
index 835dcac..3483351 100644
--- a/docs/vm_remote_attestation.md
+++ b/docs/vm_remote_attestation.md
@@ -106,3 +106,18 @@
     normal mode.
 -   The `vmComponents` field contains a list of all the APKs and apexes loaded
     by the pVM.
+
+## To Support It
+
+VM remote attestation is a strongly recommended feature from Android V. To support
+it, you only need to provide a valid VM DICE chain satisfying the following
+requirements:
+
+- The DICE chain must have a UDS-rooted public key registered at the RKP factory.
+- The DICE chain should have RKP VM markers that help identify RKP VM as required
+  by the [remote provisioning HAL][rkp-hal-markers].
+
+The feature is enabled by default. To disable it, you can set
+`PRODUCT_AVF_REMOTE_ATTESTATION_DISABLED` to true in your Makefile.
+
+[rkp-hal-markers]: https://android.googlesource.com/platform/hardware/interfaces/+/main/security/rkp/README.md#hal
diff --git a/pvmfw/Android.bp b/pvmfw/Android.bp
index 37a321d..769a955 100644
--- a/pvmfw/Android.bp
+++ b/pvmfw/Android.bp
@@ -14,6 +14,7 @@
         "libaarch64_paging",
         "libbssl_avf_nostd",
         "libbssl_sys_nostd",
+        "libcbor_util_nostd",
         "libciborium_nostd",
         "libciborium_io_nostd",
         "libcstr",
diff --git a/pvmfw/src/bcc.rs b/pvmfw/src/bcc.rs
index f56e62b..7a13da7 100644
--- a/pvmfw/src/bcc.rs
+++ b/pvmfw/src/bcc.rs
@@ -27,10 +27,9 @@
 type Result<T> = core::result::Result<T, BccError>;
 
 pub enum BccError {
-    CborDecodeError(ciborium::de::Error<ciborium_io::EndOfFile>),
-    CborEncodeError(ciborium::ser::Error<core::convert::Infallible>),
+    CborDecodeError,
+    CborEncodeError,
     DiceError(diced_open_dice::DiceError),
-    ExtraneousBytes,
     MalformedBcc(&'static str),
     MissingBcc,
 }
@@ -38,10 +37,9 @@
 impl fmt::Display for BccError {
     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
         match self {
-            Self::CborDecodeError(e) => write!(f, "Error parsing BCC CBOR: {e:?}"),
-            Self::CborEncodeError(e) => write!(f, "Error encoding BCC CBOR: {e:?}"),
+            Self::CborDecodeError => write!(f, "Error parsing BCC CBOR"),
+            Self::CborEncodeError => write!(f, "Error encoding BCC CBOR"),
             Self::DiceError(e) => write!(f, "Dice error: {e:?}"),
-            Self::ExtraneousBytes => write!(f, "Unexpected trailing data in BCC"),
             Self::MalformedBcc(s) => {
                 write!(f, "BCC does not have the expected CBOR structure: {s}")
             }
@@ -65,7 +63,7 @@
     // }
     let bcc_handover: Vec<(Value, Value)> =
         vec![(1.into(), cdi_attest.as_slice().into()), (2.into(), cdi_seal.as_slice().into())];
-    value_to_bytes(&bcc_handover.into())
+    cbor_util::serialize(&bcc_handover).map_err(|_| BccError::CborEncodeError)
 }
 
 fn taint_cdi(cdi: &Cdi, info: &str) -> Result<Cdi> {
@@ -100,7 +98,8 @@
         // We don't attempt to fully validate the BCC (e.g. we don't check the signatures) - we
         // have to trust our loader. But if it's invalid CBOR or otherwise clearly ill-formed,
         // something is very wrong, so we fail.
-        let bcc_cbor = value_from_bytes(received_bcc)?;
+        let bcc_cbor =
+            cbor_util::deserialize(received_bcc).map_err(|_| BccError::CborDecodeError)?;
 
         // Bcc = [
         //   PubKeyEd25519 / PubKeyECDSA256, // DK_pub
@@ -159,7 +158,7 @@
         // ]
         let payload =
             self.payload_bytes().ok_or(BccError::MalformedBcc("Invalid payload in BccEntry"))?;
-        let payload = value_from_bytes(payload)?;
+        let payload = cbor_util::deserialize(payload).map_err(|_| BccError::CborDecodeError)?;
         trace!("Bcc payload: {payload:?}");
         Ok(BccPayload(payload))
     }
@@ -215,21 +214,3 @@
         None
     }
 }
-
-/// Decodes the provided binary CBOR-encoded value and returns a
-/// ciborium::Value struct wrapped in Result.
-fn value_from_bytes(mut bytes: &[u8]) -> Result<Value> {
-    let value = ciborium::de::from_reader(&mut bytes).map_err(BccError::CborDecodeError)?;
-    // Ciborium tries to read one Value, but doesn't care if there is trailing data after it. We do.
-    if !bytes.is_empty() {
-        return Err(BccError::ExtraneousBytes);
-    }
-    Ok(value)
-}
-
-/// Encodes a ciborium::Value into bytes.
-fn value_to_bytes(value: &Value) -> Result<Vec<u8>> {
-    let mut bytes: Vec<u8> = Vec::new();
-    ciborium::ser::into_writer(&value, &mut bytes).map_err(BccError::CborEncodeError)?;
-    Ok(bytes)
-}
diff --git a/pvmfw/src/dice.rs b/pvmfw/src/dice.rs
index 67865e5..9283b80 100644
--- a/pvmfw/src/dice.rs
+++ b/pvmfw/src/dice.rs
@@ -13,16 +13,48 @@
 // limitations under the License.
 
 //! Support for DICE derivation and BCC generation.
+extern crate alloc;
 
+use alloc::format;
+use alloc::vec::Vec;
+use ciborium::cbor;
+use ciborium::Value;
 use core::mem::size_of;
-use cstr::cstr;
 use diced_open_dice::{
-    bcc_format_config_descriptor, bcc_handover_main_flow, hash, Config, DiceConfigValues, DiceMode,
-    Hash, InputValues, HIDDEN_SIZE,
+    bcc_handover_main_flow, hash, Config, DiceMode, Hash, InputValues, HIDDEN_SIZE,
 };
 use pvmfw_avb::{Capability, DebugLevel, Digest, VerifiedBootData};
 use zerocopy::AsBytes;
 
+const COMPONENT_NAME_KEY: i64 = -70002;
+const SECURITY_VERSION_KEY: i64 = -70005;
+const RKP_VM_MARKER_KEY: i64 = -70006;
+// TODO(b/291245237): Document this key along with others used in ConfigDescriptor in AVF based VM.
+const INSTANCE_HASH_KEY: i64 = -71003;
+
+#[derive(Debug)]
+pub enum Error {
+    /// Error in CBOR operations
+    CborError(ciborium::value::Error),
+    /// Error in DICE operations
+    DiceError(diced_open_dice::DiceError),
+}
+
+impl From<ciborium::value::Error> for Error {
+    fn from(e: ciborium::value::Error) -> Self {
+        Self::CborError(e)
+    }
+}
+
+impl From<diced_open_dice::DiceError> for Error {
+    fn from(e: diced_open_dice::DiceError) -> Self {
+        Self::DiceError(e)
+    }
+}
+
+// DICE in pvmfw result type.
+type Result<T> = core::result::Result<T, Error>;
+
 fn to_dice_mode(debug_level: DebugLevel) -> DiceMode {
     match debug_level {
         DebugLevel::None => DiceMode::kDiceModeNormal,
@@ -30,13 +62,13 @@
     }
 }
 
-fn to_dice_hash(verified_boot_data: &VerifiedBootData) -> diced_open_dice::Result<Hash> {
+fn to_dice_hash(verified_boot_data: &VerifiedBootData) -> Result<Hash> {
     let mut digests = [0u8; size_of::<Digest>() * 2];
     digests[..size_of::<Digest>()].copy_from_slice(&verified_boot_data.kernel_digest);
     if let Some(initrd_digest) = verified_boot_data.initrd_digest {
         digests[size_of::<Digest>()..].copy_from_slice(&initrd_digest);
     }
-    hash(&digests)
+    Ok(hash(&digests)?)
 }
 
 pub struct PartialInputs {
@@ -48,7 +80,7 @@
 }
 
 impl PartialInputs {
-    pub fn new(data: &VerifiedBootData) -> diced_open_dice::Result<Self> {
+    pub fn new(data: &VerifiedBootData) -> Result<Self> {
         let code_hash = to_dice_hash(data)?;
         let auth_hash = hash(data.public_key)?;
         let mode = to_dice_mode(data.debug_level);
@@ -63,14 +95,16 @@
         self,
         current_bcc_handover: &[u8],
         salt: &[u8; HIDDEN_SIZE],
+        instance_hash: Option<Hash>,
         next_bcc: &mut [u8],
-    ) -> diced_open_dice::Result<()> {
-        let mut config_descriptor_buffer = [0; 128];
-        let config = self.generate_config_descriptor(&mut config_descriptor_buffer)?;
+    ) -> Result<()> {
+        let config = self
+            .generate_config_descriptor(instance_hash)
+            .map_err(|_| diced_open_dice::DiceError::InvalidInput)?;
 
         let dice_inputs = InputValues::new(
             self.code_hash,
-            Config::Descriptor(config),
+            Config::Descriptor(&config),
             self.auth_hash,
             self.mode,
             self.make_hidden(salt)?,
@@ -79,7 +113,7 @@
         Ok(())
     }
 
-    fn make_hidden(&self, salt: &[u8; HIDDEN_SIZE]) -> diced_open_dice::Result<[u8; HIDDEN_SIZE]> {
+    fn make_hidden(&self, salt: &[u8; HIDDEN_SIZE]) -> Result<[u8; HIDDEN_SIZE]> {
         // We want to make sure we get a different sealing CDI for:
         // - VMs with different salt values
         // - An RKP VM and any other VM (regardless of salt)
@@ -95,23 +129,25 @@
         }
         // TODO(b/291213394): Include `defer_rollback_protection` flag in the Hidden Input to
         // differentiate the secrets in both cases.
-        hash(HiddenInput { rkp_vm_marker: self.rkp_vm_marker, salt: *salt }.as_bytes())
+        Ok(hash(HiddenInput { rkp_vm_marker: self.rkp_vm_marker, salt: *salt }.as_bytes())?)
     }
 
-    fn generate_config_descriptor<'a>(
-        &self,
-        config_descriptor_buffer: &'a mut [u8],
-    ) -> diced_open_dice::Result<&'a [u8]> {
-        let config_values = DiceConfigValues {
-            component_name: Some(cstr!("vm_entry")),
-            security_version: if cfg!(dice_changes) { Some(self.security_version) } else { None },
-            rkp_vm_marker: self.rkp_vm_marker,
-            ..Default::default()
-        };
-        let config_descriptor_size =
-            bcc_format_config_descriptor(&config_values, config_descriptor_buffer)?;
-        let config = &config_descriptor_buffer[..config_descriptor_size];
-        Ok(config)
+    fn generate_config_descriptor(&self, instance_hash: Option<Hash>) -> Result<Vec<u8>> {
+        let mut config = Vec::with_capacity(4);
+        config.push((cbor!(COMPONENT_NAME_KEY)?, cbor!("vm_entry")?));
+        if cfg!(dice_changes) {
+            config.push((cbor!(SECURITY_VERSION_KEY)?, cbor!(self.security_version)?));
+        }
+        if self.rkp_vm_marker {
+            config.push((cbor!(RKP_VM_MARKER_KEY)?, Value::Null))
+        }
+        if let Some(instance_hash) = instance_hash {
+            config.push((cbor!(INSTANCE_HASH_KEY)?, Value::from(instance_hash.as_slice())));
+        }
+        let config = Value::Map(config);
+        Ok(cbor_util::serialize(&config).map_err(|e| {
+            ciborium::value::Error::Custom(format!("Error in serialization: {e:?}"))
+        })?)
     }
 }
 
@@ -145,12 +181,8 @@
     use std::collections::HashMap;
     use std::vec;
 
-    const COMPONENT_NAME_KEY: i64 = -70002;
     const COMPONENT_VERSION_KEY: i64 = -70003;
     const RESETTABLE_KEY: i64 = -70004;
-    const SECURITY_VERSION_KEY: i64 = -70005;
-    const RKP_VM_MARKER_KEY: i64 = -70006;
-
     const BASE_VB_DATA: VerifiedBootData = VerifiedBootData {
         debug_level: DebugLevel::None,
         kernel_digest: [1u8; size_of::<Digest>()],
@@ -159,6 +191,7 @@
         capabilities: vec![],
         rollback_index: 42,
     };
+    const HASH: Hash = *b"sixtyfourbyteslongsentencearerarebutletsgiveitatrycantbethathard";
 
     #[test]
     fn base_data_conversion() {
@@ -193,7 +226,7 @@
     fn base_config_descriptor() {
         let vb_data = BASE_VB_DATA;
         let inputs = PartialInputs::new(&vb_data).unwrap();
-        let config_map = decode_config_descriptor(&inputs);
+        let config_map = decode_config_descriptor(&inputs, None);
 
         assert_eq!(config_map.get(&COMPONENT_NAME_KEY).unwrap().as_text().unwrap(), "vm_entry");
         assert_eq!(config_map.get(&COMPONENT_VERSION_KEY), None);
@@ -214,17 +247,37 @@
         let vb_data =
             VerifiedBootData { capabilities: vec![Capability::RemoteAttest], ..BASE_VB_DATA };
         let inputs = PartialInputs::new(&vb_data).unwrap();
-        let config_map = decode_config_descriptor(&inputs);
+        let config_map = decode_config_descriptor(&inputs, Some(HASH));
 
         assert!(config_map.get(&RKP_VM_MARKER_KEY).unwrap().is_null());
     }
 
-    fn decode_config_descriptor(inputs: &PartialInputs) -> HashMap<i64, Value> {
-        let mut buffer = [0; 128];
-        let config_descriptor = inputs.generate_config_descriptor(&mut buffer).unwrap();
+    #[test]
+    fn config_descriptor_with_instance_hash() {
+        let vb_data =
+            VerifiedBootData { capabilities: vec![Capability::RemoteAttest], ..BASE_VB_DATA };
+        let inputs = PartialInputs::new(&vb_data).unwrap();
+        let config_map = decode_config_descriptor(&inputs, Some(HASH));
+        assert_eq!(*config_map.get(&INSTANCE_HASH_KEY).unwrap(), Value::from(HASH.as_slice()));
+    }
+
+    #[test]
+    fn config_descriptor_without_instance_hash() {
+        let vb_data =
+            VerifiedBootData { capabilities: vec![Capability::RemoteAttest], ..BASE_VB_DATA };
+        let inputs = PartialInputs::new(&vb_data).unwrap();
+        let config_map = decode_config_descriptor(&inputs, None);
+        assert!(config_map.get(&INSTANCE_HASH_KEY).is_none());
+    }
+
+    fn decode_config_descriptor(
+        inputs: &PartialInputs,
+        instance_hash: Option<Hash>,
+    ) -> HashMap<i64, Value> {
+        let config_descriptor = inputs.generate_config_descriptor(instance_hash).unwrap();
 
         let cbor_map =
-            cbor_util::deserialize::<Value>(config_descriptor).unwrap().into_map().unwrap();
+            cbor_util::deserialize::<Value>(&config_descriptor).unwrap().into_map().unwrap();
 
         cbor_map
             .into_iter()
diff --git a/pvmfw/src/main.rs b/pvmfw/src/main.rs
index 2af19c4..5893907 100644
--- a/pvmfw/src/main.rs
+++ b/pvmfw/src/main.rs
@@ -143,8 +143,8 @@
         RebootReason::InternalError
     })?;
 
-    let (new_instance, salt) = if cfg!(llpvm_changes)
-        && should_defer_rollback_protection(fdt)?
+    let instance_hash = if cfg!(llpvm_changes) { Some(salt_from_instance_id(fdt)?) } else { None };
+    let (new_instance, salt) = if should_defer_rollback_protection(fdt)?
         && verified_boot_data.has_capability(Capability::SecretkeeperProtection)
     {
         info!("Guest OS is capable of Secretkeeper protection, deferring rollback protection");
@@ -155,7 +155,7 @@
             return Err(RebootReason::InvalidPayload);
         };
         // `new_instance` cannot be known to pvmfw
-        (false, salt_from_instance_id(fdt)?)
+        (false, instance_hash.unwrap())
     } else {
         let (recorded_entry, mut instance_img, header_index) =
             get_recorded_entry(&mut pci_root, cdi_seal).map_err(|e| {
@@ -164,18 +164,15 @@
             })?;
         let (new_instance, salt) = if let Some(entry) = recorded_entry {
             maybe_check_dice_measurements_match_entry(&dice_inputs, &entry)?;
-            let salt = if cfg!(llpvm_changes) { salt_from_instance_id(fdt)? } else { entry.salt };
+            let salt = instance_hash.unwrap_or(entry.salt);
             (false, salt)
         } else {
             // New instance!
-            let salt = if cfg!(llpvm_changes) {
-                salt_from_instance_id(fdt)?
-            } else {
-                rand::random_array().map_err(|e| {
-                    error!("Failed to generated instance.img salt: {e}");
-                    RebootReason::InternalError
-                })?
-            };
+            let salt = instance_hash.map_or_else(rand::random_array, Ok).map_err(|e| {
+                error!("Failed to generated instance.img salt: {e}");
+                RebootReason::InternalError
+            })?;
+
             let entry = EntryBody::new(&dice_inputs, &salt);
             record_instance_entry(&entry, cdi_seal, &mut instance_img, header_index).map_err(
                 |e| {
@@ -204,10 +201,12 @@
         Cow::Owned(truncated_bcc_handover)
     };
 
-    dice_inputs.write_next_bcc(new_bcc_handover.as_ref(), &salt, next_bcc).map_err(|e| {
-        error!("Failed to derive next-stage DICE secrets: {e:?}");
-        RebootReason::SecretDerivationError
-    })?;
+    dice_inputs.write_next_bcc(new_bcc_handover.as_ref(), &salt, instance_hash, next_bcc).map_err(
+        |e| {
+            error!("Failed to derive next-stage DICE secrets: {e:?}");
+            RebootReason::SecretDerivationError
+        },
+    )?;
     flush(next_bcc);
 
     let kaslr_seed = u64::from_ne_bytes(rand::random_array().map_err(|e| {
diff --git a/tests/hostside/helper/java/com/android/microdroid/test/host/MicrodroidHostTestCaseBase.java b/tests/hostside/helper/java/com/android/microdroid/test/host/MicrodroidHostTestCaseBase.java
index 41ddd48..c6b2499 100644
--- a/tests/hostside/helper/java/com/android/microdroid/test/host/MicrodroidHostTestCaseBase.java
+++ b/tests/hostside/helper/java/com/android/microdroid/test/host/MicrodroidHostTestCaseBase.java
@@ -55,6 +55,7 @@
     protected static final String LOG_PATH = TEST_ROOT + "log.txt";
     protected static final String CONSOLE_PATH = TEST_ROOT + "console.txt";
     protected static final String TRADEFED_CONSOLE_PATH = TRADEFED_TEST_ROOT + "console.txt";
+    protected static final String TRADEFED_LOG_PATH = TRADEFED_TEST_ROOT + "log.txt";
     private static final int TEST_VM_ADB_PORT = 8000;
     private static final String MICRODROID_SERIAL = "localhost:" + TEST_VM_ADB_PORT;
     private static final String INSTANCE_IMG = "instance.img";
diff --git a/tests/hostside/java/com/android/microdroid/test/MicrodroidHostTests.java b/tests/hostside/java/com/android/microdroid/test/MicrodroidHostTests.java
index f424ce0..eb456f2 100644
--- a/tests/hostside/java/com/android/microdroid/test/MicrodroidHostTests.java
+++ b/tests/hostside/java/com/android/microdroid/test/MicrodroidHostTests.java
@@ -871,10 +871,13 @@
         assertWithMessage("Incorrect ABI list").that(abis).hasLength(1);
 
         // Check that no denials have happened so far
-        String logText =
-                getDevice().pullFileContents(CONSOLE_PATH) + getDevice().pullFileContents(LOG_PATH);
+        String consoleText = getDevice().pullFileContents(TRADEFED_CONSOLE_PATH);
+        assertWithMessage("Console output shouldn't be empty").that(consoleText).isNotEmpty();
+        String logText = getDevice().pullFileContents(TRADEFED_LOG_PATH);
+        assertWithMessage("Log output shouldn't be empty").that(logText).isNotEmpty();
+
         assertWithMessage("Unexpected denials during VM boot")
-                .that(logText)
+                .that(consoleText + logText)
                 .doesNotContainMatch("avc:\\s+denied");
 
         assertThat(getDeviceNumCpus(microdroid)).isEqualTo(getDeviceNumCpus(android));
@@ -1171,6 +1174,40 @@
         }
     }
 
+    @Test
+    public void testHugePages() throws Exception {
+        ITestDevice device = getDevice();
+        boolean disableRoot = !device.isAdbRoot();
+        CommandRunner android = new CommandRunner(device);
+
+        final String SHMEM_ENABLED_PATH = "/sys/kernel/mm/transparent_hugepage/shmem_enabled";
+        String thpShmemStr = android.run("cat", SHMEM_ENABLED_PATH);
+
+        assumeFalse("shmem already enabled, skip", thpShmemStr.contains("[advise]"));
+        assumeTrue("Unsupported shmem, skip", thpShmemStr.contains("[never]"));
+
+        device.enableAdbRoot();
+        assumeTrue("adb root is not enabled", device.isAdbRoot());
+        android.run("echo advise > " + SHMEM_ENABLED_PATH);
+
+        final String configPath = "assets/vm_config.json";
+        mMicrodroidDevice =
+                MicrodroidBuilder.fromDevicePath(getPathForPackage(PACKAGE_NAME), configPath)
+                        .debugLevel("full")
+                        .memoryMib(minMemorySize())
+                        .cpuTopology("match_host")
+                        .protectedVm(mProtectedVm)
+                        .gki(mGki)
+                        .hugePages(true)
+                        .build(getAndroidDevice());
+        mMicrodroidDevice.waitForBootComplete(BOOT_COMPLETE_TIMEOUT);
+
+        android.run("echo never >" + SHMEM_ENABLED_PATH);
+        if (disableRoot) {
+            device.disableAdbRoot();
+        }
+    }
+
     @Before
     public void setUp() throws Exception {
         assumeDeviceIsCapable(getDevice());
diff --git a/tests/pvmfw/helper/Android.bp b/tests/pvmfw/helper/Android.bp
index 90ca03e..a75f034 100644
--- a/tests/pvmfw/helper/Android.bp
+++ b/tests/pvmfw/helper/Android.bp
@@ -5,7 +5,7 @@
 java_library_host {
     name: "PvmfwHostTestHelper",
     srcs: ["java/**/*.java"],
-    libs: [
+    static_libs: [
         "androidx.annotation_annotation",
         "truth",
     ],
diff --git a/tests/pvmfw/tools/PvmfwTool.java b/tests/pvmfw/tools/PvmfwTool.java
index e150ec4..9f0cb42 100644
--- a/tests/pvmfw/tools/PvmfwTool.java
+++ b/tests/pvmfw/tools/PvmfwTool.java
@@ -25,10 +25,10 @@
 public class PvmfwTool {
     public static void printUsage() {
         System.out.println("pvmfw-tool: Appends pvmfw.bin and config payloads.");
-        System.out.println("            Requires BCC and VM reference DT.");
-        System.out.println("            VM DTBO and Debug policy can optionally be specified");
+        System.out.println("            Requires BCC. VM Reference DT, VM DTBO, and Debug policy");
+        System.out.println("            can optionally be specified");
         System.out.println(
-                "Usage: pvmfw-tool <out> <pvmfw.bin> <bcc.dat> <VM reference DT> [VM DTBO] [debug"
+                "Usage: pvmfw-tool <out> <pvmfw.bin> <bcc.dat> [VM reference DT] [VM DTBO] [debug"
                         + " policy]");
     }
 
@@ -41,10 +41,13 @@
         File out = new File(args[0]);
         File pvmfwBin = new File(args[1]);
         File bccData = new File(args[2]);
-        File vmReferenceDt = new File(args[3]);
 
+        File vmReferenceDt = null;
         File vmDtbo = null;
         File dp = null;
+        if (args.length > 3) {
+            vmReferenceDt = new File(args[3]);
+        }
         if (args.length > 4) {
             vmDtbo = new File(args[4]);
         }
@@ -53,12 +56,18 @@
         }
 
         try {
-            Pvmfw pvmfw =
+            Pvmfw.Builder builder =
                     new Pvmfw.Builder(pvmfwBin, bccData)
                             .setVmReferenceDt(vmReferenceDt)
                             .setDebugPolicyOverlay(dp)
-                            .setVmDtbo(vmDtbo)
-                            .build();
+                            .setVmDtbo(vmDtbo);
+            if (vmReferenceDt == null) {
+                builder.setVersion(1, 1);
+            } else {
+                builder.setVersion(1, 2);
+            }
+
+            Pvmfw pvmfw = builder.build();
             pvmfw.serialize(out);
         } catch (IOException e) {
             e.printStackTrace();