Set up forwarding rules for local network agents

Test: CSLocalAgentTests, new tests for this
Change-Id: I8994af350a1799ab5f6ebb2872f2abfaf174bd61
diff --git a/service/src/com/android/server/ConnectivityService.java b/service/src/com/android/server/ConnectivityService.java
index 3dc5692..50b4134 100755
--- a/service/src/com/android/server/ConnectivityService.java
+++ b/service/src/com/android/server/ConnectivityService.java
@@ -4183,7 +4183,7 @@
                 }
                 case NetworkAgent.EVENT_LOCAL_NETWORK_CONFIG_CHANGED: {
                     final LocalNetworkConfig config = (LocalNetworkConfig) arg.second;
-                    updateLocalNetworkConfig(nai, config);
+                    updateLocalNetworkConfig(nai, nai.localNetworkConfig, config);
                     break;
                 }
                 case NetworkAgent.EVENT_NETWORK_SCORE_CHANGED: {
@@ -4944,6 +4944,17 @@
             mDefaultInetConditionPublished = 0;
         }
         notifyIfacesChangedForNetworkStats();
+        // If this was a local network forwarded to some upstream, or if some local network was
+        // forwarded to this nai, then disable forwarding rules now.
+        maybeDisableForwardRulesForDisconnectingNai(nai);
+        // If this is a local network with an upstream selector, remove the associated network
+        // request.
+        if (nai.isLocalNetwork()) {
+            final NetworkRequest selector = nai.localNetworkConfig.getUpstreamSelector();
+            if (null != selector) {
+                handleRemoveNetworkRequest(mNetworkRequests.get(selector));
+            }
+        }
         // TODO - we shouldn't send CALLBACK_LOST to requests that can be satisfied
         // by other networks that are already connected. Perhaps that can be done by
         // sending all CALLBACK_LOST messages (for requests, not listens) at the end
@@ -5057,6 +5068,48 @@
         mNetIdManager.releaseNetId(nai.network.getNetId());
     }
 
