Merge "Handle errors codes from rkpd getKey operations"
diff --git a/identity/RemotelyProvisionedKey.cpp b/identity/RemotelyProvisionedKey.cpp
index 7e90d63..784a680 100644
--- a/identity/RemotelyProvisionedKey.cpp
+++ b/identity/RemotelyProvisionedKey.cpp
@@ -21,6 +21,7 @@
 #include <android-base/logging.h>
 #include <android/security/rkp/BnGetKeyCallback.h>
 #include <android/security/rkp/BnGetRegistrationCallback.h>
+#include <android/security/rkp/IGetKeyCallback.h>
 #include <android/security/rkp/IRemoteProvisioning.h>
 #include <binder/IServiceManager.h>
 #include <binder/Status.h>
@@ -38,6 +39,7 @@
 using ::android::hardware::security::keymint::RpcHardwareInfo;
 using ::android::security::rkp::BnGetKeyCallback;
 using ::android::security::rkp::BnGetRegistrationCallback;
+using ::android::security::rkp::IGetKeyCallback;
 using ::android::security::rkp::IRegistration;
 using ::android::security::rkp::IRemoteProvisioning;
 using ::android::security::rkp::RemotelyProvisionedKey;
@@ -96,11 +98,11 @@
         keyPromise_.set_value(std::nullopt);
         return Status::ok();
     }
-    Status onError(const String16& error) override {
+    Status onError(IGetKeyCallback::ErrorCode error, const String16& description) override {
         if (called_.test_and_set()) {
             return Status::ok();
         }
-        LOG(ERROR) << "GetKeyCallback failed: " << error;
+        LOG(ERROR) << "GetKeyCallback failed: " << static_cast<int>(error) << ", " << description;
         keyPromise_.set_value(std::nullopt);
         return Status::ok();
     }
@@ -124,7 +126,8 @@
         auto cb = sp<GetKeyCallback>::make(std::move(keyPromise_));
         auto status = registration->getKey(keyId_, cb);
         if (!status.isOk()) {
-            cb->onError(String16("Failed to register GetKeyCallback"));
+            cb->onError(IGetKeyCallback::ErrorCode::ERROR_UNKNOWN,
+                        String16("Failed to register GetKeyCallback"));
         }
         return Status::ok();
     }
diff --git a/keystore2/src/error.rs b/keystore2/src/error.rs
index d1d58a4..3ca3942 100644
--- a/keystore2/src/error.rs
+++ b/keystore2/src/error.rs
@@ -71,11 +71,6 @@
     pub fn perm() -> Self {
         Error::Rc(ResponseCode::PERMISSION_DENIED)
     }
-
-    /// Short hand for `Error::Rc(ResponseCode::OUT_OF_KEYS_TRANSIENT_ERROR)`
-    pub fn out_of_keys() -> Self {
-        Error::Rc(ResponseCode::OUT_OF_KEYS_TRANSIENT_ERROR)
-    }
 }
 
 /// Helper function to map the binder status we get from calls into KeyMint
diff --git a/keystore2/src/rkpd_client.rs b/keystore2/src/rkpd_client.rs
index d0d036c..f1e8e11 100644
--- a/keystore2/src/rkpd_client.rs
+++ b/keystore2/src/rkpd_client.rs
@@ -13,16 +13,15 @@
 // limitations under the License.
 
 //! Helper wrapper around RKPD interface.
-// TODO(b/264891956): Return RKP specific errors.
 
-use crate::error::{map_binder_status_code, Error};
+use crate::error::{map_binder_status_code, Error, ResponseCode};
 use crate::globals::get_remotely_provisioned_component_name;
 use crate::ks_err;
 use crate::utils::watchdog as wd;
 use android_hardware_security_keymint::aidl::android::hardware::security::keymint::SecurityLevel::SecurityLevel;
 use android_security_rkp_aidl::aidl::android::security::rkp::{
-    IGetKeyCallback::BnGetKeyCallback, IGetKeyCallback::IGetKeyCallback,
-    IGetRegistrationCallback::BnGetRegistrationCallback,
+    IGetKeyCallback::BnGetKeyCallback, IGetKeyCallback::ErrorCode::ErrorCode as GetKeyErrorCode,
+    IGetKeyCallback::IGetKeyCallback, IGetRegistrationCallback::BnGetRegistrationCallback,
     IGetRegistrationCallback::IGetRegistrationCallback, IRegistration::IRegistration,
     IRemoteProvisioning::IRemoteProvisioning,
     IStoreUpgradedKeyCallback::BnStoreUpgradedKeyCallback,
@@ -30,7 +29,6 @@
     RemotelyProvisionedKey::RemotelyProvisionedKey,
 };
 use android_security_rkp_aidl::binder::{BinderFeatures, Interface, Strong};
