Merge changes from topic "experimental-udp-encap" am: 5d8c3b1428

Original change: https://android-review.googlesource.com/c/platform/packages/modules/Connectivity/+/2422042

Change-Id: I19e71b76fa02b5d4ebc7fb56b439613d0c4d9011
Signed-off-by: Automerger Merge Worker <android-build-automerger-merge-worker@system.gserviceaccount.com>
diff --git a/service-t/src/com/android/server/IpSecService.java b/service-t/src/com/android/server/IpSecService.java
index 9e71eb3..1a51e91 100644
--- a/service-t/src/com/android/server/IpSecService.java
+++ b/service-t/src/com/android/server/IpSecService.java
@@ -17,6 +17,7 @@
 package com.android.server;
 
 import static android.Manifest.permission.DUMP;
+import static android.Manifest.permission.NETWORK_SETTINGS;
 import static android.net.IpSecManager.FEATURE_IPSEC_TUNNEL_MIGRATION;
 import static android.net.IpSecManager.INVALID_RESOURCE_ID;
 import static android.system.OsConstants.AF_INET;
@@ -65,6 +66,7 @@
 import com.android.internal.annotations.GuardedBy;
 import com.android.internal.annotations.VisibleForTesting;
 import com.android.internal.util.Preconditions;
+import com.android.modules.utils.build.SdkLevel;
 import com.android.net.module.util.BinderUtils;
 import com.android.net.module.util.NetdUtils;
 import com.android.net.module.util.PermissionUtils;
@@ -102,6 +104,7 @@
 
     private static final int NETD_FETCH_TIMEOUT_MS = 5000; // ms
     private static final InetAddress INADDR_ANY;
+    private static final InetAddress IN6ADDR_ANY;
 
     @VisibleForTesting static final int MAX_PORT_BIND_ATTEMPTS = 10;
 
