Merge changes from topic "VSyncDispatch_unregisterCallback" into udc-dev

* changes:
  Before deleting callbacks, ensure they aren't running
  Fix for potential deadlock in EventThread::onNewVsyncSchedule
  Make VSyncCallbackRegistration's move operator= call unregisterCallback
diff --git a/services/surfaceflinger/Scheduler/EventThread.cpp b/services/surfaceflinger/Scheduler/EventThread.cpp
index 74665a7..1c0bf0d 100644
--- a/services/surfaceflinger/Scheduler/EventThread.cpp
+++ b/services/surfaceflinger/Scheduler/EventThread.cpp
@@ -692,17 +692,27 @@
 }
 
 void EventThread::onNewVsyncSchedule(std::shared_ptr<scheduler::VsyncSchedule> schedule) {
+    // Hold onto the old registration until after releasing the mutex to avoid deadlock.
+    scheduler::VSyncCallbackRegistration oldRegistration =
+            onNewVsyncScheduleInternal(std::move(schedule));
+}
+
+scheduler::VSyncCallbackRegistration EventThread::onNewVsyncScheduleInternal(
+        std::shared_ptr<scheduler::VsyncSchedule> schedule) {
     std::lock_guard<std::mutex> lock(mMutex);
     const bool reschedule = mVsyncRegistration.cancel() == scheduler::CancelResult::Cancelled;
     mVsyncSchedule = std::move(schedule);
-    mVsyncRegistration =
-            scheduler::VSyncCallbackRegistration(mVsyncSchedule->getDispatch(),
-                                                 createDispatchCallback(), mThreadName);
+    auto oldRegistration =
+            std::exchange(mVsyncRegistration,
+                          scheduler::VSyncCallbackRegistration(mVsyncSchedule->getDispatch(),
+                                                               createDispatchCallback(),
+                                                               mThreadName));
     if (reschedule) {
         mVsyncRegistration.schedule({.workDuration = mWorkDuration.get().count(),
                                      .readyDuration = mReadyDuration.count(),
                                      .earliestVsync = mLastVsyncCallbackTime.ns()});
     }
+    return oldRegistration;
 }
 
 scheduler::VSyncDispatch::Callback EventThread::createDispatchCallback() {
diff --git a/services/surfaceflinger/Scheduler/EventThread.h b/services/surfaceflinger/Scheduler/EventThread.h
index 30869e9..684745b 100644
--- a/services/surfaceflinger/Scheduler/EventThread.h
+++ b/services/surfaceflinger/Scheduler/EventThread.h
@@ -174,7 +174,7 @@
 
     size_t getEventThreadConnectionCount() override;
 
-    void onNewVsyncSchedule(std::shared_ptr<scheduler::VsyncSchedule>) override;
+    void onNewVsyncSchedule(std::shared_ptr<scheduler::VsyncSchedule>) override EXCLUDES(mMutex);
 
 private:
     friend EventThreadTest;
@@ -201,6 +201,11 @@
 
     scheduler::VSyncDispatch::Callback createDispatchCallback();
 
+    // Returns the old registration so it can be destructed outside the lock to
+    // avoid deadlock.
+    scheduler::VSyncCallbackRegistration onNewVsyncScheduleInternal(
+            std::shared_ptr<scheduler::VsyncSchedule>) EXCLUDES(mMutex);
+
     const char* const mThreadName;
     TracedOrdinal<int> mVsyncTracer;
     TracedOrdinal<std::chrono::nanoseconds> mWorkDuration GUARDED_BY(mMutex);
diff --git a/services/surfaceflinger/Scheduler/VSyncDispatch.h b/services/surfaceflinger/Scheduler/VSyncDispatch.h
index 77875e3..c3a952f 100644
--- a/services/surfaceflinger/Scheduler/VSyncDispatch.h
+++ b/services/surfaceflinger/Scheduler/VSyncDispatch.h
@@ -155,10 +155,6 @@
     VSyncDispatch& operator=(const VSyncDispatch&) = delete;
 };
 
