Merge "Follow-up changes to Update VPN isolation code for excluded routes"
diff --git a/service-t/src/com/android/server/net/NetworkStatsObservers.java b/service-t/src/com/android/server/net/NetworkStatsObservers.java
index c51a886..df4e7f5 100644
--- a/service-t/src/com/android/server/net/NetworkStatsObservers.java
+++ b/service-t/src/com/android/server/net/NetworkStatsObservers.java
@@ -44,6 +44,7 @@
 import android.util.SparseArray;
 
 import com.android.internal.annotations.VisibleForTesting;
+import com.android.net.module.util.PerUidCounter;
 
 import java.util.concurrent.atomic.AtomicInteger;
 
@@ -63,10 +64,17 @@
 
     private static final int DUMP_USAGE_REQUESTS_COUNT = 200;
 
+    // The maximum number of request allowed per uid before an exception is thrown.
+    @VisibleForTesting
+    static final int MAX_REQUESTS_PER_UID = 100;
+
     // All access to this map must be done from the handler thread.
     // indexed by DataUsageRequest#requestId
     private final SparseArray<RequestInfo> mDataUsageRequests = new SparseArray<>();
 
+    // Request counters per uid, this is thread safe.
+    private final PerUidCounter mDataUsageRequestsPerUid = new PerUidCounter(MAX_REQUESTS_PER_UID);
+
     // Sequence number of DataUsageRequests
     private final AtomicInteger mNextDataUsageRequestId = new AtomicInteger();
 
@@ -89,8 +97,9 @@
         DataUsageRequest request = buildRequest(context, inputRequest, callingUid);
         RequestInfo requestInfo = buildRequestInfo(request, callback, callingPid, callingUid,
                 callingPackage, accessLevel);
-
         if (LOG) Log.d(TAG, "Registering observer for " + requestInfo);
+        mDataUsageRequestsPerUid.incrementCountOrThrow(callingUid);
+
         getHandler().sendMessage(mHandler.obtainMessage(MSG_REGISTER, requestInfo));
         return request;
     }
@@ -189,6 +198,7 @@
 
         if (LOG) Log.d(TAG, "Unregistering " + requestInfo);
         mDataUsageRequests.remove(request.requestId);
+        mDataUsageRequestsPerUid.decrementCountOrThrow(callingUid);
         requestInfo.unlinkDeathRecipient();
         requestInfo.callCallback(NetworkStatsManager.CALLBACK_RELEASED);
     }
diff --git a/service-t/src/com/android/server/net/NetworkStatsService.java b/service-t/src/com/android/server/net/NetworkStatsService.java
index 7c167ed..b2d8b5e 100644
--- a/service-t/src/com/android/server/net/NetworkStatsService.java
+++ b/service-t/src/com/android/server/net/NetworkStatsService.java
@@ -377,8 +377,18 @@
 
     private long mLastStatsSessionPoll;
 
-    /** Map from key {@code OpenSessionKey} to count of opened sessions */
-    @GuardedBy("mOpenSessionCallsPerCaller")
+    private final Object mOpenSessionCallsLock = new Object();
+    /**
+     * Map from UID to number of opened sessions. This is used for rate-limt an app to open
+     * session frequently
+     */
+    @GuardedBy("mOpenSessionCallsLock")
+    private final SparseIntArray mOpenSessionCallsPerUid = new SparseIntArray();
+    /**
+     * Map from key {@code OpenSessionKey} to count of opened sessions. This is for recording
+     * the caller of open session and it is only for debugging.
+     */
+    @GuardedBy("mOpenSessionCallsLock")
     private final HashMap<OpenSessionKey, Integer> mOpenSessionCallsPerCaller = new HashMap<>();
 
     private final static int DUMP_STATS_SESSION_COUNT = 20;
