[mdns] re-announce services when adding addresses to the associated host

Similar to aosp/2988171, this CL restarts the announcement of services
when the services' host gets new addresses.

Bug: 327304356
Bug: 323712889
Change-Id: I1d9706d4721c216dacda203bf0445f557b008ebf
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java b/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java
index 61eb766..0b2003f 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java
@@ -36,6 +36,7 @@
 
 import java.io.IOException;
 import java.net.InetSocketAddress;
+import java.util.ArrayList;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
@@ -134,6 +135,15 @@
 
             mAnnouncer.startSending(info.getServiceId(), announcementInfo,
                     0L /* initialDelayMs */);
+
+            // Re-announce the services which have the same custom hostname.
+            final String hostname = mRecordRepository.getHostnameForServiceId(info.getServiceId());
+            if (hostname != null) {
+                final List<MdnsAnnouncer.AnnouncementInfo> announcementInfos =
+                        new ArrayList<>(mRecordRepository.restartAnnouncingForHostname(hostname));
+                announcementInfos.removeIf((i) -> i.getServiceId() == info.getServiceId());
+                reannounceServices(announcementInfos);
+            }
         }
     }
 
@@ -308,17 +318,10 @@
         if (hostname != null) {
             final List<MdnsProber.ProbingInfo> probingInfos =
                     mRecordRepository.restartProbingForHostname(hostname);
-            for (MdnsProber.ProbingInfo probingInfo : probingInfos) {
-                mProber.stop(probingInfo.getServiceId());
-                mProber.startProbing(probingInfo);
-            }
+            reprobeServices(probingInfos);
             final List<MdnsAnnouncer.AnnouncementInfo> announcementInfos =
                     mRecordRepository.restartAnnouncingForHostname(hostname);
-            for (MdnsAnnouncer.AnnouncementInfo announcementInfo : announcementInfos) {
-                mAnnouncer.stop(announcementInfo.getServiceId());
-                mAnnouncer.startSending(
-                        announcementInfo.getServiceId(), announcementInfo, 0 /* initialDelayMs */);
-            }
+            reannounceServices(announcementInfos);
         }
     }
 
@@ -464,4 +467,19 @@
             return new byte[0];
         }
     }
+
+    private void reprobeServices(List<MdnsProber.ProbingInfo> probingInfos) {
+        for (MdnsProber.ProbingInfo probingInfo : probingInfos) {
+            mProber.stop(probingInfo.getServiceId());
+            mProber.startProbing(probingInfo);
+        }
+    }
+
+    private void reannounceServices(List<MdnsAnnouncer.AnnouncementInfo> announcementInfos) {
+        for (MdnsAnnouncer.AnnouncementInfo announcementInfo : announcementInfos) {
+            mAnnouncer.stop(announcementInfo.getServiceId());
+            mAnnouncer.startSending(
+                    announcementInfo.getServiceId(), announcementInfo, 0 /* initialDelayMs */);
+        }
+    }
 }
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt
index 7ac7bee..629ac67 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt
@@ -47,7 +47,6 @@
 import org.mockito.Mockito.anyInt
 import org.mockito.Mockito.anyString
 import org.mockito.Mockito.argThat
-import org.mockito.Mockito.atLeastOnce
 import org.mockito.Mockito.doAnswer
 import org.mockito.Mockito.doReturn
 import org.mockito.Mockito.eq
@@ -55,7 +54,6 @@
 import org.mockito.Mockito.never
 import org.mockito.Mockito.times
 import org.mockito.Mockito.verify
-import org.mockito.Mockito.clearInvocations
 import org.mockito.Mockito.inOrder
 
 private const val LOG_TAG = "testlogtag"
@@ -234,19 +232,49 @@
                 addServiceAndFinishProbing(TEST_SERVICE_ID_2, TEST_SERVICE_1_CUSTOM_HOST)
         doReturn("MyTestHost")
                 .`when`(repository).getHostnameForServiceId(TEST_SERVICE_ID_1)
-        doReturn(TEST_SERVICE_ID_2).`when`(announcementInfo).serviceId
         doReturn(listOf(announcementInfo))
                 .`when`(repository).restartAnnouncingForHostname("MyTestHost")
