Add test for downstream tethering

Add end-to-end testing for testing NsdManager advertising and
discovering works fine with downstream tethering interfaces.

Bug: 281639507
Test: atest NsdManagerTest
Change-Id: I5a66423f216cfe0c82db5128502c885980ab264b
diff --git a/Tethering/tests/integration/Android.bp b/Tethering/tests/integration/Android.bp
index 20f0bc6..2594a5e 100644
--- a/Tethering/tests/integration/Android.bp
+++ b/Tethering/tests/integration/Android.bp
@@ -28,7 +28,7 @@
         "DhcpPacketLib",
         "androidx.test.rules",
         "cts-net-utils",
-        "mockito-target-extended-minus-junit4",
+        "mockito-target-minus-junit4",
         "net-tests-utils",
         "net-utils-device-common",
         "net-utils-device-common-bpf",
@@ -40,11 +40,6 @@
         "android.test.base",
         "android.test.mock",
     ],
-    jni_libs: [
-        // For mockito extended
-        "libdexmakerjvmtiagent",
-        "libstaticjvmtiagent",
-    ],
 }
 
 android_library {
@@ -54,6 +49,7 @@
     defaults: ["TetheringIntegrationTestsDefaults"],
     visibility: [
         "//packages/modules/Connectivity/Tethering/tests/mts",
+        "//packages/modules/Connectivity/tests/cts/net",
     ]
 }
 
diff --git a/Tethering/tests/integration/base/android/net/EthernetTetheringTestBase.java b/Tethering/tests/integration/base/android/net/EthernetTetheringTestBase.java
index 83fc3e4..0702aa7 100644
--- a/Tethering/tests/integration/base/android/net/EthernetTetheringTestBase.java
+++ b/Tethering/tests/integration/base/android/net/EthernetTetheringTestBase.java
@@ -31,14 +31,12 @@
 import static android.net.TetheringTester.isExpectedIcmpPacket;
 import static android.net.TetheringTester.isExpectedTcpPacket;
 import static android.net.TetheringTester.isExpectedUdpPacket;
-
 import static com.android.net.module.util.HexDump.dumpHexString;
 import static com.android.net.module.util.NetworkStackConstants.ICMPV6_ROUTER_ADVERTISEMENT;
 import static com.android.net.module.util.NetworkStackConstants.TCPHDR_ACK;
 import static com.android.net.module.util.NetworkStackConstants.TCPHDR_SYN;
 import static com.android.testutils.TestNetworkTrackerKt.initTestNetwork;
 import static com.android.testutils.TestPermissionUtil.runAsShell;
-
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertNotNull;
@@ -164,6 +162,10 @@
     private TapPacketReader mDownstreamReader;
     private MyTetheringEventCallback mTetheringEventCallback;
 
