keystore2: Add timeout to RKPD get key call

Also replace libfutures with libtokio, as the former doesn't have
support for timeouts.

Bug: 264921138
Test: keystore2_test
Change-Id: I97c9749e93b2d001afe5d17bda8c665f884b0e05
diff --git a/keystore2/Android.bp b/keystore2/Android.bp
index fa80563..57038df 100644
--- a/keystore2/Android.bp
+++ b/keystore2/Android.bp
@@ -44,7 +44,6 @@
         "android.security.rkp_aidl-rust",
         "libanyhow",
         "libbinder_rs",
-        "libfutures",
         "libkeystore2_aaid-rust",
         "libkeystore2_apc_compat-rust",
         "libkeystore2_crypto_rust",
@@ -60,6 +59,7 @@
         "libserde",
         "libserde_cbor",
         "libthiserror",
+        "libtokio",
     ],
     shared_libs: [
         "libcutils",
diff --git a/keystore2/src/rkpd_client.rs b/keystore2/src/rkpd_client.rs
index d611678..c4b0686 100644
--- a/keystore2/src/rkpd_client.rs
+++ b/keystore2/src/rkpd_client.rs
@@ -13,8 +13,9 @@
 // limitations under the License.
 
 //! Helper wrapper around RKPD interface.
+// TODO(b/264891956): Return RKP specific errors.
 
-use crate::error::{map_binder_status_code, Error, ErrorCode};
+use crate::error::{map_binder_status_code, Error};
 use crate::globals::get_remotely_provisioned_component_name;
 use crate::ks_err;
 use crate::utils::watchdog as wd;
@@ -29,10 +30,21 @@
     RemotelyProvisionedKey::RemotelyProvisionedKey,
 };
 use android_security_rkp_aidl::binder::{BinderFeatures, Interface, Strong};
+use android_system_keystore2::aidl::android::system::keystore2::ResponseCode::ResponseCode;
 use anyhow::{Context, Result};
-use futures::channel::oneshot;
-use futures::executor::block_on;
 use std::sync::Mutex;
+use std::time::Duration;
+use tokio::sync::oneshot;
+use tokio::time::timeout;
+
+// Normally, we block indefinitely when making calls outside of keystore and rely on watchdog to
+// report deadlocks. However, RKPD is mainline updatable. Also, calls to RKPD may wait on network
+// for certificates. So, we err on the side of caution and timeout instead.
+static RKPD_TIMEOUT: Duration = Duration::from_secs(10);
+
+fn tokio_rt() -> tokio::runtime::Runtime {
+    tokio::runtime::Builder::new_current_thread().enable_all().build().unwrap()
+}
 
 /// Thread-safe channel for sending a value once and only once. If a value has
 /// already been send, subsequent calls to send will noop.
@@ -79,7 +91,7 @@
         let _wp = wd::watch_millis("IGetRegistrationCallback::onCancel", 500);
         log::warn!("IGetRegistrationCallback cancelled");
         self.registration_tx.send(
-            Err(Error::Km(ErrorCode::OPERATION_CANCELLED))
+            Err(Error::Rc(ResponseCode::OUT_OF_KEYS))
                 .context(ks_err!("GetRegistrationCallback cancelled.")),
         );
         Ok(())
@@ -88,7 +100,7 @@
         let _wp = wd::watch_millis("IGetRegistrationCallback::onError", 500);
         log::error!("IGetRegistrationCallback failed: '{error}'");
         self.registration_tx.send(
-            Err(Error::Km(ErrorCode::UNKNOWN_ERROR))
+            Err(Error::Rc(ResponseCode::OUT_OF_KEYS))
                 .context(ks_err!("GetRegistrationCallback failed: {:?}", error)),
         );
         Ok(())
@@ -113,7 +125,12 @@
         .getRegistration(&rpc_name, &cb)
         .context(ks_err!("Trying to get registration."))?;
 
