Merge "Introduce ActivitySnapshotController(3/N)" into udc-dev
diff --git a/core/res/res/values/config.xml b/core/res/res/values/config.xml
index ab5fca2..19fa7a3 100644
--- a/core/res/res/values/config.xml
+++ b/core/res/res/values/config.xml
@@ -3253,6 +3253,9 @@
     <!-- Feature flag to store TaskSnapshot in 16 bit pixel format to save memory. -->
     <bool name="config_use16BitTaskSnapshotPixelFormat">false</bool>
 
+    <!-- The amount to scale fullscreen activity snapshot for predict-back animation. -->
+    <item name="config_resActivitySnapshotScale" format="float" type="dimen">0.6</item>
+
     <!-- Determines whether recent tasks are provided to the user. Default device has recents
          property. If this is false, then the following recents config flags are ignored. -->
     <bool name="config_hasRecents">true</bool>
diff --git a/core/res/res/values/symbols.xml b/core/res/res/values/symbols.xml
index 60795ad..251db6d 100644
--- a/core/res/res/values/symbols.xml
+++ b/core/res/res/values/symbols.xml
@@ -362,6 +362,7 @@
   <java-symbol type="bool" name="config_disableUsbPermissionDialogs"/>
   <java-symbol type="dimen" name="config_highResTaskSnapshotScale" />
   <java-symbol type="dimen" name="config_lowResTaskSnapshotScale" />
+  <java-symbol type="dimen" name="config_resActivitySnapshotScale" />
   <java-symbol type="dimen" name="config_qsTileStrokeWidthInactive" />
   <java-symbol type="dimen" name="config_qsTileStrokeWidthActive" />
   <java-symbol type="bool" name="config_use16BitTaskSnapshotPixelFormat" />
diff --git a/services/core/java/com/android/server/wm/AbsAppSnapshotController.java b/services/core/java/com/android/server/wm/AbsAppSnapshotController.java
index f215495..83804f7 100644
--- a/services/core/java/com/android/server/wm/AbsAppSnapshotController.java
+++ b/services/core/java/com/android/server/wm/AbsAppSnapshotController.java
@@ -55,8 +55,8 @@
  * @param <CACHE> The basic cache for either Task or ActivityRecord
  */
 abstract class AbsAppSnapshotController<TYPE extends WindowContainer,
-        CACHE extends AbsAppSnapshotCache<TYPE>> {
-    private static final String TAG = TAG_WITH_CLASS_NAME ? "SnapshotController" : TAG_WM;
+        CACHE extends SnapshotCache<TYPE>> {
+    static final String TAG = TAG_WITH_CLASS_NAME ? "SnapshotController" : TAG_WM;
     /**
      * Return value for {@link #getSnapshotMode}: We are allowed to take a real screenshot to be
      * used as the snapshot.
@@ -76,7 +76,7 @@
     static final int SNAPSHOT_MODE_NONE = 2;
 
     protected final WindowManagerService mService;
-    protected final float mHighResTaskSnapshotScale;
+    protected final float mHighResSnapshotScale;
     private final Rect mTmpRect = new Rect();
     /**
      * Flag indicating whether we are running on an Android TV device.
@@ -99,12 +99,13 @@
                 PackageManager.FEATURE_LEANBACK);
         mIsRunningOnIoT = mService.mContext.getPackageManager().hasSystemFeature(
                 PackageManager.FEATURE_EMBEDDED);
-        mHighResTaskSnapshotScale = initSnapshotScale();
+        mHighResSnapshotScale = initSnapshotScale();
     }
 
     protected float initSnapshotScale() {
-        return mService.mContext.getResources().getFloat(
+        final float config = mService.mContext.getResources().getFloat(
                 com.android.internal.R.dimen.config_highResTaskSnapshotScale);
+        return Math.max(Math.min(config, 1f), 0.1f);
     }
 
     /**
@@ -173,7 +174,7 @@
         final HardwareBuffer buffer = snapshot.getHardwareBuffer();
         if (buffer.getWidth() == 0 || buffer.getHeight() == 0) {
             buffer.close();
-            Slog.e(TAG, "Invalid task snapshot dimensions " + buffer.getWidth() + "x"
+            Slog.e(TAG, "Invalid snapshot dimensions " + buffer.getWidth() + "x"
                     + buffer.getHeight());
             return null;
         } else {
@@ -223,7 +224,7 @@
         Point taskSize = new Point();
         Trace.traceBegin(Trace.TRACE_TAG_WINDOW_MANAGER, "createSnapshot");
         final ScreenCapture.ScreenshotHardwareBuffer taskSnapshot = createSnapshot(source,
-                mHighResTaskSnapshotScale, builder.getPixelFormat(), taskSize, builder);
+                mHighResSnapshotScale, builder.getPixelFormat(), taskSize, builder);
         Trace.traceEnd(Trace.TRACE_TAG_WINDOW_MANAGER);
         builder.setTaskSize(taskSize);
         return taskSnapshot;
@@ -397,11 +398,11 @@
         final SnapshotDrawerUtils.SystemBarBackgroundPainter
                 decorPainter = new SnapshotDrawerUtils.SystemBarBackgroundPainter(attrs.flags,
                 attrs.privateFlags, attrs.insetsFlags.appearance, taskDescription,
-                mHighResTaskSnapshotScale, mainWindow.getRequestedVisibleTypes());
+                mHighResSnapshotScale, mainWindow.getRequestedVisibleTypes());
         final int taskWidth = taskBounds.width();
         final int taskHeight = taskBounds.height();
-        final int width = (int) (taskWidth * mHighResTaskSnapshotScale);
-        final int height = (int) (taskHeight * mHighResTaskSnapshotScale);
+        final int width = (int) (taskWidth * mHighResSnapshotScale);
+        final int height = (int) (taskHeight * mHighResSnapshotScale);
         final RenderNode node = RenderNode.create("SnapshotController", null);
         node.setLeftTopRightBottom(0, 0, width, height);
         node.setClipToBounds(false);
@@ -450,9 +451,28 @@
         return 0;
     }
 
+    /**
+     * Called when an {@link ActivityRecord} has been removed.
+     */
+    void onAppRemoved(ActivityRecord activity) {
+        mCache.onAppRemoved(activity);
+    }
+
+    /**
+     * Called when the process of an {@link ActivityRecord} has died.
+     */
+    void onAppDied(ActivityRecord activity) {
+        mCache.onAppDied(activity);
+    }
+
+    boolean isAnimatingByRecents(@NonNull Task task) {
+        return task.isAnimatingByRecents()
+                || mService.mAtmService.getTransitionController().inRecentsTransition(task);
+    }
+
     void dump(PrintWriter pw, String prefix) {
-        pw.println(prefix + "mHighResTaskSnapshotScale=" + mHighResTaskSnapshotScale);
-        pw.println(prefix + "mTaskSnapshotEnabled=" + mSnapshotEnabled);
+        pw.println(prefix + "mHighResSnapshotScale=" + mHighResSnapshotScale);
+        pw.println(prefix + "mSnapshotEnabled=" + mSnapshotEnabled);
         mCache.dump(pw, prefix);
     }
 }
diff --git a/services/core/java/com/android/server/wm/ActivityRecord.java b/services/core/java/com/android/server/wm/ActivityRecord.java
index 1e88ead..d08293b 100644
--- a/services/core/java/com/android/server/wm/ActivityRecord.java
+++ b/services/core/java/com/android/server/wm/ActivityRecord.java
@@ -4223,7 +4223,8 @@
 
         getDisplayContent().mOpeningApps.remove(this);
         getDisplayContent().mUnknownAppVisibilityController.appRemovedOrHidden(this);
-        mWmService.mTaskSnapshotController.onAppRemoved(this);
+        mWmService.mSnapshotController.onAppRemoved(this);
+
         mTaskSupervisor.getActivityMetricsLogger().notifyActivityRemoved(this);
         mTaskSupervisor.mStoppingActivities.remove(this);
         waitingToShow = false;
@@ -5557,7 +5558,7 @@
                 && !fromTransition) {
             // Take the screenshot before possibly hiding the WSA, otherwise the screenshot
             // will not be taken.
-            mWmService.mTaskSnapshotController.notifyAppVisibilityChanged(this, visible);
+            mWmService.mSnapshotController.notifyAppVisibilityChanged(this, visible);
         }
 
         // If we are hidden but there is no delay needed we immediately
@@ -10465,6 +10466,18 @@
                 && !inPinnedWindowingMode() && !inFreeformWindowingMode();
     }
 
