Test accept unvalidated

Test: this
Bug: 139268426
Change-Id: I9343f72e1b1f4752e9781ff9b44e2a561d166cfb
Merged-In: I3326a2119d66e67566fce0268ea4861729b1c64c
(cherry-picked from aosp/1284557)
diff --git a/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt b/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt
index 32f2bfa..14f52e1 100644
--- a/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt
+++ b/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt
@@ -24,7 +24,9 @@
 import android.net.Network
 import android.net.NetworkAgent
 import android.net.NetworkAgent.CMD_ADD_KEEPALIVE_PACKET_FILTER
+import android.net.NetworkAgent.CMD_PREVENT_AUTOMATIC_RECONNECT
 import android.net.NetworkAgent.CMD_REMOVE_KEEPALIVE_PACKET_FILTER
+import android.net.NetworkAgent.CMD_SAVE_ACCEPT_UNVALIDATED
 import android.net.NetworkAgent.CMD_START_SOCKET_KEEPALIVE
 import android.net.NetworkAgent.CMD_STOP_SOCKET_KEEPALIVE
 import android.net.NetworkAgentConfig
@@ -39,9 +41,11 @@
 import android.os.Message
 import android.os.Messenger
 import android.net.cts.NetworkAgentTest.TestableNetworkAgent.CallbackEntry.OnAddKeepalivePacketFilter
+import android.net.cts.NetworkAgentTest.TestableNetworkAgent.CallbackEntry.OnAutomaticReconnectDisabled
 import android.net.cts.NetworkAgentTest.TestableNetworkAgent.CallbackEntry.OnBandwidthUpdateRequested
 import android.net.cts.NetworkAgentTest.TestableNetworkAgent.CallbackEntry.OnNetworkUnwanted
 import android.net.cts.NetworkAgentTest.TestableNetworkAgent.CallbackEntry.OnRemoveKeepalivePacketFilter
+import android.net.cts.NetworkAgentTest.TestableNetworkAgent.CallbackEntry.OnSaveAcceptUnvalidated
 import android.net.cts.NetworkAgentTest.TestableNetworkAgent.CallbackEntry.OnStartSocketKeepalive
 import android.net.cts.NetworkAgentTest.TestableNetworkAgent.CallbackEntry.OnStopSocketKeepalive
 import androidx.test.InstrumentationRegistry
@@ -60,6 +64,7 @@
 import java.net.InetAddress
 import java.time.Duration
 import kotlin.test.assertEquals
+import kotlin.test.assertFalse
 import kotlin.test.assertFailsWith
 import kotlin.test.assertNotNull
 import kotlin.test.assertNull
@@ -123,27 +128,38 @@
      * only keeps track of one async channel.
      */
     private class FakeConnectivityService(looper: Looper) {
+        private val CMD_EXPECT_DISCONNECT = 1
+        private var disconnectExpected = false
         private val msgHistory = ArrayTrackRecord<Message>().newReadHead()
         private val asyncChannel = AsyncChannel()
         private val handler = object : Handler(looper) {
             override fun handleMessage(msg: Message) {
                 msgHistory.add(Message.obtain(msg)) // make a copy as the original will be recycled
                 when (msg.what) {
+                    CMD_EXPECT_DISCONNECT -> disconnectExpected = true
                     AsyncChannel.CMD_CHANNEL_HALF_CONNECTED ->
                         asyncChannel.sendMessage(AsyncChannel.CMD_CHANNEL_FULL_CONNECTION)
-                    AsyncChannel.CMD_CHANNEL_DISCONNECT, AsyncChannel.CMD_CHANNEL_DISCONNECTED ->
-                        fail("Agent unexpectedly disconnected")
+                    AsyncChannel.CMD_CHANNEL_DISCONNECTED ->
+                        if (!disconnectExpected) {
+                            fail("Agent unexpectedly disconnected")
+                        } else {
+                            disconnectExpected = false
+                        }
                 }
             }
         }
 
         fun connect(agentMsngr: Messenger) = asyncChannel.connect(context, handler, agentMsngr)
 
+        fun disconnect() = asyncChannel.disconnect()
+
         fun sendMessage(what: Int, arg1: Int = 0, arg2: Int = 0, obj: Any? = null) =
             asyncChannel.sendMessage(Message(what, arg1, arg2, obj))
 
         fun expectMessage(what: Int) =
             assertNotNull(msgHistory.poll(DEFAULT_TIMEOUT_MS) { it.what == what })
+
+        fun willExpectDisconnectOnce() = handler.sendEmptyMessage(CMD_EXPECT_DISCONNECT)
     }
 
     private open class TestableNetworkAgent(
@@ -169,6 +185,8 @@
                 val packet: KeepalivePacketData
             ) : CallbackEntry()
             data class OnStopSocketKeepalive(val slot: Int) : CallbackEntry()
