Merge "Implement setThreads APIs for PerformanceHintManager.Session."
diff --git a/core/api/current.txt b/core/api/current.txt
index 9da5b7e..b59e32d 100644
--- a/core/api/current.txt
+++ b/core/api/current.txt
@@ -32519,6 +32519,7 @@
   public static class PerformanceHintManager.Session implements java.io.Closeable {
     method public void close();
     method public void reportActualWorkDuration(long);
+    method public void setThreads(@NonNull int[]);
     method public void updateTargetWorkDuration(long);
   }
 
diff --git a/core/api/test-current.txt b/core/api/test-current.txt
index 1cfd644..82cc3fd 100644
--- a/core/api/test-current.txt
+++ b/core/api/test-current.txt
@@ -1861,6 +1861,7 @@
   }
 
   public static class PerformanceHintManager.Session implements java.io.Closeable {
+    method @Nullable public int[] getThreadIds();
     method public void sendHint(int);
     field public static final int CPU_LOAD_DOWN = 1; // 0x1
     field public static final int CPU_LOAD_RESET = 2; // 0x2
diff --git a/core/java/android/os/IHintManager.aidl b/core/java/android/os/IHintManager.aidl
index 661b95a..d97ea54 100644
--- a/core/java/android/os/IHintManager.aidl
+++ b/core/java/android/os/IHintManager.aidl
@@ -24,10 +24,13 @@
     /**
      * Creates a {@link Session} for the given set of threads and associates to a binder token.
      */
-   IHintSession createHintSession(in IBinder token, in int[] tids, long durationNanos);
+    IHintSession createHintSession(in IBinder token, in int[] tids, long durationNanos);
 
     /**
      * Get preferred rate limit in nano second.
      */
-   long getHintSessionPreferredRate();
+    long getHintSessionPreferredRate();
+
+    void setHintSessionThreads(in IHintSession hintSession, in int[] tids);
+    int[] getHintSessionThreadIds(in IHintSession hintSession);
 }
diff --git a/core/java/android/os/PerformanceHintManager.java b/core/java/android/os/PerformanceHintManager.java
index 85d6d83..f79d6e6 100644
--- a/core/java/android/os/PerformanceHintManager.java
+++ b/core/java/android/os/PerformanceHintManager.java
@@ -222,16 +222,50 @@
                 Reference.reachabilityFence(this);
             }
         }
+
+        /**
+         * Set a list of threads to the performance hint session. This operation will replace
+         * the current list of threads with the given list of threads.
+         * Note that this is not an oneway method.
+         *
+         * @param tids The list of threads to be associated with this session. They must be
+         *     part of this app's thread group.
+         *
+         * @throws IllegalStateException if the hint session is not in the foreground.
+         * @throws IllegalArgumentException if the thread id list is empty.
+         * @throws SecurityException if any thread id doesn't belong to the application.
+         */
+        public void setThreads(@NonNull int[] tids) {
+            if (mNativeSessionPtr == 0) {
+                return;
+            }
+            if (tids.length == 0) {
+                throw new IllegalArgumentException("Thread id list can't be empty.");
+            }
+            nativeSetThreads(mNativeSessionPtr, tids);
+        }
+
+        /**
+         * Returns the list of thread ids.
+         *
+         * @hide
+         */
+        @TestApi
+        public @Nullable int[] getThreadIds() {
+            return nativeGetThreadIds(mNativeSessionPtr);
+        }
     }
 
     private static native long nativeAcquireManager();
     private static native long nativeGetPreferredUpdateRateNanos(long nativeManagerPtr);
     private static native long nativeCreateSession(long nativeManagerPtr,
             int[] tids, long initialTargetWorkDurationNanos);
+    private static native int[] nativeGetThreadIds(long nativeSessionPtr);
     private static native void nativeUpdateTargetWorkDuration(long nativeSessionPtr,
             long targetDurationNanos);
     private static native void nativeReportActualWorkDuration(long nativeSessionPtr,
             long actualDurationNanos);
     private static native void nativeCloseSession(long nativeSessionPtr);
     private static native void nativeSendHint(long nativeSessionPtr, int hint);
