Add support for DSA and ECDSA key types

Change-Id: Ibee8d172eeb36f1a2e2ce62f275aea55ada5bcbf
diff --git a/softkeymaster/keymaster_openssl.cpp b/softkeymaster/keymaster_openssl.cpp
index 3620450..19ec999 100644
--- a/softkeymaster/keymaster_openssl.cpp
+++ b/softkeymaster/keymaster_openssl.cpp
@@ -57,6 +57,20 @@
 };
 typedef UniquePtr<PKCS8_PRIV_KEY_INFO, PKCS8_PRIV_KEY_INFO_Delete> Unique_PKCS8_PRIV_KEY_INFO;
 
+struct DSA_Delete {
+    void operator()(DSA* p) const {
+        DSA_free(p);
+    }
+};
+typedef UniquePtr<DSA, DSA_Delete> Unique_DSA;
+
+struct EC_KEY_Delete {
+    void operator()(EC_KEY* p) const {
+        EC_KEY_free(p);
+    }
+};
+typedef UniquePtr<EC_KEY, EC_KEY_Delete> Unique_EC_KEY;
+
 struct RSA_Delete {
     void operator()(RSA* p) const {
         RSA_free(p);
@@ -93,12 +107,15 @@
 }
 
 static int wrap_key(EVP_PKEY* pkey, int type, uint8_t** keyBlob, size_t* keyBlobLength) {
-    /* Find the length of each size */
-    int publicLen = i2d_PublicKey(pkey, NULL);
+    /*
+     *  Find the length of each size. Public key is not needed anymore but must be kept for
+     * alignment purposes.
+     */
+    int publicLen = 0;
     int privateLen = i2d_PrivateKey(pkey, NULL);
 
-    if (privateLen <= 0 || publicLen <= 0) {
-        ALOGE("private or public key size was too big");
+    if (privateLen <= 0) {
+        ALOGE("private key size was too big");
         return -1;
     }
 
@@ -106,7 +123,7 @@
     *keyBlobLength = get_softkey_header_size() + sizeof(int) + sizeof(int) + privateLen
             + sizeof(int) + publicLen;
 
-    UniquePtr<unsigned char[]> derData(new unsigned char[*keyBlobLength]);
+    UniquePtr<unsigned char> derData(new unsigned char[*keyBlobLength]);
     if (derData.get() == NULL) {
         ALOGE("could not allocate memory for key blob");
         return -1;
@@ -125,10 +142,6 @@
     for (int i = sizeof(int) - 1; i >= 0; i--) {
         *p++ = (publicLen >> (8*i)) & 0xFF;
     }
-    if (i2d_PublicKey(pkey, &p) != publicLen) {
-        logOpenSSLError("wrap_key");
-        return -1;
-    }
 
     /* Write private key to allocated buffer */
     for (int i = sizeof(int) - 1; i >= 0; i--) {
@@ -174,12 +187,6 @@
         type = (type << 8) | *p++;
     }
 
-    Unique_EVP_PKEY pkey(EVP_PKEY_new());
-    if (pkey.get() == NULL) {
-        logOpenSSLError("unwrap_key");
-        return NULL;
-    }
-
     for (size_t i = 0; i < sizeof(int); i++) {
         publicLen = (publicLen << 8) | *p++;
     }
@@ -187,9 +194,9 @@
         ALOGE("public key length encoding error: size=%ld, end=%d", publicLen, end - p);
         return NULL;
     }
-    EVP_PKEY* tmp = pkey.get();
-    d2i_PublicKey(type, &tmp, &p, publicLen);
+    const uint8_t *pubKey = p;
 
+    p += publicLen;
     if (end - p < 2) {
         ALOGE("private key truncated");
         return NULL;
@@ -201,64 +208,208 @@
         ALOGE("private key length encoding error: size=%ld, end=%d", privateLen, end - p);
         return NULL;
     }
-    d2i_PrivateKey(type, &tmp, &p, privateLen);
+
+    Unique_EVP_PKEY pkey(EVP_PKEY_new());
+    if (pkey.get() == NULL) {
+        logOpenSSLError("unwrap_key");
+        return NULL;
+    }
+    EVP_PKEY* tmp = pkey.get();
+
+    if (d2i_PrivateKey(type, &tmp, &p, privateLen) == NULL) {
+        logOpenSSLError("unwrap_key");
+        return NULL;
+    }
 
     return pkey.release();
 }
 
+static int generate_dsa_keypair(EVP_PKEY* pkey, const keymaster_dsa_keygen_params_t* dsa_params)
+{
+    if (dsa_params->key_size < 512) {
+        ALOGI("Requested DSA key size is too small (<512)");
+        return -1;
+    }
+
+    Unique_DSA dsa(DSA_new());
+
+    if (dsa_params->generator_len == 0 ||
+            dsa_params->prime_p_len == 0 ||
+            dsa_params->prime_q_len == 0 ||
+            dsa_params->generator == NULL||
+            dsa_params->prime_p == NULL ||
+            dsa_params->prime_q == NULL) {
+        if (DSA_generate_parameters_ex(dsa.get(), dsa_params->key_size, NULL, 0, NULL, NULL,
+                NULL) != 1) {
+            logOpenSSLError("generate_dsa_keypair");
+            return -1;
+        }
+    } else {
+        dsa->g = BN_bin2bn(dsa_params->generator,
+                dsa_params->generator_len,
+                NULL);
+        if (dsa->g == NULL) {
+            logOpenSSLError("generate_dsa_keypair");
+            return -1;
+        }
+
+        dsa->p = BN_bin2bn(dsa_params->prime_p,
+                   dsa_params->prime_p_len,
+                   NULL);
+        if (dsa->p == NULL) {
+            logOpenSSLError("generate_dsa_keypair");
+            return -1;
+        }
+
+        dsa->q = BN_bin2bn(dsa_params->prime_q,
+                   dsa_params->prime_q_len,
+                   NULL);
+        if (dsa->q == NULL) {
+            logOpenSSLError("generate_dsa_keypair");
+            return -1;
+        }
+    }
+
+    if (DSA_generate_key(dsa.get()) != 1) {
+        logOpenSSLError("generate_dsa_keypair");
+        return -1;
+    }
+
+    if (EVP_PKEY_assign_DSA(pkey, dsa.get()) == 0) {
+        logOpenSSLError("generate_dsa_keypair");
+        return -1;
+    }
+    OWNERSHIP_TRANSFERRED(dsa);
+
+    return 0;
+}
+
+static int generate_ec_keypair(EVP_PKEY* pkey, const keymaster_ec_keygen_params_t* ec_params)
+{
+    EC_GROUP* group;
+    switch (ec_params->field_size) {
+    case 192:
+        group = EC_GROUP_new_by_curve_name(NID_X9_62_prime192v1);
+        break;
+    case 224:
+        group = EC_GROUP_new_by_curve_name(NID_secp224r1);
+        break;
+    case 256:
+        group = EC_GROUP_new_by_curve_name(NID_X9_62_prime256v1);
+        break;
+    case 384:
+        group = EC_GROUP_new_by_curve_name(NID_secp384r1);
+        break;
+    case 521:
+        group = EC_GROUP_new_by_curve_name(NID_secp521r1);
+        break;
+    default:
+        group = NULL;
+        break;
+    }
+
+    if (group == NULL) {
+        logOpenSSLError("generate_ec_keypair");
+        return -1;
+    }
+
+    EC_GROUP_set_point_conversion_form(group, POINT_CONVERSION_UNCOMPRESSED);
+    EC_GROUP_set_asn1_flag(group, OPENSSL_EC_NAMED_CURVE);
+
+    /* initialize EC key */
+    Unique_EC_KEY eckey(EC_KEY_new());
+    if (eckey.get() == NULL) {
+        logOpenSSLError("generate_ec_keypair");
+        return -1;
+    }
+
+    if (EC_KEY_set_group(eckey.get(), group) != 1) {
+        logOpenSSLError("generate_ec_keypair");
+        return -1;
+    }
+
+    if (EC_KEY_generate_key(eckey.get()) != 1
+            || EC_KEY_check_key(eckey.get()) < 0) {
+        logOpenSSLError("generate_ec_keypair");
+        return -1;
+    }
+
+    if (EVP_PKEY_assign_EC_KEY(pkey, eckey.get()) == 0) {
+        logOpenSSLError("generate_ec_keypair");
+        return -1;
+    }
+    OWNERSHIP_TRANSFERRED(eckey);
+
+    return 0;
+}
+
+static int generate_rsa_keypair(EVP_PKEY* pkey, const keymaster_rsa_keygen_params_t* rsa_params)
+{
+    Unique_BIGNUM bn(BN_new());
+    if (bn.get() == NULL) {
+        logOpenSSLError("generate_rsa_keypair");
+        return -1;
+    }
+
+    if (BN_set_word(bn.get(), rsa_params->public_exponent) == 0) {
+        logOpenSSLError("generate_rsa_keypair");
+        return -1;
+    }
+
+    /* initialize RSA */
+    Unique_RSA rsa(RSA_new());
+    if (rsa.get() == NULL) {
+        logOpenSSLError("generate_rsa_keypair");
+        return -1;
+    }
+
+    if (!RSA_generate_key_ex(rsa.get(), rsa_params->modulus_size, bn.get(), NULL)
+            || RSA_check_key(rsa.get()) < 0) {
+        logOpenSSLError("generate_rsa_keypair");
+        return -1;
+    }
+
+    if (EVP_PKEY_assign_RSA(pkey, rsa.get()) == 0) {
+        logOpenSSLError("generate_rsa_keypair");
+        return -1;
+    }
+    OWNERSHIP_TRANSFERRED(rsa);
+
+    return 0;
+}
+
 static int openssl_generate_keypair(const keymaster_device_t* dev,
         const keymaster_keypair_t key_type, const void* key_params,
         uint8_t** keyBlob, size_t* keyBlobLength) {
     ssize_t privateLen, publicLen;
 
-    if (key_type != TYPE_RSA) {
-        ALOGW("Unsupported key type %d", key_type);
-        return -1;
-    } else if (key_params == NULL) {
-        ALOGW("key_params == null");
-        return -1;
-    }
-
-    keymaster_rsa_keygen_params_t* rsa_params = (keymaster_rsa_keygen_params_t*) key_params;
-
-    Unique_BIGNUM bn(BN_new());
-    if (bn.get() == NULL) {
-        logOpenSSLError("openssl_generate_keypair");
-        return -1;
-    }
-
-    if (BN_set_word(bn.get(), rsa_params->public_exponent) == 0) {
-        logOpenSSLError("openssl_generate_keypair");
-        return -1;
-    }
-
-    /* initialize RSA */
-    Unique_RSA rsa(RSA_new());
-    if (rsa.get() == NULL) {
-        logOpenSSLError("openssl_generate_keypair");
-        return -1;
-    }
-
-    if (!RSA_generate_key_ex(rsa.get(), rsa_params->modulus_size, bn.get(), NULL)
-            || RSA_check_key(rsa.get()) < 0) {
-        logOpenSSLError("openssl_generate_keypair");
-        return -1;
-    }
-
-    /* assign to EVP */
     Unique_EVP_PKEY pkey(EVP_PKEY_new());
     if (pkey.get() == NULL) {
         logOpenSSLError("openssl_generate_keypair");
         return -1;
     }
 
-    if (EVP_PKEY_assign_RSA(pkey.get(), rsa.get()) == 0) {
-        logOpenSSLError("openssl_generate_keypair");
+    if (key_params == NULL) {
+        ALOGW("key_params == null");
+        return -1;
+    } else if (key_type == TYPE_DSA) {
+        const keymaster_dsa_keygen_params_t* dsa_params =
+                (const keymaster_dsa_keygen_params_t*) key_params;
+        generate_dsa_keypair(pkey.get(), dsa_params);
+    } else if (key_type == TYPE_EC) {
+        const keymaster_ec_keygen_params_t* ec_params =
+                (const keymaster_ec_keygen_params_t*) key_params;
+        generate_ec_keypair(pkey.get(), ec_params);
+    } else if (key_type == TYPE_RSA) {
+        const keymaster_rsa_keygen_params_t* rsa_params =
+                (const keymaster_rsa_keygen_params_t*) key_params;
+        generate_rsa_keypair(pkey.get(), rsa_params);
+    } else {
+        ALOGW("Unsupported key type %d", key_type);
         return -1;
     }
-    OWNERSHIP_TRANSFERRED(rsa);
 
-    if (wrap_key(pkey.get(), EVP_PKEY_RSA, keyBlob, keyBlobLength)) {
+    if (wrap_key(pkey.get(), EVP_PKEY_type(pkey->type), keyBlob, keyBlobLength)) {
         return -1;
     }
 
@@ -338,6 +489,105 @@
     return 0;
 }
 
+static int sign_dsa(EVP_PKEY* pkey, keymaster_dsa_sign_params_t* sign_params, const uint8_t* data,
+        const size_t dataLength, uint8_t** signedData, size_t* signedDataLength) {
+    if (sign_params->digest_type != DIGEST_NONE) {
+        ALOGW("Cannot handle digest type %d", sign_params->digest_type);
+        return -1;
+    }
+
+    Unique_DSA dsa(EVP_PKEY_get1_DSA(pkey));
+    if (dsa.get() == NULL) {
+        logOpenSSLError("openssl_sign_dsa");
+        return -1;
+    }
+
+    unsigned int dsaSize = DSA_size(dsa.get());
+    UniquePtr<uint8_t> signedDataPtr(reinterpret_cast<uint8_t*>(malloc(dsaSize)));
+    if (signedDataPtr.get() == NULL) {
+        logOpenSSLError("openssl_sign_dsa");
+        return -1;
+    }
+
+    unsigned char* tmp = reinterpret_cast<unsigned char*>(signedDataPtr.get());
+    if (DSA_sign(0, data, dataLength, tmp, &dsaSize, dsa.get()) <= 0) {
+        logOpenSSLError("openssl_sign_dsa");
+        return -1;
+    }
+
+    *signedDataLength = dsaSize;
+    *signedData = signedDataPtr.release();
+
+    return 0;
+}
+
+static int sign_ec(EVP_PKEY* pkey, keymaster_ec_sign_params_t* sign_params, const uint8_t* data,
+        const size_t dataLength, uint8_t** signedData, size_t* signedDataLength) {
+    if (sign_params->digest_type != DIGEST_NONE) {
+        ALOGW("Cannot handle digest type %d", sign_params->digest_type);
+        return -1;
+    }
+
+    Unique_EC_KEY eckey(EVP_PKEY_get1_EC_KEY(pkey));
+    if (eckey.get() == NULL) {
+        logOpenSSLError("openssl_sign_ec");
+        return -1;
+    }
+
+    unsigned int ecdsaSize = ECDSA_size(eckey.get());
+    UniquePtr<uint8_t> signedDataPtr(reinterpret_cast<uint8_t*>(malloc(ecdsaSize)));
+    if (signedDataPtr.get() == NULL) {
+        logOpenSSLError("openssl_sign_ec");
+        return -1;
+    }
+
+    unsigned char* tmp = reinterpret_cast<unsigned char*>(signedDataPtr.get());
+    if (ECDSA_sign(0, data, dataLength, tmp, &ecdsaSize, eckey.get()) <= 0) {
+        logOpenSSLError("openssl_sign_ec");
+        return -1;
+    }
+
+    *signedDataLength = ecdsaSize;
+    *signedData = signedDataPtr.release();
+
+    return 0;
+}
+
+
+static int sign_rsa(EVP_PKEY* pkey, keymaster_rsa_sign_params_t* sign_params, const uint8_t* data,
+        const size_t dataLength, uint8_t** signedData, size_t* signedDataLength) {
+    if (sign_params->digest_type != DIGEST_NONE) {
+        ALOGW("Cannot handle digest type %d", sign_params->digest_type);
+        return -1;
+    } else if (sign_params->padding_type != PADDING_NONE) {
+        ALOGW("Cannot handle padding type %d", sign_params->padding_type);
+        return -1;
+    }
+
+    Unique_RSA rsa(EVP_PKEY_get1_RSA(pkey));
+    if (rsa.get() == NULL) {
+        logOpenSSLError("openssl_sign_rsa");
+        return -1;
+    }
+
+    UniquePtr<uint8_t> signedDataPtr(reinterpret_cast<uint8_t*>(malloc(dataLength)));
+    if (signedDataPtr.get() == NULL) {
+        logOpenSSLError("openssl_sign_rsa");
+        return -1;
+    }
+
+    unsigned char* tmp = reinterpret_cast<unsigned char*>(signedDataPtr.get());
+    if (RSA_private_encrypt(dataLength, data, tmp, rsa.get(), RSA_NO_PADDING) <= 0) {
+        logOpenSSLError("openssl_sign_rsa");
+        return -1;
+    }
+
+    *signedDataLength = dataLength;
+    *signedData = signedDataPtr.release();
+
+    return 0;
+}
+
 static int openssl_sign_data(const keymaster_device_t* dev,
         const void* params,
         const uint8_t* keyBlob, const size_t keyBlobLength,
@@ -361,65 +611,69 @@
         return -1;
     }
 
-    if (EVP_PKEY_type(pkey->type) != EVP_PKEY_RSA) {
-        ALOGW("Cannot handle non-RSA keys yet");
+    int type = EVP_PKEY_type(pkey->type);
+    if (type == EVP_PKEY_DSA) {
+        keymaster_dsa_sign_params_t* sign_params = (keymaster_dsa_sign_params_t*) params;
+        return sign_dsa(pkey.get(), sign_params, data, dataLength, signedData, signedDataLength);
+    } else if (type == EVP_PKEY_EC) {
+        keymaster_ec_sign_params_t* sign_params = (keymaster_ec_sign_params_t*) params;
+        return sign_ec(pkey.get(), sign_params, data, dataLength, signedData, signedDataLength);
+    } else if (type == EVP_PKEY_RSA) {
+        keymaster_rsa_sign_params_t* sign_params = (keymaster_rsa_sign_params_t*) params;
+        return sign_rsa(pkey.get(), sign_params, data, dataLength, signedData, signedDataLength);
+    } else {
+        ALOGW("Unsupported key type");
         return -1;
     }
+}
 
-    keymaster_rsa_sign_params_t* sign_params = (keymaster_rsa_sign_params_t*) params;
+static int verify_dsa(EVP_PKEY* pkey, keymaster_dsa_sign_params_t* sign_params,
+        const uint8_t* signedData, const size_t signedDataLength, const uint8_t* signature,
+        const size_t signatureLength) {
     if (sign_params->digest_type != DIGEST_NONE) {
         ALOGW("Cannot handle digest type %d", sign_params->digest_type);
         return -1;
-    } else if (sign_params->padding_type != PADDING_NONE) {
-        ALOGW("Cannot handle padding type %d", sign_params->padding_type);
+    }
+
+    Unique_DSA dsa(EVP_PKEY_get1_DSA(pkey));
+    if (dsa.get() == NULL) {
+        logOpenSSLError("openssl_verify_dsa");
         return -1;
     }
 
-    Unique_RSA rsa(EVP_PKEY_get1_RSA(pkey.get()));
-    if (rsa.get() == NULL) {
-        logOpenSSLError("openssl_sign_data");
+    if (DSA_verify(0, signedData, signedDataLength, signature, signatureLength, dsa.get()) <= 0) {
+        logOpenSSLError("openssl_verify_dsa");
         return -1;
     }
 
-    UniquePtr<uint8_t> signedDataPtr(reinterpret_cast<uint8_t*>(malloc(dataLength)));
-    if (signedDataPtr.get() == NULL) {
-        logOpenSSLError("openssl_sign_data");
-        return -1;
-    }
-
-    unsigned char* tmp = reinterpret_cast<unsigned char*>(signedDataPtr.get());
-    if (RSA_private_encrypt(dataLength, data, tmp, rsa.get(), RSA_NO_PADDING) <= 0) {
-        logOpenSSLError("openssl_sign_data");
-        return -1;
-    }
-
-    *signedDataLength = dataLength;
-    *signedData = signedDataPtr.release();
     return 0;
 }
 
-static int openssl_verify_data(const keymaster_device_t* dev,
-        const void* params,
-        const uint8_t* keyBlob, const size_t keyBlobLength,
-        const uint8_t* signedData, const size_t signedDataLength,
-        const uint8_t* signature, const size_t signatureLength) {
-
-    if (signedData == NULL || signature == NULL) {
-        ALOGW("data or signature buffers == NULL");
+static int verify_ec(EVP_PKEY* pkey, keymaster_ec_sign_params_t* sign_params,
+        const uint8_t* signedData, const size_t signedDataLength, const uint8_t* signature,
+        const size_t signatureLength) {
+    if (sign_params->digest_type != DIGEST_NONE) {
+        ALOGW("Cannot handle digest type %d", sign_params->digest_type);
         return -1;
     }
 
-    Unique_EVP_PKEY pkey(unwrap_key(keyBlob, keyBlobLength));
-    if (pkey.get() == NULL) {
+    Unique_EC_KEY eckey(EVP_PKEY_get1_EC_KEY(pkey));
+    if (eckey.get() == NULL) {
+        logOpenSSLError("openssl_verify_ec");
         return -1;
     }
 
-    if (EVP_PKEY_type(pkey->type) != EVP_PKEY_RSA) {
-        ALOGW("Cannot handle non-RSA keys yet");
+    if (ECDSA_verify(0, signedData, signedDataLength, signature, signatureLength, eckey.get()) <= 0) {
+        logOpenSSLError("openssl_verify_ec");
         return -1;
     }
 
-    keymaster_rsa_sign_params_t* sign_params = (keymaster_rsa_sign_params_t*) params;
+    return 0;
+}
+
+static int verify_rsa(EVP_PKEY* pkey, keymaster_rsa_sign_params_t* sign_params,
+        const uint8_t* signedData, const size_t signedDataLength, const uint8_t* signature,
+        const size_t signatureLength) {
     if (sign_params->digest_type != DIGEST_NONE) {
         ALOGW("Cannot handle digest type %d", sign_params->digest_type);
         return -1;
@@ -431,7 +685,7 @@
         return -1;
     }
 
-    Unique_RSA rsa(EVP_PKEY_get1_RSA(pkey.get()));
+    Unique_RSA rsa(EVP_PKEY_get1_RSA(pkey));
     if (rsa.get() == NULL) {
         logOpenSSLError("openssl_verify_data");
         return -1;
@@ -457,9 +711,40 @@
     return result == 0 ? 0 : -1;
 }
 
+static int openssl_verify_data(const keymaster_device_t* dev,
+        const void* params,
+        const uint8_t* keyBlob, const size_t keyBlobLength,
+        const uint8_t* signedData, const size_t signedDataLength,
+        const uint8_t* signature, const size_t signatureLength) {
+
+    if (signedData == NULL || signature == NULL) {
+        ALOGW("data or signature buffers == NULL");
+        return -1;
+    }
+
+    Unique_EVP_PKEY pkey(unwrap_key(keyBlob, keyBlobLength));
+    if (pkey.get() == NULL) {
+        return -1;
+    }
+
+    int type = EVP_PKEY_type(pkey->type);
+    if (type == EVP_PKEY_RSA) {
+        keymaster_rsa_sign_params_t* sign_params = (keymaster_rsa_sign_params_t*) params;
+        return verify_rsa(pkey.get(), sign_params, signedData, signedDataLength, signature,
+                signatureLength);
+    } else if (type == EVP_PKEY_EC) {
+        keymaster_ec_sign_params_t* sign_params = (keymaster_ec_sign_params_t*) params;
+        return verify_ec(pkey.get(), sign_params, signedData, signedDataLength, signature,
+                signatureLength);
+    } else {
+        ALOGW("Unsupported key type %d", type);
+        return -1;
+    }
+}
+
 /* Close an opened OpenSSL instance */
 static int openssl_close(hw_device_t *dev) {
-    free(dev);
+    delete dev;
     return 0;
 }