@@ -110,6 +113,8 @@
     static {
         try {
             INADDR_ANY = InetAddress.getByAddress(new byte[] {0, 0, 0, 0});
+            IN6ADDR_ANY = InetAddress.getByAddress(
+                    new byte[] {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
         } catch (UnknownHostException e) {
             throw new RuntimeException(e);
         }
@@ -1013,11 +1018,13 @@
     private final class EncapSocketRecord extends OwnedResourceRecord {
         private FileDescriptor mSocket;
         private final int mPort;
+        private final int mFamily;  // TODO: what about IPV6_ADDRFORM?
 
-        EncapSocketRecord(int resourceId, FileDescriptor socket, int port) {
+        EncapSocketRecord(int resourceId, FileDescriptor socket, int port, int family) {
             super(resourceId);
             mSocket = socket;
             mPort = port;
+            mFamily = family;
         }
 
         /** always guarded by IpSecService#this */
@@ -1038,6 +1045,10 @@
             return mSocket;
         }
 
+        public int getFamily() {
+            return mFamily;
+        }
+
         @Override
         protected ResourceTracker getResourceTracker() {
             return getUserRecord().mSocketQuotaTracker;
@@ -1210,15 +1221,16 @@
      * and re-binding, during which the system could *technically* hand that port out to someone
      * else.
      */
-    private int bindToRandomPort(FileDescriptor sockFd) throws IOException {
+    private int bindToRandomPort(FileDescriptor sockFd, int family) throws IOException {
+        final InetAddress any = (family == AF_INET6) ? IN6ADDR_ANY : INADDR_ANY;
         for (int i = MAX_PORT_BIND_ATTEMPTS; i > 0; i--) {
             try {
-                FileDescriptor probeSocket = Os.socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP);
-                Os.bind(probeSocket, INADDR_ANY, 0);
+                FileDescriptor probeSocket = Os.socket(family, SOCK_DGRAM, IPPROTO_UDP);
+                Os.bind(probeSocket, any, 0);
                 int port = ((InetSocketAddress) Os.getsockname(probeSocket)).getPort();
                 Os.close(probeSocket);
                 Log.v(TAG, "Binding to port " + port);
-                Os.bind(sockFd, INADDR_ANY, port);
+                Os.bind(sockFd, any, port);
                 return port;
             } catch (ErrnoException e) {
                 // Someone miraculously claimed the port just after we closed probeSocket.
@@ -1260,6 +1272,19 @@
     @Override
     public synchronized IpSecUdpEncapResponse openUdpEncapsulationSocket(int port, IBinder binder)
             throws RemoteException {
+        // Experimental support for IPv6 UDP encap.
+        final int family;
+        final InetAddress localAddr;
+        if (SdkLevel.isAtLeastU() && port >= 65536) {
+            PermissionUtils.enforceNetworkStackPermissionOr(mContext, NETWORK_SETTINGS);
+            port -= 65536;
+            family = AF_INET6;
+            localAddr = IN6ADDR_ANY;
+        } else {
+            family = AF_INET;
+            localAddr = INADDR_ANY;
+        }
+
         if (port != 0 && (port < FREE_PORT_MIN || port > PORT_MAX)) {
             throw new IllegalArgumentException(
                     "Specified port number must be a valid non-reserved UDP port");
@@ -1278,7 +1303,7 @@
 
             FileDescriptor sockFd = null;
             try {
-                sockFd = Os.socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP);
+                sockFd = Os.socket(family, SOCK_DGRAM, IPPROTO_UDP);
                 pFd = ParcelFileDescriptor.dup(sockFd);
             } finally {
                 IoUtils.closeQuietly(sockFd);
@@ -1295,15 +1320,16 @@
             mNetd.ipSecSetEncapSocketOwner(pFd, callingUid);
             if (port != 0) {
                 Log.v(TAG, "Binding to port " + port);
-                Os.bind(pFd.getFileDescriptor(), INADDR_ANY, port);
+                Os.bind(pFd.getFileDescriptor(), localAddr, port);
             } else {
-                port = bindToRandomPort(pFd.getFileDescriptor());
+                port = bindToRandomPort(pFd.getFileDescriptor(), family);
             }
 
             userRecord.mEncapSocketRecords.put(
                     resourceId,
                     new RefcountedResource<EncapSocketRecord>(
-                            new EncapSocketRecord(resourceId, pFd.getFileDescriptor(), port),
+                            new EncapSocketRecord(resourceId, pFd.getFileDescriptor(), port,
+                                    family),
                             binder));
             return new IpSecUdpEncapResponse(IpSecManager.Status.OK, resourceId, port,
                     pFd.getFileDescriptor());
@@ -1580,6 +1606,7 @@
      */
     private void checkIpSecConfig(IpSecConfig config) {
         UserRecord userRecord = mUserResourceTracker.getUserRecord(Binder.getCallingUid());
+        EncapSocketRecord encapSocketRecord = null;
 
         switch (config.getEncapType()) {
             case IpSecTransform.ENCAP_NONE:
@@ -1587,7 +1614,7 @@
             case IpSecTransform.ENCAP_ESPINUDP:
             case IpSecTransform.ENCAP_ESPINUDP_NON_IKE:
                 // Retrieve encap socket record; will throw IllegalArgumentException if not found
-                userRecord.mEncapSocketRecords.getResourceOrThrow(
+                encapSocketRecord = userRecord.mEncapSocketRecords.getResourceOrThrow(
                         config.getEncapSocketResourceId());
 
                 int port = config.getEncapRemotePort();
@@ -1641,10 +1668,9 @@
                             + ") have different address families.");
         }
 
-        // Throw an error if UDP Encapsulation is not used in IPv4.
-        if (config.getEncapType() != IpSecTransform.ENCAP_NONE && sourceFamily != AF_INET) {
+        if (encapSocketRecord != null && encapSocketRecord.getFamily() != destinationFamily) {
             throw new IllegalArgumentException(
-                    "UDP Encapsulation is not supported for this address family");
+                    "UDP encapsulation socket and destination address families must match");
         }
 
         switch (config.getMode()) {
diff --git a/tests/cts/net/src/android/net/cts/IpSecManagerTest.java b/tests/cts/net/src/android/net/cts/IpSecManagerTest.java
index 8234ec1..4fa0080 100644
--- a/tests/cts/net/src/android/net/cts/IpSecManagerTest.java
+++ b/tests/cts/net/src/android/net/cts/IpSecManagerTest.java
@@ -16,6 +16,7 @@
 
 package android.net.cts;
 
+import static android.Manifest.permission.NETWORK_SETTINGS;
 import static android.net.IpSecAlgorithm.AUTH_AES_CMAC;
 import static android.net.IpSecAlgorithm.AUTH_AES_XCBC;
 import static android.net.IpSecAlgorithm.AUTH_CRYPT_AES_GCM;
@@ -52,7 +53,9 @@
 
 import static com.android.compatibility.common.util.PropertyUtil.getFirstApiLevel;
 import static com.android.compatibility.common.util.PropertyUtil.getVendorApiLevel;
+import static com.android.testutils.DeviceInfoUtils.isKernelVersionAtLeast;
 import static com.android.testutils.MiscAsserts.assertThrows;
+import static com.android.testutils.TestPermissionUtil.runAsShell;
 
 import static org.junit.Assert.assertArrayEquals;
 import static org.junit.Assert.assertEquals;
@@ -62,6 +65,8 @@
 
 import android.net.IpSecAlgorithm;
 import android.net.IpSecManager;
+import android.net.IpSecManager.SecurityParameterIndex;
+import android.net.IpSecManager.UdpEncapsulationSocket;
 import android.net.IpSecTransform;
 import android.net.TrafficStats;
 import android.os.Build;
@@ -73,6 +78,7 @@
 import androidx.test.InstrumentationRegistry;
 import androidx.test.runner.AndroidJUnit4;
 
+import com.android.modules.utils.build.SdkLevel;
 import com.android.testutils.DevSdkIgnoreRule;
 import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo;
 
@@ -120,7 +126,7 @@
     @Test
     public void testAllocSpi() throws Exception {
         for (InetAddress addr : GOOGLE_DNS_LIST) {
-            IpSecManager.SecurityParameterIndex randomSpi = null, droidSpi = null;
+            SecurityParameterIndex randomSpi, droidSpi;
             randomSpi = mISM.allocateSecurityParameterIndex(addr);
             assertTrue(
                     "Failed to receive a valid SPI",
@@ -258,6 +264,24 @@
         accepted.close();
     }
 
+    private IpSecTransform buildTransportModeTransform(
+            SecurityParameterIndex spi, InetAddress localAddr,
+            UdpEncapsulationSocket encapSocket)
+            throws Exception {
+        final IpSecTransform.Builder builder =
+                new IpSecTransform.Builder(InstrumentationRegistry.getContext())
+                        .setEncryption(new IpSecAlgorithm(IpSecAlgorithm.CRYPT_AES_CBC, CRYPT_KEY))
+                        .setAuthentication(
+                                new IpSecAlgorithm(
+                                        IpSecAlgorithm.AUTH_HMAC_SHA256,
+                                        AUTH_KEY,
+                                        AUTH_KEY.length * 8));
+        if (encapSocket != null) {
+            builder.setIpv4Encapsulation(encapSocket, encapSocket.getPort());
+        }
+        return builder.buildTransportModeTransform(localAddr, spi);
+    }
+
     /*
      * Alloc outbound SPI
      * Alloc inbound SPI
@@ -268,21 +292,8 @@
      * release transform
      * send data (expect exception)
      */
-    @Test
-    public void testCreateTransform() throws Exception {
-        InetAddress localAddr = InetAddress.getByName(IPV4_LOOPBACK);
-        IpSecManager.SecurityParameterIndex spi =
-                mISM.allocateSecurityParameterIndex(localAddr);
-
-        IpSecTransform transform =
-                new IpSecTransform.Builder(InstrumentationRegistry.getContext())
-                        .setEncryption(new IpSecAlgorithm(IpSecAlgorithm.CRYPT_AES_CBC, CRYPT_KEY))
-                        .setAuthentication(
-                                new IpSecAlgorithm(
-                                        IpSecAlgorithm.AUTH_HMAC_SHA256,
-                                        AUTH_KEY,
-                                        AUTH_KEY.length * 8))
-                        .buildTransportModeTransform(localAddr, spi);
+    private void doTestCreateTransform(String loopbackAddrString, boolean encap) throws Exception {
+        InetAddress localAddr = InetAddress.getByName(loopbackAddrString);
 
         final boolean [][] applyInApplyOut = {
                 {false, false}, {false, true}, {true, false}, {true,true}};
@@ -291,50 +302,93 @@
 
         byte[] in = new byte[data.length];
         DatagramPacket inPacket = new DatagramPacket(in, in.length);
-        DatagramSocket localSocket;
         int localPort;
 
         for(boolean[] io : applyInApplyOut) {
             boolean applyIn = io[0];
             boolean applyOut = io[1];
-            // Bind localSocket to a random available port.
-            localSocket = new DatagramSocket(0);
-            localPort = localSocket.getLocalPort();
-            localSocket.setSoTimeout(200);
-            outPacket.setPort(localPort);
-            if (applyIn) {
-                mISM.applyTransportModeTransform(
-                        localSocket, IpSecManager.DIRECTION_IN, transform);
-            }
-            if (applyOut) {
-                mISM.applyTransportModeTransform(
-                        localSocket, IpSecManager.DIRECTION_OUT, transform);
-            }
-            if (applyIn == applyOut) {
-                localSocket.send(outPacket);
-                localSocket.receive(inPacket);
-                assertTrue("Encapsulated data did not match.",
-                        Arrays.equals(outPacket.getData(), inPacket.getData()));
-                mISM.removeTransportModeTransforms(localSocket);
-                localSocket.close();
-            } else {
-                try {
+            try (
+                SecurityParameterIndex spi = mISM.allocateSecurityParameterIndex(localAddr);
+                UdpEncapsulationSocket encapSocket = encap
+                        ? getPrivilegedUdpEncapSocket(/*ipv6=*/ localAddr instanceof Inet6Address)
+                        : null;
+                IpSecTransform transform = buildTransportModeTransform(spi, localAddr,
+                        encapSocket);
+                // Bind localSocket to a random available port.
+                DatagramSocket localSocket = new DatagramSocket(0);
+            ) {
+                localPort = localSocket.getLocalPort();
+                localSocket.setSoTimeout(200);
+                outPacket.setPort(localPort);
+                if (applyIn) {
+                    mISM.applyTransportModeTransform(
+                            localSocket, IpSecManager.DIRECTION_IN, transform);
+                }
+                if (applyOut) {
+                    mISM.applyTransportModeTransform(
+                            localSocket, IpSecManager.DIRECTION_OUT, transform);
+                }
+                if (applyIn == applyOut) {
                     localSocket.send(outPacket);
                     localSocket.receive(inPacket);
-                } catch (IOException e) {
-                    continue;
-                } finally {
+                    assertTrue("Encrypted data did not match.",
+                            Arrays.equals(outPacket.getData(), inPacket.getData()));
                     mISM.removeTransportModeTransforms(localSocket);
-                    localSocket.close();
+                } else {
+                    try {
+                        localSocket.send(outPacket);
+                        localSocket.receive(inPacket);
+                    } catch (IOException e) {
+                        continue;
+                    } finally {
+                        mISM.removeTransportModeTransforms(localSocket);
+                    }
+                    // FIXME: This check is disabled because sockets currently receive data
+                    // if there is a valid SA for decryption, even when the input policy is
+                    // not applied to a socket.
+                    //  fail("Data IO should fail on asymmetrical transforms! + Input="
+                    //          + applyIn + " Output=" + applyOut);
                 }
-                // FIXME: This check is disabled because sockets currently receive data
-                // if there is a valid SA for decryption, even when the input policy is
-                // not applied to a socket.
-                //  fail("Data IO should fail on asymmetrical transforms! + Input="
-                //          + applyIn + " Output=" + applyOut);
             }
         }
-        transform.close();
+    }
+
+    private UdpEncapsulationSocket getPrivilegedUdpEncapSocket(boolean ipv6) throws Exception {
+        return runAsShell(NETWORK_SETTINGS, () -> {
+            if (ipv6) {
+                return mISM.openUdpEncapsulationSocket(65536);
+            } else {
+                // Can't pass 0 to IpSecManager#openUdpEncapsulationSocket(int).
+                return mISM.openUdpEncapsulationSocket();
+            }
+        });
+    }
+
+    private void assumeExperimentalIpv6UdpEncapSupported() throws Exception {
+        assumeTrue("Not supported before U", SdkLevel.isAtLeastU());
+        assumeTrue("Not supported by kernel", isKernelVersionAtLeast("5.15.31")
+                || (isKernelVersionAtLeast("5.10.108") && !isKernelVersionAtLeast("5.15.0")));
+    }
+
+    @Test
+    public void testCreateTransformIpv4() throws Exception {
+        doTestCreateTransform(IPV4_LOOPBACK, false);
+    }
+
+    @Test
+    public void testCreateTransformIpv6() throws Exception {
+        doTestCreateTransform(IPV6_LOOPBACK, false);
+    }
+
+    @Test
+    public void testCreateTransformIpv4Encap() throws Exception {
+        doTestCreateTransform(IPV4_LOOPBACK, true);
+    }
+
+    @Test
+    public void testCreateTransformIpv6Encap() throws Exception {
+        assumeExperimentalIpv6UdpEncapSupported();
+        doTestCreateTransform(IPV6_LOOPBACK, true);
     }
 
     /** Snapshot of TrafficStats as of initStatsChecker call for later comparisons */
@@ -503,8 +557,8 @@
         StatsChecker.initStatsChecker();
         InetAddress local = InetAddress.getByName(localAddress);
 
-        try (IpSecManager.UdpEncapsulationSocket encapSocket = mISM.openUdpEncapsulationSocket();
-                IpSecManager.SecurityParameterIndex spi =
+        try (UdpEncapsulationSocket encapSocket = mISM.openUdpEncapsulationSocket();
+                SecurityParameterIndex spi =
                         mISM.allocateSecurityParameterIndex(local)) {
 
             IpSecTransform.Builder transformBuilder =
@@ -656,7 +710,7 @@
     public void testIkeOverUdpEncapSocket() throws Exception {
         // IPv6 not supported for UDP-encap-ESP
         InetAddress local = InetAddress.getByName(IPV4_LOOPBACK);
-        try (IpSecManager.UdpEncapsulationSocket encapSocket = mISM.openUdpEncapsulationSocket()) {
+        try (UdpEncapsulationSocket encapSocket = mISM.openUdpEncapsulationSocket()) {
             NativeUdpSocket wrappedEncapSocket =
                     new NativeUdpSocket(encapSocket.getFileDescriptor());
             checkIkePacket(wrappedEncapSocket, local);
@@ -665,7 +719,7 @@
             IpSecAlgorithm crypt = new IpSecAlgorithm(IpSecAlgorithm.CRYPT_AES_CBC, CRYPT_KEY);
             IpSecAlgorithm auth = new IpSecAlgorithm(IpSecAlgorithm.AUTH_HMAC_MD5, getKey(128), 96);
 
-            try (IpSecManager.SecurityParameterIndex spi =
+            try (SecurityParameterIndex spi =
                             mISM.allocateSecurityParameterIndex(local);
                     IpSecTransform transform =
                             new IpSecTransform.Builder(InstrumentationRegistry.getContext())
@@ -1498,7 +1552,7 @@
 
     @Test
     public void testOpenUdpEncapSocketSpecificPort() throws Exception {
-        IpSecManager.UdpEncapsulationSocket encapSocket = null;
+        UdpEncapsulationSocket encapSocket = null;
         int port = -1;
         for (int i = 0; i < MAX_PORT_BIND_ATTEMPTS; i++) {
             try {
@@ -1527,7 +1581,7 @@
 
     @Test
     public void testOpenUdpEncapSocketRandomPort() throws Exception {
-        try (IpSecManager.UdpEncapsulationSocket encapSocket = mISM.openUdpEncapsulationSocket()) {
+        try (UdpEncapsulationSocket encapSocket = mISM.openUdpEncapsulationSocket()) {
             assertTrue("Returned invalid port", encapSocket.getPort() != 0);
         }
     }
diff --git a/tests/unit/java/com/android/server/IpSecServiceTest.java b/tests/unit/java/com/android/server/IpSecServiceTest.java
index 6955620..4b6857c 100644
--- a/tests/unit/java/com/android/server/IpSecServiceTest.java
+++ b/tests/unit/java/com/android/server/IpSecServiceTest.java
@@ -82,7 +82,7 @@
     private static final int MAX_NUM_ENCAP_SOCKETS = 100;
     private static final int MAX_NUM_SPIS = 100;
     private static final int TEST_UDP_ENCAP_INVALID_PORT = 100;
-    private static final int TEST_UDP_ENCAP_PORT_OUT_RANGE = 100000;
+    private static final int TEST_UDP_ENCAP_PORT_OUT_RANGE = 200000;
 
     private static final InetAddress INADDR_ANY;