11/ Update MultiStateCallbacks to support multiple callbacks

- Allow multiple callbacks to be set for the same state
- Expose method to set state on ui thread directly
- Ensure callbacks are made immediately if the state is already set
- Clarify that the one shot callbacks vs the state listeners

Bug: 141886704
Change-Id: I8ea0dcd2821ee18d071706eaddeb2852afa13f30
diff --git a/quickstep/recents_ui_overrides/src/com/android/quickstep/BaseSwipeUpHandler.java b/quickstep/recents_ui_overrides/src/com/android/quickstep/BaseSwipeUpHandler.java
index 42d0a0c..a16c365 100644
--- a/quickstep/recents_ui_overrides/src/com/android/quickstep/BaseSwipeUpHandler.java
+++ b/quickstep/recents_ui_overrides/src/com/android/quickstep/BaseSwipeUpHandler.java
@@ -130,7 +130,6 @@
 
     protected Runnable mGestureEndCallback;
 
-    protected final Handler mMainThreadHandler = MAIN_EXECUTOR.getHandler();
     protected MultiStateCallback mStateCallback;
 
     protected boolean mCanceled;
@@ -157,14 +156,6 @@
                 .getDeviceProfile(mContext));
     }
 
-    protected void setStateOnUiThread(int stateFlag) {
-        if (Looper.myLooper() == mMainThreadHandler.getLooper()) {
-            mStateCallback.setState(stateFlag);
-        } else {
-            postAsyncCallback(mMainThreadHandler, () -> mStateCallback.setState(stateFlag));
-        }
-    }
-
     protected void performHapticFeedback() {
         if (!mVibrator.hasVibrator()) {
             return;
@@ -253,9 +244,9 @@
                                     } else {
                                         mActivityInterface.onLaunchTaskSuccess(mActivity);
                                     }
-                                }, mMainThreadHandler);
+                                }, MAIN_EXECUTOR.getHandler());
                     }
-                    setStateOnUiThread(successStateFlag);
+                    mStateCallback.setStateOnUiThread(successStateFlag);
                 }
                 mCanceled = false;
                 mFinishingRecentsAnimationForNewTaskId = -1;
