Merge "Bypass VPN lockdown for clat initialization"
diff --git a/bpf_progs/bpf_shared.h b/bpf_progs/bpf_shared.h
index 85b9f86..7b1106a 100644
--- a/bpf_progs/bpf_shared.h
+++ b/bpf_progs/bpf_shared.h
@@ -95,13 +95,13 @@
 
 // 'static' - otherwise these constants end up in .rodata in the resulting .o post compilation
 static const int COOKIE_UID_MAP_SIZE = 10000;
-static const int UID_COUNTERSET_MAP_SIZE = 2000;
+static const int UID_COUNTERSET_MAP_SIZE = 4000;
 static const int APP_STATS_MAP_SIZE = 10000;
 static const int STATS_MAP_SIZE = 5000;
 static const int IFACE_INDEX_NAME_MAP_SIZE = 1000;
 static const int IFACE_STATS_MAP_SIZE = 1000;
 static const int CONFIGURATION_MAP_SIZE = 2;
-static const int UID_OWNER_MAP_SIZE = 2000;
+static const int UID_OWNER_MAP_SIZE = 4000;
 
 #ifdef __cplusplus
 
diff --git a/bpf_progs/clatd.c b/bpf_progs/clatd.c
index a2214dc..3ac0603 100644
--- a/bpf_progs/clatd.c
+++ b/bpf_progs/clatd.c
@@ -38,8 +38,19 @@
 #include "bpf_shared.h"
 #include "clat_mark.h"
 
-// From kernel:include/net/ip.h
-#define IP_DF 0x4000  // Flag: "Don't Fragment"
+// IP flags. (from kernel's include/net/ip.h)
+#define IP_CE      0x8000  // Flag: "Congestion" (really reserved 'evil bit')
+#define IP_DF      0x4000  // Flag: "Don't Fragment"
+#define IP_MF      0x2000  // Flag: "More Fragments"
+#define IP_OFFSET  0x1FFF  // "Fragment Offset" part
+
+// from kernel's include/net/ipv6.h
+struct frag_hdr {
+    __u8   nexthdr;
+    __u8   reserved;        // always zero
+    __be16 frag_off;        // 13 bit offset, 2 bits zero, 1 bit "More Fragments"
+    __be32 identification;
+};
 
 DEFINE_BPF_MAP_GRW(clat_ingress6_map, HASH, ClatIngress6Key, ClatIngress6Value, 16, AID_SYSTEM)
 
@@ -89,7 +100,11 @@
 
     if (!v) return TC_ACT_PIPE;
 
