Merge "BpfMap.java - add requires S+ annotation" into main
diff --git a/framework/src/android/net/BpfNetMapsReader.java b/framework/src/android/net/BpfNetMapsReader.java
index 37c58f0..4ab6d3e 100644
--- a/framework/src/android/net/BpfNetMapsReader.java
+++ b/framework/src/android/net/BpfNetMapsReader.java
@@ -17,6 +17,9 @@
 package android.net;
 
 import static android.net.BpfNetMapsConstants.CONFIGURATION_MAP_PATH;
+import static android.net.BpfNetMapsConstants.DATA_SAVER_ENABLED;
+import static android.net.BpfNetMapsConstants.DATA_SAVER_ENABLED_KEY;
+import static android.net.BpfNetMapsConstants.DATA_SAVER_ENABLED_MAP_PATH;
 import static android.net.BpfNetMapsConstants.HAPPY_BOX_MATCH;
 import static android.net.BpfNetMapsConstants.PENALTY_BOX_MATCH;
 import static android.net.BpfNetMapsConstants.UID_OWNER_MAP_PATH;
@@ -33,14 +36,15 @@
 import android.os.ServiceSpecificException;
 import android.system.ErrnoException;
 import android.system.Os;
+import android.util.Log;
 
 import com.android.internal.annotations.VisibleForTesting;
 import com.android.modules.utils.build.SdkLevel;
 import com.android.net.module.util.BpfMap;
 import com.android.net.module.util.IBpfMap;
-import com.android.net.module.util.Struct;
 import com.android.net.module.util.Struct.S32;
 import com.android.net.module.util.Struct.U32;
