Merge "Add useful utils to CollectionUtils"
diff --git a/staticlibs/tests/unit/src/com/android/testutils/TestDnsServerTest.kt b/staticlibs/tests/unit/src/com/android/testutils/TestDnsServerTest.kt
new file mode 100644
index 0000000..6f4587b
--- /dev/null
+++ b/staticlibs/tests/unit/src/com/android/testutils/TestDnsServerTest.kt
@@ -0,0 +1,122 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.testutils
+
+import android.net.DnsResolver.CLASS_IN
+import android.net.DnsResolver.TYPE_AAAA
+import android.net.Network
+import androidx.test.ext.junit.runners.AndroidJUnit4
+import androidx.test.filters.SmallTest
+import com.android.net.module.util.DnsPacket
+import com.android.net.module.util.DnsPacket.DnsRecord
+import libcore.net.InetAddressUtils
+import org.junit.After
+import org.junit.Test
+import org.junit.runner.RunWith
+import org.mockito.Mockito
+import java.net.DatagramPacket
+import java.net.DatagramSocket
+import java.net.InetAddress
+import java.net.InetSocketAddress
+import kotlin.test.assertEquals
+import kotlin.test.assertFailsWith
+import kotlin.test.assertFalse
+import kotlin.test.assertTrue
+
+val TEST_V6_ADDR = InetAddressUtils.parseNumericAddress("2001:db8::3")
+const val TEST_DOMAIN = "hello.example.com"
+
+@RunWith(AndroidJUnit4::class)
+@SmallTest
+class TestDnsServerTest {
+ private val network = Mockito.mock(Network::class.java)
+ private val localAddr = InetSocketAddress(InetAddress.getLocalHost(), 0 /* port */)
+ private val testServer: TestDnsServer = TestDnsServer(network, localAddr)
+
+ @After
+ fun tearDown() {
+ if (testServer.isAlive) testServer.stop()
+ }
+
+ @Test
+ fun testStartStop() {
+ repeat(100) {
+ val server = TestDnsServer(network, localAddr)
+ server.start()
+ assertTrue(server.isAlive)
+ server.stop()
+ assertFalse(server.isAlive)
+ }
+
+ // Test illegal start/stop.
+ assertFailsWith<IllegalStateException> { testServer.stop() }
+ testServer.start()
+ assertTrue(testServer.isAlive)
+ assertFailsWith<IllegalStateException> { testServer.start() }
+ testServer.stop()
+ assertFalse(testServer.isAlive)
+ assertFailsWith<IllegalStateException> { testServer.stop() }
+ // TestDnsServer rejects start after stop.
+ assertFailsWith<IllegalStateException> { testServer.start() }
+ }
+
+ @Test
+ fun testHandleDnsQuery() {
+ testServer.setAnswer(TEST_DOMAIN, listOf(TEST_V6_ADDR))
+ testServer.start()
+
+ // Mock query and send it to the test server.
+ val queryHeader = DnsPacket.DnsHeader(0xbeef /* id */,
+ 0x0 /* flag */, 1 /* qcount */, 0 /* ancount */)
+ val qlist = listOf(DnsRecord.makeQuestion(TEST_DOMAIN, TYPE_AAAA, CLASS_IN))
+ val queryPacket = TestDnsServer.DnsQueryPacket(queryHeader, qlist, emptyList())
+ val response = resolve(queryPacket, testServer.port)
+
+ // Verify expected answer packet. Set QR bit of flag to 1 for response packet
+ // according to RFC 1035 section 4.1.1.
+ val answerHeader = DnsPacket.DnsHeader(0xbeef,
+ 1 shl 15 /* flag */, 1 /* qcount */, 1 /* ancount */)
+ val alist = listOf(DnsRecord.makeAOrAAAARecord(DnsPacket.ANSECTION, TEST_DOMAIN,
+ CLASS_IN, DEFAULT_TTL_S, TEST_V6_ADDR))
+ val expectedAnswerPacket = TestDnsServer.DnsAnswerPacket(answerHeader, qlist, alist)
+ assertEquals(expectedAnswerPacket, response)
+
+ // Clean up the server in tearDown.
+ }
+
+ private fun resolve(queryDnsPacket: DnsPacket, serverPort: Int): TestDnsServer.DnsAnswerPacket {
+ val bytes = queryDnsPacket.bytes
+ // Create a new client socket, the socket will be bound to a
+ // random port other than the server port.
+ val socket = DatagramSocket(localAddr).also { it.soTimeout = 100 }
+ val queryPacket = DatagramPacket(bytes, bytes.size, localAddr.address, serverPort)
+
+ // Send query and wait for the reply.
+ socket.send(queryPacket)
+ val buffer = ByteArray(MAX_BUF_SIZE)
+ val reply = DatagramPacket(buffer, buffer.size)
+ socket.receive(reply)
+ return TestDnsServer.DnsAnswerPacket(reply.data)
+ }
+
+ // TODO: Add more tests, which includes:
+ // * Empty question RR packet (or more unexpected states)
+ // * No answer found (setAnswer empty list at L.78)
+ // * Test one or multi A record(s)
+ // * Test multi AAAA records
+ // * Test CNAME records
+}
diff --git a/staticlibs/testutils/devicetests/com/android/testutils/DeviceConfigRule.kt b/staticlibs/testutils/devicetests/com/android/testutils/DeviceConfigRule.kt
new file mode 100644
index 0000000..3d98cc3
--- /dev/null
+++ b/staticlibs/testutils/devicetests/com/android/testutils/DeviceConfigRule.kt
@@ -0,0 +1,176 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.testutils
+
+import android.Manifest.permission.READ_DEVICE_CONFIG
+import android.Manifest.permission.WRITE_DEVICE_CONFIG
+import android.provider.DeviceConfig
+import android.util.Log
+import com.android.modules.utils.build.SdkLevel
+import com.android.testutils.FunctionalUtils.ThrowingRunnable
+import org.junit.rules.TestRule
+import org.junit.runner.Description
+import org.junit.runners.model.Statement
+import java.util.concurrent.CompletableFuture
+import java.util.concurrent.Executor
+import java.util.concurrent.TimeUnit
+
+private val TAG = DeviceConfigRule::class.simpleName
+
+private const val TIMEOUT_MS = 20_000L
+
+/**
+ * A [TestRule] that helps set [DeviceConfig] for tests and clean up the test configuration
+ * automatically on teardown.
+ *
+ * The rule can also optionally retry tests when they fail following an external change of
+ * DeviceConfig before S; this typically happens because device config flags are synced while the
+ * test is running, and DisableConfigSyncTargetPreparer is only usable starting from S.
+ *
+ * @param retryCountBeforeSIfConfigChanged if > 0, when the test fails before S, check if
+ * the configs that were set through this rule were changed, and retry the test
+ * up to the specified number of times if yes.
+ */
+class DeviceConfigRule @JvmOverloads constructor(
+ val retryCountBeforeSIfConfigChanged: Int = 0
+) : TestRule {
+ // Maps (namespace, key) -> value
+ private val originalConfig = mutableMapOf<Pair<String, String>, String?>()
+ private val usedConfig = mutableMapOf<Pair<String, String>, String?>()
+
+ /**
+ * Actions to be run after cleanup of the config, for the current test only.
+ */
+ private val currentTestCleanupActions = mutableListOf<ThrowingRunnable>()
+
+ override fun apply(base: Statement, description: Description): Statement {
+ return TestValidationUrlStatement(base, description)
+ }
+
+ private inner class TestValidationUrlStatement(
+ private val base: Statement,
+ private val description: Description
+ ) : Statement() {
+ override fun evaluate() {
+ var retryCount = if (SdkLevel.isAtLeastS()) 1 else retryCountBeforeSIfConfigChanged + 1
+ while (retryCount > 0) {
+ retryCount--
+ tryTest {
+ base.evaluate()
+ // Can't use break/return out of a loop here because this is a tryTest lambda,
+ // so set retryCount to exit instead
+ retryCount = 0
+ }.catch<Throwable> { e -> // junit AssertionFailedError does not extend Exception
+ if (retryCount == 0) throw e
+ usedConfig.forEach { (key, value) ->
+ val currentValue = runAsShell(READ_DEVICE_CONFIG) {
+ DeviceConfig.getProperty(key.first, key.second)
+ }
+ if (currentValue != value) {
+ Log.w(TAG, "Test failed with unexpected device config change, retrying")
+ return@catch
+ }
+ }
+ throw e
+ } cleanupStep {
+ runAsShell(WRITE_DEVICE_CONFIG) {
+ originalConfig.forEach { (key, value) ->
+ DeviceConfig.setProperty(
+ key.first, key.second, value, false /* makeDefault */)
+ }
+ }
+ } cleanupStep {
+ originalConfig.clear()
+ usedConfig.clear()
+ } cleanup {
+ // Fold all cleanup actions into cleanup steps of an empty tryTest, so they are
+ // all run even if exceptions are thrown, and exceptions are reported properly.
+ currentTestCleanupActions.fold(tryTest { }) {
+ tryBlock, action -> tryBlock.cleanupStep { action.run() }
+ }.cleanup {
+ currentTestCleanupActions.clear()
+ }
+ }
+ }
+ }
+ }
+
+ /**
+ * Set a configuration key/value. After the test case ends, it will be restored to the value it
+ * had when this method was first called.
+ */
+ fun setConfig(namespace: String, key: String, value: String?): String? {
+ Log.i(TAG, "Setting config \"$key\" to \"$value\"")
+ val readWritePermissions = arrayOf(READ_DEVICE_CONFIG, WRITE_DEVICE_CONFIG)
+
+ val keyPair = Pair(namespace, key)
+ val existingValue = runAsShell(*readWritePermissions) {
+ DeviceConfig.getProperty(namespace, key)
+ }
+ if (!originalConfig.containsKey(keyPair)) {
+ originalConfig[keyPair] = existingValue
+ }
+ usedConfig[keyPair] = value
+ if (existingValue == value) {
+ // Already the correct value. There may be a race if a change is already in flight,
+ // but if multiple threads update the config there is no way to fix that anyway.
+ Log.i(TAG, "\"$key\" already had value \"$value\"")
+ return value
+ }
+
+ val future = CompletableFuture<String>()
+ val listener = DeviceConfig.OnPropertiesChangedListener {
+ // The listener receives updates for any change to any key, so don't react to
+ // changes that do not affect the relevant key
+ if (!it.keyset.contains(key)) return@OnPropertiesChangedListener
+ // "null" means absent in DeviceConfig : there is no such thing as a present but
+ // null value, so the following works even if |value| is null.
+ if (it.getString(key, null) == value) {
+ future.complete(value)
+ }
+ }
+
+ return tryTest {
+ runAsShell(*readWritePermissions) {
+ DeviceConfig.addOnPropertiesChangedListener(
+ DeviceConfig.NAMESPACE_CONNECTIVITY,
+ inlineExecutor,
+ listener)
+ DeviceConfig.setProperty(
+ DeviceConfig.NAMESPACE_CONNECTIVITY,
+ key,
+ value,
+ false /* makeDefault */)
+ // Don't drop the permission until the config is applied, just in case
+ future.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)
+ }.also {
+ Log.i(TAG, "Config \"$key\" successfully set to \"$value\"")
+ }
+ } cleanup {
+ DeviceConfig.removeOnPropertiesChangedListener(listener)
+ }
+ }
+
+ private val inlineExecutor get() = Executor { r -> r.run() }
+
+ /**
+ * Add an action to be run after config cleanup when the current test case ends.
+ */
+ fun runAfterNextCleanup(action: ThrowingRunnable) {
+ currentTestCleanupActions.add(action)
+ }
+}
diff --git a/staticlibs/testutils/devicetests/com/android/testutils/DnsAnswerProvider.kt b/staticlibs/testutils/devicetests/com/android/testutils/DnsAnswerProvider.kt
new file mode 100644
index 0000000..6a804bf
--- /dev/null
+++ b/staticlibs/testutils/devicetests/com/android/testutils/DnsAnswerProvider.kt
@@ -0,0 +1,65 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.testutils
+
+import android.net.DnsResolver.CLASS_IN
+import com.android.net.module.util.DnsPacket
+import com.android.net.module.util.DnsPacket.ANSECTION
+import java.net.InetAddress
+import java.util.concurrent.ConcurrentHashMap
+
+const val DEFAULT_TTL_S = 5L
+
+/**
+ * Helper class to store the mapping of DNS queries.
+ *
+ * DnsAnswerProvider is built atop a ConcurrentHashMap and as such it provides the same
+ * guarantees as ConcurrentHashMap between writing and reading elements. Specifically :
+ * - Setting an answer happens-before reading the same answer.
+ * - Callers can read and write concurrently from DnsAnswerProvider and expect no
+ * ConcurrentModificationException.
+ * Freshness of the answers depends on ordering of the threads ; if callers need a
+ * freshness guarantee, they need to provide the happens-before relationship from a
+ * write that they want to observe to the read that they need to be observed.
+ */
+class DnsAnswerProvider {
+ private val mDnsKeyToRecords = ConcurrentHashMap<String, List<DnsPacket.DnsRecord>>()
+
+ /**
+ * Get answer for the specified hostname.
+ *
+ * @param query the target hostname.
+ * @param type type of record, could be A or AAAA.
+ *
+ * @return list of [DnsPacket.DnsRecord] associated to the query. Empty if no record matches.
+ */
+ fun getAnswer(query: String, type: Int) = mDnsKeyToRecords[query]
+ .orEmpty().filter { it.nsType == type }
+
+ /** Set answer for the specified {@code query}.
+ *
+ * @param query the target hostname
+ * @param addresses [List<InetAddress>] which could be used to generate multiple A or AAAA
+ * RRs with the corresponding addresses.
+ */
+ fun setAnswer(query: String, hosts: List<InetAddress>) = mDnsKeyToRecords.put(query, hosts.map {
+ DnsPacket.DnsRecord.makeAOrAAAARecord(ANSECTION, query, CLASS_IN, DEFAULT_TTL_S, it)
+ })
+
+ fun clearAnswer(query: String) = mDnsKeyToRecords.remove(query)
+ fun clearAll() = mDnsKeyToRecords.clear()
+}
\ No newline at end of file
diff --git a/staticlibs/testutils/devicetests/com/android/testutils/TestDnsServer.kt b/staticlibs/testutils/devicetests/com/android/testutils/TestDnsServer.kt
new file mode 100644
index 0000000..c63b38f
--- /dev/null
+++ b/staticlibs/testutils/devicetests/com/android/testutils/TestDnsServer.kt
@@ -0,0 +1,169 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.testutils
+
+import android.net.Network
+import android.util.Log
+import com.android.internal.annotations.GuardedBy
+import com.android.internal.annotations.VisibleForTesting
+import com.android.internal.annotations.VisibleForTesting.Visibility.PRIVATE
+import com.android.net.module.util.DnsPacket
+import java.net.DatagramPacket
+import java.net.DatagramSocket
+import java.net.InetAddress
+import java.net.InetSocketAddress
+import java.net.SocketAddress
+import java.net.SocketException
+import java.util.ArrayList
+
+private const val TAG = "TestDnsServer"
+private const val VDBG = true
+@VisibleForTesting(visibility = PRIVATE)
+const val MAX_BUF_SIZE = 8192
+
+/**
+ * A simple implementation of Dns Server that can be bound on specific address and Network.
+ *
+ * The caller should use start() to make the server start a new thread to receive DNS queries
+ * on the bound address, [isAlive] to check status, and stop() for stopping.
+ * The server allows user to manipulate the records to be answered through
+ * [setAnswer] at runtime.
+ *
+ * This server runs on its own thread. Please make sure writing the query to the socket
+ * happens-after using [setAnswer] to guarantee the correct answer is returned. If possible,
+ * use [setAnswer] before calling [start] for simplicity.
+ */
+class TestDnsServer(network: Network, addr: InetSocketAddress) {
+ enum class Status {
+ NOT_STARTED, STARTED, STOPPED
+ }
+ @GuardedBy("thread")
+ private var status: Status = Status.NOT_STARTED
+ private val thread = ReceivingThread()
+ private val socket = DatagramSocket(addr).also { network.bindSocket(it) }
+ private val ansProvider = DnsAnswerProvider()
+
+ // The buffer to store the received packet. They are being reused for
+ // efficiency and it's fine because they are only ever accessed
+ // on the server thread in a sequential manner.
+ private val buffer = ByteArray(MAX_BUF_SIZE)
+ private val packet = DatagramPacket(buffer, buffer.size)
+
+ fun setAnswer(hostname: String, answer: List<InetAddress>) =
+ ansProvider.setAnswer(hostname, answer)
+
+ private fun processPacket() {
+ // Blocking read and try construct a DnsQueryPacket object.
+ socket.receive(packet)
+ val q = DnsQueryPacket(packet.data)
+ handleDnsQuery(q, packet.socketAddress)
+ }
+
+ // TODO: Add support to reply some error with a DNS reply packet with failure RCODE.
+ private fun handleDnsQuery(q: DnsQueryPacket, src: SocketAddress) {
+ val queryRecords = q.queryRecords
+ if (queryRecords.size != 1) {
+ throw IllegalArgumentException(
+ "Expected one dns query record but got ${queryRecords.size}"
+ )
+ }
+ val answerRecords = queryRecords[0].let { ansProvider.getAnswer(it.dName, it.nsType) }
+
+ if (VDBG) {
+ Log.v(TAG, "handleDnsPacket: " +
+ queryRecords.map { "${it.dName},${it.nsType}" }.joinToString() +
+ " ansCount=${answerRecords.size} socketAddress=$src")
+ }
+
+ val bytes = q.getAnswerPacket(answerRecords).bytes
+ val reply = DatagramPacket(bytes, bytes.size, src)
+ socket.send(reply)
+ }
+
+ fun start() {
+ synchronized(thread) {
+ if (status != Status.NOT_STARTED) {
+ throw IllegalStateException("unexpected status: $status")
+ }
+ thread.start()
+ status = Status.STARTED
+ }
+ }
+ fun stop() {
+ synchronized(thread) {
+ if (status != Status.STARTED) {
+ throw IllegalStateException("unexpected status: $status")
+ }
+ socket.close()
+ thread.interrupt()
+ thread.join()
+ status = Status.STOPPED
+ }
+ }
+ val isAlive get() = thread.isAlive
+ val port get() = socket.localPort
+
+ inner class ReceivingThread : Thread() {
+ override fun run() {
+ Log.i(TAG, "starting addr={${socket.localSocketAddress}}")
+ while (!interrupted() && !socket.isClosed) {
+ try {
+ processPacket()
+ } catch (e: InterruptedException) {
+ // The caller terminated the server, exit.
+ break
+ } catch (e: SocketException) {
+ // The caller terminated the server, exit.
+ break
+ }
+ }
+ Log.i(TAG, "exiting socket={$socket}")
+ }
+ }
+
+ @VisibleForTesting(visibility = PRIVATE)
+ class DnsQueryPacket : DnsPacket {
+ constructor(data: ByteArray) : super(data)
+ constructor(header: DnsHeader, qd: List<DnsRecord>, an: List<DnsRecord>) :
+ super(header, qd, an)
+
+ init {
+ if (mHeader.isResponse) {
+ throw ParseException("Not a query packet")
+ }
+ }
+
+ val queryRecords: List<DnsRecord>
+ get() = mRecords[QDSECTION]
+
+ fun getAnswerPacket(ar: List<DnsRecord>): DnsAnswerPacket {
+ // Set QR bit of flag to 1 for response packet according to RFC 1035 section 4.1.1.
+ val flags = 1 shl 15
+ val qr = ArrayList(mRecords[QDSECTION])
+ // Copy the query packet header id to the answer packet as RFC 1035 section 4.1.1.
+ val header = DnsHeader(mHeader.id, flags, qr.size, ar.size)
+ return DnsAnswerPacket(header, qr, ar)
+ }
+ }
+
+ class DnsAnswerPacket : DnsPacket {
+ constructor(header: DnsHeader, qr: List<DnsRecord>, ar: List<DnsRecord>) :
+ super(header, qr, ar)
+ @VisibleForTesting(visibility = PRIVATE)
+ constructor(bytes: ByteArray) : super(bytes)
+ }
+}
diff --git a/staticlibs/testutils/devicetests/com/android/testutils/TestableNetworkCallback.kt b/staticlibs/testutils/devicetests/com/android/testutils/TestableNetworkCallback.kt
index e86ea98..b84f9a6 100644
--- a/staticlibs/testutils/devicetests/com/android/testutils/TestableNetworkCallback.kt
+++ b/staticlibs/testutils/devicetests/com/android/testutils/TestableNetworkCallback.kt
@@ -376,8 +376,17 @@
fun <T : CallbackEntry> eventuallyExpect(
type: KClass<T>,
timeoutMs: Long = defaultTimeoutMs,
- predicate: (T: CallbackEntry) -> Boolean = { true }
- ) = history.poll(timeoutMs) { type.java.isInstance(it) && predicate(it) }.also {
+ predicate: (cb: T) -> Boolean = { true }
+ ) = history.poll(timeoutMs) { type.java.isInstance(it) && predicate(it as T) }.also {
+ assertNotNull(it, "Callback ${type.java} not received within ${timeoutMs}ms")
+ } as T
+
+ fun <T : CallbackEntry> eventuallyExpect(
+ type: KClass<T>,
+ timeoutMs: Long = defaultTimeoutMs,
+ from: Int = mark,
+ predicate: (cb: T) -> Boolean = { true }
+ ) = history.poll(timeoutMs, from) { type.java.isInstance(it) && predicate(it as T) }.also {
assertNotNull(it, "Callback ${type.java} not received within ${timeoutMs}ms")
} as T