Merge "Cleanup widgetsModel and add tests" into main
diff --git a/quickstep/src/com/android/launcher3/model/WidgetsPredictionUpdateTask.java b/quickstep/src/com/android/launcher3/model/WidgetsPredictionUpdateTask.java
index 64bb05e..0395d32 100644
--- a/quickstep/src/com/android/launcher3/model/WidgetsPredictionUpdateTask.java
+++ b/quickstep/src/com/android/launcher3/model/WidgetsPredictionUpdateTask.java
@@ -65,7 +65,7 @@
                 Collectors.toSet());
         Predicate<WidgetItem> notOnWorkspace = w -> !widgetsInWorkspace.contains(w);
         Map<ComponentKey, WidgetItem> allWidgets =
-                dataModel.widgetsModel.getAllWidgetComponentsWithoutShortcuts();
+                dataModel.widgetsModel.getWidgetsByComponentKey();
 
         List<WidgetItem> servicePredictedItems = new ArrayList<>();
 
diff --git a/src/com/android/launcher3/graphics/LauncherPreviewRenderer.java b/src/com/android/launcher3/graphics/LauncherPreviewRenderer.java
index 6088941..2408955 100644
--- a/src/com/android/launcher3/graphics/LauncherPreviewRenderer.java
+++ b/src/com/android/launcher3/graphics/LauncherPreviewRenderer.java
@@ -78,8 +78,6 @@
 import com.android.launcher3.folder.FolderIcon;
 import com.android.launcher3.model.BgDataModel;
 import com.android.launcher3.model.BgDataModel.FixedContainerItems;
-import com.android.launcher3.model.WidgetItem;
-import com.android.launcher3.model.WidgetsModel;
 import com.android.launcher3.model.data.AppPairInfo;
 import com.android.launcher3.model.data.CollectionInfo;
 import com.android.launcher3.model.data.FolderInfo;
@@ -106,6 +104,7 @@
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.stream.Collectors;
 
 /**
  * Utility class for generating the preview of Launcher for a given InvariantDeviceProfile.
@@ -376,15 +375,6 @@
                 getApplicationContext(), providerInfo));
     }
 
-    private void inflateAndAddWidgets(LauncherAppWidgetInfo info, WidgetsModel widgetsModel) {
-        WidgetItem widgetItem = widgetsModel.getWidgetProviderInfoByProviderName(
-                info.providerName, info.user, mContext);
-        if (widgetItem == null) {
-            return;
-        }
-        inflateAndAddWidgets(info, widgetItem.widgetInfo);
-    }
-
     private void inflateAndAddWidgets(
             LauncherAppWidgetInfo info, LauncherAppWidgetProviderInfo providerInfo) {
         AppWidgetHostView view = mAppWidgetHost.createView(
@@ -468,17 +458,22 @@
                     break;
             }
         }
+        Map<ComponentKey, AppWidgetProviderInfo> widgetsMap = widgetProviderInfoMap;
         for (ItemInfo itemInfo : currentAppWidgets) {
             switch (itemInfo.itemType) {
                 case Favorites.ITEM_TYPE_APPWIDGET:
                 case Favorites.ITEM_TYPE_CUSTOM_APPWIDGET:
-                    if (widgetProviderInfoMap != null) {
-                        inflateAndAddWidgets(
-                                (LauncherAppWidgetInfo) itemInfo, widgetProviderInfoMap);
-                    } else {
-                        inflateAndAddWidgets((LauncherAppWidgetInfo) itemInfo,
-                                dataModel.widgetsModel);
+                    if (widgetsMap == null) {
+                        widgetsMap = dataModel.widgetsModel.getWidgetsByComponentKey()
+                                .entrySet()
+                                .stream()
+                                .filter(entry -> entry.getValue().widgetInfo != null)
+                                .collect(Collectors.toMap(
+                                        Map.Entry::getKey,
+                                        entry -> entry.getValue().widgetInfo
+                                ));
                     }
+                    inflateAndAddWidgets((LauncherAppWidgetInfo) itemInfo, widgetsMap);
                     break;
                 default:
                     break;
diff --git a/src/com/android/launcher3/model/WidgetsModel.java b/src/com/android/launcher3/model/WidgetsModel.java
index 454ae96..58ebf0f 100644
--- a/src/com/android/launcher3/model/WidgetsModel.java
+++ b/src/com/android/launcher3/model/WidgetsModel.java
@@ -54,7 +54,9 @@
 import java.util.Map;
 import java.util.Map.Entry;
 import java.util.Set;
+import java.util.function.Function;
 import java.util.function.Predicate;
+import java.util.stream.Collectors;
 
 /**
  * Widgets data model that is used by the adapters of the widget views and controllers.
@@ -67,7 +69,26 @@
     private static final boolean DEBUG = false;
 
     /* Map of widgets and shortcuts that are tracked per package. */
