[Thread] create a base class TestUdpServer

This class is useful for creating a class that responds to specific type of UDP messages. E.g. DNS server and UDP echo server.

Bug: 356770473
Test: atest ThreadNetworkIntegrationTests:android.net.thread.InternetAccessTest

Change-Id: I6880bb825ac5077d4a73f9d986116c571863e1d8
diff --git a/thread/tests/integration/src/android/net/thread/utils/TestDnsServer.kt b/thread/tests/integration/src/android/net/thread/utils/TestDnsServer.kt
index c52fc49..f97c0f2 100644
--- a/thread/tests/integration/src/android/net/thread/utils/TestDnsServer.kt
+++ b/thread/tests/integration/src/android/net/thread/utils/TestDnsServer.kt
@@ -16,18 +16,16 @@
 
 package android.net.thread.utils
 
-import android.net.thread.utils.IntegrationTestUtils.pollForPacket
 import android.system.OsConstants.IPPROTO_IP
 import android.system.OsConstants.IPPROTO_UDP
 import com.android.net.module.util.DnsPacket
 import com.android.net.module.util.PacketBuilder
-import com.android.net.module.util.Struct
 import com.android.net.module.util.structs.Ipv4Header
 import com.android.net.module.util.structs.UdpHeader
 import com.android.testutils.PollPacketReader
 import java.net.InetAddress
+import java.net.InetSocketAddress
 import java.nio.ByteBuffer
