Merge "Implement conflict detection"
diff --git a/service-t/src/com/android/server/mdns/MdnsInterfaceAdvertiser.java b/service-t/src/com/android/server/mdns/MdnsInterfaceAdvertiser.java
index a14b5ad..c616e01 100644
--- a/service-t/src/com/android/server/mdns/MdnsInterfaceAdvertiser.java
+++ b/service-t/src/com/android/server/mdns/MdnsInterfaceAdvertiser.java
@@ -278,14 +278,23 @@
      * Reset a service to the probing state due to a conflict found on the network.
      */
     public void restartProbingForConflict(int serviceId) {
-        // TODO: implement
+        final MdnsProber.ProbingInfo probingInfo = mRecordRepository.setServiceProbing(serviceId);
+        if (probingInfo == null) return;
+
+        mProber.restartForConflict(probingInfo);
     }
 
     /**
      * Rename a service following a conflict found on the network, and restart probing.
+     *
+     * If the service was not registered on this {@link MdnsInterfaceAdvertiser}, this is a no-op.
      */
     public void renameServiceForConflict(int serviceId, NsdServiceInfo newInfo) {
-        // TODO: implement
+        final MdnsProber.ProbingInfo probingInfo = mRecordRepository.renameServiceForConflict(
+                serviceId, newInfo);
+        if (probingInfo == null) return;
+
+        mProber.restartForConflict(probingInfo);
     }
 
     /**
@@ -319,8 +328,15 @@
                             + packet.additionalRecords.size() + " additional from " + src);
         }
 
-        final MdnsRecordRepository.ReplyInfo answers =
-                mRecordRepository.getReply(packet, src);
+        for (int conflictServiceId : mRecordRepository.getConflictingServices(packet)) {
+            mCbHandler.post(() -> mCb.onServiceConflict(this, conflictServiceId));
+        }
+
+        // Even in case of conflict, add replies for other services. But in general conflicts would
+        // happen when the incoming packet has answer records (not a question), so there will be no
+        // answer. One exception is simultaneous probe tiebreaking (rfc6762 8.2), in which case the
+        // conflicting service is still probing and won't reply either.
+        final MdnsRecordRepository.ReplyInfo answers = mRecordRepository.getReply(packet, src);
 
         if (answers == null) return;
         mReplySender.queueReply(answers);
diff --git a/service-t/src/com/android/server/mdns/MdnsProber.java b/service-t/src/com/android/server/mdns/MdnsProber.java
index 2cd9148..669b323 100644
--- a/service-t/src/com/android/server/mdns/MdnsProber.java
+++ b/service-t/src/com/android/server/mdns/MdnsProber.java
@@ -33,13 +33,13 @@
  * TODO: implement receiving replies and handling conflicts.
  */
 public class MdnsProber extends MdnsPacketRepeater<MdnsProber.ProbingInfo> {
+    private static final long CONFLICT_RETRY_DELAY_MS = 5_000L;
     @NonNull
     private final String mLogTag;
 
     public MdnsProber(@NonNull String interfaceTag, @NonNull Looper looper,
             @NonNull MdnsReplySender replySender,
             @NonNull PacketRepeaterCallback<ProbingInfo> cb) {
-        // 3 packets as per https://datatracker.ietf.org/doc/html/rfc6762#section-8.1
         super(looper, replySender, cb);
         mLogTag = MdnsProber.class.getSimpleName() + "/" + interfaceTag;
     }
@@ -140,4 +140,18 @@
     private void startProbing(@NonNull ProbingInfo info, long delay) {
         startSending(info.getServiceId(), info, delay);
     }
+
+    /**
+     * Restart probing with new service info as a conflict was found.
+     */
+    public void restartForConflict(@NonNull ProbingInfo newInfo) {
+        stop(newInfo.getServiceId());
+
+        /* RFC 6762 8.1: "If fifteen conflicts occur within any ten-second period, then the host
+        MUST wait at least five seconds before each successive additional probe attempt. [...]
+        For very simple devices, a valid way to comply with this requirement is to always wait
+        five seconds after any failed probe attempt before trying again. */
+        // TODO: count 15 conflicts in 10s instead of waiting for 5s every time
+        startProbing(newInfo, CONFLICT_RETRY_DELAY_MS);
+    }
 }
diff --git a/service-t/src/com/android/server/mdns/MdnsRecordRepository.java b/service-t/src/com/android/server/mdns/MdnsRecordRepository.java
index 4b2f553..e975ab4 100644
--- a/service-t/src/com/android/server/mdns/MdnsRecordRepository.java
+++ b/service-t/src/com/android/server/mdns/MdnsRecordRepository.java
@@ -43,6 +43,7 @@
 import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
