Add a helper for device tests to remove duplicates

Test: atest MicrodroidTests MicrodroidBenchmarks
Change-Id: I0983c0d6dd21a9eab1be2bf8ee78823e3614f646
diff --git a/tests/benchmark/Android.bp b/tests/benchmark/Android.bp
index f333d03..e6d5b83 100644
--- a/tests/benchmark/Android.bp
+++ b/tests/benchmark/Android.bp
@@ -9,6 +9,7 @@
     ],
     srcs: ["src/java/**/*.java"],
     static_libs: [
+        "MicroroidDeviceTestHelper",
         "androidx.test.runner",
         "androidx.test.ext.junit",
         "truth-prebuilt",
diff --git a/tests/benchmark/src/java/com/android/microdroid/benchmark/MicrodroidBenchmarks.java b/tests/benchmark/src/java/com/android/microdroid/benchmark/MicrodroidBenchmarks.java
index 864d2d5..4964216 100644
--- a/tests/benchmark/src/java/com/android/microdroid/benchmark/MicrodroidBenchmarks.java
+++ b/tests/benchmark/src/java/com/android/microdroid/benchmark/MicrodroidBenchmarks.java
@@ -20,25 +20,14 @@
 import static com.google.common.truth.Truth.assertThat;
 import static com.google.common.truth.TruthJUnit.assume;
 
-import static org.junit.Assume.assumeNoException;
-
 import android.app.Instrumentation;
-import android.content.Context;
 import android.os.Bundle;
-import android.os.ParcelFileDescriptor;
 import android.os.SystemProperties;
-import android.sysprop.HypervisorProperties;
-import android.system.virtualizationservice.DeathReason;
-import android.system.virtualmachine.VirtualMachine;
-import android.system.virtualmachine.VirtualMachineCallback;
 import android.system.virtualmachine.VirtualMachineConfig;
 import android.system.virtualmachine.VirtualMachineConfig.DebugLevel;
 import android.system.virtualmachine.VirtualMachineException;
-import android.system.virtualmachine.VirtualMachineManager;
-import android.util.Log;
 
-import androidx.annotation.CallSuper;
-import androidx.test.core.app.ApplicationProvider;
+import com.android.microdroid.test.MicrodroidDeviceTestBase;
 
 import org.junit.After;
 import org.junit.Before;
@@ -48,17 +37,10 @@
 import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
 
-import java.io.BufferedReader;
 import java.io.IOException;
-import java.io.InputStream;
-import java.io.InputStreamReader;
-import java.util.concurrent.CompletableFuture;
-import java.util.concurrent.ExecutorService;
-import java.util.concurrent.Executors;
-import java.util.concurrent.TimeUnit;
 
 @RunWith(Parameterized.class)
-public class MicrodroidBenchmarks {
+public class MicrodroidBenchmarks extends MicrodroidDeviceTestBase {
     private static final String TAG = "MicrodroidBenchmarks";
 
     @Rule public Timeout globalTimeout = Timeout.seconds(300);
@@ -74,41 +56,6 @@
                         || productName.startsWith("cf_arm"));
     }
 
-    /** Copy output from the VM to logcat. This is helpful when things go wrong. */
-    private static void logVmOutput(InputStream vmOutputStream, String name) {
-        new Thread(
-                () -> {
-                    try {
-                        BufferedReader reader =
-                                new BufferedReader(new InputStreamReader(vmOutputStream));
-                        String line;
-                        while ((line = reader.readLine()) != null
-                                && !Thread.interrupted()) {
-                            Log.i(TAG, name + ": " + line);
-                        }
-                    } catch (Exception e) {
-                        Log.w(TAG, name, e);
-                    }
-                }).start();
-    }
-
-    private static class Inner {
-        public boolean mProtectedVm;
-        public Context mContext;
-        public VirtualMachineManager mVmm;
-        public VirtualMachine mVm;
-
-        Inner(boolean protectedVm) {
-            mProtectedVm = protectedVm;
-        }
-
-        /** Create a new VirtualMachineConfig.Builder with the parameterized protection mode. */
-        public VirtualMachineConfig.Builder newVmConfigBuilder(String payloadConfigPath) {
-            return new VirtualMachineConfig.Builder(mContext, payloadConfigPath)
-                    .protectedVm(mProtectedVm);
-        }
-    }
-
     @Parameterized.Parameters(name = "protectedVm={0}")
     public static Object[] protectedVmConfigs() {
         return new Object[] {false, true};
@@ -116,128 +63,17 @@
 
     @Parameterized.Parameter public boolean mProtectedVm;
 
-    private boolean mPkvmSupported = false;
-    private Inner mInner;
-
     private Instrumentation mInstrumentation;
 
     @Before
     public void setup() {
-        // In case when the virt APEX doesn't exist on the device, classes in the
-        // android.system.virtualmachine package can't be loaded. Therefore, before using the
-        // classes, check the existence of a class in the package and skip this test if not exist.
-        try {
-            Class.forName("android.system.virtualmachine.VirtualMachineManager");
-            mPkvmSupported = true;
-        } catch (ClassNotFoundException e) {
-            assumeNoException(e);
-            return;
-        }
-        if (mProtectedVm) {
-            assume().withMessage("Skip where protected VMs aren't support")
-                    .that(HypervisorProperties.hypervisor_protected_vm_supported().orElse(false))
-                    .isTrue();
-        } else {
-            assume().withMessage("Skip where VMs aren't support")
-                    .that(HypervisorProperties.hypervisor_vm_supported().orElse(false))
-                    .isTrue();
-        }
-        mInner = new Inner(mProtectedVm);
-        mInner.mContext = ApplicationProvider.getApplicationContext();
-        mInner.mVmm = VirtualMachineManager.getInstance(mInner.mContext);
+        prepareTestSetup(mProtectedVm);
         mInstrumentation = getInstrumentation();
     }
 
     @After
     public void cleanup() throws VirtualMachineException {
-        if (!mPkvmSupported) {
-            return;
-        }
-        if (mInner == null) {
-            return;
-        }
-        if (mInner.mVm == null) {
-            return;
-        }
-        mInner.mVm.stop();
-        mInner.mVm.delete();
-    }
-
-    private abstract static class VmEventListener implements VirtualMachineCallback {
-        private ExecutorService mExecutorService = Executors.newSingleThreadExecutor();
-
-        void runToFinish(VirtualMachine vm) throws VirtualMachineException, InterruptedException {
-            vm.setCallback(mExecutorService, this);
-            vm.run();
-            logVmOutput(vm.getConsoleOutputStream(), "Console");
-            logVmOutput(vm.getLogOutputStream(), "Log");
-            mExecutorService.awaitTermination(300, TimeUnit.SECONDS);
-        }
-
-        void forceStop(VirtualMachine vm) {
-            try {
-                vm.clearCallback();
-                vm.stop();
-                mExecutorService.shutdown();
-            } catch (VirtualMachineException e) {
-                throw new RuntimeException(e);
-            }
-        }
-
-        @Override
-        public void onPayloadStarted(VirtualMachine vm, ParcelFileDescriptor stream) {}
-
-        @Override
-        public void onPayloadReady(VirtualMachine vm) {}
-
-        @Override
-        public void onPayloadFinished(VirtualMachine vm, int exitCode) {}
-
-        @Override
-        public void onError(VirtualMachine vm, int errorCode, String message) {}
-
-        @Override
-        @CallSuper
-        public void onDied(VirtualMachine vm, @DeathReason int reason) {
-            mExecutorService.shutdown();
-        }
-
-        @Override
-        public void onRamdump(VirtualMachine vm, ParcelFileDescriptor ramdump) {}
-    }
-
-    private static class BootResult {
-        public final boolean payloadStarted;
-        public final int deathReason;
-
-        BootResult(boolean payloadStarted, int deathReason) {
-            this.payloadStarted = payloadStarted;
-            this.deathReason = deathReason;
-        }
-    }
-
-    private BootResult tryBootVm(String vmName)
-            throws VirtualMachineException, InterruptedException {
-        mInner.mVm = mInner.mVmm.get(vmName); // re-load the vm before running tests
-        final CompletableFuture<Boolean> payloadStarted = new CompletableFuture<>();
-        final CompletableFuture<Integer> deathReason = new CompletableFuture<>();
-        VmEventListener listener =
-                new VmEventListener() {
-                    @Override
-                    public void onPayloadStarted(VirtualMachine vm, ParcelFileDescriptor stream) {
-                        payloadStarted.complete(true);
-                        forceStop(vm);
-                    }
-
-                    @Override
-                    public void onDied(VirtualMachine vm, int reason) {
-                        deathReason.complete(reason);
-                        super.onDied(vm, reason);
-                    }
-                };
-        listener.runToFinish(mInner.mVm);
-        return new BootResult(
-                payloadStarted.getNow(false), deathReason.getNow(DeathReason.INFRASTRUCTURE_ERROR));
+        cleanupTestSetup();
     }
 
     private boolean canBootMicrodroidWithMemory(int mem)
@@ -246,18 +82,13 @@
 
         // returns true if succeeded at least once.
         for (int i = 0; i < trialCount; i++) {
-            VirtualMachine existingVm = mInner.mVmm.get("test_vm_minimum_memory");
-            if (existingVm != null) {
-                existingVm.delete();
-            }
-
             VirtualMachineConfig.Builder builder =
                     mInner.newVmConfigBuilder("assets/vm_config.json");
             VirtualMachineConfig normalConfig =
                     builder.debugLevel(DebugLevel.FULL).memoryMib(mem).build();
-            mInner.mVmm.create("test_vm_minimum_memory", normalConfig);
+            mInner.forceCreateNewVirtualMachine("test_vm_minimum_memory", normalConfig);
 
-            if (tryBootVm("test_vm_minimum_memory").payloadStarted) return true;
+            if (tryBootVm(TAG, "test_vm_minimum_memory").payloadStarted) return true;
         }
 
         return false;
diff --git a/tests/helper/Android.bp b/tests/helper/Android.bp
new file mode 100644
index 0000000..679fbfe
--- /dev/null
+++ b/tests/helper/Android.bp
@@ -0,0 +1,15 @@
+package {
+    default_applicable_licenses: ["Android-Apache-2.0"],
+}
+
+java_library_static {
+    name: "MicroroidDeviceTestHelper",
+    srcs: ["src/java/**/*.java"],
+    static_libs: [
+        "androidx.test.runner",
+        "androidx.test.ext.junit",
+        "truth-prebuilt",
+    ],
+    libs: ["android.system.virtualmachine"],
+    platform_apis: true,
+}
diff --git a/tests/helper/src/java/com/android/microdroid/test/MicrodroidDeviceTestBase.java b/tests/helper/src/java/com/android/microdroid/test/MicrodroidDeviceTestBase.java
new file mode 100644
index 0000000..b4c814b
--- /dev/null
+++ b/tests/helper/src/java/com/android/microdroid/test/MicrodroidDeviceTestBase.java
@@ -0,0 +1,218 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.android.microdroid.test;
+
+import static com.google.common.truth.TruthJUnit.assume;
+
+import static org.junit.Assume.assumeNoException;
+
+import android.content.Context;
+import android.os.ParcelFileDescriptor;
+import android.sysprop.HypervisorProperties;
+import android.system.virtualizationservice.DeathReason;
+import android.system.virtualmachine.VirtualMachine;
+import android.system.virtualmachine.VirtualMachineCallback;
+import android.system.virtualmachine.VirtualMachineConfig;
+import android.system.virtualmachine.VirtualMachineException;
+import android.system.virtualmachine.VirtualMachineManager;
+import android.util.Log;
+
+import androidx.annotation.CallSuper;
+import androidx.test.core.app.ApplicationProvider;
+
+import java.io.BufferedReader;
+import java.io.InputStream;
+import java.io.InputStreamReader;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.TimeUnit;
+
+public abstract class MicrodroidDeviceTestBase {
+    /** Copy output from the VM to logcat. This is helpful when things go wrong. */
+    protected static void logVmOutput(String tag, InputStream vmOutputStream, String name) {
+        new Thread(
+                () -> {
+                    try {
+                        BufferedReader reader =
+                                new BufferedReader(new InputStreamReader(vmOutputStream));
+                        String line;
+                        while ((line = reader.readLine()) != null
+                                && !Thread.interrupted()) {
+                            Log.i(tag, name + ": " + line);
+                        }
+                    } catch (Exception e) {
+                        Log.w(tag, name, e);
+                    }
+                }).start();
+    }
+
+    private boolean mPkvmSupported;
+
+    // TODO(b/220920264): remove Inner class; this is a hack to hide virt APEX types
+    protected static class Inner {
+        private boolean mProtectedVm;
+        private Context mContext;
+        private VirtualMachineManager mVmm;
+
+        public Inner(Context context, boolean protectedVm, VirtualMachineManager vmm) {
+            mProtectedVm = protectedVm;
+            mVmm = vmm;
+            mContext = context;
+        }
+
+        public VirtualMachineManager getVirtualMachineManager() {
+            return mVmm;
+        }
+
+        public Context getContext() {
+            return mContext;
+        }
+
+        /** Create a new VirtualMachineConfig.Builder with the parameterized protection mode. */
+        public VirtualMachineConfig.Builder newVmConfigBuilder(String payloadConfigPath) {
+            return new VirtualMachineConfig.Builder(mContext, payloadConfigPath)
+                        .protectedVm(mProtectedVm);
+        }
+
+        /**
+         * Creates a new virtual machine, potentially removing an existing virtual machine with
+         * given name.
+         */
+        public VirtualMachine forceCreateNewVirtualMachine(String name, VirtualMachineConfig config)
+                throws VirtualMachineException {
+            VirtualMachine existingVm = mVmm.get(name);
+            if (existingVm != null) {
+                existingVm.delete();
+            }
+            return mVmm.create(name, config);
+        }
+    }
+
+    protected Inner mInner;
+
+    protected Context getContext() {
+        return mInner.getContext();
+    }
+
+    public void prepareTestSetup(boolean protectedVm) {
+        // In case when the virt APEX doesn't exist on the device, classes in the
+        // android.system.virtualmachine package can't be loaded. Therefore, before using the
+        // classes, check the existence of a class in the package and skip this test if not exist.
+        try {
+            Class.forName("android.system.virtualmachine.VirtualMachineManager");
+            mPkvmSupported = true;
+        } catch (ClassNotFoundException e) {
+            assumeNoException(e);
+            return;
+        }
+        if (protectedVm) {
+            assume().withMessage("Skip where protected VMs aren't support")
+                    .that(HypervisorProperties.hypervisor_protected_vm_supported().orElse(false))
+                    .isTrue();
+        } else {
+            assume().withMessage("Skip where VMs aren't support")
+                    .that(HypervisorProperties.hypervisor_vm_supported().orElse(false))
+                    .isTrue();
+        }
+        Context context = ApplicationProvider.getApplicationContext();
+        mInner = new Inner(context, protectedVm, VirtualMachineManager.getInstance(context));
+    }
+
+    public void cleanupTestSetup() throws VirtualMachineException {
+        if (!mPkvmSupported) {
+            return;
+        }
+    }
+
+    protected abstract static class VmEventListener implements VirtualMachineCallback {
+        private ExecutorService mExecutorService = Executors.newSingleThreadExecutor();
+
+        void runToFinish(String logTag, VirtualMachine vm)
+                throws VirtualMachineException, InterruptedException {
+            vm.setCallback(mExecutorService, this);
+            vm.run();
+            logVmOutput(logTag, vm.getConsoleOutputStream(), "Console");
+            logVmOutput(logTag, vm.getLogOutputStream(), "Log");
+            mExecutorService.awaitTermination(300, TimeUnit.SECONDS);
+        }
+
+        void forceStop(VirtualMachine vm) {
+            try {
+                vm.clearCallback();
+                vm.stop();
+                mExecutorService.shutdown();
+            } catch (VirtualMachineException e) {
+                throw new RuntimeException(e);
+            }
+        }
+
+        @Override
+        public void onPayloadStarted(VirtualMachine vm, ParcelFileDescriptor stream) {}
+
+        @Override
+        public void onPayloadReady(VirtualMachine vm) {}
+
+        @Override
+        public void onPayloadFinished(VirtualMachine vm, int exitCode) {}
+
+        @Override
+        public void onError(VirtualMachine vm, int errorCode, String message) {}
+
+        @Override
+        @CallSuper
+        public void onDied(VirtualMachine vm, @DeathReason int reason) {
+            mExecutorService.shutdown();
+        }
+
+        @Override
+        public void onRamdump(VirtualMachine vm, ParcelFileDescriptor ramdump) {}
+    }
+
+    public static class BootResult {
+        public final boolean payloadStarted;
+        public final int deathReason;
+
+        BootResult(boolean payloadStarted, int deathReason) {
+            this.payloadStarted = payloadStarted;
+            this.deathReason = deathReason;
+        }
+    }
+
+    public BootResult tryBootVm(String logTag, String vmName)
+            throws VirtualMachineException, InterruptedException {
+        VirtualMachine vm = mInner.getVirtualMachineManager().get(vmName);
+        final CompletableFuture<Boolean> payloadStarted = new CompletableFuture<>();
+        final CompletableFuture<Integer> deathReason = new CompletableFuture<>();
+        VmEventListener listener =
+                new VmEventListener() {
+                    @Override
+                    public void onPayloadStarted(VirtualMachine vm, ParcelFileDescriptor stream) {
+                        payloadStarted.complete(true);
+                        forceStop(vm);
+                    }
+
+                    @Override
+                    public void onDied(VirtualMachine vm, int reason) {
+                        deathReason.complete(reason);
+                        super.onDied(vm, reason);
+                    }
+                };
+        listener.runToFinish(logTag, vm);
+        return new BootResult(
+                payloadStarted.getNow(false), deathReason.getNow(DeathReason.INFRASTRUCTURE_ERROR));
+    }
+}
diff --git a/tests/testapk/Android.bp b/tests/testapk/Android.bp
index b3b0808..d468d76 100644
--- a/tests/testapk/Android.bp
+++ b/tests/testapk/Android.bp
@@ -10,6 +10,7 @@
     ],
     srcs: ["src/java/**/*.java"],
     static_libs: [
+        "MicroroidDeviceTestHelper",
         "androidx.test.runner",
         "androidx.test.ext.junit",
         "authfs_test_apk_assets",
diff --git a/tests/testapk/src/java/com/android/microdroid/test/MicrodroidTests.java b/tests/testapk/src/java/com/android/microdroid/test/MicrodroidTests.java
index 3a874c4..e7e4647 100644
--- a/tests/testapk/src/java/com/android/microdroid/test/MicrodroidTests.java
+++ b/tests/testapk/src/java/com/android/microdroid/test/MicrodroidTests.java
@@ -18,27 +18,18 @@
 import static com.google.common.truth.Truth.assertThat;
 import static com.google.common.truth.TruthJUnit.assume;
 
-import static org.junit.Assume.assumeNoException;
-
 import static java.nio.file.StandardCopyOption.REPLACE_EXISTING;
 
-import android.content.Context;
 import android.os.Build;
 import android.os.ParcelFileDescriptor;
 import android.os.SystemProperties;
-import android.sysprop.HypervisorProperties;
 import android.system.virtualizationservice.DeathReason;
 import android.system.virtualmachine.VirtualMachine;
-import android.system.virtualmachine.VirtualMachineCallback;
 import android.system.virtualmachine.VirtualMachineConfig;
 import android.system.virtualmachine.VirtualMachineConfig.DebugLevel;
 import android.system.virtualmachine.VirtualMachineException;
-import android.system.virtualmachine.VirtualMachineManager;
 import android.util.Log;
 
-import androidx.annotation.CallSuper;
-import androidx.test.core.app.ApplicationProvider;
-
 import com.android.microdroid.testservice.ITestService;
 
 import org.junit.After;
@@ -49,22 +40,16 @@
 import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
 
-import java.io.BufferedReader;
 import java.io.ByteArrayInputStream;
 import java.io.File;
 import java.io.FileInputStream;
 import java.io.IOException;
-import java.io.InputStream;
-import java.io.InputStreamReader;
 import java.io.RandomAccessFile;
 import java.nio.file.Files;
 import java.util.List;
 import java.util.OptionalLong;
 import java.util.UUID;
 import java.util.concurrent.CompletableFuture;
-import java.util.concurrent.ExecutorService;
-import java.util.concurrent.Executors;
-import java.util.concurrent.TimeUnit;
 
 import co.nstant.in.cbor.CborDecoder;
 import co.nstant.in.cbor.CborException;
@@ -73,140 +58,28 @@
 import co.nstant.in.cbor.model.MajorType;
 
 @RunWith(Parameterized.class)
-public class MicrodroidTests {
+public class MicrodroidTests extends MicrodroidDeviceTestBase {
     private static final String TAG = "MicrodroidTests";
 
     @Rule public Timeout globalTimeout = Timeout.seconds(300);
 
     private static final String KERNEL_VERSION = SystemProperties.get("ro.kernel.version");
 
-    /** Copy output from the VM to logcat. This is helpful when things go wrong. */
-    private static void logVmOutput(InputStream vmOutputStream, String name) {
-        new Thread(() -> {
-            try {
-                BufferedReader reader = new BufferedReader(new InputStreamReader(vmOutputStream));
-                String line;
-                while ((line = reader.readLine()) != null && !Thread.interrupted()) {
-                    Log.i(TAG, name + ": " + line);
-                }
-            } catch (Exception e) {
-                Log.w(TAG, name, e);
-            }
-        }).start();
-    }
-
-    private static class Inner {
-        public boolean mProtectedVm;
-        public Context mContext;
-        public VirtualMachineManager mVmm;
-        public VirtualMachine mVm;
-
-        Inner(boolean protectedVm) {
-            mProtectedVm = protectedVm;
-        }
-
-        /** Create a new VirtualMachineConfig.Builder with the parameterized protection mode. */
-        public VirtualMachineConfig.Builder newVmConfigBuilder(String payloadConfigPath) {
-            return new VirtualMachineConfig.Builder(mContext, payloadConfigPath)
-                            .protectedVm(mProtectedVm);
-        }
-    }
-
     @Parameterized.Parameters(name = "protectedVm={0}")
     public static Object[] protectedVmConfigs() {
         return new Object[] { false, true };
     }
 
-    @Parameterized.Parameter
-    public boolean mProtectedVm;
-
-    private boolean mPkvmSupported = false;
-    private Inner mInner;
+    @Parameterized.Parameter public boolean mProtectedVm;
 
     @Before
     public void setup() {
-        // In case when the virt APEX doesn't exist on the device, classes in the
-        // android.system.virtualmachine package can't be loaded. Therefore, before using the
-        // classes, check the existence of a class in the package and skip this test if not exist.
-        try {
-            Class.forName("android.system.virtualmachine.VirtualMachineManager");
-            mPkvmSupported = true;
-        } catch (ClassNotFoundException e) {
-            assumeNoException(e);
-            return;
-        }
-        if (mProtectedVm) {
-            assume()
-                .withMessage("Skip where protected VMs aren't support")
-                .that(HypervisorProperties.hypervisor_protected_vm_supported().orElse(false))
-                .isTrue();
-        } else {
-            assume()
-                .withMessage("Skip where VMs aren't support")
-                .that(HypervisorProperties.hypervisor_vm_supported().orElse(false))
-                .isTrue();
-        }
-        mInner = new Inner(mProtectedVm);
-        mInner.mContext = ApplicationProvider.getApplicationContext();
-        mInner.mVmm = VirtualMachineManager.getInstance(mInner.mContext);
+        prepareTestSetup(mProtectedVm);
     }
 
     @After
     public void cleanup() throws VirtualMachineException {
-        if (!mPkvmSupported) {
-            return;
-        }
-        if (mInner == null) {
-            return;
-        }
-        if (mInner.mVm == null) {
-            return;
-        }
-        mInner.mVm.stop();
-        mInner.mVm.delete();
-    }
-
-    private abstract static class VmEventListener implements VirtualMachineCallback {
-        private ExecutorService mExecutorService = Executors.newSingleThreadExecutor();
-
-        void runToFinish(VirtualMachine vm) throws VirtualMachineException, InterruptedException {
-            vm.setCallback(mExecutorService, this);
-            vm.run();
-            logVmOutput(vm.getConsoleOutputStream(), "Console");
-            logVmOutput(vm.getLogOutputStream(), "Log");
-            mExecutorService.awaitTermination(300, TimeUnit.SECONDS);
-        }
-
-        void forceStop(VirtualMachine vm) {
-            try {
-                vm.clearCallback();
-                vm.stop();
-                mExecutorService.shutdown();
-            } catch (VirtualMachineException e) {
-                throw new RuntimeException(e);
-            }
-        }
-
-        @Override
-        public void onPayloadStarted(VirtualMachine vm, ParcelFileDescriptor stream) {}
-
-        @Override
-        public void onPayloadReady(VirtualMachine vm) {}
-
-        @Override
-        public void onPayloadFinished(VirtualMachine vm, int exitCode) {}
-
-        @Override
-        public void onError(VirtualMachine vm, int errorCode, String message) {}
-
-        @Override
-        @CallSuper
-        public void onDied(VirtualMachine vm, @DeathReason int reason) {
-            mExecutorService.shutdown();
-        }
-
-        @Override
-        public void onRamdump(VirtualMachine vm, ParcelFileDescriptor ramdump) {}
+        cleanupTestSetup();
     }
 
     private static final int MIN_MEM_ARM64 = 150;
@@ -233,8 +106,7 @@
             }
         }
         VirtualMachineConfig config = builder.build();
-
-        mInner.mVm = mInner.mVmm.getOrCreate("test_vm_extra_apk", config);
+        VirtualMachine vm = mInner.forceCreateNewVirtualMachine("test_vm_extra_apk", config);
 
         class TestResults {
             Exception mException;
@@ -276,10 +148,11 @@
                     public void onPayloadStarted(VirtualMachine vm, ParcelFileDescriptor stream) {
                         Log.i(TAG, "onPayloadStarted");
                         payloadStarted.complete(true);
-                        logVmOutput(new FileInputStream(stream.getFileDescriptor()), "Payload");
+                        logVmOutput(TAG, new FileInputStream(stream.getFileDescriptor()),
+                                "Payload");
                     }
                 };
-        listener.runToFinish(mInner.mVm);
+        listener.runToFinish(TAG, vm);
         assertThat(payloadStarted.getNow(false)).isTrue();
         assertThat(payloadReady.getNow(false)).isTrue();
         assertThat(testResults.mException).isNull();
@@ -299,7 +172,7 @@
 
         VirtualMachineConfig.Builder builder = mInner.newVmConfigBuilder("assets/vm_config.json");
         VirtualMachineConfig normalConfig = builder.debugLevel(DebugLevel.NONE).build();
-        mInner.mVm = mInner.mVmm.getOrCreate("test_vm", normalConfig);
+        VirtualMachine vm = mInner.forceCreateNewVirtualMachine("test_vm", normalConfig);
         VmEventListener listener =
                 new VmEventListener() {
                     @Override
@@ -307,19 +180,20 @@
                         forceStop(vm);
                     }
                 };
-        listener.runToFinish(mInner.mVm);
+        listener.runToFinish(TAG, vm);
 
         // Launch the same VM with different debug level. The Java API prohibits this (thankfully).
         // For testing, we do that by creating another VM with debug level, and copy the config file
         // from the new VM directory to the old VM directory.
         VirtualMachineConfig debugConfig = builder.debugLevel(DebugLevel.FULL).build();
-        VirtualMachine newVm  = mInner.mVmm.getOrCreate("test_debug_vm", debugConfig);
-        File vmRoot = new File(mInner.mContext.getFilesDir(), "vm");
+        VirtualMachine newVm = mInner.forceCreateNewVirtualMachine("test_debug_vm", debugConfig);
+        File vmRoot = new File(getContext().getFilesDir(), "vm");
         File newVmConfig = new File(new File(vmRoot, "test_debug_vm"), "config.xml");
         File oldVmConfig = new File(new File(vmRoot, "test_vm"), "config.xml");
         Files.copy(newVmConfig.toPath(), oldVmConfig.toPath(), REPLACE_EXISTING);
         newVm.delete();
-        mInner.mVm = mInner.mVmm.get("test_vm"); // re-load with the copied-in config file.
+        // re-load with the copied-in config file.
+        vm = mInner.getVirtualMachineManager().get("test_vm");
         final CompletableFuture<Boolean> payloadStarted = new CompletableFuture<>();
         listener =
                 new VmEventListener() {
@@ -329,7 +203,7 @@
                         forceStop(vm);
                     }
                 };
