Merge "Don't use non-protected VMs for CompOS" into main
diff --git a/pvmfw/Android.bp b/pvmfw/Android.bp
index 769a955..144e81e 100644
--- a/pvmfw/Android.bp
+++ b/pvmfw/Android.bp
@@ -117,9 +117,10 @@
     rustlibs: [
         "libcbor_util",
         "libciborium",
-        "libdiced_open_dice_nostd",
+        "libdiced_open_dice",
         "libpvmfw_avb_nostd",
         "libzerocopy_nostd",
+        "libhex",
     ],
 }
 
@@ -320,15 +321,22 @@
     installable: false,
 }
 
-prebuilt_etc {
+filegroup {
     name: "pvmfw_embedded_key",
-    src: ":avb_testkey_rsa4096_pub_bin",
-    installable: false,
+    srcs: [":avb_testkey_rsa4096"],
+}
+
+genrule {
+    name: "pvmfw_embedded_key_pub_bin",
+    tools: ["avbtool"],
+    srcs: [":pvmfw_embedded_key"],
+    out: ["pvmfw_embedded_key_pub.bin"],
+    cmd: "$(location avbtool) extract_public_key --key $(in) --output $(out)",
 }
 
 genrule {
     name: "pvmfw_embedded_key_rs",
-    srcs: [":pvmfw_embedded_key"],
+    srcs: [":pvmfw_embedded_key_pub_bin"],
     out: ["lib.rs"],
     cmd: "(" +
         "    echo '#![no_std]';" +
diff --git a/pvmfw/src/dice.rs b/pvmfw/src/dice.rs
index aaf2691..da19931 100644
--- a/pvmfw/src/dice.rs
+++ b/pvmfw/src/dice.rs
@@ -71,6 +71,7 @@
     Ok(hash(&digests)?)
 }
 
+#[derive(Clone)]
 pub struct PartialInputs {
     pub code_hash: Hash,
     pub auth_hash: Hash,
@@ -96,6 +97,7 @@
         current_bcc_handover: &[u8],
         salt: &[u8; HIDDEN_SIZE],
         instance_hash: Option<Hash>,
+        deferred_rollback_protection: bool,
         next_bcc: &mut [u8],
     ) -> Result<()> {
         let config = self
@@ -107,16 +109,23 @@
             Config::Descriptor(&config),
             self.auth_hash,
             self.mode,
-            self.make_hidden(salt)?,
+            self.make_hidden(salt, deferred_rollback_protection)?,
         );
         let _ = bcc_handover_main_flow(current_bcc_handover, &dice_inputs, next_bcc)?;
         Ok(())
     }
 
