Add a helper for device tests to remove duplicates

Test: atest MicrodroidTests MicrodroidBenchmarks
Change-Id: I0983c0d6dd21a9eab1be2bf8ee78823e3614f646
diff --git a/tests/testapk/Android.bp b/tests/testapk/Android.bp
index b3b0808..d468d76 100644
--- a/tests/testapk/Android.bp
+++ b/tests/testapk/Android.bp
@@ -10,6 +10,7 @@
     ],
     srcs: ["src/java/**/*.java"],
     static_libs: [
+        "MicroroidDeviceTestHelper",
         "androidx.test.runner",
         "androidx.test.ext.junit",
         "authfs_test_apk_assets",
diff --git a/tests/testapk/src/java/com/android/microdroid/test/MicrodroidTests.java b/tests/testapk/src/java/com/android/microdroid/test/MicrodroidTests.java
index 3a874c4..e7e4647 100644
--- a/tests/testapk/src/java/com/android/microdroid/test/MicrodroidTests.java
+++ b/tests/testapk/src/java/com/android/microdroid/test/MicrodroidTests.java
@@ -18,27 +18,18 @@
 import static com.google.common.truth.Truth.assertThat;
 import static com.google.common.truth.TruthJUnit.assume;
 
-import static org.junit.Assume.assumeNoException;
-
 import static java.nio.file.StandardCopyOption.REPLACE_EXISTING;
 
-import android.content.Context;
 import android.os.Build;
 import android.os.ParcelFileDescriptor;
 import android.os.SystemProperties;
-import android.sysprop.HypervisorProperties;
 import android.system.virtualizationservice.DeathReason;
 import android.system.virtualmachine.VirtualMachine;
-import android.system.virtualmachine.VirtualMachineCallback;
 import android.system.virtualmachine.VirtualMachineConfig;
 import android.system.virtualmachine.VirtualMachineConfig.DebugLevel;
 import android.system.virtualmachine.VirtualMachineException;
-import android.system.virtualmachine.VirtualMachineManager;
 import android.util.Log;
 
-import androidx.annotation.CallSuper;
-import androidx.test.core.app.ApplicationProvider;
-
 import com.android.microdroid.testservice.ITestService;
 
 import org.junit.After;
@@ -49,22 +40,16 @@
 import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
 
-import java.io.BufferedReader;
 import java.io.ByteArrayInputStream;
 import java.io.File;
 import java.io.FileInputStream;
 import java.io.IOException;
-import java.io.InputStream;
-import java.io.InputStreamReader;
 import java.io.RandomAccessFile;
 import java.nio.file.Files;
 import java.util.List;
 import java.util.OptionalLong;
 import java.util.UUID;
 import java.util.concurrent.CompletableFuture;
-import java.util.concurrent.ExecutorService;
-import java.util.concurrent.Executors;
-import java.util.concurrent.TimeUnit;
 
 import co.nstant.in.cbor.CborDecoder;
 import co.nstant.in.cbor.CborException;
@@ -73,140 +58,28 @@
 import co.nstant.in.cbor.model.MajorType;
 
 @RunWith(Parameterized.class)
-public class MicrodroidTests {
+public class MicrodroidTests extends MicrodroidDeviceTestBase {
     private static final String TAG = "MicrodroidTests";
 
     @Rule public Timeout globalTimeout = Timeout.seconds(300);
 
     private static final String KERNEL_VERSION = SystemProperties.get("ro.kernel.version");
 
-    /** Copy output from the VM to logcat. This is helpful when things go wrong. */
-    private static void logVmOutput(InputStream vmOutputStream, String name) {
-        new Thread(() -> {
-            try {
-                BufferedReader reader = new BufferedReader(new InputStreamReader(vmOutputStream));
-                String line;
-                while ((line = reader.readLine()) != null && !Thread.interrupted()) {
-                    Log.i(TAG, name + ": " + line);
-                }
-            } catch (Exception e) {
-                Log.w(TAG, name, e);
-            }
-        }).start();
-    }
-
-    private static class Inner {
-        public boolean mProtectedVm;
-        public Context mContext;
-        public VirtualMachineManager mVmm;
-        public VirtualMachine mVm;
-
-        Inner(boolean protectedVm) {
-            mProtectedVm = protectedVm;
-        }
-
-        /** Create a new VirtualMachineConfig.Builder with the parameterized protection mode. */
-        public VirtualMachineConfig.Builder newVmConfigBuilder(String payloadConfigPath) {
-            return new VirtualMachineConfig.Builder(mContext, payloadConfigPath)
-                            .protectedVm(mProtectedVm);
-        }
-    }
-
     @Parameterized.Parameters(name = "protectedVm={0}")
     public static Object[] protectedVmConfigs() {
         return new Object[] { false, true };
     }
 
-    @Parameterized.Parameter
-    public boolean mProtectedVm;
-
-    private boolean mPkvmSupported = false;
-    private Inner mInner;
+    @Parameterized.Parameter public boolean mProtectedVm;
 
     @Before
     public void setup() {
-        // In case when the virt APEX doesn't exist on the device, classes in the
-        // android.system.virtualmachine package can't be loaded. Therefore, before using the
-        // classes, check the existence of a class in the package and skip this test if not exist.
-        try {
-            Class.forName("android.system.virtualmachine.VirtualMachineManager");
-            mPkvmSupported = true;
-        } catch (ClassNotFoundException e) {
-            assumeNoException(e);
-            return;
-        }
-        if (mProtectedVm) {
-            assume()
-                .withMessage("Skip where protected VMs aren't support")
-                .that(HypervisorProperties.hypervisor_protected_vm_supported().orElse(false))
-                .isTrue();
-        } else {
-            assume()
-                .withMessage("Skip where VMs aren't support")
-                .that(HypervisorProperties.hypervisor_vm_supported().orElse(false))
-                .isTrue();
-        }
-        mInner = new Inner(mProtectedVm);
-        mInner.mContext = ApplicationProvider.getApplicationContext();
-        mInner.mVmm = VirtualMachineManager.getInstance(mInner.mContext);
+        prepareTestSetup(mProtectedVm);
     }
 
     @After
     public void cleanup() throws VirtualMachineException {
-        if (!mPkvmSupported) {
-            return;
-        }
-        if (mInner == null) {
-            return;
-        }
-        if (mInner.mVm == null) {
-            return;
-        }
-        mInner.mVm.stop();
-        mInner.mVm.delete();
-    }
-
-    private abstract static class VmEventListener implements VirtualMachineCallback {
-        private ExecutorService mExecutorService = Executors.newSingleThreadExecutor();
-
-        void runToFinish(VirtualMachine vm) throws VirtualMachineException, InterruptedException {
-            vm.setCallback(mExecutorService, this);
-            vm.run();
-            logVmOutput(vm.getConsoleOutputStream(), "Console");
-            logVmOutput(vm.getLogOutputStream(), "Log");
-            mExecutorService.awaitTermination(300, TimeUnit.SECONDS);
-        }
-
-        void forceStop(VirtualMachine vm) {
-            try {
-                vm.clearCallback();
-                vm.stop();
-                mExecutorService.shutdown();
-            } catch (VirtualMachineException e) {
-                throw new RuntimeException(e);
-            }
-        }
-
-        @Override
-        public void onPayloadStarted(VirtualMachine vm, ParcelFileDescriptor stream) {}
-
-        @Override
-        public void onPayloadReady(VirtualMachine vm) {}
-
-        @Override
-        public void onPayloadFinished(VirtualMachine vm, int exitCode) {}
-
-        @Override
-        public void onError(VirtualMachine vm, int errorCode, String message) {}
-
-        @Override
-        @CallSuper
-        public void onDied(VirtualMachine vm, @DeathReason int reason) {
-            mExecutorService.shutdown();
-        }
-
-        @Override
-        public void onRamdump(VirtualMachine vm, ParcelFileDescriptor ramdump) {}
+        cleanupTestSetup();
     }
 
     private static final int MIN_MEM_ARM64 = 150;
@@ -233,8 +106,7 @@
             }
         }
         VirtualMachineConfig config = builder.build();
-
-        mInner.mVm = mInner.mVmm.getOrCreate("test_vm_extra_apk", config);
+        VirtualMachine vm = mInner.forceCreateNewVirtualMachine("test_vm_extra_apk", config);
 
         class TestResults {
             Exception mException;
@@ -276,10 +148,11 @@
                     public void onPayloadStarted(VirtualMachine vm, ParcelFileDescriptor stream) {
                         Log.i(TAG, "onPayloadStarted");
                         payloadStarted.complete(true);
-                        logVmOutput(new FileInputStream(stream.getFileDescriptor()), "Payload");
+                        logVmOutput(TAG, new FileInputStream(stream.getFileDescriptor()),
+                                "Payload");
                     }
                 };
-        listener.runToFinish(mInner.mVm);
+        listener.runToFinish(TAG, vm);
         assertThat(payloadStarted.getNow(false)).isTrue();
         assertThat(payloadReady.getNow(false)).isTrue();
         assertThat(testResults.mException).isNull();
@@ -299,7 +172,7 @@
 
         VirtualMachineConfig.Builder builder = mInner.newVmConfigBuilder("assets/vm_config.json");
         VirtualMachineConfig normalConfig = builder.debugLevel(DebugLevel.NONE).build();
-        mInner.mVm = mInner.mVmm.getOrCreate("test_vm", normalConfig);
+        VirtualMachine vm = mInner.forceCreateNewVirtualMachine("test_vm", normalConfig);
         VmEventListener listener =
                 new VmEventListener() {
                     @Override
@@ -307,19 +180,20 @@
                         forceStop(vm);
                     }
                 };
