Merge "bump min_sdk_version from 29 (Q) to 30 (R)" into main
diff --git a/Tethering/src/com/android/networkstack/tethering/util/SyncStateMachine.java b/Tethering/src/com/android/networkstack/tethering/util/SyncStateMachine.java
index cbbbdec..a17eb26 100644
--- a/Tethering/src/com/android/networkstack/tethering/util/SyncStateMachine.java
+++ b/Tethering/src/com/android/networkstack/tethering/util/SyncStateMachine.java
@@ -18,12 +18,14 @@
 
 import android.annotation.NonNull;
 import android.annotation.Nullable;
+import android.os.Message;
 import android.util.ArrayMap;
 import android.util.ArraySet;
 import android.util.Log;
 
 import com.android.internal.util.State;
 
+import java.util.ArrayDeque;
 import java.util.List;
 import java.util.Objects;
 import java.util.Set;
@@ -48,6 +50,14 @@
     // mDestState only be null before state machine starts and must only be touched on mMyThread.
     @Nullable private State mCurrentState;
     @Nullable private State mDestState;
+    private final ArrayDeque<Message> mSelfMsgQueue = new ArrayDeque<Message>();
+
+    // MIN_VALUE means not currently processing any message.
+    private int mCurrentlyProcessing = Integer.MIN_VALUE;
+    // Indicates whether automaton can send self message. Self messages can only be sent by
+    // automaton from State#enter, State#exit, or State#processMessage. Calling from outside
+    // of State is not allowed.
+    private boolean mSelfMsgAllowed = false;
 
     /**
      * A information class about a state and its parent. Used to maintain the state hierarchy.
@@ -141,16 +151,87 @@
         ensureExistingState(initialState);
 
         mDestState = initialState;
+        mSelfMsgAllowed = true;
         performTransitions();
+        mSelfMsgAllowed = false;
+        // If sendSelfMessage was called inside initialState#enter(), mSelfMsgQueue must be
+        // processed.
+        maybeProcessSelfMessageQueue();
     }
 
     /**
-     * Process the message synchronously then perform state transition.
+     * Process the message synchronously then perform state transition. This method is used
+     * externally to the automaton to request that the automaton process the given message.
+     * The message is processed sequentially, so calling this method recursively is not permitted.
+     * In other words, using this method inside State#enter, State#exit, or State#processMessage
+     * is incorrect and will result in an IllegalStateException.
      */
     public final void processMessage(int what, int arg1, int arg2, @Nullable Object obj) {
         ensureCorrectThread();
 
+        if (mCurrentlyProcessing != Integer.MIN_VALUE) {
+            throw new IllegalStateException("Message(" + mCurrentlyProcessing
+                    + ") is still being processed");
+        }
+
+        // mCurrentlyProcessing tracks the external message request and it prevents this method to
+        // be called recursively. Once this message is processed and the transitions have been
+        // performed, the automaton will process the self message queue. The messages in the self
+        // message queue are added from within the automaton during processing external message.
+        // mCurrentlyProcessing is still the original external one and it will not prevent self
+        // messages from being processed.
+        mCurrentlyProcessing = what;
+        final Message msg = Message.obtain(null, what, arg1, arg2, obj);
+        currentStateProcessMessageThenPerformTransitions(msg);
+        msg.recycle();
+        maybeProcessSelfMessageQueue();
+
+        mCurrentlyProcessing = Integer.MIN_VALUE;
+    }
+
+    private void maybeProcessSelfMessageQueue() {
+        while (!mSelfMsgQueue.isEmpty()) {
+            currentStateProcessMessageThenPerformTransitions(mSelfMsgQueue.poll());
+        }
+    }
+
+    private void currentStateProcessMessageThenPerformTransitions(@NonNull final Message msg) {
+        mSelfMsgAllowed = true;
+        StateInfo consideredState = mStateInfo.get(mCurrentState);
+        while (null != consideredState) {
+            // Ideally this should compare with IState.HANDLED, but it is not public field so just
+            // checking whether the return value is true (IState.HANDLED = true).
+            if (consideredState.state.processMessage(msg)) {
+                if (mDbg) {
+                    Log.d(mName, "State " + consideredState.state
+                            + " processed message " + msg.what);
+                }
+                break;
+            }
+            consideredState = mStateInfo.get(consideredState.parent);
+        }
+        if (null == consideredState) {
+            Log.wtf(mName, "Message " + msg.what + " was not handled");
+        }
+
         performTransitions();
+        mSelfMsgAllowed = false;
+    }
+
+    /**
+     * Send self message during state transition.
+     *
+     * Must only be used inside State processMessage, enter or exit. The typical use case is
+     * something wrong happens during state transition, sending an error message which would be
+     * handled after finishing current state transitions.
+     */
+    public final void sendSelfMessage(int what, int arg1, int arg2, Object obj) {
+        if (!mSelfMsgAllowed) {
+            throw new IllegalStateException("sendSelfMessage can only be called inside "
+                    + "State#enter, State#exit or State#processMessage");
+        }
+
+        mSelfMsgQueue.add(Message.obtain(null, what, arg1, arg2, obj));
     }
 
     /**
diff --git a/tests/cts/net/src/android/net/cts/MdnsTestUtils.kt b/tests/cts/net/src/android/net/cts/MdnsTestUtils.kt
index bc13442..eef3f87 100644
--- a/tests/cts/net/src/android/net/cts/MdnsTestUtils.kt
+++ b/tests/cts/net/src/android/net/cts/MdnsTestUtils.kt
@@ -233,46 +233,51 @@
     }
 }
 
+private fun getMdnsPayload(packet: ByteArray) = packet.copyOfRange(
+    ETHER_HEADER_LEN + IPV6_HEADER_LEN + UDP_HEADER_LEN, packet.size)
+
 fun TapPacketReader.pollForMdnsPacket(
     timeoutMs: Long = MDNS_REGISTRATION_TIMEOUT_MS,
     predicate: (TestDnsPacket) -> Boolean
-): ByteArray? {
+): TestDnsPacket? {
     val mdnsProbeFilter = IPv6UdpFilter(srcPort = MDNS_PORT, dstPort = MDNS_PORT).and {
-        val mdnsPayload = it.copyOfRange(
-            ETHER_HEADER_LEN + IPV6_HEADER_LEN + UDP_HEADER_LEN, it.size
-        )
+        val mdnsPayload = getMdnsPayload(it)
         try {
             predicate(TestDnsPacket(mdnsPayload))
         } catch (e: DnsPacket.ParseException) {
             false
         }
     }
-    return poll(timeoutMs, mdnsProbeFilter)
+    return poll(timeoutMs, mdnsProbeFilter)?.let { TestDnsPacket(getMdnsPayload(it)) }
 }
 
 fun TapPacketReader.pollForProbe(
     serviceName: String,
     serviceType: String,
     timeoutMs: Long = MDNS_REGISTRATION_TIMEOUT_MS
-): ByteArray? = pollForMdnsPacket(timeoutMs) { it.isProbeFor("$serviceName.$serviceType.local") }
+): TestDnsPacket? = pollForMdnsPacket(timeoutMs) {
+    it.isProbeFor("$serviceName.$serviceType.local")
+}
 
 fun TapPacketReader.pollForAdvertisement(
     serviceName: String,
     serviceType: String,
     timeoutMs: Long = MDNS_REGISTRATION_TIMEOUT_MS
-): ByteArray? = pollForMdnsPacket(timeoutMs) { it.isReplyFor("$serviceName.$serviceType.local") }
+): TestDnsPacket? = pollForMdnsPacket(timeoutMs) {
+    it.isReplyFor("$serviceName.$serviceType.local")
+}
 
 fun TapPacketReader.pollForQuery(
     recordName: String,
-    recordType: Int,
+    vararg requiredTypes: Int,
     timeoutMs: Long = MDNS_REGISTRATION_TIMEOUT_MS
-): ByteArray? = pollForMdnsPacket(timeoutMs) { it.isQueryFor(recordName, recordType) }
+): TestDnsPacket? = pollForMdnsPacket(timeoutMs) { it.isQueryFor(recordName, *requiredTypes) }
 
 fun TapPacketReader.pollForReply(
     serviceName: String,
     serviceType: String,
     timeoutMs: Long = MDNS_REGISTRATION_TIMEOUT_MS
-): ByteArray? = pollForMdnsPacket(timeoutMs) {
+): TestDnsPacket? = pollForMdnsPacket(timeoutMs) {
     it.isReplyFor("$serviceName.$serviceType.local")
 }
 
@@ -289,7 +294,9 @@
         it.dName == name && it.nsType == DnsResolver.TYPE_SRV
     }
 
-    fun isQueryFor(name: String, type: Int): Boolean = mRecords[QDSECTION].any {
-        it.dName == name && it.nsType == type
+    fun isQueryFor(name: String, vararg requiredTypes: Int): Boolean = requiredTypes.all { type ->
+        mRecords[QDSECTION].any {
+            it.dName == name && it.nsType == type
+        }
     }
 }
diff --git a/tests/cts/net/src/android/net/cts/NsdManagerTest.kt b/tests/cts/net/src/android/net/cts/NsdManagerTest.kt
index 27bd5d3..9c44a3e 100644
--- a/tests/cts/net/src/android/net/cts/NsdManagerTest.kt
+++ b/tests/cts/net/src/android/net/cts/NsdManagerTest.kt
@@ -20,6 +20,7 @@
 import android.app.compat.CompatChanges
 import android.net.ConnectivityManager
 import android.net.ConnectivityManager.NetworkCallback
+import android.net.DnsResolver
 import android.net.InetAddresses.parseNumericAddress
 import android.net.LinkAddress
 import android.net.LinkProperties
@@ -87,6 +88,7 @@
 import com.android.testutils.TestableNetworkAgent
 import com.android.testutils.TestableNetworkAgent.CallbackEntry.OnNetworkCreated
 import com.android.testutils.TestableNetworkCallback
+import com.android.testutils.assertEmpty
 import com.android.testutils.filters.CtsNetTestCasesMaxTargetSdk30
 import com.android.testutils.filters.CtsNetTestCasesMaxTargetSdk33
 import com.android.testutils.runAsShell
@@ -424,11 +426,7 @@
 
     @Test
     fun testNsdManager_DiscoverOnNetwork() {
-        val si = NsdServiceInfo()
-        si.serviceType = serviceType
-        si.serviceName = this.serviceName
-        si.port = 12345 // Test won't try to connect so port does not matter
-
+        val si = makeTestServiceInfo()
         val registrationRecord = NsdRegistrationRecord()
         val registeredInfo = registerService(registrationRecord, si)
 
@@ -455,11 +453,7 @@
 
     @Test
     fun testNsdManager_DiscoverWithNetworkRequest() {
-        val si = NsdServiceInfo()
-        si.serviceType = serviceType
-        si.serviceName = this.serviceName
-        si.port = 12345 // Test won't try to connect so port does not matter
-
+        val si = makeTestServiceInfo()
         val handler = Handler(handlerThread.looper)
         val executor = Executor { handler.post(it) }
 
@@ -524,11 +518,6 @@
 
     @Test
     fun testNsdManager_DiscoverWithNetworkRequest_NoMatchingNetwork() {
-        val si = NsdServiceInfo()
-        si.serviceType = serviceType
-        si.serviceName = this.serviceName
-        si.port = 12345 // Test won't try to connect so port does not matter
-
         val handler = Handler(handlerThread.looper)
         val executor = Executor { handler.post(it) }
 
@@ -568,11 +557,7 @@
 
     @Test
     fun testNsdManager_ResolveOnNetwork() {
-        val si = NsdServiceInfo()
-        si.serviceType = serviceType
-        si.serviceName = this.serviceName
-        si.port = 12345 // Test won't try to connect so port does not matter
-
+        val si = makeTestServiceInfo()
         val registrationRecord = NsdRegistrationRecord()
         val registeredInfo = registerService(registrationRecord, si)
         tryTest {
@@ -610,12 +595,7 @@
 
     @Test
     fun testNsdManager_RegisterOnNetwork() {
-        val si = NsdServiceInfo()
-        si.serviceType = serviceType
-        si.serviceName = this.serviceName
-        si.network = testNetwork1.network
-        si.port = 12345 // Test won't try to connect so port does not matter
-
+        val si = makeTestServiceInfo(testNetwork1.network)
         // Register service on testNetwork1
         val registrationRecord = NsdRegistrationRecord()
         registerService(registrationRecord, si)
@@ -889,11 +869,7 @@
 
     @Test
     fun testStopServiceResolution() {
-        val si = NsdServiceInfo()
-        si.serviceType = this@NsdManagerTest.serviceType
-        si.serviceName = this@NsdManagerTest.serviceName
-        si.port = 12345 // Test won't try to connect so port does not matter
-
+        val si = makeTestServiceInfo()
         val resolveRecord = NsdResolveRecord()
         // Try to resolve an unknown service then stop it immediately.
         // Expected ResolutionStopped callback.
@@ -911,12 +887,7 @@
         val addresses = lp.addresses
         assertFalse(addresses.isEmpty())
 
-        val si = NsdServiceInfo().apply {
-            serviceType = this@NsdManagerTest.serviceType
-            serviceName = this@NsdManagerTest.serviceName
-            network = testNetwork1.network
-            port = 12345 // Test won't try to connect so port does not matter
-        }
+        val si = makeTestServiceInfo(testNetwork1.network)
 
         // Register service on the network
         val registrationRecord = NsdRegistrationRecord()
@@ -1022,11 +993,7 @@
         // This test requires shims supporting T+ APIs (NsdServiceInfo.network)
         assumeTrue(TestUtils.shouldTestTApis())
 
-        val si = NsdServiceInfo()
-        si.serviceType = serviceType
-        si.serviceName = serviceName
-        si.network = testNetwork1.network
-        si.port = 12345 // Test won't try to connect so port does not matter
+        val si = makeTestServiceInfo(testNetwork1.network)
 
         val packetReader = TapPacketReader(Handler(handlerThread.looper),
                 testNetwork1.iface.fileDescriptor.fileDescriptor, 1500 /* maxPacketSize */)
