Remove PacketRepeater destinationsSupplier logic
PacketRepeater can just try to send to both v4 and v6 multicast
addresses, and rely on MdnsReplySender to check whether the sockets have
(automatically) joined the v4 or v6 groups, so there is no need to use
this unusual lambda setup anymore.
Bug: 264947218
Test: atest MdnsProberTest MdnsAnnouncerTest
Change-Id: I09e0fa4bf14e1f31f2d2508f17e23adf1415feb7
diff --git a/service/mdns/com/android/server/connectivity/mdns/MdnsAnnouncer.java b/service/mdns/com/android/server/connectivity/mdns/MdnsAnnouncer.java
index 91e08a8..c056e69 100644
--- a/service/mdns/com/android/server/connectivity/mdns/MdnsAnnouncer.java
+++ b/service/mdns/com/android/server/connectivity/mdns/MdnsAnnouncer.java
@@ -22,10 +22,8 @@
import com.android.internal.annotations.VisibleForTesting;
-import java.net.SocketAddress;
import java.util.Collections;
import java.util.List;
-import java.util.function.Supplier;
/**
* Sends mDns announcements when a service registration changes and at regular intervals.
@@ -43,11 +41,8 @@
static class AnnouncementInfo implements MdnsPacketRepeater.Request {
@NonNull
private final MdnsPacket mPacket;
- @NonNull
- private final Supplier<Iterable<SocketAddress>> mDestinationsSupplier;
- AnnouncementInfo(List<MdnsRecord> announcedRecords, List<MdnsRecord> additionalRecords,
- Supplier<Iterable<SocketAddress>> destinationsSupplier) {
+ AnnouncementInfo(List<MdnsRecord> announcedRecords, List<MdnsRecord> additionalRecords) {
// Records to announce (as answers)
// Records to place in the "Additional records", with NSEC negative responses
// to mark records that have been verified unique
@@ -57,7 +52,6 @@
announcedRecords,
Collections.emptyList() /* authorityRecords */,
additionalRecords);
- mDestinationsSupplier = destinationsSupplier;
}
@Override
@@ -66,11 +60,6 @@
}
@Override
- public Iterable<SocketAddress> getDestinations(int index) {
- return mDestinationsSupplier.get();
- }
-
- @Override
public long getDelayMs(int nextIndex) {
// Delay is doubled for each announcement
return ANNOUNCEMENT_INITIAL_DELAY_MS << (nextIndex - 1);
diff --git a/service/mdns/com/android/server/connectivity/mdns/MdnsPacketRepeater.java b/service/mdns/com/android/server/connectivity/mdns/MdnsPacketRepeater.java
index 015dbd8..ae54e70 100644
--- a/service/mdns/com/android/server/connectivity/mdns/MdnsPacketRepeater.java
+++ b/service/mdns/com/android/server/connectivity/mdns/MdnsPacketRepeater.java
@@ -24,7 +24,7 @@
import android.util.Log;
import java.io.IOException;
-import java.net.SocketAddress;
+import java.net.InetSocketAddress;
/**
* A class used to send several packets at given time intervals.
@@ -32,6 +32,14 @@
*/
public abstract class MdnsPacketRepeater<T extends MdnsPacketRepeater.Request> {
private static final boolean DBG = MdnsAdvertiser.DBG;
+ private static final InetSocketAddress IPV4_ADDR = new InetSocketAddress(
+ MdnsConstants.getMdnsIPv4Address(), MdnsConstants.MDNS_PORT);
+ private static final InetSocketAddress IPV6_ADDR = new InetSocketAddress(
+ MdnsConstants.getMdnsIPv6Address(), MdnsConstants.MDNS_PORT);
+ private static final InetSocketAddress[] ALL_ADDRS = new InetSocketAddress[] {
+ IPV4_ADDR, IPV6_ADDR
+ };
+
@NonNull
private final MdnsReplySender mReplySender;
@NonNull
@@ -70,12 +78,6 @@
MdnsPacket getPacket(int index);
/**
- * Get a set of destinations for the packet for one iteration.
- */
- @NonNull
- Iterable<SocketAddress> getDestinations(int index);
-
- /**
* Get the delay in milliseconds until the next packet transmission.
*/
long getDelayMs(int nextIndex);
@@ -110,12 +112,13 @@
}
final MdnsPacket packet = request.getPacket(index);
- final Iterable<SocketAddress> destinations = request.getDestinations(index);
if (DBG) {
- Log.v(getTag(), "Sending packets to " + destinations + " for iteration "
- + index + " out of " + request.getNumSends());
+ Log.v(getTag(), "Sending packets for iteration " + index + " out of "
+ + request.getNumSends());
}
- for (SocketAddress destination : destinations) {
+ // Send to both v4 and v6 addresses; the reply sender will take care of ignoring the
+ // send when the socket has not joined the relevant group.
+ for (InetSocketAddress destination : ALL_ADDRS) {
try {
mReplySender.sendNow(packet, destination);
} catch (IOException e) {
diff --git a/service/mdns/com/android/server/connectivity/mdns/MdnsProber.java b/service/mdns/com/android/server/connectivity/mdns/MdnsProber.java
index db7049e..9a1e62b 100644
--- a/service/mdns/com/android/server/connectivity/mdns/MdnsProber.java
+++ b/service/mdns/com/android/server/connectivity/mdns/MdnsProber.java
@@ -22,12 +22,10 @@
import com.android.internal.annotations.VisibleForTesting;
import com.android.net.module.util.CollectionUtils;
-import java.net.SocketAddress;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
-import java.util.function.Supplier;
/**
* Sends mDns probe requests to verify service records are unique on the network.
@@ -51,21 +49,15 @@
private final int mServiceId;
@NonNull
private final MdnsPacket mPacket;
- @NonNull
- private final Supplier<Iterable<SocketAddress>> mDestinationsSupplier;
/**
* Create a new ProbingInfo
* @param serviceId Service to probe for.
* @param probeRecords Records to be probed for uniqueness.
- * @param destinationsSupplier Supplier for the probe destinations. Will be called on the
- * probe handler thread for each probe.
*/
- ProbingInfo(int serviceId, @NonNull List<MdnsRecord> probeRecords,
- @NonNull Supplier<Iterable<SocketAddress>> destinationsSupplier) {
+ ProbingInfo(int serviceId, @NonNull List<MdnsRecord> probeRecords) {
mServiceId = serviceId;
mPacket = makePacket(probeRecords);
- mDestinationsSupplier = destinationsSupplier;
}
public int getServiceId() {
@@ -78,12 +70,6 @@
return mPacket;
}
- @NonNull
- @Override
- public Iterable<SocketAddress> getDestinations(int index) {
- return mDestinationsSupplier.get();
- }
-
@Override
public long getDelayMs(int nextIndex) {
// As per https://datatracker.ietf.org/doc/html/rfc6762#section-8.1
diff --git a/service/mdns/com/android/server/connectivity/mdns/MdnsReplySender.java b/service/mdns/com/android/server/connectivity/mdns/MdnsReplySender.java
index adf6f4d..c6b8f47 100644
--- a/service/mdns/com/android/server/connectivity/mdns/MdnsReplySender.java
+++ b/service/mdns/com/android/server/connectivity/mdns/MdnsReplySender.java
@@ -21,8 +21,10 @@
import java.io.IOException;
import java.net.DatagramPacket;
+import java.net.Inet4Address;
+import java.net.Inet6Address;
+import java.net.InetSocketAddress;
import java.net.MulticastSocket;
-import java.net.SocketAddress;
/**
* A class that handles sending mDNS replies to a {@link MulticastSocket}, possibly queueing them
@@ -50,11 +52,16 @@
*
* Must be called on the looper thread used by the {@link MdnsReplySender}.
*/
- public void sendNow(@NonNull MdnsPacket packet, @NonNull SocketAddress destination)
+ public void sendNow(@NonNull MdnsPacket packet, @NonNull InetSocketAddress destination)
throws IOException {
if (Thread.currentThread() != mLooper.getThread()) {
throw new IllegalStateException("sendNow must be called in the handler thread");
}
+ if (!((destination.getAddress() instanceof Inet6Address && mSocket.hasJoinedIpv6())
+ || (destination.getAddress() instanceof Inet4Address && mSocket.hasJoinedIpv4()))) {
+ // Skip sending if the socket has not joined the v4/v6 group (there was no address)
+ return;
+ }
// TODO: support packets over size (send in multiple packets with TC bit set)
final MdnsPacketWriter writer = new MdnsPacketWriter(mPacketCreationBuffer);
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsAnnouncerTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsAnnouncerTest.kt
index 2051e0c..961f0f0 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsAnnouncerTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsAnnouncerTest.kt
@@ -27,7 +27,6 @@
import java.net.DatagramPacket
import java.net.Inet6Address
import java.net.InetAddress
-import java.net.InetSocketAddress
import kotlin.test.assertEquals
import kotlin.test.assertTrue
import org.junit.After
@@ -37,6 +36,7 @@
import org.mockito.ArgumentCaptor
import org.mockito.Mockito.any
import org.mockito.Mockito.atLeast
+import org.mockito.Mockito.doReturn
import org.mockito.Mockito.mock
import org.mockito.Mockito.timeout
import org.mockito.Mockito.verify
@@ -46,9 +46,6 @@
private const val NEXT_ANNOUNCES_DELAY = 1L
private const val TEST_TIMEOUT_MS = 1000L
-private val destinationsSupplier = {
- listOf(InetSocketAddress(MdnsConstants.getMdnsIPv6Address(), MdnsConstants.MDNS_PORT)) }
-
@RunWith(DevSdkIgnoreRunner::class)
@IgnoreUpTo(Build.VERSION_CODES.S_V2)
class MdnsAnnouncerTest {
@@ -59,6 +56,7 @@
@Before
fun setUp() {
+ doReturn(true).`when`(socket).hasJoinedIpv6()
thread.start()
}
@@ -70,7 +68,7 @@
private class TestAnnouncementInfo(
announcedRecords: List<MdnsRecord>,
additionalRecords: List<MdnsRecord>
- ) : AnnouncementInfo(announcedRecords, additionalRecords, destinationsSupplier) {
+ ) : AnnouncementInfo(announcedRecords, additionalRecords) {
override fun getDelayMs(nextIndex: Int) =
if (nextIndex < FIRST_ANNOUNCES_COUNT) {
FIRST_ANNOUNCES_DELAY
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsProberTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsProberTest.kt
index a98a4b2..3caa97d 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsProberTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsProberTest.kt
@@ -25,7 +25,6 @@
import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo
import com.android.testutils.DevSdkIgnoreRunner
import java.net.DatagramPacket
-import java.net.InetSocketAddress
import java.util.concurrent.CompletableFuture
import java.util.concurrent.TimeUnit
import kotlin.test.assertEquals
@@ -37,15 +36,13 @@
import org.mockito.ArgumentCaptor
import org.mockito.Mockito.any
import org.mockito.Mockito.atLeast
+import org.mockito.Mockito.doReturn
import org.mockito.Mockito.mock
import org.mockito.Mockito.never
import org.mockito.Mockito.timeout
import org.mockito.Mockito.times
import org.mockito.Mockito.verify
-private val destinationsSupplier = {
- listOf(InetSocketAddress(MdnsConstants.getMdnsIPv6Address(), MdnsConstants.MDNS_PORT)) }
-
private const val TEST_TIMEOUT_MS = 10_000L
private const val SHORT_TIMEOUT_MS = 200L
@@ -64,6 +61,7 @@
@Before
fun setUp() {
+ doReturn(true).`when`(socket).hasJoinedIpv6()
thread.start()
}
@@ -73,7 +71,7 @@
}
private class TestProbeInfo(probeRecords: List<MdnsRecord>, private val delayMs: Long = 1L) :
- ProbingInfo(1 /* serviceId */, probeRecords, destinationsSupplier) {
+ ProbingInfo(1 /* serviceId */, probeRecords) {
// Just send the packets quickly. Timing-related tests for MdnsPacketRepeater are already
// done in MdnsAnnouncerTest.
override fun getDelayMs(nextIndex: Int) = delayMs