Add Ikev2VpnTests including IKE negotiation.

This commit expands IKEv2 VPN CTS testing to ensure that given a
successful IKEv2 negotiation, the VPN network will be correctly set up.
Additionally, it verifies that the stopProvisionedVpnProfile will
teardown the VPN network.

Bug: 148582947
Test: atest CtsNetTestCases:Ikev2VpnTest
Change-Id: Ib6635f0068200ac0172515989fbdee5c3d49e231
diff --git a/tests/cts/net/src/android/net/cts/IkeTunUtils.java b/tests/cts/net/src/android/net/cts/IkeTunUtils.java
new file mode 100644
index 0000000..fc25292
--- /dev/null
+++ b/tests/cts/net/src/android/net/cts/IkeTunUtils.java
@@ -0,0 +1,188 @@
+/*
+ * Copyright (C) 2020 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package android.net.cts;
+
+import static android.net.cts.PacketUtils.BytePayload;
+import static android.net.cts.PacketUtils.IP4_HDRLEN;
+import static android.net.cts.PacketUtils.IP6_HDRLEN;
+import static android.net.cts.PacketUtils.IpHeader;
+import static android.net.cts.PacketUtils.UDP_HDRLEN;
+import static android.net.cts.PacketUtils.UdpHeader;
+import static android.net.cts.PacketUtils.getIpHeader;
+import static android.system.OsConstants.IPPROTO_UDP;
+
+import android.os.ParcelFileDescriptor;
+
+import java.net.InetAddress;
+import java.nio.ByteBuffer;
+import java.util.Arrays;
+
+// TODO: Merge this with the version in the IPsec module (IKEv2 library) CTS tests.
+/** An extension of the TunUtils class with IKE-specific packet handling. */
+public class IkeTunUtils extends TunUtils {
+    private static final int PORT_LEN = 2;
+
+    private static final byte[] NON_ESP_MARKER = new byte[] {0, 0, 0, 0};
+
+    private static final int IKE_HEADER_LEN = 28;
+    private static final int IKE_SPI_LEN = 8;
+    private static final int IKE_IS_RESP_BYTE_OFFSET = 19;
+    private static final int IKE_MSG_ID_OFFSET = 20;
+    private static final int IKE_MSG_ID_LEN = 4;
+
+    public IkeTunUtils(ParcelFileDescriptor tunFd) {
+        super(tunFd);
+    }
+
+    /**
+     * Await an expected IKE request and inject an IKE response.
+     *
+     * @param respIkePkt IKE response packet without IP/UDP headers or NON ESP MARKER.
+     */
+    public byte[] awaitReqAndInjectResp(long expectedInitIkeSpi, int expectedMsgId,
+            boolean encapExpected, byte[] respIkePkt) throws Exception {
+        final byte[] request = awaitIkePacket(expectedInitIkeSpi, expectedMsgId, encapExpected);
+
+        // Build response header by flipping address and port
+        final InetAddress srcAddr = getDstAddress(request);
+        final InetAddress dstAddr = getSrcAddress(request);
+        final int srcPort = getDstPort(request);
+        final int dstPort = getSrcPort(request);
+
+        final byte[] response =
+                buildIkePacket(srcAddr, dstAddr, srcPort, dstPort, encapExpected, respIkePkt);
+        injectPacket(response);
+        return request;
+    }
+
+    private byte[] awaitIkePacket(long expectedInitIkeSpi, int expectedMsgId, boolean expectEncap)
+            throws Exception {
+        return super.awaitPacket(pkt -> isIke(pkt, expectedInitIkeSpi, expectedMsgId, expectEncap));
+    }
+
+    private static boolean isIke(
+            byte[] pkt, long expectedInitIkeSpi, int expectedMsgId, boolean encapExpected) {
+        final int ipProtocolOffset;
+        final int ikeOffset;
+
+        if (isIpv6(pkt)) {
+            ipProtocolOffset = IP6_PROTO_OFFSET;
+            ikeOffset = IP6_HDRLEN + UDP_HDRLEN;
+        } else {
+            if (encapExpected && !hasNonEspMarkerv4(pkt)) {
+                return false;
+            }
+
+            // Use default IPv4 header length (assuming no options)
+            final int encapMarkerLen = encapExpected ? NON_ESP_MARKER.length : 0;
+            ipProtocolOffset = IP4_PROTO_OFFSET;
+            ikeOffset = IP4_HDRLEN + UDP_HDRLEN + encapMarkerLen;
+        }
+
+        return pkt[ipProtocolOffset] == IPPROTO_UDP
+                && areSpiAndMsgIdEqual(pkt, ikeOffset, expectedInitIkeSpi, expectedMsgId);
+    }
+
+    /** Checks if the provided IPv4 packet has a UDP-encapsulation NON-ESP marker */
+    private static boolean hasNonEspMarkerv4(byte[] ipv4Pkt) {
+        final int nonEspMarkerOffset = IP4_HDRLEN + UDP_HDRLEN;
+        if (ipv4Pkt.length < nonEspMarkerOffset + NON_ESP_MARKER.length) {
+            return false;
+        }
+
+        final byte[] nonEspMarker = Arrays.copyOfRange(
+                ipv4Pkt, nonEspMarkerOffset, nonEspMarkerOffset + NON_ESP_MARKER.length);
+        return Arrays.equals(NON_ESP_MARKER, nonEspMarker);
+    }
+
+    private static boolean areSpiAndMsgIdEqual(
+            byte[] pkt, int ikeOffset, long expectedIkeInitSpi, int expectedMsgId) {
+        if (pkt.length <= ikeOffset + IKE_HEADER_LEN) {
+            return false;
+        }
+
+        final ByteBuffer buffer = ByteBuffer.wrap(pkt);
+        final long spi = buffer.getLong(ikeOffset);
+        final int msgId = buffer.getInt(ikeOffset + IKE_MSG_ID_OFFSET);
+
+        return expectedIkeInitSpi == spi && expectedMsgId == msgId;
+    }
+
+    private static InetAddress getSrcAddress(byte[] pkt) throws Exception {
+        return getAddress(pkt, true);
+    }
+
+    private static InetAddress getDstAddress(byte[] pkt) throws Exception {
+        return getAddress(pkt, false);
+    }
+
+    private static InetAddress getAddress(byte[] pkt, boolean getSrcAddr) throws Exception {
+        final int ipLen = isIpv6(pkt) ? IP6_ADDR_LEN : IP4_ADDR_LEN;
+        final int srcIpOffset = isIpv6(pkt) ? IP6_ADDR_OFFSET : IP4_ADDR_OFFSET;
+        final int ipOffset = getSrcAddr ? srcIpOffset : srcIpOffset + ipLen;
+
+        if (pkt.length < ipOffset + ipLen) {
+            // Should be impossible; getAddress() is only called with a full IKE request including
+            // the IP and UDP headers.
+            throw new IllegalArgumentException("Packet was too short to contain IP address");
+        }
+
+        return InetAddress.getByAddress(Arrays.copyOfRange(pkt, ipOffset, ipOffset + ipLen));
+    }
+
+    private static int getSrcPort(byte[] pkt) throws Exception {
+        return getPort(pkt, true);
+    }
+
+    private static int getDstPort(byte[] pkt) throws Exception {
+        return getPort(pkt, false);
+    }
+
+    private static int getPort(byte[] pkt, boolean getSrcPort) {
+        final int srcPortOffset = isIpv6(pkt) ? IP6_HDRLEN : IP4_HDRLEN;
+        final int portOffset = getSrcPort ? srcPortOffset : srcPortOffset + PORT_LEN;
+
+        if (pkt.length < portOffset + PORT_LEN) {
+            // Should be impossible; getPort() is only called with a full IKE request including the
+            // IP and UDP headers.
+            throw new IllegalArgumentException("Packet was too short to contain port");
+        }
+
+        final ByteBuffer buffer = ByteBuffer.wrap(pkt);
+        return Short.toUnsignedInt(buffer.getShort(portOffset));
+    }
+
+    private static byte[] buildIkePacket(
+            InetAddress srcAddr,
+            InetAddress dstAddr,
+            int srcPort,
+            int dstPort,
+            boolean useEncap,
+            byte[] payload)
+            throws Exception {
+        // Append non-ESP marker if encap is enabled
+        if (useEncap) {
+            final ByteBuffer buffer = ByteBuffer.allocate(NON_ESP_MARKER.length + payload.length);
+            buffer.put(NON_ESP_MARKER);
+            buffer.put(payload);
+            payload = buffer.array();
+        }
+
+        final UdpHeader udpPkt = new UdpHeader(srcPort, dstPort, new BytePayload(payload));
+        final IpHeader ipPkt = getIpHeader(udpPkt.getProtocolId(), srcAddr, dstAddr, udpPkt);
+        return ipPkt.getPacketBytes();
+    }
+}
diff --git a/tests/cts/net/src/android/net/cts/Ikev2VpnTest.java b/tests/cts/net/src/android/net/cts/Ikev2VpnTest.java
index 8c1cbbb..ebce513 100644
--- a/tests/cts/net/src/android/net/cts/Ikev2VpnTest.java
+++ b/tests/cts/net/src/android/net/cts/Ikev2VpnTest.java
@@ -16,11 +16,15 @@
 
 package android.net.cts;
 
