Merge "Specify aconfig_declaration modules in java_sdk_library" into main
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java b/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java
index 3a04dcd..730bd7e 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java
@@ -162,10 +162,11 @@
         @NonNull
         public MdnsReplySender makeReplySender(@NonNull String interfaceTag, @NonNull Looper looper,
                 @NonNull MdnsInterfaceSocket socket, @NonNull byte[] packetCreationBuffer,
-                @NonNull SharedLog sharedLog) {
+                @NonNull SharedLog sharedLog, @NonNull MdnsFeatureFlags mdnsFeatureFlags) {
             return new MdnsReplySender(looper, socket, packetCreationBuffer,
                     sharedLog.forSubComponent(
-                            MdnsReplySender.class.getSimpleName() + "/" + interfaceTag), DBG);
+                            MdnsReplySender.class.getSimpleName() + "/" + interfaceTag), DBG,
+                    mdnsFeatureFlags);
         }
 
         /** @see MdnsAnnouncer */
@@ -208,7 +209,7 @@
         mCb = cb;
         mCbHandler = new Handler(looper);
         mReplySender = deps.makeReplySender(sharedLog.getTag(), looper, socket,
-                packetCreationBuffer, sharedLog);
+                packetCreationBuffer, sharedLog, mdnsFeatureFlags);
         mPacketCreationBuffer = packetCreationBuffer;
         mAnnouncer = deps.makeMdnsAnnouncer(sharedLog.getTag(), looper, mReplySender,
                 mAnnouncingCallback, sharedLog);
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java b/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java
index 585b097..78c3082 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java
@@ -490,6 +490,16 @@
         return ret;
     }
 
