Put known answers to query packets

To support known-answer suppression, the querier should include
its known answers within each query packet. This allows the
service to determine whether a response is necessary based on
the querier's existing knowledge.

Bug: 312657709
Test: atest FrameworksNetTests NsdManagerTest
Change-Id: I2fb421aff1ae150b310e1a44f50f098a3edb9992
diff --git a/service-t/src/com/android/server/NsdService.java b/service-t/src/com/android/server/NsdService.java
index cfb1a33..aca386f 100644
--- a/service-t/src/com/android/server/NsdService.java
+++ b/service-t/src/com/android/server/NsdService.java
@@ -1851,6 +1851,8 @@
                         mContext, MdnsFeatureFlags.NSD_UNICAST_REPLY_ENABLED))
                 .setIsAggressiveQueryModeEnabled(mDeps.isFeatureEnabled(
                         mContext, MdnsFeatureFlags.NSD_AGGRESSIVE_QUERY_MODE))
+                .setIsQueryWithKnownAnswerEnabled(mDeps.isFeatureEnabled(
+                        mContext, MdnsFeatureFlags.NSD_QUERY_WITH_KNOWN_ANSWER))
                 .setOverrideProvider(flag -> mDeps.isFeatureEnabled(
                         mContext, FORCE_ENABLE_FLAG_FOR_TEST_PREFIX + flag))
                 .build();
diff --git a/service-t/src/com/android/server/connectivity/mdns/EnqueueMdnsQueryCallable.java b/service-t/src/com/android/server/connectivity/mdns/EnqueueMdnsQueryCallable.java
index 5537796..e61555a 100644
--- a/service-t/src/com/android/server/connectivity/mdns/EnqueueMdnsQueryCallable.java
+++ b/service-t/src/com/android/server/connectivity/mdns/EnqueueMdnsQueryCallable.java
@@ -23,6 +23,7 @@
 import android.text.TextUtils;
 import android.util.Pair;
 
+import com.android.net.module.util.CollectionUtils;
 import com.android.net.module.util.SharedLog;
 import com.android.server.connectivity.mdns.util.MdnsUtils;
 
@@ -81,6 +82,8 @@
     private final MdnsServiceTypeClient.Dependencies dependencies;
     private final boolean onlyUseIpv6OnIpv6OnlyNetworks;
     private final byte[] packetCreationBuffer = new byte[1500]; // TODO: use interface MTU
