Add a test for closing sockets when a VPN comes up.

Bug: 28251576
Change-Id: Iab0a8643cff3c54eb04168a7cdfa116c0b8e30b1
diff --git a/tests/cts/hostside/aidl/Android.mk b/tests/cts/hostside/aidl/Android.mk
new file mode 100644
index 0000000..a7ec6ef
--- /dev/null
+++ b/tests/cts/hostside/aidl/Android.mk
@@ -0,0 +1,22 @@
+# Copyright (C) 2016 The Android Open Source Project
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+LOCAL_PATH := $(call my-dir)
+
+include $(CLEAR_VARS)
+LOCAL_MODULE_TAGS := tests
+LOCAL_SDK_VERSION := current
+LOCAL_SRC_FILES := com/android/cts/net/hostside/IRemoteSocketFactory.aidl
+LOCAL_MODULE := CtsHostsideNetworkTestsAidl
+include $(BUILD_JAVA_LIBRARY)
diff --git a/tests/cts/hostside/aidl/com/android/cts/net/hostside/IRemoteSocketFactory.aidl b/tests/cts/hostside/aidl/com/android/cts/net/hostside/IRemoteSocketFactory.aidl
new file mode 100644
index 0000000..68176ad
--- /dev/null
+++ b/tests/cts/hostside/aidl/com/android/cts/net/hostside/IRemoteSocketFactory.aidl
@@ -0,0 +1,25 @@
+/*
+ * Copyright (C) 2016 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.cts.net.hostside;
+
+import android.os.ParcelFileDescriptor;
+
+interface IRemoteSocketFactory {
+    ParcelFileDescriptor openSocketFd(String host, int port, int timeoutMs);
+    String getPackageName();
+    int getUid();
+}
diff --git a/tests/cts/hostside/app/Android.mk b/tests/cts/hostside/app/Android.mk
index 7f8da07..9519ec5 100644
--- a/tests/cts/hostside/app/Android.mk
+++ b/tests/cts/hostside/app/Android.mk
@@ -20,7 +20,8 @@
 
 LOCAL_MODULE_TAGS := tests
 LOCAL_SDK_VERSION := current
-LOCAL_STATIC_JAVA_LIBRARIES := ctsdeviceutil ctstestrunner ub-uiautomator
+LOCAL_STATIC_JAVA_LIBRARIES := ctsdeviceutil ctstestrunner ub-uiautomator \
+        CtsHostsideNetworkTestsAidl
 
 LOCAL_SRC_FILES := $(call all-java-files-under, src)
 
diff --git a/tests/cts/hostside/app/src/com/android/cts/net/hostside/RemoteSocketFactoryClient.java b/tests/cts/hostside/app/src/com/android/cts/net/hostside/RemoteSocketFactoryClient.java
new file mode 100644
index 0000000..799fe50
--- /dev/null
+++ b/tests/cts/hostside/app/src/com/android/cts/net/hostside/RemoteSocketFactoryClient.java
@@ -0,0 +1,91 @@
+/*
+ * Copyright (C) 2016 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.cts.net.hostside;
+
+import android.content.ComponentName;
+import android.content.Context;
+import android.content.Intent;
+import android.content.ServiceConnection;
+import android.os.ConditionVariable;
+import android.os.IBinder;
+import android.os.RemoteException;
+
+import com.android.cts.net.hostside.IRemoteSocketFactory;
+
+import java.io.FileDescriptor;
+
+public class RemoteSocketFactoryClient {
+    private static final int TIMEOUT_MS = 5000;
+    private static final String PACKAGE = RemoteSocketFactoryClient.class.getPackage().getName();
+    private static final String APP2_PACKAGE = PACKAGE + ".app2";
+    private static final String SERVICE_NAME = APP2_PACKAGE + ".RemoteSocketFactoryService";
+
+    private Context mContext;
+    private ServiceConnection mServiceConnection;
+    private IRemoteSocketFactory mService;
+
+    public RemoteSocketFactoryClient(Context context) {
+        mContext = context;
+    }
+
+    public void bind() {
+        if (mService != null) {
+            throw new IllegalStateException("Already bound");
+        }
+
+        final ConditionVariable cv = new ConditionVariable();
+        mServiceConnection = new ServiceConnection() {
+            @Override
+            public void onServiceConnected(ComponentName name, IBinder service) {
+                mService = IRemoteSocketFactory.Stub.asInterface(service);
+                cv.open();
+            }
+            @Override
+            public void onServiceDisconnected(ComponentName name) {
+                mService = null;
+            }
+        };
+
+        final Intent intent = new Intent();
+        intent.setComponent(new ComponentName(APP2_PACKAGE, SERVICE_NAME));
+        mContext.bindService(intent, mServiceConnection, Context.BIND_AUTO_CREATE);
+        cv.block(TIMEOUT_MS);
+        if (mService == null) {
+            throw new IllegalStateException(
+                    "Could not bind to RemoteSocketFactory service after " + TIMEOUT_MS + "ms");
+        }
+    }
+
+    public void unbind() {
+        if (mService != null) {
+            mContext.unbindService(mServiceConnection);
+        }
+    }
+
+    public FileDescriptor openSocketFd(
+            String host, int port, int timeoutMs) throws RemoteException {
+        return mService.openSocketFd(host, port, timeoutMs).getFileDescriptor();
+    }
+
+    public String getPackageName() throws RemoteException {
+        return mService.getPackageName();
+    }
+
+    public int getUid() throws RemoteException {
+        return mService.getUid();
+    }
+}
diff --git a/tests/cts/hostside/app/src/com/android/cts/net/hostside/VpnTest.java b/tests/cts/hostside/app/src/com/android/cts/net/hostside/VpnTest.java
index 5045cc2..12fe625 100755
--- a/tests/cts/hostside/app/src/com/android/cts/net/hostside/VpnTest.java
+++ b/tests/cts/hostside/app/src/com/android/cts/net/hostside/VpnTest.java
@@ -27,6 +27,8 @@
 import android.net.NetworkCapabilities;
 import android.net.NetworkRequest;
 import android.net.VpnService;
+import android.os.ParcelFileDescriptor;
+import android.os.Process;
 import android.support.test.uiautomator.UiDevice;
 import android.support.test.uiautomator.UiObject;
 import android.support.test.uiautomator.UiObjectNotFoundException;
@@ -40,11 +42,18 @@
 import android.text.TextUtils;
 import android.util.Log;
 
+import com.android.cts.net.hostside.IRemoteSocketFactory;
+
+import java.io.BufferedReader;
 import java.io.Closeable;
 import java.io.FileDescriptor;
+import java.io.FileOutputStream;
+import java.io.FileInputStream;
+import java.io.InputStreamReader;
 import java.io.IOException;
 import java.io.InputStream;
 import java.io.OutputStream;
+import java.io.PrintWriter;
 import java.net.DatagramPacket;
 import java.net.DatagramSocket;
 import java.net.Inet6Address;
@@ -52,6 +61,8 @@
 import java.net.InetSocketAddress;
 import java.net.ServerSocket;
 import java.net.Socket;
+import java.net.SocketException;
+import java.nio.charset.StandardCharsets;
 import java.util.Random;
 
 /**
@@ -79,11 +90,14 @@
     public static String TAG = "VpnTest";
     public static int TIMEOUT_MS = 3 * 1000;
     public static int SOCKET_TIMEOUT_MS = 100;
+    public static String TEST_HOST = "connectivitycheck.gstatic.com";
 
     private UiDevice mDevice;
     private MyActivity mActivity;
     private String mPackageName;
     private ConnectivityManager mCM;
+    private RemoteSocketFactoryClient mRemoteSocketFactoryClient;
+
     Network mNetwork;
     NetworkCallback mCallback;
     final Object mLock = new Object();
@@ -107,11 +121,14 @@
                 MyActivity.class, null);
         mPackageName = mActivity.getPackageName();
         mCM = (ConnectivityManager) mActivity.getSystemService(mActivity.CONNECTIVITY_SERVICE);
+        mRemoteSocketFactoryClient = new RemoteSocketFactoryClient(mActivity);
+        mRemoteSocketFactoryClient.bind();
         mDevice.waitForIdle();
     }
 
     @Override
     public void tearDown() throws Exception {
+        mRemoteSocketFactoryClient.unbind();
         if (mCallback != null) {
             mCM.unregisterNetworkCallback(mCallback);
         }
@@ -441,7 +458,7 @@
         }
     }
 
-    private void checkTrafficOnVpn() throws IOException, ErrnoException {
+    private void checkTrafficOnVpn() throws Exception {
         checkUdpEcho("192.0.2.251", "192.0.2.2");
         checkUdpEcho("2001:db8:dead:beef::f00", "2001:db8:1:2::ffe");
         checkPing("2001:db8:dead:beef::f00");
@@ -449,29 +466,88 @@
         checkTcpReflection("2001:db8:dead:beef::f00", "2001:db8:1:2::ffe");
     }
 
-    private void checkNoTrafficOnVpn() throws IOException, ErrnoException {
+    private void checkNoTrafficOnVpn() throws Exception {
         checkUdpEcho("192.0.2.251", null);
         checkUdpEcho("2001:db8:dead:beef::f00", null);
         checkTcpReflection("192.0.2.252", null);
         checkTcpReflection("2001:db8:dead:beef::f00", null);
     }
 
+    private FileDescriptor openSocketFd(String host, int port, int timeoutMs) throws Exception {
+        Socket s = new Socket(host, port);
+        s.setSoTimeout(timeoutMs);
+        return ParcelFileDescriptor.fromSocket(s).getFileDescriptor();
+    }
+
+    private FileDescriptor openSocketFdInOtherApp(
+            String host, int port, int timeoutMs) throws Exception {
+        Log.d(TAG, String.format("Creating test socket in UID=%d, my UID=%d",
+                mRemoteSocketFactoryClient.getUid(), Os.getuid()));
+        FileDescriptor fd = mRemoteSocketFactoryClient.openSocketFd(host, port, TIMEOUT_MS);
+        return fd;
+    }
+
+    private void sendRequest(FileDescriptor fd, String host) throws Exception {
+        String request = "GET /generate_204 HTTP/1.1\r\n" +
+                "Host: " + host + "\r\n" +
+                "Connection: keep-alive\r\n\r\n";
+        byte[] requestBytes = request.getBytes(StandardCharsets.UTF_8);
+        int ret = Os.write(fd, requestBytes, 0, requestBytes.length);
+        Log.d(TAG, "Wrote " + ret + "bytes");
+
+        String expected = "HTTP/1.1 204 No Content\r\n";
+        byte[] response = new byte[expected.length()];
+        Os.read(fd, response, 0, response.length);
+
+        String actual = new String(response, StandardCharsets.UTF_8);
+        assertEquals(expected, actual);
+        Log.d(TAG, "Got response: " + actual);
+    }
+
+    private void assertSocketStillOpen(FileDescriptor fd, String host) throws Exception {
+        try {
+            sendRequest(fd, host);
+        } finally {
+            Os.close(fd);
+        }
+    }
+
+    private void assertSocketClosed(FileDescriptor fd, String host) throws Exception {
+        try {
+            sendRequest(fd, host);
+            fail("Socket opened before VPN connects should be closed when VPN connects");
+        } catch (ErrnoException expected) {
+            assertEquals(ECONNABORTED, expected.errno);
+        } finally {
+            Os.close(fd);
+        }
+    }
+
     public void testDefault() throws Exception {
         if (!supportedHardware()) return;
 
+        FileDescriptor fd = openSocketFdInOtherApp(TEST_HOST, 80, TIMEOUT_MS);
+
         startVpn(new String[] {"192.0.2.2/32", "2001:db8:1:2::ffe/128"},
                  new String[] {"0.0.0.0/0", "::/0"},
                  "", "");
 
+        assertSocketClosed(fd, TEST_HOST);
+
         checkTrafficOnVpn();
     }
 
     public void testAppAllowed() throws Exception {
         if (!supportedHardware()) return;
 
+        FileDescriptor fd = openSocketFdInOtherApp(TEST_HOST, 80, TIMEOUT_MS);
+
+        String allowedApps = mRemoteSocketFactoryClient.getPackageName() + "," + mPackageName;
         startVpn(new String[] {"192.0.2.2/32", "2001:db8:1:2::ffe/128"},
                  new String[] {"192.0.2.0/24", "2001:db8::/32"},
-                 mPackageName, "");
+                 allowedApps, "");
+
+        assertSocketClosed(fd, TEST_HOST);
 
         checkTrafficOnVpn();
     }
@@ -479,9 +555,16 @@
     public void testAppDisallowed() throws Exception {
         if (!supportedHardware()) return;
 
+        FileDescriptor localFd = openSocketFd(TEST_HOST, 80, TIMEOUT_MS);
+        FileDescriptor remoteFd = openSocketFdInOtherApp(TEST_HOST, 80, TIMEOUT_MS);
+
+        String disallowedApps = mRemoteSocketFactoryClient.getPackageName() + "," + mPackageName;
         startVpn(new String[] {"192.0.2.2/32", "2001:db8:1:2::ffe/128"},
                  new String[] {"192.0.2.0/24", "2001:db8::/32"},
-                 "", mPackageName);
+                 "", disallowedApps);
+
+        assertSocketStillOpen(localFd, TEST_HOST);
+        assertSocketStillOpen(remoteFd, TEST_HOST);
 
         checkNoTrafficOnVpn();
     }
diff --git a/tests/cts/hostside/app2/Android.mk b/tests/cts/hostside/app2/Android.mk
index 3b59f8f..706455d 100644
--- a/tests/cts/hostside/app2/Android.mk
+++ b/tests/cts/hostside/app2/Android.mk
@@ -20,6 +20,7 @@
 
 LOCAL_MODULE_TAGS := tests
 LOCAL_SDK_VERSION := current
+LOCAL_STATIC_JAVA_LIBRARIES := CtsHostsideNetworkTestsAidl
 
 LOCAL_SRC_FILES := $(call all-java-files-under, src)
 
diff --git a/tests/cts/hostside/app2/AndroidManifest.xml b/tests/cts/hostside/app2/AndroidManifest.xml
index 9c4884b..80b669d 100644
--- a/tests/cts/hostside/app2/AndroidManifest.xml
+++ b/tests/cts/hostside/app2/AndroidManifest.xml
@@ -29,11 +29,15 @@
 
          The manifest-defined listener also handles ordered broadcasts used to share data with the
          test app.
+
+         This application also provides a service, RemoteSocketFactoryService, that the test app can
+         use to open sockets to remote hosts as a different user ID.
     -->
     <application>
         <activity android:name=".MyActivity" android:exported="true"/>
         <service android:name=".MyService" android:exported="true"/>
         <service android:name=".MyForegroundService" android:exported="true"/>
+        <service android:name=".RemoteSocketFactoryService" android:exported="true"/>
 
         <receiver android:name=".MyBroadcastReceiver" >
             <intent-filter>
@@ -45,4 +49,4 @@
         </receiver>
     </application>
 
-</manifest>
\ No newline at end of file
+</manifest>
diff --git a/tests/cts/hostside/app2/src/com/android/cts/net/hostside/app2/RemoteSocketFactoryService.java b/tests/cts/hostside/app2/src/com/android/cts/net/hostside/app2/RemoteSocketFactoryService.java
new file mode 100644
index 0000000..b1b7d77
--- /dev/null
+++ b/tests/cts/hostside/app2/src/com/android/cts/net/hostside/app2/RemoteSocketFactoryService.java
@@ -0,0 +1,63 @@
+/*
+ * Copyright (C) 2016 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.cts.net.hostside.app2;
+
+import android.app.Service;
+import android.content.Context;
+import android.content.Intent;
+import android.os.IBinder;
+import android.os.ParcelFileDescriptor;
+import android.os.Process;
+import android.util.Log;
+
+import com.android.cts.net.hostside.IRemoteSocketFactory;
+
+import java.net.Socket;
+
+
+public class RemoteSocketFactoryService extends Service {
+
+    private static final String TAG = RemoteSocketFactoryService.class.getSimpleName();
+
+    private IRemoteSocketFactory.Stub mBinder = new IRemoteSocketFactory.Stub() {
+        @Override
+        public ParcelFileDescriptor openSocketFd(String host, int port, int timeoutMs) {
+            try {
+                Socket s = new Socket(host, port);
+                s.setSoTimeout(timeoutMs);
+                return ParcelFileDescriptor.fromSocket(s);
+            } catch (Exception e) {
+                throw new IllegalArgumentException(e);
+            }
+        }
+
+        @Override
+        public String getPackageName() {
+            return RemoteSocketFactoryService.this.getPackageName();
+        }
+
+        @Override
+        public int getUid() {
+            return Process.myUid();
+        }
+    };
+
+    @Override
+    public IBinder onBind(Intent intent) {
+        return mBinder;
+    }
+}
diff --git a/tests/cts/hostside/src/com/android/cts/net/HostsideNetworkTestCase.java b/tests/cts/hostside/src/com/android/cts/net/HostsideNetworkTestCase.java
index 39b5652..6642512 100644
--- a/tests/cts/hostside/src/com/android/cts/net/HostsideNetworkTestCase.java
+++ b/tests/cts/hostside/src/com/android/cts/net/HostsideNetworkTestCase.java
@@ -42,6 +42,8 @@
     protected static final String TAG = "HostsideNetworkTests";
     protected static final String TEST_PKG = "com.android.cts.net.hostside";
     protected static final String TEST_APK = "CtsHostsideNetworkTestsApp.apk";
+    protected static final String TEST_APP2_PKG = "com.android.cts.net.hostside.app2";
+    protected static final String TEST_APP2_APK = "CtsHostsideNetworkTestsApp2.apk";
 
     private IAbi mAbi;
     private IBuildInfo mCtsBuild;
diff --git a/tests/cts/hostside/src/com/android/cts/net/HostsideRestrictBackgroundNetworkTests.java b/tests/cts/hostside/src/com/android/cts/net/HostsideRestrictBackgroundNetworkTests.java
index ec375d6..1a8634e 100644
--- a/tests/cts/hostside/src/com/android/cts/net/HostsideRestrictBackgroundNetworkTests.java
+++ b/tests/cts/hostside/src/com/android/cts/net/HostsideRestrictBackgroundNetworkTests.java
@@ -21,9 +21,6 @@
 
 public class HostsideRestrictBackgroundNetworkTests extends HostsideNetworkTestCase {
 
-    private static final String TEST_APP2_PKG = "com.android.cts.net.hostside.app2";
-    private static final String TEST_APP2_APK = "CtsHostsideNetworkTestsApp2.apk";
-
     @Override
     protected void setUp() throws Exception {
         super.setUp();
diff --git a/tests/cts/hostside/src/com/android/cts/net/HostsideVpnTests.java b/tests/cts/hostside/src/com/android/cts/net/HostsideVpnTests.java
index dc965c5..69b07af 100644
--- a/tests/cts/hostside/src/com/android/cts/net/HostsideVpnTests.java
+++ b/tests/cts/hostside/src/com/android/cts/net/HostsideVpnTests.java
@@ -18,7 +18,30 @@
 
 public class HostsideVpnTests extends HostsideNetworkTestCase {
 
-    public void testVpn() throws Exception {
-        runDeviceTests(TEST_PKG, TEST_PKG + ".VpnTest");
+    @Override
+    protected void setUp() throws Exception {
+        super.setUp();
+
+        uninstallPackage(TEST_APP2_PKG, false);
+        installPackage(TEST_APP2_APK);
+    }
+
+    @Override
+    protected void tearDown() throws Exception {
+        super.tearDown();
+
+        uninstallPackage(TEST_APP2_PKG, true);
+    }
+
+    public void testDefault() throws Exception {
+        runDeviceTests(TEST_PKG, TEST_PKG + ".VpnTest", "testDefault");
+    }
+
+    public void testAppAllowed() throws Exception {
+        runDeviceTests(TEST_PKG, TEST_PKG + ".VpnTest", "testAppAllowed");
+    }
+
+    public void testAppDisallowed() throws Exception {
+        runDeviceTests(TEST_PKG, TEST_PKG + ".VpnTest", "testAppDisallowed");
     }
 }