Merge "Keystore2: fix test_large_number_of_concurrent_db_manipulations" into sc-dev
diff --git a/keystore2/src/crypto/certificate_utils.cpp b/keystore2/src/crypto/certificate_utils.cpp
index 31c7fb4..24b3793 100644
--- a/keystore2/src/crypto/certificate_utils.cpp
+++ b/keystore2/src/crypto/certificate_utils.cpp
@@ -23,10 +23,13 @@
 
 #include <functional>
 #include <limits>
-#include <string>
 #include <variant>
 #include <vector>
 
+#ifndef __LP64__
+#include <time64.h>
+#endif
+
 namespace keystore {
 
 namespace {
@@ -167,45 +170,42 @@
     return key_usage;
 }
 
-template <typename Out, typename In> static Out saturate(In in) {
-    if constexpr (std::is_signed_v<Out> == std::is_signed_v<In>) {
-        if constexpr (sizeof(Out) >= sizeof(In)) {
-            // Same sign, and In fits into Out. Cast is lossless.
-            return static_cast<Out>(in);
-        } else {
-            // Out is smaller than In we may need to truncate.
-            // We pick the smaller of `out::max()` and the greater of `out::min()` and `in`.
-            return static_cast<Out>(
-                std::min(static_cast<In>(std::numeric_limits<Out>::max()),
-                         std::max(static_cast<In>(std::numeric_limits<Out>::min()), in)));
-        }
-    } else {
-        // So we have different signs. This puts the lower bound at 0 because either input or output
-        // is unsigned. The upper bound is max of the smaller type or, if they are equal the max of
-        // the signed type.
-        if constexpr (std::is_signed_v<Out>) {
-            if constexpr (sizeof(Out) > sizeof(In)) {
-                return static_cast<Out>(in);
-            } else {
-                // Because `out` is the signed one, the lower bound of `in` is 0 and fits into
-                // `out`. We just have to compare the maximum and we do it in type In because it has
-                // a greater range than Out, so Out::max() is guaranteed to fit.
-                return static_cast<Out>(
-                    std::min(static_cast<In>(std::numeric_limits<Out>::max()), in));
-            }
-        } else {
-            // Out is unsigned. So we can return 0 if in is negative.
-            if (in < 0) return 0;
-            if constexpr (sizeof(Out) >= sizeof(In)) {
-                // If Out is wider or equal we can assign lossless.
-                return static_cast<Out>(in);
-            } else {
-                // Otherwise we have to take the minimum of Out::max() and `in`.
-                return static_cast<Out>(
-                    std::min(static_cast<In>(std::numeric_limits<Out>::max()), in));
-            }
-        }
+// TODO Once boring ssl can take int64_t instead of time_t we can go back to using
+//      ASN1_TIME_set: https://bugs.chromium.org/p/boringssl/issues/detail?id=416
+std::optional<std::array<char, 16>> toTimeString(int64_t timeMillis) {
+    struct tm time;
+    // If timeMillis is negative the rounding direction should still be to the nearest previous
+    // second.
+    if (timeMillis < 0 && __builtin_add_overflow(timeMillis, -999, &timeMillis)) {
+        return std::nullopt;
     }
+#if defined(__LP64__)
+    time_t timeSeconds = timeMillis / 1000;
+    if (gmtime_r(&timeSeconds, &time) == nullptr) {
+        return std::nullopt;
+    }
+#else
+    time64_t timeSeconds = timeMillis / 1000;
+    if (gmtime64_r(&timeSeconds, &time) == nullptr) {
+        return std::nullopt;
+    }
+#endif
+    std::array<char, 16> buffer;
+    if (__builtin_add_overflow(time.tm_year, 1900, &time.tm_year)) {
+        return std::nullopt;
+    }
+    if (time.tm_year >= 1950 && time.tm_year < 2050) {
+        // UTCTime according to RFC5280 4.1.2.5.1.
+        snprintf(buffer.data(), buffer.size(), "%02d%02d%02d%02d%02d%02dZ", time.tm_year % 100,
+                 time.tm_mon + 1, time.tm_mday, time.tm_hour, time.tm_min, time.tm_sec);
+    } else if (time.tm_year >= 0 && time.tm_year < 10000) {
+        // GeneralizedTime according to RFC5280 4.1.2.5.2.
+        snprintf(buffer.data(), buffer.size(), "%04d%02d%02d%02d%02d%02dZ", time.tm_year,
+                 time.tm_mon + 1, time.tm_mday, time.tm_hour, time.tm_min, time.tm_sec);
+    } else {
+        return std::nullopt;
+    }
+    return buffer;
 }
 
 // Creates a rump certificate structure with serial, subject and issuer names, as well as
@@ -259,19 +259,24 @@
         return std::get<CertUtilsError>(subjectName);
     }
 
-    time_t notBeforeTime = saturate<time_t>(activeDateTimeMilliSeconds / 1000);
+    auto notBeforeTime = toTimeString(activeDateTimeMilliSeconds);
+    if (!notBeforeTime) {
+        return CertUtilsError::TimeError;
+    }
     // Set activation date.
     ASN1_TIME_Ptr notBefore(ASN1_TIME_new());
-    if (!notBefore || !ASN1_TIME_set(notBefore.get(), notBeforeTime) ||
+    if (!notBefore || !ASN1_TIME_set_string(notBefore.get(), notBeforeTime->data()) ||
         !X509_set_notBefore(certificate.get(), notBefore.get() /* Don't release; copied */))
         return CertUtilsError::BoringSsl;
 
     // Set expiration date.
-    time_t notAfterTime;
-    notAfterTime = saturate<time_t>(usageExpireDateTimeMilliSeconds / 1000);
+    auto notAfterTime = toTimeString(usageExpireDateTimeMilliSeconds);
+    if (!notAfterTime) {
+        return CertUtilsError::TimeError;
+    }
 
     ASN1_TIME_Ptr notAfter(ASN1_TIME_new());
-    if (!notAfter || !ASN1_TIME_set(notAfter.get(), notAfterTime) ||
+    if (!notAfter || !ASN1_TIME_set_string(notAfter.get(), notAfterTime->data()) ||
         !X509_set_notAfter(certificate.get(), notAfter.get() /* Don't release; copied */)) {
         return CertUtilsError::BoringSsl;
     }
diff --git a/keystore2/src/crypto/include/certificate_utils.h b/keystore2/src/crypto/include/certificate_utils.h
index 6c25b9a..b483d88 100644
--- a/keystore2/src/crypto/include/certificate_utils.h
+++ b/keystore2/src/crypto/include/certificate_utils.h
@@ -53,6 +53,7 @@
         InvalidArgument,
         UnexpectedNullPointer,
         SignatureFailed,
+        TimeError,
     };
 
   private:
@@ -137,6 +138,16 @@
 };
 
 /**
+ * Takes an int64_t representing UNIX epoch time in milliseconds and turns it into a UTCTime
+ * or GeneralizedTime string depending on whether the year is in the interval [1950 .. 2050).
+ * Note: The string returned in the array buffer is NUL terminated and of length 13 (UTCTime)
+ * or 15 (GeneralizedTime).
+ * @param timeMillis
+ * @return UTCTime or GeneralizedTime string.
+ */
+std::optional<std::array<char, 16>> toTimeString(int64_t timeMillis);
+
+/**
  * Sets the signature specifier of the certificate and the signature according to the parameters
  * c. Then it signs the certificate with the `sign` callback.
  * IMPORTANT: The parameters `algo`, `padding`, and `digest` do not control the actual signing
diff --git a/keystore2/src/crypto/tests/certificate_utils_test.cpp b/keystore2/src/crypto/tests/certificate_utils_test.cpp
index 119c3fa..ebd6792 100644
--- a/keystore2/src/crypto/tests/certificate_utils_test.cpp
+++ b/keystore2/src/crypto/tests/certificate_utils_test.cpp
@@ -315,3 +315,23 @@
     EVP_PKEY_Ptr decoded_pkey(X509_get_pubkey(decoded_cert.get()));
     ASSERT_TRUE(X509_verify(decoded_cert.get(), decoded_pkey.get()));
 }
+
+TEST(TimeStringTests, toTimeStringTest) {
+    // Two test vectors that need to result in UTCTime
+    ASSERT_EQ(std::string(toTimeString(1622758591000)->data()), std::string("210603221631Z"));
+    ASSERT_EQ(std::string(toTimeString(0)->data()), std::string("700101000000Z"));
+    // Two test vectors that need to result in GeneralizedTime.
+    ASSERT_EQ(std::string(toTimeString(16227585910000)->data()), std::string("24840325064510Z"));
+    ASSERT_EQ(std::string(toTimeString(-1622758591000)->data()), std::string("19180731014329Z"));
+
+    // Highest possible UTCTime
+    ASSERT_EQ(std::string(toTimeString(2524607999999)->data()), "491231235959Z");
+    // And one millisecond later must be GeneralizedTime.
+    ASSERT_EQ(std::string(toTimeString(2524608000000)->data()), "20500101000000Z");
+
+    // Earliest possible UTCTime
+    ASSERT_EQ(std::string(toTimeString(-631152000000)->data()), "500101000000Z");
+    // And one millisecond earlier must be GeneralizedTime.
+    // This also checks that the rounding direction does not flip when the input is negative.
+    ASSERT_EQ(std::string(toTimeString(-631152000001)->data()), "19491231235959Z");
+}
\ No newline at end of file
diff --git a/keystore2/src/globals.rs b/keystore2/src/globals.rs
index b483a8f..8212213 100644
--- a/keystore2/src/globals.rs
+++ b/keystore2/src/globals.rs
@@ -209,9 +209,12 @@
         }
     };
 
-    let keymint = if let Some(service_name) = service_name {
-        map_binder_status_code(binder::get_interface(&service_name))
-            .context("In connect_keymint: Trying to connect to genuine KeyMint service.")
+    let (keymint, hal_version) = if let Some(service_name) = service_name {
+        (
+            map_binder_status_code(binder::get_interface(&service_name))
+                .context("In connect_keymint: Trying to connect to genuine KeyMint service.")?,
+            Some(100i32), // The HAL version code for KeyMint V1 is 100.
+        )
     } else {
         // This is a no-op if it was called before.
         keystore2_km_compat::add_keymint_device_service();
@@ -219,21 +222,35 @@
         let keystore_compat_service: Strong<dyn IKeystoreCompatService> =
             map_binder_status_code(binder::get_interface("android.security.compat"))
                 .context("In connect_keymint: Trying to connect to compat service.")?;
-        map_binder_status(keystore_compat_service.getKeyMintDevice(*security_level))
-            .map_err(|e| match e {
-                Error::BinderTransaction(StatusCode::NAME_NOT_FOUND) => {
-                    Error::Km(ErrorCode::HARDWARE_TYPE_UNAVAILABLE)
-                }
-                e => e,
-            })
-            .context("In connect_keymint: Trying to get Legacy wrapper.")
-    }?;
+        (
+            map_binder_status(keystore_compat_service.getKeyMintDevice(*security_level))
+                .map_err(|e| match e {
+                    Error::BinderTransaction(StatusCode::NAME_NOT_FOUND) => {
+                        Error::Km(ErrorCode::HARDWARE_TYPE_UNAVAILABLE)
+                    }
+                    e => e,
+                })
+                .context("In connect_keymint: Trying to get Legacy wrapper.")?,
+            None,
+        )
+    };
 
     let wp = wd::watch_millis("In connect_keymint: calling getHardwareInfo()", 500);
-    let hw_info = map_km_error(keymint.getHardwareInfo())
+    let mut hw_info = map_km_error(keymint.getHardwareInfo())
         .context("In connect_keymint: Failed to get hardware info.")?;
     drop(wp);
 
+    // The legacy wrapper sets hw_info.versionNumber to the underlying HAL version like so:
+    // 10 * <major> + <minor>, e.g., KM 3.0 = 30. So 30, 40, and 41 are the only viable values.
+    // For KeyMint the versionNumber is implementation defined and thus completely meaningless
+    // to Keystore 2.0. So at this point the versionNumber field is set to the HAL version, so
+    // that higher levels have a meaningful guide as to which feature set to expect from the
+    // implementation. As of this writing the only meaningful version number is 100 for KeyMint V1,
+    // and future AIDL versions should follow the pattern <AIDL version> * 100.
+    if let Some(hal_version) = hal_version {
+        hw_info.versionNumber = hal_version;
+    }
+
     Ok((Asp::new(keymint.as_binder()), hw_info))
 }
 
diff --git a/keystore2/src/km_compat/km_compat.cpp b/keystore2/src/km_compat/km_compat.cpp
index 19576aa..f6f8bfe 100644
--- a/keystore2/src/km_compat/km_compat.cpp
+++ b/keystore2/src/km_compat/km_compat.cpp
@@ -1395,8 +1395,7 @@
         if (!device) {
             return ScopedAStatus::fromStatus(STATUS_NAME_NOT_FOUND);
         }
-        bool inserted = false;
-        std::tie(i, inserted) = mDeviceCache.insert({in_securityLevel, std::move(device)});
+        i = mDeviceCache.insert(i, {in_securityLevel, std::move(device)});
     }
     *_aidl_return = i->second;
     return ScopedAStatus::ok();
@@ -1404,14 +1403,15 @@
 
 ScopedAStatus KeystoreCompatService::getSharedSecret(KeyMintSecurityLevel in_securityLevel,
                                                      std::shared_ptr<ISharedSecret>* _aidl_return) {
-    if (!mSharedSecret) {
+    auto i = mSharedSecretCache.find(in_securityLevel);
+    if (i == mSharedSecretCache.end()) {
         auto secret = SharedSecret::createSharedSecret(in_securityLevel);
         if (!secret) {
             return ScopedAStatus::fromStatus(STATUS_NAME_NOT_FOUND);
         }
-        mSharedSecret = std::move(secret);
+        i = mSharedSecretCache.insert(i, {in_securityLevel, std::move(secret)});
     }
-    *_aidl_return = mSharedSecret;
+    *_aidl_return = i->second;
     return ScopedAStatus::ok();
 }
 
diff --git a/keystore2/src/km_compat/km_compat.h b/keystore2/src/km_compat/km_compat.h
index 09c9157..2d892da 100644
--- a/keystore2/src/km_compat/km_compat.h
+++ b/keystore2/src/km_compat/km_compat.h
@@ -197,7 +197,7 @@
 class KeystoreCompatService : public BnKeystoreCompatService {
   private:
     std::unordered_map<KeyMintSecurityLevel, std::shared_ptr<IKeyMintDevice>> mDeviceCache;
-    std::shared_ptr<ISharedSecret> mSharedSecret;
+    std::unordered_map<KeyMintSecurityLevel, std::shared_ptr<ISharedSecret>> mSharedSecretCache;
     std::shared_ptr<ISecureClock> mSecureClock;
 
   public:
diff --git a/keystore2/src/shared_secret_negotiation.rs b/keystore2/src/shared_secret_negotiation.rs
index fb55f33..64bc2c3 100644
--- a/keystore2/src/shared_secret_negotiation.rs
+++ b/keystore2/src/shared_secret_negotiation.rs
@@ -24,6 +24,7 @@
 use anyhow::{Context, Result};
 use keystore2_vintf::{get_aidl_instances, get_hidl_instances};
 use std::fmt::{self, Display, Formatter};
+use std::time::Duration;
 
 /// This function initiates the shared secret negotiation. It starts a thread and then returns
 /// immediately. The thread consults the vintf manifest to enumerate expected negotiation
@@ -109,7 +110,11 @@
 
 /// Lists participants.
 fn list_participants() -> Result<Vec<SharedSecretParticipant>> {
-    Ok([(4, 0), (4, 1)]
+    // 4.1 implementation always also register as 4.0. So only the highest version of each
+    // "default" and "strongbox" makes the cut.
+    let mut legacy_default_found: bool = false;
+    let mut legacy_strongbox_found: bool = false;
+    Ok([(4, 1), (4, 0)]
         .iter()
         .map(|(ma, mi)| {
             get_hidl_instances(KEYMASTER_PACKAGE_NAME, *ma, *mi, KEYMASTER_INTERFACE_NAME)
@@ -119,7 +124,24 @@
                     instances
                         .into_iter()
                         .filter_map(|name| {
-                            filter_map_legacy_km_instances(name.to_string(), (*ma, *mi))
+                            filter_map_legacy_km_instances(name.to_string(), (*ma, *mi)).and_then(
+                                |sp| {
+                                    if let SharedSecretParticipant::Hidl {
+                                        is_strongbox: true,
+                                        ..
+                                    } = &sp
+                                    {
+                                        if !legacy_strongbox_found {
+                                            legacy_strongbox_found = true;
+                                            return Some(sp);
+                                        }
+                                    } else if !legacy_default_found {
+                                        legacy_default_found = true;
+                                        return Some(sp);
+                                    }
+                                    None
+                                },
+                            )
                         })
                         .collect::<Vec<SharedSecretParticipant>>()
                 })
@@ -215,7 +237,7 @@
         if participants.is_empty() {
             break;
         }
-        std::thread::sleep(std::time::Duration::from_millis(1000));
+        std::thread::sleep(Duration::from_millis(1000));
     }
     connected_participants
 }
@@ -237,7 +259,7 @@
             Err(e) => {
                 log::warn!("{:?}", e);
                 log::warn!("Retrying in one second.");
-                std::thread::sleep(std::time::Duration::from_millis(1000));
+                std::thread::sleep(Duration::from_millis(1000));
             }
             Ok(params) => break params,
         }
@@ -246,20 +268,28 @@
     params.sort_unstable();
 
     // Phase 2: Send the sorted sharing parameters to all participants.
-    participants
-        .into_iter()
-        .try_fold(None, |acc, (s, p)| {
-            match (acc, map_binder_status(s.computeSharedSecret(&params))) {
-                (None, Ok(new_sum)) => Ok(Some(new_sum)),
-                (Some(old_sum), Ok(new_sum)) => {
-                    if old_sum == new_sum {
-                        Ok(Some(old_sum))
-                    } else {
-                        Err(SharedSecretError::Checksum(p))
-                    }
+    let negotiation_result = participants.into_iter().try_fold(None, |acc, (s, p)| {
+        match (acc, map_binder_status(s.computeSharedSecret(&params))) {
+            (None, Ok(new_sum)) => Ok(Some(new_sum)),
+            (Some(old_sum), Ok(new_sum)) => {
+                if old_sum == new_sum {
+                    Ok(Some(old_sum))
+                } else {
+                    Err(SharedSecretError::Checksum(p))
                 }
-                (_, Err(e)) => Err(SharedSecretError::Computation { e, p }),
             }
-        })
-        .expect("Fatal: Shared secret computation failed.");
+            (_, Err(e)) => Err(SharedSecretError::Computation { e, p }),
+        }
+    });
+
+    if let Err(e) = negotiation_result {
+        log::error!("In negotiate_shared_secret: {:?}.", e);
+        if let SharedSecretError::Checksum(_) = e {
+            log::error!(concat!(
+                "This means that this device is NOT PROVISIONED CORRECTLY.\n",
+                "User authorization and other security functions will not work\n",
+                "as expected. Please contact your OEM for instructions.",
+            ));
+        }
+    }
 }