Fix StructInetDiagMsg#parse bug

StructInetDiagMsg#parse used absolute position to parse id and idiag_uid
fields.
This causes an issue when byte buffer contains multiple netlink messages
and parsed netlink message starts from the middle of byte buffer.

This CL fixes above issue and adds test for the case that single byte
buffer contains multiple netlink messages.

Test: atest NetworkStaticLibTests
Change-Id: I771daf17b4a606915d5edefd9f0e90ef01abf168
diff --git a/staticlibs/device/com/android/net/module/util/netlink/StructInetDiagMsg.java b/staticlibs/device/com/android/net/module/util/netlink/StructInetDiagMsg.java
index ea018cf..e7fc02f 100644
--- a/staticlibs/device/com/android/net/module/util/netlink/StructInetDiagMsg.java
+++ b/staticlibs/device/com/android/net/module/util/netlink/StructInetDiagMsg.java
@@ -43,9 +43,11 @@
  */
 public class StructInetDiagMsg {
     public static final int STRUCT_SIZE = 4 + StructInetDiagSockId.STRUCT_SIZE + 20;
-    private static final int IDIAG_SOCK_ID_OFFSET = StructNlMsgHdr.STRUCT_SIZE + 4;
-    private static final int IDIAG_UID_OFFSET = StructNlMsgHdr.STRUCT_SIZE + 4
-            + StructInetDiagSockId.STRUCT_SIZE + 12;
+    // Offset to the id field from the beginning of inet_diag_msg struct
+    private static final int IDIAG_SOCK_ID_OFFSET = 4;
+    // Offset to the idiag_uid field from the beginning of inet_diag_msg struct
+    private static final int IDIAG_UID_OFFSET =
+            IDIAG_SOCK_ID_OFFSET + StructInetDiagSockId.STRUCT_SIZE + 12;
     public int idiag_uid;
     @NonNull
     public StructInetDiagSockId id;
@@ -58,14 +60,18 @@
         if (byteBuffer.remaining() < STRUCT_SIZE) {
             return null;
         }
+        final int baseOffset = byteBuffer.position();
         StructInetDiagMsg struct = new StructInetDiagMsg();
         final byte family = byteBuffer.get();
-        byteBuffer.position(IDIAG_SOCK_ID_OFFSET);
+        byteBuffer.position(baseOffset + IDIAG_SOCK_ID_OFFSET);
         struct.id = StructInetDiagSockId.parse(byteBuffer, family);
         if (struct.id == null) {
             return null;
         }
-        struct.idiag_uid = byteBuffer.getInt(IDIAG_UID_OFFSET);
+        struct.idiag_uid = byteBuffer.getInt(baseOffset + IDIAG_UID_OFFSET);
+
+        // Move position to the end of the inet_diag_msg
+        byteBuffer.position(baseOffset + STRUCT_SIZE);
         return struct;
     }
 
diff --git a/staticlibs/tests/unit/src/com/android/net/module/util/netlink/InetDiagSocketTest.java b/staticlibs/tests/unit/src/com/android/net/module/util/netlink/InetDiagSocketTest.java
index c7e2a4d..d81422f 100644
--- a/staticlibs/tests/unit/src/com/android/net/module/util/netlink/InetDiagSocketTest.java
+++ b/staticlibs/tests/unit/src/com/android/net/module/util/netlink/InetDiagSocketTest.java
@@ -221,14 +221,31 @@
         assertArrayEquals(INET_DIAG_REQ_V2_TCP_INET6_NO_ID_SPECIFIED_BYTES, msgExt);
     }
 
