ECDH encryption module

Add a module for encrypting using ECDH, HKDF, and AES-GCM.
Also, add serialization of EC private keys, and remove derivation
from secrets; it turns out this is a better fit for the way
superencryption currently works.

Add a more thorough ECDH test in the crypto module, which simulates an
ephemeral key being used to send a message to a long-term key. The
high-level module has a similar test.

Bug: 163866361
Test: keystore2_crypto_test_rust, keystore2_test
Change-Id: I4c2bb1d8938de078ea37b930619918acc3c28fbe
diff --git a/keystore2/src/crypto/Android.bp b/keystore2/src/crypto/Android.bp
index e386735..21c9b74 100644
--- a/keystore2/src/crypto/Android.bp
+++ b/keystore2/src/crypto/Android.bp
@@ -68,7 +68,8 @@
         "--whitelist-function", "HKDFExpand",
         "--whitelist-function", "ECDHComputeKey",
         "--whitelist-function", "ECKEYGenerateKey",
-        "--whitelist-function", "ECKEYDeriveFromSecret",
+        "--whitelist-function", "ECKEYMarshalPrivateKey",
+        "--whitelist-function", "ECKEYParsePrivateKey",
         "--whitelist-function", "EC_KEY_get0_public_key",
         "--whitelist-function", "ECPOINTPoint2Oct",
         "--whitelist-function", "ECPOINTOct2Point",
diff --git a/keystore2/src/crypto/crypto.cpp b/keystore2/src/crypto/crypto.cpp
index 2e613fd..e4a1ac3 100644
--- a/keystore2/src/crypto/crypto.cpp
+++ b/keystore2/src/crypto/crypto.cpp
@@ -236,10 +236,28 @@
     return key;
 }
 
-EC_KEY* ECKEYDeriveFromSecret(const uint8_t* secret, size_t secret_len) {
+size_t ECKEYMarshalPrivateKey(const EC_KEY* priv_key, uint8_t* buf, size_t len) {
+    CBB cbb;
+    size_t out_len;
+    if (!CBB_init_fixed(&cbb, buf, len) ||
+        !EC_KEY_marshal_private_key(&cbb, priv_key, EC_PKEY_NO_PARAMETERS | EC_PKEY_NO_PUBKEY) ||
+        !CBB_finish(&cbb, nullptr, &out_len)) {
+        return 0;
+    } else {
+        return out_len;
+    }
+}
+
+EC_KEY* ECKEYParsePrivateKey(const uint8_t* buf, size_t len) {
+    CBS cbs;
+    CBS_init(&cbs, buf, len);
     EC_GROUP* group = EC_GROUP_new_by_curve_name(NID_X9_62_prime256v1);
-    auto result = EC_KEY_derive_from_secret(group, secret, secret_len);
+    auto result = EC_KEY_parse_private_key(&cbs, group);
     EC_GROUP_free(group);
+    if (result != nullptr && CBS_len(&cbs) != 0) {
+        EC_KEY_free(result);
+        return nullptr;
+    }
     return result;
 }
 
diff --git a/keystore2/src/crypto/crypto.hpp b/keystore2/src/crypto/crypto.hpp
index 6686c8c..f841eb3 100644
--- a/keystore2/src/crypto/crypto.hpp
+++ b/keystore2/src/crypto/crypto.hpp
@@ -55,7 +55,9 @@
 
   EC_KEY* ECKEYGenerateKey();
 
-  EC_KEY* ECKEYDeriveFromSecret(const uint8_t *secret, size_t secret_len);
+  size_t ECKEYMarshalPrivateKey(const EC_KEY *priv_key, uint8_t *buf, size_t len);
+
+  EC_KEY* ECKEYParsePrivateKey(const uint8_t *buf, size_t len);
 
   size_t ECPOINTPoint2Oct(const EC_POINT *point, uint8_t *buf, size_t len);
 
diff --git a/keystore2/src/crypto/error.rs b/keystore2/src/crypto/error.rs
index 1eec321..a369012 100644
--- a/keystore2/src/crypto/error.rs
+++ b/keystore2/src/crypto/error.rs
@@ -74,9 +74,13 @@
     #[error("Failed to generate key.")]
     ECKEYGenerateKeyFailed,
 
-    /// This is returned if the C implementation of ECKEYDeriveFromSecret returned null.
-    #[error("Failed to derive key.")]
-    ECKEYDeriveFailed,
+    /// This is returned if the C implementation of ECKEYMarshalPrivateKey returned 0.
+    #[error("Failed to marshal private key.")]
+    ECKEYMarshalPrivateKeyFailed,
+
+    /// This is returned if the C implementation of ECKEYParsePrivateKey returned null.
+    #[error("Failed to parse private key.")]
+    ECKEYParsePrivateKeyFailed,
 
     /// This is returned if the C implementation of ECPOINTPoint2Oct returned 0.
     #[error("Failed to convert point to oct.")]
diff --git a/keystore2/src/crypto/lib.rs b/keystore2/src/crypto/lib.rs
index 98e6eef..3523a9d 100644
--- a/keystore2/src/crypto/lib.rs
+++ b/keystore2/src/crypto/lib.rs
@@ -20,9 +20,9 @@
 pub use error::Error;
 use keystore2_crypto_bindgen::{
     extractSubjectFromCertificate, generateKeyFromPassword, randomBytes, AES_gcm_decrypt,
-    AES_gcm_encrypt, ECDHComputeKey, ECKEYDeriveFromSecret, ECKEYGenerateKey, ECPOINTOct2Point,
-    ECPOINTPoint2Oct, EC_KEY_free, EC_KEY_get0_public_key, EC_POINT_free, HKDFExpand, HKDFExtract,
-    EC_KEY, EC_MAX_BYTES, EC_POINT, EVP_MAX_MD_SIZE,
+    AES_gcm_encrypt, ECDHComputeKey, ECKEYGenerateKey, ECKEYMarshalPrivateKey,
+    ECKEYParsePrivateKey, ECPOINTOct2Point, ECPOINTPoint2Oct, EC_KEY_free, EC_KEY_get0_public_key,
+    EC_POINT_free, HKDFExpand, HKDFExtract, EC_KEY, EC_MAX_BYTES, EC_POINT, EVP_MAX_MD_SIZE,
 };
 use std::convert::TryFrom;
 use std::convert::TryInto;
@@ -338,14 +338,32 @@
     Ok(ECKey(key))
 }
 
