Test onStartSocketKeepalive

Test: this
Bug: 139268426
Change-Id: I26f68fdf9b687a2d87c971525b1cd2c48d4579bb
Merged-In: I4e251fa0203a1888badef9ed90495fe8b3340a1c
(cherry-picked aosp/1258137)
diff --git a/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt b/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt
index d0e3023..32f2bfa 100644
--- a/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt
+++ b/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt
@@ -18,29 +18,51 @@
 import android.app.Instrumentation
 import android.content.Context
 import android.net.ConnectivityManager
+import android.net.KeepalivePacketData
+import android.net.LinkAddress
 import android.net.LinkProperties
+import android.net.Network
 import android.net.NetworkAgent
+import android.net.NetworkAgent.CMD_ADD_KEEPALIVE_PACKET_FILTER
+import android.net.NetworkAgent.CMD_REMOVE_KEEPALIVE_PACKET_FILTER
+import android.net.NetworkAgent.CMD_START_SOCKET_KEEPALIVE
+import android.net.NetworkAgent.CMD_STOP_SOCKET_KEEPALIVE
 import android.net.NetworkAgentConfig
 import android.net.NetworkCapabilities
 import android.net.NetworkProvider
 import android.net.NetworkRequest
-import android.net.cts.NetworkAgentTest.TestableNetworkAgent.CallbackEntry.OnBandwidthUpdateRequested
-import android.net.cts.NetworkAgentTest.TestableNetworkAgent.CallbackEntry.OnNetworkUnwanted
+import android.net.SocketKeepalive
 import android.os.Build
+import android.os.Handler
 import android.os.HandlerThread
 import android.os.Looper
+import android.os.Message
+import android.os.Messenger
+import android.net.cts.NetworkAgentTest.TestableNetworkAgent.CallbackEntry.OnAddKeepalivePacketFilter
+import android.net.cts.NetworkAgentTest.TestableNetworkAgent.CallbackEntry.OnBandwidthUpdateRequested
+import android.net.cts.NetworkAgentTest.TestableNetworkAgent.CallbackEntry.OnNetworkUnwanted
+import android.net.cts.NetworkAgentTest.TestableNetworkAgent.CallbackEntry.OnRemoveKeepalivePacketFilter
+import android.net.cts.NetworkAgentTest.TestableNetworkAgent.CallbackEntry.OnStartSocketKeepalive
+import android.net.cts.NetworkAgentTest.TestableNetworkAgent.CallbackEntry.OnStopSocketKeepalive
 import androidx.test.InstrumentationRegistry
 import androidx.test.runner.AndroidJUnit4
+import com.android.internal.util.AsyncChannel
 import com.android.testutils.ArrayTrackRecord
 import com.android.testutils.DevSdkIgnoreRule
 import com.android.testutils.RecorderCallback.CallbackEntry.Lost
 import com.android.testutils.TestableNetworkCallback
 import org.junit.After
+import org.junit.Assert.fail
 import org.junit.Before
 import org.junit.Rule
 import org.junit.Test
 import org.junit.runner.RunWith
+import java.net.InetAddress
+import java.time.Duration
+import kotlin.test.assertEquals
 import kotlin.test.assertFailsWith
+import kotlin.test.assertNotNull
+import kotlin.test.assertNull
 import kotlin.test.assertTrue
 
 // This test doesn't really have a constraint on how fast the methods should return. If it's
@@ -51,18 +73,29 @@
 // requests filed by the test and should never match normal internet requests. 70 is the default
 // score of Ethernet networks, it's as good a value as any other.
 private const val TEST_NETWORK_SCORE = 70
+private const val FAKE_NET_ID = 1098
 private val instrumentation: Instrumentation
     get() = InstrumentationRegistry.getInstrumentation()
 private val context: Context
     get() = InstrumentationRegistry.getContext()
