Apply multicast routing configs

Applies multicast routing configs when they are updated in
LocalNetworkConfig, or when local network/upstream network
status changes.

Bug: 281217735

Test: atest CSLocalAgentTests
Change-Id: Ib1c43c645a367d0f91e5cf0a0d9f8e5883be2c40
diff --git a/service/src/com/android/server/ConnectivityService.java b/service/src/com/android/server/ConnectivityService.java
index 7339d08..9850fde 100755
--- a/service/src/com/android/server/ConnectivityService.java
+++ b/service/src/com/android/server/ConnectivityService.java
@@ -171,6 +171,7 @@
 import android.net.LocalNetworkConfig;
 import android.net.LocalNetworkInfo;
 import android.net.MatchAllNetworkSpecifier;
+import android.net.MulticastRoutingConfig;
 import android.net.NativeNetworkConfig;
 import android.net.NativeNetworkType;
 import android.net.NattSocketKeepalive;
@@ -320,6 +321,7 @@
 import com.android.server.connectivity.KeepaliveTracker;
 import com.android.server.connectivity.LingerMonitor;
 import com.android.server.connectivity.MockableSystemProperties;
+import com.android.server.connectivity.MulticastRoutingCoordinatorService;
 import com.android.server.connectivity.MultinetworkPolicyTracker;
 import com.android.server.connectivity.NetworkAgentInfo;
 import com.android.server.connectivity.NetworkDiagnostics;
@@ -347,6 +349,7 @@
 import java.io.PrintWriter;
 import java.io.Writer;
 import java.net.Inet4Address;
+import java.net.Inet6Address;
 import java.net.InetAddress;
 import java.net.InetSocketAddress;
 import java.net.SocketException;
@@ -361,6 +364,7 @@
 import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
+import java.util.Map.Entry;
 import java.util.NoSuchElementException;
 import java.util.Objects;
 import java.util.Set;
@@ -496,6 +500,7 @@
     @GuardedBy("mTNSLock")
     private TestNetworkService mTNS;
     private final CompanionDeviceManagerProxyService mCdmps;
+    private final MulticastRoutingCoordinatorService mMulticastRoutingCoordinatorService;
     private final RoutingCoordinatorService mRoutingCoordinatorService;
 
     private final Object mTNSLock = new Object();
@@ -1421,6 +1426,17 @@
             return new AutomaticOnOffKeepaliveTracker(c, h);
         }
 
+        public MulticastRoutingCoordinatorService makeMulticastRoutingCoordinatorService(
+                    @NonNull Handler h) {
+            try {
+                return new MulticastRoutingCoordinatorService(h);
+            } catch (UnsupportedOperationException e) {
+                // Multicast routing is not supported by the kernel
+                Log.i(TAG, "Skipping unsupported MulticastRoutingCoordinatorService");
+                return null;
+            }
+        }
+
         /**
          * @see BatteryStatsManager
          */
@@ -1859,6 +1875,8 @@
         }
 
         mRoutingCoordinatorService = new RoutingCoordinatorService(netd);
+        mMulticastRoutingCoordinatorService =
+                mDeps.makeMulticastRoutingCoordinatorService(mHandler);
 
         mDestroyFrozenSockets = mDeps.isAtLeastU()
                 && mDeps.isFeatureEnabled(context, KEY_DESTROY_FROZEN_SOCKETS_VERSION);
