diff --git a/tests/unit/java/com/android/server/net/NetworkStatsObserversTest.java b/tests/unit/java/com/android/server/net/NetworkStatsObserversTest.java
index 6a7da9e..66dcf6d 100644
--- a/tests/unit/java/com/android/server/net/NetworkStatsObserversTest.java
+++ b/tests/unit/java/com/android/server/net/NetworkStatsObserversTest.java
@@ -42,7 +42,6 @@
 import android.net.NetworkStats;
 import android.net.NetworkStatsAccess;
 import android.net.NetworkTemplate;
-import android.net.netstats.IUsageCallback;
 import android.os.HandlerThread;
 import android.os.IBinder;
 import android.os.Looper;
@@ -101,7 +100,7 @@
     private ArrayMap<String, NetworkIdentitySet> mActiveUidIfaces;
 
     @Mock private IBinder mUsageCallbackBinder;
-    @Mock private IUsageCallback mUsageCallback;
+    private TestableUsageCallback mUsageCallback;
 
     @Before
     public void setUp() throws Exception {
@@ -119,20 +118,27 @@
 
         mActiveIfaces = new ArrayMap<>();
         mActiveUidIfaces = new ArrayMap<>();
-        Mockito.when(mUsageCallback.asBinder()).thenReturn(mUsageCallbackBinder);
+        mUsageCallback = new TestableUsageCallback(mUsageCallbackBinder);
     }
 
     @Test
     public void testRegister_thresholdTooLow_setsDefaultThreshold() throws Exception {
-        long thresholdTooLowBytes = 1L;
-        DataUsageRequest inputRequest = new DataUsageRequest(
+        final long thresholdTooLowBytes = 1L;
+        final DataUsageRequest inputRequest = new DataUsageRequest(
                 DataUsageRequest.REQUEST_ID_UNSET, sTemplateWifi, thresholdTooLowBytes);
 
-        DataUsageRequest request = mStatsObservers.register(inputRequest, mUsageCallback,
-                Process.SYSTEM_UID, NetworkStatsAccess.Level.DEVICE);
-        assertTrue(request.requestId > 0);
-        assertTrue(Objects.equals(sTemplateWifi, request.template));
-        assertEquals(THRESHOLD_BYTES, request.thresholdInBytes);
+        final DataUsageRequest requestByApp = mStatsObservers.register(inputRequest, mUsageCallback,
+                UID_RED, NetworkStatsAccess.Level.DEVICE);
+        assertTrue(requestByApp.requestId > 0);
+        assertTrue(Objects.equals(sTemplateWifi, requestByApp.template));
+        assertEquals(THRESHOLD_BYTES, requestByApp.thresholdInBytes);
+
+        // Verify the threshold requested by system uid won't be overridden.
+        final DataUsageRequest requestBySystem = mStatsObservers.register(inputRequest,
+                mUsageCallback, Process.SYSTEM_UID, NetworkStatsAccess.Level.DEVICE);
+        assertTrue(requestBySystem.requestId > 0);
+        assertTrue(Objects.equals(sTemplateWifi, requestBySystem.template));
+        assertEquals(1, requestBySystem.thresholdInBytes);
     }
 
     @Test
@@ -304,7 +310,7 @@
         mStatsObservers.updateStats(
                 xtSnapshot, uidSnapshot, mActiveIfaces, mActiveUidIfaces, TEST_START);
         waitForObserverToIdle();
-        Mockito.verify(mUsageCallback).onThresholdReached(any());
+        mUsageCallback.expectOnThresholdReached(request);
     }
 
     @Test
@@ -337,7 +343,7 @@
         mStatsObservers.updateStats(
                 xtSnapshot, uidSnapshot, mActiveIfaces, mActiveUidIfaces, TEST_START);
         waitForObserverToIdle();
-        Mockito.verify(mUsageCallback).onThresholdReached(any());
+        mUsageCallback.expectOnThresholdReached(request);
     }
 
     @Test
@@ -402,7 +408,7 @@
         mStatsObservers.updateStats(
                 xtSnapshot, uidSnapshot, mActiveIfaces, mActiveUidIfaces, TEST_START);
         waitForObserverToIdle();
-        Mockito.verify(mUsageCallback).onThresholdReached(any());
+        mUsageCallback.expectOnThresholdReached(request);
     }
 
     @Test
