Merge "benchmark: updated benchmark tests for strongbox" am: a30c39330f am: a8f53c6363

Original change: https://android-review.googlesource.com/c/platform/hardware/interfaces/+/2168252

Change-Id: Ib94e3b4f5a14fb69df445aed52f1bdbfd052b54e
Signed-off-by: Automerger Merge Worker <android-build-automerger-merge-worker@system.gserviceaccount.com>
diff --git a/security/keymint/aidl/vts/performance/KeyMintBenchmark.cpp b/security/keymint/aidl/vts/performance/KeyMintBenchmark.cpp
index 5bbae4c..0c61c25 100644
--- a/security/keymint/aidl/vts/performance/KeyMintBenchmark.cpp
+++ b/security/keymint/aidl/vts/performance/KeyMintBenchmark.cpp
@@ -16,16 +16,21 @@
 
 #define LOG_TAG "keymint_benchmark"
 
+#include <iostream>
+
 #include <base/command_line.h>
 #include <benchmark/benchmark.h>
-#include <iostream>
 
 #include <aidl/Vintf.h>
 #include <aidl/android/hardware/security/keymint/ErrorCode.h>
 #include <aidl/android/hardware/security/keymint/IKeyMintDevice.h>
 #include <android/binder_manager.h>
 #include <binder/IServiceManager.h>
+
 #include <keymint_support/authorization_set.h>
+#include <keymint_support/openssl_utils.h>
+#include <openssl/curve25519.h>
+#include <openssl/x509.h>
 
 #define SMALL_MESSAGE_SIZE 64
 #define MEDIUM_MESSAGE_SIZE 1024
@@ -119,6 +124,22 @@
         return {};
     }
 
+    string getAlgorithmString(string transform) {
+        if (transform.find("AES") != string::npos) {
+            return "AES";
+        } else if (transform.find("Hmac") != string::npos) {
+            return "HMAC";
+        } else if (transform.find("DESede") != string::npos) {
+            return "TRIPLE_DES";
+        } else if (transform.find("RSA") != string::npos) {
+            return "RSA";
+        } else if (transform.find("EC") != string::npos) {
+            return "EC";
+        }
+        std::cerr << "Can't find algorithm for " << transform << std::endl;
+        return "";
+    }
+
     Digest getDigest(string transform) {
         if (transform.find("MD5") != string::npos) {
             return Digest::MD5;
@@ -135,29 +156,56 @@
             return Digest::SHA_2_512;
         } else if (transform.find("RSA") != string::npos &&
                    transform.find("OAEP") != string::npos) {
-            return Digest::SHA1;
+            if (securityLevel_ == SecurityLevel::STRONGBOX) {
+                return Digest::SHA_2_256;
+            } else {
+                return Digest::SHA1;
+            }
         } else if (transform.find("Hmac") != string::npos) {
             return Digest::SHA_2_256;
         }
         return Digest::NONE;
     }
 
+    string getDigestString(string transform) {
+        if (transform.find("MD5") != string::npos) {
+            return "MD5";
+        } else if (transform.find("SHA1") != string::npos ||
+                   transform.find("SHA-1") != string::npos) {
+            return "SHA1";
+        } else if (transform.find("SHA224") != string::npos) {
+            return "SHA_2_224";
+        } else if (transform.find("SHA256") != string::npos) {
+            return "SHA_2_256";
+        } else if (transform.find("SHA384") != string::npos) {
+            return "SHA_2_384";
+        } else if (transform.find("SHA512") != string::npos) {
+            return "SHA_2_512";
+        } else if (transform.find("RSA") != string::npos &&
+                   transform.find("OAEP") != string::npos) {
+            if (securityLevel_ == SecurityLevel::STRONGBOX) {
+                return "SHA_2_256";
+            } else {
+                return "SHA1";
+            }
+        } else if (transform.find("Hmac") != string::npos) {
+            return "SHA_2_256";
+        }
+        return "";
+    }
+
     optional<EcCurve> getCurveFromLength(int keySize) {
         switch (keySize) {
             case 224:
                 return EcCurve::P_224;
-                break;
             case 256:
                 return EcCurve::P_256;
-                break;
             case 384:
                 return EcCurve::P_384;
-                break;
             case 521:
                 return EcCurve::P_521;
-                break;
             default:
-                return {};
+                return std::nullopt;
         }
     }
 
