Add SDK support for thermal headroom callback API

* Add cache for headroom forecast which resets on temperature or threshold update
* Remove the cache for thermal headroom thresholds in PowerManager
  as it can change now
* Only trigger headroom callback on skin type throttling event or
  threshold update event that causes significant difference in headrooms

Bug: 360486877
Flag: android.os.allow_thermal_thresholds_callback
Test: atest ThermalManagerServiceTest ThermalManagerServiceMockingTest PowerManagerTest
Change-Id: Id5e311634f3b94fe041e51732496d182b2a78139
diff --git a/core/api/current.txt b/core/api/current.txt
index c31928dc..6931024 100644
--- a/core/api/current.txt
+++ b/core/api/current.txt
@@ -33862,12 +33862,14 @@
   }
 
   public final class PowerManager {
+    method @FlaggedApi("android.os.allow_thermal_thresholds_callback") public void addThermalHeadroomListener(@NonNull android.os.PowerManager.OnThermalHeadroomChangedListener);
+    method @FlaggedApi("android.os.allow_thermal_thresholds_callback") public void addThermalHeadroomListener(@NonNull java.util.concurrent.Executor, @NonNull android.os.PowerManager.OnThermalHeadroomChangedListener);
     method public void addThermalStatusListener(@NonNull android.os.PowerManager.OnThermalStatusChangedListener);
     method public void addThermalStatusListener(@NonNull java.util.concurrent.Executor, @NonNull android.os.PowerManager.OnThermalStatusChangedListener);
     method @Nullable public java.time.Duration getBatteryDischargePrediction();
     method public int getCurrentThermalStatus();
     method public int getLocationPowerSaveMode();
-    method public float getThermalHeadroom(@IntRange(from=0, to=60) int);
+    method @FloatRange(from=0.0f) public float getThermalHeadroom(@IntRange(from=0, to=60) int);
     method @FlaggedApi("android.os.allow_thermal_headroom_thresholds") @NonNull public java.util.Map<java.lang.Integer,java.lang.Float> getThermalHeadroomThresholds();
     method public boolean isAllowedInLowPowerStandby(int);
     method public boolean isAllowedInLowPowerStandby(@NonNull String);
@@ -33885,6 +33887,7 @@
     method public boolean isWakeLockLevelSupported(int);
     method public android.os.PowerManager.WakeLock newWakeLock(int, String);
     method @RequiresPermission(android.Manifest.permission.REBOOT) public void reboot(@Nullable String);
+    method @FlaggedApi("android.os.allow_thermal_thresholds_callback") public void removeThermalHeadroomListener(@NonNull android.os.PowerManager.OnThermalHeadroomChangedListener);
     method public void removeThermalStatusListener(@NonNull android.os.PowerManager.OnThermalStatusChangedListener);
     field @Deprecated @RequiresPermission(value=android.Manifest.permission.TURN_SCREEN_ON, conditional=true) public static final int ACQUIRE_CAUSES_WAKEUP = 268435456; // 0x10000000
     field public static final String ACTION_DEVICE_IDLE_MODE_CHANGED = "android.os.action.DEVICE_IDLE_MODE_CHANGED";
@@ -33917,6 +33920,10 @@
     field public static final int THERMAL_STATUS_SHUTDOWN = 6; // 0x6
   }
 
+  @FlaggedApi("android.os.allow_thermal_thresholds_callback") public static interface PowerManager.OnThermalHeadroomChangedListener {
+    method public void onThermalHeadroomChanged(@FloatRange(from=0.0f) float, @FloatRange(from=0.0f) float, @IntRange(from=0) int, @NonNull java.util.Map<java.lang.Integer,java.lang.Float>);
+  }
+
   public static interface PowerManager.OnThermalStatusChangedListener {
     method public void onThermalStatusChanged(int);
   }
diff --git a/core/java/Android.bp b/core/java/Android.bp
index 9875efe..71623c5 100644
--- a/core/java/Android.bp
+++ b/core/java/Android.bp
@@ -206,6 +206,7 @@
         "android/os/Temperature.aidl",
         "android/os/CoolingDevice.aidl",
         "android/os/IThermalEventListener.aidl",
+        "android/os/IThermalHeadroomListener.aidl",
         "android/os/IThermalStatusListener.aidl",
         "android/os/IThermalService.aidl",
         "android/os/IPowerManager.aidl",
diff --git a/core/java/android/os/IThermalHeadroomListener.aidl b/core/java/android/os/IThermalHeadroomListener.aidl
new file mode 100644
index 0000000..b2797d8
--- /dev/null
+++ b/core/java/android/os/IThermalHeadroomListener.aidl
@@ -0,0 +1,31 @@
+/*
+** Copyright 2024, The Android Open Source Project
+**
+** Licensed under the Apache License, Version 2.0 (the "License");
+** you may not use this file except in compliance with the License.
+** You may obtain a copy of the License at
+**
+**     http://www.apache.org/licenses/LICENSE-2.0
+**
+** Unless required by applicable law or agreed to in writing, software
+** distributed under the License is distributed on an "AS IS" BASIS,
+** WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+** See the License for the specific language governing permissions and
+** limitations under the License.
+*/
+
+package android.os;
+
+/**
+ * Listener for thermal headroom and threshold changes.
+ * This is mainly used by {@link android.os.PowerManager} to serve public thermal headoom related
+ * APIs.
+ * {@hide}
+ */
+oneway interface IThermalHeadroomListener {
+    /**
+     * Called when thermal headroom or thresholds changed.
+     */
+    void onHeadroomChange(in float headroom, in float forecastHeadroom,
+                                 in int forecastSeconds, in float[] thresholds);
+}
diff --git a/core/java/android/os/IThermalService.aidl b/core/java/android/os/IThermalService.aidl
index bcffa45..aa3bcfa 100644
--- a/core/java/android/os/IThermalService.aidl
+++ b/core/java/android/os/IThermalService.aidl
@@ -18,6 +18,7 @@
 
 import android.os.CoolingDevice;
 import android.os.IThermalEventListener;
+import android.os.IThermalHeadroomListener;
 import android.os.IThermalStatusListener;
 import android.os.Temperature;
 
@@ -116,4 +117,20 @@
      * @return thermal headroom for each thermal status
      */
     float[] getThermalHeadroomThresholds();
+
+    /**
+      * Register a listener for thermal headroom change.
+      * @param listener the {@link android.os.IThermalHeadroomListener} to be notified.
+      * @return true if registered successfully.
+      * {@hide}
+      */
+    boolean registerThermalHeadroomListener(in IThermalHeadroomListener listener);
+
+    /**
+      * Unregister a previously-registered listener for thermal headroom.
+      * @param listener the {@link android.os.IThermalHeadroomListener} to no longer be notified.
+      * @return true if unregistered successfully.
+      * {@hide}
+      */
+    boolean unregisterThermalHeadroomListener(in IThermalHeadroomListener listener);
 }
diff --git a/core/java/android/os/PowerManager.java b/core/java/android/os/PowerManager.java
index 32db3be..5a1c8b4 100644
--- a/core/java/android/os/PowerManager.java
+++ b/core/java/android/os/PowerManager.java
@@ -20,6 +20,7 @@
 import android.annotation.CallbackExecutor;
 import android.annotation.CurrentTimeMillisLong;
 import android.annotation.FlaggedApi;
+import android.annotation.FloatRange;
 import android.annotation.IntDef;
 import android.annotation.IntRange;
 import android.annotation.NonNull;
@@ -40,6 +41,7 @@
 import android.util.proto.ProtoOutputStream;
 import android.view.Display;
 
+import com.android.internal.annotations.GuardedBy;
 import com.android.internal.util.Preconditions;
 
 import java.lang.annotation.ElementType;
@@ -1191,10 +1193,12 @@
     /** We lazily initialize it.*/
     private PowerExemptionManager mPowerExemptionManager;
 
+    @GuardedBy("mThermalStatusListenerMap")
     private final ArrayMap<OnThermalStatusChangedListener, IThermalStatusListener>
-            mListenerMap = new ArrayMap<>();
-    private final Object mThermalHeadroomThresholdsLock = new Object();
-    private float[] mThermalHeadroomThresholds = null;
+            mThermalStatusListenerMap = new ArrayMap<>();
+    @GuardedBy("mThermalHeadroomListenerMap")
+    private final ArrayMap<OnThermalHeadroomChangedListener, IThermalHeadroomListener>
+            mThermalHeadroomListenerMap = new ArrayMap<>();
 
     /**
      * {@hide}
@@ -2681,15 +2685,59 @@
         void onThermalStatusChanged(@ThermalStatus int status);
     }
 
+    /**
+     * Listener passed to
+     * {@link PowerManager#addThermalHeadroomListener} and
+     * {@link PowerManager#removeThermalHeadroomListener}
+     * to notify caller of Thermal headroom or thresholds changes.
+     */
+    @FlaggedApi(Flags.FLAG_ALLOW_THERMAL_THRESHOLDS_CALLBACK)
+    public interface OnThermalHeadroomChangedListener {
+
+        /**
+         * Called when overall thermal headroom or headroom thresholds have significantly
+         * changed that requires action.
+         * <p>
+         * This may not be used to fully replace the {@link #getThermalHeadroom(int)} API as it will
+         * only notify on one of the conditions below that will significantly change one or both
+         * values of current headroom and headroom thresholds since previous callback:
+         *   1. thermal throttling events: when the skin temperature has cross any of the thresholds
+         *      and there isn't a previous callback in a short time ago with similar values.
+         *   2. skin temperature threshold change events: note that if the absolute °C threshold
+         *      values change in a way that does not significantly change the current headroom nor
+         *      headroom thresholds, it will not trigger any callback. The client should not
+         *      need to take action in such case since the difference from temperature vs threshold
+         *      hasn't changed.
+         * <p>
+         * By API version 36, it provides a forecast in the same call for developer's convenience
+         * based on a {@code forecastSeconds} defined by the device, which can be static or dynamic
+         * varied by OEM. Be aware that it will not notify on forecast temperature change but the
+         * events mentioned above. So periodically polling against {@link #getThermalHeadroom(int)}
+         * API should still be used to actively monitor temperature forecast in advance.
+         * <p>
+         * This serves as a more advanced option compared to thermal status listener, where the
+         * latter will only notify on thermal throttling events with status update.
+         *
+         * @param headroom current headroom
+         * @param forecastHeadroom forecasted headroom in future
+         * @param forecastSeconds how many seconds in the future used in forecast
+         * @param thresholds new headroom thresholds, see {@link #getThermalHeadroomThresholds()}
+         */
+        void onThermalHeadroomChanged(
+                @FloatRange(from = 0f) float headroom,
+                @FloatRange(from = 0f) float forecastHeadroom,
+                @IntRange(from = 0) int forecastSeconds,
+                @NonNull Map<@ThermalStatus Integer, Float> thresholds);
+    }
 
     /**
-     * This function adds a listener for thermal status change, listen call back will be
+     * This function adds a listener for thermal status change, listener callback will be
      * enqueued tasks on the main thread
      *
      * @param listener listener to be added,
      */
     public void addThermalStatusListener(@NonNull OnThermalStatusChangedListener listener) {
-        Objects.requireNonNull(listener, "listener cannot be null");
+        Objects.requireNonNull(listener, "Thermal status listener cannot be null");
         addThermalStatusListener(mContext.getMainExecutor(), listener);
     }
 
