Merge "Add CTS for UrlRequest's getters"
diff --git a/Tethering/src/com/android/networkstack/tethering/Tethering.java b/Tethering/src/com/android/networkstack/tethering/Tethering.java
index e5f644e..b371178 100644
--- a/Tethering/src/com/android/networkstack/tethering/Tethering.java
+++ b/Tethering/src/com/android/networkstack/tethering/Tethering.java
@@ -157,6 +157,7 @@
 import java.util.Collections;
 import java.util.HashSet;
 import java.util.List;
+import java.util.Objects;
 import java.util.Set;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.Executor;
@@ -1845,11 +1846,12 @@
 
             setUpstreamNetwork(ns);
             final Network newUpstream = (ns != null) ? ns.network : null;
-            if (mTetherUpstream != newUpstream) {
+            if (!Objects.equals(mTetherUpstream, newUpstream)) {
                 mTetherUpstream = newUpstream;
                 reportUpstreamChanged(mTetherUpstream);
-                // Need to notify capabilities change after upstream network changed because new
-                // network's capabilities should be checked every time.
+                // Need to notify capabilities change after upstream network changed because
+                // upstream may switch to existing network which don't have
+                // UpstreamNetworkMonitor.EVENT_ON_CAPABILITIES callback.
                 mNotificationUpdater.onUpstreamCapabilitiesChanged(
                         (ns != null) ? ns.networkCapabilities : null);
             }
diff --git a/Tethering/tests/integration/base/android/net/EthernetTetheringTestBase.java b/Tethering/tests/integration/base/android/net/EthernetTetheringTestBase.java
index 007bf23..a9add20 100644
--- a/Tethering/tests/integration/base/android/net/EthernetTetheringTestBase.java
+++ b/Tethering/tests/integration/base/android/net/EthernetTetheringTestBase.java
@@ -25,17 +25,14 @@
 import static android.net.TetheringManager.CONNECTIVITY_SCOPE_GLOBAL;
 import static android.net.TetheringManager.CONNECTIVITY_SCOPE_LOCAL;
 import static android.net.TetheringManager.TETHERING_ETHERNET;
+import static android.net.TetheringTester.buildTcpPacket;
+import static android.net.TetheringTester.buildUdpPacket;
+import static android.net.TetheringTester.isAddressIpv4;
 import static android.net.TetheringTester.isExpectedIcmpPacket;
 import static android.net.TetheringTester.isExpectedTcpPacket;
 import static android.net.TetheringTester.isExpectedUdpPacket;
-import static android.system.OsConstants.IPPROTO_IP;
-import static android.system.OsConstants.IPPROTO_IPV6;
-import static android.system.OsConstants.IPPROTO_TCP;
-import static android.system.OsConstants.IPPROTO_UDP;
 
 import static com.android.net.module.util.HexDump.dumpHexString;
-import static com.android.net.module.util.NetworkStackConstants.ETHER_TYPE_IPV4;
-import static com.android.net.module.util.NetworkStackConstants.ETHER_TYPE_IPV6;
 import static com.android.net.module.util.NetworkStackConstants.ICMPV6_ROUTER_ADVERTISEMENT;
 import static com.android.net.module.util.NetworkStackConstants.TCPHDR_ACK;
 import static com.android.net.module.util.NetworkStackConstants.TCPHDR_SYN;
@@ -67,11 +64,9 @@
 import android.util.Log;
 
 import androidx.annotation.NonNull;
-import androidx.annotation.Nullable;
 import androidx.test.platform.app.InstrumentationRegistry;
 
 import com.android.modules.utils.build.SdkLevel;
-import com.android.net.module.util.PacketBuilder;
 import com.android.net.module.util.Struct;
 import com.android.net.module.util.structs.Ipv6Header;
 import com.android.testutils.HandlerUtils;
@@ -128,25 +123,11 @@
             (Inet6Address) parseNumericAddress("64:ff9b::808:808");
     protected static final IpPrefix TEST_NAT64PREFIX = new IpPrefix("64:ff9b::/96");
 
-    // IPv4 header definition.
-    protected static final short ID = 27149;
-    protected static final short FLAGS_AND_FRAGMENT_OFFSET = (short) 0x4000; // flags=DF, offset=0
-    protected static final byte TIME_TO_LIVE = (byte) 0x40;
-    protected static final byte TYPE_OF_SERVICE = 0;
-
-    // IPv6 header definition.
-    private static final short HOP_LIMIT = 0x40;
-    // version=6, traffic class=0x0, flowlabel=0x0;
-    private static final int VERSION_TRAFFICCLASS_FLOWLABEL = 0x60000000;
-
-    // UDP and TCP header definition.
     // LOCAL_PORT is used by public port and private port. Assume port 9876 has not been used yet
     // before the testing that public port and private port are the same in the testing. Note that
     // NAT port forwarding could be different between private port and public port.
     protected static final short LOCAL_PORT = 9876;
     protected static final short REMOTE_PORT = 433;
-    private static final short WINDOW = (short) 0x2000;
-    private static final short URGENT_POINTER = 0;
 
     // Payload definition.
     protected static final ByteBuffer EMPTY_PAYLOAD = ByteBuffer.wrap(new byte[0]);
@@ -654,72 +635,6 @@
         return runAsShell(MANAGE_TEST_NETWORKS, () -> initTestNetwork(mContext, lp, TIMEOUT_MS));
     }
 
-    private short getEthType(@NonNull final InetAddress srcIp, @NonNull final InetAddress dstIp) {
-        return isAddressIpv4(srcIp, dstIp) ? (short) ETHER_TYPE_IPV4 : (short) ETHER_TYPE_IPV6;
-    }
-
-    private int getIpProto(@NonNull final InetAddress srcIp, @NonNull final InetAddress dstIp) {
-        return isAddressIpv4(srcIp, dstIp) ? IPPROTO_IP : IPPROTO_IPV6;
-    }
-
-    @NonNull
-    protected ByteBuffer buildUdpPacket(
-            @Nullable final MacAddress srcMac, @Nullable final MacAddress dstMac,
-            @NonNull final InetAddress srcIp, @NonNull final InetAddress dstIp,
-            short srcPort, short dstPort, @Nullable final ByteBuffer payload)
-            throws Exception {
-        final int ipProto = getIpProto(srcIp, dstIp);
-        final boolean hasEther = (srcMac != null && dstMac != null);
-        final int payloadLen = (payload == null) ? 0 : payload.limit();
-        final ByteBuffer buffer = PacketBuilder.allocate(hasEther, ipProto, IPPROTO_UDP,
-                payloadLen);
-        final PacketBuilder packetBuilder = new PacketBuilder(buffer);
-
-        // [1] Ethernet header
-        if (hasEther) {
-            packetBuilder.writeL2Header(srcMac, dstMac, getEthType(srcIp, dstIp));
-        }
-
-        // [2] IP header
-        if (ipProto == IPPROTO_IP) {
-            packetBuilder.writeIpv4Header(TYPE_OF_SERVICE, ID, FLAGS_AND_FRAGMENT_OFFSET,
-                    TIME_TO_LIVE, (byte) IPPROTO_UDP, (Inet4Address) srcIp, (Inet4Address) dstIp);
-        } else {
-            packetBuilder.writeIpv6Header(VERSION_TRAFFICCLASS_FLOWLABEL, (byte) IPPROTO_UDP,
-                    HOP_LIMIT, (Inet6Address) srcIp, (Inet6Address) dstIp);
-        }
-
-        // [3] UDP header
-        packetBuilder.writeUdpHeader(srcPort, dstPort);
-
-        // [4] Payload
-        if (payload != null) {
-            buffer.put(payload);
-            // in case data might be reused by caller, restore the position and
-            // limit of bytebuffer.
-            payload.clear();
-        }
-
-        return packetBuilder.finalizePacket();
-    }
-
-    @NonNull
-    protected ByteBuffer buildUdpPacket(@NonNull final InetAddress srcIp,
-            @NonNull final InetAddress dstIp, short srcPort, short dstPort,
-            @Nullable final ByteBuffer payload) throws Exception {
-        return buildUdpPacket(null /* srcMac */, null /* dstMac */, srcIp, dstIp, srcPort,
-                dstPort, payload);
-    }
-
-    private boolean isAddressIpv4(@NonNull final  InetAddress srcIp,
-            @NonNull final InetAddress dstIp) {
-        if (srcIp instanceof Inet4Address && dstIp instanceof Inet4Address) return true;
-        if (srcIp instanceof Inet6Address && dstIp instanceof Inet6Address) return false;
-
-        fail("Unsupported conditions: srcIp " + srcIp + ", dstIp " + dstIp);
-        return false;  // unreachable
-    }
-
     protected void sendDownloadPacketUdp(@NonNull final InetAddress srcIp,
             @NonNull final InetAddress dstIp, @NonNull final TetheringTester tester,
             boolean is6To4) throws Exception {
@@ -761,45 +676,6 @@
         });
     }
 
-
-    @NonNull
-    private ByteBuffer buildTcpPacket(
-            @Nullable final MacAddress srcMac, @Nullable final MacAddress dstMac,
-            @NonNull final InetAddress srcIp, @NonNull final InetAddress dstIp,
-            short srcPort, short dstPort, final short seq, final short ack,
-            final byte tcpFlags, @NonNull final ByteBuffer payload) throws Exception {
-        final int ipProto = getIpProto(srcIp, dstIp);
-        final boolean hasEther = (srcMac != null && dstMac != null);
-        final ByteBuffer buffer = PacketBuilder.allocate(hasEther, ipProto, IPPROTO_TCP,
-                payload.limit());
-        final PacketBuilder packetBuilder = new PacketBuilder(buffer);
-
-        // [1] Ethernet header
-        if (hasEther) {
-            packetBuilder.writeL2Header(srcMac, dstMac, getEthType(srcIp, dstIp));
-        }
-
-        // [2] IP header
-        if (ipProto == IPPROTO_IP) {
-            packetBuilder.writeIpv4Header(TYPE_OF_SERVICE, ID, FLAGS_AND_FRAGMENT_OFFSET,
-                    TIME_TO_LIVE, (byte) IPPROTO_TCP, (Inet4Address) srcIp, (Inet4Address) dstIp);
-        } else {
-            packetBuilder.writeIpv6Header(VERSION_TRAFFICCLASS_FLOWLABEL, (byte) IPPROTO_TCP,
-                    HOP_LIMIT, (Inet6Address) srcIp, (Inet6Address) dstIp);
-        }
-
-        // [3] TCP header
-        packetBuilder.writeTcpHeader(srcPort, dstPort, seq, ack, tcpFlags, WINDOW, URGENT_POINTER);
-
-        // [4] Payload
-        buffer.put(payload);
-        // in case data might be reused by caller, restore the position and
-        // limit of bytebuffer.
-        payload.clear();
-
-        return packetBuilder.finalizePacket();
-    }
-
     protected void sendDownloadPacketTcp(@NonNull final InetAddress srcIp,
             @NonNull final InetAddress dstIp, short seq, short ack, byte tcpFlags,
             @NonNull final ByteBuffer payload, @NonNull final TetheringTester tester,
diff --git a/Tethering/tests/integration/base/android/net/TetheringTester.java b/Tethering/tests/integration/base/android/net/TetheringTester.java
index 1c0803e..33baf93 100644
--- a/Tethering/tests/integration/base/android/net/TetheringTester.java
+++ b/Tethering/tests/integration/base/android/net/TetheringTester.java
@@ -17,8 +17,12 @@
 package android.net;
 
 import static android.net.InetAddresses.parseNumericAddress;
+import static android.system.OsConstants.ICMP_ECHO;
+import static android.system.OsConstants.ICMP_ECHOREPLY;
 import static android.system.OsConstants.IPPROTO_ICMP;
 import static android.system.OsConstants.IPPROTO_ICMPV6;
+import static android.system.OsConstants.IPPROTO_IP;
+import static android.system.OsConstants.IPPROTO_IPV6;
 import static android.system.OsConstants.IPPROTO_TCP;
 import static android.system.OsConstants.IPPROTO_UDP;
 
@@ -27,6 +31,8 @@
 import static com.android.net.module.util.DnsPacket.NSSECTION;
 import static com.android.net.module.util.DnsPacket.QDSECTION;
 import static com.android.net.module.util.HexDump.dumpHexString;
+import static com.android.net.module.util.IpUtils.icmpChecksum;
+import static com.android.net.module.util.IpUtils.ipChecksum;
 import static com.android.net.module.util.NetworkStackConstants.ARP_REPLY;
 import static com.android.net.module.util.NetworkStackConstants.ARP_REQUEST;
 import static com.android.net.module.util.NetworkStackConstants.ETHER_ADDR_LEN;
@@ -38,6 +44,10 @@
 import static com.android.net.module.util.NetworkStackConstants.ICMPV6_ND_OPTION_TLLA;
 import static com.android.net.module.util.NetworkStackConstants.ICMPV6_NEIGHBOR_SOLICITATION;
 import static com.android.net.module.util.NetworkStackConstants.ICMPV6_ROUTER_ADVERTISEMENT;
+import static com.android.net.module.util.NetworkStackConstants.ICMP_CHECKSUM_OFFSET;
+import static com.android.net.module.util.NetworkStackConstants.IPV4_CHECKSUM_OFFSET;
+import static com.android.net.module.util.NetworkStackConstants.IPV4_HEADER_MIN_LEN;
+import static com.android.net.module.util.NetworkStackConstants.IPV4_LENGTH_OFFSET;
 import static com.android.net.module.util.NetworkStackConstants.IPV6_ADDR_ALL_NODES_MULTICAST;
 import static com.android.net.module.util.NetworkStackConstants.NEIGHBOR_ADVERTISEMENT_FLAG_OVERRIDE;
 import static com.android.net.module.util.NetworkStackConstants.NEIGHBOR_ADVERTISEMENT_FLAG_SOLICITED;
@@ -58,6 +68,7 @@
 
 import com.android.net.module.util.DnsPacket;
 import com.android.net.module.util.Ipv6Utils;
+import com.android.net.module.util.PacketBuilder;
 import com.android.net.module.util.Struct;
 import com.android.net.module.util.structs.EthernetHeader;
 import com.android.net.module.util.structs.Icmpv4Header;
@@ -101,6 +112,23 @@
             DhcpPacket.DHCP_LEASE_TIME,
     };
     private static final InetAddress LINK_LOCAL = parseNumericAddress("fe80::1");
+    // IPv4 header definition.
+    protected static final short ID = 27149;
+    protected static final short FLAGS_AND_FRAGMENT_OFFSET = (short) 0x4000; // flags=DF, offset=0
+    protected static final byte TIME_TO_LIVE = (byte) 0x40;
+    protected static final byte TYPE_OF_SERVICE = 0;
+
+    // IPv6 header definition.
+    private static final short HOP_LIMIT = 0x40;
+    // version=6, traffic class=0x0, flowlabel=0x0;
+    private static final int VERSION_TRAFFICCLASS_FLOWLABEL = 0x60000000;
+
+    // UDP and TCP header definition.
+    private static final short WINDOW = (short) 0x2000;
+    private static final short URGENT_POINTER = 0;
+
+    // ICMP definition.
+    private static final short ICMPECHO_CODE = 0x0;
 
     public static final String DHCP_HOSTNAME = "testhostname";
 
@@ -628,6 +656,190 @@
         return false;
     }
 
+    @NonNull
+    public static ByteBuffer buildUdpPacket(
+            @Nullable final MacAddress srcMac, @Nullable final MacAddress dstMac,
+            @NonNull final InetAddress srcIp, @NonNull final InetAddress dstIp,
+            short srcPort, short dstPort, @Nullable final ByteBuffer payload)
+            throws Exception {
+        final int ipProto = getIpProto(srcIp, dstIp);
+        final boolean hasEther = (srcMac != null && dstMac != null);
+        final int payloadLen = (payload == null) ? 0 : payload.limit();
+        final ByteBuffer buffer = PacketBuilder.allocate(hasEther, ipProto, IPPROTO_UDP,
+                payloadLen);
+        final PacketBuilder packetBuilder = new PacketBuilder(buffer);
+
+        // [1] Ethernet header
+        if (hasEther) {
+            packetBuilder.writeL2Header(srcMac, dstMac, getEthType(srcIp, dstIp));
+        }
+
+        // [2] IP header
+        if (ipProto == IPPROTO_IP) {
+            packetBuilder.writeIpv4Header(TYPE_OF_SERVICE, ID, FLAGS_AND_FRAGMENT_OFFSET,
+                    TIME_TO_LIVE, (byte) IPPROTO_UDP, (Inet4Address) srcIp, (Inet4Address) dstIp);
+        } else {
+            packetBuilder.writeIpv6Header(VERSION_TRAFFICCLASS_FLOWLABEL, (byte) IPPROTO_UDP,
+                    HOP_LIMIT, (Inet6Address) srcIp, (Inet6Address) dstIp);
+        }
+
+        // [3] UDP header
+        packetBuilder.writeUdpHeader(srcPort, dstPort);
+
+        // [4] Payload
+        if (payload != null) {
+            buffer.put(payload);
+            // in case data might be reused by caller, restore the position and
+            // limit of bytebuffer.
+            payload.clear();
+        }
+
+        return packetBuilder.finalizePacket();
+    }
+
+    @NonNull
+    public static ByteBuffer buildUdpPacket(@NonNull final InetAddress srcIp,
+            @NonNull final InetAddress dstIp, short srcPort, short dstPort,
+            @Nullable final ByteBuffer payload) throws Exception {
+        return buildUdpPacket(null /* srcMac */, null /* dstMac */, srcIp, dstIp, srcPort,
+                dstPort, payload);
+    }
+
+    @NonNull
+    public static ByteBuffer buildTcpPacket(
+            @Nullable final MacAddress srcMac, @Nullable final MacAddress dstMac,
+            @NonNull final InetAddress srcIp, @NonNull final InetAddress dstIp,
+            short srcPort, short dstPort, final short seq, final short ack,
+            final byte tcpFlags, @NonNull final ByteBuffer payload) throws Exception {
+        final int ipProto = getIpProto(srcIp, dstIp);
+        final boolean hasEther = (srcMac != null && dstMac != null);
+        final ByteBuffer buffer = PacketBuilder.allocate(hasEther, ipProto, IPPROTO_TCP,
+                payload.limit());
+        final PacketBuilder packetBuilder = new PacketBuilder(buffer);
+
+        // [1] Ethernet header
+        if (hasEther) {
+            packetBuilder.writeL2Header(srcMac, dstMac, getEthType(srcIp, dstIp));
+        }
+
+        // [2] IP header
+        if (ipProto == IPPROTO_IP) {
+            packetBuilder.writeIpv4Header(TYPE_OF_SERVICE, ID, FLAGS_AND_FRAGMENT_OFFSET,
+                    TIME_TO_LIVE, (byte) IPPROTO_TCP, (Inet4Address) srcIp, (Inet4Address) dstIp);
+        } else {
+            packetBuilder.writeIpv6Header(VERSION_TRAFFICCLASS_FLOWLABEL, (byte) IPPROTO_TCP,
+                    HOP_LIMIT, (Inet6Address) srcIp, (Inet6Address) dstIp);
+        }
+
+        // [3] TCP header
+        packetBuilder.writeTcpHeader(srcPort, dstPort, seq, ack, tcpFlags, WINDOW, URGENT_POINTER);
+
+        // [4] Payload
+        buffer.put(payload);
+        // in case data might be reused by caller, restore the position and
+        // limit of bytebuffer.
+        payload.clear();
+
+        return packetBuilder.finalizePacket();
+    }
+
+    // PacketBuilder doesn't support IPv4 ICMP packet. It may need to refactor PacketBuilder first
+    // because ICMP is a specific layer 3 protocol for PacketBuilder which expects packets always
+    // have layer 3 (IP) and layer 4 (TCP, UDP) for now. Since we don't use IPv4 ICMP packet too
+    // much in this test, we just write a ICMP packet builder here.
+    @NonNull
+    public static ByteBuffer buildIcmpEchoPacketV4(
+            @Nullable final MacAddress srcMac, @Nullable final MacAddress dstMac,
+            @NonNull final Inet4Address srcIp, @NonNull final Inet4Address dstIp,
+            int type, short id, short seq) throws Exception {
+        if (type != ICMP_ECHO && type != ICMP_ECHOREPLY) {
+            fail("Unsupported ICMP type: " + type);
+        }
+
+        // Build ICMP echo id and seq fields as payload. Ignore the data field.
+        final ByteBuffer payload = ByteBuffer.allocate(4);
+        payload.putShort(id);
+        payload.putShort(seq);
+        payload.rewind();
+
+        final boolean hasEther = (srcMac != null && dstMac != null);
+        final int etherHeaderLen = hasEther ? Struct.getSize(EthernetHeader.class) : 0;
+        final int ipv4HeaderLen = Struct.getSize(Ipv4Header.class);
+        final int Icmpv4HeaderLen = Struct.getSize(Icmpv4Header.class);
+        final int payloadLen = payload.limit();
+        final ByteBuffer packet = ByteBuffer.allocate(etherHeaderLen + ipv4HeaderLen
+                + Icmpv4HeaderLen + payloadLen);
+
+        // [1] Ethernet header
+        if (hasEther) {
+            final EthernetHeader ethHeader = new EthernetHeader(dstMac, srcMac, ETHER_TYPE_IPV4);
+            ethHeader.writeToByteBuffer(packet);
+        }
+
+        // [2] IP header
+        final Ipv4Header ipv4Header = new Ipv4Header(TYPE_OF_SERVICE,
+                (short) 0 /* totalLength, calculate later */, ID,
+                FLAGS_AND_FRAGMENT_OFFSET, TIME_TO_LIVE, (byte) IPPROTO_ICMP,
+                (short) 0 /* checksum, calculate later */, srcIp, dstIp);
+        ipv4Header.writeToByteBuffer(packet);
+
+        // [3] ICMP header
+        final Icmpv4Header icmpv4Header = new Icmpv4Header((byte) type, ICMPECHO_CODE,
+                (short) 0 /* checksum, calculate later */);
+        icmpv4Header.writeToByteBuffer(packet);
+
+        // [4] Payload
+        packet.put(payload);
+        packet.flip();
+
+        // [5] Finalize packet
+        // Used for updating IP header fields. If there is Ehternet header, IPv4 header offset
+        // in buffer equals ethernet header length because IPv4 header is located next to ethernet
+        // header. Otherwise, IPv4 header offset is 0.
+        final int ipv4HeaderOffset = hasEther ? etherHeaderLen : 0;
+
+        // Populate the IPv4 totalLength field.
+        packet.putShort(ipv4HeaderOffset + IPV4_LENGTH_OFFSET,
+                (short) (ipv4HeaderLen + Icmpv4HeaderLen + payloadLen));
+
+        // Populate the IPv4 header checksum field.
+        packet.putShort(ipv4HeaderOffset + IPV4_CHECKSUM_OFFSET,
+                ipChecksum(packet, ipv4HeaderOffset /* headerOffset */));
+
+        // Populate the ICMP checksum field.
+        packet.putShort(ipv4HeaderOffset + IPV4_HEADER_MIN_LEN + ICMP_CHECKSUM_OFFSET,
+                icmpChecksum(packet, ipv4HeaderOffset + IPV4_HEADER_MIN_LEN,
+                        Icmpv4HeaderLen + payloadLen));
+        return packet;
+    }
+
+    @NonNull
+    public static ByteBuffer buildIcmpEchoPacketV4(@NonNull final Inet4Address srcIp,
+            @NonNull final Inet4Address dstIp, int type, short id, short seq)
+            throws Exception {
+        return buildIcmpEchoPacketV4(null /* srcMac */, null /* dstMac */, srcIp, dstIp,
+                type, id, seq);
+    }
+
+    private static short getEthType(@NonNull final InetAddress srcIp,
+            @NonNull final InetAddress dstIp) {
+        return isAddressIpv4(srcIp, dstIp) ? (short) ETHER_TYPE_IPV4 : (short) ETHER_TYPE_IPV6;
+    }
+
+    private static int getIpProto(@NonNull final InetAddress srcIp,
+            @NonNull final InetAddress dstIp) {
+        return isAddressIpv4(srcIp, dstIp) ? IPPROTO_IP : IPPROTO_IPV6;
+    }
+
+    public static boolean isAddressIpv4(@NonNull final  InetAddress srcIp,
+            @NonNull final InetAddress dstIp) {
+        if (srcIp instanceof Inet4Address && dstIp instanceof Inet4Address) return true;
+        if (srcIp instanceof Inet6Address && dstIp instanceof Inet6Address) return false;
+
+        fail("Unsupported conditions: srcIp " + srcIp + ", dstIp " + dstIp);
+        return false;  // unreachable
+    }
+
     public void sendUploadPacket(ByteBuffer packet) throws Exception {
         mDownstreamReader.sendResponse(packet);
     }
diff --git a/Tethering/tests/integration/src/android/net/EthernetTetheringTest.java b/Tethering/tests/integration/src/android/net/EthernetTetheringTest.java
index 5d57aa5..eed308c 100644
--- a/Tethering/tests/integration/src/android/net/EthernetTetheringTest.java
+++ b/Tethering/tests/integration/src/android/net/EthernetTetheringTest.java
@@ -20,23 +20,17 @@
 import static android.net.TetheringManager.CONNECTIVITY_SCOPE_LOCAL;
 import static android.net.TetheringManager.TETHERING_ETHERNET;
 import static android.net.TetheringTester.TestDnsPacket;
+import static android.net.TetheringTester.buildIcmpEchoPacketV4;
+import static android.net.TetheringTester.buildUdpPacket;
 import static android.net.TetheringTester.isExpectedIcmpPacket;
 import static android.net.TetheringTester.isExpectedUdpDnsPacket;
 import static android.system.OsConstants.ICMP_ECHO;
 import static android.system.OsConstants.ICMP_ECHOREPLY;
-import static android.system.OsConstants.IPPROTO_ICMP;
 
 import static com.android.net.module.util.ConnectivityUtils.isIPv6ULA;
 import static com.android.net.module.util.HexDump.dumpHexString;
-import static com.android.net.module.util.IpUtils.icmpChecksum;
-import static com.android.net.module.util.IpUtils.ipChecksum;
-import static com.android.net.module.util.NetworkStackConstants.ETHER_TYPE_IPV4;
 import static com.android.net.module.util.NetworkStackConstants.ICMPV6_ECHO_REPLY_TYPE;
 import static com.android.net.module.util.NetworkStackConstants.ICMPV6_ECHO_REQUEST_TYPE;
-import static com.android.net.module.util.NetworkStackConstants.ICMP_CHECKSUM_OFFSET;
-import static com.android.net.module.util.NetworkStackConstants.IPV4_CHECKSUM_OFFSET;
-import static com.android.net.module.util.NetworkStackConstants.IPV4_HEADER_MIN_LEN;
-import static com.android.net.module.util.NetworkStackConstants.IPV4_LENGTH_OFFSET;
 
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNotNull;
@@ -53,14 +47,11 @@
 import android.util.Log;
 
 import androidx.annotation.NonNull;
-import androidx.annotation.Nullable;
 import androidx.test.filters.MediumTest;
 import androidx.test.runner.AndroidJUnit4;
 
 import com.android.net.module.util.Ipv6Utils;
 import com.android.net.module.util.Struct;
-import com.android.net.module.util.structs.EthernetHeader;
-import com.android.net.module.util.structs.Icmpv4Header;
 import com.android.net.module.util.structs.Ipv4Header;
 import com.android.net.module.util.structs.UdpHeader;
 import com.android.testutils.DevSdkIgnoreRule;
@@ -96,7 +87,6 @@
     private static final String TAG = EthernetTetheringTest.class.getSimpleName();
 
     private static final short DNS_PORT = 53;
-    private static final short ICMPECHO_CODE = 0x0;
     private static final short ICMPECHO_ID = 0x0;
     private static final short ICMPECHO_SEQ = 0x0;
 
@@ -564,85 +554,6 @@
         runClatUdpTest();
     }
 
-    // PacketBuilder doesn't support IPv4 ICMP packet. It may need to refactor PacketBuilder first
-    // because ICMP is a specific layer 3 protocol for PacketBuilder which expects packets always
-    // have layer 3 (IP) and layer 4 (TCP, UDP) for now. Since we don't use IPv4 ICMP packet too
-    // much in this test, we just write a ICMP packet builder here.
-    // TODO: move ICMPv4 packet build function to common utilis.
-    @NonNull
-    private ByteBuffer buildIcmpEchoPacketV4(
-            @Nullable final MacAddress srcMac, @Nullable final MacAddress dstMac,
-            @NonNull final Inet4Address srcIp, @NonNull final Inet4Address dstIp,
-            int type, short id, short seq) throws Exception {
-        if (type != ICMP_ECHO && type != ICMP_ECHOREPLY) {
-            fail("Unsupported ICMP type: " + type);
-        }
-
-        // Build ICMP echo id and seq fields as payload. Ignore the data field.
-        final ByteBuffer payload = ByteBuffer.allocate(4);
-        payload.putShort(id);
-        payload.putShort(seq);
-        payload.rewind();
-
-        final boolean hasEther = (srcMac != null && dstMac != null);
-        final int etherHeaderLen = hasEther ? Struct.getSize(EthernetHeader.class) : 0;
-        final int ipv4HeaderLen = Struct.getSize(Ipv4Header.class);
-        final int Icmpv4HeaderLen = Struct.getSize(Icmpv4Header.class);
-        final int payloadLen = payload.limit();
-        final ByteBuffer packet = ByteBuffer.allocate(etherHeaderLen + ipv4HeaderLen
-                + Icmpv4HeaderLen + payloadLen);
-
-        // [1] Ethernet header
-        if (hasEther) {
-            final EthernetHeader ethHeader = new EthernetHeader(dstMac, srcMac, ETHER_TYPE_IPV4);
-            ethHeader.writeToByteBuffer(packet);
-        }
-
-        // [2] IP header
-        final Ipv4Header ipv4Header = new Ipv4Header(TYPE_OF_SERVICE,
-                (short) 0 /* totalLength, calculate later */, ID,
-                FLAGS_AND_FRAGMENT_OFFSET, TIME_TO_LIVE, (byte) IPPROTO_ICMP,
-                (short) 0 /* checksum, calculate later */, srcIp, dstIp);
-        ipv4Header.writeToByteBuffer(packet);
-
-        // [3] ICMP header
-        final Icmpv4Header icmpv4Header = new Icmpv4Header((byte) type, ICMPECHO_CODE,
-                (short) 0 /* checksum, calculate later */);
-        icmpv4Header.writeToByteBuffer(packet);
-
-        // [4] Payload
-        packet.put(payload);
-        packet.flip();
-
-        // [5] Finalize packet
-        // Used for updating IP header fields. If there is Ehternet header, IPv4 header offset
-        // in buffer equals ethernet header length because IPv4 header is located next to ethernet
-        // header. Otherwise, IPv4 header offset is 0.
-        final int ipv4HeaderOffset = hasEther ? etherHeaderLen : 0;
-
-        // Populate the IPv4 totalLength field.
-        packet.putShort(ipv4HeaderOffset + IPV4_LENGTH_OFFSET,
-                (short) (ipv4HeaderLen + Icmpv4HeaderLen + payloadLen));
-
-        // Populate the IPv4 header checksum field.
-        packet.putShort(ipv4HeaderOffset + IPV4_CHECKSUM_OFFSET,
-                ipChecksum(packet, ipv4HeaderOffset /* headerOffset */));
-
-        // Populate the ICMP checksum field.
-        packet.putShort(ipv4HeaderOffset + IPV4_HEADER_MIN_LEN + ICMP_CHECKSUM_OFFSET,
-                icmpChecksum(packet, ipv4HeaderOffset + IPV4_HEADER_MIN_LEN,
-                        Icmpv4HeaderLen + payloadLen));
-        return packet;
-    }
-
-    @NonNull
-    private ByteBuffer buildIcmpEchoPacketV4(@NonNull final Inet4Address srcIp,
-            @NonNull final Inet4Address dstIp, int type, short id, short seq)
-            throws Exception {
-        return buildIcmpEchoPacketV4(null /* srcMac */, null /* dstMac */, srcIp, dstIp,
-                type, id, seq);
-    }
-
     @Test
     public void testIcmpv4Echo() throws Exception {
         final TetheringTester tester = initTetheringTester(toList(TEST_IP4_ADDR),
diff --git a/Tethering/tests/unit/src/com/android/networkstack/tethering/TetheringTest.java b/Tethering/tests/unit/src/com/android/networkstack/tethering/TetheringTest.java
index bd8b325..9c6904d 100644
--- a/Tethering/tests/unit/src/com/android/networkstack/tethering/TetheringTest.java
+++ b/Tethering/tests/unit/src/com/android/networkstack/tethering/TetheringTest.java
@@ -240,6 +240,7 @@
     private static final String TEST_RNDIS_IFNAME = "test_rndis0";
     private static final String TEST_WIFI_IFNAME = "test_wlan0";
     private static final String TEST_WLAN_IFNAME = "test_wlan1";
+    private static final String TEST_WLAN2_IFNAME = "test_wlan2";
     private static final String TEST_P2P_IFNAME = "test_p2p-p2p0-0";
     private static final String TEST_NCM_IFNAME = "test_ncm0";
     private static final String TEST_ETH_IFNAME = "test_eth0";
@@ -392,6 +393,7 @@
             assertTrue("Non-mocked interface " + ifName,
                     ifName.equals(TEST_RNDIS_IFNAME)
                             || ifName.equals(TEST_WLAN_IFNAME)
+                            || ifName.equals(TEST_WLAN2_IFNAME)
                             || ifName.equals(TEST_WIFI_IFNAME)
                             || ifName.equals(TEST_MOBILE_IFNAME)
                             || ifName.equals(TEST_DUN_IFNAME)
@@ -400,8 +402,9 @@
                             || ifName.equals(TEST_ETH_IFNAME)
                             || ifName.equals(TEST_BT_IFNAME));
             final String[] ifaces = new String[] {
-                    TEST_RNDIS_IFNAME, TEST_WLAN_IFNAME, TEST_WIFI_IFNAME, TEST_MOBILE_IFNAME,
-                    TEST_DUN_IFNAME, TEST_P2P_IFNAME, TEST_NCM_IFNAME, TEST_ETH_IFNAME};
+                    TEST_RNDIS_IFNAME, TEST_WLAN_IFNAME, TEST_WLAN2_IFNAME, TEST_WIFI_IFNAME,
+                    TEST_MOBILE_IFNAME, TEST_DUN_IFNAME, TEST_P2P_IFNAME, TEST_NCM_IFNAME,
+                    TEST_ETH_IFNAME};
             return new InterfaceParams(ifName,
                     CollectionUtils.indexOf(ifaces, ifName) + IFINDEX_OFFSET,
                     MacAddress.ALL_ZEROS_ADDRESS);
@@ -428,7 +431,7 @@
 
     public class MockTetheringDependencies extends TetheringDependencies {
         StateMachine mUpstreamNetworkMonitorSM;
-        ArrayList<IpServer> mIpv6CoordinatorNotifyList;
+        ArrayList<IpServer> mAllDownstreams;
 
         @Override
         public BpfCoordinator getBpfCoordinator(
@@ -463,7 +466,7 @@
         @Override
         public IPv6TetheringCoordinator getIPv6TetheringCoordinator(
                 ArrayList<IpServer> notifyList, SharedLog log) {
-            mIpv6CoordinatorNotifyList = notifyList;
+            mAllDownstreams = notifyList;
             return mIPv6TetheringCoordinator;
         }
 
@@ -642,8 +645,8 @@
                 false);
         when(mNetd.interfaceGetList())
                 .thenReturn(new String[] {
-                        TEST_MOBILE_IFNAME, TEST_WLAN_IFNAME, TEST_RNDIS_IFNAME, TEST_P2P_IFNAME,
-                        TEST_NCM_IFNAME, TEST_ETH_IFNAME, TEST_BT_IFNAME});
+                        TEST_MOBILE_IFNAME, TEST_WLAN_IFNAME, TEST_WLAN2_IFNAME, TEST_RNDIS_IFNAME,
+                        TEST_P2P_IFNAME, TEST_NCM_IFNAME, TEST_ETH_IFNAME, TEST_BT_IFNAME});
         when(mResources.getString(R.string.config_wifi_tether_enable)).thenReturn("");
         mInterfaceConfiguration = new InterfaceConfigurationParcel();
         mInterfaceConfiguration.flags = new String[0];
@@ -1026,7 +1029,7 @@
      */
     private void sendIPv6TetherUpdates(UpstreamNetworkState upstreamState) {
         // IPv6TetheringCoordinator must have been notified of downstream
-        for (IpServer ipSrv : mTetheringDependencies.mIpv6CoordinatorNotifyList) {
+        for (IpServer ipSrv : mTetheringDependencies.mAllDownstreams) {
             UpstreamNetworkState ipv6OnlyState = buildMobileUpstreamState(false, true, false);
             ipSrv.sendMessage(IpServer.CMD_IPV6_TETHER_UPDATE, 0, 0,
                     upstreamState.linkProperties.isIpv6Provisioned()
@@ -2682,10 +2685,9 @@
         final UpstreamNetworkState upstreamState2 = buildMobileIPv4UpstreamState();
         initTetheringUpstream(upstreamState2);
         stateMachine.chooseUpstreamType(true);
-        // Bug: duplicated upstream change event.
-        mTetheringEventCallback.expectUpstreamChanged(upstreamState2.network);
-        inOrder.verify(mNotificationUpdater)
-                .onUpstreamCapabilitiesChanged(upstreamState2.networkCapabilities);
+        // Expect that no upstream change event and capabilities changed event.
+        mTetheringEventCallback.assertNoUpstreamChangeCallback();
+        inOrder.verify(mNotificationUpdater, never()).onUpstreamCapabilitiesChanged(any());
 
         // Set the upstream with the same network ID but different object and different capability.
         final UpstreamNetworkState upstreamState3 = buildMobileIPv4UpstreamState();
@@ -2693,11 +2695,13 @@
         upstreamState3.networkCapabilities.addCapability(NET_CAPABILITY_VALIDATED);
         initTetheringUpstream(upstreamState3);
         stateMachine.chooseUpstreamType(true);
-        // Bug: duplicated upstream change event.
-        mTetheringEventCallback.expectUpstreamChanged(upstreamState3.network);
+        // Expect that no upstream change event and capabilities changed event.
+        mTetheringEventCallback.assertNoUpstreamChangeCallback();
+        stateMachine.handleUpstreamNetworkMonitorCallback(EVENT_ON_CAPABILITIES, upstreamState3);
         inOrder.verify(mNotificationUpdater)
                 .onUpstreamCapabilitiesChanged(upstreamState3.networkCapabilities);
 
+
         // Lose upstream.
         initTetheringUpstream(null);
         stateMachine.chooseUpstreamType(true);
@@ -3046,6 +3050,58 @@
         callback.expectTetheredClientChanged(Collections.emptyList());
     }
 
+    @Test
+    @IgnoreUpTo(Build.VERSION_CODES.S_V2)
+    public void testConnectedClientsForSapAndLohsConcurrency() throws Exception {
+        TestTetheringEventCallback callback = new TestTetheringEventCallback();
+        runAsShell(NETWORK_SETTINGS, () -> {
+            mTethering.registerTetheringEventCallback(callback);
+            mLooper.dispatchAll();
+        });
+        callback.expectTetheredClientChanged(Collections.emptyList());
+
+        mTethering.interfaceStatusChanged(TEST_WLAN_IFNAME, true);
+        sendWifiApStateChanged(WIFI_AP_STATE_ENABLED, TEST_WLAN_IFNAME, IFACE_IP_MODE_TETHERED);
+        final ArgumentCaptor<IDhcpEventCallbacks> dhcpEventCbsCaptor =
+                 ArgumentCaptor.forClass(IDhcpEventCallbacks.class);
+        verify(mDhcpServer, timeout(DHCPSERVER_START_TIMEOUT_MS)).startWithCallbacks(
+                any(), dhcpEventCbsCaptor.capture());
+        IDhcpEventCallbacks eventCallbacks = dhcpEventCbsCaptor.getValue();
+        final List<TetheredClient> connectedClients = new ArrayList<>();
+        final MacAddress wifiMac = MacAddress.fromString("11:11:11:11:11:11");
+        final DhcpLeaseParcelable wifiLease = createDhcpLeaseParcelable("clientId", wifiMac,
+                "192.168.2.12", 24, Long.MAX_VALUE, "test");
+        verifyHotspotClientUpdate(false /* isLocalOnly */, wifiMac, wifiLease, connectedClients,
+                eventCallbacks, callback);
+        reset(mDhcpServer);
+
+        mTethering.interfaceStatusChanged(TEST_WLAN2_IFNAME, true);
+        sendWifiApStateChanged(WIFI_AP_STATE_ENABLED, TEST_WLAN2_IFNAME, IFACE_IP_MODE_LOCAL_ONLY);
+
+        verify(mDhcpServer, timeout(DHCPSERVER_START_TIMEOUT_MS)).startWithCallbacks(
+                any(), dhcpEventCbsCaptor.capture());
+        eventCallbacks = dhcpEventCbsCaptor.getValue();
+        final MacAddress localOnlyMac = MacAddress.fromString("22:22:22:22:22:22");
+        final DhcpLeaseParcelable localOnlyLease = createDhcpLeaseParcelable("clientId",
+                localOnlyMac, "192.168.43.24", 24, Long.MAX_VALUE, "test");
+        verifyHotspotClientUpdate(true /* isLocalOnly */, localOnlyMac, localOnlyLease,
+                connectedClients, eventCallbacks, callback);
+
+        assertTrue(isIpServerActive(TETHERING_WIFI, TEST_WLAN_IFNAME, IpServer.STATE_TETHERED));
+        assertTrue(isIpServerActive(TETHERING_WIFI, TEST_WLAN2_IFNAME, IpServer.STATE_LOCAL_ONLY));
+    }
+
+    private boolean isIpServerActive(int type, String ifName, int mode) {
+        for (IpServer ipSrv : mTetheringDependencies.mAllDownstreams) {
+            if (ipSrv.interfaceType() == type && ipSrv.interfaceName().equals(ifName)
+                    && ipSrv.servingMode() == mode) {
+                return true;
+            }
+        }
+
+        return false;
+    }
+
     private void verifyHotspotClientUpdate(final boolean isLocalOnly, final MacAddress testMac,
             final DhcpLeaseParcelable dhcpLease, final List<TetheredClient> currentClients,
             final IDhcpEventCallbacks dhcpCallback, final TestTetheringEventCallback callback)
@@ -3486,6 +3542,29 @@
         assertEquals(expectedTypes.size() > 0, mTethering.isTetheringSupported());
         callback.expectSupportedTetheringTypes(expectedTypes);
     }
+
+    @Test
+    public void testIpv4AddressForSapAndLohsConcurrency() throws Exception {
+        mTethering.interfaceStatusChanged(TEST_WLAN_IFNAME, true);
+        sendWifiApStateChanged(WIFI_AP_STATE_ENABLED, TEST_WLAN_IFNAME, IFACE_IP_MODE_TETHERED);
+
+        ArgumentCaptor<InterfaceConfigurationParcel> ifaceConfigCaptor =
+                ArgumentCaptor.forClass(InterfaceConfigurationParcel.class);
+        verify(mNetd).interfaceSetCfg(ifaceConfigCaptor.capture());
+        InterfaceConfigurationParcel ifaceConfig = ifaceConfigCaptor.getValue();
+        final IpPrefix sapPrefix = new IpPrefix(
+                InetAddresses.parseNumericAddress(ifaceConfig.ipv4Addr), ifaceConfig.prefixLength);
+
+        mTethering.interfaceStatusChanged(TEST_WLAN2_IFNAME, true);
+        sendWifiApStateChanged(WIFI_AP_STATE_ENABLED, TEST_WLAN2_IFNAME, IFACE_IP_MODE_LOCAL_ONLY);
+
+        ifaceConfigCaptor = ArgumentCaptor.forClass(InterfaceConfigurationParcel.class);
+        verify(mNetd, times(2)).interfaceSetCfg(ifaceConfigCaptor.capture());
+        ifaceConfig = ifaceConfigCaptor.getValue();
+        final IpPrefix lohsPrefix = new IpPrefix(
+                InetAddresses.parseNumericAddress(ifaceConfig.ipv4Addr), ifaceConfig.prefixLength);
+        assertFalse(sapPrefix.equals(lohsPrefix));
+    }
     // TODO: Test that a request for hotspot mode doesn't interfere with an
     // already operating tethering mode interface.
 }
diff --git a/bpf_progs/block.c b/bpf_progs/block.c
index 3797a38..d734b74 100644
--- a/bpf_progs/block.c
+++ b/bpf_progs/block.c
@@ -71,3 +71,4 @@
 
 LICENSE("Apache 2.0");
 CRITICAL("ConnectivityNative");
+DISABLE_BTF_ON_USER_BUILDS();
diff --git a/bpf_progs/clatd.c b/bpf_progs/clatd.c
index 85ba58e..905b8fa 100644
--- a/bpf_progs/clatd.c
+++ b/bpf_progs/clatd.c
@@ -416,3 +416,4 @@
 
 LICENSE("Apache 2.0");
 CRITICAL("Connectivity");
+DISABLE_BTF_ON_USER_BUILDS();
diff --git a/bpf_progs/dscpPolicy.c b/bpf_progs/dscpPolicy.c
index 262b65b..88b50b5 100644
--- a/bpf_progs/dscpPolicy.c
+++ b/bpf_progs/dscpPolicy.c
@@ -238,3 +238,4 @@
 
 LICENSE("Apache 2.0");
 CRITICAL("Connectivity");
+DISABLE_BTF_ON_USER_BUILDS();
diff --git a/bpf_progs/netd.c b/bpf_progs/netd.c
index 245ad7a..74b09e7 100644
--- a/bpf_progs/netd.c
+++ b/bpf_progs/netd.c
@@ -587,3 +587,4 @@
 
 LICENSE("Apache 2.0");
 CRITICAL("Connectivity and netd");
+DISABLE_BTF_ON_USER_BUILDS();
diff --git a/bpf_progs/offload.c b/bpf_progs/offload.c
index 80d1a41..8645dd7 100644
--- a/bpf_progs/offload.c
+++ b/bpf_progs/offload.c
@@ -863,3 +863,4 @@
 
 LICENSE("Apache 2.0");
 CRITICAL("Connectivity (Tethering)");
+DISABLE_BTF_ON_USER_BUILDS();
diff --git a/bpf_progs/test.c b/bpf_progs/test.c
index 091743c..68469c8 100644
--- a/bpf_progs/test.c
+++ b/bpf_progs/test.c
@@ -67,3 +67,4 @@
 }
 
 LICENSE("Apache 2.0");
+DISABLE_BTF_ON_USER_BUILDS();
diff --git a/common/Android.bp b/common/Android.bp
index 729ef32..ff4de11 100644
--- a/common/Android.bp
+++ b/common/Android.bp
@@ -25,7 +25,7 @@
         "src/com/android/net/module/util/bpf/*.java",
     ],
     sdk_version: "module_current",
-    min_sdk_version: "29",
+    min_sdk_version: "30",
     visibility: [
         // Do not add any lib. This library is only shared inside connectivity module
         // and its tests.
diff --git a/framework/src/android/net/NattKeepalivePacketData.java b/framework/src/android/net/NattKeepalivePacketData.java
index a18e713..9e6d80d 100644
--- a/framework/src/android/net/NattKeepalivePacketData.java
+++ b/framework/src/android/net/NattKeepalivePacketData.java
@@ -29,7 +29,9 @@
 import com.android.net.module.util.IpUtils;
 
 import java.net.Inet4Address;
+import java.net.Inet6Address;
 import java.net.InetAddress;
+import java.net.UnknownHostException;
 import java.nio.ByteBuffer;
 import java.nio.ByteOrder;
 import java.util.Objects;
@@ -38,6 +40,7 @@
 @SystemApi
 public final class NattKeepalivePacketData extends KeepalivePacketData implements Parcelable {
     private static final int IPV4_HEADER_LENGTH = 20;
+    private static final int IPV6_HEADER_LENGTH = 40;
     private static final int UDP_HEADER_LENGTH = 8;
 
     // This should only be constructed via static factory methods, such as
@@ -59,13 +62,25 @@
             throw new InvalidPacketException(ERROR_INVALID_PORT);
         }
 
-        if (!(srcAddress instanceof Inet4Address) || !(dstAddress instanceof Inet4Address)) {
+        // Convert IPv4 mapped v6 address to v4 if any.
+        final InetAddress srcAddr, dstAddr;
+        try {
+            srcAddr = InetAddress.getByAddress(srcAddress.getAddress());
+            dstAddr = InetAddress.getByAddress(dstAddress.getAddress());
+        } catch (UnknownHostException e) {
             throw new InvalidPacketException(ERROR_INVALID_IP_ADDRESS);
         }
 
-        return nattKeepalivePacketv4(
-                (Inet4Address) srcAddress, srcPort,
-                (Inet4Address) dstAddress, dstPort);
+        if (srcAddr instanceof Inet4Address && dstAddr instanceof Inet4Address) {
+            return nattKeepalivePacketv4(
+                    (Inet4Address) srcAddr, srcPort, (Inet4Address) dstAddr, dstPort);
+        } else if (srcAddr instanceof Inet6Address && dstAddr instanceof Inet6Address) {
+            return nattKeepalivePacketv6(
+                    (Inet6Address) srcAddr, srcPort, (Inet6Address) dstAddr, dstPort);
+        } else {
+            // Destination address and source address should be the same IP family.
+            throw new InvalidPacketException(ERROR_INVALID_IP_ADDRESS);
+        }
     }
 
     private static NattKeepalivePacketData nattKeepalivePacketv4(
@@ -82,14 +97,14 @@
         // /proc/sys/net/ipv4/ip_default_ttl. Use hard-coded 64 for simplicity.
         buf.put((byte) 64);                                 // TTL
         buf.put((byte) OsConstants.IPPROTO_UDP);
-        int ipChecksumOffset = buf.position();
+        final int ipChecksumOffset = buf.position();
         buf.putShort((short) 0);                            // IP checksum
         buf.put(srcAddress.getAddress());
         buf.put(dstAddress.getAddress());
         buf.putShort((short) srcPort);
         buf.putShort((short) dstPort);
         buf.putShort((short) (UDP_HEADER_LENGTH + 1));      // UDP length
-        int udpChecksumOffset = buf.position();
+        final int udpChecksumOffset = buf.position();
         buf.putShort((short) 0);                            // UDP checksum
         buf.put((byte) 0xff);                               // NAT-T keepalive
         buf.putShort(ipChecksumOffset, IpUtils.ipChecksum(buf, 0));
@@ -98,6 +113,30 @@
         return new NattKeepalivePacketData(srcAddress, srcPort, dstAddress, dstPort, buf.array());
     }
 
+    private static NattKeepalivePacketData nattKeepalivePacketv6(
+            Inet6Address srcAddress, int srcPort, Inet6Address dstAddress, int dstPort)
+            throws InvalidPacketException {
+        final ByteBuffer buf = ByteBuffer.allocate(IPV6_HEADER_LENGTH + UDP_HEADER_LENGTH + 1);
+        buf.order(ByteOrder.BIG_ENDIAN);
+        buf.putInt(0x60000000);                         // IP version, traffic class and flow label
+        buf.putShort((short) (UDP_HEADER_LENGTH + 1));  // Payload length
+        buf.put((byte) OsConstants.IPPROTO_UDP);        // Next header
+        // For native ipv6, this hop limit value should use the per interface v6 hoplimit sysctl.
+        // For 464xlat, this value should use the v4 ttl sysctl.
+        // Either way, for simplicity, just hard code 64.
+        buf.put((byte) 64);                             // Hop limit
+        buf.put(srcAddress.getAddress());
+        buf.put(dstAddress.getAddress());
+        // UDP
+        buf.putShort((short) srcPort);
+        buf.putShort((short) dstPort);
+        buf.putShort((short) (UDP_HEADER_LENGTH + 1));  // UDP length = Payload length
+        final int udpChecksumOffset = buf.position();
+        buf.putShort((short) 0);                        // UDP checksum
+        buf.put((byte) 0xff);                           // NAT-T keepalive. 1 byte of data
+        buf.putShort(udpChecksumOffset, IpUtils.udpChecksum(buf, 0, IPV6_HEADER_LENGTH));
+        return new NattKeepalivePacketData(srcAddress, srcPort, dstAddress, dstPort, buf.array());
+    }
     /** Parcelable Implementation */
     public int describeContents() {
         return 0;
diff --git a/netd/BpfHandler.cpp b/netd/BpfHandler.cpp
index 3984249..d239277 100644
--- a/netd/BpfHandler.cpp
+++ b/netd/BpfHandler.cpp
@@ -52,25 +52,25 @@
 static Status attachProgramToCgroup(const char* programPath, const unique_fd& cgroupFd,
                                     bpf_attach_type type) {
     unique_fd cgroupProg(retrieveProgram(programPath));
-    if (cgroupProg == -1) {
-        int ret = errno;
-        ALOGE("Failed to get program from %s: %s", programPath, strerror(ret));
-        return statusFromErrno(ret, "cgroup program get failed");
+    if (!cgroupProg.ok()) {
+        const int err = errno;
+        ALOGE("Failed to get program from %s: %s", programPath, strerror(err));
+        return statusFromErrno(err, "cgroup program get failed");
     }
     if (android::bpf::attachProgram(type, cgroupProg, cgroupFd)) {
-        int ret = errno;
-        ALOGE("Program from %s attach failed: %s", programPath, strerror(ret));
-        return statusFromErrno(ret, "program attach failed");
+        const int err = errno;
+        ALOGE("Program from %s attach failed: %s", programPath, strerror(err));
+        return statusFromErrno(err, "program attach failed");
     }
     return netdutils::status::ok;
 }
 
 static Status checkProgramAccessible(const char* programPath) {
     unique_fd prog(retrieveProgram(programPath));
-    if (prog == -1) {
-        int ret = errno;
-        ALOGE("Failed to get program from %s: %s", programPath, strerror(ret));
-        return statusFromErrno(ret, "program retrieve failed");
+    if (!prog.ok()) {
+        const int err = errno;
+        ALOGE("Failed to get program from %s: %s", programPath, strerror(err));
+        return statusFromErrno(err, "program retrieve failed");
     }
     return netdutils::status::ok;
 }
@@ -79,10 +79,10 @@
     if (modules::sdklevel::IsAtLeastU() && !!strcmp(cg2_path, "/sys/fs/cgroup")) abort();
 
     unique_fd cg_fd(open(cg2_path, O_DIRECTORY | O_RDONLY | O_CLOEXEC));
-    if (cg_fd == -1) {
-        const int ret = errno;
-        ALOGE("Failed to open the cgroup directory: %s", strerror(ret));
-        return statusFromErrno(ret, "Open the cgroup directory failed");
+    if (!cg_fd.ok()) {
+        const int err = errno;
+        ALOGE("Failed to open the cgroup directory: %s", strerror(err));
+        return statusFromErrno(err, "Open the cgroup directory failed");
     }
     RETURN_IF_NOT_OK(checkProgramAccessible(XT_BPF_ALLOWLIST_PROG_PATH));
     RETURN_IF_NOT_OK(checkProgramAccessible(XT_BPF_DENYLIST_PROG_PATH));
diff --git a/service-t/native/libs/libnetworkstats/NetworkTracePoller.cpp b/service-t/native/libs/libnetworkstats/NetworkTracePoller.cpp
index d538368..80c315a 100644
--- a/service-t/native/libs/libnetworkstats/NetworkTracePoller.cpp
+++ b/service-t/native/libs/libnetworkstats/NetworkTracePoller.cpp
@@ -29,16 +29,16 @@
 namespace bpf {
 namespace internal {
 
-void NetworkTracePoller::SchedulePolling() {
-  // Schedules another run of ourselves to recursively poll periodically.
-  mTaskRunner->PostDelayedTask(
-      [this]() {
-        mMutex.lock();
-        SchedulePolling();
-        ConsumeAllLocked();
-        mMutex.unlock();
-      },
-      mPollMs);
+void NetworkTracePoller::PollAndSchedule(perfetto::base::TaskRunner* runner,
+                                         uint32_t poll_ms) {
+  // Always schedule another run of ourselves to recursively poll periodically.
+  // The task runner is sequential so these can't run on top of each other.
+  runner->PostDelayedTask([=]() { PollAndSchedule(runner, poll_ms); }, poll_ms);
+
+  if (mMutex.try_lock()) {
+    ConsumeAllLocked();
+    mMutex.unlock();
+  }
 }
 
 bool NetworkTracePoller::Start(uint32_t pollMs) {
@@ -81,7 +81,7 @@
   // Start a task runner to run ConsumeAll every mPollMs milliseconds.
   mTaskRunner = perfetto::Platform::GetDefaultPlatform()->CreateTaskRunner({});
   mPollMs = pollMs;
-  SchedulePolling();
+  PollAndSchedule(mTaskRunner.get(), mPollMs);
 
   mSessionCount++;
   return true;
diff --git a/service-t/native/libs/libnetworkstats/include/netdbpf/NetworkTracePoller.h b/service-t/native/libs/libnetworkstats/include/netdbpf/NetworkTracePoller.h
index adde51e..8433934 100644
--- a/service-t/native/libs/libnetworkstats/include/netdbpf/NetworkTracePoller.h
+++ b/service-t/native/libs/libnetworkstats/include/netdbpf/NetworkTracePoller.h
@@ -53,7 +53,12 @@
   bool ConsumeAll() EXCLUDES(mMutex);
 
  private:
-  void SchedulePolling() REQUIRES(mMutex);
+  // Poll the ring buffer for new data and schedule another run of ourselves
+  // after poll_ms (essentially polling periodically until stopped). This takes
+  // in the runner and poll duration to prevent a hard requirement on the lock
+  // and thus a deadlock while resetting the TaskRunner. The runner pointer is
+  // always valid within tasks run by that runner.
+  void PollAndSchedule(perfetto::base::TaskRunner* runner, uint32_t poll_ms);
   bool ConsumeAllLocked() REQUIRES(mMutex);
 
   std::mutex mMutex;
diff --git a/service-t/src/com/android/server/NsdService.java b/service-t/src/com/android/server/NsdService.java
index 25aa693..754a60b 100644
--- a/service-t/src/com/android/server/NsdService.java
+++ b/service-t/src/com/android/server/NsdService.java
@@ -1317,10 +1317,9 @@
                 final NsdServiceInfo info = buildNsdServiceInfoFromMdnsEvent(event, code);
                 // Errors are already logged if null
                 if (info == null) return false;
-                if (DBG) {
-                    Log.d(TAG, String.format("MdnsDiscoveryManager event code=%s transactionId=%d",
-                            NsdManager.nameOf(code), transactionId));
-                }
+                mServiceLogs.log(String.format(
+                        "MdnsDiscoveryManager event code=%s transactionId=%d",
+                        NsdManager.nameOf(code), transactionId));
                 switch (code) {
                     case NsdManager.SERVICE_FOUND:
                         clientInfo.onServiceFound(clientId, info);
@@ -1722,6 +1721,7 @@
     private class AdvertiserCallback implements MdnsAdvertiser.AdvertiserCallback {
         @Override
         public void onRegisterServiceSucceeded(int serviceId, NsdServiceInfo registeredInfo) {
+            mServiceLogs.log("onRegisterServiceSucceeded: serviceId " + serviceId);
             final ClientInfo clientInfo = getClientInfoOrLog(serviceId);
             if (clientInfo == null) return;
 
diff --git a/service-t/src/com/android/server/connectivity/mdns/EnqueueMdnsQueryCallable.java b/service-t/src/com/android/server/connectivity/mdns/EnqueueMdnsQueryCallable.java
index 84faf12..bd4ec20 100644
--- a/service-t/src/com/android/server/connectivity/mdns/EnqueueMdnsQueryCallable.java
+++ b/service-t/src/com/android/server/connectivity/mdns/EnqueueMdnsQueryCallable.java
@@ -18,7 +18,6 @@
 
 import android.annotation.NonNull;
 import android.annotation.Nullable;
-import android.net.Network;
 import android.text.TextUtils;
 import android.util.Log;
 import android.util.Pair;
@@ -29,7 +28,6 @@
 import java.io.IOException;
 import java.lang.ref.WeakReference;
 import java.net.DatagramPacket;
-import java.net.InetAddress;
 import java.net.InetSocketAddress;
 import java.util.ArrayList;
 import java.util.Collection;
@@ -71,13 +69,14 @@
     private final List<String> subtypes;
     private final boolean expectUnicastResponse;
     private final int transactionId;
-    @Nullable
-    private final Network network;
+    @NonNull
+    private final SocketKey socketKey;
     private final boolean sendDiscoveryQueries;
     @NonNull
     private final List<MdnsResponse> servicesToResolve;
     @NonNull
     private final MdnsResponseDecoder.Clock clock;
+    private final boolean onlyUseIpv6OnIpv6OnlyNetworks;
 
     EnqueueMdnsQueryCallable(
             @NonNull MdnsSocketClientBase requestSender,
@@ -86,7 +85,8 @@
             @NonNull Collection<String> subtypes,
             boolean expectUnicastResponse,
             int transactionId,
-            @Nullable Network network,
+            @NonNull SocketKey socketKey,
+            boolean onlyUseIpv6OnIpv6OnlyNetworks,
             boolean sendDiscoveryQueries,
             @NonNull Collection<MdnsResponse> servicesToResolve,
             @NonNull MdnsResponseDecoder.Clock clock) {
@@ -96,7 +96,8 @@
         this.subtypes = new ArrayList<>(subtypes);
         this.expectUnicastResponse = expectUnicastResponse;
         this.transactionId = transactionId;
-        this.network = network;
+        this.socketKey = socketKey;
+        this.onlyUseIpv6OnIpv6OnlyNetworks = onlyUseIpv6OnIpv6OnlyNetworks;
         this.sendDiscoveryQueries = sendDiscoveryQueries;
         this.servicesToResolve = new ArrayList<>(servicesToResolve);
         this.clock = clock;
@@ -128,24 +129,29 @@
             for (MdnsResponse response : servicesToResolve) {
                 final String[] serviceName = response.getServiceName();
                 if (serviceName == null) continue;
-                if (!response.hasTextRecord() || MdnsUtils.isRecordRenewalNeeded(
-                        response.getTextRecord(), now)) {
-                    missingKnownAnswerRecords.add(new Pair<>(serviceName, MdnsRecord.TYPE_TXT));
-                }
-                if (!response.hasServiceRecord() || MdnsUtils.isRecordRenewalNeeded(
-                        response.getServiceRecord(), now)) {
-                    missingKnownAnswerRecords.add(new Pair<>(serviceName, MdnsRecord.TYPE_SRV));
-                    // The hostname is not yet known, so queries for address records will be sent
-                    // the next time the EnqueueMdnsQueryCallable is enqueued if the reply does not
-                    // contain them. In practice, advertisers should include the address records
-                    // when queried for SRV, although it's not a MUST requirement (RFC6763 12.2).
-                    // TODO: Figure out how to renew the A/AAAA record. Usually A/AAAA record will
-                    //  be included in the response to the SRV record so in high chances there is
-                    //  no need to renew them individually.
-                } else if (!response.hasInet4AddressRecord() && !response.hasInet6AddressRecord()) {
-                    final String[] host = response.getServiceRecord().getServiceHost();
-                    missingKnownAnswerRecords.add(new Pair<>(host, MdnsRecord.TYPE_A));
-                    missingKnownAnswerRecords.add(new Pair<>(host, MdnsRecord.TYPE_AAAA));
+                boolean renewTxt = !response.hasTextRecord() || MdnsUtils.isRecordRenewalNeeded(
+                        response.getTextRecord(), now);
+                boolean renewSrv = !response.hasServiceRecord() || MdnsUtils.isRecordRenewalNeeded(
+                        response.getServiceRecord(), now);
+                if (renewSrv && renewTxt) {
+                    missingKnownAnswerRecords.add(new Pair<>(serviceName, MdnsRecord.TYPE_ANY));
+                } else {
+                    if (renewTxt) {
+                        missingKnownAnswerRecords.add(new Pair<>(serviceName, MdnsRecord.TYPE_TXT));
+                    }
+                    if (renewSrv) {
+                        missingKnownAnswerRecords.add(new Pair<>(serviceName, MdnsRecord.TYPE_SRV));
+                        // The hostname is not yet known, so queries for address records will be
+                        // sent the next time the EnqueueMdnsQueryCallable is enqueued if the reply
+                        // does not contain them. In practice, advertisers should include the
+                        // address records when queried for SRV, although it's not a MUST
+                        // requirement (RFC6763 12.2).
+                    } else if (!response.hasInet4AddressRecord()
+                            && !response.hasInet6AddressRecord()) {
+                        final String[] host = response.getServiceRecord().getServiceHost();
+                        missingKnownAnswerRecords.add(new Pair<>(host, MdnsRecord.TYPE_A));
+                        missingKnownAnswerRecords.add(new Pair<>(host, MdnsRecord.TYPE_AAAA));
+                    }
                 }
             }
             numQuestions += missingKnownAnswerRecords.size();
@@ -183,24 +189,9 @@
                 writeQuestion(serviceTypeLabels, MdnsRecord.TYPE_PTR);
             }
 
-            if (requestSender instanceof MdnsMultinetworkSocketClient) {
-                sendPacketToIpv4AndIpv6(requestSender, MdnsConstants.MDNS_PORT, network);
-                for (Integer emulatorPort : castShellEmulatorMdnsPorts) {
-                    sendPacketToIpv4AndIpv6(requestSender, emulatorPort, network);
-                }
-            } else if (requestSender instanceof MdnsSocketClient) {
-                final MdnsSocketClient client = (MdnsSocketClient) requestSender;
-                InetAddress mdnsAddress = MdnsConstants.getMdnsIPv4Address();
-                if (client.isOnIPv6OnlyNetwork()) {
-                    mdnsAddress = MdnsConstants.getMdnsIPv6Address();
-                }
-
-                sendPacketTo(client, new InetSocketAddress(mdnsAddress, MdnsConstants.MDNS_PORT));
-                for (Integer emulatorPort : castShellEmulatorMdnsPorts) {
-                    sendPacketTo(client, new InetSocketAddress(mdnsAddress, emulatorPort));
-                }
-            } else {
-                throw new IOException("Unknown socket client type: " + requestSender.getClass());
+            sendPacketToIpv4AndIpv6(requestSender, MdnsConstants.MDNS_PORT);
+            for (Integer emulatorPort : castShellEmulatorMdnsPorts) {
+                sendPacketToIpv4AndIpv6(requestSender, emulatorPort);
             }
             return Pair.create(transactionId, subtypes);
         } catch (IOException e) {
@@ -218,38 +209,39 @@
                         | (expectUnicastResponse ? MdnsConstants.QCLASS_UNICAST : 0));
     }
 
-    private void sendPacketTo(MdnsSocketClient requestSender, InetSocketAddress address)
+    private void sendPacket(MdnsSocketClientBase requestSender, InetSocketAddress address)
             throws IOException {
         DatagramPacket packet = packetWriter.getPacket(address);
         if (expectUnicastResponse) {
-            requestSender.sendUnicastPacket(packet);
+            if (requestSender instanceof MdnsMultinetworkSocketClient) {
+                ((MdnsMultinetworkSocketClient) requestSender).sendPacketRequestingUnicastResponse(
+                        packet, socketKey, onlyUseIpv6OnIpv6OnlyNetworks);
+            } else {
+                requestSender.sendPacketRequestingUnicastResponse(
+                        packet, onlyUseIpv6OnIpv6OnlyNetworks);
+            }
         } else {
-            requestSender.sendMulticastPacket(packet);
+            if (requestSender instanceof MdnsMultinetworkSocketClient) {
+                ((MdnsMultinetworkSocketClient) requestSender)
+                        .sendPacketRequestingMulticastResponse(
+                                packet, socketKey, onlyUseIpv6OnIpv6OnlyNetworks);
+            } else {
+                requestSender.sendPacketRequestingMulticastResponse(
+                        packet, onlyUseIpv6OnIpv6OnlyNetworks);
+            }
         }
     }
 
-    private void sendPacketFromNetwork(MdnsSocketClientBase requestSender,
-            InetSocketAddress address, Network network)
-            throws IOException {
-        DatagramPacket packet = packetWriter.getPacket(address);
-        if (expectUnicastResponse) {
-            requestSender.sendUnicastPacket(packet, network);
-        } else {
-            requestSender.sendMulticastPacket(packet, network);
-        }
-    }
-
-    private void sendPacketToIpv4AndIpv6(MdnsSocketClientBase requestSender, int port,
-            Network network) {
+    private void sendPacketToIpv4AndIpv6(MdnsSocketClientBase requestSender, int port) {
         try {
-            sendPacketFromNetwork(requestSender,
-                    new InetSocketAddress(MdnsConstants.getMdnsIPv4Address(), port), network);
+            sendPacket(requestSender,
+                    new InetSocketAddress(MdnsConstants.getMdnsIPv4Address(), port));
         } catch (IOException e) {
             Log.i(TAG, "Can't send packet to IPv4", e);
         }
         try {
-            sendPacketFromNetwork(requestSender,
-                    new InetSocketAddress(MdnsConstants.getMdnsIPv6Address(), port), network);
+            sendPacket(requestSender,
+                    new InetSocketAddress(MdnsConstants.getMdnsIPv6Address(), port));
         } catch (IOException e) {
             Log.i(TAG, "Can't send packet to IPv6", e);
         }
diff --git a/service-t/src/com/android/server/connectivity/mdns/ExecutorProvider.java b/service-t/src/com/android/server/connectivity/mdns/ExecutorProvider.java
index 72b65e0..0eebc61 100644
--- a/service-t/src/com/android/server/connectivity/mdns/ExecutorProvider.java
+++ b/service-t/src/com/android/server/connectivity/mdns/ExecutorProvider.java
@@ -42,6 +42,9 @@
     /** Shuts down all the created {@link ScheduledExecutorService} instances. */
     public void shutdownAll() {
         for (ScheduledExecutorService executor : serviceTypeClientSchedulerExecutors) {
+            if (executor.isShutdown()) {
+                continue;
+            }
             executor.shutdownNow();
         }
     }
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsAdvertiser.java b/service-t/src/com/android/server/connectivity/mdns/MdnsAdvertiser.java
index 5f27b6a..158d7a3 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsAdvertiser.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsAdvertiser.java
@@ -287,7 +287,7 @@
         }
 
         @Override
-        public void onSocketCreated(@NonNull Network network,
+        public void onSocketCreated(@NonNull SocketKey socketKey,
                 @NonNull MdnsInterfaceSocket socket,
                 @NonNull List<LinkAddress> addresses) {
             MdnsInterfaceAdvertiser advertiser = mAllAdvertisers.get(socket);
@@ -311,14 +311,14 @@
         }
 
         @Override
-        public void onInterfaceDestroyed(@NonNull Network network,
+        public void onInterfaceDestroyed(@NonNull SocketKey socketKey,
                 @NonNull MdnsInterfaceSocket socket) {
             final MdnsInterfaceAdvertiser advertiser = mAdvertisers.get(socket);
             if (advertiser != null) advertiser.destroyNow();
         }
 
         @Override
-        public void onAddressesChanged(@NonNull Network network,
+        public void onAddressesChanged(@NonNull SocketKey socketKey,
                 @NonNull MdnsInterfaceSocket socket, @NonNull List<LinkAddress> addresses) {
             final MdnsInterfaceAdvertiser advertiser = mAdvertisers.get(socket);
             if (advertiser != null) advertiser.updateAddresses(addresses);
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsConstants.java b/service-t/src/com/android/server/connectivity/mdns/MdnsConstants.java
index f0e1717..ce5f540 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsConstants.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsConstants.java
@@ -16,6 +16,8 @@
 
 package com.android.server.connectivity.mdns;
 
+import static com.android.internal.annotations.VisibleForTesting.Visibility.PACKAGE;
+
 import static java.nio.charset.StandardCharsets.UTF_8;
 
 import com.android.internal.annotations.VisibleForTesting;
@@ -25,7 +27,7 @@
 import java.nio.charset.Charset;
 
 /** mDNS-related constants. */
-@VisibleForTesting
+@VisibleForTesting(visibility = PACKAGE)
 public final class MdnsConstants {
     public static final int MDNS_PORT = 5353;
     // Flags word format is:
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java b/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java
index 39fceb9..05b1dcf 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java
@@ -22,7 +22,6 @@
 import android.annotation.NonNull;
 import android.annotation.Nullable;
 import android.annotation.RequiresPermission;
-import android.net.Network;
 import android.os.Handler;
 import android.os.HandlerThread;
 import android.util.ArrayMap;
@@ -36,7 +35,6 @@
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.List;
-import java.util.Objects;
 
 /**
  * This class keeps tracking the set of registered {@link MdnsServiceBrowserListener} instances, and
@@ -50,54 +48,58 @@
     private final MdnsSocketClientBase socketClient;
     @NonNull private final SharedLog sharedLog;
 
-    @NonNull private final PerNetworkServiceTypeClients perNetworkServiceTypeClients;
+    @NonNull private final PerSocketServiceTypeClients perSocketServiceTypeClients;
     @NonNull private final Handler handler;
     @Nullable private final HandlerThread handlerThread;
 
-    private static class PerNetworkServiceTypeClients {
-        private final ArrayMap<Pair<String, Network>, MdnsServiceTypeClient> clients =
+    private static class PerSocketServiceTypeClients {
+        private final ArrayMap<Pair<String, SocketKey>, MdnsServiceTypeClient> clients =
                 new ArrayMap<>();
 
-        public void put(@NonNull String serviceType, @Nullable Network network,
+        public void put(@NonNull String serviceType, @NonNull SocketKey socketKey,
                 @NonNull MdnsServiceTypeClient client) {
             final String dnsLowerServiceType = MdnsUtils.toDnsLowerCase(serviceType);
-            final Pair<String, Network> perNetworkServiceType = new Pair<>(dnsLowerServiceType,
-                    network);
-            clients.put(perNetworkServiceType, client);
+            final Pair<String, SocketKey> perSocketServiceType = new Pair<>(dnsLowerServiceType,
+                    socketKey);
+            clients.put(perSocketServiceType, client);
         }
 
         @Nullable
-        public MdnsServiceTypeClient get(@NonNull String serviceType, @Nullable Network network) {
+        public MdnsServiceTypeClient get(
+                @NonNull String serviceType, @NonNull SocketKey socketKey) {
             final String dnsLowerServiceType = MdnsUtils.toDnsLowerCase(serviceType);
-            final Pair<String, Network> perNetworkServiceType = new Pair<>(dnsLowerServiceType,
-                    network);
-            return clients.getOrDefault(perNetworkServiceType, null);
+            final Pair<String, SocketKey> perSocketServiceType = new Pair<>(dnsLowerServiceType,
+                    socketKey);
+            return clients.getOrDefault(perSocketServiceType, null);
         }
 
         public List<MdnsServiceTypeClient> getByServiceType(@NonNull String serviceType) {
             final String dnsLowerServiceType = MdnsUtils.toDnsLowerCase(serviceType);
             final List<MdnsServiceTypeClient> list = new ArrayList<>();
             for (int i = 0; i < clients.size(); i++) {
-                final Pair<String, Network> perNetworkServiceType = clients.keyAt(i);
-                if (dnsLowerServiceType.equals(perNetworkServiceType.first)) {
+                final Pair<String, SocketKey> perSocketServiceType = clients.keyAt(i);
+                if (dnsLowerServiceType.equals(perSocketServiceType.first)) {
                     list.add(clients.valueAt(i));
                 }
             }
             return list;
         }
 
-        public List<MdnsServiceTypeClient> getByNetwork(@Nullable Network network) {
+        public List<MdnsServiceTypeClient> getBySocketKey(@NonNull SocketKey socketKey) {
             final List<MdnsServiceTypeClient> list = new ArrayList<>();
             for (int i = 0; i < clients.size(); i++) {
-                final Pair<String, Network> perNetworkServiceType = clients.keyAt(i);
-                final Network serviceTypeNetwork = perNetworkServiceType.second;
-                if (Objects.equals(network, serviceTypeNetwork)) {
+                final Pair<String, SocketKey> perSocketServiceType = clients.keyAt(i);
+                if (socketKey.equals(perSocketServiceType.second)) {
                     list.add(clients.valueAt(i));
                 }
             }
             return list;
         }
 
+        public List<MdnsServiceTypeClient> getAllMdnsServiceTypeClient() {
+            return new ArrayList<>(clients.values());
+        }
+
         public void remove(@NonNull MdnsServiceTypeClient client) {
             final int index = clients.indexOfValue(client);
             clients.removeAt(index);
@@ -113,7 +115,7 @@
         this.executorProvider = executorProvider;
         this.socketClient = socketClient;
         this.sharedLog = sharedLog;
-        this.perNetworkServiceTypeClients = new PerNetworkServiceTypeClients();
+        this.perSocketServiceTypeClients = new PerSocketServiceTypeClients();
         if (socketClient.getLooper() != null) {
             this.handlerThread = null;
             this.handler = new Handler(socketClient.getLooper());
@@ -164,7 +166,7 @@
             @NonNull String serviceType,
             @NonNull MdnsServiceBrowserListener listener,
             @NonNull MdnsSearchOptions searchOptions) {
-        if (perNetworkServiceTypeClients.isEmpty()) {
+        if (perSocketServiceTypeClients.isEmpty()) {
             // First listener. Starts the socket client.
             try {
                 socketClient.startDiscovery();
@@ -177,29 +179,29 @@
         socketClient.notifyNetworkRequested(listener, searchOptions.getNetwork(),
                 new MdnsSocketClientBase.SocketCreationCallback() {
                     @Override
-                    public void onSocketCreated(@Nullable Network network) {
+                    public void onSocketCreated(@NonNull SocketKey socketKey) {
                         ensureRunningOnHandlerThread(handler);
                         // All listeners of the same service types shares the same
                         // MdnsServiceTypeClient.
                         MdnsServiceTypeClient serviceTypeClient =
-                                perNetworkServiceTypeClients.get(serviceType, network);
+                                perSocketServiceTypeClients.get(serviceType, socketKey);
                         if (serviceTypeClient == null) {
-                            serviceTypeClient = createServiceTypeClient(serviceType, network);
-                            perNetworkServiceTypeClients.put(serviceType, network,
+                            serviceTypeClient = createServiceTypeClient(serviceType, socketKey);
+                            perSocketServiceTypeClients.put(serviceType, socketKey,
                                     serviceTypeClient);
                         }
                         serviceTypeClient.startSendAndReceive(listener, searchOptions);
                     }
 
                     @Override
-                    public void onAllSocketsDestroyed(@Nullable Network network) {
+                    public void onSocketDestroyed(@NonNull SocketKey socketKey) {
                         ensureRunningOnHandlerThread(handler);
                         final MdnsServiceTypeClient serviceTypeClient =
-                                perNetworkServiceTypeClients.get(serviceType, network);
+                                perSocketServiceTypeClients.get(serviceType, socketKey);
                         if (serviceTypeClient == null) return;
                         // Notify all listeners that all services are removed from this socket.
                         serviceTypeClient.notifySocketDestroyed();
-                        perNetworkServiceTypeClients.remove(serviceTypeClient);
+                        perSocketServiceTypeClients.remove(serviceTypeClient);
                     }
                 });
     }
@@ -224,7 +226,7 @@
         socketClient.notifyNetworkUnrequested(listener);
 
         final List<MdnsServiceTypeClient> serviceTypeClients =
-                perNetworkServiceTypeClients.getByServiceType(serviceType);
+                perSocketServiceTypeClients.getByServiceType(serviceType);
         if (serviceTypeClients.isEmpty()) {
             return;
         }
@@ -233,60 +235,60 @@
             if (serviceTypeClient.stopSendAndReceive(listener)) {
                 // No listener is registered for the service type anymore, remove it from the list
                 // of the service type clients.
-                perNetworkServiceTypeClients.remove(serviceTypeClient);
+                perSocketServiceTypeClients.remove(serviceTypeClient);
             }
         }
-        if (perNetworkServiceTypeClients.isEmpty()) {
+        if (perSocketServiceTypeClients.isEmpty()) {
             // No discovery request. Stops the socket client.
+            sharedLog.i("All service type listeners unregistered; stopping discovery");
             socketClient.stopDiscovery();
         }
     }
 
     @Override
-    public void onResponseReceived(@NonNull MdnsPacket packet,
-            int interfaceIndex, @Nullable Network network) {
+    public void onResponseReceived(@NonNull MdnsPacket packet, @NonNull SocketKey socketKey) {
         checkAndRunOnHandlerThread(() ->
-                handleOnResponseReceived(packet, interfaceIndex, network));
+                handleOnResponseReceived(packet, socketKey));
     }
 
-    private void handleOnResponseReceived(@NonNull MdnsPacket packet, int interfaceIndex,
-            @Nullable Network network) {
-        for (MdnsServiceTypeClient serviceTypeClient
-                : getMdnsServiceTypeClient(network)) {
-            serviceTypeClient.processResponse(packet, interfaceIndex, network);
+    private void handleOnResponseReceived(@NonNull MdnsPacket packet,
+            @NonNull SocketKey socketKey) {
+        for (MdnsServiceTypeClient serviceTypeClient : getMdnsServiceTypeClient(socketKey)) {
+            serviceTypeClient.processResponse(packet, socketKey);
         }
     }
 
-    private List<MdnsServiceTypeClient> getMdnsServiceTypeClient(@Nullable Network network) {
+    private List<MdnsServiceTypeClient> getMdnsServiceTypeClient(@NonNull SocketKey socketKey) {
         if (socketClient.supportsRequestingSpecificNetworks()) {
-            return perNetworkServiceTypeClients.getByNetwork(network);
+            return perSocketServiceTypeClients.getBySocketKey(socketKey);
         } else {
-            return perNetworkServiceTypeClients.getByNetwork(null);
+            return perSocketServiceTypeClients.getAllMdnsServiceTypeClient();
         }
     }
 
     @Override
     public void onFailedToParseMdnsResponse(int receivedPacketNumber, int errorCode,
-            @Nullable Network network) {
+            @NonNull SocketKey socketKey) {
         checkAndRunOnHandlerThread(() ->
-                handleOnFailedToParseMdnsResponse(receivedPacketNumber, errorCode, network));
+                handleOnFailedToParseMdnsResponse(receivedPacketNumber, errorCode, socketKey));
     }
 
     private void handleOnFailedToParseMdnsResponse(int receivedPacketNumber, int errorCode,
-            @Nullable Network network) {
-        for (MdnsServiceTypeClient serviceTypeClient
-                : getMdnsServiceTypeClient(network)) {
+            @NonNull SocketKey socketKey) {
+        for (MdnsServiceTypeClient serviceTypeClient : getMdnsServiceTypeClient(socketKey)) {
             serviceTypeClient.onFailedToParseMdnsResponse(receivedPacketNumber, errorCode);
         }
     }
 
     @VisibleForTesting
     MdnsServiceTypeClient createServiceTypeClient(@NonNull String serviceType,
-            @Nullable Network network) {
-        sharedLog.log("createServiceTypeClient for type:" + serviceType + ", net:" + network);
+            @NonNull SocketKey socketKey) {
+        sharedLog.log("createServiceTypeClient for type:" + serviceType + " " + socketKey);
+        final String tag = serviceType + "-" + socketKey.getNetwork()
+                + "/" + socketKey.getInterfaceIndex();
         return new MdnsServiceTypeClient(
                 serviceType, socketClient,
-                executorProvider.newServiceTypeClientSchedulerExecutor(), network,
-                sharedLog.forSubComponent(serviceType + "-" + network));
+                executorProvider.newServiceTypeClientSchedulerExecutor(), socketKey,
+                sharedLog.forSubComponent(tag));
     }
 }
\ No newline at end of file
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java b/service-t/src/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java
index 73e4497..d1fa57c 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java
@@ -64,7 +64,7 @@
         @NonNull
         private final SocketCreationCallback mSocketCreationCallback;
         @NonNull
-        private final ArrayMap<MdnsInterfaceSocket, Network> mActiveNetworkSockets =
+        private final ArrayMap<MdnsInterfaceSocket, SocketKey> mActiveSockets =
                 new ArrayMap<>();
 
         InterfaceSocketCallback(SocketCreationCallback socketCreationCallback) {
@@ -72,39 +72,39 @@
         }
 
         @Override
-        public void onSocketCreated(@Nullable Network network,
+        public void onSocketCreated(@NonNull SocketKey socketKey,
                 @NonNull MdnsInterfaceSocket socket, @NonNull List<LinkAddress> addresses) {
             // The socket may be already created by other request before, try to get the stored
             // ReadPacketHandler.
             ReadPacketHandler handler = mSocketPacketHandlers.get(socket);
             if (handler == null) {
                 // First request to create this socket. Initial a ReadPacketHandler for this socket.
-                handler = new ReadPacketHandler(network, socket.getInterface().getIndex());
+                handler = new ReadPacketHandler(socketKey);
                 mSocketPacketHandlers.put(socket, handler);
             }
             socket.addPacketHandler(handler);
-            mActiveNetworkSockets.put(socket, network);
-            mSocketCreationCallback.onSocketCreated(network);
+            mActiveSockets.put(socket, socketKey);
+            mSocketCreationCallback.onSocketCreated(socketKey);
         }
 
         @Override
-        public void onInterfaceDestroyed(@Nullable Network network,
+        public void onInterfaceDestroyed(@NonNull SocketKey socketKey,
                 @NonNull MdnsInterfaceSocket socket) {
             notifySocketDestroyed(socket);
             maybeCleanupPacketHandler(socket);
         }
 
         private void notifySocketDestroyed(@NonNull MdnsInterfaceSocket socket) {
-            final Network network = mActiveNetworkSockets.remove(socket);
-            if (!isAnySocketActive(network)) {
-                mSocketCreationCallback.onAllSocketsDestroyed(network);
+            final SocketKey socketKey = mActiveSockets.remove(socket);
+            if (!isSocketActive(socket)) {
+                mSocketCreationCallback.onSocketDestroyed(socketKey);
             }
         }
 
         void onNetworkUnrequested() {
-            for (int i = mActiveNetworkSockets.size() - 1; i >= 0; i--) {
+            for (int i = mActiveSockets.size() - 1; i >= 0; i--) {
                 // Iterate from the end so the socket can be removed
-                final MdnsInterfaceSocket socket = mActiveNetworkSockets.keyAt(i);
+                final MdnsInterfaceSocket socket = mActiveSockets.keyAt(i);
                 notifySocketDestroyed(socket);
                 maybeCleanupPacketHandler(socket);
             }
@@ -114,28 +114,18 @@
     private boolean isSocketActive(@NonNull MdnsInterfaceSocket socket) {
         for (int i = 0; i < mRequestedNetworks.size(); i++) {
             final InterfaceSocketCallback isc = mRequestedNetworks.valueAt(i);
-            if (isc.mActiveNetworkSockets.containsKey(socket)) {
+            if (isc.mActiveSockets.containsKey(socket)) {
                 return true;
             }
         }
         return false;
     }
 
-    private boolean isAnySocketActive(@Nullable Network network) {
+    private ArrayMap<MdnsInterfaceSocket, SocketKey> getActiveSockets() {
+        final ArrayMap<MdnsInterfaceSocket, SocketKey> sockets = new ArrayMap<>();
         for (int i = 0; i < mRequestedNetworks.size(); i++) {
             final InterfaceSocketCallback isc = mRequestedNetworks.valueAt(i);
-            if (isc.mActiveNetworkSockets.containsValue(network)) {
-                return true;
-            }
-        }
-        return false;
-    }
-
-    private ArrayMap<MdnsInterfaceSocket, Network> getActiveSockets() {
-        final ArrayMap<MdnsInterfaceSocket, Network> sockets = new ArrayMap<>();
-        for (int i = 0; i < mRequestedNetworks.size(); i++) {
-            final InterfaceSocketCallback isc = mRequestedNetworks.valueAt(i);
-            sockets.putAll(isc.mActiveNetworkSockets);
+            sockets.putAll(isc.mActiveSockets);
         }
         return sockets;
     }
@@ -146,17 +136,15 @@
     }
 
     private class ReadPacketHandler implements MulticastPacketReader.PacketHandler {
-        private final Network mNetwork;
-        private final int mInterfaceIndex;
+        @NonNull private final SocketKey mSocketKey;
 
-        ReadPacketHandler(@NonNull Network network, int interfaceIndex) {
-            mNetwork = network;
-            mInterfaceIndex = interfaceIndex;
+        ReadPacketHandler(@NonNull SocketKey socketKey) {
+            mSocketKey = socketKey;
         }
 
         @Override
         public void handlePacket(byte[] recvbuf, int length, InetSocketAddress src) {
-            processResponsePacket(recvbuf, length, mInterfaceIndex, mNetwork);
+            processResponsePacket(recvbuf, length, mSocketKey);
         }
     }
 
@@ -215,21 +203,22 @@
         return true;
     }
 
-    private void sendMdnsPacket(@NonNull DatagramPacket packet, @Nullable Network targetNetwork) {
+    private void sendMdnsPacket(@NonNull DatagramPacket packet, @NonNull SocketKey targetSocketKey,
+            boolean onlyUseIpv6OnIpv6OnlyNetworks) {
         final boolean isIpv6 = ((InetSocketAddress) packet.getSocketAddress()).getAddress()
                 instanceof Inet6Address;
         final boolean isIpv4 = ((InetSocketAddress) packet.getSocketAddress()).getAddress()
                 instanceof Inet4Address;
-        final ArrayMap<MdnsInterfaceSocket, Network> activeSockets = getActiveSockets();
+        final ArrayMap<MdnsInterfaceSocket, SocketKey> activeSockets = getActiveSockets();
+        boolean shouldQueryIpv6 = !onlyUseIpv6OnIpv6OnlyNetworks || isIpv6OnlySockets(
+                activeSockets, targetSocketKey);
         for (int i = 0; i < activeSockets.size(); i++) {
             final MdnsInterfaceSocket socket = activeSockets.keyAt(i);
-            final Network network = activeSockets.valueAt(i);
+            final SocketKey socketKey = activeSockets.valueAt(i);
             // Check ip capability and network before sending packet
-            if (((isIpv6 && socket.hasJoinedIpv6()) || (isIpv4 && socket.hasJoinedIpv4()))
-                    // Contrary to MdnsUtils.isNetworkMatched, only send packets targeting
-                    // the null network to interfaces that have the null network (tethering
-                    // downstream interfaces).
-                    && Objects.equals(network, targetNetwork)) {
+            if (((isIpv6 && socket.hasJoinedIpv6() && shouldQueryIpv6)
+                    || (isIpv4 && socket.hasJoinedIpv4()))
+                    && Objects.equals(socketKey, targetSocketKey)) {
                 try {
                     socket.send(packet);
                 } catch (IOException e) {
@@ -239,8 +228,20 @@
         }
     }
 
-    private void processResponsePacket(byte[] recvbuf, int length, int interfaceIndex,
-            @NonNull Network network) {
+    private boolean isIpv6OnlySockets(
+            @NonNull ArrayMap<MdnsInterfaceSocket, SocketKey> activeSockets,
+            @NonNull SocketKey targetSocketKey) {
+        for (int i = 0; i < activeSockets.size(); i++) {
+            final MdnsInterfaceSocket socket = activeSockets.keyAt(i);
+            final SocketKey socketKey = activeSockets.valueAt(i);
+            if (Objects.equals(socketKey, targetSocketKey) && socket.hasJoinedIpv4()) {
+                return false;
+            }
+        }
+        return true;
+    }
+
+    private void processResponsePacket(byte[] recvbuf, int length, @NonNull SocketKey socketKey) {
         int packetNumber = ++mReceivedPacketNumber;
 
         final MdnsPacket response;
@@ -250,33 +251,47 @@
             if (e.code != MdnsResponseErrorCode.ERROR_NOT_RESPONSE_MESSAGE) {
                 Log.e(TAG, e.getMessage(), e);
                 if (mCallback != null) {
-                    mCallback.onFailedToParseMdnsResponse(packetNumber, e.code, network);
+                    mCallback.onFailedToParseMdnsResponse(packetNumber, e.code, socketKey);
                 }
             }
             return;
         }
 
         if (mCallback != null) {
-            mCallback.onResponseReceived(response, interfaceIndex, network);
+            mCallback.onResponseReceived(response, socketKey);
         }
     }
 
     /**
-     * Sends a mDNS request packet via given network that asks for multicast response. Null network
-     * means sending packet via all networks.
+     * Send a mDNS request packet via given socket key that asks for multicast response.
      */
+    public void sendPacketRequestingMulticastResponse(@NonNull DatagramPacket packet,
+            @NonNull SocketKey socketKey, boolean onlyUseIpv6OnIpv6OnlyNetworks) {
+        mHandler.post(() -> sendMdnsPacket(packet, socketKey, onlyUseIpv6OnIpv6OnlyNetworks));
+    }
+
     @Override
-    public void sendMulticastPacket(@NonNull DatagramPacket packet, @Nullable Network network) {
-        mHandler.post(() -> sendMdnsPacket(packet, network));
+    public void sendPacketRequestingMulticastResponse(
+            @NonNull DatagramPacket packet, boolean onlyUseIpv6OnIpv6OnlyNetworks) {
+        throw new UnsupportedOperationException("This socket client need to specify the socket to"
+                + "send packet");
     }
 
     /**
-     * Sends a mDNS request packet via given network that asks for unicast response. Null network
-     * means sending packet via all networks.
+     * Send a mDNS request packet via given socket key that asks for unicast response.
+     *
+     * <p>The socket client may use a null network to identify some or all interfaces, in which case
+     * passing null sends the packet to these.
      */
-    @Override
-    public void sendUnicastPacket(@NonNull DatagramPacket packet, @Nullable Network network) {
-        // TODO: Separate unicast packet.
-        mHandler.post(() -> sendMdnsPacket(packet, network));
+    public void sendPacketRequestingUnicastResponse(@NonNull DatagramPacket packet,
+            @NonNull SocketKey socketKey, boolean onlyUseIpv6OnIpv6OnlyNetworks) {
+        mHandler.post(() -> sendMdnsPacket(packet, socketKey, onlyUseIpv6OnIpv6OnlyNetworks));
     }
-}
+
+    @Override
+    public void sendPacketRequestingUnicastResponse(
+            @NonNull DatagramPacket packet, boolean onlyUseIpv6OnIpv6OnlyNetworks) {
+        throw new UnsupportedOperationException("This socket client need to specify the socket to"
+                + "send packet");
+    }
+}
\ No newline at end of file
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsSearchOptions.java b/service-t/src/com/android/server/connectivity/mdns/MdnsSearchOptions.java
index 3da6bd0..98c80ee 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsSearchOptions.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsSearchOptions.java
@@ -44,10 +44,13 @@
             new Parcelable.Creator<MdnsSearchOptions>() {
                 @Override
                 public MdnsSearchOptions createFromParcel(Parcel source) {
-                    return new MdnsSearchOptions(source.createStringArrayList(),
-                            source.readBoolean(), source.readBoolean(),
+                    return new MdnsSearchOptions(
+                            source.createStringArrayList(),
+                            source.readBoolean(),
+                            source.readBoolean(),
                             source.readParcelable(null),
-                            source.readString());
+                            source.readString(),
+                            (source.dataAvail() > 0) ? source.readBoolean() : false);
                 }
 
                 @Override
@@ -61,18 +64,25 @@
     private final String resolveInstanceName;
 
     private final boolean isPassiveMode;
+    private final boolean onlyUseIpv6OnIpv6OnlyNetworks;
     private final boolean removeExpiredService;
     // The target network for searching. Null network means search on all possible interfaces.
     @Nullable private final Network mNetwork;
 
     /** Parcelable constructs for a {@link MdnsSearchOptions}. */
-    MdnsSearchOptions(List<String> subtypes, boolean isPassiveMode, boolean removeExpiredService,
-            @Nullable Network network, @Nullable String resolveInstanceName) {
+    MdnsSearchOptions(
+            List<String> subtypes,
+            boolean isPassiveMode,
+            boolean removeExpiredService,
+            @Nullable Network network,
+            @Nullable String resolveInstanceName,
+            boolean onlyUseIpv6OnIpv6OnlyNetworks) {
         this.subtypes = new ArrayList<>();
         if (subtypes != null) {
             this.subtypes.addAll(subtypes);
         }
         this.isPassiveMode = isPassiveMode;
+        this.onlyUseIpv6OnIpv6OnlyNetworks = onlyUseIpv6OnIpv6OnlyNetworks;
         this.removeExpiredService = removeExpiredService;
         mNetwork = network;
         this.resolveInstanceName = resolveInstanceName;
@@ -104,6 +114,14 @@
         return isPassiveMode;
     }
 
+    /**
+     * @return {@code true} if only the IPv4 mDNS host should be queried on network that supports
+     * both IPv6 as well as IPv4. On an IPv6-only network, this is ignored.
+     */
+    public boolean onlyUseIpv6OnIpv6OnlyNetworks() {
+        return onlyUseIpv6OnIpv6OnlyNetworks;
+    }
+
     /** Returns {@code true} if service will be removed after its TTL expires. */
     public boolean removeExpiredService() {
         return removeExpiredService;
@@ -140,12 +158,14 @@
         out.writeBoolean(removeExpiredService);
         out.writeParcelable(mNetwork, 0);
         out.writeString(resolveInstanceName);
+        out.writeBoolean(onlyUseIpv6OnIpv6OnlyNetworks);
     }
 
     /** A builder to create {@link MdnsSearchOptions}. */
     public static final class Builder {
         private final Set<String> subtypes;
         private boolean isPassiveMode = true;
+        private boolean onlyUseIpv6OnIpv6OnlyNetworks = false;
         private boolean removeExpiredService;
         private Network mNetwork;
         private String resolveInstanceName;
@@ -190,6 +210,15 @@
         }
 
         /**
+         * Sets if only the IPv4 mDNS host should be queried on a network that is both IPv4 & IPv6.
+         * On an IPv6-only network, this is ignored.
+         */
+        public Builder setOnlyUseIpv6OnIpv6OnlyNetworks(boolean onlyUseIpv6OnIpv6OnlyNetworks) {
+            this.onlyUseIpv6OnIpv6OnlyNetworks = onlyUseIpv6OnIpv6OnlyNetworks;
+            return this;
+        }
+
+        /**
          * Sets if the service should be removed after TTL.
          *
          * @param removeExpiredService If set to {@code true}, the service will be removed after TTL
@@ -223,8 +252,13 @@
 
         /** Builds a {@link MdnsSearchOptions} with the arguments supplied to this builder. */
         public MdnsSearchOptions build() {
-            return new MdnsSearchOptions(new ArrayList<>(subtypes), isPassiveMode,
-                    removeExpiredService, mNetwork, resolveInstanceName);
+            return new MdnsSearchOptions(
+                    new ArrayList<>(subtypes),
+                    isPassiveMode,
+                    removeExpiredService,
+                    mNetwork,
+                    resolveInstanceName,
+                    onlyUseIpv6OnIpv6OnlyNetworks);
         }
     }
 }
\ No newline at end of file
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
index 49a376c..48e4724 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
@@ -20,7 +20,6 @@
 
 import android.annotation.NonNull;
 import android.annotation.Nullable;
-import android.net.Network;
 import android.text.TextUtils;
 import android.util.ArrayMap;
 import android.util.ArraySet;
@@ -58,7 +57,7 @@
     private final MdnsSocketClientBase socketClient;
     private final MdnsResponseDecoder responseDecoder;
     private final ScheduledExecutorService executor;
-    @Nullable private final Network network;
+    @NonNull private final SocketKey socketKey;
     @NonNull private final SharedLog sharedLog;
     private final Object lock = new Object();
     private final ArrayMap<MdnsServiceBrowserListener, MdnsSearchOptions> listeners =
@@ -90,9 +89,9 @@
             @NonNull String serviceType,
             @NonNull MdnsSocketClientBase socketClient,
             @NonNull ScheduledExecutorService executor,
-            @Nullable Network network,
+            @NonNull SocketKey socketKey,
             @NonNull SharedLog sharedLog) {
-        this(serviceType, socketClient, executor, new MdnsResponseDecoder.Clock(), network,
+        this(serviceType, socketClient, executor, new MdnsResponseDecoder.Clock(), socketKey,
                 sharedLog);
     }
 
@@ -102,7 +101,7 @@
             @NonNull MdnsSocketClientBase socketClient,
             @NonNull ScheduledExecutorService executor,
             @NonNull MdnsResponseDecoder.Clock clock,
-            @Nullable Network network,
+            @NonNull SocketKey socketKey,
             @NonNull SharedLog sharedLog) {
         this.serviceType = serviceType;
         this.socketClient = socketClient;
@@ -110,7 +109,7 @@
         this.serviceTypeLabels = TextUtils.split(serviceType, "\\.");
         this.responseDecoder = new MdnsResponseDecoder(clock, serviceTypeLabels);
         this.clock = clock;
-        this.network = network;
+        this.socketKey = socketKey;
         this.sharedLog = sharedLog;
     }
 
@@ -198,8 +197,9 @@
             final QueryTaskConfig taskConfig = new QueryTaskConfig(
                     searchOptions.getSubtypes(),
                     searchOptions.isPassiveMode(),
+                    searchOptions.onlyUseIpv6OnIpv6OnlyNetworks(),
                     currentSessionId,
-                    network);
+                    socketKey);
             if (hadReply) {
                 requestTaskFuture = scheduleNextRunLocked(taskConfig);
             } else {
@@ -220,7 +220,7 @@
         final boolean matchesInstanceName = options.getResolveInstanceName() == null
                 // DNS is case-insensitive, so ignore case in the comparison
                 || MdnsUtils.equalsIgnoreDnsCase(options.getResolveInstanceName(),
-                        response.getServiceInstanceName());
+                response.getServiceInstanceName());
 
         // If discovery is requiring some subtypes, the response must have one that matches a
         // requested one.
@@ -261,15 +261,13 @@
     /**
      * Process an incoming response packet.
      */
-    public synchronized void processResponse(@NonNull MdnsPacket packet, int interfaceIndex,
-            Network network) {
+    public synchronized void processResponse(@NonNull MdnsPacket packet,
+            @NonNull SocketKey socketKey) {
         synchronized (lock) {
             // Augment the list of current known responses, and generated responses for resolve
             // requests if there is no known response
             final List<MdnsResponse> currentList = new ArrayList<>(instanceNameToResponse.values());
-
-            List<MdnsResponse> additionalResponses = makeResponsesForResolve(interfaceIndex,
-                    network);
+            List<MdnsResponse> additionalResponses = makeResponsesForResolve(socketKey);
             for (MdnsResponse additionalResponse : additionalResponses) {
                 if (!instanceNameToResponse.containsKey(
                         additionalResponse.getServiceInstanceName())) {
@@ -277,7 +275,8 @@
                 }
             }
             final Pair<ArraySet<MdnsResponse>, ArrayList<MdnsResponse>> augmentedResult =
-                    responseDecoder.augmentResponses(packet, currentList, interfaceIndex, network);
+                    responseDecoder.augmentResponses(packet, currentList,
+                            socketKey.getInterfaceIndex(), socketKey.getNetwork());
 
             final ArraySet<MdnsResponse> modifiedResponse = augmentedResult.first;
             final ArrayList<MdnsResponse> allResponses = augmentedResult.second;
@@ -348,6 +347,11 @@
             boolean after = response.isComplete();
             serviceBecomesComplete = !before && after;
         }
+        sharedLog.i(String.format(
+                "Handling response from service: %s, newServiceFound: %b, serviceBecomesComplete:"
+                        + " %b, responseIsComplete: %b",
+                serviceInstanceName, newServiceFound, serviceBecomesComplete,
+                response.isComplete()));
         MdnsServiceInfo serviceInfo =
                 buildMdnsServiceInfoFromResponse(response, serviceTypeLabels);
 
@@ -422,6 +426,7 @@
         private final boolean alwaysAskForUnicastResponse =
                 MdnsConfigs.alwaysAskForUnicastResponseInEachBurst();
         private final boolean usePassiveMode;
+        private final boolean onlyUseIpv6OnIpv6OnlyNetworks;
         private final long sessionId;
         @VisibleForTesting
         int transactionId;
@@ -432,11 +437,15 @@
         private int burstCounter;
         private int timeToRunNextTaskInMs;
         private boolean isFirstBurst;
-        @Nullable private final Network network;
+        @NonNull private final SocketKey socketKey;
 
-        QueryTaskConfig(@NonNull Collection<String> subtypes, boolean usePassiveMode,
-                long sessionId, @Nullable Network network) {
+        QueryTaskConfig(@NonNull Collection<String> subtypes,
+                boolean usePassiveMode,
+                boolean onlyUseIpv6OnIpv6OnlyNetworks,
+                long sessionId,
+                @Nullable SocketKey socketKey) {
             this.usePassiveMode = usePassiveMode;
+            this.onlyUseIpv6OnIpv6OnlyNetworks = onlyUseIpv6OnIpv6OnlyNetworks;
             this.subtypes = new ArrayList<>(subtypes);
             this.queriesPerBurst = QUERIES_PER_BURST;
             this.burstCounter = 0;
@@ -457,7 +466,7 @@
                 // doubles until it maxes out at TIME_BETWEEN_BURSTS_MS.
                 this.timeBetweenBurstsInMs = INITIAL_TIME_BETWEEN_BURSTS_MS;
             }
-            this.network = network;
+            this.socketKey = socketKey;
         }
 
         QueryTaskConfig getConfigForNextRun() {
@@ -497,8 +506,7 @@
         }
     }
 
-    private List<MdnsResponse> makeResponsesForResolve(int interfaceIndex,
-            @NonNull Network network) {
+    private List<MdnsResponse> makeResponsesForResolve(@NonNull SocketKey socketKey) {
         final List<MdnsResponse> resolveResponses = new ArrayList<>();
         for (int i = 0; i < listeners.size(); i++) {
             final String resolveName = listeners.valueAt(i).getResolveInstanceName();
@@ -513,7 +521,7 @@
                 instanceFullName.addAll(Arrays.asList(serviceTypeLabels));
                 knownResponse = new MdnsResponse(
                         0L /* lastUpdateTime */, instanceFullName.toArray(new String[0]),
-                        interfaceIndex, network);
+                        socketKey.getInterfaceIndex(), socketKey.getNetwork());
             }
             resolveResponses.add(knownResponse);
         }
@@ -537,10 +545,7 @@
                 // The listener is requesting to resolve a service that has no info in
                 // cache. Use the provided name to generate a minimal response, so other records are
                 // queried to complete it.
-                // Only the names are used to know which queries to send, other parameters like
-                // interfaceIndex do not matter.
-                servicesToResolve = makeResponsesForResolve(
-                        0 /* interfaceIndex */, config.network);
+                servicesToResolve = makeResponsesForResolve(config.socketKey);
                 sendDiscoveryQueries = servicesToResolve.size() < listeners.size();
             }
             Pair<Integer, List<String>> result;
@@ -553,7 +558,8 @@
                                 config.subtypes,
                                 config.expectUnicastResponse,
                                 config.transactionId,
-                                config.network,
+                                config.socketKey,
+                                config.onlyUseIpv6OnIpv6OnlyNetworks,
                                 sendDiscoveryQueries,
                                 servicesToResolve,
                                 clock)
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsSocket.java b/service-t/src/com/android/server/connectivity/mdns/MdnsSocket.java
index 5fd1354..cdd9f76 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsSocket.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsSocket.java
@@ -93,6 +93,10 @@
         }
         for (NetworkInterfaceWrapper networkInterface : networkInterfaces) {
             multicastSocket.joinGroup(multicastAddress, networkInterface.getNetworkInterface());
+            if (!isOnIPv6OnlyNetwork) {
+                multicastSocket.joinGroup(
+                        MULTICAST_IPV6_ADDRESS, networkInterface.getNetworkInterface());
+            }
         }
     }
 
@@ -105,6 +109,10 @@
         }
         for (NetworkInterfaceWrapper networkInterface : networkInterfaces) {
             multicastSocket.leaveGroup(multicastAddress, networkInterface.getNetworkInterface());
+            if (!isOnIPv6OnlyNetwork) {
+                multicastSocket.leaveGroup(
+                        MULTICAST_IPV6_ADDRESS, networkInterface.getNetworkInterface());
+            }
         }
     }
 
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClient.java b/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClient.java
index b982644..9c9812d 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClient.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClient.java
@@ -31,6 +31,9 @@
 
 import java.io.IOException;
 import java.net.DatagramPacket;
+import java.net.Inet4Address;
+import java.net.Inet6Address;
+import java.net.InetSocketAddress;
 import java.util.ArrayDeque;
 import java.util.ArrayList;
 import java.util.List;
@@ -194,39 +197,23 @@
         }
     }
 
-    /** Sends a mDNS request packet that asks for multicast response. */
-    public void sendMulticastPacket(@NonNull DatagramPacket packet) {
-        sendMdnsPacket(packet, multicastPacketQueue);
+    @Override
+    public void sendPacketRequestingMulticastResponse(@NonNull DatagramPacket packet,
+            boolean onlyUseIpv6OnIpv6OnlyNetworks) {
+        sendMdnsPacket(packet, multicastPacketQueue, onlyUseIpv6OnIpv6OnlyNetworks);
     }
 
-    /** Sends a mDNS request packet that asks for unicast response. */
-    public void sendUnicastPacket(DatagramPacket packet) {
+    @Override
+    public void sendPacketRequestingUnicastResponse(@NonNull DatagramPacket packet,
+            boolean onlyUseIpv6OnIpv6OnlyNetworks) {
         if (useSeparateSocketForUnicast) {
-            sendMdnsPacket(packet, unicastPacketQueue);
+            sendMdnsPacket(packet, unicastPacketQueue, onlyUseIpv6OnIpv6OnlyNetworks);
         } else {
-            sendMdnsPacket(packet, multicastPacketQueue);
+            sendMdnsPacket(packet, multicastPacketQueue, onlyUseIpv6OnIpv6OnlyNetworks);
         }
     }
 
     @Override
-    public void sendMulticastPacket(@NonNull DatagramPacket packet, @Nullable Network network) {
-        if (network != null) {
-            throw new IllegalArgumentException("This socket client does not support sending to "
-                    + "specific networks");
-        }
-        sendMulticastPacket(packet);
-    }
-
-    @Override
-    public void sendUnicastPacket(@NonNull DatagramPacket packet, @Nullable Network network) {
-        if (network != null) {
-            throw new IllegalArgumentException("This socket client does not support sending to "
-                    + "specific networks");
-        }
-        sendUnicastPacket(packet);
-    }
-
-    @Override
     public void notifyNetworkRequested(
             @NonNull MdnsServiceBrowserListener listener,
             @Nullable Network network,
@@ -235,7 +222,7 @@
             throw new IllegalArgumentException("This socket client does not support requesting "
                     + "specific networks");
         }
-        socketCreationCallback.onSocketCreated(null);
+        socketCreationCallback.onSocketCreated(new SocketKey(multicastSocket.getInterfaceIndex()));
     }
 
     @Override
@@ -243,11 +230,25 @@
         return false;
     }
 
-    private void sendMdnsPacket(DatagramPacket packet, Queue<DatagramPacket> packetQueueToUse) {
+    private void sendMdnsPacket(DatagramPacket packet, Queue<DatagramPacket> packetQueueToUse,
+            boolean onlyUseIpv6OnIpv6OnlyNetworks) {
         if (shouldStopSocketLoop && !MdnsConfigs.allowAddMdnsPacketAfterDiscoveryStops()) {
             LOGGER.w("sendMdnsPacket() is called after discovery already stopped");
             return;
         }
+
+        final boolean isIpv4 = ((InetSocketAddress) packet.getSocketAddress()).getAddress()
+                instanceof Inet4Address;
+        final boolean isIpv6 = ((InetSocketAddress) packet.getSocketAddress()).getAddress()
+                instanceof Inet6Address;
+        final boolean ipv6Only = multicastSocket != null && multicastSocket.isOnIPv6OnlyNetwork();
+        if (isIpv4 && ipv6Only) {
+            return;
+        }
+        if (isIpv6 && !ipv6Only && onlyUseIpv6OnIpv6OnlyNetworks) {
+            return;
+        }
+
         synchronized (packetQueueToUse) {
             while (packetQueueToUse.size() >= MdnsConfigs.mdnsPacketQueueMaxSize()) {
                 packetQueueToUse.remove();
@@ -456,7 +457,8 @@
             LOGGER.w(String.format("Error while decoding %s packet (%d): %d",
                     responseType, packetNumber, e.code));
             if (callback != null) {
-                callback.onFailedToParseMdnsResponse(packetNumber, e.code, network);
+                callback.onFailedToParseMdnsResponse(packetNumber, e.code,
+                        new SocketKey(network, interfaceIndex));
             }
             return e.code;
         }
@@ -466,7 +468,8 @@
         }
 
         if (callback != null) {
-            callback.onResponseReceived(response, interfaceIndex, network);
+            callback.onResponseReceived(
+                    response, new SocketKey(network, interfaceIndex));
         }
 
         return MdnsResponseErrorCode.SUCCESS;
@@ -533,8 +536,4 @@
         }
         packets.clear();
     }
-
-    public boolean isOnIPv6OnlyNetwork() {
-        return multicastSocket != null && multicastSocket.isOnIPv6OnlyNetwork();
-    }
 }
\ No newline at end of file
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClientBase.java b/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClientBase.java
index e0762f9..b6000f0 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClientBase.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClientBase.java
@@ -41,19 +41,15 @@
 
     /**
      * Send a mDNS request packet via given network that asks for multicast response.
-     *
-     * <p>The socket client may use a null network to identify some or all interfaces, in which case
-     * passing null sends the packet to these.
      */
-    void sendMulticastPacket(@NonNull DatagramPacket packet, @Nullable Network network);
+    void sendPacketRequestingMulticastResponse(@NonNull DatagramPacket packet,
+            boolean onlyUseIpv6OnIpv6OnlyNetworks);
 
     /**
      * Send a mDNS request packet via given network that asks for unicast response.
-     *
-     * <p>The socket client may use a null network to identify some or all interfaces, in which case
-     * passing null sends the packet to these.
      */
-    void sendUnicastPacket(@NonNull DatagramPacket packet, @Nullable Network network);
+    void sendPacketRequestingUnicastResponse(@NonNull DatagramPacket packet,
+            boolean onlyUseIpv6OnIpv6OnlyNetworks);
 
     /*** Notify that the given network is requested for mdns discovery / resolution */
     void notifyNetworkRequested(@NonNull MdnsServiceBrowserListener listener,
@@ -73,20 +69,19 @@
     /*** Callback for mdns response  */
     interface Callback {
         /*** Receive a mdns response */
-        void onResponseReceived(@NonNull MdnsPacket packet, int interfaceIndex,
-                @Nullable Network network);
+        void onResponseReceived(@NonNull MdnsPacket packet, @NonNull SocketKey socketKey);
 
         /*** Parse a mdns response failed */
         void onFailedToParseMdnsResponse(int receivedPacketNumber, int errorCode,
-                @Nullable Network network);
+                @NonNull SocketKey socketKey);
     }
 
     /*** Callback for requested socket creation  */
     interface SocketCreationCallback {
         /*** Notify requested socket is created */
-        void onSocketCreated(@Nullable Network network);
+        void onSocketCreated(@NonNull SocketKey socketKey);
 
         /*** Notify requested socket is destroyed */
-        void onAllSocketsDestroyed(@Nullable Network network);
+        void onSocketDestroyed(@NonNull SocketKey socketKey);
     }
-}
+}
\ No newline at end of file
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsSocketProvider.java b/service-t/src/com/android/server/connectivity/mdns/MdnsSocketProvider.java
index d90f67f..e963ab7 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsSocketProvider.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsSocketProvider.java
@@ -318,11 +318,14 @@
         final MdnsInterfaceSocket mSocket;
         final List<LinkAddress> mAddresses;
         final int[] mTransports;
+        @NonNull final SocketKey mSocketKey;
 
-        SocketInfo(MdnsInterfaceSocket socket, List<LinkAddress> addresses, int[] transports) {
+        SocketInfo(MdnsInterfaceSocket socket, List<LinkAddress> addresses, int[] transports,
+                @NonNull SocketKey socketKey) {
             mSocket = socket;
             mAddresses = new ArrayList<>(addresses);
             mTransports = transports;
+            mSocketKey = socketKey;
         }
     }
 
@@ -442,7 +445,7 @@
         // Try to join the group again.
         socketInfo.mSocket.joinGroup(addresses);
 
-        notifyAddressesChanged(network, socketInfo.mSocket, addresses);
+        notifyAddressesChanged(network, socketInfo, addresses);
     }
     private LinkProperties createLPForTetheredInterface(@NonNull final String interfaceName,
             int ifaceIndex) {
@@ -523,21 +526,22 @@
                     networkInterface.getNetworkInterface(), MdnsConstants.MDNS_PORT, mLooper,
                     mPacketReadBuffer);
             final List<LinkAddress> addresses = lp.getLinkAddresses();
+            final Network network =
+                    networkKey == LOCAL_NET ? null : ((NetworkAsKey) networkKey).mNetwork;
+            final SocketKey socketKey = new SocketKey(network, networkInterface.getIndex());
             // TODO: technically transport types are mutable, although generally not in ways that
             // would meaningfully impact the logic using it here. Consider updating logic to
             // support transports being added/removed.
-            final SocketInfo socketInfo = new SocketInfo(socket, addresses, transports);
+            final SocketInfo socketInfo = new SocketInfo(socket, addresses, transports, socketKey);
             if (networkKey == LOCAL_NET) {
                 mTetherInterfaceSockets.put(interfaceName, socketInfo);
             } else {
-                mNetworkSockets.put(((NetworkAsKey) networkKey).mNetwork, socketInfo);
+                mNetworkSockets.put(network, socketInfo);
             }
             // Try to join IPv4/IPv6 group.
             socket.joinGroup(addresses);
 
             // Notify the listeners which need this socket.
-            final Network network =
-                    networkKey == LOCAL_NET ? null : ((NetworkAsKey) networkKey).mNetwork;
             notifySocketCreated(network, socketInfo);
         } catch (IOException e) {
             mSharedLog.e("Create socket failed ifName:" + interfaceName, e);
@@ -579,7 +583,7 @@
         if (socketInfo == null) return;
 
         socketInfo.mSocket.destroy();
-        notifyInterfaceDestroyed(network, socketInfo.mSocket);
+        notifyInterfaceDestroyed(network, socketInfo);
         mSocketRequestMonitor.onSocketDestroyed(network, socketInfo.mSocket);
         mSharedLog.log("Remove socket on net:" + network);
     }
@@ -588,7 +592,7 @@
         final SocketInfo socketInfo = mTetherInterfaceSockets.remove(interfaceName);
         if (socketInfo == null) return;
         socketInfo.mSocket.destroy();
-        notifyInterfaceDestroyed(null /* network */, socketInfo.mSocket);
+        notifyInterfaceDestroyed(null /* network */, socketInfo);
         mSocketRequestMonitor.onSocketDestroyed(null /* network */, socketInfo.mSocket);
         mSharedLog.log("Remove socket on ifName:" + interfaceName);
     }
@@ -597,30 +601,31 @@
         for (int i = 0; i < mCallbacksToRequestedNetworks.size(); i++) {
             final Network requestedNetwork = mCallbacksToRequestedNetworks.valueAt(i);
             if (isNetworkMatched(requestedNetwork, network)) {
-                mCallbacksToRequestedNetworks.keyAt(i).onSocketCreated(network, socketInfo.mSocket,
-                        socketInfo.mAddresses);
+                mCallbacksToRequestedNetworks.keyAt(i).onSocketCreated(socketInfo.mSocketKey,
+                        socketInfo.mSocket, socketInfo.mAddresses);
                 mSocketRequestMonitor.onSocketRequestFulfilled(network, socketInfo.mSocket,
                         socketInfo.mTransports);
             }
         }
     }
 
-    private void notifyInterfaceDestroyed(Network network, MdnsInterfaceSocket socket) {
+    private void notifyInterfaceDestroyed(Network network, SocketInfo socketInfo) {
         for (int i = 0; i < mCallbacksToRequestedNetworks.size(); i++) {
             final Network requestedNetwork = mCallbacksToRequestedNetworks.valueAt(i);
             if (isNetworkMatched(requestedNetwork, network)) {
-                mCallbacksToRequestedNetworks.keyAt(i).onInterfaceDestroyed(network, socket);
+                mCallbacksToRequestedNetworks.keyAt(i)
+                        .onInterfaceDestroyed(socketInfo.mSocketKey, socketInfo.mSocket);
             }
         }
     }
 
-    private void notifyAddressesChanged(Network network, MdnsInterfaceSocket socket,
+    private void notifyAddressesChanged(Network network, SocketInfo socketInfo,
             List<LinkAddress> addresses) {
         for (int i = 0; i < mCallbacksToRequestedNetworks.size(); i++) {
             final Network requestedNetwork = mCallbacksToRequestedNetworks.valueAt(i);
             if (isNetworkMatched(requestedNetwork, network)) {
                 mCallbacksToRequestedNetworks.keyAt(i)
-                        .onAddressesChanged(network, socket, addresses);
+                        .onAddressesChanged(socketInfo.mSocketKey, socketInfo.mSocket, addresses);
             }
         }
     }
@@ -637,7 +642,7 @@
             createSocket(new NetworkAsKey(network), lp);
         } else {
             // Notify the socket for requested network.
-            cb.onSocketCreated(network, socketInfo.mSocket, socketInfo.mAddresses);
+            cb.onSocketCreated(socketInfo.mSocketKey, socketInfo.mSocket, socketInfo.mAddresses);
             mSocketRequestMonitor.onSocketRequestFulfilled(network, socketInfo.mSocket,
                     socketInfo.mTransports);
         }
@@ -652,8 +657,7 @@
                     createLPForTetheredInterface(interfaceName, ifaceIndex));
         } else {
             // Notify the socket for requested network.
-            cb.onSocketCreated(
-                    null /* network */, socketInfo.mSocket, socketInfo.mAddresses);
+            cb.onSocketCreated(socketInfo.mSocketKey, socketInfo.mSocket, socketInfo.mAddresses);
             mSocketRequestMonitor.onSocketRequestFulfilled(null /* socketNetwork */,
                     socketInfo.mSocket, socketInfo.mTransports);
         }
@@ -741,21 +745,21 @@
          * This may be called immediately when the request is registered with an existing socket,
          * if it had been created previously for other requests.
          */
-        default void onSocketCreated(@Nullable Network network, @NonNull MdnsInterfaceSocket socket,
-                @NonNull List<LinkAddress> addresses) {}
+        default void onSocketCreated(@NonNull SocketKey socketKey,
+                @NonNull MdnsInterfaceSocket socket, @NonNull List<LinkAddress> addresses) {}
 
         /**
          * Notify that the interface was destroyed, so the provided socket cannot be used anymore.
          *
          * This indicates that although the socket was still requested, it had to be destroyed.
          */
-        default void onInterfaceDestroyed(@Nullable Network network,
+        default void onInterfaceDestroyed(@NonNull SocketKey socketKey,
                 @NonNull MdnsInterfaceSocket socket) {}
 
         /**
          * Notify the interface addresses have changed for the network.
          */
-        default void onAddressesChanged(@Nullable Network network,
+        default void onAddressesChanged(@NonNull SocketKey socketKey,
                 @NonNull MdnsInterfaceSocket socket, @NonNull List<LinkAddress> addresses) {}
     }
 
diff --git a/service-t/src/com/android/server/connectivity/mdns/NetworkInterfaceWrapper.java b/service-t/src/com/android/server/connectivity/mdns/NetworkInterfaceWrapper.java
index 0ecae48..48c396e 100644
--- a/service-t/src/com/android/server/connectivity/mdns/NetworkInterfaceWrapper.java
+++ b/service-t/src/com/android/server/connectivity/mdns/NetworkInterfaceWrapper.java
@@ -57,6 +57,10 @@
         return networkInterface.getInterfaceAddresses();
     }
 
+    public int getIndex() {
+        return networkInterface.getIndex();
+    }
+
     @Override
     public String toString() {
         return networkInterface.toString();
diff --git a/service-t/src/com/android/server/connectivity/mdns/SocketKey.java b/service-t/src/com/android/server/connectivity/mdns/SocketKey.java
new file mode 100644
index 0000000..f13d0e0
--- /dev/null
+++ b/service-t/src/com/android/server/connectivity/mdns/SocketKey.java
@@ -0,0 +1,73 @@
+/*
+ * Copyright (C) 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server.connectivity.mdns;
+
+import android.annotation.Nullable;
+import android.net.Network;
+
+import java.util.Objects;
+
+/**
+ * A class that identifies a socket.
+ *
+ * <p> A socket is typically created with an associated network. However, tethering interfaces do
+ * not have an associated network, only an interface index. This means that the socket cannot be
+ * identified in some places. Therefore, this class is necessary for identifying a socket. It
+ * includes both the network and interface index.
+ */
+public class SocketKey {
+    @Nullable
+    private final Network mNetwork;
+    private final int mInterfaceIndex;
+
+    SocketKey(int interfaceIndex) {
+        this(null /* network */, interfaceIndex);
+    }
+
+    SocketKey(@Nullable Network network, int interfaceIndex) {
+        mNetwork = network;
+        mInterfaceIndex = interfaceIndex;
+    }
+
+    @Nullable
+    public Network getNetwork() {
+        return mNetwork;
+    }
+
+    public int getInterfaceIndex() {
+        return mInterfaceIndex;
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(mNetwork, mInterfaceIndex);
+    }
+
+    @Override
+    public boolean equals(@Nullable Object other) {
+        if (!(other instanceof SocketKey)) {
+            return false;
+        }
+        return Objects.equals(mNetwork, ((SocketKey) other).mNetwork)
+                && mInterfaceIndex == ((SocketKey) other).mInterfaceIndex;
+    }
+
+    @Override
+    public String toString() {
+        return "SocketKey{ network=" + mNetwork + " interfaceIndex=" + mInterfaceIndex + " }";
+    }
+}
diff --git a/service-t/src/com/android/server/ethernet/EthernetNetworkFactory.java b/service-t/src/com/android/server/ethernet/EthernetNetworkFactory.java
index 6776920..ece10f3 100644
--- a/service-t/src/com/android/server/ethernet/EthernetNetworkFactory.java
+++ b/service-t/src/com/android/server/ethernet/EthernetNetworkFactory.java
@@ -313,17 +313,12 @@
                 mIpClientShutdownCv.block();
             }
 
-            // At the time IpClient is stopped, an IpClient event may have already been posted on
-            // the back of the handler and is awaiting execution. Once that event is executed, the
-            // associated callback object may not be valid anymore
-            // (NetworkInterfaceState#mIpClientCallback points to a different object / null).
-            private boolean isCurrentCallback() {
-                return this == mIpClientCallback;
-            }
-
-            private void handleIpEvent(final @NonNull Runnable r) {
+            private void safelyPostOnHandler(Runnable r) {
                 mHandler.post(() -> {
-                    if (!isCurrentCallback()) {
+                    if (this != mIpClientCallback) {
+                        // At the time IpClient is stopped, an IpClient event may have already been
+                        // posted on the handler and is awaiting execution. Once that event is
+                        // executed, the associated callback object may not be valid anymore.
                         Log.i(TAG, "Ignoring stale IpClientCallbacks " + this);
                         return;
                     }
@@ -333,24 +328,24 @@
 
             @Override
             public void onProvisioningSuccess(LinkProperties newLp) {
-                handleIpEvent(() -> onIpLayerStarted(newLp));
+                safelyPostOnHandler(() -> onIpLayerStarted(newLp));
             }
 
             @Override
             public void onProvisioningFailure(LinkProperties newLp) {
                 // This cannot happen due to provisioning timeout, because our timeout is 0. It can
                 // happen due to errors while provisioning or on provisioning loss.
-                handleIpEvent(() -> onIpLayerStopped());
+                safelyPostOnHandler(() -> onIpLayerStopped());
             }
 
             @Override
             public void onLinkPropertiesChange(LinkProperties newLp) {
-                handleIpEvent(() -> updateLinkProperties(newLp));
+                safelyPostOnHandler(() -> updateLinkProperties(newLp));
             }
 
             @Override
             public void onReachabilityLost(String logMsg) {
-                handleIpEvent(() -> updateNeighborLostEvent(logMsg));
+                safelyPostOnHandler(() -> updateNeighborLostEvent(logMsg));
             }
 
             @Override
diff --git a/service/ServiceConnectivityResources/res/values/config.xml b/service/ServiceConnectivityResources/res/values/config.xml
index 22d9b01..f30abc6 100644
--- a/service/ServiceConnectivityResources/res/values/config.xml
+++ b/service/ServiceConnectivityResources/res/values/config.xml
@@ -135,10 +135,17 @@
     <!-- Whether to cancel network notifications automatically when tapped -->
     <bool name="config_autoCancelNetworkNotifications">true</bool>
 
-    <!-- When no internet or partial connectivity is detected on a network, and a high priority
-         (heads up) notification would be shown due to the network being explicitly selected,
-         directly show the dialog that would normally be shown when tapping the notification
-         instead of showing the notification. -->
+    <!-- Configuration to let OEMs customize what to do when :
+         • Partial connectivity is detected on the network
+         • No internet is detected on the network, and
+           - the network was explicitly selected
+           - the system is configured to actively prefer bad wifi (see config_activelyPreferBadWifi)
+         The default behavior (false) is to post a notification with a PendingIntent so
+         the user is informed and can act if they wish.
+         Making this true instead will have the system fire the intent immediately instead
+         of showing a notification. OEMs who do this should have some intent receiver
+         listening to the intent and take the action they prefer (e.g. show a dialog,
+         show a customized notification etc).  -->
     <bool name="config_notifyNoInternetAsDialogWhenHighPriority">false</bool>
 
     <!-- When showing notifications indicating partial connectivity, display the same notifications
diff --git a/service/src/com/android/server/ConnectivityService.java b/service/src/com/android/server/ConnectivityService.java
index 447dbf4..32c3b19 100755
--- a/service/src/com/android/server/ConnectivityService.java
+++ b/service/src/com/android/server/ConnectivityService.java
@@ -390,7 +390,7 @@
     // Timeout in case the "actively prefer bad wifi" feature is on
     private static final int ACTIVELY_PREFER_BAD_WIFI_INITIAL_TIMEOUT_MS = 20 * 1000;
     // Timeout in case the "actively prefer bad wifi" feature is off
-    private static final int DONT_ACTIVELY_PREFER_BAD_WIFI_INITIAL_TIMEOUT_MS = 8 * 1000;
+    private static final int DEFAULT_EVALUATION_TIMEOUT_MS = 8 * 1000;
 
     // Default to 30s linger time-out, and 5s for nascent network. Modifiable only for testing.
     private static final String LINGER_DELAY_PROPERTY = "persist.netmon.linger";
@@ -1704,7 +1704,8 @@
         mUserAllContext.registerReceiver(mPackageIntentReceiver, packageIntentFilter,
                 null /* broadcastPermission */, mHandler);
 
-        mNetworkActivityTracker = new LegacyNetworkActivityTracker(mContext, mHandler, mNetd);
+        mNetworkActivityTracker =
+                new LegacyNetworkActivityTracker(mContext, mNetd, mHandler, mDeps.isAtLeastU());
 
         final NetdCallback netdCallback = new NetdCallback();
         try {
@@ -5877,7 +5878,8 @@
                     break;
                 }
                 case EVENT_REPORT_NETWORK_ACTIVITY:
-                    mNetworkActivityTracker.handleReportNetworkActivity();
+                    final NetworkActivityParams arg = (NetworkActivityParams) msg.obj;
+                    mNetworkActivityTracker.handleReportNetworkActivity(arg);
                     break;
                 case EVENT_MOBILE_DATA_PREFERRED_UIDS_CHANGED:
                     handleMobileDataPreferredUidsChanged();
@@ -9933,10 +9935,25 @@
                 networkAgent.networkMonitor().notifyNetworkConnected(params.linkProperties,
                         params.networkCapabilities);
             }
-            final long delay = !avoidBadWifi() && activelyPreferBadWifi()
-                    ? ACTIVELY_PREFER_BAD_WIFI_INITIAL_TIMEOUT_MS
-                    : DONT_ACTIVELY_PREFER_BAD_WIFI_INITIAL_TIMEOUT_MS;
-            scheduleEvaluationTimeout(networkAgent.network, delay);
+            final long evaluationDelay;
+            if (!networkAgent.networkCapabilities.hasSingleTransport(TRANSPORT_WIFI)) {
+                // If the network is anything other than pure wifi, use the default timeout.
+                evaluationDelay = DEFAULT_EVALUATION_TIMEOUT_MS;
+            } else if (networkAgent.networkAgentConfig.isExplicitlySelected()) {
+                // If the network is explicitly selected, use the default timeout because it's
+                // shorter and the user is likely staring at the screen expecting it to validate
+                // right away.
+                evaluationDelay = DEFAULT_EVALUATION_TIMEOUT_MS;
+            } else if (avoidBadWifi() || !activelyPreferBadWifi()) {
+                // If avoiding bad wifi, or if not avoiding but also not preferring bad wifi
+                evaluationDelay = DEFAULT_EVALUATION_TIMEOUT_MS;
+            } else {
+                // It's wifi, automatically connected, and bad wifi is preferred : use the
+                // longer timeout to avoid the device switching to captive portals with bad
+                // signal or very slow response.
+                evaluationDelay = ACTIVELY_PREFER_BAD_WIFI_INITIAL_TIMEOUT_MS;
+            }
+            scheduleEvaluationTimeout(networkAgent.network, evaluationDelay);
 
             // Whether a particular NetworkRequest listen should cause signal strength thresholds to
             // be communicated to a particular NetworkAgent depends only on the network's immutable,
@@ -11060,11 +11077,34 @@
         notifyDataStallSuspected(p, network.getNetId());
     }
 
+    /**
+     * Class to hold the information for network activity change event from idle timers
+     * {@link NetdCallback#onInterfaceClassActivityChanged(boolean, int, long, int)}
+     */
+    private static final class NetworkActivityParams {
+        public final boolean isActive;
+        // Label used for idle timer. Transport type is used as label.
+        // label is int since NMS was using the identifier as int, and it has not been changed
+        public final int label;
+        public final long timestampNs;
+        // Uid represents the uid that was responsible for waking the radio.
+        // -1 for no uid and uid is -1 if isActive is false.
+        public final int uid;
+
+        NetworkActivityParams(boolean isActive, int label, long timestampNs, int uid) {
+            this.isActive = isActive;
+            this.label = label;
+            this.timestampNs = timestampNs;
+            this.uid = uid;
+        }
+    }
+
     private class NetdCallback extends BaseNetdUnsolicitedEventListener {
         @Override
-        public void onInterfaceClassActivityChanged(boolean isActive, int transportType,
+        public void onInterfaceClassActivityChanged(boolean isActive, int label,
                 long timestampNs, int uid) {
-            mNetworkActivityTracker.setAndReportNetworkActive(isActive, transportType, timestampNs);
+            mHandler.sendMessage(mHandler.obtainMessage(EVENT_REPORT_NETWORK_ACTIVITY,
+                    new NetworkActivityParams(isActive, label, timestampNs, uid)));
         }
 
         @Override
@@ -11092,14 +11132,14 @@
         private static final int NO_UID = -1;
         private final Context mContext;
         private final INetd mNetd;
+        private final Handler mHandler;
         private final RemoteCallbackList<INetworkActivityListener> mNetworkActivityListeners =
                 new RemoteCallbackList<>();
         // Indicate the current system default network activity is active or not.
-        @GuardedBy("mActiveIdleTimers")
-        private boolean mNetworkActive;
-        @GuardedBy("mActiveIdleTimers")
+        // This needs to be volatile to allow non handler threads to read this value without lock.
+        private volatile boolean mIsDefaultNetworkActive;
         private final ArrayMap<String, IdleTimerParams> mActiveIdleTimers = new ArrayMap<>();
-        private final Handler mHandler;
+        private final boolean mIsAtLeastU;
 
         private static class IdleTimerParams {
             public final int timeout;
@@ -11111,34 +11151,37 @@
             }
         }
 
-        LegacyNetworkActivityTracker(@NonNull Context context, @NonNull Handler handler,
-                @NonNull INetd netd) {
+        LegacyNetworkActivityTracker(@NonNull Context context, @NonNull INetd netd,
+                @NonNull Handler handler, boolean isAtLeastU) {
             mContext = context;
             mNetd = netd;
             mHandler = handler;
+            mIsAtLeastU = isAtLeastU;
         }
 
-        public void setAndReportNetworkActive(boolean active, int transportType, long tsNanos) {
-            sendDataActivityBroadcast(transportTypeToLegacyType(transportType), active, tsNanos);
-            synchronized (mActiveIdleTimers) {
-                mNetworkActive = active;
-                // If there are no idle timers, it means that system is not monitoring
-                // activity, so the system default network for those default network
-                // unspecified apps is always considered active.
-                //
-                // TODO: If the mActiveIdleTimers is empty, netd will actually not send
-                // any network activity change event. Whenever this event is received,
-                // the mActiveIdleTimers should be always not empty. The legacy behavior
-                // is no-op. Remove to refer to mNetworkActive only.
-                if (mNetworkActive || mActiveIdleTimers.isEmpty()) {
-                    mHandler.sendMessage(mHandler.obtainMessage(EVENT_REPORT_NETWORK_ACTIVITY));
-                }
+        private void ensureRunningOnConnectivityServiceThread() {
+            if (mHandler.getLooper().getThread() != Thread.currentThread()) {
+                throw new IllegalStateException("Not running on ConnectivityService thread: "
+                                + Thread.currentThread().getName());
             }
         }
 
-        // The network activity should only be updated from ConnectivityService handler thread
-        // when mActiveIdleTimers lock is held.
-        @GuardedBy("mActiveIdleTimers")
+        public void handleReportNetworkActivity(NetworkActivityParams activityParams) {
+            ensureRunningOnConnectivityServiceThread();
+            if (mActiveIdleTimers.isEmpty()) {
+                // This activity change is not for the current default network.
+                // This can happen if netd callback post activity change event message but
+                // the default network is lost before processing this message.
+                return;
+            }
+            sendDataActivityBroadcast(transportTypeToLegacyType(activityParams.label),
+                    activityParams.isActive, activityParams.timestampNs);
+            mIsDefaultNetworkActive = activityParams.isActive;
+            if (mIsDefaultNetworkActive) {
+                reportNetworkActive();
+            }
+        }
+
         private void reportNetworkActive() {
             final int length = mNetworkActivityListeners.beginBroadcast();
             if (DDBG) log("reportNetworkActive, notify " + length + " listeners");
@@ -11155,13 +11198,6 @@
             }
         }
 
-        @GuardedBy("mActiveIdleTimers")
-        public void handleReportNetworkActivity() {
-            synchronized (mActiveIdleTimers) {
-                reportNetworkActive();
-            }
-        }
-
         // This is deprecated and only to support legacy use cases.
         private int transportTypeToLegacyType(int type) {
             switch (type) {
@@ -11203,8 +11239,10 @@
          *
          * Every {@code setupDataActivityTracking} should be paired with a
          * {@link #removeDataActivityTracking} for cleanup.
+         *
+         * @return true if the idleTimer is added to the network, false otherwise
          */
-        private void setupDataActivityTracking(NetworkAgentInfo networkAgent) {
+        private boolean setupDataActivityTracking(NetworkAgentInfo networkAgent) {
             final String iface = networkAgent.linkProperties.getInterfaceName();
 
             final int timeout;
@@ -11223,25 +11261,22 @@
                         15);
                 type = NetworkCapabilities.TRANSPORT_WIFI;
             } else {
-                return; // do not track any other networks
+                return false; // do not track any other networks
             }
 
             updateRadioPowerState(true /* isActive */, type);
 
             if (timeout > 0 && iface != null) {
                 try {
-                    synchronized (mActiveIdleTimers) {
-                        // Networks start up.
-                        mNetworkActive = true;
-                        mActiveIdleTimers.put(iface, new IdleTimerParams(timeout, type));
-                        mNetd.idletimerAddInterface(iface, timeout, Integer.toString(type));
-                        reportNetworkActive();
-                    }
+                    mActiveIdleTimers.put(iface, new IdleTimerParams(timeout, type));
+                    mNetd.idletimerAddInterface(iface, timeout, Integer.toString(type));
+                    return true;
                 } catch (Exception e) {
                     // You shall not crash!
                     loge("Exception in setupDataActivityTracking " + e);
                 }
             }
+            return false;
         }
 
         /**
@@ -11264,26 +11299,45 @@
 
             try {
                 updateRadioPowerState(false /* isActive */, type);
-                synchronized (mActiveIdleTimers) {
-                    final IdleTimerParams params = mActiveIdleTimers.remove(iface);
-                    // The call fails silently if no idle timer setup for this interface
-                    mNetd.idletimerRemoveInterface(iface, params.timeout,
-                            Integer.toString(params.transportType));
+                final IdleTimerParams params = mActiveIdleTimers.remove(iface);
+                if (params == null) {
+                    // IdleTimer is not added if the configured timeout is 0 or negative value
+                    return;
                 }
+                // The call fails silently if no idle timer setup for this interface
+                mNetd.idletimerRemoveInterface(iface, params.timeout,
+                        Integer.toString(params.transportType));
             } catch (Exception e) {
                 // You shall not crash!
                 loge("Exception in removeDataActivityTracking " + e);
             }
         }
 
+        private void updateDefaultNetworkActivity(NetworkAgentInfo defaultNetwork,
+                boolean hasIdleTimer) {
+            if (defaultNetwork != null) {
+                mIsDefaultNetworkActive = true;
+                // On T-, callbacks are called only when the network has the idle timer.
+                if (mIsAtLeastU || hasIdleTimer) {
+                    reportNetworkActive();
+                }
+            } else {
+                // If there is no default network, default network is considered inactive.
+                mIsDefaultNetworkActive = false;
+            }
+        }
+
         /**
          * Update data activity tracking when network state is updated.
          */
         public void updateDataActivityTracking(NetworkAgentInfo newNetwork,
                 NetworkAgentInfo oldNetwork) {
+            ensureRunningOnConnectivityServiceThread();
+            boolean hasIdleTimer = false;
             if (newNetwork != null) {
-                setupDataActivityTracking(newNetwork);
+                hasIdleTimer = setupDataActivityTracking(newNetwork);
             }
+            updateDefaultNetworkActivity(newNetwork, hasIdleTimer);
             if (oldNetwork != null) {
                 removeDataActivityTracking(oldNetwork);
             }
@@ -11304,15 +11358,7 @@
         }
 
         public boolean isDefaultNetworkActive() {
-            synchronized (mActiveIdleTimers) {
-                // If there are no idle timers, it means that system is not monitoring activity,
-                // so the default network is always considered active.
-                //
-                // TODO : Distinguish between the cases where mActiveIdleTimers is empty because
-                // tracking is disabled (negative idle timer value configured), or no active default
-                // network. In the latter case, this reports active but it should report inactive.
-                return mNetworkActive || mActiveIdleTimers.isEmpty();
-            }
+            return mIsDefaultNetworkActive;
         }
 
         public void registerNetworkActivityListener(@NonNull INetworkActivityListener l) {
@@ -11324,15 +11370,22 @@
         }
 
         public void dump(IndentingPrintWriter pw) {
-            synchronized (mActiveIdleTimers) {
-                pw.print("mNetworkActive="); pw.println(mNetworkActive);
-                pw.println("Idle timers:");
-                for (HashMap.Entry<String, IdleTimerParams> ent : mActiveIdleTimers.entrySet()) {
+            pw.print("mIsDefaultNetworkActive="); pw.println(mIsDefaultNetworkActive);
+            pw.println("Idle timers:");
+            try {
+                for (Map.Entry<String, IdleTimerParams> ent : mActiveIdleTimers.entrySet()) {
                     pw.print("  "); pw.print(ent.getKey()); pw.println(":");
                     final IdleTimerParams params = ent.getValue();
                     pw.print("    timeout="); pw.print(params.timeout);
                     pw.print(" type="); pw.println(params.transportType);
                 }
+            } catch (Exception e) {
+                // mActiveIdleTimers should only be accessed from handler thread, except dump().
+                // As dump() is never called in normal usage, it would be needlessly expensive
+                // to lock the collection only for its benefit.
+                // Also, mActiveIdleTimers is not expected to be updated frequently.
+                // So catching the exception and logging.
+                pw.println("Failed to dump NetworkActivityTracker: " + e);
             }
         }
     }
diff --git a/service/src/com/android/server/connectivity/AutomaticOnOffKeepaliveTracker.java b/service/src/com/android/server/connectivity/AutomaticOnOffKeepaliveTracker.java
index 981bd4d..6ba2033 100644
--- a/service/src/com/android/server/connectivity/AutomaticOnOffKeepaliveTracker.java
+++ b/service/src/com/android/server/connectivity/AutomaticOnOffKeepaliveTracker.java
@@ -94,6 +94,7 @@
     private static final int ADJUST_TCP_POLLING_DELAY_MS = 2000;
     private static final String AUTOMATIC_ON_OFF_KEEPALIVE_VERSION =
             "automatic_on_off_keepalive_version";
+    public static final long METRICS_COLLECTION_DURATION_MS = 24 * 60 * 60 * 1_000L;
 
     // ConnectivityService parses message constants from itself and AutomaticOnOffKeepaliveTracker
     // with MessageUtils for debugging purposes, and crashes if some messages have the same values.
@@ -180,6 +181,9 @@
     private final LocalLog mEventLog = new LocalLog(MAX_EVENTS_LOGS);
 
     private final KeepaliveStatsTracker mKeepaliveStatsTracker;
+
+    private final long mMetricsWriteTimeBase;
+
     /**
      * Information about a managed keepalive.
      *
@@ -311,7 +315,26 @@
                 mContext, mConnectivityServiceHandler);
 
         mAlarmManager = mDependencies.getAlarmManager(context);
-        mKeepaliveStatsTracker = new KeepaliveStatsTracker(context, handler);
+        mKeepaliveStatsTracker =
+                mDependencies.newKeepaliveStatsTracker(context, handler);
+
+        final long time = mDependencies.getElapsedRealtime();
+        mMetricsWriteTimeBase = time % METRICS_COLLECTION_DURATION_MS;
+        final long triggerAtMillis = mMetricsWriteTimeBase + METRICS_COLLECTION_DURATION_MS;
+        mAlarmManager.set(AlarmManager.ELAPSED_REALTIME_WAKEUP, triggerAtMillis, TAG,
+                this::writeMetricsAndRescheduleAlarm, handler);
+    }
+
+    private void writeMetricsAndRescheduleAlarm() {
+        mKeepaliveStatsTracker.writeAndResetMetrics();
+
+        final long time = mDependencies.getElapsedRealtime();
+        final long triggerAtMillis =
+                mMetricsWriteTimeBase
+                        + (time - time % METRICS_COLLECTION_DURATION_MS)
+                        + METRICS_COLLECTION_DURATION_MS;
+        mAlarmManager.set(AlarmManager.ELAPSED_REALTIME_WAKEUP, triggerAtMillis, TAG,
+                this::writeMetricsAndRescheduleAlarm, mConnectivityServiceHandler);
     }
 
     private void startTcpPollingAlarm(@NonNull AutomaticOnOffKeepalive ki) {
@@ -898,6 +921,14 @@
         }
 
         /**
+         * Construct a new KeepaliveStatsTracker.
+         */
+        public KeepaliveStatsTracker newKeepaliveStatsTracker(@NonNull Context context,
+                @NonNull Handler connectivityserviceHander) {
+            return new KeepaliveStatsTracker(context, connectivityserviceHander);
+        }
+
+        /**
          * Find out if a feature is enabled from DeviceConfig.
          *
          * @param name The name of the property to look up.
diff --git a/service/src/com/android/server/connectivity/ClatCoordinator.java b/service/src/com/android/server/connectivity/ClatCoordinator.java
index d87f250..eb3e7ce 100644
--- a/service/src/com/android/server/connectivity/ClatCoordinator.java
+++ b/service/src/com/android/server/connectivity/ClatCoordinator.java
@@ -37,9 +37,10 @@
 import android.system.ErrnoException;
 import android.util.Log;
 
+import androidx.annotation.RequiresApi;
+
 import com.android.internal.annotations.VisibleForTesting;
 import com.android.internal.util.IndentingPrintWriter;
-import com.android.modules.utils.build.SdkLevel;
 import com.android.net.module.util.BpfMap;
 import com.android.net.module.util.IBpfMap;
 import com.android.net.module.util.InterfaceParams;
@@ -59,8 +60,6 @@
 import java.nio.ByteBuffer;
 import java.util.Objects;
 
-import androidx.annotation.RequiresApi;
-
 /**
  * This coordinator is responsible for providing clat relevant functionality.
  *
diff --git a/service/src/com/android/server/connectivity/KeepaliveStatsTracker.java b/service/src/com/android/server/connectivity/KeepaliveStatsTracker.java
index ff9bb70..d59d526 100644
--- a/service/src/com/android/server/connectivity/KeepaliveStatsTracker.java
+++ b/service/src/com/android/server/connectivity/KeepaliveStatsTracker.java
@@ -19,7 +19,10 @@
 import static android.telephony.SubscriptionManager.OnSubscriptionsChangedListener;
 
 import android.annotation.NonNull;
+import android.content.BroadcastReceiver;
 import android.content.Context;
+import android.content.Intent;
+import android.content.IntentFilter;
 import android.net.Network;
 import android.net.NetworkCapabilities;
 import android.net.NetworkSpecifier;
@@ -42,6 +45,8 @@
 import com.android.metrics.KeepaliveLifetimeForCarrier;
 import com.android.metrics.KeepaliveLifetimePerCarrier;
 import com.android.modules.utils.BackgroundThread;
+import com.android.net.module.util.CollectionUtils;
+import com.android.server.ConnectivityStatsLog;
 
 import java.util.ArrayList;
 import java.util.HashMap;
@@ -67,9 +72,7 @@
     // Mapping of subId to carrierId. Updates are received from OnSubscriptionsChangedListener
     private final SparseIntArray mCachedCarrierIdPerSubId = new SparseIntArray();
     // The default subscription id obtained from SubscriptionManager.getDefaultSubscriptionId.
-    // Updates are done from the OnSubscriptionsChangedListener. Note that there is no callback done
-    // to OnSubscriptionsChangedListener when the default sub id changes.
-    // TODO: Register a listener for the default subId when it is possible.
+    // Updates are received from the ACTION_DEFAULT_SUBSCRIPTION_CHANGED broadcast.
     private int mCachedDefaultSubscriptionId = SubscriptionManager.INVALID_SUBSCRIPTION_ID;
 
     // Class to store network information, lifetime durations and active state of a keepalive.
@@ -267,6 +270,19 @@
                 Objects.requireNonNull(context.getSystemService(SubscriptionManager.class));
 
         mLastUpdateDurationsTimestamp = mDependencies.getElapsedRealtime();
+        context.registerReceiver(
+                new BroadcastReceiver() {
+                    @Override
+                    public void onReceive(Context context, Intent intent) {
+                        mCachedDefaultSubscriptionId =
+                                intent.getIntExtra(
+                                        SubscriptionManager.EXTRA_SUBSCRIPTION_INDEX,
+                                        SubscriptionManager.INVALID_SUBSCRIPTION_ID);
+                    }
+                },
+                new IntentFilter(SubscriptionManager.ACTION_DEFAULT_SUBSCRIPTION_CHANGED),
+                /* broadcastPermission= */ null,
+                mConnectivityServiceHandler);
 
         // The default constructor for OnSubscriptionsChangedListener will always implicitly grab
         // the looper of the current thread. In the case the current thread does not have a looper,
@@ -284,11 +300,8 @@
                                 // but not necessarily empty, simply ignore it. Another call to the
                                 // listener will be invoked in the future.
                                 if (activeSubInfoList == null) return;
-                                final int defaultSubId =
-                                        subscriptionManager.getDefaultSubscriptionId();
                                 mConnectivityServiceHandler.post(() -> {
                                     mCachedCarrierIdPerSubId.clear();
-                                    mCachedDefaultSubscriptionId = defaultSubId;
 
                                     for (final SubscriptionInfo subInfo : activeSubInfoList) {
                                         mCachedCarrierIdPerSubId.put(subInfo.getSubscriptionId(),
@@ -592,6 +605,7 @@
      *
      * @return the DailykeepaliveInfoReported proto that was built.
      */
+    @VisibleForTesting
     public @NonNull DailykeepaliveInfoReported buildAndResetMetrics() {
         ensureRunningOnHandlerThread();
         final long timeNow = mDependencies.getElapsedRealtime();
@@ -620,6 +634,20 @@
         return metrics;
     }
 
+    /** Writes the stored metrics to ConnectivityStatsLog and resets.  */
+    public void writeAndResetMetrics() {
+        ensureRunningOnHandlerThread();
+        final DailykeepaliveInfoReported dailyKeepaliveInfoReported = buildAndResetMetrics();
+        ConnectivityStatsLog.write(
+                ConnectivityStatsLog.DAILY_KEEPALIVE_INFO_REPORTED,
+                dailyKeepaliveInfoReported.getDurationPerNumOfKeepalive().toByteArray(),
+                dailyKeepaliveInfoReported.getKeepaliveLifetimePerCarrier().toByteArray(),
+                dailyKeepaliveInfoReported.getKeepaliveRequests(),
+                dailyKeepaliveInfoReported.getAutomaticKeepaliveRequests(),
+                dailyKeepaliveInfoReported.getDistinctUserCount(),
+                CollectionUtils.toIntArray(dailyKeepaliveInfoReported.getUidList()));
+    }
+
     private void ensureRunningOnHandlerThread() {
         if (mConnectivityServiceHandler.getLooper().getThread() != Thread.currentThread()) {
             throw new IllegalStateException(
diff --git a/service/src/com/android/server/connectivity/KeepaliveTracker.java b/service/src/com/android/server/connectivity/KeepaliveTracker.java
index b4f74d5..76e97e2 100644
--- a/service/src/com/android/server/connectivity/KeepaliveTracker.java
+++ b/service/src/com/android/server/connectivity/KeepaliveTracker.java
@@ -104,19 +104,21 @@
     // Allowed unprivileged keepalive slots per uid. Caller's permission will be enforced if
     // the number of remaining keepalive slots is less than or equal to the threshold.
     private final int mAllowedUnprivilegedSlotsForUid;
-
+    private final Dependencies mDependencies;
     public KeepaliveTracker(Context context, Handler handler) {
-        this(context, handler, new TcpKeepaliveController(handler));
+        this(context, handler, new TcpKeepaliveController(handler), new Dependencies());
     }
 
     @VisibleForTesting
-    KeepaliveTracker(Context context, Handler handler, TcpKeepaliveController tcpController) {
+    KeepaliveTracker(Context context, Handler handler, TcpKeepaliveController tcpController,
+            Dependencies deps) {
         mTcpController = tcpController;
         mContext = context;
+        mDependencies = deps;
 
-        mSupportedKeepalives = KeepaliveResourceUtil.getSupportedKeepalives(context);
+        mSupportedKeepalives = mDependencies.getSupportedKeepalives(mContext);
 
-        final ConnectivityResources res = new ConnectivityResources(mContext);
+        final ConnectivityResources res = mDependencies.createConnectivityResources(mContext);
         mReservedPrivilegedSlots = res.get().getInteger(
                 R.integer.config_reservedPrivilegedKeepaliveSlots);
         mAllowedUnprivilegedSlotsForUid = res.get().getInteger(
@@ -148,8 +150,9 @@
         public static final int TYPE_TCP = 2;
 
         // Keepalive slot. A small integer that identifies this keepalive among the ones handled
-        // by this network.
-        private int mSlot = NO_KEEPALIVE;
+        // by this network. This is initialized to NO_KEEPALIVE for new keepalives, but to the
+        // old slot for resumed keepalives.
+        private int mSlot;
 
         // Packet data.
         private final KeepalivePacketData mPacket;
@@ -169,25 +172,30 @@
                 int interval,
                 int type,
                 @Nullable FileDescriptor fd) throws InvalidSocketException {
-            this(callback, nai, packet, interval, type, fd, false /* resumed */);
+            this(callback, nai, packet, Binder.getCallingPid(), Binder.getCallingUid(), interval,
+                    type, fd, NO_KEEPALIVE /* slot */, false /* resumed */);
         }
 
         KeepaliveInfo(@NonNull ISocketKeepaliveCallback callback,
                 @NonNull NetworkAgentInfo nai,
                 @NonNull KeepalivePacketData packet,
+                int pid,
+                int uid,
                 int interval,
                 int type,
                 @Nullable FileDescriptor fd,
+                int slot,
                 boolean resumed) throws InvalidSocketException {
             mCallback = callback;
-            mPid = Binder.getCallingPid();
-            mUid = Binder.getCallingUid();
+            mPid = pid;
+            mUid = uid;
             mPrivileged = (PERMISSION_GRANTED == mContext.checkPermission(PERMISSION, mPid, mUid));
 
             mNai = nai;
             mPacket = packet;
             mInterval = interval;
             mType = type;
+            mSlot = slot;
             mResumed = resumed;
 
             // For SocketKeepalive, a dup of fd is kept in mFd so the source port from which the
@@ -468,8 +476,8 @@
          * Construct a new KeepaliveInfo from existing KeepaliveInfo with a new fd.
          */
         public KeepaliveInfo withFd(@NonNull FileDescriptor fd) throws InvalidSocketException {
-            return new KeepaliveInfo(mCallback, mNai, mPacket, mInterval, mType, fd,
-                    true /* resumed */);
+            return new KeepaliveInfo(mCallback, mNai, mPacket, mPid, mUid, mInterval, mType,
+                    fd, mSlot, true /* resumed */);
         }
     }
 
@@ -508,7 +516,9 @@
      */
     public int handleStartKeepalive(KeepaliveInfo ki) {
         NetworkAgentInfo nai = ki.getNai();
-        int slot = findFirstFreeSlot(nai);
+        // If this was a paused keepalive, then reuse the same slot that was kept for it. Otherwise,
+        // use the first free slot for this network agent.
+        final int slot = NO_KEEPALIVE != ki.mSlot ? ki.mSlot : findFirstFreeSlot(nai);
         mKeepalives.get(nai).put(slot, ki);
         return ki.start(slot);
     }
@@ -518,6 +528,12 @@
         if (networkKeepalives != null) {
             final ArrayList<KeepaliveInfo> kalist = new ArrayList(networkKeepalives.values());
             for (KeepaliveInfo ki : kalist) {
+                // If the keepalive is paused, then it is already stopped with the hardware and so
+                // continue. Note that to send the appropriate stop reason callback,
+                // AutomaticOnOffKeepaliveTracker will call finalizePausedKeepalive which will also
+                // finally remove this keepalive slot from the array.
+                if (ki.mStopReason == SUCCESS_PAUSED) continue;
+
                 ki.stop(reason);
                 // Clean up keepalives since the network agent is disconnected and unable to pass
                 // back asynchronous result of stop().
@@ -556,17 +572,22 @@
             return;
         }
 
-        // Remove the keepalive from hash table so the slot can be considered available when reusing
-        // it.
-        networkKeepalives.remove(slot);
-        Log.d(TAG, "Remove keepalive " + slot + " on " + networkName + ", "
-                + networkKeepalives.size() + " remains.");
+        // If the keepalive was stopped for good, remove it from the hash table so the slot can
+        // be considered available when reusing it. If it was only a pause, let it sit in the map
+        // so it sits on the slot.
+        final int reason = ki.mStopReason;
+        if (reason != SUCCESS_PAUSED) {
+            networkKeepalives.remove(slot);
+            Log.d(TAG, "Remove keepalive " + slot + " on " + networkName + ", "
+                    + networkKeepalives.size() + " remains.");
+        } else {
+            Log.d(TAG, "Pause keepalive " + slot + " on " + networkName + ", keep slot reserved");
+        }
         if (networkKeepalives.isEmpty()) {
             mKeepalives.remove(nai);
         }
 
         // Notify app that the keepalive is stopped.
-        final int reason = ki.mStopReason;
         if (reason == SUCCESS) {
             try {
                 ki.mCallback.onStopped();
@@ -612,7 +633,8 @@
     /**
      * Finalize a paused keepalive.
      *
-     * This will send the appropriate callback after checking that this keepalive is indeed paused.
+     * This will send the appropriate callback after checking that this keepalive is indeed paused,
+     * and free the slot.
      *
      * @param ki the keepalive to finalize
      * @param reason the reason the keepalive is stopped
@@ -630,6 +652,13 @@
         } else {
             notifyErrorCallback(ki.mCallback, reason);
         }
+
+        final HashMap<Integer, KeepaliveInfo> networkKeepalives = mKeepalives.get(ki.mNai);
+        if (networkKeepalives == null) {
+            Log.e(TAG, "Attempt to finalize keepalive on nonexistent network " + ki.mNai);
+            return;
+        }
+        networkKeepalives.remove(ki.mSlot);
     }
 
     /**
@@ -725,6 +754,7 @@
             srcAddress = InetAddresses.parseNumericAddress(srcAddrString);
             dstAddress = InetAddresses.parseNumericAddress(dstAddrString);
         } catch (IllegalArgumentException e) {
+            Log.e(TAG, "Fail to construct address", e);
             notifyErrorCallback(cb, ERROR_INVALID_IP_ADDRESS);
             return null;
         }
@@ -734,6 +764,7 @@
             packet = NattKeepalivePacketData.nattKeepalivePacket(
                     srcAddress, srcPort, dstAddress, NATT_PORT);
         } catch (InvalidPacketException e) {
+            Log.e(TAG, "Fail to construct keepalive packet", e);
             notifyErrorCallback(cb, e.getError());
             return null;
         }
@@ -866,4 +897,31 @@
         }
         pw.decreaseIndent();
     }
+
+    /**
+     * Dependencies class for testing.
+     */
+    @VisibleForTesting
+    public static class Dependencies {
+        /**
+         * Read supported keepalive count for each transport type from overlay resource. This should
+         * be used to create a local variable store of resource customization, and set as the
+         * input for {@link getSupportedKeepalivesForNetworkCapabilities}.
+         *
+         * @param context The context to read resource from.
+         * @return An array of supported keepalive count for each transport type.
+         */
+        @NonNull
+        public int[] getSupportedKeepalives(@NonNull Context context) {
+            return KeepaliveResourceUtil.getSupportedKeepalives(context);
+        }
+
+        /**
+         * Create a new {@link ConnectivityResources}.
+         */
+        @NonNull
+        public ConnectivityResources createConnectivityResources(@NonNull Context context) {
+            return new ConnectivityResources(context);
+        }
+    }
 }
diff --git a/service/src/com/android/server/connectivity/Nat464Xlat.java b/service/src/com/android/server/connectivity/Nat464Xlat.java
index 90cddda..b315235 100644
--- a/service/src/com/android/server/connectivity/Nat464Xlat.java
+++ b/service/src/com/android/server/connectivity/Nat464Xlat.java
@@ -44,7 +44,9 @@
 import com.android.server.ConnectivityService;
 
 import java.io.IOException;
+import java.net.Inet4Address;
 import java.net.Inet6Address;
+import java.net.UnknownHostException;
 import java.util.Objects;
 
 /**
@@ -99,7 +101,8 @@
     private IpPrefix mNat64PrefixFromRa;
     private String mBaseIface;
     private String mIface;
-    private Inet6Address mIPv6Address;
+    @VisibleForTesting
+    Inet6Address mIPv6Address;
     private State mState = State.IDLE;
     private final ClatCoordinator mClatCoordinator;  // non-null iff T+
 
@@ -239,6 +242,7 @@
         mNat64PrefixInUse = null;
         mIface = null;
         mBaseIface = null;
+        mIPv6Address = null;
 
         if (!mPrefixDiscoveryRunning) {
             setPrefix64(null);
@@ -541,6 +545,52 @@
     }
 
     /**
+     * Translate the input v4 address to v6 clat address.
+     */
+    @Nullable
+    public Inet6Address translateV4toV6(@NonNull Inet4Address addr) {
+        // Variables in Nat464Xlat should only be accessed from handler thread.
+        ensureRunningOnHandlerThread();
+        if (!isStarted()) return null;
+
+        return convertv4ToClatv6(mNat64PrefixInUse, addr);
+    }
+
+    @Nullable
+    private static Inet6Address convertv4ToClatv6(
+            @NonNull IpPrefix prefix, @NonNull Inet4Address addr) {
+        final byte[] v6Addr = new byte[16];
+        // Generate a v6 address from Nat64 prefix. Prefix should be 12 bytes long.
+        System.arraycopy(prefix.getAddress().getAddress(), 0, v6Addr, 0, 12);
+        System.arraycopy(addr.getAddress(), 0, v6Addr, 12, 4);
+
+        try {
+            return (Inet6Address) Inet6Address.getByAddress(v6Addr);
+        } catch (UnknownHostException e) {
+            Log.e(TAG, "getByAddress should never throw for a numeric address");
+            return null;
+        }
+    }
+
+    /**
+     * Get the generated v6 address of clat.
+     */
+    @Nullable
+    public Inet6Address getClatv6SrcAddress() {
+        // Variables in Nat464Xlat should only be accessed from handler thread.
+        ensureRunningOnHandlerThread();
+
+        return mIPv6Address;
+    }
+
+    private void ensureRunningOnHandlerThread() {
+        if (mNetwork.handler().getLooper().getThread() != Thread.currentThread()) {
+            throw new IllegalStateException(
+                    "Not running on handler thread: " + Thread.currentThread().getName());
+        }
+    }
+
+    /**
      * Dump the NAT64 xlat information.
      *
      * @param pw print writer.
diff --git a/service/src/com/android/server/connectivity/NetworkAgentInfo.java b/service/src/com/android/server/connectivity/NetworkAgentInfo.java
index 85282cb..08c1455 100644
--- a/service/src/com/android/server/connectivity/NetworkAgentInfo.java
+++ b/service/src/com/android/server/connectivity/NetworkAgentInfo.java
@@ -68,6 +68,8 @@
 import com.android.server.ConnectivityService;
 
 import java.io.PrintWriter;
+import java.net.Inet4Address;
+import java.net.Inet6Address;
 import java.time.Instant;
 import java.util.ArrayList;
 import java.util.Arrays;
@@ -1033,6 +1035,22 @@
     }
 
     /**
+     * Get the generated v6 address of clat.
+     */
+    @Nullable
+    public Inet6Address getClatv6SrcAddress() {
+        return clatd.getClatv6SrcAddress();
+    }
+
+    /**
+     * Translate the input v4 address to v6 clat address.
+     */
+    @Nullable
+    public Inet6Address translateV4toClatV6(@NonNull Inet4Address addr) {
+        return clatd.translateV4toV6(addr);
+    }
+
+    /**
      * Get the NetworkMonitorManager in this NetworkAgentInfo.
      *
      * <p>This will be null before {@link #onNetworkMonitorCreated(INetworkMonitor)} is called.
diff --git a/service/src/com/android/server/connectivity/NetworkNotificationManager.java b/service/src/com/android/server/connectivity/NetworkNotificationManager.java
index 8b0cb7c..bc13592 100644
--- a/service/src/com/android/server/connectivity/NetworkNotificationManager.java
+++ b/service/src/com/android/server/connectivity/NetworkNotificationManager.java
@@ -322,7 +322,8 @@
 
     private boolean maybeNotifyViaDialog(Resources res, NotificationType notifyType,
             PendingIntent intent) {
-        if (notifyType != NotificationType.NO_INTERNET
+        if (notifyType != NotificationType.LOST_INTERNET
+                && notifyType != NotificationType.NO_INTERNET
                 && notifyType != NotificationType.PARTIAL_CONNECTIVITY) {
             return false;
         }
@@ -432,7 +433,8 @@
      * A notification with a higher number will take priority over a notification with a lower
      * number.
      */
-    private static int priority(NotificationType t) {
+    @VisibleForTesting
+    public static int priority(NotificationType t) {
         if (t == null) {
             return 0;
         }
diff --git a/service/src/com/android/server/connectivity/ProxyTracker.java b/service/src/com/android/server/connectivity/ProxyTracker.java
index bda4b8f..6a0918b 100644
--- a/service/src/com/android/server/connectivity/ProxyTracker.java
+++ b/service/src/com/android/server/connectivity/ProxyTracker.java
@@ -86,6 +86,7 @@
 
     private final Handler mConnectivityServiceHandler;
 
+    @Nullable
     private final PacProxyManager mPacProxyManager;
 
     private class PacProxyInstalledListener implements PacProxyManager.PacProxyInstalledListener {
@@ -109,9 +110,11 @@
         mConnectivityServiceHandler = connectivityServiceInternalHandler;
         mPacProxyManager = context.getSystemService(PacProxyManager.class);
 
-        PacProxyInstalledListener listener = new PacProxyInstalledListener(pacChangedEvent);
-        mPacProxyManager.addPacProxyInstalledListener(
+        if (mPacProxyManager != null) {
+            PacProxyInstalledListener listener = new PacProxyInstalledListener(pacChangedEvent);
+            mPacProxyManager.addPacProxyInstalledListener(
                 mConnectivityServiceHandler::post, listener);
+        }
     }
 
     // Convert empty ProxyInfo's to null as null-checks are used to determine if proxies are present
@@ -205,7 +208,7 @@
                 mGlobalProxy = proxyProperties;
             }
 
-            if (!TextUtils.isEmpty(pacFileUrl)) {
+            if (!TextUtils.isEmpty(pacFileUrl) && mPacProxyManager != null) {
                 mConnectivityServiceHandler.post(
                         () -> mPacProxyManager.setCurrentProxyScriptUrl(proxyProperties));
             }
@@ -251,7 +254,10 @@
         final ProxyInfo defaultProxy = getDefaultProxy();
         final ProxyInfo proxyInfo = null != defaultProxy ?
                 defaultProxy : ProxyInfo.buildDirectProxy("", 0, Collections.emptyList());
-        mPacProxyManager.setCurrentProxyScriptUrl(proxyInfo);
+
+        if (mPacProxyManager != null) {
+            mPacProxyManager.setCurrentProxyScriptUrl(proxyInfo);
+        }
 
         if (!shouldSendBroadcast(proxyInfo)) {
             return;
diff --git a/tests/common/OWNERS b/tests/common/OWNERS
new file mode 100644
index 0000000..3101da5
--- /dev/null
+++ b/tests/common/OWNERS
@@ -0,0 +1,2 @@
+# Bug template url: http://b/new?component=31808
+# TODO: move bug template config to common owners file once b/226427845 is resolved
\ No newline at end of file
diff --git a/tests/common/java/android/net/NattKeepalivePacketDataTest.kt b/tests/common/java/android/net/NattKeepalivePacketDataTest.kt
index 36b9a91..dde1d86 100644
--- a/tests/common/java/android/net/NattKeepalivePacketDataTest.kt
+++ b/tests/common/java/android/net/NattKeepalivePacketDataTest.kt
@@ -22,6 +22,7 @@
 import android.os.Build
 import androidx.test.filters.SmallTest
 import androidx.test.runner.AndroidJUnit4
+import com.android.testutils.ConnectivityModuleTest
 import com.android.testutils.DevSdkIgnoreRule
 import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo
 import com.android.testutils.assertEqualBothWays
@@ -46,6 +47,8 @@
     private val TEST_SRC_ADDRV4 = "198.168.0.2".address()
     private val TEST_DST_ADDRV4 = "198.168.0.1".address()
     private val TEST_ADDRV6 = "2001:db8::1".address()
+    private val TEST_ADDRV4MAPPEDV6 = "::ffff:1.2.3.4".address()
+    private val TEST_ADDRV4 = "1.2.3.4".address()
 
     private fun String.address() = InetAddresses.parseNumericAddress(this)
     private fun nattKeepalivePacket(
@@ -83,6 +86,28 @@
         }
     }
 
+    @Test @IgnoreUpTo(Build.VERSION_CODES.R) @ConnectivityModuleTest
+    fun testConstructor_afterR() {
+        // v4 mapped v6 will be translated to a v4 address.
+        assertFailsWith<InvalidPacketException> {
+            nattKeepalivePacket(srcAddress = TEST_ADDRV6, dstAddress = TEST_ADDRV4MAPPEDV6)
+        }
+        assertFailsWith<InvalidPacketException> {
+            nattKeepalivePacket(srcAddress = TEST_ADDRV4MAPPEDV6, dstAddress = TEST_ADDRV6)
+        }
+
+        // Both src and dst address will be v4 after translation, so it won't cause exception.
+        val packet1 = nattKeepalivePacket(
+            dstAddress = TEST_ADDRV4MAPPEDV6, srcAddress = TEST_ADDRV4MAPPEDV6)
+        assertEquals(TEST_ADDRV4, packet1.srcAddress)
+        assertEquals(TEST_ADDRV4, packet1.dstAddress)
+
+        // Packet with v6 src and v6 dst address is valid.
+        val packet2 = nattKeepalivePacket(srcAddress = TEST_ADDRV6, dstAddress = TEST_ADDRV6)
+        assertEquals(TEST_ADDRV6, packet2.srcAddress)
+        assertEquals(TEST_ADDRV6, packet2.dstAddress)
+    }
+
     @Test @IgnoreUpTo(Build.VERSION_CODES.Q)
     fun testParcel() {
         assertParcelingIsLossless(nattKeepalivePacket())
diff --git a/tests/cts/OWNERS b/tests/cts/OWNERS
index 8388cb7..8c2408b 100644
--- a/tests/cts/OWNERS
+++ b/tests/cts/OWNERS
@@ -1,4 +1,5 @@
 # Bug template url: http://b/new?component=31808
+# TODO: move bug template config to common owners file once b/226427845 is resolved
 set noparent
 file:platform/packages/modules/Connectivity:master:/OWNERS_core_networking_xts
 
diff --git a/tests/cts/hostside/app/src/com/android/cts/net/hostside/VpnTest.java b/tests/cts/hostside/app/src/com/android/cts/net/hostside/VpnTest.java
index 73a6502..cd3b650 100755
--- a/tests/cts/hostside/app/src/com/android/cts/net/hostside/VpnTest.java
+++ b/tests/cts/hostside/app/src/com/android/cts/net/hostside/VpnTest.java
@@ -1175,8 +1175,8 @@
 
         final String origMode = runWithShellPermissionIdentity(() -> {
             final String mode = DeviceConfig.getProperty(
-                    DeviceConfig.NAMESPACE_CONNECTIVITY, AUTOMATIC_ON_OFF_KEEPALIVE_VERSION);
-            DeviceConfig.setProperty(DeviceConfig.NAMESPACE_CONNECTIVITY,
+                    DeviceConfig.NAMESPACE_TETHERING, AUTOMATIC_ON_OFF_KEEPALIVE_VERSION);
+            DeviceConfig.setProperty(DeviceConfig.NAMESPACE_TETHERING,
                     AUTOMATIC_ON_OFF_KEEPALIVE_VERSION,
                     AUTOMATIC_ON_OFF_KEEPALIVE_ENABLED, false /* makeDefault */);
             return mode;
@@ -1216,7 +1216,7 @@
 
             runWithShellPermissionIdentity(() -> {
                 DeviceConfig.setProperty(
-                                DeviceConfig.NAMESPACE_CONNECTIVITY,
+                                DeviceConfig.NAMESPACE_TETHERING,
                                 AUTOMATIC_ON_OFF_KEEPALIVE_VERSION,
                                 origMode, false);
                 mCM.setTestLowTcpPollingTimerForKeepalive(0);
diff --git a/tests/cts/hostside/src/com/android/cts/net/ProcNetTest.java b/tests/cts/hostside/src/com/android/cts/net/ProcNetTest.java
index 1a528b1..ff06a90 100644
--- a/tests/cts/hostside/src/com/android/cts/net/ProcNetTest.java
+++ b/tests/cts/hostside/src/com/android/cts/net/ProcNetTest.java
@@ -157,13 +157,13 @@
         for (String interfaceDir : mSysctlDirs) {
             String path = IPV6_SYSCTL_DIR + "/" + interfaceDir + "/" + "router_solicitations";
             int value = readIntFromPath(path);
-            assertEquals(IPV6_WIFI_ROUTER_SOLICITATIONS, value);
+            assertEquals(path, IPV6_WIFI_ROUTER_SOLICITATIONS, value);
             path = IPV6_SYSCTL_DIR + "/" + interfaceDir + "/" + "router_solicitation_max_interval";
             int interval = readIntFromPath(path);
             final int lowerBoundSec = 15 * 60;
             final int upperBoundSec = 60 * 60;
-            assertTrue(lowerBoundSec <= interval);
-            assertTrue(interval <= upperBoundSec);
+            assertTrue(path, lowerBoundSec <= interval);
+            assertTrue(path, interval <= upperBoundSec);
         }
     }
 
diff --git a/tests/cts/net/Android.bp b/tests/cts/net/Android.bp
index 51ee074..1276d59 100644
--- a/tests/cts/net/Android.bp
+++ b/tests/cts/net/Android.bp
@@ -92,7 +92,7 @@
         "NetworkStackApiStableShims",
     ],
     jni_uses_sdk_apis: true,
-    min_sdk_version: "29",
+    min_sdk_version: "30",
 }
 
 // Networking CTS tests that target the latest released SDK. These tests can be installed on release
diff --git a/tests/cts/net/jni/Android.bp b/tests/cts/net/jni/Android.bp
index 8f0d78f..a421349 100644
--- a/tests/cts/net/jni/Android.bp
+++ b/tests/cts/net/jni/Android.bp
@@ -31,9 +31,9 @@
         "liblog",
     ],
     stl: "libc++_static",
-    // To be compatible with Q devices, the min_sdk_version must be 29.
+    // To be compatible with R devices, the min_sdk_version must be 30.
     sdk_version: "current",
-    min_sdk_version: "29",
+    min_sdk_version: "30",
 }
 
 cc_library_shared {
diff --git a/tests/cts/net/native/dns/Android.bp b/tests/cts/net/native/dns/Android.bp
index 49b9337..2469710 100644
--- a/tests/cts/net/native/dns/Android.bp
+++ b/tests/cts/net/native/dns/Android.bp
@@ -28,8 +28,8 @@
         "libbase",
         "libnetdutils",
     ],
-    // To be compatible with Q devices, the min_sdk_version must be 29.
-    min_sdk_version: "29",
+    // To be compatible with R devices, the min_sdk_version must be 30.
+    min_sdk_version: "30",
 }
 
 cc_test {
diff --git a/tests/cts/net/src/android/net/cts/ConnectivityManagerTest.java b/tests/cts/net/src/android/net/cts/ConnectivityManagerTest.java
index ee2f6bb..a0508e1 100644
--- a/tests/cts/net/src/android/net/cts/ConnectivityManagerTest.java
+++ b/tests/cts/net/src/android/net/cts/ConnectivityManagerTest.java
@@ -847,7 +847,7 @@
         assumeTrue(mPackageManager.hasSystemFeature(FEATURE_WIFI));
         assumeTrue(mPackageManager.hasSystemFeature(FEATURE_TELEPHONY));
 
-        Network wifiNetwork = mCtsNetUtils.connectToWifi();
+        Network wifiNetwork = mCtsNetUtils.ensureWifiConnected();
         Network cellNetwork = mCtsNetUtils.connectToCell();
         // This server returns the requestor's IP address as the response body.
         URL url = new URL("http://google-ipv6test.appspot.com/ip.js?fmt=text");
@@ -2560,9 +2560,10 @@
         assertThrows(SecurityException.class, () -> mCm.factoryReset());
     }
 
-    @AppModeFull(reason = "Cannot get WifiManager in instant app mode")
-    @Test
-    public void testFactoryReset() throws Exception {
+    // @AppModeFull(reason = "Cannot get WifiManager in instant app mode")
+    // @Test
+    // Temporarily disable the unreliable test, which is blocked by b/254183718.
+    private void testFactoryReset() throws Exception {
         assumeTrue(TestUtils.shouldTestSApis());
 
         // Store current settings.
@@ -3151,7 +3152,8 @@
     }
 
     @AppModeFull(reason = "Need WiFi support to test the default active network")
-    @Test
+    // NetworkActivityTracker is not mainlined before S.
+    @Test @DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.R)
     public void testDefaultNetworkActiveListener() throws Exception {
         final boolean supportWifi = mPackageManager.hasSystemFeature(FEATURE_WIFI);
         final boolean supportTelephony = mPackageManager.hasSystemFeature(FEATURE_TELEPHONY);
@@ -3235,7 +3237,8 @@
             newMobileDataPreferredUids.add(uid);
             ConnectivitySettingsManager.setMobileDataPreferredUids(
                     mContext, newMobileDataPreferredUids);
-            waitForAvailable(defaultTrackingCb, cellNetwork);
+            defaultTrackingCb.eventuallyExpect(CallbackEntry.AVAILABLE, NETWORK_CALLBACK_TIMEOUT_MS,
+                    entry -> cellNetwork.equals(entry.getNetwork()));
             // No change for system default network. Expect no callback except CapabilitiesChanged
             // or LinkPropertiesChanged which may be triggered randomly from wifi network.
             assertNoCallbackExceptCapOrLpChange(systemDefaultCb);
@@ -3247,7 +3250,8 @@
             newMobileDataPreferredUids.remove(uid);
             ConnectivitySettingsManager.setMobileDataPreferredUids(
                     mContext, newMobileDataPreferredUids);
-            waitForAvailable(defaultTrackingCb, wifiNetwork);
+            defaultTrackingCb.eventuallyExpect(CallbackEntry.AVAILABLE, NETWORK_CALLBACK_TIMEOUT_MS,
+                    entry -> wifiNetwork.equals(entry.getNetwork()));
             // No change for system default network. Expect no callback except CapabilitiesChanged
             // or LinkPropertiesChanged which may be triggered randomly from wifi network.
             assertNoCallbackExceptCapOrLpChange(systemDefaultCb);
@@ -3410,6 +3414,7 @@
 
     private void checkFirewallBlocking(final DatagramSocket srcSock, final DatagramSocket dstSock,
             final boolean expectBlock, final int chain) throws Exception {
+        final int uid = Process.myUid();
         final Random random = new Random();
         final byte[] sendData = new byte[100];
         random.nextBytes(sendData);
@@ -3425,7 +3430,8 @@
             fail("Expect not to be blocked by firewall but sending packet was blocked:"
                     + " chain=" + chain
                     + " chainEnabled=" + mCm.getFirewallChainEnabled(chain)
-                    + " uidFirewallRule=" + mCm.getUidFirewallRule(chain, Process.myUid()));
+                    + " uid=" + uid
+                    + " uidFirewallRule=" + mCm.getUidFirewallRule(chain, uid));
         }
 
         dstSock.receive(pkt);
@@ -3435,7 +3441,8 @@
             fail("Expect to be blocked by firewall but sending packet was not blocked:"
                     + " chain=" + chain
                     + " chainEnabled=" + mCm.getFirewallChainEnabled(chain)
-                    + " uidFirewallRule=" + mCm.getUidFirewallRule(chain, Process.myUid()));
+                    + " uid=" + uid
+                    + " uidFirewallRule=" + mCm.getUidFirewallRule(chain, uid));
         }
     }
 
diff --git a/tests/cts/net/src/android/net/cts/PacProxyManagerTest.java b/tests/cts/net/src/android/net/cts/PacProxyManagerTest.java
index 4854901..b462f71 100644
--- a/tests/cts/net/src/android/net/cts/PacProxyManagerTest.java
+++ b/tests/cts/net/src/android/net/cts/PacProxyManagerTest.java
@@ -44,12 +44,15 @@
 
 import androidx.test.InstrumentationRegistry;
 
+import com.android.compatibility.common.util.RequiredFeatureRule;
+
 import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo;
 import com.android.testutils.DevSdkIgnoreRunner;
 import com.android.testutils.TestHttpServer;
 
 import org.junit.After;
 import org.junit.Before;
+import org.junit.Rule;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 
@@ -76,6 +79,11 @@
     private ServerSocket mServerSocket;
     private Instrumentation mInstrumentation;
 
+    // Devices without WebView/JavaScript cannot support PAC proxies.
+    @Rule
+    public RequiredFeatureRule mRequiredWebviewFeatureRule =
+        new RequiredFeatureRule(PackageManager.FEATURE_WEBVIEW);
+
     private static final String PAC_FILE = "function FindProxyForURL(url, host)"
             + "{"
             + "  return \"PROXY 192.168.0.1:9091\";"
@@ -152,9 +160,6 @@
     @AppModeFull(reason = "Instant apps can't bind sockets to localhost for a test proxy server")
     @Test
     public void testSetCurrentProxyScriptUrl() throws Exception {
-        // Devices without WebView/JavaScript cannot support PAC proxies
-        assumeTrue(mContext.getPackageManager().hasSystemFeature(PackageManager.FEATURE_WEBVIEW));
-
         // Register a PacProxyInstalledListener
         final TestPacProxyInstalledListener listener = new TestPacProxyInstalledListener();
         final Executor executor = (Runnable r) -> r.run();
diff --git a/tests/cts/net/util/java/android/net/cts/util/CtsNetUtils.java b/tests/cts/net/util/java/android/net/cts/util/CtsNetUtils.java
index ce789fc..21f1358 100644
--- a/tests/cts/net/util/java/android/net/cts/util/CtsNetUtils.java
+++ b/tests/cts/net/util/java/android/net/cts/util/CtsNetUtils.java
@@ -86,7 +86,7 @@
     private static final String FEATURE_IPSEC_TUNNEL_MIGRATION =
             "android.software.ipsec_tunnel_migration";
 
-    private static final int SOCKET_TIMEOUT_MS = 2000;
+    private static final int SOCKET_TIMEOUT_MS = 10_000;
     private static final int PRIVATE_DNS_PROBE_MS = 1_000;
 
     private static final int PRIVATE_DNS_SETTING_TIMEOUT_MS = 10_000;
diff --git a/tests/cts/netpermission/internetpermission/Android.bp b/tests/cts/netpermission/internetpermission/Android.bp
index 37ad7cb..5314396 100644
--- a/tests/cts/netpermission/internetpermission/Android.bp
+++ b/tests/cts/netpermission/internetpermission/Android.bp
@@ -29,5 +29,5 @@
         "cts",
         "general-tests",
     ],
-
+    host_required: ["net-tests-utils-host-common"],
 }
diff --git a/tests/cts/netpermission/internetpermission/AndroidTest.xml b/tests/cts/netpermission/internetpermission/AndroidTest.xml
index 3b23e72..e326844 100644
--- a/tests/cts/netpermission/internetpermission/AndroidTest.xml
+++ b/tests/cts/netpermission/internetpermission/AndroidTest.xml
@@ -24,6 +24,8 @@
         <option name="cleanup-apks" value="true" />
         <option name="test-file-name" value="CtsNetTestCasesInternetPermission.apk" />
     </target_preparer>
+    <target_preparer class="com.android.testutils.ConnectivityTestTargetPreparer">
+    </target_preparer>
     <test class="com.android.tradefed.testtype.AndroidJUnitTest" >
         <option name="package" value="android.networkpermission.internetpermission.cts" />
         <option name="runtime-hint" value="10s" />
diff --git a/tests/cts/netpermission/updatestatspermission/Android.bp b/tests/cts/netpermission/updatestatspermission/Android.bp
index 7a24886..40474db 100644
--- a/tests/cts/netpermission/updatestatspermission/Android.bp
+++ b/tests/cts/netpermission/updatestatspermission/Android.bp
@@ -29,5 +29,5 @@
         "cts",
         "general-tests",
     ],
-
+    host_required: ["net-tests-utils-host-common"],
 }
diff --git a/tests/cts/netpermission/updatestatspermission/AndroidTest.xml b/tests/cts/netpermission/updatestatspermission/AndroidTest.xml
index c47cad9..a1019fa 100644
--- a/tests/cts/netpermission/updatestatspermission/AndroidTest.xml
+++ b/tests/cts/netpermission/updatestatspermission/AndroidTest.xml
@@ -24,6 +24,8 @@
         <option name="cleanup-apks" value="true" />
         <option name="test-file-name" value="CtsNetTestCasesUpdateStatsPermission.apk" />
     </target_preparer>
+    <target_preparer class="com.android.testutils.ConnectivityTestTargetPreparer">
+    </target_preparer>
     <test class="com.android.tradefed.testtype.AndroidJUnitTest" >
         <option name="package" value="android.networkpermission.updatestatspermission.cts" />
         <option name="runtime-hint" value="10s" />
diff --git a/tests/integration/OWNERS b/tests/integration/OWNERS
new file mode 100644
index 0000000..3101da5
--- /dev/null
+++ b/tests/integration/OWNERS
@@ -0,0 +1,2 @@
+# Bug template url: http://b/new?component=31808
+# TODO: move bug template config to common owners file once b/226427845 is resolved
\ No newline at end of file
diff --git a/tests/mts/Android.bp b/tests/mts/Android.bp
index 74fee3d..6425223 100644
--- a/tests/mts/Android.bp
+++ b/tests/mts/Android.bp
@@ -38,5 +38,5 @@
         "bpf_existence_test.cpp",
     ],
     compile_multilib: "first",
-    min_sdk_version: "29",  // Ensure test runs on Q and above.
+    min_sdk_version: "30",  // Ensure test runs on R and above.
 }
diff --git a/tests/mts/OWNERS b/tests/mts/OWNERS
new file mode 100644
index 0000000..3101da5
--- /dev/null
+++ b/tests/mts/OWNERS
@@ -0,0 +1,2 @@
+# Bug template url: http://b/new?component=31808
+# TODO: move bug template config to common owners file once b/226427845 is resolved
\ No newline at end of file
diff --git a/tests/native/connectivity_native_test/OWNERS b/tests/native/connectivity_native_test/OWNERS
index 8dfa455..fbfcf92 100644
--- a/tests/native/connectivity_native_test/OWNERS
+++ b/tests/native/connectivity_native_test/OWNERS
@@ -1,3 +1,4 @@
-# Bug component: 31808
+# Bug template url: http://b/new?component=31808
+# TODO: move bug template config to common owners file once b/226427845 is resolved
 set noparent
 file:platform/packages/modules/Connectivity:master:/OWNERS_core_networking_xts
diff --git a/tests/unit/OWNERS b/tests/unit/OWNERS
new file mode 100644
index 0000000..3101da5
--- /dev/null
+++ b/tests/unit/OWNERS
@@ -0,0 +1,2 @@
+# Bug template url: http://b/new?component=31808
+# TODO: move bug template config to common owners file once b/226427845 is resolved
\ No newline at end of file
diff --git a/tests/unit/java/android/net/KeepalivePacketDataUtilTest.java b/tests/unit/java/android/net/KeepalivePacketDataUtilTest.java
index 6afa4e9..7e245dc 100644
--- a/tests/unit/java/android/net/KeepalivePacketDataUtilTest.java
+++ b/tests/unit/java/android/net/KeepalivePacketDataUtilTest.java
@@ -168,12 +168,6 @@
         assertEquals(resultData.rcvWndScale, wndScale);
         assertEquals(resultData.tos, tos);
         assertEquals(resultData.ttl, ttl);
-
-        final String expected = TcpKeepalivePacketDataParcelable.class.getName()
-                + "{srcAddress: [10, 0, 0, 1],"
-                + " srcPort: 1234, dstAddress: [10, 0, 0, 5], dstPort: 4321, seq: 286331153,"
-                + " ack: 572662306, rcvWnd: 48000, rcvWndScale: 2, tos: 4, ttl: 64}";
-        assertEquals(expected, resultData.toString());
     }
 
     @Test
diff --git a/tests/unit/java/com/android/server/ConnectivityServiceTest.java b/tests/unit/java/com/android/server/ConnectivityServiceTest.java
index 57c3acc..9406edd 100755
--- a/tests/unit/java/com/android/server/ConnectivityServiceTest.java
+++ b/tests/unit/java/com/android/server/ConnectivityServiceTest.java
@@ -46,14 +46,18 @@
 import static android.content.pm.PackageManager.PERMISSION_DENIED;
 import static android.content.pm.PackageManager.PERMISSION_GRANTED;
 import static android.net.ConnectivityManager.ACTION_CAPTIVE_PORTAL_SIGN_IN;
+import static android.net.ConnectivityManager.ACTION_DATA_ACTIVITY_CHANGE;
 import static android.net.ConnectivityManager.BLOCKED_METERED_REASON_DATA_SAVER;
 import static android.net.ConnectivityManager.BLOCKED_METERED_REASON_MASK;
 import static android.net.ConnectivityManager.BLOCKED_METERED_REASON_USER_RESTRICTED;
 import static android.net.ConnectivityManager.BLOCKED_REASON_BATTERY_SAVER;
 import static android.net.ConnectivityManager.BLOCKED_REASON_NONE;
 import static android.net.ConnectivityManager.CONNECTIVITY_ACTION;
+import static android.net.ConnectivityManager.EXTRA_DEVICE_TYPE;
+import static android.net.ConnectivityManager.EXTRA_IS_ACTIVE;
 import static android.net.ConnectivityManager.EXTRA_NETWORK_INFO;
 import static android.net.ConnectivityManager.EXTRA_NETWORK_TYPE;
+import static android.net.ConnectivityManager.EXTRA_REALTIME_NS;
 import static android.net.ConnectivityManager.FIREWALL_CHAIN_DOZABLE;
 import static android.net.ConnectivityManager.FIREWALL_CHAIN_LOW_POWER_STANDBY;
 import static android.net.ConnectivityManager.FIREWALL_CHAIN_OEM_DENY_1;
@@ -161,6 +165,7 @@
 import static com.android.server.ConnectivityServiceTestUtils.transportToLegacyType;
 import static com.android.server.NetworkAgentWrapper.CallbackType.OnQosCallbackRegister;
 import static com.android.server.NetworkAgentWrapper.CallbackType.OnQosCallbackUnregister;
+import static com.android.testutils.Cleanup.testAndCleanup;
 import static com.android.testutils.ConcurrentUtils.await;
 import static com.android.testutils.ConcurrentUtils.durationOf;
 import static com.android.testutils.DevSdkIgnoreRule.IgnoreAfter;
@@ -379,6 +384,7 @@
 import com.android.internal.util.test.BroadcastInterceptingContext;
 import com.android.internal.util.test.FakeSettingsProvider;
 import com.android.net.module.util.ArrayTrackRecord;
+import com.android.net.module.util.BaseNetdUnsolicitedEventListener;
 import com.android.net.module.util.CollectionUtils;
 import com.android.net.module.util.LocationPermissionChecker;
 import com.android.net.module.util.NetworkMonitorUtils;
@@ -537,10 +543,12 @@
     private static final String WIFI_IFNAME = "test_wlan0";
     private static final String WIFI_WOL_IFNAME = "test_wlan_wol";
     private static final String VPN_IFNAME = "tun10042";
+    private static final String ETHERNET_IFNAME = "eth0";
     private static final String TEST_PACKAGE_NAME = "com.android.test.package";
     private static final int TEST_PACKAGE_UID = 123;
     private static final int TEST_PACKAGE_UID2 = 321;
     private static final int TEST_PACKAGE_UID3 = 456;
+    private static final int NETWORK_ACTIVITY_NO_UID = -1;
 
     private static final int PACKET_WAKEUP_MARK_MASK = 0x80000000;
 
@@ -888,6 +896,25 @@
             }
             super.sendStickyBroadcast(intent, options);
         }
+
+        private final ArrayTrackRecord<Intent>.ReadHead mOrderedBroadcastAsUserHistory =
+                new ArrayTrackRecord<Intent>().newReadHead();
+
+        public void expectDataActivityBroadcast(int deviceType, boolean isActive, long tsNanos) {
+            assertNotNull(mOrderedBroadcastAsUserHistory.poll(BROADCAST_TIMEOUT_MS,
+                    intent -> intent.getAction().equals(ACTION_DATA_ACTIVITY_CHANGE)
+                            && intent.getIntExtra(EXTRA_DEVICE_TYPE, -1) == deviceType
+                            && intent.getBooleanExtra(EXTRA_IS_ACTIVE, !isActive) == isActive
+                            && intent.getLongExtra(EXTRA_REALTIME_NS, -1) == tsNanos
+            ));
+        }
+
+        @Override
+        public void sendOrderedBroadcastAsUser(Intent intent, UserHandle user,
+                String receiverPermission, BroadcastReceiver resultReceiver,
+                Handler scheduler, int initialCode, String initialData, Bundle initialExtras) {
+            mOrderedBroadcastAsUserHistory.add(intent);
+        }
     }
 
     // This was only added in the T SDK, but this test needs to build against the R+S SDKs, too.
@@ -2247,12 +2274,12 @@
             mDestroySocketsWrapper.destroyLiveTcpSocketsByOwnerUids(ownerUids);
         }
 
-        final ArrayTrackRecord<Long>.ReadHead mScheduledEvaluationTimeouts =
-                new ArrayTrackRecord<Long>().newReadHead();
+        final ArrayTrackRecord<Pair<Integer, Long>>.ReadHead mScheduledEvaluationTimeouts =
+                new ArrayTrackRecord<Pair<Integer, Long>>().newReadHead();
         @Override
         public void scheduleEvaluationTimeout(@NonNull Handler handler,
                 @NonNull final Network network, final long delayMs) {
-            mScheduledEvaluationTimeouts.add(delayMs);
+            mScheduledEvaluationTimeouts.add(new Pair<>(network.netId, delayMs));
             super.scheduleEvaluationTimeout(handler, network, delayMs);
         }
     }
@@ -6124,7 +6151,7 @@
     }
 
     public void doTestPreferBadWifi(final boolean avoidBadWifi,
-            final boolean preferBadWifi,
+            final boolean preferBadWifi, final boolean explicitlySelected,
             @NonNull Predicate<Long> checkUnvalidationTimeout) throws Exception {
         // Pretend we're on a carrier that restricts switching away from bad wifi, and
         // depending on the parameter one that may indeed prefer bad wifi.
@@ -6148,10 +6175,13 @@
         mDefaultNetworkCallback.expectAvailableThenValidatedCallbacks(mCellAgent);
 
         mWiFiAgent = new TestNetworkAgentWrapper(TRANSPORT_WIFI);
+        mWiFiAgent.explicitlySelected(explicitlySelected, false /* acceptUnvalidated */);
         mWiFiAgent.connect(false);
         wifiCallback.expectAvailableCallbacksUnvalidated(mWiFiAgent);
 
-        mDeps.mScheduledEvaluationTimeouts.poll(TIMEOUT_MS, t -> checkUnvalidationTimeout.test(t));
+        assertNotNull(mDeps.mScheduledEvaluationTimeouts.poll(TIMEOUT_MS,
+                t -> t.first == mWiFiAgent.getNetwork().netId
+                        && checkUnvalidationTimeout.test(t.second)));
 
         if (!avoidBadWifi && preferBadWifi) {
             expectUnvalidationCheckWillNotify(mWiFiAgent, NotificationType.LOST_INTERNET);
@@ -6167,27 +6197,33 @@
         // Starting with U this mode is no longer supported and can't actually be tested
         assumeFalse(mDeps.isAtLeastU());
         doTestPreferBadWifi(false /* avoidBadWifi */, false /* preferBadWifi */,
-                timeout -> timeout < 14_000);
+                false /* explicitlySelected */, timeout -> timeout < 14_000);
     }
 
     @Test
-    public void testPreferBadWifi_doNotAvoid_doPrefer() throws Exception {
+    public void testPreferBadWifi_doNotAvoid_doPrefer_notExplicit() throws Exception {
         doTestPreferBadWifi(false /* avoidBadWifi */, true /* preferBadWifi */,
-                timeout -> timeout > 14_000);
+                false /* explicitlySelected */, timeout -> timeout > 14_000);
+    }
+
+    @Test
+    public void testPreferBadWifi_doNotAvoid_doPrefer_explicitlySelected() throws Exception {
+        doTestPreferBadWifi(false /* avoidBadWifi */, true /* preferBadWifi */,
+                true /* explicitlySelected */, timeout -> timeout < 14_000);
     }
 
     @Test
     public void testPreferBadWifi_doAvoid_doNotPrefer() throws Exception {
         // If avoidBadWifi=true, then preferBadWifi should be irrelevant. Test anyway.
         doTestPreferBadWifi(true /* avoidBadWifi */, false /* preferBadWifi */,
-                timeout -> timeout < 14_000);
+                false /* explicitlySelected */, timeout -> timeout < 14_000);
     }
 
     @Test
     public void testPreferBadWifi_doAvoid_doPrefer() throws Exception {
         // If avoidBadWifi=true, then preferBadWifi should be irrelevant. Test anyway.
         doTestPreferBadWifi(true /* avoidBadWifi */, true /* preferBadWifi */,
-                timeout -> timeout < 14_000);
+                false /* explicitlySelected */, timeout -> timeout < 14_000);
     }
 
     @Test
@@ -6856,10 +6892,6 @@
         ka = mCm.startNattKeepalive(myNet, validKaInterval, callback, myIPv6, 1234, dstIPv4);
         callback.expectError(PacketKeepalive.ERROR_INVALID_IP_ADDRESS);
 
-        // NAT-T is only supported for IPv4.
-        ka = mCm.startNattKeepalive(myNet, validKaInterval, callback, myIPv6, 1234, dstIPv6);
-        callback.expectError(PacketKeepalive.ERROR_INVALID_IP_ADDRESS);
-
         ka = mCm.startNattKeepalive(myNet, validKaInterval, callback, myIPv4, 123456, dstIPv4);
         callback.expectError(PacketKeepalive.ERROR_INVALID_PORT);
 
@@ -7010,13 +7042,6 @@
             callback.expectError(SocketKeepalive.ERROR_INVALID_IP_ADDRESS);
         }
 
-        // NAT-T is only supported for IPv4.
-        try (SocketKeepalive ka = mCm.createSocketKeepalive(
-                myNet, testSocket, myIPv6, dstIPv6, executor, callback)) {
-            ka.start(validKaInterval);
-            callback.expectError(SocketKeepalive.ERROR_INVALID_IP_ADDRESS);
-        }
-
         // Basic check before testing started keepalive.
         try (SocketKeepalive ka = mCm.createSocketKeepalive(
                 myNet, testSocket, myIPv4, dstIPv4, executor, callback)) {
@@ -11195,6 +11220,127 @@
         assertTrue("Nat464Xlat was not IDLE", !clat.isStarted());
     }
 
+    final String transportToTestIfaceName(int transport) {
+        switch (transport) {
+            case TRANSPORT_WIFI:
+                return WIFI_IFNAME;
+            case TRANSPORT_CELLULAR:
+                return MOBILE_IFNAME;
+            case TRANSPORT_ETHERNET:
+                return ETHERNET_IFNAME;
+            default:
+                throw new AssertionError("Unsupported transport type");
+        }
+    }
+
+    private void doTestInterfaceClassActivityChanged(final int transportType) throws Exception {
+        final int legacyType = transportToLegacyType(transportType);
+        final LinkProperties lp = new LinkProperties();
+        lp.setInterfaceName(transportToTestIfaceName(transportType));
+        final TestNetworkAgentWrapper agent = new TestNetworkAgentWrapper(transportType, lp);
+
+        final ConditionVariable onNetworkActiveCv = new ConditionVariable();
+        final ConnectivityManager.OnNetworkActiveListener listener = onNetworkActiveCv::open;
+
+        testAndCleanup(() -> {
+            agent.connect(true);
+
+            // Network is considered active when the network becomes the default network.
+            assertTrue(mCm.isDefaultNetworkActive());
+
+            mCm.addDefaultNetworkActiveListener(listener);
+
+            ArgumentCaptor<BaseNetdUnsolicitedEventListener> netdCallbackCaptor =
+                    ArgumentCaptor.forClass(BaseNetdUnsolicitedEventListener.class);
+            verify(mMockNetd).registerUnsolicitedEventListener(netdCallbackCaptor.capture());
+
+            // Interface goes to inactive state
+            netdCallbackCaptor.getValue().onInterfaceClassActivityChanged(false /* isActive */,
+                    transportType, TIMESTAMP, NETWORK_ACTIVITY_NO_UID);
+            mServiceContext.expectDataActivityBroadcast(legacyType, false /* isActive */,
+                    TIMESTAMP);
+            assertFalse(onNetworkActiveCv.block(TEST_CALLBACK_TIMEOUT_MS));
+            assertFalse(mCm.isDefaultNetworkActive());
+
+            // Interface goes to active state
+            netdCallbackCaptor.getValue().onInterfaceClassActivityChanged(true /* isActive */,
+                    transportType, TIMESTAMP, TEST_PACKAGE_UID);
+            mServiceContext.expectDataActivityBroadcast(legacyType, true /* isActive */, TIMESTAMP);
+            assertTrue(onNetworkActiveCv.block(TEST_CALLBACK_TIMEOUT_MS));
+            assertTrue(mCm.isDefaultNetworkActive());
+        }, () -> { // Cleanup
+                mCm.removeDefaultNetworkActiveListener(listener);
+            }, () -> { // Cleanup
+                agent.disconnect();
+            });
+    }
+
+    @Test
+    public void testInterfaceClassActivityChangedWifi() throws Exception {
+        doTestInterfaceClassActivityChanged(TRANSPORT_WIFI);
+    }
+
+    @Test
+    public void testInterfaceClassActivityChangedCellular() throws Exception {
+        doTestInterfaceClassActivityChanged(TRANSPORT_CELLULAR);
+    }
+
+    private void doTestOnNetworkActive_NewNetworkConnects(int transportType, boolean expectCallback)
+            throws Exception {
+        final ConditionVariable onNetworkActiveCv = new ConditionVariable();
+        final ConnectivityManager.OnNetworkActiveListener listener = onNetworkActiveCv::open;
+
+        final LinkProperties lp = new LinkProperties();
+        lp.setInterfaceName(transportToTestIfaceName(transportType));
+        final TestNetworkAgentWrapper agent = new TestNetworkAgentWrapper(transportType, lp);
+
+        testAndCleanup(() -> {
+            mCm.addDefaultNetworkActiveListener(listener);
+            agent.connect(true);
+            if (expectCallback) {
+                assertTrue(onNetworkActiveCv.block(TEST_CALLBACK_TIMEOUT_MS));
+            } else {
+                assertFalse(onNetworkActiveCv.block(TEST_CALLBACK_TIMEOUT_MS));
+            }
+            assertTrue(mCm.isDefaultNetworkActive());
+        }, () -> { // Cleanup
+                mCm.removeDefaultNetworkActiveListener(listener);
+            }, () -> { // Cleanup
+                agent.disconnect();
+            });
+    }
+
+    @Test
+    public void testOnNetworkActive_NewCellConnects_CallbackCalled() throws Exception {
+        doTestOnNetworkActive_NewNetworkConnects(TRANSPORT_CELLULAR, true /* expectCallback */);
+    }
+
+    @Test
+    public void testOnNetworkActive_NewEthernetConnects_Callback() throws Exception {
+        // On T-, LegacyNetworkActivityTracker calls onNetworkActive callback only for networks that
+        // tracker adds the idle timer to. And the tracker does not set the idle timer for the
+        // ethernet network.
+        // So onNetworkActive is not called when the ethernet becomes the default network
+        doTestOnNetworkActive_NewNetworkConnects(TRANSPORT_ETHERNET, mDeps.isAtLeastU());
+    }
+
+    @Test
+    public void testIsDefaultNetworkActiveNoDefaultNetwork() throws Exception {
+        assertFalse(mCm.isDefaultNetworkActive());
+
+        final LinkProperties cellLp = new LinkProperties();
+        cellLp.setInterfaceName(MOBILE_IFNAME);
+        mCellAgent = new TestNetworkAgentWrapper(TRANSPORT_CELLULAR, cellLp);
+        mCellAgent.connect(true);
+        // Network is considered active when the network becomes the default network.
+        assertTrue(mCm.isDefaultNetworkActive());
+
+        mCellAgent.disconnect();
+        waitForIdle();
+
+        assertFalse(mCm.isDefaultNetworkActive());
+    }
+
     @Test
     public void testDataActivityTracking() throws Exception {
         final TestNetworkCallback networkCallback = new TestNetworkCallback();
diff --git a/tests/unit/java/com/android/server/VpnManagerServiceTest.java b/tests/unit/java/com/android/server/VpnManagerServiceTest.java
index deb56ef..bf23cd1 100644
--- a/tests/unit/java/com/android/server/VpnManagerServiceTest.java
+++ b/tests/unit/java/com/android/server/VpnManagerServiceTest.java
@@ -75,12 +75,15 @@
 @IgnoreUpTo(R) // VpnManagerService is not available before R
 @SmallTest
 public class VpnManagerServiceTest extends VpnTestBase {
+    private static final String CONTEXT_ATTRIBUTION_TAG = "VPN_MANAGER";
+
     @Rule
     public final DevSdkIgnoreRule mIgnoreRule = new DevSdkIgnoreRule();
 
     private static final int TIMEOUT_MS = 2_000;
 
     @Mock Context mContext;
+    @Mock Context mContextWithoutAttributionTag;
     @Mock Context mSystemContext;
     @Mock Context mUserAllContext;
     private HandlerThread mHandlerThread;
@@ -144,6 +147,13 @@
 
         mHandlerThread = new HandlerThread("TestVpnManagerService");
         mDeps = new VpnManagerServiceDependencies();
+
+        // The attribution tag is a dependency for IKE library to collect VPN metrics correctly
+        // and thus should not be changed without updating the IKE code.
+        doReturn(mContext)
+                .when(mContextWithoutAttributionTag)
+                .createAttributionContext(CONTEXT_ATTRIBUTION_TAG);
+
         doReturn(mUserAllContext).when(mContext).createContextAsUser(UserHandle.ALL, 0);
         doReturn(mSystemContext).when(mContext).createContextAsUser(UserHandle.SYSTEM, 0);
         doReturn(mPackageManager).when(mContext).getPackageManager();
@@ -153,7 +163,7 @@
         mockService(mContext, UserManager.class, Context.USER_SERVICE, mUserManager);
         doReturn(SYSTEM_USER).when(mUserManager).getUserInfo(eq(SYSTEM_USER_ID));
 
-        mService = new VpnManagerService(mContext, mDeps);
+        mService = new VpnManagerService(mContextWithoutAttributionTag, mDeps);
         mService.systemReady();
 
         final ArgumentCaptor<BroadcastReceiver> intentReceiverCaptor =
diff --git a/tests/unit/java/com/android/server/connectivity/AutomaticOnOffKeepaliveTrackerTest.java b/tests/unit/java/com/android/server/connectivity/AutomaticOnOffKeepaliveTrackerTest.java
index 0aecd64..9e604e3 100644
--- a/tests/unit/java/com/android/server/connectivity/AutomaticOnOffKeepaliveTrackerTest.java
+++ b/tests/unit/java/com/android/server/connectivity/AutomaticOnOffKeepaliveTrackerTest.java
@@ -21,6 +21,7 @@
 import static android.net.NetworkAgent.CMD_STOP_SOCKET_KEEPALIVE;
 import static android.net.NetworkCapabilities.TRANSPORT_CELLULAR;
 
+import static com.android.server.connectivity.AutomaticOnOffKeepaliveTracker.METRICS_COLLECTION_DURATION_MS;
 import static com.android.testutils.HandlerUtils.visibleOnHandlerThread;
 
 import static org.junit.Assert.assertEquals;
@@ -42,6 +43,7 @@
 import static org.mockito.Mockito.ignoreStubs;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.spy;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.verifyNoMoreInteractions;
 
@@ -125,6 +127,7 @@
     @Mock NetworkAgentInfo mNai;
     @Mock SubscriptionManager mSubscriptionManager;
 
+    KeepaliveStatsTracker mKeepaliveStatsTracker;
     TestKeepaliveTracker mKeepaliveTracker;
     AOOTestHandler mTestHandler;
     TestTcpKeepaliveController mTcpController;
@@ -262,7 +265,7 @@
 
         TestKeepaliveTracker(@NonNull final Context context, @NonNull final Handler handler,
                 @NonNull final TcpKeepaliveController tcpController) {
-            super(context, handler, tcpController);
+            super(context, handler, tcpController, new Dependencies());
         }
 
         public void setReturnedKeepaliveInfo(@NonNull final KeepaliveInfo ki) {
@@ -344,8 +347,14 @@
         mTestHandler = new AOOTestHandler(mHandlerThread.getLooper());
         mTcpController = new TestTcpKeepaliveController(mTestHandler);
         mKeepaliveTracker = new TestKeepaliveTracker(mCtx, mTestHandler, mTcpController);
+        mKeepaliveStatsTracker = spy(new KeepaliveStatsTracker(mCtx, mTestHandler));
         doReturn(mKeepaliveTracker).when(mDependencies).newKeepaliveTracker(mCtx, mTestHandler);
+        doReturn(mKeepaliveStatsTracker)
+                .when(mDependencies)
+                .newKeepaliveStatsTracker(mCtx, mTestHandler);
+
         doReturn(true).when(mDependencies).isFeatureEnabled(any(), anyBoolean());
+        doReturn(0L).when(mDependencies).getElapsedRealtime();
         mAOOKeepaliveTracker =
                 new AutomaticOnOffKeepaliveTracker(mCtx, mTestHandler, mDependencies);
     }
@@ -499,6 +508,30 @@
         assertEquals(testInfo.underpinnedNetwork, mTestHandler.mLastAutoKi.getUnderpinnedNetwork());
     }
 
+    @Test
+    public void testAlarm_writeMetrics() throws Exception {
+        final ArgumentCaptor<AlarmManager.OnAlarmListener> listenerCaptor =
+                ArgumentCaptor.forClass(AlarmManager.OnAlarmListener.class);
+
+        // First AlarmManager.set call from the constructor.
+        verify(mAlarmManager).set(eq(AlarmManager.ELAPSED_REALTIME_WAKEUP),
+                eq(METRICS_COLLECTION_DURATION_MS), any() /* tag */, listenerCaptor.capture(),
+                eq(mTestHandler));
+
+        final AlarmManager.OnAlarmListener listener = listenerCaptor.getValue();
+
+        doReturn(METRICS_COLLECTION_DURATION_MS).when(mDependencies).getElapsedRealtime();
+        // For realism, the listener should be posted on the handler
+        mTestHandler.post(() -> listener.onAlarm());
+        HandlerUtils.waitForIdle(mTestHandler, TIMEOUT_MS);
+
+        verify(mKeepaliveStatsTracker).writeAndResetMetrics();
+        // Alarm is rescheduled.
+        verify(mAlarmManager).set(eq(AlarmManager.ELAPSED_REALTIME_WAKEUP),
+                eq(METRICS_COLLECTION_DURATION_MS * 2),
+                any() /* tag */, listenerCaptor.capture(), eq(mTestHandler));
+    }
+
     private void setupResponseWithSocketExisting() throws Exception {
         final ByteBuffer tcpBufferV6 = getByteBuffer(TEST_RESPONSE_BYTES);
         final ByteBuffer tcpBufferV4 = getByteBuffer(TEST_RESPONSE_BYTES);
@@ -659,6 +692,10 @@
         assertNull(getAutoKiForBinder(testInfo.binder));
 
         verifyNoMoreInteractions(ignoreStubs(testInfo.socketKeepaliveCallback));
+
+        // Make sure the slot is free
+        final TestKeepaliveInfo testInfo2 = doStartNattKeepalive();
+        checkAndProcessKeepaliveStart(testInfo2.kpd);
     }
 
     @Test
@@ -787,36 +824,34 @@
 
         clearInvocations(mNai);
         // Start the second keepalive while the first is paused.
-        // TODO: Uncomment the following test after fixing b/283886067. Currently this attempts to
-        // start the keepalive on TEST_SLOT and this throws in the handler thread.
-        // final TestKeepaliveInfo testInfo2 = doStartNattKeepalive();
-        // // The slot used is TEST_SLOT + 1 since TEST_SLOT is being taken by the paused keepalive.
-        // checkAndProcessKeepaliveStart(TEST_SLOT + 1, testInfo2.kpd);
-        // verify(testInfo2.socketKeepaliveCallback).onStarted();
-        // assertNotNull(getAutoKiForBinder(testInfo2.binder));
+        final TestKeepaliveInfo testInfo2 = doStartNattKeepalive();
+        // The slot used is TEST_SLOT + 1 since TEST_SLOT is being taken by the paused keepalive.
+        checkAndProcessKeepaliveStart(TEST_SLOT + 1, testInfo2.kpd);
+        verify(testInfo2.socketKeepaliveCallback).onStarted();
+        assertNotNull(getAutoKiForBinder(testInfo2.binder));
 
-        // clearInvocations(mNai);
-        // doResumeKeepalive(autoKi1);
-        // // Resume on TEST_SLOT.
-        // checkAndProcessKeepaliveStart(TEST_SLOT, testInfo1.kpd);
-        // verify(testInfo1.socketKeepaliveCallback).onResumed();
+        clearInvocations(mNai);
+        doResumeKeepalive(autoKi1);
+        // Resume on TEST_SLOT.
+        checkAndProcessKeepaliveStart(TEST_SLOT, testInfo1.kpd);
+        verify(testInfo1.socketKeepaliveCallback).onResumed();
 
-        // clearInvocations(mNai);
-        // doStopKeepalive(autoKi1);
-        // checkAndProcessKeepaliveStop(TEST_SLOT);
-        // verify(testInfo1.socketKeepaliveCallback).onStopped();
-        // verify(testInfo2.socketKeepaliveCallback, never()).onStopped();
-        // assertNull(getAutoKiForBinder(testInfo1.binder));
+        clearInvocations(mNai);
+        doStopKeepalive(autoKi1);
+        checkAndProcessKeepaliveStop(TEST_SLOT);
+        verify(testInfo1.socketKeepaliveCallback).onStopped();
+        verify(testInfo2.socketKeepaliveCallback, never()).onStopped();
+        assertNull(getAutoKiForBinder(testInfo1.binder));
 
-        // clearInvocations(mNai);
-        // assertNotNull(getAutoKiForBinder(testInfo2.binder));
-        // doStopKeepalive(getAutoKiForBinder(testInfo2.binder));
-        // checkAndProcessKeepaliveStop(TEST_SLOT + 1);
-        // verify(testInfo2.socketKeepaliveCallback).onStopped();
-        // assertNull(getAutoKiForBinder(testInfo2.binder));
+        clearInvocations(mNai);
+        assertNotNull(getAutoKiForBinder(testInfo2.binder));
+        doStopKeepalive(getAutoKiForBinder(testInfo2.binder));
+        checkAndProcessKeepaliveStop(TEST_SLOT + 1);
+        verify(testInfo2.socketKeepaliveCallback).onStopped();
+        assertNull(getAutoKiForBinder(testInfo2.binder));
 
-        // verifyNoMoreInteractions(ignoreStubs(testInfo1.socketKeepaliveCallback));
-        // verifyNoMoreInteractions(ignoreStubs(testInfo2.socketKeepaliveCallback));
+        verifyNoMoreInteractions(ignoreStubs(testInfo1.socketKeepaliveCallback));
+        verifyNoMoreInteractions(ignoreStubs(testInfo2.socketKeepaliveCallback));
     }
 
     @Test
diff --git a/tests/unit/java/com/android/server/connectivity/KeepaliveStatsTrackerTest.java b/tests/unit/java/com/android/server/connectivity/KeepaliveStatsTrackerTest.java
index 9472ded..0d2e540 100644
--- a/tests/unit/java/com/android/server/connectivity/KeepaliveStatsTrackerTest.java
+++ b/tests/unit/java/com/android/server/connectivity/KeepaliveStatsTrackerTest.java
@@ -27,12 +27,15 @@
 import static org.junit.Assert.fail;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.anyLong;
+import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.Mockito.doCallRealMethod;
 import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.verify;
 
+import android.content.BroadcastReceiver;
 import android.content.Context;
+import android.content.Intent;
 import android.net.Network;
 import android.net.NetworkCapabilities;
 import android.net.TelephonyNetworkSpecifier;
@@ -109,6 +112,18 @@
     @Mock private KeepaliveStatsTracker.Dependencies mDependencies;
     @Mock private SubscriptionManager mSubscriptionManager;
 
+    private void triggerBroadcastDefaultSubId(int subId) {
+        final ArgumentCaptor<BroadcastReceiver> receiverCaptor =
+                ArgumentCaptor.forClass(BroadcastReceiver.class);
+        verify(mContext).registerReceiver(receiverCaptor.capture(), /* filter= */ any(),
+                /* broadcastPermission= */ any(), eq(mTestHandler));
+        final Intent intent =
+                new Intent(TelephonyManager.ACTION_SUBSCRIPTION_CARRIER_IDENTITY_CHANGED);
+        intent.putExtra(SubscriptionManager.EXTRA_SUBSCRIPTION_INDEX, subId);
+
+        receiverCaptor.getValue().onReceive(mContext, intent);
+    }
+
     private OnSubscriptionsChangedListener getOnSubscriptionsChangedListener() {
         final ArgumentCaptor<OnSubscriptionsChangedListener> listenerCaptor =
                 ArgumentCaptor.forClass(OnSubscriptionsChangedListener.class);
@@ -1128,4 +1143,53 @@
                             writeTime * 3 - startTime1 - startTime2 - startTime3)
                 });
     }
+
+    @Test
+    public void testUpdateDefaultSubId() {
+        final int startTime1 = 1000;
+        final int startTime2 = 3000;
+        final int writeTime = 5000;
+
+        // No TelephonyNetworkSpecifier set with subId to force the use of default subId.
+        final NetworkCapabilities nc =
+                new NetworkCapabilities.Builder().addTransportType(TRANSPORT_CELLULAR).build();
+        onStartKeepalive(startTime1, TEST_SLOT, nc);
+        // Update default subId
+        triggerBroadcastDefaultSubId(TEST_SUB_ID_1);
+        onStartKeepalive(startTime2, TEST_SLOT2, nc);
+
+        final DailykeepaliveInfoReported dailyKeepaliveInfoReported =
+                buildKeepaliveMetrics(writeTime);
+
+        final int[] expectRegisteredDurations =
+                new int[] {startTime1, startTime2 - startTime1, writeTime - startTime2};
+        final int[] expectActiveDurations =
+                new int[] {startTime1, startTime2 - startTime1, writeTime - startTime2};
+        // Expect the carrier id of the first keepalive to be unknown
+        final KeepaliveCarrierStats expectKeepaliveCarrierStats1 =
+                new KeepaliveCarrierStats(
+                        TelephonyManager.UNKNOWN_CARRIER_ID,
+                        /* transportTypes= */ (1 << TRANSPORT_CELLULAR),
+                        TEST_KEEPALIVE_INTERVAL_SEC * 1000,
+                        writeTime - startTime1,
+                        writeTime - startTime1);
+        // Expect the carrier id of the second keepalive to be TEST_CARRIER_ID_1, from TEST_SUB_ID_1
+        final KeepaliveCarrierStats expectKeepaliveCarrierStats2 =
+                new KeepaliveCarrierStats(
+                        TEST_CARRIER_ID_1,
+                        /* transportTypes= */ (1 << TRANSPORT_CELLULAR),
+                        TEST_KEEPALIVE_INTERVAL_SEC * 1000,
+                        writeTime - startTime2,
+                        writeTime - startTime2);
+        assertDailyKeepaliveInfoReported(
+                dailyKeepaliveInfoReported,
+                /* expectRequestsCount= */ 2,
+                /* expectAutoRequestsCount= */ 2,
+                /* expectAppUids= */ new int[] {TEST_UID},
+                expectRegisteredDurations,
+                expectActiveDurations,
+                new KeepaliveCarrierStats[] {
+                    expectKeepaliveCarrierStats1, expectKeepaliveCarrierStats2
+                });
+    }
 }
diff --git a/tests/unit/java/com/android/server/connectivity/Nat464XlatTest.java b/tests/unit/java/com/android/server/connectivity/Nat464XlatTest.java
index 06e0d6d..58c0114 100644
--- a/tests/unit/java/com/android/server/connectivity/Nat464XlatTest.java
+++ b/tests/unit/java/com/android/server/connectivity/Nat464XlatTest.java
@@ -20,10 +20,12 @@
 
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertTrue;
 import static org.mockito.Mockito.any;
 import static org.mockito.Mockito.anyInt;
 import static org.mockito.Mockito.anyString;
+import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.eq;
 import static org.mockito.Mockito.inOrder;
 import static org.mockito.Mockito.never;
@@ -72,6 +74,7 @@
     static final String STACKED_IFACE = "v4-test0";
     static final LinkAddress V6ADDR = new LinkAddress("2001:db8:1::f00/64");
     static final LinkAddress ADDR = new LinkAddress("192.0.2.5/29");
+    static final String CLAT_V6 = "64:ff9b::1";
     static final String NAT64_PREFIX = "64:ff9b::/96";
     static final String OTHER_NAT64_PREFIX = "2001:db8:0:64::/96";
     static final int NETID = 42;
@@ -132,6 +135,8 @@
         when(mNetd.interfaceGetCfg(eq(STACKED_IFACE))).thenReturn(mConfig);
         mConfig.ipv4Addr = ADDR.getAddress().getHostAddress();
         mConfig.prefixLength =  ADDR.getPrefixLength();
+        doReturn(CLAT_V6).when(mClatCoordinator).clatStart(
+                BASE_IFACE, NETID, new IpPrefix(NAT64_PREFIX));
     }
 
     private void assertRequiresClat(boolean expected, NetworkAgentInfo nai) {
@@ -286,7 +291,8 @@
         assertFalse(c.getValue().getAllInterfaceNames().contains(STACKED_IFACE));
         verify(mDnsResolver).stopPrefix64Discovery(eq(NETID));
         assertIdle(nat);
-
+        // Verify the generated v6 is reset when clat is stopped.
+        assertNull(nat.mIPv6Address);
         // Stacked interface removed notification arrives and is ignored.
         nat.interfaceRemoved(STACKED_IFACE);
         mLooper.dispatchNext();
diff --git a/tests/unit/java/com/android/server/connectivity/NetworkNotificationManagerTest.java b/tests/unit/java/com/android/server/connectivity/NetworkNotificationManagerTest.java
index a27a0bf..967083e 100644
--- a/tests/unit/java/com/android/server/connectivity/NetworkNotificationManagerTest.java
+++ b/tests/unit/java/com/android/server/connectivity/NetworkNotificationManagerTest.java
@@ -62,6 +62,7 @@
 import android.util.DisplayMetrics;
 import android.widget.TextView;
 
+import androidx.annotation.NonNull;
 import androidx.annotation.Nullable;
 import androidx.annotation.StringRes;
 import androidx.test.filters.SmallTest;
@@ -386,14 +387,37 @@
     }
 
     @Test
-    public void testNotifyNoInternetAsDialogWhenHighPriority() throws Exception {
-        doReturn(true).when(mResources).getBoolean(
+    public void testNotifyNoInternet_asNotification() throws Exception {
+        doTestNotifyNotificationAsDialogWhenHighPriority(false, NO_INTERNET);
+    }
+    @Test
+        public void testNotifyNoInternet_asDialog() throws Exception {
+        doTestNotifyNotificationAsDialogWhenHighPriority(true, NO_INTERNET);
+    }
+
+    @Test
+    public void testNotifyLostInternet_asNotification() throws Exception {
+        doTestNotifyNotificationAsDialogWhenHighPriority(false, LOST_INTERNET);
+    }
+
+    @Test
+    public void testNotifyLostInternet_asDialog() throws Exception {
+        doTestNotifyNotificationAsDialogWhenHighPriority(true, LOST_INTERNET);
+    }
+
+    public void doTestNotifyNotificationAsDialogWhenHighPriority(final boolean configActive,
+            @NonNull final NotificationType notifType) throws Exception {
+        doReturn(configActive).when(mResources).getBoolean(
                 R.bool.config_notifyNoInternetAsDialogWhenHighPriority);
 
         final Instrumentation instr = InstrumentationRegistry.getInstrumentation();
         final UiDevice uiDevice =  UiDevice.getInstance(instr);
         final Context ctx = instr.getContext();
         final PowerManager pm = ctx.getSystemService(PowerManager.class);
+        // If the prio of this notif is < that of NETWORK_SWITCH, it's the lowest prio and
+        // therefore it can't be tested whether it cancels other lower-prio notifs.
+        final boolean isLowestPrioNotif = NetworkNotificationManager.priority(notifType)
+                < NetworkNotificationManager.priority(NETWORK_SWITCH);
 
         // Wake up the device (it has no effect if the device is already awake).
         uiDevice.executeShellCommand("input keyevent KEYCODE_WAKEUP");
@@ -409,9 +433,13 @@
                 uiDevice.wait(Until.hasObject(By.pkg(launcherPackageName)),
                         UI_AUTOMATOR_WAIT_TIME_MILLIS));
 
-        mManager.showNotification(TEST_NOTIF_ID, NETWORK_SWITCH, mWifiNai, mCellNai, null, false);
-        // Non-"no internet" notifications are not affected
-        verify(mNotificationManager).notify(eq(TEST_NOTIF_TAG), eq(NETWORK_SWITCH.eventId), any());
+        if (!isLowestPrioNotif) {
+            mManager.showNotification(TEST_NOTIF_ID, NETWORK_SWITCH, mWifiNai, mCellNai,
+                    null, false);
+            // Non-"no internet" notifications are not affected
+            verify(mNotificationManager).notify(eq(TEST_NOTIF_TAG), eq(NETWORK_SWITCH.eventId),
+                    any());
+        }
 
         final String testAction = "com.android.connectivity.coverage.TEST_DIALOG";
         final Intent intent = new Intent(testAction)
@@ -420,22 +448,30 @@
         final PendingIntent pendingIntent = PendingIntent.getActivity(ctx, 0 /* requestCode */,
                 intent, PendingIntent.FLAG_CANCEL_CURRENT | PendingIntent.FLAG_IMMUTABLE);
 
-        mManager.showNotification(TEST_NOTIF_ID, NO_INTERNET, mWifiNai, null /* switchToNai */,
+        mManager.showNotification(TEST_NOTIF_ID, notifType, mWifiNai, null /* switchToNai */,
                 pendingIntent, true /* highPriority */);
 
-        // Previous notifications are still dismissed
-        verify(mNotificationManager).cancel(TEST_NOTIF_TAG, NETWORK_SWITCH.eventId);
+        if (!isLowestPrioNotif) {
+            // Previous notifications are still dismissed
+            verify(mNotificationManager).cancel(TEST_NOTIF_TAG, NETWORK_SWITCH.eventId);
+        }
 
-        // Verify that the activity is shown (the activity shows the action on screen)
-        final UiObject actionText = uiDevice.findObject(new UiSelector().text(testAction));
-        assertTrue("Activity not shown", actionText.waitForExists(TEST_TIMEOUT_MS));
+        if (configActive) {
+            // Verify that the activity is shown (the activity shows the action on screen)
+            final UiObject actionText = uiDevice.findObject(new UiSelector().text(testAction));
+            assertTrue("Activity not shown", actionText.waitForExists(TEST_TIMEOUT_MS));
 
-        // Tapping the text should dismiss the dialog
-        actionText.click();
-        assertTrue("Activity not dismissed", actionText.waitUntilGone(TEST_TIMEOUT_MS));
+            // Tapping the text should dismiss the dialog
+            actionText.click();
+            assertTrue("Activity not dismissed", actionText.waitUntilGone(TEST_TIMEOUT_MS));
 
-        // Verify no NO_INTERNET notification was posted
-        verify(mNotificationManager, never()).notify(any(), eq(NO_INTERNET.eventId), any());
+            // Verify that the notification was not posted
+            verify(mNotificationManager, never()).notify(any(), eq(notifType.eventId), any());
+        } else {
+            // Notification should have been posted, and will have overridden the previous
+            // one because it has the same id (hence no cancel).
+            verify(mNotificationManager).notify(eq(TEST_NOTIF_TAG), eq(notifType.eventId), any());
+        }
     }
 
     private void doNotificationTextTest(NotificationType type, @StringRes int expectedTitleRes,
diff --git a/tests/unit/java/com/android/server/connectivity/VpnTest.java b/tests/unit/java/com/android/server/connectivity/VpnTest.java
index 2d2819c..dc50773 100644
--- a/tests/unit/java/com/android/server/connectivity/VpnTest.java
+++ b/tests/unit/java/com/android/server/connectivity/VpnTest.java
@@ -80,6 +80,7 @@
 import static org.mockito.Mockito.doCallRealMethod;
 import static org.mockito.Mockito.doNothing;
 import static org.mockito.Mockito.doReturn;
+import static org.mockito.Mockito.doThrow;
 import static org.mockito.Mockito.inOrder;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.never;
@@ -807,6 +808,32 @@
     }
 
     @Test
+    public void testPrepare_legacyVpnWithoutControlVpn()
+            throws Exception {
+        doThrow(new SecurityException("no CONTROL_VPN")).when(mContext)
+                .enforceCallingOrSelfPermission(eq(CONTROL_VPN), any());
+        final Vpn vpn = createVpn();
+        assertThrows(SecurityException.class,
+                () -> vpn.prepare(null, VpnConfig.LEGACY_VPN, VpnManager.TYPE_VPN_SERVICE));
+
+        // CONTROL_VPN can be held by the caller or another system server process - both are
+        // allowed. Just checking for `enforceCallingPermission` may not be sufficient.
+        verify(mContext, never()).enforceCallingPermission(eq(CONTROL_VPN), any());
+    }
+
+    @Test
+    public void testPrepare_legacyVpnWithControlVpn()
+            throws Exception {
+        doNothing().when(mContext).enforceCallingOrSelfPermission(eq(CONTROL_VPN), any());
+        final Vpn vpn = createVpn();
+        assertTrue(vpn.prepare(null, VpnConfig.LEGACY_VPN, VpnManager.TYPE_VPN_SERVICE));
+
+        // CONTROL_VPN can be held by the caller or another system server process - both are
+        // allowed. Just checking for `enforceCallingPermission` may not be sufficient.
+        verify(mContext, never()).enforceCallingPermission(eq(CONTROL_VPN), any());
+    }
+
+    @Test
     public void testIsAlwaysOnPackageSupported() throws Exception {
         final Vpn vpn = createVpn(PRIMARY_USER.id);
 
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsAdvertiserTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsAdvertiserTest.kt
index d9acc61..6a0334f 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsAdvertiserTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsAdvertiserTest.kt
@@ -56,7 +56,8 @@
 private val TEST_ADDR = parseNumericAddress("2001:db8::123")
 private val TEST_LINKADDR = LinkAddress(TEST_ADDR, 64 /* prefixLength */)
 private val TEST_NETWORK_1 = mock(Network::class.java)
-private val TEST_NETWORK_2 = mock(Network::class.java)
+private val TEST_SOCKETKEY_1 = SocketKey(1001 /* interfaceIndex */)
+private val TEST_SOCKETKEY_2 = SocketKey(1002 /* interfaceIndex */)
 private val TEST_HOSTNAME = arrayOf("Android_test", "local")
 private const val TEST_SUBTYPE = "_subtype"
 
@@ -145,7 +146,7 @@
         verify(socketProvider).requestSocket(eq(TEST_NETWORK_1), socketCbCaptor.capture())
 
         val socketCb = socketCbCaptor.value
-        postSync { socketCb.onSocketCreated(TEST_NETWORK_1, mockSocket1, listOf(TEST_LINKADDR)) }
+        postSync { socketCb.onSocketCreated(TEST_SOCKETKEY_1, mockSocket1, listOf(TEST_LINKADDR)) }
 
         val intAdvCbCaptor = ArgumentCaptor.forClass(MdnsInterfaceAdvertiser.Callback::class.java)
         verify(mockDeps).makeAdvertiser(
@@ -163,7 +164,7 @@
                 mockInterfaceAdvertiser1, SERVICE_ID_1) }
         verify(cb).onRegisterServiceSucceeded(eq(SERVICE_ID_1), argThat { it.matches(SERVICE_1) })
 
-        postSync { socketCb.onInterfaceDestroyed(TEST_NETWORK_1, mockSocket1) }
+        postSync { socketCb.onInterfaceDestroyed(TEST_SOCKETKEY_1, mockSocket1) }
         verify(mockInterfaceAdvertiser1).destroyNow()
     }
 
@@ -177,8 +178,8 @@
                 socketCbCaptor.capture())
 
         val socketCb = socketCbCaptor.value
-        postSync { socketCb.onSocketCreated(TEST_NETWORK_1, mockSocket1, listOf(TEST_LINKADDR)) }
-        postSync { socketCb.onSocketCreated(TEST_NETWORK_2, mockSocket2, listOf(TEST_LINKADDR)) }
+        postSync { socketCb.onSocketCreated(TEST_SOCKETKEY_1, mockSocket1, listOf(TEST_LINKADDR)) }
+        postSync { socketCb.onSocketCreated(TEST_SOCKETKEY_2, mockSocket2, listOf(TEST_LINKADDR)) }
 
         val intAdvCbCaptor1 = ArgumentCaptor.forClass(MdnsInterfaceAdvertiser.Callback::class.java)
         val intAdvCbCaptor2 = ArgumentCaptor.forClass(MdnsInterfaceAdvertiser.Callback::class.java)
@@ -241,8 +242,8 @@
 
         // Callbacks for matching network and all networks both get the socket
         postSync {
-            oneNetSocketCb.onSocketCreated(TEST_NETWORK_1, mockSocket1, listOf(TEST_LINKADDR))
-            allNetSocketCb.onSocketCreated(TEST_NETWORK_1, mockSocket1, listOf(TEST_LINKADDR))
+            oneNetSocketCb.onSocketCreated(TEST_SOCKETKEY_1, mockSocket1, listOf(TEST_LINKADDR))
+            allNetSocketCb.onSocketCreated(TEST_SOCKETKEY_1, mockSocket1, listOf(TEST_LINKADDR))
         }
 
         val expectedRenamed = NsdServiceInfo(
@@ -294,8 +295,8 @@
         verify(cb).onRegisterServiceSucceeded(eq(SERVICE_ID_2),
                 argThat { it.matches(expectedRenamed) })
 
-        postSync { oneNetSocketCb.onInterfaceDestroyed(TEST_NETWORK_1, mockSocket1) }
-        postSync { allNetSocketCb.onInterfaceDestroyed(TEST_NETWORK_1, mockSocket1) }
+        postSync { oneNetSocketCb.onInterfaceDestroyed(TEST_SOCKETKEY_1, mockSocket1) }
+        postSync { allNetSocketCb.onInterfaceDestroyed(TEST_SOCKETKEY_1, mockSocket1) }
 
         // destroyNow can be called multiple times
         verify(mockInterfaceAdvertiser1, atLeastOnce()).destroyNow()
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsDiscoveryManagerTests.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsDiscoveryManagerTests.java
index a24664e..1a4ae5d 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsDiscoveryManagerTests.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsDiscoveryManagerTests.java
@@ -19,7 +19,6 @@
 import static com.android.testutils.DevSdkIgnoreRuleKt.SC_V2;
 
 import static org.mockito.ArgumentMatchers.any;
-import static org.mockito.ArgumentMatchers.anyInt;
 import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.eq;
 import static org.mockito.Mockito.never;
@@ -28,7 +27,6 @@
 import static org.mockito.Mockito.when;
 
 import android.annotation.NonNull;
-import android.annotation.Nullable;
 import android.net.Network;
 import android.os.Handler;
 import android.os.HandlerThread;
@@ -65,17 +63,22 @@
     private static final String SERVICE_TYPE_2 = "_test._tcp.local";
     private static final Network NETWORK_1 = Mockito.mock(Network.class);
     private static final Network NETWORK_2 = Mockito.mock(Network.class);
-    private static final Pair<String, Network> PER_NETWORK_SERVICE_TYPE_1_NULL_NETWORK =
-            Pair.create(SERVICE_TYPE_1, null);
-    private static final Pair<String, Network> PER_NETWORK_SERVICE_TYPE_1_NETWORK_1 =
-            Pair.create(SERVICE_TYPE_1, NETWORK_1);
-    private static final Pair<String, Network> PER_NETWORK_SERVICE_TYPE_2_NULL_NETWORK =
-            Pair.create(SERVICE_TYPE_2, null);
-    private static final Pair<String, Network> PER_NETWORK_SERVICE_TYPE_2_NETWORK_1 =
-            Pair.create(SERVICE_TYPE_2, NETWORK_1);
-    private static final Pair<String, Network> PER_NETWORK_SERVICE_TYPE_2_NETWORK_2 =
-            Pair.create(SERVICE_TYPE_2, NETWORK_2);
-
+    private static final SocketKey SOCKET_KEY_NULL_NETWORK =
+            new SocketKey(null /* network */, 999 /* interfaceIndex */);
+    private static final SocketKey SOCKET_KEY_NETWORK_1 =
+            new SocketKey(NETWORK_1, 998 /* interfaceIndex */);
+    private static final SocketKey SOCKET_KEY_NETWORK_2 =
+            new SocketKey(NETWORK_2, 997 /* interfaceIndex */);
+    private static final Pair<String, SocketKey> PER_SOCKET_SERVICE_TYPE_1_NULL_NETWORK =
+            Pair.create(SERVICE_TYPE_1, SOCKET_KEY_NULL_NETWORK);
+    private static final Pair<String, SocketKey> PER_SOCKET_SERVICE_TYPE_2_NULL_NETWORK =
+            Pair.create(SERVICE_TYPE_2, SOCKET_KEY_NULL_NETWORK);
+    private static final Pair<String, SocketKey> PER_SOCKET_SERVICE_TYPE_1_NETWORK_1 =
+            Pair.create(SERVICE_TYPE_1, SOCKET_KEY_NETWORK_1);
+    private static final Pair<String, SocketKey> PER_SOCKET_SERVICE_TYPE_2_NETWORK_1 =
+            Pair.create(SERVICE_TYPE_2, SOCKET_KEY_NETWORK_1);
+    private static final Pair<String, SocketKey> PER_SOCKET_SERVICE_TYPE_2_NETWORK_2 =
+            Pair.create(SERVICE_TYPE_2, SOCKET_KEY_NETWORK_2);
     @Mock private ExecutorProvider executorProvider;
     @Mock private MdnsSocketClientBase socketClient;
     @Mock private MdnsServiceTypeClient mockServiceTypeClientType1NullNetwork;
@@ -104,22 +107,22 @@
                 sharedLog) {
                     @Override
                     MdnsServiceTypeClient createServiceTypeClient(@NonNull String serviceType,
-                            @Nullable Network network) {
-                        final Pair<String, Network> perNetworkServiceType =
-                                Pair.create(serviceType, network);
-                        if (perNetworkServiceType.equals(PER_NETWORK_SERVICE_TYPE_1_NULL_NETWORK)) {
+                            @NonNull SocketKey socketKey) {
+                        final Pair<String, SocketKey> perSocketServiceType =
+                                Pair.create(serviceType, socketKey);
+                        if (perSocketServiceType.equals(PER_SOCKET_SERVICE_TYPE_1_NULL_NETWORK)) {
                             return mockServiceTypeClientType1NullNetwork;
-                        } else if (perNetworkServiceType.equals(
-                                PER_NETWORK_SERVICE_TYPE_1_NETWORK_1)) {
+                        } else if (perSocketServiceType.equals(
+                                PER_SOCKET_SERVICE_TYPE_1_NETWORK_1)) {
                             return mockServiceTypeClientType1Network1;
-                        } else if (perNetworkServiceType.equals(
-                                PER_NETWORK_SERVICE_TYPE_2_NULL_NETWORK)) {
+                        } else if (perSocketServiceType.equals(
+                                PER_SOCKET_SERVICE_TYPE_2_NULL_NETWORK)) {
                             return mockServiceTypeClientType2NullNetwork;
-                        } else if (perNetworkServiceType.equals(
-                                PER_NETWORK_SERVICE_TYPE_2_NETWORK_1)) {
+                        } else if (perSocketServiceType.equals(
+                                PER_SOCKET_SERVICE_TYPE_2_NETWORK_1)) {
                             return mockServiceTypeClientType2Network1;
-                        } else if (perNetworkServiceType.equals(
-                                PER_NETWORK_SERVICE_TYPE_2_NETWORK_2)) {
+                        } else if (perSocketServiceType.equals(
+                                PER_SOCKET_SERVICE_TYPE_2_NETWORK_2)) {
                             return mockServiceTypeClientType2Network2;
                         }
                         return null;
@@ -156,7 +159,7 @@
                 MdnsSearchOptions.newBuilder().setNetwork(null /* network */).build();
         final SocketCreationCallback callback = expectSocketCreationCallback(
                 SERVICE_TYPE_1, mockListenerOne, options);
-        runOnHandler(() -> callback.onSocketCreated(null /* network */));
+        runOnHandler(() -> callback.onSocketCreated(SOCKET_KEY_NULL_NETWORK));
         verify(mockServiceTypeClientType1NullNetwork).startSendAndReceive(mockListenerOne, options);
 
         when(mockServiceTypeClientType1NullNetwork.stopSendAndReceive(mockListenerOne))
@@ -172,16 +175,16 @@
                 MdnsSearchOptions.newBuilder().setNetwork(null /* network */).build();
         final SocketCreationCallback callback = expectSocketCreationCallback(
                 SERVICE_TYPE_1, mockListenerOne, options);
-        runOnHandler(() -> callback.onSocketCreated(null /* network */));
+        runOnHandler(() -> callback.onSocketCreated(SOCKET_KEY_NULL_NETWORK));
         verify(mockServiceTypeClientType1NullNetwork).startSendAndReceive(mockListenerOne, options);
-        runOnHandler(() -> callback.onSocketCreated(NETWORK_1));
+        runOnHandler(() -> callback.onSocketCreated(SOCKET_KEY_NETWORK_1));
         verify(mockServiceTypeClientType1Network1).startSendAndReceive(mockListenerOne, options);
 
         final SocketCreationCallback callback2 = expectSocketCreationCallback(
                 SERVICE_TYPE_2, mockListenerTwo, options);
-        runOnHandler(() -> callback2.onSocketCreated(null /* network */));
+        runOnHandler(() -> callback2.onSocketCreated(SOCKET_KEY_NULL_NETWORK));
         verify(mockServiceTypeClientType2NullNetwork).startSendAndReceive(mockListenerTwo, options);
-        runOnHandler(() -> callback2.onSocketCreated(NETWORK_2));
+        runOnHandler(() -> callback2.onSocketCreated(SOCKET_KEY_NETWORK_2));
         verify(mockServiceTypeClientType2Network2).startSendAndReceive(mockListenerTwo, options);
     }
 
@@ -191,49 +194,48 @@
                 MdnsSearchOptions.newBuilder().setNetwork(null /* network */).build();
         final SocketCreationCallback callback = expectSocketCreationCallback(
                 SERVICE_TYPE_1, mockListenerOne, options1);
-        runOnHandler(() -> callback.onSocketCreated(null /* network */));
+        runOnHandler(() -> callback.onSocketCreated(SOCKET_KEY_NULL_NETWORK));
         verify(mockServiceTypeClientType1NullNetwork).startSendAndReceive(
                 mockListenerOne, options1);
-        runOnHandler(() -> callback.onSocketCreated(NETWORK_1));
+        runOnHandler(() -> callback.onSocketCreated(SOCKET_KEY_NETWORK_1));
         verify(mockServiceTypeClientType1Network1).startSendAndReceive(mockListenerOne, options1);
 
         final MdnsSearchOptions options2 =
                 MdnsSearchOptions.newBuilder().setNetwork(NETWORK_2).build();
         final SocketCreationCallback callback2 = expectSocketCreationCallback(
                 SERVICE_TYPE_2, mockListenerTwo, options2);
-        runOnHandler(() -> callback2.onSocketCreated(NETWORK_2));
+        runOnHandler(() -> callback2.onSocketCreated(SOCKET_KEY_NETWORK_2));
         verify(mockServiceTypeClientType2Network2).startSendAndReceive(mockListenerTwo, options2);
 
         final MdnsPacket responseForServiceTypeOne = createMdnsPacket(SERVICE_TYPE_1);
-        final int ifIndex = 1;
         runOnHandler(() -> discoveryManager.onResponseReceived(
-                responseForServiceTypeOne, ifIndex, null /* network */));
+                responseForServiceTypeOne, SOCKET_KEY_NULL_NETWORK));
         // Packets for network null are only processed by the ServiceTypeClient for network null
-        verify(mockServiceTypeClientType1NullNetwork).processResponse(responseForServiceTypeOne,
-                ifIndex, null /* network */);
-        verify(mockServiceTypeClientType1Network1, never()).processResponse(any(), anyInt(), any());
-        verify(mockServiceTypeClientType2Network2, never()).processResponse(any(), anyInt(), any());
+        verify(mockServiceTypeClientType1NullNetwork).processResponse(
+                responseForServiceTypeOne, SOCKET_KEY_NULL_NETWORK);
+        verify(mockServiceTypeClientType1Network1, never()).processResponse(any(), any());
+        verify(mockServiceTypeClientType2Network2, never()).processResponse(any(), any());
 
         final MdnsPacket responseForServiceTypeTwo = createMdnsPacket(SERVICE_TYPE_2);
         runOnHandler(() -> discoveryManager.onResponseReceived(
-                responseForServiceTypeTwo, ifIndex, NETWORK_1));
-        verify(mockServiceTypeClientType1NullNetwork, never()).processResponse(any(), anyInt(),
-                eq(NETWORK_1));
-        verify(mockServiceTypeClientType1Network1).processResponse(responseForServiceTypeTwo,
-                ifIndex, NETWORK_1);
-        verify(mockServiceTypeClientType2Network2, never()).processResponse(any(), anyInt(),
-                eq(NETWORK_1));
+                responseForServiceTypeTwo, SOCKET_KEY_NETWORK_1));
+        verify(mockServiceTypeClientType1NullNetwork, never()).processResponse(any(),
+                eq(SOCKET_KEY_NETWORK_1));
+        verify(mockServiceTypeClientType1Network1).processResponse(
+                responseForServiceTypeTwo, SOCKET_KEY_NETWORK_1);
+        verify(mockServiceTypeClientType2Network2, never()).processResponse(any(),
+                eq(SOCKET_KEY_NETWORK_1));
 
         final MdnsPacket responseForSubtype =
                 createMdnsPacket("subtype._sub._googlecast._tcp.local");
         runOnHandler(() -> discoveryManager.onResponseReceived(
-                responseForSubtype, ifIndex, NETWORK_2));
-        verify(mockServiceTypeClientType1NullNetwork, never()).processResponse(
-                any(), anyInt(), eq(NETWORK_2));
-        verify(mockServiceTypeClientType1Network1, never()).processResponse(
-                any(), anyInt(), eq(NETWORK_2));
+                responseForSubtype, SOCKET_KEY_NETWORK_2));
+        verify(mockServiceTypeClientType1NullNetwork, never()).processResponse(any(),
+                eq(SOCKET_KEY_NETWORK_2));
+        verify(mockServiceTypeClientType1Network1, never()).processResponse(any(),
+                eq(SOCKET_KEY_NETWORK_2));
         verify(mockServiceTypeClientType2Network2).processResponse(
-                responseForSubtype, ifIndex, NETWORK_2);
+                responseForSubtype, SOCKET_KEY_NETWORK_2);
     }
 
     @Test
@@ -243,55 +245,51 @@
                 MdnsSearchOptions.newBuilder().setNetwork(NETWORK_1).build();
         final SocketCreationCallback callback = expectSocketCreationCallback(
                 SERVICE_TYPE_1, mockListenerOne, network1Options);
-        runOnHandler(() -> callback.onSocketCreated(NETWORK_1));
+        runOnHandler(() -> callback.onSocketCreated(SOCKET_KEY_NETWORK_1));
         verify(mockServiceTypeClientType1Network1).startSendAndReceive(
                 mockListenerOne, network1Options);
 
         // Create a ServiceTypeClient for SERVICE_TYPE_2 and NETWORK_1
         final SocketCreationCallback callback2 = expectSocketCreationCallback(
                 SERVICE_TYPE_2, mockListenerTwo, network1Options);
-        runOnHandler(() -> callback2.onSocketCreated(NETWORK_1));
+        runOnHandler(() -> callback2.onSocketCreated(SOCKET_KEY_NETWORK_1));
         verify(mockServiceTypeClientType2Network1).startSendAndReceive(
                 mockListenerTwo, network1Options);
 
         // Receive a response, it should be processed on both clients.
         final MdnsPacket response = createMdnsPacket(SERVICE_TYPE_1);
-        final int ifIndex = 1;
-        runOnHandler(() -> discoveryManager.onResponseReceived(
-                response, ifIndex, NETWORK_1));
-        verify(mockServiceTypeClientType1Network1).processResponse(response, ifIndex, NETWORK_1);
-        verify(mockServiceTypeClientType2Network1).processResponse(response, ifIndex, NETWORK_1);
+        runOnHandler(() -> discoveryManager.onResponseReceived(response, SOCKET_KEY_NETWORK_1));
+        verify(mockServiceTypeClientType1Network1).processResponse(response, SOCKET_KEY_NETWORK_1);
+        verify(mockServiceTypeClientType2Network1).processResponse(response, SOCKET_KEY_NETWORK_1);
 
         // The first callback receives a notification that the network has been destroyed,
         // mockServiceTypeClientOne1 should send service removed notifications and remove from the
         // list of clients.
-        runOnHandler(() -> callback.onAllSocketsDestroyed(NETWORK_1));
+        runOnHandler(() -> callback.onSocketDestroyed(SOCKET_KEY_NETWORK_1));
         verify(mockServiceTypeClientType1Network1).notifySocketDestroyed();
 
         // Receive a response again, it should be processed only on
         // mockServiceTypeClientType2Network1. Because the mockServiceTypeClientType1Network1 is
         // removed from the list of clients, it is no longer able to process responses.
-        runOnHandler(() -> discoveryManager.onResponseReceived(
-                response, ifIndex, NETWORK_1));
+        runOnHandler(() -> discoveryManager.onResponseReceived(response, SOCKET_KEY_NETWORK_1));
         // Still times(1) as a response was received once previously
-        verify(mockServiceTypeClientType1Network1, times(1))
-                .processResponse(response, ifIndex, NETWORK_1);
-        verify(mockServiceTypeClientType2Network1, times(2))
-                .processResponse(response, ifIndex, NETWORK_1);
+        verify(mockServiceTypeClientType1Network1, times(1)).processResponse(
+                response, SOCKET_KEY_NETWORK_1);
+        verify(mockServiceTypeClientType2Network1, times(2)).processResponse(
+                response, SOCKET_KEY_NETWORK_1);
 
         // The client for NETWORK_1 receives the callback that the NETWORK_2 has been destroyed,
         // mockServiceTypeClientTwo2 shouldn't send any notifications.
-        runOnHandler(() -> callback2.onAllSocketsDestroyed(NETWORK_2));
+        runOnHandler(() -> callback2.onSocketDestroyed(SOCKET_KEY_NETWORK_2));
         verify(mockServiceTypeClientType2Network1, never()).notifySocketDestroyed();
 
         // Receive a response again, mockServiceTypeClientType2Network1 is still in the list of
         // clients, it's still able to process responses.
-        runOnHandler(() -> discoveryManager.onResponseReceived(
-                response, ifIndex, NETWORK_1));
-        verify(mockServiceTypeClientType1Network1, times(1))
-                .processResponse(response, ifIndex, NETWORK_1);
-        verify(mockServiceTypeClientType2Network1, times(3))
-                .processResponse(response, ifIndex, NETWORK_1);
+        runOnHandler(() -> discoveryManager.onResponseReceived(response, SOCKET_KEY_NETWORK_1));
+        verify(mockServiceTypeClientType1Network1, times(1)).processResponse(
+                response, SOCKET_KEY_NETWORK_1);
+        verify(mockServiceTypeClientType2Network1, times(3)).processResponse(
+                response, SOCKET_KEY_NETWORK_1);
     }
 
     @Test
@@ -301,27 +299,25 @@
                 MdnsSearchOptions.newBuilder().setNetwork(null /* network */).build();
         final SocketCreationCallback callback = expectSocketCreationCallback(
                 SERVICE_TYPE_1, mockListenerOne, network1Options);
-        runOnHandler(() -> callback.onSocketCreated(null /* network */));
+        runOnHandler(() -> callback.onSocketCreated(SOCKET_KEY_NULL_NETWORK));
         verify(mockServiceTypeClientType1NullNetwork).startSendAndReceive(
                 mockListenerOne, network1Options);
 
         // Receive a response, it should be processed on the client.
         final MdnsPacket response = createMdnsPacket(SERVICE_TYPE_1);
         final int ifIndex = 1;
-        runOnHandler(() -> discoveryManager.onResponseReceived(
-                response, ifIndex, null /* network */));
+        runOnHandler(() -> discoveryManager.onResponseReceived(response, SOCKET_KEY_NULL_NETWORK));
         verify(mockServiceTypeClientType1NullNetwork).processResponse(
-                response, ifIndex, null /* network */);
+                response, SOCKET_KEY_NULL_NETWORK);
 
-        runOnHandler(() -> callback.onAllSocketsDestroyed(null /* network */));
+        runOnHandler(() -> callback.onSocketDestroyed(SOCKET_KEY_NULL_NETWORK));
         verify(mockServiceTypeClientType1NullNetwork).notifySocketDestroyed();
 
         // Receive a response again, it should not be processed.
-        runOnHandler(() -> discoveryManager.onResponseReceived(
-                response, ifIndex, null /* network */));
+        runOnHandler(() -> discoveryManager.onResponseReceived(response, SOCKET_KEY_NULL_NETWORK));
         // Still times(1) as a response was received once previously
-        verify(mockServiceTypeClientType1NullNetwork, times(1))
-                .processResponse(response, ifIndex, null /* network */);
+        verify(mockServiceTypeClientType1NullNetwork, times(1)).processResponse(
+                response, SOCKET_KEY_NULL_NETWORK);
 
         // Unregister the listener, notifyNetworkUnrequested should be called but other stop methods
         // won't be call because the service type client was unregistered and destroyed. But those
@@ -330,7 +326,7 @@
         verify(socketClient).notifyNetworkUnrequested(mockListenerOne);
         verify(mockServiceTypeClientType1NullNetwork, never()).stopSendAndReceive(any());
         // The stopDiscovery() is only used by MdnsSocketClient, which doesn't send
-        // onAllSocketsDestroyed(). So the socket clients that send onAllSocketsDestroyed() do not
+        // onSocketDestroyed(). So the socket clients that send onSocketDestroyed() do not
         // need to call stopDiscovery().
         verify(socketClient, never()).stopDiscovery();
     }
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClientTest.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClientTest.java
index 87ba5d7..29de272 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClientTest.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClientTest.java
@@ -21,7 +21,6 @@
 
 import static org.junit.Assert.assertEquals;
 import static org.mockito.ArgumentMatchers.any;
-import static org.mockito.ArgumentMatchers.anyInt;
 import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.eq;
 import static org.mockito.Mockito.mock;
@@ -29,7 +28,6 @@
 import static org.mockito.Mockito.timeout;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
-import static org.mockito.Mockito.verifyNoMoreInteractions;
 
 import android.net.InetAddresses;
 import android.net.Network;
@@ -70,13 +68,16 @@
     @Mock private SocketCreationCallback mSocketCreationCallback;
     private MdnsMultinetworkSocketClient mSocketClient;
     private Handler mHandler;
+    private SocketKey mSocketKey;
 
     @Before
     public void setUp() throws SocketException {
         MockitoAnnotations.initMocks(this);
+
         final HandlerThread thread = new HandlerThread("MdnsMultinetworkSocketClientTest");
         thread.start();
         mHandler = new Handler(thread.getLooper());
+        mSocketKey = new SocketKey(1000 /* interfaceIndex */);
         mSocketClient = new MdnsMultinetworkSocketClient(thread.getLooper(), mProvider);
         mHandler.post(() -> mSocketClient.setCallback(mCallback));
     }
@@ -123,27 +124,49 @@
             doReturn(createEmptyNetworkInterface()).when(socket).getInterface();
         }
 
+        final SocketKey tetherSocketKey1 = new SocketKey(1001 /* interfaceIndex */);
+        final SocketKey tetherSocketKey2 = new SocketKey(1002 /* interfaceIndex */);
         // Notify socket created
-        callback.onSocketCreated(mNetwork, mSocket, List.of());
-        verify(mSocketCreationCallback).onSocketCreated(mNetwork);
-        callback.onSocketCreated(null, tetherIfaceSock1, List.of());
-        verify(mSocketCreationCallback).onSocketCreated(null);
-        callback.onSocketCreated(null, tetherIfaceSock2, List.of());
-        verify(mSocketCreationCallback, times(2)).onSocketCreated(null);
+        callback.onSocketCreated(mSocketKey, mSocket, List.of());
+        verify(mSocketCreationCallback).onSocketCreated(mSocketKey);
+        callback.onSocketCreated(tetherSocketKey1, tetherIfaceSock1, List.of());
+        verify(mSocketCreationCallback).onSocketCreated(tetherSocketKey1);
+        callback.onSocketCreated(tetherSocketKey2, tetherIfaceSock2, List.of());
+        verify(mSocketCreationCallback).onSocketCreated(tetherSocketKey2);
 
-        // Send packet to IPv4 with target network and verify sending has been called.
-        mSocketClient.sendMulticastPacket(ipv4Packet, mNetwork);
+        // Send packet to IPv4 with mSocketKey and verify sending has been called.
+        mSocketClient.sendPacketRequestingMulticastResponse(ipv4Packet, mSocketKey,
+                false /* onlyUseIpv6OnIpv6OnlyNetworks */);
         HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
         verify(mSocket).send(ipv4Packet);
         verify(tetherIfaceSock1, never()).send(any());
         verify(tetherIfaceSock2, never()).send(any());
 
-        // Send packet to IPv6 without target network and verify sending has been called.
-        mSocketClient.sendMulticastPacket(ipv6Packet, null);
+        // Send packet to IPv4 with onlyUseIpv6OnIpv6OnlyNetworks = true, the packet will be sent.
+        mSocketClient.sendPacketRequestingMulticastResponse(ipv4Packet, mSocketKey,
+                true /* onlyUseIpv6OnIpv6OnlyNetworks */);
+        HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
+        verify(mSocket, times(2)).send(ipv4Packet);
+        verify(tetherIfaceSock1, never()).send(any());
+        verify(tetherIfaceSock2, never()).send(any());
+
+        // Send packet to IPv6 with tetherSocketKey1 and verify sending has been called.
+        mSocketClient.sendPacketRequestingMulticastResponse(ipv6Packet, tetherSocketKey1,
+                false /* onlyUseIpv6OnIpv6OnlyNetworks */);
         HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
         verify(mSocket, never()).send(ipv6Packet);
         verify(tetherIfaceSock1).send(ipv6Packet);
-        verify(tetherIfaceSock2).send(ipv6Packet);
+        verify(tetherIfaceSock2, never()).send(ipv6Packet);
+
+        // Send packet to IPv6 with onlyUseIpv6OnIpv6OnlyNetworks = true, the packet will not be
+        // sent. Therefore, the tetherIfaceSock1.send() and tetherIfaceSock2.send() are still be
+        // called once.
+        mSocketClient.sendPacketRequestingMulticastResponse(ipv6Packet, tetherSocketKey1,
+                true /* onlyUseIpv6OnIpv6OnlyNetworks */);
+        HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
+        verify(mSocket, never()).send(ipv6Packet);
+        verify(tetherIfaceSock1, times(1)).send(ipv6Packet);
+        verify(tetherIfaceSock2, never()).send(ipv6Packet);
     }
 
     @Test
@@ -164,8 +187,8 @@
 
         doReturn(createEmptyNetworkInterface()).when(mSocket).getInterface();
         // Notify socket created
-        callback.onSocketCreated(mNetwork, mSocket, List.of());
-        verify(mSocketCreationCallback).onSocketCreated(mNetwork);
+        callback.onSocketCreated(mSocketKey, mSocket, List.of());
+        verify(mSocketCreationCallback).onSocketCreated(mSocketKey);
 
         final ArgumentCaptor<PacketHandler> handlerCaptor =
                 ArgumentCaptor.forClass(PacketHandler.class);
@@ -176,7 +199,7 @@
         handler.handlePacket(data, data.length, null /* src */);
         final ArgumentCaptor<MdnsPacket> responseCaptor =
                 ArgumentCaptor.forClass(MdnsPacket.class);
-        verify(mCallback).onResponseReceived(responseCaptor.capture(), anyInt(), any());
+        verify(mCallback).onResponseReceived(responseCaptor.capture(), any());
         final MdnsPacket response = responseCaptor.getValue();
         assertEquals(0, response.questions.size());
         assertEquals(0, response.additionalRecords.size());
@@ -214,14 +237,18 @@
         doReturn(createEmptyNetworkInterface()).when(socket2).getInterface();
         doReturn(createEmptyNetworkInterface()).when(socket3).getInterface();
 
-        callback.onSocketCreated(mNetwork, mSocket, List.of());
-        callback.onSocketCreated(null, socket2, List.of());
-        callback.onSocketCreated(null, socket3, List.of());
-        verify(mSocketCreationCallback).onSocketCreated(mNetwork);
-        verify(mSocketCreationCallback, times(2)).onSocketCreated(null);
+        final SocketKey socketKey2 = new SocketKey(1001 /* interfaceIndex */);
+        final SocketKey socketKey3 = new SocketKey(1002 /* interfaceIndex */);
+        callback.onSocketCreated(mSocketKey, mSocket, List.of());
+        callback.onSocketCreated(socketKey2, socket2, List.of());
+        callback.onSocketCreated(socketKey3, socket3, List.of());
+        verify(mSocketCreationCallback).onSocketCreated(mSocketKey);
+        verify(mSocketCreationCallback).onSocketCreated(socketKey2);
+        verify(mSocketCreationCallback).onSocketCreated(socketKey3);
 
-        // Send IPv4 packet on the non-null Network and verify sending has been called.
-        mSocketClient.sendMulticastPacket(ipv4Packet, mNetwork);
+        // Send IPv4 packet on the mSocketKey and verify sending has been called.
+        mSocketClient.sendPacketRequestingMulticastResponse(ipv4Packet, mSocketKey,
+                false /* onlyUseIpv6OnIpv6OnlyNetworks */);
         HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
         verify(mSocket).send(ipv4Packet);
         verify(socket2, never()).send(any());
@@ -241,80 +268,88 @@
         final SocketCallback callback2 = callback2Captor.getAllValues().get(1);
 
         // Notify socket created for all networks.
-        callback2.onSocketCreated(mNetwork, mSocket, List.of());
-        callback2.onSocketCreated(null, socket2, List.of());
-        callback2.onSocketCreated(null, socket3, List.of());
-        verify(socketCreationCb2).onSocketCreated(mNetwork);
-        verify(socketCreationCb2, times(2)).onSocketCreated(null);
+        callback2.onSocketCreated(mSocketKey, mSocket, List.of());
+        callback2.onSocketCreated(socketKey2, socket2, List.of());
+        callback2.onSocketCreated(socketKey3, socket3, List.of());
+        verify(socketCreationCb2).onSocketCreated(mSocketKey);
+        verify(socketCreationCb2).onSocketCreated(socketKey2);
+        verify(socketCreationCb2).onSocketCreated(socketKey3);
 
-        // Send IPv4 packet to null network and verify sending to the 2 tethered interface sockets.
-        mSocketClient.sendMulticastPacket(ipv4Packet, null);
+        // Send IPv4 packet on socket2 and verify sending to the socket2 only.
+        mSocketClient.sendPacketRequestingMulticastResponse(ipv4Packet, socketKey2,
+                false /* onlyUseIpv6OnIpv6OnlyNetworks */);
         HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
         // ipv4Packet still sent only once on mSocket: times(1) matches the packet sent earlier on
         // mNetwork
         verify(mSocket, times(1)).send(ipv4Packet);
         verify(socket2).send(ipv4Packet);
-        verify(socket3).send(ipv4Packet);
+        verify(socket3, never()).send(ipv4Packet);
 
         // Unregister the second request
         mHandler.post(() -> mSocketClient.notifyNetworkUnrequested(listener2));
         verify(mProvider, timeout(DEFAULT_TIMEOUT)).unrequestSocket(callback2);
 
         // Send IPv4 packet again and verify it's still sent a second time
-        mSocketClient.sendMulticastPacket(ipv4Packet, null);
+        mSocketClient.sendPacketRequestingMulticastResponse(ipv4Packet, socketKey2,
+                false /* onlyUseIpv6OnIpv6OnlyNetworks */);
         HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
         verify(socket2, times(2)).send(ipv4Packet);
-        verify(socket3, times(2)).send(ipv4Packet);
+        verify(socket3, never()).send(ipv4Packet);
 
         // Unrequest remaining sockets
         mHandler.post(() -> mSocketClient.notifyNetworkUnrequested(mListener));
         verify(mProvider, timeout(DEFAULT_TIMEOUT)).unrequestSocket(callback);
 
         // Send IPv4 packet and verify no more sending.
-        mSocketClient.sendMulticastPacket(ipv4Packet, null);
+        mSocketClient.sendPacketRequestingMulticastResponse(ipv4Packet, mSocketKey,
+                false /* onlyUseIpv6OnIpv6OnlyNetworks */);
         HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
         verify(mSocket, times(1)).send(ipv4Packet);
         verify(socket2, times(2)).send(ipv4Packet);
-        verify(socket3, times(2)).send(ipv4Packet);
+        verify(socket3, never()).send(ipv4Packet);
     }
 
     @Test
     public void testNotifyNetworkUnrequested_SocketsOnNullNetwork() {
         final MdnsInterfaceSocket otherSocket = mock(MdnsInterfaceSocket.class);
+        final SocketKey otherSocketKey = new SocketKey(1001 /* interfaceIndex */);
         final SocketCallback callback = expectSocketCallback(
                 mListener, null /* requestedNetwork */);
         doReturn(createEmptyNetworkInterface()).when(mSocket).getInterface();
         doReturn(createEmptyNetworkInterface()).when(otherSocket).getInterface();
 
-        callback.onSocketCreated(null /* network */, mSocket, List.of());
-        verify(mSocketCreationCallback).onSocketCreated(null);
-        callback.onSocketCreated(null /* network */, otherSocket, List.of());
-        verify(mSocketCreationCallback, times(2)).onSocketCreated(null);
+        callback.onSocketCreated(mSocketKey, mSocket, List.of());
+        verify(mSocketCreationCallback).onSocketCreated(mSocketKey);
+        callback.onSocketCreated(otherSocketKey, otherSocket, List.of());
+        verify(mSocketCreationCallback).onSocketCreated(otherSocketKey);
 
-        verify(mSocketCreationCallback, never()).onAllSocketsDestroyed(null /* network */);
+        verify(mSocketCreationCallback, never()).onSocketDestroyed(mSocketKey);
+        verify(mSocketCreationCallback, never()).onSocketDestroyed(otherSocketKey);
         mHandler.post(() -> mSocketClient.notifyNetworkUnrequested(mListener));
         HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
 
         verify(mProvider).unrequestSocket(callback);
-        verify(mSocketCreationCallback).onAllSocketsDestroyed(null /* network */);
+        verify(mSocketCreationCallback).onSocketDestroyed(mSocketKey);
+        verify(mSocketCreationCallback).onSocketDestroyed(otherSocketKey);
     }
 
     @Test
     public void testSocketCreatedAndDestroyed_NullNetwork() throws IOException {
         final MdnsInterfaceSocket otherSocket = mock(MdnsInterfaceSocket.class);
+        final SocketKey otherSocketKey = new SocketKey(1001 /* interfaceIndex */);
         final SocketCallback callback = expectSocketCallback(mListener, null /* network */);
         doReturn(createEmptyNetworkInterface()).when(mSocket).getInterface();
         doReturn(createEmptyNetworkInterface()).when(otherSocket).getInterface();
 
-        callback.onSocketCreated(null /* network */, mSocket, List.of());
-        verify(mSocketCreationCallback).onSocketCreated(null);
-        callback.onSocketCreated(null /* network */, otherSocket, List.of());
-        verify(mSocketCreationCallback, times(2)).onSocketCreated(null);
+        callback.onSocketCreated(mSocketKey, mSocket, List.of());
+        verify(mSocketCreationCallback).onSocketCreated(mSocketKey);
+        callback.onSocketCreated(otherSocketKey, otherSocket, List.of());
+        verify(mSocketCreationCallback).onSocketCreated(otherSocketKey);
 
         // Notify socket destroyed
-        callback.onInterfaceDestroyed(null /* network */, mSocket);
-        verifyNoMoreInteractions(mSocketCreationCallback);
-        callback.onInterfaceDestroyed(null /* network */, otherSocket);
-        verify(mSocketCreationCallback).onAllSocketsDestroyed(null /* network */);
+        callback.onInterfaceDestroyed(mSocketKey, mSocket);
+        verify(mSocketCreationCallback).onSocketDestroyed(mSocketKey);
+        callback.onInterfaceDestroyed(otherSocketKey, otherSocket);
+        verify(mSocketCreationCallback).onSocketDestroyed(otherSocketKey);
     }
 }
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java
index d1adecf..03e893f 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java
@@ -118,6 +118,7 @@
     private FakeExecutor currentThreadExecutor = new FakeExecutor();
 
     private MdnsServiceTypeClient client;
+    private SocketKey socketKey;
 
     @Before
     @SuppressWarnings("DoNotMock")
@@ -128,6 +129,7 @@
         expectedIPv4Packets = new DatagramPacket[16];
         expectedIPv6Packets = new DatagramPacket[16];
         expectedSendFutures = new ScheduledFuture<?>[16];
+        socketKey = new SocketKey(mockNetwork, INTERFACE_INDEX);
 
         for (int i = 0; i < expectedSendFutures.length; ++i) {
             expectedIPv4Packets[i] = new DatagramPacket(buf, 0 /* offset */, 5 /* length */,
@@ -174,7 +176,7 @@
 
         client =
                 new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor,
-                        mockDecoderClock, mockNetwork, mockSharedLog) {
+                        mockDecoderClock, socketKey, mockSharedLog) {
                     @Override
                     MdnsPacketWriter createMdnsPacketWriter() {
                         return mockPacketWriter;
@@ -325,7 +327,8 @@
         MdnsSearchOptions searchOptions =
                 MdnsSearchOptions.newBuilder().addSubtype("12345").setIsPassiveMode(false).build();
         QueryTaskConfig config = new QueryTaskConfig(
-                searchOptions.getSubtypes(), searchOptions.isPassiveMode(), 1, mockNetwork);
+                searchOptions.getSubtypes(), searchOptions.isPassiveMode(),
+                false /* onlyUseIpv6OnIpv6OnlyNetworks */, 1, socketKey);
 
         // This is the first query. We will ask for unicast response.
         assertTrue(config.expectUnicastResponse);
@@ -354,7 +357,8 @@
         MdnsSearchOptions searchOptions =
                 MdnsSearchOptions.newBuilder().addSubtype("12345").setIsPassiveMode(false).build();
         QueryTaskConfig config = new QueryTaskConfig(
-                searchOptions.getSubtypes(), searchOptions.isPassiveMode(), 1, mockNetwork);
+                searchOptions.getSubtypes(), searchOptions.isPassiveMode(),
+                false /* onlyUseIpv6OnIpv6OnlyNetworks */, 1, socketKey);
 
         // This is the first query. We will ask for unicast response.
         assertTrue(config.expectUnicastResponse);
@@ -434,7 +438,7 @@
         client.processResponse(createResponse(
                 "service-instance-1", "192.0.2.123", 5353,
                 SERVICE_TYPE_LABELS,
-                Collections.emptyMap(), TEST_TTL), /* interfaceIndex= */ 20, mockNetwork);
+                Collections.emptyMap(), TEST_TTL), socketKey);
 
         verify(mockListenerOne).onServiceNameDiscovered(any());
         verify(mockListenerOne).onServiceFound(any());
@@ -455,8 +459,7 @@
 
     private static void verifyServiceInfo(MdnsServiceInfo serviceInfo, String serviceName,
             String[] serviceType, List<String> ipv4Addresses, List<String> ipv6Addresses, int port,
-            List<String> subTypes, Map<String, String> attributes, int interfaceIndex,
-            Network network) {
+            List<String> subTypes, Map<String, String> attributes, SocketKey socketKey) {
         assertEquals(serviceName, serviceInfo.getServiceInstanceName());
         assertArrayEquals(serviceType, serviceInfo.getServiceType());
         assertEquals(ipv4Addresses, serviceInfo.getIpv4Addresses());
@@ -467,8 +470,8 @@
             assertTrue(attributes.containsKey(key));
             assertEquals(attributes.get(key), serviceInfo.getAttributeByKey(key));
         }
-        assertEquals(interfaceIndex, serviceInfo.getInterfaceIndex());
-        assertEquals(network, serviceInfo.getNetwork());
+        assertEquals(socketKey.getInterfaceIndex(), serviceInfo.getInterfaceIndex());
+        assertEquals(socketKey.getNetwork(), serviceInfo.getNetwork());
     }
 
     @Test
@@ -478,7 +481,7 @@
         client.processResponse(createResponse(
                 "service-instance-1", null /* host */, 0 /* port */,
                 SERVICE_TYPE_LABELS,
-                Collections.emptyMap(), TEST_TTL), INTERFACE_INDEX, mockNetwork);
+                Collections.emptyMap(), TEST_TTL), socketKey);
         verify(mockListenerOne).onServiceNameDiscovered(serviceInfoCaptor.capture());
         verifyServiceInfo(serviceInfoCaptor.getAllValues().get(0),
                 "service-instance-1",
@@ -488,8 +491,7 @@
                 /* port= */ 0,
                 /* subTypes= */ List.of(),
                 Collections.emptyMap(),
-                INTERFACE_INDEX,
-                mockNetwork);
+                socketKey);
 
         verify(mockListenerOne, never()).onServiceFound(any(MdnsServiceInfo.class));
         verify(mockListenerOne, never()).onServiceUpdated(any(MdnsServiceInfo.class));
@@ -504,14 +506,14 @@
         client.processResponse(createResponse(
                 "service-instance-1", ipV4Address, 5353,
                 /* subtype= */ "ABCDE",
-                Collections.emptyMap(), TEST_TTL), /* interfaceIndex= */ 20, mockNetwork);
+                Collections.emptyMap(), TEST_TTL), socketKey);
 
         // Process a second response with a different port and updated text attributes.
         client.processResponse(createResponse(
-                "service-instance-1", ipV4Address, 5354,
-                /* subtype= */ "ABCDE",
-                Collections.singletonMap("key", "value"), TEST_TTL),
-                /* interfaceIndex= */ 20, mockNetwork);
+                        "service-instance-1", ipV4Address, 5354,
+                        /* subtype= */ "ABCDE",
+                        Collections.singletonMap("key", "value"), TEST_TTL),
+                socketKey);
 
         // Verify onServiceNameDiscovered was called once for the initial response.
         verify(mockListenerOne).onServiceNameDiscovered(serviceInfoCaptor.capture());
@@ -523,8 +525,7 @@
                 5353 /* port */,
                 Collections.singletonList("ABCDE") /* subTypes */,
                 Collections.singletonMap("key", null) /* attributes */,
-                20 /* interfaceIndex */,
-                mockNetwork);
+                socketKey);
 
         // Verify onServiceFound was called once for the initial response.
         verify(mockListenerOne).onServiceFound(serviceInfoCaptor.capture());
@@ -534,8 +535,8 @@
         assertEquals(initialServiceInfo.getPort(), 5353);
         assertEquals(initialServiceInfo.getSubtypes(), Collections.singletonList("ABCDE"));
         assertNull(initialServiceInfo.getAttributeByKey("key"));
-        assertEquals(initialServiceInfo.getInterfaceIndex(), 20);
-        assertEquals(mockNetwork, initialServiceInfo.getNetwork());
+        assertEquals(socketKey.getInterfaceIndex(), initialServiceInfo.getInterfaceIndex());
+        assertEquals(socketKey.getNetwork(), initialServiceInfo.getNetwork());
 
         // Verify onServiceUpdated was called once for the second response.
         verify(mockListenerOne).onServiceUpdated(serviceInfoCaptor.capture());
@@ -546,8 +547,8 @@
         assertTrue(updatedServiceInfo.hasSubtypes());
         assertEquals(updatedServiceInfo.getSubtypes(), Collections.singletonList("ABCDE"));
         assertEquals(updatedServiceInfo.getAttributeByKey("key"), "value");
-        assertEquals(updatedServiceInfo.getInterfaceIndex(), 20);
-        assertEquals(mockNetwork, updatedServiceInfo.getNetwork());
+        assertEquals(socketKey.getInterfaceIndex(), updatedServiceInfo.getInterfaceIndex());
+        assertEquals(socketKey.getNetwork(), updatedServiceInfo.getNetwork());
     }
 
     @Test
@@ -559,14 +560,14 @@
         client.processResponse(createResponse(
                 "service-instance-1", ipV6Address, 5353,
                 /* subtype= */ "ABCDE",
-                Collections.emptyMap(), TEST_TTL), /* interfaceIndex= */ 20, mockNetwork);
+                Collections.emptyMap(), TEST_TTL), socketKey);
 
         // Process a second response with a different port and updated text attributes.
         client.processResponse(createResponse(
-                "service-instance-1", ipV6Address, 5354,
-                /* subtype= */ "ABCDE",
-                Collections.singletonMap("key", "value"), TEST_TTL),
-                /* interfaceIndex= */ 20, mockNetwork);
+                        "service-instance-1", ipV6Address, 5354,
+                        /* subtype= */ "ABCDE",
+                        Collections.singletonMap("key", "value"), TEST_TTL),
+                socketKey);
 
         // Verify onServiceNameDiscovered was called once for the initial response.
         verify(mockListenerOne).onServiceNameDiscovered(serviceInfoCaptor.capture());
@@ -578,8 +579,7 @@
                 5353 /* port */,
                 Collections.singletonList("ABCDE") /* subTypes */,
                 Collections.singletonMap("key", null) /* attributes */,
-                20 /* interfaceIndex */,
-                mockNetwork);
+                socketKey);
 
         // Verify onServiceFound was called once for the initial response.
         verify(mockListenerOne).onServiceFound(serviceInfoCaptor.capture());
@@ -589,8 +589,8 @@
         assertEquals(initialServiceInfo.getPort(), 5353);
         assertEquals(initialServiceInfo.getSubtypes(), Collections.singletonList("ABCDE"));
         assertNull(initialServiceInfo.getAttributeByKey("key"));
-        assertEquals(initialServiceInfo.getInterfaceIndex(), 20);
-        assertEquals(mockNetwork, initialServiceInfo.getNetwork());
+        assertEquals(socketKey.getInterfaceIndex(), initialServiceInfo.getInterfaceIndex());
+        assertEquals(socketKey.getNetwork(), initialServiceInfo.getNetwork());
 
         // Verify onServiceUpdated was called once for the second response.
         verify(mockListenerOne).onServiceUpdated(serviceInfoCaptor.capture());
@@ -601,8 +601,8 @@
         assertTrue(updatedServiceInfo.hasSubtypes());
         assertEquals(updatedServiceInfo.getSubtypes(), Collections.singletonList("ABCDE"));
         assertEquals(updatedServiceInfo.getAttributeByKey("key"), "value");
-        assertEquals(updatedServiceInfo.getInterfaceIndex(), 20);
-        assertEquals(mockNetwork, updatedServiceInfo.getNetwork());
+        assertEquals(socketKey.getInterfaceIndex(), updatedServiceInfo.getInterfaceIndex());
+        assertEquals(socketKey.getNetwork(), updatedServiceInfo.getNetwork());
     }
 
     private void verifyServiceRemovedNoCallback(MdnsServiceBrowserListener listener) {
@@ -611,17 +611,17 @@
     }
 
     private void verifyServiceRemovedCallback(MdnsServiceBrowserListener listener,
-            String serviceName, String[] serviceType, int interfaceIndex, Network network) {
+            String serviceName, String[] serviceType, SocketKey socketKey) {
         verify(listener).onServiceRemoved(argThat(
                 info -> serviceName.equals(info.getServiceInstanceName())
                         && Arrays.equals(serviceType, info.getServiceType())
-                        && info.getInterfaceIndex() == interfaceIndex
-                        && network.equals(info.getNetwork())));
+                        && info.getInterfaceIndex() == socketKey.getInterfaceIndex()
+                        && socketKey.getNetwork().equals(info.getNetwork())));
         verify(listener).onServiceNameRemoved(argThat(
                 info -> serviceName.equals(info.getServiceInstanceName())
                         && Arrays.equals(serviceType, info.getServiceType())
-                        && info.getInterfaceIndex() == interfaceIndex
-                        && network.equals(info.getNetwork())));
+                        && info.getInterfaceIndex() == socketKey.getInterfaceIndex()
+                        && socketKey.getNetwork().equals(info.getNetwork())));
     }
 
     @Test
@@ -635,12 +635,12 @@
         client.processResponse(createResponse(
                 serviceName, ipV6Address, 5353,
                 SERVICE_TYPE_LABELS,
-                Collections.emptyMap(), TEST_TTL), INTERFACE_INDEX, mockNetwork);
+                Collections.emptyMap(), TEST_TTL), socketKey);
 
         client.processResponse(createResponse(
                 "goodbye-service", ipV6Address, 5353,
                 SERVICE_TYPE_LABELS,
-                Collections.emptyMap(), /* ptrTtlMillis= */ 0L), INTERFACE_INDEX, mockNetwork);
+                Collections.emptyMap(), /* ptrTtlMillis= */ 0L), socketKey);
 
         // Verify removed callback won't be called if the service is not existed.
         verifyServiceRemovedNoCallback(mockListenerOne);
@@ -650,11 +650,11 @@
         client.processResponse(createResponse(
                 serviceName, ipV6Address, 5353,
                 SERVICE_TYPE_LABELS,
-                Collections.emptyMap(), 0L), INTERFACE_INDEX, mockNetwork);
+                Collections.emptyMap(), 0L), socketKey);
         verifyServiceRemovedCallback(
-                mockListenerOne, serviceName, SERVICE_TYPE_LABELS, INTERFACE_INDEX, mockNetwork);
+                mockListenerOne, serviceName, SERVICE_TYPE_LABELS, socketKey);
         verifyServiceRemovedCallback(
-                mockListenerTwo, serviceName, SERVICE_TYPE_LABELS, INTERFACE_INDEX, mockNetwork);
+                mockListenerTwo, serviceName, SERVICE_TYPE_LABELS, socketKey);
     }
 
     @Test
@@ -663,7 +663,7 @@
         client.processResponse(createResponse(
                 "service-instance-1", "192.168.1.1", 5353,
                 /* subtype= */ "ABCDE",
-                Collections.emptyMap(), TEST_TTL), INTERFACE_INDEX, mockNetwork);
+                Collections.emptyMap(), TEST_TTL), socketKey);
 
         client.startSendAndReceive(mockListenerOne, MdnsSearchOptions.getDefaultOptions());
 
@@ -677,8 +677,7 @@
                 5353 /* port */,
                 Collections.singletonList("ABCDE") /* subTypes */,
                 Collections.singletonMap("key", null) /* attributes */,
-                INTERFACE_INDEX,
-                mockNetwork);
+                socketKey);
 
         // Verify onServiceFound was called once for the existing response.
         verify(mockListenerOne).onServiceFound(serviceInfoCaptor.capture());
@@ -693,7 +692,7 @@
         client.processResponse(createResponse(
                 "service-instance-1", "192.168.1.1", 5353,
                 SERVICE_TYPE_LABELS,
-                Collections.emptyMap(), /* ptrTtlMillis= */ 0L), INTERFACE_INDEX, mockNetwork);
+                Collections.emptyMap(), /* ptrTtlMillis= */ 0L), socketKey);
 
         client.startSendAndReceive(mockListenerTwo, MdnsSearchOptions.getDefaultOptions());
 
@@ -709,7 +708,7 @@
         final String serviceInstanceName = "service-instance-1";
         client =
                 new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor,
-                        mockDecoderClock, mockNetwork, mockSharedLog) {
+                        mockDecoderClock, socketKey, mockSharedLog) {
                     @Override
                     MdnsPacketWriter createMdnsPacketWriter() {
                         return mockPacketWriter;
@@ -723,7 +722,7 @@
         // Process the initial response.
         client.processResponse(createResponse(
                 serviceInstanceName, "192.168.1.1", 5353, /* subtype= */ "ABCDE",
-                Collections.emptyMap(), TEST_TTL), INTERFACE_INDEX, mockNetwork);
+                Collections.emptyMap(), TEST_TTL), socketKey);
 
         // Clear the scheduled runnable.
         currentThreadExecutor.getAndClearLastScheduledRunnable();
@@ -740,8 +739,8 @@
         firstMdnsTask.run();
 
         // Verify removed callback was called.
-        verifyServiceRemovedCallback(mockListenerOne, serviceInstanceName, SERVICE_TYPE_LABELS,
-                INTERFACE_INDEX, mockNetwork);
+        verifyServiceRemovedCallback(
+                mockListenerOne, serviceInstanceName, SERVICE_TYPE_LABELS, socketKey);
     }
 
     @Test
@@ -750,7 +749,7 @@
         final String serviceInstanceName = "service-instance-1";
         client =
                 new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor,
-                        mockDecoderClock, mockNetwork, mockSharedLog) {
+                        mockDecoderClock, socketKey, mockSharedLog) {
                     @Override
                     MdnsPacketWriter createMdnsPacketWriter() {
                         return mockPacketWriter;
@@ -762,7 +761,7 @@
         // Process the initial response.
         client.processResponse(createResponse(
                 serviceInstanceName, "192.168.1.1", 5353, /* subtype= */ "ABCDE",
-                Collections.emptyMap(), TEST_TTL), INTERFACE_INDEX, mockNetwork);
+                Collections.emptyMap(), TEST_TTL), socketKey);
 
         // Clear the scheduled runnable.
         currentThreadExecutor.getAndClearLastScheduledRunnable();
@@ -783,7 +782,7 @@
         final String serviceInstanceName = "service-instance-1";
         client =
                 new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor,
-                        mockDecoderClock, mockNetwork, mockSharedLog) {
+                        mockDecoderClock, socketKey, mockSharedLog) {
                     @Override
                     MdnsPacketWriter createMdnsPacketWriter() {
                         return mockPacketWriter;
@@ -795,7 +794,7 @@
         // Process the initial response.
         client.processResponse(createResponse(
                 serviceInstanceName, "192.168.1.1", 5353, /* subtype= */ "ABCDE",
-                Collections.emptyMap(), TEST_TTL), INTERFACE_INDEX, mockNetwork);
+                Collections.emptyMap(), TEST_TTL), socketKey);
 
         // Clear the scheduled runnable.
         currentThreadExecutor.getAndClearLastScheduledRunnable();
@@ -805,8 +804,8 @@
         firstMdnsTask.run();
 
         // Verify removed callback was called.
-        verifyServiceRemovedCallback(mockListenerOne, serviceInstanceName, SERVICE_TYPE_LABELS,
-                INTERFACE_INDEX, mockNetwork);
+        verifyServiceRemovedCallback(
+                mockListenerOne, serviceInstanceName, SERVICE_TYPE_LABELS, socketKey);
     }
 
     @Test
@@ -821,23 +820,23 @@
         final String subtype = "ABCDE";
         client.processResponse(createResponse(
                 serviceName, null, 5353, subtype,
-                Collections.emptyMap(), TEST_TTL), INTERFACE_INDEX, mockNetwork);
+                Collections.emptyMap(), TEST_TTL), socketKey);
 
         // Process a second response which has ip address to make response become complete.
         client.processResponse(createResponse(
                 serviceName, ipV4Address, 5353, subtype,
-                Collections.emptyMap(), TEST_TTL), INTERFACE_INDEX, mockNetwork);
+                Collections.emptyMap(), TEST_TTL), socketKey);
 
         // Process a third response with a different ip address, port and updated text attributes.
         client.processResponse(createResponse(
                 serviceName, ipV6Address, 5354, subtype,
-                Collections.singletonMap("key", "value"), TEST_TTL), INTERFACE_INDEX, mockNetwork);
+                Collections.singletonMap("key", "value"), TEST_TTL), socketKey);
 
         // Process the last response which is goodbye message (with the main type, not subtype).
         client.processResponse(createResponse(
-                serviceName, ipV6Address, 5354, SERVICE_TYPE_LABELS,
-                Collections.singletonMap("key", "value"), /* ptrTtlMillis= */ 0L),
-                INTERFACE_INDEX, mockNetwork);
+                        serviceName, ipV6Address, 5354, SERVICE_TYPE_LABELS,
+                        Collections.singletonMap("key", "value"), /* ptrTtlMillis= */ 0L),
+                socketKey);
 
         // Verify onServiceNameDiscovered was first called for the initial response.
         inOrder.verify(mockListenerOne).onServiceNameDiscovered(serviceInfoCaptor.capture());
@@ -849,8 +848,7 @@
                 5353 /* port */,
                 Collections.singletonList(subtype) /* subTypes */,
                 Collections.singletonMap("key", null) /* attributes */,
-                INTERFACE_INDEX,
-                mockNetwork);
+                socketKey);
 
         // Verify onServiceFound was second called for the second response.
         inOrder.verify(mockListenerOne).onServiceFound(serviceInfoCaptor.capture());
@@ -862,8 +860,7 @@
                 5353 /* port */,
                 Collections.singletonList(subtype) /* subTypes */,
                 Collections.singletonMap("key", null) /* attributes */,
-                INTERFACE_INDEX,
-                mockNetwork);
+                socketKey);
 
         // Verify onServiceUpdated was third called for the third response.
         inOrder.verify(mockListenerOne).onServiceUpdated(serviceInfoCaptor.capture());
@@ -875,8 +872,7 @@
                 5354 /* port */,
                 Collections.singletonList(subtype) /* subTypes */,
                 Collections.singletonMap("key", "value") /* attributes */,
-                INTERFACE_INDEX,
-                mockNetwork);
+                socketKey);
 
         // Verify onServiceRemoved was called for the last response.
         inOrder.verify(mockListenerOne).onServiceRemoved(serviceInfoCaptor.capture());
@@ -888,8 +884,7 @@
                 5354 /* port */,
                 Collections.singletonList("ABCDE") /* subTypes */,
                 Collections.singletonMap("key", "value") /* attributes */,
-                INTERFACE_INDEX,
-                mockNetwork);
+                socketKey);
 
         // Verify onServiceNameRemoved was called for the last response.
         inOrder.verify(mockListenerOne).onServiceNameRemoved(serviceInfoCaptor.capture());
@@ -901,14 +896,13 @@
                 5354 /* port */,
                 Collections.singletonList("ABCDE") /* subTypes */,
                 Collections.singletonMap("key", "value") /* attributes */,
-                INTERFACE_INDEX,
-                mockNetwork);
+                socketKey);
     }
 
     @Test
     public void testProcessResponse_Resolve() throws Exception {
         client = new MdnsServiceTypeClient(
-                SERVICE_TYPE, mockSocketClient, currentThreadExecutor, mockNetwork, mockSharedLog);
+                SERVICE_TYPE, mockSocketClient, currentThreadExecutor, socketKey, mockSharedLog);
 
         final String instanceName = "service-instance";
         final String[] hostname = new String[] { "testhost "};
@@ -926,16 +920,16 @@
                 ArgumentCaptor.forClass(DatagramPacket.class);
         currentThreadExecutor.getAndClearLastScheduledRunnable().run();
         // Send twice for IPv4 and IPv6
-        inOrder.verify(mockSocketClient, times(2)).sendUnicastPacket(srvTxtQueryCaptor.capture(),
-                eq(mockNetwork));
+        inOrder.verify(mockSocketClient, times(2)).sendPacketRequestingUnicastResponse(
+                srvTxtQueryCaptor.capture(),
+                eq(socketKey), eq(false));
 
         final MdnsPacket srvTxtQueryPacket = MdnsPacket.parse(
                 new MdnsPacketReader(srvTxtQueryCaptor.getValue()));
 
         final String[] serviceName = getTestServiceName(instanceName);
         assertFalse(hasQuestion(srvTxtQueryPacket, MdnsRecord.TYPE_PTR));
-        assertTrue(hasQuestion(srvTxtQueryPacket, MdnsRecord.TYPE_SRV, serviceName));
-        assertTrue(hasQuestion(srvTxtQueryPacket, MdnsRecord.TYPE_TXT, serviceName));
+        assertTrue(hasQuestion(srvTxtQueryPacket, MdnsRecord.TYPE_ANY, serviceName));
 
         // Process a response with SRV+TXT
         final MdnsPacket srvTxtResponse = new MdnsPacket(
@@ -951,14 +945,15 @@
                 Collections.emptyList() /* authorityRecords */,
                 Collections.emptyList() /* additionalRecords */);
 
-        client.processResponse(srvTxtResponse, INTERFACE_INDEX, mockNetwork);
+        client.processResponse(srvTxtResponse, socketKey);
 
         // Expect a query for A/AAAA
         final ArgumentCaptor<DatagramPacket> addressQueryCaptor =
                 ArgumentCaptor.forClass(DatagramPacket.class);
         currentThreadExecutor.getAndClearLastScheduledRunnable().run();
-        inOrder.verify(mockSocketClient, times(2)).sendMulticastPacket(addressQueryCaptor.capture(),
-                eq(mockNetwork));
+        inOrder.verify(mockSocketClient, times(2)).sendPacketRequestingMulticastResponse(
+                addressQueryCaptor.capture(),
+                eq(socketKey), eq(false));
 
         final MdnsPacket addressQueryPacket = MdnsPacket.parse(
                 new MdnsPacketReader(addressQueryCaptor.getValue()));
@@ -980,7 +975,7 @@
                 Collections.emptyList() /* additionalRecords */);
 
         inOrder.verify(mockListenerOne, never()).onServiceNameDiscovered(any());
-        client.processResponse(addressResponse, INTERFACE_INDEX, mockNetwork);
+        client.processResponse(addressResponse, socketKey);
 
         inOrder.verify(mockListenerOne).onServiceFound(serviceInfoCaptor.capture());
         verifyServiceInfo(serviceInfoCaptor.getValue(),
@@ -991,14 +986,13 @@
                 1234 /* port */,
                 Collections.emptyList() /* subTypes */,
                 Collections.emptyMap() /* attributes */,
-                INTERFACE_INDEX,
-                mockNetwork);
+                socketKey);
     }
 
     @Test
     public void testRenewTxtSrvInResolve() throws Exception {
         client = new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor,
-                mockDecoderClock, mockNetwork, mockSharedLog);
+                mockDecoderClock, socketKey, mockSharedLog);
 
         final String instanceName = "service-instance";
         final String[] hostname = new String[] { "testhost "};
@@ -1016,15 +1010,15 @@
                 ArgumentCaptor.forClass(DatagramPacket.class);
         currentThreadExecutor.getAndClearLastScheduledRunnable().run();
         // Send twice for IPv4 and IPv6
-        inOrder.verify(mockSocketClient, times(2)).sendUnicastPacket(srvTxtQueryCaptor.capture(),
-                eq(mockNetwork));
+        inOrder.verify(mockSocketClient, times(2)).sendPacketRequestingUnicastResponse(
+                srvTxtQueryCaptor.capture(),
+                eq(socketKey), eq(false));
 
         final MdnsPacket srvTxtQueryPacket = MdnsPacket.parse(
                 new MdnsPacketReader(srvTxtQueryCaptor.getValue()));
 
         final String[] serviceName = getTestServiceName(instanceName);
-        assertTrue(hasQuestion(srvTxtQueryPacket, MdnsRecord.TYPE_SRV, serviceName));
-        assertTrue(hasQuestion(srvTxtQueryPacket, MdnsRecord.TYPE_TXT, serviceName));
+        assertTrue(hasQuestion(srvTxtQueryPacket, MdnsRecord.TYPE_ANY, serviceName));
 
         // Process a response with all records
         final MdnsPacket srvTxtResponse = new MdnsPacket(
@@ -1045,7 +1039,7 @@
                                 InetAddresses.parseNumericAddress(ipV6Address))),
                 Collections.emptyList() /* authorityRecords */,
                 Collections.emptyList() /* additionalRecords */);
-        client.processResponse(srvTxtResponse, INTERFACE_INDEX, mockNetwork);
+        client.processResponse(srvTxtResponse, socketKey);
         inOrder.verify(mockListenerOne).onServiceNameDiscovered(any());
         inOrder.verify(mockListenerOne).onServiceFound(any());
 
@@ -1062,13 +1056,13 @@
         final ArgumentCaptor<DatagramPacket> renewalQueryCaptor =
                 ArgumentCaptor.forClass(DatagramPacket.class);
         // Second and later sends are sent as "expect multicast response" queries
-        inOrder.verify(mockSocketClient, times(2)).sendMulticastPacket(renewalQueryCaptor.capture(),
-                eq(mockNetwork));
+        inOrder.verify(mockSocketClient, times(2)).sendPacketRequestingMulticastResponse(
+                renewalQueryCaptor.capture(),
+                eq(socketKey), eq(false));
         inOrder.verify(mockListenerOne).onDiscoveryQuerySent(any(), anyInt());
         final MdnsPacket renewalPacket = MdnsPacket.parse(
                 new MdnsPacketReader(renewalQueryCaptor.getValue()));
-        assertTrue(hasQuestion(renewalPacket, MdnsRecord.TYPE_SRV, serviceName));
-        assertTrue(hasQuestion(renewalPacket, MdnsRecord.TYPE_TXT, serviceName));
+        assertTrue(hasQuestion(renewalPacket, MdnsRecord.TYPE_ANY, serviceName));
         inOrder.verifyNoMoreInteractions();
 
         long updatedReceiptTime =  TEST_ELAPSED_REALTIME + TEST_TTL;
@@ -1090,7 +1084,7 @@
                                 InetAddresses.parseNumericAddress(ipV6Address))),
                 Collections.emptyList() /* authorityRecords */,
                 Collections.emptyList() /* additionalRecords */);
-        client.processResponse(refreshedSrvTxtResponse, INTERFACE_INDEX, mockNetwork);
+        client.processResponse(refreshedSrvTxtResponse, socketKey);
 
         // Advance time to updatedReceiptTime + 1, expected no refresh query because the cache
         // should contain the record that have update last receipt time.
@@ -1102,7 +1096,7 @@
     @Test
     public void testProcessResponse_ResolveExcludesOtherServices() {
         client = new MdnsServiceTypeClient(
-                SERVICE_TYPE, mockSocketClient, currentThreadExecutor, mockNetwork, mockSharedLog);
+                SERVICE_TYPE, mockSocketClient, currentThreadExecutor, socketKey, mockSharedLog);
 
         final String requestedInstance = "instance1";
         final String otherInstance = "instance2";
@@ -1119,25 +1113,25 @@
 
         // Complete response from instanceName
         client.processResponse(createResponse(
-                requestedInstance, ipV4Address, 5353, SERVICE_TYPE_LABELS,
+                        requestedInstance, ipV4Address, 5353, SERVICE_TYPE_LABELS,
                         Collections.emptyMap() /* textAttributes */, TEST_TTL),
-                INTERFACE_INDEX, mockNetwork);
+                socketKey);
 
         // Complete response from otherInstanceName
         client.processResponse(createResponse(
-                otherInstance, ipV4Address, 5353, SERVICE_TYPE_LABELS,
+                        otherInstance, ipV4Address, 5353, SERVICE_TYPE_LABELS,
                         Collections.emptyMap() /* textAttributes */, TEST_TTL),
-                INTERFACE_INDEX, mockNetwork);
+                socketKey);
 
         // Address update from otherInstanceName
         client.processResponse(createResponse(
                 otherInstance, ipV6Address, 5353, SERVICE_TYPE_LABELS,
-                Collections.emptyMap(), TEST_TTL), INTERFACE_INDEX, mockNetwork);
+                Collections.emptyMap(), TEST_TTL), socketKey);
 
         // Goodbye from otherInstanceName
         client.processResponse(createResponse(
                 otherInstance, ipV6Address, 5353, SERVICE_TYPE_LABELS,
-                Collections.emptyMap(), 0L /* ttl */), INTERFACE_INDEX, mockNetwork);
+                Collections.emptyMap(), 0L /* ttl */), socketKey);
 
         // mockListenerOne gets notified for the requested instance
         verify(mockListenerOne).onServiceNameDiscovered(
@@ -1166,7 +1160,7 @@
     @Test
     public void testProcessResponse_SubtypeDiscoveryLimitedToSubtype() {
         client = new MdnsServiceTypeClient(
-                SERVICE_TYPE, mockSocketClient, currentThreadExecutor, mockNetwork, mockSharedLog);
+                SERVICE_TYPE, mockSocketClient, currentThreadExecutor, socketKey, mockSharedLog);
 
         final String matchingInstance = "instance1";
         final String subtype = "_subtype";
@@ -1202,23 +1196,23 @@
                 newAnswers,
                 packetWithoutSubtype.authorityRecords,
                 packetWithoutSubtype.additionalRecords);
-        client.processResponse(packetWithSubtype, INTERFACE_INDEX, mockNetwork);
+        client.processResponse(packetWithSubtype, socketKey);
 
         // Complete response from otherInstanceName, without subtype
         client.processResponse(createResponse(
                         otherInstance, ipV4Address, 5353, SERVICE_TYPE_LABELS,
                         Collections.emptyMap() /* textAttributes */, TEST_TTL),
-                INTERFACE_INDEX, mockNetwork);
+                socketKey);
 
         // Address update from otherInstanceName
         client.processResponse(createResponse(
                 otherInstance, ipV6Address, 5353, SERVICE_TYPE_LABELS,
-                Collections.emptyMap(), TEST_TTL), INTERFACE_INDEX, mockNetwork);
+                Collections.emptyMap(), TEST_TTL), socketKey);
 
         // Goodbye from otherInstanceName
         client.processResponse(createResponse(
                 otherInstance, ipV6Address, 5353, SERVICE_TYPE_LABELS,
-                Collections.emptyMap(), 0L /* ttl */), INTERFACE_INDEX, mockNetwork);
+                Collections.emptyMap(), 0L /* ttl */), socketKey);
 
         // mockListenerOne gets notified for the requested instance
         final ArgumentMatcher<MdnsServiceInfo> subtypeInstanceMatcher = info ->
@@ -1247,7 +1241,7 @@
     @Test
     public void testNotifySocketDestroyed() throws Exception {
         client = new MdnsServiceTypeClient(
-                SERVICE_TYPE, mockSocketClient, currentThreadExecutor, mockNetwork, mockSharedLog);
+                SERVICE_TYPE, mockSocketClient, currentThreadExecutor, socketKey, mockSharedLog);
 
         final String requestedInstance = "instance1";
         final String otherInstance = "instance2";
@@ -1273,13 +1267,13 @@
         client.processResponse(createResponse(
                         requestedInstance, ipV4Address, 5353, SERVICE_TYPE_LABELS,
                         Collections.emptyMap() /* textAttributes */, TEST_TTL),
-                INTERFACE_INDEX, mockNetwork);
+                socketKey);
 
         // Complete response from otherInstanceName
         client.processResponse(createResponse(
                         otherInstance, ipV4Address, 5353, SERVICE_TYPE_LABELS,
                         Collections.emptyMap() /* textAttributes */, TEST_TTL),
-                INTERFACE_INDEX, mockNetwork);
+                socketKey);
 
         verify(expectedSendFutures[1], never()).cancel(true);
         client.notifySocketDestroyed();
@@ -1326,18 +1320,18 @@
         assertEquals(currentThreadExecutor.getAndClearLastScheduledDelayInMs(), timeInMs);
         currentThreadExecutor.getAndClearLastScheduledRunnable().run();
         if (expectsUnicastResponse) {
-            verify(mockSocketClient).sendUnicastPacket(
-                    expectedIPv4Packets[index], mockNetwork);
+            verify(mockSocketClient).sendPacketRequestingUnicastResponse(
+                    expectedIPv4Packets[index], socketKey, false);
             if (multipleSocketDiscovery) {
-                verify(mockSocketClient).sendUnicastPacket(
-                        expectedIPv6Packets[index], mockNetwork);
+                verify(mockSocketClient).sendPacketRequestingUnicastResponse(
+                        expectedIPv6Packets[index], socketKey, false);
             }
         } else {
-            verify(mockSocketClient).sendMulticastPacket(
-                    expectedIPv4Packets[index], mockNetwork);
+            verify(mockSocketClient).sendPacketRequestingMulticastResponse(
+                    expectedIPv4Packets[index], socketKey, false);
             if (multipleSocketDiscovery) {
-                verify(mockSocketClient).sendMulticastPacket(
-                        expectedIPv6Packets[index], mockNetwork);
+                verify(mockSocketClient).sendPacketRequestingMulticastResponse(
+                        expectedIPv6Packets[index], socketKey, false);
             }
         }
     }
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketClientTests.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketClientTests.java
index abb1747..69efc61 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketClientTests.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketClientTests.java
@@ -23,6 +23,7 @@
 import static org.junit.Assert.assertTrue;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.anyInt;
+import static org.mockito.ArgumentMatchers.argThat;
 import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.never;
@@ -53,6 +54,7 @@
 
 import java.io.IOException;
 import java.net.DatagramPacket;
+import java.net.InetSocketAddress;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicBoolean;
@@ -220,15 +222,17 @@
         assertTrue(unicastReceiverThread.isAlive());
 
         // Sends a packet.
-        DatagramPacket packet = new DatagramPacket(buf, 0, 5);
-        mdnsClient.sendMulticastPacket(packet);
+        DatagramPacket packet = getTestDatagramPacket();
+        mdnsClient.sendPacketRequestingMulticastResponse(packet,
+                false /* onlyUseIpv6OnIpv6OnlyNetworks */);
         // mockMulticastSocket.send() will be called on another thread. If we verify it immediately,
         // it may not be called yet. So timeout is added.
         verify(mockMulticastSocket, timeout(TIMEOUT).times(1)).send(packet);
         verify(mockUnicastSocket, timeout(TIMEOUT).times(0)).send(packet);
 
         // Verify the packet is sent by the unicast socket.
-        mdnsClient.sendUnicastPacket(packet);
+        mdnsClient.sendPacketRequestingUnicastResponse(packet,
+                false /* onlyUseIpv6OnIpv6OnlyNetworks */);
         verify(mockMulticastSocket, timeout(TIMEOUT).times(1)).send(packet);
         verify(mockUnicastSocket, timeout(TIMEOUT).times(1)).send(packet);
 
@@ -271,15 +275,17 @@
         assertNull(unicastReceiverThread);
 
         // Sends a packet.
-        DatagramPacket packet = new DatagramPacket(buf, 0, 5);
-        mdnsClient.sendMulticastPacket(packet);
+        DatagramPacket packet = getTestDatagramPacket();
+        mdnsClient.sendPacketRequestingMulticastResponse(packet,
+                false /* onlyUseIpv6OnIpv6OnlyNetworks */);
         // mockMulticastSocket.send() will be called on another thread. If we verify it immediately,
         // it may not be called yet. So timeout is added.
         verify(mockMulticastSocket, timeout(TIMEOUT).times(1)).send(packet);
         verify(mockUnicastSocket, timeout(TIMEOUT).times(0)).send(packet);
 
         // Verify the packet is sent by the multicast socket as well.
-        mdnsClient.sendUnicastPacket(packet);
+        mdnsClient.sendPacketRequestingUnicastResponse(packet,
+                false /* onlyUseIpv6OnIpv6OnlyNetworks */);
         verify(mockMulticastSocket, timeout(TIMEOUT).times(2)).send(packet);
         verify(mockUnicastSocket, timeout(TIMEOUT).times(0)).send(packet);
 
@@ -331,7 +337,8 @@
     public void testStopDiscovery_queueIsCleared() throws IOException {
         mdnsClient.startDiscovery();
         mdnsClient.stopDiscovery();
-        mdnsClient.sendMulticastPacket(new DatagramPacket(buf, 0, 5));
+        mdnsClient.sendPacketRequestingMulticastResponse(getTestDatagramPacket(),
+                false /* onlyUseIpv6OnIpv6OnlyNetworks */);
 
         synchronized (mdnsClient.multicastPacketQueue) {
             assertTrue(mdnsClient.multicastPacketQueue.isEmpty());
@@ -342,7 +349,8 @@
     public void testSendPacket_afterDiscoveryStops() throws IOException {
         mdnsClient.startDiscovery();
         mdnsClient.stopDiscovery();
-        mdnsClient.sendMulticastPacket(new DatagramPacket(buf, 0, 5));
+        mdnsClient.sendPacketRequestingMulticastResponse(getTestDatagramPacket(),
+                false /* onlyUseIpv6OnIpv6OnlyNetworks */);
 
         synchronized (mdnsClient.multicastPacketQueue) {
             assertTrue(mdnsClient.multicastPacketQueue.isEmpty());
@@ -355,7 +363,8 @@
         //MdnsConfigsFlagsImpl.mdnsPacketQueueMaxSize.override(2L);
         mdnsClient.startDiscovery();
         for (int i = 0; i < 100; i++) {
-            mdnsClient.sendMulticastPacket(new DatagramPacket(buf, 0, 5));
+            mdnsClient.sendPacketRequestingMulticastResponse(getTestDatagramPacket(),
+                    false /* onlyUseIpv6OnIpv6OnlyNetworks */);
         }
 
         synchronized (mdnsClient.multicastPacketQueue) {
@@ -370,7 +379,7 @@
         mdnsClient.startDiscovery();
 
         verify(mockCallback, timeout(TIMEOUT).atLeast(1))
-                .onResponseReceived(any(MdnsPacket.class), anyInt(), any());
+                .onResponseReceived(any(MdnsPacket.class), any(SocketKey.class));
     }
 
     @Test
@@ -379,7 +388,7 @@
         mdnsClient.startDiscovery();
 
         verify(mockCallback, timeout(TIMEOUT).atLeastOnce())
-                .onResponseReceived(any(MdnsPacket.class), anyInt(), any());
+                .onResponseReceived(any(MdnsPacket.class), any(SocketKey.class));
 
         mdnsClient.stopDiscovery();
     }
@@ -451,9 +460,11 @@
         enableUnicastResponse.set(true);
 
         mdnsClient.startDiscovery();
-        DatagramPacket packet = new DatagramPacket(buf, 0, 5);
-        mdnsClient.sendUnicastPacket(packet);
-        mdnsClient.sendMulticastPacket(packet);
+        DatagramPacket packet = getTestDatagramPacket();
+        mdnsClient.sendPacketRequestingUnicastResponse(packet,
+                false /* onlyUseIpv6OnIpv6OnlyNetworks */);
+        mdnsClient.sendPacketRequestingMulticastResponse(packet,
+                false /* onlyUseIpv6OnIpv6OnlyNetworks */);
 
         // Wait for the timer to be triggered.
         Thread.sleep(MdnsConfigs.checkMulticastResponseIntervalMs() * 2);
@@ -483,8 +494,10 @@
         assertFalse(mdnsClient.receivedUnicastResponse);
         assertFalse(mdnsClient.cannotReceiveMulticastResponse.get());
 
-        mdnsClient.sendUnicastPacket(packet);
-        mdnsClient.sendMulticastPacket(packet);
+        mdnsClient.sendPacketRequestingUnicastResponse(packet,
+                false /* onlyUseIpv6OnIpv6OnlyNetworks */);
+        mdnsClient.sendPacketRequestingMulticastResponse(packet,
+                false /* onlyUseIpv6OnIpv6OnlyNetworks */);
         Thread.sleep(MdnsConfigs.checkMulticastResponseIntervalMs() * 2);
 
         // Verify cannotReceiveMulticastResponse is not set the true because we didn't receive the
@@ -513,7 +526,7 @@
         mdnsClient.startDiscovery();
 
         verify(mockCallback, timeout(TIMEOUT).atLeastOnce())
-                .onResponseReceived(any(), eq(21), any());
+                .onResponseReceived(any(), argThat(key -> key.getInterfaceIndex() == 21));
     }
 
     @Test
@@ -536,6 +549,12 @@
         mdnsClient.startDiscovery();
 
         verify(mockMulticastSocket, never()).getInterfaceIndex();
-        verify(mockCallback, timeout(TIMEOUT).atLeast(1)).onResponseReceived(any(), eq(-1), any());
+        verify(mockCallback, timeout(TIMEOUT).atLeast(1))
+                .onResponseReceived(any(), argThat(key -> key.getInterfaceIndex() == -1));
+    }
+
+    private DatagramPacket getTestDatagramPacket() {
+        return new DatagramPacket(buf, 0, 5,
+                new InetSocketAddress(MdnsConstants.getMdnsIPv4Address(), 5353 /* port */));
     }
 }
\ No newline at end of file
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketProviderTest.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketProviderTest.java
index 4ef64cb..e971de7 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketProviderTest.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketProviderTest.java
@@ -227,30 +227,30 @@
 
     private class TestSocketCallback implements MdnsSocketProvider.SocketCallback {
         private class SocketEvent {
-            public final Network mNetwork;
+            public final SocketKey mSocketKey;
             public final List<LinkAddress> mAddresses;
 
-            SocketEvent(Network network, List<LinkAddress> addresses) {
-                mNetwork = network;
+            SocketEvent(SocketKey socketKey, List<LinkAddress> addresses) {
+                mSocketKey = socketKey;
                 mAddresses = Collections.unmodifiableList(addresses);
             }
         }
 
         private class SocketCreatedEvent extends SocketEvent {
-            SocketCreatedEvent(Network nw, List<LinkAddress> addresses) {
-                super(nw, addresses);
+            SocketCreatedEvent(SocketKey socketKey, List<LinkAddress> addresses) {
+                super(socketKey, addresses);
             }
         }
 
         private class InterfaceDestroyedEvent extends SocketEvent {
-            InterfaceDestroyedEvent(Network nw, List<LinkAddress> addresses) {
-                super(nw, addresses);
+            InterfaceDestroyedEvent(SocketKey socketKey, List<LinkAddress> addresses) {
+                super(socketKey, addresses);
             }
         }
 
         private class AddressesChangedEvent extends SocketEvent {
-            AddressesChangedEvent(Network nw, List<LinkAddress> addresses) {
-                super(nw, addresses);
+            AddressesChangedEvent(SocketKey socketKey, List<LinkAddress> addresses) {
+                super(socketKey, addresses);
             }
         }
 
@@ -258,27 +258,27 @@
                 new ArrayTrackRecord<SocketEvent>().newReadHead();
 
         @Override
-        public void onSocketCreated(Network network, MdnsInterfaceSocket socket,
+        public void onSocketCreated(SocketKey socketKey, MdnsInterfaceSocket socket,
                 List<LinkAddress> addresses) {
-            mHistory.add(new SocketCreatedEvent(network, addresses));
+            mHistory.add(new SocketCreatedEvent(socketKey, addresses));
         }
 
         @Override
-        public void onInterfaceDestroyed(Network network, MdnsInterfaceSocket socket) {
-            mHistory.add(new InterfaceDestroyedEvent(network, List.of()));
+        public void onInterfaceDestroyed(SocketKey socketKey, MdnsInterfaceSocket socket) {
+            mHistory.add(new InterfaceDestroyedEvent(socketKey, List.of()));
         }
 
         @Override
-        public void onAddressesChanged(Network network, MdnsInterfaceSocket socket,
+        public void onAddressesChanged(SocketKey socketKey, MdnsInterfaceSocket socket,
                 List<LinkAddress> addresses) {
-            mHistory.add(new AddressesChangedEvent(network, addresses));
+            mHistory.add(new AddressesChangedEvent(socketKey, addresses));
         }
 
         public void expectedSocketCreatedForNetwork(Network network, List<LinkAddress> addresses) {
             final SocketEvent event = mHistory.poll(0L /* timeoutMs */, c -> true);
             assertNotNull(event);
             assertTrue(event instanceof SocketCreatedEvent);
-            assertEquals(network, event.mNetwork);
+            assertEquals(network, event.mSocketKey.getNetwork());
             assertEquals(addresses, event.mAddresses);
         }
 
@@ -286,7 +286,7 @@
             final SocketEvent event = mHistory.poll(0L /* timeoutMs */, c -> true);
             assertNotNull(event);
             assertTrue(event instanceof InterfaceDestroyedEvent);
-            assertEquals(network, event.mNetwork);
+            assertEquals(network, event.mSocketKey.getNetwork());
         }
 
         public void expectedAddressesChangedForNetwork(Network network,
@@ -294,7 +294,7 @@
             final SocketEvent event = mHistory.poll(0L /* timeoutMs */, c -> true);
             assertNotNull(event);
             assertTrue(event instanceof AddressesChangedEvent);
-            assertEquals(network, event.mNetwork);
+            assertEquals(network, event.mSocketKey.getNetwork());
             assertEquals(event.mAddresses, addresses);
         }