-use android_system_keystore2::aidl::android::system::keystore2::ResponseCode::ResponseCode;
 use anyhow::{Context, Result};
 use std::sync::Mutex;
 use std::time::Duration;
@@ -91,17 +89,17 @@
         let _wp = wd::watch_millis("IGetRegistrationCallback::onCancel", 500);
         log::warn!("IGetRegistrationCallback cancelled");
         self.registration_tx.send(
-            Err(Error::Rc(ResponseCode::OUT_OF_KEYS))
+            Err(Error::Rc(ResponseCode::OUT_OF_KEYS_TRANSIENT_ERROR))
                 .context(ks_err!("GetRegistrationCallback cancelled.")),
         );
         Ok(())
     }
-    fn onError(&self, error: &str) -> binder::Result<()> {
+    fn onError(&self, description: &str) -> binder::Result<()> {
         let _wp = wd::watch_millis("IGetRegistrationCallback::onError", 500);
-        log::error!("IGetRegistrationCallback failed: '{error}'");
+        log::error!("IGetRegistrationCallback failed: '{description}'");
         self.registration_tx.send(
-            Err(Error::Rc(ResponseCode::OUT_OF_KEYS))
-                .context(ks_err!("GetRegistrationCallback failed: {:?}", error)),
+            Err(Error::Rc(ResponseCode::OUT_OF_KEYS_TRANSIENT_ERROR))
+                .context(ks_err!("GetRegistrationCallback failed: {:?}", description)),
         );
         Ok(())
     }
@@ -161,17 +159,33 @@
         let _wp = wd::watch_millis("IGetKeyCallback::onCancel", 500);
         log::warn!("IGetKeyCallback cancelled");
         self.key_tx.send(
-            Err(Error::Rc(ResponseCode::OUT_OF_KEYS)).context(ks_err!("GetKeyCallback cancelled.")),
+            Err(Error::Rc(ResponseCode::OUT_OF_KEYS_TRANSIENT_ERROR))
+                .context(ks_err!("GetKeyCallback cancelled.")),
         );
         Ok(())
     }
-    fn onError(&self, error: &str) -> binder::Result<()> {
+    fn onError(&self, error: GetKeyErrorCode, description: &str) -> binder::Result<()> {
         let _wp = wd::watch_millis("IGetKeyCallback::onError", 500);
-        log::error!("IGetKeyCallback failed: {error}");
-        self.key_tx.send(
-            Err(Error::Rc(ResponseCode::OUT_OF_KEYS))
-                .context(ks_err!("GetKeyCallback failed: {:?}", error)),
-        );
+        log::error!("IGetKeyCallback failed: {description}");
+        let rc = match error {
+            GetKeyErrorCode::ERROR_UNKNOWN => ResponseCode::OUT_OF_KEYS_TRANSIENT_ERROR,
+            GetKeyErrorCode::ERROR_PERMANENT => ResponseCode::OUT_OF_KEYS_PERMANENT_ERROR,
+            GetKeyErrorCode::ERROR_PENDING_INTERNET_CONNECTIVITY => {
+                ResponseCode::OUT_OF_KEYS_PENDING_INTERNET_CONNECTIVITY
+            }
+            GetKeyErrorCode::ERROR_REQUIRES_SECURITY_PATCH => {
+                ResponseCode::OUT_OF_KEYS_REQUIRES_SYSTEM_UPGRADE
+            }
+            _ => {
+                log::error!("Unexpected error from rkpd: {error:?}");
+                ResponseCode::OUT_OF_KEYS_TRANSIENT_ERROR
+            }
+        };
+        self.key_tx.send(Err(Error::Rc(rc)).context(ks_err!(
+            "GetKeyCallback failed: {:?} {:?}",
+            error,
+            description
+        )));
         Ok(())
     }
 }
