CTS for request transform state

Test: atest IpSecManagerTunnelTest (new tests)
Test: atest IpSecManagerTest (new tests)
Bug: 326837293
Change-Id: Ie060546afdabbb9c72df0f4a00938805a972c03c
diff --git a/tests/cts/net/src/android/net/cts/IpSecBaseTest.java b/tests/cts/net/src/android/net/cts/IpSecBaseTest.java
index 7f710d7..2a6c638 100644
--- a/tests/cts/net/src/android/net/cts/IpSecBaseTest.java
+++ b/tests/cts/net/src/android/net/cts/IpSecBaseTest.java
@@ -26,12 +26,15 @@
 import static android.system.OsConstants.FIONREAD;
 
 import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
 
 import android.content.Context;
 import android.net.ConnectivityManager;
 import android.net.IpSecAlgorithm;
 import android.net.IpSecManager;
 import android.net.IpSecTransform;
+import android.net.IpSecTransformState;
+import android.os.OutcomeReceiver;
 import android.platform.test.annotations.AppModeFull;
 import android.system.ErrnoException;
 import android.system.Os;
@@ -65,8 +68,12 @@
 import java.net.SocketImpl;
 import java.net.SocketOptions;
 import java.util.Arrays;
+import java.util.BitSet;
 import java.util.HashSet;
 import java.util.Set;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.Executors;
+import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicInteger;
 
 @RunWith(AndroidJUnit4.class)
@@ -83,6 +90,7 @@
     protected static final byte[] TEST_DATA = "Best test data ever!".getBytes();
     protected static final int DATA_BUFFER_LEN = 4096;
     protected static final int SOCK_TIMEOUT = 500;
+    protected static final int REPLAY_BITMAP_LEN_BYTE = 512;
 
     private static final byte[] KEY_DATA = {
         0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
@@ -122,6 +130,47 @@
                                 .getSystemService(Context.CONNECTIVITY_SERVICE);
     }
 
+    protected static void checkTransformState(
+            IpSecTransform transform,
+            long txHighestSeqNum,
+            long rxHighestSeqNum,
+            long packetCnt,
+            long byteCnt,
+            byte[] replayBitmap)
+            throws Exception {
+        final CompletableFuture<IpSecTransformState> futureIpSecTransform =
+                new CompletableFuture<>();
+        transform.requestIpSecTransformState(
+                Executors.newSingleThreadExecutor(),
+                new OutcomeReceiver<IpSecTransformState, RuntimeException>() {
+                    @Override
+                    public void onResult(IpSecTransformState state) {
+                        futureIpSecTransform.complete(state);
+                    }
+                });
+
+        final IpSecTransformState transformState =
+                futureIpSecTransform.get(SOCK_TIMEOUT, TimeUnit.MILLISECONDS);
+
+        assertEquals(txHighestSeqNum, transformState.getTxHighestSequenceNumber());
+        assertEquals(rxHighestSeqNum, transformState.getRxHighestSequenceNumber());
+        assertEquals(packetCnt, transformState.getPacketCount());
+        assertEquals(byteCnt, transformState.getByteCount());
+        assertArrayEquals(replayBitmap, transformState.getReplayBitmap());
+    }
+
+    protected static void checkTransformStateNoTraffic(IpSecTransform transform) throws Exception {
+        checkTransformState(transform, 0L, 0L, 0L, 0L, newReplayBitmap(0));
+    }
+
+    protected static byte[] newReplayBitmap(int receivedPktCnt) {
+        final BitSet bitSet = new BitSet(REPLAY_BITMAP_LEN_BYTE * 8);
+        for (int i = 0; i < receivedPktCnt; i++) {
+            bitSet.set(i);
+        }
+        return Arrays.copyOf(bitSet.toByteArray(), REPLAY_BITMAP_LEN_BYTE);
+    }
+
     /** Checks if an IPsec algorithm is enabled on the device */
     protected static boolean hasIpSecAlgorithm(String algorithm) {
         if (SdkLevel.isAtLeastS()) {
diff --git a/tests/cts/net/src/android/net/cts/IpSecManagerTest.java b/tests/cts/net/src/android/net/cts/IpSecManagerTest.java
index fe86a90..a40ed0f 100644
--- a/tests/cts/net/src/android/net/cts/IpSecManagerTest.java
+++ b/tests/cts/net/src/android/net/cts/IpSecManagerTest.java
@@ -63,11 +63,13 @@
 import static org.junit.Assert.fail;
 import static org.junit.Assume.assumeTrue;
 
+import android.net.InetAddresses;
 import android.net.IpSecAlgorithm;
 import android.net.IpSecManager;
 import android.net.IpSecManager.SecurityParameterIndex;
 import android.net.IpSecManager.UdpEncapsulationSocket;
 import android.net.IpSecTransform;
+import android.net.NetworkUtils;
 import android.net.TrafficStats;
 import android.os.Build;
 import android.platform.test.annotations.AppModeFull;
@@ -381,6 +383,22 @@
         assumeTrue("Not supported by kernel", isIpv6UdpEncapSupportedByKernel());
     }
 
+    // TODO: b/319532485 Figure out whether to support x86_32
+    private static boolean isRequestTransformStateSupportedByKernel() {
+        return NetworkUtils.isKernel64Bit() || !NetworkUtils.isKernelX86();
+    }
+
+    // Package private for use in IpSecManagerTunnelTest
+    static boolean isRequestTransformStateSupported() {
+        return SdkLevel.isAtLeastV() && isRequestTransformStateSupportedByKernel();
+    }
+
+    // Package private for use in IpSecManagerTunnelTest
+    static void assumeRequestIpSecTransformStateSupported() {
+        assumeTrue("Not supported before V", SdkLevel.isAtLeastV());
+        assumeTrue("Not supported by kernel", isRequestTransformStateSupportedByKernel());
+    }
+
     @Test
     public void testCreateTransformIpv4() throws Exception {
         doTestCreateTransform(IPV4_LOOPBACK, false);
@@ -1596,4 +1614,32 @@
             assertTrue("Returned invalid port", encapSocket.getPort() != 0);
         }
     }