-    rx.await.unwrap()
+    match timeout(RKPD_TIMEOUT, rx).await {
+        Err(e) => {
+            Err(Error::Rc(ResponseCode::SYSTEM_ERROR)).context(ks_err!("Waiting for RKPD: {:?}", e))
+        }
+        Ok(v) => v.unwrap(),
+    }
 }
 
 struct GetKeyCallback {
@@ -144,8 +161,7 @@
         let _wp = wd::watch_millis("IGetKeyCallback::onCancel", 500);
         log::warn!("IGetKeyCallback cancelled");
         self.key_tx.send(
-            Err(Error::Km(ErrorCode::OPERATION_CANCELLED))
-                .context(ks_err!("GetKeyCallback cancelled.")),
+            Err(Error::Rc(ResponseCode::OUT_OF_KEYS)).context(ks_err!("GetKeyCallback cancelled.")),
         );
         Ok(())
     }
@@ -153,13 +169,31 @@
         let _wp = wd::watch_millis("IGetKeyCallback::onError", 500);
         log::error!("IGetKeyCallback failed: {error}");
         self.key_tx.send(
-            Err(Error::Km(ErrorCode::UNKNOWN_ERROR))
+            Err(Error::Rc(ResponseCode::OUT_OF_KEYS))
                 .context(ks_err!("GetKeyCallback failed: {:?}", error)),
         );
         Ok(())
     }
 }
 
+async fn get_rkpd_attestation_key_from_registration_async(
+    registration: &Strong<dyn IRegistration>,
+    caller_uid: u32,
+) -> Result<RemotelyProvisionedKey> {
+    let (tx, rx) = oneshot::channel();
+    let cb = GetKeyCallback::new_native_binder(tx);
+
+    registration
+        .getKey(caller_uid.try_into().unwrap(), &cb)
+        .context(ks_err!("Trying to get key."))?;
+
+    match timeout(RKPD_TIMEOUT, rx).await {
+        Err(e) => Err(Error::Rc(ResponseCode::OUT_OF_KEYS))
+            .context(ks_err!("Waiting for RKPD key timed out: {:?}", e)),
+        Ok(v) => v.unwrap(),
+    }
+}
+
 async fn get_rkpd_attestation_key_async(
     security_level: &SecurityLevel,
     caller_uid: u32,
@@ -167,15 +201,7 @@
     let registration = get_rkpd_registration(security_level)
         .await
         .context(ks_err!("Trying to get to IRegistration service."))?;
-
-    let (tx, rx) = oneshot::channel();
-    let cb = GetKeyCallback::new_native_binder(tx);
-
-    registration
-        .getKey(caller_uid.try_into().unwrap(), &cb)
-        .context(ks_err!("Trying to get key."))?;
-
-    rx.await.unwrap()
+    get_rkpd_attestation_key_from_registration_async(&registration, caller_uid).await
 }
 
 struct StoreUpgradedKeyCallback {
@@ -204,13 +230,32 @@
         let _wp = wd::watch_millis("IGetRegistrationCallback::onError", 500);
         log::error!("IGetRegistrationCallback failed: {error}");
         self.completer.send(
-            Err(Error::Km(ErrorCode::UNKNOWN_ERROR))
+            Err(Error::Rc(ResponseCode::SYSTEM_ERROR))
                 .context(ks_err!("Failed to store upgraded key: {:?}", error)),
         );
         Ok(())
     }
 }
 