diff --git a/tests/unit/java/com/android/server/net/NetworkStatsServiceTest.java b/tests/unit/java/com/android/server/net/NetworkStatsServiceTest.java
index e3b3621..13dc3cb 100644
--- a/tests/unit/java/com/android/server/net/NetworkStatsServiceTest.java
+++ b/tests/unit/java/com/android/server/net/NetworkStatsServiceTest.java
@@ -1285,7 +1285,7 @@
 
 
         // Wait for the caller to invoke expectOnThresholdReached.
-        mUsageCallback.expectOnThresholdReached();
+        mUsageCallback.expectOnThresholdReached(request);
 
         // Allow binder to disconnect
         when(mUsageCallbackBinder.unlinkToDeath(any(IBinder.DeathRecipient.class), anyInt()))
@@ -1295,7 +1295,7 @@
         mService.unregisterUsageRequest(request);
 
         // Wait for the caller to invoke expectOnCallbackReleased.
-        mUsageCallback.expectOnCallbackReleased();
+        mUsageCallback.expectOnCallbackReleased(request);
 
         // Make sure that the caller binder gets disconnected
         verify(mUsageCallbackBinder).unlinkToDeath(any(IBinder.DeathRecipient.class), anyInt());
diff --git a/tests/unit/java/com/android/server/net/TestableUsageCallback.kt b/tests/unit/java/com/android/server/net/TestableUsageCallback.kt
index 44f588c..1917ec3 100644
--- a/tests/unit/java/com/android/server/net/TestableUsageCallback.kt
+++ b/tests/unit/java/com/android/server/net/TestableUsageCallback.kt
@@ -21,37 +21,34 @@
 import android.os.IBinder
 import java.util.concurrent.LinkedBlockingQueue
 import java.util.concurrent.TimeUnit
-import kotlin.test.assertEquals
 import kotlin.test.fail
 
 private const val DEFAULT_TIMEOUT_MS = 200L
 
 // TODO: Move the class to static libs once all downstream have IUsageCallback definition.
-open class TestableUsageCallback(private val binder: IBinder) : IUsageCallback.Stub() {
-    sealed class CallbackType {
-        object OnThresholdReached : CallbackType()
-        object OnCallbackReleased : CallbackType()
+class TestableUsageCallback(private val binder: IBinder) : IUsageCallback.Stub() {
+    sealed class CallbackType(val request: DataUsageRequest) {
+        class OnThresholdReached(request: DataUsageRequest) : CallbackType(request)
+        class OnCallbackReleased(request: DataUsageRequest) : CallbackType(request)
     }
 
     // TODO: Change to use ArrayTrackRecord once moved into to the module.
     private val history = LinkedBlockingQueue<CallbackType>()
 
     override fun onThresholdReached(request: DataUsageRequest) {
-        history.add(CallbackType.OnThresholdReached)
+        history.add(CallbackType.OnThresholdReached(request))
     }
 
     override fun onCallbackReleased(request: DataUsageRequest) {
-        history.add(CallbackType.OnCallbackReleased)
+        history.add(CallbackType.OnCallbackReleased(request))
     }
 
-    fun expectOnThresholdReached() {
-        assertEquals(CallbackType.OnThresholdReached,
-                history.poll(DEFAULT_TIMEOUT_MS, TimeUnit.MILLISECONDS))
+    fun expectOnThresholdReached(request: DataUsageRequest) {
+        expectCallback<CallbackType.OnThresholdReached>(request, DEFAULT_TIMEOUT_MS)
     }
 
-    fun expectOnCallbackReleased() {
-        assertEquals(CallbackType.OnCallbackReleased,
-                history.poll(DEFAULT_TIMEOUT_MS, TimeUnit.MILLISECONDS))
+    fun expectOnCallbackReleased(request: DataUsageRequest) {
+        expectCallback<CallbackType.OnCallbackReleased>(request, DEFAULT_TIMEOUT_MS)
     }
 
     @JvmOverloads
@@ -60,6 +57,22 @@
         cb?.let { fail("Expected no callback but got $cb") }
     }
 
+    // Expects a callback of the specified request on the specified network within the timeout.
+    // If no callback arrives, or a different callback arrives, fail.
+    private inline fun <reified T : CallbackType> expectCallback(
+        expectedRequest: DataUsageRequest,
+        timeoutMs: Long
+    ) {
+        history.poll(timeoutMs, TimeUnit.MILLISECONDS).let {
+            if (it !is T || it.request != expectedRequest) {
+                fail("Unexpected callback : $it," +
+                        " expected ${T::class} with Request[$expectedRequest]")
+            } else {
+                it
+            }
+        }
+    }
+
     override fun asBinder(): IBinder {
         return binder
     }
