Verify kernel implementation of ChaCha20Poly1305

This CL adds the test infra for testing kernel implementation of IPsec
algorithms and a test for ChaCha20Poly1305

Since there is no hardware that first launched with SDK beyond R at the
time of writing this CL, the test for ChaChaPoly was manually
enabled and verified on the pixel with an updated kernel

Bug: 171083832
Test: atest IpSecAlgorithmImplTest
Original-Change: https://android-review.googlesource.com/1503694
Merged-In: Ia29540c7fd6848a89bfa2d25c6a87921e45d98da
Change-Id: Ia29540c7fd6848a89bfa2d25c6a87921e45d98da
diff --git a/tests/cts/net/src/android/net/cts/IpSecAlgorithmImplTest.java b/tests/cts/net/src/android/net/cts/IpSecAlgorithmImplTest.java
new file mode 100644
index 0000000..3b110a4
--- /dev/null
+++ b/tests/cts/net/src/android/net/cts/IpSecAlgorithmImplTest.java
@@ -0,0 +1,225 @@
+/*
+ * Copyright (C) 2020 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package android.net.cts;
+
+import static android.net.IpSecAlgorithm.AUTH_CRYPT_CHACHA20_POLY1305;
+import static android.net.cts.PacketUtils.CHACHA20_POLY1305;
+import static android.net.cts.PacketUtils.CHACHA20_POLY1305_BLK_SIZE;
+import static android.net.cts.PacketUtils.CHACHA20_POLY1305_ICV_LEN;
+import static android.net.cts.PacketUtils.CHACHA20_POLY1305_IV_LEN;
+import static android.net.cts.PacketUtils.CHACHA20_POLY1305_KEY_LEN;
+import static android.net.cts.PacketUtils.CHACHA20_POLY1305_SALT_LEN;
+import static android.net.cts.PacketUtils.ESP_HDRLEN;
+import static android.net.cts.PacketUtils.IP6_HDRLEN;
+import static android.net.cts.PacketUtils.getIpHeader;
+import static android.net.cts.util.CtsNetUtils.TestNetworkCallback;
+
+import static com.android.compatibility.common.util.SystemUtil.runWithShellPermissionIdentity;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assume.assumeTrue;
+
+import android.net.IpSecAlgorithm;
+import android.net.IpSecManager;
+import android.net.IpSecTransform;
+import android.net.Network;
+import android.net.TestNetworkInterface;
+import android.net.cts.PacketUtils.BytePayload;
+import android.net.cts.PacketUtils.EspAeadCipher;
+import android.net.cts.PacketUtils.EspAuth;
+import android.net.cts.PacketUtils.EspAuthNull;
+import android.net.cts.PacketUtils.EspCipher;
+import android.net.cts.PacketUtils.EspHeader;
+import android.net.cts.PacketUtils.IpHeader;
+import android.net.cts.PacketUtils.UdpHeader;
+import android.platform.test.annotations.AppModeFull;
+
+import androidx.test.InstrumentationRegistry;
+import androidx.test.runner.AndroidJUnit4;
+
+import org.junit.Test;
+import org.junit.runner.RunWith;
+
+import java.net.DatagramPacket;
+import java.net.DatagramSocket;
+import java.net.InetAddress;
+import java.util.Arrays;
+
+@RunWith(AndroidJUnit4.class)
+@AppModeFull(reason = "Socket cannot bind in instant app mode")
+public class IpSecAlgorithmImplTest extends IpSecBaseTest {
+    private static final InetAddress LOCAL_ADDRESS =
+            InetAddress.parseNumericAddress("2001:db8:1::1");
+    private static final InetAddress REMOTE_ADDRESS =
+            InetAddress.parseNumericAddress("2001:db8:1::2");
+
+    private static final int REMOTE_PORT = 12345;
+    private static final IpSecManager IPSEC_MANAGER =
+            InstrumentationRegistry.getContext().getSystemService(IpSecManager.class);
+
+    private static class CheckCryptoImplTest implements TestNetworkRunnable.Test {
+        private final IpSecAlgorithm mIpsecEncryptAlgo;
+        private final IpSecAlgorithm mIpsecAuthAlgo;
+        private final IpSecAlgorithm mIpsecAeadAlgo;
+        private final EspCipher mEspCipher;
+        private final EspAuth mEspAuth;
+
+        public CheckCryptoImplTest(
+                IpSecAlgorithm ipsecEncryptAlgo,
+                IpSecAlgorithm ipsecAuthAlgo,
+                IpSecAlgorithm ipsecAeadAlgo,
+                EspCipher espCipher,
+                EspAuth espAuth) {
+            mIpsecEncryptAlgo = ipsecEncryptAlgo;
+            mIpsecAuthAlgo = ipsecAuthAlgo;
+            mIpsecAeadAlgo = ipsecAeadAlgo;
+            mEspCipher = espCipher;
+            mEspAuth = espAuth;
+        }
+
+        private static byte[] buildTransportModeEspPayload(
+                int srcPort, int dstPort, int spi, EspCipher espCipher, EspAuth espAuth)
+                throws Exception {
+            final UdpHeader udpPayload =
+                    new UdpHeader(srcPort, dstPort, new BytePayload(TEST_DATA));
+            final IpHeader preEspIpHeader =
+                    getIpHeader(
+                            udpPayload.getProtocolId(), LOCAL_ADDRESS, REMOTE_ADDRESS, udpPayload);
+
+            final PacketUtils.EspHeader espPayload =
+                    new EspHeader(
+                            udpPayload.getProtocolId(),
+                            spi,
+                            1 /* sequence number */,
+                            udpPayload.getPacketBytes(preEspIpHeader),
+                            espCipher,
+                            espAuth);
+            return espPayload.getPacketBytes(preEspIpHeader);
+        }
+
+        @Override
+        public void runTest(TestNetworkInterface testIface, TestNetworkCallback tunNetworkCallback)
+                throws Exception {
+            final TunUtils tunUtils = new TunUtils(testIface.getFileDescriptor());
+            tunNetworkCallback.waitForAvailable();
+            final Network testNetwork = tunNetworkCallback.currentNetwork;
+
+            final IpSecTransform.Builder transformBuilder =
+                    new IpSecTransform.Builder(InstrumentationRegistry.getContext());
+            if (mIpsecAeadAlgo != null) {
+                transformBuilder.setAuthenticatedEncryption(mIpsecAeadAlgo);
+            } else {
+                if (mIpsecEncryptAlgo != null) {
+                    transformBuilder.setEncryption(mIpsecEncryptAlgo);
+                }
+                if (mIpsecAuthAlgo != null) {
+                    transformBuilder.setAuthentication(mIpsecAuthAlgo);
+                }
+            }
+
+            try (final IpSecManager.SecurityParameterIndex outSpi =
+                            IPSEC_MANAGER.allocateSecurityParameterIndex(REMOTE_ADDRESS);
+                    final IpSecManager.SecurityParameterIndex inSpi =
+                            IPSEC_MANAGER.allocateSecurityParameterIndex(LOCAL_ADDRESS);
+                    final IpSecTransform outTransform =
+                            transformBuilder.buildTransportModeTransform(LOCAL_ADDRESS, outSpi);
+                    final IpSecTransform inTransform =
+                            transformBuilder.buildTransportModeTransform(REMOTE_ADDRESS, inSpi);
+                    // Bind localSocket to a random available port.
+                    final DatagramSocket localSocket = new DatagramSocket(0)) {
+                IPSEC_MANAGER.applyTransportModeTransform(
+                        localSocket, IpSecManager.DIRECTION_IN, inTransform);
+                IPSEC_MANAGER.applyTransportModeTransform(
+                        localSocket, IpSecManager.DIRECTION_OUT, outTransform);
+
+                // Send ESP packet
+                final DatagramPacket outPacket =
+                        new DatagramPacket(
+                                TEST_DATA, 0, TEST_DATA.length, REMOTE_ADDRESS, REMOTE_PORT);
+                testNetwork.bindSocket(localSocket);
+                localSocket.send(outPacket);
+                final byte[] outEspPacket =
+                        tunUtils.awaitEspPacket(outSpi.getSpi(), false /* useEncap */);
+
+                // Remove transform for good hygiene
+                IPSEC_MANAGER.removeTransportModeTransforms(localSocket);
+
+                // Get the kernel-generated ESP payload
+                final byte[] outEspPayload = new byte[outEspPacket.length - IP6_HDRLEN];
+                System.arraycopy(outEspPacket, IP6_HDRLEN, outEspPayload, 0, outEspPayload.length);
+
+                // Get the IV of the kernel-generated ESP payload
+                final byte[] iv =
+                        Arrays.copyOfRange(
+                                outEspPayload, ESP_HDRLEN, ESP_HDRLEN + mEspCipher.ivLen);
+
+                // Build ESP payload using the kernel-generated IV and the user space crypto
+                // implementations
+                mEspCipher.updateIv(iv);
+                final byte[] expectedEspPayload =
+                        buildTransportModeEspPayload(
+                                localSocket.getLocalPort(),
+                                REMOTE_PORT,
+                                outSpi.getSpi(),
+                                mEspCipher,
+                                mEspAuth);
+
+                // Compare user-space-generated and kernel-generated ESP payload
+                assertArrayEquals(expectedEspPayload, outEspPayload);
+            }
+        }
+
+        @Override
+        public void cleanupTest() {
+            // Do nothing
+        }
+
+        @Override
+        public InetAddress[] getTestNetworkAddresses() {
+            return new InetAddress[] {LOCAL_ADDRESS};
+        }
+    }
+
+    @Test
+    public void testChaCha20Poly1305() throws Exception {
+        assumeTrue(hasIpSecAlgorithm(AUTH_CRYPT_CHACHA20_POLY1305));
+
+        final byte[] cryptKey = getKeyBytes(CHACHA20_POLY1305_KEY_LEN);
+        final IpSecAlgorithm ipsecAeadAlgo =
+                new IpSecAlgorithm(
+                        IpSecAlgorithm.AUTH_CRYPT_CHACHA20_POLY1305,
+                        cryptKey,
+                        CHACHA20_POLY1305_ICV_LEN * 8);
+        final EspAeadCipher espAead =
+                new EspAeadCipher(
+                        CHACHA20_POLY1305,
+                        CHACHA20_POLY1305_BLK_SIZE,
+                        cryptKey,
+                        CHACHA20_POLY1305_IV_LEN,
+                        CHACHA20_POLY1305_ICV_LEN,
+                        CHACHA20_POLY1305_SALT_LEN);
+
+        runWithShellPermissionIdentity(
+                new TestNetworkRunnable(
+                        new CheckCryptoImplTest(
+                                null /* ipsecEncryptAlgo */,
+                                null /* ipsecAuthAlgo */,
+                                ipsecAeadAlgo,
+                                espAead,
+                                EspAuthNull.getInstance())));
+    }
+}
diff --git a/tests/cts/net/src/android/net/cts/PacketUtils.java b/tests/cts/net/src/android/net/cts/PacketUtils.java
index 5da0d26..27c9f3b 100644
--- a/tests/cts/net/src/android/net/cts/PacketUtils.java
+++ b/tests/cts/net/src/android/net/cts/PacketUtils.java
@@ -43,6 +43,7 @@
     static final int UDP_HDRLEN = 8;
     static final int TCP_HDRLEN = 20;
     static final int TCP_HDRLEN_WITH_TIMESTAMP_OPT = TCP_HDRLEN + 12;
