Merge "[Thread] fix upstream network selector to exclude VPN" into main
diff --git a/thread/service/java/com/android/server/thread/ThreadNetworkControllerService.java b/thread/service/java/com/android/server/thread/ThreadNetworkControllerService.java
index e6f272b..19084c6 100644
--- a/thread/service/java/com/android/server/thread/ThreadNetworkControllerService.java
+++ b/thread/service/java/com/android/server/thread/ThreadNetworkControllerService.java
@@ -271,10 +271,12 @@
     }
 
     private NetworkRequest newUpstreamNetworkRequest() {
-        NetworkRequest.Builder builder = new NetworkRequest.Builder().clearCapabilities();
+        NetworkRequest.Builder builder = new NetworkRequest.Builder();
 
         if (mUpstreamTestNetworkSpecifier != null) {
-            return builder.addTransportType(NetworkCapabilities.TRANSPORT_TEST)
+            // Test networks don't have NET_CAPABILITY_TRUSTED
+            return builder.removeCapability(NetworkCapabilities.NET_CAPABILITY_TRUSTED)
+                    .addTransportType(NetworkCapabilities.TRANSPORT_TEST)
                     .setNetworkSpecifier(mUpstreamTestNetworkSpecifier)
                     .build();
         }
diff --git a/thread/tests/unit/src/com/android/server/thread/ThreadNetworkControllerServiceTest.java b/thread/tests/unit/src/com/android/server/thread/ThreadNetworkControllerServiceTest.java
index a5dc25a..df1a65b 100644
--- a/thread/tests/unit/src/com/android/server/thread/ThreadNetworkControllerServiceTest.java
+++ b/thread/tests/unit/src/com/android/server/thread/ThreadNetworkControllerServiceTest.java
@@ -16,6 +16,7 @@
 
 package com.android.server.thread;
 
+import static android.Manifest.permission.NETWORK_SETTINGS;
 import static android.net.thread.ActiveOperationalDataset.CHANNEL_PAGE_24_GHZ;
 import static android.net.thread.ThreadNetworkController.STATE_DISABLED;
 import static android.net.thread.ThreadNetworkController.STATE_ENABLED;
@@ -38,6 +39,7 @@
 import static org.mockito.ArgumentMatchers.argThat;
 import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.Mockito.any;
+import static org.mockito.Mockito.clearInvocations;
 import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.doNothing;
 import static org.mockito.Mockito.doThrow;
@@ -56,7 +58,9 @@
 import android.content.res.Resources;
 import android.net.ConnectivityManager;
 import android.net.NetworkAgent;
+import android.net.NetworkCapabilities;
 import android.net.NetworkProvider;
+import android.net.NetworkRequest;
 import android.net.thread.ActiveOperationalDataset;
 import android.net.thread.IActiveOperationalDatasetReceiver;
 import android.net.thread.IOperationReceiver;
@@ -181,6 +185,9 @@
                 .when(mContext)
                 .enforceCallingOrSelfPermission(
                         eq(PERMISSION_THREAD_NETWORK_PRIVILEGED), anyString());
+        doNothing()
+                .when(mContext)
+                .enforceCallingOrSelfPermission(eq(NETWORK_SETTINGS), anyString());
 
         mTestLooper = new TestLooper();
         final Handler handler = new Handler(mTestLooper.getLooper());
@@ -737,4 +744,34 @@
         inOrder.verify(mockReceiver2).onSuccess();
         inOrder.verify(mockReceiver3).onSuccess();
     }
+
+    @Test
+    public void setTestNetworkAsUpstream_upstreamNetworkRequestAlwaysDisallowsVpn() {
+        mService.initialize();
+        mTestLooper.dispatchAll();
+        clearInvocations(mMockConnectivityManager);
+
+        final IOperationReceiver mockReceiver1 = mock(IOperationReceiver.class);
+        final IOperationReceiver mockReceiver2 = mock(IOperationReceiver.class);
+        mService.setTestNetworkAsUpstream("test-network", mockReceiver1);
+        mService.setTestNetworkAsUpstream(null, mockReceiver2);
+        mTestLooper.dispatchAll();
+
+        ArgumentCaptor<NetworkRequest> networkRequestCaptor =
+                ArgumentCaptor.forClass(NetworkRequest.class);
+        verify(mMockConnectivityManager, times(2))
+                .registerNetworkCallback(
+                        networkRequestCaptor.capture(),
+                        any(ConnectivityManager.NetworkCallback.class),
+                        any(Handler.class));
+        assertThat(networkRequestCaptor.getAllValues().size()).isEqualTo(2);
+        NetworkRequest networkRequest1 = networkRequestCaptor.getAllValues().get(0);
+        NetworkRequest networkRequest2 = networkRequestCaptor.getAllValues().get(1);
+        assertThat(networkRequest1.getNetworkSpecifier()).isNotNull();
+        assertThat(networkRequest1.hasCapability(NetworkCapabilities.NET_CAPABILITY_NOT_VPN))
+                .isTrue();
+        assertThat(networkRequest2.getNetworkSpecifier()).isNull();
+        assertThat(networkRequest2.hasCapability(NetworkCapabilities.NET_CAPABILITY_NOT_VPN))
+                .isTrue();
+    }
 }