Merge "Sets deviceId in providers" into udc-dev
diff --git a/service-t/src/com/android/server/NsdService.java b/service-t/src/com/android/server/NsdService.java
index 383ed2c..a62ee7f 100644
--- a/service-t/src/com/android/server/NsdService.java
+++ b/service-t/src/com/android/server/NsdService.java
@@ -61,6 +61,7 @@
 import com.android.internal.util.State;
 import com.android.internal.util.StateMachine;
 import com.android.net.module.util.DeviceConfigUtils;
+import com.android.net.module.util.InetAddressUtils;
 import com.android.net.module.util.PermissionUtils;
 import com.android.net.module.util.SharedLog;
 import com.android.server.connectivity.mdns.ExecutorProvider;
@@ -1240,16 +1241,10 @@
         }
         for (String ipv6Address : v6Addrs) {
             try {
-                final InetAddress addr = InetAddresses.parseNumericAddress(ipv6Address);
-                if (addr.isLinkLocalAddress()) {
-                    final Inet6Address v6Addr = Inet6Address.getByAddress(
-                            null /* host */, addr.getAddress(),
-                            serviceInfo.getInterfaceIndex());
-                    addresses.add(v6Addr);
-                } else {
-                    addresses.add(addr);
-                }
-            } catch (IllegalArgumentException | UnknownHostException e) {
+                final Inet6Address addr = (Inet6Address) InetAddresses.parseNumericAddress(
+                        ipv6Address);
+                addresses.add(InetAddressUtils.withScopeId(addr, serviceInfo.getInterfaceIndex()));
+            } catch (IllegalArgumentException e) {
                 Log.wtf(TAG, "Invalid ipv6 address", e);
             }
         }
diff --git a/service-t/src/com/android/server/connectivity/mdns/ISocketNetLinkMonitor.java b/service-t/src/com/android/server/connectivity/mdns/AbstractSocketNetlink.java
similarity index 95%
rename from service-t/src/com/android/server/connectivity/mdns/ISocketNetLinkMonitor.java
rename to service-t/src/com/android/server/connectivity/mdns/AbstractSocketNetlink.java
index ef3928c..b792e46 100644
--- a/service-t/src/com/android/server/connectivity/mdns/ISocketNetLinkMonitor.java
+++ b/service-t/src/com/android/server/connectivity/mdns/AbstractSocketNetlink.java
@@ -19,7 +19,7 @@
 /**
  * The interface for netlink monitor.
  */
