Fix ConntrackSocketTest#testIpv4ConntrackSocket flaky

If we run an iterations test, TCP connections opened by previous tests
won't be closed immediately. They will stay in the TIME_WAIT state for
120 seconds before closed. Netfilter socket listens to both
NF_NETLINK_CONNTRACK_NEW and NF_NETLINK_CONNTRACK_DESTROY groups.

There is a tiny time gap between socket creation and sending dump
request to kernel. If an old TCP connection happened to be closed in
this time gap, listening socket would see a DESTROY message before
seeing the NEW messages triggered by dump request.

The current test only parsing the first contrack message in the read
buffer, so it will only find the DESTROY message and ignore all of the
following NEW messages. It does a netfilter socket dump, but does not
check whether the dump produces reasonable output. Now that we have
the ability to parse conntrack messages, do more in-depth checking.

Bug: 254608655
Test: atest ConntrackSocketTest --iteration 500
Change-Id: Icf95adb9d9390d598e9c6bdb4e245c5fce764e3e
diff --git a/Tethering/src/com/android/networkstack/tethering/OffloadHardwareInterface.java b/Tethering/src/com/android/networkstack/tethering/OffloadHardwareInterface.java
index fbb342d..846abcb 100644
--- a/Tethering/src/com/android/networkstack/tethering/OffloadHardwareInterface.java
+++ b/Tethering/src/com/android/networkstack/tethering/OffloadHardwareInterface.java
@@ -295,8 +295,7 @@
                 NF_NETLINK_CONNTRACK_NEW | NF_NETLINK_CONNTRACK_DESTROY);
         if (h1 == null) return false;
 
-        sendIpv4NfGenMsg(h1, (short) ((NFNL_SUBSYS_CTNETLINK << 8) | IPCTNL_MSG_CT_GET),
-                           (short) (NLM_F_REQUEST | NLM_F_DUMP));
+        requestSocketDump(h1);
 
         final NativeHandle h2 = mDeps.createConntrackSocket(
                 NF_NETLINK_CONNTRACK_UPDATE | NF_NETLINK_CONNTRACK_DESTROY);
@@ -325,7 +324,7 @@
     }
 
     @VisibleForTesting
-    public void sendIpv4NfGenMsg(@NonNull NativeHandle handle, short type, short flags) {
+    void sendIpv4NfGenMsg(@NonNull NativeHandle handle, short type, short flags) {
         final int length = StructNlMsgHdr.STRUCT_SIZE + StructNfGenMsg.STRUCT_SIZE;
         final byte[] msg = new byte[length];
         final ByteBuffer byteBuffer = ByteBuffer.wrap(msg);
@@ -350,6 +349,12 @@
         }
     }
 
