[ST03] Add test dns server for integration tests
Bug: 139774492
Test: atest DnsAnswerProviderTest TestDnsServerTest
Change-Id: Ia5039a47f4b818efb09aab8174beb9d921339c3d
diff --git a/staticlibs/tests/unit/src/com/android/testutils/TestDnsServerTest.kt b/staticlibs/tests/unit/src/com/android/testutils/TestDnsServerTest.kt
new file mode 100644
index 0000000..6f4587b
--- /dev/null
+++ b/staticlibs/tests/unit/src/com/android/testutils/TestDnsServerTest.kt
@@ -0,0 +1,122 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.testutils
+
+import android.net.DnsResolver.CLASS_IN
+import android.net.DnsResolver.TYPE_AAAA
+import android.net.Network
+import androidx.test.ext.junit.runners.AndroidJUnit4
+import androidx.test.filters.SmallTest
+import com.android.net.module.util.DnsPacket
+import com.android.net.module.util.DnsPacket.DnsRecord
+import libcore.net.InetAddressUtils
+import org.junit.After
+import org.junit.Test
+import org.junit.runner.RunWith
+import org.mockito.Mockito
+import java.net.DatagramPacket
+import java.net.DatagramSocket
+import java.net.InetAddress
+import java.net.InetSocketAddress
+import kotlin.test.assertEquals
+import kotlin.test.assertFailsWith
+import kotlin.test.assertFalse
+import kotlin.test.assertTrue
+
+val TEST_V6_ADDR = InetAddressUtils.parseNumericAddress("2001:db8::3")
+const val TEST_DOMAIN = "hello.example.com"
+
+@RunWith(AndroidJUnit4::class)
+@SmallTest
+class TestDnsServerTest {
+ private val network = Mockito.mock(Network::class.java)
+ private val localAddr = InetSocketAddress(InetAddress.getLocalHost(), 0 /* port */)
+ private val testServer: TestDnsServer = TestDnsServer(network, localAddr)
+
+ @After
+ fun tearDown() {
+ if (testServer.isAlive) testServer.stop()
+ }
+
+ @Test
+ fun testStartStop() {
+ repeat(100) {
+ val server = TestDnsServer(network, localAddr)
+ server.start()
+ assertTrue(server.isAlive)
+ server.stop()
+ assertFalse(server.isAlive)
+ }
+
+ // Test illegal start/stop.
+ assertFailsWith<IllegalStateException> { testServer.stop() }
+ testServer.start()
+ assertTrue(testServer.isAlive)
+ assertFailsWith<IllegalStateException> { testServer.start() }
+ testServer.stop()
+ assertFalse(testServer.isAlive)
+ assertFailsWith<IllegalStateException> { testServer.stop() }
+ // TestDnsServer rejects start after stop.
+ assertFailsWith<IllegalStateException> { testServer.start() }
+ }
+
+ @Test
+ fun testHandleDnsQuery() {
+ testServer.setAnswer(TEST_DOMAIN, listOf(TEST_V6_ADDR))
+ testServer.start()
+
+ // Mock query and send it to the test server.
+ val queryHeader = DnsPacket.DnsHeader(0xbeef /* id */,
+ 0x0 /* flag */, 1 /* qcount */, 0 /* ancount */)
+ val qlist = listOf(DnsRecord.makeQuestion(TEST_DOMAIN, TYPE_AAAA, CLASS_IN))
+ val queryPacket = TestDnsServer.DnsQueryPacket(queryHeader, qlist, emptyList())
+ val response = resolve(queryPacket, testServer.port)
+
+ // Verify expected answer packet. Set QR bit of flag to 1 for response packet
+ // according to RFC 1035 section 4.1.1.
+ val answerHeader = DnsPacket.DnsHeader(0xbeef,
+ 1 shl 15 /* flag */, 1 /* qcount */, 1 /* ancount */)
+ val alist = listOf(DnsRecord.makeAOrAAAARecord(DnsPacket.ANSECTION, TEST_DOMAIN,
+ CLASS_IN, DEFAULT_TTL_S, TEST_V6_ADDR))
+ val expectedAnswerPacket = TestDnsServer.DnsAnswerPacket(answerHeader, qlist, alist)
+ assertEquals(expectedAnswerPacket, response)
+
+ // Clean up the server in tearDown.
+ }
+
+ private fun resolve(queryDnsPacket: DnsPacket, serverPort: Int): TestDnsServer.DnsAnswerPacket {
+ val bytes = queryDnsPacket.bytes
+ // Create a new client socket, the socket will be bound to a
+ // random port other than the server port.
+ val socket = DatagramSocket(localAddr).also { it.soTimeout = 100 }
+ val queryPacket = DatagramPacket(bytes, bytes.size, localAddr.address, serverPort)
+
+ // Send query and wait for the reply.
+ socket.send(queryPacket)
+ val buffer = ByteArray(MAX_BUF_SIZE)
+ val reply = DatagramPacket(buffer, buffer.size)
+ socket.receive(reply)
+ return TestDnsServer.DnsAnswerPacket(reply.data)
+ }
+
+ // TODO: Add more tests, which includes:
+ // * Empty question RR packet (or more unexpected states)
+ // * No answer found (setAnswer empty list at L.78)
+ // * Test one or multi A record(s)
+ // * Test multi AAAA records
+ // * Test CNAME records
+}
diff --git a/staticlibs/testutils/devicetests/com/android/testutils/DnsAnswerProvider.kt b/staticlibs/testutils/devicetests/com/android/testutils/DnsAnswerProvider.kt
new file mode 100644
index 0000000..6a804bf
--- /dev/null
+++ b/staticlibs/testutils/devicetests/com/android/testutils/DnsAnswerProvider.kt
@@ -0,0 +1,65 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.testutils
+
+import android.net.DnsResolver.CLASS_IN
+import com.android.net.module.util.DnsPacket
+import com.android.net.module.util.DnsPacket.ANSECTION
+import java.net.InetAddress
+import java.util.concurrent.ConcurrentHashMap
+
+const val DEFAULT_TTL_S = 5L
+
+/**
+ * Helper class to store the mapping of DNS queries.
+ *
+ * DnsAnswerProvider is built atop a ConcurrentHashMap and as such it provides the same
+ * guarantees as ConcurrentHashMap between writing and reading elements. Specifically :
+ * - Setting an answer happens-before reading the same answer.
+ * - Callers can read and write concurrently from DnsAnswerProvider and expect no
+ * ConcurrentModificationException.
+ * Freshness of the answers depends on ordering of the threads ; if callers need a
+ * freshness guarantee, they need to provide the happens-before relationship from a
+ * write that they want to observe to the read that they need to be observed.
+ */
+class DnsAnswerProvider {
+ private val mDnsKeyToRecords = ConcurrentHashMap<String, List<DnsPacket.DnsRecord>>()
+
+ /**
+ * Get answer for the specified hostname.
+ *
+ * @param query the target hostname.
+ * @param type type of record, could be A or AAAA.
+ *
+ * @return list of [DnsPacket.DnsRecord] associated to the query. Empty if no record matches.
+ */
+ fun getAnswer(query: String, type: Int) = mDnsKeyToRecords[query]
+ .orEmpty().filter { it.nsType == type }
+
+ /** Set answer for the specified {@code query}.
+ *
+ * @param query the target hostname
+ * @param addresses [List<InetAddress>] which could be used to generate multiple A or AAAA
+ * RRs with the corresponding addresses.
+ */
+ fun setAnswer(query: String, hosts: List<InetAddress>) = mDnsKeyToRecords.put(query, hosts.map {
+ DnsPacket.DnsRecord.makeAOrAAAARecord(ANSECTION, query, CLASS_IN, DEFAULT_TTL_S, it)
+ })
+
+ fun clearAnswer(query: String) = mDnsKeyToRecords.remove(query)
+ fun clearAll() = mDnsKeyToRecords.clear()
+}
\ No newline at end of file
diff --git a/staticlibs/testutils/devicetests/com/android/testutils/TestDnsServer.kt b/staticlibs/testutils/devicetests/com/android/testutils/TestDnsServer.kt
new file mode 100644
index 0000000..c63b38f
--- /dev/null
+++ b/staticlibs/testutils/devicetests/com/android/testutils/TestDnsServer.kt
@@ -0,0 +1,169 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.testutils
+
+import android.net.Network
+import android.util.Log
+import com.android.internal.annotations.GuardedBy
+import com.android.internal.annotations.VisibleForTesting
+import com.android.internal.annotations.VisibleForTesting.Visibility.PRIVATE
+import com.android.net.module.util.DnsPacket
+import java.net.DatagramPacket
+import java.net.DatagramSocket
+import java.net.InetAddress
+import java.net.InetSocketAddress
+import java.net.SocketAddress
+import java.net.SocketException
+import java.util.ArrayList
+
+private const val TAG = "TestDnsServer"
+private const val VDBG = true
+@VisibleForTesting(visibility = PRIVATE)
+const val MAX_BUF_SIZE = 8192
+
+/**
+ * A simple implementation of Dns Server that can be bound on specific address and Network.
+ *
+ * The caller should use start() to make the server start a new thread to receive DNS queries
+ * on the bound address, [isAlive] to check status, and stop() for stopping.
+ * The server allows user to manipulate the records to be answered through
+ * [setAnswer] at runtime.
+ *
+ * This server runs on its own thread. Please make sure writing the query to the socket
+ * happens-after using [setAnswer] to guarantee the correct answer is returned. If possible,
+ * use [setAnswer] before calling [start] for simplicity.
+ */
+class TestDnsServer(network: Network, addr: InetSocketAddress) {
+ enum class Status {
+ NOT_STARTED, STARTED, STOPPED
+ }
+ @GuardedBy("thread")
+ private var status: Status = Status.NOT_STARTED
+ private val thread = ReceivingThread()
+ private val socket = DatagramSocket(addr).also { network.bindSocket(it) }
+ private val ansProvider = DnsAnswerProvider()
+
+ // The buffer to store the received packet. They are being reused for
+ // efficiency and it's fine because they are only ever accessed
+ // on the server thread in a sequential manner.
+ private val buffer = ByteArray(MAX_BUF_SIZE)
+ private val packet = DatagramPacket(buffer, buffer.size)
+
+ fun setAnswer(hostname: String, answer: List<InetAddress>) =
+ ansProvider.setAnswer(hostname, answer)
+
+ private fun processPacket() {
+ // Blocking read and try construct a DnsQueryPacket object.
+ socket.receive(packet)
+ val q = DnsQueryPacket(packet.data)
+ handleDnsQuery(q, packet.socketAddress)
+ }
+
+ // TODO: Add support to reply some error with a DNS reply packet with failure RCODE.
+ private fun handleDnsQuery(q: DnsQueryPacket, src: SocketAddress) {
+ val queryRecords = q.queryRecords
+ if (queryRecords.size != 1) {
+ throw IllegalArgumentException(
+ "Expected one dns query record but got ${queryRecords.size}"
+ )
+ }
+ val answerRecords = queryRecords[0].let { ansProvider.getAnswer(it.dName, it.nsType) }
+
+ if (VDBG) {
+ Log.v(TAG, "handleDnsPacket: " +
+ queryRecords.map { "${it.dName},${it.nsType}" }.joinToString() +
+ " ansCount=${answerRecords.size} socketAddress=$src")
+ }
+
+ val bytes = q.getAnswerPacket(answerRecords).bytes
+ val reply = DatagramPacket(bytes, bytes.size, src)
+ socket.send(reply)
+ }
+
+ fun start() {
+ synchronized(thread) {
+ if (status != Status.NOT_STARTED) {
+ throw IllegalStateException("unexpected status: $status")
+ }
+ thread.start()
+ status = Status.STARTED
+ }
+ }
+ fun stop() {
+ synchronized(thread) {
+ if (status != Status.STARTED) {
+ throw IllegalStateException("unexpected status: $status")
+ }
+ socket.close()
+ thread.interrupt()
+ thread.join()
+ status = Status.STOPPED
+ }
+ }
+ val isAlive get() = thread.isAlive
+ val port get() = socket.localPort
+
+ inner class ReceivingThread : Thread() {
+ override fun run() {
+ Log.i(TAG, "starting addr={${socket.localSocketAddress}}")
+ while (!interrupted() && !socket.isClosed) {
+ try {
+ processPacket()
+ } catch (e: InterruptedException) {
+ // The caller terminated the server, exit.
+ break
+ } catch (e: SocketException) {
+ // The caller terminated the server, exit.
+ break
+ }
+ }
+ Log.i(TAG, "exiting socket={$socket}")
+ }
+ }
+
+ @VisibleForTesting(visibility = PRIVATE)
+ class DnsQueryPacket : DnsPacket {
+ constructor(data: ByteArray) : super(data)
+ constructor(header: DnsHeader, qd: List<DnsRecord>, an: List<DnsRecord>) :
+ super(header, qd, an)
+
+ init {
+ if (mHeader.isResponse) {
+ throw ParseException("Not a query packet")
+ }
+ }
+
+ val queryRecords: List<DnsRecord>
+ get() = mRecords[QDSECTION]
+
+ fun getAnswerPacket(ar: List<DnsRecord>): DnsAnswerPacket {
+ // Set QR bit of flag to 1 for response packet according to RFC 1035 section 4.1.1.
+ val flags = 1 shl 15
+ val qr = ArrayList(mRecords[QDSECTION])
+ // Copy the query packet header id to the answer packet as RFC 1035 section 4.1.1.
+ val header = DnsHeader(mHeader.id, flags, qr.size, ar.size)
+ return DnsAnswerPacket(header, qr, ar)
+ }
+ }
+
+ class DnsAnswerPacket : DnsPacket {
+ constructor(header: DnsHeader, qr: List<DnsRecord>, ar: List<DnsRecord>) :
+ super(header, qr, ar)
+ @VisibleForTesting(visibility = PRIVATE)
+ constructor(bytes: ByteArray) : super(bytes)
+ }
+}