-    // Hexadecimal representation of InetDiagReqV2 request.
-    private static final String INET_DIAG_MSG_HEX =
+    private void assertNlMsgHdr(StructNlMsgHdr hdr, short type, short flags, int seq, int pid) {
+        assertNotNull(hdr);
+        assertEquals(type, hdr.nlmsg_type);
+        assertEquals(flags, hdr.nlmsg_flags);
+        assertEquals(seq, hdr.nlmsg_seq);
+        assertEquals(pid, hdr.nlmsg_pid);
+    }
+
+    private void assertInetDiagSockId(StructInetDiagSockId sockId,
+            InetSocketAddress locSocketAddress, InetSocketAddress remSocketAddress,
+            int ifIndex, long cookie) {
+        assertEquals(locSocketAddress, sockId.locSocketAddress);
+        assertEquals(remSocketAddress, sockId.remSocketAddress);
+        assertEquals(ifIndex, sockId.ifIndex);
+        assertEquals(cookie, sockId.cookie);
+    }
+
+    // Hexadecimal representation of InetDiagMessage
+    private static final String INET_DIAG_MSG_HEX1 =
             // struct nlmsghdr
             "58000000" +     // length = 88
             "1400" +         // type = SOCK_DIAG_BY_FAMILY
             "0200" +         // flags = NLM_F_MULTI
             "00000000" +     // seqno
-            "f5220000" +     // pid (0 == kernel)
+            "f5220000" +     // pid
             // struct inet_diag_msg
             "0a" +           // family = AF_INET6
             "01" +           // idiag_state
@@ -244,36 +261,94 @@
             "00000000" +     // idiag_expires
             "00000000" +     // idiag_rqueue
             "00000000" +     // idiag_wqueue
-            "a3270000" +     // idiag_uid
+            "a3270000" +     // idiag_uid = 10147
             "A57E1900";      // idiag_inode
+
+    private void assertInetDiagMsg1(final NetlinkMessage msg) {
+        assertNotNull(msg);
+
+        assertTrue(msg instanceof InetDiagMessage);
+        final InetDiagMessage inetDiagMsg = (InetDiagMessage) msg;
+
+        assertNlMsgHdr(inetDiagMsg.getHeader(),
+                NetlinkConstants.SOCK_DIAG_BY_FAMILY,
+                StructNlMsgHdr.NLM_F_MULTI,
+                0    /* seq */,
+                8949 /* pid */);
+
+        assertEquals(10147, inetDiagMsg.inetDiagMsg.idiag_uid);
+        assertInetDiagSockId(inetDiagMsg.inetDiagMsg.id,
+                new InetSocketAddress(InetAddresses.parseNumericAddress("2001:db8::1"), 43031),
+                new InetSocketAddress(InetAddresses.parseNumericAddress("2001:db8::2"), 38415),
+                7  /* ifIndex */,
+                88 /* cookie */);
+    }
+
+    // Hexadecimal representation of InetDiagMessage
+    private static final String INET_DIAG_MSG_HEX2 =
+            // struct nlmsghdr
+            "58000000" +     // length = 88
+            "1400" +         // type = SOCK_DIAG_BY_FAMILY
+            "0200" +         // flags = NLM_F_MULTI
+            "00000000" +     // seqno
+            "f5220000" +     // pid
+            // struct inet_diag_msg
+            "0a" +           // family = AF_INET6
+            "01" +           // idiag_state
+            "00" +           // idiag_timer
+            "00" +           // idiag_retrans
+                // inet_diag_sockid
+                "a845" +     // idiag_sport = 43077
+                "01bb" +     // idiag_dport = 443
+                "20010db8000000000000000000000003" + // idiag_src = 2001:db8::3
+                "20010db8000000000000000000000004" + // idiag_dst = 2001:db8::4
+                "08000000" + // idiag_if = 8
+                "6300000000000000" + // idiag_cookie = 99
+            "00000000" +     // idiag_expires
+            "00000000" +     // idiag_rqueue
+            "00000000" +     // idiag_wqueue
+            "39300000" +     // idiag_uid = 12345
+            "A57E1900";      // idiag_inode
+
+    private void assertInetDiagMsg2(final NetlinkMessage msg) {
+        assertNotNull(msg);
+
+        assertTrue(msg instanceof InetDiagMessage);
+        final InetDiagMessage inetDiagMsg = (InetDiagMessage) msg;
+
+        assertNlMsgHdr(inetDiagMsg.getHeader(),
+                NetlinkConstants.SOCK_DIAG_BY_FAMILY,
+                StructNlMsgHdr.NLM_F_MULTI,
+                0    /* seq */,
+                8949 /* pid */);
+
+        assertEquals(12345, inetDiagMsg.inetDiagMsg.idiag_uid);
+        assertInetDiagSockId(inetDiagMsg.inetDiagMsg.id,
+                new InetSocketAddress(InetAddresses.parseNumericAddress("2001:db8::3"), 43077),
+                new InetSocketAddress(InetAddresses.parseNumericAddress("2001:db8::4"), 443),
+                8  /* ifIndex */,
+                99 /* cookie */);
+    }
+
     private static final byte[] INET_DIAG_MSG_BYTES =
