[mdns] re-announce services when adding addresses to the associated host
Similar to aosp/2988171, this CL restarts the announcement of services
when the services' host gets new addresses.
Bug: 327304356
Bug: 323712889
Change-Id: I1d9706d4721c216dacda203bf0445f557b008ebf
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java b/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java
index 61eb766..0b2003f 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java
@@ -36,6 +36,7 @@
import java.io.IOException;
import java.net.InetSocketAddress;
+import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Set;
@@ -134,6 +135,15 @@
mAnnouncer.startSending(info.getServiceId(), announcementInfo,
0L /* initialDelayMs */);
+
+ // Re-announce the services which have the same custom hostname.
+ final String hostname = mRecordRepository.getHostnameForServiceId(info.getServiceId());
+ if (hostname != null) {
+ final List<MdnsAnnouncer.AnnouncementInfo> announcementInfos =
+ new ArrayList<>(mRecordRepository.restartAnnouncingForHostname(hostname));
+ announcementInfos.removeIf((i) -> i.getServiceId() == info.getServiceId());
+ reannounceServices(announcementInfos);
+ }
}
}
@@ -308,17 +318,10 @@
if (hostname != null) {
final List<MdnsProber.ProbingInfo> probingInfos =
mRecordRepository.restartProbingForHostname(hostname);
- for (MdnsProber.ProbingInfo probingInfo : probingInfos) {
- mProber.stop(probingInfo.getServiceId());
- mProber.startProbing(probingInfo);
- }
+ reprobeServices(probingInfos);
final List<MdnsAnnouncer.AnnouncementInfo> announcementInfos =
mRecordRepository.restartAnnouncingForHostname(hostname);
- for (MdnsAnnouncer.AnnouncementInfo announcementInfo : announcementInfos) {
- mAnnouncer.stop(announcementInfo.getServiceId());
- mAnnouncer.startSending(
- announcementInfo.getServiceId(), announcementInfo, 0 /* initialDelayMs */);
- }
+ reannounceServices(announcementInfos);
}
}
@@ -464,4 +467,19 @@
return new byte[0];
}
}
+
+ private void reprobeServices(List<MdnsProber.ProbingInfo> probingInfos) {
+ for (MdnsProber.ProbingInfo probingInfo : probingInfos) {
+ mProber.stop(probingInfo.getServiceId());
+ mProber.startProbing(probingInfo);
+ }
+ }
+
+ private void reannounceServices(List<MdnsAnnouncer.AnnouncementInfo> announcementInfos) {
+ for (MdnsAnnouncer.AnnouncementInfo announcementInfo : announcementInfos) {
+ mAnnouncer.stop(announcementInfo.getServiceId());
+ mAnnouncer.startSending(
+ announcementInfo.getServiceId(), announcementInfo, 0 /* initialDelayMs */);
+ }
+ }
}
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt
index 7ac7bee..629ac67 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt
@@ -47,7 +47,6 @@
import org.mockito.Mockito.anyInt
import org.mockito.Mockito.anyString
import org.mockito.Mockito.argThat
-import org.mockito.Mockito.atLeastOnce
import org.mockito.Mockito.doAnswer
import org.mockito.Mockito.doReturn
import org.mockito.Mockito.eq
@@ -55,7 +54,6 @@
import org.mockito.Mockito.never
import org.mockito.Mockito.times
import org.mockito.Mockito.verify
-import org.mockito.Mockito.clearInvocations
import org.mockito.Mockito.inOrder
private const val LOG_TAG = "testlogtag"
@@ -234,19 +232,49 @@
addServiceAndFinishProbing(TEST_SERVICE_ID_2, TEST_SERVICE_1_CUSTOM_HOST)
doReturn("MyTestHost")
.`when`(repository).getHostnameForServiceId(TEST_SERVICE_ID_1)
- doReturn(TEST_SERVICE_ID_2).`when`(announcementInfo).serviceId
doReturn(listOf(announcementInfo))
.`when`(repository).restartAnnouncingForHostname("MyTestHost")
- clearInvocations(announcer)
+ val inOrder = inOrder(prober, announcer)
// Remove the custom host: the custom host's announcement is stopped and the probed services
// which use that hostname are re-announced.
advertiser.removeService(TEST_SERVICE_ID_1)
- verify(prober).stop(TEST_SERVICE_ID_1)
- verify(announcer, atLeastOnce()).stop(TEST_SERVICE_ID_1)
- verify(announcer).stop(TEST_SERVICE_ID_2)
- verify(announcer).startSending(TEST_SERVICE_ID_2, announcementInfo, 0L /* initialDelayMs */)
+ inOrder.verify(prober).stop(TEST_SERVICE_ID_1)
+ inOrder.verify(announcer).stop(TEST_SERVICE_ID_1)
+ inOrder.verify(announcer).stop(TEST_SERVICE_ID_2)
+ inOrder.verify(announcer).startSending(TEST_SERVICE_ID_2, announcementInfo, 0L /* initialDelayMs */)
+ }
+
+ @Test
+ fun testAddMoreAddressesForCustomHost_restartAnnouncingForProbedServices() {
+ val customHost = NsdServiceInfo().apply {
+ hostname = "MyTestHost"
+ hostAddresses = listOf(
+ parseNumericAddress("192.0.2.23"),
+ parseNumericAddress("2001:db8::1"))
+ }
+ doReturn("MyTestHost")
+ .`when`(repository).getHostnameForServiceId(TEST_SERVICE_ID_1)
+ doReturn("MyTestHost")
+ .`when`(repository).getHostnameForServiceId(TEST_SERVICE_ID_2)
+ val announcementInfo1 =
+ addServiceAndFinishProbing(TEST_SERVICE_ID_1, TEST_SERVICE_1_CUSTOM_HOST)
+
+ val probingInfo2 = addServiceAndStartProbing(TEST_SERVICE_ID_2, customHost)
+ val announcementInfo2 = AnnouncementInfo(TEST_SERVICE_ID_2, emptyList(), emptyList())
+ doReturn(announcementInfo2).`when`(repository).onProbingSucceeded(probingInfo2)
+ doReturn(listOf(announcementInfo1, announcementInfo2))
+ .`when`(repository).restartAnnouncingForHostname("MyTestHost")
+ probeCb.onFinished(probingInfo2)
+
+ val inOrder = inOrder(prober, announcer)
+
+ inOrder.verify(announcer)
+ .startSending(TEST_SERVICE_ID_2, announcementInfo2, 0L /* initialDelayMs */)
+ inOrder.verify(announcer).stop(TEST_SERVICE_ID_1)
+ inOrder.verify(announcer)
+ .startSending(TEST_SERVICE_ID_1, announcementInfo1, 0L /* initialDelayMs */)
}
@Test
@@ -489,8 +517,8 @@
verify(prober, never()).startProbing(any())
}
- private fun addServiceAndFinishProbing(serviceId: Int, serviceInfo: NsdServiceInfo):
- AnnouncementInfo {
+ private fun addServiceAndStartProbing(serviceId: Int, serviceInfo: NsdServiceInfo):
+ ProbingInfo {
val testProbingInfo = mock(ProbingInfo::class.java)
doReturn(serviceId).`when`(testProbingInfo).serviceId
doReturn(testProbingInfo).`when`(repository).setServiceProbing(serviceId)
@@ -499,8 +527,15 @@
verify(repository).addService(serviceId, serviceInfo, null /* ttl */)
verify(prober).startProbing(testProbingInfo)
+ return testProbingInfo
+ }
+
+ private fun addServiceAndFinishProbing(serviceId: Int, serviceInfo: NsdServiceInfo):
+ AnnouncementInfo {
+ val testProbingInfo = addServiceAndStartProbing(serviceId, serviceInfo)
+
// Simulate probing success: continues to announcing
- val testAnnouncementInfo = mock(AnnouncementInfo::class.java)
+ val testAnnouncementInfo = AnnouncementInfo(serviceId, emptyList(), emptyList())
doReturn(testAnnouncementInfo).`when`(repository).onProbingSucceeded(testProbingInfo)
probeCb.onFinished(testProbingInfo)
return testAnnouncementInfo