Merge "[CTS] Fix CTS failure in mainline train" into android13-tests-dev
diff --git a/tests/cts/net/src/android/net/cts/DscpPolicyTest.kt b/tests/cts/net/src/android/net/cts/DscpPolicyTest.kt
index ad7e536..cfccb8b 100644
--- a/tests/cts/net/src/android/net/cts/DscpPolicyTest.kt
+++ b/tests/cts/net/src/android/net/cts/DscpPolicyTest.kt
@@ -16,10 +16,8 @@
 
 package android.net.cts
 
-import android.net.cts.util.CtsNetUtils.TestNetworkCallback
-
-import android.app.Instrumentation
 import android.Manifest.permission.MANAGE_TEST_NETWORKS
+import android.app.Instrumentation
 import android.content.Context
 import android.net.ConnectivityManager
 import android.net.DscpPolicy
@@ -27,8 +25,8 @@
 import android.net.IpPrefix
 import android.net.LinkAddress
 import android.net.LinkProperties
-import android.net.Network
 import android.net.MacAddress
+import android.net.Network
 import android.net.NetworkAgent
 import android.net.NetworkAgent.DSCP_POLICY_STATUS_DELETED
 import android.net.NetworkAgent.DSCP_POLICY_STATUS_SUCCESS
@@ -46,6 +44,7 @@
 import android.net.RouteInfo
 import android.net.TestNetworkInterface
 import android.net.TestNetworkManager
+import android.net.cts.util.CtsNetUtils.TestNetworkCallback
 import android.os.Build
 import android.os.HandlerThread
 import android.os.SystemClock
@@ -62,28 +61,24 @@
 import android.util.Range
 import androidx.test.InstrumentationRegistry
 import androidx.test.runner.AndroidJUnit4
+import com.android.net.module.util.IpUtils
 import com.android.net.module.util.NetworkStackConstants.ETHER_TYPE_IPV4
 import com.android.net.module.util.NetworkStackConstants.ETHER_TYPE_IPV6
 import com.android.net.module.util.Struct
 import com.android.net.module.util.structs.EthernetHeader
 import com.android.testutils.ArpResponder
 import com.android.testutils.CompatUtil
+import com.android.testutils.ConnectivityModuleTest
 import com.android.testutils.DevSdkIgnoreRule
-import com.android.testutils.assertParcelingIsLossless
 import com.android.testutils.RouterAdvertisementResponder
-import com.android.testutils.runAsShell
 import com.android.testutils.SC_V2
 import com.android.testutils.TapPacketReader
 import com.android.testutils.TestableNetworkAgent
-import com.android.testutils.TestableNetworkAgent.CallbackEntry.OnNetworkCreated
 import com.android.testutils.TestableNetworkAgent.CallbackEntry.OnDscpPolicyStatusUpdated
+import com.android.testutils.TestableNetworkAgent.CallbackEntry.OnNetworkCreated
 import com.android.testutils.TestableNetworkCallback
-import org.junit.After
-import org.junit.Assume.assumeTrue
-import org.junit.Before
-import org.junit.Rule
-import org.junit.Test
-import org.junit.runner.RunWith
+import com.android.testutils.assertParcelingIsLossless
+import com.android.testutils.runAsShell
 import java.net.Inet4Address
 import java.net.Inet6Address
 import java.net.InetSocketAddress
@@ -94,6 +89,12 @@
 import kotlin.test.assertNotNull
 import kotlin.test.assertTrue
 import kotlin.test.fail
+import org.junit.After
+import org.junit.Assume.assumeTrue
+import org.junit.Before
+import org.junit.Rule
+import org.junit.Test
+import org.junit.runner.RunWith
 
 private const val MAX_PACKET_LENGTH = 1500
 
@@ -108,6 +109,7 @@
 
 @AppModeFull(reason = "Instant apps cannot create test networks")
 @RunWith(AndroidJUnit4::class)
