Add replying to queries

MdnsInterfaceAdvertiser registers to receive incoming packets, and
sends replies to queries as built by MdnsRecordRepository.

Bug: 241738458
Test: atest
Change-Id: I13db22f8efc870b6e0747d105f6bc8f759910f81
diff --git a/service-t/src/com/android/server/mdns/MdnsConstants.java b/service-t/src/com/android/server/mdns/MdnsConstants.java
index 396be5f..f0e1717 100644
--- a/service-t/src/com/android/server/mdns/MdnsConstants.java
+++ b/service-t/src/com/android/server/mdns/MdnsConstants.java
@@ -37,6 +37,7 @@
     public static final int FLAGS_QUERY = 0x0000;
     public static final int FLAGS_RESPONSE_MASK = 0xF80F;
     public static final int FLAGS_RESPONSE = 0x8000;
+    public static final int FLAG_TRUNCATED = 0x0200;
     public static final int QCLASS_INTERNET = 0x0001;
     public static final int QCLASS_UNICAST = 0x8000;
     public static final String SUBTYPE_LABEL = "_sub";
diff --git a/service-t/src/com/android/server/mdns/MdnsInterfaceAdvertiser.java b/service-t/src/com/android/server/mdns/MdnsInterfaceAdvertiser.java
index 790e69a..a14b5ad 100644
--- a/service-t/src/com/android/server/mdns/MdnsInterfaceAdvertiser.java
+++ b/service-t/src/com/android/server/mdns/MdnsInterfaceAdvertiser.java
@@ -25,16 +25,18 @@
 import android.util.Log;
 
 import com.android.internal.annotations.VisibleForTesting;
+import com.android.net.module.util.HexDump;
 import com.android.server.connectivity.mdns.MdnsAnnouncer.BaseAnnouncementInfo;
 import com.android.server.connectivity.mdns.MdnsPacketRepeater.PacketRepeaterCallback;
 
 import java.io.IOException;
+import java.net.InetSocketAddress;
 import java.util.List;
 
 /**
  * A class that handles advertising services on a {@link MdnsInterfaceSocket} tied to an interface.
  */
