Support follow-up queries for missing records

If an mDNS advertiser does not reply with the full set of
PTR+SRV+TXT+A+AAAA recors, send follow-up queries for the missing
records, instead of just repeating the PTR query.

See RFC6763 12., especially the last paragraph.

Bug: 267570781
Test: atest NsdManagerTest
Change-Id: I495db18d9d5770522482c76c7a8805f5612324aa
diff --git a/service-t/src/com/android/server/connectivity/mdns/EnqueueMdnsQueryCallable.java b/service-t/src/com/android/server/connectivity/mdns/EnqueueMdnsQueryCallable.java
index fdd1478..c472a8f 100644
--- a/service-t/src/com/android/server/connectivity/mdns/EnqueueMdnsQueryCallable.java
+++ b/service-t/src/com/android/server/connectivity/mdns/EnqueueMdnsQueryCallable.java
@@ -60,13 +60,21 @@
         }
     }
 
+    @NonNull
     private final WeakReference<MdnsSocketClientBase> weakRequestSender;
+    @NonNull
     private final MdnsPacketWriter packetWriter;
+    @NonNull
     private final String[] serviceTypeLabels;
+    @NonNull
     private final List<String> subtypes;
     private final boolean expectUnicastResponse;
     private final int transactionId;
+    @Nullable
     private final Network network;
+    private final boolean sendDiscoveryQueries;
+    @NonNull
+    private final List<MdnsResponse> servicesToResolve;
 
     EnqueueMdnsQueryCallable(
             @NonNull MdnsSocketClientBase requestSender,
@@ -75,7 +83,9 @@
             @NonNull Collection<String> subtypes,
             boolean expectUnicastResponse,
             int transactionId,
-            @Nullable Network network) {
+            @Nullable Network network,
+            boolean sendDiscoveryQueries,
+            @NonNull Collection<MdnsResponse> servicesToResolve) {
         weakRequestSender = new WeakReference<>(requestSender);
         this.packetWriter = packetWriter;
         serviceTypeLabels = TextUtils.split(serviceType, "\\.");
@@ -83,6 +93,8 @@
         this.expectUnicastResponse = expectUnicastResponse;
         this.transactionId = transactionId;
         this.network = network;
+        this.sendDiscoveryQueries = sendDiscoveryQueries;
+        this.servicesToResolve = new ArrayList<>(servicesToResolve);
     }
 
     // Incompatible return type for override of Callable#call().
@@ -96,9 +108,48 @@
                 return null;
             }
 