+            data class OnSaveAcceptUnvalidated(val accept: Boolean) : CallbackEntry()
+            object OnAutomaticReconnectDisabled : CallbackEntry()
         }
 
         override fun onBandwidthUpdateRequested() {
@@ -199,6 +217,14 @@
             history.add(OnStopSocketKeepalive(slot))
         }
 
+        override fun onSaveAcceptUnvalidated(accept: Boolean) {
+            history.add(OnSaveAcceptUnvalidated(accept))
+        }
+
+        override fun onAutomaticReconnectDisabled() {
+            history.add(OnAutomaticReconnectDisabled)
+        }
+
         inline fun <reified T : CallbackEntry> expectCallback(): T {
             val foundCallback = history.poll(DEFAULT_TIMEOUT_MS)
             assertTrue(foundCallback is T, "Expected ${T::class} but found $foundCallback")
@@ -315,4 +341,40 @@
             assertEquals(it.slot, slot)
         }
     }
+
+    @Test
+    fun testSetAcceptUnvalidated() {
+        createNetworkAgentWithFakeCS().let { agent ->
+            mFakeConnectivityService.sendMessage(CMD_SAVE_ACCEPT_UNVALIDATED, 1)
+            agent.expectCallback<OnSaveAcceptUnvalidated>().let {
+                assertTrue(it.accept)
+            }
+            agent.assertNoCallback()
+        }
+        createNetworkAgentWithFakeCS().let { agent ->
+            mFakeConnectivityService.sendMessage(CMD_SAVE_ACCEPT_UNVALIDATED, 0)
+            mFakeConnectivityService.sendMessage(CMD_PREVENT_AUTOMATIC_RECONNECT)
+            agent.expectCallback<OnSaveAcceptUnvalidated>().let {
+                assertFalse(it.accept)
+            }
+            agent.expectCallback<OnAutomaticReconnectDisabled>()
+            agent.assertNoCallback()
+            // When automatic reconnect is turned off, the network is torn down and
+            // ConnectivityService sends a disconnect. This in turn causes the agent
+            // to send a DISCONNECTED message to CS.
+            mFakeConnectivityService.willExpectDisconnectOnce()
+            mFakeConnectivityService.disconnect()
+            mFakeConnectivityService.expectMessage(AsyncChannel.CMD_CHANNEL_DISCONNECTED)
+            agent.expectCallback<OnNetworkUnwanted>()
+        }
+        createNetworkAgentWithFakeCS().let { agent ->
+            mFakeConnectivityService.sendMessage(CMD_PREVENT_AUTOMATIC_RECONNECT)
+            agent.expectCallback<OnAutomaticReconnectDisabled>()
+            agent.assertNoCallback()
+            mFakeConnectivityService.willExpectDisconnectOnce()
+            mFakeConnectivityService.disconnect()
+            mFakeConnectivityService.expectMessage(AsyncChannel.CMD_CHANNEL_DISCONNECTED)
+            agent.expectCallback<OnNetworkUnwanted>()
+        }
+    }
 }