Implement timeout mechanism to wait for network connectivity.

This includes:
Network connectivity callback with timeout.
Timeout for entire retry mechanism

Test: manual atest and manual testing on pixel 6
Bug: 231660348

Change-Id: I4b9f545e95b43826c3ba7b11467dd9d9e600f9f5
diff --git a/services/core/java/com/android/server/locksettings/RebootEscrowManager.java b/services/core/java/com/android/server/locksettings/RebootEscrowManager.java
index 4d525da..9b42cfc 100644
--- a/services/core/java/com/android/server/locksettings/RebootEscrowManager.java
+++ b/services/core/java/com/android/server/locksettings/RebootEscrowManager.java
@@ -36,6 +36,7 @@
 import android.net.ConnectivityManager;
 import android.net.Network;
 import android.net.NetworkCapabilities;
+import android.net.NetworkRequest;
 import android.os.Handler;
 import android.os.PowerManager;
 import android.os.SystemClock;
@@ -103,12 +104,12 @@
 
     /**
      * Number of boots until we consider the escrow data to be stale for the purposes of metrics.
-     * <p>
-     * If the delta between the current boot number and the boot number stored when the mechanism
+     *
+     * <p>If the delta between the current boot number and the boot number stored when the mechanism
      * was armed is under this number and the escrow mechanism fails, we report it as a failure of
      * the mechanism.
-     * <p>
-     * If the delta over this number and escrow fails, we will not report the metric as failed
+     *
+     * <p>If the delta over this number and escrow fails, we will not report the metric as failed
      * since there most likely was some other issue if the device rebooted several times before
      * getting to the escrow restore code.
      */
@@ -120,8 +121,11 @@
      */
     private static final int DEFAULT_LOAD_ESCROW_DATA_RETRY_COUNT = 3;
     private static final int DEFAULT_LOAD_ESCROW_DATA_RETRY_INTERVAL_SECONDS = 30;
+
     // 3 minutes. It's enough for the default 3 retries with 30 seconds interval
-    private static final int DEFAULT_WAKE_LOCK_TIMEOUT_MILLIS = 180_000;
+    private static final int DEFAULT_LOAD_ESCROW_BASE_TIMEOUT_MILLIS = 180_000;
+    // 5 seconds. An extension of the overall RoR timeout to account for overhead.
+    private static final int DEFAULT_LOAD_ESCROW_TIMEOUT_EXTENSION_MILLIS = 5000;
 
     @IntDef(prefix = {"ERROR_"}, value = {
             ERROR_NONE,
@@ -133,6 +137,7 @@
             ERROR_PROVIDER_MISMATCH,
             ERROR_KEYSTORE_FAILURE,
             ERROR_NO_NETWORK,
+            ERROR_TIMEOUT_EXHAUSTED,
     })
     @Retention(RetentionPolicy.SOURCE)
     @interface RebootEscrowErrorCode {
@@ -147,6 +152,7 @@
     static final int ERROR_PROVIDER_MISMATCH = 6;
     static final int ERROR_KEYSTORE_FAILURE = 7;
     static final int ERROR_NO_NETWORK = 8;
+    static final int ERROR_TIMEOUT_EXHAUSTED = 9;
 
     private @RebootEscrowErrorCode int mLoadEscrowDataErrorCode = ERROR_NONE;
 
@@ -168,6 +174,15 @@
     /** Notified when mRebootEscrowReady changes. */
     private RebootEscrowListener mRebootEscrowListener;
 
+    /** Set when unlocking reboot escrow times out. */
+    private boolean mRebootEscrowTimedOut = false;
+
+    /**
+     * Set when {@link #loadRebootEscrowDataWithRetry} is called to ensure the function is only
+     * called once.
+     */
+    private boolean mLoadEscrowDataWithRetry = false;
+
     /**
      * Hold this lock when checking or generating the reboot escrow key.
      */
@@ -192,6 +207,7 @@
 
     PowerManager.WakeLock mWakeLock;
 
+    private ConnectivityManager.NetworkCallback mNetworkCallback;
 
     interface Callbacks {
         boolean isUserSecure(int userId);
@@ -246,6 +262,11 @@
                     "server_based_ror_enabled", false);
         }
 
+        public boolean waitForInternet() {
+            return DeviceConfig.getBoolean(
+                    DeviceConfig.NAMESPACE_OTA, "wait_for_internet_ror", false);
+        }
+
         public boolean isNetworkConnected() {
             final ConnectivityManager connectivityManager =
                     mContext.getSystemService(ConnectivityManager.class);
@@ -263,6 +284,38 @@
                             NetworkCapabilities.NET_CAPABILITY_VALIDATED);
         }
 
+        /**
+         * Request network with internet connectivity with timeout.
+         *
+         * @param networkCallback callback to be executed if connectivity manager exists.
+         * @return true if success
+         */
+        public boolean requestNetworkWithInternet(
+                ConnectivityManager.NetworkCallback networkCallback) {
+            final ConnectivityManager connectivityManager =
+                    mContext.getSystemService(ConnectivityManager.class);
+            if (connectivityManager == null) {
+                return false;
+            }
+            NetworkRequest request =
+                    new NetworkRequest.Builder()
+                            .addCapability(NetworkCapabilities.NET_CAPABILITY_INTERNET)
+                            .build();
+
+            connectivityManager.requestNetwork(
+                    request, networkCallback, getLoadEscrowTimeoutMillis());
+            return true;
+        }
+
+        public void stopRequestingNetwork(ConnectivityManager.NetworkCallback networkCallback) {
+            final ConnectivityManager connectivityManager =
+                    mContext.getSystemService(ConnectivityManager.class);
+            if (connectivityManager == null) {
+                return;
+            }
+            connectivityManager.unregisterNetworkCallback(networkCallback);
+        }
+
         public Context getContext() {
             return mContext;
         }
@@ -318,6 +371,16 @@
                     DEFAULT_LOAD_ESCROW_DATA_RETRY_INTERVAL_SECONDS);
         }
 
+        @VisibleForTesting
+        public int getLoadEscrowTimeoutMillis() {
+            return DEFAULT_LOAD_ESCROW_BASE_TIMEOUT_MILLIS;
+        }
+
+        @VisibleForTesting
+        public int getWakeLockTimeoutMillis() {
+            return getLoadEscrowTimeoutMillis() + DEFAULT_LOAD_ESCROW_TIMEOUT_EXTENSION_MILLIS;
+        }
+
         public void reportMetric(boolean success, int errorCode, int serviceType, int attemptCount,
                 int escrowDurationInSeconds, int vbmetaDigestStatus,
                 int durationSinceBootCompleteInSeconds) {
@@ -351,13 +414,37 @@
         mKeyStoreManager = injector.getKeyStoreManager();
     }
 