+@ConnectivityModuleTest
 class DscpPolicyTest {
     @JvmField
     @Rule
@@ -203,7 +205,7 @@
     @After
     fun tearDown() {
         if (!kernelIsAtLeast(5, 15)) {
-            return;
+            return
         }
         if (!this::iface.isInitialized) {
             // Test was skipped or crashed in setUp
@@ -242,33 +244,32 @@
 
     private fun waitForGlobalIpv6Address(network: Network): Inet6Address {
         // Wait for global IPv6 address to be available
-        val sock = Os.socket(AF_INET6, SOCK_DGRAM, IPPROTO_UDP)
-        network.bindSocket(sock)
-
         var inet6Addr: Inet6Address? = null
-        val timeout = SystemClock.elapsedRealtime() + PACKET_TIMEOUT_MS
-        while (timeout > SystemClock.elapsedRealtime()) {
+        val onLinkPrefix = raResponder.prefix
+        val startTime = SystemClock.elapsedRealtime()
+        while (SystemClock.elapsedRealtime() - startTime < PACKET_TIMEOUT_MS) {
+            SystemClock.sleep(1 /* ms */)
+            val sock = Os.socket(AF_INET6, SOCK_DGRAM, IPPROTO_UDP)
             try {
-                // Pick any arbitrary port
-                Os.connect(sock, TEST_TARGET_IPV6_ADDR, 12345)
-                val sockAddr = Os.getsockname(sock) as InetSocketAddress
+                network.bindSocket(sock)
 
-                // TODO: make RouterAdvertisementResponder.SLAAC_PREFIX public and use it here,
-                // or make it configurable and configure it here.
-                if (IpPrefix("2001:db8::/64").contains(sockAddr.address)) {
+                try {
+                    // Pick any arbitrary port
+                    Os.connect(sock, TEST_TARGET_IPV6_ADDR, 12345)
+                } catch (e: ErrnoException) {
+                    // there may not be an address available yet.
+                    if (e.errno == ENETUNREACH) continue
+                    throw e
+                }
+                val sockAddr = Os.getsockname(sock) as InetSocketAddress
+                if (onLinkPrefix.contains(sockAddr.address)) {
                     inet6Addr = sockAddr.address as Inet6Address
                     break
                 }
-            } catch (e: ErrnoException) {
-                // ignore ENETUNREACH -- there may not be an address available yet.
-                if (e.errno != ENETUNREACH) {
-                    Os.close(sock)
-                    throw e
-                }
+            } finally {
+                Os.close(sock)
             }
-            SystemClock.sleep(10 /* ms */)
         }
-        Os.close(sock)
         assertNotNull(inet6Addr)
         return inet6Addr!!
     }
@@ -296,7 +297,6 @@
         val lp = LinkProperties().apply {
             addLinkAddress(LinkAddress(LOCAL_IPV4_ADDRESS, IP4_PREFIX_LEN))
             addRoute(RouteInfo(IpPrefix("0.0.0.0/0"), null, null))
-            addRoute(RouteInfo(IpPrefix("::/0"), TEST_ROUTER_IPV6_ADDR))
             setInterfaceName(specifier)
         }
         val config = NetworkAgentConfig.Builder().build()
@@ -322,7 +322,7 @@
     fun sendPacket(
         agent: TestableNetworkAgent,
         sendV6: Boolean,
-        dstPort: Int = 0,
+        dstPort: Int = 0
     ) {
         val testString = "test string"
         val testPacket = ByteBuffer.wrap(testString.toByteArray(Charsets.UTF_8))
@@ -334,11 +334,14 @@
 
         val originalPacket = testPacket.readAsArray()
         Os.sendto(socket, originalPacket, 0 /* bytesOffset */, originalPacket.size, 0 /* flags */,
-                if(sendV6) TEST_TARGET_IPV6_ADDR else TEST_TARGET_IPV4_ADDR, dstPort)
+                if (sendV6) TEST_TARGET_IPV6_ADDR else TEST_TARGET_IPV4_ADDR, dstPort)
         Os.close(socket)
     }
 
-    fun parseV4PacketDscp(buffer : ByteBuffer) : Int {
+    fun parseV4PacketDscp(buffer: ByteBuffer): Int {
+        // Validate checksum before parsing packet.
+        val calCheck = IpUtils.ipChecksum(buffer, Struct.getSize(EthernetHeader::class.java))
+
         val ip_ver = buffer.get()
         val tos = buffer.get()
         val length = buffer.getShort()
@@ -347,10 +350,12 @@
         val ttl = buffer.get()
         val ipType = buffer.get()
         val checksum = buffer.getShort()
+
+        assertEquals(0, calCheck, "Invalid IPv4 header checksum")
         return tos.toInt().shr(2)
     }
 
-    fun parseV6PacketDscp(buffer : ByteBuffer) : Int {
+    fun parseV6PacketDscp(buffer: ByteBuffer): Int {
         val ip_ver = buffer.get()
         val tc = buffer.get()
         val fl = buffer.getShort()
@@ -364,9 +369,9 @@
     }
 
     fun parsePacketIp(
-        buffer : ByteBuffer,
-        sendV6 : Boolean,
-    ) : Boolean {
+        buffer: ByteBuffer,
+        sendV6: Boolean
+    ): Boolean {
         val ipAddr = if (sendV6) ByteArray(16) else ByteArray(4)
         buffer.get(ipAddr)
         val srcIp = if (sendV6) Inet6Address.getByAddress(ipAddr)
@@ -379,18 +384,18 @@
 
         if ((sendV6 && srcIp == srcAddressV6 && dstIp == TEST_TARGET_IPV6_ADDR) ||
                 (!sendV6 && srcIp == LOCAL_IPV4_ADDRESS && dstIp == TEST_TARGET_IPV4_ADDR)) {
-            Log.e(TAG, "IP return true");
+            Log.e(TAG, "IP return true")
             return true
         }
-        Log.e(TAG, "IP return false");
+        Log.e(TAG, "IP return false")
         return false
     }
 
     fun parsePacketPort(
-        buffer : ByteBuffer,
-        srcPort : Int,
-        dstPort : Int
-    ) : Boolean {
+        buffer: ByteBuffer,
+        srcPort: Int,
+        dstPort: Int
+    ): Boolean {
         if (srcPort == 0 && dstPort == 0) return true
 
         val packetSrcPort = buffer.getShort().toInt()
@@ -400,20 +405,20 @@
 
         if ((srcPort == 0 || (srcPort != 0 && srcPort == packetSrcPort)) &&
                 (dstPort == 0 || (dstPort != 0 && dstPort == packetDstPort))) {
-            Log.e(TAG, "Port return true");
+            Log.e(TAG, "Port return true")
             return true
         }
-        Log.e(TAG, "Port return false");
+        Log.e(TAG, "Port return false")
         return false
     }
 
     fun validatePacket(
-        agent : TestableNetworkAgent,
-        sendV6 : Boolean = false,
-        dscpValue : Int = 0,
-        dstPort : Int = 0,
+        agent: TestableNetworkAgent,
+        sendV6: Boolean = false,
+        dscpValue: Int = 0,
+        dstPort: Int = 0
     ) {
-        var packetFound = false;
+        var packetFound = false
         sendPacket(agent, sendV6, dstPort)
         // TODO: grab source port from socket in sendPacket
 
@@ -421,6 +426,7 @@
         val packets = generateSequence { reader.poll(PACKET_TIMEOUT_MS) }
         for (packet in packets) {
             val buffer = ByteBuffer.wrap(packet, 0, packet.size).order(ByteOrder.BIG_ENDIAN)
+
             // TODO: consider using Struct.parse for all packet parsing.
             val etherHdr = Struct.parse(EthernetHeader::class.java, buffer)
             val expectedType = if (sendV6) ETHER_TYPE_IPV6 else ETHER_TYPE_IPV4
@@ -464,6 +470,9 @@
             assertEquals(DSCP_POLICY_STATUS_SUCCESS, it.status)
         }
         validatePacket(agent, dscpValue = 1, dstPort = 4444)
+        // Send a second packet to validate that the stored BPF policy
+        // is correct for subsequent packets.
+        validatePacket(agent, dscpValue = 1, dstPort = 4444)
 
         agent.sendRemoveDscpPolicy(1)
         agent.expectCallback<OnDscpPolicyStatusUpdated>().let {
@@ -502,6 +511,9 @@
             assertEquals(DSCP_POLICY_STATUS_SUCCESS, it.status)
         }
         validatePacket(agent, true, dscpValue = 1, dstPort = 4444)
+        // Send a second packet to validate that the stored BPF policy
+        // is correct for subsequent packets.
+        validatePacket(agent, true, dscpValue = 1, dstPort = 4444)
 
         agent.sendRemoveDscpPolicy(1)
         agent.expectCallback<OnDscpPolicyStatusUpdated>().let {