Upgrade remotely provisioned keys if necessary.

This change applies a key upgrade path to the RKP keys when they are
used in key generation. Without this, RKP keys will fail after a device
receives an update due to version mismatch in KeyMint.

Test: atest keystore2_test && atest RemoteProvisionerUnitTests
Change-Id: I5dddc8fa1fe7fe9d7dd559b337089d607fcc735a
diff --git a/keystore2/src/attestation_key_utils.rs b/keystore2/src/attestation_key_utils.rs
index a8c1ca9..8354ba5 100644
--- a/keystore2/src/attestation_key_utils.rs
+++ b/keystore2/src/attestation_key_utils.rs
@@ -35,6 +35,7 @@
 /// handled quite differently, thus the different representations.
 pub enum AttestationKeyInfo {
     RemoteProvisioned {
+        key_id_guard: KeyIdGuard,
         attestation_key: AttestationKey,
         attestation_certs: Certificate,
     },
@@ -66,8 +67,12 @@
                 "Trying to get remotely provisioned attestation key."
             ))
             .map(|result| {
-                result.map(|(attestation_key, attestation_certs)| {
-                    AttestationKeyInfo::RemoteProvisioned { attestation_key, attestation_certs }
+                result.map(|(key_id_guard, attestation_key, attestation_certs)| {
+                    AttestationKeyInfo::RemoteProvisioned {
+                        key_id_guard,
+                        attestation_key,
+                        attestation_certs,
+                    }
                 })
             }),
         None => Ok(None),
diff --git a/keystore2/src/database.rs b/keystore2/src/database.rs
index 65ee7ae..28a2e9d 100644
--- a/keystore2/src/database.rs
+++ b/keystore2/src/database.rs
@@ -2049,6 +2049,41 @@
         .context("In get_attestation_pool_status: ")
     }
 
+    fn query_kid_for_attestation_key_and_cert_chain(
+        &self,
+        tx: &Transaction,
+        domain: Domain,
+        namespace: i64,
+        km_uuid: &Uuid,
+    ) -> Result<Option<i64>> {
+        let mut stmt = tx.prepare(
+            "SELECT id
+             FROM persistent.keyentry
+             WHERE key_type = ?
+                   AND domain = ?
+                   AND namespace = ?
+                   AND state = ?
+                   AND km_uuid = ?;",
+        )?;
+        let rows = stmt
+            .query_map(
+                params![
+                    KeyType::Attestation,
+                    domain.0 as u32,
+                    namespace,
+                    KeyLifeCycle::Live,
+                    km_uuid
+                ],
+                |row| row.get(0),
+            )?
+            .collect::<rusqlite::Result<Vec<i64>>>()
+            .context("query failed.")?;
+        if rows.is_empty() {
+            return Ok(None);
+        }
+        Ok(Some(rows[0]))
+    }
+
     /// Fetches the private key and corresponding certificate chain assigned to a
     /// domain/namespace pair. Will either return nothing if the domain/namespace is
     /// not assigned, or one CertificateChain.
@@ -2057,7 +2092,7 @@
         domain: Domain,
         namespace: i64,
         km_uuid: &Uuid,
-    ) -> Result<Option<CertificateChain>> {
+    ) -> Result<Option<(KeyIdGuard, CertificateChain)>> {
         let _wp = wd::watch_millis("KeystoreDB::retrieve_attestation_key_and_cert_chain", 500);
 
         match domain {
@@ -2067,69 +2102,67 @@
                     .context(format!("Domain {:?} must be either App or SELinux.", domain));
             }
         }