+
+    @IgnoreUpTo(Build.VERSION_CODES.UPSIDE_DOWN_CAKE)
+    @Test
+    public void testRequestIpSecTransformState() throws Exception {
+        assumeRequestIpSecTransformStateSupported();
+
+        final InetAddress localAddr = InetAddresses.parseNumericAddress(IPV6_LOOPBACK);
+        try (SecurityParameterIndex spi = mISM.allocateSecurityParameterIndex(localAddr);
+                IpSecTransform transform =
+                        buildTransportModeTransform(spi, localAddr, null /* encapSocket*/)) {
+            final SocketPair<JavaUdpSocket> sockets =
+                    getJavaUdpSocketPair(localAddr, mISM, transform, false);
+
+            sockets.mLeftSock.sendTo(TEST_DATA, localAddr, sockets.mRightSock.getPort());
+            sockets.mRightSock.receive();
+
+            final int expectedPacketCount = 1;
+            final int expectedInnerPacketSize = TEST_DATA.length + UDP_HDRLEN;
+
+            checkTransformState(
+                    transform,
+                    expectedPacketCount,
+                    expectedPacketCount,
+                    2 * (long) expectedPacketCount,
+                    2 * (long) expectedInnerPacketSize,
+                    newReplayBitmap(expectedPacketCount));
+        }
+    }
 }
diff --git a/tests/cts/net/src/android/net/cts/IpSecManagerTunnelTest.java b/tests/cts/net/src/android/net/cts/IpSecManagerTunnelTest.java
index 1ede5c1..22a51d6 100644
--- a/tests/cts/net/src/android/net/cts/IpSecManagerTunnelTest.java
+++ b/tests/cts/net/src/android/net/cts/IpSecManagerTunnelTest.java
@@ -19,7 +19,9 @@
 import static android.app.AppOpsManager.OP_MANAGE_IPSEC_TUNNELS;
 import static android.net.IpSecManager.UdpEncapsulationSocket;
 import static android.net.cts.IpSecManagerTest.assumeExperimentalIpv6UdpEncapSupported;
