diff --git a/src/com/android/settings/biometrics2/ui/viewmodel/FingerprintEnrollProgressViewModel.java b/src/com/android/settings/biometrics2/ui/viewmodel/FingerprintEnrollProgressViewModel.java
index b1b420d..695ea0d 100644
--- a/src/com/android/settings/biometrics2/ui/viewmodel/FingerprintEnrollProgressViewModel.java
+++ b/src/com/android/settings/biometrics2/ui/viewmodel/FingerprintEnrollProgressViewModel.java
@@ -16,6 +16,8 @@
 
 package com.android.settings.biometrics2.ui.viewmodel;
 
+import static android.hardware.fingerprint.FingerprintManager.ENROLL_ENROLL;
+
 import static com.android.settings.biometrics2.ui.model.EnrollmentProgress.INITIAL_REMAINING;
 import static com.android.settings.biometrics2.ui.model.EnrollmentProgress.INITIAL_STEPS;
 
@@ -63,7 +65,6 @@
     private final int mUserId;
 
     private final FingerprintUpdater mFingerprintUpdater;
-    private final MessageDisplayController mMessageDisplayController;
     @Nullable private CancellationSignal mCancellationSignal = null;
     private final EnrollmentCallback mEnrollmentCallback = new EnrollmentCallback() {
 
@@ -81,6 +82,9 @@
 
         @Override
         public void onEnrollmentHelp(int helpMsgId, CharSequence helpString) {
+            if (DEBUG) {
+                Log.d(TAG, "onEnrollmentHelp(" + helpMsgId + ", " + helpString + ")");
+            }
             mHelpMessageLiveData.postValue(new EnrollmentStatusMessage(helpMsgId, helpString));
         }
 
@@ -113,20 +117,6 @@
         super(application);
         mFingerprintUpdater = fingerprintUpdater;
         mUserId = userId;
-
-        final Resources res = application.getResources();
-        mMessageDisplayController =
-                res.getBoolean(R.bool.enrollment_message_display_controller_flag)
-                        ? new MessageDisplayController(
-                                application.getMainThreadHandler(),
-                                mEnrollmentCallback,
-                                SystemClock.elapsedRealtimeClock(),
-                                res.getInteger(R.integer.enrollment_help_minimum_time_display),
-                                res.getInteger(R.integer.enrollment_progress_minimum_time_display),
-                                res.getBoolean(R.bool.enrollment_progress_priority_over_help),
-                                res.getBoolean(R.bool.enrollment_prioritize_acquire_messages),
-                                res.getInteger(R.integer.enrollment_collect_time))
-                        : null;
     }
 
     public void setToken(byte[] token) {
@@ -195,9 +185,24 @@
         mErrorMessageLiveData.setValue(null);
 
         mCancellationSignal = new CancellationSignal();
-        mFingerprintUpdater.enroll(mToken, mCancellationSignal, mUserId,
-                mMessageDisplayController != null ? mMessageDisplayController : mEnrollmentCallback,
-                reason);
+
+        final Resources res = getApplication().getResources();
+        if (reason == ENROLL_ENROLL
+                && res.getBoolean(R.bool.enrollment_message_display_controller_flag)) {
+            final EnrollmentCallback callback = new MessageDisplayController(
+                    getApplication().getMainThreadHandler(),
+                    mEnrollmentCallback,
+                    SystemClock.elapsedRealtimeClock(),
+                    res.getInteger(R.integer.enrollment_help_minimum_time_display),
+                    res.getInteger(R.integer.enrollment_progress_minimum_time_display),
+                    res.getBoolean(R.bool.enrollment_progress_priority_over_help),
+                    res.getBoolean(R.bool.enrollment_prioritize_acquire_messages),
+                    res.getInteger(R.integer.enrollment_collect_time));
+            mFingerprintUpdater.enroll(mToken, mCancellationSignal, mUserId, callback, reason);
+        } else {
+            mFingerprintUpdater.enroll(mToken, mCancellationSignal, mUserId, mEnrollmentCallback,
+                    reason);
+        }
         return true;
     }
 