@@ -261,6 +309,109 @@
         return GetReturnErrorCode(result);
     }
 
+    /* Copied the function LocalRsaEncryptMessage from
+     * hardware/interfaces/security/keymint/aidl/vts/functional/KeyMintAidlTestBase.cpp in VTS.
+     * Replaced asserts with the condition check and return false in case of failure condition.
+     * Require return value to skip the benchmark test case from further execution in case
+     * LocalRsaEncryptMessage fails.
+     */
+    optional<string> LocalRsaEncryptMessage(const string& message, const AuthorizationSet& params) {
+        // Retrieve the public key from the leaf certificate.
+        if (cert_chain_.empty()) {
+            std::cerr << "Local RSA encrypt Error: invalid cert_chain_" << std::endl;
+            return "Failure";
+        }
+        X509_Ptr key_cert(parse_cert_blob(cert_chain_[0].encodedCertificate));
+        EVP_PKEY_Ptr pub_key(X509_get_pubkey(key_cert.get()));
+        RSA_Ptr rsa(EVP_PKEY_get1_RSA(const_cast<EVP_PKEY*>(pub_key.get())));
+
+        // Retrieve relevant tags.
+        Digest digest = Digest::NONE;
+        Digest mgf_digest = Digest::SHA1;
+        PaddingMode padding = PaddingMode::NONE;
+
+        auto digest_tag = params.GetTagValue(TAG_DIGEST);
+        if (digest_tag.has_value()) digest = digest_tag.value();
+        auto pad_tag = params.GetTagValue(TAG_PADDING);
+        if (pad_tag.has_value()) padding = pad_tag.value();
+        auto mgf_tag = params.GetTagValue(TAG_RSA_OAEP_MGF_DIGEST);
+        if (mgf_tag.has_value()) mgf_digest = mgf_tag.value();
+
+        const EVP_MD* md = openssl_digest(digest);
+        const EVP_MD* mgf_md = openssl_digest(mgf_digest);
+
+        // Set up encryption context.
+        EVP_PKEY_CTX_Ptr ctx(EVP_PKEY_CTX_new(pub_key.get(), /* engine= */ nullptr));
+        if (EVP_PKEY_encrypt_init(ctx.get()) <= 0) {
+            std::cerr << "Local RSA encrypt Error: Encryption init failed" << std::endl;
+            return "Failure";
+        }
+
+        int rc = -1;
+        switch (padding) {
+            case PaddingMode::NONE:
+                rc = EVP_PKEY_CTX_set_rsa_padding(ctx.get(), RSA_NO_PADDING);
+                break;
+            case PaddingMode::RSA_PKCS1_1_5_ENCRYPT:
+                rc = EVP_PKEY_CTX_set_rsa_padding(ctx.get(), RSA_PKCS1_PADDING);
+                break;
+            case PaddingMode::RSA_OAEP:
+                rc = EVP_PKEY_CTX_set_rsa_padding(ctx.get(), RSA_PKCS1_OAEP_PADDING);
+                break;
+            default:
+                break;
+        }
+        if (rc <= 0) {
+            std::cerr << "Local RSA encrypt Error: Set padding failed" << std::endl;
+            return "Failure";
+        }
+        if (padding == PaddingMode::RSA_OAEP) {
+            if (!EVP_PKEY_CTX_set_rsa_oaep_md(ctx.get(), md)) {
+                std::cerr << "Local RSA encrypt Error: Set digest failed: " << ERR_peek_last_error()
+                          << std::endl;
+                return "Failure";
+            }
+            if (!EVP_PKEY_CTX_set_rsa_mgf1_md(ctx.get(), mgf_md)) {
+                std::cerr << "Local RSA encrypt Error: Set digest failed: " << ERR_peek_last_error()
+                          << std::endl;
+                return "Failure";
+            }
+        }
+
+        // Determine output size.
+        size_t outlen;
+        if (EVP_PKEY_encrypt(ctx.get(), nullptr /* out */, &outlen,
+                             reinterpret_cast<const uint8_t*>(message.data()),
+                             message.size()) <= 0) {
+            std::cerr << "Local RSA encrypt Error: Determine output size failed: "
+                      << ERR_peek_last_error() << std::endl;
+            return "Failure";
+        }
+
+        // Left-zero-pad the input if necessary.
+        const uint8_t* to_encrypt = reinterpret_cast<const uint8_t*>(message.data());
+        size_t to_encrypt_len = message.size();
+
+        std::unique_ptr<string> zero_padded_message;
+        if (padding == PaddingMode::NONE && to_encrypt_len < outlen) {
+            zero_padded_message.reset(new string(outlen, '\0'));
+            memcpy(zero_padded_message->data() + (outlen - to_encrypt_len), message.data(),
+                   message.size());
+            to_encrypt = reinterpret_cast<const uint8_t*>(zero_padded_message->data());
+            to_encrypt_len = outlen;
+        }
+
+        // Do the encryption.
+        string output(outlen, '\0');
+        if (EVP_PKEY_encrypt(ctx.get(), reinterpret_cast<uint8_t*>(output.data()), &outlen,
+                             to_encrypt, to_encrypt_len) <= 0) {
+            std::cerr << "Local RSA encrypt Error: Encryption failed: " << ERR_peek_last_error()
+                      << std::endl;
+            return "Failure";
+        }
+        return output;
+    }
+
     SecurityLevel securityLevel_;
     string name_;
 