+import com.android.net.module.util.Struct.U8;
 
 /**
  * A helper class to *read* java BpfMaps.
@@ -48,6 +52,8 @@
  */
 @RequiresApi(Build.VERSION_CODES.TIRAMISU)  // BPF maps were only mainlined in T
 public class BpfNetMapsReader {
+    private static final String TAG = BpfNetMapsReader.class.getSimpleName();
+
     // Locally store the handle of bpf maps. The FileDescriptors are statically cached inside the
     // BpfMap implementation.
 
@@ -57,6 +63,7 @@
     // Bpf map to store per uid traffic control configurations.
     // See {@link UidOwnerValue} for more detail.
     private final IBpfMap<S32, UidOwnerValue> mUidOwnerMap;
+    private final IBpfMap<S32, U8> mDataSaverEnabledMap;
     private final Dependencies mDeps;
 
     // Bitmaps for calculating whether a given uid is blocked by firewall chains.
@@ -104,6 +111,7 @@
         mDeps = deps;
         mConfigurationMap = mDeps.getConfigurationMap();
         mUidOwnerMap = mDeps.getUidOwnerMap();
+        mDataSaverEnabledMap = mDeps.getDataSaverEnabledMap();
     }
 
     /**
@@ -130,6 +138,16 @@
                 throw new IllegalStateException("Cannot open uid owner map", e);
             }
         }
+
+        /** Get the data saver enabled map. */
+        public  IBpfMap<S32, U8> getDataSaverEnabledMap() {
+            try {
+                return new BpfMap<>(DATA_SAVER_ENABLED_MAP_PATH, BpfMap.BPF_F_RDONLY, S32.class,
+                        U8.class);
+            } catch (ErrnoException e) {
+                throw new IllegalStateException("Cannot open data saver enabled map", e);
+            }
+        }
     }
 
     /**
@@ -171,12 +189,12 @@
      *                                  cause of the failure.
      */
     public static boolean isChainEnabled(
-            final IBpfMap<Struct.S32, Struct.U32> configurationMap, final int chain) {
+            final IBpfMap<S32, U32> configurationMap, final int chain) {
         throwIfPreT("isChainEnabled is not available on pre-T devices");
 
         final long match = getMatchByFirewallChain(chain);
         try {
-            final Struct.U32 config = configurationMap.getValue(UID_RULES_CONFIGURATION_KEY);
+            final U32 config = configurationMap.getValue(UID_RULES_CONFIGURATION_KEY);
             return (config.val & match) != 0;
         } catch (ErrnoException e) {
             throw new ServiceSpecificException(e.errno,
@@ -195,14 +213,14 @@
      * @throws ServiceSpecificException      in case of failure, with an error code indicating the
      *                                       cause of the failure.
      */
-    public static int getUidRule(final IBpfMap<Struct.S32, UidOwnerValue> uidOwnerMap,
+    public static int getUidRule(final IBpfMap<S32, UidOwnerValue> uidOwnerMap,
             final int chain, final int uid) {
         throwIfPreT("getUidRule is not available on pre-T devices");
 
         final long match = getMatchByFirewallChain(chain);
         final boolean isAllowList = isFirewallAllowList(chain);
         try {
-            final UidOwnerValue uidMatch = uidOwnerMap.getValue(new Struct.S32(uid));
+            final UidOwnerValue uidMatch = uidOwnerMap.getValue(new S32(uid));
             final boolean isMatchEnabled = uidMatch != null && (uidMatch.rule & match) != 0;
             return isMatchEnabled == isAllowList ? FIREWALL_RULE_ALLOW : FIREWALL_RULE_DENY;
         } catch (ErrnoException e) {
@@ -249,4 +267,29 @@
         if ((uidMatch & HAPPY_BOX_MATCH) != 0) return false;
         return isDataSaverEnabled;
     }
+
+    /**
+     * Get Data Saver enabled or disabled
+     *
+     * @return whether Data Saver is enabled or disabled.
+     * @throws ServiceSpecificException in case of failure, with an error code indicating the
+     *                                  cause of the failure.
+     */
+    public boolean getDataSaverEnabled() {
+        throwIfPreT("getDataSaverEnabled is not available on pre-T devices");
+
+        // Note that this is not expected to be called until V given that it relies on the
+        // counterpart platform solution to set data saver status to bpf.
+        // See {@code NetworkManagementService#setDataSaverModeEnabled}.
+        if (!SdkLevel.isAtLeastV()) {
+            Log.wtf(TAG, "getDataSaverEnabled is not expected to be called on pre-V devices");
+        }
+
+        try {
+            return mDataSaverEnabledMap.getValue(DATA_SAVER_ENABLED_KEY).val == DATA_SAVER_ENABLED;
+        } catch (ErrnoException e) {
+            throw new ServiceSpecificException(e.errno, "Unable to get data saver: "
+                    + Os.strerror(e.errno));
+        }
+    }
 }
diff --git a/framework/src/android/net/ConnectivityManager.java b/framework/src/android/net/ConnectivityManager.java
index 57ecf49..a934ddb 100644
--- a/framework/src/android/net/ConnectivityManager.java
+++ b/framework/src/android/net/ConnectivityManager.java
@@ -83,6 +83,7 @@
 
 import com.android.internal.annotations.GuardedBy;
 import com.android.internal.annotations.VisibleForTesting;
+import com.android.modules.utils.build.SdkLevel;
 
 import libcore.net.event.NetworkEventDispatcher;
 
@@ -6371,10 +6372,13 @@
         final BpfNetMapsReader reader = BpfNetMapsReader.getInstance();
 
         final boolean isDataSaverEnabled;
-        // TODO: For U-QPR3+ devices, get data saver status from bpf configuration map directly.
-        final DataSaverStatusTracker dataSaverStatusTracker =
-                DataSaverStatusTracker.getInstance(mContext);
-        isDataSaverEnabled = dataSaverStatusTracker.getDataSaverEnabled();
+        if (SdkLevel.isAtLeastV()) {
+            isDataSaverEnabled = reader.getDataSaverEnabled();
+        } else {
+            final DataSaverStatusTracker dataSaverStatusTracker =
+                    DataSaverStatusTracker.getInstance(mContext);
+            isDataSaverEnabled = dataSaverStatusTracker.getDataSaverEnabled();
+        }
 
         return reader.isUidNetworkingBlocked(uid, isNetworkMetered, isDataSaverEnabled);
     }
diff --git a/service/Android.bp b/service/Android.bp
index 82f64ba..76741bc 100644
--- a/service/Android.bp
+++ b/service/Android.bp
@@ -107,10 +107,6 @@
         "-Werror",
         "-Wno-unused-parameter",
         "-Wthread-safety",
-
-        // AServiceManager_waitForService is available on only 31+, but it's still safe for Thread
-        // service because it's enabled on only 34+
-        "-Wno-unguarded-availability",
     ],
     srcs: [
         ":services.connectivity-netstats-jni-sources",
diff --git a/service/jni/com_android_server_ServiceManagerWrapper.cpp b/service/jni/com_android_server_ServiceManagerWrapper.cpp
index 0cd58f4..0e32726 100644
--- a/service/jni/com_android_server_ServiceManagerWrapper.cpp
+++ b/service/jni/com_android_server_ServiceManagerWrapper.cpp
@@ -25,7 +25,13 @@
 static jobject com_android_server_ServiceManagerWrapper_waitForService(
         JNIEnv* env, jobject clazz, jstring serviceName) {
     ScopedUtfChars name(env, serviceName);
+
+// AServiceManager_waitForService is available on only 31+, but it's still safe for Thread
+// service because it's enabled on only 34+
+#pragma clang diagnostic push
+#pragma clang diagnostic ignored "-Wunguarded-availability"
     return AIBinder_toJavaBinder(env, AServiceManager_waitForService(name.c_str()));
+#pragma clang diagnostic pop
 }
 
 /*
diff --git a/service/src/com/android/server/ConnectivityService.java b/service/src/com/android/server/ConnectivityService.java
index f904cd1..ba9ea86 100755
--- a/service/src/com/android/server/ConnectivityService.java
+++ b/service/src/com/android/server/ConnectivityService.java
@@ -6262,8 +6262,10 @@
                     if (!networkFound) return;
 
                     if (underpinnedNetworkFound) {
+                        final NetworkCapabilities underpinnedNc =
+                                getNetworkCapabilitiesInternal(underpinnedNetwork);
                         mKeepaliveTracker.handleMonitorAutomaticKeepalive(ki,
-                                underpinnedNetwork.netId);
+                                underpinnedNetwork.netId, underpinnedNc.getUids());
                     } else {
                         // If no underpinned network, then make sure the keepalive is running.
                         mKeepaliveTracker.handleMaybeResumeKeepalive(ki);
@@ -10083,6 +10085,45 @@
         // Process default network changes if applicable.
         processDefaultNetworkChanges(changes);
 
+        // Update forwarding rules for the upstreams of local networks. Do this before sending
+        // onAvailable so that by the time onAvailable is sent the forwarding rules are set up.
+        // Don't send CALLBACK_LOCAL_NETWORK_INFO_CHANGED yet though : they should be sent after
+        // onAvailable so clients know what network the change is about. Store such changes in
+        // an array that's only allocated if necessary (because it's almost never necessary).
+        ArrayList<NetworkAgentInfo> localInfoChangedAgents = null;
+        for (final NetworkAgentInfo nai : mNetworkAgentInfos) {
+            if (!nai.isLocalNetwork()) continue;
+            final NetworkRequest nr = nai.localNetworkConfig.getUpstreamSelector();
+            if (null == nr) continue; // No upstream for this local network
+            final NetworkRequestInfo nri = mNetworkRequests.get(nr);
+            final NetworkReassignment.RequestReassignment change = changes.getReassignment(nri);
+            if (null == change) continue; // No change in upstreams for this network
+            final String fromIface = nai.linkProperties.getInterfaceName();
+            if (!hasSameInterfaceName(change.mOldNetwork, change.mNewNetwork)
+                    || change.mOldNetwork.isDestroyed()) {
+                // There can be a change with the same interface name if the new network is the
+                // replacement for the old network that was unregisteredAfterReplacement.
+                try {
+                    if (null != change.mOldNetwork) {
+                        mRoutingCoordinatorService.removeInterfaceForward(fromIface,
+                                change.mOldNetwork.linkProperties.getInterfaceName());
+                    }
+                    // If the new upstream is already destroyed, there is no point in setting up
+                    // a forward (in fact, it might forward to the interface for some new network !)
+                    // Later when the upstream disconnects CS will try to remove the forward, which
+                    // is ignored with a benign log by RoutingCoordinatorService.
+                    if (null != change.mNewNetwork && !change.mNewNetwork.isDestroyed()) {
+                        mRoutingCoordinatorService.addInterfaceForward(fromIface,
+                                change.mNewNetwork.linkProperties.getInterfaceName());
+                    }
+                } catch (final RemoteException e) {
+                    loge("Can't update forwarding rules", e);
+                }
+            }
+            if (null == localInfoChangedAgents) localInfoChangedAgents = new ArrayList<>();
+            localInfoChangedAgents.add(nai);
+        }
+
         // Notify requested networks are available after the default net is switched, but
         // before LegacyTypeTracker sends legacy broadcasts
         for (final NetworkReassignment.RequestReassignment event :
@@ -10131,38 +10172,12 @@
             notifyNetworkLosing(nai, now);
         }
 
-        // Update forwarding rules for the upstreams of local networks. Do this after sending
-        // onAvailable so that clients understand what network this is about.
-        for (final NetworkAgentInfo nai : mNetworkAgentInfos) {
-            if (!nai.isLocalNetwork()) continue;
-            final NetworkRequest nr = nai.localNetworkConfig.getUpstreamSelector();
-            if (null == nr) continue; // No upstream for this local network
-            final NetworkRequestInfo nri = mNetworkRequests.get(nr);
-            final NetworkReassignment.RequestReassignment change = changes.getReassignment(nri);
-            if (null == change) continue; // No change in upstreams for this network
-            final String fromIface = nai.linkProperties.getInterfaceName();
-            if (!hasSameInterfaceName(change.mOldNetwork, change.mNewNetwork)
-                    || change.mOldNetwork.isDestroyed()) {
-                // There can be a change with the same interface name if the new network is the
-                // replacement for the old network that was unregisteredAfterReplacement.
-                try {
-                    if (null != change.mOldNetwork) {
-                        mRoutingCoordinatorService.removeInterfaceForward(fromIface,
-                                change.mOldNetwork.linkProperties.getInterfaceName());
-                    }
-                    // If the new upstream is already destroyed, there is no point in setting up
-                    // a forward (in fact, it might forward to the interface for some new network !)
-                    // Later when the upstream disconnects CS will try to remove the forward, which
-                    // is ignored with a benign log by RoutingCoordinatorService.
-                    if (null != change.mNewNetwork && !change.mNewNetwork.isDestroyed()) {
-                        mRoutingCoordinatorService.addInterfaceForward(fromIface,
-                                change.mNewNetwork.linkProperties.getInterfaceName());
-                    }
-                } catch (final RemoteException e) {
-                    loge("Can't update forwarding rules", e);
-                }
+        // Send LOCAL_NETWORK_INFO_CHANGED callbacks now that onAvailable and onLost have been sent.
+        if (null != localInfoChangedAgents) {
+            for (final NetworkAgentInfo nai : localInfoChangedAgents) {
+                notifyNetworkCallbacks(nai,
+                        ConnectivityManager.CALLBACK_LOCAL_NETWORK_INFO_CHANGED);
             }
-            notifyNetworkCallbacks(nai, ConnectivityManager.CALLBACK_LOCAL_NETWORK_INFO_CHANGED);
         }
 
         updateLegacyTypeTrackerAndVpnLockdownForRematch(changes, nais);
diff --git a/service/src/com/android/server/connectivity/AutomaticOnOffKeepaliveTracker.java b/service/src/com/android/server/connectivity/AutomaticOnOffKeepaliveTracker.java
index bba132f..8036ae9 100644
--- a/service/src/com/android/server/connectivity/AutomaticOnOffKeepaliveTracker.java
+++ b/service/src/com/android/server/connectivity/AutomaticOnOffKeepaliveTracker.java
@@ -53,6 +53,7 @@
 import android.util.LocalLog;
 import android.util.Log;
 import android.util.Pair;
+import android.util.Range;
 import android.util.SparseArray;
 
 import com.android.internal.annotations.VisibleForTesting;
@@ -77,6 +78,7 @@
 import java.util.ArrayList;
 import java.util.List;
 import java.util.Objects;
+import java.util.Set;
 
 /**
  * Manages automatic on/off socket keepalive requests.
@@ -373,26 +375,27 @@
      * Determine if any state transition is needed for the specific automatic keepalive.
      */
     public void handleMonitorAutomaticKeepalive(@NonNull final AutomaticOnOffKeepalive ki,
-            final int vpnNetId) {
+            final int vpnNetId, @NonNull Set<Range<Integer>> vpnUidRanges) {
         // Might happen if the automatic keepalive was removed by the app just as the alarm fires.
         if (!mAutomaticOnOffKeepalives.contains(ki)) return;
         if (STATE_ALWAYS_ON == ki.mAutomaticOnOffState) {
             throw new IllegalStateException("Should not monitor non-auto keepalive");
         }
 
-        handleMonitorTcpConnections(ki, vpnNetId);
+        handleMonitorTcpConnections(ki, vpnNetId, vpnUidRanges);
     }
 
     /**
      * Determine if disable or re-enable keepalive is needed or not based on TCP sockets status.
      */
-    private void handleMonitorTcpConnections(@NonNull AutomaticOnOffKeepalive ki, int vpnNetId) {
+    private void handleMonitorTcpConnections(@NonNull AutomaticOnOffKeepalive ki, int vpnNetId,
+            @NonNull Set<Range<Integer>> vpnUidRanges) {
         // Might happen if the automatic keepalive was removed by the app just as the alarm fires.
         if (!mAutomaticOnOffKeepalives.contains(ki)) return;
         if (STATE_ALWAYS_ON == ki.mAutomaticOnOffState) {
             throw new IllegalStateException("Should not monitor non-auto keepalive");
         }
-        if (!isAnyTcpSocketConnected(vpnNetId)) {
+        if (!isAnyTcpSocketConnected(vpnNetId, vpnUidRanges)) {
             // No TCP socket exists. Stop keepalive if ENABLED, and remain SUSPENDED if currently
             // SUSPENDED.
             if (ki.mAutomaticOnOffState == STATE_ENABLED) {
@@ -744,7 +747,7 @@
     }
 
     @VisibleForTesting
-    boolean isAnyTcpSocketConnected(int netId) {
+    boolean isAnyTcpSocketConnected(int netId, @NonNull Set<Range<Integer>> vpnUidRanges) {
         FileDescriptor fd = null;
 
         try {
@@ -757,7 +760,8 @@
 
             // Send request for each IP family
             for (final int family : ADDRESS_FAMILIES) {
-                if (isAnyTcpSocketConnectedForFamily(fd, family, networkMark, networkMask)) {
+                if (isAnyTcpSocketConnectedForFamily(
+                        fd, family, networkMark, networkMask, vpnUidRanges)) {
                     return true;
                 }
             }
@@ -771,7 +775,8 @@
     }
 
     private boolean isAnyTcpSocketConnectedForFamily(FileDescriptor fd, int family, int networkMark,
-            int networkMask) throws ErrnoException, InterruptedIOException {
+            int networkMask, @NonNull Set<Range<Integer>> vpnUidRanges)
+            throws ErrnoException, InterruptedIOException {
         ensureRunningOnHandlerThread();
         // Build SocketDiag messages and cache it.
         if (mSockDiagMsg.get(family) == null) {
diff --git a/tests/unit/java/android/net/BpfNetMapsReaderTest.kt b/tests/unit/java/android/net/BpfNetMapsReaderTest.kt
index 258e422..9de7f4d 100644
--- a/tests/unit/java/android/net/BpfNetMapsReaderTest.kt
+++ b/tests/unit/java/android/net/BpfNetMapsReaderTest.kt
@@ -16,6 +16,9 @@
 
 package android.net
 
+import android.net.BpfNetMapsConstants.DATA_SAVER_DISABLED
+import android.net.BpfNetMapsConstants.DATA_SAVER_ENABLED
+import android.net.BpfNetMapsConstants.DATA_SAVER_ENABLED_KEY
 import android.net.BpfNetMapsConstants.DOZABLE_MATCH
 import android.net.BpfNetMapsConstants.HAPPY_BOX_MATCH
 import android.net.BpfNetMapsConstants.PENALTY_BOX_MATCH
@@ -26,6 +29,8 @@
 import com.android.net.module.util.IBpfMap
 import com.android.net.module.util.Struct.S32
 import com.android.net.module.util.Struct.U32
+import com.android.net.module.util.Struct.U8
+import com.android.testutils.DevSdkIgnoreRule
 import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo
 import com.android.testutils.DevSdkIgnoreRunner
 import com.android.testutils.TestBpfMap
@@ -33,6 +38,7 @@
 import kotlin.test.assertEquals
 import kotlin.test.assertFalse
 import kotlin.test.assertTrue
+import org.junit.Rule
 import org.junit.Test
 import org.junit.runner.RunWith
 
@@ -45,17 +51,24 @@
 @RunWith(DevSdkIgnoreRunner::class)
 @IgnoreUpTo(VERSION_CODES.S_V2)
 class BpfNetMapsReaderTest {
+    @Rule
+    @JvmField
+    val ignoreRule = DevSdkIgnoreRule()
+
     private val testConfigurationMap: IBpfMap<S32, U32> = TestBpfMap()
     private val testUidOwnerMap: IBpfMap<S32, UidOwnerValue> = TestBpfMap()
+    private val testDataSaverEnabledMap: IBpfMap<S32, U8> = TestBpfMap()
     private val bpfNetMapsReader = BpfNetMapsReader(
-        TestDependencies(testConfigurationMap, testUidOwnerMap))
+        TestDependencies(testConfigurationMap, testUidOwnerMap, testDataSaverEnabledMap))
 
     class TestDependencies(
         private val configMap: IBpfMap<S32, U32>,
-        private val uidOwnerMap: IBpfMap<S32, UidOwnerValue>
+        private val uidOwnerMap: IBpfMap<S32, UidOwnerValue>,
+        private val dataSaverEnabledMap: IBpfMap<S32, U8>
     ) : BpfNetMapsReader.Dependencies() {
         override fun getConfigurationMap() = configMap
         override fun getUidOwnerMap() = uidOwnerMap
+        override fun getDataSaverEnabledMap() = dataSaverEnabledMap
     }
 
     private fun doTestIsChainEnabled(chain: Int) {
@@ -199,4 +212,13 @@
         assertFalse(isUidNetworkingBlocked(TEST_UID2))
         assertFalse(isUidNetworkingBlocked(TEST_UID3))
     }
+
+    @IgnoreUpTo(VERSION_CODES.UPSIDE_DOWN_CAKE)
+    @Test
+    fun testGetDataSaverEnabled() {
+        testDataSaverEnabledMap.updateEntry(DATA_SAVER_ENABLED_KEY, U8(DATA_SAVER_DISABLED))
+        assertFalse(bpfNetMapsReader.dataSaverEnabled)
+        testDataSaverEnabledMap.updateEntry(DATA_SAVER_ENABLED_KEY, U8(DATA_SAVER_ENABLED))
+        assertTrue(bpfNetMapsReader.dataSaverEnabled)
+    }
 }
diff --git a/tests/unit/java/android/net/ConnectivityManagerTest.java b/tests/unit/java/android/net/ConnectivityManagerTest.java
index b8c5447..0082af2 100644
--- a/tests/unit/java/android/net/ConnectivityManagerTest.java
+++ b/tests/unit/java/android/net/ConnectivityManagerTest.java
@@ -90,6 +90,7 @@
 import com.android.testutils.DevSdkIgnoreRunner;
 
 import org.junit.Before;
+import org.junit.Rule;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.mockito.ArgumentCaptor;
@@ -102,6 +103,8 @@
 @SmallTest
 @DevSdkIgnoreRule.IgnoreUpTo(VERSION_CODES.R)
 public class ConnectivityManagerTest {
+    @Rule
+    public final DevSdkIgnoreRule mIgnoreRule = new DevSdkIgnoreRule();
     private static final int TIMEOUT_MS = 30_000;
     private static final int SHORT_TIMEOUT_MS = 150;
 
@@ -524,6 +527,7 @@
                     + " attempts", ref.get());
     }
 
+    @DevSdkIgnoreRule.IgnoreAfter(VERSION_CODES.UPSIDE_DOWN_CAKE)
     @Test
     public void testDataSaverStatusTracker() {
         mockService(NetworkPolicyManager.class, Context.NETWORK_POLICY_SERVICE, mNpm);
diff --git a/tests/unit/java/com/android/internal/net/VpnProfileTest.java b/tests/unit/java/com/android/internal/net/VpnProfileTest.java
index b2dff2e..acae7d2 100644
--- a/tests/unit/java/com/android/internal/net/VpnProfileTest.java
+++ b/tests/unit/java/com/android/internal/net/VpnProfileTest.java
@@ -26,6 +26,7 @@
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertNotEquals;
+import static org.junit.Assert.assertNotSame;
 import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertTrue;
 
@@ -311,4 +312,12 @@
         decoded.password = profile.password;
         assertEquals(profile, decoded);
     }
+
+    @Test
+    public void testClone() {
+        final VpnProfile profile = getSampleIkev2Profile(DUMMY_PROFILE_KEY);
+        final VpnProfile clone = profile.clone();
+        assertEquals(profile, clone);
+        assertNotSame(profile, clone);
+    }
 }
diff --git a/tests/unit/java/com/android/server/connectivity/AutomaticOnOffKeepaliveTrackerTest.java b/tests/unit/java/com/android/server/connectivity/AutomaticOnOffKeepaliveTrackerTest.java
index 8e19c01..10a0982 100644
--- a/tests/unit/java/com/android/server/connectivity/AutomaticOnOffKeepaliveTrackerTest.java
+++ b/tests/unit/java/com/android/server/connectivity/AutomaticOnOffKeepaliveTrackerTest.java
@@ -72,7 +72,9 @@
 import android.os.SystemClock;
 import android.telephony.SubscriptionManager;
 import android.test.suitebuilder.annotation.SmallTest;
+import android.util.ArraySet;
 import android.util.Log;
+import android.util.Range;
 
 import androidx.annotation.NonNull;
 import androidx.annotation.Nullable;
@@ -102,7 +104,9 @@
 import java.nio.ByteBuffer;
 import java.nio.ByteOrder;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.List;
+import java.util.Set;
 
 @RunWith(DevSdkIgnoreRunner.class)
 @SmallTest
@@ -232,6 +236,9 @@
     private static final byte[] TEST_RESPONSE_BYTES =
             HexEncoding.decode(TEST_RESPONSE_HEX.toCharArray(), false);
 
+    private static final Set<Range<Integer>> TEST_UID_RANGES =
+            new ArraySet<>(Arrays.asList(new Range<>(10000, 99999)));
+
     private static class TestKeepaliveInfo {
         private static List<Socket> sOpenSockets = new ArrayList<>();
 
@@ -409,28 +416,28 @@
     public void testIsAnyTcpSocketConnected_runOnNonHandlerThread() throws Exception {
         setupResponseWithSocketExisting();
         assertThrows(IllegalStateException.class,
-                () -> mAOOKeepaliveTracker.isAnyTcpSocketConnected(TEST_NETID));
+                () -> mAOOKeepaliveTracker.isAnyTcpSocketConnected(TEST_NETID, TEST_UID_RANGES));
     }
 
     @Test
     public void testIsAnyTcpSocketConnected_withTargetNetId() throws Exception {
         setupResponseWithSocketExisting();
         assertTrue(visibleOnHandlerThread(mTestHandler,
-                () -> mAOOKeepaliveTracker.isAnyTcpSocketConnected(TEST_NETID)));
+                () -> mAOOKeepaliveTracker.isAnyTcpSocketConnected(TEST_NETID, TEST_UID_RANGES)));
     }
 
     @Test
     public void testIsAnyTcpSocketConnected_withIncorrectNetId() throws Exception {
         setupResponseWithSocketExisting();
         assertFalse(visibleOnHandlerThread(mTestHandler,
-                () -> mAOOKeepaliveTracker.isAnyTcpSocketConnected(OTHER_NETID)));
+                () -> mAOOKeepaliveTracker.isAnyTcpSocketConnected(OTHER_NETID, TEST_UID_RANGES)));
     }
 
     @Test
     public void testIsAnyTcpSocketConnected_noSocketExists() throws Exception {
         setupResponseWithoutSocketExisting();
         assertFalse(visibleOnHandlerThread(mTestHandler,
-                () -> mAOOKeepaliveTracker.isAnyTcpSocketConnected(TEST_NETID)));
+                () -> mAOOKeepaliveTracker.isAnyTcpSocketConnected(TEST_NETID, TEST_UID_RANGES)));
     }
 
     private void triggerEventKeepalive(int slot, int reason) {
@@ -474,14 +481,16 @@
         setupResponseWithoutSocketExisting();
         visibleOnHandlerThread(
                 mTestHandler,
-                () -> mAOOKeepaliveTracker.handleMonitorAutomaticKeepalive(autoKi, TEST_NETID));
+                () -> mAOOKeepaliveTracker.handleMonitorAutomaticKeepalive(
+                        autoKi, TEST_NETID, TEST_UID_RANGES));
     }
 
     private void doResumeKeepalive(AutomaticOnOffKeepalive autoKi) throws Exception {
         setupResponseWithSocketExisting();
         visibleOnHandlerThread(
                 mTestHandler,
-                () -> mAOOKeepaliveTracker.handleMonitorAutomaticKeepalive(autoKi, TEST_NETID));
+                () -> mAOOKeepaliveTracker.handleMonitorAutomaticKeepalive(
+                        autoKi, TEST_NETID, TEST_UID_RANGES));
     }
 
     private void doStopKeepalive(AutomaticOnOffKeepalive autoKi) throws Exception {
diff --git a/tests/unit/java/com/android/server/connectivity/VpnTest.java b/tests/unit/java/com/android/server/connectivity/VpnTest.java
index ff801e5..109c132 100644
--- a/tests/unit/java/com/android/server/connectivity/VpnTest.java
+++ b/tests/unit/java/com/android/server/connectivity/VpnTest.java
@@ -74,7 +74,9 @@
 import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.ArgumentMatchers.longThat;
 import static org.mockito.Mockito.after;
+import static org.mockito.Mockito.atLeast;
 import static org.mockito.Mockito.atLeastOnce;
+import static org.mockito.Mockito.clearInvocations;
 import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.doCallRealMethod;
 import static org.mockito.Mockito.doNothing;
@@ -188,6 +190,7 @@
 import org.mockito.AdditionalAnswers;
 import org.mockito.Answers;
 import org.mockito.ArgumentCaptor;
+import org.mockito.Captor;
 import org.mockito.InOrder;
 import org.mockito.Mock;
 import org.mockito.MockitoAnnotations;
@@ -204,6 +207,7 @@
 import java.net.UnknownHostException;
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
@@ -314,6 +318,8 @@
     @Mock DeviceIdleInternal mDeviceIdleInternal;
     private final VpnProfile mVpnProfile;
 
+    @Captor private ArgumentCaptor<Collection<Range<Integer>>> mUidRangesCaptor;
+
     private IpSecManager mIpSecManager;
     private TestDeps mTestDeps;
 
@@ -1093,37 +1099,53 @@
         }
     }
 
