Shedule task with current run if there is a service in the cache

A service record exists in the cache, and a query transmission
with a long duration has been scheduled because the service
information was recently updated. However, the query will be sent
immediately if a new discovery/resolution/service info callback
listener is registered. Therefore, to reduce unnecessary queries
in aggressive query mode, if a service exists in the cache, the
query timing should either reuse the scheduled run or use a
backoff time if there is no scheduled run.

Fix: 345146854
Test: atest FrameworksNetTests NsdManagerTest
Change-Id: I3335a0f6e0018526a8677dfbfbf053dd0dc906e5
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsQueryScheduler.java b/service-t/src/com/android/server/connectivity/mdns/MdnsQueryScheduler.java
index 5c02767..cfeca5d 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsQueryScheduler.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsQueryScheduler.java
@@ -78,7 +78,7 @@
 
         final long timeToRun = calculateTimeToRun(mLastScheduledQueryTaskArgs,
                 mLastScheduledQueryTaskArgs.config, now, minRemainingTtl, lastSentTime,
-                numOfQueriesBeforeBackoff);
+                numOfQueriesBeforeBackoff, false /* forceEnableBackoff */);
 
         if (timeToRun <= mLastScheduledQueryTaskArgs.timeToRun) {
             return null;
@@ -102,14 +102,16 @@
             long lastSentTime,
             long sessionId,
             int queryMode,
-            int numOfQueriesBeforeBackoff) {
+            int numOfQueriesBeforeBackoff,
+            boolean forceEnableBackoff) {
         final QueryTaskConfig nextRunConfig = currentConfig.getConfigForNextRun(queryMode);
-        final long timeToRun;
-        if (mLastScheduledQueryTaskArgs == null) {
+        long timeToRun;
+        if (mLastScheduledQueryTaskArgs == null && !forceEnableBackoff) {
             timeToRun = now + nextRunConfig.delayUntilNextTaskWithoutBackoffMs;
         } else {
             timeToRun = calculateTimeToRun(mLastScheduledQueryTaskArgs,
-                    nextRunConfig, now, minRemainingTtl, lastSentTime, numOfQueriesBeforeBackoff);
+                    nextRunConfig, now, minRemainingTtl, lastSentTime, numOfQueriesBeforeBackoff,
+                    forceEnableBackoff);
         }
         mLastScheduledQueryTaskArgs = new ScheduledQueryTaskArgs(nextRunConfig, timeToRun,
                 minRemainingTtl + now,
@@ -128,11 +130,12 @@
         return mLastScheduledQueryTaskArgs;
     }
 
-    private static long calculateTimeToRun(@NonNull ScheduledQueryTaskArgs taskArgs,
+    private static long calculateTimeToRun(@Nullable ScheduledQueryTaskArgs taskArgs,
             QueryTaskConfig queryTaskConfig, long now, long minRemainingTtl, long lastSentTime,
-            int numOfQueriesBeforeBackoff) {
+            int numOfQueriesBeforeBackoff, boolean forceEnableBackoff) {
         final long baseDelayInMs = queryTaskConfig.delayUntilNextTaskWithoutBackoffMs;
-        if (!queryTaskConfig.shouldUseQueryBackoff(numOfQueriesBeforeBackoff)) {
+        if (!(forceEnableBackoff
+                || queryTaskConfig.shouldUseQueryBackoff(numOfQueriesBeforeBackoff))) {
             return lastSentTime + baseDelayInMs;
         }
         if (minRemainingTtl <= 0) {
@@ -141,7 +144,7 @@
             return lastSentTime + baseDelayInMs;
         }
         // If the next TTL expiration time hasn't changed, then use previous calculated timeToRun.
-        if (lastSentTime < now
+        if (lastSentTime < now && taskArgs != null
                 && taskArgs.minTtlExpirationTimeWhenScheduled == now + minRemainingTtl) {
             // Use the original scheduling time if the TTL has not changed, to avoid continuously
             // rescheduling to 80% of the remaining TTL as time passes
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
index 8959c1b..4b55ea9 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
@@ -16,6 +16,7 @@
 
 package com.android.server.connectivity.mdns;
 
+import static com.android.server.connectivity.mdns.MdnsSearchOptions.AGGRESSIVE_QUERY_MODE;
 import static com.android.server.connectivity.mdns.MdnsServiceCache.ServiceExpiredCallback;
 import static com.android.server.connectivity.mdns.MdnsServiceCache.findMatchedResponse;
 import static com.android.server.connectivity.mdns.util.MdnsUtils.Clock;
@@ -182,7 +183,8 @@
                                     lastSentTime,
                                     sentResult.taskArgs.sessionId,
                                     searchOptions.getQueryMode(),
-                                    searchOptions.numOfQueriesBeforeBackoff()
+                                    searchOptions.numOfQueriesBeforeBackoff(),
+                                    false /* forceEnableBackoff */
                             );
                     dependencies.sendMessageDelayed(
                             handler,
@@ -396,11 +398,13 @@
         }
         // Remove the next scheduled periodical task.
         removeScheduledTask();
-        mdnsQueryScheduler.cancelScheduledRun();
-        // Keep tracking the ScheduledFuture for the task so we can cancel it if caller is not
-        // interested anymore.
-        final QueryTaskConfig taskConfig = new QueryTaskConfig(
-                searchOptions.getQueryMode());
+        final boolean forceEnableBackoff =
+                (searchOptions.getQueryMode() == AGGRESSIVE_QUERY_MODE && hadReply);
+        // Keep the latest scheduled run for rescheduling if there is a service in the cache.
+        if (!(forceEnableBackoff)) {
+            mdnsQueryScheduler.cancelScheduledRun();
+        }
+        final QueryTaskConfig taskConfig = new QueryTaskConfig(searchOptions.getQueryMode());
         final long now = clock.elapsedRealtime();
         if (lastSentTime == 0) {
             lastSentTime = now;
@@ -415,7 +419,8 @@
                             lastSentTime,
                             currentSessionId,
                             searchOptions.getQueryMode(),
-                            searchOptions.numOfQueriesBeforeBackoff()
+                            searchOptions.numOfQueriesBeforeBackoff(),
+                            forceEnableBackoff
                     );
             dependencies.sendMessageDelayed(
                     handler,
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java
index 569f4d7..da0bc88 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java
@@ -2056,6 +2056,64 @@
         assertTrue(hasAnswer(knownAnswersQueryPacket, MdnsRecord.TYPE_PTR, subtypeLabels));
     }
 
+    @Test
+    public void sendQueries_AggressiveQueryMode_ServiceInCache() {
+        final int numOfQueriesBeforeBackoff = 11;
+        final MdnsSearchOptions searchOptions = MdnsSearchOptions.newBuilder()
+                .setQueryMode(AGGRESSIVE_QUERY_MODE)
+                .setNumOfQueriesBeforeBackoff(numOfQueriesBeforeBackoff)
+                .build();
+        startSendAndReceive(mockListenerOne, searchOptions);
+        verify(mockDeps, times(1)).removeMessages(any(), eq(EVENT_START_QUERYTASK));
+
+        int burstCounter = 0;
+        int betweenBurstTime = 0;
+        for (int i = 0; i < numOfQueriesBeforeBackoff; i += 3) {
+            verifyAndSendQuery(i, betweenBurstTime, /* expectsUnicastResponse= */ true);
+            verifyAndSendQuery(i + 1, /* timeInMs= */ 0, /* expectsUnicastResponse= */ false);
+            verifyAndSendQuery(i + 2, TIME_BETWEEN_RETRANSMISSION_QUERIES_IN_BURST_MS,
+                    /* expectsUnicastResponse= */ false);
+            betweenBurstTime = Math.min(
+                    INITIAL_AGGRESSIVE_TIME_BETWEEN_BURSTS_MS * (int) Math.pow(2, burstCounter),
+                    MAX_TIME_BETWEEN_AGGRESSIVE_BURSTS_MS);
+            burstCounter++;
+        }
+        // In backoff mode, the current scheduled task will be canceled and reschedule if the
+        // 0.8 * smallestRemainingTtl is larger than time to next run.
+        long currentTime = TEST_TTL / 2 + TEST_ELAPSED_REALTIME;
+        doReturn(currentTime).when(mockDecoderClock).elapsedRealtime();
+        doReturn(true).when(mockDeps).hasMessages(any(), eq(EVENT_START_QUERYTASK));
+        processResponse(createResponse(
+                "service-instance-1", "192.0.2.123", 5353,
+                SERVICE_TYPE_LABELS,
+                Collections.emptyMap(), TEST_TTL), socketKey);
+        verify(mockDeps, times(2)).removeMessages(any(), eq(EVENT_START_QUERYTASK));
+        assertNotNull(delayMessage);
+        assertEquals((long) (TEST_TTL / 2 * 0.8), latestDelayMs);
+
+        // Register another listener. There is a service in cache, the query time should be
+        // rescheduled with previous run.
+        currentTime += (long) ((TEST_TTL / 2 * 0.8) - 500L);
+        doReturn(currentTime).when(mockDecoderClock).elapsedRealtime();
+        startSendAndReceive(mockListenerTwo, searchOptions);
+        verify(mockDeps, times(3)).removeMessages(any(), eq(EVENT_START_QUERYTASK));
+        assertNotNull(delayMessage);
+        assertEquals(500L, latestDelayMs);
+
+        // Stop all listeners
+        stopSendAndReceive(mockListenerOne);
+        stopSendAndReceive(mockListenerTwo);
+        verify(mockDeps, times(4)).removeMessages(any(), eq(EVENT_START_QUERYTASK));
+
+        // Register a new listener. There is a service in cache, the query time should be
+        // rescheduled with remaining ttl.
+        currentTime += 400L;
+        doReturn(currentTime).when(mockDecoderClock).elapsedRealtime();
+        startSendAndReceive(mockListenerOne, searchOptions);
+        assertNotNull(delayMessage);
+        assertEquals(9680L, latestDelayMs);
+    }
+
     private static MdnsServiceInfo matchServiceName(String name) {
         return argThat(info -> info.getServiceInstanceName().equals(name));
     }