Merge "3/3 Move some Shell utils to the Shared package." into main
diff --git a/core/api/system-current.txt b/core/api/system-current.txt
index 3637ca7..e7ed8fb 100644
--- a/core/api/system-current.txt
+++ b/core/api/system-current.txt
@@ -61,6 +61,7 @@
     field @FlaggedApi("com.android.internal.telephony.flags.use_oem_domain_selection_service") public static final String BIND_DOMAIN_SELECTION_SERVICE = "android.permission.BIND_DOMAIN_SELECTION_SERVICE";
     field public static final String BIND_DOMAIN_VERIFICATION_AGENT = "android.permission.BIND_DOMAIN_VERIFICATION_AGENT";
     field public static final String BIND_EUICC_SERVICE = "android.permission.BIND_EUICC_SERVICE";
+    field @FlaggedApi("android.crashrecovery.flags.enable_crashrecovery") public static final String BIND_EXPLICIT_HEALTH_CHECK_SERVICE = "android.permission.BIND_EXPLICIT_HEALTH_CHECK_SERVICE";
     field public static final String BIND_EXTERNAL_STORAGE_SERVICE = "android.permission.BIND_EXTERNAL_STORAGE_SERVICE";
     field public static final String BIND_FIELD_CLASSIFICATION_SERVICE = "android.permission.BIND_FIELD_CLASSIFICATION_SERVICE";
     field public static final String BIND_GBA_SERVICE = "android.permission.BIND_GBA_SERVICE";
@@ -4631,6 +4632,7 @@
     method public int getCommittedSessionId();
     method @NonNull public java.util.List<android.content.rollback.PackageRollbackInfo> getPackages();
     method public int getRollbackId();
+    method @FlaggedApi("android.crashrecovery.flags.enable_crashrecovery") public int getRollbackImpactLevel();
     method public boolean isStaged();
     method public void writeToParcel(android.os.Parcel, int);
     field @NonNull public static final android.os.Parcelable.Creator<android.content.rollback.RollbackInfo> CREATOR;
