Merge "Use libnetjniutils for JNI File Descriptor info"
diff --git a/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt b/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt
index 85d0a2e..1d2f19a 100644
--- a/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt
+++ b/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt
@@ -18,6 +18,8 @@
 import android.app.Instrumentation
 import android.content.Context
 import android.net.ConnectivityManager
+import android.net.InetAddresses
+import android.net.IpPrefix
 import android.net.KeepalivePacketData
 import android.net.LinkAddress
 import android.net.LinkProperties
@@ -34,9 +36,20 @@
 import android.net.NetworkAgent.VALID_NETWORK
 import android.net.NetworkAgentConfig
 import android.net.NetworkCapabilities
+import android.net.NetworkCapabilities.NET_CAPABILITY_INTERNET
+import android.net.NetworkCapabilities.NET_CAPABILITY_NOT_CONGESTED
+import android.net.NetworkCapabilities.NET_CAPABILITY_NOT_METERED
+import android.net.NetworkCapabilities.NET_CAPABILITY_NOT_ROAMING
+import android.net.NetworkCapabilities.NET_CAPABILITY_NOT_SUSPENDED
+import android.net.NetworkCapabilities.NET_CAPABILITY_NOT_VPN
+import android.net.NetworkCapabilities.NET_CAPABILITY_TEMPORARILY_NOT_METERED
+import android.net.NetworkCapabilities.NET_CAPABILITY_TRUSTED
+import android.net.NetworkCapabilities.TRANSPORT_TEST
+import android.net.NetworkCapabilities.TRANSPORT_VPN
 import android.net.NetworkInfo
 import android.net.NetworkProvider
 import android.net.NetworkRequest
+import android.net.RouteInfo
 import android.net.SocketKeepalive
 import android.net.StringNetworkSpecifier
 import android.net.Uri
@@ -57,6 +70,7 @@
 import android.os.Looper
 import android.os.Message
 import android.os.Messenger
+import android.util.DebugUtils.valueToString
 import androidx.test.InstrumentationRegistry
 import androidx.test.runner.AndroidJUnit4
 import com.android.internal.util.AsyncChannel
@@ -80,7 +94,6 @@
 import org.mockito.Mockito.doReturn
 import org.mockito.Mockito.mock
 import org.mockito.Mockito.verify
-import java.net.InetAddress
 import java.time.Duration
 import java.util.Arrays
 import java.util.UUID
@@ -122,8 +135,8 @@
     @Rule @JvmField
     val ignoreRule = DevSdkIgnoreRule(ignoreClassUpTo = Build.VERSION_CODES.Q)
 
-    private val LOCAL_IPV4_ADDRESS = InetAddress.parseNumericAddress("192.0.2.1")
-    private val REMOTE_IPV4_ADDRESS = InetAddress.parseNumericAddress("192.0.2.2")
+    private val LOCAL_IPV4_ADDRESS = InetAddresses.parseNumericAddress("192.0.2.1")
+    private val REMOTE_IPV4_ADDRESS = InetAddresses.parseNumericAddress("192.0.2.2")
 
     private val mCM = realContext.getSystemService(ConnectivityManager::class.java)
     private val mHandlerThread = HandlerThread("${javaClass.simpleName} handler thread")