+    private boolean isTruncatedKnownAnswerPacket(MdnsPacket packet) {
+        if (!mMdnsFeatureFlags.mIsKnownAnswerSuppressionEnabled
+                // Should ignore the response packet.
+                || (packet.flags & MdnsConstants.FLAGS_RESPONSE) != 0) {
+            return false;
+        }
+        // Check the packet contains no questions and as many more Known-Answer records as will fit.
+        return packet.questions.size() == 0 && packet.answers.size() != 0;
+    }
+
     /**
      * Get the reply to send to an incoming packet.
      *
@@ -550,7 +560,20 @@
                 answerInfo.iterator(), additionalAnswerInfo.iterator());
 
         if (answerInfo.size() == 0 && additionalAnswerRecords.size() == 0) {
-            return null;
+            // RFC6762 7.2. Multipacket Known-Answer Suppression
+            // Sometimes a Multicast DNS querier will already have too many answers
+            // to fit in the Known-Answer Section of its query packets. In this
+            // case, it should issue a Multicast DNS query containing a question and
+            // as many Known-Answer records as will fit.  It MUST then set the TC
+            // (Truncated) bit in the header before sending the query.  It MUST
+            // immediately follow the packet with another query packet containing no
+            // questions and as many more Known-Answer records as will fit.  If
+            // there are still too many records remaining to fit in the packet, it
+            // again sets the TC bit and continues until all the Known-Answer
+            // records have been sent.
+            if (!isTruncatedKnownAnswerPacket(packet)) {
+                return null;
+            }
         }
 
         // Determine the send delay
@@ -598,7 +621,8 @@
             answerRecords.add(info.record);
         }
 
-        return new MdnsReplyInfo(answerRecords, additionalAnswerRecords, delayMs, dest);
+        return new MdnsReplyInfo(answerRecords, additionalAnswerRecords, delayMs, dest, src,
+                new ArrayList<>(packet.answers));
     }
 
     private boolean isKnownAnswer(MdnsRecord answer, @NonNull List<MdnsRecord> knownAnswerRecords) {
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsReplyInfo.java b/service-t/src/com/android/server/connectivity/mdns/MdnsReplyInfo.java
index ce61b54..8747f67 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsReplyInfo.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsReplyInfo.java
@@ -32,22 +32,32 @@
     public final long sendDelayMs;
     @NonNull
     public final InetSocketAddress destination;
+    @NonNull
+    public final InetSocketAddress source;
+    @NonNull
+    public final List<MdnsRecord> knownAnswers;
 
     public MdnsReplyInfo(
             @NonNull List<MdnsRecord> answers,
             @NonNull List<MdnsRecord> additionalAnswers,
             long sendDelayMs,
-            @NonNull InetSocketAddress destination) {
+            @NonNull InetSocketAddress destination,
+            @NonNull InetSocketAddress source,
+            @NonNull List<MdnsRecord> knownAnswers) {
         this.answers = answers;
         this.additionalAnswers = additionalAnswers;
         this.sendDelayMs = sendDelayMs;
         this.destination = destination;
+        this.source = source;
+        this.knownAnswers = knownAnswers;
     }
 
     @Override
     public String toString() {
-        return "{MdnsReplyInfo to " + destination + ", answers: " + answers.size()
+        return "{MdnsReplyInfo: " + source + " to " + destination
+                + ", answers: " + answers.size()
                 + ", additionalAnswers: " + additionalAnswers.size()
+                + ", knownAnswers: " + knownAnswers.size()
                 + ", sendDelayMs " + sendDelayMs + "}";
     }
 }
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsReplySender.java b/service-t/src/com/android/server/connectivity/mdns/MdnsReplySender.java
index 651b643..a46be3b 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsReplySender.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsReplySender.java
@@ -16,6 +16,8 @@
 
 package com.android.server.connectivity.mdns;
 
+import static com.android.server.connectivity.mdns.MdnsConstants.IPV4_SOCKET_ADDR;
+import static com.android.server.connectivity.mdns.MdnsConstants.IPV6_SOCKET_ADDR;
 import static com.android.server.connectivity.mdns.util.MdnsUtils.ensureRunningOnHandlerThread;
 
 import android.annotation.NonNull;
@@ -24,6 +26,8 @@
 import android.os.Handler;
 import android.os.Looper;
 import android.os.Message;
+import android.util.ArrayMap;
+import android.util.ArraySet;
 
 import com.android.internal.annotations.VisibleForTesting;
 import com.android.net.module.util.SharedLog;
@@ -35,7 +39,10 @@
 import java.net.Inet6Address;
 import java.net.InetSocketAddress;
 import java.net.MulticastSocket;
+import java.util.ArrayList;
 import java.util.Collections;
+import java.util.Map;
+import java.util.Set;
 
 /**
  * A class that handles sending mDNS replies to a {@link MulticastSocket}, possibly queueing them
@@ -60,6 +67,12 @@
     private final boolean mEnableDebugLog;
     @NonNull
     private final Dependencies mDependencies;
+    // RFC6762 15.2. Multipacket Known-Answer lists
+    // Multicast DNS responders associate the initial truncated query with its
+    // continuation packets by examining the source IP address in each packet.
+    private final Map<InetSocketAddress, MdnsReplyInfo> mSrcReplies = new ArrayMap<>();
+    @NonNull
+    private final MdnsFeatureFlags mMdnsFeatureFlags;
 
     /**
      * Dependencies of MdnsReplySender, for injection in tests.
@@ -80,24 +93,50 @@
         public void removeMessages(@NonNull Handler handler, int what) {
             handler.removeMessages(what);
         }
+
+        /**
+         * @see Handler#removeMessages(int)
+         */
+        public void removeMessages(@NonNull Handler handler, int what, @NonNull Object object) {
+            handler.removeMessages(what, object);
+        }
     }
 
     public MdnsReplySender(@NonNull Looper looper, @NonNull MdnsInterfaceSocket socket,
             @NonNull byte[] packetCreationBuffer, @NonNull SharedLog sharedLog,
-            boolean enableDebugLog) {
-        this(looper, socket, packetCreationBuffer, sharedLog, enableDebugLog, new Dependencies());
+            boolean enableDebugLog, @NonNull MdnsFeatureFlags mdnsFeatureFlags) {
+        this(looper, socket, packetCreationBuffer, sharedLog, enableDebugLog, new Dependencies(),
+                mdnsFeatureFlags);
     }
 
     @VisibleForTesting
     public MdnsReplySender(@NonNull Looper looper, @NonNull MdnsInterfaceSocket socket,
             @NonNull byte[] packetCreationBuffer, @NonNull SharedLog sharedLog,
-            boolean enableDebugLog, @NonNull Dependencies dependencies) {
+            boolean enableDebugLog, @NonNull Dependencies dependencies,
+            @NonNull MdnsFeatureFlags mdnsFeatureFlags) {
         mHandler = new SendHandler(looper);
         mSocket = socket;
         mPacketCreationBuffer = packetCreationBuffer;
         mSharedLog = sharedLog;
         mEnableDebugLog = enableDebugLog;
         mDependencies = dependencies;
+        mMdnsFeatureFlags = mdnsFeatureFlags;
+    }
+
+    static InetSocketAddress getReplyDestination(@NonNull InetSocketAddress queuingDest,
+            @NonNull InetSocketAddress incomingDest) {
+        // The queuing reply is multicast, just use the current destination.
+        if (queuingDest.equals(IPV4_SOCKET_ADDR) || queuingDest.equals(IPV6_SOCKET_ADDR)) {
+            return queuingDest;
+        }
+
+        // The incoming reply is multicast, change the reply from unicast to multicast since
+        // replying unicast when the query requests unicast reply is optional.
+        if (incomingDest.equals(IPV4_SOCKET_ADDR) || incomingDest.equals(IPV6_SOCKET_ADDR)) {
+            return incomingDest;
+        }
+
+        return queuingDest;
     }
 
     /**
@@ -105,9 +144,53 @@
      */
     public void queueReply(@NonNull MdnsReplyInfo reply) {
         ensureRunningOnHandlerThread(mHandler);
-        // TODO: implement response aggregation (RFC 6762 6.4)
-        mDependencies.sendMessageDelayed(
-                mHandler, mHandler.obtainMessage(MSG_SEND, reply), reply.sendDelayMs);
+
+        if (mMdnsFeatureFlags.mIsKnownAnswerSuppressionEnabled) {
+            mDependencies.removeMessages(mHandler, MSG_SEND, reply.source);
+
+            final MdnsReplyInfo queuingReply = mSrcReplies.remove(reply.source);
+            final ArraySet<MdnsRecord> answers = new ArraySet<>();
+            final Set<MdnsRecord> additionalAnswers = new ArraySet<>();
+            final Set<MdnsRecord> knownAnswers = new ArraySet<>();
+            if (queuingReply != null) {
+                answers.addAll(queuingReply.answers);
+                additionalAnswers.addAll(queuingReply.additionalAnswers);
+                knownAnswers.addAll(queuingReply.knownAnswers);
+            }
+            answers.addAll(reply.answers);
+            additionalAnswers.addAll(reply.additionalAnswers);
+            knownAnswers.addAll(reply.knownAnswers);
+            // RFC6762 7.2. Multipacket Known-Answer Suppression
+            // If the responder sees any of its answers listed in the Known-Answer
+            // lists of subsequent packets from the querying host, it MUST delete
+            // that answer from the list of answers it is planning to give.
+            for (MdnsRecord knownAnswer : knownAnswers) {
+                final int idx = answers.indexOf(knownAnswer);
+                if (idx >= 0 && knownAnswer.getTtl() > answers.valueAt(idx).getTtl() / 2) {
+                    answers.removeAt(idx);
+                }
+            }
+
+            if (answers.size() == 0) {
+                return;
+            }
+
+            final MdnsReplyInfo newReply = new MdnsReplyInfo(
+                    new ArrayList<>(answers),
+                    new ArrayList<>(additionalAnswers),
+                    reply.sendDelayMs,
+                    queuingReply == null ? reply.destination
+                            : getReplyDestination(queuingReply.destination, reply.destination),
+                    reply.source,
+                    new ArrayList<>(knownAnswers));
+
+            mSrcReplies.put(newReply.source, newReply);
+            mDependencies.sendMessageDelayed(mHandler,
+                    mHandler.obtainMessage(MSG_SEND, newReply.source), newReply.sendDelayMs);
+        } else {
+            mDependencies.sendMessageDelayed(
+                    mHandler, mHandler.obtainMessage(MSG_SEND, reply), reply.sendDelayMs);
+        }
 
         if (mEnableDebugLog) {
             mSharedLog.v("Scheduling " + reply);
@@ -147,7 +230,21 @@
 
         @Override
         public void handleMessage(@NonNull Message msg) {
-            final MdnsReplyInfo replyInfo = (MdnsReplyInfo) msg.obj;
+            final MdnsReplyInfo replyInfo;
+            if (mMdnsFeatureFlags.mIsKnownAnswerSuppressionEnabled) {
+                // Retrieve the MdnsReplyInfo from the map via a source address, as the reply info
+                // will be combined or updated.
+                final InetSocketAddress source = (InetSocketAddress) msg.obj;
+                replyInfo = mSrcReplies.remove(source);
+            } else {
+                replyInfo = (MdnsReplyInfo) msg.obj;
+            }
+
+            if (replyInfo == null) {
+                mSharedLog.wtf("Unknown reply info.");
+                return;
+            }
+
             if (mEnableDebugLog) mSharedLog.v("Sending " + replyInfo);
 
             final int flags = 0x8400; // Response, authoritative (rfc6762 18.4)
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsAnnouncerTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsAnnouncerTest.kt
index 2797462..27242f1 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsAnnouncerTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsAnnouncerTest.kt
@@ -55,6 +55,7 @@
     private val socket = mock(MdnsInterfaceSocket::class.java)
     private val sharedLog = mock(SharedLog::class.java)
     private val buffer = ByteArray(1500)
+    private val flags = MdnsFeatureFlags.newBuilder().build()
 
     @Before
     fun setUp() {
@@ -83,7 +84,7 @@
     @Test
     fun testAnnounce() {
         val replySender = MdnsReplySender(
-                thread.looper, socket, buffer, sharedLog, true /* enableDebugLog */)
+                thread.looper, socket, buffer, sharedLog, true /* enableDebugLog */, flags)
         @Suppress("UNCHECKED_CAST")
         val cb = mock(MdnsPacketRepeater.PacketRepeaterCallback::class.java)
                 as MdnsPacketRepeater.PacketRepeaterCallback<BaseAnnouncementInfo>
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt
index ee0bd1a..0e5cc50 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt
@@ -45,6 +45,7 @@
 import org.mockito.Mockito.any
 import org.mockito.Mockito.anyInt
 import org.mockito.Mockito.anyString
+import org.mockito.Mockito.argThat
 import org.mockito.Mockito.doAnswer
 import org.mockito.Mockito.doReturn
 import org.mockito.Mockito.eq
@@ -87,7 +88,8 @@
     private val announcer = mock(MdnsAnnouncer::class.java)
     private val prober = mock(MdnsProber::class.java)
     private val sharedlog = SharedLog("MdnsInterfaceAdvertiserTest")
-    private val flags = MdnsFeatureFlags.newBuilder().build()
+    private val flags = MdnsFeatureFlags.newBuilder()
+            .setIsKnownAnswerSuppressionEnabled(true).build()
     @Suppress("UNCHECKED_CAST")
     private val probeCbCaptor = ArgumentCaptor.forClass(PacketRepeaterCallback::class.java)
             as ArgumentCaptor<PacketRepeaterCallback<ProbingInfo>>
@@ -118,7 +120,8 @@
     @Before
     fun setUp() {
         doReturn(repository).`when`(deps).makeRecordRepository(any(), eq(TEST_HOSTNAME), any())
-        doReturn(replySender).`when`(deps).makeReplySender(anyString(), any(), any(), any(), any())
+        doReturn(replySender).`when`(deps).makeReplySender(
+                anyString(), any(), any(), any(), any(), any())
         doReturn(announcer).`when`(deps).makeMdnsAnnouncer(anyString(), any(), any(), any(), any())
         doReturn(prober).`when`(deps).makeMdnsProber(anyString(), any(), any(), any(), any())
 
@@ -200,7 +203,8 @@
     fun testReplyToQuery() {
         addServiceAndFinishProbing(TEST_SERVICE_ID_1, TEST_SERVICE_1)
 
-        val testReply = MdnsReplyInfo(emptyList(), emptyList(), 0, InetSocketAddress(0))
+        val testReply = MdnsReplyInfo(emptyList(), emptyList(), 0, InetSocketAddress(0),
+                InetSocketAddress(0), emptyList())
         doReturn(testReply).`when`(repository).getReply(any(), any())
 
         // Query obtained with:
@@ -235,6 +239,112 @@
     }
 
     @Test
+    fun testReplyToQuery_TruncatedBitSet() {
+        addServiceAndFinishProbing(TEST_SERVICE_ID_1, TEST_SERVICE_1)
+        val src = InetSocketAddress(parseNumericAddress("2001:db8::456"), MdnsConstants.MDNS_PORT)
+        val testReply = MdnsReplyInfo(emptyList(), emptyList(), 400L, InetSocketAddress(0), src,
+                emptyList())
+        val knownAnswersReply = MdnsReplyInfo(emptyList(), emptyList(), 400L, InetSocketAddress(0),
+                src, emptyList())
+        val knownAnswersReply2 = MdnsReplyInfo(emptyList(), emptyList(), 0L, InetSocketAddress(0),
+                src, emptyList())
+        doReturn(testReply).`when`(repository).getReply(
+                argThat { pkg -> pkg.questions.size != 0 && pkg.answers.size == 0 &&
+                        (pkg.flags and MdnsConstants.FLAG_TRUNCATED) != 0},
+                eq(src))
+        doReturn(knownAnswersReply).`when`(repository).getReply(
+                argThat { pkg -> pkg.questions.size == 0 && pkg.answers.size != 0 &&
+                        (pkg.flags and MdnsConstants.FLAG_TRUNCATED) != 0},
+                eq(src))
+        doReturn(knownAnswersReply2).`when`(repository).getReply(
+                argThat { pkg -> pkg.questions.size == 0 && pkg.answers.size != 0 &&
+                        (pkg.flags and MdnsConstants.FLAG_TRUNCATED) == 0},
+                eq(src))
+
+        // Query obtained with:
+        // scapy.raw(scapy.DNS(
+        //  tc = 1, qd = scapy.DNSQR(qtype='PTR', qname='_testservice._tcp.local'))
+        // ).hex().upper()
+        val query = HexDump.hexStringToByteArray(
+                "0000030000010000000000000C5F7465737473657276696365045F746370056C6F63616C00000C0001"
+        )
+
+        packetHandler.handlePacket(query, query.size, src)
+
+        val packetCaptor = ArgumentCaptor.forClass(MdnsPacket::class.java)
+        verify(repository).getReply(packetCaptor.capture(), eq(src))
+
+        packetCaptor.value.let {
+            assertTrue((it.flags and MdnsConstants.FLAG_TRUNCATED) != 0)
+            assertEquals(1, it.questions.size)
+            assertEquals(0, it.answers.size)
+            assertEquals(0, it.authorityRecords.size)
+            assertEquals(0, it.additionalRecords.size)
+
+            assertTrue(it.questions[0] is MdnsPointerRecord)
+            assertContentEquals(arrayOf("_testservice", "_tcp", "local"), it.questions[0].name)
+        }
+
+        verify(replySender).queueReply(testReply)
+
+        // Known-Answer packet with truncated bit set obtained with:
+        // scapy.raw(scapy.DNS(
+        //   tc = 1, qd = None, an = scapy.DNSRR(type='PTR', rrname='_testtype._tcp.local',
+        //   rdata='othertestservice._testtype._tcp.local', rclass='IN', ttl=4500))
+        // ).hex().upper()
+        val knownAnswers = HexDump.hexStringToByteArray(
+                "000003000000000100000000095F7465737474797065045F746370056C6F63616C00000C0001000" +
+                        "011940027106F746865727465737473657276696365095F7465737474797065045F7463" +
+                        "70056C6F63616C00"
+        )
+
+        packetHandler.handlePacket(knownAnswers, knownAnswers.size, src)
+
+        verify(repository, times(2)).getReply(packetCaptor.capture(), eq(src))
+
+        packetCaptor.value.let {
+            assertTrue((it.flags and MdnsConstants.FLAG_TRUNCATED) != 0)
+            assertEquals(0, it.questions.size)
+            assertEquals(1, it.answers.size)
+            assertEquals(0, it.authorityRecords.size)
+            assertEquals(0, it.additionalRecords.size)
+
+            assertTrue(it.answers[0] is MdnsPointerRecord)
+            assertContentEquals(arrayOf("_testtype", "_tcp", "local"), it.answers[0].name)
+        }
+
+        verify(replySender).queueReply(knownAnswersReply)
+
+        // Known-Answer packet obtained with:
+        // scapy.raw(scapy.DNS(
+        //   qd = None, an = scapy.DNSRR(type='PTR', rrname='_testtype._tcp.local',
+        //   rdata='testservice._testtype._tcp.local', rclass='IN', ttl=4500))
+        // ).hex().upper()
+        val knownAnswers2 = HexDump.hexStringToByteArray(
+                "000001000000000100000000095F7465737474797065045F746370056C6F63616C00000C0001000" +
+                        "0119400220B7465737473657276696365095F7465737474797065045F746370056C6F63" +
+                        "616C00"
+        )
+
+        packetHandler.handlePacket(knownAnswers2, knownAnswers2.size, src)
+
+        verify(repository, times(3)).getReply(packetCaptor.capture(), eq(src))
+
+        packetCaptor.value.let {
+            assertTrue((it.flags and MdnsConstants.FLAG_TRUNCATED) == 0)
+            assertEquals(0, it.questions.size)
+            assertEquals(1, it.answers.size)
+            assertEquals(0, it.authorityRecords.size)
+            assertEquals(0, it.additionalRecords.size)
+
+            assertTrue(it.answers[0] is MdnsPointerRecord)
+            assertContentEquals(arrayOf("_testtype", "_tcp", "local"), it.answers[0].name)
+        }
+
+        verify(replySender).queueReply(knownAnswersReply2)
+    }
+
+    @Test
     fun testConflict() {
         addServiceAndFinishProbing(TEST_SERVICE_ID_1, TEST_SERVICE_1)
         doReturn(setOf(TEST_SERVICE_ID_1)).`when`(repository).getConflictingServices(any())
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsProberTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsProberTest.kt
index 5b7c0ba..9befbc1 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsProberTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsProberTest.kt
@@ -61,6 +61,7 @@
     private val cb = mock(MdnsPacketRepeater.PacketRepeaterCallback::class.java)
         as MdnsPacketRepeater.PacketRepeaterCallback<ProbingInfo>
     private val buffer = ByteArray(1500)
+    private val flags = MdnsFeatureFlags.newBuilder().build()
 
     @Before
     fun setUp() {
@@ -120,7 +121,7 @@
     @Test
     fun testProbe() {
         val replySender = MdnsReplySender(
-                thread.looper, socket, buffer, sharedLog, true /* enableDebugLog */)
+                thread.looper, socket, buffer, sharedLog, true /* enableDebugLog */, flags)
         val prober = TestProber(thread.looper, replySender, cb, sharedLog)
         val probeInfo = TestProbeInfo(
                 listOf(makeServiceRecord(TEST_SERVICE_NAME_1, 37890)))
@@ -145,7 +146,7 @@
     @Test
     fun testProbeMultipleRecords() {
         val replySender = MdnsReplySender(
-                thread.looper, socket, buffer, sharedLog, true /* enableDebugLog */)
+                thread.looper, socket, buffer, sharedLog, true /* enableDebugLog */, flags)
         val prober = TestProber(thread.looper, replySender, cb, sharedLog)
         val probeInfo = TestProbeInfo(listOf(
                 makeServiceRecord(TEST_SERVICE_NAME_1, 37890),
@@ -184,7 +185,7 @@
     @Test
     fun testStopProbing() {
         val replySender = MdnsReplySender(
-                thread.looper, socket, buffer, sharedLog, true /* enableDebugLog */)
+                thread.looper, socket, buffer, sharedLog, true /* enableDebugLog */, flags)
         val prober = TestProber(thread.looper, replySender, cb, sharedLog)
         val probeInfo = TestProbeInfo(
                 listOf(makeServiceRecord(TEST_SERVICE_NAME_1, 37890)),
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt
index 1edc806..06f12fe 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt
@@ -175,8 +175,8 @@
 
         val queriedName = arrayOf(TEST_SUBTYPE, "_sub", "_testservice", "_tcp", "local")
         val questions = listOf(MdnsPointerRecord(queriedName, false /* isUnicast */))
-        val query = MdnsPacket(0 /* flags */, questions, listOf() /* answers */,
-                listOf() /* authorityRecords */, listOf() /* additionalRecords */)
+        val query = MdnsPacket(0 /* flags */, questions, emptyList() /* answers */,
+                emptyList() /* authorityRecords */, emptyList() /* additionalRecords */)
         val src = InetSocketAddress(parseNumericAddress("192.0.2.123"), 5353)
         val reply = repository.getReply(query, src)
 
@@ -510,8 +510,8 @@
         val questionsCaseInSensitive = listOf(
                 MdnsPointerRecord(arrayOf("_TESTSERVICE", "_TCP", "local"), false /* isUnicast */))
         val queryCaseInsensitive = MdnsPacket(0 /* flags */, questionsCaseInSensitive,
-            listOf() /* answers */, listOf() /* authorityRecords */,
-            listOf() /* additionalRecords */)
+            emptyList() /* answers */, emptyList() /* authorityRecords */,
+            emptyList() /* additionalRecords */)
         val src = InetSocketAddress(parseNumericAddress("192.0.2.123"), 5353)
         val replyCaseInsensitive = repository.getReply(queryCaseInsensitive, src)
         assertNotNull(replyCaseInsensitive)
@@ -524,8 +524,8 @@
      */
     private fun makeQuery(vararg queries: Pair<Int, Array<String>>): MdnsPacket {
         val questions = queries.map { (type, name) -> makeQuestionRecord(name, type) }
-        return MdnsPacket(0 /* flags */, questions, listOf() /* answers */,
-            listOf() /* authorityRecords */, listOf() /* additionalRecords */)
+        return MdnsPacket(0 /* flags */, questions, emptyList() /* answers */,
+                emptyList() /* authorityRecords */, emptyList() /* additionalRecords */)
     }
 
     private fun makeQuestionRecord(name: Array<String>, type: Int): MdnsRecord {
@@ -554,7 +554,7 @@
                     arrayOf("_testservice", "_tcp", "local"), 0L, false, LONG_TTL, serviceName)),
             reply.answers)
         assertEquals(listOf(
-                MdnsTextRecord(serviceName, 0L, true, LONG_TTL, listOf()),
+                MdnsTextRecord(serviceName, 0L, true, LONG_TTL, emptyList()),
                 MdnsServiceRecord(serviceName, 0L, true, SHORT_TTL, 0, 0, TEST_PORT, TEST_HOSTNAME),
                 MdnsInetAddressRecord(
                     TEST_HOSTNAME, 0L, true, SHORT_TTL, TEST_ADDRESSES[0].address),
@@ -587,7 +587,7 @@
                     LONG_TTL, serviceName)),
             reply.answers)
         assertEquals(listOf(
-                MdnsTextRecord(serviceName, 0L, true, LONG_TTL, listOf()),
+                MdnsTextRecord(serviceName, 0L, true, LONG_TTL, emptyList()),
                 MdnsServiceRecord(serviceName, 0L, true, SHORT_TTL, 0, 0, TEST_PORT, TEST_HOSTNAME),
                 MdnsInetAddressRecord(
                     TEST_HOSTNAME, 0L, true, SHORT_TTL, TEST_ADDRESSES[0].address),
@@ -620,7 +620,7 @@
                     arrayOf("_testservice", "_tcp", "local"), 0L, false, LONG_TTL, serviceName)),
             reply.answers)
         assertEquals(listOf(
-                MdnsTextRecord(serviceName, 0L, true, LONG_TTL, listOf()),
+                MdnsTextRecord(serviceName, 0L, true, LONG_TTL, emptyList()),
                 MdnsServiceRecord(serviceName, 0L, true, SHORT_TTL, 0, 0, TEST_PORT, TEST_HOSTNAME),
                 MdnsInetAddressRecord(
                     TEST_HOSTNAME, 0L, true, SHORT_TTL, TEST_ADDRESSES[0].address),
@@ -656,7 +656,7 @@
                     0L, false, LONG_TTL, serviceName)),
             reply.answers)
         assertEquals(listOf(
-                MdnsTextRecord(serviceName, 0L, true, LONG_TTL, listOf()),
+                MdnsTextRecord(serviceName, 0L, true, LONG_TTL, emptyList()),
                 MdnsServiceRecord(serviceName, 0L, true, SHORT_TTL, 0, 0, TEST_PORT, TEST_HOSTNAME),
                 MdnsInetAddressRecord(
                     TEST_HOSTNAME, 0L, true, SHORT_TTL, TEST_ADDRESSES[0].address),
@@ -682,7 +682,7 @@
         val reply = repository.getReply(query, src)
 
         assertNotNull(reply)
-        assertEquals(listOf(MdnsTextRecord(serviceName, 0L, true, LONG_TTL, listOf())),
+        assertEquals(listOf(MdnsTextRecord(serviceName, 0L, true, LONG_TTL, emptyList())),
                 reply.answers)
         // No NSEC records because the reply doesn't include the SRV record
         assertTrue(reply.additionalAnswers.isEmpty())
@@ -747,7 +747,7 @@
         assertNotNull(reply)
         assertEquals(listOf(
                 MdnsServiceRecord(serviceName, 0L, true, SHORT_TTL, 0, 0, TEST_PORT, TEST_HOSTNAME),
-                MdnsTextRecord(serviceName, 0L, true, LONG_TTL, listOf()),
+                MdnsTextRecord(serviceName, 0L, true, LONG_TTL, emptyList()),
                 MdnsInetAddressRecord(
                         TEST_HOSTNAME, 0L, true, SHORT_TTL, TEST_ADDRESSES[0].address),
                 MdnsInetAddressRecord(
@@ -915,8 +915,8 @@
 
         val questions = listOf(
                 MdnsPointerRecord(arrayOf("_testservice", "_tcp", "local"), false /* isUnicast */))
-        val query = MdnsPacket(0 /* flags */, questions, listOf() /* answers */,
-                listOf() /* authorityRecords */, listOf() /* additionalRecords */)
+        val query = MdnsPacket(0 /* flags */, questions, emptyList() /* answers */,
+                emptyList() /* authorityRecords */, emptyList() /* additionalRecords */)
         val src = InetSocketAddress(parseNumericAddress("192.0.2.123"), 5353)
 
         // Reply to the question and verify there is one packet replied.
@@ -994,18 +994,17 @@
             questions: List<MdnsRecord>,
             knownAnswers: List<MdnsRecord>,
             replyAnswers: List<MdnsRecord>,
-            additionalAnswers: List<MdnsRecord>,
-            expectReply: Boolean
+            additionalAnswers: List<MdnsRecord>
     ) {
         val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME,
             makeFlags(isKnownAnswerSuppressionEnabled = true))
         repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
         val query = MdnsPacket(0 /* flags */, questions, knownAnswers,
-                listOf() /* authorityRecords */, listOf() /* additionalRecords */)
+                emptyList() /* authorityRecords */, emptyList() /* additionalRecords */)
         val src = InetSocketAddress(parseNumericAddress("192.0.2.123"), 5353)
         val reply = repository.getReply(query, src)
 
-        if (!expectReply) {
+        if (replyAnswers.isEmpty() || additionalAnswers.isEmpty()) {
             assertNull(reply)
             return
         }
@@ -1016,6 +1015,7 @@
         assertEquals(MdnsConstants.MDNS_PORT, reply.destination.port)
         assertEquals(replyAnswers, reply.answers)
         assertEquals(additionalAnswers, reply.additionalAnswers)
+        assertEquals(knownAnswers, reply.knownAnswers)
     }
 
     @Test
@@ -1028,8 +1028,8 @@
                 false /* cacheFlush */,
                 LONG_TTL,
                 arrayOf("MyTestService", "_testservice", "_tcp", "local")))
-        doGetReplyWithAnswersTest(questions, knownAnswers, listOf() /* replyAnswers */,
-                listOf() /* additionalAnswers */, false /* expectReply */)
+        doGetReplyWithAnswersTest(questions, knownAnswers, emptyList() /* replyAnswers */,
+                emptyList() /* additionalAnswers */)
     }
 
     @Test
@@ -1055,7 +1055,7 @@
                         0L /* receiptTimeMillis */,
                         true /* cacheFlush */,
                         LONG_TTL,
-                        listOf() /* entries */),
+                        emptyList() /* entries */),
                 MdnsServiceRecord(
                         serviceName,
                         0L /* receiptTimeMillis */,
@@ -1097,8 +1097,7 @@
                         SHORT_TTL,
                         TEST_HOSTNAME /* nextDomain */,
                         intArrayOf(MdnsRecord.TYPE_A, MdnsRecord.TYPE_AAAA)))
-        doGetReplyWithAnswersTest(questions, knownAnswers, replyAnswers, additionalAnswers,
-                true /* expectReply */)
+        doGetReplyWithAnswersTest(questions, knownAnswers, replyAnswers, additionalAnswers)
     }
 
     @Test
@@ -1124,7 +1123,7 @@
                         0L /* receiptTimeMillis */,
                         true /* cacheFlush */,
                         LONG_TTL,
-                        listOf() /* entries */),
+                        emptyList() /* entries */),
                 MdnsServiceRecord(
                         serviceName,
                         0L /* receiptTimeMillis */,
@@ -1166,8 +1165,7 @@
                         SHORT_TTL,
                         TEST_HOSTNAME /* nextDomain */,
                         intArrayOf(MdnsRecord.TYPE_A, MdnsRecord.TYPE_AAAA)))