diff --git a/core/api/test-current.txt b/core/api/test-current.txt
index 009d082..6454e73 100644
--- a/core/api/test-current.txt
+++ b/core/api/test-current.txt
@@ -1269,10 +1269,6 @@
 
 package android.content.rollback {
 
-  public final class RollbackInfo implements android.os.Parcelable {
-    method @FlaggedApi("android.content.pm.recoverability_detection") public int getRollbackImpactLevel();
-  }
-
   public final class RollbackManager {
     method @RequiresPermission(android.Manifest.permission.TEST_MANAGE_ROLLBACKS) public void blockRollbackManager(long);
     method @RequiresPermission(android.Manifest.permission.TEST_MANAGE_ROLLBACKS) public void expireRollbackForPackage(@NonNull String);
diff --git a/core/java/android/content/rollback/RollbackInfo.java b/core/java/android/content/rollback/RollbackInfo.java
index d128055..a20159d 100644
--- a/core/java/android/content/rollback/RollbackInfo.java
+++ b/core/java/android/content/rollback/RollbackInfo.java
@@ -19,8 +19,6 @@
 import android.annotation.FlaggedApi;
 import android.annotation.NonNull;
 import android.annotation.SystemApi;
-import android.annotation.TestApi;
-import android.content.pm.Flags;
 import android.content.pm.PackageManager;
 import android.content.pm.VersionedPackage;
 import android.os.Parcel;
@@ -136,11 +134,8 @@
      * Get rollback impact level. Refer {@link
      * android.content.pm.PackageInstaller.SessionParams#setRollbackImpactLevel(int)} for more info
      * on impact level.
-     *
-     * @hide
      */
-    @TestApi
-    @FlaggedApi(Flags.FLAG_RECOVERABILITY_DETECTION)
+    @FlaggedApi(android.crashrecovery.flags.Flags.FLAG_ENABLE_CRASHRECOVERY)
     public @PackageManager.RollbackImpactLevel int getRollbackImpactLevel() {
         return mRollbackImpactLevel;
     }
diff --git a/core/java/com/android/internal/protolog/PerfettoProtoLogImpl.java b/core/java/com/android/internal/protolog/PerfettoProtoLogImpl.java
index 2ff8c8c..4264358 100644
--- a/core/java/com/android/internal/protolog/PerfettoProtoLogImpl.java
+++ b/core/java/com/android/internal/protolog/PerfettoProtoLogImpl.java
@@ -930,37 +930,47 @@
     }
 
     private static class Message {
+        @Nullable
         private final Long mMessageHash;
+        @Nullable
         private final Integer mMessageMask;
+        @Nullable
         private final String mMessageString;
 
-        private Message(Long messageHash, int messageMask) {
+        private Message(long messageHash, int messageMask) {
             this.mMessageHash = messageHash;
             this.mMessageMask = messageMask;
             this.mMessageString = null;
         }
 
-        private Message(String messageString) {
+        private Message(@NonNull String messageString) {
             this.mMessageHash = null;
             final List<Integer> argTypes = LogDataType.parseFormatString(messageString);
             this.mMessageMask = LogDataType.logDataTypesToBitMask(argTypes);
             this.mMessageString = messageString;
         }
 
-        private int getMessageMask() {
+        @Nullable
+        private Integer getMessageMask() {
             return mMessageMask;
         }
 
+        @Nullable
         private String getMessage() {
             return mMessageString;
         }
 
+        @Nullable
         private String getMessage(@NonNull ProtoLogViewerConfigReader viewerConfigReader) {
             if (mMessageString != null) {
                 return mMessageString;
             }
 
-            return viewerConfigReader.getViewerString(mMessageHash);
+            if (mMessageHash != null) {
+                return viewerConfigReader.getViewerString(mMessageHash);
+            }
+
+            throw new RuntimeException("Both mMessageString and mMessageHash should never be null");
         }
     }
 }
diff --git a/core/java/com/android/internal/protolog/ProtoLogConfigurationService.java b/core/java/com/android/internal/protolog/ProtoLogConfigurationService.java
index 1765738..eeac139 100644
--- a/core/java/com/android/internal/protolog/ProtoLogConfigurationService.java
+++ b/core/java/com/android/internal/protolog/ProtoLogConfigurationService.java
@@ -23,6 +23,7 @@
 import static android.internal.perfetto.protos.Protolog.ProtoLogViewerConfig.MESSAGES;
 import static android.internal.perfetto.protos.Protolog.ProtoLogViewerConfig.MessageData.GROUP_ID;
 import static android.internal.perfetto.protos.Protolog.ProtoLogViewerConfig.MessageData.LEVEL;
+import static android.internal.perfetto.protos.Protolog.ProtoLogViewerConfig.MessageData.LOCATION;
 import static android.internal.perfetto.protos.Protolog.ProtoLogViewerConfig.MessageData.MESSAGE;
 import static android.internal.perfetto.protos.Protolog.ProtoLogViewerConfig.MessageData.MESSAGE_ID;
 import static android.internal.perfetto.protos.TracePacketOuterClass.TracePacket.PROTOLOG_VIEWER_CONFIG;
@@ -210,8 +211,7 @@
          *                             want to write to the trace buffer.
          * @throws FileNotFoundException if the viewerConfigFilePath is invalid.
          */
-        void trace(@NonNull ProtoLogDataSource dataSource, @NonNull String viewerConfigFilePath)
-                throws FileNotFoundException;
+        void trace(@NonNull ProtoLogDataSource dataSource, @NonNull String viewerConfigFilePath);
     }
 
     @Override
@@ -351,11 +351,7 @@
 
     private void onTracingInstanceFlush() {
         for (String fileName : mConfigFileCounts.keySet()) {
-            try {
-                mViewerConfigFileTracer.trace(mDataSource, fileName);
-            } catch (FileNotFoundException e) {
-                throw new RuntimeException(e);
-            }
+            mViewerConfigFileTracer.trace(mDataSource, fileName);
         }
     }
 
@@ -364,10 +360,16 @@
     }
 
     private static void dumpTransitionTraceConfig(@NonNull ProtoLogDataSource dataSource,
-            @NonNull String viewerConfigFilePath) throws FileNotFoundException {
-        final var pis = new ProtoInputStream(new FileInputStream(viewerConfigFilePath));
-
+            @NonNull String viewerConfigFilePath) {
         dataSource.trace(ctx -> {
+            final ProtoInputStream pis;
+            try {
+                pis = new ProtoInputStream(new FileInputStream(viewerConfigFilePath));
+            } catch (FileNotFoundException e) {
+                throw new RuntimeException(
+                        "Failed to load viewer config file " + viewerConfigFilePath, e);
+            }
+
             try {
                 final ProtoOutputStream os = ctx.newTracePacket();
 
@@ -396,11 +398,7 @@
             mConfigFileCounts.put(configFile, newCount);
             boolean lastProcessWithViewerConfig = newCount == 0;
             if (lastProcessWithViewerConfig) {
-                try {
-                    mViewerConfigFileTracer.trace(mDataSource, configFile);
-                } catch (FileNotFoundException e) {
-                    throw new RuntimeException(e);
-                }
+                mViewerConfigFileTracer.trace(mDataSource, configFile);
             }
         }
     }
@@ -446,6 +444,7 @@
                 case (int) MESSAGE -> os.write(MESSAGE, pis.readString(MESSAGE));
                 case (int) LEVEL -> os.write(LEVEL, pis.readInt(LEVEL));
                 case (int) GROUP_ID -> os.write(GROUP_ID, pis.readInt(GROUP_ID));
+                case (int) LOCATION -> os.write(LOCATION, pis.readString(LOCATION));
                 default ->
                     throw new RuntimeException(
                             "Unexpected field id " + pis.getFieldNumber());
diff --git a/core/java/com/android/internal/protolog/ProtoLogViewerConfigReader.java b/core/java/com/android/internal/protolog/ProtoLogViewerConfigReader.java
index 38ca0d8..3b24f27 100644
--- a/core/java/com/android/internal/protolog/ProtoLogViewerConfigReader.java
+++ b/core/java/com/android/internal/protolog/ProtoLogViewerConfigReader.java
@@ -3,7 +3,6 @@
 import static android.internal.perfetto.protos.Protolog.ProtoLogViewerConfig.GROUPS;
 import static android.internal.perfetto.protos.Protolog.ProtoLogViewerConfig.Group.ID;
 import static android.internal.perfetto.protos.Protolog.ProtoLogViewerConfig.Group.NAME;
-
 import static android.internal.perfetto.protos.Protolog.ProtoLogViewerConfig.MESSAGES;
 import static android.internal.perfetto.protos.Protolog.ProtoLogViewerConfig.MessageData.MESSAGE;
 import static android.internal.perfetto.protos.Protolog.ProtoLogViewerConfig.MessageData.MESSAGE_ID;
@@ -11,7 +10,6 @@
 
 import android.annotation.NonNull;
 import android.annotation.Nullable;
-import android.util.Log;
 import android.util.LongSparseArray;
 import android.util.proto.ProtoInputStream;
 
@@ -38,6 +36,7 @@
      * Returns message format string for its hash or null if unavailable
      * or the viewer config is not loaded into memory.
      */
+    @Nullable
     public synchronized String getViewerString(long messageHash) {
         return mLogMessageMap.get(messageHash);
     }
diff --git a/core/res/AndroidManifest.xml b/core/res/AndroidManifest.xml
index 17ff2eb..5decf7f 100644
--- a/core/res/AndroidManifest.xml
+++ b/core/res/AndroidManifest.xml
@@ -7648,7 +7648,8 @@
     <permission android:name="android.permission.BIND_CARRIER_MESSAGING_CLIENT_SERVICE"
         android:protectionLevel="signature" />
 
-    <!-- Must be required by an {@link android.service.watchdog.ExplicitHealthCheckService} to
+    <!-- @FlaggedApi(android.crashrecovery.flags.Flags.FLAG_ENABLE_CRASHRECOVERY) @SystemApi
+         Must be required by an {@link android.service.watchdog.ExplicitHealthCheckService} to
          ensure that only the system can bind to it.
          @hide This is not a third-party API (intended for OEMs and system apps).
     -->
diff --git a/libs/input/MouseCursorController.cpp b/libs/input/MouseCursorController.cpp
index eecc741..1afef75 100644
--- a/libs/input/MouseCursorController.cpp
+++ b/libs/input/MouseCursorController.cpp
@@ -25,6 +25,9 @@
 #include <input/Input.h>
 #include <log/log.h>
 
+#define INDENT "  "
+#define INDENT2 "    "
+
 namespace {
 // Time to spend fading out the pointer completely.
 const nsecs_t POINTER_FADE_DURATION = 500 * 1000000LL; // 500 ms
@@ -449,6 +452,24 @@
     return mLocked.resourcesLoaded;
 }
 
+std::string MouseCursorController::dump() const {
+    std::string dump = INDENT "MouseCursorController:\n";
+    std::scoped_lock lock(mLock);
+    dump += StringPrintf(INDENT2 "viewport: %s\n", mLocked.viewport.toString().c_str());
+    dump += StringPrintf(INDENT2 "stylusHoverMode: %s\n",
+                         mLocked.stylusHoverMode ? "true" : "false");
+    dump += StringPrintf(INDENT2 "pointerFadeDirection: %d\n", mLocked.pointerFadeDirection);
+    dump += StringPrintf(INDENT2 "updatePointerIcon: %s\n",
+                         mLocked.updatePointerIcon ? "true" : "false");
+    dump += StringPrintf(INDENT2 "resourcesLoaded: %s\n",
+                         mLocked.resourcesLoaded ? "true" : "false");
+    dump += StringPrintf(INDENT2 "requestedPointerType: %d\n", mLocked.requestedPointerType);
+    dump += StringPrintf(INDENT2 "resolvedPointerType: %d\n", mLocked.resolvedPointerType);
+    dump += StringPrintf(INDENT2 "skipScreenshot: %s\n", mLocked.skipScreenshot ? "true" : "false");
+    dump += StringPrintf(INDENT2 "animating: %s\n", mLocked.animating ? "true" : "false");
+    return dump;
+}
+
 bool MouseCursorController::doAnimations(nsecs_t timestamp) {
     std::scoped_lock lock(mLock);
     bool keepFading = doFadingAnimationLocked(timestamp);
diff --git a/libs/input/MouseCursorController.h b/libs/input/MouseCursorController.h
index 78f6413..8600341 100644
--- a/libs/input/MouseCursorController.h
+++ b/libs/input/MouseCursorController.h
@@ -67,6 +67,8 @@
 
     bool resourcesLoaded();
 
+    std::string dump() const;
+
 private:
     mutable std::mutex mLock;
 
diff --git a/libs/input/PointerController.cpp b/libs/input/PointerController.cpp
index 11b27a2..5ae967b 100644
--- a/libs/input/PointerController.cpp
+++ b/libs/input/PointerController.cpp
@@ -25,6 +25,7 @@
 #include <android-base/stringprintf.h>
 #include <android-base/thread_annotations.h>
 #include <ftl/enum.h>
+#include <input/PrintTools.h>
 
 #include <mutex>
 
@@ -353,6 +354,8 @@
     for (const auto& [_, spotController] : mLocked.spotControllers) {
         spotController.dump(dump, INDENT3);
     }
+    dump += INDENT2 "Cursor Controller:\n";
+    dump += addLinePrefix(mCursorController.dump(), INDENT3);
     return dump;
 }
 
diff --git a/packages/SystemUI/compose/facade/enabled/src/com/android/systemui/scene/QuickSettingsShadeOverlayModule.kt b/packages/SystemUI/compose/facade/enabled/src/com/android/systemui/scene/QuickSettingsShadeOverlayModule.kt
new file mode 100644
index 0000000..bc4adf9
--- /dev/null
+++ b/packages/SystemUI/compose/facade/enabled/src/com/android/systemui/scene/QuickSettingsShadeOverlayModule.kt
@@ -0,0 +1,29 @@
+/*
+ * Copyright (C) 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.systemui.scene
+
+import com.android.systemui.qs.ui.composable.QuickSettingsShadeOverlay
+import com.android.systemui.scene.ui.composable.Overlay
+import dagger.Binds
+import dagger.Module
+import dagger.multibindings.IntoSet
+
+@Module
+interface QuickSettingsShadeOverlayModule {
+
+    @Binds @IntoSet fun quickSettingsShade(overlay: QuickSettingsShadeOverlay): Overlay
+}
diff --git a/packages/SystemUI/compose/features/src/com/android/systemui/qs/ui/composable/QuickSettingsShadeOverlay.kt b/packages/SystemUI/compose/features/src/com/android/systemui/qs/ui/composable/QuickSettingsShadeOverlay.kt
new file mode 100644
index 0000000..fa37729
--- /dev/null
+++ b/packages/SystemUI/compose/features/src/com/android/systemui/qs/ui/composable/QuickSettingsShadeOverlay.kt
@@ -0,0 +1,85 @@
+/*
+ * Copyright (C) 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.systemui.qs.ui.composable
+
+import androidx.compose.foundation.layout.Column
+import androidx.compose.foundation.layout.padding
+import androidx.compose.runtime.Composable
+import androidx.compose.ui.Modifier
+import com.android.compose.animation.scene.ContentScope
+import com.android.systemui.battery.BatteryMeterViewController
+import com.android.systemui.dagger.SysUISingleton
+import com.android.systemui.lifecycle.rememberViewModel
+import com.android.systemui.qs.ui.viewmodel.QuickSettingsShadeOverlayActionsViewModel
+import com.android.systemui.qs.ui.viewmodel.QuickSettingsShadeOverlayContentViewModel
+import com.android.systemui.scene.shared.model.Overlays
+import com.android.systemui.scene.ui.composable.Overlay
+import com.android.systemui.shade.ui.composable.ExpandedShadeHeader
+import com.android.systemui.shade.ui.composable.OverlayShade
+import com.android.systemui.statusbar.phone.ui.StatusBarIconController
+import com.android.systemui.statusbar.phone.ui.TintedIconManager
+import java.util.Optional
+import javax.inject.Inject
+
+@SysUISingleton
+class QuickSettingsShadeOverlay
+@Inject
+constructor(
+    private val actionsViewModelFactory: QuickSettingsShadeOverlayActionsViewModel.Factory,
+    private val contentViewModelFactory: QuickSettingsShadeOverlayContentViewModel.Factory,
+    private val tintedIconManagerFactory: TintedIconManager.Factory,
+    private val batteryMeterViewControllerFactory: BatteryMeterViewController.Factory,
+    private val statusBarIconController: StatusBarIconController,
+) : Overlay {
+
+    override val key = Overlays.QuickSettingsShade
+
+    private val actionsViewModel: QuickSettingsShadeOverlayActionsViewModel by lazy {
+        actionsViewModelFactory.create()
+    }
+
+    override suspend fun activate(): Nothing {
+        actionsViewModel.activate()
+    }
+
+    @Composable
+    override fun ContentScope.Content(
+        modifier: Modifier,
+    ) {
+        val viewModel =
+            rememberViewModel("QuickSettingsShadeOverlay") { contentViewModelFactory.create() }
+        OverlayShade(
+            modifier = modifier,
+            viewModelFactory = viewModel.overlayShadeViewModelFactory,
+            lockscreenContent = { Optional.empty() },
+        ) {
+            Column {
+                ExpandedShadeHeader(
+                    viewModelFactory = viewModel.shadeHeaderViewModelFactory,
+                    createTintedIconManager = tintedIconManagerFactory::create,
+                    createBatteryMeterViewController = batteryMeterViewControllerFactory::create,
+                    statusBarIconController = statusBarIconController,
+                    modifier = Modifier.padding(QuickSettingsShade.Dimensions.Padding),
+                )
+
+                ShadeBody(
+                    viewModel = viewModel.quickSettingsContainerViewModel,
+                )
+            }
+        }
+    }
+}
diff --git a/packages/SystemUI/compose/features/src/com/android/systemui/shade/ui/composable/ShadeScene.kt b/packages/SystemUI/compose/features/src/com/android/systemui/shade/ui/composable/ShadeScene.kt
index b7c6edc..d8ab0a1 100644
--- a/packages/SystemUI/compose/features/src/com/android/systemui/shade/ui/composable/ShadeScene.kt
+++ b/packages/SystemUI/compose/features/src/com/android/systemui/shade/ui/composable/ShadeScene.kt
@@ -30,7 +30,7 @@
 import androidx.compose.foundation.layout.Row
 import androidx.compose.foundation.layout.WindowInsets
 import androidx.compose.foundation.layout.asPaddingValues
-import androidx.compose.foundation.layout.displayCutoutPadding
+import androidx.compose.foundation.layout.displayCutout
 import androidx.compose.foundation.layout.fillMaxHeight
 import androidx.compose.foundation.layout.fillMaxSize
 import androidx.compose.foundation.layout.fillMaxWidth
@@ -47,8 +47,9 @@
 import androidx.compose.runtime.DisposableEffect
 import androidx.compose.runtime.LaunchedEffect
 import androidx.compose.runtime.getValue
-import androidx.compose.runtime.mutableStateOf
+import androidx.compose.runtime.mutableIntStateOf
 import androidx.compose.runtime.remember
+import androidx.compose.runtime.setValue
 import androidx.compose.ui.Alignment
 import androidx.compose.ui.Modifier
 import androidx.compose.ui.graphics.CompositingStrategy
@@ -99,7 +100,6 @@
 import com.android.systemui.notifications.ui.composable.NotificationStackCutoffGuideline
 import com.android.systemui.qs.footer.ui.compose.FooterActionsWithAnimatedVisibility
 import com.android.systemui.qs.ui.composable.BrightnessMirror
-import com.android.systemui.qs.ui.composable.QSMediaMeasurePolicy
 import com.android.systemui.qs.ui.composable.QuickSettings
 import com.android.systemui.qs.ui.composable.QuickSettings.SharedValues.MediaLandscapeTopOffset
 import com.android.systemui.qs.ui.composable.QuickSettings.SharedValues.MediaOffset.InQQS
@@ -269,13 +269,14 @@
     shadeSession: SaveableSession,
 ) {
     val cutoutLocation = LocalDisplayCutout.current.location
+    val cutoutInsets = WindowInsets.Companion.displayCutout
     val isLandscape = LocalWindowSizeClass.current.heightSizeClass == WindowHeightSizeClass.Compact
     val usingCollapsedLandscapeMedia =
         Utils.useCollapsedMediaInLandscape(LocalContext.current.resources)
     val isExpanded = !usingCollapsedLandscapeMedia || !isLandscape
     mediaHost.expansion = if (isExpanded) EXPANDED else COLLAPSED
 
-    val maxNotifScrimTop = remember { mutableStateOf(0f) }
+    var maxNotifScrimTop by remember { mutableIntStateOf(0) }
     val tileSquishiness by
         animateSceneFloatAsState(
             value = 1f,
@@ -301,6 +302,24 @@
             viewModel.qsSceneAdapter,
         )
     }
+    val shadeMeasurePolicy =
+        remember(mediaInRow) {
+            SingleShadeMeasurePolicy(
+                isMediaInRow = mediaInRow,
+                mediaOffset = { mediaOffset.roundToPx() },
+                onNotificationsTopChanged = { maxNotifScrimTop = it },
+                mediaZIndex = {
+                    if (MediaContentPicker.shouldElevateMedia(layoutState)) 1f else 0f
+                },
+                cutoutInsetsProvider = {
+                    if (cutoutLocation == CutoutLocation.CENTER) {
+                        null
+                    } else {
+                        cutoutInsets
+                    }
+                }
+            )
+        }
 
     Box(
         modifier =
@@ -318,101 +337,54 @@
                     .background(colorResource(R.color.shade_scrim_background_dark)),
         )
         Layout(
-            contents =
-                listOf(
-                    {
-                        Column(
-                            horizontalAlignment = Alignment.CenterHorizontally,
-                            modifier =
-                                Modifier.fillMaxWidth()
-                                    .thenIf(isEmptySpaceClickable) {
-                                        Modifier.clickable(
-                                            onClick = { viewModel.onEmptySpaceClicked() }
-                                        )
-                                    }
-                                    .thenIf(cutoutLocation != CutoutLocation.CENTER) {
-                                        Modifier.displayCutoutPadding()
-                                    },
-                        ) {
-                            CollapsedShadeHeader(
-                                viewModelFactory = viewModel.shadeHeaderViewModelFactory,
-                                createTintedIconManager = createTintedIconManager,
-                                createBatteryMeterViewController = createBatteryMeterViewController,
-                                statusBarIconController = statusBarIconController,
-                            )
-
-                            val content: @Composable () -> Unit = {
-                                Box(
-                                    Modifier.element(QuickSettings.Elements.QuickQuickSettings)
-                                        .layoutId(QSMediaMeasurePolicy.LayoutId.QS)
-                                ) {
-                                    QuickSettings(
-                                        viewModel.qsSceneAdapter,
-                                        { viewModel.qsSceneAdapter.qqsHeight },
-                                        isSplitShade = false,
-                                        squishiness = { tileSquishiness },
-                                    )
-                                }
-
-                                ShadeMediaCarousel(
-                                    isVisible = isMediaVisible,
-                                    mediaHost = mediaHost,
-                                    mediaOffsetProvider = mediaOffsetProvider,
-                                    modifier =
-                                        Modifier.layoutId(QSMediaMeasurePolicy.LayoutId.Media),
-                                    carouselController = mediaCarouselController,
-                                )
-                            }
-                            val landscapeQsMediaMeasurePolicy = remember {
-                                QSMediaMeasurePolicy(
-                                    { viewModel.qsSceneAdapter.qqsHeight },
-                                    { mediaOffset.roundToPx() },
-                                )
-                            }
-                            if (mediaInRow) {
-                                Layout(
-                                    content = content,
-                                    measurePolicy = landscapeQsMediaMeasurePolicy,
-                                )
-                            } else {
-                                content()
-                            }
-                        }
-                    },
-                    {
-                        NotificationScrollingStack(
-                            shadeSession = shadeSession,
-                            stackScrollView = notificationStackScrollView,
-                            viewModel = notificationsPlaceholderViewModel,
-                            maxScrimTop = { maxNotifScrimTop.value },
-                            shadeMode = ShadeMode.Single,
-                            shouldPunchHoleBehindScrim = shouldPunchHoleBehindScrim,
-                            onEmptySpaceClick =
-                                viewModel::onEmptySpaceClicked.takeIf { isEmptySpaceClickable },
-                        )
-                    },
+            modifier =
+                Modifier.thenIf(isEmptySpaceClickable) {
+                    Modifier.clickable { viewModel.onEmptySpaceClicked() }
+                },
+            content = {
+                CollapsedShadeHeader(
+                    viewModelFactory = viewModel.shadeHeaderViewModelFactory,
+                    createTintedIconManager = createTintedIconManager,
+                    createBatteryMeterViewController = createBatteryMeterViewController,
+                    statusBarIconController = statusBarIconController,
+                    modifier = Modifier.layoutId(SingleShadeMeasurePolicy.LayoutId.ShadeHeader),
                 )
-        ) { measurables, constraints ->
-            check(measurables.size == 2)
-            check(measurables[0].size == 1)
-            check(measurables[1].size == 1)
 
-            val quickSettingsPlaceable = measurables[0][0].measure(constraints)
-            val notificationsPlaceable = measurables[1][0].measure(constraints)
+                Box(
+                    Modifier.element(QuickSettings.Elements.QuickQuickSettings)
+                        .layoutId(SingleShadeMeasurePolicy.LayoutId.QuickSettings)
+                ) {
+                    QuickSettings(
+                        viewModel.qsSceneAdapter,
+                        { viewModel.qsSceneAdapter.qqsHeight },
+                        isSplitShade = false,
+                        squishiness = { tileSquishiness },
+                    )
+                }
 
-            maxNotifScrimTop.value = quickSettingsPlaceable.height.toFloat()
+                ShadeMediaCarousel(
+                    isVisible = isMediaVisible,
+                    isInRow = mediaInRow,
+                    mediaHost = mediaHost,
+                    mediaOffsetProvider = mediaOffsetProvider,
+                    carouselController = mediaCarouselController,
+                    modifier = Modifier.layoutId(SingleShadeMeasurePolicy.LayoutId.Media),
+                )
 
-            layout(constraints.maxWidth, constraints.maxHeight) {
-                val qsZIndex =
-                    if (MediaContentPicker.shouldElevateMedia(layoutState)) {
-                        1f
-                    } else {
-                        0f
-                    }
-                quickSettingsPlaceable.placeRelative(x = 0, y = 0, zIndex = qsZIndex)
-                notificationsPlaceable.placeRelative(x = 0, y = maxNotifScrimTop.value.roundToInt())
-            }
-        }
+                NotificationScrollingStack(
+                    shadeSession = shadeSession,
+                    stackScrollView = notificationStackScrollView,
+                    viewModel = notificationsPlaceholderViewModel,
+                    maxScrimTop = { maxNotifScrimTop.toFloat() },
+                    shadeMode = ShadeMode.Single,
+                    shouldPunchHoleBehindScrim = shouldPunchHoleBehindScrim,
+                    onEmptySpaceClick =
+                        viewModel::onEmptySpaceClicked.takeIf { isEmptySpaceClickable },
+                    modifier = Modifier.layoutId(SingleShadeMeasurePolicy.LayoutId.Notifications),
+                )
+            },
+            measurePolicy = shadeMeasurePolicy,
+        )
         Box(
             modifier =
                 Modifier.align(Alignment.BottomCenter)
@@ -600,6 +572,7 @@
 
                             ShadeMediaCarousel(
                                 isVisible = isMediaVisible,
+                                isInRow = false,
                                 mediaHost = mediaHost,
                                 mediaOffsetProvider = mediaOffsetProvider,
                                 modifier =
@@ -657,6 +630,7 @@
 @Composable
 private fun SceneScope.ShadeMediaCarousel(
     isVisible: Boolean,
+    isInRow: Boolean,
     mediaHost: MediaHost,
     carouselController: MediaCarouselController,
     mediaOffsetProvider: ShadeMediaOffsetProvider,
@@ -668,7 +642,7 @@
         mediaHost = mediaHost,
         carouselController = carouselController,
         offsetProvider =
-            if (MediaContentPicker.shouldElevateMedia(layoutState)) {
+            if (isInRow || MediaContentPicker.shouldElevateMedia(layoutState)) {
                 null
             } else {
                 { mediaOffsetProvider.offset }
diff --git a/packages/SystemUI/compose/features/src/com/android/systemui/shade/ui/composable/SingleShadeMeasurePolicy.kt b/packages/SystemUI/compose/features/src/com/android/systemui/shade/ui/composable/SingleShadeMeasurePolicy.kt
new file mode 100644
index 0000000..6275ac3
--- /dev/null
+++ b/packages/SystemUI/compose/features/src/com/android/systemui/shade/ui/composable/SingleShadeMeasurePolicy.kt
@@ -0,0 +1,155 @@
+/*
+ * Copyright (C) 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.systemui.shade.ui.composable
+
+import androidx.compose.foundation.layout.WindowInsets
+import androidx.compose.ui.layout.Measurable
+import androidx.compose.ui.layout.MeasurePolicy
+import androidx.compose.ui.layout.MeasureResult
+import androidx.compose.ui.layout.MeasureScope
+import androidx.compose.ui.layout.Placeable
+import androidx.compose.ui.layout.layoutId
+import androidx.compose.ui.unit.Constraints
+import androidx.compose.ui.unit.offset
+import androidx.compose.ui.util.fastFirst
+import androidx.compose.ui.util.fastFirstOrNull
+import com.android.systemui.shade.ui.composable.SingleShadeMeasurePolicy.LayoutId
+import kotlin.math.max
+
+/**
+ * Lays out elements from the [LayoutId] in the shade. This policy supports the case when the QS and
+ * UMO share the same row and when they should be one below another.
+ */
+class SingleShadeMeasurePolicy(
+    private val isMediaInRow: Boolean,
+    private val mediaOffset: MeasureScope.() -> Int,
+    private val onNotificationsTopChanged: (Int) -> Unit,
+    private val mediaZIndex: () -> Float,
+    private val cutoutInsetsProvider: () -> WindowInsets?,
+) : MeasurePolicy {
+
+    enum class LayoutId {
+        QuickSettings,
+        Media,
+        Notifications,
+        ShadeHeader,
+    }
+
+    override fun MeasureScope.measure(
+        measurables: List<Measurable>,
+        constraints: Constraints,
+    ): MeasureResult {
+        val cutoutInsets: WindowInsets? = cutoutInsetsProvider()
+        val constraintsWithCutout = applyCutout(constraints, cutoutInsets)
+        val insetsLeft = cutoutInsets?.getLeft(this, layoutDirection) ?: 0
+        val insetsTop = cutoutInsets?.getTop(this) ?: 0
+
+        val shadeHeaderPlaceable =
+            measurables
+                .fastFirst { it.layoutId == LayoutId.ShadeHeader }
+                .measure(constraintsWithCutout)
+        val mediaPlaceable =
+            measurables
+                .fastFirstOrNull { it.layoutId == LayoutId.Media }
+                ?.measure(applyMediaConstraints(constraintsWithCutout, isMediaInRow))
+        val quickSettingsPlaceable =
+            measurables
+                .fastFirst { it.layoutId == LayoutId.QuickSettings }
+                .measure(constraintsWithCutout)
+        val notificationsPlaceable =
+            measurables.fastFirst { it.layoutId == LayoutId.Notifications }.measure(constraints)
+
+        val notificationsTop =
+            calculateNotificationsTop(
+                statusBarHeaderPlaceable = shadeHeaderPlaceable,
+                quickSettingsPlaceable = quickSettingsPlaceable,
+                mediaPlaceable = mediaPlaceable,
+                insetsTop = insetsTop,
+                isMediaInRow = isMediaInRow,
+            )
+        onNotificationsTopChanged(notificationsTop)
+
+        return layout(constraints.maxWidth, constraints.maxHeight) {
+            shadeHeaderPlaceable.placeRelative(x = insetsLeft, y = insetsTop)
+            quickSettingsPlaceable.placeRelative(
+                x = insetsLeft,
+                y = insetsTop + shadeHeaderPlaceable.height,
+            )
+
+            if (isMediaInRow) {
+                mediaPlaceable?.placeRelative(
+                    x = insetsLeft + constraintsWithCutout.maxWidth / 2,
+                    y = mediaOffset() + insetsTop + shadeHeaderPlaceable.height,
+                    zIndex = mediaZIndex(),
+                )
+            } else {
+                mediaPlaceable?.placeRelative(
+                    x = insetsLeft,
+                    y = insetsTop + shadeHeaderPlaceable.height + quickSettingsPlaceable.height,
+                    zIndex = mediaZIndex(),
+                )
+            }
+
+            // Notifications don't need to accommodate for horizontal insets
+            notificationsPlaceable.placeRelative(x = 0, y = notificationsTop)
+        }
+    }
+
+    private fun calculateNotificationsTop(
+        statusBarHeaderPlaceable: Placeable,
+        quickSettingsPlaceable: Placeable,
+        mediaPlaceable: Placeable?,
+        insetsTop: Int,
+        isMediaInRow: Boolean,
+    ): Int {
+        val mediaHeight = mediaPlaceable?.height ?: 0
+        return insetsTop +
+            statusBarHeaderPlaceable.height +
+            if (isMediaInRow) {
+                max(quickSettingsPlaceable.height, mediaHeight)
+            } else {
+                quickSettingsPlaceable.height + mediaHeight
+            }
+    }
+
+    private fun applyMediaConstraints(
+        constraints: Constraints,
+        isMediaInRow: Boolean,
+    ): Constraints {
+        return if (isMediaInRow) {
+            constraints.copy(maxWidth = constraints.maxWidth / 2)
+        } else {
+            constraints
+        }
+    }
+
+    private fun MeasureScope.applyCutout(
+        constraints: Constraints,
+        cutoutInsets: WindowInsets?,
+    ): Constraints {
+        return if (cutoutInsets == null) {
+            constraints
+        } else {
+            val left = cutoutInsets.getLeft(this, layoutDirection)
+            val top = cutoutInsets.getTop(this)
+            val right = cutoutInsets.getRight(this, layoutDirection)
+            val bottom = cutoutInsets.getBottom(this)
+
+            constraints.offset(horizontal = -(left + right), vertical = -(top + bottom))
+        }
+    }
+}
diff --git a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/AnimateContent.kt b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/AnimateContent.kt
index b166737..d876606 100644
--- a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/AnimateContent.kt
+++ b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/AnimateContent.kt
@@ -21,9 +21,7 @@
 import androidx.compose.animation.core.SpringSpec
 import com.android.compose.animation.scene.content.state.TransitionState
 import kotlinx.coroutines.CoroutineScope
-import kotlinx.coroutines.CoroutineStart
 import kotlinx.coroutines.Job
-import kotlinx.coroutines.launch
 
 internal fun CoroutineScope.animateContent(
     layoutState: MutableSceneTransitionLayoutStateImpl,
@@ -31,37 +29,24 @@
     oneOffAnimation: OneOffAnimation,
     targetProgress: Float,
     chain: Boolean = true,
-) {
-    // Start the transition. This will compute the TransformationSpec associated to [transition],
-    // which we need to initialize the Animatable that will actually animate it.
-    layoutState.startTransition(transition, chain)
-
-    // The transition now contains the transformation spec that we should use to instantiate the
-    // Animatable.
-    val animationSpec = transition.transformationSpec.progressSpec
-    val visibilityThreshold =
-        (animationSpec as? SpringSpec)?.visibilityThreshold ?: ProgressVisibilityThreshold
-    val replacedTransition = transition.replacedTransition
-    val initialProgress = replacedTransition?.progress ?: 0f
-    val initialVelocity = replacedTransition?.progressVelocity ?: 0f
-    val animatable =
-        Animatable(initialProgress, visibilityThreshold = visibilityThreshold).also {
-            oneOffAnimation.animatable = it
-        }
-
-    // Animate the progress to its target value.
-    //
-    // Important: We start atomically to make sure that we start the coroutine even if it is
-    // cancelled right after it is launched, so that finishTransition() is correctly called.
-    // Otherwise, this transition will never be stopped and we will never settle to Idle.
-    oneOffAnimation.job =
-        launch(start = CoroutineStart.ATOMIC) {
-            try {
-                animatable.animateTo(targetProgress, animationSpec, initialVelocity)
-            } finally {
-                layoutState.finishTransition(transition)
+): Job {
+    oneOffAnimation.onRun = {
+        // Animate the progress to its target value.
+        val animationSpec = transition.transformationSpec.progressSpec
+        val visibilityThreshold =
+            (animationSpec as? SpringSpec)?.visibilityThreshold ?: ProgressVisibilityThreshold
+        val replacedTransition = transition.replacedTransition
+        val initialProgress = replacedTransition?.progress ?: 0f
+        val initialVelocity = replacedTransition?.progressVelocity ?: 0f
+        val animatable =
+            Animatable(initialProgress, visibilityThreshold = visibilityThreshold).also {
+                oneOffAnimation.animatable = it
             }
-        }
+
+        animatable.animateTo(targetProgress, animationSpec, initialVelocity)
+    }
+
+    return layoutState.startTransitionImmediately(animationScope = this, transition, chain)
 }
 
 internal class OneOffAnimation {
@@ -74,8 +59,8 @@
      */
     lateinit var animatable: Animatable<Float, AnimationVector1D>
 
-    /** The job that is animating [animatable]. */
-    lateinit var job: Job
+    /** The runnable to run for this animation. */
+    lateinit var onRun: suspend () -> Unit
 
     val progress: Float
         get() = animatable.value
@@ -83,7 +68,13 @@
     val progressVelocity: Float
         get() = animatable.velocity
 
-    fun finish(): Job = job
+    suspend fun run() {
+        onRun()
+    }
+
+    fun freezeAndAnimateToCurrentState() {
+        // Do nothing, the state of one-off animations never change and we directly animate to it.
+    }
 }
 
 // TODO(b/290184746): Compute a good default visibility threshold that depends on the layout size
diff --git a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/AnimateOverlay.kt b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/AnimateOverlay.kt
index e020f14..28116cb 100644
--- a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/AnimateOverlay.kt
+++ b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/AnimateOverlay.kt
@@ -18,7 +18,6 @@
 
 import com.android.compose.animation.scene.content.state.TransitionState
 import kotlinx.coroutines.CoroutineScope
-import kotlinx.coroutines.Job
 
 /** Trigger a one-off transition to show or hide an overlay. */
 internal fun CoroutineScope.showOrHideOverlay(
@@ -120,7 +119,13 @@
     override val isInitiatedByUserInput: Boolean = false
     override val isUserInputOngoing: Boolean = false
 
-    override fun finish(): Job = oneOffAnimation.finish()
+    override suspend fun run() {
+        oneOffAnimation.run()
+    }
+
+    override fun freezeAndAnimateToCurrentState() {
+        oneOffAnimation.freezeAndAnimateToCurrentState()
+    }
 }
 
 private class OneOffOverlayReplacingTransition(
@@ -140,5 +145,11 @@
     override val isInitiatedByUserInput: Boolean = false
     override val isUserInputOngoing: Boolean = false
 
-    override fun finish(): Job = oneOffAnimation.finish()
+    override suspend fun run() {
+        oneOffAnimation.run()
+    }
+
+    override fun freezeAndAnimateToCurrentState() {
+        oneOffAnimation.freezeAndAnimateToCurrentState()
+    }
 }
diff --git a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/AnimateToScene.kt b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/AnimateToScene.kt
index e15bc12..86be4a4 100644
--- a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/AnimateToScene.kt
+++ b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/AnimateToScene.kt
@@ -28,7 +28,7 @@
     layoutState: MutableSceneTransitionLayoutStateImpl,
     target: SceneKey,
     transitionKey: TransitionKey?,
-): TransitionState.Transition.ChangeScene? {
+): Pair<TransitionState.Transition.ChangeScene, Job>? {
     val transitionState = layoutState.transitionState
     if (transitionState.currentScene == target) {
         // This can happen in 3 different situations, for which there isn't anything else to do:
@@ -139,7 +139,7 @@
     reversed: Boolean = false,
     fromScene: SceneKey = layoutState.transitionState.currentScene,
     chain: Boolean = true,
-): TransitionState.Transition.ChangeScene {
+): Pair<TransitionState.Transition.ChangeScene, Job> {
     val oneOffAnimation = OneOffAnimation()
     val targetProgress = if (reversed) 0f else 1f
     val transition =
@@ -165,15 +165,16 @@
             )
         }
 
-    animateContent(
-        layoutState = layoutState,
-        transition = transition,
-        oneOffAnimation = oneOffAnimation,
-        targetProgress = targetProgress,
-        chain = chain,
-    )
+    val job =
+        animateContent(
+            layoutState = layoutState,
+            transition = transition,
+            oneOffAnimation = oneOffAnimation,
+            targetProgress = targetProgress,
+            chain = chain,
+        )
 
-    return transition
+    return transition to job
 }
 
 private class OneOffSceneTransition(
@@ -193,5 +194,11 @@
 
     override val isUserInputOngoing: Boolean = false
 
-    override fun finish(): Job = oneOffAnimation.finish()
+    override suspend fun run() {
+        oneOffAnimation.run()
+    }
+
+    override fun freezeAndAnimateToCurrentState() {
+        oneOffAnimation.freezeAndAnimateToCurrentState()
+    }
 }
diff --git a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/DraggableHandler.kt b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/DraggableHandler.kt
index 37e4daa..24fef71 100644
--- a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/DraggableHandler.kt
+++ b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/DraggableHandler.kt
@@ -28,7 +28,6 @@
 import com.android.compose.animation.scene.content.state.TransitionState.HasOverscrollProperties.Companion.DistanceUnspecified
 import com.android.compose.nestedscroll.PriorityNestedScrollConnection
 import kotlin.math.absoluteValue
-import kotlinx.coroutines.CoroutineScope
 
 internal interface DraggableHandler {
     /**
@@ -63,7 +62,6 @@
 internal class DraggableHandlerImpl(
     internal val layoutImpl: SceneTransitionLayoutImpl,
     internal val orientation: Orientation,
-    internal val coroutineScope: CoroutineScope,
 ) : DraggableHandler {
     internal val nestedScrollKey = Any()
     /** The [DraggableHandler] can only have one active [DragController] at a time. */
@@ -101,11 +99,6 @@
 
         val swipeAnimation = dragController.swipeAnimation
 
-        // Don't intercept a transition that is finishing.
-        if (swipeAnimation.isFinishing) {
-            return false
-        }
-
         // Only intercept the current transition if one of the 2 swipes results is also a transition
         // between the same pair of contents.
         val swipes = computeSwipes(startedPosition, pointersDown = 1)
@@ -140,7 +133,6 @@
             // This [transition] was already driving the animation: simply take over it.
             // Stop animating and start from the current offset.
             val oldSwipeAnimation = oldDragController.swipeAnimation
-            oldSwipeAnimation.cancelOffsetAnimation()
 
             // We need to recompute the swipe results since this is a new gesture, and the
             // fromScene.userActions may have changed.
@@ -192,13 +184,7 @@
                 else -> error("Unknown result $result ($upOrLeftResult $downOrRightResult)")
             }
 
-        return createSwipeAnimation(
-            layoutImpl,
-            layoutImpl.coroutineScope,
-            result,
-            isUpOrLeft,
-            orientation
-        )
+        return createSwipeAnimation(layoutImpl, result, isUpOrLeft, orientation)
     }
 
     private fun computeSwipes(startedPosition: Offset?, pointersDown: Int): Swipes {
@@ -279,16 +265,14 @@
 
     fun updateTransition(newTransition: SwipeAnimation<*>, force: Boolean = false) {
         if (force || isDrivingTransition) {
-            layoutState.startTransition(newTransition.contentTransition)
+            layoutState.startTransitionImmediately(
+                animationScope = draggableHandler.layoutImpl.animationScope,
+                newTransition.contentTransition,
+                true
+            )
         }
 
-        val previous = swipeAnimation
         swipeAnimation = newTransition
-
-        // Finish the previous transition.
-        if (previous != newTransition) {
-            layoutState.finishTransition(previous.contentTransition)
-        }
     }
 
     /**
@@ -302,7 +286,7 @@
     }
 
     private fun <T : ContentKey> onDrag(delta: Float, swipeAnimation: SwipeAnimation<T>): Float {
-        if (delta == 0f || !isDrivingTransition || swipeAnimation.isFinishing) {
+        if (delta == 0f || !isDrivingTransition || swipeAnimation.isAnimatingOffset()) {
             return 0f
         }
 
@@ -409,7 +393,7 @@
         swipeAnimation: SwipeAnimation<T>,
     ): Float {
         // The state was changed since the drag started; don't do anything.
-        if (!isDrivingTransition || swipeAnimation.isFinishing) {
+        if (!isDrivingTransition || swipeAnimation.isAnimatingOffset()) {
             return 0f
         }
 
diff --git a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/MultiPointerDraggable.kt b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/MultiPointerDraggable.kt
index fd4c310..5780c08 100644
--- a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/MultiPointerDraggable.kt
+++ b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/MultiPointerDraggable.kt
@@ -56,7 +56,7 @@
 import com.android.compose.ui.util.SpaceVectorConverter
 import kotlin.coroutines.cancellation.CancellationException
 import kotlin.math.sign
-import kotlinx.coroutines.coroutineScope
+import kotlinx.coroutines.currentCoroutineContext
 import kotlinx.coroutines.isActive
 import kotlinx.coroutines.launch
 
@@ -143,8 +143,8 @@
     CompositionLocalConsumerModifierNode,
     ObserverModifierNode,
     SpaceVectorConverter {
-    private val pointerInputHandler: suspend PointerInputScope.() -> Unit = { pointerInput() }
-    private val delegate = delegate(SuspendingPointerInputModifierNode(pointerInputHandler))
+    private val pointerTracker = delegate(SuspendingPointerInputModifierNode { pointerTracker() })
+    private val pointerInput = delegate(SuspendingPointerInputModifierNode { pointerInput() })
     private val velocityTracker = VelocityTracker()
     private var previousEnabled: Boolean = false
 
@@ -153,7 +153,7 @@
             // Reset the pointer input whenever enabled changed.
             if (value != field) {
                 field = value
-                delegate.resetPointerInputHandler()
+                pointerInput.resetPointerInputHandler()
             }
         }
 
@@ -173,7 +173,7 @@
             if (value != field) {
                 field = value
                 converter = SpaceVectorConverter(value)
-                delegate.resetPointerInputHandler()
+                pointerInput.resetPointerInputHandler()
             }
         }
 
@@ -186,19 +186,26 @@
         observeReads {
             val newEnabled = enabled()
             if (newEnabled != previousEnabled) {
-                delegate.resetPointerInputHandler()
+                pointerInput.resetPointerInputHandler()
             }
             previousEnabled = newEnabled
         }
     }
 
-    override fun onCancelPointerInput() = delegate.onCancelPointerInput()
+    override fun onCancelPointerInput() {
+        pointerTracker.onCancelPointerInput()
+        pointerInput.onCancelPointerInput()
+    }
 
     override fun onPointerEvent(
         pointerEvent: PointerEvent,
         pass: PointerEventPass,
         bounds: IntSize
-    ) = delegate.onPointerEvent(pointerEvent, pass, bounds)
+    ) {
+        // The order is important here: the tracker is always called first.
+        pointerTracker.onPointerEvent(pointerEvent, pass, bounds)
+        pointerInput.onPointerEvent(pointerEvent, pass, bounds)
+    }
 
     private var startedPosition: Offset? = null
     private var pointersDown: Int = 0
@@ -211,81 +218,77 @@
         )
     }
 
+    private suspend fun PointerInputScope.pointerTracker() {
+        val currentContext = currentCoroutineContext()
+        awaitPointerEventScope {
+            // Intercepts pointer inputs and exposes [PointersInfo], via
+            // [requireAncestorPointersInfoOwner], to our descendants.
+            while (currentContext.isActive) {
+                // During the Initial pass, we receive the event after our ancestors.
+                val pointers = awaitPointerEvent(PointerEventPass.Initial).changes
+                pointersDown = pointers.countDown()
+                if (pointersDown == 0) {
+                    // There are no more pointers down
+                    startedPosition = null
+                } else if (startedPosition == null) {
+                    startedPosition = pointers.first().position
+                    if (enabled()) {
+                        onFirstPointerDown()
+                    }
+                }
+            }
+        }
+    }
+
     private suspend fun PointerInputScope.pointerInput() {
         if (!enabled()) {
             return
         }
 
-        coroutineScope {
-            launch {
-                // Intercepts pointer inputs and exposes [PointersInfo], via
-                // [requireAncestorPointersInfoOwner], to our descendants.
-                awaitPointerEventScope {
-                    while (isActive) {
-                        // During the Initial pass, we receive the event after our ancestors.
-                        val pointers = awaitPointerEvent(PointerEventPass.Initial).changes
-
-                        pointersDown = pointers.countDown()
-                        if (pointersDown == 0) {
-                            // There are no more pointers down
-                            startedPosition = null
-                        } else if (startedPosition == null) {
-                            startedPosition = pointers.first().position
-                            onFirstPointerDown()
-                        }
-                    }
-                }
-            }
-
-            // The order is important here: we want to make sure that the previous PointerEventScope
-            // is initialized first. This ensures that the following PointerEventScope doesn't
-            // receive more events than the first one.
-            launch {
-                awaitPointerEventScope {
-                    while (isActive) {
-                        try {
-                            detectDragGestures(
-                                orientation = orientation,
-                                startDragImmediately = startDragImmediately,
-                                onDragStart = { startedPosition, overSlop, pointersDown ->
-                                    velocityTracker.resetTracking()
-                                    onDragStarted(startedPosition, overSlop, pointersDown)
-                                },
-                                onDrag = { controller, change, amount ->
-                                    velocityTracker.addPointerInputChange(change)
-                                    dispatchScrollEvents(
-                                        availableOnPreScroll = amount,
-                                        onScroll = { controller.onDrag(it) },
-                                        source = NestedScrollSource.UserInput,
-                                    )
-                                },
-                                onDragEnd = { controller ->
-                                    startFlingGesture(
-                                        initialVelocity =
-                                            currentValueOf(LocalViewConfiguration)
-                                                .maximumFlingVelocity
-                                                .let {
-                                                    val maxVelocity = Velocity(it, it)
-                                                    velocityTracker.calculateVelocity(maxVelocity)
-                                                }
-                                                .toFloat(),
-                                        onFling = { controller.onStop(it, canChangeContent = true) }
-                                    )
-                                },
-                                onDragCancel = { controller ->
-                                    startFlingGesture(
-                                        initialVelocity = 0f,
-                                        onFling = { controller.onStop(it, canChangeContent = true) }
-                                    )
-                                },
-                                swipeDetector = swipeDetector,
+        val currentContext = currentCoroutineContext()
+        awaitPointerEventScope {
+            while (currentContext.isActive) {
+                try {
+                    detectDragGestures(
+                        orientation = orientation,
+                        startDragImmediately = startDragImmediately,
+                        onDragStart = { startedPosition, overSlop, pointersDown ->
+                            velocityTracker.resetTracking()
+                            onDragStarted(startedPosition, overSlop, pointersDown)
+                        },
+                        onDrag = { controller, change, amount ->
+                            velocityTracker.addPointerInputChange(change)
+                            dispatchScrollEvents(
+                                availableOnPreScroll = amount,
+                                onScroll = { controller.onDrag(it) },
+                                source = NestedScrollSource.UserInput,
                             )
-                        } catch (exception: CancellationException) {
-                            // If the coroutine scope is active, we can just restart the drag cycle.
-                            if (!isActive) {
-                                throw exception
-                            }
-                        }
+                        },
+                        onDragEnd = { controller ->
+                            startFlingGesture(
+                                initialVelocity =
+                                    currentValueOf(LocalViewConfiguration)
+                                        .maximumFlingVelocity
+                                        .let {
+                                            val maxVelocity = Velocity(it, it)
+                                            velocityTracker.calculateVelocity(maxVelocity)
+                                        }
+                                        .toFloat(),
+                                onFling = { controller.onStop(it, canChangeContent = true) }
+                            )
+                        },
+                        onDragCancel = { controller ->
+                            startFlingGesture(
+                                initialVelocity = 0f,
+                                onFling = { controller.onStop(it, canChangeContent = true) }
+                            )
+                        },
+                        swipeDetector = swipeDetector,
+                    )
+                } catch (exception: CancellationException) {
+                    // If the coroutine scope is active, we can just restart the drag cycle.
+                    if (!currentContext.isActive) {
+                        throw exception
                     }
                 }
             }
diff --git a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/PredictiveBackHandler.kt b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/PredictiveBackHandler.kt
index e930011..3bf19fc 100644
--- a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/PredictiveBackHandler.kt
+++ b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/PredictiveBackHandler.kt
@@ -22,8 +22,10 @@
 import androidx.compose.foundation.gestures.Orientation
 import androidx.compose.runtime.Composable
 import kotlin.coroutines.cancellation.CancellationException
+import kotlinx.coroutines.coroutineScope
 import kotlinx.coroutines.flow.Flow
 import kotlinx.coroutines.flow.first
+import kotlinx.coroutines.launch
 
 @Composable
 internal fun PredictiveBackHandler(
@@ -42,7 +44,6 @@
         val animation =
             createSwipeAnimation(
                 layoutImpl,
-                layoutImpl.coroutineScope,
                 result.userActionCopy(
                     transitionKey = result.transitionKey ?: TransitionKey.PredictiveBack
                 ),
@@ -64,7 +65,8 @@
 ) {
     fun animateOffset(targetContent: T, spec: AnimationSpec<Float>? = null) {
         if (
-            layoutImpl.state.transitionState != animation.contentTransition || animation.isFinishing
+            layoutImpl.state.transitionState != animation.contentTransition ||
+                animation.isAnimatingOffset()
         ) {
             return
         }
@@ -76,20 +78,23 @@
         )
     }
 
-    layoutImpl.state.startTransition(animation.contentTransition)
-    try {
-        progress.collect { backEvent -> animation.dragOffset = backEvent.progress }
+    coroutineScope {
+        launch {
+            try {
+                progress.collect { backEvent -> animation.dragOffset = backEvent.progress }
 
-        // Back gesture successful.
-        animateOffset(
-            animation.toContent,
-            animation.contentTransition.transformationSpec.progressSpec
-        )
-    } catch (e: CancellationException) {
-        // Back gesture cancelled.
-        // If the back gesture is cancelled, the progress is animated back to 0f by the system.
-        // Since the remaining change in progress is usually very small, the progressSpec is omitted
-        // and the default spring spec used instead.
-        animateOffset(animation.fromContent)
+                // Back gesture successful.
+                animateOffset(
+                    animation.toContent,
+                    animation.contentTransition.transformationSpec.progressSpec,
+                )
+            } catch (e: CancellationException) {
+                // Back gesture cancelled.
+                animateOffset(animation.fromContent)
+            }
+        }
+
+        // Start the transition.
+        layoutImpl.state.startTransition(animation.contentTransition)
     }
 }
diff --git a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayout.kt b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayout.kt
index e453430..a0d512c 100644
--- a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayout.kt
+++ b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayout.kt
@@ -617,7 +617,7 @@
                 swipeSourceDetector = swipeSourceDetector,
                 transitionInterceptionThreshold = transitionInterceptionThreshold,
                 builder = builder,
-                coroutineScope = coroutineScope,
+                animationScope = coroutineScope,
             )
             .also { onLayoutImpl?.invoke(it) }
     }
diff --git a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayoutImpl.kt b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayoutImpl.kt
index b33b4f6..f36c0fa 100644
--- a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayoutImpl.kt
+++ b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayoutImpl.kt
@@ -59,7 +59,13 @@
     internal var swipeSourceDetector: SwipeSourceDetector,
     internal var transitionInterceptionThreshold: Float,
     builder: SceneTransitionLayoutScope.() -> Unit,
-    internal val coroutineScope: CoroutineScope,
+
+    /**
+     * The scope that should be used by *animations started by this layout only*, i.e. animations
+     * triggered by gestures set up on this layout in [swipeToScene] or interruption decay
+     * animations.
+     */
+    internal val animationScope: CoroutineScope,
 ) {
     /**
      * The map of [Scene]s.
@@ -142,18 +148,10 @@
         // DraggableHandlerImpl must wait for the scenes to be initialized, in order to access the
         // current scene (required for SwipeTransition).
         horizontalDraggableHandler =
-            DraggableHandlerImpl(
-                layoutImpl = this,
-                orientation = Orientation.Horizontal,
-                coroutineScope = coroutineScope,
-            )
+            DraggableHandlerImpl(layoutImpl = this, orientation = Orientation.Horizontal)
 
         verticalDraggableHandler =
-            DraggableHandlerImpl(
-                layoutImpl = this,
-                orientation = Orientation.Vertical,
-                coroutineScope = coroutineScope,
-            )
+            DraggableHandlerImpl(layoutImpl = this, orientation = Orientation.Vertical)
 
         // Make sure that the state is created on the same thread (most probably the main thread)
         // than this STLImpl.
diff --git a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayoutState.kt b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayoutState.kt
index f3128f1..cc7d146 100644
--- a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayoutState.kt
+++ b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SceneTransitionLayoutState.kt
@@ -30,6 +30,10 @@
 import com.android.compose.animation.scene.transition.link.StateLink
 import kotlin.math.absoluteValue
 import kotlinx.coroutines.CoroutineScope
+import kotlinx.coroutines.CoroutineStart
+import kotlinx.coroutines.Job
+import kotlinx.coroutines.coroutineScope
+import kotlinx.coroutines.launch
 
 /**
  * The state of a [SceneTransitionLayout].
@@ -108,24 +112,25 @@
      * If [targetScene] is different than the [currentScene][TransitionState.currentScene] of
      * [transitionState], then this will animate to [targetScene]. The associated
      * [TransitionState.Transition] will be returned and will be set as the current
-     * [transitionState] of this [MutableSceneTransitionLayoutState].
+     * [transitionState] of this [MutableSceneTransitionLayoutState]. The [Job] in which the
+     * transition runs will be returned, allowing you to easily [join][Job.join] or
+     * [cancel][Job.cancel] the animation.
      *
      * Note that because a non-null [TransitionState.Transition] is returned does not mean that the
      * transition will finish and that we will settle to [targetScene]. The returned transition
      * might still be interrupted, for instance by another call to [setTargetScene] or by a user
      * gesture.
      *
-     * If [this] [CoroutineScope] is cancelled during the transition and that the transition was
-     * still active, then the [transitionState] of this [MutableSceneTransitionLayoutState] will be
-     * set to `TransitionState.Idle(targetScene)`.
-     *
-     * TODO(b/318794193): Add APIs to await() and cancel() any [TransitionState.Transition].
+     * If [coroutineScope] is cancelled during the transition and that the transition was still
+     * active, then the [transitionState] of this [MutableSceneTransitionLayoutState] will be set to
+     * `TransitionState.Idle(targetScene)`.
      */
     fun setTargetScene(
         targetScene: SceneKey,
+        // TODO(b/362727477): Rename to animationScope.
         coroutineScope: CoroutineScope,
         transitionKey: TransitionKey? = null,
-    ): TransitionState.Transition?
+    ): Pair<TransitionState.Transition, Job>?
 
     /** Immediately snap to the given [scene]. */
     fun snapToScene(
@@ -299,7 +304,7 @@
         targetScene: SceneKey,
         coroutineScope: CoroutineScope,
         transitionKey: TransitionKey?,
-    ): TransitionState.Transition.ChangeScene? {
+    ): Pair<TransitionState.Transition.ChangeScene, Job>? {
         checkThread()
 
         return coroutineScope.animateToScene(
@@ -310,17 +315,67 @@
     }
 
     /**
+     * Instantly start a [transition], running it in [animationScope].
+     *
+     * This call returns immediately and [transition] will be the [currentTransition] of this
+     * [MutableSceneTransitionLayoutState].
+     *
+     * @see startTransition
+     */
+    internal fun startTransitionImmediately(
+        animationScope: CoroutineScope,
+        transition: TransitionState.Transition,
+        chain: Boolean = true,
+    ): Job {
+        // Note that we start with UNDISPATCHED so that startTransition() is called directly and
+        // transition becomes the current [transitionState] right after this call.
+        return animationScope.launch(
+            start = CoroutineStart.UNDISPATCHED,
+        ) {
+            startTransition(transition, chain)
+        }
+    }
+
+    /**
      * Start a new [transition].
      *
      * If [chain] is `true`, then the transitions will simply be added to [currentTransitions] and
      * will run in parallel to the current transitions. If [chain] is `false`, then the list of
      * [currentTransitions] will be cleared and [transition] will be the only running transition.
      *
-     * Important: you *must* call [finishTransition] once the transition is finished.
+     * If any transition is currently ongoing, it will be interrupted and forced to animate to its
+     * current state.
+     *
+     * This method returns when [transition] is done running, i.e. when the call to
+     * [run][TransitionState.Transition.run] returns.
      */
-    internal fun startTransition(transition: TransitionState.Transition, chain: Boolean = true) {
+    internal suspend fun startTransition(
+        transition: TransitionState.Transition,
+        chain: Boolean = true,
+    ) {
         checkThread()
 
+        try {
+            // Keep a reference to the previous transition (if any).
+            val previousTransition = currentTransition
+
+            // Start the transition.
+            startTransitionInternal(transition, chain)
+
+            // Handle transition links.
+            previousTransition?.let { cancelActiveTransitionLinks(it) }
+            if (stateLinks.isNotEmpty()) {
+                coroutineScope { setupTransitionLinks(transition) }
+            }
+
+            // Run the transition until it is finished.
+            transition.run()
+        } finally {
+            finishTransition(transition)
+        }
+    }
+
+    private fun startTransitionInternal(transition: TransitionState.Transition, chain: Boolean) {
         // Set the current scene and overlays on the transition.
         val currentState = transitionState
         transition.currentSceneWhenTransitionStarted = currentState.currentScene
@@ -349,10 +404,6 @@
             transition.updateOverscrollSpecs(fromSpec = null, toSpec = null)
         }
 
-        // Handle transition links.
-        currentTransition?.let { cancelActiveTransitionLinks(it) }
-        setupTransitionLinks(transition)
-
         if (!enableInterruptions) {
             // Set the current transition.
             check(transitionStates.size == 1)
@@ -367,9 +418,8 @@
                 transitionStates = listOf(transition)
             }
             is TransitionState.Transition -> {
-                // Force the current transition to finish to currentScene. The transition will call
-                // [finishTransition] once it's finished.
-                currentState.finish()
+                // Force the current transition to finish to currentScene.
+                currentState.freezeAndAnimateToCurrentState()
 
                 val tooManyTransitions = transitionStates.size >= MAX_CONCURRENT_TRANSITIONS
                 val clearCurrentTransitions = !chain || tooManyTransitions
@@ -423,7 +473,7 @@
         transition.activeTransitionLinks.clear()
     }
 
-    private fun setupTransitionLinks(transition: TransitionState.Transition) {
+    private fun CoroutineScope.setupTransitionLinks(transition: TransitionState.Transition) {
         stateLinks.fastForEach { stateLink ->
             val matchingLinks =
                 stateLink.transitionLinks.fastFilter { it.isMatchingLink(transition) }
@@ -443,7 +493,11 @@
                     key = matchingLink.targetTransitionKey,
                 )
 
-            stateLink.target.startTransition(linkedTransition)
+            // Start with UNDISPATCHED so that startTransition is called directly and the new linked
+            // transition is observable directly.
+            launch(start = CoroutineStart.UNDISPATCHED) {
+                stateLink.target.startTransition(linkedTransition)
+            }
             transition.activeTransitionLinks[stateLink] = linkedTransition
         }
     }
@@ -453,7 +507,7 @@
      * [currentScene][TransitionState.currentScene]. This will do nothing if [transition] was
      * interrupted since it was started.
      */
-    internal fun finishTransition(transition: TransitionState.Transition) {
+    private fun finishTransition(transition: TransitionState.Transition) {
         checkThread()
 
         if (finishedTransitions.contains(transition)) {
@@ -461,6 +515,10 @@
             return
         }
 
+        // Make sure that this transition settles in case it was force finished, for instance by
+        // calling snapToScene().
+        transition.freezeAndAnimateToCurrentState()
+
         val transitionStates = this.transitionStates
         if (!transitionStates.contains(transition)) {
             // This transition was already removed from transitionStates.
diff --git a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SwipeAnimation.kt b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SwipeAnimation.kt
index be9c567..bd4627d 100644
--- a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SwipeAnimation.kt
+++ b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/SwipeAnimation.kt
@@ -28,14 +28,10 @@
 import com.android.compose.animation.scene.content.state.TransitionState
 import com.android.compose.animation.scene.content.state.TransitionState.HasOverscrollProperties.Companion.DistanceUnspecified
 import kotlin.math.absoluteValue
-import kotlinx.coroutines.CoroutineScope
-import kotlinx.coroutines.CoroutineStart
-import kotlinx.coroutines.Job
-import kotlinx.coroutines.launch
+import kotlinx.coroutines.CompletableDeferred
 
 internal fun createSwipeAnimation(
     layoutState: MutableSceneTransitionLayoutStateImpl,
-    animationScope: CoroutineScope,
     result: UserActionResult,
     isUpOrLeft: Boolean,
     orientation: Orientation,
@@ -43,7 +39,6 @@
 ): SwipeAnimation<*> {
     return createSwipeAnimation(
         layoutState,
-        animationScope,
         result,
         isUpOrLeft,
         orientation,
@@ -56,7 +51,6 @@
 
 internal fun createSwipeAnimation(
     layoutImpl: SceneTransitionLayoutImpl,
-    animationScope: CoroutineScope,
     result: UserActionResult,
     isUpOrLeft: Boolean,
     orientation: Orientation,
@@ -88,7 +82,6 @@
 
     return createSwipeAnimation(
         layoutImpl.state,
-        animationScope,
         result,
         isUpOrLeft,
         orientation,
@@ -99,7 +92,6 @@
 
 private fun createSwipeAnimation(
     layoutState: MutableSceneTransitionLayoutStateImpl,
-    animationScope: CoroutineScope,
     result: UserActionResult,
     isUpOrLeft: Boolean,
     orientation: Orientation,
@@ -109,7 +101,6 @@
     fun <T : ContentKey> swipeAnimation(fromContent: T, toContent: T): SwipeAnimation<T> {
         return SwipeAnimation(
             layoutState = layoutState,
-            animationScope = animationScope,
             fromContent = fromContent,
             toContent = toContent,
             orientation = orientation,
@@ -197,7 +188,6 @@
 /** A helper class that contains the main logic for swipe transitions. */
 internal class SwipeAnimation<T : ContentKey>(
     val layoutState: MutableSceneTransitionLayoutStateImpl,
-    val animationScope: CoroutineScope,
     val fromContent: T,
     val toContent: T,
     override val orientation: Orientation,
@@ -210,14 +200,22 @@
     /** The [TransitionState.Transition] whose implementation delegates to this [SwipeAnimation]. */
     lateinit var contentTransition: TransitionState.Transition
 
-    var currentContent by mutableStateOf(currentContent)
+    private var _currentContent by mutableStateOf(currentContent)
+    var currentContent: T
+        get() = _currentContent
+        set(value) {
+            check(!isAnimatingOffset()) {
+                "currentContent can not be changed once we are animating the offset"
+            }
+            _currentContent = value
+        }
 
     val progress: Float
         get() {
             // Important: If we are going to return early because distance is equal to 0, we should
             // still make sure we read the offset before returning so that the calling code still
             // subscribes to the offset value.
-            val animatable = offsetAnimation?.animatable
+            val animatable = offsetAnimation
             val offset =
                 when {
                     animatable != null -> animatable.value
@@ -238,7 +236,7 @@
 
     val progressVelocity: Float
         get() {
-            val animatable = offsetAnimation?.animatable ?: return 0f
+            val animatable = offsetAnimation ?: return 0f
             val distance = distance()
             if (distance == DistanceUnspecified) {
                 return 0f
@@ -263,7 +261,8 @@
     var dragOffset by mutableFloatStateOf(dragOffset)
 
     /** The offset animation that animates the offset once the user lifts their finger. */
-    private var offsetAnimation: OffsetAnimation? by mutableStateOf(null)
+    private var offsetAnimation: Animatable<Float, AnimationVector1D>? by mutableStateOf(null)
+    private val offsetAnimationRunnable = CompletableDeferred<(suspend () -> Unit)?>()
 
     val isUserInputOngoing: Boolean
         get() = offsetAnimation == null
@@ -271,15 +270,10 @@
     override val absoluteDistance: Float
         get() = distance().absoluteValue
 
-    /** Whether [finish] was called on this animation. */
-    var isFinishing = false
-        private set
-
     constructor(
         other: SwipeAnimation<T>
     ) : this(
         layoutState = other.layoutState,
-        animationScope = other.animationScope,
         fromContent = other.fromContent,
         toContent = other.toContent,
         orientation = other.orientation,
@@ -287,9 +281,17 @@
         requiresFullDistanceSwipe = other.requiresFullDistanceSwipe,
         distance = other.distance,
         currentContent = other.currentContent,
-        dragOffset = other.dragOffset,
+        dragOffset = other.offsetAnimation?.value ?: other.dragOffset,
     )
 
+    suspend fun run() {
+        // This animation will first be driven by finger, then when the user lift their finger we
+        // start an animation to the target offset (progress = 1f or progress = 0f). We await() for
+        // offsetAnimationRunnable to be completed and then run it.
+        val runAnimation = offsetAnimationRunnable.await() ?: return
+        runAnimation()
+    }
+
     /**
      * The signed distance between [fromContent] and [toContent]. It is negative if [fromContent] is
      * above or to the left of [toContent].
@@ -300,28 +302,15 @@
      */
     fun distance(): Float = distance(this)
 
-    /** Ends any previous [offsetAnimation] and runs the new [animation]. */
-    private fun startOffsetAnimation(animation: () -> OffsetAnimation): OffsetAnimation {
-        cancelOffsetAnimation()
-        return animation().also { offsetAnimation = it }
-    }
-
-    /** Cancel any ongoing offset animation. */
-    // TODO(b/317063114) This should be a suspended function to avoid multiple jobs running at
-    // the same time.
-    fun cancelOffsetAnimation() {
-        val animation = offsetAnimation ?: return
-        offsetAnimation = null
-
-        dragOffset = animation.animatable.value
-        animation.job.cancel()
-    }
+    fun isAnimatingOffset(): Boolean = offsetAnimation != null
 
     fun animateOffset(
         initialVelocity: Float,
         targetContent: T,
         spec: AnimationSpec<Float>? = null,
-    ): OffsetAnimation {
+    ) {
+        check(!isAnimatingOffset()) { "SwipeAnimation.animateOffset() can only be called once" }
+
         val initialProgress = progress
         // Skip the animation if we have already reached the target content and the overscroll does
         // not animate anything.
@@ -358,74 +347,76 @@
             currentContent = targetContent
         }
 
-        return startOffsetAnimation {
-            val startProgress =
-                if (contentTransition.previewTransformationSpec != null) 0f else dragOffset
-            val animatable = Animatable(startProgress, OffsetVisibilityThreshold)
-            val isTargetGreater = targetOffset > animatable.value
-            val startedWhenOvercrollingTargetContent =
-                if (targetContent == fromContent) initialProgress < 0f else initialProgress > 1f
-            val job =
-                animationScope
-                    // Important: We start atomically to make sure that we start the coroutine even
-                    // if it is cancelled right after it is launched, so that snapToContent() is
-                    // correctly called. Otherwise, this transition will never be stopped and we
-                    // will never settle to Idle.
-                    .launch(start = CoroutineStart.ATOMIC) {
-                        // TODO(b/327249191): Refactor the code so that we don't even launch a
-                        // coroutine if we don't need to animate.
-                        if (skipAnimation) {
-                            snapToContent(targetContent)
-                            dragOffset = targetOffset
-                            return@launch
-                        }
+        val startProgress =
+            if (contentTransition.previewTransformationSpec != null) 0f else dragOffset
 
-                        try {
-                            val swipeSpec =
-                                spec
-                                    ?: contentTransition.transformationSpec.swipeSpec
-                                    ?: layoutState.transitions.defaultSwipeSpec
-                            animatable.animateTo(
-                                targetValue = targetOffset,
-                                animationSpec = swipeSpec,
-                                initialVelocity = initialVelocity,
-                            ) {
-                                if (bouncingContent == null) {
-                                    val isBouncing =
-                                        if (isTargetGreater) {
-                                            if (startedWhenOvercrollingTargetContent) {
-                                                value >= targetOffset
-                                            } else {
-                                                value > targetOffset
-                                            }
-                                        } else {
-                                            if (startedWhenOvercrollingTargetContent) {
-                                                value <= targetOffset
-                                            } else {
-                                                value < targetOffset
-                                            }
-                                        }
+        val animatable =
+            Animatable(startProgress, OffsetVisibilityThreshold).also { offsetAnimation = it }
 
-                                    if (isBouncing) {
-                                        bouncingContent = targetContent
+        check(isAnimatingOffset())
 
-                                        // Immediately stop this transition if we are bouncing on a
-                                        // content that does not bounce.
-                                        if (!contentTransition.isWithinProgressRange(progress)) {
-                                            snapToContent(targetContent)
-                                        }
-                                    }
+        // Note: we still create the animatable and set it on offsetAnimation even when
+        // skipAnimation is true, just so that isUserInputOngoing and isAnimatingOffset() are
+        // unchanged even despite this small skip-optimization (which is just an implementation
+        // detail).
+        if (skipAnimation) {
+            // Unblock the job.
+            offsetAnimationRunnable.complete(null)
+            return
+        }
+
+        val isTargetGreater = targetOffset > animatable.value
+        val startedWhenOvercrollingTargetContent =
+            if (targetContent == fromContent) initialProgress < 0f else initialProgress > 1f
+
+        val swipeSpec =
+            spec
+                ?: contentTransition.transformationSpec.swipeSpec
+                ?: layoutState.transitions.defaultSwipeSpec
+
+        offsetAnimationRunnable.complete {
+            try {
+                animatable.animateTo(
+                    targetValue = targetOffset,
+                    animationSpec = swipeSpec,
+                    initialVelocity = initialVelocity,
+                ) {
+                    if (bouncingContent == null) {
+                        val isBouncing =
+                            if (isTargetGreater) {
+                                if (startedWhenOvercrollingTargetContent) {
+                                    value >= targetOffset
+                                } else {
+                                    value > targetOffset
+                                }
+                            } else {
+                                if (startedWhenOvercrollingTargetContent) {
+                                    value <= targetOffset
+                                } else {
+                                    value < targetOffset
                                 }
                             }
-                        } finally {
-                            snapToContent(targetContent)
+
+                        if (isBouncing) {
+                            bouncingContent = targetContent
+
+                            // Immediately stop this transition if we are bouncing on a content that
+                            // does not bounce.
+                            if (!contentTransition.isWithinProgressRange(progress)) {
+                                throw SnapException()
+                            }
                         }
                     }
-
-            OffsetAnimation(animatable, job)
+                }
+            } catch (_: SnapException) {
+                /* Ignore. */
+            }
         }
     }
 
+    /** An exception thrown during the animation to stop it immediately. */
+    private class SnapException : Exception()
+
     private fun canChangeContent(targetContent: ContentKey): Boolean {
         return when (val transition = contentTransition) {
             is TransitionState.Transition.ChangeScene ->
@@ -446,34 +437,11 @@
         }
     }
 
-    private fun snapToContent(content: T) {
-        cancelOffsetAnimation()
-        check(currentContent == content)
-        layoutState.finishTransition(contentTransition)
+    fun freezeAndAnimateToCurrentState() {
+        if (isAnimatingOffset()) return
+
+        animateOffset(initialVelocity = 0f, targetContent = currentContent)
     }
-
-    fun finish(): Job {
-        if (isFinishing) return requireNotNull(offsetAnimation).job
-        isFinishing = true
-
-        // If we were already animating the offset, simply return the job.
-        offsetAnimation?.let {
-            return it.job
-        }
-
-        // Animate to the current content.
-        val animation = animateOffset(initialVelocity = 0f, targetContent = currentContent)
-        check(offsetAnimation == animation)
-        return animation.job
-    }
-
-    internal class OffsetAnimation(
-        /** The animatable used to animate the offset. */
-        val animatable: Animatable<Float, AnimationVector1D>,
-
-        /** The job in which [animatable] is animated. */
-        val job: Job,
-    )
 }
 
 private object DefaultSwipeDistance : UserActionDistance {
@@ -537,7 +505,13 @@
     override val isUserInputOngoing: Boolean
         get() = swipeAnimation.isUserInputOngoing
 
-    override fun finish(): Job = swipeAnimation.finish()
+    override suspend fun run() {
+        swipeAnimation.run()
+    }
+
+    override fun freezeAndAnimateToCurrentState() {
+        swipeAnimation.freezeAndAnimateToCurrentState()
+    }
 }
 
 private class ShowOrHideOverlaySwipeTransition(
@@ -594,7 +568,13 @@
     override val isUserInputOngoing: Boolean
         get() = swipeAnimation.isUserInputOngoing
 
-    override fun finish(): Job = swipeAnimation.finish()
+    override suspend fun run() {
+        swipeAnimation.run()
+    }
+
+    override fun freezeAndAnimateToCurrentState() {
+        swipeAnimation.freezeAndAnimateToCurrentState()
+    }
 }
 
 private class ReplaceOverlaySwipeTransition(
@@ -645,5 +625,11 @@
     override val isUserInputOngoing: Boolean
         get() = swipeAnimation.isUserInputOngoing
 
-    override fun finish(): Job = swipeAnimation.finish()
+    override suspend fun run() {
+        swipeAnimation.run()
+    }
+
+    override fun freezeAndAnimateToCurrentState() {
+        swipeAnimation.freezeAndAnimateToCurrentState()
+    }
 }
diff --git a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/content/state/TransitionState.kt b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/content/state/TransitionState.kt
index 0cd8c1a..a47caaa 100644
--- a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/content/state/TransitionState.kt
+++ b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/content/state/TransitionState.kt
@@ -35,7 +35,6 @@
 import com.android.compose.animation.scene.TransitionKey
 import com.android.compose.animation.scene.transition.link.LinkedTransition
 import com.android.compose.animation.scene.transition.link.StateLink
-import kotlinx.coroutines.Job
 import kotlinx.coroutines.launch
 
 /** The state associated to a [SceneTransitionLayout] at some specific point in time. */
@@ -300,19 +299,19 @@
             return fromContent == content || toContent == content
         }
 
+        /** Run this transition and return once it is finished. */
+        internal abstract suspend fun run()
+
         /**
-         * Force this transition to finish and animate to an [Idle] state.
+         * Freeze this transition state so that neither [currentScene] nor [currentOverlays] will
+         * change in the future, and animate the progress towards that state. For instance, a
+         * [Transition.ChangeScene] should animate the progress to 0f if its [currentScene] is equal
+         * to its [fromScene][Transition.ChangeScene.fromScene] or animate it to 1f if its equal to
+         * its [toScene][Transition.ChangeScene.toScene].
          *
-         * Important: Once this is called, the effective state of the transition should remain
-         * unchanged. For instance, in the case of a [TransitionState.Transition], its
-         * [currentScene][TransitionState.Transition.currentScene] should never change once [finish]
-         * is called.
-         *
-         * @return the [Job] that animates to the idle state. It can be used to wait until the
-         *   animation is complete or cancel it to snap the animation. Calling [finish] multiple
-         *   times will return the same [Job].
+         * This is called when this transition is interrupted (replaced) by another transition.
          */
-        internal abstract fun finish(): Job
+        internal abstract fun freezeAndAnimateToCurrentState()
 
         internal fun updateOverscrollSpecs(
             fromSpec: OverscrollSpecImpl?,
@@ -350,7 +349,7 @@
 
             fun create(): Animatable<Float, AnimationVector1D> {
                 val animatable = Animatable(1f, visibilityThreshold = ProgressVisibilityThreshold)
-                layoutImpl.coroutineScope.launch {
+                layoutImpl.animationScope.launch {
                     val swipeSpec = layoutImpl.state.transitions.defaultSwipeSpec
                     val progressSpec =
                         spring(
diff --git a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/transition/link/LinkedTransition.kt b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/transition/link/LinkedTransition.kt
index 564d4b3..42ba9ba 100644
--- a/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/transition/link/LinkedTransition.kt
+++ b/packages/SystemUI/compose/scene/src/com/android/compose/animation/scene/transition/link/LinkedTransition.kt
@@ -19,7 +19,6 @@
 import com.android.compose.animation.scene.SceneKey
 import com.android.compose.animation.scene.TransitionKey
 import com.android.compose.animation.scene.content.state.TransitionState
-import kotlinx.coroutines.Job
 
 /** A linked transition which is driven by a [originalTransition]. */
 internal class LinkedTransition(
@@ -50,5 +49,11 @@
     override val progressVelocity: Float
         get() = originalTransition.progressVelocity
 
-    override fun finish(): Job = originalTransition.finish()
+    override suspend fun run() {
+        originalTransition.run()
+    }
+
+    override fun freezeAndAnimateToCurrentState() {
+        originalTransition.freezeAndAnimateToCurrentState()
+    }
 }
diff --git a/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/AnimatedSharedAsStateTest.kt b/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/AnimatedSharedAsStateTest.kt
index 8ebb42a..a491349 100644
--- a/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/AnimatedSharedAsStateTest.kt
+++ b/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/AnimatedSharedAsStateTest.kt
@@ -39,7 +39,10 @@
 import com.android.compose.animation.scene.TestScenes.SceneB
 import com.android.compose.animation.scene.TestScenes.SceneC
 import com.android.compose.animation.scene.TestScenes.SceneD
+import com.android.compose.test.setContentAndCreateMainScope
+import com.android.compose.test.transition
 import com.google.common.truth.Truth.assertThat
+import kotlinx.coroutines.launch
 import kotlinx.coroutines.test.runTest
 import org.junit.Assert.assertThrows
 import org.junit.Rule
@@ -406,30 +409,33 @@
             }
         }
 
-        rule.setContent {
-            SceneTransitionLayout(state) {
-                // foo goes from 0f to 100f in A => B.
-                scene(SceneA) { animateFloat(0f, foo) }
-                scene(SceneB) { animateFloat(100f, foo) }
+        val scope =
+            rule.setContentAndCreateMainScope {
+                SceneTransitionLayout(state) {
+                    // foo goes from 0f to 100f in A => B.
+                    scene(SceneA) { animateFloat(0f, foo) }
+                    scene(SceneB) { animateFloat(100f, foo) }
 
-                // bar goes from 0f to 10f in C => D.
-                scene(SceneC) { animateFloat(0f, bar) }
-                scene(SceneD) { animateFloat(10f, bar) }
+                    // bar goes from 0f to 10f in C => D.
+                    scene(SceneC) { animateFloat(0f, bar) }
+                    scene(SceneD) { animateFloat(10f, bar) }
+                }
             }
-        }
 
-        rule.runOnUiThread {
-            // A => B is at 30%.
+        // A => B is at 30%.
+        scope.launch {
             state.startTransition(
                 transition(
                     from = SceneA,
                     to = SceneB,
                     progress = { 0.3f },
-                    onFinish = neverFinish(),
+                    onFreezeAndAnimate = { /* never finish */ },
                 )
             )
+        }
 
-            // C => D is at 70%.
+        // C => D is at 70%.
+        scope.launch {
             state.startTransition(transition(from = SceneC, to = SceneD, progress = { 0.7f }))
         }
         rule.waitForIdle()
@@ -466,17 +472,18 @@
             }
         }
 
-        rule.setContent {
-            SceneTransitionLayout(state) {
-                scene(SceneA) { animateFloat(0f, key) }
-                scene(SceneB) { animateFloat(100f, key) }
+        val scope =
+            rule.setContentAndCreateMainScope {
+                SceneTransitionLayout(state) {
+                    scene(SceneA) { animateFloat(0f, key) }
+                    scene(SceneB) { animateFloat(100f, key) }
+                }
             }
-        }
 
         // Overscroll on A at -100%: value should be interpolated given that there is no overscroll
         // defined for scene A.
         var progress by mutableStateOf(-1f)
-        rule.runOnIdle {
+        scope.launch {
             state.startTransition(transition(from = SceneA, to = SceneB, progress = { progress }))
         }
         rule.waitForIdle()
diff --git a/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/DraggableHandlerTest.kt b/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/DraggableHandlerTest.kt
index 9fa4722..26743fc 100644
--- a/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/DraggableHandlerTest.kt
+++ b/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/DraggableHandlerTest.kt
@@ -41,9 +41,9 @@
 import com.android.compose.animation.scene.subjects.assertThat
 import com.android.compose.test.MonotonicClockTestScope
 import com.android.compose.test.runMonotonicClockTest
+import com.android.compose.test.transition
 import com.google.common.truth.Truth.assertThat
 import kotlinx.coroutines.CoroutineScope
-import kotlinx.coroutines.cancelAndJoin
 import kotlinx.coroutines.launch
 import org.junit.Test
 import org.junit.runner.RunWith
@@ -132,7 +132,10 @@
                     swipeSourceDetector = DefaultEdgeDetector,
                     transitionInterceptionThreshold = transitionInterceptionThreshold,
                     builder = scenesBuilder,
-                    coroutineScope = testScope,
+
+                    // Use testScope and not backgroundScope here because backgroundScope does not
+                    // work well with advanceUntilIdle(), which is used by some tests.
+                    animationScope = testScope,
                 )
                 .apply { setContentsAndLayoutTargetSizeForTest(LAYOUT_SIZE) }
 
@@ -301,8 +304,20 @@
         runMonotonicClockTest {
             val testGestureScope = TestGestureScope(testScope = this)
 
-            // run the test
-            testGestureScope.block()
+            try {
+                // Run the test.
+                testGestureScope.block()
+            } finally {
+                // Make sure we stop the last transition if it was not explicitly stopped, otherwise
+                // tests will time out after 10s given that the transitions are now started on the
+                // test scope. We don't use backgroundScope when starting the test transitions
+                // because coroutines started on the background scope don't work well with
+                // advanceUntilIdle(), which is used in a few tests.
+                if (testGestureScope.draggableHandler.isDrivingTransition) {
+                    (testGestureScope.layoutState.transitionState as Transition)
+                        .freezeAndAnimateToCurrentState()
+                }
+            }
         }
     }
 
@@ -940,7 +955,7 @@
     }
 
     @Test
-    fun finish() = runGestureTest {
+    fun freezeAndAnimateToCurrentState() = runGestureTest {
         // Start at scene C.
         navigateToSceneC()
 
@@ -952,35 +967,25 @@
         // The current transition can be intercepted.
         assertThat(draggableHandler.shouldImmediatelyIntercept(middle)).isTrue()
 
-        // Finish the transition.
+        // Freeze the transition.
         val transition = transitionState as Transition
-        val job = transition.finish()
+        transition.freezeAndAnimateToCurrentState()
         assertTransition(isUserInputOngoing = false)
-
-        // The current transition can not be intercepted anymore.
-        assertThat(draggableHandler.shouldImmediatelyIntercept(middle)).isFalse()
-
-        // Calling finish() multiple times returns the same Job.
-        assertThat(transition.finish()).isSameInstanceAs(job)
-        assertThat(transition.finish()).isSameInstanceAs(job)
-        assertThat(transition.finish()).isSameInstanceAs(job)
-
-        // We can join the job to wait for the animation to end.
-        assertTransition()
-        job.join()
+        advanceUntilIdle()
         assertIdle(SceneC)
     }
 
     @Test
-    fun finish_cancelled() = runGestureTest {
-        // Swipe up from the middle to transition to scene B.
-        val middle = Offset(SCREEN_SIZE / 2f, SCREEN_SIZE / 2f)
-        onDragStarted(startedPosition = middle, overSlop = up(0.1f))
-        assertTransition(fromScene = SceneA, toScene = SceneB)
+    fun interruptedTransitionCanNotBeImmediatelyIntercepted() = runGestureTest {
+        assertThat(draggableHandler.shouldImmediatelyIntercept(startedPosition = null)).isFalse()
+        onDragStarted(overSlop = up(0.1f))
+        assertThat(draggableHandler.shouldImmediatelyIntercept(startedPosition = null)).isTrue()
 
-        // Finish the transition and cancel the returned job.
-        (transitionState as Transition).finish().cancelAndJoin()
-        assertIdle(SceneA)
+        layoutState.startTransitionImmediately(
+            animationScope = testScope.backgroundScope,
+            transition(SceneA, SceneB)
+        )
+        assertThat(draggableHandler.shouldImmediatelyIntercept(startedPosition = null)).isFalse()
     }
 
     @Test
diff --git a/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/ElementTest.kt b/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/ElementTest.kt
index 770c0f8..60596de 100644
--- a/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/ElementTest.kt
+++ b/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/ElementTest.kt
@@ -71,10 +71,11 @@
 import com.android.compose.animation.scene.TestScenes.SceneC
 import com.android.compose.animation.scene.subjects.assertThat
 import com.android.compose.test.assertSizeIsEqualTo
+import com.android.compose.test.setContentAndCreateMainScope
+import com.android.compose.test.transition
 import com.google.common.truth.Truth.assertThat
 import kotlinx.coroutines.CoroutineScope
 import kotlinx.coroutines.launch
-import kotlinx.coroutines.test.runTest
 import org.junit.Assert.assertThrows
 import org.junit.Ignore
 import org.junit.Rule
@@ -504,7 +505,7 @@
     }
 
     @Test
-    fun elementModifierNodeIsRecycledInLazyLayouts() = runTest {
+    fun elementModifierNodeIsRecycledInLazyLayouts() {
         val nPages = 2
         val pagerState = PagerState(currentPage = 0) { nPages }
         var nullableLayoutImpl: SceneTransitionLayoutImpl? = null
@@ -630,18 +631,19 @@
                 )
             }
 
-        rule.setContent {
-            SceneTransitionLayout(state) {
-                scene(SceneA) { Box(Modifier.element(TestElements.Foo).size(20.dp)) }
-                scene(SceneB) {}
+        val scope =
+            rule.setContentAndCreateMainScope {
+                SceneTransitionLayout(state) {
+                    scene(SceneA) { Box(Modifier.element(TestElements.Foo).size(20.dp)) }
+                    scene(SceneB) {}
+                }
             }
-        }
 
         // Pause the clock to block recompositions.
         rule.mainClock.autoAdvance = false
 
         // Change the current transition.
-        rule.runOnUiThread {
+        scope.launch {
             state.startTransition(transition(from = SceneA, to = SceneB, progress = { 0.5f }))
         }
 
@@ -1296,7 +1298,7 @@
     }
 
     @Test
-    fun interruption() = runTest {
+    fun interruption() {
         // 4 frames of animation.
         val duration = 4 * 16
 
@@ -1336,37 +1338,41 @@
         val valueInC = 200f
 
         lateinit var layoutImpl: SceneTransitionLayoutImpl
-        rule.setContent {
-            SceneTransitionLayoutForTesting(
-                state,
-                Modifier.size(layoutSize),
-                onLayoutImpl = { layoutImpl = it },
-            ) {
-                // In scene A, Foo is aligned at the TopStart.
-                scene(SceneA) {
-                    Box(Modifier.fillMaxSize()) {
-                        Foo(sizeInA, valueInA, Modifier.align(Alignment.TopStart))
+        val scope =
+            rule.setContentAndCreateMainScope {
+                SceneTransitionLayoutForTesting(
+                    state,
+                    Modifier.size(layoutSize),
+                    onLayoutImpl = { layoutImpl = it },
+                ) {
+                    // In scene A, Foo is aligned at the TopStart.
+                    scene(SceneA) {
+                        Box(Modifier.fillMaxSize()) {
+                            Foo(sizeInA, valueInA, Modifier.align(Alignment.TopStart))
+                        }
                     }
-                }
 
-                // In scene C, Foo is aligned at the BottomEnd, so it moves vertically when coming
-                // from B. We put it before (below) scene B so that we can check that interruptions
-                // values and deltas are properly cleared once all transitions are done.
-                scene(SceneC) {
-                    Box(Modifier.fillMaxSize()) {
-                        Foo(sizeInC, valueInC, Modifier.align(Alignment.BottomEnd))
+                    // In scene C, Foo is aligned at the BottomEnd, so it moves vertically when
+                    // coming
+                    // from B. We put it before (below) scene B so that we can check that
+                    // interruptions
+                    // values and deltas are properly cleared once all transitions are done.
+                    scene(SceneC) {
+                        Box(Modifier.fillMaxSize()) {
+                            Foo(sizeInC, valueInC, Modifier.align(Alignment.BottomEnd))
+                        }
                     }
-                }
 
-                // In scene B, Foo is aligned at the TopEnd, so it moves horizontally when coming
-                // from A.
-                scene(SceneB) {
-                    Box(Modifier.fillMaxSize()) {
-                        Foo(sizeInB, valueInB, Modifier.align(Alignment.TopEnd))
+                    // In scene B, Foo is aligned at the TopEnd, so it moves horizontally when
+                    // coming
+                    // from A.
+                    scene(SceneB) {
+                        Box(Modifier.fillMaxSize()) {
+                            Foo(sizeInB, valueInB, Modifier.align(Alignment.TopEnd))
+                        }
                     }
                 }
             }
-        }
 
         // The offset of Foo when idle in A, B or C.
         val offsetInA = DpOffset.Zero
@@ -1390,12 +1396,12 @@
                 from = SceneA,
                 to = SceneB,
                 progress = { aToBProgress },
-                onFinish = neverFinish(),
+                onFreezeAndAnimate = { /* never finish */ },
             )
         val offsetInAToB = lerp(offsetInA, offsetInB, aToBProgress)
         val sizeInAToB = lerp(sizeInA, sizeInB, aToBProgress)
         val valueInAToB = lerp(valueInA, valueInB, aToBProgress)
-        rule.runOnUiThread { state.startTransition(aToB) }
+        scope.launch { state.startTransition(aToB) }
         rule
             .onNode(isElement(TestElements.Foo, SceneB))
             .assertSizeIsEqualTo(sizeInAToB)
@@ -1415,7 +1421,7 @@
                 progress = { bToCProgress },
                 interruptionProgress = { interruptionProgress },
             )
-        rule.runOnUiThread { state.startTransition(bToC) }
+        scope.launch { state.startTransition(bToC) }
 
         // The interruption deltas, which will be multiplied by the interruption progress then added
         // to the current transition offset and size.
@@ -1476,10 +1482,8 @@
             .assertSizeIsEqualTo(sizeInC)
 
         // Manually finish the transition.
-        rule.runOnUiThread {
-            state.finishTransition(aToB)
-            state.finishTransition(bToC)
-        }
+        aToB.finish()
+        bToC.finish()
         rule.waitForIdle()
         assertThat(state.transitionState).isIdle()
 
@@ -1498,7 +1502,7 @@
     }
 
     @Test
-    fun interruption_sharedTransitionDisabled() = runTest {
+    fun interruption_sharedTransitionDisabled() {
         // 4 frames of animation.
         val duration = 4 * 16
         val layoutSize = DpSize(200.dp, 100.dp)
@@ -1524,21 +1528,22 @@
             Box(modifier.element(TestElements.Foo).size(fooSize))
         }
 
-        rule.setContent {
-            SceneTransitionLayout(state, Modifier.size(layoutSize)) {
-                scene(SceneA) {
-                    Box(Modifier.fillMaxSize()) { Foo(Modifier.align(Alignment.TopStart)) }
-                }
+        val scope =
+            rule.setContentAndCreateMainScope {
+                SceneTransitionLayout(state, Modifier.size(layoutSize)) {
+                    scene(SceneA) {
+                        Box(Modifier.fillMaxSize()) { Foo(Modifier.align(Alignment.TopStart)) }
+                    }
 
-                scene(SceneB) {
-                    Box(Modifier.fillMaxSize()) { Foo(Modifier.align(Alignment.TopEnd)) }
-                }
+                    scene(SceneB) {
+                        Box(Modifier.fillMaxSize()) { Foo(Modifier.align(Alignment.TopEnd)) }
+                    }
 
-                scene(SceneC) {
-                    Box(Modifier.fillMaxSize()) { Foo(Modifier.align(Alignment.BottomEnd)) }
+                    scene(SceneC) {
+                        Box(Modifier.fillMaxSize()) { Foo(Modifier.align(Alignment.BottomEnd)) }
+                    }
                 }
             }
-        }
 
         // The offset of Foo when idle in A, B or C.
         val offsetInA = DpOffset.Zero
@@ -1547,7 +1552,12 @@
 
         // State is a transition A => B at 50% interrupted by B => C at 30%.
         val aToB =
-            transition(from = SceneA, to = SceneB, progress = { 0.5f }, onFinish = neverFinish())
+            transition(
+                from = SceneA,
+                to = SceneB,
+                progress = { 0.5f },
+                onFreezeAndAnimate = { /* never finish */ },
+            )
         var bToCInterruptionProgress by mutableStateOf(1f)
         val bToC =
             transition(
@@ -1555,11 +1565,11 @@
                 to = SceneC,
                 progress = { 0.3f },
                 interruptionProgress = { bToCInterruptionProgress },
-                onFinish = neverFinish(),
+                onFreezeAndAnimate = { /* never finish */ },
             )
-        rule.runOnUiThread { state.startTransition(aToB) }
+        scope.launch { state.startTransition(aToB) }
         rule.waitForIdle()
-        rule.runOnUiThread { state.startTransition(bToC) }
+        scope.launch { state.startTransition(bToC) }
 
         // Foo is placed in both B and C given that the shared transition is disabled. In B, its
         // offset is impacted by the interruption but in C it is not.
@@ -1579,7 +1589,8 @@
 
         // Manually finish A => B so only B => C is remaining.
         bToCInterruptionProgress = 0f
-        rule.runOnUiThread { state.finishTransition(aToB) }
+        aToB.finish()
+
         rule
             .onNode(isElement(TestElements.Foo, SceneB))
             .assertPositionInRootIsEqualTo(offsetInB.x, offsetInB.y)
@@ -1595,7 +1606,7 @@
                 progress = { 0.7f },
                 interruptionProgress = { 1f },
             )
-        rule.runOnUiThread { state.startTransition(bToA) }
+        scope.launch { state.startTransition(bToA) }
 
         // Foo should have the position it had in B right before the interruption.
         rule
@@ -1609,32 +1620,35 @@
         val state =
             rule.runOnUiThread {
                 MutableSceneTransitionLayoutStateImpl(
-                        SceneA,
-                        transitions { overscrollDisabled(SceneA, Orientation.Horizontal) }
-                    )
-                    .apply {
-                        startTransition(
-                            transition(
-                                from = SceneA,
-                                to = SceneB,
-                                progress = { -1f },
-                                orientation = Orientation.Horizontal
-                            )
-                        )
-                    }
+                    SceneA,
+                    transitions { overscrollDisabled(SceneA, Orientation.Horizontal) }
+                )
             }
 
         lateinit var layoutImpl: SceneTransitionLayoutImpl
-        rule.setContent {
-            SceneTransitionLayoutForTesting(
-                state,
-                Modifier.size(100.dp),
-                onLayoutImpl = { layoutImpl = it },
-            ) {
-                scene(SceneA) {}
-                scene(SceneB) { Box(Modifier.element(TestElements.Foo)) }
+        val scope =
+            rule.setContentAndCreateMainScope {
+                SceneTransitionLayoutForTesting(
+                    state,
+                    Modifier.size(100.dp),
+                    onLayoutImpl = { layoutImpl = it },
+                ) {
+                    scene(SceneA) {}
+                    scene(SceneB) { Box(Modifier.element(TestElements.Foo)) }
+                }
             }
+
+        scope.launch {
+            state.startTransition(
+                transition(
+                    from = SceneA,
+                    to = SceneB,
+                    progress = { -1f },
+                    orientation = Orientation.Horizontal
+                )
+            )
         }
+        rule.waitForIdle()
 
         assertThat(layoutImpl.elements).containsKey(TestElements.Foo)
         val foo = layoutImpl.elements.getValue(TestElements.Foo)
@@ -1647,7 +1661,7 @@
     }
 
     @Test
-    fun lastAlphaIsNotSetByOutdatedLayer() = runTest {
+    fun lastAlphaIsNotSetByOutdatedLayer() {
         val state =
             rule.runOnUiThread {
                 MutableSceneTransitionLayoutStateImpl(
@@ -1657,23 +1671,24 @@
             }
 
         lateinit var layoutImpl: SceneTransitionLayoutImpl
-        rule.setContent {
-            SceneTransitionLayoutForTesting(state, onLayoutImpl = { layoutImpl = it }) {
-                scene(SceneA) {}
-                scene(SceneB) { Box(Modifier.element(TestElements.Foo)) }
-                scene(SceneC) { Box(Modifier.element(TestElements.Foo)) }
+        val scope =
+            rule.setContentAndCreateMainScope {
+                SceneTransitionLayoutForTesting(state, onLayoutImpl = { layoutImpl = it }) {
+                    scene(SceneA) {}
+                    scene(SceneB) { Box(Modifier.element(TestElements.Foo)) }
+                    scene(SceneC) { Box(Modifier.element(TestElements.Foo)) }
+                }
             }
-        }
 
         // Start A => B at 0.5f.
         var aToBProgress by mutableStateOf(0.5f)
-        rule.runOnUiThread {
+        scope.launch {
             state.startTransition(
                 transition(
                     from = SceneA,
                     to = SceneB,
                     progress = { aToBProgress },
-                    onFinish = neverFinish(),
+                    onFreezeAndAnimate = { /* never finish */ },
                 )
             )
         }
@@ -1692,7 +1707,7 @@
         assertThat(fooInB.lastAlpha).isEqualTo(0.7f)
 
         // Start B => C at 0.3f.
-        rule.runOnUiThread {
+        scope.launch {
             state.startTransition(transition(from = SceneB, to = SceneC, progress = { 0.3f }))
         }
         rule.waitForIdle()
@@ -1720,16 +1735,17 @@
             }
 
         lateinit var layoutImpl: SceneTransitionLayoutImpl
-        rule.setContent {
-            SceneTransitionLayoutForTesting(state, onLayoutImpl = { layoutImpl = it }) {
-                scene(SceneA) {}
-                scene(SceneB) { Box(Modifier.element(TestElements.Foo)) }
+        val scope =
+            rule.setContentAndCreateMainScope {
+                SceneTransitionLayoutForTesting(state, onLayoutImpl = { layoutImpl = it }) {
+                    scene(SceneA) {}
+                    scene(SceneB) { Box(Modifier.element(TestElements.Foo)) }
+                }
             }
-        }
 
         // Start A => B at 60%.
         var interruptionProgress by mutableStateOf(1f)
-        rule.runOnUiThread {
+        scope.launch {
             state.startTransition(
                 transition(
                     from = SceneA,
@@ -1774,19 +1790,20 @@
             Box(Modifier.element(TestElements.Foo).size(10.dp))
         }
 
-        rule.setContent {
-            SceneTransitionLayout(state) {
-                scene(SceneA) { Foo() }
-                scene(SceneB) { Foo() }
+        val scope =
+            rule.setContentAndCreateMainScope {
+                SceneTransitionLayout(state) {
+                    scene(SceneA) { Foo() }
+                    scene(SceneB) { Foo() }
+                }
             }
-        }
 
         rule.onNode(isElement(TestElements.Foo, SceneA)).assertIsDisplayed()
         rule.onNode(isElement(TestElements.Foo, SceneB)).assertDoesNotExist()
 
         // A => B while overscrolling at scene B.
         var progress by mutableStateOf(2f)
-        rule.runOnUiThread {
+        scope.launch {
             state.startTransition(transition(from = SceneA, to = SceneB, progress = { progress }))
         }
         rule.waitForIdle()
@@ -1827,19 +1844,20 @@
             MovableElement(key, modifier) { content { Text(text) } }
         }
 
-        rule.setContent {
-            SceneTransitionLayout(state) {
-                scene(SceneA) { MovableFoo(text = fooInA) }
-                scene(SceneB) { MovableFoo(text = fooInB) }
+        val scope =
+            rule.setContentAndCreateMainScope {
+                SceneTransitionLayout(state) {
+                    scene(SceneA) { MovableFoo(text = fooInA) }
+                    scene(SceneB) { MovableFoo(text = fooInB) }
+                }
             }
-        }
 
         rule.onNode(hasText(fooInA)).assertIsDisplayed()
         rule.onNode(hasText(fooInB)).assertDoesNotExist()
 
         // A => B while overscrolling at scene B.
         var progress by mutableStateOf(2f)
-        rule.runOnUiThread {
+        scope.launch {
             state.startTransition(transition(from = SceneA, to = SceneB, progress = { progress }))
         }
         rule.waitForIdle()
@@ -1858,7 +1876,7 @@
     }
 
     @Test
-    fun interruptionThenOverscroll() = runTest {
+    fun interruptionThenOverscroll() {
         val state =
             rule.runOnUiThread {
                 MutableSceneTransitionLayoutStateImpl(
@@ -1879,22 +1897,23 @@
             }
         }
 
-        rule.setContent {
-            SceneTransitionLayout(state, Modifier.size(200.dp)) {
-                scene(SceneA) { SceneWithFoo(offset = DpOffset.Zero) }
-                scene(SceneB) { SceneWithFoo(offset = DpOffset(x = 40.dp, y = 0.dp)) }
-                scene(SceneC) { SceneWithFoo(offset = DpOffset(x = 40.dp, y = 40.dp)) }
+        val scope =
+            rule.setContentAndCreateMainScope {
+                SceneTransitionLayout(state, Modifier.size(200.dp)) {
+                    scene(SceneA) { SceneWithFoo(offset = DpOffset.Zero) }
+                    scene(SceneB) { SceneWithFoo(offset = DpOffset(x = 40.dp, y = 0.dp)) }
+                    scene(SceneC) { SceneWithFoo(offset = DpOffset(x = 40.dp, y = 40.dp)) }
+                }
             }
-        }
 
         // Start A => B at 75%.
-        rule.runOnUiThread {
+        scope.launch {
             state.startTransition(
                 transition(
                     from = SceneA,
                     to = SceneB,
                     progress = { 0.75f },
-                    onFinish = neverFinish(),
+                    onFreezeAndAnimate = { /* never finish */ },
                 )
             )
         }
@@ -1907,7 +1926,7 @@
         // Interrupt A => B with B => C at 0%.
         var progress by mutableStateOf(0f)
         var interruptionProgress by mutableStateOf(1f)
-        rule.runOnUiThread {
+        scope.launch {
             state.startTransition(
                 transition(
                     from = SceneB,
@@ -1915,7 +1934,7 @@
                     progress = { progress },
                     interruptionProgress = { interruptionProgress },
                     orientation = Orientation.Vertical,
-                    onFinish = neverFinish(),
+                    onFreezeAndAnimate = { /* never finish */ },
                 )
             )
         }
@@ -1963,12 +1982,13 @@
         }
 
         lateinit var layoutImpl: SceneTransitionLayoutImpl
-        rule.setContent {
-            SceneTransitionLayoutForTesting(state, onLayoutImpl = { layoutImpl = it }) {
-                scene(SceneA) { NestedFooBar() }
-                scene(SceneB) { NestedFooBar() }
+        val scope =
+            rule.setContentAndCreateMainScope {
+                SceneTransitionLayoutForTesting(state, onLayoutImpl = { layoutImpl = it }) {
+                    scene(SceneA) { NestedFooBar() }
+                    scene(SceneB) { NestedFooBar() }
+                }
             }
-        }
 
         // Idle on A: composed and placed only in B.
         rule.onNode(isElement(TestElements.Foo, SceneA)).assertIsDisplayed()
@@ -1997,7 +2017,7 @@
         assertThat(barInA.lastScale).isNotEqualTo(Scale.Unspecified)
 
         // A => B: composed in both and placed only in B.
-        rule.runOnUiThread { state.startTransition(transition(from = SceneA, to = SceneB)) }
+        scope.launch { state.startTransition(transition(from = SceneA, to = SceneB)) }
         rule.onNode(isElement(TestElements.Foo, SceneA)).assertExists().assertIsNotDisplayed()
         rule.onNode(isElement(TestElements.Bar, SceneA)).assertExists().assertIsNotDisplayed()
         rule.onNode(isElement(TestElements.Foo, SceneB)).assertIsDisplayed()
@@ -2024,7 +2044,7 @@
     }
 
     @Test
-    fun currentTransitionSceneIsUsedToComputeElementValues() = runTest {
+    fun currentTransitionSceneIsUsedToComputeElementValues() {
         val state =
             rule.runOnIdle {
                 MutableSceneTransitionLayoutStateImpl(
@@ -2044,23 +2064,31 @@
             }
         }
 
-        rule.setContent {
-            SceneTransitionLayout(state, Modifier.size(200.dp)) {
-                scene(SceneA) { Foo() }
-                scene(SceneB) {}
-                scene(SceneC) { Foo() }
+        val scope =
+            rule.setContentAndCreateMainScope {
+                SceneTransitionLayout(state, Modifier.size(200.dp)) {
+                    scene(SceneA) { Foo() }
+                    scene(SceneB) {}
+                    scene(SceneC) { Foo() }
+                }
             }
-        }
 
         // We have 2 transitions:
         //  - A => B at 100%
         //  - B => C at 0%
         // So Foo should have a size of (40dp, 60dp) in both A and C given that it is scaling its
         // size in B => C.
-        rule.runOnUiThread {
+        scope.launch {
             state.startTransition(
-                transition(from = SceneA, to = SceneB, progress = { 1f }, onFinish = neverFinish())
+                transition(
+                    from = SceneA,
+                    to = SceneB,
+                    progress = { 1f },
+                    onFreezeAndAnimate = { /* never finish */ },
+                )
             )
+        }
+        scope.launch {
             state.startTransition(transition(from = SceneB, to = SceneC, progress = { 0f }))
         }
 
@@ -2069,7 +2097,7 @@
     }
 
     @Test
-    fun interruptionDeltasAreProperlyCleaned() = runTest {
+    fun interruptionDeltasAreProperlyCleaned() {
         val state = rule.runOnIdle { MutableSceneTransitionLayoutStateImpl(SceneA) }
 
         @Composable
@@ -2079,18 +2107,24 @@
             }
         }
 
-        rule.setContent {
-            SceneTransitionLayout(state, Modifier.size(200.dp)) {
-                scene(SceneA) { Foo(offset = 0.dp) }
-                scene(SceneB) { Foo(offset = 20.dp) }
-                scene(SceneC) { Foo(offset = 40.dp) }
+        val scope =
+            rule.setContentAndCreateMainScope {
+                SceneTransitionLayout(state, Modifier.size(200.dp)) {
+                    scene(SceneA) { Foo(offset = 0.dp) }
+                    scene(SceneB) { Foo(offset = 20.dp) }
+                    scene(SceneC) { Foo(offset = 40.dp) }
+                }
             }
-        }
 
         // Start A => B at 50%.
         val aToB =
-            transition(from = SceneA, to = SceneB, progress = { 0.5f }, onFinish = neverFinish())
-        rule.runOnUiThread { state.startTransition(aToB) }
+            transition(
+                from = SceneA,
+                to = SceneB,
+                progress = { 0.5f },
+                onFreezeAndAnimate = { /* never finish */ },
+            )
+        scope.launch { state.startTransition(aToB) }
         rule.onNode(isElement(TestElements.Foo, SceneB)).assertPositionInRootIsEqualTo(10.dp, 10.dp)
 
         // Start B => C at 0%. This will compute an interruption delta of (-10dp, -10dp) so that the
@@ -2103,9 +2137,9 @@
                 current = { SceneB },
                 progress = { 0f },
                 interruptionProgress = { interruptionProgress },
-                onFinish = neverFinish(),
+                onFreezeAndAnimate = { /* never finish */ },
             )
-        rule.runOnUiThread { state.startTransition(bToC) }
+        scope.launch { state.startTransition(bToC) }
         rule.onNode(isElement(TestElements.Foo, SceneC)).assertPositionInRootIsEqualTo(10.dp, 10.dp)
 
         // Finish the interruption and leave the transition progress at 0f. We should be at the same
@@ -2116,9 +2150,9 @@
         // Finish both transitions but directly start a new one B => A with interruption progress
         // 100%. We should be at (20dp, 20dp), unless the interruption deltas have not been
         // correctly cleaned.
-        rule.runOnUiThread {
-            state.finishTransition(aToB)
-            state.finishTransition(bToC)
+        aToB.finish()
+        bToC.finish()
+        scope.launch {
             state.startTransition(
                 transition(
                     from = SceneB,
@@ -2132,7 +2166,7 @@
     }
 
     @Test
-    fun lastSizeIsUnspecifiedWhenOverscrollingOtherScene() = runTest {
+    fun lastSizeIsUnspecifiedWhenOverscrollingOtherScene() {
         val state =
             rule.runOnIdle {
                 MutableSceneTransitionLayoutStateImpl(
@@ -2147,17 +2181,23 @@
         }
 
         lateinit var layoutImpl: SceneTransitionLayoutImpl
-        rule.setContent {
-            SceneTransitionLayoutForTesting(state, onLayoutImpl = { layoutImpl = it }) {
-                scene(SceneA) { Foo() }
-                scene(SceneB) { Foo() }
+        val scope =
+            rule.setContentAndCreateMainScope {
+                SceneTransitionLayoutForTesting(state, onLayoutImpl = { layoutImpl = it }) {
+                    scene(SceneA) { Foo() }
+                    scene(SceneB) { Foo() }
+                }
             }
-        }
 
         // Overscroll A => B on A.
-        rule.runOnUiThread {
+        scope.launch {
             state.startTransition(
-                transition(from = SceneA, to = SceneB, progress = { -1f }, onFinish = neverFinish())
+                transition(
+                    from = SceneA,
+                    to = SceneB,
+                    progress = { -1f },
+                    onFreezeAndAnimate = { /* never finish */ },
+                )
             )
         }
         rule.waitForIdle()
@@ -2173,7 +2213,7 @@
     }
 
     @Test
-    fun transparentElementIsNotImpactingInterruption() = runTest {
+    fun transparentElementIsNotImpactingInterruption() {
         val state =
             rule.runOnIdle {
                 MutableSceneTransitionLayoutStateImpl(
@@ -2200,23 +2240,24 @@
             Box(modifier.element(TestElements.Foo).size(10.dp))
         }
 
-        rule.setContent {
-            SceneTransitionLayout(state) {
-                scene(SceneB) { Foo(Modifier.offset(40.dp, 60.dp)) }
+        val scope =
+            rule.setContentAndCreateMainScope {
+                SceneTransitionLayout(state) {
+                    scene(SceneB) { Foo(Modifier.offset(40.dp, 60.dp)) }
 
-                // Define A after B so that Foo is placed in A during A <=> B.
-                scene(SceneA) { Foo() }
+                    // Define A after B so that Foo is placed in A during A <=> B.
+                    scene(SceneA) { Foo() }
+                }
             }
-        }
 
         // Start A => B at 70%.
-        rule.runOnUiThread {
+        scope.launch {
             state.startTransition(
                 transition(
                     from = SceneA,
                     to = SceneB,
                     progress = { 0.7f },
-                    onFinish = neverFinish(),
+                    onFreezeAndAnimate = { /* never finish */ },
                 )
             )
         }
@@ -2227,14 +2268,14 @@
         // Start B => A at 50% with interruptionProgress = 100%. Foo is placed in A and should still
         // be at (40dp, 60dp) given that it was fully transparent in A before the interruption.
         var interruptionProgress by mutableStateOf(1f)
-        rule.runOnUiThread {
+        scope.launch {
             state.startTransition(
                 transition(
                     from = SceneB,
                     to = SceneA,
                     progress = { 0.5f },
                     interruptionProgress = { interruptionProgress },
-                    onFinish = neverFinish(),
+                    onFreezeAndAnimate = { /* never finish */ },
                 )
             )
         }
@@ -2250,7 +2291,7 @@
     }
 
     @Test
-    fun replacedTransitionDoesNotTriggerInterruption() = runTest {
+    fun replacedTransitionDoesNotTriggerInterruption() {
         val state = rule.runOnIdle { MutableSceneTransitionLayoutStateImpl(SceneA) }
 
         @Composable
@@ -2258,17 +2299,23 @@
             Box(modifier.element(TestElements.Foo).size(10.dp))
         }
 
-        rule.setContent {
-            SceneTransitionLayout(state) {
-                scene(SceneA) { Foo() }
-                scene(SceneB) { Foo(Modifier.offset(40.dp, 60.dp)) }
+        val scope =
+            rule.setContentAndCreateMainScope {
+                SceneTransitionLayout(state) {
+                    scene(SceneA) { Foo() }
+                    scene(SceneB) { Foo(Modifier.offset(40.dp, 60.dp)) }
+                }
             }
-        }
 
         // Start A => B at 50%.
         val aToB1 =
-            transition(from = SceneA, to = SceneB, progress = { 0.5f }, onFinish = neverFinish())
-        rule.runOnUiThread { state.startTransition(aToB1) }
+            transition(
+                from = SceneA,
+                to = SceneB,
+                progress = { 0.5f },
+                onFreezeAndAnimate = { /* never finish */ },
+            )
+        scope.launch { state.startTransition(aToB1) }
         rule.onNode(isElement(TestElements.Foo, SceneA)).assertIsNotDisplayed()
         rule.onNode(isElement(TestElements.Foo, SceneB)).assertPositionInRootIsEqualTo(20.dp, 30.dp)
 
@@ -2282,7 +2329,7 @@
                 interruptionProgress = { 1f },
                 replacedTransition = aToB1,
             )
-        rule.runOnUiThread { state.startTransition(aToB2) }
+        scope.launch { state.startTransition(aToB2) }
         rule.onNode(isElement(TestElements.Foo, SceneA)).assertIsNotDisplayed()
         rule.onNode(isElement(TestElements.Foo, SceneB)).assertPositionInRootIsEqualTo(40.dp, 60.dp)
     }
@@ -2428,12 +2475,13 @@
         }
 
         lateinit var layoutImpl: SceneTransitionLayoutImpl
-        rule.setContent {
-            SceneTransitionLayoutForTesting(state, onLayoutImpl = { layoutImpl = it }) {
-                scene(from) { Box { exitingElements.forEach { Foo(it) } } }
-                scene(to) { Box { enteringElements.forEach { Foo(it) } } }
+        val scope =
+            rule.setContentAndCreateMainScope {
+                SceneTransitionLayoutForTesting(state, onLayoutImpl = { layoutImpl = it }) {
+                    scene(from) { Box { exitingElements.forEach { Foo(it) } } }
+                    scene(to) { Box { enteringElements.forEach { Foo(it) } } }
+                }
             }
-        }
 
         val bToA =
             transition(
@@ -2443,7 +2491,7 @@
                 previewProgress = { previewProgress },
                 isInPreviewStage = { isInPreviewStage }
             )
-        rule.runOnUiThread { state.startTransition(bToA) }
+        scope.launch { state.startTransition(bToA) }
         rule.waitForIdle()
         return layoutImpl
     }
diff --git a/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/InterruptionHandlerTest.kt b/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/InterruptionHandlerTest.kt
index 3f6bd2c..7498df1 100644
--- a/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/InterruptionHandlerTest.kt
+++ b/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/InterruptionHandlerTest.kt
@@ -25,9 +25,9 @@
 import com.android.compose.animation.scene.content.state.TransitionState
 import com.android.compose.animation.scene.subjects.assertThat
 import com.android.compose.test.runMonotonicClockTest
+import com.android.compose.test.transition
 import com.google.common.truth.Correspondence
 import com.google.common.truth.Truth.assertThat
-import kotlinx.coroutines.launch
 import org.junit.Rule
 import org.junit.Test
 import org.junit.runner.RunWith
@@ -155,13 +155,21 @@
                 // Progress must be > visibility threshold otherwise we will directly snap to A.
                 progress = { 0.5f },
                 progressVelocity = { progressVelocity },
-                onFinish = { launch {} },
             )
-        state.startTransition(aToB)
+        state.startTransitionImmediately(animationScope = backgroundScope, aToB)
 
         // Animate back to A. The previous transition is reversed, i.e. it has the same (from, to)
         // pair, and its velocity is used when animating the progress back to 0.
-        val bToA = checkNotNull(state.setTargetScene(SceneA, coroutineScope = this))
+        val bToA =
+            checkNotNull(
+                    state.setTargetScene(
+                        SceneA,
+                        // We use testScope here and not backgroundScope because setTargetScene
+                        // needs the monotonic clock that is only available in the test scope.
+                        coroutineScope = this,
+                    )
+                )
+                .first
         testScheduler.runCurrent()
         assertThat(bToA).hasFromScene(SceneA)
         assertThat(bToA).hasToScene(SceneB)
@@ -181,13 +189,21 @@
                 to = SceneB,
                 current = { SceneA },
                 progressVelocity = { progressVelocity },
-                onFinish = { launch {} },
             )
-        state.startTransition(aToB)
+        state.startTransitionImmediately(animationScope = backgroundScope, aToB)
 
         // Animate to B. The previous transition is reversed, i.e. it has the same (from, to) pair,
         // and its velocity is used when animating the progress to 1.
-        val bToA = checkNotNull(state.setTargetScene(SceneB, coroutineScope = this))
+        val bToA =
+            checkNotNull(
+                    state.setTargetScene(
+                        SceneB,
+                        // We use testScope here and not backgroundScope because setTargetScene
+                        // needs the monotonic clock that is only available in the test scope.
+                        coroutineScope = this,
+                    )
+                )
+                .first
         testScheduler.runCurrent()
         assertThat(bToA).hasFromScene(SceneA)
         assertThat(bToA).hasToScene(SceneB)
diff --git a/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/MovableElementContentPickerTest.kt b/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/MovableElementContentPickerTest.kt
index e1d0945..c8e7e65 100644
--- a/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/MovableElementContentPickerTest.kt
+++ b/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/MovableElementContentPickerTest.kt
@@ -17,6 +17,7 @@
 package com.android.compose.animation.scene
 
 import androidx.test.ext.junit.runners.AndroidJUnit4
+import com.android.compose.test.transition
 import com.google.common.truth.Truth.assertThat
 import org.junit.Assert.assertThrows
 import org.junit.Test
diff --git a/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/ObservableTransitionStateTest.kt b/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/ObservableTransitionStateTest.kt
index 0543e7f..2c723ec 100644
--- a/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/ObservableTransitionStateTest.kt
+++ b/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/ObservableTransitionStateTest.kt
@@ -29,6 +29,7 @@
 import androidx.test.ext.junit.runners.AndroidJUnit4
 import com.android.compose.animation.scene.TestScenes.SceneA
 import com.android.compose.animation.scene.TestScenes.SceneB
+import com.android.compose.test.transition
 import com.google.common.truth.Truth.assertThat
 import kotlinx.coroutines.flow.flatMapLatest
 import kotlinx.coroutines.launch
@@ -139,7 +140,7 @@
         var transitionCurrentScene by mutableStateOf(SceneA)
         val transition =
             transition(from = SceneA, to = SceneB, current = { transitionCurrentScene })
-        state.startTransition(transition)
+        state.startTransitionImmediately(animationScope = backgroundScope, transition)
         assertThat(currentScene.value).isEqualTo(SceneA)
 
         // Change the transition current scene.
diff --git a/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/SceneTransitionLayoutStateTest.kt b/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/SceneTransitionLayoutStateTest.kt
index 69f2cba..29eedf6 100644
--- a/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/SceneTransitionLayoutStateTest.kt
+++ b/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/SceneTransitionLayoutStateTest.kt
@@ -30,14 +30,14 @@
 import com.android.compose.animation.scene.content.state.TransitionState
 import com.android.compose.animation.scene.subjects.assertThat
 import com.android.compose.animation.scene.transition.link.StateLink
+import com.android.compose.test.MonotonicClockTestScope
+import com.android.compose.test.TestTransition
 import com.android.compose.test.runMonotonicClockTest
+import com.android.compose.test.transition
 import com.google.common.truth.Truth.assertThat
 import kotlinx.coroutines.CoroutineStart
-import kotlinx.coroutines.Job
 import kotlinx.coroutines.cancelAndJoin
 import kotlinx.coroutines.launch
-import kotlinx.coroutines.sync.Mutex
-import kotlinx.coroutines.sync.withLock
 import kotlinx.coroutines.test.runTest
 import org.junit.Rule
 import org.junit.Test
@@ -58,9 +58,12 @@
     }
 
     @Test
-    fun isTransitioningTo_transition() {
+    fun isTransitioningTo_transition() = runTest {
         val state = MutableSceneTransitionLayoutStateImpl(SceneA, SceneTransitions.Empty)
-        state.startTransition(transition(from = SceneA, to = SceneB))
+        state.startTransitionImmediately(
+            animationScope = backgroundScope,
+            transition(from = SceneA, to = SceneB)
+        )
 
         assertThat(state.isTransitioning()).isTrue()
         assertThat(state.isTransitioning(from = SceneA)).isTrue()
@@ -79,11 +82,10 @@
     @Test
     fun setTargetScene_idleToDifferentScene() = runMonotonicClockTest {
         val state = MutableSceneTransitionLayoutState(SceneA)
-        val transition = state.setTargetScene(SceneB, coroutineScope = this)
-        assertThat(transition).isNotNull()
+        val (transition, job) = checkNotNull(state.setTargetScene(SceneB, coroutineScope = this))
         assertThat(state.transitionState).isEqualTo(transition)
 
-        transition!!.finish().join()
+        job.join()
         assertThat(state.transitionState).isEqualTo(TransitionState.Idle(SceneB))
     }
 
@@ -91,11 +93,10 @@
     fun setTargetScene_transitionToSameScene() = runMonotonicClockTest {
         val state = MutableSceneTransitionLayoutState(SceneA)
 
-        val transition = state.setTargetScene(SceneB, coroutineScope = this)
-        assertThat(transition).isNotNull()
+        val (_, job) = checkNotNull(state.setTargetScene(SceneB, coroutineScope = this))
         assertThat(state.setTargetScene(SceneB, coroutineScope = this)).isNull()
 
-        transition!!.finish().join()
+        job.join()
         assertThat(state.transitionState).isEqualTo(TransitionState.Idle(SceneB))
     }
 
@@ -104,10 +105,9 @@
         val state = MutableSceneTransitionLayoutState(SceneA)
 
         assertThat(state.setTargetScene(SceneB, coroutineScope = this)).isNotNull()
-        val transition = state.setTargetScene(SceneC, coroutineScope = this)
-        assertThat(transition).isNotNull()
+        val (_, job) = checkNotNull(state.setTargetScene(SceneC, coroutineScope = this))
 
-        transition!!.finish().join()
+        job.join()
         assertThat(state.transitionState).isEqualTo(TransitionState.Idle(SceneC))
     }
 
@@ -118,7 +118,7 @@
         lateinit var transition: TransitionState.Transition
         val job =
             launch(start = CoroutineStart.UNDISPATCHED) {
-                transition = state.setTargetScene(SceneB, coroutineScope = this)!!
+                transition = checkNotNull(state.setTargetScene(SceneB, coroutineScope = this)).first
             }
         assertThat(state.transitionState).isEqualTo(transition)
 
@@ -127,18 +127,6 @@
         assertThat(state.transitionState).isEqualTo(TransitionState.Idle(SceneB))
     }
 
-    @Test
-    fun transition_finishReturnsTheSameJobWhenCalledMultipleTimes() = runMonotonicClockTest {
-        val state = MutableSceneTransitionLayoutState(SceneA)
-        val transition = state.setTargetScene(SceneB, coroutineScope = this)
-        assertThat(transition).isNotNull()
-
-        val job = transition!!.finish()
-        assertThat(transition.finish()).isSameInstanceAs(job)
-        assertThat(transition.finish()).isSameInstanceAs(job)
-        assertThat(transition.finish()).isSameInstanceAs(job)
-    }
-
     private fun setupLinkedStates(
         parentInitialScene: SceneKey = SceneC,
         childInitialScene: SceneKey = SceneA,
@@ -163,22 +151,24 @@
     }
 
     @Test
-    fun linkedTransition_startsLinkAndFinishesLinkInToState() {
+    fun linkedTransition_startsLinkAndFinishesLinkInToState() = runTest {
         val (parentState, childState) = setupLinkedStates()
 
         val childTransition = transition(SceneA, SceneB)
 
-        childState.startTransition(childTransition)
+        val job =
+            childState.startTransitionImmediately(animationScope = backgroundScope, childTransition)
         assertThat(childState.isTransitioning(SceneA, SceneB)).isTrue()
         assertThat(parentState.isTransitioning(SceneC, SceneD)).isTrue()
 
-        childState.finishTransition(childTransition)
+        childTransition.finish()
+        job.join()
         assertThat(childState.transitionState).isEqualTo(TransitionState.Idle(SceneB))
         assertThat(parentState.transitionState).isEqualTo(TransitionState.Idle(SceneD))
     }
 
     @Test
-    fun linkedTransition_transitiveLink() {
+    fun linkedTransition_transitiveLink() = runTest {
         val parentParentState =
             MutableSceneTransitionLayoutState(SceneB) as MutableSceneTransitionLayoutStateImpl
         val parentLink =
@@ -204,25 +194,27 @@
 
         val childTransition = transition(SceneA, SceneB)
 
-        childState.startTransition(childTransition)
+        val job =
+            childState.startTransitionImmediately(animationScope = backgroundScope, childTransition)
         assertThat(childState.isTransitioning(SceneA, SceneB)).isTrue()
         assertThat(parentState.isTransitioning(SceneC, SceneD)).isTrue()
         assertThat(parentParentState.isTransitioning(SceneB, SceneC)).isTrue()
 
-        childState.finishTransition(childTransition)
+        childTransition.finish()
+        job.join()
         assertThat(childState.transitionState).isEqualTo(TransitionState.Idle(SceneB))
         assertThat(parentState.transitionState).isEqualTo(TransitionState.Idle(SceneD))
         assertThat(parentParentState.transitionState).isEqualTo(TransitionState.Idle(SceneC))
     }
 
     @Test
-    fun linkedTransition_linkProgressIsEqual() {
+    fun linkedTransition_linkProgressIsEqual() = runTest {
         val (parentState, childState) = setupLinkedStates()
 
         var progress = 0f
         val childTransition = transition(SceneA, SceneB, progress = { progress })
 
-        childState.startTransition(childTransition)
+        childState.startTransitionImmediately(animationScope = backgroundScope, childTransition)
         assertThat(parentState.currentTransition?.progress).isEqualTo(0f)
 
         progress = .5f
@@ -230,28 +222,32 @@
     }
 
     @Test
-    fun linkedTransition_reverseTransitionIsNotLinked() {
+    fun linkedTransition_reverseTransitionIsNotLinked() = runTest {
         val (parentState, childState) = setupLinkedStates()
 
         val childTransition = transition(SceneB, SceneA, current = { SceneB })
 
-        childState.startTransition(childTransition)
+        val job =
+            childState.startTransitionImmediately(animationScope = backgroundScope, childTransition)
         assertThat(childState.isTransitioning(SceneB, SceneA)).isTrue()
         assertThat(parentState.transitionState).isEqualTo(TransitionState.Idle(SceneC))
 
-        childState.finishTransition(childTransition)
+        childTransition.finish()
+        job.join()
         assertThat(childState.transitionState).isEqualTo(TransitionState.Idle(SceneB))
         assertThat(parentState.transitionState).isEqualTo(TransitionState.Idle(SceneC))
     }
 
     @Test
-    fun linkedTransition_startsLinkAndFinishesLinkInFromState() {
+    fun linkedTransition_startsLinkAndFinishesLinkInFromState() = runTest {
         val (parentState, childState) = setupLinkedStates()
 
         val childTransition = transition(SceneA, SceneB, current = { SceneA })
-        childState.startTransition(childTransition)
+        val job =
+            childState.startTransitionImmediately(animationScope = backgroundScope, childTransition)
 
-        childState.finishTransition(childTransition)
+        childTransition.finish()
+        job.join()
         assertThat(childState.transitionState).isEqualTo(TransitionState.Idle(SceneA))
         assertThat(parentState.transitionState).isEqualTo(TransitionState.Idle(SceneC))
     }
@@ -260,22 +256,14 @@
     fun linkedTransition_startsLinkButLinkedStateIsTakenOver() = runTest {
         val (parentState, childState) = setupLinkedStates()
 
-        val childTransition =
-            transition(
-                SceneA,
-                SceneB,
-                onFinish = { launch { /* Do nothing. */ } },
-            )
-        val parentTransition =
-            transition(
-                SceneC,
-                SceneA,
-                onFinish = { launch { /* Do nothing. */ } },
-            )
-        childState.startTransition(childTransition)
-        parentState.startTransition(parentTransition)
+        val childTransition = transition(SceneA, SceneB)
+        val parentTransition = transition(SceneC, SceneA)
+        val job =
+            childState.startTransitionImmediately(animationScope = backgroundScope, childTransition)
+        parentState.startTransitionImmediately(animationScope = backgroundScope, parentTransition)
 
-        childState.finishTransition(childTransition)
+        childTransition.finish()
+        job.join()
         assertThat(childState.transitionState).isEqualTo(TransitionState.Idle(SceneB))
         assertThat(parentState.transitionState).isEqualTo(parentTransition)
     }
@@ -321,7 +309,8 @@
     @Test
     fun snapToIdleIfClose_snapToStart() = runMonotonicClockTest {
         val state = MutableSceneTransitionLayoutStateImpl(SceneA, SceneTransitions.Empty)
-        state.startTransition(
+        state.startTransitionImmediately(
+            animationScope = backgroundScope,
             transition(from = SceneA, to = SceneB, current = { SceneA }, progress = { 0.2f })
         )
         assertThat(state.isTransitioning()).isTrue()
@@ -339,7 +328,10 @@
     @Test
     fun snapToIdleIfClose_snapToEnd() = runMonotonicClockTest {
         val state = MutableSceneTransitionLayoutStateImpl(SceneA, SceneTransitions.Empty)
-        state.startTransition(transition(from = SceneA, to = SceneB, progress = { 0.8f }))
+        state.startTransitionImmediately(
+            animationScope = backgroundScope,
+            transition(from = SceneA, to = SceneB, progress = { 0.8f })
+        )
         assertThat(state.isTransitioning()).isTrue()
 
         // Ignore the request if the progress is not close to 0 or 1, using the threshold.
@@ -356,18 +348,12 @@
     fun snapToIdleIfClose_multipleTransitions() = runMonotonicClockTest {
         val state = MutableSceneTransitionLayoutStateImpl(SceneA, SceneTransitions.Empty)
 
-        val aToB =
-            transition(
-                from = SceneA,
-                to = SceneB,
-                progress = { 0.5f },
-                onFinish = { launch { /* do nothing */ } },
-            )
-        state.startTransition(aToB)
+        val aToB = transition(from = SceneA, to = SceneB, progress = { 0.5f })
+        state.startTransitionImmediately(animationScope = backgroundScope, aToB)
         assertThat(state.currentTransitions).containsExactly(aToB).inOrder()
 
         val bToC = transition(from = SceneB, to = SceneC, progress = { 0.8f })
-        state.startTransition(bToC)
+        state.startTransitionImmediately(animationScope = backgroundScope, bToC)
         assertThat(state.currentTransitions).containsExactly(aToB, bToC).inOrder()
 
         // Ignore the request if the progress is not close to 0 or 1, using the threshold.
@@ -385,7 +371,8 @@
         val state = MutableSceneTransitionLayoutStateImpl(SceneA, SceneTransitions.Empty)
         var progress by mutableStateOf(0f)
         var currentScene by mutableStateOf(SceneB)
-        state.startTransition(
+        state.startTransitionImmediately(
+            animationScope = backgroundScope,
             transition(
                 from = SceneA,
                 to = SceneB,
@@ -406,47 +393,51 @@
     }
 
     @Test
-    fun linkedTransition_fuzzyLinksAreMatchedAndStarted() {
+    fun linkedTransition_fuzzyLinksAreMatchedAndStarted() = runTest {
         val (parentState, childState) = setupLinkedStates(SceneC, SceneA, null, null, null, SceneD)
         val childTransition = transition(SceneA, SceneB)
 
-        childState.startTransition(childTransition)
+        val job =
+            childState.startTransitionImmediately(animationScope = backgroundScope, childTransition)
         assertThat(childState.isTransitioning(SceneA, SceneB)).isTrue()
         assertThat(parentState.isTransitioning(SceneC, SceneD)).isTrue()
 
-        childState.finishTransition(childTransition)
+        childTransition.finish()
+        job.join()
         assertThat(childState.transitionState).isEqualTo(TransitionState.Idle(SceneB))
         assertThat(parentState.transitionState).isEqualTo(TransitionState.Idle(SceneD))
     }
 
     @Test
-    fun linkedTransition_fuzzyLinksAreMatchedAndResetToProperPreviousScene() {
+    fun linkedTransition_fuzzyLinksAreMatchedAndResetToProperPreviousScene() = runTest {
         val (parentState, childState) =
             setupLinkedStates(SceneC, SceneA, SceneA, null, null, SceneD)
 
         val childTransition = transition(SceneA, SceneB, current = { SceneA })
 
-        childState.startTransition(childTransition)
+        val job =
+            childState.startTransitionImmediately(animationScope = backgroundScope, childTransition)
         assertThat(childState.isTransitioning(SceneA, SceneB)).isTrue()
         assertThat(parentState.isTransitioning(SceneC, SceneD)).isTrue()
 
-        childState.finishTransition(childTransition)
+        childTransition.finish()
+        job.join()
         assertThat(childState.transitionState).isEqualTo(TransitionState.Idle(SceneA))
         assertThat(parentState.transitionState).isEqualTo(TransitionState.Idle(SceneC))
     }
 
     @Test
-    fun linkedTransition_fuzzyLinksAreNotMatched() {
+    fun linkedTransition_fuzzyLinksAreNotMatched() = runTest {
         val (parentState, childState) =
             setupLinkedStates(SceneC, SceneA, SceneB, null, SceneC, SceneD)
         val childTransition = transition(SceneA, SceneB)
 
-        childState.startTransition(childTransition)
+        childState.startTransitionImmediately(animationScope = backgroundScope, childTransition)
         assertThat(childState.isTransitioning(SceneA, SceneB)).isTrue()
         assertThat(parentState.isTransitioning(SceneC, SceneD)).isFalse()
     }
 
-    private fun startOverscrollableTransistionFromAtoB(
+    private fun MonotonicClockTestScope.startOverscrollableTransistionFromAtoB(
         progress: () -> Float,
         sceneTransitions: SceneTransitions,
     ): MutableSceneTransitionLayoutStateImpl {
@@ -455,7 +446,8 @@
                 SceneA,
                 sceneTransitions,
             )
-        state.startTransition(
+        state.startTransitionImmediately(
+            animationScope = backgroundScope,
             transition(
                 from = SceneA,
                 to = SceneB,
@@ -560,54 +552,54 @@
 
     @Test
     fun multipleTransitions() = runTest {
-        val finishingTransitions = mutableSetOf<TransitionState.Transition>()
-        fun onFinish(transition: TransitionState.Transition): Job {
-            // Instead of letting the transition finish, we put the transition in the
-            // finishingTransitions set so that we can verify that finish() is called when expected
-            // and then we call state STLState.finishTransition() ourselves.
-            finishingTransitions.add(transition)
+        val frozenTransitions = mutableSetOf<TestTransition>()
+        fun onFreezeAndAnimate(transition: TestTransition): () -> Unit {
+            // Instead of letting the transition finish when it is frozen, we put the transition in
+            // the frozenTransitions set so that we can verify that freezeAndAnimateToCurrentState()
+            // is called when expected and then we call finish() ourselves to finish the
+            // transitions.
+            frozenTransitions.add(transition)
 
-            return backgroundScope.launch {
-                // Try to acquire a locked mutex so that this code never completes.
-                Mutex(locked = true).withLock {}
-            }
+            return { /* do nothing */ }
         }
 
         val state = MutableSceneTransitionLayoutStateImpl(SceneA, EmptyTestTransitions)
-        val aToB = transition(SceneA, SceneB, onFinish = ::onFinish)
-        val bToC = transition(SceneB, SceneC, onFinish = ::onFinish)
-        val cToA = transition(SceneC, SceneA, onFinish = ::onFinish)
+        val aToB = transition(SceneA, SceneB, onFreezeAndAnimate = ::onFreezeAndAnimate)
+        val bToC = transition(SceneB, SceneC, onFreezeAndAnimate = ::onFreezeAndAnimate)
+        val cToA = transition(SceneC, SceneA, onFreezeAndAnimate = ::onFreezeAndAnimate)
 
         // Starting state.
-        assertThat(finishingTransitions).isEmpty()
+        assertThat(frozenTransitions).isEmpty()
         assertThat(state.currentTransitions).isEmpty()
 
         // A => B.
-        state.startTransition(aToB)
-        assertThat(finishingTransitions).isEmpty()
+        val aToBJob = state.startTransitionImmediately(animationScope = backgroundScope, aToB)
+        assertThat(frozenTransitions).isEmpty()
         assertThat(state.finishedTransitions).isEmpty()
         assertThat(state.currentTransitions).containsExactly(aToB).inOrder()
 
-        // B => C. This should automatically call finish() on aToB.
-        state.startTransition(bToC)
-        assertThat(finishingTransitions).containsExactly(aToB)
+        // B => C. This should automatically call freezeAndAnimateToCurrentState() on aToB.
+        val bToCJob = state.startTransitionImmediately(animationScope = backgroundScope, bToC)
+        assertThat(frozenTransitions).containsExactly(aToB)
         assertThat(state.finishedTransitions).isEmpty()
         assertThat(state.currentTransitions).containsExactly(aToB, bToC).inOrder()
 
-        // C => A. This should automatically call finish() on bToC.
-        state.startTransition(cToA)
-        assertThat(finishingTransitions).containsExactly(aToB, bToC)
+        // C => A. This should automatically call freezeAndAnimateToCurrentState() on bToC.
+        state.startTransitionImmediately(animationScope = backgroundScope, cToA)
+        assertThat(frozenTransitions).containsExactly(aToB, bToC)
         assertThat(state.finishedTransitions).isEmpty()
         assertThat(state.currentTransitions).containsExactly(aToB, bToC, cToA).inOrder()
 
         // Mark bToC as finished. The list of current transitions does not change because aToB is
         // still not marked as finished.
-        state.finishTransition(bToC)
+        bToC.finish()
+        bToCJob.join()
         assertThat(state.finishedTransitions).containsExactly(bToC)
         assertThat(state.currentTransitions).containsExactly(aToB, bToC, cToA).inOrder()
 
         // Mark aToB as finished. This will remove both aToB and bToC from the list of transitions.
-        state.finishTransition(aToB)
+        aToB.finish()
+        aToBJob.join()
         assertThat(state.finishedTransitions).isEmpty()
         assertThat(state.currentTransitions).containsExactly(cToA).inOrder()
     }
@@ -617,8 +609,9 @@
         val state = MutableSceneTransitionLayoutStateImpl(SceneA, EmptyTestTransitions)
 
         fun startTransition() {
-            val transition = transition(SceneA, SceneB, onFinish = { launch { /* do nothing */ } })
-            state.startTransition(transition)
+            val transition =
+                transition(SceneA, SceneB, onFreezeAndAnimate = { launch { /* do nothing */ } })
+            state.startTransitionImmediately(animationScope = backgroundScope, transition)
         }
 
         var hasLoggedWtf = false
@@ -650,4 +643,21 @@
         assertThat(state.transitionState).isIdle()
         assertThat(state.transitionState).hasCurrentScene(SceneC)
     }
+
+    @Test
+    fun snapToScene_freezesCurrentTransition() = runMonotonicClockTest {
+        val state = MutableSceneTransitionLayoutStateImpl(SceneA)
+
+        // Start a transition that is never finished. We don't use backgroundScope on purpose so
+        // that this test would fail if the transition was not frozen when snapping.
+        state.startTransitionImmediately(animationScope = this, transition(SceneA, SceneB))
+        val transition = assertThat(state.transitionState).isSceneTransition()
+        assertThat(transition).hasFromScene(SceneA)
+        assertThat(transition).hasToScene(SceneB)
+
+        // Snap to C.
+        state.snapToScene(SceneC)
+        assertThat(state.transitionState).isIdle()
+        assertThat(state.transitionState).hasCurrentScene(SceneC)
+    }
 }
diff --git a/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/SceneTransitionLayoutTest.kt b/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/SceneTransitionLayoutTest.kt
index b8e13da..63ab04f 100644
--- a/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/SceneTransitionLayoutTest.kt
+++ b/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/SceneTransitionLayoutTest.kt
@@ -55,10 +55,13 @@
 import com.android.compose.animation.scene.TestScenes.SceneC
 import com.android.compose.animation.scene.subjects.assertThat
 import com.android.compose.test.assertSizeIsEqualTo
+import com.android.compose.test.setContentAndCreateMainScope
 import com.android.compose.test.subjects.DpOffsetSubject
 import com.android.compose.test.subjects.assertThat
+import com.android.compose.test.transition
 import com.google.common.truth.Truth.assertThat
 import kotlinx.coroutines.CoroutineScope
+import kotlinx.coroutines.launch
 import org.junit.Assert.assertThrows
 import org.junit.Rule
 import org.junit.Test
@@ -327,17 +330,18 @@
             }
 
         val layoutTag = "layout"
-        rule.setContent {
-            SceneTransitionLayout(state, Modifier.testTag(layoutTag)) {
-                scene(SceneA) { Box(Modifier.size(50.dp)) }
-                scene(SceneB) { Box(Modifier.size(70.dp)) }
+        val scope =
+            rule.setContentAndCreateMainScope {
+                SceneTransitionLayout(state, Modifier.testTag(layoutTag)) {
+                    scene(SceneA) { Box(Modifier.size(50.dp)) }
+                    scene(SceneB) { Box(Modifier.size(70.dp)) }
+                }
             }
-        }
 
         // Overscroll on A at -100%: size should be interpolated given that there is no overscroll
         // defined for scene A.
         var progress by mutableStateOf(-1f)
-        rule.runOnIdle {
+        scope.launch {
             state.startTransition(transition(from = SceneA, to = SceneB, progress = { progress }))
         }
         rule.onNodeWithTag(layoutTag).assertSizeIsEqualTo(30.dp)
diff --git a/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/transformation/AnchoredTranslateTest.kt b/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/transformation/AnchoredTranslateTest.kt
index 46075c3..5bfc947 100644
--- a/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/transformation/AnchoredTranslateTest.kt
+++ b/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/transformation/AnchoredTranslateTest.kt
@@ -27,7 +27,6 @@
 import androidx.test.ext.junit.runners.AndroidJUnit4
 import com.android.compose.animation.scene.TestElements
 import com.android.compose.animation.scene.testTransition
-import com.android.compose.animation.scene.transition
 import org.junit.Rule
 import org.junit.Test
 import org.junit.runner.RunWith
diff --git a/packages/SystemUI/compose/scene/tests/src/com/android/compose/test/SetContentAndCreateScope.kt b/packages/SystemUI/compose/scene/tests/src/com/android/compose/test/SetContentAndCreateScope.kt
new file mode 100644
index 0000000..28a864f
--- /dev/null
+++ b/packages/SystemUI/compose/scene/tests/src/com/android/compose/test/SetContentAndCreateScope.kt
@@ -0,0 +1,38 @@
+/*
+ * Copyright (C) 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.compose.test
+
+import androidx.compose.runtime.Composable
+import androidx.compose.runtime.rememberCoroutineScope
+import androidx.compose.ui.test.junit4.ComposeContentTestRule
+import kotlinx.coroutines.CoroutineScope
+import kotlinx.coroutines.Dispatchers
+
+/**
+ * Set [content] as this rule's content and return a [CoroutineScope] bound to [Dispatchers.Main]
+ * and scoped to this rule.
+ */
+fun ComposeContentTestRule.setContentAndCreateMainScope(
+    content: @Composable () -> Unit,
+): CoroutineScope {
+    lateinit var coroutineScope: CoroutineScope
+    setContent {
+        coroutineScope = rememberCoroutineScope(getContext = { Dispatchers.Main })
+        content()
+    }
+    return coroutineScope
+}
diff --git a/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/Transition.kt b/packages/SystemUI/compose/scene/tests/src/com/android/compose/test/TestTransition.kt
similarity index 64%
rename from packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/Transition.kt
rename to packages/SystemUI/compose/scene/tests/src/com/android/compose/test/TestTransition.kt
index 467031a..a6a83ee 100644
--- a/packages/SystemUI/compose/scene/tests/src/com/android/compose/animation/scene/Transition.kt
+++ b/packages/SystemUI/compose/scene/tests/src/com/android/compose/test/TestTransition.kt
@@ -14,17 +14,35 @@
  * limitations under the License.
  */
 
-package com.android.compose.animation.scene
+package com.android.compose.test
 
 import androidx.compose.foundation.gestures.Orientation
+import com.android.compose.animation.scene.ContentKey
+import com.android.compose.animation.scene.SceneKey
+import com.android.compose.animation.scene.SceneTransitionLayoutImpl
 import com.android.compose.animation.scene.content.state.TransitionState
-import kotlinx.coroutines.Job
-import kotlinx.coroutines.launch
-import kotlinx.coroutines.sync.Mutex
-import kotlinx.coroutines.sync.withLock
-import kotlinx.coroutines.test.TestScope
+import com.android.compose.animation.scene.content.state.TransitionState.Transition
+import kotlinx.coroutines.CompletableDeferred
 
-/** A utility to easily create a [TransitionState.Transition] in tests. */
+/** A transition for tests that will be finished once [finish] is called. */
+abstract class TestTransition(
+    fromScene: SceneKey,
+    toScene: SceneKey,
+    replacedTransition: Transition?,
+) : Transition.ChangeScene(fromScene, toScene, replacedTransition) {
+    private val finishCompletable = CompletableDeferred<Unit>()
+
+    override suspend fun run() {
+        finishCompletable.await()
+    }
+
+    /** Finish this transition. */
+    fun finish() {
+        finishCompletable.complete(Unit)
+    }
+}
+
+/** A utility to easily create a [TestTransition] in tests. */
 fun transition(
     from: SceneKey,
     to: SceneKey,
@@ -40,12 +58,11 @@
     isUpOrLeft: Boolean = false,
     bouncingContent: ContentKey? = null,
     orientation: Orientation = Orientation.Horizontal,
-    onFinish: ((TransitionState.Transition) -> Job)? = null,
-    replacedTransition: TransitionState.Transition? = null,
-): TransitionState.Transition.ChangeScene {
+    onFreezeAndAnimate: ((TestTransition) -> Unit)? = null,
+    replacedTransition: Transition? = null,
+): TestTransition {
     return object :
-        TransitionState.Transition.ChangeScene(from, to, replacedTransition),
-        TransitionState.HasOverscrollProperties {
+        TestTransition(from, to, replacedTransition), TransitionState.HasOverscrollProperties {
         override val currentScene: SceneKey
             get() = current()
 
@@ -71,14 +88,12 @@
         override val orientation: Orientation = orientation
         override val absoluteDistance = 0f
 
-        override fun finish(): Job {
-            val onFinish =
-                onFinish
-                    ?: error(
-                        "onFinish() must be provided if finish() is called on test transitions"
-                    )
-
-            return onFinish(this)
+        override fun freezeAndAnimateToCurrentState() {
+            if (onFreezeAndAnimate != null) {
+                onFreezeAndAnimate(this)
+            } else {
+                finish()
+            }
         }
 
         override fun interruptionProgress(layoutImpl: SceneTransitionLayoutImpl): Float {
@@ -86,16 +101,3 @@
         }
     }
 }
-
-/**
- * Return a onFinish lambda that can be used with [transition] so that the transition never
- * finishes. This allows to keep the transition in the current transitions list.
- */
-fun TestScope.neverFinish(): (TransitionState.Transition) -> Job {
-    return {
-        backgroundScope.launch {
-            // Try to acquire a locked mutex so that this code never completes.
-            Mutex(locked = true).withLock {}
-        }
-    }
-}
diff --git a/packages/SystemUI/multivalentTests/src/com/android/systemui/qs/ui/viewmodel/QuickSettingsShadeOverlayActionsViewModelTest.kt b/packages/SystemUI/multivalentTests/src/com/android/systemui/qs/ui/viewmodel/QuickSettingsShadeOverlayActionsViewModelTest.kt
new file mode 100644
index 0000000..fbfefb9
--- /dev/null
+++ b/packages/SystemUI/multivalentTests/src/com/android/systemui/qs/ui/viewmodel/QuickSettingsShadeOverlayActionsViewModelTest.kt
@@ -0,0 +1,83 @@
+/*
+ * Copyright (C) 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.systemui.qs.ui.viewmodel
+
+import android.testing.TestableLooper
+import androidx.test.ext.junit.runners.AndroidJUnit4
+import androidx.test.filters.SmallTest
+import com.android.compose.animation.scene.Back
+import com.android.compose.animation.scene.Swipe
+import com.android.compose.animation.scene.UserActionResult
+import com.android.systemui.SysuiTestCase
+import com.android.systemui.coroutines.collectLastValue
+import com.android.systemui.flags.EnableSceneContainer
+import com.android.systemui.kosmos.testScope
+import com.android.systemui.lifecycle.activateIn
+import com.android.systemui.scene.shared.model.Overlays
+import com.android.systemui.shade.data.repository.fakeShadeRepository
+import com.android.systemui.testKosmos
+import com.google.common.truth.Truth.assertThat
+import kotlinx.coroutines.test.runTest
+import org.junit.Test
+import org.junit.runner.RunWith
+
+@SmallTest
+@RunWith(AndroidJUnit4::class)
+@TestableLooper.RunWithLooper
+@EnableSceneContainer
+class QuickSettingsShadeOverlayActionsViewModelTest : SysuiTestCase() {
+
+    private val kosmos = testKosmos()
+    private val testScope = kosmos.testScope
+    private val fakeShadeRepository by lazy { kosmos.fakeShadeRepository }
+
+    private val underTest by lazy { kosmos.quickSettingsShadeOverlayActionsViewModel }
+
+    @Test
+    fun upTransitionSceneKey_topAligned_hidesShade() =
+        testScope.runTest {
+            val actions by collectLastValue(underTest.actions)
+            fakeShadeRepository.setDualShadeAlignedToBottom(false)
+            underTest.activateIn(this)
+
+            assertThat((actions?.get(Swipe.Up) as? UserActionResult.HideOverlay)?.overlay)
+                .isEqualTo(Overlays.QuickSettingsShade)
+            assertThat(actions?.get(Swipe.Down)).isNull()
+        }
+
+    @Test
+    fun upTransitionSceneKey_bottomAligned_doesNothing() =
+        testScope.runTest {
+            val actions by collectLastValue(underTest.actions)
+            fakeShadeRepository.setDualShadeAlignedToBottom(true)
+            underTest.activateIn(this)
+
+            assertThat(actions?.get(Swipe.Up)).isNull()
+            assertThat((actions?.get(Swipe.Down) as? UserActionResult.HideOverlay)?.overlay)
+                .isEqualTo(Overlays.QuickSettingsShade)
+        }
+
+    @Test
+    fun back_hidesShade() =
+        testScope.runTest {
+            val actions by collectLastValue(underTest.actions)
+            underTest.activateIn(this)
+
+            assertThat((actions?.get(Back) as? UserActionResult.HideOverlay)?.overlay)
+                .isEqualTo(Overlays.QuickSettingsShade)
+        }
+}
diff --git a/packages/SystemUI/res/values/strings.xml b/packages/SystemUI/res/values/strings.xml
index 1307301..d3d757b 100644
--- a/packages/SystemUI/res/values/strings.xml
+++ b/packages/SystemUI/res/values/strings.xml
@@ -721,8 +721,8 @@
     <!-- QuickSettings: Do not disturb - Priority only [CHAR LIMIT=NONE] -->
     <!-- QuickSettings: Do not disturb - Alarms only [CHAR LIMIT=NONE] -->
     <!-- QuickSettings: Do not disturb - Total silence [CHAR LIMIT=NONE] -->
-    <!-- QuickSettings: Priority modes [CHAR LIMIT=NONE] -->
-    <string name="quick_settings_modes_label">Priority modes</string>
+    <!-- QuickSettings: Modes [CHAR LIMIT=NONE] -->
+    <string name="quick_settings_modes_label">Modes</string>
     <!-- QuickSettings: Bluetooth [CHAR LIMIT=NONE] -->
     <string name="quick_settings_bluetooth_label">Bluetooth</string>
     <!-- QuickSettings: Bluetooth (Multiple) [CHAR LIMIT=NONE] -->
@@ -1097,28 +1097,28 @@
     <!-- QuickStep: Accessibility to toggle overview [CHAR LIMIT=40] -->
     <string name="quick_step_accessibility_toggle_overview">Toggle Overview</string>
 
-    <!-- Priority modes dialog title [CHAR LIMIT=35] -->
-    <string name="zen_modes_dialog_title">Priority modes</string>
+    <!-- Modes dialog title [CHAR LIMIT=35] -->
+    <string name="zen_modes_dialog_title">Modes</string>
 
-    <!-- Priority modes dialog confirmation button [CHAR LIMIT=15] -->
+    <!-- Modes dialog confirmation button [CHAR LIMIT=15] -->
     <string name="zen_modes_dialog_done">Done</string>
 
-    <!-- Priority modes dialog settings shortcut button [CHAR LIMIT=15] -->
+    <!-- Modes dialog settings shortcut button [CHAR LIMIT=15] -->
     <string name="zen_modes_dialog_settings">Settings</string>
 
-    <!-- Priority modes: label for an active mode [CHAR LIMIT=35] -->
+    <!-- Modes: label for an active mode [CHAR LIMIT=35] -->
     <string name="zen_mode_on">On</string>
 
-    <!-- Priority modes: label for an active mode, with details [CHAR LIMIT=10] -->
+    <!-- Modes: label for an active mode, with details [CHAR LIMIT=10] -->
     <string name="zen_mode_on_with_details">On • <xliff:g id="trigger_description" example="Mon-Fri, 23:00-7:00">%1$s</xliff:g></string>
 
-    <!-- Priority modes: label for an inactive mode [CHAR LIMIT=35] -->
+    <!-- Modes: label for an inactive mode [CHAR LIMIT=35] -->
     <string name="zen_mode_off">Off</string>
 
-    <!-- Priority modes: label for a mode that needs to be set up [CHAR LIMIT=35] -->
+    <!-- Modes: label for a mode that needs to be set up [CHAR LIMIT=35] -->
     <string name="zen_mode_set_up">Set up</string>
 
-    <!-- Priority modes: label for a mode that cannot be manually turned on [CHAR LIMIT=35] -->
+    <!-- Modes: label for a mode that cannot be manually turned on [CHAR LIMIT=35] -->
     <string name="zen_mode_no_manual_invocation">Manage in settings</string>
 
     <string name="zen_mode_active_modes">
diff --git a/packages/SystemUI/src/com/android/systemui/qs/ui/viewmodel/QuickSettingsShadeOverlayActionsViewModel.kt b/packages/SystemUI/src/com/android/systemui/qs/ui/viewmodel/QuickSettingsShadeOverlayActionsViewModel.kt
new file mode 100644
index 0000000..b75f180
--- /dev/null
+++ b/packages/SystemUI/src/com/android/systemui/qs/ui/viewmodel/QuickSettingsShadeOverlayActionsViewModel.kt
@@ -0,0 +1,61 @@
+/*
+ * Copyright (C) 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.systemui.qs.ui.viewmodel
+
+import com.android.compose.animation.scene.Back
+import com.android.compose.animation.scene.Swipe
+import com.android.compose.animation.scene.UserAction
+import com.android.compose.animation.scene.UserActionResult
+import com.android.systemui.scene.shared.model.Overlays
+import com.android.systemui.scene.shared.model.TransitionKeys
+import com.android.systemui.scene.ui.viewmodel.SceneActionsViewModel
+import com.android.systemui.shade.domain.interactor.ShadeInteractor
+import com.android.systemui.shade.shared.model.ShadeAlignment
+import dagger.assisted.AssistedFactory
+import dagger.assisted.AssistedInject
+
+/** Models the UI state for the user actions for navigating to other scenes or overlays. */
+class QuickSettingsShadeOverlayActionsViewModel
+@AssistedInject
+constructor(
+    private val shadeInteractor: ShadeInteractor,
+) : SceneActionsViewModel() {
+
+    override suspend fun hydrateActions(setActions: (Map<UserAction, UserActionResult>) -> Unit) {
+        setActions(
+            buildMap {
+                if (shadeInteractor.shadeAlignment == ShadeAlignment.Top) {
+                    put(Swipe.Up, UserActionResult.HideOverlay(Overlays.QuickSettingsShade))
+                } else {
+                    put(
+                        Swipe.Down,
+                        UserActionResult.HideOverlay(
+                            overlay = Overlays.QuickSettingsShade,
+                            transitionKey = TransitionKeys.OpenBottomShade,
+                        )
+                    )
+                }
+                put(Back, UserActionResult.HideOverlay(Overlays.QuickSettingsShade))
+            }
+        )
+    }
+
+    @AssistedFactory
+    interface Factory {
+        fun create(): QuickSettingsShadeOverlayActionsViewModel
+    }
+}
diff --git a/packages/SystemUI/src/com/android/systemui/qs/ui/viewmodel/QuickSettingsShadeOverlayContentViewModel.kt b/packages/SystemUI/src/com/android/systemui/qs/ui/viewmodel/QuickSettingsShadeOverlayContentViewModel.kt
new file mode 100644
index 0000000..b8311ce
--- /dev/null
+++ b/packages/SystemUI/src/com/android/systemui/qs/ui/viewmodel/QuickSettingsShadeOverlayContentViewModel.kt
@@ -0,0 +1,42 @@
+/*
+ * Copyright (C) 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.systemui.qs.ui.viewmodel
+
+import com.android.systemui.shade.ui.viewmodel.OverlayShadeViewModel
+import com.android.systemui.shade.ui.viewmodel.ShadeHeaderViewModel
+import dagger.assisted.AssistedFactory
+import dagger.assisted.AssistedInject
+
+/**
+ * Models UI state used to render the content of the quick settings shade overlay.
+ *
+ * Different from [QuickSettingsShadeOverlayActionsViewModel], which only models user actions that
+ * can be performed to navigate to other scenes.
+ */
+class QuickSettingsShadeOverlayContentViewModel
+@AssistedInject
+constructor(
+    val overlayShadeViewModelFactory: OverlayShadeViewModel.Factory,
+    val shadeHeaderViewModelFactory: ShadeHeaderViewModel.Factory,
+    val quickSettingsContainerViewModel: QuickSettingsContainerViewModel,
+) {
+
+    @AssistedFactory
+    interface Factory {
+        fun create(): QuickSettingsShadeOverlayContentViewModel
+    }
+}
diff --git a/packages/SystemUI/src/com/android/systemui/scene/KeyguardlessSceneContainerFrameworkModule.kt b/packages/SystemUI/src/com/android/systemui/scene/KeyguardlessSceneContainerFrameworkModule.kt
index 98cf941..00944b8 100644
--- a/packages/SystemUI/src/com/android/systemui/scene/KeyguardlessSceneContainerFrameworkModule.kt
+++ b/packages/SystemUI/src/com/android/systemui/scene/KeyguardlessSceneContainerFrameworkModule.kt
@@ -46,6 +46,7 @@
             NotificationsShadeOverlayModule::class,
             NotificationsShadeSceneModule::class,
             NotificationsShadeSessionModule::class,
+            QuickSettingsShadeOverlayModule::class,
             QuickSettingsSceneModule::class,
             ShadeSceneModule::class,
             SceneDomainModule::class,
@@ -104,6 +105,7 @@
                 overlayKeys =
                     listOfNotNull(
                         Overlays.NotificationsShade.takeIf { DualShade.isEnabled },
+                        Overlays.QuickSettingsShade.takeIf { DualShade.isEnabled },
                     ),
                 navigationDistances =
                     mapOf(
diff --git a/packages/SystemUI/src/com/android/systemui/scene/SceneContainerFrameworkModule.kt b/packages/SystemUI/src/com/android/systemui/scene/SceneContainerFrameworkModule.kt
index 8fc896c..4061ad8 100644
--- a/packages/SystemUI/src/com/android/systemui/scene/SceneContainerFrameworkModule.kt
+++ b/packages/SystemUI/src/com/android/systemui/scene/SceneContainerFrameworkModule.kt
@@ -48,6 +48,7 @@
             LockscreenSceneModule::class,
             QuickSettingsSceneModule::class,
             ShadeSceneModule::class,
+            QuickSettingsShadeOverlayModule::class,
             QuickSettingsShadeSceneModule::class,
             NotificationsShadeOverlayModule::class,
             NotificationsShadeSceneModule::class,
@@ -111,6 +112,7 @@
                 overlayKeys =
                     listOfNotNull(
                         Overlays.NotificationsShade.takeIf { DualShade.isEnabled },
+                        Overlays.QuickSettingsShade.takeIf { DualShade.isEnabled },
                     ),
                 navigationDistances =
                     mapOf(
diff --git a/packages/SystemUI/src/com/android/systemui/scene/shared/model/Overlays.kt b/packages/SystemUI/src/com/android/systemui/scene/shared/model/Overlays.kt
index 0bb02e9..c47a850 100644
--- a/packages/SystemUI/src/com/android/systemui/scene/shared/model/Overlays.kt
+++ b/packages/SystemUI/src/com/android/systemui/scene/shared/model/Overlays.kt
@@ -36,4 +36,17 @@
      * side-by-side in their own columns).
      */
     @JvmField val NotificationsShade = OverlayKey("notifications_shade")
+
+    /**
+     * The quick settings shade overlay shows the quick settings tiles UI.
+     *
+     * It's used only in the dual shade configuration, where there are two separate shades: one for
+     * quick settings (this overlay) and another for [NotificationsShade].
+     *
+     * It's not used in the single/accordion configuration (swipe down once to reveal the shade,
+     * swipe down again the to expand quick settings) or in the "split" shade configuration (on
+     * large screens or unfolded foldables, where notifications and quick settings are shown
+     * side-by-side in their own columns).
+     */
+    @JvmField val QuickSettingsShade = OverlayKey("quick_settings_shade")
 }
diff --git a/packages/SystemUI/src/com/android/systemui/statusbar/notification/stack/NotificationStackScrollLayout.java b/packages/SystemUI/src/com/android/systemui/statusbar/notification/stack/NotificationStackScrollLayout.java
index 64d7124..925ebf3 100644
--- a/packages/SystemUI/src/com/android/systemui/statusbar/notification/stack/NotificationStackScrollLayout.java
+++ b/packages/SystemUI/src/com/android/systemui/statusbar/notification/stack/NotificationStackScrollLayout.java
@@ -789,7 +789,6 @@
     private void onJustBeforeDraw() {
         if (SceneContainerFlag.isEnabled()) {
             if (mChildrenUpdateRequested) {
-                updateForcedScroll();
                 updateChildren();
                 mChildrenUpdateRequested = false;
             }
@@ -1998,7 +1997,8 @@
     }
 
     public void lockScrollTo(View v) {
-        if (mForcedScroll == v) {
+        // NSSL shouldn't handle scrolling with SceneContainer enabled.
+        if (mForcedScroll == v || SceneContainerFlag.isEnabled()) {
             return;
         }
         mForcedScroll = v;
@@ -2006,6 +2006,10 @@
     }
 
     public boolean scrollTo(View v) {
+        // NSSL shouldn't handle scrolling with SceneContainer enabled.
+        if (SceneContainerFlag.isEnabled()) {
+            return false;
+        }
         ExpandableView expandableView = (ExpandableView) v;
         int positionInLinearLayout = getPositionInLinearLayout(v);
         int targetScroll = targetScrollForView(expandableView, positionInLinearLayout);
@@ -2027,6 +2031,7 @@
      * the IME.
      */
     private int targetScrollForView(ExpandableView v, int positionInLinearLayout) {
+        SceneContainerFlag.assertInLegacyMode();
         return positionInLinearLayout + v.getIntrinsicHeight() +
                 getImeInset() - getHeight()
                 + ((!isExpanded() && isPinnedHeadsUp(v)) ? mHeadsUpInset : getTopPadding());
@@ -4172,6 +4177,11 @@
      */
     @Override
     public boolean performAccessibilityActionInternal(int action, Bundle arguments) {
+        // Don't handle scroll accessibility events from the NSSL, when SceneContainer enabled.
+        if (SceneContainerFlag.isEnabled()) {
+            return super.performAccessibilityActionInternal(action, arguments);
+        }
+
         if (super.performAccessibilityActionInternal(action, arguments)) {
             return true;
         }
@@ -4933,6 +4943,11 @@
     @Override
     public void onInitializeAccessibilityEventInternal(AccessibilityEvent event) {
         super.onInitializeAccessibilityEventInternal(event);
+        // Don't handle scroll accessibility events from the NSSL, when SceneContainer enabled.
+        if (SceneContainerFlag.isEnabled()) {
+            return;
+        }
+
         event.setScrollable(mScrollable);
         event.setMaxScrollX(mScrollX);
         event.setScrollY(mOwnScrollY);
@@ -4942,6 +4957,11 @@
     @Override
     public void onInitializeAccessibilityNodeInfoInternal(AccessibilityNodeInfo info) {
         super.onInitializeAccessibilityNodeInfoInternal(info);
+        // Don't handle scroll accessibility events from the NSSL, when SceneContainer enabled.
+        if (SceneContainerFlag.isEnabled()) {
+            return;
+        }
+
         if (mScrollable) {
             info.setScrollable(true);
             if (mBackwardScrollable) {
diff --git a/packages/SystemUI/tests/utils/src/com/android/systemui/qs/ui/viewmodel/QuickSettingsShadeOverlayActionsViewModelKosmos.kt b/packages/SystemUI/tests/utils/src/com/android/systemui/qs/ui/viewmodel/QuickSettingsShadeOverlayActionsViewModelKosmos.kt
new file mode 100644
index 0000000..41ca2f9
--- /dev/null
+++ b/packages/SystemUI/tests/utils/src/com/android/systemui/qs/ui/viewmodel/QuickSettingsShadeOverlayActionsViewModelKosmos.kt
@@ -0,0 +1,28 @@
+/*
+ * Copyright (C) 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.systemui.qs.ui.viewmodel
+
+import com.android.systemui.kosmos.Kosmos
+import com.android.systemui.kosmos.Kosmos.Fixture
+import com.android.systemui.shade.domain.interactor.shadeInteractor
+
+val Kosmos.quickSettingsShadeOverlayActionsViewModel:
+    QuickSettingsShadeOverlayActionsViewModel by Fixture {
+    QuickSettingsShadeOverlayActionsViewModel(
+        shadeInteractor = shadeInteractor,
+    )
+}
diff --git a/packages/SystemUI/tests/utils/src/com/android/systemui/qs/ui/viewmodel/QuickSettingsShadeOverlayContentViewModelKosmos.kt b/packages/SystemUI/tests/utils/src/com/android/systemui/qs/ui/viewmodel/QuickSettingsShadeOverlayContentViewModelKosmos.kt
new file mode 100644
index 0000000..9025c5c
--- /dev/null
+++ b/packages/SystemUI/tests/utils/src/com/android/systemui/qs/ui/viewmodel/QuickSettingsShadeOverlayContentViewModelKosmos.kt
@@ -0,0 +1,30 @@
+/*
+ * Copyright (C) 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.systemui.qs.ui.viewmodel
+
+import com.android.systemui.kosmos.Kosmos
+import com.android.systemui.shade.ui.viewmodel.overlayShadeViewModelFactory
+import com.android.systemui.shade.ui.viewmodel.shadeHeaderViewModelFactory
+
+val Kosmos.quickSettingsShadeOverlayContentViewModel: QuickSettingsShadeOverlayContentViewModel by
+    Kosmos.Fixture {
+        QuickSettingsShadeOverlayContentViewModel(
+            overlayShadeViewModelFactory = overlayShadeViewModelFactory,
+            shadeHeaderViewModelFactory = shadeHeaderViewModelFactory,
+            quickSettingsContainerViewModel = quickSettingsContainerViewModel,
+        )
+    }
diff --git a/packages/SystemUI/tests/utils/src/com/android/systemui/scene/SceneKosmos.kt b/packages/SystemUI/tests/utils/src/com/android/systemui/scene/SceneKosmos.kt
index 7bc2483..dc45d93 100644
--- a/packages/SystemUI/tests/utils/src/com/android/systemui/scene/SceneKosmos.kt
+++ b/packages/SystemUI/tests/utils/src/com/android/systemui/scene/SceneKosmos.kt
@@ -28,6 +28,7 @@
 var Kosmos.overlayKeys by Fixture {
     listOf(
         Overlays.NotificationsShade,
+        Overlays.QuickSettingsShade,
     )
 }
 
diff --git a/services/core/java/com/android/server/power/Notifier.java b/services/core/java/com/android/server/power/Notifier.java
index 1a2a196..303828f 100644
--- a/services/core/java/com/android/server/power/Notifier.java
+++ b/services/core/java/com/android/server/power/Notifier.java
@@ -1064,9 +1064,9 @@
     private void notifyWakeLockListener(IWakeLockCallback callback, String tag, boolean isEnabled,
             int ownerUid, int ownerPid, int flags, WorkSource workSource, String packageName,
             String historyTag) {
+        long currentTime = mInjector.currentTimeMillis();
         mHandler.post(() -> {
             if (mFlags.improveWakelockLatency()) {
-                long currentTime = mInjector.currentTimeMillis();
                 if (isEnabled) {
                     notifyWakelockAcquisition(tag, ownerUid, ownerPid, flags,
                             workSource, packageName, historyTag, currentTime);
diff --git a/services/core/java/com/android/server/webkit/SystemImpl.java b/services/core/java/com/android/server/webkit/SystemImpl.java
index c4d601d..5e35925 100644
--- a/services/core/java/com/android/server/webkit/SystemImpl.java
+++ b/services/core/java/com/android/server/webkit/SystemImpl.java
@@ -19,14 +19,12 @@
 import static android.webkit.Flags.updateServiceV2;
 
 import android.app.ActivityManager;
-import android.app.AppGlobals;
 import android.content.Context;
 import android.content.pm.ApplicationInfo;
 import android.content.pm.PackageInfo;
 import android.content.pm.PackageInstaller;
 import android.content.pm.PackageManager;
 import android.content.pm.PackageManager.NameNotFoundException;
-import android.content.pm.UserInfo;
 import android.content.res.XmlResourceParser;
 import android.os.Build;
 import android.os.RemoteException;
@@ -79,7 +77,7 @@
         XmlResourceParser parser = null;
         List<WebViewProviderInfo> webViewProviders = new ArrayList<WebViewProviderInfo>();
         try {
-            parser = AppGlobals.getInitialApplication().getResources().getXml(
+            parser = mContext.getResources().getXml(
                     com.android.internal.R.xml.config_webview_packages);
             XmlUtils.beginDocument(parser, TAG_START);
             while(true) {
@@ -148,7 +146,7 @@
     }
 
     public long getFactoryPackageVersion(String packageName) throws NameNotFoundException {
-        PackageManager pm = AppGlobals.getInitialApplication().getPackageManager();
+        PackageManager pm = mContext.getPackageManager();
         return pm.getPackageInfo(packageName, PackageManager.MATCH_FACTORY_ONLY)
                 .getLongVersionCode();
     }
@@ -203,47 +201,48 @@
     @Override
     public void enablePackageForAllUsers(String packageName, boolean enable) {
         UserManager userManager = mContext.getSystemService(UserManager.class);
-        for(UserInfo userInfo : userManager.getUsers()) {
-            enablePackageForUser(packageName, enable, userInfo.id);
+        for (UserHandle user : userManager.getUserHandles(false)) {
+            enablePackageForUser(packageName, enable, user);
         }
     }
 
-    private void enablePackageForUser(String packageName, boolean enable, int userId) {
+    private void enablePackageForUser(String packageName, boolean enable, UserHandle user) {
+        Context contextAsUser = mContext.createContextAsUser(user, 0);
+        PackageManager pm = contextAsUser.getPackageManager();
         try {
-            AppGlobals.getPackageManager().setApplicationEnabledSetting(
+            pm.setApplicationEnabledSetting(
                     packageName,
                     enable ? PackageManager.COMPONENT_ENABLED_STATE_DEFAULT :
-                    PackageManager.COMPONENT_ENABLED_STATE_DISABLED_USER, 0,
-                    userId, null);
-        } catch (RemoteException | IllegalArgumentException e) {
+                    PackageManager.COMPONENT_ENABLED_STATE_DISABLED_USER, 0);
+        } catch (IllegalArgumentException e) {
             Log.w(TAG, "Tried to " + (enable ? "enable " : "disable ") + packageName
-                    + " for user " + userId + ": " + e);
+                    + " for user " + user + ": " + e);
         }
     }
 
     @Override
     public void installExistingPackageForAllUsers(String packageName) {
         UserManager userManager = mContext.getSystemService(UserManager.class);
-        for (UserInfo userInfo : userManager.getUsers()) {
-            installPackageForUser(packageName, userInfo.id);
+        for (UserHandle user : userManager.getUserHandles(false)) {
+            installPackageForUser(packageName, user);
         }
     }
 
-    private void installPackageForUser(String packageName, int userId) {
-        final Context contextAsUser = mContext.createContextAsUser(UserHandle.of(userId), 0);
-        final PackageInstaller installer = contextAsUser.getPackageManager().getPackageInstaller();
+    private void installPackageForUser(String packageName, UserHandle user) {
+        Context contextAsUser = mContext.createContextAsUser(user, 0);
+        PackageInstaller installer = contextAsUser.getPackageManager().getPackageInstaller();
         installer.installExistingPackage(packageName, PackageManager.INSTALL_REASON_UNKNOWN, null);
     }
 
     @Override
     public boolean systemIsDebuggable() {
-        return Build.IS_DEBUGGABLE;
+        return Build.isDebuggable();
     }
 
     @Override
     public PackageInfo getPackageInfoForProvider(WebViewProviderInfo configInfo)
             throws NameNotFoundException {
-        PackageManager pm = AppGlobals.getInitialApplication().getPackageManager();
+        PackageManager pm = mContext.getPackageManager();
         return pm.getPackageInfo(configInfo.packageName, PACKAGE_FLAGS);
     }
 
diff --git a/services/core/java/com/android/server/wm/DisplayContent.java b/services/core/java/com/android/server/wm/DisplayContent.java
index 0597ed7..34bbe6a 100644
--- a/services/core/java/com/android/server/wm/DisplayContent.java
+++ b/services/core/java/com/android/server/wm/DisplayContent.java
@@ -1835,7 +1835,7 @@
         if (mTransitionController.useShellTransitionsRotation()) {
             return ROTATION_UNDEFINED;
         }
-        final int activityOrientation = r.getOverrideOrientation();
+        int activityOrientation = r.getOverrideOrientation();
         if (!WindowManagerService.ENABLE_FIXED_ROTATION_TRANSFORM
                 || shouldIgnoreOrientationRequest(activityOrientation)) {
             return ROTATION_UNDEFINED;
@@ -1846,14 +1846,15 @@
                     r /* boundary */, false /* includeBoundary */, true /* traverseTopToBottom */);
             if (nextCandidate != null) {
                 r = nextCandidate;
+                activityOrientation = r.getOverrideOrientation();
             }
         }
-        if (r.inMultiWindowMode() || r.getRequestedConfigurationOrientation(true /* forDisplay */)
-                == getConfiguration().orientation) {
+        if (r.inMultiWindowMode() || r.getRequestedConfigurationOrientation(true /* forDisplay */,
+                activityOrientation) == getConfiguration().orientation) {
             return ROTATION_UNDEFINED;
         }
         final int currentRotation = getRotation();
-        final int rotation = mDisplayRotation.rotationForOrientation(r.getRequestedOrientation(),
+        final int rotation = mDisplayRotation.rotationForOrientation(activityOrientation,
                 currentRotation);
         if (rotation == currentRotation) {
             return ROTATION_UNDEFINED;
diff --git a/services/core/java/com/android/server/wm/OWNERS b/services/core/java/com/android/server/wm/OWNERS
index 781023c..5d6d8bc 100644
--- a/services/core/java/com/android/server/wm/OWNERS
+++ b/services/core/java/com/android/server/wm/OWNERS
@@ -24,6 +24,7 @@
 per-file Background*Start* = set noparent
 per-file Background*Start* = file:/BAL_OWNERS
 per-file Background*Start* = ogunwale@google.com, louischang@google.com
+per-file BackgroundLaunchProcessController.java = file:/BAL_OWNERS
 
 # File related to activity callers
 per-file ActivityCallerState.java = file:/core/java/android/app/COMPONENT_CALLER_OWNERS
diff --git a/services/core/java/com/android/server/wm/WindowContainer.java b/services/core/java/com/android/server/wm/WindowContainer.java
index 6995027..790ca1b 100644
--- a/services/core/java/com/android/server/wm/WindowContainer.java
+++ b/services/core/java/com/android/server/wm/WindowContainer.java
@@ -1731,13 +1731,13 @@
      *         last time {@link #getOrientation(int) was called.
      */
     @Nullable
-    WindowContainer getLastOrientationSource() {
-        final WindowContainer source = mLastOrientationSource;
-        if (source != null && source != this) {
-            final WindowContainer nextSource = source.getLastOrientationSource();
-            if (nextSource != null) {
-                return nextSource;
-            }
+    final WindowContainer<?> getLastOrientationSource() {
+        if (mLastOrientationSource == null) {
+            return null;
+        }
+        WindowContainer<?> source = this;
+        while (source != source.mLastOrientationSource && source.mLastOrientationSource != null) {
+            source = source.mLastOrientationSource;
         }
         return source;
     }
diff --git a/services/core/java/com/android/server/wm/WindowManagerService.java b/services/core/java/com/android/server/wm/WindowManagerService.java
index e3ceb33..29ab4dd 100644
--- a/services/core/java/com/android/server/wm/WindowManagerService.java
+++ b/services/core/java/com/android/server/wm/WindowManagerService.java
@@ -7920,7 +7920,7 @@
             }
             boolean allWindowsDrawn = false;
             synchronized (mGlobalLock) {
-                if ((displayId == DEFAULT_DISPLAY || displayId == INVALID_DISPLAY)
+                if (displayId == INVALID_DISPLAY
                         && mRoot.getDefaultDisplay().mDisplayUpdater.waitForTransition(message)) {
                     // Use the ready-to-play of transition as the signal.
                     return;
diff --git a/services/tests/wmtests/src/com/android/server/wm/DisplayContentDeferredUpdateTests.java b/services/tests/wmtests/src/com/android/server/wm/DisplayContentDeferredUpdateTests.java
index 1933908..14276ae 100644
--- a/services/tests/wmtests/src/com/android/server/wm/DisplayContentDeferredUpdateTests.java
+++ b/services/tests/wmtests/src/com/android/server/wm/DisplayContentDeferredUpdateTests.java
@@ -16,7 +16,6 @@
 
 package com.android.server.wm;
 
-import static android.view.Display.DEFAULT_DISPLAY;
 import static android.view.Display.INVALID_DISPLAY;
 import static android.view.WindowManager.LayoutParams.TYPE_BASE_APPLICATION;
 
@@ -278,7 +277,7 @@
         mDisplayContent.mDisplayUpdater.onDisplaySwitching(/* switching= */ true);
 
         mWmInternal.waitForAllWindowsDrawn(mScreenUnblocker,
-                /* timeout= */ Integer.MAX_VALUE, DEFAULT_DISPLAY);
+                /* timeout= */ Integer.MAX_VALUE, INVALID_DISPLAY);
         mWmInternal.waitForAllWindowsDrawn(mSecondaryScreenUnblocker,
                 /* timeout= */ Integer.MAX_VALUE, mSecondaryDisplayContent.getDisplayId());
 
@@ -317,50 +316,6 @@
         verify(mScreenUnblocker).sendToTarget();
     }
 
-    @Test
-    public void testWaitForAllWindowsDrawnForInvalidDisplay_usesTransitionToUnblock() {
-        mSetFlagsRule.enableFlags(Flags.FLAG_WAIT_FOR_TRANSITION_ON_DISPLAY_SWITCH);
-
-        final WindowState defaultDisplayWindow = createWindow(/* parent= */ null,
-                TYPE_BASE_APPLICATION, mDisplayContent, "DefaultDisplayWindow");
-        makeWindowVisibleAndNotDrawn(defaultDisplayWindow);
-
-        mDisplayContent.mDisplayUpdater.onDisplaySwitching(/* switching= */ true);
-
-        mWmInternal.waitForAllWindowsDrawn(mScreenUnblocker,
-                /* timeout= */ Integer.MAX_VALUE, INVALID_DISPLAY);
-
-        // Perform display update
-        mUniqueId = "new_default_display_unique_id";
-        mDisplayContent.requestDisplayUpdate(mock(Runnable.class));
-
-        when(mDisplayContent.mTransitionController.inTransition()).thenReturn(true);
-
-        // Notify that transition started collecting
-        captureStartTransitionCollection().getAllValues().forEach((callback) ->
-                callback.onCollectStarted(/* deferred= */ true));
-
-        // Verify that screen is not unblocked yet
-        verify(mScreenUnblocker, never()).sendToTarget();
-
-        // Make all display windows drawn
-        defaultDisplayWindow.mWinAnimator.mDrawState = HAS_DRAWN;
-        mWm.mRoot.performSurfacePlacement();
-
-        // Verify that default display is still not unblocked yet
-        // (so it doesn't use old windows drawn path)
-        verify(mScreenUnblocker, never()).sendToTarget();
-
-        // Mark start transaction as presented
-        when(mDisplayContent.mTransitionController.inTransition()).thenReturn(false);
-        captureRequestedTransition().getAllValues().forEach(
-                this::makeTransitionTransactionCompleted);
-
-        // Verify that the default screen unblocker is sent only after start transaction
-        // of the Shell transition is presented
-        verify(mScreenUnblocker).sendToTarget();
-    }
-
     private void prepareSecondaryDisplay() {
         mSecondaryDisplayContent = createNewDisplay();
         when(mSecondaryScreenUnblocker.getTarget()).thenReturn(mWm.mH);
diff --git a/services/tests/wmtests/src/com/android/server/wm/DisplayContentTests.java b/services/tests/wmtests/src/com/android/server/wm/DisplayContentTests.java
index f2ea1c9..eca4d21 100644
--- a/services/tests/wmtests/src/com/android/server/wm/DisplayContentTests.java
+++ b/services/tests/wmtests/src/com/android/server/wm/DisplayContentTests.java
@@ -1144,6 +1144,7 @@
 
     @Test
     public void testOrientationBehind() {
+        assertNull(mDisplayContent.getLastOrientationSource());
         final ActivityRecord prev = new ActivityBuilder(mAtm).setCreateTask(true)
                 .setScreenOrientation(getRotatedOrientation(mDisplayContent)).build();
         prev.setVisibleRequested(false);