Fix system caller cannot query stats for other uids

Follow-up from aosp/3249945. If the system caller queries stats
for other UIDs, the caller would get UNSUPPORTED.
This follows the official documentation, where the API would
not return stats for all other UIDs.

However, partners or other legacy callers might already depend
on this undocumented behavior. Changing this behavior might
damage existing apps.

Thus, this change reverts the behavior change introduced at
aosp/3249945 and allows a system caller to query stats for
other UIDs.

Test: atest ConnectivityCoverageTests:android.net.connectivity.android.net.TrafficStatsTest
Bug: 372851500
Change-Id: Ia7ea051f7b0b6349ffcb890c300362bbfbdee3dc
diff --git a/framework-t/src/android/net/TrafficStats.java b/framework-t/src/android/net/TrafficStats.java
index ab0aaa0..1294b3e 100644
--- a/framework-t/src/android/net/TrafficStats.java
+++ b/framework-t/src/android/net/TrafficStats.java
@@ -17,6 +17,10 @@
 package android.net;
 
 import static android.annotation.SystemApi.Client.MODULE_LIBRARIES;
+import static android.net.NetworkStats.UID_ALL;
+import static android.os.Process.SYSTEM_UID;
+
+import static com.android.internal.annotations.VisibleForTesting.Visibility.PRIVATE;
 
 import android.annotation.NonNull;
 import android.annotation.Nullable;
@@ -37,6 +41,9 @@
 import android.os.StrictMode;
 import android.util.Log;
 
+import com.android.internal.annotations.GuardedBy;
+import com.android.internal.annotations.VisibleForTesting;
+
 import java.io.FileDescriptor;
 import java.io.IOException;
 import java.net.DatagramSocket;
@@ -177,10 +184,16 @@
     /** @hide */
     public static final int TAG_SYSTEM_PROBE = 0xFFFFFF42;
 
+    @GuardedBy("TrafficStats.class")
     private static INetworkStatsService sStatsService;
