Merge changes from topic "sk_vm_session_refresh" into main

* changes:
  Benchmark rollbackProtectedSecret API
  Refresh & retry Sk requests in case of failure.
  VM maybe be in stop state when payload finishes
diff --git a/guest/microdroid_manager/src/vm_secret.rs b/guest/microdroid_manager/src/vm_secret.rs
index 56b3482..5999122 100644
--- a/guest/microdroid_manager/src/vm_secret.rs
+++ b/guest/microdroid_manager/src/vm_secret.rs
@@ -171,7 +171,11 @@
             return Err(anyhow!("Rollback protected data is not available with V1 secrets"));
         };
         let payload_id = sha::sha512(instance_id);
-        secretkeeper_session.get_secret(payload_id)
+        secretkeeper_session.get_secret(payload_id).or_else(|e| {
+            log::info!("Secretkeeper get failed with {e:?}. Refreshing connection & retrying!");
+            secretkeeper_session.refresh()?;
+            secretkeeper_session.get_secret(payload_id)
+        })
     }
 
     pub fn write_payload_data_rp(&self, data: &[u8; SECRET_SIZE]) -> Result<()> {
@@ -180,7 +184,12 @@
             return Err(anyhow!("Rollback protected data is not available with V1 secrets"));
         };
         let payload_id = sha::sha512(instance_id);
-        secretkeeper_session.store_secret(payload_id, data)
+        if let Err(e) = secretkeeper_session.store_secret(payload_id, data.clone()) {
+            log::info!("Secretkeeper store failed with {e:?}. Refreshing connection & retrying!");
+            secretkeeper_session.refresh()?;
+            secretkeeper_session.store_secret(payload_id, data)?;
+        }
+        Ok(())
     }
 }
 
@@ -276,6 +285,11 @@
         Ok(Self { session, sealing_policy })
     }
 
+    fn refresh(&self) -> Result<()> {
+        let mut session = self.session.lock().unwrap();
+        Ok(session.refresh()?)
+    }
+
     fn store_secret(&self, id: [u8; ID_SIZE], secret: Zeroizing<[u8; SECRET_SIZE]>) -> Result<()> {
         let store_request = StoreSecretRequest {
             id: Id(id),
diff --git a/tests/benchmark/src/java/com/android/microdroid/benchmark/MicrodroidBenchmarks.java b/tests/benchmark/src/java/com/android/microdroid/benchmark/MicrodroidBenchmarks.java
index 109c5e0..1d827b9 100644
--- a/tests/benchmark/src/java/com/android/microdroid/benchmark/MicrodroidBenchmarks.java
+++ b/tests/benchmark/src/java/com/android/microdroid/benchmark/MicrodroidBenchmarks.java
@@ -27,6 +27,8 @@
 import static com.google.common.truth.Truth.assertWithMessage;
 import static com.google.common.truth.TruthJUnit.assume;
 
+import static org.junit.Assume.assumeTrue;
+
 import android.app.Application;
 import android.app.Instrumentation;
 import android.content.ComponentCallbacks2;
@@ -69,13 +71,16 @@
 import java.io.Writer;
 import java.nio.file.Files;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.OptionalLong;
+import java.util.Random;
 import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.CompletionException;
 import java.util.concurrent.atomic.AtomicReference;
 import java.util.function.Function;
 
@@ -927,4 +932,169 @@
         }
         reportMetrics(requestAttestationTime, "request_attestation_time", "microsecond");
     }