+    @NonNull
+    private final List<MdnsResponse> existingServices;
 
     EnqueueMdnsQueryCallable(
             @NonNull MdnsSocketClientBase requestSender,
@@ -94,7 +97,8 @@
             @NonNull Collection<MdnsResponse> servicesToResolve,
             @NonNull MdnsUtils.Clock clock,
             @NonNull SharedLog sharedLog,
-            @NonNull MdnsServiceTypeClient.Dependencies dependencies) {
+            @NonNull MdnsServiceTypeClient.Dependencies dependencies,
+            @NonNull Collection<MdnsResponse> existingServices) {
         weakRequestSender = new WeakReference<>(requestSender);
         serviceTypeLabels = TextUtils.split(serviceType, "\\.");
         this.subtypes = new ArrayList<>(subtypes);
@@ -107,6 +111,7 @@
         this.clock = clock;
         this.sharedLog = sharedLog;
         this.dependencies = dependencies;
+        this.existingServices = new ArrayList<>(existingServices);
     }
 
     /**
@@ -177,11 +182,34 @@
                 return Pair.create(INVALID_TRANSACTION_ID, new ArrayList<>());
             }
 
+            // Put the existing ptr records into known-answer section.
+            final List<MdnsRecord> knownAnswers = new ArrayList<>();
+            if (sendDiscoveryQueries) {
+                for (MdnsResponse existingService : existingServices) {
+                    for (MdnsPointerRecord ptrRecord : existingService.getPointerRecords()) {
+                        // Ignore any PTR records that don't match the current query.
+                        if (!CollectionUtils.any(questions,
+                                q -> q instanceof MdnsPointerRecord
+                                        && MdnsUtils.equalsDnsLabelIgnoreDnsCase(
+                                                q.getName(), ptrRecord.getName()))) {
+                            continue;
+                        }
+
+                        knownAnswers.add(new MdnsPointerRecord(
+                                ptrRecord.getName(),
+                                ptrRecord.getReceiptTime(),
+                                ptrRecord.getCacheFlush(),
+                                ptrRecord.getRemainingTTL(now), // Put the remaining ttl.
+                                ptrRecord.getPointer()));
+                    }
+                }
+            }
+
             final MdnsPacket queryPacket = new MdnsPacket(
                     transactionId,
                     MdnsConstants.FLAGS_QUERY,
                     questions,
-                    Collections.emptyList(), /* answers */
+                    knownAnswers,
                     Collections.emptyList(), /* authorityRecords */
                     Collections.emptyList() /* additionalRecords */);
             sendPacketToIpv4AndIpv6(requestSender, MdnsConstants.MDNS_PORT, queryPacket);
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java b/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java
index 21b7069..7b0c738 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java
@@ -362,7 +362,7 @@
         return new MdnsServiceTypeClient(
                 serviceType, socketClient,
                 executorProvider.newServiceTypeClientSchedulerExecutor(), socketKey,
-                sharedLog.forSubComponent(tag), looper, serviceCache);
+                sharedLog.forSubComponent(tag), looper, serviceCache, mdnsFeatureFlags);
     }
 
     /**
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsFeatureFlags.java b/service-t/src/com/android/server/connectivity/mdns/MdnsFeatureFlags.java
index 56202fd..f4a08ba 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsFeatureFlags.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsFeatureFlags.java
@@ -62,6 +62,11 @@
      */
     public static final String NSD_AGGRESSIVE_QUERY_MODE = "nsd_aggressive_query_mode";
 
+    /**
+     * A feature flag to control whether the query with known-answer should be enabled.
+     */
+    public static final String NSD_QUERY_WITH_KNOWN_ANSWER = "nsd_query_with_known_answer";
+
     // Flag for offload feature
     public final boolean mIsMdnsOffloadFeatureEnabled;
 
@@ -83,6 +88,9 @@
     // Flag for aggressive query mode
     public final boolean mIsAggressiveQueryModeEnabled;
 
+    // Flag for query with known-answer
+    public final boolean mIsQueryWithKnownAnswerEnabled;
+
     @Nullable
     private final FlagOverrideProvider mOverrideProvider;
 
@@ -126,6 +134,14 @@
     }
 
     /**
+     * Indicates whether {@link #NSD_QUERY_WITH_KNOWN_ANSWER} is enabled, including for testing.
+     */
+    public boolean isQueryWithKnownAnswerEnabled() {
+        return mIsQueryWithKnownAnswerEnabled
+                || isForceEnabledForTest(NSD_QUERY_WITH_KNOWN_ANSWER);
+    }
+
+    /**
      * The constructor for {@link MdnsFeatureFlags}.
      */
     public MdnsFeatureFlags(boolean isOffloadFeatureEnabled,
@@ -135,6 +151,7 @@
             boolean isKnownAnswerSuppressionEnabled,
             boolean isUnicastReplyEnabled,
             boolean isAggressiveQueryModeEnabled,
+            boolean isQueryWithKnownAnswerEnabled,
             @Nullable FlagOverrideProvider overrideProvider) {
         mIsMdnsOffloadFeatureEnabled = isOffloadFeatureEnabled;
         mIncludeInetAddressRecordsInProbing = includeInetAddressRecordsInProbing;
@@ -143,6 +160,7 @@
         mIsKnownAnswerSuppressionEnabled = isKnownAnswerSuppressionEnabled;
         mIsUnicastReplyEnabled = isUnicastReplyEnabled;
         mIsAggressiveQueryModeEnabled = isAggressiveQueryModeEnabled;
+        mIsQueryWithKnownAnswerEnabled = isQueryWithKnownAnswerEnabled;
         mOverrideProvider = overrideProvider;
     }
 