-/*
- * Helper class to operate on registered callbacks. It is up to user of the class to ensure
- * that VsyncDispatch lifetime exceeds the lifetime of VSyncCallbackRegistation.
- */
 class VSyncCallbackRegistration {
 public:
     VSyncCallbackRegistration(std::shared_ptr<VSyncDispatch>, VSyncDispatch::Callback,
@@ -178,9 +174,10 @@
     CancelResult cancel();
 
 private:
+    friend class VSyncCallbackRegistrationTest;
+
     std::shared_ptr<VSyncDispatch> mDispatch;
-    VSyncDispatch::CallbackToken mToken;
-    bool mValidToken;
+    std::optional<VSyncDispatch::CallbackToken> mToken;
 };
 
 } // namespace android::scheduler
diff --git a/services/surfaceflinger/Scheduler/VSyncDispatchTimerQueue.cpp b/services/surfaceflinger/Scheduler/VSyncDispatchTimerQueue.cpp
index 26389eb..1f922f1 100644
--- a/services/surfaceflinger/Scheduler/VSyncDispatchTimerQueue.cpp
+++ b/services/surfaceflinger/Scheduler/VSyncDispatchTimerQueue.cpp
@@ -21,12 +21,16 @@
 #include <android-base/stringprintf.h>
 #include <ftl/concat.h>
 #include <utils/Trace.h>
+#include <log/log_main.h>
 
 #include <scheduler/TimeKeeper.h>
 
 #include "VSyncDispatchTimerQueue.h"
 #include "VSyncTracker.h"
 