-    private void onGetRebootEscrowKeyFailed(List<UserInfo> users, int attemptCount) {
+    /** Wrapper function to set error code serialized through handler, */
+    private void setLoadEscrowDataErrorCode(@RebootEscrowErrorCode int value, Handler handler) {
+        if (mInjector.waitForInternet()) {
+            mInjector.post(
+                    handler,
+                    () -> {
+                        mLoadEscrowDataErrorCode = value;
+                    });
+        } else {
+            mLoadEscrowDataErrorCode = value;
+        }
+    }
+
+    /** Wrapper function to compare and set error code serialized through handler. */
+    private void compareAndSetLoadEscrowDataErrorCode(
+            @RebootEscrowErrorCode int expectedValue,
+            @RebootEscrowErrorCode int newValue,
+            Handler handler) {
+        if (expectedValue == mLoadEscrowDataErrorCode) {
+            setLoadEscrowDataErrorCode(newValue, handler);
+        }
+    }
+
+    private void onGetRebootEscrowKeyFailed(
+            List<UserInfo> users, int attemptCount, Handler retryHandler) {
         Slog.w(TAG, "Had reboot escrow data for users, but no key; removing escrow storage.");
         for (UserInfo user : users) {
             mStorage.removeRebootEscrow(user.id);
         }
 
-        onEscrowRestoreComplete(false, attemptCount);
+        onEscrowRestoreComplete(false, attemptCount, retryHandler);
     }
 
     void loadRebootEscrowDataIfAvailable(Handler retryHandler) {
@@ -380,39 +467,130 @@
         mWakeLock = mInjector.getWakeLock();
         if (mWakeLock != null) {
             mWakeLock.setReferenceCounted(false);
-            mWakeLock.acquire(DEFAULT_WAKE_LOCK_TIMEOUT_MILLIS);
+            mWakeLock.acquire(mInjector.getWakeLockTimeoutMillis());
+        }
+
+        if (mInjector.waitForInternet()) {
+            // Timeout to stop retrying same as the wake lock timeout.
+            mInjector.postDelayed(
+                    retryHandler,
+                    () -> {
+                        mRebootEscrowTimedOut = true;
+                    },
+                    mInjector.getLoadEscrowTimeoutMillis());
+
+            mInjector.post(
+                    retryHandler,
+                    () -> loadRebootEscrowDataOnInternet(retryHandler, users, rebootEscrowUsers));
+            return;
         }
 
         mInjector.post(retryHandler, () -> loadRebootEscrowDataWithRetry(
                 retryHandler, 0, users, rebootEscrowUsers));
     }
 
-    void scheduleLoadRebootEscrowDataOrFail(Handler retryHandler, int attemptNumber,
-            List<UserInfo> users, List<UserInfo> rebootEscrowUsers) {
+    void scheduleLoadRebootEscrowDataOrFail(
+            Handler retryHandler,
+            int attemptNumber,
+            List<UserInfo> users,
+            List<UserInfo> rebootEscrowUsers) {
         Objects.requireNonNull(retryHandler);
 
         final int retryLimit = mInjector.getLoadEscrowDataRetryLimit();
         final int retryIntervalInSeconds = mInjector.getLoadEscrowDataRetryIntervalSeconds();
 
-        if (attemptNumber < retryLimit) {
+        if (attemptNumber < retryLimit && !mRebootEscrowTimedOut) {
             Slog.i(TAG, "Scheduling loadRebootEscrowData retry number: " + attemptNumber);
             mInjector.postDelayed(retryHandler, () -> loadRebootEscrowDataWithRetry(
-                    retryHandler, attemptNumber, users, rebootEscrowUsers),
+                            retryHandler, attemptNumber, users, rebootEscrowUsers),
                     retryIntervalInSeconds * 1000);
             return;
         }
 
+        if (mInjector.waitForInternet()) {
+            if (mRebootEscrowTimedOut) {
+                Slog.w(TAG, "Failed to load reboot escrow data within timeout");
+                compareAndSetLoadEscrowDataErrorCode(
+                        ERROR_NONE, ERROR_TIMEOUT_EXHAUSTED, retryHandler);
+            } else {
+                Slog.w(
+                        TAG,
+                        "Failed to load reboot escrow data after " + attemptNumber + " attempts");
+                compareAndSetLoadEscrowDataErrorCode(
+                        ERROR_NONE, ERROR_RETRY_COUNT_EXHAUSTED, retryHandler);
+            }
+            onGetRebootEscrowKeyFailed(users, attemptNumber, retryHandler);
+            return;
+        }
+
         Slog.w(TAG, "Failed to load reboot escrow data after " + attemptNumber + " attempts");
         if (mInjector.serverBasedResumeOnReboot() && !mInjector.isNetworkConnected()) {
             mLoadEscrowDataErrorCode = ERROR_NO_NETWORK;
         } else {
             mLoadEscrowDataErrorCode = ERROR_RETRY_COUNT_EXHAUSTED;
         }
-        onGetRebootEscrowKeyFailed(users, attemptNumber);
+        onGetRebootEscrowKeyFailed(users, attemptNumber, retryHandler);
     }
 
-    void loadRebootEscrowDataWithRetry(Handler retryHandler, int attemptNumber,
-            List<UserInfo> users, List<UserInfo> rebootEscrowUsers) {
+    void loadRebootEscrowDataOnInternet(
+            Handler retryHandler, List<UserInfo> users, List<UserInfo> rebootEscrowUsers) {
+
+        // HAL-Based RoR does not require network connectivity.
+        if (!mInjector.serverBasedResumeOnReboot()) {
+            loadRebootEscrowDataWithRetry(
+                    retryHandler, /* attemptNumber = */ 0, users, rebootEscrowUsers);
+            return;
+        }
+
+        mNetworkCallback =
+                new ConnectivityManager.NetworkCallback() {
+                    @Override
+                    public void onAvailable(Network network) {
+                        compareAndSetLoadEscrowDataErrorCode(
+                                ERROR_NO_NETWORK, ERROR_NONE, retryHandler);
+
+                        if (!mLoadEscrowDataWithRetry) {
+                            mLoadEscrowDataWithRetry = true;
+                            // Only kickoff retry mechanism on first onAvailable call.
+                            loadRebootEscrowDataWithRetry(
+                                    retryHandler,
+                                    /* attemptNumber = */ 0,
+                                    users,
+                                    rebootEscrowUsers);
+                        }
+                    }
+
+                    @Override
+                    public void onUnavailable() {
+                        Slog.w(TAG, "Failed to connect to network within timeout");
+                        compareAndSetLoadEscrowDataErrorCode(
+                                ERROR_NONE, ERROR_NO_NETWORK, retryHandler);
+                        onGetRebootEscrowKeyFailed(users, /* attemptCount= */ 0, retryHandler);
+                    }
+
+                    @Override
+                    public void onLost(Network lostNetwork) {
+                        // TODO(b/231660348): If network is lost, wait for network to become
+                        // available again.
+                        Slog.w(TAG, "Network lost, still attempting to load escrow key.");
+                        compareAndSetLoadEscrowDataErrorCode(
+                                ERROR_NONE, ERROR_NO_NETWORK, retryHandler);
+                    }
+                };
+
+        // Fallback to retrying without waiting for internet on failure.
+        boolean success = mInjector.requestNetworkWithInternet(mNetworkCallback);
+        if (!success) {
+            loadRebootEscrowDataWithRetry(
+                    retryHandler, /* attemptNumber = */ 0, users, rebootEscrowUsers);
+        }
+    }
+
+    void loadRebootEscrowDataWithRetry(
+            Handler retryHandler,
+            int attemptNumber,
+            List<UserInfo> users,
+            List<UserInfo> rebootEscrowUsers) {
         // Fetch the key from keystore to decrypt the escrow data & escrow key; this key is
         // generated before reboot. Note that we will clear the escrow key even if the keystore key
         // is null.
@@ -423,7 +601,7 @@
 
         RebootEscrowKey escrowKey;
         try {
-            escrowKey = getAndClearRebootEscrowKey(kk);
+            escrowKey = getAndClearRebootEscrowKey(kk, retryHandler);
         } catch (IOException e) {
             Slog.i(TAG, "Failed to load escrow key, scheduling retry.", e);
             scheduleLoadRebootEscrowDataOrFail(retryHandler, attemptNumber + 1, users,
@@ -438,12 +616,12 @@
                         ? RebootEscrowProviderInterface.TYPE_SERVER_BASED
                         : RebootEscrowProviderInterface.TYPE_HAL;
                 if (providerType != mStorage.getInt(REBOOT_ESCROW_KEY_PROVIDER, -1, USER_SYSTEM)) {
-                    mLoadEscrowDataErrorCode = ERROR_PROVIDER_MISMATCH;
+                    setLoadEscrowDataErrorCode(ERROR_PROVIDER_MISMATCH, retryHandler);
                 } else {
-                    mLoadEscrowDataErrorCode = ERROR_LOAD_ESCROW_KEY;
+                    setLoadEscrowDataErrorCode(ERROR_LOAD_ESCROW_KEY, retryHandler);
                 }
             }
-            onGetRebootEscrowKeyFailed(users, attemptNumber + 1);
+            onGetRebootEscrowKeyFailed(users, attemptNumber + 1, retryHandler);
             return;
         }
 
@@ -454,10 +632,10 @@
             allUsersUnlocked &= restoreRebootEscrowForUser(user.id, escrowKey, kk);
         }
 
-        if (!allUsersUnlocked && mLoadEscrowDataErrorCode == ERROR_NONE) {
-            mLoadEscrowDataErrorCode = ERROR_UNLOCK_ALL_USERS;
+        if (!allUsersUnlocked) {
+            compareAndSetLoadEscrowDataErrorCode(ERROR_NONE, ERROR_UNLOCK_ALL_USERS, retryHandler);
         }
-        onEscrowRestoreComplete(allUsersUnlocked, attemptNumber + 1);
+        onEscrowRestoreComplete(allUsersUnlocked, attemptNumber + 1, retryHandler);
     }
 
     private void clearMetricsStorage() {
@@ -497,7 +675,8 @@
                 .REBOOT_ESCROW_RECOVERY_REPORTED__VBMETA_DIGEST_STATUS__MISMATCH;
     }
 
-    private void reportMetricOnRestoreComplete(boolean success, int attemptCount) {
+    private void reportMetricOnRestoreComplete(
+            boolean success, int attemptCount, Handler retryHandler) {
         int serviceType = mInjector.serverBasedResumeOnReboot()
                 ? FrameworkStatsLog.REBOOT_ESCROW_RECOVERY_REPORTED__TYPE__SERVER_BASED
                 : FrameworkStatsLog.REBOOT_ESCROW_RECOVERY_REPORTED__TYPE__HAL;
@@ -511,52 +690,69 @@
         }
 
         int vbmetaDigestStatus = getVbmetaDigestStatusOnRestoreComplete();
-        if (!success && mLoadEscrowDataErrorCode == ERROR_NONE) {
-            mLoadEscrowDataErrorCode = ERROR_UNKNOWN;
+        if (!success) {
+            compareAndSetLoadEscrowDataErrorCode(ERROR_NONE, ERROR_UNKNOWN, retryHandler);
         }
 
-        Slog.i(TAG, "Reporting RoR recovery metrics, success: " + success + ", service type: "
-                + serviceType + ", error code: " + mLoadEscrowDataErrorCode);
+        Slog.i(
+                TAG,
+                "Reporting RoR recovery metrics, success: "
+                        + success
+                        + ", service type: "
+                        + serviceType
+                        + ", error code: "
+                        + mLoadEscrowDataErrorCode);
         // TODO(179105110) report the duration since boot complete.
-        mInjector.reportMetric(success, mLoadEscrowDataErrorCode, serviceType, attemptCount,
-                escrowDurationInSeconds, vbmetaDigestStatus, -1);
+        mInjector.reportMetric(
+                success,
+                mLoadEscrowDataErrorCode,
+                serviceType,
+                attemptCount,
+                escrowDurationInSeconds,
+                vbmetaDigestStatus,
+                -1);
 
-        mLoadEscrowDataErrorCode = ERROR_NONE;
+        setLoadEscrowDataErrorCode(ERROR_NONE, retryHandler);
     }
 
-    private void onEscrowRestoreComplete(boolean success, int attemptCount) {
+    private void onEscrowRestoreComplete(boolean success, int attemptCount, Handler retryHandler) {
         int previousBootCount = mStorage.getInt(REBOOT_ESCROW_ARMED_KEY, -1, USER_SYSTEM);
 
         int bootCountDelta = mInjector.getBootCount() - previousBootCount;
         if (success || (previousBootCount != -1 && bootCountDelta <= BOOT_COUNT_TOLERANCE)) {
-            reportMetricOnRestoreComplete(success, attemptCount);
+            reportMetricOnRestoreComplete(success, attemptCount, retryHandler);
         }
-
         // Clear the old key in keystore. A new key will be generated by new RoR requests.
         mKeyStoreManager.clearKeyStoreEncryptionKey();
         // Clear the saved reboot escrow provider
         mInjector.clearRebootEscrowProvider();
         clearMetricsStorage();
 
+        if (mNetworkCallback != null) {
+            mInjector.stopRequestingNetwork(mNetworkCallback);
+        }
+
         if (mWakeLock != null) {
             mWakeLock.release();
         }
     }
 
-    private RebootEscrowKey getAndClearRebootEscrowKey(SecretKey kk) throws IOException {
+    private RebootEscrowKey getAndClearRebootEscrowKey(SecretKey kk, Handler retryHandler)
+            throws IOException {
         RebootEscrowProviderInterface rebootEscrowProvider =
                 mInjector.createRebootEscrowProviderIfNeeded();
         if (rebootEscrowProvider == null) {
-            Slog.w(TAG,
+            Slog.w(
+                    TAG,
                     "Had reboot escrow data for users, but RebootEscrowProvider is unavailable");
-            mLoadEscrowDataErrorCode = ERROR_NO_PROVIDER;
+            setLoadEscrowDataErrorCode(ERROR_NO_PROVIDER, retryHandler);
             return null;
         }
 
         // Server based RoR always need the decryption key from keystore.
         if (rebootEscrowProvider.getType() == RebootEscrowProviderInterface.TYPE_SERVER_BASED
                 && kk == null) {
-            mLoadEscrowDataErrorCode = ERROR_KEYSTORE_FAILURE;
+            setLoadEscrowDataErrorCode(ERROR_KEYSTORE_FAILURE, retryHandler);
             return null;
         }
 
@@ -870,6 +1066,9 @@
         pw.print("mRebootEscrowListener=");
         pw.println(mRebootEscrowListener);
 
+        pw.print("mLoadEscrowDataErrorCode=");
+        pw.println(mLoadEscrowDataErrorCode);
+
         boolean keySet;
         synchronized (mKeyGenerationLock) {
             keySet = mPendingRebootEscrowKey != null;
diff --git a/services/tests/servicestests/src/com/android/server/locksettings/RebootEscrowManagerTests.java b/services/tests/servicestests/src/com/android/server/locksettings/RebootEscrowManagerTests.java
index b01c1c8..ce6bd6c 100644
--- a/services/tests/servicestests/src/com/android/server/locksettings/RebootEscrowManagerTests.java
+++ b/services/tests/servicestests/src/com/android/server/locksettings/RebootEscrowManagerTests.java
@@ -50,6 +50,8 @@
 import android.content.ContextWrapper;
 import android.content.pm.UserInfo;
 import android.hardware.rebootescrow.IRebootEscrow;
+import android.net.ConnectivityManager;
+import android.net.Network;
 import android.os.Handler;
 import android.os.HandlerThread;
 import android.os.RemoteException;
@@ -72,6 +74,7 @@
 import java.io.File;
 import java.io.IOException;
 import java.util.ArrayList;
+import java.util.function.Consumer;
 
 import javax.crypto.SecretKey;
 import javax.crypto.spec.SecretKeySpec;
@@ -113,6 +116,7 @@
     private RebootEscrowManager mService;
     private SecretKey mAesKey;
     private MockInjector mMockInjector;
+    private Handler mHandler;
 
     public interface MockableRebootEscrowInjected {
         int getBootCount();
@@ -132,6 +136,9 @@
         private final RebootEscrowKeyStoreManager mKeyStoreManager;
         private boolean mServerBased;
         private RebootEscrowProviderInterface mRebootEscrowProviderInUse;
+        private ConnectivityManager.NetworkCallback mNetworkCallback;
+        private Consumer<ConnectivityManager.NetworkCallback> mNetworkConsumer;
+        private boolean mWaitForInternet;
 
         MockInjector(Context context, UserManager userManager,
                 IRebootEscrow rebootEscrow,
@@ -142,6 +149,7 @@
             mRebootEscrow = rebootEscrow;
             mServiceConnection = null;
             mServerBased = false;
+            mWaitForInternet = false;
             RebootEscrowProviderHalImpl.Injector halInjector =
                     new RebootEscrowProviderHalImpl.Injector() {
                         @Override
@@ -164,6 +172,7 @@
             mServiceConnection = serviceConnection;
             mRebootEscrow = null;
             mServerBased = true;
+            mWaitForInternet = false;
             RebootEscrowProviderServerBasedImpl.Injector injector =
                     new RebootEscrowProviderServerBasedImpl.Injector(serviceConnection) {
                         @Override
@@ -199,11 +208,33 @@
         }
 
         @Override
+        public boolean waitForInternet() {
+            return mWaitForInternet;
+        }
+
+        public void setWaitForNetwork(boolean waitForNetworkEnabled) {
+            mWaitForInternet = waitForNetworkEnabled;
+        }
+
+        @Override
         public boolean isNetworkConnected() {
             return false;
         }
 
         @Override
+        public boolean requestNetworkWithInternet(
+                ConnectivityManager.NetworkCallback networkCallback) {
+            mNetworkCallback = networkCallback;
+            mNetworkConsumer.accept(networkCallback);
+            return true;
+        }
+
+        @Override
+        public void stopRequestingNetwork(ConnectivityManager.NetworkCallback networkCallback) {
+            mNetworkCallback = null;
+        }
+
+        @Override
         public RebootEscrowProviderInterface createRebootEscrowProviderIfNeeded() {
             mRebootEscrowProviderInUse = mDefaultRebootEscrowProvider;
             return mRebootEscrowProviderInUse;
@@ -242,6 +273,12 @@
         }
 
         @Override
+        public int getLoadEscrowTimeoutMillis() {
+            // Timeout in 3 seconds.
+            return 3000;
+        }
+
+        @Override
         public String getVbmetaDigest(boolean other) {
             return other ? "" : "fake digest";
         }
@@ -291,6 +328,9 @@
         mMockInjector = new MockInjector(mContext, mUserManager, mRebootEscrow,
                 mKeyStoreManager, mStorage, mInjected);
         mService = new RebootEscrowManager(mMockInjector, mCallbacks, mStorage);
+        HandlerThread thread = new HandlerThread("RebootEscrowManagerTest");
+        thread.start();
+        mHandler = new Handler(thread.getLooper());
     }
 
     private void setServerBasedRebootEscrowProvider() throws Exception {
@@ -462,7 +502,7 @@
 
     @Test
     public void loadRebootEscrowDataIfAvailable_NothingAvailable_Success() throws Exception {
-        mService.loadRebootEscrowDataIfAvailable(null);
+        mService.loadRebootEscrowDataIfAvailable(mHandler);
     }
 
     @Test
@@ -499,7 +539,7 @@
                 eq(20), eq(0) /* vbmeta status */, anyInt());
         when(mRebootEscrow.retrieveKey()).thenAnswer(invocation -> keyByteCaptor.getValue());
 
-        mService.loadRebootEscrowDataIfAvailable(null);
+        mService.loadRebootEscrowDataIfAvailable(mHandler);
         verify(mRebootEscrow).retrieveKey();
         assertTrue(metricsSuccessCaptor.getValue());
         verify(mKeyStoreManager).clearKeyStoreEncryptionKey();
@@ -531,9 +571,16 @@
         // pretend reboot happens here
         when(mInjected.getBootCount()).thenReturn(1);
         ArgumentCaptor<Boolean> metricsSuccessCaptor = ArgumentCaptor.forClass(Boolean.class);
-        doNothing().when(mInjected).reportMetric(metricsSuccessCaptor.capture(),
-                eq(0) /* error code */, eq(2) /* Server based */, eq(1) /* attempt count */,
-                anyInt(), eq(0) /* vbmeta status */, anyInt());
+        doNothing()
+                .when(mInjected)
+                .reportMetric(
+                        metricsSuccessCaptor.capture(),
+                        eq(0) /* error code */,
+                        eq(2) /* Server based */,
+                        eq(1) /* attempt count */,
+                        anyInt(),
+                        eq(0) /* vbmeta status */,
+                        anyInt());
 
         when(mServiceConnection.unwrap(any(), anyLong()))
                 .thenAnswer(invocation -> invocation.getArgument(0));
@@ -569,15 +616,23 @@
         when(mInjected.getBootCount()).thenReturn(1);
         ArgumentCaptor<Boolean> metricsSuccessCaptor = ArgumentCaptor.forClass(Boolean.class);
         ArgumentCaptor<Integer> metricsErrorCodeCaptor = ArgumentCaptor.forClass(Integer.class);
-        doNothing().when(mInjected).reportMetric(metricsSuccessCaptor.capture(),
-                metricsErrorCodeCaptor.capture(), eq(2) /* Server based */,
-                eq(1) /* attempt count */, anyInt(), eq(0) /* vbmeta status */, anyInt());
+        doNothing()
+                .when(mInjected)
+                .reportMetric(
+                        metricsSuccessCaptor.capture(),
+                        metricsErrorCodeCaptor.capture(),
+                        eq(2) /* Server based */,
+                        eq(1) /* attempt count */,
+                        anyInt(),
+                        eq(0) /* vbmeta status */,
+                        anyInt());
 
         when(mServiceConnection.unwrap(any(), anyLong())).thenThrow(RemoteException.class);
         mService.loadRebootEscrowDataIfAvailable(null);
         verify(mServiceConnection).unwrap(any(), anyLong());
         assertFalse(metricsSuccessCaptor.getValue());
-        assertEquals(Integer.valueOf(RebootEscrowManager.ERROR_LOAD_ESCROW_KEY),
+        assertEquals(
+                Integer.valueOf(RebootEscrowManager.ERROR_LOAD_ESCROW_KEY),
                 metricsErrorCodeCaptor.getValue());
     }
 
@@ -606,18 +661,24 @@
         when(mInjected.getBootCount()).thenReturn(1);
         ArgumentCaptor<Boolean> metricsSuccessCaptor = ArgumentCaptor.forClass(Boolean.class);
         ArgumentCaptor<Integer> metricsErrorCodeCaptor = ArgumentCaptor.forClass(Integer.class);
-        doNothing().when(mInjected).reportMetric(metricsSuccessCaptor.capture(),
-                metricsErrorCodeCaptor.capture(), eq(2) /* Server based */,
-                eq(2) /* attempt count */, anyInt(), eq(0) /* vbmeta status */, anyInt());
+        doNothing()
+                .when(mInjected)
+                .reportMetric(
+                        metricsSuccessCaptor.capture(),
+                        metricsErrorCodeCaptor.capture(),
+                        eq(2) /* Server based */,
+                        eq(2) /* attempt count */,
+                        anyInt(),
+                        eq(0) /* vbmeta status */,
+                        anyInt());
         when(mServiceConnection.unwrap(any(), anyLong())).thenThrow(IOException.class);
 
-        HandlerThread thread = new HandlerThread("RebootEscrowManagerTest");
-        thread.start();
-        mService.loadRebootEscrowDataIfAvailable(new Handler(thread.getLooper()));
+        mService.loadRebootEscrowDataIfAvailable(mHandler);
         // Sleep 5s for the retry to complete
         Thread.sleep(5 * 1000);
         assertFalse(metricsSuccessCaptor.getValue());
-        assertEquals(Integer.valueOf(RebootEscrowManager.ERROR_NO_NETWORK),
+        assertEquals(
+                Integer.valueOf(RebootEscrowManager.ERROR_NO_NETWORK),
                 metricsErrorCodeCaptor.getValue());
     }
 
@@ -645,16 +706,22 @@
         // pretend reboot happens here
         when(mInjected.getBootCount()).thenReturn(1);
         ArgumentCaptor<Boolean> metricsSuccessCaptor = ArgumentCaptor.forClass(Boolean.class);
-        doNothing().when(mInjected).reportMetric(metricsSuccessCaptor.capture(),
-                anyInt(), anyInt(), eq(2) /* attempt count */, anyInt(), anyInt(), anyInt());
+        doNothing()
+                .when(mInjected)
+                .reportMetric(
+                        metricsSuccessCaptor.capture(),
+                        anyInt(),
+                        anyInt(),
+                        eq(2) /* attempt count */,
+                        anyInt(),
+                        anyInt(),
+                        anyInt());
 
         when(mServiceConnection.unwrap(any(), anyLong()))
                 .thenThrow(new IOException())
                 .thenAnswer(invocation -> invocation.getArgument(0));
 
-        HandlerThread thread = new HandlerThread("RebootEscrowManagerTest");
-        thread.start();
-        mService.loadRebootEscrowDataIfAvailable(new Handler(thread.getLooper()));
+        mService.loadRebootEscrowDataIfAvailable(mHandler);
         // Sleep 5s for the retry to complete
         Thread.sleep(5 * 1000);
         verify(mServiceConnection, times(2)).unwrap(any(), anyLong());
@@ -663,6 +730,447 @@
     }
 
     @Test
+    public void loadRebootEscrowDataIfAvailable_serverBasedWaitForInternet_success()
+            throws Exception {
+        setServerBasedRebootEscrowProvider();
+        mMockInjector.setWaitForNetwork(true);
+
+        when(mInjected.getBootCount()).thenReturn(0);
+        RebootEscrowListener mockListener = mock(RebootEscrowListener.class);
+        mService.setRebootEscrowListener(mockListener);
+        mService.prepareRebootEscrow();
+
+        clearInvocations(mServiceConnection);
+        mService.callToRebootEscrowIfNeeded(PRIMARY_USER_ID, FAKE_SP_VERSION, FAKE_AUTH_TOKEN);
+        verify(mockListener).onPreparedForReboot(eq(true));
+        verify(mServiceConnection, never()).wrapBlob(any(), anyLong(), anyLong());
+
+        // Use x -> x for both wrap & unwrap functions.
+        when(mServiceConnection.wrapBlob(any(), anyLong(), anyLong()))
+                .thenAnswer(invocation -> invocation.getArgument(0));
+        assertEquals(ARM_REBOOT_ERROR_NONE, mService.armRebootEscrowIfNeeded());
+        verify(mServiceConnection).wrapBlob(any(), anyLong(), anyLong());
+        assertTrue(mStorage.hasRebootEscrowServerBlob());
+
+        // pretend reboot happens here
+        when(mInjected.getBootCount()).thenReturn(1);
+        ArgumentCaptor<Boolean> metricsSuccessCaptor = ArgumentCaptor.forClass(Boolean.class);
+        doNothing()
+                .when(mInjected)
+                .reportMetric(
+                        metricsSuccessCaptor.capture(),
+                        eq(0) /* error code */,
+                        eq(2) /* Server based */,
+                        eq(1) /* attempt count */,
+                        anyInt(),
+                        eq(0) /* vbmeta status */,
+                        anyInt());
+
+        // load escrow data
+        when(mServiceConnection.unwrap(any(), anyLong()))
+                .thenAnswer(invocation -> invocation.getArgument(0));
+        Network mockNetwork = mock(Network.class);
+        mMockInjector.mNetworkConsumer =
+                (callback) -> {
+                    callback.onAvailable(mockNetwork);
+                };
+
+        mService.loadRebootEscrowDataIfAvailable(mHandler);
+        verify(mServiceConnection).unwrap(any(), anyLong());
+        assertTrue(metricsSuccessCaptor.getValue());
+        verify(mKeyStoreManager).clearKeyStoreEncryptionKey();
+        assertNull(mMockInjector.mNetworkCallback);
+    }
+
+    @Test
+    public void loadRebootEscrowDataIfAvailable_serverBasedWaitForInternetRemoteException_Failure()
+            throws Exception {
+        setServerBasedRebootEscrowProvider();
+        mMockInjector.setWaitForNetwork(true);
+
+        when(mInjected.getBootCount()).thenReturn(0);
+        RebootEscrowListener mockListener = mock(RebootEscrowListener.class);
+        mService.setRebootEscrowListener(mockListener);
+        mService.prepareRebootEscrow();
+
+        clearInvocations(mServiceConnection);
+        mService.callToRebootEscrowIfNeeded(PRIMARY_USER_ID, FAKE_SP_VERSION, FAKE_AUTH_TOKEN);
+        verify(mockListener).onPreparedForReboot(eq(true));
+        verify(mServiceConnection, never()).wrapBlob(any(), anyLong(), anyLong());
+
+        // Use x -> x for both wrap & unwrap functions.
+        when(mServiceConnection.wrapBlob(any(), anyLong(), anyLong()))
+                .thenAnswer(invocation -> invocation.getArgument(0));
+        assertEquals(ARM_REBOOT_ERROR_NONE, mService.armRebootEscrowIfNeeded());
+        verify(mServiceConnection).wrapBlob(any(), anyLong(), anyLong());
+        assertTrue(mStorage.hasRebootEscrowServerBlob());
+
+        // pretend reboot happens here
+        when(mInjected.getBootCount()).thenReturn(1);
+        ArgumentCaptor<Boolean> metricsSuccessCaptor = ArgumentCaptor.forClass(Boolean.class);
+        ArgumentCaptor<Integer> metricsErrorCodeCaptor = ArgumentCaptor.forClass(Integer.class);
+        doNothing()
+                .when(mInjected)
+                .reportMetric(
+                        metricsSuccessCaptor.capture(),
+                        metricsErrorCodeCaptor.capture(),
+                        eq(2) /* Server based */,
+                        eq(1) /* attempt count */,
+                        anyInt(),
+                        eq(0) /* vbmeta status */,
+                        anyInt());
+
+        // load escrow data
+        when(mServiceConnection.unwrap(any(), anyLong())).thenThrow(RemoteException.class);
+        Network mockNetwork = mock(Network.class);
+        mMockInjector.mNetworkConsumer =
+                (callback) -> {
+                    callback.onAvailable(mockNetwork);
+                };
+
+        mService.loadRebootEscrowDataIfAvailable(mHandler);
+        verify(mServiceConnection).unwrap(any(), anyLong());
+        assertFalse(metricsSuccessCaptor.getValue());
+        assertEquals(
+                Integer.valueOf(RebootEscrowManager.ERROR_LOAD_ESCROW_KEY),
+                metricsErrorCodeCaptor.getValue());
+        assertNull(mMockInjector.mNetworkCallback);
+    }
+
+    @Test
+    public void loadRebootEscrowDataIfAvailable_waitForInternet_networkUnavailable()
+            throws Exception {
+        setServerBasedRebootEscrowProvider();
+        mMockInjector.setWaitForNetwork(true);
+
+        when(mInjected.getBootCount()).thenReturn(0);
+        RebootEscrowListener mockListener = mock(RebootEscrowListener.class);
+        mService.setRebootEscrowListener(mockListener);
+        mService.prepareRebootEscrow();
+
+        clearInvocations(mServiceConnection);
+        mService.callToRebootEscrowIfNeeded(PRIMARY_USER_ID, FAKE_SP_VERSION, FAKE_AUTH_TOKEN);
+        verify(mockListener).onPreparedForReboot(eq(true));
+        verify(mServiceConnection, never()).wrapBlob(any(), anyLong(), anyLong());
+
+        // Use x -> x for both wrap & unwrap functions.
+        when(mServiceConnection.wrapBlob(any(), anyLong(), anyLong()))
+                .thenAnswer(invocation -> invocation.getArgument(0));
+        assertEquals(ARM_REBOOT_ERROR_NONE, mService.armRebootEscrowIfNeeded());
+        verify(mServiceConnection).wrapBlob(any(), anyLong(), anyLong());
+        assertTrue(mStorage.hasRebootEscrowServerBlob());
+
+        // pretend reboot happens here
+        when(mInjected.getBootCount()).thenReturn(1);
+        ArgumentCaptor<Boolean> metricsSuccessCaptor = ArgumentCaptor.forClass(Boolean.class);
+        ArgumentCaptor<Integer> metricsErrorCodeCaptor = ArgumentCaptor.forClass(Integer.class);
+        doNothing()
+                .when(mInjected)
+                .reportMetric(
+                        metricsSuccessCaptor.capture(),
+                        metricsErrorCodeCaptor.capture(),
+                        eq(2) /* Server based */,
+                        eq(0) /* attempt count */,
+                        anyInt(),
+                        eq(0) /* vbmeta status */,
+                        anyInt());
+
+        // Network is not available within timeout.
+        mMockInjector.mNetworkConsumer = ConnectivityManager.NetworkCallback::onUnavailable;
+        mService.loadRebootEscrowDataIfAvailable(mHandler);
+        assertFalse(metricsSuccessCaptor.getValue());
+        assertEquals(
+                Integer.valueOf(RebootEscrowManager.ERROR_NO_NETWORK),
+                metricsErrorCodeCaptor.getValue());
+        assertNull(mMockInjector.mNetworkCallback);
+    }
+
+    @Test
+    public void loadRebootEscrowDataIfAvailable_waitForInternet_networkLost() throws Exception {
+        setServerBasedRebootEscrowProvider();
+        mMockInjector.setWaitForNetwork(true);
+
+        when(mInjected.getBootCount()).thenReturn(0);
+        RebootEscrowListener mockListener = mock(RebootEscrowListener.class);
+        mService.setRebootEscrowListener(mockListener);
+        mService.prepareRebootEscrow();
+
+        clearInvocations(mServiceConnection);
+        mService.callToRebootEscrowIfNeeded(PRIMARY_USER_ID, FAKE_SP_VERSION, FAKE_AUTH_TOKEN);
+        verify(mockListener).onPreparedForReboot(eq(true));
+        verify(mServiceConnection, never()).wrapBlob(any(), anyLong(), anyLong());
+
+        // Use x -> x for both wrap & unwrap functions.
+        when(mServiceConnection.wrapBlob(any(), anyLong(), anyLong()))
+                .thenAnswer(invocation -> invocation.getArgument(0));
+        assertEquals(ARM_REBOOT_ERROR_NONE, mService.armRebootEscrowIfNeeded());
+        verify(mServiceConnection).wrapBlob(any(), anyLong(), anyLong());
+        assertTrue(mStorage.hasRebootEscrowServerBlob());
+
+        // pretend reboot happens here
+        when(mInjected.getBootCount()).thenReturn(1);
+        ArgumentCaptor<Boolean> metricsSuccessCaptor = ArgumentCaptor.forClass(Boolean.class);
+        ArgumentCaptor<Integer> metricsErrorCodeCaptor = ArgumentCaptor.forClass(Integer.class);
+        doNothing()
+                .when(mInjected)
+                .reportMetric(
+                        metricsSuccessCaptor.capture(),
+                        metricsErrorCodeCaptor.capture(),
+                        eq(2) /* Server based */,
+                        eq(2) /* attempt count */,
+                        anyInt(),
+                        eq(0) /* vbmeta status */,
+                        anyInt());
+
+        // Network is available, then lost.
+        when(mServiceConnection.unwrap(any(), anyLong())).thenThrow(new IOException());
+        Network mockNetwork = mock(Network.class);
+        mMockInjector.mNetworkConsumer =
+                (callback) -> {
+                    callback.onAvailable(mockNetwork);
+                    callback.onLost(mockNetwork);
+                };
+        mService.loadRebootEscrowDataIfAvailable(mHandler);
+        // Sleep 5s for the retry to complete
+        Thread.sleep(5 * 1000);
+        assertFalse(metricsSuccessCaptor.getValue());
+        assertEquals(
+                Integer.valueOf(RebootEscrowManager.ERROR_NO_NETWORK),
+                metricsErrorCodeCaptor.getValue());
+        assertNull(mMockInjector.mNetworkCallback);
+    }
+
+    @Test
+    public void loadRebootEscrowDataIfAvailable_waitForInternet_networkAvailableWithDelay()
+            throws Exception {
+        setServerBasedRebootEscrowProvider();
+        mMockInjector.setWaitForNetwork(true);
+
+        when(mInjected.getBootCount()).thenReturn(0);
+        RebootEscrowListener mockListener = mock(RebootEscrowListener.class);
+        mService.setRebootEscrowListener(mockListener);
+        mService.prepareRebootEscrow();
+
+        clearInvocations(mServiceConnection);
+        mService.callToRebootEscrowIfNeeded(PRIMARY_USER_ID, FAKE_SP_VERSION, FAKE_AUTH_TOKEN);
+        verify(mockListener).onPreparedForReboot(eq(true));
+        verify(mServiceConnection, never()).wrapBlob(any(), anyLong(), anyLong());
+
+        // Use x -> x for both wrap & unwrap functions.
+        when(mServiceConnection.wrapBlob(any(), anyLong(), anyLong()))
+                .thenAnswer(invocation -> invocation.getArgument(0));
+        assertEquals(ARM_REBOOT_ERROR_NONE, mService.armRebootEscrowIfNeeded());
+        verify(mServiceConnection).wrapBlob(any(), anyLong(), anyLong());
+        assertTrue(mStorage.hasRebootEscrowServerBlob());
+
+        // pretend reboot happens here
+        when(mInjected.getBootCount()).thenReturn(1);
+        ArgumentCaptor<Boolean> metricsSuccessCaptor = ArgumentCaptor.forClass(Boolean.class);
+        ArgumentCaptor<Integer> metricsErrorCodeCaptor = ArgumentCaptor.forClass(Integer.class);
+        doNothing()
+                .when(mInjected)
+                .reportMetric(
+                        metricsSuccessCaptor.capture(),
+                        metricsErrorCodeCaptor.capture(),
+                        eq(2) /* Server based */,
+                        eq(1) /* attempt count */,
+                        anyInt(),
+                        eq(0) /* vbmeta status */,
+                        anyInt());
+
+        // load escrow data
+        when(mServiceConnection.unwrap(any(), anyLong()))
+                .thenAnswer(invocation -> invocation.getArgument(0));
+        // network available after 1 sec
+        Network mockNetwork = mock(Network.class);
+        mMockInjector.mNetworkConsumer =
+                (callback) -> {
+                    try {
+                        Thread.sleep(1000);
+                    } catch (InterruptedException e) {
+                        throw new RuntimeException(e);
+                    }
+                    callback.onAvailable(mockNetwork);
+                };
+        mService.loadRebootEscrowDataIfAvailable(mHandler);
+        verify(mServiceConnection).unwrap(any(), anyLong());
+        assertTrue(metricsSuccessCaptor.getValue());
+        verify(mKeyStoreManager).clearKeyStoreEncryptionKey();
+        assertNull(mMockInjector.mNetworkCallback);
+    }
+
+    @Test
+    public void loadRebootEscrowDataIfAvailable_waitForInternet_timeoutExhausted()
+            throws Exception {
+        setServerBasedRebootEscrowProvider();
+        mMockInjector.setWaitForNetwork(true);
+
+        when(mInjected.getBootCount()).thenReturn(0);
+        RebootEscrowListener mockListener = mock(RebootEscrowListener.class);
+        mService.setRebootEscrowListener(mockListener);
+        mService.prepareRebootEscrow();
+
+        clearInvocations(mServiceConnection);
+        mService.callToRebootEscrowIfNeeded(PRIMARY_USER_ID, FAKE_SP_VERSION, FAKE_AUTH_TOKEN);
+        verify(mockListener).onPreparedForReboot(eq(true));
+        verify(mServiceConnection, never()).wrapBlob(any(), anyLong(), anyLong());
+
+        // Use x -> x for both wrap & unwrap functions.
+        when(mServiceConnection.wrapBlob(any(), anyLong(), anyLong()))
+                .thenAnswer(invocation -> invocation.getArgument(0));
+        assertEquals(ARM_REBOOT_ERROR_NONE, mService.armRebootEscrowIfNeeded());
+        verify(mServiceConnection).wrapBlob(any(), anyLong(), anyLong());
+        assertTrue(mStorage.hasRebootEscrowServerBlob());
+
+        // pretend reboot happens here
+        when(mInjected.getBootCount()).thenReturn(1);
+        ArgumentCaptor<Boolean> metricsSuccessCaptor = ArgumentCaptor.forClass(Boolean.class);
+        ArgumentCaptor<Integer> metricsErrorCodeCaptor = ArgumentCaptor.forClass(Integer.class);
+        doNothing()
+                .when(mInjected)
+                .reportMetric(
+                        metricsSuccessCaptor.capture(),
+                        metricsErrorCodeCaptor.capture(),
+                        eq(2) /* Server based */,
+                        eq(1) /* attempt count */,
+                        anyInt(),
+                        eq(0) /* vbmeta status */,
+                        anyInt());
+
+        // load reboot escrow data
+        when(mServiceConnection.unwrap(any(), anyLong())).thenThrow(IOException.class);
+        Network mockNetwork = mock(Network.class);
+        // wait past timeout
+        mMockInjector.mNetworkConsumer =
+                (callback) -> {
+                    try {
+                        Thread.sleep(3500);
+                    } catch (InterruptedException e) {
+                        throw new RuntimeException(e);
+                    }
+                    callback.onAvailable(mockNetwork);
+                };
+        mService.loadRebootEscrowDataIfAvailable(mHandler);
+        verify(mServiceConnection).unwrap(any(), anyLong());
+        assertFalse(metricsSuccessCaptor.getValue());
+        assertEquals(
+                Integer.valueOf(RebootEscrowManager.ERROR_TIMEOUT_EXHAUSTED),
+                metricsErrorCodeCaptor.getValue());
+        assertNull(mMockInjector.mNetworkCallback);
+    }
+
+    @Test
+    public void loadRebootEscrowDataIfAvailable_serverBasedWaitForNetwork_retryCountExhausted()
+            throws Exception {
+        setServerBasedRebootEscrowProvider();
+        mMockInjector.setWaitForNetwork(true);
+
+        when(mInjected.getBootCount()).thenReturn(0);
+        RebootEscrowListener mockListener = mock(RebootEscrowListener.class);
+        mService.setRebootEscrowListener(mockListener);
+        mService.prepareRebootEscrow();
+
+        clearInvocations(mServiceConnection);
+        mService.callToRebootEscrowIfNeeded(PRIMARY_USER_ID, FAKE_SP_VERSION, FAKE_AUTH_TOKEN);
+        verify(mockListener).onPreparedForReboot(eq(true));
+        verify(mServiceConnection, never()).wrapBlob(any(), anyLong(), anyLong());
+
+        // Use x -> x for both wrap & unwrap functions.
+        when(mServiceConnection.wrapBlob(any(), anyLong(), anyLong()))
+                .thenAnswer(invocation -> invocation.getArgument(0));
+        assertEquals(ARM_REBOOT_ERROR_NONE, mService.armRebootEscrowIfNeeded());
+        verify(mServiceConnection).wrapBlob(any(), anyLong(), anyLong());
+        assertTrue(mStorage.hasRebootEscrowServerBlob());
+
+        // pretend reboot happens here
+        when(mInjected.getBootCount()).thenReturn(1);
+        ArgumentCaptor<Boolean> metricsSuccessCaptor = ArgumentCaptor.forClass(Boolean.class);
+        ArgumentCaptor<Integer> metricsErrorCodeCaptor = ArgumentCaptor.forClass(Integer.class);
+        doNothing()
+                .when(mInjected)
+                .reportMetric(
+                        metricsSuccessCaptor.capture(),
+                        metricsErrorCodeCaptor.capture(),
+                        eq(2) /* Server based */,
+                        eq(2) /* attempt count */,
+                        anyInt(),
+                        eq(0) /* vbmeta status */,
+                        anyInt());
+
+        when(mServiceConnection.unwrap(any(), anyLong())).thenThrow(new IOException());
+        Network mockNetwork = mock(Network.class);
+        mMockInjector.mNetworkConsumer =
+                (callback) -> {
+                    callback.onAvailable(mockNetwork);
+                };
+
+        mService.loadRebootEscrowDataIfAvailable(mHandler);
+        // Sleep 5s for the retry to complete
+        Thread.sleep(5 * 1000);
+        verify(mServiceConnection, times(2)).unwrap(any(), anyLong());
+        assertFalse(metricsSuccessCaptor.getValue());
+        assertEquals(
+                Integer.valueOf(RebootEscrowManager.ERROR_RETRY_COUNT_EXHAUSTED),
+                metricsErrorCodeCaptor.getValue());
+        assertNull(mMockInjector.mNetworkCallback);
+    }
+
+    @Test
+    public void loadRebootEscrowDataIfAvailable_ServerBasedWaitForInternet_RetrySuccess()
+            throws Exception {
+        setServerBasedRebootEscrowProvider();
+        mMockInjector.setWaitForNetwork(true);
+
+        when(mInjected.getBootCount()).thenReturn(0);
+        RebootEscrowListener mockListener = mock(RebootEscrowListener.class);
+        mService.setRebootEscrowListener(mockListener);
+        mService.prepareRebootEscrow();
+
+        clearInvocations(mServiceConnection);
+        mService.callToRebootEscrowIfNeeded(PRIMARY_USER_ID, FAKE_SP_VERSION, FAKE_AUTH_TOKEN);
+        verify(mockListener).onPreparedForReboot(eq(true));
+        verify(mServiceConnection, never()).wrapBlob(any(), anyLong(), anyLong());
+
+        // Use x -> x for both wrap & unwrap functions.
+        when(mServiceConnection.wrapBlob(any(), anyLong(), anyLong()))
+                .thenAnswer(invocation -> invocation.getArgument(0));
+        assertEquals(ARM_REBOOT_ERROR_NONE, mService.armRebootEscrowIfNeeded());
+        verify(mServiceConnection).wrapBlob(any(), anyLong(), anyLong());
+        assertTrue(mStorage.hasRebootEscrowServerBlob());
+
+        // pretend reboot happens here
+        when(mInjected.getBootCount()).thenReturn(1);
+        ArgumentCaptor<Boolean> metricsSuccessCaptor = ArgumentCaptor.forClass(Boolean.class);
+        doNothing()
+                .when(mInjected)
+                .reportMetric(
+                        metricsSuccessCaptor.capture(),
+                        anyInt(),
+                        anyInt(),
+                        eq(2) /* attempt count */,
+                        anyInt(),
+                        anyInt(),
+                        anyInt());
+
+        when(mServiceConnection.unwrap(any(), anyLong()))
+                .thenThrow(new IOException())
+                .thenAnswer(invocation -> invocation.getArgument(0));
+        Network mockNetwork = mock(Network.class);
+        mMockInjector.mNetworkConsumer =
+                (callback) -> {
+                    callback.onAvailable(mockNetwork);
+                };
+
+        mService.loadRebootEscrowDataIfAvailable(mHandler);
+        // Sleep 5s for the retry to complete
+        Thread.sleep(5 * 1000);
+        verify(mServiceConnection, times(2)).unwrap(any(), anyLong());
+        assertTrue(metricsSuccessCaptor.getValue());
+        verify(mKeyStoreManager).clearKeyStoreEncryptionKey();
+        assertNull(mMockInjector.mNetworkCallback);
+    }
+
+    @Test
     public void loadRebootEscrowDataIfAvailable_TooManyBootsInBetween_NoMetrics() throws Exception {
         when(mInjected.getBootCount()).thenReturn(0);
 
@@ -687,7 +1195,7 @@
         when(mInjected.getBootCount()).thenReturn(10);
         when(mRebootEscrow.retrieveKey()).thenReturn(new byte[32]);
 
-        mService.loadRebootEscrowDataIfAvailable(null);
+        mService.loadRebootEscrowDataIfAvailable(mHandler);
         verify(mRebootEscrow).retrieveKey();
         verify(mInjected, never()).reportMetric(anyBoolean(), anyInt(), anyInt(), anyInt(),
                 anyInt(), anyInt(), anyInt());
@@ -715,7 +1223,7 @@
         when(mInjected.getBootCount()).thenReturn(10);
         when(mRebootEscrow.retrieveKey()).thenReturn(new byte[32]);
 
-        mService.loadRebootEscrowDataIfAvailable(null);
+        mService.loadRebootEscrowDataIfAvailable(mHandler);
         verify(mInjected, never()).reportMetric(anyBoolean(), anyInt(), anyInt(), anyInt(),
                 anyInt(), anyInt(), anyInt());
     }
@@ -753,7 +1261,7 @@
         // Trigger a vbmeta digest mismatch
         mStorage.setString(RebootEscrowManager.REBOOT_ESCROW_KEY_VBMETA_DIGEST,
                 "non sense value", USER_SYSTEM);
-        mService.loadRebootEscrowDataIfAvailable(null);
+        mService.loadRebootEscrowDataIfAvailable(mHandler);
         verify(mInjected).reportMetric(eq(true), eq(0) /* error code */, eq(1) /* HAL based */,
                 eq(1) /* attempt count */, anyInt(), eq(2) /* vbmeta status */, anyInt());
         assertEquals(mStorage.getString(RebootEscrowManager.REBOOT_ESCROW_KEY_VBMETA_DIGEST,
@@ -790,7 +1298,7 @@
                 eq(1) /* attempt count */, anyInt(), anyInt(), anyInt());
 
         when(mRebootEscrow.retrieveKey()).thenAnswer(invocation -> null);
-        mService.loadRebootEscrowDataIfAvailable(null);
+        mService.loadRebootEscrowDataIfAvailable(mHandler);
         verify(mRebootEscrow).retrieveKey();
         assertFalse(metricsSuccessCaptor.getValue());
         assertEquals(Integer.valueOf(RebootEscrowManager.ERROR_LOAD_ESCROW_KEY),