Merge "Keystore2: Batching listing of key entries" am: 4ec7585ff8 am: 6e5fe5015b am: cf61a5003d

Original change: https://android-review.googlesource.com/c/platform/system/security/+/2402644

Change-Id: Id7609b51a8008816be4ea23465ee351643998fef
Signed-off-by: Automerger Merge Worker <android-build-automerger-merge-worker@system.gserviceaccount.com>
diff --git a/keystore2/src/database.rs b/keystore2/src/database.rs
index 7b90fd5..6108221 100644
--- a/keystore2/src/database.rs
+++ b/keystore2/src/database.rs
@@ -2987,32 +2987,50 @@
         })
     }
 
-    /// Returns a list of KeyDescriptors in the selected domain/namespace.
+    /// Returns a list of KeyDescriptors in the selected domain/namespace whose
+    /// aliases are greater than the specified 'start_past_alias'. If no value
+    /// is provided, returns all KeyDescriptors.
     /// The key descriptors will have the domain, nspace, and alias field set.
+    /// The returned list will be sorted by alias.
     /// Domain must be APP or SELINUX, the caller must make sure of that.
-    pub fn list(
+    pub fn list_past_alias(
         &mut self,
         domain: Domain,
         namespace: i64,
         key_type: KeyType,
+        start_past_alias: Option<&str>,
     ) -> Result<Vec<KeyDescriptor>> {
-        let _wp = wd::watch_millis("KeystoreDB::list", 500);
+        let _wp = wd::watch_millis("KeystoreDB::list_past_alias", 500);
 
-        self.with_transaction(TransactionBehavior::Deferred, |tx| {
-            let mut stmt = tx
-                .prepare(
-                    "SELECT alias FROM persistent.keyentry
+        let query = format!(
+            "SELECT DISTINCT alias FROM persistent.keyentry
                      WHERE domain = ?
                      AND namespace = ?
                      AND alias IS NOT NULL
                      AND state = ?
-                     AND key_type = ?;",
-                )
-                .context(ks_err!("Failed to prepare."))?;
+                     AND key_type = ?
+                     {}
+                     ORDER BY alias ASC;",
+            if start_past_alias.is_some() { " AND alias > ?" } else { "" }
+        );
 
-            let mut rows = stmt
-                .query(params![domain.0 as u32, namespace, KeyLifeCycle::Live, key_type])
-                .context(ks_err!("Failed to query."))?;
+        self.with_transaction(TransactionBehavior::Deferred, |tx| {
+            let mut stmt = tx.prepare(&query).context(ks_err!("Failed to prepare."))?;
+
+            let mut rows = match start_past_alias {
+                Some(past_alias) => stmt
+                    .query(params![
+                        domain.0 as u32,
+                        namespace,
+                        KeyLifeCycle::Live,
+                        key_type,
+                        past_alias
+                    ])
+                    .context(ks_err!("Failed to query."))?,
+                None => stmt
+                    .query(params![domain.0 as u32, namespace, KeyLifeCycle::Live, key_type,])
+                    .context(ks_err!("Failed to query."))?,
+            };
 
             let mut descriptors: Vec<KeyDescriptor> = Vec::new();
             db_utils::with_rows_extract_all(&mut rows, |row| {
@@ -3029,6 +3047,33 @@
         })
     }
 
+    /// Returns a number of KeyDescriptors in the selected domain/namespace.
+    /// Domain must be APP or SELINUX, the caller must make sure of that.
+    pub fn count_keys(
+        &mut self,
+        domain: Domain,
+        namespace: i64,
+        key_type: KeyType,
+    ) -> Result<usize> {
+        let _wp = wd::watch_millis("KeystoreDB::countKeys", 500);
+
+        let num_keys = self.with_transaction(TransactionBehavior::Deferred, |tx| {
+            tx.query_row(
+                "SELECT COUNT(alias) FROM persistent.keyentry
+                     WHERE domain = ?
+                     AND namespace = ?
+                     AND alias IS NOT NULL
+                     AND state = ?
+                     AND key_type = ?;",
+                params![domain.0 as u32, namespace, KeyLifeCycle::Live, key_type],
+                |row| row.get(0),
+            )
+            .context(ks_err!("Failed to count number of keys."))
+            .no_gc()
+        })?;
+        Ok(num_keys)
+    }
+
     /// Adds a grant to the grant table.
     /// Like `load_key_entry` this function loads the access tuple before
     /// it uses the callback for a permission check. Upon success,
@@ -4920,7 +4965,7 @@
                 })
                 .collect();
             list_o_descriptors.sort();
-            let mut list_result = db.list(*domain, *namespace, KeyType::Client)?;
+            let mut list_result = db.list_past_alias(*domain, *namespace, KeyType::Client, None)?;
             list_result.sort();
             assert_eq!(list_o_descriptors, list_result);
 
@@ -4950,7 +4995,10 @@
             loaded_entries.sort_unstable();
             assert_eq!(list_o_ids, loaded_entries);
         }
-        assert_eq!(Vec::<KeyDescriptor>::new(), db.list(Domain::SELINUX, 101, KeyType::Client)?);
+        assert_eq!(
+            Vec::<KeyDescriptor>::new(),
+            db.list_past_alias(Domain::SELINUX, 101, KeyType::Client, None)?
+        );
 
         Ok(())
     }
@@ -5474,11 +5522,11 @@
         make_test_key_entry(&mut db, Domain::APP, 110000, TEST_ALIAS, None)?;
         db.unbind_keys_for_user(2, false)?;
 
-        assert_eq!(1, db.list(Domain::APP, 110000, KeyType::Client)?.len());
-        assert_eq!(0, db.list(Domain::APP, 210000, KeyType::Client)?.len());
+        assert_eq!(1, db.list_past_alias(Domain::APP, 110000, KeyType::Client, None)?.len());
+        assert_eq!(0, db.list_past_alias(Domain::APP, 210000, KeyType::Client, None)?.len());
 
         db.unbind_keys_for_user(1, true)?;
-        assert_eq!(0, db.list(Domain::APP, 110000, KeyType::Client)?.len());
+        assert_eq!(0, db.list_past_alias(Domain::APP, 110000, KeyType::Client, None)?.len());
 
         Ok(())
     }
