launcher3: Improve SimpleBroadcastReceiver.

  Unless explicitly annotated, parameters in java are by
  default nullable. There are a few cases where a null context
  may be passed to the unregisterReceiverSafely function
  of SimpleBroadcastReceiver.

  To mitigate misuses or incorrect contexts being passed for
  register vs unregister, keep the context as a strong reference
  in the constructor.

  Also added NonNull annotations for any public callsites to
  enforce behavior.

Bug: 395019017, 395966395
Flag: NONE - bug fixed
Test: manual - presubmit
Change-Id: Ie371fa45cadceaf51cf184b446df9123ef27c337
diff --git a/src/com/android/launcher3/InvariantDeviceProfile.java b/src/com/android/launcher3/InvariantDeviceProfile.java
index 43876a6..4107c8f 100644
--- a/src/com/android/launcher3/InvariantDeviceProfile.java
+++ b/src/com/android/launcher3/InvariantDeviceProfile.java
@@ -300,10 +300,10 @@
         lifeCycle.addCloseable(() -> prefs.removeListener(prefListener,
                 FIXED_LANDSCAPE_MODE, ENABLE_TWOLINE_ALLAPPS_TOGGLE));
 
-        SimpleBroadcastReceiver localeReceiver = new SimpleBroadcastReceiver(
+        SimpleBroadcastReceiver localeReceiver = new SimpleBroadcastReceiver(context,
                 MAIN_EXECUTOR, i -> onConfigChanged(context));
-        localeReceiver.register(context, Intent.ACTION_LOCALE_CHANGED);
-        lifeCycle.addCloseable(() -> localeReceiver.unregisterReceiverSafely(context));
+        localeReceiver.register(Intent.ACTION_LOCALE_CHANGED);
+        lifeCycle.addCloseable(() -> localeReceiver.unregisterReceiverSafely());
     }
 
     private String initGrid(Context context, String gridName) {
diff --git a/src/com/android/launcher3/LauncherAppState.java b/src/com/android/launcher3/LauncherAppState.java
index e560a14..71013c3 100644
--- a/src/com/android/launcher3/LauncherAppState.java
+++ b/src/com/android/launcher3/LauncherAppState.java
@@ -119,14 +119,13 @@
         }
 
         SimpleBroadcastReceiver modelChangeReceiver =
-                new SimpleBroadcastReceiver(UI_HELPER_EXECUTOR, mModel::onBroadcastIntent);
+                new SimpleBroadcastReceiver(context, UI_HELPER_EXECUTOR, mModel::onBroadcastIntent);
         modelChangeReceiver.register(
-                mContext,
                 ACTION_DEVICE_POLICY_RESOURCE_UPDATED);
         if (BuildConfig.IS_STUDIO_BUILD) {
-            modelChangeReceiver.register(mContext, RECEIVER_EXPORTED, ACTION_FORCE_ROLOAD);
+            modelChangeReceiver.register(RECEIVER_EXPORTED, ACTION_FORCE_ROLOAD);
         }
-        mOnTerminateCallback.add(() -> modelChangeReceiver.unregisterReceiverSafely(mContext));
+        mOnTerminateCallback.add(() -> modelChangeReceiver.unregisterReceiverSafely());
 
         SafeCloseable userChangeListener = UserCache.INSTANCE.get(mContext)
                 .addUserEventListener(mModel::onUserEvent);
