keystore2: Add utils to fetch keys from RKPD

Bug: 261214100
Test: keystore2_test
Change-Id: I87ec83dd700b4e3e24c65ce0650cd5643000a390
diff --git a/keystore2/Android.bp b/keystore2/Android.bp
index 51ce9d1..fa80563 100644
--- a/keystore2/Android.bp
+++ b/keystore2/Android.bp
@@ -41,8 +41,10 @@
         "android.security.maintenance-rust",
         "android.security.metrics-rust",
         "android.security.remoteprovisioning-rust",
+        "android.security.rkp_aidl-rust",
         "libanyhow",
         "libbinder_rs",
+        "libfutures",
         "libkeystore2_aaid-rust",
         "libkeystore2_apc_compat-rust",
         "libkeystore2_crypto_rust",
diff --git a/keystore2/src/globals.rs b/keystore2/src/globals.rs
index 425812f..ed59578 100644
--- a/keystore2/src/globals.rs
+++ b/keystore2/src/globals.rs
@@ -441,13 +441,12 @@
 static REMOTE_PROVISIONING_HAL_SERVICE_NAME: &str =
     "android.hardware.security.keymint.IRemotelyProvisionedComponent";
 
-fn connect_remotely_provisioned_component(
-    security_level: &SecurityLevel,
-) -> Result<Strong<dyn IRemotelyProvisionedComponent>> {
+/// Get the service name of a remotely provisioned component corresponding to given security level.
+pub fn get_remotely_provisioned_component_name(security_level: &SecurityLevel) -> Result<String> {
     let remotely_prov_instances =
         get_aidl_instances("android.hardware.security.keymint", 1, "IRemotelyProvisionedComponent");
 
-    let service_name = match *security_level {
+    match *security_level {
         SecurityLevel::TRUSTED_ENVIRONMENT => {
             if remotely_prov_instances.iter().any(|instance| *instance == "default") {
                 Some(format!("{}/default", REMOTE_PROVISIONING_HAL_SERVICE_NAME))
@@ -465,8 +464,13 @@
         _ => None,
     }
     .ok_or(Error::Km(ErrorCode::HARDWARE_TYPE_UNAVAILABLE))
-    .context(ks_err!())?;
+    .context(ks_err!())
+}
 
+fn connect_remotely_provisioned_component(
+    security_level: &SecurityLevel,
+) -> Result<Strong<dyn IRemotelyProvisionedComponent>> {
+    let service_name = get_remotely_provisioned_component_name(security_level)?;
     let rem_prov_hal: Strong<dyn IRemotelyProvisionedComponent> =
         map_binder_status_code(binder::get_interface(&service_name))
             .context(ks_err!("Trying to connect to RemotelyProvisionedComponent service."))?;
diff --git a/keystore2/src/lib.rs b/keystore2/src/lib.rs
index 0b830be..9794889 100644
--- a/keystore2/src/lib.rs
+++ b/keystore2/src/lib.rs
@@ -38,6 +38,7 @@
 pub mod permission;
 pub mod raw_device;
 pub mod remote_provisioning;
+pub mod rkpd_client;
 pub mod security_level;
 pub mod service;
 pub mod shared_secret_negotiation;
diff --git a/keystore2/src/rkpd_client.rs b/keystore2/src/rkpd_client.rs
new file mode 100644
index 0000000..2d1b23b
--- /dev/null
+++ b/keystore2/src/rkpd_client.rs
@@ -0,0 +1,387 @@
+// Copyright 2022, The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//! Helper wrapper around RKPD interface.
+
+use crate::error::{map_binder_status_code, map_or_log_err, Error, ErrorCode};
+use crate::globals::get_remotely_provisioned_component_name;
+use crate::ks_err;
+use crate::utils::watchdog as wd;
+use android_hardware_security_keymint::aidl::android::hardware::security::keymint::SecurityLevel::SecurityLevel;
+use android_security_rkp_aidl::aidl::android::security::rkp::{
+    IGetKeyCallback::BnGetKeyCallback, IGetKeyCallback::IGetKeyCallback,
+    IGetRegistrationCallback::BnGetRegistrationCallback,
+    IGetRegistrationCallback::IGetRegistrationCallback, IRegistration::IRegistration,
+    IRemoteProvisioning::IRemoteProvisioning, RemotelyProvisionedKey::RemotelyProvisionedKey,
+};
+use android_security_rkp_aidl::binder::{BinderFeatures, Interface, Strong};
+use anyhow::{Context, Result};
+use futures::channel::oneshot;
+use futures::executor::block_on;
+use std::sync::Mutex;
+
+type RegistrationSender = oneshot::Sender<Result<binder::Strong<dyn IRegistration>>>;
+
+struct GetRegistrationCallback {
+    registration_tx: Mutex<Option<RegistrationSender>>,
+}
+
+impl GetRegistrationCallback {
+    pub fn new_native_binder(
+        registration_tx: RegistrationSender,
+    ) -> Result<Strong<dyn IGetRegistrationCallback>> {
+        let result: Self =
+            GetRegistrationCallback { registration_tx: Mutex::new(Some(registration_tx)) };
+        Ok(BnGetRegistrationCallback::new_binder(result, BinderFeatures::default()))
+    }
+    fn on_success(&self, registration: &binder::Strong<dyn IRegistration>) -> Result<()> {
+        if let Some(tx) = self.registration_tx.lock().unwrap().take() {
+            tx.send(Ok(registration.clone())).unwrap();
+        }
+        Ok(())
+    }
+    fn on_cancel(&self) -> Result<()> {
+        if let Some(tx) = self.registration_tx.lock().unwrap().take() {
+            tx.send(
+                Err(Error::Km(ErrorCode::OPERATION_CANCELLED))
+                    .context(ks_err!("GetRegistrationCallback cancelled.")),
+            )
+            .unwrap();
+        }
+        Ok(())
+    }
+    fn on_error(&self, error: &str) -> Result<()> {
+        if let Some(tx) = self.registration_tx.lock().unwrap().take() {
+            tx.send(
+                Err(Error::Km(ErrorCode::UNKNOWN_ERROR))
+                    .context(ks_err!("GetRegistrationCallback failed: {:?}", error)),
+            )
+            .unwrap();
+        }
+        Ok(())
+    }
+}
+
+impl Interface for GetRegistrationCallback {}
+
+impl IGetRegistrationCallback for GetRegistrationCallback {
+    fn onSuccess(&self, registration: &Strong<dyn IRegistration>) -> binder::Result<()> {
+        let _wp = wd::watch_millis("IGetRegistrationCallback::onSuccess", 500);
+        map_or_log_err(self.on_success(registration), Ok)
+    }
+    fn onCancel(&self) -> binder::Result<()> {
+        let _wp = wd::watch_millis("IGetRegistrationCallback::onCancel", 500);
+        map_or_log_err(self.on_cancel(), Ok)
+    }
+    fn onError(&self, error: &str) -> binder::Result<()> {
+        let _wp = wd::watch_millis("IGetRegistrationCallback::onError", 500);
+        map_or_log_err(self.on_error(error), Ok)
+    }
+}
+
+/// Make a new connection to a IRegistration service.
+async fn get_rkpd_registration(
+    security_level: &SecurityLevel,
+) -> Result<binder::Strong<dyn IRegistration>> {
+    let remote_provisioning: Strong<dyn IRemoteProvisioning> =
+        map_binder_status_code(binder::get_interface("remote_provisioning"))
+            .context(ks_err!("Trying to connect to IRemoteProvisioning service."))?;
+
+    let rpc_name = get_remotely_provisioned_component_name(security_level)
+        .context(ks_err!("Trying to get IRPC name."))?;
+
+    let (tx, rx) = oneshot::channel();
+    let cb = GetRegistrationCallback::new_native_binder(tx)
+        .context(ks_err!("Trying to create a IGetRegistrationCallback."))?;
+
+    remote_provisioning
+        .getRegistration(&rpc_name, &cb)
+        .context(ks_err!("Trying to get registration."))?;
+
+    rx.await.unwrap()
+}
+
+type KeySender = oneshot::Sender<Result<RemotelyProvisionedKey>>;
+
+struct GetKeyCallback {
+    key_tx: Mutex<Option<KeySender>>,
+}
+
+impl GetKeyCallback {
+    pub fn new_native_binder(key_tx: KeySender) -> Result<Strong<dyn IGetKeyCallback>> {
+        let result: Self = GetKeyCallback { key_tx: Mutex::new(Some(key_tx)) };
+        Ok(BnGetKeyCallback::new_binder(result, BinderFeatures::default()))
+    }
+    fn on_success(&self, key: &RemotelyProvisionedKey) -> Result<()> {
+        if let Some(tx) = self.key_tx.lock().unwrap().take() {
+            tx.send(Ok(RemotelyProvisionedKey {
+                keyBlob: key.keyBlob.clone(),
+                encodedCertChain: key.encodedCertChain.clone(),
+            }))
+            .unwrap();
+        }
+        Ok(())
+    }
+    fn on_cancel(&self) -> Result<()> {
+        if let Some(tx) = self.key_tx.lock().unwrap().take() {
+            tx.send(
+                Err(Error::Km(ErrorCode::OPERATION_CANCELLED))
+                    .context(ks_err!("GetKeyCallback cancelled.")),
+            )
+            .unwrap();
+        }
+        Ok(())
+    }
+    fn on_error(&self, error: &str) -> Result<()> {
+        if let Some(tx) = self.key_tx.lock().unwrap().take() {
+            tx.send(
+                Err(Error::Km(ErrorCode::UNKNOWN_ERROR))
+                    .context(ks_err!("GetKeyCallback failed: {:?}", error)),
+            )
+            .unwrap();
+        }
+        Ok(())
+    }
+}
+
+impl Interface for GetKeyCallback {}
+
+impl IGetKeyCallback for GetKeyCallback {
+    fn onSuccess(&self, key: &RemotelyProvisionedKey) -> binder::Result<()> {
+        let _wp = wd::watch_millis("IGetKeyCallback::onSuccess", 500);
+        map_or_log_err(self.on_success(key), Ok)
+    }
+    fn onCancel(&self) -> binder::Result<()> {
+        let _wp = wd::watch_millis("IGetKeyCallback::onCancel", 500);
+        map_or_log_err(self.on_cancel(), Ok)
+    }
+    fn onError(&self, error: &str) -> binder::Result<()> {
+        let _wp = wd::watch_millis("IGetKeyCallback::onError", 500);
+        map_or_log_err(self.on_error(error), Ok)
+    }
+}
+
+async fn get_rkpd_attestation_key_async(
+    security_level: &SecurityLevel,
+    caller_uid: u32,
+) -> Result<RemotelyProvisionedKey> {
+    let registration = get_rkpd_registration(security_level)
+        .await
+        .context(ks_err!("Trying to get to IRegistration service."))?;
+
+    let (tx, rx) = oneshot::channel();
+    let cb = GetKeyCallback::new_native_binder(tx)
+        .context(ks_err!("Trying to create a IGetKeyCallback."))?;
+
+    registration
+        .getKey(caller_uid.try_into().unwrap(), &cb)
+        .context(ks_err!("Trying to get key."))?;
+
+    rx.await.unwrap()
+}
+
+async fn store_rkpd_attestation_key_async(
+    security_level: &SecurityLevel,
+    key_blob: &[u8],
+    upgraded_blob: &[u8],
+) -> Result<()> {
+    let registration = get_rkpd_registration(security_level)
+        .await
+        .context(ks_err!("Trying to get to IRegistration service."))?;
+
+    registration
+        .storeUpgradedKey(key_blob, upgraded_blob)
+        .context(ks_err!("Failed to store upgraded blob with RKPD."))?;
+    Ok(())
+}
+
+/// Get attestation key from RKPD.
+pub fn get_rkpd_attestation_key(
+    security_level: &SecurityLevel,
+    caller_uid: u32,
+) -> Result<RemotelyProvisionedKey> {
+    let _wp = wd::watch_millis("Calling get_rkpd_attestation_key()", 500);
+    block_on(get_rkpd_attestation_key_async(security_level, caller_uid))
+}
+
+/// Store attestation key in RKPD.
+pub fn store_rkpd_attestation_key(
+    security_level: &SecurityLevel,
+    key_blob: &[u8],
+    upgraded_blob: &[u8],
+) -> Result<()> {
+    let _wp = wd::watch_millis("Calling store_rkpd_attestation_key()", 500);
+    block_on(store_rkpd_attestation_key_async(security_level, key_blob, upgraded_blob))
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use android_security_rkp_aidl::aidl::android::security::rkp::IRegistration::BnRegistration;
+    use std::sync::Arc;
+
+    #[derive(Default)]
+    struct MockRegistrationValues {
+        _key: RemotelyProvisionedKey,
+    }
+
+    #[derive(Default)]
+    struct MockRegistration(Arc<Mutex<MockRegistrationValues>>);
+
+    impl MockRegistration {
+        pub fn new_native_binder() -> Strong<dyn IRegistration> {
+            let result: Self = Default::default();
+            BnRegistration::new_binder(result, BinderFeatures::default())
+        }
+    }
+
+    impl Interface for MockRegistration {}
+
+    impl IRegistration for MockRegistration {
+        fn getKey(&self, _: i32, _: &Strong<dyn IGetKeyCallback>) -> binder::Result<()> {
+            todo!()
+        }
+
+        fn cancelGetKey(&self, _: &Strong<dyn IGetKeyCallback>) -> binder::Result<()> {
+            todo!()
+        }
+
+        fn storeUpgradedKey(&self, _: &[u8], _: &[u8]) -> binder::Result<()> {
+            todo!()
+        }
+    }
+
+    fn get_mock_registration() -> Result<binder::Strong<dyn IRegistration>> {
+        let (tx, rx) = oneshot::channel();
+        let cb = GetRegistrationCallback::new_native_binder(tx).unwrap();
+        let mock_registration = MockRegistration::new_native_binder();
+
+        assert!(cb.onSuccess(&mock_registration).is_ok());
+        block_on(rx).unwrap()
+    }
+
+    #[test]
+    fn test_get_registration_cb_success() {
+        let registration = get_mock_registration();
+        assert!(registration.is_ok());
+    }
+
+    #[test]
+    fn test_get_registration_cb_cancel() {
+        let (tx, rx) = oneshot::channel();
+        let cb = GetRegistrationCallback::new_native_binder(tx).unwrap();
+        assert!(cb.onCancel().is_ok());
+
+        let result = block_on(rx).unwrap();
+        assert_eq!(
+            result.unwrap_err().downcast::<Error>().unwrap(),
+            Error::Km(ErrorCode::OPERATION_CANCELLED)
+        );
+    }
+
+    #[test]
+    fn test_get_registration_cb_error() {
+        let (tx, rx) = oneshot::channel();
+        let cb = GetRegistrationCallback::new_native_binder(tx).unwrap();
+        assert!(cb.onError("error").is_ok());
+
+        let result = block_on(rx).unwrap();
+        assert_eq!(
+            result.unwrap_err().downcast::<Error>().unwrap(),
+            Error::Km(ErrorCode::UNKNOWN_ERROR)
+        );
+    }
+
+    #[test]
+    fn test_get_key_cb_success() {
+        let mock_key =
+            RemotelyProvisionedKey { keyBlob: vec![1, 2, 3], encodedCertChain: vec![4, 5, 6] };
+        let (tx, rx) = oneshot::channel();
+        let cb = GetKeyCallback::new_native_binder(tx).unwrap();
+        assert!(cb.onSuccess(&mock_key).is_ok());
+
+        let key = block_on(rx).unwrap().unwrap();
+        assert_eq!(key, mock_key);
+    }
+
+    #[test]
+    fn test_get_key_cb_cancel() {
+        let (tx, rx) = oneshot::channel();
+        let cb = GetKeyCallback::new_native_binder(tx).unwrap();
+        assert!(cb.onCancel().is_ok());
+
+        let result = block_on(rx).unwrap();
+        assert_eq!(
+            result.unwrap_err().downcast::<Error>().unwrap(),
+            Error::Km(ErrorCode::OPERATION_CANCELLED)
+        );
+    }
+
+    #[test]
+    fn test_get_key_cb_error() {
+        let (tx, rx) = oneshot::channel();
+        let cb = GetKeyCallback::new_native_binder(tx).unwrap();
+        assert!(cb.onError("error").is_ok());
+
+        let result = block_on(rx).unwrap();
+        assert_eq!(
+            result.unwrap_err().downcast::<Error>().unwrap(),
+            Error::Km(ErrorCode::UNKNOWN_ERROR)
+        );
+    }
+
+    #[test]
+    #[ignore]
+    fn test_get_rkpd_attestation_key() {
+        let key = get_rkpd_attestation_key(&SecurityLevel::TRUSTED_ENVIRONMENT, 0).unwrap();
+        assert!(!key.keyBlob.is_empty());
+        assert!(!key.encodedCertChain.is_empty());
+    }
+
+    #[test]
+    #[ignore]
+    fn test_get_rkpd_attestation_key_same_caller() {
+        let sec_level = SecurityLevel::TRUSTED_ENVIRONMENT;
+        let caller_uid = 0;
+
+        // Multiple calls should return the same key.
+        let first_key = get_rkpd_attestation_key(&sec_level, caller_uid).unwrap();
+        let second_key = get_rkpd_attestation_key(&sec_level, caller_uid).unwrap();
+
+        assert_eq!(first_key.keyBlob, second_key.keyBlob);
+        assert_eq!(first_key.encodedCertChain, second_key.encodedCertChain);
+    }
+
+    #[test]
+    #[ignore]
+    fn test_get_rkpd_attestation_key_different_caller() {
+        let sec_level = SecurityLevel::TRUSTED_ENVIRONMENT;
+
+        // Different callers should be getting different keys.
+        let first_key = get_rkpd_attestation_key(&sec_level, 1).unwrap();
+        let second_key = get_rkpd_attestation_key(&sec_level, 2).unwrap();
+
+        assert_ne!(first_key.keyBlob, second_key.keyBlob);
+        assert_ne!(first_key.encodedCertChain, second_key.encodedCertChain);
+    }
+
+    #[test]
+    #[ignore]
+    fn test_store_rkpd_attestation_key() {
+        let sec_level = SecurityLevel::TRUSTED_ENVIRONMENT;
+        let key = get_rkpd_attestation_key(&SecurityLevel::TRUSTED_ENVIRONMENT, 0).unwrap();
+
+        assert!(store_rkpd_attestation_key(&sec_level, &key.keyBlob, &key.keyBlob).is_ok());
+    }
+}