-        doGetReplyWithAnswersTest(questions, knownAnswers, replyAnswers, additionalAnswers,
-                true /* expectReply */)
+        doGetReplyWithAnswersTest(questions, knownAnswers, replyAnswers, additionalAnswers)
     }
 
     @Test
@@ -1218,8 +1216,7 @@
                         SHORT_TTL,
                         TEST_HOSTNAME /* nextDomain */,
                         intArrayOf(MdnsRecord.TYPE_A, MdnsRecord.TYPE_AAAA)))
-        doGetReplyWithAnswersTest(questions, knownAnswers, replyAnswers, additionalAnswers,
-                true /* expectReply */)
+        doGetReplyWithAnswersTest(questions, knownAnswers, replyAnswers, additionalAnswers)
     }
 
     @Test
@@ -1248,10 +1245,8 @@
                 TEST_HOSTNAME
             )
         )
-        doGetReplyWithAnswersTest(
-            questions, knownAnswers, listOf() /* replyAnswers */,
-            listOf() /* additionalAnswers */, false /* expectReply */
-        )
+        doGetReplyWithAnswersTest(questions, knownAnswers, emptyList() /* replyAnswers */,
+                emptyList() /* additionalAnswers */)
     }
 
     @Test
@@ -1263,8 +1258,8 @@
         val questions = listOf(
             MdnsPointerRecord(arrayOf("_testservice", "_tcp", "local"), true /* isUnicast */),
             MdnsPointerRecord(arrayOf("_otherservice", "_tcp", "local"), true /* isUnicast */))
