Merge "Remove synchronized lock in MdnsServiceTypeClient" into main
diff --git a/service-t/src/com/android/server/connectivity/mdns/EnqueueMdnsQueryCallable.java b/service-t/src/com/android/server/connectivity/mdns/EnqueueMdnsQueryCallable.java
index bd4ec20..13f6dac 100644
--- a/service-t/src/com/android/server/connectivity/mdns/EnqueueMdnsQueryCallable.java
+++ b/service-t/src/com/android/server/connectivity/mdns/EnqueueMdnsQueryCallable.java
@@ -17,7 +17,6 @@
 package com.android.server.connectivity.mdns;
 
 import android.annotation.NonNull;
-import android.annotation.Nullable;
 import android.text.TextUtils;
 import android.util.Log;
 import android.util.Pair;
@@ -106,12 +105,11 @@
     // Incompatible return type for override of Callable#call().
     @SuppressWarnings("nullness:override.return.invalid")
     @Override
-    @Nullable
     public Pair<Integer, List<String>> call() {
         try {
             MdnsSocketClientBase requestSender = weakRequestSender.get();
             if (requestSender == null) {
-                return null;
+                return Pair.create(-1, new ArrayList<>());
             }
 
             int numQuestions = 0;
@@ -158,7 +156,7 @@
 
             if (numQuestions == 0) {
                 // No query to send
-                return null;
+                return Pair.create(-1, new ArrayList<>());
             }
 
             // Header.
@@ -197,7 +195,7 @@
         } catch (IOException e) {
             LOGGER.e(String.format("Failed to create mDNS packet for subtype: %s.",
                     TextUtils.join(",", subtypes)), e);
-            return null;
+            return Pair.create(-1, new ArrayList<>());
         }
     }
 
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
index e1e3170..7035c90 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
@@ -56,6 +56,7 @@
     private static final int DEFAULT_MTU = 1500;
     @VisibleForTesting
     static final int EVENT_START_QUERYTASK = 1;
+    static final int EVENT_QUERY_RESULT = 2;
 
     private final String serviceType;
     private final String[] serviceTypeLabels;
@@ -66,11 +67,9 @@
     @NonNull private final SharedLog sharedLog;
     @NonNull private final Handler handler;
     @NonNull private final Dependencies dependencies;
-    private final Object lock = new Object();
     private final ArrayMap<MdnsServiceBrowserListener, MdnsSearchOptions> listeners =
             new ArrayMap<>();
     // TODO: change instanceNameToResponse to TreeMap with case insensitive comparator.
-    @GuardedBy("lock")
     private final Map<String, MdnsResponse> instanceNameToResponse = new HashMap<>();
     private final boolean removeServiceAfterTtlExpires =
             MdnsConfigs.removeServiceAfterTtlExpires();
@@ -83,11 +82,8 @@
     // new subtypes. It stays the same between packets for same subtypes.
     private long currentSessionId = 0;
 
-    @GuardedBy("lock")
     @Nullable
