Remove PacketRepeater destinationsSupplier logic

PacketRepeater can just try to send to both v4 and v6 multicast
addresses, and rely on MdnsReplySender to check whether the sockets have
(automatically) joined the v4 or v6 groups, so there is no need to use
this unusual lambda setup anymore.

Bug: 264947218
Test: atest MdnsProberTest MdnsAnnouncerTest
Change-Id: I09e0fa4bf14e1f31f2d2508f17e23adf1415feb7
diff --git a/service/mdns/com/android/server/connectivity/mdns/MdnsAnnouncer.java b/service/mdns/com/android/server/connectivity/mdns/MdnsAnnouncer.java
index 91e08a8..c056e69 100644
--- a/service/mdns/com/android/server/connectivity/mdns/MdnsAnnouncer.java
+++ b/service/mdns/com/android/server/connectivity/mdns/MdnsAnnouncer.java
@@ -22,10 +22,8 @@
 
 import com.android.internal.annotations.VisibleForTesting;
 
-import java.net.SocketAddress;
 import java.util.Collections;
 import java.util.List;
-import java.util.function.Supplier;
 
 /**
  * Sends mDns announcements when a service registration changes and at regular intervals.
@@ -43,11 +41,8 @@
     static class AnnouncementInfo implements MdnsPacketRepeater.Request {
         @NonNull
         private final MdnsPacket mPacket;
-        @NonNull
-        private final Supplier<Iterable<SocketAddress>> mDestinationsSupplier;
 
-        AnnouncementInfo(List<MdnsRecord> announcedRecords, List<MdnsRecord> additionalRecords,
-                Supplier<Iterable<SocketAddress>> destinationsSupplier) {
+        AnnouncementInfo(List<MdnsRecord> announcedRecords, List<MdnsRecord> additionalRecords) {
             // Records to announce (as answers)
             // Records to place in the "Additional records", with NSEC negative responses
             // to mark records that have been verified unique
@@ -57,7 +52,6 @@
                     announcedRecords,
                     Collections.emptyList() /* authorityRecords */,
                     additionalRecords);
-            mDestinationsSupplier = destinationsSupplier;
         }
 
         @Override
@@ -66,11 +60,6 @@
         }
 
         @Override
