Implement a memory reclaim microdroid benchmark test

This manually triggers the onTrimMemory hook and reports on guest
memory usage after allowing memory reclaim to occur.

Bug: 238931615
Test: atest MicrodroidBenchmarks
Change-Id: Ief3a3209cd99384cd716b1d920e1b2fe49896d04
diff --git a/tests/aidl/com/android/microdroid/testservice/IBenchmarkService.aidl b/tests/aidl/com/android/microdroid/testservice/IBenchmarkService.aidl
index 16e4893..c8c8660 100644
--- a/tests/aidl/com/android/microdroid/testservice/IBenchmarkService.aidl
+++ b/tests/aidl/com/android/microdroid/testservice/IBenchmarkService.aidl
@@ -30,6 +30,9 @@
     /** Returns an entry from /proc/meminfo. */
     long getMemInfoEntry(String name);
 
+    /** Allocates anonymous memory and returns the raw pointer. */
+    long allocAnonMemory(long mb);
+
     /**
      * Initializes the vsock server on VM.
      * @return the server socket file descriptor.
diff --git a/tests/benchmark/src/java/com/android/microdroid/benchmark/MicrodroidBenchmarks.java b/tests/benchmark/src/java/com/android/microdroid/benchmark/MicrodroidBenchmarks.java
index 14a0e39..3c3faf2 100644
--- a/tests/benchmark/src/java/com/android/microdroid/benchmark/MicrodroidBenchmarks.java
+++ b/tests/benchmark/src/java/com/android/microdroid/benchmark/MicrodroidBenchmarks.java
@@ -27,6 +27,7 @@
 import android.app.Instrumentation;
 import android.os.Bundle;
 import android.os.ParcelFileDescriptor;
+import android.os.Process;
 import android.os.RemoteException;
 import android.system.virtualmachine.VirtualMachine;
 import android.system.virtualmachine.VirtualMachineConfig;
@@ -279,6 +280,58 @@
         return runInShell(TAG, mInstrumentation.getUiAutomation(), command);
     }
 
+    private static class CrosvmStats {
+        public final long mHostRss;
+        public final long mHostPss;
+        public final long mGuestRss;
+        public final long mGuestPss;
+
+        CrosvmStats(Function<String, String> shellExecutor) {
+            try {
+                List<Integer> crosvmPids =
+                        ProcessUtil.getProcessMap(shellExecutor).entrySet().stream()
+                                .filter(e -> e.getValue().contains("crosvm"))
+                                .map(e -> e.getKey())
+                                .collect(java.util.stream.Collectors.toList());
+                if (crosvmPids.size() != 1) {
+                    throw new IllegalStateException(
+                            "expected to find exactly one crosvm processes, found "
+                                    + crosvmPids.size());
+                }
+
+                long hostRss = 0;
+                long hostPss = 0;
+                long guestRss = 0;
+                long guestPss = 0;
+                boolean hasGuestMaps = false;
+                for (ProcessUtil.SMapEntry entry :
+                        ProcessUtil.getProcessSmaps(crosvmPids.get(0), shellExecutor)) {
+                    long rss = entry.metrics.get("Rss");
+                    long pss = entry.metrics.get("Pss");
+                    if (entry.name.contains("crosvm_guest")) {
+                        guestRss += rss;
+                        guestPss += pss;
+                        hasGuestMaps = true;
+                    } else {
+                        hostRss += rss;
+                        hostPss += pss;
+                    }
+                }
+                if (!hasGuestMaps) {
+                    throw new IllegalStateException(
+                            "found no crosvm_guest smap entry in crosvm process");
+                }
+                mHostRss = hostRss;
+                mHostPss = hostPss;
+                mGuestRss = guestRss;
+                mGuestPss = guestPss;
+            } catch (Exception e) {
+                Log.e(TAG, "Error inside onPayloadReady():" + e);
+                throw new RuntimeException(e);
+            }
+        }
+    }
+
     @Test
     public void testMemoryUsage() throws Exception {
         final String vmName = "test_vm_mem_usage";
@@ -299,10 +352,10 @@
         double mem_buffers = (double) listener.mBuffers / 1024.0;
         double mem_cached = (double) listener.mCached / 1024.0;
         double mem_slab = (double) listener.mSlab / 1024.0;
-        double mem_crosvm_host_rss = (double) listener.mCrosvmHostRss / 1024.0;
-        double mem_crosvm_host_pss = (double) listener.mCrosvmHostPss / 1024.0;
-        double mem_crosvm_guest_rss = (double) listener.mCrosvmGuestRss / 1024.0;
-        double mem_crosvm_guest_pss = (double) listener.mCrosvmGuestPss / 1024.0;
+        double mem_crosvm_host_rss = (double) listener.mCrosvm.mHostRss / 1024.0;
+        double mem_crosvm_host_pss = (double) listener.mCrosvm.mHostPss / 1024.0;
+        double mem_crosvm_guest_rss = (double) listener.mCrosvm.mGuestRss / 1024.0;
+        double mem_crosvm_guest_pss = (double) listener.mCrosvm.mGuestPss / 1024.0;
 
         double mem_kernel = mem_overall - mem_total;
         double mem_used = mem_total - mem_free - mem_buffers - mem_cached - mem_slab;
@@ -327,7 +380,7 @@
             mShellExecutor = shellExecutor;
         }
 
-        public Function<String, String> mShellExecutor;
+        public final Function<String, String> mShellExecutor;
 
         public long mMemTotal;
         public long mMemFree;
@@ -336,10 +389,7 @@
         public long mCached;
         public long mSlab;
 
-        public long mCrosvmHostRss;
-        public long mCrosvmHostPss;
-        public long mCrosvmGuestRss;
-        public long mCrosvmGuestPss;
+        public CrosvmStats mCrosvm;
 
         @Override
         public void onPayloadReady(VirtualMachine vm, IBenchmarkService service)
@@ -350,39 +400,80 @@
             mBuffers = service.getMemInfoEntry("Buffers");
             mCached = service.getMemInfoEntry("Cached");
             mSlab = service.getMemInfoEntry("Slab");
+            mCrosvm = new CrosvmStats(mShellExecutor);
+        }
+    }
 
+    @Test
+    public void testMemoryReclaim() throws Exception {
+        final String vmName = "test_vm_mem_reclaim";
+        VirtualMachineConfig config =
+                newVmConfigBuilder()
+                        .setPayloadConfigPath("assets/vm_config_io.json")
+                        .setDebugLevel(DEBUG_LEVEL_NONE)
+                        .setMemoryMib(256)
+                        .build();
+        VirtualMachine vm = forceCreateNewVirtualMachine(vmName, config);
+        MemoryReclaimListener listener = new MemoryReclaimListener(this::executeCommand);
+        BenchmarkVmListener.create(listener).runToFinish(TAG, vm);
+
+        double mem_pre_crosvm_host_rss = (double) listener.mPreCrosvm.mHostRss / 1024.0;
+        double mem_pre_crosvm_host_pss = (double) listener.mPreCrosvm.mHostPss / 1024.0;
+        double mem_pre_crosvm_guest_rss = (double) listener.mPreCrosvm.mGuestRss / 1024.0;
+        double mem_pre_crosvm_guest_pss = (double) listener.mPreCrosvm.mGuestPss / 1024.0;
+        double mem_post_crosvm_host_rss = (double) listener.mPostCrosvm.mHostRss / 1024.0;
+        double mem_post_crosvm_host_pss = (double) listener.mPostCrosvm.mHostPss / 1024.0;
+        double mem_post_crosvm_guest_rss = (double) listener.mPostCrosvm.mGuestRss / 1024.0;
+        double mem_post_crosvm_guest_pss = (double) listener.mPostCrosvm.mGuestPss / 1024.0;
+
+        Bundle bundle = new Bundle();
+        bundle.putDouble(
+                METRIC_NAME_PREFIX + "mem_pre_crosvm_host_rss_MB", mem_pre_crosvm_host_rss);
+        bundle.putDouble(
+                METRIC_NAME_PREFIX + "mem_pre_crosvm_host_pss_MB", mem_pre_crosvm_host_pss);
+        bundle.putDouble(
+                METRIC_NAME_PREFIX + "mem_pre_crosvm_guest_rss_MB", mem_pre_crosvm_guest_rss);
+        bundle.putDouble(
+                METRIC_NAME_PREFIX + "mem_pre_crosvm_guest_pss_MB", mem_pre_crosvm_guest_pss);
+        bundle.putDouble(
+                METRIC_NAME_PREFIX + "mem_post_crosvm_host_rss_MB", mem_post_crosvm_host_rss);
+        bundle.putDouble(
+                METRIC_NAME_PREFIX + "mem_post_crosvm_host_pss_MB", mem_post_crosvm_host_pss);
+        bundle.putDouble(
+                METRIC_NAME_PREFIX + "mem_post_crosvm_guest_rss_MB", mem_post_crosvm_guest_rss);
+        bundle.putDouble(
+                METRIC_NAME_PREFIX + "mem_post_crosvm_guest_pss_MB", mem_post_crosvm_guest_pss);
+        mInstrumentation.sendStatus(0, bundle);
+    }
+
+    private static class MemoryReclaimListener implements BenchmarkVmListener.InnerListener {
+        MemoryReclaimListener(Function<String, String> shellExecutor) {
+            mShellExecutor = shellExecutor;
+        }
+
+        public final Function<String, String> mShellExecutor;
+
+        public CrosvmStats mPreCrosvm;
+        public CrosvmStats mPostCrosvm;
+
+        @Override
+        @SuppressWarnings("ReturnValueIgnored")
+        public void onPayloadReady(VirtualMachine vm, IBenchmarkService service)
+                throws RemoteException {
+            // Allocate 256MB of anonymous memory. This will fill all guest
+            // memory and cause swapping to start.
+            service.allocAnonMemory(256);
+            mPreCrosvm = new CrosvmStats(mShellExecutor);
+            // Send a memory trim hint to cause memory reclaim.
+            mShellExecutor.apply("am send-trim-memory " + Process.myPid() + " RUNNING_CRITICAL");
+            // Give time for the memory reclaim to do its work.
             try {
-                List<Integer> crosvmPids =
-                        ProcessUtil.getProcessMap(mShellExecutor).entrySet().stream()
-                                .filter(e -> e.getValue().contains("crosvm"))
-                                .map(e -> e.getKey())
-                                .collect(java.util.stream.Collectors.toList());
-                if (crosvmPids.size() != 1) {
-                    throw new RuntimeException(
-                            "expected to find exactly one crosvm processes, found "
-                                    + crosvmPids.size());
-                }
-
-                mCrosvmHostRss = 0;
-                mCrosvmHostPss = 0;
-                mCrosvmGuestRss = 0;
-                mCrosvmGuestPss = 0;
-                for (ProcessUtil.SMapEntry entry :
-                        ProcessUtil.getProcessSmaps(crosvmPids.get(0), mShellExecutor)) {
-                    long rss = entry.metrics.get("Rss");
-                    long pss = entry.metrics.get("Pss");
-                    if (entry.name.contains("crosvm_guest")) {
-                        mCrosvmGuestRss += rss;
-                        mCrosvmGuestPss += pss;
-                    } else {
-                        mCrosvmHostRss += rss;
-                        mCrosvmHostPss += pss;
-                    }
-                }
-            } catch (Exception e) {
-                Log.e(TAG, "Error inside onPayloadReady():" + e);
-                throw new RuntimeException(e);
+                Thread.sleep(isCuttlefish() ? 10000 : 5000);
+            } catch (InterruptedException e) {
+                Log.e(TAG, "Interrupted sleep:" + e);
+                Thread.currentThread().interrupt();
             }
+            mPostCrosvm = new CrosvmStats(mShellExecutor);
         }
     }
 
diff --git a/tests/benchmark/src/native/benchmarkbinary.cpp b/tests/benchmark/src/native/benchmarkbinary.cpp
index 70ec7db..5c172c0 100644
--- a/tests/benchmark/src/native/benchmarkbinary.cpp
+++ b/tests/benchmark/src/native/benchmarkbinary.cpp
@@ -77,6 +77,11 @@
         return ndk::ScopedAStatus::ok();
     }
 
+    ndk::ScopedAStatus allocAnonMemory(long mb, long* out) override {
+        *out = (long)alloc_anon_memory(mb);
+        return ndk::ScopedAStatus::ok();
+    }
+
     ndk::ScopedAStatus initVsockServer(int32_t port, int32_t* out) override {
         auto res = io_vsock::init_vsock_server(port);
         if (res.ok()) {
@@ -131,6 +136,17 @@
         return {file_size_mb / elapsed_seconds};
     }
 
+    void* alloc_anon_memory(long mb) {
+        long bytes = mb << 20;
+        void* p = malloc(bytes);
+        /*
+         * Heap memory is demand allocated. Dirty all pages to ensure
+         * all are allocated.
+         */
+        memset(p, 0x55, bytes);
+        return p;
+    }
+
     Result<size_t> read_meminfo_entry(const std::string& stat) {
         std::ifstream fs("/proc/meminfo");
         if (!fs.is_open()) {