Merge changes from topic "trafficstats-client-cache" into main

* changes:
  Introduce client-side rate limiting for TrafficStats APIs
  Unit test for TrafficStats Rate Limit Cache
  Add `putIfAbsent` method to LruCacheWithExpiry
diff --git a/framework-t/src/android/net/TrafficStats.java b/framework-t/src/android/net/TrafficStats.java
index 81f2cf9..868033a 100644
--- a/framework-t/src/android/net/TrafficStats.java
+++ b/framework-t/src/android/net/TrafficStats.java
@@ -17,6 +17,7 @@
 package android.net;
 
 import static android.annotation.SystemApi.Client.MODULE_LIBRARIES;
+import static android.net.NetworkStats.UID_ALL;
 
 import static com.android.internal.annotations.VisibleForTesting.Visibility.PRIVATE;
 
@@ -33,21 +34,25 @@
 import android.content.Context;
 import android.media.MediaPlayer;
 import android.net.netstats.StatsResult;
+import android.net.netstats.TrafficStatsRateLimitCacheConfig;
 import android.os.Binder;
 import android.os.Build;
 import android.os.RemoteException;
 import android.os.StrictMode;
+import android.os.SystemClock;
 import android.util.Log;
 
 import com.android.internal.annotations.GuardedBy;
 import com.android.internal.annotations.VisibleForTesting;
+import com.android.net.module.util.BinderUtils;
+import com.android.net.module.util.LruCacheWithExpiry;
 
 import java.io.FileDescriptor;
 import java.io.IOException;
 import java.net.DatagramSocket;
 import java.net.Socket;
 import java.net.SocketException;
-
+import java.util.function.LongSupplier;
 
 /**
  * Class that provides network traffic statistics. These statistics include
@@ -182,13 +187,48 @@
     /** @hide */
     public static final int TAG_SYSTEM_PROBE = 0xFFFFFF42;
 
+    private static final StatsResult EMPTY_STATS = new StatsResult(0L, 0L, 0L, 0L);
+
+    private static final Object sRateLimitCacheLock = new Object();
+
     @GuardedBy("TrafficStats.class")
+    @Nullable
     private static INetworkStatsService sStatsService;
 
     // The variable will only be accessed in the test, which is effectively
     // single-threaded.
+    @Nullable
     private static INetworkStatsService sStatsServiceForTest = null;
 
