Merge "Dice Policy: Add dice chain matching functionality" into main
diff --git a/secretkeeper/dice_policy/Android.bp b/secretkeeper/dice_policy/Android.bp
index a7ac5b9..4f1e8b6 100644
--- a/secretkeeper/dice_policy/Android.bp
+++ b/secretkeeper/dice_policy/Android.bp
@@ -13,7 +13,9 @@
         "libanyhow",
         "libciborium",
         "libcoset",
+        "libnum_traits",
     ],
+    proc_macros: ["libnum_derive"],
 }
 
 rust_library {
diff --git a/secretkeeper/dice_policy/src/lib.rs b/secretkeeper/dice_policy/src/lib.rs
index 327b8a4..2e91305 100644
--- a/secretkeeper/dice_policy/src/lib.rs
+++ b/secretkeeper/dice_policy/src/lib.rs
@@ -57,16 +57,20 @@
 //!
 //! value = bool / int / tstr / bstr
 
-use anyhow::{anyhow, bail, Context, Result};
+use anyhow::{anyhow, bail, ensure, Context, Result};
 use ciborium::Value;
 use coset::{AsCborValue, CoseSign1};
+use num_derive::FromPrimitive;
+use num_traits::FromPrimitive;
 use std::borrow::Cow;
+use std::iter::zip;
 
 const DICE_POLICY_VERSION: u64 = 1;
 
 /// Constraint Types supported in Dice policy.
+#[repr(u16)]
 #[non_exhaustive]
-#[derive(Clone, Copy, Debug, PartialEq)]
+#[derive(Clone, Copy, Debug, FromPrimitive, PartialEq)]
 pub enum ConstraintType {
     /// Enforce exact match criteria, indicating the policy should match
     /// if the dice chain has exact same specified values.
@@ -133,6 +137,7 @@
     ///    ];
     ///
     /// 2. For hypothetical (and highly simplified) dice chain:
+    ///
     ///    [ROT_KEY, [{1 : 'a', 2 : {200 : 5, 201 : 'b'}}]]
     ///    The following can be used
     ///    constraint_spec =[
@@ -140,13 +145,7 @@
     ///     ConstraintSpec(ConstraintType::GreaterOrEqual, vec![2, 200]),// matches any value >= 5
     ///    ];
     pub fn from_dice_chain(dice_chain: &[u8], constraint_spec: &[ConstraintSpec]) -> Result<Self> {
-        // TODO(b/298217847): Check if the given dice chain adheres to Explicit-key DiceCertChain
-        // format and if not, convert it before policy construction.
-        let dice_chain = value_from_bytes(dice_chain).context("Unable to decode top-level CBOR")?;
-        let dice_chain = match dice_chain {
-            Value::Array(array) if array.len() >= 2 => array,
-            _ => bail!("Expected an array of at least length 2, found: {:?}", dice_chain),
-        };
+        let dice_chain = deserialize_dice_chain(dice_chain)?;
         let mut constraints_list: Vec<NodeConstraints> = Vec::with_capacity(dice_chain.len());
         let mut it = dice_chain.into_iter();
 
@@ -167,6 +166,61 @@
             node_constraints_list: constraints_list.into_boxed_slice(),
         })
     }
