Add OT multicast addresses to tun interface.

Bug: 333668308

Test: atest ThreadNetworkIntegrationTests:android.net.thread.ThreadIntegrationTest#joinNetwork_tunInterfaceJoinsAllRouterMulticastGroup
Change-Id: I0bb91259ae5cbe0eb9c5d71b7aa4e5b2a95f8c6b
diff --git a/thread/service/java/com/android/server/thread/TunInterfaceController.java b/thread/service/java/com/android/server/thread/TunInterfaceController.java
index 97cdd55..dec72b2 100644
--- a/thread/service/java/com/android/server/thread/TunInterfaceController.java
+++ b/thread/service/java/com/android/server/thread/TunInterfaceController.java
@@ -16,6 +16,8 @@
 
 package com.android.server.thread;
 
+import static android.system.OsConstants.EADDRINUSE;
+
 import android.annotation.Nullable;
 import android.net.IpPrefix;
 import android.net.LinkAddress;
@@ -39,6 +41,10 @@
 import java.io.InterruptedIOException;
 import java.net.Inet6Address;
 import java.net.InetAddress;
+import java.net.InetSocketAddress;
+import java.net.MulticastSocket;
+import java.net.NetworkInterface;
+import java.net.SocketException;
 import java.net.UnknownHostException;
 import java.util.ArrayList;
 import java.util.List;
@@ -58,12 +64,16 @@
     private ParcelFileDescriptor mParcelTunFd;
     private FileDescriptor mNetlinkSocket;
     private static int sNetlinkSeqNo = 0;
+    private final MulticastSocket mMulticastSocket; // For join group and leave group
+    private NetworkInterface mNetworkInterface;
+    private List<InetAddress> mMulticastAddresses = new ArrayList<>();
 
     /** Creates a new {@link TunInterfaceController} instance for given interface. */
     public TunInterfaceController(String interfaceName) {
         mIfName = interfaceName;
         mLinkProperties.setInterfaceName(mIfName);
         mLinkProperties.setMtu(MTU);
+        mMulticastSocket = createMulticastSocket();
     }
 
     /** Returns link properties of the Thread TUN interface. */
@@ -83,6 +93,11 @@
         } catch (ErrnoException e) {
             throw new IOException("Failed to create netlink socket", e);
         }
+        try {
+            mNetworkInterface = NetworkInterface.getByName(mIfName);
+        } catch (SocketException e) {
+            throw new IOException("Failed to get NetworkInterface", e);
+        }
     }
 
     public void destroyTunInterface() {
@@ -94,6 +109,7 @@
         }
         mParcelTunFd = null;
         mNetlinkSocket = null;
+        mNetworkInterface = null;
     }
 
     /** Returns the FD of the tunnel interface. */
@@ -187,6 +203,7 @@
 
     public void updateAddresses(List<Ipv6AddressInfo> addressInfoList) {
         final List<LinkAddress> newLinkAddresses = new ArrayList<>();
+        final List<InetAddress> newMulticastAddresses = new ArrayList<>();
         boolean hasActiveOmrAddress = false;
 
         for (Ipv6AddressInfo addressInfo : addressInfoList) {
@@ -199,12 +216,10 @@
         for (Ipv6AddressInfo addressInfo : addressInfoList) {
             InetAddress address = addressInfoToInetAddress(addressInfo);
             if (address.isMulticastAddress()) {
-                // TODO: Logging here will create repeated logs for a single multicast address, and
-                // it currently is not mandatory for debugging. Add log for ignored multicast
-                // address when necessary.
-                continue;
+                newMulticastAddresses.add(address);
+            } else {
+                newLinkAddresses.add(newLinkAddress(addressInfo, hasActiveOmrAddress));
             }
-            newLinkAddresses.add(newLinkAddress(addressInfo, hasActiveOmrAddress));
         }
 
         final CompareResult<LinkAddress> addressDiff =
@@ -215,6 +230,17 @@
         for (LinkAddress linkAddress : addressDiff.added) {
             addAddress(linkAddress);
         }
+
+        final CompareResult<InetAddress> multicastAddressDiff =
+                new CompareResult<>(mMulticastAddresses, newMulticastAddresses);
+        for (InetAddress address : multicastAddressDiff.removed) {
+            leaveGroup(address);
+        }
+        for (InetAddress address : multicastAddressDiff.added) {
+            joinGroup(address);
+        }
+        mMulticastAddresses.clear();
+        mMulticastAddresses.addAll(newMulticastAddresses);
     }
 
     private RouteInfo getRouteForAddress(LinkAddress linkAddress) {
@@ -274,4 +300,37 @@
                 deprecationTimeMillis,
                 LinkAddress.LIFETIME_PERMANENT /* expirationTime */);
     }
