Add CTS test for DnsResolver cancellation

Verify that no more queries are sent after the cancellation is
requested.

This assumes that the DnsResolver leaves enough time between retries for
the cancellation to reliably reach it.

Test: atest
Change-Id: Iad00dc1d181656797557d97d360a46c1014a8669
diff --git a/tests/cts/net/src/android/net/cts/DnsResolverTapTest.kt b/tests/cts/net/src/android/net/cts/DnsResolverTapTest.kt
index c8f2f7c..b1b6e0d 100644
--- a/tests/cts/net/src/android/net/cts/DnsResolverTapTest.kt
+++ b/tests/cts/net/src/android/net/cts/DnsResolverTapTest.kt
@@ -17,24 +17,41 @@
 package android.net.cts
 
 import android.Manifest.permission.MANAGE_TEST_NETWORKS
+import android.Manifest.permission.READ_DEVICE_CONFIG
+import android.net.DnsResolver
 import android.net.InetAddresses.parseNumericAddress
 import android.net.IpPrefix
 import android.net.MacAddress
 import android.net.RouteInfo
+import android.os.CancellationSignal
 import android.os.HandlerThread
+import android.os.SystemClock
+import android.provider.DeviceConfig
+import android.provider.DeviceConfig.NAMESPACE_NETD_NATIVE
 import androidx.test.ext.junit.runners.AndroidJUnit4
 import androidx.test.platform.app.InstrumentationRegistry
+import com.android.net.module.util.NetworkStackConstants.ETHER_HEADER_LEN
+import com.android.net.module.util.NetworkStackConstants.IPV6_HEADER_LEN
+import com.android.net.module.util.NetworkStackConstants.UDP_HEADER_LEN
 import com.android.testutils.AutoReleaseNetworkCallbackRule
+import com.android.testutils.DeviceConfigRule
 import com.android.testutils.DnsResolverModuleTest
+import com.android.testutils.IPv6UdpFilter
 import com.android.testutils.RecorderCallback.CallbackEntry.LinkPropertiesChanged
 import com.android.testutils.RouterAdvertisementResponder
 import com.android.testutils.TapPacketReaderRule
 import com.android.testutils.TestableNetworkAgent
+import com.android.testutils.TestDnsPacket
+import com.android.testutils.com.android.testutils.SetFeatureFlagsRule
 import com.android.testutils.runAsShell
 import java.net.Inet6Address
+import java.net.InetAddress
+import kotlin.test.assertNotNull
+import kotlin.test.assertNull
 import org.junit.After
 import org.junit.Before
 import org.junit.Rule
+import org.junit.Test
 import org.junit.runner.RunWith
 
 private val TEST_DNSSERVER_MAC = MacAddress.fromString("00:11:22:33:44:55")
@@ -48,9 +65,29 @@
     private val handlerThread = HandlerThread(TAG)
 
     @get:Rule(order = 1)
-    val packetReaderRule = TapPacketReaderRule()
+    val deviceConfigRule = DeviceConfigRule()
 
     @get:Rule(order = 2)
+    val featureFlagsRule = SetFeatureFlagsRule(
+        setFlagsMethod = { name, enabled ->
+            val value = when (enabled) {
+                null -> null
+                true -> "1"
+                false -> "0"
+            }
+            deviceConfigRule.setConfig(NAMESPACE_NETD_NATIVE, name, value)
+        },
+        getFlagsMethod = {
+            runAsShell(READ_DEVICE_CONFIG) {
+                DeviceConfig.getInt(NAMESPACE_NETD_NATIVE, it, 0) == 1
+            }
+        }
+    )
+
+    @get:Rule(order = 3)
+    val packetReaderRule = TapPacketReaderRule()
+
+    @get:Rule(order = 4)
     val cbRule = AutoReleaseNetworkCallbackRule()
 
     private val ndResponder by lazy { RouterAdvertisementResponder(packetReaderRule.reader) }
@@ -90,4 +127,57 @@
         handlerThread.quitSafely()
         handlerThread.join()
     }
+
+    private class DnsCallback : DnsResolver.Callback<List<InetAddress>> {
+        override fun onAnswer(answer: List<InetAddress>, rcode: Int) = Unit
+        override fun onError(error: DnsResolver.DnsException) = Unit
+    }
+
+    /**
+     * Run a cancellation test.
+     *
+     * @param domain Domain name to query
+     * @param waitTimeForNoRetryAfterCancellationMs If positive, cancel the query and wait for that
+     *                                              delay to check no retry is sent.
+     * @return The duration it took to receive all expected replies.
+     */
+    fun doCancellationTest(domain: String, waitTimeForNoRetryAfterCancellationMs: Long): Long {
+        val cancellationSignal = CancellationSignal()
+        val dnsCb = DnsCallback()
+        val queryStart = SystemClock.elapsedRealtime()
+        DnsResolver.getInstance().query(
+            agent.network, domain, 0 /* flags */,
+            Runnable::run /* executor */, cancellationSignal, dnsCb
+        )
+
+        if (waitTimeForNoRetryAfterCancellationMs > 0) {
+            cancellationSignal.cancel()
+        }
+        // Filter for queries on UDP port 53 for the specified domain
+        val filter = IPv6UdpFilter(dstPort = 53).and {
+            TestDnsPacket(
+                it.copyOfRange(ETHER_HEADER_LEN + IPV6_HEADER_LEN + UDP_HEADER_LEN, it.size),
+                dstAddr = dnsServerAddr
+            ).isQueryFor(domain, DnsResolver.TYPE_AAAA)
+        }
+
+        val reader = packetReaderRule.reader
+        assertNotNull(reader.poll(TEST_TIMEOUT_MS, filter), "Original query not found")
+        if (waitTimeForNoRetryAfterCancellationMs > 0) {
+            assertNull(reader.poll(waitTimeForNoRetryAfterCancellationMs, filter),
+                "Expected no retry query")
+        } else {
+            assertNotNull(reader.poll(TEST_TIMEOUT_MS, filter), "Retry query not found")
+        }
+        return SystemClock.elapsedRealtime() - queryStart
+    }
+
+    @SetFeatureFlagsRule.FeatureFlag("no_retry_after_cancel", true)
+    @Test
+    fun testCancellation() {
+        val timeWithRetryWhenNotCancelled = doCancellationTest("test1.example.com",
+            waitTimeForNoRetryAfterCancellationMs = 0L)
+        doCancellationTest("test2.example.com",
+            waitTimeForNoRetryAfterCancellationMs = timeWithRetryWhenNotCancelled + 50L)
+    }
 }
\ No newline at end of file