-            int numQuestions = 1;
-            if (!subtypes.isEmpty()) {
-                numQuestions += subtypes.size();
+            int numQuestions = 0;
+
+            if (sendDiscoveryQueries) {
+                numQuestions++; // Base service type
+                if (!subtypes.isEmpty()) {
+                    numQuestions += subtypes.size();
+                }
+            }
+
+            // List of (name, type) to query
+            final ArrayList<Pair<String[], Integer>> missingKnownAnswerRecords = new ArrayList<>();
+            for (MdnsResponse response : servicesToResolve) {
+                // In practice responses should always have at least one pointer record, since the
+                // record is added after creation in MdnsResponseDecoder. All PTR records point to
+                // the same instance name, since addPointerRecord is only called on instances
+                // obtained through MdnsResponseDecoder.findResponseWithPointer.
+                // TODO: also send queries to renew record TTL (as per RFC6762 7.1 no need to query
+                // if remaining TTL is more than half the original one, so send the queries if half
+                // the TTL has passed).
+                if (!response.hasPointerRecords() || response.isComplete()) continue;
+                final String[] instanceName = response.getPointerRecords().get(0).getPointer();
+                if (instanceName == null) continue;
+                if (!response.hasTextRecord()) {
+                    missingKnownAnswerRecords.add(new Pair<>(instanceName, MdnsRecord.TYPE_TXT));
+                }
+                if (!response.hasServiceRecord()) {
+                    missingKnownAnswerRecords.add(new Pair<>(instanceName, MdnsRecord.TYPE_SRV));
+                    // The hostname is not yet known, so queries for address records will be sent
+                    // the next time the EnqueueMdnsQueryCallable is enqueued if the reply does not
+                    // contain them. In practice, advertisers should include the address records
+                    // when queried for SRV, although it's not a MUST requirement (RFC6763 12.2).
+                } else if (!response.hasInet4AddressRecord() && !response.hasInet6AddressRecord()) {
+                    final String[] host = response.getServiceRecord().getServiceHost();
+                    missingKnownAnswerRecords.add(new Pair<>(host, MdnsRecord.TYPE_A));
+                    missingKnownAnswerRecords.add(new Pair<>(host, MdnsRecord.TYPE_AAAA));
+                }
+            }
+            numQuestions += missingKnownAnswerRecords.size();
+
+            if (numQuestions == 0) {
+                // No query to send
+                return null;
             }
 
             // Header.
@@ -109,28 +160,25 @@
             packetWriter.writeUInt16(0); // number of authority entries
             packetWriter.writeUInt16(0); // number of additional records
 
-            // Question(s). There will be one question for each (fqdn+subtype, recordType)
-          // combination,
-            // as well as one for each (fqdn, recordType) combination.
-
-            for (String subtype : subtypes) {
-                String[] labels = new String[serviceTypeLabels.length + 2];
-                labels[0] = MdnsConstants.SUBTYPE_PREFIX + subtype;
-                labels[1] = MdnsConstants.SUBTYPE_LABEL;
-                System.arraycopy(serviceTypeLabels, 0, labels, 2, serviceTypeLabels.length);
-
-                packetWriter.writeLabels(labels);
-                packetWriter.writeUInt16(MdnsRecord.TYPE_PTR);
-                packetWriter.writeUInt16(
-                        MdnsConstants.QCLASS_INTERNET
-                                | (expectUnicastResponse ? MdnsConstants.QCLASS_UNICAST : 0));
+            // Question(s) for missing records on known answers
+            for (Pair<String[], Integer> question : missingKnownAnswerRecords) {
+                writeQuestion(question.first, question.second);
             }
 
-            packetWriter.writeLabels(serviceTypeLabels);
-            packetWriter.writeUInt16(MdnsRecord.TYPE_PTR);
-            packetWriter.writeUInt16(
-                    MdnsConstants.QCLASS_INTERNET
-                            | (expectUnicastResponse ? MdnsConstants.QCLASS_UNICAST : 0));
+            // Question(s) for discovering other services with the type. There will be one question
+            // for each (fqdn+subtype, recordType) combination, as well as one for each (fqdn,
+            // recordType) combination.
+            if (sendDiscoveryQueries) {
+                for (String subtype : subtypes) {
+                    String[] labels = new String[serviceTypeLabels.length + 2];
+                    labels[0] = MdnsConstants.SUBTYPE_PREFIX + subtype;
+                    labels[1] = MdnsConstants.SUBTYPE_LABEL;
+                    System.arraycopy(serviceTypeLabels, 0, labels, 2, serviceTypeLabels.length);
+
+                    writeQuestion(labels, MdnsRecord.TYPE_PTR);
+                }
+                writeQuestion(serviceTypeLabels, MdnsRecord.TYPE_PTR);
+            }
 
             if (requestSender instanceof MdnsMultinetworkSocketClient) {
                 sendPacketToIpv4AndIpv6(requestSender, MdnsConstants.MDNS_PORT, network);
@@ -159,6 +207,14 @@
         }
     }
 
+    private void writeQuestion(String[] labels, int type) throws IOException {
+        packetWriter.writeLabels(labels);
+        packetWriter.writeUInt16(type);
+        packetWriter.writeUInt16(
+                MdnsConstants.QCLASS_INTERNET
+                        | (expectUnicastResponse ? MdnsConstants.QCLASS_UNICAST : 0));
+    }
+
     private void sendPacketTo(MdnsSocketClientBase requestSender, InetSocketAddress address)
             throws IOException {
         DatagramPacket packet = packetWriter.getPacket(address);
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsSearchOptions.java b/service-t/src/com/android/server/connectivity/mdns/MdnsSearchOptions.java
index 583c4a9..3da6bd0 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsSearchOptions.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsSearchOptions.java
@@ -46,7 +46,8 @@
                 public MdnsSearchOptions createFromParcel(Parcel source) {
                     return new MdnsSearchOptions(source.createStringArrayList(),
                             source.readBoolean(), source.readBoolean(),
-                            source.readParcelable(null));
+                            source.readParcelable(null),
+                            source.readString());
                 }
 
                 @Override
@@ -56,6 +57,8 @@
             };
     private static MdnsSearchOptions defaultOptions;
     private final List<String> subtypes;