+    public Context getContext() {
+        return mContext;
+    }
+
     @BeforeClass
     public static void setUpOnce() throws Exception {
         // The first test case may experience tethering restart with IP conflict handling.
diff --git a/Tethering/tests/integration/base/android/net/TetheringTester.java b/Tethering/tests/integration/base/android/net/TetheringTester.java
index 4f3c6e7..ae4ae55 100644
--- a/Tethering/tests/integration/base/android/net/TetheringTester.java
+++ b/Tethering/tests/integration/base/android/net/TetheringTester.java
@@ -27,12 +27,9 @@
 import static android.system.OsConstants.IPPROTO_IPV6;
 import static android.system.OsConstants.IPPROTO_TCP;
 import static android.system.OsConstants.IPPROTO_UDP;
-
 import static com.android.net.module.util.DnsPacket.ANSECTION;
-import static com.android.net.module.util.DnsPacket.ARSECTION;
 import static com.android.net.module.util.DnsPacket.DnsHeader;
 import static com.android.net.module.util.DnsPacket.DnsRecord;
-import static com.android.net.module.util.DnsPacket.NSSECTION;
 import static com.android.net.module.util.DnsPacket.QDSECTION;
 import static com.android.net.module.util.HexDump.dumpHexString;
 import static com.android.net.module.util.IpUtils.icmpChecksum;
@@ -56,7 +53,6 @@
 import static com.android.net.module.util.NetworkStackConstants.NEIGHBOR_ADVERTISEMENT_FLAG_OVERRIDE;
 import static com.android.net.module.util.NetworkStackConstants.NEIGHBOR_ADVERTISEMENT_FLAG_SOLICITED;
 import static com.android.net.module.util.NetworkStackConstants.TCPHDR_SYN;
-
 import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.fail;
 
diff --git a/framework/src/android/net/DnsResolver.java b/framework/src/android/net/DnsResolver.java
index c6034f1..5fefcd6 100644
--- a/framework/src/android/net/DnsResolver.java
+++ b/framework/src/android/net/DnsResolver.java
@@ -77,6 +77,15 @@
     @interface QueryType {}
     public static final int TYPE_A = 1;
     public static final int TYPE_AAAA = 28;
+    // TODO: add below constants as part of QueryType and the public API
+    /** @hide */
+    public static final int TYPE_PTR = 12;
+    /** @hide */
+    public static final int TYPE_TXT = 16;
+    /** @hide */
+    public static final int TYPE_SRV = 33;
+    /** @hide */
+    public static final int TYPE_ANY = 255;
 
     @IntDef(prefix = { "FLAG_" }, value = {
             FLAG_EMPTY,
diff --git a/service-t/src/com/android/server/NsdService.java b/service-t/src/com/android/server/NsdService.java
index 1250e65..feb8516 100644
--- a/service-t/src/com/android/server/NsdService.java
+++ b/service-t/src/com/android/server/NsdService.java
@@ -1677,7 +1677,10 @@
         mMdnsSocketProvider = deps.makeMdnsSocketProvider(ctx, handler.getLooper(),
                 LOGGER.forSubComponent("MdnsSocketProvider"), new SocketRequestMonitor());
         // Netlink monitor starts on boot, and intentionally never stopped, to ensure that all
-        // address events are received.
+        // address events are received. When the netlink monitor starts, any IP addresses already
+        // on the interfaces will not be seen. In practice, the network will not connect at boot
+        // time As a result, all the netlink message should be observed if the netlink monitor
+        // starts here.
         handler.post(mMdnsSocketProvider::startNetLinkMonitor);
 
         // NsdService is started after ActivityManager (startOtherServices in SystemServer, vs.
diff --git a/tests/cts/net/Android.bp b/tests/cts/net/Android.bp
index 1276d59..6de663a 100644
--- a/tests/cts/net/Android.bp
+++ b/tests/cts/net/Android.bp
@@ -56,6 +56,7 @@
         "modules-utils-build",
         "net-utils-framework-common",
         "truth-prebuilt",
+        "TetheringIntegrationTestsBaseLib",
     ],
 
     // uncomment when b/13249961 is fixed
diff --git a/tests/cts/net/src/android/net/cts/MdnsTestUtils.kt b/tests/cts/net/src/android/net/cts/MdnsTestUtils.kt
new file mode 100644
index 0000000..bc13442
--- /dev/null
+++ b/tests/cts/net/src/android/net/cts/MdnsTestUtils.kt
@@ -0,0 +1,295 @@
+/*
+ * Copyright (C) 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package android.net.cts
+
+import android.net.DnsResolver
+import android.net.Network
+import android.net.nsd.NsdManager
+import android.net.nsd.NsdServiceInfo
+import android.os.Process
+import com.android.net.module.util.ArrayTrackRecord
+import com.android.net.module.util.DnsPacket
+import com.android.net.module.util.NetworkStackConstants.ETHER_HEADER_LEN
+import com.android.net.module.util.NetworkStackConstants.IPV6_HEADER_LEN
+import com.android.net.module.util.NetworkStackConstants.UDP_HEADER_LEN
+import com.android.net.module.util.TrackRecord
+import com.android.testutils.IPv6UdpFilter
+import com.android.testutils.TapPacketReader
+import kotlin.test.assertEquals
+import kotlin.test.assertNotNull
+import kotlin.test.assertNull
+import kotlin.test.assertTrue
+import kotlin.test.fail
+
+private const val MDNS_REGISTRATION_TIMEOUT_MS = 10_000L
+private const val MDNS_PORT = 5353.toShort()
+const val MDNS_CALLBACK_TIMEOUT = 2000L
+const val MDNS_NO_CALLBACK_TIMEOUT_MS = 200L
+
+interface NsdEvent
+open class NsdRecord<T : NsdEvent> private constructor(
+    private val history: ArrayTrackRecord<T>,
+    private val expectedThreadId: Int? = null
+) : TrackRecord<T> by history {
+    constructor(expectedThreadId: Int? = null) : this(ArrayTrackRecord(), expectedThreadId)
+
+    val nextEvents = history.newReadHead()
+
+    override fun add(e: T): Boolean {
+        if (expectedThreadId != null) {
+            assertEquals(
+                expectedThreadId, Process.myTid(),
+                "Callback is running on the wrong thread"
+            )
+        }
+        return history.add(e)
+    }
+
+    inline fun <reified V : NsdEvent> expectCallbackEventually(
+        timeoutMs: Long = MDNS_CALLBACK_TIMEOUT,
+        crossinline predicate: (V) -> Boolean = { true }
+    ): V = nextEvents.poll(timeoutMs) { e -> e is V && predicate(e) } as V?
+        ?: fail("Callback for ${V::class.java.simpleName} not seen after $timeoutMs ms")
+
+    inline fun <reified V : NsdEvent> expectCallback(timeoutMs: Long = MDNS_CALLBACK_TIMEOUT): V {
+        val nextEvent = nextEvents.poll(timeoutMs)
+        assertNotNull(
+            nextEvent, "No callback received after $timeoutMs ms, expected " +
+                    "${V::class.java.simpleName}"
+        )
+        assertTrue(
+            nextEvent is V, "Expected ${V::class.java.simpleName} but got " +
+                    nextEvent.javaClass.simpleName
+        )
+        return nextEvent
+    }
+
+    inline fun assertNoCallback(timeoutMs: Long = MDNS_NO_CALLBACK_TIMEOUT_MS) {
+        val cb = nextEvents.poll(timeoutMs)
+        assertNull(cb, "Expected no callback but got $cb")
+    }
+}
+
+class NsdDiscoveryRecord(expectedThreadId: Int? = null) :
+    NsdManager.DiscoveryListener, NsdRecord<NsdDiscoveryRecord.DiscoveryEvent>(expectedThreadId) {
+    sealed class DiscoveryEvent : NsdEvent {
+        data class StartDiscoveryFailed(val serviceType: String, val errorCode: Int) :
+            DiscoveryEvent()
+
+        data class StopDiscoveryFailed(val serviceType: String, val errorCode: Int) :
+            DiscoveryEvent()
+
+        data class DiscoveryStarted(val serviceType: String) : DiscoveryEvent()
+        data class DiscoveryStopped(val serviceType: String) : DiscoveryEvent()
+        data class ServiceFound(val serviceInfo: NsdServiceInfo) : DiscoveryEvent()
+        data class ServiceLost(val serviceInfo: NsdServiceInfo) : DiscoveryEvent()
+    }
+
+    override fun onStartDiscoveryFailed(serviceType: String, err: Int) {
+        add(DiscoveryEvent.StartDiscoveryFailed(serviceType, err))
+    }
+
+    override fun onStopDiscoveryFailed(serviceType: String, err: Int) {
+        add(DiscoveryEvent.StopDiscoveryFailed(serviceType, err))
+    }
+
+    override fun onDiscoveryStarted(serviceType: String) {
+        add(DiscoveryEvent.DiscoveryStarted(serviceType))
+    }
+
+    override fun onDiscoveryStopped(serviceType: String) {
+        add(DiscoveryEvent.DiscoveryStopped(serviceType))
+    }
+
+    override fun onServiceFound(si: NsdServiceInfo) {
+        add(DiscoveryEvent.ServiceFound(si))
+    }
+
+    override fun onServiceLost(si: NsdServiceInfo) {
+        add(DiscoveryEvent.ServiceLost(si))
+    }
+
+    fun waitForServiceDiscovered(
+        serviceName: String,
+        serviceType: String,
+        expectedNetwork: Network? = null
+    ): NsdServiceInfo {
+        val serviceFound = expectCallbackEventually<DiscoveryEvent.ServiceFound> {
+            it.serviceInfo.serviceName == serviceName &&
+                    (expectedNetwork == null ||
+                            expectedNetwork == it.serviceInfo.network)
+        }.serviceInfo
+        // Discovered service types have a dot at the end
+        assertEquals("$serviceType.", serviceFound.serviceType)
+        return serviceFound
+    }
+}
+
+class NsdRegistrationRecord(expectedThreadId: Int? = null) : NsdManager.RegistrationListener,
+    NsdRecord<NsdRegistrationRecord.RegistrationEvent>(expectedThreadId) {
+    sealed class RegistrationEvent : NsdEvent {
+        abstract val serviceInfo: NsdServiceInfo
+
+        data class RegistrationFailed(
+            override val serviceInfo: NsdServiceInfo,
+            val errorCode: Int
+        ) : RegistrationEvent()
+
+        data class UnregistrationFailed(
+            override val serviceInfo: NsdServiceInfo,
+            val errorCode: Int
+        ) : RegistrationEvent()
+
+        data class ServiceRegistered(override val serviceInfo: NsdServiceInfo) :
+            RegistrationEvent()
+
+        data class ServiceUnregistered(override val serviceInfo: NsdServiceInfo) :
+            RegistrationEvent()
+    }
+
+    override fun onRegistrationFailed(si: NsdServiceInfo, err: Int) {
+        add(RegistrationEvent.RegistrationFailed(si, err))
+    }
+
+    override fun onUnregistrationFailed(si: NsdServiceInfo, err: Int) {
+        add(RegistrationEvent.UnregistrationFailed(si, err))
+    }
+
+    override fun onServiceRegistered(si: NsdServiceInfo) {
+        add(RegistrationEvent.ServiceRegistered(si))
+    }
+
+    override fun onServiceUnregistered(si: NsdServiceInfo) {
+        add(RegistrationEvent.ServiceUnregistered(si))
+    }
+}
+
+class NsdResolveRecord : NsdManager.ResolveListener,
+    NsdRecord<NsdResolveRecord.ResolveEvent>() {
+    sealed class ResolveEvent : NsdEvent {
+        data class ResolveFailed(val serviceInfo: NsdServiceInfo, val errorCode: Int) :
+            ResolveEvent()
+
+        data class ServiceResolved(val serviceInfo: NsdServiceInfo) : ResolveEvent()
+        data class ResolutionStopped(val serviceInfo: NsdServiceInfo) : ResolveEvent()
+        data class StopResolutionFailed(val serviceInfo: NsdServiceInfo, val errorCode: Int) :
+            ResolveEvent()
+    }
+
+    override fun onResolveFailed(si: NsdServiceInfo, err: Int) {
+        add(ResolveEvent.ResolveFailed(si, err))
+    }
+
+    override fun onServiceResolved(si: NsdServiceInfo) {
+        add(ResolveEvent.ServiceResolved(si))
+    }
+
+    override fun onResolutionStopped(si: NsdServiceInfo) {
+        add(ResolveEvent.ResolutionStopped(si))
+    }
+
+    override fun onStopResolutionFailed(si: NsdServiceInfo, err: Int) {
+        super.onStopResolutionFailed(si, err)
+        add(ResolveEvent.StopResolutionFailed(si, err))
+    }
+}
+
+class NsdServiceInfoCallbackRecord : NsdManager.ServiceInfoCallback,
+    NsdRecord<NsdServiceInfoCallbackRecord.ServiceInfoCallbackEvent>() {
+    sealed class ServiceInfoCallbackEvent : NsdEvent {
+        data class RegisterCallbackFailed(val errorCode: Int) : ServiceInfoCallbackEvent()
+        data class ServiceUpdated(val serviceInfo: NsdServiceInfo) : ServiceInfoCallbackEvent()
+        object ServiceUpdatedLost : ServiceInfoCallbackEvent()
+        object UnregisterCallbackSucceeded : ServiceInfoCallbackEvent()
+    }
+
+    override fun onServiceInfoCallbackRegistrationFailed(err: Int) {
+        add(ServiceInfoCallbackEvent.RegisterCallbackFailed(err))
+    }
+
+    override fun onServiceUpdated(si: NsdServiceInfo) {
+        add(ServiceInfoCallbackEvent.ServiceUpdated(si))
+    }
+
+    override fun onServiceLost() {
+        add(ServiceInfoCallbackEvent.ServiceUpdatedLost)
+    }
+
+    override fun onServiceInfoCallbackUnregistered() {
+        add(ServiceInfoCallbackEvent.UnregisterCallbackSucceeded)
+    }
+}
+
+fun TapPacketReader.pollForMdnsPacket(
+    timeoutMs: Long = MDNS_REGISTRATION_TIMEOUT_MS,
+    predicate: (TestDnsPacket) -> Boolean
+): ByteArray? {
+    val mdnsProbeFilter = IPv6UdpFilter(srcPort = MDNS_PORT, dstPort = MDNS_PORT).and {
+        val mdnsPayload = it.copyOfRange(
+            ETHER_HEADER_LEN + IPV6_HEADER_LEN + UDP_HEADER_LEN, it.size
+        )
+        try {
+            predicate(TestDnsPacket(mdnsPayload))
+        } catch (e: DnsPacket.ParseException) {
+            false
+        }
+    }
+    return poll(timeoutMs, mdnsProbeFilter)
+}
+
+fun TapPacketReader.pollForProbe(
+    serviceName: String,
+    serviceType: String,
+    timeoutMs: Long = MDNS_REGISTRATION_TIMEOUT_MS
+): ByteArray? = pollForMdnsPacket(timeoutMs) { it.isProbeFor("$serviceName.$serviceType.local") }
+
+fun TapPacketReader.pollForAdvertisement(
+    serviceName: String,
+    serviceType: String,
+    timeoutMs: Long = MDNS_REGISTRATION_TIMEOUT_MS
+): ByteArray? = pollForMdnsPacket(timeoutMs) { it.isReplyFor("$serviceName.$serviceType.local") }
+
+fun TapPacketReader.pollForQuery(
+    recordName: String,
+    recordType: Int,
+    timeoutMs: Long = MDNS_REGISTRATION_TIMEOUT_MS
+): ByteArray? = pollForMdnsPacket(timeoutMs) { it.isQueryFor(recordName, recordType) }
+
+fun TapPacketReader.pollForReply(
+    serviceName: String,
+    serviceType: String,
+    timeoutMs: Long = MDNS_REGISTRATION_TIMEOUT_MS
+): ByteArray? = pollForMdnsPacket(timeoutMs) {
+    it.isReplyFor("$serviceName.$serviceType.local")
+}
+
+class TestDnsPacket(data: ByteArray) : DnsPacket(data) {
+    val header: DnsHeader
+        get() = mHeader
+    val records: Array<List<DnsRecord>>
+        get() = mRecords
+    fun isProbeFor(name: String): Boolean = mRecords[QDSECTION].any {
+        it.dName == name && it.nsType == DnsResolver.TYPE_ANY
+    }
+
+    fun isReplyFor(name: String): Boolean = mRecords[ANSECTION].any {
+        it.dName == name && it.nsType == DnsResolver.TYPE_SRV
+    }
+
+    fun isQueryFor(name: String, type: Int): Boolean = mRecords[QDSECTION].any {
+        it.dName == name && it.nsType == type
+    }
+}
diff --git a/tests/cts/net/src/android/net/cts/NsdManagerDownstreamTetheringTest.kt b/tests/cts/net/src/android/net/cts/NsdManagerDownstreamTetheringTest.kt
new file mode 100644
index 0000000..c2bb7cd
--- /dev/null
+++ b/tests/cts/net/src/android/net/cts/NsdManagerDownstreamTetheringTest.kt
@@ -0,0 +1,150 @@
+/*
+ * Copyright (C) 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package android.net.cts
+
+import android.net.EthernetTetheringTestBase
+import android.net.LinkAddress
+import android.net.TestNetworkInterface
+import android.net.TetheringManager.CONNECTIVITY_SCOPE_LOCAL
+import android.net.TetheringManager.TETHERING_ETHERNET
+import android.net.TetheringManager.TetheringRequest
+import android.net.nsd.NsdManager
+import android.os.Build
+import androidx.test.filters.SmallTest
+import com.android.testutils.ConnectivityModuleTest
+import com.android.testutils.DevSdkIgnoreRule
+import com.android.testutils.DevSdkIgnoreRunner
+import com.android.testutils.TapPacketReader
+import com.android.testutils.tryTest
+import java.util.Random
+import kotlin.test.assertEquals
+import kotlin.test.assertNotNull
+import org.junit.After
+import org.junit.Assume.assumeFalse
+import org.junit.Before
+import org.junit.Test
+import org.junit.runner.RunWith
+
+@RunWith(DevSdkIgnoreRunner::class)
+@SmallTest
+@ConnectivityModuleTest
+@DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.S_V2)
+class NsdManagerDownstreamTetheringTest : EthernetTetheringTestBase() {
+    private val nsdManager by lazy { context.getSystemService(NsdManager::class.java)!! }
+    private val serviceType = "_nmt%09d._tcp".format(Random().nextInt(1_000_000_000))
+
+    @Before
+    override fun setUp() {
+        super.setUp()
+        setIncludeTestInterfaces(true)
+    }
+
+    @After
+    override fun tearDown() {
+        super.tearDown()
+        setIncludeTestInterfaces(false)
+    }
+
+    @Test
+    fun testMdnsDiscoveryCanSendPacketOnLocalOnlyDownstreamTetheringInterface() {
+        assumeFalse(isInterfaceForTetheringAvailable)
+
+        var downstreamIface: TestNetworkInterface? = null
+        var tetheringEventCallback: MyTetheringEventCallback? = null
+        var downstreamReader: TapPacketReader? = null
+
+        val discoveryRecord = NsdDiscoveryRecord()
+
+        tryTest {
+            downstreamIface = createTestInterface()
+            val iface = tetheredInterface
+            assertEquals(iface, downstreamIface?.interfaceName)
+            val request = TetheringRequest.Builder(TETHERING_ETHERNET)
+                .setConnectivityScope(CONNECTIVITY_SCOPE_LOCAL).build()
+            tetheringEventCallback = enableEthernetTethering(
+                iface, request,
+                null /* any upstream */
+            ).apply {
+                awaitInterfaceLocalOnly()
+            }
+            // This shouldn't be flaky because the TAP interface will buffer all packets even
+            // before the reader is started.
+            downstreamReader = makePacketReader(downstreamIface)
+            waitForRouterAdvertisement(downstreamReader, iface, WAIT_RA_TIMEOUT_MS)
+
+            nsdManager.discoverServices(serviceType, NsdManager.PROTOCOL_DNS_SD, discoveryRecord)
+            discoveryRecord.expectCallback<NsdDiscoveryRecord.DiscoveryEvent.DiscoveryStarted>()
+            assertNotNull(downstreamReader?.pollForQuery("$serviceType.local", 12 /* type PTR */))
+        } cleanupStep {
+            nsdManager.stopServiceDiscovery(discoveryRecord)
+            discoveryRecord.expectCallback<NsdDiscoveryRecord.DiscoveryEvent.DiscoveryStopped>()
+        } cleanupStep {
+            maybeStopTapPacketReader(downstreamReader)
+        } cleanupStep {
+            maybeCloseTestInterface(downstreamIface)
+        } cleanup {
+            maybeUnregisterTetheringEventCallback(tetheringEventCallback)
+        }
+    }
+
+    @Test
+    fun testMdnsDiscoveryWorkOnTetheringInterface() {
+        assumeFalse(isInterfaceForTetheringAvailable)
+        setIncludeTestInterfaces(true)
+
+        var downstreamIface: TestNetworkInterface? = null
+        var tetheringEventCallback: MyTetheringEventCallback? = null
+        var downstreamReader: TapPacketReader? = null
+
+        val discoveryRecord = NsdDiscoveryRecord()
+
+        tryTest {
+            downstreamIface = createTestInterface()
+            val iface = tetheredInterface
+            assertEquals(iface, downstreamIface?.interfaceName)
+
+            val localAddr = LinkAddress("192.0.2.3/28")
+            val clientAddr = LinkAddress("192.0.2.2/28")
+            val request = TetheringRequest.Builder(TETHERING_ETHERNET)
+                .setStaticIpv4Addresses(localAddr, clientAddr)
+                .setShouldShowEntitlementUi(false).build()
+            tetheringEventCallback = enableEthernetTethering(
+                iface, request,
+                null /* any upstream */
+            ).apply {
+                awaitInterfaceTethered()
+            }
+
+            val fd = downstreamIface?.fileDescriptor?.fileDescriptor
+            assertNotNull(fd)
+            downstreamReader = makePacketReader(fd, getMTU(downstreamIface))
+
+            nsdManager.discoverServices(serviceType, NsdManager.PROTOCOL_DNS_SD, discoveryRecord)
+            discoveryRecord.expectCallback<NsdDiscoveryRecord.DiscoveryEvent.DiscoveryStarted>()
+            assertNotNull(downstreamReader?.pollForQuery("$serviceType.local", 12 /* type PTR */))
+            // TODO: Add another test to check packet reply can trigger serviceFound.
+        } cleanupStep {
+            nsdManager.stopServiceDiscovery(discoveryRecord)
+            discoveryRecord.expectCallback<NsdDiscoveryRecord.DiscoveryEvent.DiscoveryStopped>()
+        } cleanupStep {
+            maybeStopTapPacketReader(downstreamReader)
+        } cleanupStep {
+            maybeCloseTestInterface(downstreamIface)
+        } cleanup {
+            maybeUnregisterTetheringEventCallback(tetheringEventCallback)
+        }
+    }
+}
diff --git a/tests/cts/net/src/android/net/cts/NsdManagerTest.kt b/tests/cts/net/src/android/net/cts/NsdManagerTest.kt
index 17a135a..27bd5d3 100644
--- a/tests/cts/net/src/android/net/cts/NsdManagerTest.kt
+++ b/tests/cts/net/src/android/net/cts/NsdManagerTest.kt
@@ -38,36 +38,26 @@
 import android.net.TestNetworkManager
 import android.net.TestNetworkSpecifier
 import android.net.connectivity.ConnectivityCompatChanges
