Add more crypto operations.

Test: keystore2_crypto_test_rust
Change-Id: Ice2facdc1b41f4e4ece839c2a3b956889e813960
diff --git a/keystore2/src/crypto/Android.bp b/keystore2/src/crypto/Android.bp
index 03c42b2..9ecd823 100644
--- a/keystore2/src/crypto/Android.bp
+++ b/keystore2/src/crypto/Android.bp
@@ -47,6 +47,30 @@
     crate_name: "keystore2_crypto_bindgen",
     source_stem: "bindings",
     host_supported: true,
+    shared_libs: ["libcrypto"],
+    bindgen_flags: [
+        "--size_t-is-usize",
+        "--whitelist-function", "randomBytes",
+        "--whitelist-function", "AES_gcm_encrypt",
+        "--whitelist-function", "AES_gcm_decrypt",
+        "--whitelist-function", "CreateKeyId",
+        "--whitelist-function", "generateKeyFromPassword",
+        "--whitelist-function", "HKDFExtract",
+        "--whitelist-function", "HKDFExpand",
+        "--whitelist-function", "ECDHComputeKey",
+        "--whitelist-function", "ECKEYGenerateKey",
+        "--whitelist-function", "ECKEYDeriveFromSecret",
+        "--whitelist-function", "EC_KEY_get0_public_key",
+        "--whitelist-function", "ECPOINTPoint2Oct",
+        "--whitelist-function", "ECPOINTOct2Point",
+        "--whitelist-function", "EC_KEY_free",
+        "--whitelist-function", "EC_POINT_free",
+        "--whitelist-type", "EC_KEY",
+        "--whitelist-type", "EC_POINT",
+        "--whitelist-var", "EC_MAX_BYTES",
+        "--whitelist-var", "EVP_MAX_MD_SIZE",
+    ],
+    cflags: ["-DBORINGSSL_NO_CXX"],
 }
 
 rust_test {
diff --git a/keystore2/src/crypto/crypto.cpp b/keystore2/src/crypto/crypto.cpp
index 173ed11..3cc19c5 100644
--- a/keystore2/src/crypto/crypto.cpp
+++ b/keystore2/src/crypto/crypto.cpp
@@ -20,7 +20,11 @@
 
 #include <log/log.h>
 #include <openssl/aes.h>
+#include <openssl/ec.h>
+#include <openssl/ec_key.h>
+#include <openssl/ecdh.h>
 #include <openssl/evp.h>
+#include <openssl/hkdf.h>
 #include <openssl/rand.h>
 
 #include <vector>
@@ -197,3 +201,63 @@
 
     PKCS5_PBKDF2_HMAC(pw, pw_len, salt, saltSize, 8192, digest, key_len, key);
 }
