Merge "Remove NsdShim as it is not required" into main
diff --git a/service-t/native/libs/libnetworkstats/NetworkTraceHandler.cpp b/service-t/native/libs/libnetworkstats/NetworkTraceHandler.cpp
index c5f9631..ec63e41 100644
--- a/service-t/native/libs/libnetworkstats/NetworkTraceHandler.cpp
+++ b/service-t/native/libs/libnetworkstats/NetworkTraceHandler.cpp
@@ -119,7 +119,14 @@
       // the session and delegates writing. The corresponding handler will write
       // with the setting specified in the trace config.
       NetworkTraceHandler::Trace([&](NetworkTraceHandler::TraceContext ctx) {
-        ctx.GetDataSourceLocked()->Write(packets, ctx);
+        perfetto::LockedHandle<NetworkTraceHandler> handle =
+            ctx.GetDataSourceLocked();
+        // The underlying handle can be invalidated between when Trace starts
+        // and GetDataSourceLocked is called, but not while the LockedHandle
+        // exists and holds the lock. Check validity prior to use.
+        if (handle.valid()) {
+          handle->Write(packets, ctx);
+        }
       });
     });
 
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsConfigs.java b/service-t/src/com/android/server/connectivity/mdns/MdnsConfigs.java
index f5e7790..8cb3e96 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsConfigs.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsConfigs.java
@@ -54,10 +54,6 @@
         return true;
     }
 
-    public static boolean shouldCancelScanTaskWhenFutureIsNull() {
-        return false;
-    }
-
     public static long sleepTimeForSocketThreadMs() {
         return 20_000L;
     }
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
index 9c49b8f..e1e3170 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
@@ -18,12 +18,11 @@
 
 import static com.android.server.connectivity.mdns.util.MdnsUtils.ensureRunningOnHandlerThread;
 
-import static java.util.concurrent.TimeUnit.MILLISECONDS;
-
 import android.annotation.NonNull;
 import android.annotation.Nullable;
 import android.os.Handler;
 import android.os.Looper;
+import android.os.Message;
 import android.text.TextUtils;
 import android.util.ArrayMap;
 import android.util.ArraySet;
@@ -45,7 +44,6 @@
 import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
-import java.util.concurrent.Future;
 import java.util.concurrent.ScheduledExecutorService;
 
 /**
@@ -56,6 +54,8 @@
 
     private static final String TAG = MdnsServiceTypeClient.class.getSimpleName();
     private static final int DEFAULT_MTU = 1500;
+    @VisibleForTesting
+    static final int EVENT_START_QUERYTASK = 1;
 
     private final String serviceType;
     private final String[] serviceTypeLabels;
@@ -65,6 +65,7 @@
     @NonNull private final SocketKey socketKey;
     @NonNull private final SharedLog sharedLog;
     @NonNull private final Handler handler;
+    @NonNull private final Dependencies dependencies;
     private final Object lock = new Object();
     private final ArrayMap<MdnsServiceBrowserListener, MdnsSearchOptions> listeners =
             new ArrayMap<>();
@@ -84,15 +85,57 @@
 
     @GuardedBy("lock")
     @Nullable
-    private Future<?> nextQueryTaskFuture;
-
-    @GuardedBy("lock")
-    @Nullable
     private QueryTask lastScheduledTask;
 
     @GuardedBy("lock")
     private long lastSentTime;
 
+    private class QueryTaskHandler extends Handler {
+        QueryTaskHandler(Looper looper) {
+            super(looper);
+        }
+
+        @Override
+        public void handleMessage(Message msg) {
+            switch (msg.what) {
+                case EVENT_START_QUERYTASK:
+                    handleStartQueryTask((QueryTask) msg.obj);
+                    break;
+                default:
+                    sharedLog.e("Unrecognized event " + msg.what);
+                    break;
+            }
+        }
+    }
+
+    /**
+     * Dependencies of MdnsServiceTypeClient, for injection in tests.
+     */
+    @VisibleForTesting
+    public static class Dependencies {
+        /**
+         * @see Handler#sendMessageDelayed(Message, long)
+         */
+        public void sendMessageDelayed(@NonNull Handler handler, @NonNull Message message,
+                long delayMillis) {
+            handler.sendMessageDelayed(message, delayMillis);
+        }
+
+        /**
+         * @see Handler#removeMessages(int)
+         */
+        public void removeMessages(@NonNull Handler handler, int what) {
+            handler.removeMessages(what);
+        }
+
+        /**
+         * @see Handler#hasMessages(int)
+         */
+        public boolean hasMessages(@NonNull Handler handler, int what) {
+            return handler.hasMessages(what);
+        }
+    }
+
     /**
      * Constructor of {@link MdnsServiceTypeClient}.
      *
@@ -107,7 +150,7 @@
             @NonNull SharedLog sharedLog,
             @NonNull Looper looper) {
         this(serviceType, socketClient, executor, new MdnsResponseDecoder.Clock(), socketKey,
-                sharedLog, looper);
+                sharedLog, looper, new Dependencies());
     }
 
     @VisibleForTesting
@@ -118,7 +161,8 @@
             @NonNull MdnsResponseDecoder.Clock clock,
             @NonNull SocketKey socketKey,
             @NonNull SharedLog sharedLog,
-            @NonNull Looper looper) {
+            @NonNull Looper looper,
+            @NonNull Dependencies dependencies) {
         this.serviceType = serviceType;
         this.socketClient = socketClient;
         this.executor = executor;
@@ -127,7 +171,8 @@
         this.clock = clock;
         this.socketKey = socketKey;
         this.sharedLog = sharedLog;
-        this.handler = new Handler(looper);
+        this.handler = new QueryTaskHandler(looper);
+        this.dependencies = dependencies;
     }
 
     private static MdnsServiceInfo buildMdnsServiceInfoFromResponse(
@@ -206,10 +251,8 @@
                     }
                 }
             }
-            // Cancel the next scheduled periodical task.
-            if (nextQueryTaskFuture != null) {
-                cancelRequestTaskLocked();
-            }
+            // Remove the next scheduled periodical task.
+            removeScheduledTaskLock();
             // Keep tracking the ScheduledFuture for the task so we can cancel it if caller is not
             // interested anymore.
             final QueryTaskConfig taskConfig = new QueryTaskConfig(
@@ -226,30 +269,29 @@
                 final QueryTaskConfig queryTaskConfig = taskConfig.getConfigForNextRun();
                 final long minRemainingTtl = getMinRemainingTtlLocked(now);
                 final long timeToRun = now + queryTaskConfig.delayUntilNextTaskWithoutBackoffMs;
-                nextQueryTaskFuture = scheduleNextRunLocked(queryTaskConfig,
-                        minRemainingTtl, now, timeToRun, currentSessionId);
+                scheduleNextRunLocked(
+                        queryTaskConfig, minRemainingTtl, now, timeToRun, currentSessionId);
             } else {
                 lastScheduledTask = new QueryTask(taskConfig,
                         now /* timeToRun */,
                         now + getMinRemainingTtlLocked(now)/* minTtlExpirationTimeWhenScheduled */,
                         currentSessionId);
