diff --git a/quickstep/src/com/android/launcher3/WidgetPickerActivity.java b/quickstep/src/com/android/launcher3/WidgetPickerActivity.java
index 7f3e615..4d3e3be 100644
--- a/quickstep/src/com/android/launcher3/WidgetPickerActivity.java
+++ b/quickstep/src/com/android/launcher3/WidgetPickerActivity.java
@@ -68,7 +68,8 @@
 import java.util.regex.Pattern;
 
 /** An Activity that can host Launcher's widget picker. */
-public class WidgetPickerActivity extends BaseActivity {
+public class WidgetPickerActivity extends BaseActivity implements
+        WidgetPredictionsRequester.WidgetPredictionsListener {
     private static final String TAG = "WidgetPickerActivity";
     /**
      * Name of the extra that indicates that a widget being dragged.
@@ -322,7 +323,7 @@
             if (mUiSurface != null) {
                 mWidgetPredictionsRequester = new WidgetPredictionsRequester(app.getContext(),
                         mUiSurface, mModel.getWidgetsByComponentKeyForPicker());
-                mWidgetPredictionsRequester.request(mAddedWidgets, this::bindRecommendedWidgets);
+                mWidgetPredictionsRequester.request(mAddedWidgets, /*listener=*/ this);
             }
         });
     }
@@ -355,7 +356,8 @@
         });
     }
 