-    private Vpn prepareVpnForVerifyAppExclusionList() throws Exception {
-        final Vpn vpn = createVpn(AppOpsManager.OPSTR_ACTIVATE_PLATFORM_VPN);
+    private String startVpnForVerifyAppExclusionList(Vpn vpn) throws Exception {
         when(mVpnProfileStore.get(vpn.getProfileNameForPackage(TEST_VPN_PKG)))
                 .thenReturn(mVpnProfile.encode());
         when(mVpnProfileStore.get(PRIMARY_USER_APP_EXCLUDE_KEY))
                 .thenReturn(HexDump.hexStringToByteArray(PKGS_BYTES));
-
-        vpn.startVpnProfile(TEST_VPN_PKG);
+        final String sessionKey = vpn.startVpnProfile(TEST_VPN_PKG);
+        final Set<Range<Integer>> uidRanges = vpn.createUserAndRestrictedProfilesRanges(
+                PRIMARY_USER.id, null /* allowedApplications */, Arrays.asList(PKGS));
+        verify(mConnectivityManager).setVpnDefaultForUids(eq(sessionKey), eq(uidRanges));
+        clearInvocations(mConnectivityManager);
         verify(mVpnProfileStore).get(eq(vpn.getProfileNameForPackage(TEST_VPN_PKG)));
         vpn.mNetworkAgent = mMockNetworkAgent;
+
+        return sessionKey;
+    }
+
+    private Vpn prepareVpnForVerifyAppExclusionList() throws Exception {
+        final Vpn vpn = createVpn(AppOpsManager.OPSTR_ACTIVATE_PLATFORM_VPN);
+        startVpnForVerifyAppExclusionList(vpn);
+
         return vpn;
     }
 
     @Test
     public void testSetAndGetAppExclusionList() throws Exception {
-        final Vpn vpn = prepareVpnForVerifyAppExclusionList();
+        final Vpn vpn = createVpn(AppOpsManager.OPSTR_ACTIVATE_PLATFORM_VPN);
+        final String sessionKey = startVpnForVerifyAppExclusionList(vpn);
         verify(mVpnProfileStore, never()).put(eq(PRIMARY_USER_APP_EXCLUDE_KEY), any());
         vpn.setAppExclusionList(TEST_VPN_PKG, Arrays.asList(PKGS));
         verify(mVpnProfileStore)
                 .put(eq(PRIMARY_USER_APP_EXCLUDE_KEY),
                      eq(HexDump.hexStringToByteArray(PKGS_BYTES)));
-        assertEquals(vpn.createUserAndRestrictedProfilesRanges(
-                PRIMARY_USER.id, null, Arrays.asList(PKGS)),
-                vpn.mNetworkCapabilities.getUids());
+        final Set<Range<Integer>> uidRanges = vpn.createUserAndRestrictedProfilesRanges(
+                PRIMARY_USER.id, null /* allowedApplications */, Arrays.asList(PKGS));
+        verify(mConnectivityManager).setVpnDefaultForUids(eq(sessionKey), eq(uidRanges));
+        assertEquals(uidRanges, vpn.mNetworkCapabilities.getUids());
         assertEquals(Arrays.asList(PKGS), vpn.getAppExclusionList(TEST_VPN_PKG));
     }
 
     @Test
     public void testRefreshPlatformVpnAppExclusionList_updatesExcludedUids() throws Exception {
-        final Vpn vpn = prepareVpnForVerifyAppExclusionList();
+        final Vpn vpn = createVpn(AppOpsManager.OPSTR_ACTIVATE_PLATFORM_VPN);
+        final String sessionKey = startVpnForVerifyAppExclusionList(vpn);
         vpn.setAppExclusionList(TEST_VPN_PKG, Arrays.asList(PKGS));
+        final Set<Range<Integer>> uidRanges = vpn.createUserAndRestrictedProfilesRanges(
+                PRIMARY_USER.id, null /* allowedApplications */, Arrays.asList(PKGS));
+        verify(mConnectivityManager).setVpnDefaultForUids(eq(sessionKey), eq(uidRanges));
         verify(mMockNetworkAgent).doSendNetworkCapabilities(any());
         assertEquals(Arrays.asList(PKGS), vpn.getAppExclusionList(TEST_VPN_PKG));
 
@@ -1132,33 +1154,36 @@
         // Remove one of the package
         List<Integer> newExcludedUids = toList(PKG_UIDS);
         newExcludedUids.remove((Integer) PKG_UIDS[0]);
+        Set<Range<Integer>> newUidRanges = makeVpnUidRangeSet(PRIMARY_USER.id, newExcludedUids);
         sPackages.remove(PKGS[0]);
         vpn.refreshPlatformVpnAppExclusionList();
 
         // List in keystore is not changed, but UID for the removed packages is no longer exempted.
         assertEquals(Arrays.asList(PKGS), vpn.getAppExclusionList(TEST_VPN_PKG));
-        assertEquals(makeVpnUidRangeSet(PRIMARY_USER.id, newExcludedUids),
-                vpn.mNetworkCapabilities.getUids());
+        assertEquals(newUidRanges, vpn.mNetworkCapabilities.getUids());
         ArgumentCaptor<NetworkCapabilities> ncCaptor =
                 ArgumentCaptor.forClass(NetworkCapabilities.class);
         verify(mMockNetworkAgent).doSendNetworkCapabilities(ncCaptor.capture());
-        assertEquals(makeVpnUidRangeSet(PRIMARY_USER.id, newExcludedUids),
-                ncCaptor.getValue().getUids());
+        assertEquals(newUidRanges, ncCaptor.getValue().getUids());
+        verify(mConnectivityManager).setVpnDefaultForUids(eq(sessionKey), eq(newUidRanges));
 
         reset(mMockNetworkAgent);
 
         // Add the package back
         newExcludedUids.add(PKG_UIDS[0]);
+        newUidRanges = makeVpnUidRangeSet(PRIMARY_USER.id, newExcludedUids);
         sPackages.put(PKGS[0], PKG_UIDS[0]);
         vpn.refreshPlatformVpnAppExclusionList();
 
         // List in keystore is not changed and the uid list should be updated in the net cap.
         assertEquals(Arrays.asList(PKGS), vpn.getAppExclusionList(TEST_VPN_PKG));
-        assertEquals(makeVpnUidRangeSet(PRIMARY_USER.id, newExcludedUids),
-                vpn.mNetworkCapabilities.getUids());
+        assertEquals(newUidRanges, vpn.mNetworkCapabilities.getUids());
         verify(mMockNetworkAgent).doSendNetworkCapabilities(ncCaptor.capture());
-        assertEquals(makeVpnUidRangeSet(PRIMARY_USER.id, newExcludedUids),
-                ncCaptor.getValue().getUids());
+        assertEquals(newUidRanges, ncCaptor.getValue().getUids());
+
+        // The uidRange is the same as the original setAppExclusionList so this is the second call
+        verify(mConnectivityManager, times(2))
+                .setVpnDefaultForUids(eq(sessionKey), eq(newUidRanges));
     }
 
     private List<Range<Integer>> makeVpnUidRange(int userId, List<Integer> excludedAppIdList) {
@@ -1784,6 +1809,9 @@
                 .getRedactedLinkPropertiesForPackage(any(), anyInt(), anyString());
 
         final String sessionKey = vpn.startVpnProfile(TEST_VPN_PKG);
+        final Set<Range<Integer>> uidRanges = rangeSet(PRIMARY_USER_RANGE);
+        // This is triggered by Ikev2VpnRunner constructor.
+        verify(mConnectivityManager, times(1)).setVpnDefaultForUids(eq(sessionKey), eq(uidRanges));
         final NetworkCallback cb = triggerOnAvailableAndGetCallback();
 
         verifyInterfaceSetCfgWithFlags(IF_STATE_UP);
@@ -1792,6 +1820,8 @@
         // state
         verify(mIkev2SessionCreator, timeout(TEST_TIMEOUT_MS))
                 .createIkeSession(any(), any(), any(), any(), captor.capture(), any());
+        // This is triggered by Vpn#startOrMigrateIkeSession().
+        verify(mConnectivityManager, times(2)).setVpnDefaultForUids(eq(sessionKey), eq(uidRanges));
         reset(mIkev2SessionCreator);
         // For network lost case, the process should be triggered by calling onLost(), which is the
         // same process with the real case.
@@ -1811,16 +1841,43 @@
                 new String[] {TEST_VPN_PKG}, new VpnProfileState(VpnProfileState.STATE_CONNECTING,
                         sessionKey, false /* alwaysOn */, false /* lockdown */));
         if (errorType == VpnManager.ERROR_CLASS_NOT_RECOVERABLE) {
+            verify(mConnectivityManager).setVpnDefaultForUids(eq(sessionKey),
+                    eq(Collections.EMPTY_LIST));
             verify(mConnectivityManager, timeout(TEST_TIMEOUT_MS))
                     .unregisterNetworkCallback(eq(cb));
         } else if (errorType == VpnManager.ERROR_CLASS_RECOVERABLE
                 // Vpn won't retry when there is no usable underlying network.
                 && errorCode != VpnManager.ERROR_CODE_NETWORK_LOST) {
             int retryIndex = 0;
-            final IkeSessionCallback ikeCb2 = verifyRetryAndGetNewIkeCb(retryIndex++);
+            // First failure occurred above.
+            final IkeSessionCallback retryCb = verifyRetryAndGetNewIkeCb(retryIndex++);
+            // Trigger 2 more failures to let the retry delay increase to 5s.
+            mExecutor.execute(() -> retryCb.onClosedWithException(exception));
+            final IkeSessionCallback retryCb2 = verifyRetryAndGetNewIkeCb(retryIndex++);
+            mExecutor.execute(() -> retryCb2.onClosedWithException(exception));
+            final IkeSessionCallback retryCb3 = verifyRetryAndGetNewIkeCb(retryIndex++);
 
-            mExecutor.execute(() -> ikeCb2.onClosedWithException(exception));
+            // setVpnDefaultForUids may be called again but the uidRanges should not change.
+            verify(mConnectivityManager, atLeast(2)).setVpnDefaultForUids(eq(sessionKey),
+                    mUidRangesCaptor.capture());
+            final List<Collection<Range<Integer>>> capturedUidRanges =
+                    mUidRangesCaptor.getAllValues();
+            for (int i = 2; i < capturedUidRanges.size(); i++) {
+                // Assert equals no order.
+                assertTrue(
+                        "uid ranges should not be modified. Expected: " + uidRanges
+                                + ", actual: " + capturedUidRanges.get(i),
+                        capturedUidRanges.get(i).containsAll(uidRanges)
+                                && capturedUidRanges.get(i).size() == uidRanges.size());
+            }
+
+            // A fourth failure will cause the retry delay to be greater than 5s.
+            mExecutor.execute(() -> retryCb3.onClosedWithException(exception));
             verifyRetryAndGetNewIkeCb(retryIndex++);
+
+            // The VPN network preference will be cleared when the retry delay is greater than 5s.
+            verify(mConnectivityManager).setVpnDefaultForUids(eq(sessionKey),
+                    eq(Collections.EMPTY_LIST));
         }
     }
 