+    private static native void nativeSetThreads(long nativeSessionPtr, int[] tids);
 }
diff --git a/core/jni/android_os_PerformanceHintManager.cpp b/core/jni/android_os_PerformanceHintManager.cpp
index ac1401d..0223b96 100644
--- a/core/jni/android_os_PerformanceHintManager.cpp
+++ b/core/jni/android_os_PerformanceHintManager.cpp
@@ -41,6 +41,8 @@
 typedef void (*APH_reportActualWorkDuration)(APerformanceHintSession*, int64_t);
 typedef void (*APH_closeSession)(APerformanceHintSession* session);
 typedef void (*APH_sendHint)(APerformanceHintSession*, int32_t);
+typedef void (*APH_setThreads)(APerformanceHintSession*, const int32_t*, size_t);
+typedef void (*APH_getThreadIds)(APerformanceHintSession*, int32_t* const, size_t* const);
 
 bool gAPerformanceHintBindingInitialized = false;
 APH_getManager gAPH_getManagerFn = nullptr;
@@ -50,6 +52,8 @@
 APH_reportActualWorkDuration gAPH_reportActualWorkDurationFn = nullptr;
 APH_closeSession gAPH_closeSessionFn = nullptr;
 APH_sendHint gAPH_sendHintFn = nullptr;
+APH_setThreads gAPH_setThreadsFn = nullptr;
+APH_getThreadIds gAPH_getThreadIdsFn = nullptr;
 
 void ensureAPerformanceHintBindingInitialized() {
     if (gAPerformanceHintBindingInitialized) return;
@@ -95,6 +99,14 @@
                         "Failed to find required symbol "
                         "APerformanceHint_sendHint!");
 
+    gAPH_setThreadsFn = (APH_setThreads)dlsym(handle_, "APerformanceHint_setThreads");
+    LOG_ALWAYS_FATAL_IF(gAPH_setThreadsFn == nullptr,
+                        "Failed to find required symbol APerformanceHint_setThreads!");
+
+    gAPH_getThreadIdsFn = (APH_getThreadIds)dlsym(handle_, "APerformanceHint_getThreadIds");
+    LOG_ALWAYS_FATAL_IF(gAPH_getThreadIdsFn == nullptr,
+                        "Failed to find required symbol APerformanceHint_getThreadIds!");
+
     gAPerformanceHintBindingInitialized = true;
 }
 
@@ -150,6 +162,50 @@
     gAPH_sendHintFn(reinterpret_cast<APerformanceHintSession*>(nativeSessionPtr), hint);
 }
 
