Move scheduling logic into MdnsQueryScheduler class

Move scheduling logic into a standalone MdnsQueryScheduler class to
simplify the MdnsServiceTypeClient class.

Bug: 292470176
Test: atest FrameworksNetTests CtsNetTestCases
Change-Id: I31130c239bc54f6dc0efde2921ce51df35076a74
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsQueryScheduler.java b/service-t/src/com/android/server/connectivity/mdns/MdnsQueryScheduler.java
new file mode 100644
index 0000000..3fcf0d4
--- /dev/null
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsQueryScheduler.java
@@ -0,0 +1,144 @@
+/*
+ * Copyright (C) 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server.connectivity.mdns;
+
+import android.annotation.NonNull;
+import android.annotation.Nullable;
+
+/**
+ * The query scheduler class for calculating next query tasks parameters.
+ * <p>
+ * The class is not thread-safe and needs to be used on a consistent thread.
+ */
+public class MdnsQueryScheduler {
+
+    /**
+     * The argument for tracking the query tasks status.
+     */
+    public static class ScheduledQueryTaskArgs {
+        public final QueryTaskConfig config;
+        public final long timeToRun;
+        public final long minTtlExpirationTimeWhenScheduled;
+        public final long sessionId;
+
+        ScheduledQueryTaskArgs(@NonNull QueryTaskConfig config, long timeToRun,
+                long minTtlExpirationTimeWhenScheduled, long sessionId) {
+            this.config = config;
+            this.timeToRun = timeToRun;
+            this.minTtlExpirationTimeWhenScheduled = minTtlExpirationTimeWhenScheduled;
+            this.sessionId = sessionId;
+        }
+    }
+
+    @Nullable
+    private ScheduledQueryTaskArgs mLastScheduledQueryTaskArgs;
+
+    public MdnsQueryScheduler() {
+    }
+
+    /**
+     * Cancel the scheduled run. The method needed to be called when the scheduled task need to
+     * be canceled and rescheduling is not need.
+     */
+    public void cancelScheduledRun() {
+        mLastScheduledQueryTaskArgs = null;
+    }
+
+    /**
+     * Calculates ScheduledQueryTaskArgs for rescheduling the current task. Returns null if the
+     * rescheduling is not necessary.
+     */
+    @Nullable
+    public ScheduledQueryTaskArgs maybeRescheduleCurrentRun(long now,
+            long minRemainingTtl, long lastSentTime, long sessionId) {
+        if (mLastScheduledQueryTaskArgs == null) {
+            return null;
+        }
+        if (!mLastScheduledQueryTaskArgs.config.shouldUseQueryBackoff()) {
+            return null;
+        }
+
+        final long timeToRun = calculateTimeToRun(mLastScheduledQueryTaskArgs,
+                mLastScheduledQueryTaskArgs.config, now, minRemainingTtl, lastSentTime);
+
+        if (timeToRun <= mLastScheduledQueryTaskArgs.timeToRun) {
+            return null;
+        }
+
+        mLastScheduledQueryTaskArgs = new ScheduledQueryTaskArgs(mLastScheduledQueryTaskArgs.config,
+                timeToRun,
+                minRemainingTtl + now,
+                sessionId);
+        return mLastScheduledQueryTaskArgs;
+    }
+
+    /**
+     *  Calculates the ScheduledQueryTaskArgs for the next run.
+     */
+    @NonNull
+    public ScheduledQueryTaskArgs scheduleNextRun(
+            @NonNull QueryTaskConfig currentConfig,
+            long minRemainingTtl,
+            long now,
+            long lastSentTime,
+            long sessionId) {
+        final QueryTaskConfig nextRunConfig = currentConfig.getConfigForNextRun();
+        final long timeToRun;
+        if (mLastScheduledQueryTaskArgs == null) {
+            timeToRun = now + nextRunConfig.delayUntilNextTaskWithoutBackoffMs;
+        } else {
+            timeToRun = calculateTimeToRun(mLastScheduledQueryTaskArgs,
+                    nextRunConfig, now, minRemainingTtl, lastSentTime);
+        }
+        mLastScheduledQueryTaskArgs = new ScheduledQueryTaskArgs(nextRunConfig, timeToRun,
+                minRemainingTtl + now,
+                sessionId);
+        return mLastScheduledQueryTaskArgs;
+    }
+
+    /**
+     *  Calculates the ScheduledQueryTaskArgs for the initial run.
+     */
+    public ScheduledQueryTaskArgs scheduleFirstRun(@NonNull QueryTaskConfig taskConfig,
+            long now, long minRemainingTtl, long currentSessionId) {
+        mLastScheduledQueryTaskArgs = new ScheduledQueryTaskArgs(taskConfig, now /* timeToRun */,
+                now + minRemainingTtl/* minTtlExpirationTimeWhenScheduled */,
+                currentSessionId);
+        return mLastScheduledQueryTaskArgs;
+    }
+
+    private static long calculateTimeToRun(@NonNull ScheduledQueryTaskArgs taskArgs,
+            QueryTaskConfig queryTaskConfig, long now, long minRemainingTtl, long lastSentTime) {
+        final long baseDelayInMs = queryTaskConfig.delayUntilNextTaskWithoutBackoffMs;
+        if (!queryTaskConfig.shouldUseQueryBackoff()) {
+            return lastSentTime + baseDelayInMs;
+        }
+        if (minRemainingTtl <= 0) {
+            // There's no service, or there is an expired service. In any case, schedule for the
+            // minimum time, which is the base delay.
+            return lastSentTime + baseDelayInMs;
+        }
+        // If the next TTL expiration time hasn't changed, then use previous calculated timeToRun.
+        if (lastSentTime < now
+                && taskArgs.minTtlExpirationTimeWhenScheduled == now + minRemainingTtl) {
+            // Use the original scheduling time if the TTL has not changed, to avoid continuously
+            // rescheduling to 80% of the remaining TTL as time passes
+            return taskArgs.timeToRun;
+        }
+        return Math.max(now + (long) (0.8 * minRemainingTtl), lastSentTime + baseDelayInMs);
+    }
+}
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
index b5fd8a0..cdaebee 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
@@ -65,6 +65,7 @@
     @NonNull private final SocketKey socketKey;
     @NonNull private final SharedLog sharedLog;
     @NonNull private final Handler handler;