@@ -2701,29 +2749,31 @@
      */
     public void addThermalStatusListener(@NonNull @CallbackExecutor Executor executor,
             @NonNull OnThermalStatusChangedListener listener) {
-        Objects.requireNonNull(listener, "listener cannot be null");
-        Objects.requireNonNull(executor, "executor cannot be null");
-        Preconditions.checkArgument(!mListenerMap.containsKey(listener),
-                "Listener already registered: %s", listener);
-        IThermalStatusListener internalListener = new IThermalStatusListener.Stub() {
-            @Override
-            public void onStatusChange(int status) {
-                final long token = Binder.clearCallingIdentity();
-                try {
-                    executor.execute(() -> listener.onThermalStatusChanged(status));
-                } finally {
-                    Binder.restoreCallingIdentity(token);
+        Objects.requireNonNull(listener, "Thermal status listener cannot be null");
+        Objects.requireNonNull(executor, "Executor cannot be null");
+        synchronized (mThermalStatusListenerMap) {
+            Preconditions.checkArgument(!mThermalStatusListenerMap.containsKey(listener),
+                    "Thermal status listener already registered: %s", listener);
+            IThermalStatusListener internalListener = new IThermalStatusListener.Stub() {
+                @Override
+                public void onStatusChange(int status) {
+                    final long token = Binder.clearCallingIdentity();
+                    try {
+                        executor.execute(() -> listener.onThermalStatusChanged(status));
+                    } finally {
+                        Binder.restoreCallingIdentity(token);
+                    }
                 }
+            };
+            try {
+                if (mThermalService.registerThermalStatusListener(internalListener)) {
+                    mThermalStatusListenerMap.put(listener, internalListener);
+                } else {
+                    throw new RuntimeException("Thermal status listener failed to set");
+                }
+            } catch (RemoteException e) {
+                throw e.rethrowFromSystemServer();
             }
-        };
-        try {
-            if (mThermalService.registerThermalStatusListener(internalListener)) {
-                mListenerMap.put(listener, internalListener);
-            } else {
-                throw new RuntimeException("Listener failed to set");
-            }
-        } catch (RemoteException e) {
-            throw e.rethrowFromSystemServer();
         }
     }
 
@@ -2733,20 +2783,101 @@
      * @param listener listener to be removed
      */
     public void removeThermalStatusListener(@NonNull OnThermalStatusChangedListener listener) {
-        Objects.requireNonNull(listener, "listener cannot be null");
-        IThermalStatusListener internalListener = mListenerMap.get(listener);
-        Preconditions.checkArgument(internalListener != null, "Listener was not added");
-        try {
-            if (mThermalService.unregisterThermalStatusListener(internalListener)) {
-                mListenerMap.remove(listener);
-            } else {
-                throw new RuntimeException("Listener failed to remove");
+        Objects.requireNonNull(listener, "Thermal status listener cannot be null");
+        synchronized (mThermalStatusListenerMap) {
+            IThermalStatusListener internalListener = mThermalStatusListenerMap.get(listener);
+            Preconditions.checkArgument(internalListener != null,
+                    "Thermal status listener was not added");
+            try {
+                if (mThermalService.unregisterThermalStatusListener(internalListener)) {
+                    mThermalStatusListenerMap.remove(listener);
+                } else {
+                    throw new RuntimeException("Failed to unregister thermal status listener");
+                }
+            } catch (RemoteException e) {
+                throw e.rethrowFromSystemServer();
             }
-        } catch (RemoteException e) {
-            throw e.rethrowFromSystemServer();
         }
     }
 
+    /**
+     * This function adds a listener for thermal headroom change, listener callback will be
+     * enqueued tasks on the main thread
+     *
+     * @param listener listener to be added,
+     */
+    @FlaggedApi(Flags.FLAG_ALLOW_THERMAL_THRESHOLDS_CALLBACK)
+    public void addThermalHeadroomListener(@NonNull OnThermalHeadroomChangedListener listener) {
+        Objects.requireNonNull(listener, "Thermal headroom listener cannot be null");
+        addThermalHeadroomListener(mContext.getMainExecutor(), listener);
+    }
+
+    /**
+     * This function adds a listener for thermal headroom change.
+     *
+     * @param executor {@link Executor} to handle listener callback.
+     * @param listener listener to be added.
+     */
+    @FlaggedApi(Flags.FLAG_ALLOW_THERMAL_THRESHOLDS_CALLBACK)
+    public void addThermalHeadroomListener(@NonNull @CallbackExecutor Executor executor,
+            @NonNull OnThermalHeadroomChangedListener listener) {
+        Objects.requireNonNull(listener, "Thermal headroom listener cannot be null");
+        Objects.requireNonNull(executor, "Executor cannot be null");
+        synchronized (mThermalHeadroomListenerMap) {
+            Preconditions.checkArgument(!mThermalHeadroomListenerMap.containsKey(listener),
+                    "Thermal headroom listener already registered: %s", listener);
+            IThermalHeadroomListener internalListener = new IThermalHeadroomListener.Stub() {
+                @Override
+                public void onHeadroomChange(float headroom, float forecastHeadroom,
+                        int forecastSeconds, float[] thresholds)
+                        throws RemoteException {
+                    final Map<Integer, Float> map = convertThresholdsToMap(thresholds);
+                    final long token = Binder.clearCallingIdentity();
+                    try {
+                        executor.execute(() -> listener.onThermalHeadroomChanged(headroom,
+                                forecastHeadroom, forecastSeconds, map));
+                    } finally {
+                        Binder.restoreCallingIdentity(token);
+                    }
+                }
+            };
+            try {
+                if (mThermalService.registerThermalHeadroomListener(internalListener)) {
+                    mThermalHeadroomListenerMap.put(listener, internalListener);
+                } else {
+                    throw new RuntimeException("Thermal headroom listener failed to set");
+                }
+            } catch (RemoteException e) {
+                throw e.rethrowFromSystemServer();
+            }
+        }
+    }
+
+    /**
+     * This function removes a listener for Thermal headroom change
+     *
+     * @param listener listener to be removed
+     */
+    @FlaggedApi(Flags.FLAG_ALLOW_THERMAL_THRESHOLDS_CALLBACK)
+    public void removeThermalHeadroomListener(@NonNull OnThermalHeadroomChangedListener listener) {
+        Objects.requireNonNull(listener, "Thermal headroom listener cannot be null");
+        synchronized (mThermalHeadroomListenerMap) {
+            IThermalHeadroomListener internalListener = mThermalHeadroomListenerMap.get(listener);
+            Preconditions.checkArgument(internalListener != null,
+                    "Thermal headroom listener was not added");
+            try {
+                if (mThermalService.unregisterThermalHeadroomListener(internalListener)) {
+                    mThermalHeadroomListenerMap.remove(listener);
+                } else {
+                    throw new RuntimeException("Failed to unregister thermal status listener");
+                }
+            } catch (RemoteException e) {
+                throw e.rethrowFromSystemServer();
+            }
+        }
+    }
+
+
     @CurrentTimeMillisLong
     private final AtomicLong mLastHeadroomUpdate = new AtomicLong(0L);
     private static final int MINIMUM_HEADROOM_TIME_MILLIS = 500;
@@ -2786,7 +2917,8 @@
      *         functionality or if this function is called significantly faster than once per
      *         second.
      */
-    public float getThermalHeadroom(@IntRange(from = 0, to = 60) int forecastSeconds) {
+    public @FloatRange(from = 0f) float getThermalHeadroom(
+            @IntRange(from = 0, to = 60) int forecastSeconds) {
         // Rate-limit calls into the thermal service
         long now = SystemClock.elapsedRealtime();
         long timeSinceLastUpdate = now - mLastHeadroomUpdate.get();
@@ -2831,9 +2963,11 @@
      * headroom of 0.75 will never come with {@link #THERMAL_STATUS_MODERATE} but lower, and 0.65
      * will never come with {@link #THERMAL_STATUS_LIGHT} but {@link #THERMAL_STATUS_NONE}.
      * <p>
-     * The returned map of thresholds will not change between calls to this function, so it's
-     * best to call this once on initialization. Modifying the result will not change the thresholds
-     * cached by the system, and a new call to the API will get a new copy.
+     * Starting at {@link android.os.Build.VERSION_CODES#BAKLAVA} the returned map of thresholds can
+     * change between calls to this function, one could use the new
+     * {@link #addThermalHeadroomListener(Executor, OnThermalHeadroomChangedListener)} API to
+     * register a listener and get callback for changes to thresholds.
+     * <p>
      *
      * @return map from each thermal status to its thermal headroom
      * @throws IllegalStateException if the thermal service is not ready
@@ -2842,24 +2976,22 @@
     @FlaggedApi(Flags.FLAG_ALLOW_THERMAL_HEADROOM_THRESHOLDS)
     public @NonNull Map<@ThermalStatus Integer, Float> getThermalHeadroomThresholds() {
         try {
-            synchronized (mThermalHeadroomThresholdsLock) {
-                if (mThermalHeadroomThresholds == null) {
-                    mThermalHeadroomThresholds = mThermalService.getThermalHeadroomThresholds();
-                }
-                final ArrayMap<Integer, Float> ret = new ArrayMap<>(THERMAL_STATUS_SHUTDOWN);
-                for (int status = THERMAL_STATUS_LIGHT; status <= THERMAL_STATUS_SHUTDOWN;
-                        status++) {
-                    if (!Float.isNaN(mThermalHeadroomThresholds[status])) {
-                        ret.put(status, mThermalHeadroomThresholds[status]);
-                    }
-                }
-                return ret;
-            }
+            return convertThresholdsToMap(mThermalService.getThermalHeadroomThresholds());
         } catch (RemoteException e) {
             throw e.rethrowFromSystemServer();
         }
     }
 
+    private Map<@ThermalStatus Integer, Float> convertThresholdsToMap(final float[] thresholds) {
+        final ArrayMap<Integer, Float> ret = new ArrayMap<>(THERMAL_STATUS_SHUTDOWN);
+        for (int status = THERMAL_STATUS_LIGHT; status <= THERMAL_STATUS_SHUTDOWN; status++) {
+            if (!Float.isNaN(thresholds[status])) {
+                ret.put(status, thresholds[status]);
+            }
+        }
+        return ret;
+    }
+
     /**
      * If true, the doze component is not started until after the screen has been
      * turned off and the screen off animation has been performed.
diff --git a/core/tests/coretests/src/android/os/PowerManagerTest.java b/core/tests/coretests/src/android/os/PowerManagerTest.java
index 3b27fc0..e4e965f 100644
--- a/core/tests/coretests/src/android/os/PowerManagerTest.java
+++ b/core/tests/coretests/src/android/os/PowerManagerTest.java
@@ -21,6 +21,8 @@
 import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyInt;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.reset;
 import static org.mockito.Mockito.timeout;
@@ -61,9 +63,13 @@
     private UiDevice mUiDevice;
     private Executor mExec = Executors.newSingleThreadExecutor();
     @Mock
-    private PowerManager.OnThermalStatusChangedListener mListener1;
+    private PowerManager.OnThermalStatusChangedListener mStatusListener1;
     @Mock
-    private PowerManager.OnThermalStatusChangedListener mListener2;
+    private PowerManager.OnThermalStatusChangedListener mStatusListener2;
+    @Mock
+    private PowerManager.OnThermalHeadroomChangedListener mHeadroomListener1;
+    @Mock
+    private PowerManager.OnThermalHeadroomChangedListener mHeadroomListener2;
     private static final long CALLBACK_TIMEOUT_MILLI_SEC = 5000;
     private native Parcel nativeObtainPowerSaveStateParcel(boolean batterySaverEnabled,
             boolean globalBatterySaverEnabled, int locationMode, int soundTriggerMode,
@@ -245,53 +251,90 @@
         // Initial override status is THERMAL_STATUS_NONE
         int status = PowerManager.THERMAL_STATUS_NONE;
         // Add listener1
-        mPm.addThermalStatusListener(mExec, mListener1);
-        verify(mListener1, timeout(CALLBACK_TIMEOUT_MILLI_SEC)
+        mPm.addThermalStatusListener(mExec, mStatusListener1);
+        verify(mStatusListener1, timeout(CALLBACK_TIMEOUT_MILLI_SEC)
                 .times(1)).onThermalStatusChanged(status);
-        reset(mListener1);
+        reset(mStatusListener1);
         status = PowerManager.THERMAL_STATUS_SEVERE;
         mUiDevice.executeShellCommand("cmd thermalservice override-status "
                 + Integer.toString(status));
-        verify(mListener1, timeout(CALLBACK_TIMEOUT_MILLI_SEC)
+        verify(mStatusListener1, timeout(CALLBACK_TIMEOUT_MILLI_SEC)
                 .times(1)).onThermalStatusChanged(status);
-        reset(mListener1);
+        reset(mStatusListener1);
         // Add listener1 again
         try {
-            mPm.addThermalStatusListener(mListener1);
+            mPm.addThermalStatusListener(mStatusListener1);
             fail("Expected exception not thrown");
         } catch (IllegalArgumentException expectedException) {
         }
         // Add listener2 on main thread.
-        mPm.addThermalStatusListener(mListener2);
-        verify(mListener2, timeout(CALLBACK_TIMEOUT_MILLI_SEC)
+        mPm.addThermalStatusListener(mStatusListener2);
+        verify(mStatusListener2, timeout(CALLBACK_TIMEOUT_MILLI_SEC)
             .times(1)).onThermalStatusChanged(status);
-        reset(mListener2);
+        reset(mStatusListener2);
         status = PowerManager.THERMAL_STATUS_MODERATE;
         mUiDevice.executeShellCommand("cmd thermalservice override-status "
                 + Integer.toString(status));
-        verify(mListener1, timeout(CALLBACK_TIMEOUT_MILLI_SEC)
+        verify(mStatusListener1, timeout(CALLBACK_TIMEOUT_MILLI_SEC)
                 .times(1)).onThermalStatusChanged(status);
-        verify(mListener2, timeout(CALLBACK_TIMEOUT_MILLI_SEC)
+        verify(mStatusListener2, timeout(CALLBACK_TIMEOUT_MILLI_SEC)
                 .times(1)).onThermalStatusChanged(status);
-        reset(mListener1);
-        reset(mListener2);
+        reset(mStatusListener1);
+        reset(mStatusListener2);
         // Remove listener1
-        mPm.removeThermalStatusListener(mListener1);
+        mPm.removeThermalStatusListener(mStatusListener1);
         // Remove listener1 again
         try {
-            mPm.removeThermalStatusListener(mListener1);
+            mPm.removeThermalStatusListener(mStatusListener1);
             fail("Expected exception not thrown");
         } catch (IllegalArgumentException expectedException) {
         }
         status = PowerManager.THERMAL_STATUS_LIGHT;
         mUiDevice.executeShellCommand("cmd thermalservice override-status "
                 + Integer.toString(status));
-        verify(mListener1, timeout(CALLBACK_TIMEOUT_MILLI_SEC)
+        verify(mStatusListener1, timeout(CALLBACK_TIMEOUT_MILLI_SEC)
                 .times(0)).onThermalStatusChanged(status);
-        verify(mListener2, timeout(CALLBACK_TIMEOUT_MILLI_SEC)
+        verify(mStatusListener2, timeout(CALLBACK_TIMEOUT_MILLI_SEC)
                 .times(1)).onThermalStatusChanged(status);
     }
 
+    /**
+     * Confirm that we can add/remove thermal headroom listener.
+     */
+    @Test
+    @RequiresFlagsEnabled(Flags.FLAG_ALLOW_THERMAL_THRESHOLDS_CALLBACK)
+    public void testThermalHeadroomCallback() throws Exception {
+        float headroom = mPm.getThermalHeadroom(0);
+        // If the device doesn't support thermal headroom, return early
+        if (Float.isNaN(headroom)) {
+            return;
+        }
+        // Add listener1
+        mPm.addThermalHeadroomListener(mExec, mHeadroomListener1);
+        verify(mHeadroomListener1, timeout(CALLBACK_TIMEOUT_MILLI_SEC)
+                .times(1)).onThermalHeadroomChanged(anyInt(), anyInt(), anyInt(), any());
+        reset(mHeadroomListener1);
+        // Add listener1 again
+        try {
+            mPm.addThermalHeadroomListener(mHeadroomListener1);
+            fail("Expected exception not thrown");
+        } catch (IllegalArgumentException expectedException) {
+        }
+        // Add listener2 on main thread.
+        mPm.addThermalHeadroomListener(mHeadroomListener2);
+        verify(mHeadroomListener2, timeout(CALLBACK_TIMEOUT_MILLI_SEC)
+                .times(1)).onThermalHeadroomChanged(anyInt(), anyInt(), anyInt(), any());
+        reset(mHeadroomListener2);
+        // Remove listener1
+        mPm.removeThermalHeadroomListener(mHeadroomListener1);
+        // Remove listener1 again
+        try {
+            mPm.removeThermalHeadroomListener(mHeadroomListener1);
+            fail("Expected exception not thrown");
+        } catch (IllegalArgumentException expectedException) {
+        }
+    }
+
     @Test
     public void testGetThermalHeadroom() throws Exception {
         float headroom = mPm.getThermalHeadroom(0);
diff --git a/native/android/tests/thermal/NativeThermalUnitTest.cpp b/native/android/tests/thermal/NativeThermalUnitTest.cpp
index 6d6861a..4e319fc 100644
--- a/native/android/tests/thermal/NativeThermalUnitTest.cpp
+++ b/native/android/tests/thermal/NativeThermalUnitTest.cpp
@@ -67,6 +67,14 @@
     MOCK_METHOD(Status, getThermalHeadroomThresholds, (::std::vector<float> * _aidl_return),
                 (override));
     MOCK_METHOD(IBinder*, onAsBinder, (), (override));
+    MOCK_METHOD(Status, registerThermalHeadroomListener,
+                (const ::android::sp<::android::os::IThermalHeadroomListener>& listener,
+                 bool* _aidl_return),
+                (override));
+    MOCK_METHOD(Status, unregisterThermalHeadroomListener,
+                (const ::android::sp<::android::os::IThermalHeadroomListener>& listener,
+                 bool* _aidl_return),
+                (override));
 };
 
 class NativeThermalUnitTest : public Test {
diff --git a/services/core/java/com/android/server/power/ThermalManagerService.java b/services/core/java/com/android/server/power/ThermalManagerService.java
index 78bc06c..42dbb79 100644
--- a/services/core/java/com/android/server/power/ThermalManagerService.java
+++ b/services/core/java/com/android/server/power/ThermalManagerService.java
@@ -43,6 +43,7 @@
 import android.os.HwBinder;
 import android.os.IBinder;
 import android.os.IThermalEventListener;
+import android.os.IThermalHeadroomListener;
 import android.os.IThermalService;
 import android.os.IThermalStatusListener;
 import android.os.PowerManager;
@@ -59,6 +60,7 @@
 import android.util.ArrayMap;
 import android.util.EventLog;
 import android.util.Slog;
+import android.util.SparseArray;
 import android.util.StatsEvent;
 
 import com.android.internal.annotations.GuardedBy;
@@ -96,6 +98,15 @@
     /** Input range limits for getThermalHeadroom API */
     public static final int MIN_FORECAST_SEC = 0;
     public static final int MAX_FORECAST_SEC = 60;
+    public static final int DEFAULT_FORECAST_SECONDS = 10;
+    public static final int HEADROOM_CALLBACK_MIN_INTERVAL_MILLIS = 5000;
+    // headroom to temperature conversion: 3C every 0.1 headroom difference
+    // if no throttling event, the temperature difference should be at least 0.9C (or 0.03 headroom)
+    // to make a callback
+    public static final float HEADROOM_CALLBACK_MIN_DIFFERENCE = 0.03f;
+    // if no throttling event, the threshold headroom difference should be at least 0.01 (or 0.3C)
+    // to make a callback
+    public static final float HEADROOM_THRESHOLD_CALLBACK_MIN_DIFFERENCE = 0.01f;
 
     /** Lock to protect listen list. */
     private final Object mLock = new Object();
@@ -113,6 +124,15 @@
     private final RemoteCallbackList<IThermalStatusListener> mThermalStatusListeners =
             new RemoteCallbackList<>();
 
+    /** Registered observers of the thermal headroom. */
+    @GuardedBy("mLock")
+    private final RemoteCallbackList<IThermalHeadroomListener> mThermalHeadroomListeners =
+            new RemoteCallbackList<>();
+    @GuardedBy("mLock")
+    private long mLastHeadroomCallbackTimeMillis;
+    @GuardedBy("mLock")
+    private HeadroomCallbackData mLastHeadroomCallbackData = null;
+
     /** Current thermal status */
     @GuardedBy("mLock")
     private int mStatus;
@@ -133,7 +153,7 @@
 
     /** Watches temperatures to forecast when throttling will occur */
     @VisibleForTesting
-    final TemperatureWatcher mTemperatureWatcher = new TemperatureWatcher();
+    final TemperatureWatcher mTemperatureWatcher;
 
     private final ThermalHalWrapper.WrapperThermalChangedCallback mWrapperCallback =
             new ThermalHalWrapper.WrapperThermalChangedCallback() {
@@ -151,8 +171,14 @@
                 public void onThresholdChanged(TemperatureThreshold threshold) {
                     final long token = Binder.clearCallingIdentity();
                     try {
+                        final HeadroomCallbackData data;
                         synchronized (mTemperatureWatcher.mSamples) {
+                            Slog.d(TAG, "Updating skin threshold: " + threshold);
                             mTemperatureWatcher.updateTemperatureThresholdLocked(threshold, true);
+                            data = mTemperatureWatcher.getHeadroomCallbackDataLocked();
+                        }
+                        synchronized (mLock) {
+                            checkAndNotifyHeadroomListenersLocked(data);
                         }
                     } finally {
                         Binder.restoreCallingIdentity(token);
@@ -175,6 +201,7 @@
             halWrapper.setCallback(mWrapperCallback);
         }
         mStatus = Temperature.THROTTLING_NONE;
+        mTemperatureWatcher = new TemperatureWatcher();
     }
 
     @Override
@@ -231,32 +258,79 @@
         }
     }
 
-    private void postStatusListener(IThermalStatusListener listener) {
+    @GuardedBy("mLock")
+    private void postStatusListenerLocked(IThermalStatusListener listener) {
         final boolean thermalCallbackQueued = FgThread.getHandler().post(() -> {
             try {
                 listener.onStatusChange(mStatus);
             } catch (RemoteException | RuntimeException e) {
-                Slog.e(TAG, "Thermal callback failed to call", e);
+                Slog.e(TAG, "Thermal status callback failed to call", e);
             }
         });
         if (!thermalCallbackQueued) {
-            Slog.e(TAG, "Thermal callback failed to queue");
+            Slog.e(TAG, "Thermal status callback failed to queue");
         }
     }
 
+    @GuardedBy("mLock")
     private void notifyStatusListenersLocked() {
         final int length = mThermalStatusListeners.beginBroadcast();
         try {
             for (int i = 0; i < length; i++) {
                 final IThermalStatusListener listener =
                         mThermalStatusListeners.getBroadcastItem(i);
-                postStatusListener(listener);
+                postStatusListenerLocked(listener);
             }
         } finally {
             mThermalStatusListeners.finishBroadcast();
         }
     }
 
+    @GuardedBy("mLock")
+    private void postHeadroomListenerLocked(IThermalHeadroomListener listener,
+            HeadroomCallbackData data) {
+        if (!mHalReady.get()) {
+            return;
+        }
+        final boolean thermalCallbackQueued = FgThread.getHandler().post(() -> {
+            try {
+                if (Float.isNaN(data.mHeadroom)) {
+                    return;
+                }
+                listener.onHeadroomChange(data.mHeadroom, data.mForecastHeadroom,
+                        data.mForecastSeconds, data.mHeadroomThresholds);
+            } catch (RemoteException | RuntimeException e) {
+                Slog.e(TAG, "Thermal headroom callback failed to call", e);
+            }
+        });
+        if (!thermalCallbackQueued) {
+            Slog.e(TAG, "Thermal headroom callback failed to queue");
+        }
+    }
+
+    @GuardedBy("mLock")
+    private void checkAndNotifyHeadroomListenersLocked(HeadroomCallbackData data) {
+        if (!data.isSignificantDifferentFrom(mLastHeadroomCallbackData)
+                && System.currentTimeMillis()
+                < mLastHeadroomCallbackTimeMillis + HEADROOM_CALLBACK_MIN_INTERVAL_MILLIS) {
+            // skip notifying the client with similar data within a short period
+            return;
+        }
+        mLastHeadroomCallbackTimeMillis = System.currentTimeMillis();
+        mLastHeadroomCallbackData = data;
+        final int length = mThermalHeadroomListeners.beginBroadcast();
+        try {
+            for (int i = 0; i < length; i++) {
+                final IThermalHeadroomListener listener =
+                        mThermalHeadroomListeners.getBroadcastItem(i);
+                postHeadroomListenerLocked(listener, data);
+            }
+        } finally {
+            mThermalHeadroomListeners.finishBroadcast();
+        }
+    }
+
+    @GuardedBy("mLock")
     private void onTemperatureMapChangedLocked() {
         int newStatus = Temperature.THROTTLING_NONE;
         final int count = mTemperatureMap.size();
@@ -272,6 +346,7 @@
         }
     }
 
+    @GuardedBy("mLock")
     private void setStatusLocked(int newStatus) {
         if (newStatus != mStatus) {
             Trace.traceCounter(Trace.TRACE_TAG_POWER, "ThermalManagerService.status", newStatus);
@@ -280,18 +355,18 @@
         }
     }
 
-    private void postEventListenerCurrentTemperatures(IThermalEventListener listener,
+    @GuardedBy("mLock")
+    private void postEventListenerCurrentTemperaturesLocked(IThermalEventListener listener,
             @Nullable Integer type) {
-        synchronized (mLock) {
-            final int count = mTemperatureMap.size();
-            for (int i = 0; i < count; i++) {
-                postEventListener(mTemperatureMap.valueAt(i), listener,
-                        type);
-            }
+        final int count = mTemperatureMap.size();
+        for (int i = 0; i < count; i++) {
+            postEventListenerLocked(mTemperatureMap.valueAt(i), listener,
+                    type);
         }
     }
 
-    private void postEventListener(Temperature temperature,
+    @GuardedBy("mLock")
+    private void postEventListenerLocked(Temperature temperature,
             IThermalEventListener listener,
             @Nullable Integer type) {
         // Skip if listener registered with a different type
@@ -302,14 +377,15 @@
             try {
                 listener.notifyThrottling(temperature);
             } catch (RemoteException | RuntimeException e) {
-                Slog.e(TAG, "Thermal callback failed to call", e);
+                Slog.e(TAG, "Thermal event callback failed to call", e);
             }
         });
         if (!thermalCallbackQueued) {
-            Slog.e(TAG, "Thermal callback failed to queue");
+            Slog.e(TAG, "Thermal event callback failed to queue");
         }
     }
 
+    @GuardedBy("mLock")
     private void notifyEventListenersLocked(Temperature temperature) {
         final int length = mThermalEventListeners.beginBroadcast();
         try {
@@ -318,7 +394,7 @@
                         mThermalEventListeners.getBroadcastItem(i);
                 final Integer type =
                         (Integer) mThermalEventListeners.getBroadcastCookie(i);
-                postEventListener(temperature, listener, type);
+                postEventListenerLocked(temperature, listener, type);
             }
         } finally {
             mThermalEventListeners.finishBroadcast();
@@ -348,17 +424,31 @@
         }
     }
 
-    private void onTemperatureChanged(Temperature temperature, boolean sendStatus) {
+    private void onTemperatureChanged(Temperature temperature, boolean sendCallback) {
         shutdownIfNeeded(temperature);
         synchronized (mLock) {
             Temperature old = mTemperatureMap.put(temperature.getName(), temperature);
             if (old == null || old.getStatus() != temperature.getStatus()) {
                 notifyEventListenersLocked(temperature);
             }
-            if (sendStatus) {
+            if (sendCallback) {
                 onTemperatureMapChangedLocked();
             }
         }
+        if (sendCallback && Flags.allowThermalThresholdsCallback()
+                && temperature.getType() == Temperature.TYPE_SKIN) {
+            final HeadroomCallbackData data;
+            synchronized (mTemperatureWatcher.mSamples) {
+                Slog.d(TAG, "Updating new temperature: " + temperature);
+                mTemperatureWatcher.updateTemperatureSampleLocked(System.currentTimeMillis(),
+                        temperature);
+                mTemperatureWatcher.mCachedHeadrooms.clear();
+                data = mTemperatureWatcher.getHeadroomCallbackDataLocked();
+            }
+            synchronized (mLock) {
+                checkAndNotifyHeadroomListenersLocked(data);
+            }
+        }
     }
 
     private void registerStatsCallbacks() {
@@ -399,7 +489,7 @@
                         return false;
                     }
                     // Notify its callback after new client registered.
-                    postEventListenerCurrentTemperatures(listener, null);
+                    postEventListenerCurrentTemperaturesLocked(listener, null);
                     return true;
                 } finally {
                     Binder.restoreCallingIdentity(token);
@@ -415,11 +505,11 @@
             synchronized (mLock) {
                 final long token = Binder.clearCallingIdentity();
                 try {
-                    if (!mThermalEventListeners.register(listener, new Integer(type))) {
+                    if (!mThermalEventListeners.register(listener, type)) {
                         return false;
                     }
                     // Notify its callback after new client registered.
-                    postEventListenerCurrentTemperatures(listener, new Integer(type));
+                    postEventListenerCurrentTemperaturesLocked(listener, type);
                     return true;
                 } finally {
                     Binder.restoreCallingIdentity(token);
@@ -484,7 +574,7 @@
                         return false;
                     }
                     // Notify its callback after new client registered.
-                    postStatusListener(listener);
+                    postStatusListenerLocked(listener);
                     return true;
                 } finally {
                     Binder.restoreCallingIdentity(token);
@@ -557,11 +647,50 @@
         }
 
         @Override
+        public boolean registerThermalHeadroomListener(IThermalHeadroomListener listener) {
+            if (!mHalReady.get()) {
+                return false;
+            }
+            synchronized (mLock) {
+                // Notify its callback after new client registered.
+                final long token = Binder.clearCallingIdentity();
+                try {
+                    if (!mThermalHeadroomListeners.register(listener)) {
+                        return false;
+                    }
+                } finally {
+                    Binder.restoreCallingIdentity(token);
+                }
+            }
+            final HeadroomCallbackData data;
+            synchronized (mTemperatureWatcher.mSamples) {
+                data = mTemperatureWatcher.getHeadroomCallbackDataLocked();
+            }
+            // Notify its callback after new client registered.
+            synchronized (mLock) {
+                postHeadroomListenerLocked(listener, data);
+            }
+            return true;
+        }
+
+        @Override
+        public boolean unregisterThermalHeadroomListener(IThermalHeadroomListener listener) {
+            synchronized (mLock) {
+                final long token = Binder.clearCallingIdentity();
+                try {
+                    return mThermalHeadroomListeners.unregister(listener);
+                } finally {
+                    Binder.restoreCallingIdentity(token);
+                }
+            }
+        }
+
+        @Override
         public float getThermalHeadroom(int forecastSeconds) {
             if (!mHalReady.get()) {
                 FrameworkStatsLog.write(FrameworkStatsLog.THERMAL_HEADROOM_CALLED, getCallingUid(),
-                            FrameworkStatsLog.THERMAL_HEADROOM_CALLED__API_STATUS__HAL_NOT_READY,
-                            Float.NaN, forecastSeconds);
+                        FrameworkStatsLog.THERMAL_HEADROOM_CALLED__API_STATUS__HAL_NOT_READY,
+                        Float.NaN, forecastSeconds);
                 return Float.NaN;
             }
 
@@ -570,8 +699,8 @@
                     Slog.d(TAG, "Invalid forecastSeconds: " + forecastSeconds);
                 }
                 FrameworkStatsLog.write(FrameworkStatsLog.THERMAL_HEADROOM_CALLED, getCallingUid(),
-                            FrameworkStatsLog.THERMAL_HEADROOM_CALLED__API_STATUS__INVALID_ARGUMENT,
-                            Float.NaN, forecastSeconds);
+                        FrameworkStatsLog.THERMAL_HEADROOM_CALLED__API_STATUS__INVALID_ARGUMENT,
+                        Float.NaN, forecastSeconds);
                 return Float.NaN;
             }
 
@@ -592,13 +721,10 @@
                         THERMAL_HEADROOM_THRESHOLDS_CALLED__API_STATUS__FEATURE_NOT_SUPPORTED);
                 throw new UnsupportedOperationException("Thermal headroom thresholds not enabled");
             }
-            synchronized (mTemperatureWatcher.mSamples) {
-                FrameworkStatsLog.write(FrameworkStatsLog.THERMAL_HEADROOM_THRESHOLDS_CALLED,
-                        Binder.getCallingUid(),
-                        THERMAL_HEADROOM_THRESHOLDS_CALLED__API_STATUS__SUCCESS);
-                return Arrays.copyOf(mTemperatureWatcher.mHeadroomThresholds,
-                        mTemperatureWatcher.mHeadroomThresholds.length);
-            }
+            FrameworkStatsLog.write(FrameworkStatsLog.THERMAL_HEADROOM_THRESHOLDS_CALLED,
+                    Binder.getCallingUid(),
+                    THERMAL_HEADROOM_THRESHOLDS_CALLED__API_STATUS__SUCCESS);
+            return mTemperatureWatcher.getHeadroomThresholds();
         }
 
         @Override
@@ -711,7 +837,7 @@
     class ThermalShellCommand extends ShellCommand {
         @Override
         public int onCommand(String cmd) {
-            switch(cmd != null ? cmd : "") {
+            switch (cmd != null ? cmd : "") {
                 case "inject-temperature":
                     return runInjectTemperature();
                 case "override-status":
@@ -1112,7 +1238,8 @@
         }
 
         @Override
-        @NonNull protected List<TemperatureThreshold> getTemperatureThresholds(
+        @NonNull
+        protected List<TemperatureThreshold> getTemperatureThresholds(
                 boolean shouldFilter, int type) {
             synchronized (mHalLock) {
                 final List<TemperatureThreshold> ret = new ArrayList<>();
@@ -1631,14 +1758,68 @@
         }
     }
 
+    private static final class HeadroomCallbackData {
+        float mHeadroom;
+        float mForecastHeadroom;
+        int mForecastSeconds;
+        float[] mHeadroomThresholds;
+
+        HeadroomCallbackData(float headroom, float forecastHeadroom, int forecastSeconds,
+                @NonNull float[] headroomThresholds) {
+            mHeadroom = headroom;
+            mForecastHeadroom = forecastHeadroom;
+            mForecastSeconds = forecastSeconds;
+            mHeadroomThresholds = headroomThresholds;
+        }
+
+        private boolean isSignificantDifferentFrom(HeadroomCallbackData other) {
+            if (other == null) return true;
+            // currently this is always the same as DEFAULT_FORECAST_SECONDS, when it's retried
+            // from thermal HAL, we may want to adjust this.
+            if (this.mForecastSeconds != other.mForecastSeconds) return true;
+            if (Math.abs(this.mHeadroom - other.mHeadroom)
+                    >= HEADROOM_CALLBACK_MIN_DIFFERENCE) return true;
+            if (Math.abs(this.mForecastHeadroom - other.mForecastHeadroom)
+                    >= HEADROOM_CALLBACK_MIN_DIFFERENCE) return true;
+            for (int i = 0; i < this.mHeadroomThresholds.length; i++) {
+                if (Float.isNaN(this.mHeadroomThresholds[i]) != Float.isNaN(
+                        other.mHeadroomThresholds[i])) {
+                    return true;
+                }
+                if (Math.abs(this.mHeadroomThresholds[i] - other.mHeadroomThresholds[i])
+                        >= HEADROOM_THRESHOLD_CALLBACK_MIN_DIFFERENCE) {
+                    return true;
+                }
+            }
+            return false;
+        }
+
+        @Override
+        public String toString() {
+            return "HeadroomCallbackData[mHeadroom=" + mHeadroom + ", mForecastHeadroom="
+                    + mForecastHeadroom + ", mForecastSeconds=" + mForecastSeconds
+                    + ", mHeadroomThresholds=" + Arrays.toString(mHeadroomThresholds) + "]";
+        }
+    }
+
     @VisibleForTesting
     class TemperatureWatcher {
+        private static final int RING_BUFFER_SIZE = 30;
+        private static final int INACTIVITY_THRESHOLD_MILLIS = 10000;
+        @VisibleForTesting
+        long mInactivityThresholdMillis = INACTIVITY_THRESHOLD_MILLIS;
+
         private final Handler mHandler = BackgroundThread.getHandler();
 
-        /** Map of skin temperature sensor name to a corresponding list of samples */
+        /**
+         * Map of skin temperature sensor name to a corresponding list of samples
+         * Updates to the samples should also clear the headroom cache.
+         */
         @GuardedBy("mSamples")
         @VisibleForTesting
         final ArrayMap<String, ArrayList<Sample>> mSamples = new ArrayMap<>();
+        @GuardedBy("mSamples")
+        private final SparseArray<Float> mCachedHeadrooms = new SparseArray<>(2);
 
         /** Map of skin temperature sensor name to the corresponding SEVERE temperature threshold */
         @GuardedBy("mSamples")
@@ -1650,13 +1831,9 @@
         @GuardedBy("mSamples")
         private long mLastForecastCallTimeMillis = 0;
 
-        private static final int INACTIVITY_THRESHOLD_MILLIS = 10000;
-        @VisibleForTesting
-        long mInactivityThresholdMillis = INACTIVITY_THRESHOLD_MILLIS;
-
         void getAndUpdateThresholds() {
             List<TemperatureThreshold> thresholds =
-                        mHalWrapper.getTemperatureThresholds(true, Temperature.TYPE_SKIN);
+                    mHalWrapper.getTemperatureThresholds(true, Temperature.TYPE_SKIN);
             synchronized (mSamples) {
                 if (Flags.allowThermalHeadroomThresholds()) {
                     Arrays.fill(mHeadroomThresholds, Float.NaN);
@@ -1684,6 +1861,8 @@
                 return;
             }
             if (override) {
+                Slog.d(TAG, "Headroom cache cleared on threshold update " + threshold);
+                mCachedHeadrooms.clear();
                 Arrays.fill(mHeadroomThresholds, Float.NaN);
             }
             for (int severity = ThrottlingSeverity.LIGHT;
@@ -1693,62 +1872,61 @@
                     if (Float.isNaN(t)) {
                         continue;
                     }
-                    synchronized (mSamples) {
-                        if (severity == ThrottlingSeverity.SEVERE) {
-                            mHeadroomThresholds[severity] = 1.0f;
-                            continue;
-                        }
-                        float headroom = normalizeTemperature(t, severeThreshold);
-                        if (Float.isNaN(mHeadroomThresholds[severity])) {
-                            mHeadroomThresholds[severity] = headroom;
-                        } else {
-                            float lastHeadroom = mHeadroomThresholds[severity];
-                            mHeadroomThresholds[severity] = Math.min(lastHeadroom, headroom);
-                        }
+                    if (severity == ThrottlingSeverity.SEVERE) {
+                        mHeadroomThresholds[severity] = 1.0f;
+                        continue;
+                    }
+                    float headroom = normalizeTemperature(t, severeThreshold);
+                    if (Float.isNaN(mHeadroomThresholds[severity])) {
+                        mHeadroomThresholds[severity] = headroom;
+                    } else {
+                        float lastHeadroom = mHeadroomThresholds[severity];
+                        mHeadroomThresholds[severity] = Math.min(lastHeadroom, headroom);
                     }
                 }
             }
         }
 
-        private static final int RING_BUFFER_SIZE = 30;
-
-        private void updateTemperature() {
+        private void getAndUpdateTemperatureSamples() {
             synchronized (mSamples) {
                 if (SystemClock.elapsedRealtime() - mLastForecastCallTimeMillis
                         < mInactivityThresholdMillis) {
                     // Trigger this again after a second as long as forecast has been called more
                     // recently than the inactivity timeout
-                    mHandler.postDelayed(this::updateTemperature, 1000);
+                    mHandler.postDelayed(this::getAndUpdateTemperatureSamples, 1000);
                 } else {
                     // Otherwise, we've been idle for at least 10 seconds, so we should
                     // shut down
                     mSamples.clear();
+                    mCachedHeadrooms.clear();
                     return;
                 }
 
                 long now = SystemClock.elapsedRealtime();
-                List<Temperature> temperatures = mHalWrapper.getCurrentTemperatures(true,
+                final List<Temperature> temperatures = mHalWrapper.getCurrentTemperatures(true,
                         Temperature.TYPE_SKIN);
-
-                for (int t = 0; t < temperatures.size(); ++t) {
-                    Temperature temperature = temperatures.get(t);
-
-                    // Filter out invalid temperatures. If this results in no values being stored at
-                    // all, the mSamples.empty() check in getForecast() will catch it.
-                    if (Float.isNaN(temperature.getValue())) {
-                        continue;
-                    }
-
-                    ArrayList<Sample> samples = mSamples.computeIfAbsent(temperature.getName(),
-                            k -> new ArrayList<>(RING_BUFFER_SIZE));
-                    if (samples.size() == RING_BUFFER_SIZE) {
-                        samples.removeFirst();
-                    }
-                    samples.add(new Sample(now, temperature.getValue()));
+                for (Temperature temperature : temperatures) {
+                    updateTemperatureSampleLocked(now, temperature);
                 }
+                mCachedHeadrooms.clear();
             }
         }
 
+        @GuardedBy("mSamples")
+        private void updateTemperatureSampleLocked(long timeNow, Temperature temperature) {
+            // Filter out invalid temperatures. If this results in no values being stored at
+            // all, the mSamples.empty() check in getForecast() will catch it.
+            if (Float.isNaN(temperature.getValue())) {
+                return;
+            }
+            ArrayList<Sample> samples = mSamples.computeIfAbsent(temperature.getName(),
+                    k -> new ArrayList<>(RING_BUFFER_SIZE));
+            if (samples.size() == RING_BUFFER_SIZE) {
+                samples.removeFirst();
+            }
+            samples.add(new Sample(timeNow, temperature.getValue()));
+        }
+
         /**
          * Calculates the trend using a linear regression. As the samples are degrees Celsius with
          * associated timestamps in milliseconds, the slope is in degrees Celsius per millisecond.
@@ -1801,7 +1979,7 @@
             synchronized (mSamples) {
                 mLastForecastCallTimeMillis = SystemClock.elapsedRealtime();
                 if (mSamples.isEmpty()) {
-                    updateTemperature();
+                    getAndUpdateTemperatureSamples();
                 }
 
                 // If somehow things take much longer than expected or there are no temperatures
@@ -1826,6 +2004,14 @@
                     return Float.NaN;
                 }
 
+                if (mCachedHeadrooms.contains(forecastSeconds)) {
+                    // TODO(b/360486877): replace with metrics
+                    Slog.d(TAG,
+                            "Headroom forecast in " + forecastSeconds + "s served from cache: "
+                                    + mCachedHeadrooms.get(forecastSeconds));
+                    return mCachedHeadrooms.get(forecastSeconds);
+                }
+
                 float maxNormalized = Float.NaN;
                 int noThresholdSampleCount = 0;
                 for (Map.Entry<String, ArrayList<Sample>> entry : mSamples.entrySet()) {
@@ -1842,6 +2028,12 @@
                     float currentTemperature = samples.getLast().temperature;
 
                     if (samples.size() < MINIMUM_SAMPLE_COUNT) {
+                        if (mSamples.size() == 1 && mCachedHeadrooms.contains(0)) {
+                            // if only one sensor name exists, then try reading the cache
+                            // TODO(b/360486877): replace with metrics
+                            Slog.d(TAG, "Headroom forecast cached: " + mCachedHeadrooms.get(0));
+                            return mCachedHeadrooms.get(0);
+                        }
                         // Don't try to forecast, just use the latest one we have
                         float normalized = normalizeTemperature(currentTemperature, threshold);
                         if (Float.isNaN(maxNormalized) || normalized > maxNormalized) {
@@ -1849,8 +2041,10 @@
                         }
                         continue;
                     }
-
-                    float slope = getSlopeOf(samples);
+                    float slope = 0.0f;
+                    if (forecastSeconds > 0) {
+                        slope = getSlopeOf(samples);
+                    }
                     float normalized = normalizeTemperature(
                             currentTemperature + slope * forecastSeconds * 1000, threshold);
                     if (Float.isNaN(maxNormalized) || normalized > maxNormalized) {
@@ -1868,10 +2062,28 @@
                             FrameworkStatsLog.THERMAL_HEADROOM_CALLED__API_STATUS__SUCCESS,
                             maxNormalized, forecastSeconds);
                 }
+                mCachedHeadrooms.put(forecastSeconds, maxNormalized);
                 return maxNormalized;
             }
         }
 
+        float[] getHeadroomThresholds() {
+            synchronized (mSamples) {
+                return Arrays.copyOf(mHeadroomThresholds, mHeadroomThresholds.length);
+            }
+        }
+
+        @GuardedBy("mSamples")
+        HeadroomCallbackData getHeadroomCallbackDataLocked() {
+            final HeadroomCallbackData data = new HeadroomCallbackData(
+                    getForecast(0),
+                    getForecast(DEFAULT_FORECAST_SECONDS),
+                    DEFAULT_FORECAST_SECONDS,
+                    Arrays.copyOf(mHeadroomThresholds, mHeadroomThresholds.length));
+            Slog.d(TAG, "New headroom callback data: " + data);
+            return data;
+        }
+
         @VisibleForTesting
         // Since Sample is inside an inner class, we can't make it static
         // This allows test code to create Sample objects via ThermalManagerService
@@ -1880,7 +2092,7 @@
         }
 
         @VisibleForTesting
-        class Sample {
+        static class Sample {
             public long time;
             public float temperature;
 
@@ -1888,6 +2100,11 @@
                 this.time = time;
                 this.temperature = temperature;
             }
+
+            @Override
+            public String toString() {
+                return "Sample[temperature=" + temperature + ", time=" + time + "]";
+            }
         }
     }
 }
diff --git a/services/tests/servicestests/src/com/android/server/power/ThermalManagerServiceTest.java b/services/tests/servicestests/src/com/android/server/power/ThermalManagerServiceTest.java
index cfe3d84..2ed71ce 100644
--- a/services/tests/servicestests/src/com/android/server/power/ThermalManagerServiceTest.java
+++ b/services/tests/servicestests/src/com/android/server/power/ThermalManagerServiceTest.java
@@ -22,10 +22,12 @@
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertNotNull;
-import static org.junit.Assert.assertNotSame;
 import static org.junit.Assert.assertTrue;
+import static org.mockito.AdditionalMatchers.aryEq;
 import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyFloat;
 import static org.mockito.ArgumentMatchers.anyInt;
+import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.reset;
@@ -39,14 +41,18 @@
 import android.hardware.thermal.TemperatureThreshold;
 import android.hardware.thermal.ThrottlingSeverity;
 import android.os.CoolingDevice;
+import android.os.Flags;
 import android.os.IBinder;
 import android.os.IPowerManager;
 import android.os.IThermalEventListener;
+import android.os.IThermalHeadroomListener;
 import android.os.IThermalService;
 import android.os.IThermalStatusListener;
 import android.os.PowerManager;
 import android.os.RemoteException;
 import android.os.Temperature;
+import android.platform.test.annotations.EnableFlags;
+import android.platform.test.flag.junit.SetFlagsRule;
 
 import androidx.test.filters.SmallTest;
 import androidx.test.runner.AndroidJUnit4;
@@ -56,6 +62,8 @@
 import com.android.server.power.ThermalManagerService.ThermalHalWrapper;
 
 import org.junit.Before;
+import org.junit.ClassRule;
+import org.junit.Rule;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.mockito.ArgumentCaptor;
@@ -78,6 +86,11 @@
 @SmallTest
 @RunWith(AndroidJUnit4.class)
 public class ThermalManagerServiceTest {
+    @ClassRule
+    public static final SetFlagsRule.ClassRule mSetFlagsClassRule = new SetFlagsRule.ClassRule();
+    @Rule
+    public final SetFlagsRule mSetFlagsRule = mSetFlagsClassRule.createSetFlagsRule();
+
     private static final long CALLBACK_TIMEOUT_MILLI_SEC = 5000;
     private ThermalManagerService mService;
     private ThermalHalFake mFakeHal;
@@ -89,6 +102,8 @@
     @Mock
     private IThermalService mIThermalServiceMock;
     @Mock
+    private IThermalHeadroomListener mHeadroomListener;
+    @Mock
     private IThermalEventListener mEventListener1;
     @Mock
     private IThermalEventListener mEventListener2;
@@ -102,22 +117,23 @@
      */
     private class ThermalHalFake extends ThermalHalWrapper {
         private static final int INIT_STATUS = Temperature.THROTTLING_NONE;
-        private ArrayList<Temperature> mTemperatureList = new ArrayList<>();
-        private ArrayList<CoolingDevice> mCoolingDeviceList = new ArrayList<>();
-        private ArrayList<TemperatureThreshold> mTemperatureThresholdList = initializeThresholds();
+        private List<Temperature> mTemperatureList = new ArrayList<>();
+        private List<Temperature> mOverrideTemperatures = null;
+        private List<CoolingDevice> mCoolingDeviceList = new ArrayList<>();
+        private List<TemperatureThreshold> mTemperatureThresholdList = initializeThresholds();
 
-        private Temperature mSkin1 = new Temperature(0, Temperature.TYPE_SKIN, "skin1",
+        private Temperature mSkin1 = new Temperature(28, Temperature.TYPE_SKIN, "skin1",
                 INIT_STATUS);
-        private Temperature mSkin2 = new Temperature(0, Temperature.TYPE_SKIN, "skin2",
+        private Temperature mSkin2 = new Temperature(31, Temperature.TYPE_SKIN, "skin2",
                 INIT_STATUS);
-        private Temperature mBattery = new Temperature(0, Temperature.TYPE_BATTERY, "batt",
+        private Temperature mBattery = new Temperature(34, Temperature.TYPE_BATTERY, "batt",
                 INIT_STATUS);
-        private Temperature mUsbPort = new Temperature(0, Temperature.TYPE_USB_PORT, "usbport",
+        private Temperature mUsbPort = new Temperature(37, Temperature.TYPE_USB_PORT, "usbport",
                 INIT_STATUS);
-        private CoolingDevice mCpu = new CoolingDevice(0, CoolingDevice.TYPE_BATTERY, "cpu");
-        private CoolingDevice mGpu = new CoolingDevice(0, CoolingDevice.TYPE_BATTERY, "gpu");
+        private CoolingDevice mCpu = new CoolingDevice(40, CoolingDevice.TYPE_BATTERY, "cpu");
+        private CoolingDevice mGpu = new CoolingDevice(43, CoolingDevice.TYPE_BATTERY, "gpu");
 
-        private ArrayList<TemperatureThreshold> initializeThresholds() {
+        private List<TemperatureThreshold> initializeThresholds() {
             ArrayList<TemperatureThreshold> thresholds = new ArrayList<>();
 
             TemperatureThreshold skinThreshold = new TemperatureThreshold();
@@ -157,6 +173,14 @@
             mCoolingDeviceList.add(mGpu);
         }
 
+        void setOverrideTemperatures(List<Temperature> temperatures) {
+            mOverrideTemperatures = temperatures;
+        }
+
+        void resetOverrideTemperatures() {
+            mOverrideTemperatures = null;
+        }
+
         @Override
         protected List<Temperature> getCurrentTemperatures(boolean shouldFilter, int type) {
             List<Temperature> ret = new ArrayList<>();
@@ -221,22 +245,36 @@
         when(mContext.getSystemService(PowerManager.class)).thenReturn(mPowerManager);
         resetListenerMock();
         mService = new ThermalManagerService(mContext, mFakeHal);
-        // Register callbacks before AMS ready and no callback sent
+        mService.onBootPhase(SystemService.PHASE_ACTIVITY_MANAGER_READY);
+    }
+
+    private void resetListenerMock() {
+        reset(mEventListener1);
+        reset(mStatusListener1);
+        reset(mEventListener2);
+        reset(mStatusListener2);
+        reset(mHeadroomListener);
+        doReturn(mock(IBinder.class)).when(mEventListener1).asBinder();
+        doReturn(mock(IBinder.class)).when(mStatusListener1).asBinder();
+        doReturn(mock(IBinder.class)).when(mEventListener2).asBinder();
+        doReturn(mock(IBinder.class)).when(mStatusListener2).asBinder();
+        doReturn(mock(IBinder.class)).when(mHeadroomListener).asBinder();
+    }
+
+    @Test
+    public void testRegister() throws Exception {
+        mService = new ThermalManagerService(mContext, mFakeHal);
+        // Register callbacks before AMS ready and verify they are called after AMS is ready
         assertTrue(mService.mService.registerThermalEventListener(mEventListener1));
         assertTrue(mService.mService.registerThermalStatusListener(mStatusListener1));
         assertTrue(mService.mService.registerThermalEventListenerWithType(mEventListener2,
                 Temperature.TYPE_SKIN));
         assertTrue(mService.mService.registerThermalStatusListener(mStatusListener2));
-        verify(mEventListener1, timeout(CALLBACK_TIMEOUT_MILLI_SEC)
-                .times(0)).notifyThrottling(any(Temperature.class));
-        verify(mStatusListener1, timeout(CALLBACK_TIMEOUT_MILLI_SEC)
-                .times(1)).onStatusChange(Temperature.THROTTLING_NONE);
-        verify(mEventListener2, timeout(CALLBACK_TIMEOUT_MILLI_SEC)
-                .times(0)).notifyThrottling(any(Temperature.class));
-        verify(mStatusListener2, timeout(CALLBACK_TIMEOUT_MILLI_SEC)
-                .times(1)).onStatusChange(Temperature.THROTTLING_NONE);
+        Thread.sleep(CALLBACK_TIMEOUT_MILLI_SEC);
         resetListenerMock();
         mService.onBootPhase(SystemService.PHASE_ACTIVITY_MANAGER_READY);
+        assertTrue(mService.mService.registerThermalHeadroomListener(mHeadroomListener));
+
         ArgumentCaptor<Temperature> captor = ArgumentCaptor.forClass(Temperature.class);
         verify(mEventListener1, timeout(CALLBACK_TIMEOUT_MILLI_SEC)
                 .times(4)).notifyThrottling(captor.capture());
@@ -251,31 +289,18 @@
                 captor.getAllValues());
         verify(mStatusListener2, timeout(CALLBACK_TIMEOUT_MILLI_SEC)
                 .times(0)).onStatusChange(Temperature.THROTTLING_NONE);
-    }
-
-    private void resetListenerMock() {
-        reset(mEventListener1);
-        reset(mStatusListener1);
-        reset(mEventListener2);
-        reset(mStatusListener2);
-        doReturn(mock(IBinder.class)).when(mEventListener1).asBinder();
-        doReturn(mock(IBinder.class)).when(mStatusListener1).asBinder();
-        doReturn(mock(IBinder.class)).when(mEventListener2).asBinder();
-        doReturn(mock(IBinder.class)).when(mStatusListener2).asBinder();
-    }
-
-    @Test
-    public void testRegister() throws RemoteException {
         resetListenerMock();
-        // Register callbacks and verify they are called
+
+        // Register callbacks after AMS ready and verify they are called
         assertTrue(mService.mService.registerThermalEventListener(mEventListener1));
         assertTrue(mService.mService.registerThermalStatusListener(mStatusListener1));
-        ArgumentCaptor<Temperature> captor = ArgumentCaptor.forClass(Temperature.class);
+        captor = ArgumentCaptor.forClass(Temperature.class);
         verify(mEventListener1, timeout(CALLBACK_TIMEOUT_MILLI_SEC)
                 .times(4)).notifyThrottling(captor.capture());
         assertListEqualsIgnoringOrder(mFakeHal.mTemperatureList, captor.getAllValues());
         verify(mStatusListener1, timeout(CALLBACK_TIMEOUT_MILLI_SEC)
                 .times(1)).onStatusChange(Temperature.THROTTLING_NONE);
+
         // Register new callbacks and verify old ones are not called (remained same) while new
         // ones are called
         assertTrue(mService.mService.registerThermalEventListenerWithType(mEventListener2,
@@ -296,7 +321,15 @@
     }
 
     @Test
-    public void testNotifyThrottling() throws RemoteException {
+    public void testNotifyThrottling() throws Exception {
+        assertTrue(mService.mService.registerThermalEventListener(mEventListener1));
+        assertTrue(mService.mService.registerThermalStatusListener(mStatusListener1));
+        assertTrue(mService.mService.registerThermalEventListenerWithType(mEventListener2,
+                Temperature.TYPE_SKIN));
+        assertTrue(mService.mService.registerThermalStatusListener(mStatusListener2));
+        Thread.sleep(CALLBACK_TIMEOUT_MILLI_SEC);
+        resetListenerMock();
+
         int status = Temperature.THROTTLING_SEVERE;
         // Should only notify event not status
         Temperature newBattery = new Temperature(50, Temperature.TYPE_BATTERY, "batt", status);
@@ -349,6 +382,57 @@
     }
 
     @Test
+    @EnableFlags({Flags.FLAG_ALLOW_THERMAL_THRESHOLDS_CALLBACK})
+    public void testNotifyThrottling_headroomCallback() throws Exception {
+        assertTrue(mService.mService.registerThermalHeadroomListener(mHeadroomListener));
+        Thread.sleep(CALLBACK_TIMEOUT_MILLI_SEC);
+        resetListenerMock();
+        int status = Temperature.THROTTLING_SEVERE;
+        mFakeHal.setOverrideTemperatures(new ArrayList<>());
+
+        // Should not notify on non-skin type
+        Temperature newBattery = new Temperature(37, Temperature.TYPE_BATTERY, "batt", status);
+        mFakeHal.mCallback.onTemperatureChanged(newBattery);
+        verify(mHeadroomListener, timeout(CALLBACK_TIMEOUT_MILLI_SEC)
+                .times(0)).onHeadroomChange(anyFloat(), anyFloat(), anyInt(), any());
+        resetListenerMock();
+
+        // Notify headroom on skin temperature change
+        Temperature newSkin = new Temperature(37, Temperature.TYPE_SKIN, "skin1", status);
+        mFakeHal.mCallback.onTemperatureChanged(newSkin);
+        verify(mHeadroomListener, timeout(CALLBACK_TIMEOUT_MILLI_SEC)
+                .times(1)).onHeadroomChange(eq(0.9f), anyFloat(), anyInt(),
+                eq(new float[]{Float.NaN, 0.6666667f, 0.8333333f, 1.0f, 1.1666666f, 1.3333334f,
+                        1.5f}));
+        resetListenerMock();
+
+        // Same or similar temperature should not trigger in a short period
+        mFakeHal.mCallback.onTemperatureChanged(newSkin);
+        newSkin = new Temperature(36.9f, Temperature.TYPE_SKIN, "skin1", status);
+        mFakeHal.mCallback.onTemperatureChanged(newSkin);
+        newSkin = new Temperature(37.1f, Temperature.TYPE_SKIN, "skin1", status);
+        mFakeHal.mCallback.onTemperatureChanged(newSkin);
+        verify(mHeadroomListener, timeout(CALLBACK_TIMEOUT_MILLI_SEC)
+                .times(0)).onHeadroomChange(anyFloat(), anyFloat(), anyInt(), any());
+        resetListenerMock();
+
+        // Significant temperature should trigger in a short period
+        newSkin = new Temperature(34f, Temperature.TYPE_SKIN, "skin1", status);
+        mFakeHal.mCallback.onTemperatureChanged(newSkin);
+        verify(mHeadroomListener, timeout(CALLBACK_TIMEOUT_MILLI_SEC)
+                .times(1)).onHeadroomChange(eq(0.8f), anyFloat(), anyInt(),
+                eq(new float[]{Float.NaN, 0.6666667f, 0.8333333f, 1.0f, 1.1666666f, 1.3333334f,
+                        1.5f}));
+        resetListenerMock();
+        newSkin = new Temperature(40f, Temperature.TYPE_SKIN, "skin1", status);
+        mFakeHal.mCallback.onTemperatureChanged(newSkin);
+        verify(mHeadroomListener, timeout(CALLBACK_TIMEOUT_MILLI_SEC)
+                .times(1)).onHeadroomChange(eq(1.0f), anyFloat(), anyInt(),
+                eq(new float[]{Float.NaN, 0.6666667f, 0.8333333f, 1.0f, 1.1666666f, 1.3333334f,
+                        1.5f}));
+    }
+
+    @Test
     public void testGetCurrentTemperatures() throws RemoteException {
         assertListEqualsIgnoringOrder(mFakeHal.getCurrentTemperatures(false, 0),
                 Arrays.asList(mService.mService.getCurrentTemperatures()));
@@ -388,13 +472,28 @@
         // Do no call onActivityManagerReady to skip connect HAL
         assertTrue(mService.mService.registerThermalEventListener(mEventListener1));
         assertTrue(mService.mService.registerThermalStatusListener(mStatusListener1));
-        assertTrue(mService.mService.unregisterThermalEventListener(mEventListener1));
-        assertTrue(mService.mService.unregisterThermalStatusListener(mStatusListener1));
+        assertTrue(mService.mService.registerThermalEventListenerWithType(mEventListener2,
+                Temperature.TYPE_SKIN));
+        assertFalse(mService.mService.registerThermalHeadroomListener(mHeadroomListener));
+        verify(mEventListener1, timeout(CALLBACK_TIMEOUT_MILLI_SEC)
+                .times(0)).notifyThrottling(any(Temperature.class));
+        verify(mStatusListener1, timeout(CALLBACK_TIMEOUT_MILLI_SEC)
+                .times(1)).onStatusChange(Temperature.THROTTLING_NONE);
+        verify(mEventListener2, timeout(CALLBACK_TIMEOUT_MILLI_SEC)
+                .times(0)).notifyThrottling(any(Temperature.class));
+        verify(mHeadroomListener, timeout(CALLBACK_TIMEOUT_MILLI_SEC)
+                .times(0)).onHeadroomChange(anyFloat(), anyFloat(), anyInt(), any());
+
         assertEquals(0, Arrays.asList(mService.mService.getCurrentTemperatures()).size());
         assertEquals(0, Arrays.asList(mService.mService.getCurrentTemperaturesWithType(
-                        Temperature.TYPE_SKIN)).size());
+                Temperature.TYPE_SKIN)).size());
         assertEquals(Temperature.THROTTLING_NONE, mService.mService.getCurrentThermalStatus());
         assertTrue(Float.isNaN(mService.mService.getThermalHeadroom(0)));
+
+        assertTrue(mService.mService.unregisterThermalEventListener(mEventListener1));
+        assertTrue(mService.mService.unregisterThermalEventListener(mEventListener2));
+        assertTrue(mService.mService.unregisterThermalStatusListener(mStatusListener1));
+        assertFalse(mService.mService.unregisterThermalHeadroomListener(mHeadroomListener));
     }
 
     @Test
@@ -419,35 +518,45 @@
     }
 
     @Test
-    public void testTemperatureWatcherUpdateSevereThresholds() {
+    @EnableFlags({Flags.FLAG_ALLOW_THERMAL_THRESHOLDS_CALLBACK,
+            Flags.FLAG_ALLOW_THERMAL_HEADROOM_THRESHOLDS})
+    public void testTemperatureWatcherUpdateSevereThresholds() throws Exception {
+        assertTrue(mService.mService.registerThermalHeadroomListener(mHeadroomListener));
+        verify(mHeadroomListener, timeout(CALLBACK_TIMEOUT_MILLI_SEC)
+                .times(1)).onHeadroomChange(eq(0.6f), eq(0.6f), anyInt(),
+                aryEq(new float[]{Float.NaN, 0.6666667f, 0.8333333f, 1.0f, 1.1666666f, 1.3333334f,
+                        1.5f}));
+        resetListenerMock();
         TemperatureWatcher watcher = mService.mTemperatureWatcher;
+        TemperatureThreshold newThreshold = new TemperatureThreshold();
+        newThreshold.name = "skin1";
+        newThreshold.type = Temperature.TYPE_SKIN;
+        // significant change in threshold (> 0.3C) should trigger a callback
+        newThreshold.hotThrottlingThresholds = new float[]{
+                Float.NaN, 43.0f, 46.0f, 49.0f, Float.NaN, Float.NaN, Float.NaN
+        };
+        mFakeHal.mCallback.onThresholdChanged(newThreshold);
         synchronized (watcher.mSamples) {
-            watcher.mSevereThresholds.erase();
-            watcher.getAndUpdateThresholds();
-            assertEquals(1, watcher.mSevereThresholds.size());
-            assertEquals("skin1", watcher.mSevereThresholds.keyAt(0));
             Float threshold = watcher.mSevereThresholds.get("skin1");
             assertNotNull(threshold);
-            assertEquals(40.0f, threshold, 0.0f);
+            assertEquals(49.0f, threshold, 0.0f);
             assertArrayEquals("Got" + Arrays.toString(watcher.mHeadroomThresholds),
-                    new float[]{Float.NaN, 0.6667f, 0.8333f, 1.0f, 1.166f, 1.3333f,
-                            1.5f},
-                    watcher.mHeadroomThresholds, 0.01f);
-
-            TemperatureThreshold newThreshold = new TemperatureThreshold();
-            newThreshold.name = "skin1";
-            newThreshold.hotThrottlingThresholds = new float[] {
-                    Float.NaN, 44.0f, 47.0f, 50.0f, Float.NaN, Float.NaN, Float.NaN
-            };
-            mFakeHal.mCallback.onThresholdChanged(newThreshold);
-            threshold = watcher.mSevereThresholds.get("skin1");
-            assertNotNull(threshold);
-            assertEquals(50.0f, threshold, 0.0f);
-            assertArrayEquals("Got" + Arrays.toString(watcher.mHeadroomThresholds),
-                    new float[]{Float.NaN, 0.8f, 0.9f, 1.0f, Float.NaN, Float.NaN,
-                            Float.NaN},
+                    new float[]{Float.NaN, 0.8f, 0.9f, 1.0f, Float.NaN, Float.NaN, Float.NaN},
                     watcher.mHeadroomThresholds, 0.01f);
         }
+        verify(mHeadroomListener, timeout(CALLBACK_TIMEOUT_MILLI_SEC)
+                .times(1)).onHeadroomChange(eq(0.3f), eq(0.3f), anyInt(),
+                aryEq(new float[]{Float.NaN, 0.8f, 0.9f, 1.0f, Float.NaN, Float.NaN, Float.NaN}));
+        resetListenerMock();
+
+        // same or similar threshold callback data within a second should not trigger callback
+        mFakeHal.mCallback.onThresholdChanged(newThreshold);
+        newThreshold.hotThrottlingThresholds = new float[]{
+                Float.NaN, 43.1f, 45.9f, 49.0f, Float.NaN, Float.NaN, Float.NaN
+        };
+        mFakeHal.mCallback.onThresholdChanged(newThreshold);
+        verify(mHeadroomListener, timeout(CALLBACK_TIMEOUT_MILLI_SEC)
+                .times(0)).onHeadroomChange(anyFloat(), anyFloat(), anyInt(), any());
     }
 
     @Test
@@ -475,28 +584,34 @@
     }
 
     @Test
-    public void testGetThermalHeadroomThresholdsOnlyReadOnce() throws Exception {
+    public void testGetThermalHeadroomThresholds() throws Exception {
         float[] expected = new float[]{Float.NaN, 0.1f, 0.2f, 0.3f, 0.4f, Float.NaN, 0.6f};
         when(mIThermalServiceMock.getThermalHeadroomThresholds()).thenReturn(expected);
         Map<Integer, Float> thresholds1 = mPowerManager.getThermalHeadroomThresholds();
         verify(mIThermalServiceMock, times(1)).getThermalHeadroomThresholds();
+        checkHeadroomThresholds(expected, thresholds1);
+
+        reset(mIThermalServiceMock);
+        expected = new float[]{Float.NaN, 0.2f, 0.3f, 0.4f, 0.4f, Float.NaN, 0.6f};
+        when(mIThermalServiceMock.getThermalHeadroomThresholds()).thenReturn(expected);
+        Map<Integer, Float> thresholds2 = mPowerManager.getThermalHeadroomThresholds();
+        verify(mIThermalServiceMock, times(1)).getThermalHeadroomThresholds();
+        checkHeadroomThresholds(expected, thresholds2);
+    }
+
+    private void checkHeadroomThresholds(float[] expected, Map<Integer, Float> thresholds) {
         for (int status = PowerManager.THERMAL_STATUS_LIGHT;
                 status <= PowerManager.THERMAL_STATUS_SHUTDOWN; status++) {
             if (Float.isNaN(expected[status])) {
-                assertFalse(thresholds1.containsKey(status));
+                assertFalse(thresholds.containsKey(status));
             } else {
-                assertEquals(expected[status], thresholds1.get(status), 0.01f);
+                assertEquals(expected[status], thresholds.get(status), 0.01f);
             }
         }
-        reset(mIThermalServiceMock);
-        Map<Integer, Float> thresholds2 = mPowerManager.getThermalHeadroomThresholds();
-        verify(mIThermalServiceMock, times(0)).getThermalHeadroomThresholds();
-        assertNotSame(thresholds1, thresholds2);
-        assertEquals(thresholds1, thresholds2);
     }
 
     @Test
-    public void testGetThermalHeadroomThresholdsOnDefaultHalResult() throws Exception  {
+    public void testGetThermalHeadroomThresholdsOnDefaultHalResult() throws Exception {
         TemperatureWatcher watcher = mService.mTemperatureWatcher;
         ArrayList<TemperatureThreshold> thresholds = new ArrayList<>();
         mFakeHal.mTemperatureThresholdList = thresholds;
@@ -510,8 +625,8 @@
         TemperatureThreshold nanThresholds = new TemperatureThreshold();
         nanThresholds.name = "nan";
         nanThresholds.type = Temperature.TYPE_SKIN;
-        nanThresholds.hotThrottlingThresholds = new float[ThrottlingSeverity.SHUTDOWN  + 1];
-        nanThresholds.coldThrottlingThresholds = new float[ThrottlingSeverity.SHUTDOWN  + 1];
+        nanThresholds.hotThrottlingThresholds = new float[ThrottlingSeverity.SHUTDOWN + 1];
+        nanThresholds.coldThrottlingThresholds = new float[ThrottlingSeverity.SHUTDOWN + 1];
         Arrays.fill(nanThresholds.hotThrottlingThresholds, Float.NaN);
         Arrays.fill(nanThresholds.coldThrottlingThresholds, Float.NaN);
         thresholds.add(nanThresholds);
@@ -607,7 +722,13 @@
     }
 
     @Test
-    public void testDump() {
+    public void testDump() throws Exception {
+        assertTrue(mService.mService.registerThermalEventListener(mEventListener1));
+        assertTrue(mService.mService.registerThermalStatusListener(mStatusListener1));
+        assertTrue(mService.mService.registerThermalEventListenerWithType(mEventListener2,
+                Temperature.TYPE_SKIN));
+        assertTrue(mService.mService.registerThermalStatusListener(mStatusListener2));
+
         when(mContext.checkCallingOrSelfPermission(android.Manifest.permission.DUMP))
                 .thenReturn(PackageManager.PERMISSION_GRANTED);
         final StringWriter out = new StringWriter();
@@ -628,22 +749,22 @@
         assertThat(dumpStr).contains("Thermal Status: 0");
         assertThat(dumpStr).contains(
                 "Cached temperatures:\n"
-                + "\tTemperature{mValue=0.0, mType=4, mName=usbport, mStatus=0}\n"
-                + "\tTemperature{mValue=0.0, mType=2, mName=batt, mStatus=0}\n"
-                + "\tTemperature{mValue=0.0, mType=3, mName=skin1, mStatus=0}\n"
-                + "\tTemperature{mValue=0.0, mType=3, mName=skin2, mStatus=0}"
+                        + "\tTemperature{mValue=37.0, mType=4, mName=usbport, mStatus=0}\n"
+                        + "\tTemperature{mValue=34.0, mType=2, mName=batt, mStatus=0}\n"
+                        + "\tTemperature{mValue=28.0, mType=3, mName=skin1, mStatus=0}\n"
+                        + "\tTemperature{mValue=31.0, mType=3, mName=skin2, mStatus=0}"
         );
         assertThat(dumpStr).contains("HAL Ready: true\n"
                 + "HAL connection:\n"
                 + "\tThermalHAL AIDL 1  connected: yes");
         assertThat(dumpStr).contains("Current temperatures from HAL:\n"
-                + "\tTemperature{mValue=0.0, mType=3, mName=skin1, mStatus=0}\n"
-                + "\tTemperature{mValue=0.0, mType=3, mName=skin2, mStatus=0}\n"
-                + "\tTemperature{mValue=0.0, mType=2, mName=batt, mStatus=0}\n"
-                + "\tTemperature{mValue=0.0, mType=4, mName=usbport, mStatus=0}\n");
+                + "\tTemperature{mValue=28.0, mType=3, mName=skin1, mStatus=0}\n"
+                + "\tTemperature{mValue=31.0, mType=3, mName=skin2, mStatus=0}\n"
+                + "\tTemperature{mValue=34.0, mType=2, mName=batt, mStatus=0}\n"
+                + "\tTemperature{mValue=37.0, mType=4, mName=usbport, mStatus=0}\n");
         assertThat(dumpStr).contains("Current cooling devices from HAL:\n"
-                + "\tCoolingDevice{mValue=0, mType=1, mName=cpu}\n"
-                + "\tCoolingDevice{mValue=0, mType=1, mName=gpu}\n");
+                + "\tCoolingDevice{mValue=40, mType=1, mName=cpu}\n"
+                + "\tCoolingDevice{mValue=43, mType=1, mName=gpu}\n");
         assertThat(dumpStr).contains("Temperature static thresholds from HAL:\n"
                 + "\tTemperatureThreshold{mType=3, mName=skin1, mHotThrottlingThresholds=[25.0, "
                 + "30.0, 35.0, 40.0, 45.0, 50.0, 55.0], mColdThrottlingThresholds=[0.0, 0.0, 0.0,"