+static void nativeSetThreads(JNIEnv* env, jclass clazz, jlong nativeSessionPtr, jintArray tids) {
+    ensureAPerformanceHintBindingInitialized();
+
+    if (tids == nullptr) {
+        return;
+    }
+    ScopedIntArrayRO tidsArray(env, tids);
+    std::vector<int32_t> tidsVector;
+    tidsVector.reserve(tidsArray.size());
+    for (size_t i = 0; i < tidsArray.size(); ++i) {
+        tidsVector.push_back(static_cast<int32_t>(tidsArray[i]));
+    }
+    gAPH_setThreadsFn(reinterpret_cast<APerformanceHintSession*>(nativeSessionPtr),
+                      tidsVector.data(), tidsVector.size());
+}
+
+// This call should only be used for validation in tests only. This call will initiate two IPC
+// calls, the first one is used to determined the size of the thread ids list, the second one
+// is used to return the actual list.
+static jintArray nativeGetThreadIds(JNIEnv* env, jclass clazz, jlong nativeSessionPtr) {
+    ensureAPerformanceHintBindingInitialized();
+    size_t size = 0;
+    gAPH_getThreadIdsFn(reinterpret_cast<APerformanceHintSession*>(nativeSessionPtr), nullptr,
+                        &size);
+    if (size == 0) {
+        jintArray jintArr = env->NewIntArray(0);
+        return jintArr;
+    }
+    std::vector<int32_t> tidsVector(size);
+    gAPH_getThreadIdsFn(reinterpret_cast<APerformanceHintSession*>(nativeSessionPtr),
+                        tidsVector.data(), &size);
+    jintArray jintArr = env->NewIntArray(size);
+    if (jintArr == nullptr) {
+        jniThrowException(env, "java/lang/OutOfMemoryError", nullptr);
+        return nullptr;
+    }
+    jint* threadIds = env->GetIntArrayElements(jintArr, 0);
+    for (int i = 0; i < size; ++i) {
+        threadIds[i] = tidsVector[i];
+    }
+    env->ReleaseIntArrayElements(jintArr, threadIds, 0);
+    return jintArr;
+}
+
 static const JNINativeMethod gPerformanceHintMethods[] = {
         {"nativeAcquireManager", "()J", (void*)nativeAcquireManager},
         {"nativeGetPreferredUpdateRateNanos", "(J)J", (void*)nativeGetPreferredUpdateRateNanos},
@@ -158,6 +214,8 @@
         {"nativeReportActualWorkDuration", "(JJ)V", (void*)nativeReportActualWorkDuration},
         {"nativeCloseSession", "(J)V", (void*)nativeCloseSession},
         {"nativeSendHint", "(JI)V", (void*)nativeSendHint},
+        {"nativeSetThreads", "(J[I)V", (void*)nativeSetThreads},
+        {"nativeGetThreadIds", "(J)[I", (void*)nativeGetThreadIds},
 };
 
 int register_android_os_PerformanceHintManager(JNIEnv* env) {
diff --git a/core/tests/coretests/src/android/os/PerformanceHintManagerTest.java b/core/tests/coretests/src/android/os/PerformanceHintManagerTest.java
index 44923b6..7eefbbc 100644
--- a/core/tests/coretests/src/android/os/PerformanceHintManagerTest.java
+++ b/core/tests/coretests/src/android/os/PerformanceHintManagerTest.java
@@ -137,4 +137,13 @@
         assumeNotNull(s);
         s.close();
     }
+
+    @Test
+    public void testSetThreadsWithIllegalArgument() {
+        Session session = createSession();
+        assumeNotNull(session);
+        assertThrows(IllegalArgumentException.class, () -> {
+            session.setThreads(new int[] { });
+        });
+    }
 }
diff --git a/native/android/libandroid.map.txt b/native/android/libandroid.map.txt
index 4e6a0c5..e89c8c9 100644
--- a/native/android/libandroid.map.txt
+++ b/native/android/libandroid.map.txt
@@ -330,6 +330,7 @@
     APerformanceHint_updateTargetWorkDuration; # introduced=Tiramisu
     APerformanceHint_reportActualWorkDuration; # introduced=Tiramisu
     APerformanceHint_closeSession; # introduced=Tiramisu
+    APerformanceHint_setThreads; # introduced=UpsideDownCake
   local:
     *;
 };
@@ -338,6 +339,7 @@
   global:
     APerformanceHint_setIHintManagerForTesting;
     APerformanceHint_sendHint;
