Create helper type for oneshot::Sender
Move boilerplate for the sender code into a helper type. This moves
all the mutex unwrapping and optional checking into one spot, cleaning
up the call sites.
Test: keystore2_test
Change-Id: I517e091047d11d4ecca73026e5ee288878658fa3
diff --git a/keystore2/src/rkpd_client.rs b/keystore2/src/rkpd_client.rs
index 2d1b23b..2369d0a 100644
--- a/keystore2/src/rkpd_client.rs
+++ b/keystore2/src/rkpd_client.rs
@@ -31,44 +31,53 @@
use futures::executor::block_on;
use std::sync::Mutex;
-type RegistrationSender = oneshot::Sender<Result<binder::Strong<dyn IRegistration>>>;
+/// Thread-safe channel for sending a value once and only once. If a value has
+/// already been send, subsequent calls to send will noop.
+struct SafeSender<T> {
+ inner: Mutex<Option<oneshot::Sender<T>>>,
+}
+
+impl<T> SafeSender<T> {
+ fn new(sender: oneshot::Sender<T>) -> Self {
+ Self { inner: Mutex::new(Some(sender)) }
+ }
+
+ fn send(&self, value: T) {
+ if let Some(inner) = self.inner.lock().unwrap().take() {
+ // assert instead of unwrap, because on failure send returns Err(value)
+ assert!(inner.send(value).is_ok(), "thread state is terminally broken");
+ }
+ }
+}
struct GetRegistrationCallback {
- registration_tx: Mutex<Option<RegistrationSender>>,
+ registration_tx: SafeSender<Result<binder::Strong<dyn IRegistration>>>,
}
impl GetRegistrationCallback {
pub fn new_native_binder(
- registration_tx: RegistrationSender,
+ registration_tx: oneshot::Sender<Result<binder::Strong<dyn IRegistration>>>,
) -> Result<Strong<dyn IGetRegistrationCallback>> {
let result: Self =
- GetRegistrationCallback { registration_tx: Mutex::new(Some(registration_tx)) };
+ GetRegistrationCallback { registration_tx: SafeSender::new(registration_tx) };
Ok(BnGetRegistrationCallback::new_binder(result, BinderFeatures::default()))
}
fn on_success(&self, registration: &binder::Strong<dyn IRegistration>) -> Result<()> {
- if let Some(tx) = self.registration_tx.lock().unwrap().take() {
- tx.send(Ok(registration.clone())).unwrap();
- }
+ self.registration_tx.send(Ok(registration.clone()));
Ok(())
}
fn on_cancel(&self) -> Result<()> {
- if let Some(tx) = self.registration_tx.lock().unwrap().take() {
- tx.send(
- Err(Error::Km(ErrorCode::OPERATION_CANCELLED))
- .context(ks_err!("GetRegistrationCallback cancelled.")),
- )
- .unwrap();
- }
+ self.registration_tx.send(
+ Err(Error::Km(ErrorCode::OPERATION_CANCELLED))
+ .context(ks_err!("GetRegistrationCallback cancelled.")),
+ );
Ok(())
}
fn on_error(&self, error: &str) -> Result<()> {
- if let Some(tx) = self.registration_tx.lock().unwrap().take() {
- tx.send(
- Err(Error::Km(ErrorCode::UNKNOWN_ERROR))
- .context(ks_err!("GetRegistrationCallback failed: {:?}", error)),
- )
- .unwrap();
- }
+ self.registration_tx.send(
+ Err(Error::Km(ErrorCode::UNKNOWN_ERROR))
+ .context(ks_err!("GetRegistrationCallback failed: {:?}", error)),
+ );
Ok(())
}
}
@@ -112,45 +121,36 @@
rx.await.unwrap()
}
-type KeySender = oneshot::Sender<Result<RemotelyProvisionedKey>>;
-
struct GetKeyCallback {
- key_tx: Mutex<Option<KeySender>>,
+ key_tx: SafeSender<Result<RemotelyProvisionedKey>>,
}
impl GetKeyCallback {
- pub fn new_native_binder(key_tx: KeySender) -> Result<Strong<dyn IGetKeyCallback>> {
- let result: Self = GetKeyCallback { key_tx: Mutex::new(Some(key_tx)) };
+ pub fn new_native_binder(
+ key_tx: oneshot::Sender<Result<RemotelyProvisionedKey>>,
+ ) -> Result<Strong<dyn IGetKeyCallback>> {
+ let result: Self = GetKeyCallback { key_tx: SafeSender::new(key_tx) };
Ok(BnGetKeyCallback::new_binder(result, BinderFeatures::default()))
}
fn on_success(&self, key: &RemotelyProvisionedKey) -> Result<()> {
- if let Some(tx) = self.key_tx.lock().unwrap().take() {
- tx.send(Ok(RemotelyProvisionedKey {
- keyBlob: key.keyBlob.clone(),
- encodedCertChain: key.encodedCertChain.clone(),
- }))
- .unwrap();
- }
+ self.key_tx.send(Ok(RemotelyProvisionedKey {
+ keyBlob: key.keyBlob.clone(),
+ encodedCertChain: key.encodedCertChain.clone(),
+ }));
Ok(())
}
fn on_cancel(&self) -> Result<()> {
- if let Some(tx) = self.key_tx.lock().unwrap().take() {
- tx.send(
- Err(Error::Km(ErrorCode::OPERATION_CANCELLED))
- .context(ks_err!("GetKeyCallback cancelled.")),
- )
- .unwrap();
- }
+ self.key_tx.send(
+ Err(Error::Km(ErrorCode::OPERATION_CANCELLED))
+ .context(ks_err!("GetKeyCallback cancelled.")),
+ );
Ok(())
}
fn on_error(&self, error: &str) -> Result<()> {
- if let Some(tx) = self.key_tx.lock().unwrap().take() {
- tx.send(
- Err(Error::Km(ErrorCode::UNKNOWN_ERROR))
- .context(ks_err!("GetKeyCallback failed: {:?}", error)),
- )
- .unwrap();
- }
+ self.key_tx.send(
+ Err(Error::Km(ErrorCode::UNKNOWN_ERROR))
+ .context(ks_err!("GetKeyCallback failed: {:?}", error)),
+ );
Ok(())
}
}