-    private QueryTask lastScheduledTask;
-
-    @GuardedBy("lock")
+    private ScheduledQueryTaskArgs lastScheduledQueryTaskArgs;
     private long lastSentTime;
 
     private class QueryTaskHandler extends Handler {
@@ -98,9 +94,48 @@
         @Override
         public void handleMessage(Message msg) {
             switch (msg.what) {
-                case EVENT_START_QUERYTASK:
-                    handleStartQueryTask((QueryTask) msg.obj);
+                case EVENT_START_QUERYTASK: {
+                    final ScheduledQueryTaskArgs taskArgs = (ScheduledQueryTaskArgs) msg.obj;
+                    // QueryTask should be run immediately after being created (not be scheduled in
+                    // advance). Because the result of "makeResponsesForResolve" depends on answers
+                    // that were received before it is called, so to take into account all answers
+                    // before sending the query, it needs to be called just before sending it.
+                    final List<MdnsResponse> servicesToResolve = makeResponsesForResolve(socketKey);
+                    final QueryTask queryTask = new QueryTask(taskArgs, servicesToResolve,
+                            servicesToResolve.size() < listeners.size() /* sendDiscoveryQueries */);
+                    executor.submit(queryTask);
                     break;
+                }
+                case EVENT_QUERY_RESULT: {
+                    final QuerySentResult sentResult = (QuerySentResult) msg.obj;
+                    if (MdnsConfigs.useSessionIdToScheduleMdnsTask()) {
+                        // In case that the task is not canceled successfully, use session ID to
+                        // check if this task should continue to schedule more.
+                        if (sentResult.taskArgs.sessionId != currentSessionId) {
+                            break;
+                        }
+                    }
+
+                    if ((sentResult.transactionId != -1)) {
+                        for (int i = 0; i < listeners.size(); i++) {
+                            listeners.keyAt(i).onDiscoveryQuerySent(
+                                    sentResult.subTypes, sentResult.transactionId);
+                        }
+                    }
+
+                    tryRemoveServiceAfterTtlExpires();
+
+                    final QueryTaskConfig nextRunConfig =
+                            sentResult.taskArgs.config.getConfigForNextRun();
+                    final long now = clock.elapsedRealtime();
+                    lastSentTime = now;
+                    final long minRemainingTtl = getMinRemainingTtl(now);
+                    final long timeToRun = calculateTimeToRun(lastScheduledQueryTaskArgs,
+                            nextRunConfig, now, minRemainingTtl, lastSentTime);
+                    scheduleNextRun(nextRunConfig, minRemainingTtl, now, timeToRun,
+                            lastScheduledQueryTaskArgs.sessionId);
+                    break;
+                }
                 default:
                     sharedLog.e("Unrecognized event " + msg.what);
                     break;
@@ -134,6 +169,13 @@
         public boolean hasMessages(@NonNull Handler handler, int what) {
             return handler.hasMessages(what);
         }
+
+        /**
+         * @see Handler#post(Runnable)
+         */
+        public void sendMessage(@NonNull Handler handler, @NonNull Message message) {
+            handler.sendMessage(message);
+        }
     }
 
     /**
@@ -236,62 +278,57 @@
             @NonNull MdnsServiceBrowserListener listener,
             @NonNull MdnsSearchOptions searchOptions) {
         ensureRunningOnHandlerThread(handler);
-        synchronized (lock) {
-            this.searchOptions = searchOptions;
-            boolean hadReply = false;
-            if (listeners.put(listener, searchOptions) == null) {
-                for (MdnsResponse existingResponse : instanceNameToResponse.values()) {
-                    if (!responseMatchesOptions(existingResponse, searchOptions)) continue;
-                    final MdnsServiceInfo info =
-                            buildMdnsServiceInfoFromResponse(existingResponse, serviceTypeLabels);
-                    listener.onServiceNameDiscovered(info);
-                    if (existingResponse.isComplete()) {
-                        listener.onServiceFound(info);
-                        hadReply = true;
-                    }
+        this.searchOptions = searchOptions;
+        boolean hadReply = false;
+        if (listeners.put(listener, searchOptions) == null) {
+            for (MdnsResponse existingResponse : instanceNameToResponse.values()) {
+                if (!responseMatchesOptions(existingResponse, searchOptions)) continue;
+                final MdnsServiceInfo info =
+                        buildMdnsServiceInfoFromResponse(existingResponse, serviceTypeLabels);
+                listener.onServiceNameDiscovered(info);
+                if (existingResponse.isComplete()) {
+                    listener.onServiceFound(info);
+                    hadReply = true;
                 }
             }
-            // Remove the next scheduled periodical task.
-            removeScheduledTaskLock();
-            // Keep tracking the ScheduledFuture for the task so we can cancel it if caller is not
-            // interested anymore.
-            final QueryTaskConfig taskConfig = new QueryTaskConfig(
-                    searchOptions.getSubtypes(),
-                    searchOptions.isPassiveMode(),
-                    searchOptions.onlyUseIpv6OnIpv6OnlyNetworks(),
-                    searchOptions.numOfQueriesBeforeBackoff(),
-                    socketKey);
-            final long now = clock.elapsedRealtime();
-            if (lastSentTime == 0) {
-                lastSentTime = now;
-            }
-            if (hadReply) {
-                final QueryTaskConfig queryTaskConfig = taskConfig.getConfigForNextRun();
-                final long minRemainingTtl = getMinRemainingTtlLocked(now);
-                final long timeToRun = now + queryTaskConfig.delayUntilNextTaskWithoutBackoffMs;
-                scheduleNextRunLocked(
-                        queryTaskConfig, minRemainingTtl, now, timeToRun, currentSessionId);
-            } else {
-                lastScheduledTask = new QueryTask(taskConfig,
-                        now /* timeToRun */,
-                        now + getMinRemainingTtlLocked(now)/* minTtlExpirationTimeWhenScheduled */,
-                        currentSessionId);
-                handleStartQueryTask(lastScheduledTask);
-            }
+        }
+        // Remove the next scheduled periodical task.
+        removeScheduledTask();
+        // Keep tracking the ScheduledFuture for the task so we can cancel it if caller is not
+        // interested anymore.
+        final QueryTaskConfig taskConfig = new QueryTaskConfig(
+                searchOptions.getSubtypes(),
+                searchOptions.isPassiveMode(),
+                searchOptions.onlyUseIpv6OnIpv6OnlyNetworks(),
+                searchOptions.numOfQueriesBeforeBackoff(),
+                socketKey);
+        final long now = clock.elapsedRealtime();
+        if (lastSentTime == 0) {
+            lastSentTime = now;
+        }
+        if (hadReply) {
+            final QueryTaskConfig queryTaskConfig = taskConfig.getConfigForNextRun();
+            final long minRemainingTtl = getMinRemainingTtl(now);
+            final long timeToRun = now + queryTaskConfig.delayUntilNextTaskWithoutBackoffMs;
+            scheduleNextRun(
+                    queryTaskConfig, minRemainingTtl, now, timeToRun, currentSessionId);
+        } else {
+            final List<MdnsResponse> servicesToResolve = makeResponsesForResolve(socketKey);
+            lastScheduledQueryTaskArgs = new ScheduledQueryTaskArgs(taskConfig, now /* timeToRun */,
+                    now + getMinRemainingTtl(now)/* minTtlExpirationTimeWhenScheduled */,
+                    currentSessionId);
+            final QueryTask queryTask = new QueryTask(lastScheduledQueryTaskArgs, servicesToResolve,
+                    servicesToResolve.size() < listeners.size() /* sendDiscoveryQueries */);
+            executor.submit(queryTask);
         }
     }
 
