Move large test modules into separate files
This complies better with the Android Rust style guide, and makes it
easier to navigate the code.
Test: keystore2_test libwatchdog_rs.test librkpd_client.test
Change-Id: Iceb49e309af66ec16d31da66b328936b0312061a
diff --git a/keystore2/rkpd_client/src/tests.rs b/keystore2/rkpd_client/src/tests.rs
new file mode 100644
index 0000000..fd0468f
--- /dev/null
+++ b/keystore2/rkpd_client/src/tests.rs
@@ -0,0 +1,338 @@
+// Copyright 2022, The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//! RKPD tests.
+
+use super::*;
+use android_security_rkp_aidl::aidl::android::security::rkp::IRegistration::BnRegistration;
+use std::sync::atomic::{AtomicU32, Ordering};
+use std::sync::{Arc, Mutex};
+
+const DEFAULT_RPC_SERVICE_NAME: &str =
+ "android.hardware.security.keymint.IRemotelyProvisionedComponent/default";
+
+struct MockRegistrationValues {
+ key: RemotelyProvisionedKey,
+ latency: Option<Duration>,
+ thread_join_handles: Vec<Option<std::thread::JoinHandle<()>>>,
+}
+
+struct MockRegistration(Arc<Mutex<MockRegistrationValues>>);
+
+impl MockRegistration {
+ pub fn new_native_binder(
+ key: &RemotelyProvisionedKey,
+ latency: Option<Duration>,
+ ) -> Strong<dyn IRegistration> {
+ let result = Self(Arc::new(Mutex::new(MockRegistrationValues {
+ key: RemotelyProvisionedKey {
+ keyBlob: key.keyBlob.clone(),
+ encodedCertChain: key.encodedCertChain.clone(),
+ },
+ latency,
+ thread_join_handles: Vec::new(),
+ })));
+ BnRegistration::new_binder(result, BinderFeatures::default())
+ }
+}
+
+impl Drop for MockRegistration {
+ fn drop(&mut self) {
+ let mut values = self.0.lock().unwrap();
+ for handle in values.thread_join_handles.iter_mut() {
+ // These are test threads. So, no need to worry too much about error handling.
+ handle.take().unwrap().join().unwrap();
+ }
+ }
+}
+
+impl Interface for MockRegistration {}
+
+impl IRegistration for MockRegistration {
+ fn getKey(&self, _: i32, cb: &Strong<dyn IGetKeyCallback>) -> binder::Result<()> {
+ let mut values = self.0.lock().unwrap();
+ let key = RemotelyProvisionedKey {
+ keyBlob: values.key.keyBlob.clone(),
+ encodedCertChain: values.key.encodedCertChain.clone(),
+ };
+ let latency = values.latency;
+ let get_key_cb = cb.clone();
+
+ // Need a separate thread to trigger timeout in the caller.
+ let join_handle = std::thread::spawn(move || {
+ if let Some(duration) = latency {
+ std::thread::sleep(duration);
+ }
+ get_key_cb.onSuccess(&key).unwrap();
+ });
+ values.thread_join_handles.push(Some(join_handle));
+ Ok(())
+ }
+
+ fn cancelGetKey(&self, _: &Strong<dyn IGetKeyCallback>) -> binder::Result<()> {
+ Ok(())
+ }
+
+ fn storeUpgradedKeyAsync(
+ &self,
+ _: &[u8],
+ _: &[u8],
+ cb: &Strong<dyn IStoreUpgradedKeyCallback>,
+ ) -> binder::Result<()> {
+ // We are primarily concerned with timing out correctly. Storing the key in this mock
+ // registration isn't particularly interesting, so skip that part.
+ let values = self.0.lock().unwrap();
+ let store_cb = cb.clone();
+ let latency = values.latency;
+
+ std::thread::spawn(move || {
+ if let Some(duration) = latency {
+ std::thread::sleep(duration);
+ }
+ store_cb.onSuccess().unwrap();
+ });
+ Ok(())
+ }
+}
+
+fn get_mock_registration(
+ key: &RemotelyProvisionedKey,
+ latency: Option<Duration>,
+) -> Result<binder::Strong<dyn IRegistration>> {
+ let (tx, rx) = oneshot::channel();
+ let cb = GetRegistrationCallback::new_native_binder(tx);
+ let mock_registration = MockRegistration::new_native_binder(key, latency);
+
+ assert!(cb.onSuccess(&mock_registration).is_ok());
+ tokio_rt().block_on(rx).unwrap()
+}
+
+// Using the same key ID makes test cases race with each other. So, we use separate key IDs for
+// different test cases.
+fn get_next_key_id() -> u32 {
+ static ID: AtomicU32 = AtomicU32::new(0);
+ ID.fetch_add(1, Ordering::Relaxed)
+}
+
+#[test]
+fn test_get_registration_cb_success() {
+ let key: RemotelyProvisionedKey = Default::default();
+ let registration = get_mock_registration(&key, /*latency=*/ None);
+ assert!(registration.is_ok());
+}
+
+#[test]
+fn test_get_registration_cb_cancel() {
+ let (tx, rx) = oneshot::channel();
+ let cb = GetRegistrationCallback::new_native_binder(tx);
+ assert!(cb.onCancel().is_ok());
+
+ let result = tokio_rt().block_on(rx).unwrap();
+ assert_eq!(result.unwrap_err().downcast::<Error>().unwrap(), Error::RequestCancelled);
+}
+
+#[test]
+fn test_get_registration_cb_error() {
+ let (tx, rx) = oneshot::channel();
+ let cb = GetRegistrationCallback::new_native_binder(tx);
+ assert!(cb.onError("error").is_ok());
+
+ let result = tokio_rt().block_on(rx).unwrap();
+ assert_eq!(result.unwrap_err().downcast::<Error>().unwrap(), Error::GetRegistrationFailed);
+}
+
+#[test]
+fn test_get_key_cb_success() {
+ let mock_key =
+ RemotelyProvisionedKey { keyBlob: vec![1, 2, 3], encodedCertChain: vec![4, 5, 6] };
+ let (tx, rx) = oneshot::channel();
+ let cb = GetKeyCallback::new_native_binder(tx);
+ assert!(cb.onSuccess(&mock_key).is_ok());
+
+ let key = tokio_rt().block_on(rx).unwrap().unwrap();
+ assert_eq!(key, mock_key);
+}
+
+#[test]
+fn test_get_key_cb_cancel() {
+ let (tx, rx) = oneshot::channel();
+ let cb = GetKeyCallback::new_native_binder(tx);
+ assert!(cb.onCancel().is_ok());
+
+ let result = tokio_rt().block_on(rx).unwrap();
+ assert_eq!(result.unwrap_err().downcast::<Error>().unwrap(), Error::RequestCancelled);
+}
+
+#[test]
+fn test_get_key_cb_error() {
+ for get_key_error in GetKeyErrorCode::enum_values() {
+ let (tx, rx) = oneshot::channel();
+ let cb = GetKeyCallback::new_native_binder(tx);
+ assert!(cb.onError(get_key_error, "error").is_ok());
+
+ let result = tokio_rt().block_on(rx).unwrap();
+ assert_eq!(
+ result.unwrap_err().downcast::<Error>().unwrap(),
+ Error::GetKeyFailed(get_key_error),
+ );
+ }
+}
+
+#[test]
+fn test_store_upgraded_cb_success() {
+ let (tx, rx) = oneshot::channel();
+ let cb = StoreUpgradedKeyCallback::new_native_binder(tx);
+ assert!(cb.onSuccess().is_ok());
+
+ tokio_rt().block_on(rx).unwrap().unwrap();
+}
+
+#[test]
+fn test_store_upgraded_key_cb_error() {
+ let (tx, rx) = oneshot::channel();
+ let cb = StoreUpgradedKeyCallback::new_native_binder(tx);
+ assert!(cb.onError("oh no! it failed").is_ok());
+
+ let result = tokio_rt().block_on(rx).unwrap();
+ assert_eq!(result.unwrap_err().downcast::<Error>().unwrap(), Error::StoreUpgradedKeyFailed);
+}
+
+#[test]
+fn test_get_mock_key_success() {
+ let mock_key =
+ RemotelyProvisionedKey { keyBlob: vec![1, 2, 3], encodedCertChain: vec![4, 5, 6] };
+ let registration = get_mock_registration(&mock_key, /*latency=*/ None).unwrap();
+
+ let key = tokio_rt()
+ .block_on(get_rkpd_attestation_key_from_registration_async(®istration, 0))
+ .unwrap();
+ assert_eq!(key, mock_key);
+}
+
+#[test]
+fn test_get_mock_key_timeout() {
+ let mock_key =
+ RemotelyProvisionedKey { keyBlob: vec![1, 2, 3], encodedCertChain: vec![4, 5, 6] };
+ let latency = RKPD_TIMEOUT + Duration::from_secs(1);
+ let registration = get_mock_registration(&mock_key, Some(latency)).unwrap();
+
+ let result =
+ tokio_rt().block_on(get_rkpd_attestation_key_from_registration_async(®istration, 0));
+ assert_eq!(result.unwrap_err().downcast::<Error>().unwrap(), Error::RetryableTimeout);
+}
+
+#[test]
+fn test_store_mock_key_success() {
+ let mock_key =
+ RemotelyProvisionedKey { keyBlob: vec![1, 2, 3], encodedCertChain: vec![4, 5, 6] };
+ let registration = get_mock_registration(&mock_key, /*latency=*/ None).unwrap();
+ tokio_rt()
+ .block_on(store_rkpd_attestation_key_with_registration_async(®istration, &[], &[]))
+ .unwrap();
+}
+
+#[test]
+fn test_store_mock_key_timeout() {
+ let mock_key =
+ RemotelyProvisionedKey { keyBlob: vec![1, 2, 3], encodedCertChain: vec![4, 5, 6] };
+ let latency = RKPD_TIMEOUT + Duration::from_secs(1);
+ let registration = get_mock_registration(&mock_key, Some(latency)).unwrap();
+
+ let result = tokio_rt().block_on(store_rkpd_attestation_key_with_registration_async(
+ ®istration,
+ &[],
+ &[],
+ ));
+ assert_eq!(result.unwrap_err().downcast::<Error>().unwrap(), Error::Timeout);
+}
+
+#[test]
+fn test_get_rkpd_attestation_key() {
+ binder::ProcessState::start_thread_pool();
+ let key_id = get_next_key_id();
+ let key = get_rkpd_attestation_key(DEFAULT_RPC_SERVICE_NAME, key_id).unwrap();
+ assert!(!key.keyBlob.is_empty());
+ assert!(!key.encodedCertChain.is_empty());
+}
+
+#[test]
+fn test_get_rkpd_attestation_key_same_caller() {
+ binder::ProcessState::start_thread_pool();
+ let key_id = get_next_key_id();
+
+ // Multiple calls should return the same key.
+ let first_key = get_rkpd_attestation_key(DEFAULT_RPC_SERVICE_NAME, key_id).unwrap();
+ let second_key = get_rkpd_attestation_key(DEFAULT_RPC_SERVICE_NAME, key_id).unwrap();
+
+ assert_eq!(first_key.keyBlob, second_key.keyBlob);
+ assert_eq!(first_key.encodedCertChain, second_key.encodedCertChain);
+}
+
+#[test]
+fn test_get_rkpd_attestation_key_different_caller() {
+ binder::ProcessState::start_thread_pool();
+ let first_key_id = get_next_key_id();
+ let second_key_id = get_next_key_id();
+
+ // Different callers should be getting different keys.
+ let first_key = get_rkpd_attestation_key(DEFAULT_RPC_SERVICE_NAME, first_key_id).unwrap();
+ let second_key = get_rkpd_attestation_key(DEFAULT_RPC_SERVICE_NAME, second_key_id).unwrap();
+
+ assert_ne!(first_key.keyBlob, second_key.keyBlob);
+ assert_ne!(first_key.encodedCertChain, second_key.encodedCertChain);
+}
+
+#[test]
+// Couple of things to note:
+// 1. This test must never run with UID of keystore. Otherwise, it can mess up keys stored by
+// keystore.
+// 2. Storing and reading the stored key is prone to race condition. So, we only do this in one
+// test case.
+fn test_store_rkpd_attestation_key() {
+ binder::ProcessState::start_thread_pool();
+ let key_id = get_next_key_id();
+ let key = get_rkpd_attestation_key(DEFAULT_RPC_SERVICE_NAME, key_id).unwrap();
+ let new_blob: [u8; 8] = rand::random();
+
+ assert!(store_rkpd_attestation_key(DEFAULT_RPC_SERVICE_NAME, &key.keyBlob, &new_blob).is_ok());
+
+ let new_key = get_rkpd_attestation_key(DEFAULT_RPC_SERVICE_NAME, key_id).unwrap();
+
+ // Restore original key so that we don't leave RKPD with invalid blobs.
+ assert!(store_rkpd_attestation_key(DEFAULT_RPC_SERVICE_NAME, &new_blob, &key.keyBlob).is_ok());
+ assert_eq!(new_key.keyBlob, new_blob);
+}
+
+#[test]
+fn test_stress_get_rkpd_attestation_key() {
+ binder::ProcessState::start_thread_pool();
+ let key_id = get_next_key_id();
+ let mut threads = vec![];
+ const NTHREADS: u32 = 10;
+ const NCALLS: u32 = 1000;
+
+ for _ in 0..NTHREADS {
+ threads.push(std::thread::spawn(move || {
+ for _ in 0..NCALLS {
+ let key = get_rkpd_attestation_key(DEFAULT_RPC_SERVICE_NAME, key_id).unwrap();
+ assert!(!key.keyBlob.is_empty());
+ assert!(!key.encodedCertChain.is_empty());
+ }
+ }));
+ }
+
+ for t in threads {
+ assert!(t.join().is_ok());
+ }
+}