-                nextQueryTaskFuture = executor.submit(lastScheduledTask);
+                handleStartQueryTask(lastScheduledTask);
             }
         }
     }
 
     @GuardedBy("lock")
-    private void cancelRequestTaskLocked() {
-        final boolean canceled = nextQueryTaskFuture.cancel(true);
-        sharedLog.log("task canceled:" + canceled + ", current session: " + currentSessionId
-                + " task hashcode: " + getHexString(nextQueryTaskFuture));
+    private void removeScheduledTaskLock() {
+        dependencies.removeMessages(handler, EVENT_START_QUERYTASK);
+        sharedLog.log("Remove EVENT_START_QUERYTASK"
+                + ", current session: " + currentSessionId);
         ++currentSessionId;
-        nextQueryTaskFuture = null;
         lastScheduledTask = null;
     }
 
-    private static String getHexString(Object o) {
-        return Integer.toHexString(System.identityHashCode(o));
+    private void handleStartQueryTask(@NonNull QueryTask task) {
+        executor.submit(task);
     }
 
     private boolean responseMatchesOptions(@NonNull MdnsResponse response,
@@ -285,8 +327,8 @@
             if (listeners.remove(listener) == null) {
                 return listeners.isEmpty();
             }
-            if (listeners.isEmpty() && nextQueryTaskFuture != null) {
-                cancelRequestTaskLocked();
+            if (listeners.isEmpty()) {
+                removeScheduledTaskLock();
             }
             return listeners.isEmpty();
         }
@@ -329,7 +371,8 @@
                     instanceNameToResponse.put(response.getServiceInstanceName(), response);
                 }
             }
-            if (nextQueryTaskFuture != null && lastScheduledTask != null
+            if (dependencies.hasMessages(handler, EVENT_START_QUERYTASK)
+                    && lastScheduledTask != null
                     && lastScheduledTask.config.shouldUseQueryBackoff()) {
                 final long now = clock.elapsedRealtime();
                 final long minRemainingTtl = getMinRemainingTtlLocked(now);
@@ -338,9 +381,9 @@
                         minRemainingTtl, lastSentTime);
                 if (timeToRun > lastScheduledTask.timeToRun) {
                     QueryTaskConfig lastTaskConfig = lastScheduledTask.config;
-                    cancelRequestTaskLocked();
-                    nextQueryTaskFuture = scheduleNextRunLocked(lastTaskConfig, minRemainingTtl,
-                            now, timeToRun, currentSessionId);
+                    removeScheduledTaskLock();
+                    scheduleNextRunLocked(
+                            lastTaskConfig, minRemainingTtl, now, timeToRun, currentSessionId);
                 }
             }
         }
@@ -373,10 +416,7 @@
                     listener.onServiceNameRemoved(serviceInfo);
                 }
             }
-
-            if (nextQueryTaskFuture != null) {
-                cancelRequestTaskLocked();
-            }
+            removeScheduledTaskLock();
         }
     }
 
@@ -678,14 +718,6 @@
                     }
                 }
 
-                if (MdnsConfigs.shouldCancelScanTaskWhenFutureIsNull()) {
-                    if (nextQueryTaskFuture == null) {
-                        // If requestTaskFuture is set to null, the task is cancelled. We can't use
-                        // isCancelled() here because this QueryTask is different from the future
-                        // that is returned from executor.schedule(). See b/71646910.
-                        return;
-                    }
-                }
                 if ((result != null)) {
                     for (int i = 0; i < listeners.size(); i++) {
                         listeners.keyAt(i).onDiscoveryQuerySent(result.second, result.first);
@@ -730,8 +762,8 @@
                 final long minRemainingTtl = getMinRemainingTtlLocked(now);
                 final long timeToRun = calculateTimeToRun(this, nextRunConfig, now,
                         minRemainingTtl, lastSentTime);
-                nextQueryTaskFuture = scheduleNextRunLocked(nextRunConfig,
-                        minRemainingTtl, now, timeToRun, lastScheduledTask.sessionId);
+                scheduleNextRunLocked(nextRunConfig, minRemainingTtl, now, timeToRun,
+                        lastScheduledTask.sessionId);
             }
         }
     }
@@ -779,7 +811,7 @@
 
     @GuardedBy("lock")
     @NonNull
