Enhance TimerFileDescriptor to schedule multiple tasks

Improved the class to allow scheduling and execution of multiple tasks
concurrently.

Test: TH
Change-Id: I79b63c6846e0e0a934f257d4a38d9cab08bd13e8
diff --git a/staticlibs/device/com/android/net/module/util/TimerFileDescriptor.java b/staticlibs/device/com/android/net/module/util/TimerFileDescriptor.java
index dbbccc5..9efb00c 100644
--- a/staticlibs/device/com/android/net/module/util/TimerFileDescriptor.java
+++ b/staticlibs/device/com/android/net/module/util/TimerFileDescriptor.java
@@ -24,13 +24,14 @@
 import android.os.Message;
 import android.os.MessageQueue;
 import android.os.ParcelFileDescriptor;
+import android.os.SystemClock;
 import android.util.CloseGuard;
 import android.util.Log;
 
 import androidx.annotation.NonNull;
-import androidx.annotation.Nullable;
 
 import java.io.IOException;
+import java.util.PriorityQueue;
 
 /**
  * Represents a Timer file descriptor object used for scheduling tasks with precise delays.
@@ -48,10 +49,10 @@
  * final TimerFileDescriptor timerFd = new TimerFileDescriptor(handler);
  *
  * // Schedule a new task with a delay.
- * timerFd.setDelayedTask(() -> taskToExecute(), delayTime);
+ * timerFd.postDelayed(() -> taskToExecute(), delayTime);
  *
  * // Once the delay has elapsed, and the task is running, schedule another task.
- * timerFd.setDelayedTask(() -> anotherTaskToExecute(), anotherDelayTime);
+ * timerFd.postDelayed(() -> anotherTaskToExecute(), anotherDelayTime);
  *
  * // Remember to close the TimerFileDescriptor after all tasks have finished running.
  * timerFd.close();
@@ -69,28 +70,54 @@
     @NonNull
     private final ParcelFileDescriptor mParcelFileDescriptor;
     private final int mFdInt;
-    @Nullable
-    private ITask mTask;
+
+    private final PriorityQueue<Task> mTaskQueue;
 
     /**
-     * An interface for defining tasks that can be executed using a {@link Handler}.
+     * An abstract class for defining tasks that can be executed using a {@link Handler}.
      */
