[ST03] Add test dns server for integration tests

Bug: 139774492
Test: atest DnsAnswerProviderTest TestDnsServerTest
Change-Id: Ia5039a47f4b818efb09aab8174beb9d921339c3d
diff --git a/staticlibs/tests/unit/src/com/android/testutils/TestDnsServerTest.kt b/staticlibs/tests/unit/src/com/android/testutils/TestDnsServerTest.kt
new file mode 100644
index 0000000..6f4587b
--- /dev/null
+++ b/staticlibs/tests/unit/src/com/android/testutils/TestDnsServerTest.kt
@@ -0,0 +1,122 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.testutils
+
+import android.net.DnsResolver.CLASS_IN
+import android.net.DnsResolver.TYPE_AAAA
+import android.net.Network
+import androidx.test.ext.junit.runners.AndroidJUnit4
+import androidx.test.filters.SmallTest
+import com.android.net.module.util.DnsPacket
+import com.android.net.module.util.DnsPacket.DnsRecord
+import libcore.net.InetAddressUtils
+import org.junit.After
+import org.junit.Test
+import org.junit.runner.RunWith
+import org.mockito.Mockito
+import java.net.DatagramPacket
+import java.net.DatagramSocket
+import java.net.InetAddress
+import java.net.InetSocketAddress
+import kotlin.test.assertEquals
+import kotlin.test.assertFailsWith
+import kotlin.test.assertFalse
+import kotlin.test.assertTrue
+
+val TEST_V6_ADDR = InetAddressUtils.parseNumericAddress("2001:db8::3")
+const val TEST_DOMAIN = "hello.example.com"
+
+@RunWith(AndroidJUnit4::class)
+@SmallTest
+class TestDnsServerTest {
+    private val network = Mockito.mock(Network::class.java)
+    private val localAddr = InetSocketAddress(InetAddress.getLocalHost(), 0 /* port */)
+    private val testServer: TestDnsServer = TestDnsServer(network, localAddr)
+
+    @After
+    fun tearDown() {
+        if (testServer.isAlive) testServer.stop()
+    }
+
+    @Test
+    fun testStartStop() {
+        repeat(100) {
+            val server = TestDnsServer(network, localAddr)
+            server.start()
+            assertTrue(server.isAlive)
+            server.stop()
+            assertFalse(server.isAlive)
+        }
+
+        // Test illegal start/stop.
+        assertFailsWith<IllegalStateException> { testServer.stop() }
+        testServer.start()
+        assertTrue(testServer.isAlive)
+        assertFailsWith<IllegalStateException> { testServer.start() }
+        testServer.stop()
+        assertFalse(testServer.isAlive)
+        assertFailsWith<IllegalStateException> { testServer.stop() }
+        // TestDnsServer rejects start after stop.
+        assertFailsWith<IllegalStateException> { testServer.start() }
+    }
+
+    @Test
+    fun testHandleDnsQuery() {
+        testServer.setAnswer(TEST_DOMAIN, listOf(TEST_V6_ADDR))
+        testServer.start()
+
+        // Mock query and send it to the test server.
+        val queryHeader = DnsPacket.DnsHeader(0xbeef /* id */,
+                0x0 /* flag */, 1 /* qcount */, 0 /* ancount */)
+        val qlist = listOf(DnsRecord.makeQuestion(TEST_DOMAIN, TYPE_AAAA, CLASS_IN))
+        val queryPacket = TestDnsServer.DnsQueryPacket(queryHeader, qlist, emptyList())
+        val response = resolve(queryPacket, testServer.port)
+
+        // Verify expected answer packet. Set QR bit of flag to 1 for response packet
+        // according to RFC 1035 section 4.1.1.
+        val answerHeader = DnsPacket.DnsHeader(0xbeef,
+            1 shl 15 /* flag */, 1 /* qcount */, 1 /* ancount */)
+        val alist = listOf(DnsRecord.makeAOrAAAARecord(DnsPacket.ANSECTION, TEST_DOMAIN,
+                    CLASS_IN, DEFAULT_TTL_S, TEST_V6_ADDR))
+        val expectedAnswerPacket = TestDnsServer.DnsAnswerPacket(answerHeader, qlist, alist)
+        assertEquals(expectedAnswerPacket, response)
+
+        // Clean up the server in tearDown.
+    }
+
+    private fun resolve(queryDnsPacket: DnsPacket, serverPort: Int): TestDnsServer.DnsAnswerPacket {
+        val bytes = queryDnsPacket.bytes
+        // Create a new client socket, the socket will be bound to a
+        // random port other than the server port.
+        val socket = DatagramSocket(localAddr).also { it.soTimeout = 100 }
+        val queryPacket = DatagramPacket(bytes, bytes.size, localAddr.address, serverPort)
+
+        // Send query and wait for the reply.
+        socket.send(queryPacket)
+        val buffer = ByteArray(MAX_BUF_SIZE)
+        val reply = DatagramPacket(buffer, buffer.size)
+        socket.receive(reply)
+        return TestDnsServer.DnsAnswerPacket(reply.data)
+    }
+
+    // TODO: Add more tests, which includes:
+    //  * Empty question RR packet (or more unexpected states)
+    //  * No answer found (setAnswer empty list at L.78)
+    //  * Test one or multi A record(s)
+    //  * Test multi AAAA records
+    //  * Test CNAME records
+}
diff --git a/staticlibs/testutils/devicetests/com/android/testutils/DnsAnswerProvider.kt b/staticlibs/testutils/devicetests/com/android/testutils/DnsAnswerProvider.kt
new file mode 100644
index 0000000..6a804bf
--- /dev/null
+++ b/staticlibs/testutils/devicetests/com/android/testutils/DnsAnswerProvider.kt
@@ -0,0 +1,65 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.testutils
+
+import android.net.DnsResolver.CLASS_IN
+import com.android.net.module.util.DnsPacket
+import com.android.net.module.util.DnsPacket.ANSECTION
+import java.net.InetAddress
+import java.util.concurrent.ConcurrentHashMap
+
+const val DEFAULT_TTL_S = 5L
+
+/**
+ * Helper class to store the mapping of DNS queries.
+ *
+ * DnsAnswerProvider is built atop a ConcurrentHashMap and as such it provides the same
+ * guarantees as ConcurrentHashMap between writing and reading elements. Specifically :
+ * - Setting an answer happens-before reading the same answer.
+ * - Callers can read and write concurrently from DnsAnswerProvider and expect no
+ *   ConcurrentModificationException.
+ * Freshness of the answers depends on ordering of the threads ; if callers need a
+ * freshness guarantee, they need to provide the happens-before relationship from a
+ * write that they want to observe to the read that they need to be observed.
+ */
+class DnsAnswerProvider {
+    private val mDnsKeyToRecords = ConcurrentHashMap<String, List<DnsPacket.DnsRecord>>()
+
+    /**
+     * Get answer for the specified hostname.
+     *
+     * @param query the target hostname.
+     * @param type type of record, could be A or AAAA.
+     *
+     * @return list of [DnsPacket.DnsRecord] associated to the query. Empty if no record matches.
+     */
+    fun getAnswer(query: String, type: Int) = mDnsKeyToRecords[query]
+            .orEmpty().filter { it.nsType == type }
+
+    /** Set answer for the specified {@code query}.
+     *
+     * @param query the target hostname
+     * @param addresses [List<InetAddress>] which could be used to generate multiple A or AAAA
+     *                  RRs with the corresponding addresses.
+     */
+    fun setAnswer(query: String, hosts: List<InetAddress>) = mDnsKeyToRecords.put(query, hosts.map {
+            DnsPacket.DnsRecord.makeAOrAAAARecord(ANSECTION, query, CLASS_IN, DEFAULT_TTL_S, it)
+        })
+
+    fun clearAnswer(query: String) = mDnsKeyToRecords.remove(query)
+    fun clearAll() = mDnsKeyToRecords.clear()
+}
\ No newline at end of file
diff --git a/staticlibs/testutils/devicetests/com/android/testutils/TestDnsServer.kt b/staticlibs/testutils/devicetests/com/android/testutils/TestDnsServer.kt
new file mode 100644
index 0000000..c63b38f
--- /dev/null
+++ b/staticlibs/testutils/devicetests/com/android/testutils/TestDnsServer.kt
@@ -0,0 +1,169 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.testutils
+
+import android.net.Network
+import android.util.Log
+import com.android.internal.annotations.GuardedBy
+import com.android.internal.annotations.VisibleForTesting
+import com.android.internal.annotations.VisibleForTesting.Visibility.PRIVATE
+import com.android.net.module.util.DnsPacket
+import java.net.DatagramPacket
+import java.net.DatagramSocket
+import java.net.InetAddress
+import java.net.InetSocketAddress
+import java.net.SocketAddress
+import java.net.SocketException
+import java.util.ArrayList
+
+private const val TAG = "TestDnsServer"
+private const val VDBG = true
+@VisibleForTesting(visibility = PRIVATE)
+const val MAX_BUF_SIZE = 8192
+
+/**
+ * A simple implementation of Dns Server that can be bound on specific address and Network.
+ *
+ * The caller should use start() to make the server start a new thread to receive DNS queries
+ * on the bound address, [isAlive] to check status, and stop() for stopping.
+ * The server allows user to manipulate the records to be answered through
+ * [setAnswer] at runtime.
+ *
+ * This server runs on its own thread. Please make sure writing the query to the socket
+ * happens-after using [setAnswer] to guarantee the correct answer is returned. If possible,
+ * use [setAnswer] before calling [start] for simplicity.
+ */
+class TestDnsServer(network: Network, addr: InetSocketAddress) {
+    enum class Status {
+        NOT_STARTED, STARTED, STOPPED
+    }
+    @GuardedBy("thread")
+    private var status: Status = Status.NOT_STARTED
+    private val thread = ReceivingThread()
+    private val socket = DatagramSocket(addr).also { network.bindSocket(it) }
+    private val ansProvider = DnsAnswerProvider()
+
+    // The buffer to store the received packet. They are being reused for
+    // efficiency and it's fine because they are only ever accessed
+    // on the server thread in a sequential manner.
+    private val buffer = ByteArray(MAX_BUF_SIZE)
+    private val packet = DatagramPacket(buffer, buffer.size)
+
+    fun setAnswer(hostname: String, answer: List<InetAddress>) =
+        ansProvider.setAnswer(hostname, answer)
+
+    private fun processPacket() {
+        // Blocking read and try construct a DnsQueryPacket object.
+        socket.receive(packet)
+        val q = DnsQueryPacket(packet.data)
+        handleDnsQuery(q, packet.socketAddress)
+    }
+
+    // TODO: Add support to reply some error with a DNS reply packet with failure RCODE.
+    private fun handleDnsQuery(q: DnsQueryPacket, src: SocketAddress) {
+        val queryRecords = q.queryRecords
+        if (queryRecords.size != 1) {
+            throw IllegalArgumentException(
+                "Expected one dns query record but got ${queryRecords.size}"
+            )
+        }
+        val answerRecords = queryRecords[0].let { ansProvider.getAnswer(it.dName, it.nsType) }
+
+        if (VDBG) {
+            Log.v(TAG, "handleDnsPacket: " +
+                        queryRecords.map { "${it.dName},${it.nsType}" }.joinToString() +
+                        " ansCount=${answerRecords.size} socketAddress=$src")
+        }
+
+        val bytes = q.getAnswerPacket(answerRecords).bytes
+        val reply = DatagramPacket(bytes, bytes.size, src)
+        socket.send(reply)
+    }
+
+    fun start() {
+        synchronized(thread) {
+            if (status != Status.NOT_STARTED) {
+                throw IllegalStateException("unexpected status: $status")
+            }
+            thread.start()
+            status = Status.STARTED
+        }
+    }
+    fun stop() {
+        synchronized(thread) {
+            if (status != Status.STARTED) {
+                throw IllegalStateException("unexpected status: $status")
+            }
+            socket.close()
+            thread.interrupt()
+            thread.join()
+            status = Status.STOPPED
+        }
+    }
+    val isAlive get() = thread.isAlive
+    val port get() = socket.localPort
+
+    inner class ReceivingThread : Thread() {
+        override fun run() {
+            Log.i(TAG, "starting addr={${socket.localSocketAddress}}")
+            while (!interrupted() && !socket.isClosed) {
+                try {
+                    processPacket()
+                } catch (e: InterruptedException) {
+                    // The caller terminated the server, exit.
+                    break
+                } catch (e: SocketException) {
+                    // The caller terminated the server, exit.
+                    break
+                }
+            }
+            Log.i(TAG, "exiting socket={$socket}")
+        }
+    }
+
+    @VisibleForTesting(visibility = PRIVATE)
+    class DnsQueryPacket : DnsPacket {
+        constructor(data: ByteArray) : super(data)
+        constructor(header: DnsHeader, qd: List<DnsRecord>, an: List<DnsRecord>) :
+                super(header, qd, an)
+
+        init {
+            if (mHeader.isResponse) {
+                throw ParseException("Not a query packet")
+            }
+        }
+
+        val queryRecords: List<DnsRecord>
+            get() = mRecords[QDSECTION]
+
+        fun getAnswerPacket(ar: List<DnsRecord>): DnsAnswerPacket {
+            // Set QR bit of flag to 1 for response packet according to RFC 1035 section 4.1.1.
+            val flags = 1 shl 15
+            val qr = ArrayList(mRecords[QDSECTION])
+            // Copy the query packet header id to the answer packet as RFC 1035 section 4.1.1.
+            val header = DnsHeader(mHeader.id, flags, qr.size, ar.size)
+            return DnsAnswerPacket(header, qr, ar)
+        }
+    }
+
+    class DnsAnswerPacket : DnsPacket {
+        constructor(header: DnsHeader, qr: List<DnsRecord>, ar: List<DnsRecord>) :
+                super(header, qr, ar)
+        @VisibleForTesting(visibility = PRIVATE)
+        constructor(bytes: ByteArray) : super(bytes)
+    }
+}