-    fn make_hidden(&self, salt: &[u8; HIDDEN_SIZE]) -> Result<[u8; HIDDEN_SIZE]> {
+    fn make_hidden(
+        &self,
+        salt: &[u8; HIDDEN_SIZE],
+        deferred_rollback_protection: bool,
+    ) -> diced_open_dice::Result<[u8; HIDDEN_SIZE]> {
         // We want to make sure we get a different sealing CDI for:
         // - VMs with different salt values
         // - An RKP VM and any other VM (regardless of salt)
+        // - depending on whether rollback protection has been deferred to payload. This ensures the
+        //   adversary cannot leak the secrets by using old images & setting
+        //   `deferred_rollback_protection` to true.
         // The hidden input for DICE affects the sealing CDI (but the values in the config
         // descriptor do not).
         // Since the hidden input has to be a fixed size, create it as a hash of the values we
@@ -126,10 +135,16 @@
         struct HiddenInput {
             rkp_vm_marker: bool,
             salt: [u8; HIDDEN_SIZE],
+            deferred_rollback_protection: bool,
         }
-        // TODO(b/291213394): Include `defer_rollback_protection` flag in the Hidden Input to
-        // differentiate the secrets in both cases.
-        Ok(hash(HiddenInput { rkp_vm_marker: self.rkp_vm_marker, salt: *salt }.as_bytes())?)
+        hash(
+            HiddenInput {
+                rkp_vm_marker: self.rkp_vm_marker,
+                salt: *salt,
+                deferred_rollback_protection,
+            }
+            .as_bytes(),
+        )
     }
 
     fn generate_config_descriptor(&self, instance_hash: Option<Hash>) -> Result<Vec<u8>> {
@@ -176,9 +191,20 @@
 
 #[cfg(test)]
 mod tests {
-    use super::*;
+    use crate::{
+        Hash, PartialInputs, COMPONENT_NAME_KEY, INSTANCE_HASH_KEY, RKP_VM_MARKER_KEY,
+        SECURITY_VERSION_KEY,
+    };
     use ciborium::Value;
+    use diced_open_dice::DiceArtifacts;
+    use diced_open_dice::DiceMode;
+    use diced_open_dice::HIDDEN_SIZE;
+    use pvmfw_avb::Capability;
+    use pvmfw_avb::DebugLevel;
+    use pvmfw_avb::Digest;
+    use pvmfw_avb::VerifiedBootData;
     use std::collections::HashMap;
+    use std::mem::size_of;
     use std::vec;
 
     const COMPONENT_VERSION_KEY: i64 = -70003;
@@ -284,4 +310,67 @@
             .map(|(k, v)| ((k.into_integer().unwrap().try_into().unwrap()), v))
             .collect()
     }
+
+    #[test]
+    fn changing_deferred_rpb_changes_secrets() {
+        let vb_data = VerifiedBootData { debug_level: DebugLevel::Full, ..BASE_VB_DATA };
+        let inputs = PartialInputs::new(&vb_data).unwrap();
+        let mut buffer_without_defer = [0; 4096];
+        let mut buffer_with_defer = [0; 4096];
+        let mut buffer_without_defer_retry = [0; 4096];
+
+        let sample_dice_input: &[u8] = &[
+            0xa3, // CDI attest
+            0x01, 0x58, 0x20, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+            0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+            0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // CDI seal
+            0x02, 0x58, 0x20, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+            0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+            0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // DICE chain
+            0x03, 0x82, 0xa6, 0x01, 0x02, 0x03, 0x27, 0x04, 0x02, 0x20, 0x01, 0x21, 0x40, 0x22,
+            0x40, 0x84, 0x40, 0xa0, 0x40, 0x40,
+            // 8-bytes of trailing data that aren't part of the DICE chain.
+            0x84, 0x41, 0x55, 0xa0, 0x42, 0x11, 0x22, 0x40,
+        ];
+
+        inputs
+            .clone()
+            .write_next_bcc(
+                sample_dice_input,
+                &[0u8; HIDDEN_SIZE],
+                Some([0u8; 64]),
+                false,
+                &mut buffer_without_defer,
+            )
+            .unwrap();
+        let bcc_handover1 = diced_open_dice::bcc_handover_parse(&buffer_without_defer).unwrap();
+
+        inputs
+            .clone()
+            .write_next_bcc(
+                sample_dice_input,
+                &[0u8; HIDDEN_SIZE],
+                Some([0u8; 64]),
+                true,
+                &mut buffer_with_defer,
+            )
+            .unwrap();
+        let bcc_handover2 = diced_open_dice::bcc_handover_parse(&buffer_with_defer).unwrap();
+
+        inputs
+            .clone()
+            .write_next_bcc(
+                sample_dice_input,
+                &[0u8; HIDDEN_SIZE],
+                Some([0u8; 64]),
+                false,
+                &mut buffer_without_defer_retry,
+            )
+            .unwrap();
+        let bcc_handover3 =
+            diced_open_dice::bcc_handover_parse(&buffer_without_defer_retry).unwrap();
+
+        assert_ne!(bcc_handover1.cdi_seal(), bcc_handover2.cdi_seal());
+        assert_eq!(bcc_handover1.cdi_seal(), bcc_handover3.cdi_seal());
+    }
 }
diff --git a/pvmfw/src/main.rs b/pvmfw/src/main.rs
index 5893907..299d1c0 100644
--- a/pvmfw/src/main.rs
+++ b/pvmfw/src/main.rs
@@ -144,9 +144,9 @@
     })?;
 
     let instance_hash = if cfg!(llpvm_changes) { Some(salt_from_instance_id(fdt)?) } else { None };
