keystore2_test: Join all test threads
This catches bugs that could happen after main test thread completes.
Bug: 269460851
Test: keystore2_test
Change-Id: I0d723b04a95e83da8aaceb0748f5af0a9eab90e2
diff --git a/keystore2/src/rkpd_client.rs b/keystore2/src/rkpd_client.rs
index f1e8e11..0754a64 100644
--- a/keystore2/src/rkpd_client.rs
+++ b/keystore2/src/rkpd_client.rs
@@ -314,47 +314,63 @@
use keystore2_crypto::parse_subject_from_certificate;
use std::collections::HashMap;
use std::sync::atomic::{AtomicU32, Ordering};
+ use std::sync::{Arc, Mutex};
- #[derive(Default)]
- struct MockRegistration {
+ struct MockRegistrationValues {
key: RemotelyProvisionedKey,
latency: Option<Duration>,
+ thread_join_handles: Vec<Option<std::thread::JoinHandle<()>>>,
}
+ struct MockRegistration(Arc<Mutex<MockRegistrationValues>>);
+
impl MockRegistration {
pub fn new_native_binder(
key: &RemotelyProvisionedKey,
latency: Option<Duration>,
) -> Strong<dyn IRegistration> {
- let result = Self {
+ let result = Self(Arc::new(Mutex::new(MockRegistrationValues {
key: RemotelyProvisionedKey {
keyBlob: key.keyBlob.clone(),
encodedCertChain: key.encodedCertChain.clone(),
},
latency,
- };
+ thread_join_handles: Vec::new(),
+ })));
BnRegistration::new_binder(result, BinderFeatures::default())
}
}
+ impl Drop for MockRegistration {
+ fn drop(&mut self) {
+ let mut values = self.0.lock().unwrap();
+ for handle in values.thread_join_handles.iter_mut() {
+ // These are test threads. So, no need to worry too much about error handling.
+ handle.take().unwrap().join().unwrap();
+ }
+ }
+ }
+
impl Interface for MockRegistration {}
impl IRegistration for MockRegistration {
fn getKey(&self, _: i32, cb: &Strong<dyn IGetKeyCallback>) -> binder::Result<()> {
+ let mut values = self.0.lock().unwrap();
let key = RemotelyProvisionedKey {
- keyBlob: self.key.keyBlob.clone(),
- encodedCertChain: self.key.encodedCertChain.clone(),
+ keyBlob: values.key.keyBlob.clone(),
+ encodedCertChain: values.key.encodedCertChain.clone(),
};
- let latency = self.latency;
+ let latency = values.latency;
let get_key_cb = cb.clone();
// Need a separate thread to trigger timeout in the caller.
- std::thread::spawn(move || {
+ let join_handle = std::thread::spawn(move || {
if let Some(duration) = latency {
std::thread::sleep(duration);
}
get_key_cb.onSuccess(&key).unwrap();
});
+ values.thread_join_handles.push(Some(join_handle));
Ok(())
}
@@ -370,8 +386,9 @@
) -> binder::Result<()> {
// We are primarily concerned with timing out correctly. Storing the key in this mock
// registration isn't particularly interesting, so skip that part.
+ let values = self.0.lock().unwrap();
let store_cb = cb.clone();
- let latency = self.latency;
+ let latency = values.latency;
std::thread::spawn(move || {
if let Some(duration) = latency {
@@ -528,7 +545,7 @@
fn test_get_mock_key_timeout() {
let mock_key =
RemotelyProvisionedKey { keyBlob: vec![1, 2, 3], encodedCertChain: vec![4, 5, 6] };
- let latency = RKPD_TIMEOUT + Duration::from_secs(10);
+ let latency = RKPD_TIMEOUT + Duration::from_secs(1);
let registration = get_mock_registration(&mock_key, Some(latency)).unwrap();
let result =
@@ -553,7 +570,7 @@
fn test_store_mock_key_timeout() {
let mock_key =
RemotelyProvisionedKey { keyBlob: vec![1, 2, 3], encodedCertChain: vec![4, 5, 6] };
- let latency = RKPD_TIMEOUT + Duration::from_secs(10);
+ let latency = RKPD_TIMEOUT + Duration::from_secs(1);
let registration = get_mock_registration(&mock_key, Some(latency)).unwrap();
let result = tokio_rt().block_on(store_rkpd_attestation_key_with_registration_async(