+    @GuardedBy("TrafficStats.class")
+    private static INetworkStatsService sStatsServiceForTest = null;
+    @GuardedBy("TrafficStats.class")
+    private static int sMyUidForTest = UID_ALL;
 
     @UnsupportedAppUsage(maxTargetSdk = Build.VERSION_CODES.P, trackingBug = 130143562)
     private synchronized static INetworkStatsService getStatsService() {
+        if (sStatsServiceForTest != null) return sStatsServiceForTest;
         if (sStatsService == null) {
             throw new IllegalStateException("TrafficStats not initialized, uid="
                     + Binder.getCallingUid());
@@ -188,6 +201,40 @@
         return sStatsService;
     }
 
+    /** @hide */
+    protected static int getMyUid() {
+        synchronized (TrafficStats.class) {
+            if (sMyUidForTest != UID_ALL) {
+                return sMyUidForTest;
+            }
+        }
+        return android.os.Process.myUid();
+    }
+
+    /**
+     * Set the network stats service for testing, or null to reset.
+     *
+     * @hide
+     */
+    @VisibleForTesting(visibility = PRIVATE)
+    public static void setServiceForTest(INetworkStatsService statsService) {
+        synchronized (TrafficStats.class) {
+            sStatsServiceForTest = statsService;
+        }
+    }
+
+    /**
+     * Set myUid for test, or UID_ALL to reset.
+     *
+     * @hide
+     */
+    @VisibleForTesting(visibility = PRIVATE)
+    public static void setMyUidForTest(int myUid) {
+        synchronized (TrafficStats.class) {
+            sMyUidForTest = myUid;
+        }
+    }
+
     /**
      * Snapshot of {@link NetworkStats} when the currently active profiling
      * session started, or {@code null} if no session active.
@@ -450,7 +497,7 @@
      */
     @Deprecated
     public static void setThreadStatsUidSelf() {
-        setThreadStatsUid(android.os.Process.myUid());
+        setThreadStatsUid(getMyUid());
     }
 
     /**
@@ -591,7 +638,7 @@
      * @param operationCount Number of operations to increment count by.
      */
     public static void incrementOperationCount(int tag, int operationCount) {
-        final int uid = android.os.Process.myUid();
+        final int uid = getMyUid();
         try {
             getStatsService().incrementOperationCount(uid, tag, operationCount);
         } catch (RemoteException e) {
@@ -959,8 +1006,11 @@
 
     /** @hide */
     public static long getUidStats(int uid, int type) {
-        if (!isEntryValueTypeValid(type)
-                || android.os.Process.myUid() != uid) {
+        // Perform a quick check on the UID to avoid unnecessary work.
+        // This mirrors a similar check on the service side, but is primarily for
+        // efficiency rather than security, as user-space checks can be bypassed.
+        final int myUid = getMyUid();
+        if (!isEntryValueTypeValid(type) || (myUid != SYSTEM_UID && myUid != uid)) {
             return UNSUPPORTED;
         }
         final StatsResult stats;
@@ -1094,7 +1144,7 @@
      */
     private static NetworkStats getDataLayerSnapshotForUid(Context context) {
         // TODO: take snapshot locally, since proc file is now visible
-        final int uid = android.os.Process.myUid();
+        final int uid = getMyUid();
         try {
             return getStatsService().getDataLayerSnapshotForUid(uid);
         } catch (RemoteException e) {
diff --git a/tests/unit/java/android/net/TrafficStatsTest.kt b/tests/unit/java/android/net/TrafficStatsTest.kt
new file mode 100644
index 0000000..6e8f3db
--- /dev/null
+++ b/tests/unit/java/android/net/TrafficStatsTest.kt
@@ -0,0 +1,92 @@
+/*
+ * Copyright (C) 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package android.net
+
+import android.net.NetworkStats.UID_ALL
+import android.net.TrafficStats.UNSUPPORTED
+import android.net.netstats.StatsResult
+import android.os.Build
+import android.os.Process.SYSTEM_UID
+import com.android.testutils.DevSdkIgnoreRule
+import com.android.testutils.DevSdkIgnoreRunner
+import org.junit.After
+import org.junit.Assert.assertEquals
+import org.junit.Before
+import org.junit.Test
+import org.junit.runner.RunWith
+import org.mockito.ArgumentMatchers.anyInt
+import org.mockito.Mockito.doReturn
+import org.mockito.Mockito.mock
+import org.mockito.Mockito.never
+import org.mockito.Mockito.verify
+
+@RunWith(DevSdkIgnoreRunner::class)
+@DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.S_V2)
+class TrafficStatsTest {
+    private val binder = mock(INetworkStatsService::class.java)
+    private val myUid = android.os.Process.myUid()
+    private val notMyUid = myUid + 1
+    private val mockSystemUidStatsResult = StatsResult(1L, 2L, 3L, 4L)
+    private val mockMyUidStatsResult = StatsResult(5L, 6L, 7L, 8L)
+    private val mockNotMyUidStatsResult = StatsResult(9L, 10L, 11L, 12L)
+    private val unsupportedStatsResult =
+            StatsResult(UNSUPPORTED.toLong(), UNSUPPORTED.toLong(),
+                    UNSUPPORTED.toLong(), UNSUPPORTED.toLong())
+
+    @Before
+    fun setUp() {
+        TrafficStats.setServiceForTest(binder)
+        doReturn(mockSystemUidStatsResult).`when`(binder).getUidStats(SYSTEM_UID)
+        doReturn(mockMyUidStatsResult).`when`(binder).getUidStats(myUid)
+        doReturn(mockNotMyUidStatsResult).`when`(binder).getUidStats(notMyUid)
+    }
+
+    @After
+    fun tearDown() {
+        TrafficStats.setServiceForTest(null)
+        TrafficStats.setMyUidForTest(UID_ALL)
+    }
+
+    private fun assertUidStats(uid: Int, stats: StatsResult) {
+        assertEquals(stats.rxBytes, TrafficStats.getUidRxBytes(uid))
+        assertEquals(stats.rxPackets, TrafficStats.getUidRxPackets(uid))
+        assertEquals(stats.txBytes, TrafficStats.getUidTxBytes(uid))
+        assertEquals(stats.txPackets, TrafficStats.getUidTxPackets(uid))
+    }
+
+    // Verify a normal caller could get a quick UNSUPPORTED result in the TrafficStats
+    // without accessing the service if query stats other than itself.
+    @Test
+    fun testGetUidStats_appCaller() {
+        assertUidStats(SYSTEM_UID, unsupportedStatsResult)
+        assertUidStats(notMyUid, unsupportedStatsResult)
+        verify(binder, never()).getUidStats(anyInt())
+        assertUidStats(myUid, mockMyUidStatsResult)
+    }
+
+    // Verify that callers with SYSTEM_UID can access network
+    // stats for other UIDs. While this behavior is not officially documented
+    // in the API, it exists for compatibility with existing callers that may
+    // rely on it.
+    @Test
+    fun testGetUidStats_systemCaller() {
+        TrafficStats.setMyUidForTest(SYSTEM_UID)
+        assertUidStats(SYSTEM_UID, mockSystemUidStatsResult)
+        assertUidStats(myUid, mockMyUidStatsResult)
+        assertUidStats(notMyUid, mockNotMyUidStatsResult)
+    }
+}