+import static android.net.NetworkCapabilities.TRANSPORT_VPN;
+import static android.net.cts.util.CtsNetUtils.TestNetworkCallback;
+
 import static com.android.compatibility.common.util.SystemUtil.runWithShellPermissionIdentity;
 
 import static org.junit.Assert.assertArrayEquals;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
@@ -34,7 +38,12 @@
 import android.net.ConnectivityManager;
 import android.net.Ikev2VpnProfile;
 import android.net.IpSecAlgorithm;
+import android.net.LinkAddress;
+import android.net.Network;
+import android.net.NetworkRequest;
 import android.net.ProxyInfo;
+import android.net.TestNetworkInterface;
+import android.net.TestNetworkManager;
 import android.net.VpnManager;
 import android.net.cts.util.CtsNetUtils;
 import android.platform.test.annotations.AppModeFull;
@@ -42,12 +51,14 @@
 import androidx.test.InstrumentationRegistry;
 import androidx.test.runner.AndroidJUnit4;
 
+import com.android.internal.util.HexDump;
 import com.android.org.bouncycastle.x509.X509V1CertificateGenerator;
 
 import org.junit.Test;
 import org.junit.runner.RunWith;
 
 import java.math.BigInteger;
+import java.net.InetAddress;
 import java.security.KeyPair;
 import java.security.KeyPairGenerator;
 import java.security.PrivateKey;
