Support StructInetDiagSockId parse

Bug: 217624062
Test: atest NetworkStaticLibTests
Change-Id: If42436459874d4c61037a98db18c43b40eb81066
diff --git a/staticlibs/device/com/android/net/module/util/netlink/StructInetDiagSockId.java b/staticlibs/device/com/android/net/module/util/netlink/StructInetDiagSockId.java
index 95d60e5..95723cb 100644
--- a/staticlibs/device/com/android/net/module/util/netlink/StructInetDiagSockId.java
+++ b/staticlibs/device/com/android/net/module/util/netlink/StructInetDiagSockId.java
@@ -16,10 +16,22 @@
 
 package com.android.net.module.util.netlink;
 
+import static android.system.OsConstants.AF_INET;
+import static android.system.OsConstants.AF_INET6;
+
+import static com.android.net.module.util.NetworkStackConstants.IPV4_ADDR_LEN;
+import static com.android.net.module.util.NetworkStackConstants.IPV6_ADDR_LEN;
+
 import static java.nio.ByteOrder.BIG_ENDIAN;
 
+import android.util.Log;
+
+import androidx.annotation.Nullable;
+
 import java.net.Inet4Address;
+import java.net.InetAddress;
 import java.net.InetSocketAddress;
+import java.net.UnknownHostException;
 import java.nio.ByteBuffer;
 import java.nio.ByteOrder;
 