-    public interface ITask {
+    private abstract static class Task implements Comparable<Task> {
+        private final long mRunTimeMs;
+        private final long mCreatedTimeNs = SystemClock.elapsedRealtimeNanos();
+
+        /**
+         * create a task with a run time
+         */
+        Task(long runTimeMs) {
+            mRunTimeMs = runTimeMs;
+        }
+
         /**
          * Executes the task using the provided {@link Handler}.
          *
          * @param handler The {@link Handler} to use for executing the task.
          */
-        void post(Handler handler);
+        abstract void post(Handler handler);
+
+        @Override
+        public int compareTo(@NonNull Task o) {
+            if (mRunTimeMs != o.mRunTimeMs) {
+                return Long.compare(mRunTimeMs, o.mRunTimeMs);
+            }
+            return Long.compare(mCreatedTimeNs, o.mCreatedTimeNs);
+        }
+
+        /**
+         * Returns the run time of the task.
+         */
+        public long getRunTimeMs() {
+            return mRunTimeMs;
+        }
     }
 
     /**
      * A task that sends a {@link Message} using a {@link Handler}.
      */
-    public static class MessageTask implements ITask {
+    private static class MessageTask extends Task {
         private final Message mMessage;
 
-        public MessageTask(Message message) {
+        MessageTask(Message message, long runTimeMs) {
+            super(runTimeMs);
             mMessage = message;
         }
 
@@ -108,10 +135,11 @@
     /**
      * A task that posts a {@link Runnable} to a {@link Handler}.
      */
-    public static class RunnableTask implements ITask {
+    private static class RunnableTask extends Task {
         private final Runnable mRunnable;
 
-        public RunnableTask(Runnable runnable) {
+        RunnableTask(Runnable runnable, long runTimeMs) {
+            super(runTimeMs);
             mRunnable = runnable;
         }
 
@@ -127,7 +155,7 @@
     }
 
     /**
-     * TimerFileDescriptor constructor
+     * The TimerFileDescriptor constructor
      *
      * Note: The constructor is currently safe to call on another thread because it only sets final
      * members and registers the event to be called on the handler.
@@ -137,54 +165,75 @@
         mParcelFileDescriptor = ParcelFileDescriptor.adoptFd(mFdInt);
         mHandler = handler;
         mQueue = handler.getLooper().getQueue();
+        mTaskQueue = new PriorityQueue<>();
         registerFdEventListener();
 
         mGuard.open("close");
     }
 
-    /**
-     * Set a task to be executed after a specified delay.
-     *
-     * <p> A task can only be scheduled once at a time. Cancel previous scheduled task before the
-     *     new task is scheduled.
-     *
-     * @param task the task to be executed
-     * @param delayMs the delay time in milliseconds
-     * @throws IllegalArgumentException if try to replace the current scheduled task
-     * @throws IllegalArgumentException if the delay time is less than 0
-     */
-    public void setDelayedTask(@NonNull ITask task, long delayMs) {
+    private boolean enqueueTask(@NonNull Task task, long delayMs) {
         ensureRunningOnCorrectThread();
-        if (mTask != null) {
-            throw new IllegalArgumentException("task is already scheduled");
-        }
         if (delayMs <= 0L) {
             task.post(mHandler);
-            return;
+            return true;
         }
-
-        if (TimerFdUtils.setExpirationTime(mFdInt, delayMs)) {
-            mTask = task;
+        if (mTaskQueue.isEmpty() || task.compareTo(mTaskQueue.peek()) < 0) {
+            if (!TimerFdUtils.setExpirationTime(mFdInt, delayMs)) {
+                return false;
+            }
         }
+        mTaskQueue.add(task);
+        return true;
     }
 
     /**
-     * Cancel the scheduled task.
+     * Set a runnable to be executed after a specified delay.
+     *
+     * If delayMs is less than or equal to 0, the runnable will be executed immediately.
+     *
+     * @param runnable the runnable to be executed
+     * @param delayMs the delay time in milliseconds
+     * @return true if the task is scheduled successfully, false otherwise.
      */
-    public void cancelTask() {
-        ensureRunningOnCorrectThread();
-        if (mTask == null) return;
-
-        TimerFdUtils.setExpirationTime(mFdInt, 0 /* delayMs */);
-        mTask = null;
+    public boolean postDelayed(@NonNull Runnable runnable, long delayMs) {
+        return enqueueTask(new RunnableTask(runnable, SystemClock.elapsedRealtime() + delayMs),
+                delayMs);
     }
 
     /**
-     * Check if there is a scheduled task.
+     * Remove a scheduled runnable.
+     *
+     * @param runnable the runnable to be removed
      */
-    public boolean hasDelayedTask() {
+    public void removeDelayedRunnable(@NonNull Runnable runnable) {
         ensureRunningOnCorrectThread();
-        return mTask != null;
+        mTaskQueue.removeIf(task -> task instanceof RunnableTask
+                && ((RunnableTask) task).mRunnable == runnable);
+    }
+
+    /**
+     * Set a message to be sent after a specified delay.
+     *
+     * If delayMs is less than or equal to 0, the message will be sent immediately.
+     *
+     * @param msg the message to be sent
+     * @param delayMs the delay time in milliseconds
+     * @return true if the message is scheduled successfully, false otherwise.
+     */
+    public boolean sendDelayedMessage(Message msg, long delayMs) {
+
+        return enqueueTask(new MessageTask(msg, SystemClock.elapsedRealtime() + delayMs), delayMs);
+    }
+
+    /**
+     * Remove a scheduled message.
+     *
+     * @param what the message to be removed
+     */
+    public void removeDelayedMessage(int what) {
+        ensureRunningOnCorrectThread();
+        mTaskQueue.removeIf(task -> task instanceof MessageTask
+                && ((MessageTask) task).mMessage.what == what);
     }
 
     /**
@@ -216,10 +265,31 @@
     }
 
     private void handleExpiration() {
-        // Execute the task
-        if (mTask != null) {
-            mTask.post(mHandler);
-            mTask = null;
+        long currentTimeMs = SystemClock.elapsedRealtime();
+        while (!mTaskQueue.isEmpty()) {
+            final Task task = mTaskQueue.peek();
+            currentTimeMs = SystemClock.elapsedRealtime();
+            if (currentTimeMs < task.getRunTimeMs()) {
+                break;
+            }
+            task.post(mHandler);
+            mTaskQueue.poll();
+        }
+
+
+        if (!mTaskQueue.isEmpty()) {
+            // Using currentTimeMs ensures that the calculated expiration time
+            // is always positive.
+            if (!TimerFdUtils.setExpirationTime(mFdInt,
+                    mTaskQueue.peek().getRunTimeMs() - currentTimeMs)) {
+                // If setting the expiration time fails, clear the task queue.
+                Log.wtf(TAG, "Failed to set expiration time");
+                mTaskQueue.clear();
+            }
+        } else {
+            // We have to clean up the timer if no tasks are left. Otherwise, the timer will keep
+            // being triggered.
+            TimerFdUtils.setExpirationTime(mFdInt, 0);
         }
     }
 
diff --git a/staticlibs/tests/unit/Android.bp b/staticlibs/tests/unit/Android.bp
index 9d1d291..f4f1ea9 100644
--- a/staticlibs/tests/unit/Android.bp
+++ b/staticlibs/tests/unit/Android.bp
@@ -28,6 +28,7 @@
         "net-utils-device-common-struct-base",
         "net-utils-device-common-wear",
         "net-utils-service-connectivity",
+        "truth",
     ],
     libs: [
         "android.test.runner.stubs",
diff --git a/staticlibs/tests/unit/src/com/android/net/module/util/TimerFileDescriptorTest.kt b/staticlibs/tests/unit/src/com/android/net/module/util/TimerFileDescriptorTest.kt
index f5e47c9..3ad979e 100644
--- a/staticlibs/tests/unit/src/com/android/net/module/util/TimerFileDescriptorTest.kt
+++ b/staticlibs/tests/unit/src/com/android/net/module/util/TimerFileDescriptorTest.kt
@@ -22,33 +22,34 @@
 import android.os.HandlerThread
 import android.os.Looper
 import android.os.Message
+import android.os.SystemClock
 import androidx.test.filters.SmallTest
-import com.android.net.module.util.TimerFileDescriptor.ITask
-import com.android.net.module.util.TimerFileDescriptor.MessageTask
-import com.android.net.module.util.TimerFileDescriptor.RunnableTask
 import com.android.testutils.DevSdkIgnoreRule
 import com.android.testutils.DevSdkIgnoreRunner
 import com.android.testutils.tryTest
 import com.android.testutils.visibleOnHandlerThread
+import com.google.common.collect.Range
+import com.google.common.truth.Truth.assertThat
+import kotlin.test.assertEquals
 import org.junit.After
 import org.junit.Test
 import org.junit.runner.RunWith
-import java.time.Duration
-import java.time.Instant
-import kotlin.test.assertFalse
-import kotlin.test.assertTrue
-
-private const val MSG_TEST = 1
 
 @DevSdkIgnoreRunner.MonitorThreadLeak
 @RunWith(DevSdkIgnoreRunner::class)
 @SmallTest
 @DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.R)
 class TimerFileDescriptorTest {
+
+    private val TIMEOUT_MS = 1000L
+    private val TOLERANCE_MS = 50L
     private class TestHandler(looper: Looper) : Handler(looper) {
         override fun handleMessage(msg: Message) {
-            val cv = msg.obj as ConditionVariable
+            val pair = msg.obj as Pair<ConditionVariable, MutableList<Long>>
+            val cv = pair.first
             cv.open()
+            val executionTimes = pair.second
+            executionTimes.add(SystemClock.elapsedRealtime())
         }
     }
     private val thread = HandlerThread(TimerFileDescriptorTest::class.simpleName).apply { start() }
@@ -60,55 +61,80 @@
         thread.join()
     }
 
-    private fun assertDelayedTaskPost(
-            timerFd: TimerFileDescriptor,
-            task: ITask,
-            cv: ConditionVariable
-    ) {
-        val delayTime = 10L
-        val startTime1 = Instant.now()
-        handler.post { timerFd.setDelayedTask(task, delayTime) }
-        assertTrue(cv.block(100L /* timeoutMs*/))
-        assertTrue(Duration.between(startTime1, Instant.now()).toMillis() >= delayTime)
-    }
-
     @Test
-    fun testSetDelayedTask() {
-        val timerFd = TimerFileDescriptor(handler)
+    fun testMultiplePostDelayedTasks() {
+        val scheduler = TimerFileDescriptor(handler)
         tryTest {
-            // Verify the delayed task is executed with the self-implemented ITask
-            val cv1 = ConditionVariable()
-            assertDelayedTaskPost(timerFd, { cv1.open() }, cv1)
-
-            // Verify the delayed task is executed with the RunnableTask
-            val cv2 = ConditionVariable()
-            assertDelayedTaskPost(timerFd, RunnableTask{ cv2.open() }, cv2)
-
-            // Verify the delayed task is executed with the MessageTask
-            val cv3 = ConditionVariable()
-            assertDelayedTaskPost(timerFd, MessageTask(handler.obtainMessage(MSG_TEST, cv3)), cv3)
+            val initialTimeMs = SystemClock.elapsedRealtime()
+            val executionTimes = mutableListOf<Long>()
+            val cv = ConditionVariable()
+            handler.post {
+                scheduler.postDelayed(
+                    { executionTimes.add(SystemClock.elapsedRealtime() - initialTimeMs) }, 0)
+                scheduler.postDelayed(
+                    { executionTimes.add(SystemClock.elapsedRealtime() - initialTimeMs) }, 200)
+                val toBeRemoved = Runnable {
+                    executionTimes.add(SystemClock.elapsedRealtime() - initialTimeMs)
+                }
+                scheduler.postDelayed(toBeRemoved, 250)
+                scheduler.removeDelayedRunnable(toBeRemoved)
+                scheduler.postDelayed(
+                    { executionTimes.add(SystemClock.elapsedRealtime() - initialTimeMs) }, 100)
+                scheduler.postDelayed({
+                    executionTimes.add(SystemClock.elapsedRealtime() - initialTimeMs)
+                    cv.open() }, 300)
+            }
+            cv.block(TIMEOUT_MS)
+            assertEquals(4, executionTimes.size)
+            assertThat(executionTimes[0]).isIn(Range.closed(0L, TOLERANCE_MS))
+            assertThat(executionTimes[1]).isIn(Range.closed(100L, 100 + TOLERANCE_MS))
+            assertThat(executionTimes[2]).isIn(Range.closed(200L, 200 + TOLERANCE_MS))
+            assertThat(executionTimes[3]).isIn(Range.closed(300L, 300 + TOLERANCE_MS))
         } cleanup {
-            visibleOnHandlerThread(handler) { timerFd.close() }
+            visibleOnHandlerThread(handler) { scheduler.close() }
         }
     }
 
     @Test
-    fun testCancelTask() {
-        // The task is posted and canceled within the same handler loop, so the short delay used
-        // here won't cause flakes.
-        val delayTime = 10L
-        val timerFd = TimerFileDescriptor(handler)
-        val cv = ConditionVariable()
+    fun testMultipleSendDelayedMessages() {
+        val scheduler = TimerFileDescriptor(handler)
         tryTest {
+            val MSG_ID_0 = 0
+            val MSG_ID_1 = 1
+            val MSG_ID_2 = 2
+            val MSG_ID_3 = 3
+            val MSG_ID_4 = 4
+            val initialTimeMs = SystemClock.elapsedRealtime()
+            val executionTimes = mutableListOf<Long>()
+            val cv = ConditionVariable()
             handler.post {
-                timerFd.setDelayedTask({ cv.open() }, delayTime)
-                assertTrue(timerFd.hasDelayedTask())
-                timerFd.cancelTask()
-                assertFalse(timerFd.hasDelayedTask())
+                scheduler.sendDelayedMessage(
+                    Message.obtain(handler, MSG_ID_0, Pair(ConditionVariable(), executionTimes)), 0)
+                scheduler.sendDelayedMessage(
+                    Message.obtain(handler, MSG_ID_1, Pair(ConditionVariable(), executionTimes)),
+                    200)
+                scheduler.sendDelayedMessage(
+                    Message.obtain(handler, MSG_ID_4, Pair(ConditionVariable(), executionTimes)),
+                    250)
+                scheduler.removeDelayedMessage(MSG_ID_4)
+                scheduler.sendDelayedMessage(
+                    Message.obtain(handler, MSG_ID_2, Pair(ConditionVariable(), executionTimes)),
+                    100)
+                scheduler.sendDelayedMessage(
+                    Message.obtain(handler, MSG_ID_3, Pair(cv, executionTimes)),
+                    300)
             }
-            assertFalse(cv.block(20L /* timeoutMs*/))
+            cv.block(TIMEOUT_MS)
+            assertEquals(4, executionTimes.size)
+            assertThat(executionTimes[0] - initialTimeMs).isIn(Range.closed(0L, TOLERANCE_MS))
+            assertThat(executionTimes[1] - initialTimeMs)
+                .isIn(Range.closed(100L, 100 + TOLERANCE_MS))
+            assertThat(executionTimes[2] - initialTimeMs)
+                .isIn(Range.closed(200L, 200 + TOLERANCE_MS))
+            assertThat(executionTimes[3] - initialTimeMs)
+                .isIn(Range.closed(300L, 300 + TOLERANCE_MS))
         } cleanup {
-            visibleOnHandlerThread(handler) { timerFd.close() }
+            visibleOnHandlerThread(handler) { scheduler.close() }
         }
     }
 }