Add a test for closing sockets when a VPN comes up.

Bug: 28251576
Change-Id: Iab0a8643cff3c54eb04168a7cdfa116c0b8e30b1
diff --git a/tests/cts/hostside/app/Android.mk b/tests/cts/hostside/app/Android.mk
index 7f8da07..9519ec5 100644
--- a/tests/cts/hostside/app/Android.mk
+++ b/tests/cts/hostside/app/Android.mk
@@ -20,7 +20,8 @@
 
 LOCAL_MODULE_TAGS := tests
 LOCAL_SDK_VERSION := current
-LOCAL_STATIC_JAVA_LIBRARIES := ctsdeviceutil ctstestrunner ub-uiautomator
+LOCAL_STATIC_JAVA_LIBRARIES := ctsdeviceutil ctstestrunner ub-uiautomator \
+        CtsHostsideNetworkTestsAidl
 
 LOCAL_SRC_FILES := $(call all-java-files-under, src)
 
diff --git a/tests/cts/hostside/app/src/com/android/cts/net/hostside/RemoteSocketFactoryClient.java b/tests/cts/hostside/app/src/com/android/cts/net/hostside/RemoteSocketFactoryClient.java
new file mode 100644
index 0000000..799fe50
--- /dev/null
+++ b/tests/cts/hostside/app/src/com/android/cts/net/hostside/RemoteSocketFactoryClient.java
@@ -0,0 +1,91 @@
+/*
+ * Copyright (C) 2016 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.cts.net.hostside;
+
+import android.content.ComponentName;
+import android.content.Context;
+import android.content.Intent;
+import android.content.ServiceConnection;
+import android.os.ConditionVariable;
+import android.os.IBinder;
+import android.os.RemoteException;
+
+import com.android.cts.net.hostside.IRemoteSocketFactory;
+
+import java.io.FileDescriptor;
+
+public class RemoteSocketFactoryClient {
+    private static final int TIMEOUT_MS = 5000;
+    private static final String PACKAGE = RemoteSocketFactoryClient.class.getPackage().getName();
+    private static final String APP2_PACKAGE = PACKAGE + ".app2";
+    private static final String SERVICE_NAME = APP2_PACKAGE + ".RemoteSocketFactoryService";
+
+    private Context mContext;
+    private ServiceConnection mServiceConnection;
+    private IRemoteSocketFactory mService;
+
+    public RemoteSocketFactoryClient(Context context) {
+        mContext = context;
+    }
+
+    public void bind() {
+        if (mService != null) {
+            throw new IllegalStateException("Already bound");
+        }
+
+        final ConditionVariable cv = new ConditionVariable();
+        mServiceConnection = new ServiceConnection() {
+            @Override
+            public void onServiceConnected(ComponentName name, IBinder service) {
+                mService = IRemoteSocketFactory.Stub.asInterface(service);
+                cv.open();
+            }
+            @Override
+            public void onServiceDisconnected(ComponentName name) {
+                mService = null;
+            }
+        };
+
+        final Intent intent = new Intent();
+        intent.setComponent(new ComponentName(APP2_PACKAGE, SERVICE_NAME));
+        mContext.bindService(intent, mServiceConnection, Context.BIND_AUTO_CREATE);
+        cv.block(TIMEOUT_MS);
+        if (mService == null) {
+            throw new IllegalStateException(
+                    "Could not bind to RemoteSocketFactory service after " + TIMEOUT_MS + "ms");
+        }
+    }
+
+    public void unbind() {
+        if (mService != null) {
+            mContext.unbindService(mServiceConnection);
+        }
+    }
+
+    public FileDescriptor openSocketFd(
+            String host, int port, int timeoutMs) throws RemoteException {
+        return mService.openSocketFd(host, port, timeoutMs).getFileDescriptor();
+    }
+
+    public String getPackageName() throws RemoteException {
+        return mService.getPackageName();
+    }
+
+    public int getUid() throws RemoteException {
+        return mService.getUid();
+    }
+}
diff --git a/tests/cts/hostside/app/src/com/android/cts/net/hostside/VpnTest.java b/tests/cts/hostside/app/src/com/android/cts/net/hostside/VpnTest.java
index 5045cc2..12fe625 100755
--- a/tests/cts/hostside/app/src/com/android/cts/net/hostside/VpnTest.java
+++ b/tests/cts/hostside/app/src/com/android/cts/net/hostside/VpnTest.java
@@ -27,6 +27,8 @@
 import android.net.NetworkCapabilities;
 import android.net.NetworkRequest;
 import android.net.VpnService;
+import android.os.ParcelFileDescriptor;
+import android.os.Process;
 import android.support.test.uiautomator.UiDevice;
 import android.support.test.uiautomator.UiObject;
 import android.support.test.uiautomator.UiObjectNotFoundException;
@@ -40,11 +42,18 @@
 import android.text.TextUtils;
 import android.util.Log;
 
+import com.android.cts.net.hostside.IRemoteSocketFactory;
+
+import java.io.BufferedReader;
 import java.io.Closeable;
 import java.io.FileDescriptor;
+import java.io.FileOutputStream;
+import java.io.FileInputStream;
+import java.io.InputStreamReader;
 import java.io.IOException;
 import java.io.InputStream;
 import java.io.OutputStream;
+import java.io.PrintWriter;
 import java.net.DatagramPacket;
 import java.net.DatagramSocket;
 import java.net.Inet6Address;
@@ -52,6 +61,8 @@
 import java.net.InetSocketAddress;
 import java.net.ServerSocket;
 import java.net.Socket;
+import java.net.SocketException;
+import java.nio.charset.StandardCharsets;
 import java.util.Random;
 
 /**
@@ -79,11 +90,14 @@
     public static String TAG = "VpnTest";
     public static int TIMEOUT_MS = 3 * 1000;
     public static int SOCKET_TIMEOUT_MS = 100;
+    public static String TEST_HOST = "connectivitycheck.gstatic.com";
 
     private UiDevice mDevice;
     private MyActivity mActivity;
     private String mPackageName;
     private ConnectivityManager mCM;
+    private RemoteSocketFactoryClient mRemoteSocketFactoryClient;
+
     Network mNetwork;
     NetworkCallback mCallback;
     final Object mLock = new Object();
@@ -107,11 +121,14 @@
                 MyActivity.class, null);
         mPackageName = mActivity.getPackageName();
         mCM = (ConnectivityManager) mActivity.getSystemService(mActivity.CONNECTIVITY_SERVICE);
+        mRemoteSocketFactoryClient = new RemoteSocketFactoryClient(mActivity);
+        mRemoteSocketFactoryClient.bind();
         mDevice.waitForIdle();
     }
 
     @Override
     public void tearDown() throws Exception {
+        mRemoteSocketFactoryClient.unbind();
         if (mCallback != null) {
             mCM.unregisterNetworkCallback(mCallback);
         }
@@ -441,7 +458,7 @@
         }
     }
 
-    private void checkTrafficOnVpn() throws IOException, ErrnoException {
+    private void checkTrafficOnVpn() throws Exception {
         checkUdpEcho("192.0.2.251", "192.0.2.2");
         checkUdpEcho("2001:db8:dead:beef::f00", "2001:db8:1:2::ffe");
         checkPing("2001:db8:dead:beef::f00");
@@ -449,29 +466,88 @@
         checkTcpReflection("2001:db8:dead:beef::f00", "2001:db8:1:2::ffe");
     }
 
-    private void checkNoTrafficOnVpn() throws IOException, ErrnoException {
+    private void checkNoTrafficOnVpn() throws Exception {
         checkUdpEcho("192.0.2.251", null);
         checkUdpEcho("2001:db8:dead:beef::f00", null);
         checkTcpReflection("192.0.2.252", null);
         checkTcpReflection("2001:db8:dead:beef::f00", null);
     }
 
+    private FileDescriptor openSocketFd(String host, int port, int timeoutMs) throws Exception {
+        Socket s = new Socket(host, port);
+        s.setSoTimeout(timeoutMs);
+        return ParcelFileDescriptor.fromSocket(s).getFileDescriptor();
+    }
+
+    private FileDescriptor openSocketFdInOtherApp(
+            String host, int port, int timeoutMs) throws Exception {
+        Log.d(TAG, String.format("Creating test socket in UID=%d, my UID=%d",
+                mRemoteSocketFactoryClient.getUid(), Os.getuid()));
+        FileDescriptor fd = mRemoteSocketFactoryClient.openSocketFd(host, port, TIMEOUT_MS);
+        return fd;
+    }
+
+    private void sendRequest(FileDescriptor fd, String host) throws Exception {
+        String request = "GET /generate_204 HTTP/1.1\r\n" +
+                "Host: " + host + "\r\n" +
+                "Connection: keep-alive\r\n\r\n";
+        byte[] requestBytes = request.getBytes(StandardCharsets.UTF_8);
+        int ret = Os.write(fd, requestBytes, 0, requestBytes.length);
+        Log.d(TAG, "Wrote " + ret + "bytes");
+
+        String expected = "HTTP/1.1 204 No Content\r\n";
+        byte[] response = new byte[expected.length()];
+        Os.read(fd, response, 0, response.length);
+
+        String actual = new String(response, StandardCharsets.UTF_8);
+        assertEquals(expected, actual);
+        Log.d(TAG, "Got response: " + actual);
+    }
+
+    private void assertSocketStillOpen(FileDescriptor fd, String host) throws Exception {
+        try {
+            sendRequest(fd, host);
+        } finally {
+            Os.close(fd);
+        }
+    }
+
+    private void assertSocketClosed(FileDescriptor fd, String host) throws Exception {
+        try {
+            sendRequest(fd, host);
+            fail("Socket opened before VPN connects should be closed when VPN connects");
+        } catch (ErrnoException expected) {
+            assertEquals(ECONNABORTED, expected.errno);
+        } finally {
+            Os.close(fd);
+        }
+    }
+
     public void testDefault() throws Exception {
         if (!supportedHardware()) return;
 
+        FileDescriptor fd = openSocketFdInOtherApp(TEST_HOST, 80, TIMEOUT_MS);
+
         startVpn(new String[] {"192.0.2.2/32", "2001:db8:1:2::ffe/128"},
                  new String[] {"0.0.0.0/0", "::/0"},
                  "", "");
 
+        assertSocketClosed(fd, TEST_HOST);
+
         checkTrafficOnVpn();
     }
 
     public void testAppAllowed() throws Exception {
         if (!supportedHardware()) return;
 
+        FileDescriptor fd = openSocketFdInOtherApp(TEST_HOST, 80, TIMEOUT_MS);
+
+        String allowedApps = mRemoteSocketFactoryClient.getPackageName() + "," + mPackageName;
         startVpn(new String[] {"192.0.2.2/32", "2001:db8:1:2::ffe/128"},
                  new String[] {"192.0.2.0/24", "2001:db8::/32"},
-                 mPackageName, "");
+                 allowedApps, "");
+
+        assertSocketClosed(fd, TEST_HOST);
 
         checkTrafficOnVpn();
     }
@@ -479,9 +555,16 @@
     public void testAppDisallowed() throws Exception {
         if (!supportedHardware()) return;
 
+        FileDescriptor localFd = openSocketFd(TEST_HOST, 80, TIMEOUT_MS);
+        FileDescriptor remoteFd = openSocketFdInOtherApp(TEST_HOST, 80, TIMEOUT_MS);
+
+        String disallowedApps = mRemoteSocketFactoryClient.getPackageName() + "," + mPackageName;
         startVpn(new String[] {"192.0.2.2/32", "2001:db8:1:2::ffe/128"},
                  new String[] {"192.0.2.0/24", "2001:db8::/32"},
-                 "", mPackageName);
+                 "", disallowedApps);
+
+        assertSocketStillOpen(localFd, TEST_HOST);
+        assertSocketStillOpen(remoteFd, TEST_HOST);
 
         checkNoTrafficOnVpn();
     }