[VCN05] Pass request type when requesting network

Currently, ConnectivityService decides the request type by
whether NetworkCapabilities is null when handling request
network. However, to fulfill the need of firing background
request via ConnectivityManager in the follow-up patches,
the request type is needed to pass into ConnectivityService.

This change also make ConnectivityService utilizes the passed
request type.

Test: atest ConnectivityManagerTest#testRequestType
Bug: 175662146
Change-Id: I3bc172bca1217c8020db45057a621d0745d43b3c
diff --git a/core/java/android/net/ConnectivityManager.java b/core/java/android/net/ConnectivityManager.java
index 06c1598..8742ecb 100644
--- a/core/java/android/net/ConnectivityManager.java
+++ b/core/java/android/net/ConnectivityManager.java
@@ -16,6 +16,9 @@
 package android.net;
 
 import static android.net.IpSecManager.INVALID_RESOURCE_ID;
+import static android.net.NetworkRequest.Type.LISTEN;
+import static android.net.NetworkRequest.Type.REQUEST;
+import static android.net.NetworkRequest.Type.TRACK_DEFAULT;
 
 import android.annotation.CallbackExecutor;
 import android.annotation.IntDef;
@@ -3730,14 +3733,12 @@
     private static final HashMap<NetworkRequest, NetworkCallback> sCallbacks = new HashMap<>();
     private static CallbackHandler sCallbackHandler;
 