-        listener.runToFinish(mInner.mVm);
+        listener.runToFinish(TAG, vm);
         assertThat(payloadStarted.getNow(false)).isFalse();
     }
 
@@ -340,10 +214,7 @@
 
     private VmCdis launchVmAndGetCdis(String instanceName)
             throws VirtualMachineException, InterruptedException {
-        VirtualMachineConfig normalConfig = mInner.newVmConfigBuilder("assets/vm_config.json")
-                .debugLevel(DebugLevel.NONE)
-                .build();
-        mInner.mVm = mInner.mVmm.getOrCreate(instanceName, normalConfig);
+        VirtualMachine vm = mInner.getVirtualMachineManager().get(instanceName);
         final VmCdis vmCdis = new VmCdis();
         final CompletableFuture<Exception> exception = new CompletableFuture<>();
         VmEventListener listener =
@@ -361,7 +232,7 @@
                         }
                     }
                 };
-        listener.runToFinish(mInner.mVm);
+        listener.runToFinish(TAG, vm);
         assertThat(exception.getNow(null)).isNull();
         return vmCdis;
     }
@@ -374,6 +245,11 @@
             .that(KERNEL_VERSION)
             .isNotEqualTo("5.4");
 
+        VirtualMachineConfig normalConfig = mInner.newVmConfigBuilder("assets/vm_config.json")
+                .debugLevel(DebugLevel.NONE)
+                .build();
+        mInner.forceCreateNewVirtualMachine("test_vm_a", normalConfig);
+        mInner.forceCreateNewVirtualMachine("test_vm_b", normalConfig);
         VmCdis vm_a_cdis = launchVmAndGetCdis("test_vm_a");
         VmCdis vm_b_cdis = launchVmAndGetCdis("test_vm_b");
         assertThat(vm_a_cdis.cdiAttest).isNotNull();