@@ -415,23 +425,20 @@
      * the caller.
      */
     private static class OpenSessionKey {
-        public final int mPid;
-        public final int mUid;
-        public final String mPackage;
+        public final int uid;
+        public final String packageName;
 
-        OpenSessionKey(int pid, int uid, @NonNull String packageName) {
-            mPid = pid;
-            mUid = uid;
-            mPackage = packageName;
+        OpenSessionKey(int uid, @NonNull String packageName) {
+            this.uid = uid;
+            this.packageName = packageName;
         }
 
         @Override
         public String toString() {
             final StringBuilder sb = new StringBuilder();
             sb.append("{");
-            sb.append("pid=").append(mPid).append(",");
-            sb.append("uid=").append(mUid).append(",");
-            sb.append("package=").append(mPackage);
+            sb.append("uid=").append(uid).append(",");
+            sb.append("package=").append(packageName);
             sb.append("}");
             return sb.toString();
         }
@@ -442,13 +449,12 @@
             if (o.getClass() != getClass()) return false;
 
             final OpenSessionKey key = (OpenSessionKey) o;
-            return mPid == key.mPid && mUid == key.mUid
-                    && TextUtils.equals(mPackage, key.mPackage);
+            return this.uid == key.uid && TextUtils.equals(this.packageName, key.packageName);
         }
 
         @Override
         public int hashCode() {
-            return Objects.hash(mPid, mUid, mPackage);
+            return Objects.hash(uid, packageName);
         }
     }
 
@@ -843,21 +849,27 @@
         final long lastCallTime;
         final long now = SystemClock.elapsedRealtime();
 
-        synchronized (mOpenSessionCallsPerCaller) {
-            Integer calls = mOpenSessionCallsPerCaller.get(key);
-            if (calls == null) {
+        synchronized (mOpenSessionCallsLock) {
+            Integer callsPerCaller = mOpenSessionCallsPerCaller.get(key);
+            if (callsPerCaller == null) {
                 mOpenSessionCallsPerCaller.put((key), 1);
             } else {
-                mOpenSessionCallsPerCaller.put(key, Integer.sum(calls, 1));
+                mOpenSessionCallsPerCaller.put(key, Integer.sum(callsPerCaller, 1));
             }
+
+            int callsPerUid = mOpenSessionCallsPerUid.get(key.uid, 0);
+            mOpenSessionCallsPerUid.put(key.uid, callsPerUid + 1);
+
+            if (key.uid == android.os.Process.SYSTEM_UID) {
+                return false;
+            }
+
+            // To avoid a non-system user to be rate-limited after system users open sessions,
+            // so update mLastStatsSessionPoll after checked if the uid is SYSTEM_UID.
             lastCallTime = mLastStatsSessionPoll;
             mLastStatsSessionPoll = now;
         }
 
-        if (key.mUid == android.os.Process.SYSTEM_UID) {
-            return false;
-        }
-
         return now - lastCallTime < POLL_RATE_LIMIT_MS;
     }
 
@@ -870,9 +882,8 @@
             flags |= NetworkStatsManager.FLAG_POLL_ON_OPEN;
         }
         // Non-system uids are rate limited for POLL_ON_OPEN.
-        final int callingPid = Binder.getCallingPid();
         final int callingUid = Binder.getCallingUid();
-        final OpenSessionKey key = new OpenSessionKey(callingPid, callingUid, callingPackage);
+        final OpenSessionKey key = new OpenSessionKey(callingUid, callingPackage);
         flags = isRateLimitedForPoll(key)
                 ? flags & (~NetworkStatsManager.FLAG_POLL_ON_OPEN)
                 : flags;
@@ -1990,6 +2001,9 @@
         for (int uid : uids) {
             deleteKernelTagData(uid);
         }
+
+       // TODO: Remove the UID's entries from mOpenSessionCallsPerUid and
+       // mOpenSessionCallsPerCaller
     }
 
     /**
@@ -2114,7 +2128,7 @@
 
             // Get the top openSession callers
             final HashMap calls;
-            synchronized (mOpenSessionCallsPerCaller) {
+            synchronized (mOpenSessionCallsLock) {
                 calls = new HashMap<>(mOpenSessionCallsPerCaller);
             }
             final List<Map.Entry<OpenSessionKey, Integer>> list = new ArrayList<>(calls.entrySet());
diff --git a/service/src/com/android/server/ConnectivityService.java b/service/src/com/android/server/ConnectivityService.java
index 7de7c27..66ef9e9 100644
--- a/service/src/com/android/server/ConnectivityService.java
+++ b/service/src/com/android/server/ConnectivityService.java
@@ -1187,6 +1187,7 @@
     /**
      * Keeps track of the number of requests made under different uids.
      */
+    // TODO: Remove the hack and use com.android.net.module.util.PerUidCounter instead.
     public static class PerUidCounter {
         private final int mMaxCountPerUid;
 
diff --git a/tests/cts/net/src/android/net/cts/NetworkStatsManagerTest.java b/tests/cts/net/src/android/net/cts/NetworkStatsManagerTest.java
index d618915..7c47a70 100644
--- a/tests/cts/net/src/android/net/cts/NetworkStatsManagerTest.java
+++ b/tests/cts/net/src/android/net/cts/NetworkStatsManagerTest.java
@@ -870,10 +870,9 @@
         }
     }
 
