Merge "Moving keystore2_client_tests to presubmit." into main
diff --git a/keystore2/Android.bp b/keystore2/Android.bp
index 5726078..7bba687 100644
--- a/keystore2/Android.bp
+++ b/keystore2/Android.bp
@@ -113,6 +113,7 @@
         "liblibsqlite3_sys",
         "libnix",
         "librusqlite",
+        "libtempfile",
     ],
     // The test should always include watchdog.
     features: [
diff --git a/keystore2/src/database/tests.rs b/keystore2/src/database/tests.rs
index 031d749..5f882cd 100644
--- a/keystore2/src/database/tests.rs
+++ b/keystore2/src/database/tests.rs
@@ -42,7 +42,11 @@
 use std::time::Instant;
 
 pub fn new_test_db() -> Result<KeystoreDB> {
-    let conn = KeystoreDB::make_connection("file::memory:")?;
+    new_test_db_at("file::memory:")
+}
+
+fn new_test_db_at(path: &str) -> Result<KeystoreDB> {
+    let conn = KeystoreDB::make_connection(path)?;
 
     let mut db = KeystoreDB { conn, gc: None, perboot: Arc::new(perboot::PerbootDB::new()) };
     db.with_transaction(Immediate("TX_new_test_db"), |tx| {
@@ -2411,6 +2415,81 @@
     Ok(())
 }
 
+fn blob_count(db: &mut KeystoreDB, sc_type: SubComponentType) -> usize {
+    db.with_transaction(TransactionBehavior::Deferred, |tx| {
+        tx.query_row(
+            "SELECT COUNT(*) FROM persistent.blobentry
+                     WHERE subcomponent_type = ?;",
+            params![sc_type],
+            |row| row.get(0),
+        )
+        .context(ks_err!("Failed to count number of {sc_type:?} blobs"))
+        .no_gc()
+    })
+    .unwrap()
+}
+
+#[test]
+fn test_blobentry_gc() -> Result<()> {
+    let mut db = new_test_db()?;
+    let _key_id1 = make_test_key_entry(&mut db, Domain::APP, 1, "key1", None)?.0;
+    let key_guard2 = make_test_key_entry(&mut db, Domain::APP, 2, "key2", None)?;
+    let key_guard3 = make_test_key_entry(&mut db, Domain::APP, 3, "key3", None)?;
+    let key_id4 = make_test_key_entry(&mut db, Domain::APP, 4, "key4", None)?.0;
+    let key_id5 = make_test_key_entry(&mut db, Domain::APP, 5, "key5", None)?.0;
+
+    assert_eq!(5, blob_count(&mut db, SubComponentType::KEY_BLOB));
+    assert_eq!(5, blob_count(&mut db, SubComponentType::CERT));
+    assert_eq!(5, blob_count(&mut db, SubComponentType::CERT_CHAIN));
+
+    // Replace the keyblobs for keys 2 and 3.  The previous blobs will still exist.
+    db.set_blob(&key_guard2, SubComponentType::KEY_BLOB, Some(&[1, 2, 3]), None)?;
+    db.set_blob(&key_guard3, SubComponentType::KEY_BLOB, Some(&[1, 2, 3]), None)?;
+
+    assert_eq!(7, blob_count(&mut db, SubComponentType::KEY_BLOB));
+    assert_eq!(5, blob_count(&mut db, SubComponentType::CERT));
+    assert_eq!(5, blob_count(&mut db, SubComponentType::CERT_CHAIN));
+
+    // Delete keys 4 and 5.  The keyblobs aren't removed yet.
+    db.with_transaction(Immediate("TX_delete_test_keys"), |tx| {
+        KeystoreDB::mark_unreferenced(tx, key_id4)?;
+        KeystoreDB::mark_unreferenced(tx, key_id5)?;
+        Ok(()).no_gc()
+    })
+    .unwrap();
+
+    assert_eq!(7, blob_count(&mut db, SubComponentType::KEY_BLOB));
+    assert_eq!(5, blob_count(&mut db, SubComponentType::CERT));
+    assert_eq!(5, blob_count(&mut db, SubComponentType::CERT_CHAIN));
+
+    // First garbage collection should return all 4 blobentry rows that are no longer current for
+    // their key.
+    let superseded = db.handle_next_superseded_blobs(&[], 20).unwrap();
+    let superseded_ids: Vec<i64> = superseded.iter().map(|v| v.blob_id).collect();
+    assert_eq!(4, superseded.len());
+    assert_eq!(7, blob_count(&mut db, SubComponentType::KEY_BLOB));
+    assert_eq!(5, blob_count(&mut db, SubComponentType::CERT));
+    assert_eq!(5, blob_count(&mut db, SubComponentType::CERT_CHAIN));
+
+    // Feed the superseded blob IDs back in, to trigger removal of the old KEY_BLOB entries.  As no
+    // new superseded KEY_BLOBs are found, the unreferenced CERT/CERT_CHAIN blobs are removed.
+    let superseded = db.handle_next_superseded_blobs(&superseded_ids, 20).unwrap();
+    let superseded_ids: Vec<i64> = superseded.iter().map(|v| v.blob_id).collect();
+    assert_eq!(0, superseded.len());
+    assert_eq!(3, blob_count(&mut db, SubComponentType::KEY_BLOB));
+    assert_eq!(3, blob_count(&mut db, SubComponentType::CERT));
+    assert_eq!(3, blob_count(&mut db, SubComponentType::CERT_CHAIN));
+
+    // Nothing left to garbage collect.
+    let superseded = db.handle_next_superseded_blobs(&superseded_ids, 20).unwrap();
+    assert_eq!(0, superseded.len());
+    assert_eq!(3, blob_count(&mut db, SubComponentType::KEY_BLOB));
+    assert_eq!(3, blob_count(&mut db, SubComponentType::CERT));
+    assert_eq!(3, blob_count(&mut db, SubComponentType::CERT_CHAIN));
+
+    Ok(())
+}
+
 #[test]
 fn test_load_key_descriptor() -> Result<()> {
     let mut db = new_test_db()?;
@@ -2526,3 +2605,151 @@
     assert_eq!(third_sid_apps, vec![second_app_id]);
     Ok(())
 }
+
+// Starting from `next_keyid`, add keys to the database until the count reaches
+// `key_count`.  (`next_keyid` is assumed to indicate how many rows already exist.)
+fn db_populate_keys(db: &mut KeystoreDB, next_keyid: usize, key_count: usize) {
+    db.with_transaction(Immediate("test_keyentry"), |tx| {
+        for next_keyid in next_keyid..key_count {
+            tx.execute(
+                "INSERT into persistent.keyentry
+                        (id, key_type, domain, namespace, alias, state, km_uuid)
+                        VALUES(?, ?, ?, ?, ?, ?, ?);",
+                params![
+                    next_keyid,
+                    KeyType::Client,
+                    Domain::APP.0 as u32,
+                    10001,
+                    &format!("alias-{next_keyid}"),
+                    KeyLifeCycle::Live,
+                    KEYSTORE_UUID,
+                ],
+            )?;
+            tx.execute(
+                "INSERT INTO persistent.blobentry
+                         (subcomponent_type, keyentryid, blob) VALUES (?, ?, ?);",
+                params![SubComponentType::KEY_BLOB, next_keyid, TEST_KEY_BLOB],
+            )?;
+            tx.execute(
+                "INSERT INTO persistent.blobentry
+                         (subcomponent_type, keyentryid, blob) VALUES (?, ?, ?);",
+                params![SubComponentType::CERT, next_keyid, TEST_CERT_BLOB],
+            )?;
+            tx.execute(
+                "INSERT INTO persistent.blobentry
+                         (subcomponent_type, keyentryid, blob) VALUES (?, ?, ?);",
+                params![SubComponentType::CERT_CHAIN, next_keyid, TEST_CERT_CHAIN_BLOB],
+            )?;
+        }
+        Ok(()).no_gc()
+    })
+    .unwrap()
+}
+
+/// Run the provided `test_fn` against the database at various increasing stages of
+/// database population.
+fn run_with_many_keys<F, T>(max_count: usize, test_fn: F) -> Result<()>
+where
+    F: Fn(&mut KeystoreDB) -> T,
+{
+    android_logger::init_once(
+        android_logger::Config::default()
+            .with_tag("keystore2_test")
+            .with_max_level(log::LevelFilter::Debug),
+    );
+    // Put the test database on disk for a more realistic result.
+    let db_root = tempfile::Builder::new().prefix("ks2db-test-").tempdir().unwrap();
+    let mut db_path = db_root.path().to_owned();
+    db_path.push("ks2-test.sqlite");
+    let mut db = new_test_db_at(&db_path.to_string_lossy())?;
+
+    println!("\nNumber_of_keys,time_in_s");
+    let mut key_count = 10;
+    let mut next_keyid = 0;
+    while key_count < max_count {
+        db_populate_keys(&mut db, next_keyid, key_count);
+        assert_eq!(db_key_count(&mut db), key_count);
+
+        let start = std::time::Instant::now();
+        let _result = test_fn(&mut db);
+        println!("{key_count}, {}", start.elapsed().as_secs_f64());
+
+        next_keyid = key_count;
+        key_count *= 2;
+    }
+
+    Ok(())
+}
+
+fn db_key_count(db: &mut KeystoreDB) -> usize {
+    db.with_transaction(TransactionBehavior::Deferred, |tx| {
+        tx.query_row(
+            "SELECT COUNT(*) FROM persistent.keyentry
+                         WHERE domain = ? AND state = ? AND key_type = ?;",
+            params![Domain::APP.0 as u32, KeyLifeCycle::Live, KeyType::Client],
+            |row| row.get::<usize, usize>(0),
+        )
+        .context(ks_err!("Failed to count number of keys."))
+        .no_gc()
+    })
+    .unwrap()
+}
+
+#[test]
+fn test_handle_superseded_with_many_keys() -> Result<()> {
+    run_with_many_keys(1_000_000, |db| db.handle_next_superseded_blobs(&[], 20))
+}
+
+#[test]
+fn test_get_storage_stats_with_many_keys() -> Result<()> {
+    use android_security_metrics::aidl::android::security::metrics::Storage::Storage as MetricsStorage;
+    run_with_many_keys(1_000_000, |db| {
+        db.get_storage_stat(MetricsStorage::DATABASE).unwrap();
+        db.get_storage_stat(MetricsStorage::KEY_ENTRY).unwrap();
+        db.get_storage_stat(MetricsStorage::KEY_ENTRY_ID_INDEX).unwrap();
+        db.get_storage_stat(MetricsStorage::KEY_ENTRY_DOMAIN_NAMESPACE_INDEX).unwrap();
+        db.get_storage_stat(MetricsStorage::BLOB_ENTRY).unwrap();
+        db.get_storage_stat(MetricsStorage::BLOB_ENTRY_KEY_ENTRY_ID_INDEX).unwrap();
+        db.get_storage_stat(MetricsStorage::KEY_PARAMETER).unwrap();
+        db.get_storage_stat(MetricsStorage::KEY_PARAMETER_KEY_ENTRY_ID_INDEX).unwrap();
+        db.get_storage_stat(MetricsStorage::KEY_METADATA).unwrap();
+        db.get_storage_stat(MetricsStorage::KEY_METADATA_KEY_ENTRY_ID_INDEX).unwrap();
+        db.get_storage_stat(MetricsStorage::GRANT).unwrap();
+        db.get_storage_stat(MetricsStorage::AUTH_TOKEN).unwrap();
+        db.get_storage_stat(MetricsStorage::BLOB_METADATA).unwrap();
+        db.get_storage_stat(MetricsStorage::BLOB_METADATA_BLOB_ENTRY_ID_INDEX).unwrap();
+    })
+}
+
+#[test]
+fn test_list_keys_with_many_keys() -> Result<()> {
+    run_with_many_keys(1_000_000, |db: &mut KeystoreDB| -> Result<()> {
+        // Behave equivalently to how clients list aliases.
+        let domain = Domain::APP;
+        let namespace = 10001;
+        let mut start_past: Option<String> = None;
+        let mut count = 0;
+        let mut batches = 0;
+        loop {
+            let keys = db
+                .list_past_alias(domain, namespace, KeyType::Client, start_past.as_deref())
+                .unwrap();
+            let batch_size = crate::utils::estimate_safe_amount_to_return(
+                domain,
+                namespace,
+                &keys,
+                crate::utils::RESPONSE_SIZE_LIMIT,
+            );
+            let batch = &keys[..batch_size];
+            count += batch.len();
+            match batch.last() {
+                Some(key) => start_past.clone_from(&key.alias),
+                None => {
+                    log::info!("got {count} keys in {batches} non-empty batches");
+                    return Ok(());
+                }
+            }
+            batches += 1;
+        }
+    })
+}
diff --git a/keystore2/src/metrics_store.rs b/keystore2/src/metrics_store.rs
index 895374c..01f0a01 100644
--- a/keystore2/src/metrics_store.rs
+++ b/keystore2/src/metrics_store.rs
@@ -516,7 +516,7 @@
     }
 }
 
-fn pull_storage_stats() -> Result<Vec<KeystoreAtom>> {
+pub(crate) fn pull_storage_stats() -> Result<Vec<KeystoreAtom>> {
     let mut atom_vec: Vec<KeystoreAtom> = Vec::new();
     let mut append = |stat| {
         match stat {
diff --git a/keystore2/src/utils.rs b/keystore2/src/utils.rs
index 03c0626..81ebdab 100644
--- a/keystore2/src/utils.rs
+++ b/keystore2/src/utils.rs
@@ -522,7 +522,7 @@
     result
 }
 
-fn estimate_safe_amount_to_return(
+pub(crate) fn estimate_safe_amount_to_return(
     domain: Domain,
     namespace: i64,
     key_descriptors: &[KeyDescriptor],
@@ -560,6 +560,9 @@
     items_to_return
 }
 
+/// Estimate for maximum size of a Binder response in bytes.
+pub(crate) const RESPONSE_SIZE_LIMIT: usize = 358400;
+
 /// List all key aliases for a given domain + namespace. whose alias is greater
 /// than start_past_alias (if provided).
 pub fn list_key_entries(
@@ -583,7 +586,6 @@
         start_past_alias,
     );
 
-    const RESPONSE_SIZE_LIMIT: usize = 358400;
     let safe_amount_to_return =
         estimate_safe_amount_to_return(domain, namespace, &merged_key_entries, RESPONSE_SIZE_LIMIT);
     Ok(merged_key_entries[..safe_amount_to_return].to_vec())
diff --git a/keystore2/test_utils/key_generations.rs b/keystore2/test_utils/key_generations.rs
index e2f0b3e..e63ee60 100644
--- a/keystore2/test_utils/key_generations.rs
+++ b/keystore2/test_utils/key_generations.rs
@@ -520,18 +520,6 @@
         }
     ));
 
-    // Access denied for finding vendor-patch-level ("ro.vendor.build.security_patch") property
-    // in a test running with `untrusted_app` context. Keeping this check to verify
-    // vendor-patch-level in tests running with `su` context.
-    if getuid().is_root() {
-        assert!(check_key_param(
-            authorizations,
-            &KeyParameter {
-                tag: Tag::VENDOR_PATCHLEVEL,
-                value: KeyParameterValue::Integer(get_vendor_patchlevel().try_into().unwrap())
-            }
-        ));
-    }
     assert!(check_key_param(
         authorizations,
         &KeyParameter { tag: Tag::ORIGIN, value: KeyParameterValue::Origin(expected_key_origin) }
@@ -553,6 +541,22 @@
             .iter()
             .map(|auth| &auth.keyParameter)
             .any(|key_param| key_param.tag == Tag::CREATION_DATETIME));
+
+        // Access denied for finding vendor-patch-level ("ro.vendor.build.security_patch") property
+        // in a test running with `untrusted_app` context. Keeping this check to verify
+        // vendor-patch-level in tests running with `su` context.
+        if getuid().is_root() {
+            // Keymaster implementations may not consistently include `Tag::VENDOR_PATCHLEVEL`
+            // in generated key characteristics. So, checking this if the underlying device is a
+            // KeyMint implementation.
+            assert!(check_key_param(
+                authorizations,
+                &KeyParameter {
+                    tag: Tag::VENDOR_PATCHLEVEL,
+                    value: KeyParameterValue::Integer(get_vendor_patchlevel().try_into().unwrap())
+                }
+            ));
+        }
     }
 }
 
diff --git a/provisioner/rkp_factory_extraction_lib.cpp b/provisioner/rkp_factory_extraction_lib.cpp
index ec70d08..2c2614d 100644
--- a/provisioner/rkp_factory_extraction_lib.cpp
+++ b/provisioner/rkp_factory_extraction_lib.cpp
@@ -224,7 +224,8 @@
 }
 
 CborResult<cppbor::Array> getCsrV3(std::string_view componentName,
-                                   IRemotelyProvisionedComponent* irpc, bool selfTest) {
+                                   IRemotelyProvisionedComponent* irpc, bool selfTest,
+                                   bool allowDegenerate) {
     std::vector<uint8_t> csr;
     std::vector<MacedPublicKey> emptyKeys;
     const std::vector<uint8_t> challenge = generateChallenge();
@@ -237,7 +238,8 @@
     }
 
     if (selfTest) {
-        auto result = verifyFactoryCsr(/*keysToSign=*/cppbor::Array(), csr, irpc, challenge);
+        auto result =
+            verifyFactoryCsr(/*keysToSign=*/cppbor::Array(), csr, irpc, challenge, allowDegenerate);
         if (!result) {
             std::cerr << "Self test failed for IRemotelyProvisionedComponent '" << componentName
                       << "'. Error message: '" << result.message() << "'." << std::endl;
@@ -249,7 +251,7 @@
 }
 
 CborResult<Array> getCsr(std::string_view componentName, IRemotelyProvisionedComponent* irpc,
-                         bool selfTest) {
+                         bool selfTest, bool allowDegenerate) {
     RpcHardwareInfo hwInfo;
     auto status = irpc->getHardwareInfo(&hwInfo);
     if (!status.isOk()) {
@@ -264,7 +266,7 @@
         }
         return getCsrV1(componentName, irpc);
     } else {
-        return getCsrV3(componentName, irpc, selfTest);
+        return getCsrV3(componentName, irpc, selfTest, allowDegenerate);
     }
 }
 