+    boolean canCaptureSnapshot() {
+        if (!isSurfaceShowing() || findMainWindow() == null) {
+            return false;
+        }
+        return forAllWindows(
+                // Ensure at least one window for the top app is visible before attempting to
+                // take a screenshot. Visible here means that the WSA surface is shown and has
+                // an alpha greater than 0.
+                ws -> ws.mWinAnimator != null && ws.mWinAnimator.getShown()
+                        && ws.mWinAnimator.mLastAlpha > 0f, true  /* traverseTopToBottom */);
+    }
+
     void overrideCustomTransition(boolean open, int enterAnim, int exitAnim, int backgroundColor) {
         CustomAppTransition transition = getCustomAnimation(open);
         if (transition == null) {
diff --git a/services/core/java/com/android/server/wm/ActivitySnapshotCache.java b/services/core/java/com/android/server/wm/ActivitySnapshotCache.java
new file mode 100644
index 0000000..a54dd82
--- /dev/null
+++ b/services/core/java/com/android/server/wm/ActivitySnapshotCache.java
@@ -0,0 +1,40 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server.wm;
+
+import android.window.TaskSnapshot;
+
+/**
+ * A snapshot cache for activity, the token is the hashCode of the activity.
+ */
+class ActivitySnapshotCache extends SnapshotCache<ActivityRecord> {
+
+    ActivitySnapshotCache(WindowManagerService service) {
+        super(service, "Activity");
+    }
+
+    @Override
+    void putSnapshot(ActivityRecord ar, TaskSnapshot snapshot) {
+        final int hasCode = System.identityHashCode(ar);
+        final CacheEntry entry = mRunningCache.get(hasCode);
+        if (entry != null) {
+            mAppIdMap.remove(entry.topApp);
+        }
+        mAppIdMap.put(ar, hasCode);
+        mRunningCache.put(hasCode, new CacheEntry(snapshot, ar));
+    }
+}
diff --git a/services/core/java/com/android/server/wm/ActivitySnapshotController.java b/services/core/java/com/android/server/wm/ActivitySnapshotController.java
new file mode 100644
index 0000000..90a4820
--- /dev/null
+++ b/services/core/java/com/android/server/wm/ActivitySnapshotController.java
@@ -0,0 +1,505 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server.wm;
+
+import static com.android.server.wm.SnapshotController.ACTIVITY_CLOSE;
+import static com.android.server.wm.SnapshotController.ACTIVITY_OPEN;
+import static com.android.server.wm.SnapshotController.TASK_CLOSE;
+import static com.android.server.wm.SnapshotController.TASK_OPEN;
+
+import android.annotation.NonNull;
+import android.annotation.Nullable;
+import android.app.ActivityManager;
+import android.os.Environment;
+import android.os.SystemProperties;
+import android.util.ArraySet;
+import android.util.Slog;
+import android.util.SparseArray;
+import android.window.TaskSnapshot;
+
+import com.android.internal.annotations.VisibleForTesting;
+import com.android.server.LocalServices;
+import com.android.server.pm.UserManagerInternal;
+import com.android.server.wm.BaseAppSnapshotPersister.PersistInfoProvider;
+import com.android.server.wm.SnapshotController.TransitionState;
+
+import java.io.File;
+import java.util.ArrayList;
+
+/**
+ * When an app token becomes invisible, we take a snapshot (bitmap) and put it into our cache.
+ * Internally we use gralloc buffers to be able to draw them wherever we like without any copying.
+ * <p>
+ * System applications may retrieve a snapshot to represent the current state of an activity, and
+ * draw them in their own process.
+ * <p>
+ * Unlike TaskSnapshotController, we only keep one activity snapshot for a visible task in the
+ * cache. Which should largely reduce the memory usage.
+ * <p>
+ * To access this class, acquire the global window manager lock.
+ */
+class ActivitySnapshotController extends AbsAppSnapshotController<ActivityRecord,
+        ActivitySnapshotCache> {
+    private static final boolean DEBUG = false;
+    private static final String TAG = AbsAppSnapshotController.TAG;
+    // Maximum persisted snapshot count on disk.
+    private static final int MAX_PERSIST_SNAPSHOT_COUNT = 20;
+
+    static final String SNAPSHOTS_DIRNAME = "activity_snapshots";
+
+    /**
+     * The pending activities which should capture snapshot when process transition finish.
+     */
+    @VisibleForTesting
+    final ArraySet<ActivityRecord> mPendingCaptureActivity = new ArraySet<>();
+
+    /**
+     * The pending activities which should remove snapshot from memory when process transition
+     * finish.
+     */
+    @VisibleForTesting
+    final ArraySet<ActivityRecord> mPendingRemoveActivity = new ArraySet<>();
+
+    /**
+     * The pending activities which should delete snapshot files when process transition finish.
+     */
+    @VisibleForTesting
+    final ArraySet<ActivityRecord> mPendingDeleteActivity = new ArraySet<>();
+
+    /**
+     * The pending activities which should load snapshot from disk when process transition finish.
+     */
+    @VisibleForTesting
+    final ArraySet<ActivityRecord> mPendingLoadActivity = new ArraySet<>();
+
+    private final SnapshotPersistQueue mSnapshotPersistQueue;
+    private final PersistInfoProvider mPersistInfoProvider;
+    private final AppSnapshotLoader mSnapshotLoader;
+
+    /**
+     * File information holders, to make the sequence align, always update status of
+     * mUserSavedFiles/mSavedFilesInOrder before persist file from mPersister.
+     */
+    private final SparseArray<SparseArray<UserSavedFile>> mUserSavedFiles = new SparseArray<>();
+    // Keep sorted with create timeline.
+    private final ArrayList<UserSavedFile> mSavedFilesInOrder = new ArrayList<>();
+    private final TaskSnapshotPersister mPersister;
+
+    ActivitySnapshotController(WindowManagerService service, SnapshotPersistQueue persistQueue) {
+        super(service);
+        mSnapshotPersistQueue = persistQueue;
+        mPersistInfoProvider = createPersistInfoProvider(service,
+                Environment::getDataSystemCeDirectory);
+        mPersister = new TaskSnapshotPersister(persistQueue, mPersistInfoProvider);
+        mSnapshotLoader = new AppSnapshotLoader(mPersistInfoProvider);
+        initialize(new ActivitySnapshotCache(service));
+
+        final boolean snapshotEnabled =
+                !service.mContext
+                        .getResources()
+                        .getBoolean(com.android.internal.R.bool.config_disableTaskSnapshots)
+                && isSnapshotEnabled()
+                && !ActivityManager.isLowRamDeviceStatic(); // Don't support Android Go
+        setSnapshotEnabled(snapshotEnabled);
+    }
+
+    void systemReady() {
+        if (shouldDisableSnapshots()) {
+            return;
+        }
+        mService.mSnapshotController.registerTransitionStateConsumer(
+                ACTIVITY_OPEN, this::handleOpenActivityTransition);
+        mService.mSnapshotController.registerTransitionStateConsumer(
+                ACTIVITY_CLOSE, this::handleCloseActivityTransition);
+        mService.mSnapshotController.registerTransitionStateConsumer(
+                TASK_OPEN, this::handleOpenTaskTransition);
+        mService.mSnapshotController.registerTransitionStateConsumer(
+                TASK_CLOSE, this::handleCloseTaskTransition);
+    }
+
+    @Override
+    protected float initSnapshotScale() {
+        final float config = mService.mContext.getResources().getFloat(
+                com.android.internal.R.dimen.config_resActivitySnapshotScale);
+        return Math.max(Math.min(config, 1f), 0.1f);
+    }
+
+    // TODO remove when enabled
+    static boolean isSnapshotEnabled() {
+        return SystemProperties.getInt("persist.wm.debug.activity_screenshot", 0) != 0;
+    }
+
+    static PersistInfoProvider createPersistInfoProvider(
+            WindowManagerService service, BaseAppSnapshotPersister.DirectoryResolver resolver) {
+        // Don't persist reduced file, instead we only persist the "HighRes" bitmap which has
+        // already scaled with #initSnapshotScale
+        final boolean use16BitFormat = service.mContext.getResources().getBoolean(
+                com.android.internal.R.bool.config_use16BitTaskSnapshotPixelFormat);
+        return new PersistInfoProvider(resolver, SNAPSHOTS_DIRNAME,
+                false /* enableLowResSnapshots */, 0 /* lowResScaleFactor */, use16BitFormat);
+    }
+
+    /** Retrieves a snapshot for an activity from cache. */
+    @Nullable
+    TaskSnapshot getSnapshot(ActivityRecord ar) {
+        final int code = getSystemHashCode(ar);
+        return mCache.getSnapshot(code);
+    }
+
+    private void cleanUpUserFiles(int userId) {
+        synchronized (mSnapshotPersistQueue.getLock()) {
+            mSnapshotPersistQueue.sendToQueueLocked(
+                    new SnapshotPersistQueue.WriteQueueItem(mPersistInfoProvider) {
+                        @Override
+                        boolean isReady() {
+                            final UserManagerInternal mUserManagerInternal =
+                                    LocalServices.getService(UserManagerInternal.class);
+                            return mUserManagerInternal.isUserUnlocked(userId);
+                        }
+
+                        @Override
+                        void write() {
+                            final File file = mPersistInfoProvider.getDirectory(userId);
+                            if (file.exists()) {
+                                final File[] contents = file.listFiles();
+                                if (contents != null) {
+                                    for (int i = contents.length - 1; i >= 0; i--) {
+                                        contents[i].delete();
+                                    }
+                                }
+                            }
+                        }
+                    });
+        }
+    }
+
+    /**
+     * Prepare to handle on transition start. Clear all temporary fields.
+     */
+    void preTransitionStart() {
+        resetTmpFields();
+    }
+
+    /**
+     * on transition start has notified, start process data.
+     */
+    void postTransitionStart() {
+        if (shouldDisableSnapshots()) {
+            return;
+        }
+        onCommitTransition();
+    }
+
+    @VisibleForTesting
+    void resetTmpFields() {
+        mPendingCaptureActivity.clear();
+        mPendingRemoveActivity.clear();
+        mPendingDeleteActivity.clear();
+        mPendingLoadActivity.clear();
+    }
+
+    /**
+     * Start process all pending activities for a transition.
+     */
+    private void onCommitTransition() {
+        if (DEBUG) {
+            Slog.d(TAG, "ActivitySnapshotController#onCommitTransition result:"
+                    + " capture " + mPendingCaptureActivity
+                    + " remove " + mPendingRemoveActivity
+                    + " delete " + mPendingDeleteActivity
+                    + " load " + mPendingLoadActivity);
+        }
+        // task snapshots
+        for (int i = mPendingCaptureActivity.size() - 1; i >= 0; i--) {
+            recordSnapshot(mPendingCaptureActivity.valueAt(i));
+        }
+        // clear mTmpRemoveActivity from cache
+        for (int i = mPendingRemoveActivity.size() - 1; i >= 0; i--) {
+            final ActivityRecord ar = mPendingRemoveActivity.valueAt(i);
+            final int code = getSystemHashCode(ar);
+            mCache.onIdRemoved(code);
+        }
+        // clear snapshot on cache and delete files
+        for (int i = mPendingDeleteActivity.size() - 1; i >= 0; i--) {
+            final ActivityRecord ar = mPendingDeleteActivity.valueAt(i);
+            final int code = getSystemHashCode(ar);
+            mCache.onIdRemoved(code);
+            removeIfUserSavedFileExist(code, ar.mUserId);
+        }
+        // load snapshot to cache
+        for (int i = mPendingLoadActivity.size() - 1; i >= 0; i--) {
+            final ActivityRecord ar = mPendingLoadActivity.valueAt(i);
+            final int code = getSystemHashCode(ar);
+            final int userId = ar.mUserId;
+            if (mCache.getSnapshot(code) != null) {
+                // already in cache, skip
+                continue;
+            }
+            if (containsFile(code, userId)) {
+                synchronized (mSnapshotPersistQueue.getLock()) {
+                    mSnapshotPersistQueue.sendToQueueLocked(
+                            new SnapshotPersistQueue.WriteQueueItem(mPersistInfoProvider) {
+                                @Override
+                                void write() {
+                                    final TaskSnapshot snapshot = mSnapshotLoader.loadTask(code,
+                                            userId, false /* loadLowResolutionBitmap */);
+                                    synchronized (mService.getWindowManagerLock()) {
+                                        if (snapshot != null && !ar.finishing) {
+                                            mCache.putSnapshot(ar, snapshot);
+                                        }
+                                    }
+                                }
+                            });
+                }
+            }
+        }
+        // don't keep any reference
+        resetTmpFields();
+    }
+
+    private void recordSnapshot(ActivityRecord activity) {
+        final TaskSnapshot snapshot = recordSnapshotInner(activity, false /* allowSnapshotHome */);
+        if (snapshot != null) {
+            final int code = getSystemHashCode(activity);
+            addUserSavedFile(code, activity.mUserId, snapshot);
+        }
+    }
+
+    /**
+     * Called when the visibility of an app changes outside the regular app transition flow.
+     */
+    void notifyAppVisibilityChanged(ActivityRecord appWindowToken, boolean visible) {
+        if (!visible) {
+            resetTmpFields();
+            addBelowTopActivityIfExist(appWindowToken.getTask(), mPendingRemoveActivity,
+                    "remove-snapshot");
+            onCommitTransition();
+        }
+    }
+
+    private static int getSystemHashCode(ActivityRecord activity) {
+        return System.identityHashCode(activity);
+    }
+
+    void handleOpenActivityTransition(TransitionState<ActivityRecord> transitionState) {
+        ArraySet<ActivityRecord> participant = transitionState.getParticipant(false /* open */);
+        for (ActivityRecord ar : participant) {
+            mPendingCaptureActivity.add(ar);
+            // remove the snapshot for the one below close
+            final ActivityRecord below = ar.getTask().getActivityBelow(ar);
+            if (below != null) {
+                mPendingRemoveActivity.add(below);
+            }
+        }
+    }
+
+    void handleCloseActivityTransition(TransitionState<ActivityRecord> transitionState) {
+        ArraySet<ActivityRecord> participant = transitionState.getParticipant(true /* open */);
+        for (ActivityRecord ar : participant) {
+            mPendingDeleteActivity.add(ar);
+            // load next one if exists.
+            final ActivityRecord below = ar.getTask().getActivityBelow(ar);
+            if (below != null) {
+                mPendingLoadActivity.add(below);
+            }
+        }
+    }
+
+    void handleCloseTaskTransition(TransitionState<Task> closeTaskTransitionRecord) {
+        ArraySet<Task> participant = closeTaskTransitionRecord.getParticipant(false /* open */);
+        for (Task close : participant) {
+            // this is close task transition
+            // remove the N - 1 from cache
+            addBelowTopActivityIfExist(close, mPendingRemoveActivity, "remove-snapshot");
+        }
+    }
+
+    void handleOpenTaskTransition(TransitionState<Task> openTaskTransitionRecord) {
+        ArraySet<Task> participant = openTaskTransitionRecord.getParticipant(true /* open */);
+        for (Task open : participant) {
+            // this is close task transition
+            // remove the N - 1 from cache
+            addBelowTopActivityIfExist(open, mPendingLoadActivity, "load-snapshot");
+            // Move the activities to top of mSavedFilesInOrder, so when purge happen, there
+            // will trim the persisted files from the most non-accessed.
+            adjustSavedFileOrder(open);
+        }
+    }
+
+    // Add the top -1 activity to a set if it exists.
+    private void addBelowTopActivityIfExist(Task task, ArraySet<ActivityRecord> set,
+            String debugMessage) {
+        final ActivityRecord topActivity = task.getTopMostActivity();
+        if (topActivity != null) {
+            final ActivityRecord below = task.getActivityBelow(topActivity);
+            if (below != null) {
+                set.add(below);
+                if (DEBUG) {
+                    Slog.d(TAG, "ActivitySnapshotController#addBelowTopActivityIfExist "
+                            + below + " from " + debugMessage);
+                }
+            }
+        }
+    }
+
+    private void adjustSavedFileOrder(Task nextTopTask) {
+        final int userId = nextTopTask.mUserId;
+        nextTopTask.forAllActivities(ar -> {
+            final int code = getSystemHashCode(ar);
+            final UserSavedFile usf = getUserFiles(userId).get(code);
+            if (usf != null) {
+                mSavedFilesInOrder.remove(usf);
+                mSavedFilesInOrder.add(usf);
+            }
+        }, false /* traverseTopToBottom */);
+    }
+
+    @Override
+    void onAppRemoved(ActivityRecord activity) {
+        super.onAppRemoved(activity);
+        final int code = getSystemHashCode(activity);
+        removeIfUserSavedFileExist(code, activity.mUserId);
+        if (DEBUG) {
+            Slog.d(TAG, "ActivitySnapshotController#onAppRemoved delete snapshot " + activity);
+        }
+    }
+
+    @Override
+    void onAppDied(ActivityRecord activity) {
+        super.onAppDied(activity);
+        final int code = getSystemHashCode(activity);
+        removeIfUserSavedFileExist(code, activity.mUserId);
+        if (DEBUG) {
+            Slog.d(TAG, "ActivitySnapshotController#onAppDied delete snapshot " + activity);
+        }
+    }
+
+    @Override
+    ActivityRecord getTopActivity(ActivityRecord activity) {
+        return activity;
+    }
+
+    @Override
+    ActivityRecord getTopFullscreenActivity(ActivityRecord activity) {
+        final WindowState win = activity.findMainWindow();
+        return (win != null && win.mAttrs.isFullscreen()) ? activity : null;
+    }
+
+    @Override
+    ActivityManager.TaskDescription getTaskDescription(ActivityRecord object) {
+        return object.taskDescription;
+    }
+
+    /**
+     * Find the window for a given activity to take a snapshot. During app transitions, trampoline
+     * activities can appear in the children, but should be ignored.
+     */
+    @Override
+    protected ActivityRecord findAppTokenForSnapshot(ActivityRecord activity) {
+        if (activity == null) {
+            return null;
+        }
+        return activity.canCaptureSnapshot() ? activity : null;
+    }
+
+    @Override
+    protected boolean use16BitFormat() {
+        return mPersistInfoProvider.use16BitFormat();
+    }
+
+    @NonNull
+    private SparseArray<UserSavedFile> getUserFiles(int userId) {
+        if (mUserSavedFiles.get(userId) == null) {
+            mUserSavedFiles.put(userId, new SparseArray<>());
+            // This is the first time this user attempt to access snapshot, clear up the disk.
+            cleanUpUserFiles(userId);
+        }
+        return mUserSavedFiles.get(userId);
+    }
+
+    private void removeIfUserSavedFileExist(int code, int userId) {
+        final UserSavedFile usf = getUserFiles(userId).get(code);
+        if (usf != null) {
+            mUserSavedFiles.remove(code);
+            mSavedFilesInOrder.remove(usf);
+            mPersister.removeSnap(code, userId);
+        }
+    }
+
+    private boolean containsFile(int code, int userId) {
+        return getUserFiles(userId).get(code) != null;
+    }
+
+    private void addUserSavedFile(int code, int userId, TaskSnapshot snapshot) {
+        final SparseArray<UserSavedFile> savedFiles = getUserFiles(userId);
+        final UserSavedFile savedFile = savedFiles.get(code);
+        if (savedFile == null) {
+            final UserSavedFile usf = new UserSavedFile(code, userId);
+            savedFiles.put(code, usf);
+            mSavedFilesInOrder.add(usf);
+            mPersister.persistSnapshot(code, userId, snapshot);
+
+            if (mSavedFilesInOrder.size() > MAX_PERSIST_SNAPSHOT_COUNT * 2) {
+                purgeSavedFile();
+            }
+        }
+    }
+
+    private void purgeSavedFile() {
+        final int savedFileCount = mSavedFilesInOrder.size();
+        final int removeCount = savedFileCount - MAX_PERSIST_SNAPSHOT_COUNT;
+        final ArrayList<UserSavedFile> usfs = new ArrayList<>();
+        if (removeCount > 0) {
+            final int removeTillIndex = savedFileCount - removeCount;
+            for (int i = savedFileCount - 1; i > removeTillIndex; --i) {
+                final UserSavedFile usf = mSavedFilesInOrder.remove(i);
+                if (usf != null) {
+                    mUserSavedFiles.remove(usf.mFileId);
+                    usfs.add(usf);
+                }
+            }
+        }
+        if (usfs.size() > 0) {
+            removeSnapshotFiles(usfs);
+        }
+    }
+
+    private void removeSnapshotFiles(ArrayList<UserSavedFile> files) {
+        synchronized (mSnapshotPersistQueue.getLock()) {
+            mSnapshotPersistQueue.sendToQueueLocked(
+                    new SnapshotPersistQueue.WriteQueueItem(mPersistInfoProvider) {
+                        @Override
+                        void write() {
+                            for (int i = files.size() - 1; i >= 0; --i) {
+                                final UserSavedFile usf = files.get(i);
+                                mSnapshotPersistQueue.deleteSnapshot(
+                                        usf.mFileId, usf.mUserId, mPersistInfoProvider);
+                            }
+                        }
+                    });
+        }
+    }
+
+    static class UserSavedFile {
+        int mFileId;
+        int mUserId;
+        UserSavedFile(int fileId, int userId) {
+            mFileId = fileId;
+            mUserId = userId;
+        }
+    }
+}
diff --git a/services/core/java/com/android/server/wm/AppTransitionController.java b/services/core/java/com/android/server/wm/AppTransitionController.java
index 4e94f96..841d28b 100644
--- a/services/core/java/com/android/server/wm/AppTransitionController.java
+++ b/services/core/java/com/android/server/wm/AppTransitionController.java
@@ -321,7 +321,7 @@
             mService.mSurfaceAnimationRunner.continueStartingAnimations();
         }
 
-        mService.mTaskSnapshotController.onTransitionStarting(mDisplayContent);
+        mService.mSnapshotController.onTransitionStarting(mDisplayContent);
 
         mDisplayContent.mOpeningApps.clear();
         mDisplayContent.mClosingApps.clear();
diff --git a/services/core/java/com/android/server/wm/AbsAppSnapshotCache.java b/services/core/java/com/android/server/wm/SnapshotCache.java
similarity index 96%
rename from services/core/java/com/android/server/wm/AbsAppSnapshotCache.java
rename to services/core/java/com/android/server/wm/SnapshotCache.java
index c8adc8f..401b260 100644
--- a/services/core/java/com/android/server/wm/AbsAppSnapshotCache.java
+++ b/services/core/java/com/android/server/wm/SnapshotCache.java
@@ -25,13 +25,13 @@
  * Base class for an app snapshot cache
  * @param <TYPE> The basic type, either Task or ActivityRecord
  */
-abstract class AbsAppSnapshotCache<TYPE extends WindowContainer> {
+abstract class SnapshotCache<TYPE extends WindowContainer> {
     protected final WindowManagerService mService;
     protected final String mName;
     protected final ArrayMap<ActivityRecord, Integer> mAppIdMap = new ArrayMap<>();
     protected final ArrayMap<Integer, CacheEntry> mRunningCache = new ArrayMap<>();
 
-    AbsAppSnapshotCache(WindowManagerService service, String name) {
+    SnapshotCache(WindowManagerService service, String name) {
         mService = service;
         mName = name;
     }
diff --git a/services/core/java/com/android/server/wm/SnapshotController.java b/services/core/java/com/android/server/wm/SnapshotController.java
new file mode 100644
index 0000000..cd1263e
--- /dev/null
+++ b/services/core/java/com/android/server/wm/SnapshotController.java
@@ -0,0 +1,337 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server.wm;
+
+import static android.view.WindowManager.TRANSIT_CLOSE;
+import static android.view.WindowManager.TRANSIT_FIRST_CUSTOM;
+import static android.view.WindowManager.TRANSIT_OPEN;
+import static android.view.WindowManager.TRANSIT_TO_BACK;
+import static android.view.WindowManager.TRANSIT_TO_FRONT;
+
+import android.annotation.IntDef;
+import android.util.ArraySet;
+import android.util.Slog;
+import android.util.SparseArray;
+import android.view.WindowManager;
+
+import com.android.internal.annotations.VisibleForTesting;
+
+import java.io.PrintWriter;
+import java.lang.annotation.Retention;
+import java.lang.annotation.RetentionPolicy;
+import java.util.ArrayList;
+import java.util.function.Consumer;
+
+/**
+ * Integrates common functionality from TaskSnapshotController and ActivitySnapshotController.
+ */
+class SnapshotController {
+    private static final boolean DEBUG = false;
+    private static final String TAG = AbsAppSnapshotController.TAG;
+
+    static final int ACTIVITY_OPEN = 1;
+    static final int ACTIVITY_CLOSE = 2;
+    static final int TASK_OPEN = 4;
+    static final int TASK_CLOSE = 8;
+    /** @hide */
+    @Retention(RetentionPolicy.SOURCE)
+    @IntDef(
+            value = {ACTIVITY_OPEN,
+                    ACTIVITY_CLOSE,
+                    TASK_OPEN,
+                    TASK_CLOSE})
+    @interface TransitionStateType {}
+
+    private final SnapshotPersistQueue mSnapshotPersistQueue;
+    final TaskSnapshotController mTaskSnapshotController;
+    final ActivitySnapshotController mActivitySnapshotController;
+
+    private final ArraySet<Task> mTmpCloseTasks = new ArraySet<>();
+    private final ArraySet<Task> mTmpOpenTasks = new ArraySet<>();
+
+    private final SparseArray<TransitionState> mTmpOpenCloseRecord = new SparseArray<>();
+    private final ArraySet<Integer> mTmpAnalysisRecord = new ArraySet<>();
+    private final SparseArray<ArrayList<Consumer<TransitionState>>> mTransitionStateConsumer =
+            new SparseArray<>();
+    private int mActivatedType;
+
+    private final ActivityOrderCheck mActivityOrderCheck = new ActivityOrderCheck();
+    private final ActivityOrderCheck.AnalysisResult mResultHandler = (type, close, open) -> {
+        addTransitionRecord(type, true/* open */, open);
+        addTransitionRecord(type, false/* open */, close);
+    };
+
+    private static class ActivityOrderCheck {
+        private ActivityRecord mOpenActivity;
+        private ActivityRecord mCloseActivity;
+        private int mOpenIndex = -1;
+        private int mCloseIndex = -1;
+
+        private void reset() {
+            mOpenActivity = null;
+            mCloseActivity = null;
+            mOpenIndex = -1;
+            mCloseIndex = -1;
+        }
+
+        private void setTarget(boolean open, ActivityRecord ar, int index) {
+            if (open) {
+                mOpenActivity = ar;
+                mOpenIndex = index;
+            } else {
+                mCloseActivity = ar;
+                mCloseIndex = index;
+            }
+        }
+
+        void analysisOrder(ArraySet<ActivityRecord> closeApps,
+                ArraySet<ActivityRecord> openApps, Task task, AnalysisResult result) {
+            for (int j = closeApps.size() - 1; j >= 0; j--) {
+                final ActivityRecord ar = closeApps.valueAt(j);
+                if (ar.getTask() == task) {
+                    setTarget(false, ar, task.mChildren.indexOf(ar));
+                    break;
+                }
+            }
+            for (int j = openApps.size() - 1; j >= 0; j--) {
+                final ActivityRecord ar = openApps.valueAt(j);
+                if (ar.getTask() == task) {
+                    setTarget(true, ar, task.mChildren.indexOf(ar));
+                    break;
+                }
+            }
+            if (mOpenIndex > mCloseIndex && mCloseIndex != -1) {
+                result.onCheckResult(ACTIVITY_OPEN, mCloseActivity, mOpenActivity);
+            } else if (mOpenIndex < mCloseIndex && mOpenIndex != -1) {
+                result.onCheckResult(ACTIVITY_CLOSE, mCloseActivity, mOpenActivity);
+            }
+            reset();
+        }
+        private interface AnalysisResult {
+            void onCheckResult(@TransitionStateType int type,
+                    ActivityRecord close, ActivityRecord open);
+        }
+    }
+
+    private void addTransitionRecord(int type, boolean open, WindowContainer target) {
+        TransitionState record = mTmpOpenCloseRecord.get(type);
+        if (record == null) {
+            record =  new TransitionState();
+            mTmpOpenCloseRecord.set(type, record);
+        }
+        record.addParticipant(target, open);
+        mTmpAnalysisRecord.add(type);
+    }
+
+    private void clearRecord() {
+        mTmpOpenCloseRecord.clear();
+        mTmpAnalysisRecord.clear();
+    }
+
+    static class TransitionState<TYPE extends WindowContainer> {
+        private final ArraySet<TYPE> mOpenParticipant = new ArraySet<>();
+        private final ArraySet<TYPE> mCloseParticipant = new ArraySet<>();
+
+        void addParticipant(TYPE target, boolean open) {
+            final ArraySet<TYPE> participant = open
+                    ? mOpenParticipant : mCloseParticipant;
+            participant.add(target);
+        }
+
+        ArraySet<TYPE> getParticipant(boolean open) {
+            return open ? mOpenParticipant : mCloseParticipant;
+        }
+    }
+
+    SnapshotController(WindowManagerService wms) {
+        mSnapshotPersistQueue = new SnapshotPersistQueue();
+        mTaskSnapshotController = new TaskSnapshotController(wms, mSnapshotPersistQueue);
+        mActivitySnapshotController = new ActivitySnapshotController(wms, mSnapshotPersistQueue);
+    }
+
+    void registerTransitionStateConsumer(@TransitionStateType int type,
+            Consumer<TransitionState> consumer) {
+        ArrayList<Consumer<TransitionState>> consumers = mTransitionStateConsumer.get(type);
+        if (consumers == null) {
+            consumers = new ArrayList<>();
+            mTransitionStateConsumer.set(type, consumers);
+        }
+        if (!consumers.contains(consumer)) {
+            consumers.add(consumer);
+        }
+        mActivatedType |= type;
+    }
+
+    void unregisterTransitionStateConsumer(int type, Consumer<TransitionState> consumer) {
+        final ArrayList<Consumer<TransitionState>> consumers = mTransitionStateConsumer.get(type);
+        if (consumers == null) {
+            return;
+        }
+        consumers.remove(consumer);
+        if (consumers.size() == 0) {
+            mActivatedType &= ~type;
+        }
+    }
+
+    private boolean hasTransitionStateConsumer(@TransitionStateType int type) {
+        return (mActivatedType & type) != 0;
+    }
+
+    void systemReady() {
+        mSnapshotPersistQueue.systemReady();
+        mTaskSnapshotController.systemReady();
+        mActivitySnapshotController.systemReady();
+    }
+
+    void setPause(boolean paused) {
+        mSnapshotPersistQueue.setPaused(paused);
+    }
+
+    void onAppRemoved(ActivityRecord activity) {
+        mTaskSnapshotController.onAppRemoved(activity);
+        mActivitySnapshotController.onAppRemoved(activity);
+    }
+
+    void onAppDied(ActivityRecord activity) {
+        mTaskSnapshotController.onAppDied(activity);
+        mActivitySnapshotController.onAppDied(activity);
+    }
+
+    void notifyAppVisibilityChanged(ActivityRecord appWindowToken, boolean visible) {
+        if (!visible && hasTransitionStateConsumer(TASK_CLOSE)) {
+            // close task transition
+            addTransitionRecord(TASK_CLOSE, false /*open*/, appWindowToken.getTask());
+            mActivitySnapshotController.preTransitionStart();
+            notifyTransition(TASK_CLOSE);
+            mActivitySnapshotController.postTransitionStart();
+            clearRecord();
+        }
+    }
+
+    // For legacy transition
+    void onTransitionStarting(DisplayContent displayContent) {
+        handleAppTransition(displayContent.mClosingApps, displayContent.mOpeningApps);
+    }
+
+    // For shell transition, adapt to legacy transition.
+    void onTransitionReady(@WindowManager.TransitionType int type,
+            ArraySet<WindowContainer> participants) {
+        final boolean isTransitionOpen = isTransitionOpen(type);
+        final boolean isTransitionClose = isTransitionClose(type);
+        if (!isTransitionOpen && !isTransitionClose && type < TRANSIT_FIRST_CUSTOM
+                || (mActivatedType == 0)) {
+            return;
+        }
+        final ArraySet<ActivityRecord> openingApps = new ArraySet<>();
+        final ArraySet<ActivityRecord> closingApps = new ArraySet<>();
+
+        for (int i = participants.size() - 1; i >= 0; --i) {
+            final ActivityRecord ar = participants.valueAt(i).asActivityRecord();
+            if (ar == null || ar.getTask() == null) continue;
+            if (ar.isVisibleRequested()) {
+                openingApps.add(ar);
+            } else {
+                closingApps.add(ar);
+            }
+        }
+        handleAppTransition(closingApps, openingApps);
+    }
+
+    private static boolean isTransitionOpen(int type) {
+        return type == TRANSIT_OPEN || type == TRANSIT_TO_FRONT;
+    }
+    private static boolean isTransitionClose(int type) {
+        return type == TRANSIT_CLOSE || type == TRANSIT_TO_BACK;
+    }
+
+    @VisibleForTesting
+    void handleAppTransition(ArraySet<ActivityRecord> closingApps,
+            ArraySet<ActivityRecord> openApps) {
+        if (mActivatedType == 0) {
+            return;
+        }
+        analysisTransition(closingApps, openApps);
+        mActivitySnapshotController.preTransitionStart();
+        for (Integer transitionType : mTmpAnalysisRecord) {
+            notifyTransition(transitionType);
+        }
+        mActivitySnapshotController.postTransitionStart();
+        clearRecord();
+    }
+
+    private void notifyTransition(int transitionType) {
+        final TransitionState record = mTmpOpenCloseRecord.get(transitionType);
+        final ArrayList<Consumer<TransitionState>> consumers =
+                mTransitionStateConsumer.get(transitionType);
+        for (Consumer<TransitionState> consumer : consumers) {
+            consumer.accept(record);
+        }
+    }
+
+    private void analysisTransition(ArraySet<ActivityRecord> closingApps,
+            ArraySet<ActivityRecord> openingApps) {
+        getParticipantTasks(closingApps, mTmpCloseTasks, false /* isOpen */);
+        getParticipantTasks(openingApps, mTmpOpenTasks, true /* isOpen */);
+        if (DEBUG) {
+            Slog.d(TAG, "AppSnapshotController#analysisTransition participants"
+                    + " mTmpCloseTasks " + mTmpCloseTasks
+                    + " mTmpOpenTasks " + mTmpOpenTasks);
+        }
+        for (int i = mTmpCloseTasks.size() - 1; i >= 0; i--) {
+            final Task closeTask = mTmpCloseTasks.valueAt(i);
+            if (mTmpOpenTasks.contains(closeTask)) {
+                if (hasTransitionStateConsumer(ACTIVITY_OPEN)
+                        || hasTransitionStateConsumer(ACTIVITY_CLOSE)) {
+                    mActivityOrderCheck.analysisOrder(closingApps, openingApps, closeTask,
+                            mResultHandler);
+                }
+            } else if (hasTransitionStateConsumer(TASK_CLOSE)) {
+                // close task transition
+                addTransitionRecord(TASK_CLOSE, false /*open*/, closeTask);
+            }
+        }
+        if (hasTransitionStateConsumer(TASK_OPEN)) {
+            for (int i = mTmpOpenTasks.size() - 1; i >= 0; i--) {
+                final Task openTask = mTmpOpenTasks.valueAt(i);
+                if (!mTmpCloseTasks.contains(openTask)) {
+                    // this is open task transition
+                    addTransitionRecord(TASK_OPEN, true /*open*/, openTask);
+                }
+            }
+        }
+        mTmpCloseTasks.clear();
+        mTmpOpenTasks.clear();
+    }
+
+    private void getParticipantTasks(ArraySet<ActivityRecord> activityRecords, ArraySet<Task> tasks,
+            boolean isOpen) {
+        for (int i = activityRecords.size() - 1; i >= 0; i--) {
+            final ActivityRecord activity = activityRecords.valueAt(i);
+            final Task task = activity.getTask();
+            if (task == null) continue;
+
+            if (isOpen == activity.isVisibleRequested()) {
+                tasks.add(task);
+            }
+        }
+    }
+
+    void dump(PrintWriter pw, String prefix) {
+        mTaskSnapshotController.dump(pw, prefix);
+        mActivitySnapshotController.dump(pw, prefix);
+    }
+}
diff --git a/services/core/java/com/android/server/wm/SnapshotPersistQueue.java b/services/core/java/com/android/server/wm/SnapshotPersistQueue.java
index fdc3616..afef85e 100644
--- a/services/core/java/com/android/server/wm/SnapshotPersistQueue.java
+++ b/services/core/java/com/android/server/wm/SnapshotPersistQueue.java
@@ -129,7 +129,7 @@
         }
     }
 
-    private void deleteSnapshot(int index, int userId, PersistInfoProvider provider) {
+    void deleteSnapshot(int index, int userId, PersistInfoProvider provider) {
         final File protoFile = provider.getProtoFile(index, userId);
         final File bitmapLowResFile = provider.getLowResolutionBitmapFile(index, userId);
         protoFile.delete();
diff --git a/services/core/java/com/android/server/wm/TaskSnapshotCache.java b/services/core/java/com/android/server/wm/TaskSnapshotCache.java
index 55e863e..33486cc 100644
--- a/services/core/java/com/android/server/wm/TaskSnapshotCache.java
+++ b/services/core/java/com/android/server/wm/TaskSnapshotCache.java
@@ -24,7 +24,7 @@
  * <p>
  * Access to this class should be guarded by the global window manager lock.
  */
-class TaskSnapshotCache extends AbsAppSnapshotCache<Task> {
+class TaskSnapshotCache extends SnapshotCache<Task> {
 
     private final AppSnapshotLoader mLoader;
 
diff --git a/services/core/java/com/android/server/wm/TaskSnapshotController.java b/services/core/java/com/android/server/wm/TaskSnapshotController.java
index 679f0f5..4d0bff9 100644
--- a/services/core/java/com/android/server/wm/TaskSnapshotController.java
+++ b/services/core/java/com/android/server/wm/TaskSnapshotController.java
@@ -16,6 +16,7 @@
 
 package com.android.server.wm;
 
+import static com.android.server.wm.SnapshotController.TASK_CLOSE;
 import static com.android.server.wm.WindowManagerDebugConfig.DEBUG_SCREENSHOT;
 import static com.android.server.wm.WindowManagerDebugConfig.TAG_WM;
 
@@ -77,6 +78,13 @@
         setSnapshotEnabled(snapshotEnabled);
     }
 
+    void systemReady() {
+        if (!shouldDisableSnapshots()) {
+            mService.mSnapshotController.registerTransitionStateConsumer(TASK_CLOSE,
+                    this::handleTaskClose);
+        }
+    }
+
     static PersistInfoProvider createPersistInfoProvider(WindowManagerService service,
             BaseAppSnapshotPersister.DirectoryResolver resolver) {
         final float highResTaskSnapshotScale = service.mContext.getResources().getFloat(
@@ -109,8 +117,21 @@
                 enableLowResSnapshots, lowResScaleFactor, use16BitFormat);
     }
 
-    void onTransitionStarting(DisplayContent displayContent) {
-        handleClosingApps(displayContent.mClosingApps);
+    void handleTaskClose(SnapshotController.TransitionState<Task> closeTaskTransitionRecord) {
+        if (shouldDisableSnapshots()) {
+            return;
+        }
+        mTmpTasks.clear();
+        final ArraySet<Task> tasks = closeTaskTransitionRecord.getParticipant(false /* open */);
+        if (mService.mAtmService.getTransitionController().isShellTransitionsEnabled()) {
+            mTmpTasks.addAll(tasks);
+        } else {
+            for (Task task : tasks) {
+                getClosingTasksInner(task, mTmpTasks);
+            }
+        }
+        snapshotTasks(mTmpTasks);
+        mSkipClosingAppSnapshotTasks.clear();
     }
 
     /**
@@ -189,18 +210,7 @@
      * children, which should be ignored.
      */
     @Nullable protected ActivityRecord findAppTokenForSnapshot(Task task) {
-        return task.getActivity((r) -> {
-            if (r == null || !r.isSurfaceShowing() || r.findMainWindow() == null) {
-                return false;
-            }
-            return r.forAllWindows(
-                    // Ensure at least one window for the top app is visible before attempting to
-                    // take a screenshot. Visible here means that the WSA surface is shown and has
-                    // an alpha greater than 0.
-                    ws -> ws.mWinAnimator != null && ws.mWinAnimator.getShown()
-                            && ws.mWinAnimator.mLastAlpha > 0f, true  /* traverseTopToBottom */);
-
-        });
+        return task.getActivity(ActivityRecord::canCaptureSnapshot);
     }
 
 
@@ -272,32 +282,22 @@
             final Task task = activity.getTask();
             if (task == null) continue;
 
-            // Since RecentsAnimation will handle task snapshot while switching apps with the
-            // best capture timing (e.g. IME window capture),
-            // No need additional task capture while task is controlled by RecentsAnimation.
-            if (isAnimatingByRecents(task)) {
-                mSkipClosingAppSnapshotTasks.add(task);
-            }
-            // If the task of the app is not visible anymore, it means no other app in that task
-            // is opening. Thus, the task is closing.
-            if (!task.isVisible() && !mSkipClosingAppSnapshotTasks.contains(task)) {
-                outClosingTasks.add(task);
-            }
+            getClosingTasksInner(task, outClosingTasks);
         }
     }
 
-    /**
-     * Called when an {@link ActivityRecord} has been removed.
-     */
-    void onAppRemoved(ActivityRecord activity) {
-        mCache.onAppRemoved(activity);
-    }
-
-    /**
-     * Called when the process of an {@link ActivityRecord} has died.
-     */
-    void onAppDied(ActivityRecord activity) {
-        mCache.onAppDied(activity);
+    void getClosingTasksInner(Task task, ArraySet<Task> outClosingTasks) {
+        // Since RecentsAnimation will handle task snapshot while switching apps with the
+        // best capture timing (e.g. IME window capture),
+        // No need additional task capture while task is controlled by RecentsAnimation.
+        if (isAnimatingByRecents(task)) {
+            mSkipClosingAppSnapshotTasks.add(task);
+        }
+        // If the task of the app is not visible anymore, it means no other app in that task
+        // is opening. Thus, the task is closing.
+        if (!task.isVisible() && !mSkipClosingAppSnapshotTasks.contains(task)) {
+            outClosingTasks.add(task);
+        }
     }
 
     void notifyTaskRemovedFromRecents(int taskId, int userId) {
@@ -361,9 +361,4 @@
                 && mService.mPolicy.isKeyguardSecure(mService.mCurrentUserId);
         snapshotTasks(mTmpTasks, allowSnapshotHome);
     }
-
-    private boolean isAnimatingByRecents(@NonNull Task task) {
-        return task.isAnimatingByRecents()
-                || mService.mAtmService.getTransitionController().inRecentsTransition(task);
-    }
 }
diff --git a/services/core/java/com/android/server/wm/Transition.java b/services/core/java/com/android/server/wm/Transition.java
index 873a83d..362e1c8 100644
--- a/services/core/java/com/android/server/wm/Transition.java
+++ b/services/core/java/com/android/server/wm/Transition.java
@@ -879,8 +879,10 @@
                                 && mTransientLaunches != null) {
                             // If transition is transient, then snapshots are taken at end of
                             // transition.
-                            mController.mTaskSnapshotController.recordSnapshot(
-                                    task, false /* allowSnapshotHome */);
+                            mController.mSnapshotController.mTaskSnapshotController
+                                    .recordSnapshot(task, false /* allowSnapshotHome */);
+                            mController.mSnapshotController.mActivitySnapshotController
+                                    .notifyAppVisibilityChanged(ar, false /* visible */);
                         }
                         ar.commitVisibility(false /* visible */, false /* performLayout */,
                                 true /* fromTransition */);
@@ -1225,13 +1227,7 @@
         // transferred. If transition is transient, IME won't be moved during the transition and
         // the tasks are still live, so we take the snapshot at the end of the transition instead.
         if (mTransientLaunches == null) {
-            for (int i = mParticipants.size() - 1; i >= 0; --i) {
-                final ActivityRecord ar = mParticipants.valueAt(i).asActivityRecord();
-                if (ar == null || ar.isVisibleRequested() || ar.getTask() == null
-                        || ar.getTask().isVisibleRequested()) continue;
-                mController.mTaskSnapshotController.recordSnapshot(
-                        ar.getTask(), false /* allowSnapshotHome */);
-            }
+            mController.mSnapshotController.onTransitionReady(mType, mParticipants);
         }
 
         // This is non-null only if display has changes. It handles the visible windows that don't
diff --git a/services/core/java/com/android/server/wm/TransitionController.java b/services/core/java/com/android/server/wm/TransitionController.java
index c74f167..f314b21 100644
--- a/services/core/java/com/android/server/wm/TransitionController.java
+++ b/services/core/java/com/android/server/wm/TransitionController.java
@@ -88,8 +88,9 @@
 
     private WindowProcessController mTransitionPlayerProc;
     final ActivityTaskManagerService mAtm;
+
     final RemotePlayer mRemotePlayer;
-    TaskSnapshotController mTaskSnapshotController;
+    SnapshotController mSnapshotController;
     TransitionTracer mTransitionTracer;
 
     private final ArrayList<WindowManagerInternal.AppTransitionListener> mLegacyListeners =
@@ -153,7 +154,7 @@
     }
 
     void setWindowManager(WindowManagerService wms) {
-        mTaskSnapshotController = wms.mTaskSnapshotController;
+        mSnapshotController = wms.mSnapshotController;
         mTransitionTracer = wms.mTransitionTracer;
         mIsWaitingForDisplayEnabled = !wms.mDisplayEnabled;
         registerLegacyListener(wms.mActivityManagerAppTransitionNotifier);
@@ -739,12 +740,12 @@
             t.setEarlyWakeupStart();
             // Usually transitions put quite a load onto the system already (with all the things
             // happening in app), so pause task snapshot persisting to not increase the load.
-            mAtm.mWindowManager.mSnapshotPersistQueue.setPaused(true);
+            mAtm.mWindowManager.mSnapshotController.setPause(true);
             mAnimatingState = true;
             Trace.asyncTraceBegin(Trace.TRACE_TAG_WINDOW_MANAGER, "transitAnim", 0);
         } else if (!animatingState && mAnimatingState) {
             t.setEarlyWakeupEnd();
-            mAtm.mWindowManager.mSnapshotPersistQueue.setPaused(false);
+            mAtm.mWindowManager.mSnapshotController.setPause(false);
             mAnimatingState = false;
             Trace.asyncTraceEnd(Trace.TRACE_TAG_WINDOW_MANAGER, "transitAnim", 0);
         }
diff --git a/services/core/java/com/android/server/wm/WindowAnimator.java b/services/core/java/com/android/server/wm/WindowAnimator.java
index 2596533..10bedd4 100644
--- a/services/core/java/com/android/server/wm/WindowAnimator.java
+++ b/services/core/java/com/android/server/wm/WindowAnimator.java
@@ -200,11 +200,11 @@
                                 | ANIMATION_TYPE_RECENTS /* typesToCheck */);
         if (runningExpensiveAnimations && !mRunningExpensiveAnimations) {
             // Usually app transitions put quite a load onto the system already (with all the things
-            // happening in app), so pause task snapshot persisting to not increase the load.
-            mService.mSnapshotPersistQueue.setPaused(true);
+            // happening in app), so pause snapshot persisting to not increase the load.
+            mService.mSnapshotController.setPause(true);
             mTransaction.setEarlyWakeupStart();
         } else if (!runningExpensiveAnimations && mRunningExpensiveAnimations) {
-            mService.mSnapshotPersistQueue.setPaused(false);
+            mService.mSnapshotController.setPause(false);
             mTransaction.setEarlyWakeupEnd();
         }
         mRunningExpensiveAnimations = runningExpensiveAnimations;
diff --git a/services/core/java/com/android/server/wm/WindowManagerService.java b/services/core/java/com/android/server/wm/WindowManagerService.java
index f7641f5..918729d 100644
--- a/services/core/java/com/android/server/wm/WindowManagerService.java
+++ b/services/core/java/com/android/server/wm/WindowManagerService.java
@@ -689,8 +689,8 @@
     // changes the orientation.
     private final PowerManager.WakeLock mScreenFrozenLock;
 
-    final SnapshotPersistQueue mSnapshotPersistQueue;
     final TaskSnapshotController mTaskSnapshotController;
+    final SnapshotController mSnapshotController;
 
     final BlurController mBlurController;
     final TaskFpsCallbackController mTaskFpsCallbackController;
@@ -1200,8 +1200,8 @@
         mSyncEngine = new BLASTSyncEngine(this);
 
         mWindowPlacerLocked = new WindowSurfacePlacer(this);
-        mSnapshotPersistQueue = new SnapshotPersistQueue();
-        mTaskSnapshotController = new TaskSnapshotController(this, mSnapshotPersistQueue);
+        mSnapshotController = new SnapshotController(this);
+        mTaskSnapshotController = mSnapshotController.mTaskSnapshotController;
 
         mWindowTracing = WindowTracing.createDefaultAndStartLooper(this,
                 Choreographer.getInstance());
@@ -5141,7 +5141,7 @@
         mSystemReady = true;
         mPolicy.systemReady();
         mRoot.forAllDisplayPolicies(DisplayPolicy::systemReady);
-        mSnapshotPersistQueue.systemReady();
+        mSnapshotController.systemReady();
         mHasWideColorGamutSupport = queryWideColorGamutSupport();
         mHasHdrSupport = queryHdrSupport();
         UiThread.getHandler().post(mSettingsObserver::loadSettings);
@@ -6685,7 +6685,7 @@
                 pw.println();
 
         mInputManagerCallback.dump(pw, "  ");
-        mTaskSnapshotController.dump(pw, "  ");
+        mSnapshotController.dump(pw, " ");
         if (mAccessibilityController.hasCallbacks()) {
             mAccessibilityController.dump(pw, "  ");
         }
diff --git a/services/core/java/com/android/server/wm/WindowState.java b/services/core/java/com/android/server/wm/WindowState.java
index 8a083aa..d1bd06f 100644
--- a/services/core/java/com/android/server/wm/WindowState.java
+++ b/services/core/java/com/android/server/wm/WindowState.java
@@ -2913,8 +2913,9 @@
                             .windowForClientLocked(mSession, mClient, false);
                     Slog.i(TAG, "WIN DEATH: " + win);
                     if (win != null) {
-                        if (win.mActivityRecord != null && win.mActivityRecord.findMainWindow() == win) {
-                            mWmService.mTaskSnapshotController.onAppDied(win.mActivityRecord);
+                        if (win.mActivityRecord != null
+                                && win.mActivityRecord.findMainWindow() == win) {
+                            mWmService.mSnapshotController.onAppDied(win.mActivityRecord);
                         }
                         win.removeIfPossible();
                     } else if (mHasSurface) {
diff --git a/services/tests/wmtests/src/com/android/server/wm/ActivitySnapshotControllerTests.java b/services/tests/wmtests/src/com/android/server/wm/ActivitySnapshotControllerTests.java
new file mode 100644
index 0000000..0eca8c9
--- /dev/null
+++ b/services/tests/wmtests/src/com/android/server/wm/ActivitySnapshotControllerTests.java
@@ -0,0 +1,162 @@
+/*
+ * Copyright (C) 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server.wm;
+
+import static android.app.WindowConfiguration.ACTIVITY_TYPE_STANDARD;
+
+import static org.junit.Assert.assertEquals;
+
+import android.platform.test.annotations.Presubmit;
+
+import androidx.test.filters.SmallTest;
+
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+
+/**
+ * Test class for {@link ActivitySnapshotController}.
+ *
+ * Build/Install/Run:
+ *  *  atest WmTests:ActivitySnapshotControllerTests
+ */
+@SmallTest
+@Presubmit
+@RunWith(WindowTestRunner.class)
+public class ActivitySnapshotControllerTests extends WindowTestsBase {
+
+    private ActivitySnapshotController mActivitySnapshotController;
+    @Before
+    public void setUp() throws Exception {
+        mActivitySnapshotController = mWm.mSnapshotController.mActivitySnapshotController;
+        mActivitySnapshotController.resetTmpFields();
+    }
+    @Test
+    public void testOpenActivityTransition() {
+        final SnapshotController.TransitionState transitionState =
+                new SnapshotController.TransitionState();
+        final Task task = createTask(mDisplayContent);
+        // note for createAppWindow: the new child is added at index 0
+        final WindowState openingWindow = createAppWindow(task,
+                ACTIVITY_TYPE_STANDARD, "openingWindow");
+        openingWindow.mActivityRecord.commitVisibility(
+                true /* visible */, true /* performLayout */);
+        final WindowState closingWindow = createAppWindow(task, ACTIVITY_TYPE_STANDARD,
+                "closingWindow");
+        closingWindow.mActivityRecord.commitVisibility(
+                false /* visible */, true /* performLayout */);
+        transitionState.addParticipant(closingWindow.mActivityRecord, false);
+        transitionState.addParticipant(openingWindow.mActivityRecord, true);
+        mActivitySnapshotController.handleOpenActivityTransition(transitionState);
+
+        assertEquals(1, mActivitySnapshotController.mPendingCaptureActivity.size());
+        assertEquals(0, mActivitySnapshotController.mPendingRemoveActivity.size());
+        assertEquals(closingWindow.mActivityRecord,
+                mActivitySnapshotController.mPendingCaptureActivity.valueAt(0));
+        mActivitySnapshotController.resetTmpFields();
+
+        // simulate three activity
+        final WindowState belowClose = createAppWindow(task, ACTIVITY_TYPE_STANDARD,
+                "belowClose");
+        belowClose.mActivityRecord.commitVisibility(
+                false /* visible */, true /* performLayout */);
+        mActivitySnapshotController.handleOpenActivityTransition(transitionState);
+        assertEquals(1, mActivitySnapshotController.mPendingCaptureActivity.size());
+        assertEquals(1, mActivitySnapshotController.mPendingRemoveActivity.size());
+        assertEquals(closingWindow.mActivityRecord,
+                mActivitySnapshotController.mPendingCaptureActivity.valueAt(0));
+        assertEquals(belowClose.mActivityRecord,
+                mActivitySnapshotController.mPendingRemoveActivity.valueAt(0));
+    }
+
+    @Test
+    public void testCloseActivityTransition() {
+        final SnapshotController.TransitionState transitionState =
+                new SnapshotController.TransitionState();
+        final Task task = createTask(mDisplayContent);
+        // note for createAppWindow: the new child is added at index 0
+        final WindowState closingWindow = createAppWindow(task, ACTIVITY_TYPE_STANDARD,
+                "closingWindow");
+        closingWindow.mActivityRecord.commitVisibility(
+                false /* visible */, true /* performLayout */);
+        final WindowState openingWindow = createAppWindow(task,
+                ACTIVITY_TYPE_STANDARD, "openingWindow");
+        openingWindow.mActivityRecord.commitVisibility(
+                true /* visible */, true /* performLayout */);
+        transitionState.addParticipant(closingWindow.mActivityRecord, false);
+        transitionState.addParticipant(openingWindow.mActivityRecord, true);
+        mActivitySnapshotController.handleCloseActivityTransition(transitionState);
+        assertEquals(0, mActivitySnapshotController.mPendingCaptureActivity.size());
+        assertEquals(1, mActivitySnapshotController.mPendingDeleteActivity.size());
+        assertEquals(openingWindow.mActivityRecord,
+                mActivitySnapshotController.mPendingDeleteActivity.valueAt(0));
+        mActivitySnapshotController.resetTmpFields();
+
+        // simulate three activity
+        final WindowState belowOpen = createAppWindow(task, ACTIVITY_TYPE_STANDARD,
+                "belowOpen");
+        belowOpen.mActivityRecord.commitVisibility(
+                false /* visible */, true /* performLayout */);
+        mActivitySnapshotController.handleCloseActivityTransition(transitionState);
+        assertEquals(0, mActivitySnapshotController.mPendingCaptureActivity.size());
+        assertEquals(1, mActivitySnapshotController.mPendingDeleteActivity.size());
+        assertEquals(1, mActivitySnapshotController.mPendingLoadActivity.size());
+        assertEquals(openingWindow.mActivityRecord,
+                mActivitySnapshotController.mPendingDeleteActivity.valueAt(0));
+        assertEquals(belowOpen.mActivityRecord,
+                mActivitySnapshotController.mPendingLoadActivity.valueAt(0));
+    }
+
+    @Test
+    public void testTaskTransition() {
+        final SnapshotController.TransitionState taskCloseTransition =
+                new SnapshotController.TransitionState();
+        final SnapshotController.TransitionState taskOpenTransition =
+                new SnapshotController.TransitionState();
+        final Task closeTask = createTask(mDisplayContent);
+        // note for createAppWindow: the new child is added at index 0
+        final WindowState closingWindow = createAppWindow(closeTask, ACTIVITY_TYPE_STANDARD,
+                "closingWindow");
+        closingWindow.mActivityRecord.commitVisibility(
+                false /* visible */, true /* performLayout */);
+        final WindowState closingWindowBelow = createAppWindow(closeTask, ACTIVITY_TYPE_STANDARD,
+                "closingWindowBelow");
+        closingWindowBelow.mActivityRecord.commitVisibility(
+                false /* visible */, true /* performLayout */);
+
+        final Task openTask = createTask(mDisplayContent);
+        final WindowState openingWindow = createAppWindow(openTask, ACTIVITY_TYPE_STANDARD,
+                "openingWindow");
+        openingWindow.mActivityRecord.commitVisibility(
+                true /* visible */, true /* performLayout */);
+        final WindowState openingWindowBelow = createAppWindow(openTask, ACTIVITY_TYPE_STANDARD,
+                "openingWindowBelow");
+        openingWindowBelow.mActivityRecord.commitVisibility(
+                false /* visible */, true /* performLayout */);
+        taskCloseTransition.addParticipant(closeTask, false);
+        taskOpenTransition.addParticipant(openTask, true);
+        mActivitySnapshotController.handleCloseTaskTransition(taskCloseTransition);
+        mActivitySnapshotController.handleOpenTaskTransition(taskOpenTransition);
+
+        assertEquals(1, mActivitySnapshotController.mPendingRemoveActivity.size());
+        assertEquals(closingWindowBelow.mActivityRecord,
+                mActivitySnapshotController.mPendingRemoveActivity.valueAt(0));
+        assertEquals(1, mActivitySnapshotController.mPendingLoadActivity.size());
+        assertEquals(openingWindowBelow.mActivityRecord,
+                mActivitySnapshotController.mPendingLoadActivity.valueAt(0));
+    }
+}
diff --git a/services/tests/wmtests/src/com/android/server/wm/AppSnapshotControllerTests.java b/services/tests/wmtests/src/com/android/server/wm/AppSnapshotControllerTests.java
new file mode 100644
index 0000000..83af1814
--- /dev/null
+++ b/services/tests/wmtests/src/com/android/server/wm/AppSnapshotControllerTests.java
@@ -0,0 +1,171 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server.wm;
+
+import static android.app.WindowConfiguration.ACTIVITY_TYPE_STANDARD;
+
+import static com.android.server.wm.SnapshotController.ACTIVITY_CLOSE;
+import static com.android.server.wm.SnapshotController.ACTIVITY_OPEN;
+import static com.android.server.wm.SnapshotController.TASK_CLOSE;
+import static com.android.server.wm.SnapshotController.TASK_OPEN;
+
+import static org.junit.Assert.assertTrue;
+
+import android.platform.test.annotations.Presubmit;
+import android.util.ArraySet;
+
+import androidx.test.filters.SmallTest;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+
+
+/**
+ * Test class for {@link SnapshotController}.
+ *
+ * Build/Install/Run:
+ *  *  atest WmTests:AppSnapshotControllerTests
+ */
+@SmallTest
+@Presubmit
+@RunWith(WindowTestRunner.class)
+public class AppSnapshotControllerTests extends WindowTestsBase {
+    final ArraySet<ActivityRecord> mClosingApps = new ArraySet<>();
+    final ArraySet<ActivityRecord> mOpeningApps = new ArraySet<>();
+
+    final TransitionMonitor mOpenActivityMonitor = new TransitionMonitor();
+    final TransitionMonitor mCloseActivityMonitor = new TransitionMonitor();
+    final TransitionMonitor mOpenTaskMonitor = new TransitionMonitor();
+    final TransitionMonitor mCloseTaskMonitor = new TransitionMonitor();
+
+    @Before
+    public void setUp() throws Exception {
+        resetStatus();
+        mWm.mSnapshotController.registerTransitionStateConsumer(
+                ACTIVITY_CLOSE, mCloseActivityMonitor::handleTransition);
+        mWm.mSnapshotController.registerTransitionStateConsumer(
+                ACTIVITY_OPEN, mOpenActivityMonitor::handleTransition);
+        mWm.mSnapshotController.registerTransitionStateConsumer(
+                TASK_CLOSE, mCloseTaskMonitor::handleTransition);
+        mWm.mSnapshotController.registerTransitionStateConsumer(
+                TASK_OPEN, mOpenTaskMonitor::handleTransition);
+    }
+
+    @After
+    public void tearDown() throws Exception {
+        mWm.mSnapshotController.unregisterTransitionStateConsumer(
+                ACTIVITY_CLOSE, mCloseActivityMonitor::handleTransition);
+        mWm.mSnapshotController.unregisterTransitionStateConsumer(
+                ACTIVITY_OPEN, mOpenActivityMonitor::handleTransition);
+        mWm.mSnapshotController.unregisterTransitionStateConsumer(
+                TASK_CLOSE, mCloseTaskMonitor::handleTransition);
+        mWm.mSnapshotController.unregisterTransitionStateConsumer(
+                TASK_OPEN, mOpenTaskMonitor::handleTransition);
+    }
+
+    private static class TransitionMonitor {
+        private final ArraySet<WindowContainer> mOpenParticipant = new ArraySet<>();
+        private final ArraySet<WindowContainer> mCloseParticipant = new ArraySet<>();
+        void handleTransition(SnapshotController.TransitionState<ActivityRecord> state) {
+            mOpenParticipant.addAll(state.getParticipant(true /* open */));
+            mCloseParticipant.addAll(state.getParticipant(false /* close */));
+        }
+        void reset() {
+            mOpenParticipant.clear();
+            mCloseParticipant.clear();
+        }
+    }
+
+    private void resetStatus() {
+        mClosingApps.clear();
+        mOpeningApps.clear();
+        mOpenActivityMonitor.reset();
+        mCloseActivityMonitor.reset();
+        mOpenTaskMonitor.reset();
+        mCloseTaskMonitor.reset();
+    }
+
+    @Test
+    public void testHandleAppTransition_openActivityTransition() {
+        final Task task = createTask(mDisplayContent);
+        // note for createAppWindow: the new child is added at index 0
+        final WindowState openingWindow = createAppWindow(task,
+                ACTIVITY_TYPE_STANDARD, "openingWindow");
+        openingWindow.mActivityRecord.commitVisibility(
+                true /* visible */, true /* performLayout */);
+        final WindowState closingWindow = createAppWindow(task, ACTIVITY_TYPE_STANDARD,
+                "closingWindow");
+        closingWindow.mActivityRecord.commitVisibility(
+                false /* visible */, true /* performLayout */);
+        mClosingApps.add(closingWindow.mActivityRecord);
+        mOpeningApps.add(openingWindow.mActivityRecord);
+        mWm.mSnapshotController.handleAppTransition(mClosingApps, mOpeningApps);
+        assertTrue(mOpenActivityMonitor.mCloseParticipant.contains(closingWindow.mActivityRecord));
+        assertTrue(mOpenActivityMonitor.mOpenParticipant.contains(openingWindow.mActivityRecord));
+    }
+
+    @Test
+    public void testHandleAppTransition_closeActivityTransition() {
+        final Task task = createTask(mDisplayContent);
+        // note for createAppWindow: the new child is added at index 0
+        final WindowState closingWindow = createAppWindow(task, ACTIVITY_TYPE_STANDARD,
+                "closingWindow");
+        closingWindow.mActivityRecord.commitVisibility(
+                false /* visible */, true /* performLayout */);
+        final WindowState openingWindow = createAppWindow(task,
+                ACTIVITY_TYPE_STANDARD, "openingWindow");
+        openingWindow.mActivityRecord.commitVisibility(
+                true /* visible */, true /* performLayout */);
+        mClosingApps.add(closingWindow.mActivityRecord);
+        mOpeningApps.add(openingWindow.mActivityRecord);
+        mWm.mSnapshotController.handleAppTransition(mClosingApps, mOpeningApps);
+        assertTrue(mCloseActivityMonitor.mCloseParticipant.contains(closingWindow.mActivityRecord));
+        assertTrue(mCloseActivityMonitor.mOpenParticipant.contains(openingWindow.mActivityRecord));
+    }
+
+    @Test
+    public void testHandleAppTransition_TaskTransition() {
+        final Task closeTask = createTask(mDisplayContent);
+        // note for createAppWindow: the new child is added at index 0
+        final WindowState closingWindow = createAppWindow(closeTask, ACTIVITY_TYPE_STANDARD,
+                "closingWindow");
+        closingWindow.mActivityRecord.commitVisibility(
+                false /* visible */, true /* performLayout */);
+        final WindowState closingWindowBelow = createAppWindow(closeTask, ACTIVITY_TYPE_STANDARD,
+                "closingWindowBelow");
+        closingWindowBelow.mActivityRecord.commitVisibility(
+                false /* visible */, true /* performLayout */);
+
+        final Task openTask = createTask(mDisplayContent);
+        final WindowState openingWindow = createAppWindow(openTask, ACTIVITY_TYPE_STANDARD,
+                "openingWindow");
+        openingWindow.mActivityRecord.commitVisibility(
+                true /* visible */, true /* performLayout */);
+        final WindowState openingWindowBelow = createAppWindow(openTask, ACTIVITY_TYPE_STANDARD,
+                "openingWindowBelow");
+        openingWindowBelow.mActivityRecord.commitVisibility(
+                false /* visible */, true /* performLayout */);
+
+        mClosingApps.add(closingWindow.mActivityRecord);
+        mOpeningApps.add(openingWindow.mActivityRecord);
+        mWm.mSnapshotController.handleAppTransition(mClosingApps, mOpeningApps);
+        assertTrue(mCloseTaskMonitor.mCloseParticipant.contains(closeTask));
+        assertTrue(mOpenTaskMonitor.mOpenParticipant.contains(openTask));
+    }
+}
diff --git a/services/tests/wmtests/src/com/android/server/wm/TransitionTests.java b/services/tests/wmtests/src/com/android/server/wm/TransitionTests.java
index 616d528..582d7d8 100644
--- a/services/tests/wmtests/src/com/android/server/wm/TransitionTests.java
+++ b/services/tests/wmtests/src/com/android/server/wm/TransitionTests.java
@@ -44,6 +44,7 @@
 import static com.android.dx.mockito.inline.extended.ExtendedMockito.doNothing;
 import static com.android.dx.mockito.inline.extended.ExtendedMockito.doReturn;
 import static com.android.dx.mockito.inline.extended.ExtendedMockito.spyOn;
+import static com.android.server.wm.SnapshotController.TASK_CLOSE;
 import static com.android.server.wm.WindowContainer.POSITION_TOP;
 
 import static org.junit.Assert.assertEquals;
@@ -1351,6 +1352,9 @@
 
     @Test
     public void testTransientLaunch() {
+        spyOn(mWm.mSnapshotController.mTaskSnapshotController);
+        mWm.mSnapshotController.registerTransitionStateConsumer(TASK_CLOSE,
+                mWm.mSnapshotController.mTaskSnapshotController::handleTaskClose);
         final ArrayList<ActivityRecord> enteringAnimReports = new ArrayList<>();
         final TransitionController controller = new TestTransitionController(mAtm) {
             @Override
@@ -1361,7 +1365,9 @@
                 super.dispatchLegacyAppTransitionFinished(ar);
             }
         };
-        final TaskSnapshotController snapshotController = controller.mTaskSnapshotController;
+        controller.mSnapshotController = mWm.mSnapshotController;
+        final TaskSnapshotController taskSnapshotController = controller.mSnapshotController
+                .mTaskSnapshotController;
         final ITransitionPlayer player = new ITransitionPlayer.Default();
         controller.registerTransitionPlayer(player, null /* playerProc */);
         final Transition openTransition = controller.createTransition(TRANSIT_OPEN);
@@ -1391,7 +1397,7 @@
         // normally.
         mWm.mSyncEngine.abort(openTransition.getSyncId());
 
-        verify(snapshotController, times(1)).recordSnapshot(eq(task2), eq(false));
+        verify(taskSnapshotController, times(1)).recordSnapshot(eq(task2), eq(false));
 
         controller.finishTransition(openTransition);
 
@@ -1421,7 +1427,7 @@
 
         // Make sure we haven't called recordSnapshot (since we are transient, it shouldn't be
         // called until finish).
-        verify(snapshotController, times(0)).recordSnapshot(eq(task1), eq(false));
+        verify(taskSnapshotController, times(0)).recordSnapshot(eq(task1), eq(false));
 
         enteringAnimReports.clear();
         doCallRealMethod().when(mWm.mRoot).ensureActivitiesVisible(any(),
@@ -1447,7 +1453,7 @@
         assertFalse(activity1.isVisible());
         assertFalse(activity1.app.hasActivityInVisibleTask());
 
-        verify(snapshotController, times(1)).recordSnapshot(eq(task1), eq(false));
+        verify(taskSnapshotController, times(1)).recordSnapshot(eq(task1), eq(false));
         assertTrue(enteringAnimReports.contains(activity2));
     }
 
diff --git a/services/tests/wmtests/src/com/android/server/wm/WindowTestsBase.java b/services/tests/wmtests/src/com/android/server/wm/WindowTestsBase.java
index b80500a..0d7cdc8 100644
--- a/services/tests/wmtests/src/com/android/server/wm/WindowTestsBase.java
+++ b/services/tests/wmtests/src/com/android/server/wm/WindowTestsBase.java
@@ -1746,7 +1746,8 @@
     static class TestTransitionController extends TransitionController {
         TestTransitionController(ActivityTaskManagerService atms) {
             super(atms);
-            mTaskSnapshotController = mock(TaskSnapshotController.class);
+            doReturn(this).when(atms).getTransitionController();
+            mSnapshotController = mock(SnapshotController.class);
             mTransitionTracer = mock(TransitionTracer.class);
         }
     }