@@ -162,6 +180,7 @@
         private boolean mIsKnownAnswerSuppressionEnabled;
         private boolean mIsUnicastReplyEnabled;
         private boolean mIsAggressiveQueryModeEnabled;
+        private boolean mIsQueryWithKnownAnswerEnabled;
         private FlagOverrideProvider mOverrideProvider;
 
         /**
@@ -175,6 +194,7 @@
             mIsKnownAnswerSuppressionEnabled = false;
             mIsUnicastReplyEnabled = true;
             mIsAggressiveQueryModeEnabled = false;
+            mIsQueryWithKnownAnswerEnabled = false;
             mOverrideProvider = null;
         }
 
@@ -261,6 +281,16 @@
         }
 
         /**
+         * Set whether the query with known-answer is enabled.
+         *
+         * @see #NSD_QUERY_WITH_KNOWN_ANSWER
+         */
+        public Builder setIsQueryWithKnownAnswerEnabled(boolean isQueryWithKnownAnswerEnabled) {
+            mIsQueryWithKnownAnswerEnabled = isQueryWithKnownAnswerEnabled;
+            return this;
+        }
+
+        /**
          * Builds a {@link MdnsFeatureFlags} with the arguments supplied to this builder.
          */
         public MdnsFeatureFlags build() {
@@ -271,6 +301,7 @@
                     mIsKnownAnswerSuppressionEnabled,
                     mIsUnicastReplyEnabled,
                     mIsAggressiveQueryModeEnabled,
+                    mIsQueryWithKnownAnswerEnabled,
                     mOverrideProvider);
         }
     }
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
index ba6cdd5..bfcd0b4 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
@@ -86,6 +86,7 @@
                     notifyRemovedServiceToListeners(previousResponse, "Service record expired");
                 }
             };
+    @NonNull private final MdnsFeatureFlags featureFlags;
     private final ArrayMap<MdnsServiceBrowserListener, ListenerInfo> listeners =
             new ArrayMap<>();
     private final boolean removeServiceAfterTtlExpires =
@@ -144,7 +145,8 @@
                     // before sending the query, it needs to be called just before sending it.
                     final List<MdnsResponse> servicesToResolve = makeResponsesForResolve(socketKey);
                     final QueryTask queryTask = new QueryTask(taskArgs, servicesToResolve,
-                            getAllDiscoverySubtypes(), needSendDiscoveryQueries(listeners));
+                            getAllDiscoverySubtypes(), needSendDiscoveryQueries(listeners),
+                            getExistingServices());
                     executor.submit(queryTask);
                     break;
                 }
@@ -248,9 +250,10 @@
             @NonNull SocketKey socketKey,
             @NonNull SharedLog sharedLog,
             @NonNull Looper looper,
-            @NonNull MdnsServiceCache serviceCache) {
+            @NonNull MdnsServiceCache serviceCache,
+            @NonNull MdnsFeatureFlags featureFlags) {
         this(serviceType, socketClient, executor, new Clock(), socketKey, sharedLog, looper,
-                new Dependencies(), serviceCache);
+                new Dependencies(), serviceCache, featureFlags);
     }
 
     @VisibleForTesting
@@ -263,7 +266,8 @@
             @NonNull SharedLog sharedLog,
             @NonNull Looper looper,
             @NonNull Dependencies dependencies,