diff --git a/provisioner/rkp_factory_extraction_lib.h b/provisioner/rkp_factory_extraction_lib.h
index 93c498a..94bd751 100644
--- a/provisioner/rkp_factory_extraction_lib.h
+++ b/provisioner/rkp_factory_extraction_lib.h
@@ -47,7 +47,7 @@
 CborResult<cppbor::Array>
 getCsr(std::string_view componentName,
        aidl::android::hardware::security::keymint::IRemotelyProvisionedComponent* irpc,
-       bool selfTest);
+       bool selfTest, bool allowDegenerate);
 
 // Generates a test certificate chain and validates it, exiting the process on error.
 void selfTestGetCsr(
diff --git a/provisioner/rkp_factory_extraction_lib_test.cpp b/provisioner/rkp_factory_extraction_lib_test.cpp
index 3fe88da..247c508 100644
--- a/provisioner/rkp_factory_extraction_lib_test.cpp
+++ b/provisioner/rkp_factory_extraction_lib_test.cpp
@@ -181,7 +181,7 @@
                         Return(ByMove(ScopedAStatus::ok()))));  //
 
     auto [csr, csrErrMsg] = getCsr("mock component name", mockRpc.get(),
-                                   /*selfTest=*/false);
+                                   /*selfTest=*/false, /*allowDegenerate=*/true);
     ASSERT_THAT(csr, NotNull()) << csrErrMsg;
     ASSERT_THAT(csr->asArray(), Pointee(Property(&Array::size, Eq(4))));
 