-    let (new_instance, salt) = if should_defer_rollback_protection(fdt)?
-        && verified_boot_data.has_capability(Capability::SecretkeeperProtection)
-    {
+    let defer_rollback_protection = should_defer_rollback_protection(fdt)?
+        && verified_boot_data.has_capability(Capability::SecretkeeperProtection);
+    let (new_instance, salt) = if defer_rollback_protection {
         info!("Guest OS is capable of Secretkeeper protection, deferring rollback protection");
         // rollback_index of the image is used as security_version and is expected to be > 0 to
         // discourage implicit allocation.
@@ -201,12 +201,18 @@
         Cow::Owned(truncated_bcc_handover)
     };
 
-    dice_inputs.write_next_bcc(new_bcc_handover.as_ref(), &salt, instance_hash, next_bcc).map_err(
-        |e| {
+    dice_inputs
+        .write_next_bcc(
+            new_bcc_handover.as_ref(),
+            &salt,
+            instance_hash,
+            defer_rollback_protection,
+            next_bcc,
+        )
+        .map_err(|e| {
             error!("Failed to derive next-stage DICE secrets: {e:?}");
             RebootReason::SecretDerivationError
-        },
-    )?;
+        })?;
     flush(next_bcc);
 
     let kaslr_seed = u64::from_ne_bytes(rand::random_array().map_err(|e| {
diff --git a/tests/hostside/helper/java/com/android/microdroid/test/host/MicrodroidHostTestCaseBase.java b/tests/hostside/helper/java/com/android/microdroid/test/host/MicrodroidHostTestCaseBase.java
index c6b2499..46df011 100644
--- a/tests/hostside/helper/java/com/android/microdroid/test/host/MicrodroidHostTestCaseBase.java
+++ b/tests/hostside/helper/java/com/android/microdroid/test/host/MicrodroidHostTestCaseBase.java
@@ -282,4 +282,8 @@
                 .map(os -> os.replaceFirst("^microdroid_gki-", ""))
                 .collect(Collectors.toList());
     }
+
+    protected boolean isPkvmHypervisor() throws DeviceNotAvailableException {
+        return getDevice().getProperty("ro.boot.hypervisor.version").equals("kvm.arm-protected");
+    }
 }
diff --git a/tests/hostside/java/com/android/microdroid/test/MicrodroidHostTests.java b/tests/hostside/java/com/android/microdroid/test/MicrodroidHostTests.java
index eb456f2..9d0b04b 100644
--- a/tests/hostside/java/com/android/microdroid/test/MicrodroidHostTests.java
+++ b/tests/hostside/java/com/android/microdroid/test/MicrodroidHostTests.java
@@ -809,8 +809,10 @@
         // Check VmCreationRequested atom
         AtomsProto.VmCreationRequested atomVmCreationRequested =
                 data.get(0).getAtom().getVmCreationRequested();
-        assertThat(atomVmCreationRequested.getHypervisor())
-                .isEqualTo(AtomsProto.VmCreationRequested.Hypervisor.PKVM);
+        if (isPkvmHypervisor()) {
+            assertThat(atomVmCreationRequested.getHypervisor())
+                    .isEqualTo(AtomsProto.VmCreationRequested.Hypervisor.PKVM);
+        }
         assertThat(atomVmCreationRequested.getIsProtected()).isEqualTo(mProtectedVm);
         assertThat(atomVmCreationRequested.getCreationSucceeded()).isTrue();
         assertThat(atomVmCreationRequested.getBinderExceptionCode()).isEqualTo(0);
@@ -832,7 +834,11 @@
         assertThat(atomVmExited.getDeathReason()).isEqualTo(AtomsProto.VmExited.DeathReason.KILLED);
         assertThat(atomVmExited.getExitSignal()).isEqualTo(9);
         // In CPU & memory related fields, check whether positive values are collected or not.
-        assertThat(atomVmExited.getGuestTimeMillis()).isGreaterThan(0);
+        if (isPkvmHypervisor()) {
+            // Guest Time may not be updated on other hypervisors.
+            // Checking only if the hypervisor is PKVM.
+            assertThat(atomVmExited.getGuestTimeMillis()).isGreaterThan(0);
+        }
         assertThat(atomVmExited.getRssVmKb()).isGreaterThan(0);
         assertThat(atomVmExited.getRssCrosvmKb()).isGreaterThan(0);
 
diff --git a/virtualizationmanager/src/aidl.rs b/virtualizationmanager/src/aidl.rs
index d173b34..dd17b46 100644
--- a/virtualizationmanager/src/aidl.rs
+++ b/virtualizationmanager/src/aidl.rs
@@ -624,6 +624,8 @@
         } else {
             None
         };
+        let virtio_snd_backend =
+            if cfg!(paravirtualized_devices) { Some(String::from("aaudio")) } else { None };
 
         // Actually start the VM.
         let crosvm_config = CrosvmConfig {
@@ -654,6 +656,7 @@
             input_device_options,
             hugepages: config.hugePages,
             tap,
+            virtio_snd_backend,
         };
         let instance = Arc::new(
             VmInstance::new(
diff --git a/virtualizationmanager/src/crosvm.rs b/virtualizationmanager/src/crosvm.rs
index f73a977..371a908 100644
--- a/virtualizationmanager/src/crosvm.rs
+++ b/virtualizationmanager/src/crosvm.rs
@@ -123,6 +123,7 @@
     pub input_device_options: Vec<InputDeviceOption>,
     pub hugepages: bool,
     pub tap: Option<File>,
+    pub virtio_snd_backend: Option<String>,
 }
 
 #[derive(Debug)]
@@ -1029,6 +1030,12 @@
     debug!("Preserving FDs {:?}", preserved_fds);
     command.preserved_fds(preserved_fds);
 
+    if cfg!(paravirtualized_devices) {
+        if let Some(virtio_snd_backend) = &config.virtio_snd_backend {
+            command.arg("--virtio-snd").arg(format!("backend={}", virtio_snd_backend));
+        }
+    }
+
     print_crosvm_args(&command);
 
     let result = SharedChild::spawn(&mut command)?;