Add replying to queries
MdnsInterfaceAdvertiser registers to receive incoming packets, and
sends replies to queries as built by MdnsRecordRepository.
Bug: 241738458
Test: atest
Change-Id: I13db22f8efc870b6e0747d105f6bc8f759910f81
diff --git a/service-t/src/com/android/server/mdns/MdnsConstants.java b/service-t/src/com/android/server/mdns/MdnsConstants.java
index 396be5f..f0e1717 100644
--- a/service-t/src/com/android/server/mdns/MdnsConstants.java
+++ b/service-t/src/com/android/server/mdns/MdnsConstants.java
@@ -37,6 +37,7 @@
public static final int FLAGS_QUERY = 0x0000;
public static final int FLAGS_RESPONSE_MASK = 0xF80F;
public static final int FLAGS_RESPONSE = 0x8000;
+ public static final int FLAG_TRUNCATED = 0x0200;
public static final int QCLASS_INTERNET = 0x0001;
public static final int QCLASS_UNICAST = 0x8000;
public static final String SUBTYPE_LABEL = "_sub";
diff --git a/service-t/src/com/android/server/mdns/MdnsInterfaceAdvertiser.java b/service-t/src/com/android/server/mdns/MdnsInterfaceAdvertiser.java
index 790e69a..a14b5ad 100644
--- a/service-t/src/com/android/server/mdns/MdnsInterfaceAdvertiser.java
+++ b/service-t/src/com/android/server/mdns/MdnsInterfaceAdvertiser.java
@@ -25,16 +25,18 @@
import android.util.Log;
import com.android.internal.annotations.VisibleForTesting;
+import com.android.net.module.util.HexDump;
import com.android.server.connectivity.mdns.MdnsAnnouncer.BaseAnnouncementInfo;
import com.android.server.connectivity.mdns.MdnsPacketRepeater.PacketRepeaterCallback;
import java.io.IOException;
+import java.net.InetSocketAddress;
import java.util.List;
/**
* A class that handles advertising services on a {@link MdnsInterfaceSocket} tied to an interface.
*/
-public class MdnsInterfaceAdvertiser {
+public class MdnsInterfaceAdvertiser implements MulticastPacketReader.PacketHandler {
private static final boolean DBG = MdnsAdvertiser.DBG;
@VisibleForTesting
public static final long EXIT_ANNOUNCEMENT_DELAY_MS = 100L;
@@ -145,9 +147,9 @@
/** @see MdnsReplySender */
@NonNull
- public MdnsReplySender makeReplySender(@NonNull Looper looper,
+ public MdnsReplySender makeReplySender(@NonNull String interfaceTag, @NonNull Looper looper,
@NonNull MdnsInterfaceSocket socket, @NonNull byte[] packetCreationBuffer) {
- return new MdnsReplySender(looper, socket, packetCreationBuffer);
+ return new MdnsReplySender(interfaceTag, looper, socket, packetCreationBuffer);
}
/** @see MdnsAnnouncer */
@@ -182,7 +184,7 @@
mSocket = socket;
mCb = cb;
mCbHandler = new Handler(looper);
- mReplySender = deps.makeReplySender(looper, socket, packetCreationBuffer);
+ mReplySender = deps.makeReplySender(logTag, looper, socket, packetCreationBuffer);
mAnnouncer = deps.makeMdnsAnnouncer(logTag, looper, mReplySender,
mAnnouncingCallback);
mProber = deps.makeMdnsProber(logTag, looper, mReplySender, mProbingCallback);
@@ -196,7 +198,7 @@
* {@link #destroyNow()}.
*/
public void start() {
- // TODO: start receiving packets
+ mSocket.addPacketHandler(this);
}
/**
@@ -267,8 +269,8 @@
mProber.stop(serviceId);
mAnnouncer.stop(serviceId);
}
-
- // TODO: stop receiving packets
+ mReplySender.cancelAll();
+ mSocket.removePacketHandler(this);
mCbHandler.post(() -> mCb.onDestroyed(mSocket));
}
@@ -294,4 +296,33 @@
public boolean isProbing(int serviceId) {
return mRecordRepository.isProbing(serviceId);
}
+
+ @Override
+ public void handlePacket(byte[] recvbuf, int length, InetSocketAddress src) {
+ final MdnsPacket packet;
+ try {
+ packet = MdnsPacket.parse(new MdnsPacketReader(recvbuf, length));
+ } catch (MdnsPacket.ParseException e) {
+ Log.e(mTag, "Error parsing mDNS packet", e);
+ if (DBG) {
+ Log.v(
+ mTag, "Packet: " + HexDump.toHexString(recvbuf, 0, length));
+ }
+ return;
+ }
+
+ if (DBG) {
+ Log.v(mTag,
+ "Parsed packet with " + packet.questions.size() + " questions, "
+ + packet.answers.size() + " answers, "
+ + packet.authorityRecords.size() + " authority, "
+ + packet.additionalRecords.size() + " additional from " + src);
+ }
+
+ final MdnsRecordRepository.ReplyInfo answers =
+ mRecordRepository.getReply(packet, src);
+
+ if (answers == null) return;
+ mReplySender.queueReply(answers);
+ }
}
diff --git a/service-t/src/com/android/server/mdns/MdnsInterfaceSocket.java b/service-t/src/com/android/server/mdns/MdnsInterfaceSocket.java
index d1290b6..119c7a8 100644
--- a/service-t/src/com/android/server/mdns/MdnsInterfaceSocket.java
+++ b/service-t/src/com/android/server/mdns/MdnsInterfaceSocket.java
@@ -162,6 +162,14 @@
}
/**
+ * Remove a handler added via {@link #addPacketHandler}. If the handler is not present, this is
+ * a no-op.
+ */
+ public void removePacketHandler(@NonNull MulticastPacketReader.PacketHandler handler) {
+ mPacketReader.removePacketHandler(handler);
+ }
+
+ /**
* Returns the network interface that this socket is bound to.
*
* <p>This method could be used on any thread.
diff --git a/service-t/src/com/android/server/mdns/MdnsPacketRepeater.java b/service-t/src/com/android/server/mdns/MdnsPacketRepeater.java
index ae54e70..4c385da 100644
--- a/service-t/src/com/android/server/mdns/MdnsPacketRepeater.java
+++ b/service-t/src/com/android/server/mdns/MdnsPacketRepeater.java
@@ -16,6 +16,9 @@
package com.android.server.connectivity.mdns;
+import static com.android.server.connectivity.mdns.MdnsRecordRepository.IPV4_ADDR;
+import static com.android.server.connectivity.mdns.MdnsRecordRepository.IPV6_ADDR;
+
import android.annotation.NonNull;
import android.annotation.Nullable;
import android.os.Handler;
@@ -32,10 +35,6 @@
*/
public abstract class MdnsPacketRepeater<T extends MdnsPacketRepeater.Request> {
private static final boolean DBG = MdnsAdvertiser.DBG;
- private static final InetSocketAddress IPV4_ADDR = new InetSocketAddress(
- MdnsConstants.getMdnsIPv4Address(), MdnsConstants.MDNS_PORT);
- private static final InetSocketAddress IPV6_ADDR = new InetSocketAddress(
- MdnsConstants.getMdnsIPv6Address(), MdnsConstants.MDNS_PORT);
private static final InetSocketAddress[] ALL_ADDRS = new InetSocketAddress[] {
IPV4_ADDR, IPV6_ADDR
};
@@ -114,7 +113,7 @@
final MdnsPacket packet = request.getPacket(index);
if (DBG) {
Log.v(getTag(), "Sending packets for iteration " + index + " out of "
- + request.getNumSends());
+ + request.getNumSends() + " for ID " + msg.what);
}
// Send to both v4 and v6 addresses; the reply sender will take care of ignoring the
// send when the socket has not joined the relevant group.
diff --git a/service-t/src/com/android/server/mdns/MdnsRecord.java b/service-t/src/com/android/server/mdns/MdnsRecord.java
index 00871ea..bcee9d1 100644
--- a/service-t/src/com/android/server/mdns/MdnsRecord.java
+++ b/service-t/src/com/android/server/mdns/MdnsRecord.java
@@ -45,6 +45,7 @@
private static final int FLAG_CACHE_FLUSH = 0x8000;
public static final long RECEIPT_TIME_NOT_SENT = 0L;
+ public static final int CLASS_ANY = 0x00ff;
/** Status indicating that the record is current. */
public static final int STATUS_OK = 0;
@@ -317,4 +318,4 @@
return (recordType * 31) + Arrays.hashCode(recordName);
}
}
-}
\ No newline at end of file
+}
diff --git a/service-t/src/com/android/server/mdns/MdnsRecordRepository.java b/service-t/src/com/android/server/mdns/MdnsRecordRepository.java
index dd00212..4b2f553 100644
--- a/service-t/src/com/android/server/mdns/MdnsRecordRepository.java
+++ b/service-t/src/com/android/server/mdns/MdnsRecordRepository.java
@@ -34,6 +34,7 @@
import java.io.IOException;
import java.net.Inet4Address;
import java.net.InetAddress;
+import java.net.InetSocketAddress;
import java.net.NetworkInterface;
import java.util.ArrayList;
import java.util.Arrays;
@@ -42,6 +43,7 @@
import java.util.Iterator;
import java.util.List;
import java.util.Map;
+import java.util.Random;
import java.util.Set;
import java.util.TreeMap;
import java.util.UUID;
@@ -54,6 +56,9 @@
*/
@TargetApi(Build.VERSION_CODES.TIRAMISU) // Allow calling T+ APIs; this is only loaded on T+
public class MdnsRecordRepository {
+ // RFC6762 p.15
+ private static final long MIN_MULTICAST_REPLY_INTERVAL_MS = 1_000L;
+
// TTLs as per RFC6762 10.
// TTL for records with a host name as the resource record's name (e.g., A, AAAA, HINFO) or a
// host name contained within the resource record's rdata (e.g., SRV, reverse mapping PTR
@@ -69,6 +74,13 @@
private static final String[] DNS_SD_SERVICE_TYPE =
new String[] { "_services", "_dns-sd", "_udp", LOCAL_TLD };
+ public static final InetSocketAddress IPV6_ADDR = new InetSocketAddress(
+ MdnsConstants.getMdnsIPv6Address(), MdnsConstants.MDNS_PORT);
+ public static final InetSocketAddress IPV4_ADDR = new InetSocketAddress(
+ MdnsConstants.getMdnsIPv4Address(), MdnsConstants.MDNS_PORT);
+
+ @NonNull
+ private final Random mDelayGenerator = new Random();
// Map of service unique ID -> records for service
@NonNull
private final SparseArray<ServiceRegistration> mServices = new SparseArray<>();
@@ -139,6 +151,11 @@
public boolean isProbing;
/**
+ * Last time (as per SystemClock.elapsedRealtime) when advertised via multicast, 0 if never
+ */
+ public long lastAdvertisedTimeMs;
+
+ /**
* Last time (as per SystemClock.elapsedRealtime) when sent via unicast or multicast,
* 0 if never
*/
@@ -391,6 +408,212 @@
}
/**
+ * Info about a reply to be sent.
+ */
+ public static class ReplyInfo {
+ @NonNull
+ public final List<MdnsRecord> answers;
+ @NonNull
+ public final List<MdnsRecord> additionalAnswers;
+ public final long sendDelayMs;
+ @NonNull
+ public final InetSocketAddress destination;
+
+ public ReplyInfo(
+ @NonNull List<MdnsRecord> answers,
+ @NonNull List<MdnsRecord> additionalAnswers,
+ long sendDelayMs,
+ @NonNull InetSocketAddress destination) {
+ this.answers = answers;
+ this.additionalAnswers = additionalAnswers;
+ this.sendDelayMs = sendDelayMs;
+ this.destination = destination;
+ }
+
+ @Override
+ public String toString() {
+ return "{ReplyInfo to " + destination + ", answers: " + answers.size()
+ + ", additionalAnswers: " + additionalAnswers.size()
+ + ", sendDelayMs " + sendDelayMs + "}";
+ }
+ }
+
+ /**
+ * Get the reply to send to an incoming packet.
+ *
+ * @param packet The incoming packet.
+ * @param src The source address of the incoming packet.
+ */
+ @Nullable
+ public ReplyInfo getReply(MdnsPacket packet, InetSocketAddress src) {
+ final long now = SystemClock.elapsedRealtime();
+ final boolean replyUnicast = (packet.flags & MdnsConstants.QCLASS_UNICAST) != 0;
+ final ArrayList<MdnsRecord> additionalAnswerRecords = new ArrayList<>();
+ final ArrayList<RecordInfo<?>> answerInfo = new ArrayList<>();
+ for (MdnsRecord question : packet.questions) {
+ // Add answers from general records
+ addReplyFromService(question, mGeneralRecords, null /* servicePtrRecord */,
+ null /* serviceSrvRecord */, null /* serviceTxtRecord */, replyUnicast, now,
+ answerInfo, additionalAnswerRecords);
+
+ // Add answers from each service
+ for (int i = 0; i < mServices.size(); i++) {
+ final ServiceRegistration registration = mServices.valueAt(i);
+ if (registration.exiting) continue;
+ addReplyFromService(question, registration.allRecords, registration.ptrRecord,
+ registration.srvRecord, registration.txtRecord, replyUnicast, now,
+ answerInfo, additionalAnswerRecords);
+ }
+ }
+
+ if (answerInfo.size() == 0 && additionalAnswerRecords.size() == 0) {
+ return null;
+ }
+
+ // Determine the send delay
+ final long delayMs;
+ if ((packet.flags & MdnsConstants.FLAG_TRUNCATED) != 0) {
+ // RFC 6762 6.: 400-500ms delay if TC bit is set
+ delayMs = 400L + mDelayGenerator.nextInt(100);
+ } else if (packet.questions.size() > 1
+ || CollectionUtils.any(answerInfo, a -> a.isSharedName)) {
+ // 20-120ms if there may be responses from other hosts (not a fully owned
+ // name) (RFC 6762 6.), or if there are multiple questions (6.3).
+ // TODO: this should be 0 if this is a probe query ("can be distinguished from a
+ // normal query by the fact that a probe query contains a proposed record in the
+ // Authority Section that answers the question" in 6.), and the reply is for a fully
+ // owned record.
+ delayMs = 20L + mDelayGenerator.nextInt(100);
+ } else {
+ delayMs = 0L;
+ }
+
+ // Determine the send destination
+ final InetSocketAddress dest;
+ if (replyUnicast) {
+ dest = src;
+ } else if (src.getAddress() instanceof Inet4Address) {
+ dest = IPV4_ADDR;
+ } else {
+ dest = IPV6_ADDR;
+ }
+
+ // Build the list of answer records from their RecordInfo
+ final ArrayList<MdnsRecord> answerRecords = new ArrayList<>(answerInfo.size());
+ for (RecordInfo<?> info : answerInfo) {
+ // TODO: consider actual packet send delay after response aggregation
+ info.lastSentTimeMs = now + delayMs;
+ if (!replyUnicast) {
+ info.lastAdvertisedTimeMs = info.lastSentTimeMs;
+ }
+ answerRecords.add(info.record);
+ }
+
+ return new ReplyInfo(answerRecords, additionalAnswerRecords, delayMs, dest);
+ }
+
+ /**
+ * Add answers and additional answers for a question, from a ServiceRegistration.
+ */
+ private void addReplyFromService(@NonNull MdnsRecord question,
+ @NonNull List<RecordInfo<?>> serviceRecords,
+ @Nullable RecordInfo<MdnsPointerRecord> servicePtrRecord,
+ @Nullable RecordInfo<MdnsServiceRecord> serviceSrvRecord,
+ @Nullable RecordInfo<MdnsTextRecord> serviceTxtRecord,
+ boolean replyUnicast, long now, @NonNull List<RecordInfo<?>> answerInfo,
+ @NonNull List<MdnsRecord> additionalAnswerRecords) {
+ boolean hasDnsSdPtrRecordAnswer = false;
+ boolean hasDnsSdSrvRecordAnswer = false;
+ boolean hasFullyOwnedNameMatch = false;
+ boolean hasKnownAnswer = false;
+
+ final int answersStartIndex = answerInfo.size();
+ for (RecordInfo<?> info : serviceRecords) {
+ if (info.isProbing) continue;
+
+ /* RFC6762 6.: the record name must match the question name, the record rrtype
+ must match the question qtype unless the qtype is "ANY" (255) or the rrtype is
+ "CNAME" (5), and the record rrclass must match the question qclass unless the
+ qclass is "ANY" (255) */
+ if (!Arrays.equals(info.record.getName(), question.getName())) continue;
+ hasFullyOwnedNameMatch |= !info.isSharedName;
+
+ // The repository does not store CNAME records
+ if (question.getType() != MdnsRecord.TYPE_ANY
+ && question.getType() != info.record.getType()) {
+ continue;
+ }
+ if (question.getRecordClass() != MdnsRecord.CLASS_ANY
+ && question.getRecordClass() != info.record.getRecordClass()) {
+ continue;
+ }
+
+ hasKnownAnswer = true;
+ hasDnsSdPtrRecordAnswer |= (info == servicePtrRecord);
+ hasDnsSdSrvRecordAnswer |= (info == serviceSrvRecord);
+
+ // TODO: responses to probe queries should bypass this check and only ensure the
+ // reply is sent 250ms after the last sent time (RFC 6762 p.15)
+ if (!replyUnicast && info.lastAdvertisedTimeMs > 0L
+ && now - info.lastAdvertisedTimeMs < MIN_MULTICAST_REPLY_INTERVAL_MS) {
+ continue;
+ }
+
+ // TODO: Don't reply if in known answers of the querier (7.1) if TTL is > half
+
+ answerInfo.add(info);
+ }
+
+ // RFC6762 6.1:
+ // "Any time a responder receives a query for a name for which it has verified exclusive
+ // ownership, for a type for which that name has no records, the responder MUST [...]
+ // respond asserting the nonexistence of that record"
+ if (hasFullyOwnedNameMatch && !hasKnownAnswer) {
+ additionalAnswerRecords.add(new MdnsNsecRecord(
+ question.getName(),
+ 0L /* receiptTimeMillis */,
+ true /* cacheFlush */,
+ // TODO: RFC6762 6.1: "In general, the TTL given for an NSEC record SHOULD
+ // be the same as the TTL that the record would have had, had it existed."
+ NAME_RECORDS_TTL_MILLIS,
+ question.getName(),
+ new int[] { question.getType() }));
+ }
+
+ // No more records to add if no answer
+ if (answerInfo.size() == answersStartIndex) return;
+
+ final List<RecordInfo<?>> additionalAnswerInfo = new ArrayList<>();
+ // RFC6763 12.1: if including PTR record, include the SRV and TXT records it names
+ if (hasDnsSdPtrRecordAnswer) {
+ if (serviceTxtRecord != null) {
+ additionalAnswerInfo.add(serviceTxtRecord);
+ }
+ if (serviceSrvRecord != null) {
+ additionalAnswerInfo.add(serviceSrvRecord);
+ }
+ }
+
+ // RFC6763 12.1&.2: if including PTR or SRV record, include the address records it names
+ if (hasDnsSdPtrRecordAnswer || hasDnsSdSrvRecordAnswer) {
+ for (RecordInfo<?> record : mGeneralRecords) {
+ if (record.record instanceof MdnsInetAddressRecord) {
+ additionalAnswerInfo.add(record);
+ }
+ }
+ }
+
+ for (RecordInfo<?> info : additionalAnswerInfo) {
+ additionalAnswerRecords.add(info.record);
+ }
+
+ // RFC6762 6.1: negative responses
+ addNsecRecordsForUniqueNames(additionalAnswerRecords,
+ answerInfo.listIterator(answersStartIndex),
+ additionalAnswerInfo.listIterator());
+ }
+
+ /**
* Add NSEC records indicating that the response records are unique.
*
* Following RFC6762 6.1:
@@ -540,6 +763,7 @@
final long now = SystemClock.elapsedRealtime();
for (RecordInfo<?> record : registration.allRecords) {
record.lastSentTimeMs = now;
+ record.lastAdvertisedTimeMs = now;
}
}
diff --git a/service-t/src/com/android/server/mdns/MdnsReplySender.java b/service-t/src/com/android/server/mdns/MdnsReplySender.java
index c6b8f47..f1389ca 100644
--- a/service-t/src/com/android/server/mdns/MdnsReplySender.java
+++ b/service-t/src/com/android/server/mdns/MdnsReplySender.java
@@ -16,8 +16,15 @@
package com.android.server.connectivity.mdns;
+import static com.android.server.connectivity.mdns.MdnsSocketProvider.ensureRunningOnHandlerThread;
+
import android.annotation.NonNull;
+import android.os.Handler;
import android.os.Looper;
+import android.os.Message;
+import android.util.Log;
+
+import com.android.server.connectivity.mdns.MdnsRecordRepository.ReplyInfo;
import java.io.IOException;
import java.net.DatagramPacket;
@@ -25,6 +32,7 @@
import java.net.Inet6Address;
import java.net.InetSocketAddress;
import java.net.MulticastSocket;
+import java.util.Collections;
/**
* A class that handles sending mDNS replies to a {@link MulticastSocket}, possibly queueing them
@@ -33,30 +41,46 @@
* TODO: implement sending after a delay, combining queued replies and duplicate answer suppression
*/
public class MdnsReplySender {
+ private static final boolean DBG = MdnsAdvertiser.DBG;
+ private static final int MSG_SEND = 1;
+
+ private final String mLogTag;
@NonNull
private final MdnsInterfaceSocket mSocket;
@NonNull
- private final Looper mLooper;
+ private final Handler mHandler;
@NonNull
private final byte[] mPacketCreationBuffer;
- public MdnsReplySender(@NonNull Looper looper,
+ public MdnsReplySender(@NonNull String interfaceTag, @NonNull Looper looper,
@NonNull MdnsInterfaceSocket socket, @NonNull byte[] packetCreationBuffer) {
- mLooper = looper;
+ mHandler = new SendHandler(looper);
+ mLogTag = MdnsReplySender.class.getSimpleName() + "/" + interfaceTag;
mSocket = socket;
mPacketCreationBuffer = packetCreationBuffer;
}
/**
+ * Queue a reply to be sent when its send delay expires.
+ */
+ public void queueReply(@NonNull ReplyInfo reply) {
+ ensureRunningOnHandlerThread(mHandler);
+ // TODO: implement response aggregation (RFC 6762 6.4)
+ mHandler.sendMessageDelayed(mHandler.obtainMessage(MSG_SEND, reply), reply.sendDelayMs);
+
+ if (DBG) {
+ Log.v(mLogTag, "Scheduling " + reply);
+ }
+ }
+
+ /**
* Send a packet immediately.
*
* Must be called on the looper thread used by the {@link MdnsReplySender}.
*/
public void sendNow(@NonNull MdnsPacket packet, @NonNull InetSocketAddress destination)
throws IOException {
- if (Thread.currentThread() != mLooper.getThread()) {
- throw new IllegalStateException("sendNow must be called in the handler thread");
- }
+ ensureRunningOnHandlerThread(mHandler);
if (!((destination.getAddress() instanceof Inet6Address && mSocket.hasJoinedIpv6())
|| (destination.getAddress() instanceof Inet4Address && mSocket.hasJoinedIpv4()))) {
// Skip sending if the socket has not joined the v4/v6 group (there was no address)
@@ -93,4 +117,37 @@
mSocket.send(new DatagramPacket(outBuffer, 0, len, destination));
}
+
+ /**
+ * Cancel all pending sends.
+ */
+ public void cancelAll() {
+ ensureRunningOnHandlerThread(mHandler);
+ mHandler.removeMessages(MSG_SEND);
+ }
+
+ private class SendHandler extends Handler {
+ SendHandler(@NonNull Looper looper) {
+ super(looper);
+ }
+
+ @Override
+ public void handleMessage(@NonNull Message msg) {
+ final ReplyInfo replyInfo = (ReplyInfo) msg.obj;
+ if (DBG) Log.v(mLogTag, "Sending " + replyInfo);
+
+ final int flags = 0x8400; // Response, authoritative (rfc6762 18.4)
+ final MdnsPacket packet = new MdnsPacket(flags,
+ Collections.emptyList() /* questions */,
+ replyInfo.answers,
+ Collections.emptyList() /* authorityRecords */,
+ replyInfo.additionalAnswers);
+
+ try {
+ sendNow(packet, replyInfo.destination);
+ } catch (IOException e) {
+ Log.e(mLogTag, "Error sending MDNS response", e);
+ }
+ }
+ }
}
diff --git a/service-t/src/com/android/server/mdns/MulticastPacketReader.java b/service-t/src/com/android/server/mdns/MulticastPacketReader.java
index 20cc47f..b597f0a 100644
--- a/service-t/src/com/android/server/mdns/MulticastPacketReader.java
+++ b/service-t/src/com/android/server/mdns/MulticastPacketReader.java
@@ -107,5 +107,14 @@
ensureRunningOnHandlerThread(mHandler);
mPacketHandlers.add(handler);
}
+
+ /**
+ * Remove a packet handler added via {@link #addPacketHandler}. If the handler was not set,
+ * this is a no-op.
+ */
+ public void removePacketHandler(@NonNull PacketHandler handler) {
+ ensureRunningOnHandlerThread(mHandler);
+ mPacketHandlers.remove(handler);
+ }
}
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsAnnouncerTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsAnnouncerTest.kt
index 334f99d..6c3f729 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsAnnouncerTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsAnnouncerTest.kt
@@ -79,7 +79,7 @@
@Test
fun testAnnounce() {
- val replySender = MdnsReplySender(thread.looper, socket, buffer)
+ val replySender = MdnsReplySender("testiface", thread.looper, socket, buffer)
@Suppress("UNCHECKED_CAST")
val cb = mock(MdnsPacketRepeater.PacketRepeaterCallback::class.java)
as MdnsPacketRepeater.PacketRepeaterCallback<BaseAnnouncementInfo>
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt
index 2cb0850..02b3976 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt
@@ -21,6 +21,7 @@
import android.net.nsd.NsdServiceInfo
import android.os.Build
import android.os.HandlerThread
+import com.android.net.module.util.HexDump
import com.android.server.connectivity.mdns.MdnsAnnouncer.AnnouncementInfo
import com.android.server.connectivity.mdns.MdnsAnnouncer.BaseAnnouncementInfo
import com.android.server.connectivity.mdns.MdnsAnnouncer.ExitAnnouncementInfo
@@ -30,6 +31,10 @@
import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo
import com.android.testutils.DevSdkIgnoreRunner
import com.android.testutils.waitForIdle
+import java.net.InetSocketAddress
+import kotlin.test.assertContentEquals
+import kotlin.test.assertEquals
+import kotlin.test.assertTrue
import org.junit.After
import org.junit.Before
import org.junit.Test
@@ -37,8 +42,10 @@
import org.mockito.ArgumentCaptor
import org.mockito.Mockito.any
import org.mockito.Mockito.anyInt
+import org.mockito.Mockito.anyString
import org.mockito.Mockito.doAnswer
import org.mockito.Mockito.doReturn
+import org.mockito.Mockito.eq
import org.mockito.Mockito.mock
import org.mockito.Mockito.times
import org.mockito.Mockito.verify
@@ -67,13 +74,18 @@
private val replySender = mock(MdnsReplySender::class.java)
private val announcer = mock(MdnsAnnouncer::class.java)
private val prober = mock(MdnsProber::class.java)
+ @Suppress("UNCHECKED_CAST")
private val probeCbCaptor = ArgumentCaptor.forClass(PacketRepeaterCallback::class.java)
as ArgumentCaptor<PacketRepeaterCallback<ProbingInfo>>
+ @Suppress("UNCHECKED_CAST")
private val announceCbCaptor = ArgumentCaptor.forClass(PacketRepeaterCallback::class.java)
as ArgumentCaptor<PacketRepeaterCallback<BaseAnnouncementInfo>>
+ private val packetHandlerCaptor = ArgumentCaptor.forClass(
+ MulticastPacketReader.PacketHandler::class.java)
private val probeCb get() = probeCbCaptor.value
private val announceCb get() = announceCbCaptor.value
+ private val packetHandler get() = packetHandlerCaptor.value
private val advertiser by lazy {
MdnsInterfaceAdvertiser(LOG_TAG, socket, TEST_ADDRS, thread.looper, TEST_BUFFER, cb, deps)
@@ -82,9 +94,9 @@
@Before
fun setUp() {
doReturn(repository).`when`(deps).makeRecordRepository(any())
- doReturn(replySender).`when`(deps).makeReplySender(any(), any(), any())
- doReturn(announcer).`when`(deps).makeMdnsAnnouncer(any(), any(), any(), any())
- doReturn(prober).`when`(deps).makeMdnsProber(any(), any(), any(), any())
+ doReturn(replySender).`when`(deps).makeReplySender(anyString(), any(), any(), any())
+ doReturn(announcer).`when`(deps).makeMdnsAnnouncer(anyString(), any(), any(), any())
+ doReturn(prober).`when`(deps).makeMdnsProber(anyString(), any(), any(), any())
val knownServices = mutableSetOf<Int>()
doAnswer { inv ->
@@ -104,6 +116,7 @@
thread.start()
advertiser.start()
+ verify(socket).addPacketHandler(packetHandlerCaptor.capture())
verify(deps).makeMdnsProber(any(), any(), any(), probeCbCaptor.capture())
verify(deps).makeMdnsAnnouncer(any(), any(), any(), announceCbCaptor.capture())
}
@@ -157,6 +170,39 @@
verify(announcer, times(1)).stop(TEST_SERVICE_ID_1)
}
+ @Test
+ fun testReplyToQuery() {
+ addServiceAndFinishProbing(TEST_SERVICE_ID_1, TEST_SERVICE_1)
+
+ val mockReply = mock(MdnsRecordRepository.ReplyInfo::class.java)
+ doReturn(mockReply).`when`(repository).getReply(any(), any())
+
+ // Query obtained with:
+ // scapy.raw(scapy.DNS(
+ // qd = scapy.DNSQR(qtype='PTR', qname='_testservice._tcp.local'))
+ // ).hex().upper()
+ val query = HexDump.hexStringToByteArray(
+ "0000010000010000000000000C5F7465737473657276696365045F746370056C6F63616C00000C0001"
+ )
+ val src = InetSocketAddress(parseNumericAddress("2001:db8::456"), MdnsConstants.MDNS_PORT)
+ packetHandler.handlePacket(query, query.size, src)
+
+ val packetCaptor = ArgumentCaptor.forClass(MdnsPacket::class.java)
+ verify(repository).getReply(packetCaptor.capture(), eq(src))
+
+ packetCaptor.value.let {
+ assertEquals(1, it.questions.size)
+ assertEquals(0, it.answers.size)
+ assertEquals(0, it.authorityRecords.size)
+ assertEquals(0, it.additionalRecords.size)
+
+ assertTrue(it.questions[0] is MdnsPointerRecord)
+ assertContentEquals(arrayOf("_testservice", "_tcp", "local"), it.questions[0].name)
+ }
+
+ verify(replySender).queueReply(mockReply)
+ }
+
private fun addServiceAndFinishProbing(serviceId: Int, serviceInfo: NsdServiceInfo):
AnnouncementInfo {
val testProbingInfo = mock(ProbingInfo::class.java)
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsProberTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsProberTest.kt
index 3caa97d..a2dbbc6 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsProberTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsProberTest.kt
@@ -114,7 +114,7 @@
@Test
fun testProbe() {
- val replySender = MdnsReplySender(thread.looper, socket, buffer)
+ val replySender = MdnsReplySender("testiface", thread.looper, socket, buffer)
val prober = TestProber(thread.looper, replySender, cb)
val probeInfo = TestProbeInfo(
listOf(makeServiceRecord(TEST_SERVICE_NAME_1, 37890)))
@@ -129,7 +129,7 @@
@Test
fun testProbeMultipleRecords() {
- val replySender = MdnsReplySender(thread.looper, socket, buffer)
+ val replySender = MdnsReplySender("testiface", thread.looper, socket, buffer)
val prober = TestProber(thread.looper, replySender, cb)
val probeInfo = TestProbeInfo(listOf(
makeServiceRecord(TEST_SERVICE_NAME_1, 37890),
@@ -167,7 +167,7 @@
@Test
fun testStopProbing() {
- val replySender = MdnsReplySender(thread.looper, socket, buffer)
+ val replySender = MdnsReplySender("testiface", thread.looper, socket, buffer)
val prober = TestProber(thread.looper, replySender, cb)
val probeInfo = TestProbeInfo(
listOf(makeServiceRecord(TEST_SERVICE_NAME_1, 37890)),
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt
index 29d0854..597663c 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt
@@ -21,10 +21,12 @@
import android.net.nsd.NsdServiceInfo
import android.os.Build
import android.os.HandlerThread
+import com.android.server.connectivity.mdns.MdnsAnnouncer.AnnouncementInfo
import com.android.server.connectivity.mdns.MdnsRecordRepository.Dependencies
import com.android.server.connectivity.mdns.MdnsRecordRepository.getReverseDnsAddress
import com.android.testutils.DevSdkIgnoreRule
import com.android.testutils.DevSdkIgnoreRunner
+import java.net.InetSocketAddress
import java.net.NetworkInterface
import java.util.Collections
import kotlin.test.assertContentEquals
@@ -150,11 +152,7 @@
@Test
fun testExitAnnouncements() {
val repository = MdnsRecordRepository(thread.looper, deps)
- repository.updateAddresses(TEST_ADDRESSES)
-
- repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
- val probingInfo = repository.setServiceProbing(TEST_SERVICE_ID_1)
- repository.onProbingSucceeded(probingInfo)
+ repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
repository.onAdvertisementSent(TEST_SERVICE_ID_1)
val exitAnnouncement = repository.exitService(TEST_SERVICE_ID_1)
@@ -183,9 +181,7 @@
@Test
fun testExitingServiceReAdded() {
val repository = MdnsRecordRepository(thread.looper, deps)
- repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
- val probingInfo = repository.setServiceProbing(TEST_SERVICE_ID_1)
- repository.onProbingSucceeded(probingInfo)
+ repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
repository.onAdvertisementSent(TEST_SERVICE_ID_1)
repository.exitService(TEST_SERVICE_ID_1)
@@ -199,11 +195,8 @@
@Test
fun testOnProbingSucceeded() {
val repository = MdnsRecordRepository(thread.looper, deps)
- repository.updateAddresses(TEST_ADDRESSES)
-
- repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
- val probingInfo = repository.setServiceProbing(TEST_SERVICE_ID_1)
- val announcementInfo = repository.onProbingSucceeded(probingInfo)
+ val announcementInfo = repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
+ repository.onAdvertisementSent(TEST_SERVICE_ID_1)
val packet = announcementInfo.getPacket(0)
assertEquals(0x8400 /* response, authoritative */, packet.flags)
@@ -322,4 +315,98 @@
val expectedV4 = "123.2.0.192.in-addr.arpa".split(".").toTypedArray()
assertContentEquals(expectedV4, getReverseDnsAddress(parseNumericAddress("192.0.2.123")))
}
+
+ @Test
+ fun testGetReply() {
+ val repository = MdnsRecordRepository(thread.looper, deps)
+ repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
+ val questions = listOf(MdnsPointerRecord(arrayOf("_testservice", "_tcp", "local"),
+ 0L /* receiptTimeMillis */,
+ false /* cacheFlush */,
+ // TTL and data is empty for a question
+ 0L /* ttlMillis */,
+ null /* pointer */))
+ val query = MdnsPacket(0 /* flags */, questions, listOf() /* answers */,
+ listOf() /* authorityRecords */, listOf() /* additionalRecords */)
+ val src = InetSocketAddress(parseNumericAddress("192.0.2.123"), 5353)
+ val reply = repository.getReply(query, src)
+
+ assertNotNull(reply)
+ // Source address is IPv4
+ assertEquals(MdnsConstants.getMdnsIPv4Address(), reply.destination.address)
+ assertEquals(MdnsConstants.MDNS_PORT, reply.destination.port)
+
+ // TTLs as per RFC6762 10.
+ val longTtl = 4_500_000L
+ val shortTtl = 120_000L
+ val serviceName = arrayOf("MyTestService", "_testservice", "_tcp", "local")
+
+ assertEquals(listOf(
+ MdnsPointerRecord(
+ arrayOf("_testservice", "_tcp", "local"),
+ 0L /* receiptTimeMillis */,
+ false /* cacheFlush */,
+ longTtl,
+ serviceName),
+ ), reply.answers)
+
+ assertEquals(listOf(
+ MdnsTextRecord(
+ serviceName,
+ 0L /* receiptTimeMillis */,
+ true /* cacheFlush */,
+ longTtl,
+ listOf() /* entries */),
+ MdnsServiceRecord(
+ serviceName,
+ 0L /* receiptTimeMillis */,
+ true /* cacheFlush */,
+ shortTtl,
+ 0 /* servicePriority */,
+ 0 /* serviceWeight */,
+ TEST_PORT,
+ TEST_HOSTNAME),
+ MdnsInetAddressRecord(
+ TEST_HOSTNAME,
+ 0L /* receiptTimeMillis */,
+ true /* cacheFlush */,
+ shortTtl,
+ TEST_ADDRESSES[0].address),
+ MdnsInetAddressRecord(
+ TEST_HOSTNAME,
+ 0L /* receiptTimeMillis */,
+ true /* cacheFlush */,
+ shortTtl,
+ TEST_ADDRESSES[1].address),
+ MdnsInetAddressRecord(
+ TEST_HOSTNAME,
+ 0L /* receiptTimeMillis */,
+ true /* cacheFlush */,
+ shortTtl,
+ TEST_ADDRESSES[2].address),
+ MdnsNsecRecord(
+ serviceName,
+ 0L /* receiptTimeMillis */,
+ true /* cacheFlush */,
+ longTtl,
+ serviceName /* nextDomain */,
+ intArrayOf(MdnsRecord.TYPE_TXT, MdnsRecord.TYPE_SRV)),
+ MdnsNsecRecord(
+ TEST_HOSTNAME,
+ 0L /* receiptTimeMillis */,
+ true /* cacheFlush */,
+ shortTtl,
+ TEST_HOSTNAME /* nextDomain */,
+ intArrayOf(MdnsRecord.TYPE_A, MdnsRecord.TYPE_AAAA)),
+ ), reply.additionalAnswers)
+ }
+}
+
+private fun MdnsRecordRepository.initWithService(serviceId: Int, serviceInfo: NsdServiceInfo):
+ AnnouncementInfo {
+ updateAddresses(TEST_ADDRESSES)
+ addService(serviceId, serviceInfo)
+ val probingInfo = setServiceProbing(serviceId)
+ assertNotNull(probingInfo)
+ return onProbingSucceeded(probingInfo)
}