Pass active functor rendering threads to HWUI ADPF session

Test: unit tested

Bug: 329219352

Change-Id: I13e83030e5d0891584ba9d62fe6cc1eb332a7b74
diff --git a/libs/hwui/WebViewFunctorManager.cpp b/libs/hwui/WebViewFunctorManager.cpp
index 5b4ab5f..efa9b11 100644
--- a/libs/hwui/WebViewFunctorManager.cpp
+++ b/libs/hwui/WebViewFunctorManager.cpp
@@ -16,15 +16,16 @@
 
 #include "WebViewFunctorManager.h"
 
+#include <log/log.h>
 #include <private/hwui/WebViewFunctor.h>
+#include <utils/Trace.h>
+
+#include <atomic>
+
 #include "Properties.h"
 #include "renderthread/CanvasContext.h"
 #include "renderthread/RenderThread.h"
 
-#include <log/log.h>
-#include <utils/Trace.h>
-#include <atomic>
-
 namespace android::uirenderer {
 
 namespace {
@@ -265,7 +266,7 @@
 }
 
 void WebViewFunctor::reportRenderingThreads(const int32_t* thread_ids, size_t size) {
-    // TODO(b/329219352): Pass the threads to HWUI and update the ADPF session.
+    mRenderingThreads = std::vector<int32_t>(thread_ids, thread_ids + size);
 }
 
 WebViewFunctorManager& WebViewFunctorManager::instance() {
@@ -365,6 +366,21 @@
     }
 }
 
+std::vector<int32_t> WebViewFunctorManager::getRenderingThreadsForActiveFunctors() {
+    std::vector<int32_t> renderingThreads;
+    std::lock_guard _lock{mLock};
+    for (const auto& iter : mActiveFunctors) {
+        const auto& functorThreads = iter->getRenderingThreads();
+        for (const auto& tid : functorThreads) {
+            if (std::find(renderingThreads.begin(), renderingThreads.end(), tid) ==
+                renderingThreads.end()) {
+                renderingThreads.push_back(tid);
+            }
+        }
+    }
+    return renderingThreads;
+}
+
 sp<WebViewFunctor::Handle> WebViewFunctorManager::handleFor(int functor) {
     std::lock_guard _lock{mLock};
     for (auto& iter : mActiveFunctors) {
diff --git a/libs/hwui/WebViewFunctorManager.h b/libs/hwui/WebViewFunctorManager.h
index 1bf2c1f..2d77dd8 100644
--- a/libs/hwui/WebViewFunctorManager.h
+++ b/libs/hwui/WebViewFunctorManager.h
@@ -60,6 +60,10 @@
 
         void onRemovedFromTree() { mReference.onRemovedFromTree(); }
 
+        const std::vector<int32_t>& getRenderingThreads() const {
+            return mReference.getRenderingThreads();
+        }
+
     private:
         friend class WebViewFunctor;
 
@@ -82,6 +86,7 @@
     void mergeTransaction(ASurfaceTransaction* transaction);
 
     void reportRenderingThreads(const int32_t* thread_ids, size_t size);
+    const std::vector<int32_t>& getRenderingThreads() const { return mRenderingThreads; }
 
     sp<Handle> createHandle() {
         LOG_ALWAYS_FATAL_IF(mCreatedHandle);
@@ -102,6 +107,7 @@
     bool mCreatedHandle = false;
     int32_t mParentSurfaceControlGenerationId = 0;
     ASurfaceControl* mSurfaceControl = nullptr;
+    std::vector<int32_t> mRenderingThreads;
 };
 
 class WebViewFunctorManager {
@@ -113,6 +119,7 @@
     void onContextDestroyed();
     void destroyFunctor(int functor);
     void reportRenderingThreads(int functor, const int32_t* thread_ids, size_t size);
+    std::vector<int32_t> getRenderingThreadsForActiveFunctors();
 
     sp<WebViewFunctor::Handle> handleFor(int functor);
 
diff --git a/libs/hwui/renderthread/CanvasContext.cpp b/libs/hwui/renderthread/CanvasContext.cpp
index abf64d0..1fbd580 100644
--- a/libs/hwui/renderthread/CanvasContext.cpp
+++ b/libs/hwui/renderthread/CanvasContext.cpp
@@ -777,6 +777,8 @@
                                  (std::min(syncDelayDuration, mLastDequeueBufferDuration)) -
                                  dequeueBufferDuration - idleDuration;
         mHintSessionWrapper->reportActualWorkDuration(actualDuration);
+        mHintSessionWrapper->setActiveFunctorThreads(
+                WebViewFunctorManager::instance().getRenderingThreadsForActiveFunctors());
     }
 
     mLastDequeueBufferDuration = dequeueBufferDuration;
diff --git a/libs/hwui/renderthread/HintSessionWrapper.cpp b/libs/hwui/renderthread/HintSessionWrapper.cpp
index 2362331..6993d52 100644
--- a/libs/hwui/renderthread/HintSessionWrapper.cpp
+++ b/libs/hwui/renderthread/HintSessionWrapper.cpp
@@ -20,6 +20,7 @@
 #include <private/performance_hint_private.h>
 #include <utils/Log.h>
 
+#include <algorithm>
 #include <chrono>
 #include <vector>
 
@@ -49,6 +50,7 @@
     BIND_APH_METHOD(updateTargetWorkDuration);
     BIND_APH_METHOD(reportActualWorkDuration);
     BIND_APH_METHOD(sendHint);
+    BIND_APH_METHOD(setThreads);
 
     mInitialized = true;
 }
@@ -67,6 +69,10 @@
         mHintSession = mHintSessionFuture->get();
         mHintSessionFuture = std::nullopt;
     }
