Add support for processing netlink dump messages
Test: atest NetworkStaticLibTests:com.android.net.moduletests.util.netlink.NetlinkUtilsTest
Change-Id: I15e208dca5ed6a723585f79cf09276017bc8a885
diff --git a/staticlibs/device/com/android/net/module/util/netlink/NetlinkUtils.java b/staticlibs/device/com/android/net/module/util/netlink/NetlinkUtils.java
index f1f30d3..f6282fd 100644
--- a/staticlibs/device/com/android/net/module/util/netlink/NetlinkUtils.java
+++ b/staticlibs/device/com/android/net/module/util/netlink/NetlinkUtils.java
@@ -29,7 +29,12 @@
import static android.system.OsConstants.SO_RCVBUF;
import static android.system.OsConstants.SO_RCVTIMEO;
import static android.system.OsConstants.SO_SNDTIMEO;
+import static com.android.net.module.util.netlink.NetlinkConstants.hexify;
+import static com.android.net.module.util.netlink.NetlinkConstants.NLMSG_DONE;
+import static com.android.net.module.util.netlink.StructNlMsgHdr.NLM_F_DUMP;
+import static com.android.net.module.util.netlink.StructNlMsgHdr.NLM_F_REQUEST;
+import android.net.ParseException;
import android.net.util.SocketUtils;
import android.system.ErrnoException;
import android.system.Os;
@@ -47,7 +52,11 @@
import java.net.SocketException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
+import java.util.ArrayList;
+import java.util.List;
import java.util.Objects;
+import java.util.function.Consumer;
+import java.util.stream.Collectors;
/**
* Utilities for netlink related class that may not be able to fit into a specific class.
@@ -163,11 +172,7 @@
Log.e(TAG, errPrefix, e);
throw new ErrnoException(errPrefix, EIO, e);
} finally {
- try {
- SocketUtils.closeSocket(fd);
- } catch (IOException e) {
- // Nothing we can do here
- }
+ closeSocketQuietly(fd);
}
}
@@ -308,4 +313,85 @@
}
private NetlinkUtils() {}
+
+ /**
+ * Sends a netlink dump request and processes the returned dump messages
+ *
+ * @param <T> extends NetlinkMessage
+ * @param dumpRequestMessage netlink dump request message to be sent
+ * @param nlFamily netlink family
+ * @param msgClass expected class of the netlink message
+ * @param func function defined by caller to handle the dump messages
+ * @throws SocketException when fails to create socket
+ * @throws InterruptedIOException when fails to read the dumpFd
+ * @throws ErrnoException when fails to send dump request
+ * @throws ParseException when message can't be parsed
+ */
+ public static <T extends NetlinkMessage> void getAndProcessNetlinkDumpMessages(
+ byte[] dumpRequestMessage, int nlFamily, Class<T> msgClass,
+ Consumer<T> func)
+ throws SocketException, InterruptedIOException, ErrnoException, ParseException {
+ // Create socket and send dump request
+ final FileDescriptor fd;
+ try {
+ fd = netlinkSocketForProto(nlFamily);
+ } catch (ErrnoException e) {
+ Log.e(TAG, "Failed to create netlink socket " + e);
+ throw e.rethrowAsSocketException();
+ }
+
+ try {
+ connectToKernel(fd);
+ } catch (ErrnoException | SocketException e) {
+ Log.e(TAG, "Failed to connect netlink socket to kernel " + e);
+ closeSocketQuietly(fd);
+ return;
+ }
+
+ try {
+ sendMessage(fd, dumpRequestMessage, 0, dumpRequestMessage.length, IO_TIMEOUT_MS);
+ } catch (InterruptedIOException | ErrnoException e) {
+ Log.e(TAG, "Failed to send dump request " + e);
+ closeSocketQuietly(fd);
+ throw e;
+ }
+
+ while (true) {
+ final ByteBuffer buf = recvMessage(
+ fd, NetlinkUtils.DEFAULT_RECV_BUFSIZE, IO_TIMEOUT_MS);
+
+ while (buf.remaining() > 0) {
+ final int position = buf.position();
+ final NetlinkMessage nlMsg = NetlinkMessage.parse(buf, nlFamily);
+ if (nlMsg == null) {
+ // Move to the position where parse started for error log.
+ buf.position(position);
+ Log.e(TAG, "Failed to parse netlink message: " + hexify(buf));
+ closeSocketQuietly(fd);
+ throw new ParseException("Failed to parse netlink message");
+ }
+
+ if (nlMsg.getHeader().nlmsg_type == NLMSG_DONE) {
+ closeSocketQuietly(fd);
+ return;
+ }
+
+ if (!msgClass.isInstance(nlMsg)) {
+ Log.e(TAG, "Received unexpected netlink message: " + nlMsg);
+ continue;
+ }
+
+ final T msg = (T) nlMsg;
+ func.accept(msg);
+ }
+ }
+ }
+
+ private static void closeSocketQuietly(final FileDescriptor fd) {
+ try {
+ SocketUtils.closeSocket(fd);
+ } catch (IOException e) {
+ // Nothing we can do here
+ }
+ }
}
diff --git a/staticlibs/tests/unit/src/com/android/net/module/util/netlink/NetlinkUtilsTest.java b/staticlibs/tests/unit/src/com/android/net/module/util/netlink/NetlinkUtilsTest.java
index 5a231fc..17d4e81 100644
--- a/staticlibs/tests/unit/src/com/android/net/module/util/netlink/NetlinkUtilsTest.java
+++ b/staticlibs/tests/unit/src/com/android/net/module/util/netlink/NetlinkUtilsTest.java
@@ -55,6 +55,9 @@
import java.nio.ByteOrder;
import java.nio.file.Files;
import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.function.Consumer;
@RunWith(AndroidJUnit4.class)
@SmallTest
@@ -65,19 +68,14 @@
@Test
public void testGetNeighborsQuery() throws Exception {
- final FileDescriptor fd = NetlinkUtils.netlinkSocketForProto(NETLINK_ROUTE);
- assertNotNull(fd);
-
- NetlinkUtils.connectToKernel(fd);
-
- final NetlinkSocketAddress localAddr = (NetlinkSocketAddress) Os.getsockname(fd);
- assertNotNull(localAddr);
- assertEquals(0, localAddr.getGroupsMask());
- assertTrue(0 != localAddr.getPortId());
-
final byte[] req = RtNetlinkNeighborMessage.newGetNeighborsRequest(TEST_SEQNO);
assertNotNull(req);
+ List<RtNetlinkNeighborMessage> msgs = new ArrayList<>();
+ Consumer<RtNetlinkNeighborMessage> handleNlDumpMsg = (msg) -> {
+ msgs.add(msg);
+ };
+
final Context ctx = InstrumentationRegistry.getInstrumentation().getContext();
final int targetSdk =
ctx.getPackageManager()
@@ -94,7 +92,8 @@
assumeFalse("network_stack context is expected to have permission to send RTM_GETNEIGH",
ctxt.startsWith("u:r:network_stack:s0"));
try {
- NetlinkUtils.sendMessage(fd, req, 0, req.length, TEST_TIMEOUT_MS);
+ NetlinkUtils.<RtNetlinkNeighborMessage>getAndProcessNetlinkDumpMessages(req,
+ NETLINK_ROUTE, RtNetlinkNeighborMessage.class, handleNlDumpMsg);
fail("RTM_GETNEIGH is not allowed for apps targeting SDK > 31 on T+ platforms,"
+ " target SDK version: " + targetSdk);
} catch (ErrnoException e) {
@@ -105,106 +104,70 @@
}
// Check that apps targeting lower API levels / running on older platforms succeed
- assertEquals(req.length,
- NetlinkUtils.sendMessage(fd, req, 0, req.length, TEST_TIMEOUT_MS));
+ NetlinkUtils.<RtNetlinkNeighborMessage>getAndProcessNetlinkDumpMessages(req,
+ NETLINK_ROUTE, RtNetlinkNeighborMessage.class, handleNlDumpMsg);
- int neighMessageCount = 0;
- int doneMessageCount = 0;
-
- while (doneMessageCount == 0) {
- ByteBuffer response =
- NetlinkUtils.recvMessage(fd, DEFAULT_RECV_BUFSIZE, TEST_TIMEOUT_MS);
- assertNotNull(response);
- assertTrue(StructNlMsgHdr.STRUCT_SIZE <= response.limit());
- assertEquals(0, response.position());
- assertEquals(ByteOrder.nativeOrder(), response.order());
-
- // Verify the messages at least appears minimally reasonable.
- while (response.remaining() > 0) {
- final NetlinkMessage msg = NetlinkMessage.parse(response, NETLINK_ROUTE);
- assertNotNull(msg);
- final StructNlMsgHdr hdr = msg.getHeader();
- assertNotNull(hdr);
-
- if (hdr.nlmsg_type == NetlinkConstants.NLMSG_DONE) {
- doneMessageCount++;
- continue;
- }
-
- assertEquals(NetlinkConstants.RTM_NEWNEIGH, hdr.nlmsg_type);
- assertTrue(msg instanceof RtNetlinkNeighborMessage);
- assertTrue((hdr.nlmsg_flags & StructNlMsgHdr.NLM_F_MULTI) != 0);
- assertEquals(TEST_SEQNO, hdr.nlmsg_seq);
- assertEquals(localAddr.getPortId(), hdr.nlmsg_pid);
-
- neighMessageCount++;
- }
+ for (var msg : msgs) {
+ assertNotNull(msg);
+ final StructNlMsgHdr hdr = msg.getHeader();
+ assertNotNull(hdr);
+ assertEquals(NetlinkConstants.RTM_NEWNEIGH, hdr.nlmsg_type);
+ assertTrue((hdr.nlmsg_flags & StructNlMsgHdr.NLM_F_MULTI) != 0);
+ assertEquals(TEST_SEQNO, hdr.nlmsg_seq);
}
- assertEquals(1, doneMessageCount);
// TODO: make sure this test passes sanely in airplane mode.
- assertTrue(neighMessageCount > 0);
-
- IoUtils.closeQuietly(fd);
+ assertTrue(msgs.size() > 0);
}
@Test
public void testBasicWorkingGetAddrQuery() throws Exception {
- final FileDescriptor fd = NetlinkUtils.netlinkSocketForProto(NETLINK_ROUTE);
- assertNotNull(fd);
-
- NetlinkUtils.connectToKernel(fd);
-
- final NetlinkSocketAddress localAddr = (NetlinkSocketAddress) Os.getsockname(fd);
- assertNotNull(localAddr);
- assertEquals(0, localAddr.getGroupsMask());
- assertTrue(0 != localAddr.getPortId());
-
final int testSeqno = 8;
final byte[] req = newGetAddrRequest(testSeqno);
assertNotNull(req);
- final long timeout = 500;
- assertEquals(req.length, NetlinkUtils.sendMessage(fd, req, 0, req.length, timeout));
+ List<RtNetlinkAddressMessage> msgs = new ArrayList<>();
+ Consumer<RtNetlinkAddressMessage> handleNlDumpMsg = (msg) -> {
+ msgs.add(msg);
+ };
+ NetlinkUtils.<RtNetlinkAddressMessage>getAndProcessNetlinkDumpMessages(req, NETLINK_ROUTE,
+ RtNetlinkAddressMessage.class, handleNlDumpMsg);
- int addrMessageCount = 0;
+ boolean ipv4LoopbackAddressFound = false;
+ boolean ipv6LoopbackAddressFound = false;
+ final InetAddress loopbackIpv4 = InetAddress.getByName("127.0.0.1");
+ final InetAddress loopbackIpv6 = InetAddress.getByName("::1");
- while (true) {
- ByteBuffer response = NetlinkUtils.recvMessage(fd, DEFAULT_RECV_BUFSIZE, timeout);
- assertNotNull(response);
- assertTrue(StructNlMsgHdr.STRUCT_SIZE <= response.limit());
- assertEquals(0, response.position());
- assertEquals(ByteOrder.nativeOrder(), response.order());
-
- final NetlinkMessage msg = NetlinkMessage.parse(response, NETLINK_ROUTE);
+ for (var msg : msgs) {
assertNotNull(msg);
final StructNlMsgHdr nlmsghdr = msg.getHeader();
assertNotNull(nlmsghdr);
-
- if (nlmsghdr.nlmsg_type == NetlinkConstants.NLMSG_DONE) {
- break;
- }
-
assertEquals(NetlinkConstants.RTM_NEWADDR, nlmsghdr.nlmsg_type);
assertTrue((nlmsghdr.nlmsg_flags & StructNlMsgHdr.NLM_F_MULTI) != 0);
assertEquals(testSeqno, nlmsghdr.nlmsg_seq);
- assertEquals(localAddr.getPortId(), nlmsghdr.nlmsg_pid);
assertTrue(msg instanceof RtNetlinkAddressMessage);
- addrMessageCount++;
-
- // From the query response we can see the RTM_NEWADDR messages representing for IPv4
- // and IPv6 loopback address: 127.0.0.1 and ::1.
+ // When parsing the full response we can see the RTM_NEWADDR messages representing for
+ // IPv4 and IPv6 loopback address: 127.0.0.1 and ::1 and non-loopback addresses.
final StructIfaddrMsg ifaMsg = ((RtNetlinkAddressMessage) msg).getIfaddrHeader();
final InetAddress ipAddress = ((RtNetlinkAddressMessage) msg).getIpAddress();
assertTrue(
"Non-IP address family: " + ifaMsg.family,
ifaMsg.family == AF_INET || ifaMsg.family == AF_INET6);
- assertTrue(ipAddress.isLoopbackAddress());
+ assertNotNull(ipAddress);
+
+ if (ipAddress.equals(loopbackIpv4)) {
+ ipv4LoopbackAddressFound = true;
+ assertTrue(ipAddress.isLoopbackAddress());
+ }
+ if (ipAddress.equals(loopbackIpv6)) {
+ ipv6LoopbackAddressFound = true;
+ assertTrue(ipAddress.isLoopbackAddress());
+ }
}
- assertTrue(addrMessageCount > 0);
-
- IoUtils.closeQuietly(fd);
+ assertTrue(msgs.size() > 0);
+ // Check ipv4 and ipv6 loopback addresses are in the output
+ assertTrue(ipv4LoopbackAddressFound && ipv6LoopbackAddressFound);
}
/** A convenience method to create an RTM_GETADDR request message. */