+import java.util.Objects;
 import java.util.Random;
 import java.util.Set;
 import java.util.TreeMap;
@@ -721,6 +722,55 @@
     }
 
     /**
+     * Get the service IDs of services conflicting with a received packet.
+     */
+    public Set<Integer> getConflictingServices(MdnsPacket packet) {
+        // Avoid allocating a new set for each incoming packet: use an empty set by default.
+        Set<Integer> conflicting = Collections.emptySet();
+        for (MdnsRecord record : packet.answers) {
+            for (int i = 0; i < mServices.size(); i++) {
+                final ServiceRegistration registration = mServices.valueAt(i);
+                if (registration.exiting) continue;
+
+                // Only look for conflicts in service name, as a different service name can be used
+                // if there is a conflict, but there is nothing actionable if any other conflict
+                // happens. In fact probing is only done for the service name in the SRV record.
+                // This means only SRV and TXT records need to be checked.
+                final RecordInfo<MdnsServiceRecord> srvRecord = registration.srvRecord;
+                if (!Arrays.equals(record.getName(), srvRecord.record.getName())) continue;
+
+                // As per RFC6762 9., it's fine if the "conflict" is an identical record with same
+                // data.
+                if (record instanceof MdnsServiceRecord) {
+                    final MdnsServiceRecord local = srvRecord.record;
+                    final MdnsServiceRecord other = (MdnsServiceRecord) record;
+                    // Note "equals" does not consider TTL or receipt time, as intended here
+                    if (Objects.equals(local, other)) {
+                        continue;
+                    }
+                }
+
+                if (record instanceof MdnsTextRecord) {
+                    final MdnsTextRecord local = registration.txtRecord.record;
+                    final MdnsTextRecord other = (MdnsTextRecord) record;
+                    if (Objects.equals(local, other)) {
+                        continue;
+                    }
+                }
+
+                if (conflicting.size() == 0) {
+                    // Conflict was found: use a mutable set
+                    conflicting = new ArraySet<>();
+                }
+                final int serviceId = mServices.keyAt(i);
+                conflicting.add(serviceId);
+            }
+        }
+
+        return conflicting;
+    }
+
+    /**
      * (Re)set a service to the probing state.
      * @return The {@link MdnsProber.ProbingInfo} to send for probing.
      */
@@ -754,6 +804,21 @@
     }
 
     /**
+     * Rename a service to the newly provided info, following a conflict.
+     *
+     * If the specified service does not exist, this returns null.
+     */
+    @Nullable
+    public MdnsProber.ProbingInfo renameServiceForConflict(int serviceId, NsdServiceInfo newInfo) {
+        if (!mServices.contains(serviceId)) return null;
+
+        final ServiceRegistration newService = new ServiceRegistration(
+                mDeviceHostname, newInfo);
+        mServices.put(serviceId, newService);
+        return makeProbingInfo(serviceId, newService.srvRecord.record);
+    }
+
+    /**
      * Called when {@link MdnsAdvertiser} sent an advertisement for the given service.
      */
     public void onAdvertisementSent(int serviceId) {
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt
index 02b3976..4a806b1 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt
@@ -203,6 +203,59 @@
         verify(replySender).queueReply(mockReply)
     }
 