@@ -2103,7 +2160,9 @@
         when(mVpnProfileStore.get(vpn.getProfileNameForPackage(TEST_VPN_PKG)))
                 .thenReturn(vpnProfile.encode());
 
-        vpn.startVpnProfile(TEST_VPN_PKG);
+        final String sessionKey = vpn.startVpnProfile(TEST_VPN_PKG);
+        final Set<Range<Integer>> uidRanges = Collections.singleton(PRIMARY_USER_RANGE);
+        verify(mConnectivityManager).setVpnDefaultForUids(eq(sessionKey), eq(uidRanges));
         final NetworkCallback nwCb = triggerOnAvailableAndGetCallback(underlyingNetworkCaps);
         // There are 4 interactions with the executor.
         // - Network available
@@ -2196,6 +2255,7 @@
         final PlatformVpnSnapshot vpnSnapShot = verifySetupPlatformVpn(
                 createIkeConfig(createIkeConnectInfo(), true /* isMobikeEnabled */));
         vpnSnapShot.vpn.mVpnRunner.exitVpnRunner();
+        verify(mConnectivityManager).setVpnDefaultForUids(anyString(), eq(Collections.EMPTY_LIST));
     }
 
     @Test
@@ -3104,6 +3164,20 @@
         assertThrows(UnsupportedOperationException.class, () -> startLegacyVpn(vpn, profile));
     }
 
