Merge "Make sure agents start in the CONNECTING state"
diff --git a/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt b/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt
index d2ca3f8..7508228 100644
--- a/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt
+++ b/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt
@@ -34,6 +34,7 @@
 import android.net.NetworkAgent.VALID_NETWORK
 import android.net.NetworkAgentConfig
 import android.net.NetworkCapabilities
+import android.net.NetworkInfo
 import android.net.NetworkProvider
 import android.net.NetworkRequest
 import android.net.SocketKeepalive
@@ -71,6 +72,13 @@
 import org.junit.Rule
 import org.junit.Test
 import org.junit.runner.RunWith
+import org.mockito.ArgumentMatchers.any
+import org.mockito.ArgumentMatchers.anyInt
+import org.mockito.ArgumentMatchers.argThat
+import org.mockito.ArgumentMatchers.eq
+import org.mockito.Mockito.doReturn
+import org.mockito.Mockito.mock
+import org.mockito.Mockito.verify
 import java.net.InetAddress
 import java.time.Duration
 import java.util.UUID
@@ -98,7 +106,7 @@
 private const val FAKE_NET_ID = 1098
 private val instrumentation: Instrumentation
     get() = InstrumentationRegistry.getInstrumentation()
-private val context: Context
+private val realContext: Context
     get() = InstrumentationRegistry.getContext()
 private fun Message(what: Int, arg1: Int, arg2: Int, obj: Any?) = Message.obtain().also {
     it.what = what
@@ -115,7 +123,7 @@
     private val LOCAL_IPV4_ADDRESS = InetAddress.parseNumericAddress("192.0.2.1")
     private val REMOTE_IPV4_ADDRESS = InetAddress.parseNumericAddress("192.0.2.2")
 
-    private val mCM = context.getSystemService(ConnectivityManager::class.java)
+    private val mCM = realContext.getSystemService(ConnectivityManager::class.java)
     private val mHandlerThread = HandlerThread("${javaClass.simpleName} handler thread")
     private val mFakeConnectivityService by lazy { FakeConnectivityService(mHandlerThread.looper) }
 
@@ -166,7 +174,7 @@
             }
         }
 
-        fun connect(agentMsngr: Messenger) = asyncChannel.connect(context, handler, agentMsngr)
+        fun connect(agentMsngr: Messenger) = asyncChannel.connect(realContext, handler, agentMsngr)
 
         fun disconnect() = asyncChannel.disconnect()
 
@@ -180,6 +188,7 @@
     }
 
     private open class TestableNetworkAgent(
+        context: Context,
         looper: Looper,
         val nc: NetworkCapabilities,
         val lp: LinkProperties,
@@ -300,7 +309,10 @@
         callbacksToCleanUp.add(callback)
     }
 
-    private fun createNetworkAgent(name: String? = null): TestableNetworkAgent {
+    private fun createNetworkAgent(
+        context: Context = realContext,
+        name: String? = null
+    ): TestableNetworkAgent {
         val nc = NetworkCapabilities().apply {
             addTransportType(NetworkCapabilities.TRANSPORT_TEST)
             removeCapability(NetworkCapabilities.NET_CAPABILITY_TRUSTED)
@@ -316,12 +328,12 @@
             addLinkAddress(LinkAddress(LOCAL_IPV4_ADDRESS, 0))
         }
         val config = NetworkAgentConfig.Builder().build()
-        return TestableNetworkAgent(mHandlerThread.looper, nc, lp, config).also {
+        return TestableNetworkAgent(context, mHandlerThread.looper, nc, lp, config).also {
             agentsToCleanUp.add(it)
         }
     }
 
-    private fun createConnectedNetworkAgent(name: String? = null):
+    private fun createConnectedNetworkAgent(context: Context = realContext, name: String? = null):
             Pair<TestableNetworkAgent, TestableNetworkCallback> {
         val request: NetworkRequest = NetworkRequest.Builder()
                 .clearCapabilities()
@@ -329,7 +341,7 @@
                 .build()
         val callback = TestableNetworkCallback(timeoutMs = DEFAULT_TIMEOUT_MS)
         requestNetwork(request, callback)
-        val agent = createNetworkAgent(name)
+        val agent = createNetworkAgent(context, name)
         agent.register()
         agent.markConnected()
         return agent to callback
@@ -509,13 +521,13 @@
         requestNetwork(request, callback)
 
         // Connect the first Network
-        createConnectedNetworkAgent(name1).let { (agent1, _) ->
+        createConnectedNetworkAgent(name = name1).let { (agent1, _) ->
             callback.expectAvailableThenValidatedCallbacks(agent1.network)
             // Upgrade agent1 to a better score so that there is no ambiguity when
             // agent2 connects that agent1 is still better
             agent1.sendNetworkScore(BETTER_NETWORK_SCORE - 1)
             // Connect the second agent
-            createConnectedNetworkAgent(name2).let { (agent2, _) ->
+            createConnectedNetworkAgent(name = name2).let { (agent2, _) ->
                 agent2.markConnected()
                 // The callback should not see anything yet
                 callback.assertNoCallback(NO_CALLBACK_TIMEOUT)
@@ -529,6 +541,21 @@
     }
 
     @Test
+    fun testAgentStartsInConnecting() {
+        val mockContext = mock(Context::class.java)
+        val mockCm = mock(ConnectivityManager::class.java)
+        doReturn(mockCm).`when`(mockContext).getSystemService(Context.CONNECTIVITY_SERVICE)
+        createConnectedNetworkAgent(mockContext)
+        verify(mockCm).registerNetworkAgent(any(Messenger::class.java),
+                argThat<NetworkInfo> { it.detailedState == NetworkInfo.DetailedState.CONNECTING },
+                any(LinkProperties::class.java),
+                any(NetworkCapabilities::class.java),
+                anyInt() /* score */,
+                any(NetworkAgentConfig::class.java),
+                eq(NetworkProvider.ID_NONE))
+    }
+
+    @Test
     fun testSetAcceptUnvalidated() {
         createNetworkAgentWithFakeCS().let { agent ->
             mFakeConnectivityService.sendMessage(CMD_SAVE_ACCEPT_UNVALIDATED, 1)