@@ -5175,9 +5193,12 @@
     private void removeLocalNetworkUpstream(@NonNull final NetworkAgentInfo localAgent,
             @NonNull final NetworkAgentInfo upstream) {
         try {
+            final String localNetworkInterfaceName = localAgent.linkProperties.getInterfaceName();
+            final String upstreamNetworkInterfaceName = upstream.linkProperties.getInterfaceName();
             mRoutingCoordinatorService.removeInterfaceForward(
-                    localAgent.linkProperties.getInterfaceName(),
-                    upstream.linkProperties.getInterfaceName());
+                    localNetworkInterfaceName,
+                    upstreamNetworkInterfaceName);
+            disableMulticastRouting(localNetworkInterfaceName, upstreamNetworkInterfaceName);
         } catch (RemoteException e) {
             loge("Couldn't remove interface forward for "
                     + localAgent.linkProperties.getInterfaceName() + " to "
@@ -9095,6 +9116,61 @@
         updateCapabilities(nai.getScore(), nai, nai.networkCapabilities);
     }
 
+    private void maybeApplyMulticastRoutingConfig(@NonNull final NetworkAgentInfo nai,
+            final LocalNetworkConfig oldConfig,
+            final LocalNetworkConfig newConfig) {
+        final MulticastRoutingConfig oldUpstreamConfig =
+                oldConfig == null ? MulticastRoutingConfig.CONFIG_FORWARD_NONE :
+                        oldConfig.getUpstreamMulticastRoutingConfig();
+        final MulticastRoutingConfig oldDownstreamConfig =
+                oldConfig == null ? MulticastRoutingConfig.CONFIG_FORWARD_NONE :
+                        oldConfig.getDownstreamMulticastRoutingConfig();
+        final MulticastRoutingConfig newUpstreamConfig =
+                newConfig == null ? MulticastRoutingConfig.CONFIG_FORWARD_NONE :
+                        newConfig.getUpstreamMulticastRoutingConfig();
+        final MulticastRoutingConfig newDownstreamConfig =
+                newConfig == null ? MulticastRoutingConfig.CONFIG_FORWARD_NONE :
+                        newConfig.getDownstreamMulticastRoutingConfig();
+
+        if (oldUpstreamConfig.equals(newUpstreamConfig) &&
+            oldDownstreamConfig.equals(newDownstreamConfig)) {
+            return;
+        }
+
+        final String downstreamNetworkName = nai.linkProperties.getInterfaceName();
+        final LocalNetworkInfo lni = localNetworkInfoForNai(nai);
+        final Network upstreamNetwork = lni.getUpstreamNetwork();
+
+        if (upstreamNetwork != null) {
+            final String upstreamNetworkName =
+                    getLinkProperties(upstreamNetwork).getInterfaceName();
+            applyMulticastRoutingConfig(downstreamNetworkName, upstreamNetworkName, newConfig);
+        }
+    }
+
+    private void applyMulticastRoutingConfig(@NonNull String localNetworkInterfaceName,
+            @NonNull String upstreamNetworkInterfaceName,
+            @NonNull final LocalNetworkConfig config) {
+        if (mMulticastRoutingCoordinatorService == null) return;
+
+        mMulticastRoutingCoordinatorService.applyMulticastRoutingConfig(localNetworkInterfaceName,
+                upstreamNetworkInterfaceName, config.getUpstreamMulticastRoutingConfig());
+        mMulticastRoutingCoordinatorService.applyMulticastRoutingConfig
+                (upstreamNetworkInterfaceName, localNetworkInterfaceName,
+                        config.getDownstreamMulticastRoutingConfig());
+    }
+
+    private void disableMulticastRouting(@NonNull String localNetworkInterfaceName,
+            @NonNull String upstreamNetworkInterfaceName) {
+        if (mMulticastRoutingCoordinatorService == null) return;
+
+        mMulticastRoutingCoordinatorService.applyMulticastRoutingConfig(localNetworkInterfaceName,
+                upstreamNetworkInterfaceName, MulticastRoutingConfig.CONFIG_FORWARD_NONE);
+        mMulticastRoutingCoordinatorService.applyMulticastRoutingConfig
+                (upstreamNetworkInterfaceName, localNetworkInterfaceName,
+                        MulticastRoutingConfig.CONFIG_FORWARD_NONE);
+    }
+
     // oldConfig is null iff this is the original registration of the local network config
     private void handleUpdateLocalNetworkConfig(@NonNull final NetworkAgentInfo nai,
             @Nullable final LocalNetworkConfig oldConfig,
@@ -9108,7 +9184,6 @@
             Log.v(TAG, "Update local network config " + nai.network.netId + " : " + newConfig);
         }
         final LocalNetworkConfig.Builder configBuilder = new LocalNetworkConfig.Builder();
-        // TODO : apply the diff for multicast routing.
         configBuilder.setUpstreamMulticastRoutingConfig(
                 newConfig.getUpstreamMulticastRoutingConfig());
         configBuilder.setDownstreamMulticastRoutingConfig(
@@ -9167,6 +9242,7 @@
             configBuilder.setUpstreamSelector(oldRequest);
             nai.localNetworkConfig = configBuilder.build();
         }
+        maybeApplyMulticastRoutingConfig(nai, oldConfig, newConfig);
     }
 
     /**
@@ -10166,6 +10242,8 @@
                     if (null != change.mOldNetwork) {
                         mRoutingCoordinatorService.removeInterfaceForward(fromIface,
                                 change.mOldNetwork.linkProperties.getInterfaceName());
+                        disableMulticastRouting(fromIface,
+                                change.mOldNetwork.linkProperties.getInterfaceName());
                     }
                     // If the new upstream is already destroyed, there is no point in setting up
                     // a forward (in fact, it might forward to the interface for some new network !)
@@ -10174,6 +10252,9 @@
                     if (null != change.mNewNetwork && !change.mNewNetwork.isDestroyed()) {
                         mRoutingCoordinatorService.addInterfaceForward(fromIface,
                                 change.mNewNetwork.linkProperties.getInterfaceName());
+                        applyMulticastRoutingConfig(fromIface,
+                                change.mNewNetwork.linkProperties.getInterfaceName(),
+                                nai.localNetworkConfig);
                     }
                 } catch (final RemoteException e) {
                     loge("Can't update forwarding rules", e);
diff --git a/tests/unit/java/com/android/server/connectivityservice/CSLocalAgentTests.kt b/tests/unit/java/com/android/server/connectivityservice/CSLocalAgentTests.kt
index dd0706b..016608b 100644
--- a/tests/unit/java/com/android/server/connectivityservice/CSLocalAgentTests.kt
+++ b/tests/unit/java/com/android/server/connectivityservice/CSLocalAgentTests.kt
@@ -20,6 +20,8 @@
 import android.net.LinkAddress
 import android.net.LinkProperties
 import android.net.LocalNetworkConfig
+import android.net.MulticastRoutingConfig
+import android.net.MulticastRoutingConfig.CONFIG_FORWARD_NONE
 import android.net.NetworkCapabilities
 import android.net.NetworkCapabilities.NET_CAPABILITY_DUN
 import android.net.NetworkCapabilities.NET_CAPABILITY_INTERNET
@@ -45,6 +47,7 @@
 import org.junit.Test
 import org.junit.runner.RunWith
 import org.mockito.Mockito.clearInvocations
+import org.mockito.Mockito.eq
 import org.mockito.Mockito.inOrder
 import org.mockito.Mockito.never
 import org.mockito.Mockito.timeout
@@ -82,6 +85,24 @@
 @RunWith(DevSdkIgnoreRunner::class)
 @DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.TIRAMISU)
 class CSLocalAgentTests : CSTest() {
+    val multicastRoutingConfigMinScope =
+                MulticastRoutingConfig.Builder(MulticastRoutingConfig.FORWARD_WITH_MIN_SCOPE, 4)
+                .build();
+    val multicastRoutingConfigSelected =
+                MulticastRoutingConfig.Builder(MulticastRoutingConfig.FORWARD_SELECTED)
+                .build();
+    val upstreamSelectorAny = NetworkRequest.Builder()
+                .addForbiddenCapability(NET_CAPABILITY_LOCAL_NETWORK)
+                .build()
+    val upstreamSelectorWifi = NetworkRequest.Builder()
+                .addForbiddenCapability(NET_CAPABILITY_LOCAL_NETWORK)
+                .addTransportType(TRANSPORT_WIFI)
+                .build()
+    val upstreamSelectorCell = NetworkRequest.Builder()
+                .addForbiddenCapability(NET_CAPABILITY_LOCAL_NETWORK)
+                .addTransportType(TRANSPORT_CELLULAR)
+                .build()
+
     @Test
     fun testBadAgents() {
         deps.setBuildSdk(VERSION_V)
@@ -177,6 +198,266 @@
         localAgent.disconnect()
     }
 
+    private fun createLocalAgent(name: String, localNetworkConfig: FromS<LocalNetworkConfig>):
+                CSAgentWrapper {
+        val localAgent = Agent(
+                nc = nc(TRANSPORT_THREAD, NET_CAPABILITY_LOCAL_NETWORK),
+                lp = lp(name),
+                lnc = localNetworkConfig,
+        )
+        return localAgent
+    }
+
+    private fun createWifiAgent(name: String): CSAgentWrapper {
+        return Agent(score = keepScore(), lp = lp(name),
+                nc = nc(TRANSPORT_WIFI, NET_CAPABILITY_INTERNET))
+    }
+
+    private fun createCellAgent(name: String): CSAgentWrapper {
+        return Agent(score = keepScore(), lp = lp(name),
+                nc = nc(TRANSPORT_CELLULAR, NET_CAPABILITY_INTERNET))
+    }
+
+    private fun sendLocalNetworkConfig(localAgent: CSAgentWrapper,
+                upstreamSelector: NetworkRequest?, upstreamConfig: MulticastRoutingConfig,
+                downstreamConfig: MulticastRoutingConfig) {
+        val newLnc = LocalNetworkConfig.Builder()
+                .setUpstreamSelector(upstreamSelector)
+                .setUpstreamMulticastRoutingConfig(upstreamConfig)
+                .setDownstreamMulticastRoutingConfig(downstreamConfig)
+                .build()
+        localAgent.sendLocalNetworkConfig(newLnc)
+    }
+
+    @Test
+    fun testMulticastRoutingConfig() {
+        deps.setBuildSdk(VERSION_V)
+        val cb = TestableNetworkCallback()
+        cm.registerNetworkCallback(NetworkRequest.Builder().clearCapabilities().build(), cb)
+        val inOrder = inOrder(multicastRoutingCoordinatorService)
+
+        val lnc = FromS(LocalNetworkConfig.Builder()
+                .setUpstreamSelector(upstreamSelectorWifi)
+                .setUpstreamMulticastRoutingConfig(multicastRoutingConfigMinScope)
+                .setDownstreamMulticastRoutingConfig(multicastRoutingConfigSelected)
+                .build()
+        )
+        val localAgent = createLocalAgent("local0", lnc)
+        localAgent.connect()
+
+        cb.expectAvailableCallbacks(localAgent.network, validated = false)
+
+        val wifiAgent = createWifiAgent("wifi0")
+        wifiAgent.connect()
+        cb.expectAvailableCallbacks(wifiAgent.network, validated = false)
+        cb.expect<LocalInfoChanged>(localAgent.network) {
+            it.info.upstreamNetwork == wifiAgent.network
+        }
+
+        inOrder.verify(multicastRoutingCoordinatorService).applyMulticastRoutingConfig(
+                "local0", "wifi0", multicastRoutingConfigMinScope)
+        inOrder.verify(multicastRoutingCoordinatorService).applyMulticastRoutingConfig(
+                "wifi0", "local0", multicastRoutingConfigSelected)
+
+        wifiAgent.disconnect()
+
+        inOrder.verify(multicastRoutingCoordinatorService)
+                .applyMulticastRoutingConfig("local0", "wifi0", CONFIG_FORWARD_NONE)
+        inOrder.verify(multicastRoutingCoordinatorService)
+                .applyMulticastRoutingConfig("wifi0", "local0", CONFIG_FORWARD_NONE)
+
+        localAgent.disconnect()
+    }
+
+    @Test
+    fun testMulticastRoutingConfig_2LocalNetworks() {
+        deps.setBuildSdk(VERSION_V)
+        val inOrder = inOrder(multicastRoutingCoordinatorService)
+        val lnc = FromS(LocalNetworkConfig.Builder()
+                .setUpstreamSelector(upstreamSelectorWifi)
+                .setUpstreamMulticastRoutingConfig(multicastRoutingConfigMinScope)
+                .setDownstreamMulticastRoutingConfig(multicastRoutingConfigSelected)
+                .build()
+        )
+        val localAgent0 = createLocalAgent("local0", lnc)
+        localAgent0.connect()
+
+        val wifiAgent = createWifiAgent("wifi0")
+        wifiAgent.connect()
+        waitForIdle()
+
+        inOrder.verify(multicastRoutingCoordinatorService).applyMulticastRoutingConfig(
+                "local0", "wifi0", multicastRoutingConfigMinScope)
+        inOrder.verify(multicastRoutingCoordinatorService).applyMulticastRoutingConfig(
+                "wifi0", "local0", multicastRoutingConfigSelected)
+
+        val localAgent1 = createLocalAgent("local1", lnc)
+        localAgent1.connect()
+        waitForIdle()
+
+        inOrder.verify(multicastRoutingCoordinatorService).applyMulticastRoutingConfig(
+                "local1", "wifi0", multicastRoutingConfigMinScope)
+        inOrder.verify(multicastRoutingCoordinatorService).applyMulticastRoutingConfig(
+                "wifi0", "local1", multicastRoutingConfigSelected)
+
+        localAgent0.disconnect()
+        localAgent1.disconnect()
+        wifiAgent.disconnect()
+    }
+
+    @Test
+    fun testMulticastRoutingConfig_UpstreamNetworkCellToWifi() {
+        deps.setBuildSdk(VERSION_V)
+        val cb = TestableNetworkCallback()
+        cm.registerNetworkCallback(NetworkRequest.Builder().clearCapabilities()
+                        .addCapability(NET_CAPABILITY_LOCAL_NETWORK)
+                        .build(), cb)
+        val inOrder = inOrder(multicastRoutingCoordinatorService)
+        val lnc = FromS(LocalNetworkConfig.Builder()
+                .setUpstreamSelector(upstreamSelectorAny)
+                .setUpstreamMulticastRoutingConfig(multicastRoutingConfigMinScope)
+                .setDownstreamMulticastRoutingConfig(multicastRoutingConfigSelected)
+                .build()
+        )
+        val localAgent = createLocalAgent("local0", lnc)
+        val wifiAgent = createWifiAgent("wifi0")
+        val cellAgent = createCellAgent("cell0")
+
+        localAgent.connect()
+        cb.expectAvailableCallbacks(localAgent.network, validated = false)
+
+        cellAgent.connect()
+        cb.expect<LocalInfoChanged>(localAgent.network) {
+            it.info.upstreamNetwork == cellAgent.network
+        }
+
+        inOrder.verify(multicastRoutingCoordinatorService).applyMulticastRoutingConfig(
+                "local0", "cell0", multicastRoutingConfigMinScope)
+        inOrder.verify(multicastRoutingCoordinatorService).applyMulticastRoutingConfig(
+                "cell0", "local0", multicastRoutingConfigSelected)
+
+        wifiAgent.connect()
+
+        cb.expect<LocalInfoChanged>(localAgent.network) {
+            it.info.upstreamNetwork == wifiAgent.network
+        }
+
+        // upstream should have been switched to wifi
+        inOrder.verify(multicastRoutingCoordinatorService)
+                .applyMulticastRoutingConfig("local0", "cell0", CONFIG_FORWARD_NONE)
+        inOrder.verify(multicastRoutingCoordinatorService)
+                .applyMulticastRoutingConfig("cell0", "local0", CONFIG_FORWARD_NONE)
+        inOrder.verify(multicastRoutingCoordinatorService).applyMulticastRoutingConfig(
+                "local0", "wifi0", multicastRoutingConfigMinScope)
+        inOrder.verify(multicastRoutingCoordinatorService).applyMulticastRoutingConfig(
+                "wifi0", "local0", multicastRoutingConfigSelected)
+
+        localAgent.disconnect()
+        cellAgent.disconnect()
+        wifiAgent.disconnect()
+    }
+
+    @Test
+    fun testMulticastRoutingConfig_UpstreamSelectorCellToWifi() {
+        deps.setBuildSdk(VERSION_V)
+        val cb = TestableNetworkCallback()
+        cm.registerNetworkCallback(NetworkRequest.Builder().clearCapabilities()
+                        .addCapability(NET_CAPABILITY_LOCAL_NETWORK)
+                        .build(), cb)
+        val inOrder = inOrder(multicastRoutingCoordinatorService)
+        val lnc = FromS(LocalNetworkConfig.Builder()
+                .setUpstreamSelector(upstreamSelectorCell)
+                .setUpstreamMulticastRoutingConfig(multicastRoutingConfigMinScope)
+                .setDownstreamMulticastRoutingConfig(multicastRoutingConfigSelected)
+                .build()
+        )
+        val localAgent = createLocalAgent("local0", lnc)
+        val wifiAgent = createWifiAgent("wifi0")
+        val cellAgent = createCellAgent("cell0")
+
+        localAgent.connect()
+        cellAgent.connect()
+        wifiAgent.connect()
+        cb.expectAvailableCallbacks(localAgent.network, validated = false)
+        cb.expect<LocalInfoChanged>(localAgent.network) {
+            it.info.upstreamNetwork == cellAgent.network
+        }
+
+        inOrder.verify(multicastRoutingCoordinatorService).applyMulticastRoutingConfig(
+                "local0", "cell0", multicastRoutingConfigMinScope)
+        inOrder.verify(multicastRoutingCoordinatorService).applyMulticastRoutingConfig(
+                "cell0", "local0", multicastRoutingConfigSelected)
+
+        sendLocalNetworkConfig(localAgent, upstreamSelectorWifi, multicastRoutingConfigMinScope,
+                multicastRoutingConfigSelected)
+        cb.expect<LocalInfoChanged>(localAgent.network) {
+            it.info.upstreamNetwork == wifiAgent.network
+        }
+
+        // upstream should have been switched to wifi
+        inOrder.verify(multicastRoutingCoordinatorService)
+                .applyMulticastRoutingConfig("local0", "cell0", CONFIG_FORWARD_NONE)
+        inOrder.verify(multicastRoutingCoordinatorService)
+                .applyMulticastRoutingConfig("cell0", "local0", CONFIG_FORWARD_NONE)
+        inOrder.verify(multicastRoutingCoordinatorService).applyMulticastRoutingConfig(
+                "local0", "wifi0", multicastRoutingConfigMinScope)
+        inOrder.verify(multicastRoutingCoordinatorService).applyMulticastRoutingConfig(
+                "wifi0", "local0", multicastRoutingConfigSelected)
+
+        localAgent.disconnect()
+        cellAgent.disconnect()
+        wifiAgent.disconnect()
+    }
+
+    @Test
+    fun testMulticastRoutingConfig_UpstreamSelectorWifiToNull() {
+        deps.setBuildSdk(VERSION_V)
+        val cb = TestableNetworkCallback()
+        cm.registerNetworkCallback(NetworkRequest.Builder().clearCapabilities()
+                        .addCapability(NET_CAPABILITY_LOCAL_NETWORK)
+                        .build(), cb)
+        val inOrder = inOrder(multicastRoutingCoordinatorService)
+        val lnc = FromS(LocalNetworkConfig.Builder()
+                .setUpstreamSelector(upstreamSelectorWifi)
+                .setUpstreamMulticastRoutingConfig(multicastRoutingConfigMinScope)
+                .setDownstreamMulticastRoutingConfig(multicastRoutingConfigSelected)
+                .build()
+        )
+        val localAgent = createLocalAgent("local0", lnc)
+        localAgent.connect()
+        val wifiAgent = createWifiAgent("wifi0")
+        wifiAgent.connect()
+        cb.expectAvailableCallbacks(localAgent.network, validated = false)
+        cb.expect<LocalInfoChanged>(localAgent.network) {
+            it.info.upstreamNetwork == wifiAgent.network
+        }
+
+        inOrder.verify(multicastRoutingCoordinatorService).applyMulticastRoutingConfig(
+                "local0", "wifi0", multicastRoutingConfigMinScope)
+        inOrder.verify(multicastRoutingCoordinatorService).applyMulticastRoutingConfig(
+                "wifi0", "local0", multicastRoutingConfigSelected)
+
+        sendLocalNetworkConfig(localAgent, null, multicastRoutingConfigMinScope,
+                multicastRoutingConfigSelected)
+        cb.expect<LocalInfoChanged>(localAgent.network) {
+            it.info.upstreamNetwork == null
+        }
+
+        // upstream should have been switched to null
+        inOrder.verify(multicastRoutingCoordinatorService)
+                .applyMulticastRoutingConfig("local0", "wifi0", CONFIG_FORWARD_NONE)
+        inOrder.verify(multicastRoutingCoordinatorService)
+                .applyMulticastRoutingConfig("wifi0", "local0", CONFIG_FORWARD_NONE)
+        inOrder.verify(multicastRoutingCoordinatorService, never()).applyMulticastRoutingConfig(
+                eq("local0"), any(), eq(multicastRoutingConfigMinScope))
+        inOrder.verify(multicastRoutingCoordinatorService, never()).applyMulticastRoutingConfig(
+                any(), eq("local0"), eq(multicastRoutingConfigSelected))
+
+        localAgent.disconnect()
+        wifiAgent.disconnect()
+    }
+
+
     @Test
     fun testUnregisterUpstreamAfterReplacement_SameIfaceName() {
         doTestUnregisterUpstreamAfterReplacement(true)
@@ -196,11 +477,10 @@
         val localAgent = Agent(nc = nc(TRANSPORT_WIFI, NET_CAPABILITY_LOCAL_NETWORK),
                 lp = lp("local0"),
                 lnc = FromS(LocalNetworkConfig.Builder()
-                .setUpstreamSelector(NetworkRequest.Builder()
-                        .addForbiddenCapability(NET_CAPABILITY_LOCAL_NETWORK)
-                        .addTransportType(TRANSPORT_WIFI)
-                        .build())
-                .build()),
+                        .setUpstreamSelector(upstreamSelectorWifi)
+                        .setUpstreamMulticastRoutingConfig(multicastRoutingConfigMinScope)
+                        .setDownstreamMulticastRoutingConfig(multicastRoutingConfigSelected)
+                        .build()),
                 score = FromS(NetworkScore.Builder()
                         .setKeepConnectedReason(KEEP_CONNECTED_LOCAL_NETWORK)
                         .build())
@@ -219,10 +499,15 @@
         }
 
         clearInvocations(netd)
-        val inOrder = inOrder(netd)
+        clearInvocations(multicastRoutingCoordinatorService)
+        val inOrder = inOrder(netd, multicastRoutingCoordinatorService)
         wifiAgent.unregisterAfterReplacement(LONG_TIMEOUT_MS)
         waitForIdle()
         inOrder.verify(netd).ipfwdRemoveInterfaceForward("local0", "wifi0")
+        inOrder.verify(multicastRoutingCoordinatorService)
+                .applyMulticastRoutingConfig("local0", "wifi0", CONFIG_FORWARD_NONE)
+        inOrder.verify(multicastRoutingCoordinatorService)
+                .applyMulticastRoutingConfig("wifi0", "local0", CONFIG_FORWARD_NONE)
         inOrder.verify(netd).networkDestroy(wifiAgent.network.netId)
 
         val wifiIface2 = if (sameIfaceName) "wifi0" else "wifi1"
@@ -235,9 +520,16 @@
         cb.expect<Lost> { it.network == wifiAgent.network }
 
         inOrder.verify(netd).ipfwdAddInterfaceForward("local0", wifiIface2)
-        if (sameIfaceName) {
-            inOrder.verify(netd, never()).ipfwdRemoveInterfaceForward(any(), any())
-        }
+        inOrder.verify(multicastRoutingCoordinatorService).applyMulticastRoutingConfig(
+                "local0", wifiIface2, multicastRoutingConfigMinScope)
+        inOrder.verify(multicastRoutingCoordinatorService).applyMulticastRoutingConfig(
+                wifiIface2, "local0", multicastRoutingConfigSelected)
+
+        inOrder.verify(netd, never()).ipfwdRemoveInterfaceForward(any(), any())
+        inOrder.verify(multicastRoutingCoordinatorService, never())
+                .applyMulticastRoutingConfig("local0", "wifi0", CONFIG_FORWARD_NONE)
+        inOrder.verify(multicastRoutingCoordinatorService, never())
+                .applyMulticastRoutingConfig("wifi0", "local0", CONFIG_FORWARD_NONE)
     }
 
     @Test
diff --git a/tests/unit/java/com/android/server/connectivityservice/base/CSTest.kt b/tests/unit/java/com/android/server/connectivityservice/base/CSTest.kt
index 958c4f2..83a11cf 100644
--- a/tests/unit/java/com/android/server/connectivityservice/base/CSTest.kt
+++ b/tests/unit/java/com/android/server/connectivityservice/base/CSTest.kt
@@ -61,9 +61,11 @@
 import com.android.server.connectivity.CarrierPrivilegeAuthenticator
 import com.android.server.connectivity.ClatCoordinator
 import com.android.server.connectivity.ConnectivityFlags
+import com.android.server.connectivity.MulticastRoutingCoordinatorService
 import com.android.server.connectivity.MultinetworkPolicyTracker
 import com.android.server.connectivity.MultinetworkPolicyTrackerTestDependencies
 import com.android.server.connectivity.ProxyTracker
+import com.android.server.connectivity.RoutingCoordinatorService
 import com.android.testutils.visibleOnHandlerThread
 import com.android.testutils.waitForIdle
 import java.util.concurrent.Executors
@@ -166,6 +168,8 @@
         doReturn(true).`when`(it).isDataCapable()
     }
 
+    val multicastRoutingCoordinatorService = mock<MulticastRoutingCoordinatorService>()
+
     val deps = CSDeps()
     val service = makeConnectivityService(context, netd, deps).also { it.systemReadyInternal() }
     val cm = ConnectivityManager(context, service)
@@ -179,6 +183,8 @@
 
         override fun makeHandlerThread(tag: String) = csHandlerThread
         override fun makeProxyTracker(context: Context, connServiceHandler: Handler) = proxyTracker
+        override fun makeMulticastRoutingCoordinatorService(handler: Handler) =
+                this@CSTest.multicastRoutingCoordinatorService
 
         override fun makeCarrierPrivilegeAuthenticator(
                 context: Context,