+async fn store_rkpd_attestation_key_with_registration_async(
+    registration: &Strong<dyn IRegistration>,
+    key_blob: &[u8],
+    upgraded_blob: &[u8],
+) -> Result<()> {
+    let (tx, rx) = oneshot::channel();
+    let cb = StoreUpgradedKeyCallback::new_native_binder(tx);
+
+    registration
+        .storeUpgradedKeyAsync(key_blob, upgraded_blob, &cb)
+        .context(ks_err!("Failed to store upgraded blob with RKPD."))?;
+
+    match timeout(RKPD_TIMEOUT, rx).await {
+        Err(e) => Err(Error::Rc(ResponseCode::SYSTEM_ERROR))
+            .context(ks_err!("Waiting for RKPD to complete storing key: {:?}", e)),
+        Ok(v) => v.unwrap(),
+    }
+}
+
 async fn store_rkpd_attestation_key_async(
     security_level: &SecurityLevel,
     key_blob: &[u8],
@@ -219,15 +264,7 @@
     let registration = get_rkpd_registration(security_level)
         .await
         .context(ks_err!("Trying to get to IRegistration service."))?;
-
-    let (tx, rx) = oneshot::channel();
-    let cb = StoreUpgradedKeyCallback::new_native_binder(tx);
-
-    registration
-        .storeUpgradedKeyAsync(key_blob, upgraded_blob, &cb)
-        .context(ks_err!("Failed to store upgraded blob with RKPD."))?;
-
-    rx.await.unwrap()
+    store_rkpd_attestation_key_with_registration_async(&registration, key_blob, upgraded_blob).await
 }
 
 /// Get attestation key from RKPD.
@@ -236,7 +273,7 @@
     caller_uid: u32,
 ) -> Result<RemotelyProvisionedKey> {
     let _wp = wd::watch_millis("Calling get_rkpd_attestation_key()", 500);
-    block_on(get_rkpd_attestation_key_async(security_level, caller_uid))
+    tokio_rt().block_on(get_rkpd_attestation_key_async(security_level, caller_uid))
 }
 
 /// Store attestation key in RKPD.
@@ -246,7 +283,7 @@
     upgraded_blob: &[u8],
 ) -> Result<()> {
     let _wp = wd::watch_millis("Calling store_rkpd_attestation_key()", 500);
-    block_on(store_rkpd_attestation_key_async(security_level, key_blob, upgraded_blob))
+    tokio_rt().block_on(store_rkpd_attestation_key_async(security_level, key_blob, upgraded_blob))
 }
 
 #[cfg(test)]
@@ -254,19 +291,25 @@
     use super::*;
     use android_security_rkp_aidl::aidl::android::security::rkp::IRegistration::BnRegistration;
     use std::sync::atomic::{AtomicU32, Ordering};
-    use std::sync::Arc;
 
     #[derive(Default)]
