Merge "Switch to RsaPublicKey."
diff --git a/ondevice-signing/CertUtils.cpp b/ondevice-signing/CertUtils.cpp
index acc11e4..d2f19ce 100644
--- a/ondevice-signing/CertUtils.cpp
+++ b/ondevice-signing/CertUtils.cpp
@@ -299,6 +299,37 @@
     return extractPublicKey(X509_get_pubkey(cert.value().get()));
 }
 
+Result<std::vector<uint8_t>> extractRsaPublicKey(EVP_PKEY* pkey) {
+    RSA* rsa = EVP_PKEY_get0_RSA(pkey);
+    if (rsa == nullptr) {
+        return Error() << "The public key is not an RSA key";
+    }
+
+    uint8_t* out = nullptr;
+    int size = i2d_RSAPublicKey(rsa, &out);
+    if (size < 0 || !out) {
+        return Error() << "Failed to convert to RSAPublicKey";
+    }
+
+    bssl::UniquePtr<uint8_t> buffer(out);
+    std::vector<uint8_t> result(out, out + size);
+    return result;
+}
+
+Result<std::vector<uint8_t>> extractRsaPublicKeyFromX509(const std::vector<uint8_t>& derCert) {
+    auto derCertBytes = derCert.data();
+    bssl::UniquePtr<X509> decoded_cert(d2i_X509(nullptr, &derCertBytes, derCert.size()));
+    if (decoded_cert.get() == nullptr) {
+        return Error() << "Failed to decode X509 certificate.";
+    }
+    bssl::UniquePtr<EVP_PKEY> decoded_pkey(X509_get_pubkey(decoded_cert.get()));
+    if (decoded_pkey == nullptr) {
+        return Error() << "Failed to extract public key from x509 cert";
+    }
+
+    return extractRsaPublicKey(decoded_pkey.get());
+}
+
 Result<CertInfo> verifyAndExtractCertInfoFromX509(const std::string& path,
                                                   const std::vector<uint8_t>& publicKey) {
     auto public_key = toRsaPkey(publicKey);
diff --git a/ondevice-signing/CertUtils.h b/ondevice-signing/CertUtils.h
index 1fa5bbc..fd6080d 100644
--- a/ondevice-signing/CertUtils.h
+++ b/ondevice-signing/CertUtils.h
@@ -60,6 +60,9 @@
 extractPublicKeyFromSubjectPublicKeyInfo(const std::vector<uint8_t>& subjectKeyInfo);
 android::base::Result<std::vector<uint8_t>> extractPublicKeyFromX509(const std::string& path);
 
+android::base::Result<std::vector<uint8_t>>
+extractRsaPublicKeyFromX509(const std::vector<uint8_t>& x509);
+
 android::base::Result<CertInfo>
 verifyAndExtractCertInfoFromX509(const std::string& path, const std::vector<uint8_t>& publicKey);
 
diff --git a/ondevice-signing/FakeCompOs.cpp b/ondevice-signing/FakeCompOs.cpp
index 48eb01a..637fe5c 100644
--- a/ondevice-signing/FakeCompOs.cpp
+++ b/ondevice-signing/FakeCompOs.cpp
@@ -16,7 +16,6 @@
 
 #include "FakeCompOs.h"
 
-#include "CertUtils.h"
 #include "KeyConstants.h"
 
 #include <android-base/file.h>
@@ -26,7 +25,10 @@
 
 #include <binder/IServiceManager.h>
 
+#include <openssl/nid.h>
 #include <openssl/rand.h>
+#include <openssl/rsa.h>
+#include <openssl/sha.h>
 
 using android::String16;
 
@@ -210,6 +212,28 @@
     return signature.value();
 }
 
+Result<void> FakeCompOs::verifySignature(const ByteVector& message, const ByteVector& signature,
+                                         const ByteVector& rsaPublicKey) const {
+    auto derBytes = rsaPublicKey.data();
+    bssl::UniquePtr<RSA> rsaKey(d2i_RSAPublicKey(nullptr, &derBytes, rsaPublicKey.size()));
+    if (rsaKey.get() == nullptr) {
+        return Error() << "Failed to parse RsaPublicKey";
+    }
+    if (derBytes != rsaPublicKey.data() + rsaPublicKey.size()) {
+        return Error() << "Key has unexpected trailing data";
+    }
+
+    uint8_t hashBuf[SHA256_DIGEST_LENGTH];
+    SHA256(message.data(), message.size(), hashBuf);
+
+    bool success = RSA_verify(NID_sha256, hashBuf, sizeof(hashBuf), signature.data(),
+                              signature.size(), rsaKey.get());
+    if (!success) {
+        return Error() << "Failed to verify signature";
+    }
+    return {};
+}
+
 Result<void> FakeCompOs::loadAndVerifyKey(const ByteVector& keyBlob,
                                           const ByteVector& publicKey) const {
     // To verify the key is valid, we use it to sign some data, and then verify the signature using
@@ -225,8 +249,5 @@
         return signature.error();
     }
 
-    std::string dataStr(data.begin(), data.end());
-    std::string signatureStr(signature.value().begin(), signature.value().end());
-
-    return verifySignature(dataStr, signatureStr, publicKey);
+    return verifySignature(data, signature.value(), publicKey);
 }
diff --git a/ondevice-signing/FakeCompOs.h b/ondevice-signing/FakeCompOs.h
index 7d76938..6c12c60 100644
--- a/ondevice-signing/FakeCompOs.h
+++ b/ondevice-signing/FakeCompOs.h
@@ -53,6 +53,9 @@
 
     android::base::Result<ByteVector> signData(const ByteVector& keyBlob,
                                                const ByteVector& data) const;
+    android::base::Result<void> verifySignature(const ByteVector& message,
+                                                const ByteVector& signature,
+                                                const ByteVector& rsaPublicKey) const;
 
     KeyDescriptor mDescriptor;
     android::sp<IKeystoreService> mService;
diff --git a/ondevice-signing/odsign_main.cpp b/ondevice-signing/odsign_main.cpp
index 55d8b1c..b14a91e 100644
--- a/ondevice-signing/odsign_main.cpp
+++ b/ondevice-signing/odsign_main.cpp
@@ -221,7 +221,7 @@
         if (!keyData.ok()) {
             return Error() << "Failed to generate key: " << keyData.error();
         }
-        auto publicKeyStatus = extractPublicKeyFromX509(keyData.value().cert);
+        auto publicKeyStatus = extractRsaPublicKeyFromX509(keyData.value().cert);
         if (!publicKeyStatus.ok()) {
             return Error() << "Failed to extract CompOs public key" << publicKeyStatus.error();
         }