@@ -188,7 +202,7 @@
         .context(ks_err!("Trying to get key."))?;
 
     match timeout(RKPD_TIMEOUT, rx).await {
-        Err(e) => Err(Error::Rc(ResponseCode::OUT_OF_KEYS))
+        Err(e) => Err(Error::Rc(ResponseCode::OUT_OF_KEYS_TRANSIENT_ERROR))
             .context(ks_err!("Waiting for RKPD key timed out: {:?}", e)),
         Ok(v) => v.unwrap(),
     }
@@ -298,6 +312,7 @@
     };
     use android_security_rkp_aidl::aidl::android::security::rkp::IRegistration::BnRegistration;
     use keystore2_crypto::parse_subject_from_certificate;
+    use std::collections::HashMap;
     use std::sync::atomic::{AtomicU32, Ordering};
 
     #[derive(Default)]
@@ -403,7 +418,7 @@
         let result = tokio_rt().block_on(rx).unwrap();
         assert_eq!(
             result.unwrap_err().downcast::<Error>().unwrap(),
-            Error::Rc(ResponseCode::OUT_OF_KEYS)
+            Error::Rc(ResponseCode::OUT_OF_KEYS_TRANSIENT_ERROR)
         );
     }
 
@@ -416,7 +431,7 @@
         let result = tokio_rt().block_on(rx).unwrap();
         assert_eq!(
             result.unwrap_err().downcast::<Error>().unwrap(),
-            Error::Rc(ResponseCode::OUT_OF_KEYS)
+            Error::Rc(ResponseCode::OUT_OF_KEYS_TRANSIENT_ERROR)
         );
     }
 
@@ -441,21 +456,38 @@
         let result = tokio_rt().block_on(rx).unwrap();
         assert_eq!(
             result.unwrap_err().downcast::<Error>().unwrap(),
-            Error::Rc(ResponseCode::OUT_OF_KEYS)
+            Error::Rc(ResponseCode::OUT_OF_KEYS_TRANSIENT_ERROR)
         );
     }
 
     #[test]
     fn test_get_key_cb_error() {
-        let (tx, rx) = oneshot::channel();
-        let cb = GetKeyCallback::new_native_binder(tx);
-        assert!(cb.onError("error").is_ok());
+        let error_mapping = HashMap::from([
+            (GetKeyErrorCode::ERROR_UNKNOWN, ResponseCode::OUT_OF_KEYS_TRANSIENT_ERROR),
+            (GetKeyErrorCode::ERROR_PERMANENT, ResponseCode::OUT_OF_KEYS_PERMANENT_ERROR),
+            (
+                GetKeyErrorCode::ERROR_PENDING_INTERNET_CONNECTIVITY,
+                ResponseCode::OUT_OF_KEYS_PENDING_INTERNET_CONNECTIVITY,
+            ),
+            (
+                GetKeyErrorCode::ERROR_REQUIRES_SECURITY_PATCH,
+                ResponseCode::OUT_OF_KEYS_REQUIRES_SYSTEM_UPGRADE,
+            ),
+        ]);
 
-        let result = tokio_rt().block_on(rx).unwrap();
-        assert_eq!(
-            result.unwrap_err().downcast::<Error>().unwrap(),
-            Error::Rc(ResponseCode::OUT_OF_KEYS)
-        );
+        // Loop over the generated list of enum values to better ensure this test stays in
+        // sync with the AIDL.
+        for get_key_error in GetKeyErrorCode::enum_values() {
+            let (tx, rx) = oneshot::channel();
+            let cb = GetKeyCallback::new_native_binder(tx);
+            assert!(cb.onError(get_key_error, "error").is_ok());
+
+            let result = tokio_rt().block_on(rx).unwrap();
+            assert_eq!(
+                result.unwrap_err().downcast::<Error>().unwrap(),
+                Error::Rc(error_mapping[&get_key_error]),
+            );
+        }
     }
 
     #[test]
@@ -503,7 +535,7 @@
             tokio_rt().block_on(get_rkpd_attestation_key_from_registration_async(&registration, 0));
         assert_eq!(
             result.unwrap_err().downcast::<Error>().unwrap(),
-            Error::Rc(ResponseCode::OUT_OF_KEYS)
+            Error::Rc(ResponseCode::OUT_OF_KEYS_TRANSIENT_ERROR)
         );
     }