diff --git a/tests/benchmark/src/java/com/android/microdroid/benchmark/MicrodroidBenchmarks.java b/tests/benchmark/src/java/com/android/microdroid/benchmark/MicrodroidBenchmarks.java
index 7bf3c4e..d96ceb8 100644
--- a/tests/benchmark/src/java/com/android/microdroid/benchmark/MicrodroidBenchmarks.java
+++ b/tests/benchmark/src/java/com/android/microdroid/benchmark/MicrodroidBenchmarks.java
@@ -32,6 +32,7 @@
 import android.util.Log;
 
 import com.android.microdroid.test.common.MetricsProcessor;
+import com.android.microdroid.test.common.ProcessUtil;
 import com.android.microdroid.test.device.MicrodroidDeviceTestBase;
 import com.android.microdroid.testservice.IBenchmarkService;
 
@@ -49,6 +50,7 @@
 import java.util.List;
 import java.util.Map;
 import java.util.concurrent.atomic.AtomicReference;
+import java.util.function.Function;
 
 @RunWith(Parameterized.class)
 public class MicrodroidBenchmarks extends MicrodroidDeviceTestBase {
@@ -273,6 +275,10 @@
         }
     }
 
+    private String executeCommand(String command) {
+        return runInShell(TAG, mInstrumentation.getUiAutomation(), command);
+    }
+
     @Test
     public void testMemoryUsage() throws Exception {
         final String vmName = "test_vm_mem_usage";
@@ -283,7 +289,7 @@
                         .build();
         mInner.forceCreateNewVirtualMachine(vmName, config);
         VirtualMachine vm = mInner.getVirtualMachineManager().get(vmName);
-        MemoryUsageListener listener = new MemoryUsageListener();
+        MemoryUsageListener listener = new MemoryUsageListener(this::executeCommand);
         BenchmarkVmListener.create(listener).runToFinish(TAG, vm);
 
         double mem_overall = 256.0;
@@ -293,6 +299,10 @@
         double mem_buffers = (double) listener.mBuffers / 1024.0;
         double mem_cached = (double) listener.mCached / 1024.0;
         double mem_slab = (double) listener.mSlab / 1024.0;
+        double mem_crosvm_host_rss = (double) listener.mCrosvmHostRss / 1024.0;
+        double mem_crosvm_host_pss = (double) listener.mCrosvmHostPss / 1024.0;
+        double mem_crosvm_guest_rss = (double) listener.mCrosvmGuestRss / 1024.0;
+        double mem_crosvm_guest_pss = (double) listener.mCrosvmGuestPss / 1024.0;
 
         double mem_kernel = mem_overall - mem_total;
         double mem_used = mem_total - mem_free - mem_buffers - mem_cached - mem_slab;
@@ -305,10 +315,20 @@
         bundle.putDouble(METRIC_NAME_PREFIX + "mem_cached_MB", mem_cached);
         bundle.putDouble(METRIC_NAME_PREFIX + "mem_slab_MB", mem_slab);
         bundle.putDouble(METRIC_NAME_PREFIX + "mem_unreclaimable_MB", mem_unreclaimable);
+        bundle.putDouble(METRIC_NAME_PREFIX + "mem_crosvm_host_rss_MB", mem_crosvm_host_rss);
+        bundle.putDouble(METRIC_NAME_PREFIX + "mem_crosvm_host_pss_MB", mem_crosvm_host_pss);
+        bundle.putDouble(METRIC_NAME_PREFIX + "mem_crosvm_guest_rss_MB", mem_crosvm_guest_rss);
+        bundle.putDouble(METRIC_NAME_PREFIX + "mem_crosvm_guest_pss_MB", mem_crosvm_guest_pss);
         mInstrumentation.sendStatus(0, bundle);
     }
 
     private static class MemoryUsageListener implements BenchmarkVmListener.InnerListener {
+        MemoryUsageListener(Function<String, String> shellExecutor) {
+            mShellExecutor = shellExecutor;
+        }
+
+        public Function<String, String> mShellExecutor;
+
         public long mMemTotal;
         public long mMemFree;
         public long mMemAvailable;
@@ -316,6 +336,11 @@
         public long mCached;
         public long mSlab;
 
+        public long mCrosvmHostRss;
+        public long mCrosvmHostPss;
+        public long mCrosvmGuestRss;
+        public long mCrosvmGuestPss;
+
         @Override
         public void onPayloadReady(VirtualMachine vm, IBenchmarkService service)
                 throws RemoteException {
@@ -325,6 +350,39 @@
             mBuffers = service.getMemInfoEntry("Buffers");
             mCached = service.getMemInfoEntry("Cached");
             mSlab = service.getMemInfoEntry("Slab");
+
+            try {
+                List<Integer> crosvmPids =
+                        ProcessUtil.getProcessMap(mShellExecutor).entrySet().stream()
+                                .filter(e -> e.getValue().contains("crosvm"))
+                                .map(e -> e.getKey())
+                                .collect(java.util.stream.Collectors.toList());
+                if (crosvmPids.size() != 1) {
+                    throw new RuntimeException(
+                            "expected to find exactly one crosvm processes, found "
+                                    + crosvmPids.size());
+                }
+
+                mCrosvmHostRss = 0;
+                mCrosvmHostPss = 0;
+                mCrosvmGuestRss = 0;
+                mCrosvmGuestPss = 0;
+                for (ProcessUtil.SMapEntry entry :
+                        ProcessUtil.getProcessSmaps(crosvmPids.get(0), mShellExecutor)) {
+                    long rss = entry.metrics.get("Rss");
+                    long pss = entry.metrics.get("Pss");
+                    if (entry.name.contains("crosvm_guest")) {
+                        mCrosvmGuestRss += rss;
+                        mCrosvmGuestPss += pss;
+                    } else {
+                        mCrosvmHostRss += rss;
+                        mCrosvmHostPss += pss;
+                    }
+                }
+            } catch (Exception e) {
+                Log.e(TAG, "Error inside onPayloadReady():" + e);
+                throw new RuntimeException(e);
+            }
         }
     }
 