-import android.net.cts.NsdManagerTest.NsdDiscoveryRecord.DiscoveryEvent.DiscoveryStarted
-import android.net.cts.NsdManagerTest.NsdDiscoveryRecord.DiscoveryEvent.DiscoveryStopped
-import android.net.cts.NsdManagerTest.NsdDiscoveryRecord.DiscoveryEvent.ServiceFound
-import android.net.cts.NsdManagerTest.NsdDiscoveryRecord.DiscoveryEvent.ServiceLost
-import android.net.cts.NsdManagerTest.NsdDiscoveryRecord.DiscoveryEvent.StartDiscoveryFailed
-import android.net.cts.NsdManagerTest.NsdDiscoveryRecord.DiscoveryEvent.StopDiscoveryFailed
-import android.net.cts.NsdManagerTest.NsdRegistrationRecord.RegistrationEvent.RegistrationFailed
-import android.net.cts.NsdManagerTest.NsdRegistrationRecord.RegistrationEvent.ServiceRegistered
-import android.net.cts.NsdManagerTest.NsdRegistrationRecord.RegistrationEvent.ServiceUnregistered
-import android.net.cts.NsdManagerTest.NsdRegistrationRecord.RegistrationEvent.UnregistrationFailed
-import android.net.cts.NsdManagerTest.NsdResolveRecord.ResolveEvent.ResolutionStopped
-import android.net.cts.NsdManagerTest.NsdResolveRecord.ResolveEvent.ResolveFailed
-import android.net.cts.NsdManagerTest.NsdResolveRecord.ResolveEvent.ServiceResolved
-import android.net.cts.NsdManagerTest.NsdResolveRecord.ResolveEvent.StopResolutionFailed
-import android.net.cts.NsdManagerTest.NsdServiceInfoCallbackRecord.ServiceInfoCallbackEvent.RegisterCallbackFailed
-import android.net.cts.NsdManagerTest.NsdServiceInfoCallbackRecord.ServiceInfoCallbackEvent.ServiceUpdated
-import android.net.cts.NsdManagerTest.NsdServiceInfoCallbackRecord.ServiceInfoCallbackEvent.ServiceUpdatedLost
-import android.net.cts.NsdManagerTest.NsdServiceInfoCallbackRecord.ServiceInfoCallbackEvent.UnregisterCallbackSucceeded
+import android.net.cts.NsdDiscoveryRecord.DiscoveryEvent.DiscoveryStarted
+import android.net.cts.NsdDiscoveryRecord.DiscoveryEvent.DiscoveryStopped
+import android.net.cts.NsdDiscoveryRecord.DiscoveryEvent.ServiceFound
+import android.net.cts.NsdDiscoveryRecord.DiscoveryEvent.ServiceLost
+import android.net.cts.NsdRegistrationRecord.RegistrationEvent.ServiceRegistered
+import android.net.cts.NsdRegistrationRecord.RegistrationEvent.ServiceUnregistered
+import android.net.cts.NsdResolveRecord.ResolveEvent.ResolutionStopped
+import android.net.cts.NsdResolveRecord.ResolveEvent.ServiceResolved
+import android.net.cts.NsdResolveRecord.ResolveEvent.StopResolutionFailed
+import android.net.cts.NsdServiceInfoCallbackRecord.ServiceInfoCallbackEvent.ServiceUpdated
+import android.net.cts.NsdServiceInfoCallbackRecord.ServiceInfoCallbackEvent.ServiceUpdatedLost
+import android.net.cts.NsdServiceInfoCallbackRecord.ServiceInfoCallbackEvent.UnregisterCallbackSucceeded
 import android.net.cts.util.CtsNetUtils
 import android.net.nsd.NsdManager
