Add tests for conflicts

Add end-to-end testing for conflicts happening during and after probing.

This includes making sure that interfaces have a usable IPv6 address
before starting each test, as mdnsresponder would not probe on the
interface if the address is not present, and would go straight to
advertising otherwise.

Bug: 266151066
Test: atest NsdManagerTest
Change-Id: Ie1a1e888afbcd5c1bafaf218a98ac4c2f5fe63ee
diff --git a/tests/cts/net/src/android/net/cts/NsdManagerTest.kt b/tests/cts/net/src/android/net/cts/NsdManagerTest.kt
index e4ee8de..7be4f78 100644
--- a/tests/cts/net/src/android/net/cts/NsdManagerTest.kt
+++ b/tests/cts/net/src/android/net/cts/NsdManagerTest.kt
@@ -25,6 +25,7 @@
 import android.net.LinkProperties
 import android.net.LocalSocket
 import android.net.LocalSocketAddress
+import android.net.MacAddress
 import android.net.Network
 import android.net.NetworkAgentConfig
 import android.net.NetworkCapabilities
@@ -73,6 +74,8 @@
 import android.system.OsConstants.AF_INET6
 import android.system.OsConstants.EADDRNOTAVAIL
 import android.system.OsConstants.ENETUNREACH
+import android.system.OsConstants.ETH_P_IPV6
+import android.system.OsConstants.IPPROTO_IPV6
 import android.system.OsConstants.IPPROTO_UDP
 import android.system.OsConstants.SOCK_DGRAM
 import android.util.Log
@@ -82,13 +85,21 @@
 import com.android.compatibility.common.util.PropertyUtil
 import com.android.modules.utils.build.SdkLevel.isAtLeastU
 import com.android.net.module.util.ArrayTrackRecord
+import com.android.net.module.util.DnsPacket
+import com.android.net.module.util.HexDump
+import com.android.net.module.util.NetworkStackConstants.ETHER_HEADER_LEN
+import com.android.net.module.util.NetworkStackConstants.IPV6_HEADER_LEN
+import com.android.net.module.util.NetworkStackConstants.UDP_HEADER_LEN
+import com.android.net.module.util.PacketBuilder
 import com.android.net.module.util.TrackRecord
 import com.android.testutils.ConnectivityModuleTest
 import com.android.testutils.DevSdkIgnoreRule
 import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo
 import com.android.testutils.DevSdkIgnoreRunner
+import com.android.testutils.IPv6UdpFilter
 import com.android.testutils.RecorderCallback.CallbackEntry.CapabilitiesChanged
 import com.android.testutils.RecorderCallback.CallbackEntry.LinkPropertiesChanged
+import com.android.testutils.TapPacketReader
 import com.android.testutils.TestableNetworkAgent
 import com.android.testutils.TestableNetworkAgent.CallbackEntry.OnNetworkCreated
 import com.android.testutils.TestableNetworkCallback
@@ -103,6 +114,7 @@
 import java.net.InetAddress
 import java.net.NetworkInterface
 import java.net.ServerSocket
+import java.nio.ByteBuffer
 import java.nio.charset.StandardCharsets
 import java.util.Random
 import java.util.concurrent.Executor
@@ -133,6 +145,8 @@
 private const val REGISTRATION_TIMEOUT_MS = 10_000L
 private const val DBG = false
 private const val TEST_PORT = 12345
+private const val MDNS_PORT = 5353.toShort()
+private val multicastIpv6Addr = parseNumericAddress("ff02::fb") as Inet6Address
 
 @AppModeFull(reason = "Socket cannot bind in instant app mode")
 @RunWith(DevSdkIgnoreRunner::class)