diff --git a/keystore2/src/service.rs b/keystore2/src/service.rs
index 1040228..7ba8cbc 100644
--- a/keystore2/src/service.rs
+++ b/keystore2/src/service.rs
@@ -22,7 +22,7 @@
 use crate::permission::{KeyPerm, KeystorePerm};
 use crate::security_level::KeystoreSecurityLevel;
 use crate::utils::{
-    check_grant_permission, check_key_permission, check_keystore_permission,
+    check_grant_permission, check_key_permission, check_keystore_permission, count_key_entries,
     key_parameters_to_authorizations, list_key_entries, uid_to_android_user, watchdog as wd,
 };
 use crate::{
@@ -251,7 +251,11 @@
         .context(ks_err!())
     }
 
-    fn list_entries(&self, domain: Domain, namespace: i64) -> Result<Vec<KeyDescriptor>> {
+    fn get_key_descriptor_for_lookup(
+        &self,
+        domain: Domain,
+        namespace: i64,
+    ) -> Result<KeyDescriptor> {
         let mut k = match domain {
             Domain::APP => KeyDescriptor {
                 domain,
@@ -284,8 +288,29 @@
                 return Err(e).context(ks_err!("While checking key permission."))?;
             }
         }
+        Ok(k)
+    }
 
-        DB.with(|db| list_key_entries(&mut db.borrow_mut(), k.domain, k.nspace))
+    fn list_entries(&self, domain: Domain, namespace: i64) -> Result<Vec<KeyDescriptor>> {
+        let k = self.get_key_descriptor_for_lookup(domain, namespace)?;
+
+        DB.with(|db| list_key_entries(&mut db.borrow_mut(), k.domain, k.nspace, None))
+    }
+
+    fn count_num_entries(&self, domain: Domain, namespace: i64) -> Result<i32> {
+        let k = self.get_key_descriptor_for_lookup(domain, namespace)?;
+
+        DB.with(|db| count_key_entries(&mut db.borrow_mut(), k.domain, k.nspace))
+    }
+
+    fn list_entries_batched(
+        &self,
+        domain: Domain,
+        namespace: i64,
+        start_past_alias: Option<&str>,
+    ) -> Result<Vec<KeyDescriptor>> {
+        let k = self.get_key_descriptor_for_lookup(domain, namespace)?;
+        DB.with(|db| list_key_entries(&mut db.borrow_mut(), k.domain, k.nspace, start_past_alias))
     }
 
     fn delete_key(&self, key: &KeyDescriptor) -> Result<()> {
@@ -389,4 +414,18 @@
         let _wp = wd::watch_millis("IKeystoreService::ungrant", 500);
         map_or_log_err(self.ungrant(key, grantee_uid), Ok)
     }
+    fn listEntriesBatched(
+        &self,
+        domain: Domain,
+        namespace: i64,
+        start_past_alias: Option<&str>,
+    ) -> binder::Result<Vec<KeyDescriptor>> {
+        let _wp = wd::watch_millis("IKeystoreService::listEntriesBatched", 500);
+        map_or_log_err(self.list_entries_batched(domain, namespace, start_past_alias), Ok)
+    }
+
+    fn getNumberOfEntries(&self, domain: Domain, namespace: i64) -> binder::Result<i32> {
+        let _wp = wd::watch_millis("IKeystoreService::getNumberOfEntries", 500);
+        map_or_log_err(self.count_num_entries(domain, namespace), Ok)
+    }
 }
