Merge "Keystore2: fix test_large_number_of_concurrent_db_manipulations"
diff --git a/fsverity_init/fsverity_init.cpp b/fsverity_init/fsverity_init.cpp
index 7ab4097..7bc6022 100644
--- a/fsverity_init/fsverity_init.cpp
+++ b/fsverity_init/fsverity_init.cpp
@@ -37,15 +37,17 @@
     return true;
 }
 
-void LoadKeyFromStdin(key_serial_t keyring_id, const char* keyname) {
+bool LoadKeyFromStdin(key_serial_t keyring_id, const char* keyname) {
     std::string content;
     if (!android::base::ReadFdToString(STDIN_FILENO, &content)) {
         LOG(ERROR) << "Failed to read key from stdin";
-        return;
+        return false;
     }
     if (!LoadKeyToKeyring(keyring_id, keyname, content.c_str(), content.size())) {
         LOG(ERROR) << "Failed to load key from stdin";
+        return false;
     }
+    return true;
 }
 
 void LoadKeyFromFile(key_serial_t keyring_id, const char* keyname, const std::string& path) {
@@ -101,7 +103,9 @@
             LOG(ERROR) << "--load-extra-key requires <key_name> argument.";
             return -1;
         }
-        LoadKeyFromStdin(keyring_id, argv[2]);
+        if (!LoadKeyFromStdin(keyring_id, argv[2])) {
+            return -1;
+        }
     } else if (command == "--lock") {
         // Requires files backed by fs-verity to be verified with a key in .fs-verity
         // keyring.
diff --git a/keystore/Android.bp b/keystore/Android.bp
index 0f2000c..b59bacc 100644
--- a/keystore/Android.bp
+++ b/keystore/Android.bp
@@ -87,7 +87,7 @@
 }
 
 // Library for keystore clients using the WiFi HIDL interface