-    private Future<?> scheduleNextRunLocked(@NonNull QueryTaskConfig nextRunConfig,
+    private void scheduleNextRunLocked(@NonNull QueryTaskConfig nextRunConfig,
             long minRemainingTtl,
             long timeWhenScheduled, long timeToRun, long sessionId) {
         lastScheduledTask = new QueryTask(nextRunConfig, timeToRun,
@@ -789,7 +821,9 @@
         sharedLog.log(
                 String.format("Next run: sessionId: %d, in %d ms", lastScheduledTask.sessionId,
                         timeToNextTasksWithBackoffInMs));
-        return executor.schedule(lastScheduledTask, timeToNextTasksWithBackoffInMs,
-                MILLISECONDS);
+        dependencies.sendMessageDelayed(
+                handler,
+                handler.obtainMessage(EVENT_START_QUERYTASK, lastScheduledTask),
+                timeToNextTasksWithBackoffInMs);
     }
 }
\ No newline at end of file
diff --git a/service/src/com/android/server/connectivity/AutomaticOnOffKeepaliveTracker.java b/service/src/com/android/server/connectivity/AutomaticOnOffKeepaliveTracker.java
index 368860e..d03cac6 100644
--- a/service/src/com/android/server/connectivity/AutomaticOnOffKeepaliveTracker.java
+++ b/service/src/com/android/server/connectivity/AutomaticOnOffKeepaliveTracker.java
@@ -495,8 +495,11 @@
         final AutomaticOnOffKeepalive autoKi;
         try {
             autoKi = target.withKeepaliveInfo(res.second);
-            // Close the duplicated fd.
-            target.close();
+            // Only automatic keepalives duplicate the fd.
+            if (target.mAutomaticOnOffState != STATE_ALWAYS_ON) {
+                // Close the duplicated fd.
+                target.close();
+            }
         } catch (InvalidSocketException e) {
             Log.wtf(TAG, "Fail to create AutomaticOnOffKeepalive", e);
             return;
diff --git a/service/src/com/android/server/connectivity/KeepaliveTracker.java b/service/src/com/android/server/connectivity/KeepaliveTracker.java
index 125c269..2ce186f 100644
--- a/service/src/com/android/server/connectivity/KeepaliveTracker.java
+++ b/service/src/com/android/server/connectivity/KeepaliveTracker.java
@@ -553,6 +553,8 @@
 
     private KeepaliveInfo handleUpdateKeepaliveForClat(KeepaliveInfo ki)
             throws InvalidSocketException, InvalidPacketException {
+        // Translation applies to only NAT-T keepalive
+        if (ki.mType != KeepaliveInfo.TYPE_NATT) return ki;
         // Only try to translate address if the packet source address is the clat's source address.
         if (!ki.mPacket.getSrcAddress().equals(ki.getNai().getClatv4SrcAddress())) return ki;
 
diff --git a/tests/cts/net/src/android/net/cts/NetworkStatsManagerTest.java b/tests/cts/net/src/android/net/cts/NetworkStatsManagerTest.java
index 83b9b81..7bccbde 100644
--- a/tests/cts/net/src/android/net/cts/NetworkStatsManagerTest.java
+++ b/tests/cts/net/src/android/net/cts/NetworkStatsManagerTest.java
@@ -82,9 +82,9 @@
 import org.junit.Test;
 import org.junit.runner.RunWith;
 
+import java.io.BufferedInputStream;
 import java.io.IOException;
 import java.io.InputStream;
-import java.io.InputStreamReader;
 import java.net.HttpURLConnection;
 import java.net.URL;
 import java.net.UnknownHostException;
@@ -220,7 +220,7 @@
         } else {
             Log.w(LOG_TAG, "Network: " + networkInfo.toString());
         }
-        InputStreamReader in = null;
+        BufferedInputStream in = null;
         HttpURLConnection urlc = null;
         String originalKeepAlive = System.getProperty("http.keepAlive");
         System.setProperty("http.keepAlive", "false");
@@ -236,10 +236,10 @@
             urlc.connect();
             boolean ping = urlc.getResponseCode() == 200;
             if (ping) {
-                in = new InputStreamReader((InputStream) urlc.getContent());
-                // Since the test doesn't really care about the precise amount of data, instead
-                // of reading all contents, just read few bytes at the beginning.
-                in.read();
+                in = new BufferedInputStream((InputStream) urlc.getContent());
+                while (in.read() != -1) {
+                    // Comments to suppress lint error.
+                }
             }
         } catch (Exception e) {
             Log.i(LOG_TAG, "Badness during exercising remote server: " + e);
@@ -377,9 +377,14 @@
                 .addCapability(NetworkCapabilities.NET_CAPABILITY_INTERNET)
                 .build(), callback);
         synchronized (this) {
-            try {
-                wait((int) (TIMEOUT_MILLIS * 2.4));
-            } catch (InterruptedException e) {
+            long now = System.currentTimeMillis();
+            final long deadline = (long) (now + TIMEOUT_MILLIS * 2.4);
+            while (!callback.success && now < deadline) {
+                try {
+                    wait(deadline - now);
+                } catch (InterruptedException e) {
+                }
+                now = System.currentTimeMillis();
             }
         }
         if (callback.success) {
diff --git a/tests/unit/java/com/android/server/connectivity/AutomaticOnOffKeepaliveTrackerTest.java b/tests/unit/java/com/android/server/connectivity/AutomaticOnOffKeepaliveTrackerTest.java
index eeffbe1..b30c9ce 100644
--- a/tests/unit/java/com/android/server/connectivity/AutomaticOnOffKeepaliveTrackerTest.java
+++ b/tests/unit/java/com/android/server/connectivity/AutomaticOnOffKeepaliveTrackerTest.java
@@ -612,17 +612,60 @@
         verifyNoMoreInteractions(ignoreStubs(testInfo.socketKeepaliveCallback));
     }
 
-    @Test
-    public void testStartNattKeepalive_addressTranslationOnClat() throws Exception {
-        final InetAddress v6AddrSrc = InetAddresses.parseNumericAddress("2001:db8::1");
-        final InetAddress v6AddrDst = InetAddresses.parseNumericAddress("2001:db8::2");
-        doReturn(v6AddrDst).when(mNai).translateV4toClatV6(any());
-        doReturn(v6AddrSrc).when(mNai).getClatv6SrcAddress();
+    private void setupTestNaiForClat(InetAddress v6Src, InetAddress v6Dst) throws Exception {
+        doReturn(v6Dst).when(mNai).translateV4toClatV6(any());
+        doReturn(v6Src).when(mNai).getClatv6SrcAddress();
         doReturn(InetAddress.getByAddress(V4_SRC_ADDR)).when(mNai).getClatv4SrcAddress();
         // Setup nai to add clat address
         final LinkProperties stacked = new LinkProperties();
         stacked.setInterfaceName(TEST_V4_IFACE);
+        final InetAddress srcAddress = InetAddress.getByAddress(
+                new byte[] { (byte) 192, 0, 0, (byte) 129 });
+        mNai.linkProperties.addLinkAddress(new LinkAddress(srcAddress, 24));
         mNai.linkProperties.addStackedLink(stacked);
+    }
+
+    private TestKeepaliveInfo doStartTcpKeepalive(InetAddress srcAddr) throws Exception {
+        final KeepalivePacketData kpd = new TcpKeepalivePacketData(
+                srcAddr,
+                12345 /* srcPort */,
+                InetAddress.getByAddress(new byte[] { 8, 8, 8, 8}) /* dstAddr */,
+                12345 /* dstPort */, new byte[] {1},  111 /* tcpSeq */,
+                222 /* tcpAck */, 800 /* tcpWindow */, 2 /* tcpWindowScale */,
+                4 /* ipTos */, 64 /* ipTtl */);
+        final TestKeepaliveInfo testInfo = new TestKeepaliveInfo(kpd);
+
+        final KeepaliveInfo ki = mKeepaliveTracker.new KeepaliveInfo(
+                testInfo.socketKeepaliveCallback, mNai, kpd,
+                TEST_KEEPALIVE_INTERVAL_SEC, KeepaliveInfo.TYPE_TCP, testInfo.fd);
+        mKeepaliveTracker.setReturnedKeepaliveInfo(ki);
+
+        // Setup TCP keepalive.
+        mAOOKeepaliveTracker.startTcpKeepalive(mNai, testInfo.fd, TEST_KEEPALIVE_INTERVAL_SEC,
+                testInfo.socketKeepaliveCallback);
+        HandlerUtils.waitForIdle(mTestHandler, TIMEOUT_MS);
+        return testInfo;
+    }
+    @Test
+    public void testStartTcpKeepalive_addressTranslationOnClat() throws Exception {
+        setupTestNaiForClat(InetAddresses.parseNumericAddress("2001:db8::1") /* v6Src */,
+                InetAddresses.parseNumericAddress("2001:db8::2") /* v6Dst */);
+        final InetAddress srcAddr = InetAddress.getByAddress(V4_SRC_ADDR);
+        doStartTcpKeepalive(srcAddr);
+        final ArgumentCaptor<TcpKeepalivePacketData> tpdCaptor =
+                ArgumentCaptor.forClass(TcpKeepalivePacketData.class);
+        verify(mNai).onStartTcpSocketKeepalive(
+                eq(TEST_SLOT), eq(TEST_KEEPALIVE_INTERVAL_SEC), tpdCaptor.capture());
+        final TcpKeepalivePacketData tpd = tpdCaptor.getValue();
+        // Verify the addresses still be the same address when clat is started.
+        assertEquals(srcAddr, tpd.getSrcAddress());
+    }
+
+    @Test
+    public void testStartNattKeepalive_addressTranslationOnClat() throws Exception {
+        final InetAddress v6AddrSrc = InetAddresses.parseNumericAddress("2001:db8::1");
+        final InetAddress v6AddrDst = InetAddresses.parseNumericAddress("2001:db8::2");
+        setupTestNaiForClat(v6AddrSrc, v6AddrDst);
 
         final TestKeepaliveInfo testInfo = doStartNattKeepalive();
         final ArgumentCaptor<NattKeepalivePacketData> kpdCaptor =
@@ -899,24 +942,8 @@
                 new byte[] { (byte) 192, 0, 0, (byte) 129 });
         mNai.linkProperties.addLinkAddress(new LinkAddress(srcAddress, 24));
 
-        final KeepalivePacketData kpd = new TcpKeepalivePacketData(
-                InetAddress.getByAddress(new byte[] { (byte) 192, 0, 0, (byte) 129 }) /* srcAddr */,
-                12345 /* srcPort */,
-                InetAddress.getByAddress(new byte[] { 8, 8, 8, 8}) /* dstAddr */,
-                12345 /* dstPort */, new byte[] {1},  111 /* tcpSeq */,
-                222 /* tcpAck */, 800 /* tcpWindow */, 2 /* tcpWindowScale */,
-                4 /* ipTos */, 64 /* ipTtl */);
-        final TestKeepaliveInfo testInfo = new TestKeepaliveInfo(kpd);
-
-        final KeepaliveInfo ki = mKeepaliveTracker.new KeepaliveInfo(
-                testInfo.socketKeepaliveCallback, mNai, kpd,
-                TEST_KEEPALIVE_INTERVAL_SEC, KeepaliveInfo.TYPE_TCP, testInfo.fd);
-        mKeepaliveTracker.setReturnedKeepaliveInfo(ki);
-
-        // Setup TCP keepalive.
-        mAOOKeepaliveTracker.startTcpKeepalive(mNai, testInfo.fd, TEST_KEEPALIVE_INTERVAL_SEC,
-                testInfo.socketKeepaliveCallback);
-        HandlerUtils.waitForIdle(mTestHandler, TIMEOUT_MS);
+        final TestKeepaliveInfo testInfo =
+                doStartTcpKeepalive(InetAddress.getByAddress(V4_SRC_ADDR));
 
         // A closed socket will result in EVENT_HANGUP and trigger error to
         // FileDescriptorEventListener.
@@ -924,6 +951,6 @@
         HandlerUtils.waitForIdle(mTestHandler, TIMEOUT_MS);
 
         // The keepalive should be removed in AutomaticOnOffKeepaliveTracker.
-        getAutoKiForBinder(testInfo.binder);
+        assertNull(getAutoKiForBinder(testInfo.binder));
     }
 }
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java
index cf6275f..ad5583b 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java
@@ -16,6 +16,7 @@
 
 package com.android.server.connectivity.mdns;
 
+import static com.android.server.connectivity.mdns.MdnsServiceTypeClient.EVENT_START_QUERYTASK;
 import static com.android.testutils.DevSdkIgnoreRuleKt.SC_V2;
 
 import static org.junit.Assert.assertArrayEquals;
@@ -26,8 +27,10 @@
 import static org.junit.Assert.assertTrue;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.anyInt;
+import static org.mockito.ArgumentMatchers.anyLong;
 import static org.mockito.ArgumentMatchers.argThat;
 import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.inOrder;
 import static org.mockito.Mockito.never;
@@ -43,6 +46,7 @@
 import android.net.Network;
 import android.os.Handler;
 import android.os.HandlerThread;
+import android.os.Message;
 import android.text.TextUtils;
 
 import com.android.net.module.util.CollectionUtils;
@@ -112,6 +116,8 @@
     private MdnsResponseDecoder.Clock mockDecoderClock;
     @Mock
     private SharedLog mockSharedLog;
+    @Mock
+    private MdnsServiceTypeClient.Dependencies mockDeps;
     @Captor
     private ArgumentCaptor<MdnsServiceInfo> serviceInfoCaptor;
 
@@ -119,13 +125,15 @@
 
     private DatagramPacket[] expectedIPv4Packets;
     private DatagramPacket[] expectedIPv6Packets;
-    private ScheduledFuture<?>[] expectedSendFutures;
     private FakeExecutor currentThreadExecutor = new FakeExecutor();
 
     private MdnsServiceTypeClient client;
     private SocketKey socketKey;
     private HandlerThread thread;
     private Handler handler;
+    private long latestDelayMs = 0;
+    private Message delayMessage = null;
+    private Handler realHandler = null;
 
     @Before
     @SuppressWarnings("DoNotMock")
@@ -135,15 +143,13 @@
 
         expectedIPv4Packets = new DatagramPacket[16];
         expectedIPv6Packets = new DatagramPacket[16];
-        expectedSendFutures = new ScheduledFuture<?>[16];
         socketKey = new SocketKey(mockNetwork, INTERFACE_INDEX);
 
-        for (int i = 0; i < expectedSendFutures.length; ++i) {
+        for (int i = 0; i < expectedIPv4Packets.length; ++i) {
             expectedIPv4Packets[i] = new DatagramPacket(buf, 0 /* offset */, 5 /* length */,
                     MdnsConstants.getMdnsIPv4Address(), MdnsConstants.MDNS_PORT);
             expectedIPv6Packets[i] = new DatagramPacket(buf, 0 /* offset */, 5 /* length */,
                     MdnsConstants.getMdnsIPv6Address(), MdnsConstants.MDNS_PORT);
-            expectedSendFutures[i] = Mockito.mock(ScheduledFuture.class);
         }
         when(mockPacketWriter.getPacket(IPV4_ADDRESS))
                 .thenReturn(expectedIPv4Packets[0])
@@ -184,9 +190,23 @@
         thread = new HandlerThread("MdnsServiceTypeClientTests");
         thread.start();
         handler = new Handler(thread.getLooper());
+
+        doAnswer(inv -> {
+            latestDelayMs = 0;
+            delayMessage = null;
+            return true;
+        }).when(mockDeps).removeMessages(any(Handler.class), eq(EVENT_START_QUERYTASK));
+
+        doAnswer(inv -> {
+            realHandler = (Handler) inv.getArguments()[0];
+            delayMessage = (Message) inv.getArguments()[1];
+            latestDelayMs = (long) inv.getArguments()[2];
+            return true;
+        }).when(mockDeps).sendMessageDelayed(any(Handler.class), any(Message.class), anyLong());
+
         client =
                 new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor,
-                        mockDecoderClock, socketKey, mockSharedLog, thread.getLooper()) {
+                        mockDecoderClock, socketKey, mockSharedLog, thread.getLooper(), mockDeps) {
                     @Override
                     MdnsPacketWriter createMdnsPacketWriter() {
                         return mockPacketWriter;
@@ -223,11 +243,18 @@
         runOnHandler(() -> client.notifySocketDestroyed());
     }
 
+    private void dispatchMessage() {
+        runOnHandler(() -> realHandler.dispatchMessage(delayMessage));
+        delayMessage = null;
+    }
+
     @Test
     public void sendQueries_activeScanMode() {
         MdnsSearchOptions searchOptions =
                 MdnsSearchOptions.newBuilder().addSubtype("12345").setIsPassiveMode(false).build();
         startSendAndReceive(mockListenerOne, searchOptions);
+        // Always try to remove the task.
+        verify(mockDeps, times(1)).removeMessages(any(), eq(EVENT_START_QUERYTASK));
 
         // First burst, 3 queries.
         verifyAndSendQuery(0, 0, /* expectsUnicastResponse= */ true);
@@ -265,10 +292,12 @@
                 13, MdnsConfigs.timeBetweenQueriesInBurstMs(), /* expectsUnicastResponse= */ false);
         verifyAndSendQuery(
                 14, MdnsConfigs.timeBetweenQueriesInBurstMs(), /* expectsUnicastResponse= */ false);
+        // Verify that Task is not removed before stopSendAndReceive was called.
+        verify(mockDeps, times(1)).removeMessages(any(), eq(EVENT_START_QUERYTASK));
 
         // Stop sending packets.
         stopSendAndReceive(mockListenerOne);
-        verify(expectedSendFutures[15]).cancel(true);
+        verify(mockDeps, times(2)).removeMessages(any(), eq(EVENT_START_QUERYTASK));
     }
 
     @Test
@@ -276,6 +305,8 @@
         MdnsSearchOptions searchOptions =
                 MdnsSearchOptions.newBuilder().addSubtype("12345").setIsPassiveMode(false).build();
         startSendAndReceive(mockListenerOne, searchOptions);
+        // Always try to remove the task.
+        verify(mockDeps, times(1)).removeMessages(any(), eq(EVENT_START_QUERYTASK));
 
         // First burst, first query is sent.
         verifyAndSendQuery(0, 0, /* expectsUnicastResponse= */ true);
@@ -289,7 +320,7 @@
                         .build();
         startSendAndReceive(mockListenerOne, searchOptions);
         // The previous scheduled task should be canceled.
-        verify(expectedSendFutures[1]).cancel(true);
+        verify(mockDeps, times(2)).removeMessages(any(), eq(EVENT_START_QUERYTASK));
 
         // Queries should continue to be sent.
         verifyAndSendQuery(1, 0, /* expectsUnicastResponse= */ true);
@@ -300,7 +331,7 @@
 
         // Stop sending packets.
         stopSendAndReceive(mockListenerOne);
-        verify(expectedSendFutures[5]).cancel(true);
+        verify(mockDeps, times(3)).removeMessages(any(), eq(EVENT_START_QUERYTASK));
     }
 
     @Test