+
+    List<Double> rpDataAccessWithExistingSession(boolean measureWrite) throws Exception {
+        assumeTrue(
+                "Rollback protected secrets are only available in Updatable VMs",
+                isUpdatableVmSupported());
+        final int NUM_WARMUPS = 10;
+        final int NUM_REQUESTS = 10_000;
+
+        VirtualMachineConfig config =
+                newVmConfigBuilderWithPayloadBinary("MicrodroidTestNativeLib.so")
+                        .setDebugLevel(DEBUG_LEVEL_NONE)
+                        .build();
+
+        byte[] data = new byte[32];
+        Arrays.fill(data, (byte) 0xcc);
+
+        List<Double> requestLatencies = new ArrayList<>(NUM_REQUESTS);
+        VirtualMachine vm = forceCreateNewVirtualMachine("rp_data_access", config);
+        TestResults testResult =
+                runVmTestService(
+                        TAG,
+                        vm,
+                        (ts, tr) -> {
+                            tr.mTimings = new long[NUM_REQUESTS];
+                            for (int i = 0; i < NUM_WARMUPS; i++) {
+                                ts.insecurelyWritePayloadRpData(data);
+                                ts.insecurelyReadPayloadRpData();
+                            }
+                            for (int i = 0; i < NUM_REQUESTS; i++) {
+                                long start = System.nanoTime();
+                                if (measureWrite) {
+                                    ts.insecurelyWritePayloadRpData(data);
+                                    tr.mTimings[i] = System.nanoTime() - start;
+                                } else {
+                                    tr.mPayloadRpData = ts.insecurelyReadPayloadRpData();
+                                    tr.mTimings[i] = System.nanoTime() - start;
+                                    assertThat(tr.mPayloadRpData).isEqualTo(data);
+                                }
+                            }
+                        });
+        // Correctness check.
+        testResult.assertNoException();
+        for (long timings : testResult.mTimings) {
+            requestLatencies.add((double) timings / NANO_TO_MICRO);
+        }
+        return requestLatencies;
+    }
+
+    @Test
+    public void rpDataReadWithExistingSession() throws Exception {
+        reportMetrics(
+                rpDataAccessWithExistingSession(false),
+                "latency/readRollbackProtectedSecretWithExistingSession",
+                "us");
+    }
+
+    @Test
+    public void rpDataWriteWithExistingSession() throws Exception {
+        reportMetrics(
+                rpDataAccessWithExistingSession(true),
+                "latency/writeRollbackProtectedSecretWithExistingSession",
+                "us");
+    }
+
+    List<Double> rpDataAccessWithRefreshingSession(boolean measureWrite) throws Exception {
+        assumeTrue(
+                "Rollback protected secrets are only available in Updatable VMs",
+                isUpdatableVmSupported());
+        final long vmSize = minMemoryRequired();
+        final int numVMs = 8;
+        final int NUM_REQUESTS = 10;
+        final long availableMem = getAvailableMemory();
+
+        // Let's not use more than half of the available memory
+        assume().withMessage("Available memory (" + availableMem + " bytes) too small")
+                .that((numVMs * vmSize) <= (availableMem / 2))
+                .isTrue();
+
+        VirtualMachineConfig config =
+                newVmConfigBuilderWithPayloadBinary("MicrodroidTestNativeLib.so")
+                        .setDebugLevel(DEBUG_LEVEL_FULL)
+                        .setMemoryBytes(vmSize)
+                        .build();
+
+        byte[] data = new byte[32];
+        Arrays.fill(data, (byte) 0xcc);
+
+        List<Double> requestLatencies = new ArrayList<>(numVMs * NUM_REQUESTS);
+        CompletableFuture<TestResults>[] resultFutureList = new CompletableFuture[numVMs];
+        RunTestsAgainstTestService testToRun =
+                (ts, tr) -> {
+                    tr.mTimings = new long[NUM_REQUESTS];
+                    // Warm up request!
+                    ts.insecurelyWritePayloadRpData(data);
+                    for (int j = 0; j < NUM_REQUESTS; j++) {
+                        // Sleep time between 2 requests.
+                        // Randomized
+                        // between 200ms-300ms.
+                        long rnd_sleep_time = (long) (200.0 + new Random().nextDouble() * 100);
+                        Thread.sleep(rnd_sleep_time); // Sleep
+                        long start = System.nanoTime();
+                        if (measureWrite) {
+                            // Write
+                            ts.insecurelyWritePayloadRpData(data);
+                            tr.mTimings[j] = System.nanoTime() - start;
+
+                        } else {
+                            tr.mPayloadRpData = ts.insecurelyReadPayloadRpData();
+                            tr.mTimings[j] = System.nanoTime() - start;
+                            assertThat(tr.mPayloadRpData).isEqualTo(data);
+                        }
+                    }
+                };
+        for (int i = 0; i < numVMs; i++) {
+            final VirtualMachine vm =
+                    forceCreateNewVirtualMachine("rp_data_access_refresh" + i, config);
+            resultFutureList[i] =
+                    CompletableFuture.supplyAsync(
+                            () -> {
+                                try {
+                                    TestResults testResult = runVmTestService(TAG, vm, testToRun);
+                                    // Correctness check.
+                                    testResult.assertNoException();
+                                    return testResult;
+                                } catch (Exception e) {
+                                    throw new CompletionException(e);
+                                }
+                            });
+        }
+
+        for (int i = 0; i < numVMs; i++) {
+            TestResults tr = resultFutureList[i].get();
+            tr.assertNoException();
+            for (long timings : tr.mTimings) {
+                requestLatencies.add((double) timings / NANO_TO_MICRO);
+            }
+        }
+        return requestLatencies;
+    }
+
+    // The following benchmark corresponds to cases when payload access rollback protected secret,
+    // but there is no existing session with Secretkeeper - which could be the case when several VMs
+    // are attempting to establish a connection.
+    //
+    // Implementation detail of the API in such scenario: Microdroid attempts to access the secret
+    // from Secretkeeper -> gets an error ("UnknownKeyId") -> Refreshes the session (this includes
+    // several call to AuthGraphKey Exchange HAL) -> retries access.
+    //
+    // Essentially this latency is (Failed Secretkeeper access from pVM + AuthGraphKeyExchange
+    // protocol between pVM & Secretkeeper + Successful Secretkeeper access from pVM)
+    @Test
+    public void rpDataReadWithRefreshingSession() throws Exception {
+        reportMetrics(
+                rpDataAccessWithRefreshingSession(false),
+                "latency/readRollbackProtectedSecretWithRefreshSession",
+                "us");
+    }
+
+    @Test
+    public void rpDataWriteWithRefreshingSession() throws Exception {
+        reportMetrics(
+                rpDataAccessWithRefreshingSession(true),
+                "latency/writeRollbackProtectedSecretWithRefreshSession",
+                "us");
+    }
 }