-import android.net.nsd.NsdManager.DiscoveryListener
-import android.net.nsd.NsdManager.RegistrationListener
-import android.net.nsd.NsdManager.ResolveListener
 import android.net.nsd.NsdServiceInfo
 import android.net.nsd.OffloadEngine
 import android.net.nsd.OffloadServiceInfo
 import android.os.Build
 import android.os.Handler
 import android.os.HandlerThread
-import android.os.Process.myTid
 import android.platform.test.annotations.AppModeFull
 import android.system.ErrnoException
 import android.system.Os
@@ -84,19 +74,13 @@
 import com.android.compatibility.common.util.PollingCheck
 import com.android.compatibility.common.util.PropertyUtil
 import com.android.modules.utils.build.SdkLevel.isAtLeastU
-import com.android.net.module.util.ArrayTrackRecord
 import com.android.net.module.util.DnsPacket
 import com.android.net.module.util.HexDump
-import com.android.net.module.util.NetworkStackConstants.ETHER_HEADER_LEN
-import com.android.net.module.util.NetworkStackConstants.IPV6_HEADER_LEN
-import com.android.net.module.util.NetworkStackConstants.UDP_HEADER_LEN
 import com.android.net.module.util.PacketBuilder
-import com.android.net.module.util.TrackRecord
 import com.android.testutils.ConnectivityModuleTest
 import com.android.testutils.DevSdkIgnoreRule
 import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo
 import com.android.testutils.DevSdkIgnoreRunner