-public class MdnsInterfaceAdvertiser {
+public class MdnsInterfaceAdvertiser implements MulticastPacketReader.PacketHandler {
     private static final boolean DBG = MdnsAdvertiser.DBG;
     @VisibleForTesting
     public static final long EXIT_ANNOUNCEMENT_DELAY_MS = 100L;
@@ -145,9 +147,9 @@
 
         /** @see MdnsReplySender */
         @NonNull
-        public MdnsReplySender makeReplySender(@NonNull Looper looper,
+        public MdnsReplySender makeReplySender(@NonNull String interfaceTag, @NonNull Looper looper,
                 @NonNull MdnsInterfaceSocket socket, @NonNull byte[] packetCreationBuffer) {
-            return new MdnsReplySender(looper, socket, packetCreationBuffer);
+            return new MdnsReplySender(interfaceTag, looper, socket, packetCreationBuffer);
         }
 
         /** @see MdnsAnnouncer */
@@ -182,7 +184,7 @@
         mSocket = socket;
         mCb = cb;
         mCbHandler = new Handler(looper);
-        mReplySender = deps.makeReplySender(looper, socket, packetCreationBuffer);
+        mReplySender = deps.makeReplySender(logTag, looper, socket, packetCreationBuffer);
         mAnnouncer = deps.makeMdnsAnnouncer(logTag, looper, mReplySender,
                 mAnnouncingCallback);
         mProber = deps.makeMdnsProber(logTag, looper, mReplySender, mProbingCallback);
@@ -196,7 +198,7 @@
      * {@link #destroyNow()}.
      */
     public void start() {
-        // TODO: start receiving packets
+        mSocket.addPacketHandler(this);
     }
 
     /**
@@ -267,8 +269,8 @@
             mProber.stop(serviceId);
             mAnnouncer.stop(serviceId);
         }
-
-        // TODO: stop receiving packets
+        mReplySender.cancelAll();
+        mSocket.removePacketHandler(this);
         mCbHandler.post(() -> mCb.onDestroyed(mSocket));
     }
 
@@ -294,4 +296,33 @@
     public boolean isProbing(int serviceId) {
         return mRecordRepository.isProbing(serviceId);
     }
+
+    @Override
+    public void handlePacket(byte[] recvbuf, int length, InetSocketAddress src) {
+        final MdnsPacket packet;
+        try {
+            packet = MdnsPacket.parse(new MdnsPacketReader(recvbuf, length));
+        } catch (MdnsPacket.ParseException e) {
+            Log.e(mTag, "Error parsing mDNS packet", e);
+            if (DBG) {
+                Log.v(
+                        mTag, "Packet: " + HexDump.toHexString(recvbuf, 0, length));
+            }
+            return;
+        }
+
+        if (DBG) {
+            Log.v(mTag,
+                    "Parsed packet with " + packet.questions.size() + " questions, "
+                            + packet.answers.size() + " answers, "
+                            + packet.authorityRecords.size() + " authority, "
+                            + packet.additionalRecords.size() + " additional from " + src);
+        }
+
+        final MdnsRecordRepository.ReplyInfo answers =
+                mRecordRepository.getReply(packet, src);
+
+        if (answers == null) return;
+        mReplySender.queueReply(answers);
+    }
 }
diff --git a/service-t/src/com/android/server/mdns/MdnsInterfaceSocket.java b/service-t/src/com/android/server/mdns/MdnsInterfaceSocket.java
index d1290b6..119c7a8 100644
--- a/service-t/src/com/android/server/mdns/MdnsInterfaceSocket.java
+++ b/service-t/src/com/android/server/mdns/MdnsInterfaceSocket.java
@@ -162,6 +162,14 @@
     }
 
     /**
+     * Remove a handler added via {@link #addPacketHandler}. If the handler is not present, this is
+     * a no-op.
+     */
+    public void removePacketHandler(@NonNull MulticastPacketReader.PacketHandler handler) {
+        mPacketReader.removePacketHandler(handler);
+    }
+
+    /**
      * Returns the network interface that this socket is bound to.
      *
      * <p>This method could be used on any thread.
diff --git a/service-t/src/com/android/server/mdns/MdnsPacketRepeater.java b/service-t/src/com/android/server/mdns/MdnsPacketRepeater.java
index ae54e70..4c385da 100644
--- a/service-t/src/com/android/server/mdns/MdnsPacketRepeater.java
+++ b/service-t/src/com/android/server/mdns/MdnsPacketRepeater.java
@@ -16,6 +16,9 @@
 
 package com.android.server.connectivity.mdns;
 
+import static com.android.server.connectivity.mdns.MdnsRecordRepository.IPV4_ADDR;
+import static com.android.server.connectivity.mdns.MdnsRecordRepository.IPV6_ADDR;
+
 import android.annotation.NonNull;
 import android.annotation.Nullable;
 import android.os.Handler;
@@ -32,10 +35,6 @@
  */
 public abstract class MdnsPacketRepeater<T extends MdnsPacketRepeater.Request> {
     private static final boolean DBG = MdnsAdvertiser.DBG;
-    private static final InetSocketAddress IPV4_ADDR = new InetSocketAddress(
-            MdnsConstants.getMdnsIPv4Address(), MdnsConstants.MDNS_PORT);
-    private static final InetSocketAddress IPV6_ADDR = new InetSocketAddress(
-            MdnsConstants.getMdnsIPv6Address(), MdnsConstants.MDNS_PORT);
     private static final InetSocketAddress[] ALL_ADDRS = new InetSocketAddress[] {
             IPV4_ADDR, IPV6_ADDR
     };
@@ -114,7 +113,7 @@
             final MdnsPacket packet = request.getPacket(index);
             if (DBG) {
                 Log.v(getTag(), "Sending packets for iteration " + index + " out of "
-                        + request.getNumSends());
+                        + request.getNumSends() + " for ID " + msg.what);
             }
             // Send to both v4 and v6 addresses; the reply sender will take care of ignoring the
             // send when the socket has not joined the relevant group.
diff --git a/service-t/src/com/android/server/mdns/MdnsRecord.java b/service-t/src/com/android/server/mdns/MdnsRecord.java
index 00871ea..bcee9d1 100644
--- a/service-t/src/com/android/server/mdns/MdnsRecord.java
+++ b/service-t/src/com/android/server/mdns/MdnsRecord.java
@@ -45,6 +45,7 @@
     private static final int FLAG_CACHE_FLUSH = 0x8000;
 
     public static final long RECEIPT_TIME_NOT_SENT = 0L;
+    public static final int CLASS_ANY = 0x00ff;
 
     /** Status indicating that the record is current. */
     public static final int STATUS_OK = 0;
@@ -317,4 +318,4 @@
             return (recordType * 31) + Arrays.hashCode(recordName);
         }
     }
-}
\ No newline at end of file
+}
diff --git a/service-t/src/com/android/server/mdns/MdnsRecordRepository.java b/service-t/src/com/android/server/mdns/MdnsRecordRepository.java
index dd00212..4b2f553 100644
--- a/service-t/src/com/android/server/mdns/MdnsRecordRepository.java
+++ b/service-t/src/com/android/server/mdns/MdnsRecordRepository.java
@@ -34,6 +34,7 @@
 import java.io.IOException;
 import java.net.Inet4Address;
 import java.net.InetAddress;
+import java.net.InetSocketAddress;
 import java.net.NetworkInterface;
 import java.util.ArrayList;
 import java.util.Arrays;
@@ -42,6 +43,7 @@
 import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
+import java.util.Random;
 import java.util.Set;
 import java.util.TreeMap;
 import java.util.UUID;
