Merge "remove support of multi-file partition"
diff --git a/compos/compos_key_cmd/Android.bp b/compos/compos_key_cmd/Android.bp
index 1d889c6..e0584f4 100644
--- a/compos/compos_key_cmd/Android.bp
+++ b/compos/compos_key_cmd/Android.bp
@@ -12,6 +12,7 @@
     ],
 
     shared_libs: [
+        "android.system.virtualizationservice-ndk",
         "compos_aidl_interface-ndk",
         "libbase",
         "libbinder_rpc_unstable",
diff --git a/compos/compos_key_cmd/compos_key_cmd.cpp b/compos/compos_key_cmd/compos_key_cmd.cpp
index bee9de1..84a0a7c 100644
--- a/compos/compos_key_cmd/compos_key_cmd.cpp
+++ b/compos/compos_key_cmd/compos_key_cmd.cpp
@@ -14,12 +14,16 @@
  * limitations under the License.
  */
 
+#include <aidl/android/system/virtualizationservice/BnVirtualMachineCallback.h>
+#include <aidl/android/system/virtualizationservice/IVirtualizationService.h>
 #include <aidl/com/android/compos/ICompOsKeyService.h>
 #include <android-base/file.h>
+#include <android-base/logging.h>
 #include <android-base/result.h>
 #include <android-base/unique_fd.h>
 #include <android/binder_auto_utils.h>
 #include <android/binder_manager.h>
+#include <android/binder_process.h>
 #include <asm/byteorder.h>
 #include <libfsverity.h>
 #include <linux/fsverity.h>
@@ -27,9 +31,13 @@
 #include <openssl/mem.h>
 #include <openssl/sha.h>
 #include <openssl/x509.h>
+#include <unistd.h>
 
+#include <chrono>
+#include <condition_variable>
 #include <filesystem>
 #include <iostream>
+#include <mutex>
 #include <string>
 #include <string_view>
 
@@ -42,6 +50,11 @@
 
 using namespace std::literals;
 
+using aidl::android::system::virtualizationservice::BnVirtualMachineCallback;
+using aidl::android::system::virtualizationservice::IVirtualizationService;
+using aidl::android::system::virtualizationservice::IVirtualMachine;
+using aidl::android::system::virtualizationservice::IVirtualMachineCallback;
+using aidl::android::system::virtualizationservice::VirtualMachineConfig;
 using aidl::com::android::compos::CompOsKeyData;
 using aidl::com::android::compos::ICompOsKeyService;
 using android::base::ErrnoError;
@@ -49,8 +62,19 @@
 using android::base::Result;
 using android::base::unique_fd;
 using compos::proto::Signature;
+using ndk::ScopedAStatus;
+using ndk::ScopedFileDescriptor;
+using ndk::SharedRefBase;
 
-const unsigned int kRpcPort = 3142;
+constexpr unsigned int kRpcPort = 3142;
+
+constexpr const char* kConfigApkPath =
+        "/apex/com.android.compos/app/CompOSPayloadApp/CompOSPayloadApp.apk";
+constexpr const char* kConfigApkIdsigPath =
+        "/apex/com.android.compos/etc/CompOSPayloadApp.apk.idsig";
+
+// This is a path inside the APK
+constexpr const char* kConfigFilePath = "assets/key_service_vm_config.json";
 
 static bool writeBytesToFile(const std::vector<uint8_t>& bytes, const std::string& path) {
     std::string str(bytes.begin(), bytes.end());
@@ -66,11 +90,160 @@
 }
 
 static std::shared_ptr<ICompOsKeyService> getService(int cid) {
+    LOG(INFO) << "Connecting to cid " << cid;
     ndk::SpAIBinder binder(cid == 0 ? AServiceManager_getService("android.system.composkeyservice")
                                     : RpcClient(cid, kRpcPort));
     return ICompOsKeyService::fromBinder(binder);
 }
 
+namespace {
+class Callback : public BnVirtualMachineCallback {
+public:
+    ::ndk::ScopedAStatus onPayloadStarted(
+            int32_t in_cid, const ::ndk::ScopedFileDescriptor& /*in_stdout*/) override {
+        // TODO: Consider copying stdout somewhere useful?
+        LOG(INFO) << "Payload started! cid = " << in_cid;
+        {
+            std::unique_lock lock(mMutex);
+            mStarted = true;
+        }
+        mCv.notify_all();
+        return ScopedAStatus::ok();
+    }
+
+    ::ndk::ScopedAStatus onDied(int32_t in_cid) override {
+        LOG(WARNING) << "VM died! cid = " << in_cid;
+        {
+            std::unique_lock lock(mMutex);
+            mDied = true;
+        }
+        mCv.notify_all();
+        return ScopedAStatus::ok();
+    }
+
+    bool waitForStarted() {
+        std::unique_lock lock(mMutex);
+        return mCv.wait_for(lock, std::chrono::seconds(10), [this] { return mStarted || mDied; }) &&
+                !mDied;
+    }
+
+private:
+    std::mutex mMutex;
+    std::condition_variable mCv;
+    bool mStarted;
+    bool mDied;
+};
+
+class TargetVm {
+public:
+    TargetVm(int cid, const std::string& logFile, const std::string& instanceImageFile)
+          : mCid(cid), mLogFile(logFile), mInstanceImageFile(instanceImageFile) {}
+
+    // Returns 0 if we are to connect to a local service, otherwise the CID of
+    // either an existing VM or a VM we have started, depending on the command
+    // line arguments.
+    Result<int> resolveCid() {
+        if (mInstanceImageFile.empty()) {
+            return mCid;
+        }
+        if (mCid != 0) {
+            return Error() << "Can't specify both cid and image file.";
+        }
+
+        // We need a thread pool to receive VM callbacks.
+        ABinderProcess_startThreadPool();
+
+        ndk::SpAIBinder binder(
+                AServiceManager_waitForService("android.system.virtualizationservice"));
+        auto service = IVirtualizationService::fromBinder(binder);
+        if (!service) {
+            return Error() << "Failed to connect to virtualization service.";
+        }
+
+        ScopedFileDescriptor logFd;
+        if (mLogFile.empty()) {
+            logFd.set(dup(STDOUT_FILENO));
+            if (logFd.get() == -1) {
+                return ErrnoError() << "dup() failed: ";
+            }
+        } else {
+            logFd.set(TEMP_FAILURE_RETRY(open(mLogFile.c_str(),
+                                              O_WRONLY | O_CREAT | O_TRUNC | O_CLOEXEC,
+                                              S_IRUSR | S_IWUSR)));
+            if (logFd.get() == -1) {
+                return ErrnoError() << "Failed to open " << mLogFile;
+            }
+        }
+
+        ScopedFileDescriptor apkFd(TEMP_FAILURE_RETRY(open(kConfigApkPath, O_RDONLY | O_CLOEXEC)));
+        if (apkFd.get() == -1) {
+            return ErrnoError() << "Failed to open config APK";
+        }
+
+        ScopedFileDescriptor idsigFd(
+                TEMP_FAILURE_RETRY(open(kConfigApkIdsigPath, O_RDONLY | O_CLOEXEC)));
+        if (idsigFd.get() == -1) {
+            return ErrnoError() << "Failed to open config APK signature";
+        }
+
+        ScopedFileDescriptor instanceFd(
+                TEMP_FAILURE_RETRY(open(mInstanceImageFile.c_str(), O_RDONLY | O_CLOEXEC)));
+        if (instanceFd.get() == -1) {
+            return ErrnoError() << "Failed to open instance image file";
+        }
+
+        auto config = VirtualMachineConfig::make<VirtualMachineConfig::Tag::appConfig>();
+        auto& appConfig = config.get<VirtualMachineConfig::Tag::appConfig>();
+        appConfig.apk = std::move(apkFd);
+        appConfig.idsig = std::move(idsigFd);
+        appConfig.instanceImage = std::move(instanceFd);
+        appConfig.configPath = kConfigFilePath;
+        appConfig.debug = false; // Don't disable selinux in VM
+        appConfig.memoryMib = 0; // Use default
+
+        LOG(INFO) << "Starting VM";
+        auto status = service->startVm(config, logFd, &mVm);
+        if (!status.isOk()) {
+            return Error() << status.getDescription();
+        }
+
+        int32_t cid;
+        status = mVm->getCid(&cid);
+        if (!status.isOk()) {
+            return Error() << status.getDescription();
+        }
+
+        LOG(INFO) << "Started VM with cid = " << cid;
+
+        // We need to use this rather than std::make_shared to make sure the
+        // embedded weak_ptr is initialised.
+        mCallback = SharedRefBase::make<Callback>();
+
+        status = mVm->registerCallback(mCallback);
+        if (!status.isOk()) {
+            return Error() << status.getDescription();
+        }
+
+        if (!mCallback->waitForStarted()) {
+            return Error() << "VM Payload failed to start";
+        }
+
+        // TODO(b/194677789): Implement a polling loop or find a more reliable
+        // way to detect when the service is listening.
+        sleep(3);
+
+        return cid;
+    }
+
+private:
+    const int mCid;
+    const std::string mLogFile;
+    const std::string mInstanceImageFile;
+    std::shared_ptr<Callback> mCallback;
+    std::shared_ptr<IVirtualMachine> mVm;
+};
+} // namespace
+
 static Result<std::vector<uint8_t>> extractRsaPublicKey(
         const std::vector<uint8_t>& der_certificate) {
     auto data = der_certificate.data();
@@ -102,9 +275,13 @@
     return result;
 }
 
-static Result<void> generate(int cid, const std::string& blob_file,
+static Result<void> generate(TargetVm& vm, const std::string& blob_file,
                              const std::string& public_key_file) {
-    auto service = getService(cid);
+    auto cid = vm.resolveCid();
+    if (!cid.ok()) {
+        return cid.error();
+    }
+    auto service = getService(*cid);
     if (!service) {
         return Error() << "No service";
     }
@@ -130,9 +307,13 @@
     return {};
 }
 
-static Result<bool> verify(int cid, const std::string& blob_file,
+static Result<bool> verify(TargetVm& vm, const std::string& blob_file,
                            const std::string& public_key_file) {
-    auto service = getService(cid);
+    auto cid = vm.resolveCid();
+    if (!cid.ok()) {
+        return cid.error();
+    }
+    auto service = getService(*cid);
     if (!service) {
         return Error() << "No service";
     }
@@ -223,9 +404,13 @@
     return {};
 }
 
-static Result<void> sign(int cid, const std::string& blob_file,
+static Result<void> sign(TargetVm& vm, const std::string& blob_file,
                          const std::vector<std::string>& files) {
-    auto service = getService(cid);
+    auto cid = vm.resolveCid();
+    if (!cid.ok()) {
+        return cid.error();
+    }
+    auto service = getService(*cid);
     if (!service) {
         return Error() << "No service";
     }
@@ -244,30 +429,63 @@
     return {};
 }
 
+static Result<void> makeInstanceImage(const std::string& image_path) {
+    ndk::SpAIBinder binder(AServiceManager_waitForService("android.system.virtualizationservice"));
+    auto service = IVirtualizationService::fromBinder(binder);
+    if (!service) {
+        return Error() << "Failed to connect to virtualization service.";
+    }
+
+    ScopedFileDescriptor fd(TEMP_FAILURE_RETRY(
+            open(image_path.c_str(), O_CREAT | O_RDWR | O_TRUNC | O_CLOEXEC, S_IRUSR | S_IWUSR)));
+    if (fd.get() == -1) {
+        return ErrnoError() << "Failed to create image file";
+    }
+
+    auto status = service->initializeWritablePartition(fd, 10 * 1024 * 1024);
+    if (!status.isOk()) {
+        return Error() << "Failed to initialize partition: " << status.getDescription();
+    }
+    return {};
+}
+
 int main(int argc, char** argv) {
     // Restrict access to our outputs to the current user.
     umask(077);
 
     int cid = 0;
-    if (argc >= 3 && argv[1] == "--cid"sv) {
-        cid = atoi(argv[2]);
-        if (cid == 0) {
-            std::cerr << "Invalid cid\n";
-            return 1;
+    std::string imageFile;
+    std::string logFile;
+
+    while (argc >= 3) {
+        if (argv[1] == "--cid"sv) {
+            cid = atoi(argv[2]);
+            if (cid == 0) {
+                std::cerr << "Invalid cid\n";
+                return 1;
+            }
+        } else if (argv[1] == "--start"sv) {
+            imageFile = argv[2];
+        } else if (argv[1] == "--log"sv) {
+            logFile = argv[2];
+        } else {
+            break;
         }
         argc -= 2;
         argv += 2;
     }
 
+    TargetVm vm(cid, logFile, imageFile);
+
     if (argc == 4 && argv[1] == "generate"sv) {
-        auto result = generate(cid, argv[2], argv[3]);
+        auto result = generate(vm, argv[2], argv[3]);
         if (result.ok()) {
             return 0;
         } else {
             std::cerr << result.error() << '\n';
         }
     } else if (argc == 4 && argv[1] == "verify"sv) {
-        auto result = verify(cid, argv[2], argv[3]);
+        auto result = verify(vm, argv[2], argv[3]);
         if (result.ok()) {
             if (result.value()) {
                 std::cerr << "Key files are valid.\n";
@@ -280,23 +498,35 @@
         }
     } else if (argc >= 4 && argv[1] == "sign"sv) {
         const std::vector<std::string> files{&argv[3], &argv[argc]};
-        auto result = sign(cid, argv[2], files);
+        auto result = sign(vm, argv[2], files);
         if (result.ok()) {
             std::cerr << "All signatures generated.\n";
             return 0;
         } else {
             std::cerr << result.error() << '\n';
         }
+    } else if (argc == 3 && argv[1] == "make-instance"sv) {
+        auto result = makeInstanceImage(argv[2]);
+        if (result.ok()) {
+            return 0;
+        } else {
+            std::cerr << result.error() << '\n';
+        }
     } else {
-        std::cerr << "Usage: compos_key_cmd [--cid <cid>] generate|verify|sign\n"
-                  << "  generate <blob file> <public key file> Generate new key pair and "
-                     "write\n"
+        std::cerr << "Usage: compos_key_cmd [OPTIONS] generate|verify|sign|make-instance\n"
+                  << "  generate <blob file> <public key file> Generate new key pair and write\n"
                   << "    the private key blob and public key to the specified files.\n "
                   << "  verify <blob file> <public key file> Verify that the content of the\n"
                   << "    specified private key blob and public key files are valid.\n "
                   << "  sign <blob file> <files to be signed> Generate signatures for one or\n"
                   << "    more files using the supplied private key blob.\n"
-                  << "Specify --cid to connect to a VM rather than the host\n";
+                  << "  make-instance <image file> Create an empty instance image file for a VM.\n"
+                  << "\n"
+                  << "OPTIONS: --log <log file> (--cid <cid> | --start <image file>)\n"
+                  << "  Specify --log to write VM log to a file rather than stdout.\n"
+                  << "  Specify --cid to connect to a VM rather than the host.\n"
+                  << "  Specify --start to start a VM from the given instance image file and\n "
+                  << "    connect to that.\n";
     }
     return 1;
 }
diff --git a/compos/tests/java/android/compos/test/ComposTestCase.java b/compos/tests/java/android/compos/test/ComposTestCase.java
index 4471e63..da23919 100644
--- a/compos/tests/java/android/compos/test/ComposTestCase.java
+++ b/compos/tests/java/android/compos/test/ComposTestCase.java
@@ -30,11 +30,13 @@
 
 import org.junit.After;
 import org.junit.Before;
+import org.junit.Ignore;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 
 @RootPermissionTest
 @RunWith(DeviceJUnit4ClassRunner.class)
+@Ignore("b/194974010")
 public final class ComposTestCase extends VirtualizationTestCaseBase {
 
     /** Path to odrefresh on Microdroid */
diff --git a/tests/hostside/helper/java/android/virt/test/VirtualizationTestCaseBase.java b/tests/hostside/helper/java/android/virt/test/VirtualizationTestCaseBase.java
index a9e5040..0e07c60 100644
--- a/tests/hostside/helper/java/android/virt/test/VirtualizationTestCaseBase.java
+++ b/tests/hostside/helper/java/android/virt/test/VirtualizationTestCaseBase.java
@@ -230,13 +230,17 @@
             throws DeviceNotAvailableException {
         CommandRunner android = new CommandRunner(androidDevice);
 
-        // Close the connection before shutting the VM down. Otherwise, b/192660485.
-        tryRunOnHost("adb", "disconnect", MICRODROID_SERIAL);
-        final String serial = androidDevice.getSerialNumber();
-        tryRunOnHost("adb", "-s", serial, "forward", "--remove", "tcp:" + TEST_VM_ADB_PORT);
-
         // Shutdown the VM
         android.run(VIRT_APEX + "bin/vm", "stop", cid);
+
+        // TODO(192660485): Figure out why shutting down the VM disconnects adb on cuttlefish
+        // temporarily. Without this wait, the rest of `runOnAndroid/skipIfFail` fails due to the
+        // connection loss, and results in assumption error exception for the rest of the tests.
+        try {
+            Thread.sleep(1000);
+        } catch (InterruptedException e) {
+            Thread.currentThread().interrupt();
+        }
     }
 
     public static void rootMicrodroid() throws DeviceNotAvailableException {
diff --git a/virtualizationservice/aidl/Android.bp b/virtualizationservice/aidl/Android.bp
index f7cb339..7d85bd3 100644
--- a/virtualizationservice/aidl/Android.bp
+++ b/virtualizationservice/aidl/Android.bp
@@ -16,6 +16,11 @@
         cpp: {
             enabled: true,
         },
+        ndk: {
+            apex_available: [
+                "com.android.compos",
+            ],
+        },
         rust: {
             enabled: true,
             apex_available: ["com.android.virt"],