diff --git a/src/com/android/launcher3/Launcher.java b/src/com/android/launcher3/Launcher.java
index 6c9d4c0..4e566ab 100644
--- a/src/com/android/launcher3/Launcher.java
+++ b/src/com/android/launcher3/Launcher.java
@@ -312,7 +312,7 @@
 
     private static boolean sIsNewProcess = true;
 
-    private StateManager<LauncherState> mStateManager;
+    private StateManager<LauncherState, Launcher> mStateManager;
 
     private static final int ON_ACTIVITY_RESULT_ANIMATION_DELAY = 500;
 
@@ -2775,7 +2775,7 @@
     }
 
     @Override
-    protected void collectStateHandlers(List<StateHandler> out) {
+     public void collectStateHandlers(List<StateHandler<LauncherState>> out) {
         out.add(getAllAppsController());
         out.add(getWorkspace());
     }
@@ -2996,7 +2996,7 @@
     }
 
     @Override
-    public StateManager<LauncherState> getStateManager() {
+    public StateManager<LauncherState, Launcher> getStateManager() {
         return mStateManager;
     }
 
diff --git a/src/com/android/launcher3/LauncherState.java b/src/com/android/launcher3/LauncherState.java
index d2d56f2..102189b 100644
--- a/src/com/android/launcher3/LauncherState.java
+++ b/src/com/android/launcher3/LauncherState.java
@@ -440,7 +440,7 @@
      */
     public void onBackInvoked(Launcher launcher) {
         if (this != NORMAL) {
-            StateManager<LauncherState> lsm = launcher.getStateManager();
+            StateManager<LauncherState, Launcher> lsm = launcher.getStateManager();
             LauncherState lastState = lsm.getLastState();
             lsm.goToState(lastState, forEndCallback(this::onBackAnimationCompleted));
         }
@@ -461,7 +461,7 @@
      */
     public void onBackProgressed(
             Launcher launcher, @FloatRange(from = 0.0, to = 1.0) float backProgress) {
-        StateManager<LauncherState> lsm = launcher.getStateManager();
+        StateManager<LauncherState, Launcher> lsm = launcher.getStateManager();
         LauncherState toState = lsm.getLastState();
         lsm.onBackProgressed(toState, backProgress);
     }
@@ -471,7 +471,7 @@
      * predictive back gesture.
      */
     public void onBackCancelled(Launcher launcher) {
-        StateManager<LauncherState> lsm = launcher.getStateManager();
+        StateManager<LauncherState, Launcher> lsm = launcher.getStateManager();
         LauncherState toState = lsm.getLastState();
         lsm.onBackCancelled(toState);
     }
diff --git a/src/com/android/launcher3/Workspace.java b/src/com/android/launcher3/Workspace.java
index 584d089..e601a3e 100644
--- a/src/com/android/launcher3/Workspace.java
+++ b/src/com/android/launcher3/Workspace.java
@@ -526,7 +526,7 @@
         }
 
         updateChildrenLayersEnabled();
-        StateManager<LauncherState> stateManager = mLauncher.getStateManager();
+        StateManager<LauncherState, Launcher> stateManager = mLauncher.getStateManager();
         stateManager.addStateListener(new StateManager.StateListener<LauncherState>() {
             @Override
             public void onStateTransitionComplete(LauncherState finalState) {
diff --git a/src/com/android/launcher3/allapps/LauncherAllAppsContainerView.java b/src/com/android/launcher3/allapps/LauncherAllAppsContainerView.java
index 63a168e..1e0f289 100644
--- a/src/com/android/launcher3/allapps/LauncherAllAppsContainerView.java
+++ b/src/com/android/launcher3/allapps/LauncherAllAppsContainerView.java
@@ -56,7 +56,7 @@
             return false;
         }
         Launcher launcher = mActivityContext;
-        StateManager<LauncherState> manager = launcher.getStateManager();
+        StateManager<LauncherState, Launcher> manager = launcher.getStateManager();
         if (manager.isInTransition() && manager.getTargetState() != null) {
             return manager.getTargetState().shouldFloatingSearchBarUsePillWhenUnfocused(launcher);
         }
@@ -70,7 +70,7 @@
             return super.getFloatingSearchBarRestingMarginBottom();
         }
         Launcher launcher = mActivityContext;
-        StateManager<LauncherState> stateManager = launcher.getStateManager();
+        StateManager<LauncherState, Launcher> stateManager = launcher.getStateManager();
 
         // We want to rest at the current state's resting position, unless we are in transition and
         // the target state's resting position is higher (that way if we are closing the keyboard,
@@ -95,7 +95,7 @@
             return super.getFloatingSearchBarRestingMarginStart();
         }
 
-        StateManager<LauncherState> stateManager = mActivityContext.getStateManager();
+        StateManager<LauncherState, Launcher> stateManager = mActivityContext.getStateManager();
 
         // Special case to not expand the search bar when exiting All Apps on phones.
         if (stateManager.getCurrentStableState() == LauncherState.ALL_APPS
@@ -117,7 +117,7 @@
             return super.getFloatingSearchBarRestingMarginEnd();
         }
 
-        StateManager<LauncherState> stateManager = mActivityContext.getStateManager();
+        StateManager<LauncherState, Launcher> stateManager = mActivityContext.getStateManager();
 
         // Special case to not expand the search bar when exiting All Apps on phones.
         if (stateManager.getCurrentStableState() == LauncherState.ALL_APPS
diff --git a/src/com/android/launcher3/statemanager/BaseState.java b/src/com/android/launcher3/statemanager/BaseState.java
index a01d402..b81729a 100644
--- a/src/com/android/launcher3/statemanager/BaseState.java
+++ b/src/com/android/launcher3/statemanager/BaseState.java
@@ -21,7 +21,7 @@
 import com.android.launcher3.views.ActivityContext;
 
 /**
- * Interface representing a state of a StatefulActivity
+ * Interface representing a state of a StatefulContainer
  */
 public interface BaseState<T extends BaseState> {
 
diff --git a/src/com/android/launcher3/statemanager/StateManager.java b/src/com/android/launcher3/statemanager/StateManager.java
index eea1a7d..ac07c0f 100644
--- a/src/com/android/launcher3/statemanager/StateManager.java
+++ b/src/com/android/launcher3/statemanager/StateManager.java
@@ -26,6 +26,7 @@
 import android.animation.Animator.AnimatorListener;
 import android.animation.AnimatorListenerAdapter;
 import android.animation.AnimatorSet;
+import android.content.Context;
 import android.os.Handler;
 import android.os.Looper;
 import android.util.Log;
@@ -47,8 +48,11 @@
 /**
  * Class to manage transitions between different states for a StatefulActivity based on different
  * states
+ * @param STATE_TYPE Basestate used by the state manager
+ * @param STATEFUL_CONTAINER container object used to manage state
  */
-public class StateManager<STATE_TYPE extends BaseState<STATE_TYPE>> {
+public class StateManager<STATE_TYPE extends BaseState<STATE_TYPE>,
+        STATEFUL_CONTAINER extends Context & StatefulContainer<STATE_TYPE>> {
 
     public static final String TAG = "StateManager";
     // b/279059025, b/325463989
@@ -56,7 +60,7 @@
 
     private final AnimationState mConfig = new AnimationState();
     private final Handler mUiHandler;
-    private final StatefulActivity<STATE_TYPE> mActivity;
+    private final STATEFUL_CONTAINER mStatefulContainer;
     private final ArrayList<StateListener<STATE_TYPE>> mListeners = new ArrayList<>();
     private final STATE_TYPE mBaseState;
 
@@ -71,12 +75,12 @@
 
     private STATE_TYPE mRestState;
 
-    public StateManager(StatefulActivity<STATE_TYPE> l, STATE_TYPE baseState) {
+    public StateManager(STATEFUL_CONTAINER container, STATE_TYPE baseState) {
         mUiHandler = new Handler(Looper.getMainLooper());
-        mActivity = l;
+        mStatefulContainer = container;
         mBaseState = baseState;
         mState = mLastStableState = mCurrentStableState = baseState;
-        mAtomicAnimationFactory = l.createAtomicAnimationFactory();
+        mAtomicAnimationFactory = container.createAtomicAnimationFactory();
     }
 
     public STATE_TYPE getState() {
@@ -109,10 +113,10 @@
         writer.println(prefix + "\tisInTransition:" + isInTransition());
     }
 
-    public StateHandler[] getStateHandlers() {
+    public StateHandler<STATE_TYPE>[] getStateHandlers() {
         if (mStateHandlers == null) {
-            ArrayList<StateHandler> handlers = new ArrayList<>();
-            mActivity.collectStateHandlers(handlers);
+            ArrayList<StateHandler<STATE_TYPE>> handlers = new ArrayList<>();
+            mStatefulContainer.collectStateHandlers(handlers);
             mStateHandlers = handlers.toArray(new StateHandler[handlers.size()]);
         }
         return mStateHandlers;
@@ -130,7 +134,7 @@
      * Returns true if the state changes should be animated.
      */
     public boolean shouldAnimateStateChange() {
-        return !mActivity.isForceInvisible() && mActivity.isStarted();
+        return mStatefulContainer.shouldAnimateStateChange();
     }
 
     /**
@@ -245,7 +249,7 @@
         }
 
         animated &= areAnimatorsEnabled();
-        if (mActivity.isInState(state)) {
+        if (mStatefulContainer.isInState(state)) {
             if (mConfig.currentAnimation == null) {
                 // Run any queued runnable
                 if (listener != null) {
@@ -302,8 +306,8 @@
         // Since state mBaseState can be reached from multiple states, just assume that the
         // transition plays in reverse and use the same duration as previous state.
         mConfig.duration = state == mBaseState
-                ? fromState.getTransitionDuration(mActivity, false /* isToState */)
-                : state.getTransitionDuration(mActivity, true /* isToState */);
+                ? fromState.getTransitionDuration(mStatefulContainer, false /* isToState */)
+                : state.getTransitionDuration(mStatefulContainer, true /* isToState */);
         prepareForAtomicAnimation(fromState, state, mConfig);
         AnimatorSet animation = createAnimationToNewWorkspaceInternal(state).buildAnim();
         if (listener != null) {
@@ -336,7 +340,7 @@
         PendingAnimation builder = new PendingAnimation(config.duration);
         prepareForAtomicAnimation(fromState, toState, config);
 
-        for (StateHandler handler : mActivity.getStateManager().getStateHandlers()) {
+        for (StateHandler handler : mStatefulContainer.getStateManager().getStateHandlers()) {
             handler.setStateWithAnimation(toState, config, builder);
         }
         return builder.buildAnim();
@@ -402,7 +406,7 @@
 
     private void onStateTransitionStart(STATE_TYPE state) {
         mState = state;
-        mActivity.onStateSetStart(mState);
+        mStatefulContainer.onStateSetStart(mState);
 
         if (DEBUG) {
             Log.d(TAG, "onStateTransitionStart - state: " + state);
@@ -419,7 +423,7 @@
             mCurrentStableState = state;
         }
 
-        mActivity.onStateSetEnd(state);
+        mStatefulContainer.onStateSetEnd(state);
         if (state == mBaseState) {
             setRestState(null);
         }
diff --git a/src/com/android/launcher3/statemanager/StatefulActivity.java b/src/com/android/launcher3/statemanager/StatefulActivity.java
index 30ba703..28f2def 100644
--- a/src/com/android/launcher3/statemanager/StatefulActivity.java
+++ b/src/com/android/launcher3/statemanager/StatefulActivity.java
@@ -18,7 +18,6 @@
 import static android.content.pm.ActivityInfo.CONFIG_ORIENTATION;
 import static android.content.pm.ActivityInfo.CONFIG_SCREEN_SIZE;
 
-import static com.android.launcher3.LauncherState.FLAG_CLOSE_POPUPS;
 import static com.android.launcher3.LauncherState.FLAG_NON_INTERACTIVE;
 
 import android.content.res.Configuration;
@@ -29,11 +28,9 @@
 
 import androidx.annotation.CallSuper;
 
-import com.android.launcher3.AbstractFloatingView;
 import com.android.launcher3.BaseDraggingActivity;
 import com.android.launcher3.LauncherRootView;
 import com.android.launcher3.Utilities;
-import com.android.launcher3.statemanager.StateManager.AtomicAnimationFactory;
 import com.android.launcher3.statemanager.StateManager.StateHandler;
 import com.android.launcher3.util.window.WindowManagerProxy;
 import com.android.launcher3.views.BaseDragLayer;
@@ -45,7 +42,7 @@
  * @param <STATE_TYPE> Type of state object
  */
 public abstract class StatefulActivity<STATE_TYPE extends BaseState<STATE_TYPE>>
-        extends BaseDraggingActivity {
+        extends BaseDraggingActivity implements StatefulContainer<STATE_TYPE> {
 
     public final Handler mHandler = new Handler();
     private final Runnable mHandleDeferredResume = this::handleDeferredResume;
@@ -67,19 +64,9 @@
     /**
      * Create handlers to control the property changes for this activity
      */
-    protected abstract void collectStateHandlers(List<StateHandler> out);
 
-    /**
-     * Returns true if the activity is in the provided state
-     */
-    public boolean isInState(STATE_TYPE state) {
-        return getStateManager().getState() == state;
-    }
-
-    /**
-     * Returns the state manager for this activity
-     */
-    public abstract StateManager<STATE_TYPE> getStateManager();
+    @Override
+    public abstract void collectStateHandlers(List<StateHandler<STATE_TYPE>> out);
 
     protected void inflateRootView(int layoutId) {
         mRootView = (LauncherRootView) LayoutInflater.from(this).inflate(layoutId, null);
@@ -106,22 +93,12 @@
         if (mDeferredResumePending) {
             handleDeferredResume();
         }
-
-        if (state.hasFlag(FLAG_CLOSE_POPUPS)) {
-            AbstractFloatingView.closeAllOpenViews(this, !state.hasFlag(FLAG_NON_INTERACTIVE));
-        }
+        StatefulContainer.super.onStateSetStart(state);
     }
 
-    /**
-     * Called when transition to state ends
-     */
-    public void onStateSetEnd(STATE_TYPE state) { }
-
-    /**
-     * Creates a factory for atomic state animations
-     */
-    public AtomicAnimationFactory<STATE_TYPE> createAtomicAnimationFactory() {
-        return new AtomicAnimationFactory(0);
+    @Override
+    public boolean shouldAnimateStateChange() {
+        return !isForceInvisible() && isStarted();
     }
 
     @Override
diff --git a/src/com/android/launcher3/statemanager/StatefulContainer.java b/src/com/android/launcher3/statemanager/StatefulContainer.java
new file mode 100644
index 0000000..0cf0a27
--- /dev/null
+++ b/src/com/android/launcher3/statemanager/StatefulContainer.java
@@ -0,0 +1,84 @@
+/*
+ * Copyright (C) 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.launcher3.statemanager;
+
+
+import static com.android.launcher3.LauncherState.FLAG_CLOSE_POPUPS;
+import static com.android.launcher3.statemanager.BaseState.FLAG_NON_INTERACTIVE;
+
+import androidx.annotation.CallSuper;
+
+import com.android.launcher3.AbstractFloatingView;
+import com.android.launcher3.views.ActivityContext;
+
+import java.util.List;
+
+/**
+ * Interface for a container that can be managed by a state manager.
+ *
+ * @param <STATE_TYPE> The type of state that the container can be in.
+ */
+public interface StatefulContainer<STATE_TYPE extends BaseState<STATE_TYPE>> extends
+        ActivityContext {
+
+    /**
+     * Creates a factory for atomic state animations
+     */
+    default StateManager.AtomicAnimationFactory<STATE_TYPE> createAtomicAnimationFactory() {
+        return new StateManager.AtomicAnimationFactory<>(0);
+    }
+
+    /**
+     * Create handlers to control the property changes for this activity
+     */
+    void collectStateHandlers(List<StateManager.StateHandler<STATE_TYPE>> out);
+
+    /**
+     * Retrieves state manager for given container
+     */
+    StateManager<STATE_TYPE, ?> getStateManager();
+
+    /**
+     * Called when transition to state ends
+     * @param state current state of State_Type
+     */
+    default void onStateSetEnd(STATE_TYPE state) { }
+
+    /**
+     * Called when transition to state starts
+     * @param state current state of State_Type
+     */
+    @CallSuper
+    default void onStateSetStart(STATE_TYPE state) {
+        if (state.hasFlag(FLAG_CLOSE_POPUPS)) {
+            AbstractFloatingView.closeAllOpenViews(this, !state.hasFlag(FLAG_NON_INTERACTIVE));
+        }
+    }
+
+    /**
+     * Returns true if the activity is in the provided state
+     * @param state current state of State_Type
+     */
+    default boolean isInState(STATE_TYPE state) {
+        return getStateManager().getState() == state;
+    }
+
+    /**
+     * Returns true if state change should transition with animation
+     */
+    boolean shouldAnimateStateChange();
+}
diff --git a/src/com/android/launcher3/util/CannedAnimationCoordinator.kt b/src/com/android/launcher3/util/CannedAnimationCoordinator.kt
index 85f81f5..749caac 100644
--- a/src/com/android/launcher3/util/CannedAnimationCoordinator.kt
+++ b/src/com/android/launcher3/util/CannedAnimationCoordinator.kt
@@ -113,10 +113,9 @@
         // flags accordingly
         animationController =
             controller.apply {
-                activity.stateManager.setCurrentAnimation(
-                    this,
-                    USER_CONTROLLED or HANDLE_STATE_APPLY
-                )
+                activity
+                    .stateManager
+                    .setCurrentAnimation(this, USER_CONTROLLED or HANDLE_STATE_APPLY)
             }
         recreateAnimation(provider)
     }