@@ -194,8 +208,8 @@
 
         inline fun <reified V : NsdEvent> expectCallback(timeoutMs: Long = TIMEOUT_MS): V {
             val nextEvent = nextEvents.poll(timeoutMs)
-            assertNotNull(nextEvent, "No callback received after $timeoutMs ms, expected " +
-                    "${V::class.java.simpleName}")
+            assertNotNull(nextEvent, "No callback received after $timeoutMs ms, " +
+                    "expected ${V::class.java.simpleName}")
             assertTrue(nextEvent is V, "Expected ${V::class.java.simpleName} but got " +
                     nextEvent.javaClass.simpleName)
             return nextEvent
@@ -411,7 +425,6 @@
         val lp = LinkProperties().apply {
             interfaceName = ifaceName
         }
-
         val agent = TestableNetworkAgent(context, handlerThread.looper,
                 NetworkCapabilities().apply {
                     removeCapability(NET_CAPABILITY_TRUSTED)
@@ -1144,6 +1157,176 @@
         }
     }
 
+    @Test
+    fun testRegisterWithConflictDuringProbing() {
+        // This test requires shims supporting T+ APIs (NsdServiceInfo.network)
+        assumeTrue(TestUtils.shouldTestTApis())
+
+        val si = NsdServiceInfo()
+        si.serviceType = serviceType
+        si.serviceName = serviceName
+        si.network = testNetwork1.network
+        si.port = 12345 // Test won't try to connect so port does not matter
+
+        val packetReader = TapPacketReader(Handler(handlerThread.looper),
+                testNetwork1.iface.fileDescriptor.fileDescriptor, 1500 /* maxPacketSize */)
+        packetReader.startAsyncForTest()
+        handlerThread.waitForIdle(TIMEOUT_MS)
+
+        // Register service on testNetwork1
+        val registrationRecord = NsdRegistrationRecord()
+        nsdManager.registerService(si, NsdManager.PROTOCOL_DNS_SD, { it.run() },
+                registrationRecord)
+
+        tryTest {
+            assertNotNull(packetReader.pollForProbe(serviceName, serviceType),
+                    "Did not find a probe for the service")
+            packetReader.sendResponse(buildConflictingAnnouncement())
+
+            // Registration must use an updated name to avoid the conflict
+            val cb = registrationRecord.expectCallback<ServiceRegistered>(REGISTRATION_TIMEOUT_MS)
+            cb.serviceInfo.serviceName.let {
+                assertTrue("Unexpected registered name: $it",
+                        it.startsWith(serviceName) && it != serviceName)
+            }
+        } cleanupStep {
+            nsdManager.unregisterService(registrationRecord)
+            registrationRecord.expectCallback<ServiceUnregistered>()
+        } cleanup {
+            packetReader.handler.post { packetReader.stop() }
+            handlerThread.waitForIdle(TIMEOUT_MS)
+        }
+    }
+
+    @Test
+    fun testRegisterWithConflictAfterProbing() {
+        // This test requires shims supporting T+ APIs (NsdServiceInfo.network)
+        assumeTrue(TestUtils.shouldTestTApis())
+
+        val si = NsdServiceInfo()
+        si.serviceType = serviceType
+        si.serviceName = serviceName
+        si.network = testNetwork1.network
+        si.port = 12345 // Test won't try to connect so port does not matter
+
+        // Register service on testNetwork1
+        val registrationRecord = NsdRegistrationRecord()
+        val discoveryRecord = NsdDiscoveryRecord()
+        val registeredService = registerService(registrationRecord, si)
+        val packetReader = TapPacketReader(Handler(handlerThread.looper),
+                testNetwork1.iface.fileDescriptor.fileDescriptor, 1500 /* maxPacketSize */)
+        packetReader.startAsyncForTest()
+        handlerThread.waitForIdle(TIMEOUT_MS)
+
+        tryTest {
+            assertNotNull(packetReader.pollForAdvertisement(serviceName, serviceType),
+                    "No announcements sent after initial probing")
+
+            assertEquals(si.serviceName, registeredService.serviceName)
+
+            nsdManager.discoverServices(serviceType, NsdManager.PROTOCOL_DNS_SD,
+                testNetwork1.network, { it.run() }, discoveryRecord)
+            discoveryRecord.waitForServiceDiscovered(si.serviceName, serviceType)
+
+            // Send a conflicting announcement
+            val conflictingAnnouncement = buildConflictingAnnouncement()
+            packetReader.sendResponse(conflictingAnnouncement)
+
+            // Expect to see probes (RFC6762 9., service is reset to probing state)
+            assertNotNull(packetReader.pollForProbe(serviceName, serviceType),
+                    "Probe not received within timeout after conflict")
+
+            // Send the conflicting packet again to reply to the probe
+            packetReader.sendResponse(conflictingAnnouncement)
+
+            // Note the legacy mdnsresponder would send an exit announcement here (a 0-lifetime
+            // advertisement just for the PTR record), but not the new advertiser. This probably
+            // follows RFC 6762 8.4, saying that when a record rdata changed, "In the case of shared
+            // records, a host MUST send a "goodbye" announcement with RR TTL zero [...] for the old
+            // rdata, to cause it to be deleted from peer caches, before announcing the new rdata".
+            //
+            // This should be implemented by the new advertiser, but in the case of conflicts it is
+            // not very valuable since an identical PTR record would be used by the conflicting
+            // service (except for subtypes). In that case the exit announcement may be
+            // counter-productive as it conflicts with announcements done by the conflicting
+            // service.
+
+            // Note that before sending the following ServiceRegistered callback for the renamed
+            // service, the legacy mdnsresponder-based implementation would first send a
+            // Service*Registered* callback for the original service name being *unregistered*; it
+            // should have been a ServiceUnregistered callback instead (bug in NsdService
+            // interpretation of the callback).
+            val newRegistration = registrationRecord.expectCallbackEventually<ServiceRegistered>(
+                    REGISTRATION_TIMEOUT_MS) {
+                it.serviceInfo.serviceName.startsWith(serviceName) &&
+                        it.serviceInfo.serviceName != serviceName
+            }
+
+            discoveryRecord.expectCallbackEventually<ServiceFound> {
+                it.serviceInfo.serviceName == newRegistration.serviceInfo.serviceName
+            }
+        } cleanupStep {
+            nsdManager.stopServiceDiscovery(discoveryRecord)
+            discoveryRecord.expectCallback<DiscoveryStopped>()
+        } cleanupStep {
+            nsdManager.unregisterService(registrationRecord)
+            registrationRecord.expectCallback<ServiceUnregistered>()
+        } cleanup {
+            packetReader.handler.post { packetReader.stop() }
+            handlerThread.waitForIdle(TIMEOUT_MS)
+        }
+    }
+
+    private fun buildConflictingAnnouncement(): ByteBuffer {
+        /*
+        Generated with:
+        scapy.raw(scapy.DNS(rd=0, qr=1, aa=1, qd = None, an =
+                scapy.DNSRRSRV(rrname='NsdTest123456789._nmt123456789._tcp.local',
+                    rclass=0x8001, port=31234, target='conflict.local', ttl=120)
+        )).hex()
+         */
+        val mdnsPayload = HexDump.hexStringToByteArray("000084000000000100000000104e736454657" +
+                "3743132333435363738390d5f6e6d74313233343536373839045f746370056c6f63616c00002" +
+                "18001000000780016000000007a0208636f6e666c696374056c6f63616c00")
+        val packetBuffer = ByteBuffer.wrap(mdnsPayload)
+        // Replace service name and types in the packet with the random ones used in the test.
+        // Test service name and types have consistent length and are always ASCII
+        val testPacketName = "NsdTest123456789".encodeToByteArray()
+        val testPacketTypePrefix = "_nmt123456789".encodeToByteArray()
+        val encodedServiceName = serviceName.encodeToByteArray()
+        val encodedTypePrefix = serviceType.split('.')[0].encodeToByteArray()
+        assertEquals(testPacketName.size, encodedServiceName.size)
+        assertEquals(testPacketTypePrefix.size, encodedTypePrefix.size)
+        packetBuffer.position(mdnsPayload.indexOf(testPacketName))
+        packetBuffer.put(encodedServiceName)
+        packetBuffer.position(mdnsPayload.indexOf(testPacketTypePrefix))
+        packetBuffer.put(encodedTypePrefix)
+
+        return buildMdnsPacket(mdnsPayload)
+    }
+
+    private fun buildMdnsPacket(mdnsPayload: ByteArray): ByteBuffer {
+        val packetBuffer = PacketBuilder.allocate(true /* hasEther */, IPPROTO_IPV6,
+                IPPROTO_UDP, mdnsPayload.size)
+        val packetBuilder = PacketBuilder(packetBuffer)
+        // Multicast ethernet address for IPv6 to ff02::fb
+        val multicastEthAddr = MacAddress.fromBytes(
+                byteArrayOf(0x33, 0x33, 0, 0, 0, 0xfb.toByte()))
+        packetBuilder.writeL2Header(
+                MacAddress.fromBytes(byteArrayOf(1, 2, 3, 4, 5, 6)) /* srcMac */,
+                multicastEthAddr,
+                ETH_P_IPV6.toShort())
+        packetBuilder.writeIpv6Header(
+                0x60000000, // version=6, traffic class=0x0, flowlabel=0x0
+                IPPROTO_UDP.toByte(),
+                64 /* hop limit */,
+                parseNumericAddress("2001:db8::123") as Inet6Address /* srcIp */,
+                multicastIpv6Addr /* dstIp */)
+        packetBuilder.writeUdpHeader(MDNS_PORT /* srcPort */, MDNS_PORT /* dstPort */)
+        packetBuffer.put(mdnsPayload)
+        return packetBuilder.finalizePacket()
+    }
+
     /**
      * Register a service and return its registration record.
      */
@@ -1169,7 +1352,65 @@
     }
 }
 