diff --git a/quickstep/recents_ui_overrides/src/com/android/quickstep/WindowTransformSwipeHandler.java b/quickstep/recents_ui_overrides/src/com/android/quickstep/WindowTransformSwipeHandler.java
index 77ebc40..da933b0 100644
--- a/quickstep/recents_ui_overrides/src/com/android/quickstep/WindowTransformSwipeHandler.java
+++ b/quickstep/recents_ui_overrides/src/com/android/quickstep/WindowTransformSwipeHandler.java
@@ -211,58 +211,57 @@
     private void initStateCallbacks() {
         mStateCallback = new MultiStateCallback(STATE_NAMES);
 
-        mStateCallback.addCallback(STATE_LAUNCHER_PRESENT | STATE_GESTURE_STARTED,
+        mStateCallback.runOnceAtState(STATE_LAUNCHER_PRESENT | STATE_GESTURE_STARTED,
                 this::onLauncherPresentAndGestureStarted);
 
-        mStateCallback.addCallback(STATE_LAUNCHER_DRAWN | STATE_GESTURE_STARTED,
+        mStateCallback.runOnceAtState(STATE_LAUNCHER_DRAWN | STATE_GESTURE_STARTED,
                 this::initializeLauncherAnimationController);
 
-        mStateCallback.addCallback(STATE_LAUNCHER_PRESENT | STATE_LAUNCHER_DRAWN,
+        mStateCallback.runOnceAtState(STATE_LAUNCHER_PRESENT | STATE_LAUNCHER_DRAWN,
                 this::launcherFrameDrawn);
 
-        mStateCallback.addCallback(STATE_LAUNCHER_PRESENT | STATE_LAUNCHER_STARTED
+        mStateCallback.runOnceAtState(STATE_LAUNCHER_PRESENT | STATE_LAUNCHER_STARTED
                         | STATE_GESTURE_CANCELLED,
                 this::resetStateForAnimationCancel);
 
-        mStateCallback.addCallback(STATE_LAUNCHER_STARTED | STATE_APP_CONTROLLER_RECEIVED,
+        mStateCallback.runOnceAtState(STATE_LAUNCHER_STARTED | STATE_APP_CONTROLLER_RECEIVED,
                 this::sendRemoteAnimationsToAnimationFactory);
 
-        mStateCallback.addCallback(STATE_RESUME_LAST_TASK | STATE_APP_CONTROLLER_RECEIVED,
+        mStateCallback.runOnceAtState(STATE_RESUME_LAST_TASK | STATE_APP_CONTROLLER_RECEIVED,
                 this::resumeLastTask);
-        mStateCallback.addCallback(STATE_START_NEW_TASK | STATE_SCREENSHOT_CAPTURED,
+        mStateCallback.runOnceAtState(STATE_START_NEW_TASK | STATE_SCREENSHOT_CAPTURED,
                 this::startNewTask);
 
-        mStateCallback.addCallback(STATE_LAUNCHER_PRESENT | STATE_APP_CONTROLLER_RECEIVED
+        mStateCallback.runOnceAtState(STATE_LAUNCHER_PRESENT | STATE_APP_CONTROLLER_RECEIVED
                         | STATE_LAUNCHER_DRAWN | STATE_CAPTURE_SCREENSHOT,
                 this::switchToScreenshot);
 
-        mStateCallback.addCallback(STATE_SCREENSHOT_CAPTURED | STATE_GESTURE_COMPLETED
+        mStateCallback.runOnceAtState(STATE_SCREENSHOT_CAPTURED | STATE_GESTURE_COMPLETED
                         | STATE_SCALED_CONTROLLER_RECENTS,
                 this::finishCurrentTransitionToRecents);
 
-        mStateCallback.addCallback(STATE_SCREENSHOT_CAPTURED | STATE_GESTURE_COMPLETED
+        mStateCallback.runOnceAtState(STATE_SCREENSHOT_CAPTURED | STATE_GESTURE_COMPLETED
                         | STATE_SCALED_CONTROLLER_HOME,
                 this::finishCurrentTransitionToHome);
-        mStateCallback.addCallback(STATE_SCALED_CONTROLLER_HOME | STATE_CURRENT_TASK_FINISHED,
+        mStateCallback.runOnceAtState(STATE_SCALED_CONTROLLER_HOME | STATE_CURRENT_TASK_FINISHED,
                 this::reset);
 
-        mStateCallback.addCallback(STATE_LAUNCHER_PRESENT | STATE_APP_CONTROLLER_RECEIVED
+        mStateCallback.runOnceAtState(STATE_LAUNCHER_PRESENT | STATE_APP_CONTROLLER_RECEIVED
                         | STATE_LAUNCHER_DRAWN | STATE_SCALED_CONTROLLER_RECENTS
                         | STATE_CURRENT_TASK_FINISHED | STATE_GESTURE_COMPLETED
                         | STATE_GESTURE_STARTED,
                 this::setupLauncherUiAfterSwipeUpToRecentsAnimation);
 
-        mGestureState.addCallback(STATE_END_TARGET_ANIMATION_FINISHED,
-                this::onEndTargetSet);
+        mGestureState.runOnceAtState(STATE_END_TARGET_ANIMATION_FINISHED, this::onEndTargetSet);
 
-        mStateCallback.addCallback(STATE_HANDLER_INVALIDATED, this::invalidateHandler);
-        mStateCallback.addCallback(STATE_LAUNCHER_PRESENT | STATE_HANDLER_INVALIDATED,
+        mStateCallback.runOnceAtState(STATE_HANDLER_INVALIDATED, this::invalidateHandler);
+        mStateCallback.runOnceAtState(STATE_LAUNCHER_PRESENT | STATE_HANDLER_INVALIDATED,
                 this::invalidateHandlerWithLauncher);
-        mStateCallback.addCallback(STATE_HANDLER_INVALIDATED | STATE_RESUME_LAST_TASK,
+        mStateCallback.runOnceAtState(STATE_HANDLER_INVALIDATED | STATE_RESUME_LAST_TASK,
                 this::notifyTransitionCancelled);
 
         if (!ENABLE_QUICKSTEP_LIVE_TILE.get()) {
-            mStateCallback.addChangeHandler(STATE_APP_CONTROLLER_RECEIVED | STATE_LAUNCHER_PRESENT
+            mStateCallback.addChangeListener(STATE_APP_CONTROLLER_RECEIVED | STATE_LAUNCHER_PRESENT
                             | STATE_SCREENSHOT_VIEW_SHOWN | STATE_CAPTURE_SCREENSHOT,
                     (b) -> mRecentsView.setRunningTaskHidden(!b));
         }
@@ -330,7 +329,7 @@
                 // Launcher is visible, but might be about to stop. Thus, if we prepare recents
                 // now, it might get overridden by moveToRestState() in onStop(). To avoid this,
                 // wait until the next gesture (and possibly launcher) starts.
-                mStateCallback.addCallback(STATE_GESTURE_STARTED, initAnimFactory);
+                mStateCallback.runOnceAtState(STATE_GESTURE_STARTED, initAnimFactory);
             } else {
                 initAnimFactory.run();
             }
@@ -596,9 +595,9 @@
         super.onRecentsAnimationStart(controller, targets);
 
         // Only add the callback to enable the input consumer after we actually have the controller
-        mStateCallback.addCallback(STATE_APP_CONTROLLER_RECEIVED | STATE_GESTURE_STARTED,
+        mStateCallback.runOnceAtState(STATE_APP_CONTROLLER_RECEIVED | STATE_GESTURE_STARTED,
                 mRecentsAnimationController::enableInputConsumer);
-        setStateOnUiThread(STATE_APP_CONTROLLER_RECEIVED);
+        mStateCallback.setStateOnUiThread(STATE_APP_CONTROLLER_RECEIVED);
 
         mPassedOverviewThreshold = false;
     }
@@ -608,7 +607,7 @@
         super.onRecentsAnimationCanceled(thumbnailData);
         mRecentsView.setRecentsAnimationTargets(null, null);
         mActivityInitListener.unregister();
-        setStateOnUiThread(STATE_GESTURE_CANCELLED | STATE_HANDLER_INVALIDATED);
+        mStateCallback.setStateOnUiThread(STATE_GESTURE_CANCELLED | STATE_HANDLER_INVALIDATED);
         ActiveGestureLog.INSTANCE.addLog("cancelRecentsAnimation");
     }
 
@@ -616,7 +615,7 @@
     public void onGestureStarted() {
         notifyGestureStartedAsync();
         mShiftAtGestureStart = mCurrentShift.value;
-        setStateOnUiThread(STATE_GESTURE_STARTED);
+        mStateCallback.setStateOnUiThread(STATE_GESTURE_STARTED);
         mGestureStarted = true;
     }
 
@@ -639,7 +638,7 @@
     @Override
     public void onGestureCancelled() {
         updateDisplacement(0);
-        setStateOnUiThread(STATE_GESTURE_COMPLETED);
+        mStateCallback.setStateOnUiThread(STATE_GESTURE_COMPLETED);
         mLogAction = Touch.SWIPE_NOOP;
         handleNormalGestureEnd(0, false, new PointF(), true /* isCancel */);
     }
@@ -654,7 +653,7 @@
         float flingThreshold = mContext.getResources()
                 .getDimension(R.dimen.quickstep_fling_threshold_velocity);
         boolean isFling = mGestureStarted && Math.abs(endVelocity) > flingThreshold;
-        setStateOnUiThread(STATE_GESTURE_COMPLETED);
+        mStateCallback.setStateOnUiThread(STATE_GESTURE_COMPLETED);
 
         mLogAction = isFling ? Touch.FLING : Touch.SWIPE;
         boolean isVelocityVertical = Math.abs(velocity.y) > Math.abs(velocity.x);
@@ -906,7 +905,7 @@
                         return AnimatorPlaybackController.wrap(new AnimatorSet(), duration);
                     }
                 };
-                mStateCallback.addChangeHandler(STATE_LAUNCHER_PRESENT | STATE_HANDLER_INVALIDATED,
+                mStateCallback.addChangeListener(STATE_LAUNCHER_PRESENT | STATE_HANDLER_INVALIDATED,
                         isPresent -> mRecentsView.startHome());
             }
             RectFSpringAnim windowAnim = createWindowAnimationToHome(start, homeAnimFactory);
@@ -1042,7 +1041,7 @@
     }
 
     private void reset() {
-        setStateOnUiThread(STATE_HANDLER_INVALIDATED);
+        mStateCallback.setStateOnUiThread(STATE_HANDLER_INVALIDATED);
     }
 
     /**
@@ -1122,10 +1121,10 @@
                 }
                 mRecentsView.updateThumbnail(mRunningTaskId, mTaskSnapshot, false /* refreshNow */);
             }
-            setStateOnUiThread(STATE_SCREENSHOT_CAPTURED);
+            mStateCallback.setStateOnUiThread(STATE_SCREENSHOT_CAPTURED);
         } else if (!hasTargets()) {
             // If there are no targets, then we don't need to capture anything
-            setStateOnUiThread(STATE_SCREENSHOT_CAPTURED);
+            mStateCallback.setStateOnUiThread(STATE_SCREENSHOT_CAPTURED);
         } else {
             boolean finishTransitionPosted = false;
             if (mRecentsAnimationController != null) {
@@ -1145,14 +1144,15 @@
                     // Defer finishing the animation until the next launcher frame with the
                     // new thumbnail
                     finishTransitionPosted = ViewUtils.postDraw(taskView,
-                            () -> setStateOnUiThread(STATE_SCREENSHOT_CAPTURED), this::isCanceled);
+                            () -> mStateCallback.setStateOnUiThread(STATE_SCREENSHOT_CAPTURED),
+                                    this::isCanceled);
                 }
             }
             if (!finishTransitionPosted) {
                 // If we haven't posted a draw callback, set the state immediately.
                 Object traceToken = TraceHelper.INSTANCE.beginSection(SCREENSHOT_CAPTURED_EVT,
                         TraceHelper.FLAG_CHECK_FOR_RACE_CONDITIONS);
-                setStateOnUiThread(STATE_SCREENSHOT_CAPTURED);
+                mStateCallback.setStateOnUiThread(STATE_SCREENSHOT_CAPTURED);
                 TraceHelper.INSTANCE.endSection(traceToken);
             }
         }
@@ -1160,14 +1160,14 @@
 
     private void finishCurrentTransitionToRecents() {
         if (ENABLE_QUICKSTEP_LIVE_TILE.get()) {
-            setStateOnUiThread(STATE_CURRENT_TASK_FINISHED);
+            mStateCallback.setStateOnUiThread(STATE_CURRENT_TASK_FINISHED);
         } else if (!hasTargets()) {
             // If there are no targets, then there is nothing to finish
-            setStateOnUiThread(STATE_CURRENT_TASK_FINISHED);
+            mStateCallback.setStateOnUiThread(STATE_CURRENT_TASK_FINISHED);
         } else {
             synchronized (mRecentsAnimationController) {
                 mRecentsAnimationController.finish(true /* toRecents */,
-                        () -> setStateOnUiThread(STATE_CURRENT_TASK_FINISHED));
+                        () -> mStateCallback.setStateOnUiThread(STATE_CURRENT_TASK_FINISHED));
             }
         }
         ActiveGestureLog.INSTANCE.addLog("finishRecentsAnimation", true);
@@ -1176,7 +1176,7 @@
     private void finishCurrentTransitionToHome() {
         synchronized (mRecentsAnimationController) {
             mRecentsAnimationController.finish(true /* toRecents */,
-                    () -> setStateOnUiThread(STATE_CURRENT_TASK_FINISHED),
+                    () -> mStateCallback.setStateOnUiThread(STATE_CURRENT_TASK_FINISHED),
                     true /* sendUserLeaveHint */);
         }
         ActiveGestureLog.INSTANCE.addLog("finishRecentsAnimation", true);
diff --git a/quickstep/recents_ui_overrides/src/com/android/quickstep/inputconsumers/DeviceLockedInputConsumer.java b/quickstep/recents_ui_overrides/src/com/android/quickstep/inputconsumers/DeviceLockedInputConsumer.java
index 8fb2e2a..370f161 100644
--- a/quickstep/recents_ui_overrides/src/com/android/quickstep/inputconsumers/DeviceLockedInputConsumer.java
+++ b/quickstep/recents_ui_overrides/src/com/android/quickstep/inputconsumers/DeviceLockedInputConsumer.java
@@ -113,7 +113,7 @@
 
         // Init states
         mStateCallback = new MultiStateCallback(STATE_NAMES);
-        mStateCallback.addCallback(STATE_TARGET_RECEIVED | STATE_HANDLER_INVALIDATED,
+        mStateCallback.runOnceAtState(STATE_TARGET_RECEIVED | STATE_HANDLER_INVALIDATED,
                 this::endRemoteAnimation);
 
         mVelocityTracker = VelocityTracker.obtain();
diff --git a/quickstep/recents_ui_overrides/src/com/android/quickstep/inputconsumers/FallbackNoButtonInputConsumer.java b/quickstep/recents_ui_overrides/src/com/android/quickstep/inputconsumers/FallbackNoButtonInputConsumer.java
index 5b76ba5..e062fc1 100644
--- a/quickstep/recents_ui_overrides/src/com/android/quickstep/inputconsumers/FallbackNoButtonInputConsumer.java
+++ b/quickstep/recents_ui_overrides/src/com/android/quickstep/inputconsumers/FallbackNoButtonInputConsumer.java
@@ -145,20 +145,20 @@
     private void initStateCallbacks() {
         mStateCallback = new MultiStateCallback(STATE_NAMES);
 
-        mStateCallback.addCallback(STATE_HANDLER_INVALIDATED,
+        mStateCallback.runOnceAtState(STATE_HANDLER_INVALIDATED,
                 this::onHandlerInvalidated);
-        mStateCallback.addCallback(STATE_RECENTS_PRESENT | STATE_HANDLER_INVALIDATED,
+        mStateCallback.runOnceAtState(STATE_RECENTS_PRESENT | STATE_HANDLER_INVALIDATED,
                 this::onHandlerInvalidatedWithRecents);
 
-        mStateCallback.addCallback(STATE_GESTURE_CANCELLED | STATE_APP_CONTROLLER_RECEIVED,
+        mStateCallback.runOnceAtState(STATE_GESTURE_CANCELLED | STATE_APP_CONTROLLER_RECEIVED,
                 this::finishAnimationTargetSetAnimationComplete);
 
         if (mInQuickSwitchMode) {
-            mStateCallback.addCallback(STATE_GESTURE_COMPLETED | STATE_APP_CONTROLLER_RECEIVED
+            mStateCallback.runOnceAtState(STATE_GESTURE_COMPLETED | STATE_APP_CONTROLLER_RECEIVED
                             | STATE_RECENTS_PRESENT,
                     this::finishAnimationTargetSet);
         } else {
-            mStateCallback.addCallback(STATE_GESTURE_COMPLETED | STATE_APP_CONTROLLER_RECEIVED,
+            mStateCallback.runOnceAtState(STATE_GESTURE_COMPLETED | STATE_APP_CONTROLLER_RECEIVED,
                     this::finishAnimationTargetSet);
         }
     }
@@ -186,7 +186,7 @@
                 mRecentsView.onGestureAnimationStart(mRunningTaskId);
             }
         }
-        setStateOnUiThread(STATE_RECENTS_PRESENT);
+        mStateCallback.setStateOnUiThread(STATE_RECENTS_PRESENT);
         return true;
     }
 
@@ -251,7 +251,7 @@
     public void onGestureCancelled() {
         updateDisplacement(0);
         mGestureState.setEndTarget(LAST_TASK);
-        setStateOnUiThread(STATE_GESTURE_CANCELLED);
+        mStateCallback.setStateOnUiThread(STATE_GESTURE_CANCELLED);
     }
 
     @Override
@@ -275,7 +275,7 @@
                         : LAST_TASK);
             }
         }
-        setStateOnUiThread(STATE_GESTURE_COMPLETED);
+        mStateCallback.setStateOnUiThread(STATE_GESTURE_COMPLETED);
     }
 
     @Override
@@ -302,7 +302,7 @@
                 mRecentsView.setOnScrollChangeListener(null);
             }
         } else {
-            setStateOnUiThread(STATE_HANDLER_INVALIDATED);
+            mStateCallback.setStateOnUiThread(STATE_HANDLER_INVALIDATED);
         }
     }
 
@@ -366,7 +366,7 @@
             }
         }
 
-        setStateOnUiThread(STATE_HANDLER_INVALIDATED);
+        mStateCallback.setStateOnUiThread(STATE_HANDLER_INVALIDATED);
     }
 
     private void finishAnimationTargetSet() {
@@ -436,14 +436,14 @@
         }
         applyTransformUnchecked();
 