@@ -393,6 +269,11 @@
             .that(KERNEL_VERSION)
             .isNotEqualTo("5.4");
 
+        VirtualMachineConfig normalConfig = mInner.newVmConfigBuilder("assets/vm_config.json")
+                .debugLevel(DebugLevel.NONE)
+                .build();
+        mInner.forceCreateNewVirtualMachine("test_vm", normalConfig);
+
         VmCdis first_boot_cdis = launchVmAndGetCdis("test_vm");
         VmCdis second_boot_cdis = launchVmAndGetCdis("test_vm");
         // The attestation CDI isn't specified to be stable, though it might be
@@ -412,7 +293,7 @@
         VirtualMachineConfig normalConfig = mInner.newVmConfigBuilder("assets/vm_config.json")
                 .debugLevel(DebugLevel.NONE)
                 .build();
-        mInner.mVm = mInner.mVmm.getOrCreate("bcc_vm", normalConfig);
+        VirtualMachine vm = mInner.forceCreateNewVirtualMachine("bcc_vm", normalConfig);
         final VmCdis vmCdis = new VmCdis();
         final CompletableFuture<byte[]> bcc = new CompletableFuture<>();
         final CompletableFuture<Exception> exception = new CompletableFuture<>();
@@ -430,7 +311,7 @@
                         }
                     }
                 };