+    @NonNull private final MdnsQueryScheduler mdnsQueryScheduler;
     @NonNull private final Dependencies dependencies;
     /**
      * The service caches for each socket. It should be accessed from looper thread only.
@@ -82,9 +83,6 @@
     // QueryTask for
     // new subtypes. It stays the same between packets for same subtypes.
     private long currentSessionId = 0;
-
-    @Nullable
-    private ScheduledQueryTaskArgs lastScheduledQueryTaskArgs;
     private long lastSentTime;
 
     private class QueryTaskHandler extends Handler {
@@ -97,7 +95,8 @@
         public void handleMessage(Message msg) {
             switch (msg.what) {
                 case EVENT_START_QUERYTASK: {
-                    final ScheduledQueryTaskArgs taskArgs = (ScheduledQueryTaskArgs) msg.obj;
+                    final MdnsQueryScheduler.ScheduledQueryTaskArgs taskArgs =
+                            (MdnsQueryScheduler.ScheduledQueryTaskArgs) msg.obj;
                     // QueryTask should be run immediately after being created (not be scheduled in
                     // advance). Because the result of "makeResponsesForResolve" depends on answers
                     // that were received before it is called, so to take into account all answers
@@ -126,15 +125,21 @@
 
                     tryRemoveServiceAfterTtlExpires();
 
-                    final QueryTaskConfig nextRunConfig =
-                            sentResult.taskArgs.config.getConfigForNextRun();
                     final long now = clock.elapsedRealtime();
                     lastSentTime = now;
                     final long minRemainingTtl = getMinRemainingTtl(now);
-                    final long timeToRun = calculateTimeToRun(lastScheduledQueryTaskArgs,
-                            nextRunConfig, now, minRemainingTtl, lastSentTime);
-                    scheduleNextRun(nextRunConfig, minRemainingTtl, now, timeToRun,
-                            lastScheduledQueryTaskArgs.sessionId);
+                    MdnsQueryScheduler.ScheduledQueryTaskArgs args =
+                            mdnsQueryScheduler.scheduleNextRun(
+                                    sentResult.taskArgs.config,
+                                    minRemainingTtl,
+                                    now,
+                                    lastSentTime,
+                                    sentResult.taskArgs.sessionId
+                            );
+                    dependencies.sendMessageDelayed(
+                            handler,
+                            handler.obtainMessage(EVENT_START_QUERYTASK, args),
+                            calculateTimeToNextTask(args, now, sharedLog));
                     break;
                 }
                 default:
@@ -219,6 +224,7 @@
         this.handler = new QueryTaskHandler(looper);
         this.dependencies = dependencies;
         this.serviceCache = serviceCache;
+        this.mdnsQueryScheduler = new MdnsQueryScheduler();
     }
 
     private static MdnsServiceInfo buildMdnsServiceInfoFromResponse(
@@ -300,6 +306,7 @@
         }
         // Remove the next scheduled periodical task.
         removeScheduledTask();
+        mdnsQueryScheduler.cancelScheduledRun();
         // Keep tracking the ScheduledFuture for the task so we can cancel it if caller is not
         // interested anymore.
         final QueryTaskConfig taskConfig = new QueryTaskConfig(
@@ -312,18 +319,25 @@
         if (lastSentTime == 0) {
             lastSentTime = now;
         }
+        final long minRemainingTtl = getMinRemainingTtl(now);
         if (hadReply) {
-            final QueryTaskConfig queryTaskConfig = taskConfig.getConfigForNextRun();
-            final long minRemainingTtl = getMinRemainingTtl(now);
-            final long timeToRun = now + queryTaskConfig.delayUntilNextTaskWithoutBackoffMs;
-            scheduleNextRun(
-                    queryTaskConfig, minRemainingTtl, now, timeToRun, currentSessionId);
+            MdnsQueryScheduler.ScheduledQueryTaskArgs args =
+                    mdnsQueryScheduler.scheduleNextRun(
+                            taskConfig,
+                            minRemainingTtl,
+                            now,
+                            lastSentTime,
+                            currentSessionId
+                    );
+            dependencies.sendMessageDelayed(
+                    handler,
+                    handler.obtainMessage(EVENT_START_QUERYTASK, args),
+                    calculateTimeToNextTask(args, now, sharedLog));
         } else {
             final List<MdnsResponse> servicesToResolve = makeResponsesForResolve(socketKey);
-            lastScheduledQueryTaskArgs = new ScheduledQueryTaskArgs(taskConfig, now /* timeToRun */,
-                    now + getMinRemainingTtl(now)/* minTtlExpirationTimeWhenScheduled */,
-                    currentSessionId);
-            final QueryTask queryTask = new QueryTask(lastScheduledQueryTaskArgs, servicesToResolve,
+            final QueryTask queryTask = new QueryTask(
+                    mdnsQueryScheduler.scheduleFirstRun(taskConfig, now,
+                            minRemainingTtl, currentSessionId), servicesToResolve,
                     servicesToResolve.size() < listeners.size() /* sendDiscoveryQueries */);
             executor.submit(queryTask);
         }