diff --git a/tests/unit/src/com/android/settings/biometrics2/ui/viewmodel/FingerprintEnrollProgressViewModelTest.java b/tests/unit/src/com/android/settings/biometrics2/ui/viewmodel/FingerprintEnrollProgressViewModelTest.java
index 323618a..6190c5e 100644
--- a/tests/unit/src/com/android/settings/biometrics2/ui/viewmodel/FingerprintEnrollProgressViewModelTest.java
+++ b/tests/unit/src/com/android/settings/biometrics2/ui/viewmodel/FingerprintEnrollProgressViewModelTest.java
@@ -21,6 +21,8 @@
 import static android.hardware.fingerprint.FingerprintManager.EnrollReason;
 import static android.hardware.fingerprint.FingerprintManager.EnrollmentCallback;
 
+import static com.android.settings.Utils.SETTINGS_PACKAGE_NAME;
+
 import static com.google.common.truth.Truth.assertThat;
 
 import static org.mockito.ArgumentMatchers.anyInt;
@@ -29,18 +31,24 @@
 import static org.mockito.Mockito.eq;
 import static org.mockito.Mockito.never;
 import static org.mockito.Mockito.only;
+import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
 
 import android.app.Application;
 import android.content.res.Resources;
 import android.os.CancellationSignal;
+import android.os.Handler;
+import android.os.Looper;
+import android.os.Message;
 
+import androidx.annotation.NonNull;
 import androidx.lifecycle.LiveData;
+import androidx.test.core.app.ApplicationProvider;
 import androidx.test.ext.junit.runners.AndroidJUnit4;
 
-import com.android.settings.R;
 import com.android.settings.biometrics.fingerprint.FingerprintUpdater;
+import com.android.settings.biometrics.fingerprint.MessageDisplayController;
 import com.android.settings.biometrics2.ui.model.EnrollmentProgress;
 import com.android.settings.biometrics2.ui.model.EnrollmentStatusMessage;
 import com.android.settings.testutils.InstantTaskExecutorRule;
@@ -68,12 +76,18 @@
     private FingerprintEnrollProgressViewModel mViewModel;
     private final TestWrapper<CancellationSignal> mCancellationSignalWrapper = new TestWrapper<>();
     private final TestWrapper<EnrollmentCallback> mCallbackWrapper = new TestWrapper<>();
+    private int mEnrollmentMessageDisplayControllerFlagResId;
 
     @Before
     public void setUp() {
+        mEnrollmentMessageDisplayControllerFlagResId = ApplicationProvider.getApplicationContext()
+                .getResources().getIdentifier("enrollment_message_display_controller_flag", "bool",
+                        SETTINGS_PACKAGE_NAME);
+
         when(mApplication.getResources()).thenReturn(mResources);
-        when(mResources.getBoolean(R.bool.enrollment_message_display_controller_flag))
-                .thenReturn(false);
+
+        // Not use MessageDisplayController by default
+        when(mResources.getBoolean(mEnrollmentMessageDisplayControllerFlagResId)).thenReturn(false);
         mViewModel = new FingerprintEnrollProgressViewModel(mApplication, mFingerprintUpdater,
                 TEST_USER_ID);
 
@@ -88,7 +102,7 @@
     }
 
     @Test