diff --git a/keystore2/src/utils.rs b/keystore2/src/utils.rs
index 7bc548e..acac7ee 100644
--- a/keystore2/src/utils.rs
+++ b/keystore2/src/utils.rs
@@ -258,32 +258,49 @@
     rustutils::users::multiuser_get_user_id(uid)
 }
 
-/// List all key aliases for a given domain + namespace.
-pub fn list_key_entries(
-    db: &mut KeystoreDB,
-    domain: Domain,
-    namespace: i64,
-) -> Result<Vec<KeyDescriptor>> {
-    let mut result = Vec::new();
-    result.append(
-        &mut LEGACY_IMPORTER
-            .list_uid(domain, namespace)
-            .context(ks_err!("Trying to list legacy keys."))?,
-    );
-    result.append(
-        &mut db
-            .list(domain, namespace, KeyType::Client)
-            .context(ks_err!("Trying to list keystore database."))?,
-    );
+/// Merges and filters two lists of key descriptors. The first input list, legacy_descriptors,
+/// is assumed to not be sorted or filtered. As such, all key descriptors in that list whose
+/// alias is less than, or equal to, start_past_alias (if provided) will be removed.
+/// This list will then be merged with the second list, db_descriptors. The db_descriptors list
+/// is assumed to be sorted and filtered so the output list will be sorted prior to returning.
+/// The returned value is a list of KeyDescriptor objects whose alias is greater than
+/// start_past_alias, sorted and de-duplicated.
+fn merge_and_filter_key_entry_lists(
+    legacy_descriptors: &[KeyDescriptor],
+    db_descriptors: &[KeyDescriptor],
+    start_past_alias: Option<&str>,
+) -> Vec<KeyDescriptor> {
+    let mut result: Vec<KeyDescriptor> =
+        match start_past_alias {
+            Some(past_alias) => legacy_descriptors
+                .iter()
+                .filter(|kd| {
+                    if let Some(alias) = &kd.alias {
+                        alias.as_str() > past_alias
+                    } else {
+                        false
+                    }
+                })
+                .cloned()
+                .collect(),
+            None => legacy_descriptors.to_vec(),
+        };
+
+    result.extend_from_slice(db_descriptors);
     result.sort_unstable();
     result.dedup();
+    result
+}
 
+fn estimate_safe_amount_to_return(
+    key_descriptors: &[KeyDescriptor],
+    response_size_limit: usize,
+) -> usize {
     let mut items_to_return = 0;
     let mut returned_bytes: usize = 0;
-    const RESPONSE_SIZE_LIMIT: usize = 358400;
     // Estimate the transaction size to avoid returning more items than what
     // could fit in a binder transaction.
-    for kd in result.iter() {
+    for kd in key_descriptors.iter() {
         // 4 bytes for the Domain enum
         // 8 bytes for the Namespace long.
         returned_bytes += 4 + 8;
@@ -298,11 +315,11 @@
         // The binder transaction size limit is 1M. Empirical measurements show
         // that the binder overhead is 60% (to be confirmed). So break after
         // 350KB and return a partial list.
-        if returned_bytes > RESPONSE_SIZE_LIMIT {
+        if returned_bytes > response_size_limit {
             log::warn!(
                 "Key descriptors list ({} items) may exceed binder \
                        size, returning {} items est {} bytes.",
-                result.len(),
+                key_descriptors.len(),
                 items_to_return,
                 returned_bytes
             );
@@ -310,7 +327,47 @@
         }
         items_to_return += 1;
     }
-    Ok(result[..items_to_return].to_vec())
+    items_to_return
+}
+
+/// List all key aliases for a given domain + namespace. whose alias is greater
+/// than start_past_alias (if provided).
+pub fn list_key_entries(
+    db: &mut KeystoreDB,
+    domain: Domain,
+    namespace: i64,
+    start_past_alias: Option<&str>,
+) -> Result<Vec<KeyDescriptor>> {
+    let legacy_key_descriptors: Vec<KeyDescriptor> = LEGACY_IMPORTER
+        .list_uid(domain, namespace)
+        .context(ks_err!("Trying to list legacy keys."))?;
+
+    // The results from the database will be sorted and unique
+    let db_key_descriptors: Vec<KeyDescriptor> = db
+        .list_past_alias(domain, namespace, KeyType::Client, start_past_alias)
+        .context(ks_err!("Trying to list keystore database past alias."))?;
+
+    let merged_key_entries = merge_and_filter_key_entry_lists(
+        &legacy_key_descriptors,
+        &db_key_descriptors,
+        start_past_alias,
+    );
+
+    const RESPONSE_SIZE_LIMIT: usize = 358400;
+    let safe_amount_to_return =
+        estimate_safe_amount_to_return(&merged_key_entries, RESPONSE_SIZE_LIMIT);
+    Ok(merged_key_entries[..safe_amount_to_return].to_vec())
+}
+
+/// Count all key aliases for a given domain + namespace.
+pub fn count_key_entries(db: &mut KeystoreDB, domain: Domain, namespace: i64) -> Result<i32> {
+    let legacy_keys = LEGACY_IMPORTER
+        .list_uid(domain, namespace)
+        .context(ks_err!("Trying to list legacy keys."))?;
+
+    let num_keys_in_db = db.count_keys(domain, namespace, KeyType::Client)?;
+
+    Ok((legacy_keys.len() + num_keys_in_db) as i32)
 }
 
 /// This module provides helpers for simplified use of the watchdog module.
@@ -407,4 +464,84 @@
             }
         })
     }