+    @Test
+    public void testStartLegacyVpnModifyProfile_TypePSK() throws Exception {
+        setMockedUsers(PRIMARY_USER);
+        final Vpn vpn = createVpn(PRIMARY_USER.id);
+        final Ikev2VpnProfile ikev2VpnProfile =
+                new Ikev2VpnProfile.Builder(TEST_VPN_SERVER, TEST_VPN_IDENTITY)
+                        .setAuthPsk(TEST_VPN_PSK)
+                        .build();
+        final VpnProfile profile = ikev2VpnProfile.toVpnProfile();
+
+        startLegacyVpn(vpn, profile);
+        assertEquals(profile, ikev2VpnProfile.toVpnProfile());
+    }
+
     private void assertTransportInfoMatches(NetworkCapabilities nc, int type) {
         assertNotNull(nc);
         VpnTransportInfo ti = (VpnTransportInfo) nc.getTransportInfo();
@@ -3248,12 +3322,6 @@
         }
 
         @Override
-        public long getNextRetryDelayMs(int retryCount) {
-            // Simply return retryCount as the delay seconds for retrying.
-            return retryCount * 1000;
-        }
-
-        @Override
         public long getValidationFailRecoveryMs(int retryCount) {
             // Simply return retryCount as the delay seconds for retrying.
             return retryCount * 100L;
diff --git a/tests/unit/java/com/android/server/connectivityservice/CSKeepConnectedTest.kt b/tests/unit/java/com/android/server/connectivityservice/CSKeepConnectedTest.kt
index 2126a09..a753922 100644
--- a/tests/unit/java/com/android/server/connectivityservice/CSKeepConnectedTest.kt
+++ b/tests/unit/java/com/android/server/connectivityservice/CSKeepConnectedTest.kt
@@ -47,8 +47,9 @@
         val keepConnectedAgent = Agent(nc = nc, score = FromS(NetworkScore.Builder()
                 .setKeepConnectedReason(KEEP_CONNECTED_LOCAL_NETWORK)
                 .build()),
-                lnc = LocalNetworkConfig.Builder().build())
-        val dontKeepConnectedAgent = Agent(nc = nc, lnc = LocalNetworkConfig.Builder().build())
+                lnc = FromS(LocalNetworkConfig.Builder().build()))
+        val dontKeepConnectedAgent = Agent(nc = nc,
+                lnc = FromS(LocalNetworkConfig.Builder().build()))
         doTestKeepConnected(keepConnectedAgent, dontKeepConnectedAgent)
     }
 