-        val query = MdnsPacket(0 /* flags */, questions, listOf() /* answers */,
-            listOf() /* authorityRecords */, listOf() /* additionalRecords */)
+        val query = MdnsPacket(0 /* flags */, questions, emptyList() /* answers */,
+                emptyList() /* authorityRecords */, emptyList() /* additionalRecords */)
         val src = InetSocketAddress(parseNumericAddress("2001:db8::123"), 5353)
 
         // Reply to the question and verify it is sent to the source.
@@ -1287,8 +1282,8 @@
         val questions = listOf(
             MdnsPointerRecord(arrayOf("_testservice", "_tcp", "local"), true /* isUnicast */),
             MdnsPointerRecord(arrayOf("_otherservice", "_tcp", "local"), false /* isUnicast */))
-        val query = MdnsPacket(0 /* flags */, questions, listOf() /* answers */,
-            listOf() /* authorityRecords */, listOf() /* additionalRecords */)
+        val query = MdnsPacket(0 /* flags */, questions, emptyList() /* answers */,
+                emptyList() /* authorityRecords */, emptyList() /* additionalRecords */)
         val src = InetSocketAddress(parseNumericAddress("2001:db8::123"), 5353)
 
         // Reply to the question and verify it is sent multicast.
@@ -1306,8 +1301,8 @@
         val questions = listOf(
             MdnsPointerRecord(arrayOf("_otherservice", "_tcp", "local"), true /* isUnicast */),
             MdnsPointerRecord(arrayOf("_testservice", "_tcp", "local"), false /* isUnicast */))