-    @GuardedBy("lock")
-    private void removeScheduledTaskLock() {
+    private void removeScheduledTask() {
         dependencies.removeMessages(handler, EVENT_START_QUERYTASK);
         sharedLog.log("Remove EVENT_START_QUERYTASK"
                 + ", current session: " + currentSessionId);
         ++currentSessionId;
-        lastScheduledTask = null;
-    }
-
-    private void handleStartQueryTask(@NonNull QueryTask task) {
-        executor.submit(task);
+        lastScheduledQueryTaskArgs = null;
     }
 
     private boolean responseMatchesOptions(@NonNull MdnsResponse response,
@@ -323,15 +360,13 @@
      */
     public boolean stopSendAndReceive(@NonNull MdnsServiceBrowserListener listener) {
         ensureRunningOnHandlerThread(handler);
-        synchronized (lock) {
-            if (listeners.remove(listener) == null) {
-                return listeners.isEmpty();
-            }
-            if (listeners.isEmpty()) {
-                removeScheduledTaskLock();
-            }
+        if (listeners.remove(listener) == null) {
             return listeners.isEmpty();
         }
+        if (listeners.isEmpty()) {
+            removeScheduledTask();
+        }
+        return listeners.isEmpty();
     }
 
     /**
@@ -340,51 +375,48 @@
     public synchronized void processResponse(@NonNull MdnsPacket packet,
             @NonNull SocketKey socketKey) {
         ensureRunningOnHandlerThread(handler);
-        synchronized (lock) {
-            // Augment the list of current known responses, and generated responses for resolve
-            // requests if there is no known response
-            final List<MdnsResponse> currentList = new ArrayList<>(instanceNameToResponse.values());
-            List<MdnsResponse> additionalResponses = makeResponsesForResolve(socketKey);
-            for (MdnsResponse additionalResponse : additionalResponses) {
-                if (!instanceNameToResponse.containsKey(
-                        additionalResponse.getServiceInstanceName())) {
-                    currentList.add(additionalResponse);
-                }
+        // Augment the list of current known responses, and generated responses for resolve
+        // requests if there is no known response
+        final List<MdnsResponse> currentList = new ArrayList<>(instanceNameToResponse.values());
+        List<MdnsResponse> additionalResponses = makeResponsesForResolve(socketKey);
+        for (MdnsResponse additionalResponse : additionalResponses) {
+            if (!instanceNameToResponse.containsKey(
+                    additionalResponse.getServiceInstanceName())) {
+                currentList.add(additionalResponse);
             }
-            final Pair<ArraySet<MdnsResponse>, ArrayList<MdnsResponse>> augmentedResult =
-                    responseDecoder.augmentResponses(packet, currentList,
-                            socketKey.getInterfaceIndex(), socketKey.getNetwork());
+        }
+        final Pair<ArraySet<MdnsResponse>, ArrayList<MdnsResponse>> augmentedResult =
+                responseDecoder.augmentResponses(packet, currentList,
+                        socketKey.getInterfaceIndex(), socketKey.getNetwork());
 
-            final ArraySet<MdnsResponse> modifiedResponse = augmentedResult.first;
-            final ArrayList<MdnsResponse> allResponses = augmentedResult.second;
+        final ArraySet<MdnsResponse> modifiedResponse = augmentedResult.first;
+        final ArrayList<MdnsResponse> allResponses = augmentedResult.second;
 
-            for (MdnsResponse response : allResponses) {
-                if (modifiedResponse.contains(response)) {
-                    if (response.isGoodbye()) {
-                        onGoodbyeReceivedLocked(response.getServiceInstanceName());
-                    } else {
-                        onResponseModifiedLocked(response);
-                    }
-                } else if (instanceNameToResponse.containsKey(response.getServiceInstanceName())) {
-                    // If the response is not modified and already in the cache. The cache will
-                    // need to be updated to refresh the last receipt time.
-                    instanceNameToResponse.put(response.getServiceInstanceName(), response);
+        for (MdnsResponse response : allResponses) {
+            if (modifiedResponse.contains(response)) {
+                if (response.isGoodbye()) {
+                    onGoodbyeReceived(response.getServiceInstanceName());
+                } else {
+                    onResponseModified(response);
                 }
+            } else if (instanceNameToResponse.containsKey(response.getServiceInstanceName())) {
+                // If the response is not modified and already in the cache. The cache will
+                // need to be updated to refresh the last receipt time.
+                instanceNameToResponse.put(response.getServiceInstanceName(), response);
             }
-            if (dependencies.hasMessages(handler, EVENT_START_QUERYTASK)
-                    && lastScheduledTask != null
-                    && lastScheduledTask.config.shouldUseQueryBackoff()) {
-                final long now = clock.elapsedRealtime();
-                final long minRemainingTtl = getMinRemainingTtlLocked(now);
-                final long timeToRun = calculateTimeToRun(lastScheduledTask,
-                        lastScheduledTask.config, now,
-                        minRemainingTtl, lastSentTime);
-                if (timeToRun > lastScheduledTask.timeToRun) {
-                    QueryTaskConfig lastTaskConfig = lastScheduledTask.config;
-                    removeScheduledTaskLock();
-                    scheduleNextRunLocked(
-                            lastTaskConfig, minRemainingTtl, now, timeToRun, currentSessionId);
-                }
+        }
+        if (dependencies.hasMessages(handler, EVENT_START_QUERYTASK)
+                && lastScheduledQueryTaskArgs != null
+                && lastScheduledQueryTaskArgs.config.shouldUseQueryBackoff()) {
+            final long now = clock.elapsedRealtime();
+            final long minRemainingTtl = getMinRemainingTtl(now);
+            final long timeToRun = calculateTimeToRun(lastScheduledQueryTaskArgs,
+                    lastScheduledQueryTaskArgs.config, now,
+                    minRemainingTtl, lastSentTime);
+            if (timeToRun > lastScheduledQueryTaskArgs.timeToRun) {
+                QueryTaskConfig lastTaskConfig = lastScheduledQueryTaskArgs.config;
+                removeScheduledTask();
+                scheduleNextRun(lastTaskConfig, minRemainingTtl, now, timeToRun, currentSessionId);
             }
         }
     }
@@ -399,29 +431,26 @@
     /** Notify all services are removed because the socket is destroyed. */
     public void notifySocketDestroyed() {
         ensureRunningOnHandlerThread(handler);
-        synchronized (lock) {
-            for (MdnsResponse response : instanceNameToResponse.values()) {
-                final String name = response.getServiceInstanceName();
-                if (name == null) continue;
-                for (int i = 0; i < listeners.size(); i++) {
-                    if (!responseMatchesOptions(response, listeners.valueAt(i))) continue;
-                    final MdnsServiceBrowserListener listener = listeners.keyAt(i);
-                    final MdnsServiceInfo serviceInfo =
-                            buildMdnsServiceInfoFromResponse(response, serviceTypeLabels);
-                    if (response.isComplete()) {
-                        sharedLog.log("Socket destroyed. onServiceRemoved: " + name);
-                        listener.onServiceRemoved(serviceInfo);
-                    }
-                    sharedLog.log("Socket destroyed. onServiceNameRemoved: " + name);
-                    listener.onServiceNameRemoved(serviceInfo);
+        for (MdnsResponse response : instanceNameToResponse.values()) {
+            final String name = response.getServiceInstanceName();
+            if (name == null) continue;
+            for (int i = 0; i < listeners.size(); i++) {
+                if (!responseMatchesOptions(response, listeners.valueAt(i))) continue;
+                final MdnsServiceBrowserListener listener = listeners.keyAt(i);
+                final MdnsServiceInfo serviceInfo =
+                        buildMdnsServiceInfoFromResponse(response, serviceTypeLabels);
+                if (response.isComplete()) {
+                    sharedLog.log("Socket destroyed. onServiceRemoved: " + name);
+                    listener.onServiceRemoved(serviceInfo);
                 }
+                sharedLog.log("Socket destroyed. onServiceNameRemoved: " + name);
+                listener.onServiceNameRemoved(serviceInfo);
             }
-            removeScheduledTaskLock();
         }
+        removeScheduledTask();
     }
 
-    @GuardedBy("lock")
-    private void onResponseModifiedLocked(@NonNull MdnsResponse response) {
+    private void onResponseModified(@NonNull MdnsResponse response) {
         final String serviceInstanceName = response.getServiceInstanceName();
         final MdnsResponse currentResponse =
                 instanceNameToResponse.get(serviceInstanceName);
@@ -467,8 +496,7 @@
         }
     }
 
-    @GuardedBy("lock")
-    private void onGoodbyeReceivedLocked(@Nullable String serviceInstanceName) {
+    private void onGoodbyeReceived(@Nullable String serviceInstanceName) {
         final MdnsResponse response = instanceNameToResponse.remove(serviceInstanceName);
         if (response == null) {
             return;
@@ -660,34 +688,80 @@
         return resolveResponses;
     }
 
-    // A FutureTask that enqueues a single query, and schedule a new FutureTask for the next task.
-    private class QueryTask implements Runnable {
+    private void tryRemoveServiceAfterTtlExpires() {
+        if (!shouldRemoveServiceAfterTtlExpires()) return;
 
+        Iterator<MdnsResponse> iter = instanceNameToResponse.values().iterator();
+        while (iter.hasNext()) {
+            MdnsResponse existingResponse = iter.next();
+            if (existingResponse.hasServiceRecord()
+                    && existingResponse.getServiceRecord()
+                    .getRemainingTTL(clock.elapsedRealtime()) == 0) {
+                iter.remove();
+                for (int i = 0; i < listeners.size(); i++) {
+                    if (!responseMatchesOptions(existingResponse, listeners.valueAt(i))) {
+                        continue;
+                    }
+                    final MdnsServiceBrowserListener listener = listeners.keyAt(i);
+                    if (existingResponse.getServiceInstanceName() != null) {
+                        final MdnsServiceInfo serviceInfo = buildMdnsServiceInfoFromResponse(
+                                existingResponse, serviceTypeLabels);
+                        if (existingResponse.isComplete()) {
+                            sharedLog.log("TTL expired. onServiceRemoved: " + serviceInfo);
+                            listener.onServiceRemoved(serviceInfo);
+                        }
+                        sharedLog.log("TTL expired. onServiceNameRemoved: " + serviceInfo);
+                        listener.onServiceNameRemoved(serviceInfo);
+                    }
+                }
+            }
+        }
+    }
+
+    private static class ScheduledQueryTaskArgs {
         private final QueryTaskConfig config;
         private final long timeToRun;
         private final long minTtlExpirationTimeWhenScheduled;
         private final long sessionId;
 
-        QueryTask(@NonNull QueryTaskConfig config, long timeToRun,
-                long minTtlExpirationTimeWhenScheduled,
-                long sessionId) {
+        ScheduledQueryTaskArgs(@NonNull QueryTaskConfig config, long timeToRun,
+                long minTtlExpirationTimeWhenScheduled, long sessionId) {
             this.config = config;
             this.timeToRun = timeToRun;
             this.minTtlExpirationTimeWhenScheduled = minTtlExpirationTimeWhenScheduled;
             this.sessionId = sessionId;
         }
+    }
+
+    private static class QuerySentResult {
+        private final int transactionId;
+        private final List<String> subTypes = new ArrayList<>();
+        private final ScheduledQueryTaskArgs taskArgs;
+
+        QuerySentResult(int transactionId, @NonNull List<String> subTypes,
+                @NonNull ScheduledQueryTaskArgs taskArgs) {
+            this.transactionId = transactionId;
+            this.subTypes.addAll(subTypes);
+            this.taskArgs = taskArgs;
+        }
+    }
+
+    // A FutureTask that enqueues a single query, and schedule a new FutureTask for the next task.
+    private class QueryTask implements Runnable {
+
+        private final ScheduledQueryTaskArgs taskArgs;
+        private final List<MdnsResponse> servicesToResolve = new ArrayList<>();
+        private final boolean sendDiscoveryQueries;
+
+        QueryTask(@NonNull ScheduledQueryTaskArgs taskArgs,
+                @NonNull List<MdnsResponse> servicesToResolve, boolean sendDiscoveryQueries) {
+            this.taskArgs = taskArgs;
+            this.servicesToResolve.addAll(servicesToResolve);
+            this.sendDiscoveryQueries = sendDiscoveryQueries;
+        }
 
         @Override
         public void run() {
-            final List<MdnsResponse> servicesToResolve;
-            final boolean sendDiscoveryQueries;
-            synchronized (lock) {
-                // The listener is requesting to resolve a service that has no info in
-                // cache. Use the provided name to generate a minimal response, so other records are
-                // queried to complete it.
-                servicesToResolve = makeResponsesForResolve(config.socketKey);
-                sendDiscoveryQueries = servicesToResolve.size() < listeners.size();
-            }
             Pair<Integer, List<String>> result;
             try {
                 result =
@@ -695,80 +769,27 @@
                                 socketClient,
                                 createMdnsPacketWriter(),
                                 serviceType,
-                                config.subtypes,
-                                config.expectUnicastResponse,
-                                config.transactionId,
-                                config.socketKey,
-                                config.onlyUseIpv6OnIpv6OnlyNetworks,
+                                taskArgs.config.subtypes,
+                                taskArgs.config.expectUnicastResponse,
+                                taskArgs.config.transactionId,
+                                taskArgs.config.socketKey,
+                                taskArgs.config.onlyUseIpv6OnIpv6OnlyNetworks,
                                 sendDiscoveryQueries,
                                 servicesToResolve,
                                 clock)
                                 .call();
             } catch (RuntimeException e) {
                 sharedLog.e(String.format("Failed to run EnqueueMdnsQueryCallable for subtype: %s",
-                        TextUtils.join(",", config.subtypes)), e);
-                result = null;
+                        TextUtils.join(",", taskArgs.config.subtypes)), e);
+                result = Pair.create(-1, new ArrayList<>());
             }
-            synchronized (lock) {
-                if (MdnsConfigs.useSessionIdToScheduleMdnsTask()) {
-                    // In case that the task is not canceled successfully, use session ID to check
-                    // if this task should continue to schedule more.
-                    if (sessionId != currentSessionId) {
-                        return;
-                    }
-                }
-
-                if ((result != null)) {
-                    for (int i = 0; i < listeners.size(); i++) {
-                        listeners.keyAt(i).onDiscoveryQuerySent(result.second, result.first);
-                    }
-                }
-                if (shouldRemoveServiceAfterTtlExpires()) {
-                    Iterator<MdnsResponse> iter = instanceNameToResponse.values().iterator();
-                    while (iter.hasNext()) {
-                        MdnsResponse existingResponse = iter.next();
-                        if (existingResponse.hasServiceRecord()
-                                && existingResponse
-                                .getServiceRecord()
-                                .getRemainingTTL(clock.elapsedRealtime())
-                                == 0) {
-                            iter.remove();
-                            for (int i = 0; i < listeners.size(); i++) {
-                                if (!responseMatchesOptions(existingResponse,
-                                        listeners.valueAt(i)))  {
-                                    continue;
-                                }
-                                final MdnsServiceBrowserListener listener = listeners.keyAt(i);
-                                if (existingResponse.getServiceInstanceName() != null) {
-                                    final MdnsServiceInfo serviceInfo =
-                                            buildMdnsServiceInfoFromResponse(
-                                                    existingResponse, serviceTypeLabels);
-                                    if (existingResponse.isComplete()) {
-                                        sharedLog.log("TTL expired. onServiceRemoved: "
-                                                + serviceInfo);
-                                        listener.onServiceRemoved(serviceInfo);
-                                    }
-                                    sharedLog.log("TTL expired. onServiceNameRemoved: "
-                                            + serviceInfo);
-                                    listener.onServiceNameRemoved(serviceInfo);
-                                }
-                            }
-                        }
-                    }
-                }
-                QueryTaskConfig nextRunConfig = this.config.getConfigForNextRun();
-                final long now = clock.elapsedRealtime();
-                lastSentTime = now;
-                final long minRemainingTtl = getMinRemainingTtlLocked(now);
-                final long timeToRun = calculateTimeToRun(this, nextRunConfig, now,
-                        minRemainingTtl, lastSentTime);
-                scheduleNextRunLocked(nextRunConfig, minRemainingTtl, now, timeToRun,
-                        lastScheduledTask.sessionId);
-            }
+            dependencies.sendMessage(
+                    handler, handler.obtainMessage(EVENT_QUERY_RESULT,
+                            new QuerySentResult(result.first, result.second, taskArgs)));
         }
     }
 
-    private static long calculateTimeToRun(@NonNull QueryTask lastScheduledTask,
+    private static long calculateTimeToRun(@NonNull ScheduledQueryTaskArgs taskArgs,
             QueryTaskConfig queryTaskConfig, long now, long minRemainingTtl, long lastSentTime) {
         final long baseDelayInMs = queryTaskConfig.delayUntilNextTaskWithoutBackoffMs;
         if (!queryTaskConfig.shouldUseQueryBackoff()) {
@@ -781,16 +802,15 @@
         }
         // If the next TTL expiration time hasn't changed, then use previous calculated timeToRun.
         if (lastSentTime < now
-                && lastScheduledTask.minTtlExpirationTimeWhenScheduled == now + minRemainingTtl) {
+                && taskArgs.minTtlExpirationTimeWhenScheduled == now + minRemainingTtl) {
             // Use the original scheduling time if the TTL has not changed, to avoid continuously
             // rescheduling to 80% of the remaining TTL as time passes
-            return lastScheduledTask.timeToRun;
+            return taskArgs.timeToRun;
         }
         return Math.max(now + (long) (0.8 * minRemainingTtl), lastSentTime + baseDelayInMs);
     }
 
-    @GuardedBy("lock")
-    private long getMinRemainingTtlLocked(long now) {
+    private long getMinRemainingTtl(long now) {
         long minRemainingTtl = Long.MAX_VALUE;
         for (MdnsResponse response : instanceNameToResponse.values()) {
             if (!response.isComplete()) {
@@ -811,19 +831,18 @@
 
     @GuardedBy("lock")
     @NonNull
-    private void scheduleNextRunLocked(@NonNull QueryTaskConfig nextRunConfig,
+    private void scheduleNextRun(@NonNull QueryTaskConfig nextRunConfig,
             long minRemainingTtl,
             long timeWhenScheduled, long timeToRun, long sessionId) {
-        lastScheduledTask = new QueryTask(nextRunConfig, timeToRun,
+        lastScheduledQueryTaskArgs = new ScheduledQueryTaskArgs(nextRunConfig, timeToRun,
                 minRemainingTtl + timeWhenScheduled, sessionId);
         // The timeWhenScheduled could be greater than the timeToRun if the Runnable is delayed.
         long timeToNextTasksWithBackoffInMs = Math.max(timeToRun - timeWhenScheduled, 0);
-        sharedLog.log(
-                String.format("Next run: sessionId: %d, in %d ms", lastScheduledTask.sessionId,
-                        timeToNextTasksWithBackoffInMs));
+        sharedLog.log(String.format("Next run: sessionId: %d, in %d ms",
+                lastScheduledQueryTaskArgs.sessionId, timeToNextTasksWithBackoffInMs));
         dependencies.sendMessageDelayed(
                 handler,
-                handler.obtainMessage(EVENT_START_QUERYTASK, lastScheduledTask),
+                handler.obtainMessage(EVENT_START_QUERYTASK, lastScheduledQueryTaskArgs),
                 timeToNextTasksWithBackoffInMs);
     }
 }
\ No newline at end of file
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java
index ad5583b..4328053 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java
@@ -204,6 +204,13 @@
             return true;
         }).when(mockDeps).sendMessageDelayed(any(Handler.class), any(Message.class), anyLong());
 
+        doAnswer(inv -> {
+            final Handler handler = (Handler) inv.getArguments()[0];
+            final Message message = (Message) inv.getArguments()[1];
+            runOnHandler(() -> handler.dispatchMessage(message));
+            return true;
+        }).when(mockDeps).sendMessage(any(Handler.class), any(Message.class));
+
         client =
                 new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor,
                         mockDecoderClock, socketKey, mockSharedLog, thread.getLooper(), mockDeps) {
@@ -925,6 +932,7 @@
         // Simulate the case where the response is under TTL.
         doReturn(TEST_ELAPSED_REALTIME + TEST_TTL - 1L).when(mockDecoderClock).elapsedRealtime();
         firstMdnsTask.run();
+        verify(mockDeps, times(1)).sendMessage(any(), any(Message.class));
 
         // Verify removed callback was not called.
         verifyServiceRemovedNoCallback(mockListenerOne);
@@ -932,6 +940,7 @@
         // Simulate the case where the response is after TTL.
         doReturn(TEST_ELAPSED_REALTIME + TEST_TTL + 1L).when(mockDecoderClock).elapsedRealtime();
         firstMdnsTask.run();
+        verify(mockDeps, times(2)).sendMessage(any(), any(Message.class));
 
         // Verify removed callback was called.
         verifyServiceRemovedCallback(
@@ -1118,6 +1127,7 @@
         inOrder.verify(mockSocketClient, times(2)).sendPacketRequestingUnicastResponse(
                 srvTxtQueryCaptor.capture(),
                 eq(socketKey), eq(false));
+        verify(mockDeps, times(1)).sendMessage(any(), any(Message.class));
         assertNotNull(delayMessage);
 
         final MdnsPacket srvTxtQueryPacket = MdnsPacket.parse(
@@ -1210,6 +1220,7 @@
         inOrder.verify(mockSocketClient, times(2)).sendPacketRequestingUnicastResponse(
                 srvTxtQueryCaptor.capture(),
                 eq(socketKey), eq(false));
+        verify(mockDeps, times(1)).sendMessage(any(), any(Message.class));
         assertNotNull(delayMessage);
 
         final MdnsPacket srvTxtQueryPacket = MdnsPacket.parse(
@@ -1249,6 +1260,7 @@
         // Advance time so 75% of TTL passes and re-execute
         doReturn(TEST_ELAPSED_REALTIME + (long) (TEST_TTL * 0.75))
                 .when(mockDecoderClock).elapsedRealtime();
+        verify(mockDeps, times(2)).sendMessage(any(), any(Message.class));
         assertNotNull(delayMessage);
         dispatchMessage();
         currentThreadExecutor.getAndClearLastScheduledRunnable().run();
@@ -1260,12 +1272,13 @@
         inOrder.verify(mockSocketClient, times(2)).sendPacketRequestingMulticastResponse(
                 renewalQueryCaptor.capture(),
                 eq(socketKey), eq(false));
+        verify(mockDeps, times(3)).sendMessage(any(), any(Message.class));
+        assertNotNull(delayMessage);
         inOrder.verify(mockListenerOne).onDiscoveryQuerySent(any(), anyInt());
         final MdnsPacket renewalPacket = MdnsPacket.parse(
                 new MdnsPacketReader(renewalQueryCaptor.getValue()));
         assertTrue(hasQuestion(renewalPacket, MdnsRecord.TYPE_ANY, serviceName));
         inOrder.verifyNoMoreInteractions();
-        assertNotNull(delayMessage);
 
         long updatedReceiptTime =  TEST_ELAPSED_REALTIME + TEST_TTL;
         final MdnsPacket refreshedSrvTxtResponse = new MdnsPacket(
@@ -1545,6 +1558,8 @@
                         expectedIPv6Packets[index], socketKey, false);
             }
         }
+        verify(mockDeps, times(index + 1))
+                .sendMessage(any(Handler.class), any(Message.class));
         // Verify the task has been scheduled.
         verify(mockDeps, times(scheduledCount))
                 .sendMessageDelayed(any(Handler.class), any(Message.class), anyLong());