-            @NonNull MdnsServiceCache serviceCache) {
+            @NonNull MdnsServiceCache serviceCache,
+            @NonNull MdnsFeatureFlags featureFlags) {
         this.serviceType = serviceType;
         this.socketClient = socketClient;
         this.executor = executor;
@@ -277,6 +281,7 @@
         this.serviceCache = serviceCache;
         this.mdnsQueryScheduler = new MdnsQueryScheduler();
         this.cacheKey = new MdnsServiceCache.CacheKey(serviceType, socketKey);
+        this.featureFlags = featureFlags;
     }
 
     /**
@@ -339,6 +344,11 @@
                 now.plusMillis(response.getMinRemainingTtl(now.toEpochMilli())));
     }
 
+    private List<MdnsResponse> getExistingServices() {
+        return featureFlags.isQueryWithKnownAnswerEnabled()
+                ? serviceCache.getCachedServices(cacheKey) : Collections.emptyList();
+    }
+
     /**
      * Registers {@code listener} for receiving discovery event of mDNS service instances, and
      * starts
@@ -403,7 +413,8 @@
             final QueryTask queryTask = new QueryTask(
                     mdnsQueryScheduler.scheduleFirstRun(taskConfig, now,
                             minRemainingTtl, currentSessionId), servicesToResolve,
-                    getAllDiscoverySubtypes(), needSendDiscoveryQueries(listeners));
+                    getAllDiscoverySubtypes(), needSendDiscoveryQueries(listeners),
+                    getExistingServices());
             executor.submit(queryTask);
         }
 
@@ -701,14 +712,16 @@
         private final List<MdnsResponse> servicesToResolve = new ArrayList<>();
         private final List<String> subtypes = new ArrayList<>();
         private final boolean sendDiscoveryQueries;
+        private final List<MdnsResponse> existingServices = new ArrayList<>();
         QueryTask(@NonNull MdnsQueryScheduler.ScheduledQueryTaskArgs taskArgs,
                 @NonNull Collection<MdnsResponse> servicesToResolve,
-                @NonNull Collection<String> subtypes,
-                boolean sendDiscoveryQueries) {
+                @NonNull Collection<String> subtypes, boolean sendDiscoveryQueries,
+                @NonNull Collection<MdnsResponse> existingServices) {
             this.taskArgs = taskArgs;
             this.servicesToResolve.addAll(servicesToResolve);
             this.subtypes.addAll(subtypes);
             this.sendDiscoveryQueries = sendDiscoveryQueries;
+            this.existingServices.addAll(existingServices);
         }
 
         @Override
@@ -728,7 +741,8 @@
                                 servicesToResolve,
                                 clock,
                                 sharedLog,
-                                dependencies)
+                                dependencies,
+                                existingServices)
                                 .call();
             } catch (RuntimeException e) {
                 sharedLog.e(String.format("Failed to run EnqueueMdnsQueryCallable for subtype: %s",
diff --git a/tests/cts/net/src/android/net/cts/NsdManagerTest.kt b/tests/cts/net/src/android/net/cts/NsdManagerTest.kt
index 8dbcf2f..3f72395 100644
--- a/tests/cts/net/src/android/net/cts/NsdManagerTest.kt
+++ b/tests/cts/net/src/android/net/cts/NsdManagerTest.kt
@@ -1888,6 +1888,64 @@
         }
     }
 
+    @Test
+    fun testQueryWhenKnownAnswerSuppressionFlagSet() {
+        // The flag may be removed in the future but known-answer suppression should be enabled by
+        // default in that case. The rule will reset flags automatically on teardown.
+        deviceConfigRule.setConfig(NAMESPACE_TETHERING, "test_nsd_query_with_known_answer", "1")
+
+        // Register service on testNetwork1
+        val discoveryRecord = NsdDiscoveryRecord()
+        val packetReader = TapPacketReader(Handler(handlerThread.looper),
+                testNetwork1.iface.fileDescriptor.fileDescriptor, 1500 /* maxPacketSize */)
+        packetReader.startAsyncForTest()
+        handlerThread.waitForIdle(TIMEOUT_MS)
+
+        nsdManager.discoverServices(serviceType, NsdManager.PROTOCOL_DNS_SD,
+                testNetwork1.network, { it.run() }, discoveryRecord)
+
+        tryTest {
+            discoveryRecord.expectCallback<DiscoveryStarted>()
+            assertNotNull(packetReader.pollForQuery("$serviceType.local", DnsResolver.TYPE_PTR))
+            /*
+            Generated with:
+            scapy.raw(scapy.DNS(rd=0, qr=1, aa=1, qd = None, an =
+                scapy.DNSRR(rrname='_nmt123456789._tcp.local', type='PTR', ttl=120,
+                rdata='NsdTest123456789._nmt123456789._tcp.local'))).hex()
+             */
+            val ptrResponsePayload = HexDump.hexStringToByteArray("0000840000000001000000000d5f6e" +
+                    "6d74313233343536373839045f746370056c6f63616c00000c000100000078002b104e736454" +
+                    "6573743132333435363738390d5f6e6d74313233343536373839045f746370056c6f63616c00")
+
+            replaceServiceNameAndTypeWithTestSuffix(ptrResponsePayload)
+            packetReader.sendResponse(buildMdnsPacket(ptrResponsePayload))
+
+            val serviceFound = discoveryRecord.expectCallback<ServiceFound>()
+            serviceFound.serviceInfo.let {
+                assertEquals(serviceName, it.serviceName)
+                // Discovered service types have a dot at the end
+                assertEquals("$serviceType.", it.serviceType)
+                assertEquals(testNetwork1.network, it.network)
+                // ServiceFound does not provide port, address or attributes (only information
+                // available in the PTR record is included in that callback, regardless of whether
+                // other records exist).
+                assertEquals(0, it.port)
+                assertEmpty(it.hostAddresses)
+                assertEquals(0, it.attributes.size)
+            }
+
+            // Expect the second query with a known answer
+            val query = packetReader.pollForMdnsPacket { pkt ->
+                pkt.isQueryFor("$serviceType.local", DnsResolver.TYPE_PTR) &&
+                        pkt.isReplyFor("$serviceType.local", DnsResolver.TYPE_PTR)
+            }
+            assertNotNull(query)
+        } cleanup {
+            nsdManager.stopServiceDiscovery(discoveryRecord)
+            discoveryRecord.expectCallback<DiscoveryStopped>()
+        }
+    }
+
     private fun makeLinkLocalAddressOfOtherDeviceOnPrefix(network: Network): Inet6Address {
         val lp = cm.getLinkProperties(network) ?: fail("No LinkProperties for net $network")
         // Expect to have a /64 link-local address
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java
index b1df7f8..2eb9440 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java
@@ -144,6 +144,7 @@
     private long latestDelayMs = 0;
     private Message delayMessage = null;
     private Handler realHandler = null;
+    private MdnsFeatureFlags featureFlags = MdnsFeatureFlags.newBuilder().build();
 
     @Before
     @SuppressWarnings("DoNotMock")
@@ -249,7 +250,7 @@
     private MdnsServiceTypeClient makeMdnsServiceTypeClient() {
         return new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor,
                 mockDecoderClock, socketKey, mockSharedLog, thread.getLooper(), mockDeps,
-                serviceCache);
+                serviceCache, featureFlags);
     }
 
     @After
