diff --git a/services/surfaceflinger/tests/unittests/DisplayTransactionTest.cpp b/services/surfaceflinger/tests/unittests/DisplayTransactionTest.cpp
index f04221c..5ddd1c8 100644
--- a/services/surfaceflinger/tests/unittests/DisplayTransactionTest.cpp
+++ b/services/surfaceflinger/tests/unittests/DisplayTransactionTest.cpp
@@ -120,51 +120,6 @@
     });
 }
 
-sp<DisplayDevice> DisplayTransactionTest::injectDefaultInternalDisplay(
-        std::function<void(FakeDisplayDeviceInjector&)> injectExtra) {
-    constexpr PhysicalDisplayId DEFAULT_DISPLAY_ID = PhysicalDisplayId::fromPort(255u);
-    constexpr int DEFAULT_DISPLAY_WIDTH = 1080;
-    constexpr int DEFAULT_DISPLAY_HEIGHT = 1920;
-    constexpr HWDisplayId DEFAULT_DISPLAY_HWC_DISPLAY_ID = 0;
-
-    // The DisplayDevice is required to have a framebuffer (behind the
-    // ANativeWindow interface) which uses the actual hardware display
-    // size.
-    EXPECT_CALL(*mNativeWindow, query(NATIVE_WINDOW_WIDTH, _))
-            .WillRepeatedly(DoAll(SetArgPointee<1>(DEFAULT_DISPLAY_WIDTH), Return(0)));
-    EXPECT_CALL(*mNativeWindow, query(NATIVE_WINDOW_HEIGHT, _))
-            .WillRepeatedly(DoAll(SetArgPointee<1>(DEFAULT_DISPLAY_HEIGHT), Return(0)));
-    EXPECT_CALL(*mNativeWindow, perform(NATIVE_WINDOW_SET_BUFFERS_FORMAT));
-    EXPECT_CALL(*mNativeWindow, perform(NATIVE_WINDOW_API_CONNECT));
-    EXPECT_CALL(*mNativeWindow, perform(NATIVE_WINDOW_SET_USAGE64));
-    EXPECT_CALL(*mNativeWindow, perform(NATIVE_WINDOW_API_DISCONNECT)).Times(AnyNumber());
-
-    auto compositionDisplay =
-            compositionengine::impl::createDisplay(mFlinger.getCompositionEngine(),
-                                                   compositionengine::DisplayCreationArgsBuilder()
-                                                           .setId(DEFAULT_DISPLAY_ID)
-                                                           .setPixels({DEFAULT_DISPLAY_WIDTH,
-                                                                       DEFAULT_DISPLAY_HEIGHT})
-                                                           .setPowerAdvisor(&mPowerAdvisor)
-                                                           .build());
-
-    constexpr bool kIsPrimary = true;
-    auto injector = FakeDisplayDeviceInjector(mFlinger, compositionDisplay,
-                                              ui::DisplayConnectionType::Internal,
-                                              DEFAULT_DISPLAY_HWC_DISPLAY_ID, kIsPrimary);
-
-    injector.setNativeWindow(mNativeWindow);
-    if (injectExtra) {
-        injectExtra(injector);
-    }
-
-    auto displayDevice = injector.inject();
-
-    Mock::VerifyAndClear(mNativeWindow.get());
-
-    return displayDevice;
-}
-
 bool DisplayTransactionTest::hasPhysicalHwcDisplay(HWDisplayId hwcDisplayId) const {
     return mFlinger.hwcPhysicalDisplayIdMap().count(hwcDisplayId) == 1;
 }
diff --git a/services/surfaceflinger/tests/unittests/DisplayTransactionTestHelpers.h b/services/surfaceflinger/tests/unittests/DisplayTransactionTestHelpers.h
index f5235ce..60f773f 100644
--- a/services/surfaceflinger/tests/unittests/DisplayTransactionTestHelpers.h
+++ b/services/surfaceflinger/tests/unittests/DisplayTransactionTestHelpers.h
@@ -42,6 +42,7 @@
 #include <renderengine/mock/RenderEngine.h>
 #include <ui/DebugUtils.h>
 