-        setStateOnUiThread(STATE_APP_CONTROLLER_RECEIVED);
+        mStateCallback.setStateOnUiThread(STATE_APP_CONTROLLER_RECEIVED);
     }
 
     @Override
     public void onRecentsAnimationCanceled(ThumbnailData thumbnailData) {
         super.onRecentsAnimationCanceled(thumbnailData);
         mRecentsView.setRecentsAnimationTargets(null, null);
-        setStateOnUiThread(STATE_HANDLER_INVALIDATED);
+        mStateCallback.setStateOnUiThread(STATE_HANDLER_INVALIDATED);
     }
 
     /**
diff --git a/quickstep/recents_ui_overrides/src/com/android/quickstep/inputconsumers/ResetGestureInputConsumer.java b/quickstep/recents_ui_overrides/src/com/android/quickstep/inputconsumers/ResetGestureInputConsumer.java
index 5ef5246..d34b40b 100644
--- a/quickstep/recents_ui_overrides/src/com/android/quickstep/inputconsumers/ResetGestureInputConsumer.java
+++ b/quickstep/recents_ui_overrides/src/com/android/quickstep/inputconsumers/ResetGestureInputConsumer.java
@@ -19,7 +19,6 @@
 
 import com.android.quickstep.InputConsumer;
 import com.android.quickstep.TaskAnimationManager;
-import com.android.quickstep.TouchInteractionService;
 
 /**
  * A NO_OP input consumer which also resets any pending gesture
diff --git a/quickstep/src/com/android/quickstep/GestureState.java b/quickstep/src/com/android/quickstep/GestureState.java
index 67eb9de..98ff410 100644
--- a/quickstep/src/com/android/quickstep/GestureState.java
+++ b/quickstep/src/com/android/quickstep/GestureState.java
@@ -129,8 +129,8 @@
     /**
      * Adds a callback for when the states matching the given {@param stateMask} is set.
      */
