Merge "Test races around NetworkAgent creation." into main
diff --git a/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt b/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt
index 60081d4..815c3a5 100644
--- a/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt
+++ b/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt
@@ -83,13 +83,17 @@
 import android.os.ConditionVariable
 import android.os.Handler
 import android.os.HandlerThread
+import android.os.Looper
 import android.os.Message
 import android.os.PersistableBundle
 import android.os.Process
 import android.os.SystemClock
 import android.platform.test.annotations.AppModeFull
+import android.system.Os
+import android.system.OsConstants.AF_INET6
 import android.system.OsConstants.IPPROTO_TCP
 import android.system.OsConstants.IPPROTO_UDP
+import android.system.OsConstants.SOCK_DGRAM
 import android.telephony.CarrierConfigManager
 import android.telephony.SubscriptionManager
 import android.telephony.TelephonyManager
@@ -105,6 +109,10 @@
 import com.android.compatibility.common.util.UiccUtil
 import com.android.modules.utils.build.SdkLevel
 import com.android.net.module.util.ArrayTrackRecord
+import com.android.net.module.util.NetworkStackConstants.ETHER_MTU
+import com.android.net.module.util.NetworkStackConstants.IPV6_HEADER_LEN
+import com.android.net.module.util.NetworkStackConstants.IPV6_PROTOCOL_OFFSET
+import com.android.net.module.util.NetworkStackConstants.UDP_HEADER_LEN
 import com.android.testutils.CompatUtil
 import com.android.testutils.ConnectivityModuleTest
 import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo
@@ -115,6 +123,7 @@
 import com.android.testutils.RecorderCallback.CallbackEntry.LinkPropertiesChanged
 import com.android.testutils.RecorderCallback.CallbackEntry.Losing
 import com.android.testutils.RecorderCallback.CallbackEntry.Lost
+import com.android.testutils.PollPacketReader
 import com.android.testutils.TestableNetworkAgent
 import com.android.testutils.TestableNetworkAgent.CallbackEntry.OnAddKeepalivePacketFilter
 import com.android.testutils.TestableNetworkAgent.CallbackEntry.OnAutomaticReconnectDisabled
@@ -133,6 +142,7 @@
 import com.android.testutils.assertThrows
 import com.android.testutils.runAsShell
 import com.android.testutils.tryTest
+import com.android.testutils.waitForIdle
 import java.io.Closeable
 import java.io.IOException
 import java.net.DatagramSocket
@@ -140,10 +150,13 @@
 import java.net.InetSocketAddress
 import java.net.Socket
 import java.security.MessageDigest
+import java.nio.ByteBuffer
 import java.time.Duration
 import java.util.Arrays
+import java.util.Random
 import java.util.UUID
 import java.util.concurrent.Executors
+import kotlin.collections.ArrayList
 import kotlin.test.assertEquals
 import kotlin.test.assertFailsWith
 import kotlin.test.assertFalse
@@ -188,6 +201,11 @@
     it.obj = obj
 }
 
+private val LINK_ADDRESS = LinkAddress("2001:db8::1/64")
+private val REMOTE_ADDRESS = InetAddresses.parseNumericAddress("2001:db8::123")
+private val PREFIX = IpPrefix("2001:db8::/64")
+private val NEXTHOP = InetAddresses.parseNumericAddress("fe80::abcd")
+
 // On T and below, the native network is only created when the agent connects.
 // Starting in U, the native network was to be created as soon as the agent is registered,
 // but this has been flagged off for now pending resolution of race conditions.
@@ -321,6 +339,15 @@
         if (transports.size > 0) removeCapability(NET_CAPABILITY_NOT_RESTRICTED)
     }
 