+
+// New code.
+
+bool HKDFExtract(uint8_t* out_key, size_t* out_len, const uint8_t* secret, size_t secret_len,
+                 const uint8_t* salt, size_t salt_len) {
+    const EVP_MD* digest = EVP_sha256();
+    auto result = HKDF_extract(out_key, out_len, digest, secret, secret_len, salt, salt_len);
+    return result == 1;
+}
+
+bool HKDFExpand(uint8_t* out_key, size_t out_len, const uint8_t* prk, size_t prk_len,
+                const uint8_t* info, size_t info_len) {
+    const EVP_MD* digest = EVP_sha256();
+    auto result = HKDF_expand(out_key, out_len, digest, prk, prk_len, info, info_len);
+    return result == 1;
+}
+
+int ECDHComputeKey(void* out, const EC_POINT* pub_key, const EC_KEY* priv_key) {
+    return ECDH_compute_key(out, EC_MAX_BYTES, pub_key, priv_key, nullptr);
+}
+
+EC_KEY* ECKEYGenerateKey() {
+    EC_KEY* key = EC_KEY_new();
+    EC_GROUP* group = EC_GROUP_new_by_curve_name(NID_X9_62_prime256v1);
+    EC_KEY_set_group(key, group);
+    auto result = EC_KEY_generate_key(key);
+    if (result == 0) {
+        EC_GROUP_free(group);
+        EC_KEY_free(key);
+        return nullptr;
+    }
+    return key;
+}
+
+EC_KEY* ECKEYDeriveFromSecret(const uint8_t* secret, size_t secret_len) {
+    EC_GROUP* group = EC_GROUP_new_by_curve_name(NID_X9_62_prime256v1);
+    auto result = EC_KEY_derive_from_secret(group, secret, secret_len);
+    EC_GROUP_free(group);
+    return result;
+}
+
+size_t ECPOINTPoint2Oct(const EC_POINT* point, uint8_t* buf, size_t len) {
+    EC_GROUP* group = EC_GROUP_new_by_curve_name(NID_X9_62_prime256v1);
+    point_conversion_form_t form = POINT_CONVERSION_UNCOMPRESSED;
+    auto result = EC_POINT_point2oct(group, point, form, buf, len, nullptr);
+    EC_GROUP_free(group);
+    return result;
+}
+
+EC_POINT* ECPOINTOct2Point(const uint8_t* buf, size_t len) {
+    EC_GROUP* group = EC_GROUP_new_by_curve_name(NID_X9_62_prime256v1);
+    EC_POINT* point = EC_POINT_new(group);
+    auto result = EC_POINT_oct2point(group, point, buf, len, nullptr);
+    EC_GROUP_free(group);
+    if (result == 0) {
+        EC_POINT_free(point);
+        return nullptr;
+    }
+    return point;
+}
diff --git a/keystore2/src/crypto/crypto.hpp b/keystore2/src/crypto/crypto.hpp
index 2e597f1..9bd7758 100644
--- a/keystore2/src/crypto/crypto.hpp
+++ b/keystore2/src/crypto/crypto.hpp
@@ -36,6 +36,30 @@
 
   void generateKeyFromPassword(uint8_t* key, size_t key_len, const char* pw,
                                size_t pw_len, const uint8_t* salt);
+
+  #include "openssl/digest.h"
+  #include "openssl/ec_key.h"
+
+  bool HKDFExtract(uint8_t *out_key, size_t *out_len,
+                   const uint8_t *secret, size_t secret_len,
+                   const uint8_t *salt, size_t salt_len);
+
+  bool HKDFExpand(uint8_t *out_key, size_t out_len,
+                  const uint8_t *prk, size_t prk_len,
+                  const uint8_t *info, size_t info_len);
+
+  // We define this as field_elem_size.
+  static const size_t EC_MAX_BYTES = 32;
+
+  int ECDHComputeKey(void *out, const EC_POINT *pub_key, const EC_KEY *priv_key);
+
+  EC_KEY* ECKEYGenerateKey();
+
+  EC_KEY* ECKEYDeriveFromSecret(const uint8_t *secret, size_t secret_len);
+
+  size_t ECPOINTPoint2Oct(const EC_POINT *point, uint8_t *buf, size_t len);
+
+  EC_POINT* ECPOINTOct2Point(const uint8_t *buf, size_t len);
 }
 
 #endif  //  __CRYPTO_H__
diff --git a/keystore2/src/crypto/error.rs b/keystore2/src/crypto/error.rs
index 2eb97b9..1e84fc6 100644
--- a/keystore2/src/crypto/error.rs
+++ b/keystore2/src/crypto/error.rs
@@ -56,4 +56,33 @@
     /// Nix error.
     #[error(transparent)]
     NixError(#[from] nix::Error),
+
+    /// This is returned if the C implementation of HKDFExtract returned false
+    /// or otherwise failed.
+    #[error("Failed to extract.")]
+    HKDFExtractFailed,
+
+    /// This is returned if the C implementation of HKDFExpand returned false.
+    #[error("Failed to expand.")]
+    HKDFExpandFailed,
+
+    /// This is returned if the C implementation of ECDHComputeKey returned -1.
+    #[error("Failed to compute ecdh key.")]
+    ECDHComputeKeyFailed,
+
+    /// This is returned if the C implementation of ECKEYGenerateKey returned null.
+    #[error("Failed to generate key.")]
+    ECKEYGenerateKeyFailed,
+
+    /// This is returned if the C implementation of ECKEYDeriveFromSecret returned null.
+    #[error("Failed to derive key.")]
+    ECKEYDeriveFailed,
+
+    /// This is returned if the C implementation of ECPOINTPoint2Oct returned 0.
+    #[error("Failed to convert point to oct.")]
+    ECPoint2OctFailed,
+
+    /// This is returned if the C implementation of ECPOINTOct2Point returned null.
+    #[error("Failed to convert oct to point.")]
+    ECOct2PointFailed,
 }