@@ -308,6 +339,8 @@
         MdnsSearchOptions searchOptions =
                 MdnsSearchOptions.newBuilder().addSubtype("12345").setIsPassiveMode(true).build();
         startSendAndReceive(mockListenerOne, searchOptions);
+        // Always try to remove the task.
+        verify(mockDeps, times(1)).removeMessages(any(), eq(EVENT_START_QUERYTASK));
 
         // First burst, 3 query.
         verifyAndSendQuery(0, 0, /* expectsUnicastResponse= */ true);
@@ -324,7 +357,7 @@
 
         // Stop sending packets.
         stopSendAndReceive(mockListenerOne);
-        verify(expectedSendFutures[5]).cancel(true);
+        verify(mockDeps, times(2)).removeMessages(any(), eq(EVENT_START_QUERYTASK));
     }
 
     @Test
@@ -333,6 +366,8 @@
                 MdnsSearchOptions.newBuilder().addSubtype("12345").setIsPassiveMode(
                         false).setNumOfQueriesBeforeBackoff(11).build();
         startSendAndReceive(mockListenerOne, searchOptions);
+        // Always try to remove the task.
+        verify(mockDeps, times(1)).removeMessages(any(), eq(EVENT_START_QUERYTASK));
 
         // First burst, 3 queries.
         verifyAndSendQuery(0, 0, /* expectsUnicastResponse= */ true);