+    @Nullable
+    private final String resolveInstanceName;
 
     private final boolean isPassiveMode;
     private final boolean removeExpiredService;
@@ -64,7 +67,7 @@
 
     /** Parcelable constructs for a {@link MdnsSearchOptions}. */
     MdnsSearchOptions(List<String> subtypes, boolean isPassiveMode, boolean removeExpiredService,
-            @Nullable Network network) {
+            @Nullable Network network, @Nullable String resolveInstanceName) {
         this.subtypes = new ArrayList<>();
         if (subtypes != null) {
             this.subtypes.addAll(subtypes);
@@ -72,6 +75,7 @@
         this.isPassiveMode = isPassiveMode;
         this.removeExpiredService = removeExpiredService;
         mNetwork = network;
+        this.resolveInstanceName = resolveInstanceName;
     }
 
     /** Returns a {@link Builder} for {@link MdnsSearchOptions}. */
@@ -115,6 +119,15 @@
         return mNetwork;
     }
 
+    /**
+     * If non-null, queries should try to resolve all records of this specific service, rather than
+     * discovering all services.
+     */
+    @Nullable
+    public String getResolveInstanceName() {
+        return resolveInstanceName;
+    }
+
     @Override
     public int describeContents() {
         return 0;
@@ -126,6 +139,7 @@
         out.writeBoolean(isPassiveMode);
         out.writeBoolean(removeExpiredService);
         out.writeParcelable(mNetwork, 0);
+        out.writeString(resolveInstanceName);
     }
 
     /** A builder to create {@link MdnsSearchOptions}. */
@@ -134,6 +148,7 @@
         private boolean isPassiveMode = true;
         private boolean removeExpiredService;
         private Network mNetwork;
+        private String resolveInstanceName;
 
         private Builder() {
             subtypes = new ArraySet<>();
@@ -194,10 +209,22 @@
             return this;
         }
 
+        /**
+         * Set the instance name to resolve.
+         *
+         * If non-null, queries should try to resolve all records of this specific service,
+         * rather than discovering all services.
+         * @param name The instance name.
+         */
+        public Builder setResolveInstanceName(String name) {
+            resolveInstanceName = name;
+            return this;
+        }
+
         /** Builds a {@link MdnsSearchOptions} with the arguments supplied to this builder. */
         public MdnsSearchOptions build() {
             return new MdnsSearchOptions(new ArrayList<>(subtypes), isPassiveMode,
-                    removeExpiredService, mNetwork);
+                    removeExpiredService, mNetwork, resolveInstanceName);
         }
     }
 }
\ No newline at end of file
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
index d26fbdb..e33686f 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
@@ -23,7 +23,7 @@
 import android.net.Network;
 import android.os.SystemClock;
 import android.text.TextUtils;
-import android.util.ArraySet;
+import android.util.ArrayMap;
 import android.util.Pair;
 
 import com.android.internal.annotations.GuardedBy;
@@ -33,12 +33,12 @@
 import java.net.Inet4Address;
 import java.net.Inet6Address;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Collection;
 import java.util.HashMap;
 import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
-import java.util.Set;
 import java.util.concurrent.Future;
 import java.util.concurrent.ScheduledExecutorService;
 
@@ -56,7 +56,8 @@
     private final MdnsSocketClientBase socketClient;
     private final ScheduledExecutorService executor;
     private final Object lock = new Object();
-    private final Set<MdnsServiceBrowserListener> listeners = new ArraySet<>();
+    private final ArrayMap<MdnsServiceBrowserListener, MdnsSearchOptions> listeners =
+            new ArrayMap<>();
     private final Map<String, MdnsResponse> instanceNameToResponse = new HashMap<>();
     private final boolean removeServiceAfterTtlExpires =
             MdnsConfigs.removeServiceAfterTtlExpires();