-        self.with_transaction(TransactionBehavior::Deferred, |tx| {
-            let mut stmt = tx.prepare(
-                "SELECT subcomponent_type, blob
-             FROM persistent.blobentry
-             WHERE keyentryid IN
-                (SELECT id
-                 FROM persistent.keyentry
-                 WHERE key_type = ?
-                       AND domain = ?
-                       AND namespace = ?
-                       AND state = ?
-                       AND km_uuid = ?);",
-            )?;
-            let rows = stmt
-                .query_map(
-                    params![
-                        KeyType::Attestation,
-                        domain.0 as u32,
-                        namespace,
-                        KeyLifeCycle::Live,
-                        km_uuid
-                    ],
-                    |row| Ok((row.get(0)?, row.get(1)?)),
-                )?
-                .collect::<rusqlite::Result<Vec<(SubComponentType, Vec<u8>)>>>()
-                .context("query failed.")?;
-            if rows.is_empty() {
-                return Ok(None).no_gc();
-            } else if rows.len() != 3 {
-                return Err(KsError::sys()).context(format!(
-                    concat!(
-                        "Expected to get a single attestation",
-                        "key, cert, and cert chain for a total of 3 entries, but instead got {}."
-                    ),
-                    rows.len()
-                ));
-            }
-            let mut km_blob: Vec<u8> = Vec::new();
-            let mut cert_chain_blob: Vec<u8> = Vec::new();
-            let mut batch_cert_blob: Vec<u8> = Vec::new();
-            for row in rows {
-                let sub_type: SubComponentType = row.0;
-                match sub_type {
-                    SubComponentType::KEY_BLOB => {
-                        km_blob = row.1;
-                    }
-                    SubComponentType::CERT_CHAIN => {
-                        cert_chain_blob = row.1;
-                    }
-                    SubComponentType::CERT => {
-                        batch_cert_blob = row.1;
-                    }
-                    _ => Err(KsError::sys()).context("Unknown or incorrect subcomponent type.")?,
+
+        let tx = self.conn.unchecked_transaction().context(
+            "In retrieve_attestation_key_and_cert_chain: Failed to initialize transaction.",
+        )?;
+        let key_id: i64;
+        match self.query_kid_for_attestation_key_and_cert_chain(&tx, domain, namespace, km_uuid)? {
+            None => return Ok(None),
+            Some(kid) => key_id = kid,
+        }
+        tx.commit()
+            .context("In retrieve_attestation_key_and_cert_chain: Failed to commit keyid query")?;
+        let key_id_guard = KEY_ID_LOCK.get(key_id);
+        let tx = self.conn.unchecked_transaction().context(
+            "In retrieve_attestation_key_and_cert_chain: Failed to initialize transaction.",
+        )?;
+        let mut stmt = tx.prepare(
+            "SELECT subcomponent_type, blob
+            FROM persistent.blobentry
+            WHERE keyentryid = ?;",
+        )?;
+        let rows = stmt
+            .query_map(params![key_id_guard.id()], |row| Ok((row.get(0)?, row.get(1)?)))?
+            .collect::<rusqlite::Result<Vec<(SubComponentType, Vec<u8>)>>>()
+            .context("query failed.")?;
+        if rows.is_empty() {
+            return Ok(None);
+        } else if rows.len() != 3 {
+            return Err(KsError::sys()).context(format!(
+                concat!(
+                    "Expected to get a single attestation",
+                    "key, cert, and cert chain for a total of 3 entries, but instead got {}."
+                ),
+                rows.len()
+            ));
+        }
+        let mut km_blob: Vec<u8> = Vec::new();
+        let mut cert_chain_blob: Vec<u8> = Vec::new();
+        let mut batch_cert_blob: Vec<u8> = Vec::new();
+        for row in rows {
+            let sub_type: SubComponentType = row.0;
+            match sub_type {
+                SubComponentType::KEY_BLOB => {
+                    km_blob = row.1;
                 }
+                SubComponentType::CERT_CHAIN => {
+                    cert_chain_blob = row.1;
+                }
+                SubComponentType::CERT => {
+                    batch_cert_blob = row.1;
+                }
+                _ => Err(KsError::sys()).context("Unknown or incorrect subcomponent type.")?,
             }
-            Ok(Some(CertificateChain {
+        }
+        Ok(Some((
+            key_id_guard,
+            CertificateChain {
                 private_key: ZVec::try_from(km_blob)?,
                 batch_cert: batch_cert_blob,
                 cert_chain: cert_chain_blob,
-            }))
-            .no_gc()
-        })
-        .context("In retrieve_attestation_key_and_cert_chain:")
+            },
+        )))
     }
 
     /// Updates the alias column of the given key id `newid` with the given alias,
@@ -3516,7 +3549,7 @@
         let chain =
             db.retrieve_attestation_key_and_cert_chain(Domain::APP, namespace, &KEYSTORE_UUID)?;
         assert!(chain.is_some());
-        let cert_chain = chain.unwrap();
+        let (_, cert_chain) = chain.unwrap();
         assert_eq!(cert_chain.private_key.to_vec(), loaded_values.priv_key);
         assert_eq!(cert_chain.batch_cert, loaded_values.batch_cert);
         assert_eq!(cert_chain.cert_chain, loaded_values.cert_chain);
@@ -3611,7 +3644,7 @@
         let mut cert_chain =
             db.retrieve_attestation_key_and_cert_chain(Domain::APP, namespace, &KEYSTORE_UUID)?;
         assert!(cert_chain.is_some());
-        let value = cert_chain.unwrap();
+        let (_, value) = cert_chain.unwrap();
         assert_eq!(entry_values.batch_cert, value.batch_cert);
         assert_eq!(entry_values.cert_chain, value.cert_chain);
         assert_eq!(entry_values.priv_key, value.private_key.to_vec());
diff --git a/keystore2/src/remote_provisioning.rs b/keystore2/src/remote_provisioning.rs
index 639fe1e..be23ae5 100644
--- a/keystore2/src/remote_provisioning.rs
+++ b/keystore2/src/remote_provisioning.rs
@@ -45,7 +45,7 @@
 use std::collections::BTreeMap;
 use std::sync::atomic::{AtomicBool, Ordering};
 
-use crate::database::{CertificateChain, KeystoreDB, Uuid};
+use crate::database::{CertificateChain, KeyIdGuard, KeystoreDB, Uuid};
 use crate::error::{self, map_or_log_err, map_rem_prov_error, Error};
 use crate::globals::{get_keymint_device, get_remotely_provisioned_component, DB};
 use crate::metrics_store::log_rkp_error_stats;
@@ -73,6 +73,11 @@
         Self { security_level, km_uuid, is_hal_present: AtomicBool::new(true) }
     }
 