-    struct MockRegistrationValues {
-        _key: RemotelyProvisionedKey,
+    struct MockRegistration {
+        key: RemotelyProvisionedKey,
+        latency: Option<Duration>,
     }
 
-    #[derive(Default)]
-    struct MockRegistration(Arc<Mutex<MockRegistrationValues>>);
-
     impl MockRegistration {
-        pub fn new_native_binder() -> Strong<dyn IRegistration> {
-            let result: Self = Default::default();
+        pub fn new_native_binder(
+            key: &RemotelyProvisionedKey,
+            latency: Option<Duration>,
+        ) -> Strong<dyn IRegistration> {
+            let result = Self {
+                key: RemotelyProvisionedKey {
+                    keyBlob: key.keyBlob.clone(),
+                    encodedCertChain: key.encodedCertChain.clone(),
+                },
+                latency,
+            };
             BnRegistration::new_binder(result, BinderFeatures::default())
         }
     }
@@ -274,8 +317,22 @@
     impl Interface for MockRegistration {}
 
     impl IRegistration for MockRegistration {
-        fn getKey(&self, _: i32, _: &Strong<dyn IGetKeyCallback>) -> binder::Result<()> {
-            todo!()
+        fn getKey(&self, _: i32, cb: &Strong<dyn IGetKeyCallback>) -> binder::Result<()> {
+            let key = RemotelyProvisionedKey {
+                keyBlob: self.key.keyBlob.clone(),
+                encodedCertChain: self.key.encodedCertChain.clone(),
+            };
+            let latency = self.latency;
+            let get_key_cb = cb.clone();
+
+            // Need a separate thread to trigger timeout in the caller.
+            std::thread::spawn(move || {
+                if let Some(duration) = latency {
+                    std::thread::sleep(duration);
+                }
+                get_key_cb.onSuccess(&key).unwrap();
+            });
+            Ok(())
         }
 
         fn cancelGetKey(&self, _: &Strong<dyn IGetKeyCallback>) -> binder::Result<()> {
@@ -286,19 +343,33 @@
             &self,
             _: &[u8],
             _: &[u8],
-            _: &Strong<dyn IStoreUpgradedKeyCallback>,
+            cb: &Strong<dyn IStoreUpgradedKeyCallback>,
         ) -> binder::Result<()> {
-            todo!()
+            // We are primarily concerned with timing out correctly. Storing the key in this mock
+            // registration isn't particularly interesting, so skip that part.
+            let store_cb = cb.clone();
+            let latency = self.latency;
+
+            std::thread::spawn(move || {
+                if let Some(duration) = latency {
+                    std::thread::sleep(duration);
+                }
+                store_cb.onSuccess().unwrap();
+            });
+            Ok(())
         }
     }
 
-    fn get_mock_registration() -> Result<binder::Strong<dyn IRegistration>> {
+    fn get_mock_registration(
+        key: &RemotelyProvisionedKey,
+        latency: Option<Duration>,
+    ) -> Result<binder::Strong<dyn IRegistration>> {
         let (tx, rx) = oneshot::channel();
         let cb = GetRegistrationCallback::new_native_binder(tx);
-        let mock_registration = MockRegistration::new_native_binder();
+        let mock_registration = MockRegistration::new_native_binder(key, latency);
 
         assert!(cb.onSuccess(&mock_registration).is_ok());
-        block_on(rx).unwrap()
+        tokio_rt().block_on(rx).unwrap()
     }
 
     // Using the same key ID makes test cases race with each other. So, we use separate key IDs for
@@ -310,7 +381,8 @@
 
     #[test]
     fn test_get_registration_cb_success() {
-        let registration = get_mock_registration();
+        let key: RemotelyProvisionedKey = Default::default();
+        let registration = get_mock_registration(&key, /*latency=*/ None);
         assert!(registration.is_ok());
     }
 
@@ -320,10 +392,10 @@
         let cb = GetRegistrationCallback::new_native_binder(tx);
         assert!(cb.onCancel().is_ok());
 
-        let result = block_on(rx).unwrap();
+        let result = tokio_rt().block_on(rx).unwrap();
         assert_eq!(
             result.unwrap_err().downcast::<Error>().unwrap(),
-            Error::Km(ErrorCode::OPERATION_CANCELLED)
+            Error::Rc(ResponseCode::OUT_OF_KEYS)
         );
     }
 
@@ -333,10 +405,10 @@
         let cb = GetRegistrationCallback::new_native_binder(tx);
         assert!(cb.onError("error").is_ok());
 
-        let result = block_on(rx).unwrap();
+        let result = tokio_rt().block_on(rx).unwrap();
         assert_eq!(
             result.unwrap_err().downcast::<Error>().unwrap(),
-            Error::Km(ErrorCode::UNKNOWN_ERROR)
+            Error::Rc(ResponseCode::OUT_OF_KEYS)
         );
     }
 
@@ -348,7 +420,7 @@
         let cb = GetKeyCallback::new_native_binder(tx);
         assert!(cb.onSuccess(&mock_key).is_ok());
 
-        let key = block_on(rx).unwrap().unwrap();
+        let key = tokio_rt().block_on(rx).unwrap().unwrap();
         assert_eq!(key, mock_key);
     }
 
@@ -358,10 +430,10 @@
         let cb = GetKeyCallback::new_native_binder(tx);
         assert!(cb.onCancel().is_ok());
 
-        let result = block_on(rx).unwrap();
+        let result = tokio_rt().block_on(rx).unwrap();
         assert_eq!(
             result.unwrap_err().downcast::<Error>().unwrap(),
-            Error::Km(ErrorCode::OPERATION_CANCELLED)
+            Error::Rc(ResponseCode::OUT_OF_KEYS)
         );
     }
 
@@ -371,10 +443,10 @@
         let cb = GetKeyCallback::new_native_binder(tx);
         assert!(cb.onError("error").is_ok());
 