+
+    /// Dice chain policy verifier - Compare the input dice chain against this Dice policy.
+    /// The method returns Ok() if the dice chain meets the constraints set in Dice policy,
+    /// otherwise returns error in case of mismatch.
+    /// TODO(b/291238565) Create a separate error module for DicePolicy mismatches.
+    pub fn matches_dice_chain(&self, dice_chain: &[u8]) -> Result<()> {
+        let dice_chain = deserialize_dice_chain(dice_chain)?;
+        ensure!(
+            dice_chain.len() == self.node_constraints_list.len(),
+            format!(
+                "Dice chain size({}) does not match policy({})",
+                dice_chain.len(),
+                self.node_constraints_list.len()
+            )
+        );
+
+        for (n, (dice_node, node_constraints)) in
+            zip(dice_chain, self.node_constraints_list.iter()).enumerate()
+        {
+            let dice_node_payload = if n == 0 {
+                dice_node
+            } else {
+                cbor_value_from_cose_sign(dice_node)
+                    .with_context(|| format!("Unable to get Cose payload at: {}", n))?
+            };
+            check_constraints_on_node(node_constraints, &dice_node_payload)
+                .context(format!("Mismatch found at {}", n))?;
+        }
+        Ok(())
+    }
+}
+
+fn check_constraints_on_node(node_constraints: &NodeConstraints, dice_node: &Value) -> Result<()> {
+    for constraint in node_constraints.0.iter() {
+        check_constraint_on_node(constraint, dice_node)?;
+    }
+    Ok(())
+}
+
+fn check_constraint_on_node(constraint: &Constraint, dice_node: &Value) -> Result<()> {
+    let Constraint(cons_type, path, value_in_constraint) = constraint;
+    let value_in_node = lookup_value_in_nested_map(dice_node, path)?;
+    match ConstraintType::from_u16(*cons_type).ok_or(anyhow!("Unexpected Constraint type"))? {
+        ConstraintType::ExactMatch => ensure!(value_in_node == *value_in_constraint),
+        ConstraintType::GreaterOrEqual => {
+            let value_in_node = value_in_node
+                .as_integer()
+                .ok_or(anyhow!("Mismatch type: expected a cbor integer"))?;
+            let value_min = value_in_constraint
+                .as_integer()
+                .ok_or(anyhow!("Mismatch type: expected a cbor integer"))?;
+            ensure!(value_in_node >= value_min);
+        }
+    };
+    Ok(())
 }
 
 // Take the payload of a dice node & construct the constraints on it.
@@ -231,6 +285,17 @@
         Some(payload) => Ok(value_from_bytes(&payload)?),
     }
 }
+fn deserialize_dice_chain(dice_chain_bytes: &[u8]) -> Result<Vec<Value>> {
+    // TODO(b/298217847): Check if the given dice chain adheres to Explicit-key DiceCertChain
+    // format and if not, convert it.
+    let dice_chain =
+        value_from_bytes(dice_chain_bytes).context("Unable to decode top-level CBOR")?;
+    let dice_chain = match dice_chain {
+        Value::Array(array) if array.len() >= 2 => array,
+        _ => bail!("Expected an array of at least length 2, found: {:?}", dice_chain),
+    };
+    Ok(dice_chain)
+}
 
 /// Decodes the provided binary CBOR-encoded value and returns a
 /// ciborium::Value struct wrapped in Result.
@@ -266,38 +331,29 @@
         constraint_spec: Vec<ConstraintSpec>,
         // The expected dice policy if above constraint_spec is applied to input_dice.
         expected_dice_policy: DicePolicy,
+        // Another dice chain, which is almost same as the input_dice, but (roughly) imitates
+        // an 'updated' one, ie, some int entries are higher than corresponding entry
+        // in input_chain.
+        updated_input_dice: Vec<u8>,
     }
 
     impl TestArtifacts {
         // Get an example instance of TestArtifacts. This uses a hard coded, hypothetical
         // chain of certificates & a list of constraint_spec on this.
         fn get_example() -> Self {
-            const EXAMPLE_NUM: i64 = 59765;
+            const EXAMPLE_NUM_1: i64 = 59765;
+            const EXAMPLE_NUM_2: i64 = 59766;
             const EXAMPLE_STRING: &str = "testing_dice_policy";
+            const UNCONSTRAINED_STRING: &str = "unconstrained_string";
+            const ANOTHER_UNCONSTRAINED_STRING: &str = "another_unconstrained_string";
 
             let rot_key = CoseKey::default().to_cbor_value().unwrap();
-            let nested_payload = cbor!({
-                100 => EXAMPLE_NUM
-            })
-            .unwrap();
-            let payload = cbor!({
-                1 => EXAMPLE_STRING,
-                2 => "some_other_example_string",
-                3 => Value::Bytes(value_to_bytes(&nested_payload).unwrap()),
-            })
-            .unwrap();
-            let payload = value_to_bytes(&payload).unwrap();
-            let dice_node = CoseSign1 {
-                protected: ProtectedHeader::default(),
-                unprotected: Header::default(),
-                payload: Some(payload),
-                signature: b"ddef".to_vec(),
-            }
-            .to_cbor_value()
-            .unwrap();
-            let input_dice = Value::Array([rot_key.clone(), dice_node].to_vec());
-
-            let input_dice = value_to_bytes(&input_dice).unwrap();
+            let input_dice = Self::get_dice_chain_helper(
+                rot_key.clone(),
+                EXAMPLE_NUM_1,
+                EXAMPLE_STRING,
+                UNCONSTRAINED_STRING,
+            );
 
             // Now construct constraint_spec on the input dice, note this will use the keys
             // which are also hardcoded within the get_dice_chain_helper.
@@ -305,7 +361,7 @@
             let constraint_spec = vec![
                 ConstraintSpec::new(ConstraintType::ExactMatch, vec![1]).unwrap(),
                 // Notice how key "2" is (deliberately) absent in ConstraintSpec
-                // so policy should not constraint it.
+                // so policy should not constrain it.
                 ConstraintSpec::new(ConstraintType::GreaterOrEqual, vec![3, 100]).unwrap(),
             ];
             let expected_dice_policy = DicePolicy {
@@ -325,12 +381,53 @@
                         Constraint(
                             ConstraintType::GreaterOrEqual as u16,
                             vec![3, 100],
-                            Value::from(EXAMPLE_NUM),
+                            Value::from(EXAMPLE_NUM_1),
                         ),
                     ])),
                 ]),
             };