diff --git a/tests/helper/src/java/com/android/microdroid/test/device/MicrodroidDeviceTestBase.java b/tests/helper/src/java/com/android/microdroid/test/device/MicrodroidDeviceTestBase.java
index c05fb0b..94f7ced 100644
--- a/tests/helper/src/java/com/android/microdroid/test/device/MicrodroidDeviceTestBase.java
+++ b/tests/helper/src/java/com/android/microdroid/test/device/MicrodroidDeviceTestBase.java
@@ -27,9 +27,11 @@
 import static org.junit.Assume.assumeFalse;
 import static org.junit.Assume.assumeTrue;
 
+import android.app.ActivityManager;
 import android.app.Instrumentation;
 import android.app.UiAutomation;
 import android.content.Context;
+import android.os.Build;
 import android.os.ParcelFileDescriptor;
 import android.os.SystemProperties;
 import android.system.Os;
@@ -79,6 +81,10 @@
                                     "microdroid_16k",
                                     "microdroid_gki-android15-6.6")));
 
+    private static final long ONE_MEBI = 1024 * 1024;
+    private static final long MIN_MEM_ARM64 = 170 * ONE_MEBI;
+    private static final long MIN_MEM_X86_64 = 196 * ONE_MEBI;
+
     public static boolean isCuttlefish() {
         return getDeviceProperties().isCuttlefish();
     }
@@ -393,6 +399,10 @@
             return mProcessedBootTimeMetrics;
         }
 