-    private final Map<PackageItemInfo, List<WidgetItem>> mWidgetsList = new HashMap<>();
+    private final Map<PackageItemInfo, List<WidgetItem>> mWidgetsByPackageItem = new HashMap<>();
+
+    /**
+     * Returns all widgets keyed by their component key.
+     */
+    public synchronized Map<ComponentKey, WidgetItem> getWidgetsByComponentKey() {
+        return mWidgetsByPackageItem.values().stream()
+                .flatMap(Collection::stream).distinct()
+                .collect(Collectors.toMap(
+                        widget -> new ComponentKey(widget.componentName, widget.user),
+                        Function.identity()
+                ));
+    }
+
+    /**
+     * Returns widgets grouped by the package item that they should belong to.
+     */
+    public synchronized Map<PackageItemInfo, List<WidgetItem>> getWidgetsByPackageItem() {
+        return mWidgetsByPackageItem;
+    }
 
     /**
      * Returns a list of {@link WidgetsListBaseEntry} filtered using given widget item filter. All
@@ -85,7 +106,8 @@
         ArrayList<WidgetsListBaseEntry> result = new ArrayList<>();
         AlphabeticIndexCompat indexer = new AlphabeticIndexCompat(context);
 
-        for (Map.Entry<PackageItemInfo, List<WidgetItem>> entry : mWidgetsList.entrySet()) {
+        for (Map.Entry<PackageItemInfo, List<WidgetItem>> entry :
+                mWidgetsByPackageItem.entrySet()) {
             PackageItemInfo pkgItem = entry.getKey();
             List<WidgetItem> widgetItems = entry.getValue()
                     .stream()
@@ -112,41 +134,6 @@
         return getFilteredWidgetsListForPicker(context, /*widgetItemFilter=*/ item -> true);
     }
 