-/// Calls the boringssl EC_KEY_derive_from_secret function.
-pub fn ec_key_derive_from_secret(secret: &[u8]) -> Result<ECKey, Error> {
-    // Safety: secret is a valid buffer.
-    let result = unsafe { ECKEYDeriveFromSecret(secret.as_ptr(), secret.len()) };
-    if result.is_null() {
-        return Err(Error::ECKEYDeriveFailed);
+/// Calls the boringssl EC_KEY_marshal_private_key function.
+pub fn ec_key_marshal_private_key(key: &ECKey) -> Result<ZVec, Error> {
+    let len = 39; // Empirically observed length of private key
+    let mut buf = ZVec::new(len)?;
+    // Safety: the key is valid.
+    // This will not write past the specified length of the buffer; if the
+    // len above is too short, it returns 0.
+    let written_len =
+        unsafe { ECKEYMarshalPrivateKey(key.0, buf.as_mut_ptr(), buf.len()) } as usize;
+    if written_len == len {
+        Ok(buf)
+    } else {
+        Err(Error::ECKEYMarshalPrivateKeyFailed)
     }
-    Ok(ECKey(result))
+}
+
+/// Calls the boringssl EC_KEY_parse_private_key function.
+pub fn ec_key_parse_private_key(buf: &[u8]) -> Result<ECKey, Error> {
+    // Safety: this will not read past the specified length of the buffer.
+    // It fails if less than the whole buffer is consumed.
+    let key = unsafe { ECKEYParsePrivateKey(buf.as_ptr(), buf.len()) };
+    if key.is_null() {
+        Err(Error::ECKEYParsePrivateKeyFailed)
+    } else {
+        Ok(ECKey(key))
+    }
 }
 
 /// Calls the boringssl EC_KEY_get0_public_key function.
@@ -519,26 +537,26 @@
     }
 
     #[test]
-    fn test_ec() {
-        let key = ec_key_generate_key();
-        assert!(key.is_ok());
-        assert!(!key.unwrap().0.is_null());
+    fn test_ec() -> Result<(), Error> {
+        let priv0 = ec_key_generate_key()?;
+        assert!(!priv0.0.is_null());
+        let pub0 = ec_key_get0_public_key(&priv0);
 
-        let key = ec_key_derive_from_secret(&[42; 16]);
-        assert!(key.is_ok());
-        let key = key.unwrap();
-        assert!(!key.0.is_null());
+        let priv1 = ec_key_generate_key()?;
+        let pub1 = ec_key_get0_public_key(&priv1);
 
-        let point = ec_key_get0_public_key(&key);
+        let priv0s = ec_key_marshal_private_key(&priv0)?;
+        let pub0s = ec_point_point_to_oct(pub0.get_point())?;
+        let pub1s = ec_point_point_to_oct(pub1.get_point())?;
 
-        let result = ecdh_compute_key(point.get_point(), &key);
-        assert!(result.is_ok());
+        let priv0 = ec_key_parse_private_key(&priv0s)?;
+        let pub0 = ec_point_oct_to_point(&pub0s)?;
+        let pub1 = ec_point_oct_to_point(&pub1s)?;
 
-        let oct = ec_point_point_to_oct(point.get_point());
-        assert!(oct.is_ok());
-        let oct = oct.unwrap();
+        let left_key = ecdh_compute_key(pub0.get_point(), &priv1)?;
+        let right_key = ecdh_compute_key(pub1.get_point(), &priv0)?;
 
-        let point2 = ec_point_oct_to_point(oct.as_slice());
-        assert!(point2.is_ok());
+        assert_eq!(left_key, right_key);
+        Ok(())
     }
 }
diff --git a/keystore2/src/ec_crypto.rs b/keystore2/src/ec_crypto.rs
new file mode 100644
index 0000000..0425d4a
--- /dev/null
+++ b/keystore2/src/ec_crypto.rs
@@ -0,0 +1,137 @@
+// Copyright 2021, The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//! Implement ECDH-based encryption.
+
+use anyhow::{Context, Result};
+use keystore2_crypto::{
+    aes_gcm_decrypt, aes_gcm_encrypt, ec_key_generate_key, ec_key_get0_public_key,
+    ec_key_marshal_private_key, ec_key_parse_private_key, ec_point_oct_to_point,
+    ec_point_point_to_oct, ecdh_compute_key, generate_salt, hkdf_expand, hkdf_extract, ECKey, ZVec,
+    AES_256_KEY_LENGTH,
+};
+
+/// Private key for ECDH encryption.
+pub struct ECDHPrivateKey(ECKey);
+
+impl ECDHPrivateKey {
+    /// Randomly generate a fresh keypair.
+    pub fn generate() -> Result<ECDHPrivateKey> {
+        ec_key_generate_key()
+            .map(ECDHPrivateKey)
+            .context("In ECDHPrivateKey::generate: generation failed")
+    }
+
+    /// Deserialize bytes into an ECDH keypair
+    pub fn from_private_key(buf: &[u8]) -> Result<ECDHPrivateKey> {
+        ec_key_parse_private_key(buf)
+            .map(ECDHPrivateKey)
+            .context("In ECDHPrivateKey::from_private_key: parsing failed")
+    }
+
+    /// Serialize the ECDH key into bytes
+    pub fn private_key(&self) -> Result<ZVec> {
+        ec_key_marshal_private_key(&self.0)
+            .context("In ECDHPrivateKey::private_key: marshalling failed")
+    }
+
+    /// Generate the serialization of the corresponding public key
+    pub fn public_key(&self) -> Result<Vec<u8>> {
+        let point = ec_key_get0_public_key(&self.0);
+        ec_point_point_to_oct(point.get_point())
+            .context("In ECDHPrivateKey::public_key: marshalling failed")
+    }
+
+    /// Use ECDH to agree an AES key with another party whose public key we have.
+    /// Sender and recipient public keys are passed separately because they are
+    /// switched in encryption vs decryption.
+    fn agree_key(
+        &self,
+        salt: &[u8],
+        other_public_key: &[u8],
+        sender_public_key: &[u8],
+        recipient_public_key: &[u8],
+    ) -> Result<ZVec> {
+        let hkdf = hkdf_extract(sender_public_key, salt)
+            .context("In ECDHPrivateKey::agree_key: hkdf_extract on sender_public_key failed")?;
+        let hkdf = hkdf_extract(recipient_public_key, &hkdf)
+            .context("In ECDHPrivateKey::agree_key: hkdf_extract on recipient_public_key failed")?;
+        let other_public_key = ec_point_oct_to_point(other_public_key)
+            .context("In ECDHPrivateKey::agree_key: ec_point_oct_to_point failed")?;
+        let secret = ecdh_compute_key(other_public_key.get_point(), &self.0)
+            .context("In ECDHPrivateKey::agree_key: ecdh_compute_key failed")?;
+        let prk = hkdf_extract(&secret, &hkdf)
+            .context("In ECDHPrivateKey::agree_key: hkdf_extract on secret failed")?;
+
+        let aes_key = hkdf_expand(AES_256_KEY_LENGTH, &prk, b"AES-256-GCM key")
+            .context("In ECDHPrivateKey::agree_key: hkdf_expand failed")?;
+        Ok(aes_key)
+    }
+
+    /// Encrypt a message to the party with the given public key
+    pub fn encrypt_message(
+        recipient_public_key: &[u8],
+        message: &[u8],
+    ) -> Result<(Vec<u8>, Vec<u8>, Vec<u8>, Vec<u8>, Vec<u8>)> {
+        let sender_key =
+            Self::generate().context("In ECDHPrivateKey::encrypt_message: generate failed")?;
+        let sender_public_key = sender_key
+            .public_key()
+            .context("In ECDHPrivateKey::encrypt_message: public_key failed")?;
+        let salt =
+            generate_salt().context("In ECDHPrivateKey::encrypt_message: generate_salt failed")?;
+        let aes_key = sender_key
+            .agree_key(&salt, recipient_public_key, &sender_public_key, recipient_public_key)
+            .context("In ECDHPrivateKey::encrypt_message: agree_key failed")?;
+        let (ciphertext, iv, tag) = aes_gcm_encrypt(message, &aes_key)
+            .context("In ECDHPrivateKey::encrypt_message: aes_gcm_encrypt failed")?;
+        Ok((sender_public_key, salt, iv, ciphertext, tag))
+    }
+
+    /// Decrypt a message sent to us
+    pub fn decrypt_message(
+        &self,
+        sender_public_key: &[u8],
+        salt: &[u8],
+        iv: &[u8],
+        ciphertext: &[u8],
+        tag: &[u8],
+    ) -> Result<ZVec> {
+        let recipient_public_key = self.public_key()?;
+        let aes_key = self
+            .agree_key(salt, sender_public_key, sender_public_key, &recipient_public_key)
+            .context("In ECDHPrivateKey::decrypt_message: agree_key failed")?;
+        aes_gcm_decrypt(ciphertext, iv, tag, &aes_key)
+            .context("In ECDHPrivateKey::decrypt_message: aes_gcm_decrypt failed")
+    }
+}
+
+#[cfg(test)]
+mod test {
+    use super::*;
+
+    #[test]
+    fn test_crypto_roundtrip() -> Result<()> {
+        let message = b"Hello world";
+        let recipient = ECDHPrivateKey::generate()?;
+        let (sender_public_key, salt, iv, ciphertext, tag) =
+            ECDHPrivateKey::encrypt_message(&recipient.public_key()?, message)?;
+        let recipient = ECDHPrivateKey::from_private_key(&recipient.private_key()?)?;
+        let decrypted =
+            recipient.decrypt_message(&sender_public_key, &salt, &iv, &ciphertext, &tag)?;
+        let dc: &[u8] = &decrypted;
+        assert_eq!(message, dc);
+        Ok(())
+    }
+}
diff --git a/keystore2/src/lib.rs b/keystore2/src/lib.rs
index cb47e3e..b6df8e8 100644
--- a/keystore2/src/lib.rs
+++ b/keystore2/src/lib.rs
@@ -19,6 +19,7 @@
 pub mod async_task;
 pub mod authorization;
 pub mod database;
+pub mod ec_crypto;
 pub mod enforcements;
 pub mod entropy;
 pub mod error;