Add tests for conflicts
Add end-to-end testing for conflicts happening during and after probing.
This includes making sure that interfaces have a usable IPv6 address
before starting each test, as mdnsresponder would not probe on the
interface if the address is not present, and would go straight to
advertising otherwise.
Bug: 266151066
Test: atest NsdManagerTest
Change-Id: Ie1a1e888afbcd5c1bafaf218a98ac4c2f5fe63ee
diff --git a/tests/cts/net/src/android/net/cts/NsdManagerTest.kt b/tests/cts/net/src/android/net/cts/NsdManagerTest.kt
index e4ee8de..7be4f78 100644
--- a/tests/cts/net/src/android/net/cts/NsdManagerTest.kt
+++ b/tests/cts/net/src/android/net/cts/NsdManagerTest.kt
@@ -25,6 +25,7 @@
import android.net.LinkProperties
import android.net.LocalSocket
import android.net.LocalSocketAddress
+import android.net.MacAddress
import android.net.Network
import android.net.NetworkAgentConfig
import android.net.NetworkCapabilities
@@ -73,6 +74,8 @@
import android.system.OsConstants.AF_INET6
import android.system.OsConstants.EADDRNOTAVAIL
import android.system.OsConstants.ENETUNREACH
+import android.system.OsConstants.ETH_P_IPV6
+import android.system.OsConstants.IPPROTO_IPV6
import android.system.OsConstants.IPPROTO_UDP
import android.system.OsConstants.SOCK_DGRAM
import android.util.Log
@@ -82,13 +85,21 @@
import com.android.compatibility.common.util.PropertyUtil
import com.android.modules.utils.build.SdkLevel.isAtLeastU
import com.android.net.module.util.ArrayTrackRecord
+import com.android.net.module.util.DnsPacket
+import com.android.net.module.util.HexDump
+import com.android.net.module.util.NetworkStackConstants.ETHER_HEADER_LEN
+import com.android.net.module.util.NetworkStackConstants.IPV6_HEADER_LEN
+import com.android.net.module.util.NetworkStackConstants.UDP_HEADER_LEN
+import com.android.net.module.util.PacketBuilder
import com.android.net.module.util.TrackRecord
import com.android.testutils.ConnectivityModuleTest
import com.android.testutils.DevSdkIgnoreRule
import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo
import com.android.testutils.DevSdkIgnoreRunner
+import com.android.testutils.IPv6UdpFilter
import com.android.testutils.RecorderCallback.CallbackEntry.CapabilitiesChanged
import com.android.testutils.RecorderCallback.CallbackEntry.LinkPropertiesChanged
+import com.android.testutils.TapPacketReader
import com.android.testutils.TestableNetworkAgent
import com.android.testutils.TestableNetworkAgent.CallbackEntry.OnNetworkCreated
import com.android.testutils.TestableNetworkCallback
@@ -103,6 +114,7 @@
import java.net.InetAddress
import java.net.NetworkInterface
import java.net.ServerSocket
+import java.nio.ByteBuffer
import java.nio.charset.StandardCharsets
import java.util.Random
import java.util.concurrent.Executor
@@ -133,6 +145,8 @@
private const val REGISTRATION_TIMEOUT_MS = 10_000L
private const val DBG = false
private const val TEST_PORT = 12345
+private const val MDNS_PORT = 5353.toShort()
+private val multicastIpv6Addr = parseNumericAddress("ff02::fb") as Inet6Address
@AppModeFull(reason = "Socket cannot bind in instant app mode")
@RunWith(DevSdkIgnoreRunner::class)
@@ -194,8 +208,8 @@
inline fun <reified V : NsdEvent> expectCallback(timeoutMs: Long = TIMEOUT_MS): V {
val nextEvent = nextEvents.poll(timeoutMs)
- assertNotNull(nextEvent, "No callback received after $timeoutMs ms, expected " +
- "${V::class.java.simpleName}")
+ assertNotNull(nextEvent, "No callback received after $timeoutMs ms, " +
+ "expected ${V::class.java.simpleName}")
assertTrue(nextEvent is V, "Expected ${V::class.java.simpleName} but got " +
nextEvent.javaClass.simpleName)
return nextEvent
@@ -411,7 +425,6 @@
val lp = LinkProperties().apply {
interfaceName = ifaceName
}
-
val agent = TestableNetworkAgent(context, handlerThread.looper,
NetworkCapabilities().apply {
removeCapability(NET_CAPABILITY_TRUSTED)
@@ -1144,6 +1157,176 @@
}
}
+ @Test
+ fun testRegisterWithConflictDuringProbing() {
+ // This test requires shims supporting T+ APIs (NsdServiceInfo.network)
+ assumeTrue(TestUtils.shouldTestTApis())
+
+ val si = NsdServiceInfo()
+ si.serviceType = serviceType
+ si.serviceName = serviceName
+ si.network = testNetwork1.network
+ si.port = 12345 // Test won't try to connect so port does not matter
+
+ val packetReader = TapPacketReader(Handler(handlerThread.looper),
+ testNetwork1.iface.fileDescriptor.fileDescriptor, 1500 /* maxPacketSize */)
+ packetReader.startAsyncForTest()
+ handlerThread.waitForIdle(TIMEOUT_MS)
+
+ // Register service on testNetwork1
+ val registrationRecord = NsdRegistrationRecord()
+ nsdManager.registerService(si, NsdManager.PROTOCOL_DNS_SD, { it.run() },
+ registrationRecord)
+
+ tryTest {
+ assertNotNull(packetReader.pollForProbe(serviceName, serviceType),
+ "Did not find a probe for the service")
+ packetReader.sendResponse(buildConflictingAnnouncement())
+
+ // Registration must use an updated name to avoid the conflict
+ val cb = registrationRecord.expectCallback<ServiceRegistered>(REGISTRATION_TIMEOUT_MS)
+ cb.serviceInfo.serviceName.let {
+ assertTrue("Unexpected registered name: $it",
+ it.startsWith(serviceName) && it != serviceName)
+ }
+ } cleanupStep {
+ nsdManager.unregisterService(registrationRecord)
+ registrationRecord.expectCallback<ServiceUnregistered>()
+ } cleanup {
+ packetReader.handler.post { packetReader.stop() }
+ handlerThread.waitForIdle(TIMEOUT_MS)
+ }
+ }
+
+ @Test
+ fun testRegisterWithConflictAfterProbing() {
+ // This test requires shims supporting T+ APIs (NsdServiceInfo.network)
+ assumeTrue(TestUtils.shouldTestTApis())
+
+ val si = NsdServiceInfo()
+ si.serviceType = serviceType
+ si.serviceName = serviceName
+ si.network = testNetwork1.network
+ si.port = 12345 // Test won't try to connect so port does not matter
+
+ // Register service on testNetwork1
+ val registrationRecord = NsdRegistrationRecord()
+ val discoveryRecord = NsdDiscoveryRecord()
+ val registeredService = registerService(registrationRecord, si)
+ val packetReader = TapPacketReader(Handler(handlerThread.looper),
+ testNetwork1.iface.fileDescriptor.fileDescriptor, 1500 /* maxPacketSize */)
+ packetReader.startAsyncForTest()
+ handlerThread.waitForIdle(TIMEOUT_MS)
+
+ tryTest {
+ assertNotNull(packetReader.pollForAdvertisement(serviceName, serviceType),
+ "No announcements sent after initial probing")
+
+ assertEquals(si.serviceName, registeredService.serviceName)
+
+ nsdManager.discoverServices(serviceType, NsdManager.PROTOCOL_DNS_SD,
+ testNetwork1.network, { it.run() }, discoveryRecord)
+ discoveryRecord.waitForServiceDiscovered(si.serviceName, serviceType)
+
+ // Send a conflicting announcement
+ val conflictingAnnouncement = buildConflictingAnnouncement()
+ packetReader.sendResponse(conflictingAnnouncement)
+
+ // Expect to see probes (RFC6762 9., service is reset to probing state)
+ assertNotNull(packetReader.pollForProbe(serviceName, serviceType),
+ "Probe not received within timeout after conflict")
+
+ // Send the conflicting packet again to reply to the probe
+ packetReader.sendResponse(conflictingAnnouncement)
+
+ // Note the legacy mdnsresponder would send an exit announcement here (a 0-lifetime
+ // advertisement just for the PTR record), but not the new advertiser. This probably
+ // follows RFC 6762 8.4, saying that when a record rdata changed, "In the case of shared
+ // records, a host MUST send a "goodbye" announcement with RR TTL zero [...] for the old
+ // rdata, to cause it to be deleted from peer caches, before announcing the new rdata".
+ //
+ // This should be implemented by the new advertiser, but in the case of conflicts it is
+ // not very valuable since an identical PTR record would be used by the conflicting
+ // service (except for subtypes). In that case the exit announcement may be
+ // counter-productive as it conflicts with announcements done by the conflicting
+ // service.
+
+ // Note that before sending the following ServiceRegistered callback for the renamed
+ // service, the legacy mdnsresponder-based implementation would first send a
+ // Service*Registered* callback for the original service name being *unregistered*; it
+ // should have been a ServiceUnregistered callback instead (bug in NsdService
+ // interpretation of the callback).
+ val newRegistration = registrationRecord.expectCallbackEventually<ServiceRegistered>(
+ REGISTRATION_TIMEOUT_MS) {
+ it.serviceInfo.serviceName.startsWith(serviceName) &&
+ it.serviceInfo.serviceName != serviceName
+ }
+
+ discoveryRecord.expectCallbackEventually<ServiceFound> {
+ it.serviceInfo.serviceName == newRegistration.serviceInfo.serviceName
+ }
+ } cleanupStep {
+ nsdManager.stopServiceDiscovery(discoveryRecord)
+ discoveryRecord.expectCallback<DiscoveryStopped>()
+ } cleanupStep {
+ nsdManager.unregisterService(registrationRecord)
+ registrationRecord.expectCallback<ServiceUnregistered>()
+ } cleanup {
+ packetReader.handler.post { packetReader.stop() }
+ handlerThread.waitForIdle(TIMEOUT_MS)
+ }
+ }
+
+ private fun buildConflictingAnnouncement(): ByteBuffer {
+ /*
+ Generated with:
+ scapy.raw(scapy.DNS(rd=0, qr=1, aa=1, qd = None, an =
+ scapy.DNSRRSRV(rrname='NsdTest123456789._nmt123456789._tcp.local',
+ rclass=0x8001, port=31234, target='conflict.local', ttl=120)
+ )).hex()
+ */
+ val mdnsPayload = HexDump.hexStringToByteArray("000084000000000100000000104e736454657" +
+ "3743132333435363738390d5f6e6d74313233343536373839045f746370056c6f63616c00002" +
+ "18001000000780016000000007a0208636f6e666c696374056c6f63616c00")
+ val packetBuffer = ByteBuffer.wrap(mdnsPayload)
+ // Replace service name and types in the packet with the random ones used in the test.
+ // Test service name and types have consistent length and are always ASCII
+ val testPacketName = "NsdTest123456789".encodeToByteArray()
+ val testPacketTypePrefix = "_nmt123456789".encodeToByteArray()
+ val encodedServiceName = serviceName.encodeToByteArray()
+ val encodedTypePrefix = serviceType.split('.')[0].encodeToByteArray()
+ assertEquals(testPacketName.size, encodedServiceName.size)
+ assertEquals(testPacketTypePrefix.size, encodedTypePrefix.size)
+ packetBuffer.position(mdnsPayload.indexOf(testPacketName))
+ packetBuffer.put(encodedServiceName)
+ packetBuffer.position(mdnsPayload.indexOf(testPacketTypePrefix))
+ packetBuffer.put(encodedTypePrefix)
+
+ return buildMdnsPacket(mdnsPayload)
+ }
+
+ private fun buildMdnsPacket(mdnsPayload: ByteArray): ByteBuffer {
+ val packetBuffer = PacketBuilder.allocate(true /* hasEther */, IPPROTO_IPV6,
+ IPPROTO_UDP, mdnsPayload.size)
+ val packetBuilder = PacketBuilder(packetBuffer)
+ // Multicast ethernet address for IPv6 to ff02::fb
+ val multicastEthAddr = MacAddress.fromBytes(
+ byteArrayOf(0x33, 0x33, 0, 0, 0, 0xfb.toByte()))
+ packetBuilder.writeL2Header(
+ MacAddress.fromBytes(byteArrayOf(1, 2, 3, 4, 5, 6)) /* srcMac */,
+ multicastEthAddr,
+ ETH_P_IPV6.toShort())
+ packetBuilder.writeIpv6Header(
+ 0x60000000, // version=6, traffic class=0x0, flowlabel=0x0
+ IPPROTO_UDP.toByte(),
+ 64 /* hop limit */,
+ parseNumericAddress("2001:db8::123") as Inet6Address /* srcIp */,
+ multicastIpv6Addr /* dstIp */)
+ packetBuilder.writeUdpHeader(MDNS_PORT /* srcPort */, MDNS_PORT /* dstPort */)
+ packetBuffer.put(mdnsPayload)
+ return packetBuilder.finalizePacket()
+ }
+
/**
* Register a service and return its registration record.
*/
@@ -1169,7 +1352,65 @@
}
}
+private fun TapPacketReader.pollForMdnsPacket(
+ timeoutMs: Long = REGISTRATION_TIMEOUT_MS,
+ predicate: (TestDnsPacket) -> Boolean
+): ByteArray? {
+ val mdnsProbeFilter = IPv6UdpFilter(srcPort = MDNS_PORT, dstPort = MDNS_PORT).and {
+ val mdnsPayload = it.copyOfRange(
+ ETHER_HEADER_LEN + IPV6_HEADER_LEN + UDP_HEADER_LEN, it.size)
+ try {
+ predicate(TestDnsPacket(mdnsPayload))
+ } catch (e: DnsPacket.ParseException) {
+ false
+ }
+ }
+ return poll(timeoutMs, mdnsProbeFilter)
+}
+
+private fun TapPacketReader.pollForProbe(
+ serviceName: String,
+ serviceType: String,
+ timeoutMs: Long = REGISTRATION_TIMEOUT_MS
+): ByteArray? = pollForMdnsPacket(timeoutMs) { it.isProbeFor("$serviceName.$serviceType.local") }
+
+private fun TapPacketReader.pollForAdvertisement(
+ serviceName: String,
+ serviceType: String,
+ timeoutMs: Long = REGISTRATION_TIMEOUT_MS
+): ByteArray? = pollForMdnsPacket(timeoutMs) { it.isReplyFor("$serviceName.$serviceType.local") }
+
+private class TestDnsPacket(data: ByteArray) : DnsPacket(data) {
+ fun isProbeFor(name: String): Boolean = mRecords[QDSECTION].any {
+ it.dName == name && it.nsType == 0xff /* ANY */
+ }
+
+ fun isReplyFor(name: String): Boolean = mRecords[ANSECTION].any {
+ it.dName == name && it.nsType == 0x21 /* SRV */
+ }
+}
+
private fun ByteArray?.utf8ToString(): String {
if (this == null) return ""
return String(this, StandardCharsets.UTF_8)
}
+
+private fun ByteArray.indexOf(sub: ByteArray): Int {
+ var subIndex = 0
+ forEachIndexed { i, b ->
+ when (b) {
+ // Still matching: continue comparing with next byte
+ sub[subIndex] -> {
+ subIndex++
+ if (subIndex == sub.size) {
+ return i - sub.size + 1
+ }
+ }
+ // Not matching next byte but matches first byte: continue comparing with 2nd byte
+ sub[0] -> subIndex = 1
+ // No matches: continue comparing from first byte
+ else -> subIndex = 0
+ }
+ }
+ return -1
+}