@@ -54,6 +56,9 @@
  */
 @TargetApi(Build.VERSION_CODES.TIRAMISU) // Allow calling T+ APIs; this is only loaded on T+
 public class MdnsRecordRepository {
+    // RFC6762 p.15
+    private static final long MIN_MULTICAST_REPLY_INTERVAL_MS = 1_000L;
+
     // TTLs as per RFC6762 10.
     // TTL for records with a host name as the resource record's name (e.g., A, AAAA, HINFO) or a
     // host name contained within the resource record's rdata (e.g., SRV, reverse mapping PTR
@@ -69,6 +74,13 @@
     private static final String[] DNS_SD_SERVICE_TYPE =
             new String[] { "_services", "_dns-sd", "_udp", LOCAL_TLD };
 
+    public static final InetSocketAddress IPV6_ADDR = new InetSocketAddress(
+            MdnsConstants.getMdnsIPv6Address(), MdnsConstants.MDNS_PORT);
+    public static final InetSocketAddress IPV4_ADDR = new InetSocketAddress(
+            MdnsConstants.getMdnsIPv4Address(), MdnsConstants.MDNS_PORT);
+
+    @NonNull
+    private final Random mDelayGenerator = new Random();
     // Map of service unique ID -> records for service
     @NonNull
     private final SparseArray<ServiceRegistration> mServices = new SparseArray<>();
@@ -139,6 +151,11 @@
         public boolean isProbing;
 
         /**
+         * Last time (as per SystemClock.elapsedRealtime) when advertised via multicast, 0 if never
+         */
+        public long lastAdvertisedTimeMs;
+
+        /**
          * Last time (as per SystemClock.elapsedRealtime) when sent via unicast or multicast,
          * 0 if never
          */
@@ -391,6 +408,212 @@
     }
 
     /**
+     * Info about a reply to be sent.
+     */
+    public static class ReplyInfo {
+        @NonNull
+        public final List<MdnsRecord> answers;
+        @NonNull
+        public final List<MdnsRecord> additionalAnswers;
+        public final long sendDelayMs;
+        @NonNull
+        public final InetSocketAddress destination;
+
+        public ReplyInfo(
+                @NonNull List<MdnsRecord> answers,
+                @NonNull List<MdnsRecord> additionalAnswers,
+                long sendDelayMs,
+                @NonNull InetSocketAddress destination) {
+            this.answers = answers;
+            this.additionalAnswers = additionalAnswers;
+            this.sendDelayMs = sendDelayMs;
+            this.destination = destination;
+        }
+
+        @Override
+        public String toString() {
+            return "{ReplyInfo to " + destination + ", answers: " + answers.size()
+                    + ", additionalAnswers: " + additionalAnswers.size()
+                    + ", sendDelayMs " + sendDelayMs + "}";
+        }
+    }
+
+    /**
+     * Get the reply to send to an incoming packet.
+     *
+     * @param packet The incoming packet.
+     * @param src The source address of the incoming packet.
+     */
+    @Nullable
+    public ReplyInfo getReply(MdnsPacket packet, InetSocketAddress src) {
+        final long now = SystemClock.elapsedRealtime();
+        final boolean replyUnicast = (packet.flags & MdnsConstants.QCLASS_UNICAST) != 0;
+        final ArrayList<MdnsRecord> additionalAnswerRecords = new ArrayList<>();
+        final ArrayList<RecordInfo<?>> answerInfo = new ArrayList<>();
+        for (MdnsRecord question : packet.questions) {
+            // Add answers from general records
+            addReplyFromService(question, mGeneralRecords, null /* servicePtrRecord */,
+                    null /* serviceSrvRecord */, null /* serviceTxtRecord */, replyUnicast, now,
+                    answerInfo, additionalAnswerRecords);
+
+            // Add answers from each service
+            for (int i = 0; i < mServices.size(); i++) {
+                final ServiceRegistration registration = mServices.valueAt(i);
+                if (registration.exiting) continue;
+                addReplyFromService(question, registration.allRecords, registration.ptrRecord,
+                        registration.srvRecord, registration.txtRecord, replyUnicast, now,
+                        answerInfo, additionalAnswerRecords);
+            }
+        }
+
+        if (answerInfo.size() == 0 && additionalAnswerRecords.size() == 0) {
+            return null;
+        }
+
+        // Determine the send delay
+        final long delayMs;
+        if ((packet.flags & MdnsConstants.FLAG_TRUNCATED) != 0) {
+            // RFC 6762 6.: 400-500ms delay if TC bit is set
+            delayMs = 400L + mDelayGenerator.nextInt(100);
+        } else if (packet.questions.size() > 1
+                || CollectionUtils.any(answerInfo, a -> a.isSharedName)) {
+            // 20-120ms if there may be responses from other hosts (not a fully owned
+            // name) (RFC 6762 6.), or if there are multiple questions (6.3).
+            // TODO: this should be 0 if this is a probe query ("can be distinguished from a
+            // normal query by the fact that a probe query contains a proposed record in the
+            // Authority Section that answers the question" in 6.), and the reply is for a fully
+            // owned record.
+            delayMs = 20L + mDelayGenerator.nextInt(100);
+        } else {
+            delayMs = 0L;
+        }
+
+        // Determine the send destination
+        final InetSocketAddress dest;
+        if (replyUnicast) {
+            dest = src;
+        } else if (src.getAddress() instanceof Inet4Address) {
+            dest = IPV4_ADDR;
+        } else {
+            dest = IPV6_ADDR;
+        }
+
+        // Build the list of answer records from their RecordInfo
+        final ArrayList<MdnsRecord> answerRecords = new ArrayList<>(answerInfo.size());
+        for (RecordInfo<?> info : answerInfo) {
+            // TODO: consider actual packet send delay after response aggregation
+            info.lastSentTimeMs = now + delayMs;
+            if (!replyUnicast) {
+                info.lastAdvertisedTimeMs = info.lastSentTimeMs;
+            }
+            answerRecords.add(info.record);
+        }
+
+        return new ReplyInfo(answerRecords, additionalAnswerRecords, delayMs, dest);
+    }
+
+    /**
+     * Add answers and additional answers for a question, from a ServiceRegistration.
+     */
+    private void addReplyFromService(@NonNull MdnsRecord question,
+            @NonNull List<RecordInfo<?>> serviceRecords,
+            @Nullable RecordInfo<MdnsPointerRecord> servicePtrRecord,
+            @Nullable RecordInfo<MdnsServiceRecord> serviceSrvRecord,
+            @Nullable RecordInfo<MdnsTextRecord> serviceTxtRecord,
+            boolean replyUnicast, long now, @NonNull List<RecordInfo<?>> answerInfo,
+            @NonNull List<MdnsRecord> additionalAnswerRecords) {
+        boolean hasDnsSdPtrRecordAnswer = false;
+        boolean hasDnsSdSrvRecordAnswer = false;
+        boolean hasFullyOwnedNameMatch = false;
+        boolean hasKnownAnswer = false;
+
+        final int answersStartIndex = answerInfo.size();
+        for (RecordInfo<?> info : serviceRecords) {
+            if (info.isProbing) continue;
+
+             /* RFC6762 6.: the record name must match the question name, the record rrtype
+             must match the question qtype unless the qtype is "ANY" (255) or the rrtype is
+             "CNAME" (5), and the record rrclass must match the question qclass unless the
+             qclass is "ANY" (255) */
+            if (!Arrays.equals(info.record.getName(), question.getName())) continue;
+            hasFullyOwnedNameMatch |= !info.isSharedName;
+
+            // The repository does not store CNAME records
+            if (question.getType() != MdnsRecord.TYPE_ANY
+                    && question.getType() != info.record.getType()) {
+                continue;
+            }
+            if (question.getRecordClass() != MdnsRecord.CLASS_ANY
+                    && question.getRecordClass() != info.record.getRecordClass()) {
+                continue;
+            }
+
+            hasKnownAnswer = true;
+            hasDnsSdPtrRecordAnswer |= (info == servicePtrRecord);
+            hasDnsSdSrvRecordAnswer |= (info == serviceSrvRecord);
+
+            // TODO: responses to probe queries should bypass this check and only ensure the
+            // reply is sent 250ms after the last sent time (RFC 6762 p.15)
+            if (!replyUnicast && info.lastAdvertisedTimeMs > 0L
+                    && now - info.lastAdvertisedTimeMs < MIN_MULTICAST_REPLY_INTERVAL_MS) {
+                continue;
+            }
+
+            // TODO: Don't reply if in known answers of the querier (7.1) if TTL is > half
+
+            answerInfo.add(info);
+        }
+
+        // RFC6762 6.1:
+        // "Any time a responder receives a query for a name for which it has verified exclusive
+        // ownership, for a type for which that name has no records, the responder MUST [...]
+        // respond asserting the nonexistence of that record"
+        if (hasFullyOwnedNameMatch && !hasKnownAnswer) {
+            additionalAnswerRecords.add(new MdnsNsecRecord(
+                    question.getName(),
+                    0L /* receiptTimeMillis */,
+                    true /* cacheFlush */,
+                    // TODO: RFC6762 6.1: "In general, the TTL given for an NSEC record SHOULD
+                    // be the same as the TTL that the record would have had, had it existed."
+                    NAME_RECORDS_TTL_MILLIS,
+                    question.getName(),
+                    new int[] { question.getType() }));
+        }
+
+        // No more records to add if no answer
+        if (answerInfo.size() == answersStartIndex) return;
+
+        final List<RecordInfo<?>> additionalAnswerInfo = new ArrayList<>();
+        // RFC6763 12.1: if including PTR record, include the SRV and TXT records it names
+        if (hasDnsSdPtrRecordAnswer) {
+            if (serviceTxtRecord != null) {
+                additionalAnswerInfo.add(serviceTxtRecord);
+            }
+            if (serviceSrvRecord != null) {
+                additionalAnswerInfo.add(serviceSrvRecord);
+            }
+        }
+
+        // RFC6763 12.1&.2: if including PTR or SRV record, include the address records it names
+        if (hasDnsSdPtrRecordAnswer || hasDnsSdSrvRecordAnswer) {
+            for (RecordInfo<?> record : mGeneralRecords) {
+                if (record.record instanceof MdnsInetAddressRecord) {
+                    additionalAnswerInfo.add(record);
+                }
+            }
+        }
+
+        for (RecordInfo<?> info : additionalAnswerInfo) {
+            additionalAnswerRecords.add(info.record);
+        }
+
+        // RFC6762 6.1: negative responses
+        addNsecRecordsForUniqueNames(additionalAnswerRecords,
+                answerInfo.listIterator(answersStartIndex),
+                additionalAnswerInfo.listIterator());
+    }
+
+    /**
      * Add NSEC records indicating that the response records are unique.
      *
      * Following RFC6762 6.1:
@@ -540,6 +763,7 @@
         final long now = SystemClock.elapsedRealtime();
         for (RecordInfo<?> record : registration.allRecords) {
             record.lastSentTimeMs = now;
+            record.lastAdvertisedTimeMs = now;
         }
     }
 
diff --git a/service-t/src/com/android/server/mdns/MdnsReplySender.java b/service-t/src/com/android/server/mdns/MdnsReplySender.java
index c6b8f47..f1389ca 100644
--- a/service-t/src/com/android/server/mdns/MdnsReplySender.java
+++ b/service-t/src/com/android/server/mdns/MdnsReplySender.java
@@ -16,8 +16,15 @@
 
 package com.android.server.connectivity.mdns;
 
+import static com.android.server.connectivity.mdns.MdnsSocketProvider.ensureRunningOnHandlerThread;
+
 import android.annotation.NonNull;
+import android.os.Handler;
 import android.os.Looper;
+import android.os.Message;
+import android.util.Log;
+
+import com.android.server.connectivity.mdns.MdnsRecordRepository.ReplyInfo;
 
 import java.io.IOException;
 import java.net.DatagramPacket;
@@ -25,6 +32,7 @@
 import java.net.Inet6Address;
 import java.net.InetSocketAddress;
 import java.net.MulticastSocket;
+import java.util.Collections;
 
 /**
  * A class that handles sending mDNS replies to a {@link MulticastSocket}, possibly queueing them
@@ -33,30 +41,46 @@
  * TODO: implement sending after a delay, combining queued replies and duplicate answer suppression
  */
 public class MdnsReplySender {
+    private static final boolean DBG = MdnsAdvertiser.DBG;
+    private static final int MSG_SEND = 1;
+
+    private final String mLogTag;
     @NonNull
     private final MdnsInterfaceSocket mSocket;
     @NonNull
-    private final Looper mLooper;
+    private final Handler mHandler;
     @NonNull
     private final byte[] mPacketCreationBuffer;
 
-    public MdnsReplySender(@NonNull Looper looper,
+    public MdnsReplySender(@NonNull String interfaceTag, @NonNull Looper looper,
             @NonNull MdnsInterfaceSocket socket, @NonNull byte[] packetCreationBuffer) {
-        mLooper = looper;
+        mHandler = new SendHandler(looper);
+        mLogTag = MdnsReplySender.class.getSimpleName() + "/" +  interfaceTag;
         mSocket = socket;
         mPacketCreationBuffer = packetCreationBuffer;
     }
 
     /**
+     * Queue a reply to be sent when its send delay expires.
+     */
+    public void queueReply(@NonNull ReplyInfo reply) {
+        ensureRunningOnHandlerThread(mHandler);
+        // TODO: implement response aggregation (RFC 6762 6.4)
+        mHandler.sendMessageDelayed(mHandler.obtainMessage(MSG_SEND, reply), reply.sendDelayMs);
+
+        if (DBG) {
+            Log.v(mLogTag, "Scheduling " + reply);
+        }
+    }
+
+    /**
      * Send a packet immediately.
      *
      * Must be called on the looper thread used by the {@link MdnsReplySender}.
      */
     public void sendNow(@NonNull MdnsPacket packet, @NonNull InetSocketAddress destination)
             throws IOException {
-        if (Thread.currentThread() != mLooper.getThread()) {
-            throw new IllegalStateException("sendNow must be called in the handler thread");
-        }
+        ensureRunningOnHandlerThread(mHandler);
         if (!((destination.getAddress() instanceof Inet6Address && mSocket.hasJoinedIpv6())
                 || (destination.getAddress() instanceof Inet4Address && mSocket.hasJoinedIpv4()))) {
             // Skip sending if the socket has not joined the v4/v6 group (there was no address)
@@ -93,4 +117,37 @@
 
         mSocket.send(new DatagramPacket(outBuffer, 0, len, destination));
     }
+
+    /**
+     * Cancel all pending sends.
+     */
+    public void cancelAll() {
+        ensureRunningOnHandlerThread(mHandler);
+        mHandler.removeMessages(MSG_SEND);
+    }
+
+    private class SendHandler extends Handler {
+        SendHandler(@NonNull Looper looper) {
+            super(looper);
+        }
+
+        @Override
+        public void handleMessage(@NonNull Message msg) {
+            final ReplyInfo replyInfo = (ReplyInfo) msg.obj;
+            if (DBG) Log.v(mLogTag, "Sending " + replyInfo);
+
+            final int flags = 0x8400; // Response, authoritative (rfc6762 18.4)
+            final MdnsPacket packet = new MdnsPacket(flags,
+                    Collections.emptyList() /* questions */,
+                    replyInfo.answers,
+                    Collections.emptyList() /* authorityRecords */,
+                    replyInfo.additionalAnswers);
+
+            try {
+                sendNow(packet, replyInfo.destination);
+            } catch (IOException e) {
+                Log.e(mLogTag, "Error sending MDNS response", e);
+            }
+        }
+    }
 }
diff --git a/service-t/src/com/android/server/mdns/MulticastPacketReader.java b/service-t/src/com/android/server/mdns/MulticastPacketReader.java
index 20cc47f..b597f0a 100644
--- a/service-t/src/com/android/server/mdns/MulticastPacketReader.java
+++ b/service-t/src/com/android/server/mdns/MulticastPacketReader.java
@@ -107,5 +107,14 @@
         ensureRunningOnHandlerThread(mHandler);
         mPacketHandlers.add(handler);
     }
+
+    /**
+     * Remove a packet handler added via {@link #addPacketHandler}. If the handler was not set,
+     * this is a no-op.
+     */
+    public void removePacketHandler(@NonNull PacketHandler handler) {
+        ensureRunningOnHandlerThread(mHandler);
+        mPacketHandlers.remove(handler);
+    }
 }
 
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsAnnouncerTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsAnnouncerTest.kt
index 334f99d..6c3f729 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsAnnouncerTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsAnnouncerTest.kt
@@ -79,7 +79,7 @@
 
     @Test
     fun testAnnounce() {
-        val replySender = MdnsReplySender(thread.looper, socket, buffer)
+        val replySender = MdnsReplySender("testiface", thread.looper, socket, buffer)
         @Suppress("UNCHECKED_CAST")
         val cb = mock(MdnsPacketRepeater.PacketRepeaterCallback::class.java)
                 as MdnsPacketRepeater.PacketRepeaterCallback<BaseAnnouncementInfo>
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt
index 2cb0850..02b3976 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt
@@ -21,6 +21,7 @@
 import android.net.nsd.NsdServiceInfo
 import android.os.Build
 import android.os.HandlerThread
+import com.android.net.module.util.HexDump
 import com.android.server.connectivity.mdns.MdnsAnnouncer.AnnouncementInfo
 import com.android.server.connectivity.mdns.MdnsAnnouncer.BaseAnnouncementInfo
 import com.android.server.connectivity.mdns.MdnsAnnouncer.ExitAnnouncementInfo
@@ -30,6 +31,10 @@
 import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo
 import com.android.testutils.DevSdkIgnoreRunner
 import com.android.testutils.waitForIdle
+import java.net.InetSocketAddress
+import kotlin.test.assertContentEquals
+import kotlin.test.assertEquals
+import kotlin.test.assertTrue
 import org.junit.After
 import org.junit.Before
 import org.junit.Test
@@ -37,8 +42,10 @@
 import org.mockito.ArgumentCaptor
 import org.mockito.Mockito.any
 import org.mockito.Mockito.anyInt
+import org.mockito.Mockito.anyString
 import org.mockito.Mockito.doAnswer
 import org.mockito.Mockito.doReturn
+import org.mockito.Mockito.eq
 import org.mockito.Mockito.mock
 import org.mockito.Mockito.times
 import org.mockito.Mockito.verify
@@ -67,13 +74,18 @@
     private val replySender = mock(MdnsReplySender::class.java)
     private val announcer = mock(MdnsAnnouncer::class.java)
     private val prober = mock(MdnsProber::class.java)
+    @Suppress("UNCHECKED_CAST")
     private val probeCbCaptor = ArgumentCaptor.forClass(PacketRepeaterCallback::class.java)
             as ArgumentCaptor<PacketRepeaterCallback<ProbingInfo>>
+    @Suppress("UNCHECKED_CAST")
     private val announceCbCaptor = ArgumentCaptor.forClass(PacketRepeaterCallback::class.java)
             as ArgumentCaptor<PacketRepeaterCallback<BaseAnnouncementInfo>>
+    private val packetHandlerCaptor = ArgumentCaptor.forClass(
+            MulticastPacketReader.PacketHandler::class.java)
 
     private val probeCb get() = probeCbCaptor.value
     private val announceCb get() = announceCbCaptor.value
+    private val packetHandler get() = packetHandlerCaptor.value
 
     private val advertiser by lazy {
         MdnsInterfaceAdvertiser(LOG_TAG, socket, TEST_ADDRS, thread.looper, TEST_BUFFER, cb, deps)
@@ -82,9 +94,9 @@
     @Before
     fun setUp() {
         doReturn(repository).`when`(deps).makeRecordRepository(any())
-        doReturn(replySender).`when`(deps).makeReplySender(any(), any(), any())
-        doReturn(announcer).`when`(deps).makeMdnsAnnouncer(any(), any(), any(), any())
-        doReturn(prober).`when`(deps).makeMdnsProber(any(), any(), any(), any())
+        doReturn(replySender).`when`(deps).makeReplySender(anyString(), any(), any(), any())
+        doReturn(announcer).`when`(deps).makeMdnsAnnouncer(anyString(), any(), any(), any())
+        doReturn(prober).`when`(deps).makeMdnsProber(anyString(), any(), any(), any())
 
         val knownServices = mutableSetOf<Int>()
         doAnswer { inv ->
@@ -104,6 +116,7 @@
         thread.start()
         advertiser.start()
 
+        verify(socket).addPacketHandler(packetHandlerCaptor.capture())
         verify(deps).makeMdnsProber(any(), any(), any(), probeCbCaptor.capture())
         verify(deps).makeMdnsAnnouncer(any(), any(), any(), announceCbCaptor.capture())
     }
@@ -157,6 +170,39 @@
         verify(announcer, times(1)).stop(TEST_SERVICE_ID_1)
     }
 
+    @Test
+    fun testReplyToQuery() {
+        addServiceAndFinishProbing(TEST_SERVICE_ID_1, TEST_SERVICE_1)
+
+        val mockReply = mock(MdnsRecordRepository.ReplyInfo::class.java)
+        doReturn(mockReply).`when`(repository).getReply(any(), any())
+
+        // Query obtained with:
+        // scapy.raw(scapy.DNS(
+        //  qd = scapy.DNSQR(qtype='PTR', qname='_testservice._tcp.local'))
+        // ).hex().upper()
+        val query = HexDump.hexStringToByteArray(
+                "0000010000010000000000000C5F7465737473657276696365045F746370056C6F63616C00000C0001"
+        )
+        val src = InetSocketAddress(parseNumericAddress("2001:db8::456"), MdnsConstants.MDNS_PORT)
+        packetHandler.handlePacket(query, query.size, src)
+
+        val packetCaptor = ArgumentCaptor.forClass(MdnsPacket::class.java)
+        verify(repository).getReply(packetCaptor.capture(), eq(src))
+
+        packetCaptor.value.let {
+            assertEquals(1, it.questions.size)
+            assertEquals(0, it.answers.size)
+            assertEquals(0, it.authorityRecords.size)
+            assertEquals(0, it.additionalRecords.size)
+
+            assertTrue(it.questions[0] is MdnsPointerRecord)
+            assertContentEquals(arrayOf("_testservice", "_tcp", "local"), it.questions[0].name)
+        }
+
+        verify(replySender).queueReply(mockReply)
+    }
+
     private fun addServiceAndFinishProbing(serviceId: Int, serviceInfo: NsdServiceInfo):
             AnnouncementInfo {
         val testProbingInfo = mock(ProbingInfo::class.java)
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsProberTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsProberTest.kt
index 3caa97d..a2dbbc6 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsProberTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsProberTest.kt
@@ -114,7 +114,7 @@
 
     @Test
     fun testProbe() {
-        val replySender = MdnsReplySender(thread.looper, socket, buffer)
+        val replySender = MdnsReplySender("testiface", thread.looper, socket, buffer)
         val prober = TestProber(thread.looper, replySender, cb)
         val probeInfo = TestProbeInfo(
                 listOf(makeServiceRecord(TEST_SERVICE_NAME_1, 37890)))
@@ -129,7 +129,7 @@
 
     @Test
     fun testProbeMultipleRecords() {
-        val replySender = MdnsReplySender(thread.looper, socket, buffer)
+        val replySender = MdnsReplySender("testiface", thread.looper, socket, buffer)
         val prober = TestProber(thread.looper, replySender, cb)
         val probeInfo = TestProbeInfo(listOf(
                 makeServiceRecord(TEST_SERVICE_NAME_1, 37890),
@@ -167,7 +167,7 @@
 
     @Test
     fun testStopProbing() {
-        val replySender = MdnsReplySender(thread.looper, socket, buffer)
+        val replySender = MdnsReplySender("testiface", thread.looper, socket, buffer)
         val prober = TestProber(thread.looper, replySender, cb)
         val probeInfo = TestProbeInfo(
                 listOf(makeServiceRecord(TEST_SERVICE_NAME_1, 37890)),
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt
index 29d0854..597663c 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt
@@ -21,10 +21,12 @@
 import android.net.nsd.NsdServiceInfo
 import android.os.Build
 import android.os.HandlerThread
+import com.android.server.connectivity.mdns.MdnsAnnouncer.AnnouncementInfo
 import com.android.server.connectivity.mdns.MdnsRecordRepository.Dependencies
 import com.android.server.connectivity.mdns.MdnsRecordRepository.getReverseDnsAddress
 import com.android.testutils.DevSdkIgnoreRule
 import com.android.testutils.DevSdkIgnoreRunner
+import java.net.InetSocketAddress
 import java.net.NetworkInterface
 import java.util.Collections
 import kotlin.test.assertContentEquals
@@ -150,11 +152,7 @@
     @Test
     fun testExitAnnouncements() {
         val repository = MdnsRecordRepository(thread.looper, deps)
-        repository.updateAddresses(TEST_ADDRESSES)
-
-        repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
-        val probingInfo = repository.setServiceProbing(TEST_SERVICE_ID_1)
-        repository.onProbingSucceeded(probingInfo)
+        repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
         repository.onAdvertisementSent(TEST_SERVICE_ID_1)
 
         val exitAnnouncement = repository.exitService(TEST_SERVICE_ID_1)
@@ -183,9 +181,7 @@
     @Test
     fun testExitingServiceReAdded() {
         val repository = MdnsRecordRepository(thread.looper, deps)
-        repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
-        val probingInfo = repository.setServiceProbing(TEST_SERVICE_ID_1)
-        repository.onProbingSucceeded(probingInfo)
+        repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
         repository.onAdvertisementSent(TEST_SERVICE_ID_1)
         repository.exitService(TEST_SERVICE_ID_1)
 
@@ -199,11 +195,8 @@
     @Test
     fun testOnProbingSucceeded() {
         val repository = MdnsRecordRepository(thread.looper, deps)
-        repository.updateAddresses(TEST_ADDRESSES)
-
-        repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
-        val probingInfo = repository.setServiceProbing(TEST_SERVICE_ID_1)
-        val announcementInfo = repository.onProbingSucceeded(probingInfo)
+        val announcementInfo = repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
+        repository.onAdvertisementSent(TEST_SERVICE_ID_1)
         val packet = announcementInfo.getPacket(0)
 
         assertEquals(0x8400 /* response, authoritative */, packet.flags)
@@ -322,4 +315,98 @@
         val expectedV4 = "123.2.0.192.in-addr.arpa".split(".").toTypedArray()
         assertContentEquals(expectedV4, getReverseDnsAddress(parseNumericAddress("192.0.2.123")))
     }
+
+    @Test
+    fun testGetReply() {
+        val repository = MdnsRecordRepository(thread.looper, deps)
+        repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
+        val questions = listOf(MdnsPointerRecord(arrayOf("_testservice", "_tcp", "local"),
+                0L /* receiptTimeMillis */,
+                false /* cacheFlush */,
+                // TTL and data is empty for a question
+                0L /* ttlMillis */,
+                null /* pointer */))
+        val query = MdnsPacket(0 /* flags */, questions, listOf() /* answers */,
+                listOf() /* authorityRecords */, listOf() /* additionalRecords */)
+        val src = InetSocketAddress(parseNumericAddress("192.0.2.123"), 5353)
+        val reply = repository.getReply(query, src)
+
+        assertNotNull(reply)
+        // Source address is IPv4
+        assertEquals(MdnsConstants.getMdnsIPv4Address(), reply.destination.address)
+        assertEquals(MdnsConstants.MDNS_PORT, reply.destination.port)
+
+        // TTLs as per RFC6762 10.
+        val longTtl = 4_500_000L
+        val shortTtl = 120_000L
+        val serviceName = arrayOf("MyTestService", "_testservice", "_tcp", "local")
+
+        assertEquals(listOf(
+                MdnsPointerRecord(
+                        arrayOf("_testservice", "_tcp", "local"),
+                        0L /* receiptTimeMillis */,
+                        false /* cacheFlush */,
+                        longTtl,
+                        serviceName),
+        ), reply.answers)
+
+        assertEquals(listOf(
+            MdnsTextRecord(
+                    serviceName,
+                    0L /* receiptTimeMillis */,
+                    true /* cacheFlush */,
+                    longTtl,
+                    listOf() /* entries */),
+            MdnsServiceRecord(
+                    serviceName,
+                    0L /* receiptTimeMillis */,
+                    true /* cacheFlush */,
+                    shortTtl,
+                    0 /* servicePriority */,
+                    0 /* serviceWeight */,
+                    TEST_PORT,
+                    TEST_HOSTNAME),
+            MdnsInetAddressRecord(
+                    TEST_HOSTNAME,
+                    0L /* receiptTimeMillis */,
+                    true /* cacheFlush */,
+                    shortTtl,
+                    TEST_ADDRESSES[0].address),
+            MdnsInetAddressRecord(
+                    TEST_HOSTNAME,
+                    0L /* receiptTimeMillis */,
+                    true /* cacheFlush */,
+                    shortTtl,
+                    TEST_ADDRESSES[1].address),
+            MdnsInetAddressRecord(
+                    TEST_HOSTNAME,
+                    0L /* receiptTimeMillis */,
+                    true /* cacheFlush */,
+                    shortTtl,
+                    TEST_ADDRESSES[2].address),
+            MdnsNsecRecord(
+                    serviceName,
+                    0L /* receiptTimeMillis */,
+                    true /* cacheFlush */,
+                    longTtl,
+                    serviceName /* nextDomain */,
+                    intArrayOf(MdnsRecord.TYPE_TXT, MdnsRecord.TYPE_SRV)),
+            MdnsNsecRecord(
+                    TEST_HOSTNAME,
+                    0L /* receiptTimeMillis */,
+                    true /* cacheFlush */,
+                    shortTtl,
+                    TEST_HOSTNAME /* nextDomain */,
+                    intArrayOf(MdnsRecord.TYPE_A, MdnsRecord.TYPE_AAAA)),
+        ), reply.additionalAnswers)
+    }
+}
+
+private fun MdnsRecordRepository.initWithService(serviceId: Int, serviceInfo: NsdServiceInfo):
+        AnnouncementInfo {
+    updateAddresses(TEST_ADDRESSES)
+    addService(serviceId, serviceInfo)
+    val probingInfo = setServiceProbing(serviceId)
+    assertNotNull(probingInfo)
+    return onProbingSucceeded(probingInfo)
 }