Add a CTS test for NetworkAgent#setUnderlyingNetworks.

Bug: 173331190
Test: atest CtsNetTestCases:NetworkAgentTest#testSetUnderlyingNetworks
Change-Id: I442a618d2d50eb15dbcb8926b60fc6fd0d5b2f3e
diff --git a/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt b/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt
index 45a84f8..85d0a2e 100644
--- a/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt
+++ b/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt
@@ -82,6 +82,7 @@
 import org.mockito.Mockito.verify
 import java.net.InetAddress
 import java.time.Duration
+import java.util.Arrays
 import java.util.UUID
 import kotlin.test.assertEquals
 import kotlin.test.assertFailsWith
@@ -312,9 +313,11 @@
 
     private fun createNetworkAgent(
         context: Context = realContext,
-        name: String? = null
+        name: String? = null,
+        nc: NetworkCapabilities = NetworkCapabilities(),
+        lp: LinkProperties = LinkProperties()
     ): TestableNetworkAgent {
-        val nc = NetworkCapabilities().apply {
+        nc.apply {
             addTransportType(NetworkCapabilities.TRANSPORT_TEST)
             removeCapability(NetworkCapabilities.NET_CAPABILITY_TRUSTED)
             removeCapability(NetworkCapabilities.NET_CAPABILITY_INTERNET)
@@ -325,7 +328,7 @@
                 setNetworkSpecifier(StringNetworkSpecifier(name))
             }
         }
-        val lp = LinkProperties().apply {
+        lp.apply {
             addLinkAddress(LinkAddress(LOCAL_IPV4_ADDRESS, 0))
         }
         val config = NetworkAgentConfig.Builder().build()
@@ -541,8 +544,63 @@
         // tearDown() will unregister the requests and agents
     }
 
+    private fun hasAllTransports(nc: NetworkCapabilities?, transports: IntArray) =
+            nc != null && transports.all { nc.hasTransport(it) }
+
     @Test
-    @IgnoreUpTo(android.os.Build.VERSION_CODES.R)
+    @IgnoreUpTo(Build.VERSION_CODES.R)
+    fun testSetUnderlyingNetworks() {
+        val request = NetworkRequest.Builder()
+                .addTransportType(NetworkCapabilities.TRANSPORT_TEST)
+                .addTransportType(NetworkCapabilities.TRANSPORT_VPN)
+                .removeCapability(NetworkCapabilities.NET_CAPABILITY_NOT_VPN)
+                .removeCapability(NetworkCapabilities.NET_CAPABILITY_TRUSTED) // TODO: add to VPN!
+                .build()
+        val callback = TestableNetworkCallback()
+        mCM.registerNetworkCallback(request, callback)
+
+        val nc = NetworkCapabilities().apply {
+            addTransportType(NetworkCapabilities.TRANSPORT_TEST)
+            addTransportType(NetworkCapabilities.TRANSPORT_VPN)
+            removeCapability(NetworkCapabilities.NET_CAPABILITY_NOT_VPN)
+        }
+        val defaultNetwork = mCM.activeNetwork
+        assertNotNull(defaultNetwork)
+        val defaultNetworkTransports = mCM.getNetworkCapabilities(defaultNetwork).transportTypes
+
+        val agent = createNetworkAgent(nc = nc)
+        agent.register()
+        agent.markConnected()
+        callback.expectAvailableThenValidatedCallbacks(agent.network!!)
+
+        var vpnNc = mCM.getNetworkCapabilities(agent.network)
+        assertNotNull(vpnNc)
+        assertTrue(NetworkCapabilities.TRANSPORT_VPN in vpnNc.transportTypes)
+        assertTrue(hasAllTransports(vpnNc, defaultNetworkTransports),
+                "VPN transports ${Arrays.toString(vpnNc.transportTypes)}" +
+                " lacking transports from ${Arrays.toString(defaultNetworkTransports)}")
+
+        agent.setUnderlyingNetworks(listOf<Network>())
+        callback.expectCapabilitiesThat(agent.network!!) {
+            it.transportTypes.size == 1 && it.hasTransport(NetworkCapabilities.TRANSPORT_VPN)
+        }
+
+        val expectedTransports = (defaultNetworkTransports.toSet() +
+                NetworkCapabilities.TRANSPORT_VPN).toIntArray()
+        agent.setUnderlyingNetworks(null)
+        callback.expectCapabilitiesThat(agent.network!!) {
+            it.transportTypes.size == expectedTransports.size &&
+                    hasAllTransports(it, expectedTransports)
+        }
+
+        agent.unregister()
+        callback.expectCallback<Lost>(agent.network)
+
+        mCM.unregisterNetworkCallback(callback)
+    }
+
+    @Test
+    @IgnoreUpTo(Build.VERSION_CODES.R)
     fun testAgentStartsInConnecting() {
         val mockContext = mock(Context::class.java)
         val mockCm = mock(ConnectivityManager::class.java)