-        listener.runToFinish(mInner.mVm);
+        listener.runToFinish(TAG, vm);
 
         // Launch the same VM with different debug level. The Java API prohibits this (thankfully).
         // For testing, we do that by creating another VM with debug level, and copy the config file
         // from the new VM directory to the old VM directory.
         VirtualMachineConfig debugConfig = builder.debugLevel(DebugLevel.FULL).build();
-        VirtualMachine newVm  = mInner.mVmm.getOrCreate("test_debug_vm", debugConfig);
-        File vmRoot = new File(mInner.mContext.getFilesDir(), "vm");
+        VirtualMachine newVm = mInner.forceCreateNewVirtualMachine("test_debug_vm", debugConfig);
+        File vmRoot = new File(getContext().getFilesDir(), "vm");
         File newVmConfig = new File(new File(vmRoot, "test_debug_vm"), "config.xml");
         File oldVmConfig = new File(new File(vmRoot, "test_vm"), "config.xml");
         Files.copy(newVmConfig.toPath(), oldVmConfig.toPath(), REPLACE_EXISTING);
         newVm.delete();
-        mInner.mVm = mInner.mVmm.get("test_vm"); // re-load with the copied-in config file.
+        // re-load with the copied-in config file.
+        vm = mInner.getVirtualMachineManager().get("test_vm");
         final CompletableFuture<Boolean> payloadStarted = new CompletableFuture<>();
         listener =
                 new VmEventListener() {
@@ -329,7 +203,7 @@
                         forceStop(vm);
                     }
                 };
