Merge changes I232ec389,I6c9ca2f2,I0dcd1f3a am: dfdc5ce529

Original change: https://android-review.googlesource.com/c/platform/packages/modules/Virtualization/+/2478021

Change-Id: I0dcdfa13516e21a2d0e6dab135f8beab1d4400ba
Signed-off-by: Automerger Merge Worker <android-build-automerger-merge-worker@system.gserviceaccount.com>
diff --git a/javalib/src/android/system/virtualmachine/VirtualMachine.java b/javalib/src/android/system/virtualmachine/VirtualMachine.java
index 5f39b1c..7713faf 100644
--- a/javalib/src/android/system/virtualmachine/VirtualMachine.java
+++ b/javalib/src/android/system/virtualmachine/VirtualMachine.java
@@ -459,7 +459,7 @@
                 }
             }
 
-            IVirtualizationService service = vm.mVirtualizationService.connect();
+            IVirtualizationService service = vm.mVirtualizationService.getBinder();
 
             try {
                 service.initializeWritablePartition(
@@ -785,7 +785,7 @@
                 throw new VirtualMachineException("Failed to create APK signature file", e);
             }
 
-            IVirtualizationService service = mVirtualizationService.connect();
+            IVirtualizationService service = mVirtualizationService.getBinder();
 
             try {
                 if (mVmOutputCaptured) {
diff --git a/javalib/src/android/system/virtualmachine/VirtualizationService.java b/javalib/src/android/system/virtualmachine/VirtualizationService.java
index c3f2ba3..1cf97b5 100644
--- a/javalib/src/android/system/virtualmachine/VirtualizationService.java
+++ b/javalib/src/android/system/virtualmachine/VirtualizationService.java
@@ -41,6 +41,9 @@
      */
     private final ParcelFileDescriptor mClientFd;
 
+    /* Persistent connection to IVirtualizationService. */
+    private final IVirtualizationService mBinder;
+
     private static native int nativeSpawn();
 
     private native IBinder nativeConnect(int clientFd);
@@ -57,15 +60,18 @@
             throw new VirtualMachineException("Could not spawn VirtualizationService");
         }
         mClientFd = ParcelFileDescriptor.adoptFd(clientFd);
-    }
 
-    /* Connects to the VirtualizationService AIDL service. */
-    public IVirtualizationService connect() throws VirtualMachineException {
         IBinder binder = nativeConnect(mClientFd.getFd());
         if (binder == null) {
             throw new VirtualMachineException("Could not connect to VirtualizationService");
         }
-        return IVirtualizationService.Stub.asInterface(binder);
+        mBinder = IVirtualizationService.Stub.asInterface(binder);
+    }
+
+    /* Returns the IVirtualizationService binder. */
+    @NonNull
+    IVirtualizationService getBinder() {
+        return mBinder;
     }
 
     /*
diff --git a/tests/benchmark/src/java/com/android/microdroid/benchmark/MicrodroidBenchmarks.java b/tests/benchmark/src/java/com/android/microdroid/benchmark/MicrodroidBenchmarks.java
index 4b11d77..9851a17 100644
--- a/tests/benchmark/src/java/com/android/microdroid/benchmark/MicrodroidBenchmarks.java
+++ b/tests/benchmark/src/java/com/android/microdroid/benchmark/MicrodroidBenchmarks.java
@@ -36,6 +36,7 @@
 import android.system.virtualmachine.VirtualMachine;
 import android.system.virtualmachine.VirtualMachineConfig;
 import android.system.virtualmachine.VirtualMachineException;
+import android.system.Os;
 import android.util.Log;
 
 import com.android.microdroid.test.common.MetricsProcessor;
@@ -347,16 +348,7 @@
 
         CrosvmStats(Function<String, String> shellExecutor) {
             try {
-                List<Integer> crosvmPids =
-                        ProcessUtil.getProcessMap(shellExecutor).entrySet().stream()
-                                .filter(e -> e.getValue().contains("crosvm"))
-                                .map(e -> e.getKey())
-                                .collect(java.util.stream.Collectors.toList());
-                if (crosvmPids.size() != 1) {
-                    throw new IllegalStateException(
-                            "expected to find exactly one crosvm processes, found "
-                                    + crosvmPids.size());
-                }
+                int crosvmPid = ProcessUtil.getCrosvmPid(Os.getpid(), shellExecutor);
 
                 long hostRss = 0;
                 long hostPss = 0;
@@ -364,7 +356,7 @@
                 long guestPss = 0;
                 boolean hasGuestMaps = false;
                 for (ProcessUtil.SMapEntry entry :
-                        ProcessUtil.getProcessSmaps(crosvmPids.get(0), shellExecutor)) {
+                        ProcessUtil.getProcessSmaps(crosvmPid, shellExecutor)) {
                     long rss = entry.metrics.get("Rss");
                     long pss = entry.metrics.get("Pss");
                     if (entry.name.contains("crosvm_guest")) {
diff --git a/tests/helper/src/java/com/android/microdroid/test/common/ProcessUtil.java b/tests/helper/src/java/com/android/microdroid/test/common/ProcessUtil.java
index 940ec9c..c72d91e 100644
--- a/tests/helper/src/java/com/android/microdroid/test/common/ProcessUtil.java
+++ b/tests/helper/src/java/com/android/microdroid/test/common/ProcessUtil.java
@@ -22,9 +22,12 @@
 import java.util.List;
 import java.util.Map;
 import java.util.function.Function;
+import java.util.stream.IntStream;
 
 /** This class provides process utility for both device tests and host tests. */
 public final class ProcessUtil {
+    private static final String CROSVM_BIN = "/apex/com.android.virt/bin/crosvm";
+    private static final String VIRTMGR_BIN = "/apex/com.android.virt/bin/virtmgr";
 
     /** A memory map entry from /proc/{pid}/smaps */
     public static class SMapEntry {
@@ -89,6 +92,35 @@
         return processMap;
     }
 
+    private static IntStream getChildProcesses(
+            int pid, String cmdlineFilter, Function<String, String> shellExecutor) {
+        String cmd = "pgrep -P " + pid;
+        if (cmdlineFilter != null) {
+            cmd += " -f " + cmdlineFilter;
+        }
+        return shellExecutor.apply(cmd).trim().lines().mapToInt(Integer::parseInt);
+    }
+
+    private static int getSingleChildProcess(
+            int parentPid, String cmdlineFilter, Function<String, String> shellExecutor) {
+        int[] pids = getChildProcesses(parentPid, cmdlineFilter, shellExecutor).toArray();
+        if (pids.length == 0) {
+            throw new IllegalStateException("No process found for " + cmdlineFilter);
+        } else if (pids.length > 1) {
+            throw new IllegalStateException("More than one process found for " + cmdlineFilter);
+        }
+        return pids[0];
+    }
+
+    public static int getVirtmgrPid(int parentPid, Function<String, String> shellExecutor) {
+        return getSingleChildProcess(parentPid, VIRTMGR_BIN, shellExecutor);
+    }
+
+    public static int getCrosvmPid(int parentPid, Function<String, String> shellExecutor) {
+        int virtmgrPid = getVirtmgrPid(parentPid, shellExecutor);
+        return getSingleChildProcess(virtmgrPid, CROSVM_BIN, shellExecutor);
+    }
+
     // To ensures that only one object is created at a time.
     private ProcessUtil() {}
 
diff --git a/virtualizationmanager/src/crosvm.rs b/virtualizationmanager/src/crosvm.rs
index 745d4f6..7201670 100644
--- a/virtualizationmanager/src/crosvm.rs
+++ b/virtualizationmanager/src/crosvm.rs
@@ -492,6 +492,10 @@
         // first, as monitor_vm_exit() takes it as well.
         monitor_vm_exit_thread.map(JoinHandle::join);
 
+        // Now that the VM has been killed, shut down the VirtualMachineService
+        // server to eagerly free up the server threads.
+        self.vm_context.vm_server.shutdown()?;
+
         Ok(())
     }