Merge "Fix concurrency of IVibratorCallback" into main
diff --git a/services/core/java/com/android/server/vibrator/VibratorController.java b/services/core/java/com/android/server/vibrator/VibratorController.java
index 6710d02..988e8fe 100644
--- a/services/core/java/com/android/server/vibrator/VibratorController.java
+++ b/services/core/java/com/android/server/vibrator/VibratorController.java
@@ -56,10 +56,17 @@
     private volatile boolean mIsUnderExternalControl;
     private volatile float mCurrentAmplitude;
 
-    /** Listener for vibration completion callbacks from native. */
+    /**
+     * Listener for vibration completion callbacks from native.
+     *
+     * <p>Only the latest active native call to {@link VibratorController#on} will ever trigger this
+     * completion callback, to avoid race conditions during a vibration playback. If a new call to
+     * {@link #on} or {@link #off} happens before a previous callback was triggered then the
+     * previous callback will be disabled, even if the new command fails.
+     */
     public interface OnVibrationCompleteListener {
 
-        /** Callback triggered when vibration is complete. */
+        /** Callback triggered when an active vibration command is complete. */
         void onComplete(int vibratorId, long vibrationId);
     }
 
@@ -235,7 +242,7 @@
     }
 
     /**
-     * Turn on the vibrator for {@code milliseconds} time, using {@code vibrationId} or completion
+     * Turn on the vibrator for {@code milliseconds} time, using {@code vibrationId} for completion
      * callback to {@link OnVibrationCompleteListener}.
      *
      * <p>This will affect the state of {@link #isVibrating()}.
@@ -255,7 +262,7 @@
     }
 
     /**
-     * Plays predefined vibration effect, using {@code vibrationId} or completion callback to
+     * Plays predefined vibration effect, using {@code vibrationId} for completion callback to
      * {@link OnVibrationCompleteListener}.
      *
      * <p>This will affect the state of {@link #isVibrating()}.
@@ -276,8 +283,8 @@
     }
 
     /**
-     * Plays a composition of vibration primitives, using {@code vibrationId} or completion callback
-     * to {@link OnVibrationCompleteListener}.
+     * Plays a composition of vibration primitives, using {@code vibrationId} for completion
+     * callback to {@link OnVibrationCompleteListener}.
      *
      * <p>This will affect the state of {@link #isVibrating()}.
      *
@@ -299,7 +306,7 @@
     }
 
     /**
-     * Plays a composition of pwle primitives, using {@code vibrationId} or completion callback
+     * Plays a composition of pwle primitives, using {@code vibrationId} for completion callback
      * to {@link OnVibrationCompleteListener}.
      *
      * <p>This will affect the state of {@link #isVibrating()}.
@@ -321,7 +328,11 @@
         }
     }
 
-    /** Turns off the vibrator. This will affect the state of {@link #isVibrating()}. */
+    /**
+     * Turns off the vibrator and disables completion callback to any pending vibration.
+     *
+     * <p>This will affect the state of {@link #isVibrating()}.
+     */
     public void off() {
         synchronized (mLock) {
             mNativeWrapper.off();
diff --git a/services/core/jni/com_android_server_vibrator_VibratorController.cpp b/services/core/jni/com_android_server_vibrator_VibratorController.cpp
index f47a59d..4be21d8 100644
--- a/services/core/jni/com_android_server_vibrator_VibratorController.cpp
+++ b/services/core/jni/com_android_server_vibrator_VibratorController.cpp
@@ -131,17 +131,28 @@
     }
 
     std::function<void()> createCallback(jlong vibrationId) {
-        return [vibrationId, this]() {
+        auto callbackId = ++mCallbackId;
+        return [vibrationId, callbackId, this]() {
+            auto currentCallbackId = mCallbackId.load();
+            if (currentCallbackId != callbackId) {
+                // This callback is from an older HAL call that is no longer relevant to the service
+                return;
+            }
             auto jniEnv = GetOrAttachJNIEnvironment(sJvm);
             jniEnv->CallVoidMethod(mCallbackListener, sMethodIdOnComplete, mVibratorId,
                                    vibrationId);
         };
     }
 
+    void disableOldCallbacks() {
+        mCallbackId++;
+    }
+
 private:
     const std::shared_ptr<vibrator::HalController> mHal;
     const int32_t mVibratorId;
     const jobject mCallbackListener;
+    std::atomic<int64_t> mCallbackId;
 };
 
 static aidl::BrakingPwle brakingPwle(aidl::Braking braking, int32_t duration) {
@@ -236,6 +247,7 @@
     }
     auto offFn = [](vibrator::HalWrapper* hal) { return hal->off(); };
     wrapper->halCall<void>(offFn, "off");
+    wrapper->disableOldCallbacks();
 }
 
 static void vibratorSetAmplitude(JNIEnv* env, jclass /* clazz */, jlong ptr, jfloat amplitude) {