+    @VisibleForTesting
+    void requestSocketDump(NativeHandle handle) {
+        sendIpv4NfGenMsg(handle, (short) ((NFNL_SUBSYS_CTNETLINK << 8) | IPCTNL_MSG_CT_GET),
+                (short) (NLM_F_REQUEST | NLM_F_DUMP));
+    }
+
     private void closeFdInNativeHandle(final NativeHandle h) {
         try {
             h.close();
diff --git a/Tethering/tests/privileged/src/com/android/networkstack/tethering/ConntrackSocketTest.java b/Tethering/tests/privileged/src/com/android/networkstack/tethering/ConntrackSocketTest.java
index d38a7c3..23fb60c 100644
--- a/Tethering/tests/privileged/src/com/android/networkstack/tethering/ConntrackSocketTest.java
+++ b/Tethering/tests/privileged/src/com/android/networkstack/tethering/ConntrackSocketTest.java
@@ -16,28 +16,32 @@
 
 package com.android.networkstack.tethering;
 
+import static android.system.OsConstants.EAGAIN;
+import static android.system.OsConstants.IPPROTO_TCP;
+import static android.system.OsConstants.NETLINK_NETFILTER;
+
 import static com.android.net.module.util.netlink.NetlinkSocket.DEFAULT_RECV_BUFSIZE;
-import static com.android.net.module.util.netlink.StructNlMsgHdr.NLM_F_DUMP;
-import static com.android.net.module.util.netlink.StructNlMsgHdr.NLM_F_REQUEST;
-import static com.android.networkstack.tethering.OffloadHardwareInterface.IPCTNL_MSG_CT_GET;
 import static com.android.networkstack.tethering.OffloadHardwareInterface.IPCTNL_MSG_CT_NEW;
 import static com.android.networkstack.tethering.OffloadHardwareInterface.NFNL_SUBSYS_CTNETLINK;
 import static com.android.networkstack.tethering.OffloadHardwareInterface.NF_NETLINK_CONNTRACK_DESTROY;
 import static com.android.networkstack.tethering.OffloadHardwareInterface.NF_NETLINK_CONNTRACK_NEW;
 
-import static org.junit.Assert.assertNotNull;
-import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
 
 import android.os.Handler;
 import android.os.HandlerThread;
 import android.os.Looper;
 import android.os.NativeHandle;
-import android.system.Os;
+import android.system.ErrnoException;
+import android.util.Log;
 
 import androidx.test.filters.SmallTest;
 import androidx.test.runner.AndroidJUnit4;
 
 import com.android.net.module.util.SharedLog;
+import com.android.net.module.util.netlink.ConntrackMessage;
+import com.android.net.module.util.netlink.NetlinkMessage;
+import com.android.net.module.util.netlink.NetlinkSocket;
 import com.android.net.module.util.netlink.StructNlMsgHdr;
 
 import org.junit.Before;
@@ -45,18 +49,18 @@
 import org.junit.runner.RunWith;
 import org.mockito.MockitoAnnotations;
 
+import java.io.FileDescriptor;
 import java.net.InetAddress;
 import java.net.InetSocketAddress;
 import java.net.ServerSocket;
 import java.net.Socket;
-import java.net.SocketAddress;
 import java.nio.ByteBuffer;
-import java.nio.ByteOrder;
 
 @RunWith(AndroidJUnit4.class)
 @SmallTest
 public class ConntrackSocketTest {
     private static final long TIMEOUT = 500;
+    private static final String TAG = ConntrackSocketTest.class.getSimpleName();
 
     private HandlerThread mHandlerThread;
     private Handler mHandler;
@@ -80,51 +84,72 @@
         mOffloadHw = new OffloadHardwareInterface(mHandler, mLog, mDeps);
     }
 
+    void findConnectionOrThrow(FileDescriptor fd, InetSocketAddress local, InetSocketAddress remote)
+            throws Exception {
+        Log.d(TAG, "Looking for socket " + local + " -> " + remote);
+
+        // Loop until the socket is found (and return) or recvMessage throws an exception.
+        while (true) {
+            final ByteBuffer buffer = NetlinkSocket.recvMessage(fd, DEFAULT_RECV_BUFSIZE, TIMEOUT);
+
+            // Parse all the netlink messages in the dump.
+            // NetlinkMessage#parse returns null if the message is truncated or invalid.
+            while (buffer.remaining() > 0) {
+                NetlinkMessage nlmsg = NetlinkMessage.parse(buffer, NETLINK_NETFILTER);
+                Log.d(TAG, "Got netlink message: " + nlmsg);
+                if (!(nlmsg instanceof ConntrackMessage)) {
+                    continue;
+                }
+
+                StructNlMsgHdr nlmsghdr = nlmsg.getHeader();
+                ConntrackMessage ctmsg = (ConntrackMessage) nlmsg;
+                ConntrackMessage.Tuple tuple = ctmsg.tupleOrig;
+
+                if (nlmsghdr.nlmsg_type == (NFNL_SUBSYS_CTNETLINK << 8 | IPCTNL_MSG_CT_NEW)
+                        && tuple.protoNum == IPPROTO_TCP
+                        && tuple.srcIp.equals(local.getAddress())
+                        && tuple.dstIp.equals(remote.getAddress())
+                        && tuple.srcPort == (short) local.getPort()
+                        && tuple.dstPort == (short) remote.getPort()) {
+                    return;
+                }
+            }
+        }
+    }
+
     @Test
     public void testIpv4ConntrackSocket() throws Exception {
         // Set up server and connect.
-        final InetSocketAddress anyAddress = new InetSocketAddress(
-                InetAddress.getByName("127.0.0.1"), 0);
+        final InetAddress localhost = InetAddress.getByName("127.0.0.1");
+        final InetSocketAddress anyAddress = new InetSocketAddress(localhost, 0);
         final ServerSocket serverSocket = new ServerSocket();
         serverSocket.bind(anyAddress);
-        final SocketAddress theAddress = serverSocket.getLocalSocketAddress();
+        final InetSocketAddress theAddress =
+                (InetSocketAddress) serverSocket.getLocalSocketAddress();
 
         // Make a connection to the server.
         final Socket socket = new Socket();
         socket.connect(theAddress);
+        final InetSocketAddress localAddress = (InetSocketAddress) socket.getLocalSocketAddress();
         final Socket acceptedSocket = serverSocket.accept();
 
         final NativeHandle handle = mDeps.createConntrackSocket(
                 NF_NETLINK_CONNTRACK_NEW | NF_NETLINK_CONNTRACK_DESTROY);
-        mOffloadHw.sendIpv4NfGenMsg(handle,
-                (short) ((NFNL_SUBSYS_CTNETLINK << 8) | IPCTNL_MSG_CT_GET),
-                (short) (NLM_F_REQUEST | NLM_F_DUMP));
-
-        boolean foundConntrackEntry = false;
-        ByteBuffer buffer = ByteBuffer.allocate(DEFAULT_RECV_BUFSIZE);
-        buffer.order(ByteOrder.nativeOrder());
+        mOffloadHw.requestSocketDump(handle);
 
         try {
-            while (Os.read(handle.getFileDescriptor(), buffer) > 0) {
-                buffer.flip();
-
-                // TODO: ConntrackMessage should get a parse API like StructNlMsgHdr
-                // so we can confirm that the conntrack added is for the TCP connection above.
-                final StructNlMsgHdr nlmsghdr = StructNlMsgHdr.parse(buffer);
-                assertNotNull(nlmsghdr);
-
-                // As long as 1 conntrack entry is found test case will pass, even if it's not
-                // the from the TCP connection above.
-                if (nlmsghdr.nlmsg_type == ((NFNL_SUBSYS_CTNETLINK << 8) | IPCTNL_MSG_CT_NEW)) {
-                    foundConntrackEntry = true;
-                    break;
-                }
+            findConnectionOrThrow(handle.getFileDescriptor(), localAddress, theAddress);
+            // No exceptions? Socket was found, test passes.
+        } catch (ErrnoException e) {
+            if (e.errno == EAGAIN) {
+                fail("Did not find socket " + localAddress + "->" + theAddress + " in dump");
+            } else {
+                throw e;
             }
         } finally {
             socket.close();
             serverSocket.close();
+            acceptedSocket.close();
         }
-        assertTrue("Did not receive any NFNL_SUBSYS_CTNETLINK/IPCTNL_MSG_CT_NEW message",
-                foundConntrackEntry);
     }
 }