@@ -1063,11 +1030,7 @@
         // This test requires shims supporting T+ APIs (NsdServiceInfo.network)
         assumeTrue(TestUtils.shouldTestTApis())
 
-        val si = NsdServiceInfo()
-        si.serviceType = serviceType
-        si.serviceName = serviceName
-        si.network = testNetwork1.network
-        si.port = 12345 // Test won't try to connect so port does not matter
+        val si = makeTestServiceInfo(testNetwork1.network)
 
         // Register service on testNetwork1
         val registrationRecord = NsdRegistrationRecord()
@@ -1137,6 +1100,127 @@
         }
     }
 
+    // Test that even if only a PTR record is received as a reply when discovering, without the
+    // SRV, TXT, address records as recommended (but not mandated) by RFC 6763 12, the service can
+    // still be discovered.
+    @Test
+    fun testDiscoveryWithPtrOnlyResponse_ServiceIsFound() {
+        // Register service on testNetwork1
+        val discoveryRecord = NsdDiscoveryRecord()
+        val packetReader = TapPacketReader(Handler(handlerThread.looper),
+                testNetwork1.iface.fileDescriptor.fileDescriptor, 1500 /* maxPacketSize */)
+        packetReader.startAsyncForTest()
+        handlerThread.waitForIdle(TIMEOUT_MS)
+
+        nsdManager.discoverServices(serviceType, NsdManager.PROTOCOL_DNS_SD,
+                testNetwork1.network, { it.run() }, discoveryRecord)
+
+        tryTest {
+            discoveryRecord.expectCallback<DiscoveryStarted>()
+            assertNotNull(packetReader.pollForQuery("$serviceType.local", DnsResolver.TYPE_PTR))
+            /*
+            Generated with:
+            scapy.raw(scapy.DNS(rd=0, qr=1, aa=1, qd = None, an =
+                scapy.DNSRR(rrname='_nmt123456789._tcp.local', type='PTR', ttl=120,
+                rdata='NsdTest123456789._nmt123456789._tcp.local'))).hex()
+             */
+            val ptrResponsePayload = HexDump.hexStringToByteArray("0000840000000001000000000d5f6e" +
+                    "6d74313233343536373839045f746370056c6f63616c00000c000100000078002b104e736454" +
+                    "6573743132333435363738390d5f6e6d74313233343536373839045f746370056c6f63616c00")
+
+            replaceServiceNameAndTypeWithTestSuffix(ptrResponsePayload)
+            packetReader.sendResponse(buildMdnsPacket(ptrResponsePayload))
+
+            val serviceFound = discoveryRecord.expectCallback<ServiceFound>()
+            serviceFound.serviceInfo.let {
+                assertEquals(serviceName, it.serviceName)
+                // Discovered service types have a dot at the end
+                assertEquals("$serviceType.", it.serviceType)
+                assertEquals(testNetwork1.network, it.network)
+                // ServiceFound does not provide port, address or attributes (only information
+                // available in the PTR record is included in that callback, regardless of whether
+                // other records exist).
+                assertEquals(0, it.port)
+                assertEmpty(it.hostAddresses)
+                assertEquals(0, it.attributes.size)
+            }
+        } cleanup {
+            nsdManager.stopServiceDiscovery(discoveryRecord)
+            discoveryRecord.expectCallback<DiscoveryStopped>()
+        }
+    }
+
+    // Test RFC 6763 12. "Clients MUST be capable of functioning correctly with DNS servers [...]
+    // that fail to generate these additional records automatically, by issuing subsequent queries
+    // for any further record(s) they require"
+    @Test
+    fun testResolveWhenServerSendsNoAdditionalRecord() {
+        // Resolve service on testNetwork1
+        val resolveRecord = NsdResolveRecord()
+        val packetReader = TapPacketReader(Handler(handlerThread.looper),
+                testNetwork1.iface.fileDescriptor.fileDescriptor, 1500 /* maxPacketSize */)
+        packetReader.startAsyncForTest()
+        handlerThread.waitForIdle(TIMEOUT_MS)
+
+        val si = makeTestServiceInfo(testNetwork1.network)
+        nsdManager.resolveService(si, { it.run() }, resolveRecord)
+
+        val serviceFullName = "$serviceName.$serviceType.local"
+        // The query should ask for ANY, since both SRV and TXT are requested. Note legacy
+        // mdnsresponder will ask for SRV and TXT separately, and will not proceed to asking for
+        // address records without an answer for both.
+        val srvTxtQuery = packetReader.pollForQuery(serviceFullName, DnsResolver.TYPE_ANY)
+        assertNotNull(srvTxtQuery)
+
+        /*
+        Generated with:
+        scapy.raw(scapy.dns_compress(scapy.DNS(rd=0, qr=1, aa=1, qd = None, an =
+            scapy.DNSRRSRV(rrname='NsdTest123456789._nmt123456789._tcp.local',
+                rclass=0x8001, port=31234, target='testhost.local', ttl=120) /
+            scapy.DNSRR(rrname='NsdTest123456789._nmt123456789._tcp.local', type='TXT', ttl=120,
+                rdata='testkey=testvalue')
+        ))).hex()
+         */
+        val srvTxtResponsePayload = HexDump.hexStringToByteArray("000084000000000200000000104" +
+                "e7364546573743132333435363738390d5f6e6d74313233343536373839045f746370056c6f6" +
+                "3616c0000218001000000780011000000007a020874657374686f7374c030c00c00100001000" +
+                "00078001211746573746b65793d7465737476616c7565")
+        replaceServiceNameAndTypeWithTestSuffix(srvTxtResponsePayload)
+        packetReader.sendResponse(buildMdnsPacket(srvTxtResponsePayload))
+
+        val testHostname = "testhost.local"
+        val addressQuery = packetReader.pollForQuery(testHostname,
+            DnsResolver.TYPE_A, DnsResolver.TYPE_AAAA)
+        assertNotNull(addressQuery)
+
+        /*
+        Generated with:
+        scapy.raw(scapy.dns_compress(scapy.DNS(rd=0, qr=1, aa=1, qd = None, an =
+            scapy.DNSRR(rrname='testhost.local', type='A', ttl=120,
+                rdata='192.0.2.123') /
+            scapy.DNSRR(rrname='testhost.local', type='AAAA', ttl=120,
+                rdata='2001:db8::123')
+        ))).hex()
+         */
+        val addressPayload = HexDump.hexStringToByteArray("0000840000000002000000000874657374" +
+                "686f7374056c6f63616c0000010001000000780004c000027bc00c001c000100000078001020" +
+                "010db8000000000000000000000123")
+        packetReader.sendResponse(buildMdnsPacket(addressPayload))
+
+        val serviceResolved = resolveRecord.expectCallback<ServiceResolved>()
+        serviceResolved.serviceInfo.let {
+            assertEquals(serviceName, it.serviceName)
+            assertEquals(".$serviceType", it.serviceType)
+            assertEquals(testNetwork1.network, it.network)
+            assertEquals(31234, it.port)
+            assertEquals(1, it.attributes.size)
+            assertArrayEquals("testvalue".encodeToByteArray(), it.attributes["testkey"])
+        }
+        assertEquals(
+                setOf(parseNumericAddress("192.0.2.123"), parseNumericAddress("2001:db8::123")),
+                serviceResolved.serviceInfo.hostAddresses.toSet())
+    }
+
     private fun buildConflictingAnnouncement(): ByteBuffer {
         /*
         Generated with:
@@ -1148,21 +1232,37 @@
         val mdnsPayload = HexDump.hexStringToByteArray("000084000000000100000000104e736454657" +
                 "3743132333435363738390d5f6e6d74313233343536373839045f746370056c6f63616c00002" +
                 "18001000000780016000000007a0208636f6e666c696374056c6f63616c00")
-        val packetBuffer = ByteBuffer.wrap(mdnsPayload)
-        // Replace service name and types in the packet with the random ones used in the test.
+        replaceServiceNameAndTypeWithTestSuffix(mdnsPayload)
+
+        return buildMdnsPacket(mdnsPayload)
+    }
+
+    /**
+     * Replaces occurrences of "NsdTest123456789" and "_nmt123456789" in mDNS payload with the
+     * actual random name and type that are used by the test.
+     */
+    private fun replaceServiceNameAndTypeWithTestSuffix(mdnsPayload: ByteArray) {
         // Test service name and types have consistent length and are always ASCII
         val testPacketName = "NsdTest123456789".encodeToByteArray()
         val testPacketTypePrefix = "_nmt123456789".encodeToByteArray()
         val encodedServiceName = serviceName.encodeToByteArray()
         val encodedTypePrefix = serviceType.split('.')[0].encodeToByteArray()
-        assertEquals(testPacketName.size, encodedServiceName.size)
-        assertEquals(testPacketTypePrefix.size, encodedTypePrefix.size)
-        packetBuffer.position(mdnsPayload.indexOf(testPacketName))
-        packetBuffer.put(encodedServiceName)
-        packetBuffer.position(mdnsPayload.indexOf(testPacketTypePrefix))
-        packetBuffer.put(encodedTypePrefix)
 
-        return buildMdnsPacket(mdnsPayload)
+        val packetBuffer = ByteBuffer.wrap(mdnsPayload)
+        replaceAll(packetBuffer, testPacketName, encodedServiceName)
+        replaceAll(packetBuffer, testPacketTypePrefix, encodedTypePrefix)
+    }
+
+    private tailrec fun replaceAll(buffer: ByteBuffer, source: ByteArray, replacement: ByteArray) {
+        assertEquals(source.size, replacement.size)
+        val index = buffer.array().indexOf(source)
+        if (index < 0) return
+
+        val origPosition = buffer.position()
+        buffer.position(index)
+        buffer.put(replacement)
+        buffer.position(origPosition)
+        replaceAll(buffer, source, replacement)
     }
 
     private fun buildMdnsPacket(mdnsPayload: ByteArray): ByteBuffer {