diff --git a/packages/SystemUI/Android.bp b/packages/SystemUI/Android.bp
index 2f5b5f4..9ee160c 100644
--- a/packages/SystemUI/Android.bp
+++ b/packages/SystemUI/Android.bp
@@ -114,6 +114,7 @@
         "androidx.dynamicanimation_dynamicanimation",
         "androidx-constraintlayout_constraintlayout",
         "androidx.exifinterface_exifinterface",
+        "androidx.test.ext.junit",
         "com.google.android.material_material",
         "kotlinx_coroutines_android",
         "kotlinx_coroutines",
diff --git a/packages/SystemUI/src/com/android/systemui/controls/management/ControlsEditingActivity.kt b/packages/SystemUI/src/com/android/systemui/controls/management/ControlsEditingActivity.kt
index 5e8ce6d..b11103a 100644
--- a/packages/SystemUI/src/com/android/systemui/controls/management/ControlsEditingActivity.kt
+++ b/packages/SystemUI/src/com/android/systemui/controls/management/ControlsEditingActivity.kt
@@ -20,11 +20,14 @@
 import android.content.ComponentName
 import android.content.Intent
 import android.os.Bundle
+import android.util.Log
 import android.view.View
 import android.view.ViewGroup
 import android.view.ViewStub
 import android.widget.Button
 import android.widget.TextView
+import android.window.OnBackInvokedCallback
+import android.window.OnBackInvokedDispatcher
 import androidx.activity.ComponentActivity
 import androidx.recyclerview.widget.GridLayoutManager
 import androidx.recyclerview.widget.ItemTouchHelper
@@ -42,7 +45,7 @@
 /**
  * Activity for rearranging and removing controls for a given structure
  */
-class ControlsEditingActivity @Inject constructor(
+open class ControlsEditingActivity @Inject constructor(
     private val controller: ControlsControllerImpl,
     private val broadcastDispatcher: BroadcastDispatcher,
     private val customIconCache: CustomIconCache,
@@ -50,8 +53,9 @@
 ) : ComponentActivity() {
 
     companion object {
+        private const val DEBUG = false
         private const val TAG = "ControlsEditingActivity"
-        private const val EXTRA_STRUCTURE = ControlsFavoritingActivity.EXTRA_STRUCTURE
+        const val EXTRA_STRUCTURE = ControlsFavoritingActivity.EXTRA_STRUCTURE
         private val SUBTITLE_ID = R.string.controls_favorite_rearrange
         private val EMPTY_TEXT_ID = R.string.controls_favorite_removed
     }
@@ -73,6 +77,13 @@
         }
     }
 