@@ -64,6 +75,45 @@
 public class Ikev2VpnTest {
     private static final String TAG = Ikev2VpnTest.class.getSimpleName();
 
+    // Test vectors for IKE negotiation in test mode.
+    private static final String SUCCESSFUL_IKE_INIT_RESP =
+            "46b8eca1e0d72a18b2b5d9006d47a0022120222000000000000002d0220000300000002c01010004030000"
+                    + "0c0100000c800e0100030000080300000c030000080200000400000008040000102800020800"
+                    + "100000b8070f159fe5141d8754ca86f72ecc28d66f514927e96cbe9eec0adb42bf2c276a0ab7"
+                    + "a97fa93555f4be9218c14e7f286bb28c6b4fb13825a420f2ffc165854f200bab37d69c8963d4"
+                    + "0acb831d983163aa50622fd35c182efe882cf54d6106222abcfaa597255d302f1b95ab71c142"
+                    + "c279ea5839a180070bff73f9d03fab815f0d5ee2adec7e409d1e35979f8bd92ffd8aab13d1a0"
+                    + "0657d816643ae767e9ae84d2ccfa2bcce1a50572be8d3748ae4863c41ae90da16271e014270f"
+                    + "77edd5cd2e3299f3ab27d7203f93d770bacf816041cdcecd0f9af249033979da4369cb242dd9"
+                    + "6d172e60513ff3db02de63e50eb7d7f596ada55d7946cad0af0669d1f3e2804846ab3f2a930d"
+                    + "df56f7f025f25c25ada694e6231abbb87ee8cfd072c8481dc0b0f6b083fdc3bd89b080e49feb"
+                    + "0288eef6fdf8a26ee2fc564a11e7385215cf2deaf2a9965638fc279c908ccdf04094988d91a2"
+                    + "464b4a8c0326533aff5119ed79ecbd9d99a218b44f506a5eb09351e67da86698b4c58718db25"
+                    + "d55f426fb4c76471b27a41fbce00777bc233c7f6e842e39146f466826de94f564cad8b92bfbe"
+                    + "87c99c4c7973ec5f1eea8795e7da82819753aa7c4fcfdab77066c56b939330c4b0d354c23f83"
+                    + "ea82fa7a64c4b108f1188379ea0eb4918ee009d804100e6bf118771b9058d42141c847d5ec37"
+                    + "6e5ec591c71fc9dac01063c2bd31f9c783b28bf1182900002430f3d5de3449462b31dd28bc27"
+                    + "297b6ad169bccce4f66c5399c6e0be9120166f2900001c0000400428b8df2e66f69c8584a186"
+                    + "c5eac66783551d49b72900001c000040054e7a622e802d5cbfb96d5f30a6e433994370173529"
+                    + "0000080000402e290000100000402f00020003000400050000000800004014";
+    private static final String SUCCESSFUL_IKE_AUTH_RESP =
+            "46b8eca1e0d72a18b2b5d9006d47a0022e20232000000001000000e0240000c420a2500a3da4c66fa6929e"
+                    + "600f36349ba0e38de14f78a3ad0416cba8c058735712a3d3f9a0a6ed36de09b5e9e02697e7c4"
+                    + "2d210ac86cfbd709503cfa51e2eab8cfdc6427d136313c072968f6506a546eb5927164200592"
+                    + "6e36a16ee994e63f029432a67bc7d37ca619e1bd6e1678df14853067ecf816b48b81e8746069"
+                    + "406363e5aa55f13cb2afda9dbebee94256c29d630b17dd7f1ee52351f92b6e1c3d8551c513f1"
+                    + "d74ac52a80b2041397e109fe0aeb3c105b0d4be0ae343a943398764281";
+    private static final long IKE_INITIATOR_SPI = Long.parseLong("46B8ECA1E0D72A18", 16);
+
+    private static final InetAddress LOCAL_OUTER_4 = InetAddress.parseNumericAddress("192.0.2.1");
+    private static final InetAddress LOCAL_OUTER_6 =
+            InetAddress.parseNumericAddress("2001:db8:1::1");
+
+    private static final int IP4_PREFIX_LEN = 32;
+    private static final int IP6_PREFIX_LEN = 128;
+
+    // TODO: Use IPv6 address when we can generate test vectors (GCE does not allow IPv6 yet).
+    private static final String TEST_SERVER_ADDR_V4 = "192.0.2.2";
     private static final String TEST_SERVER_ADDR = "2001:db8::1";
     private static final String TEST_IDENTITY = "client.cts.android.com";
     private static final List<String> TEST_ALLOWED_ALGORITHMS =
@@ -73,7 +123,7 @@
             ProxyInfo.buildDirectProxy("proxy.cts.android.com", 1234);
     private static final int TEST_MTU = 1300;
 
-    private static final byte[] TEST_PSK = "ikev2".getBytes();
+    private static final byte[] TEST_PSK = "ikeAndroidPsk".getBytes();
     private static final String TEST_USER = "username";
     private static final String TEST_PASSWORD = "pa55w0rd";
 
@@ -115,17 +165,22 @@
         }
 
         return builder.setBypassable(true)