@@ -314,22 +327,23 @@
     private fun createNetworkAgent(
         context: Context = realContext,
         name: String? = null,
-        nc: NetworkCapabilities = NetworkCapabilities(),
-        lp: LinkProperties = LinkProperties()
+        initialNc: NetworkCapabilities? = null,
+        initialLp: LinkProperties? = null
     ): TestableNetworkAgent {
-        nc.apply {
-            addTransportType(NetworkCapabilities.TRANSPORT_TEST)
-            removeCapability(NetworkCapabilities.NET_CAPABILITY_TRUSTED)
-            removeCapability(NetworkCapabilities.NET_CAPABILITY_INTERNET)
-            addCapability(NetworkCapabilities.NET_CAPABILITY_NOT_SUSPENDED)
-            addCapability(NetworkCapabilities.NET_CAPABILITY_NOT_ROAMING)
-            addCapability(NetworkCapabilities.NET_CAPABILITY_NOT_VPN)
+        val nc = initialNc ?: NetworkCapabilities().apply {
+            addTransportType(TRANSPORT_TEST)
+            removeCapability(NET_CAPABILITY_TRUSTED)
+            removeCapability(NET_CAPABILITY_INTERNET)
+            addCapability(NET_CAPABILITY_NOT_SUSPENDED)
+            addCapability(NET_CAPABILITY_NOT_ROAMING)
+            addCapability(NET_CAPABILITY_NOT_VPN)
             if (null != name) {
                 setNetworkSpecifier(StringNetworkSpecifier(name))
             }
         }
-        lp.apply {
-            addLinkAddress(LinkAddress(LOCAL_IPV4_ADDRESS, 0))
+        val lp = initialLp ?: LinkProperties().apply {
+            addLinkAddress(LinkAddress(LOCAL_IPV4_ADDRESS, 32))
+            addRoute(RouteInfo(IpPrefix("0.0.0.0/0"), null, null))
         }
         val config = NetworkAgentConfig.Builder().build()
         return TestableNetworkAgent(context, mHandlerThread.looper, nc, lp, config).also {
@@ -341,7 +355,7 @@
             Pair<TestableNetworkAgent, TestableNetworkCallback> {
         val request: NetworkRequest = NetworkRequest.Builder()
                 .clearCapabilities()
-                .addTransportType(NetworkCapabilities.TRANSPORT_TEST)
+                .addTransportType(TRANSPORT_TEST)
                 .build()
         val callback = TestableNetworkCallback(timeoutMs = DEFAULT_TIMEOUT_MS)
         requestNetwork(request, callback)
@@ -386,7 +400,7 @@
         val callbacks = thresholds.map { strength ->
             val request = NetworkRequest.Builder()
                     .clearCapabilities()
-                    .addTransportType(NetworkCapabilities.TRANSPORT_TEST)
+                    .addTransportType(TRANSPORT_TEST)
                     .setSignalStrength(strength)
                     .build()
             TestableNetworkCallback(DEFAULT_TIMEOUT_MS).also {
@@ -486,10 +500,10 @@
             it.getInterfaceName() == ifaceName
         }
         val nc = NetworkCapabilities(agent.nc)
-        nc.addCapability(NetworkCapabilities.NET_CAPABILITY_NOT_METERED)
+        nc.addCapability(NET_CAPABILITY_NOT_METERED)
         agent.sendNetworkCapabilities(nc)
         callback.expectCapabilitiesThat(agent.network) {
-            it.hasCapability(NetworkCapabilities.NET_CAPABILITY_NOT_METERED)
+            it.hasCapability(NET_CAPABILITY_NOT_METERED)
         }
     }
 
@@ -503,12 +517,12 @@
         val name2 = UUID.randomUUID().toString()
         val request1 = NetworkRequest.Builder()
                 .clearCapabilities()
-                .addTransportType(NetworkCapabilities.TRANSPORT_TEST)
+                .addTransportType(TRANSPORT_TEST)
                 .setNetworkSpecifier(StringNetworkSpecifier(name1))
                 .build()
         val request2 = NetworkRequest.Builder()
                 .clearCapabilities()
-                .addTransportType(NetworkCapabilities.TRANSPORT_TEST)
+                .addTransportType(TRANSPORT_TEST)
                 .setNetworkSpecifier(StringNetworkSpecifier(name2))
                 .build()
         val callback1 = TestableNetworkCallback(timeoutMs = DEFAULT_TIMEOUT_MS)
@@ -519,7 +533,7 @@
         // Then file the interesting request
         val request = NetworkRequest.Builder()
                 .clearCapabilities()
-                .addTransportType(NetworkCapabilities.TRANSPORT_TEST)
+                .addTransportType(TRANSPORT_TEST)
                 .build()
         val callback = TestableNetworkCallback(timeoutMs = DEFAULT_TIMEOUT_MS)
         requestNetwork(request, callback)
@@ -551,52 +565,71 @@
     @IgnoreUpTo(Build.VERSION_CODES.R)
     fun testSetUnderlyingNetworks() {
         val request = NetworkRequest.Builder()
-                .addTransportType(NetworkCapabilities.TRANSPORT_TEST)
-                .addTransportType(NetworkCapabilities.TRANSPORT_VPN)
-                .removeCapability(NetworkCapabilities.NET_CAPABILITY_NOT_VPN)
-                .removeCapability(NetworkCapabilities.NET_CAPABILITY_TRUSTED) // TODO: add to VPN!
+                .addTransportType(TRANSPORT_TEST)
+                .addTransportType(TRANSPORT_VPN)
+                .removeCapability(NET_CAPABILITY_NOT_VPN)
+                .removeCapability(NET_CAPABILITY_TRUSTED) // TODO: add to VPN!
                 .build()
         val callback = TestableNetworkCallback()
-        mCM.registerNetworkCallback(request, callback)
+        registerNetworkCallback(request, callback)
 
         val nc = NetworkCapabilities().apply {
-            addTransportType(NetworkCapabilities.TRANSPORT_TEST)
-            addTransportType(NetworkCapabilities.TRANSPORT_VPN)
-            removeCapability(NetworkCapabilities.NET_CAPABILITY_NOT_VPN)
+            addTransportType(TRANSPORT_TEST)
+            addTransportType(TRANSPORT_VPN)
+            removeCapability(NET_CAPABILITY_NOT_VPN)
         }
         val defaultNetwork = mCM.activeNetwork
         assertNotNull(defaultNetwork)
-        val defaultNetworkTransports = mCM.getNetworkCapabilities(defaultNetwork).transportTypes
+        val defaultNetworkCapabilities = mCM.getNetworkCapabilities(defaultNetwork)
+        val defaultNetworkTransports = defaultNetworkCapabilities.transportTypes
 
-        val agent = createNetworkAgent(nc = nc)
+        val agent = createNetworkAgent(initialNc = nc)
         agent.register()
         agent.markConnected()
         callback.expectAvailableThenValidatedCallbacks(agent.network!!)
 
+        // Check that the default network's transport is propagated to the VPN.
         var vpnNc = mCM.getNetworkCapabilities(agent.network)
         assertNotNull(vpnNc)
-        assertTrue(NetworkCapabilities.TRANSPORT_VPN in vpnNc.transportTypes)
+
+        val testAndVpn = intArrayOf(TRANSPORT_TEST, TRANSPORT_VPN)
+        assertTrue(hasAllTransports(vpnNc, testAndVpn))
+        assertFalse(vpnNc.hasCapability(NET_CAPABILITY_NOT_VPN))
         assertTrue(hasAllTransports(vpnNc, defaultNetworkTransports),
                 "VPN transports ${Arrays.toString(vpnNc.transportTypes)}" +
                 " lacking transports from ${Arrays.toString(defaultNetworkTransports)}")
 
+        // Check that when no underlying networks are announced the underlying transport disappears.
         agent.setUnderlyingNetworks(listOf<Network>())
         callback.expectCapabilitiesThat(agent.network!!) {
-            it.transportTypes.size == 1 && it.hasTransport(NetworkCapabilities.TRANSPORT_VPN)
+            it.transportTypes.size == 2 && hasAllTransports(it, testAndVpn)
         }
 
-        val expectedTransports = (defaultNetworkTransports.toSet() +
-                NetworkCapabilities.TRANSPORT_VPN).toIntArray()
+        // Put the underlying network back and check that the underlying transport reappears.
+        val expectedTransports = (defaultNetworkTransports.toSet() + TRANSPORT_TEST + TRANSPORT_VPN)
+                .toIntArray()
         agent.setUnderlyingNetworks(null)
         callback.expectCapabilitiesThat(agent.network!!) {
             it.transportTypes.size == expectedTransports.size &&
                     hasAllTransports(it, expectedTransports)
         }
 
+        // Check that some underlying capabilities are propagated.
+        // This is not very accurate because the test does not control the capabilities of the
+        // underlying networks, and because not congested, not roaming, and not suspended are the
+        // default anyway. It's still useful as an extra check though.
+        vpnNc = mCM.getNetworkCapabilities(agent.network)
+        for (cap in listOf(NET_CAPABILITY_NOT_CONGESTED,
+                NET_CAPABILITY_NOT_ROAMING,
+                NET_CAPABILITY_NOT_SUSPENDED)) {
+            val capStr = valueToString(NetworkCapabilities::class.java, "NET_CAPABILITY_", cap)
+            if (defaultNetworkCapabilities.hasCapability(cap) && !vpnNc.hasCapability(cap)) {
+                fail("$capStr not propagated from underlying: $defaultNetworkCapabilities")
+            }
+        }
+
         agent.unregister()
         callback.expectCallback<Lost>(agent.network)
-
-        mCM.unregisterNetworkCallback(callback)
     }
 
     @Test
@@ -687,7 +720,7 @@
         // First create a request to make sure the network is kept up
         val request1 = NetworkRequest.Builder()
                 .clearCapabilities()
-                .addTransportType(NetworkCapabilities.TRANSPORT_TEST)
+                .addTransportType(TRANSPORT_TEST)
                 .build()
         val callback1 = TestableNetworkCallback(timeoutMs = DEFAULT_TIMEOUT_MS).also {
             registerNetworkCallback(request1, it)
@@ -697,7 +730,7 @@
         // Then file the interesting request
         val request = NetworkRequest.Builder()
                 .clearCapabilities()
-                .addTransportType(NetworkCapabilities.TRANSPORT_TEST)
+                .addTransportType(TRANSPORT_TEST)
                 .build()
         val callback = TestableNetworkCallback(timeoutMs = DEFAULT_TIMEOUT_MS)
         requestNetwork(request, callback)
@@ -708,18 +741,18 @@
 
             // Send TEMP_NOT_METERED and check that the callback is called appropriately.
             val nc1 = NetworkCapabilities(agent.nc)
-                    .addCapability(NetworkCapabilities.NET_CAPABILITY_TEMPORARILY_NOT_METERED)
+                    .addCapability(NET_CAPABILITY_TEMPORARILY_NOT_METERED)
             agent.sendNetworkCapabilities(nc1)
             callback.expectCapabilitiesThat(agent.network) {
-                it.hasCapability(NetworkCapabilities.NET_CAPABILITY_TEMPORARILY_NOT_METERED)
+                it.hasCapability(NET_CAPABILITY_TEMPORARILY_NOT_METERED)
             }
 
             // Remove TEMP_NOT_METERED and check that the callback is called appropriately.
             val nc2 = NetworkCapabilities(agent.nc)
-                    .removeCapability(NetworkCapabilities.NET_CAPABILITY_TEMPORARILY_NOT_METERED)
+                    .removeCapability(NET_CAPABILITY_TEMPORARILY_NOT_METERED)
             agent.sendNetworkCapabilities(nc2)
             callback.expectCapabilitiesThat(agent.network) {
-                !it.hasCapability(NetworkCapabilities.NET_CAPABILITY_TEMPORARILY_NOT_METERED)
+                !it.hasCapability(NET_CAPABILITY_TEMPORARILY_NOT_METERED)
             }
         }