@@ -367,16 +402,21 @@
         // 0.8 * smallestRemainingTtl is larger than time to next run.
         long currentTime = TEST_TTL / 2 + TEST_ELAPSED_REALTIME;
         doReturn(currentTime).when(mockDecoderClock).elapsedRealtime();
+        doReturn(true).when(mockDeps).hasMessages(any(), eq(EVENT_START_QUERYTASK));
         processResponse(createResponse(
                 "service-instance-1", "192.0.2.123", 5353,
                 SERVICE_TYPE_LABELS,
                 Collections.emptyMap(), TEST_TTL), socketKey);
-        verifyAndSendQuery(12, (long) (TEST_TTL / 2 * 0.8), /* expectsUnicastResponse= */
-                false);
+        verify(mockDeps, times(2)).removeMessages(any(), eq(EVENT_START_QUERYTASK));
+        assertNotNull(delayMessage);
+        verifyAndSendQuery(12 /* index */, (long) (TEST_TTL / 2 * 0.8) /* timeInMs */,
+                false /* expectsUnicastResponse */, true /* multipleSocketDiscovery */,
+                14 /* scheduledCount */);
         currentTime += (long) (TEST_TTL / 2 * 0.8);
         doReturn(currentTime).when(mockDecoderClock).elapsedRealtime();
