Migrate ComposKeyTestCase to unit tests

Refactor SigningKey to allow it to be decoupled from Dice for
testing. Add unit tests with the same coverage as ComposKeyTestCase,
then remove the latter. (This does mean we are no longer testing the
code inside a VM, but that's covered by ComposTestCase and the end to
end tests.)

As a result we no longer need compos_key_cmd, so remove that. This is
a win as it was becoming a maintenance burden.

Bug: 213891964
Test: atest compsvc_device_tests
Change-Id: If863abcb4e89eeb97a4be6a4a958b691aa1446be
diff --git a/compos/apex/Android.bp b/compos/apex/Android.bp
index ea72018..aec3c88 100644
--- a/compos/apex/Android.bp
+++ b/compos/apex/Android.bp
@@ -33,7 +33,7 @@
     key: "com.android.compos.key",
     certificate: ":com.android.compos.certificate",
 
-    // TODO(victorhsieh): make it updatable
+    // TODO(b/206618706): make it updatable
     updatable: false,
     future_updatable: true,
     platform_apis: true,
@@ -42,7 +42,6 @@
 
     binaries: [
         // Used in Android
-        "compos_key_cmd",
         "compos_verify_key",
         "composd",
         "composd_cmd",
diff --git a/compos/apk/assets/vm_test_config.json b/compos/apk/assets/vm_test_config.json
deleted file mode 100644
index 16d1037..0000000
--- a/compos/apk/assets/vm_test_config.json
+++ /dev/null
@@ -1,24 +0,0 @@
-{
-  "version": 1,
-  "os": {
-    "name": "microdroid"
-  },
-  "task": {
-    "type": "executable",
-    "command": "/apex/com.android.compos/bin/compsvc"
-  },
-  "apexes": [
-    {
-      "name": "com.android.art"
-    },
-    {
-      "name": "com.android.compos"
-    },
-    {
-      "name": "com.android.sdkext"
-    },
-    {
-      "name": "{CLASSPATH}"
-    }
-  ]
-}
diff --git a/compos/compos_key_cmd/Android.bp b/compos/compos_key_cmd/Android.bp
deleted file mode 100644
index d412f66..0000000
--- a/compos/compos_key_cmd/Android.bp
+++ /dev/null
@@ -1,23 +0,0 @@
-package {
-    default_applicable_licenses: ["Android-Apache-2.0"],
-}
-
-cc_binary {
-    name: "compos_key_cmd",
-    srcs: ["compos_key_cmd.cpp"],
-    apex_available: ["com.android.compos"],
-
-    static_libs: [
-        "lib_odsign_proto",
-    ],
-
-    shared_libs: [
-        "android.system.virtualizationservice-ndk",
-        "compos_aidl_interface-ndk",
-        "libbase",
-        "libbinder_ndk",
-        "libbinder_rpc_unstable",
-        "libfsverity",
-        "libprotobuf-cpp-lite",
-    ],
-}
diff --git a/compos/compos_key_cmd/compos_key_cmd.cpp b/compos/compos_key_cmd/compos_key_cmd.cpp
deleted file mode 100644
index 07ff636..0000000
--- a/compos/compos_key_cmd/compos_key_cmd.cpp
+++ /dev/null
@@ -1,530 +0,0 @@
-/*
- * Copyright (C) 2021 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include <aidl/android/system/virtualizationservice/BnVirtualMachineCallback.h>
-#include <aidl/android/system/virtualizationservice/IVirtualizationService.h>
-#include <aidl/com/android/compos/ICompOsService.h>
-#include <android-base/file.h>
-#include <android-base/logging.h>
-#include <android-base/result.h>
-#include <android-base/unique_fd.h>
-#include <android/binder_auto_utils.h>
-#include <android/binder_manager.h>
-#include <android/binder_process.h>
-#include <asm/byteorder.h>
-#include <libfsverity.h>
-#include <linux/fsverity.h>
-#include <stdio.h>
-#include <unistd.h>
-
-#include <binder_rpc_unstable.hpp>
-#include <chrono>
-#include <condition_variable>
-#include <filesystem>
-#include <iostream>
-#include <map>
-#include <mutex>
-#include <string>
-#include <string_view>
-#include <thread>
-
-#include "odsign_info.pb.h"
-
-using namespace std::literals;
-
-using aidl::android::system::virtualizationservice::BnVirtualMachineCallback;
-using aidl::android::system::virtualizationservice::DeathReason;
-using aidl::android::system::virtualizationservice::IVirtualizationService;
-using aidl::android::system::virtualizationservice::IVirtualMachine;
-using aidl::android::system::virtualizationservice::IVirtualMachineCallback;
-using aidl::android::system::virtualizationservice::PartitionType;
-using aidl::android::system::virtualizationservice::VirtualMachineAppConfig;
-using aidl::android::system::virtualizationservice::VirtualMachineConfig;
-using aidl::com::android::compos::CompOsKeyData;
-using aidl::com::android::compos::ICompOsService;
-using android::base::Dirname;
-using android::base::ErrnoError;
-using android::base::Error;
-using android::base::Fdopen;
-using android::base::Result;
-using android::base::unique_fd;
-using android::base::WriteFully;
-using ndk::ScopedAStatus;
-using ndk::ScopedFileDescriptor;
-using ndk::SharedRefBase;
-using odsign::proto::OdsignInfo;
-
-constexpr unsigned int kRpcPort = 6432;
-
-constexpr int kVmMemoryMib = 1024;
-
-constexpr const char* kConfigApkPath =
-        "/apex/com.android.compos/app/CompOSPayloadApp/CompOSPayloadApp.apk";
-
-// These are paths inside the APK
-constexpr const char* kDefaultConfigFilePath = "assets/vm_config.json";
-constexpr const char* kPreferStagedConfigFilePath = "assets/vm_config_staged.json";
-
-static bool writeBytesToFile(const std::vector<uint8_t>& bytes, const std::string& path) {
-    std::string str(bytes.begin(), bytes.end());
-    return android::base::WriteStringToFile(str, path);
-}
-
-static Result<std::vector<uint8_t>> readBytesFromFile(const std::string& path) {
-    std::string str;
-    if (!android::base::ReadFileToString(path, &str)) {
-        return Error() << "Failed to read " << path;
-    }
-    return std::vector<uint8_t>(str.begin(), str.end());
-}
-
-static std::shared_ptr<ICompOsService> getService(int cid) {
-    LOG(INFO) << "Connecting to cid " << cid;
-    ndk::SpAIBinder binder(cid == 0 ? AServiceManager_getService("android.system.composkeyservice")
-                                    : RpcClient(cid, kRpcPort));
-    return ICompOsService::fromBinder(binder);
-}
-
-namespace {
-
-void copyToLog(unique_fd&& fd) {
-    FILE* source = Fdopen(std::move(fd), "r");
-    if (source == nullptr) {
-        LOG(INFO) << "Can't log VM output";
-        return;
-    }
-    size_t size = 0;
-    char* line = nullptr;
-
-    LOG(INFO) << "Started logging VM output";
-
-    for (;;) {
-        ssize_t len = getline(&line, &size, source);
-        if (len < 0) {
-            LOG(INFO) << "VM logging ended: " << ErrnoError().str();
-            break;
-        }
-        LOG(DEBUG) << "VM: " << std::string_view(line, len);
-    }
-    free(line);
-}
-
-class Callback : public BnVirtualMachineCallback {
-public:
-    ::ndk::ScopedAStatus onPayloadStarted(int32_t in_cid,
-                                          const ::ndk::ScopedFileDescriptor& stream) override {
-        LOG(INFO) << "Payload started! cid = " << in_cid;
-
-        unique_fd stream_fd(dup(stream.get()));
-        std::thread logger([fd = std::move(stream_fd)]() mutable { copyToLog(std::move(fd)); });
-        logger.detach();
-
-        return ScopedAStatus::ok();
-    }
-
-    ::ndk::ScopedAStatus onPayloadReady(int32_t in_cid) override {
-        LOG(INFO) << "Payload is ready! cid = " << in_cid;
-        {
-            std::unique_lock lock(mMutex);
-            mReady = true;
-        }
-        mCv.notify_all();
-        return ScopedAStatus::ok();
-    }
-
-    ::ndk::ScopedAStatus onPayloadFinished(int32_t in_cid, int32_t in_exit_code) override {
-        LOG(INFO) << "Payload finished! cid = " << in_cid << ", exit_code = " << in_exit_code;
-        return ScopedAStatus::ok();
-    }
-
-    ::ndk::ScopedAStatus onError(int32_t in_cid, int32_t in_error_code,
-                                 const std::string& in_message) override {
-        LOG(WARNING) << "VM error! cid = " << in_cid << ", error_code = " << in_error_code
-                     << ", message = " << in_message;
-        {
-            std::unique_lock lock(mMutex);
-            mDied = true;
-        }
-        mCv.notify_all();
-        return ScopedAStatus::ok();
-    }
-
-    ::ndk::ScopedAStatus onDied(int32_t in_cid, DeathReason reason) override {
-        LOG(WARNING) << "VM died! cid = " << in_cid << " reason = " << toString(reason);
-        {
-            std::unique_lock lock(mMutex);
-            mDied = true;
-        }
-        mCv.notify_all();
-        return ScopedAStatus::ok();
-    }
-
-    bool waitUntilReady() {
-        std::unique_lock lock(mMutex);
-        // 10s is long enough on real hardware, but it can take 90s when using nested
-        // virtualization.
-        // TODO(b/200924405): Reduce timeout/detect nested virtualization
-        return mCv.wait_for(lock, std::chrono::seconds(120), [this] { return mReady || mDied; }) &&
-                !mDied;
-    }
-
-private:
-    std::mutex mMutex;
-    std::condition_variable mCv;
-    bool mReady;
-    bool mDied;
-};
-
-class TargetVm {
-public:
-    TargetVm(int cid, const std::string& logFile, const std::string& instanceImageFile,
-             bool debuggable, bool preferStaged)
-          : mCid(cid),
-            mLogFile(logFile),
-            mInstanceImageFile(instanceImageFile),
-            mDebuggable(debuggable),
-            mPreferStaged(preferStaged) {}
-
-    // Returns 0 if we are to connect to a local service, otherwise the CID of
-    // either an existing VM or a VM we have started, depending on the command
-    // line arguments.
-    Result<int> resolveCid() {
-        if (mInstanceImageFile.empty()) {
-            return mCid;
-        }
-        if (mCid != 0) {
-            return Error() << "Can't specify both cid and image file.";
-        }
-
-        // Start a new VM with a given instance.img
-
-        // We need a thread pool to receive VM callbacks.
-        ABinderProcess_startThreadPool();
-
-        ndk::SpAIBinder binder(
-                AServiceManager_waitForService("android.system.virtualizationservice"));
-        auto service = IVirtualizationService::fromBinder(binder);
-        if (!service) {
-            return Error() << "Failed to connect to virtualization service.";
-        }
-
-        // Console output and the system log output from the VM are redirected to this file.
-        ScopedFileDescriptor logFd;
-        if (mLogFile.empty()) {
-            logFd.set(dup(STDOUT_FILENO));
-            if (logFd.get() == -1) {
-                return ErrnoError() << "dup() failed: ";
-            }
-        } else {
-            logFd.set(TEMP_FAILURE_RETRY(open(mLogFile.c_str(),
-                                              O_WRONLY | O_CREAT | O_TRUNC | O_CLOEXEC,
-                                              S_IRUSR | S_IWUSR)));
-            if (logFd.get() == -1) {
-                return ErrnoError() << "Failed to open " << mLogFile;
-            }
-        }
-
-        ScopedFileDescriptor apkFd(TEMP_FAILURE_RETRY(open(kConfigApkPath, O_RDONLY | O_CLOEXEC)));
-        if (apkFd.get() == -1) {
-            return ErrnoError() << "Failed to open config APK";
-        }
-
-        // Prepare an idsig file
-        std::string idsigPath = Dirname(mInstanceImageFile) + "/idsig";
-        {
-            ScopedFileDescriptor idsigFd(TEMP_FAILURE_RETRY(
-                    open(idsigPath.c_str(), O_RDWR | O_CREAT | O_TRUNC | O_CLOEXEC,
-                         S_IRUSR | S_IWUSR | S_IRGRP)));
-            if (idsigFd.get() == -1) {
-                return ErrnoError() << "Failed to create an idsig file";
-            }
-            auto status = service->createOrUpdateIdsigFile(apkFd, idsigFd);
-            if (!status.isOk()) {
-                return Error() << status.getDescription();
-            }
-        }
-
-        ScopedFileDescriptor idsigFd(
-                TEMP_FAILURE_RETRY(open(idsigPath.c_str(), O_RDONLY | O_CLOEXEC)));
-        if (idsigFd.get() == -1) {
-            return ErrnoError() << "Failed to open an idsig file";
-        }
-
-        ScopedFileDescriptor instanceFd(
-                TEMP_FAILURE_RETRY(open(mInstanceImageFile.c_str(), O_RDWR | O_CLOEXEC)));
-        if (instanceFd.get() == -1) {
-            return ErrnoError() << "Failed to open instance image file";
-        }
-
-        auto config = VirtualMachineConfig::make<VirtualMachineConfig::Tag::appConfig>();
-        auto& appConfig = config.get<VirtualMachineConfig::Tag::appConfig>();
-        appConfig.apk = std::move(apkFd);
-        appConfig.idsig = std::move(idsigFd);
-        appConfig.instanceImage = std::move(instanceFd);
-        appConfig.configPath = mPreferStaged ? kPreferStagedConfigFilePath : kDefaultConfigFilePath;
-        appConfig.debugLevel = mDebuggable ? VirtualMachineAppConfig::DebugLevel::FULL
-                                           : VirtualMachineAppConfig::DebugLevel::NONE;
-        appConfig.memoryMib = kVmMemoryMib;
-
-        LOG(INFO) << "Starting VM";
-        auto status = service->createVm(config, logFd, logFd, &mVm);
-        if (!status.isOk()) {
-            return Error() << status.getDescription();
-        }
-
-        int32_t cid;
-        status = mVm->getCid(&cid);
-        if (!status.isOk()) {
-            return Error() << status.getDescription();
-        }
-
-        LOG(INFO) << "Created VM with CID = " << cid;
-
-        // We need to use this rather than std::make_shared to make sure the
-        // embedded weak_ptr is initialised.
-        mCallback = SharedRefBase::make<Callback>();
-
-        status = mVm->registerCallback(mCallback);
-        if (!status.isOk()) {
-            return Error() << status.getDescription();
-        }
-
-        status = mVm->start();
-        if (!status.isOk()) {
-            return Error() << status.getDescription();
-        }
-        LOG(INFO) << "Started VM";
-
-        if (!mCallback->waitUntilReady()) {
-            return Error() << "VM Payload failed to start";
-        }
-
-        return cid;
-    }
-
-private:
-    const int mCid;
-    const std::string mLogFile;
-    const std::string mInstanceImageFile;
-    const bool mDebuggable;
-    const bool mPreferStaged;
-    std::shared_ptr<Callback> mCallback;
-    std::shared_ptr<IVirtualMachine> mVm;
-};
-
-} // namespace
-
-static Result<void> generate(TargetVm& vm, const std::string& blob_file,
-                             const std::string& public_key_file) {
-    auto cid = vm.resolveCid();
-    if (!cid.ok()) {
-        return cid.error();
-    }
-    auto service = getService(*cid);
-    if (!service) {
-        return Error() << "No service";
-    }
-
-    CompOsKeyData key_data;
-    auto status = service->generateSigningKey(&key_data);
-    if (!status.isOk()) {
-        return Error() << "Failed to generate key: " << status.getDescription();
-    }
-
-    if (!writeBytesToFile(key_data.keyBlob, blob_file)) {
-        return Error() << "Failed to write keyBlob to " << blob_file;
-    }
-
-    if (!writeBytesToFile(key_data.publicKey, public_key_file)) {
-        return Error() << "Failed to write public key to " << public_key_file;
-    }
-
-    return {};
-}
-
-static Result<bool> verify(TargetVm& vm, const std::string& blob_file,
-                           const std::string& public_key_file) {
-    auto cid = vm.resolveCid();
-    if (!cid.ok()) {
-        return cid.error();
-    }
-    auto service = getService(*cid);
-    if (!service) {
-        return Error() << "No service";
-    }
-
-    auto blob = readBytesFromFile(blob_file);
-    if (!blob.ok()) {
-        return blob.error();
-    }
-
-    auto public_key = readBytesFromFile(public_key_file);
-    if (!public_key.ok()) {
-        return public_key.error();
-    }
-
-    bool result = false;
-    auto status = service->verifySigningKey(blob.value(), public_key.value(), &result);
-    if (!status.isOk()) {
-        return Error() << "Failed to verify key: " << status.getDescription();
-    }
-
-    return result;
-}
-
-static Result<void> initializeKey(TargetVm& vm, const std::string& blob_file) {
-    auto cid = vm.resolveCid();
-    if (!cid.ok()) {
-        return cid.error();
-    }
-    auto service = getService(*cid);
-    if (!service) {
-        return Error() << "No service";
-    }
-
-    auto blob = readBytesFromFile(blob_file);
-    if (!blob.ok()) {
-        return blob.error();
-    }
-
-    auto status = service->initializeSigningKey(blob.value());
-    if (!status.isOk()) {
-        return Error() << "Failed to initialize signing key: " << status.getDescription();
-    }
-    return {};
-}
-
-static Result<void> makeInstanceImage(const std::string& image_path) {
-    ndk::SpAIBinder binder(AServiceManager_waitForService("android.system.virtualizationservice"));
-    auto service = IVirtualizationService::fromBinder(binder);
-    if (!service) {
-        return Error() << "Failed to connect to virtualization service.";
-    }
-
-    ScopedFileDescriptor fd(TEMP_FAILURE_RETRY(
-            open(image_path.c_str(), O_CREAT | O_RDWR | O_TRUNC | O_CLOEXEC, S_IRUSR | S_IWUSR)));
-    if (fd.get() == -1) {
-        return ErrnoError() << "Failed to create image file";
-    }
-
-    auto status = service->initializeWritablePartition(fd, 10 * 1024 * 1024,
-                                                       PartitionType::ANDROID_VM_INSTANCE);
-    if (!status.isOk()) {
-        return Error() << "Failed to initialize partition: " << status.getDescription();
-    }
-    return {};
-}
-
-int main(int argc, char** argv) {
-    // Restrict access to our outputs to the current user.
-    umask(077);
-
-    int cid = 0;
-    std::string imageFile;
-    std::string logFile;
-    bool debuggable = false;
-    bool preferStaged = false;
-
-    for (;;) {
-        // Options with no associated value
-        if (argc >= 2) {
-            if (argv[1] == "--debug"sv) {
-                debuggable = true;
-                argc -= 1;
-                argv += 1;
-                continue;
-            } else if (argv[1] == "--staged"sv) {
-                preferStaged = true;
-                argc -= 1;
-                argv += 1;
-                continue;
-            }
-        }
-        if (argc < 3) break;
-        // Options requiring a value
-        if (argv[1] == "--cid"sv) {
-            cid = atoi(argv[2]);
-            if (cid == 0) {
-                std::cerr << "Invalid cid\n";
-                return 1;
-            }
-        } else if (argv[1] == "--start"sv) {
-            imageFile = argv[2];
-        } else if (argv[1] == "--log"sv) {
-            logFile = argv[2];
-        } else {
-            break;
-        }
-        argc -= 2;
-        argv += 2;
-    }
-
-    TargetVm vm(cid, logFile, imageFile, debuggable, preferStaged);
-
-    if (argc == 4 && argv[1] == "generate"sv) {
-        auto result = generate(vm, argv[2], argv[3]);
-        if (result.ok()) {
-            return 0;
-        } else {
-            std::cerr << result.error() << '\n';
-        }
-    } else if (argc == 4 && argv[1] == "verify"sv) {
-        auto result = verify(vm, argv[2], argv[3]);
-        if (result.ok()) {
-            if (result.value()) {
-                std::cerr << "Key files are valid.\n";
-                return 0;
-            } else {
-                std::cerr << "Key files are not valid.\n";
-            }
-        } else {
-            std::cerr << result.error() << '\n';
-        }
-    } else if (argc == 3 && argv[1] == "init-key"sv) {
-        auto result = initializeKey(vm, argv[2]);
-        if (result.ok()) {
-            return 0;
-        } else {
-            std::cerr << result.error() << '\n';
-        }
-    } else if (argc == 3 && argv[1] == "make-instance"sv) {
-        auto result = makeInstanceImage(argv[2]);
-        if (result.ok()) {
-            return 0;
-        } else {
-            std::cerr << result.error() << '\n';
-        }
-    } else {
-        std::cerr << "Usage: compos_key_cmd [OPTIONS] COMMAND\n"
-                  << "Where COMMAND can be:\n"
-                  << "  make-instance <image file> Create an empty instance image file for a VM.\n"
-                  << "  generate <blob file> <public key file> Generate new key pair and write\n"
-                  << "    the private key blob and public key to the specified files.\n "
-                  << "  verify <blob file> <public key file> Verify that the content of the\n"
-                  << "    specified private key blob and public key files are valid.\n "
-                  << "  init-key <blob file> Initialize the service key.\n"
-                  << "\n"
-                  << "OPTIONS: --log <log file> --debug --staged\n"
-                  << "    (--cid <cid> | --start <image file>)\n"
-                  << "  Specify --log to write VM log to a file rather than stdout.\n"
-                  << "  Specify --debug with --start to make the VM fully debuggable.\n"
-                  << "  Specify --staged with --start to prefer staged APEXes in the VM.\n"
-                  << "  Specify --cid to connect to a VM rather than the host.\n"
-                  << "  Specify --start to start a VM from the given instance image file and\n "
-                  << "    connect to that.\n";
-    }
-    return 1;
-}
diff --git a/compos/src/artifact_signer.rs b/compos/src/artifact_signer.rs
index e1b0efa..b5eb8cb 100644
--- a/compos/src/artifact_signer.rs
+++ b/compos/src/artifact_signer.rs
@@ -20,7 +20,7 @@
 #![allow(dead_code)] // Will be used soon
 
 use crate::fsverity;
-use crate::signing_key::Signer;
+use crate::signing_key::DiceSigner;
 use anyhow::{anyhow, Context, Result};
 use odsign_proto::odsign_info::OdsignInfo;
 use protobuf::Message;
@@ -63,7 +63,7 @@
 
     /// Consume this ArtifactSigner and write details of all its artifacts to the given path,
     /// with accompanying sigature file.
-    pub fn write_info_and_signature(self, signer: Signer, info_path: &Path) -> Result<()> {
+    pub fn write_info_and_signature(self, signer: DiceSigner, info_path: &Path) -> Result<()> {
         let mut info = OdsignInfo::new();
         info.mut_file_hashes().extend(self.file_digests.into_iter());
         let bytes = info.write_to_bytes()?;
diff --git a/compos/src/compilation.rs b/compos/src/compilation.rs
index 7e3834a..48ba4a6 100644
--- a/compos/src/compilation.rs
+++ b/compos/src/compilation.rs
@@ -27,7 +27,7 @@
 use std::process::Command;
 
 use crate::artifact_signer::ArtifactSigner;
-use crate::signing_key::Signer;
+use crate::signing_key::DiceSigner;
 use authfs_aidl_interface::aidl::com::android::virt::fs::{
     AuthFsConfig::{
         AuthFsConfig, InputDirFdAnnotation::InputDirFdAnnotation,
@@ -101,7 +101,7 @@
     odrefresh_path: &Path,
     context: OdrefreshContext,
     authfs_service: Strong<dyn IAuthFsService>,
-    signer: Signer,
+    signer: DiceSigner,
 ) -> Result<ExitCode> {
     // Mount authfs (via authfs_service). The authfs instance unmounts once the `authfs` variable
     // is out of scope.
diff --git a/compos/src/compsvc.rs b/compos/src/compsvc.rs
index 9d754a7..3ec15dd 100644
--- a/compos/src/compsvc.rs
+++ b/compos/src/compsvc.rs
@@ -27,7 +27,8 @@
 use std::sync::RwLock;
 
 use crate::compilation::{odrefresh, OdrefreshContext};
-use crate::signing_key::{Signer, SigningKey};
+use crate::dice::Dice;
+use crate::signing_key::{DiceSigner, DiceSigningKey};
 use authfs_aidl_interface::aidl::com::android::virt::fs::IAuthFsService::IAuthFsService;
 use compos_aidl_interface::aidl::com::android::compos::{
     CompOsKeyData::CompOsKeyData,
@@ -44,7 +45,7 @@
 pub fn new_binder() -> Result<Strong<dyn ICompOsService>> {
     let service = CompOsService {
         odrefresh_path: PathBuf::from(ODREFRESH_PATH),
-        signing_key: SigningKey::new()?,
+        signing_key: DiceSigningKey::new(Dice::new()?),
         key_blob: RwLock::new(Vec::new()),
     };
     Ok(BnCompOsService::new_binder(service, BinderFeatures::default()))
@@ -52,12 +53,12 @@
 
 struct CompOsService {
     odrefresh_path: PathBuf,
-    signing_key: SigningKey,
+    signing_key: DiceSigningKey,
     key_blob: RwLock<Vec<u8>>,
 }
 
 impl CompOsService {
-    fn new_signer(&self) -> BinderResult<Signer> {
+    fn new_signer(&self) -> BinderResult<DiceSigner> {
         let key = &*self.key_blob.read().unwrap();
         if key.is_empty() {
             Err(new_binder_exception(ExceptionCode::ILLEGAL_STATE, "Key is not initialized"))
diff --git a/compos/src/dice.rs b/compos/src/dice.rs
index 9f66b5e..25148ab 100644
--- a/compos/src/dice.rs
+++ b/compos/src/dice.rs
@@ -20,6 +20,7 @@
 use android_security_dice::binder::{wait_for_interface, Strong};
 use anyhow::{Context, Result};
 
+#[derive(Clone)]
 pub struct Dice {
     node: Strong<dyn IDiceNode>,
 }
diff --git a/compos/src/signing_key.rs b/compos/src/signing_key.rs
index 175a11b..90beecf 100644
--- a/compos/src/signing_key.rs
+++ b/compos/src/signing_key.rs
@@ -28,13 +28,26 @@
     signature,
 };
 
-pub struct SigningKey {
-    _unused: (), // Prevent construction other than by new()
+pub type DiceSigningKey = SigningKey<Dice>;
+pub type DiceSigner = Signer<Dice>;
+
+pub struct SigningKey<T: SecretStore> {
+    secret_store: T,
 }
 
-impl SigningKey {
-    pub fn new() -> Result<Self> {
-        Ok(Self { _unused: () })
+pub trait SecretStore: Clone {
+    fn get_secret(&self) -> Result<Vec<u8>>;
+}
+
+impl SecretStore for Dice {
+    fn get_secret(&self) -> Result<Vec<u8>> {
+        self.get_sealing_cdi()
+    }
+}
+
+impl<T: SecretStore> SigningKey<T> {
+    pub fn new(secret_store: T) -> Self {
+        Self { secret_store }
     }
 
     pub fn generate(&self) -> Result<CompOsKeyData> {
@@ -43,7 +56,8 @@
             bail!("Failed to generate key pair: {}", key_result.error);
         }
 
-        let encrypted = encrypt_private_key(&Dice::new()?, &key_result.private_key)?;
+        let encrypted =
+            encrypt_private_key(&self.secret_store.get_secret()?, &key_result.private_key)?;
         Ok(CompOsKeyData { publicKey: key_result.public_key, keyBlob: encrypted })
     }
 
@@ -63,19 +77,19 @@
         Ok(())
     }
 
-    pub fn new_signer(&self, key_blob: &[u8]) -> Result<Signer> {
-        Ok(Signer { key_blob: key_blob.to_owned(), dice: Dice::new()? })
+    pub fn new_signer(&self, key_blob: &[u8]) -> Result<Signer<T>> {
+        Ok(Signer { key_blob: key_blob.to_owned(), secret_store: self.secret_store.clone() })
     }
 }
 
-pub struct Signer {
+pub struct Signer<T: SecretStore> {
     key_blob: Vec<u8>,
-    dice: Dice,
+    secret_store: T,
 }
 
-impl Signer {
+impl<T: SecretStore> Signer<T> {
     pub fn sign(self, data: &[u8]) -> Result<Vec<u8>> {
-        let private_key = decrypt_private_key(&self.dice, &self.key_blob)?;
+        let private_key = decrypt_private_key(&self.secret_store.get_secret()?, &self.key_blob)?;
         let sign_result = compos_native::sign(&private_key, data);
         if sign_result.signature.is_empty() {
             bail!("Failed to sign: {}", sign_result.error);
@@ -84,14 +98,69 @@
     }
 }
 
-fn encrypt_private_key(dice: &Dice, private_key: &[u8]) -> Result<Vec<u8>> {
-    let cdi = dice.get_sealing_cdi()?;
-    let aead_key = blob_encryption::derive_aead_key(&cdi)?;
+fn encrypt_private_key(vm_secret: &[u8], private_key: &[u8]) -> Result<Vec<u8>> {
+    let aead_key = blob_encryption::derive_aead_key(vm_secret)?;
     blob_encryption::encrypt_bytes(aead_key, private_key)
 }
 
-fn decrypt_private_key(dice: &Dice, blob: &[u8]) -> Result<Vec<u8>> {
-    let cdi = dice.get_sealing_cdi()?;
-    let aead_key = blob_encryption::derive_aead_key(&cdi)?;
+fn decrypt_private_key(vm_secret: &[u8], blob: &[u8]) -> Result<Vec<u8>> {
+    let aead_key = blob_encryption::derive_aead_key(vm_secret)?;
     blob_encryption::decrypt_bytes(aead_key, blob)
 }
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+
+    const SECRET: &[u8] = b"This is not very secret";
+
+    #[derive(Clone)]
+    struct TestSecretStore;
+
+    impl SecretStore for TestSecretStore {
+        fn get_secret(&self) -> Result<Vec<u8>> {
+            Ok(SECRET.to_owned())
+        }
+    }
+
+    type TestSigningKey = SigningKey<TestSecretStore>;
+
+    fn signing_key_for_test() -> TestSigningKey {
+        TestSigningKey::new(TestSecretStore)
+    }
+
+    #[test]
+    fn test_generated_key_verifies() -> Result<()> {
+        let signing_key = signing_key_for_test();
+        let key_pair = signing_key.generate()?;
+
+        signing_key.verify(&key_pair.keyBlob, &key_pair.publicKey)
+    }
+
+    #[test]
+    fn test_bogus_key_pair_rejected() -> Result<()> {
+        let signing_key = signing_key_for_test();
+        let key_pair = signing_key.generate()?;
+
+        // Swap public key & key blob - clearly invalid
+        assert!(signing_key.verify(&key_pair.publicKey, &key_pair.keyBlob).is_err());
+
+        Ok(())
+    }
+
+    #[test]
+    fn test_mismatched_key_rejected() -> Result<()> {
+        let signing_key = signing_key_for_test();
+        let key_pair1 = signing_key.generate()?;
+        let key_pair2 = signing_key.generate()?;
+
+        // Both pairs should be valid
+        signing_key.verify(&key_pair1.keyBlob, &key_pair1.publicKey)?;
+        signing_key.verify(&key_pair2.keyBlob, &key_pair2.publicKey)?;
+
+        // But using the public key from one and the private key from the other should not,
+        // even though both are well-formed
+        assert!(signing_key.verify(&key_pair1.publicKey, &key_pair2.keyBlob).is_err());
+        Ok(())
+    }
+}
diff --git a/compos/tests/java/android/compos/test/ComposKeyTestCase.java b/compos/tests/java/android/compos/test/ComposKeyTestCase.java
deleted file mode 100644
index 49235fe..0000000
--- a/compos/tests/java/android/compos/test/ComposKeyTestCase.java
+++ /dev/null
@@ -1,171 +0,0 @@
-/*
- * Copyright (C) 2021 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package android.compos.test;
-
-import static com.google.common.truth.Truth.assertThat;
-
-import android.platform.test.annotations.RootPermissionTest;
-import android.virt.test.CommandRunner;
-import android.virt.test.VirtualizationTestCaseBase;
-
-import com.android.compatibility.common.util.PollingCheck;
-import com.android.tradefed.testtype.DeviceJUnit4ClassRunner;
-import com.android.tradefed.util.CommandResult;
-import com.android.tradefed.util.CommandStatus;
-
-import org.junit.After;
-import org.junit.Before;
-import org.junit.Test;
-import org.junit.runner.RunWith;
-
-import java.util.Optional;
-
-@RootPermissionTest
-@RunWith(DeviceJUnit4ClassRunner.class)
-public final class ComposKeyTestCase extends VirtualizationTestCaseBase {
-
-    /**
-     * Wait time for service to be ready on boot
-     */
-    private static final int READY_LATENCY_MS = 10 * 1000; // 10 seconds
-
-    /**
-     * Path to compos_key_cmd tool
-     */
-    private static final String COMPOS_KEY_CMD_BIN = "/apex/com.android.compos/bin/compos_key_cmd";
-
-    /**
-     * Path to the com.android.compos.payload APK
-     */
-    private static final String COMPOS_PAYLOAD_APK_PATH =
-            "/apex/com.android.compos/app/CompOSPayloadApp/CompOSPayloadApp.apk";
-
-    /**
-     * Config of the test VM. This is a path inside the APK.
-     */
-    private static final String VM_TEST_CONFIG_PATH = "assets/vm_test_config.json";
-
-    private String mCid;
-
-    @Before
-    public void setUp() throws Exception {
-        testIfDeviceIsCapable(getDevice());
-
-        prepareVirtualizationTestSetup(getDevice());
-    }
-
-    @After
-    public void tearDown() throws Exception {
-        if (mCid != null) {
-            shutdownMicrodroid(getDevice(), mCid);
-            mCid = null;
-        }
-
-        cleanUpVirtualizationTestSetup(getDevice());
-    }
-
-    @Test
-    public void testKeyService() throws Exception {
-        startVm();
-        waitForServiceRunning();
-
-        CommandRunner android = new CommandRunner(getDevice());
-        CommandResult result;
-
-        // Generate keys - should succeed
-        android.run(
-                COMPOS_KEY_CMD_BIN,
-                "--cid " + mCid,
-                "generate",
-                TEST_ROOT + "test_key.blob",
-                TEST_ROOT + "test_key.pubkey");
-
-        // Verify them - should also succeed, since we just generated them
-        android.run(
-                COMPOS_KEY_CMD_BIN,
-                "--cid " + mCid,
-                "verify",
-                TEST_ROOT + "test_key.blob",
-                TEST_ROOT + "test_key.pubkey");
-
-        // Swap public key & blob - should fail to verify
-        result =
-                android.runForResult(
-                        COMPOS_KEY_CMD_BIN,
-                        "--cid " + mCid,
-                        "verify",
-                        TEST_ROOT + "test_key.pubkey",
-                        TEST_ROOT + "test_key.blob");
-        assertThat(result.getStatus()).isEqualTo(CommandStatus.FAILED);
-
-        // Generate another set of keys - should succeed
-        android.run(
-                COMPOS_KEY_CMD_BIN,
-                "--cid " + mCid,
-                "generate",
-                TEST_ROOT + "test_key2.blob",
-                TEST_ROOT + "test_key2.pubkey");
-
-        // They should also verify ok
-        android.run(
-                COMPOS_KEY_CMD_BIN,
-                "--cid " + mCid,
-                "verify",
-                TEST_ROOT + "test_key2.blob",
-                TEST_ROOT + "test_key2.pubkey");
-
-        // Mismatched key blob & public key should fail to verify
-        result =
-                android.runForResult(
-                        COMPOS_KEY_CMD_BIN,
-                        "--cid " + mCid,
-                        "verify",
-                        TEST_ROOT + "test_key.pubkey",
-                        TEST_ROOT + "test_key2.blob");
-        assertThat(result.getStatus()).isEqualTo(CommandStatus.FAILED);
-    }
-
-    private void startVm() throws Exception {
-        final String packageName = "com.android.compos.payload";
-        mCid =
-                startMicrodroid(
-                        getDevice(),
-                        getBuild(),
-                        /* apkName, no need to install */ null,
-                        COMPOS_PAYLOAD_APK_PATH,
-                        /* packageName - not needed, we know the path */ null,
-                        /* extraIdSigPaths */ null,
-                        VM_TEST_CONFIG_PATH,
-                        /* debug */ true,
-                        /* use default memoryMib */ 0,
-                        Optional.empty(),
-                        Optional.empty());
-        adbConnectToMicrodroid(getDevice(), mCid);
-    }
-
-    private void waitForServiceRunning() {
-        try {
-            PollingCheck.waitFor(READY_LATENCY_MS, this::isServiceRunning);
-        } catch (Exception e) {
-            throw new RuntimeException("Service unavailable", e);
-        }
-    }
-
-    private boolean isServiceRunning() {
-        return tryRunOnMicrodroid("pidof compsvc") != null;
-    }
-}