+private fun Message(what: Int, arg1: Int, arg2: Int, obj: Any?) = Message.obtain().also {
+    it.what = what
+    it.arg1 = arg1
+    it.arg2 = arg2
+    it.obj = obj
+}
 
 @RunWith(AndroidJUnit4::class)
 class NetworkAgentTest {
     @Rule @JvmField
     val ignoreRule = DevSdkIgnoreRule(ignoreClassUpTo = Build.VERSION_CODES.Q)
 
+    private val LOCAL_IPV4_ADDRESS = InetAddress.parseNumericAddress("192.0.2.1")
+    private val REMOTE_IPV4_ADDRESS = InetAddress.parseNumericAddress("192.0.2.2")
+
     private val mCM = context.getSystemService(ConnectivityManager::class.java)
     private val mHandlerThread = HandlerThread("${javaClass.simpleName} handler thread")
+    private val mFakeConnectivityService by lazy { FakeConnectivityService(mHandlerThread.looper) }
 
     private class Provider(context: Context, looper: Looper) :
             NetworkProvider(context, looper, "NetworkAgentTest NetworkProvider")
@@ -84,8 +117,37 @@
         instrumentation.getUiAutomation().dropShellPermissionIdentity()
     }
 
-    private class TestableNetworkAgent(
-        looper: Looper,
+    /**
+     * A fake that helps simulating ConnectivityService talking to a harnessed agent.
+     * This fake only supports speaking to one harnessed agent at a time because it
+     * only keeps track of one async channel.
+     */
+    private class FakeConnectivityService(looper: Looper) {
+        private val msgHistory = ArrayTrackRecord<Message>().newReadHead()
+        private val asyncChannel = AsyncChannel()
+        private val handler = object : Handler(looper) {
+            override fun handleMessage(msg: Message) {
+                msgHistory.add(Message.obtain(msg)) // make a copy as the original will be recycled
+                when (msg.what) {
+                    AsyncChannel.CMD_CHANNEL_HALF_CONNECTED ->
+                        asyncChannel.sendMessage(AsyncChannel.CMD_CHANNEL_FULL_CONNECTION)
+                    AsyncChannel.CMD_CHANNEL_DISCONNECT, AsyncChannel.CMD_CHANNEL_DISCONNECTED ->
+                        fail("Agent unexpectedly disconnected")
+                }
+            }
+        }
+
+        fun connect(agentMsngr: Messenger) = asyncChannel.connect(context, handler, agentMsngr)
+
+        fun sendMessage(what: Int, arg1: Int = 0, arg2: Int = 0, obj: Any? = null) =
+            asyncChannel.sendMessage(Message(what, arg1, arg2, obj))
+
+        fun expectMessage(what: Int) =
+            assertNotNull(msgHistory.poll(DEFAULT_TIMEOUT_MS) { it.what == what })
+    }
+
+    private open class TestableNetworkAgent(
+        val looper: Looper,
         nc: NetworkCapabilities,
         lp: LinkProperties,
         conf: NetworkAgentConfig
@@ -96,6 +158,17 @@
         sealed class CallbackEntry {
             object OnBandwidthUpdateRequested : CallbackEntry()
             object OnNetworkUnwanted : CallbackEntry()
+            data class OnAddKeepalivePacketFilter(
+                val slot: Int,
+                val packet: KeepalivePacketData
+            ) : CallbackEntry()
+            data class OnRemoveKeepalivePacketFilter(val slot: Int) : CallbackEntry()
+            data class OnStartSocketKeepalive(
+                val slot: Int,
+                val interval: Int,
+                val packet: KeepalivePacketData
+            ) : CallbackEntry()
+            data class OnStopSocketKeepalive(val slot: Int) : CallbackEntry()
         }
 
         override fun onBandwidthUpdateRequested() {
@@ -106,9 +179,35 @@
             history.add(OnNetworkUnwanted)
         }
 
-        inline fun <reified T : CallbackEntry> expectCallback() {
+        override fun onAddKeepalivePacketFilter(slot: Int, packet: KeepalivePacketData) {
+            history.add(OnAddKeepalivePacketFilter(slot, packet))
+        }
+
+        override fun onRemoveKeepalivePacketFilter(slot: Int) {
+            history.add(OnRemoveKeepalivePacketFilter(slot))
+        }
+
+        override fun onStartSocketKeepalive(
+            slot: Int,
+            interval: Duration,
+            packet: KeepalivePacketData
+        ) {
+            history.add(OnStartSocketKeepalive(slot, interval.seconds.toInt(), packet))
+        }
+
+        override fun onStopSocketKeepalive(slot: Int) {
+            history.add(OnStopSocketKeepalive(slot))
+        }
+
+        inline fun <reified T : CallbackEntry> expectCallback(): T {
             val foundCallback = history.poll(DEFAULT_TIMEOUT_MS)
             assertTrue(foundCallback is T, "Expected ${T::class} but found $foundCallback")
+            return foundCallback
+        }
+
+        fun assertNoCallback() {
+            assertTrue(waitForIdle(DEFAULT_TIMEOUT_MS), "Handler never became idle")
+            assertNull(history.peek())
         }
     }
 
@@ -126,7 +225,9 @@
             addCapability(NetworkCapabilities.NET_CAPABILITY_NOT_ROAMING)
             addCapability(NetworkCapabilities.NET_CAPABILITY_NOT_VPN)
         }
-        val lp = LinkProperties()
+        val lp = LinkProperties().apply {
+            addLinkAddress(LinkAddress(LOCAL_IPV4_ADDRESS, 0))
+        }
         val config = NetworkAgentConfig.Builder().build()
         return TestableNetworkAgent(mHandlerThread.looper, nc, lp, config).also {
             agentsToCleanUp.add(it)
@@ -146,6 +247,10 @@
         return agent to callback
     }
 
+    private fun createNetworkAgentWithFakeCS() = createNetworkAgent().also {
+        mFakeConnectivityService.connect(it.registerForTest(Network(FAKE_NET_ID)))
+    }
+
     @Test
     fun testConnectAndUnregister() {
         val (agent, callback) = createConnectedNetworkAgent()
@@ -166,4 +271,48 @@
         agent.expectCallback<OnBandwidthUpdateRequested>()
         agent.unregister()
     }
+
+    @Test
+    fun testSocketKeepalive(): Unit = createNetworkAgentWithFakeCS().let { agent ->
+        val packet = object : KeepalivePacketData(
+                LOCAL_IPV4_ADDRESS /* srcAddress */, 1234 /* srcPort */,
+                REMOTE_IPV4_ADDRESS /* dstAddress */, 4567 /* dstPort */,
+                ByteArray(100 /* size */) { it.toByte() /* init */ }) {}
+        val slot = 4
+        val interval = 37
+
+        mFakeConnectivityService.sendMessage(CMD_ADD_KEEPALIVE_PACKET_FILTER,
+                arg1 = slot, obj = packet)
+        mFakeConnectivityService.sendMessage(CMD_START_SOCKET_KEEPALIVE,
+                arg1 = slot, arg2 = interval, obj = packet)
+
+        agent.expectCallback<OnAddKeepalivePacketFilter>().let {
+            assertEquals(it.slot, slot)
+            assertEquals(it.packet, packet)
+        }
+        agent.expectCallback<OnStartSocketKeepalive>().let {
+            assertEquals(it.slot, slot)
+            assertEquals(it.interval, interval)
+            assertEquals(it.packet, packet)
+        }
+
+        agent.assertNoCallback()
+
+        // Check that when the agent sends a keepalive event, ConnectivityService receives the
+        // expected message.
+        agent.sendSocketKeepaliveEvent(slot, SocketKeepalive.ERROR_UNSUPPORTED)
+        mFakeConnectivityService.expectMessage(NetworkAgent.EVENT_SOCKET_KEEPALIVE).let() {
+            assertEquals(slot, it.arg1)
+            assertEquals(SocketKeepalive.ERROR_UNSUPPORTED, it.arg2)
+        }
+
+        mFakeConnectivityService.sendMessage(CMD_STOP_SOCKET_KEEPALIVE, arg1 = slot)
+        mFakeConnectivityService.sendMessage(CMD_REMOVE_KEEPALIVE_PACKET_FILTER, arg1 = slot)
+        agent.expectCallback<OnStopSocketKeepalive>().let {
+            assertEquals(it.slot, slot)
+        }
+        agent.expectCallback<OnRemoveKeepalivePacketFilter>().let {
+            assertEquals(it.slot, slot)
+        }
+    }
 }