+        // Stopping a virtual machine is like pulling the plug on a real computer. VM may be left in
+        // an inconsistent state.
+        // For a graceful shutdown, request the payload to call {@code exit()} and wait for
+        // VirtualMachineCallback#onPayloadFinished} to be called.
         protected void forceStop(VirtualMachine vm) {
             try {
                 vm.stop();
@@ -722,7 +732,6 @@
                     public void onPayloadFinished(VirtualMachine vm, int exitCode) {
                         Log.i(logTag, "onPayloadFinished: " + exitCode);
                         payloadFinished.complete(true);
-                        forceStop(vm);
                     }
                 };
 
@@ -733,6 +742,26 @@
         return testResults;
     }
 
+    protected long getAvailableMemory() {
+        ActivityManager am = getContext().getSystemService(ActivityManager.class);
+        ActivityManager.MemoryInfo memoryInfo = new ActivityManager.MemoryInfo();
+        am.getMemoryInfo(memoryInfo);
+        return memoryInfo.availMem;
+    }
+
+    protected long minMemoryRequired() {
+        assertThat(Build.SUPPORTED_ABIS).isNotEmpty();
+        String primaryAbi = Build.SUPPORTED_ABIS[0];
+        switch (primaryAbi) {
+            case "x86_64":
+                return MIN_MEM_X86_64;
+            case "arm64-v8a":
+            case "arm64-v8a-hwasan":
+                return MIN_MEM_ARM64;
+        }
+        throw new AssertionError("Unsupported ABI: " + primaryAbi);
+    }
+
     @FunctionalInterface
     protected interface RunTestsAgainstTestService {
         void runTests(ITestService testService, TestResults testResults) throws Exception;
diff --git a/tests/testapk/src/java/com/android/microdroid/test/MicrodroidTests.java b/tests/testapk/src/java/com/android/microdroid/test/MicrodroidTests.java
index a2b4747..3aaed5e 100644
--- a/tests/testapk/src/java/com/android/microdroid/test/MicrodroidTests.java
+++ b/tests/testapk/src/java/com/android/microdroid/test/MicrodroidTests.java
@@ -39,7 +39,6 @@
 import static java.nio.file.StandardCopyOption.REPLACE_EXISTING;
 import static java.util.stream.Collectors.toList;
 
-import android.app.ActivityManager;
 import android.app.Instrumentation;
 import android.app.UiAutomation;
 import android.content.ComponentName;
@@ -47,7 +46,6 @@
 import android.content.ContextWrapper;
 import android.content.Intent;
 import android.content.ServiceConnection;
-import android.os.Build;
 import android.os.IBinder;
 import android.os.Parcel;
 import android.os.ParcelFileDescriptor;
@@ -119,6 +117,7 @@
 import java.util.OptionalLong;
 import java.util.UUID;
 import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.CompletionException;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicReference;
@@ -171,11 +170,6 @@
     public void tearDown() {
         revokePermission(VirtualMachine.USE_CUSTOM_VIRTUAL_MACHINE_PERMISSION);
     }
-
-    private static final long ONE_MEBI = 1024 * 1024;
-
-    private static final long MIN_MEM_ARM64 = 170 * ONE_MEBI;
-    private static final long MIN_MEM_X86_64 = 196 * ONE_MEBI;
     private static final String EXAMPLE_STRING = "Literally any string!! :)";
 
     private static final String VM_SHARE_APP_PACKAGE_NAME = "com.android.microdroid.vmshare_app";
@@ -1900,7 +1894,7 @@
                         (ts, tr) -> {
                             tr.mPayloadRpData = ts.insecurelyReadPayloadRpData();
                         });
-        // ainsecurelyReadPayloadRpData()` must've failed since no data was ever written!
+        // `insecurelyReadPayloadRpData()` must've failed since no data was ever written!
         assertWithMessage("The read (unexpectedly) succeeded!")
                 .that(testResults.mException)
                 .isNotNull();