@@ -41,19 +53,79 @@
  * @hide
  */
 public class StructInetDiagSockId {
+    private static final String TAG = StructInetDiagSockId.class.getSimpleName();
     public static final int STRUCT_SIZE = 48;
 
-    private static final byte[] INET_DIAG_NOCOOKIE = new byte[]{
-            (byte) 0xff, (byte) 0xff, (byte) 0xff, (byte) 0xff,
-            (byte) 0xff, (byte) 0xff, (byte) 0xff, (byte) 0xff};
+    private static final long INET_DIAG_NOCOOKIE = ~0L;
     private static final byte[] IPV4_PADDING = new byte[] {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
 
-    private final InetSocketAddress mLocSocketAddress;
-    private final InetSocketAddress mRemSocketAddress;
+    public final InetSocketAddress locSocketAddress;
+    public final InetSocketAddress remSocketAddress;
+    public final int ifIndex;
+    public final long cookie;
 
     public StructInetDiagSockId(InetSocketAddress loc, InetSocketAddress rem) {
-        mLocSocketAddress = loc;
-        mRemSocketAddress = rem;
+        this(loc, rem, 0 /* ifIndex */, INET_DIAG_NOCOOKIE);
+    }
+
+    public StructInetDiagSockId(InetSocketAddress loc, InetSocketAddress rem,
+            int ifIndex, long cookie) {
+        this.locSocketAddress = loc;
+        this.remSocketAddress = rem;
+        this.ifIndex = ifIndex;
+        this.cookie = cookie;
+    }
+
+    /**
+     * Parse inet diag socket id from buffer.
+     */
+    @Nullable
+    public static StructInetDiagSockId parse(final ByteBuffer byteBuffer, final byte family) {
+        if (byteBuffer.remaining() < STRUCT_SIZE) {
+            return null;
+        }
+
+        byteBuffer.order(BIG_ENDIAN);
+        final int srcPort = Short.toUnsignedInt(byteBuffer.getShort());
+        final int dstPort = Short.toUnsignedInt(byteBuffer.getShort());
+
+        final byte[] srcAddrByte;
+        final byte[] dstAddrByte;
+        if (family == AF_INET) {
+            srcAddrByte = new byte[IPV4_ADDR_LEN];
+            dstAddrByte = new byte[IPV4_ADDR_LEN];
+            byteBuffer.get(srcAddrByte);
+            // Address always uses IPV6_ADDR_LEN in the buffer. So if the address is IPv4, position
+            // needs to be advanced to the next field.
+            byteBuffer.position(byteBuffer.position() + (IPV6_ADDR_LEN - IPV4_ADDR_LEN));
+            byteBuffer.get(dstAddrByte);
+            byteBuffer.position(byteBuffer.position() + (IPV6_ADDR_LEN - IPV4_ADDR_LEN));
+        } else if (family == AF_INET6) {
+            srcAddrByte = new byte[IPV6_ADDR_LEN];
+            dstAddrByte = new byte[IPV6_ADDR_LEN];
+            byteBuffer.get(srcAddrByte);
+            byteBuffer.get(dstAddrByte);
+        } else {
+            Log.e(TAG, "Invalid address family: " + family);
+            return null;
+        }
+
+        final InetSocketAddress srcAddr;
+        final InetSocketAddress dstAddr;
+        try {
+            srcAddr = new InetSocketAddress(InetAddress.getByAddress(srcAddrByte), srcPort);
+            dstAddr = new InetSocketAddress(InetAddress.getByAddress(dstAddrByte), dstPort);
+        } catch (UnknownHostException e) {
+            // Should not happen. UnknownHostException is thrown only if addr byte array is of
+            // illegal length.
+            Log.e(TAG, "Failed to parse address: " + e);
+            return null;
+        }
+
+        byteBuffer.order(ByteOrder.nativeOrder());
+        final int ifIndex = byteBuffer.getInt();
+        final long cookie = byteBuffer.getLong();
+        return new StructInetDiagSockId(srcAddr, dstAddr, ifIndex, cookie);
     }
 
     /**
@@ -61,30 +133,31 @@
      */
     public void pack(ByteBuffer byteBuffer) {
         byteBuffer.order(BIG_ENDIAN);
-        byteBuffer.putShort((short) mLocSocketAddress.getPort());
-        byteBuffer.putShort((short) mRemSocketAddress.getPort());
-        byteBuffer.put(mLocSocketAddress.getAddress().getAddress());
-        if (mLocSocketAddress.getAddress() instanceof Inet4Address) {
+        byteBuffer.putShort((short) locSocketAddress.getPort());
+        byteBuffer.putShort((short) remSocketAddress.getPort());
+        byteBuffer.put(locSocketAddress.getAddress().getAddress());
+        if (locSocketAddress.getAddress() instanceof Inet4Address) {
             byteBuffer.put(IPV4_PADDING);
         }
-        byteBuffer.put(mRemSocketAddress.getAddress().getAddress());
-        if (mRemSocketAddress.getAddress() instanceof Inet4Address) {
+        byteBuffer.put(remSocketAddress.getAddress().getAddress());
+        if (remSocketAddress.getAddress() instanceof Inet4Address) {
             byteBuffer.put(IPV4_PADDING);
         }
         byteBuffer.order(ByteOrder.nativeOrder());
-        byteBuffer.putInt(0);
-        byteBuffer.put(INET_DIAG_NOCOOKIE);
+        byteBuffer.putInt(ifIndex);
+        byteBuffer.putLong(cookie);
     }
 
     @Override
     public String toString() {
         return "StructInetDiagSockId{ "
-                + "idiag_sport{" + mLocSocketAddress.getPort() + "}, "
-                + "idiag_dport{" + mRemSocketAddress.getPort() + "}, "
-                + "idiag_src{" + mLocSocketAddress.getAddress().getHostAddress() + "}, "
-                + "idiag_dst{" + mRemSocketAddress.getAddress().getHostAddress() + "}, "
-                + "idiag_if{" + 0 + "} "
-                + "idiag_cookie{INET_DIAG_NOCOOKIE}"
+                + "idiag_sport{" + locSocketAddress.getPort() + "}, "
+                + "idiag_dport{" + remSocketAddress.getPort() + "}, "
+                + "idiag_src{" + locSocketAddress.getAddress().getHostAddress() + "}, "
+                + "idiag_dst{" + remSocketAddress.getAddress().getHostAddress() + "}, "
+                + "idiag_if{" + ifIndex + "}, "
+                + "idiag_cookie{"
+                + (cookie == INET_DIAG_NOCOOKIE ? "INET_DIAG_NOCOOKIE" : cookie) + "}"
                 + "}";
     }
 }
diff --git a/staticlibs/tests/unit/src/com/android/net/module/util/netlink/StructInetDiagSockIdTest.java b/staticlibs/tests/unit/src/com/android/net/module/util/netlink/StructInetDiagSockIdTest.java
index fb929fc..ce190f2 100644
--- a/staticlibs/tests/unit/src/com/android/net/module/util/netlink/StructInetDiagSockIdTest.java
+++ b/staticlibs/tests/unit/src/com/android/net/module/util/netlink/StructInetDiagSockIdTest.java
@@ -16,6 +16,9 @@
 
 package com.android.net.module.util.netlink;
 
+import static android.system.OsConstants.AF_INET;
+import static android.system.OsConstants.AF_INET6;
+
 import static org.junit.Assert.assertArrayEquals;
 import static org.junit.Assert.assertEquals;
 
@@ -45,6 +48,8 @@
             (Inet6Address) InetAddresses.parseNumericAddress("2001:db8::2");
     private static final int SRC_PORT = 65297;
     private static final int DST_PORT = 443;
+    private static final int IF_INDEX = 7;
+    private static final long COOKIE = 561;
 
     private static final byte[] INET_DIAG_SOCKET_ID_IPV4 =
             new byte[] {
@@ -67,6 +72,27 @@
                     (byte) 0xff, (byte) 0xff, (byte) 0xff, (byte) 0xff
             };
 
+    private static final byte[] INET_DIAG_SOCKET_ID_IPV4_IF_COOKIE =
+            new byte[] {
+                    // src port, dst port
+                    (byte) 0xff, (byte) 0x11, (byte) 0x01, (byte) 0xbb,
+                    // src address
+                    (byte) 0xc0, (byte) 0x00, (byte) 0x02, (byte) 0x01,
+                    (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00,
+                    (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00,
+                    (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00,
+                    // dst address
+                    (byte) 0xc6, (byte) 0x33, (byte) 0x64, (byte) 0x01,
+                    (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00,
+                    (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00,
+                    (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00,
+                    // if index
+                    (byte) 0x07, (byte) 0x00, (byte) 0x00, (byte) 0x00,
+                    // cookie
+                    (byte) 0x31, (byte) 0x02, (byte) 0x00, (byte) 0x00,
+                    (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00,
+            };
+
     private static final byte[] INET_DIAG_SOCKET_ID_IPV6 =
             new byte[] {
                     // src port, dst port
@@ -88,6 +114,27 @@
                     (byte) 0xff, (byte) 0xff, (byte) 0xff, (byte) 0xff
             };
 
+    private static final byte[] INET_DIAG_SOCKET_ID_IPV6_IF_COOKIE =
+            new byte[] {
+                    // src port, dst port
+                    (byte) 0xff, (byte) 0x11, (byte) 0x01, (byte) 0xbb,
+                    // src address
+                    (byte) 0x20, (byte) 0x01, (byte) 0x0d, (byte) 0xb8,
+                    (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00,
+                    (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00,
+                    (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x01,
+                    // dst address
+                    (byte) 0x20, (byte) 0x01, (byte) 0x0d, (byte) 0xb8,
+                    (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00,
+                    (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00,
+                    (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x02,
+                    // if index
+                    (byte) 0x07, (byte) 0x00, (byte) 0x00, (byte) 0x00,
+                    // cookie
+                    (byte) 0x31, (byte) 0x02, (byte) 0x00, (byte) 0x00,
+                    (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00,
+            };
+
     @Test
     public void testPackStructInetDiagSockIdWithIpv4() {
         final InetSocketAddress srcAddr = new InetSocketAddress(IPV4_SRC_ADDR, SRC_PORT);
@@ -109,12 +156,60 @@
     }
 
     @Test
+    public void testPackStructInetDiagSockIdWithIpv4IfIndexCookie() {
+        final InetSocketAddress srcAddr = new InetSocketAddress(IPV4_SRC_ADDR, SRC_PORT);
+        final InetSocketAddress dstAddr = new InetSocketAddress(IPV4_DST_ADDR, DST_PORT);
+        final StructInetDiagSockId sockId =
+                new StructInetDiagSockId(srcAddr, dstAddr, IF_INDEX, COOKIE);
+        final ByteBuffer buffer = ByteBuffer.allocate(StructInetDiagSockId.STRUCT_SIZE);
+        sockId.pack(buffer);
+        assertArrayEquals(INET_DIAG_SOCKET_ID_IPV4_IF_COOKIE, buffer.array());
+    }
+
+    @Test
+    public void testPackStructInetDiagSockIdWithIpv6IfIndexCookie() {
+        final InetSocketAddress srcAddr = new InetSocketAddress(IPV6_SRC_ADDR, SRC_PORT);
+        final InetSocketAddress dstAddr = new InetSocketAddress(IPV6_DST_ADDR, DST_PORT);
+        final StructInetDiagSockId sockId =
+                new StructInetDiagSockId(srcAddr, dstAddr, IF_INDEX, COOKIE);
+        final ByteBuffer buffer = ByteBuffer.allocate(StructInetDiagSockId.STRUCT_SIZE);
+        sockId.pack(buffer);
+        assertArrayEquals(INET_DIAG_SOCKET_ID_IPV6_IF_COOKIE, buffer.array());
+    }
+
+    @Test
+    public void testParseStructInetDiagSockIdWithIpv4() {
+        final ByteBuffer buffer = ByteBuffer.wrap(INET_DIAG_SOCKET_ID_IPV4_IF_COOKIE);
+        final StructInetDiagSockId sockId = StructInetDiagSockId.parse(buffer, (byte) AF_INET);
+
+        assertEquals(SRC_PORT, sockId.locSocketAddress.getPort());
+        assertEquals(IPV4_SRC_ADDR, sockId.locSocketAddress.getAddress());
+        assertEquals(DST_PORT, sockId.remSocketAddress.getPort());
+        assertEquals(IPV4_DST_ADDR, sockId.remSocketAddress.getAddress());
+        assertEquals(IF_INDEX, sockId.ifIndex);
+        assertEquals(COOKIE, sockId.cookie);
+    }
+
+    @Test
+    public void testParseStructInetDiagSockIdWithIpv6() {
+        final ByteBuffer buffer = ByteBuffer.wrap(INET_DIAG_SOCKET_ID_IPV6_IF_COOKIE);
+        final StructInetDiagSockId sockId = StructInetDiagSockId.parse(buffer, (byte) AF_INET6);
+
+        assertEquals(SRC_PORT, sockId.locSocketAddress.getPort());
+        assertEquals(IPV6_SRC_ADDR, sockId.locSocketAddress.getAddress());
+        assertEquals(DST_PORT, sockId.remSocketAddress.getPort());
+        assertEquals(IPV6_DST_ADDR, sockId.remSocketAddress.getAddress());
+        assertEquals(IF_INDEX, sockId.ifIndex);
+        assertEquals(COOKIE, sockId.cookie);
+    }
+
+    @Test
     public void testToStringStructInetDiagSockIdWithIpv4() {
         final InetSocketAddress srcAddr = new InetSocketAddress(IPV4_SRC_ADDR, SRC_PORT);
         final InetSocketAddress dstAddr = new InetSocketAddress(IPV4_DST_ADDR, DST_PORT);
         final StructInetDiagSockId sockId = new StructInetDiagSockId(srcAddr, dstAddr);
         assertEquals("StructInetDiagSockId{ idiag_sport{65297}, idiag_dport{443},"
-                + " idiag_src{192.0.2.1}, idiag_dst{198.51.100.1}, idiag_if{0}"
+                + " idiag_src{192.0.2.1}, idiag_dst{198.51.100.1}, idiag_if{0},"
                 + " idiag_cookie{INET_DIAG_NOCOOKIE}}", sockId.toString());
     }
 
@@ -124,7 +219,7 @@
         final InetSocketAddress dstAddr = new InetSocketAddress(IPV6_DST_ADDR, DST_PORT);
         final StructInetDiagSockId sockId = new StructInetDiagSockId(srcAddr, dstAddr);
         assertEquals("StructInetDiagSockId{ idiag_sport{65297}, idiag_dport{443},"
-                + " idiag_src{2001:db8::1}, idiag_dst{2001:db8::2}, idiag_if{0}"
+                + " idiag_src{2001:db8::1}, idiag_dst{2001:db8::2}, idiag_if{0},"
                 + " idiag_cookie{INET_DIAG_NOCOOKIE}}", sockId.toString());
     }
 }