Merge "Check target sockets by UID range" into main
diff --git a/service/src/com/android/server/connectivity/AutomaticOnOffKeepaliveTracker.java b/service/src/com/android/server/connectivity/AutomaticOnOffKeepaliveTracker.java
index 8036ae9..94ba9de 100644
--- a/service/src/com/android/server/connectivity/AutomaticOnOffKeepaliveTracker.java
+++ b/service/src/com/android/server/connectivity/AutomaticOnOffKeepaliveTracker.java
@@ -25,9 +25,6 @@
 import static android.system.OsConstants.SOL_SOCKET;
 import static android.system.OsConstants.SO_SNDTIMEO;
 
-import static com.android.net.module.util.netlink.NetlinkConstants.NLMSG_DONE;
-import static com.android.net.module.util.netlink.NetlinkConstants.SOCKDIAG_MSG_HEADER_SIZE;
-import static com.android.net.module.util.netlink.NetlinkConstants.SOCK_DIAG_BY_FAMILY;
 import static com.android.net.module.util.netlink.NetlinkUtils.IO_TIMEOUT_MS;
 
 import android.annotation.IntDef;
@@ -90,6 +87,7 @@
  */
 public class AutomaticOnOffKeepaliveTracker {
     private static final String TAG = "AutomaticOnOffKeepaliveTracker";
+    private static final boolean DBG = Log.isLoggable(TAG, Log.DEBUG);
     private static final int[] ADDRESS_FAMILIES = new int[] {AF_INET6, AF_INET};
     private static final long LOW_TCP_POLLING_INTERVAL_MS = 1_000L;
     private static final int ADJUST_TCP_POLLING_DELAY_MS = 2000;
@@ -794,22 +792,18 @@
 
             try {
                 while (NetlinkUtils.enoughBytesRemainForValidNlMsg(bytes)) {
-                    final int startPos = bytes.position();
+                    // NetlinkMessage.parse() will move the byte buffer position.
+                    // TODO: Parse dst address information to filter socket.
+                    final NetlinkMessage nlMsg = NetlinkMessage.parse(
+                            bytes, OsConstants.NETLINK_INET_DIAG);
+                    if (!(nlMsg instanceof InetDiagMessage)) {
+                        if (DBG) Log.e(TAG, "Not a SOCK_DIAG_BY_FAMILY msg");
+                        return false;
+                    }
 
-                    final int nlmsgLen = bytes.getInt();
-                    final int nlmsgType = bytes.getShort();
-                    if (isEndOfMessageOrError(nlmsgType)) return false;
-                    // TODO: Parse InetDiagMessage to get uid and dst address information to filter
-                    //  socket via NetlinkMessage.parse.
-
-                    // Skip the header to move to data part.
-                    bytes.position(startPos + SOCKDIAG_MSG_HEADER_SIZE);
-
-                    if (isTargetTcpSocket(bytes, nlmsgLen, networkMark, networkMask)) {
-                        if (Log.isLoggable(TAG, Log.DEBUG)) {
-                            bytes.position(startPos);
-                            final InetDiagMessage diagMsg = (InetDiagMessage) NetlinkMessage.parse(
-                                    bytes, OsConstants.NETLINK_INET_DIAG);
+                    final InetDiagMessage diagMsg = (InetDiagMessage) nlMsg;
+                    if (isTargetTcpSocket(diagMsg, networkMark, networkMask, vpnUidRanges)) {
+                        if (DBG) {
                             Log.d(TAG, String.format("Found open TCP connection by uid %d to %s"
                                             + " cookie %d",
                                     diagMsg.inetDiagMsg.idiag_uid,
@@ -834,26 +828,31 @@
         return false;
     }
 
-    private boolean isEndOfMessageOrError(int nlmsgType) {
-        return nlmsgType == NLMSG_DONE || nlmsgType != SOCK_DIAG_BY_FAMILY;
+    private static boolean containsUid(Set<Range<Integer>> ranges, int uid) {
+        for (final Range<Integer> range: ranges) {
+            if (range.contains(uid)) {
+                return true;
+            }
+        }
+        return false;
     }
 
-    private boolean isTargetTcpSocket(@NonNull ByteBuffer bytes, int nlmsgLen, int networkMark,
-            int networkMask) {
-        final int mark = readSocketDataAndReturnMark(bytes, nlmsgLen);
+    private boolean isTargetTcpSocket(@NonNull InetDiagMessage diagMsg,
+            int networkMark, int networkMask, @NonNull Set<Range<Integer>> vpnUidRanges) {
+        if (!containsUid(vpnUidRanges, diagMsg.inetDiagMsg.idiag_uid)) return false;
+
+        final int mark = readSocketDataAndReturnMark(diagMsg);
         return (mark & networkMask) == networkMark;
     }
 
-    private int readSocketDataAndReturnMark(@NonNull ByteBuffer bytes, int nlmsgLen) {
-        final int nextMsgOffset = bytes.position() + nlmsgLen - SOCKDIAG_MSG_HEADER_SIZE;
+    private int readSocketDataAndReturnMark(@NonNull InetDiagMessage diagMsg) {
         int mark = NetlinkUtils.INIT_MARK_VALUE;
         // Get socket mark
-        // TODO: Add a parsing method in NetlinkMessage.parse to support this to skip the remaining
-        //  data.
-        while (bytes.position() < nextMsgOffset) {
-            final StructNlAttr nlattr = StructNlAttr.parse(bytes);
-            if (nlattr != null && nlattr.nla_type == NetlinkUtils.INET_DIAG_MARK) {
-                mark = nlattr.getValueAsInteger();
+        for (StructNlAttr attr : diagMsg.nlAttrs) {
+            if (attr.nla_type == NetlinkUtils.INET_DIAG_MARK) {
+                // The netlink attributes should contain only one INET_DIAG_MARK for each socket.
+                mark = attr.getValueAsInteger();
+                break;
             }
         }
         return mark;
diff --git a/tests/unit/java/com/android/server/connectivity/AutomaticOnOffKeepaliveTrackerTest.java b/tests/unit/java/com/android/server/connectivity/AutomaticOnOffKeepaliveTrackerTest.java
index 015b75a..4fcf8a8 100644
--- a/tests/unit/java/com/android/server/connectivity/AutomaticOnOffKeepaliveTrackerTest.java
+++ b/tests/unit/java/com/android/server/connectivity/AutomaticOnOffKeepaliveTrackerTest.java
@@ -162,7 +162,7 @@
             + "00000000"            // idiag_expires
             + "00000000"            // idiag_rqueue
             + "00000000"            // idiag_wqueue
-            + "00000000"            // idiag_uid
+            + "39300000"            // idiag_uid = 12345
             + "00000000"            // idiag_inode
             // rtattr
             + "0500"            // len = 5
@@ -427,6 +427,16 @@
     }
 
     @Test
+    public void testIsAnyTcpSocketConnected_noTargetUidSocket() throws Exception {
+        setupResponseWithSocketExisting();
+        // Configured uid(12345) is not in the VPN range.
+        assertFalse(visibleOnHandlerThread(mTestHandler,
+                () -> mAOOKeepaliveTracker.isAnyTcpSocketConnected(
+                        TEST_NETID,
+                        new ArraySet<>(Arrays.asList(new Range<>(99999, 99999))))));
+    }
+
+    @Test
     public void testIsAnyTcpSocketConnected_withIncorrectNetId() throws Exception {
         setupResponseWithSocketExisting();
         assertFalse(visibleOnHandlerThread(mTestHandler,