-    public void addCallback(int stateMask, Runnable callback) {
-        mStateCallback.addCallback(stateMask, callback);
+    public void runOnceAtState(int stateMask, Runnable callback) {
+        mStateCallback.runOnceAtState(stateMask, callback);
     }
 
     /**
@@ -196,7 +196,7 @@
      */
     public void setFinishingRecentsAnimationTaskId(int taskId) {
         mFinishingRecentsAnimationTaskId = taskId;
-        mStateCallback.addCallback(STATE_RECENTS_ANIMATION_FINISHED, () -> {
+        mStateCallback.runOnceAtState(STATE_RECENTS_ANIMATION_FINISHED, () -> {
             mFinishingRecentsAnimationTaskId = -1;
         });
     }
diff --git a/quickstep/src/com/android/quickstep/MultiStateCallback.java b/quickstep/src/com/android/quickstep/MultiStateCallback.java
index 357c9fc..6c65e01 100644
--- a/quickstep/src/com/android/quickstep/MultiStateCallback.java
+++ b/quickstep/src/com/android/quickstep/MultiStateCallback.java
@@ -15,11 +15,17 @@
  */
 package com.android.quickstep;
 
+import static com.android.launcher3.Utilities.postAsyncCallback;
+import static com.android.launcher3.util.Executors.MAIN_EXECUTOR;
+
+import android.os.Looper;
 import android.util.Log;
 import android.util.SparseArray;
 
 import com.android.launcher3.config.FeatureFlags;
 
+import java.util.ArrayList;
+import java.util.LinkedList;
 import java.util.StringJoiner;
 import java.util.function.Consumer;
 
@@ -31,16 +37,29 @@
     private static final String TAG = "MultiStateCallback";
     public static final boolean DEBUG_STATES = false;
 
-    private final SparseArray<Runnable> mCallbacks = new SparseArray<>();
-    private final SparseArray<Consumer<Boolean>> mStateChangeHandlers = new SparseArray<>();
+    private final SparseArray<LinkedList<Runnable>> mCallbacks = new SparseArray<>();
+    private final SparseArray<ArrayList<Consumer<Boolean>>> mStateChangeListeners =
+            new SparseArray<>();
 
     private final String[] mStateNames;
 
+    private int mState = 0;
+
     public MultiStateCallback(String[] stateNames) {
         mStateNames = DEBUG_STATES ? stateNames : null;
     }
 
-    private int mState = 0;
+    /**
+     * Adds the provided state flags to the global state on the UI thread and executes any callbacks
+     * as a result.
+     */
+    public void setStateOnUiThread(int stateFlag) {
+        if (Looper.myLooper() == Looper.getMainLooper()) {
+            setState(stateFlag);
+        } else {
+            postAsyncCallback(MAIN_EXECUTOR.getHandler(), () -> setState(stateFlag));
+        }
+    }
 
     /**
      * Adds the provided state flags to the global state and executes any callbacks as a result.
@@ -51,7 +70,7 @@
                     + convertToFlagNames(stateFlag) + " to " + convertToFlagNames(mState));
         }
 
-        int oldState = mState;
+        final int oldState = mState;
         mState = mState | stateFlag;
 
         int count = mCallbacks.size();
@@ -59,15 +78,13 @@
             int state = mCallbacks.keyAt(i);
 
             if ((mState & state) == state) {
-                Runnable callback = mCallbacks.valueAt(i);
-                if (callback != null) {
-                    // Set the callback to null, so that it does not run again.
-                    mCallbacks.setValueAt(i, null);
-                    callback.run();
+                LinkedList<Runnable> callbacks = mCallbacks.valueAt(i);
+                while (!callbacks.isEmpty()) {
+                    callbacks.pollFirst().run();
                 }
             }
         }
-        notifyStateChangeHandlers(oldState);
+        notifyStateChangeListeners(oldState);
     }
 
     /**
@@ -82,38 +99,61 @@
 
         int oldState = mState;
         mState = mState & ~stateFlag;
-        notifyStateChangeHandlers(oldState);
+        notifyStateChangeListeners(oldState);
     }
 
-    private void notifyStateChangeHandlers(int oldState) {
-        int count = mStateChangeHandlers.size();
+    private void notifyStateChangeListeners(int oldState) {
+        int count = mStateChangeListeners.size();
         for (int i = 0; i < count; i++) {
-            int state = mStateChangeHandlers.keyAt(i);
+            int state = mStateChangeListeners.keyAt(i);
             boolean wasOn = (state & oldState) == state;
             boolean isOn = (state & mState) == state;
 
             if (wasOn != isOn) {
-                mStateChangeHandlers.valueAt(i).accept(isOn);
+                ArrayList<Consumer<Boolean>> listeners = mStateChangeListeners.valueAt(i);
+                for (Consumer<Boolean> listener : listeners) {
+                    listener.accept(isOn);
+                }
             }
         }
     }
 
     /**
-     * Sets the callbacks to be run when the provided states are enabled.
-     * The callback is only run once.
+     * Sets a callback to be run when the provided states in the given {@param stateMask} is
+     * enabled. The callback is only run *once*, and if the states are already set at the time of
+     * this call then the callback will be made immediately.
      */
-    public void addCallback(int stateMask, Runnable callback) {
-        if (FeatureFlags.IS_DOGFOOD_BUILD && mCallbacks.get(stateMask) != null) {
-            throw new IllegalStateException("Multiple callbacks on same state");
+    public void runOnceAtState(int stateMask, Runnable callback) {
+        if ((mState & stateMask) == stateMask) {
+            callback.run();
+        } else {
+            final LinkedList<Runnable> callbacks;
+            if (mCallbacks.indexOfKey(stateMask) >= 0) {
+                callbacks = mCallbacks.get(stateMask);
+                if (FeatureFlags.IS_DOGFOOD_BUILD && callbacks.contains(callback)) {
+                    throw new IllegalStateException("Existing callback for state found");
+                }
+            } else {
+                callbacks = new LinkedList<>();
+                mCallbacks.put(stateMask, callbacks);
+            }
+            callbacks.add(callback);
         }
-        mCallbacks.put(stateMask, callback);
     }
 
     /**
-     * Sets the handler to be called when the provided states are enabled or disabled.
+     * Adds a persistent listener to be called states in the given {@param stateMask} are enabled
+     * or disabled.
      */
-    public void addChangeHandler(int stateMask, Consumer<Boolean> handler) {
-        mStateChangeHandlers.put(stateMask, handler);
+    public void addChangeListener(int stateMask, Consumer<Boolean> listener) {
+        final ArrayList<Consumer<Boolean>> listeners;
+        if (mStateChangeListeners.indexOfKey(stateMask) >= 0) {
+            listeners = mStateChangeListeners.get(stateMask);
+        } else {
+            listeners = new ArrayList<>();
+            mStateChangeListeners.put(stateMask, listeners);
+        }
+        listeners.add(listener);
     }
 
     public int getState() {