Merge "Unregister widget prediction callback on clear" into main
diff --git a/quickstep/src/com/android/launcher3/WidgetPickerActivity.java b/quickstep/src/com/android/launcher3/WidgetPickerActivity.java
index 7f3e615..4d3e3be 100644
--- a/quickstep/src/com/android/launcher3/WidgetPickerActivity.java
+++ b/quickstep/src/com/android/launcher3/WidgetPickerActivity.java
@@ -68,7 +68,8 @@
import java.util.regex.Pattern;
/** An Activity that can host Launcher's widget picker. */
-public class WidgetPickerActivity extends BaseActivity {
+public class WidgetPickerActivity extends BaseActivity implements
+ WidgetPredictionsRequester.WidgetPredictionsListener {
private static final String TAG = "WidgetPickerActivity";
/**
* Name of the extra that indicates that a widget being dragged.
@@ -322,7 +323,7 @@
if (mUiSurface != null) {
mWidgetPredictionsRequester = new WidgetPredictionsRequester(app.getContext(),
mUiSurface, mModel.getWidgetsByComponentKeyForPicker());
- mWidgetPredictionsRequester.request(mAddedWidgets, this::bindRecommendedWidgets);
+ mWidgetPredictionsRequester.request(mAddedWidgets, /*listener=*/ this);
}
});
}
@@ -355,7 +356,8 @@
});
}
- private void bindRecommendedWidgets(List<ItemInfo> recommendedWidgets) {
+ @Override
+ public void onPredictionsAvailable(List<ItemInfo> recommendedWidgets) {
// Bind recommendations once picker has finished open animation.
MAIN_EXECUTOR.getHandler().postDelayed(
() -> mWidgetPickerDataProvider.setWidgetRecommendations(recommendedWidgets),
diff --git a/quickstep/src/com/android/launcher3/model/WidgetPredictionsRequester.java b/quickstep/src/com/android/launcher3/model/WidgetPredictionsRequester.java
index d3ac975..f9cec82 100644
--- a/quickstep/src/com/android/launcher3/model/WidgetPredictionsRequester.java
+++ b/quickstep/src/com/android/launcher3/model/WidgetPredictionsRequester.java
@@ -48,7 +48,6 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
-import java.util.function.Consumer;
import java.util.function.Predicate;
import java.util.stream.Collectors;
@@ -56,7 +55,7 @@
* Works with app predictor to fetch and process widget predictions displayed in a standalone
* widget picker activity for a UI surface.
*/
-public class WidgetPredictionsRequester {
+public class WidgetPredictionsRequester implements AppPredictor.Callback {
private static final int NUM_OF_RECOMMENDED_WIDGETS_PREDICATION = 20;
private static final String BUNDLE_KEY_ADDED_APP_WIDGETS = "added_app_widgets";
// container/screenid/[positionx,positiony]/[spanx,spany]
@@ -71,6 +70,9 @@
@NonNull
private final String mUiSurface;
private boolean mPredictionsAvailable;
+ @Nullable
+ private WidgetPredictionsListener mPredictionsListener = null;
+ @Nullable Predicate<WidgetItem> mFilter = null;
@NonNull
private final Map<ComponentKey, WidgetItem> mAllWidgets;
@@ -81,36 +83,49 @@
mAllWidgets = Collections.unmodifiableMap(allWidgets);
}
+ // AppPredictor.Callback -> onTargetsAvailable
+ @Override
+ @WorkerThread
+ public void onTargetsAvailable(List<AppTarget> targets) {
+ List<WidgetItem> filteredPredictions = filterPredictions(targets, mAllWidgets, mFilter);
+ List<ItemInfo> mappedPredictions = mapWidgetItemsToItemInfo(filteredPredictions);
+
+ if (!mPredictionsAvailable && mPredictionsListener != null) {
+ mPredictionsAvailable = true;
+ MAIN_EXECUTOR.execute(
+ () -> mPredictionsListener.onPredictionsAvailable(mappedPredictions));
+ }
+ }
+
/**
* Requests one time predictions from the app predictions manager and invokes provided callback
- * once predictions are available.
+ * once predictions are available. Any previous requests may be cancelled.
*
* @param existingWidgets widgets that are currently added to the surface;
- * @param callback consumer of prediction results to be called when predictions are
- * available
+ * @param listener consumer of prediction results to be called when predictions are
+ * available; any previous listener will no longer receive updates.
*/
+ @WorkerThread // e.g. MODEL_EXECUTOR
public void request(List<AppWidgetProviderInfo> existingWidgets,
- Consumer<List<ItemInfo>> callback) {
+ WidgetPredictionsListener listener) {
+ clear();
+ mPredictionsListener = listener;
+ mFilter = notOnUiSurfaceFilter(existingWidgets);
+
+ AppPredictionManager apm = mContext.getSystemService(AppPredictionManager.class);
+ if (apm == null) {
+ return;
+ }
+
Bundle bundle = buildBundleForPredictionSession(existingWidgets);
- Predicate<WidgetItem> filter = notOnUiSurfaceFilter(existingWidgets);
-
- MODEL_EXECUTOR.execute(() -> {
- clear();
- AppPredictionManager apm = mContext.getSystemService(AppPredictionManager.class);
- if (apm == null) {
- return;
- }
-
- mAppPredictor = apm.createAppPredictionSession(
- new AppPredictionContext.Builder(mContext)
- .setUiSurface(mUiSurface)
- .setExtras(bundle)
- .setPredictedTargetCount(NUM_OF_RECOMMENDED_WIDGETS_PREDICATION)
- .build());
- mAppPredictor.registerPredictionUpdates(MODEL_EXECUTOR,
- targets -> bindPredictions(targets, filter, callback));
- mAppPredictor.requestPredictionUpdate();
- });
+ mAppPredictor = apm.createAppPredictionSession(
+ new AppPredictionContext.Builder(mContext)
+ .setUiSurface(mUiSurface)
+ .setExtras(bundle)
+ .setPredictedTargetCount(NUM_OF_RECOMMENDED_WIDGETS_PREDICATION)
+ .build());
+ mAppPredictor.registerPredictionUpdates(MODEL_EXECUTOR, /*callback=*/ this);
+ mAppPredictor.requestPredictionUpdate();
}
/**
@@ -158,27 +173,14 @@
return widgetItem -> !existingComponentKeys.contains(widgetItem);
}
- /** Provides the predictions returned by the predictor to the registered callback. */
- @WorkerThread
- private void bindPredictions(List<AppTarget> targets, Predicate<WidgetItem> filter,
- Consumer<List<ItemInfo>> callback) {
- if (!mPredictionsAvailable) {
- mPredictionsAvailable = true;
- List<WidgetItem> filteredPredictions = filterPredictions(targets, mAllWidgets, filter);
- List<ItemInfo> mappedPredictions = mapWidgetItemsToItemInfo(filteredPredictions);
-
- MAIN_EXECUTOR.execute(() -> callback.accept(mappedPredictions));
- MODEL_EXECUTOR.execute(this::clear);
- }
- }
-
/**
* Applies the provided filter (e.g. widgets not on workspace) on the predictions returned by
* the predictor.
*/
@VisibleForTesting
static List<WidgetItem> filterPredictions(List<AppTarget> predictions,
- Map<ComponentKey, WidgetItem> allWidgets, Predicate<WidgetItem> filter) {
+ @NonNull Map<ComponentKey, WidgetItem> allWidgets,
+ @Nullable Predicate<WidgetItem> filter) {
List<WidgetItem> servicePredictedItems = new ArrayList<>();
for (AppTarget prediction : predictions) {
@@ -187,7 +189,7 @@
WidgetItem widgetItem = allWidgets.get(
new ComponentKey(new ComponentName(prediction.getPackageName(), className),
prediction.getUser()));
- if (widgetItem != null && filter.test(widgetItem)) {
+ if (widgetItem != null && (filter == null || filter.test(widgetItem))) {
servicePredictedItems.add(widgetItem);
}
}
@@ -218,9 +220,23 @@
/** Cleans up any open prediction sessions. */
public void clear() {
if (mAppPredictor != null) {
+ mAppPredictor.unregisterPredictionUpdates(this);
mAppPredictor.destroy();
mAppPredictor = null;
}
+ mPredictionsListener = null;
mPredictionsAvailable = false;
+ mFilter = null;
+ }
+
+ /**
+ * Listener class to listen to updates from the {@link WidgetPredictionsRequester}
+ */
+ public interface WidgetPredictionsListener {
+ /**
+ * Callback method that is called when the predicted widgets are available.
+ * @param predictions list of predicted widgets {@link PendingAddWidgetInfo}
+ */
+ void onPredictionsAvailable(List<ItemInfo> predictions);
}
}
diff --git a/quickstep/tests/multivalentTests/src/com/android/launcher3/model/WidgetsPredictionsRequesterTest.kt b/quickstep/tests/multivalentTests/src/com/android/launcher3/model/WidgetsPredictionsRequesterTest.kt
index 4ea74df..d445189 100644
--- a/quickstep/tests/multivalentTests/src/com/android/launcher3/model/WidgetsPredictionsRequesterTest.kt
+++ b/quickstep/tests/multivalentTests/src/com/android/launcher3/model/WidgetsPredictionsRequesterTest.kt
@@ -16,6 +16,8 @@
package com.android.launcher3.model
+import android.app.prediction.AppPredictionManager
+import android.app.prediction.AppPredictor
import android.app.prediction.AppTarget
import android.app.prediction.AppTargetEvent
import android.app.prediction.AppTargetId
@@ -36,9 +38,15 @@
import com.android.launcher3.model.WidgetPredictionsRequester.notOnUiSurfaceFilter
import com.android.launcher3.util.ActivityContextWrapper
import com.android.launcher3.util.ComponentKey
+import com.android.launcher3.util.Executors
+import com.android.launcher3.util.Executors.MODEL_EXECUTOR
+import com.android.launcher3.util.TestUtil
import com.android.launcher3.util.WidgetUtils.createAppWidgetProviderInfo
import com.android.launcher3.widget.LauncherAppWidgetProviderInfo
+import com.android.launcher3.widget.PendingAddWidgetInfo
import com.google.common.truth.Truth.assertThat
+import java.util.concurrent.CountDownLatch
+import java.util.concurrent.TimeUnit
import java.util.function.Predicate
import junit.framework.Assert.assertNotNull
import org.junit.Before
@@ -46,6 +54,9 @@
import org.junit.runner.RunWith
import org.mockito.Mock
import org.mockito.MockitoAnnotations
+import org.mockito.kotlin.any
+import org.mockito.kotlin.doAnswer
+import org.mockito.kotlin.whenever
@RunWith(AndroidJUnit4::class)
class WidgetsPredictionsRequesterTest {
@@ -67,11 +78,26 @@
@Mock private lateinit var iconCache: IconCache
+ @Mock private lateinit var apmMock: AppPredictionManager
+
+ @Mock private lateinit var predictorMock: AppPredictor
+
@Before
fun setUp() {
MockitoAnnotations.initMocks(this)
mUserHandle = myUserHandle()
- context = ActivityContextWrapper(ApplicationProvider.getApplicationContext())
+
+ whenever(apmMock.createAppPredictionSession(any())).thenReturn(predictorMock)
+
+ context =
+ object : ActivityContextWrapper(ApplicationProvider.getApplicationContext()) {
+ override fun getSystemService(name: String): Any? {
+ if (name == "app_prediction") {
+ return apmMock
+ }
+ return super.getSystemService(name)
+ }
+ }
testInvariantProfile = LauncherAppState.getIDP(context)
deviceProfile = testInvariantProfile.getDeviceProfile(context).copy(context)
@@ -114,22 +140,68 @@
buildExpectedAppTargetEvent(
/*pkg=*/ APP_1_PACKAGE_NAME,
/*providerClassName=*/ APP_1_PROVIDER_A_CLASS_NAME,
- /*user=*/ mUserHandle
+ /*user=*/ mUserHandle,
),
buildExpectedAppTargetEvent(
/*pkg=*/ APP_1_PACKAGE_NAME,
/*providerClassName=*/ APP_1_PROVIDER_B_CLASS_NAME,
- /*user=*/ mUserHandle
+ /*user=*/ mUserHandle,
),
buildExpectedAppTargetEvent(
/*pkg=*/ APP_2_PACKAGE_NAME,
/*providerClassName=*/ APP_2_PROVIDER_1_CLASS_NAME,
- /*user=*/ mUserHandle
- )
+ /*user=*/ mUserHandle,
+ ),
)
}
@Test
+ fun request_invokesCallbackWithPredictedItems() {
+ TestUtil.runOnExecutorSync(MODEL_EXECUTOR) {
+ val underTest = WidgetPredictionsRequester(context, TEST_UI_SURFACE, allWidgets)
+ val existingWidgets = arrayListOf(widget1aInfo, widget1bInfo)
+ val predictions =
+ listOf(
+ // (existing) already on surface
+ AppTarget(
+ AppTargetId(APP_1_PACKAGE_NAME),
+ APP_1_PACKAGE_NAME,
+ APP_1_PROVIDER_B_CLASS_NAME,
+ mUserHandle,
+ ),
+ // eligible
+ AppTarget(
+ AppTargetId(APP_2_PACKAGE_NAME),
+ APP_2_PACKAGE_NAME,
+ APP_2_PROVIDER_1_CLASS_NAME,
+ mUserHandle,
+ ),
+ )
+ doAnswer {
+ underTest.onTargetsAvailable(predictions)
+ null
+ }
+ .whenever(predictorMock)
+ .requestPredictionUpdate()
+ val testCountDownLatch = CountDownLatch(1)
+ val listener =
+ WidgetPredictionsRequester.WidgetPredictionsListener { itemInfos ->
+ if (itemInfos.size == 1 && itemInfos[0] is PendingAddWidgetInfo) {
+ // only one item was eligible.
+ testCountDownLatch.countDown()
+ } else {
+ println("Unexpected prediction items found: ${itemInfos.size}")
+ }
+ }
+
+ underTest.request(existingWidgets, listener)
+ TestUtil.runOnExecutorSync(Executors.MAIN_EXECUTOR) {}
+
+ assertThat(testCountDownLatch.await(TEST_TIMEOUT, TimeUnit.SECONDS)).isTrue()
+ }
+ }
+
+ @Test
fun filterPredictions_notOnUiSurfaceFilter_returnsOnlyEligiblePredictions() {
val widgetsAlreadyOnSurface = arrayListOf(widget1bInfo)
val filter: Predicate<WidgetItem> = notOnUiSurfaceFilter(widgetsAlreadyOnSurface)
@@ -141,15 +213,15 @@
AppTargetId(APP_1_PACKAGE_NAME),
APP_1_PACKAGE_NAME,
APP_1_PROVIDER_B_CLASS_NAME,
- mUserHandle
+ mUserHandle,
),
// eligible
AppTarget(
AppTargetId(APP_2_PACKAGE_NAME),
APP_2_PACKAGE_NAME,
APP_2_PROVIDER_1_CLASS_NAME,
- mUserHandle
- )
+ mUserHandle,
+ ),
)
// only 2 was eligible
@@ -167,27 +239,27 @@
AppTargetId(APP_1_PACKAGE_NAME),
APP_1_PACKAGE_NAME,
"$APP_1_PACKAGE_NAME.SomeActivity",
- mUserHandle
+ mUserHandle,
),
AppTarget(
AppTargetId(APP_2_PACKAGE_NAME),
APP_2_PACKAGE_NAME,
"$APP_2_PACKAGE_NAME.SomeActivity2",
- mUserHandle
+ mUserHandle,
),
)
assertThat(filterPredictions(predictions, allWidgets, filter)).isEmpty()
}
- private fun createWidgetItem(
- providerInfo: AppWidgetProviderInfo,
- ): WidgetItem {
+ private fun createWidgetItem(providerInfo: AppWidgetProviderInfo): WidgetItem {
val widgetInfo = LauncherAppWidgetProviderInfo.fromProviderInfo(context, providerInfo)
return WidgetItem(widgetInfo, testInvariantProfile, iconCache, context)
}
companion object {
+ const val TEST_TIMEOUT = 3L
+
const val TEST_UI_SURFACE = "widgets_test"
const val BUNDLE_KEY_ADDED_APP_WIDGETS = "added_app_widgets"
@@ -203,13 +275,13 @@
private fun buildExpectedAppTargetEvent(
pkg: String,
providerClassName: String,
- userHandle: UserHandle
+ userHandle: UserHandle,
): AppTargetEvent {
val appTarget =
AppTarget.Builder(
/*id=*/ AppTargetId("widget:$pkg"),
/*packageName=*/ pkg,
- /*user=*/ userHandle
+ /*user=*/ userHandle,
)
.setClassName(providerClassName)
.build()