@@ -148,7 +149,7 @@
             @NonNull MdnsSearchOptions searchOptions) {
         synchronized (lock) {
             this.searchOptions = searchOptions;
-            if (listeners.add(listener)) {
+            if (listeners.put(listener, searchOptions) == null) {
                 for (MdnsResponse existingResponse : instanceNameToResponse.values()) {
                     final MdnsServiceInfo info =
                             buildMdnsServiceInfoFromResponse(existingResponse, serviceTypeLabels);
@@ -219,8 +220,8 @@
     }
 
     public synchronized void onFailedToParseMdnsResponse(int receivedPacketNumber, int errorCode) {
-        for (MdnsServiceBrowserListener listener : listeners) {
-            listener.onFailedToParseMdnsResponse(receivedPacketNumber, errorCode);
+        for (int i = 0; i < listeners.size(); i++) {
+            listeners.keyAt(i).onFailedToParseMdnsResponse(receivedPacketNumber, errorCode);
         }
     }
 
@@ -250,7 +251,8 @@
         MdnsServiceInfo serviceInfo =
                 buildMdnsServiceInfoFromResponse(currentResponse, serviceTypeLabels);
 
-        for (MdnsServiceBrowserListener listener : listeners) {
+        for (int i = 0; i < listeners.size(); i++) {
+            final MdnsServiceBrowserListener listener = listeners.keyAt(i);
             if (newServiceFound) {
                 listener.onServiceNameDiscovered(serviceInfo);
             }
@@ -270,7 +272,8 @@
         if (response == null) {
             return;
         }
-        for (MdnsServiceBrowserListener listener : listeners) {
+        for (int i = 0; i < listeners.size(); i++) {
+            final MdnsServiceBrowserListener listener = listeners.keyAt(i);
             final MdnsServiceInfo serviceInfo =
                     buildMdnsServiceInfoFromResponse(response, serviceTypeLabels);
             if (response.isComplete()) {
@@ -400,6 +403,36 @@
 
         @Override
         public void run() {
+            final List<MdnsResponse> servicesToResolve = new ArrayList<>();
+            boolean sendDiscoveryQueries = false;
+            synchronized (lock) {
+                for (int i = 0; i < listeners.size(); i++) {
+                    final String resolveName = listeners.valueAt(i).getResolveInstanceName();
+                    if (resolveName == null) {
+                        sendDiscoveryQueries = true;
+                        continue;
+                    }
+                    MdnsResponse knownResponse = instanceNameToResponse.get(resolveName);
+                    if (knownResponse == null) {
+                        // The listener is requesting to resolve a service that has no info in
+                        // cache. Use the provided name to generate a minimal response with just a
+                        // PTR record, so other records are queried to complete it.
+                        // Only the names are used to know which queries to send, other
+                        // parameters do not matter.
+                        knownResponse = new MdnsResponse(
+                                0L /* now */, 0 /* interfaceIndex */, config.network);
+                        final ArrayList<String> instanceFullName = new ArrayList<>(
+                                serviceTypeLabels.length + 1);
+                        instanceFullName.add(resolveName);
+                        instanceFullName.addAll(Arrays.asList(serviceTypeLabels));
+                        knownResponse.addPointerRecord(new MdnsPointerRecord(
+                                serviceTypeLabels, 0L /* receiptTimeMillis */,
+                                false /* cacheFlush */, 0L /* ttlMillis */,
+                                instanceFullName.toArray(new String[0])));
+                    }
+                    servicesToResolve.add(knownResponse);
+                }
+            }
             Pair<Integer, List<String>> result;
             try {
                 result =
@@ -410,7 +443,9 @@
                                 config.subtypes,
                                 config.expectUnicastResponse,
                                 config.transactionId,
-                                config.network)
+                                config.network,
+                                sendDiscoveryQueries,
+                                servicesToResolve)
                                 .call();
             } catch (RuntimeException e) {
                 LOGGER.e(String.format("Failed to run EnqueueMdnsQueryCallable for subtype: %s",
@@ -435,8 +470,8 @@
                     }
                 }
                 if ((result != null)) {
-                    for (MdnsServiceBrowserListener listener : listeners) {
-                        listener.onDiscoveryQuerySent(result.second, result.first);
+                    for (int i = 0; i < listeners.size(); i++) {
+                        listeners.keyAt(i).onDiscoveryQuerySent(result.second, result.first);
                     }
                 }
                 if (shouldRemoveServiceAfterTtlExpires()) {
@@ -449,7 +484,8 @@
                                 .getRemainingTTL(SystemClock.elapsedRealtime())
                                 == 0) {
                             iter.remove();
-                            for (MdnsServiceBrowserListener listener : listeners) {
+                            for (int i = 0; i < listeners.size(); i++) {
+                                final MdnsServiceBrowserListener listener = listeners.keyAt(i);
                                 String serviceInstanceName =
                                         existingResponse.getServiceInstanceName();
                                 if (serviceInstanceName != null) {