diff --git a/src/com/android/launcher3/graphics/ThemeManager.kt b/src/com/android/launcher3/graphics/ThemeManager.kt
index 242220a..1636da8 100644
--- a/src/com/android/launcher3/graphics/ThemeManager.kt
+++ b/src/com/android/launcher3/graphics/ThemeManager.kt
@@ -62,8 +62,8 @@
     private val listeners = CopyOnWriteArrayList<ThemeChangeListener>()
 
     init {
-        val receiver = SimpleBroadcastReceiver(MAIN_EXECUTOR) { verifyIconState() }
-        receiver.registerPkgActions(context, "android", ACTION_OVERLAY_CHANGED)
+        val receiver = SimpleBroadcastReceiver(context, MAIN_EXECUTOR) { verifyIconState() }
+        receiver.registerPkgActions("android", ACTION_OVERLAY_CHANGED)
 
         val prefListener = LauncherPrefChangeListener { key ->
             when (key) {
@@ -74,7 +74,7 @@
         prefs.addListener(prefListener, THEMED_ICONS, PREF_ICON_SHAPE)
 
         lifecycle.addCloseable {
-            receiver.unregisterReceiverSafely(context)
+            receiver.unregisterReceiverSafely()
             prefs.removeListener(prefListener)
         }
     }
diff --git a/src/com/android/launcher3/pm/UserCache.java b/src/com/android/launcher3/pm/UserCache.java
index 0b18a87..20c0ecc 100644
--- a/src/com/android/launcher3/pm/UserCache.java
+++ b/src/com/android/launcher3/pm/UserCache.java
@@ -81,10 +81,7 @@
     }
 
     private final List<BiConsumer<UserHandle, String>> mUserEventListeners = new ArrayList<>();
-    private final SimpleBroadcastReceiver mUserChangeReceiver =
-            new SimpleBroadcastReceiver(MODEL_EXECUTOR, this::onUsersChanged);
-
-    private final Context mContext;
+    private final SimpleBroadcastReceiver mUserChangeReceiver;
     private final ApiWrapper mApiWrapper;
 
     @NonNull
@@ -99,16 +96,17 @@
             DaggerSingletonTracker tracker,
             ApiWrapper apiWrapper
     ) {
-        mContext = context;
         mApiWrapper = apiWrapper;
+        mUserChangeReceiver = new SimpleBroadcastReceiver(context,
+                MODEL_EXECUTOR, this::onUsersChanged);
         mUserToSerialMap = Collections.emptyMap();
         MODEL_EXECUTOR.execute(this::initAsync);
-        tracker.addCloseable(() -> mUserChangeReceiver.unregisterReceiverSafely(mContext));
+        tracker.addCloseable(() -> mUserChangeReceiver.unregisterReceiverSafely());
     }
 
     @WorkerThread
     private void initAsync() {
-        mUserChangeReceiver.register(mContext,
+        mUserChangeReceiver.register(
                 Intent.ACTION_MANAGED_PROFILE_AVAILABLE,
                 Intent.ACTION_MANAGED_PROFILE_UNAVAILABLE,
                 Intent.ACTION_MANAGED_PROFILE_REMOVED,
diff --git a/src/com/android/launcher3/util/DisplayController.java b/src/com/android/launcher3/util/DisplayController.java
index ee1af81..89f12d8 100644
--- a/src/com/android/launcher3/util/DisplayController.java
+++ b/src/com/android/launcher3/util/DisplayController.java
@@ -107,7 +107,6 @@
     private static final String ACTION_OVERLAY_CHANGED = "android.intent.action.OVERLAY_CHANGED";
     private static final String TARGET_OVERLAY_PACKAGE = "android";
 
-    private final Context mContext;
     private final WindowManagerProxy mWMProxy;
 
     // Null for SDK < S
@@ -120,8 +119,7 @@
 
     // We will register broadcast receiver on main thread to ensure not missing changes on
     // TARGET_OVERLAY_PACKAGE and ACTION_OVERLAY_CHANGED.
-    private final SimpleBroadcastReceiver mReceiver =
-            new SimpleBroadcastReceiver(MAIN_EXECUTOR, this::onIntent);
+    private final SimpleBroadcastReceiver mReceiver;
 
     private Info mInfo;
     private boolean mDestroyed = false;
@@ -131,7 +129,6 @@
             WindowManagerProxy wmProxy,
             LauncherPrefs prefs,
             DaggerSingletonTracker lifecycle) {
-        mContext = context;
         mWMProxy = wmProxy;
 
         if (enableTaskbarPinning()) {
@@ -155,11 +152,12 @@
 
         Display display = context.getSystemService(DisplayManager.class)
                 .getDisplay(DEFAULT_DISPLAY);
-        mWindowContext = mContext.createWindowContext(display, TYPE_APPLICATION, null);
+        mWindowContext = context.createWindowContext(display, TYPE_APPLICATION, null);
         mWindowContext.registerComponentCallbacks(this);
 
         // Initialize navigation mode change listener
-        mReceiver.registerPkgActions(mContext, TARGET_OVERLAY_PACKAGE, ACTION_OVERLAY_CHANGED);
+        mReceiver = new SimpleBroadcastReceiver(context, MAIN_EXECUTOR, this::onIntent);
+        mReceiver.registerPkgActions(TARGET_OVERLAY_PACKAGE, ACTION_OVERLAY_CHANGED);
 
         mInfo = new Info(mWindowContext, wmProxy,
                 wmProxy.estimateInternalDisplayBounds(mWindowContext));
@@ -169,7 +167,7 @@
         lifecycle.addCloseable(() -> {
             mDestroyed = true;
             mWindowContext.unregisterComponentCallbacks(this);
-            mReceiver.unregisterReceiverSafely(mContext);
+            mReceiver.unregisterReceiverSafely();
             wmProxy.unregisterDesktopVisibilityListener(this);
         });
     }
diff --git a/src/com/android/launcher3/util/LockedUserState.kt b/src/com/android/launcher3/util/LockedUserState.kt
index a6a6ceb..742a327 100644
--- a/src/com/android/launcher3/util/LockedUserState.kt
+++ b/src/com/android/launcher3/util/LockedUserState.kt
@@ -44,7 +44,7 @@
 
     @VisibleForTesting
     val userUnlockedReceiver =
-        SimpleBroadcastReceiver(UI_HELPER_EXECUTOR) {
+        SimpleBroadcastReceiver(context, UI_HELPER_EXECUTOR) {
             if (Intent.ACTION_USER_UNLOCKED == it.action) {
                 isUserUnlocked = true
             }
@@ -61,7 +61,6 @@
         isUserUnlockedAtLauncherStartup = isUserUnlocked
         if (!isUserUnlocked) {
             userUnlockedReceiver.register(
-                context,
                 {
                     // If user is unlocked while registering broadcast receiver, we should update
                     // [isUserUnlocked], which will call [notifyUserUnlocked] in setter
@@ -72,7 +71,7 @@
                 Intent.ACTION_USER_UNLOCKED,
             )
         }
-        lifeCycle.addCloseable { userUnlockedReceiver.unregisterReceiverSafely(context) }
+        lifeCycle.addCloseable { userUnlockedReceiver.unregisterReceiverSafely() }
     }
 
     private fun checkIsUserUnlocked() =
@@ -80,7 +79,7 @@
 
     private fun notifyUserUnlocked() {
         mUserUnlockedActions.executeAllAndDestroy()
-        userUnlockedReceiver.unregisterReceiverSafely(context)
+        userUnlockedReceiver.unregisterReceiverSafely()
     }
 
     /**
diff --git a/src/com/android/launcher3/util/ScreenOnTracker.java b/src/com/android/launcher3/util/ScreenOnTracker.java
index 50be98b..8ffe9ea 100644
--- a/src/com/android/launcher3/util/ScreenOnTracker.java
+++ b/src/com/android/launcher3/util/ScreenOnTracker.java
@@ -46,34 +46,31 @@
     private final SimpleBroadcastReceiver mReceiver;
     private final CopyOnWriteArrayList<ScreenOnListener> mListeners = new CopyOnWriteArrayList<>();
 
-    private final Context mContext;
     private boolean mIsScreenOn;
 
     @Inject
     ScreenOnTracker(@ApplicationContext Context context, DaggerSingletonTracker tracker) {
         // Assume that the screen is on to begin with
-        mContext = context;
-        mReceiver = new SimpleBroadcastReceiver(UI_HELPER_EXECUTOR, this::onReceive);
+        mReceiver = new SimpleBroadcastReceiver(context, UI_HELPER_EXECUTOR, this::onReceive);
         init(tracker);
     }
 
     @VisibleForTesting
     ScreenOnTracker(@ApplicationContext Context context, SimpleBroadcastReceiver receiver,
             DaggerSingletonTracker tracker) {
-        mContext = context;
         mReceiver = receiver;
         init(tracker);
     }
 
     private void init(DaggerSingletonTracker tracker) {
         mIsScreenOn = true;
-        mReceiver.register(mContext, ACTION_SCREEN_ON, ACTION_SCREEN_OFF, ACTION_USER_PRESENT);
+        mReceiver.register(ACTION_SCREEN_ON, ACTION_SCREEN_OFF, ACTION_USER_PRESENT);
         tracker.addCloseable(this);
     }
 
     @Override
     public void close() {
-        mReceiver.unregisterReceiverSafely(mContext);
+        mReceiver.unregisterReceiverSafely();
     }
 
     @VisibleForTesting
diff --git a/src/com/android/launcher3/util/SimpleBroadcastReceiver.java b/src/com/android/launcher3/util/SimpleBroadcastReceiver.java
index 539a7cb..7a40abe 100644
--- a/src/com/android/launcher3/util/SimpleBroadcastReceiver.java
+++ b/src/com/android/launcher3/util/SimpleBroadcastReceiver.java
@@ -25,22 +25,29 @@
 import android.text.TextUtils;
 
 import androidx.annotation.AnyThread;
+import androidx.annotation.NonNull;
 import androidx.annotation.Nullable;
 
 import java.util.function.Consumer;
 
 public class SimpleBroadcastReceiver extends BroadcastReceiver {
+    public static final String TAG = "SimpleBroadcastReceiver";
+    // Keeps a strong reference to the context.
+    private final Context mContext;
 
     private final Consumer<Intent> mIntentConsumer;
 
     // Handler to register/unregister broadcast receiver
     private final Handler mHandler;
 
-    public SimpleBroadcastReceiver(LooperExecutor looperExecutor, Consumer<Intent> intentConsumer) {
-        this(looperExecutor.getHandler(), intentConsumer);
+    public SimpleBroadcastReceiver(@NonNull Context context, LooperExecutor looperExecutor,
+            Consumer<Intent> intentConsumer) {
+        this(context, looperExecutor.getHandler(), intentConsumer);
     }
 
-    public SimpleBroadcastReceiver(Handler handler, Consumer<Intent> intentConsumer) {
+    public SimpleBroadcastReceiver(@NonNull Context context, Handler handler,
+            Consumer<Intent> intentConsumer) {
+        mContext = context;
         mIntentConsumer = intentConsumer;
         mHandler = handler;
     }
@@ -50,18 +57,18 @@
         mIntentConsumer.accept(intent);
     }
 
-    /** Calls {@link #register(Context, Runnable, String...)} with null completionCallback. */
+    /** Calls {@link #register(Runnable, String...)} with null completionCallback. */
     @AnyThread
-    public void register(Context context, String... actions) {
-        register(context, null, actions);
+    public void register(String... actions) {
+        register(null, actions);
     }
 
     /**
-     * Calls {@link #register(Context, Runnable, int, String...)} with null completionCallback.
+     * Calls {@link #register(Runnable, int, String...)} with null completionCallback.
      */
     @AnyThread
-    public void register(Context context, int flags, String... actions) {
-        register(context, null, flags, actions);
+    public void register(int flags, String... actions) {
+        register(null, flags, actions);
     }
 
     /**
@@ -74,19 +81,18 @@
      *                           while registerReceiver() is executed on a binder call.
      */
     @AnyThread
-    public void register(
-            Context context, @Nullable Runnable completionCallback, String... actions) {
+    public void register(@Nullable Runnable completionCallback, String... actions) {
         if (Looper.myLooper() == mHandler.getLooper()) {
-            registerInternal(context, completionCallback, actions);
+            registerInternal(mContext, completionCallback, actions);
         } else {
-            mHandler.post(() -> registerInternal(context, completionCallback, actions));
+            mHandler.post(() -> registerInternal(mContext, completionCallback, actions));
         }
     }
 
     /** Register broadcast receiver and run completion callback if passed. */
     @AnyThread
     private void registerInternal(
-            Context context, @Nullable Runnable completionCallback, String... actions) {
+            @NonNull Context context, @Nullable Runnable completionCallback, String... actions) {
         context.registerReceiver(this, getFilter(actions));
         if (completionCallback != null) {
             completionCallback.run();
@@ -94,37 +100,37 @@
     }
 
     /**
-     * Same as {@link #register(Context, Runnable, String...)} above but with additional flags
-     * params.
+     * Same as {@link #register(Runnable, String...)} above but with additional flags
+     * params utilizine the original {@link Context}.
      */
     @AnyThread
-    public void register(
-            Context context, @Nullable Runnable completionCallback, int flags, String... actions) {
+    public void register(@Nullable Runnable completionCallback, int flags, String... actions) {
         if (Looper.myLooper() == mHandler.getLooper()) {
-            registerInternal(context, completionCallback, flags, actions);
+            registerInternal(mContext, completionCallback, flags, actions);
         } else {
-            mHandler.post(() -> registerInternal(context, completionCallback, flags, actions));
+            mHandler.post(() -> registerInternal(mContext, completionCallback, flags, actions));
         }
     }
 
     /** Register broadcast receiver and run completion callback if passed. */
     @AnyThread
     private void registerInternal(
-            Context context, @Nullable Runnable completionCallback, int flags, String... actions) {
+            @NonNull Context context, @Nullable Runnable completionCallback, int flags,
+            String... actions) {
         context.registerReceiver(this, getFilter(actions), flags);
         if (completionCallback != null) {
             completionCallback.run();
         }
     }
 
-    /** Same as {@link #register(Context, Runnable, String...)} above but with pkg name. */
+    /** Same as {@link #register(Runnable, String...)} above but with pkg name. */
     @AnyThread
-    public void registerPkgActions(Context context, @Nullable String pkg, String... actions) {
+    public void registerPkgActions(@Nullable String pkg, String... actions) {
         if (Looper.myLooper() == mHandler.getLooper()) {
-            context.registerReceiver(this, getPackageFilter(pkg, actions));
+            mContext.registerReceiver(this, getPackageFilter(pkg, actions));
         } else {
             mHandler.post(() -> {
-                context.registerReceiver(this, getPackageFilter(pkg, actions));
+                mContext.registerReceiver(this, getPackageFilter(pkg, actions));
             });
         }
     }
@@ -135,19 +141,19 @@
      * unregister happens on {@link #mHandler}'s looper.
      */
     @AnyThread
-    public void unregisterReceiverSafely(Context context) {
+    public void unregisterReceiverSafely() {
         if (Looper.myLooper() == mHandler.getLooper()) {
-            unregisterReceiverSafelyInternal(context);
+            unregisterReceiverSafelyInternal(mContext);
         } else {
             mHandler.post(() -> {
-                unregisterReceiverSafelyInternal(context);
+                unregisterReceiverSafelyInternal(mContext);
             });
         }
     }
 
     /** Unregister broadcast receiver ignoring any errors. */
     @AnyThread
-    private void unregisterReceiverSafelyInternal(Context context) {
+    private void unregisterReceiverSafelyInternal(@NonNull Context context) {
         try {
             context.unregisterReceiver(this);
         } catch (IllegalArgumentException e) {
diff --git a/src/com/android/launcher3/util/WallpaperOffsetInterpolator.java b/src/com/android/launcher3/util/WallpaperOffsetInterpolator.java
index f8cbe0d..26a04a5 100644
--- a/src/com/android/launcher3/util/WallpaperOffsetInterpolator.java
+++ b/src/com/android/launcher3/util/WallpaperOffsetInterpolator.java
@@ -31,8 +31,7 @@
     // Don't use all the wallpaper for parallax until you have at least this many pages
     private static final int MIN_PARALLAX_PAGE_SPAN = 4;
 
-    private final SimpleBroadcastReceiver mWallpaperChangeReceiver =
-            new SimpleBroadcastReceiver(UI_HELPER_EXECUTOR, i -> onWallpaperChanged());
+    private final SimpleBroadcastReceiver mWallpaperChangeReceiver;
     private final Workspace<?> mWorkspace;
     private final boolean mIsRtl;
     private final Handler mHandler;
@@ -46,6 +45,8 @@
 
     public WallpaperOffsetInterpolator(Workspace<?> workspace) {
         mWorkspace = workspace;
+        mWallpaperChangeReceiver = new SimpleBroadcastReceiver(
+                workspace.getContext(), UI_HELPER_EXECUTOR, i -> onWallpaperChanged());
         mIsRtl = Utilities.isRtl(workspace.getResources());
         mHandler = new OffsetHandler(workspace.getContext());
     }
@@ -198,11 +199,10 @@
     public void setWindowToken(IBinder token) {
         mWindowToken = token;
         if (mWindowToken == null && mRegistered) {
-            mWallpaperChangeReceiver.unregisterReceiverSafely(mWorkspace.getContext());
+            mWallpaperChangeReceiver.unregisterReceiverSafely();
             mRegistered = false;
         } else if (mWindowToken != null && !mRegistered) {
-            mWallpaperChangeReceiver.register(
-                    mWorkspace.getContext(), ACTION_WALLPAPER_CHANGED);
+            mWallpaperChangeReceiver.register(ACTION_WALLPAPER_CHANGED);
             onWallpaperChanged();
             mRegistered = true;
         }