Merge changes from topic "add-upgrade-key"
* changes:
The RKPD store upgraded key interface is now async
Create helper type for oneshot::Sender
diff --git a/keystore2/src/rkpd_client.rs b/keystore2/src/rkpd_client.rs
index 2d1b23b..039cb90 100644
--- a/keystore2/src/rkpd_client.rs
+++ b/keystore2/src/rkpd_client.rs
@@ -23,7 +23,10 @@
IGetKeyCallback::BnGetKeyCallback, IGetKeyCallback::IGetKeyCallback,
IGetRegistrationCallback::BnGetRegistrationCallback,
IGetRegistrationCallback::IGetRegistrationCallback, IRegistration::IRegistration,
- IRemoteProvisioning::IRemoteProvisioning, RemotelyProvisionedKey::RemotelyProvisionedKey,
+ IRemoteProvisioning::IRemoteProvisioning,
+ IStoreUpgradedKeyCallback::BnStoreUpgradedKeyCallback,
+ IStoreUpgradedKeyCallback::IStoreUpgradedKeyCallback,
+ RemotelyProvisionedKey::RemotelyProvisionedKey,
};
use android_security_rkp_aidl::binder::{BinderFeatures, Interface, Strong};
use anyhow::{Context, Result};
@@ -31,44 +34,53 @@
use futures::executor::block_on;
use std::sync::Mutex;
-type RegistrationSender = oneshot::Sender<Result<binder::Strong<dyn IRegistration>>>;
+/// Thread-safe channel for sending a value once and only once. If a value has
+/// already been send, subsequent calls to send will noop.
+struct SafeSender<T> {
+ inner: Mutex<Option<oneshot::Sender<T>>>,
+}
+
+impl<T> SafeSender<T> {
+ fn new(sender: oneshot::Sender<T>) -> Self {
+ Self { inner: Mutex::new(Some(sender)) }
+ }
+
+ fn send(&self, value: T) {
+ if let Some(inner) = self.inner.lock().unwrap().take() {
+ // assert instead of unwrap, because on failure send returns Err(value)
+ assert!(inner.send(value).is_ok(), "thread state is terminally broken");
+ }
+ }
+}
struct GetRegistrationCallback {
- registration_tx: Mutex<Option<RegistrationSender>>,
+ registration_tx: SafeSender<Result<binder::Strong<dyn IRegistration>>>,
}
impl GetRegistrationCallback {
pub fn new_native_binder(
- registration_tx: RegistrationSender,
+ registration_tx: oneshot::Sender<Result<binder::Strong<dyn IRegistration>>>,
) -> Result<Strong<dyn IGetRegistrationCallback>> {
let result: Self =
- GetRegistrationCallback { registration_tx: Mutex::new(Some(registration_tx)) };
+ GetRegistrationCallback { registration_tx: SafeSender::new(registration_tx) };
Ok(BnGetRegistrationCallback::new_binder(result, BinderFeatures::default()))
}
fn on_success(&self, registration: &binder::Strong<dyn IRegistration>) -> Result<()> {
- if let Some(tx) = self.registration_tx.lock().unwrap().take() {
- tx.send(Ok(registration.clone())).unwrap();
- }
+ self.registration_tx.send(Ok(registration.clone()));
Ok(())
}
fn on_cancel(&self) -> Result<()> {
- if let Some(tx) = self.registration_tx.lock().unwrap().take() {
- tx.send(
- Err(Error::Km(ErrorCode::OPERATION_CANCELLED))
- .context(ks_err!("GetRegistrationCallback cancelled.")),
- )
- .unwrap();
- }
+ self.registration_tx.send(
+ Err(Error::Km(ErrorCode::OPERATION_CANCELLED))
+ .context(ks_err!("GetRegistrationCallback cancelled.")),
+ );
Ok(())
}
fn on_error(&self, error: &str) -> Result<()> {
- if let Some(tx) = self.registration_tx.lock().unwrap().take() {
- tx.send(
- Err(Error::Km(ErrorCode::UNKNOWN_ERROR))
- .context(ks_err!("GetRegistrationCallback failed: {:?}", error)),
- )
- .unwrap();
- }
+ self.registration_tx.send(
+ Err(Error::Km(ErrorCode::UNKNOWN_ERROR))
+ .context(ks_err!("GetRegistrationCallback failed: {:?}", error)),
+ );
Ok(())
}
}
@@ -112,45 +124,36 @@
rx.await.unwrap()
}
-type KeySender = oneshot::Sender<Result<RemotelyProvisionedKey>>;
-
struct GetKeyCallback {
- key_tx: Mutex<Option<KeySender>>,
+ key_tx: SafeSender<Result<RemotelyProvisionedKey>>,
}
impl GetKeyCallback {
- pub fn new_native_binder(key_tx: KeySender) -> Result<Strong<dyn IGetKeyCallback>> {
- let result: Self = GetKeyCallback { key_tx: Mutex::new(Some(key_tx)) };
+ pub fn new_native_binder(
+ key_tx: oneshot::Sender<Result<RemotelyProvisionedKey>>,
+ ) -> Result<Strong<dyn IGetKeyCallback>> {
+ let result: Self = GetKeyCallback { key_tx: SafeSender::new(key_tx) };
Ok(BnGetKeyCallback::new_binder(result, BinderFeatures::default()))
}
fn on_success(&self, key: &RemotelyProvisionedKey) -> Result<()> {
- if let Some(tx) = self.key_tx.lock().unwrap().take() {
- tx.send(Ok(RemotelyProvisionedKey {
- keyBlob: key.keyBlob.clone(),
- encodedCertChain: key.encodedCertChain.clone(),
- }))
- .unwrap();
- }
+ self.key_tx.send(Ok(RemotelyProvisionedKey {
+ keyBlob: key.keyBlob.clone(),
+ encodedCertChain: key.encodedCertChain.clone(),
+ }));
Ok(())
}
fn on_cancel(&self) -> Result<()> {
- if let Some(tx) = self.key_tx.lock().unwrap().take() {
- tx.send(
- Err(Error::Km(ErrorCode::OPERATION_CANCELLED))
- .context(ks_err!("GetKeyCallback cancelled.")),
- )
- .unwrap();
- }
+ self.key_tx.send(
+ Err(Error::Km(ErrorCode::OPERATION_CANCELLED))
+ .context(ks_err!("GetKeyCallback cancelled.")),
+ );
Ok(())
}
fn on_error(&self, error: &str) -> Result<()> {
- if let Some(tx) = self.key_tx.lock().unwrap().take() {
- tx.send(
- Err(Error::Km(ErrorCode::UNKNOWN_ERROR))
- .context(ks_err!("GetKeyCallback failed: {:?}", error)),
- )
- .unwrap();
- }
+ self.key_tx.send(
+ Err(Error::Km(ErrorCode::UNKNOWN_ERROR))
+ .context(ks_err!("GetKeyCallback failed: {:?}", error)),
+ );
Ok(())
}
}
@@ -191,6 +194,46 @@
rx.await.unwrap()
}
+struct StoreUpgradedKeyCallback {
+ completer: SafeSender<Result<()>>,
+}
+
+impl StoreUpgradedKeyCallback {
+ pub fn new_native_binder(
+ completer: oneshot::Sender<Result<()>>,
+ ) -> Result<Strong<dyn IStoreUpgradedKeyCallback>> {
+ let result: Self = StoreUpgradedKeyCallback { completer: SafeSender::new(completer) };
+ Ok(BnStoreUpgradedKeyCallback::new_binder(result, BinderFeatures::default()))
+ }
+
+ fn on_success(&self) -> Result<()> {
+ self.completer.send(Ok(()));
+ Ok(())
+ }
+
+ fn on_error(&self, error: &str) -> Result<()> {
+ self.completer.send(
+ Err(Error::Km(ErrorCode::UNKNOWN_ERROR))
+ .context(ks_err!("Failed to store upgraded key: {:?}", error)),
+ );
+ Ok(())
+ }
+}
+
+impl Interface for StoreUpgradedKeyCallback {}
+
+impl IStoreUpgradedKeyCallback for StoreUpgradedKeyCallback {
+ fn onSuccess(&self) -> binder::Result<()> {
+ let _wp = wd::watch_millis("IGetRegistrationCallback::onSuccess", 500);
+ map_or_log_err(self.on_success(), Ok)
+ }
+
+ fn onError(&self, error: &str) -> binder::Result<()> {
+ let _wp = wd::watch_millis("IGetRegistrationCallback::onError", 500);
+ map_or_log_err(self.on_error(error), Ok)
+ }
+}
+
async fn store_rkpd_attestation_key_async(
security_level: &SecurityLevel,
key_blob: &[u8],
@@ -200,10 +243,15 @@
.await
.context(ks_err!("Trying to get to IRegistration service."))?;
+ let (tx, rx) = oneshot::channel();
+ let cb = StoreUpgradedKeyCallback::new_native_binder(tx)
+ .context(ks_err!("Trying to create a StoreUpgradedKeyCallback."))?;
+
registration
- .storeUpgradedKey(key_blob, upgraded_blob)
+ .storeUpgradedKeyAsync(key_blob, upgraded_blob, &cb)
.context(ks_err!("Failed to store upgraded blob with RKPD."))?;
- Ok(())
+
+ rx.await.unwrap()
}
/// Get attestation key from RKPD.
@@ -257,7 +305,12 @@
todo!()
}
- fn storeUpgradedKey(&self, _: &[u8], _: &[u8]) -> binder::Result<()> {
+ fn storeUpgradedKeyAsync(
+ &self,
+ _: &[u8],
+ _: &[u8],
+ _: &Strong<dyn IStoreUpgradedKeyCallback>,
+ ) -> binder::Result<()> {
todo!()
}
}
@@ -342,6 +395,28 @@
}
#[test]
+ fn test_store_upgraded_key_cb_success() {
+ let (tx, rx) = oneshot::channel();
+ let cb = StoreUpgradedKeyCallback::new_native_binder(tx).unwrap();
+ assert!(cb.onSuccess().is_ok());
+
+ block_on(rx).unwrap().unwrap();
+ }
+
+ #[test]
+ fn test_store_upgraded_key_cb_error() {
+ let (tx, rx) = oneshot::channel();
+ let cb = StoreUpgradedKeyCallback::new_native_binder(tx).unwrap();
+ assert!(cb.onError("oh no! it failed").is_ok());
+
+ let result = block_on(rx).unwrap();
+ assert_eq!(
+ result.unwrap_err().downcast::<Error>().unwrap(),
+ Error::Km(ErrorCode::UNKNOWN_ERROR)
+ );
+ }
+
+ #[test]
#[ignore]
fn test_get_rkpd_attestation_key() {
let key = get_rkpd_attestation_key(&SecurityLevel::TRUSTED_ENVIRONMENT, 0).unwrap();