-    public void testStartEnrollment() {
+    public void testStartFindSensor() {
         @EnrollReason final int enrollReason = ENROLL_FIND_SENSOR;
         final byte[] token = new byte[] { 1, 2, 3 };
         mViewModel.setToken(token);
@@ -99,6 +113,54 @@
         assertThat(ret).isTrue();
         verify(mFingerprintUpdater, only()).enroll(eq(token), any(CancellationSignal.class),
                 eq(TEST_USER_ID), any(EnrollmentCallback.class), eq(enrollReason));
+        assertThat(mCallbackWrapper.mValue instanceof MessageDisplayController).isFalse();
+    }
+
+    @Test
+    public void testStartEnrolling() {
+        @EnrollReason final int enrollReason = ENROLL_ENROLL;
+        final byte[] token = new byte[] { 1, 2, 3 };
+        mViewModel.setToken(token);
+
+        // Start enrollment
+        final boolean ret = mViewModel.startEnrollment(enrollReason);
+
+        assertThat(ret).isTrue();
+        verify(mFingerprintUpdater, only()).enroll(eq(token), any(CancellationSignal.class),
+                eq(TEST_USER_ID), any(EnrollmentCallback.class), eq(enrollReason));
+        assertThat(mCallbackWrapper.mValue instanceof MessageDisplayController).isFalse();
+    }
+
+    @Test
+    public void testStartEnrollingWithMessageDisplayController() {
+        // Enable MessageDisplayController and mock handler for it
+        when(mResources.getBoolean(mEnrollmentMessageDisplayControllerFlagResId)).thenReturn(true);
+        when(mApplication.getMainThreadHandler()).thenReturn(new TestHandler());
+
+        @EnrollReason final int enrollReason = ENROLL_ENROLL;
+        final byte[] token = new byte[] { 1, 2, 3 };
+        mViewModel.setToken(token);
+
+        // Start enrollment
+        final boolean ret = mViewModel.startEnrollment(enrollReason);
+
+        assertThat(ret).isTrue();
+        verify(mFingerprintUpdater, only()).enroll(eq(token), any(CancellationSignal.class),
+                eq(TEST_USER_ID), any(MessageDisplayController.class), eq(enrollReason));
+        assertThat(mCallbackWrapper.mValue).isNotNull();
+
+        assertThat(mCallbackWrapper.mValue instanceof MessageDisplayController).isTrue();
+        final EnrollmentCallback callback1 = mCallbackWrapper.mValue;
+
+        // Cancel and start again
+        mViewModel.cancelEnrollment();
+        mViewModel.startEnrollment(enrollReason);
+
+        // Shall not use the same MessageDisplayController
+        verify(mFingerprintUpdater, times(2)).enroll(eq(token), any(CancellationSignal.class),
+                eq(TEST_USER_ID), any(MessageDisplayController.class), eq(enrollReason));
+        assertThat(mCallbackWrapper.mValue).isNotNull();
+        assertThat(callback1).isNotEqualTo(mCallbackWrapper.mValue);
     }
 
     @Test
@@ -163,6 +225,48 @@
     }
 
     @Test
+    public void testProgressUpdateWithMessageDisplayController() {
+        // Enable MessageDisplayController and mock handler for it
+        when(mResources.getBoolean(mEnrollmentMessageDisplayControllerFlagResId)).thenReturn(true);
+        when(mApplication.getMainThreadHandler()).thenReturn(new TestHandler());
+
+        mViewModel.setToken(new byte[] { 1, 2, 3 });
+
+        // Start enrollment
+        final boolean ret = mViewModel.startEnrollment(ENROLL_ENROLL);
+        assertThat(ret).isTrue();
+        assertThat(mCallbackWrapper.mValue).isNotNull();
+
+        // Test default value
+        final LiveData<EnrollmentProgress> progressLiveData = mViewModel.getProgressLiveData();
+        EnrollmentProgress progress = progressLiveData.getValue();
+        assertThat(progress).isNotNull();
+        assertThat(progress.getSteps()).isEqualTo(-1);
+        assertThat(progress.getRemaining()).isEqualTo(0);
+
+        // Update first progress
+        mCallbackWrapper.mValue.onEnrollmentProgress(25);
+        progress = progressLiveData.getValue();
+        assertThat(progress).isNotNull();
+        assertThat(progress.getSteps()).isEqualTo(25);
+        assertThat(progress.getRemaining()).isEqualTo(25);
+
+        // Update second progress
+        mCallbackWrapper.mValue.onEnrollmentProgress(20);
+        progress = progressLiveData.getValue();
+        assertThat(progress).isNotNull();
+        assertThat(progress.getSteps()).isEqualTo(25);
+        assertThat(progress.getRemaining()).isEqualTo(20);
+
+        // Update latest progress
+        mCallbackWrapper.mValue.onEnrollmentProgress(0);
+        progress = progressLiveData.getValue();
+        assertThat(progress).isNotNull();
+        assertThat(progress.getSteps()).isEqualTo(25);
+        assertThat(progress.getRemaining()).isEqualTo(0);
+    }
+
+    @Test
     public void testGetErrorMessageLiveData() {
         // Start enrollment
         mViewModel.setToken(new byte[] { 1, 2, 3 });
@@ -262,4 +366,17 @@
     private static class TestWrapper<T> {
         T mValue;
     }
+
+    private static class TestHandler extends Handler {
+
+        TestHandler() {
+            super(Looper.getMainLooper());
+        }
+
+        @Override
+        public boolean sendMessageAtTime(@NonNull Message msg, long uptimeMillis) {
+            msg.getCallback().run();
+            return true;
+        }
+    }
 }