-        listener.runToFinish(mInner.mVm);
+        listener.runToFinish(TAG, vm);
         byte[] bccBytes = bcc.getNow(null);
         assertThat(exception.getNow(null)).isNull();
         assertThat(bccBytes).isNotNull();
@@ -483,57 +364,19 @@
         file.writeByte(b ^ 1);
     }
 
-    private static class BootResult {
-        public final boolean payloadStarted;
-        public final int deathReason;
-
-        BootResult(boolean payloadStarted, int deathReason) {
-            this.payloadStarted = payloadStarted;
-            this.deathReason = deathReason;
-        }
-    }
-
-    private BootResult tryBootVm(String vmName)
-            throws VirtualMachineException, InterruptedException {
-        mInner.mVm = mInner.mVmm.get(vmName); // re-load the vm before running tests
-        final CompletableFuture<Boolean> payloadStarted = new CompletableFuture<>();
-        final CompletableFuture<Integer> deathReason = new CompletableFuture<>();
-        VmEventListener listener =
-                new VmEventListener() {
-                    @Override
-                    public void onPayloadStarted(VirtualMachine vm, ParcelFileDescriptor stream) {
-                        payloadStarted.complete(true);
-                        forceStop(vm);
-                    }
-                    @Override
-                    public void onDied(VirtualMachine vm, int reason) {
-                        deathReason.complete(reason);
-                        super.onDied(vm, reason);
-                    }
-                };
-        listener.runToFinish(mInner.mVm);
-        return new BootResult(
-                payloadStarted.getNow(false), deathReason.getNow(DeathReason.INFRASTRUCTURE_ERROR));
-    }
-
     private RandomAccessFile prepareInstanceImage(String vmName)
             throws VirtualMachineException, InterruptedException, IOException {
         VirtualMachineConfig config = mInner.newVmConfigBuilder("assets/vm_config.json")
                 .debugLevel(DebugLevel.FULL)
                 .build();
 
-        // Remove any existing VM so we can start from scratch
-        VirtualMachine oldVm = mInner.mVmm.getOrCreate(vmName, config);
-        oldVm.delete();
-        mInner.mVmm.getOrCreate(vmName, config);
+        mInner.forceCreateNewVirtualMachine(vmName, config);
+        assertThat(tryBootVm(TAG, vmName).payloadStarted).isTrue();
 
-        assertThat(tryBootVm(vmName).payloadStarted).isTrue();
-
-        File vmRoot = new File(mInner.mContext.getFilesDir(), "vm");
+        File vmRoot = new File(getContext().getFilesDir(), "vm");
         File vmDir = new File(vmRoot, vmName);
         File instanceImgPath = new File(vmDir, "instance.img");
         return new RandomAccessFile(instanceImgPath, "rw");
-
     }
 
     private void assertThatPartitionIsMissing(UUID partitionUuid)