+    /// Returns the uuid for the KM instance attached to this RemProvState struct.
+    pub fn get_uuid(&self) -> Uuid {
+        self.km_uuid
+    }
+
     /// Checks if remote provisioning is enabled and partially caches the result. On a hybrid system
     /// remote provisioning can flip from being disabled to enabled depending on responses from the
     /// server, so unfortunately caching the presence or absence of the HAL is not enough to fully
@@ -121,7 +126,7 @@
         caller_uid: u32,
         params: &[KeyParameter],
         db: &mut KeystoreDB,
-    ) -> Result<Option<(AttestationKey, Certificate)>> {
+    ) -> Result<Option<(KeyIdGuard, AttestationKey, Certificate)>> {
         if !self.is_asymmetric_key(params) || !self.check_rem_prov_enabled(db)? {
             // There is no remote provisioning component for this security level on the
             // device. Return None so the underlying KM instance knows to use its
@@ -142,7 +147,8 @@
                     Ok(None)
                 }
                 Ok(v) => match v {
-                    Some(cert_chain) => Ok(Some((
+                    Some((guard, cert_chain)) => Ok(Some((
+                        guard,
                         AttestationKey {
                             keyBlob: cert_chain.private_key.to_vec(),
                             attestKeyParams: vec![],
@@ -433,7 +439,7 @@
     caller_uid: u32,
     db: &mut KeystoreDB,
     km_uuid: &Uuid,
-) -> Result<Option<CertificateChain>> {
+) -> Result<Option<(KeyIdGuard, CertificateChain)>> {
     match domain {
         Domain::APP => {
             // Attempt to get an Attestation Key once. If it fails, then the app doesn't
@@ -458,7 +464,7 @@
                             "key and failed silently. Something is very wrong."
                         ))
                     },
-                    |cert_chain| Ok(Some(cert_chain)),
+                    |(guard, cert_chain)| Ok(Some((guard, cert_chain))),
                 )
         }
         _ => Ok(None),
@@ -471,12 +477,12 @@
     caller_uid: u32,
     db: &mut KeystoreDB,
     km_uuid: &Uuid,
-) -> Result<Option<CertificateChain>> {
-    let cert_chain = db
+) -> Result<Option<(KeyIdGuard, CertificateChain)>> {
+    let guard_and_chain = db
         .retrieve_attestation_key_and_cert_chain(domain, caller_uid as i64, km_uuid)
         .context("In get_rem_prov_attest_key_helper: Failed to retrieve a key + cert chain")?;
-    match cert_chain {
-        Some(cert_chain) => Ok(Some(cert_chain)),
+    match guard_and_chain {
+        Some((guard, cert_chain)) => Ok(Some((guard, cert_chain))),
         // Either this app needs to be assigned a key, or the pool is empty. An error will
         // be thrown if there is no key available to assign. This will indicate that the app
         // should be nudged to provision more keys so keystore can retry.
@@ -598,10 +604,11 @@
             .context(format!("In get_attestation_key: unknown irpc id '{}'", irpc_id))?;
         let (_, _, km_uuid) = get_keymint_device(sec_level)?;
 
-        let cert_chain = get_rem_prov_attest_key(Domain::APP, caller_uid as u32, db, &km_uuid)
-            .context("In get_attestation_key")?;
-        match cert_chain {
-            Some(chain) => Ok(RemotelyProvisionedKey {
+        let guard_and_cert_chain =
+            get_rem_prov_attest_key(Domain::APP, caller_uid as u32, db, &km_uuid)
+                .context("In get_attestation_key")?;
+        match guard_and_cert_chain {
+            Some((_, chain)) => Ok(RemotelyProvisionedKey {
                 keyBlob: chain.private_key.to_vec(),
                 encodedCertChain: chain.cert_chain,
             }),
diff --git a/keystore2/src/security_level.rs b/keystore2/src/security_level.rs
index 4cf41c5..8574244 100644
--- a/keystore2/src/security_level.rs
+++ b/keystore2/src/security_level.rs
@@ -319,7 +319,7 @@
                 &*self.keymint,
                 key_id_guard,
                 &km_blob,
-                &blob_metadata,
+                blob_metadata.km_uuid().copied(),
                 operation_parameters,
                 |blob| loop {
                     match map_km_error({
@@ -557,7 +557,7 @@
                     &*self.keymint,
                     Some(key_id_guard),
                     &KeyBlob::Ref(&blob),
-                    &blob_metadata,
+                    blob_metadata.km_uuid().copied(),
                     &params,
                     |blob| {
                         let attest_key = Some(AttestationKey {
@@ -579,23 +579,40 @@
                 )
                 .context("In generate_key: Using user generated attestation key.")
                 .map(|(result, _)| result),
-            Some(AttestationKeyInfo::RemoteProvisioned { attestation_key, attestation_certs }) => {
-                map_km_error({
-                    let _wp = self.watch_millis(
-                        concat!(
-                            "In KeystoreSecurityLevel::generate_key (RemoteProvisioned): ",
-                            "calling generate_key.",
-                        ),
-                        5000, // Generate can take a little longer.
-                    );
-                    self.keymint.generateKey(&params, Some(&attestation_key))
-                })
+            Some(AttestationKeyInfo::RemoteProvisioned {
+                key_id_guard,
+                attestation_key,
+                attestation_certs,
+            }) => self
+                .upgrade_keyblob_if_required_with(
+                    &*self.keymint,
+                    Some(key_id_guard),
+                    &KeyBlob::Ref(&attestation_key.keyBlob),
+                    Some(self.rem_prov_state.get_uuid()),
+                    &[],
+                    |blob| {
+                        map_km_error({
+                            let _wp = self.watch_millis(
+                                concat!(
+                                    "In KeystoreSecurityLevel::generate_key (RemoteProvisioned): ",
+                                    "calling generate_key.",
+                                ),
+                                5000, // Generate can take a little longer.
+                            );
+                            let dynamic_attest_key = Some(AttestationKey {
+                                keyBlob: blob.to_vec(),
+                                attestKeyParams: vec![],
+                                issuerSubjectName: attestation_key.issuerSubjectName.clone(),
+                            });
+                            self.keymint.generateKey(&params, dynamic_attest_key.as_ref())
+                        })
+                    },
+                )
                 .context("While generating Key with remote provisioned attestation key.")
-                .map(|mut creation_result| {
-                    creation_result.certificateChain.push(attestation_certs);
-                    creation_result
-                })
-            }
+                .map(|(mut result, _)| {
+                    result.certificateChain.push(attestation_certs);
+                    result
+                }),
             None => map_km_error({
                 let _wp = self.watch_millis(
                     concat!(
@@ -781,7 +798,7 @@
                 &*self.keymint,
                 Some(wrapping_key_id_guard),
                 &wrapping_key_blob,
-                &wrapping_blob_metadata,
+                wrapping_blob_metadata.km_uuid().copied(),
                 &[],
                 |wrapping_blob| {
                     let _wp = self.watch_millis(
@@ -807,7 +824,7 @@
 
     fn store_upgraded_keyblob(
         key_id_guard: KeyIdGuard,
-        km_uuid: Option<&Uuid>,
+        km_uuid: Option<Uuid>,
         key_blob: &KeyBlob,
         upgraded_blob: &[u8],
     ) -> Result<()> {
@@ -817,7 +834,7 @@
 
         let mut new_blob_metadata = new_blob_metadata.unwrap_or_default();
         if let Some(uuid) = km_uuid {
-            new_blob_metadata.add(BlobMetaEntry::KmUuid(*uuid));
+            new_blob_metadata.add(BlobMetaEntry::KmUuid(uuid));
         }
 
         DB.with(|db| {
@@ -837,7 +854,7 @@
         km_dev: &dyn IKeyMintDevice,
         mut key_id_guard: Option<KeyIdGuard>,
         key_blob: &KeyBlob,
-        blob_metadata: &BlobMetaData,
+        km_uuid: Option<Uuid>,
         params: &[KeyParameter],
         f: F,
     ) -> Result<(T, Option<Vec<u8>>)>
@@ -853,13 +870,9 @@
                 if key_id_guard.is_some() {
                     // Unwrap cannot panic, because the is_some was true.
                     let kid = key_id_guard.take().unwrap();
-                    Self::store_upgraded_keyblob(
-                        kid,
-                        blob_metadata.km_uuid(),
-                        key_blob,
-                        upgraded_blob,
+                    Self::store_upgraded_keyblob(kid, km_uuid, key_blob, upgraded_blob).context(
+                        "In upgrade_keyblob_if_required_with: store_upgraded_keyblob failed",
                     )
-                    .context("In upgrade_keyblob_if_required_with: store_upgraded_keyblob failed")
                 } else {
                     Ok(())
                 }
@@ -872,11 +885,10 @@
         // upgrade was performed above and if one was given in the first place.
         if key_blob.force_reencrypt() {
             if let Some(kid) = key_id_guard {
-                Self::store_upgraded_keyblob(kid, blob_metadata.km_uuid(), key_blob, key_blob)
-                    .context(concat!(
-                        "In upgrade_keyblob_if_required_with: ",
-                        "store_upgraded_keyblob failed in forced reencrypt"
-                    ))?;
+                Self::store_upgraded_keyblob(kid, km_uuid, key_blob, key_blob).context(concat!(
+                    "In upgrade_keyblob_if_required_with: ",
+                    "store_upgraded_keyblob failed in forced reencrypt"
+                ))?;
             }
         }
         Ok((v, upgraded_blob))