Merge "Refactor Benchmark VM event listeners"
diff --git a/tests/benchmark/src/java/com/android/microdroid/benchmark/BenchmarkVmListener.java b/tests/benchmark/src/java/com/android/microdroid/benchmark/BenchmarkVmListener.java
new file mode 100644
index 0000000..eb45a71
--- /dev/null
+++ b/tests/benchmark/src/java/com/android/microdroid/benchmark/BenchmarkVmListener.java
@@ -0,0 +1,66 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.microdroid.benchmark;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import android.os.RemoteException;
+import android.system.virtualmachine.VirtualMachine;
+import android.util.Log;
+
+import com.android.microdroid.test.MicrodroidDeviceTestBase.VmEventListener;
+import com.android.microdroid.testservice.IBenchmarkService;
+
+/**
+ * This VM listener is used in {@link MicrodroidBenchmark} tests to facilitate the communication
+ * between the host and VM via {@link IBenchmarkService}.
+ */
+class BenchmarkVmListener extends VmEventListener {
+    private static final String TAG = "BenchmarkVm";
+
+    interface InnerListener {
+        /** This is invoked when both the payload and {@link IBenchmarkService} are ready. */
+        void onPayloadReady(VirtualMachine vm, IBenchmarkService benchmarkService)
+                throws RemoteException;
+    }
+
+    private final InnerListener mListener;
+
+    private BenchmarkVmListener(InnerListener listener) {
+        mListener = listener;
+    }
+
+    @Override
+    public final void onPayloadReady(VirtualMachine vm) {
+        try {
+            IBenchmarkService benchmarkService =
+                    IBenchmarkService.Stub.asInterface(
+                            vm.connectToVsockServer(IBenchmarkService.SERVICE_PORT).get());
+            assertThat(benchmarkService).isNotNull();
+
+            mListener.onPayloadReady(vm, benchmarkService);
+        } catch (Exception e) {
+            Log.e(TAG, "Error inside onPayloadReady():" + e);
+            throw new RuntimeException(e);
+        }
+        forceStop(vm);
+    }
+
+    static BenchmarkVmListener create(InnerListener listener) {
+        return new BenchmarkVmListener(listener);
+    }
+}
diff --git a/tests/benchmark/src/java/com/android/microdroid/benchmark/MicrodroidBenchmarks.java b/tests/benchmark/src/java/com/android/microdroid/benchmark/MicrodroidBenchmarks.java
index e19c72a..200b7bf 100644
--- a/tests/benchmark/src/java/com/android/microdroid/benchmark/MicrodroidBenchmarks.java
+++ b/tests/benchmark/src/java/com/android/microdroid/benchmark/MicrodroidBenchmarks.java
@@ -24,6 +24,7 @@
 import android.app.Instrumentation;
 import android.os.Bundle;
 import android.os.ParcelFileDescriptor;
+import android.os.RemoteException;
 import android.system.virtualmachine.VirtualMachine;
 import android.system.virtualmachine.VirtualMachineConfig;
 import android.system.virtualmachine.VirtualMachineConfig.DebugLevel;
@@ -193,8 +194,7 @@
             String vmName = "test_vm_io_" + i;
             mInner.forceCreateNewVirtualMachine(vmName, config);
             VirtualMachine vm = mInner.getVirtualMachineManager().get(vmName);
-            VsockVmEventListener listener = new VsockVmEventListener(transferRates, port);
-            listener.runToFinish(TAG, vm);
+            BenchmarkVmListener.create(new VsockListener(transferRates, port)).runToFinish(TAG, vm);
         }
         reportMetrics(transferRates, "vsock/transfer_host_to_vm_", "_mb_per_sec");
     }
@@ -226,8 +226,8 @@
             String vmName = "test_vm_io_" + i;
             mInner.forceCreateNewVirtualMachine(vmName, config);
             VirtualMachine vm = mInner.getVirtualMachineManager().get(vmName);