diff --git a/tests/unit/java/com/android/server/connectivityservice/CSLocalAgentCreationTests.kt b/tests/unit/java/com/android/server/connectivityservice/CSLocalAgentCreationTests.kt
index cfc3a3d..6add6b9 100644
--- a/tests/unit/java/com/android/server/connectivityservice/CSLocalAgentCreationTests.kt
+++ b/tests/unit/java/com/android/server/connectivityservice/CSLocalAgentCreationTests.kt
@@ -49,7 +49,7 @@
 private fun keepConnectedScore() =
         FromS(NetworkScore.Builder().setKeepConnectedReason(KEEP_CONNECTED_FOR_TEST).build())
 
-private fun defaultLnc() = LocalNetworkConfig.Builder().build()
+private fun defaultLnc() = FromS(LocalNetworkConfig.Builder().build())
 
 @RunWith(DevSdkIgnoreRunner::class)
 @SmallTest
@@ -124,8 +124,7 @@
                     lnc = null)
         }
         assertFailsWith<IllegalArgumentException> {
-            Agent(nc = NetworkCapabilities.Builder().build(),
-                    lnc = LocalNetworkConfig.Builder().build())
+            Agent(nc = NetworkCapabilities.Builder().build(), lnc = defaultLnc())
         }
     }
 }