+    APerformanceHint_getThreadIds;
     extern "C++" {
         ASurfaceControl_registerSurfaceStatsListener*;
         ASurfaceControl_unregisterSurfaceStatsListener*;
diff --git a/native/android/performance_hint.cpp b/native/android/performance_hint.cpp
index 43b3d2e..dfbd7b5 100644
--- a/native/android/performance_hint.cpp
+++ b/native/android/performance_hint.cpp
@@ -62,18 +62,21 @@
 
 struct APerformanceHintSession {
 public:
-    APerformanceHintSession(sp<IHintSession> session, int64_t preferredRateNanos,
-                            int64_t targetDurationNanos);
+    APerformanceHintSession(sp<IHintManager> hintManager, sp<IHintSession> session,
+                            int64_t preferredRateNanos, int64_t targetDurationNanos);
     APerformanceHintSession() = delete;
     ~APerformanceHintSession();
 
     int updateTargetWorkDuration(int64_t targetDurationNanos);
     int reportActualWorkDuration(int64_t actualDurationNanos);
     int sendHint(int32_t hint);
+    int setThreads(const int32_t* threadIds, size_t size);
+    int getThreadIds(int32_t* const threadIds, size_t* size);
 
 private:
     friend struct APerformanceHintManager;
 
+    sp<IHintManager> mHintManager;
     sp<IHintSession> mHintSession;
     // HAL preferred update rate
     const int64_t mPreferredRateNanos;
@@ -140,7 +143,7 @@
     if (!ret.isOk() || !session) {
         return nullptr;
     }
-    return new APerformanceHintSession(std::move(session), mPreferredRateNanos,
+    return new APerformanceHintSession(mHintManager, std::move(session), mPreferredRateNanos,
                                        initialTargetWorkDurationNanos);
 }
 
@@ -150,10 +153,12 @@
 
 // ===================================== APerformanceHintSession implementation
 
-APerformanceHintSession::APerformanceHintSession(sp<IHintSession> session,
+APerformanceHintSession::APerformanceHintSession(sp<IHintManager> hintManager,
+                                                 sp<IHintSession> session,
                                                  int64_t preferredRateNanos,
                                                  int64_t targetDurationNanos)
-      : mHintSession(std::move(session)),
+      : mHintManager(hintManager),
+        mHintSession(std::move(session)),
         mPreferredRateNanos(preferredRateNanos),
         mTargetDurationNanos(targetDurationNanos),
         mFirstTargetMetTimestamp(0),
@@ -260,6 +265,47 @@
     return 0;
 }
 
+int APerformanceHintSession::setThreads(const int32_t* threadIds, size_t size) {
+    if (size == 0) {
+        ALOGE("%s: the list of thread ids must not be empty.", __FUNCTION__);
+        return EINVAL;
+    }
+    std::vector<int32_t> tids(threadIds, threadIds + size);
+    binder::Status ret = mHintManager->setHintSessionThreads(mHintSession, tids);
+    if (!ret.isOk()) {
+        ALOGE("%s: failed: %s", __FUNCTION__, ret.exceptionMessage().c_str());
+        if (ret.exceptionCode() == binder::Status::Exception::EX_SECURITY ||
+            ret.exceptionCode() == binder::Status::Exception::EX_ILLEGAL_ARGUMENT) {
+            return EINVAL;
+        }
+        return EPIPE;
+    }
+    return 0;
+}
+
+int APerformanceHintSession::getThreadIds(int32_t* const threadIds, size_t* size) {
+    std::vector<int32_t> tids;
+    binder::Status ret = mHintManager->getHintSessionThreadIds(mHintSession, &tids);
+    if (!ret.isOk()) {
+        ALOGE("%s: failed: %s", __FUNCTION__, ret.exceptionMessage().c_str());
+        return EPIPE;
+    }
+
+    // When threadIds is nullptr, this is the first call to determine the size
+    // of the thread ids list.
+    if (threadIds == nullptr) {
+        *size = tids.size();
+        return 0;
+    }
+
+    // Second call to return the actual list of thread ids.
+    *size = tids.size();
+    for (size_t i = 0; i < *size; ++i) {
+        threadIds[i] = tids[i];
+    }
+    return 0;
+}
+
 // ===================================== C API
 APerformanceHintManager* APerformanceHint_getManager() {
     return APerformanceHintManager::getInstance();
@@ -293,6 +339,23 @@
     return reinterpret_cast<APerformanceHintSession*>(session)->sendHint(hint);
 }
 
+int APerformanceHint_setThreads(APerformanceHintSession* session, const int32_t* threadIds,
+                                size_t size) {
+    if (session == nullptr) {
+        return EINVAL;
+    }
+    return session->setThreads(threadIds, size);
+}
+
+int APerformanceHint_getThreadIds(void* aPerformanceHintSession, int32_t* const threadIds,
+                                  size_t* const size) {
+    if (aPerformanceHintSession == nullptr) {
+        return EINVAL;
+    }
+    return static_cast<APerformanceHintSession*>(aPerformanceHintSession)
+            ->getThreadIds(threadIds, size);
+}
+
 void APerformanceHint_setIHintManagerForTesting(void* iManager) {
     delete gHintManagerForTesting;
     gHintManagerForTesting = nullptr;
diff --git a/native/android/tests/performance_hint/PerformanceHintNativeTest.cpp b/native/android/tests/performance_hint/PerformanceHintNativeTest.cpp
index 0c2d3b6..321a7dd 100644
--- a/native/android/tests/performance_hint/PerformanceHintNativeTest.cpp
+++ b/native/android/tests/performance_hint/PerformanceHintNativeTest.cpp
@@ -37,10 +37,15 @@
 class MockIHintManager : public IHintManager {
 public:
     MOCK_METHOD(Status, createHintSession,
-                (const ::android::sp<::android::IBinder>& token, const ::std::vector<int32_t>& tids,
-                 int64_t durationNanos, ::android::sp<::android::os::IHintSession>* _aidl_return),
+                (const sp<IBinder>& token, const ::std::vector<int32_t>& tids,
+                 int64_t durationNanos, ::android::sp<IHintSession>* _aidl_return),
                 (override));
     MOCK_METHOD(Status, getHintSessionPreferredRate, (int64_t * _aidl_return), (override));
+    MOCK_METHOD(Status, setHintSessionThreads,
+                (const sp<IHintSession>& hintSession, const ::std::vector<int32_t>& tids),
+                (override));
+    MOCK_METHOD(Status, getHintSessionThreadIds,
+                (const sp<IHintSession>& hintSession, ::std::vector<int32_t>* tids), (override));
     MOCK_METHOD(IBinder*, onAsBinder, (), (override));
 };
 
@@ -141,3 +146,36 @@
     EXPECT_CALL(*iSession, close()).Times(Exactly(1));
     APerformanceHint_closeSession(session);
 }
+
+TEST_F(PerformanceHintTest, SetThreads) {
+    APerformanceHintManager* manager = createManager();
+
+    std::vector<int32_t> tids;
+    tids.push_back(1);
+    tids.push_back(2);
+    int64_t targetDuration = 56789L;
+
+    StrictMock<MockIHintSession>* iSession = new StrictMock<MockIHintSession>();
+    sp<IHintSession> session_sp(iSession);
+
+    EXPECT_CALL(*mMockIHintManager, createHintSession(_, Eq(tids), Eq(targetDuration), _))
+            .Times(Exactly(1))
+            .WillRepeatedly(DoAll(SetArgPointee<3>(std::move(session_sp)), Return(Status())));
+
+    APerformanceHintSession* session =
+            APerformanceHint_createSession(manager, tids.data(), tids.size(), targetDuration);
+    ASSERT_TRUE(session);
+
+    std::vector<int32_t> emptyTids;
+    int result = APerformanceHint_setThreads(session, emptyTids.data(), emptyTids.size());
+    EXPECT_EQ(EINVAL, result);
+
+    std::vector<int32_t> newTids;
+    newTids.push_back(1);
+    newTids.push_back(3);
+    EXPECT_CALL(*mMockIHintManager, setHintSessionThreads(_, Eq(newTids)))
+            .Times(Exactly(1))
+            .WillOnce(Return(Status()));
+    result = APerformanceHint_setThreads(session, newTids.data(), newTids.size());
+    EXPECT_EQ(0, result);
+}
diff --git a/services/core/java/com/android/server/power/hint/HintManagerService.java b/services/core/java/com/android/server/power/hint/HintManagerService.java
index 952fcdc..a9a1d5e 100644
--- a/services/core/java/com/android/server/power/hint/HintManagerService.java
+++ b/services/core/java/com/android/server/power/hint/HintManagerService.java
@@ -189,6 +189,8 @@
 
         private static native void nativeSendHint(long halPtr, int hint);
 
+        private static native void nativeSetThreads(long halPtr, int[] tids);
+
         private static native long nativeGetHintSessionPreferredRate();
 
         /** Wrapper for HintManager.nativeInit */
@@ -237,6 +239,11 @@
         public long halGetHintSessionPreferredRate() {
             return nativeGetHintSessionPreferredRate();
         }
+
+        /** Wrapper for HintManager.nativeSetThreads */
+        public void halSetThreads(long halPtr, int[] tids) {
+            nativeSetThreads(halPtr, tids);
+        }
     }
 
     @VisibleForTesting
@@ -400,6 +407,18 @@
         }
 
         @Override
