Merge "[mdns] separate the multicast reply quota for IPv4 and IPv6" into main
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java b/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java
index 073e465..5ff19ca 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java
@@ -140,17 +140,25 @@
         public final boolean isSharedName;
 
         /**
-         * Last time (as per SystemClock.elapsedRealtime) when advertised via multicast, 0 if never
+         * Last time (as per SystemClock.elapsedRealtime) when advertised via multicast on IPv4, 0
+         * if never
          */
-        public long lastAdvertisedTimeMs;
+        public long lastAdvertisedOnIpv4TimeMs;
 
         /**
-         * Last time (as per SystemClock.elapsedRealtime) when sent via unicast or multicast,
-         * 0 if never
+         * Last time (as per SystemClock.elapsedRealtime) when advertised via multicast on IPv6, 0
+         * if never
          */
-        // FIXME: the `lastSentTimeMs` and `lastAdvertisedTimeMs` should be maintained separately
-        // for IPv4 and IPv6, because neither IPv4 nor and IPv6 clients can receive replies in
-        // different address space.
+        public long lastAdvertisedOnIpv6TimeMs;
+
+        /**
+         * Last time (as per SystemClock.elapsedRealtime) when sent via unicast or multicast, 0 if
+         * never.
+         *
+         * <p>Different from lastAdvertisedOnIpv(4|6)TimeMs, lastSentTimeMs is mainly used for
+         * tracking is a record is ever sent out, no matter unicast/multicast or IPv4/IPv6. It's
+         * unnecessary to maintain two versions (IPv4/IPv6) for it.
+         */
         public long lastSentTimeMs;
 
         RecordInfo(NsdServiceInfo serviceInfo, T record, boolean sharedName) {
@@ -578,6 +586,7 @@
     @Nullable
     public MdnsReplyInfo getReply(MdnsPacket packet, InetSocketAddress src) {
         final long now = SystemClock.elapsedRealtime();
+        final boolean isQuestionOnIpv4 = src.getAddress() instanceof Inet4Address;
 
         // TODO: b/322142420 - Set<RecordInfo<?>> may contain duplicate records wrapped in different
         // RecordInfo<?>s when custom host is enabled.
@@ -595,7 +604,7 @@
                     null /* serviceSrvRecord */, null /* serviceTxtRecord */,
                     null /* hostname */,
                     replyUnicastEnabled, now, answerInfo, additionalAnswerInfo,
-                    Collections.emptyList())) {
+                    Collections.emptyList(), isQuestionOnIpv4)) {
                 replyUnicast &= question.isUnicastReplyRequested();
             }
 
@@ -607,7 +616,7 @@
                         registration.srvRecord, registration.txtRecord,
                         registration.serviceInfo.getHostname(),
                         replyUnicastEnabled, now,
