Merge changes I3b1ad1be,Id4c2e610
* changes:
Implement exit announcements
Implement announcements on probing success
diff --git a/service-t/src/com/android/server/mdns/MdnsAnnouncer.java b/service-t/src/com/android/server/mdns/MdnsAnnouncer.java
index 7c84323..27fc945 100644
--- a/service-t/src/com/android/server/mdns/MdnsAnnouncer.java
+++ b/service-t/src/com/android/server/mdns/MdnsAnnouncer.java
@@ -30,23 +30,27 @@
*
* This allows maintaining other hosts' caches up-to-date. See RFC6762 8.3.
*/
-public class MdnsAnnouncer extends MdnsPacketRepeater<MdnsAnnouncer.AnnouncementInfo> {
+public class MdnsAnnouncer extends MdnsPacketRepeater<MdnsAnnouncer.BaseAnnouncementInfo> {
private static final long ANNOUNCEMENT_INITIAL_DELAY_MS = 1000L;
@VisibleForTesting
static final int ANNOUNCEMENT_COUNT = 8;
+ // Matches delay and GoodbyeCount used by the legacy implementation
+ private static final long EXIT_DELAY_MS = 2000L;
+ private static final int EXIT_COUNT = 3;
+
@NonNull
private final String mLogTag;
- /** Announcement request to send with {@link MdnsAnnouncer}. */
- public static class AnnouncementInfo implements MdnsPacketRepeater.Request {
+ /** Base class for announcement requests to send with {@link MdnsAnnouncer}. */
+ public abstract static class BaseAnnouncementInfo implements MdnsPacketRepeater.Request {
+ private final int mServiceId;
@NonNull
private final MdnsPacket mPacket;
- AnnouncementInfo(List<MdnsRecord> announcedRecords, List<MdnsRecord> additionalRecords) {
- // Records to announce (as answers)
- // Records to place in the "Additional records", with NSEC negative responses
- // to mark records that have been verified unique
+ protected BaseAnnouncementInfo(int serviceId, @NonNull List<MdnsRecord> announcedRecords,
+ @NonNull List<MdnsRecord> additionalRecords) {
+ mServiceId = serviceId;
final int flags = 0x8400; // Response, authoritative (rfc6762 18.4)
mPacket = new MdnsPacket(flags,
Collections.emptyList() /* questions */,
@@ -55,10 +59,22 @@
additionalRecords);
}
+ public int getServiceId() {
+ return mServiceId;
+ }
+
@Override
public MdnsPacket getPacket(int index) {
return mPacket;
}
+ }
+
+ /** Announcement request to send with {@link MdnsAnnouncer}. */
+ public static class AnnouncementInfo extends BaseAnnouncementInfo {
+ AnnouncementInfo(int serviceId, List<MdnsRecord> announcedRecords,
+ List<MdnsRecord> additionalRecords) {
+ super(serviceId, announcedRecords, additionalRecords);
+ }
@Override
public long getDelayMs(int nextIndex) {
@@ -72,9 +88,26 @@
}
}
+ /** Service exit announcement request to send with {@link MdnsAnnouncer}. */
+ public static class ExitAnnouncementInfo extends BaseAnnouncementInfo {
+ ExitAnnouncementInfo(int serviceId, List<MdnsRecord> announcedRecords) {
+ super(serviceId, announcedRecords, Collections.emptyList() /* additionalRecords */);
+ }
+
+ @Override
+ public long getDelayMs(int nextIndex) {
+ return EXIT_DELAY_MS;
+ }
+
+ @Override
+ public int getNumSends() {
+ return EXIT_COUNT;
+ }
+ }
+
public MdnsAnnouncer(@NonNull String interfaceTag, @NonNull Looper looper,
@NonNull MdnsReplySender replySender,
- @Nullable PacketRepeaterCallback<AnnouncementInfo> cb) {
+ @Nullable PacketRepeaterCallback<BaseAnnouncementInfo> cb) {
super(looper, replySender, cb);
mLogTag = MdnsAnnouncer.class.getSimpleName() + "/" + interfaceTag;
}
diff --git a/service-t/src/com/android/server/mdns/MdnsInterfaceAdvertiser.java b/service-t/src/com/android/server/mdns/MdnsInterfaceAdvertiser.java
index 997dcbb..790e69a 100644
--- a/service-t/src/com/android/server/mdns/MdnsInterfaceAdvertiser.java
+++ b/service-t/src/com/android/server/mdns/MdnsInterfaceAdvertiser.java
@@ -25,6 +25,7 @@
import android.util.Log;
import com.android.internal.annotations.VisibleForTesting;
+import com.android.server.connectivity.mdns.MdnsAnnouncer.BaseAnnouncementInfo;
import com.android.server.connectivity.mdns.MdnsPacketRepeater.PacketRepeaterCallback;
import java.io.IOException;
@@ -113,9 +114,22 @@
/**
* Callbacks from {@link MdnsAnnouncer}.
*/
- private class AnnouncingCallback
- implements PacketRepeaterCallback<MdnsAnnouncer.AnnouncementInfo> {
- // TODO: implement
+ private class AnnouncingCallback implements PacketRepeaterCallback<BaseAnnouncementInfo> {
+ @Override
+ public void onSent(int index, @NonNull BaseAnnouncementInfo info) {
+ mRecordRepository.onAdvertisementSent(info.getServiceId());
+ }
+
+ @Override
+ public void onFinished(@NonNull BaseAnnouncementInfo info) {
+ if (info instanceof MdnsAnnouncer.ExitAnnouncementInfo) {
+ mRecordRepository.removeService(info.getServiceId());
+
+ if (mRecordRepository.getServicesCount() == 0) {
+ destroyNow();
+ }
+ }
+ }
}
/**
@@ -139,7 +153,7 @@
/** @see MdnsAnnouncer */
public MdnsAnnouncer makeMdnsAnnouncer(@NonNull String interfaceTag, @NonNull Looper looper,
@NonNull MdnsReplySender replySender,
- @Nullable PacketRepeaterCallback<MdnsAnnouncer.AnnouncementInfo> cb) {
+ @Nullable PacketRepeaterCallback<MdnsAnnouncer.BaseAnnouncementInfo> cb) {
return new MdnsAnnouncer(interfaceTag, looper, replySender, cb);
}
@@ -210,9 +224,10 @@
* This will trigger exit announcements for the service.
*/
public void removeService(int id) {
+ if (!mRecordRepository.hasActiveService(id)) return;
mProber.stop(id);
mAnnouncer.stop(id);
- final MdnsAnnouncer.AnnouncementInfo exitInfo = mRecordRepository.exitService(id);
+ final MdnsAnnouncer.ExitAnnouncementInfo exitInfo = mRecordRepository.exitService(id);
if (exitInfo != null) {
// This effectively schedules destroyNow(), as it is to be called when the exit
// announcement finishes if there is no service left.
diff --git a/service-t/src/com/android/server/mdns/MdnsNsecRecord.java b/service-t/src/com/android/server/mdns/MdnsNsecRecord.java
index 06fdd5e..6ec2f99 100644
--- a/service-t/src/com/android/server/mdns/MdnsNsecRecord.java
+++ b/service-t/src/com/android/server/mdns/MdnsNsecRecord.java
@@ -17,12 +17,14 @@
package com.android.server.connectivity.mdns;
import android.net.DnsResolver;
+import android.text.TextUtils;
import com.android.net.module.util.CollectionUtils;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
+import java.util.Objects;
/**
* A mDNS "NSEC" record, used in particular for negative responses (RFC6762 6.1).
@@ -140,4 +142,30 @@
writer.writeUInt8(bytesInBlock);
writer.writeBytes(bytes);
}
+
+ @Override
+ public String toString() {
+ return "NSEC: NextDomain: " + TextUtils.join(".", mNextDomain)
+ + " Types " + Arrays.toString(mTypes);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(super.hashCode(),
+ Arrays.hashCode(mNextDomain), Arrays.hashCode(mTypes));
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (this == other) {
+ return true;
+ }
+ if (!(other instanceof MdnsNsecRecord)) {
+ return false;
+ }
+
+ return super.equals(other)
+ && Arrays.equals(mNextDomain, ((MdnsNsecRecord) other).mNextDomain)
+ && Arrays.equals(mTypes, ((MdnsNsecRecord) other).mTypes);
+ }
}
diff --git a/service-t/src/com/android/server/mdns/MdnsRecordRepository.java b/service-t/src/com/android/server/mdns/MdnsRecordRepository.java
index bb9c751..dd00212 100644
--- a/service-t/src/com/android/server/mdns/MdnsRecordRepository.java
+++ b/service-t/src/com/android/server/mdns/MdnsRecordRepository.java
@@ -18,21 +18,32 @@
import android.annotation.NonNull;
import android.annotation.Nullable;
+import android.annotation.TargetApi;
import android.net.LinkAddress;
import android.net.nsd.NsdServiceInfo;
+import android.os.Build;
import android.os.Looper;
+import android.os.SystemClock;
+import android.util.ArraySet;
import android.util.SparseArray;
import com.android.internal.annotations.VisibleForTesting;
+import com.android.net.module.util.CollectionUtils;
+import com.android.net.module.util.HexDump;
import java.io.IOException;
+import java.net.Inet4Address;
import java.net.InetAddress;
import java.net.NetworkInterface;
import java.util.ArrayList;
+import java.util.Arrays;
import java.util.Collections;
import java.util.Enumeration;
+import java.util.Iterator;
import java.util.List;
import java.util.Map;
+import java.util.Set;
+import java.util.TreeMap;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
@@ -41,6 +52,7 @@
*
* Must be used on a consistent looper thread.
*/
+@TargetApi(Build.VERSION_CODES.TIRAMISU) // Allow calling T+ APIs; this is only loaded on T+
public class MdnsRecordRepository {
// TTLs as per RFC6762 10.
// TTL for records with a host name as the resource record's name (e.g., A, AAAA, HINFO) or a
@@ -61,6 +73,8 @@
@NonNull
private final SparseArray<ServiceRegistration> mServices = new SparseArray<>();
@NonNull
+ private final List<RecordInfo<?>> mGeneralRecords = new ArrayList<>();
+ @NonNull
private final Looper mLooper;
@NonNull
private String[] mDeviceHostname;
@@ -124,6 +138,12 @@
*/
public boolean isProbing;
+ /**
+ * Last time (as per SystemClock.elapsedRealtime) when sent via unicast or multicast,
+ * 0 if never
+ */
+ public long lastSentTimeMs;
+
RecordInfo(NsdServiceInfo serviceInfo, T record, boolean sharedName,
boolean probing) {
this.serviceInfo = serviceInfo;
@@ -221,7 +241,29 @@
* Inform the repository of the latest interface addresses.
*/
public void updateAddresses(@NonNull List<LinkAddress> newAddresses) {
- // TODO: implement to update addresses in records
+ mGeneralRecords.clear();
+ for (LinkAddress addr : newAddresses) {
+ final String[] revDnsAddr = getReverseDnsAddress(addr.getAddress());
+ mGeneralRecords.add(new RecordInfo<>(
+ null /* serviceInfo */,
+ new MdnsPointerRecord(
+ revDnsAddr,
+ 0L /* receiptTimeMillis */,
+ true /* cacheFlush */,
+ NAME_RECORDS_TTL_MILLIS,
+ mDeviceHostname),
+ false /* sharedName */, false /* probing */));
+
+ mGeneralRecords.add(new RecordInfo<>(
+ null /* serviceInfo */,
+ new MdnsInetAddressRecord(
+ mDeviceHostname,
+ 0L /* receiptTimeMillis */,
+ true /* cacheFlush */,
+ NAME_RECORDS_TTL_MILLIS,
+ addr.getAddress()),
+ false /* sharedName */, false /* probing */));
+ }
}
/**
@@ -298,20 +340,31 @@
* @return The exit announcement to indicate the service was removed, or null if not necessary.
*/
@Nullable
- public MdnsAnnouncer.AnnouncementInfo exitService(int id) {
+ public MdnsAnnouncer.ExitAnnouncementInfo exitService(int id) {
final ServiceRegistration registration = mServices.get(id);
if (registration == null) return null;
if (registration.exiting) return null;
- registration.exiting = true;
+ // Send exit (TTL 0) for the PTR record, if the record was sent (in particular don't send
+ // if still probing)
+ if (registration.ptrRecord.lastSentTimeMs == 0L) {
+ return null;
+ }
- // TODO: implement
- return null;
+ registration.exiting = true;
+ final MdnsPointerRecord expiredRecord = new MdnsPointerRecord(
+ registration.ptrRecord.record.getName(),
+ 0L /* receiptTimeMillis */,
+ true /* cacheFlush */,
+ 0L /* ttlMillis */,
+ registration.ptrRecord.record.getPointer());
+
+ // Exit should be skipped if the record is still advertised by another service, but that
+ // would be a conflict (2 service registrations with the same service name), so it would
+ // not have been allowed by the repository.
+ return new MdnsAnnouncer.ExitAnnouncementInfo(id, Collections.singletonList(expiredRecord));
}
- /**
- * Remove a service from the repository
- */
public void removeService(int id) {
mServices.remove(id);
}
@@ -338,14 +391,110 @@
}
/**
+ * Add NSEC records indicating that the response records are unique.
+ *
+ * Following RFC6762 6.1:
+ * "On receipt of a question for a particular name, rrtype, and rrclass, for which a responder
+ * does have one or more unique answers, the responder MAY also include an NSEC record in the
+ * Additional Record Section indicating the nonexistence of other rrtypes for that name and
+ * rrclass."
+ * @param destinationList List to add the NSEC records to.
+ * @param answerRecords Lists of answered records based on which to add NSEC records (typically
+ * answer and additionalAnswer sections)
+ */
+ @SafeVarargs
+ private static void addNsecRecordsForUniqueNames(
+ List<MdnsRecord> destinationList,
+ Iterator<RecordInfo<?>>... answerRecords) {
+ // Group unique records by name. Use a TreeMap with comparator as arrays don't implement
+ // equals / hashCode.
+ final Map<String[], List<MdnsRecord>> nsecByName = new TreeMap<>(Arrays::compare);
+ // But keep the list of names in added order, otherwise records would be sorted in
+ // alphabetical order instead of the order of the original records, which would look like
+ // inconsistent behavior depending on service name.
+ final List<String[]> namesInAddedOrder = new ArrayList<>();
+ for (Iterator<RecordInfo<?>> answers : answerRecords) {
+ addNonSharedRecordsToMap(answers, nsecByName, namesInAddedOrder);
+ }
+
+ for (String[] nsecName : namesInAddedOrder) {
+ final List<MdnsRecord> entryRecords = nsecByName.get(nsecName);
+ long minTtl = Long.MAX_VALUE;
+ final Set<Integer> types = new ArraySet<>(entryRecords.size());
+ for (MdnsRecord record : entryRecords) {
+ if (minTtl > record.getTtl()) minTtl = record.getTtl();
+ types.add(record.getType());
+ }
+
+ destinationList.add(new MdnsNsecRecord(
+ nsecName,
+ 0L /* receiptTimeMillis */,
+ true /* cacheFlush */,
+ minTtl,
+ nsecName,
+ CollectionUtils.toIntArray(types)));
+ }
+ }
+
+ /**
+ * Add non-shared records to a map listing them by record name, and to a list of names that
+ * remembers the adding order.
+ *
+ * In the destination map records are grouped by name; so the map has one key per record name,
+ * and the values are the lists of different records that share the same name.
+ * @param records Records to scan.
+ * @param dest Map to add the records to.
+ * @param namesInAddedOrder List of names to add the names in order, keeping the first
+ * occurrence of each name.
+ */
+ private static void addNonSharedRecordsToMap(
+ Iterator<RecordInfo<?>> records,
+ Map<String[], List<MdnsRecord>> dest,
+ List<String[]> namesInAddedOrder) {
+ while (records.hasNext()) {
+ final RecordInfo<?> record = records.next();
+ if (record.isSharedName) continue;
+ final List<MdnsRecord> recordsForName = dest.computeIfAbsent(record.record.name,
+ key -> {
+ namesInAddedOrder.add(key);
+ return new ArrayList<>();
+ });
+ recordsForName.add(record.record);
+ }
+ }
+
+ /**
* Called to indicate that probing succeeded for a service.
* @param probeSuccessInfo The successful probing info.
* @return The {@link MdnsAnnouncer.AnnouncementInfo} to send, now that probing has succeeded.
*/
public MdnsAnnouncer.AnnouncementInfo onProbingSucceeded(
- MdnsProber.ProbingInfo probeSuccessInfo) throws IOException {
- // TODO: implement: set service as not probing anymore and generate announcements
- throw new IOException("Announcements not implemented");
+ MdnsProber.ProbingInfo probeSuccessInfo)
+ throws IOException {
+
+ final ServiceRegistration registration = mServices.get(probeSuccessInfo.getServiceId());
+ if (registration == null) throw new IOException(
+ "Service is not registered: " + probeSuccessInfo.getServiceId());
+ registration.setProbing(false);
+
+ final ArrayList<MdnsRecord> answers = new ArrayList<>();
+ final ArrayList<MdnsRecord> additionalAnswers = new ArrayList<>();
+
+ // Interface address records in general records
+ for (RecordInfo<?> record : mGeneralRecords) {
+ answers.add(record.record);
+ }
+
+ // All service records
+ for (RecordInfo<?> info : registration.allRecords) {
+ answers.add(info.record);
+ }
+
+ addNsecRecordsForUniqueNames(additionalAnswers,
+ mGeneralRecords.iterator(), registration.allRecords.iterator());
+
+ return new MdnsAnnouncer.AnnouncementInfo(probeSuccessInfo.getServiceId(),
+ answers, additionalAnswers);
}
/**
@@ -371,6 +520,60 @@
return registration.srvRecord.isProbing;
}
+ /**
+ * Return whether the repository has an active (non-exiting) service for the given ID.
+ */
+ public boolean hasActiveService(int serviceId) {
+ final ServiceRegistration registration = mServices.get(serviceId);
+ if (registration == null) return false;
+
+ return !registration.exiting;
+ }
+
+ /**
+ * Called when {@link MdnsAdvertiser} sent an advertisement for the given service.
+ */
+ public void onAdvertisementSent(int serviceId) {
+ final ServiceRegistration registration = mServices.get(serviceId);
+ if (registration == null) return;
+
+ final long now = SystemClock.elapsedRealtime();
+ for (RecordInfo<?> record : registration.allRecords) {
+ record.lastSentTimeMs = now;
+ }
+ }
+
+ /**
+ * Compute:
+ * 2001:db8::1 --> 1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.B.D.0.1.0.0.2.ip6.arpa
+ *
+ * Or:
+ * 192.0.2.123 --> 123.2.0.192.in-addr.arpa
+ */
+ @VisibleForTesting
+ public static String[] getReverseDnsAddress(@NonNull InetAddress addr) {
+ // xxx.xxx.xxx.xxx.in-addr.arpa (up to 28 characters)
+ // or 32 hex characters separated by dots + .ip6.arpa
+ final byte[] addrBytes = addr.getAddress();
+ final List<String> out = new ArrayList<>();
+ if (addr instanceof Inet4Address) {
+ for (int i = addrBytes.length - 1; i >= 0; i--) {
+ out.add(String.valueOf(Byte.toUnsignedInt(addrBytes[i])));
+ }
+ out.add("in-addr");
+ } else {
+ final String hexAddr = HexDump.toHexString(addrBytes);
+
+ for (int i = hexAddr.length() - 1; i >= 0; i--) {
+ out.add(String.valueOf(hexAddr.charAt(i)));
+ }
+ out.add("ip6");
+ }
+ out.add("arpa");
+
+ return out.toArray(new String[0]);
+ }
+
private static String[] splitFullyQualifiedName(
@NonNull NsdServiceInfo info, @NonNull String[] serviceType) {
final String[] split = new String[serviceType.length + 1];
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsAnnouncerTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsAnnouncerTest.kt
index 961f0f0..650607d 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsAnnouncerTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsAnnouncerTest.kt
@@ -22,11 +22,11 @@
import android.os.SystemClock
import com.android.internal.util.HexDump
import com.android.server.connectivity.mdns.MdnsAnnouncer.AnnouncementInfo
+import com.android.server.connectivity.mdns.MdnsAnnouncer.BaseAnnouncementInfo
+import com.android.server.connectivity.mdns.MdnsRecordRepository.getReverseDnsAddress
import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo
import com.android.testutils.DevSdkIgnoreRunner
import java.net.DatagramPacket
-import java.net.Inet6Address
-import java.net.InetAddress
import kotlin.test.assertEquals
import kotlin.test.assertTrue
import org.junit.After
@@ -68,7 +68,7 @@
private class TestAnnouncementInfo(
announcedRecords: List<MdnsRecord>,
additionalRecords: List<MdnsRecord>
- ) : AnnouncementInfo(announcedRecords, additionalRecords) {
+ ) : AnnouncementInfo(1 /* serviceId */, announcedRecords, additionalRecords) {
override fun getDelayMs(nextIndex: Int) =
if (nextIndex < FIRST_ANNOUNCES_COUNT) {
FIRST_ANNOUNCES_DELAY
@@ -82,7 +82,7 @@
val replySender = MdnsReplySender(thread.looper, socket, buffer)
@Suppress("UNCHECKED_CAST")
val cb = mock(MdnsPacketRepeater.PacketRepeaterCallback::class.java)
- as MdnsPacketRepeater.PacketRepeaterCallback<AnnouncementInfo>
+ as MdnsPacketRepeater.PacketRepeaterCallback<BaseAnnouncementInfo>
val announcer = MdnsAnnouncer("testiface", thread.looper, replySender, cb)
/*
The expected packet replicates records announced when registering a service, as observed in
@@ -150,8 +150,8 @@
val v6Addr1 = parseNumericAddress("2001:DB8::123")
val v6Addr2 = parseNumericAddress("2001:DB8::456")
val v4AddrRev = arrayOf("123", "0", "2", "192", "in-addr", "arpa")
- val v6Addr1Rev = getReverseV6AddressName(v6Addr1)
- val v6Addr2Rev = getReverseV6AddressName(v6Addr2)
+ val v6Addr1Rev = getReverseDnsAddress(v6Addr1)
+ val v6Addr2Rev = getReverseDnsAddress(v6Addr2)
val announcedRecords = listOf(
// Reverse address records
@@ -267,13 +267,3 @@
}
}
}
-
-/**
- * Compute 2001:db8::1 --> 1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.B.D.1.0.0.2.ip6.arpa
- */
-private fun getReverseV6AddressName(addr: InetAddress): Array<String> {
- assertTrue(addr is Inet6Address)
- return addr.address.flatMapTo(mutableListOf("arpa", "ip6")) {
- HexDump.toHexString(it).toCharArray().map(Char::toString)
- }.reversed().toTypedArray()
-}
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt
index ad22305..2cb0850 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt
@@ -22,6 +22,8 @@
import android.os.Build
import android.os.HandlerThread
import com.android.server.connectivity.mdns.MdnsAnnouncer.AnnouncementInfo
+import com.android.server.connectivity.mdns.MdnsAnnouncer.BaseAnnouncementInfo
+import com.android.server.connectivity.mdns.MdnsAnnouncer.ExitAnnouncementInfo
import com.android.server.connectivity.mdns.MdnsInterfaceAdvertiser.EXIT_ANNOUNCEMENT_DELAY_MS
import com.android.server.connectivity.mdns.MdnsPacketRepeater.PacketRepeaterCallback
import com.android.server.connectivity.mdns.MdnsProber.ProbingInfo
@@ -35,8 +37,10 @@
import org.mockito.ArgumentCaptor
import org.mockito.Mockito.any
import org.mockito.Mockito.anyInt
+import org.mockito.Mockito.doAnswer
import org.mockito.Mockito.doReturn
import org.mockito.Mockito.mock
+import org.mockito.Mockito.times
import org.mockito.Mockito.verify
private const val LOG_TAG = "testlogtag"
@@ -66,7 +70,7 @@
private val probeCbCaptor = ArgumentCaptor.forClass(PacketRepeaterCallback::class.java)
as ArgumentCaptor<PacketRepeaterCallback<ProbingInfo>>
private val announceCbCaptor = ArgumentCaptor.forClass(PacketRepeaterCallback::class.java)
- as ArgumentCaptor<PacketRepeaterCallback<AnnouncementInfo>>
+ as ArgumentCaptor<PacketRepeaterCallback<BaseAnnouncementInfo>>
private val probeCb get() = probeCbCaptor.value
private val announceCb get() = announceCbCaptor.value
@@ -82,7 +86,21 @@
doReturn(announcer).`when`(deps).makeMdnsAnnouncer(any(), any(), any(), any())
doReturn(prober).`when`(deps).makeMdnsProber(any(), any(), any(), any())
- doReturn(-1).`when`(repository).addService(anyInt(), any())
+ val knownServices = mutableSetOf<Int>()
+ doAnswer { inv ->
+ knownServices.add(inv.getArgument(0))
+ -1
+ }.`when`(repository).addService(anyInt(), any())
+ doAnswer { inv ->
+ knownServices.remove(inv.getArgument(0))
+ null
+ }.`when`(repository).removeService(anyInt())
+ doAnswer {
+ knownServices.toIntArray().also { knownServices.clear() }
+ }.`when`(repository).clearServices()
+ doAnswer { inv ->
+ knownServices.contains(inv.getArgument(0))
+ }.`when`(repository).hasActiveService(anyInt())
thread.start()
advertiser.start()
@@ -97,18 +115,7 @@
@Test
fun testAddRemoveService() {
- val testProbingInfo = mock(ProbingInfo::class.java)
- doReturn(TEST_SERVICE_ID_1).`when`(testProbingInfo).serviceId
- doReturn(testProbingInfo).`when`(repository).setServiceProbing(TEST_SERVICE_ID_1)
-
- advertiser.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
- verify(repository).addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
- verify(prober).startProbing(testProbingInfo)
-
- // Simulate probing success: continues to announcing
- val testAnnouncementInfo = mock(AnnouncementInfo::class.java)
- doReturn(testAnnouncementInfo).`when`(repository).onProbingSucceeded(testProbingInfo)
- probeCb.onFinished(testProbingInfo)
+ val testAnnouncementInfo = addServiceAndFinishProbing(TEST_SERVICE_ID_1, TEST_SERVICE_1)
verify(announcer).startSending(TEST_SERVICE_ID_1, testAnnouncementInfo,
0L /* initialDelayMs */)
@@ -117,13 +124,53 @@
verify(cb).onRegisterServiceSucceeded(advertiser, TEST_SERVICE_ID_1)
// Remove the service: expect exit announcements
- val testExitInfo = mock(AnnouncementInfo::class.java)
+ val testExitInfo = mock(ExitAnnouncementInfo::class.java)
doReturn(testExitInfo).`when`(repository).exitService(TEST_SERVICE_ID_1)
advertiser.removeService(TEST_SERVICE_ID_1)
+ verify(prober).stop(TEST_SERVICE_ID_1)
+ verify(announcer).stop(TEST_SERVICE_ID_1)
verify(announcer).startSending(TEST_SERVICE_ID_1, testExitInfo, EXIT_ANNOUNCEMENT_DELAY_MS)
- // TODO: after exit announcements are implemented, verify that announceCb.onFinished causes
- // cb.onDestroyed to be called.
+ // Exit announcements finish: the advertiser has no left service and destroys itself
+ announceCb.onFinished(testExitInfo)
+ thread.waitForIdle(TIMEOUT_MS)
+ verify(cb).onDestroyed(socket)
+ }
+
+ @Test
+ fun testDoubleRemove() {
+ addServiceAndFinishProbing(TEST_SERVICE_ID_1, TEST_SERVICE_1)
+
+ val testExitInfo = mock(ExitAnnouncementInfo::class.java)
+ doReturn(testExitInfo).`when`(repository).exitService(TEST_SERVICE_ID_1)
+ advertiser.removeService(TEST_SERVICE_ID_1)
+
+ verify(prober).stop(TEST_SERVICE_ID_1)
+ verify(announcer).stop(TEST_SERVICE_ID_1)
+ verify(announcer).startSending(TEST_SERVICE_ID_1, testExitInfo, EXIT_ANNOUNCEMENT_DELAY_MS)
+
+ doReturn(false).`when`(repository).hasActiveService(TEST_SERVICE_ID_1)
+ advertiser.removeService(TEST_SERVICE_ID_1)
+ // Prober, announcer were still stopped only one time
+ verify(prober, times(1)).stop(TEST_SERVICE_ID_1)
+ verify(announcer, times(1)).stop(TEST_SERVICE_ID_1)
+ }
+
+ private fun addServiceAndFinishProbing(serviceId: Int, serviceInfo: NsdServiceInfo):
+ AnnouncementInfo {
+ val testProbingInfo = mock(ProbingInfo::class.java)
+ doReturn(serviceId).`when`(testProbingInfo).serviceId
+ doReturn(testProbingInfo).`when`(repository).setServiceProbing(serviceId)
+
+ advertiser.addService(serviceId, serviceInfo)
+ verify(repository).addService(serviceId, serviceInfo)
+ verify(prober).startProbing(testProbingInfo)
+
+ // Simulate probing success: continues to announcing
+ val testAnnouncementInfo = mock(AnnouncementInfo::class.java)
+ doReturn(testAnnouncementInfo).`when`(repository).onProbingSucceeded(testProbingInfo)
+ probeCb.onFinished(testProbingInfo)
+ return testAnnouncementInfo
}
}
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt
index 502a36a..29d0854 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt
@@ -17,10 +17,12 @@
package com.android.server.connectivity.mdns
import android.net.InetAddresses.parseNumericAddress
+import android.net.LinkAddress
import android.net.nsd.NsdServiceInfo
import android.os.Build
import android.os.HandlerThread
import com.android.server.connectivity.mdns.MdnsRecordRepository.Dependencies
+import com.android.server.connectivity.mdns.MdnsRecordRepository.getReverseDnsAddress
import com.android.testutils.DevSdkIgnoreRule
import com.android.testutils.DevSdkIgnoreRunner
import java.net.NetworkInterface
@@ -28,6 +30,7 @@
import kotlin.test.assertContentEquals
import kotlin.test.assertEquals
import kotlin.test.assertFailsWith
+import kotlin.test.assertFalse
import kotlin.test.assertNotNull
import kotlin.test.assertTrue
import org.junit.After
@@ -39,10 +42,10 @@
private const val TEST_SERVICE_ID_2 = 43
private const val TEST_PORT = 12345
private val TEST_HOSTNAME = arrayOf("Android_000102030405060708090A0B0C0D0E0F", "local")
-private val TEST_ADDRESSES = arrayOf(
- parseNumericAddress("192.0.2.111"),
- parseNumericAddress("2001:db8::111"),
- parseNumericAddress("2001:db8::222"))
+private val TEST_ADDRESSES = listOf(
+ LinkAddress(parseNumericAddress("192.0.2.111"), 24),
+ LinkAddress(parseNumericAddress("2001:db8::111"), 64),
+ LinkAddress(parseNumericAddress("2001:db8::222"), 64))
private val TEST_SERVICE_1 = NsdServiceInfo().apply {
serviceType = "_testservice._tcp"
@@ -50,6 +53,12 @@
port = TEST_PORT
}
+private val TEST_SERVICE_2 = NsdServiceInfo().apply {
+ serviceType = "_testservice._tcp"
+ serviceName = "MyOtherTestService"
+ port = TEST_PORT
+}
+
@RunWith(DevSdkIgnoreRunner::class)
@DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.S_V2)
class MdnsRecordRepositoryTest {
@@ -57,7 +66,7 @@
private val deps = object : Dependencies() {
override fun getHostname() = TEST_HOSTNAME
override fun getInterfaceInetAddresses(iface: NetworkInterface) =
- Collections.enumeration(TEST_ADDRESSES.toList())
+ Collections.enumeration(TEST_ADDRESSES.map { it.address })
}
@Before
@@ -113,9 +122,71 @@
}
@Test
+ fun testInvalidReuseOfServiceId() {
+ val repository = MdnsRecordRepository(thread.looper, deps)
+ repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
+ assertFailsWith(IllegalArgumentException::class) {
+ repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_2)
+ }
+ }
+
+ @Test
+ fun testHasActiveService() {
+ val repository = MdnsRecordRepository(thread.looper, deps)
+ assertFalse(repository.hasActiveService(TEST_SERVICE_ID_1))
+
+ repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
+ assertTrue(repository.hasActiveService(TEST_SERVICE_ID_1))
+
+ val probingInfo = repository.setServiceProbing(TEST_SERVICE_ID_1)
+ repository.onProbingSucceeded(probingInfo)
+ repository.onAdvertisementSent(TEST_SERVICE_ID_1)
+ assertTrue(repository.hasActiveService(TEST_SERVICE_ID_1))
+
+ repository.exitService(TEST_SERVICE_ID_1)
+ assertFalse(repository.hasActiveService(TEST_SERVICE_ID_1))
+ }
+
+ @Test
+ fun testExitAnnouncements() {
+ val repository = MdnsRecordRepository(thread.looper, deps)
+ repository.updateAddresses(TEST_ADDRESSES)
+
+ repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
+ val probingInfo = repository.setServiceProbing(TEST_SERVICE_ID_1)
+ repository.onProbingSucceeded(probingInfo)
+ repository.onAdvertisementSent(TEST_SERVICE_ID_1)
+
+ val exitAnnouncement = repository.exitService(TEST_SERVICE_ID_1)
+ assertNotNull(exitAnnouncement)
+ assertEquals(1, repository.servicesCount)
+ val packet = exitAnnouncement.getPacket(0)
+
+ assertEquals(0x8400 /* response, authoritative */, packet.flags)
+ assertEquals(0, packet.questions.size)
+ assertEquals(0, packet.authorityRecords.size)
+ assertEquals(0, packet.additionalRecords.size)
+
+ assertContentEquals(listOf(
+ MdnsPointerRecord(
+ arrayOf("_testservice", "_tcp", "local"),
+ 0L /* receiptTimeMillis */,
+ true /* cacheFlush */,
+ 0L /* ttlMillis */,
+ arrayOf("MyTestService", "_testservice", "_tcp", "local"))
+ ), packet.answers)
+
+ repository.removeService(TEST_SERVICE_ID_1)
+ assertEquals(0, repository.servicesCount)
+ }
+
+ @Test
fun testExitingServiceReAdded() {
val repository = MdnsRecordRepository(thread.looper, deps)
repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
+ val probingInfo = repository.setServiceProbing(TEST_SERVICE_ID_1)
+ repository.onProbingSucceeded(probingInfo)
+ repository.onAdvertisementSent(TEST_SERVICE_ID_1)
repository.exitService(TEST_SERVICE_ID_1)
assertEquals(TEST_SERVICE_ID_1, repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_1))
@@ -124,4 +195,131 @@
repository.removeService(TEST_SERVICE_ID_2)
assertEquals(0, repository.servicesCount)
}
+
+ @Test
+ fun testOnProbingSucceeded() {
+ val repository = MdnsRecordRepository(thread.looper, deps)
+ repository.updateAddresses(TEST_ADDRESSES)
+
+ repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
+ val probingInfo = repository.setServiceProbing(TEST_SERVICE_ID_1)
+ val announcementInfo = repository.onProbingSucceeded(probingInfo)
+ val packet = announcementInfo.getPacket(0)
+
+ assertEquals(0x8400 /* response, authoritative */, packet.flags)
+ assertEquals(0, packet.questions.size)
+ assertEquals(0, packet.authorityRecords.size)
+
+ val serviceType = arrayOf("_testservice", "_tcp", "local")
+ val serviceName = arrayOf("MyTestService", "_testservice", "_tcp", "local")
+ val v4AddrRev = getReverseDnsAddress(TEST_ADDRESSES[0].address)
+ val v6Addr1Rev = getReverseDnsAddress(TEST_ADDRESSES[1].address)
+ val v6Addr2Rev = getReverseDnsAddress(TEST_ADDRESSES[2].address)
+
+ assertContentEquals(listOf(
+ // Reverse address and address records for the hostname
+ MdnsPointerRecord(v4AddrRev,
+ 0L /* receiptTimeMillis */,
+ true /* cacheFlush */,
+ 120000L /* ttlMillis */,
+ TEST_HOSTNAME),
+ MdnsInetAddressRecord(TEST_HOSTNAME,
+ 0L /* receiptTimeMillis */,
+ true /* cacheFlush */,
+ 120000L /* ttlMillis */,
+ TEST_ADDRESSES[0].address),
+ MdnsPointerRecord(v6Addr1Rev,
+ 0L /* receiptTimeMillis */,
+ true /* cacheFlush */,
+ 120000L /* ttlMillis */,
+ TEST_HOSTNAME),
+ MdnsInetAddressRecord(TEST_HOSTNAME,
+ 0L /* receiptTimeMillis */,
+ true /* cacheFlush */,
+ 120000L /* ttlMillis */,
+ TEST_ADDRESSES[1].address),
+ MdnsPointerRecord(v6Addr2Rev,
+ 0L /* receiptTimeMillis */,
+ true /* cacheFlush */,
+ 120000L /* ttlMillis */,
+ TEST_HOSTNAME),
+ MdnsInetAddressRecord(TEST_HOSTNAME,
+ 0L /* receiptTimeMillis */,
+ true /* cacheFlush */,
+ 120000L /* ttlMillis */,
+ TEST_ADDRESSES[2].address),
+ // Service registration records (RFC6763)
+ MdnsPointerRecord(
+ serviceType,
+ 0L /* receiptTimeMillis */,
+ // Not a unique name owned by the announcer, so cacheFlush=false
+ false /* cacheFlush */,
+ 4500000L /* ttlMillis */,
+ serviceName),
+ MdnsServiceRecord(
+ serviceName,
+ 0L /* receiptTimeMillis */,
+ true /* cacheFlush */,
+ 120000L /* ttlMillis */,
+ 0 /* servicePriority */,
+ 0 /* serviceWeight */,
+ TEST_PORT /* servicePort */,
+ TEST_HOSTNAME),
+ MdnsTextRecord(
+ serviceName,
+ 0L /* receiptTimeMillis */,
+ true /* cacheFlush */,
+ 4500000L /* ttlMillis */,
+ emptyList() /* entries */),
+ // Service type enumeration record (RFC6763 9.)
+ MdnsPointerRecord(
+ arrayOf("_services", "_dns-sd", "_udp", "local"),
+ 0L /* receiptTimeMillis */,
+ false /* cacheFlush */,
+ 4500000L /* ttlMillis */,
+ serviceType)
+ ), packet.answers)
+
+ assertContentEquals(listOf(
+ MdnsNsecRecord(v4AddrRev,
+ 0L /* receiptTimeMillis */,
+ true /* cacheFlush */,
+ 120000L /* ttlMillis */,
+ v4AddrRev,
+ intArrayOf(MdnsRecord.TYPE_PTR)),
+ MdnsNsecRecord(TEST_HOSTNAME,
+ 0L /* receiptTimeMillis */,
+ true /* cacheFlush */,
+ 120000L /* ttlMillis */,
+ TEST_HOSTNAME,
+ intArrayOf(MdnsRecord.TYPE_A, MdnsRecord.TYPE_AAAA)),
+ MdnsNsecRecord(v6Addr1Rev,
+ 0L /* receiptTimeMillis */,
+ true /* cacheFlush */,
+ 120000L /* ttlMillis */,
+ v6Addr1Rev,
+ intArrayOf(MdnsRecord.TYPE_PTR)),
+ MdnsNsecRecord(v6Addr2Rev,
+ 0L /* receiptTimeMillis */,
+ true /* cacheFlush */,
+ 120000L /* ttlMillis */,
+ v6Addr2Rev,
+ intArrayOf(MdnsRecord.TYPE_PTR)),
+ MdnsNsecRecord(serviceName,
+ 0L /* receiptTimeMillis */,
+ true /* cacheFlush */,
+ 4500000L /* ttlMillis */,
+ serviceName,
+ intArrayOf(MdnsRecord.TYPE_TXT, MdnsRecord.TYPE_SRV))
+ ), packet.additionalRecords)
+ }
+
+ @Test
+ fun testGetReverseDnsAddress() {
+ val expectedV6 = "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.B.D.0.1.0.0.2.ip6.arpa"
+ .split(".").toTypedArray()
+ assertContentEquals(expectedV6, getReverseDnsAddress(parseNumericAddress("2001:db8::1")))
+ val expectedV4 = "123.2.0.192.in-addr.arpa".split(".").toTypedArray()
+ assertContentEquals(expectedV4, getReverseDnsAddress(parseNumericAddress("192.0.2.123")))
+ }
}