-        val query = MdnsPacket(0 /* flags */, questions, listOf() /* answers */,
-            listOf() /* authorityRecords */, listOf() /* additionalRecords */)
+        val query = MdnsPacket(0 /* flags */, questions, emptyList() /* answers */,
+                emptyList() /* authorityRecords */, emptyList() /* additionalRecords */)
         val src = InetSocketAddress(parseNumericAddress("2001:db8::123"), 5353)
 
         // Reply to the question and verify it is sent multicast.
@@ -1325,8 +1320,8 @@
         // The service is known and requests unicast reply, but the feature is disabled
         val questions = listOf(
             MdnsPointerRecord(arrayOf("_testservice", "_tcp", "local"), true /* isUnicast */))
-        val query = MdnsPacket(0 /* flags */, questions, listOf() /* answers */,
-            listOf() /* authorityRecords */, listOf() /* additionalRecords */)
+        val query = MdnsPacket(0 /* flags */, questions, emptyList() /* answers */,
+                emptyList() /* authorityRecords */, emptyList() /* additionalRecords */)
         val src = InetSocketAddress(parseNumericAddress("2001:db8::123"), 5353)
 
         // Reply to the question and verify it is sent multicast.
@@ -1334,6 +1329,28 @@
         assertNotNull(reply)
         assertEquals(MdnsConstants.getMdnsIPv6Address(), reply.destination.address)
     }