-            VirtioBlkVmEventListener listener = new VirtioBlkVmEventListener(readRates, isRand);
-            listener.runToFinish(TAG, vm);
+            BenchmarkVmListener.create(new VirtioBlkListener(readRates, isRand))
+                    .runToFinish(TAG, vm);
         }
         reportMetrics(
                 readRates,
@@ -260,14 +260,14 @@
         mInstrumentation.sendStatus(0, bundle);
     }
 
-    private static class VirtioBlkVmEventListener extends VmEventListener {
+    private static class VirtioBlkListener implements BenchmarkVmListener.InnerListener {
         private static final String FILENAME = APEX_ETC_FS + "microdroid_super.img";
 
         private final long mFileSizeBytes;
         private final List<Double> mReadRates;
         private final boolean mIsRand;
 
-        VirtioBlkVmEventListener(List<Double> readRates, boolean isRand) {
+        VirtioBlkListener(List<Double> readRates, boolean isRand) {
             File file = new File(FILENAME);
             try {
                 mFileSizeBytes = Files.size(file.toPath());
@@ -280,32 +280,26 @@
         }
 
         @Override
-        public void onPayloadReady(VirtualMachine vm) {
-            try {
-                IBenchmarkService benchmarkService =
-                        IBenchmarkService.Stub.asInterface(
-                                vm.connectToVsockServer(IBenchmarkService.SERVICE_PORT).get());
-                double elapsedSeconds =
-                        benchmarkService.readFile(FILENAME, mFileSizeBytes, mIsRand);
-                double fileSizeMb = mFileSizeBytes / SIZE_MB;
-                mReadRates.add(fileSizeMb / elapsedSeconds);
-            } catch (Exception e) {
-                throw new RuntimeException(e);
-            }
-            forceStop(vm);
+        public void onPayloadReady(VirtualMachine vm, IBenchmarkService benchmarkService)
+                throws RemoteException {
+            double elapsedSeconds = benchmarkService.readFile(FILENAME, mFileSizeBytes, mIsRand);
+            double fileSizeMb = mFileSizeBytes / SIZE_MB;
+            mReadRates.add(fileSizeMb / elapsedSeconds);
         }
     }
 
     @Test
     public void testMemoryUsage() throws Exception {
         final String vmName = "test_vm_mem_usage";
-        VirtualMachineConfig.Builder builder = mInner.newVmConfigBuilder(
-                "assets/vm_config_io.json");
-        VirtualMachineConfig config = builder.debugLevel(DebugLevel.NONE).memoryMib(256).build();
+        VirtualMachineConfig config =
+                mInner.newVmConfigBuilder("assets/vm_config_io.json")
+                        .debugLevel(DebugLevel.NONE)
+                        .memoryMib(256)
+                        .build();
         mInner.forceCreateNewVirtualMachine(vmName, config);
         VirtualMachine vm = mInner.getVirtualMachineManager().get(vmName);
         MemoryUsageListener listener = new MemoryUsageListener();
-        listener.runToFinish(TAG, vm);
+        BenchmarkVmListener.create(listener).runToFinish(TAG, vm);
 
         double mem_overall = 256.0;
         double mem_total = (double) listener.mMemTotal / 1024.0;
@@ -329,7 +323,7 @@
         mInstrumentation.sendStatus(0, bundle);
     }
 
-    private static class MemoryUsageListener extends VmEventListener {
+    private static class MemoryUsageListener implements BenchmarkVmListener.InnerListener {
         public long mMemTotal;
         public long mMemFree;
         public long mMemAvailable;
@@ -338,55 +332,38 @@
         public long mSlab;
 
         @Override
-        public void onPayloadReady(VirtualMachine vm) {
-            try {
-                IBenchmarkService service =
-                        IBenchmarkService.Stub.asInterface(
-                                vm.connectToVsockServer(IBenchmarkService.SERVICE_PORT).get());
-
-                mMemTotal = service.getMemInfoEntry("MemTotal");
-                mMemFree = service.getMemInfoEntry("MemFree");
-                mMemAvailable = service.getMemInfoEntry("MemAvailable");
-                mBuffers = service.getMemInfoEntry("Buffers");
-                mCached = service.getMemInfoEntry("Cached");
-                mSlab = service.getMemInfoEntry("Slab");
-            } catch (Exception e) {
-                throw new RuntimeException(e);
-            }
-            forceStop(vm);
+        public void onPayloadReady(VirtualMachine vm, IBenchmarkService service)
+                throws RemoteException {
+            mMemTotal = service.getMemInfoEntry("MemTotal");
+            mMemFree = service.getMemInfoEntry("MemFree");
+            mMemAvailable = service.getMemInfoEntry("MemAvailable");
+            mBuffers = service.getMemInfoEntry("Buffers");
+            mCached = service.getMemInfoEntry("Cached");
+            mSlab = service.getMemInfoEntry("Slab");
         }
     }
 
-    private static class VsockVmEventListener extends VmEventListener {
+    private static class VsockListener implements BenchmarkVmListener.InnerListener {
         private static final int NUM_BYTES_TO_TRANSFER = 48 * 1024 * 1024;
 
         private final List<Double> mReadRates;
         private final int mPort;
 
-        VsockVmEventListener(List<Double> readRates, int port) {
+        VsockListener(List<Double> readRates, int port) {
             mReadRates = readRates;
             mPort = port;
         }
 
         @Override
-        public void onPayloadReady(VirtualMachine vm) {
-            try {
-                IBenchmarkService benchmarkService =
-                        IBenchmarkService.Stub.asInterface(
-                                vm.connectToVsockServer(IBenchmarkService.SERVICE_PORT).get());
-                assertThat(benchmarkService).isNotNull();
-                AtomicReference<Double> sendRate = new AtomicReference();
+        public void onPayloadReady(VirtualMachine vm, IBenchmarkService benchmarkService)
+                throws RemoteException {
+            AtomicReference<Double> sendRate = new AtomicReference();
 
-                int serverFd = benchmarkService.initVsockServer(mPort);
-                new Thread(() -> sendRate.set(runVsockClientAndSendData(vm))).start();
-                benchmarkService.runVsockServerAndReceiveData(serverFd, NUM_BYTES_TO_TRANSFER);
+            int serverFd = benchmarkService.initVsockServer(mPort);
+            new Thread(() -> sendRate.set(runVsockClientAndSendData(vm))).start();
+            benchmarkService.runVsockServerAndReceiveData(serverFd, NUM_BYTES_TO_TRANSFER);
 
-                mReadRates.add(sendRate.get());
-            } catch (Exception e) {
-                Log.e(TAG, "Test failed in VM:" + e);
-                throw new RuntimeException(e);
-            }
-            forceStop(vm);
+            mReadRates.add(sendRate.get());
         }
 
         private double runVsockClientAndSendData(VirtualMachine vm) {
diff --git a/tests/helper/src/java/com/android/microdroid/test/MicrodroidDeviceTestBase.java b/tests/helper/src/java/com/android/microdroid/test/MicrodroidDeviceTestBase.java
index 1a573bb..a80111f 100644
--- a/tests/helper/src/java/com/android/microdroid/test/MicrodroidDeviceTestBase.java
+++ b/tests/helper/src/java/com/android/microdroid/test/MicrodroidDeviceTestBase.java
@@ -118,7 +118,7 @@
         mInner = new Inner(context, protectedVm, VirtualMachineManager.getInstance(context));
     }
 
-    protected abstract static class VmEventListener implements VirtualMachineCallback {
+    public abstract static class VmEventListener implements VirtualMachineCallback {
         private ExecutorService mExecutorService = Executors.newSingleThreadExecutor();
         private OptionalLong mVcpuStartedNanoTime = OptionalLong.empty();
         private OptionalLong mKernelStartedNanoTime = OptionalLong.empty();