@@ -551,7 +394,8 @@
         assertThat(offset.isPresent()).isTrue();
 
         flipBit(instanceFile, offset.getAsLong());
-        assertThat(tryBootVm("test_vm_integrity").payloadStarted).isFalse();
+
+        assertThat(tryBootVm(TAG, "test_vm_integrity").payloadStarted).isFalse();
     }
 
     @Test
@@ -596,17 +440,12 @@
     @Test
     public void bootFailsWhenConfigIsInvalid()
             throws VirtualMachineException, InterruptedException, IOException {
-        VirtualMachine existingVm = mInner.mVmm.get("test_vm_invalid_config");
-        if (existingVm != null) {
-            existingVm.delete();
-        }
-
         VirtualMachineConfig.Builder builder =
                 mInner.newVmConfigBuilder("assets/vm_config_no_task.json");
         VirtualMachineConfig normalConfig = builder.debugLevel(DebugLevel.NONE).build();
-        mInner.mVmm.create("test_vm_invalid_config", normalConfig);
+        mInner.forceCreateNewVirtualMachine("test_vm_invalid_config", normalConfig);
 
-        BootResult bootResult = tryBootVm("test_vm_invalid_config");
+        BootResult bootResult = tryBootVm(TAG, "test_vm_invalid_config");
         assertThat(bootResult.payloadStarted).isFalse();
         assertThat(bootResult.deathReason).isEqualTo(DeathReason.MICRODROID_INVALID_PAYLOAD_CONFIG);
     }