diff --git a/tests/unit/java/com/android/server/connectivityservice/CSLocalAgentTests.kt b/tests/unit/java/com/android/server/connectivityservice/CSLocalAgentTests.kt
index 1dab548..ad21bf5 100644
--- a/tests/unit/java/com/android/server/connectivityservice/CSLocalAgentTests.kt
+++ b/tests/unit/java/com/android/server/connectivityservice/CSLocalAgentTests.kt
@@ -94,7 +94,7 @@
         }
         assertFailsWith<IllegalArgumentException> {
             Agent(nc = NetworkCapabilities.Builder().build(),
-                    lnc = LocalNetworkConfig.Builder().build())
+                    lnc = FromS(LocalNetworkConfig.Builder().build()))
         }
     }
 
@@ -110,7 +110,7 @@
         val agent = Agent(nc = NetworkCapabilities.Builder()
                 .addCapability(NET_CAPABILITY_LOCAL_NETWORK)
                 .build(),
-                lnc = LocalNetworkConfig.Builder().build())
+                lnc = FromS(LocalNetworkConfig.Builder().build()))
         agent.connect()
         cb.expectAvailableCallbacks(agent.network, validated = false)
         agent.sendNetworkCapabilities(NetworkCapabilities.Builder().build())
@@ -141,7 +141,7 @@
         val localAgent = Agent(
                 nc = nc(TRANSPORT_WIFI, NET_CAPABILITY_LOCAL_NETWORK),
                 lp = lp("local0"),
-                lnc = LocalNetworkConfig.Builder().build(),
+                lnc = FromS(LocalNetworkConfig.Builder().build()),
         )
         localAgent.connect()
 
