Add a test for prefilled memory slots

The test writes the prefilled memory slots to a counter region, and
triggers program execution by sending a ping packet.

Test: atest ApfIntegrationTest
Change-Id: Ib90ad52e8395ef1af3dff7e026fb2734d5ff3e64
diff --git a/tests/cts/net/src/android/net/cts/ApfIntegrationTest.kt b/tests/cts/net/src/android/net/cts/ApfIntegrationTest.kt
index 6ce8b7c..fcba530 100644
--- a/tests/cts/net/src/android/net/cts/ApfIntegrationTest.kt
+++ b/tests/cts/net/src/android/net/cts/ApfIntegrationTest.kt
@@ -31,7 +31,9 @@
 import android.net.apf.ApfConstant.IPV6_NEXT_HEADER_OFFSET
 import android.net.apf.ApfV4Generator
 import android.net.apf.BaseApfGenerator
+import android.net.apf.BaseApfGenerator.MemorySlot
 import android.net.apf.BaseApfGenerator.Register.R0
+import android.net.apf.BaseApfGenerator.Register.R1
 import android.os.Build
 import android.os.Handler
 import android.os.HandlerThread
@@ -64,12 +66,14 @@
 import com.android.testutils.TestableNetworkCallback
 import com.android.testutils.runAsShell
 import com.android.testutils.waitForIdle
+import com.google.common.truth.Expect
 import com.google.common.truth.Truth.assertThat
 import com.google.common.truth.Truth.assertWithMessage
 import com.google.common.truth.TruthJUnit.assume
 import java.io.FileDescriptor
 import java.lang.Thread
 import java.net.InetSocketAddress
+import java.nio.ByteBuffer
 import java.util.concurrent.CompletableFuture
 import java.util.concurrent.TimeUnit
 import java.util.concurrent.TimeoutException
@@ -230,8 +234,8 @@
         }
     }
 
-    @get:Rule
-    val ignoreRule = DevSdkIgnoreRule()
+    @get:Rule val ignoreRule = DevSdkIgnoreRule()
+    @get:Rule val expect = Expect.create()
 
     private val cm by lazy { context.getSystemService(ConnectivityManager::class.java)!! }
     private val pm by lazy { context.packageManager }
@@ -364,6 +368,20 @@
         }
     }
 
+    fun ApfV4Generator.addPassIfNotIcmpv6EchoReply() {
+        // If not IPv6 -> PASS
+        addLoad16(R0, ETH_ETHERTYPE_OFFSET)
+        addJumpIfR0NotEquals(ETH_P_IPV6.toLong(), BaseApfGenerator.PASS_LABEL)
+
+        // If not ICMPv6 -> PASS
+        addLoad8(R0, IPV6_NEXT_HEADER_OFFSET)
+        addJumpIfR0NotEquals(IPPROTO_ICMPV6.toLong(), BaseApfGenerator.PASS_LABEL)
+
+        // If not echo reply -> PASS
+        addLoad8(R0, ICMP6_TYPE_OFFSET)
+        addJumpIfR0NotEquals(0x81, BaseApfGenerator.PASS_LABEL)
+    }
+
     @Test
     fun testDropPingReply() {
         assumeApfVersionSupportAtLeast(4)
@@ -381,17 +399,8 @@
         // Generate an APF program that drops the next ping
         gen = ApfV4Generator(caps.apfVersionSupported)
 
-        // If not IPv6 -> PASS
-        gen.addLoad16(R0, ETH_ETHERTYPE_OFFSET)
-        gen.addJumpIfR0NotEquals(ETH_P_IPV6.toLong(), BaseApfGenerator.PASS_LABEL)
-
-        // If not ICMPv6 -> PASS
-        gen.addLoad8(R0, IPV6_NEXT_HEADER_OFFSET)
-        gen.addJumpIfR0NotEquals(IPPROTO_ICMPV6.toLong(), BaseApfGenerator.PASS_LABEL)
-
-        // If not echo reply -> PASS
-        gen.addLoad8(R0, ICMP6_TYPE_OFFSET)
-        gen.addJumpIfR0NotEquals(0x81, BaseApfGenerator.PASS_LABEL)
+        // If not ICMPv6 Echo Reply -> PASS
+        gen.addPassIfNotIcmpv6EchoReply()
 
         // if not data matches -> PASS
         gen.addLoadImmediate(R0, ICMP6_TYPE_OFFSET + PING_HEADER_LENGTH)
@@ -407,4 +416,52 @@
         packetReader.sendPing(data)
         packetReader.expectPingDropped()
     }
+
+    // APF integration is mostly broken before V
+    @IgnoreUpTo(Build.VERSION_CODES.UPSIDE_DOWN_CAKE)
+    @Test
+    fun testPrefilledMemorySlotsV4() {
+        // Test v4 memory slots on both v4 and v6 interpreters.
+        assumeApfVersionSupportAtLeast(4)
+        // Clear the entire memory before starting this test
+        installProgram(ByteArray(caps.maximumApfProgramSize))
+        val gen = ApfV4Generator(4)
+
+        // If not ICMPv6 Echo Reply -> PASS
+        gen.addPassIfNotIcmpv6EchoReply()
+
+        // Store all prefilled memory slots in counter region [500, 520)
+        val counterRegion = 500
+        gen.addLoadImmediate(R1, counterRegion)
+        gen.addLoadFromMemory(R0, MemorySlot.PROGRAM_SIZE)
+        gen.addStoreData(R0, 0)
+        gen.addLoadFromMemory(R0, MemorySlot.RAM_LEN)
+        gen.addStoreData(R0, 4)
+        gen.addLoadFromMemory(R0, MemorySlot.IPV4_HEADER_SIZE)
+        gen.addStoreData(R0, 8)
+        gen.addLoadFromMemory(R0, MemorySlot.PACKET_SIZE)
+        gen.addStoreData(R0, 12)
+        gen.addLoadFromMemory(R0, MemorySlot.FILTER_AGE_SECONDS)
+        gen.addStoreData(R0, 16)
+
+        val program = gen.generate()
+        assertThat(program.size).isLessThan(counterRegion)
+        installProgram(program)
+        readProgram() // wait for install completion
+
+        // Trigger the program by sending a ping and waiting on the reply.
+        val data = ByteArray(56).also { Random.nextBytes(it) }
+        packetReader.sendPing(data)
+        packetReader.expectPingReply()
+
+        val readResult = readProgram()
+        val buffer = ByteBuffer.wrap(readResult)
+        buffer.position(counterRegion)
+        expect.withMessage("PROGRAM_SIZE").that(buffer.getInt()).isEqualTo(program.size)
+        expect.withMessage("RAM_LEN").that(buffer.getInt()).isEqualTo(caps.maximumApfProgramSize)
+        expect.withMessage("IPV4_HEADER_SIZE").that(buffer.getInt()).isEqualTo(0)
+        // Ping packet (64) + IPv6 header (40) + ethernet header (14)
+        expect.withMessage("PACKET_SIZE").that(buffer.getInt()).isEqualTo(64 + 40 + 14)
+        expect.withMessage("FILTER_AGE_SECONDS").that(buffer.getInt()).isLessThan(5)
+    }
 }