+    private val mOnBackInvokedCallback = OnBackInvokedCallback {
+        if (DEBUG) {
+            Log.d(TAG, "Predictive Back dispatcher called mOnBackInvokedCallback")
+        }
+        onBackPressed()
+    }
+
     override fun onCreate(savedInstanceState: Bundle?) {
         super.onCreate(savedInstanceState)
 
@@ -94,11 +105,22 @@
         setUpList()
 
         currentUserTracker.startTracking()
+
+        if (DEBUG) {
+            Log.d(TAG, "Registered onBackInvokedCallback")
+        }
+        onBackInvokedDispatcher.registerOnBackInvokedCallback(
+                OnBackInvokedDispatcher.PRIORITY_DEFAULT, mOnBackInvokedCallback)
     }
 
     override fun onStop() {
         super.onStop()
         currentUserTracker.stopTracking()
+
+        if (DEBUG) {
+            Log.d(TAG, "Unregistered onBackInvokedCallback")
+        }
+        onBackInvokedDispatcher.unregisterOnBackInvokedCallback(mOnBackInvokedCallback)
     }
 
     override fun onBackPressed() {
diff --git a/packages/SystemUI/src/com/android/systemui/controls/management/ControlsFavoritingActivity.kt b/packages/SystemUI/src/com/android/systemui/controls/management/ControlsFavoritingActivity.kt
index be572c5..9b2a728 100644
--- a/packages/SystemUI/src/com/android/systemui/controls/management/ControlsFavoritingActivity.kt
+++ b/packages/SystemUI/src/com/android/systemui/controls/management/ControlsFavoritingActivity.kt
@@ -24,6 +24,7 @@
 import android.content.res.Configuration
 import android.os.Bundle
 import android.text.TextUtils
+import android.util.Log
 import android.view.Gravity
 import android.view.View
 import android.view.ViewGroup
@@ -32,6 +33,8 @@
 import android.widget.FrameLayout
 import android.widget.TextView
 import android.widget.Toast
+import android.window.OnBackInvokedCallback
+import android.window.OnBackInvokedDispatcher
 import androidx.activity.ComponentActivity
 import androidx.viewpager2.widget.ViewPager2
 import com.android.systemui.Prefs
@@ -50,7 +53,7 @@
 import java.util.function.Consumer
 import javax.inject.Inject
 
-class ControlsFavoritingActivity @Inject constructor(
+open class ControlsFavoritingActivity @Inject constructor(
     @Main private val executor: Executor,
     private val controller: ControlsControllerImpl,
     private val listingController: ControlsListingController,
@@ -59,6 +62,7 @@
 ) : ComponentActivity() {
 
     companion object {
+        private const val DEBUG = false
         private const val TAG = "ControlsFavoritingActivity"
 
         // If provided and no structure is available, use as the title
@@ -67,7 +71,7 @@
         // If provided, show this structure page first
         const val EXTRA_STRUCTURE = "extra_structure"
         const val EXTRA_SINGLE_STRUCTURE = "extra_single_structure"
-        internal const val EXTRA_FROM_PROVIDER_SELECTOR = "extra_from_provider_selector"
+        const val EXTRA_FROM_PROVIDER_SELECTOR = "extra_from_provider_selector"
         private const val TOOLTIP_PREFS_KEY = Prefs.Key.CONTROLS_STRUCTURE_SWIPE_TOOLTIP_COUNT
         private const val TOOLTIP_MAX_SHOWN = 2
     }
@@ -102,6 +106,13 @@
         }
     }
 
+    private val mOnBackInvokedCallback = OnBackInvokedCallback {
+        if (DEBUG) {
+            Log.d(TAG, "Predictive Back dispatcher called mOnBackInvokedCallback")
+        }
+        onBackPressed()
+    }
+
     private val listingCallback = object : ControlsListingController.ControlsListingCallback {
 
         override fun onServicesUpdated(serviceInfos: List<ControlsServiceInfo>) {
@@ -346,13 +357,19 @@
     override fun onPause() {
         super.onPause()
         mTooltipManager?.hide(false)
-    }
+   }
 
     override fun onStart() {
         super.onStart()
 
         listingController.addCallback(listingCallback)
         currentUserTracker.startTracking()
+
+        if (DEBUG) {
+            Log.d(TAG, "Registered onBackInvokedCallback")
+        }
+        onBackInvokedDispatcher.registerOnBackInvokedCallback(
+                OnBackInvokedDispatcher.PRIORITY_DEFAULT, mOnBackInvokedCallback)
     }
 
     override fun onResume() {
@@ -365,13 +382,19 @@
             loadControls()
             isPagerLoaded = true
         }
-    }
+   }
 
     override fun onStop() {
         super.onStop()
 
         listingController.removeCallback(listingCallback)
         currentUserTracker.stopTracking()
+
+        if (DEBUG) {
+            Log.d(TAG, "Unregistered onBackInvokedCallback")
+        }
+        onBackInvokedDispatcher.unregisterOnBackInvokedCallback(
+                mOnBackInvokedCallback)
     }
 
     override fun onConfigurationChanged(newConfig: Configuration) {
diff --git a/packages/SystemUI/src/com/android/systemui/controls/management/ControlsProviderSelectorActivity.kt b/packages/SystemUI/src/com/android/systemui/controls/management/ControlsProviderSelectorActivity.kt
index b26615f..47690a7 100644
--- a/packages/SystemUI/src/com/android/systemui/controls/management/ControlsProviderSelectorActivity.kt
+++ b/packages/SystemUI/src/com/android/systemui/controls/management/ControlsProviderSelectorActivity.kt
@@ -20,16 +20,18 @@
 import android.content.ComponentName
 import android.content.Intent
 import android.os.Bundle
+import android.util.Log
 import android.view.LayoutInflater
 import android.view.View
 import android.view.ViewGroup
 import android.view.ViewStub
 import android.widget.Button
 import android.widget.TextView
+import android.window.OnBackInvokedCallback
+import android.window.OnBackInvokedDispatcher
 import androidx.activity.ComponentActivity
 import androidx.recyclerview.widget.LinearLayoutManager
 import androidx.recyclerview.widget.RecyclerView
-import androidx.recyclerview.widget.RecyclerView.AdapterDataObserver
 import com.android.systemui.R
 import com.android.systemui.broadcast.BroadcastDispatcher
 import com.android.systemui.controls.controller.ControlsController
@@ -44,7 +46,7 @@
 /**
  * Activity to select an application to favorite the [Control] provided by them.
  */
-class ControlsProviderSelectorActivity @Inject constructor(
+open class ControlsProviderSelectorActivity @Inject constructor(
     @Main private val executor: Executor,
     @Background private val backExecutor: Executor,
     private val listingController: ControlsListingController,
@@ -54,6 +56,7 @@
 ) : ComponentActivity() {
 
     companion object {
+        private const val DEBUG = false
         private const val TAG = "ControlsProviderSelectorActivity"
         const val BACK_SHOULD_EXIT = "back_should_exit"
     }
@@ -70,6 +73,13 @@
         }
     }
 
+    private val mOnBackInvokedCallback = OnBackInvokedCallback {
+        if (DEBUG) {
+            Log.d(TAG, "Predictive Back dispatcher called mOnBackInvokedCallback")
+        }
+        onBackPressed()
+    }
+
     override fun onCreate(savedInstanceState: Bundle?) {
         super.onCreate(savedInstanceState)
 
@@ -141,11 +151,22 @@
                 }
             })
         }