@@ -1929,6 +1930,138 @@
                 16 /* scheduledCount */);
     }
 
+    @Test
+    public void testSendQueryWithKnownAnswers() throws Exception {
+        client = new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor,
+                mockDecoderClock, socketKey, mockSharedLog, thread.getLooper(), mockDeps,
+                serviceCache,
+                MdnsFeatureFlags.newBuilder().setIsQueryWithKnownAnswerEnabled(true).build());
+
+        doCallRealMethod().when(mockDeps).getDatagramPacketFromMdnsPacket(
+                any(), any(MdnsPacket.class), any(InetSocketAddress.class));
+
+        startSendAndReceive(mockListenerOne, MdnsSearchOptions.getDefaultOptions());
+        InOrder inOrder = inOrder(mockListenerOne, mockSocketClient);
+
+        final ArgumentCaptor<DatagramPacket> queryCaptor =
+                ArgumentCaptor.forClass(DatagramPacket.class);
+        currentThreadExecutor.getAndClearLastScheduledRunnable().run();
+        // Send twice for IPv4 and IPv6
+        inOrder.verify(mockSocketClient, times(2)).sendPacketRequestingUnicastResponse(
+                queryCaptor.capture(), eq(socketKey), eq(false));
+        verify(mockDeps, times(1)).sendMessage(any(), any(Message.class));
+        assertNotNull(delayMessage);
+
+        final MdnsPacket queryPacket = MdnsPacket.parse(
+                new MdnsPacketReader(queryCaptor.getValue()));
+        assertTrue(hasQuestion(queryPacket, MdnsRecord.TYPE_PTR));
+
+        // Process a response
+        final String serviceName = "service-instance";
+        final String ipV4Address = "192.0.2.0";
+        final String[] subtypeLabels = Stream.concat(Stream.of("_subtype", "_sub"),
+                        Arrays.stream(SERVICE_TYPE_LABELS)).toArray(String[]::new);
+        final MdnsPacket packetWithoutSubtype = createResponse(
+                serviceName, ipV4Address, 5353, SERVICE_TYPE_LABELS,
+                Collections.emptyMap() /* textAttributes */, TEST_TTL);
+        final MdnsPointerRecord originalPtr = (MdnsPointerRecord) CollectionUtils.findFirst(
+                packetWithoutSubtype.answers, r -> r instanceof MdnsPointerRecord);
+
+        // Add a subtype PTR record
+        final ArrayList<MdnsRecord> newAnswers = new ArrayList<>(packetWithoutSubtype.answers);
+        newAnswers.add(new MdnsPointerRecord(subtypeLabels, originalPtr.getReceiptTime(),
+                originalPtr.getCacheFlush(), originalPtr.getTtl(), originalPtr.getPointer()));
+        final MdnsPacket packetWithSubtype = new MdnsPacket(
+                packetWithoutSubtype.flags,
+                packetWithoutSubtype.questions,
+                newAnswers,
+                packetWithoutSubtype.authorityRecords,
+                packetWithoutSubtype.additionalRecords);
+        processResponse(packetWithSubtype, socketKey);
+
+        // Expect a query with known answers
+        dispatchMessage();
+        final ArgumentCaptor<DatagramPacket> knownAnswersQueryCaptor =
+                ArgumentCaptor.forClass(DatagramPacket.class);
+        currentThreadExecutor.getAndClearLastScheduledRunnable().run();
+        inOrder.verify(mockSocketClient, times(2)).sendPacketRequestingMulticastResponse(
+                knownAnswersQueryCaptor.capture(), eq(socketKey), eq(false));
+
+        final MdnsPacket knownAnswersQueryPacket = MdnsPacket.parse(
+                new MdnsPacketReader(knownAnswersQueryCaptor.getValue()));
+        assertTrue(hasQuestion(knownAnswersQueryPacket, MdnsRecord.TYPE_PTR, SERVICE_TYPE_LABELS));
+        assertTrue(hasAnswer(knownAnswersQueryPacket, MdnsRecord.TYPE_PTR, SERVICE_TYPE_LABELS));
+        assertFalse(hasAnswer(knownAnswersQueryPacket, MdnsRecord.TYPE_PTR, subtypeLabels));
+    }
+
+    @Test
+    public void testSendQueryWithSubTypeWithKnownAnswers() throws Exception {
+        client = new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor,
+                mockDecoderClock, socketKey, mockSharedLog, thread.getLooper(), mockDeps,
+                serviceCache,
+                MdnsFeatureFlags.newBuilder().setIsQueryWithKnownAnswerEnabled(true).build());
+
+        doCallRealMethod().when(mockDeps).getDatagramPacketFromMdnsPacket(
+                any(), any(MdnsPacket.class), any(InetSocketAddress.class));
+
+        final MdnsSearchOptions options = MdnsSearchOptions.newBuilder()
+                .addSubtype("subtype").build();
+        startSendAndReceive(mockListenerOne, options);
+        InOrder inOrder = inOrder(mockListenerOne, mockSocketClient);
+
+        final ArgumentCaptor<DatagramPacket> queryCaptor =
+                ArgumentCaptor.forClass(DatagramPacket.class);
+        currentThreadExecutor.getAndClearLastScheduledRunnable().run();
+        // Send twice for IPv4 and IPv6
+        inOrder.verify(mockSocketClient, times(2)).sendPacketRequestingUnicastResponse(
+                queryCaptor.capture(), eq(socketKey), eq(false));
+        verify(mockDeps, times(1)).sendMessage(any(), any(Message.class));
+        assertNotNull(delayMessage);
+
+        final MdnsPacket queryPacket = MdnsPacket.parse(
+                new MdnsPacketReader(queryCaptor.getValue()));
+        final String[] subtypeLabels = Stream.concat(Stream.of("_subtype", "_sub"),
+                Arrays.stream(SERVICE_TYPE_LABELS)).toArray(String[]::new);
+        assertTrue(hasQuestion(queryPacket, MdnsRecord.TYPE_PTR, SERVICE_TYPE_LABELS));
+        assertTrue(hasQuestion(queryPacket, MdnsRecord.TYPE_PTR, subtypeLabels));
+
+        // Process a response
+        final String serviceName = "service-instance";
+        final String ipV4Address = "192.0.2.0";
+        final MdnsPacket packetWithoutSubtype = createResponse(
+                serviceName, ipV4Address, 5353, SERVICE_TYPE_LABELS,
+                Collections.emptyMap() /* textAttributes */, TEST_TTL);
+        final MdnsPointerRecord originalPtr = (MdnsPointerRecord) CollectionUtils.findFirst(
+                packetWithoutSubtype.answers, r -> r instanceof MdnsPointerRecord);
+
+        // Add a subtype PTR record
+        final ArrayList<MdnsRecord> newAnswers = new ArrayList<>(packetWithoutSubtype.answers);
+        newAnswers.add(new MdnsPointerRecord(subtypeLabels, originalPtr.getReceiptTime(),
+                originalPtr.getCacheFlush(), originalPtr.getTtl(), originalPtr.getPointer()));
+        final MdnsPacket packetWithSubtype = new MdnsPacket(
+                packetWithoutSubtype.flags,
+                packetWithoutSubtype.questions,
+                newAnswers,
+                packetWithoutSubtype.authorityRecords,
+                packetWithoutSubtype.additionalRecords);
+        processResponse(packetWithSubtype, socketKey);
+
+        // Expect a query with known answers
+        dispatchMessage();
+        final ArgumentCaptor<DatagramPacket> knownAnswersQueryCaptor =
+                ArgumentCaptor.forClass(DatagramPacket.class);
+        currentThreadExecutor.getAndClearLastScheduledRunnable().run();
+        inOrder.verify(mockSocketClient, times(2)).sendPacketRequestingMulticastResponse(
+                knownAnswersQueryCaptor.capture(), eq(socketKey), eq(false));
+
+        final MdnsPacket knownAnswersQueryPacket = MdnsPacket.parse(
+                new MdnsPacketReader(knownAnswersQueryCaptor.getValue()));
+        assertTrue(hasQuestion(knownAnswersQueryPacket, MdnsRecord.TYPE_PTR, SERVICE_TYPE_LABELS));
+        assertTrue(hasQuestion(knownAnswersQueryPacket, MdnsRecord.TYPE_PTR, subtypeLabels));
+        assertTrue(hasAnswer(knownAnswersQueryPacket, MdnsRecord.TYPE_PTR, SERVICE_TYPE_LABELS));
+        assertTrue(hasAnswer(knownAnswersQueryPacket, MdnsRecord.TYPE_PTR, subtypeLabels));
+    }
+
     private static MdnsServiceInfo matchServiceName(String name) {
         return argThat(info -> info.getServiceInstanceName().equals(name));
     }
@@ -1989,6 +2122,12 @@
                 && (name == null || Arrays.equals(q.name, name)));
     }
 
+    private static boolean hasAnswer(MdnsPacket packet, int type, @NonNull String[] name) {
+        return packet.answers.stream().anyMatch(q -> {
+            return q.getType() == type && (Arrays.equals(q.name, name));
+        });
+    }
+
     // A fake ScheduledExecutorService that keeps tracking the last scheduled Runnable and its delay
     // time.
     private class FakeExecutor extends ScheduledThreadPoolExecutor {