Merge changes I69128db9,I13db22f8
* changes:
Implement onServiceConflict
Add replying to queries
diff --git a/service-t/src/com/android/server/mdns/MdnsAdvertiser.java b/service-t/src/com/android/server/mdns/MdnsAdvertiser.java
index 4e40efe..977478a 100644
--- a/service-t/src/com/android/server/mdns/MdnsAdvertiser.java
+++ b/service-t/src/com/android/server/mdns/MdnsAdvertiser.java
@@ -29,10 +29,10 @@
import com.android.internal.annotations.VisibleForTesting;
-import java.io.IOException;
import java.util.List;
import java.util.Map;
-import java.util.function.Predicate;
+import java.util.function.BiPredicate;
+import java.util.function.Consumer;
/**
* MdnsAdvertiser manages advertising services per {@link com.android.server.NsdService} requests.
@@ -85,7 +85,7 @@
public void onRegisterServiceSucceeded(
@NonNull MdnsInterfaceAdvertiser advertiser, int serviceId) {
// Wait for all current interfaces to be done probing before notifying of success.
- if (anyAdvertiser(a -> a.isProbing(serviceId))) return;
+ if (any(mAllAdvertisers, (k, a) -> a.isProbing(serviceId))) return;
// The service may still be unregistered/renamed if a conflict is found on a later added
// interface, or if a conflicting announcement/reply is detected (RFC6762 9.)
@@ -102,7 +102,37 @@
@Override
public void onServiceConflict(@NonNull MdnsInterfaceAdvertiser advertiser, int serviceId) {
- // TODO: handle conflicts found after registration (during or after probing)
+ if (DBG) {
+ Log.v(TAG, "Found conflict, restarted probing for service " + serviceId);
+ }
+
+ final Registration registration = mRegistrations.get(serviceId);
+ if (registration == null) return;
+ if (registration.mNotifiedRegistrationSuccess) {
+ // TODO: consider notifying clients that the service is no longer registered with
+ // the old name (back to probing). The legacy implementation did not send any
+ // callback though; it only sent onServiceRegistered after re-probing finishes
+ // (with the old, conflicting, actually not used name as argument... The new
+ // implementation will send callbacks with the new name).
+ registration.mNotifiedRegistrationSuccess = false;
+
+ // The service was done probing, just reset it to probing state (RFC6762 9.)
+ forAllAdvertisers(a -> a.restartProbingForConflict(serviceId));
+ return;
+ }
+
+ // Conflict was found during probing; rename once to find a name that has no conflict
+ registration.updateForConflict(
+ registration.makeNewServiceInfoForConflict(1 /* renameCount */),
+ 1 /* renameCount */);
+
+ // Keep renaming if the new name conflicts in local registrations
+ updateRegistrationUntilNoConflict((net, adv) -> adv.hasRegistration(registration),
+ registration);
+
+ // Update advertisers to use the new name
+ forAllAdvertisers(a -> a.renameServiceForConflict(
+ serviceId, registration.getServiceInfo()));
}
@Override
@@ -116,6 +146,25 @@
}
};
+ private boolean hasAnyConflict(
+ @NonNull BiPredicate<Network, InterfaceAdvertiserRequest> applicableAdvertiserFilter,
+ @NonNull NsdServiceInfo newInfo) {
+ return any(mAdvertiserRequests, (network, adv) ->
+ applicableAdvertiserFilter.test(network, adv) && adv.hasConflict(newInfo));
+ }
+
+ private void updateRegistrationUntilNoConflict(
+ @NonNull BiPredicate<Network, InterfaceAdvertiserRequest> applicableAdvertiserFilter,
+ @NonNull Registration registration) {
+ int renameCount = 0;
+ NsdServiceInfo newInfo = registration.getServiceInfo();
+ while (hasAnyConflict(applicableAdvertiserFilter, newInfo)) {
+ renameCount++;
+ newInfo = registration.makeNewServiceInfoForConflict(renameCount);
+ }
+ registration.updateForConflict(newInfo, renameCount);
+ }
+
/**
* A request for a {@link MdnsInterfaceAdvertiser}.
*
@@ -153,6 +202,21 @@
}
/**
+ * Return whether this {@link InterfaceAdvertiserRequest} has the given registration.
+ */
+ boolean hasRegistration(@NonNull Registration registration) {
+ return mPendingRegistrations.indexOfValue(registration) >= 0;
+ }
+
+ /**
+ * Return whether using the proposed new {@link NsdServiceInfo} to add a registration would
+ * cause a conflict in this {@link InterfaceAdvertiserRequest}.
+ */
+ boolean hasConflict(@NonNull NsdServiceInfo newInfo) {
+ return getConflictingService(newInfo) >= 0;
+ }
+
+ /**
* Get the ID of a conflicting service, or -1 if none.
*/
int getConflictingService(@NonNull NsdServiceInfo info) {
@@ -166,16 +230,19 @@
return -1;
}
- void addService(int id, Registration registration)
- throws NameConflictException {
- final int conflicting = getConflictingService(registration.getServiceInfo());
- if (conflicting >= 0) {
- throw new NameConflictException(conflicting);
- }
-
+ /**
+ * Add a service.
+ *
+ * Conflicts must be checked via {@link #getConflictingService} before attempting to add.
+ */
+ void addService(int id, Registration registration) {
mPendingRegistrations.put(id, registration);
for (int i = 0; i < mAdvertisers.size(); i++) {
- mAdvertisers.valueAt(i).addService(id, registration.getServiceInfo());
+ try {
+ mAdvertisers.valueAt(i).addService(id, registration.getServiceInfo());
+ } catch (NameConflictException e) {
+ Log.wtf(TAG, "Name conflict adding services that should have unique names", e);
+ }
}
}
@@ -239,32 +306,42 @@
/**
* Update the registration to use a different service name, after a conflict was found.
*
+ * @param newInfo New service info to use.
+ * @param renameCount How many renames were done before reaching the current name.
+ */
+ private void updateForConflict(@NonNull NsdServiceInfo newInfo, int renameCount) {
+ mConflictCount += renameCount;
+ mServiceInfo = newInfo;
+ }
+
+ /**
+ * Make a new service name for the registration, after a conflict was found.
+ *
* If a name conflict was found during probing or because different advertising requests
* used the same name, the registration is attempted again with a new name (here using
* a number suffix, (1), (2) etc). Registration success is notified once probing succeeds
* with a new name. This matches legacy behavior based on mdnsresponder, and appendix D of
* RFC6763.
- * @return The new service info with the updated name.
+ *
+ * @param renameCount How much to increase the number suffix for this conflict.
*/
@NonNull
- private NsdServiceInfo updateForConflict() {
- mConflictCount++;
+ public NsdServiceInfo makeNewServiceInfoForConflict(int renameCount) {
// In case of conflict choose a different service name. After the first conflict use
// "Name (2)", then "Name (3)" etc.
// TODO: use a hidden method in NsdServiceInfo once MdnsAdvertiser is moved to service-t
final NsdServiceInfo newInfo = new NsdServiceInfo();
- newInfo.setServiceName(mOriginalName + " (" + (mConflictCount + 1) + ")");
+ newInfo.setServiceName(mOriginalName + " (" + (mConflictCount + renameCount + 1) + ")");
newInfo.setServiceType(mServiceInfo.getServiceType());
for (Map.Entry<String, byte[]> attr : mServiceInfo.getAttributes().entrySet()) {
- newInfo.setAttribute(attr.getKey(), attr.getValue());
+ newInfo.setAttribute(attr.getKey(),
+ attr.getValue() == null ? null : new String(attr.getValue()));
}
newInfo.setHost(mServiceInfo.getHost());
newInfo.setPort(mServiceInfo.getPort());
newInfo.setNetwork(mServiceInfo.getNetwork());
// interfaceIndex is not set when registering
-
- mServiceInfo = newInfo;
- return mServiceInfo;
+ return newInfo;
}
@NonNull
@@ -338,55 +415,27 @@
Log.i(TAG, "Adding service " + service + " with ID " + id);
}
- try {
- final Registration registration = new Registration(service);
- while (!tryAddRegistration(id, registration)) {
- registration.updateForConflict();
- }
-
- mRegistrations.put(id, registration);
- } catch (IOException e) {
- Log.e(TAG, "Error adding service " + service, e);
- removeService(id);
- // TODO (b/264986328): add a more specific error code
- mCb.onRegisterServiceFailed(id, NsdManager.FAILURE_INTERNAL_ERROR);
- }
- }
-
- private boolean tryAddRegistration(int id, @NonNull Registration registration)
- throws IOException {
- final NsdServiceInfo serviceInfo = registration.getServiceInfo();
- final Network network = serviceInfo.getNetwork();
- try {
- InterfaceAdvertiserRequest advertiser = mAdvertiserRequests.get(network);
- if (advertiser == null) {
- advertiser = new InterfaceAdvertiserRequest(network);
- mAdvertiserRequests.put(network, advertiser);
- }
- advertiser.addService(id, registration);
- } catch (NameConflictException e) {
- if (DBG) {
- Log.i(TAG, "Service name conflicts: " + serviceInfo.getServiceName());
- }
- removeService(id);
- return false;
+ final Network network = service.getNetwork();
+ final Registration registration = new Registration(service);
+ final BiPredicate<Network, InterfaceAdvertiserRequest> checkConflictFilter;
+ if (network == null) {
+ // If registering on all networks, no advertiser must have conflicts
+ checkConflictFilter = (net, adv) -> true;
+ } else {
+ // If registering on one network, the matching network advertiser and the one for all
+ // networks must not have conflicts
+ checkConflictFilter = (net, adv) -> net == null || network.equals(net);
}
- // When adding a service to a specific network, check that it does not conflict with other
- // registrations advertising on all networks
- final InterfaceAdvertiserRequest allNetworksAdvertiser = mAdvertiserRequests.get(null);
- if (network != null && allNetworksAdvertiser != null
- && allNetworksAdvertiser.getConflictingService(serviceInfo) >= 0) {
- if (DBG) {
- Log.i(TAG, "Service conflicts with advertisement on all networks: "
- + serviceInfo.getServiceName());
- }
- removeService(id);
- return false;
- }
+ updateRegistrationUntilNoConflict(checkConflictFilter, registration);
+ InterfaceAdvertiserRequest advertiser = mAdvertiserRequests.get(network);
+ if (advertiser == null) {
+ advertiser = new InterfaceAdvertiserRequest(network);
+ mAdvertiserRequests.put(network, advertiser);
+ }
+ advertiser.addService(id, registration);
mRegistrations.put(id, registration);
- return true;
}
/**
@@ -406,12 +455,20 @@
mRegistrations.remove(id);
}
- private boolean anyAdvertiser(@NonNull Predicate<MdnsInterfaceAdvertiser> predicate) {
- for (int i = 0; i < mAllAdvertisers.size(); i++) {
- if (predicate.test(mAllAdvertisers.valueAt(i))) {
+ private static <K, V> boolean any(@NonNull ArrayMap<K, V> map,
+ @NonNull BiPredicate<K, V> predicate) {
+ for (int i = 0; i < map.size(); i++) {
+ if (predicate.test(map.keyAt(i), map.valueAt(i))) {
return true;
}
}
return false;
}
+
+ private void forAllAdvertisers(@NonNull Consumer<MdnsInterfaceAdvertiser> consumer) {
+ any(mAllAdvertisers, (socket, advertiser) -> {
+ consumer.accept(advertiser);
+ return false;
+ });
+ }
}
diff --git a/service-t/src/com/android/server/mdns/MdnsConstants.java b/service-t/src/com/android/server/mdns/MdnsConstants.java
index 396be5f..f0e1717 100644
--- a/service-t/src/com/android/server/mdns/MdnsConstants.java
+++ b/service-t/src/com/android/server/mdns/MdnsConstants.java
@@ -37,6 +37,7 @@
public static final int FLAGS_QUERY = 0x0000;
public static final int FLAGS_RESPONSE_MASK = 0xF80F;
public static final int FLAGS_RESPONSE = 0x8000;
+ public static final int FLAG_TRUNCATED = 0x0200;
public static final int QCLASS_INTERNET = 0x0001;
public static final int QCLASS_UNICAST = 0x8000;
public static final String SUBTYPE_LABEL = "_sub";
diff --git a/service-t/src/com/android/server/mdns/MdnsInterfaceAdvertiser.java b/service-t/src/com/android/server/mdns/MdnsInterfaceAdvertiser.java
index 790e69a..a14b5ad 100644
--- a/service-t/src/com/android/server/mdns/MdnsInterfaceAdvertiser.java
+++ b/service-t/src/com/android/server/mdns/MdnsInterfaceAdvertiser.java
@@ -25,16 +25,18 @@
import android.util.Log;
import com.android.internal.annotations.VisibleForTesting;
+import com.android.net.module.util.HexDump;
import com.android.server.connectivity.mdns.MdnsAnnouncer.BaseAnnouncementInfo;
import com.android.server.connectivity.mdns.MdnsPacketRepeater.PacketRepeaterCallback;
import java.io.IOException;
+import java.net.InetSocketAddress;
import java.util.List;
/**
* A class that handles advertising services on a {@link MdnsInterfaceSocket} tied to an interface.
*/
-public class MdnsInterfaceAdvertiser {
+public class MdnsInterfaceAdvertiser implements MulticastPacketReader.PacketHandler {
private static final boolean DBG = MdnsAdvertiser.DBG;
@VisibleForTesting
public static final long EXIT_ANNOUNCEMENT_DELAY_MS = 100L;
@@ -145,9 +147,9 @@
/** @see MdnsReplySender */
@NonNull
- public MdnsReplySender makeReplySender(@NonNull Looper looper,
+ public MdnsReplySender makeReplySender(@NonNull String interfaceTag, @NonNull Looper looper,
@NonNull MdnsInterfaceSocket socket, @NonNull byte[] packetCreationBuffer) {
- return new MdnsReplySender(looper, socket, packetCreationBuffer);
+ return new MdnsReplySender(interfaceTag, looper, socket, packetCreationBuffer);
}
/** @see MdnsAnnouncer */
@@ -182,7 +184,7 @@
mSocket = socket;
mCb = cb;
mCbHandler = new Handler(looper);
- mReplySender = deps.makeReplySender(looper, socket, packetCreationBuffer);
+ mReplySender = deps.makeReplySender(logTag, looper, socket, packetCreationBuffer);
mAnnouncer = deps.makeMdnsAnnouncer(logTag, looper, mReplySender,
mAnnouncingCallback);
mProber = deps.makeMdnsProber(logTag, looper, mReplySender, mProbingCallback);
@@ -196,7 +198,7 @@
* {@link #destroyNow()}.
*/
public void start() {
- // TODO: start receiving packets
+ mSocket.addPacketHandler(this);
}
/**
@@ -267,8 +269,8 @@
mProber.stop(serviceId);
mAnnouncer.stop(serviceId);
}
-
- // TODO: stop receiving packets
+ mReplySender.cancelAll();
+ mSocket.removePacketHandler(this);
mCbHandler.post(() -> mCb.onDestroyed(mSocket));
}
@@ -294,4 +296,33 @@
public boolean isProbing(int serviceId) {
return mRecordRepository.isProbing(serviceId);
}
+
+ @Override
+ public void handlePacket(byte[] recvbuf, int length, InetSocketAddress src) {
+ final MdnsPacket packet;
+ try {
+ packet = MdnsPacket.parse(new MdnsPacketReader(recvbuf, length));
+ } catch (MdnsPacket.ParseException e) {
+ Log.e(mTag, "Error parsing mDNS packet", e);
+ if (DBG) {
+ Log.v(
+ mTag, "Packet: " + HexDump.toHexString(recvbuf, 0, length));
+ }
+ return;
+ }
+
+ if (DBG) {
+ Log.v(mTag,
+ "Parsed packet with " + packet.questions.size() + " questions, "
+ + packet.answers.size() + " answers, "
+ + packet.authorityRecords.size() + " authority, "
+ + packet.additionalRecords.size() + " additional from " + src);
+ }
+
+ final MdnsRecordRepository.ReplyInfo answers =
+ mRecordRepository.getReply(packet, src);
+
+ if (answers == null) return;
+ mReplySender.queueReply(answers);
+ }
}
diff --git a/service-t/src/com/android/server/mdns/MdnsInterfaceSocket.java b/service-t/src/com/android/server/mdns/MdnsInterfaceSocket.java
index d1290b6..119c7a8 100644
--- a/service-t/src/com/android/server/mdns/MdnsInterfaceSocket.java
+++ b/service-t/src/com/android/server/mdns/MdnsInterfaceSocket.java
@@ -162,6 +162,14 @@
}
/**
+ * Remove a handler added via {@link #addPacketHandler}. If the handler is not present, this is
+ * a no-op.
+ */
+ public void removePacketHandler(@NonNull MulticastPacketReader.PacketHandler handler) {
+ mPacketReader.removePacketHandler(handler);
+ }
+
+ /**
* Returns the network interface that this socket is bound to.
*
* <p>This method could be used on any thread.
diff --git a/service-t/src/com/android/server/mdns/MdnsPacketRepeater.java b/service-t/src/com/android/server/mdns/MdnsPacketRepeater.java
index ae54e70..4c385da 100644
--- a/service-t/src/com/android/server/mdns/MdnsPacketRepeater.java
+++ b/service-t/src/com/android/server/mdns/MdnsPacketRepeater.java
@@ -16,6 +16,9 @@
package com.android.server.connectivity.mdns;
+import static com.android.server.connectivity.mdns.MdnsRecordRepository.IPV4_ADDR;
+import static com.android.server.connectivity.mdns.MdnsRecordRepository.IPV6_ADDR;
+
import android.annotation.NonNull;
import android.annotation.Nullable;
import android.os.Handler;
@@ -32,10 +35,6 @@
*/
public abstract class MdnsPacketRepeater<T extends MdnsPacketRepeater.Request> {
private static final boolean DBG = MdnsAdvertiser.DBG;
- private static final InetSocketAddress IPV4_ADDR = new InetSocketAddress(
- MdnsConstants.getMdnsIPv4Address(), MdnsConstants.MDNS_PORT);
- private static final InetSocketAddress IPV6_ADDR = new InetSocketAddress(
- MdnsConstants.getMdnsIPv6Address(), MdnsConstants.MDNS_PORT);
private static final InetSocketAddress[] ALL_ADDRS = new InetSocketAddress[] {
IPV4_ADDR, IPV6_ADDR
};
@@ -114,7 +113,7 @@
final MdnsPacket packet = request.getPacket(index);
if (DBG) {
Log.v(getTag(), "Sending packets for iteration " + index + " out of "
- + request.getNumSends());
+ + request.getNumSends() + " for ID " + msg.what);
}
// Send to both v4 and v6 addresses; the reply sender will take care of ignoring the
// send when the socket has not joined the relevant group.
diff --git a/service-t/src/com/android/server/mdns/MdnsRecord.java b/service-t/src/com/android/server/mdns/MdnsRecord.java
index 00871ea..bcee9d1 100644
--- a/service-t/src/com/android/server/mdns/MdnsRecord.java
+++ b/service-t/src/com/android/server/mdns/MdnsRecord.java
@@ -45,6 +45,7 @@
private static final int FLAG_CACHE_FLUSH = 0x8000;
public static final long RECEIPT_TIME_NOT_SENT = 0L;
+ public static final int CLASS_ANY = 0x00ff;
/** Status indicating that the record is current. */
public static final int STATUS_OK = 0;
@@ -317,4 +318,4 @@
return (recordType * 31) + Arrays.hashCode(recordName);
}
}
-}
\ No newline at end of file
+}
diff --git a/service-t/src/com/android/server/mdns/MdnsRecordRepository.java b/service-t/src/com/android/server/mdns/MdnsRecordRepository.java
index dd00212..4b2f553 100644
--- a/service-t/src/com/android/server/mdns/MdnsRecordRepository.java
+++ b/service-t/src/com/android/server/mdns/MdnsRecordRepository.java
@@ -34,6 +34,7 @@
import java.io.IOException;
import java.net.Inet4Address;
import java.net.InetAddress;
+import java.net.InetSocketAddress;
import java.net.NetworkInterface;
import java.util.ArrayList;
import java.util.Arrays;
@@ -42,6 +43,7 @@
import java.util.Iterator;
import java.util.List;
import java.util.Map;
+import java.util.Random;
import java.util.Set;
import java.util.TreeMap;
import java.util.UUID;
@@ -54,6 +56,9 @@
*/
@TargetApi(Build.VERSION_CODES.TIRAMISU) // Allow calling T+ APIs; this is only loaded on T+
public class MdnsRecordRepository {
+ // RFC6762 p.15
+ private static final long MIN_MULTICAST_REPLY_INTERVAL_MS = 1_000L;
+
// TTLs as per RFC6762 10.
// TTL for records with a host name as the resource record's name (e.g., A, AAAA, HINFO) or a
// host name contained within the resource record's rdata (e.g., SRV, reverse mapping PTR
@@ -69,6 +74,13 @@
private static final String[] DNS_SD_SERVICE_TYPE =
new String[] { "_services", "_dns-sd", "_udp", LOCAL_TLD };
+ public static final InetSocketAddress IPV6_ADDR = new InetSocketAddress(
+ MdnsConstants.getMdnsIPv6Address(), MdnsConstants.MDNS_PORT);
+ public static final InetSocketAddress IPV4_ADDR = new InetSocketAddress(
+ MdnsConstants.getMdnsIPv4Address(), MdnsConstants.MDNS_PORT);
+
+ @NonNull
+ private final Random mDelayGenerator = new Random();
// Map of service unique ID -> records for service
@NonNull
private final SparseArray<ServiceRegistration> mServices = new SparseArray<>();
@@ -139,6 +151,11 @@
public boolean isProbing;
/**
+ * Last time (as per SystemClock.elapsedRealtime) when advertised via multicast, 0 if never
+ */
+ public long lastAdvertisedTimeMs;
+
+ /**
* Last time (as per SystemClock.elapsedRealtime) when sent via unicast or multicast,
* 0 if never
*/
@@ -391,6 +408,212 @@
}
/**
+ * Info about a reply to be sent.
+ */
+ public static class ReplyInfo {
+ @NonNull
+ public final List<MdnsRecord> answers;
+ @NonNull
+ public final List<MdnsRecord> additionalAnswers;
+ public final long sendDelayMs;
+ @NonNull
+ public final InetSocketAddress destination;
+
+ public ReplyInfo(
+ @NonNull List<MdnsRecord> answers,
+ @NonNull List<MdnsRecord> additionalAnswers,
+ long sendDelayMs,
+ @NonNull InetSocketAddress destination) {
+ this.answers = answers;
+ this.additionalAnswers = additionalAnswers;
+ this.sendDelayMs = sendDelayMs;
+ this.destination = destination;
+ }
+
+ @Override
+ public String toString() {
+ return "{ReplyInfo to " + destination + ", answers: " + answers.size()
+ + ", additionalAnswers: " + additionalAnswers.size()
+ + ", sendDelayMs " + sendDelayMs + "}";
+ }
+ }
+
+ /**
+ * Get the reply to send to an incoming packet.
+ *
+ * @param packet The incoming packet.
+ * @param src The source address of the incoming packet.
+ */
+ @Nullable
+ public ReplyInfo getReply(MdnsPacket packet, InetSocketAddress src) {
+ final long now = SystemClock.elapsedRealtime();
+ final boolean replyUnicast = (packet.flags & MdnsConstants.QCLASS_UNICAST) != 0;
+ final ArrayList<MdnsRecord> additionalAnswerRecords = new ArrayList<>();
+ final ArrayList<RecordInfo<?>> answerInfo = new ArrayList<>();
+ for (MdnsRecord question : packet.questions) {
+ // Add answers from general records
+ addReplyFromService(question, mGeneralRecords, null /* servicePtrRecord */,
+ null /* serviceSrvRecord */, null /* serviceTxtRecord */, replyUnicast, now,
+ answerInfo, additionalAnswerRecords);
+
+ // Add answers from each service
+ for (int i = 0; i < mServices.size(); i++) {
+ final ServiceRegistration registration = mServices.valueAt(i);
+ if (registration.exiting) continue;
+ addReplyFromService(question, registration.allRecords, registration.ptrRecord,
+ registration.srvRecord, registration.txtRecord, replyUnicast, now,
+ answerInfo, additionalAnswerRecords);
+ }
+ }
+
+ if (answerInfo.size() == 0 && additionalAnswerRecords.size() == 0) {
+ return null;
+ }
+
+ // Determine the send delay
+ final long delayMs;
+ if ((packet.flags & MdnsConstants.FLAG_TRUNCATED) != 0) {
+ // RFC 6762 6.: 400-500ms delay if TC bit is set
+ delayMs = 400L + mDelayGenerator.nextInt(100);
+ } else if (packet.questions.size() > 1
+ || CollectionUtils.any(answerInfo, a -> a.isSharedName)) {
+ // 20-120ms if there may be responses from other hosts (not a fully owned
+ // name) (RFC 6762 6.), or if there are multiple questions (6.3).
+ // TODO: this should be 0 if this is a probe query ("can be distinguished from a
+ // normal query by the fact that a probe query contains a proposed record in the
+ // Authority Section that answers the question" in 6.), and the reply is for a fully
+ // owned record.
+ delayMs = 20L + mDelayGenerator.nextInt(100);
+ } else {
+ delayMs = 0L;
+ }
+
+ // Determine the send destination
+ final InetSocketAddress dest;
+ if (replyUnicast) {
+ dest = src;
+ } else if (src.getAddress() instanceof Inet4Address) {
+ dest = IPV4_ADDR;
+ } else {
+ dest = IPV6_ADDR;
+ }
+
+ // Build the list of answer records from their RecordInfo
+ final ArrayList<MdnsRecord> answerRecords = new ArrayList<>(answerInfo.size());
+ for (RecordInfo<?> info : answerInfo) {
+ // TODO: consider actual packet send delay after response aggregation
+ info.lastSentTimeMs = now + delayMs;
+ if (!replyUnicast) {
+ info.lastAdvertisedTimeMs = info.lastSentTimeMs;
+ }
+ answerRecords.add(info.record);
+ }
+
+ return new ReplyInfo(answerRecords, additionalAnswerRecords, delayMs, dest);
+ }
+
+ /**
+ * Add answers and additional answers for a question, from a ServiceRegistration.
+ */
+ private void addReplyFromService(@NonNull MdnsRecord question,
+ @NonNull List<RecordInfo<?>> serviceRecords,
+ @Nullable RecordInfo<MdnsPointerRecord> servicePtrRecord,
+ @Nullable RecordInfo<MdnsServiceRecord> serviceSrvRecord,
+ @Nullable RecordInfo<MdnsTextRecord> serviceTxtRecord,
+ boolean replyUnicast, long now, @NonNull List<RecordInfo<?>> answerInfo,
+ @NonNull List<MdnsRecord> additionalAnswerRecords) {
+ boolean hasDnsSdPtrRecordAnswer = false;
+ boolean hasDnsSdSrvRecordAnswer = false;
+ boolean hasFullyOwnedNameMatch = false;
+ boolean hasKnownAnswer = false;
+
+ final int answersStartIndex = answerInfo.size();
+ for (RecordInfo<?> info : serviceRecords) {
+ if (info.isProbing) continue;
+
+ /* RFC6762 6.: the record name must match the question name, the record rrtype
+ must match the question qtype unless the qtype is "ANY" (255) or the rrtype is
+ "CNAME" (5), and the record rrclass must match the question qclass unless the
+ qclass is "ANY" (255) */
+ if (!Arrays.equals(info.record.getName(), question.getName())) continue;
+ hasFullyOwnedNameMatch |= !info.isSharedName;
+
+ // The repository does not store CNAME records
+ if (question.getType() != MdnsRecord.TYPE_ANY
+ && question.getType() != info.record.getType()) {
+ continue;
+ }
+ if (question.getRecordClass() != MdnsRecord.CLASS_ANY
+ && question.getRecordClass() != info.record.getRecordClass()) {
+ continue;
+ }
+
+ hasKnownAnswer = true;
+ hasDnsSdPtrRecordAnswer |= (info == servicePtrRecord);
+ hasDnsSdSrvRecordAnswer |= (info == serviceSrvRecord);
+
+ // TODO: responses to probe queries should bypass this check and only ensure the
+ // reply is sent 250ms after the last sent time (RFC 6762 p.15)
+ if (!replyUnicast && info.lastAdvertisedTimeMs > 0L
+ && now - info.lastAdvertisedTimeMs < MIN_MULTICAST_REPLY_INTERVAL_MS) {
+ continue;
+ }
+
+ // TODO: Don't reply if in known answers of the querier (7.1) if TTL is > half
+
+ answerInfo.add(info);
+ }
+
+ // RFC6762 6.1:
+ // "Any time a responder receives a query for a name for which it has verified exclusive
+ // ownership, for a type for which that name has no records, the responder MUST [...]
+ // respond asserting the nonexistence of that record"
+ if (hasFullyOwnedNameMatch && !hasKnownAnswer) {
+ additionalAnswerRecords.add(new MdnsNsecRecord(
+ question.getName(),
+ 0L /* receiptTimeMillis */,
+ true /* cacheFlush */,
+ // TODO: RFC6762 6.1: "In general, the TTL given for an NSEC record SHOULD
+ // be the same as the TTL that the record would have had, had it existed."
+ NAME_RECORDS_TTL_MILLIS,
+ question.getName(),
+ new int[] { question.getType() }));
+ }
+
+ // No more records to add if no answer
+ if (answerInfo.size() == answersStartIndex) return;
+
+ final List<RecordInfo<?>> additionalAnswerInfo = new ArrayList<>();
+ // RFC6763 12.1: if including PTR record, include the SRV and TXT records it names
+ if (hasDnsSdPtrRecordAnswer) {
+ if (serviceTxtRecord != null) {
+ additionalAnswerInfo.add(serviceTxtRecord);
+ }
+ if (serviceSrvRecord != null) {
+ additionalAnswerInfo.add(serviceSrvRecord);
+ }
+ }
+
+ // RFC6763 12.1&.2: if including PTR or SRV record, include the address records it names
+ if (hasDnsSdPtrRecordAnswer || hasDnsSdSrvRecordAnswer) {
+ for (RecordInfo<?> record : mGeneralRecords) {
+ if (record.record instanceof MdnsInetAddressRecord) {
+ additionalAnswerInfo.add(record);
+ }
+ }
+ }
+
+ for (RecordInfo<?> info : additionalAnswerInfo) {
+ additionalAnswerRecords.add(info.record);
+ }
+
+ // RFC6762 6.1: negative responses
+ addNsecRecordsForUniqueNames(additionalAnswerRecords,
+ answerInfo.listIterator(answersStartIndex),
+ additionalAnswerInfo.listIterator());
+ }
+
+ /**
* Add NSEC records indicating that the response records are unique.
*
* Following RFC6762 6.1:
@@ -540,6 +763,7 @@
final long now = SystemClock.elapsedRealtime();
for (RecordInfo<?> record : registration.allRecords) {
record.lastSentTimeMs = now;
+ record.lastAdvertisedTimeMs = now;
}
}
diff --git a/service-t/src/com/android/server/mdns/MdnsReplySender.java b/service-t/src/com/android/server/mdns/MdnsReplySender.java
index c6b8f47..f1389ca 100644
--- a/service-t/src/com/android/server/mdns/MdnsReplySender.java
+++ b/service-t/src/com/android/server/mdns/MdnsReplySender.java
@@ -16,8 +16,15 @@
package com.android.server.connectivity.mdns;
+import static com.android.server.connectivity.mdns.MdnsSocketProvider.ensureRunningOnHandlerThread;
+
import android.annotation.NonNull;
+import android.os.Handler;
import android.os.Looper;
+import android.os.Message;
+import android.util.Log;
+
+import com.android.server.connectivity.mdns.MdnsRecordRepository.ReplyInfo;
import java.io.IOException;
import java.net.DatagramPacket;
@@ -25,6 +32,7 @@
import java.net.Inet6Address;
import java.net.InetSocketAddress;
import java.net.MulticastSocket;
+import java.util.Collections;
/**
* A class that handles sending mDNS replies to a {@link MulticastSocket}, possibly queueing them
@@ -33,30 +41,46 @@
* TODO: implement sending after a delay, combining queued replies and duplicate answer suppression
*/
public class MdnsReplySender {
+ private static final boolean DBG = MdnsAdvertiser.DBG;
+ private static final int MSG_SEND = 1;
+
+ private final String mLogTag;
@NonNull
private final MdnsInterfaceSocket mSocket;
@NonNull
- private final Looper mLooper;
+ private final Handler mHandler;
@NonNull
private final byte[] mPacketCreationBuffer;
- public MdnsReplySender(@NonNull Looper looper,
+ public MdnsReplySender(@NonNull String interfaceTag, @NonNull Looper looper,
@NonNull MdnsInterfaceSocket socket, @NonNull byte[] packetCreationBuffer) {
- mLooper = looper;
+ mHandler = new SendHandler(looper);
+ mLogTag = MdnsReplySender.class.getSimpleName() + "/" + interfaceTag;
mSocket = socket;
mPacketCreationBuffer = packetCreationBuffer;
}
/**
+ * Queue a reply to be sent when its send delay expires.
+ */
+ public void queueReply(@NonNull ReplyInfo reply) {
+ ensureRunningOnHandlerThread(mHandler);
+ // TODO: implement response aggregation (RFC 6762 6.4)
+ mHandler.sendMessageDelayed(mHandler.obtainMessage(MSG_SEND, reply), reply.sendDelayMs);
+
+ if (DBG) {
+ Log.v(mLogTag, "Scheduling " + reply);
+ }
+ }
+
+ /**
* Send a packet immediately.
*
* Must be called on the looper thread used by the {@link MdnsReplySender}.
*/
public void sendNow(@NonNull MdnsPacket packet, @NonNull InetSocketAddress destination)
throws IOException {
- if (Thread.currentThread() != mLooper.getThread()) {
- throw new IllegalStateException("sendNow must be called in the handler thread");
- }
+ ensureRunningOnHandlerThread(mHandler);
if (!((destination.getAddress() instanceof Inet6Address && mSocket.hasJoinedIpv6())
|| (destination.getAddress() instanceof Inet4Address && mSocket.hasJoinedIpv4()))) {
// Skip sending if the socket has not joined the v4/v6 group (there was no address)
@@ -93,4 +117,37 @@
mSocket.send(new DatagramPacket(outBuffer, 0, len, destination));
}
+
+ /**
+ * Cancel all pending sends.
+ */
+ public void cancelAll() {
+ ensureRunningOnHandlerThread(mHandler);
+ mHandler.removeMessages(MSG_SEND);
+ }
+
+ private class SendHandler extends Handler {
+ SendHandler(@NonNull Looper looper) {
+ super(looper);
+ }
+
+ @Override
+ public void handleMessage(@NonNull Message msg) {
+ final ReplyInfo replyInfo = (ReplyInfo) msg.obj;
+ if (DBG) Log.v(mLogTag, "Sending " + replyInfo);
+
+ final int flags = 0x8400; // Response, authoritative (rfc6762 18.4)
+ final MdnsPacket packet = new MdnsPacket(flags,
+ Collections.emptyList() /* questions */,
+ replyInfo.answers,
+ Collections.emptyList() /* authorityRecords */,
+ replyInfo.additionalAnswers);
+
+ try {
+ sendNow(packet, replyInfo.destination);
+ } catch (IOException e) {
+ Log.e(mLogTag, "Error sending MDNS response", e);
+ }
+ }
+ }
}
diff --git a/service-t/src/com/android/server/mdns/MulticastPacketReader.java b/service-t/src/com/android/server/mdns/MulticastPacketReader.java
index 20cc47f..b597f0a 100644
--- a/service-t/src/com/android/server/mdns/MulticastPacketReader.java
+++ b/service-t/src/com/android/server/mdns/MulticastPacketReader.java
@@ -107,5 +107,14 @@
ensureRunningOnHandlerThread(mHandler);
mPacketHandlers.add(handler);
}
+
+ /**
+ * Remove a packet handler added via {@link #addPacketHandler}. If the handler was not set,
+ * this is a no-op.
+ */
+ public void removePacketHandler(@NonNull PacketHandler handler) {
+ ensureRunningOnHandlerThread(mHandler);
+ mPacketHandlers.remove(handler);
+ }
}
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsAdvertiserTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsAdvertiserTest.kt
index e2babb1..1febe6d 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsAdvertiserTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsAdvertiserTest.kt
@@ -38,6 +38,7 @@
import org.mockito.Mockito.any
import org.mockito.Mockito.anyInt
import org.mockito.Mockito.argThat
+import org.mockito.Mockito.atLeastOnce
import org.mockito.Mockito.doReturn
import org.mockito.Mockito.mock
import org.mockito.Mockito.never
@@ -161,6 +162,60 @@
verify(socketProvider).unrequestSocket(socketCb)
}
+ @Test
+ fun testAddService_Conflicts() {
+ val advertiser = MdnsAdvertiser(thread.looper, socketProvider, cb, mockDeps)
+ postSync { advertiser.addService(SERVICE_ID_1, SERVICE_1) }
+
+ val oneNetSocketCbCaptor = ArgumentCaptor.forClass(SocketCallback::class.java)
+ verify(socketProvider).requestSocket(eq(TEST_NETWORK_1), oneNetSocketCbCaptor.capture())
+ val oneNetSocketCb = oneNetSocketCbCaptor.value
+
+ // Register a service with the same name on all networks (name conflict)
+ postSync { advertiser.addService(SERVICE_ID_2, ALL_NETWORKS_SERVICE) }
+ val allNetSocketCbCaptor = ArgumentCaptor.forClass(SocketCallback::class.java)
+ verify(socketProvider).requestSocket(eq(null), allNetSocketCbCaptor.capture())
+ val allNetSocketCb = allNetSocketCbCaptor.value
+
+ // Callbacks for matching network and all networks both get the socket
+ postSync {
+ oneNetSocketCb.onSocketCreated(TEST_NETWORK_1, mockSocket1, listOf(TEST_LINKADDR))
+ allNetSocketCb.onSocketCreated(TEST_NETWORK_1, mockSocket1, listOf(TEST_LINKADDR))
+ }
+
+ val expectedRenamed = NsdServiceInfo(
+ "${ALL_NETWORKS_SERVICE.serviceName} (2)", ALL_NETWORKS_SERVICE.serviceType).apply {
+ port = ALL_NETWORKS_SERVICE.port
+ host = ALL_NETWORKS_SERVICE.host
+ network = ALL_NETWORKS_SERVICE.network
+ }
+
+ val intAdvCbCaptor = ArgumentCaptor.forClass(MdnsInterfaceAdvertiser.Callback::class.java)
+ verify(mockDeps).makeAdvertiser(eq(mockSocket1), eq(listOf(TEST_LINKADDR)),
+ eq(thread.looper), any(), intAdvCbCaptor.capture())
+ verify(mockInterfaceAdvertiser1).addService(eq(SERVICE_ID_1),
+ argThat { it.matches(SERVICE_1) })
+ verify(mockInterfaceAdvertiser1).addService(eq(SERVICE_ID_2),
+ argThat { it.matches(expectedRenamed) })
+
+ doReturn(false).`when`(mockInterfaceAdvertiser1).isProbing(SERVICE_ID_1)
+ postSync { intAdvCbCaptor.value.onRegisterServiceSucceeded(
+ mockInterfaceAdvertiser1, SERVICE_ID_1) }
+ verify(cb).onRegisterServiceSucceeded(eq(SERVICE_ID_1), argThat { it.matches(SERVICE_1) })
+
+ doReturn(false).`when`(mockInterfaceAdvertiser1).isProbing(SERVICE_ID_2)
+ postSync { intAdvCbCaptor.value.onRegisterServiceSucceeded(
+ mockInterfaceAdvertiser1, SERVICE_ID_2) }
+ verify(cb).onRegisterServiceSucceeded(eq(SERVICE_ID_2),
+ argThat { it.matches(expectedRenamed) })
+
+ postSync { oneNetSocketCb.onInterfaceDestroyed(TEST_NETWORK_1, mockSocket1) }
+ postSync { allNetSocketCb.onInterfaceDestroyed(TEST_NETWORK_1, mockSocket1) }
+
+ // destroyNow can be called multiple times
+ verify(mockInterfaceAdvertiser1, atLeastOnce()).destroyNow()
+ }
+
private fun postSync(r: () -> Unit) {
handler.post(r)
handler.waitForIdle(TIMEOUT_MS)
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsAnnouncerTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsAnnouncerTest.kt
index 334f99d..6c3f729 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsAnnouncerTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsAnnouncerTest.kt
@@ -79,7 +79,7 @@
@Test
fun testAnnounce() {
- val replySender = MdnsReplySender(thread.looper, socket, buffer)
+ val replySender = MdnsReplySender("testiface", thread.looper, socket, buffer)
@Suppress("UNCHECKED_CAST")
val cb = mock(MdnsPacketRepeater.PacketRepeaterCallback::class.java)
as MdnsPacketRepeater.PacketRepeaterCallback<BaseAnnouncementInfo>
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt
index 2cb0850..02b3976 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt
@@ -21,6 +21,7 @@
import android.net.nsd.NsdServiceInfo
import android.os.Build
import android.os.HandlerThread
+import com.android.net.module.util.HexDump
import com.android.server.connectivity.mdns.MdnsAnnouncer.AnnouncementInfo
import com.android.server.connectivity.mdns.MdnsAnnouncer.BaseAnnouncementInfo
import com.android.server.connectivity.mdns.MdnsAnnouncer.ExitAnnouncementInfo
@@ -30,6 +31,10 @@
import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo
import com.android.testutils.DevSdkIgnoreRunner
import com.android.testutils.waitForIdle
+import java.net.InetSocketAddress
+import kotlin.test.assertContentEquals
+import kotlin.test.assertEquals
+import kotlin.test.assertTrue
import org.junit.After
import org.junit.Before
import org.junit.Test
@@ -37,8 +42,10 @@
import org.mockito.ArgumentCaptor
import org.mockito.Mockito.any
import org.mockito.Mockito.anyInt
+import org.mockito.Mockito.anyString
import org.mockito.Mockito.doAnswer
import org.mockito.Mockito.doReturn
+import org.mockito.Mockito.eq
import org.mockito.Mockito.mock
import org.mockito.Mockito.times
import org.mockito.Mockito.verify
@@ -67,13 +74,18 @@
private val replySender = mock(MdnsReplySender::class.java)
private val announcer = mock(MdnsAnnouncer::class.java)
private val prober = mock(MdnsProber::class.java)
+ @Suppress("UNCHECKED_CAST")
private val probeCbCaptor = ArgumentCaptor.forClass(PacketRepeaterCallback::class.java)
as ArgumentCaptor<PacketRepeaterCallback<ProbingInfo>>
+ @Suppress("UNCHECKED_CAST")
private val announceCbCaptor = ArgumentCaptor.forClass(PacketRepeaterCallback::class.java)
as ArgumentCaptor<PacketRepeaterCallback<BaseAnnouncementInfo>>
+ private val packetHandlerCaptor = ArgumentCaptor.forClass(
+ MulticastPacketReader.PacketHandler::class.java)
private val probeCb get() = probeCbCaptor.value
private val announceCb get() = announceCbCaptor.value
+ private val packetHandler get() = packetHandlerCaptor.value
private val advertiser by lazy {
MdnsInterfaceAdvertiser(LOG_TAG, socket, TEST_ADDRS, thread.looper, TEST_BUFFER, cb, deps)
@@ -82,9 +94,9 @@
@Before
fun setUp() {
doReturn(repository).`when`(deps).makeRecordRepository(any())
- doReturn(replySender).`when`(deps).makeReplySender(any(), any(), any())
- doReturn(announcer).`when`(deps).makeMdnsAnnouncer(any(), any(), any(), any())
- doReturn(prober).`when`(deps).makeMdnsProber(any(), any(), any(), any())
+ doReturn(replySender).`when`(deps).makeReplySender(anyString(), any(), any(), any())
+ doReturn(announcer).`when`(deps).makeMdnsAnnouncer(anyString(), any(), any(), any())
+ doReturn(prober).`when`(deps).makeMdnsProber(anyString(), any(), any(), any())
val knownServices = mutableSetOf<Int>()
doAnswer { inv ->
@@ -104,6 +116,7 @@
thread.start()
advertiser.start()
+ verify(socket).addPacketHandler(packetHandlerCaptor.capture())
verify(deps).makeMdnsProber(any(), any(), any(), probeCbCaptor.capture())
verify(deps).makeMdnsAnnouncer(any(), any(), any(), announceCbCaptor.capture())
}
@@ -157,6 +170,39 @@
verify(announcer, times(1)).stop(TEST_SERVICE_ID_1)
}
+ @Test
+ fun testReplyToQuery() {
+ addServiceAndFinishProbing(TEST_SERVICE_ID_1, TEST_SERVICE_1)
+
+ val mockReply = mock(MdnsRecordRepository.ReplyInfo::class.java)
+ doReturn(mockReply).`when`(repository).getReply(any(), any())
+
+ // Query obtained with:
+ // scapy.raw(scapy.DNS(
+ // qd = scapy.DNSQR(qtype='PTR', qname='_testservice._tcp.local'))
+ // ).hex().upper()
+ val query = HexDump.hexStringToByteArray(
+ "0000010000010000000000000C5F7465737473657276696365045F746370056C6F63616C00000C0001"
+ )
+ val src = InetSocketAddress(parseNumericAddress("2001:db8::456"), MdnsConstants.MDNS_PORT)
+ packetHandler.handlePacket(query, query.size, src)
+
+ val packetCaptor = ArgumentCaptor.forClass(MdnsPacket::class.java)
+ verify(repository).getReply(packetCaptor.capture(), eq(src))
+
+ packetCaptor.value.let {
+ assertEquals(1, it.questions.size)
+ assertEquals(0, it.answers.size)
+ assertEquals(0, it.authorityRecords.size)
+ assertEquals(0, it.additionalRecords.size)
+
+ assertTrue(it.questions[0] is MdnsPointerRecord)
+ assertContentEquals(arrayOf("_testservice", "_tcp", "local"), it.questions[0].name)
+ }
+
+ verify(replySender).queueReply(mockReply)
+ }
+
private fun addServiceAndFinishProbing(serviceId: Int, serviceInfo: NsdServiceInfo):
AnnouncementInfo {
val testProbingInfo = mock(ProbingInfo::class.java)
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsProberTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsProberTest.kt
index 3caa97d..a2dbbc6 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsProberTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsProberTest.kt
@@ -114,7 +114,7 @@
@Test
fun testProbe() {
- val replySender = MdnsReplySender(thread.looper, socket, buffer)
+ val replySender = MdnsReplySender("testiface", thread.looper, socket, buffer)
val prober = TestProber(thread.looper, replySender, cb)
val probeInfo = TestProbeInfo(
listOf(makeServiceRecord(TEST_SERVICE_NAME_1, 37890)))
@@ -129,7 +129,7 @@
@Test
fun testProbeMultipleRecords() {
- val replySender = MdnsReplySender(thread.looper, socket, buffer)
+ val replySender = MdnsReplySender("testiface", thread.looper, socket, buffer)
val prober = TestProber(thread.looper, replySender, cb)
val probeInfo = TestProbeInfo(listOf(
makeServiceRecord(TEST_SERVICE_NAME_1, 37890),
@@ -167,7 +167,7 @@
@Test
fun testStopProbing() {
- val replySender = MdnsReplySender(thread.looper, socket, buffer)
+ val replySender = MdnsReplySender("testiface", thread.looper, socket, buffer)
val prober = TestProber(thread.looper, replySender, cb)
val probeInfo = TestProbeInfo(
listOf(makeServiceRecord(TEST_SERVICE_NAME_1, 37890)),
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt
index 29d0854..597663c 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt
@@ -21,10 +21,12 @@
import android.net.nsd.NsdServiceInfo
import android.os.Build
import android.os.HandlerThread
+import com.android.server.connectivity.mdns.MdnsAnnouncer.AnnouncementInfo
import com.android.server.connectivity.mdns.MdnsRecordRepository.Dependencies
import com.android.server.connectivity.mdns.MdnsRecordRepository.getReverseDnsAddress
import com.android.testutils.DevSdkIgnoreRule
import com.android.testutils.DevSdkIgnoreRunner
+import java.net.InetSocketAddress
import java.net.NetworkInterface
import java.util.Collections
import kotlin.test.assertContentEquals
@@ -150,11 +152,7 @@
@Test
fun testExitAnnouncements() {
val repository = MdnsRecordRepository(thread.looper, deps)
- repository.updateAddresses(TEST_ADDRESSES)
-
- repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
- val probingInfo = repository.setServiceProbing(TEST_SERVICE_ID_1)
- repository.onProbingSucceeded(probingInfo)
+ repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
repository.onAdvertisementSent(TEST_SERVICE_ID_1)
val exitAnnouncement = repository.exitService(TEST_SERVICE_ID_1)
@@ -183,9 +181,7 @@
@Test
fun testExitingServiceReAdded() {
val repository = MdnsRecordRepository(thread.looper, deps)
- repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
- val probingInfo = repository.setServiceProbing(TEST_SERVICE_ID_1)
- repository.onProbingSucceeded(probingInfo)
+ repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
repository.onAdvertisementSent(TEST_SERVICE_ID_1)
repository.exitService(TEST_SERVICE_ID_1)
@@ -199,11 +195,8 @@
@Test
fun testOnProbingSucceeded() {
val repository = MdnsRecordRepository(thread.looper, deps)
- repository.updateAddresses(TEST_ADDRESSES)
-
- repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
- val probingInfo = repository.setServiceProbing(TEST_SERVICE_ID_1)
- val announcementInfo = repository.onProbingSucceeded(probingInfo)
+ val announcementInfo = repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
+ repository.onAdvertisementSent(TEST_SERVICE_ID_1)
val packet = announcementInfo.getPacket(0)
assertEquals(0x8400 /* response, authoritative */, packet.flags)
@@ -322,4 +315,98 @@
val expectedV4 = "123.2.0.192.in-addr.arpa".split(".").toTypedArray()
assertContentEquals(expectedV4, getReverseDnsAddress(parseNumericAddress("192.0.2.123")))
}
+
+ @Test
+ fun testGetReply() {
+ val repository = MdnsRecordRepository(thread.looper, deps)
+ repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
+ val questions = listOf(MdnsPointerRecord(arrayOf("_testservice", "_tcp", "local"),
+ 0L /* receiptTimeMillis */,
+ false /* cacheFlush */,
+ // TTL and data is empty for a question
+ 0L /* ttlMillis */,
+ null /* pointer */))
+ val query = MdnsPacket(0 /* flags */, questions, listOf() /* answers */,
+ listOf() /* authorityRecords */, listOf() /* additionalRecords */)
+ val src = InetSocketAddress(parseNumericAddress("192.0.2.123"), 5353)
+ val reply = repository.getReply(query, src)
+
+ assertNotNull(reply)
+ // Source address is IPv4
+ assertEquals(MdnsConstants.getMdnsIPv4Address(), reply.destination.address)
+ assertEquals(MdnsConstants.MDNS_PORT, reply.destination.port)
+
+ // TTLs as per RFC6762 10.
+ val longTtl = 4_500_000L
+ val shortTtl = 120_000L
+ val serviceName = arrayOf("MyTestService", "_testservice", "_tcp", "local")
+
+ assertEquals(listOf(
+ MdnsPointerRecord(
+ arrayOf("_testservice", "_tcp", "local"),
+ 0L /* receiptTimeMillis */,
+ false /* cacheFlush */,
+ longTtl,
+ serviceName),
+ ), reply.answers)
+
+ assertEquals(listOf(
+ MdnsTextRecord(
+ serviceName,
+ 0L /* receiptTimeMillis */,
+ true /* cacheFlush */,
+ longTtl,
+ listOf() /* entries */),
+ MdnsServiceRecord(
+ serviceName,
+ 0L /* receiptTimeMillis */,
+ true /* cacheFlush */,
+ shortTtl,
+ 0 /* servicePriority */,
+ 0 /* serviceWeight */,
+ TEST_PORT,
+ TEST_HOSTNAME),
+ MdnsInetAddressRecord(
+ TEST_HOSTNAME,
+ 0L /* receiptTimeMillis */,
+ true /* cacheFlush */,
+ shortTtl,
+ TEST_ADDRESSES[0].address),
+ MdnsInetAddressRecord(
+ TEST_HOSTNAME,
+ 0L /* receiptTimeMillis */,
+ true /* cacheFlush */,
+ shortTtl,
+ TEST_ADDRESSES[1].address),
+ MdnsInetAddressRecord(
+ TEST_HOSTNAME,
+ 0L /* receiptTimeMillis */,
+ true /* cacheFlush */,
+ shortTtl,
+ TEST_ADDRESSES[2].address),
+ MdnsNsecRecord(
+ serviceName,
+ 0L /* receiptTimeMillis */,
+ true /* cacheFlush */,
+ longTtl,
+ serviceName /* nextDomain */,
+ intArrayOf(MdnsRecord.TYPE_TXT, MdnsRecord.TYPE_SRV)),
+ MdnsNsecRecord(
+ TEST_HOSTNAME,
+ 0L /* receiptTimeMillis */,
+ true /* cacheFlush */,
+ shortTtl,
+ TEST_HOSTNAME /* nextDomain */,
+ intArrayOf(MdnsRecord.TYPE_A, MdnsRecord.TYPE_AAAA)),
+ ), reply.additionalAnswers)
+ }
+}
+
+private fun MdnsRecordRepository.initWithService(serviceId: Int, serviceInfo: NsdServiceInfo):
+ AnnouncementInfo {
+ updateAddresses(TEST_ADDRESSES)
+ addService(serviceId, serviceInfo)
+ val probingInfo = setServiceProbing(serviceId)
+ assertNotNull(probingInfo)
+ return onProbingSucceeded(probingInfo)
}