-    /** Returns a mapping of packages to their widgets without static shortcuts. */
-    public synchronized Map<PackageUserKey, List<WidgetItem>> getAllWidgetsWithoutShortcuts() {
-        if (!WIDGETS_ENABLED) {
-            return Collections.emptyMap();
-        }
-        Map<PackageUserKey, List<WidgetItem>> packagesToWidgets = new HashMap<>();
-        mWidgetsList.forEach((packageItemInfo, widgetsAndShortcuts) -> {
-            List<WidgetItem> widgets = widgetsAndShortcuts.stream()
-                    .filter(item -> item.widgetInfo != null)
-                    .collect(toList());
-            if (widgets.size() > 0) {
-                packagesToWidgets.put(
-                        new PackageUserKey(packageItemInfo.packageName, packageItemInfo.user),
-                        widgets);
-            }
-        });
-        return packagesToWidgets;
-    }
-
-    /**
-     * Returns a map of widget component keys to corresponding widget items. Excludes the
-     * shortcuts.
-     */
-    public synchronized Map<ComponentKey, WidgetItem> getAllWidgetComponentsWithoutShortcuts() {
-        if (!WIDGETS_ENABLED) {
-            return Collections.emptyMap();
-        }
-        Map<ComponentKey, WidgetItem> widgetsMap = new HashMap<>();
-        mWidgetsList.forEach((packageItemInfo, widgetsAndShortcuts) ->
-                widgetsAndShortcuts.stream().filter(item -> item.widgetInfo != null).forEach(
-                        item -> widgetsMap.put(new ComponentKey(item.componentName, item.user),
-                                item)));
-        return widgetsMap;
-    }
-
     /**
      * @param packageUser If null, all widgets and shortcuts are updated and returned, otherwise
      *                    only widgets and shortcuts associated with the package/user are.
@@ -210,14 +197,14 @@
 
         if (packageUser == null) {
             // Clear the list if this is an update on all widgets and shortcuts.
-            mWidgetsList.clear();
+            mWidgetsByPackageItem.clear();
         } else {
             // Otherwise, only clear the widgets and shortcuts for the changed package.
-            mWidgetsList.remove(packageItemInfoCache.getOrCreate(packageUser));
+            mWidgetsByPackageItem.remove(packageItemInfoCache.getOrCreate(packageUser));
         }
 
         // add and update.
-        mWidgetsList.putAll(rawWidgetsShortcuts.stream()
+        mWidgetsByPackageItem.putAll(rawWidgetsShortcuts.stream()
                 .filter(new WidgetValidityCheck(app))
                 .filter(new WidgetFlagCheck())
                 .flatMap(widgetItem -> getPackageUserKeys(app.getContext(), widgetItem).stream()
@@ -237,7 +224,7 @@
             return;
         }
         WidgetManagerHelper widgetManager = new WidgetManagerHelper(app.getContext());
-        for (Entry<PackageItemInfo, List<WidgetItem>> entry : mWidgetsList.entrySet()) {
+        for (Entry<PackageItemInfo, List<WidgetItem>> entry : mWidgetsByPackageItem.entrySet()) {
             if (packageNames.contains(entry.getKey().packageName)) {
                 List<WidgetItem> items = entry.getValue();
                 int count = items.size();
@@ -258,50 +245,6 @@
         }
     }
 
-    private PackageItemInfo createPackageItemInfo(
-            ComponentName providerName,
-            UserHandle user,
-            int category
-    ) {
-        if (category == NO_CATEGORY) {
-            return new PackageItemInfo(providerName.getPackageName(), user);
-        } else {
-            return new PackageItemInfo("" , category, user);
-        }
-    }
-
-    private IntSet getCategories(ComponentName providerName, Context context) {
-        IntSet categories = WidgetSections.getWidgetsToCategory(context).get(providerName);
-        if (categories != null) {
-            return categories;
-        }
-        categories = new IntSet();
-        categories.add(NO_CATEGORY);
-        return categories;
-    }
-
-    public WidgetItem getWidgetProviderInfoByProviderName(
-            ComponentName providerName, UserHandle user, Context context) {
-        if (!WIDGETS_ENABLED) {
-            return null;
-        }
-        IntSet categories = getCategories(providerName, context);
-
-        // Checking if we have a provider in any of the categories.
-        for (Integer category: categories) {
-            PackageItemInfo key = createPackageItemInfo(providerName, user, category);
-            List<WidgetItem> widgets = mWidgetsList.get(key);
-            if (widgets != null) {
-                return widgets.stream().filter(
-                                item -> item.componentName.equals(providerName)
-                        )
-                        .findFirst()
-                        .orElse(null);
-            }
-        }
-        return null;
-    }
-
     /** Returns {@link PackageItemInfo} of a pending widget. */
     public static PackageItemInfo newPendingItemInfo(Context context, ComponentName provider,
             UserHandle user) {
diff --git a/tests/multivalentTests/src/com/android/launcher3/model/WidgetsModelTest.kt b/tests/multivalentTests/src/com/android/launcher3/model/WidgetsModelTest.kt
new file mode 100644
index 0000000..71f7d47
--- /dev/null
+++ b/tests/multivalentTests/src/com/android/launcher3/model/WidgetsModelTest.kt
@@ -0,0 +1,209 @@
+/*
+ * Copyright (C) 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.launcher3.model
+
+import android.appwidget.AppWidgetManager
+import android.content.ComponentName
+import android.content.Context
+import android.os.UserHandle
+import android.platform.test.rule.AllowedDevices
+import android.platform.test.rule.DeviceProduct
+import android.platform.test.rule.LimitDevicesRule
+import androidx.test.core.app.ApplicationProvider
+import androidx.test.ext.junit.runners.AndroidJUnit4
+import com.android.launcher3.DeviceProfile
+import com.android.launcher3.InvariantDeviceProfile
+import com.android.launcher3.LauncherAppState
+import com.android.launcher3.icons.IconCache
+import com.android.launcher3.model.data.PackageItemInfo
+import com.android.launcher3.pm.UserCache
+import com.android.launcher3.util.ActivityContextWrapper
+import com.android.launcher3.util.ComponentKey
+import com.android.launcher3.util.Executors
+import com.android.launcher3.util.IntSet
+import com.android.launcher3.util.PackageUserKey
+import com.android.launcher3.util.WidgetUtils.createAppWidgetProviderInfo
+import com.android.launcher3.widget.LauncherAppWidgetProviderInfo
+import com.android.launcher3.widget.WidgetSections
+import com.android.launcher3.widget.WidgetSections.NO_CATEGORY
+import com.google.common.truth.Truth.assertThat
+import java.util.concurrent.CountDownLatch
+import java.util.concurrent.TimeUnit
+import org.junit.Assert.fail
+import org.junit.Before
+import org.junit.Rule
+import org.junit.Test
+import org.junit.runner.RunWith
+import org.mockito.Mock
+import org.mockito.Mockito.spy
+import org.mockito.junit.MockitoJUnit
+import org.mockito.junit.MockitoRule
+import org.mockito.kotlin.any
+import org.mockito.kotlin.whenever
+
+@AllowedDevices(allowed = [DeviceProduct.ROBOLECTRIC])
+@RunWith(AndroidJUnit4::class)
+class WidgetsModelTest {
+    @Rule @JvmField val limitDevicesRule = LimitDevicesRule()
+    @Rule @JvmField val mockitoRule: MockitoRule = MockitoJUnit.rule()
+
+    @Mock private lateinit var appWidgetManager: AppWidgetManager
+    @Mock private lateinit var app: LauncherAppState
+    @Mock private lateinit var iconCacheMock: IconCache
+
+    private lateinit var context: Context
+    private lateinit var idp: InvariantDeviceProfile
+    private lateinit var underTest: WidgetsModel
+
+    private var widgetSectionCategory: Int = 0
+    private lateinit var appAPackage: String
+
+    @Before
+    fun setUp() {
+        val appContext: Context = ApplicationProvider.getApplicationContext()
+        idp = InvariantDeviceProfile.INSTANCE[appContext]
+
+        context =
+            object : ActivityContextWrapper(ApplicationProvider.getApplicationContext()) {
+                override fun getSystemService(name: String): Any? {
+                    if (name == "appwidget") {
+                        return appWidgetManager
+                    }
+                    return super.getSystemService(name)
+                }
+
+                override fun getDeviceProfile(): DeviceProfile {
+                    return idp.getDeviceProfile(applicationContext).copy(applicationContext)
+                }
+            }
+
+        whenever(iconCacheMock.getTitleNoCache(any<LauncherAppWidgetProviderInfo>()))
+            .thenReturn("title")
+        whenever(app.iconCache).thenReturn(iconCacheMock)
+        whenever(app.context).thenReturn(context)
+        whenever(app.invariantDeviceProfile).thenReturn(idp)
+
+        val widgetToCategoryEntry: Map.Entry<ComponentName, IntSet> =
+            WidgetSections.getWidgetsToCategory(context).entries.first()
+        widgetSectionCategory = widgetToCategoryEntry.value.first()
+        val appAWidgetComponent = widgetToCategoryEntry.key
+        appAPackage = appAWidgetComponent.packageName
+
+        whenever(appWidgetManager.getInstalledProvidersForProfile(any()))
+            .thenReturn(
+                listOf(
+                    // First widget from widget sections xml
+                    createAppWidgetProviderInfo(appAWidgetComponent),
+                    // A widget that belongs to same package as the widget from widget sections
+                    // xml, but, because it's not mentioned in xml, it would be included in its
+                    // own package section.
+                    createAppWidgetProviderInfo(
+                        ComponentName.createRelative(appAPackage, APP_A_TEST_WIDGET_NAME)
+                    ),
+                    // A widget in different package (none of that app's widgets are in widget
+                    // sections xml)
+                    createAppWidgetProviderInfo(AppBTestWidgetComponent),
+                )
+            )
+
+        val userCache = spy(UserCache.INSTANCE.get(context))
+        whenever(userCache.userProfiles).thenReturn(listOf(UserHandle.CURRENT))
+
+        underTest = WidgetsModel()
+    }
+
+    @Test
+    fun widgetsByPackage_treatsWidgetSectionsAsSeparatePackageItems() {
+        loadWidgets()
+
+        val packages: Map<PackageItemInfo, List<WidgetItem>> = underTest.widgetsByPackageItem
+
+        // expect 3 package items
+        // one for the custom section with widget from appA
+        // one for package section for second widget from appA (that wasn't listed in xml)
+        // and one for package section for appB
+        assertThat(packages).hasSize(3)
+
+        // Each package item when used as a key is distinct (i.e. even if appA is split into custom
+        // package and owner package section, each of them is a distinct key). This ensures that
+        // clicking on a custom widget section doesn't take user to app package section.
+        val distinctPackageUserKeys =
+            packages.map { PackageUserKey.fromPackageItemInfo(it.key) }.distinct()
+        assertThat(distinctPackageUserKeys).hasSize(3)
+
+        val customSections = packages.filter { it.key.widgetCategory == widgetSectionCategory }
+        assertThat(customSections).hasSize(1)
+        val widgetsInCustomSection = customSections.entries.first().value
+        assertThat(widgetsInCustomSection).hasSize(1)
+
+        val packageSections = packages.filter { it.key.widgetCategory == NO_CATEGORY }
+        assertThat(packageSections).hasSize(2)
+
+        // App A's package section
+        val appAPackageSection = packageSections.filter { it.key.packageName == appAPackage }
+        assertThat(appAPackageSection).hasSize(1)
+        val widgetsInAppASection = appAPackageSection.entries.first().value
+        assertThat(widgetsInAppASection).hasSize(1)
+
+        // App B's package section
+        val appBPackageSection =
+            packageSections.filter { it.key.packageName == AppBTestWidgetComponent.packageName }
+        assertThat(appBPackageSection).hasSize(1)
+        val widgetsInAppBSection = appBPackageSection.entries.first().value
+        assertThat(widgetsInAppBSection).hasSize(1)
+    }
+
+    @Test
+    fun widgetComponentMap_returnsWidgets() {
+        loadWidgets()
+
+        val widgetsByComponentKey: Map<ComponentKey, WidgetItem> = underTest.widgetsByComponentKey
+
+        assertThat(widgetsByComponentKey).hasSize(3)
+        widgetsByComponentKey.forEach { entry ->
+            assertThat(entry.key).isEqualTo(entry.value as ComponentKey)
+        }
+    }
+
+    @Test
+    fun widgets_noData_returnsEmpty() {
+        // no loadWidgets()
+
+        assertThat(underTest.widgetsByComponentKey).isEmpty()
+    }
+
+    private fun loadWidgets() {
+        val latch = CountDownLatch(1)
+        Executors.MODEL_EXECUTOR.execute {
+            underTest.update(app, /* packageUser= */ null)
+            latch.countDown()
+        }
+        if (!latch.await(LOAD_WIDGETS_TIMEOUT_SECONDS, TimeUnit.SECONDS)) {
+            fail("Timed out waiting widgets to load")
+        }
+    }
+
+    companion object {
+        // Another widget within app A
+        private const val APP_A_TEST_WIDGET_NAME = "MyProvider"
+
+        private val AppBTestWidgetComponent: ComponentName =
+            ComponentName.createRelative("com.test.package", "TestProvider")
+
+        private const val LOAD_WIDGETS_TIMEOUT_SECONDS = 2L
+    }
+}