+#undef LOG_TAG
+#define LOG_TAG "VSyncDispatch"
+
 namespace android::scheduler {
 
 using base::StringAppendF;
@@ -225,6 +229,10 @@
 VSyncDispatchTimerQueue::~VSyncDispatchTimerQueue() {
     std::lock_guard lock(mMutex);
     cancelTimer();
+    for (auto& [_, entry] : mCallbacks) {
+        ALOGE("Forgot to unregister a callback on VSyncDispatch!");
+        entry->ensureNotRunning();
+    }
 }
 
 void VSyncDispatchTimerQueue::cancelTimer() {
@@ -438,47 +446,44 @@
                                                      VSyncDispatch::Callback callback,
                                                      std::string callbackName)
       : mDispatch(std::move(dispatch)),
-        mToken(mDispatch->registerCallback(std::move(callback), std::move(callbackName))),
-        mValidToken(true) {}
+        mToken(mDispatch->registerCallback(std::move(callback), std::move(callbackName))) {}
 
 VSyncCallbackRegistration::VSyncCallbackRegistration(VSyncCallbackRegistration&& other)
-      : mDispatch(std::move(other.mDispatch)),
-        mToken(std::move(other.mToken)),
-        mValidToken(std::move(other.mValidToken)) {
-    other.mValidToken = false;
-}
+      : mDispatch(std::move(other.mDispatch)), mToken(std::exchange(other.mToken, std::nullopt)) {}
 
 VSyncCallbackRegistration& VSyncCallbackRegistration::operator=(VSyncCallbackRegistration&& other) {
+    if (this == &other) return *this;
+    if (mToken) {
+        mDispatch->unregisterCallback(*mToken);
+    }
     mDispatch = std::move(other.mDispatch);
-    mToken = std::move(other.mToken);
-    mValidToken = std::move(other.mValidToken);
-    other.mValidToken = false;
+    mToken = std::exchange(other.mToken, std::nullopt);
     return *this;
 }
 
 VSyncCallbackRegistration::~VSyncCallbackRegistration() {
-    if (mValidToken) mDispatch->unregisterCallback(mToken);
+    if (mToken) mDispatch->unregisterCallback(*mToken);
 }
 
 ScheduleResult VSyncCallbackRegistration::schedule(VSyncDispatch::ScheduleTiming scheduleTiming) {
-    if (!mValidToken) {
+    if (!mToken) {
         return std::nullopt;
     }
-    return mDispatch->schedule(mToken, scheduleTiming);
+    return mDispatch->schedule(*mToken, scheduleTiming);
 }
 
 ScheduleResult VSyncCallbackRegistration::update(VSyncDispatch::ScheduleTiming scheduleTiming) {
-    if (!mValidToken) {
+    if (!mToken) {
         return std::nullopt;
     }
-    return mDispatch->update(mToken, scheduleTiming);
+    return mDispatch->update(*mToken, scheduleTiming);
 }
 
 CancelResult VSyncCallbackRegistration::cancel() {
-    if (!mValidToken) {
+    if (!mToken) {
         return CancelResult::Error;
     }
-    return mDispatch->cancel(mToken);
+    return mDispatch->cancel(*mToken);
 }
 
 } // namespace android::scheduler
diff --git a/services/surfaceflinger/tests/unittests/Android.bp b/services/surfaceflinger/tests/unittests/Android.bp
index 201d37f..3713275 100644
--- a/services/surfaceflinger/tests/unittests/Android.bp
+++ b/services/surfaceflinger/tests/unittests/Android.bp
@@ -128,6 +128,7 @@
         "TransactionTracingTest.cpp",
         "TunnelModeEnabledReporterTest.cpp",
         "StrongTypingTest.cpp",
+        "VSyncCallbackRegistrationTest.cpp",
         "VSyncDispatchTimerQueueTest.cpp",
         "VSyncDispatchRealtimeTest.cpp",
         "VsyncModulatorTest.cpp",
diff --git a/services/surfaceflinger/tests/unittests/VSyncCallbackRegistrationTest.cpp b/services/surfaceflinger/tests/unittests/VSyncCallbackRegistrationTest.cpp
new file mode 100644
index 0000000..69b3861
--- /dev/null
+++ b/services/surfaceflinger/tests/unittests/VSyncCallbackRegistrationTest.cpp
@@ -0,0 +1,144 @@
+/*
+ * Copyright 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#undef LOG_TAG
+#define LOG_TAG "LibSurfaceFlingerUnittests"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+#include "Scheduler/VSyncDispatch.h"
+#include "mock/MockVSyncDispatch.h"
+
+using namespace testing;
+
+namespace android::scheduler {
+
+class VSyncCallbackRegistrationTest : public Test {
+protected:
+    VSyncDispatch::Callback mCallback = [](nsecs_t, nsecs_t, nsecs_t) {};
+
+    std::shared_ptr<mock::VSyncDispatch> mVsyncDispatch = std::make_shared<mock::VSyncDispatch>();
+    VSyncDispatch::CallbackToken mCallbackToken{7};
+    std::string mCallbackName = "callback";
+
+    std::shared_ptr<mock::VSyncDispatch> mVsyncDispatch2 = std::make_shared<mock::VSyncDispatch>();
+    VSyncDispatch::CallbackToken mCallbackToken2{42};
+    std::string mCallbackName2 = "callback2";
+
+    void assertDispatch(const VSyncCallbackRegistration& registration,
+                        std::shared_ptr<VSyncDispatch> dispatch) {
+        ASSERT_EQ(registration.mDispatch, dispatch);
+    }
+
+    void assertToken(const VSyncCallbackRegistration& registration,
+                     const std::optional<VSyncDispatch::CallbackToken>& token) {
+        ASSERT_EQ(registration.mToken, token);
+    }
+};
+
+TEST_F(VSyncCallbackRegistrationTest, unregistersCallbackOnDestruction) {
+    // TODO (b/279581095): With ftl::Function, `_` can be replaced with
+    // `mCallback`, here and in other calls to `registerCallback, since the
+    // ftl version has an operator==, unlike std::function.
+    EXPECT_CALL(*mVsyncDispatch, registerCallback(_, mCallbackName))
+            .WillOnce(Return(mCallbackToken));
+    EXPECT_CALL(*mVsyncDispatch, unregisterCallback(mCallbackToken)).Times(1);
+
+    VSyncCallbackRegistration registration(mVsyncDispatch, mCallback, mCallbackName);
+    ASSERT_NO_FATAL_FAILURE(assertDispatch(registration, mVsyncDispatch));
+    ASSERT_NO_FATAL_FAILURE(assertToken(registration, mCallbackToken));
+}
+
+TEST_F(VSyncCallbackRegistrationTest, unregistersCallbackOnPointerMove) {
+    {
+        InSequence seq;
+        EXPECT_CALL(*mVsyncDispatch, registerCallback(_, mCallbackName))
+                .WillOnce(Return(mCallbackToken));
+        EXPECT_CALL(*mVsyncDispatch2, registerCallback(_, mCallbackName2))
+                .WillOnce(Return(mCallbackToken2));
+        EXPECT_CALL(*mVsyncDispatch2, unregisterCallback(mCallbackToken2)).Times(1);
+        EXPECT_CALL(*mVsyncDispatch, unregisterCallback(mCallbackToken)).Times(1);
+    }
+
+    auto registration =
+            std::make_unique<VSyncCallbackRegistration>(mVsyncDispatch, mCallback, mCallbackName);
+
+    auto registration2 =
+            std::make_unique<VSyncCallbackRegistration>(mVsyncDispatch2, mCallback, mCallbackName2);
+
+    registration2 = std::move(registration);
+
+    ASSERT_NO_FATAL_FAILURE(assertDispatch(*registration2.get(), mVsyncDispatch));
+    ASSERT_NO_FATAL_FAILURE(assertToken(*registration2.get(), mCallbackToken));
+}
+
+TEST_F(VSyncCallbackRegistrationTest, unregistersCallbackOnMoveOperator) {
+    {
+        InSequence seq;
+        EXPECT_CALL(*mVsyncDispatch, registerCallback(_, mCallbackName))
+                .WillOnce(Return(mCallbackToken));
+        EXPECT_CALL(*mVsyncDispatch2, registerCallback(_, mCallbackName2))
+                .WillOnce(Return(mCallbackToken2));
+        EXPECT_CALL(*mVsyncDispatch2, unregisterCallback(mCallbackToken2)).Times(1);
+        EXPECT_CALL(*mVsyncDispatch, unregisterCallback(mCallbackToken)).Times(1);
+    }
+
+    VSyncCallbackRegistration registration(mVsyncDispatch, mCallback, mCallbackName);
+
+    VSyncCallbackRegistration registration2(mVsyncDispatch2, mCallback, mCallbackName2);
+
+    registration2 = std::move(registration);
+
+    ASSERT_NO_FATAL_FAILURE(assertDispatch(registration, nullptr));
+    ASSERT_NO_FATAL_FAILURE(assertToken(registration, std::nullopt));
+
+    ASSERT_NO_FATAL_FAILURE(assertDispatch(registration2, mVsyncDispatch));
+    ASSERT_NO_FATAL_FAILURE(assertToken(registration2, mCallbackToken));
+}
+
+TEST_F(VSyncCallbackRegistrationTest, moveConstructor) {
+    EXPECT_CALL(*mVsyncDispatch, registerCallback(_, mCallbackName))
+            .WillOnce(Return(mCallbackToken));
+    EXPECT_CALL(*mVsyncDispatch, unregisterCallback(mCallbackToken)).Times(1);
+
+    VSyncCallbackRegistration registration(mVsyncDispatch, mCallback, mCallbackName);
+    VSyncCallbackRegistration registration2(std::move(registration));
+
+    ASSERT_NO_FATAL_FAILURE(assertDispatch(registration, nullptr));
+    ASSERT_NO_FATAL_FAILURE(assertToken(registration, std::nullopt));
+
+    ASSERT_NO_FATAL_FAILURE(assertDispatch(registration2, mVsyncDispatch));
+    ASSERT_NO_FATAL_FAILURE(assertToken(registration2, mCallbackToken));
+}
+
+TEST_F(VSyncCallbackRegistrationTest, moveOperatorEqualsSelf) {
+    EXPECT_CALL(*mVsyncDispatch, registerCallback(_, mCallbackName))
+            .WillOnce(Return(mCallbackToken));
+    EXPECT_CALL(*mVsyncDispatch, unregisterCallback(mCallbackToken)).Times(1);
+
+    VSyncCallbackRegistration registration(mVsyncDispatch, mCallback, mCallbackName);
+
+    // Use a reference so the compiler doesn't realize that registration is
+    // being moved to itself.
+    VSyncCallbackRegistration& registrationRef = registration;
+    registration = std::move(registrationRef);
+
+    ASSERT_NO_FATAL_FAILURE(assertDispatch(registration, mVsyncDispatch));
+    ASSERT_NO_FATAL_FAILURE(assertToken(registration, mCallbackToken));
+}
+
+} // namespace android::scheduler