+                .setAllowedAlgorithms(TEST_ALLOWED_ALGORITHMS)
                 .setProxy(TEST_PROXY_INFO)
                 .setMaxMtu(TEST_MTU)
                 .setMetered(false)
-                .setAllowedAlgorithms(TEST_ALLOWED_ALGORITHMS)
                 .build();
     }
 
     private Ikev2VpnProfile buildIkev2VpnProfilePsk(boolean isRestrictedToTestNetworks)
             throws Exception {
+        return buildIkev2VpnProfilePsk(TEST_SERVER_ADDR, isRestrictedToTestNetworks);
+    }
+
+    private Ikev2VpnProfile buildIkev2VpnProfilePsk(
+            String remote, boolean isRestrictedToTestNetworks) throws Exception {
         final Ikev2VpnProfile.Builder builder =
-                new Ikev2VpnProfile.Builder(TEST_SERVER_ADDR, TEST_IDENTITY).setAuthPsk(TEST_PSK);
+                new Ikev2VpnProfile.Builder(remote, TEST_IDENTITY).setAuthPsk(TEST_PSK);
 
         return buildIkev2VpnProfileCommon(builder, isRestrictedToTestNetworks);
     }
@@ -300,24 +355,84 @@
         }
     }
 
+    private void checkStartStopVpnProfileBuildsNetworks(IkeTunUtils tunUtils) throws Exception {
+        // Requires MANAGE_TEST_NETWORKS to provision a test-mode profile.
+        mCtsNetUtils.setAppopPrivileged(AppOpsManager.OP_ACTIVATE_PLATFORM_VPN, true);
+
+        final Ikev2VpnProfile profile =
+                buildIkev2VpnProfilePsk(TEST_SERVER_ADDR_V4, true /* isRestrictedToTestNetworks */);
+        assertNull(sVpnMgr.provisionVpnProfile(profile));
+
+        sVpnMgr.startProvisionedVpnProfile();
+
+        // Inject IKE negotiation
+        int expectedMsgId = 0;
+        tunUtils.awaitReqAndInjectResp(IKE_INITIATOR_SPI, expectedMsgId++, false /* isEncap */,
+                HexDump.hexStringToByteArray(SUCCESSFUL_IKE_INIT_RESP));
+        tunUtils.awaitReqAndInjectResp(IKE_INITIATOR_SPI, expectedMsgId++, true /* isEncap */,
+                HexDump.hexStringToByteArray(SUCCESSFUL_IKE_AUTH_RESP));
+
+        // Verify the VPN network came up
+        final NetworkRequest nr = new NetworkRequest.Builder()
+                .clearCapabilities().addTransportType(TRANSPORT_VPN).build();
+
+        final TestNetworkCallback cb = new TestNetworkCallback();
+        sCM.requestNetwork(nr, cb);
+        cb.waitForAvailable();
+        final Network vpnNetwork = cb.currentNetwork;
+        assertNotNull(vpnNetwork);
+
+        sVpnMgr.stopProvisionedVpnProfile();
+        cb.waitForLost();
+        assertEquals(vpnNetwork, cb.lastLostNetwork);
+    }
+
+    private void doTestStartStopVpnProfile() throws Exception {
+        // Non-final; these variables ensure we clean up properly after our test if we have
+        // allocated test network resources
+        final TestNetworkManager tnm = sContext.getSystemService(TestNetworkManager.class);
+        TestNetworkInterface testIface = null;
+        TestNetworkCallback tunNetworkCallback = null;
+
+        try {
+            // Build underlying test network
+            testIface = tnm.createTunInterface(
+                    new LinkAddress[] {
+                            new LinkAddress(LOCAL_OUTER_4, IP4_PREFIX_LEN),
+                            new LinkAddress(LOCAL_OUTER_6, IP6_PREFIX_LEN)});
+
+            // Hold on to this callback to ensure network does not get reaped.
+            tunNetworkCallback = mCtsNetUtils.setupAndGetTestNetwork(testIface.getInterfaceName());
+            final IkeTunUtils tunUtils = new IkeTunUtils(testIface.getFileDescriptor());
+
+            checkStartStopVpnProfileBuildsNetworks(tunUtils);
+        } finally {
+            // Make sure to stop the VPN profile. This is safe to call multiple times.
+            sVpnMgr.stopProvisionedVpnProfile();
+
+            if (testIface != null) {
+                testIface.getFileDescriptor().close();
+            }
+
+            if (tunNetworkCallback != null) {
+                sCM.unregisterNetworkCallback(tunNetworkCallback);
+            }
+
+            final Network testNetwork = tunNetworkCallback.currentNetwork;
+            if (testNetwork != null) {
+                tnm.teardownTestNetwork(testNetwork);
+            }
+        }
+    }
+
     @Test
     public void testStartStopVpnProfile() throws Exception {
         assumeTrue(mCtsNetUtils.hasIpsecTunnelsFeature());
 
-        // Requires MANAGE_TEST_NETWORKS to provision a test-mode profile.
+        // Requires shell permission to update appops.
         runWithShellPermissionIdentity(() -> {
-            mCtsNetUtils.setAppopPrivileged(AppOpsManager.OP_ACTIVATE_PLATFORM_VPN, true);
-
-            final Ikev2VpnProfile profile =
-                    buildIkev2VpnProfilePsk(true /* isRestrictedToTestNetworks */);
-            assertNull(sVpnMgr.provisionVpnProfile(profile));
-
-            sVpnMgr.startProvisionedVpnProfile();
-            // TODO: When IKEv2 setup is injectable, verify network was set up properly.
-
-            sVpnMgr.stopProvisionedVpnProfile();
-            // TODO: When IKEv2 setup is injectable, verify network is lost.
-        }, Manifest.permission.MANAGE_TEST_NETWORKS);
+            doTestStartStopVpnProfile();
+        });
     }
 
     private static class CertificateAndKey {
diff --git a/tests/cts/net/src/android/net/cts/TunUtils.java b/tests/cts/net/src/android/net/cts/TunUtils.java
index a030713..adaba9d 100644
--- a/tests/cts/net/src/android/net/cts/TunUtils.java
+++ b/tests/cts/net/src/android/net/cts/TunUtils.java
@@ -21,8 +21,8 @@
 import static android.net.cts.PacketUtils.IPPROTO_ESP;
 import static android.net.cts.PacketUtils.UDP_HDRLEN;
 import static android.system.OsConstants.IPPROTO_UDP;
+
 import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.fail;
 
 import android.os.ParcelFileDescriptor;
@@ -39,19 +39,18 @@
 public class TunUtils {
     private static final String TAG = TunUtils.class.getSimpleName();
 
+    protected static final int IP4_ADDR_OFFSET = 12;
+    protected static final int IP4_ADDR_LEN = 4;
+    protected static final int IP6_ADDR_OFFSET = 8;
+    protected static final int IP6_ADDR_LEN = 16;
+    protected static final int IP4_PROTO_OFFSET = 9;
+    protected static final int IP6_PROTO_OFFSET = 6;
+
     private static final int DATA_BUFFER_LEN = 4096;
-    private static final int TIMEOUT = 100;
+    private static final int TIMEOUT = 1000;
 
-    private static final int IP4_PROTO_OFFSET = 9;
-    private static final int IP6_PROTO_OFFSET = 6;
-
-    private static final int IP4_ADDR_OFFSET = 12;
-    private static final int IP4_ADDR_LEN = 4;
-    private static final int IP6_ADDR_OFFSET = 8;
-    private static final int IP6_ADDR_LEN = 16;
-
-    private final ParcelFileDescriptor mTunFd;
     private final List<byte[]> mPackets = new ArrayList<>();
+    private final ParcelFileDescriptor mTunFd;
     private final Thread mReaderThread;
 
     public TunUtils(ParcelFileDescriptor tunFd) {
@@ -112,46 +111,15 @@
         return null;
     }
 
-    /**
-     * Checks if the specified bytes were ever sent in plaintext.
-     *
-     * <p>Only checks for known plaintext bytes to prevent triggering on ICMP/RA packets or the like
-     *
-     * @param plaintext the plaintext bytes to check for
-     * @param startIndex the index in the list to check for
-     */
-    public boolean hasPlaintextPacket(byte[] plaintext, int startIndex) {
-        Predicate<byte[]> verifier =
-                (pkt) -> {
-                    return Collections.indexOfSubList(Arrays.asList(pkt), Arrays.asList(plaintext))
-                            != -1;
-                };
-        return getFirstMatchingPacket(verifier, startIndex) != null;
-    }
-
-    public byte[] getEspPacket(int spi, boolean encap, int startIndex) {
-        return getFirstMatchingPacket(
-                (pkt) -> {
-                    return isEsp(pkt, spi, encap);
-                },
-                startIndex);
-    }
-
-    public byte[] awaitEspPacketNoPlaintext(
-            int spi, byte[] plaintext, boolean useEncap, int expectedPacketSize) throws Exception {
+    protected byte[] awaitPacket(Predicate<byte[]> verifier) throws Exception {
         long endTime = System.currentTimeMillis() + TIMEOUT;
         int startIndex = 0;
 
         synchronized (mPackets) {
             while (System.currentTimeMillis() < endTime) {
-                byte[] espPkt = getEspPacket(spi, useEncap, startIndex);
-                if (espPkt != null) {
-                    // Validate packet size
-                    assertEquals(expectedPacketSize, espPkt.length);
-
-                    // Always check plaintext from start
-                    assertFalse(hasPlaintextPacket(plaintext, 0));
-                    return espPkt; // We've found the packet we're looking for.
+                final byte[] pkt = getFirstMatchingPacket(verifier, startIndex);
+                if (pkt != null) {
+                    return pkt; // We've found the packet we're looking for.
                 }
 
                 startIndex = mPackets.size();
@@ -162,10 +130,21 @@
                     mPackets.wait(waitTimeout);
                 }
             }
-
-            fail("No such ESP packet found with SPI " + spi);
         }
-        return null;
+
+        fail("No packet found matching verifier");
+        throw new IllegalStateException("Impossible condition; should have thrown in fail()");
+    }
+
+    public byte[] awaitEspPacketNoPlaintext(
+            int spi, byte[] plaintext, boolean useEncap, int expectedPacketSize) throws Exception {
+        final byte[] espPkt = awaitPacket(
+                (pkt) -> isEspFailIfSpecifiedPlaintextFound(pkt, spi, useEncap, plaintext));
+
+        // Validate packet size
+        assertEquals(expectedPacketSize, espPkt.length);
+
+        return espPkt; // We've found the packet we're looking for.
     }
 
     private static boolean isSpiEqual(byte[] pkt, int espOffset, int spi) {
@@ -176,6 +155,24 @@
                 && pkt[espOffset + 3] == (byte) (spi & 0xff);
     }
 
+    /**
+     * Variant of isEsp that also fails the test if the provided plaintext is found
+     *
+     * @param pkt the packet bytes to verify
+     * @param spi the expected SPI to look for
+     * @param encap whether encap was enabled, and the packet has a UDP header
+     * @param plaintext the plaintext packet before outbound encryption, which MUST not appear in
+     *     the provided packet.
+     */
+    private static boolean isEspFailIfSpecifiedPlaintextFound(
+            byte[] pkt, int spi, boolean encap, byte[] plaintext) {
+        if (Collections.indexOfSubList(Arrays.asList(pkt), Arrays.asList(plaintext)) != -1) {
+            fail("Banned plaintext packet found");
+        }
+
+        return isEsp(pkt, spi, encap);
+    }
+
     private static boolean isEsp(byte[] pkt, int spi, boolean encap) {
         if (isIpv6(pkt)) {
             // IPv6 UDP encap not supported by kernels; assume non-encap.
@@ -191,7 +188,7 @@
         }
     }
 
-    private static boolean isIpv6(byte[] pkt) {
+    public static boolean isIpv6(byte[] pkt) {
         // First nibble shows IP version. 0x60 for IPv6
         return (pkt[0] & (byte) 0xF0) == (byte) 0x60;
     }