-            HexEncoding.decode(INET_DIAG_MSG_HEX.toCharArray(), false);
+            HexEncoding.decode(INET_DIAG_MSG_HEX1.toCharArray(), false);
 
     @Test
     public void testParseInetDiagResponse() throws Exception {
         final ByteBuffer byteBuffer = ByteBuffer.wrap(INET_DIAG_MSG_BYTES);
         byteBuffer.order(ByteOrder.LITTLE_ENDIAN);
-        final NetlinkMessage msg = NetlinkMessage.parse(byteBuffer, NETLINK_INET_DIAG);
-        assertNotNull(msg);
+        assertInetDiagMsg1(NetlinkMessage.parse(byteBuffer, NETLINK_INET_DIAG));
+    }
 
-        assertTrue(msg instanceof InetDiagMessage);
-        final InetDiagMessage inetDiagMsg = (InetDiagMessage) msg;
-        assertEquals(10147, inetDiagMsg.inetDiagMsg.idiag_uid);
-        final StructInetDiagSockId sockId = inetDiagMsg.inetDiagMsg.id;
-        assertEquals(43031, sockId.locSocketAddress.getPort());
-        assertEquals(InetAddresses.parseNumericAddress("2001:db8::1"),
-                sockId.locSocketAddress.getAddress());
-        assertEquals(38415, sockId.remSocketAddress.getPort());
-        assertEquals(InetAddresses.parseNumericAddress("2001:db8::2"),
-                sockId.remSocketAddress.getAddress());
-        assertEquals(7, sockId.ifIndex);
-        assertEquals(88, sockId.cookie);
 
-        final StructNlMsgHdr hdr = inetDiagMsg.getHeader();
-        assertNotNull(hdr);
-        assertEquals(NetlinkConstants.SOCK_DIAG_BY_FAMILY, hdr.nlmsg_type);
-        assertEquals(StructNlMsgHdr.NLM_F_MULTI, hdr.nlmsg_flags);
-        assertEquals(0, hdr.nlmsg_seq);
-        assertEquals(8949, hdr.nlmsg_pid);
+    private static final byte[] INET_DIAG_MSG_BYTES_MULTIPLE =
+            HexEncoding.decode((INET_DIAG_MSG_HEX1 + INET_DIAG_MSG_HEX2).toCharArray(), false);
+
+    @Test
+    public void testParseInetDiagResponseMultiple() {
+        final ByteBuffer byteBuffer = ByteBuffer.wrap(INET_DIAG_MSG_BYTES_MULTIPLE);
+        byteBuffer.order(ByteOrder.LITTLE_ENDIAN);
+        assertInetDiagMsg1(NetlinkMessage.parse(byteBuffer, NETLINK_INET_DIAG));
+        assertInetDiagMsg2(NetlinkMessage.parse(byteBuffer, NETLINK_INET_DIAG));
     }
 }