@@ -194,11 +194,11 @@
         // Set up a local agent that should forward its traffic to the best wifi upstream.
         val localAgent = Agent(nc = nc(TRANSPORT_WIFI, NET_CAPABILITY_LOCAL_NETWORK),
                 lp = lp("local0"),
-                lnc = LocalNetworkConfig.Builder()
+                lnc = FromS(LocalNetworkConfig.Builder()
                 .setUpstreamSelector(NetworkRequest.Builder()
                         .addTransportType(TRANSPORT_WIFI)
                         .build())
-                .build(),
+                .build()),
                 score = FromS(NetworkScore.Builder()
                         .setKeepConnectedReason(KEEP_CONNECTED_LOCAL_NETWORK)
                         .build())
@@ -247,11 +247,11 @@
         // Set up a local agent that should forward its traffic to the best wifi upstream.
         val localAgent = Agent(nc = nc(TRANSPORT_WIFI, NET_CAPABILITY_LOCAL_NETWORK),
                 lp = lp("local0"),
-                lnc = LocalNetworkConfig.Builder()
+                lnc = FromS(LocalNetworkConfig.Builder()
                         .setUpstreamSelector(NetworkRequest.Builder()
                                 .addTransportType(TRANSPORT_WIFI)
                                 .build())
-                        .build(),
+                        .build()),
                 score = FromS(NetworkScore.Builder()
                         .setKeepConnectedReason(KEEP_CONNECTED_LOCAL_NETWORK)
                         .build())
@@ -293,11 +293,11 @@
         cm.registerNetworkCallback(NetworkRequest.Builder().clearCapabilities().build(), cb)
 
         val localNc = nc(TRANSPORT_WIFI, NET_CAPABILITY_LOCAL_NETWORK)
-        val lnc = LocalNetworkConfig.Builder()
+        val lnc = FromS(LocalNetworkConfig.Builder()
                 .setUpstreamSelector(NetworkRequest.Builder()
                         .addTransportType(TRANSPORT_WIFI)
                         .build())
-                .build()
+                .build())
         val localScore = FromS(NetworkScore.Builder().build())
 
         // Set up a local agent that should forward its traffic to the best wifi upstream.
@@ -346,11 +346,11 @@
         wifiAgent.unregisterAfterReplacement(LONG_TIMEOUT_MS)
         val localAgent = Agent(nc = nc(TRANSPORT_WIFI, NET_CAPABILITY_LOCAL_NETWORK),
                 lp = lp("local0"),
-                lnc = LocalNetworkConfig.Builder()
+                lnc = FromS(LocalNetworkConfig.Builder()
                         .setUpstreamSelector(NetworkRequest.Builder()
                                 .addTransportType(TRANSPORT_WIFI)
                                 .build())
-                        .build(),
+                        .build()),
                 score = FromS(NetworkScore.Builder()
                         .setKeepConnectedReason(KEEP_CONNECTED_LOCAL_NETWORK)
                         .build())
@@ -374,11 +374,11 @@
     fun testForwardingRules() {
         deps.setBuildSdk(VERSION_V)
         // Set up a local agent that should forward its traffic to the best DUN upstream.
-        val lnc = LocalNetworkConfig.Builder()
+        val lnc = FromS(LocalNetworkConfig.Builder()
                 .setUpstreamSelector(NetworkRequest.Builder()
                         .addCapability(NET_CAPABILITY_DUN)
                         .build())
-                .build()
+                .build())
         val localAgent = Agent(nc = nc(TRANSPORT_WIFI, NET_CAPABILITY_LOCAL_NETWORK),
                 lp = lp("local0"),
                 lnc = lnc,
@@ -426,7 +426,7 @@
 
         // Make sure sending the same config again doesn't do anything
         repeat(5) {
-            localAgent.sendLocalNetworkConfig(lnc)
+            localAgent.sendLocalNetworkConfig(lnc.value)
         }
         inOrder.verifyNoMoreInteractions()
 
@@ -501,13 +501,13 @@
         }
 
         // Set up a local agent.
-        val lnc = LocalNetworkConfig.Builder().apply {
+        val lnc = FromS(LocalNetworkConfig.Builder().apply {
             if (haveUpstream) {
                 setUpstreamSelector(NetworkRequest.Builder()
                         .addTransportType(TRANSPORT_WIFI)
                         .build())
             }
-        }.build()
+        }.build())
         val localAgent = Agent(nc = nc(TRANSPORT_THREAD, NET_CAPABILITY_LOCAL_NETWORK),
                 lp = lp("local0"),
                 lnc = lnc,
diff --git a/tests/unit/java/com/android/server/connectivityservice/base/CSAgentWrapper.kt b/tests/unit/java/com/android/server/connectivityservice/base/CSAgentWrapper.kt
index 013a749..d41c742 100644
--- a/tests/unit/java/com/android/server/connectivityservice/base/CSAgentWrapper.kt
+++ b/tests/unit/java/com/android/server/connectivityservice/base/CSAgentWrapper.kt
@@ -69,7 +69,7 @@
         nac: NetworkAgentConfig,
         val nc: NetworkCapabilities,
         val lp: LinkProperties,
-        val lnc: LocalNetworkConfig?,
+        val lnc: FromS<LocalNetworkConfig>?,
         val score: FromS<NetworkScore>,
         val provider: NetworkProvider?
 ) : TestableNetworkCallback.HasNetwork {
@@ -101,7 +101,7 @@
         // Create the actual agent. NetworkAgent is abstract, so make an anonymous subclass.
         if (deps.isAtLeastS()) {
             agent = object : NetworkAgent(context, csHandlerThread.looper, TAG,
-                    nc, lp, lnc, score.value, nac, provider) {}
+                    nc, lp, lnc?.value, score.value, nac, provider) {}
         } else {
             agent = object : NetworkAgent(context, csHandlerThread.looper, TAG,
                     nc, lp, 50 /* score */, nac, provider) {}
diff --git a/tests/unit/java/com/android/server/connectivityservice/base/CSTest.kt b/tests/unit/java/com/android/server/connectivityservice/base/CSTest.kt
index f21a428..4f5cfc0 100644
--- a/tests/unit/java/com/android/server/connectivityservice/base/CSTest.kt
+++ b/tests/unit/java/com/android/server/connectivityservice/base/CSTest.kt
@@ -321,7 +321,7 @@
             nc: NetworkCapabilities = defaultNc(),
             nac: NetworkAgentConfig = emptyAgentConfig(nc.getLegacyType()),
             lp: LinkProperties = defaultLp(),
-            lnc: LocalNetworkConfig? = null,
+            lnc: FromS<LocalNetworkConfig>? = null,
             score: FromS<NetworkScore> = defaultScore(),
             provider: NetworkProvider? = null
     ) = CSAgentWrapper(context, deps, csHandlerThread, networkStack,