@@ -341,7 +355,6 @@
         sharedLog.log("Remove EVENT_START_QUERYTASK"
                 + ", current session: " + currentSessionId);
         ++currentSessionId;
-        lastScheduledQueryTaskArgs = null;
     }
 
     private boolean responseMatchesOptions(@NonNull MdnsResponse response,
@@ -378,6 +391,7 @@
         }
         if (listeners.isEmpty()) {
             removeScheduledTask();
+            mdnsQueryScheduler.cancelScheduledRun();
         }
         return listeners.isEmpty();
     }
@@ -421,18 +435,18 @@
                 serviceCache.addOrUpdateService(serviceType, socketKey, response);
             }
         }
-        if (dependencies.hasMessages(handler, EVENT_START_QUERYTASK)
-                && lastScheduledQueryTaskArgs != null
-                && lastScheduledQueryTaskArgs.config.shouldUseQueryBackoff()) {
+        if (dependencies.hasMessages(handler, EVENT_START_QUERYTASK)) {
             final long now = clock.elapsedRealtime();
             final long minRemainingTtl = getMinRemainingTtl(now);
-            final long timeToRun = calculateTimeToRun(lastScheduledQueryTaskArgs,
-                    lastScheduledQueryTaskArgs.config, now,
-                    minRemainingTtl, lastSentTime);
-            if (timeToRun > lastScheduledQueryTaskArgs.timeToRun) {
-                QueryTaskConfig lastTaskConfig = lastScheduledQueryTaskArgs.config;
+            MdnsQueryScheduler.ScheduledQueryTaskArgs args =
+                    mdnsQueryScheduler.maybeRescheduleCurrentRun(now, minRemainingTtl,
+                            lastSentTime, currentSessionId + 1);
+            if (args != null) {
                 removeScheduledTask();
-                scheduleNextRun(lastTaskConfig, minRemainingTtl, now, timeToRun, currentSessionId);
+                dependencies.sendMessageDelayed(
+                        handler,
+                        handler.obtainMessage(EVENT_START_QUERYTASK, args),
+                        calculateTimeToNextTask(args, now, sharedLog));
             }
         }
     }
@@ -464,6 +478,7 @@
             }
         }
         removeScheduledTask();
+        mdnsQueryScheduler.cancelScheduledRun();
     }
 
     private void onResponseModified(@NonNull MdnsResponse response) {
@@ -599,28 +614,14 @@
         }
     }
 