-        clearInvocations(announcer)
+        val inOrder = inOrder(prober, announcer)
 
         // Remove the custom host: the custom host's announcement is stopped and the probed services
         // which use that hostname are re-announced.
         advertiser.removeService(TEST_SERVICE_ID_1)
 
-        verify(prober).stop(TEST_SERVICE_ID_1)
-        verify(announcer, atLeastOnce()).stop(TEST_SERVICE_ID_1)
-        verify(announcer).stop(TEST_SERVICE_ID_2)
-        verify(announcer).startSending(TEST_SERVICE_ID_2, announcementInfo, 0L /* initialDelayMs */)
+        inOrder.verify(prober).stop(TEST_SERVICE_ID_1)
+        inOrder.verify(announcer).stop(TEST_SERVICE_ID_1)
+        inOrder.verify(announcer).stop(TEST_SERVICE_ID_2)
+        inOrder.verify(announcer).startSending(TEST_SERVICE_ID_2, announcementInfo, 0L /* initialDelayMs */)
+    }
+
+    @Test
+    fun testAddMoreAddressesForCustomHost_restartAnnouncingForProbedServices() {
+        val customHost = NsdServiceInfo().apply {
+            hostname = "MyTestHost"
+            hostAddresses = listOf(
+                parseNumericAddress("192.0.2.23"),
+                parseNumericAddress("2001:db8::1"))
+        }
+        doReturn("MyTestHost")
+            .`when`(repository).getHostnameForServiceId(TEST_SERVICE_ID_1)
+        doReturn("MyTestHost")
+            .`when`(repository).getHostnameForServiceId(TEST_SERVICE_ID_2)
+        val announcementInfo1 =
+            addServiceAndFinishProbing(TEST_SERVICE_ID_1, TEST_SERVICE_1_CUSTOM_HOST)
+
+        val probingInfo2 = addServiceAndStartProbing(TEST_SERVICE_ID_2, customHost)
+        val announcementInfo2 = AnnouncementInfo(TEST_SERVICE_ID_2, emptyList(), emptyList())
+        doReturn(announcementInfo2).`when`(repository).onProbingSucceeded(probingInfo2)
+        doReturn(listOf(announcementInfo1, announcementInfo2))
+            .`when`(repository).restartAnnouncingForHostname("MyTestHost")
+        probeCb.onFinished(probingInfo2)
+
+        val inOrder = inOrder(prober, announcer)
+
+        inOrder.verify(announcer)
+            .startSending(TEST_SERVICE_ID_2, announcementInfo2, 0L /* initialDelayMs */)
+        inOrder.verify(announcer).stop(TEST_SERVICE_ID_1)
+        inOrder.verify(announcer)
+            .startSending(TEST_SERVICE_ID_1, announcementInfo1, 0L /* initialDelayMs */)
     }
 
     @Test
@@ -489,8 +517,8 @@
         verify(prober, never()).startProbing(any())
     }
 
-    private fun addServiceAndFinishProbing(serviceId: Int, serviceInfo: NsdServiceInfo):
-            AnnouncementInfo {
+    private fun addServiceAndStartProbing(serviceId: Int, serviceInfo: NsdServiceInfo):
+            ProbingInfo {
         val testProbingInfo = mock(ProbingInfo::class.java)
         doReturn(serviceId).`when`(testProbingInfo).serviceId
         doReturn(testProbingInfo).`when`(repository).setServiceProbing(serviceId)
@@ -499,8 +527,15 @@
         verify(repository).addService(serviceId, serviceInfo, null /* ttl */)
         verify(prober).startProbing(testProbingInfo)
 
+        return testProbingInfo
+    }
+
+    private fun addServiceAndFinishProbing(serviceId: Int, serviceInfo: NsdServiceInfo):
+            AnnouncementInfo {
+        val testProbingInfo = addServiceAndStartProbing(serviceId, serviceInfo)
+
         // Simulate probing success: continues to announcing
-        val testAnnouncementInfo = mock(AnnouncementInfo::class.java)
+        val testAnnouncementInfo = AnnouncementInfo(serviceId, emptyList(), emptyList())
         doReturn(testAnnouncementInfo).`when`(repository).onProbingSucceeded(testProbingInfo)
         probeCb.onFinished(testProbingInfo)
         return testAnnouncementInfo