Refresh & retry Sk requests in case of failure.

While reading/writing payload's data in Secretkeeper, the session may
have expired. Refresh the session, and retry.

Test: #rollbackProtectedDataCanBeAccessedPostConnectionExpiration
Bug: 389083566
Change-Id: I7c224c46f41081a3ae555254e015870d5b2eb911
diff --git a/guest/microdroid_manager/src/vm_secret.rs b/guest/microdroid_manager/src/vm_secret.rs
index 04d6817..5489751 100644
--- a/guest/microdroid_manager/src/vm_secret.rs
+++ b/guest/microdroid_manager/src/vm_secret.rs
@@ -167,7 +167,11 @@
             return Err(anyhow!("Rollback protected data is not available with V1 secrets"));
         };
         let payload_id = sha::sha512(instance_id);
-        secretkeeper_session.get_secret(payload_id)
+        secretkeeper_session.get_secret(payload_id).or_else(|e| {
+            log::info!("Secretkeeper get failed with {e:?}. Refreshing connection & retrying!");
+            secretkeeper_session.refresh()?;
+            secretkeeper_session.get_secret(payload_id)
+        })
     }
 
     pub fn write_payload_data_rp(&self, data: &[u8; SECRET_SIZE]) -> Result<()> {
@@ -176,7 +180,12 @@
             return Err(anyhow!("Rollback protected data is not available with V1 secrets"));
         };
         let payload_id = sha::sha512(instance_id);
-        secretkeeper_session.store_secret(payload_id, data)
+        if let Err(e) = secretkeeper_session.store_secret(payload_id, data.clone()) {
+            log::info!("Secretkeeper store failed with {e:?}. Refreshing connection & retrying!");
+            secretkeeper_session.refresh()?;
+            secretkeeper_session.store_secret(payload_id, data)?;
+        }
+        Ok(())
     }
 }
 
@@ -272,6 +281,11 @@
         Ok(Self { session, sealing_policy })
     }
 
+    fn refresh(&self) -> Result<()> {
+        let mut session = self.session.lock().unwrap();
+        Ok(session.refresh()?)
+    }
+
     fn store_secret(&self, id: [u8; ID_SIZE], secret: Zeroizing<[u8; SECRET_SIZE]>) -> Result<()> {
         let store_request = StoreSecretRequest {
             id: Id(id),
diff --git a/tests/testapk/src/java/com/android/microdroid/test/MicrodroidTests.java b/tests/testapk/src/java/com/android/microdroid/test/MicrodroidTests.java
index 797214c..2c940c9 100644
--- a/tests/testapk/src/java/com/android/microdroid/test/MicrodroidTests.java
+++ b/tests/testapk/src/java/com/android/microdroid/test/MicrodroidTests.java
@@ -119,6 +119,7 @@
 import java.util.OptionalLong;
 import java.util.UUID;
 import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.CompletionException;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicReference;
@@ -1909,7 +1910,7 @@
                         (ts, tr) -> {
                             tr.mPayloadRpData = ts.insecurelyReadPayloadRpData();
                         });
-        // ainsecurelyReadPayloadRpData()` must've failed since no data was ever written!
+        // `insecurelyReadPayloadRpData()` must've failed since no data was ever written!
         assertWithMessage("The read (unexpectedly) succeeded!")
                 .that(testResults.mException)
                 .isNotNull();
@@ -1940,6 +1941,62 @@
     }
 
     @Test
+    public void rollbackProtectedDataCanBeAccessedPostConnectionExpiration() throws Exception {
+        final long vmSize = minMemoryRequired();
+        // The reference implementation of Secretkeeper maintains 4 live session keys,
+        // dropping the oldest one when new connections are requested. Therefore we spin 8 VMs
+        // asynchronously.
+        // Within a VM, wait for 5 sec (> Microdroid boot time) and trigger rp data access
+        // hoping at least some of the connection between VM <-> Secretkeeper are expired.
+        final int numVMs = 8;
+        final long availableMem = getAvailableMemory();
+
+        // Let's not use more than half of the available memory
+        assume().withMessage("Available memory (" + availableMem + " bytes) too small")
+                .that((numVMs * vmSize) <= (availableMem / 2))
+                .isTrue();
+
+        VirtualMachineConfig config =
+                newVmConfigBuilderWithPayloadBinary("MicrodroidTestNativeLib.so")
+                        .setDebugLevel(DEBUG_LEVEL_FULL)
+                        .setMemoryBytes(vmSize)
+                        .build();
+        byte[] data = new byte[32];
+        Arrays.fill(data, (byte) 0xcc);
+
+        CompletableFuture<TestResults>[] resultFutureList = new CompletableFuture[numVMs];
+        for (int i = 0; i < numVMs; i++) {
+            final VirtualMachine vm =
+                    forceCreateNewVirtualMachine("test_sk_session_expiration_vm_" + i, config);
+            resultFutureList[i] =
+                    CompletableFuture.supplyAsync(
+                            () -> {
+                                try {
+                                    TestResults testResults =
+                                            runVmTestService(
+                                                    TAG,
+                                                    vm,
+                                                    (ts, tr) -> {
+                                                        ts.insecurelyWritePayloadRpData(data);
+                                                        Thread.sleep(5 * 1000); // 5 seconds of wait
+                                                        tr.mPayloadRpData =
+                                                                ts.insecurelyReadPayloadRpData();
+                                                    });
+                                    return testResults;
+                                } catch (Exception e) {
+                                    throw new CompletionException(e);
+                                }
+                            });
+        }
+
+        for (int i = 0; i < numVMs; i++) {
+            TestResults testResult = resultFutureList[i].get();
+            testResult.assertNoException();
+            assertThat(testResult.mPayloadRpData).isEqualTo(data);
+        }
+    }
+
+    @Test
     @CddTest
     public void canReadFileFromAssets_debugFull() throws Exception {
         assumeSupportedDevice();