-    switch (ip6->nexthdr) {
+    __u8 proto = ip6->nexthdr;
+    __be16 ip_id = 0;
+    __be16 frag_off = htons(IP_DF);
+
+    switch (proto) {
         case IPPROTO_TCP:  // For TCP & UDP the checksum neutrality of the chosen IPv6
         case IPPROTO_UDP:  // address means there is no need to update their checksums.
         case IPPROTO_GRE:  // We do not need to bother looking at GRE/ESP headers,
@@ -114,14 +129,14 @@
             .version = 4,                                                      // u4
             .ihl = sizeof(struct iphdr) / sizeof(__u32),                       // u4
             .tos = (ip6->priority << 4) + (ip6->flow_lbl[0] >> 4),             // u8
-            .tot_len = htons(ntohs(ip6->payload_len) + sizeof(struct iphdr)),  // u16
-            .id = 0,                                                           // u16
-            .frag_off = htons(IP_DF),                                          // u16
+            .tot_len = htons(ntohs(ip6->payload_len) + sizeof(struct iphdr)),  // be16
+            .id = ip_id,                                                       // be16
+            .frag_off = frag_off,                                              // be16
             .ttl = ip6->hop_limit,                                             // u8
-            .protocol = ip6->nexthdr,                                          // u8
+            .protocol = proto,                                                 // u8
             .check = 0,                                                        // u16
-            .saddr = ip6->saddr.in6_u.u6_addr32[3],                            // u32
-            .daddr = v->local4.s_addr,                                         // u32
+            .saddr = ip6->saddr.in6_u.u6_addr32[3],                            // be32
+            .daddr = v->local4.s_addr,                                         // be32
     };
 
     // Calculate the IPv4 one's complement checksum of the IPv4 header.
@@ -211,11 +226,6 @@
 
 DEFINE_BPF_MAP_GRW(clat_egress4_map, HASH, ClatEgress4Key, ClatEgress4Value, 16, AID_SYSTEM)
 
-DEFINE_BPF_PROG("schedcls/egress4/clat_ether", AID_ROOT, AID_SYSTEM, sched_cls_egress4_clat_ether)
-(struct __sk_buff* skb) {
-    return TC_ACT_PIPE;
-}
-
 DEFINE_BPF_PROG("schedcls/egress4/clat_rawip", AID_ROOT, AID_SYSTEM, sched_cls_egress4_clat_rawip)
 (struct __sk_buff* skb) {
     // Must be meta-ethernet IPv4 frame
diff --git a/service/src/com/android/server/ConnectivityService.java b/service/src/com/android/server/ConnectivityService.java
index e732001..195f046 100755
--- a/service/src/com/android/server/ConnectivityService.java
+++ b/service/src/com/android/server/ConnectivityService.java
@@ -368,7 +368,7 @@
     private static final int DEFAULT_NASCENT_DELAY_MS = 5_000;
 
     // The maximum value for the blocking validation result, in milliseconds.
-    public static final int MAX_VALIDATION_FAILURE_BLOCKING_TIME_MS = 10000;
+    public static final int MAX_VALIDATION_IGNORE_AFTER_ROAM_TIME_MS = 10000;
 
     // The maximum number of network request allowed per uid before an exception is thrown.
     @VisibleForTesting
@@ -4241,12 +4241,18 @@
         return nai.isCreated() && !nai.isDestroyed();
     }
 
-    private boolean shouldIgnoreValidationFailureAfterRoam(NetworkAgentInfo nai) {
+    @VisibleForTesting
+    boolean shouldIgnoreValidationFailureAfterRoam(NetworkAgentInfo nai) {
         // T+ devices should use unregisterAfterReplacement.
         if (SdkLevel.isAtLeastT()) return false;
+
+        // If the network never roamed, return false. The check below is not sufficient if time
+        // since boot is less than blockTimeOut, though that's extremely unlikely to happen.
+        if (nai.lastRoamTime == 0) return false;
+
         final long blockTimeOut = Long.valueOf(mResources.get().getInteger(
                 R.integer.config_validationFailureAfterRoamIgnoreTimeMillis));
-        if (blockTimeOut <= MAX_VALIDATION_FAILURE_BLOCKING_TIME_MS
+        if (blockTimeOut <= MAX_VALIDATION_IGNORE_AFTER_ROAM_TIME_MS
                 && blockTimeOut >= 0) {
             final long currentTimeMs = SystemClock.elapsedRealtime();
             long timeSinceLastRoam = currentTimeMs - nai.lastRoamTime;
@@ -9774,8 +9780,8 @@
         return ((VpnTransportInfo) ti).getType();
     }
 
-    private void maybeUpdateWifiRoamTimestamp(NetworkAgentInfo nai, NetworkCapabilities nc) {
-        if (nai == null) return;
+    private void maybeUpdateWifiRoamTimestamp(@NonNull NetworkAgentInfo nai,
+            @NonNull NetworkCapabilities nc) {
         final TransportInfo prevInfo = nai.networkCapabilities.getTransportInfo();
         final TransportInfo newInfo = nc.getTransportInfo();
         if (!(prevInfo instanceof WifiInfo) || !(newInfo instanceof WifiInfo)) {
diff --git a/service/src/com/android/server/connectivity/ClatCoordinator.java b/service/src/com/android/server/connectivity/ClatCoordinator.java
index fecddd7..42b3827 100644
--- a/service/src/com/android/server/connectivity/ClatCoordinator.java
+++ b/service/src/com/android/server/connectivity/ClatCoordinator.java
@@ -113,12 +113,12 @@
         return "/sys/fs/bpf/net_shared/map_clatd_clat_" + which + "_map";
     }
 
-    private static String makeProgPath(boolean ingress, boolean ether) {
-        String path = "/sys/fs/bpf/net_shared/prog_clatd_schedcls_"
-                + (ingress ? "ingress6" : "egress4")
-                + "_clat_"
+    private static final String CLAT_EGRESS4_RAWIP_PROG_PATH =
+            "/sys/fs/bpf/net_shared/prog_clatd_schedcls_egress4_clat_rawip";
+
+    private static String makeIngressProgPath(boolean ether) {
+        return "/sys/fs/bpf/net_shared/prog_clatd_schedcls_ingress6_clat_"
                 + (ether ? "ether" : "rawip");
-        return path;
     }
 
     @NonNull
@@ -478,7 +478,7 @@
             // tc filter add dev .. egress prio 4 protocol ip bpf object-pinned /sys/fs/bpf/...
             // direct-action
             mDeps.tcFilterAddDevBpf(tracker.v4ifIndex, EGRESS, (short) PRIO_CLAT, (short) ETH_P_IP,
-                    makeProgPath(EGRESS, RAWIP));
+                    CLAT_EGRESS4_RAWIP_PROG_PATH);
         } catch (IOException e) {
             Log.e(TAG, "tc filter add dev (" + tracker.v4ifIndex + "[" + tracker.v4iface
                     + "]) egress prio PRIO_CLAT protocol ip failure: " + e);
@@ -504,7 +504,7 @@
             // tc filter add dev .. ingress prio 4 protocol ipv6 bpf object-pinned /sys/fs/bpf/...
             // direct-action
             mDeps.tcFilterAddDevBpf(tracker.ifIndex, INGRESS, (short) PRIO_CLAT,
-                    (short) ETH_P_IPV6, makeProgPath(INGRESS, isEthernet));
+                    (short) ETH_P_IPV6, makeIngressProgPath(isEthernet));
         } catch (IOException e) {
             Log.e(TAG, "tc filter add dev (" + tracker.ifIndex + "[" + tracker.iface
                     + "]) ingress prio PRIO_CLAT protocol ipv6 failure: " + e);
diff --git a/tests/mts/bpf_existence_test.cpp b/tests/mts/bpf_existence_test.cpp
index c7e8b97..bb07a98 100644
--- a/tests/mts/bpf_existence_test.cpp
+++ b/tests/mts/bpf_existence_test.cpp
@@ -100,7 +100,6 @@
     NETD "map_netd_uid_counterset_map",
     NETD "map_netd_uid_owner_map",
     NETD "map_netd_uid_permission_map",
-    SHARED "prog_clatd_schedcls_egress4_clat_ether",
     SHARED "prog_clatd_schedcls_egress4_clat_rawip",
     SHARED "prog_clatd_schedcls_ingress6_clat_ether",
     SHARED "prog_clatd_schedcls_ingress6_clat_rawip",
diff --git a/tests/unit/java/com/android/server/ConnectivityServiceTest.java b/tests/unit/java/com/android/server/ConnectivityServiceTest.java
index 12478cb..2d8cf80 100755
--- a/tests/unit/java/com/android/server/ConnectivityServiceTest.java
+++ b/tests/unit/java/com/android/server/ConnectivityServiceTest.java
@@ -7156,17 +7156,38 @@
         mCm.unregisterNetworkCallback(networkCallback);
     }
 
-    private void expectNotifyNetworkStatus(List<Network> defaultNetworks, String defaultIface,
-            Integer vpnUid, String vpnIfname, List<String> underlyingIfaces) throws Exception {
+    private void expectNotifyNetworkStatus(List<Network> allNetworks, List<Network> defaultNetworks,
+            String defaultIface, Integer vpnUid, String vpnIfname, List<String> underlyingIfaces)
+            throws Exception {
         ArgumentCaptor<List<Network>> defaultNetworksCaptor = ArgumentCaptor.forClass(List.class);
         ArgumentCaptor<List<UnderlyingNetworkInfo>> vpnInfosCaptor =
                 ArgumentCaptor.forClass(List.class);
+        ArgumentCaptor<List<NetworkStateSnapshot>> snapshotsCaptor =
+                ArgumentCaptor.forClass(List.class);
 
         verify(mStatsManager, atLeastOnce()).notifyNetworkStatus(defaultNetworksCaptor.capture(),
-                any(List.class), eq(defaultIface), vpnInfosCaptor.capture());
+                snapshotsCaptor.capture(), eq(defaultIface), vpnInfosCaptor.capture());
 
         assertSameElements(defaultNetworks, defaultNetworksCaptor.getValue());
 
+        List<Network> snapshotNetworks = new ArrayList<Network>();
+        for (NetworkStateSnapshot ns : snapshotsCaptor.getValue()) {
+            snapshotNetworks.add(ns.getNetwork());
+        }
+        assertSameElements(allNetworks, snapshotNetworks);
+
+        if (defaultIface != null) {
+            assertNotNull(
+                    "Did not find interface " + defaultIface + " in call to notifyNetworkStatus",
+                    CollectionUtils.findFirst(snapshotsCaptor.getValue(), (ns) -> {
+                        final LinkProperties lp = ns.getLinkProperties();
+                        if (lp != null && TextUtils.equals(defaultIface, lp.getInterfaceName())) {
+                            return true;
+                        }
+                        return false;
+                    }));
+        }
+
         List<UnderlyingNetworkInfo> infos = vpnInfosCaptor.getValue();
         if (vpnUid != null) {
             assertEquals("Should have exactly one VPN:", 1, infos.size());
@@ -7181,28 +7202,38 @@
     }
 
     private void expectNotifyNetworkStatus(
-            List<Network> defaultNetworks, String defaultIface) throws Exception {
-        expectNotifyNetworkStatus(defaultNetworks, defaultIface, null, null, List.of());
+            List<Network> allNetworks, List<Network> defaultNetworks, String defaultIface)
+            throws Exception {
+        expectNotifyNetworkStatus(allNetworks, defaultNetworks, defaultIface, null, null,
+                List.of());
+    }
+
+    private List<Network> onlyCell() {
+        return List.of(mCellNetworkAgent.getNetwork());
+    }
+
+    private List<Network> onlyWifi() {
+        return List.of(mWiFiNetworkAgent.getNetwork());
+    }
+
+    private List<Network> cellAndWifi() {
+        return List.of(mCellNetworkAgent.getNetwork(), mWiFiNetworkAgent.getNetwork());
     }
 
     @Test
     public void testStatsIfacesChanged() throws Exception {
-        mCellNetworkAgent = new TestNetworkAgentWrapper(TRANSPORT_CELLULAR);
-        mWiFiNetworkAgent = new TestNetworkAgentWrapper(TRANSPORT_WIFI);
-
-        final List<Network> onlyCell = List.of(mCellNetworkAgent.getNetwork());
-        final List<Network> onlyWifi = List.of(mWiFiNetworkAgent.getNetwork());
-
         LinkProperties cellLp = new LinkProperties();
         cellLp.setInterfaceName(MOBILE_IFNAME);
         LinkProperties wifiLp = new LinkProperties();
         wifiLp.setInterfaceName(WIFI_IFNAME);
 
-        // Simple connection should have updated ifaces
+        mCellNetworkAgent = new TestNetworkAgentWrapper(TRANSPORT_CELLULAR, cellLp);
+        mWiFiNetworkAgent = new TestNetworkAgentWrapper(TRANSPORT_WIFI);
+
+        // Simple connection with initial LP should have updated ifaces.
         mCellNetworkAgent.connect(false);
-        mCellNetworkAgent.sendLinkProperties(cellLp);
         waitForIdle();
-        expectNotifyNetworkStatus(onlyCell, MOBILE_IFNAME);
+        expectNotifyNetworkStatus(onlyCell(), onlyCell(), MOBILE_IFNAME);
         reset(mStatsManager);
 
         // Default network switch should update ifaces.
@@ -7210,37 +7241,56 @@
         mWiFiNetworkAgent.sendLinkProperties(wifiLp);
         waitForIdle();
         assertEquals(wifiLp, mService.getActiveLinkProperties());
-        expectNotifyNetworkStatus(onlyWifi, WIFI_IFNAME);
+        expectNotifyNetworkStatus(cellAndWifi(), onlyWifi(), WIFI_IFNAME);
+        reset(mStatsManager);
+
+        // Disconnecting a network updates ifaces again. The soon-to-be disconnected interface is
+        // still in the list to ensure that stats are counted on that interface.
+        // TODO: this means that if anything else uses that interface for any other reason before
+        // notifyNetworkStatus is called again, traffic on that interface will be accounted to the
+        // disconnected network. This is likely a bug in ConnectivityService; it should probably
+        // call notifyNetworkStatus again without the disconnected network.
+        mCellNetworkAgent.disconnect();
+        waitForIdle();
+        expectNotifyNetworkStatus(cellAndWifi(), onlyWifi(), WIFI_IFNAME);
+        verifyNoMoreInteractions(mStatsManager);
+        reset(mStatsManager);
+
+        // Connecting a network updates ifaces even if the network doesn't become default.
+        mCellNetworkAgent = new TestNetworkAgentWrapper(TRANSPORT_CELLULAR, cellLp);
+        mCellNetworkAgent.connect(false);
+        waitForIdle();
+        expectNotifyNetworkStatus(cellAndWifi(), onlyWifi(), WIFI_IFNAME);
         reset(mStatsManager);
 
         // Disconnect should update ifaces.
         mWiFiNetworkAgent.disconnect();
         waitForIdle();
-        expectNotifyNetworkStatus(onlyCell, MOBILE_IFNAME);
+        expectNotifyNetworkStatus(onlyCell(), onlyCell(), MOBILE_IFNAME);
         reset(mStatsManager);
 
         // Metered change should update ifaces
         mCellNetworkAgent.addCapability(NetworkCapabilities.NET_CAPABILITY_NOT_METERED);
         waitForIdle();
-        expectNotifyNetworkStatus(onlyCell, MOBILE_IFNAME);
+        expectNotifyNetworkStatus(onlyCell(), onlyCell(), MOBILE_IFNAME);
         reset(mStatsManager);
 
         mCellNetworkAgent.removeCapability(NetworkCapabilities.NET_CAPABILITY_NOT_METERED);
         waitForIdle();
-        expectNotifyNetworkStatus(onlyCell, MOBILE_IFNAME);
+        expectNotifyNetworkStatus(onlyCell(), onlyCell(), MOBILE_IFNAME);
         reset(mStatsManager);
 
         // Temp metered change shouldn't update ifaces
         mCellNetworkAgent.addCapability(NET_CAPABILITY_TEMPORARILY_NOT_METERED);
         waitForIdle();
-        verify(mStatsManager, never()).notifyNetworkStatus(eq(onlyCell),
+        verify(mStatsManager, never()).notifyNetworkStatus(eq(onlyCell()),
                 any(List.class), eq(MOBILE_IFNAME), any(List.class));
         reset(mStatsManager);
 
         // Roaming change should update ifaces
         mCellNetworkAgent.addCapability(NetworkCapabilities.NET_CAPABILITY_NOT_ROAMING);
         waitForIdle();
-        expectNotifyNetworkStatus(onlyCell, MOBILE_IFNAME);
+        expectNotifyNetworkStatus(onlyCell(), onlyCell(), MOBILE_IFNAME);
         reset(mStatsManager);
 
         // Test VPNs.
@@ -7254,8 +7304,8 @@
                 List.of(mCellNetworkAgent.getNetwork(), mMockVpn.getNetwork());
 
         // A VPN with default (null) underlying networks sets the underlying network's interfaces...
-        expectNotifyNetworkStatus(cellAndVpn, MOBILE_IFNAME, Process.myUid(), VPN_IFNAME,
-                List.of(MOBILE_IFNAME));
+        expectNotifyNetworkStatus(cellAndVpn, cellAndVpn, MOBILE_IFNAME, Process.myUid(),
+                VPN_IFNAME, List.of(MOBILE_IFNAME));
 
         // ...and updates them as the default network switches.
         mWiFiNetworkAgent = new TestNetworkAgentWrapper(TRANSPORT_WIFI);
@@ -7264,15 +7314,16 @@
         final Network[] onlyNull = new Network[]{null};
         final List<Network> wifiAndVpn =
                 List.of(mWiFiNetworkAgent.getNetwork(), mMockVpn.getNetwork());
-        final List<Network> cellAndWifi =
-                List.of(mCellNetworkAgent.getNetwork(), mWiFiNetworkAgent.getNetwork());
+        final List<Network> cellWifiAndVpn =
+                List.of(mCellNetworkAgent.getNetwork(), mWiFiNetworkAgent.getNetwork(),
+                        mMockVpn.getNetwork());
         final Network[] cellNullAndWifi =
                 new Network[]{mCellNetworkAgent.getNetwork(), null, mWiFiNetworkAgent.getNetwork()};
 
         waitForIdle();
         assertEquals(wifiLp, mService.getActiveLinkProperties());
-        expectNotifyNetworkStatus(wifiAndVpn, WIFI_IFNAME, Process.myUid(), VPN_IFNAME,
-                List.of(WIFI_IFNAME));
+        expectNotifyNetworkStatus(cellWifiAndVpn, wifiAndVpn, WIFI_IFNAME, Process.myUid(),
+                VPN_IFNAME, List.of(WIFI_IFNAME));
         reset(mStatsManager);
 
         // A VPN that sets its underlying networks passes the underlying interfaces, and influences
@@ -7281,23 +7332,23 @@
         // MOBILE_IFNAME even though the default network is wifi.
         // TODO: fix this to pass in the actual default network interface. Whether or not the VPN
         // applies to the system server UID should not have any bearing on network stats.
-        mMockVpn.setUnderlyingNetworks(onlyCell.toArray(new Network[0]));
+        mMockVpn.setUnderlyingNetworks(onlyCell().toArray(new Network[0]));
         waitForIdle();
-        expectNotifyNetworkStatus(wifiAndVpn, MOBILE_IFNAME, Process.myUid(), VPN_IFNAME,
-                List.of(MOBILE_IFNAME));
+        expectNotifyNetworkStatus(cellWifiAndVpn, wifiAndVpn, MOBILE_IFNAME, Process.myUid(),
+                VPN_IFNAME, List.of(MOBILE_IFNAME));
         reset(mStatsManager);
 
-        mMockVpn.setUnderlyingNetworks(cellAndWifi.toArray(new Network[0]));
+        mMockVpn.setUnderlyingNetworks(cellAndWifi().toArray(new Network[0]));
         waitForIdle();
-        expectNotifyNetworkStatus(wifiAndVpn, MOBILE_IFNAME, Process.myUid(), VPN_IFNAME,
-                List.of(MOBILE_IFNAME, WIFI_IFNAME));
+        expectNotifyNetworkStatus(cellWifiAndVpn, wifiAndVpn, MOBILE_IFNAME, Process.myUid(),
+                VPN_IFNAME,  List.of(MOBILE_IFNAME, WIFI_IFNAME));
         reset(mStatsManager);
 
         // Null underlying networks are ignored.
         mMockVpn.setUnderlyingNetworks(cellNullAndWifi);
         waitForIdle();
-        expectNotifyNetworkStatus(wifiAndVpn, MOBILE_IFNAME, Process.myUid(), VPN_IFNAME,
-                List.of(MOBILE_IFNAME, WIFI_IFNAME));
+        expectNotifyNetworkStatus(cellWifiAndVpn, wifiAndVpn, MOBILE_IFNAME, Process.myUid(),
+                VPN_IFNAME,  List.of(MOBILE_IFNAME, WIFI_IFNAME));
         reset(mStatsManager);
 
         // If an underlying network disconnects, that interface should no longer be underlying.
@@ -7310,8 +7361,8 @@
         mCellNetworkAgent.disconnect();
         waitForIdle();
         assertNull(mService.getLinkProperties(mCellNetworkAgent.getNetwork()));
-        expectNotifyNetworkStatus(wifiAndVpn, MOBILE_IFNAME, Process.myUid(), VPN_IFNAME,
-                List.of(MOBILE_IFNAME, WIFI_IFNAME));
+        expectNotifyNetworkStatus(cellWifiAndVpn, wifiAndVpn, MOBILE_IFNAME, Process.myUid(),
+                VPN_IFNAME, List.of(MOBILE_IFNAME, WIFI_IFNAME));
 
         // Confirm that we never tell NetworkStatsService that cell is no longer the underlying
         // network for the VPN...
@@ -7346,25 +7397,25 @@
         // Also, for the same reason as above, the active interface passed in is null.
         mMockVpn.setUnderlyingNetworks(new Network[0]);
         waitForIdle();
-        expectNotifyNetworkStatus(wifiAndVpn, null);
+        expectNotifyNetworkStatus(wifiAndVpn, wifiAndVpn, null);
         reset(mStatsManager);
 
         // Specifying only a null underlying network is the same as no networks.
         mMockVpn.setUnderlyingNetworks(onlyNull);
         waitForIdle();
-        expectNotifyNetworkStatus(wifiAndVpn, null);
+        expectNotifyNetworkStatus(wifiAndVpn, wifiAndVpn, null);
         reset(mStatsManager);
 
         // Specifying networks that are all disconnected is the same as specifying no networks.
-        mMockVpn.setUnderlyingNetworks(onlyCell.toArray(new Network[0]));
+        mMockVpn.setUnderlyingNetworks(onlyCell().toArray(new Network[0]));
         waitForIdle();
-        expectNotifyNetworkStatus(wifiAndVpn, null);
+        expectNotifyNetworkStatus(wifiAndVpn, wifiAndVpn, null);
         reset(mStatsManager);
 
         // Passing in null again means follow the default network again.
         mMockVpn.setUnderlyingNetworks(null);
         waitForIdle();
-        expectNotifyNetworkStatus(wifiAndVpn, WIFI_IFNAME, Process.myUid(), VPN_IFNAME,
+        expectNotifyNetworkStatus(wifiAndVpn, wifiAndVpn, WIFI_IFNAME, Process.myUid(), VPN_IFNAME,
                 List.of(WIFI_IFNAME));
         reset(mStatsManager);
     }
@@ -7419,7 +7470,7 @@
         mWiFiNetworkAgent.connect(true /* validated */);
 
         final List<Network> none = List.of();
-        expectNotifyNetworkStatus(none, null);  // Wifi is not the default network
+        expectNotifyNetworkStatus(onlyWifi(), none, null);  // Wifi is not the default network
 
         // Create a virtual network based on the wifi network.
         final int ownerUid = 10042;
@@ -7444,7 +7495,9 @@
         assertFalse(nc.hasTransport(TRANSPORT_WIFI));
         assertFalse(nc.hasCapability(NET_CAPABILITY_NOT_METERED));
         final List<Network> onlyVcn = List.of(vcn.getNetwork());
-        expectNotifyNetworkStatus(onlyVcn, vcnIface, ownerUid, vcnIface, List.of(WIFI_IFNAME));
+        final List<Network> vcnAndWifi = List.of(vcn.getNetwork(), mWiFiNetworkAgent.getNetwork());
+        expectNotifyNetworkStatus(vcnAndWifi, onlyVcn, vcnIface, ownerUid, vcnIface,
+                List.of(WIFI_IFNAME));
 
         // Add NOT_METERED to the underlying network, check that it is not propagated.
         mWiFiNetworkAgent.addCapability(NET_CAPABILITY_NOT_METERED);
@@ -7467,7 +7520,10 @@
         nc = mCm.getNetworkCapabilities(vcn.getNetwork());
         assertFalse(nc.hasTransport(TRANSPORT_WIFI));
         assertFalse(nc.hasCapability(NET_CAPABILITY_NOT_ROAMING));
-        expectNotifyNetworkStatus(onlyVcn, vcnIface, ownerUid, vcnIface, List.of(MOBILE_IFNAME));
+        final List<Network> vcnWifiAndCell = List.of(vcn.getNetwork(),
+                mWiFiNetworkAgent.getNetwork(), mCellNetworkAgent.getNetwork());
+        expectNotifyNetworkStatus(vcnWifiAndCell, onlyVcn, vcnIface, ownerUid, vcnIface,
+                List.of(MOBILE_IFNAME));
     }
 
     @Test
@@ -16726,23 +16782,25 @@
         assertThrows(NullPointerException.class, () -> mService.unofferNetwork(null));
     }
 
-    public void doTestIgnoreValidationAfterRoam(final boolean enabled) throws Exception {
-        assumeFalse(SdkLevel.isAtLeastT());
-        doReturn(enabled ? 5000 : -1).when(mResources)
+    public void doTestIgnoreValidationAfterRoam(int resValue, final boolean enabled)
+            throws Exception {
+        doReturn(resValue).when(mResources)
                 .getInteger(R.integer.config_validationFailureAfterRoamIgnoreTimeMillis);
 
+        final String bssid1 = "AA:AA:AA:AA:AA:AA";
+        final String bssid2 = "BB:BB:BB:BB:BB:BB";
         mCellNetworkAgent = new TestNetworkAgentWrapper(TRANSPORT_CELLULAR);
         mCellNetworkAgent.connect(true);
         NetworkCapabilities wifiNc1 = new NetworkCapabilities()
                 .addCapability(NET_CAPABILITY_INTERNET)
+                .addCapability(NET_CAPABILITY_NOT_VPN)
+                .addCapability(NET_CAPABILITY_NOT_RESTRICTED)
+                .addCapability(NET_CAPABILITY_TRUSTED)
                 .addCapability(NET_CAPABILITY_NOT_VCN_MANAGED)
                 .addTransportType(TRANSPORT_WIFI)
-                .setTransportInfo(new WifiInfo.Builder().setBssid("AA:AA:AA:AA:AA:AA").build());
-        NetworkCapabilities wifiNc2 = new NetworkCapabilities()
-                .addCapability(NET_CAPABILITY_INTERNET)
-                .addCapability(NET_CAPABILITY_NOT_VCN_MANAGED)
-                .addTransportType(TRANSPORT_WIFI)
-                .setTransportInfo(new WifiInfo.Builder().setBssid("BB:BB:BB:BB:BB:BB").build());
+                .setTransportInfo(new WifiInfo.Builder().setBssid(bssid1).build());
+        NetworkCapabilities wifiNc2 = new NetworkCapabilities(wifiNc1)
+                .setTransportInfo(new WifiInfo.Builder().setBssid(bssid2).build());
         final LinkProperties wifiLp = new LinkProperties();
         wifiLp.setInterfaceName(WIFI_IFNAME);
         mWiFiNetworkAgent = new TestNetworkAgentWrapper(TRANSPORT_WIFI, wifiLp, wifiNc1);
@@ -16806,6 +16864,7 @@
             wifiNetworkCallback.expectCapabilitiesWithout(NET_CAPABILITY_PARTIAL_CONNECTIVITY,
                     mWiFiNetworkAgent);
         }
+        mCm.unregisterNetworkCallback(wifiNetworkCallback);
 
         // Wi-Fi roams from wifiNc1 to wifiNc2, and now becomes really invalid. If validation
         // failures after roam are not ignored, this will cause cell to become the default network.
@@ -16816,40 +16875,77 @@
         mWiFiNetworkAgent.setNetworkInvalid(false /* isStrictMode */);
         mCm.reportNetworkConnectivity(mWiFiNetworkAgent.getNetwork(), false);
 
-        if (!enabled) {
+        if (enabled) {
+            // Network validation failed, but the result will be ignored.
+            assertTrue(mCm.getNetworkCapabilities(mWiFiNetworkAgent.getNetwork()).hasCapability(
+                    NET_CAPABILITY_VALIDATED));
+            mWiFiNetworkAgent.setNetworkValid(false);
+
+            // Behavior of after config_validationFailureAfterRoamIgnoreTimeMillis
+            ConditionVariable waitForValidationBlock = new ConditionVariable();
+            doReturn(50).when(mResources)
+                    .getInteger(R.integer.config_validationFailureAfterRoamIgnoreTimeMillis);
+            // Wi-Fi roaming from wifiNc2 to wifiNc1.
+            mWiFiNetworkAgent.setNetworkCapabilities(wifiNc1, true);
+            mWiFiNetworkAgent.setNetworkInvalid(false);
+            waitForValidationBlock.block(150);
+            mCm.reportNetworkConnectivity(mWiFiNetworkAgent.getNetwork(), false);
             mDefaultNetworkCallback.expectAvailableCallbacksValidated(mCellNetworkAgent);
-            return;
+        } else {
+            mDefaultNetworkCallback.expectAvailableCallbacksValidated(mCellNetworkAgent);
         }
 
-        // Network validation failed, but the result will be ignored.
-        assertTrue(mCm.getNetworkCapabilities(mWiFiNetworkAgent.getNetwork()).hasCapability(
-                NET_CAPABILITY_VALIDATED));
-        mWiFiNetworkAgent.setNetworkValid(false);
+        // Wi-Fi is still connected and would become the default network if cell were to
+        // disconnect. This assertion ensures that the switch to cellular was not caused by
+        // Wi-Fi disconnecting (e.g., because the capability change to wifiNc2 caused it
+        // to stop satisfying the default request).
+        mCellNetworkAgent.disconnect();
+        mDefaultNetworkCallback.expectCallback(CallbackEntry.LOST, mCellNetworkAgent);
+        mDefaultNetworkCallback.expectAvailableCallbacksUnvalidated(mWiFiNetworkAgent);
 
-        // Behavior of after config_validationFailureAfterRoamIgnoreTimeMillis
-        ConditionVariable waitForValidationBlock = new ConditionVariable();
-        doReturn(50).when(mResources)
-                .getInteger(R.integer.config_validationFailureAfterRoamIgnoreTimeMillis);
-        // Wi-Fi roaming from wifiNc2 to wifiNc1.
-        mWiFiNetworkAgent.setNetworkCapabilities(wifiNc1, true);
-        mWiFiNetworkAgent.setNetworkInvalid(false);
-        waitForValidationBlock.block(150);
-        mCm.reportNetworkConnectivity(mWiFiNetworkAgent.getNetwork(), false);
-        mDefaultNetworkCallback.expectAvailableCallbacksValidated(mCellNetworkAgent);
-
-        mCm.unregisterNetworkCallback(wifiNetworkCallback);
     }
 
     @Test
     public void testIgnoreValidationAfterRoamDisabled() throws Exception {
-        doTestIgnoreValidationAfterRoam(false /* enabled */);
+        doTestIgnoreValidationAfterRoam(-1, false /* enabled */);
     }
+
     @Test
     public void testIgnoreValidationAfterRoamEnabled() throws Exception {
-        doTestIgnoreValidationAfterRoam(true /* enabled */);
+        final boolean enabled = !SdkLevel.isAtLeastT();
+        doTestIgnoreValidationAfterRoam(5_000, enabled);
     }
 
     @Test
+    public void testShouldIgnoreValidationFailureAfterRoam() {
+        // Always disabled on T+.
+        assumeFalse(SdkLevel.isAtLeastT());
+
+        NetworkAgentInfo nai = fakeWifiNai(new NetworkCapabilities());
+
+        // Enabled, but never roamed.
+        doReturn(5_000).when(mResources)
+                .getInteger(R.integer.config_validationFailureAfterRoamIgnoreTimeMillis);
+        assertEquals(0, nai.lastRoamTime);
+        assertFalse(mService.shouldIgnoreValidationFailureAfterRoam(nai));
+
+        // Roamed recently.
+        nai.lastRoamTime = SystemClock.elapsedRealtime() - 500 /* ms */;
+        assertTrue(mService.shouldIgnoreValidationFailureAfterRoam(nai));
+
+        // Disabled due to invalid setting (maximum is 10 seconds).
+        doReturn(15_000).when(mResources)
+                .getInteger(R.integer.config_validationFailureAfterRoamIgnoreTimeMillis);
+        assertFalse(mService.shouldIgnoreValidationFailureAfterRoam(nai));
+
+        // Disabled.
+        doReturn(-1).when(mResources)
+                .getInteger(R.integer.config_validationFailureAfterRoamIgnoreTimeMillis);
+        assertFalse(mService.shouldIgnoreValidationFailureAfterRoam(nai));
+    }
+
+
+    @Test
     public void testLegacyTetheringApiGuardWithProperPermission() throws Exception {
         final String testIface = "test0";
         mServiceContext.setPermission(ACCESS_NETWORK_STATE, PERMISSION_DENIED);