Test PSCI MEM_PROTECT

The PSCI MEM_PROTECT flag is a protection against cold-reboot type
attacks. It is implemented in the hypervisor with a counter that can be
observed with a hypervisor event.

Ensure this counter is incremented during the run of a protected VM and
then return to 0.

Bug: 340244758
Test: AVFHostTestCase#testPsciMemProtect
Change-Id: Ifec1533fca9a8199fabd307c887e79065c6ef263
diff --git a/tests/benchmark_hostside/java/android/avf/test/AVFHostTestCase.java b/tests/benchmark_hostside/java/android/avf/test/AVFHostTestCase.java
index 99a8df8..4a61016 100644
--- a/tests/benchmark_hostside/java/android/avf/test/AVFHostTestCase.java
+++ b/tests/benchmark_hostside/java/android/avf/test/AVFHostTestCase.java
@@ -47,6 +47,7 @@
 import org.junit.runner.RunWith;
 
 import java.util.ArrayList;
+import java.util.Collections;
 import java.util.List;
 import java.util.Map;
 import java.util.regex.Matcher;
@@ -139,6 +140,36 @@
     }
 
     @Test
+    public void testPsciMemProtect() throws Exception {
+        String[] hypEvents = {
+            "psci_mem_protect"
+        };
+
+        assumeTrue("Skip without hypervisor tracing",
+            KvmHypTracer.isSupported(getDevice(), hypEvents));
+        KvmHypTracer tracer = new KvmHypTracer(getDevice(), hypEvents);
+
+        /* We need to wait for crosvm to die so all the VM pages are reclaimed */
+        String result = tracer.run(COMPOSD_CMD_BIN + " test-compile && killall -w crosvm || true");
+        assertWithMessage("Failed to test compilation VM.")
+                .that(result).ignoringCase().contains("all ok");
+
+        List<Integer> values = tracer.getPsciMemProtect();
+
+        assertWithMessage("PSCI MEM_PROTECT events not recorded")
+            .that(values.size()).isGreaterThan(2);
+
+        assertWithMessage("PSCI MEM_PROTECT counter not starting from 0")
+            .that(values.get(0)).isEqualTo(0);
+
+        assertWithMessage("PSCI MEM_PROTECT counter not ending with 0")
+            .that(values.get(values.size() - 1)).isEqualTo(0);
+
+        assertWithMessage("PSCI MEM_PROTECT counter didn't increment")
+            .that(Collections.max(values)).isGreaterThan(0);
+    }
+
+    @Test
     public void testCameraAppStartupTime() throws Exception {
         String[] launchIntentPackages = {
             "com.android.camera2",
diff --git a/tests/hostside/helper/java/com/android/microdroid/test/host/KvmHypTracer.java b/tests/hostside/helper/java/com/android/microdroid/test/host/KvmHypTracer.java
index 06b3624..5c72358 100644
--- a/tests/hostside/helper/java/com/android/microdroid/test/host/KvmHypTracer.java
+++ b/tests/hostside/helper/java/com/android/microdroid/test/host/KvmHypTracer.java
@@ -30,6 +30,7 @@
 import java.text.ParseException;
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.List;
 import java.util.regex.Matcher;
 import java.util.regex.Pattern;
 import javax.annotation.Nonnull;
@@ -135,15 +136,21 @@
             cmd += "CPU" + i + "_TRACE_PIPE_PID=$!;";
         }
 
+        String cmd_script = mRunner.run("mktemp -t cmd_script.XXXXXXXXXX");
+        mRunner.run("echo '" + payload_cmd + "' > " + cmd_script);
+
         /* Run the payload with tracing enabled */
         cmd += "echo 1 > tracing_on;";
         String cmd_stdout = mRunner.run("mktemp -t cmd_stdout.XXXXXXXXXX");
-        cmd += payload_cmd + " > " + cmd_stdout + ";";
+        cmd += "sh " + cmd_script + " > " + cmd_stdout + ";";
         cmd += "echo 0 > tracing_on;";
 
-        /* Actively kill the cat subprocesses as trace_pipe is blocking */
-        for (int i = 0; i < mNrCpus; i++)
+        /* Wait for cat to finish reading the pipe interface before killing it */
+        for (int i = 0; i < mNrCpus; i++) {
+            cmd += "while $(test '$(ps -o S -p $CPU" + i
+                + "_TRACE_PIPE_PID | tail -n 1)' = 'R'); do sleep 1; done;";
             cmd += "kill -9 $CPU" + i + "_TRACE_PIPE_PID;";
+        }
         cmd += "wait";
 
         /*
@@ -155,6 +162,8 @@
          */
         mRunner.run(cmd);
 
+        mRunner.run("rm -f " + cmd_script);
+
         for (String t: trace_pipes) {
             File trace = mDevice.pullFile(t);
             assertNotNull(trace);
@@ -201,13 +210,10 @@
         for (File trace: mTraces) {
             BufferedReader br = new BufferedReader(new FileReader(trace));
             double last = 0.0, hyp_enter = 0.0;
-            String l, prev_event = "";
-            while ((l = br.readLine()) != null) {
-                KvmHypEvent hypEvent = new KvmHypEvent(l);
+            String prev_event = "";
+            KvmHypEvent hypEvent;
 
-                if (!hypEvent.valid)
-                    continue;
-
+            while ((hypEvent = getNextEvent(br)) != null) {
                 int cpu = hypEvent.cpu;
                 if (cpu < 0 || cpu >= mNrCpus)
                     throw new ParseException("Incorrect CPU number: " + cpu, 0);
@@ -219,8 +225,8 @@
 
                 String event = hypEvent.name;
                 if (event.equals(prev_event)) {
-                    throw new ParseException("Hyp event found twice in a row: " + trace + " - " + l,
-                                             0);
+                    throw new ParseException("Hyp event found twice in a row: " +
+                                             trace + " - " + hypEvent, 0);
                 }
 
                 switch (event) {
@@ -232,7 +238,7 @@
                         hyp_enter = cur;
                         break;
                     default:
-                        throw new ParseException("Unexpected line in trace " + l, 0);
+                        throw new ParseException("Unexpected line in trace " + hypEvent, 0);
                 }
                 prev_event = event;
             }
@@ -240,4 +246,55 @@
 
         return stats;
     }
+
+    public List<Integer> getPsciMemProtect() throws Exception {
+        String[] reqEvents = {"psci_mem_protect"};
+        List<Integer> psciMemProtect = new ArrayList<>();
+
+        assertWithMessage("KvmHypTracer() is missing events " + String.join(",", reqEvents))
+            .that(hasEvents(reqEvents)).isTrue();
+
+        BufferedReader[] brs = new BufferedReader[mTraces.size()];
+        KvmHypEvent[] next = new KvmHypEvent[mTraces.size()];
+
+        for (int i = 0; i < mTraces.size(); i++) {
+            brs[i] = new BufferedReader(new FileReader(mTraces.get(i)));
+            next[i] = getNextEvent(brs[i]);
+        }
+
+        while (true) {
+            double oldest = Double.MAX_VALUE;
+            int oldestIdx = -1;
+
+            for (int i = 0; i < mTraces.size(); i ++) {
+                if ((next[i] != null) && (next[i].timestamp < oldest)) {
+                    oldest = next[i].timestamp;
+                    oldestIdx = i;
+                }
+            }
+
+            if (oldestIdx < 0)
+                break;
+
+            Pattern pattern = Pattern.compile(
+                "count=([0-9]*) was=([0-9]*)");
+            Matcher matcher = pattern.matcher(next[oldestIdx].args);
+            if (!matcher.find()) {
+                throw new ParseException("Unexpected psci_mem_protect event: " +
+                                         next[oldestIdx], 0);
+            }
+
+            int count = Integer.parseInt(matcher.group(1));
+            int was = Integer.parseInt(matcher.group(2));
+
+            if (psciMemProtect.isEmpty()) {
+                psciMemProtect.add(was);
+            }
+
+            psciMemProtect.add(count);
+            next[oldestIdx] = getNextEvent(brs[oldestIdx]);
+        }
+
+        return psciMemProtect;
+    }
 }