Merge changes I1707d28e,Ibb4a2916 into main

* changes:
  Dump top distinct tag counts per uid
  Support per uid tag throttling to NetworkStats
diff --git a/service-t/src/com/android/server/net/NetworkStatsFactory.java b/service-t/src/com/android/server/net/NetworkStatsFactory.java
index 5ff708d..c5a69c0 100644
--- a/service-t/src/com/android/server/net/NetworkStatsFactory.java
+++ b/service-t/src/com/android/server/net/NetworkStatsFactory.java
@@ -19,6 +19,7 @@
 import static android.net.NetworkStats.INTERFACES_ALL;
 import static android.net.NetworkStats.TAG_ALL;
 import static android.net.NetworkStats.UID_ALL;
+import static android.provider.DeviceConfig.NAMESPACE_TETHERING;
 
 import android.annotation.NonNull;
 import android.content.Context;
@@ -26,15 +27,26 @@
 import android.net.UnderlyingNetworkInfo;
 import android.os.ServiceSpecificException;
 import android.os.SystemClock;
+import android.util.ArraySet;
+import android.util.IndentingPrintWriter;
+import android.util.Log;
+import android.util.Pair;
+import android.util.SparseArray;
+import android.util.SparseBooleanArray;
 
 import com.android.internal.annotations.GuardedBy;
 import com.android.internal.annotations.VisibleForTesting;
+import com.android.net.module.util.DeviceConfigUtils;
 import com.android.server.BpfNetMaps;
 import com.android.server.connectivity.InterfaceTracker;
 
 import java.io.IOException;
 import java.net.ProtocolException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
 import java.util.Map;
+import java.util.Set;
 import java.util.concurrent.ConcurrentHashMap;
 
 /**
@@ -65,6 +77,18 @@
     /** Set containing info about active VPNs and their underlying networks. */
     private volatile UnderlyingNetworkInfo[] mUnderlyingNetworkInfos = new UnderlyingNetworkInfo[0];
 
+    static final String CONFIG_PER_UID_TAG_THROTTLING = "per_uid_tag_throttling";
+    static final String CONFIG_PER_UID_TAG_THROTTLING_THRESHOLD =
+            "per_uid_tag_throttling_threshold";
+    private static final int DEFAULT_TAGS_PER_UID_THRESHOLD = 1000;
+    private static final int DUMP_TAGS_PER_UID_COUNT = 20;
+    private final boolean mSupportPerUidTagThrottling;
+    private final int mPerUidTagThrottlingThreshold;
+
+    // Map for set of distinct tags per uid. Used for tag count limiting.
+    @GuardedBy("mPersistentDataLock")
+    private final SparseArray<SparseBooleanArray> mUidTagSets = new SparseArray<>();
+
     // A persistent snapshot of cumulative stats since device start
     @GuardedBy("mPersistentDataLock")
     private NetworkStats mPersistSnapshot;
@@ -110,6 +134,26 @@
         public BpfNetMaps createBpfNetMaps(@NonNull Context ctx) {
             return new BpfNetMaps(ctx, new InterfaceTracker(ctx));
         }