+    private fun makeTestLinkProperties(ifName: String): LinkProperties {
+        return LinkProperties().apply {
+            interfaceName = ifName
+            addLinkAddress(LINK_ADDRESS)
+            addRoute(RouteInfo(PREFIX, null /* nextHop */, ifName))
+            addRoute(RouteInfo(IpPrefix("::/0"), NEXTHOP, ifName))
+        }
+    }
+
     private fun createNetworkAgent(
         context: Context = realContext,
         specifier: String? = null,
@@ -341,6 +368,7 @@
 
     private fun createConnectedNetworkAgent(
         context: Context = realContext,
+        lp: LinkProperties? = null,
         specifier: String? = UUID.randomUUID().toString(),
         initialConfig: NetworkAgentConfig? = null,
         expectedInitSignalStrengthThresholds: IntArray = intArrayOf(),
@@ -350,7 +378,8 @@
         // Ensure this NetworkAgent is never unneeded by filing a request with its specifier.
         requestNetwork(makeTestNetworkRequest(specifier), callback)
         val nc = makeTestNetworkCapabilities(specifier, transports)
-        val agent = createNetworkAgent(context, initialConfig = initialConfig, initialNc = nc)
+        val agent = createNetworkAgent(context, initialConfig = initialConfig, initialLp = lp,
+            initialNc = nc)
         agent.setTeardownDelayMillis(0)
         // Connect the agent and verify initial status callbacks.
         agent.register()
@@ -361,8 +390,9 @@
         return agent to callback
     }
 
-    private fun connectNetwork(vararg transports: Int): Pair<TestableNetworkAgent, Network> {
-        val (agent, callback) = createConnectedNetworkAgent(transports = transports)
+    private fun connectNetwork(vararg transports: Int, lp: LinkProperties? = null):
+            Pair<TestableNetworkAgent, Network> {
+        val (agent, callback) = createConnectedNetworkAgent(transports = transports, lp = lp)
         val network = agent.network!!
         // createConnectedNetworkAgent internally files a request; release it so that the network
         // will be torn down if unneeded.
@@ -382,8 +412,9 @@
         assertNoCallback()
     }
 
-    private fun createTunInterface(): TestNetworkInterface = realContext.getSystemService(
-                TestNetworkManager::class.java)!!.createTunInterface(emptyList()).also {
+    private fun createTunInterface(addrs: Collection<LinkAddress> = emptyList()):
+            TestNetworkInterface = realContext.getSystemService(
+                TestNetworkManager::class.java)!!.createTunInterface(addrs).also {
             ifacesToCleanUp.add(it)
     }
 
@@ -1501,15 +1532,75 @@
 
     private fun createEpsAttributes(qci: Int = 1): EpsBearerQosSessionAttributes {
         val remoteAddresses = ArrayList<InetSocketAddress>()
-        remoteAddresses.add(InetSocketAddress("2001:db8::123", 80))
+        remoteAddresses.add(InetSocketAddress(REMOTE_ADDRESS, 80))
         return EpsBearerQosSessionAttributes(
                 qci, 2, 3, 4, 5,
                 remoteAddresses
         )
     }
 
+    fun sendAndExpectUdpPacket(net: Network,
+                               reader: PollPacketReader, iface: TestNetworkInterface) {
+        val s = Os.socket(AF_INET6, SOCK_DGRAM, 0)
+        net.bindSocket(s)
+        val content = ByteArray(16)
+        Random().nextBytes(content)
+        Os.sendto(s, ByteBuffer.wrap(content), 0, REMOTE_ADDRESS, 7 /* port */)
+        val match = reader.poll(DEFAULT_TIMEOUT_MS) {
+            val udpStart = IPV6_HEADER_LEN + UDP_HEADER_LEN
+            it.size == udpStart + content.size &&
+                    it[0].toInt() and 0xf0 == 0x60 &&
+                    it[IPV6_PROTOCOL_OFFSET].toInt() == IPPROTO_UDP &&
+                    Arrays.equals(content, it.copyOfRange(udpStart, udpStart + content.size))
+        }
+        assertNotNull(match, "Did not receive matching packet on ${iface.interfaceName} " +
+                " after ${DEFAULT_TIMEOUT_MS}ms")
+    }
+
+    fun createInterfaceAndReader(): Triple<TestNetworkInterface, PollPacketReader, LinkProperties> {
+        val iface = createTunInterface(listOf(LINK_ADDRESS))
+        val handler = Handler(Looper.getMainLooper())
+        val reader = PollPacketReader(handler, iface.fileDescriptor.fileDescriptor, ETHER_MTU)
+        reader.startAsyncForTest()
+        handler.waitForIdle(DEFAULT_TIMEOUT_MS)
+        val ifName = iface.interfaceName
+        val lp = makeTestLinkProperties(ifName)
+        return Triple(iface, reader, lp)
+    }
+
+    @Test
+    fun testRegisterAfterUnregister() {
+        val (iface, reader, lp) = createInterfaceAndReader()
+
+        // File a request that matches and keeps up the best-scoring test network.
+        val testCallback = TestableNetworkCallback(timeoutMs = DEFAULT_TIMEOUT_MS)
+        requestNetwork(makeTestNetworkRequest(), testCallback)
+
+        // Register and unregister networkagents in a loop, checking that every time an agent
+        // connects, the native network is correctly configured and packets can be sent.
+        // Running 10 iterations takes about 1 second on x86 cuttlefish, and detects the race in
+        // b/286649301 most of the time.
+        for (i in 1..10) {
+            val agent1 = createNetworkAgent(realContext, initialLp = lp)
+            agent1.register()
+            agent1.unregister()
+
+            val agent2 = createNetworkAgent(realContext, initialLp = lp)
+            agent2.register()
+            agent2.markConnected()
+            val network2 = agent2.network!!
+
+            testCallback.expectAvailableThenValidatedCallbacks(network2)
+            sendAndExpectUdpPacket(network2, reader, iface)
+            agent2.unregister()
+            testCallback.expect<Lost>(network2)
+        }
+    }
+
     @Test
     fun testUnregisterAfterReplacement() {
+        val (iface, reader, lp) = createInterfaceAndReader()
+
         // Keeps an eye on all test networks.
         val matchAllCallback = TestableNetworkCallback(timeoutMs = DEFAULT_TIMEOUT_MS)
         registerNetworkCallback(makeTestNetworkRequest(), matchAllCallback)
@@ -1519,14 +1610,13 @@
         requestNetwork(makeTestNetworkRequest(), testCallback)
 
         // Connect the first network. This should satisfy the request.
-        val (agent1, network1) = connectNetwork()
+        val (agent1, network1) = connectNetwork(lp = lp)
         matchAllCallback.expectAvailableThenValidatedCallbacks(network1)
         testCallback.expectAvailableThenValidatedCallbacks(network1)
-        // Check that network1 exists by binding a socket to it and getting no exceptions.
-        network1.bindSocket(DatagramSocket())
+        sendAndExpectUdpPacket(network1, reader, iface)
 
         // Connect a second agent. network1 is preferred because it was already registered, so
-        // testCallback will not see any events. agent2 is be torn down because it has no requests.
+        // testCallback will not see any events. agent2 is torn down because it has no requests.
         val (agent2, network2) = connectNetwork()
         matchAllCallback.expectAvailableThenValidatedCallbacks(network2)
         matchAllCallback.expect<Lost>(network2)
@@ -1551,9 +1641,10 @@
         // as soon as it validates (until then, it is outscored by network1).
         // The fact that the first events seen by matchAllCallback is the connection of network3
         // implicitly ensures that no callbacks are sent since network1 was lost.
-        val (agent3, network3) = connectNetwork()
+        val (agent3, network3) = connectNetwork(lp = lp)
         matchAllCallback.expectAvailableThenValidatedCallbacks(network3)
         testCallback.expectAvailableDoubleValidatedCallbacks(network3)
+        sendAndExpectUdpPacket(network3, reader, iface)
 
         // As soon as the replacement arrives, network1 is disconnected.
         // Check that this happens before the replacement timeout (5 seconds) fires.
@@ -1573,6 +1664,7 @@
         matchAllCallback.expect<Losing>(network3)
         testCallback.expectAvailableCallbacks(network4, validated = true)
         mCM.unregisterNetworkCallback(agent4callback)
+        sendAndExpectUdpPacket(network3, reader, iface)
         agent3.unregisterAfterReplacement(5_000)
         agent3.expectCallback<OnNetworkUnwanted>()
         matchAllCallback.expect<Lost>(network3, 1000L)
@@ -1588,9 +1680,10 @@
 
         // If a network that is awaiting replacement is unregistered, it disconnects immediately,
         // before the replacement timeout fires.
-        val (agent5, network5) = connectNetwork()
+        val (agent5, network5) = connectNetwork(lp = lp)
         matchAllCallback.expectAvailableThenValidatedCallbacks(network5)
         testCallback.expectAvailableThenValidatedCallbacks(network5)
+        sendAndExpectUdpPacket(network5, reader, iface)
         agent5.unregisterAfterReplacement(5_000 /* timeoutMillis */)
         agent5.unregister()
         matchAllCallback.expect<Lost>(network5, 1000L /* timeoutMs */)
@@ -1637,7 +1730,7 @@
         matchAllCallback.assertNoCallback(200 /* timeoutMs */)
 
         // If wifi is replaced within the timeout, the device does not switch to cellular.
-        val (_, cellNetwork) = connectNetwork(TRANSPORT_CELLULAR)
+        val (cellAgent, cellNetwork) = connectNetwork(TRANSPORT_CELLULAR)
         testCallback.expectAvailableThenValidatedCallbacks(cellNetwork)
         matchAllCallback.expectAvailableThenValidatedCallbacks(cellNetwork)
 
@@ -1674,6 +1767,34 @@
         matchAllCallback.expectAvailableThenValidatedCallbacks(newWifiNetwork)
         matchAllCallback.expect<Lost>(wifiNetwork)
         wifiAgent.expectCallback<OnNetworkUnwanted>()
+        testCallback.expect<CapabilitiesChanged>(newWifiNetwork)
+
+        cellAgent.unregister()
+        matchAllCallback.expect<Lost>(cellNetwork)
+        newWifiAgent.unregister()
+        matchAllCallback.expect<Lost>(newWifiNetwork)
+        testCallback.expect<Lost>(newWifiNetwork)
+
+        // Calling unregisterAfterReplacement several times in quick succession works.
+        // These networks are all kept up by testCallback.
+        val agent10 = createNetworkAgent(realContext, initialLp = lp)
+        agent10.register()
+        agent10.unregisterAfterReplacement(5_000)
+
+        val agent11 = createNetworkAgent(realContext, initialLp = lp)
+        agent11.register()
+        agent11.unregisterAfterReplacement(5_000)
+
+        val agent12 = createNetworkAgent(realContext, initialLp = lp)
+        agent12.register()
+        agent12.unregisterAfterReplacement(5_000)
+
+        val agent13 = createNetworkAgent(realContext, initialLp = lp)
+        agent13.register()
+        agent13.markConnected()
+        testCallback.expectAvailableThenValidatedCallbacks(agent13.network!!)
+        sendAndExpectUdpPacket(agent13.network!!, reader, iface)
+        agent13.unregister()
     }
 
     @Test
@@ -1706,14 +1827,7 @@
                 it.underlyingNetworks = listOf()
             }
         }
-        val lp = LinkProperties().apply {
-            interfaceName = ifName
-            addLinkAddress(LinkAddress("2001:db8::1/64"))
-            addRoute(RouteInfo(IpPrefix("2001:db8::/64"), null /* nextHop */, ifName))
-            addRoute(RouteInfo(IpPrefix("::/0"),
-                    InetAddresses.parseNumericAddress("fe80::abcd"),
-                    ifName))
-        }
+        val lp = makeTestLinkProperties(ifName)
 
         // File a request containing the agent's specifier to receive callbacks and to ensure that
         // the agent is not torn down due to being unneeded.