Merge "Implement conflict detection"
diff --git a/service-t/src/com/android/server/mdns/MdnsInterfaceAdvertiser.java b/service-t/src/com/android/server/mdns/MdnsInterfaceAdvertiser.java
index a14b5ad..c616e01 100644
--- a/service-t/src/com/android/server/mdns/MdnsInterfaceAdvertiser.java
+++ b/service-t/src/com/android/server/mdns/MdnsInterfaceAdvertiser.java
@@ -278,14 +278,23 @@
* Reset a service to the probing state due to a conflict found on the network.
*/
public void restartProbingForConflict(int serviceId) {
- // TODO: implement
+ final MdnsProber.ProbingInfo probingInfo = mRecordRepository.setServiceProbing(serviceId);
+ if (probingInfo == null) return;
+
+ mProber.restartForConflict(probingInfo);
}
/**
* Rename a service following a conflict found on the network, and restart probing.
+ *
+ * If the service was not registered on this {@link MdnsInterfaceAdvertiser}, this is a no-op.
*/
public void renameServiceForConflict(int serviceId, NsdServiceInfo newInfo) {
- // TODO: implement
+ final MdnsProber.ProbingInfo probingInfo = mRecordRepository.renameServiceForConflict(
+ serviceId, newInfo);
+ if (probingInfo == null) return;
+
+ mProber.restartForConflict(probingInfo);
}
/**
@@ -319,8 +328,15 @@
+ packet.additionalRecords.size() + " additional from " + src);
}
- final MdnsRecordRepository.ReplyInfo answers =
- mRecordRepository.getReply(packet, src);
+ for (int conflictServiceId : mRecordRepository.getConflictingServices(packet)) {
+ mCbHandler.post(() -> mCb.onServiceConflict(this, conflictServiceId));
+ }
+
+ // Even in case of conflict, add replies for other services. But in general conflicts would
+ // happen when the incoming packet has answer records (not a question), so there will be no
+ // answer. One exception is simultaneous probe tiebreaking (rfc6762 8.2), in which case the
+ // conflicting service is still probing and won't reply either.
+ final MdnsRecordRepository.ReplyInfo answers = mRecordRepository.getReply(packet, src);
if (answers == null) return;
mReplySender.queueReply(answers);
diff --git a/service-t/src/com/android/server/mdns/MdnsProber.java b/service-t/src/com/android/server/mdns/MdnsProber.java
index 2cd9148..669b323 100644
--- a/service-t/src/com/android/server/mdns/MdnsProber.java
+++ b/service-t/src/com/android/server/mdns/MdnsProber.java
@@ -33,13 +33,13 @@
* TODO: implement receiving replies and handling conflicts.
*/
public class MdnsProber extends MdnsPacketRepeater<MdnsProber.ProbingInfo> {
+ private static final long CONFLICT_RETRY_DELAY_MS = 5_000L;
@NonNull
private final String mLogTag;
public MdnsProber(@NonNull String interfaceTag, @NonNull Looper looper,
@NonNull MdnsReplySender replySender,
@NonNull PacketRepeaterCallback<ProbingInfo> cb) {
- // 3 packets as per https://datatracker.ietf.org/doc/html/rfc6762#section-8.1
super(looper, replySender, cb);
mLogTag = MdnsProber.class.getSimpleName() + "/" + interfaceTag;
}
@@ -140,4 +140,18 @@
private void startProbing(@NonNull ProbingInfo info, long delay) {
startSending(info.getServiceId(), info, delay);
}
+
+ /**
+ * Restart probing with new service info as a conflict was found.
+ */
+ public void restartForConflict(@NonNull ProbingInfo newInfo) {
+ stop(newInfo.getServiceId());
+
+ /* RFC 6762 8.1: "If fifteen conflicts occur within any ten-second period, then the host
+ MUST wait at least five seconds before each successive additional probe attempt. [...]
+ For very simple devices, a valid way to comply with this requirement is to always wait
+ five seconds after any failed probe attempt before trying again. */
+ // TODO: count 15 conflicts in 10s instead of waiting for 5s every time
+ startProbing(newInfo, CONFLICT_RETRY_DELAY_MS);
+ }
}
diff --git a/service-t/src/com/android/server/mdns/MdnsRecordRepository.java b/service-t/src/com/android/server/mdns/MdnsRecordRepository.java
index 4b2f553..e975ab4 100644
--- a/service-t/src/com/android/server/mdns/MdnsRecordRepository.java
+++ b/service-t/src/com/android/server/mdns/MdnsRecordRepository.java
@@ -43,6 +43,7 @@
import java.util.Iterator;
import java.util.List;
import java.util.Map;
+import java.util.Objects;
import java.util.Random;
import java.util.Set;
import java.util.TreeMap;
@@ -721,6 +722,55 @@
}
/**
+ * Get the service IDs of services conflicting with a received packet.
+ */
+ public Set<Integer> getConflictingServices(MdnsPacket packet) {
+ // Avoid allocating a new set for each incoming packet: use an empty set by default.
+ Set<Integer> conflicting = Collections.emptySet();
+ for (MdnsRecord record : packet.answers) {
+ for (int i = 0; i < mServices.size(); i++) {
+ final ServiceRegistration registration = mServices.valueAt(i);
+ if (registration.exiting) continue;
+
+ // Only look for conflicts in service name, as a different service name can be used
+ // if there is a conflict, but there is nothing actionable if any other conflict
+ // happens. In fact probing is only done for the service name in the SRV record.
+ // This means only SRV and TXT records need to be checked.
+ final RecordInfo<MdnsServiceRecord> srvRecord = registration.srvRecord;
+ if (!Arrays.equals(record.getName(), srvRecord.record.getName())) continue;
+
+ // As per RFC6762 9., it's fine if the "conflict" is an identical record with same
+ // data.
+ if (record instanceof MdnsServiceRecord) {
+ final MdnsServiceRecord local = srvRecord.record;
+ final MdnsServiceRecord other = (MdnsServiceRecord) record;
+ // Note "equals" does not consider TTL or receipt time, as intended here
+ if (Objects.equals(local, other)) {
+ continue;
+ }
+ }
+
+ if (record instanceof MdnsTextRecord) {
+ final MdnsTextRecord local = registration.txtRecord.record;
+ final MdnsTextRecord other = (MdnsTextRecord) record;
+ if (Objects.equals(local, other)) {
+ continue;
+ }
+ }
+
+ if (conflicting.size() == 0) {
+ // Conflict was found: use a mutable set
+ conflicting = new ArraySet<>();
+ }
+ final int serviceId = mServices.keyAt(i);
+ conflicting.add(serviceId);
+ }
+ }
+
+ return conflicting;
+ }
+
+ /**
* (Re)set a service to the probing state.
* @return The {@link MdnsProber.ProbingInfo} to send for probing.
*/
@@ -754,6 +804,21 @@
}
/**
+ * Rename a service to the newly provided info, following a conflict.
+ *
+ * If the specified service does not exist, this returns null.
+ */
+ @Nullable
+ public MdnsProber.ProbingInfo renameServiceForConflict(int serviceId, NsdServiceInfo newInfo) {
+ if (!mServices.contains(serviceId)) return null;
+
+ final ServiceRegistration newService = new ServiceRegistration(
+ mDeviceHostname, newInfo);
+ mServices.put(serviceId, newService);
+ return makeProbingInfo(serviceId, newService.srvRecord.record);
+ }
+
+ /**
* Called when {@link MdnsAdvertiser} sent an advertisement for the given service.
*/
public void onAdvertisementSent(int serviceId) {
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt
index 02b3976..4a806b1 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt
@@ -203,6 +203,59 @@
verify(replySender).queueReply(mockReply)
}
+ @Test
+ fun testConflict() {
+ addServiceAndFinishProbing(TEST_SERVICE_ID_1, TEST_SERVICE_1)
+ doReturn(setOf(TEST_SERVICE_ID_1)).`when`(repository).getConflictingServices(any())
+
+ // Reply obtained with:
+ // scapy.raw(scapy.DNS(
+ // qd = None,
+ // an = scapy.DNSRR(type='TXT', rrname='_testservice._tcp.local'))
+ // ).hex().upper()
+ val query = HexDump.hexStringToByteArray("0000010000000001000000000C5F7465737473657276696" +
+ "365045F746370056C6F63616C0000100001000000000000")
+ val src = InetSocketAddress(parseNumericAddress("2001:db8::456"), MdnsConstants.MDNS_PORT)
+ packetHandler.handlePacket(query, query.size, src)
+
+ val packetCaptor = ArgumentCaptor.forClass(MdnsPacket::class.java)
+ verify(repository).getConflictingServices(packetCaptor.capture())
+
+ packetCaptor.value.let {
+ assertEquals(0, it.questions.size)
+ assertEquals(1, it.answers.size)
+ assertEquals(0, it.authorityRecords.size)
+ assertEquals(0, it.additionalRecords.size)
+
+ assertTrue(it.answers[0] is MdnsTextRecord)
+ assertContentEquals(arrayOf("_testservice", "_tcp", "local"), it.answers[0].name)
+ }
+
+ thread.waitForIdle(TIMEOUT_MS)
+ verify(cb).onServiceConflict(advertiser, TEST_SERVICE_ID_1)
+ }
+
+ @Test
+ fun testRestartProbingForConflict() {
+ val mockProbingInfo = mock(ProbingInfo::class.java)
+ doReturn(mockProbingInfo).`when`(repository).setServiceProbing(TEST_SERVICE_ID_1)
+
+ advertiser.restartProbingForConflict(TEST_SERVICE_ID_1)
+
+ verify(prober).restartForConflict(mockProbingInfo)
+ }
+
+ @Test
+ fun testRenameServiceForConflict() {
+ val mockProbingInfo = mock(ProbingInfo::class.java)
+ doReturn(mockProbingInfo).`when`(repository).renameServiceForConflict(
+ TEST_SERVICE_ID_1, TEST_SERVICE_1)
+
+ advertiser.renameServiceForConflict(TEST_SERVICE_ID_1, TEST_SERVICE_1)
+
+ verify(prober).restartForConflict(mockProbingInfo)
+ }
+
private fun addServiceAndFinishProbing(serviceId: Int, serviceInfo: NsdServiceInfo):
AnnouncementInfo {
val testProbingInfo = mock(ProbingInfo::class.java)
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt
index 597663c..ecc11ec 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt
@@ -24,6 +24,7 @@
import com.android.server.connectivity.mdns.MdnsAnnouncer.AnnouncementInfo
import com.android.server.connectivity.mdns.MdnsRecordRepository.Dependencies
import com.android.server.connectivity.mdns.MdnsRecordRepository.getReverseDnsAddress
+import com.android.server.connectivity.mdns.MdnsServiceInfo.TextEntry
import com.android.testutils.DevSdkIgnoreRule
import com.android.testutils.DevSdkIgnoreRunner
import java.net.InetSocketAddress
@@ -400,6 +401,63 @@
intArrayOf(MdnsRecord.TYPE_A, MdnsRecord.TYPE_AAAA)),
), reply.additionalAnswers)
}
+
+ @Test
+ fun testGetConflictingServices() {
+ val repository = MdnsRecordRepository(thread.looper, deps)
+ repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
+ repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_2)
+
+ val packet = MdnsPacket(
+ 0 /* flags */,
+ emptyList() /* questions */,
+ listOf(
+ MdnsServiceRecord(
+ arrayOf("MyTestService", "_testservice", "_tcp", "local"),
+ 0L /* receiptTimeMillis */, true /* cacheFlush */, 0L /* ttlMillis */,
+ 0 /* servicePriority */, 0 /* serviceWeight */,
+ TEST_SERVICE_1.port + 1,
+ TEST_HOSTNAME),
+ MdnsTextRecord(
+ arrayOf("MyOtherTestService", "_testservice", "_tcp", "local"),
+ 0L /* receiptTimeMillis */, true /* cacheFlush */, 0L /* ttlMillis */,
+ listOf(TextEntry.fromString("somedifferent=entry"))),
+ ) /* answers */,
+ emptyList() /* authorityRecords */,
+ emptyList() /* additionalRecords */)
+
+ assertEquals(setOf(TEST_SERVICE_ID_1, TEST_SERVICE_ID_2),
+ repository.getConflictingServices(packet))
+ }
+
+ @Test
+ fun testGetConflictingServices_IdenticalService() {
+ val repository = MdnsRecordRepository(thread.looper, deps)
+ repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
+ repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_2)
+
+ val otherTtlMillis = 1234L
+ val packet = MdnsPacket(
+ 0 /* flags */,
+ emptyList() /* questions */,
+ listOf(
+ MdnsServiceRecord(
+ arrayOf("MyTestService", "_testservice", "_tcp", "local"),
+ 0L /* receiptTimeMillis */, true /* cacheFlush */,
+ otherTtlMillis, 0 /* servicePriority */, 0 /* serviceWeight */,
+ TEST_SERVICE_1.port,
+ TEST_HOSTNAME),
+ MdnsTextRecord(
+ arrayOf("MyOtherTestService", "_testservice", "_tcp", "local"),
+ 0L /* receiptTimeMillis */, true /* cacheFlush */,
+ otherTtlMillis, emptyList()),
+ ) /* answers */,
+ emptyList() /* authorityRecords */,
+ emptyList() /* additionalRecords */)
+
+ // Above records are identical to the actual registrations: no conflict
+ assertEquals(emptySet(), repository.getConflictingServices(packet))
+ }
}
private fun MdnsRecordRepository.initWithService(serviceId: Int, serviceInfo: NsdServiceInfo):