@@ -1931,6 +1925,62 @@
     }
 
     @Test
+    public void rollbackProtectedDataCanBeAccessedPostConnectionExpiration() throws Exception {
+        final long vmSize = minMemoryRequired();
+        // The reference implementation of Secretkeeper maintains 4 live session keys,
+        // dropping the oldest one when new connections are requested. Therefore we spin 8 VMs
+        // asynchronously.
+        // Within a VM, wait for 5 sec (> Microdroid boot time) and trigger rp data access
+        // hoping at least some of the connection between VM <-> Secretkeeper are expired.
+        final int numVMs = 8;
+        final long availableMem = getAvailableMemory();
+
+        // Let's not use more than half of the available memory
+        assume().withMessage("Available memory (" + availableMem + " bytes) too small")
+                .that((numVMs * vmSize) <= (availableMem / 2))
+                .isTrue();
+
+        VirtualMachineConfig config =
+                newVmConfigBuilderWithPayloadBinary("MicrodroidTestNativeLib.so")
+                        .setDebugLevel(DEBUG_LEVEL_FULL)
+                        .setMemoryBytes(vmSize)
+                        .build();
+        byte[] data = new byte[32];
+        Arrays.fill(data, (byte) 0xcc);
+
+        CompletableFuture<TestResults>[] resultFutureList = new CompletableFuture[numVMs];
+        for (int i = 0; i < numVMs; i++) {
+            final VirtualMachine vm =
+                    forceCreateNewVirtualMachine("test_sk_session_expiration_vm_" + i, config);
+            resultFutureList[i] =
+                    CompletableFuture.supplyAsync(
+                            () -> {
+                                try {
+                                    TestResults testResults =
+                                            runVmTestService(
+                                                    TAG,
+                                                    vm,
+                                                    (ts, tr) -> {
+                                                        ts.insecurelyWritePayloadRpData(data);
+                                                        Thread.sleep(5 * 1000); // 5 seconds of wait
+                                                        tr.mPayloadRpData =
+                                                                ts.insecurelyReadPayloadRpData();
+                                                    });
+                                    return testResults;
+                                } catch (Exception e) {
+                                    throw new CompletionException(e);
+                                }
+                            });
+        }
+
+        for (int i = 0; i < numVMs; i++) {
+            TestResults testResult = resultFutureList[i].get();
+            testResult.assertNoException();
+            assertThat(testResult.mPayloadRpData).isEqualTo(data);
+        }
+    }
+
+    @Test
     @CddTest
     public void isNewInstanceTest() throws Exception {
         assumeSupportedDevice();
@@ -2771,13 +2821,6 @@
         }
     }
 
-    private long getAvailableMemory() {
-        ActivityManager am = getContext().getSystemService(ActivityManager.class);
-        ActivityManager.MemoryInfo memoryInfo = new ActivityManager.MemoryInfo();
-        am.getMemoryInfo(memoryInfo);
-        return memoryInfo.availMem;
-    }
-
     private VirtualMachineDescriptor toParcelFromParcel(VirtualMachineDescriptor descriptor) {
         Parcel parcel = Parcel.obtain();
         descriptor.writeToParcel(parcel, 0);
@@ -2810,17 +2853,4 @@
         Exception e = assertThrows(VirtualMachineException.class, runnable);
         assertThat(e).hasMessageThat().contains(expectedContents);
     }
-
-    private long minMemoryRequired() {
-        assertThat(Build.SUPPORTED_ABIS).isNotEmpty();
-        String primaryAbi = Build.SUPPORTED_ABIS[0];
-        switch (primaryAbi) {
-            case "x86_64":
-                return MIN_MEM_X86_64;
-            case "arm64-v8a":
-            case "arm64-v8a-hwasan":
-                return MIN_MEM_ARM64;
-        }
-        throw new AssertionError("Unsupported ABI: " + primaryAbi);
-    }
 }