Merge "Add getter to PerUidCounter"
diff --git a/staticlibs/framework/com/android/net/module/util/PerUidCounter.java b/staticlibs/framework/com/android/net/module/util/PerUidCounter.java
index 0b2de7a..463b0c4 100644
--- a/staticlibs/framework/com/android/net/module/util/PerUidCounter.java
+++ b/staticlibs/framework/com/android/net/module/util/PerUidCounter.java
@@ -32,7 +32,7 @@
 
     // Map from UID to count that UID has filed.
     @VisibleForTesting
-    @GuardedBy("mUidToCount")
+    @GuardedBy("this")
     final SparseIntArray mUidToCount = new SparseIntArray();
 
     /**
@@ -57,15 +57,8 @@
      *
      * @param uid the uid that the counter was made under
      */
-    public void incrementCountOrThrow(final int uid) {
-        incrementCountOrThrow(uid, 1 /* numToIncrement */);
-    }
-
-    public synchronized void incrementCountOrThrow(final int uid, final int numToIncrement) {
-        if (numToIncrement <= 0) {
-            throw new IllegalArgumentException("Increment count must be positive");
-        }
-        final long newCount = ((long) mUidToCount.get(uid, 0)) + numToIncrement;
+    public synchronized void incrementCountOrThrow(final int uid) {
+        final long newCount = ((long) mUidToCount.get(uid, 0)) + 1;
         if (newCount > mMaxCountPerUid) {
             throw new IllegalStateException("Uid " + uid + " exceeded its allowed limit");
         }
@@ -83,15 +76,8 @@
      *
      * @param uid the uid that the count was made under
      */
-    public void decrementCountOrThrow(final int uid) {
-        decrementCountOrThrow(uid, 1 /* numToDecrement */);
-    }
-
-    public synchronized void decrementCountOrThrow(final int uid, final int numToDecrement) {
-        if (numToDecrement <= 0) {
-            throw new IllegalArgumentException("Decrement count must be positive");
-        }
-        final int newCount = mUidToCount.get(uid, 0) - numToDecrement;
+    public synchronized void decrementCountOrThrow(final int uid) {
+        final int newCount = mUidToCount.get(uid, 0) - 1;
         if (newCount < 0) {
             throw new IllegalStateException("BUG: too small count " + newCount + " for UID " + uid);
         } else if (newCount == 0) {
@@ -100,4 +86,9 @@
             mUidToCount.put(uid, newCount);
         }
     }
+
+    @VisibleForTesting
+    public synchronized int get(int uid) {
+        return mUidToCount.get(uid, 0);
+    }
 }
diff --git a/staticlibs/tests/unit/src/com/android/net/module/util/PerUidCounterTest.kt b/staticlibs/tests/unit/src/com/android/net/module/util/PerUidCounterTest.kt
index 0f2d52a..321fe59 100644
--- a/staticlibs/tests/unit/src/com/android/net/module/util/PerUidCounterTest.kt
+++ b/staticlibs/tests/unit/src/com/android/net/module/util/PerUidCounterTest.kt
@@ -20,6 +20,7 @@
 import androidx.test.runner.AndroidJUnit4
 import org.junit.Test
 import org.junit.runner.RunWith
+import kotlin.test.assertEquals
 import kotlin.test.assertFailsWith
 
 @RunWith(AndroidJUnit4::class)
@@ -27,6 +28,7 @@
 class PerUidCounterTest {
     private val UID_A = 1000
     private val UID_B = 1001
+    private val UID_C = 1002
 
     @Test
     fun testCounterMaximum() {
@@ -37,31 +39,35 @@
             PerUidCounter(0)
         }
 
-        val largeMaxCounter = PerUidCounter(Integer.MAX_VALUE)
-        largeMaxCounter.incrementCountOrThrow(UID_A, Integer.MAX_VALUE)
-        assertFailsWith<IllegalStateException> {
-            largeMaxCounter.incrementCountOrThrow(UID_A)
+        val testLimit = 1000
+        val testCounter = PerUidCounter(testLimit)
+        assertEquals(0, testCounter[UID_A])
+        repeat(testLimit) {
+            testCounter.incrementCountOrThrow(UID_A)
         }
+        assertEquals(testLimit, testCounter[UID_A])
+        assertFailsWith<IllegalStateException> {
+            testCounter.incrementCountOrThrow(UID_A)
+        }
+        assertEquals(testLimit, testCounter[UID_A])
     }
 
     @Test
     fun testIncrementCountOrThrow() {
         val counter = PerUidCounter(3)
 
-        // Verify the increment count cannot be zero.
-        assertFailsWith<IllegalArgumentException> {
-            counter.incrementCountOrThrow(UID_A, 0)
-        }
-
         // Verify the counters work independently.
         counter.incrementCountOrThrow(UID_A)
-        counter.incrementCountOrThrow(UID_B, 2)
+        counter.incrementCountOrThrow(UID_B)
         counter.incrementCountOrThrow(UID_B)
         counter.incrementCountOrThrow(UID_A)
         counter.incrementCountOrThrow(UID_A)
+        assertEquals(3, counter[UID_A])
+        assertEquals(2, counter[UID_B])
         assertFailsWith<IllegalStateException> {
             counter.incrementCountOrThrow(UID_A)
         }
+        counter.incrementCountOrThrow(UID_B)
         assertFailsWith<IllegalStateException> {
             counter.incrementCountOrThrow(UID_B)
         }
@@ -71,39 +77,66 @@
             counter.incrementCountOrThrow(UID_A)
         }
         assertFailsWith<IllegalStateException> {
-            counter.incrementCountOrThrow(UID_A, 3)
+            repeat(3) {
+                counter.incrementCountOrThrow(UID_A)
+            }
         }
+        assertEquals(3, counter[UID_A])
+        assertEquals(3, counter[UID_B])
+        assertEquals(0, counter[UID_C])
     }
 
     @Test
     fun testDecrementCountOrThrow() {
         val counter = PerUidCounter(3)
 
-        // Verify the decrement count cannot be zero.
-        assertFailsWith<IllegalArgumentException> {
-            counter.decrementCountOrThrow(UID_A, 0)
-        }
-
         // Verify the count cannot go below zero.
         assertFailsWith<IllegalStateException> {
             counter.decrementCountOrThrow(UID_A)
         }
         assertFailsWith<IllegalStateException> {
-            counter.decrementCountOrThrow(UID_A, 5)
-        }
-        assertFailsWith<IllegalStateException> {
-            counter.decrementCountOrThrow(UID_A, Integer.MAX_VALUE)
+            repeat(5) {
+                counter.decrementCountOrThrow(UID_A)
+            }
         }
 
         // Verify the counters work independently.
         counter.incrementCountOrThrow(UID_A)
         counter.incrementCountOrThrow(UID_B)
+        assertEquals(1, counter[UID_A])
+        assertEquals(1, counter[UID_B])
         assertFailsWith<IllegalStateException> {
-            counter.decrementCountOrThrow(UID_A, 3)
+            repeat(3) {
+                counter.decrementCountOrThrow(UID_A)
+            }
         }
-        counter.decrementCountOrThrow(UID_A)
         assertFailsWith<IllegalStateException> {
             counter.decrementCountOrThrow(UID_A)
         }
+        assertEquals(0, counter[UID_A])
+        assertEquals(1, counter[UID_B])
+
+        // Verify mixing increment and decrement.
+        val largeCounter = PerUidCounter(100)
+        repeat(90) {
+            largeCounter.incrementCountOrThrow(UID_A)
+        }
+        repeat(70) {
+            largeCounter.decrementCountOrThrow(UID_A)
+        }
+        repeat(80) {
+            largeCounter.incrementCountOrThrow(UID_A)
+        }
+        assertFailsWith<IllegalStateException> {
+            largeCounter.incrementCountOrThrow(UID_A)
+        }
+        assertEquals(100, largeCounter[UID_A])
+        repeat(100) {
+            largeCounter.decrementCountOrThrow(UID_A)
+        }
+        assertFailsWith<IllegalStateException> {
+            largeCounter.decrementCountOrThrow(UID_A)
+        }
+        assertEquals(0, largeCounter[UID_A])
     }
 }
\ No newline at end of file