-        listener.runToFinish(mInner.mVm);
+        listener.runToFinish(TAG, vm);
         assertThat(payloadStarted.getNow(false)).isFalse();
     }
 
@@ -340,10 +214,7 @@
 
     private VmCdis launchVmAndGetCdis(String instanceName)
             throws VirtualMachineException, InterruptedException {
-        VirtualMachineConfig normalConfig = mInner.newVmConfigBuilder("assets/vm_config.json")
-                .debugLevel(DebugLevel.NONE)
-                .build();
-        mInner.mVm = mInner.mVmm.getOrCreate(instanceName, normalConfig);
+        VirtualMachine vm = mInner.getVirtualMachineManager().get(instanceName);
         final VmCdis vmCdis = new VmCdis();
         final CompletableFuture<Exception> exception = new CompletableFuture<>();
         VmEventListener listener =
@@ -361,7 +232,7 @@
                         }
                     }
                 };
-        listener.runToFinish(mInner.mVm);
+        listener.runToFinish(TAG, vm);
         assertThat(exception.getNow(null)).isNull();
         return vmCdis;
     }
@@ -374,6 +245,11 @@
             .that(KERNEL_VERSION)
             .isNotEqualTo("5.4");
 
+        VirtualMachineConfig normalConfig = mInner.newVmConfigBuilder("assets/vm_config.json")
+                .debugLevel(DebugLevel.NONE)
+                .build();
+        mInner.forceCreateNewVirtualMachine("test_vm_a", normalConfig);
+        mInner.forceCreateNewVirtualMachine("test_vm_b", normalConfig);
         VmCdis vm_a_cdis = launchVmAndGetCdis("test_vm_a");
         VmCdis vm_b_cdis = launchVmAndGetCdis("test_vm_b");
         assertThat(vm_a_cdis.cdiAttest).isNotNull();
@@ -393,6 +269,11 @@
             .that(KERNEL_VERSION)
             .isNotEqualTo("5.4");
 
+        VirtualMachineConfig normalConfig = mInner.newVmConfigBuilder("assets/vm_config.json")
+                .debugLevel(DebugLevel.NONE)
+                .build();
+        mInner.forceCreateNewVirtualMachine("test_vm", normalConfig);
+
         VmCdis first_boot_cdis = launchVmAndGetCdis("test_vm");
         VmCdis second_boot_cdis = launchVmAndGetCdis("test_vm");
         // The attestation CDI isn't specified to be stable, though it might be
@@ -412,7 +293,7 @@
         VirtualMachineConfig normalConfig = mInner.newVmConfigBuilder("assets/vm_config.json")
                 .debugLevel(DebugLevel.NONE)
                 .build();