+
+    fn create_key_descriptors_from_aliases(key_aliases: &[&str]) -> Vec<KeyDescriptor> {
+        key_aliases
+            .iter()
+            .map(|key_alias| KeyDescriptor {
+                domain: Domain::APP,
+                nspace: 0,
+                alias: Some(key_alias.to_string()),
+                blob: None,
+            })
+            .collect::<Vec<KeyDescriptor>>()
+    }
+
+    fn aliases_from_key_descriptors(key_descriptors: &[KeyDescriptor]) -> Vec<String> {
+        key_descriptors
+            .iter()
+            .map(
+                |kd| {
+                    if let Some(alias) = &kd.alias {
+                        String::from(alias)
+                    } else {
+                        String::from("")
+                    }
+                },
+            )
+            .collect::<Vec<String>>()
+    }
+
+    #[test]
+    fn test_safe_amount_to_return() -> Result<()> {
+        let key_aliases = vec!["key1", "key2", "key3"];
+        let key_descriptors = create_key_descriptors_from_aliases(&key_aliases);
+
+        assert_eq!(estimate_safe_amount_to_return(&key_descriptors, 20), 1);
+        assert_eq!(estimate_safe_amount_to_return(&key_descriptors, 50), 2);
+        assert_eq!(estimate_safe_amount_to_return(&key_descriptors, 100), 3);
+        Ok(())
+    }
+
+    #[test]
+    fn test_merge_and_sort_lists_without_filtering() -> Result<()> {
+        let legacy_key_aliases = vec!["key_c", "key_a", "key_b"];
+        let legacy_key_descriptors = create_key_descriptors_from_aliases(&legacy_key_aliases);
+        let db_key_aliases = vec!["key_a", "key_d"];
+        let db_key_descriptors = create_key_descriptors_from_aliases(&db_key_aliases);
+        let result =
+            merge_and_filter_key_entry_lists(&legacy_key_descriptors, &db_key_descriptors, None);
+        assert_eq!(aliases_from_key_descriptors(&result), vec!["key_a", "key_b", "key_c", "key_d"]);
+        Ok(())
+    }
+
+    #[test]
+    fn test_merge_and_sort_lists_with_filtering() -> Result<()> {
+        let legacy_key_aliases = vec!["key_f", "key_a", "key_e", "key_b"];
+        let legacy_key_descriptors = create_key_descriptors_from_aliases(&legacy_key_aliases);
+        let db_key_aliases = vec!["key_c", "key_g"];
+        let db_key_descriptors = create_key_descriptors_from_aliases(&db_key_aliases);
+        let result = merge_and_filter_key_entry_lists(
+            &legacy_key_descriptors,
+            &db_key_descriptors,
+            Some("key_b"),
+        );
+        assert_eq!(aliases_from_key_descriptors(&result), vec!["key_c", "key_e", "key_f", "key_g"]);
+        Ok(())
+    }
+
+    #[test]
+    fn test_merge_and_sort_lists_with_filtering_and_dups() -> Result<()> {
+        let legacy_key_aliases = vec!["key_f", "key_a", "key_e", "key_b"];
+        let legacy_key_descriptors = create_key_descriptors_from_aliases(&legacy_key_aliases);
+        let db_key_aliases = vec!["key_d", "key_e", "key_g"];
+        let db_key_descriptors = create_key_descriptors_from_aliases(&db_key_aliases);
+        let result = merge_and_filter_key_entry_lists(
+            &legacy_key_descriptors,
+            &db_key_descriptors,
+            Some("key_c"),
+        );
+        assert_eq!(aliases_from_key_descriptors(&result), vec!["key_d", "key_e", "key_f", "key_g"]);
+        Ok(())
+    }
 }