-        verifyAndSendQuery(
-                13, MdnsConfigs.timeBetweenQueriesInBurstMs(), /* expectsUnicastResponse= */ false);
+        verifyAndSendQuery(13 /* index */, MdnsConfigs.timeBetweenQueriesInBurstMs(),
+                false /* expectsUnicastResponse */, true /* multipleSocketDiscovery */,
+                15 /* scheduledCount */);
     }
 
     @Test
@@ -385,29 +425,36 @@
                 MdnsSearchOptions.newBuilder().addSubtype("12345").setIsPassiveMode(
                         true).setNumOfQueriesBeforeBackoff(3).build();
         startSendAndReceive(mockListenerOne, searchOptions);
-        verifyAndSendQuery(0, 0, /* expectsUnicastResponse= */ true);
-        verifyAndSendQuery(
-                1, MdnsConfigs.timeBetweenQueriesInBurstMs(), /* expectsUnicastResponse= */ false);
-        verifyAndSendQuery(
-                2, MdnsConfigs.timeBetweenQueriesInBurstMs(), /* expectsUnicastResponse= */ false);
-        verifyAndSendQuery(3, MdnsConfigs.timeBetweenBurstsMs(), /* expectsUnicastResponse= */
-                false);
-        assertEquals(4, currentThreadExecutor.getNumOfScheduledFuture());
+        // Always try to remove the task.
+        verify(mockDeps, times(1)).removeMessages(any(), eq(EVENT_START_QUERYTASK));
+
+        verifyAndSendQuery(0 /* index */, 0 /* timeInMs */, true /* expectsUnicastResponse */,
+                true /* multipleSocketDiscovery */, 1 /* scheduledCount */);
+        verifyAndSendQuery(1 /* index */, MdnsConfigs.timeBetweenQueriesInBurstMs(),
+                false /* expectsUnicastResponse */, true /* multipleSocketDiscovery */,
+                2 /* scheduledCount */);
+        verifyAndSendQuery(2 /* index */, MdnsConfigs.timeBetweenQueriesInBurstMs(),
+                false /* expectsUnicastResponse */, true /* multipleSocketDiscovery */,
+                3 /* scheduledCount */);
+        verifyAndSendQuery(3 /* index */, MdnsConfigs.timeBetweenBurstsMs(),
+                false /* expectsUnicastResponse */, true /* multipleSocketDiscovery */,
+                4 /* scheduledCount */);
 
         // In backoff mode, the current scheduled task will be canceled and reschedule if the
         // 0.8 * smallestRemainingTtl is larger than time to next run.
         doReturn(TEST_ELAPSED_REALTIME + 20000).when(mockDecoderClock).elapsedRealtime();
+        doReturn(true).when(mockDeps).hasMessages(any(), eq(EVENT_START_QUERYTASK));
         processResponse(createResponse(
                 "service-instance-1", "192.0.2.123", 5353,
                 SERVICE_TYPE_LABELS,
                 Collections.emptyMap(), TEST_TTL), socketKey);
-        verify(expectedSendFutures[4]).cancel(true);
-        assertEquals(5, currentThreadExecutor.getNumOfScheduledFuture());
-        verifyAndSendQuery(4, 80000 /* timeInMs */, false /* expectsUnicastResponse */);
-        assertEquals(6, currentThreadExecutor.getNumOfScheduledFuture());
+        verify(mockDeps, times(2)).removeMessages(any(), eq(EVENT_START_QUERYTASK));
+        assertNotNull(delayMessage);
+        verifyAndSendQuery(4 /* index */, 80000 /* timeInMs */, false /* expectsUnicastResponse */,
+                true /* multipleSocketDiscovery */, 6 /* scheduledCount */);
         // Next run should also be scheduled in 0.8 * smallestRemainingTtl
-        verifyAndSendQuery(5, 80000 /* timeInMs */, false /* expectsUnicastResponse */);
-        assertEquals(7, currentThreadExecutor.getNumOfScheduledFuture());
+        verifyAndSendQuery(5 /* index */, 80000 /* timeInMs */, false /* expectsUnicastResponse */,
+                true /* multipleSocketDiscovery */, 7 /* scheduledCount */);
 
         // If the records is not refreshed, the current scheduled task will not be canceled.
         doReturn(TEST_ELAPSED_REALTIME + 20001).when(mockDecoderClock).elapsedRealtime();
@@ -416,7 +463,7 @@
                 SERVICE_TYPE_LABELS,
                 Collections.emptyMap(), TEST_TTL,
                 TEST_ELAPSED_REALTIME - 1), socketKey);
-        verify(expectedSendFutures[7], never()).cancel(true);
+        verify(mockDeps, times(2)).removeMessages(any(), eq(EVENT_START_QUERYTASK));
 
         // In backoff mode, the current scheduled task will not be canceled if the
         // 0.8 * smallestRemainingTtl is smaller than time to next run.