diff --git a/tests/helper/src/java/com/android/microdroid/test/common/ProcessUtil.java b/tests/helper/src/java/com/android/microdroid/test/common/ProcessUtil.java
index c5aad6e..611a572 100644
--- a/tests/helper/src/java/com/android/microdroid/test/common/ProcessUtil.java
+++ b/tests/helper/src/java/com/android/microdroid/test/common/ProcessUtil.java
@@ -17,18 +17,41 @@
 package com.android.microdroid.test.common;
 
 import java.io.IOException;
+import java.util.ArrayList;
 import java.util.HashMap;
+import java.util.List;
 import java.util.Map;
 import java.util.function.Function;
 
 /** This class provides process utility for both device tests and host tests. */
 public final class ProcessUtil {
 
+    /** A memory map entry from /proc/{pid}/smaps */
+    public static class SMapEntry {
+        public String name;
+        public Map<String, Long> metrics;
+    }
+
+    /** Gets metrics key and values mapping of specified process id */
+    public static List<SMapEntry> getProcessSmaps(int pid, Function<String, String> shellExecutor)
+            throws IOException {
+        String path = "/proc/" + pid + "/smaps";
+        return parseSmaps(shellExecutor.apply("cat " + path + " || true"));
+    }
+
     /** Gets metrics key and values mapping of specified process id */
     public static Map<String, Long> getProcessSmapsRollup(
             int pid, Function<String, String> shellExecutor) throws IOException {
         String path = "/proc/" + pid + "/smaps_rollup";
-        return parseMemoryInfo(skipFirstLine(shellExecutor.apply("cat " + path + " || true")));
+        List<SMapEntry> entries = parseSmaps(shellExecutor.apply("cat " + path + " || true"));
+        if (entries.size() > 1) {
+            throw new RuntimeException(
+                    "expected at most one entry in smaps_rollup, got " + entries.size());
+        }
+        if (entries.size() == 1) {
+            return entries.get(0).metrics;
+        }
+        return new HashMap<String, Long>();
     }
 
     /** Gets process id and process name mapping of the device */
@@ -54,21 +77,47 @@
     // To ensures that only one object is created at a time.
     private ProcessUtil() {}
 
-    private static Map<String, Long> parseMemoryInfo(String file) {
-        Map<String, Long> stats = new HashMap<>();
-        for (String line : file.split("[\r\n]+")) {
+    private static List<SMapEntry> parseSmaps(String file) {
+        List<SMapEntry> entries = new ArrayList<SMapEntry>();
+        for (String line : file.split("\n")) {
             line = line.trim();
             if (line.length() == 0) {
                 continue;
             }
-            // Each line is '<metrics>:        <number> kB'.
-            // EX : Pss_Anon:        70712 kB
-            if (line.endsWith(" kB")) line = line.substring(0, line.length() - 3);
-
-            String[] elems = line.split(":");
-            stats.put(elems[0].trim(), Long.parseLong(elems[1].trim()));
+            if (line.contains(": ")) {
+                if (entries.size() == 0) {
+                    throw new RuntimeException("unexpected line: " + line);
+                }
+                // Each line is '<metrics>:        <number> kB'.
+                // EX : Pss_Anon:        70712 kB
+                if (line.endsWith(" kB")) line = line.substring(0, line.length() - 3);
+                String[] elems = line.split(":");
+                String name = elems[0].trim();
+                try {
+                    entries.get(entries.size() - 1)
+                            .metrics
+                            .put(name, Long.parseLong(elems[1].trim()));
+                } catch (java.lang.NumberFormatException e) {
+                    // Some entries, like "VmFlags", aren't numbers, just ignore.
+                }
+                continue;
+            }
+            // Parse the header and create a new entry for it.
+            // Some header examples:
+            //     7f644098a000-7f644098c000 rw-p 00000000 00:00 0
+            //     00400000-0048a000 r-xp 00000000 fd:03 960637   /bin/bash
+            //     75e42af000-75f42af000 rw-s 00000000 00:01 235  /memfd:crosvm_guest (deleted)
+            SMapEntry entry = new SMapEntry();
+            String[] parts = line.split("\\s+", 6);
+            if (parts.length >= 6) {
+                entry.name = parts[5];
+            } else {
+                entry.name = "";
+            }
+            entry.metrics = new HashMap<String, Long>();
+            entries.add(entry);
         }
-        return stats;
+        return entries;
     }
 
     private static String skipFirstLine(String str) {