+
+        /**
+         * Check whether one specific feature is not disabled.
+         * @param name Flag name of the experiment in the tethering namespace.
+         * @see DeviceConfigUtils#isTetheringFeatureNotChickenedOut(Context, String)
+         */
+        public boolean isFeatureNotChickenedOut(@NonNull Context context, @NonNull String name) {
+            return DeviceConfigUtils.isTetheringFeatureNotChickenedOut(context, name);
+        }
+
+        /**
+         * Wrapper method for DeviceConfigUtils#getDeviceConfigPropertyInt for test injections.
+         *
+         * See {@link DeviceConfigUtils#getDeviceConfigPropertyInt(String, String, int)}
+         * for more detailed information.
+         */
+        public int getDeviceConfigPropertyInt(@NonNull String name, int defaultValue) {
+            return DeviceConfigUtils.getDeviceConfigPropertyInt(
+                    NAMESPACE_TETHERING, name, defaultValue);
+        }
     }
 
     /**
@@ -162,6 +206,10 @@
         }
         mContext = ctx;
         mDeps = deps;
+        mSupportPerUidTagThrottling = mDeps.isFeatureNotChickenedOut(
+            ctx, CONFIG_PER_UID_TAG_THROTTLING);
+        mPerUidTagThrottlingThreshold = mDeps.getDeviceConfigPropertyInt(
+                CONFIG_PER_UID_TAG_THROTTLING_THRESHOLD, DEFAULT_TAGS_PER_UID_THRESHOLD);
     }
 
     /**
@@ -210,10 +258,13 @@
             requestSwapActiveStatsMapLocked();
             // Stats are always read from the inactive map, so they must be read after the
             // swap
-            final NetworkStats stats = mDeps.getNetworkStatsDetail();
+            final NetworkStats diff = mDeps.getNetworkStatsDetail();
+            // Filter based on UID tag set before merging.
+            final NetworkStats filteredDiff = mSupportPerUidTagThrottling
+                    ? filterStatsByUidTagSets(diff) : diff;
             // BPF stats are incremental; fold into mPersistSnapshot.
-            mPersistSnapshot.setElapsedRealtime(stats.getElapsedRealtime());
-            mPersistSnapshot.combineAllValues(stats);
+            mPersistSnapshot.setElapsedRealtime(diff.getElapsedRealtime());
+            mPersistSnapshot.combineAllValues(filteredDiff);
 
             NetworkStats adjustedStats = adjustForTunAnd464Xlat(mPersistSnapshot, prev, vpnArray);
 
@@ -224,6 +275,41 @@
     }
 
     @GuardedBy("mPersistentDataLock")
+    private NetworkStats filterStatsByUidTagSets(NetworkStats stats) {
+        final NetworkStats filteredStats =
+                new NetworkStats(stats.getElapsedRealtime(), stats.size());
+
+        final NetworkStats.Entry entry = new NetworkStats.Entry();
+        final Set<Integer> tooManyTagsUidSet = new ArraySet<>();
+        for (int i = 0; i < stats.size(); i++) {
+            stats.getValues(i, entry);
+            final int uid = entry.uid;
+            final int tag = entry.tag;
+
+            if (tag == NetworkStats.TAG_NONE) {
+                filteredStats.combineValues(entry);
+                continue;
+            }
+
+            SparseBooleanArray tagSet = mUidTagSets.get(uid);
+            if (tagSet == null) {
+                tagSet = new SparseBooleanArray();
+            }
+            if (tagSet.size() < mPerUidTagThrottlingThreshold || tagSet.get(tag)) {
+                filteredStats.combineValues(entry);
+                tagSet.put(tag, true);
+                mUidTagSets.put(uid, tagSet);
+            } else {
+                tooManyTagsUidSet.add(uid);
+            }
+        }
+        if (tooManyTagsUidSet.size() > 0) {
+            Log.wtf(TAG, "Too many tags detected for uids: " + tooManyTagsUidSet);
+        }
+        return filteredStats;
+    }
+
+    @GuardedBy("mPersistentDataLock")
     private NetworkStats adjustForTunAnd464Xlat(NetworkStats uidDetailStats,
             NetworkStats previousStats, UnderlyingNetworkInfo[] vpnArray) {
         // Calculate delta from last snapshot
@@ -307,4 +393,34 @@
         pe.initCause(cause);
         return pe;
     }
+
+    /**
+     * Dump the contents of NetworkStatsFactory.
+     */
+    public void dump(IndentingPrintWriter pw) {
+        dumpUidTagSets(pw);
+    }
+
+    private void dumpUidTagSets(IndentingPrintWriter pw) {
+        pw.println("Top distinct tag counts in UidTagSets:");
+        pw.increaseIndent();
+        final List<Pair<Integer, Integer>> countForUidList = new ArrayList<>();
+        synchronized (mPersistentDataLock) {
+            for (int i = 0; i < mUidTagSets.size(); i++) {
+                final Pair<Integer, Integer> countForUid =
+                        new Pair<>(mUidTagSets.keyAt(i), mUidTagSets.valueAt(i).size());
+                countForUidList.add(countForUid);
+            }
+        }
+        Collections.sort(countForUidList,
+                (entry1, entry2) -> Integer.compare(entry2.second, entry1.second));
+        final int dumpSize = Math.min(countForUidList.size(), DUMP_TAGS_PER_UID_COUNT);
+        for (int j = 0; j < dumpSize; j++) {
+            final Pair<Integer, Integer> entry = countForUidList.get(j);
+            pw.print(entry.first);
+            pw.print("=");
+            pw.println(entry.second);
+        }
+        pw.decreaseIndent();
+    }
 }
diff --git a/service-t/src/com/android/server/net/NetworkStatsService.java b/service-t/src/com/android/server/net/NetworkStatsService.java
index 5c5f4ca..75d30a9 100644
--- a/service-t/src/com/android/server/net/NetworkStatsService.java
+++ b/service-t/src/com/android/server/net/NetworkStatsService.java
@@ -3228,6 +3228,12 @@
             pw.increaseIndent();
             mSkDestroyListener.dump(pw);
             pw.decreaseIndent();
+
+            pw.println();
+            pw.println("NetworkStatsFactory logs:");
+            pw.increaseIndent();
+            mStatsFactory.dump(pw);
+            pw.decreaseIndent();
         }
     }
 