+        public void setHintSessionThreads(@NonNull IHintSession hintSession, @NonNull int[] tids) {
+            AppHintSession appHintSession = (AppHintSession) hintSession;
+            appHintSession.setThreads(tids);
+        }
+
+        @Override
+        public int[] getHintSessionThreadIds(@NonNull IHintSession hintSession) {
+            AppHintSession appHintSession = (AppHintSession) hintSession;
+            return appHintSession.getThreadIds();
+        }
+
+        @Override
         public void dump(FileDescriptor fd, PrintWriter pw, String[] args) {
             if (!DumpUtils.checkDumpPermission(getContext(), TAG, pw)) {
                 return;
@@ -434,11 +453,12 @@
     final class AppHintSession extends IHintSession.Stub implements IBinder.DeathRecipient {
         protected final int mUid;
         protected final int mPid;
-        protected final int[] mThreadIds;
+        protected int[] mThreadIds;
         protected final IBinder mToken;
         protected long mHalSessionPtr;
         protected long mTargetDurationNanos;
         protected boolean mUpdateAllowed;
+        protected int[] mNewThreadIds;
 
         protected AppHintSession(
                 int uid, int pid, int[] threadIds, IBinder token,
@@ -541,6 +561,38 @@
             }
         }
 
+        public void setThreads(@NonNull int[] tids) {
+            synchronized (mLock) {
+                if (mHalSessionPtr == 0) {
+                    return;
+                }
+                if (tids.length == 0) {
+                    throw new IllegalArgumentException("Thread id list can't be empty.");
+                }
+                final int callingUid = Binder.getCallingUid();
+                final int callingTgid = Process.getThreadGroupLeader(Binder.getCallingPid());
+                final long identity = Binder.clearCallingIdentity();
+                try {
+                    if (!checkTidValid(callingUid, callingTgid, tids)) {
+                        throw new SecurityException("Some tid doesn't belong to the application.");
+                    }
+                } finally {
+                    Binder.restoreCallingIdentity(identity);
+                }
+                if (!updateHintAllowed()) {
+                    Slogf.v(TAG, "update hint not allowed, storing tids.");
+                    mNewThreadIds = tids;
+                    return;
+                }
+                mNativeWrapper.halSetThreads(mHalSessionPtr, tids);
+                mThreadIds = tids;
+            }
+        }
+
+        public int[] getThreadIds() {
+            return mThreadIds;
+        }
+
         private void onProcStateChanged() {
             updateHintAllowed();
         }
@@ -556,6 +608,11 @@
             synchronized (mLock) {
                 if (mHalSessionPtr == 0) return;
                 mNativeWrapper.halResumeHintSession(mHalSessionPtr);
+                if (mNewThreadIds != null) {
+                    mNativeWrapper.halSetThreads(mHalSessionPtr, mNewThreadIds);
+                    mThreadIds = mNewThreadIds;
+                    mNewThreadIds = null;
+                }
             }
         }
 
diff --git a/services/core/jni/com_android_server_hint_HintManagerService.cpp b/services/core/jni/com_android_server_hint_HintManagerService.cpp
index d975760..e322fa2 100644
--- a/services/core/jni/com_android_server_hint_HintManagerService.cpp
+++ b/services/core/jni/com_android_server_hint_HintManagerService.cpp
@@ -87,6 +87,11 @@
     appSession->sendHint(hint);
 }
 