@@ -425,10 +472,10 @@
                 "service-instance-1", "192.0.2.123", 5353,
                 SERVICE_TYPE_LABELS,
                 Collections.emptyMap(), TEST_TTL), socketKey);
-        verify(expectedSendFutures[7], never()).cancel(true);
+        verify(mockDeps, times(2)).removeMessages(any(), eq(EVENT_START_QUERYTASK));
 
         stopSendAndReceive(mockListenerOne);
-        verify(expectedSendFutures[7]).cancel(true);
+        verify(mockDeps, times(3)).removeMessages(any(), eq(EVENT_START_QUERYTASK));
     }
 
     @Test
@@ -436,6 +483,8 @@
         MdnsSearchOptions searchOptions =
                 MdnsSearchOptions.newBuilder().addSubtype("12345").setIsPassiveMode(true).build();
         startSendAndReceive(mockListenerOne, searchOptions);
+        // Always try to remove the task.
+        verify(mockDeps, times(1)).removeMessages(any(), eq(EVENT_START_QUERYTASK));
 
         // First burst, first query is sent.
         verifyAndSendQuery(0, 0, /* expectsUnicastResponse= */ true);
@@ -449,7 +498,7 @@
                         .build();
         startSendAndReceive(mockListenerOne, searchOptions);
         // The previous scheduled task should be canceled.
-        verify(expectedSendFutures[1]).cancel(true);
+        verify(mockDeps, times(2)).removeMessages(any(), eq(EVENT_START_QUERYTASK));
 
         // Queries should continue to be sent.
         verifyAndSendQuery(1, 0, /* expectsUnicastResponse= */ true);
@@ -460,7 +509,7 @@
 
         // Stop sending packets.
         stopSendAndReceive(mockListenerOne);
-        verify(expectedSendFutures[5]).cancel(true);
+        verify(mockDeps, times(3)).removeMessages(any(), eq(EVENT_START_QUERYTASK));
     }
 
     @Test
@@ -596,10 +645,9 @@
 
         // This time no query is submitted, only scheduled
         assertNull(currentThreadExecutor.getAndClearSubmittedRunnable());
-        assertNotNull(currentThreadExecutor.getAndClearLastScheduledRunnable());
         // This just skips the first query of the first burst
-        assertEquals(MdnsConfigs.timeBetweenQueriesInBurstMs(),
-                currentThreadExecutor.getAndClearLastScheduledDelayInMs());
+        verify(mockDeps).sendMessageDelayed(
+                any(), any(), eq(MdnsConfigs.timeBetweenQueriesInBurstMs()));
     }
 
     private static void verifyServiceInfo(MdnsServiceInfo serviceInfo, String serviceName,
@@ -853,7 +901,7 @@
         final String serviceInstanceName = "service-instance-1";
         client =
                 new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor,
-                        mockDecoderClock, socketKey, mockSharedLog, thread.getLooper()) {
+                        mockDecoderClock, socketKey, mockSharedLog, thread.getLooper(), mockDeps) {
                     @Override
                     MdnsPacketWriter createMdnsPacketWriter() {
                         return mockPacketWriter;
@@ -896,7 +944,7 @@
         final String serviceInstanceName = "service-instance-1";
         client =
                 new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor,
-                        mockDecoderClock, socketKey, mockSharedLog, thread.getLooper()) {
+                        mockDecoderClock, socketKey, mockSharedLog, thread.getLooper(), mockDeps) {
                     @Override
                     MdnsPacketWriter createMdnsPacketWriter() {
                         return mockPacketWriter;
@@ -929,7 +977,7 @@
         final String serviceInstanceName = "service-instance-1";
         client =
                 new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor,
-                        mockDecoderClock, socketKey, mockSharedLog, thread.getLooper()) {
+                        mockDecoderClock, socketKey, mockSharedLog, thread.getLooper(), mockDeps) {
                     @Override
                     MdnsPacketWriter createMdnsPacketWriter() {
                         return mockPacketWriter;
@@ -1049,7 +1097,7 @@
     @Test
     public void testProcessResponse_Resolve() throws Exception {
         client = new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor,
-                socketKey, mockSharedLog, thread.getLooper());
+                mockDecoderClock, socketKey, mockSharedLog, thread.getLooper(), mockDeps);
 
         final String instanceName = "service-instance";
         final String[] hostname = new String[] { "testhost "};
@@ -1070,6 +1118,7 @@
         inOrder.verify(mockSocketClient, times(2)).sendPacketRequestingUnicastResponse(
                 srvTxtQueryCaptor.capture(),
                 eq(socketKey), eq(false));
+        assertNotNull(delayMessage);
 
         final MdnsPacket srvTxtQueryPacket = MdnsPacket.parse(
                 new MdnsPacketReader(srvTxtQueryCaptor.getValue()));
@@ -1095,6 +1144,7 @@
         processResponse(srvTxtResponse, socketKey);
 
         // Expect a query for A/AAAA
+        dispatchMessage();
         final ArgumentCaptor<DatagramPacket> addressQueryCaptor =
                 ArgumentCaptor.forClass(DatagramPacket.class);
         currentThreadExecutor.getAndClearLastScheduledRunnable().run();
@@ -1139,7 +1189,7 @@
     @Test
     public void testRenewTxtSrvInResolve() throws Exception {
         client = new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor,
-                mockDecoderClock, socketKey, mockSharedLog, thread.getLooper());
+                mockDecoderClock, socketKey, mockSharedLog, thread.getLooper(), mockDeps);
 
         final String instanceName = "service-instance";
         final String[] hostname = new String[] { "testhost "};
@@ -1160,6 +1210,7 @@
         inOrder.verify(mockSocketClient, times(2)).sendPacketRequestingUnicastResponse(
                 srvTxtQueryCaptor.capture(),
                 eq(socketKey), eq(false));
+        assertNotNull(delayMessage);
 
         final MdnsPacket srvTxtQueryPacket = MdnsPacket.parse(
                 new MdnsPacketReader(srvTxtQueryCaptor.getValue()));
@@ -1187,6 +1238,7 @@
                 Collections.emptyList() /* authorityRecords */,
                 Collections.emptyList() /* additionalRecords */);
         processResponse(srvTxtResponse, socketKey);
+        dispatchMessage();
         inOrder.verify(mockListenerOne).onServiceNameDiscovered(any());
         inOrder.verify(mockListenerOne).onServiceFound(any());
 
@@ -1197,6 +1249,8 @@
         // Advance time so 75% of TTL passes and re-execute
         doReturn(TEST_ELAPSED_REALTIME + (long) (TEST_TTL * 0.75))
                 .when(mockDecoderClock).elapsedRealtime();
+        assertNotNull(delayMessage);
+        dispatchMessage();
         currentThreadExecutor.getAndClearLastScheduledRunnable().run();
 
         // Expect a renewal query
@@ -1211,6 +1265,7 @@
                 new MdnsPacketReader(renewalQueryCaptor.getValue()));
         assertTrue(hasQuestion(renewalPacket, MdnsRecord.TYPE_ANY, serviceName));
         inOrder.verifyNoMoreInteractions();
