Merge changes Ic4181fc8,Icffeed2e,I9fdba4a9
am: 54c0726220

Change-Id: I3a5a1a1ae0096eba2993798bf2707a8cfda81c97
diff --git a/tests/cts/net/src/android/net/cts/IpSecBaseTest.java b/tests/cts/net/src/android/net/cts/IpSecBaseTest.java
index 35d0f48..087dbda 100644
--- a/tests/cts/net/src/android/net/cts/IpSecBaseTest.java
+++ b/tests/cts/net/src/android/net/cts/IpSecBaseTest.java
@@ -28,11 +28,12 @@
 import android.test.AndroidTestCase;
 import android.util.Log;
 
+import androidx.test.InstrumentationRegistry;
+
 import java.io.FileDescriptor;
 import java.io.IOException;
 import java.net.DatagramPacket;
 import java.net.DatagramSocket;
-import java.net.Inet4Address;
 import java.net.Inet6Address;
 import java.net.InetAddress;
 import java.net.InetSocketAddress;
@@ -72,8 +73,14 @@
 
     protected void setUp() throws Exception {
         super.setUp();
-        mISM = (IpSecManager) getContext().getSystemService(Context.IPSEC_SERVICE);
-        mCM = (ConnectivityManager) getContext().getSystemService(Context.CONNECTIVITY_SERVICE);
+        mISM =
+                (IpSecManager)
+                        InstrumentationRegistry.getContext()
+                                .getSystemService(Context.IPSEC_SERVICE);
+        mCM =
+                (ConnectivityManager)
+                        InstrumentationRegistry.getContext()
+                                .getSystemService(Context.CONNECTIVITY_SERVICE);
     }
 
     protected static byte[] getKey(int bitLength) {
@@ -195,6 +202,17 @@
     public static class JavaUdpSocket implements GenericUdpSocket {
         public final DatagramSocket mSocket;
 
+        public JavaUdpSocket(InetAddress localAddr, int port) {
+            try {
+                mSocket = new DatagramSocket(port, localAddr);
+                mSocket.setSoTimeout(SOCK_TIMEOUT);
+            } catch (SocketException e) {
+                // Fail loudly if we can't set up sockets properly. And without the timeout, we
+                // could easily end up in an endless wait.
+                throw new RuntimeException(e);
+            }
+        }
+
         public JavaUdpSocket(InetAddress localAddr) {
             try {
                 mSocket = new DatagramSocket(0, localAddr);
@@ -425,26 +443,25 @@
     }
 
     protected static IpSecTransform buildIpSecTransform(
-            Context mContext,
+            Context context,
             IpSecManager.SecurityParameterIndex spi,
             IpSecManager.UdpEncapsulationSocket encapSocket,
             InetAddress remoteAddr)
             throws Exception {
-        String localAddr = (remoteAddr instanceof Inet4Address) ? IPV4_LOOPBACK : IPV6_LOOPBACK;
         IpSecTransform.Builder builder =
-                new IpSecTransform.Builder(mContext)
-                .setEncryption(new IpSecAlgorithm(IpSecAlgorithm.CRYPT_AES_CBC, CRYPT_KEY))
-                .setAuthentication(
-                        new IpSecAlgorithm(
-                                IpSecAlgorithm.AUTH_HMAC_SHA256,
-                                AUTH_KEY,
-                                AUTH_KEY.length * 4));
+                new IpSecTransform.Builder(context)
+                        .setEncryption(new IpSecAlgorithm(IpSecAlgorithm.CRYPT_AES_CBC, CRYPT_KEY))
+                        .setAuthentication(
+                                new IpSecAlgorithm(
+                                        IpSecAlgorithm.AUTH_HMAC_SHA256,
+                                        AUTH_KEY,
+                                        AUTH_KEY.length * 4));
 
         if (encapSocket != null) {
             builder.setIpv4Encapsulation(encapSocket, encapSocket.getPort());
         }
 
-        return builder.buildTransportModeTransform(InetAddress.getByName(localAddr), spi);
+        return builder.buildTransportModeTransform(remoteAddr, spi);
     }
 
     private IpSecTransform buildDefaultTransform(InetAddress localAddr) throws Exception {
diff --git a/tests/cts/net/src/android/net/cts/IpSecManagerTest.java b/tests/cts/net/src/android/net/cts/IpSecManagerTest.java
index 3387064..60d1c03 100644
--- a/tests/cts/net/src/android/net/cts/IpSecManagerTest.java
+++ b/tests/cts/net/src/android/net/cts/IpSecManagerTest.java
@@ -16,6 +16,14 @@
 
 package android.net.cts;
 
+import static android.net.cts.PacketUtils.AES_CBC_BLK_SIZE;
+import static android.net.cts.PacketUtils.AES_CBC_IV_LEN;
+import static android.net.cts.PacketUtils.AES_GCM_BLK_SIZE;
+import static android.net.cts.PacketUtils.AES_GCM_IV_LEN;
+import static android.net.cts.PacketUtils.IP4_HDRLEN;
+import static android.net.cts.PacketUtils.IP6_HDRLEN;
+import static android.net.cts.PacketUtils.TCP_HDRLEN_WITH_TIMESTAMP_OPT;
+import static android.net.cts.PacketUtils.UDP_HDRLEN;
 import static android.system.OsConstants.IPPROTO_TCP;
 import static android.system.OsConstants.IPPROTO_UDP;
 import static org.junit.Assert.assertArrayEquals;
@@ -53,17 +61,6 @@
 
     private static final byte[] AEAD_KEY = getKey(288);
 
-    private static final int TCP_HDRLEN_WITH_OPTIONS = 32;
-    private static final int UDP_HDRLEN = 8;
-    private static final int IP4_HDRLEN = 20;
-    private static final int IP6_HDRLEN = 40;
-
-    // Encryption parameters
-    private static final int AES_GCM_IV_LEN = 8;
-    private static final int AES_CBC_IV_LEN = 16;
-    private static final int AES_GCM_BLK_SIZE = 4;
-    private static final int AES_CBC_BLK_SIZE = 16;
-
     protected void setUp() throws Exception {
         super.setUp();
     }
@@ -432,19 +429,6 @@
         }
     }
 
-    /** Helper function to calculate expected ESP packet size. */
-    private int calculateEspPacketSize(
-            int payloadLen, int cryptIvLength, int cryptBlockSize, int authTruncLen) {
-        final int ESP_HDRLEN = 4 + 4; // SPI + Seq#
-        final int ICV_LEN = authTruncLen / 8; // Auth trailer; based on truncation length
-        payloadLen += cryptIvLength; // Initialization Vector
-        payloadLen += 2; // ESP trailer
-
-        // Align to block size of encryption algorithm
-        payloadLen += (cryptBlockSize - (payloadLen % cryptBlockSize)) % cryptBlockSize;
-        return payloadLen + ESP_HDRLEN + ICV_LEN;
-    }
-
     public void checkTransform(
             int protocol,
             String localAddress,
@@ -485,7 +469,7 @@
             try (IpSecTransform transform =
                         transformBuilder.buildTransportModeTransform(local, spi)) {
                 if (protocol == IPPROTO_TCP) {
-                    transportHdrLen = TCP_HDRLEN_WITH_OPTIONS;
+                    transportHdrLen = TCP_HDRLEN_WITH_TIMESTAMP_OPT;
                     checkTcp(transform, local, sendCount, useJavaSockets);
                 } else if (protocol == IPPROTO_UDP) {
                     transportHdrLen = UDP_HDRLEN;
@@ -522,7 +506,7 @@
 
         int innerPacketSize = TEST_DATA.length + transportHdrLen + ipHdrLen;
         int outerPacketSize =
-                calculateEspPacketSize(
+                PacketUtils.calculateEspPacketSize(
                                 TEST_DATA.length + transportHdrLen, ivLen, blkSize, truncLenBits)
                         + udpEncapLen
                         + ipHdrLen;
@@ -540,13 +524,13 @@
         // Add TCP ACKs for data packets
         if (protocol == IPPROTO_TCP) {
             int encryptedTcpPktSize =
-                    calculateEspPacketSize(TCP_HDRLEN_WITH_OPTIONS, ivLen, blkSize, truncLenBits);
+                    PacketUtils.calculateEspPacketSize(
+                            TCP_HDRLEN_WITH_TIMESTAMP_OPT, ivLen, blkSize, truncLenBits);
 
-
-                // Add data packet ACKs
-                expectedOuterBytes += (encryptedTcpPktSize + udpEncapLen + ipHdrLen) * (sendCount);
-                expectedInnerBytes += (TCP_HDRLEN_WITH_OPTIONS + ipHdrLen) * (sendCount);
-                expectedPackets += sendCount;
+            // Add data packet ACKs
+            expectedOuterBytes += (encryptedTcpPktSize + udpEncapLen + ipHdrLen) * (sendCount);
+            expectedInnerBytes += (TCP_HDRLEN_WITH_TIMESTAMP_OPT + ipHdrLen) * (sendCount);
+            expectedPackets += sendCount;
         }
 
         StatsChecker.waitForNumPackets(expectedPackets);
diff --git a/tests/cts/net/src/android/net/cts/IpSecManagerTunnelTest.java b/tests/cts/net/src/android/net/cts/IpSecManagerTunnelTest.java
index c8c99f4..e8c0a7a 100644
--- a/tests/cts/net/src/android/net/cts/IpSecManagerTunnelTest.java
+++ b/tests/cts/net/src/android/net/cts/IpSecManagerTunnelTest.java
@@ -16,174 +16,728 @@
 
 package android.net.cts;
 
+import static android.app.AppOpsManager.OP_MANAGE_IPSEC_TUNNELS;
+import static android.net.IpSecManager.UdpEncapsulationSocket;
+import static android.net.NetworkCapabilities.NET_CAPABILITY_NOT_VPN;
+import static android.net.NetworkCapabilities.NET_CAPABILITY_TRUSTED;
+import static android.net.NetworkCapabilities.TRANSPORT_TEST;
+import static android.net.cts.PacketUtils.AES_CBC_BLK_SIZE;
+import static android.net.cts.PacketUtils.AES_CBC_IV_LEN;
+import static android.net.cts.PacketUtils.BytePayload;
+import static android.net.cts.PacketUtils.EspHeader;
+import static android.net.cts.PacketUtils.IP4_HDRLEN;
+import static android.net.cts.PacketUtils.IP6_HDRLEN;
+import static android.net.cts.PacketUtils.Ip4Header;
+import static android.net.cts.PacketUtils.Ip6Header;
+import static android.net.cts.PacketUtils.IpHeader;
+import static android.net.cts.PacketUtils.UDP_HDRLEN;
+import static android.net.cts.PacketUtils.UdpHeader;
+import static android.system.OsConstants.AF_INET;
+import static android.system.OsConstants.AF_INET6;
+
+import static org.junit.Assert.assertArrayEquals;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertTrue;
 
+import android.app.AppOpsManager;
+import android.content.Context;
 import android.content.pm.PackageManager;
+import android.net.ConnectivityManager;
 import android.net.IpSecAlgorithm;
 import android.net.IpSecManager;
 import android.net.IpSecTransform;
+import android.net.LinkAddress;
 import android.net.Network;
+import android.net.NetworkRequest;
+import android.net.TestNetworkInterface;
+import android.net.TestNetworkManager;
+import android.net.cts.PacketUtils.Payload;
+import android.os.Binder;
+import android.os.IBinder;
+import android.os.ParcelFileDescriptor;
+
+import androidx.test.InstrumentationRegistry;
+import androidx.test.runner.AndroidJUnit4;
 
 import com.android.compatibility.common.util.SystemUtil;
 
+import java.net.Inet4Address;
 import java.net.Inet6Address;
 import java.net.InetAddress;
 import java.net.InterfaceAddress;
 import java.net.NetworkInterface;
+import java.util.List;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.TimeUnit;
 
+import org.junit.AfterClass;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+
+@RunWith(AndroidJUnit4.class)
 public class IpSecManagerTunnelTest extends IpSecBaseTest {
-
     private static final String TAG = IpSecManagerTunnelTest.class.getSimpleName();
-    private static final int IP4_PREFIX_LEN = 24;
-    private static final int IP6_PREFIX_LEN = 48;
-    private static final InetAddress OUTER_ADDR4 = InetAddress.parseNumericAddress("192.0.2.0");
-    private static final InetAddress OUTER_ADDR6 =
-            InetAddress.parseNumericAddress("2001:db8:f00d::1");
-    private static final InetAddress INNER_ADDR4 = InetAddress.parseNumericAddress("10.0.0.1");
-    private static final InetAddress INNER_ADDR6 =
-            InetAddress.parseNumericAddress("2001:db8:d00d::1");
 
-    private Network mUnderlyingNetwork;
-    private Network mIpSecNetwork;
+    private static final InetAddress LOCAL_OUTER_4 = InetAddress.parseNumericAddress("192.0.2.1");
+    private static final InetAddress REMOTE_OUTER_4 = InetAddress.parseNumericAddress("192.0.2.2");
+    private static final InetAddress LOCAL_OUTER_6 =
+            InetAddress.parseNumericAddress("2001:db8:1::1");
+    private static final InetAddress REMOTE_OUTER_6 =
+            InetAddress.parseNumericAddress("2001:db8:1::2");
 
-    protected void setUp() throws Exception {
+    private static final InetAddress LOCAL_INNER_4 =
+            InetAddress.parseNumericAddress("198.51.100.1");
+    private static final InetAddress REMOTE_INNER_4 =
+            InetAddress.parseNumericAddress("198.51.100.2");
+    private static final InetAddress LOCAL_INNER_6 =
+            InetAddress.parseNumericAddress("2001:db8:2::1");
+    private static final InetAddress REMOTE_INNER_6 =
+            InetAddress.parseNumericAddress("2001:db8:2::2");
+
+    private static final int IP4_PREFIX_LEN = 32;
+    private static final int IP6_PREFIX_LEN = 128;
+
+    private static final int TIMEOUT_MS = 500;
+
+    // Static state to reduce setup/teardown
+    private static ConnectivityManager sCM;
+    private static TestNetworkManager sTNM;
+    private static ParcelFileDescriptor sTunFd;
+    private static TestNetworkCallback sTunNetworkCallback;
+    private static Network sTunNetwork;
+    private static TunUtils sTunUtils;
+
+    private static Context sContext = InstrumentationRegistry.getContext();
+    private static IBinder sBinder = new Binder();
+
+    @BeforeClass
+    public static void setUpBeforeClass() throws Exception {
+        InstrumentationRegistry.getInstrumentation()
+                .getUiAutomation()
+                .adoptShellPermissionIdentity();
+        sCM = (ConnectivityManager) sContext.getSystemService(Context.CONNECTIVITY_SERVICE);
+        sTNM = (TestNetworkManager) sContext.getSystemService(Context.TEST_NETWORK_SERVICE);
+
+        // Under normal circumstances, the MANAGE_IPSEC_TUNNELS appop would be auto-granted, and
+        // a standard permission is insufficient. So we shell out the appop, to give us the
+        // right appop permissions.
+        setAppop(OP_MANAGE_IPSEC_TUNNELS, true);
+
+        TestNetworkInterface testIntf =
+                sTNM.createTunInterface(
+                        new LinkAddress[] {
+                            new LinkAddress(LOCAL_OUTER_4, IP4_PREFIX_LEN),
+                            new LinkAddress(LOCAL_OUTER_6, IP6_PREFIX_LEN)
+                        });
+
+        sTunFd = testIntf.getFileDescriptor();
+        sTunNetworkCallback = setupAndGetTestNetwork(testIntf.getInterfaceName());
+        sTunNetwork = sTunNetworkCallback.getNetworkBlocking();
+
+        sTunUtils = new TunUtils(sTunFd);
+    }
+
+    @Before
+    public void setUp() throws Exception {
         super.setUp();
+
+        // Set to true before every run; some tests flip this.
+        setAppop(OP_MANAGE_IPSEC_TUNNELS, true);
+
+        // Clear sTunUtils state
+        sTunUtils.reset();
     }
 
-    protected void tearDown() {
-        setAppop(false);
+    @AfterClass
+    public static void tearDownAfterClass() throws Exception {
+        setAppop(OP_MANAGE_IPSEC_TUNNELS, false);
+
+        sCM.unregisterNetworkCallback(sTunNetworkCallback);
+
+        sTNM.teardownTestNetwork(sTunNetwork);
+        sTunFd.close();
+
+        InstrumentationRegistry.getInstrumentation()
+                .getUiAutomation()
+                .dropShellPermissionIdentity();
     }
 
-    private boolean hasTunnelsFeature() {
-        return getContext()
-                .getPackageManager()
-                .hasSystemFeature(PackageManager.FEATURE_IPSEC_TUNNELS);
+    private static boolean hasTunnelsFeature() {
+        return sContext.getPackageManager().hasSystemFeature(PackageManager.FEATURE_IPSEC_TUNNELS);
     }
 
-    private void setAppop(boolean allow) {
-        // Under normal circumstances, the MANAGE_IPSEC_TUNNELS appop would be auto-granted by the
-        // telephony framework, and the only permission that is sufficient is NETWORK_STACK. So we
-        // shell out the appop manager, to give us the right appop permissions.
-        String cmd =
-                "appops set "
-                        + mContext.getPackageName()
-                        + " MANAGE_IPSEC_TUNNELS "
-                + (allow ? "allow" : "deny");
-        SystemUtil.runShellCommand(cmd);
+    private static void setAppop(int appop, boolean allow) {
+        String opName = AppOpsManager.opToName(appop);
+        for (String pkg : new String[] {"com.android.shell", sContext.getPackageName()}) {
+            String cmd =
+                    String.format(
+                            "appops set %s %s %s",
+                            pkg, // Package name
+                            opName, // Appop
+                            (allow ? "allow" : "deny")); // Action
+            SystemUtil.runShellCommand(cmd);
+        }
     }
 
-    public void testSecurityExceptionsCreateTunnelInterface() throws Exception {
+    private static TestNetworkCallback setupAndGetTestNetwork(String ifname) throws Exception {
+        // Build a network request
+        NetworkRequest nr =
+                new NetworkRequest.Builder()
+                        .addTransportType(TRANSPORT_TEST)
+                        .removeCapability(NET_CAPABILITY_TRUSTED)
+                        .removeCapability(NET_CAPABILITY_NOT_VPN)
+                        .setNetworkSpecifier(ifname)
+                        .build();
+
+        TestNetworkCallback cb = new TestNetworkCallback();
+        sCM.requestNetwork(nr, cb);
+
+        // Setup the test network after network request is filed to prevent Network from being
+        // reaped due to no requests matching it.
+        sTNM.setupTestNetwork(ifname, sBinder);
+
+        return cb;
+    }
+
+    @Test
+    public void testSecurityExceptionCreateTunnelInterfaceWithoutAppop() throws Exception {
         if (!hasTunnelsFeature()) return;
 
         // Ensure we don't have the appop. Permission is not requested in the Manifest
-        setAppop(false);
+        setAppop(OP_MANAGE_IPSEC_TUNNELS, false);
 
         // Security exceptions are thrown regardless of IPv4/IPv6. Just test one
         try {
-            mISM.createIpSecTunnelInterface(OUTER_ADDR6, OUTER_ADDR6, mUnderlyingNetwork);
+            mISM.createIpSecTunnelInterface(LOCAL_INNER_6, REMOTE_INNER_6, sTunNetwork);
             fail("Did not throw SecurityException for Tunnel creation without appop");
         } catch (SecurityException expected) {
         }
     }
 
-    public void testSecurityExceptionsBuildTunnelTransform() throws Exception {
+    @Test
+    public void testSecurityExceptionBuildTunnelTransformWithoutAppop() throws Exception {
         if (!hasTunnelsFeature()) return;
 
         // Ensure we don't have the appop. Permission is not requested in the Manifest
-        setAppop(false);
+        setAppop(OP_MANAGE_IPSEC_TUNNELS, false);
 
         // Security exceptions are thrown regardless of IPv4/IPv6. Just test one
         try (IpSecManager.SecurityParameterIndex spi =
-                mISM.allocateSecurityParameterIndex(OUTER_ADDR4);
+                        mISM.allocateSecurityParameterIndex(LOCAL_INNER_4);
                 IpSecTransform transform =
-                        new IpSecTransform.Builder(mContext)
-                                .buildTunnelModeTransform(OUTER_ADDR4, spi)) {
+                        new IpSecTransform.Builder(sContext)
+                                .buildTunnelModeTransform(REMOTE_INNER_4, spi)) {
             fail("Did not throw SecurityException for Transform creation without appop");
         } catch (SecurityException expected) {
         }
     }
 
-    private void checkTunnel(InetAddress inner, InetAddress outer, boolean useEncap)
+    /* Test runnables for callbacks after IPsec tunnels are set up. */
+    private interface TestRunnable {
+        void run(Network ipsecNetwork) throws Exception;
+    }
+
+    private static class TestNetworkCallback extends ConnectivityManager.NetworkCallback {
+        private final CompletableFuture<Network> futureNetwork = new CompletableFuture<>();
+
+        @Override
+        public void onAvailable(Network network) {
+            futureNetwork.complete(network);
+        }
+
+        public Network getNetworkBlocking() throws Exception {
+            return futureNetwork.get(TIMEOUT_MS, TimeUnit.MILLISECONDS);
+        }
+    }
+
+    private int getPacketSize(
+            int innerFamily, int outerFamily, boolean useEncap, boolean transportInTunnelMode) {
+        int expectedPacketSize = TEST_DATA.length + UDP_HDRLEN;
+
+        // Inner Transport mode packet size
+        if (transportInTunnelMode) {
+            expectedPacketSize =
+                    PacketUtils.calculateEspPacketSize(
+                            expectedPacketSize,
+                            AES_CBC_IV_LEN,
+                            AES_CBC_BLK_SIZE,
+                            AUTH_KEY.length * 4);
+        }
+
+        // Inner IP Header
+        expectedPacketSize += innerFamily == AF_INET ? IP4_HDRLEN : IP6_HDRLEN;
+
+        // Tunnel mode transform size
+        expectedPacketSize =
+                PacketUtils.calculateEspPacketSize(
+                        expectedPacketSize, AES_CBC_IV_LEN, AES_CBC_BLK_SIZE, AUTH_KEY.length * 4);
+
+        // UDP encap size
+        expectedPacketSize += useEncap ? UDP_HDRLEN : 0;
+
+        // Outer IP Header
+        expectedPacketSize += outerFamily == AF_INET ? IP4_HDRLEN : IP6_HDRLEN;
+
+        return expectedPacketSize;
+    }
+
+    private interface TestRunnableFactory {
+        TestRunnable getTestRunnable(
+                boolean transportInTunnelMode,
+                int spi,
+                InetAddress localInner,
+                InetAddress remoteInner,
+                InetAddress localOuter,
+                InetAddress remoteOuter,
+                IpSecTransform inTransportTransform,
+                IpSecTransform outTransportTransform,
+                int encapPort,
+                int expectedPacketSize)
+                throws Exception;
+    }
+
+    private class OutputTestRunnableFactory implements TestRunnableFactory {
+        public TestRunnable getTestRunnable(
+                boolean transportInTunnelMode,
+                int spi,
+                InetAddress localInner,
+                InetAddress remoteInner,
+                InetAddress localOuter,
+                InetAddress remoteOuter,
+                IpSecTransform inTransportTransform,
+                IpSecTransform outTransportTransform,
+                int encapPort,
+                int expectedPacketSize) {
+            return new TestRunnable() {
+                @Override
+                public void run(Network ipsecNetwork) throws Exception {
+                    // Build a socket and send traffic
+                    JavaUdpSocket socket = new JavaUdpSocket(localInner);
+                    ipsecNetwork.bindSocket(socket.mSocket);
+
+                    // For Transport-In-Tunnel mode, apply transform to socket
+                    if (transportInTunnelMode) {
+                        mISM.applyTransportModeTransform(
+                                socket.mSocket, IpSecManager.DIRECTION_IN, inTransportTransform);
+                        mISM.applyTransportModeTransform(
+                                socket.mSocket, IpSecManager.DIRECTION_OUT, outTransportTransform);
+                    }
+
+                    socket.sendTo(TEST_DATA, remoteInner, socket.getPort());
+
+                    // Verify that an encrypted packet is sent. As of right now, checking encrypted
+                    // body is not possible, due to our not knowing some of the fields of the
+                    // inner IP header (flow label, flags, etc)
+                    sTunUtils.awaitEspPacketNoPlaintext(
+                            spi, TEST_DATA, encapPort != 0, expectedPacketSize);
+
+                    socket.close();
+                }
+            };
+        }
+    }
+
+    private class InputPacketGeneratorTestRunnableFactory implements TestRunnableFactory {
+        public TestRunnable getTestRunnable(
+                boolean transportInTunnelMode,
+                int spi,
+                InetAddress localInner,
+                InetAddress remoteInner,
+                InetAddress localOuter,
+                InetAddress remoteOuter,
+                IpSecTransform inTransportTransform,
+                IpSecTransform outTransportTransform,
+                int encapPort,
+                int expectedPacketSize)
+                throws Exception {
+            return new TestRunnable() {
+                @Override
+                public void run(Network ipsecNetwork) throws Exception {
+                    // Build a socket and receive traffic
+                    JavaUdpSocket socket = new JavaUdpSocket(localInner);
+                    // JavaUdpSocket socket = new JavaUdpSocket(localInner, socketPort.get());
+                    ipsecNetwork.bindSocket(socket.mSocket);
+
+                    // For Transport-In-Tunnel mode, apply transform to socket
+                    if (transportInTunnelMode) {
+                        mISM.applyTransportModeTransform(
+                                socket.mSocket, IpSecManager.DIRECTION_IN, outTransportTransform);
+                        mISM.applyTransportModeTransform(
+                                socket.mSocket, IpSecManager.DIRECTION_OUT, inTransportTransform);
+                    }
+
+                    byte[] pkt;
+                    if (transportInTunnelMode) {
+                        pkt =
+                                getTransportInTunnelModePacket(
+                                        spi,
+                                        spi,
+                                        remoteInner,
+                                        localInner,
+                                        remoteOuter,
+                                        localOuter,
+                                        socket.getPort(),
+                                        encapPort);
+                    } else {
+                        pkt =
+                                getTunnelModePacket(
+                                        spi,
+                                        remoteInner,
+                                        localInner,
+                                        remoteOuter,
+                                        localOuter,
+                                        socket.getPort(),
+                                        encapPort);
+                    }
+                    sTunUtils.injectPacket(pkt);
+
+                    // Receive packet from socket, and validate
+                    receiveAndValidatePacket(socket);
+
+                    socket.close();
+                }
+            };
+        }
+    }
+
+    private void checkTunnelOutput(
+            int innerFamily, int outerFamily, boolean useEncap, boolean transportInTunnelMode)
+            throws Exception {
+        checkTunnel(
+                innerFamily,
+                outerFamily,
+                useEncap,
+                transportInTunnelMode,
+                new OutputTestRunnableFactory());
+    }
+
+    private void checkTunnelInput(
+            int innerFamily, int outerFamily, boolean useEncap, boolean transportInTunnelMode)
+            throws Exception {
+        checkTunnel(
+                innerFamily,
+                outerFamily,
+                useEncap,
+                transportInTunnelMode,
+                new InputPacketGeneratorTestRunnableFactory());
+    }
+
+    public void checkTunnel(
+            int innerFamily,
+            int outerFamily,
+            boolean useEncap,
+            boolean transportInTunnelMode,
+            TestRunnableFactory factory)
             throws Exception {
         if (!hasTunnelsFeature()) return;
 
-        setAppop(true);
-        int innerPrefixLen = inner instanceof Inet6Address ? IP6_PREFIX_LEN : IP4_PREFIX_LEN;
+        InetAddress localInner = innerFamily == AF_INET ? LOCAL_INNER_4 : LOCAL_INNER_6;
+        InetAddress remoteInner = innerFamily == AF_INET ? REMOTE_INNER_4 : REMOTE_INNER_6;
 
-        try (IpSecManager.SecurityParameterIndex spi = mISM.allocateSecurityParameterIndex(outer);
+        InetAddress localOuter = outerFamily == AF_INET ? LOCAL_OUTER_4 : LOCAL_OUTER_6;
+        InetAddress remoteOuter = outerFamily == AF_INET ? REMOTE_OUTER_4 : REMOTE_OUTER_6;
+
+        // Preselect both SPI and encap port, to be used for both inbound and outbound tunnels.
+        // Re-uses the same SPI to ensure that even in cases of symmetric SPIs shared across tunnel
+        // and transport mode, packets are encrypted/decrypted properly based on the src/dst.
+        int spi = getRandomSpi(localOuter, remoteOuter);
+        int expectedPacketSize =
+                getPacketSize(innerFamily, outerFamily, useEncap, transportInTunnelMode);
+
+        try (IpSecManager.SecurityParameterIndex inTransportSpi =
+                        mISM.allocateSecurityParameterIndex(localInner, spi);
+                IpSecManager.SecurityParameterIndex outTransportSpi =
+                        mISM.allocateSecurityParameterIndex(remoteInner, spi);
+                IpSecTransform inTransportTransform =
+                        buildIpSecTransform(sContext, inTransportSpi, null, remoteInner);
+                IpSecTransform outTransportTransform =
+                        buildIpSecTransform(sContext, outTransportSpi, null, localInner);
+                UdpEncapsulationSocket encapSocket = mISM.openUdpEncapsulationSocket()) {
+
+            buildTunnelAndNetwork(
+                    localInner,
+                    remoteInner,
+                    localOuter,
+                    remoteOuter,
+                    spi,
+                    useEncap ? encapSocket : null,
+                    factory.getTestRunnable(
+                            transportInTunnelMode,
+                            spi,
+                            localInner,
+                            remoteInner,
+                            localOuter,
+                            remoteOuter,
+                            inTransportTransform,
+                            outTransportTransform,
+                            useEncap ? encapSocket.getPort() : 0,
+                            expectedPacketSize));
+        }
+    }
+
+    private void buildTunnelAndNetwork(
+            InetAddress localInner,
+            InetAddress remoteInner,
+            InetAddress localOuter,
+            InetAddress remoteOuter,
+            int spi,
+            UdpEncapsulationSocket encapSocket,
+            TestRunnable test)
+            throws Exception {
+        int innerPrefixLen = localInner instanceof Inet6Address ? IP6_PREFIX_LEN : IP4_PREFIX_LEN;
+        TestNetworkCallback testNetworkCb = null;
+
+        try (IpSecManager.SecurityParameterIndex inSpi =
+                        mISM.allocateSecurityParameterIndex(localOuter, spi);
+                IpSecManager.SecurityParameterIndex outSpi =
+                        mISM.allocateSecurityParameterIndex(remoteOuter, spi);
                 IpSecManager.IpSecTunnelInterface tunnelIntf =
-                        mISM.createIpSecTunnelInterface(outer, outer, mCM.getActiveNetwork());
-                IpSecManager.UdpEncapsulationSocket encapSocket =
-                        mISM.openUdpEncapsulationSocket()) {
+                        mISM.createIpSecTunnelInterface(localOuter, remoteOuter, sTunNetwork)) {
+            // Build the test network
+            tunnelIntf.addAddress(localInner, innerPrefixLen);
+            testNetworkCb = setupAndGetTestNetwork(tunnelIntf.getInterfaceName());
+            Network testNetwork = testNetworkCb.getNetworkBlocking();
 
-            IpSecTransform.Builder transformBuilder = new IpSecTransform.Builder(mContext);
+            // Check interface was created
+            NetworkInterface netIntf = NetworkInterface.getByName(tunnelIntf.getInterfaceName());
+            assertNotNull(netIntf);
+
+            // Check addresses
+            List<InterfaceAddress> intfAddrs = netIntf.getInterfaceAddresses();
+            assertEquals(1, intfAddrs.size());
+            assertEquals(localInner, intfAddrs.get(0).getAddress());
+            assertEquals(innerPrefixLen, intfAddrs.get(0).getNetworkPrefixLength());
+
+            // Configure Transform parameters
+            IpSecTransform.Builder transformBuilder = new IpSecTransform.Builder(sContext);
             transformBuilder.setEncryption(
                     new IpSecAlgorithm(IpSecAlgorithm.CRYPT_AES_CBC, CRYPT_KEY));
             transformBuilder.setAuthentication(
                     new IpSecAlgorithm(
                             IpSecAlgorithm.AUTH_HMAC_SHA256, AUTH_KEY, AUTH_KEY.length * 4));
 
-            if (useEncap) {
+            if (encapSocket != null) {
                 transformBuilder.setIpv4Encapsulation(encapSocket, encapSocket.getPort());
             }
 
-            // Check transform application
-            try (IpSecTransform transform = transformBuilder.buildTunnelModeTransform(outer, spi)) {
-                mISM.applyTunnelModeTransform(tunnelIntf, IpSecManager.DIRECTION_IN, transform);
-                mISM.applyTunnelModeTransform(tunnelIntf, IpSecManager.DIRECTION_OUT, transform);
+            // Apply transform and check that traffic is properly encrypted
+            try (IpSecTransform inTransform =
+                            transformBuilder.buildTunnelModeTransform(remoteOuter, inSpi);
+                    IpSecTransform outTransform =
+                            transformBuilder.buildTunnelModeTransform(localOuter, outSpi)) {
+                mISM.applyTunnelModeTransform(tunnelIntf, IpSecManager.DIRECTION_IN, inTransform);
+                mISM.applyTunnelModeTransform(tunnelIntf, IpSecManager.DIRECTION_OUT, outTransform);
 
-                // TODO: Test to ensure that send/receive works with these transforms.
+                test.run(testNetwork);
             }
 
-            // Check interface was created
-            NetworkInterface netIntf = NetworkInterface.getByName(tunnelIntf.getInterfaceName());
-            assertNotNull(netIntf);
-
-            // Add addresses and check
-            tunnelIntf.addAddress(inner, innerPrefixLen);
-            for (InterfaceAddress intfAddr : netIntf.getInterfaceAddresses()) {
-                assertEquals(intfAddr.getAddress(), inner);
-                assertEquals(intfAddr.getNetworkPrefixLength(), innerPrefixLen);
-            }
+            // Teardown the test network
+            sTNM.teardownTestNetwork(testNetwork);
 
             // Remove addresses and check
-            tunnelIntf.removeAddress(inner, innerPrefixLen);
+            tunnelIntf.removeAddress(localInner, innerPrefixLen);
+            netIntf = NetworkInterface.getByName(tunnelIntf.getInterfaceName());
             assertTrue(netIntf.getInterfaceAddresses().isEmpty());
 
             // Check interface was cleaned up
             tunnelIntf.close();
             netIntf = NetworkInterface.getByName(tunnelIntf.getInterfaceName());
             assertNull(netIntf);
+        } finally {
+            if (testNetworkCb != null) {
+                sCM.unregisterNetworkCallback(testNetworkCb);
+            }
         }
     }
 
-    /*
-     * Create, add and remove addresses, then teardown tunnel
-     */
+    private static void receiveAndValidatePacket(JavaUdpSocket socket) throws Exception {
+        byte[] socketResponseBytes = socket.receive();
+        assertArrayEquals(TEST_DATA, socketResponseBytes);
+    }
+
+    private int getRandomSpi(InetAddress localOuter, InetAddress remoteOuter) throws Exception {
+        // Try to allocate both in and out SPIs using the same requested SPI value.
+        try (IpSecManager.SecurityParameterIndex inSpi =
+                        mISM.allocateSecurityParameterIndex(localOuter);
+                IpSecManager.SecurityParameterIndex outSpi =
+                        mISM.allocateSecurityParameterIndex(remoteOuter, inSpi.getSpi()); ) {
+            return inSpi.getSpi();
+        }
+    }
+
+    private IpHeader getIpHeader(int protocol, InetAddress src, InetAddress dst, Payload payload) {
+        if ((src instanceof Inet6Address) != (dst instanceof Inet6Address)) {
+            throw new IllegalArgumentException("Invalid src/dst address combination");
+        }
+
+        if (src instanceof Inet6Address) {
+            return new Ip6Header(protocol, (Inet6Address) src, (Inet6Address) dst, payload);
+        } else {
+            return new Ip4Header(protocol, (Inet4Address) src, (Inet4Address) dst, payload);
+        }
+    }
+
+    private EspHeader buildTransportModeEspPacket(
+            int spi, InetAddress src, InetAddress dst, int port, Payload payload) throws Exception {
+        IpHeader preEspIpHeader = getIpHeader(payload.getProtocolId(), src, dst, payload);
+
+        return new EspHeader(
+                payload.getProtocolId(),
+                spi,
+                1, // sequence number
+                CRYPT_KEY, // Same key for auth and crypt
+                payload.getPacketBytes(preEspIpHeader));
+    }
+
+    private EspHeader buildTunnelModeEspPacket(
+            int spi,
+            InetAddress srcInner,
+            InetAddress dstInner,
+            InetAddress srcOuter,
+            InetAddress dstOuter,
+            int port,
+            int encapPort,
+            Payload payload)
+            throws Exception {
+        IpHeader innerIp = getIpHeader(payload.getProtocolId(), srcInner, dstInner, payload);
+        return new EspHeader(
+                innerIp.getProtocolId(),
+                spi,
+                1, // sequence number
+                CRYPT_KEY, // Same key for auth and crypt
+                innerIp.getPacketBytes());
+    }
+
+    private IpHeader maybeEncapPacket(
+            InetAddress src, InetAddress dst, int encapPort, EspHeader espPayload)
+            throws Exception {
+
+        Payload payload = espPayload;
+        if (encapPort != 0) {
+            payload = new UdpHeader(encapPort, encapPort, espPayload);
+        }
+
+        return getIpHeader(payload.getProtocolId(), src, dst, payload);
+    }
+
+    private byte[] getTunnelModePacket(
+            int spi,
+            InetAddress srcInner,
+            InetAddress dstInner,
+            InetAddress srcOuter,
+            InetAddress dstOuter,
+            int port,
+            int encapPort)
+            throws Exception {
+        UdpHeader udp = new UdpHeader(port, port, new BytePayload(TEST_DATA));
+
+        EspHeader espPayload =
+                buildTunnelModeEspPacket(
+                        spi, srcInner, dstInner, srcOuter, dstOuter, port, encapPort, udp);
+        return maybeEncapPacket(srcOuter, dstOuter, encapPort, espPayload).getPacketBytes();
+    }
+
+    private byte[] getTransportInTunnelModePacket(
+            int spiInner,
+            int spiOuter,
+            InetAddress srcInner,
+            InetAddress dstInner,
+            InetAddress srcOuter,
+            InetAddress dstOuter,
+            int port,
+            int encapPort)
+            throws Exception {
+        UdpHeader udp = new UdpHeader(port, port, new BytePayload(TEST_DATA));
+
+        EspHeader espPayload = buildTransportModeEspPacket(spiInner, srcInner, dstInner, port, udp);
+        espPayload =
+                buildTunnelModeEspPacket(
+                        spiOuter,
+                        srcInner,
+                        dstInner,
+                        srcOuter,
+                        dstOuter,
+                        port,
+                        encapPort,
+                        espPayload);
+        return maybeEncapPacket(srcOuter, dstOuter, encapPort, espPayload).getPacketBytes();
+    }
+
+    // Transport-in-Tunnel mode tests
+    @Test
+    public void testTransportInTunnelModeV4InV4() throws Exception {
+        checkTunnelOutput(AF_INET, AF_INET, false, true);
+        checkTunnelInput(AF_INET, AF_INET, false, true);
+    }
+
+    @Test
+    public void testTransportInTunnelModeV4InV4UdpEncap() throws Exception {
+        checkTunnelOutput(AF_INET, AF_INET, true, true);
+        checkTunnelInput(AF_INET, AF_INET, true, true);
+    }
+
+    @Test
+    public void testTransportInTunnelModeV4InV6() throws Exception {
+        checkTunnelOutput(AF_INET, AF_INET6, false, true);
+        checkTunnelInput(AF_INET, AF_INET6, false, true);
+    }
+
+    @Test
+    public void testTransportInTunnelModeV6InV4() throws Exception {
+        checkTunnelOutput(AF_INET6, AF_INET, false, true);
+        checkTunnelInput(AF_INET6, AF_INET, false, true);
+    }
+
+    @Test
+    public void testTransportInTunnelModeV6InV4UdpEncap() throws Exception {
+        checkTunnelOutput(AF_INET6, AF_INET, true, true);
+        checkTunnelInput(AF_INET6, AF_INET, true, true);
+    }
+
+    @Test
+    public void testTransportInTunnelModeV6InV6() throws Exception {
+        checkTunnelOutput(AF_INET, AF_INET6, false, true);
+        checkTunnelInput(AF_INET, AF_INET6, false, true);
+    }
+
+    // Tunnel mode tests
+    @Test
     public void testTunnelV4InV4() throws Exception {
-        checkTunnel(INNER_ADDR4, OUTER_ADDR4, false);
+        checkTunnelOutput(AF_INET, AF_INET, false, false);
+        checkTunnelInput(AF_INET, AF_INET, false, false);
     }
 
+    @Test
     public void testTunnelV4InV4UdpEncap() throws Exception {
-        checkTunnel(INNER_ADDR4, OUTER_ADDR4, true);
+        checkTunnelOutput(AF_INET, AF_INET, true, false);
+        checkTunnelInput(AF_INET, AF_INET, true, false);
     }
 
+    @Test
     public void testTunnelV4InV6() throws Exception {
-        checkTunnel(INNER_ADDR4, OUTER_ADDR6, false);
+        checkTunnelOutput(AF_INET, AF_INET6, false, false);
+        checkTunnelInput(AF_INET, AF_INET6, false, false);
     }
 
+    @Test
     public void testTunnelV6InV4() throws Exception {
-        checkTunnel(INNER_ADDR6, OUTER_ADDR4, false);
+        checkTunnelOutput(AF_INET6, AF_INET, false, false);
+        checkTunnelInput(AF_INET6, AF_INET, false, false);
     }
 
+    @Test
     public void testTunnelV6InV4UdpEncap() throws Exception {
-        checkTunnel(INNER_ADDR6, OUTER_ADDR4, true);
+        checkTunnelOutput(AF_INET6, AF_INET, true, false);
+        checkTunnelInput(AF_INET6, AF_INET, true, false);
     }
 
+    @Test
     public void testTunnelV6InV6() throws Exception {
-        checkTunnel(INNER_ADDR6, OUTER_ADDR6, false);
+        checkTunnelOutput(AF_INET6, AF_INET6, false, false);
+        checkTunnelInput(AF_INET6, AF_INET6, false, false);
     }
 }
diff --git a/tests/cts/net/src/android/net/cts/PacketUtils.java b/tests/cts/net/src/android/net/cts/PacketUtils.java
new file mode 100644
index 0000000..6177827
--- /dev/null
+++ b/tests/cts/net/src/android/net/cts/PacketUtils.java
@@ -0,0 +1,460 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package android.net.cts;
+
+import static android.system.OsConstants.IPPROTO_IPV6;
+import static android.system.OsConstants.IPPROTO_UDP;
+
+import java.net.Inet4Address;
+import java.net.Inet6Address;
+import java.net.InetAddress;
+import java.nio.ByteBuffer;
+import java.nio.ShortBuffer;
+import java.security.GeneralSecurityException;
+import java.security.SecureRandom;
+import java.util.Arrays;
+import javax.crypto.Cipher;
+import javax.crypto.Mac;
+import javax.crypto.spec.IvParameterSpec;
+import javax.crypto.spec.SecretKeySpec;
+
+public class PacketUtils {
+    private static final String TAG = PacketUtils.class.getSimpleName();
+
+    private static final int DATA_BUFFER_LEN = 4096;
+
+    static final int IP4_HDRLEN = 20;
+    static final int IP6_HDRLEN = 40;
+    static final int UDP_HDRLEN = 8;
+    static final int TCP_HDRLEN = 20;
+    static final int TCP_HDRLEN_WITH_TIMESTAMP_OPT = TCP_HDRLEN + 12;
+
+    // Not defined in OsConstants
+    static final int IPPROTO_IPV4 = 4;
+    static final int IPPROTO_ESP = 50;
+
+    // Encryption parameters
+    static final int AES_GCM_IV_LEN = 8;
+    static final int AES_CBC_IV_LEN = 16;
+    static final int AES_GCM_BLK_SIZE = 4;
+    static final int AES_CBC_BLK_SIZE = 16;
+
+    // Encryption algorithms
+    static final String AES = "AES";
+    static final String AES_CBC = "AES/CBC/NoPadding";
+    static final String HMAC_SHA_256 = "HmacSHA256";
+
+    public interface Payload {
+        byte[] getPacketBytes(IpHeader header) throws Exception;
+
+        void addPacketBytes(IpHeader header, ByteBuffer resultBuffer) throws Exception;
+
+        short length();
+
+        int getProtocolId();
+    }
+
+    public abstract static class IpHeader {
+
+        public final byte proto;
+        public final InetAddress srcAddr;
+        public final InetAddress dstAddr;
+        public final Payload payload;
+
+        public IpHeader(int proto, InetAddress src, InetAddress dst, Payload payload) {
+            this.proto = (byte) proto;
+            this.srcAddr = src;
+            this.dstAddr = dst;
+            this.payload = payload;
+        }
+
+        public abstract byte[] getPacketBytes() throws Exception;
+
+        public abstract int getProtocolId();
+    }
+
+    public static class Ip4Header extends IpHeader {
+        private short checksum;
+
+        public Ip4Header(int proto, Inet4Address src, Inet4Address dst, Payload payload) {
+            super(proto, src, dst, payload);
+        }
+
+        public byte[] getPacketBytes() throws Exception {
+            ByteBuffer resultBuffer = buildHeader();
+            payload.addPacketBytes(this, resultBuffer);
+
+            return getByteArrayFromBuffer(resultBuffer);
+        }
+
+        public ByteBuffer buildHeader() {
+            ByteBuffer bb = ByteBuffer.allocate(DATA_BUFFER_LEN);
+
+            // Version, IHL
+            bb.put((byte) (0x45));
+
+            // DCSP, ECN
+            bb.put((byte) 0);
+
+            // Total Length
+            bb.putShort((short) (IP4_HDRLEN + payload.length()));
+
+            // Empty for Identification, Flags and Fragment Offset
+            bb.putShort((short) 0);
+            bb.put((byte) 0x40);
+            bb.put((byte) 0x00);
+
+            // TTL
+            bb.put((byte) 64);
+
+            // Protocol
+            bb.put(proto);
+
+            // Header Checksum
+            final int ipChecksumOffset = bb.position();
+            bb.putShort((short) 0);
+
+            // Src/Dst addresses
+            bb.put(srcAddr.getAddress());
+            bb.put(dstAddr.getAddress());
+
+            bb.putShort(ipChecksumOffset, calculateChecksum(bb));
+
+            return bb;
+        }
+
+        private short calculateChecksum(ByteBuffer bb) {
+            int checksum = 0;
+
+            // Calculate sum of 16-bit values, excluding checksum. IPv4 headers are always 32-bit
+            // aligned, so no special cases needed for unaligned values.
+            ShortBuffer shortBuffer = ByteBuffer.wrap(getByteArrayFromBuffer(bb)).asShortBuffer();
+            while (shortBuffer.hasRemaining()) {
+                short val = shortBuffer.get();
+
+                // Wrap as needed
+                checksum = addAndWrapForChecksum(checksum, val);
+            }
+
+            return onesComplement(checksum);
+        }
+
+        public int getProtocolId() {
+            return IPPROTO_IPV4;
+        }
+    }
+
+    public static class Ip6Header extends IpHeader {
+        public Ip6Header(int nextHeader, Inet6Address src, Inet6Address dst, Payload payload) {
+            super(nextHeader, src, dst, payload);
+        }
+
+        public byte[] getPacketBytes() throws Exception {
+            ByteBuffer bb = ByteBuffer.allocate(DATA_BUFFER_LEN);
+
+            // Version | Traffic Class (First 4 bits)
+            bb.put((byte) 0x60);
+
+            // Traffic class (Last 4 bits), Flow Label
+            bb.put((byte) 0);
+            bb.put((byte) 0);
+            bb.put((byte) 0);
+
+            // Payload Length
+            bb.putShort((short) payload.length());
+
+            // Next Header
+            bb.put(proto);
+
+            // Hop Limit
+            bb.put((byte) 64);
+
+            // Src/Dst addresses
+            bb.put(srcAddr.getAddress());
+            bb.put(dstAddr.getAddress());
+
+            // Payload
+            payload.addPacketBytes(this, bb);
+
+            return getByteArrayFromBuffer(bb);
+        }
+
+        public int getProtocolId() {
+            return IPPROTO_IPV6;
+        }
+    }
+
+    public static class BytePayload implements Payload {
+        public final byte[] payload;
+
+        public BytePayload(byte[] payload) {
+            this.payload = payload;
+        }
+
+        public int getProtocolId() {
+            return -1;
+        }
+
+        public byte[] getPacketBytes(IpHeader header) {
+            ByteBuffer bb = ByteBuffer.allocate(DATA_BUFFER_LEN);
+
+            addPacketBytes(header, bb);
+            return getByteArrayFromBuffer(bb);
+        }
+
+        public void addPacketBytes(IpHeader header, ByteBuffer resultBuffer) {
+            resultBuffer.put(payload);
+        }
+
+        public short length() {
+            return (short) payload.length;
+        }
+    }
+
+    public static class UdpHeader implements Payload {
+
+        public final short srcPort;
+        public final short dstPort;
+        public final Payload payload;
+
+        public UdpHeader(int srcPort, int dstPort, Payload payload) {
+            this.srcPort = (short) srcPort;
+            this.dstPort = (short) dstPort;
+            this.payload = payload;
+        }
+
+        public int getProtocolId() {
+            return IPPROTO_UDP;
+        }
+
+        public short length() {
+            return (short) (payload.length() + 8);
+        }
+
+        public byte[] getPacketBytes(IpHeader header) throws Exception {
+            ByteBuffer bb = ByteBuffer.allocate(DATA_BUFFER_LEN);
+
+            addPacketBytes(header, bb);
+            return getByteArrayFromBuffer(bb);
+        }
+
+        public void addPacketBytes(IpHeader header, ByteBuffer resultBuffer) throws Exception {
+            // Source, Destination port
+            resultBuffer.putShort(srcPort);
+            resultBuffer.putShort(dstPort);
+
+            // Payload Length
+            resultBuffer.putShort(length());
+
+            // Get payload bytes for checksum + payload
+            ByteBuffer payloadBuffer = ByteBuffer.allocate(DATA_BUFFER_LEN);
+            payload.addPacketBytes(header, payloadBuffer);
+            byte[] payloadBytes = getByteArrayFromBuffer(payloadBuffer);
+
+            // Checksum
+            resultBuffer.putShort(calculateChecksum(header, payloadBytes));
+
+            // Payload
+            resultBuffer.put(payloadBytes);
+        }
+
+        private short calculateChecksum(IpHeader header, byte[] payloadBytes) throws Exception {
+            int newChecksum = 0;
+            ShortBuffer srcBuffer = ByteBuffer.wrap(header.srcAddr.getAddress()).asShortBuffer();
+            ShortBuffer dstBuffer = ByteBuffer.wrap(header.dstAddr.getAddress()).asShortBuffer();
+
+            while (srcBuffer.hasRemaining() || dstBuffer.hasRemaining()) {
+                short val = srcBuffer.hasRemaining() ? srcBuffer.get() : dstBuffer.get();
+
+                // Wrap as needed
+                newChecksum = addAndWrapForChecksum(newChecksum, val);
+            }
+
+            // Add pseudo-header values. Proto is 0-padded, so just use the byte.
+            newChecksum = addAndWrapForChecksum(newChecksum, header.proto);
+            newChecksum = addAndWrapForChecksum(newChecksum, length());
+            newChecksum = addAndWrapForChecksum(newChecksum, srcPort);
+            newChecksum = addAndWrapForChecksum(newChecksum, dstPort);
+            newChecksum = addAndWrapForChecksum(newChecksum, length());
+
+            ShortBuffer payloadShortBuffer = ByteBuffer.wrap(payloadBytes).asShortBuffer();
+            while (payloadShortBuffer.hasRemaining()) {
+                newChecksum = addAndWrapForChecksum(newChecksum, payloadShortBuffer.get());
+            }
+            if (payload.length() % 2 != 0) {
+                newChecksum =
+                        addAndWrapForChecksum(
+                                newChecksum, (payloadBytes[payloadBytes.length - 1] << 8));
+            }
+
+            return onesComplement(newChecksum);
+        }
+    }
+
+    public static class EspHeader implements Payload {
+        public final int nextHeader;
+        public final int spi;
+        public final int seqNum;
+        public final byte[] key;
+        public final byte[] payload;
+
+        /**
+         * Generic constructor for ESP headers.
+         *
+         * <p>For Tunnel mode, payload will be a full IP header + attached payloads
+         *
+         * <p>For Transport mode, payload will be only the attached payloads, but with the checksum
+         * calculated using the pre-encryption IP header
+         */
+        public EspHeader(int nextHeader, int spi, int seqNum, byte[] key, byte[] payload) {
+            this.nextHeader = nextHeader;
+            this.spi = spi;
+            this.seqNum = seqNum;
+            this.key = key;
+            this.payload = payload;
+        }
+
+        public int getProtocolId() {
+            return IPPROTO_ESP;
+        }
+
+        public short length() {
+            // ALWAYS uses AES-CBC, HMAC-SHA256 (128b trunc len)
+            return (short)
+                    calculateEspPacketSize(payload.length, AES_CBC_IV_LEN, AES_CBC_BLK_SIZE, 128);
+        }
+
+        public byte[] getPacketBytes(IpHeader header) throws Exception {
+            ByteBuffer bb = ByteBuffer.allocate(DATA_BUFFER_LEN);
+
+            addPacketBytes(header, bb);
+            return getByteArrayFromBuffer(bb);
+        }
+
+        public void addPacketBytes(IpHeader header, ByteBuffer resultBuffer) throws Exception {
+            ByteBuffer espPayloadBuffer = ByteBuffer.allocate(DATA_BUFFER_LEN);
+            espPayloadBuffer.putInt(spi);
+            espPayloadBuffer.putInt(seqNum);
+            espPayloadBuffer.put(getCiphertext(key));
+
+            espPayloadBuffer.put(getIcv(getByteArrayFromBuffer(espPayloadBuffer)), 0, 16);
+            resultBuffer.put(getByteArrayFromBuffer(espPayloadBuffer));
+        }
+
+        private byte[] getIcv(byte[] authenticatedSection) throws GeneralSecurityException {
+            Mac sha256HMAC = Mac.getInstance(HMAC_SHA_256);
+            SecretKeySpec authKey = new SecretKeySpec(key, HMAC_SHA_256);
+            sha256HMAC.init(authKey);
+
+            return sha256HMAC.doFinal(authenticatedSection);
+        }
+
+        /**
+         * Encrypts and builds ciphertext block. Includes the IV, Padding and Next-Header blocks
+         *
+         * <p>The ciphertext does NOT include the SPI/Sequence numbers, or the ICV.
+         */
+        private byte[] getCiphertext(byte[] key) throws GeneralSecurityException {
+            int paddedLen = calculateEspEncryptedLength(payload.length, AES_CBC_BLK_SIZE);
+            ByteBuffer paddedPayload = ByteBuffer.allocate(paddedLen);
+            paddedPayload.put(payload);
+
+            // Add padding - consecutive integers from 0x01
+            int pad = 1;
+            while (paddedPayload.position() < paddedPayload.limit()) {
+                paddedPayload.put((byte) pad++);
+            }
+
+            paddedPayload.position(paddedPayload.limit() - 2);
+            paddedPayload.put((byte) (paddedLen - 2 - payload.length)); // Pad length
+            paddedPayload.put((byte) nextHeader);
+
+            // Generate Initialization Vector
+            byte[] iv = new byte[AES_CBC_IV_LEN];
+            new SecureRandom().nextBytes(iv);
+            IvParameterSpec ivParameterSpec = new IvParameterSpec(iv);
+            SecretKeySpec secretKeySpec = new SecretKeySpec(key, AES);
+
+            // Encrypt payload
+            Cipher cipher = Cipher.getInstance(AES_CBC);
+            cipher.init(Cipher.ENCRYPT_MODE, secretKeySpec, ivParameterSpec);
+            byte[] encrypted = cipher.doFinal(getByteArrayFromBuffer(paddedPayload));
+
+            // Build ciphertext
+            ByteBuffer cipherText = ByteBuffer.allocate(AES_CBC_IV_LEN + encrypted.length);
+            cipherText.put(iv);
+            cipherText.put(encrypted);
+
+            return getByteArrayFromBuffer(cipherText);
+        }
+    }
+
+    private static int addAndWrapForChecksum(int currentChecksum, int value) {
+        currentChecksum += value & 0x0000ffff;
+
+        // Wrap anything beyond the first 16 bits, and add to lower order bits
+        return (currentChecksum >>> 16) + (currentChecksum & 0x0000ffff);
+    }
+
+    private static short onesComplement(int val) {
+        val = (val >>> 16) + (val & 0xffff);
+
+        if (val == 0) return 0;
+        return (short) ((~val) & 0xffff);
+    }
+
+    public static int calculateEspPacketSize(
+            int payloadLen, int cryptIvLength, int cryptBlockSize, int authTruncLen) {
+        final int ESP_HDRLEN = 4 + 4; // SPI + Seq#
+        final int ICV_LEN = authTruncLen / 8; // Auth trailer; based on truncation length
+        payloadLen += cryptIvLength; // Initialization Vector
+
+        // Align to block size of encryption algorithm
+        payloadLen = calculateEspEncryptedLength(payloadLen, cryptBlockSize);
+        return payloadLen + ESP_HDRLEN + ICV_LEN;
+    }
+
+    private static int calculateEspEncryptedLength(int payloadLen, int cryptBlockSize) {
+        payloadLen += 2; // ESP trailer
+
+        // Align to block size of encryption algorithm
+        return payloadLen + calculateEspPadLen(payloadLen, cryptBlockSize);
+    }
+
+    private static int calculateEspPadLen(int payloadLen, int cryptBlockSize) {
+        return (cryptBlockSize - (payloadLen % cryptBlockSize)) % cryptBlockSize;
+    }
+
+    private static byte[] getByteArrayFromBuffer(ByteBuffer buffer) {
+        return Arrays.copyOfRange(buffer.array(), 0, buffer.position());
+    }
+
+    /*
+     * Debug printing
+     */
+    private static final char[] hexArray = "0123456789ABCDEF".toCharArray();
+
+    public static String bytesToHex(byte[] bytes) {
+        StringBuilder sb = new StringBuilder();
+        for (byte b : bytes) {
+            sb.append(hexArray[b >>> 4]);
+            sb.append(hexArray[b & 0x0F]);
+            sb.append(' ');
+        }
+        return sb.toString();
+    }
+}
diff --git a/tests/cts/net/src/android/net/cts/TunUtils.java b/tests/cts/net/src/android/net/cts/TunUtils.java
new file mode 100644
index 0000000..a030713
--- /dev/null
+++ b/tests/cts/net/src/android/net/cts/TunUtils.java
@@ -0,0 +1,257 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package android.net.cts;
+
+import static android.net.cts.PacketUtils.IP4_HDRLEN;
+import static android.net.cts.PacketUtils.IP6_HDRLEN;
+import static android.net.cts.PacketUtils.IPPROTO_ESP;
+import static android.net.cts.PacketUtils.UDP_HDRLEN;
+import static android.system.OsConstants.IPPROTO_UDP;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.fail;
+
+import android.os.ParcelFileDescriptor;
+
+import java.io.FileInputStream;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import java.util.function.Predicate;
+
+public class TunUtils {
+    private static final String TAG = TunUtils.class.getSimpleName();
+
+    private static final int DATA_BUFFER_LEN = 4096;
+    private static final int TIMEOUT = 100;
+
+    private static final int IP4_PROTO_OFFSET = 9;
+    private static final int IP6_PROTO_OFFSET = 6;
+
+    private static final int IP4_ADDR_OFFSET = 12;
+    private static final int IP4_ADDR_LEN = 4;
+    private static final int IP6_ADDR_OFFSET = 8;
+    private static final int IP6_ADDR_LEN = 16;
+
+    private final ParcelFileDescriptor mTunFd;
+    private final List<byte[]> mPackets = new ArrayList<>();
+    private final Thread mReaderThread;
+
+    public TunUtils(ParcelFileDescriptor tunFd) {
+        mTunFd = tunFd;
+
+        // Start background reader thread
+        mReaderThread =
+                new Thread(
+                        () -> {
+                            try {
+                                // Loop will exit and thread will quit when tunFd is closed.
+                                // Receiving either EOF or an exception will exit this reader loop.
+                                // FileInputStream in uninterruptable, so there's no good way to
+                                // ensure that this thread shuts down except upon FD closure.
+                                while (true) {
+                                    byte[] intercepted = receiveFromTun();
+                                    if (intercepted == null) {
+                                        // Exit once we've hit EOF
+                                        return;
+                                    } else if (intercepted.length > 0) {
+                                        // Only save packet if we've received any bytes.
+                                        synchronized (mPackets) {
+                                            mPackets.add(intercepted);
+                                            mPackets.notifyAll();
+                                        }
+                                    }
+                                }
+                            } catch (IOException ignored) {
+                                // Simply exit this reader thread
+                                return;
+                            }
+                        });
+        mReaderThread.start();
+    }
+
+    private byte[] receiveFromTun() throws IOException {
+        FileInputStream in = new FileInputStream(mTunFd.getFileDescriptor());
+        byte[] inBytes = new byte[DATA_BUFFER_LEN];
+        int bytesRead = in.read(inBytes);
+
+        if (bytesRead < 0) {
+            return null; // return null for EOF
+        } else if (bytesRead >= DATA_BUFFER_LEN) {
+            throw new IllegalStateException("Too big packet. Fragmentation unsupported");
+        }
+        return Arrays.copyOf(inBytes, bytesRead);
+    }
+
+    private byte[] getFirstMatchingPacket(Predicate<byte[]> verifier, int startIndex) {
+        synchronized (mPackets) {
+            for (int i = startIndex; i < mPackets.size(); i++) {
+                byte[] pkt = mPackets.get(i);
+                if (verifier.test(pkt)) {
+                    return pkt;
+                }
+            }
+        }
+        return null;
+    }
+
+    /**
+     * Checks if the specified bytes were ever sent in plaintext.
+     *
+     * <p>Only checks for known plaintext bytes to prevent triggering on ICMP/RA packets or the like
+     *
+     * @param plaintext the plaintext bytes to check for
+     * @param startIndex the index in the list to check for
+     */
+    public boolean hasPlaintextPacket(byte[] plaintext, int startIndex) {
+        Predicate<byte[]> verifier =
+                (pkt) -> {
+                    return Collections.indexOfSubList(Arrays.asList(pkt), Arrays.asList(plaintext))
+                            != -1;
+                };
+        return getFirstMatchingPacket(verifier, startIndex) != null;
+    }
+
+    public byte[] getEspPacket(int spi, boolean encap, int startIndex) {
+        return getFirstMatchingPacket(
+                (pkt) -> {
+                    return isEsp(pkt, spi, encap);
+                },
+                startIndex);
+    }
+
+    public byte[] awaitEspPacketNoPlaintext(
+            int spi, byte[] plaintext, boolean useEncap, int expectedPacketSize) throws Exception {
+        long endTime = System.currentTimeMillis() + TIMEOUT;
+        int startIndex = 0;
+
+        synchronized (mPackets) {
+            while (System.currentTimeMillis() < endTime) {
+                byte[] espPkt = getEspPacket(spi, useEncap, startIndex);
+                if (espPkt != null) {
+                    // Validate packet size
+                    assertEquals(expectedPacketSize, espPkt.length);
+
+                    // Always check plaintext from start
+                    assertFalse(hasPlaintextPacket(plaintext, 0));
+                    return espPkt; // We've found the packet we're looking for.
+                }
+
+                startIndex = mPackets.size();
+
+                // Try to prevent waiting too long. If waitTimeout <= 0, we've already hit timeout
+                long waitTimeout = endTime - System.currentTimeMillis();
+                if (waitTimeout > 0) {
+                    mPackets.wait(waitTimeout);
+                }
+            }
+
+            fail("No such ESP packet found with SPI " + spi);
+        }
+        return null;
+    }
+
+    private static boolean isSpiEqual(byte[] pkt, int espOffset, int spi) {
+        // Check SPI byte by byte.
+        return pkt[espOffset] == (byte) ((spi >>> 24) & 0xff)
+                && pkt[espOffset + 1] == (byte) ((spi >>> 16) & 0xff)
+                && pkt[espOffset + 2] == (byte) ((spi >>> 8) & 0xff)
+                && pkt[espOffset + 3] == (byte) (spi & 0xff);
+    }
+
+    private static boolean isEsp(byte[] pkt, int spi, boolean encap) {
+        if (isIpv6(pkt)) {
+            // IPv6 UDP encap not supported by kernels; assume non-encap.
+            return pkt[IP6_PROTO_OFFSET] == IPPROTO_ESP && isSpiEqual(pkt, IP6_HDRLEN, spi);
+        } else {
+            // Use default IPv4 header length (assuming no options)
+            if (encap) {
+                return pkt[IP4_PROTO_OFFSET] == IPPROTO_UDP
+                        && isSpiEqual(pkt, IP4_HDRLEN + UDP_HDRLEN, spi);
+            } else {
+                return pkt[IP4_PROTO_OFFSET] == IPPROTO_ESP && isSpiEqual(pkt, IP4_HDRLEN, spi);
+            }
+        }
+    }
+
+    private static boolean isIpv6(byte[] pkt) {
+        // First nibble shows IP version. 0x60 for IPv6
+        return (pkt[0] & (byte) 0xF0) == (byte) 0x60;
+    }
+
+    private static byte[] getReflectedPacket(byte[] pkt) {
+        byte[] reflected = Arrays.copyOf(pkt, pkt.length);
+
+        if (isIpv6(pkt)) {
+            // Set reflected packet's dst to that of the original's src
+            System.arraycopy(
+                    pkt, // src
+                    IP6_ADDR_OFFSET + IP6_ADDR_LEN, // src offset
+                    reflected, // dst
+                    IP6_ADDR_OFFSET, // dst offset
+                    IP6_ADDR_LEN); // len
+            // Set reflected packet's src IP to that of the original's dst IP
+            System.arraycopy(
+                    pkt, // src
+                    IP6_ADDR_OFFSET, // src offset
+                    reflected, // dst
+                    IP6_ADDR_OFFSET + IP6_ADDR_LEN, // dst offset
+                    IP6_ADDR_LEN); // len
+        } else {
+            // Set reflected packet's dst to that of the original's src
+            System.arraycopy(
+                    pkt, // src
+                    IP4_ADDR_OFFSET + IP4_ADDR_LEN, // src offset
+                    reflected, // dst
+                    IP4_ADDR_OFFSET, // dst offset
+                    IP4_ADDR_LEN); // len
+            // Set reflected packet's src IP to that of the original's dst IP
+            System.arraycopy(
+                    pkt, // src
+                    IP4_ADDR_OFFSET, // src offset
+                    reflected, // dst
+                    IP4_ADDR_OFFSET + IP4_ADDR_LEN, // dst offset
+                    IP4_ADDR_LEN); // len
+        }
+        return reflected;
+    }
+
+    /** Takes all captured packets, flips the src/dst, and re-injects them. */
+    public void reflectPackets() throws IOException {
+        synchronized (mPackets) {
+            for (byte[] pkt : mPackets) {
+                injectPacket(getReflectedPacket(pkt));
+            }
+        }
+    }
+
+    public void injectPacket(byte[] pkt) throws IOException {
+        FileOutputStream out = new FileOutputStream(mTunFd.getFileDescriptor());
+        out.write(pkt);
+        out.flush();
+    }
+
+    /** Resets the intercepted packets. */
+    public void reset() throws IOException {
+        synchronized (mPackets) {
+            mPackets.clear();
+        }
+    }
+}