+import static android.net.cts.IpSecManagerTest.assumeRequestIpSecTransformStateSupported;
 import static android.net.cts.IpSecManagerTest.isIpv6UdpEncapSupported;
+import static android.net.cts.IpSecManagerTest.isRequestTransformStateSupported;
 import static android.net.cts.PacketUtils.AES_CBC_BLK_SIZE;
 import static android.net.cts.PacketUtils.AES_CBC_IV_LEN;
 import static android.net.cts.PacketUtils.BytePayload;
@@ -117,6 +119,8 @@
 
     private static final int TIMEOUT_MS = 500;
 
+    private static final int PACKET_COUNT = 5000;
+
     // Static state to reduce setup/teardown
     private static ConnectivityManager sCM;
     private static TestNetworkManager sTNM;
@@ -256,7 +260,7 @@
     }
 
     /* Test runnables for callbacks after IPsec tunnels are set up. */
-    private abstract class IpSecTunnelTestRunnable {
+    private interface IpSecTunnelTestRunnable {
         /**
          * Runs the test code, and returns the inner socket port, if any.
          *
@@ -282,8 +286,7 @@
                 throws Exception;
     }
 
-    private int getPacketSize(
-            int innerFamily, int outerFamily, boolean useEncap, boolean transportInTunnelMode) {
+    private static int getInnerPacketSize(int innerFamily, boolean transportInTunnelMode) {
         int expectedPacketSize = TEST_DATA.length + UDP_HDRLEN;
 
         // Inner Transport mode packet size
@@ -299,6 +302,13 @@
         // Inner IP Header
         expectedPacketSize += innerFamily == AF_INET ? IP4_HDRLEN : IP6_HDRLEN;
 
+        return expectedPacketSize;
+    }
+
+    private static int getPacketSize(
+            int innerFamily, int outerFamily, boolean useEncap, boolean transportInTunnelMode) {
+        int expectedPacketSize = getInnerPacketSize(innerFamily, transportInTunnelMode);
+
         // Tunnel mode transform size
         expectedPacketSize =
                 PacketUtils.calculateEspPacketSize(
@@ -401,6 +411,20 @@
                             spi, TEST_DATA, useEncap, expectedPacketSize);
                     socket.close();
 
+                    if (isRequestTransformStateSupported()) {
+                        final int innerPacketSize =
+                                getInnerPacketSize(innerFamily, transportInTunnelMode);
+
+                        checkTransformState(
+                                outTunnelTransform,
+                                seqNum,
+                                0L,
+                                seqNum,
+                                seqNum * (long) innerPacketSize,
+                                newReplayBitmap(0));
+                        checkTransformStateNoTraffic(inTunnelTransform);
+                    }
+
                     return innerSocketPort;
                 }
             };
@@ -524,6 +548,22 @@
 
                     socket.close();
 
+                    if (isRequestTransformStateSupported()) {
+                        final int innerFamily =
+                                localInner instanceof Inet4Address ? AF_INET : AF_INET6;
+                        final int innerPacketSize =
+                                getInnerPacketSize(innerFamily, transportInTunnelMode);
+
+                        checkTransformStateNoTraffic(outTunnelTransform);
+                        checkTransformState(
+                                inTunnelTransform,
+                                0L,
+                                seqNum,
+                                seqNum,
+                                seqNum * (long) innerPacketSize,
+                                newReplayBitmap(seqNum));
+                    }
+
                     return 0;
                 }
             };
@@ -1127,6 +1167,18 @@
         return innerSocketPort;
     }
 
+    private int buildTunnelNetworkAndRunTestsSimple(int spi, IpSecTunnelTestRunnable test)
+            throws Exception {
+        return buildTunnelNetworkAndRunTests(
+                LOCAL_INNER_6,
+                REMOTE_INNER_6,
+                LOCAL_OUTER_6,
+                REMOTE_OUTER_6,
+                spi,
+                null /* encapSocket */,
+                test);
+    }
+
     private static void receiveAndValidatePacket(JavaUdpSocket socket) throws Exception {
         byte[] socketResponseBytes = socket.receive();
         assertArrayEquals(TEST_DATA, socketResponseBytes);
@@ -1691,4 +1743,101 @@
         assumeExperimentalIpv6UdpEncapSupported();
         doTestMigrateTunnelModeTransform(AF_INET6, AF_INET6, true, false);
     }