-        mInner.mVm = mInner.mVmm.getOrCreate("bcc_vm", normalConfig);
+        VirtualMachine vm = mInner.forceCreateNewVirtualMachine("bcc_vm", normalConfig);
         final VmCdis vmCdis = new VmCdis();
         final CompletableFuture<byte[]> bcc = new CompletableFuture<>();
         final CompletableFuture<Exception> exception = new CompletableFuture<>();
@@ -430,7 +311,7 @@
                         }
                     }
                 };
-        listener.runToFinish(mInner.mVm);
+        listener.runToFinish(TAG, vm);
         byte[] bccBytes = bcc.getNow(null);
         assertThat(exception.getNow(null)).isNull();
         assertThat(bccBytes).isNotNull();
@@ -483,57 +364,19 @@
         file.writeByte(b ^ 1);
     }
 
-    private static class BootResult {
-        public final boolean payloadStarted;
-        public final int deathReason;
-
-        BootResult(boolean payloadStarted, int deathReason) {
-            this.payloadStarted = payloadStarted;
-            this.deathReason = deathReason;
-        }
-    }
-
-    private BootResult tryBootVm(String vmName)
-            throws VirtualMachineException, InterruptedException {
-        mInner.mVm = mInner.mVmm.get(vmName); // re-load the vm before running tests
-        final CompletableFuture<Boolean> payloadStarted = new CompletableFuture<>();
-        final CompletableFuture<Integer> deathReason = new CompletableFuture<>();
-        VmEventListener listener =
-                new VmEventListener() {
-                    @Override
-                    public void onPayloadStarted(VirtualMachine vm, ParcelFileDescriptor stream) {
-                        payloadStarted.complete(true);
-                        forceStop(vm);
-                    }
-                    @Override
-                    public void onDied(VirtualMachine vm, int reason) {
-                        deathReason.complete(reason);
-                        super.onDied(vm, reason);
-                    }
-                };
-        listener.runToFinish(mInner.mVm);
-        return new BootResult(
-                payloadStarted.getNow(false), deathReason.getNow(DeathReason.INFRASTRUCTURE_ERROR));
-    }
-
     private RandomAccessFile prepareInstanceImage(String vmName)
             throws VirtualMachineException, InterruptedException, IOException {
         VirtualMachineConfig config = mInner.newVmConfigBuilder("assets/vm_config.json")
                 .debugLevel(DebugLevel.FULL)
                 .build();
 
-        // Remove any existing VM so we can start from scratch
-        VirtualMachine oldVm = mInner.mVmm.getOrCreate(vmName, config);
-        oldVm.delete();
-        mInner.mVmm.getOrCreate(vmName, config);
+        mInner.forceCreateNewVirtualMachine(vmName, config);
+        assertThat(tryBootVm(TAG, vmName).payloadStarted).isTrue();
 
-        assertThat(tryBootVm(vmName).payloadStarted).isTrue();
-
-        File vmRoot = new File(mInner.mContext.getFilesDir(), "vm");
+        File vmRoot = new File(getContext().getFilesDir(), "vm");
         File vmDir = new File(vmRoot, vmName);
         File instanceImgPath = new File(vmDir, "instance.img");
         return new RandomAccessFile(instanceImgPath, "rw");
-
     }
 
     private void assertThatPartitionIsMissing(UUID partitionUuid)
@@ -551,7 +394,8 @@
         assertThat(offset.isPresent()).isTrue();
 
         flipBit(instanceFile, offset.getAsLong());
-        assertThat(tryBootVm("test_vm_integrity").payloadStarted).isFalse();
+
+        assertThat(tryBootVm(TAG, "test_vm_integrity").payloadStarted).isFalse();
     }
 
     @Test
@@ -596,17 +440,12 @@
     @Test
     public void bootFailsWhenConfigIsInvalid()
             throws VirtualMachineException, InterruptedException, IOException {
-        VirtualMachine existingVm = mInner.mVmm.get("test_vm_invalid_config");
-        if (existingVm != null) {
-            existingVm.delete();
-        }
-
         VirtualMachineConfig.Builder builder =
                 mInner.newVmConfigBuilder("assets/vm_config_no_task.json");
         VirtualMachineConfig normalConfig = builder.debugLevel(DebugLevel.NONE).build();
-        mInner.mVmm.create("test_vm_invalid_config", normalConfig);
+        mInner.forceCreateNewVirtualMachine("test_vm_invalid_config", normalConfig);
 
-        BootResult bootResult = tryBootVm("test_vm_invalid_config");
+        BootResult bootResult = tryBootVm(TAG, "test_vm_invalid_config");
         assertThat(bootResult.payloadStarted).isFalse();
         assertThat(bootResult.deathReason).isEqualTo(DeathReason.MICRODROID_INVALID_PAYLOAD_CONFIG);
     }