-    private static final int LISTEN  = 1;
-    private static final int REQUEST = 2;
-
     private NetworkRequest sendRequestForNetwork(NetworkCapabilities need, NetworkCallback callback,
-            int timeoutMs, int action, int legacyType, CallbackHandler handler) {
+            int timeoutMs, NetworkRequest.Type reqType, int legacyType, CallbackHandler handler) {
         printStackTrace();
         checkCallbackNotNull(callback);
-        Preconditions.checkArgument(action == REQUEST || need != null, "null NetworkCapabilities");
+        Preconditions.checkArgument(
+                reqType == TRACK_DEFAULT || need != null, "null NetworkCapabilities");
         final NetworkRequest request;
         final String callingPackageName = mContext.getOpPackageName();
         try {
@@ -3750,13 +3751,13 @@
                 }
                 Messenger messenger = new Messenger(handler);
                 Binder binder = new Binder();
-                if (action == LISTEN) {
+                if (reqType == LISTEN) {
                     request = mService.listenForNetwork(
                             need, messenger, binder, callingPackageName);
                 } else {
                     request = mService.requestNetwork(
-                            need, messenger, timeoutMs, binder, legacyType, callingPackageName,
-                            getAttributionTag());
+                            need, reqType.ordinal(), messenger, timeoutMs, binder, legacyType,
+                            callingPackageName, getAttributionTag());
                 }
                 if (request != null) {
                     sCallbacks.put(request, callback);
@@ -4260,7 +4261,7 @@
         // request, i.e., the system default network.
         CallbackHandler cbHandler = new CallbackHandler(handler);
         sendRequestForNetwork(null /* NetworkCapabilities need */, networkCallback, 0,
-                REQUEST, TYPE_NONE, cbHandler);
+                TRACK_DEFAULT, TYPE_NONE, cbHandler);
     }
 
     /**
diff --git a/core/java/android/net/IConnectivityManager.aidl b/core/java/android/net/IConnectivityManager.aidl
index b32c98b..5e925b6 100644
--- a/core/java/android/net/IConnectivityManager.aidl
+++ b/core/java/android/net/IConnectivityManager.aidl
@@ -167,7 +167,7 @@
             in NetworkCapabilities nc, int score, in NetworkAgentConfig config,
             in int factorySerialNumber);
 
-    NetworkRequest requestNetwork(in NetworkCapabilities networkCapabilities,
+    NetworkRequest requestNetwork(in NetworkCapabilities networkCapabilities, int reqType,
             in Messenger messenger, int timeoutSec, in IBinder binder, int legacy,
             String callingPackageName, String callingAttributionTag);
 
diff --git a/services/core/java/com/android/server/ConnectivityService.java b/services/core/java/com/android/server/ConnectivityService.java
index 397eeb2..ad54312 100644
--- a/services/core/java/com/android/server/ConnectivityService.java
+++ b/services/core/java/com/android/server/ConnectivityService.java
@@ -5613,31 +5613,40 @@
 
     @Override
     public NetworkRequest requestNetwork(NetworkCapabilities networkCapabilities,
-            Messenger messenger, int timeoutMs, IBinder binder, int legacyType,
-            @NonNull String callingPackageName, @Nullable String callingAttributionTag) {
+            int reqTypeInt, Messenger messenger, int timeoutMs, IBinder binder,
+            int legacyType, @NonNull String callingPackageName,
+            @Nullable String callingAttributionTag) {
         if (legacyType != TYPE_NONE && !checkNetworkStackPermission()) {
             if (checkUnsupportedStartingFrom(Build.VERSION_CODES.M, callingPackageName)) {
                 throw new SecurityException("Insufficient permissions to specify legacy type");
             }
         }
         final int callingUid = mDeps.getCallingUid();
-        final NetworkRequest.Type type = (networkCapabilities == null)
-                ? NetworkRequest.Type.TRACK_DEFAULT
-                : NetworkRequest.Type.REQUEST;
-        // If the requested networkCapabilities is null, take them instead from
-        // the default network request. This allows callers to keep track of
-        // the system default network.
-        if (type == NetworkRequest.Type.TRACK_DEFAULT) {
-            networkCapabilities = createDefaultNetworkCapabilitiesForUid(callingUid);
-            enforceAccessPermission();
-        } else {
-            networkCapabilities = new NetworkCapabilities(networkCapabilities);
-            enforceNetworkRequestPermissions(networkCapabilities, callingPackageName,
-                    callingAttributionTag);
-            // TODO: this is incorrect. We mark the request as metered or not depending on the state
-            // of the app when the request is filed, but we never change the request if the app
-            // changes network state. http://b/29964605
-            enforceMeteredApnPolicy(networkCapabilities);
+        final NetworkRequest.Type reqType;
+        try {
+            reqType = NetworkRequest.Type.values()[reqTypeInt];
+        } catch (ArrayIndexOutOfBoundsException e) {
+            throw new IllegalArgumentException("Unsupported request type " + reqTypeInt);
+        }
+        switch (reqType) {
+            case TRACK_DEFAULT:
+                // If the request type is TRACK_DEFAULT, the passed {@code networkCapabilities}
+                // is unused and will be replaced by the one from the default network request.
+                // This allows callers to keep track of the system default network.
+                networkCapabilities = createDefaultNetworkCapabilitiesForUid(callingUid);
+                enforceAccessPermission();
+                break;
+            case REQUEST:
+                networkCapabilities = new NetworkCapabilities(networkCapabilities);
+                enforceNetworkRequestPermissions(networkCapabilities, callingPackageName,
+                        callingAttributionTag);
+                // TODO: this is incorrect. We mark the request as metered or not depending on
+                //  the state of the app when the request is filed, but we never change the
+                //  request if the app changes network state. http://b/29964605
+                enforceMeteredApnPolicy(networkCapabilities);
+                break;
+            default:
+                throw new IllegalArgumentException("Unsupported request type " + reqType);
         }
         ensureRequestableCapabilities(networkCapabilities);
         ensureSufficientPermissionsForRequest(networkCapabilities,
@@ -5656,7 +5665,7 @@
         ensureValid(networkCapabilities);
 
         NetworkRequest networkRequest = new NetworkRequest(networkCapabilities, legacyType,
-                nextNetworkRequestId(), type);
+                nextNetworkRequestId(), reqType);
         NetworkRequestInfo nri = new NetworkRequestInfo(messenger, networkRequest, binder);
         if (DBG) log("requestNetwork for " + nri);
 
diff --git a/tests/net/java/android/net/ConnectivityManagerTest.java b/tests/net/java/android/net/ConnectivityManagerTest.java
index d74a621..f2dd27e 100644
--- a/tests/net/java/android/net/ConnectivityManagerTest.java
+++ b/tests/net/java/android/net/ConnectivityManagerTest.java
@@ -16,6 +16,7 @@
 
 package android.net;
 
+import static android.net.ConnectivityManager.TYPE_NONE;
 import static android.net.NetworkCapabilities.NET_CAPABILITY_CBS;
 import static android.net.NetworkCapabilities.NET_CAPABILITY_DUN;
 import static android.net.NetworkCapabilities.NET_CAPABILITY_FOTA;
@@ -31,16 +32,21 @@
 import static android.net.NetworkCapabilities.TRANSPORT_CELLULAR;
 import static android.net.NetworkCapabilities.TRANSPORT_ETHERNET;
 import static android.net.NetworkCapabilities.TRANSPORT_WIFI;
+import static android.net.NetworkRequest.Type.REQUEST;
+import static android.net.NetworkRequest.Type.TRACK_DEFAULT;
 
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
+import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.ArgumentMatchers.nullable;
 import static org.mockito.Mockito.any;
 import static org.mockito.Mockito.anyBoolean;
 import static org.mockito.Mockito.anyInt;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.reset;
 import static org.mockito.Mockito.timeout;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
@@ -49,9 +55,7 @@
 import android.app.PendingIntent;
 import android.content.Context;
 import android.content.pm.ApplicationInfo;
-import android.net.ConnectivityManager;
 import android.net.ConnectivityManager.NetworkCallback;
-import android.net.NetworkCapabilities;
 import android.os.Build.VERSION_CODES;
 import android.os.Bundle;
 import android.os.Handler;
@@ -213,9 +217,8 @@
         ArgumentCaptor<Messenger> captor = ArgumentCaptor.forClass(Messenger.class);
 
         // register callback
-        when(mService.requestNetwork(
-                any(), captor.capture(), anyInt(), any(), anyInt(), any(), nullable(String.class)))
-                .thenReturn(request);
+        when(mService.requestNetwork(any(), anyInt(), captor.capture(), anyInt(), any(), anyInt(),
+                any(), nullable(String.class))).thenReturn(request);
         manager.requestNetwork(request, callback, handler);
 
         // callback triggers
@@ -242,9 +245,8 @@
         ArgumentCaptor<Messenger> captor = ArgumentCaptor.forClass(Messenger.class);
 
         // register callback
-        when(mService.requestNetwork(
-                any(), captor.capture(), anyInt(), any(), anyInt(), any(), nullable(String.class)))
-                .thenReturn(req1);
+        when(mService.requestNetwork(any(), anyInt(), captor.capture(), anyInt(), any(), anyInt(),
+                any(), nullable(String.class))).thenReturn(req1);
         manager.requestNetwork(req1, callback, handler);
 
         // callback triggers
@@ -261,9 +263,8 @@
         verify(callback, timeout(100).times(0)).onLosing(any(), anyInt());
 
         // callback can be registered again
-        when(mService.requestNetwork(
-                any(), captor.capture(), anyInt(), any(), anyInt(), any(), nullable(String.class)))
-                .thenReturn(req2);
+        when(mService.requestNetwork(any(), anyInt(), captor.capture(), anyInt(), any(), anyInt(),
+                any(), nullable(String.class))).thenReturn(req2);
         manager.requestNetwork(req2, callback, handler);
 
         // callback triggers
@@ -286,7 +287,7 @@
         info.targetSdkVersion = VERSION_CODES.N_MR1 + 1;
 
         when(mCtx.getApplicationInfo()).thenReturn(info);
-        when(mService.requestNetwork(any(), any(), anyInt(), any(), anyInt(), any(),
+        when(mService.requestNetwork(any(), anyInt(), any(), anyInt(), any(), anyInt(), any(),
                 nullable(String.class))).thenReturn(request);
 
         Handler handler = new Handler(Looper.getMainLooper());
@@ -340,6 +341,35 @@
         }
     }
 
+    @Test
+    public void testRequestType() throws Exception {
+        final String testPkgName = "MyPackage";
+        final ConnectivityManager manager = new ConnectivityManager(mCtx, mService);
+        when(mCtx.getOpPackageName()).thenReturn(testPkgName);
+        final NetworkRequest request = makeRequest(1);
+        final NetworkCallback callback = new ConnectivityManager.NetworkCallback();
+
+        manager.requestNetwork(request, callback);
+        verify(mService).requestNetwork(eq(request.networkCapabilities),
+                eq(REQUEST.ordinal()), any(), anyInt(), any(), eq(TYPE_NONE),
+                eq(testPkgName), eq(null));
+        reset(mService);
+
+        // Verify that register network callback does not calls requestNetwork at all.
+        manager.registerNetworkCallback(request, callback);
+        verify(mService, never()).requestNetwork(any(), anyInt(), any(), anyInt(), any(),
+                anyInt(), any(), any());
+        verify(mService).listenForNetwork(eq(request.networkCapabilities), any(), any(),
+                eq(testPkgName));
+        reset(mService);
+
+        manager.registerDefaultNetworkCallback(callback);
+        verify(mService).requestNetwork(eq(null),
+                eq(TRACK_DEFAULT.ordinal()), any(), anyInt(), any(), eq(TYPE_NONE),
+                eq(testPkgName), eq(null));
+        reset(mService);
+    }
+
     static Message makeMessage(NetworkRequest req, int messageType) {
         Bundle bundle = new Bundle();
         bundle.putParcelable(NetworkRequest.class.getSimpleName(), req);
diff --git a/tests/net/java/com/android/server/ConnectivityServiceTest.java b/tests/net/java/com/android/server/ConnectivityServiceTest.java
index 9421acd..e2cf90d 100644
--- a/tests/net/java/com/android/server/ConnectivityServiceTest.java
+++ b/tests/net/java/com/android/server/ConnectivityServiceTest.java
@@ -3226,8 +3226,8 @@
             NetworkCapabilities networkCapabilities = new NetworkCapabilities();
             networkCapabilities.addTransportType(TRANSPORT_WIFI)
                     .setNetworkSpecifier(new MatchAllNetworkSpecifier());
-            mService.requestNetwork(networkCapabilities, null, 0, null,
-                    ConnectivityManager.TYPE_WIFI, mContext.getPackageName(),
+            mService.requestNetwork(networkCapabilities, NetworkRequest.Type.REQUEST.ordinal(),
+                    null, 0, null, ConnectivityManager.TYPE_WIFI, mContext.getPackageName(),
                     getAttributionTag());
         });