@@ -268,12 +419,13 @@
     ErrorCode GenerateKey(const AuthorizationSet& key_desc,
                           const optional<AttestationKey>& attest_key = std::nullopt) {
         key_blob_.clear();
+        cert_chain_.clear();
         KeyCreationResult creationResult;
         Status result = keymint_->generateKey(key_desc.vector_data(), attest_key, &creationResult);
         if (result.isOk()) {
             key_blob_ = std::move(creationResult.keyBlob);
+            cert_chain_ = std::move(creationResult.certificateChain);
             creationResult.keyCharacteristics.clear();
-            creationResult.certificateChain.clear();
         }
         return GetReturnErrorCode(result);
     }
@@ -338,6 +490,11 @@
         return ErrorCode::UNKNOWN_ERROR;
     }
 
+    X509_Ptr parse_cert_blob(const vector<uint8_t>& blob) {
+        const uint8_t* p = blob.data();
+        return X509_Ptr(d2i_X509(nullptr /* allocate new */, &p, blob.size()));
+    }
+
     std::shared_ptr<IKeyMintOperation> op_;
     vector<Certificate> cert_chain_;
     vector<uint8_t> key_blob_;
@@ -390,6 +547,10 @@
     BENCHMARK_KM_MSG(encrypt, transform, keySize, msgSize) \
     BENCHMARK_KM_MSG(decrypt, transform, keySize, msgSize)
 
+// Skip public key operations as they are not supported in KeyMint.
+#define BENCHMARK_KM_ASYM_CIPHER(transform, keySize, msgSize)   \
+    BENCHMARK_KM_MSG(decrypt, transform, keySize, msgSize)
+
 #define BENCHMARK_KM_CIPHER_ALL_MSGS(transform, keySize) \
     BENCHMARK_KM_ALL_MSGS(encrypt, transform, keySize)   \
     BENCHMARK_KM_ALL_MSGS(decrypt, transform, keySize)