-public interface ISocketNetLinkMonitor {
+public interface AbstractSocketNetlink {
 
     /**
      * Returns if the netlink monitor is supported or not. By default, it is not supported.
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceSocket.java b/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceSocket.java
index 31627f8..119c7a8 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceSocket.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceSocket.java
@@ -78,7 +78,7 @@
         }
 
         mPacketReader = new MulticastPacketReader(networkInterface.getName(), mFileDescriptor,
-                new Handler(looper), packetReadBuffer, port);
+                new Handler(looper), packetReadBuffer);
         mPacketReader.start();
     }
 
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceCache.java b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceCache.java
new file mode 100644
index 0000000..bfda535
--- /dev/null
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceCache.java
@@ -0,0 +1,178 @@
+/*
+ * Copyright (C) 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server.connectivity.mdns;
+
+import static com.android.server.connectivity.mdns.MdnsSocketProvider.ensureRunningOnHandlerThread;
+import static com.android.server.connectivity.mdns.util.MdnsUtils.equalsIgnoreDnsCase;
+import static com.android.server.connectivity.mdns.util.MdnsUtils.toDnsLowerCase;
+
+import android.annotation.NonNull;
+import android.annotation.Nullable;
+import android.net.Network;
+import android.os.Handler;
+import android.os.Looper;
+import android.util.ArrayMap;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Objects;
+
+/**
+ * The {@link MdnsServiceCache} manages the service which discovers from each socket and cache these
+ * services to reduce duplicated queries.
+ *
+ * <p>This class is not thread safe, it is intended to be used only from the looper thread.
+ *  However, the constructor is an exception, as it is called on another thread;
+ *  therefore for thread safety all members of this class MUST either be final or initialized
+ *  to their default value (0, false or null).
+ */
+public class MdnsServiceCache {
+    private static class CacheKey {
+        @NonNull final String mLowercaseServiceType;
+        @Nullable final Network mNetwork;
+
+        CacheKey(@NonNull String serviceType, @Nullable Network network) {
+            mLowercaseServiceType = toDnsLowerCase(serviceType);
+            mNetwork = network;
+        }
+
+        @Override public int hashCode() {
+            return Objects.hash(mLowercaseServiceType, mNetwork);
+        }
+
+        @Override public boolean equals(Object other) {
+            if (this == other) {
+                return true;
+            }
+            if (!(other instanceof CacheKey)) {
+                return false;
+            }
+            return Objects.equals(mLowercaseServiceType, ((CacheKey) other).mLowercaseServiceType)
+                    && Objects.equals(mNetwork, ((CacheKey) other).mNetwork);
+        }
+    }
+    /**
+     * A map of cached services. Key is composed of service name, type and network. Value is the
+     * service which use the service type to discover from each socket.
+     */
+    @NonNull
+    private final ArrayMap<CacheKey, List<MdnsResponse>> mCachedServices = new ArrayMap<>();
+    @NonNull
+    private final Handler mHandler;
+
+    public MdnsServiceCache(@NonNull Looper looper) {
+        mHandler = new Handler(looper);
+    }
+
+    /**
+     * Get the cache services which are queried from given service type and network.
+     *
+     * @param serviceType the target service type.
+     * @param network the target network
+     * @return the set of services which matches the given service type.
+     */
+    @NonNull
+    public List<MdnsResponse> getCachedServices(@NonNull String serviceType,
+            @Nullable Network network) {
+        ensureRunningOnHandlerThread(mHandler);
+        final CacheKey key = new CacheKey(serviceType, network);
+        return mCachedServices.containsKey(key)
+                ? Collections.unmodifiableList(new ArrayList<>(mCachedServices.get(key)))
+                : Collections.emptyList();
+    }
+
+    private MdnsResponse findMatchedResponse(@NonNull List<MdnsResponse> responses,
+            @NonNull String serviceName) {
+        for (MdnsResponse response : responses) {
+            if (equalsIgnoreDnsCase(serviceName, response.getServiceInstanceName())) {
+                return response;
+            }
+        }
+        return null;
+    }
+
+    /**
+     * Get the cache service.
+     *
+     * @param serviceName the target service name.
+     * @param serviceType the target service type.
+     * @param network the target network
+     * @return the service which matches given conditions.
+     */
+    @Nullable
+    public MdnsResponse getCachedService(@NonNull String serviceName,
+            @NonNull String serviceType, @Nullable Network network) {
+        ensureRunningOnHandlerThread(mHandler);
+        final List<MdnsResponse> responses =
+                mCachedServices.get(new CacheKey(serviceType, network));
+        if (responses == null) {
+            return null;
+        }
+        final MdnsResponse response = findMatchedResponse(responses, serviceName);
+        return response != null ? new MdnsResponse(response) : null;
+    }
+
+    /**
+     * Add or update a service.
+     *
+     * @param serviceType the service type.
+     * @param network the target network
+     * @param response the response of the discovered service.
+     */
+    public void addOrUpdateService(@NonNull String serviceType, @Nullable Network network,
+            @NonNull MdnsResponse response) {
+        ensureRunningOnHandlerThread(mHandler);
+        final List<MdnsResponse> responses = mCachedServices.computeIfAbsent(
+                new CacheKey(serviceType, network), key -> new ArrayList<>());
+        // Remove existing service if present.
+        final MdnsResponse existing =
+                findMatchedResponse(responses, response.getServiceInstanceName());
+        responses.remove(existing);
+        responses.add(response);
+    }
+
+    /**
+     * Remove a service which matches the given service name, type and network.
+     *
+     * @param serviceName the target service name.
+     * @param serviceType the target service type.
+     * @param network the target network.
+     */
+    @Nullable
+    public MdnsResponse removeService(@NonNull String serviceName, @NonNull String serviceType,
+            @Nullable Network network) {
+        ensureRunningOnHandlerThread(mHandler);
+        final List<MdnsResponse> responses =
+                mCachedServices.get(new CacheKey(serviceType, network));
+        if (responses == null) {
+            return null;
+        }
+        final Iterator<MdnsResponse> iterator = responses.iterator();
+        while (iterator.hasNext()) {
+            final MdnsResponse response = iterator.next();
+            if (equalsIgnoreDnsCase(serviceName, response.getServiceInstanceName())) {
+                iterator.remove();
+                return response;
+            }
+        }
+        return null;
+    }
+
+    // TODO: check ttl expiration for each service and notify to the clients.
+}
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsSocketProvider.java b/service-t/src/com/android/server/connectivity/mdns/MdnsSocketProvider.java
index c45345a..b45a831 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsSocketProvider.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsSocketProvider.java
@@ -74,7 +74,7 @@
     @NonNull private final Dependencies mDependencies;
     @NonNull private final NetworkCallback mNetworkCallback;
     @NonNull private final TetheringEventCallback mTetheringEventCallback;
-    @NonNull private final ISocketNetLinkMonitor mSocketNetlinkMonitor;
+    @NonNull private final AbstractSocketNetlink mSocketNetlinkMonitor;
     private final ArrayMap<Network, SocketInfo> mNetworkSockets = new ArrayMap<>();
     private final ArrayMap<String, SocketInfo> mTetherInterfaceSockets = new ArrayMap<>();
     private final ArrayMap<Network, LinkProperties> mActiveNetworksLinkProperties =
@@ -171,7 +171,7 @@
             return iface.getIndex();
         }
         /*** Creates a SocketNetlinkMonitor */
-        public ISocketNetLinkMonitor createSocketNetlinkMonitor(@NonNull final Handler handler,
+        public AbstractSocketNetlink createSocketNetlinkMonitor(@NonNull final Handler handler,
                 @NonNull final SharedLog log,
                 @NonNull final NetLinkMonitorCallBack cb) {
             return SocketNetLinkMonitorFactory.createNetLinkMonitor(handler, log, cb);
diff --git a/service-t/src/com/android/server/connectivity/mdns/MulticastPacketReader.java b/service-t/src/com/android/server/connectivity/mdns/MulticastPacketReader.java
index 078c4dd..0e5a2a4 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MulticastPacketReader.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MulticastPacketReader.java
@@ -63,8 +63,10 @@
      */
     protected MulticastPacketReader(@NonNull String interfaceTag,
             @NonNull ParcelFileDescriptor socket, @NonNull Handler handler,
-            @NonNull byte[] buffer, int port) {
-        super(handler, new RecvBuffer(buffer, new InetSocketAddress(port)));
+            @NonNull byte[] buffer) {
+        // Set the port to zero as placeholder as the recvfrom() call will fill the actual port
+        // value later.
+        super(handler, new RecvBuffer(buffer, new InetSocketAddress(0 /* port */)));
         mLogTag = MulticastPacketReader.class.getSimpleName() + "/" + interfaceTag;
         mSocket = socket;
         mHandler = handler;
diff --git a/service-t/src/com/android/server/connectivity/mdns/SocketNetLinkMonitorFactory.java b/service-t/src/com/android/server/connectivity/mdns/SocketNetLinkMonitorFactory.java
index 4650255..6bc7941 100644
--- a/service-t/src/com/android/server/connectivity/mdns/SocketNetLinkMonitorFactory.java
+++ b/service-t/src/com/android/server/connectivity/mdns/SocketNetLinkMonitorFactory.java
@@ -30,7 +30,7 @@
     /**
      * Creates a new netlink monitor.
      */
-    public static ISocketNetLinkMonitor createNetLinkMonitor(@NonNull final Handler handler,
+    public static AbstractSocketNetlink createNetLinkMonitor(@NonNull final Handler handler,
             @NonNull SharedLog log, @NonNull MdnsSocketProvider.NetLinkMonitorCallBack cb) {
         return new SocketNetlinkMonitor(handler, log, cb);
     }
diff --git a/service-t/src/com/android/server/connectivity/mdns/internal/SocketNetlinkMonitor.java b/service-t/src/com/android/server/connectivity/mdns/internal/SocketNetlinkMonitor.java
index 6395b53..40191cf 100644
--- a/service-t/src/com/android/server/connectivity/mdns/internal/SocketNetlinkMonitor.java
+++ b/service-t/src/com/android/server/connectivity/mdns/internal/SocketNetlinkMonitor.java
@@ -28,13 +28,13 @@
 import com.android.net.module.util.netlink.NetlinkMessage;
 import com.android.net.module.util.netlink.RtNetlinkAddressMessage;
 import com.android.net.module.util.netlink.StructIfaddrMsg;
-import com.android.server.connectivity.mdns.ISocketNetLinkMonitor;
+import com.android.server.connectivity.mdns.AbstractSocketNetlink;
 import com.android.server.connectivity.mdns.MdnsSocketProvider;
 
 /**
  * The netlink monitor for MdnsSocketProvider.
  */
-public class SocketNetlinkMonitor extends NetlinkMonitor implements ISocketNetLinkMonitor {
+public class SocketNetlinkMonitor extends NetlinkMonitor implements AbstractSocketNetlink {
 
     public static final String TAG = SocketNetlinkMonitor.class.getSimpleName();
 
diff --git a/service-t/src/com/android/server/connectivity/mdns/util/MdnsUtils.java b/service-t/src/com/android/server/connectivity/mdns/util/MdnsUtils.java
new file mode 100644
index 0000000..4b0f2a4
--- /dev/null
+++ b/service-t/src/com/android/server/connectivity/mdns/util/MdnsUtils.java
@@ -0,0 +1,59 @@
+/*
+ * Copyright (C) 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server.connectivity.mdns.util;
+
+import android.annotation.NonNull;
+
+/**
+ * Mdns utility functions.
+ */
+public class MdnsUtils {
+
+    private MdnsUtils() { }
+
+    /**
+     * Convert the string to DNS case-insensitive lowercase
+     *
+     * Per rfc6762#page-46, accented characters are not defined to be automatically equivalent to
+     * their unaccented counterparts. So the "DNS lowercase" should be if character is A-Z then they
+     * transform into a-z. Otherwise, they are kept as-is.
+     */
+    public static String toDnsLowerCase(@NonNull String string) {
+        final char[] outChars = new char[string.length()];
+        for (int i = 0; i < string.length(); i++) {
+            outChars[i] = toDnsLowerCase(string.charAt(i));
+        }
+        return new String(outChars);
+    }
+
+    /**
+     * Compare two strings by DNS case-insensitive lowercase.
+     */
+    public static boolean equalsIgnoreDnsCase(@NonNull String a, @NonNull String b) {
+        if (a.length() != b.length()) return false;
+        for (int i = 0; i < a.length(); i++) {
+            if (toDnsLowerCase(a.charAt(i)) != toDnsLowerCase(b.charAt(i))) {
+                return false;
+            }
+        }
+        return true;
+    }
+
+    private static char toDnsLowerCase(char a) {
+        return a >= 'A' && a <= 'Z' ? (char) (a + ('a' - 'A')) : a;
+    }
+}
diff --git a/service-t/src/com/android/server/net/NetworkStatsService.java b/service-t/src/com/android/server/net/NetworkStatsService.java
index 961337d..c660792 100644
--- a/service-t/src/com/android/server/net/NetworkStatsService.java
+++ b/service-t/src/com/android/server/net/NetworkStatsService.java
@@ -946,7 +946,11 @@
     @GuardedBy("mStatsLock")
     private void shutdownLocked() {
         final TetheringManager tetheringManager = mContext.getSystemService(TetheringManager.class);
-        tetheringManager.unregisterTetheringEventCallback(mTetherListener);
+        try {
+            tetheringManager.unregisterTetheringEventCallback(mTetherListener);
+        } catch (IllegalStateException e) {
+            Log.i(TAG, "shutdownLocked: error when unregister tethering, ignored. e=" + e);
+        }
         mContext.unregisterReceiver(mPollReceiver);
         mContext.unregisterReceiver(mRemovedReceiver);
         mContext.unregisterReceiver(mUserReceiver);
diff --git a/service/src/com/android/server/ConnectivityService.java b/service/src/com/android/server/ConnectivityService.java
index ba503e0..6fa2746 100755
--- a/service/src/com/android/server/ConnectivityService.java
+++ b/service/src/com/android/server/ConnectivityService.java
@@ -7904,16 +7904,27 @@
         return captivePortalBuilder.build();
     }
 
-    private String makeNflogPrefix(String iface, long networkHandle) {
+    @VisibleForTesting
+    static String makeNflogPrefix(String iface, long networkHandle) {
         // This needs to be kept in sync and backwards compatible with the decoding logic in
         // NetdEventListenerService, which is non-mainline code.
         return SdkLevel.isAtLeastU() ? (networkHandle + ":" + iface) : ("iface:" + iface);
     }
 
+    private static boolean isWakeupMarkingSupported(NetworkCapabilities capabilities) {
+        if (capabilities.hasTransport(TRANSPORT_WIFI)) {
+            return true;
+        }
+        if (SdkLevel.isAtLeastU() && capabilities.hasTransport(TRANSPORT_CELLULAR)) {
+            return true;
+        }
+        return false;
+    }
+
     private void wakeupModifyInterface(String iface, NetworkAgentInfo nai, boolean add) {
         // Marks are only available on WiFi interfaces. Checking for
         // marks on unsupported interfaces is harmless.
-        if (!nai.networkCapabilities.hasTransport(NetworkCapabilities.TRANSPORT_WIFI)) {
+        if (!isWakeupMarkingSupported(nai.networkCapabilities)) {
             return;
         }
 
diff --git a/service/src/com/android/server/connectivity/NetworkDiagnostics.java b/service/src/com/android/server/connectivity/NetworkDiagnostics.java
index 509110d..15d0925 100644
--- a/service/src/com/android/server/connectivity/NetworkDiagnostics.java
+++ b/service/src/com/android/server/connectivity/NetworkDiagnostics.java
@@ -18,6 +18,11 @@
 
 import static android.system.OsConstants.*;
 
+import static com.android.net.module.util.NetworkStackConstants.ICMP_HEADER_LEN;
+import static com.android.net.module.util.NetworkStackConstants.IPV4_HEADER_MIN_LEN;
+import static com.android.net.module.util.NetworkStackConstants.IPV6_HEADER_LEN;
+import static com.android.net.module.util.NetworkStackConstants.IPV6_MIN_MTU;
+
 import android.annotation.NonNull;
 import android.annotation.Nullable;
 import android.net.InetAddresses;
@@ -33,6 +38,7 @@
 import android.system.Os;
 import android.system.StructTimeval;
 import android.text.TextUtils;
+import android.util.Log;
 import android.util.Pair;
 
 import com.android.internal.util.IndentingPrintWriter;
@@ -172,7 +178,7 @@
         }
     }
 
-    private final Map<InetAddress, Measurement> mIcmpChecks = new HashMap<>();
+    private final Map<Pair<InetAddress, Integer>, Measurement> mIcmpChecks = new HashMap<>();
     private final Map<Pair<InetAddress, InetAddress>, Measurement> mExplicitSourceIcmpChecks =
             new HashMap<>();
     private final Map<InetAddress, Measurement> mDnsUdpChecks = new HashMap<>();
@@ -205,17 +211,21 @@
             mLinkProperties.addDnsServer(TEST_DNS6);
         }
 
+        final int mtu = mLinkProperties.getMtu();
         for (RouteInfo route : mLinkProperties.getRoutes()) {
             if (route.getType() == RouteInfo.RTN_UNICAST && route.hasGateway()) {
                 InetAddress gateway = route.getGateway();
-                prepareIcmpMeasurement(gateway);
+                // Use mtu in the route if exists. Otherwise, use the one in the link property.
+                final int routeMtu = route.getMtu();
+                prepareIcmpMeasurements(gateway, (routeMtu > 0) ? routeMtu : mtu);
                 if (route.isIPv6Default()) {
                     prepareExplicitSourceIcmpMeasurements(gateway);
                 }
             }
         }
+
         for (InetAddress nameserver : mLinkProperties.getDnsServers()) {
-            prepareIcmpMeasurement(nameserver);
+            prepareIcmpMeasurements(nameserver, mtu);
             prepareDnsMeasurement(nameserver);
 
             // Unlike the DnsResolver which doesn't do certificate validation in opportunistic mode,
@@ -261,11 +271,50 @@
                 localAddr.getHostAddress(), inetSockAddr.getPort());
     }
 
-    private void prepareIcmpMeasurement(InetAddress target) {
-        if (!mIcmpChecks.containsKey(target)) {
-            Measurement measurement = new Measurement();
-            measurement.thread = new Thread(new IcmpCheck(target, measurement));
-            mIcmpChecks.put(target, measurement);
+    private static int getHeaderLen(@NonNull InetAddress target) {
+        // Convert IPv4 mapped v6 address to v4 if any.
+        try {
+            final InetAddress addr = InetAddress.getByAddress(target.getAddress());
+            // An ICMPv6 header is technically 4 bytes, but the implementation in IcmpCheck#run()
+            // will always fill in another 4 bytes padding in the v6 diagnostic packets, so the size
+            // before icmp data is always 8 bytes in the implementation of ICMP diagnostics for both
+            // v4 and v6 packets. Thus, it's fine to use the v4 header size in the length
+            // calculation.
+            if (addr instanceof Inet6Address) {
+                return IPV6_HEADER_LEN + ICMP_HEADER_LEN;
+            }
+        } catch (UnknownHostException e) {
+            Log.e(TAG, "Create InetAddress fail(" + target + "): " + e);
+        }
+
+        return IPV4_HEADER_MIN_LEN + ICMP_HEADER_LEN;
+    }
+
+    private void prepareIcmpMeasurements(@NonNull InetAddress target, int targetNetworkMtu) {
+        // Test with different size payload ICMP.
+        // 1. Test with 0 payload.
+        addPayloadIcmpMeasurement(target, 0);
+        final int header = getHeaderLen(target);
+        // 2. Test with full size MTU.
+        addPayloadIcmpMeasurement(target, targetNetworkMtu - header);
+        // 3. If v6, make another measurement with the full v6 min MTU, unless that's what
+        //    was done above.
+        if ((target instanceof Inet6Address) && (targetNetworkMtu != IPV6_MIN_MTU)) {
+            addPayloadIcmpMeasurement(target, IPV6_MIN_MTU - header);
+        }
+    }
+
+    private void addPayloadIcmpMeasurement(@NonNull InetAddress target, int payloadLen) {
+        // This can happen if the there is no mtu filled(which is 0) in the link property.
+        // The value becomes negative after minus header length.
+        if (payloadLen < 0) return;
+
+        final Pair<InetAddress, Integer> lenTarget =
+                new Pair<>(target, Integer.valueOf(payloadLen));
+        if (!mIcmpChecks.containsKey(lenTarget)) {
+            final Measurement measurement = new Measurement();
+            measurement.thread = new Thread(new IcmpCheck(target, payloadLen, measurement));
+            mIcmpChecks.put(lenTarget, measurement);
         }
     }
 
@@ -276,7 +325,7 @@
                 Pair<InetAddress, InetAddress> srcTarget = new Pair<>(source, target);
                 if (!mExplicitSourceIcmpChecks.containsKey(srcTarget)) {
                     Measurement measurement = new Measurement();
-                    measurement.thread = new Thread(new IcmpCheck(source, target, measurement));
+                    measurement.thread = new Thread(new IcmpCheck(source, target, 0, measurement));
                     mExplicitSourceIcmpChecks.put(srcTarget, measurement);
                 }
             }
@@ -334,8 +383,8 @@
         ArrayList<Measurement> measurements = new ArrayList(totalMeasurementCount());
 
         // Sort measurements IPv4 first.
-        for (Map.Entry<InetAddress, Measurement> entry : mIcmpChecks.entrySet()) {
-            if (entry.getKey() instanceof Inet4Address) {
+        for (Map.Entry<Pair<InetAddress, Integer>, Measurement> entry : mIcmpChecks.entrySet()) {
+            if (entry.getKey().first instanceof Inet4Address) {
                 measurements.add(entry.getValue());
             }
         }
@@ -357,8 +406,8 @@
         }
 
         // IPv6 measurements second.
-        for (Map.Entry<InetAddress, Measurement> entry : mIcmpChecks.entrySet()) {
-            if (entry.getKey() instanceof Inet6Address) {
+        for (Map.Entry<Pair<InetAddress, Integer>, Measurement> entry : mIcmpChecks.entrySet()) {
+            if (entry.getKey().first instanceof Inet6Address) {
                 measurements.add(entry.getValue());
             }
         }
@@ -489,8 +538,11 @@
         private static final int PACKET_BUFSIZE = 512;
         private final int mProtocol;
         private final int mIcmpType;
+        private final int mPayloadSize;
+        // The length parameter is effectively the -s flag to ping/ping6 to specify the number of
+        // data bytes to be sent.
+        IcmpCheck(InetAddress source, InetAddress target, int length, Measurement measurement) {
 
-        public IcmpCheck(InetAddress source, InetAddress target, Measurement measurement) {
             super(source, target, measurement);
 
             if (mAddressFamily == AF_INET6) {
@@ -502,12 +554,13 @@
                 mIcmpType = NetworkConstants.ICMPV4_ECHO_REQUEST_TYPE;
                 mMeasurement.description = "ICMPv4";
             }
-
-            mMeasurement.description += " dst{" + mTarget.getHostAddress() + "}";
+            mPayloadSize = length;
+            mMeasurement.description += " payloadLength{" + mPayloadSize  + "}"
+                    + " dst{" + mTarget.getHostAddress() + "}";
         }
 
-        public IcmpCheck(InetAddress target, Measurement measurement) {
-            this(null, target, measurement);
+        IcmpCheck(InetAddress target, int length, Measurement measurement) {
+            this(null, target, length, measurement);
         }
 
         @Override
@@ -523,9 +576,11 @@
             mMeasurement.description += " src{" + socketAddressToString(mSocketAddress) + "}";
 
             // Build a trivial ICMP packet.
-            final byte[] icmpPacket = {
-                    (byte) mIcmpType, 0, 0, 0, 0, 0, 0, 0  // ICMP header
-            };
+            // The v4 ICMP header ICMP_HEADER_LEN (which is 8) and v6 is only 4 bytes (4 bytes
+            // message body followed by header before the payload).
+            // Use 8 bytes for both v4 and v6 for simplicity.
+            final byte[] icmpPacket = new byte[ICMP_HEADER_LEN + mPayloadSize];
+            icmpPacket[0] = (byte) mIcmpType;
 
             int count = 0;
             mMeasurement.startTime = now();
diff --git a/tests/unit/java/com/android/server/ConnectivityServiceTest.java b/tests/unit/java/com/android/server/ConnectivityServiceTest.java
index 3188c9c..32c854b 100755
--- a/tests/unit/java/com/android/server/ConnectivityServiceTest.java
+++ b/tests/unit/java/com/android/server/ConnectivityServiceTest.java
@@ -154,6 +154,7 @@
 import static com.android.server.ConnectivityService.PREFERENCE_ORDER_PROFILE;
 import static com.android.server.ConnectivityService.PREFERENCE_ORDER_VPN;
 import static com.android.server.ConnectivityService.createDeliveryGroupKeyForConnectivityAction;
+import static com.android.server.ConnectivityService.makeNflogPrefix;
 import static com.android.server.ConnectivityServiceTestUtils.transportToLegacyType;
 import static com.android.server.NetworkAgentWrapper.CallbackType.OnQosCallbackRegister;
 import static com.android.server.NetworkAgentWrapper.CallbackType.OnQosCallbackUnregister;
@@ -533,6 +534,10 @@
     private static final int TEST_PACKAGE_UID = 123;
     private static final int TEST_PACKAGE_UID2 = 321;
     private static final int TEST_PACKAGE_UID3 = 456;
+
+    private static final int PACKET_WAKEUP_MASK = 0xffff0000;
+    private static final int PACKET_WAKEUP_MARK = 0x88880000;
+
     private static final String ALWAYS_ON_PACKAGE = "com.android.test.alwaysonvpn";
 
     private static final String INTERFACE_NAME = "interface";
@@ -1910,6 +1915,10 @@
         doReturn(0).when(mResources).getInteger(R.integer.config_activelyPreferBadWifi);
         doReturn(true).when(mResources)
                 .getBoolean(R.bool.config_cellular_radio_timesharing_capable);
+        doReturn(PACKET_WAKEUP_MASK).when(mResources).getInteger(
+                R.integer.config_networkWakeupPacketMask);
+        doReturn(PACKET_WAKEUP_MARK).when(mResources).getInteger(
+                R.integer.config_networkWakeupPacketMark);
     }
 
     class ConnectivityServiceDependencies extends ConnectivityService.Dependencies {
@@ -10386,6 +10395,16 @@
         return event;
     }
 
+    private void verifyWakeupModifyInterface(String iface, boolean add) throws RemoteException {
+        if (add) {
+            verify(mMockNetd).wakeupAddInterface(eq(iface), anyString(), anyInt(),
+                    anyInt());
+        } else {
+            verify(mMockNetd).wakeupDelInterface(eq(iface), anyString(), anyInt(),
+                    anyInt());
+        }
+    }
+
     private <T> T verifyWithOrder(@Nullable InOrder inOrder, @NonNull T t) {
         if (inOrder != null) {
             return inOrder.verify(t);
@@ -10612,6 +10631,11 @@
         clat.interfaceRemoved(CLAT_MOBILE_IFNAME);
         networkCallback.assertNoCallback();
         verify(mMockNetd, times(1)).networkRemoveInterface(cellNetId, CLAT_MOBILE_IFNAME);
+
+        if (SdkLevel.isAtLeastU()) {
+            verifyWakeupModifyInterface(CLAT_MOBILE_IFNAME, false);
+        }
+
         verifyNoMoreInteractions(mMockNetd);
         verifyNoMoreInteractions(mClatCoordinator);
         verifyNoMoreInteractions(mMockDnsResolver);
@@ -10648,6 +10672,10 @@
         assertRoutesAdded(cellNetId, stackedDefault);
         verify(mMockNetd, times(1)).networkAddInterface(cellNetId, CLAT_MOBILE_IFNAME);
 
+        if (SdkLevel.isAtLeastU()) {
+            verifyWakeupModifyInterface(CLAT_MOBILE_IFNAME, true);
+        }
+
         // NAT64 prefix is removed. Expect that clat is stopped.
         mService.mResolverUnsolEventCallback.onNat64PrefixEvent(makeNat64PrefixEvent(
                 cellNetId, PREFIX_OPERATION_REMOVED, kNat64PrefixString, 96));
@@ -10662,6 +10690,11 @@
                 cb -> cb.getLp().getStackedLinks().size() == 0);
         verify(mMockNetd, times(1)).networkRemoveInterface(cellNetId, CLAT_MOBILE_IFNAME);
         verify(mMockNetd, times(1)).interfaceGetCfg(CLAT_MOBILE_IFNAME);
+
+        if (SdkLevel.isAtLeastU()) {
+            verifyWakeupModifyInterface(CLAT_MOBILE_IFNAME, false);
+        }
+
         // Clean up.
         mCellAgent.disconnect();
         networkCallback.expect(LOST, mCellAgent);
@@ -10674,6 +10707,11 @@
         } else {
             verify(mMockNetd, never()).setNetworkAllowlist(any());
         }
+
+        if (SdkLevel.isAtLeastU()) {
+            verifyWakeupModifyInterface(MOBILE_IFNAME, false);
+        }
+
         verifyNoMoreInteractions(mMockNetd);
         verifyNoMoreInteractions(mClatCoordinator);
         reset(mMockNetd);
@@ -10703,6 +10741,11 @@
         verify(mMockNetd).networkAddInterface(cellNetId, CLAT_MOBILE_IFNAME);
         // assertRoutesAdded sees all calls since last mMockNetd reset, so expect IPv6 routes again.
         assertRoutesAdded(cellNetId, ipv6Subnet, ipv6Default, stackedDefault);
+
+        if (SdkLevel.isAtLeastU()) {
+            verifyWakeupModifyInterface(MOBILE_IFNAME, true);
+        }
+
         reset(mMockNetd);
         reset(mClatCoordinator);
 
@@ -10711,6 +10754,11 @@
         networkCallback.expect(LOST, mCellAgent);
         networkCallback.assertNoCallback();
         verifyClatdStop(null /* inOrder */, MOBILE_IFNAME);
+
+        if (SdkLevel.isAtLeastU()) {
+            verifyWakeupModifyInterface(CLAT_MOBILE_IFNAME, false);
+        }
+
         verify(mMockNetd).idletimerRemoveInterface(eq(MOBILE_IFNAME), anyInt(),
                 eq(Integer.toString(TRANSPORT_CELLULAR)));
         verify(mMockNetd).networkDestroy(cellNetId);
@@ -10719,6 +10767,11 @@
         } else {
             verify(mMockNetd, never()).setNetworkAllowlist(any());
         }
+
+        if (SdkLevel.isAtLeastU()) {
+            verifyWakeupModifyInterface(MOBILE_IFNAME, false);
+        }
+
         verifyNoMoreInteractions(mMockNetd);
         verifyNoMoreInteractions(mClatCoordinator);
 
@@ -17666,4 +17719,48 @@
         info.setExtraInfo("test_info");
         assertEquals("0;2;test_info", createDeliveryGroupKeyForConnectivityAction(info));
     }
+
+    @Test
+    public void testNetdWakeupAddInterfaceForWifiTransport() throws Exception {
+        final LinkProperties wifiLp = new LinkProperties();
+        wifiLp.setInterfaceName(WIFI_IFNAME);
+        mWiFiAgent = new TestNetworkAgentWrapper(TRANSPORT_WIFI, wifiLp);
+        mWiFiAgent.connect(false /* validated */);
+
+        final String expectedPrefix = makeNflogPrefix(WIFI_IFNAME,
+                mWiFiAgent.getNetwork().getNetworkHandle());
+        verify(mMockNetd).wakeupAddInterface(WIFI_IFNAME, expectedPrefix, PACKET_WAKEUP_MARK,
+                PACKET_WAKEUP_MASK);
+    }
+
+    @Test
+    public void testNetdWakeupAddInterfaceForCellularTransport() throws Exception {
+        final LinkProperties cellLp = new LinkProperties();
+        cellLp.setInterfaceName(MOBILE_IFNAME);
+        mCellAgent = new TestNetworkAgentWrapper(TRANSPORT_CELLULAR, cellLp);
+        mCellAgent.connect(false /* validated */);
+
+        if (SdkLevel.isAtLeastU()) {
+            final String expectedPrefix = makeNflogPrefix(MOBILE_IFNAME,
+                    mCellAgent.getNetwork().getNetworkHandle());
+            verify(mMockNetd).wakeupAddInterface(MOBILE_IFNAME, expectedPrefix, PACKET_WAKEUP_MARK,
+                    PACKET_WAKEUP_MASK);
+        } else {
+            verify(mMockNetd, never()).wakeupAddInterface(eq(MOBILE_IFNAME), anyString(), anyInt(),
+                    anyInt());
+        }
+    }
+
+    @Test
+    public void testNetdWakeupAddInterfaceForEthernetTransport() throws Exception {
+        final String ethernetIface = "eth42";
+
+        final LinkProperties ethLp = new LinkProperties();
+        ethLp.setInterfaceName(ethernetIface);
+        mEthernetAgent = new TestNetworkAgentWrapper(TRANSPORT_ETHERNET, ethLp);
+        mEthernetAgent.connect(false /* validated */);
+
+        verify(mMockNetd, never()).wakeupAddInterface(eq(ethernetIface), anyString(), anyInt(),
+                anyInt());
+    }
 }
diff --git a/tests/unit/java/com/android/server/connectivity/NetdEventListenerServiceTest.java b/tests/unit/java/com/android/server/connectivity/NetdEventListenerServiceTest.java
index f4b6464..d667662 100644
--- a/tests/unit/java/com/android/server/connectivity/NetdEventListenerServiceTest.java
+++ b/tests/unit/java/com/android/server/connectivity/NetdEventListenerServiceTest.java
@@ -24,6 +24,7 @@
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.fail;
 import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.verify;
 
@@ -41,6 +42,8 @@
 import com.android.testutils.DevSdkIgnoreRule;
 import com.android.testutils.DevSdkIgnoreRunner;
 
+import libcore.util.EmptyArray;
+
 import org.junit.Before;
 import org.junit.Test;
 import org.junit.runner.RunWith;
@@ -60,7 +63,8 @@
     private static final String EXAMPLE_IPV4 = "192.0.2.1";
     private static final String EXAMPLE_IPV6 = "2001:db8:1200::2:1";
 
-    private static final long NET_HANDLE = new Network(5391).getNetworkHandle();
+    private static final Network TEST_WIFI_NETWORK = new Network(5391);
+    private static final Network TEST_CELL_NETWORK = new Network(5832);
 
     private static final byte[] MAC_ADDR =
             {(byte)0x84, (byte)0xc9, (byte)0xb2, (byte)0x6a, (byte)0xed, (byte)0x4b};
@@ -78,6 +82,8 @@
     public void setUp() {
         mCm = mock(ConnectivityManager.class);
         mService = new NetdEventListenerService(mCm);
+        doReturn(CAPABILITIES_WIFI).when(mCm).getNetworkCapabilities(TEST_WIFI_NETWORK);
+        doReturn(CAPABILITIES_CELL).when(mCm).getNetworkCapabilities(TEST_CELL_NETWORK);
     }
 
     @Test
@@ -111,19 +117,25 @@
         wakeupEvent(iface, uids[5], v4, tcp, mac, srcIp, dstIp, sport, dport, now);
         wakeupEvent(iface, uids[6], v6, udp, mac, srcIp6, dstIp6, sport, dport, now);
         wakeupEvent(iface, uids[7], v6, tcp, mac, srcIp6, dstIp6, sport, dport, now);
-        wakeupEvent(iface, uids[8], v6, udp, mac, srcIp6, dstIp6, sport, dport, now);
+        wakeupEvent("rmnet0", uids[8], v6, udp, EmptyArray.BYTE, srcIp6, dstIp6, sport, dport, now,
+                TEST_CELL_NETWORK);
 
         String[] events2 = remove(listNetdEvent(), baseline);
-        int expectedLength2 = uids.length + 1; // +1 for the WakeupStats line
+        int expectedLength2 = uids.length + 2; // +2 for the WakeupStats headers
         assertEquals(expectedLength2, events2.length);
+
         assertStringContains(events2[0], "WakeupStats");
-        assertStringContains(events2[0], "wlan0");
-        assertStringContains(events2[0], "0x800");
+        assertStringContains(events2[0], "rmnet0");
         assertStringContains(events2[0], "0x86dd");
+
+        assertStringContains(events2[1], "WakeupStats");
+        assertStringContains(events2[1], "wlan0");
+        assertStringContains(events2[1], "0x800");
+        assertStringContains(events2[1], "0x86dd");
         for (int i = 0; i < uids.length; i++) {
-            String got = events2[i+1];
+            String got = events2[i + 2];
             assertStringContains(got, "WakeupEvent");
-            assertStringContains(got, "wlan0");
+            assertStringContains(got, ((i == 8) ? "rmnet0" : "wlan0"));
             assertStringContains(got, "uid: " + uids[i]);
         }
 
@@ -134,11 +146,13 @@
         }
 
         String[] events3 = remove(listNetdEvent(), baseline);
-        int expectedLength3 = BUFFER_LENGTH + 1; // +1 for the WakeupStats line
+        int expectedLength3 = BUFFER_LENGTH + 2; // +2 for the WakeupStats headers
         assertEquals(expectedLength3, events3.length);
-        assertStringContains(events2[0], "WakeupStats");
-        assertStringContains(events2[0], "wlan0");
-        for (int i = 1; i < expectedLength3; i++) {
+        assertStringContains(events3[0], "WakeupStats");
+        assertStringContains(events3[0], "rmnet0");
+        assertStringContains(events3[1], "WakeupStats");
+        assertStringContains(events3[1], "wlan0");
+        for (int i = 2; i < expectedLength3; i++) {
             String got = events3[i];
             assertStringContains(got, "WakeupEvent");
             assertStringContains(got, "wlan0");
@@ -173,19 +187,24 @@
         final int icmp6 = 58;
 
         wakeupEvent("wlan0", 1000, v4, tcp, mac, srcIp, dstIp, sport, dport, now);
-        wakeupEvent("rmnet0", 10123, v4, tcp, mac, srcIp, dstIp, sport, dport, now);
+        wakeupEvent("rmnet0", 10123, v4, tcp, mac, srcIp, dstIp, sport, dport, now,
+                TEST_CELL_NETWORK);
         wakeupEvent("wlan0", 1000, v4, udp, mac, srcIp, dstIp, sport, dport, now);
-        wakeupEvent("rmnet0", 10008, v4, tcp, mac, srcIp, dstIp, sport, dport, now);
+        wakeupEvent("rmnet0", 10008, v4, tcp, EmptyArray.BYTE, srcIp, dstIp, sport, dport, now,
+                TEST_CELL_NETWORK);
         wakeupEvent("wlan0", -1, v6, icmp6, mac, srcIp6, dstIp6, sport, dport, now);
         wakeupEvent("wlan0", 10008, v4, tcp, mac, srcIp, dstIp, sport, dport, now);
-        wakeupEvent("rmnet0", 1000, v4, tcp, mac, srcIp, dstIp, sport, dport, now);
+        wakeupEvent("rmnet0", 1000, v4, tcp, mac, srcIp, dstIp, sport, dport, now,
+                TEST_CELL_NETWORK);
         wakeupEvent("wlan0", 10004, v4, udp, mac, srcIp, dstIp, sport, dport, now);
         wakeupEvent("wlan0", 1000, v6, tcp, mac, srcIp6, dstIp6, sport, dport, now);
         wakeupEvent("wlan0", 0, v6, udp, mac, srcIp6, dstIp6, sport, dport, now);
         wakeupEvent("wlan0", -1, v6, icmp6, mac, srcIp6, dstIp6, sport, dport, now);
-        wakeupEvent("rmnet0", 10052, v4, tcp, mac, srcIp, dstIp, sport, dport, now);
+        wakeupEvent("rmnet0", 10052, v4, tcp, mac, srcIp, dstIp, sport, dport, now,
+                TEST_CELL_NETWORK);
         wakeupEvent("wlan0", 0, v6, udp, mac, srcIp6, dstIp6, sport, dport, now);
-        wakeupEvent("rmnet0", 1000, v6, tcp, mac, srcIp6, dstIp6, sport, dport, now);
+        wakeupEvent("rmnet0", 1000, v6, tcp, null, srcIp6, dstIp6, sport, dport, now,
+                TEST_CELL_NETWORK);
         wakeupEvent("wlan0", 1010, v4, udp, mac, srcIp, dstIp, sport, dport, now);
 
         String got = flushStatistics();
@@ -214,7 +233,7 @@
                 "    >",
                 "    l2_broadcast_count: 0",
                 "    l2_multicast_count: 0",
-                "    l2_unicast_count: 5",
+                "    l2_unicast_count: 3",
                 "    no_uid_wakeups: 0",
                 "    non_application_wakeups: 0",
                 "    root_wakeups: 0",
@@ -499,8 +518,13 @@
     }
 
     void wakeupEvent(String iface, int uid, int ether, int ip, byte[] mac, String srcIp,
-            String dstIp, int sport, int dport, long now) throws Exception {
-        String prefix = NET_HANDLE + ":" + iface;
+            String dstIp, int sport, int dport, long now) {
+        wakeupEvent(iface, uid, ether, ip, mac, srcIp, dstIp, sport, dport, now, TEST_WIFI_NETWORK);
+    }
+
+    void wakeupEvent(String iface, int uid, int ether, int ip, byte[] mac, String srcIp,
+            String dstIp, int sport, int dport, long now, Network network) {
+        String prefix = network.getNetworkHandle() + ":" + iface;
         mService.onWakeupEvent(prefix, uid, ether, ip, mac, srcIp, dstIp, sport, dport, now);
     }
 
diff --git a/tests/unit/java/com/android/server/connectivity/VpnTest.java b/tests/unit/java/com/android/server/connectivity/VpnTest.java
index 2926c9a..395e2bb 100644
--- a/tests/unit/java/com/android/server/connectivity/VpnTest.java
+++ b/tests/unit/java/com/android/server/connectivity/VpnTest.java
@@ -142,6 +142,7 @@
 import android.net.ipsec.ike.exceptions.IkeNonProtocolException;
 import android.net.ipsec.ike.exceptions.IkeProtocolException;
 import android.net.ipsec.ike.exceptions.IkeTimeoutException;
+import android.net.vcn.VcnTransportInfo;
 import android.net.wifi.WifiInfo;
 import android.os.Build.VERSION_CODES;
 import android.os.Bundle;
@@ -1563,6 +1564,11 @@
     }
 
     private NetworkCallback triggerOnAvailableAndGetCallback() throws Exception {
+        return triggerOnAvailableAndGetCallback(new NetworkCapabilities.Builder().build());
+    }
+
+    private NetworkCallback triggerOnAvailableAndGetCallback(
+            @NonNull final NetworkCapabilities caps) throws Exception {
         final ArgumentCaptor<NetworkCallback> networkCallbackCaptor =
                 ArgumentCaptor.forClass(NetworkCallback.class);
         verify(mConnectivityManager, timeout(TEST_TIMEOUT_MS))
@@ -1579,7 +1585,7 @@
         // if NetworkCapabilities and LinkProperties of underlying network will be sent/cleared or
         // not.
         // See verifyVpnManagerEvent().
-        cb.onCapabilitiesChanged(TEST_NETWORK, new NetworkCapabilities());
+        cb.onCapabilitiesChanged(TEST_NETWORK, caps);
         cb.onLinkPropertiesChanged(TEST_NETWORK, new LinkProperties());
         return cb;
     }
@@ -1903,12 +1909,15 @@
 
     private PlatformVpnSnapshot verifySetupPlatformVpn(VpnProfile vpnProfile,
             IkeSessionConfiguration ikeConfig, boolean mtuSupportsIpv6) throws Exception {
-        return verifySetupPlatformVpn(vpnProfile, ikeConfig, mtuSupportsIpv6,
-                false /* areLongLivedTcpConnectionsExpensive */);
+        return verifySetupPlatformVpn(vpnProfile, ikeConfig,
+                new NetworkCapabilities.Builder().build() /* underlying network caps */,
+                mtuSupportsIpv6, false /* areLongLivedTcpConnectionsExpensive */);
     }
 
     private PlatformVpnSnapshot verifySetupPlatformVpn(VpnProfile vpnProfile,
-            IkeSessionConfiguration ikeConfig, boolean mtuSupportsIpv6,
+            IkeSessionConfiguration ikeConfig,
+            @NonNull final NetworkCapabilities underlyingNetworkCaps,
+            boolean mtuSupportsIpv6,
             boolean areLongLivedTcpConnectionsExpensive) throws Exception {
         if (!mtuSupportsIpv6) {
             doReturn(IPV6_MIN_MTU - 1).when(mTestDeps).calculateVpnMtu(any(), anyInt(), anyInt(),
@@ -1925,7 +1934,7 @@
                 .thenReturn(vpnProfile.encode());
 
         vpn.startVpnProfile(TEST_VPN_PKG);
-        final NetworkCallback nwCb = triggerOnAvailableAndGetCallback();
+        final NetworkCallback nwCb = triggerOnAvailableAndGetCallback(underlyingNetworkCaps);
         verify(mExecutor, atLeastOnce()).schedule(any(Runnable.class), anyLong(), any());
         reset(mExecutor);
 
@@ -2079,15 +2088,16 @@
         doTestMigrateIkeSession(ikeProfile.toVpnProfile(),
                 expectedKeepalive,
                 ESP_IP_VERSION_AUTO /* expectedIpVersion */,
-                ESP_ENCAP_TYPE_AUTO /* expectedEncapType */);
+                ESP_ENCAP_TYPE_AUTO /* expectedEncapType */,
+                new NetworkCapabilities.Builder().build());
     }
 
-    private void doTestMigrateIkeSession_FromIkeTunnConnParams(
+    private Ikev2VpnProfile makeIkeV2VpnProfile(
             boolean isAutomaticIpVersionSelectionEnabled,
             boolean isAutomaticNattKeepaliveTimerEnabled,
             int keepaliveInProfile,
             int ipVersionInProfile,
-            int encapTypeInProfile) throws Exception {
+            int encapTypeInProfile) {
         // TODO: Update helper function in IkeSessionTestUtils to support building IkeSessionParams
         // with IP version and encap type when mainline-prod branch support these two APIs.
         final IkeSessionParams params = getTestIkeSessionParams(true /* testIpv6 */,
@@ -2099,12 +2109,40 @@
 
         final IkeTunnelConnectionParams tunnelParams =
                 new IkeTunnelConnectionParams(ikeSessionParams, CHILD_PARAMS);
-        final Ikev2VpnProfile ikeProfile = new Ikev2VpnProfile.Builder(tunnelParams)
+        return new Ikev2VpnProfile.Builder(tunnelParams)
                 .setBypassable(true)
                 .setAutomaticNattKeepaliveTimerEnabled(isAutomaticNattKeepaliveTimerEnabled)
                 .setAutomaticIpVersionSelectionEnabled(isAutomaticIpVersionSelectionEnabled)
                 .build();
+    }
 
+    private void doTestMigrateIkeSession_FromIkeTunnConnParams(
+            boolean isAutomaticIpVersionSelectionEnabled,
+            boolean isAutomaticNattKeepaliveTimerEnabled,
+            int keepaliveInProfile,
+            int ipVersionInProfile,
+            int encapTypeInProfile) throws Exception {
+        doTestMigrateIkeSession_FromIkeTunnConnParams(isAutomaticIpVersionSelectionEnabled,
+                isAutomaticNattKeepaliveTimerEnabled, keepaliveInProfile, ipVersionInProfile,
+                encapTypeInProfile, new NetworkCapabilities.Builder().build());
+    }
+
+    private void doTestMigrateIkeSession_FromIkeTunnConnParams(
+            boolean isAutomaticIpVersionSelectionEnabled,
+            boolean isAutomaticNattKeepaliveTimerEnabled,
+            int keepaliveInProfile,
+            int ipVersionInProfile,
+            int encapTypeInProfile,
+            @NonNull final NetworkCapabilities nc) throws Exception {
+        final Ikev2VpnProfile ikeProfile = makeIkeV2VpnProfile(
+                isAutomaticIpVersionSelectionEnabled,
+                isAutomaticNattKeepaliveTimerEnabled,
+                keepaliveInProfile,
+                ipVersionInProfile,
+                encapTypeInProfile);
+
+        final IkeSessionParams ikeSessionParams =
+                ikeProfile.getIkeTunnelConnectionParams().getIkeSessionParams();
         final int expectedKeepalive = isAutomaticNattKeepaliveTimerEnabled
                 ? AUTOMATIC_KEEPALIVE_DELAY_SECONDS
                 : ikeSessionParams.getNattKeepAliveDelaySeconds();
@@ -2115,22 +2153,48 @@
                 ? ESP_ENCAP_TYPE_AUTO
                 : ikeSessionParams.getEncapType();
         doTestMigrateIkeSession(ikeProfile.toVpnProfile(), expectedKeepalive,
-                expectedIpVersion, expectedEncapType);
+                expectedIpVersion, expectedEncapType, nc);
     }
 
-    private void doTestMigrateIkeSession(VpnProfile profile,
-            int expectedKeepalive, int expectedIpVersion, int expectedEncapType) throws Exception {
+    @Test
+    public void doTestMigrateIkeSession_Vcn() throws Exception {
+        final int expectedKeepalive = 2097; // Any unlikely number will do
+        final NetworkCapabilities vcnNc = new NetworkCapabilities.Builder()
+                .addTransportType(TRANSPORT_CELLULAR)
+                .setTransportInfo(new VcnTransportInfo(TEST_SUB_ID, expectedKeepalive))
+                .build();
+        final Ikev2VpnProfile ikev2VpnProfile = makeIkeV2VpnProfile(
+                true /* isAutomaticIpVersionSelectionEnabled */,
+                true /* isAutomaticNattKeepaliveTimerEnabled */,
+                234 /* keepaliveInProfile */, // Should be ignored, any value will do
+                ESP_IP_VERSION_IPV4, // Should be ignored
+                ESP_ENCAP_TYPE_UDP // Should be ignored
+        );
+        doTestMigrateIkeSession(
+                ikev2VpnProfile.toVpnProfile(),
+                expectedKeepalive,
+                ESP_IP_VERSION_AUTO /* expectedIpVersion */,
+                ESP_ENCAP_TYPE_AUTO /* expectedEncapType */,
+                vcnNc);
+    }
+
+    private void doTestMigrateIkeSession(
+            @NonNull final VpnProfile profile,
+            final int expectedKeepalive,
+            final int expectedIpVersion,
+            final int expectedEncapType,
+            @NonNull final NetworkCapabilities caps) throws Exception {
         final PlatformVpnSnapshot vpnSnapShot =
                 verifySetupPlatformVpn(profile,
                         createIkeConfig(createIkeConnectInfo(), true /* isMobikeEnabled */),
+                        caps /* underlying network capabilities */,
                         false /* mtuSupportsIpv6 */,
                         expectedKeepalive < DEFAULT_LONG_LIVED_TCP_CONNS_EXPENSIVE_TIMEOUT_SEC);
         // Simulate a new network coming up
         vpnSnapShot.nwCb.onAvailable(TEST_NETWORK_2);
         verify(mIkeSessionWrapper, never()).setNetwork(any(), anyInt(), anyInt(), anyInt());
 
-        vpnSnapShot.nwCb.onCapabilitiesChanged(
-                TEST_NETWORK_2, new NetworkCapabilities.Builder().build());
+        vpnSnapShot.nwCb.onCapabilitiesChanged(TEST_NETWORK_2, caps);
         // Verify MOBIKE is triggered
         verify(mIkeSessionWrapper, timeout(TEST_TIMEOUT_MS)).setNetwork(TEST_NETWORK_2,
                 expectedIpVersion, expectedEncapType, expectedKeepalive);
@@ -2156,6 +2220,7 @@
         final PlatformVpnSnapshot vpnSnapShot =
                 verifySetupPlatformVpn(ikeProfile.toVpnProfile(),
                         createIkeConfig(createIkeConnectInfo(), true /* isMobikeEnabled */),
+                        new NetworkCapabilities.Builder().build() /* underlying network caps */,
                         hasV6 /* mtuSupportsIpv6 */,
                         false /* areLongLivedTcpConnectionsExpensive */);
         reset(mExecutor);
@@ -2343,6 +2408,7 @@
         final PlatformVpnSnapshot vpnSnapShot =
                 verifySetupPlatformVpn(ikeProfile.toVpnProfile(),
                         createIkeConfig(createIkeConnectInfo(), true /* isMobikeEnabled */),
+                        new NetworkCapabilities.Builder().build() /* underlying network caps */,
                         false /* mtuSupportsIpv6 */,
                         true /* areLongLivedTcpConnectionsExpensive */);
 
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceCacheTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceCacheTest.kt
new file mode 100644
index 0000000..f091eea
--- /dev/null
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceCacheTest.kt
@@ -0,0 +1,136 @@
+/*
+ * Copyright (C) 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License")
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server.connectivity.mdns
+
+import android.net.Network
+import android.os.Build
+import android.os.Handler
+import android.os.HandlerThread
+import com.android.testutils.DevSdkIgnoreRule
+import com.android.testutils.DevSdkIgnoreRunner
+import java.util.concurrent.CompletableFuture
+import java.util.concurrent.TimeUnit
+import kotlin.test.assertNotNull
+import org.junit.After
+import org.junit.Assert.assertEquals
+import org.junit.Assert.assertNull
+import org.junit.Assert.assertTrue
+import org.junit.Before
+import org.junit.Test
+import org.junit.runner.RunWith
+import org.mockito.Mockito.mock
+
+private const val SERVICE_NAME_1 = "service-instance-1"
+private const val SERVICE_NAME_2 = "service-instance-2"
+private const val SERVICE_TYPE_1 = "_test1._tcp.local"
+private const val SERVICE_TYPE_2 = "_test2._tcp.local"
+private const val INTERFACE_INDEX = 999
+private const val DEFAULT_TIMEOUT_MS = 2000L
+
+@RunWith(DevSdkIgnoreRunner::class)
+@DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.S_V2)
+class MdnsServiceCacheTest {
+    private val network = mock(Network::class.java)
+    private val thread = HandlerThread(MdnsServiceCacheTest::class.simpleName)
+    private val handler by lazy {
+        Handler(thread.looper)
+    }
+    private val serviceCache by lazy {
+        MdnsServiceCache(thread.looper)
+    }
+
+    @Before
+    fun setUp() {
+        thread.start()
+    }
+
+    @After
+    fun tearDown() {
+        thread.quitSafely()
+    }
+
+    private fun <T> runningOnHandlerAndReturn(functor: (() -> T)): T {
+        val future = CompletableFuture<T>()
+        handler.post {
+            future.complete(functor())
+        }
+        return future.get(DEFAULT_TIMEOUT_MS, TimeUnit.MILLISECONDS)
+    }
+
+    private fun addOrUpdateService(serviceType: String, network: Network, service: MdnsResponse):
+            Unit = runningOnHandlerAndReturn {
+        serviceCache.addOrUpdateService(serviceType, network, service) }
+
+    private fun removeService(serviceName: String, serviceType: String, network: Network):
+            Unit = runningOnHandlerAndReturn {
+        serviceCache.removeService(serviceName, serviceType, network) }
+
+    private fun getService(serviceName: String, serviceType: String, network: Network):
+            MdnsResponse? = runningOnHandlerAndReturn {
+        serviceCache.getCachedService(serviceName, serviceType, network) }
+
+    private fun getServices(serviceType: String, network: Network): List<MdnsResponse> =
+        runningOnHandlerAndReturn { serviceCache.getCachedServices(serviceType, network) }
+
+    @Test
+    fun testAddAndRemoveService() {
+        addOrUpdateService(SERVICE_TYPE_1, network, createResponse(SERVICE_NAME_1, SERVICE_TYPE_1))
+        var response = getService(SERVICE_NAME_1, SERVICE_TYPE_1, network)
+        assertNotNull(response)
+        assertEquals(SERVICE_NAME_1, response.serviceInstanceName)
+        removeService(SERVICE_NAME_1, SERVICE_TYPE_1, network)
+        response = getService(SERVICE_NAME_1, SERVICE_TYPE_1, network)
+        assertNull(response)
+    }
+
+    @Test
+    fun testGetCachedServices_multipleServiceTypes() {
+        addOrUpdateService(SERVICE_TYPE_1, network, createResponse(SERVICE_NAME_1, SERVICE_TYPE_1))
+        addOrUpdateService(SERVICE_TYPE_1, network, createResponse(SERVICE_NAME_2, SERVICE_TYPE_1))
+        addOrUpdateService(SERVICE_TYPE_2, network, createResponse(SERVICE_NAME_2, SERVICE_TYPE_2))
+
+        val responses1 = getServices(SERVICE_TYPE_1, network)
+        assertEquals(2, responses1.size)
+        assertTrue(responses1.stream().anyMatch { response ->
+            response.serviceInstanceName == SERVICE_NAME_1
+        })
+        assertTrue(responses1.any { response ->
+            response.serviceInstanceName == SERVICE_NAME_2
+        })
+        val responses2 = getServices(SERVICE_TYPE_2, network)
+        assertEquals(1, responses2.size)
+        assertTrue(responses2.any { response ->
+            response.serviceInstanceName == SERVICE_NAME_2
+        })
+
+        removeService(SERVICE_NAME_2, SERVICE_TYPE_1, network)
+        val responses3 = getServices(SERVICE_TYPE_1, network)
+        assertEquals(1, responses3.size)
+        assertTrue(responses3.any { response ->
+            response.serviceInstanceName == SERVICE_NAME_1
+        })
+        val responses4 = getServices(SERVICE_TYPE_2, network)
+        assertEquals(1, responses4.size)
+        assertTrue(responses4.any { response ->
+            response.serviceInstanceName == SERVICE_NAME_2
+        })
+    }
+
+    private fun createResponse(serviceInstanceName: String, serviceType: String) = MdnsResponse(
+        0 /* now */, "$serviceInstanceName.$serviceType".split(".").toTypedArray(),
+            INTERFACE_INDEX, network)
+}
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketProviderTest.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketProviderTest.java
index 2d73c98..6f3322b 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketProviderTest.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketProviderTest.java
@@ -112,7 +112,7 @@
     private NetworkCallback mNetworkCallback;
     private TetheringEventCallback mTetheringEventCallback;
 
-    private TestNetLinkMonitor mTestSocketNetLinkMonitor;
+    private TestNetlinkMonitor mTestSocketNetLinkMonitor;
     @Before
     public void setUp() throws IOException {
         MockitoAnnotations.initMocks(this);
@@ -147,7 +147,7 @@
         doReturn(mTestSocketNetLinkMonitor).when(mDeps).createSocketNetlinkMonitor(any(), any(),
                 any());
         doAnswer(inv -> {
-            mTestSocketNetLinkMonitor = new TestNetLinkMonitor(inv.getArgument(0),
+            mTestSocketNetLinkMonitor = new TestNetlinkMonitor(inv.getArgument(0),
                     inv.getArgument(1),
                     inv.getArgument(2));
             return mTestSocketNetLinkMonitor;
@@ -174,8 +174,8 @@
         HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
     }
 
-    private static class TestNetLinkMonitor extends SocketNetlinkMonitor {
-        TestNetLinkMonitor(@NonNull Handler handler,
+    private static class TestNetlinkMonitor extends SocketNetlinkMonitor {
+        TestNetlinkMonitor(@NonNull Handler handler,
                 @NonNull SharedLog log,
                 @Nullable MdnsSocketProvider.NetLinkMonitorCallBack cb) {
             super(handler, log, cb);
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/util/MdnsUtilsTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/util/MdnsUtilsTest.kt
new file mode 100644
index 0000000..f584ed5
--- /dev/null
+++ b/tests/unit/java/com/android/server/connectivity/mdns/util/MdnsUtilsTest.kt
@@ -0,0 +1,68 @@
+/*
+ * Copyright (C) 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License")
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server.connectivity.mdns.util
+
+import android.os.Build
+import com.android.server.connectivity.mdns.util.MdnsUtils.equalsIgnoreDnsCase
+import com.android.server.connectivity.mdns.util.MdnsUtils.toDnsLowerCase
+import com.android.testutils.DevSdkIgnoreRule
+import com.android.testutils.DevSdkIgnoreRunner
+import org.junit.Assert.assertEquals
+import org.junit.Assert.assertFalse
+import org.junit.Assert.assertTrue
+import org.junit.Test
+import org.junit.runner.RunWith
+
+@RunWith(DevSdkIgnoreRunner::class)
+@DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.S_V2)
+class MdnsUtilsTest {
+    @Test
+    fun testToDnsLowerCase() {
+        assertEquals("test", toDnsLowerCase("TEST"))
+        assertEquals("test", toDnsLowerCase("TeSt"))
+        assertEquals("test", toDnsLowerCase("test"))
+        assertEquals("tÉst", toDnsLowerCase("TÉST"))
+        assertEquals("ลฃést", toDnsLowerCase("ลฃést"))
+        // Unicode characters 0x10000 (๐€€), 0x10001 (๐€), 0x10041 (๐)
+        // Note the last 2 bytes of 0x10041 are identical to 'A', but it should remain unchanged.
+        assertEquals("test: -->\ud800\udc00 \ud800\udc01 \ud800\udc41<-- ",
+                toDnsLowerCase("Test: -->\ud800\udc00 \ud800\udc01 \ud800\udc41<-- "))
+        // Also test some characters where the first surrogate is not \ud800
+        assertEquals("test: >\ud83c\udff4\udb40\udc67\udb40\udc62\udb40" +
+                "\udc77\udb40\udc6c\udb40\udc73\udb40\udc7f<",
+                toDnsLowerCase("Test: >\ud83c\udff4\udb40\udc67\udb40\udc62\udb40" +
+                        "\udc77\udb40\udc6c\udb40\udc73\udb40\udc7f<"))
+    }
+
+    @Test
+    fun testEqualsIgnoreDnsCase() {
+        assertTrue(equalsIgnoreDnsCase("TEST", "Test"))
+        assertTrue(equalsIgnoreDnsCase("TEST", "test"))
+        assertTrue(equalsIgnoreDnsCase("test", "TeSt"))
+        assertTrue(equalsIgnoreDnsCase("Tést", "tést"))
+        assertFalse(equalsIgnoreDnsCase("ลขÉST", "ลฃést"))
+        // Unicode characters 0x10000 (๐€€), 0x10001 (๐€), 0x10041 (๐)
+        // Note the last 2 bytes of 0x10041 are identical to 'A', but it should remain unchanged.
+        assertTrue(equalsIgnoreDnsCase("test: -->\ud800\udc00 \ud800\udc01 \ud800\udc41<-- ",
+                "Test: -->\ud800\udc00 \ud800\udc01 \ud800\udc41<-- "))
+        // Also test some characters where the first surrogate is not \ud800
+        assertTrue(equalsIgnoreDnsCase("test: >\ud83c\udff4\udb40\udc67\udb40\udc62\udb40" +
+                "\udc77\udb40\udc6c\udb40\udc73\udb40\udc7f<",
+                "Test: >\ud83c\udff4\udb40\udc67\udb40\udc62\udb40" +
+                        "\udc77\udb40\udc6c\udb40\udc73\udb40\udc7f<"))
+    }
+}