Merge "Check target sockets by UID range" into main
diff --git a/service/src/com/android/server/connectivity/AutomaticOnOffKeepaliveTracker.java b/service/src/com/android/server/connectivity/AutomaticOnOffKeepaliveTracker.java
index 8036ae9..94ba9de 100644
--- a/service/src/com/android/server/connectivity/AutomaticOnOffKeepaliveTracker.java
+++ b/service/src/com/android/server/connectivity/AutomaticOnOffKeepaliveTracker.java
@@ -25,9 +25,6 @@
import static android.system.OsConstants.SOL_SOCKET;
import static android.system.OsConstants.SO_SNDTIMEO;
-import static com.android.net.module.util.netlink.NetlinkConstants.NLMSG_DONE;
-import static com.android.net.module.util.netlink.NetlinkConstants.SOCKDIAG_MSG_HEADER_SIZE;
-import static com.android.net.module.util.netlink.NetlinkConstants.SOCK_DIAG_BY_FAMILY;
import static com.android.net.module.util.netlink.NetlinkUtils.IO_TIMEOUT_MS;
import android.annotation.IntDef;
@@ -90,6 +87,7 @@
*/
public class AutomaticOnOffKeepaliveTracker {
private static final String TAG = "AutomaticOnOffKeepaliveTracker";
+ private static final boolean DBG = Log.isLoggable(TAG, Log.DEBUG);
private static final int[] ADDRESS_FAMILIES = new int[] {AF_INET6, AF_INET};
private static final long LOW_TCP_POLLING_INTERVAL_MS = 1_000L;
private static final int ADJUST_TCP_POLLING_DELAY_MS = 2000;
@@ -794,22 +792,18 @@
try {
while (NetlinkUtils.enoughBytesRemainForValidNlMsg(bytes)) {
- final int startPos = bytes.position();
+ // NetlinkMessage.parse() will move the byte buffer position.
+ // TODO: Parse dst address information to filter socket.
+ final NetlinkMessage nlMsg = NetlinkMessage.parse(
+ bytes, OsConstants.NETLINK_INET_DIAG);
+ if (!(nlMsg instanceof InetDiagMessage)) {
+ if (DBG) Log.e(TAG, "Not a SOCK_DIAG_BY_FAMILY msg");
+ return false;
+ }
- final int nlmsgLen = bytes.getInt();
- final int nlmsgType = bytes.getShort();
- if (isEndOfMessageOrError(nlmsgType)) return false;
- // TODO: Parse InetDiagMessage to get uid and dst address information to filter
- // socket via NetlinkMessage.parse.
-
- // Skip the header to move to data part.
- bytes.position(startPos + SOCKDIAG_MSG_HEADER_SIZE);
-
- if (isTargetTcpSocket(bytes, nlmsgLen, networkMark, networkMask)) {
- if (Log.isLoggable(TAG, Log.DEBUG)) {
- bytes.position(startPos);
- final InetDiagMessage diagMsg = (InetDiagMessage) NetlinkMessage.parse(
- bytes, OsConstants.NETLINK_INET_DIAG);
+ final InetDiagMessage diagMsg = (InetDiagMessage) nlMsg;
+ if (isTargetTcpSocket(diagMsg, networkMark, networkMask, vpnUidRanges)) {
+ if (DBG) {
Log.d(TAG, String.format("Found open TCP connection by uid %d to %s"
+ " cookie %d",
diagMsg.inetDiagMsg.idiag_uid,
@@ -834,26 +828,31 @@
return false;
}
- private boolean isEndOfMessageOrError(int nlmsgType) {
- return nlmsgType == NLMSG_DONE || nlmsgType != SOCK_DIAG_BY_FAMILY;
+ private static boolean containsUid(Set<Range<Integer>> ranges, int uid) {
+ for (final Range<Integer> range: ranges) {
+ if (range.contains(uid)) {
+ return true;
+ }
+ }
+ return false;
}
- private boolean isTargetTcpSocket(@NonNull ByteBuffer bytes, int nlmsgLen, int networkMark,
- int networkMask) {
- final int mark = readSocketDataAndReturnMark(bytes, nlmsgLen);
+ private boolean isTargetTcpSocket(@NonNull InetDiagMessage diagMsg,
+ int networkMark, int networkMask, @NonNull Set<Range<Integer>> vpnUidRanges) {
+ if (!containsUid(vpnUidRanges, diagMsg.inetDiagMsg.idiag_uid)) return false;
+
+ final int mark = readSocketDataAndReturnMark(diagMsg);
return (mark & networkMask) == networkMark;
}
- private int readSocketDataAndReturnMark(@NonNull ByteBuffer bytes, int nlmsgLen) {
- final int nextMsgOffset = bytes.position() + nlmsgLen - SOCKDIAG_MSG_HEADER_SIZE;
+ private int readSocketDataAndReturnMark(@NonNull InetDiagMessage diagMsg) {
int mark = NetlinkUtils.INIT_MARK_VALUE;
// Get socket mark
- // TODO: Add a parsing method in NetlinkMessage.parse to support this to skip the remaining
- // data.
- while (bytes.position() < nextMsgOffset) {
- final StructNlAttr nlattr = StructNlAttr.parse(bytes);
- if (nlattr != null && nlattr.nla_type == NetlinkUtils.INET_DIAG_MARK) {
- mark = nlattr.getValueAsInteger();
+ for (StructNlAttr attr : diagMsg.nlAttrs) {
+ if (attr.nla_type == NetlinkUtils.INET_DIAG_MARK) {
+ // The netlink attributes should contain only one INET_DIAG_MARK for each socket.
+ mark = attr.getValueAsInteger();
+ break;
}
}
return mark;
diff --git a/tests/unit/java/com/android/server/connectivity/AutomaticOnOffKeepaliveTrackerTest.java b/tests/unit/java/com/android/server/connectivity/AutomaticOnOffKeepaliveTrackerTest.java
index 015b75a..4fcf8a8 100644
--- a/tests/unit/java/com/android/server/connectivity/AutomaticOnOffKeepaliveTrackerTest.java
+++ b/tests/unit/java/com/android/server/connectivity/AutomaticOnOffKeepaliveTrackerTest.java
@@ -162,7 +162,7 @@
+ "00000000" // idiag_expires
+ "00000000" // idiag_rqueue
+ "00000000" // idiag_wqueue
- + "00000000" // idiag_uid
+ + "39300000" // idiag_uid = 12345
+ "00000000" // idiag_inode
// rtattr
+ "0500" // len = 5
@@ -427,6 +427,16 @@
}
@Test
+ public void testIsAnyTcpSocketConnected_noTargetUidSocket() throws Exception {
+ setupResponseWithSocketExisting();
+ // Configured uid(12345) is not in the VPN range.
+ assertFalse(visibleOnHandlerThread(mTestHandler,
+ () -> mAOOKeepaliveTracker.isAnyTcpSocketConnected(
+ TEST_NETID,
+ new ArraySet<>(Arrays.asList(new Range<>(99999, 99999))))));
+ }
+
+ @Test
public void testIsAnyTcpSocketConnected_withIncorrectNetId() throws Exception {
setupResponseWithSocketExisting();
assertFalse(visibleOnHandlerThread(mTestHandler,