+    // This holds the configuration for the TrafficStats rate limit caches.
+    // It will be filled with the result of a query to the service the first time
+    // the caller invokes get*Stats APIs.
+    // This variable can be accessed from any thread with the lock held.
+    @GuardedBy("sRateLimitCacheLock")
+    @Nullable
+    private static TrafficStatsRateLimitCacheConfig sRateLimitCacheConfig;
+
+    // Cache for getIfaceStats and getTotalStats binder interfaces.
+    // This variable can be accessed from any thread with the lock held,
+    // while the cache itself is thread-safe and can be accessed outside
+    // the lock.
+    @GuardedBy("sRateLimitCacheLock")
+    @Nullable
+    private static LruCacheWithExpiry<String, StatsResult> sRateLimitIfaceCache;
+
+    // Cache for getUidStats binder interface.
+    // This variable can be accessed from any thread with the lock held,
+    // while the cache itself is thread-safe and can be accessed outside
+    // the lock.
+    @GuardedBy("sRateLimitCacheLock")
+    @Nullable
+    private static LruCacheWithExpiry<Integer, StatsResult> sRateLimitUidCache;
+
+    // The variable will only be accessed in the test, which is effectively
+    // single-threaded.
+    @Nullable
+    private static LongSupplier sTimeSupplierForTest = null;
+
     @UnsupportedAppUsage(maxTargetSdk = Build.VERSION_CODES.P, trackingBug = 130143562)
     private synchronized static INetworkStatsService getStatsService() {
         if (sStatsServiceForTest != null) return sStatsServiceForTest;
@@ -215,6 +255,28 @@
     }
 
     /**
+     * Set time supplier for test, or null to reset.
+     *
+     * @hide
+     */
+    @VisibleForTesting(visibility = PRIVATE)
+    public static void setTimeSupplierForTest(LongSupplier timeSupplier) {
+        sTimeSupplierForTest = timeSupplier;
+    }
+
+    /**
+     * Trigger query rate-limit cache config and initializing the caches.
+     *
+     * This is for test purpose.
+     *
+     * @hide
+     */
+    @VisibleForTesting(visibility = PRIVATE)
+    public static void reinitRateLimitCacheForTest() {
+        maybeGetConfigAndInitRateLimitCache(true /* forceReinit */);
+    }
+
+    /**
      * Snapshot of {@link NetworkStats} when the currently active profiling
      * session started, or {@code null} if no session active.
      *
@@ -254,6 +316,92 @@
         sStatsService = statsManager.getBinder();
     }
 
+    @Nullable
+    private static LruCacheWithExpiry<String, StatsResult> maybeGetRateLimitIfaceCache() {
+        if (!maybeGetConfigAndInitRateLimitCache(false /* forceReinit */)) return null;
+        synchronized (sRateLimitCacheLock) {
+            return sRateLimitIfaceCache;
+        }
+    }
+
+    @Nullable
+    private static LruCacheWithExpiry<Integer, StatsResult> maybeGetRateLimitUidCache() {
+        if (!maybeGetConfigAndInitRateLimitCache(false /* forceReinit */)) return null;
+        synchronized (sRateLimitCacheLock) {
+            return sRateLimitUidCache;
+        }
+    }
+
+    /**
+     * Gets the rate limit cache configuration and init caches if null.
+     *
+     * Gets the configuration from the service as the configuration
+     * is not expected to change dynamically. And use it to initialize
+     * rate-limit cache if not yet initialized.
+     *
+     * @return whether the rate-limit cache is enabled.
+     *
+     * @hide
+     */
+    private static boolean maybeGetConfigAndInitRateLimitCache(boolean forceReinit) {
+        // Access the service outside the lock to avoid potential deadlocks. This is
+        // especially important when the caller is a system component (e.g.,
+        // NetworkPolicyManagerService) that might hold other locks that the service
+        // also needs.
+        // Although this introduces a race condition where multiple threads might
+        // query the service concurrently, it's acceptable in this case because the
+        // configuration doesn't change dynamically. The configuration only needs to
+        // be fetched once before initializing the cache.
+        synchronized (sRateLimitCacheLock) {
+            if (sRateLimitCacheConfig != null && !forceReinit) {
+                return sRateLimitCacheConfig.isCacheEnabled;
+            }
+        }
+
+        final TrafficStatsRateLimitCacheConfig config;
+        try {
+            config = getStatsService().getRateLimitCacheConfig();
+        } catch (RemoteException e) {
+            throw e.rethrowFromSystemServer();
+        }
+
+        synchronized (sRateLimitCacheLock) {
+            if (sRateLimitCacheConfig == null || forceReinit) {
+                sRateLimitCacheConfig = config;
+                initRateLimitCacheLocked();
+            }
+        }
+        return config.isCacheEnabled;
+    }
+
+    @GuardedBy("sRateLimitCacheLock")
+    private static void initRateLimitCacheLocked() {
+        // Set up rate limiting caches.
+        // Use uid cache with UID_ALL to cache total stats.
+        if (sRateLimitCacheConfig.isCacheEnabled) {
+            // A time supplier which is monotonic until device reboots, and counts
+            // time spent in sleep. This is needed to ensure the get*Stats caller
+            // won't get stale value after system time adjustment or waking up from sleep.
+            final LongSupplier realtimeSupplier = (sTimeSupplierForTest != null
+                    ? sTimeSupplierForTest : () -> SystemClock.elapsedRealtime());
+            sRateLimitIfaceCache = new LruCacheWithExpiry<String, StatsResult>(
+                    realtimeSupplier,
+                    sRateLimitCacheConfig.expiryDurationMs,
+                    sRateLimitCacheConfig.maxEntries,
+                    (statsResult) -> !isEmpty(statsResult)
+            );
+            sRateLimitUidCache = new LruCacheWithExpiry<Integer, StatsResult>(
+                    realtimeSupplier,
+                    sRateLimitCacheConfig.expiryDurationMs,
+                    sRateLimitCacheConfig.maxEntries,
+                    (statsResult) -> !isEmpty(statsResult)
+            );
+        } else {
+            sRateLimitIfaceCache = null;
+            sRateLimitUidCache = null;
+        }
+    }
+
     /**
      * Attach the socket tagger implementation to the current process, to
      * get notified when a socket's {@link FileDescriptor} is assigned to
@@ -736,6 +884,14 @@
             android.Manifest.permission.NETWORK_STACK,
             android.Manifest.permission.NETWORK_SETTINGS})
     public static void clearRateLimitCaches() {
+        final LruCacheWithExpiry<String, StatsResult> ifaceCache = maybeGetRateLimitIfaceCache();
+        if (ifaceCache != null) {
+            ifaceCache.clear();
+        }
+        final LruCacheWithExpiry<Integer, StatsResult> uidCache = maybeGetRateLimitUidCache();
+        if (uidCache != null) {
+            uidCache.clear();
+        }
         try {
             getStatsService().clearTrafficStatsRateLimitCaches();
         } catch (RemoteException e) {
@@ -985,35 +1141,76 @@
 
     /** @hide */
     public static long getUidStats(int uid, int type) {
-        final StatsResult stats;
+        return fetchStats(maybeGetRateLimitUidCache(), uid,
+                () -> getStatsService().getUidStats(uid), type);
+    }
+
+    // Note: This method calls to the service, do not invoke this method with lock held.
+    private static <K> long fetchStats(
+            @Nullable LruCacheWithExpiry<K, StatsResult> cache, K key,
+            BinderUtils.ThrowingSupplier<StatsResult, RemoteException> statsFetcher, int type) {
         try {
-            stats = getStatsService().getUidStats(uid);
+            final StatsResult stats;
+            if (cache != null) {
+                stats = fetchStatsWithCache(cache, key, statsFetcher);
+            } else {
+                // Cache is not enabled, fetch directly from service.
+                stats = statsFetcher.get();
+            }
+            return getEntryValueForType(stats, type);
         } catch (RemoteException e) {
             throw e.rethrowFromSystemServer();
         }
-        return getEntryValueForType(stats, type);
+    }
+
+    // Note: This method calls to the service, do not invoke this method with lock held.
+    @Nullable
+    private static <K> StatsResult fetchStatsWithCache(LruCacheWithExpiry<K, StatsResult> cache,
+            K key, BinderUtils.ThrowingSupplier<StatsResult, RemoteException> statsFetcher)
+            throws RemoteException {
+        // Attempt to retrieve from the cache first.
+        StatsResult stats = cache.get(key);
+
+        // Although the cache instance is thread-safe, this can still introduce a
+        // race condition between threads of the same process, potentially
+        // returning non-monotonic results. This is because there is no lock
+        // between get, fetch, and put operations. This is considered acceptable
+        // because varying thread execution speeds can also cause non-monotonic
+        // results, even with locking.
+        if (stats == null) {
+            // Cache miss, fetch from the service.
+            stats = statsFetcher.get();
+
+            // Update the cache with the fetched result if valid.
+            if (stats != null && !isEmpty(stats)) {
+                final StatsResult cachedValue = cache.putIfAbsent(key, stats);
+                if (cachedValue != null) {
+                    // Some other thread cached a value after this thread
+                    // originally got a cache miss. Return the cached value
+                    // to ensure all returned values after caching are consistent.
+                    return cachedValue;
+                }
+            }
+        }
+        return stats;
+    }
+
+    private static boolean isEmpty(StatsResult stats) {
+        return stats.equals(EMPTY_STATS);
     }
 
     /** @hide */
     public static long getTotalStats(int type) {
-        final StatsResult stats;
-        try {
-            stats = getStatsService().getTotalStats();
-        } catch (RemoteException e) {
-            throw e.rethrowFromSystemServer();
-        }
-        return getEntryValueForType(stats, type);
+        // In practice, Bpf doesn't use UID_ALL for storing per-UID stats.
+        // Use uid cache with UID_ALL to cache total stats.
+        return fetchStats(maybeGetRateLimitUidCache(), UID_ALL,
+                () -> getStatsService().getTotalStats(), type);
     }
 
     /** @hide */
     public static long getIfaceStats(String iface, int type) {
-        final StatsResult stats;
-        try {
-            stats = getStatsService().getIfaceStats(iface);
-        } catch (RemoteException e) {
-            throw e.rethrowFromSystemServer();
-        }
-        return getEntryValueForType(stats, type);
+        return fetchStats(maybeGetRateLimitIfaceCache(), iface,
+                () -> getStatsService().getIfaceStats(iface), type);
     }
 
     /**
diff --git a/service-t/src/com/android/server/net/NetworkStatsService.java b/service-t/src/com/android/server/net/NetworkStatsService.java
index fb712a1..a8e3203 100644
--- a/service-t/src/com/android/server/net/NetworkStatsService.java
+++ b/service-t/src/com/android/server/net/NetworkStatsService.java
@@ -493,7 +493,8 @@
     @Nullable
     private final TrafficStatsRateLimitCache mTrafficStatsUidCache;
     // A feature flag to control whether the client-side rate limit cache should be enabled.
-    static final String TRAFFICSTATS_CLIENT_RATE_LIMIT_CACHE_ENABLED_FLAG =
+    @VisibleForTesting
+    public static final String TRAFFICSTATS_CLIENT_RATE_LIMIT_CACHE_ENABLED_FLAG =
             "trafficstats_client_rate_limit_cache_enabled_flag";
     static final String TRAFFICSTATS_SERVICE_RATE_LIMIT_CACHE_ENABLED_FLAG =
             "trafficstats_rate_limit_cache_enabled_flag";
diff --git a/staticlibs/framework/com/android/net/module/util/LruCacheWithExpiry.java b/staticlibs/framework/com/android/net/module/util/LruCacheWithExpiry.java
index 31382bb..96d995a 100644
--- a/staticlibs/framework/com/android/net/module/util/LruCacheWithExpiry.java
+++ b/staticlibs/framework/com/android/net/module/util/LruCacheWithExpiry.java
@@ -125,6 +125,25 @@
     }
 
     /**
+     * Stores a value in the cache if absent, associated with the given key.
+     *
+     * @param key   The key to associate with the value.
+     * @param value The value to store in the cache.
+     * @return The existing value associated with the key, if present; otherwise, null.
+     */
+    @Nullable
+    public V putIfAbsent(@NonNull K key, @NonNull V value) {
+        Objects.requireNonNull(value);
+        synchronized (mMap) {
+            final V existingValue = get(key);
+            if (existingValue == null) {
+                put(key, value);
+            }
+            return existingValue;
+        }
+    }
+
+    /**
      * Clear the cache.
      */
     public void clear() {
diff --git a/staticlibs/tests/unit/src/com/android/net/module/util/LruCacheWithExpiryTest.kt b/staticlibs/tests/unit/src/com/android/net/module/util/LruCacheWithExpiryTest.kt
new file mode 100644
index 0000000..b6af892
--- /dev/null
+++ b/staticlibs/tests/unit/src/com/android/net/module/util/LruCacheWithExpiryTest.kt
@@ -0,0 +1,77 @@
+/*
+ * Copyright (C) 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.net.module.util
+
+import com.android.testutils.DevSdkIgnoreRunner
+import org.junit.Assert.assertEquals
+import org.junit.Assert.assertNull
+import org.junit.Test
+import org.junit.runner.RunWith
+import java.util.function.LongSupplier
+
+@RunWith(DevSdkIgnoreRunner::class)
+class LruCacheWithExpiryTest {
+
+    companion object {
+        private const val CACHE_SIZE = 2
+        private const val EXPIRY_DURATION_MS = 1000L
+    }
+
+    private val timeSupplier = object : LongSupplier {
+        private var currentTimeMillis = 0L
+        override fun getAsLong(): Long = currentTimeMillis
+        fun advanceTime(millis: Long) { currentTimeMillis += millis }
+    }
+
+    private val cache = LruCacheWithExpiry<Int, String>(
+            timeSupplier, EXPIRY_DURATION_MS, CACHE_SIZE) { true }
+
+    @Test
+    fun testPutIfAbsent_keyNotPresent() {
+        val value = cache.putIfAbsent(1, "value1")
+        assertNull(value)
+        assertEquals("value1", cache.get(1))
+    }
+
+    @Test
+    fun testPutIfAbsent_keyPresent() {
+        cache.put(1, "value1")
+        val value = cache.putIfAbsent(1, "value2")
+        assertEquals("value1", value)
+        assertEquals("value1", cache.get(1))
+    }
+
+    @Test
+    fun testPutIfAbsent_keyPresentButExpired() {
+        cache.put(1, "value1")
+        // Advance time to expire the entry
+        timeSupplier.advanceTime(EXPIRY_DURATION_MS + 1)
+        val value = cache.putIfAbsent(1, "value2")
+        assertNull(value)
+        assertEquals("value2", cache.get(1))
+    }
+
+    @Test
+    fun testPutIfAbsent_maxSizeReached() {
+        cache.put(1, "value1")
+        cache.put(2, "value2")
+        cache.putIfAbsent(3, "value3") // This should evict the least recently used entry (1)
+        assertNull(cache.get(1))
+        assertEquals("value2", cache.get(2))
+        assertEquals("value3", cache.get(3))
+    }
+}
diff --git a/tests/unit/java/android/net/TrafficStatsTest.kt b/tests/unit/java/android/net/TrafficStatsTest.kt
new file mode 100644
index 0000000..0f85daf
--- /dev/null
+++ b/tests/unit/java/android/net/TrafficStatsTest.kt
@@ -0,0 +1,251 @@
+/*
+ * Copyright (C) 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package android.net
+
+import android.net.TrafficStats.UNSUPPORTED
+import android.net.netstats.StatsResult
+import android.net.netstats.TrafficStatsRateLimitCacheConfig
+import android.os.Build
+import com.android.server.net.NetworkStatsService.TRAFFICSTATS_CLIENT_RATE_LIMIT_CACHE_ENABLED_FLAG
+import com.android.testutils.DevSdkIgnoreRule
+import com.android.testutils.DevSdkIgnoreRunner
+import com.android.testutils.com.android.testutils.SetFeatureFlagsRule
+import com.android.testutils.com.android.testutils.SetFeatureFlagsRule.FeatureFlag
+import org.junit.After
+import org.junit.Assert.assertEquals
+import org.junit.Before
+import org.junit.Rule
+import org.junit.Test
+import org.junit.runner.RunWith
+import org.mockito.Mockito.clearInvocations
+import org.mockito.Mockito.doReturn
+import org.mockito.Mockito.mock
+import org.mockito.Mockito.times
+import org.mockito.Mockito.verify
+import java.util.HashMap
+import java.util.function.LongSupplier
+
+const val TEST_EXPIRY_DURATION_MS = 1000
+const val TEST_IFACE = "wlan0"
+
+@RunWith(DevSdkIgnoreRunner::class)
+@DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.S_V2)
+class TrafficStatsTest {
+    private val binder = mock(INetworkStatsService::class.java)
+    private val myUid = android.os.Process.myUid()
+    private val mockMyUidStatsResult = StatsResult(5L, 6L, 7L, 8L)
+    private val mockIfaceStatsResult = StatsResult(7L, 3L, 10L, 21L)
+    private val mockTotalStatsResult = StatsResult(8L, 1L, 5L, 2L)
+    private val secondUidStatsResult = StatsResult(3L, 7L, 10L, 5L)
+    private val secondIfaceStatsResult = StatsResult(9L, 8L, 7L, 6L)
+    private val secondTotalStatsResult = StatsResult(4L, 3L, 2L, 1L)
+    private val emptyStatsResult = StatsResult(0L, 0L, 0L, 0L)
+    private val unsupportedStatsResult =
+            StatsResult(UNSUPPORTED.toLong(), UNSUPPORTED.toLong(),
+                    UNSUPPORTED.toLong(), UNSUPPORTED.toLong())
+
+    private val cacheDisabledConfig = TrafficStatsRateLimitCacheConfig.Builder()
+            .setIsCacheEnabled(false)
+            .setExpiryDurationMs(0)
+            .setMaxEntries(0)
+            .build()
+    private val cacheEnabledConfig = TrafficStatsRateLimitCacheConfig.Builder()
+            .setIsCacheEnabled(true)
+            .setExpiryDurationMs(TEST_EXPIRY_DURATION_MS)
+            .setMaxEntries(100)
+            .build()
+    private val mTestTimeSupplier = TestTimeSupplier()
+
+    private val featureFlags = HashMap<String, Boolean>()
+
+    // This will set feature flags from @FeatureFlag annotations
+    // into the map before setUp() runs.
+    @get:Rule
+    val setFeatureFlagsRule = SetFeatureFlagsRule(
+            { name, enabled -> featureFlags.put(name, enabled == true) },
+            { name -> featureFlags.getOrDefault(name, false) }
+    )
+
+    class TestTimeSupplier : LongSupplier {
+        private var currentTimeMillis = 0L
+
+        override fun getAsLong() = currentTimeMillis
+
+        fun advanceTime(millis: Int) {
+            currentTimeMillis += millis
+        }
+    }
+
+    @Before
+    fun setUp() {
+        TrafficStats.setServiceForTest(binder)
+        TrafficStats.setTimeSupplierForTest(mTestTimeSupplier)
+        mockStats(mockMyUidStatsResult, mockIfaceStatsResult, mockTotalStatsResult)
+        if (featureFlags.getOrDefault(TRAFFICSTATS_CLIENT_RATE_LIMIT_CACHE_ENABLED_FLAG, false)) {
+            doReturn(cacheEnabledConfig).`when`(binder).getRateLimitCacheConfig()
+        } else {
+            doReturn(cacheDisabledConfig).`when`(binder).getRateLimitCacheConfig()
+        }
+        TrafficStats.reinitRateLimitCacheForTest()
+    }
+
+    @After
+    fun tearDown() {
+        TrafficStats.setServiceForTest(null)
+        TrafficStats.setTimeSupplierForTest(null)
+        TrafficStats.reinitRateLimitCacheForTest()
+    }
+
+    private fun assertUidStats(uid: Int, stats: StatsResult) {
+        assertEquals(stats.rxBytes, TrafficStats.getUidRxBytes(uid))
+        assertEquals(stats.rxPackets, TrafficStats.getUidRxPackets(uid))
+        assertEquals(stats.txBytes, TrafficStats.getUidTxBytes(uid))
+        assertEquals(stats.txPackets, TrafficStats.getUidTxPackets(uid))
+    }
+
+    private fun assertIfaceStats(iface: String, stats: StatsResult) {
+        assertEquals(stats.rxBytes, TrafficStats.getRxBytes(iface))
+        assertEquals(stats.rxPackets, TrafficStats.getRxPackets(iface))
+        assertEquals(stats.txBytes, TrafficStats.getTxBytes(iface))
+        assertEquals(stats.txPackets, TrafficStats.getTxPackets(iface))
+    }
+
+    private fun assertTotalStats(stats: StatsResult) {
+        assertEquals(stats.rxBytes, TrafficStats.getTotalRxBytes())
+        assertEquals(stats.rxPackets, TrafficStats.getTotalRxPackets())
+        assertEquals(stats.txBytes, TrafficStats.getTotalTxBytes())
+        assertEquals(stats.txPackets, TrafficStats.getTotalTxPackets())
+    }
+
+    private fun mockStats(uidStats: StatsResult?, ifaceStats: StatsResult?,
+                          totalStats: StatsResult?) {
+        doReturn(uidStats).`when`(binder).getUidStats(myUid)
+        doReturn(ifaceStats).`when`(binder).getIfaceStats(TEST_IFACE)
+        doReturn(totalStats).`when`(binder).getTotalStats()
+    }
+
+    private fun assertStats(uidStats: StatsResult, ifaceStats: StatsResult,
+                            totalStats: StatsResult) {
+        assertUidStats(myUid, uidStats)
+        assertIfaceStats(TEST_IFACE, ifaceStats)
+        assertTotalStats(totalStats)
+    }
+
+    private fun assertStatsFetchInvocations(wantedInvocations: Int) {
+        verify(binder, times(wantedInvocations)).getUidStats(myUid)
+        verify(binder, times(wantedInvocations)).getIfaceStats(TEST_IFACE)
+        verify(binder, times(wantedInvocations)).getTotalStats()
+    }
+
+    @FeatureFlag(name = TRAFFICSTATS_CLIENT_RATE_LIMIT_CACHE_ENABLED_FLAG)
+    @Test
+    fun testRateLimitCacheExpiry_cacheEnabled() {
+        // Initial fetch, verify binder calls.
+        assertStats(mockMyUidStatsResult, mockIfaceStatsResult, mockTotalStatsResult)
+        assertStatsFetchInvocations(1)
+
+        // Advance time within expiry, verify cached values used.
+        clearInvocations(binder)
+        mockStats(secondUidStatsResult, secondIfaceStatsResult, secondTotalStatsResult)
+        mTestTimeSupplier.advanceTime(1)
+        assertStats(mockMyUidStatsResult, mockIfaceStatsResult, mockTotalStatsResult)
+        assertStatsFetchInvocations(0)
+
+        // Advance time to expire cache, verify new values fetched.
+        clearInvocations(binder)
+        mTestTimeSupplier.advanceTime(TEST_EXPIRY_DURATION_MS)
+        assertStats(secondUidStatsResult, secondIfaceStatsResult, secondTotalStatsResult)
+        assertStatsFetchInvocations(1)
+    }
+
+    @FeatureFlag(name = TRAFFICSTATS_CLIENT_RATE_LIMIT_CACHE_ENABLED_FLAG, enabled = false)
+    @Test
+    fun testRateLimitCacheExpiry_cacheDisabled() {
+        // Initial fetch, verify binder calls.
+        assertStats(mockMyUidStatsResult, mockIfaceStatsResult, mockTotalStatsResult)
+        assertStatsFetchInvocations(4)
+
+        // Advance time within expiry, verify new values fetched.
+        clearInvocations(binder)
+        mockStats(secondUidStatsResult, secondIfaceStatsResult, secondTotalStatsResult)
+        mTestTimeSupplier.advanceTime(1)
+        assertStats(secondUidStatsResult, secondIfaceStatsResult, secondTotalStatsResult)
+        assertStatsFetchInvocations(4)
+    }
+
+    @FeatureFlag(name = TRAFFICSTATS_CLIENT_RATE_LIMIT_CACHE_ENABLED_FLAG)
+    @Test
+    fun testInvalidStatsNotCached_cacheEnabled() {
+        doTestInvalidStatsNotCached()
+    }
+
+    @FeatureFlag(name = TRAFFICSTATS_CLIENT_RATE_LIMIT_CACHE_ENABLED_FLAG, enabled = false)
+    @Test
+    fun testInvalidStatsNotCached_cacheDisabled() {
+        doTestInvalidStatsNotCached()
+    }
+
+    private fun doTestInvalidStatsNotCached() {
+        // Mock null stats, this usually happens when the query is not valid,
+        // e.g. query uid stats of other application.
+        mockStats(null, null, null)
+        assertStats(unsupportedStatsResult, unsupportedStatsResult, unsupportedStatsResult)
+        assertStatsFetchInvocations(4)
+
+        // Verify null stats is not cached, and mock empty stats. This usually
+        // happens when queries with non-existent interface names.
+        clearInvocations(binder)
+        mockStats(emptyStatsResult, emptyStatsResult, emptyStatsResult)
+        assertStats(emptyStatsResult, emptyStatsResult, emptyStatsResult)
+        assertStatsFetchInvocations(4)
+
+        // Verify empty result is also not cached.
+        clearInvocations(binder)
+        assertStats(emptyStatsResult, emptyStatsResult, emptyStatsResult)
+        assertStatsFetchInvocations(4)
+    }
+
+    @FeatureFlag(name = TRAFFICSTATS_CLIENT_RATE_LIMIT_CACHE_ENABLED_FLAG)
+    @Test
+    fun testClearRateLimitCaches_cacheEnabled() {
+        doTestClearRateLimitCaches(true)
+    }
+
+    @FeatureFlag(name = TRAFFICSTATS_CLIENT_RATE_LIMIT_CACHE_ENABLED_FLAG, enabled = false)
+    @Test
+    fun testClearRateLimitCaches_cacheDisabled() {
+        doTestClearRateLimitCaches(false)
+    }
+
+    private fun doTestClearRateLimitCaches(cacheEnabled: Boolean) {
+        // Initial fetch, verify binder calls.
+        assertStats(mockMyUidStatsResult, mockIfaceStatsResult, mockTotalStatsResult)
+        assertStatsFetchInvocations(if (cacheEnabled) 1 else 4)
+
+        // Verify cached values are used.
+        clearInvocations(binder)
+        assertStats(mockMyUidStatsResult, mockIfaceStatsResult, mockTotalStatsResult)
+        assertStatsFetchInvocations(if (cacheEnabled) 0 else 4)
+
+        // Clear caches, verify fetching from the service.
+        clearInvocations(binder)
+        TrafficStats.clearRateLimitCaches()
+        mockStats(secondUidStatsResult, secondIfaceStatsResult, secondTotalStatsResult)
+        assertStats(secondUidStatsResult, secondIfaceStatsResult, secondTotalStatsResult)
+        assertStatsFetchInvocations(if (cacheEnabled) 1 else 4)
+    }
+}