@@ -397,12 +558,43 @@
 #define BENCHMARK_KM_SIGNATURE_ALL_MSGS(transform, keySize) \
     BENCHMARK_KM_ALL_MSGS(sign, transform, keySize)         \
     BENCHMARK_KM_ALL_MSGS(verify, transform, keySize)
-// clang-format on
+
+// Skip public key operations as they are not supported in KeyMint.
+#define BENCHMARK_KM_ASYM_SIGNATURE_ALL_MSGS(transform, keySize) \
+    BENCHMARK_KM_ALL_MSGS(sign, transform, keySize) \
+    // clang-format on
 
 /*
  * ============= KeyGen TESTS ==================
  */
+
+static bool isValidSBKeySize(string transform, int keySize) {
+    std::optional<Algorithm> algorithm = keymintTest->getAlgorithm(transform);
+    switch (algorithm.value()) {
+        case Algorithm::AES:
+            return (keySize == 128 || keySize == 256);
+        case Algorithm::HMAC:
+            return (keySize % 8 == 0 && keySize >= 64 && keySize <= 512);
+        case Algorithm::TRIPLE_DES:
+            return (keySize == 168);
+        case Algorithm::RSA:
+            return (keySize == 2048);
+        case Algorithm::EC:
+            return (keySize == 256);
+    }
+    return false;
+}
+
 static void keygen(benchmark::State& state, string transform, int keySize) {
+    // Skip the test for unsupported key size in StrongBox
+    if (keymintTest->securityLevel_ == SecurityLevel::STRONGBOX &&
+        !isValidSBKeySize(transform, keySize)) {
+        state.SkipWithError(("Skipped for STRONGBOX: Keysize: " + std::to_string(keySize) +
+                             " is not supported in StrongBox for algorithm: " +
+                             keymintTest->getAlgorithmString(transform))
+                                    .c_str());
+        return;
+    }
     addDefaultLabel(state);
     for (auto _ : state) {
         if (!keymintTest->GenerateKey(transform, keySize)) {
@@ -438,8 +630,24 @@
 /*
  * ============= SIGNATURE TESTS ==================
  */
-
 static void sign(benchmark::State& state, string transform, int keySize, int msgSize) {
+    // Skip the test for unsupported key size or unsupported digest in StrongBox
+    if (keymintTest->securityLevel_ == SecurityLevel::STRONGBOX) {
+        if (!isValidSBKeySize(transform, keySize)) {
+            state.SkipWithError(("Skipped for STRONGBOX: Keysize: " + std::to_string(keySize) +
+                                 " is not supported in StrongBox for algorithm: " +
+                                 keymintTest->getAlgorithmString(transform))
+                                        .c_str());
+            return;
+        }
+        if (keymintTest->getDigest(transform) != Digest::SHA_2_256) {
+            state.SkipWithError(
+                    ("Skipped for STRONGBOX: Digest: " + keymintTest->getDigestString(transform) +
+                     " is not supported in StrongBox")
+                            .c_str());
+            return;
+        }
+    }
     addDefaultLabel(state);
     if (!keymintTest->GenerateKey(transform, keySize, true)) {
         state.SkipWithError(
@@ -469,6 +677,23 @@
 }
 
 static void verify(benchmark::State& state, string transform, int keySize, int msgSize) {
+    // Skip the test for unsupported key size or unsupported digest in StrongBox
+    if (keymintTest->securityLevel_ == SecurityLevel::STRONGBOX) {
+        if (!isValidSBKeySize(transform, keySize)) {
+            state.SkipWithError(("Skipped for STRONGBOX: Keysize: " + std::to_string(keySize) +
+                                 " is not supported in StrongBox for algorithm: " +
+                                 keymintTest->getAlgorithmString(transform))
+                                        .c_str());
+            return;
+        }
+        if (keymintTest->getDigest(transform) != Digest::SHA_2_256) {
+            state.SkipWithError(
+                    ("Skipped for STRONGBOX: Digest: " + keymintTest->getDigestString(transform) +
+                     " is not supported in StrongBox")
+                            .c_str());
+            return;
+        }
+    }
     addDefaultLabel(state);
     if (!keymintTest->GenerateKey(transform, keySize, true)) {
         state.SkipWithError(
@@ -525,10 +750,10 @@
 BENCHMARK_KM_SIGNATURE_ALL_HMAC_KEYS(HmacSHA512)
 
 #define BENCHMARK_KM_SIGNATURE_ALL_ECDSA_KEYS(transform) \
-    BENCHMARK_KM_SIGNATURE_ALL_MSGS(transform, 224)      \
-    BENCHMARK_KM_SIGNATURE_ALL_MSGS(transform, 256)      \
-    BENCHMARK_KM_SIGNATURE_ALL_MSGS(transform, 384)      \
-    BENCHMARK_KM_SIGNATURE_ALL_MSGS(transform, 521)
+    BENCHMARK_KM_ASYM_SIGNATURE_ALL_MSGS(transform, 224)      \
+    BENCHMARK_KM_ASYM_SIGNATURE_ALL_MSGS(transform, 256)      \
+    BENCHMARK_KM_ASYM_SIGNATURE_ALL_MSGS(transform, 384)      \
+    BENCHMARK_KM_ASYM_SIGNATURE_ALL_MSGS(transform, 521)
 
 BENCHMARK_KM_SIGNATURE_ALL_ECDSA_KEYS(NONEwithECDSA);
 BENCHMARK_KM_SIGNATURE_ALL_ECDSA_KEYS(SHA1withECDSA);
@@ -538,13 +763,14 @@
 BENCHMARK_KM_SIGNATURE_ALL_ECDSA_KEYS(SHA512withECDSA);
 
 #define BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(transform) \
-    BENCHMARK_KM_SIGNATURE_ALL_MSGS(transform, 2048)   \
-    BENCHMARK_KM_SIGNATURE_ALL_MSGS(transform, 3072)   \
-    BENCHMARK_KM_SIGNATURE_ALL_MSGS(transform, 4096)
+    BENCHMARK_KM_ASYM_SIGNATURE_ALL_MSGS(transform, 2048)   \
+    BENCHMARK_KM_ASYM_SIGNATURE_ALL_MSGS(transform, 3072)   \
+    BENCHMARK_KM_ASYM_SIGNATURE_ALL_MSGS(transform, 4096)
 
 BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(MD5withRSA);
 BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(SHA1withRSA);
 BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(SHA224withRSA);
+BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(SHA256withRSA);
 BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(SHA384withRSA);
 BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(SHA512withRSA);
 
@@ -553,6 +779,7 @@
 BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(SHA224withRSA/PSS);
 BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(SHA384withRSA/PSS);
 BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(SHA512withRSA/PSS);
+
 // clang-format on
 
 /*
@@ -560,6 +787,15 @@
  */
 
 static void encrypt(benchmark::State& state, string transform, int keySize, int msgSize) {
+    // Skip the test for unsupported key size in StrongBox
+    if (keymintTest->securityLevel_ == SecurityLevel::STRONGBOX &&
+        (!isValidSBKeySize(transform, keySize))) {
+        state.SkipWithError(("Skipped for STRONGBOX: Keysize: " + std::to_string(keySize) +
+                             " is not supported in StrongBox for algorithm: " +
+                             keymintTest->getAlgorithmString(transform))
+                                    .c_str());
+        return;
+    }
     addDefaultLabel(state);
     if (!keymintTest->GenerateKey(transform, keySize)) {
         state.SkipWithError(
@@ -589,6 +825,15 @@
 }
 
 static void decrypt(benchmark::State& state, string transform, int keySize, int msgSize) {
+    // Skip the test for unsupported key size in StrongBox
+    if (keymintTest->securityLevel_ == SecurityLevel::STRONGBOX &&
+        (!isValidSBKeySize(transform, keySize))) {
+        state.SkipWithError(("Skipped for STRONGBOX: Keysize: " + std::to_string(keySize) +
+                             " is not supported in StrongBox for algorithm: " +
+                             keymintTest->getAlgorithmString(transform))
+                                    .c_str());
+        return;
+    }
     addDefaultLabel(state);
     if (!keymintTest->GenerateKey(transform, keySize)) {
         state.SkipWithError(
@@ -598,23 +843,34 @@
     AuthorizationSet out_params;
     AuthorizationSet in_params = keymintTest->getOperationParams(transform);
     string message = keymintTest->GenerateMessage(msgSize);
-    auto error = keymintTest->Begin(KeyPurpose::ENCRYPT, in_params, &out_params);
-    if (error != ErrorCode::OK) {
-        state.SkipWithError(
-                ("Encryption begin error, " + std::to_string(keymintTest->getError())).c_str());
-        return;
+    optional<string> encryptedMessage;
+
+    if (keymintTest->getAlgorithm(transform).value() == Algorithm::RSA) {
+        // Public key operation not supported, doing local Encryption
+        encryptedMessage = keymintTest->LocalRsaEncryptMessage(message, in_params);
+        if ((keySize / 8) != (*encryptedMessage).size()) {
+            state.SkipWithError("Local Encryption falied");
+            return;
+        }
+    } else {
+        auto error = keymintTest->Begin(KeyPurpose::ENCRYPT, in_params, &out_params);
+        if (error != ErrorCode::OK) {
+            state.SkipWithError(
+                    ("Encryption begin error, " + std::to_string(keymintTest->getError())).c_str());
+            return;
+        }
+        encryptedMessage = keymintTest->Process(message);
+        if (!encryptedMessage) {
+            state.SkipWithError(
+                    ("Encryption error, " + std::to_string(keymintTest->getError())).c_str());
+            return;
+        }
+        in_params.push_back(out_params);
+        out_params.Clear();
     }
-    auto encryptedMessage = keymintTest->Process(message);
-    if (!encryptedMessage) {
-        state.SkipWithError(
-                ("Encryption error, " + std::to_string(keymintTest->getError())).c_str());
-        return;
-    }
-    in_params.push_back(out_params);
-    out_params.Clear();
     for (auto _ : state) {
         state.PauseTiming();
-        error = keymintTest->Begin(KeyPurpose::DECRYPT, in_params, &out_params);
+        auto error = keymintTest->Begin(KeyPurpose::DECRYPT, in_params, &out_params);
         if (error != ErrorCode::OK) {
             state.SkipWithError(
                     ("Decryption begin error, " + std::to_string(keymintTest->getError())).c_str());
@@ -649,9 +905,9 @@
 BENCHMARK_KM_CIPHER_ALL_MSGS(DESede/ECB/PKCS7Padding, 168);
 
 #define BENCHMARK_KM_CIPHER_ALL_RSA_KEYS(transform, msgSize) \
-    BENCHMARK_KM_CIPHER(transform, 2048, msgSize)            \
-    BENCHMARK_KM_CIPHER(transform, 3072, msgSize)            \
-    BENCHMARK_KM_CIPHER(transform, 4096, msgSize)
+    BENCHMARK_KM_ASYM_CIPHER(transform, 2048, msgSize)            \
+    BENCHMARK_KM_ASYM_CIPHER(transform, 3072, msgSize)            \
+    BENCHMARK_KM_ASYM_CIPHER(transform, 4096, msgSize)
 
 BENCHMARK_KM_CIPHER_ALL_RSA_KEYS(RSA/ECB/NoPadding, SMALL_MESSAGE_SIZE);
 BENCHMARK_KM_CIPHER_ALL_RSA_KEYS(RSA/ECB/PKCS1Padding, SMALL_MESSAGE_SIZE);