-    private void bindRecommendedWidgets(List<ItemInfo> recommendedWidgets) {
+    @Override
+    public void onPredictionsAvailable(List<ItemInfo> recommendedWidgets) {
         // Bind recommendations once picker has finished open animation.
         MAIN_EXECUTOR.getHandler().postDelayed(
                 () -> mWidgetPickerDataProvider.setWidgetRecommendations(recommendedWidgets),
diff --git a/quickstep/src/com/android/launcher3/model/WidgetPredictionsRequester.java b/quickstep/src/com/android/launcher3/model/WidgetPredictionsRequester.java
index d3ac975..f9cec82 100644
--- a/quickstep/src/com/android/launcher3/model/WidgetPredictionsRequester.java
+++ b/quickstep/src/com/android/launcher3/model/WidgetPredictionsRequester.java
@@ -48,7 +48,6 @@
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
-import java.util.function.Consumer;
 import java.util.function.Predicate;
 import java.util.stream.Collectors;
 
@@ -56,7 +55,7 @@
  * Works with app predictor to fetch and process widget predictions displayed in a standalone
  * widget picker activity for a UI surface.
  */
-public class WidgetPredictionsRequester {
+public class WidgetPredictionsRequester implements AppPredictor.Callback {
     private static final int NUM_OF_RECOMMENDED_WIDGETS_PREDICATION = 20;
     private static final String BUNDLE_KEY_ADDED_APP_WIDGETS = "added_app_widgets";
     // container/screenid/[positionx,positiony]/[spanx,spany]
@@ -71,6 +70,9 @@
     @NonNull
     private final String mUiSurface;
     private boolean mPredictionsAvailable;
+    @Nullable
+    private WidgetPredictionsListener mPredictionsListener = null;
+    @Nullable Predicate<WidgetItem> mFilter = null;
     @NonNull
     private final Map<ComponentKey, WidgetItem> mAllWidgets;
 
@@ -81,36 +83,49 @@
         mAllWidgets = Collections.unmodifiableMap(allWidgets);
     }
 
+    // AppPredictor.Callback -> onTargetsAvailable
+    @Override
+    @WorkerThread
+    public void onTargetsAvailable(List<AppTarget> targets) {
+        List<WidgetItem> filteredPredictions = filterPredictions(targets, mAllWidgets, mFilter);
+        List<ItemInfo> mappedPredictions = mapWidgetItemsToItemInfo(filteredPredictions);
+
+        if (!mPredictionsAvailable && mPredictionsListener != null) {
+            mPredictionsAvailable = true;
+            MAIN_EXECUTOR.execute(
+                    () -> mPredictionsListener.onPredictionsAvailable(mappedPredictions));
+        }
+    }
+
     /**
      * Requests one time predictions from the app predictions manager and invokes provided callback
-     * once predictions are available.
+     * once predictions are available. Any previous requests may be cancelled.
      *
      * @param existingWidgets widgets that are currently added to the surface;
-     * @param callback        consumer of prediction results to be called when predictions are
-     *                        available
+     * @param listener        consumer of prediction results to be called when predictions are
+     *                        available; any previous listener will no longer receive updates.
      */
+    @WorkerThread // e.g. MODEL_EXECUTOR
     public void request(List<AppWidgetProviderInfo> existingWidgets,
-            Consumer<List<ItemInfo>> callback) {
+            WidgetPredictionsListener listener) {
+        clear();
+        mPredictionsListener = listener;
+        mFilter = notOnUiSurfaceFilter(existingWidgets);
+
+        AppPredictionManager apm = mContext.getSystemService(AppPredictionManager.class);
+        if (apm == null) {
+            return;
+        }
+
         Bundle bundle = buildBundleForPredictionSession(existingWidgets);
-        Predicate<WidgetItem> filter = notOnUiSurfaceFilter(existingWidgets);
-
-        MODEL_EXECUTOR.execute(() -> {
-            clear();
-            AppPredictionManager apm = mContext.getSystemService(AppPredictionManager.class);
-            if (apm == null) {
-                return;
-            }
-
-            mAppPredictor = apm.createAppPredictionSession(
-                    new AppPredictionContext.Builder(mContext)
-                            .setUiSurface(mUiSurface)
-                            .setExtras(bundle)
-                            .setPredictedTargetCount(NUM_OF_RECOMMENDED_WIDGETS_PREDICATION)
-                            .build());
-            mAppPredictor.registerPredictionUpdates(MODEL_EXECUTOR,
-                    targets -> bindPredictions(targets, filter, callback));
-            mAppPredictor.requestPredictionUpdate();
-        });
+        mAppPredictor = apm.createAppPredictionSession(
+                new AppPredictionContext.Builder(mContext)
+                        .setUiSurface(mUiSurface)
+                        .setExtras(bundle)
+                        .setPredictedTargetCount(NUM_OF_RECOMMENDED_WIDGETS_PREDICATION)
+                        .build());
+        mAppPredictor.registerPredictionUpdates(MODEL_EXECUTOR, /*callback=*/ this);
+        mAppPredictor.requestPredictionUpdate();
     }
 
     /**
@@ -158,27 +173,14 @@
         return widgetItem -> !existingComponentKeys.contains(widgetItem);
     }
 
-    /** Provides the predictions returned by the predictor to the registered callback. */
-    @WorkerThread
-    private void bindPredictions(List<AppTarget> targets, Predicate<WidgetItem> filter,
-            Consumer<List<ItemInfo>> callback) {
-        if (!mPredictionsAvailable) {
-            mPredictionsAvailable = true;
-            List<WidgetItem> filteredPredictions = filterPredictions(targets, mAllWidgets, filter);
-            List<ItemInfo> mappedPredictions = mapWidgetItemsToItemInfo(filteredPredictions);
-
-            MAIN_EXECUTOR.execute(() -> callback.accept(mappedPredictions));
-            MODEL_EXECUTOR.execute(this::clear);
-        }
-    }
-
     /**
      * Applies the provided filter (e.g. widgets not on workspace) on the predictions returned by
      * the predictor.
      */
     @VisibleForTesting
     static List<WidgetItem> filterPredictions(List<AppTarget> predictions,
-            Map<ComponentKey, WidgetItem> allWidgets, Predicate<WidgetItem> filter) {
+            @NonNull Map<ComponentKey, WidgetItem> allWidgets,
+            @Nullable Predicate<WidgetItem> filter) {
         List<WidgetItem> servicePredictedItems = new ArrayList<>();
 
         for (AppTarget prediction : predictions) {
@@ -187,7 +189,7 @@
                 WidgetItem widgetItem = allWidgets.get(
                         new ComponentKey(new ComponentName(prediction.getPackageName(), className),
                                 prediction.getUser()));
-                if (widgetItem != null && filter.test(widgetItem)) {
+                if (widgetItem != null && (filter == null || filter.test(widgetItem))) {
                     servicePredictedItems.add(widgetItem);
                 }
             }
@@ -218,9 +220,23 @@
     /** Cleans up any open prediction sessions. */
     public void clear() {
         if (mAppPredictor != null) {
+            mAppPredictor.unregisterPredictionUpdates(this);
             mAppPredictor.destroy();
             mAppPredictor = null;
         }
+        mPredictionsListener = null;
         mPredictionsAvailable = false;
+        mFilter = null;
+    }
+
+    /**
+     * Listener class to listen to updates from the {@link WidgetPredictionsRequester}
+     */
+    public interface WidgetPredictionsListener {
+        /**
+         * Callback method that is called when the predicted widgets are available.
+         * @param predictions list of predicted widgets {@link PendingAddWidgetInfo}
+         */
+        void onPredictionsAvailable(List<ItemInfo> predictions);
     }
 }
diff --git a/quickstep/tests/multivalentTests/src/com/android/launcher3/model/WidgetsPredictionsRequesterTest.kt b/quickstep/tests/multivalentTests/src/com/android/launcher3/model/WidgetsPredictionsRequesterTest.kt
index 4ea74df..d445189 100644
--- a/quickstep/tests/multivalentTests/src/com/android/launcher3/model/WidgetsPredictionsRequesterTest.kt
+++ b/quickstep/tests/multivalentTests/src/com/android/launcher3/model/WidgetsPredictionsRequesterTest.kt
@@ -16,6 +16,8 @@
 
 package com.android.launcher3.model
 
+import android.app.prediction.AppPredictionManager
+import android.app.prediction.AppPredictor
 import android.app.prediction.AppTarget
 import android.app.prediction.AppTargetEvent
 import android.app.prediction.AppTargetId
@@ -36,9 +38,15 @@
 import com.android.launcher3.model.WidgetPredictionsRequester.notOnUiSurfaceFilter
 import com.android.launcher3.util.ActivityContextWrapper
 import com.android.launcher3.util.ComponentKey
+import com.android.launcher3.util.Executors
+import com.android.launcher3.util.Executors.MODEL_EXECUTOR
+import com.android.launcher3.util.TestUtil
 import com.android.launcher3.util.WidgetUtils.createAppWidgetProviderInfo
 import com.android.launcher3.widget.LauncherAppWidgetProviderInfo
+import com.android.launcher3.widget.PendingAddWidgetInfo
 import com.google.common.truth.Truth.assertThat
+import java.util.concurrent.CountDownLatch
+import java.util.concurrent.TimeUnit
 import java.util.function.Predicate
 import junit.framework.Assert.assertNotNull
 import org.junit.Before
@@ -46,6 +54,9 @@
 import org.junit.runner.RunWith
 import org.mockito.Mock
 import org.mockito.MockitoAnnotations
+import org.mockito.kotlin.any
+import org.mockito.kotlin.doAnswer
+import org.mockito.kotlin.whenever
 
 @RunWith(AndroidJUnit4::class)
 class WidgetsPredictionsRequesterTest {
@@ -67,11 +78,26 @@
 
     @Mock private lateinit var iconCache: IconCache
 
+    @Mock private lateinit var apmMock: AppPredictionManager
+
+    @Mock private lateinit var predictorMock: AppPredictor
+
     @Before
     fun setUp() {
         MockitoAnnotations.initMocks(this)
         mUserHandle = myUserHandle()
-        context = ActivityContextWrapper(ApplicationProvider.getApplicationContext())
+
+        whenever(apmMock.createAppPredictionSession(any())).thenReturn(predictorMock)
+
+        context =
+            object : ActivityContextWrapper(ApplicationProvider.getApplicationContext()) {
+                override fun getSystemService(name: String): Any? {
+                    if (name == "app_prediction") {
+                        return apmMock
+                    }
+                    return super.getSystemService(name)
+                }
+            }
         testInvariantProfile = LauncherAppState.getIDP(context)
         deviceProfile = testInvariantProfile.getDeviceProfile(context).copy(context)
 
@@ -114,22 +140,68 @@
                 buildExpectedAppTargetEvent(
                     /*pkg=*/ APP_1_PACKAGE_NAME,
                     /*providerClassName=*/ APP_1_PROVIDER_A_CLASS_NAME,
-                    /*user=*/ mUserHandle
+                    /*user=*/ mUserHandle,
                 ),
                 buildExpectedAppTargetEvent(
                     /*pkg=*/ APP_1_PACKAGE_NAME,
                     /*providerClassName=*/ APP_1_PROVIDER_B_CLASS_NAME,
-                    /*user=*/ mUserHandle
+                    /*user=*/ mUserHandle,
                 ),
                 buildExpectedAppTargetEvent(
                     /*pkg=*/ APP_2_PACKAGE_NAME,
                     /*providerClassName=*/ APP_2_PROVIDER_1_CLASS_NAME,
-                    /*user=*/ mUserHandle
-                )
+                    /*user=*/ mUserHandle,
+                ),
             )
     }
 
     @Test
+    fun request_invokesCallbackWithPredictedItems() {
+        TestUtil.runOnExecutorSync(MODEL_EXECUTOR) {
+            val underTest = WidgetPredictionsRequester(context, TEST_UI_SURFACE, allWidgets)
+            val existingWidgets = arrayListOf(widget1aInfo, widget1bInfo)
+            val predictions =
+                listOf(
+                    // (existing) already on surface
+                    AppTarget(
+                        AppTargetId(APP_1_PACKAGE_NAME),
+                        APP_1_PACKAGE_NAME,
+                        APP_1_PROVIDER_B_CLASS_NAME,
+                        mUserHandle,
+                    ),
+                    // eligible
+                    AppTarget(
+                        AppTargetId(APP_2_PACKAGE_NAME),
+                        APP_2_PACKAGE_NAME,
+                        APP_2_PROVIDER_1_CLASS_NAME,
+                        mUserHandle,
+                    ),
+                )
+            doAnswer {
+                    underTest.onTargetsAvailable(predictions)
+                    null
+                }
+                .whenever(predictorMock)
+                .requestPredictionUpdate()
+            val testCountDownLatch = CountDownLatch(1)
+            val listener =
+                WidgetPredictionsRequester.WidgetPredictionsListener { itemInfos ->
+                    if (itemInfos.size == 1 && itemInfos[0] is PendingAddWidgetInfo) {
+                        // only one item was eligible.
+                        testCountDownLatch.countDown()
+                    } else {
+                        println("Unexpected prediction items found: ${itemInfos.size}")
+                    }
+                }
+
+            underTest.request(existingWidgets, listener)
+            TestUtil.runOnExecutorSync(Executors.MAIN_EXECUTOR) {}
+
+            assertThat(testCountDownLatch.await(TEST_TIMEOUT, TimeUnit.SECONDS)).isTrue()
+        }
+    }
+
+    @Test
     fun filterPredictions_notOnUiSurfaceFilter_returnsOnlyEligiblePredictions() {
         val widgetsAlreadyOnSurface = arrayListOf(widget1bInfo)
         val filter: Predicate<WidgetItem> = notOnUiSurfaceFilter(widgetsAlreadyOnSurface)
@@ -141,15 +213,15 @@
                     AppTargetId(APP_1_PACKAGE_NAME),
                     APP_1_PACKAGE_NAME,
                     APP_1_PROVIDER_B_CLASS_NAME,
-                    mUserHandle
+                    mUserHandle,
                 ),
                 // eligible
                 AppTarget(
                     AppTargetId(APP_2_PACKAGE_NAME),
                     APP_2_PACKAGE_NAME,
                     APP_2_PROVIDER_1_CLASS_NAME,
-                    mUserHandle
-                )
+                    mUserHandle,
+                ),
             )
 
         // only 2 was eligible
@@ -167,27 +239,27 @@
                     AppTargetId(APP_1_PACKAGE_NAME),
                     APP_1_PACKAGE_NAME,
                     "$APP_1_PACKAGE_NAME.SomeActivity",
-                    mUserHandle
+                    mUserHandle,
                 ),
                 AppTarget(
                     AppTargetId(APP_2_PACKAGE_NAME),
                     APP_2_PACKAGE_NAME,
                     "$APP_2_PACKAGE_NAME.SomeActivity2",
-                    mUserHandle
+                    mUserHandle,
                 ),
             )
 
         assertThat(filterPredictions(predictions, allWidgets, filter)).isEmpty()
     }
 
-    private fun createWidgetItem(
-        providerInfo: AppWidgetProviderInfo,
-    ): WidgetItem {
+    private fun createWidgetItem(providerInfo: AppWidgetProviderInfo): WidgetItem {
         val widgetInfo = LauncherAppWidgetProviderInfo.fromProviderInfo(context, providerInfo)
         return WidgetItem(widgetInfo, testInvariantProfile, iconCache, context)
     }
 
     companion object {
+        const val TEST_TIMEOUT = 3L
+
         const val TEST_UI_SURFACE = "widgets_test"
         const val BUNDLE_KEY_ADDED_APP_WIDGETS = "added_app_widgets"
 
@@ -203,13 +275,13 @@
         private fun buildExpectedAppTargetEvent(
             pkg: String,
             providerClassName: String,
-            userHandle: UserHandle
+            userHandle: UserHandle,
         ): AppTargetEvent {
             val appTarget =
                 AppTarget.Builder(
                         /*id=*/ AppTargetId("widget:$pkg"),
                         /*packageName=*/ pkg,
-                        /*user=*/ userHandle
+                        /*user=*/ userHandle,
                     )
                     .setClassName(providerClassName)
                     .build()