-        let result = block_on(rx).unwrap();
+        let result = tokio_rt().block_on(rx).unwrap();
         assert_eq!(
             result.unwrap_err().downcast::<Error>().unwrap(),
-            Error::Km(ErrorCode::UNKNOWN_ERROR)
+            Error::Rc(ResponseCode::OUT_OF_KEYS)
         );
     }
 
@@ -384,7 +456,7 @@
         let cb = StoreUpgradedKeyCallback::new_native_binder(tx);
         assert!(cb.onSuccess().is_ok());
 
-        block_on(rx).unwrap().unwrap();
+        tokio_rt().block_on(rx).unwrap().unwrap();
     }
 
     #[test]
@@ -393,17 +465,73 @@
         let cb = StoreUpgradedKeyCallback::new_native_binder(tx);
         assert!(cb.onError("oh no! it failed").is_ok());
 
-        let result = block_on(rx).unwrap();
+        let result = tokio_rt().block_on(rx).unwrap();
         assert_eq!(
             result.unwrap_err().downcast::<Error>().unwrap(),
-            Error::Km(ErrorCode::UNKNOWN_ERROR)
+            Error::Rc(ResponseCode::SYSTEM_ERROR)
+        );
+    }
+
+    #[test]
+    fn test_get_mock_key_success() {
+        let mock_key =
+            RemotelyProvisionedKey { keyBlob: vec![1, 2, 3], encodedCertChain: vec![4, 5, 6] };
+        let registration = get_mock_registration(&mock_key, /*latency=*/ None).unwrap();
+
+        let key = tokio_rt()
+            .block_on(get_rkpd_attestation_key_from_registration_async(&registration, 0))
+            .unwrap();
+        assert_eq!(key, mock_key);
+    }
+
+    #[test]
+    fn test_get_mock_key_timeout() {
+        let mock_key =
+            RemotelyProvisionedKey { keyBlob: vec![1, 2, 3], encodedCertChain: vec![4, 5, 6] };
+        let latency = RKPD_TIMEOUT + Duration::from_secs(10);
+        let registration = get_mock_registration(&mock_key, Some(latency)).unwrap();
+
+        let result =
+            tokio_rt().block_on(get_rkpd_attestation_key_from_registration_async(&registration, 0));
+        assert_eq!(
+            result.unwrap_err().downcast::<Error>().unwrap(),
+            Error::Rc(ResponseCode::OUT_OF_KEYS)
+        );
+    }
+
+    #[test]
+    fn test_store_mock_key_success() {
+        let mock_key =
+            RemotelyProvisionedKey { keyBlob: vec![1, 2, 3], encodedCertChain: vec![4, 5, 6] };
+        let registration = get_mock_registration(&mock_key, /*latency=*/ None).unwrap();
+        tokio_rt()
+            .block_on(store_rkpd_attestation_key_with_registration_async(&registration, &[], &[]))
+            .unwrap();
+    }
+
+    #[test]
+    fn test_store_mock_key_timeout() {
+        let mock_key =
+            RemotelyProvisionedKey { keyBlob: vec![1, 2, 3], encodedCertChain: vec![4, 5, 6] };
+        let latency = RKPD_TIMEOUT + Duration::from_secs(10);
+        let registration = get_mock_registration(&mock_key, Some(latency)).unwrap();
+
+        let result = tokio_rt().block_on(store_rkpd_attestation_key_with_registration_async(
+            &registration,
+            &[],
+            &[],
+        ));
+        assert_eq!(
+            result.unwrap_err().downcast::<Error>().unwrap(),
+            Error::Rc(ResponseCode::SYSTEM_ERROR)
         );
     }
 
     #[test]
     fn test_get_rkpd_attestation_key() {
         binder::ProcessState::start_thread_pool();
-        let key = get_rkpd_attestation_key(&SecurityLevel::TRUSTED_ENVIRONMENT, 0).unwrap();
+        let key_id = get_next_key_id();
+        let key = get_rkpd_attestation_key(&SecurityLevel::TRUSTED_ENVIRONMENT, key_id).unwrap();
         assert!(!key.keyBlob.is_empty());
         assert!(!key.encodedCertChain.is_empty());
     }