+
+    private MulticastSocket createMulticastSocket() {
+        try {
+            return new MulticastSocket();
+        } catch (IOException e) {
+            throw new IllegalStateException("Failed to create multicast socket ", e);
+        }
+    }
+
+    private void joinGroup(InetAddress address) {
+        InetSocketAddress socketAddress = new InetSocketAddress(address, 0);
+        try {
+            mMulticastSocket.joinGroup(socketAddress, mNetworkInterface);
+        } catch (IOException e) {
+            if (e.getCause() instanceof ErrnoException) {
+                ErrnoException ee = (ErrnoException) e.getCause();
+                if (ee.errno == EADDRINUSE) {
+                    Log.w(TAG, "Already joined group" + address.getHostAddress(), e);
+                    return;
+                }
+            }
+            Log.e(TAG, "failed to join group " + address.getHostAddress(), e);
+        }
+    }
+
+    private void leaveGroup(InetAddress address) {
+        InetSocketAddress socketAddress = new InetSocketAddress(address, 0);
+        try {
+            mMulticastSocket.leaveGroup(socketAddress, mNetworkInterface);
+        } catch (IOException e) {
+            Log.e(TAG, "failed to leave group " + address.getHostAddress(), e);
+        }
+    }
 }
diff --git a/thread/tests/integration/Android.bp b/thread/tests/integration/Android.bp
index 94985b1..71693af 100644
--- a/thread/tests/integration/Android.bp
+++ b/thread/tests/integration/Android.bp
@@ -34,6 +34,7 @@
         "testables",
         "ThreadNetworkTestUtils",
         "truth",
+        "ot-daemon-aidl-java",
     ],
     libs: [
         "android.test.runner",
diff --git a/thread/tests/integration/src/android/net/thread/ThreadIntegrationTest.java b/thread/tests/integration/src/android/net/thread/ThreadIntegrationTest.java
index 1410d41..e211e22 100644
--- a/thread/tests/integration/src/android/net/thread/ThreadIntegrationTest.java
+++ b/thread/tests/integration/src/android/net/thread/ThreadIntegrationTest.java
@@ -22,15 +22,18 @@
 import static android.net.thread.utils.IntegrationTestUtils.CALLBACK_TIMEOUT;
 import static android.net.thread.utils.IntegrationTestUtils.RESTART_JOIN_TIMEOUT;
 import static android.net.thread.utils.IntegrationTestUtils.getIpv6LinkAddresses;
+import static android.net.thread.utils.IntegrationTestUtils.isInMulticastGroup;
 import static android.net.thread.utils.IntegrationTestUtils.waitFor;
 
 import static com.android.compatibility.common.util.SystemUtil.runShellCommand;
 import static com.android.compatibility.common.util.SystemUtil.runShellCommandOrThrow;
+import static com.android.server.thread.openthread.IOtDaemon.TUN_IF_NAME;
 
 import static com.google.common.io.BaseEncoding.base16;
 import static com.google.common.truth.Truth.assertThat;
 
 import android.content.Context;
+import android.net.InetAddresses;
 import android.net.IpPrefix;
 import android.net.LinkAddress;
 import android.net.thread.utils.FullThreadDevice;
@@ -83,6 +86,9 @@
     private static final ActiveOperationalDataset DEFAULT_DATASET =
             ActiveOperationalDataset.fromThreadTlvs(DEFAULT_DATASET_TLVS);
 
+    private static final Inet6Address GROUP_ADDR_ALL_ROUTERS =
+            (Inet6Address) InetAddresses.parseNumericAddress("ff02::2");
+
     @Rule public final ThreadFeatureCheckerRule mThreadRule = new ThreadFeatureCheckerRule();
 
     private ExecutorService mExecutor;
@@ -224,6 +230,13 @@
         mOtCtl.executeCommand("br enable");
     }
 
+    @Test
+    public void joinNetwork_tunInterfaceJoinsAllRouterMulticastGroup() throws Exception {
+        mController.joinAndWait(DEFAULT_DATASET);
+
+        assertTunInterfaceMemberOfGroup(GROUP_ADDR_ALL_ROUTERS);
+    }
+
     // TODO (b/323300829): add more tests for integration with linux platform and
     // ConnectivityService
 
@@ -259,4 +272,8 @@
             throw new IllegalStateException(e);
         }
     }
+
+    private void assertTunInterfaceMemberOfGroup(Inet6Address address) throws Exception {
+        waitFor(() -> isInMulticastGroup(TUN_IF_NAME, address), TUN_ADDR_UPDATE_TIMEOUT);
+    }
 }