+    @DevSdkIgnoreRule.IgnoreUpTo(SC_V2)
     @Test
     public void testDataMigrationUtils() throws Exception {
-        if (!SdkLevel.isAtLeastT()) return;
-
         final List<String> prefixes = List.of(PREFIX_UID, PREFIX_XT, PREFIX_UID_TAG);
         for (final String prefix : prefixes) {
             final long duration = TextUtils.equals(PREFIX_XT, prefix) ? TimeUnit.HOURS.toMillis(1)
diff --git a/tests/unit/java/com/android/server/net/NetworkStatsObserversTest.java b/tests/unit/java/com/android/server/net/NetworkStatsObserversTest.java
index 13a85e8..e8c9637 100644
--- a/tests/unit/java/com/android/server/net/NetworkStatsObserversTest.java
+++ b/tests/unit/java/com/android/server/net/NetworkStatsObserversTest.java
@@ -32,6 +32,7 @@
 import static com.android.testutils.DevSdkIgnoreRuleKt.SC_V2;
 
 import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertThrows;
 import static org.junit.Assert.assertTrue;
 import static org.mockito.Matchers.any;
 import static org.mockito.Matchers.anyInt;
@@ -64,6 +65,7 @@
 import org.mockito.Mockito;
 import org.mockito.MockitoAnnotations;
 
+import java.util.ArrayList;
 import java.util.Objects;
 
 /**
@@ -185,6 +187,51 @@
     }
 
     @Test
+    public void testRegister_limit() throws Exception {
+        final DataUsageRequest inputRequest = new DataUsageRequest(
+                DataUsageRequest.REQUEST_ID_UNSET, sTemplateWifi, THRESHOLD_BYTES);
+
+        // Register maximum requests for red.
+        final ArrayList<DataUsageRequest> redRequests = new ArrayList<>();
+        for (int i = 0; i < NetworkStatsObservers.MAX_REQUESTS_PER_UID; i++) {
+            final DataUsageRequest returnedRequest =
+                    mStatsObservers.register(mContext, inputRequest, mUsageCallback,
+                            PID_RED, UID_RED, PACKAGE_RED, NetworkStatsAccess.Level.DEVICE);
+            redRequests.add(returnedRequest);
+            assertTrue(returnedRequest.requestId > 0);
+        }
+
+        // Verify request exceeds the limit throws.
+        assertThrows(IllegalStateException.class, () ->
+                mStatsObservers.register(mContext, inputRequest, mUsageCallback,
+                    PID_RED, UID_RED, PACKAGE_RED, NetworkStatsAccess.Level.DEVICE));
+
+        // Verify another uid is not affected.
+        final ArrayList<DataUsageRequest> blueRequests = new ArrayList<>();
+        for (int i = 0; i < NetworkStatsObservers.MAX_REQUESTS_PER_UID; i++) {
+            final DataUsageRequest returnedRequest =
+                    mStatsObservers.register(mContext, inputRequest, mUsageCallback,
+                            PID_BLUE, UID_BLUE, PACKAGE_BLUE, NetworkStatsAccess.Level.DEVICE);
+            blueRequests.add(returnedRequest);
+            assertTrue(returnedRequest.requestId > 0);
+        }
+
+        // Again, verify request exceeds the limit throws for the 2nd uid.
+        assertThrows(IllegalStateException.class, () ->
+                mStatsObservers.register(mContext, inputRequest, mUsageCallback,
+                        PID_RED, UID_RED, PACKAGE_RED, NetworkStatsAccess.Level.DEVICE));
+
+        // Unregister all registered requests. Note that exceptions cannot be tested since
+        // unregister is handled in the handler thread.
+        for (final DataUsageRequest request : redRequests) {
+            mStatsObservers.unregister(request, UID_RED);
+        }
+        for (final DataUsageRequest request : blueRequests) {
+            mStatsObservers.unregister(request, UID_BLUE);
+        }
+    }
+
+    @Test
     public void testUnregister_unknownRequest_noop() throws Exception {
         DataUsageRequest unknownRequest = new DataUsageRequest(
                 123456 /* id */, sTemplateWifi, THRESHOLD_BYTES);