+#include "FakeDisplayInjector.h"
 #include "TestableScheduler.h"
 #include "TestableSurfaceFlinger.h"
 #include "mock/DisplayHardware/MockComposer.h"
@@ -90,7 +91,9 @@
     void injectFakeBufferQueueFactory();
     void injectFakeNativeWindowSurfaceFactory();
     sp<DisplayDevice> injectDefaultInternalDisplay(
-            std::function<void(TestableSurfaceFlinger::FakeDisplayDeviceInjector&)>);
+            std::function<void(TestableSurfaceFlinger::FakeDisplayDeviceInjector&)> injectExtra) {
+        return mFakeDisplayInjector.injectInternalDisplay(injectExtra);
+    }
 
     // --------------------------------------------------------------------
     // Postcondition helpers
@@ -115,6 +118,8 @@
     sp<GraphicBuffer> mBuffer = new GraphicBuffer();
     Hwc2::mock::PowerAdvisor mPowerAdvisor;
 
+    FakeDisplayInjector mFakeDisplayInjector{mFlinger, mPowerAdvisor, mNativeWindow};
+
     // These mocks are created by the test, but are destroyed by SurfaceFlinger
     // by virtue of being stored into a std::unique_ptr. However we still need
     // to keep a reference to them for use in setting up call expectations.