+    if (mSetThreadsFuture.has_value()) {
+        mSetThreadsFuture->wait();
+        mSetThreadsFuture = std::nullopt;
+    }
     if (mHintSession) {
         mBinding->closeSession(mHintSession);
         mSessionValid = true;
@@ -106,16 +112,16 @@
     APerformanceHintManager* manager = mBinding->getManager();
     if (!manager) return false;
 
-    std::vector<pid_t> tids = CommonPool::getThreadIds();
-    tids.push_back(mUiThreadId);
-    tids.push_back(mRenderThreadId);
+    mPermanentSessionTids = CommonPool::getThreadIds();
+    mPermanentSessionTids.push_back(mUiThreadId);
+    mPermanentSessionTids.push_back(mRenderThreadId);
 
     // Use the cached target value if there is one, otherwise use a default. This is to ensure
     // the cached target and target in PowerHAL are consistent, and that it updates correctly
     // whenever there is a change.
     int64_t targetDurationNanos =
             mLastTargetWorkDuration == 0 ? kDefaultTargetDuration : mLastTargetWorkDuration;
-    mHintSessionFuture = CommonPool::async([=, this, tids = std::move(tids)] {
+    mHintSessionFuture = CommonPool::async([=, this, tids = mPermanentSessionTids] {
         return mBinding->createSession(manager, tids.data(), tids.size(), targetDurationNanos);
     });
     return false;
@@ -143,6 +149,23 @@
     mLastFrameNotification = systemTime();
 }
 
+void HintSessionWrapper::setActiveFunctorThreads(std::vector<pid_t> threadIds) {
+    if (!init()) return;
+    if (!mBinding || !mHintSession) return;
+    // Sort the vector to make sure they're compared as sets.
+    std::sort(threadIds.begin(), threadIds.end());
+    if (threadIds == mActiveFunctorTids) return;
+    mActiveFunctorTids = std::move(threadIds);
+    std::vector<pid_t> combinedTids = mPermanentSessionTids;
+    std::copy(mActiveFunctorTids.begin(), mActiveFunctorTids.end(),
+              std::back_inserter(combinedTids));
+    mSetThreadsFuture = CommonPool::async([this, tids = std::move(combinedTids)] {
+        int ret = mBinding->setThreads(mHintSession, tids.data(), tids.size());
+        ALOGE_IF(ret != 0, "APerformaceHint_setThreads failed: %d", ret);
+        return ret;
+    });
+}
+
 void HintSessionWrapper::sendLoadResetHint() {
     static constexpr int kMaxResetsSinceLastReport = 2;
     if (!init()) return;
diff --git a/libs/hwui/renderthread/HintSessionWrapper.h b/libs/hwui/renderthread/HintSessionWrapper.h
index 41891cd..14e7a53 100644
--- a/libs/hwui/renderthread/HintSessionWrapper.h
+++ b/libs/hwui/renderthread/HintSessionWrapper.h
@@ -20,6 +20,7 @@
 
 #include <future>
 #include <optional>
+#include <vector>
 
 #include "utils/TimeUtils.h"
 
@@ -47,11 +48,15 @@
     nsecs_t getLastUpdate();
     void delayedDestroy(renderthread::RenderThread& rt, nsecs_t delay,
                         std::shared_ptr<HintSessionWrapper> wrapperPtr);
+    // Must be called on Render thread. Otherwise can cause a race condition.
+    void setActiveFunctorThreads(std::vector<pid_t> threadIds);
 
 private:
     APerformanceHintSession* mHintSession = nullptr;
     // This needs to work concurrently for testing
     std::optional<std::shared_future<APerformanceHintSession*>> mHintSessionFuture;
+    // This needs to work concurrently for testing
+    std::optional<std::shared_future<int>> mSetThreadsFuture;
 
     int mResetsSinceLastReport = 0;
     nsecs_t mLastFrameNotification = 0;
@@ -59,6 +64,8 @@
 
     pid_t mUiThreadId;
     pid_t mRenderThreadId;
+    std::vector<pid_t> mPermanentSessionTids;
+    std::vector<pid_t> mActiveFunctorTids;
 
     bool mSessionValid = true;
 
@@ -82,6 +89,8 @@
         void (*reportActualWorkDuration)(APerformanceHintSession* session,
                                          int64_t actualDuration) = nullptr;
         void (*sendHint)(APerformanceHintSession* session, int32_t hintId) = nullptr;
+        int (*setThreads)(APerformanceHintSession* session, const pid_t* tids,
+                          size_t size) = nullptr;
 
     private:
         bool mInitialized = false;
diff --git a/libs/hwui/tests/unit/HintSessionWrapperTests.cpp b/libs/hwui/tests/unit/HintSessionWrapperTests.cpp
index 10a740a1..c16602c 100644
--- a/libs/hwui/tests/unit/HintSessionWrapperTests.cpp
+++ b/libs/hwui/tests/unit/HintSessionWrapperTests.cpp
@@ -58,6 +58,7 @@
         MOCK_METHOD(void, fakeUpdateTargetWorkDuration, (APerformanceHintSession*, int64_t));
         MOCK_METHOD(void, fakeReportActualWorkDuration, (APerformanceHintSession*, int64_t));
         MOCK_METHOD(void, fakeSendHint, (APerformanceHintSession*, int32_t));
+        MOCK_METHOD(int, fakeSetThreads, (APerformanceHintSession*, const std::vector<pid_t>&));
         // Needs to be on the binding so it can be accessed from static methods
         std::promise<int> allowCreationToFinish;
     };
@@ -102,11 +103,20 @@
     static void stubSendHint(APerformanceHintSession* session, int32_t hintId) {
         sMockBinding->fakeSendHint(session, hintId);
     };
+    static int stubSetThreads(APerformanceHintSession* session, const pid_t* ids, size_t size) {
+        std::vector<pid_t> tids(ids, ids + size);
+        return sMockBinding->fakeSetThreads(session, tids);
+    }
     void waitForWrapperReady() {
         if (mWrapper->mHintSessionFuture.has_value()) {
             mWrapper->mHintSessionFuture->wait();
         }
     }
+    void waitForSetThreadsReady() {
+        if (mWrapper->mSetThreadsFuture.has_value()) {
+            mWrapper->mSetThreadsFuture->wait();
+        }
+    }
     void scheduleDelayedDestroyManaged() {
         TestUtils::runOnRenderThread([&](renderthread::RenderThread& rt) {
             // Guaranteed to be scheduled first, allows destruction to start
@@ -130,6 +140,7 @@
     mWrapper->mBinding = sMockBinding;
     EXPECT_CALL(*sMockBinding, fakeGetManager).WillOnce(Return(managerPtr));
     ON_CALL(*sMockBinding, fakeCreateSession).WillByDefault(Return(sessionPtr));
+    ON_CALL(*sMockBinding, fakeSetThreads).WillByDefault(Return(0));
 }
 
 void HintSessionWrapperTests::MockHintSessionBinding::init() {
@@ -141,6 +152,7 @@
     sMockBinding->updateTargetWorkDuration = &stubUpdateTargetWorkDuration;
     sMockBinding->reportActualWorkDuration = &stubReportActualWorkDuration;
     sMockBinding->sendHint = &stubSendHint;
+    sMockBinding->setThreads = &stubSetThreads;
 }
 
 void HintSessionWrapperTests::TearDown() {
@@ -339,4 +351,44 @@
     EXPECT_EQ(mWrapper->alive(), false);
 }
 
+TEST_F(HintSessionWrapperTests, setThreadsUpdatesSessionThreads) {
+    EXPECT_CALL(*sMockBinding, fakeCreateSession(managerPtr, _, Gt(1), _)).Times(1);
+    EXPECT_CALL(*sMockBinding, fakeSetThreads(sessionPtr, testing::IsSupersetOf({11, 22})))
+            .Times(1);
+    mWrapper->init();
+    waitForWrapperReady();
+
+    // This changes the overall set of threads in the session, so the session wrapper should call
+    // setThreads.
+    mWrapper->setActiveFunctorThreads({11, 22});
+    waitForSetThreadsReady();
+
+    // The set of threads doesn't change, so the session wrapper should not call setThreads this
+    // time. The order of the threads shouldn't matter.
+    mWrapper->setActiveFunctorThreads({22, 11});
+    waitForSetThreadsReady();
+}
+
+TEST_F(HintSessionWrapperTests, setThreadsDoesntCrashAfterDestroy) {
+    EXPECT_CALL(*sMockBinding, fakeCloseSession(sessionPtr)).Times(1);
+
+    mWrapper->init();
+    waitForWrapperReady();
+    // Init a second time just to grab the wrapper from the promise
+    mWrapper->init();
+    EXPECT_EQ(mWrapper->alive(), true);
+
+    // Then, kill the session
+    mWrapper->destroy();
+
+    // Verify it died
+    Mock::VerifyAndClearExpectations(sMockBinding.get());
+    EXPECT_EQ(mWrapper->alive(), false);
+
+    // setActiveFunctorThreads shouldn't do anything, and shouldn't crash.
+    EXPECT_CALL(*sMockBinding, fakeSetThreads(_, _)).Times(0);
+    mWrapper->setActiveFunctorThreads({11, 22});
+    waitForSetThreadsReady();
+}
+
 }  // namespace android::uirenderer::renderthread
\ No newline at end of file