+
+    @IgnoreUpTo(Build.VERSION_CODES.UPSIDE_DOWN_CAKE)
+    @Test
+    public void testRequestIpSecTransformStateForRx() throws Exception {
+        assumeRequestIpSecTransformStateSupported();
+
+        final int spi = getRandomSpi(LOCAL_OUTER_6, REMOTE_OUTER_6);
+        buildTunnelNetworkAndRunTestsSimple(
+                spi,
+                (ipsecNetwork,
+                        tunnelIface,
+                        tunUtils,
+                        inTunnelTransform,
+                        outTunnelTransform,
+                        localOuter,
+                        remoteOuter,
+                        seqNum) -> {
+                    // Build a socket and send traffic
+                    final JavaUdpSocket socket = new JavaUdpSocket(LOCAL_INNER_6);
+                    ipsecNetwork.bindSocket(socket.mSocket);
+                    int innerSocketPort = socket.getPort();
+
+                    for (int i = 1; i < PACKET_COUNT + 1; i++) {
+                        byte[] pkt =
+                                getTunnelModePacket(
+                                        spi,
+                                        REMOTE_INNER_6,
+                                        LOCAL_INNER_6,
+                                        remoteOuter,
+                                        localOuter,
+                                        innerSocketPort,
+                                        0,
+                                        i);
+                        tunUtils.injectPacket(pkt);
+                        receiveAndValidatePacket(socket);
+                    }
+
+                    final int innerPacketSize = getInnerPacketSize(AF_INET6, false);
+                    checkTransformState(
+                            inTunnelTransform,
+                            0L,
+                            PACKET_COUNT,
+                            PACKET_COUNT,
+                            PACKET_COUNT * (long) innerPacketSize,
+                            newReplayBitmap(REPLAY_BITMAP_LEN_BYTE * 8));
+
+                    return innerSocketPort;
+                });
+    }
+
+    @IgnoreUpTo(Build.VERSION_CODES.UPSIDE_DOWN_CAKE)
+    @Test
+    public void testRequestIpSecTransformStateForTx() throws Exception {
+        assumeRequestIpSecTransformStateSupported();
+
+        final int spi = getRandomSpi(LOCAL_OUTER_6, REMOTE_OUTER_6);
+        buildTunnelNetworkAndRunTestsSimple(
+                spi,
+                (ipsecNetwork,
+                        tunnelIface,
+                        tunUtils,
+                        inTunnelTransform,
+                        outTunnelTransform,
+                        localOuter,
+                        remoteOuter,
+                        seqNum) -> {
+                    // Build a socket and send traffic
+                    final JavaUdpSocket outSocket = new JavaUdpSocket(LOCAL_INNER_6);
+                    ipsecNetwork.bindSocket(outSocket.mSocket);
+                    int innerSocketPort = outSocket.getPort();
+
+                    int expectedPacketSize =
+                            getPacketSize(
+                                    AF_INET6,
+                                    AF_INET6,
+                                    false /* useEncap */,
+                                    false /* transportInTunnelMode */);
+
+                    for (int i = 0; i < PACKET_COUNT; i++) {
+                        outSocket.sendTo(TEST_DATA, REMOTE_INNER_6, innerSocketPort);
+                        tunUtils.awaitEspPacketNoPlaintext(
+                                spi, TEST_DATA, false /* useEncap */, expectedPacketSize);
+                    }
+
+                    final int innerPacketSize =
+                            getInnerPacketSize(AF_INET6, false /* transportInTunnelMode */);
+                    checkTransformState(
+                            outTunnelTransform,
+                            PACKET_COUNT,
+                            0L,
+                            PACKET_COUNT,
+                            PACKET_COUNT * (long) innerPacketSize,
+                            newReplayBitmap(0));
+
+                    return innerSocketPort;
+                });
+    }
 }