-import com.android.testutils.IPv6UdpFilter
 import com.android.testutils.RecorderCallback.CallbackEntry.CapabilitiesChanged
 import com.android.testutils.RecorderCallback.CallbackEntry.LinkPropertiesChanged
 import com.android.testutils.TapPacketReader
@@ -123,7 +107,6 @@
 import kotlin.test.assertFailsWith
 import kotlin.test.assertNotNull
 import kotlin.test.assertNull
-import kotlin.test.assertTrue
 import kotlin.test.fail
 import org.junit.After
 import org.junit.Assert.assertArrayEquals
@@ -137,7 +120,6 @@
 
 private const val TAG = "NsdManagerTest"
 private const val TIMEOUT_MS = 2000L
-private const val NO_CALLBACK_TIMEOUT_MS = 200L
 // Registration may take a long time if there are devices with the same hostname on the network,
 // as the device needs to try another name and probe again. This is especially true since when using
 // mdnsresponder the usual hostname is "Android", and on conflict "Android-2", "Android-3", ... are
@@ -159,7 +141,9 @@
     val ignoreRule = DevSdkIgnoreRule()
 
     private val context by lazy { InstrumentationRegistry.getInstrumentation().context }
-    private val nsdManager by lazy { context.getSystemService(NsdManager::class.java)!! }
+    private val nsdManager by lazy {
+        context.getSystemService(NsdManager::class.java) ?: fail("Could not get NsdManager service")
+    }
 
     private val cm by lazy { context.getSystemService(ConnectivityManager::class.java)!! }
     private val serviceName = "NsdTest%09d".format(Random().nextInt(1_000_000_000))
