Add more crypto operations.
Test: keystore2_crypto_test_rust
Change-Id: Ice2facdc1b41f4e4ece839c2a3b956889e813960
diff --git a/keystore2/src/crypto/lib.rs b/keystore2/src/crypto/lib.rs
index 338bdb9..92b257c 100644
--- a/keystore2/src/crypto/lib.rs
+++ b/keystore2/src/crypto/lib.rs
@@ -19,8 +19,13 @@
mod zvec;
pub use error::Error;
use keystore2_crypto_bindgen::{
- generateKeyFromPassword, randomBytes, size_t, AES_gcm_decrypt, AES_gcm_encrypt,
+ generateKeyFromPassword, randomBytes, AES_gcm_decrypt, AES_gcm_encrypt, ECDHComputeKey,
+ ECKEYDeriveFromSecret, ECKEYGenerateKey, ECPOINTOct2Point, ECPOINTPoint2Oct, EC_KEY_free,
+ EC_KEY_get0_public_key, EC_POINT_free, HKDFExpand, HKDFExtract, EC_KEY, EC_MAX_BYTES, EC_POINT,
+ EVP_MAX_MD_SIZE,
};
+use std::convert::TryInto;
+use std::marker::PhantomData;
pub use zvec::ZVec;
/// Length of the expected initialization vector.
@@ -43,7 +48,7 @@
pub fn generate_aes256_key() -> Result<ZVec, Error> {
// Safety: key has the same length as the requested number of random bytes.
let mut key = ZVec::new(AES_256_KEY_LENGTH)?;
- if unsafe { randomBytes(key.as_mut_ptr(), AES_256_KEY_LENGTH as size_t) } {
+ if unsafe { randomBytes(key.as_mut_ptr(), AES_256_KEY_LENGTH) } {
Ok(key)
} else {
Err(Error::RandomNumberGenerationFailed)
@@ -54,7 +59,7 @@
pub fn generate_salt() -> Result<Vec<u8>, Error> {
// Safety: salt has the same length as the requested number of random bytes.
let mut salt = vec![0; SALT_LENGTH];
- if unsafe { randomBytes(salt.as_mut_ptr(), SALT_LENGTH as size_t) } {
+ if unsafe { randomBytes(salt.as_mut_ptr(), SALT_LENGTH) } {
Ok(salt)
} else {
Err(Error::RandomNumberGenerationFailed)
@@ -91,9 +96,9 @@
AES_gcm_decrypt(
data.as_ptr(),
result.as_mut_ptr(),
- data.len() as size_t,
+ data.len(),
key.as_ptr(),
- key.len() as size_t,
+ key.len(),
iv.as_ptr(),
tag.as_ptr(),
)
@@ -111,7 +116,7 @@
let mut iv = vec![0; IV_LENGTH];
// Safety: iv is longer than GCM_IV_LENGTH, which is 12 while IV_LENGTH is 16.
// The iv needs to be 16 bytes long, but the last 4 bytes remain zeroed.
- if !unsafe { randomBytes(iv.as_mut_ptr(), GCM_IV_LENGTH as size_t) } {
+ if !unsafe { randomBytes(iv.as_mut_ptr(), GCM_IV_LENGTH) } {
return Err(Error::RandomNumberGenerationFailed);
}
@@ -126,9 +131,9 @@
AES_gcm_encrypt(
data.as_ptr(),
result.as_mut_ptr(),
- data.len() as size_t,
+ data.len(),
key.as_ptr(),
- key.len() as size_t,
+ key.len(),
iv.as_ptr(),
tag.as_mut_ptr(),
)
@@ -166,9 +171,9 @@
unsafe {
generateKeyFromPassword(
result.as_mut_ptr(),
- result.len() as size_t,
+ result.len(),
pw.as_ptr() as *const std::os::raw::c_char,
- pw.len() as size_t,
+ pw.len(),
salt,
)
};
@@ -176,6 +181,178 @@
Ok(result)
}
+/// Calls the boringssl HKDF_extract function.
+pub fn hkdf_extract(secret: &[u8], salt: &[u8]) -> Result<ZVec, Error> {
+ let max_size: usize = EVP_MAX_MD_SIZE.try_into().unwrap();
+ let mut buf = ZVec::new(max_size)?;
+ let mut out_len = 0;
+ // Safety: HKDF_extract writes at most EVP_MAX_MD_SIZE bytes.
+ // Secret and salt point to valid buffers.
+ let result = unsafe {
+ HKDFExtract(
+ buf.as_mut_ptr(),
+ &mut out_len,
+ secret.as_ptr(),
+ secret.len(),
+ salt.as_ptr(),
+ salt.len(),
+ )
+ };
+ if !result {
+ return Err(Error::HKDFExtractFailed);
+ }
+ // According to the boringssl API, this should never happen.
+ if out_len > max_size {
+ return Err(Error::HKDFExtractFailed);
+ }
+ // HKDF_extract may write fewer than the maximum number of bytes, so we
+ // truncate the buffer.
+ buf.reduce_len(out_len);
+ Ok(buf)
+}
+
+/// Calls the boringssl HKDF_expand function.
+pub fn hkdf_expand(out_len: usize, prk: &[u8], info: &[u8]) -> Result<ZVec, Error> {
+ let mut buf = ZVec::new(out_len)?;
+ // Safety: HKDF_expand writes out_len bytes to the buffer.
+ // prk and info are valid buffers.
+ let result = unsafe {
+ HKDFExpand(buf.as_mut_ptr(), out_len, prk.as_ptr(), prk.len(), info.as_ptr(), info.len())
+ };
+ if !result {
+ return Err(Error::HKDFExpandFailed);
+ }
+ Ok(buf)
+}
+
+/// A wrapper around the boringssl EC_KEY type that frees it on drop.
+pub struct ECKey(*mut EC_KEY);
+
+impl Drop for ECKey {
+ fn drop(&mut self) {
+ // Safety: We only create ECKey objects for valid EC_KEYs
+ // and they are the sole owners of those keys.
+ unsafe { EC_KEY_free(self.0) };
+ }
+}
+
+// Wrappers around the boringssl EC_POINT type.
+// The EC_POINT can either be owned (and therefore mutable) or a pointer to an
+// EC_POINT owned by someone else (and thus immutable). The former are freed
+// on drop.
+
+/// An owned EC_POINT object.
+pub struct OwnedECPoint(*mut EC_POINT);
+
+/// A pointer to an EC_POINT object.
+pub struct BorrowedECPoint<'a> {
+ data: *const EC_POINT,
+ phantom: PhantomData<&'a EC_POINT>,
+}
+
+impl OwnedECPoint {
+ /// Get the wrapped EC_POINT object.
+ pub fn get_point(&self) -> &EC_POINT {
+ // Safety: We only create OwnedECPoint objects for valid EC_POINTs.
+ unsafe { self.0.as_ref().unwrap() }
+ }
+}
+
+impl<'a> BorrowedECPoint<'a> {
+ /// Get the wrapped EC_POINT object.
+ pub fn get_point(&self) -> &EC_POINT {
+ // Safety: We only create BorrowedECPoint objects for valid EC_POINTs.
+ unsafe { self.data.as_ref().unwrap() }
+ }
+}
+
+impl Drop for OwnedECPoint {
+ fn drop(&mut self) {
+ // Safety: We only create OwnedECPoint objects for valid
+ // EC_POINTs and they are the sole owners of those points.
+ unsafe { EC_POINT_free(self.0) };
+ }
+}
+
+/// Calls the boringssl ECDH_compute_key function.
+pub fn ecdh_compute_key(pub_key: &EC_POINT, priv_key: &ECKey) -> Result<ZVec, Error> {
+ let mut buf = ZVec::new(EC_MAX_BYTES)?;
+ // Safety: Our ECDHComputeKey wrapper passes EC_MAX_BYES to ECDH_compute_key, which
+ // writes at most that many bytes to the output.
+ // The two keys are valid objects.
+ let result =
+ unsafe { ECDHComputeKey(buf.as_mut_ptr() as *mut std::ffi::c_void, pub_key, priv_key.0) };
+ if result == -1 {
+ return Err(Error::ECDHComputeKeyFailed);
+ }
+ let out_len = result.try_into().unwrap();
+ // According to the boringssl API, this should never happen.
+ if out_len > buf.len() {
+ return Err(Error::ECDHComputeKeyFailed);
+ }
+ // ECDH_compute_key may write fewer than the maximum number of bytes, so we
+ // truncate the buffer.
+ buf.reduce_len(out_len);
+ Ok(buf)
+}
+
+/// Calls the boringssl EC_KEY_generate_key function.
+pub fn ec_key_generate_key() -> Result<ECKey, Error> {
+ // Safety: Creates a new key on its own.
+ let key = unsafe { ECKEYGenerateKey() };
+ if key.is_null() {
+ return Err(Error::ECKEYGenerateKeyFailed);
+ }
+ Ok(ECKey(key))
+}
+
+/// Calls the boringssl EC_KEY_derive_from_secret function.
+pub fn ec_key_derive_from_secret(secret: &[u8]) -> Result<ECKey, Error> {
+ // Safety: secret is a valid buffer.
+ let result = unsafe { ECKEYDeriveFromSecret(secret.as_ptr(), secret.len()) };
+ if result.is_null() {
+ return Err(Error::ECKEYDeriveFailed);
+ }
+ Ok(ECKey(result))
+}
+
+/// Calls the boringssl EC_KEY_get0_public_key function.
+pub fn ec_key_get0_public_key(key: &ECKey) -> BorrowedECPoint {
+ // Safety: The key is valid.
+ // This returns a pointer to a key, so we create an immutable variant.
+ BorrowedECPoint { data: unsafe { EC_KEY_get0_public_key(key.0) }, phantom: PhantomData }
+}
+
+/// Calls the boringssl EC_POINT_point2oct.
+pub fn ec_point_point_to_oct(point: &EC_POINT) -> Result<Vec<u8>, Error> {
+ // We fix the length to 65 (1 + 2 * field_elem_size), as we get an error if it's too small.
+ let len = 65;
+ let mut buf = vec![0; len];
+ // Safety: EC_POINT_point2oct writes at most len bytes. The point is valid.
+ let result = unsafe { ECPOINTPoint2Oct(point, buf.as_mut_ptr(), len) };
+ if result == 0 {
+ return Err(Error::ECPoint2OctFailed);
+ }
+ // According to the boringssl API, this should never happen.
+ if result > len {
+ return Err(Error::ECPoint2OctFailed);
+ }
+ buf.resize(result, 0);
+ Ok(buf)
+}
+
+/// Calls the boringssl EC_POINT_oct2point function.
+pub fn ec_point_oct_to_point(buf: &[u8]) -> Result<OwnedECPoint, Error> {
+ // Safety: The buffer is valid.
+ let result = unsafe { ECPOINTOct2Point(buf.as_ptr(), buf.len()) };
+ if result.is_null() {
+ return Err(Error::ECPoint2OctFailed);
+ }
+ // Our C wrapper creates a new EC_POINT, so we mark this mutable and free
+ // it on drop.
+ Ok(OwnedECPoint(result))
+}
+
#[cfg(test)]
mod tests {
@@ -249,4 +426,39 @@
}
assert_ne!(key, vec![0; 16]);
}
+
+ #[test]
+ fn test_hkdf() {
+ let result = hkdf_extract(&[0; 16], &[0; 16]);
+ assert!(result.is_ok());
+ for out_len in 4..=8 {
+ let result = hkdf_expand(out_len, &[0; 16], &[0; 16]);
+ assert!(result.is_ok());
+ assert_eq!(result.unwrap().len(), out_len);
+ }
+ }
+
+ #[test]
+ fn test_ec() {
+ let key = ec_key_generate_key();
+ assert!(key.is_ok());
+ assert!(!key.unwrap().0.is_null());
+
+ let key = ec_key_derive_from_secret(&[42; 16]);
+ assert!(key.is_ok());
+ let key = key.unwrap();
+ assert!(!key.0.is_null());
+
+ let point = ec_key_get0_public_key(&key);
+
+ let result = ecdh_compute_key(point.get_point(), &key);
+ assert!(result.is_ok());
+
+ let oct = ec_point_point_to_oct(point.get_point());
+ assert!(oct.is_ok());
+ let oct = oct.unwrap();
+
+ let point2 = ec_point_oct_to_point(oct.as_slice());
+ assert!(point2.is_ok());
+ }
}