@@ -251,7 +251,7 @@
                         Return(ByMove(ScopedAStatus::ok()))));
 
     auto [csr, csrErrMsg] = getCsr("mock component name", mockRpc.get(),
-                                   /*selfTest=*/false);
+                                   /*selfTest=*/false, /*allowDegenerate=*/true);
     ASSERT_THAT(csr, NotNull()) << csrErrMsg;
     ASSERT_THAT(csr, Pointee(Property(&Array::size, Eq(5))));
 
diff --git a/provisioner/rkp_factory_extraction_tool.cpp b/provisioner/rkp_factory_extraction_tool.cpp
index 1cb1144..c0f6beb 100644
--- a/provisioner/rkp_factory_extraction_tool.cpp
+++ b/provisioner/rkp_factory_extraction_tool.cpp
@@ -43,6 +43,8 @@
             "If true, this tool performs a self-test, validating the payload for correctness. "
             "This checks that the device on the factory line is producing valid output "
             "before attempting to upload the output to the device info service.");
+DEFINE_bool(allow_degenerate, true,
+            "If true, self_test validation will allow degenerate DICE chains in the CSR.");
 DEFINE_string(serialno_prop, "ro.serialno",
               "The property of getting serial number. Defaults to 'ro.serialno'.");
 
@@ -83,7 +85,7 @@
     if (std::string(name) == "avf" && !isRemoteProvisioningSupported(irpc)) {
         return;
     }
-    auto [request, errMsg] = getCsr(name, irpc, FLAGS_self_test);
+    auto [request, errMsg] = getCsr(name, irpc, FLAGS_self_test, FLAGS_allow_degenerate);
     auto fullName = getFullServiceName(descriptor, name);
     if (!request) {
         std::cerr << "Unable to build CSR for '" << fullName << ": " << errMsg << std::endl;