Implement exit announcements
Build ExitAnnouncementInfo in MdnsRecordRepository.exitService. Use a
separate class for AnnouncementInfo and ExitAnnouncementInfo, so
announcement callbacks can differentiate each case.
Bug: 241738458
Test: atest
Change-Id: I3b1ad1bef3dc1514479d7c789ef06b6a7de02e59
diff --git a/service-t/src/com/android/server/mdns/MdnsAnnouncer.java b/service-t/src/com/android/server/mdns/MdnsAnnouncer.java
index 7c84323..27fc945 100644
--- a/service-t/src/com/android/server/mdns/MdnsAnnouncer.java
+++ b/service-t/src/com/android/server/mdns/MdnsAnnouncer.java
@@ -30,23 +30,27 @@
*
* This allows maintaining other hosts' caches up-to-date. See RFC6762 8.3.
*/
-public class MdnsAnnouncer extends MdnsPacketRepeater<MdnsAnnouncer.AnnouncementInfo> {
+public class MdnsAnnouncer extends MdnsPacketRepeater<MdnsAnnouncer.BaseAnnouncementInfo> {
private static final long ANNOUNCEMENT_INITIAL_DELAY_MS = 1000L;
@VisibleForTesting
static final int ANNOUNCEMENT_COUNT = 8;
+ // Matches delay and GoodbyeCount used by the legacy implementation
+ private static final long EXIT_DELAY_MS = 2000L;
+ private static final int EXIT_COUNT = 3;
+
@NonNull
private final String mLogTag;
- /** Announcement request to send with {@link MdnsAnnouncer}. */
- public static class AnnouncementInfo implements MdnsPacketRepeater.Request {
+ /** Base class for announcement requests to send with {@link MdnsAnnouncer}. */
+ public abstract static class BaseAnnouncementInfo implements MdnsPacketRepeater.Request {
+ private final int mServiceId;
@NonNull
private final MdnsPacket mPacket;
- AnnouncementInfo(List<MdnsRecord> announcedRecords, List<MdnsRecord> additionalRecords) {
- // Records to announce (as answers)
- // Records to place in the "Additional records", with NSEC negative responses
- // to mark records that have been verified unique
+ protected BaseAnnouncementInfo(int serviceId, @NonNull List<MdnsRecord> announcedRecords,
+ @NonNull List<MdnsRecord> additionalRecords) {
+ mServiceId = serviceId;
final int flags = 0x8400; // Response, authoritative (rfc6762 18.4)
mPacket = new MdnsPacket(flags,
Collections.emptyList() /* questions */,
@@ -55,10 +59,22 @@
additionalRecords);
}
+ public int getServiceId() {
+ return mServiceId;
+ }
+
@Override
public MdnsPacket getPacket(int index) {
return mPacket;
}
+ }
+
+ /** Announcement request to send with {@link MdnsAnnouncer}. */
+ public static class AnnouncementInfo extends BaseAnnouncementInfo {
+ AnnouncementInfo(int serviceId, List<MdnsRecord> announcedRecords,
+ List<MdnsRecord> additionalRecords) {
+ super(serviceId, announcedRecords, additionalRecords);
+ }
@Override
public long getDelayMs(int nextIndex) {
@@ -72,9 +88,26 @@
}
}
+ /** Service exit announcement request to send with {@link MdnsAnnouncer}. */
+ public static class ExitAnnouncementInfo extends BaseAnnouncementInfo {
+ ExitAnnouncementInfo(int serviceId, List<MdnsRecord> announcedRecords) {
+ super(serviceId, announcedRecords, Collections.emptyList() /* additionalRecords */);
+ }
+
+ @Override
+ public long getDelayMs(int nextIndex) {
+ return EXIT_DELAY_MS;
+ }
+
+ @Override
+ public int getNumSends() {
+ return EXIT_COUNT;
+ }
+ }
+
public MdnsAnnouncer(@NonNull String interfaceTag, @NonNull Looper looper,
@NonNull MdnsReplySender replySender,
- @Nullable PacketRepeaterCallback<AnnouncementInfo> cb) {
+ @Nullable PacketRepeaterCallback<BaseAnnouncementInfo> cb) {
super(looper, replySender, cb);
mLogTag = MdnsAnnouncer.class.getSimpleName() + "/" + interfaceTag;
}
diff --git a/service-t/src/com/android/server/mdns/MdnsInterfaceAdvertiser.java b/service-t/src/com/android/server/mdns/MdnsInterfaceAdvertiser.java
index 997dcbb..790e69a 100644
--- a/service-t/src/com/android/server/mdns/MdnsInterfaceAdvertiser.java
+++ b/service-t/src/com/android/server/mdns/MdnsInterfaceAdvertiser.java
@@ -25,6 +25,7 @@
import android.util.Log;
import com.android.internal.annotations.VisibleForTesting;
+import com.android.server.connectivity.mdns.MdnsAnnouncer.BaseAnnouncementInfo;
import com.android.server.connectivity.mdns.MdnsPacketRepeater.PacketRepeaterCallback;
import java.io.IOException;
@@ -113,9 +114,22 @@
/**
* Callbacks from {@link MdnsAnnouncer}.
*/
- private class AnnouncingCallback
- implements PacketRepeaterCallback<MdnsAnnouncer.AnnouncementInfo> {
- // TODO: implement
+ private class AnnouncingCallback implements PacketRepeaterCallback<BaseAnnouncementInfo> {
+ @Override
+ public void onSent(int index, @NonNull BaseAnnouncementInfo info) {
+ mRecordRepository.onAdvertisementSent(info.getServiceId());
+ }
+
+ @Override
+ public void onFinished(@NonNull BaseAnnouncementInfo info) {
+ if (info instanceof MdnsAnnouncer.ExitAnnouncementInfo) {
+ mRecordRepository.removeService(info.getServiceId());
+
+ if (mRecordRepository.getServicesCount() == 0) {
+ destroyNow();
+ }
+ }
+ }
}
/**
@@ -139,7 +153,7 @@
/** @see MdnsAnnouncer */
public MdnsAnnouncer makeMdnsAnnouncer(@NonNull String interfaceTag, @NonNull Looper looper,
@NonNull MdnsReplySender replySender,
- @Nullable PacketRepeaterCallback<MdnsAnnouncer.AnnouncementInfo> cb) {
+ @Nullable PacketRepeaterCallback<MdnsAnnouncer.BaseAnnouncementInfo> cb) {
return new MdnsAnnouncer(interfaceTag, looper, replySender, cb);
}
@@ -210,9 +224,10 @@
* This will trigger exit announcements for the service.
*/
public void removeService(int id) {
+ if (!mRecordRepository.hasActiveService(id)) return;
mProber.stop(id);
mAnnouncer.stop(id);
- final MdnsAnnouncer.AnnouncementInfo exitInfo = mRecordRepository.exitService(id);
+ final MdnsAnnouncer.ExitAnnouncementInfo exitInfo = mRecordRepository.exitService(id);
if (exitInfo != null) {
// This effectively schedules destroyNow(), as it is to be called when the exit
// announcement finishes if there is no service left.
diff --git a/service-t/src/com/android/server/mdns/MdnsRecordRepository.java b/service-t/src/com/android/server/mdns/MdnsRecordRepository.java
index 5915f8b..dd00212 100644
--- a/service-t/src/com/android/server/mdns/MdnsRecordRepository.java
+++ b/service-t/src/com/android/server/mdns/MdnsRecordRepository.java
@@ -23,6 +23,7 @@
import android.net.nsd.NsdServiceInfo;
import android.os.Build;
import android.os.Looper;
+import android.os.SystemClock;
import android.util.ArraySet;
import android.util.SparseArray;
@@ -137,6 +138,12 @@
*/
public boolean isProbing;
+ /**
+ * Last time (as per SystemClock.elapsedRealtime) when sent via unicast or multicast,
+ * 0 if never
+ */
+ public long lastSentTimeMs;
+
RecordInfo(NsdServiceInfo serviceInfo, T record, boolean sharedName,
boolean probing) {
this.serviceInfo = serviceInfo;
@@ -333,20 +340,31 @@
* @return The exit announcement to indicate the service was removed, or null if not necessary.
*/
@Nullable
- public MdnsAnnouncer.AnnouncementInfo exitService(int id) {
+ public MdnsAnnouncer.ExitAnnouncementInfo exitService(int id) {
final ServiceRegistration registration = mServices.get(id);
if (registration == null) return null;
if (registration.exiting) return null;
- registration.exiting = true;
+ // Send exit (TTL 0) for the PTR record, if the record was sent (in particular don't send
+ // if still probing)
+ if (registration.ptrRecord.lastSentTimeMs == 0L) {
+ return null;
+ }
- // TODO: implement
- return null;
+ registration.exiting = true;
+ final MdnsPointerRecord expiredRecord = new MdnsPointerRecord(
+ registration.ptrRecord.record.getName(),
+ 0L /* receiptTimeMillis */,
+ true /* cacheFlush */,
+ 0L /* ttlMillis */,
+ registration.ptrRecord.record.getPointer());
+
+ // Exit should be skipped if the record is still advertised by another service, but that
+ // would be a conflict (2 service registrations with the same service name), so it would
+ // not have been allowed by the repository.
+ return new MdnsAnnouncer.ExitAnnouncementInfo(id, Collections.singletonList(expiredRecord));
}
- /**
- * Remove a service from the repository
- */
public void removeService(int id) {
mServices.remove(id);
}
@@ -475,7 +493,8 @@
addNsecRecordsForUniqueNames(additionalAnswers,
mGeneralRecords.iterator(), registration.allRecords.iterator());
- return new MdnsAnnouncer.AnnouncementInfo(answers, additionalAnswers);
+ return new MdnsAnnouncer.AnnouncementInfo(probeSuccessInfo.getServiceId(),
+ answers, additionalAnswers);
}
/**
@@ -502,8 +521,31 @@
}
/**
+ * Return whether the repository has an active (non-exiting) service for the given ID.
+ */
+ public boolean hasActiveService(int serviceId) {
+ final ServiceRegistration registration = mServices.get(serviceId);
+ if (registration == null) return false;
+
+ return !registration.exiting;
+ }
+
+ /**
+ * Called when {@link MdnsAdvertiser} sent an advertisement for the given service.
+ */
+ public void onAdvertisementSent(int serviceId) {
+ final ServiceRegistration registration = mServices.get(serviceId);
+ if (registration == null) return;
+
+ final long now = SystemClock.elapsedRealtime();
+ for (RecordInfo<?> record : registration.allRecords) {
+ record.lastSentTimeMs = now;
+ }
+ }
+
+ /**
* Compute:
- * 2001:db8::1 --> 1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.B.D.1.0.0.2.ip6.arpa
+ * 2001:db8::1 --> 1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.B.D.0.1.0.0.2.ip6.arpa
*
* Or:
* 192.0.2.123 --> 123.2.0.192.in-addr.arpa
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsAnnouncerTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsAnnouncerTest.kt
index abb2627..650607d 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsAnnouncerTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsAnnouncerTest.kt
@@ -22,6 +22,7 @@
import android.os.SystemClock
import com.android.internal.util.HexDump
import com.android.server.connectivity.mdns.MdnsAnnouncer.AnnouncementInfo
+import com.android.server.connectivity.mdns.MdnsAnnouncer.BaseAnnouncementInfo
import com.android.server.connectivity.mdns.MdnsRecordRepository.getReverseDnsAddress
import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo
import com.android.testutils.DevSdkIgnoreRunner
@@ -67,7 +68,7 @@
private class TestAnnouncementInfo(
announcedRecords: List<MdnsRecord>,
additionalRecords: List<MdnsRecord>
- ) : AnnouncementInfo(announcedRecords, additionalRecords) {
+ ) : AnnouncementInfo(1 /* serviceId */, announcedRecords, additionalRecords) {
override fun getDelayMs(nextIndex: Int) =
if (nextIndex < FIRST_ANNOUNCES_COUNT) {
FIRST_ANNOUNCES_DELAY
@@ -81,7 +82,7 @@
val replySender = MdnsReplySender(thread.looper, socket, buffer)
@Suppress("UNCHECKED_CAST")
val cb = mock(MdnsPacketRepeater.PacketRepeaterCallback::class.java)
- as MdnsPacketRepeater.PacketRepeaterCallback<AnnouncementInfo>
+ as MdnsPacketRepeater.PacketRepeaterCallback<BaseAnnouncementInfo>
val announcer = MdnsAnnouncer("testiface", thread.looper, replySender, cb)
/*
The expected packet replicates records announced when registering a service, as observed in
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt
index bcfcf1d..2cb0850 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt
@@ -22,6 +22,8 @@
import android.os.Build
import android.os.HandlerThread
import com.android.server.connectivity.mdns.MdnsAnnouncer.AnnouncementInfo
+import com.android.server.connectivity.mdns.MdnsAnnouncer.BaseAnnouncementInfo
+import com.android.server.connectivity.mdns.MdnsAnnouncer.ExitAnnouncementInfo
import com.android.server.connectivity.mdns.MdnsInterfaceAdvertiser.EXIT_ANNOUNCEMENT_DELAY_MS
import com.android.server.connectivity.mdns.MdnsPacketRepeater.PacketRepeaterCallback
import com.android.server.connectivity.mdns.MdnsProber.ProbingInfo
@@ -35,8 +37,10 @@
import org.mockito.ArgumentCaptor
import org.mockito.Mockito.any
import org.mockito.Mockito.anyInt
+import org.mockito.Mockito.doAnswer
import org.mockito.Mockito.doReturn
import org.mockito.Mockito.mock
+import org.mockito.Mockito.times
import org.mockito.Mockito.verify
private const val LOG_TAG = "testlogtag"
@@ -66,7 +70,7 @@
private val probeCbCaptor = ArgumentCaptor.forClass(PacketRepeaterCallback::class.java)
as ArgumentCaptor<PacketRepeaterCallback<ProbingInfo>>
private val announceCbCaptor = ArgumentCaptor.forClass(PacketRepeaterCallback::class.java)
- as ArgumentCaptor<PacketRepeaterCallback<AnnouncementInfo>>
+ as ArgumentCaptor<PacketRepeaterCallback<BaseAnnouncementInfo>>
private val probeCb get() = probeCbCaptor.value
private val announceCb get() = announceCbCaptor.value
@@ -82,7 +86,21 @@
doReturn(announcer).`when`(deps).makeMdnsAnnouncer(any(), any(), any(), any())
doReturn(prober).`when`(deps).makeMdnsProber(any(), any(), any(), any())
- doReturn(-1).`when`(repository).addService(anyInt(), any())
+ val knownServices = mutableSetOf<Int>()
+ doAnswer { inv ->
+ knownServices.add(inv.getArgument(0))
+ -1
+ }.`when`(repository).addService(anyInt(), any())
+ doAnswer { inv ->
+ knownServices.remove(inv.getArgument(0))
+ null
+ }.`when`(repository).removeService(anyInt())
+ doAnswer {
+ knownServices.toIntArray().also { knownServices.clear() }
+ }.`when`(repository).clearServices()
+ doAnswer { inv ->
+ knownServices.contains(inv.getArgument(0))
+ }.`when`(repository).hasActiveService(anyInt())
thread.start()
advertiser.start()
@@ -97,18 +115,7 @@
@Test
fun testAddRemoveService() {
- val testProbingInfo = mock(ProbingInfo::class.java)
- doReturn(TEST_SERVICE_ID_1).`when`(testProbingInfo).serviceId
- doReturn(testProbingInfo).`when`(repository).setServiceProbing(TEST_SERVICE_ID_1)
-
- advertiser.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
- verify(repository).addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
- verify(prober).startProbing(testProbingInfo)
-
- // Simulate probing success: continues to announcing
- val testAnnouncementInfo = mock(AnnouncementInfo::class.java)
- doReturn(testAnnouncementInfo).`when`(repository).onProbingSucceeded(testProbingInfo)
- probeCb.onFinished(testProbingInfo)
+ val testAnnouncementInfo = addServiceAndFinishProbing(TEST_SERVICE_ID_1, TEST_SERVICE_1)
verify(announcer).startSending(TEST_SERVICE_ID_1, testAnnouncementInfo,
0L /* initialDelayMs */)
@@ -117,7 +124,7 @@
verify(cb).onRegisterServiceSucceeded(advertiser, TEST_SERVICE_ID_1)
// Remove the service: expect exit announcements
- val testExitInfo = mock(AnnouncementInfo::class.java)
+ val testExitInfo = mock(ExitAnnouncementInfo::class.java)
doReturn(testExitInfo).`when`(repository).exitService(TEST_SERVICE_ID_1)
advertiser.removeService(TEST_SERVICE_ID_1)
@@ -125,7 +132,45 @@
verify(announcer).stop(TEST_SERVICE_ID_1)
verify(announcer).startSending(TEST_SERVICE_ID_1, testExitInfo, EXIT_ANNOUNCEMENT_DELAY_MS)
- // TODO: after exit announcements are implemented, verify that announceCb.onFinished causes
- // cb.onDestroyed to be called.
+ // Exit announcements finish: the advertiser has no left service and destroys itself
+ announceCb.onFinished(testExitInfo)
+ thread.waitForIdle(TIMEOUT_MS)
+ verify(cb).onDestroyed(socket)
+ }
+
+ @Test
+ fun testDoubleRemove() {
+ addServiceAndFinishProbing(TEST_SERVICE_ID_1, TEST_SERVICE_1)
+
+ val testExitInfo = mock(ExitAnnouncementInfo::class.java)
+ doReturn(testExitInfo).`when`(repository).exitService(TEST_SERVICE_ID_1)
+ advertiser.removeService(TEST_SERVICE_ID_1)
+
+ verify(prober).stop(TEST_SERVICE_ID_1)
+ verify(announcer).stop(TEST_SERVICE_ID_1)
+ verify(announcer).startSending(TEST_SERVICE_ID_1, testExitInfo, EXIT_ANNOUNCEMENT_DELAY_MS)
+
+ doReturn(false).`when`(repository).hasActiveService(TEST_SERVICE_ID_1)
+ advertiser.removeService(TEST_SERVICE_ID_1)
+ // Prober, announcer were still stopped only one time
+ verify(prober, times(1)).stop(TEST_SERVICE_ID_1)
+ verify(announcer, times(1)).stop(TEST_SERVICE_ID_1)
+ }
+
+ private fun addServiceAndFinishProbing(serviceId: Int, serviceInfo: NsdServiceInfo):
+ AnnouncementInfo {
+ val testProbingInfo = mock(ProbingInfo::class.java)
+ doReturn(serviceId).`when`(testProbingInfo).serviceId
+ doReturn(testProbingInfo).`when`(repository).setServiceProbing(serviceId)
+
+ advertiser.addService(serviceId, serviceInfo)
+ verify(repository).addService(serviceId, serviceInfo)
+ verify(prober).startProbing(testProbingInfo)
+
+ // Simulate probing success: continues to announcing
+ val testAnnouncementInfo = mock(AnnouncementInfo::class.java)
+ doReturn(testAnnouncementInfo).`when`(repository).onProbingSucceeded(testProbingInfo)
+ probeCb.onFinished(testProbingInfo)
+ return testAnnouncementInfo
}
}
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt
index aa249e5..29d0854 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt
@@ -30,6 +30,7 @@
import kotlin.test.assertContentEquals
import kotlin.test.assertEquals
import kotlin.test.assertFailsWith
+import kotlin.test.assertFalse
import kotlin.test.assertNotNull
import kotlin.test.assertTrue
import org.junit.After
@@ -130,9 +131,62 @@
}
@Test
+ fun testHasActiveService() {
+ val repository = MdnsRecordRepository(thread.looper, deps)
+ assertFalse(repository.hasActiveService(TEST_SERVICE_ID_1))
+
+ repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
+ assertTrue(repository.hasActiveService(TEST_SERVICE_ID_1))
+
+ val probingInfo = repository.setServiceProbing(TEST_SERVICE_ID_1)
+ repository.onProbingSucceeded(probingInfo)
+ repository.onAdvertisementSent(TEST_SERVICE_ID_1)
+ assertTrue(repository.hasActiveService(TEST_SERVICE_ID_1))
+
+ repository.exitService(TEST_SERVICE_ID_1)
+ assertFalse(repository.hasActiveService(TEST_SERVICE_ID_1))
+ }
+
+ @Test
+ fun testExitAnnouncements() {
+ val repository = MdnsRecordRepository(thread.looper, deps)
+ repository.updateAddresses(TEST_ADDRESSES)
+
+ repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
+ val probingInfo = repository.setServiceProbing(TEST_SERVICE_ID_1)
+ repository.onProbingSucceeded(probingInfo)
+ repository.onAdvertisementSent(TEST_SERVICE_ID_1)
+
+ val exitAnnouncement = repository.exitService(TEST_SERVICE_ID_1)
+ assertNotNull(exitAnnouncement)
+ assertEquals(1, repository.servicesCount)
+ val packet = exitAnnouncement.getPacket(0)
+
+ assertEquals(0x8400 /* response, authoritative */, packet.flags)
+ assertEquals(0, packet.questions.size)
+ assertEquals(0, packet.authorityRecords.size)
+ assertEquals(0, packet.additionalRecords.size)
+
+ assertContentEquals(listOf(
+ MdnsPointerRecord(
+ arrayOf("_testservice", "_tcp", "local"),
+ 0L /* receiptTimeMillis */,
+ true /* cacheFlush */,
+ 0L /* ttlMillis */,
+ arrayOf("MyTestService", "_testservice", "_tcp", "local"))
+ ), packet.answers)
+
+ repository.removeService(TEST_SERVICE_ID_1)
+ assertEquals(0, repository.servicesCount)
+ }
+
+ @Test
fun testExitingServiceReAdded() {
val repository = MdnsRecordRepository(thread.looper, deps)
repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
+ val probingInfo = repository.setServiceProbing(TEST_SERVICE_ID_1)
+ repository.onProbingSucceeded(probingInfo)
+ repository.onAdvertisementSent(TEST_SERVICE_ID_1)
repository.exitService(TEST_SERVICE_ID_1)
assertEquals(TEST_SERVICE_ID_1, repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_1))
@@ -262,7 +316,7 @@
@Test
fun testGetReverseDnsAddress() {
- val expectedV6 = "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.B.D.1.0.0.2.ip6.arpa"
+ val expectedV6 = "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.B.D.0.1.0.0.2.ip6.arpa"
.split(".").toTypedArray()
assertContentEquals(expectedV6, getReverseDnsAddress(parseNumericAddress("2001:db8::1")))
val expectedV4 = "123.2.0.192.in-addr.arpa".split(".").toTypedArray()