+    @Test
+    fun testConflict() {
+        addServiceAndFinishProbing(TEST_SERVICE_ID_1, TEST_SERVICE_1)
+        doReturn(setOf(TEST_SERVICE_ID_1)).`when`(repository).getConflictingServices(any())
+
+        // Reply obtained with:
+        // scapy.raw(scapy.DNS(
+        //    qd = None,
+        //    an = scapy.DNSRR(type='TXT', rrname='_testservice._tcp.local'))
+        // ).hex().upper()
+        val query = HexDump.hexStringToByteArray("0000010000000001000000000C5F7465737473657276696" +
+                "365045F746370056C6F63616C0000100001000000000000")
+        val src = InetSocketAddress(parseNumericAddress("2001:db8::456"), MdnsConstants.MDNS_PORT)
+        packetHandler.handlePacket(query, query.size, src)
+
+        val packetCaptor = ArgumentCaptor.forClass(MdnsPacket::class.java)
+        verify(repository).getConflictingServices(packetCaptor.capture())
+
+        packetCaptor.value.let {
+            assertEquals(0, it.questions.size)
+            assertEquals(1, it.answers.size)
+            assertEquals(0, it.authorityRecords.size)
+            assertEquals(0, it.additionalRecords.size)
+
+            assertTrue(it.answers[0] is MdnsTextRecord)
+            assertContentEquals(arrayOf("_testservice", "_tcp", "local"), it.answers[0].name)
+        }
+
+        thread.waitForIdle(TIMEOUT_MS)
+        verify(cb).onServiceConflict(advertiser, TEST_SERVICE_ID_1)
+    }
+
+    @Test
+    fun testRestartProbingForConflict() {
+        val mockProbingInfo = mock(ProbingInfo::class.java)
+        doReturn(mockProbingInfo).`when`(repository).setServiceProbing(TEST_SERVICE_ID_1)
+
+        advertiser.restartProbingForConflict(TEST_SERVICE_ID_1)
+
+        verify(prober).restartForConflict(mockProbingInfo)
+    }
+
+    @Test
+    fun testRenameServiceForConflict() {
+        val mockProbingInfo = mock(ProbingInfo::class.java)
+        doReturn(mockProbingInfo).`when`(repository).renameServiceForConflict(
+                TEST_SERVICE_ID_1, TEST_SERVICE_1)
+
+        advertiser.renameServiceForConflict(TEST_SERVICE_ID_1, TEST_SERVICE_1)
+
+        verify(prober).restartForConflict(mockProbingInfo)
+    }
+
     private fun addServiceAndFinishProbing(serviceId: Int, serviceInfo: NsdServiceInfo):
             AnnouncementInfo {
         val testProbingInfo = mock(ProbingInfo::class.java)
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt
index 597663c..ecc11ec 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt
@@ -24,6 +24,7 @@
 import com.android.server.connectivity.mdns.MdnsAnnouncer.AnnouncementInfo
 import com.android.server.connectivity.mdns.MdnsRecordRepository.Dependencies
 import com.android.server.connectivity.mdns.MdnsRecordRepository.getReverseDnsAddress
+import com.android.server.connectivity.mdns.MdnsServiceInfo.TextEntry
 import com.android.testutils.DevSdkIgnoreRule
 import com.android.testutils.DevSdkIgnoreRunner
 import java.net.InetSocketAddress
@@ -400,6 +401,63 @@
                     intArrayOf(MdnsRecord.TYPE_A, MdnsRecord.TYPE_AAAA)),
         ), reply.additionalAnswers)
     }
+
+    @Test
+    fun testGetConflictingServices() {
+        val repository = MdnsRecordRepository(thread.looper, deps)
+        repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
+        repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_2)
+
+        val packet = MdnsPacket(
+                0 /* flags */,
+                emptyList() /* questions */,
+                listOf(
+                    MdnsServiceRecord(
+                            arrayOf("MyTestService", "_testservice", "_tcp", "local"),
+                            0L /* receiptTimeMillis */, true /* cacheFlush */, 0L /* ttlMillis */,
+                            0 /* servicePriority */, 0 /* serviceWeight */,
+                            TEST_SERVICE_1.port + 1,
+                            TEST_HOSTNAME),
+                    MdnsTextRecord(
+                            arrayOf("MyOtherTestService", "_testservice", "_tcp", "local"),
+                            0L /* receiptTimeMillis */, true /* cacheFlush */, 0L /* ttlMillis */,
+                            listOf(TextEntry.fromString("somedifferent=entry"))),
+                ) /* answers */,
+                emptyList() /* authorityRecords */,
+                emptyList() /* additionalRecords */)
+
+        assertEquals(setOf(TEST_SERVICE_ID_1, TEST_SERVICE_ID_2),
+                repository.getConflictingServices(packet))
+    }
+
+    @Test
+    fun testGetConflictingServices_IdenticalService() {
+        val repository = MdnsRecordRepository(thread.looper, deps)
+        repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
+        repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_2)
+
+        val otherTtlMillis = 1234L
+        val packet = MdnsPacket(
+                0 /* flags */,
+                emptyList() /* questions */,
+                listOf(
+                        MdnsServiceRecord(
+                                arrayOf("MyTestService", "_testservice", "_tcp", "local"),
+                                0L /* receiptTimeMillis */, true /* cacheFlush */,
+                                otherTtlMillis, 0 /* servicePriority */, 0 /* serviceWeight */,
+                                TEST_SERVICE_1.port,
+                                TEST_HOSTNAME),
+                        MdnsTextRecord(
+                                arrayOf("MyOtherTestService", "_testservice", "_tcp", "local"),
+                                0L /* receiptTimeMillis */, true /* cacheFlush */,
+                                otherTtlMillis, emptyList()),
+                ) /* answers */,
+                emptyList() /* authorityRecords */,
+                emptyList() /* additionalRecords */)
+
+        // Above records are identical to the actual registrations: no conflict
+        assertEquals(emptySet(), repository.getConflictingServices(packet))
+    }
 }
 
 private fun MdnsRecordRepository.initWithService(serviceId: Int, serviceInfo: NsdServiceInfo):