+
+        if (DEBUG) {
+            Log.d(TAG, "Registered onBackInvokedCallback")
+        }
+        onBackInvokedDispatcher.registerOnBackInvokedCallback(
+                OnBackInvokedDispatcher.PRIORITY_DEFAULT, mOnBackInvokedCallback)
     }
 
     override fun onStop() {
         super.onStop()
         currentUserTracker.stopTracking()
+
+        if (DEBUG) {
+            Log.d(TAG, "Unregistered onBackInvokedCallback")
+        }
+        onBackInvokedDispatcher.unregisterOnBackInvokedCallback(mOnBackInvokedCallback)
     }
 
     /**
diff --git a/packages/SystemUI/tests/AndroidManifest.xml b/packages/SystemUI/tests/AndroidManifest.xml
index 1b404a8..8abdc87 100644
--- a/packages/SystemUI/tests/AndroidManifest.xml
+++ b/packages/SystemUI/tests/AndroidManifest.xml
@@ -93,6 +93,21 @@
             android:excludeFromRecents="true"
             />
 
+        <activity android:name="com.android.systemui.controls.management.ControlsEditingActivityTest$TestableControlsEditingActivity"
+            android:exported="false"
+            android:excludeFromRecents="true"
+            />
+
+        <activity android:name="com.android.systemui.controls.management.ControlsFavoritingActivityTest$TestableControlsFavoritingActivity"
+            android:exported="false"
+            android:excludeFromRecents="true"
+            />
+
+        <activity android:name="com.android.systemui.controls.management.ControlsProviderSelectorActivityTest$TestableControlsProviderSelectorActivity"
+            android:exported="false"
+            android:excludeFromRecents="true"
+            />
+
         <activity android:name="com.android.systemui.screenshot.ScrollViewActivity"
                   android:exported="false" />
 
diff --git a/packages/SystemUI/tests/src/com/android/systemui/controls/management/ControlsEditingActivityTest.kt b/packages/SystemUI/tests/src/com/android/systemui/controls/management/ControlsEditingActivityTest.kt
new file mode 100644
index 0000000..0b72a68
--- /dev/null
+++ b/packages/SystemUI/tests/src/com/android/systemui/controls/management/ControlsEditingActivityTest.kt
@@ -0,0 +1,112 @@
+package com.android.systemui.controls.management
+
+import android.content.ComponentName
+import android.content.Intent
+import android.testing.AndroidTestingRunner
+import android.testing.TestableLooper
+import android.window.OnBackInvokedCallback
+import android.window.OnBackInvokedDispatcher
+import androidx.test.filters.SmallTest
+import androidx.test.rule.ActivityTestRule
+import androidx.test.runner.intercepting.SingleActivityFactory
+import com.android.systemui.SysuiTestCase
+import com.android.systemui.broadcast.BroadcastDispatcher
+import com.android.systemui.controls.CustomIconCache
+import com.android.systemui.controls.controller.ControlsControllerImpl
+import com.android.systemui.controls.ui.ControlsUiController
+import java.util.concurrent.CountDownLatch
+import org.junit.Before
+import org.junit.Rule
+import org.junit.Test
+import org.junit.runner.RunWith
+import org.mockito.ArgumentCaptor
+import org.mockito.ArgumentMatchers
+import org.mockito.Captor
+import org.mockito.Mock
+import org.mockito.Mockito.verify
+import org.mockito.MockitoAnnotations
+
+@SmallTest
+@RunWith(AndroidTestingRunner::class)
+@TestableLooper.RunWithLooper
+class ControlsEditingActivityTest : SysuiTestCase() {
+    @Mock lateinit var controller: ControlsControllerImpl
+
+    @Mock lateinit var broadcastDispatcher: BroadcastDispatcher
+
+    @Mock lateinit var customIconCache: CustomIconCache
+
+    @Mock lateinit var uiController: ControlsUiController
+
+    private lateinit var controlsEditingActivity: ControlsEditingActivity_Factory
+    private var latch: CountDownLatch = CountDownLatch(1)
+
+    @Mock private lateinit var mockDispatcher: OnBackInvokedDispatcher
+    @Captor private lateinit var captureCallback: ArgumentCaptor<OnBackInvokedCallback>
+
+    @Rule
+    @JvmField
+    var activityRule =
+        ActivityTestRule(
+            object :
+                SingleActivityFactory<TestableControlsEditingActivity>(
+                    TestableControlsEditingActivity::class.java
+                ) {
+                override fun create(intent: Intent?): TestableControlsEditingActivity {
+                    return TestableControlsEditingActivity(
+                        controller,
+                        broadcastDispatcher,
+                        customIconCache,
+                        uiController,
+                        mockDispatcher,
+                        latch
+                    )
+                }
+            },
+            false,
+            false
+        )
+
+    @Before
+    fun setUp() {
+        MockitoAnnotations.initMocks(this)
+        val intent = Intent()
+        intent.putExtra(ControlsEditingActivity.EXTRA_STRUCTURE, "TestTitle")
+        val cname = ComponentName("TestPackageName", "TestClassName")
+        intent.putExtra(Intent.EXTRA_COMPONENT_NAME, cname)
+        activityRule.launchActivity(intent)
+    }
+
+    @Test
+    fun testBackCallbackRegistrationAndUnregistration() {
+        // 1. ensure that launching the activity results in it registering a callback
+        verify(mockDispatcher)
+            .registerOnBackInvokedCallback(
+                ArgumentMatchers.eq(OnBackInvokedDispatcher.PRIORITY_DEFAULT),
+                captureCallback.capture()
+            )
+        activityRule.finishActivity()
+        latch.await() // ensure activity is finished
+        // 2. ensure that when the activity is finished, it unregisters the same callback
+        verify(mockDispatcher).unregisterOnBackInvokedCallback(captureCallback.value)
+    }
+
+    public class TestableControlsEditingActivity(
+        private val controller: ControlsControllerImpl,
+        private val broadcastDispatcher: BroadcastDispatcher,
+        private val customIconCache: CustomIconCache,
+        private val uiController: ControlsUiController,
+        private val mockDispatcher: OnBackInvokedDispatcher,
+        private val latch: CountDownLatch
+    ) : ControlsEditingActivity(controller, broadcastDispatcher, customIconCache, uiController) {
+        override fun getOnBackInvokedDispatcher(): OnBackInvokedDispatcher {
+            return mockDispatcher
+        }
+
+        override fun onStop() {
+            super.onStop()
+            // ensures that test runner thread does not proceed until ui thread is done
+            latch.countDown()
+        }
+    }
+}
diff --git a/packages/SystemUI/tests/src/com/android/systemui/controls/management/ControlsFavoritingActivityTest.kt b/packages/SystemUI/tests/src/com/android/systemui/controls/management/ControlsFavoritingActivityTest.kt
new file mode 100644
index 0000000..4b0f7e6
--- /dev/null
+++ b/packages/SystemUI/tests/src/com/android/systemui/controls/management/ControlsFavoritingActivityTest.kt
@@ -0,0 +1,122 @@
+package com.android.systemui.controls.management
+
+import android.content.Intent
+import android.testing.AndroidTestingRunner
+import android.testing.TestableLooper
+import android.window.OnBackInvokedCallback
+import android.window.OnBackInvokedDispatcher
+import androidx.test.filters.SmallTest
+import androidx.test.rule.ActivityTestRule
+import androidx.test.runner.intercepting.SingleActivityFactory
+import com.android.systemui.SysuiTestCase
+import com.android.systemui.broadcast.BroadcastDispatcher
+import com.android.systemui.controls.controller.ControlsControllerImpl
+import com.android.systemui.controls.ui.ControlsUiController
+import com.android.systemui.dagger.qualifiers.Main
+import com.google.common.util.concurrent.MoreExecutors
+import java.util.concurrent.CountDownLatch
+import java.util.concurrent.Executor
+import org.junit.Before
+import org.junit.Rule
+import org.junit.Test
+import org.junit.runner.RunWith
+import org.mockito.ArgumentCaptor
+import org.mockito.ArgumentMatchers
+import org.mockito.Captor
+import org.mockito.Mock
+import org.mockito.Mockito.verify
+import org.mockito.MockitoAnnotations
+
+@SmallTest
+@RunWith(AndroidTestingRunner::class)
+@TestableLooper.RunWithLooper
+class ControlsFavoritingActivityTest : SysuiTestCase() {
+    @Main private val executor: Executor = MoreExecutors.directExecutor()
+
+    @Mock lateinit var controller: ControlsControllerImpl
+
+    @Mock lateinit var listingController: ControlsListingController
+
+    @Mock lateinit var broadcastDispatcher: BroadcastDispatcher
+
+    @Mock lateinit var uiController: ControlsUiController
+
+    private lateinit var controlsFavoritingActivity: ControlsFavoritingActivity_Factory
+    private var latch: CountDownLatch = CountDownLatch(1)
+
+    @Mock private lateinit var mockDispatcher: OnBackInvokedDispatcher
+    @Captor private lateinit var captureCallback: ArgumentCaptor<OnBackInvokedCallback>
+
+    @Rule
+    @JvmField
+    var activityRule =
+        ActivityTestRule(
+            object :
+                SingleActivityFactory<TestableControlsFavoritingActivity>(
+                    TestableControlsFavoritingActivity::class.java
+                ) {
+                override fun create(intent: Intent?): TestableControlsFavoritingActivity {
+                    return TestableControlsFavoritingActivity(
+                        executor,
+                        controller,
+                        listingController,
+                        broadcastDispatcher,
+                        uiController,
+                        mockDispatcher,
+                        latch
+                    )
+                }
+            },
+            false,
+            false
+        )
+
+    @Before
+    fun setUp() {
+        MockitoAnnotations.initMocks(this)
+        val intent = Intent()
+        intent.putExtra(ControlsFavoritingActivity.EXTRA_FROM_PROVIDER_SELECTOR, true)
+        activityRule.launchActivity(intent)
+    }
+
+    @Test
+    fun testBackCallbackRegistrationAndUnregistration() {
+        // 1. ensure that launching the activity results in it registering a callback
+        verify(mockDispatcher)
+            .registerOnBackInvokedCallback(
+                ArgumentMatchers.eq(OnBackInvokedDispatcher.PRIORITY_DEFAULT),
+                captureCallback.capture()
+            )
+        activityRule.finishActivity()
+        latch.await() // ensure activity is finished
+        // 2. ensure that when the activity is finished, it unregisters the same callback
+        verify(mockDispatcher).unregisterOnBackInvokedCallback(captureCallback.value)
+    }
+
+    public class TestableControlsFavoritingActivity(
+        executor: Executor,
+        controller: ControlsControllerImpl,
+        listingController: ControlsListingController,
+        broadcastDispatcher: BroadcastDispatcher,
+        uiController: ControlsUiController,
+        private val mockDispatcher: OnBackInvokedDispatcher,
+        private val latch: CountDownLatch
+    ) :
+        ControlsFavoritingActivity(
+            executor,
+            controller,
+            listingController,
+            broadcastDispatcher,
+            uiController
+        ) {
+        override fun getOnBackInvokedDispatcher(): OnBackInvokedDispatcher {
+            return mockDispatcher
+        }
+
+        override fun onStop() {
+            super.onStop()
+            // ensures that test runner thread does not proceed until ui thread is done
+            latch.countDown()
+        }
+    }
+}
diff --git a/packages/SystemUI/tests/src/com/android/systemui/controls/management/ControlsProviderSelectorActivityTest.kt b/packages/SystemUI/tests/src/com/android/systemui/controls/management/ControlsProviderSelectorActivityTest.kt
new file mode 100644
index 0000000..acc6222
--- /dev/null
+++ b/packages/SystemUI/tests/src/com/android/systemui/controls/management/ControlsProviderSelectorActivityTest.kt
@@ -0,0 +1,144 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.systemui.controls.management
+
+import android.content.Intent
+import android.testing.AndroidTestingRunner
+import android.testing.TestableLooper
+import android.window.OnBackInvokedCallback
+import android.window.OnBackInvokedDispatcher
+import androidx.test.filters.SmallTest
+import androidx.test.rule.ActivityTestRule
+import androidx.test.runner.intercepting.SingleActivityFactory
+import com.android.systemui.SysuiTestCase
+import com.android.systemui.broadcast.BroadcastDispatcher
+import com.android.systemui.controls.controller.ControlsController
+import com.android.systemui.controls.ui.ControlsUiController
+import com.android.systemui.dagger.qualifiers.Background
+import com.android.systemui.dagger.qualifiers.Main
+import com.google.common.util.concurrent.MoreExecutors
+import java.util.concurrent.CountDownLatch
+import java.util.concurrent.Executor
+import org.junit.Before
+import org.junit.Rule
+import org.junit.Test
+import org.junit.runner.RunWith
+import org.mockito.ArgumentCaptor
+import org.mockito.ArgumentMatchers
+import org.mockito.Captor
+import org.mockito.Mock
+import org.mockito.Mockito.verify
+import org.mockito.MockitoAnnotations
+
+@SmallTest
+@RunWith(AndroidTestingRunner::class)
+@TestableLooper.RunWithLooper
+class ControlsProviderSelectorActivityTest : SysuiTestCase() {
+    @Main private val executor: Executor = MoreExecutors.directExecutor()
+
+    @Background private val backExecutor: Executor = MoreExecutors.directExecutor()
+
+    @Mock lateinit var listingController: ControlsListingController
+
+    @Mock lateinit var controlsController: ControlsController
+
+    @Mock lateinit var broadcastDispatcher: BroadcastDispatcher
+
+    @Mock lateinit var uiController: ControlsUiController
+
+    private lateinit var controlsProviderSelectorActivity: ControlsProviderSelectorActivity_Factory
+    private var latch: CountDownLatch = CountDownLatch(1)
+
+    @Mock private lateinit var mockDispatcher: OnBackInvokedDispatcher
+    @Captor private lateinit var captureCallback: ArgumentCaptor<OnBackInvokedCallback>
+
+    @Rule
+    @JvmField
+    var activityRule =
+        ActivityTestRule(
+            object :
+                SingleActivityFactory<TestableControlsProviderSelectorActivity>(
+                    TestableControlsProviderSelectorActivity::class.java
+                ) {
+                override fun create(intent: Intent?): TestableControlsProviderSelectorActivity {
+                    return TestableControlsProviderSelectorActivity(
+                        executor,
+                        backExecutor,
+                        listingController,
+                        controlsController,
+                        broadcastDispatcher,
+                        uiController,
+                        mockDispatcher,
+                        latch
+                    )
+                }
+            },
+            false,
+            false
+        )
+
+    @Before
+    fun setUp() {
+        MockitoAnnotations.initMocks(this)
+        val intent = Intent()
+        intent.putExtra(ControlsProviderSelectorActivity.BACK_SHOULD_EXIT, true)
+        activityRule.launchActivity(intent)
+    }
+
+    @Test
+    fun testBackCallbackRegistrationAndUnregistration() {
+        // 1. ensure that launching the activity results in it registering a callback
+        verify(mockDispatcher)
+            .registerOnBackInvokedCallback(
+                ArgumentMatchers.eq(OnBackInvokedDispatcher.PRIORITY_DEFAULT),
+                captureCallback.capture()
+            )
+        activityRule.finishActivity()
+        latch.await() // ensure activity is finished
+        // 2. ensure that when the activity is finished, it unregisters the same callback
+        verify(mockDispatcher).unregisterOnBackInvokedCallback(captureCallback.value)
+    }
+
+    public class TestableControlsProviderSelectorActivity(
+        executor: Executor,
+        backExecutor: Executor,
+        listingController: ControlsListingController,
+        controlsController: ControlsController,
+        broadcastDispatcher: BroadcastDispatcher,
+        uiController: ControlsUiController,
+        private val mockDispatcher: OnBackInvokedDispatcher,
+        private val latch: CountDownLatch
+    ) :
+        ControlsProviderSelectorActivity(
+            executor,
+            backExecutor,
+            listingController,
+            controlsController,
+            broadcastDispatcher,
+            uiController
+        ) {
+        override fun getOnBackInvokedDispatcher(): OnBackInvokedDispatcher {
+            return mockDispatcher
+        }
+
+        override fun onStop() {
+            super.onStop()
+            // ensures that test runner thread does not proceed until ui thread is done
+            latch.countDown()
+        }
+    }
+}