+private fun TapPacketReader.pollForMdnsPacket(
+    timeoutMs: Long = REGISTRATION_TIMEOUT_MS,
+    predicate: (TestDnsPacket) -> Boolean
+): ByteArray? {
+    val mdnsProbeFilter = IPv6UdpFilter(srcPort = MDNS_PORT, dstPort = MDNS_PORT).and {
+        val mdnsPayload = it.copyOfRange(
+                ETHER_HEADER_LEN + IPV6_HEADER_LEN + UDP_HEADER_LEN, it.size)
+        try {
+            predicate(TestDnsPacket(mdnsPayload))
+        } catch (e: DnsPacket.ParseException) {
+            false
+        }
+    }
+    return poll(timeoutMs, mdnsProbeFilter)
+}
+
+private fun TapPacketReader.pollForProbe(
+    serviceName: String,
+    serviceType: String,
+    timeoutMs: Long = REGISTRATION_TIMEOUT_MS
+): ByteArray? = pollForMdnsPacket(timeoutMs) { it.isProbeFor("$serviceName.$serviceType.local") }
+
+private fun TapPacketReader.pollForAdvertisement(
+    serviceName: String,
+    serviceType: String,
+    timeoutMs: Long = REGISTRATION_TIMEOUT_MS
+): ByteArray? = pollForMdnsPacket(timeoutMs) { it.isReplyFor("$serviceName.$serviceType.local") }
+
+private class TestDnsPacket(data: ByteArray) : DnsPacket(data) {
+    fun isProbeFor(name: String): Boolean = mRecords[QDSECTION].any {
+        it.dName == name && it.nsType == 0xff /* ANY */
+    }
+
+    fun isReplyFor(name: String): Boolean = mRecords[ANSECTION].any {
+        it.dName == name && it.nsType == 0x21 /* SRV */
+    }
+}
+
 private fun ByteArray?.utf8ToString(): String {
     if (this == null) return ""
     return String(this, StandardCharsets.UTF_8)
 }
+
+private fun ByteArray.indexOf(sub: ByteArray): Int {
+    var subIndex = 0
+    forEachIndexed { i, b ->
+        when (b) {
+            // Still matching: continue comparing with next byte
+            sub[subIndex] -> {
+                subIndex++
+                if (subIndex == sub.size) {
+                    return i - sub.size + 1
+                }
+            }
+            // Not matching next byte but matches first byte: continue comparing with 2nd byte
+            sub[0] -> subIndex = 1
+            // No matches: continue comparing from first byte
+            else -> subIndex = 0
+        }
+    }
+    return -1
+}