Merge "Enhance ApfIntegrationTest to test multiple packet transmission" into main
diff --git a/tests/cts/net/src/android/net/cts/ApfIntegrationTest.kt b/tests/cts/net/src/android/net/cts/ApfIntegrationTest.kt
index 324078a..3ab6c0d 100644
--- a/tests/cts/net/src/android/net/cts/ApfIntegrationTest.kt
+++ b/tests/cts/net/src/android/net/cts/ApfIntegrationTest.kt
@@ -60,6 +60,8 @@
 import android.system.OsConstants
 import android.system.OsConstants.AF_INET6
 import android.system.OsConstants.ETH_P_IPV6
+import android.system.OsConstants.ICMP6_ECHO_REPLY
+import android.system.OsConstants.ICMP6_ECHO_REQUEST
 import android.system.OsConstants.IPPROTO_ICMPV6
 import android.system.OsConstants.SOCK_DGRAM
 import android.system.OsConstants.SOCK_NONBLOCK
@@ -212,8 +214,13 @@
             handler: Handler,
             private val network: Network
     ) : PacketReader(handler, RCV_BUFFER_SIZE) {
+        private data class PingContext(
+            val futureReply: CompletableFuture<List<ByteArray>>,
+            val expectReplyCount: Int,
+            val replyPayloads: MutableList<ByteArray> = mutableListOf()
+        )
         private var sockFd: FileDescriptor? = null
-        private var futureReply: CompletableFuture<ByteArray>? = null
+        private var pingContext: PingContext? = null
 
         override fun createFd(): FileDescriptor {
             // sockFd is closed by calling super.stop()
@@ -225,6 +232,8 @@
         }
 
         override fun handlePacket(recvbuf: ByteArray, length: Int) {
+            val context = pingContext ?: return
+
             // If zero-length or Type is not echo reply: ignore.
             if (length == 0 || recvbuf[0] != 0x81.toByte()) {
                 return
@@ -232,10 +241,14 @@
             // Only copy the ping data and complete the future.
             val result = recvbuf.sliceArray(8..<length)
             Log.i(TAG, "Received ping reply: ${result.toHexString()}")
-            futureReply!!.complete(recvbuf.sliceArray(8..<length))
+            context.replyPayloads.add(recvbuf.sliceArray(8..<length))
+            if (context.replyPayloads.size == context.expectReplyCount) {
+                context.futureReply.complete(context.replyPayloads)
+                pingContext = null
+            }
         }
 
-        fun sendPing(data: ByteArray, payloadSize: Int) {
+        fun sendPing(data: ByteArray, payloadSize: Int, expectReplyCount: Int = 1) {
             require(data.size == payloadSize)
 
             // rfc4443#section-4.1: Echo Request Message
@@ -251,17 +264,20 @@
             val icmp6Header = byteArrayOf(0x80.toByte(), 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00)
             val packet = icmp6Header + data
             Log.i(TAG, "Sent ping: ${packet.toHexString()}")
-            futureReply = CompletableFuture<ByteArray>()
+            pingContext = PingContext(
+                futureReply = CompletableFuture<List<ByteArray>>(),
+                expectReplyCount = expectReplyCount
+            )
             Os.sendto(sockFd!!, packet, 0, packet.size, 0, PING_DESTINATION)
         }
 
-        fun expectPingReply(timeoutMs: Long = TIMEOUT_MS): ByteArray {
-            return futureReply!!.get(timeoutMs, TimeUnit.MILLISECONDS)
+        fun expectPingReply(timeoutMs: Long = TIMEOUT_MS): List<ByteArray> {
+            return pingContext!!.futureReply.get(timeoutMs, TimeUnit.MILLISECONDS)
         }
 
         fun expectPingDropped() {
             assertFailsWith(TimeoutException::class) {
-                futureReply!!.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)
+                pingContext!!.futureReply.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)
             }
         }
 
@@ -503,7 +519,7 @@
         }
         val data = ByteArray(payloadSize).also { Random.nextBytes(it) }
         packetReader.sendPing(data, payloadSize)
-        assertThat(packetReader.expectPingReply()).isEqualTo(data)
+        assertThat(packetReader.expectPingReply()[0]).isEqualTo(data)
 
         // Generate an APF program that drops the next ping
         val gen = ApfV4Generator(
@@ -715,69 +731,76 @@
         //     increase PASSED_IPV6_ICMP counter
         //     pass
         //   else
-        //     transmit a ICMPv6 echo request packet with the first byte of the payload in the reply
+        //     transmit 3 ICMPv6 echo requests with random first byte
         //     increase DROPPED_IPV6_NS_REPLIED_NON_DAD counter
         //     drop
-        val program = gen
-                .addLoad16(R0, ETH_ETHERTYPE_OFFSET)
+        gen.addLoad16(R0, ETH_ETHERTYPE_OFFSET)
                 .addJumpIfR0NotEquals(ETH_P_IPV6.toLong(), skipPacketLabel)
                 .addLoad8(R0, IPV6_NEXT_HEADER_OFFSET)
                 .addJumpIfR0NotEquals(IPPROTO_ICMPV6.toLong(), skipPacketLabel)
                 .addLoad8(R0, ICMP6_TYPE_OFFSET)
-                .addJumpIfR0NotEquals(0x81, skipPacketLabel) // Echo reply type
+                .addJumpIfR0NotEquals(ICMP6_ECHO_REPLY.toLong(), skipPacketLabel)
                 .addLoadFromMemory(R0, MemorySlot.PACKET_SIZE)
                 .addCountAndPassIfR0Equals(
-                        (ETHER_HEADER_LEN + IPV6_HEADER_LEN + PING_HEADER_LENGTH + firstByte.size)
-                                .toLong(),
-                        PASSED_IPV6_ICMP
+                    (ETHER_HEADER_LEN + IPV6_HEADER_LEN + PING_HEADER_LENGTH + firstByte.size)
+                        .toLong(),
+                    PASSED_IPV6_ICMP
                 )
-                // Ping Packet Generation
-                .addAllocate(pingRequestPktLen)
-                // Eth header
-                .addPacketCopy(ETHER_SRC_ADDR_OFFSET, ETHER_ADDR_LEN) // dst MAC address
-                .addPacketCopy(ETHER_DST_ADDR_OFFSET, ETHER_ADDR_LEN) // src MAC address
-                .addWriteU16(ETH_P_IPV6) // IPv6 type
-                // IPv6 Header
-                .addWrite32(0x60000000) // IPv6 Header: version, traffic class, flowlabel
-                // payload length (2 bytes) | next header: ICMPv6 (1 byte) | hop limit (1 byte)
-                .addWrite32(pingRequestIpv6PayloadLen shl 16 or (IPPROTO_ICMPV6 shl 8 or 64))
-                .addPacketCopy(IPV6_DEST_ADDR_OFFSET, IPV6_ADDR_LEN) // src ip
-                .addPacketCopy(IPV6_SRC_ADDR_OFFSET, IPV6_ADDR_LEN) // dst ip
-                // ICMPv6
-                .addWriteU8(0x80) // type: echo request
-                .addWriteU8(0) // code
-                .addWriteU16(pingRequestIpv6PayloadLen) // checksum
-                // identifier
-                .addPacketCopy(ETHER_HEADER_LEN + IPV6_HEADER_LEN + ICMPV6_HEADER_MIN_LEN, 2)
-                .addWriteU16(0) // sequence number
-                .addDataCopy(firstByte) // data
-                .addTransmitL4(
+
+        val numOfPacketToTransmit = 3
+        val expectReplyPayloads = (0 until numOfPacketToTransmit).map { Random.nextBytes(1) }
+        expectReplyPayloads.forEach { replyPingPayload ->
+            // Ping Packet Generation
+            gen.addAllocate(pingRequestPktLen)
+                    // Eth header
+                    .addPacketCopy(ETHER_SRC_ADDR_OFFSET, ETHER_ADDR_LEN) // dst MAC address
+                    .addPacketCopy(ETHER_DST_ADDR_OFFSET, ETHER_ADDR_LEN) // src MAC address
+                    .addWriteU16(ETH_P_IPV6) // IPv6 type
+                    // IPv6 Header
+                    .addWrite32(0x60000000) // IPv6 Header: version, traffic class, flowlabel
+                    // payload length (2 bytes) | next header: ICMPv6 (1 byte) | hop limit (1 byte)
+                    .addWrite32(pingRequestIpv6PayloadLen shl 16 or (IPPROTO_ICMPV6 shl 8 or 64))
+                    .addPacketCopy(IPV6_DEST_ADDR_OFFSET, IPV6_ADDR_LEN) // src ip
+                    .addPacketCopy(IPV6_SRC_ADDR_OFFSET, IPV6_ADDR_LEN) // dst ip
+                    // ICMPv6
+                    .addWriteU8(ICMP6_ECHO_REQUEST)
+                    .addWriteU8(0) // code
+                    .addWriteU16(pingRequestIpv6PayloadLen) // checksum
+                    // identifier
+                    .addPacketCopy(ETHER_HEADER_LEN + IPV6_HEADER_LEN + ICMPV6_HEADER_MIN_LEN, 2)
+                    .addWriteU16(0) // sequence number
+                    .addDataCopy(replyPingPayload) // data
+                    .addTransmitL4(
                         ETHER_HEADER_LEN, // ip_ofs
                         ICMP6_CHECKSUM_OFFSET, // csum_ofs
                         IPV6_SRC_ADDR_OFFSET, // csum_start
                         IPPROTO_ICMPV6, // partial_sum
                         false // udp
-                )
-                // Warning: the program abuse DROPPED_IPV6_NS_REPLIED_NON_DAD for debugging purpose
-                .addCountAndDrop(DROPPED_IPV6_NS_REPLIED_NON_DAD)
-                .defineLabel(skipPacketLabel)
-                .addPass()
-                .generate()
+                    )
+        }
 
+        // Warning: the program abuse DROPPED_IPV6_NS_REPLIED_NON_DAD for debugging purpose
+        gen.addCountAndDrop(DROPPED_IPV6_NS_REPLIED_NON_DAD)
+            .defineLabel(skipPacketLabel)
+            .addPass()
+
+        val program = gen.generate()
         installAndVerifyProgram(program)
 
-        packetReader.sendPing(payload, payloadSize)
-
-        val replyPayload = try {
+        packetReader.sendPing(payload, payloadSize, expectReplyCount = numOfPacketToTransmit)
+        val replyPayloads = try {
             packetReader.expectPingReply(TIMEOUT_MS * 2)
         } catch (e: TimeoutException) {
-            byteArrayOf() // Empty payload if timeout occurs
+            emptyList()
         }
 
         val apfCounterTracker = ApfCounterTracker()
         apfCounterTracker.updateCountersFromData(readProgram())
         Log.i(TAG, "counter map: ${apfCounterTracker.counters}")
 
-        assertThat(replyPayload).isEqualTo(firstByte)
+        assertThat(replyPayloads.size).isEqualTo(expectReplyPayloads.size)
+        for (i in replyPayloads.indices) {
+            assertThat(replyPayloads[i]).isEqualTo(expectReplyPayloads[i])
+        }
     }
 }