Use LruCache to limit the amount of entries
This change also implement a getOrCompute method, if the entry
is not found in the cache or has expired, computes it using the
provided supplier and stores the result in the cache.
Test: atest ConnectivityCoverageTests:android.net.connectivity.com.android.server.net.TrafficStatsRateLimitCacheTest
Bug: N/A
Change-Id: I60f07e84b7d2c754548b7cd431bd079b0d808c0a
diff --git a/service-t/src/com/android/server/net/TrafficStatsRateLimitCache.java b/service-t/src/com/android/server/net/TrafficStatsRateLimitCache.java
index 8598ac4..ca97d07 100644
--- a/service-t/src/com/android/server/net/TrafficStatsRateLimitCache.java
+++ b/service-t/src/com/android/server/net/TrafficStatsRateLimitCache.java
@@ -19,12 +19,13 @@
import android.annotation.NonNull;
import android.annotation.Nullable;
import android.net.NetworkStats;
+import android.util.LruCache;
import com.android.internal.annotations.GuardedBy;
import java.time.Clock;
-import java.util.HashMap;
import java.util.Objects;
+import java.util.function.Supplier;
/**
* A thread-safe cache for storing and retrieving {@link NetworkStats.Entry} objects,
@@ -39,10 +40,12 @@
*
* @param clock The {@link Clock} to use for determining timestamps.
* @param expiryDurationMs The expiry duration in milliseconds.
+ * @param maxSize Maximum number of entries.
*/
- TrafficStatsRateLimitCache(@NonNull Clock clock, long expiryDurationMs) {
+ TrafficStatsRateLimitCache(@NonNull Clock clock, long expiryDurationMs, int maxSize) {
mClock = clock;
mExpiryDurationMs = expiryDurationMs;
+ mMap = new LruCache<>(maxSize);
}
private static class TrafficStatsCacheKey {
@@ -81,7 +84,7 @@
}
@GuardedBy("mMap")
- private final HashMap<TrafficStatsCacheKey, TrafficStatsCacheValue> mMap = new HashMap<>();
+ private final LruCache<TrafficStatsCacheKey, TrafficStatsCacheValue> mMap;
/**
* Retrieves a {@link NetworkStats.Entry} from the cache, associated with the given key.
@@ -105,6 +108,36 @@
}
/**
+ * Retrieves a {@link NetworkStats.Entry} from the cache, associated with the given key.
+ * If the entry is not found in the cache or has expired, computes it using the provided
+ * {@code supplier} and stores the result in the cache.
+ *
+ * @param iface The interface name to include in the cache key. {@code IFACE_ALL}
+ * if not applicable.
+ * @param uid The UID to include in the cache key. {@code UID_ALL} if not applicable.
+ * @param supplier The {@link Supplier} to compute the {@link NetworkStats.Entry} if not found.
+ * @return The cached or computed {@link NetworkStats.Entry}, or null if not found, expired,
+ * or if the {@code supplier} returns null.
+ */
+ @Nullable
+ NetworkStats.Entry getOrCompute(String iface, int uid,
+ @NonNull Supplier<NetworkStats.Entry> supplier) {
+ synchronized (mMap) {
+ final NetworkStats.Entry cachedValue = get(iface, uid);
+ if (cachedValue != null) {
+ return cachedValue;
+ }
+
+ // Entry not found or expired, compute it
+ final NetworkStats.Entry computedEntry = supplier.get();
+ if (computedEntry != null && !computedEntry.isEmpty()) {
+ put(iface, uid, computedEntry);
+ }
+ return computedEntry;
+ }
+ }
+
+ /**
* Stores a {@link NetworkStats.Entry} in the cache, associated with the given key.
*
* @param iface The interface name to include in the cache key. Null if not applicable.
@@ -124,7 +157,7 @@
*/
void clear() {
synchronized (mMap) {
- mMap.clear();
+ mMap.evictAll();
}
}
diff --git a/tests/unit/java/com/android/server/net/TrafficStatsRateLimitCacheTest.kt b/tests/unit/java/com/android/server/net/TrafficStatsRateLimitCacheTest.kt
index 27e6f96..99f762d 100644
--- a/tests/unit/java/com/android/server/net/TrafficStatsRateLimitCacheTest.kt
+++ b/tests/unit/java/com/android/server/net/TrafficStatsRateLimitCacheTest.kt
@@ -16,30 +16,35 @@
package com.android.server.net
-import android.net.NetworkStats
+import android.net.NetworkStats.Entry
import com.android.testutils.DevSdkIgnoreRunner
import java.time.Clock
+import java.util.function.Supplier
import kotlin.test.assertEquals
import kotlin.test.assertNull
+import kotlin.test.fail
import org.junit.Test
import org.junit.runner.RunWith
+import org.mockito.Mockito.doReturn
import org.mockito.Mockito.mock
+import org.mockito.Mockito.verify
import org.mockito.Mockito.`when`
@RunWith(DevSdkIgnoreRunner::class)
class TrafficStatsRateLimitCacheTest {
companion object {
private const val expiryDurationMs = 1000L
+ private const val maxSize = 2
}
private val clock = mock(Clock::class.java)
- private val entry = mock(NetworkStats.Entry::class.java)
- private val cache = TrafficStatsRateLimitCache(clock, expiryDurationMs)
+ private val entry = mock(Entry::class.java)
+ private val cache = TrafficStatsRateLimitCache(clock, expiryDurationMs, maxSize)
@Test
fun testGet_returnsEntryIfNotExpired() {
cache.put("iface", 2, entry)
- `when`(clock.millis()).thenReturn(500L) // Set clock to before expiry
+ doReturn(500L).`when`(clock).millis() // Set clock to before expiry
val result = cache.get("iface", 2)
assertEquals(entry, result)
}
@@ -47,7 +52,7 @@
@Test
fun testGet_returnsNullIfExpired() {
cache.put("iface", 2, entry)
- `when`(clock.millis()).thenReturn(2000L) // Set clock to after expiry
+ doReturn(2000L).`when`(clock).millis() // Set clock to after expiry
assertNull(cache.get("iface", 2))
}
@@ -59,8 +64,8 @@
@Test
fun testPutAndGet_retrievesCorrectEntryForDifferentKeys() {
- val entry1 = mock(NetworkStats.Entry::class.java)
- val entry2 = mock(NetworkStats.Entry::class.java)
+ val entry1 = mock(Entry::class.java)
+ val entry2 = mock(Entry::class.java)
cache.put("iface1", 2, entry1)
cache.put("iface2", 4, entry2)
@@ -71,8 +76,8 @@
@Test
fun testPut_overridesExistingEntry() {
- val entry1 = mock(NetworkStats.Entry::class.java)
- val entry2 = mock(NetworkStats.Entry::class.java)
+ val entry1 = mock(Entry::class.java)
+ val entry2 = mock(Entry::class.java)
cache.put("iface", 2, entry1)
cache.put("iface", 2, entry2) // Put with the same key
@@ -81,6 +86,62 @@
}
@Test
+ fun testPut_removeLru() {
+ // Assumes max size is 2. Verify eldest entry get removed.
+ val entry1 = mock(Entry::class.java)
+ val entry2 = mock(Entry::class.java)
+ val entry3 = mock(Entry::class.java)
+
+ cache.put("iface1", 2, entry1)
+ cache.put("iface2", 4, entry2)
+ cache.put("iface3", 8, entry3)
+
+ assertNull(cache.get("iface1", 2))
+ assertEquals(entry2, cache.get("iface2", 4))
+ assertEquals(entry3, cache.get("iface3", 8))
+ }
+
+ @Test
+ fun testGetOrCompute_cacheHit() {
+ val entry1 = mock(Entry::class.java)
+
+ cache.put("iface1", 2, entry1)
+
+ // Set clock to before expiry.
+ doReturn(500L).`when`(clock).millis()
+
+ // Now call getOrCompute
+ val result = cache.getOrCompute("iface1", 2) {
+ fail("Supplier should not be called")
+ }
+
+ // Assertions
+ assertEquals(entry1, result) // Should get the cached entry.
+ }
+
+ @Suppress("UNCHECKED_CAST")
+ @Test
+ fun testGetOrCompute_cacheMiss() {
+ val entry1 = mock(Entry::class.java)
+
+ cache.put("iface1", 2, entry1)
+
+ // Set clock to after expiry.
+ doReturn(1500L).`when`(clock).millis()
+
+ // Mock the supplier to return our network stats entry.
+ val supplier = mock(Supplier::class.java) as Supplier<Entry>
+ doReturn(entry1).`when`(supplier).get()
+
+ // Now call getOrCompute.
+ val result = cache.getOrCompute("iface1", 2, supplier)
+
+ // Assertions.
+ assertEquals(entry1, result) // Should get the cached entry.
+ verify(supplier).get()
+ }
+
+ @Test
fun testClear() {
cache.put("iface", 2, entry)
cache.clear()