@@ -185,192 +169,6 @@
         }
     }
 
-    private interface NsdEvent
-    private open class NsdRecord<T : NsdEvent> private constructor(
-        private val history: ArrayTrackRecord<T>,
-        private val expectedThreadId: Int? = null
-    ) : TrackRecord<T> by history {
-        constructor(expectedThreadId: Int? = null) : this(ArrayTrackRecord(), expectedThreadId)
-
-        val nextEvents = history.newReadHead()
-
-        override fun add(e: T): Boolean {
-            if (expectedThreadId != null) {
-                assertEquals(expectedThreadId, myTid(), "Callback is running on the wrong thread")
-            }
-            return history.add(e)
-        }
-
-        inline fun <reified V : NsdEvent> expectCallbackEventually(
-            timeoutMs: Long = TIMEOUT_MS,
-            crossinline predicate: (V) -> Boolean = { true }
-        ): V = nextEvents.poll(timeoutMs) { e -> e is V && predicate(e) } as V?
-                ?: fail("Callback for ${V::class.java.simpleName} not seen after $timeoutMs ms")
-
-        inline fun <reified V : NsdEvent> expectCallback(timeoutMs: Long = TIMEOUT_MS): V {
-            val nextEvent = nextEvents.poll(timeoutMs)
-            assertNotNull(nextEvent, "No callback received after $timeoutMs ms, " +
-                    "expected ${V::class.java.simpleName}")
-            assertTrue(nextEvent is V, "Expected ${V::class.java.simpleName} but got " +
-                    nextEvent.javaClass.simpleName)
-            return nextEvent
-        }
-
-        inline fun assertNoCallback(timeoutMs: Long = NO_CALLBACK_TIMEOUT_MS) {
-            val cb = nextEvents.poll(timeoutMs)
-            assertNull(cb, "Expected no callback but got $cb")
-        }
-    }
-
-    private class NsdRegistrationRecord(expectedThreadId: Int? = null) : RegistrationListener,
-            NsdRecord<NsdRegistrationRecord.RegistrationEvent>(expectedThreadId) {
-        sealed class RegistrationEvent : NsdEvent {
-            abstract val serviceInfo: NsdServiceInfo
-
-            data class RegistrationFailed(
-                override val serviceInfo: NsdServiceInfo,
-                val errorCode: Int
-            ) : RegistrationEvent()
-
-            data class UnregistrationFailed(
-                override val serviceInfo: NsdServiceInfo,
-                val errorCode: Int
-            ) : RegistrationEvent()
-
-            data class ServiceRegistered(override val serviceInfo: NsdServiceInfo) :
-                    RegistrationEvent()
-            data class ServiceUnregistered(override val serviceInfo: NsdServiceInfo) :
-                    RegistrationEvent()
-        }
-
-        override fun onRegistrationFailed(si: NsdServiceInfo, err: Int) {
-            add(RegistrationFailed(si, err))
-        }
-
-        override fun onUnregistrationFailed(si: NsdServiceInfo, err: Int) {
-            add(UnregistrationFailed(si, err))
-        }
-
-        override fun onServiceRegistered(si: NsdServiceInfo) {
-            add(ServiceRegistered(si))
-        }
-
-        override fun onServiceUnregistered(si: NsdServiceInfo) {
-            add(ServiceUnregistered(si))
-        }
-    }
-
-    private class NsdDiscoveryRecord(expectedThreadId: Int? = null) :
-            DiscoveryListener, NsdRecord<NsdDiscoveryRecord.DiscoveryEvent>(expectedThreadId) {
-        sealed class DiscoveryEvent : NsdEvent {
-            data class StartDiscoveryFailed(val serviceType: String, val errorCode: Int) :
-                    DiscoveryEvent()
-
-            data class StopDiscoveryFailed(val serviceType: String, val errorCode: Int) :
-                    DiscoveryEvent()
-
-            data class DiscoveryStarted(val serviceType: String) : DiscoveryEvent()
-            data class DiscoveryStopped(val serviceType: String) : DiscoveryEvent()
-            data class ServiceFound(val serviceInfo: NsdServiceInfo) : DiscoveryEvent()
-            data class ServiceLost(val serviceInfo: NsdServiceInfo) : DiscoveryEvent()
-        }
-
-        override fun onStartDiscoveryFailed(serviceType: String, err: Int) {
-            add(StartDiscoveryFailed(serviceType, err))
-        }
-
-        override fun onStopDiscoveryFailed(serviceType: String, err: Int) {
-            add(StopDiscoveryFailed(serviceType, err))
-        }
-
-        override fun onDiscoveryStarted(serviceType: String) {
-            add(DiscoveryStarted(serviceType))
-        }
-
-        override fun onDiscoveryStopped(serviceType: String) {
-            add(DiscoveryStopped(serviceType))
-        }
-
-        override fun onServiceFound(si: NsdServiceInfo) {
-            add(ServiceFound(si))
-        }
-
-        override fun onServiceLost(si: NsdServiceInfo) {
-            add(ServiceLost(si))
-        }
-
-        fun waitForServiceDiscovered(
-            serviceName: String,
-            serviceType: String,
-            expectedNetwork: Network? = null
-        ): NsdServiceInfo {
-            val serviceFound = expectCallbackEventually<ServiceFound> {
-                it.serviceInfo.serviceName == serviceName &&
-                        (expectedNetwork == null ||
-                                expectedNetwork == it.serviceInfo.network)
-            }.serviceInfo
-            // Discovered service types have a dot at the end
-            assertEquals("$serviceType.", serviceFound.serviceType)
-            return serviceFound
-        }
-    }
-
-    private class NsdResolveRecord : ResolveListener,
-            NsdRecord<NsdResolveRecord.ResolveEvent>() {
-        sealed class ResolveEvent : NsdEvent {
-            data class ResolveFailed(val serviceInfo: NsdServiceInfo, val errorCode: Int) :
-                    ResolveEvent()
-
-            data class ServiceResolved(val serviceInfo: NsdServiceInfo) : ResolveEvent()
-            data class ResolutionStopped(val serviceInfo: NsdServiceInfo) : ResolveEvent()
-            data class StopResolutionFailed(val serviceInfo: NsdServiceInfo, val errorCode: Int) :
-                    ResolveEvent()
-        }
-
-        override fun onResolveFailed(si: NsdServiceInfo, err: Int) {
-            add(ResolveFailed(si, err))
-        }
-
-        override fun onServiceResolved(si: NsdServiceInfo) {
-            add(ServiceResolved(si))
-        }
-
-        override fun onResolutionStopped(si: NsdServiceInfo) {
-            add(ResolutionStopped(si))
-        }
-
-        override fun onStopResolutionFailed(si: NsdServiceInfo, err: Int) {
-            super.onStopResolutionFailed(si, err)
-            add(StopResolutionFailed(si, err))
-        }
-    }
-
-    private class NsdServiceInfoCallbackRecord : NsdManager.ServiceInfoCallback,
-            NsdRecord<NsdServiceInfoCallbackRecord.ServiceInfoCallbackEvent>() {
-        sealed class ServiceInfoCallbackEvent : NsdEvent {
-            data class RegisterCallbackFailed(val errorCode: Int) : ServiceInfoCallbackEvent()
-            data class ServiceUpdated(val serviceInfo: NsdServiceInfo) : ServiceInfoCallbackEvent()
-            object ServiceUpdatedLost : ServiceInfoCallbackEvent()
-            object UnregisterCallbackSucceeded : ServiceInfoCallbackEvent()
-        }
-
-        override fun onServiceInfoCallbackRegistrationFailed(err: Int) {
-            add(RegisterCallbackFailed(err))
-        }
-
-        override fun onServiceUpdated(si: NsdServiceInfo) {
-            add(ServiceUpdated(si))
-        }
-
-        override fun onServiceLost() {
-            add(ServiceUpdatedLost)
-        }
-
-        override fun onServiceInfoCallbackUnregistered() {
-            add(UnregisterCallbackSucceeded)
-        }
-    }
-
     private class TestNsdOffloadEngine : OffloadEngine,
         NsdRecord<TestNsdOffloadEngine.OffloadEvent>() {
         sealed class OffloadEvent : NsdEvent {
@@ -1414,54 +1212,6 @@
     }
 }
 