-        public Iterable<SocketAddress> getDestinations(int index) {
-            return mDestinationsSupplier.get();
-        }
-
-        @Override
         public long getDelayMs(int nextIndex) {
             // Delay is doubled for each announcement
             return ANNOUNCEMENT_INITIAL_DELAY_MS << (nextIndex - 1);
diff --git a/service/mdns/com/android/server/connectivity/mdns/MdnsPacketRepeater.java b/service/mdns/com/android/server/connectivity/mdns/MdnsPacketRepeater.java
index 015dbd8..ae54e70 100644
--- a/service/mdns/com/android/server/connectivity/mdns/MdnsPacketRepeater.java
+++ b/service/mdns/com/android/server/connectivity/mdns/MdnsPacketRepeater.java
@@ -24,7 +24,7 @@
 import android.util.Log;
 
 import java.io.IOException;
-import java.net.SocketAddress;
+import java.net.InetSocketAddress;
 
 /**
  * A class used to send several packets at given time intervals.
@@ -32,6 +32,14 @@
  */
 public abstract class MdnsPacketRepeater<T extends MdnsPacketRepeater.Request> {
     private static final boolean DBG = MdnsAdvertiser.DBG;
+    private static final InetSocketAddress IPV4_ADDR = new InetSocketAddress(
+            MdnsConstants.getMdnsIPv4Address(), MdnsConstants.MDNS_PORT);
+    private static final InetSocketAddress IPV6_ADDR = new InetSocketAddress(
+            MdnsConstants.getMdnsIPv6Address(), MdnsConstants.MDNS_PORT);
+    private static final InetSocketAddress[] ALL_ADDRS = new InetSocketAddress[] {
+            IPV4_ADDR, IPV6_ADDR
+    };
+
     @NonNull
     private final MdnsReplySender mReplySender;
     @NonNull
@@ -70,12 +78,6 @@
         MdnsPacket getPacket(int index);
 
         /**
-         * Get a set of destinations for the packet for one iteration.
-         */
-        @NonNull
-        Iterable<SocketAddress> getDestinations(int index);
-
-        /**
          * Get the delay in milliseconds until the next packet transmission.
          */
         long getDelayMs(int nextIndex);
@@ -110,12 +112,13 @@
             }
 
             final MdnsPacket packet = request.getPacket(index);
-            final Iterable<SocketAddress> destinations = request.getDestinations(index);
             if (DBG) {
-                Log.v(getTag(), "Sending packets to " + destinations + " for iteration "
-                        + index + " out of " + request.getNumSends());
+                Log.v(getTag(), "Sending packets for iteration " + index + " out of "
+                        + request.getNumSends());
             }
-            for (SocketAddress destination : destinations) {
+            // Send to both v4 and v6 addresses; the reply sender will take care of ignoring the
+            // send when the socket has not joined the relevant group.
+            for (InetSocketAddress destination : ALL_ADDRS) {
                 try {
                     mReplySender.sendNow(packet, destination);
                 } catch (IOException e) {
diff --git a/service/mdns/com/android/server/connectivity/mdns/MdnsProber.java b/service/mdns/com/android/server/connectivity/mdns/MdnsProber.java
index db7049e..9a1e62b 100644
--- a/service/mdns/com/android/server/connectivity/mdns/MdnsProber.java
+++ b/service/mdns/com/android/server/connectivity/mdns/MdnsProber.java
@@ -22,12 +22,10 @@
 import com.android.internal.annotations.VisibleForTesting;
 import com.android.net.module.util.CollectionUtils;
 
-import java.net.SocketAddress;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.List;
-import java.util.function.Supplier;
 
 /**
  * Sends mDns probe requests to verify service records are unique on the network.
@@ -51,21 +49,15 @@
         private final int mServiceId;
         @NonNull
         private final MdnsPacket mPacket;
-        @NonNull
-        private final Supplier<Iterable<SocketAddress>> mDestinationsSupplier;
 
         /**
          * Create a new ProbingInfo
          * @param serviceId Service to probe for.
          * @param probeRecords Records to be probed for uniqueness.
-         * @param destinationsSupplier Supplier for the probe destinations. Will be called on the
-         *                             probe handler thread for each probe.
          */
-        ProbingInfo(int serviceId, @NonNull List<MdnsRecord> probeRecords,
-                @NonNull Supplier<Iterable<SocketAddress>> destinationsSupplier) {
+        ProbingInfo(int serviceId, @NonNull List<MdnsRecord> probeRecords) {
             mServiceId = serviceId;
             mPacket = makePacket(probeRecords);
-            mDestinationsSupplier = destinationsSupplier;
         }
 
         public int getServiceId() {
@@ -78,12 +70,6 @@
             return mPacket;
         }
 
-        @NonNull
-        @Override
-        public Iterable<SocketAddress> getDestinations(int index) {
-            return mDestinationsSupplier.get();
-        }
-
         @Override
         public long getDelayMs(int nextIndex) {
             // As per https://datatracker.ietf.org/doc/html/rfc6762#section-8.1
diff --git a/service/mdns/com/android/server/connectivity/mdns/MdnsReplySender.java b/service/mdns/com/android/server/connectivity/mdns/MdnsReplySender.java
index adf6f4d..c6b8f47 100644
--- a/service/mdns/com/android/server/connectivity/mdns/MdnsReplySender.java
+++ b/service/mdns/com/android/server/connectivity/mdns/MdnsReplySender.java
@@ -21,8 +21,10 @@
 
 import java.io.IOException;
 import java.net.DatagramPacket;
+import java.net.Inet4Address;
+import java.net.Inet6Address;
+import java.net.InetSocketAddress;
 import java.net.MulticastSocket;
-import java.net.SocketAddress;
 
 /**
  * A class that handles sending mDNS replies to a {@link MulticastSocket}, possibly queueing them
@@ -50,11 +52,16 @@
      *
      * Must be called on the looper thread used by the {@link MdnsReplySender}.
      */
-    public void sendNow(@NonNull MdnsPacket packet, @NonNull SocketAddress destination)
+    public void sendNow(@NonNull MdnsPacket packet, @NonNull InetSocketAddress destination)
             throws IOException {
         if (Thread.currentThread() != mLooper.getThread()) {
             throw new IllegalStateException("sendNow must be called in the handler thread");
         }
+        if (!((destination.getAddress() instanceof Inet6Address && mSocket.hasJoinedIpv6())
+                || (destination.getAddress() instanceof Inet4Address && mSocket.hasJoinedIpv4()))) {
+            // Skip sending if the socket has not joined the v4/v6 group (there was no address)
+            return;
+        }
 
         // TODO: support packets over size (send in multiple packets with TC bit set)
         final MdnsPacketWriter writer = new MdnsPacketWriter(mPacketCreationBuffer);
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsAnnouncerTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsAnnouncerTest.kt
index 2051e0c..961f0f0 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsAnnouncerTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsAnnouncerTest.kt
@@ -27,7 +27,6 @@
 import java.net.DatagramPacket
 import java.net.Inet6Address
 import java.net.InetAddress
-import java.net.InetSocketAddress
 import kotlin.test.assertEquals
 import kotlin.test.assertTrue
 import org.junit.After
@@ -37,6 +36,7 @@
 import org.mockito.ArgumentCaptor
 import org.mockito.Mockito.any
 import org.mockito.Mockito.atLeast
+import org.mockito.Mockito.doReturn
 import org.mockito.Mockito.mock
 import org.mockito.Mockito.timeout
 import org.mockito.Mockito.verify
@@ -46,9 +46,6 @@
 private const val NEXT_ANNOUNCES_DELAY = 1L
 private const val TEST_TIMEOUT_MS = 1000L
 
-private val destinationsSupplier = {
-    listOf(InetSocketAddress(MdnsConstants.getMdnsIPv6Address(), MdnsConstants.MDNS_PORT)) }
-
 @RunWith(DevSdkIgnoreRunner::class)
 @IgnoreUpTo(Build.VERSION_CODES.S_V2)
 class MdnsAnnouncerTest {
@@ -59,6 +56,7 @@
 
     @Before
     fun setUp() {
+        doReturn(true).`when`(socket).hasJoinedIpv6()
         thread.start()
     }
 
@@ -70,7 +68,7 @@
     private class TestAnnouncementInfo(
         announcedRecords: List<MdnsRecord>,
         additionalRecords: List<MdnsRecord>
-    ) : AnnouncementInfo(announcedRecords, additionalRecords, destinationsSupplier) {
+    ) : AnnouncementInfo(announcedRecords, additionalRecords) {
         override fun getDelayMs(nextIndex: Int) =
                 if (nextIndex < FIRST_ANNOUNCES_COUNT) {
                     FIRST_ANNOUNCES_DELAY
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsProberTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsProberTest.kt
index a98a4b2..3caa97d 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsProberTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsProberTest.kt
@@ -25,7 +25,6 @@
 import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo
 import com.android.testutils.DevSdkIgnoreRunner
 import java.net.DatagramPacket
-import java.net.InetSocketAddress
 import java.util.concurrent.CompletableFuture
 import java.util.concurrent.TimeUnit
 import kotlin.test.assertEquals
@@ -37,15 +36,13 @@
 import org.mockito.ArgumentCaptor
 import org.mockito.Mockito.any
 import org.mockito.Mockito.atLeast
+import org.mockito.Mockito.doReturn
 import org.mockito.Mockito.mock
 import org.mockito.Mockito.never
 import org.mockito.Mockito.timeout
 import org.mockito.Mockito.times
 import org.mockito.Mockito.verify
 
-private val destinationsSupplier = {
-    listOf(InetSocketAddress(MdnsConstants.getMdnsIPv6Address(), MdnsConstants.MDNS_PORT)) }
-
 private const val TEST_TIMEOUT_MS = 10_000L
 private const val SHORT_TIMEOUT_MS = 200L
 
@@ -64,6 +61,7 @@
 
     @Before
     fun setUp() {
+        doReturn(true).`when`(socket).hasJoinedIpv6()
         thread.start()
     }
 
@@ -73,7 +71,7 @@
     }
 
     private class TestProbeInfo(probeRecords: List<MdnsRecord>, private val delayMs: Long = 1L) :
-            ProbingInfo(1 /* serviceId */, probeRecords, destinationsSupplier) {
+            ProbingInfo(1 /* serviceId */, probeRecords) {
         // Just send the packets quickly. Timing-related tests for MdnsPacketRepeater are already
         // done in MdnsAnnouncerTest.
         override fun getDelayMs(nextIndex: Int) = delayMs