Add reflected-packet based data tests

This commit adds tests that reflect outgoing packets, flipping the outer
src/dst headers to avoid the need to tear down and rebuild the outer
TUN.

This allows us to at least test that our implementation can interoperate
with itself.

Bug: 72950854
Test: this, passing
Merged-In: Ia969f78f4c1a0c0a017f5aad425a68852ff4433a
Change-Id: Ia969f78f4c1a0c0a017f5aad425a68852ff4433a
(cherry picked from commit 144937f3df37ee0b1d5484f10e8c86a8a70a9cb5)
diff --git a/tests/cts/net/src/android/net/cts/IpSecManagerTunnelTest.java b/tests/cts/net/src/android/net/cts/IpSecManagerTunnelTest.java
index 828abcc..d1438ec 100644
--- a/tests/cts/net/src/android/net/cts/IpSecManagerTunnelTest.java
+++ b/tests/cts/net/src/android/net/cts/IpSecManagerTunnelTest.java
@@ -240,8 +240,16 @@
     }
 
     /* Test runnables for callbacks after IPsec tunnels are set up. */
-    private interface TestRunnable {
-        void run(Network ipsecNetwork) throws Exception;
+    private abstract class IpSecTunnelTestRunnable {
+        /**
+         * Runs the test code, and returns the inner socket port, if any.
+         *
+         * @param ipsecNetwork The IPsec Interface based Network for binding sockets on
+         * @return the integer port of the inner socket if outbound, or 0 if inbound
+         *     IpSecTunnelTestRunnable
+         * @throws Exception if any part of the test failed.
+         */
+        public abstract int run(Network ipsecNetwork) throws Exception;
     }
 
     private static class TestNetworkCallback extends ConnectivityManager.NetworkCallback {
@@ -288,8 +296,8 @@
         return expectedPacketSize;
     }
 
-    private interface TestRunnableFactory {
-        TestRunnable getTestRunnable(
+    private interface IpSecTunnelTestRunnableFactory {
+        IpSecTunnelTestRunnable getIpSecTunnelTestRunnable(
                 boolean transportInTunnelMode,
                 int spi,
                 InetAddress localInner,
@@ -299,12 +307,13 @@
                 IpSecTransform inTransportTransform,
                 IpSecTransform outTransportTransform,
                 int encapPort,
+                int innerSocketPort,
                 int expectedPacketSize)
                 throws Exception;
     }
 
-    private class OutputTestRunnableFactory implements TestRunnableFactory {
-        public TestRunnable getTestRunnable(
+    private class OutputIpSecTunnelTestRunnableFactory implements IpSecTunnelTestRunnableFactory {
+        public IpSecTunnelTestRunnable getIpSecTunnelTestRunnable(
                 boolean transportInTunnelMode,
                 int spi,
                 InetAddress localInner,
@@ -314,13 +323,15 @@
                 IpSecTransform inTransportTransform,
                 IpSecTransform outTransportTransform,
                 int encapPort,
+                int unusedInnerSocketPort,
                 int expectedPacketSize) {
-            return new TestRunnable() {
+            return new IpSecTunnelTestRunnable() {
                 @Override
-                public void run(Network ipsecNetwork) throws Exception {
+                public int run(Network ipsecNetwork) throws Exception {
                     // Build a socket and send traffic
                     JavaUdpSocket socket = new JavaUdpSocket(localInner);
                     ipsecNetwork.bindSocket(socket.mSocket);
+                    int innerSocketPort = socket.getPort();
 
                     // For Transport-In-Tunnel mode, apply transform to socket
                     if (transportInTunnelMode) {
@@ -333,19 +344,22 @@
                     socket.sendTo(TEST_DATA, remoteInner, socket.getPort());
 
                     // Verify that an encrypted packet is sent. As of right now, checking encrypted
-                    // body is not possible, due to our not knowing some of the fields of the
+                    // body is not possible, due to the test not knowing some of the fields of the
                     // inner IP header (flow label, flags, etc)
                     sTunUtils.awaitEspPacketNoPlaintext(
                             spi, TEST_DATA, encapPort != 0, expectedPacketSize);
 
                     socket.close();
+
+                    return innerSocketPort;
                 }
             };
         }
     }
 
-    private class InputPacketGeneratorTestRunnableFactory implements TestRunnableFactory {
-        public TestRunnable getTestRunnable(
+    private class InputReflectedIpSecTunnelTestRunnableFactory
+            implements IpSecTunnelTestRunnableFactory {
+        public IpSecTunnelTestRunnable getIpSecTunnelTestRunnable(
                 boolean transportInTunnelMode,
                 int spi,
                 InetAddress localInner,
@@ -355,14 +369,57 @@
                 IpSecTransform inTransportTransform,
                 IpSecTransform outTransportTransform,
                 int encapPort,
+                int innerSocketPort,
                 int expectedPacketSize)
                 throws Exception {
-            return new TestRunnable() {
+            return new IpSecTunnelTestRunnable() {
                 @Override
-                public void run(Network ipsecNetwork) throws Exception {
+                public int run(Network ipsecNetwork) throws Exception {
+                    // Build a socket and receive traffic
+                    JavaUdpSocket socket = new JavaUdpSocket(localInner, innerSocketPort);
+                    ipsecNetwork.bindSocket(socket.mSocket);
+
+                    // For Transport-In-Tunnel mode, apply transform to socket
+                    if (transportInTunnelMode) {
+                        mISM.applyTransportModeTransform(
+                                socket.mSocket, IpSecManager.DIRECTION_IN, outTransportTransform);
+                        mISM.applyTransportModeTransform(
+                                socket.mSocket, IpSecManager.DIRECTION_OUT, inTransportTransform);
+                    }
+
+                    sTunUtils.reflectPackets();
+
+                    // Receive packet from socket, and validate that the payload is correct
+                    receiveAndValidatePacket(socket);
+
+                    socket.close();
+
+                    return 0;
+                }
+            };
+        }
+    }
+
+    private class InputPacketGeneratorIpSecTunnelTestRunnableFactory
+            implements IpSecTunnelTestRunnableFactory {
+        public IpSecTunnelTestRunnable getIpSecTunnelTestRunnable(
+                boolean transportInTunnelMode,
+                int spi,
+                InetAddress localInner,
+                InetAddress remoteInner,
+                InetAddress localOuter,
+                InetAddress remoteOuter,
+                IpSecTransform inTransportTransform,
+                IpSecTransform outTransportTransform,
+                int encapPort,
+                int innerSocketPort,
+                int expectedPacketSize)
+                throws Exception {
+            return new IpSecTunnelTestRunnable() {
+                @Override
+                public int run(Network ipsecNetwork) throws Exception {
                     // Build a socket and receive traffic
                     JavaUdpSocket socket = new JavaUdpSocket(localInner);
-                    // JavaUdpSocket socket = new JavaUdpSocket(localInner, socketPort.get());
                     ipsecNetwork.bindSocket(socket.mSocket);
 
                     // For Transport-In-Tunnel mode, apply transform to socket
@@ -402,6 +459,8 @@
                     receiveAndValidatePacket(socket);
 
                     socket.close();
+
+                    return 0;
                 }
             };
         }
@@ -415,7 +474,7 @@
                 outerFamily,
                 useEncap,
                 transportInTunnelMode,
-                new OutputTestRunnableFactory());
+                new OutputIpSecTunnelTestRunnableFactory());
     }
 
     private void checkTunnelInput(
@@ -426,7 +485,91 @@
                 outerFamily,
                 useEncap,
                 transportInTunnelMode,
-                new InputPacketGeneratorTestRunnableFactory());
+                new InputPacketGeneratorIpSecTunnelTestRunnableFactory());
+    }
+
+    /**
+     * Validates that the kernel can talk to itself.
+     *
+     * <p>This test takes an outbound IPsec packet, reflects it (by flipping IP src/dst), and
+     * injects it back into the TUN. This test then verifies that a packet with the correct payload
+     * is found on the specified socket/port.
+     */
+    public void checkTunnelReflected(
+            int innerFamily, int outerFamily, boolean useEncap, boolean transportInTunnelMode)
+            throws Exception {
+        if (!hasTunnelsFeature()) return;
+
+        InetAddress localInner = innerFamily == AF_INET ? LOCAL_INNER_4 : LOCAL_INNER_6;
+        InetAddress remoteInner = innerFamily == AF_INET ? REMOTE_INNER_4 : REMOTE_INNER_6;
+
+        InetAddress localOuter = outerFamily == AF_INET ? LOCAL_OUTER_4 : LOCAL_OUTER_6;
+        InetAddress remoteOuter = outerFamily == AF_INET ? REMOTE_OUTER_4 : REMOTE_OUTER_6;
+
+        // Preselect both SPI and encap port, to be used for both inbound and outbound tunnels.
+        int spi = getRandomSpi(localOuter, remoteOuter);
+        int expectedPacketSize =
+                getPacketSize(innerFamily, outerFamily, useEncap, transportInTunnelMode);
+
+        try (IpSecManager.SecurityParameterIndex inTransportSpi =
+                        mISM.allocateSecurityParameterIndex(localInner, spi);
+                IpSecManager.SecurityParameterIndex outTransportSpi =
+                        mISM.allocateSecurityParameterIndex(remoteInner, spi);
+                IpSecTransform inTransportTransform =
+                        buildIpSecTransform(sContext, inTransportSpi, null, remoteInner);
+                IpSecTransform outTransportTransform =
+                        buildIpSecTransform(sContext, outTransportSpi, null, localInner);
+                UdpEncapsulationSocket encapSocket = mISM.openUdpEncapsulationSocket()) {
+
+            // Run output direction tests
+            IpSecTunnelTestRunnable outputIpSecTunnelTestRunnable =
+                    new OutputIpSecTunnelTestRunnableFactory()
+                            .getIpSecTunnelTestRunnable(
+                                    transportInTunnelMode,
+                                    spi,
+                                    localInner,
+                                    remoteInner,
+                                    localOuter,
+                                    remoteOuter,
+                                    inTransportTransform,
+                                    outTransportTransform,
+                                    useEncap ? encapSocket.getPort() : 0,
+                                    0,
+                                    expectedPacketSize);
+            int innerSocketPort =
+                    buildTunnelNetworkAndRunTests(
+                    localInner,
+                    remoteInner,
+                    localOuter,
+                    remoteOuter,
+                    spi,
+                    useEncap ? encapSocket : null,
+                    outputIpSecTunnelTestRunnable);
+
+            // Input direction tests, with matching inner socket ports.
+            IpSecTunnelTestRunnable inputIpSecTunnelTestRunnable =
+                    new InputReflectedIpSecTunnelTestRunnableFactory()
+                            .getIpSecTunnelTestRunnable(
+                                    transportInTunnelMode,
+                                    spi,
+                                    remoteInner,
+                                    localInner,
+                                    localOuter,
+                                    remoteOuter,
+                                    inTransportTransform,
+                                    outTransportTransform,
+                                    useEncap ? encapSocket.getPort() : 0,
+                                    innerSocketPort,
+                                    expectedPacketSize);
+            buildTunnelNetworkAndRunTests(
+                    remoteInner,
+                    localInner,
+                    localOuter,
+                    remoteOuter,
+                    spi,
+                    useEncap ? encapSocket : null,
+                    inputIpSecTunnelTestRunnable);
+        }
     }
 
     public void checkTunnel(
@@ -434,7 +577,7 @@
             int outerFamily,
             boolean useEncap,
             boolean transportInTunnelMode,
-            TestRunnableFactory factory)
+            IpSecTunnelTestRunnableFactory factory)
             throws Exception {
         if (!hasTunnelsFeature()) return;
 
@@ -461,14 +604,14 @@
                         buildIpSecTransform(sContext, outTransportSpi, null, localInner);
                 UdpEncapsulationSocket encapSocket = mISM.openUdpEncapsulationSocket()) {
 
-            buildTunnelAndNetwork(
+            buildTunnelNetworkAndRunTests(
                     localInner,
                     remoteInner,
                     localOuter,
                     remoteOuter,
                     spi,
                     useEncap ? encapSocket : null,
-                    factory.getTestRunnable(
+                    factory.getIpSecTunnelTestRunnable(
                             transportInTunnelMode,
                             spi,
                             localInner,
@@ -478,21 +621,23 @@
                             inTransportTransform,
                             outTransportTransform,
                             useEncap ? encapSocket.getPort() : 0,
+                            0,
                             expectedPacketSize));
         }
     }
 
-    private void buildTunnelAndNetwork(
+    private int buildTunnelNetworkAndRunTests(
             InetAddress localInner,
             InetAddress remoteInner,
             InetAddress localOuter,
             InetAddress remoteOuter,
             int spi,
             UdpEncapsulationSocket encapSocket,
-            TestRunnable test)
+            IpSecTunnelTestRunnable test)
             throws Exception {
         int innerPrefixLen = localInner instanceof Inet6Address ? IP6_PREFIX_LEN : IP4_PREFIX_LEN;
         TestNetworkCallback testNetworkCb = null;
+        int innerSocketPort;
 
         try (IpSecManager.SecurityParameterIndex inSpi =
                         mISM.allocateSecurityParameterIndex(localOuter, spi);
@@ -534,7 +679,7 @@
                 mISM.applyTunnelModeTransform(
                         tunnelIface, IpSecManager.DIRECTION_OUT, outTransform);
 
-                test.run(testNetwork);
+                innerSocketPort = test.run(testNetwork);
             }
 
             // Teardown the test network
@@ -553,6 +698,8 @@
                 sCM.unregisterNetworkCallback(testNetworkCb);
             }
         }
+
+        return innerSocketPort;
     }
 
     private static void receiveAndValidatePacket(JavaUdpSocket socket) throws Exception {
@@ -676,35 +823,65 @@
     }
 
     @Test
+    public void testTransportInTunnelModeV4InV4Reflected() throws Exception {
+        checkTunnelReflected(AF_INET, AF_INET, false, true);
+    }
+
+    @Test
     public void testTransportInTunnelModeV4InV4UdpEncap() throws Exception {
         checkTunnelOutput(AF_INET, AF_INET, true, true);
         checkTunnelInput(AF_INET, AF_INET, true, true);
     }
 
     @Test
+    public void testTransportInTunnelModeV4InV4UdpEncapReflected() throws Exception {
+        checkTunnelReflected(AF_INET, AF_INET, false, true);
+    }
+
+    @Test
     public void testTransportInTunnelModeV4InV6() throws Exception {
         checkTunnelOutput(AF_INET, AF_INET6, false, true);
         checkTunnelInput(AF_INET, AF_INET6, false, true);
     }
 
     @Test
+    public void testTransportInTunnelModeV4InV6Reflected() throws Exception {
+        checkTunnelReflected(AF_INET, AF_INET, false, true);
+    }
+
+    @Test
     public void testTransportInTunnelModeV6InV4() throws Exception {
         checkTunnelOutput(AF_INET6, AF_INET, false, true);
         checkTunnelInput(AF_INET6, AF_INET, false, true);
     }
 
     @Test
+    public void testTransportInTunnelModeV6InV4Reflected() throws Exception {
+        checkTunnelReflected(AF_INET, AF_INET, false, true);
+    }
+
+    @Test
     public void testTransportInTunnelModeV6InV4UdpEncap() throws Exception {
         checkTunnelOutput(AF_INET6, AF_INET, true, true);
         checkTunnelInput(AF_INET6, AF_INET, true, true);
     }
 
     @Test
+    public void testTransportInTunnelModeV6InV4UdpEncapReflected() throws Exception {
+        checkTunnelReflected(AF_INET, AF_INET, false, true);
+    }
+
+    @Test
     public void testTransportInTunnelModeV6InV6() throws Exception {
         checkTunnelOutput(AF_INET, AF_INET6, false, true);
         checkTunnelInput(AF_INET, AF_INET6, false, true);
     }
 
+    @Test
+    public void testTransportInTunnelModeV6InV6Reflected() throws Exception {
+        checkTunnelReflected(AF_INET, AF_INET, false, true);
+    }
+
     // Tunnel mode tests
     @Test
     public void testTunnelV4InV4() throws Exception {
@@ -713,32 +890,62 @@
     }
 
     @Test
+    public void testTunnelV4InV4Reflected() throws Exception {
+        checkTunnelReflected(AF_INET, AF_INET, false, false);
+    }
+
+    @Test
     public void testTunnelV4InV4UdpEncap() throws Exception {
         checkTunnelOutput(AF_INET, AF_INET, true, false);
         checkTunnelInput(AF_INET, AF_INET, true, false);
     }
 
     @Test
+    public void testTunnelV4InV4UdpEncapReflected() throws Exception {
+        checkTunnelReflected(AF_INET, AF_INET, true, false);
+    }
+
+    @Test
     public void testTunnelV4InV6() throws Exception {
         checkTunnelOutput(AF_INET, AF_INET6, false, false);
         checkTunnelInput(AF_INET, AF_INET6, false, false);
     }
 
     @Test
+    public void testTunnelV4InV6Reflected() throws Exception {
+        checkTunnelReflected(AF_INET, AF_INET6, false, false);
+    }
+
+    @Test
     public void testTunnelV6InV4() throws Exception {
         checkTunnelOutput(AF_INET6, AF_INET, false, false);
         checkTunnelInput(AF_INET6, AF_INET, false, false);
     }
 
     @Test
+    public void testTunnelV6InV4Reflected() throws Exception {
+        checkTunnelReflected(AF_INET6, AF_INET, false, false);
+    }
+
+    @Test
     public void testTunnelV6InV4UdpEncap() throws Exception {
         checkTunnelOutput(AF_INET6, AF_INET, true, false);
         checkTunnelInput(AF_INET6, AF_INET, true, false);
     }
 
     @Test
+    public void testTunnelV6InV4UdpEncapReflected() throws Exception {
+        checkTunnelReflected(AF_INET6, AF_INET, true, false);
+    }
+
+    @Test
     public void testTunnelV6InV6() throws Exception {
         checkTunnelOutput(AF_INET6, AF_INET6, false, false);
         checkTunnelInput(AF_INET6, AF_INET6, false, false);
     }
+
+    @Test
+    public void testTunnelV6InV6Reflected() throws Exception {
+        checkTunnelReflected(AF_INET6, AF_INET6, false, false);
+    }
 }