-private fun TapPacketReader.pollForMdnsPacket(
-    timeoutMs: Long = REGISTRATION_TIMEOUT_MS,
-    predicate: (TestDnsPacket) -> Boolean
-): ByteArray? {
-    val mdnsProbeFilter = IPv6UdpFilter(srcPort = MDNS_PORT, dstPort = MDNS_PORT).and {
-        val mdnsPayload = it.copyOfRange(
-                ETHER_HEADER_LEN + IPV6_HEADER_LEN + UDP_HEADER_LEN, it.size)
-        try {
-            predicate(TestDnsPacket(mdnsPayload))
-        } catch (e: DnsPacket.ParseException) {
-            false
-        }
-    }
-    return poll(timeoutMs, mdnsProbeFilter)
-}
-
-private fun TapPacketReader.pollForProbe(
-    serviceName: String,
-    serviceType: String,
-    timeoutMs: Long = REGISTRATION_TIMEOUT_MS
-): ByteArray? = pollForMdnsPacket(timeoutMs) { it.isProbeFor("$serviceName.$serviceType.local") }
-
-private fun TapPacketReader.pollForAdvertisement(
-    serviceName: String,
-    serviceType: String,
-    timeoutMs: Long = REGISTRATION_TIMEOUT_MS
-): ByteArray? = pollForMdnsPacket(timeoutMs) { it.isReplyFor("$serviceName.$serviceType.local") }
-
-private class TestDnsPacket(data: ByteArray) : DnsPacket(data) {
-    val header: DnsHeader
-        get() = mHeader
-    val records: Array<List<DnsRecord>>
-        get() = mRecords
-
-    fun isProbeFor(name: String): Boolean = mRecords[QDSECTION].any {
-        it.dName == name && it.nsType == 0xff /* ANY */
-    }
-
-    fun isReplyFor(name: String): Boolean = mRecords[ANSECTION].any {
-        it.dName == name && it.nsType == 0x21 /* SRV */
-    }
-}
-
-private fun ByteArray?.utf8ToString(): String {
-    if (this == null) return ""
-    return String(this, StandardCharsets.UTF_8)
-}
-
 private fun ByteArray.indexOf(sub: ByteArray): Int {
     var subIndex = 0
     forEachIndexed { i, b ->
@@ -1481,3 +1231,8 @@
     }
     return -1
 }
+
+private fun ByteArray?.utf8ToString(): String {
+    if (this == null) return ""
+    return String(this, StandardCharsets.UTF_8)
+}