diff --git a/services/surfaceflinger/tests/unittests/FakeDisplayInjector.h b/services/surfaceflinger/tests/unittests/FakeDisplayInjector.h
new file mode 100644
index 0000000..6e4bf2b
--- /dev/null
+++ b/services/surfaceflinger/tests/unittests/FakeDisplayInjector.h
@@ -0,0 +1,96 @@
+/*
+ * Copyright 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include <gmock/gmock.h>
+
+#include "TestableSurfaceFlinger.h"
+#include "mock/DisplayHardware/MockPowerAdvisor.h"
+#include "mock/system/window/MockNativeWindow.h"
+
+namespace android {
+
+using FakeDisplayDeviceInjector = TestableSurfaceFlinger::FakeDisplayDeviceInjector;
+using android::hardware::graphics::composer::hal::HWDisplayId;
+using android::Hwc2::mock::PowerAdvisor;
+
+struct FakeDisplayInjectorArgs {
+    PhysicalDisplayId displayId = PhysicalDisplayId::fromPort(255u);
+    HWDisplayId hwcDisplayId = 0;
+    bool isPrimary = true;
+};
+
+class FakeDisplayInjector {
+public:
+    FakeDisplayInjector(TestableSurfaceFlinger& flinger, Hwc2::mock::PowerAdvisor& powerAdvisor,
+                        sp<mock::NativeWindow> nativeWindow)
+          : mFlinger(flinger), mPowerAdvisor(powerAdvisor), mNativeWindow(nativeWindow) {}
+
+    sp<DisplayDevice> injectInternalDisplay(
+            const std::function<void(FakeDisplayDeviceInjector&)>& injectExtra,
+            FakeDisplayInjectorArgs args = {}) {
+        using testing::_;
+        using testing::AnyNumber;
+        using testing::DoAll;
+        using testing::Mock;
+        using testing::Return;
+        using testing::SetArgPointee;
+
+        constexpr ui::Size kResolution = {1080, 1920};
+
+        // The DisplayDevice is required to have a framebuffer (behind the
+        // ANativeWindow interface) which uses the actual hardware display
+        // size.
+        EXPECT_CALL(*mNativeWindow, query(NATIVE_WINDOW_WIDTH, _))
+                .WillRepeatedly(DoAll(SetArgPointee<1>(kResolution.getWidth()), Return(0)));
+        EXPECT_CALL(*mNativeWindow, query(NATIVE_WINDOW_HEIGHT, _))
+                .WillRepeatedly(DoAll(SetArgPointee<1>(kResolution.getHeight()), Return(0)));
+        EXPECT_CALL(*mNativeWindow, perform(NATIVE_WINDOW_SET_BUFFERS_FORMAT));
+        EXPECT_CALL(*mNativeWindow, perform(NATIVE_WINDOW_API_CONNECT));
+        EXPECT_CALL(*mNativeWindow, perform(NATIVE_WINDOW_SET_USAGE64));
+        EXPECT_CALL(*mNativeWindow, perform(NATIVE_WINDOW_API_DISCONNECT)).Times(AnyNumber());
+
+        auto compositionDisplay = compositionengine::impl::
+                createDisplay(mFlinger.getCompositionEngine(),
+                              compositionengine::DisplayCreationArgsBuilder()
+                                      .setId(args.displayId)
+                                      .setPixels(kResolution)
+                                      .setPowerAdvisor(&mPowerAdvisor)
+                                      .build());
+
+        auto injector = FakeDisplayDeviceInjector(mFlinger, compositionDisplay,
+                                                  ui::DisplayConnectionType::Internal,
+                                                  args.hwcDisplayId, args.isPrimary);
+
+        injector.setNativeWindow(mNativeWindow);
+        if (injectExtra) {
+            injectExtra(injector);
+        }
+
+        auto displayDevice = injector.inject();
+
+        Mock::VerifyAndClear(mNativeWindow.get());
+
+        return displayDevice;
+    }
+
+    TestableSurfaceFlinger& mFlinger;
+    Hwc2::mock::PowerAdvisor& mPowerAdvisor;
+    sp<mock::NativeWindow> mNativeWindow;
+};
+
+} // namespace android
diff --git a/services/surfaceflinger/tests/unittests/SurfaceFlinger_SetPowerModeInternalTest.cpp b/services/surfaceflinger/tests/unittests/SurfaceFlinger_SetPowerModeInternalTest.cpp
index 583cf5f..b560025 100644
--- a/services/surfaceflinger/tests/unittests/SurfaceFlinger_SetPowerModeInternalTest.cpp
+++ b/services/surfaceflinger/tests/unittests/SurfaceFlinger_SetPowerModeInternalTest.cpp
@@ -323,6 +323,9 @@
 public:
     template <typename Case>
     void transitionDisplayCommon();
+
+    template <bool kBoot>
+    sp<DisplayDevice> activeDisplayTest();
 };
 
 template <PowerMode PowerMode>
@@ -499,5 +502,87 @@
     transitionDisplayCommon<ExternalDisplayPowerCase<TransitionOnToUnknownVariant>>();
 }
 
+template <bool kBoot>
+sp<DisplayDevice> SetPowerModeInternalTest::activeDisplayTest() {
+    using Case = SimplePrimaryDisplayCase;
+
+    // --------------------------------------------------------------------
+    // Preconditions
+
+    // Inject a primary display.
+    Case::Display::injectHwcDisplay(this);
+    auto injector = Case::Display::makeFakeExistingDisplayInjector(this);
+    injector.setPowerMode(kBoot ? std::nullopt : std::make_optional(PowerMode::OFF));
+
+    const auto display = injector.inject();
+    EXPECT_EQ(display->getDisplayToken(), mFlinger.mutableActiveDisplayToken());
+
+    using PowerCase = PrimaryDisplayPowerCase<TransitionOffToOnVariant>;
+    TransitionOffToOnVariant::template setupCallExpectations<PowerCase>(this);
+
+    constexpr size_t kTimes = kBoot ? 1 : 0;
+    EXPECT_CALL(*mRenderEngine, onActiveDisplaySizeChanged(display->getSize())).Times(kTimes);
+    EXPECT_CALL(*mEventThread, onModeChanged(display->getActiveMode())).Times(kTimes);
+
+    if constexpr (kBoot) {
+        mFlinger.mutableActiveDisplayToken() = nullptr;
+    }
+
+    // --------------------------------------------------------------------
+    // Invocation
+
+    mFlinger.setPowerModeInternal(display, PowerMode::ON);
+
+    // --------------------------------------------------------------------
+    // Postconditions
+
+    // The primary display should be the active display.
+    EXPECT_EQ(display->getDisplayToken(), mFlinger.mutableActiveDisplayToken());
+
+    Mock::VerifyAndClearExpectations(mComposer);
+    Mock::VerifyAndClearExpectations(mRenderEngine);
+    Mock::VerifyAndClearExpectations(mEventThread);
+    Mock::VerifyAndClearExpectations(mVsyncController);
+    Mock::VerifyAndClearExpectations(mVSyncTracker);
+    Mock::VerifyAndClearExpectations(mFlinger.scheduler());
+    Mock::VerifyAndClearExpectations(&mFlinger.mockSchedulerCallback());
+
+    return display;
+}
+
+TEST_F(SetPowerModeInternalTest, activeDisplayBoot) {
+    constexpr bool kBoot = true;
+    activeDisplayTest<kBoot>();
+}
+
+TEST_F(SetPowerModeInternalTest, activeDisplaySingle) {
+    constexpr bool kBoot = false;
+    activeDisplayTest<kBoot>();
+}
+
+TEST_F(SetPowerModeInternalTest, activeDisplayDual) {
+    constexpr bool kBoot = false;
+    const auto innerDisplay = activeDisplayTest<kBoot>();
+
+    // Inject a powered-off outer display.
+    const auto outerDisplay = mFakeDisplayInjector.injectInternalDisplay(
+            [&](FakeDisplayDeviceInjector& injector) { injector.setPowerMode(PowerMode::OFF); },
+            {.displayId = PhysicalDisplayId::fromPort(254u),
+             .hwcDisplayId = 1,
+             .isPrimary = false});
+
+    EXPECT_EQ(innerDisplay->getDisplayToken(), mFlinger.mutableActiveDisplayToken());
+
+    mFlinger.setPowerModeInternal(innerDisplay, PowerMode::OFF);
+    mFlinger.setPowerModeInternal(outerDisplay, PowerMode::ON);
+
+    EXPECT_EQ(outerDisplay->getDisplayToken(), mFlinger.mutableActiveDisplayToken());
+
+    mFlinger.setPowerModeInternal(outerDisplay, PowerMode::OFF);
+    mFlinger.setPowerModeInternal(innerDisplay, PowerMode::ON);
+
+    EXPECT_EQ(innerDisplay->getDisplayToken(), mFlinger.mutableActiveDisplayToken());
+}
+
 } // namespace
 } // namespace android
diff --git a/services/surfaceflinger/tests/unittests/TestableSurfaceFlinger.h b/services/surfaceflinger/tests/unittests/TestableSurfaceFlinger.h
index 283f9ca..3cffec1 100644
--- a/services/surfaceflinger/tests/unittests/TestableSurfaceFlinger.h
+++ b/services/surfaceflinger/tests/unittests/TestableSurfaceFlinger.h
@@ -798,7 +798,7 @@
             return *this;
         }
 
-        auto& setPowerMode(hal::PowerMode mode) {
+        auto& setPowerMode(std::optional<hal::PowerMode> mode) {
             mCreationArgs.initialPowerMode = mode;
             return *this;
         }
@@ -867,6 +867,10 @@
                                   .deviceProductInfo = {},
                                   .supportedModes = modes,
                                   .activeMode = activeMode->get()};
+
+                if (mCreationArgs.isPrimary) {
+                    mFlinger.mutableActiveDisplayToken() = mDisplayToken;
+                }
             }
 
             state.isSecure = mCreationArgs.isSecure;