-                        answerInfo, additionalAnswerInfo, packet.answers)) {
+                        answerInfo, additionalAnswerInfo, packet.answers, isQuestionOnIpv4)) {
                     replyUnicast &= question.isUnicastReplyRequested();
                     registration.repliedServiceCount++;
                     registration.sentPacketCount++;
@@ -685,7 +694,7 @@
             // multicast responses. Unicast replies are faster as they do not need to wait for the
             // beacon interval on Wi-Fi.
             dest = src;
-        } else if (src.getAddress() instanceof Inet4Address) {
+        } else if (isQuestionOnIpv4) {
             dest = IPV4_SOCKET_ADDR;
         } else {
             dest = IPV6_SOCKET_ADDR;
@@ -697,7 +706,11 @@
             // TODO: consider actual packet send delay after response aggregation
             info.lastSentTimeMs = now + delayMs;
             if (!replyUnicast) {
-                info.lastAdvertisedTimeMs = info.lastSentTimeMs;
+                if (isQuestionOnIpv4) {
+                    info.lastAdvertisedOnIpv4TimeMs = info.lastSentTimeMs;
+                } else {
+                    info.lastAdvertisedOnIpv6TimeMs = info.lastSentTimeMs;
+                }
             }
             // Different RecordInfos may the contain the same record
             if (!answerRecords.contains(info.record)) {
@@ -729,7 +742,8 @@
             @Nullable String hostname,
             boolean replyUnicastEnabled, long now, @NonNull Set<RecordInfo<?>> answerInfo,
             @NonNull Set<RecordInfo<?>> additionalAnswerInfo,
-            @NonNull List<MdnsRecord> knownAnswerRecords) {
+            @NonNull List<MdnsRecord> knownAnswerRecords,
+            boolean isQuestionOnIpv4) {
         boolean hasDnsSdPtrRecordAnswer = false;
         boolean hasDnsSdSrvRecordAnswer = false;
         boolean hasFullyOwnedNameMatch = false;
@@ -778,10 +792,20 @@
 
             // TODO: responses to probe queries should bypass this check and only ensure the
             // reply is sent 250ms after the last sent time (RFC 6762 p.15)
-            if (!(replyUnicastEnabled && question.isUnicastReplyRequested())
-                    && info.lastAdvertisedTimeMs > 0L
-                    && now - info.lastAdvertisedTimeMs < MIN_MULTICAST_REPLY_INTERVAL_MS) {
-                continue;
+            if (!(replyUnicastEnabled && question.isUnicastReplyRequested())) {
+                if (isQuestionOnIpv4) { // IPv4
+                    if (info.lastAdvertisedOnIpv4TimeMs > 0L
+                            && now - info.lastAdvertisedOnIpv4TimeMs
+                                    < MIN_MULTICAST_REPLY_INTERVAL_MS) {
+                        continue;
+                    }
+                } else { // IPv6
+                    if (info.lastAdvertisedOnIpv6TimeMs > 0L
+                            && now - info.lastAdvertisedOnIpv6TimeMs
+                                    < MIN_MULTICAST_REPLY_INTERVAL_MS) {
+                        continue;
+                    }
+                }
             }
 
             answerInfo.add(info);
@@ -1305,7 +1329,8 @@
         final long now = SystemClock.elapsedRealtime();
         for (RecordInfo<?> record : registration.allRecords) {
             record.lastSentTimeMs = now;
-            record.lastAdvertisedTimeMs = now;
+            record.lastAdvertisedOnIpv4TimeMs = now;
+            record.lastAdvertisedOnIpv6TimeMs = now;
         }
         registration.sentPacketCount += sentPacketCount;
     }
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt
index 271cc65..bdefbf5 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt
@@ -1003,6 +1003,58 @@
     }
 
     @Test
+    fun testGetReply_ipv4AndIpv6Queries_ipv4AndIpv6Replies() {
+        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
+        repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1, setOf(TEST_SUBTYPE))
+        val query = makeQuery(TYPE_PTR to arrayOf("_testservice", "_tcp", "local"))
+
+        val srcIpv4 = InetSocketAddress(parseNumericAddress("192.0.2.123"), 5353)
+        val replyIpv4 = repository.getReply(query, srcIpv4)
+        val srcIpv6 = InetSocketAddress(parseNumericAddress("2001:db8::123"), 5353)
+        val replyIpv6 = repository.getReply(query, srcIpv6)
+
+        assertNotNull(replyIpv4)
+        assertEquals(MdnsConstants.getMdnsIPv4Address(), replyIpv4.destination.address)
+        assertEquals(MdnsConstants.MDNS_PORT, replyIpv4.destination.port)
+        assertNotNull(replyIpv6)
+        assertEquals(MdnsConstants.getMdnsIPv6Address(), replyIpv6.destination.address)
+        assertEquals(MdnsConstants.MDNS_PORT, replyIpv6.destination.port)
+    }
+
+    @Test
+    fun testGetReply_twoIpv4Queries_theSecondReplyIsThrottled() {
+        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
+        repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1, setOf(TEST_SUBTYPE))
+        val query = makeQuery(TYPE_PTR to arrayOf("_testservice", "_tcp", "local"))
+
+        val srcIpv4 = InetSocketAddress(parseNumericAddress("192.0.2.123"), 5353)
+        val firstReplyIpv4 = repository.getReply(query, srcIpv4)
+        val secondReply = repository.getReply(query, srcIpv4)
+
+        assertNotNull(firstReplyIpv4)
+        assertEquals(MdnsConstants.getMdnsIPv4Address(), firstReplyIpv4.destination.address)
+        assertEquals(MdnsConstants.MDNS_PORT, firstReplyIpv4.destination.port)
+        assertNull(secondReply)
+    }
+
+
+    @Test
+    fun testGetReply_twoIpv6Queries_theSecondReplyIsThrottled() {
+        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
+        repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1, setOf(TEST_SUBTYPE))
+        val query = makeQuery(TYPE_PTR to arrayOf("_testservice", "_tcp", "local"))
+
+        val srcIpv6 = InetSocketAddress(parseNumericAddress("2001:db8::123"), 5353)
+        val firstReplyIpv6 = repository.getReply(query, srcIpv6)
+        val secondReply = repository.getReply(query, srcIpv6)
+
+        assertNotNull(firstReplyIpv6)
+        assertEquals(MdnsConstants.getMdnsIPv6Address(), firstReplyIpv6.destination.address)
+        assertEquals(MdnsConstants.MDNS_PORT, firstReplyIpv6.destination.port)
+        assertNull(secondReply)
+    }
+
+    @Test
     fun testGetConflictingServices() {
         val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
         repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1, null /* ttl */)