+    private void maybeDisableForwardRulesForDisconnectingNai(
+            @NonNull final NetworkAgentInfo disconnecting) {
+        // Step 1 : maybe this network was the upstream for one or more local networks.
+        for (final NetworkAgentInfo local : mNetworkAgentInfos) {
+            if (!local.isLocalNetwork()) continue;
+            final NetworkRequest selector = local.localNetworkConfig.getUpstreamSelector();
+            if (null == selector) continue;
+            final NetworkRequestInfo nri = mNetworkRequests.get(selector);
+            // null == nri can happen while disconnecting a network, because destroyNetwork() is
+            // called after removing all associated NRIs from mNetworkRequests.
+            if (null == nri) continue;
+            final NetworkAgentInfo satisfier = nri.getSatisfier();
+            if (disconnecting != satisfier) continue;
+            removeLocalNetworkUpstream(local, disconnecting);
+        }
+
+        // Step 2 : maybe this is a local network that had an upstream.
+        if (!disconnecting.isLocalNetwork()) return;
+        final NetworkRequest selector = disconnecting.localNetworkConfig.getUpstreamSelector();
+        if (null == selector) return;
+        final NetworkRequestInfo nri = mNetworkRequests.get(selector);
+        // As above null == nri can happen while disconnecting a network, because destroyNetwork()
+        // is called after removing all associated NRIs from mNetworkRequests.
+        if (null == nri) return;
+        final NetworkAgentInfo satisfier = nri.getSatisfier();
+        if (null == satisfier) return;
+        removeLocalNetworkUpstream(disconnecting, satisfier);
+    }
+
+    private void removeLocalNetworkUpstream(@NonNull final NetworkAgentInfo localAgent,
+            @NonNull final NetworkAgentInfo upstream) {
+        try {
+            mRoutingCoordinatorService.removeInterfaceForward(
+                    localAgent.linkProperties.getInterfaceName(),
+                    upstream.linkProperties.getInterfaceName());
+        } catch (RemoteException e) {
+            loge("Couldn't remove interface forward for "
+                    + localAgent.linkProperties.getInterfaceName() + " to "
+                    + upstream.linkProperties.getInterfaceName() + " while disconnecting");
+        }
+    }
+
     private boolean createNativeNetwork(@NonNull NetworkAgentInfo nai) {
         try {
             // This should never fail.  Specifying an already in use NetID will cause failure.
@@ -5071,10 +5124,9 @@
                         !nai.networkAgentConfig.allowBypass /* secure */,
                         getVpnType(nai), nai.networkAgentConfig.excludeLocalRouteVpn);
             } else {
-                final boolean hasLocalCap =
-                        nai.networkCapabilities.hasCapability(NET_CAPABILITY_LOCAL_NETWORK);
                 config = new NativeNetworkConfig(nai.network.getNetId(),
-                        hasLocalCap ? NativeNetworkType.PHYSICAL_LOCAL : NativeNetworkType.PHYSICAL,
+                        nai.isLocalNetwork() ? NativeNetworkType.PHYSICAL_LOCAL
+                                : NativeNetworkType.PHYSICAL,
                         getNetworkPermission(nai.networkCapabilities),
                         false /* secure */,
                         VpnManager.TYPE_VPN_NONE,
@@ -5095,6 +5147,9 @@
         if (mDscpPolicyTracker != null) {
             mDscpPolicyTracker.removeAllDscpPolicies(nai, false);
         }
+        // Remove any forwarding rules to and from the interface for this network, since
+        // the interface is going to go away.
+        maybeDisableForwardRulesForDisconnectingNai(nai);
         try {
             mNetd.networkDestroy(nai.network.getNetId());
         } catch (RemoteException | ServiceSpecificException e) {
@@ -8264,6 +8319,9 @@
             e.rethrowAsRuntimeException();
         }
 
+        if (nai.isLocalNetwork()) {
+            updateLocalNetworkConfig(nai, null /* oldConfig */, nai.localNetworkConfig);
+        }
         nai.notifyRegistered();
         NetworkInfo networkInfo = nai.networkInfo;
         updateNetworkInfo(nai, networkInfo);
@@ -8929,14 +8987,67 @@
         updateCapabilities(nai.getScore(), nai, nai.networkCapabilities);
     }
 
+    // oldConfig is null iff this is the original registration of the local network config
     private void updateLocalNetworkConfig(@NonNull final NetworkAgentInfo nai,
-            @NonNull final LocalNetworkConfig config) {
-        if (!nai.networkCapabilities.hasCapability(NET_CAPABILITY_LOCAL_NETWORK)) {
+            @Nullable final LocalNetworkConfig oldConfig,
+            @NonNull final LocalNetworkConfig newConfig) {
+        if (!nai.isLocalNetwork()) {
             Log.wtf(TAG, "Ignoring update of a local network info on non-local network " + nai);
             return;
         }
-        // TODO : actually apply the diff.
-        nai.localNetworkConfig = config;
+
+        final LocalNetworkConfig.Builder configBuilder = new LocalNetworkConfig.Builder();
+        // TODO : apply the diff for multicast routing.
+        configBuilder.setUpstreamMulticastRoutingConfig(
+                newConfig.getUpstreamMulticastRoutingConfig());
+        configBuilder.setDownstreamMulticastRoutingConfig(
+                newConfig.getDownstreamMulticastRoutingConfig());
+
+        final NetworkRequest oldRequest =
+                (null == oldConfig) ? null : oldConfig.getUpstreamSelector();
+        final NetworkCapabilities oldCaps =
+                (null == oldRequest) ? null : oldRequest.networkCapabilities;
+        final NetworkRequestInfo oldNri =
+                null == oldRequest ? null : mNetworkRequests.get(oldRequest);
+        final NetworkAgentInfo oldSatisfier =
+                null == oldNri ? null : oldNri.getSatisfier();
+        final NetworkRequest newRequest = newConfig.getUpstreamSelector();
+        final NetworkCapabilities newCaps =
+                (null == newRequest) ? null : newRequest.networkCapabilities;
+        final boolean requestUpdated = !Objects.equals(newCaps, oldCaps);
+        if (null != oldRequest && requestUpdated) {
+            handleRemoveNetworkRequest(mNetworkRequests.get(oldRequest));
+            if (null == newRequest && null != oldSatisfier) {
+                // If there is an old satisfier, but no new request, then remove the old upstream.
+                removeLocalNetworkUpstream(nai, oldSatisfier);
+                nai.localNetworkConfig = configBuilder.build();
+                return;
+            }
+        }
+        if (null != newRequest && requestUpdated) {
+            // File the new request if :
+            //  - it has changed (requestUpdated), or
+            //  - it's the first time this local info (null == oldConfig)
+            // is updated and the request has not been filed yet.
+            // Requests for local info are always LISTEN_FOR_BEST, because they have at most one
+            // upstream (the best) but never request it to be brought up.
+            final NetworkRequest nr = new NetworkRequest(newCaps, ConnectivityManager.TYPE_NONE,
+                    nextNetworkRequestId(), LISTEN_FOR_BEST);
+            configBuilder.setUpstreamSelector(nr);
+            final NetworkRequestInfo nri = new NetworkRequestInfo(
+                    nai.creatorUid, nr, null /* messenger */, null /* binder */,
+                    0 /* callbackFlags */, null /* attributionTag */);
+            if (null != oldSatisfier) {
+                // Set the old satisfier in the new NRI so that the rematch will see any changes
+                nri.setSatisfier(oldSatisfier, nr);
+            }
+            nai.localNetworkConfig = configBuilder.build();
+            handleRegisterNetworkRequest(nri);
+        } else {
+            configBuilder.setUpstreamSelector(oldRequest);
+            nai.localNetworkConfig = configBuilder.build();
+        }
+
     }
 
     /**
@@ -9718,7 +9829,8 @@
             if (VDBG) log("rematch for " + newSatisfier.toShortString());
             if (null != previousRequest && null != previousSatisfier) {
                 if (VDBG || DDBG) {
-                    log("   accepting network in place of " + previousSatisfier.toShortString());
+                    log("   accepting network in place of " + previousSatisfier.toShortString()
+                            + " for " + newRequest);
                 }
                 previousSatisfier.removeRequest(previousRequest.requestId);
                 if (canSupportGracefulNetworkSwitch(previousSatisfier, newSatisfier)
@@ -9737,7 +9849,7 @@
                     previousSatisfier.lingerRequest(previousRequest.requestId, now);
                 }
             } else {
-                if (VDBG || DDBG) log("   accepting network in place of null");
+                if (VDBG || DDBG) log("   accepting network in place of null for " + newRequest);
             }
 
             // To prevent constantly CPU wake up for nascent timer, if a network comes up
@@ -9853,6 +9965,14 @@
         }
     }
 
+    private boolean hasSameInterfaceName(@Nullable final NetworkAgentInfo nai1,
+            @Nullable final NetworkAgentInfo nai2) {
+        if (null == nai1) return null == nai2;
+        if (null == nai2) return false;
+        return nai1.linkProperties.getInterfaceName()
+                .equals(nai2.linkProperties.getInterfaceName());
+    }
+
     private void applyNetworkReassignment(@NonNull final NetworkReassignment changes,
             final long now) {
         final Collection<NetworkAgentInfo> nais = mNetworkAgentInfos;
@@ -9926,6 +10046,39 @@
             notifyNetworkLosing(nai, now);
         }
 
+        // Update forwarding rules for the upstreams of local networks. Do this after sending
+        // onAvailable so that clients understand what network this is about.
+        for (final NetworkAgentInfo nai : mNetworkAgentInfos) {
+            if (!nai.isLocalNetwork()) continue;
+            final NetworkRequest nr = nai.localNetworkConfig.getUpstreamSelector();
+            if (null == nr) continue; // No upstream for this local network
+            final NetworkRequestInfo nri = mNetworkRequests.get(nr);
+            final NetworkReassignment.RequestReassignment change = changes.getReassignment(nri);
+            if (null == change) continue; // No change in upstreams for this network
+            final String fromIface = nai.linkProperties.getInterfaceName();
+            if (!hasSameInterfaceName(change.mOldNetwork, change.mNewNetwork)
+                    || change.mOldNetwork.isDestroyed()) {
+                // There can be a change with the same interface name if the new network is the
+                // replacement for the old network that was unregisteredAfterReplacement.
+                try {
+                    if (null != change.mOldNetwork) {
+                        mRoutingCoordinatorService.removeInterfaceForward(fromIface,
+                                change.mOldNetwork.linkProperties.getInterfaceName());
+                    }
+                    // If the new upstream is already destroyed, there is no point in setting up
+                    // a forward (in fact, it might forward to the interface for some new network !)
+                    // Later when the upstream disconnects CS will try to remove the forward, which
+                    // is ignored with a benign log by RoutingCoordinatorService.
+                    if (null != change.mNewNetwork && !change.mNewNetwork.isDestroyed()) {
+                        mRoutingCoordinatorService.addInterfaceForward(fromIface,
+                                change.mNewNetwork.linkProperties.getInterfaceName());
+                    }
+                } catch (final RemoteException e) {
+                    loge("Can't update forwarding rules", e);
+                }
+            }
+        }
+
         updateLegacyTypeTrackerAndVpnLockdownForRematch(changes, nais);
 
         // Tear down all unneeded networks.
diff --git a/service/src/com/android/server/connectivity/NetworkAgentInfo.java b/service/src/com/android/server/connectivity/NetworkAgentInfo.java
index dacae20..7cd3cc8 100644
--- a/service/src/com/android/server/connectivity/NetworkAgentInfo.java
+++ b/service/src/com/android/server/connectivity/NetworkAgentInfo.java
@@ -1258,6 +1258,11 @@
         return networkCapabilities.hasTransport(NetworkCapabilities.TRANSPORT_VPN);
     }
 
+    /** Whether this network is a local network */
+    public boolean isLocalNetwork() {
+        return networkCapabilities.hasCapability(NET_CAPABILITY_LOCAL_NETWORK);
+    }
+
     /**
      * Whether this network should propagate the capabilities from its underlying networks.
      * Currently only true for VPNs.
diff --git a/service/src/com/android/server/connectivity/RoutingCoordinatorService.java b/service/src/com/android/server/connectivity/RoutingCoordinatorService.java
index ede78ce..3350d2d 100644
--- a/service/src/com/android/server/connectivity/RoutingCoordinatorService.java
+++ b/service/src/com/android/server/connectivity/RoutingCoordinatorService.java
@@ -197,8 +197,16 @@
         synchronized (mIfacesLock) {
             final ForwardingPair fwp = new ForwardingPair(fromIface, toIface);
             if (!mForwardedInterfaces.contains(fwp)) {
-                throw new IllegalStateException("No forward set up between interfaces "
-                        + fromIface + " → " + toIface);
+                // This can happen when an upstream was unregisteredAfterReplacement. The forward
+                // is removed immediately when the upstream is destroyed, but later when the
+                // network actually disconnects CS does not know that and it asks for removal
+                // again.
+                // This can also happen if the network was destroyed before being set as an
+                // upstream, because then CS does not set up the forward rules seeing how the
+                // interface was removed anyway.
+                // Either way, this is benign.
+                Log.i(TAG, "No forward set up between interfaces " + fromIface + " → " + toIface);
+                return;
             }
             mForwardedInterfaces.remove(fwp);
             try {
diff --git a/tests/unit/java/com/android/server/connectivityservice/CSLocalAgentTests.kt b/tests/unit/java/com/android/server/connectivityservice/CSLocalAgentTests.kt
index d9f7f9f..3a76ad0 100644
--- a/tests/unit/java/com/android/server/connectivityservice/CSLocalAgentTests.kt
+++ b/tests/unit/java/com/android/server/connectivityservice/CSLocalAgentTests.kt
@@ -21,13 +21,19 @@
 import android.net.LinkProperties
 import android.net.LocalNetworkConfig
 import android.net.NetworkCapabilities
+import android.net.NetworkCapabilities.NET_CAPABILITY_DUN
+import android.net.NetworkCapabilities.NET_CAPABILITY_INTERNET
 import android.net.NetworkCapabilities.NET_CAPABILITY_LOCAL_NETWORK
 import android.net.NetworkCapabilities.NET_CAPABILITY_NOT_RESTRICTED
 import android.net.NetworkCapabilities.NET_CAPABILITY_NOT_ROAMING
 import android.net.NetworkCapabilities.NET_CAPABILITY_NOT_SUSPENDED
 import android.net.NetworkCapabilities.NET_CAPABILITY_NOT_VCN_MANAGED
+import android.net.NetworkCapabilities.TRANSPORT_CELLULAR
 import android.net.NetworkCapabilities.TRANSPORT_WIFI
 import android.net.NetworkRequest
+import android.net.NetworkScore
+import android.net.NetworkScore.KEEP_CONNECTED_FOR_TEST
+import android.net.NetworkScore.KEEP_CONNECTED_LOCAL_NETWORK
 import android.net.RouteInfo
 import android.os.Build
 import com.android.testutils.DevSdkIgnoreRule
@@ -40,8 +46,17 @@
 import com.android.testutils.TestableNetworkCallback
 import org.junit.Test
 import org.junit.runner.RunWith
+import org.mockito.Mockito.clearInvocations
+import org.mockito.Mockito.inOrder
+import org.mockito.Mockito.never
+import org.mockito.Mockito.timeout
+import org.mockito.Mockito.verify
 import kotlin.test.assertFailsWith
 
+private const val TIMEOUT_MS = 200L
+private const val MEDIUM_TIMEOUT_MS = 1_000L
+private const val LONG_TIMEOUT_MS = 5_000
+
 private fun nc(transport: Int, vararg caps: Int) = NetworkCapabilities.Builder().apply {
     addTransportType(transport)
     caps.forEach {
@@ -60,6 +75,12 @@
     addRoute(RouteInfo(IpPrefix("0.0.0.0/0"), null, null))
 }
 
+// This allows keeping all the networks connected without having to file individual requests
+// for them.
+private fun keepScore() = FromS(
+        NetworkScore.Builder().setKeepConnectedReason(KEEP_CONNECTED_FOR_TEST).build()
+)
+
 @RunWith(DevSdkIgnoreRunner::class)
 @DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.TIRAMISU)
 class CSLocalAgentTests : CSTest() {
@@ -125,7 +146,8 @@
                 cb)
 
         // Set up a local agent that should forward its traffic to the best DUN upstream.
-        val localAgent = Agent(nc = nc(TRANSPORT_WIFI, NET_CAPABILITY_LOCAL_NETWORK),
+        val localAgent = Agent(
+                nc = nc(TRANSPORT_WIFI, NET_CAPABILITY_LOCAL_NETWORK),
                 lp = lp("local0"),
                 lnc = LocalNetworkConfig.Builder().build(),
         )
@@ -145,4 +167,242 @@
 
         localAgent.disconnect()
     }
+
+    @Test
+    fun testUnregisterUpstreamAfterReplacement_SameIfaceName() {
+        doTestUnregisterUpstreamAfterReplacement(true)
+    }
+
+    @Test
+    fun testUnregisterUpstreamAfterReplacement_DifferentIfaceName() {
+        doTestUnregisterUpstreamAfterReplacement(false)
+    }
+
+    fun doTestUnregisterUpstreamAfterReplacement(sameIfaceName: Boolean) {
+        deps.setBuildSdk(VERSION_V)
+        val cb = TestableNetworkCallback()
+        cm.registerNetworkCallback(NetworkRequest.Builder().clearCapabilities().build(), cb)
+
+        // Set up a local agent that should forward its traffic to the best wifi upstream.
+        val localAgent = Agent(nc = nc(TRANSPORT_WIFI, NET_CAPABILITY_LOCAL_NETWORK),
+                lp = lp("local0"),
+                lnc = LocalNetworkConfig.Builder()
+                .setUpstreamSelector(NetworkRequest.Builder()
+                        .addTransportType(TRANSPORT_WIFI)
+                        .build())
+                .build(),
+                score = FromS(NetworkScore.Builder()
+                        .setKeepConnectedReason(KEEP_CONNECTED_LOCAL_NETWORK)
+                        .build())
+        )
+        localAgent.connect()
+
+        cb.expectAvailableCallbacks(localAgent.network, validated = false)
+
+        val wifiAgent = Agent(lp = lp("wifi0"),
+                nc = nc(TRANSPORT_WIFI, NET_CAPABILITY_INTERNET))
+        wifiAgent.connect()
+
+        cb.expectAvailableCallbacks(wifiAgent.network, validated = false)
+
+        clearInvocations(netd)
+        val inOrder = inOrder(netd)
+        wifiAgent.unregisterAfterReplacement(LONG_TIMEOUT_MS)
+        waitForIdle()
+        inOrder.verify(netd).ipfwdRemoveInterfaceForward("local0", "wifi0")
+        inOrder.verify(netd).networkDestroy(wifiAgent.network.netId)
+
+        val wifiIface2 = if (sameIfaceName) "wifi0" else "wifi1"
+        val wifiAgent2 = Agent(lp = lp(wifiIface2),
+                nc = nc(TRANSPORT_WIFI, NET_CAPABILITY_INTERNET))
+        wifiAgent2.connect()
+
+        cb.expectAvailableCallbacks(wifiAgent2.network, validated = false)
+        cb.expect<Lost> { it.network == wifiAgent.network }
+
+        inOrder.verify(netd).ipfwdAddInterfaceForward("local0", wifiIface2)
+        if (sameIfaceName) {
+            inOrder.verify(netd, never()).ipfwdRemoveInterfaceForward(any(), any())
+        }
+    }
+
+    @Test
+    fun testUnregisterUpstreamAfterReplacement_neverReplaced() {
+        deps.setBuildSdk(VERSION_V)
+        val cb = TestableNetworkCallback()
+        cm.registerNetworkCallback(NetworkRequest.Builder().clearCapabilities().build(), cb)
+
+        // Set up a local agent that should forward its traffic to the best wifi upstream.
+        val localAgent = Agent(nc = nc(TRANSPORT_WIFI, NET_CAPABILITY_LOCAL_NETWORK),
+                lp = lp("local0"),
+                lnc = LocalNetworkConfig.Builder()
+                        .setUpstreamSelector(NetworkRequest.Builder()
+                                .addTransportType(TRANSPORT_WIFI)
+                                .build())
+                        .build(),
+                score = FromS(NetworkScore.Builder()
+                        .setKeepConnectedReason(KEEP_CONNECTED_LOCAL_NETWORK)
+                        .build())
+        )
+        localAgent.connect()
+
+        cb.expectAvailableCallbacks(localAgent.network, validated = false)
+
+        val wifiAgent = Agent(lp = lp("wifi0"),
+                nc = nc(TRANSPORT_WIFI, NET_CAPABILITY_INTERNET))
+        wifiAgent.connect()
+
+        cb.expectAvailableCallbacksUnvalidated(wifiAgent)
+
+        clearInvocations(netd)
+        wifiAgent.unregisterAfterReplacement(TIMEOUT_MS.toInt())
+        waitForIdle()
+        verify(netd).networkDestroy(wifiAgent.network.netId)
+        verify(netd).ipfwdRemoveInterfaceForward("local0", "wifi0")
+
+        cb.expect<Lost> { it.network == wifiAgent.network }
+    }
+
+    @Test
+    fun testUnregisterLocalAgentAfterReplacement() {
+        deps.setBuildSdk(VERSION_V)
+
+        val localCb = TestableNetworkCallback()
+        cm.requestNetwork(NetworkRequest.Builder().clearCapabilities()
+                .addCapability(NET_CAPABILITY_LOCAL_NETWORK)
+                .build(),
+                localCb)
+
+        val cb = TestableNetworkCallback()
+        cm.registerNetworkCallback(NetworkRequest.Builder().clearCapabilities().build(), cb)
+
+        val localNc = nc(TRANSPORT_WIFI, NET_CAPABILITY_LOCAL_NETWORK)
+        val lnc = LocalNetworkConfig.Builder()
+                .setUpstreamSelector(NetworkRequest.Builder()
+                        .addTransportType(TRANSPORT_WIFI)
+                        .build())
+                .build()
+        val localScore = FromS(NetworkScore.Builder().build())
+
+        // Set up a local agent that should forward its traffic to the best wifi upstream.
+        val localAgent = Agent(nc = localNc, lp = lp("local0"), lnc = lnc, score = localScore)
+        localAgent.connect()
+
+        localCb.expectAvailableCallbacks(localAgent.network, validated = false)
+        cb.expectAvailableCallbacks(localAgent.network, validated = false)
+
+        val wifiAgent = Agent(lp = lp("wifi0"), nc = nc(TRANSPORT_WIFI, NET_CAPABILITY_INTERNET))
+        wifiAgent.connect()
+
+        cb.expectAvailableCallbacksUnvalidated(wifiAgent)
+
+        verify(netd).ipfwdAddInterfaceForward("local0", "wifi0")
+
+        localAgent.unregisterAfterReplacement(LONG_TIMEOUT_MS)
+
+        val localAgent2 = Agent(nc = localNc, lp = lp("local0"), lnc = lnc, score = localScore)
+        localAgent2.connect()
+
+        localCb.expectAvailableCallbacks(localAgent2.network, validated = false)
+        cb.expectAvailableCallbacks(localAgent2.network, validated = false)
+        cb.expect<Lost> { it.network == localAgent.network }
+    }
+
+    @Test
+    fun testDestroyedNetworkAsSelectedUpstream() {
+        deps.setBuildSdk(VERSION_V)
+        val cb = TestableNetworkCallback()
+        cm.registerNetworkCallback(NetworkRequest.Builder().clearCapabilities().build(), cb)
+
+        val wifiAgent = Agent(lp = lp("wifi0"), nc = nc(TRANSPORT_WIFI, NET_CAPABILITY_INTERNET))
+        wifiAgent.connect()
+        cb.expectAvailableCallbacksUnvalidated(wifiAgent)
+
+        // Set up a local agent that should forward its traffic to the best wifi upstream.
+        val localAgent = Agent(nc = nc(TRANSPORT_WIFI, NET_CAPABILITY_LOCAL_NETWORK),
+                lp = lp("local0"),
+                lnc = LocalNetworkConfig.Builder()
+                        .setUpstreamSelector(NetworkRequest.Builder()
+                                .addTransportType(TRANSPORT_WIFI)
+                                .build())
+                        .build(),
+                score = FromS(NetworkScore.Builder()
+                        .setKeepConnectedReason(KEEP_CONNECTED_LOCAL_NETWORK)
+                        .build())
+        )
+
+        // ...but destroy the wifi agent before connecting it
+        wifiAgent.unregisterAfterReplacement(LONG_TIMEOUT_MS)
+
+        localAgent.connect()
+        cb.expectAvailableCallbacks(localAgent.network, validated = false)
+
+        verify(netd).ipfwdAddInterfaceForward("local0", "wifi0")
+        verify(netd).ipfwdRemoveInterfaceForward("local0", "wifi0")
+    }
+
+    @Test
+    fun testForwardingRules() {
+        deps.setBuildSdk(VERSION_V)
+        // Set up a local agent that should forward its traffic to the best DUN upstream.
+        val lnc = LocalNetworkConfig.Builder()
+                .setUpstreamSelector(NetworkRequest.Builder()
+                        .addCapability(NET_CAPABILITY_DUN)
+                        .build())
+                .build()
+        val localAgent = Agent(nc = nc(TRANSPORT_WIFI, NET_CAPABILITY_LOCAL_NETWORK),
+                lp = lp("local0"),
+                lnc = lnc,
+                score = FromS(NetworkScore.Builder()
+                        .setKeepConnectedReason(KEEP_CONNECTED_LOCAL_NETWORK)
+                        .build())
+        )
+        localAgent.connect()
+
+        val wifiAgent = Agent(score = keepScore(), lp = lp("wifi0"),
+                nc = nc(TRANSPORT_WIFI, NET_CAPABILITY_INTERNET))
+        val cellAgentDun = Agent(score = keepScore(), lp = lp("cell0"),
+                nc = nc(TRANSPORT_CELLULAR, NET_CAPABILITY_INTERNET, NET_CAPABILITY_DUN))
+        val wifiAgentDun = Agent(score = keepScore(), lp = lp("wifi1"),
+                nc = nc(TRANSPORT_WIFI, NET_CAPABILITY_INTERNET, NET_CAPABILITY_DUN))
+
+        val inOrder = inOrder(netd)
+        inOrder.verify(netd, never()).ipfwdAddInterfaceForward(any(), any())
+
+        wifiAgent.connect()
+        inOrder.verify(netd, never()).ipfwdAddInterfaceForward(any(), any())
+
+        cellAgentDun.connect()
+        inOrder.verify(netd).ipfwdEnableForwarding(any())
+        inOrder.verify(netd).ipfwdAddInterfaceForward("local0", "cell0")
+
+        wifiAgentDun.connect()
+        inOrder.verify(netd).ipfwdRemoveInterfaceForward("local0", "cell0")
+        inOrder.verify(netd).ipfwdAddInterfaceForward("local0", "wifi1")
+
+        // Make sure sending the same config again doesn't do anything
+        repeat(5) {
+            localAgent.sendLocalNetworkConfig(lnc)
+        }
+        inOrder.verifyNoMoreInteractions()
+
+        wifiAgentDun.disconnect()
+        inOrder.verify(netd).ipfwdRemoveInterfaceForward("local0", "wifi1")
+        // This can take a little bit of time because it needs to wait for the rematch
+        inOrder.verify(netd, timeout(MEDIUM_TIMEOUT_MS)).ipfwdAddInterfaceForward("local0", "cell0")
+
+        cellAgentDun.disconnect()
+        inOrder.verify(netd).ipfwdRemoveInterfaceForward("local0", "cell0")
+        inOrder.verify(netd).ipfwdDisableForwarding(any())
+
+        val wifiAgentDun2 = Agent(score = keepScore(), lp = lp("wifi2"),
+                nc = nc(TRANSPORT_WIFI, NET_CAPABILITY_INTERNET, NET_CAPABILITY_DUN))
+        wifiAgentDun2.connect()
+        inOrder.verify(netd).ipfwdEnableForwarding(any())
+        inOrder.verify(netd).ipfwdAddInterfaceForward("local0", "wifi2")
+
+        localAgent.disconnect()
+        inOrder.verify(netd).ipfwdRemoveInterfaceForward("local0", "wifi2")
+        inOrder.verify(netd).ipfwdDisableForwarding(any())
+    }
 }
diff --git a/tests/unit/java/com/android/server/connectivityservice/base/CSAgentWrapper.kt b/tests/unit/java/com/android/server/connectivityservice/base/CSAgentWrapper.kt
index f903e51..013a749 100644
--- a/tests/unit/java/com/android/server/connectivityservice/base/CSAgentWrapper.kt
+++ b/tests/unit/java/com/android/server/connectivityservice/base/CSAgentWrapper.kt
@@ -168,6 +168,7 @@
     }
 
     fun unregisterAfterReplacement(timeoutMs: Int) = agent.unregisterAfterReplacement(timeoutMs)
+
     fun sendLocalNetworkConfig(lnc: LocalNetworkConfig) = agent.sendLocalNetworkConfig(lnc)
     fun sendNetworkCapabilities(nc: NetworkCapabilities) = agent.sendNetworkCapabilities(nc)
 }