-import kotlin.concurrent.thread
 
 /**
  * A class that simulates a DNS server.
@@ -41,11 +39,12 @@
 class TestDnsServer(
     private val packetReader: PollPacketReader,
     private val serverAddress: InetAddress,
-    private val answerRecords: List<DnsPacket.DnsRecord>,
-) {
-    private val TAG = TestDnsServer::class.java.simpleName
-    private val DNS_UDP_PORT = 53
-    private var workerThread: Thread? = null
+    private val serverAnswers: List<DnsPacket.DnsRecord>,
+) : TestUdpServer(packetReader, InetSocketAddress(serverAddress, DNS_UDP_PORT)) {
+    companion object {
+        private val TAG = TestDnsServer::class.java.simpleName
+        private const val DNS_UDP_PORT = 53
+    }
 
     private class TestDnsPacket : DnsPacket {
 
@@ -61,49 +60,12 @@
         val records = super.mRecords
     }
 
-    /**
-     * Starts the DNS server to respond to DNS requests.
-     *
-     * <p> The server polls the DNS requests from the {@code packetReader} and responds with the
-     * {@code answerRecords}. The server will automatically stop when it fails to poll a DNS request
-     * within the timeout (3000 ms, as defined in IntegrationTestUtils).
-     */
-    fun start() {
-        workerThread = thread {
-            var requestPacket: ByteArray
-            while (true) {
-                requestPacket = pollForDnsPacket() ?: break
-                val buf = ByteBuffer.wrap(requestPacket)
-                packetReader.sendResponse(buildDnsResponse(buf, answerRecords))
-            }
-        }
-    }
-
-    /** Stops the DNS server. */
-    fun stop() {
-        workerThread?.join()
-    }
-
-    private fun pollForDnsPacket(): ByteArray? {
-        val filter =
-            fun(packet: ByteArray): Boolean {
-                val buf = ByteBuffer.wrap(packet)
-                val ipv4Header = Struct.parse(Ipv4Header::class.java, buf) ?: return false
-                val udpHeader = Struct.parse(UdpHeader::class.java, buf) ?: return false
-                return ipv4Header.dstIp == serverAddress && udpHeader.dstPort == DNS_UDP_PORT
-            }
-        return pollForPacket(packetReader, filter)
-    }
-
-    private fun buildDnsResponse(
-        requestPacket: ByteBuffer,
-        serverAnswers: List<DnsPacket.DnsRecord>,
+    override fun buildResponse(
+        requestIpv4Header: Ipv4Header,
+        requestUdpHeader: UdpHeader,
+        requestUdpPayload: ByteArray,
     ): ByteBuffer? {
-        val requestIpv4Header = Struct.parse(Ipv4Header::class.java, requestPacket) ?: return null
-        val requestUdpHeader = Struct.parse(UdpHeader::class.java, requestPacket) ?: return null
-        val remainingRequestPacket = ByteArray(requestPacket.remaining())
-        requestPacket.get(remainingRequestPacket)
-        val requestDnsPacket = TestDnsPacket(remainingRequestPacket)
+        val requestDnsPacket = TestDnsPacket(requestUdpPayload)
         val requestDnsHeader = requestDnsPacket.header
 
         val answerRecords =
diff --git a/thread/tests/integration/src/android/net/thread/utils/TestUdpServer.kt b/thread/tests/integration/src/android/net/thread/utils/TestUdpServer.kt
new file mode 100644
index 0000000..fb0942e
--- /dev/null
+++ b/thread/tests/integration/src/android/net/thread/utils/TestUdpServer.kt
@@ -0,0 +1,98 @@
+/*
+ * Copyright (C) 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package android.net.thread.utils
+
+import android.net.thread.utils.IntegrationTestUtils.pollForPacket
+import com.android.net.module.util.Struct
+import com.android.net.module.util.structs.Ipv4Header
+import com.android.net.module.util.structs.UdpHeader
+import com.android.testutils.PollPacketReader
+import java.net.InetSocketAddress
+import java.nio.ByteBuffer
+import kotlin.concurrent.thread
+
+/**
+ * A class that simulates a UDP server that replies to incoming UDP messages.
+ *
+ * @param packetReader the packet reader to poll UDP requests from
+ * @param serverAddress the address and port of the UDP server
+ */
+abstract class TestUdpServer(
+    private val packetReader: PollPacketReader,
+    private val serverAddress: InetSocketAddress,
+) {
+    private val TAG = TestUdpServer::class.java.simpleName
+    private var workerThread: Thread? = null
+
+    /**
+     * Starts the UDP server to respond to UDP messages.
+     *
+     * <p> The server polls the UDP messages from the {@code packetReader} and responds with a
+     * message built by {@code buildResponse}. The server will automatically stop when it fails to
+     * poll a UDP request within the timeout (3000 ms, as defined in IntegrationTestUtils).
+     */
+    fun start() {
+        workerThread = thread {
+            var requestPacket: ByteArray
+            while (true) {
+                requestPacket = pollForUdpPacket() ?: break
+                val buf = ByteBuffer.wrap(requestPacket)
+                packetReader.sendResponse(buildResponse(buf) ?: break)
+            }
+        }
+    }
+
+    /** Stops the UDP server. */
+    fun stop() {
+        workerThread?.join()
+    }
+
+    /**
+     * Builds the UDP response for the given UDP request.
+     *
+     * @param ipv4Header the IPv4 header of the UDP request
+     * @param udpHeader the UDP header of the UDP request
+     * @param udpPayload the payload of the UDP request
+     * @return the UDP response
+     */
+    abstract fun buildResponse(
+        requestIpv4Header: Ipv4Header,
+        requestUdpHeader: UdpHeader,
+        requestUdpPayload: ByteArray,
+    ): ByteBuffer?
+
+    private fun pollForUdpPacket(): ByteArray? {
+        val filter =
+            fun(packet: ByteArray): Boolean {
+                val buf = ByteBuffer.wrap(packet)
+                val ipv4Header = Struct.parse(Ipv4Header::class.java, buf) ?: return false
+                val udpHeader = Struct.parse(UdpHeader::class.java, buf) ?: return false
+                return ipv4Header.dstIp == serverAddress.address &&
+                    udpHeader.dstPort == serverAddress.port
+            }
+        return pollForPacket(packetReader, filter)
+    }
+
+    private fun buildResponse(requestPacket: ByteBuffer): ByteBuffer? {
+        val requestIpv4Header = Struct.parse(Ipv4Header::class.java, requestPacket) ?: return null
+        val requestUdpHeader = Struct.parse(UdpHeader::class.java, requestPacket) ?: return null
+        val remainingRequestPacket = ByteArray(requestPacket.remaining())
+        requestPacket.get(remainingRequestPacket)
+
+        return buildResponse(requestIpv4Header, requestUdpHeader, remainingRequestPacket)
+    }
+}