+static void setThreads(int64_t session_ptr, const std::vector<int32_t>& threadIds) {
+    sp<IPowerHintSession> appSession = reinterpret_cast<IPowerHintSession*>(session_ptr);
+    appSession->setThreads(threadIds);
+}
+
 static int64_t getHintSessionPreferredRate() {
     int64_t rate = -1;
     auto result = gPowerHalController.getHintSessionPreferredRate();
@@ -149,6 +154,16 @@
     sendHint(session_ptr, static_cast<SessionHint>(hint));
 }
 
+static void nativeSetThreads(JNIEnv* env, jclass /* clazz */, jlong session_ptr, jintArray tids) {
+    ScopedIntArrayRO arrayThreadIds(env, tids);
+
+    std::vector<int32_t> threadIds(arrayThreadIds.size());
+    for (size_t i = 0; i < arrayThreadIds.size(); i++) {
+        threadIds[i] = arrayThreadIds[i];
+    }
+    setThreads(session_ptr, threadIds);
+}
+
 static jlong nativeGetHintSessionPreferredRate(JNIEnv* /* env */, jclass /* clazz */) {
     return static_cast<jlong>(getHintSessionPreferredRate());
 }
@@ -164,6 +179,7 @@
         {"nativeUpdateTargetWorkDuration", "(JJ)V", (void*)nativeUpdateTargetWorkDuration},
         {"nativeReportActualWorkDuration", "(J[J[J)V", (void*)nativeReportActualWorkDuration},
         {"nativeSendHint", "(JI)V", (void*)nativeSendHint},