-cc_library_shared {
+cc_library {
     name: "libkeystore-wifi-hidl",
     defaults: ["keystore_defaults"],
 
@@ -102,5 +102,5 @@
 
     export_include_dirs: ["include"],
 
-    vendor: true,
+    vendor_available: true,
 }
diff --git a/keystore2/src/km_compat/km_compat.cpp b/keystore2/src/km_compat/km_compat.cpp
index f6f8bfe..5f65eaa 100644
--- a/keystore2/src/km_compat/km_compat.cpp
+++ b/keystore2/src/km_compat/km_compat.cpp
@@ -1311,7 +1311,7 @@
     CHECK(serviceManager.get()) << "Failed to get ServiceManager";
     auto result = enumerateKeymasterDevices<Keymaster4>(serviceManager.get());
     auto softKeymaster = result[SecurityLevel::SOFTWARE];
-    if (!result[SecurityLevel::TRUSTED_ENVIRONMENT]) {
+    if ((!result[SecurityLevel::TRUSTED_ENVIRONMENT]) && (!result[SecurityLevel::STRONGBOX])) {
         result = enumerateKeymasterDevices<Keymaster3>(serviceManager.get());
     }
     if (softKeymaster) result[SecurityLevel::SOFTWARE] = softKeymaster;
diff --git a/ondevice-signing/Android.bp b/ondevice-signing/Android.bp
index 432e585..9085d81 100644
--- a/ondevice-signing/Android.bp
+++ b/ondevice-signing/Android.bp
@@ -84,6 +84,7 @@
   srcs: [
     "odsign_main.cpp",
     "CertUtils.cpp",
+    "FakeCompOs.cpp",
     "KeystoreKey.cpp",
     "KeystoreHmacKey.cpp",
     "VerityUtils.cpp",
diff --git a/ondevice-signing/CertUtils.cpp b/ondevice-signing/CertUtils.cpp
index 9867f62..ce2b0fd 100644
--- a/ondevice-signing/CertUtils.cpp
+++ b/ondevice-signing/CertUtils.cpp
@@ -26,39 +26,55 @@
 #include <openssl/x509.h>
 #include <openssl/x509v3.h>
 
-#include <fcntl.h>
+#include <optional>
 #include <vector>
 
 #include "KeyConstants.h"
 
-const char kBasicConstraints[] = "CA:TRUE";
-const char kKeyUsage[] = "critical,keyCertSign,cRLSign,digitalSignature";
-const char kSubjectKeyIdentifier[] = "hash";
-const char kAuthorityKeyIdentifier[] = "keyid:always";
+const char kRootCommonName[] = "ODS";
 constexpr int kCertLifetimeSeconds = 10 * 365 * 24 * 60 * 60;
 
-using android::base::Result;
-// using android::base::ErrnoError;
+using android::base::ErrnoError;
 using android::base::Error;
+using android::base::Result;
 
-static bool add_ext(X509* cert, int nid, const char* value) {
-    size_t len = strlen(value) + 1;
-    std::vector<char> mutableValue(value, value + len);
-    X509V3_CTX context;
+static Result<bssl::UniquePtr<X509>> loadX509(const std::string& path) {
+    X509* rawCert;
+    auto f = fopen(path.c_str(), "re");
+    if (f == nullptr) {
+        return Error() << "Failed to open " << path;
+    }
+    if (!d2i_X509_fp(f, &rawCert)) {
+        fclose(f);
+        return Error() << "Unable to decode x509 cert at " << path;
+    }
+    bssl::UniquePtr<X509> cert(rawCert);
 
-    X509V3_set_ctx_nodb(&context);
+    fclose(f);
+    return cert;
+}
 
-    X509V3_set_ctx(&context, cert, cert, nullptr, nullptr, 0);
-    X509_EXTENSION* ex = X509V3_EXT_nconf_nid(nullptr, &context, nid, mutableValue.data());
+static X509V3_CTX makeContext(X509* issuer, X509* subject) {
+    X509V3_CTX context = {};
+    X509V3_set_ctx(&context, issuer, subject, nullptr, nullptr, 0);
+    return context;
+}
+
+static bool add_ext(X509V3_CTX* context, X509* cert, int nid, const char* value) {
+    bssl::UniquePtr<X509_EXTENSION> ex(X509V3_EXT_nconf_nid(nullptr, context, nid, value));
     if (!ex) {
         return false;
     }
 
-    X509_add_ext(cert, ex, -1);
-    X509_EXTENSION_free(ex);
+    X509_add_ext(cert, ex.get(), -1);
     return true;
 }
 
+static void addNameEntry(X509_NAME* name, const char* field, const char* value) {
+    X509_NAME_add_entry_by_txt(name, field, MBSTRING_ASC,
+                               reinterpret_cast<const unsigned char*>(value), -1, -1, 0);
+}
+
 Result<bssl::UniquePtr<RSA>> getRsa(const std::vector<uint8_t>& publicKey) {
     bssl::UniquePtr<BIGNUM> n(BN_new());
     bssl::UniquePtr<BIGNUM> e(BN_new());
@@ -109,19 +125,31 @@
     return public_key;
 }
 
-Result<void> createSelfSignedCertificate(
-    const std::vector<uint8_t>& publicKey,
-    const std::function<Result<std::string>(const std::string&)>& signFunction,
-    const std::string& path) {
+static Result<void> createCertificate(
+    const char* commonName, const std::vector<uint8_t>& publicKey,
+    const std::function<android::base::Result<std::string>(const std::string&)>& signFunction,
+    const std::optional<std::string>& issuerCertPath, const std::string& path) {
+
+    // If an issuer cert is specified, we are signing someone else's key.
+    // Otherwise we are signing our key - a self-signed certificate.
+    bool selfSigned = !issuerCertPath;
+
     bssl::UniquePtr<X509> x509(X509_new());
     if (!x509) {
         return Error() << "Unable to allocate x509 container";
     }
     X509_set_version(x509.get(), 2);
-
-    ASN1_INTEGER_set(X509_get_serialNumber(x509.get()), 1);
     X509_gmtime_adj(X509_get_notBefore(x509.get()), 0);
     X509_gmtime_adj(X509_get_notAfter(x509.get()), kCertLifetimeSeconds);
+    ASN1_INTEGER_set(X509_get_serialNumber(x509.get()), selfSigned ? 1 : 2);
+
+    bssl::UniquePtr<X509_ALGOR> algor(X509_ALGOR_new());
+    if (!algor ||
+        !X509_ALGOR_set0(algor.get(), OBJ_nid2obj(NID_sha256WithRSAEncryption), V_ASN1_NULL,
+                         NULL) ||
+        !X509_set1_signature_algo(x509.get(), algor.get())) {
+        return Error() << "Unable to set x509 signature algorithm";
+    }
 
     auto public_key = toRsaPkey(publicKey);
     if (!public_key.ok()) {
@@ -132,33 +160,53 @@
         return Error() << "Unable to set x509 public key";
     }
 
-    X509_NAME* name = X509_get_subject_name(x509.get());
-    if (!name) {
+    X509_NAME* subjectName = X509_get_subject_name(x509.get());
+    if (!subjectName) {
         return Error() << "Unable to get x509 subject name";
     }
-    X509_NAME_add_entry_by_txt(name, "C", MBSTRING_ASC,
-                               reinterpret_cast<const unsigned char*>("US"), -1, -1, 0);
-    X509_NAME_add_entry_by_txt(name, "O", MBSTRING_ASC,
-                               reinterpret_cast<const unsigned char*>("Android"), -1, -1, 0);
-    X509_NAME_add_entry_by_txt(name, "CN", MBSTRING_ASC,
-                               reinterpret_cast<const unsigned char*>("ODS"), -1, -1, 0);
-    if (!X509_set_issuer_name(x509.get(), name)) {
-        return Error() << "Unable to set x509 issuer name";
+    addNameEntry(subjectName, "C", "US");
+    addNameEntry(subjectName, "O", "Android");
+    addNameEntry(subjectName, "CN", commonName);
+
+    if (selfSigned) {
+        if (!X509_set_issuer_name(x509.get(), subjectName)) {
+            return Error() << "Unable to set x509 issuer name";
+        }
+    } else {
+        X509_NAME* issuerName = X509_get_issuer_name(x509.get());
+        if (!issuerName) {
+            return Error() << "Unable to get x509 issuer name";
+        }
+        addNameEntry(issuerName, "C", "US");
+        addNameEntry(issuerName, "O", "Android");
+        addNameEntry(issuerName, "CN", kRootCommonName);
     }
 
-    add_ext(x509.get(), NID_basic_constraints, kBasicConstraints);
-    add_ext(x509.get(), NID_key_usage, kKeyUsage);
-    add_ext(x509.get(), NID_subject_key_identifier, kSubjectKeyIdentifier);
-    add_ext(x509.get(), NID_authority_key_identifier, kAuthorityKeyIdentifier);
+    // Beware: context contains a pointer to issuerCert, so we need to keep it alive.
+    bssl::UniquePtr<X509> issuerCert;
+    X509V3_CTX context;
 
-    bssl::UniquePtr<X509_ALGOR> algor(X509_ALGOR_new());
-    if (!algor ||
-        !X509_ALGOR_set0(algor.get(), OBJ_nid2obj(NID_sha256WithRSAEncryption), V_ASN1_NULL,
-                         NULL) ||
-        !X509_set1_signature_algo(x509.get(), algor.get())) {
-        return Error() << "Unable to set x509 signature algorithm";
+    if (selfSigned) {
+        context = makeContext(x509.get(), x509.get());
+    } else {
+        auto certStatus = loadX509(*issuerCertPath);
+        if (!certStatus.ok()) {
+            return Error() << "Unable to load issuer cert: " << certStatus.error();
+        }
+        issuerCert = std::move(certStatus.value());
+        context = makeContext(issuerCert.get(), x509.get());
     }
 
+    // If it's a self-signed cert we use it for signing certs, otherwise only for signing data.
+    const char* basicConstraints = selfSigned ? "CA:TRUE" : "CA:FALSE";
+    const char* keyUsage =
+        selfSigned ? "critical,keyCertSign,cRLSign,digitalSignature" : "critical,digitalSignature";
+
+    add_ext(&context, x509.get(), NID_basic_constraints, basicConstraints);
+    add_ext(&context, x509.get(), NID_key_usage, keyUsage);
+    add_ext(&context, x509.get(), NID_subject_key_identifier, "hash");
+    add_ext(&context, x509.get(), NID_authority_key_identifier, "keyid:always");
+
     // Get the data to be signed
     unsigned char* to_be_signed_buf(nullptr);
     size_t to_be_signed_length = i2d_re_X509_tbs(x509.get(), &to_be_signed_buf);
@@ -177,14 +225,30 @@
 
     auto f = fopen(path.c_str(), "wbe");
     if (f == nullptr) {
-        return Error() << "Failed to open " << path;
+        return ErrnoError() << "Failed to open " << path;
     }
     i2d_X509_fp(f, x509.get());
-    fclose(f);
+    if (fclose(f) != 0) {
+        return ErrnoError() << "Failed to close " << path;
+    }
 
     return {};
 }
 
+Result<void> createSelfSignedCertificate(
+    const std::vector<uint8_t>& publicKey,
+    const std::function<Result<std::string>(const std::string&)>& signFunction,
+    const std::string& path) {
+    return createCertificate(kRootCommonName, publicKey, signFunction, {}, path);
+}
+
+android::base::Result<void> createLeafCertificate(
+    const char* commonName, const std::vector<uint8_t>& publicKey,
+    const std::function<android::base::Result<std::string>(const std::string&)>& signFunction,
+    const std::string& issuerCertPath, const std::string& path) {
+    return createCertificate(commonName, publicKey, signFunction, issuerCertPath, path);
+}
+
 Result<std::vector<uint8_t>> extractPublicKey(EVP_PKEY* pkey) {
     if (pkey == nullptr) {
         return Error() << "Failed to extract public key from x509 cert";
@@ -225,22 +289,6 @@
     return extractPublicKey(decoded_pkey.get());
 }
 
-static Result<bssl::UniquePtr<X509>> loadX509(const std::string& path) {
-    X509* rawCert;
-    auto f = fopen(path.c_str(), "re");
-    if (f == nullptr) {
-        return Error() << "Failed to open " << path;
-    }
-    if (!d2i_X509_fp(f, &rawCert)) {
-        fclose(f);
-        return Error() << "Unable to decode x509 cert at " << path;
-    }
-    bssl::UniquePtr<X509> cert(rawCert);
-
-    fclose(f);
-    return cert;
-}
-
 Result<std::vector<uint8_t>> extractPublicKeyFromX509(const std::string& path) {
     auto cert = loadX509(path);
     if (!cert.ok()) {
diff --git a/ondevice-signing/CertUtils.h b/ondevice-signing/CertUtils.h
index d202fbc..b412d21 100644
--- a/ondevice-signing/CertUtils.h
+++ b/ondevice-signing/CertUtils.h
@@ -16,6 +16,10 @@
 
 #pragma once
 
+#include <functional>
+#include <string>
+#include <vector>
+
 #include <android-base/result.h>
 
 struct CertInfo {
@@ -27,6 +31,12 @@
     const std::vector<uint8_t>& publicKey,
     const std::function<android::base::Result<std::string>(const std::string&)>& signFunction,
     const std::string& path);
+
+android::base::Result<void> createLeafCertificate(
+    const char* commonName, const std::vector<uint8_t>& publicKey,
+    const std::function<android::base::Result<std::string>(const std::string&)>& signFunction,
+    const std::string& issuerCertPath, const std::string& outPath);
+
 android::base::Result<std::vector<uint8_t>> createPkcs7(const std::vector<uint8_t>& signedData);
 
 android::base::Result<std::vector<uint8_t>>
diff --git a/ondevice-signing/FakeCompOs.cpp b/ondevice-signing/FakeCompOs.cpp
new file mode 100644
index 0000000..48eb01a
--- /dev/null
+++ b/ondevice-signing/FakeCompOs.cpp
@@ -0,0 +1,232 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "FakeCompOs.h"
+
+#include "CertUtils.h"
+#include "KeyConstants.h"
+
+#include <android-base/file.h>
+#include <android-base/logging.h>
+#include <android-base/result.h>
+#include <android-base/scopeguard.h>
+
+#include <binder/IServiceManager.h>
+
+#include <openssl/rand.h>
+
+using android::String16;
+
+using android::hardware::security::keymint::Algorithm;
+using android::hardware::security::keymint::Digest;
+using android::hardware::security::keymint::KeyParameter;
+using android::hardware::security::keymint::KeyParameterValue;
+using android::hardware::security::keymint::KeyPurpose;
+using android::hardware::security::keymint::PaddingMode;
+using android::hardware::security::keymint::SecurityLevel;
+using android::hardware::security::keymint::Tag;
+
+using android::system::keystore2::CreateOperationResponse;
+using android::system::keystore2::Domain;
+
+using android::base::Error;
+using android::base::Result;
+
+using android::binder::Status;
+
+// TODO: Allocate a namespace for CompOS
+const int64_t kCompOsNamespace = 101;
+
+Result<std::unique_ptr<FakeCompOs>> FakeCompOs::newInstance() {
+    std::unique_ptr<FakeCompOs> compOs(new FakeCompOs);
+    auto init = compOs->initialize();
+    if (init.ok()) {
+        return compOs;
+    } else {
+        return init.error();
+    }
+}
+
+FakeCompOs::FakeCompOs() {}
+
+Result<void> FakeCompOs::initialize() {
+    auto sm = android::defaultServiceManager();
+    if (!sm) {
+        return Error() << "No ServiceManager";
+    }
+    auto rawService = sm->getService(String16("android.system.keystore2.IKeystoreService/default"));
+    if (!rawService) {
+        return Error() << "No Keystore service";
+    }
+    mService = interface_cast<android::system::keystore2::IKeystoreService>(rawService);
+    if (!mService) {
+        return Error() << "Bad Keystore service";
+    }
+
+    // TODO: We probably want SecurityLevel::SOFTWARE here, in the VM, but Keystore doesn't do it
+    auto status = mService->getSecurityLevel(SecurityLevel::TRUSTED_ENVIRONMENT, &mSecurityLevel);
+    if (!status.isOk()) {
+        return Error() << status;
+    }
+
+    return {};
+}
+
+Result<FakeCompOs::KeyData> FakeCompOs::generateKey() const {
+    std::vector<KeyParameter> params;
+
+    KeyParameter algo;
+    algo.tag = Tag::ALGORITHM;
+    algo.value = KeyParameterValue::make<KeyParameterValue::algorithm>(Algorithm::RSA);
+    params.push_back(algo);
+
+    KeyParameter key_size;
+    key_size.tag = Tag::KEY_SIZE;
+    key_size.value = KeyParameterValue::make<KeyParameterValue::integer>(kRsaKeySize);
+    params.push_back(key_size);
+
+    KeyParameter digest;
+    digest.tag = Tag::DIGEST;
+    digest.value = KeyParameterValue::make<KeyParameterValue::digest>(Digest::SHA_2_256);
+    params.push_back(digest);
+
+    KeyParameter padding;
+    padding.tag = Tag::PADDING;
+    padding.value =
+        KeyParameterValue::make<KeyParameterValue::paddingMode>(PaddingMode::RSA_PKCS1_1_5_SIGN);
+    params.push_back(padding);
+
+    KeyParameter exponent;
+    exponent.tag = Tag::RSA_PUBLIC_EXPONENT;
+    exponent.value = KeyParameterValue::make<KeyParameterValue::longInteger>(kRsaKeyExponent);
+    params.push_back(exponent);
+
+    KeyParameter purpose;
+    purpose.tag = Tag::PURPOSE;
+    purpose.value = KeyParameterValue::make<KeyParameterValue::keyPurpose>(KeyPurpose::SIGN);
+    params.push_back(purpose);
+
+    KeyParameter auth;
+    auth.tag = Tag::NO_AUTH_REQUIRED;
+    auth.value = KeyParameterValue::make<KeyParameterValue::boolValue>(true);
+    params.push_back(auth);
+
+    KeyDescriptor descriptor;
+    descriptor.domain = Domain::BLOB;
+    descriptor.nspace = kCompOsNamespace;
+
+    KeyMetadata metadata;
+    auto status = mSecurityLevel->generateKey(descriptor, {}, params, 0, {}, &metadata);
+    if (!status.isOk()) {
+        return Error() << "Failed to generate key";
+    }
+
+    auto& cert = metadata.certificate;
+    if (!cert) {
+        return Error() << "No certificate.";
+    }
+
+    auto& blob = metadata.key.blob;
+    if (!blob) {
+        return Error() << "No blob.";
+    }
+
+    KeyData key_data{std::move(metadata.certificate.value()), std::move(metadata.key.blob.value())};
+    return key_data;
+}
+
+Result<FakeCompOs::ByteVector> FakeCompOs::signData(const ByteVector& keyBlob,
+                                                    const ByteVector& data) const {
+    KeyDescriptor descriptor;
+    descriptor.domain = Domain::BLOB;
+    descriptor.nspace = kCompOsNamespace;
+    descriptor.blob = keyBlob;
+
+    std::vector<KeyParameter> parameters;
+
+    {
+        KeyParameter algo;
+        algo.tag = Tag::ALGORITHM;
+        algo.value = KeyParameterValue::make<KeyParameterValue::algorithm>(Algorithm::RSA);
+        parameters.push_back(algo);
+
+        KeyParameter digest;
+        digest.tag = Tag::DIGEST;
+        digest.value = KeyParameterValue::make<KeyParameterValue::digest>(Digest::SHA_2_256);
+        parameters.push_back(digest);
+
+        KeyParameter padding;
+        padding.tag = Tag::PADDING;
+        padding.value = KeyParameterValue::make<KeyParameterValue::paddingMode>(
+            PaddingMode::RSA_PKCS1_1_5_SIGN);
+        parameters.push_back(padding);
+
+        KeyParameter purpose;
+        purpose.tag = Tag::PURPOSE;
+        purpose.value = KeyParameterValue::make<KeyParameterValue::keyPurpose>(KeyPurpose::SIGN);
+        parameters.push_back(purpose);
+    }
+
+    Status status;
+
+    CreateOperationResponse response;
+    status = mSecurityLevel->createOperation(descriptor, parameters, /*forced=*/false, &response);
+    if (!status.isOk()) {
+        return Error() << "Failed to create operation: " << status;
+    }
+
+    auto operation = response.iOperation;
+    auto abort_guard = android::base::make_scope_guard([&] { operation->abort(); });
+
+    if (response.operationChallenge.has_value()) {
+        return Error() << "Key requires user authorization";
+    }
+
+    std::optional<ByteVector> signature;
+    status = operation->finish(data, {}, &signature);
+    if (!status.isOk()) {
+        return Error() << "Failed to sign data: " << status;
+    }
+
+    abort_guard.Disable();
+
+    if (!signature.has_value()) {
+        return Error() << "No signature received from keystore.";
+    }
+
+    return signature.value();
+}
+
+Result<void> FakeCompOs::loadAndVerifyKey(const ByteVector& keyBlob,
+                                          const ByteVector& publicKey) const {
+    // To verify the key is valid, we use it to sign some data, and then verify the signature using
+    // the supplied public key.
+
+    ByteVector data(32);
+    if (RAND_bytes(data.data(), data.size()) != 1) {
+        return Error() << "No random bytes";
+    }
+
+    auto signature = signData(keyBlob, data);
+    if (!signature.ok()) {
+        return signature.error();
+    }
+
+    std::string dataStr(data.begin(), data.end());
+    std::string signatureStr(signature.value().begin(), signature.value().end());
+
+    return verifySignature(dataStr, signatureStr, publicKey);
+}
diff --git a/ondevice-signing/FakeCompOs.h b/ondevice-signing/FakeCompOs.h
new file mode 100644
index 0000000..7d76938
--- /dev/null
+++ b/ondevice-signing/FakeCompOs.h
@@ -0,0 +1,60 @@
+/*
+ * Copyright (C) 2020 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include <android-base/result.h>
+
+#include <utils/StrongPointer.h>
+
+#include <android/system/keystore2/IKeystoreService.h>
+
+class FakeCompOs {
+    using IKeystoreService = ::android::system::keystore2::IKeystoreService;
+    using IKeystoreSecurityLevel = ::android::system::keystore2::IKeystoreSecurityLevel;
+    using KeyDescriptor = ::android::system::keystore2::KeyDescriptor;
+    using KeyMetadata = ::android::system::keystore2::KeyMetadata;
+
+  public:
+    using ByteVector = std::vector<uint8_t>;
+    struct KeyData {
+        ByteVector cert;
+        ByteVector blob;
+    };
+
+    static android::base::Result<std::unique_ptr<FakeCompOs>> newInstance();
+
+    android::base::Result<KeyData> generateKey() const;
+
+    android::base::Result<void> loadAndVerifyKey(const ByteVector& keyBlob,
+                                                 const ByteVector& publicKey) const;
+
+  private:
+    FakeCompOs();
+
+    android::base::Result<void> initialize();
+
+    android::base::Result<ByteVector> signData(const ByteVector& keyBlob,
+                                               const ByteVector& data) const;
+
+    KeyDescriptor mDescriptor;
+    android::sp<IKeystoreService> mService;
+    android::sp<IKeystoreSecurityLevel> mSecurityLevel;
+};
diff --git a/ondevice-signing/KeystoreKey.cpp b/ondevice-signing/KeystoreKey.cpp
index 0951d92..4f41d4b 100644
--- a/ondevice-signing/KeystoreKey.cpp
+++ b/ondevice-signing/KeystoreKey.cpp
@@ -297,19 +297,13 @@
 
     auto status = mSecurityLevel->createOperation(mDescriptor, opParameters, false, &opResponse);
     if (!status.isOk()) {
-        return Error() << "Failed to create keystore signing operation: "
-                       << status.serviceSpecificErrorCode();
+        return Error() << "Failed to create keystore signing operation: " << status;
     }
     auto operation = opResponse.iOperation;
 
-    std::optional<std::vector<uint8_t>> out;
-    status = operation->update({message.begin(), message.end()}, &out);
-    if (!status.isOk()) {
-        return Error() << "Failed to call keystore update operation.";
-    }
-
+    std::optional<std::vector<uint8_t>> input{std::in_place, message.begin(), message.end()};
     std::optional<std::vector<uint8_t>> signature;
-    status = operation->finish({}, {}, &signature);
+    status = operation->finish(input, {}, &signature);
     if (!status.isOk()) {
         return Error() << "Failed to call keystore finish operation.";
     }
diff --git a/ondevice-signing/KeystoreKey.h b/ondevice-signing/KeystoreKey.h
index 1257cbb..f2fbb70 100644
--- a/ondevice-signing/KeystoreKey.h
+++ b/ondevice-signing/KeystoreKey.h
@@ -20,7 +20,6 @@
 
 #include <android-base/macros.h>
 #include <android-base/result.h>
-#include <android-base/unique_fd.h>
 
 #include <utils/StrongPointer.h>
 
diff --git a/ondevice-signing/VerityUtils.cpp b/ondevice-signing/VerityUtils.cpp
index 25f949c..56dcd5e 100644
--- a/ondevice-signing/VerityUtils.cpp
+++ b/ondevice-signing/VerityUtils.cpp
@@ -50,13 +50,6 @@
 #define le16_to_cpu(v) (__builtin_bswap16((__force uint16_t)(v)))
 #endif
 
-struct fsverity_signed_digest {
-    char magic[8]; /* must be "FSVerity" */
-    __le16 digest_algorithm;
-    __le16 digest_size;
-    __u8 digest[];
-};
-
 static std::string toHex(std::span<uint8_t> data) {
     std::stringstream ss;
     for (auto it = data.begin(); it != data.end(); ++it) {
@@ -121,7 +114,7 @@
 
 static Result<std::vector<uint8_t>> signDigest(const SigningKey& key,
                                                const std::vector<uint8_t>& digest) {
-    auto d = makeUniqueWithTrailingData<fsverity_signed_digest>(digest.size());
+    auto d = makeUniqueWithTrailingData<fsverity_formatted_digest>(digest.size());
 
     memcpy(d->magic, "FSVerity", 8);
     d->digest_algorithm = cpu_to_le16(FS_VERITY_HASH_ALG_SHA256);
@@ -247,6 +240,9 @@
     const char* const argv[] = {kFsVerityInitPath, "--load-extra-key", keyName};
 
     int fd = open(path.c_str(), O_RDONLY | O_CLOEXEC);
+    if (fd == -1) {
+        return ErrnoError() << "Failed to open " << path;
+    }
     pid_t pid = fork();
     if (pid == 0) {
         dup2(fd, STDIN_FILENO);
@@ -271,10 +267,8 @@
     if (!WIFEXITED(status)) {
         return Error() << kFsVerityInitPath << ": abnormal process exit";
     }
-    if (WEXITSTATUS(status)) {
-        if (status != 0) {
-            return Error() << kFsVerityInitPath << " exited with " << status;
-        }
+    if (WEXITSTATUS(status) != 0) {
+        return Error() << kFsVerityInitPath << " exited with " << WEXITSTATUS(status);
     }
 
     return {};
diff --git a/ondevice-signing/odsign_main.cpp b/ondevice-signing/odsign_main.cpp
index 135c4a0..5fad7fc 100644
--- a/ondevice-signing/odsign_main.cpp
+++ b/ondevice-signing/odsign_main.cpp
@@ -32,6 +32,7 @@
 #include <odrefresh/odrefresh.h>
 
 #include "CertUtils.h"
+#include "FakeCompOs.h"
 #include "KeystoreKey.h"
 #include "VerityUtils.h"
 
@@ -58,7 +59,10 @@
 static const bool kUseCompOs = false;  // STOPSHIP if true
 
 static const char* kVirtApexPath = "/apex/com.android.virt";
+static const char* kCompOsCommonName = "CompOS";
 const std::string kCompOsCert = "/data/misc/odsign/compos_key.cert";
+const std::string kCompOsPublicKey = "/data/misc/odsign/compos_key.pubkey";
+const std::string kCompOsKeyBlob = "/data/misc/odsign/compos_key.blob";
 
 static const char* kOdsignVerificationDoneProp = "odsign.verification.done";
 static const char* kOdsignKeyDoneProp = "odsign.key.done";
@@ -67,6 +71,17 @@
 static const char* kOdsignVerificationStatusValid = "1";
 static const char* kOdsignVerificationStatusError = "0";
 
+static void writeBytesToFile(const std::vector<uint8_t>& bytes, const std::string& path) {
+    std::string str(bytes.begin(), bytes.end());
+    android::base::WriteStringToFile(str, path);
+}
+
+static std::vector<uint8_t> readBytesFromFile(const std::string& path) {
+    std::string str;
+    android::base::ReadFileToString(path, &str);
+    return std::vector<uint8_t>(str.begin(), str.end());
+}
+
 bool compOsPresent() {
     return access(kVirtApexPath, F_OK) == 0;
 }
@@ -102,8 +117,7 @@
     }
 
     auto keySignFunction = [&](const std::string& to_be_signed) { return key.sign(to_be_signed); };
-    createSelfSignedCertificate(*publicKey, keySignFunction, outPath);
-    return {};
+    return createSelfSignedCertificate(*publicKey, keySignFunction, outPath);
 }
 
 Result<std::vector<uint8_t>> extractPublicKeyFromLeafCert(const SigningKey& key,
@@ -132,6 +146,63 @@
     return existingCertInfo.value().subjectKey;
 }
 
+Result<void> verifyOrGenerateCompOsKey(const SigningKey& signingKey) {
+    auto compOsStatus = FakeCompOs::newInstance();
+    if (!compOsStatus.ok()) {
+        return Error() << "Failed to start CompOs: " << compOsStatus.error();
+    }
+
+    FakeCompOs* compOs = compOsStatus.value().get();
+
+    std::vector<uint8_t> keyBlob;
+    std::vector<uint8_t> publicKey;
+    bool haveKey = false;
+
+    if (access(kCompOsPublicKey.c_str(), F_OK) == 0 && access(kCompOsKeyBlob.c_str(), F_OK) == 0) {
+        // We have a purported key, but not a valid signature for it.
+        // If compOs can verify it, we can sign it now.
+        keyBlob = readBytesFromFile(kCompOsKeyBlob);
+        publicKey = readBytesFromFile(kCompOsPublicKey);
+
+        auto response = compOs->loadAndVerifyKey(keyBlob, publicKey);
+        if (response.ok()) {
+            LOG(INFO) << "Verified existing CompOs key";
+            haveKey = true;
+        } else {
+            LOG(WARNING) << "Failed to verify existing CompOs key: " << response.error();
+        }
+    }
+
+    if (!haveKey) {
+        // If we don't have a key, or it doesn't verify, then we need a new one.
+        auto keyData = compOs->generateKey();
+        if (!keyData.ok()) {
+            return Error() << "Failed to generate key: " << keyData.error();
+        }
+        auto publicKeyStatus = extractPublicKeyFromX509(keyData.value().cert);
+        if (!publicKeyStatus.ok()) {
+            return Error() << "Failed to extract CompOs public key" << publicKeyStatus.error();
+        }
+
+        keyBlob = std::move(keyData.value().blob);
+        publicKey = std::move(publicKeyStatus.value());
+
+        writeBytesToFile(keyBlob, kCompOsKeyBlob);
+        writeBytesToFile(publicKey, kCompOsPublicKey);
+    }
+
+    auto signFunction = [&](const std::string& to_be_signed) {
+        return signingKey.sign(to_be_signed);
+    };
+    auto certStatus = createLeafCertificate(kCompOsCommonName, publicKey, signFunction,
+                                            kSigningKeyCert, kCompOsCert);
+    if (!certStatus.ok()) {
+        return Error() << "Failed to create CompOs cert: " << certStatus.error();
+    }
+
+    return {};
+}
+
 art::odrefresh::ExitCode compileArtifacts(bool force) {
     const char* const argv[] = {kOdrefreshPath, force ? "--force-compile" : "--compile"};
     const int exit_code =
@@ -334,7 +405,7 @@
 
     auto keystoreResult = KeystoreKey::getInstance();
     if (!keystoreResult.ok()) {
-        LOG(ERROR) << "Could not create keystore key: " << keystoreResult.error().message();
+        LOG(ERROR) << "Could not create keystore key: " << keystoreResult.error();
         return -1;
     }
     SigningKey* key = keystoreResult.value();
@@ -349,12 +420,12 @@
     if (supportsFsVerity) {
         auto existing_cert = verifyExistingRootCert(*key);
         if (!existing_cert.ok()) {
-            LOG(WARNING) << existing_cert.error().message();
+            LOG(WARNING) << existing_cert.error();
 
             // Try to create a new cert
             auto new_cert = createX509RootCert(*key, kSigningKeyCert);
             if (!new_cert.ok()) {
-                LOG(ERROR) << "Failed to create X509 certificate: " << new_cert.error().message();
+                LOG(ERROR) << "Failed to create X509 certificate: " << new_cert.error();
                 // TODO apparently the key become invalid - delete the blob / cert
                 return -1;
             }
@@ -364,29 +435,34 @@
         auto cert_add_result = addCertToFsVerityKeyring(kSigningKeyCert, "fsv_ods");
         if (!cert_add_result.ok()) {
             LOG(ERROR) << "Failed to add certificate to fs-verity keyring: "
-                       << cert_add_result.error().message();
+                       << cert_add_result.error();
             return -1;
         }
     }
 
     if (supportsCompOs) {
-        auto compos_key = extractPublicKeyFromLeafCert(*key, kCompOsCert, "CompOS");
-        if (compos_key.ok()) {
-            auto cert_add_result = addCertToFsVerityKeyring(kCompOsCert, "fsv_compos");
-            if (cert_add_result.ok()) {
-                LOG(INFO) << "Added CompOs key to fs-verity keyring";
+        auto compos_key = extractPublicKeyFromLeafCert(*key, kCompOsCert, kCompOsCommonName);
+        if (!compos_key.ok()) {
+            LOG(WARNING) << compos_key.error();
+
+            auto status = verifyOrGenerateCompOsKey(*key);
+            if (!status.ok()) {
+                LOG(ERROR) << status.error();
             } else {
-                LOG(ERROR) << "Failed to add CompOs certificate to fs-verity keyring: "
-                           << cert_add_result.error().message();
-                // TODO - what do we do now?
-                // return -1;
+                LOG(INFO) << "Generated new CompOs public key certificate";
             }
         } else {
-            LOG(ERROR) << "Failed to retrieve key from CompOs certificate: "
-                       << compos_key.error().message();
+            LOG(INFO) << "Found and verified existing CompOs public key certificate: "
+                      << kCompOsCert;
+        };
+        auto cert_add_result = addCertToFsVerityKeyring(kCompOsCert, "fsv_compos");
+        if (!cert_add_result.ok()) {
+            LOG(ERROR) << "Failed to add CompOs certificate to fs-verity keyring: "
+                       << cert_add_result.error();
             // Best efforts only - nothing we can do if deletion fails.
             unlink(kCompOsCert.c_str());
             // TODO - what do we do now?
+            // return -1;
         }
     }
 
@@ -402,7 +478,7 @@
         if (artifactsPresent) {
             auto verificationResult = verifyArtifacts(*key, supportsFsVerity);
             if (!verificationResult.ok()) {
-                LOG(ERROR) << verificationResult.error().message();
+                LOG(ERROR) << verificationResult.error();
                 return -1;
             }
         }
@@ -420,12 +496,12 @@
             digests = computeDigests(kArtArtifactsDir);
         }
         if (!digests.ok()) {
-            LOG(ERROR) << digests.error().message();
+            LOG(ERROR) << digests.error();
             return -1;
         }
         auto persistStatus = persistDigests(*digests, *key);
         if (!persistStatus.ok()) {
-            LOG(ERROR) << persistStatus.error().message();
+            LOG(ERROR) << persistStatus.error();
             return -1;
         }
     } else if (odrefresh_status == art::odrefresh::ExitCode::kCleanupFailed) {
@@ -444,5 +520,6 @@
     // And we did a successful verification
     SetProperty(kOdsignVerificationDoneProp, "1");
     SetProperty(kOdsignVerificationStatusProp, kOdsignVerificationStatusValid);
+
     return 0;
 }
diff --git a/provisioner/Android.bp b/provisioner/Android.bp
index 12a21d1..ea84063 100644
--- a/provisioner/Android.bp
+++ b/provisioner/Android.bp
@@ -43,15 +43,6 @@
     },
 }
 
-java_binary {
-    name: "provisioner_cli",
-    wrapper: "provisioner_cli",
-    srcs: ["src/com/android/commands/provisioner/**/*.java"],
-    static_libs: [
-        "android.security.provisioner-java",
-    ],
-}
-
 cc_binary {
     name: "rkp_factory_extraction_tool",
     srcs: ["rkp_factory_extraction_tool.cpp"],
diff --git a/provisioner/provisioner_cli b/provisioner/provisioner_cli
deleted file mode 100755
index 7b53d6e..0000000
--- a/provisioner/provisioner_cli
+++ /dev/null
@@ -1,21 +0,0 @@
-#!/system/bin/sh
-#
-# Copyright (C) 2020 The Android Open Source Project
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#      http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-# Script to start "provisioner_cli" on the device.
-#
-base=/system
-export CLASSPATH=$base/framework/provisioner_cli.jar
-exec app_process $base/bin com.android.commands.provisioner.Cli "$@"
diff --git a/provisioner/src/com/android/commands/provisioner/Cli.java b/provisioner/src/com/android/commands/provisioner/Cli.java
deleted file mode 100644
index 62afdac..0000000
--- a/provisioner/src/com/android/commands/provisioner/Cli.java
+++ /dev/null
@@ -1,141 +0,0 @@
-/*
- * Copyright 2020 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.android.commands.provisioner;
-
-import android.os.IBinder;
-import android.os.RemoteException;
-import android.os.ServiceManager;
-import android.security.provisioner.IProvisionerService;
-
-import com.android.internal.os.BaseCommand;
-
-import java.io.ByteArrayOutputStream;
-import java.io.InputStream;
-import java.io.IOException;
-import java.io.PrintStream;
-import java.lang.IllegalArgumentException;
-
-/**
- * Contains the implementation of the remote provisioning command-line interface.
- */
-public class Cli extends BaseCommand {
-    /**
-     * Creates an instance of the command-line interface and runs it. This is the entry point of
-     * the tool.
-     */
-    public static void main(String[] args) {
-        new Cli().run(args);
-    }
-
-    /**
-     * Runs the command requested by the invoker. It parses the very first required argument, which
-     * is the command, and calls the appropriate handler.
-     */
-    @Override
-    public void onRun() throws Exception {
-        String cmd = nextArgRequired();
-        switch (cmd) {
-        case "get-req":
-            getRequest();
-            break;
-
-        case "help":
-            onShowUsage(System.out);
-            break;
-
-        default:
-            throw new IllegalArgumentException("unknown command: " + cmd);
-        }
-    }
-
-    /**
-     * Retrieves a 'certificate request' from the provisioning service. The COSE-encoded
-     * 'certificate chain' describing the endpoint encryption key (EEK) to use for encryption is
-     * read from the standard input. The retrieved request is written to the standard output.
-     */
-    private void getRequest() throws Exception {
-        // Process options.
-        boolean test = false;
-        byte[] challenge = null;
-        int count = 0;
-        String arg;
-        while ((arg = nextArg()) != null) {
-            switch (arg) {
-            case "--test":
-                test = true;
-                break;
-
-            case "--challenge":
-                // TODO: We may need a different encoding of the challenge.
-                challenge = nextArgRequired().getBytes();
-                break;
-
-            case "--count":
-                count = Integer.parseInt(nextArgRequired());
-                if (count < 0) {
-                    throw new IllegalArgumentException(
-                            "--count must be followed by non-negative number");
-                }
-                break;
-
-            default:
-                throw new IllegalArgumentException("unknown argument: " + arg);
-            }
-        }
-
-        // Send the request over to the provisioning service and write the result to stdout.
-        byte[] res = getService().getCertificateRequest(test, count, readAll(System.in), challenge);
-        if (res != null) {
-            System.out.write(res);
-        }
-    }
-
-    /**
-     * Retrieves an implementation of the IProvisionerService interface. It allows the caller to
-     * call into the service via binder.
-     */
-    private static IProvisionerService getService() throws RemoteException {
-        IBinder binder = ServiceManager.getService("remote-provisioner");
-        if (binder == null) {
-            throw new RemoteException("Provisioning service is inaccessible");
-        }
-        return IProvisionerService.Stub.asInterface(binder);
-    }
-
-    /** Reads all data from the provided input stream and returns it as a byte array. */
-    private static byte[] readAll(InputStream in) throws IOException {
-        ByteArrayOutputStream out = new ByteArrayOutputStream();
-        byte[] buf = new byte[1024];
-        int read;
-        while ((read = in.read(buf)) != -1) {
-            out.write(buf, 0, read);
-        }
-        return out.toByteArray();
-    }
-
-    /**
-     * Writes the usage information to the given stream. This is displayed to users of the tool when
-     * they ask for help or when they pass incorrect arguments to the tool.
-     */
-    @Override
-    public void onShowUsage(PrintStream out) {
-        out.println(
-                "Usage: provisioner_cli <command> [options]\n" +
-                "Commands: help\n" +
-                "          get-req [--count <n>] [--test] [--challenge <v>]");
-    }
-}