+        assertNotNull(delayMessage);
 
         long updatedReceiptTime =  TEST_ELAPSED_REALTIME + TEST_TTL;
         final MdnsPacket refreshedSrvTxtResponse = new MdnsPacket(
@@ -1232,6 +1287,7 @@
                 Collections.emptyList() /* authorityRecords */,
                 Collections.emptyList() /* additionalRecords */);
         processResponse(refreshedSrvTxtResponse, socketKey);
+        dispatchMessage();
 
         // Advance time to updatedReceiptTime + 1, expected no refresh query because the cache
         // should contain the record that have update last receipt time.
@@ -1243,7 +1299,7 @@
     @Test
     public void testProcessResponse_ResolveExcludesOtherServices() {
         client = new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor,
-                socketKey, mockSharedLog, thread.getLooper());
+                mockDecoderClock, socketKey, mockSharedLog, thread.getLooper(), mockDeps);
 
         final String requestedInstance = "instance1";
         final String otherInstance = "instance2";
@@ -1307,7 +1363,7 @@
     @Test
     public void testProcessResponse_SubtypeDiscoveryLimitedToSubtype() {
         client = new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor,
-                socketKey, mockSharedLog, thread.getLooper());
+                mockDecoderClock, socketKey, mockSharedLog, thread.getLooper(), mockDeps);
 
         final String matchingInstance = "instance1";
         final String subtype = "_subtype";
@@ -1388,7 +1444,7 @@
     @Test
     public void testNotifySocketDestroyed() throws Exception {
         client = new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor,
-                socketKey, mockSharedLog, thread.getLooper());
+                mockDecoderClock, socketKey, mockSharedLog, thread.getLooper(), mockDeps);
 
         final String requestedInstance = "instance1";
         final String otherInstance = "instance2";
@@ -1399,6 +1455,8 @@
                 .setResolveInstanceName("instance1").build();
 
         startSendAndReceive(mockListenerOne, resolveOptions);
+        // Always try to remove the task.
+        verify(mockDeps, times(1)).removeMessages(any(), eq(EVENT_START_QUERYTASK));
         // Ensure the first task is executed so it schedules a future task
         currentThreadExecutor.getAndClearSubmittedFuture().get(
                 TEST_TIMEOUT_MS, TimeUnit.MILLISECONDS);
@@ -1407,7 +1465,7 @@
                         Integer.MAX_VALUE).build());
 
         // Filing the second request cancels the first future
-        verify(expectedSendFutures[0]).cancel(true);
+        verify(mockDeps, times(2)).removeMessages(any(), eq(EVENT_START_QUERYTASK));
 
         // Ensure it gets executed too
         currentThreadExecutor.getAndClearSubmittedFuture().get(
@@ -1425,9 +1483,8 @@
                         Collections.emptyMap() /* textAttributes */, TEST_TTL),
                 socketKey);
 
-        verify(expectedSendFutures[1], never()).cancel(true);
         notifySocketDestroyed();
-        verify(expectedSendFutures[1]).cancel(true);
+        verify(mockDeps, times(3)).removeMessages(any(), eq(EVENT_START_QUERYTASK));
 
         // mockListenerOne gets notified for the requested instance
         final InOrder inOrder1 = inOrder(mockListenerOne);
@@ -1461,13 +1518,17 @@
     // verifies that the right query was enqueued with the right delay, and send query by executing
     // the runnable.
     private void verifyAndSendQuery(int index, long timeInMs, boolean expectsUnicastResponse) {
-        verifyAndSendQuery(
-                index, timeInMs, expectsUnicastResponse, true /* multipleSocketDiscovery */);
+        verifyAndSendQuery(index, timeInMs, expectsUnicastResponse,
+                true /* multipleSocketDiscovery */, index + 1 /* scheduledCount */);
     }
 
     private void verifyAndSendQuery(int index, long timeInMs, boolean expectsUnicastResponse,
-            boolean multipleSocketDiscovery) {
-        assertEquals(timeInMs, currentThreadExecutor.getAndClearLastScheduledDelayInMs());
+            boolean multipleSocketDiscovery, int scheduledCount) {
+        // Dispatch the message
+        if (delayMessage != null && realHandler != null) {
+            dispatchMessage();
+        }
+        assertEquals(timeInMs, latestDelayMs);
         currentThreadExecutor.getAndClearLastScheduledRunnable().run();
         if (expectsUnicastResponse) {
             verify(mockSocketClient).sendPacketRequestingUnicastResponse(
@@ -1484,6 +1545,9 @@
                         expectedIPv6Packets[index], socketKey, false);
             }
         }
+        // Verify the task has been scheduled.
+        verify(mockDeps, times(scheduledCount))
+                .sendMessageDelayed(any(Handler.class), any(Message.class), anyLong());
     }
 
     private static String[] getTestServiceName(String instanceName) {
@@ -1528,7 +1592,7 @@
         public ScheduledFuture<?> schedule(Runnable command, long delay, TimeUnit unit) {
             lastScheduledDelayInMs = delay;
             lastScheduledRunnable = command;
-            return expectedSendFutures[futureIndex++];
+            return Mockito.mock(ScheduledFuture.class);
         }
 
         // Returns the delay of the last scheduled task, and clear it.
@@ -1556,10 +1620,6 @@
             lastSubmittedFuture = null;
             return val;
         }
-
-        public int getNumOfScheduledFuture() {
-            return futureIndex - 1;
-        }
     }
 
     private MdnsPacket createResponse(