+        {"nativeSetThreads", "(J[I)V", (void*)nativeSetThreads},
         {"nativeGetHintSessionPreferredRate", "()J", (void*)nativeGetHintSessionPreferredRate},
 };
 
diff --git a/services/tests/servicestests/src/com/android/server/power/hint/HintManagerServiceTest.java b/services/tests/servicestests/src/com/android/server/power/hint/HintManagerServiceTest.java
index dcbdcdc..136507d 100644
--- a/services/tests/servicestests/src/com/android/server/power/hint/HintManagerServiceTest.java
+++ b/services/tests/servicestests/src/com/android/server/power/hint/HintManagerServiceTest.java
@@ -19,6 +19,7 @@
 
 import static com.google.common.truth.Truth.assertThat;
 
+import static org.junit.Assert.assertArrayEquals;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertNotEquals;
 import static org.junit.Assert.assertNotNull;
@@ -310,4 +311,32 @@
                 a.mUid, ActivityManager.PROCESS_STATE_IMPORTANT_FOREGROUND, 0, 0);
         assertTrue(a.updateHintAllowed());
     }
+
+    @Test
+    public void testSetThreads() throws Exception {
+        HintManagerService service = createService();
+        IBinder token = new Binder();
+
+        AppHintSession a = (AppHintSession) service.getBinderServiceInstance()
+                .createHintSession(token, SESSION_TIDS_A, DEFAULT_TARGET_DURATION);
+
+        a.updateTargetWorkDuration(100L);
+
+        assertThrows(IllegalArgumentException.class, () -> {
+            a.setThreads(new int[]{});
+        });
+
+        a.setThreads(SESSION_TIDS_B);
+        verify(mNativeWrapperMock, times(1)).halSetThreads(anyLong(), eq(SESSION_TIDS_B));
+        assertArrayEquals(SESSION_TIDS_B, a.getThreadIds());
+
+        reset(mNativeWrapperMock);
+        // Set session to background, then the duration would not be updated.
+        service.mUidObserver.onUidStateChanged(
+                a.mUid, ActivityManager.PROCESS_STATE_TRANSIENT_BACKGROUND, 0, 0);
+        FgThread.getHandler().runWithScissors(() -> { }, 500);
+        assertFalse(a.updateHintAllowed());
+        a.setThreads(SESSION_TIDS_A);
+        verify(mNativeWrapperMock, never()).halSetThreads(anyLong(), any());
+    }
 }