diff --git a/keystore2/src/crypto/lib.rs b/keystore2/src/crypto/lib.rs
index 338bdb9..92b257c 100644
--- a/keystore2/src/crypto/lib.rs
+++ b/keystore2/src/crypto/lib.rs
@@ -19,8 +19,13 @@
 mod zvec;
 pub use error::Error;
 use keystore2_crypto_bindgen::{
-    generateKeyFromPassword, randomBytes, size_t, AES_gcm_decrypt, AES_gcm_encrypt,
+    generateKeyFromPassword, randomBytes, AES_gcm_decrypt, AES_gcm_encrypt, ECDHComputeKey,
+    ECKEYDeriveFromSecret, ECKEYGenerateKey, ECPOINTOct2Point, ECPOINTPoint2Oct, EC_KEY_free,
+    EC_KEY_get0_public_key, EC_POINT_free, HKDFExpand, HKDFExtract, EC_KEY, EC_MAX_BYTES, EC_POINT,
+    EVP_MAX_MD_SIZE,
 };
+use std::convert::TryInto;
+use std::marker::PhantomData;
 pub use zvec::ZVec;
 
 /// Length of the expected initialization vector.
@@ -43,7 +48,7 @@
 pub fn generate_aes256_key() -> Result<ZVec, Error> {
     // Safety: key has the same length as the requested number of random bytes.
     let mut key = ZVec::new(AES_256_KEY_LENGTH)?;
-    if unsafe { randomBytes(key.as_mut_ptr(), AES_256_KEY_LENGTH as size_t) } {
+    if unsafe { randomBytes(key.as_mut_ptr(), AES_256_KEY_LENGTH) } {
         Ok(key)
     } else {
         Err(Error::RandomNumberGenerationFailed)
@@ -54,7 +59,7 @@
 pub fn generate_salt() -> Result<Vec<u8>, Error> {
     // Safety: salt has the same length as the requested number of random bytes.
     let mut salt = vec![0; SALT_LENGTH];
-    if unsafe { randomBytes(salt.as_mut_ptr(), SALT_LENGTH as size_t) } {
+    if unsafe { randomBytes(salt.as_mut_ptr(), SALT_LENGTH) } {
         Ok(salt)
     } else {
         Err(Error::RandomNumberGenerationFailed)
@@ -91,9 +96,9 @@
         AES_gcm_decrypt(
             data.as_ptr(),
             result.as_mut_ptr(),
-            data.len() as size_t,
+            data.len(),
             key.as_ptr(),
-            key.len() as size_t,
+            key.len(),
             iv.as_ptr(),
             tag.as_ptr(),
         )
@@ -111,7 +116,7 @@
     let mut iv = vec![0; IV_LENGTH];
     // Safety: iv is longer than GCM_IV_LENGTH, which is 12 while IV_LENGTH is 16.
     // The iv needs to be 16 bytes long, but the last 4 bytes remain zeroed.
-    if !unsafe { randomBytes(iv.as_mut_ptr(), GCM_IV_LENGTH as size_t) } {
+    if !unsafe { randomBytes(iv.as_mut_ptr(), GCM_IV_LENGTH) } {
         return Err(Error::RandomNumberGenerationFailed);
     }
 
@@ -126,9 +131,9 @@
         AES_gcm_encrypt(
             data.as_ptr(),
             result.as_mut_ptr(),
-            data.len() as size_t,
+            data.len(),
             key.as_ptr(),
-            key.len() as size_t,
+            key.len(),
             iv.as_ptr(),
             tag.as_mut_ptr(),
         )
@@ -166,9 +171,9 @@
     unsafe {
         generateKeyFromPassword(
             result.as_mut_ptr(),
-            result.len() as size_t,
+            result.len(),
             pw.as_ptr() as *const std::os::raw::c_char,
-            pw.len() as size_t,
+            pw.len(),
             salt,
         )
     };
@@ -176,6 +181,178 @@
     Ok(result)
 }
 
+/// Calls the boringssl HKDF_extract function.
+pub fn hkdf_extract(secret: &[u8], salt: &[u8]) -> Result<ZVec, Error> {
+    let max_size: usize = EVP_MAX_MD_SIZE.try_into().unwrap();
+    let mut buf = ZVec::new(max_size)?;
+    let mut out_len = 0;
+    // Safety: HKDF_extract writes at most EVP_MAX_MD_SIZE bytes.
+    // Secret and salt point to valid buffers.
+    let result = unsafe {
+        HKDFExtract(
+            buf.as_mut_ptr(),
+            &mut out_len,
+            secret.as_ptr(),
+            secret.len(),
+            salt.as_ptr(),
+            salt.len(),
+        )
+    };
+    if !result {
+        return Err(Error::HKDFExtractFailed);
+    }
+    // According to the boringssl API, this should never happen.
+    if out_len > max_size {
+        return Err(Error::HKDFExtractFailed);
+    }
+    // HKDF_extract may write fewer than the maximum number of bytes, so we
+    // truncate the buffer.
+    buf.reduce_len(out_len);
+    Ok(buf)
+}
+
+/// Calls the boringssl HKDF_expand function.
+pub fn hkdf_expand(out_len: usize, prk: &[u8], info: &[u8]) -> Result<ZVec, Error> {
+    let mut buf = ZVec::new(out_len)?;
+    // Safety: HKDF_expand writes out_len bytes to the buffer.
+    // prk and info are valid buffers.
+    let result = unsafe {
+        HKDFExpand(buf.as_mut_ptr(), out_len, prk.as_ptr(), prk.len(), info.as_ptr(), info.len())
+    };
+    if !result {
+        return Err(Error::HKDFExpandFailed);
+    }
+    Ok(buf)
+}
+
+/// A wrapper around the boringssl EC_KEY type that frees it on drop.
+pub struct ECKey(*mut EC_KEY);
+
+impl Drop for ECKey {
+    fn drop(&mut self) {
+        // Safety: We only create ECKey objects for valid EC_KEYs
+        // and they are the sole owners of those keys.
+        unsafe { EC_KEY_free(self.0) };
+    }
+}
+
+// Wrappers around the boringssl EC_POINT type.
+// The EC_POINT can either be owned (and therefore mutable) or a pointer to an
+// EC_POINT owned by someone else (and thus immutable).  The former are freed
+// on drop.
+
+/// An owned EC_POINT object.
+pub struct OwnedECPoint(*mut EC_POINT);
+
+/// A pointer to an EC_POINT object.
+pub struct BorrowedECPoint<'a> {
+    data: *const EC_POINT,
+    phantom: PhantomData<&'a EC_POINT>,
+}
+
+impl OwnedECPoint {
+    /// Get the wrapped EC_POINT object.
+    pub fn get_point(&self) -> &EC_POINT {
+        // Safety: We only create OwnedECPoint objects for valid EC_POINTs.
+        unsafe { self.0.as_ref().unwrap() }
+    }
+}
+
+impl<'a> BorrowedECPoint<'a> {
+    /// Get the wrapped EC_POINT object.
+    pub fn get_point(&self) -> &EC_POINT {
+        // Safety: We only create BorrowedECPoint objects for valid EC_POINTs.
+        unsafe { self.data.as_ref().unwrap() }
+    }
+}
+
+impl Drop for OwnedECPoint {
+    fn drop(&mut self) {
+        // Safety: We only create OwnedECPoint objects for valid
+        // EC_POINTs and they are the sole owners of those points.
+        unsafe { EC_POINT_free(self.0) };
+    }
+}
+
+/// Calls the boringssl ECDH_compute_key function.
+pub fn ecdh_compute_key(pub_key: &EC_POINT, priv_key: &ECKey) -> Result<ZVec, Error> {
+    let mut buf = ZVec::new(EC_MAX_BYTES)?;
+    // Safety: Our ECDHComputeKey wrapper passes EC_MAX_BYES to ECDH_compute_key, which
+    // writes at most that many bytes to the output.
+    // The two keys are valid objects.
+    let result =
+        unsafe { ECDHComputeKey(buf.as_mut_ptr() as *mut std::ffi::c_void, pub_key, priv_key.0) };
+    if result == -1 {
+        return Err(Error::ECDHComputeKeyFailed);
+    }
+    let out_len = result.try_into().unwrap();
+    // According to the boringssl API, this should never happen.
+    if out_len > buf.len() {
+        return Err(Error::ECDHComputeKeyFailed);
+    }
+    // ECDH_compute_key may write fewer than the maximum number of bytes, so we
+    // truncate the buffer.
+    buf.reduce_len(out_len);
+    Ok(buf)
+}
+
+/// Calls the boringssl EC_KEY_generate_key function.
+pub fn ec_key_generate_key() -> Result<ECKey, Error> {
+    // Safety: Creates a new key on its own.
+    let key = unsafe { ECKEYGenerateKey() };
+    if key.is_null() {
+        return Err(Error::ECKEYGenerateKeyFailed);
+    }
+    Ok(ECKey(key))
+}
+
+/// Calls the boringssl EC_KEY_derive_from_secret function.
+pub fn ec_key_derive_from_secret(secret: &[u8]) -> Result<ECKey, Error> {
+    // Safety: secret is a valid buffer.
+    let result = unsafe { ECKEYDeriveFromSecret(secret.as_ptr(), secret.len()) };
+    if result.is_null() {
+        return Err(Error::ECKEYDeriveFailed);
+    }
+    Ok(ECKey(result))
+}
+
+/// Calls the boringssl EC_KEY_get0_public_key function.
+pub fn ec_key_get0_public_key(key: &ECKey) -> BorrowedECPoint {
+    // Safety: The key is valid.
+    // This returns a pointer to a key, so we create an immutable variant.
+    BorrowedECPoint { data: unsafe { EC_KEY_get0_public_key(key.0) }, phantom: PhantomData }
+}
+
+/// Calls the boringssl EC_POINT_point2oct.
+pub fn ec_point_point_to_oct(point: &EC_POINT) -> Result<Vec<u8>, Error> {
+    // We fix the length to 65 (1 + 2 * field_elem_size), as we get an error if it's too small.
+    let len = 65;
+    let mut buf = vec![0; len];
+    // Safety: EC_POINT_point2oct writes at most len bytes. The point is valid.
+    let result = unsafe { ECPOINTPoint2Oct(point, buf.as_mut_ptr(), len) };
+    if result == 0 {
+        return Err(Error::ECPoint2OctFailed);
+    }
+    // According to the boringssl API, this should never happen.
+    if result > len {
+        return Err(Error::ECPoint2OctFailed);
+    }
+    buf.resize(result, 0);
+    Ok(buf)
+}
+
+/// Calls the boringssl EC_POINT_oct2point function.
+pub fn ec_point_oct_to_point(buf: &[u8]) -> Result<OwnedECPoint, Error> {
+    // Safety: The buffer is valid.
+    let result = unsafe { ECPOINTOct2Point(buf.as_ptr(), buf.len()) };
+    if result.is_null() {
+        return Err(Error::ECPoint2OctFailed);
+    }
+    // Our C wrapper creates a new EC_POINT, so we mark this mutable and free
+    // it on drop.
+    Ok(OwnedECPoint(result))
+}
+
 #[cfg(test)]
 mod tests {
 
@@ -249,4 +426,39 @@
         }
         assert_ne!(key, vec![0; 16]);
     }
+
+    #[test]
+    fn test_hkdf() {
+        let result = hkdf_extract(&[0; 16], &[0; 16]);
+        assert!(result.is_ok());
+        for out_len in 4..=8 {
+            let result = hkdf_expand(out_len, &[0; 16], &[0; 16]);
+            assert!(result.is_ok());
+            assert_eq!(result.unwrap().len(), out_len);
+        }
+    }
+
+    #[test]
+    fn test_ec() {
+        let key = ec_key_generate_key();
+        assert!(key.is_ok());
+        assert!(!key.unwrap().0.is_null());
+
+        let key = ec_key_derive_from_secret(&[42; 16]);
+        assert!(key.is_ok());
+        let key = key.unwrap();
+        assert!(!key.0.is_null());
+
+        let point = ec_key_get0_public_key(&key);
+
+        let result = ecdh_compute_key(point.get_point(), &key);
+        assert!(result.is_ok());
+
+        let oct = ec_point_point_to_oct(point.get_point());
+        assert!(oct.is_ok());
+        let oct = oct.unwrap();
+
+        let point2 = ec_point_oct_to_point(oct.as_slice());
+        assert!(point2.is_ok());
+    }
 }
diff --git a/keystore2/src/crypto/zvec.rs b/keystore2/src/crypto/zvec.rs
index 52addfc..e75e1dc 100644
--- a/keystore2/src/crypto/zvec.rs
+++ b/keystore2/src/crypto/zvec.rs
@@ -21,10 +21,15 @@
 use std::ops::{Deref, DerefMut};
 use std::ptr::write_volatile;
 
-/// A fixed size u8 vector that is zeroed when dropped. Also the data is
-/// pinned in memory with mlock.
+/// A semi fixed size u8 vector that is zeroed when dropped.  It can shrink in
+/// size but cannot grow larger than the original size (and if it shrinks it
+/// still owns the entire buffer).  Also the data is pinned in memory with
+/// mlock.
 #[derive(Default, Eq, PartialEq)]
-pub struct ZVec(Box<[u8]>);
+pub struct ZVec {
+    elems: Box<[u8]>,
+    len: usize,
+}
 
 impl ZVec {
     /// Create a ZVec with the given size.
@@ -34,18 +39,27 @@
         if size > 0 {
             unsafe { mlock(b.as_ptr() as *const std::ffi::c_void, b.len()) }?;
         }
-        Ok(Self(b))
+        Ok(Self { elems: b, len: size })
+    }
+
+    /// Reduce the length to the given value.  Does nothing if that length is
+    /// greater than the length of the vector.  Note that it still owns the
+    /// original allocation even if the length is reduced.
+    pub fn reduce_len(&mut self, len: usize) {
+        if len <= self.elems.len() {
+            self.len = len;
+        }
     }
 }
 
 impl Drop for ZVec {
     fn drop(&mut self) {
-        for i in 0..self.0.len() {
-            unsafe { write_volatile(self.0.as_mut_ptr().add(i), 0) };
+        for i in 0..self.elems.len() {
+            unsafe { write_volatile(self.elems.as_mut_ptr().add(i), 0) };
         }
-        if !self.0.is_empty() {
+        if !self.elems.is_empty() {
             if let Err(e) =
-                unsafe { munlock(self.0.as_ptr() as *const std::ffi::c_void, self.0.len()) }
+                unsafe { munlock(self.elems.as_ptr() as *const std::ffi::c_void, self.elems.len()) }
             {
                 log::error!("In ZVec::drop: `munlock` failed: {:?}.", e);
             }
@@ -57,22 +71,22 @@
     type Target = [u8];
 
     fn deref(&self) -> &Self::Target {
-        &self.0
+        &self.elems[0..self.len]
     }
 }
 
 impl DerefMut for ZVec {
     fn deref_mut(&mut self) -> &mut Self::Target {
-        &mut self.0
+        &mut self.elems[0..self.len]
     }
 }
 
 impl fmt::Debug for ZVec {
     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
-        if self.0.is_empty() {
+        if self.elems.is_empty() {
             write!(f, "Zvec empty")
         } else {
-            write!(f, "Zvec size: {} [ Sensitive information redacted ]", self.0.len())
+            write!(f, "Zvec size: {} [ Sensitive information redacted ]", self.len)
         }
     }
 }
@@ -97,6 +111,7 @@
         if !b.is_empty() {
             unsafe { mlock(b.as_ptr() as *const std::ffi::c_void, b.len()) }?;
         }
-        Ok(Self(b))
+        let len = b.len();
+        Ok(Self { elems: b, len })
     }
 }