+
+    @Test
+    fun testGetReply_OnlyKnownAnswers() {
+        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME,
+                makeFlags(isKnownAnswerSuppressionEnabled = true))
+        repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
+        val knownAnswers = listOf(MdnsPointerRecord(
+                arrayOf("_testservice", "_tcp", "local"),
+                0L /* receiptTimeMillis */,
+                false /* cacheFlush */,
+                LONG_TTL - 1000L,
+                arrayOf("MyTestService", "_testservice", "_tcp", "local")))
+        val query = MdnsPacket(MdnsConstants.FLAG_TRUNCATED /* flags */, emptyList(),
+                knownAnswers, emptyList() /* authorityRecords */,
+                emptyList() /* additionalRecords */)
+        val src = InetSocketAddress(parseNumericAddress("192.0.2.123"), 5353)
+        val reply = repository.getReply(query, src)
+        assertNotNull(reply)
+        assertEquals(0, reply.answers.size)
+        assertEquals(0, reply.additionalAnswers.size)
+        assertEquals(knownAnswers, reply.knownAnswers)
+    }
 }
 
 private fun MdnsRecordRepository.initWithService(
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsReplySenderTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsReplySenderTest.kt
index 9e2933f..9bd0530 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsReplySenderTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsReplySenderTest.kt
@@ -24,21 +24,28 @@
 import android.os.Message
 import com.android.net.module.util.SharedLog
 import com.android.server.connectivity.mdns.MdnsConstants.IPV4_SOCKET_ADDR
+import com.android.server.connectivity.mdns.MdnsConstants.IPV6_SOCKET_ADDR
+import com.android.server.connectivity.mdns.MdnsReplySender.getReplyDestination
 import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo
 import com.android.testutils.DevSdkIgnoreRunner
+import java.net.DatagramPacket
 import java.net.InetSocketAddress
 import java.util.concurrent.CompletableFuture
 import java.util.concurrent.TimeUnit
+import kotlin.test.assertEquals
 import org.junit.After
 import org.junit.Before
 import org.junit.Test
 import org.junit.runner.RunWith
 import org.mockito.ArgumentCaptor
+import org.mockito.Mockito.any
+import org.mockito.Mockito.anyLong
 import org.mockito.Mockito.argThat
 import org.mockito.Mockito.doReturn
 import org.mockito.Mockito.eq
 import org.mockito.Mockito.mock
 import org.mockito.Mockito.timeout
+import org.mockito.Mockito.times
 import org.mockito.Mockito.verify
 
 private const val TEST_PORT = 12345
@@ -50,8 +57,12 @@
 @IgnoreUpTo(Build.VERSION_CODES.S_V2)
 class MdnsReplySenderTest {
     private val serviceName = arrayOf("MyTestService", "_testservice", "_tcp", "local")
+    private val otherServiceName = arrayOf("OtherTestService", "_testservice", "_tcp", "local")
     private val serviceType = arrayOf("_testservice", "_tcp", "local")
+    private val source = InetSocketAddress(
+            InetAddresses.parseNumericAddress("192.0.2.1"), TEST_PORT)
     private val hostname = arrayOf("Android_000102030405060708090A0B0C0D0E0F", "local")
+    private val otherHostname = arrayOf("Android_0F0E0D0C0B0A09080706050403020100", "local")
     private val hostAddresses = listOf(
             LinkAddress(InetAddresses.parseNumericAddress("192.0.2.111"), 24),
             LinkAddress(InetAddresses.parseNumericAddress("2001:db8::111"), 64),
@@ -59,9 +70,12 @@
     private val answers = listOf(
             MdnsPointerRecord(serviceType, 0L /* receiptTimeMillis */, false /* cacheFlush */,
                     LONG_TTL, serviceName))
+    private val otherAnswers = listOf(
+            MdnsPointerRecord(serviceType, 0L /* receiptTimeMillis */, false /* cacheFlush */,
+                    LONG_TTL, otherServiceName))
     private val additionalAnswers = listOf(
             MdnsTextRecord(serviceName, 0L /* receiptTimeMillis */, true /* cacheFlush */, LONG_TTL,
-                    listOf() /* entries */),
+                    emptyList() /* entries */),
             MdnsServiceRecord(serviceName, 0L /* receiptTimeMillis */, true /* cacheFlush */,
                     SHORT_TTL, 0 /* servicePriority */, 0 /* serviceWeight */, TEST_PORT, hostname),
             MdnsInetAddressRecord(hostname, 0L /* receiptTimeMillis */, true /* cacheFlush */,
@@ -75,15 +89,30 @@
                     intArrayOf(MdnsRecord.TYPE_TXT, MdnsRecord.TYPE_SRV)),
             MdnsNsecRecord(hostname, 0L /* receiptTimeMillis */, true /* cacheFlush */, SHORT_TTL,
                     hostname /* nextDomain */, intArrayOf(MdnsRecord.TYPE_A, MdnsRecord.TYPE_AAAA)))
+    private val otherAdditionalAnswers = listOf(
+            MdnsTextRecord(otherServiceName, 0L /* receiptTimeMillis */, true /* cacheFlush */,
+                    LONG_TTL, emptyList() /* entries */),
+            MdnsServiceRecord(otherServiceName, 0L /* receiptTimeMillis */, true /* cacheFlush */,
+                    SHORT_TTL, 0 /* servicePriority */, 0 /* serviceWeight */, TEST_PORT,
+                    otherHostname),
+            MdnsInetAddressRecord(otherHostname, 0L /* receiptTimeMillis */, true /* cacheFlush */,
+                    SHORT_TTL, hostAddresses[0].address),
+            MdnsInetAddressRecord(otherHostname, 0L /* receiptTimeMillis */, true /* cacheFlush */,
+                    SHORT_TTL, hostAddresses[1].address),
+            MdnsInetAddressRecord(otherHostname, 0L /* receiptTimeMillis */, true /* cacheFlush */,
+                    SHORT_TTL, hostAddresses[2].address),
+            MdnsNsecRecord(otherServiceName, 0L /* receiptTimeMillis */, true /* cacheFlush */,
+                    LONG_TTL, otherServiceName /* nextDomain */,
+                    intArrayOf(MdnsRecord.TYPE_TXT, MdnsRecord.TYPE_SRV)),
+            MdnsNsecRecord(otherHostname, 0L /* receiptTimeMillis */, true /* cacheFlush */,
+                    SHORT_TTL, otherHostname /* nextDomain */,
+                    intArrayOf(MdnsRecord.TYPE_A, MdnsRecord.TYPE_AAAA)))
     private val thread = HandlerThread(MdnsReplySenderTest::class.simpleName)
     private val socket = mock(MdnsInterfaceSocket::class.java)
     private val buffer = ByteArray(1500)
     private val sharedLog = SharedLog(MdnsReplySenderTest::class.simpleName)
     private val deps = mock(MdnsReplySender.Dependencies::class.java)
     private val handler by lazy { Handler(thread.looper) }
-    private val replySender by lazy {
-        MdnsReplySender(thread.looper, socket, buffer, sharedLog, false /* enableDebugLog */, deps)
-    }
 
     @Before
     fun setUp() {
@@ -106,37 +135,180 @@
         return future.get(DEFAULT_TIMEOUT_MS, TimeUnit.MILLISECONDS)
     }
 
-    private fun sendNow(packet: MdnsPacket, destination: InetSocketAddress):
-            Unit = runningOnHandlerAndReturn { replySender.sendNow(packet, destination) }
+    private fun sendNow(sender: MdnsReplySender, packet: MdnsPacket, dest: InetSocketAddress):
+            Unit = runningOnHandlerAndReturn { sender.sendNow(packet, dest) }
 
-    private fun queueReply(reply: MdnsReplyInfo):
-            Unit = runningOnHandlerAndReturn { replySender.queueReply(reply) }
+    private fun queueReply(sender: MdnsReplySender, reply: MdnsReplyInfo):
+            Unit = runningOnHandlerAndReturn { sender.queueReply(reply) }
+
+    private fun buildFlags(enableKAS: Boolean): MdnsFeatureFlags {
+        return MdnsFeatureFlags.newBuilder()
+                .setIsKnownAnswerSuppressionEnabled(enableKAS).build()
+    }
+
+    private fun createSender(enableKAS: Boolean): MdnsReplySender =
+            MdnsReplySender(thread.looper, socket, buffer, sharedLog, false /* enableDebugLog */,
+                    deps, buildFlags(enableKAS))
 
     @Test
     fun testSendNow() {
+        val replySender = createSender(enableKAS = false)
         val packet = MdnsPacket(0x8400,
-                listOf() /* questions */,
+                emptyList() /* questions */,
                 answers,
-                listOf() /* authorityRecords */,
+                emptyList() /* authorityRecords */,
                 additionalAnswers)
-        sendNow(packet, IPV4_SOCKET_ADDR)
+        sendNow(replySender, packet, IPV4_SOCKET_ADDR)
         verify(socket).send(argThat{ it.socketAddress.equals(IPV4_SOCKET_ADDR) })
     }
 
+    private fun verifyMessageQueued(
+            sender: MdnsReplySender,
+            replies: List<MdnsReplyInfo>
+    ): Pair<Handler, Message> {
+        val handlerCaptor = ArgumentCaptor.forClass(Handler::class.java)
+        val messageCaptor = ArgumentCaptor.forClass(Message::class.java)
+        for (reply in replies) {
+            queueReply(sender, reply)
+            verify(deps).sendMessageDelayed(
+                    handlerCaptor.capture(), messageCaptor.capture(), eq(reply.sendDelayMs))
+        }
+        return Pair(handlerCaptor.value, messageCaptor.value)
+    }
+
+    private fun verifyReplySent(
+            realHandler: Handler,
+            delayMessage: Message,
+            remainingAnswers: List<MdnsRecord>
+    ) {
+        val datagramPacketCaptor = ArgumentCaptor.forClass(DatagramPacket::class.java)
+        realHandler.sendMessage(delayMessage)
+        verify(socket, timeout(DEFAULT_TIMEOUT_MS)).send(datagramPacketCaptor.capture())
+
+        val dPacket = datagramPacketCaptor.value
+        val mdnsPacket = MdnsPacket.parse(MdnsPacketReader(
+                dPacket.data, dPacket.length, buildFlags(enableKAS = false)))
+        assertEquals(mdnsPacket.answers.toSet(), remainingAnswers.toSet())
+    }
+
     @Test
     fun testQueueReply() {
+        val replySender = createSender(enableKAS = false)
         val reply = MdnsReplyInfo(answers, additionalAnswers, 20L /* sendDelayMs */,
-                IPV4_SOCKET_ADDR)
-        val handlerCaptor = ArgumentCaptor.forClass(Handler::class.java)
-        val messageCaptor = ArgumentCaptor.forClass(Message::class.java)
-        queueReply(reply)
-        verify(deps).sendMessageDelayed(handlerCaptor.capture(), messageCaptor.capture(), eq(20L))
+                IPV4_SOCKET_ADDR, source, emptyList())
+        val (handler, message) = verifyMessageQueued(replySender, listOf(reply))
+        verifyReplySent(handler, message, answers)
+    }
 
-        val realHandler = handlerCaptor.value
-        val delayMessage = messageCaptor.value
-        realHandler.sendMessage(delayMessage)
-        verify(socket, timeout(DEFAULT_TIMEOUT_MS)).send(argThat{
-            it.socketAddress.equals(IPV4_SOCKET_ADDR)
-        })
+    @Test
+    fun testQueueReply_KnownAnswerSuppressionEnabled() {
+        val replySender = createSender(enableKAS = true)
+        val reply = MdnsReplyInfo(answers, additionalAnswers, 20L /* sendDelayMs */,
+                IPV4_SOCKET_ADDR, source, emptyList())
+        val (handler, message) = verifyMessageQueued(replySender, listOf(reply))
+        verifyReplySent(handler, message, answers)
+    }
+
+    @Test
+    fun testQueueReply_MultiplePacket() {
+        val replySender = createSender(enableKAS = true)
+        val reply = MdnsReplyInfo(answers, additionalAnswers, 400L /* sendDelayMs */,
+                IPV4_SOCKET_ADDR, source, emptyList())
+        verifyMessageQueued(replySender, listOf(reply))
+
+        // Receive a known-answer packet and verify no message queued.
+        val knownAnswersReply = MdnsReplyInfo(emptyList(), emptyList(), 0L /* sendDelayMs */,
+                IPV4_SOCKET_ADDR, source, answers)
+        queueReply(replySender, knownAnswersReply)
+        verify(deps, times(1)).sendMessageDelayed(any(), any(), anyLong())
+    }
+
+    @Test
+    fun testQueueReply_MultiplePacket_LostSubsequentPacket() {
+        val replySender = createSender(enableKAS = true)
+        val reply = MdnsReplyInfo(answers, additionalAnswers, 400L /* sendDelayMs */,
+                IPV4_SOCKET_ADDR, source, emptyList())
+        val (handler, message) = verifyMessageQueued(replySender, listOf(reply))
+
+        // No subsequent packets
+        verifyReplySent(handler, message, answers)
+    }
+
+    @Test
+    fun testQueueReply_MultiplePacket_OtherKnownAnswer() {
+        val replySender = createSender(enableKAS = true)
+        val reply = MdnsReplyInfo(answers, additionalAnswers, 400L /* sendDelayMs */,
+                IPV4_SOCKET_ADDR, source, emptyList())
+        // Other known-answer service
+        val otherKnownAnswersReply = MdnsReplyInfo(emptyList(), emptyList(), 0L /* sendDelayMs */,
+                IPV4_SOCKET_ADDR, source, otherAnswers)
+        val (handler, message) = verifyMessageQueued(
+                replySender, listOf(reply, otherKnownAnswersReply))
+        verifyReplySent(handler, message, answers)
+    }
+
+    @Test
+    fun testQueueReply_MultiplePacket_TwoKnownAnswerPackets() {
+        val replySender = createSender(enableKAS = true)
+        val reply = MdnsReplyInfo(answers, additionalAnswers, 400L /* sendDelayMs */,
+                IPV4_SOCKET_ADDR, source, emptyList())
+        val firstKnownAnswerReply = MdnsReplyInfo(emptyList(), emptyList(), 401L /* sendDelayMs */,
+                IPV4_SOCKET_ADDR, source, otherAnswers)
+        verifyMessageQueued(replySender, listOf(reply, firstKnownAnswerReply))
+
+        // Second known-answer service
+        val secondKnownAnswerReply = MdnsReplyInfo(emptyList(), emptyList(), 0L /* sendDelayMs */,
+                IPV4_SOCKET_ADDR, source, answers)
+        queueReply(replySender, secondKnownAnswerReply)
+
+        // Verify that no reply is queued, as all answers are known.
+        verify(deps, times(2)).sendMessageDelayed(any(), any(), anyLong())
+    }
+
+    @Test
+    fun testQueueReply_MultiplePacket_LostSecondaryPacket() {
+        val replySender = createSender(enableKAS = true)
+        val reply = MdnsReplyInfo(answers, additionalAnswers, 400L /* sendDelayMs */,
+                IPV4_SOCKET_ADDR, source, emptyList())
+        val firstKnownAnswerReply = MdnsReplyInfo(emptyList(), emptyList(), 401L /* sendDelayMs */,
+                IPV4_SOCKET_ADDR, source, otherAnswers)
+        val (handler, message) = verifyMessageQueued(
+                replySender, listOf(reply, firstKnownAnswerReply))
+
+        // Second known-answer service lost
+        verifyReplySent(handler, message, answers)
+    }
+
+    @Test
+    fun testQueueReply_MultiplePacket_WithMultipleQuestions() {
+        val replySender = createSender(enableKAS = true)
+        val twoAnswers = listOf(
+                MdnsPointerRecord(serviceType, 0L /* receiptTimeMillis */, false /* cacheFlush */,
+                        LONG_TTL, serviceName),
+                MdnsServiceRecord(otherServiceName, 0L /* receiptTimeMillis */,
+                        true /* cacheFlush */, SHORT_TTL, 0 /* servicePriority */,
+                        0 /* serviceWeight */, TEST_PORT, otherHostname))
+        val reply = MdnsReplyInfo(twoAnswers, additionalAnswers, 400L /* sendDelayMs */,
+                IPV4_SOCKET_ADDR, source, emptyList())
+        val knownAnswersReply = MdnsReplyInfo(otherAnswers, otherAdditionalAnswers,
+                20L /* sendDelayMs */, IPV4_SOCKET_ADDR, source, answers)
+        val (handler, message) = verifyMessageQueued(replySender, listOf(reply, knownAnswersReply))
+
+        val remainingAnswers = listOf(
+                MdnsPointerRecord(serviceType, 0L /* receiptTimeMillis */, false /* cacheFlush */,
+                        LONG_TTL, otherServiceName),
+                MdnsServiceRecord(otherServiceName, 0L /* receiptTimeMillis */,
+                        true /* cacheFlush */, SHORT_TTL, 0 /* servicePriority */,
+                        0 /* serviceWeight */, TEST_PORT, otherHostname))
+        verifyReplySent(handler, message, remainingAnswers)
+    }
+
+    @Test
+    fun testGetReplyDestination() {
+        assertEquals(IPV4_SOCKET_ADDR, getReplyDestination(IPV4_SOCKET_ADDR, IPV4_SOCKET_ADDR))
+        assertEquals(IPV6_SOCKET_ADDR, getReplyDestination(IPV6_SOCKET_ADDR, IPV6_SOCKET_ADDR))
+        assertEquals(IPV4_SOCKET_ADDR, getReplyDestination(source, IPV4_SOCKET_ADDR))
+        assertEquals(IPV6_SOCKET_ADDR, getReplyDestination(source, IPV6_SOCKET_ADDR))
+        assertEquals(source, getReplyDestination(source, source))
     }
 }