-            Self { input_dice, constraint_spec, expected_dice_policy }
+
+            let updated_input_dice = Self::get_dice_chain_helper(
+                rot_key.clone(),
+                EXAMPLE_NUM_2,
+                EXAMPLE_STRING,
+                ANOTHER_UNCONSTRAINED_STRING,
+            );
+            Self { input_dice, constraint_spec, expected_dice_policy, updated_input_dice }
+        }
+
+        // Helper method method to generate a dice chain with a given rot_key.
+        // Other arguments are ad-hoc values in the nested map. Callers use these to
+        // construct appropriate constrains in dice policies.
+        fn get_dice_chain_helper(
+            rot_key: Value,
+            version: i64,
+            constrained_string: &str,
+            unconstrained_string: &str,
+        ) -> Vec<u8> {
+            let nested_payload = cbor!({
+                100 => version
+            })
+            .unwrap();
+
+            let payload = cbor!({
+                1 => constrained_string,
+                2 => unconstrained_string,
+                3 => Value::Bytes(value_to_bytes(&nested_payload).unwrap()),
+            })
+            .unwrap();
+            let payload = value_to_bytes(&payload).unwrap();
+            let dice_node = CoseSign1 {
+                protected: ProtectedHeader::default(),
+                unprotected: Header::default(),
+                payload: Some(payload),
+                signature: b"ddef".to_vec(),
+            }
+            .to_cbor_value()
+            .unwrap();
+            let input_dice = Value::Array([rot_key.clone(), dice_node].to_vec());
+
+            value_to_bytes(&input_dice).unwrap()
         }
     }
 
@@ -344,6 +441,43 @@
         assert_eq!(policy, example.expected_dice_policy);
     }
 
+    test!(policy_matches_original_dice_chain);
+    fn policy_matches_original_dice_chain() {
+        let example = TestArtifacts::get_example();
+        assert!(
+            DicePolicy::from_dice_chain(&example.input_dice, &example.constraint_spec)
+                .unwrap()
+                .matches_dice_chain(&example.input_dice)
+                .is_ok(),
+            "The dice chain did not match the policy constructed out of it!"
+        );
+    }
+
+    test!(policy_matches_updated_dice_chain);
+    fn policy_matches_updated_dice_chain() {
+        let example = TestArtifacts::get_example();
+        assert!(
+            DicePolicy::from_dice_chain(&example.input_dice, &example.constraint_spec)
+                .unwrap()
+                .matches_dice_chain(&example.updated_input_dice)
+                .is_ok(),
+            "The updated dice chain did not match the original policy!"
+        );
+    }
+
+    test!(policy_mismatch_downgraded_dice_chain);
+    fn policy_mismatch_downgraded_dice_chain() {
+        let example = TestArtifacts::get_example();
+        assert!(
+            DicePolicy::from_dice_chain(&example.updated_input_dice, &example.constraint_spec)
+                .unwrap()
+                .matches_dice_chain(&example.input_dice)
+                .is_err(),
+            "The (downgraded) dice chain matched the policy constructed out of the 'updated'\
+            dice chain!!"
+        );
+    }
+
     test!(policy_dice_size_is_same);
     fn policy_dice_size_is_same() {
         // This is the number of certs in compos bcc (including the first ROT)