diff --git a/tests/unit/java/com/android/server/net/NetworkStatsFactoryTest.java b/tests/unit/java/com/android/server/net/NetworkStatsFactoryTest.java
index 63daebc..89acf69 100644
--- a/tests/unit/java/com/android/server/net/NetworkStatsFactoryTest.java
+++ b/tests/unit/java/com/android/server/net/NetworkStatsFactoryTest.java
@@ -20,6 +20,7 @@
 import static android.net.NetworkStats.DEFAULT_NETWORK_NO;
 import static android.net.NetworkStats.METERED_ALL;
 import static android.net.NetworkStats.METERED_NO;
+import static android.net.NetworkStats.METERED_YES;
 import static android.net.NetworkStats.ROAMING_ALL;
 import static android.net.NetworkStats.ROAMING_NO;
 import static android.net.NetworkStats.SET_ALL;
@@ -29,6 +30,8 @@
 import static android.net.NetworkStats.TAG_NONE;
 import static android.net.NetworkStats.UID_ALL;
 
+import static com.android.server.net.NetworkStatsFactory.CONFIG_PER_UID_TAG_THROTTLING;
+import static com.android.server.net.NetworkStatsFactory.CONFIG_PER_UID_TAG_THROTTLING_THRESHOLD;
 import static com.android.server.net.NetworkStatsFactory.kernelToTag;
 import static com.android.testutils.DevSdkIgnoreRuleKt.SC_V2;
 
@@ -36,6 +39,9 @@
 import static org.junit.Assert.fail;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.anyInt;
+import static org.mockito.ArgumentMatchers.anyString;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.doReturn;
 
 import android.content.Context;
@@ -52,12 +58,15 @@
 import com.android.server.BpfNetMaps;
 import com.android.testutils.DevSdkIgnoreRule;
 import com.android.testutils.DevSdkIgnoreRunner;
+import com.android.testutils.com.android.testutils.SetFeatureFlagsRule;
+import com.android.testutils.com.android.testutils.SetFeatureFlagsRule.FeatureFlag;
 
 import libcore.io.IoUtils;
 import libcore.testing.io.TestIoUtils;
 
 import org.junit.After;
 import org.junit.Before;
+import org.junit.Rule;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.mockito.Mock;
@@ -66,6 +75,7 @@
 import java.io.File;
 import java.io.IOException;
 import java.net.ProtocolException;
+import java.util.HashMap;
 
 /** Tests for {@link NetworkStatsFactory}. */
 @RunWith(DevSdkIgnoreRunner.class)