+    static final int ESP_HDRLEN = 8;
     static final int ESP_BLK_SIZE = 4; // ESP has to be 4-byte aligned
     static final int ESP_TRAILER_LEN = 2;
 
@@ -61,8 +62,10 @@
     // AEAD parameters
     static final int AES_GCM_IV_LEN = 8;
     static final int AES_GCM_BLK_SIZE = 4;
+    static final int CHACHA20_POLY1305_KEY_LEN = 36;
     static final int CHACHA20_POLY1305_BLK_SIZE = ESP_BLK_SIZE;
     static final int CHACHA20_POLY1305_IV_LEN = 8;
+    static final int CHACHA20_POLY1305_SALT_LEN = 4;
     static final int CHACHA20_POLY1305_ICV_LEN = 16;
 
     // Authentication parameters
@@ -77,6 +80,11 @@
     // Encryption algorithms
     static final String AES = "AES";
     static final String AES_CBC = "AES/CBC/NoPadding";
+
+    // AEAD algorithms
+    static final String CHACHA20_POLY1305 = "ChaCha20/Poly1305/NoPadding";
+
+    // Authentication algorithms
     static final String HMAC_SHA_256 = "HmacSHA256";
 
     public interface Payload {
@@ -372,6 +380,11 @@
             if (cipher instanceof EspCipherNull && auth instanceof EspAuthNull) {
                 throw new IllegalArgumentException("No algorithm is provided");
             }
+
+            if (cipher instanceof EspAeadCipher && !(auth instanceof EspAuthNull)) {
+                throw new IllegalArgumentException(
+                        "AEAD is provided with an authentication" + " algorithm.");
+            }
         }
 
         private static EspCipher getDefaultCipher(byte[] key) {
@@ -387,8 +400,10 @@
         }
 
         public short length() {
+            final int icvLen =
+                    cipher instanceof EspAeadCipher ? ((EspAeadCipher) cipher).icvLen : auth.icvLen;
             return calculateEspPacketSize(
-                    payload.length, cipher.ivLen, cipher.blockSize, auth.icvLen * 8);
+                    payload.length, cipher.ivLen, cipher.blockSize, icvLen * 8);
         }
 
         public byte[] getPacketBytes(IpHeader header) throws Exception {
@@ -426,12 +441,11 @@
 
     public static short calculateEspPacketSize(
             int payloadLen, int cryptIvLength, int cryptBlockSize, int authTruncLen) {
-        final int ESP_HDRLEN = 4 + 4; // SPI + Seq#
         final int ICV_LEN = authTruncLen / 8; // Auth trailer; based on truncation length
-        payloadLen += cryptIvLength; // Initialization Vector
 
         // Align to block size of encryption algorithm
         payloadLen = calculateEspEncryptedLength(payloadLen, cryptBlockSize);
+        payloadLen += cryptIvLength; // Initialization Vector
         return (short) (payloadLen + ESP_HDRLEN + ICV_LEN);
     }
 
@@ -464,20 +478,28 @@
     }
 
     public abstract static class EspCipher {
+        protected static final int SALT_LEN_UNUSED = 0;
+
         public final String algoName;
         public final int blockSize;
         public final byte[] key;
         public final int ivLen;
+        public final int saltLen;
         protected byte[] iv;
 
-        public EspCipher(String algoName, int blockSize, byte[] key, int ivLen) {
+        public EspCipher(String algoName, int blockSize, byte[] key, int ivLen, int saltLen) {
             this.algoName = algoName;
             this.blockSize = blockSize;
             this.key = key;
             this.ivLen = ivLen;
+            this.saltLen = saltLen;
             this.iv = getIv(ivLen);
         }
 
+        public void updateIv(byte[] iv) {
+            this.iv = iv;
+        }
+
         public static byte[] getPaddedPayload(int nextHeader, byte[] payload, int blockSize) {
             final int paddedLen = calculateEspEncryptedLength(payload.length, blockSize);
             final ByteBuffer paddedPayload = ByteBuffer.allocate(paddedLen);
@@ -514,7 +536,7 @@
         private static final EspCipherNull INSTANCE = new EspCipherNull();
 
         private EspCipherNull() {
-            super(CRYPT_NULL, ESP_BLK_SIZE, KEY_UNUSED, IV_LEN_UNUSED);
+            super(CRYPT_NULL, ESP_BLK_SIZE, KEY_UNUSED, IV_LEN_UNUSED, SALT_LEN_UNUSED);
         }
 
         public static EspCipherNull getInstance() {
@@ -530,7 +552,7 @@
 
     public static class EspCryptCipher extends EspCipher {
         public EspCryptCipher(String algoName, int blockSize, byte[] key, int ivLen) {
-            super(algoName, blockSize, key, ivLen);
+            super(algoName, blockSize, key, ivLen, SALT_LEN_UNUSED);
         }
 
         @Override
@@ -554,7 +576,50 @@
         }
     }
 
-    // TODO: Implement EspAeadCipher in the following CL
+    public static class EspAeadCipher extends EspCipher {
+        public final int icvLen;
+
+        public EspAeadCipher(
+                String algoName, int blockSize, byte[] key, int ivLen, int icvLen, int saltLen) {
+            super(algoName, blockSize, key, ivLen, saltLen);
+            this.icvLen = icvLen;
+        }
+
+        @Override
+        public byte[] getCipherText(int nextHeader, byte[] payload, int spi, int seqNum)
+                throws GeneralSecurityException {
+            // Provided key consists of encryption/decryption key plus salt. Salt is used
+            // with ESP payload IV to build IvParameterSpec.
+            final byte[] secretKey = Arrays.copyOfRange(key, 0, key.length - saltLen);
+            final byte[] salt = Arrays.copyOfRange(key, secretKey.length, key.length);
+
+            final SecretKeySpec secretKeySpec = new SecretKeySpec(secretKey, algoName);
+
+            final ByteBuffer ivParameterBuffer = ByteBuffer.allocate(saltLen + iv.length);
+            ivParameterBuffer.put(salt);
+            ivParameterBuffer.put(iv);
+            final IvParameterSpec ivParameterSpec = new IvParameterSpec(ivParameterBuffer.array());
+
+            final ByteBuffer aadBuffer = ByteBuffer.allocate(ESP_HDRLEN);
+            aadBuffer.putInt(spi);
+            aadBuffer.putInt(seqNum);
+
+            // Encrypt payload
+            final Cipher cipher = Cipher.getInstance(algoName);
+            cipher.init(Cipher.ENCRYPT_MODE, secretKeySpec, ivParameterSpec);
+            cipher.updateAAD(aadBuffer.array());
+            final byte[] encryptedTextAndIcv =
+                    cipher.doFinal(getPaddedPayload(nextHeader, payload, blockSize));
+
+            // Build ciphertext
+            final ByteBuffer cipherText =
+                    ByteBuffer.allocate(iv.length + encryptedTextAndIcv.length);
+            cipherText.put(iv);
+            cipherText.put(encryptedTextAndIcv);
+
+            return getByteArrayFromBuffer(cipherText);
+        }
+    }
 
     public static class EspAuth {
         public final String algoName;
diff --git a/tests/cts/net/src/android/net/cts/TunUtils.java b/tests/cts/net/src/android/net/cts/TunUtils.java
index 7887385..d8e39b4 100644
--- a/tests/cts/net/src/android/net/cts/TunUtils.java
+++ b/tests/cts/net/src/android/net/cts/TunUtils.java
@@ -147,6 +147,10 @@
         return espPkt; // We've found the packet we're looking for.
     }
 
+    public byte[] awaitEspPacket(int spi, boolean useEncap) throws Exception {
+        return awaitPacket((pkt) -> isEsp(pkt, spi, useEncap));
+    }
+
     private static boolean isSpiEqual(byte[] pkt, int espOffset, int spi) {
         // Check SPI byte by byte.
         return pkt[espOffset] == (byte) ((spi >>> 24) & 0xff)