Merge "[mdns] Skip conflict check for incoming mDNS answer records which are probed in the repository" into main
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java b/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java
index 39e8bcc..36f3982 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java
@@ -34,6 +34,7 @@
 import android.util.ArrayMap;
 import android.util.ArraySet;
 import android.util.SparseArray;
+import android.util.SparseIntArray;
 
 import com.android.internal.annotations.VisibleForTesting;
 import com.android.net.module.util.CollectionUtils;
@@ -87,6 +88,12 @@
     private static final String[] DNS_SD_SERVICE_TYPE =
             new String[] { "_services", "_dns-sd", "_udp", LOCAL_TLD };
 
+    private enum RecordConflictType {
+        NO_CONFLICT,
+        CONFLICT,
+        IDENTICAL
+    }
+
     @NonNull
     private final Random mDelayGenerator = new Random();
     // Map of service unique ID -> records for service
@@ -1172,38 +1179,49 @@
      * {@link MdnsInterfaceAdvertiser#CONFLICT_HOST}.
      */
     public Map<Integer, Integer> getConflictingServices(MdnsPacket packet) {
-        // Avoid allocating a new set for each incoming packet: use an empty set by default.
-        Map<Integer, Integer> conflicting = Collections.emptyMap();
+        Map<Integer, Integer> conflicting = new ArrayMap<>();
         for (MdnsRecord record : packet.answers) {
+            SparseIntArray conflictingWithRecord = new SparseIntArray();
             for (int i = 0; i < mServices.size(); i++) {
                 final ServiceRegistration registration = mServices.valueAt(i);
                 if (registration.exiting) continue;
 
-                int conflictType = 0;
+                final RecordConflictType conflictForService =
+                        conflictForService(record, registration);
+                final RecordConflictType conflictForHost = conflictForHost(record, registration);
 
-                if (conflictForService(record, registration)) {
-                    conflictType |= CONFLICT_SERVICE;
+                // Identical record is found in the repository so there won't be a conflict.
+                if (conflictForService == RecordConflictType.IDENTICAL
+                        || conflictForHost == RecordConflictType.IDENTICAL) {
+                    conflictingWithRecord.clear();
+                    break;
                 }
 
-                if (conflictForHost(record, registration)) {
+                int conflictType = 0;
+                if (conflictForService == RecordConflictType.CONFLICT) {
+                    conflictType |= CONFLICT_SERVICE;
+                }
+                if (conflictForHost == RecordConflictType.CONFLICT) {
                     conflictType |= CONFLICT_HOST;
                 }
 
                 if (conflictType != 0) {
-                    if (conflicting.isEmpty()) {
-                        // Conflict was found: use a mutable set
-                        conflicting = new ArrayMap<>();
-                    }
                     final int serviceId = mServices.keyAt(i);
-                    conflicting.put(serviceId, conflictType);
+                    conflictingWithRecord.put(serviceId, conflictType);
                 }
             }
+            for (int i = 0; i < conflictingWithRecord.size(); i++) {
+                final int serviceId = conflictingWithRecord.keyAt(i);
+                final int conflictType = conflictingWithRecord.valueAt(i);
+                final int oldConflictType = conflicting.getOrDefault(serviceId, 0);
+                conflicting.put(serviceId, oldConflictType | conflictType);
+            }
         }
 
         return conflicting;
     }
 
-    private static boolean conflictForService(
+    private static RecordConflictType conflictForService(
             @NonNull MdnsRecord record, @NonNull ServiceRegistration registration) {
         String[] fullServiceName;
         if (registration.srvRecord != null) {
@@ -1211,75 +1229,75 @@
         } else if (registration.serviceKeyRecord != null) {
             fullServiceName = registration.serviceKeyRecord.record.getName();
         } else {
-            return false;
+            return RecordConflictType.NO_CONFLICT;
         }
 
         if (!MdnsUtils.equalsDnsLabelIgnoreDnsCase(record.getName(), fullServiceName)) {
-            return false;
+            return RecordConflictType.NO_CONFLICT;
         }
 
         // As per RFC6762 9., it's fine if the "conflict" is an identical record with same
         // data.
         if (record instanceof MdnsServiceRecord && equals(record, registration.srvRecord)) {
-            return false;
+            return RecordConflictType.IDENTICAL;
         }
         if (record instanceof MdnsTextRecord && equals(record, registration.txtRecord)) {
-            return false;
+            return RecordConflictType.IDENTICAL;
         }
         if (record instanceof MdnsKeyRecord && equals(record, registration.serviceKeyRecord)) {
-            return false;
+            return RecordConflictType.IDENTICAL;
         }
 
-        return true;
+        return RecordConflictType.CONFLICT;
     }
 
-    private boolean conflictForHost(
+    private RecordConflictType conflictForHost(
             @NonNull MdnsRecord record, @NonNull ServiceRegistration registration) {
         // Only custom hosts are checked. When using the default host, the hostname is derived from
         // a UUID and it's supposed to be unique.
         if (registration.serviceInfo.getHostname() == null) {
-            return false;
+            return RecordConflictType.NO_CONFLICT;
         }
 
-        // It cannot be a hostname conflict because not record is registered with the hostname.
+        // It cannot be a hostname conflict because no record is registered with the hostname.
         if (registration.addressRecords.isEmpty() && registration.hostKeyRecord == null) {
-            return false;
+            return RecordConflictType.NO_CONFLICT;
         }
 
         // The record's name cannot be registered by NsdManager so it's not a conflict.
         if (record.getName().length != 2 || !record.getName()[1].equals(LOCAL_TLD)) {
-            return false;
+            return RecordConflictType.NO_CONFLICT;
         }
 
         // Different names. There won't be a conflict.
         if (!MdnsUtils.equalsIgnoreDnsCase(
                 record.getName()[0], registration.serviceInfo.getHostname())) {
-            return false;
+            return RecordConflictType.NO_CONFLICT;
         }
 
         // As per RFC6762 9., it's fine if the "conflict" is an identical record with same
         // data.
         if (record instanceof MdnsInetAddressRecord
                 && hasInetAddressRecord(registration, (MdnsInetAddressRecord) record)) {
-            return false;
+            return RecordConflictType.IDENTICAL;
         }
         if (record instanceof MdnsKeyRecord && equals(record, registration.hostKeyRecord)) {
-            return false;
+            return RecordConflictType.IDENTICAL;
         }
 
         // Per RFC 6762 8.1, when a record is being probed, any answer containing a record with that
         // name, of any type, MUST be considered a conflicting response.
         if (registration.isProbing) {
-            return true;
+            return RecordConflictType.CONFLICT;
         }
         if (record instanceof MdnsInetAddressRecord && !registration.addressRecords.isEmpty()) {
-            return true;
+            return RecordConflictType.CONFLICT;
         }
         if (record instanceof MdnsKeyRecord && registration.hostKeyRecord != null) {
-            return true;
+            return RecordConflictType.CONFLICT;
         }
 
-        return false;
+        return RecordConflictType.NO_CONFLICT;
     }
 
     private List<RecordInfo<MdnsInetAddressRecord>> getInetAddressRecordsForHostname(
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt
index d735dc6..2cb97c9 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt
@@ -1625,6 +1625,44 @@
     }
 
     @Test
+    fun testGetConflictingServices_multipleRegistrationsForHostKey_noConflict() {
+        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
+
+        repository.addServiceAndFinishProbing(TEST_SERVICE_ID_1, NsdServiceInfo().apply {
+            hostname = "MyHost"
+            hostAddresses = listOf(
+                parseNumericAddress("2001:db8::1"),
+                parseNumericAddress("2001:db8::2"))
+            publicKey = TEST_PUBLIC_KEY
+        })
+        repository.addService(TEST_SERVICE_ID_2, NsdServiceInfo().apply {
+            serviceType = "_testservice._tcp"
+            serviceName = "MyTestService"
+            port = TEST_PORT
+            hostname = "MyHost"
+            publicKey = TEST_PUBLIC_KEY
+        }, null /* ttl */)
+
+        // Although there's a KEY RR in the second registration being probed, it shouldn't conflict
+        // with an address record which is from a probed registration in the repository.
+        val otherTtlMillis = 1234L
+        val packet = MdnsPacket(
+            0 /* flags */,
+            emptyList() /* questions */,
+            listOf(
+                MdnsInetAddressRecord(
+                    arrayOf("MyHost", "local"),
+                    0L /* receiptTimeMillis */, true /* cacheFlush */,
+                    otherTtlMillis,
+                    parseNumericAddress("2001:db8::1"))
+            ) /* answers */,
+            emptyList() /* authorityRecords */,
+            emptyList() /* additionalRecords */)
+
+        assertEquals(mapOf(), repository.getConflictingServices(packet))
+    }
+
+    @Test
     fun testGetConflictingServices_IdenticalService() {
         val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
         repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1, null /* ttl */)