-    private static class ScheduledQueryTaskArgs {
-        private final QueryTaskConfig config;
-        private final long timeToRun;
-        private final long minTtlExpirationTimeWhenScheduled;
-        private final long sessionId;
-
-        ScheduledQueryTaskArgs(@NonNull QueryTaskConfig config, long timeToRun,
-                long minTtlExpirationTimeWhenScheduled, long sessionId) {
-            this.config = config;
-            this.timeToRun = timeToRun;
-            this.minTtlExpirationTimeWhenScheduled = minTtlExpirationTimeWhenScheduled;
-            this.sessionId = sessionId;
-        }
-    }
 
     private static class QuerySentArguments {
         private final int transactionId;
         private final List<String> subTypes = new ArrayList<>();
-        private final ScheduledQueryTaskArgs taskArgs;
+        private final MdnsQueryScheduler.ScheduledQueryTaskArgs taskArgs;
 
         QuerySentArguments(int transactionId, @NonNull List<String> subTypes,
-                @NonNull ScheduledQueryTaskArgs taskArgs) {
+                @NonNull MdnsQueryScheduler.ScheduledQueryTaskArgs taskArgs) {
             this.transactionId = transactionId;
             this.subTypes.addAll(subTypes);
             this.taskArgs = taskArgs;
@@ -629,12 +630,10 @@
 
     // A FutureTask that enqueues a single query, and schedule a new FutureTask for the next task.
     private class QueryTask implements Runnable {
-
-        private final ScheduledQueryTaskArgs taskArgs;
+        private final MdnsQueryScheduler.ScheduledQueryTaskArgs taskArgs;
         private final List<MdnsResponse> servicesToResolve = new ArrayList<>();
         private final boolean sendDiscoveryQueries;
-
-        QueryTask(@NonNull ScheduledQueryTaskArgs taskArgs,
+        QueryTask(@NonNull MdnsQueryScheduler.ScheduledQueryTaskArgs taskArgs,
                 @NonNull List<MdnsResponse> servicesToResolve, boolean sendDiscoveryQueries) {
             this.taskArgs = taskArgs;
             this.servicesToResolve.addAll(servicesToResolve);
@@ -670,27 +669,6 @@
         }
     }
 
-    private static long calculateTimeToRun(@NonNull ScheduledQueryTaskArgs taskArgs,
-            QueryTaskConfig queryTaskConfig, long now, long minRemainingTtl, long lastSentTime) {
-        final long baseDelayInMs = queryTaskConfig.delayUntilNextTaskWithoutBackoffMs;
-        if (!queryTaskConfig.shouldUseQueryBackoff()) {
-            return lastSentTime + baseDelayInMs;
-        }
-        if (minRemainingTtl <= 0) {
-            // There's no service, or there is an expired service. In any case, schedule for the
-            // minimum time, which is the base delay.
-            return lastSentTime + baseDelayInMs;
-        }
-        // If the next TTL expiration time hasn't changed, then use previous calculated timeToRun.
-        if (lastSentTime < now
-                && taskArgs.minTtlExpirationTimeWhenScheduled == now + minRemainingTtl) {
-            // Use the original scheduling time if the TTL has not changed, to avoid continuously
-            // rescheduling to 80% of the remaining TTL as time passes
-            return taskArgs.timeToRun;
-        }
-        return Math.max(now + (long) (0.8 * minRemainingTtl), lastSentTime + baseDelayInMs);
-    }
-
     private long getMinRemainingTtl(long now) {
         long minRemainingTtl = Long.MAX_VALUE;
         for (MdnsResponse response : serviceCache.getCachedServices(serviceType, socketKey)) {
@@ -710,19 +688,11 @@
         return minRemainingTtl == Long.MAX_VALUE ? 0 : minRemainingTtl;
     }
 
-    @NonNull
-    private void scheduleNextRun(@NonNull QueryTaskConfig nextRunConfig,
-            long minRemainingTtl,
-            long timeWhenScheduled, long timeToRun, long sessionId) {
-        lastScheduledQueryTaskArgs = new ScheduledQueryTaskArgs(nextRunConfig, timeToRun,
-                minRemainingTtl + timeWhenScheduled, sessionId);
-        // The timeWhenScheduled could be greater than the timeToRun if the Runnable is delayed.
-        long timeToNextTasksWithBackoffInMs = Math.max(timeToRun - timeWhenScheduled, 0);
+    private static long calculateTimeToNextTask(MdnsQueryScheduler.ScheduledQueryTaskArgs args,
+            long now, SharedLog sharedLog) {
+        long timeToNextTasksWithBackoffInMs = Math.max(args.timeToRun - now, 0);
         sharedLog.log(String.format("Next run: sessionId: %d, in %d ms",
-                lastScheduledQueryTaskArgs.sessionId, timeToNextTasksWithBackoffInMs));
-        dependencies.sendMessageDelayed(
-                handler,
-                handler.obtainMessage(EVENT_START_QUERYTASK, lastScheduledQueryTaskArgs),
-                timeToNextTasksWithBackoffInMs);
+                args.sessionId, timeToNextTasksWithBackoffInMs));
+        return timeToNextTasksWithBackoffInMs;
     }
 }
\ No newline at end of file