diff --git a/services/surfaceflinger/tests/Transaction_test.cpp b/services/surfaceflinger/tests/Transaction_test.cpp
index 216532a..b95c0ec 100644
--- a/services/surfaceflinger/tests/Transaction_test.cpp
+++ b/services/surfaceflinger/tests/Transaction_test.cpp
@@ -2664,6 +2664,17 @@
     }
 }
 
+struct CallbackData {
+    CallbackData() = default;
+    CallbackData(nsecs_t time, const sp<Fence>& fence,
+                 const std::vector<SurfaceControlStats>& stats)
+          : latchTime(time), presentFence(fence), surfaceControlStats(stats) {}
+
+    nsecs_t latchTime;
+    sp<Fence> presentFence;
+    std::vector<SurfaceControlStats> surfaceControlStats;
+};
+
 class ExpectedResult {
 public:
     enum Transaction {
@@ -2691,8 +2702,7 @@
                     ExpectedResult::Buffer bufferResult = ACQUIRED,
                     ExpectedResult::PreviousBuffer previousBufferResult = NOT_RELEASED) {
         mTransactionResult = transactionResult;
-        mExpectedSurfaceResults.emplace(std::piecewise_construct,
-                                        std::forward_as_tuple(layer->getHandle()),
+        mExpectedSurfaceResults.emplace(std::piecewise_construct, std::forward_as_tuple(layer),
                                         std::forward_as_tuple(bufferResult, previousBufferResult));
     }
 
@@ -2709,8 +2719,8 @@
         mExpectedPresentTime = expectedPresentTime;
     }
 
-    void verifyTransactionStats(const TransactionStats& transactionStats) const {
-        const auto& [latchTime, presentFence, surfaceStats] = transactionStats;
+    void verifyCallbackData(const CallbackData& callbackData) const {
+        const auto& [latchTime, presentFence, surfaceControlStats] = callbackData;
         if (mTransactionResult == ExpectedResult::Transaction::PRESENTED) {
             ASSERT_GE(latchTime, 0) << "bad latch time";
             ASSERT_NE(presentFence, nullptr);
@@ -2727,14 +2737,16 @@
             ASSERT_EQ(latchTime, -1) << "unpresented transactions shouldn't be latched";
         }
 
-        ASSERT_EQ(surfaceStats.size(), mExpectedSurfaceResults.size())
+        ASSERT_EQ(surfaceControlStats.size(), mExpectedSurfaceResults.size())
                 << "wrong number of surfaces";
 
-        for (const auto& stats : surfaceStats) {
+        for (const auto& stats : surfaceControlStats) {
+            ASSERT_NE(stats.surfaceControl, nullptr) << "returned null surface control";
+
             const auto& expectedSurfaceResult = mExpectedSurfaceResults.find(stats.surfaceControl);
             ASSERT_NE(expectedSurfaceResult, mExpectedSurfaceResults.end())
                     << "unexpected surface control";
-            expectedSurfaceResult->second.verifySurfaceStats(stats, latchTime);
+            expectedSurfaceResult->second.verifySurfaceControlStats(stats, latchTime);
         }
     }
 
@@ -2745,8 +2757,9 @@
                               ExpectedResult::PreviousBuffer previousBufferResult)
               : mBufferResult(bufferResult), mPreviousBufferResult(previousBufferResult) {}
 
-        void verifySurfaceStats(const SurfaceStats& surfaceStats, nsecs_t latchTime) const {
-            const auto& [surfaceControl, acquireTime, previousReleaseFence] = surfaceStats;
+        void verifySurfaceControlStats(const SurfaceControlStats& surfaceControlStats,
+                                       nsecs_t latchTime) const {
+            const auto& [surfaceControl, acquireTime, previousReleaseFence] = surfaceControlStats;
 
             ASSERT_EQ(acquireTime > 0, mBufferResult == ExpectedResult::Buffer::ACQUIRED)
                     << "bad acquire time";
@@ -2766,39 +2779,40 @@
         ExpectedResult::PreviousBuffer mPreviousBufferResult;
     };
 
-    struct IBinderHash {
-        std::size_t operator()(const sp<IBinder>& strongPointer) const {
-            return std::hash<IBinder*>{}(strongPointer.get());
+    struct SCHash {
+        std::size_t operator()(const sp<SurfaceControl>& sc) const {
+            return std::hash<IBinder*>{}(sc->getHandle().get());
         }
     };
     ExpectedResult::Transaction mTransactionResult = ExpectedResult::Transaction::NOT_PRESENTED;
     nsecs_t mExpectedPresentTime = -1;
-    std::unordered_map<sp<IBinder>, ExpectedSurfaceResult, IBinderHash> mExpectedSurfaceResults;
+    std::unordered_map<sp<SurfaceControl>, ExpectedSurfaceResult, SCHash> mExpectedSurfaceResults;
 };
 
 class CallbackHelper {
 public:
-    static void function(void* callbackContext, const TransactionStats& transactionStats) {
+    static void function(void* callbackContext, nsecs_t latchTime, const sp<Fence>& presentFence,
+                         const std::vector<SurfaceControlStats>& stats) {
         if (!callbackContext) {
             ALOGE("failed to get callback context");
         }
         CallbackHelper* helper = static_cast<CallbackHelper*>(callbackContext);
         std::lock_guard lock(helper->mMutex);
-        helper->mTransactionStatsQueue.push(transactionStats);
+        helper->mCallbackDataQueue.emplace(latchTime, presentFence, stats);
         helper->mConditionVariable.notify_all();
     }
 
-    void getTransactionStats(TransactionStats* outStats) {
+    void getCallbackData(CallbackData* outData) {
         std::unique_lock lock(mMutex);
 
-        if (mTransactionStatsQueue.empty()) {
+        if (mCallbackDataQueue.empty()) {
             ASSERT_NE(mConditionVariable.wait_for(lock, std::chrono::seconds(3)),
                       std::cv_status::timeout)
                     << "did not receive callback";
         }
 
-        *outStats = std::move(mTransactionStatsQueue.front());
-        mTransactionStatsQueue.pop();
+        *outData = std::move(mCallbackDataQueue.front());
+        mCallbackDataQueue.pop();
     }
 
     void verifyFinalState() {
@@ -2806,15 +2820,15 @@
         std::this_thread::sleep_for(500ms);
 
         std::lock_guard lock(mMutex);
-        EXPECT_EQ(mTransactionStatsQueue.size(), 0) << "extra callbacks received";
-        mTransactionStatsQueue = {};
+        EXPECT_EQ(mCallbackDataQueue.size(), 0) << "extra callbacks received";
+        mCallbackDataQueue = {};
     }
 
     void* getContext() { return static_cast<void*>(this); }
 
     std::mutex mMutex;
     std::condition_variable mConditionVariable;
-    std::queue<TransactionStats> mTransactionStatsQueue;
+    std::queue<CallbackData> mCallbackDataQueue;
 };
 
 class LayerCallbackTest : public LayerTransactionTest {
@@ -2843,9 +2857,9 @@
 
     static void waitForCallback(CallbackHelper& helper, const ExpectedResult& expectedResult,
                                 bool finalState = false) {
-        TransactionStats transactionStats;
-        ASSERT_NO_FATAL_FAILURE(helper.getTransactionStats(&transactionStats));
-        EXPECT_NO_FATAL_FAILURE(expectedResult.verifyTransactionStats(transactionStats));
+        CallbackData callbackData;
+        ASSERT_NO_FATAL_FAILURE(helper.getCallbackData(&callbackData));
+        EXPECT_NO_FATAL_FAILURE(expectedResult.verifyCallbackData(callbackData));
 
         if (finalState) {
             ASSERT_NO_FATAL_FAILURE(helper.verifyFinalState());