@@ -73,6 +83,7 @@
 @DevSdkIgnoreRule.IgnoreUpTo(SC_V2)
 public class NetworkStatsFactoryTest extends NetworkStatsBaseTest {
     private static final String CLAT_PREFIX = "v4-";
+    private static final int TEST_TAGS_PER_UID_THRESHOLD = 10;
 
     private File mTestProc;
     private NetworkStatsFactory mFactory;
@@ -80,6 +91,16 @@
     @Mock private NetworkStatsFactory.Dependencies mDeps;
     @Mock private BpfNetMaps mBpfNetMaps;
 
+    final HashMap<String, Boolean> mFeatureFlags = new HashMap<>();
+    // This will set feature flags from @FeatureFlag annotations
+    // into the map before setUp() runs.
+    @Rule
+    public final SetFeatureFlagsRule mSetFeatureFlagsRule =
+            new SetFeatureFlagsRule((name, enabled) -> {
+                mFeatureFlags.put(name, enabled);
+                return null;
+            }, (name) -> mFeatureFlags.getOrDefault(name, false));
+
     @Before
     public void setUp() throws Exception {
         MockitoAnnotations.initMocks(this);
@@ -90,6 +111,10 @@
         // related to networkStatsFactory is compiled to a minimal native library and loaded here.
         System.loadLibrary("networkstatsfactorytestjni");
         doReturn(mBpfNetMaps).when(mDeps).createBpfNetMaps(any());
+        doAnswer(invocation -> mFeatureFlags.getOrDefault((String) invocation.getArgument(1), true))
+            .when(mDeps).isFeatureNotChickenedOut(any(), anyString());
+        doReturn(TEST_TAGS_PER_UID_THRESHOLD).when(mDeps)
+                .getDeviceConfigPropertyInt(eq(CONFIG_PER_UID_TAG_THROTTLING_THRESHOLD), anyInt());
 
         mFactory = new NetworkStatsFactory(mContext, mDeps);
         mFactory.updateUnderlyingNetworkInfos(new UnderlyingNetworkInfo[0]);
@@ -498,6 +523,71 @@
         assertValues(removedUidsStats, TEST_IFACE, UID_GREEN, 64L, 3L, 1024L, 8L);
     }
 
+    @FeatureFlag(name = CONFIG_PER_UID_TAG_THROTTLING)
+    @Test
+    public void testFilterTooManyTags_featureEnabled() throws Exception {
+        doTestFilterTooManyTags(true);
+    }
+
+    @FeatureFlag(name = CONFIG_PER_UID_TAG_THROTTLING, enabled = false)
+    @Test
+    public void testFilterTooManyTags_featureDisabled() throws Exception {
+        doTestFilterTooManyTags(false);
+    }
+
+    private void doTestFilterTooManyTags(boolean supportPerUidTagThrottling) throws Exception {
+        // Add entries for UID_RED which reaches the threshold.
+        final NetworkStats statsWithManyTags = new NetworkStats(0L, TEST_TAGS_PER_UID_THRESHOLD);
+        for (int tag = 1; tag <= TEST_TAGS_PER_UID_THRESHOLD; tag++) {
+            statsWithManyTags.combineValues(
+                    new NetworkStats.Entry(TEST_IFACE, UID_RED, SET_DEFAULT, tag,
+                            METERED_NO, ROAMING_NO, DEFAULT_NETWORK_NO, 12L, 18L, 14L, 1L, 0L));
+        }
+        doReturn(statsWithManyTags).when(mDeps).getNetworkStatsDetail();
+        final NetworkStats stats1 = mFactory.readNetworkStatsDetail();
+        assertEquals(stats1.size(), TEST_TAGS_PER_UID_THRESHOLD);
+
+        // Add 2 new entries with pre-existing tag, verify they can be added no matter what.
+        final NetworkStats newDiffWithExistingTag = new NetworkStats(0L, 2);
+        // This one should be added as a new entry, as the metered data doesn't exist yet.
+        newDiffWithExistingTag.combineValues(
+                new NetworkStats.Entry(TEST_IFACE, UID_RED, SET_DEFAULT,
+                        TEST_TAGS_PER_UID_THRESHOLD,
+                        METERED_YES, ROAMING_NO, DEFAULT_NETWORK_NO, 3L, 5L, 8L, 1L, 1L));
+        // This one should be combined into existing entry.
+        newDiffWithExistingTag.combineValues(
+                new NetworkStats.Entry(TEST_IFACE, UID_RED, SET_DEFAULT,
+                        TEST_TAGS_PER_UID_THRESHOLD,
+                        METERED_NO, ROAMING_NO, DEFAULT_NETWORK_NO, 1L, 2L, 3L, 4L, 5L));
+
+        doReturn(newDiffWithExistingTag).when(mDeps).getNetworkStatsDetail();
+        final NetworkStats stats2 = mFactory.readNetworkStatsDetail();
+        assertEquals(stats2.size(), TEST_TAGS_PER_UID_THRESHOLD + 1);
+        assertValues(stats2, TEST_IFACE, UID_RED, SET_DEFAULT, TEST_TAGS_PER_UID_THRESHOLD,
+                METERED_YES, ROAMING_NO, DEFAULT_NETWORK_NO, 3L, 5L, 8L, 1L, 1L);
+        assertValues(stats2, TEST_IFACE, UID_RED, SET_DEFAULT, TEST_TAGS_PER_UID_THRESHOLD,
+                METERED_NO, ROAMING_NO, DEFAULT_NETWORK_NO, 13L, 20L, 17L, 5L, 5L);
+
+        // Add an entry which exceeds the threshold, verify the entry is filtered out.
+        final NetworkStats newDiffWithNonExistingTag = new NetworkStats(0L, 1);
+        newDiffWithNonExistingTag.combineValues(
+                new NetworkStats.Entry(TEST_IFACE, UID_RED, SET_DEFAULT,
+                        TEST_TAGS_PER_UID_THRESHOLD + 1,
+                        METERED_NO, ROAMING_NO, DEFAULT_NETWORK_NO, 12L, 18L, 14L, 1L, 0L));
+        doReturn(newDiffWithNonExistingTag).when(mDeps).getNetworkStatsDetail();
+        final NetworkStats stats3 = mFactory.readNetworkStatsDetail();
+        if (supportPerUidTagThrottling) {
+            assertEquals(stats3.size(), TEST_TAGS_PER_UID_THRESHOLD + 1);
+            assertNoStatsEntry(stats3, TEST_IFACE, UID_RED, SET_DEFAULT,
+                    TEST_TAGS_PER_UID_THRESHOLD + 1);
+        } else {
+            assertEquals(stats3.size(), TEST_TAGS_PER_UID_THRESHOLD + 2);
+            assertValues(stats3, TEST_IFACE, UID_RED, SET_DEFAULT,
+                    TEST_TAGS_PER_UID_THRESHOLD + 1,
+                    METERED_NO, ROAMING_NO, DEFAULT_NETWORK_NO, 12L, 18L, 14L, 1L, 0L);
+        }
+    }
+
     private NetworkStats buildEmptyStats() {
         return new NetworkStats(SystemClock.elapsedRealtime(), 0);
     }