Merge "Rename AbstractSocketNetlink to AbstractSocketNetlinkMonitor"
diff --git a/Cronet/tests/cts/src/android/net/http/cts/BidirectionalStreamTest.kt b/Cronet/tests/cts/src/android/net/http/cts/BidirectionalStreamTest.kt
index 0760e68..0885f4f 100644
--- a/Cronet/tests/cts/src/android/net/http/cts/BidirectionalStreamTest.kt
+++ b/Cronet/tests/cts/src/android/net/http/cts/BidirectionalStreamTest.kt
@@ -27,6 +27,7 @@
 import androidx.test.core.app.ApplicationProvider
 import com.android.testutils.DevSdkIgnoreRule
 import com.android.testutils.DevSdkIgnoreRunner
+import com.google.common.truth.Truth.assertThat
 import kotlin.test.Test
 import kotlin.test.assertEquals
 import org.hamcrest.MatcherAssert
@@ -81,4 +82,113 @@
             "Received byte count must be > 0", info.receivedByteCount, Matchers.greaterThan(0L))
         assertEquals("h2", info.negotiatedProtocol)
     }
+
+    @Test
+    @Throws(Exception::class)
+    fun testBidirectionalStream_getHttpMethod() {
+        val builder = createBidirectionalStreamBuilder(URL)
+        val method = "GET"
+
+        builder.setHttpMethod(method)
+        stream = builder.build()
+        assertThat(stream!!.getHttpMethod()).isEqualTo(method)
+    }
+
+    @Test
+    @Throws(Exception::class)
+    fun testBidirectionalStream_hasTrafficStatsTag() {
+        val builder = createBidirectionalStreamBuilder(URL)
+
+        builder.setTrafficStatsTag(10)
+        stream = builder.build()
+        assertThat(stream!!.hasTrafficStatsTag()).isTrue()
+    }
+
+    @Test
+    @Throws(Exception::class)
+    fun testBidirectionalStream_getTrafficStatsTag() {
+        val builder = createBidirectionalStreamBuilder(URL)
+        val trafficStatsTag = 10
+
+        builder.setTrafficStatsTag(trafficStatsTag)
+        stream = builder.build()
+        assertThat(stream!!.getTrafficStatsTag()).isEqualTo(trafficStatsTag)
+    }
+
+    @Test
+    @Throws(Exception::class)
+    fun testBidirectionalStream_hasTrafficStatsUid() {
+        val builder = createBidirectionalStreamBuilder(URL)
+
+        builder.setTrafficStatsUid(10)
+        stream = builder.build()
+        assertThat(stream!!.hasTrafficStatsUid()).isTrue()
+    }
+
+    @Test
+    @Throws(Exception::class)
+    fun testBidirectionalStream_getTrafficStatsUid() {
+        val builder = createBidirectionalStreamBuilder(URL)
+        val trafficStatsUid = 10
+
+        builder.setTrafficStatsUid(trafficStatsUid)
+        stream = builder.build()
+        assertThat(stream!!.getTrafficStatsUid()).isEqualTo(trafficStatsUid)
+    }
+
+    @Test
+    @Throws(Exception::class)
+    fun testBidirectionalStream_getHeaders_asList() {
+        val builder = createBidirectionalStreamBuilder(URL)
+        val expectedHeaders = mapOf(
+          "Authorization" to "Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==",
+          "Max-Forwards" to "10",
+          "X-Client-Data" to "random custom header content").entries.toList()
+
+        for (header in expectedHeaders) {
+            builder.addHeader(header.key, header.value)
+        }
+
+        stream = builder.build()
+        assertThat(stream!!.getHeaders().getAsList()).containsAtLeastElementsIn(expectedHeaders)
+    }
+
+    @Test
+    @Throws(Exception::class)
+    fun testBidirectionalStream_getHeaders_asMap() {
+        val builder = createBidirectionalStreamBuilder(URL)
+        val expectedHeaders = mapOf(
+          "Authorization" to listOf("Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ=="),
+          "Max-Forwards" to listOf("10"),
+          "X-Client-Data" to listOf("random custom header content"))
+
+        for (header in expectedHeaders) {
+            builder.addHeader(header.key, header.value.get(0))
+        }
+
+        stream = builder.build()
+        assertThat(stream!!.getHeaders().getAsMap()).containsAtLeastEntriesIn(expectedHeaders)
+    }
+
+    @Test
+    @Throws(Exception::class)
+    fun testBidirectionalStream_getPriority() {
+        val builder = createBidirectionalStreamBuilder(URL)
+        val priority = BidirectionalStream.STREAM_PRIORITY_LOW
+
+        builder.setPriority(priority)
+        stream = builder.build()
+        assertThat(stream!!.getPriority()).isEqualTo(priority)
+    }
+
+    @Test
+    @Throws(Exception::class)
+    fun testBidirectionalStream_isDelayRequestHeadersUntilFirstFlushEnabled() {
+        val builder = createBidirectionalStreamBuilder(URL)
+
+        builder.setDelayRequestHeadersUntilFirstFlushEnabled(true)
+        stream = builder.build()
+        assertThat(stream!!.isDelayRequestHeadersUntilFirstFlushEnabled()).isTrue()
+    }
+
 }
diff --git a/Cronet/tests/cts/src/android/net/http/cts/UrlRequestTest.java b/Cronet/tests/cts/src/android/net/http/cts/UrlRequestTest.java
index 07e7d45..3c4d134 100644
--- a/Cronet/tests/cts/src/android/net/http/cts/UrlRequestTest.java
+++ b/Cronet/tests/cts/src/android/net/http/cts/UrlRequestTest.java
@@ -363,6 +363,116 @@
                 .containsAtLeastElementsIn(expectedHeaders);
     }
 
+    @Test
+    public void testUrlRequest_getHttpMethod() throws Exception {
+        UrlRequest.Builder builder = createUrlRequestBuilder(mTestServer.getSuccessUrl());
+        final String method = "POST";
+
+        builder.setHttpMethod(method);
+        UrlRequest request = builder.build();
+        assertThat(request.getHttpMethod()).isEqualTo(method);
+    }
+
+    @Test
+    public void testUrlRequest_getHeaders_asList() throws Exception {
+        UrlRequest.Builder builder = createUrlRequestBuilder(mTestServer.getSuccessUrl());
+        final List<Map.Entry<String, String>> expectedHeaders = Arrays.asList(
+                Map.entry("Authorization", "Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ=="),
+                Map.entry("Max-Forwards", "10"),
+                Map.entry("X-Client-Data", "random custom header content"));
+
+        for (Map.Entry<String, String> header : expectedHeaders) {
+            builder.addHeader(header.getKey(), header.getValue());
+        }
+
+        UrlRequest request = builder.build();
+        assertThat(request.getHeaders().getAsList()).containsAtLeastElementsIn(expectedHeaders);
+    }
+
+    @Test
+    public void testUrlRequest_getHeaders_asMap() throws Exception {
+        UrlRequest.Builder builder = createUrlRequestBuilder(mTestServer.getSuccessUrl());
+        final Map<String, List<String>> expectedHeaders = Map.of(
+                "Authorization", Arrays.asList("Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ=="),
+                "Max-Forwards", Arrays.asList("10"),
+                "X-Client-Data", Arrays.asList("random custom header content"));
+
+        for (Map.Entry<String, List<String>> header : expectedHeaders.entrySet()) {
+            builder.addHeader(header.getKey(), header.getValue().get(0));
+        }
+
+        UrlRequest request = builder.build();
+        assertThat(request.getHeaders().getAsMap()).containsAtLeastEntriesIn(expectedHeaders);
+    }
+
+    @Test
+    public void testUrlRequest_isCacheDisabled() throws Exception {
+        UrlRequest.Builder builder = createUrlRequestBuilder(mTestServer.getSuccessUrl());
+        final boolean isCacheDisabled = true;
+
+        builder.setCacheDisabled(isCacheDisabled);
+        UrlRequest request = builder.build();
+        assertThat(request.isCacheDisabled()).isEqualTo(isCacheDisabled);
+    }
+
+    @Test
+    public void testUrlRequest_isDirectExecutorAllowed() throws Exception {
+        UrlRequest.Builder builder = createUrlRequestBuilder(mTestServer.getSuccessUrl());
+        final boolean isDirectExecutorAllowed = true;
+
+        builder.setDirectExecutorAllowed(isDirectExecutorAllowed);
+        UrlRequest request = builder.build();
+        assertThat(request.isDirectExecutorAllowed()).isEqualTo(isDirectExecutorAllowed);
+    }
+
+    @Test
+    public void testUrlRequest_getPriority() throws Exception {
+        UrlRequest.Builder builder = createUrlRequestBuilder(mTestServer.getSuccessUrl());
+        final int priority = UrlRequest.REQUEST_PRIORITY_LOW;
+
+        builder.setPriority(priority);
+        UrlRequest request = builder.build();
+        assertThat(request.getPriority()).isEqualTo(priority);
+    }
+
+    @Test
+    public void testUrlRequest_hasTrafficStatsTag() throws Exception {
+        UrlRequest.Builder builder = createUrlRequestBuilder(mTestServer.getSuccessUrl());
+
+        builder.setTrafficStatsTag(10);
+        UrlRequest request = builder.build();
+        assertThat(request.hasTrafficStatsTag()).isEqualTo(true);
+    }
+
+    @Test
+    public void testUrlRequest_getTrafficStatsTag() throws Exception {
+        UrlRequest.Builder builder = createUrlRequestBuilder(mTestServer.getSuccessUrl());
+        final int trafficStatsTag = 10;
+
+        builder.setTrafficStatsTag(trafficStatsTag);
+        UrlRequest request = builder.build();
+        assertThat(request.getTrafficStatsTag()).isEqualTo(trafficStatsTag);
+    }
+
+    @Test
+    public void testUrlRequest_hasTrafficStatsUid() throws Exception {
+        UrlRequest.Builder builder = createUrlRequestBuilder(mTestServer.getSuccessUrl());
+
+        builder.setTrafficStatsUid(10);
+        UrlRequest request = builder.build();
+        assertThat(request.hasTrafficStatsUid()).isEqualTo(true);
+    }
+
+    @Test
+    public void testUrlRequest_getTrafficStatsUid() throws Exception {
+        UrlRequest.Builder builder = createUrlRequestBuilder(mTestServer.getSuccessUrl());
+        final int trafficStatsUid = 10;
+
+        builder.setTrafficStatsUid(trafficStatsUid);
+        UrlRequest request = builder.build();
+        assertThat(request.getTrafficStatsUid()).isEqualTo(trafficStatsUid);
+    }
+
     private static List<Map.Entry<String, String>> extractEchoedHeaders(HeaderBlock headers) {
         return headers.getAsList()
                 .stream()
diff --git a/Tethering/tests/integration/base/android/net/EthernetTetheringTestBase.java b/Tethering/tests/integration/base/android/net/EthernetTetheringTestBase.java
index 007bf23..9dad301 100644
--- a/Tethering/tests/integration/base/android/net/EthernetTetheringTestBase.java
+++ b/Tethering/tests/integration/base/android/net/EthernetTetheringTestBase.java
@@ -25,17 +25,14 @@
 import static android.net.TetheringManager.CONNECTIVITY_SCOPE_GLOBAL;
 import static android.net.TetheringManager.CONNECTIVITY_SCOPE_LOCAL;
 import static android.net.TetheringManager.TETHERING_ETHERNET;
+import static android.net.TetheringTester.buildTcpPacket;
+import static android.net.TetheringTester.buildUdpPacket;
+import static android.net.TetheringTester.isAddressIpv4;
 import static android.net.TetheringTester.isExpectedIcmpPacket;
 import static android.net.TetheringTester.isExpectedTcpPacket;
 import static android.net.TetheringTester.isExpectedUdpPacket;
-import static android.system.OsConstants.IPPROTO_IP;
-import static android.system.OsConstants.IPPROTO_IPV6;
-import static android.system.OsConstants.IPPROTO_TCP;
-import static android.system.OsConstants.IPPROTO_UDP;
 
 import static com.android.net.module.util.HexDump.dumpHexString;
-import static com.android.net.module.util.NetworkStackConstants.ETHER_TYPE_IPV4;
-import static com.android.net.module.util.NetworkStackConstants.ETHER_TYPE_IPV6;
 import static com.android.net.module.util.NetworkStackConstants.ICMPV6_ROUTER_ADVERTISEMENT;
 import static com.android.net.module.util.NetworkStackConstants.TCPHDR_ACK;
 import static com.android.net.module.util.NetworkStackConstants.TCPHDR_SYN;
@@ -67,11 +64,9 @@
 import android.util.Log;
 
 import androidx.annotation.NonNull;
-import androidx.annotation.Nullable;
 import androidx.test.platform.app.InstrumentationRegistry;
 
 import com.android.modules.utils.build.SdkLevel;
-import com.android.net.module.util.PacketBuilder;
 import com.android.net.module.util.Struct;
 import com.android.net.module.util.structs.Ipv6Header;
 import com.android.testutils.HandlerUtils;
@@ -124,29 +119,17 @@
             (Inet4Address) parseNumericAddress("8.8.8.8");
     protected static final Inet6Address REMOTE_IP6_ADDR =
             (Inet6Address) parseNumericAddress("2002:db8:1::515:ca");
+    // The IPv6 network address translation of REMOTE_IP4_ADDR if pref64::/n is 64:ff9b::/96.
+    // For more information, see TetheringTester#PREF64_IPV4ONLY_ADDR, which assumes a prefix
+    // of 64:ff9b::/96.
     protected static final Inet6Address REMOTE_NAT64_ADDR =
             (Inet6Address) parseNumericAddress("64:ff9b::808:808");
-    protected static final IpPrefix TEST_NAT64PREFIX = new IpPrefix("64:ff9b::/96");
 
-    // IPv4 header definition.
-    protected static final short ID = 27149;
-    protected static final short FLAGS_AND_FRAGMENT_OFFSET = (short) 0x4000; // flags=DF, offset=0
-    protected static final byte TIME_TO_LIVE = (byte) 0x40;
-    protected static final byte TYPE_OF_SERVICE = 0;
-
-    // IPv6 header definition.
-    private static final short HOP_LIMIT = 0x40;
-    // version=6, traffic class=0x0, flowlabel=0x0;
-    private static final int VERSION_TRAFFICCLASS_FLOWLABEL = 0x60000000;
-
-    // UDP and TCP header definition.
     // LOCAL_PORT is used by public port and private port. Assume port 9876 has not been used yet
     // before the testing that public port and private port are the same in the testing. Note that
     // NAT port forwarding could be different between private port and public port.
     protected static final short LOCAL_PORT = 9876;
     protected static final short REMOTE_PORT = 433;
-    private static final short WINDOW = (short) 0x2000;
-    private static final short URGENT_POINTER = 0;
 
     // Payload definition.
     protected static final ByteBuffer EMPTY_PAYLOAD = ByteBuffer.wrap(new byte[0]);
@@ -184,13 +167,18 @@
         mHandlerThread.start();
         mHandler = new Handler(mHandlerThread.getLooper());
 
-        mRunTests = runAsShell(NETWORK_SETTINGS, TETHER_PRIVILEGED, () ->
-                mTm.isTetheringSupported());
+        mRunTests = isEthernetTetheringSupported();
         assumeTrue(mRunTests);
 
         mTetheredInterfaceRequester = new TetheredInterfaceRequester(mHandler, mEm);
     }
 
+    private boolean isEthernetTetheringSupported() throws Exception {
+        if (mEm == null) return false;
+
+        return runAsShell(NETWORK_SETTINGS, TETHER_PRIVILEGED, () -> mTm.isTetheringSupported());
+    }
+
     protected void maybeStopTapPacketReader(final TapPacketReader tapPacketReader)
             throws Exception {
         if (tapPacketReader != null) {
@@ -649,77 +637,10 @@
         final LinkProperties lp = new LinkProperties();
         lp.setLinkAddresses(addresses);
         lp.setDnsServers(dnses);
-        lp.setNat64Prefix(TEST_NAT64PREFIX);
 
         return runAsShell(MANAGE_TEST_NETWORKS, () -> initTestNetwork(mContext, lp, TIMEOUT_MS));
     }
 
-    private short getEthType(@NonNull final InetAddress srcIp, @NonNull final InetAddress dstIp) {
-        return isAddressIpv4(srcIp, dstIp) ? (short) ETHER_TYPE_IPV4 : (short) ETHER_TYPE_IPV6;
-    }
-
-    private int getIpProto(@NonNull final InetAddress srcIp, @NonNull final InetAddress dstIp) {
-        return isAddressIpv4(srcIp, dstIp) ? IPPROTO_IP : IPPROTO_IPV6;
-    }
-
-    @NonNull
-    protected ByteBuffer buildUdpPacket(
-            @Nullable final MacAddress srcMac, @Nullable final MacAddress dstMac,
-            @NonNull final InetAddress srcIp, @NonNull final InetAddress dstIp,
-            short srcPort, short dstPort, @Nullable final ByteBuffer payload)
-            throws Exception {
-        final int ipProto = getIpProto(srcIp, dstIp);
-        final boolean hasEther = (srcMac != null && dstMac != null);
-        final int payloadLen = (payload == null) ? 0 : payload.limit();
-        final ByteBuffer buffer = PacketBuilder.allocate(hasEther, ipProto, IPPROTO_UDP,
-                payloadLen);
-        final PacketBuilder packetBuilder = new PacketBuilder(buffer);
-
-        // [1] Ethernet header
-        if (hasEther) {
-            packetBuilder.writeL2Header(srcMac, dstMac, getEthType(srcIp, dstIp));
-        }
-
-        // [2] IP header
-        if (ipProto == IPPROTO_IP) {
-            packetBuilder.writeIpv4Header(TYPE_OF_SERVICE, ID, FLAGS_AND_FRAGMENT_OFFSET,
-                    TIME_TO_LIVE, (byte) IPPROTO_UDP, (Inet4Address) srcIp, (Inet4Address) dstIp);
-        } else {
-            packetBuilder.writeIpv6Header(VERSION_TRAFFICCLASS_FLOWLABEL, (byte) IPPROTO_UDP,
-                    HOP_LIMIT, (Inet6Address) srcIp, (Inet6Address) dstIp);
-        }
-
-        // [3] UDP header
-        packetBuilder.writeUdpHeader(srcPort, dstPort);
-
-        // [4] Payload
-        if (payload != null) {
-            buffer.put(payload);
-            // in case data might be reused by caller, restore the position and
-            // limit of bytebuffer.
-            payload.clear();
-        }
-
-        return packetBuilder.finalizePacket();
-    }
-
-    @NonNull
-    protected ByteBuffer buildUdpPacket(@NonNull final InetAddress srcIp,
-            @NonNull final InetAddress dstIp, short srcPort, short dstPort,
-            @Nullable final ByteBuffer payload) throws Exception {
-        return buildUdpPacket(null /* srcMac */, null /* dstMac */, srcIp, dstIp, srcPort,
-                dstPort, payload);
-    }
-
-    private boolean isAddressIpv4(@NonNull final  InetAddress srcIp,
-            @NonNull final InetAddress dstIp) {
-        if (srcIp instanceof Inet4Address && dstIp instanceof Inet4Address) return true;
-        if (srcIp instanceof Inet6Address && dstIp instanceof Inet6Address) return false;
-
-        fail("Unsupported conditions: srcIp " + srcIp + ", dstIp " + dstIp);
-        return false;  // unreachable
-    }
-
     protected void sendDownloadPacketUdp(@NonNull final InetAddress srcIp,
             @NonNull final InetAddress dstIp, @NonNull final TetheringTester tester,
             boolean is6To4) throws Exception {
@@ -761,45 +682,6 @@
         });
     }
 
-
-    @NonNull
-    private ByteBuffer buildTcpPacket(
-            @Nullable final MacAddress srcMac, @Nullable final MacAddress dstMac,
-            @NonNull final InetAddress srcIp, @NonNull final InetAddress dstIp,
-            short srcPort, short dstPort, final short seq, final short ack,
-            final byte tcpFlags, @NonNull final ByteBuffer payload) throws Exception {
-        final int ipProto = getIpProto(srcIp, dstIp);
-        final boolean hasEther = (srcMac != null && dstMac != null);
-        final ByteBuffer buffer = PacketBuilder.allocate(hasEther, ipProto, IPPROTO_TCP,
-                payload.limit());
-        final PacketBuilder packetBuilder = new PacketBuilder(buffer);
-
-        // [1] Ethernet header
-        if (hasEther) {
-            packetBuilder.writeL2Header(srcMac, dstMac, getEthType(srcIp, dstIp));
-        }
-
-        // [2] IP header
-        if (ipProto == IPPROTO_IP) {
-            packetBuilder.writeIpv4Header(TYPE_OF_SERVICE, ID, FLAGS_AND_FRAGMENT_OFFSET,
-                    TIME_TO_LIVE, (byte) IPPROTO_TCP, (Inet4Address) srcIp, (Inet4Address) dstIp);
-        } else {
-            packetBuilder.writeIpv6Header(VERSION_TRAFFICCLASS_FLOWLABEL, (byte) IPPROTO_TCP,
-                    HOP_LIMIT, (Inet6Address) srcIp, (Inet6Address) dstIp);
-        }
-
-        // [3] TCP header
-        packetBuilder.writeTcpHeader(srcPort, dstPort, seq, ack, tcpFlags, WINDOW, URGENT_POINTER);
-
-        // [4] Payload
-        buffer.put(payload);
-        // in case data might be reused by caller, restore the position and
-        // limit of bytebuffer.
-        payload.clear();
-
-        return packetBuilder.finalizePacket();
-    }
-
     protected void sendDownloadPacketTcp(@NonNull final InetAddress srcIp,
             @NonNull final InetAddress dstIp, short seq, short ack, byte tcpFlags,
             @NonNull final ByteBuffer payload, @NonNull final TetheringTester tester,
diff --git a/Tethering/tests/integration/base/android/net/TetheringTester.java b/Tethering/tests/integration/base/android/net/TetheringTester.java
index 1c0803e..3f3768e 100644
--- a/Tethering/tests/integration/base/android/net/TetheringTester.java
+++ b/Tethering/tests/integration/base/android/net/TetheringTester.java
@@ -16,17 +16,27 @@
 
 package android.net;
 
+import static android.net.DnsResolver.CLASS_IN;
+import static android.net.DnsResolver.TYPE_AAAA;
 import static android.net.InetAddresses.parseNumericAddress;
+import static android.system.OsConstants.ICMP_ECHO;
+import static android.system.OsConstants.ICMP_ECHOREPLY;
 import static android.system.OsConstants.IPPROTO_ICMP;
 import static android.system.OsConstants.IPPROTO_ICMPV6;
+import static android.system.OsConstants.IPPROTO_IP;
+import static android.system.OsConstants.IPPROTO_IPV6;
 import static android.system.OsConstants.IPPROTO_TCP;
 import static android.system.OsConstants.IPPROTO_UDP;
 
 import static com.android.net.module.util.DnsPacket.ANSECTION;
 import static com.android.net.module.util.DnsPacket.ARSECTION;
+import static com.android.net.module.util.DnsPacket.DnsHeader;
+import static com.android.net.module.util.DnsPacket.DnsRecord;
 import static com.android.net.module.util.DnsPacket.NSSECTION;
 import static com.android.net.module.util.DnsPacket.QDSECTION;
 import static com.android.net.module.util.HexDump.dumpHexString;
+import static com.android.net.module.util.IpUtils.icmpChecksum;
+import static com.android.net.module.util.IpUtils.ipChecksum;
 import static com.android.net.module.util.NetworkStackConstants.ARP_REPLY;
 import static com.android.net.module.util.NetworkStackConstants.ARP_REQUEST;
 import static com.android.net.module.util.NetworkStackConstants.ETHER_ADDR_LEN;
@@ -38,6 +48,10 @@
 import static com.android.net.module.util.NetworkStackConstants.ICMPV6_ND_OPTION_TLLA;
 import static com.android.net.module.util.NetworkStackConstants.ICMPV6_NEIGHBOR_SOLICITATION;
 import static com.android.net.module.util.NetworkStackConstants.ICMPV6_ROUTER_ADVERTISEMENT;
+import static com.android.net.module.util.NetworkStackConstants.ICMP_CHECKSUM_OFFSET;
+import static com.android.net.module.util.NetworkStackConstants.IPV4_CHECKSUM_OFFSET;
+import static com.android.net.module.util.NetworkStackConstants.IPV4_HEADER_MIN_LEN;
+import static com.android.net.module.util.NetworkStackConstants.IPV4_LENGTH_OFFSET;
 import static com.android.net.module.util.NetworkStackConstants.IPV6_ADDR_ALL_NODES_MULTICAST;
 import static com.android.net.module.util.NetworkStackConstants.NEIGHBOR_ADVERTISEMENT_FLAG_OVERRIDE;
 import static com.android.net.module.util.NetworkStackConstants.NEIGHBOR_ADVERTISEMENT_FLAG_SOLICITED;
@@ -58,6 +72,7 @@
 
 import com.android.net.module.util.DnsPacket;
 import com.android.net.module.util.Ipv6Utils;
+import com.android.net.module.util.PacketBuilder;
 import com.android.net.module.util.Struct;
 import com.android.net.module.util.structs.EthernetHeader;
 import com.android.net.module.util.structs.Icmpv4Header;
@@ -101,6 +116,44 @@
             DhcpPacket.DHCP_LEASE_TIME,
     };
     private static final InetAddress LINK_LOCAL = parseNumericAddress("fe80::1");
+    // IPv4 header definition.
+    protected static final short ID = 27149;
+    protected static final short FLAGS_AND_FRAGMENT_OFFSET = (short) 0x4000; // flags=DF, offset=0
+    protected static final byte TIME_TO_LIVE = (byte) 0x40;
+    protected static final byte TYPE_OF_SERVICE = 0;
+
+    // IPv6 header definition.
+    private static final short HOP_LIMIT = 0x40;
+    // version=6, traffic class=0x0, flowlabel=0x0;
+    private static final int VERSION_TRAFFICCLASS_FLOWLABEL = 0x60000000;
+
+    // UDP and TCP header definition.
+    private static final short WINDOW = (short) 0x2000;
+    private static final short URGENT_POINTER = 0;
+
+    // ICMP definition.
+    private static final short ICMPECHO_CODE = 0x0;
+
+    // Prefix64 discovery definition. See RFC 7050 section 8.
+    // Note that the AAAA response Pref64::WKAs consisting of Pref64::/n and WKA.
+    // Use 64:ff9b::/96 as Pref64::/n and WKA 192.0.0.17{0|1} here.
+    //
+    // Host                                          DNS64 server
+    //   |                                                |
+    //   |  "AAAA" query for "ipv4only.arpa."             |
+    //   |----------------------------------------------->|
+    //   |                                                |
+    //   |  "AAAA" response with:                         |
+    //   |  "64:ff9b::192.0.0.170"                        |
+    //   |<-----------------------------------------------|
+    //
+    private static final String PREF64_IPV4ONLY_HOSTNAME = "ipv4only.arpa";
+    private static final InetAddress PREF64_IPV4ONLY_ADDR = parseNumericAddress(
+            "64:ff9b::192.0.0.170");
+
+    // DNS header definition.
+    private static final short FLAG = (short) 0x8100;  // qr, ra
+    private static final short TTL = (short) 0;
 
     public static final String DHCP_HOSTNAME = "testhostname";
 
@@ -462,6 +515,11 @@
             super(data);
         }
 
+        TestDnsPacket(@NonNull DnsHeader header, @Nullable ArrayList<DnsRecord> qd,
+                @Nullable ArrayList<DnsRecord> an) {
+            super(header, qd, an);
+        }
+
         @Nullable
         public static TestDnsPacket getTestDnsPacket(final ByteBuffer buf) {
             try {
@@ -628,6 +686,190 @@
         return false;
     }
 
+    @NonNull
+    public static ByteBuffer buildUdpPacket(
+            @Nullable final MacAddress srcMac, @Nullable final MacAddress dstMac,
+            @NonNull final InetAddress srcIp, @NonNull final InetAddress dstIp,
+            short srcPort, short dstPort, @Nullable final ByteBuffer payload)
+            throws Exception {
+        final int ipProto = getIpProto(srcIp, dstIp);
+        final boolean hasEther = (srcMac != null && dstMac != null);
+        final int payloadLen = (payload == null) ? 0 : payload.limit();
+        final ByteBuffer buffer = PacketBuilder.allocate(hasEther, ipProto, IPPROTO_UDP,
+                payloadLen);
+        final PacketBuilder packetBuilder = new PacketBuilder(buffer);
+
+        // [1] Ethernet header
+        if (hasEther) {
+            packetBuilder.writeL2Header(srcMac, dstMac, getEthType(srcIp, dstIp));
+        }
+
+        // [2] IP header
+        if (ipProto == IPPROTO_IP) {
+            packetBuilder.writeIpv4Header(TYPE_OF_SERVICE, ID, FLAGS_AND_FRAGMENT_OFFSET,
+                    TIME_TO_LIVE, (byte) IPPROTO_UDP, (Inet4Address) srcIp, (Inet4Address) dstIp);
+        } else {
+            packetBuilder.writeIpv6Header(VERSION_TRAFFICCLASS_FLOWLABEL, (byte) IPPROTO_UDP,
+                    HOP_LIMIT, (Inet6Address) srcIp, (Inet6Address) dstIp);
+        }
+
+        // [3] UDP header
+        packetBuilder.writeUdpHeader(srcPort, dstPort);
+
+        // [4] Payload
+        if (payload != null) {
+            buffer.put(payload);
+            // in case data might be reused by caller, restore the position and
+            // limit of bytebuffer.
+            payload.clear();
+        }
+
+        return packetBuilder.finalizePacket();
+    }
+
+    @NonNull
+    public static ByteBuffer buildUdpPacket(@NonNull final InetAddress srcIp,
+            @NonNull final InetAddress dstIp, short srcPort, short dstPort,
+            @Nullable final ByteBuffer payload) throws Exception {
+        return buildUdpPacket(null /* srcMac */, null /* dstMac */, srcIp, dstIp, srcPort,
+                dstPort, payload);
+    }
+
+    @NonNull
+    public static ByteBuffer buildTcpPacket(
+            @Nullable final MacAddress srcMac, @Nullable final MacAddress dstMac,
+            @NonNull final InetAddress srcIp, @NonNull final InetAddress dstIp,
+            short srcPort, short dstPort, final short seq, final short ack,
+            final byte tcpFlags, @NonNull final ByteBuffer payload) throws Exception {
+        final int ipProto = getIpProto(srcIp, dstIp);
+        final boolean hasEther = (srcMac != null && dstMac != null);
+        final ByteBuffer buffer = PacketBuilder.allocate(hasEther, ipProto, IPPROTO_TCP,
+                payload.limit());
+        final PacketBuilder packetBuilder = new PacketBuilder(buffer);
+
+        // [1] Ethernet header
+        if (hasEther) {
+            packetBuilder.writeL2Header(srcMac, dstMac, getEthType(srcIp, dstIp));
+        }
+
+        // [2] IP header
+        if (ipProto == IPPROTO_IP) {
+            packetBuilder.writeIpv4Header(TYPE_OF_SERVICE, ID, FLAGS_AND_FRAGMENT_OFFSET,
+                    TIME_TO_LIVE, (byte) IPPROTO_TCP, (Inet4Address) srcIp, (Inet4Address) dstIp);
+        } else {
+            packetBuilder.writeIpv6Header(VERSION_TRAFFICCLASS_FLOWLABEL, (byte) IPPROTO_TCP,
+                    HOP_LIMIT, (Inet6Address) srcIp, (Inet6Address) dstIp);
+        }
+
+        // [3] TCP header
+        packetBuilder.writeTcpHeader(srcPort, dstPort, seq, ack, tcpFlags, WINDOW, URGENT_POINTER);
+
+        // [4] Payload
+        buffer.put(payload);
+        // in case data might be reused by caller, restore the position and
+        // limit of bytebuffer.
+        payload.clear();
+
+        return packetBuilder.finalizePacket();
+    }
+
+    // PacketBuilder doesn't support IPv4 ICMP packet. It may need to refactor PacketBuilder first
+    // because ICMP is a specific layer 3 protocol for PacketBuilder which expects packets always
+    // have layer 3 (IP) and layer 4 (TCP, UDP) for now. Since we don't use IPv4 ICMP packet too
+    // much in this test, we just write a ICMP packet builder here.
+    @NonNull
+    public static ByteBuffer buildIcmpEchoPacketV4(
+            @Nullable final MacAddress srcMac, @Nullable final MacAddress dstMac,
+            @NonNull final Inet4Address srcIp, @NonNull final Inet4Address dstIp,
+            int type, short id, short seq) throws Exception {
+        if (type != ICMP_ECHO && type != ICMP_ECHOREPLY) {
+            fail("Unsupported ICMP type: " + type);
+        }
+
+        // Build ICMP echo id and seq fields as payload. Ignore the data field.
+        final ByteBuffer payload = ByteBuffer.allocate(4);
+        payload.putShort(id);
+        payload.putShort(seq);
+        payload.rewind();
+
+        final boolean hasEther = (srcMac != null && dstMac != null);
+        final int etherHeaderLen = hasEther ? Struct.getSize(EthernetHeader.class) : 0;
+        final int ipv4HeaderLen = Struct.getSize(Ipv4Header.class);
+        final int Icmpv4HeaderLen = Struct.getSize(Icmpv4Header.class);
+        final int payloadLen = payload.limit();
+        final ByteBuffer packet = ByteBuffer.allocate(etherHeaderLen + ipv4HeaderLen
+                + Icmpv4HeaderLen + payloadLen);
+
+        // [1] Ethernet header
+        if (hasEther) {
+            final EthernetHeader ethHeader = new EthernetHeader(dstMac, srcMac, ETHER_TYPE_IPV4);
+            ethHeader.writeToByteBuffer(packet);
+        }
+
+        // [2] IP header
+        final Ipv4Header ipv4Header = new Ipv4Header(TYPE_OF_SERVICE,
+                (short) 0 /* totalLength, calculate later */, ID,
+                FLAGS_AND_FRAGMENT_OFFSET, TIME_TO_LIVE, (byte) IPPROTO_ICMP,
+                (short) 0 /* checksum, calculate later */, srcIp, dstIp);
+        ipv4Header.writeToByteBuffer(packet);
+
+        // [3] ICMP header
+        final Icmpv4Header icmpv4Header = new Icmpv4Header((byte) type, ICMPECHO_CODE,
+                (short) 0 /* checksum, calculate later */);
+        icmpv4Header.writeToByteBuffer(packet);
+
+        // [4] Payload
+        packet.put(payload);
+        packet.flip();
+
+        // [5] Finalize packet
+        // Used for updating IP header fields. If there is Ehternet header, IPv4 header offset
+        // in buffer equals ethernet header length because IPv4 header is located next to ethernet
+        // header. Otherwise, IPv4 header offset is 0.
+        final int ipv4HeaderOffset = hasEther ? etherHeaderLen : 0;
+
+        // Populate the IPv4 totalLength field.
+        packet.putShort(ipv4HeaderOffset + IPV4_LENGTH_OFFSET,
+                (short) (ipv4HeaderLen + Icmpv4HeaderLen + payloadLen));
+
+        // Populate the IPv4 header checksum field.
+        packet.putShort(ipv4HeaderOffset + IPV4_CHECKSUM_OFFSET,
+                ipChecksum(packet, ipv4HeaderOffset /* headerOffset */));
+
+        // Populate the ICMP checksum field.
+        packet.putShort(ipv4HeaderOffset + IPV4_HEADER_MIN_LEN + ICMP_CHECKSUM_OFFSET,
+                icmpChecksum(packet, ipv4HeaderOffset + IPV4_HEADER_MIN_LEN,
+                        Icmpv4HeaderLen + payloadLen));
+        return packet;
+    }
+
+    @NonNull
+    public static ByteBuffer buildIcmpEchoPacketV4(@NonNull final Inet4Address srcIp,
+            @NonNull final Inet4Address dstIp, int type, short id, short seq)
+            throws Exception {
+        return buildIcmpEchoPacketV4(null /* srcMac */, null /* dstMac */, srcIp, dstIp,
+                type, id, seq);
+    }
+
+    private static short getEthType(@NonNull final InetAddress srcIp,
+            @NonNull final InetAddress dstIp) {
+        return isAddressIpv4(srcIp, dstIp) ? (short) ETHER_TYPE_IPV4 : (short) ETHER_TYPE_IPV6;
+    }
+
+    private static int getIpProto(@NonNull final InetAddress srcIp,
+            @NonNull final InetAddress dstIp) {
+        return isAddressIpv4(srcIp, dstIp) ? IPPROTO_IP : IPPROTO_IPV6;
+    }
+
+    public static boolean isAddressIpv4(@NonNull final  InetAddress srcIp,
+            @NonNull final InetAddress dstIp) {
+        if (srcIp instanceof Inet4Address && dstIp instanceof Inet4Address) return true;
+        if (srcIp instanceof Inet6Address && dstIp instanceof Inet6Address) return false;
+
+        fail("Unsupported conditions: srcIp " + srcIp + ", dstIp " + dstIp);
+        return false;  // unreachable
+    }
+
     public void sendUploadPacket(ByteBuffer packet) throws Exception {
         mDownstreamReader.sendResponse(packet);
     }
@@ -650,10 +892,85 @@
         return null;
     }
 
+    @NonNull
+    private ByteBuffer buildUdpDnsPrefix64ReplyPacket(int dnsId, @NonNull final Inet6Address srcIp,
+            @NonNull final Inet6Address dstIp, short srcPort, short dstPort) throws Exception {
+        // [1] Build prefix64 DNS message.
+        final ArrayList<DnsRecord> qlist = new ArrayList<>();
+        // Fill QD section.
+        qlist.add(DnsRecord.makeQuestion(PREF64_IPV4ONLY_HOSTNAME, TYPE_AAAA, CLASS_IN));
+        final ArrayList<DnsRecord> alist = new ArrayList<>();
+        // Fill AN sections.
+        alist.add(DnsRecord.makeAOrAAAARecord(ANSECTION, PREF64_IPV4ONLY_HOSTNAME, CLASS_IN, TTL,
+                PREF64_IPV4ONLY_ADDR));
+        final TestDnsPacket dns = new TestDnsPacket(
+                new DnsHeader(dnsId, FLAG, qlist.size(), alist.size()), qlist, alist);
+
+        // [2] Build IPv6 UDP DNS packet.
+        return buildUdpPacket(srcIp, dstIp, srcPort, dstPort, ByteBuffer.wrap(dns.getBytes()));
+    }
+
+    private void maybeReplyUdpDnsPrefix64Discovery(@NonNull byte[] packet) {
+        final ByteBuffer buf = ByteBuffer.wrap(packet);
+
+        // [1] Parse the prefix64 discovery DNS query for hostname ipv4only.arpa.
+        // Parse IPv6 and UDP header.
+        Ipv6Header ipv6Header = null;
+        try {
+            ipv6Header = Struct.parse(Ipv6Header.class, buf);
+            if (ipv6Header == null || ipv6Header.nextHeader != IPPROTO_UDP) return;
+        } catch (Exception e) {
+            // Parsing packet fail means it is not IPv6 UDP packet.
+            return;
+        }
+        final UdpHeader udpHeader = Struct.parse(UdpHeader.class, buf);
+
+        // Parse DNS message.
+        final TestDnsPacket pref64Query = TestDnsPacket.getTestDnsPacket(buf);
+        if (pref64Query == null) return;
+        if (pref64Query.getHeader().isResponse()) return;
+        if (pref64Query.getQDCount() != 1) return;
+        if (pref64Query.getANCount() != 0) return;
+        if (pref64Query.getNSCount() != 0) return;
+        if (pref64Query.getARCount() != 0) return;
+
+        final List<DnsRecord> qdRecordList = pref64Query.getRecordList(QDSECTION);
+        if (qdRecordList.size() != 1) return;
+        if (!qdRecordList.get(0).dName.equals(PREF64_IPV4ONLY_HOSTNAME)) return;
+
+        // [2] Build prefix64 DNS discovery reply from received query.
+        // DNS response transaction id must be copied from DNS query. Used by the requester
+        // to match up replies to outstanding queries. See RFC 1035 section 4.1.1. Also reverse
+        // the source/destination address/port of query packet for building reply packet.
+        final ByteBuffer replyPacket;
+        try {
+            replyPacket = buildUdpDnsPrefix64ReplyPacket(pref64Query.getHeader().getId(),
+                    ipv6Header.dstIp /* srcIp */, ipv6Header.srcIp /* dstIp */,
+                    (short) udpHeader.dstPort /* srcPort */,
+                    (short) udpHeader.srcPort /* dstPort */);
+        } catch (Exception e) {
+            fail("Failed to build prefix64 discovery reply for " + ipv6Header.srcIp + ": " + e);
+            return;
+        }
+
+        Log.d(TAG, "Sending prefix64 discovery reply");
+        try {
+            sendDownloadPacket(replyPacket);
+        } catch (Exception e) {
+            fail("Failed to reply prefix64 discovery for " + ipv6Header.srcIp + ": " + e);
+        }
+    }
+
     private byte[] getUploadPacket(Predicate<byte[]> filter) {
         assertNotNull("Can't deal with upstream interface in local only mode", mUpstreamReader);
 
-        return mUpstreamReader.poll(PACKET_READ_TIMEOUT_MS, filter);
+        byte[] packet;
+        while ((packet = mUpstreamReader.poll(PACKET_READ_TIMEOUT_MS)) != null) {
+            if (filter.test(packet)) return packet;
+
+            maybeReplyUdpDnsPrefix64Discovery(packet);
+        }
+        return null;
     }
 
     private @NonNull byte[] verifyPacketNotNull(String message, @Nullable byte[] packet) {
diff --git a/Tethering/tests/integration/src/android/net/EthernetTetheringTest.java b/Tethering/tests/integration/src/android/net/EthernetTetheringTest.java
index 5d57aa5..eed308c 100644
--- a/Tethering/tests/integration/src/android/net/EthernetTetheringTest.java
+++ b/Tethering/tests/integration/src/android/net/EthernetTetheringTest.java
@@ -20,23 +20,17 @@
 import static android.net.TetheringManager.CONNECTIVITY_SCOPE_LOCAL;
 import static android.net.TetheringManager.TETHERING_ETHERNET;
 import static android.net.TetheringTester.TestDnsPacket;
+import static android.net.TetheringTester.buildIcmpEchoPacketV4;
+import static android.net.TetheringTester.buildUdpPacket;
 import static android.net.TetheringTester.isExpectedIcmpPacket;
 import static android.net.TetheringTester.isExpectedUdpDnsPacket;
 import static android.system.OsConstants.ICMP_ECHO;
 import static android.system.OsConstants.ICMP_ECHOREPLY;
-import static android.system.OsConstants.IPPROTO_ICMP;
 
 import static com.android.net.module.util.ConnectivityUtils.isIPv6ULA;
 import static com.android.net.module.util.HexDump.dumpHexString;
-import static com.android.net.module.util.IpUtils.icmpChecksum;
-import static com.android.net.module.util.IpUtils.ipChecksum;
-import static com.android.net.module.util.NetworkStackConstants.ETHER_TYPE_IPV4;
 import static com.android.net.module.util.NetworkStackConstants.ICMPV6_ECHO_REPLY_TYPE;
 import static com.android.net.module.util.NetworkStackConstants.ICMPV6_ECHO_REQUEST_TYPE;
-import static com.android.net.module.util.NetworkStackConstants.ICMP_CHECKSUM_OFFSET;
-import static com.android.net.module.util.NetworkStackConstants.IPV4_CHECKSUM_OFFSET;
-import static com.android.net.module.util.NetworkStackConstants.IPV4_HEADER_MIN_LEN;
-import static com.android.net.module.util.NetworkStackConstants.IPV4_LENGTH_OFFSET;
 
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNotNull;
@@ -53,14 +47,11 @@
 import android.util.Log;
 
 import androidx.annotation.NonNull;
-import androidx.annotation.Nullable;
 import androidx.test.filters.MediumTest;
 import androidx.test.runner.AndroidJUnit4;
 
 import com.android.net.module.util.Ipv6Utils;
 import com.android.net.module.util.Struct;
-import com.android.net.module.util.structs.EthernetHeader;
-import com.android.net.module.util.structs.Icmpv4Header;
 import com.android.net.module.util.structs.Ipv4Header;
 import com.android.net.module.util.structs.UdpHeader;
 import com.android.testutils.DevSdkIgnoreRule;
@@ -96,7 +87,6 @@
     private static final String TAG = EthernetTetheringTest.class.getSimpleName();
 
     private static final short DNS_PORT = 53;
-    private static final short ICMPECHO_CODE = 0x0;
     private static final short ICMPECHO_ID = 0x0;
     private static final short ICMPECHO_SEQ = 0x0;
 
@@ -564,85 +554,6 @@
         runClatUdpTest();
     }
 
-    // PacketBuilder doesn't support IPv4 ICMP packet. It may need to refactor PacketBuilder first
-    // because ICMP is a specific layer 3 protocol for PacketBuilder which expects packets always
-    // have layer 3 (IP) and layer 4 (TCP, UDP) for now. Since we don't use IPv4 ICMP packet too
-    // much in this test, we just write a ICMP packet builder here.
-    // TODO: move ICMPv4 packet build function to common utilis.
-    @NonNull
-    private ByteBuffer buildIcmpEchoPacketV4(
-            @Nullable final MacAddress srcMac, @Nullable final MacAddress dstMac,
-            @NonNull final Inet4Address srcIp, @NonNull final Inet4Address dstIp,
-            int type, short id, short seq) throws Exception {
-        if (type != ICMP_ECHO && type != ICMP_ECHOREPLY) {
-            fail("Unsupported ICMP type: " + type);
-        }
-
-        // Build ICMP echo id and seq fields as payload. Ignore the data field.
-        final ByteBuffer payload = ByteBuffer.allocate(4);
-        payload.putShort(id);
-        payload.putShort(seq);
-        payload.rewind();
-
-        final boolean hasEther = (srcMac != null && dstMac != null);
-        final int etherHeaderLen = hasEther ? Struct.getSize(EthernetHeader.class) : 0;
-        final int ipv4HeaderLen = Struct.getSize(Ipv4Header.class);
-        final int Icmpv4HeaderLen = Struct.getSize(Icmpv4Header.class);
-        final int payloadLen = payload.limit();
-        final ByteBuffer packet = ByteBuffer.allocate(etherHeaderLen + ipv4HeaderLen
-                + Icmpv4HeaderLen + payloadLen);
-
-        // [1] Ethernet header
-        if (hasEther) {
-            final EthernetHeader ethHeader = new EthernetHeader(dstMac, srcMac, ETHER_TYPE_IPV4);
-            ethHeader.writeToByteBuffer(packet);
-        }
-
-        // [2] IP header
-        final Ipv4Header ipv4Header = new Ipv4Header(TYPE_OF_SERVICE,
-                (short) 0 /* totalLength, calculate later */, ID,
-                FLAGS_AND_FRAGMENT_OFFSET, TIME_TO_LIVE, (byte) IPPROTO_ICMP,
-                (short) 0 /* checksum, calculate later */, srcIp, dstIp);
-        ipv4Header.writeToByteBuffer(packet);
-
-        // [3] ICMP header
-        final Icmpv4Header icmpv4Header = new Icmpv4Header((byte) type, ICMPECHO_CODE,
-                (short) 0 /* checksum, calculate later */);
-        icmpv4Header.writeToByteBuffer(packet);
-
-        // [4] Payload
-        packet.put(payload);
-        packet.flip();
-
-        // [5] Finalize packet
-        // Used for updating IP header fields. If there is Ehternet header, IPv4 header offset
-        // in buffer equals ethernet header length because IPv4 header is located next to ethernet
-        // header. Otherwise, IPv4 header offset is 0.
-        final int ipv4HeaderOffset = hasEther ? etherHeaderLen : 0;
-
-        // Populate the IPv4 totalLength field.
-        packet.putShort(ipv4HeaderOffset + IPV4_LENGTH_OFFSET,
-                (short) (ipv4HeaderLen + Icmpv4HeaderLen + payloadLen));
-
-        // Populate the IPv4 header checksum field.
-        packet.putShort(ipv4HeaderOffset + IPV4_CHECKSUM_OFFSET,
-                ipChecksum(packet, ipv4HeaderOffset /* headerOffset */));
-
-        // Populate the ICMP checksum field.
-        packet.putShort(ipv4HeaderOffset + IPV4_HEADER_MIN_LEN + ICMP_CHECKSUM_OFFSET,
-                icmpChecksum(packet, ipv4HeaderOffset + IPV4_HEADER_MIN_LEN,
-                        Icmpv4HeaderLen + payloadLen));
-        return packet;
-    }
-
-    @NonNull
-    private ByteBuffer buildIcmpEchoPacketV4(@NonNull final Inet4Address srcIp,
-            @NonNull final Inet4Address dstIp, int type, short id, short seq)
-            throws Exception {
-        return buildIcmpEchoPacketV4(null /* srcMac */, null /* dstMac */, srcIp, dstIp,
-                type, id, seq);
-    }
-
     @Test
     public void testIcmpv4Echo() throws Exception {
         final TetheringTester tester = initTetheringTester(toList(TEST_IP4_ADDR),
diff --git a/service-t/native/libs/libnetworkstats/NetworkTracePoller.cpp b/service-t/native/libs/libnetworkstats/NetworkTracePoller.cpp
index d538368..80c315a 100644
--- a/service-t/native/libs/libnetworkstats/NetworkTracePoller.cpp
+++ b/service-t/native/libs/libnetworkstats/NetworkTracePoller.cpp
@@ -29,16 +29,16 @@
 namespace bpf {
 namespace internal {
 
-void NetworkTracePoller::SchedulePolling() {
-  // Schedules another run of ourselves to recursively poll periodically.
-  mTaskRunner->PostDelayedTask(
-      [this]() {
-        mMutex.lock();
-        SchedulePolling();
-        ConsumeAllLocked();
-        mMutex.unlock();
-      },
-      mPollMs);
+void NetworkTracePoller::PollAndSchedule(perfetto::base::TaskRunner* runner,
+                                         uint32_t poll_ms) {
+  // Always schedule another run of ourselves to recursively poll periodically.
+  // The task runner is sequential so these can't run on top of each other.
+  runner->PostDelayedTask([=]() { PollAndSchedule(runner, poll_ms); }, poll_ms);
+
+  if (mMutex.try_lock()) {
+    ConsumeAllLocked();
+    mMutex.unlock();
+  }
 }
 
 bool NetworkTracePoller::Start(uint32_t pollMs) {
@@ -81,7 +81,7 @@
   // Start a task runner to run ConsumeAll every mPollMs milliseconds.
   mTaskRunner = perfetto::Platform::GetDefaultPlatform()->CreateTaskRunner({});
   mPollMs = pollMs;
-  SchedulePolling();
+  PollAndSchedule(mTaskRunner.get(), mPollMs);
 
   mSessionCount++;
   return true;
diff --git a/service-t/native/libs/libnetworkstats/include/netdbpf/NetworkTracePoller.h b/service-t/native/libs/libnetworkstats/include/netdbpf/NetworkTracePoller.h
index adde51e..8433934 100644
--- a/service-t/native/libs/libnetworkstats/include/netdbpf/NetworkTracePoller.h
+++ b/service-t/native/libs/libnetworkstats/include/netdbpf/NetworkTracePoller.h
@@ -53,7 +53,12 @@
   bool ConsumeAll() EXCLUDES(mMutex);
 
  private:
-  void SchedulePolling() REQUIRES(mMutex);
+  // Poll the ring buffer for new data and schedule another run of ourselves
+  // after poll_ms (essentially polling periodically until stopped). This takes
+  // in the runner and poll duration to prevent a hard requirement on the lock
+  // and thus a deadlock while resetting the TaskRunner. The runner pointer is
+  // always valid within tasks run by that runner.
+  void PollAndSchedule(perfetto::base::TaskRunner* runner, uint32_t poll_ms);
   bool ConsumeAllLocked() REQUIRES(mMutex);
 
   std::mutex mMutex;
diff --git a/service-t/src/com/android/server/NsdService.java b/service-t/src/com/android/server/NsdService.java
index 25aa693..b06e9cb 100644
--- a/service-t/src/com/android/server/NsdService.java
+++ b/service-t/src/com/android/server/NsdService.java
@@ -188,8 +188,8 @@
      */
     private final HashMap<NsdServiceConnector, ClientInfo> mClients = new HashMap<>();
 
-    /* A map from unique id to client info */
-    private final SparseArray<ClientInfo> mIdToClientInfoMap= new SparseArray<>();
+    /* A map from transaction(unique) id to client info */
+    private final SparseArray<ClientInfo> mTransactionIdToClientInfoMap = new SparseArray<>();
 
     // Note this is not final to avoid depending on the Wi-Fi service starting before NsdService
     @Nullable
@@ -211,16 +211,16 @@
     private int mClientNumberId = 1;
 
     private static class MdnsListener implements MdnsServiceBrowserListener {
-        protected final int mClientId;
+        protected final int mClientRequestId;
         protected final int mTransactionId;
         @NonNull
         protected final NsdServiceInfo mReqServiceInfo;
         @NonNull
         protected final String mListenedServiceType;
 
-        MdnsListener(int clientId, int transactionId, @NonNull NsdServiceInfo reqServiceInfo,
+        MdnsListener(int clientRequestId, int transactionId, @NonNull NsdServiceInfo reqServiceInfo,
                 @NonNull String listenedServiceType) {
-            mClientId = clientId;
+            mClientRequestId = clientRequestId;
             mTransactionId = transactionId;
             mReqServiceInfo = reqServiceInfo;
             mListenedServiceType = listenedServiceType;
@@ -261,67 +261,67 @@
 
     private class DiscoveryListener extends MdnsListener {
 
-        DiscoveryListener(int clientId, int transactionId, @NonNull NsdServiceInfo reqServiceInfo,
-                @NonNull String listenServiceType) {
-            super(clientId, transactionId, reqServiceInfo, listenServiceType);
+        DiscoveryListener(int clientRequestId, int transactionId,
+                @NonNull NsdServiceInfo reqServiceInfo, @NonNull String listenServiceType) {
+            super(clientRequestId, transactionId, reqServiceInfo, listenServiceType);
         }
 
         @Override
         public void onServiceNameDiscovered(@NonNull MdnsServiceInfo serviceInfo) {
             mNsdStateMachine.sendMessage(MDNS_DISCOVERY_MANAGER_EVENT, mTransactionId,
                     NsdManager.SERVICE_FOUND,
-                    new MdnsEvent(mClientId, serviceInfo));
+                    new MdnsEvent(mClientRequestId, serviceInfo));
         }
 
         @Override
         public void onServiceNameRemoved(@NonNull MdnsServiceInfo serviceInfo) {
             mNsdStateMachine.sendMessage(MDNS_DISCOVERY_MANAGER_EVENT, mTransactionId,
                     NsdManager.SERVICE_LOST,
-                    new MdnsEvent(mClientId, serviceInfo));
+                    new MdnsEvent(mClientRequestId, serviceInfo));
         }
     }
 
     private class ResolutionListener extends MdnsListener {
 
-        ResolutionListener(int clientId, int transactionId, @NonNull NsdServiceInfo reqServiceInfo,
-                @NonNull String listenServiceType) {
-            super(clientId, transactionId, reqServiceInfo, listenServiceType);
+        ResolutionListener(int clientRequestId, int transactionId,
+                @NonNull NsdServiceInfo reqServiceInfo, @NonNull String listenServiceType) {
+            super(clientRequestId, transactionId, reqServiceInfo, listenServiceType);
         }
 
         @Override
         public void onServiceFound(MdnsServiceInfo serviceInfo) {
             mNsdStateMachine.sendMessage(MDNS_DISCOVERY_MANAGER_EVENT, mTransactionId,
                     NsdManager.RESOLVE_SERVICE_SUCCEEDED,
-                    new MdnsEvent(mClientId, serviceInfo));
+                    new MdnsEvent(mClientRequestId, serviceInfo));
         }
     }
 
     private class ServiceInfoListener extends MdnsListener {
 
-        ServiceInfoListener(int clientId, int transactionId, @NonNull NsdServiceInfo reqServiceInfo,
-                @NonNull String listenServiceType) {
-            super(clientId, transactionId, reqServiceInfo, listenServiceType);
+        ServiceInfoListener(int clientRequestId, int transactionId,
+                @NonNull NsdServiceInfo reqServiceInfo, @NonNull String listenServiceType) {
+            super(clientRequestId, transactionId, reqServiceInfo, listenServiceType);
         }
 
         @Override
         public void onServiceFound(@NonNull MdnsServiceInfo serviceInfo) {
             mNsdStateMachine.sendMessage(MDNS_DISCOVERY_MANAGER_EVENT, mTransactionId,
                     NsdManager.SERVICE_UPDATED,
-                    new MdnsEvent(mClientId, serviceInfo));
+                    new MdnsEvent(mClientRequestId, serviceInfo));
         }
 
         @Override
         public void onServiceUpdated(@NonNull MdnsServiceInfo serviceInfo) {
             mNsdStateMachine.sendMessage(MDNS_DISCOVERY_MANAGER_EVENT, mTransactionId,
                     NsdManager.SERVICE_UPDATED,
-                    new MdnsEvent(mClientId, serviceInfo));
+                    new MdnsEvent(mClientRequestId, serviceInfo));
         }
 
         @Override
         public void onServiceRemoved(@NonNull MdnsServiceInfo serviceInfo) {
             mNsdStateMachine.sendMessage(MDNS_DISCOVERY_MANAGER_EVENT, mTransactionId,
                     NsdManager.SERVICE_UPDATED_LOST,
-                    new MdnsEvent(mClientId, serviceInfo));
+                    new MdnsEvent(mClientRequestId, serviceInfo));
         }
     }
 
@@ -409,8 +409,8 @@
             // Return early if NSD is not active, or not on any relevant network
             return -1;
         }
-        for (int i = 0; i < mIdToClientInfoMap.size(); i++) {
-            final ClientInfo clientInfo = mIdToClientInfoMap.valueAt(i);
+        for (int i = 0; i < mTransactionIdToClientInfoMap.size(); i++) {
+            final ClientInfo clientInfo = mTransactionIdToClientInfoMap.valueAt(i);
             if (!mRunningAppActiveUids.contains(clientInfo.mUid)) {
                 // Ignore non-active UIDs
                 continue;
@@ -427,12 +427,12 @@
      * Data class of mdns service callback information.
      */
     private static class MdnsEvent {
-        final int mClientId;
+        final int mClientRequestId;
         @NonNull
         final MdnsServiceInfo mMdnsServiceInfo;
 
-        MdnsEvent(int clientId, @NonNull MdnsServiceInfo mdnsServiceInfo) {
-            mClientId = clientId;
+        MdnsEvent(int clientRequestId, @NonNull MdnsServiceInfo mdnsServiceInfo) {
+            mClientRequestId = clientRequestId;
             mMdnsServiceInfo = mdnsServiceInfo;
         }
     }
@@ -471,7 +471,7 @@
         }
 
         private boolean isAnyRequestActive() {
-            return mIdToClientInfoMap.size() != 0;
+            return mTransactionIdToClientInfoMap.size() != 0;
         }
 
         private void scheduleStop() {
@@ -520,7 +520,7 @@
             @Override
             public boolean processMessage(Message msg) {
                 final ClientInfo cInfo;
-                final int clientId = msg.arg2;
+                final int clientRequestId = msg.arg2;
                 switch (msg.what) {
                     case NsdManager.REGISTER_CLIENT:
                         final ConnectorArgs arg = (ConnectorArgs) msg.obj;
@@ -532,7 +532,8 @@
                                     mServiceLogs.forSubComponent(tag));
                             mClients.put(arg.connector, cInfo);
                         } catch (RemoteException e) {
-                            Log.w(TAG, "Client " + clientId + " has already died");
+                            Log.w(TAG, "Client request id " + clientRequestId
+                                    + " has already died");
                         }
                         break;
                     case NsdManager.UNREGISTER_CLIENT:
@@ -551,49 +552,49 @@
                         cInfo = getClientInfoForReply(msg);
                         if (cInfo != null) {
                             cInfo.onDiscoverServicesFailed(
-                                    clientId, NsdManager.FAILURE_INTERNAL_ERROR);
+                                    clientRequestId, NsdManager.FAILURE_INTERNAL_ERROR);
                         }
                        break;
                     case NsdManager.STOP_DISCOVERY:
                         cInfo = getClientInfoForReply(msg);
                         if (cInfo != null) {
                             cInfo.onStopDiscoveryFailed(
-                                    clientId, NsdManager.FAILURE_INTERNAL_ERROR);
+                                    clientRequestId, NsdManager.FAILURE_INTERNAL_ERROR);
                         }
                         break;
                     case NsdManager.REGISTER_SERVICE:
                         cInfo = getClientInfoForReply(msg);
                         if (cInfo != null) {
                             cInfo.onRegisterServiceFailed(
-                                    clientId, NsdManager.FAILURE_INTERNAL_ERROR);
+                                    clientRequestId, NsdManager.FAILURE_INTERNAL_ERROR);
                         }
                         break;
                     case NsdManager.UNREGISTER_SERVICE:
                         cInfo = getClientInfoForReply(msg);
                         if (cInfo != null) {
                             cInfo.onUnregisterServiceFailed(
-                                    clientId, NsdManager.FAILURE_INTERNAL_ERROR);
+                                    clientRequestId, NsdManager.FAILURE_INTERNAL_ERROR);
                         }
                         break;
                     case NsdManager.RESOLVE_SERVICE:
                         cInfo = getClientInfoForReply(msg);
                         if (cInfo != null) {
                             cInfo.onResolveServiceFailed(
-                                    clientId, NsdManager.FAILURE_INTERNAL_ERROR);
+                                    clientRequestId, NsdManager.FAILURE_INTERNAL_ERROR);
                         }
                         break;
                     case NsdManager.STOP_RESOLUTION:
                         cInfo = getClientInfoForReply(msg);
                         if (cInfo != null) {
                             cInfo.onStopResolutionFailed(
-                                    clientId, NsdManager.FAILURE_OPERATION_NOT_RUNNING);
+                                    clientRequestId, NsdManager.FAILURE_OPERATION_NOT_RUNNING);
                         }
                         break;
                     case NsdManager.REGISTER_SERVICE_CALLBACK:
                         cInfo = getClientInfoForReply(msg);
                         if (cInfo != null) {
                             cInfo.onServiceInfoCallbackRegistrationFailed(
-                                    clientId, NsdManager.FAILURE_BAD_PARAMETERS);
+                                    clientRequestId, NsdManager.FAILURE_BAD_PARAMETERS);
                         }
                         break;
                     case NsdManager.DAEMON_CLEANUP:
@@ -644,27 +645,29 @@
                 return false;
             }
 
-            private void storeLegacyRequestMap(int clientId, int globalId, ClientInfo clientInfo,
-                    int what) {
-                clientInfo.mClientRequests.put(clientId, new LegacyClientRequest(globalId, what));
-                mIdToClientInfoMap.put(globalId, clientInfo);
+            private void storeLegacyRequestMap(int clientRequestId, int transactionId,
+                    ClientInfo clientInfo, int what) {
+                clientInfo.mClientRequests.put(
+                        clientRequestId, new LegacyClientRequest(transactionId, what));
+                mTransactionIdToClientInfoMap.put(transactionId, clientInfo);
                 // Remove the cleanup event because here comes a new request.
                 cancelStop();
             }
 
-            private void storeAdvertiserRequestMap(int clientId, int globalId,
+            private void storeAdvertiserRequestMap(int clientRequestId, int transactionId,
                     ClientInfo clientInfo, @Nullable Network requestedNetwork) {
-                clientInfo.mClientRequests.put(clientId,
-                        new AdvertiserClientRequest(globalId, requestedNetwork));
-                mIdToClientInfoMap.put(globalId, clientInfo);
+                clientInfo.mClientRequests.put(clientRequestId,
+                        new AdvertiserClientRequest(transactionId, requestedNetwork));
+                mTransactionIdToClientInfoMap.put(transactionId, clientInfo);
                 updateMulticastLock();
             }
 
-            private void removeRequestMap(int clientId, int globalId, ClientInfo clientInfo) {
-                final ClientRequest existing = clientInfo.mClientRequests.get(clientId);
+            private void removeRequestMap(
+                    int clientRequestId, int transactionId, ClientInfo clientInfo) {
+                final ClientRequest existing = clientInfo.mClientRequests.get(clientRequestId);
                 if (existing == null) return;
-                clientInfo.mClientRequests.remove(clientId);
-                mIdToClientInfoMap.remove(globalId);
+                clientInfo.mClientRequests.remove(clientRequestId);
+                mTransactionIdToClientInfoMap.remove(transactionId);
 
                 if (existing instanceof LegacyClientRequest) {
                     maybeScheduleStop();
@@ -674,12 +677,12 @@
                 }
             }
 
-            private void storeDiscoveryManagerRequestMap(int clientId, int globalId,
+            private void storeDiscoveryManagerRequestMap(int clientRequestId, int transactionId,
                     MdnsListener listener, ClientInfo clientInfo,
                     @Nullable Network requestedNetwork) {
-                clientInfo.mClientRequests.put(clientId,
-                        new DiscoveryManagerRequest(globalId, listener, requestedNetwork));
-                mIdToClientInfoMap.put(globalId, clientInfo);
+                clientInfo.mClientRequests.put(clientRequestId,
+                        new DiscoveryManagerRequest(transactionId, listener, requestedNetwork));
+                mTransactionIdToClientInfoMap.put(transactionId, clientInfo);
                 updateMulticastLock();
             }
 
@@ -695,17 +698,17 @@
                 return MdnsUtils.truncateServiceName(originalName, MAX_LABEL_LENGTH);
             }
 
-            private void stopDiscoveryManagerRequest(ClientRequest request, int clientId, int id,
-                    ClientInfo clientInfo) {
+            private void stopDiscoveryManagerRequest(ClientRequest request, int clientRequestId,
+                    int transactionId, ClientInfo clientInfo) {
                 clientInfo.unregisterMdnsListenerFromRequest(request);
-                removeRequestMap(clientId, id, clientInfo);
+                removeRequestMap(clientRequestId, transactionId, clientInfo);
             }
 
             @Override
             public boolean processMessage(Message msg) {
                 final ClientInfo clientInfo;
-                final int id;
-                final int clientId = msg.arg2;
+                final int transactionId;
+                final int clientRequestId = msg.arg2;
                 final ListenerArgs args;
                 switch (msg.what) {
                     case NsdManager.DISCOVER_SERVICES: {
@@ -722,12 +725,12 @@
 
                         if (requestLimitReached(clientInfo)) {
                             clientInfo.onDiscoverServicesFailed(
-                                    clientId, NsdManager.FAILURE_MAX_LIMIT);
+                                    clientRequestId, NsdManager.FAILURE_MAX_LIMIT);
                             break;
                         }
 
                         final NsdServiceInfo info = args.serviceInfo;
-                        id = getUniqueId();
+                        transactionId = getUniqueId();
                         final Pair<String, String> typeAndSubtype =
                                 parseTypeAndSubtype(info.getServiceType());
                         final String serviceType = typeAndSubtype == null
@@ -736,15 +739,15 @@
                                 || mDeps.isMdnsDiscoveryManagerEnabled(mContext)
                                 || useDiscoveryManagerForType(serviceType)) {
                             if (serviceType == null) {
-                                clientInfo.onDiscoverServicesFailed(clientId,
+                                clientInfo.onDiscoverServicesFailed(clientRequestId,
                                         NsdManager.FAILURE_INTERNAL_ERROR);
                                 break;
                             }
 
                             final String listenServiceType = serviceType + ".local";
                             maybeStartMonitoringSockets();
-                            final MdnsListener listener =
-                                    new DiscoveryListener(clientId, id, info, listenServiceType);
+                            final MdnsListener listener = new DiscoveryListener(clientRequestId,
+                                    transactionId, info, listenServiceType);
                             final MdnsSearchOptions.Builder optionsBuilder =
                                     MdnsSearchOptions.newBuilder()
                                             .setNetwork(info.getNetwork())
@@ -757,23 +760,24 @@
                             }
                             mMdnsDiscoveryManager.registerListener(
                                     listenServiceType, listener, optionsBuilder.build());
-                            storeDiscoveryManagerRequestMap(clientId, id, listener, clientInfo,
-                                    info.getNetwork());
-                            clientInfo.onDiscoverServicesStarted(clientId, info);
-                            clientInfo.log("Register a DiscoveryListener " + id
+                            storeDiscoveryManagerRequestMap(clientRequestId, transactionId,
+                                    listener, clientInfo, info.getNetwork());
+                            clientInfo.onDiscoverServicesStarted(clientRequestId, info);
+                            clientInfo.log("Register a DiscoveryListener " + transactionId
                                     + " for service type:" + listenServiceType);
                         } else {
                             maybeStartDaemon();
-                            if (discoverServices(id, info)) {
+                            if (discoverServices(transactionId, info)) {
                                 if (DBG) {
-                                    Log.d(TAG, "Discover " + msg.arg2 + " " + id
+                                    Log.d(TAG, "Discover " + msg.arg2 + " " + transactionId
                                             + info.getServiceType());
                                 }
-                                storeLegacyRequestMap(clientId, id, clientInfo, msg.what);
-                                clientInfo.onDiscoverServicesStarted(clientId, info);
+                                storeLegacyRequestMap(
+                                        clientRequestId, transactionId, clientInfo, msg.what);
+                                clientInfo.onDiscoverServicesStarted(clientRequestId, info);
                             } else {
-                                stopServiceDiscovery(id);
-                                clientInfo.onDiscoverServicesFailed(clientId,
+                                stopServiceDiscovery(transactionId);
+                                clientInfo.onDiscoverServicesFailed(clientRequestId,
                                         NsdManager.FAILURE_INTERNAL_ERROR);
                             }
                         }
@@ -791,26 +795,28 @@
                             break;
                         }
 
-                        final ClientRequest request = clientInfo.mClientRequests.get(clientId);
+                        final ClientRequest request =
+                                clientInfo.mClientRequests.get(clientRequestId);
                         if (request == null) {
                             Log.e(TAG, "Unknown client request in STOP_DISCOVERY");
                             break;
                         }
-                        id = request.mGlobalId;
+                        transactionId = request.mTransactionId;
                         // Note isMdnsDiscoveryManagerEnabled may have changed to false at this
                         // point, so this needs to check the type of the original request to
                         // unregister instead of looking at the flag value.
                         if (request instanceof DiscoveryManagerRequest) {
-                            stopDiscoveryManagerRequest(request, clientId, id, clientInfo);
-                            clientInfo.onStopDiscoverySucceeded(clientId);
-                            clientInfo.log("Unregister the DiscoveryListener " + id);
+                            stopDiscoveryManagerRequest(
+                                    request, clientRequestId, transactionId, clientInfo);
+                            clientInfo.onStopDiscoverySucceeded(clientRequestId);
+                            clientInfo.log("Unregister the DiscoveryListener " + transactionId);
                         } else {
-                            removeRequestMap(clientId, id, clientInfo);
-                            if (stopServiceDiscovery(id)) {
-                                clientInfo.onStopDiscoverySucceeded(clientId);
+                            removeRequestMap(clientRequestId, transactionId, clientInfo);
+                            if (stopServiceDiscovery(transactionId)) {
+                                clientInfo.onStopDiscoverySucceeded(clientRequestId);
                             } else {
                                 clientInfo.onStopDiscoveryFailed(
-                                        clientId, NsdManager.FAILURE_INTERNAL_ERROR);
+                                        clientRequestId, NsdManager.FAILURE_INTERNAL_ERROR);
                             }
                         }
                         break;
@@ -829,11 +835,11 @@
 
                         if (requestLimitReached(clientInfo)) {
                             clientInfo.onRegisterServiceFailed(
-                                    clientId, NsdManager.FAILURE_MAX_LIMIT);
+                                    clientRequestId, NsdManager.FAILURE_MAX_LIMIT);
                             break;
                         }
 
-                        id = getUniqueId();
+                        transactionId = getUniqueId();
                         final NsdServiceInfo serviceInfo = args.serviceInfo;
                         final String serviceType = serviceInfo.getServiceType();
                         final Pair<String, String> typeSubtype = parseTypeAndSubtype(serviceType);
@@ -844,7 +850,7 @@
                                 || useAdvertiserForType(registerServiceType)) {
                             if (registerServiceType == null) {
                                 Log.e(TAG, "Invalid service type: " + serviceType);
-                                clientInfo.onRegisterServiceFailed(clientId,
+                                clientInfo.onRegisterServiceFailed(clientRequestId,
                                         NsdManager.FAILURE_INTERNAL_ERROR);
                                 break;
                             }
@@ -857,19 +863,23 @@
                             // service type would generate service instance names like
                             // Name._subtype._sub._type._tcp, which is incorrect
                             // (it should be Name._type._tcp).
-                            mAdvertiser.addService(id, serviceInfo, typeSubtype.second);
-                            storeAdvertiserRequestMap(clientId, id, clientInfo,
+                            mAdvertiser.addService(transactionId, serviceInfo, typeSubtype.second);
+                            storeAdvertiserRequestMap(clientRequestId, transactionId, clientInfo,
                                     serviceInfo.getNetwork());
                         } else {
                             maybeStartDaemon();
-                            if (registerService(id, serviceInfo)) {
-                                if (DBG) Log.d(TAG, "Register " + clientId + " " + id);
-                                storeLegacyRequestMap(clientId, id, clientInfo, msg.what);
+                            if (registerService(transactionId, serviceInfo)) {
+                                if (DBG) {
+                                    Log.d(TAG, "Register " + clientRequestId
+                                            + " " + transactionId);
+                                }
+                                storeLegacyRequestMap(
+                                        clientRequestId, transactionId, clientInfo, msg.what);
                                 // Return success after mDns reports success
                             } else {
-                                unregisterService(id);
+                                unregisterService(transactionId);
                                 clientInfo.onRegisterServiceFailed(
-                                        clientId, NsdManager.FAILURE_INTERNAL_ERROR);
+                                        clientRequestId, NsdManager.FAILURE_INTERNAL_ERROR);
                             }
 
                         }
@@ -886,26 +896,27 @@
                             Log.e(TAG, "Unknown connector in unregistration");
                             break;
                         }
-                        final ClientRequest request = clientInfo.mClientRequests.get(clientId);
+                        final ClientRequest request =
+                                clientInfo.mClientRequests.get(clientRequestId);
                         if (request == null) {
                             Log.e(TAG, "Unknown client request in UNREGISTER_SERVICE");
                             break;
                         }
-                        id = request.mGlobalId;
-                        removeRequestMap(clientId, id, clientInfo);
+                        transactionId = request.mTransactionId;
+                        removeRequestMap(clientRequestId, transactionId, clientInfo);
 
                         // Note isMdnsAdvertiserEnabled may have changed to false at this point,
                         // so this needs to check the type of the original request to unregister
                         // instead of looking at the flag value.
                         if (request instanceof AdvertiserClientRequest) {
-                            mAdvertiser.removeService(id);
-                            clientInfo.onUnregisterServiceSucceeded(clientId);
+                            mAdvertiser.removeService(transactionId);
+                            clientInfo.onUnregisterServiceSucceeded(clientRequestId);
                         } else {
-                            if (unregisterService(id)) {
-                                clientInfo.onUnregisterServiceSucceeded(clientId);
+                            if (unregisterService(transactionId)) {
+                                clientInfo.onUnregisterServiceSucceeded(clientRequestId);
                             } else {
                                 clientInfo.onUnregisterServiceFailed(
-                                        clientId, NsdManager.FAILURE_INTERNAL_ERROR);
+                                        clientRequestId, NsdManager.FAILURE_INTERNAL_ERROR);
                             }
                         }
                         break;
@@ -923,7 +934,7 @@
                         }
 
                         final NsdServiceInfo info = args.serviceInfo;
-                        id = getUniqueId();
+                        transactionId = getUniqueId();
                         final Pair<String, String> typeSubtype =
                                 parseTypeAndSubtype(info.getServiceType());
                         final String serviceType = typeSubtype == null
@@ -932,15 +943,15 @@
                                 ||  mDeps.isMdnsDiscoveryManagerEnabled(mContext)
                                 || useDiscoveryManagerForType(serviceType)) {
                             if (serviceType == null) {
-                                clientInfo.onResolveServiceFailed(clientId,
+                                clientInfo.onResolveServiceFailed(clientRequestId,
                                         NsdManager.FAILURE_INTERNAL_ERROR);
                                 break;
                             }
                             final String resolveServiceType = serviceType + ".local";
 
                             maybeStartMonitoringSockets();
-                            final MdnsListener listener =
-                                    new ResolutionListener(clientId, id, info, resolveServiceType);
+                            final MdnsListener listener = new ResolutionListener(clientRequestId,
+                                    transactionId, info, resolveServiceType);
                             final MdnsSearchOptions options = MdnsSearchOptions.newBuilder()
                                     .setNetwork(info.getNetwork())
                                     .setIsPassiveMode(true)
@@ -949,24 +960,25 @@
                                     .build();
                             mMdnsDiscoveryManager.registerListener(
                                     resolveServiceType, listener, options);
-                            storeDiscoveryManagerRequestMap(clientId, id, listener, clientInfo,
-                                    info.getNetwork());
-                            clientInfo.log("Register a ResolutionListener " + id
+                            storeDiscoveryManagerRequestMap(clientRequestId, transactionId,
+                                    listener, clientInfo, info.getNetwork());
+                            clientInfo.log("Register a ResolutionListener " + transactionId
                                     + " for service type:" + resolveServiceType);
                         } else {
                             if (clientInfo.mResolvedService != null) {
                                 clientInfo.onResolveServiceFailed(
-                                        clientId, NsdManager.FAILURE_ALREADY_ACTIVE);
+                                        clientRequestId, NsdManager.FAILURE_ALREADY_ACTIVE);
                                 break;
                             }
 
                             maybeStartDaemon();
-                            if (resolveService(id, info)) {
+                            if (resolveService(transactionId, info)) {
                                 clientInfo.mResolvedService = new NsdServiceInfo();
-                                storeLegacyRequestMap(clientId, id, clientInfo, msg.what);
+                                storeLegacyRequestMap(
+                                        clientRequestId, transactionId, clientInfo, msg.what);
                             } else {
                                 clientInfo.onResolveServiceFailed(
-                                        clientId, NsdManager.FAILURE_INTERNAL_ERROR);
+                                        clientRequestId, NsdManager.FAILURE_INTERNAL_ERROR);
                             }
                         }
                         break;
@@ -983,26 +995,28 @@
                             break;
                         }
 
-                        final ClientRequest request = clientInfo.mClientRequests.get(clientId);
+                        final ClientRequest request =
+                                clientInfo.mClientRequests.get(clientRequestId);
                         if (request == null) {
                             Log.e(TAG, "Unknown client request in STOP_RESOLUTION");
                             break;
                         }
-                        id = request.mGlobalId;
+                        transactionId = request.mTransactionId;
                         // Note isMdnsDiscoveryManagerEnabled may have changed to false at this
                         // point, so this needs to check the type of the original request to
                         // unregister instead of looking at the flag value.
                         if (request instanceof DiscoveryManagerRequest) {
-                            stopDiscoveryManagerRequest(request, clientId, id, clientInfo);
-                            clientInfo.onStopResolutionSucceeded(clientId);
-                            clientInfo.log("Unregister the ResolutionListener " + id);
+                            stopDiscoveryManagerRequest(
+                                    request, clientRequestId, transactionId, clientInfo);
+                            clientInfo.onStopResolutionSucceeded(clientRequestId);
+                            clientInfo.log("Unregister the ResolutionListener " + transactionId);
                         } else {
-                            removeRequestMap(clientId, id, clientInfo);
-                            if (stopResolveService(id)) {
-                                clientInfo.onStopResolutionSucceeded(clientId);
+                            removeRequestMap(clientRequestId, transactionId, clientInfo);
+                            if (stopResolveService(transactionId)) {
+                                clientInfo.onStopResolutionSucceeded(clientRequestId);
                             } else {
                                 clientInfo.onStopResolutionFailed(
-                                        clientId, NsdManager.FAILURE_OPERATION_NOT_RUNNING);
+                                        clientRequestId, NsdManager.FAILURE_OPERATION_NOT_RUNNING);
                             }
                             clientInfo.mResolvedService = null;
                         }
@@ -1021,21 +1035,21 @@
                         }
 
                         final NsdServiceInfo info = args.serviceInfo;
-                        id = getUniqueId();
+                        transactionId = getUniqueId();
                         final Pair<String, String> typeAndSubtype =
                                 parseTypeAndSubtype(info.getServiceType());
                         final String serviceType = typeAndSubtype == null
                                 ? null : typeAndSubtype.first;
                         if (serviceType == null) {
-                            clientInfo.onServiceInfoCallbackRegistrationFailed(clientId,
+                            clientInfo.onServiceInfoCallbackRegistrationFailed(clientRequestId,
                                     NsdManager.FAILURE_BAD_PARAMETERS);
                             break;
                         }
                         final String resolveServiceType = serviceType + ".local";
 
                         maybeStartMonitoringSockets();
-                        final MdnsListener listener =
-                                new ServiceInfoListener(clientId, id, info, resolveServiceType);
+                        final MdnsListener listener = new ServiceInfoListener(clientRequestId,
+                                transactionId, info, resolveServiceType);
                         final MdnsSearchOptions options = MdnsSearchOptions.newBuilder()
                                 .setNetwork(info.getNetwork())
                                 .setIsPassiveMode(true)
@@ -1044,9 +1058,9 @@
                                 .build();
                         mMdnsDiscoveryManager.registerListener(
                                 resolveServiceType, listener, options);
-                        storeDiscoveryManagerRequestMap(clientId, id, listener, clientInfo,
-                                info.getNetwork());
-                        clientInfo.log("Register a ServiceInfoListener " + id
+                        storeDiscoveryManagerRequestMap(clientRequestId, transactionId, listener,
+                                clientInfo, info.getNetwork());
+                        clientInfo.log("Register a ServiceInfoListener " + transactionId
                                 + " for service type:" + resolveServiceType);
                         break;
                     }
@@ -1062,16 +1076,18 @@
                             break;
                         }
 
-                        final ClientRequest request = clientInfo.mClientRequests.get(clientId);
+                        final ClientRequest request =
+                                clientInfo.mClientRequests.get(clientRequestId);
                         if (request == null) {
                             Log.e(TAG, "Unknown client request in UNREGISTER_SERVICE_CALLBACK");
                             break;
                         }
-                        id = request.mGlobalId;
+                        transactionId = request.mTransactionId;
                         if (request instanceof DiscoveryManagerRequest) {
-                            stopDiscoveryManagerRequest(request, clientId, id, clientInfo);
-                            clientInfo.onServiceInfoCallbackUnregistered(clientId);
-                            clientInfo.log("Unregister the ServiceInfoListener " + id);
+                            stopDiscoveryManagerRequest(
+                                    request, clientRequestId, transactionId, clientInfo);
+                            clientInfo.onServiceInfoCallbackUnregistered(clientRequestId);
+                            clientInfo.log("Unregister the ServiceInfoListener " + transactionId);
                         } else {
                             loge("Unregister failed with non-DiscoveryManagerRequest.");
                         }
@@ -1093,26 +1109,28 @@
                 return HANDLED;
             }
 
-            private boolean handleMDnsServiceEvent(int code, int id, Object obj) {
+            private boolean handleMDnsServiceEvent(int code, int transactionId, Object obj) {
                 NsdServiceInfo servInfo;
-                ClientInfo clientInfo = mIdToClientInfoMap.get(id);
+                ClientInfo clientInfo = mTransactionIdToClientInfoMap.get(transactionId);
                 if (clientInfo == null) {
-                    Log.e(TAG, String.format("id %d for %d has no client mapping", id, code));
+                    Log.e(TAG, String.format(
+                            "transactionId %d for %d has no client mapping", transactionId, code));
                     return false;
                 }
 
                 /* This goes in response as msg.arg2 */
-                int clientId = clientInfo.getClientId(id);
-                if (clientId < 0) {
+                int clientRequestId = clientInfo.getClientRequestId(transactionId);
+                if (clientRequestId < 0) {
                     // This can happen because of race conditions. For example,
                     // SERVICE_FOUND may race with STOP_SERVICE_DISCOVERY,
                     // and we may get in this situation.
-                    Log.d(TAG, String.format("%d for listener id %d that is no longer active",
-                            code, id));
+                    Log.d(TAG, String.format("%d for transactionId %d that is no longer active",
+                            code, transactionId));
                     return false;
                 }
                 if (DBG) {
-                    Log.d(TAG, String.format("MDns service event code:%d id=%d", code, id));
+                    Log.d(TAG, String.format(
+                            "MDns service event code:%d transactionId=%d", code, transactionId));
                 }
                 switch (code) {
                     case IMDnsEventListener.SERVICE_FOUND: {
@@ -1134,7 +1152,7 @@
                             break;
                         }
                         setServiceNetworkForCallback(servInfo, info.netId, info.interfaceIdx);
-                        clientInfo.onServiceFound(clientId, servInfo);
+                        clientInfo.onServiceFound(clientRequestId, servInfo);
                         break;
                     }
                     case IMDnsEventListener.SERVICE_LOST: {
@@ -1148,23 +1166,23 @@
                         // TODO: avoid returning null in that case, possibly by remembering
                         // found services on the same interface index and their network at the time
                         setServiceNetworkForCallback(servInfo, lostNetId, info.interfaceIdx);
-                        clientInfo.onServiceLost(clientId, servInfo);
+                        clientInfo.onServiceLost(clientRequestId, servInfo);
                         break;
                     }
                     case IMDnsEventListener.SERVICE_DISCOVERY_FAILED:
                         clientInfo.onDiscoverServicesFailed(
-                                clientId, NsdManager.FAILURE_INTERNAL_ERROR);
+                                clientRequestId, NsdManager.FAILURE_INTERNAL_ERROR);
                         break;
                     case IMDnsEventListener.SERVICE_REGISTERED: {
                         final RegistrationInfo info = (RegistrationInfo) obj;
                         final String name = info.serviceName;
                         servInfo = new NsdServiceInfo(name, null /* serviceType */);
-                        clientInfo.onRegisterServiceSucceeded(clientId, servInfo);
+                        clientInfo.onRegisterServiceSucceeded(clientRequestId, servInfo);
                         break;
                     }
                     case IMDnsEventListener.SERVICE_REGISTRATION_FAILED:
                         clientInfo.onRegisterServiceFailed(
-                                clientId, NsdManager.FAILURE_INTERNAL_ERROR);
+                                clientRequestId, NsdManager.FAILURE_INTERNAL_ERROR);
                         break;
                     case IMDnsEventListener.SERVICE_RESOLVED: {
                         final ResolutionInfo info = (ResolutionInfo) obj;
@@ -1192,34 +1210,34 @@
                         serviceInfo.setTxtRecords(info.txtRecord);
                         // Network will be added after SERVICE_GET_ADDR_SUCCESS
 
-                        stopResolveService(id);
-                        removeRequestMap(clientId, id, clientInfo);
+                        stopResolveService(transactionId);
+                        removeRequestMap(clientRequestId, transactionId, clientInfo);
 
-                        final int id2 = getUniqueId();
-                        if (getAddrInfo(id2, info.hostname, info.interfaceIdx)) {
-                            storeLegacyRequestMap(clientId, id2, clientInfo,
+                        final int transactionId2 = getUniqueId();
+                        if (getAddrInfo(transactionId2, info.hostname, info.interfaceIdx)) {
+                            storeLegacyRequestMap(clientRequestId, transactionId2, clientInfo,
                                     NsdManager.RESOLVE_SERVICE);
                         } else {
                             clientInfo.onResolveServiceFailed(
-                                    clientId, NsdManager.FAILURE_INTERNAL_ERROR);
+                                    clientRequestId, NsdManager.FAILURE_INTERNAL_ERROR);
                             clientInfo.mResolvedService = null;
                         }
                         break;
                     }
                     case IMDnsEventListener.SERVICE_RESOLUTION_FAILED:
                         /* NNN resolveId errorCode */
-                        stopResolveService(id);
-                        removeRequestMap(clientId, id, clientInfo);
+                        stopResolveService(transactionId);
+                        removeRequestMap(clientRequestId, transactionId, clientInfo);
                         clientInfo.onResolveServiceFailed(
-                                clientId, NsdManager.FAILURE_INTERNAL_ERROR);
+                                clientRequestId, NsdManager.FAILURE_INTERNAL_ERROR);
                         clientInfo.mResolvedService = null;
                         break;
                     case IMDnsEventListener.SERVICE_GET_ADDR_FAILED:
                         /* NNN resolveId errorCode */
-                        stopGetAddrInfo(id);
-                        removeRequestMap(clientId, id, clientInfo);
+                        stopGetAddrInfo(transactionId);
+                        removeRequestMap(clientRequestId, transactionId, clientInfo);
                         clientInfo.onResolveServiceFailed(
-                                clientId, NsdManager.FAILURE_INTERNAL_ERROR);
+                                clientRequestId, NsdManager.FAILURE_INTERNAL_ERROR);
                         clientInfo.mResolvedService = null;
                         break;
                     case IMDnsEventListener.SERVICE_GET_ADDR_SUCCESS: {
@@ -1242,13 +1260,13 @@
                             setServiceNetworkForCallback(clientInfo.mResolvedService,
                                     netId, info.interfaceIdx);
                             clientInfo.onResolveServiceSucceeded(
-                                    clientId, clientInfo.mResolvedService);
+                                    clientRequestId, clientInfo.mResolvedService);
                         } else {
                             clientInfo.onResolveServiceFailed(
-                                    clientId, NsdManager.FAILURE_INTERNAL_ERROR);
+                                    clientRequestId, NsdManager.FAILURE_INTERNAL_ERROR);
                         }
-                        stopGetAddrInfo(id);
-                        removeRequestMap(clientId, id, clientInfo);
+                        stopGetAddrInfo(transactionId);
+                        removeRequestMap(clientRequestId, transactionId, clientInfo);
                         clientInfo.mResolvedService = null;
                         break;
                     }
@@ -1305,7 +1323,7 @@
 
             private boolean handleMdnsDiscoveryManagerEvent(
                     int transactionId, int code, Object obj) {
-                final ClientInfo clientInfo = mIdToClientInfoMap.get(transactionId);
+                final ClientInfo clientInfo = mTransactionIdToClientInfoMap.get(transactionId);
                 if (clientInfo == null) {
                     Log.e(TAG, String.format(
                             "id %d for %d has no client mapping", transactionId, code));
@@ -1313,23 +1331,23 @@
                 }
 
                 final MdnsEvent event = (MdnsEvent) obj;
-                final int clientId = event.mClientId;
+                final int clientRequestId = event.mClientRequestId;
                 final NsdServiceInfo info = buildNsdServiceInfoFromMdnsEvent(event, code);
                 // Errors are already logged if null
                 if (info == null) return false;
-                if (DBG) {
-                    Log.d(TAG, String.format("MdnsDiscoveryManager event code=%s transactionId=%d",
-                            NsdManager.nameOf(code), transactionId));
-                }
+                mServiceLogs.log(String.format(
+                        "MdnsDiscoveryManager event code=%s transactionId=%d",
+                        NsdManager.nameOf(code), transactionId));
                 switch (code) {
                     case NsdManager.SERVICE_FOUND:
-                        clientInfo.onServiceFound(clientId, info);
+                        clientInfo.onServiceFound(clientRequestId, info);
                         break;
                     case NsdManager.SERVICE_LOST:
-                        clientInfo.onServiceLost(clientId, info);
+                        clientInfo.onServiceLost(clientRequestId, info);
                         break;
                     case NsdManager.RESOLVE_SERVICE_SUCCEEDED: {
-                        final ClientRequest request = clientInfo.mClientRequests.get(clientId);
+                        final ClientRequest request =
+                                clientInfo.mClientRequests.get(clientRequestId);
                         if (request == null) {
                             Log.e(TAG, "Unknown client request in RESOLVE_SERVICE_SUCCEEDED");
                             break;
@@ -1349,11 +1367,11 @@
                         final List<InetAddress> addresses = getInetAddresses(serviceInfo);
                         if (addresses.size() != 0) {
                             info.setHostAddresses(addresses);
-                            clientInfo.onResolveServiceSucceeded(clientId, info);
+                            clientInfo.onResolveServiceSucceeded(clientRequestId, info);
                         } else {
                             // No address. Notify resolution failure.
                             clientInfo.onResolveServiceFailed(
-                                    clientId, NsdManager.FAILURE_INTERNAL_ERROR);
+                                    clientRequestId, NsdManager.FAILURE_INTERNAL_ERROR);
                         }
 
                         // Unregister the listener immediately like IMDnsEventListener design
@@ -1361,7 +1379,8 @@
                             Log.wtf(TAG, "non-DiscoveryManager request in DiscoveryManager event");
                             break;
                         }
-                        stopDiscoveryManagerRequest(request, clientId, transactionId, clientInfo);
+                        stopDiscoveryManagerRequest(
+                                request, clientRequestId, transactionId, clientInfo);
                         break;
                     }
                     case NsdManager.SERVICE_UPDATED: {
@@ -1380,11 +1399,11 @@
 
                         final List<InetAddress> addresses = getInetAddresses(serviceInfo);
                         info.setHostAddresses(addresses);
-                        clientInfo.onServiceUpdated(clientId, info);
+                        clientInfo.onServiceUpdated(clientRequestId, info);
                         break;
                     }
                     case NsdManager.SERVICE_UPDATED_LOST:
-                        clientInfo.onServiceUpdatedLost(clientId);
+                        clientInfo.onServiceUpdatedLost(clientRequestId);
                         break;
                     default:
                         return false;
@@ -1721,44 +1740,46 @@
 
     private class AdvertiserCallback implements MdnsAdvertiser.AdvertiserCallback {
         @Override
-        public void onRegisterServiceSucceeded(int serviceId, NsdServiceInfo registeredInfo) {
-            final ClientInfo clientInfo = getClientInfoOrLog(serviceId);
+        public void onRegisterServiceSucceeded(int transactionId, NsdServiceInfo registeredInfo) {
+            mServiceLogs.log("onRegisterServiceSucceeded: transactionId " + transactionId);
+            final ClientInfo clientInfo = getClientInfoOrLog(transactionId);
             if (clientInfo == null) return;
 
-            final int clientId = getClientIdOrLog(clientInfo, serviceId);
-            if (clientId < 0) return;
+            final int clientRequestId = getClientRequestIdOrLog(clientInfo, transactionId);
+            if (clientRequestId < 0) return;
 
             // onRegisterServiceSucceeded only has the service name in its info. This aligns with
             // historical behavior.
             final NsdServiceInfo cbInfo = new NsdServiceInfo(registeredInfo.getServiceName(), null);
-            clientInfo.onRegisterServiceSucceeded(clientId, cbInfo);
+            clientInfo.onRegisterServiceSucceeded(clientRequestId, cbInfo);
         }
 
         @Override
-        public void onRegisterServiceFailed(int serviceId, int errorCode) {
-            final ClientInfo clientInfo = getClientInfoOrLog(serviceId);
+        public void onRegisterServiceFailed(int transactionId, int errorCode) {
+            final ClientInfo clientInfo = getClientInfoOrLog(transactionId);
             if (clientInfo == null) return;
 
-            final int clientId = getClientIdOrLog(clientInfo, serviceId);
-            if (clientId < 0) return;
+            final int clientRequestId = getClientRequestIdOrLog(clientInfo, transactionId);
+            if (clientRequestId < 0) return;
 
-            clientInfo.onRegisterServiceFailed(clientId, errorCode);
+            clientInfo.onRegisterServiceFailed(clientRequestId, errorCode);
         }
 
-        private ClientInfo getClientInfoOrLog(int serviceId) {
-            final ClientInfo clientInfo = mIdToClientInfoMap.get(serviceId);
+        private ClientInfo getClientInfoOrLog(int transactionId) {
+            final ClientInfo clientInfo = mTransactionIdToClientInfoMap.get(transactionId);
             if (clientInfo == null) {
-                Log.e(TAG, String.format("Callback for service %d has no client", serviceId));
+                Log.e(TAG, String.format("Callback for service %d has no client", transactionId));
             }
             return clientInfo;
         }
 
-        private int getClientIdOrLog(@NonNull ClientInfo info, int serviceId) {
-            final int clientId = info.getClientId(serviceId);
-            if (clientId < 0) {
-                Log.e(TAG, String.format("Client ID not found for service %d", serviceId));
+        private int getClientRequestIdOrLog(@NonNull ClientInfo info, int transactionId) {
+            final int clientRequestId = info.getClientRequestId(transactionId);
+            if (clientRequestId < 0) {
+                Log.e(TAG, String.format(
+                        "Client request ID not found for service %d", transactionId));
             }
-            return clientId;
+            return clientRequestId;
         }
     }
 
@@ -1879,9 +1900,9 @@
         return mUniqueId;
     }
 
-    private boolean registerService(int regId, NsdServiceInfo service) {
+    private boolean registerService(int transactionId, NsdServiceInfo service) {
         if (DBG) {
-            Log.d(TAG, "registerService: " + regId + " " + service);
+            Log.d(TAG, "registerService: " + transactionId + " " + service);
         }
         String name = service.getServiceName();
         String type = service.getServiceType();
@@ -1892,28 +1913,29 @@
             Log.e(TAG, "Interface to register service on not found");
             return false;
         }
-        return mMDnsManager.registerService(regId, name, type, port, textRecord, registerInterface);
+        return mMDnsManager.registerService(
+                transactionId, name, type, port, textRecord, registerInterface);
     }
 
-    private boolean unregisterService(int regId) {
-        return mMDnsManager.stopOperation(regId);
+    private boolean unregisterService(int transactionId) {
+        return mMDnsManager.stopOperation(transactionId);
     }
 
-    private boolean discoverServices(int discoveryId, NsdServiceInfo serviceInfo) {
+    private boolean discoverServices(int transactionId, NsdServiceInfo serviceInfo) {
         final String type = serviceInfo.getServiceType();
         final int discoverInterface = getNetworkInterfaceIndex(serviceInfo);
         if (serviceInfo.getNetwork() != null && discoverInterface == IFACE_IDX_ANY) {
             Log.e(TAG, "Interface to discover service on not found");
             return false;
         }
-        return mMDnsManager.discover(discoveryId, type, discoverInterface);
+        return mMDnsManager.discover(transactionId, type, discoverInterface);
     }
 
-    private boolean stopServiceDiscovery(int discoveryId) {
-        return mMDnsManager.stopOperation(discoveryId);
+    private boolean stopServiceDiscovery(int transactionId) {
+        return mMDnsManager.stopOperation(transactionId);
     }
 
-    private boolean resolveService(int resolveId, NsdServiceInfo service) {
+    private boolean resolveService(int transactionId, NsdServiceInfo service) {
         final String name = service.getServiceName();
         final String type = service.getServiceType();
         final int resolveInterface = getNetworkInterfaceIndex(service);
@@ -1921,7 +1943,7 @@
             Log.e(TAG, "Interface to resolve service on not found");
             return false;
         }
-        return mMDnsManager.resolve(resolveId, name, type, "local.", resolveInterface);
+        return mMDnsManager.resolve(transactionId, name, type, "local.", resolveInterface);
     }
 
     /**
@@ -1970,16 +1992,16 @@
         return iface.getIndex();
     }
 
-    private boolean stopResolveService(int resolveId) {
-        return mMDnsManager.stopOperation(resolveId);
+    private boolean stopResolveService(int transactionId) {
+        return mMDnsManager.stopOperation(transactionId);
     }
 
-    private boolean getAddrInfo(int resolveId, String hostname, int interfaceIdx) {
-        return mMDnsManager.getServiceAddress(resolveId, hostname, interfaceIdx);
+    private boolean getAddrInfo(int transactionId, String hostname, int interfaceIdx) {
+        return mMDnsManager.getServiceAddress(transactionId, hostname, interfaceIdx);
     }
 
-    private boolean stopGetAddrInfo(int resolveId) {
-        return mMDnsManager.stopOperation(resolveId);
+    private boolean stopGetAddrInfo(int transactionId) {
+        return mMDnsManager.stopOperation(transactionId);
     }
 
     @Override
@@ -1999,18 +2021,18 @@
     }
 
     private abstract static class ClientRequest {
-        private final int mGlobalId;
+        private final int mTransactionId;
 
-        private ClientRequest(int globalId) {
-            mGlobalId = globalId;
+        private ClientRequest(int transactionId) {
+            mTransactionId = transactionId;
         }
     }
 
     private static class LegacyClientRequest extends ClientRequest {
         private final int mRequestCode;
 
-        private LegacyClientRequest(int globalId, int requestCode) {
-            super(globalId);
+        private LegacyClientRequest(int transactionId, int requestCode) {
+            super(transactionId);
             mRequestCode = requestCode;
         }
     }
@@ -2019,8 +2041,8 @@
         @Nullable
         private final Network mRequestedNetwork;
 
-        private JavaBackendClientRequest(int globalId, @Nullable Network requestedNetwork) {
-            super(globalId);
+        private JavaBackendClientRequest(int transactionId, @Nullable Network requestedNetwork) {
+            super(transactionId);
             mRequestedNetwork = requestedNetwork;
         }
 
@@ -2031,8 +2053,8 @@
     }
 
     private static class AdvertiserClientRequest extends JavaBackendClientRequest {
-        private AdvertiserClientRequest(int globalId, @Nullable Network requestedNetwork) {
-            super(globalId, requestedNetwork);
+        private AdvertiserClientRequest(int transactionId, @Nullable Network requestedNetwork) {
+            super(transactionId, requestedNetwork);
         }
     }
 
@@ -2040,9 +2062,9 @@
         @NonNull
         private final MdnsListener mListener;
 
-        private DiscoveryManagerRequest(int globalId, @NonNull MdnsListener listener,
+        private DiscoveryManagerRequest(int transactionId, @NonNull MdnsListener listener,
                 @Nullable Network requestedNetwork) {
-            super(globalId, requestedNetwork);
+            super(transactionId, requestedNetwork);
             mListener = listener;
         }
     }
@@ -2055,7 +2077,7 @@
         /* Remembers a resolved service until getaddrinfo completes */
         private NsdServiceInfo mResolvedService;
 
-        /* A map from client-side ID (listenerKey) to the request */
+        /* A map from client request ID (listenerKey) to the request */
         private final SparseArray<ClientRequest> mClientRequests = new SparseArray<>();
 
         // The target SDK of this client < Build.VERSION_CODES.S
@@ -2083,10 +2105,10 @@
             sb.append("mUseJavaBackend ").append(mUseJavaBackend).append("\n");
             sb.append("mUid ").append(mUid).append("\n");
             for (int i = 0; i < mClientRequests.size(); i++) {
-                int clientID = mClientRequests.keyAt(i);
-                sb.append("clientId ")
-                        .append(clientID)
-                        .append(" mDnsId ").append(mClientRequests.valueAt(i).mGlobalId)
+                int clientRequestId = mClientRequests.keyAt(i);
+                sb.append("clientRequestId ")
+                        .append(clientRequestId)
+                        .append(" transactionId ").append(mClientRequests.valueAt(i).mTransactionId)
                         .append(" type ").append(
                                 mClientRequests.valueAt(i).getClass().getSimpleName())
                         .append("\n");
@@ -2115,13 +2137,14 @@
             mClientLogs.log("Client unregistered. expungeAllRequests!");
             // TODO: to keep handler responsive, do not clean all requests for that client at once.
             for (int i = 0; i < mClientRequests.size(); i++) {
-                final int clientId = mClientRequests.keyAt(i);
+                final int clientRequestId = mClientRequests.keyAt(i);
                 final ClientRequest request = mClientRequests.valueAt(i);
-                final int globalId = request.mGlobalId;
-                mIdToClientInfoMap.remove(globalId);
+                final int transactionId = request.mTransactionId;
+                mTransactionIdToClientInfoMap.remove(transactionId);
                 if (DBG) {
-                    Log.d(TAG, "Terminating client-ID " + clientId
-                            + " global-ID " + globalId + " type " + mClientRequests.get(clientId));
+                    Log.d(TAG, "Terminating clientRequestId " + clientRequestId
+                            + " transactionId " + transactionId
+                            + " type " + mClientRequests.get(clientRequestId));
                 }
 
                 if (request instanceof DiscoveryManagerRequest) {
@@ -2130,7 +2153,7 @@
                 }
 
                 if (request instanceof AdvertiserClientRequest) {
-                    mAdvertiser.removeService(globalId);
+                    mAdvertiser.removeService(transactionId);
                     continue;
                 }
 
@@ -2140,13 +2163,13 @@
 
                 switch (((LegacyClientRequest) request).mRequestCode) {
                     case NsdManager.DISCOVER_SERVICES:
-                        stopServiceDiscovery(globalId);
+                        stopServiceDiscovery(transactionId);
                         break;
                     case NsdManager.RESOLVE_SERVICE:
-                        stopResolveService(globalId);
+                        stopResolveService(transactionId);
                         break;
                     case NsdManager.REGISTER_SERVICE:
-                        unregisterService(globalId);
+                        unregisterService(transactionId);
                         break;
                     default:
                         break;
@@ -2175,12 +2198,11 @@
             return false;
         }
 
-        // mClientRequests is a sparse array of listener id -> ClientRequest.  For a given
-        // mDnsClient id, return the corresponding listener id.  mDnsClient id is also called a
-        // global id.
-        private int getClientId(final int globalId) {
+        // mClientRequests is a sparse array of client request id -> ClientRequest.  For a given
+        // transaction id, return the corresponding client request id.
+        private int getClientRequestId(final int transactionId) {
             for (int i = 0; i < mClientRequests.size(); i++) {
-                if (mClientRequests.valueAt(i).mGlobalId == globalId) {
+                if (mClientRequests.valueAt(i).mTransactionId == transactionId) {
                     return mClientRequests.keyAt(i);
                 }
             }
diff --git a/service-t/src/com/android/server/connectivity/mdns/EnqueueMdnsQueryCallable.java b/service-t/src/com/android/server/connectivity/mdns/EnqueueMdnsQueryCallable.java
index 2d5bb00..bd4ec20 100644
--- a/service-t/src/com/android/server/connectivity/mdns/EnqueueMdnsQueryCallable.java
+++ b/service-t/src/com/android/server/connectivity/mdns/EnqueueMdnsQueryCallable.java
@@ -18,7 +18,6 @@
 
 import android.annotation.NonNull;
 import android.annotation.Nullable;
-import android.net.Network;
 import android.text.TextUtils;
 import android.util.Log;
 import android.util.Pair;
@@ -70,8 +69,8 @@
     private final List<String> subtypes;
     private final boolean expectUnicastResponse;
     private final int transactionId;
-    @Nullable
-    private final Network network;
+    @NonNull
+    private final SocketKey socketKey;
     private final boolean sendDiscoveryQueries;
     @NonNull
     private final List<MdnsResponse> servicesToResolve;
@@ -86,7 +85,7 @@
             @NonNull Collection<String> subtypes,
             boolean expectUnicastResponse,
             int transactionId,
-            @Nullable Network network,
+            @NonNull SocketKey socketKey,
             boolean onlyUseIpv6OnIpv6OnlyNetworks,
             boolean sendDiscoveryQueries,
             @NonNull Collection<MdnsResponse> servicesToResolve,
@@ -97,7 +96,7 @@
         this.subtypes = new ArrayList<>(subtypes);
         this.expectUnicastResponse = expectUnicastResponse;
         this.transactionId = transactionId;
-        this.network = network;
+        this.socketKey = socketKey;
         this.onlyUseIpv6OnIpv6OnlyNetworks = onlyUseIpv6OnIpv6OnlyNetworks;
         this.sendDiscoveryQueries = sendDiscoveryQueries;
         this.servicesToResolve = new ArrayList<>(servicesToResolve);
@@ -216,7 +215,7 @@
         if (expectUnicastResponse) {
             if (requestSender instanceof MdnsMultinetworkSocketClient) {
                 ((MdnsMultinetworkSocketClient) requestSender).sendPacketRequestingUnicastResponse(
-                        packet, network, onlyUseIpv6OnIpv6OnlyNetworks);
+                        packet, socketKey, onlyUseIpv6OnIpv6OnlyNetworks);
             } else {
                 requestSender.sendPacketRequestingUnicastResponse(
                         packet, onlyUseIpv6OnIpv6OnlyNetworks);
@@ -225,7 +224,7 @@
             if (requestSender instanceof MdnsMultinetworkSocketClient) {
                 ((MdnsMultinetworkSocketClient) requestSender)
                         .sendPacketRequestingMulticastResponse(
-                                packet, network, onlyUseIpv6OnIpv6OnlyNetworks);
+                                packet, socketKey, onlyUseIpv6OnIpv6OnlyNetworks);
             } else {
                 requestSender.sendPacketRequestingMulticastResponse(
                         packet, onlyUseIpv6OnIpv6OnlyNetworks);
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsConstants.java b/service-t/src/com/android/server/connectivity/mdns/MdnsConstants.java
index f0e1717..ce5f540 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsConstants.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsConstants.java
@@ -16,6 +16,8 @@
 
 package com.android.server.connectivity.mdns;
 
+import static com.android.internal.annotations.VisibleForTesting.Visibility.PACKAGE;
+
 import static java.nio.charset.StandardCharsets.UTF_8;
 
 import com.android.internal.annotations.VisibleForTesting;
@@ -25,7 +27,7 @@
 import java.nio.charset.Charset;
 
 /** mDNS-related constants. */
-@VisibleForTesting
+@VisibleForTesting(visibility = PACKAGE)
 public final class MdnsConstants {
     public static final int MDNS_PORT = 5353;
     // Flags word format is:
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java b/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java
index afad3b7..dfaec75 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java
@@ -194,7 +194,7 @@
                     }
 
                     @Override
-                    public void onAllSocketsDestroyed(@NonNull SocketKey socketKey) {
+                    public void onSocketDestroyed(@NonNull SocketKey socketKey) {
                         ensureRunningOnHandlerThread(handler);
                         final MdnsServiceTypeClient serviceTypeClient =
                                 perSocketServiceTypeClients.get(serviceType, socketKey);
@@ -254,8 +254,7 @@
     private void handleOnResponseReceived(@NonNull MdnsPacket packet,
             @NonNull SocketKey socketKey) {
         for (MdnsServiceTypeClient serviceTypeClient : getMdnsServiceTypeClient(socketKey)) {
-            serviceTypeClient.processResponse(
-                    packet, socketKey.getInterfaceIndex(), socketKey.getNetwork());
+            serviceTypeClient.processResponse(packet, socketKey);
         }
     }
 
@@ -285,9 +284,11 @@
     MdnsServiceTypeClient createServiceTypeClient(@NonNull String serviceType,
             @NonNull SocketKey socketKey) {
         sharedLog.log("createServiceTypeClient for type:" + serviceType + " " + socketKey);
+        final String tag = serviceType + "-" + socketKey.getNetwork()
+                + "/" + socketKey.getInterfaceIndex();
         return new MdnsServiceTypeClient(
                 serviceType, socketClient,
                 executorProvider.newServiceTypeClientSchedulerExecutor(), socketKey,
-                sharedLog.forSubComponent(serviceType + "-" + socketKey));
+                sharedLog.forSubComponent(tag), handler.getLooper());
     }
 }
\ No newline at end of file
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java b/service-t/src/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java
index 1253444..d1fa57c 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java
@@ -64,7 +64,7 @@
         @NonNull
         private final SocketCreationCallback mSocketCreationCallback;
         @NonNull
-        private final ArrayMap<MdnsInterfaceSocket, SocketKey> mActiveNetworkSockets =
+        private final ArrayMap<MdnsInterfaceSocket, SocketKey> mActiveSockets =
                 new ArrayMap<>();
 
         InterfaceSocketCallback(SocketCreationCallback socketCreationCallback) {
@@ -83,7 +83,7 @@
                 mSocketPacketHandlers.put(socket, handler);
             }
             socket.addPacketHandler(handler);
-            mActiveNetworkSockets.put(socket, socketKey);
+            mActiveSockets.put(socket, socketKey);
             mSocketCreationCallback.onSocketCreated(socketKey);
         }
 
@@ -95,16 +95,16 @@
         }
 
         private void notifySocketDestroyed(@NonNull MdnsInterfaceSocket socket) {
-            final SocketKey socketKey = mActiveNetworkSockets.remove(socket);
-            if (!isAnySocketActive(socketKey)) {
-                mSocketCreationCallback.onAllSocketsDestroyed(socketKey);
+            final SocketKey socketKey = mActiveSockets.remove(socket);
+            if (!isSocketActive(socket)) {
+                mSocketCreationCallback.onSocketDestroyed(socketKey);
             }
         }
 
         void onNetworkUnrequested() {
-            for (int i = mActiveNetworkSockets.size() - 1; i >= 0; i--) {
+            for (int i = mActiveSockets.size() - 1; i >= 0; i--) {
                 // Iterate from the end so the socket can be removed
-                final MdnsInterfaceSocket socket = mActiveNetworkSockets.keyAt(i);
+                final MdnsInterfaceSocket socket = mActiveSockets.keyAt(i);
                 notifySocketDestroyed(socket);
                 maybeCleanupPacketHandler(socket);
             }
@@ -114,17 +114,7 @@
     private boolean isSocketActive(@NonNull MdnsInterfaceSocket socket) {
         for (int i = 0; i < mRequestedNetworks.size(); i++) {
             final InterfaceSocketCallback isc = mRequestedNetworks.valueAt(i);
-            if (isc.mActiveNetworkSockets.containsKey(socket)) {
-                return true;
-            }
-        }
-        return false;
-    }
-
-    private boolean isAnySocketActive(@NonNull SocketKey socketKey) {
-        for (int i = 0; i < mRequestedNetworks.size(); i++) {
-            final InterfaceSocketCallback isc = mRequestedNetworks.valueAt(i);
-            if (isc.mActiveNetworkSockets.containsValue(socketKey)) {
+            if (isc.mActiveSockets.containsKey(socket)) {
                 return true;
             }
         }
@@ -135,7 +125,7 @@
         final ArrayMap<MdnsInterfaceSocket, SocketKey> sockets = new ArrayMap<>();
         for (int i = 0; i < mRequestedNetworks.size(); i++) {
             final InterfaceSocketCallback isc = mRequestedNetworks.valueAt(i);
-            sockets.putAll(isc.mActiveNetworkSockets);
+            sockets.putAll(isc.mActiveSockets);
         }
         return sockets;
     }
@@ -213,25 +203,22 @@
         return true;
     }
 
-    private void sendMdnsPacket(@NonNull DatagramPacket packet, @Nullable Network targetNetwork,
+    private void sendMdnsPacket(@NonNull DatagramPacket packet, @NonNull SocketKey targetSocketKey,
             boolean onlyUseIpv6OnIpv6OnlyNetworks) {
         final boolean isIpv6 = ((InetSocketAddress) packet.getSocketAddress()).getAddress()
                 instanceof Inet6Address;
         final boolean isIpv4 = ((InetSocketAddress) packet.getSocketAddress()).getAddress()
                 instanceof Inet4Address;
         final ArrayMap<MdnsInterfaceSocket, SocketKey> activeSockets = getActiveSockets();
-        boolean shouldQueryIpv6 = !onlyUseIpv6OnIpv6OnlyNetworks || isIpv6OnlyNetworks(
-                activeSockets, targetNetwork);
+        boolean shouldQueryIpv6 = !onlyUseIpv6OnIpv6OnlyNetworks || isIpv6OnlySockets(
+                activeSockets, targetSocketKey);
         for (int i = 0; i < activeSockets.size(); i++) {
             final MdnsInterfaceSocket socket = activeSockets.keyAt(i);
-            final Network network = activeSockets.valueAt(i).getNetwork();
+            final SocketKey socketKey = activeSockets.valueAt(i);
             // Check ip capability and network before sending packet
             if (((isIpv6 && socket.hasJoinedIpv6() && shouldQueryIpv6)
                     || (isIpv4 && socket.hasJoinedIpv4()))
-                    // Contrary to MdnsUtils.isNetworkMatched, only send packets targeting
-                    // the null network to interfaces that have the null network (tethering
-                    // downstream interfaces).
-                    && Objects.equals(network, targetNetwork)) {
+                    && Objects.equals(socketKey, targetSocketKey)) {
                 try {
                     socket.send(packet);
                 } catch (IOException e) {
@@ -241,13 +228,13 @@
         }
     }
 
-    private boolean isIpv6OnlyNetworks(
+    private boolean isIpv6OnlySockets(
             @NonNull ArrayMap<MdnsInterfaceSocket, SocketKey> activeSockets,
-            @Nullable Network targetNetwork) {
+            @NonNull SocketKey targetSocketKey) {
         for (int i = 0; i < activeSockets.size(); i++) {
             final MdnsInterfaceSocket socket = activeSockets.keyAt(i);
-            final Network network = activeSockets.valueAt(i).getNetwork();
-            if (Objects.equals(network, targetNetwork) && socket.hasJoinedIpv4()) {
+            final SocketKey socketKey = activeSockets.valueAt(i);
+            if (Objects.equals(socketKey, targetSocketKey) && socket.hasJoinedIpv4()) {
                 return false;
             }
         }
@@ -276,38 +263,35 @@
     }
 
     /**
-     * Send a mDNS request packet via given network that asks for multicast response.
-     *
-     * <p>The socket client may use a null network to identify some or all interfaces, in which case
-     * passing null sends the packet to these.
+     * Send a mDNS request packet via given socket key that asks for multicast response.
      */
     public void sendPacketRequestingMulticastResponse(@NonNull DatagramPacket packet,
-            @Nullable Network network, boolean onlyUseIpv6OnIpv6OnlyNetworks) {
-        mHandler.post(() -> sendMdnsPacket(packet, network, onlyUseIpv6OnIpv6OnlyNetworks));
+            @NonNull SocketKey socketKey, boolean onlyUseIpv6OnIpv6OnlyNetworks) {
+        mHandler.post(() -> sendMdnsPacket(packet, socketKey, onlyUseIpv6OnIpv6OnlyNetworks));
     }
 
     @Override
     public void sendPacketRequestingMulticastResponse(
             @NonNull DatagramPacket packet, boolean onlyUseIpv6OnIpv6OnlyNetworks) {
-        sendPacketRequestingMulticastResponse(
-                packet, null /* network */, onlyUseIpv6OnIpv6OnlyNetworks);
+        throw new UnsupportedOperationException("This socket client need to specify the socket to"
+                + "send packet");
     }
 
     /**
-     * Send a mDNS request packet via given network that asks for unicast response.
+     * Send a mDNS request packet via given socket key that asks for unicast response.
      *
      * <p>The socket client may use a null network to identify some or all interfaces, in which case
      * passing null sends the packet to these.
      */
     public void sendPacketRequestingUnicastResponse(@NonNull DatagramPacket packet,
-            @Nullable Network network, boolean onlyUseIpv6OnIpv6OnlyNetworks) {
-        mHandler.post(() -> sendMdnsPacket(packet, network, onlyUseIpv6OnIpv6OnlyNetworks));
+            @NonNull SocketKey socketKey, boolean onlyUseIpv6OnIpv6OnlyNetworks) {
+        mHandler.post(() -> sendMdnsPacket(packet, socketKey, onlyUseIpv6OnIpv6OnlyNetworks));
     }
 
     @Override
     public void sendPacketRequestingUnicastResponse(
             @NonNull DatagramPacket packet, boolean onlyUseIpv6OnIpv6OnlyNetworks) {
-        sendPacketRequestingUnicastResponse(
-                packet, null /* network */, onlyUseIpv6OnIpv6OnlyNetworks);
+        throw new UnsupportedOperationException("This socket client need to specify the socket to"
+                + "send packet");
     }
 }
\ No newline at end of file
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsSearchOptions.java b/service-t/src/com/android/server/connectivity/mdns/MdnsSearchOptions.java
index 98c80ee..f09596d 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsSearchOptions.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsSearchOptions.java
@@ -50,7 +50,8 @@
                             source.readBoolean(),
                             source.readParcelable(null),
                             source.readString(),
-                            (source.dataAvail() > 0) ? source.readBoolean() : false);
+                            source.readBoolean(),
+                            source.readInt());
                 }
 
                 @Override
@@ -62,9 +63,9 @@
     private final List<String> subtypes;
     @Nullable
     private final String resolveInstanceName;
-
     private final boolean isPassiveMode;
     private final boolean onlyUseIpv6OnIpv6OnlyNetworks;
+    private final int numOfQueriesBeforeBackoff;
     private final boolean removeExpiredService;
     // The target network for searching. Null network means search on all possible interfaces.
     @Nullable private final Network mNetwork;
@@ -76,13 +77,15 @@
             boolean removeExpiredService,
             @Nullable Network network,
             @Nullable String resolveInstanceName,
-            boolean onlyUseIpv6OnIpv6OnlyNetworks) {
+            boolean onlyUseIpv6OnIpv6OnlyNetworks,
+            int numOfQueriesBeforeBackoff) {
         this.subtypes = new ArrayList<>();
         if (subtypes != null) {
             this.subtypes.addAll(subtypes);
         }
         this.isPassiveMode = isPassiveMode;
         this.onlyUseIpv6OnIpv6OnlyNetworks = onlyUseIpv6OnIpv6OnlyNetworks;
+        this.numOfQueriesBeforeBackoff = numOfQueriesBeforeBackoff;
         this.removeExpiredService = removeExpiredService;
         mNetwork = network;
         this.resolveInstanceName = resolveInstanceName;
@@ -122,6 +125,14 @@
         return onlyUseIpv6OnIpv6OnlyNetworks;
     }
 
+    /**
+     *  Returns number of queries should be executed before backoff mode is enabled.
+     *  The default number is 3 if it is not set.
+     */
+    public int numOfQueriesBeforeBackoff() {
+        return numOfQueriesBeforeBackoff;
+    }
+
     /** Returns {@code true} if service will be removed after its TTL expires. */
     public boolean removeExpiredService() {
         return removeExpiredService;
@@ -159,6 +170,7 @@
         out.writeParcelable(mNetwork, 0);
         out.writeString(resolveInstanceName);
         out.writeBoolean(onlyUseIpv6OnIpv6OnlyNetworks);
+        out.writeInt(numOfQueriesBeforeBackoff);
     }
 
     /** A builder to create {@link MdnsSearchOptions}. */
@@ -166,6 +178,7 @@
         private final Set<String> subtypes;
         private boolean isPassiveMode = true;
         private boolean onlyUseIpv6OnIpv6OnlyNetworks = false;
+        private int numOfQueriesBeforeBackoff = 3;
         private boolean removeExpiredService;
         private Network mNetwork;
         private String resolveInstanceName;
@@ -219,6 +232,14 @@
         }
 
         /**
+         * Sets if the query backoff mode should be turned on.
+         */
+        public Builder setNumOfQueriesBeforeBackoff(int numOfQueriesBeforeBackoff) {
+            this.numOfQueriesBeforeBackoff = numOfQueriesBeforeBackoff;
+            return this;
+        }
+
+        /**
          * Sets if the service should be removed after TTL.
          *
          * @param removeExpiredService If set to {@code true}, the service will be removed after TTL
@@ -258,7 +279,8 @@
                     removeExpiredService,
                     mNetwork,
                     resolveInstanceName,
-                    onlyUseIpv6OnIpv6OnlyNetworks);
+                    onlyUseIpv6OnIpv6OnlyNetworks,
+                    numOfQueriesBeforeBackoff);
         }
     }
 }
\ No newline at end of file
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceCache.java b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceCache.java
index cd0be67..dc99e49 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceCache.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceCache.java
@@ -22,7 +22,6 @@
 
 import android.annotation.NonNull;
 import android.annotation.Nullable;
-import android.net.Network;
 import android.os.Handler;
 import android.os.Looper;
 import android.util.ArrayMap;
@@ -45,15 +44,15 @@
 public class MdnsServiceCache {
     private static class CacheKey {
         @NonNull final String mLowercaseServiceType;
-        @Nullable final Network mNetwork;
+        @NonNull final SocketKey mSocketKey;
 
-        CacheKey(@NonNull String serviceType, @Nullable Network network) {
+        CacheKey(@NonNull String serviceType, @NonNull SocketKey socketKey) {
             mLowercaseServiceType = toDnsLowerCase(serviceType);
-            mNetwork = network;
+            mSocketKey = socketKey;
         }
 
         @Override public int hashCode() {
-            return Objects.hash(mLowercaseServiceType, mNetwork);
+            return Objects.hash(mLowercaseServiceType, mSocketKey);
         }
 
         @Override public boolean equals(Object other) {
@@ -64,11 +63,11 @@
                 return false;
             }
             return Objects.equals(mLowercaseServiceType, ((CacheKey) other).mLowercaseServiceType)
-                    && Objects.equals(mNetwork, ((CacheKey) other).mNetwork);
+                    && Objects.equals(mSocketKey, ((CacheKey) other).mSocketKey);
         }
     }
     /**
-     * A map of cached services. Key is composed of service name, type and network. Value is the
+     * A map of cached services. Key is composed of service name, type and socket. Value is the
      * service which use the service type to discover from each socket.
      */
     @NonNull
@@ -81,17 +80,17 @@
     }
 
     /**
-     * Get the cache services which are queried from given service type and network.
+     * Get the cache services which are queried from given service type and socket.
      *
      * @param serviceType the target service type.
-     * @param network the target network
+     * @param socketKey the target socket
      * @return the set of services which matches the given service type.
      */
     @NonNull
     public List<MdnsResponse> getCachedServices(@NonNull String serviceType,
-            @Nullable Network network) {
+            @NonNull SocketKey socketKey) {
         ensureRunningOnHandlerThread(mHandler);
-        final CacheKey key = new CacheKey(serviceType, network);
+        final CacheKey key = new CacheKey(serviceType, socketKey);
         return mCachedServices.containsKey(key)
                 ? Collections.unmodifiableList(new ArrayList<>(mCachedServices.get(key)))
                 : Collections.emptyList();
@@ -112,15 +111,15 @@
      *
      * @param serviceName the target service name.
      * @param serviceType the target service type.
-     * @param network the target network
+     * @param socketKey the target socket
      * @return the service which matches given conditions.
      */
     @Nullable
     public MdnsResponse getCachedService(@NonNull String serviceName,
-            @NonNull String serviceType, @Nullable Network network) {
+            @NonNull String serviceType, @NonNull SocketKey socketKey) {
         ensureRunningOnHandlerThread(mHandler);
         final List<MdnsResponse> responses =
-                mCachedServices.get(new CacheKey(serviceType, network));
+                mCachedServices.get(new CacheKey(serviceType, socketKey));
         if (responses == null) {
             return null;
         }
@@ -132,14 +131,14 @@
      * Add or update a service.
      *
      * @param serviceType the service type.
-     * @param network the target network
+     * @param socketKey the target socket
      * @param response the response of the discovered service.
      */
-    public void addOrUpdateService(@NonNull String serviceType, @Nullable Network network,
+    public void addOrUpdateService(@NonNull String serviceType, @NonNull SocketKey socketKey,
             @NonNull MdnsResponse response) {
         ensureRunningOnHandlerThread(mHandler);
         final List<MdnsResponse> responses = mCachedServices.computeIfAbsent(
-                new CacheKey(serviceType, network), key -> new ArrayList<>());
+                new CacheKey(serviceType, socketKey), key -> new ArrayList<>());
         // Remove existing service if present.
         final MdnsResponse existing =
                 findMatchedResponse(responses, response.getServiceInstanceName());
@@ -148,18 +147,18 @@
     }
 
     /**
-     * Remove a service which matches the given service name, type and network.
+     * Remove a service which matches the given service name, type and socket.
      *
      * @param serviceName the target service name.
      * @param serviceType the target service type.
-     * @param network the target network.
+     * @param socketKey the target socket.
      */
     @Nullable
     public MdnsResponse removeService(@NonNull String serviceName, @NonNull String serviceType,
-            @Nullable Network network) {
+            @NonNull SocketKey socketKey) {
         ensureRunningOnHandlerThread(mHandler);
         final List<MdnsResponse> responses =
-                mCachedServices.get(new CacheKey(serviceType, network));
+                mCachedServices.get(new CacheKey(serviceType, socketKey));
         if (responses == null) {
             return null;
         }
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
index a36eb1b..9c49b8f 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
@@ -16,11 +16,14 @@
 
 package com.android.server.connectivity.mdns;
 
+import static com.android.server.connectivity.mdns.util.MdnsUtils.ensureRunningOnHandlerThread;
+
 import static java.util.concurrent.TimeUnit.MILLISECONDS;
 
 import android.annotation.NonNull;
 import android.annotation.Nullable;
-import android.net.Network;
+import android.os.Handler;
+import android.os.Looper;
 import android.text.TextUtils;
 import android.util.ArrayMap;
 import android.util.ArraySet;
@@ -51,6 +54,7 @@
  */
 public class MdnsServiceTypeClient {
 
+    private static final String TAG = MdnsServiceTypeClient.class.getSimpleName();
     private static final int DEFAULT_MTU = 1500;
 
     private final String serviceType;
@@ -60,10 +64,12 @@
     private final ScheduledExecutorService executor;
     @NonNull private final SocketKey socketKey;
     @NonNull private final SharedLog sharedLog;
+    @NonNull private final Handler handler;
     private final Object lock = new Object();
     private final ArrayMap<MdnsServiceBrowserListener, MdnsSearchOptions> listeners =
             new ArrayMap<>();
     // TODO: change instanceNameToResponse to TreeMap with case insensitive comparator.
+    @GuardedBy("lock")
     private final Map<String, MdnsResponse> instanceNameToResponse = new HashMap<>();
     private final boolean removeServiceAfterTtlExpires =
             MdnsConfigs.removeServiceAfterTtlExpires();
@@ -78,7 +84,14 @@
 
     @GuardedBy("lock")
     @Nullable
-    private Future<?> requestTaskFuture;
+    private Future<?> nextQueryTaskFuture;
+
+    @GuardedBy("lock")
+    @Nullable
+    private QueryTask lastScheduledTask;
+
+    @GuardedBy("lock")
+    private long lastSentTime;
 
     /**
      * Constructor of {@link MdnsServiceTypeClient}.
@@ -91,9 +104,10 @@
             @NonNull MdnsSocketClientBase socketClient,
             @NonNull ScheduledExecutorService executor,
             @NonNull SocketKey socketKey,
-            @NonNull SharedLog sharedLog) {
+            @NonNull SharedLog sharedLog,
+            @NonNull Looper looper) {
         this(serviceType, socketClient, executor, new MdnsResponseDecoder.Clock(), socketKey,
-                sharedLog);
+                sharedLog, looper);
     }
 
     @VisibleForTesting
@@ -103,7 +117,8 @@
             @NonNull ScheduledExecutorService executor,
             @NonNull MdnsResponseDecoder.Clock clock,
             @NonNull SocketKey socketKey,
-            @NonNull SharedLog sharedLog) {
+            @NonNull SharedLog sharedLog,
+            @NonNull Looper looper) {
         this.serviceType = serviceType;
         this.socketClient = socketClient;
         this.executor = executor;
@@ -112,6 +127,7 @@
         this.clock = clock;
         this.socketKey = socketKey;
         this.sharedLog = sharedLog;
+        this.handler = new Handler(looper);
     }
 
     private static MdnsServiceInfo buildMdnsServiceInfoFromResponse(
@@ -174,6 +190,7 @@
     public void startSendAndReceive(
             @NonNull MdnsServiceBrowserListener listener,
             @NonNull MdnsSearchOptions searchOptions) {
+        ensureRunningOnHandlerThread(handler);
         synchronized (lock) {
             this.searchOptions = searchOptions;
             boolean hadReply = false;
@@ -190,7 +207,7 @@
                 }
             }
             // Cancel the next scheduled periodical task.
-            if (requestTaskFuture != null) {
+            if (nextQueryTaskFuture != null) {
                 cancelRequestTaskLocked();
             }
             // Keep tracking the ScheduledFuture for the task so we can cancel it if caller is not
@@ -199,21 +216,40 @@
                     searchOptions.getSubtypes(),
                     searchOptions.isPassiveMode(),
                     searchOptions.onlyUseIpv6OnIpv6OnlyNetworks(),
-                    currentSessionId,
+                    searchOptions.numOfQueriesBeforeBackoff(),
                     socketKey);
+            final long now = clock.elapsedRealtime();
+            if (lastSentTime == 0) {
+                lastSentTime = now;
+            }
             if (hadReply) {
-                requestTaskFuture = scheduleNextRunLocked(taskConfig);
+                final QueryTaskConfig queryTaskConfig = taskConfig.getConfigForNextRun();
+                final long minRemainingTtl = getMinRemainingTtlLocked(now);
+                final long timeToRun = now + queryTaskConfig.delayUntilNextTaskWithoutBackoffMs;
+                nextQueryTaskFuture = scheduleNextRunLocked(queryTaskConfig,
+                        minRemainingTtl, now, timeToRun, currentSessionId);
             } else {
-                requestTaskFuture = executor.submit(new QueryTask(taskConfig));
+                lastScheduledTask = new QueryTask(taskConfig,
+                        now /* timeToRun */,
+                        now + getMinRemainingTtlLocked(now)/* minTtlExpirationTimeWhenScheduled */,
+                        currentSessionId);
+                nextQueryTaskFuture = executor.submit(lastScheduledTask);
             }
         }
     }
 
     @GuardedBy("lock")
     private void cancelRequestTaskLocked() {
-        requestTaskFuture.cancel(true);
+        final boolean canceled = nextQueryTaskFuture.cancel(true);
+        sharedLog.log("task canceled:" + canceled + ", current session: " + currentSessionId
+                + " task hashcode: " + getHexString(nextQueryTaskFuture));
         ++currentSessionId;
-        requestTaskFuture = null;
+        nextQueryTaskFuture = null;
+        lastScheduledTask = null;
+    }
+
+    private static String getHexString(Object o) {
+        return Integer.toHexString(System.identityHashCode(o));
     }
 
     private boolean responseMatchesOptions(@NonNull MdnsResponse response,
@@ -244,33 +280,29 @@
      * listener}. Otherwise returns {@code false}.
      */
     public boolean stopSendAndReceive(@NonNull MdnsServiceBrowserListener listener) {
+        ensureRunningOnHandlerThread(handler);
         synchronized (lock) {
             if (listeners.remove(listener) == null) {
                 return listeners.isEmpty();
             }
-            if (listeners.isEmpty() && requestTaskFuture != null) {
+            if (listeners.isEmpty() && nextQueryTaskFuture != null) {
                 cancelRequestTaskLocked();
             }
             return listeners.isEmpty();
         }
     }
 
-    public String[] getServiceTypeLabels() {
-        return serviceTypeLabels;
-    }
-
     /**
      * Process an incoming response packet.
      */
-    public synchronized void processResponse(@NonNull MdnsPacket packet, int interfaceIndex,
-            Network network) {
+    public synchronized void processResponse(@NonNull MdnsPacket packet,
+            @NonNull SocketKey socketKey) {
+        ensureRunningOnHandlerThread(handler);
         synchronized (lock) {
             // Augment the list of current known responses, and generated responses for resolve
             // requests if there is no known response
             final List<MdnsResponse> currentList = new ArrayList<>(instanceNameToResponse.values());
-
-            List<MdnsResponse> additionalResponses = makeResponsesForResolve(interfaceIndex,
-                    network);
+            List<MdnsResponse> additionalResponses = makeResponsesForResolve(socketKey);
             for (MdnsResponse additionalResponse : additionalResponses) {
                 if (!instanceNameToResponse.containsKey(
                         additionalResponse.getServiceInstanceName())) {
@@ -278,7 +310,8 @@
                 }
             }
             final Pair<ArraySet<MdnsResponse>, ArrayList<MdnsResponse>> augmentedResult =
-                    responseDecoder.augmentResponses(packet, currentList, interfaceIndex, network);
+                    responseDecoder.augmentResponses(packet, currentList,
+                            socketKey.getInterfaceIndex(), socketKey.getNetwork());
 
             final ArraySet<MdnsResponse> modifiedResponse = augmentedResult.first;
             final ArrayList<MdnsResponse> allResponses = augmentedResult.second;
@@ -286,9 +319,9 @@
             for (MdnsResponse response : allResponses) {
                 if (modifiedResponse.contains(response)) {
                     if (response.isGoodbye()) {
-                        onGoodbyeReceived(response.getServiceInstanceName());
+                        onGoodbyeReceivedLocked(response.getServiceInstanceName());
                     } else {
-                        onResponseModified(response);
+                        onResponseModifiedLocked(response);
                     }
                 } else if (instanceNameToResponse.containsKey(response.getServiceInstanceName())) {
                     // If the response is not modified and already in the cache. The cache will
@@ -296,10 +329,25 @@
                     instanceNameToResponse.put(response.getServiceInstanceName(), response);
                 }
             }
+            if (nextQueryTaskFuture != null && lastScheduledTask != null
+                    && lastScheduledTask.config.shouldUseQueryBackoff()) {
+                final long now = clock.elapsedRealtime();
+                final long minRemainingTtl = getMinRemainingTtlLocked(now);
+                final long timeToRun = calculateTimeToRun(lastScheduledTask,
+                        lastScheduledTask.config, now,
+                        minRemainingTtl, lastSentTime);
+                if (timeToRun > lastScheduledTask.timeToRun) {
+                    QueryTaskConfig lastTaskConfig = lastScheduledTask.config;
+                    cancelRequestTaskLocked();
+                    nextQueryTaskFuture = scheduleNextRunLocked(lastTaskConfig, minRemainingTtl,
+                            now, timeToRun, currentSessionId);
+                }
+            }
         }
     }
 
     public synchronized void onFailedToParseMdnsResponse(int receivedPacketNumber, int errorCode) {
+        ensureRunningOnHandlerThread(handler);
         for (int i = 0; i < listeners.size(); i++) {
             listeners.keyAt(i).onFailedToParseMdnsResponse(receivedPacketNumber, errorCode);
         }
@@ -307,6 +355,7 @@
 
     /** Notify all services are removed because the socket is destroyed. */
     public void notifySocketDestroyed() {
+        ensureRunningOnHandlerThread(handler);
         synchronized (lock) {
             for (MdnsResponse response : instanceNameToResponse.values()) {
                 final String name = response.getServiceInstanceName();
@@ -325,13 +374,14 @@
                 }
             }
 
-            if (requestTaskFuture != null) {
+            if (nextQueryTaskFuture != null) {
                 cancelRequestTaskLocked();
             }
         }
     }
 
-    private void onResponseModified(@NonNull MdnsResponse response) {
+    @GuardedBy("lock")
+    private void onResponseModifiedLocked(@NonNull MdnsResponse response) {
         final String serviceInstanceName = response.getServiceInstanceName();
         final MdnsResponse currentResponse =
                 instanceNameToResponse.get(serviceInstanceName);
@@ -377,7 +427,8 @@
         }
     }
 
-    private void onGoodbyeReceived(@Nullable String serviceInstanceName) {
+    @GuardedBy("lock")
+    private void onGoodbyeReceivedLocked(@Nullable String serviceInstanceName) {
         final MdnsResponse response = instanceNameToResponse.remove(serviceInstanceName);
         if (response == null) {
             return;
@@ -429,32 +480,52 @@
                 MdnsConfigs.alwaysAskForUnicastResponseInEachBurst();
         private final boolean usePassiveMode;
         private final boolean onlyUseIpv6OnIpv6OnlyNetworks;
-        private final long sessionId;
+        private final int numOfQueriesBeforeBackoff;
         @VisibleForTesting
-        int transactionId;
+        final int transactionId;
         @VisibleForTesting
-        boolean expectUnicastResponse;
-        private int queriesPerBurst;
-        private int timeBetweenBurstsInMs;
-        private int burstCounter;
-        private int timeToRunNextTaskInMs;
-        private boolean isFirstBurst;
+        final boolean expectUnicastResponse;
+        private final int queriesPerBurst;
+        private final int timeBetweenBurstsInMs;
+        private final int burstCounter;
+        private final long delayUntilNextTaskWithoutBackoffMs;
+        private final boolean isFirstBurst;
+        private final long queryCount;
         @NonNull private final SocketKey socketKey;
 
+
+        QueryTaskConfig(@NonNull QueryTaskConfig other, long queryCount, int transactionId,
+                boolean expectUnicastResponse, boolean isFirstBurst, int burstCounter,
+                int queriesPerBurst, int timeBetweenBurstsInMs,
+                long delayUntilNextTaskWithoutBackoffMs) {
+            this.subtypes = new ArrayList<>(other.subtypes);
+            this.usePassiveMode = other.usePassiveMode;
+            this.onlyUseIpv6OnIpv6OnlyNetworks = other.onlyUseIpv6OnIpv6OnlyNetworks;
+            this.numOfQueriesBeforeBackoff = other.numOfQueriesBeforeBackoff;
+            this.transactionId = transactionId;
+            this.expectUnicastResponse = expectUnicastResponse;
+            this.queriesPerBurst = queriesPerBurst;
+            this.timeBetweenBurstsInMs = timeBetweenBurstsInMs;
+            this.burstCounter = burstCounter;
+            this.delayUntilNextTaskWithoutBackoffMs = delayUntilNextTaskWithoutBackoffMs;
+            this.isFirstBurst = isFirstBurst;
+            this.queryCount = queryCount;
+            this.socketKey = other.socketKey;
+        }
         QueryTaskConfig(@NonNull Collection<String> subtypes,
                 boolean usePassiveMode,
                 boolean onlyUseIpv6OnIpv6OnlyNetworks,
-                long sessionId,
+                int numOfQueriesBeforeBackoff,
                 @Nullable SocketKey socketKey) {
             this.usePassiveMode = usePassiveMode;
             this.onlyUseIpv6OnIpv6OnlyNetworks = onlyUseIpv6OnIpv6OnlyNetworks;
+            this.numOfQueriesBeforeBackoff = numOfQueriesBeforeBackoff;
             this.subtypes = new ArrayList<>(subtypes);
             this.queriesPerBurst = QUERIES_PER_BURST;
             this.burstCounter = 0;
             this.transactionId = 1;
             this.expectUnicastResponse = true;
             this.isFirstBurst = true;
-            this.sessionId = sessionId;
             // Config the scan frequency based on the scan mode.
             if (this.usePassiveMode) {
                 // In passive scan mode, sends a single burst of QUERIES_PER_BURST queries, and then
@@ -469,47 +540,65 @@
                 this.timeBetweenBurstsInMs = INITIAL_TIME_BETWEEN_BURSTS_MS;
             }
             this.socketKey = socketKey;
+            this.queryCount = 0;
+            this.delayUntilNextTaskWithoutBackoffMs = TIME_BETWEEN_QUERIES_IN_BURST_MS;
         }
 
         QueryTaskConfig getConfigForNextRun() {
-            if (++transactionId > UNSIGNED_SHORT_MAX_VALUE) {
-                transactionId = 1;
+            long newQueryCount = queryCount + 1;
+            int newTransactionId = transactionId + 1;
+            if (newTransactionId > UNSIGNED_SHORT_MAX_VALUE) {
+                newTransactionId = 1;
             }
+            boolean newExpectUnicastResponse = false;
+            boolean newIsFirstBurst = isFirstBurst;
+            int newQueriesPerBurst = queriesPerBurst;
+            int newBurstCounter = burstCounter + 1;
+            long newDelayUntilNextTaskWithoutBackoffMs = delayUntilNextTaskWithoutBackoffMs;
+            int newTimeBetweenBurstsInMs = timeBetweenBurstsInMs;
             // Only the first query expects uni-cast response.
-            expectUnicastResponse = false;
-            if (++burstCounter == queriesPerBurst) {
-                burstCounter = 0;
+            if (newBurstCounter == queriesPerBurst) {
+                newBurstCounter = 0;
 
                 if (alwaysAskForUnicastResponse) {
-                    expectUnicastResponse = true;
+                    newExpectUnicastResponse = true;
                 }
                 // In passive scan mode, sends a single burst of QUERIES_PER_BURST queries, and
                 // then in each TIME_BETWEEN_BURSTS interval, sends QUERIES_PER_BURST_PASSIVE_MODE
                 // queries.
                 if (isFirstBurst) {
-                    isFirstBurst = false;
+                    newIsFirstBurst = false;
                     if (usePassiveMode) {
-                        queriesPerBurst = QUERIES_PER_BURST_PASSIVE_MODE;
+                        newQueriesPerBurst = QUERIES_PER_BURST_PASSIVE_MODE;
                     }
                 }
                 // In active scan mode, sends a burst of QUERIES_PER_BURST queries,
                 // TIME_BETWEEN_QUERIES_IN_BURST_MS apart, then waits for the scan interval, and
                 // then repeats. The scan interval starts as INITIAL_TIME_BETWEEN_BURSTS_MS and
                 // doubles until it maxes out at TIME_BETWEEN_BURSTS_MS.
-                timeToRunNextTaskInMs = timeBetweenBurstsInMs;
+                newDelayUntilNextTaskWithoutBackoffMs = timeBetweenBurstsInMs;
                 if (timeBetweenBurstsInMs < TIME_BETWEEN_BURSTS_MS) {
-                    timeBetweenBurstsInMs = Math.min(timeBetweenBurstsInMs * 2,
+                    newTimeBetweenBurstsInMs = Math.min(timeBetweenBurstsInMs * 2,
                             TIME_BETWEEN_BURSTS_MS);
                 }
             } else {
-                timeToRunNextTaskInMs = TIME_BETWEEN_QUERIES_IN_BURST_MS;
+                newDelayUntilNextTaskWithoutBackoffMs = TIME_BETWEEN_QUERIES_IN_BURST_MS;
             }
-            return this;
+            return new QueryTaskConfig(this, newQueryCount, newTransactionId,
+                    newExpectUnicastResponse, newIsFirstBurst, newBurstCounter, newQueriesPerBurst,
+                    newTimeBetweenBurstsInMs, newDelayUntilNextTaskWithoutBackoffMs);
+        }
+
+        private boolean shouldUseQueryBackoff() {
+            // Don't enable backoff mode during the burst or in the first burst
+            if (burstCounter != 0 || isFirstBurst) {
+                return false;
+            }
+            return queryCount > numOfQueriesBeforeBackoff;
         }
     }
 
-    private List<MdnsResponse> makeResponsesForResolve(int interfaceIndex,
-            @NonNull Network network) {
+    private List<MdnsResponse> makeResponsesForResolve(@NonNull SocketKey socketKey) {
         final List<MdnsResponse> resolveResponses = new ArrayList<>();
         for (int i = 0; i < listeners.size(); i++) {
             final String resolveName = listeners.valueAt(i).getResolveInstanceName();
@@ -524,7 +613,7 @@
                 instanceFullName.addAll(Arrays.asList(serviceTypeLabels));
                 knownResponse = new MdnsResponse(
                         0L /* lastUpdateTime */, instanceFullName.toArray(new String[0]),
-                        interfaceIndex, network);
+                        socketKey.getInterfaceIndex(), socketKey.getNetwork());
             }
             resolveResponses.add(knownResponse);
         }
@@ -535,9 +624,17 @@
     private class QueryTask implements Runnable {
 
         private final QueryTaskConfig config;
+        private final long timeToRun;
+        private final long minTtlExpirationTimeWhenScheduled;
+        private final long sessionId;
 
-        QueryTask(@NonNull QueryTaskConfig config) {
+        QueryTask(@NonNull QueryTaskConfig config, long timeToRun,
+                long minTtlExpirationTimeWhenScheduled,
+                long sessionId) {
             this.config = config;
+            this.timeToRun = timeToRun;
+            this.minTtlExpirationTimeWhenScheduled = minTtlExpirationTimeWhenScheduled;
+            this.sessionId = sessionId;
         }
 
         @Override
@@ -548,10 +645,7 @@
                 // The listener is requesting to resolve a service that has no info in
                 // cache. Use the provided name to generate a minimal response, so other records are
                 // queried to complete it.
-                // Only the names are used to know which queries to send, other parameters like
-                // interfaceIndex do not matter.
-                servicesToResolve = makeResponsesForResolve(
-                        0 /* interfaceIndex */, config.socketKey.getNetwork());
+                servicesToResolve = makeResponsesForResolve(config.socketKey);
                 sendDiscoveryQueries = servicesToResolve.size() < listeners.size();
             }
             Pair<Integer, List<String>> result;
@@ -564,7 +658,7 @@
                                 config.subtypes,
                                 config.expectUnicastResponse,
                                 config.transactionId,
-                                config.socketKey.getNetwork(),
+                                config.socketKey,
                                 config.onlyUseIpv6OnIpv6OnlyNetworks,
                                 sendDiscoveryQueries,
                                 servicesToResolve,
@@ -579,13 +673,13 @@
                 if (MdnsConfigs.useSessionIdToScheduleMdnsTask()) {
                     // In case that the task is not canceled successfully, use session ID to check
                     // if this task should continue to schedule more.
-                    if (config.sessionId != currentSessionId) {
+                    if (sessionId != currentSessionId) {
                         return;
                     }
                 }
 
                 if (MdnsConfigs.shouldCancelScanTaskWhenFutureIsNull()) {
-                    if (requestTaskFuture == null) {
+                    if (nextQueryTaskFuture == null) {
                         // If requestTaskFuture is set to null, the task is cancelled. We can't use
                         // isCancelled() here because this QueryTask is different from the future
                         // that is returned from executor.schedule(). See b/71646910.
@@ -630,14 +724,72 @@
                         }
                     }
                 }
-                requestTaskFuture = scheduleNextRunLocked(this.config);
+                QueryTaskConfig nextRunConfig = this.config.getConfigForNextRun();
+                final long now = clock.elapsedRealtime();
+                lastSentTime = now;
+                final long minRemainingTtl = getMinRemainingTtlLocked(now);
+                final long timeToRun = calculateTimeToRun(this, nextRunConfig, now,
+                        minRemainingTtl, lastSentTime);
+                nextQueryTaskFuture = scheduleNextRunLocked(nextRunConfig,
+                        minRemainingTtl, now, timeToRun, lastScheduledTask.sessionId);
             }
         }
     }
 
+    private static long calculateTimeToRun(@NonNull QueryTask lastScheduledTask,
+            QueryTaskConfig queryTaskConfig, long now, long minRemainingTtl, long lastSentTime) {
+        final long baseDelayInMs = queryTaskConfig.delayUntilNextTaskWithoutBackoffMs;
+        if (!queryTaskConfig.shouldUseQueryBackoff()) {
+            return lastSentTime + baseDelayInMs;
+        }
+        if (minRemainingTtl <= 0) {
+            // There's no service, or there is an expired service. In any case, schedule for the
+            // minimum time, which is the base delay.
+            return lastSentTime + baseDelayInMs;
+        }
+        // If the next TTL expiration time hasn't changed, then use previous calculated timeToRun.
+        if (lastSentTime < now
+                && lastScheduledTask.minTtlExpirationTimeWhenScheduled == now + minRemainingTtl) {
+            // Use the original scheduling time if the TTL has not changed, to avoid continuously
+            // rescheduling to 80% of the remaining TTL as time passes
+            return lastScheduledTask.timeToRun;
+        }
+        return Math.max(now + (long) (0.8 * minRemainingTtl), lastSentTime + baseDelayInMs);
+    }
+
+    @GuardedBy("lock")
+    private long getMinRemainingTtlLocked(long now) {
+        long minRemainingTtl = Long.MAX_VALUE;
+        for (MdnsResponse response : instanceNameToResponse.values()) {
+            if (!response.isComplete()) {
+                continue;
+            }
+            long remainingTtl =
+                    response.getServiceRecord().getRemainingTTL(now);
+            // remainingTtl is <= 0 means the service expired.
+            if (remainingTtl <= 0) {
+                return 0;
+            }
+            if (remainingTtl < minRemainingTtl) {
+                minRemainingTtl = remainingTtl;
+            }
+        }
+        return minRemainingTtl == Long.MAX_VALUE ? 0 : minRemainingTtl;
+    }
+
+    @GuardedBy("lock")
     @NonNull
-    private Future<?> scheduleNextRunLocked(@NonNull QueryTaskConfig lastRunConfig) {
-        QueryTaskConfig config = lastRunConfig.getConfigForNextRun();
-        return executor.schedule(new QueryTask(config), config.timeToRunNextTaskInMs, MILLISECONDS);
+    private Future<?> scheduleNextRunLocked(@NonNull QueryTaskConfig nextRunConfig,
+            long minRemainingTtl,
+            long timeWhenScheduled, long timeToRun, long sessionId) {
+        lastScheduledTask = new QueryTask(nextRunConfig, timeToRun,
+                minRemainingTtl + timeWhenScheduled, sessionId);
+        // The timeWhenScheduled could be greater than the timeToRun if the Runnable is delayed.
+        long timeToNextTasksWithBackoffInMs = Math.max(timeToRun - timeWhenScheduled, 0);
+        sharedLog.log(
+                String.format("Next run: sessionId: %d, in %d ms", lastScheduledTask.sessionId,
+                        timeToNextTasksWithBackoffInMs));
+        return executor.schedule(lastScheduledTask, timeToNextTasksWithBackoffInMs,
+                MILLISECONDS);
     }
 }
\ No newline at end of file
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClientBase.java b/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClientBase.java
index 5e4a8b5..b6000f0 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClientBase.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClientBase.java
@@ -82,6 +82,6 @@
         void onSocketCreated(@NonNull SocketKey socketKey);
 
         /*** Notify requested socket is destroyed */
-        void onAllSocketsDestroyed(@NonNull SocketKey socketKey);
+        void onSocketDestroyed(@NonNull SocketKey socketKey);
     }
 }
\ No newline at end of file
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsSocketProvider.java b/service-t/src/com/android/server/connectivity/mdns/MdnsSocketProvider.java
index 5ee2da1..6925b49 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsSocketProvider.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsSocketProvider.java
@@ -259,11 +259,6 @@
                 @NonNull final NetLinkMonitorCallBack cb) {
             return SocketNetLinkMonitorFactory.createNetLinkMonitor(handler, log, cb);
         }
-
-        /*** Get interface index by given socket */
-        public int getInterfaceIndex(@NonNull MdnsInterfaceSocket socket) {
-            return socket.getInterface().getIndex();
-        }
     }
     /**
      * The callback interface for the netlink monitor messages.
@@ -324,11 +319,14 @@
         final MdnsInterfaceSocket mSocket;
         final List<LinkAddress> mAddresses;
         final int[] mTransports;
+        @NonNull final SocketKey mSocketKey;
 
-        SocketInfo(MdnsInterfaceSocket socket, List<LinkAddress> addresses, int[] transports) {
+        SocketInfo(MdnsInterfaceSocket socket, List<LinkAddress> addresses, int[] transports,
+                @NonNull SocketKey socketKey) {
             mSocket = socket;
             mAddresses = new ArrayList<>(addresses);
             mTransports = transports;
+            mSocketKey = socketKey;
         }
     }
 
@@ -448,7 +446,7 @@
         // Try to join the group again.
         socketInfo.mSocket.joinGroup(addresses);
 
-        notifyAddressesChanged(network, socketInfo.mSocket, addresses);
+        notifyAddressesChanged(network, socketInfo, addresses);
     }
     private LinkProperties createLPForTetheredInterface(@NonNull final String interfaceName,
             int ifaceIndex) {
@@ -529,21 +527,22 @@
                     networkInterface.getNetworkInterface(), MdnsConstants.MDNS_PORT, mLooper,
                     mPacketReadBuffer);
             final List<LinkAddress> addresses = lp.getLinkAddresses();
+            final Network network =
+                    networkKey == LOCAL_NET ? null : ((NetworkAsKey) networkKey).mNetwork;
+            final SocketKey socketKey = new SocketKey(network, networkInterface.getIndex());
             // TODO: technically transport types are mutable, although generally not in ways that
             // would meaningfully impact the logic using it here. Consider updating logic to
             // support transports being added/removed.
-            final SocketInfo socketInfo = new SocketInfo(socket, addresses, transports);
+            final SocketInfo socketInfo = new SocketInfo(socket, addresses, transports, socketKey);
             if (networkKey == LOCAL_NET) {
                 mTetherInterfaceSockets.put(interfaceName, socketInfo);
             } else {
-                mNetworkSockets.put(((NetworkAsKey) networkKey).mNetwork, socketInfo);
+                mNetworkSockets.put(network, socketInfo);
             }
             // Try to join IPv4/IPv6 group.
             socket.joinGroup(addresses);
 
             // Notify the listeners which need this socket.
-            final Network network =
-                    networkKey == LOCAL_NET ? null : ((NetworkAsKey) networkKey).mNetwork;
             notifySocketCreated(network, socketInfo);
         } catch (IOException e) {
             mSharedLog.e("Create socket failed ifName:" + interfaceName, e);
@@ -585,7 +584,7 @@
         if (socketInfo == null) return;
 
         socketInfo.mSocket.destroy();
-        notifyInterfaceDestroyed(network, socketInfo.mSocket);
+        notifyInterfaceDestroyed(network, socketInfo);
         mSocketRequestMonitor.onSocketDestroyed(network, socketInfo.mSocket);
         mSharedLog.log("Remove socket on net:" + network);
     }
@@ -594,7 +593,7 @@
         final SocketInfo socketInfo = mTetherInterfaceSockets.remove(interfaceName);
         if (socketInfo == null) return;
         socketInfo.mSocket.destroy();
-        notifyInterfaceDestroyed(null /* network */, socketInfo.mSocket);
+        notifyInterfaceDestroyed(null /* network */, socketInfo);
         mSocketRequestMonitor.onSocketDestroyed(null /* network */, socketInfo.mSocket);
         mSharedLog.log("Remove socket on ifName:" + interfaceName);
     }
@@ -603,9 +602,7 @@
         for (int i = 0; i < mCallbacksToRequestedNetworks.size(); i++) {
             final Network requestedNetwork = mCallbacksToRequestedNetworks.valueAt(i);
             if (isNetworkMatched(requestedNetwork, network)) {
-                final int ifaceIndex = mDependencies.getInterfaceIndex(socketInfo.mSocket);
-                final SocketKey socketKey = new SocketKey(network, ifaceIndex);
-                mCallbacksToRequestedNetworks.keyAt(i).onSocketCreated(socketKey,
+                mCallbacksToRequestedNetworks.keyAt(i).onSocketCreated(socketInfo.mSocketKey,
                         socketInfo.mSocket, socketInfo.mAddresses);
                 mSocketRequestMonitor.onSocketRequestFulfilled(network, socketInfo.mSocket,
                         socketInfo.mTransports);
@@ -613,25 +610,23 @@
         }
     }
 
-    private void notifyInterfaceDestroyed(Network network, MdnsInterfaceSocket socket) {
+    private void notifyInterfaceDestroyed(Network network, SocketInfo socketInfo) {
         for (int i = 0; i < mCallbacksToRequestedNetworks.size(); i++) {
             final Network requestedNetwork = mCallbacksToRequestedNetworks.valueAt(i);
             if (isNetworkMatched(requestedNetwork, network)) {
-                final int ifaceIndex = mDependencies.getInterfaceIndex(socket);
                 mCallbacksToRequestedNetworks.keyAt(i)
-                        .onInterfaceDestroyed(new SocketKey(network, ifaceIndex), socket);
+                        .onInterfaceDestroyed(socketInfo.mSocketKey, socketInfo.mSocket);
             }
         }
     }
 
-    private void notifyAddressesChanged(Network network, MdnsInterfaceSocket socket,
+    private void notifyAddressesChanged(Network network, SocketInfo socketInfo,
             List<LinkAddress> addresses) {
         for (int i = 0; i < mCallbacksToRequestedNetworks.size(); i++) {
             final Network requestedNetwork = mCallbacksToRequestedNetworks.valueAt(i);
             if (isNetworkMatched(requestedNetwork, network)) {
-                final int ifaceIndex = mDependencies.getInterfaceIndex(socket);
                 mCallbacksToRequestedNetworks.keyAt(i)
-                        .onAddressesChanged(new SocketKey(network, ifaceIndex), socket, addresses);
+                        .onAddressesChanged(socketInfo.mSocketKey, socketInfo.mSocket, addresses);
             }
         }
     }
@@ -648,9 +643,7 @@
             createSocket(new NetworkAsKey(network), lp);
         } else {
             // Notify the socket for requested network.
-            final int ifaceIndex = mDependencies.getInterfaceIndex(socketInfo.mSocket);
-            final SocketKey socketKey = new SocketKey(network, ifaceIndex);
-            cb.onSocketCreated(socketKey, socketInfo.mSocket, socketInfo.mAddresses);
+            cb.onSocketCreated(socketInfo.mSocketKey, socketInfo.mSocket, socketInfo.mAddresses);
             mSocketRequestMonitor.onSocketRequestFulfilled(network, socketInfo.mSocket,
                     socketInfo.mTransports);
         }
@@ -665,9 +658,7 @@
                     createLPForTetheredInterface(interfaceName, ifaceIndex));
         } else {
             // Notify the socket for requested network.
-            final int ifaceIndex = mDependencies.getInterfaceIndex(socketInfo.mSocket);
-            final SocketKey socketKey = new SocketKey(ifaceIndex);
-            cb.onSocketCreated(socketKey, socketInfo.mSocket, socketInfo.mAddresses);
+            cb.onSocketCreated(socketInfo.mSocketKey, socketInfo.mSocket, socketInfo.mAddresses);
             mSocketRequestMonitor.onSocketRequestFulfilled(null /* socketNetwork */,
                     socketInfo.mSocket, socketInfo.mTransports);
         }
diff --git a/service-t/src/com/android/server/connectivity/mdns/NetworkInterfaceWrapper.java b/service-t/src/com/android/server/connectivity/mdns/NetworkInterfaceWrapper.java
index 0ecae48..48c396e 100644
--- a/service-t/src/com/android/server/connectivity/mdns/NetworkInterfaceWrapper.java
+++ b/service-t/src/com/android/server/connectivity/mdns/NetworkInterfaceWrapper.java
@@ -57,6 +57,10 @@
         return networkInterface.getInterfaceAddresses();
     }
 
+    public int getIndex() {
+        return networkInterface.getIndex();
+    }
+
     @Override
     public String toString() {
         return networkInterface.toString();
diff --git a/service-t/src/com/android/server/connectivity/mdns/SocketKey.java b/service-t/src/com/android/server/connectivity/mdns/SocketKey.java
index a893acb..f13d0e0 100644
--- a/service-t/src/com/android/server/connectivity/mdns/SocketKey.java
+++ b/service-t/src/com/android/server/connectivity/mdns/SocketKey.java
@@ -43,6 +43,7 @@
         mInterfaceIndex = interfaceIndex;
     }
 
+    @Nullable
     public Network getNetwork() {
         return mNetwork;
     }
diff --git a/service/ServiceConnectivityResources/res/values/config.xml b/service/ServiceConnectivityResources/res/values/config.xml
index 22d9b01..f30abc6 100644
--- a/service/ServiceConnectivityResources/res/values/config.xml
+++ b/service/ServiceConnectivityResources/res/values/config.xml
@@ -135,10 +135,17 @@
     <!-- Whether to cancel network notifications automatically when tapped -->
     <bool name="config_autoCancelNetworkNotifications">true</bool>
 
-    <!-- When no internet or partial connectivity is detected on a network, and a high priority
-         (heads up) notification would be shown due to the network being explicitly selected,
-         directly show the dialog that would normally be shown when tapping the notification
-         instead of showing the notification. -->
+    <!-- Configuration to let OEMs customize what to do when :
+         • Partial connectivity is detected on the network
+         • No internet is detected on the network, and
+           - the network was explicitly selected
+           - the system is configured to actively prefer bad wifi (see config_activelyPreferBadWifi)
+         The default behavior (false) is to post a notification with a PendingIntent so
+         the user is informed and can act if they wish.
+         Making this true instead will have the system fire the intent immediately instead
+         of showing a notification. OEMs who do this should have some intent receiver
+         listening to the intent and take the action they prefer (e.g. show a dialog,
+         show a customized notification etc).  -->
     <bool name="config_notifyNoInternetAsDialogWhenHighPriority">false</bool>
 
     <!-- When showing notifications indicating partial connectivity, display the same notifications
diff --git a/service/src/com/android/server/ConnectivityService.java b/service/src/com/android/server/ConnectivityService.java
index 53127d1..83afa83 100755
--- a/service/src/com/android/server/ConnectivityService.java
+++ b/service/src/com/android/server/ConnectivityService.java
@@ -318,6 +318,7 @@
 import java.io.InterruptedIOException;
 import java.io.PrintWriter;
 import java.io.Writer;
+import java.lang.IllegalArgumentException;
 import java.net.Inet4Address;
 import java.net.InetAddress;
 import java.net.InetSocketAddress;
@@ -389,7 +390,7 @@
     // Timeout in case the "actively prefer bad wifi" feature is on
     private static final int ACTIVELY_PREFER_BAD_WIFI_INITIAL_TIMEOUT_MS = 20 * 1000;
     // Timeout in case the "actively prefer bad wifi" feature is off
-    private static final int DONT_ACTIVELY_PREFER_BAD_WIFI_INITIAL_TIMEOUT_MS = 8 * 1000;
+    private static final int DEFAULT_EVALUATION_TIMEOUT_MS = 8 * 1000;
 
     // Default to 30s linger time-out, and 5s for nascent network. Modifiable only for testing.
     private static final String LINGER_DELAY_PROPERTY = "persist.netmon.linger";
@@ -4010,7 +4011,7 @@
                     // the destroyed flag is only just above the "current satisfier wins"
                     // tie-breaker. But technically anything that affects scoring should rematch.
                     rematchAllNetworksAndRequests();
-                    mHandler.postDelayed(() -> disconnectAndDestroyNetwork(nai), timeoutMs);
+                    mHandler.postDelayed(() -> nai.disconnect(), timeoutMs);
                     break;
                 }
             }
@@ -4547,9 +4548,11 @@
 
     @VisibleForTesting
     protected static boolean shouldCreateNetworksImmediately() {
-        // Before U, physical networks are only created when the agent advances to CONNECTED.
-        // In U and above, all networks are immediately created when the agent is registered.
-        return SdkLevel.isAtLeastU();
+        // The feature of creating the networks immediately was slated for U, but race conditions
+        // detected late required this was flagged off.
+        // TODO : enable this in a Mainline update or in V, and re-enable the test for this
+        // in NetworkAgentTest.
+        return false;
     }
 
     private static boolean shouldCreateNativeNetwork(@NonNull NetworkAgentInfo nai,
@@ -4609,9 +4612,6 @@
         if (DBG) {
             log(nai.toShortString() + " disconnected, was satisfying " + nai.numNetworkRequests());
         }
-
-        nai.disconnect();
-
         // Clear all notifications of this network.
         mNotifier.clearNotification(nai.network.getNetId());
         // A network agent has disconnected.
@@ -5897,7 +5897,7 @@
                     final NetworkAgentInfo nai = getNetworkAgentInfoForNetwork((Network) msg.obj);
                     if (nai == null) break;
                     nai.onPreventAutomaticReconnect();
-                    disconnectAndDestroyNetwork(nai);
+                    nai.disconnect();
                     break;
                 case EVENT_SET_VPN_NETWORK_PREFERENCE:
                     handleSetVpnNetworkPreference((VpnNetworkPreferenceInfo) msg.obj);
@@ -9042,7 +9042,7 @@
                 break;
             }
         }
-        disconnectAndDestroyNetwork(nai);
+        nai.disconnect();
     }
 
     private void handleLingerComplete(NetworkAgentInfo oldNetwork) {
@@ -9584,10 +9584,7 @@
         updateLegacyTypeTrackerAndVpnLockdownForRematch(changes, nais);
 
         // Tear down all unneeded networks.
-        // Iterate in reverse order because teardownUnneededNetwork removes the nai from
-        // mNetworkAgentInfos.
-        for (int i = mNetworkAgentInfos.size() - 1; i >= 0; i--) {
-            final NetworkAgentInfo nai = mNetworkAgentInfos.valueAt(i);
+        for (NetworkAgentInfo nai : mNetworkAgentInfos) {
             if (unneeded(nai, UnneededFor.TEARDOWN)) {
                 if (nai.getInactivityExpiry() > 0) {
                     // This network has active linger timers and no requests, but is not
@@ -9940,10 +9937,25 @@
                 networkAgent.networkMonitor().notifyNetworkConnected(params.linkProperties,
                         params.networkCapabilities);
             }
-            final long delay = !avoidBadWifi() && activelyPreferBadWifi()
-                    ? ACTIVELY_PREFER_BAD_WIFI_INITIAL_TIMEOUT_MS
-                    : DONT_ACTIVELY_PREFER_BAD_WIFI_INITIAL_TIMEOUT_MS;
-            scheduleEvaluationTimeout(networkAgent.network, delay);
+            final long evaluationDelay;
+            if (!networkAgent.networkCapabilities.hasSingleTransport(TRANSPORT_WIFI)) {
+                // If the network is anything other than pure wifi, use the default timeout.
+                evaluationDelay = DEFAULT_EVALUATION_TIMEOUT_MS;
+            } else if (networkAgent.networkAgentConfig.isExplicitlySelected()) {
+                // If the network is explicitly selected, use the default timeout because it's
+                // shorter and the user is likely staring at the screen expecting it to validate
+                // right away.
+                evaluationDelay = DEFAULT_EVALUATION_TIMEOUT_MS;
+            } else if (avoidBadWifi() || !activelyPreferBadWifi()) {
+                // If avoiding bad wifi, or if not avoiding but also not preferring bad wifi
+                evaluationDelay = DEFAULT_EVALUATION_TIMEOUT_MS;
+            } else {
+                // It's wifi, automatically connected, and bad wifi is preferred : use the
+                // longer timeout to avoid the device switching to captive portals with bad
+                // signal or very slow response.
+                evaluationDelay = ACTIVELY_PREFER_BAD_WIFI_INITIAL_TIMEOUT_MS;
+            }
+            scheduleEvaluationTimeout(networkAgent.network, evaluationDelay);
 
             // Whether a particular NetworkRequest listen should cause signal strength thresholds to
             // be communicated to a particular NetworkAgent depends only on the network's immutable,
@@ -9970,6 +9982,7 @@
             // This has to happen after matching the requests, because callbacks are just requests.
             notifyNetworkCallbacks(networkAgent, ConnectivityManager.CALLBACK_PRECHECK);
         } else if (state == NetworkInfo.State.DISCONNECTED) {
+            networkAgent.disconnect();
             if (networkAgent.isVPN()) {
                 updateVpnUids(networkAgent, networkAgent.networkCapabilities, null);
             }
diff --git a/service/src/com/android/server/connectivity/AutomaticOnOffKeepaliveTracker.java b/service/src/com/android/server/connectivity/AutomaticOnOffKeepaliveTracker.java
index 6ba2033..368860e 100644
--- a/service/src/com/android/server/connectivity/AutomaticOnOffKeepaliveTracker.java
+++ b/service/src/com/android/server/connectivity/AutomaticOnOffKeepaliveTracker.java
@@ -291,6 +291,18 @@
             }
         }
 
+        /**
+         * Construct a new AutomaticOnOffKeepalive from existing AutomaticOnOffKeepalive with a
+         * new KeepaliveInfo.
+         */
+        public AutomaticOnOffKeepalive withKeepaliveInfo(KeepaliveTracker.KeepaliveInfo ki)
+                throws InvalidSocketException {
+            return new AutomaticOnOffKeepalive(
+                    ki,
+                    mAutomaticOnOffState != STATE_ALWAYS_ON /* autoOnOff */,
+                    mUnderpinnedNetwork);
+        }
+
         @Override
         public String toString() {
             return "AutomaticOnOffKeepalive [ "
@@ -470,13 +482,26 @@
      * The message is expected to contain a KeepaliveTracker.KeepaliveInfo.
      */
     public void handleStartKeepalive(Message message) {
-        final AutomaticOnOffKeepalive autoKi = (AutomaticOnOffKeepalive) message.obj;
-        final int error = mKeepaliveTracker.handleStartKeepalive(autoKi.mKi);
+        final AutomaticOnOffKeepalive target = (AutomaticOnOffKeepalive) message.obj;
+        final Pair<Integer, KeepaliveTracker.KeepaliveInfo> res =
+                mKeepaliveTracker.handleStartKeepalive(target.mKi);
+        final int error = res.first;
         if (error != SUCCESS) {
-            mEventLog.log("Failed to start keepalive " + autoKi.mCallback + " on "
-                    + autoKi.getNetwork() + " with error " + error);
+            mEventLog.log("Failed to start keepalive " + target.mCallback + " on "
+                    + target.getNetwork() + " with error " + error);
             return;
         }
+        // Generate a new auto ki with the started keepalive info.
+        final AutomaticOnOffKeepalive autoKi;
+        try {
+            autoKi = target.withKeepaliveInfo(res.second);
+            // Close the duplicated fd.
+            target.close();
+        } catch (InvalidSocketException e) {
+            Log.wtf(TAG, "Fail to create AutomaticOnOffKeepalive", e);
+            return;
+        }
+
         mEventLog.log("Start keepalive " + autoKi.mCallback + " on " + autoKi.getNetwork());
         mKeepaliveStatsTracker.onStartKeepalive(
                 autoKi.getNetwork(),
@@ -506,14 +531,19 @@
      * @return SUCCESS if the keepalive is successfully starting and the error reason otherwise.
      */
     private int handleResumeKeepalive(@NonNull final KeepaliveTracker.KeepaliveInfo ki) {
-        final int error = mKeepaliveTracker.handleStartKeepalive(ki);
+        final Pair<Integer, KeepaliveTracker.KeepaliveInfo> res =
+                mKeepaliveTracker.handleStartKeepalive(ki);
+        final KeepaliveTracker.KeepaliveInfo startedKi = res.second;
+        final int error = res.first;
         if (error != SUCCESS) {
-            mEventLog.log("Failed to resume keepalive " + ki.mCallback + " on " + ki.mNai
-                    + " with error " + error);
+            mEventLog.log("Failed to resume keepalive " + startedKi.mCallback + " on "
+                    + startedKi.mNai + " with error " + error);
             return error;
         }
-        mKeepaliveStatsTracker.onResumeKeepalive(ki.getNai().network(), ki.getSlot());
-        mEventLog.log("Resumed successfully keepalive " + ki.mCallback + " on " + ki.mNai);
+
+        mKeepaliveStatsTracker.onResumeKeepalive(startedKi.getNai().network(), startedKi.getSlot());
+        mEventLog.log("Resumed successfully keepalive " + startedKi.mCallback
+                + " on " + startedKi.mNai);
 
         return SUCCESS;
     }
diff --git a/service/src/com/android/server/connectivity/KeepaliveStatsTracker.java b/service/src/com/android/server/connectivity/KeepaliveStatsTracker.java
index d59d526..414aca3 100644
--- a/service/src/com/android/server/connectivity/KeepaliveStatsTracker.java
+++ b/service/src/com/android/server/connectivity/KeepaliveStatsTracker.java
@@ -45,6 +45,7 @@
 import com.android.metrics.KeepaliveLifetimeForCarrier;
 import com.android.metrics.KeepaliveLifetimePerCarrier;
 import com.android.modules.utils.BackgroundThread;
+import com.android.modules.utils.build.SdkLevel;
 import com.android.net.module.util.CollectionUtils;
 import com.android.server.ConnectivityStatsLog;
 
@@ -251,6 +252,22 @@
         public long getElapsedRealtime() {
             return SystemClock.elapsedRealtime();
         }
+
+        /**
+         * Writes a DAILY_KEEPALIVE_INFO_REPORTED to ConnectivityStatsLog.
+         *
+         * @param dailyKeepaliveInfoReported the proto to write to statsD.
+         */
+        public void writeStats(DailykeepaliveInfoReported dailyKeepaliveInfoReported) {
+            ConnectivityStatsLog.write(
+                    ConnectivityStatsLog.DAILY_KEEPALIVE_INFO_REPORTED,
+                    dailyKeepaliveInfoReported.getDurationPerNumOfKeepalive().toByteArray(),
+                    dailyKeepaliveInfoReported.getKeepaliveLifetimePerCarrier().toByteArray(),
+                    dailyKeepaliveInfoReported.getKeepaliveRequests(),
+                    dailyKeepaliveInfoReported.getAutomaticKeepaliveRequests(),
+                    dailyKeepaliveInfoReported.getDistinctUserCount(),
+                    CollectionUtils.toIntArray(dailyKeepaliveInfoReported.getUidList()));
+        }
     }
 
     public KeepaliveStatsTracker(@NonNull Context context, @NonNull Handler handler) {
@@ -637,15 +654,15 @@
     /** Writes the stored metrics to ConnectivityStatsLog and resets.  */
     public void writeAndResetMetrics() {
         ensureRunningOnHandlerThread();
+        // Keepalive stats use repeated atoms, which are only supported on T+. If written to statsd
+        // on S- they will bootloop the system, so they must not be sent on S-. See b/289471411.
+        if (!SdkLevel.isAtLeastT()) {
+            Log.d(TAG, "KeepaliveStatsTracker is disabled before T, skipping write");
+            return;
+        }
+
         final DailykeepaliveInfoReported dailyKeepaliveInfoReported = buildAndResetMetrics();
-        ConnectivityStatsLog.write(
-                ConnectivityStatsLog.DAILY_KEEPALIVE_INFO_REPORTED,
-                dailyKeepaliveInfoReported.getDurationPerNumOfKeepalive().toByteArray(),
-                dailyKeepaliveInfoReported.getKeepaliveLifetimePerCarrier().toByteArray(),
-                dailyKeepaliveInfoReported.getKeepaliveRequests(),
-                dailyKeepaliveInfoReported.getAutomaticKeepaliveRequests(),
-                dailyKeepaliveInfoReported.getDistinctUserCount(),
-                CollectionUtils.toIntArray(dailyKeepaliveInfoReported.getUidList()));
+        mDependencies.writeStats(dailyKeepaliveInfoReported);
     }
 
     private void ensureRunningOnHandlerThread() {
diff --git a/service/src/com/android/server/connectivity/KeepaliveTracker.java b/service/src/com/android/server/connectivity/KeepaliveTracker.java
index 76e97e2..125c269 100644
--- a/service/src/com/android/server/connectivity/KeepaliveTracker.java
+++ b/service/src/com/android/server/connectivity/KeepaliveTracker.java
@@ -54,6 +54,7 @@
 import android.system.ErrnoException;
 import android.system.Os;
 import android.util.Log;
+import android.util.Pair;
 
 import com.android.connectivity.resources.R;
 import com.android.internal.annotations.VisibleForTesting;
@@ -62,6 +63,8 @@
 import com.android.net.module.util.IpUtils;
 
 import java.io.FileDescriptor;
+import java.net.Inet4Address;
+import java.net.Inet6Address;
 import java.net.InetAddress;
 import java.net.InetSocketAddress;
 import java.net.SocketAddress;
@@ -292,11 +295,15 @@
 
         private int checkSourceAddress() {
             // Check that we have the source address.
-            for (InetAddress address : mNai.linkProperties.getAddresses()) {
+            for (InetAddress address : mNai.linkProperties.getAllAddresses()) {
                 if (address.equals(mPacket.getSrcAddress())) {
                     return SUCCESS;
                 }
             }
+            // Or the address is the clat source address.
+            if (mPacket.getSrcAddress().equals(mNai.getClatv6SrcAddress())) {
+                return SUCCESS;
+            }
             return ERROR_INVALID_IP_ADDRESS;
         }
 
@@ -479,6 +486,15 @@
             return new KeepaliveInfo(mCallback, mNai, mPacket, mPid, mUid, mInterval, mType,
                     fd, mSlot, true /* resumed */);
         }
+
+        /**
+         * Construct a new KeepaliveInfo from existing KeepaliveInfo with a new KeepalivePacketData.
+         */
+        public KeepaliveInfo withPacketData(@NonNull KeepalivePacketData packet)
+                throws InvalidSocketException {
+            return new KeepaliveInfo(mCallback, mNai, packet, mPid, mUid, mInterval, mType,
+                    mFd, mSlot, mResumed);
+        }
     }
 
     void notifyErrorCallback(ISocketKeepaliveCallback cb, int error) {
@@ -512,15 +528,47 @@
      * Handle start keepalives with the message.
      *
      * @param ki the keepalive to start.
-     * @return SUCCESS if the keepalive is successfully starting and the error reason otherwise.
+     * @return Pair of (SUCCESS if the keepalive is successfully starting and the error reason
+     *         otherwise, the started KeepaliveInfo object)
      */
-    public int handleStartKeepalive(KeepaliveInfo ki) {
-        NetworkAgentInfo nai = ki.getNai();
+    public Pair<Integer, KeepaliveInfo> handleStartKeepalive(KeepaliveInfo ki) {
+        final KeepaliveInfo newKi;
+        try {
+            newKi = handleUpdateKeepaliveForClat(ki);
+        } catch (InvalidSocketException | InvalidPacketException e) {
+            Log.e(TAG, "Fail to construct keepalive packet");
+            notifyErrorCallback(ki.mCallback, ERROR_INVALID_IP_ADDRESS);
+            // Fail to create new keepalive packet for clat. Return the original keepalive info.
+            return new Pair<>(ERROR_INVALID_IP_ADDRESS, ki);
+        }
+
+        final NetworkAgentInfo nai = newKi.getNai();
         // If this was a paused keepalive, then reuse the same slot that was kept for it. Otherwise,
         // use the first free slot for this network agent.
-        final int slot = NO_KEEPALIVE != ki.mSlot ? ki.mSlot : findFirstFreeSlot(nai);
-        mKeepalives.get(nai).put(slot, ki);
-        return ki.start(slot);
+        final int slot = NO_KEEPALIVE != newKi.mSlot ? newKi.mSlot : findFirstFreeSlot(nai);
+        mKeepalives.get(nai).put(slot, newKi);
+
+        return new Pair<>(newKi.start(slot), newKi);
+    }
+
+    private KeepaliveInfo handleUpdateKeepaliveForClat(KeepaliveInfo ki)
+            throws InvalidSocketException, InvalidPacketException {
+        // Only try to translate address if the packet source address is the clat's source address.
+        if (!ki.mPacket.getSrcAddress().equals(ki.getNai().getClatv4SrcAddress())) return ki;
+
+        final InetAddress dstAddr = ki.mPacket.getDstAddress();
+        // Do not perform translation for a v6 dst address.
+        if (!(dstAddr instanceof Inet4Address)) return ki;
+
+        final Inet6Address address = ki.getNai().translateV4toClatV6((Inet4Address) dstAddr);
+
+        if (address == null) return ki;
+
+        final int srcPort = ki.mPacket.getSrcPort();
+        final KeepaliveInfo newInfo = ki.withPacketData(NattKeepalivePacketData.nattKeepalivePacket(
+                ki.getNai().getClatv6SrcAddress(), srcPort, address, NATT_PORT));
+        Log.d(TAG, "Src is clat v4 address. Convert from " + ki + " to " + newInfo);
+        return newInfo;
     }
 
     public void handleStopAllKeepalives(NetworkAgentInfo nai, int reason) {
diff --git a/service/src/com/android/server/connectivity/Nat464Xlat.java b/service/src/com/android/server/connectivity/Nat464Xlat.java
index 90cddda..f9e07fd 100644
--- a/service/src/com/android/server/connectivity/Nat464Xlat.java
+++ b/service/src/com/android/server/connectivity/Nat464Xlat.java
@@ -44,7 +44,9 @@
 import com.android.server.ConnectivityService;
 
 import java.io.IOException;
+import java.net.Inet4Address;
 import java.net.Inet6Address;
+import java.net.UnknownHostException;
 import java.util.Objects;
 
 /**
@@ -99,7 +101,8 @@
     private IpPrefix mNat64PrefixFromRa;
     private String mBaseIface;
     private String mIface;
-    private Inet6Address mIPv6Address;
+    @VisibleForTesting
+    Inet6Address mIPv6Address;
     private State mState = State.IDLE;
     private final ClatCoordinator mClatCoordinator;  // non-null iff T+
 
@@ -239,6 +242,7 @@
         mNat64PrefixInUse = null;
         mIface = null;
         mBaseIface = null;
+        mIPv6Address = null;
 
         if (!mPrefixDiscoveryRunning) {
             setPrefix64(null);
@@ -541,6 +545,67 @@
     }
 
     /**
+     * Translate the input v4 address to v6 clat address.
+     */
+    @Nullable
+    public Inet6Address translateV4toV6(@NonNull Inet4Address addr) {
+        // Variables in Nat464Xlat should only be accessed from handler thread.
+        ensureRunningOnHandlerThread();
+        if (!isStarted()) return null;
+
+        return convertv4ToClatv6(mNat64PrefixInUse, addr);
+    }
+
+    @Nullable
+    private static Inet6Address convertv4ToClatv6(
+            @NonNull IpPrefix prefix, @NonNull Inet4Address addr) {
+        final byte[] v6Addr = new byte[16];
+        // Generate a v6 address from Nat64 prefix. Prefix should be 12 bytes long.
+        System.arraycopy(prefix.getAddress().getAddress(), 0, v6Addr, 0, 12);
+        System.arraycopy(addr.getAddress(), 0, v6Addr, 12, 4);
+
+        try {
+            return (Inet6Address) Inet6Address.getByAddress(v6Addr);
+        } catch (UnknownHostException e) {
+            Log.wtf(TAG, "getByAddress should never throw for a numeric address", e);
+            return null;
+        }
+    }
+
+    /**
+     * Get the generated v6 address of clat.
+     */
+    @Nullable
+    public Inet6Address getClatv6SrcAddress() {
+        // Variables in Nat464Xlat should only be accessed from handler thread.
+        ensureRunningOnHandlerThread();
+
+        return mIPv6Address;
+    }
+
+    /**
+     * Get the generated v4 address of clat.
+     */
+    @Nullable
+    public Inet4Address getClatv4SrcAddress() {
+        // Variables in Nat464Xlat should only be accessed from handler thread.
+        ensureRunningOnHandlerThread();
+        if (!isStarted()) return null;
+
+        final LinkAddress v4Addr = getLinkAddress(mIface);
+        if (v4Addr == null) return null;
+
+        return (Inet4Address) v4Addr.getAddress();
+    }
+
+    private void ensureRunningOnHandlerThread() {
+        if (mNetwork.handler().getLooper().getThread() != Thread.currentThread()) {
+            throw new IllegalStateException(
+                    "Not running on handler thread: " + Thread.currentThread().getName());
+        }
+    }
+
+    /**
      * Dump the NAT64 xlat information.
      *
      * @param pw print writer.
diff --git a/service/src/com/android/server/connectivity/NetworkAgentInfo.java b/service/src/com/android/server/connectivity/NetworkAgentInfo.java
index 85282cb..845c04c 100644
--- a/service/src/com/android/server/connectivity/NetworkAgentInfo.java
+++ b/service/src/com/android/server/connectivity/NetworkAgentInfo.java
@@ -68,6 +68,8 @@
 import com.android.server.ConnectivityService;
 
 import java.io.PrintWriter;
+import java.net.Inet4Address;
+import java.net.Inet6Address;
 import java.time.Instant;
 import java.util.ArrayList;
 import java.util.Arrays;
@@ -1033,6 +1035,30 @@
     }
 
     /**
+     * Get the generated v6 address of clat.
+     */
+    @Nullable
+    public Inet6Address getClatv6SrcAddress() {
+        return clatd.getClatv6SrcAddress();
+    }
+
+    /**
+     * Get the generated v4 address of clat.
+     */
+    @Nullable
+    public Inet4Address getClatv4SrcAddress() {
+        return clatd.getClatv4SrcAddress();
+    }
+
+    /**
+     * Translate the input v4 address to v6 clat address.
+     */
+    @Nullable
+    public Inet6Address translateV4toClatV6(@NonNull Inet4Address addr) {
+        return clatd.translateV4toV6(addr);
+    }
+
+    /**
      * Get the NetworkMonitorManager in this NetworkAgentInfo.
      *
      * <p>This will be null before {@link #onNetworkMonitorCreated(INetworkMonitor)} is called.
diff --git a/service/src/com/android/server/connectivity/NetworkNotificationManager.java b/service/src/com/android/server/connectivity/NetworkNotificationManager.java
index 8b0cb7c..bc13592 100644
--- a/service/src/com/android/server/connectivity/NetworkNotificationManager.java
+++ b/service/src/com/android/server/connectivity/NetworkNotificationManager.java
@@ -322,7 +322,8 @@
 
     private boolean maybeNotifyViaDialog(Resources res, NotificationType notifyType,
             PendingIntent intent) {
-        if (notifyType != NotificationType.NO_INTERNET
+        if (notifyType != NotificationType.LOST_INTERNET
+                && notifyType != NotificationType.NO_INTERNET
                 && notifyType != NotificationType.PARTIAL_CONNECTIVITY) {
             return false;
         }
@@ -432,7 +433,8 @@
      * A notification with a higher number will take priority over a notification with a lower
      * number.
      */
-    private static int priority(NotificationType t) {
+    @VisibleForTesting
+    public static int priority(NotificationType t) {
         if (t == null) {
             return 0;
         }
diff --git a/service/src/com/android/server/connectivity/ProxyTracker.java b/service/src/com/android/server/connectivity/ProxyTracker.java
index 6a0918b..4415007 100644
--- a/service/src/com/android/server/connectivity/ProxyTracker.java
+++ b/service/src/com/android/server/connectivity/ProxyTracker.java
@@ -404,7 +404,7 @@
                 // network, so discount this case.
                 if (null == mGlobalProxy && !lp.getHttpProxy().getPacFileUrl()
                         .equals(defaultProxy.getPacFileUrl())) {
-                    throw new IllegalStateException("Unexpected discrepancy between proxy in LP of "
+                    Log.wtf(TAG, "Unexpected discrepancy between proxy in LP of "
                             + "default network and default proxy. The former has a PAC URL of "
                             + lp.getHttpProxy().getPacFileUrl() + " while the latter has "
                             + defaultProxy.getPacFileUrl());
diff --git a/tests/common/OWNERS b/tests/common/OWNERS
new file mode 100644
index 0000000..3101da5
--- /dev/null
+++ b/tests/common/OWNERS
@@ -0,0 +1,2 @@
+# Bug template url: http://b/new?component=31808
+# TODO: move bug template config to common owners file once b/226427845 is resolved
\ No newline at end of file
diff --git a/tests/unit/java/android/net/nsd/NsdServiceInfoTest.java b/tests/common/java/android/net/nsd/NsdServiceInfoTest.java
similarity index 88%
rename from tests/unit/java/android/net/nsd/NsdServiceInfoTest.java
rename to tests/common/java/android/net/nsd/NsdServiceInfoTest.java
index 9ce0693..ffe0e91 100644
--- a/tests/unit/java/android/net/nsd/NsdServiceInfoTest.java
+++ b/tests/common/java/android/net/nsd/NsdServiceInfoTest.java
@@ -26,10 +26,10 @@
 import android.os.Build;
 import android.os.Bundle;
 import android.os.Parcel;
-import android.os.StrictMode;
 
 import androidx.test.filters.SmallTest;
 
+import com.android.testutils.ConnectivityModuleTest;
 import com.android.testutils.DevSdkIgnoreRule;
 import com.android.testutils.DevSdkIgnoreRunner;
 
@@ -37,7 +37,6 @@
 import org.junit.runner.RunWith;
 
 import java.net.InetAddress;
-import java.net.UnknownHostException;
 import java.util.Arrays;
 import java.util.List;
 import java.util.Map;
@@ -45,22 +44,11 @@
 @RunWith(DevSdkIgnoreRunner.class)
 @SmallTest
 @DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.S_V2)
+@ConnectivityModuleTest
 public class NsdServiceInfoTest {
 
     private static final InetAddress IPV4_ADDRESS = InetAddresses.parseNumericAddress("192.0.2.1");
     private static final InetAddress IPV6_ADDRESS = InetAddresses.parseNumericAddress("2001:db8::");
-    public final static InetAddress LOCALHOST;
-    static {
-        // Because test.
-        StrictMode.ThreadPolicy policy = new StrictMode.ThreadPolicy.Builder().permitAll().build();
-        StrictMode.setThreadPolicy(policy);
-
-        InetAddress _host = null;
-        try {
-            _host = InetAddress.getLocalHost();
-        } catch (UnknownHostException e) { }
-        LOCALHOST = _host;
-    }
 
     @Test
     public void testLimits() throws Exception {
@@ -89,10 +77,10 @@
         // Single key + value length too long.
         exceptionThrown = false;
         try {
-            String longValue = "loooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo" +
-                    "oooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo" +
-                    "oooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo" +
-                    "ooooooooooooooooooooooooooooong";  // 248 characters.
+            String longValue = "loooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo"
+                    + "oooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo"
+                    + "oooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo"
+                    + "ooooooooooooooooooooooooooooong";  // 248 characters.
             info.setAttribute("longcat", longValue);  // Key + value == 255 characters.
         } catch (IllegalArgumentException e) {
             exceptionThrown = true;
@@ -127,7 +115,6 @@
         fullInfo.setServiceName("kitten");
         fullInfo.setServiceType("_kitten._tcp");
         fullInfo.setPort(4242);
-        fullInfo.setHost(LOCALHOST);
         fullInfo.setHostAddresses(List.of(IPV4_ADDRESS));
         fullInfo.setNetwork(new Network(123));
         fullInfo.setInterfaceIndex(456);
@@ -143,8 +130,7 @@
         attributedInfo.setServiceName("kitten");
         attributedInfo.setServiceType("_kitten._tcp");
         attributedInfo.setPort(4242);
-        attributedInfo.setHost(LOCALHOST);
-        fullInfo.setHostAddresses(List.of(IPV6_ADDRESS, IPV4_ADDRESS));
+        attributedInfo.setHostAddresses(List.of(IPV6_ADDRESS, IPV4_ADDRESS));
         attributedInfo.setAttribute("color", "pink");
         attributedInfo.setAttribute("sound", (new String("にゃあ")).getBytes("UTF-8"));
         attributedInfo.setAttribute("adorable", (String) null);
diff --git a/tests/cts/OWNERS b/tests/cts/OWNERS
index 8388cb7..8c2408b 100644
--- a/tests/cts/OWNERS
+++ b/tests/cts/OWNERS
@@ -1,4 +1,5 @@
 # Bug template url: http://b/new?component=31808
+# TODO: move bug template config to common owners file once b/226427845 is resolved
 set noparent
 file:platform/packages/modules/Connectivity:master:/OWNERS_core_networking_xts
 
diff --git a/tests/cts/hostside/app/src/com/android/cts/net/hostside/VpnTest.java b/tests/cts/hostside/app/src/com/android/cts/net/hostside/VpnTest.java
index cd3b650..454940f 100755
--- a/tests/cts/hostside/app/src/com/android/cts/net/hostside/VpnTest.java
+++ b/tests/cts/hostside/app/src/com/android/cts/net/hostside/VpnTest.java
@@ -1726,10 +1726,21 @@
         assertEquals(VpnManager.TYPE_VPN_SERVICE, ((VpnTransportInfo) ti).getType());
     }
 
-    private void assertDefaultProxy(ProxyInfo expected) {
+    private void assertDefaultProxy(ProxyInfo expected) throws Exception {
         assertEquals("Incorrect proxy config.", expected, mCM.getDefaultProxy());
         String expectedHost = expected == null ? null : expected.getHost();
         String expectedPort = expected == null ? null : String.valueOf(expected.getPort());
+
+        // ActivityThread may not have time to set it in the properties yet which will cause flakes.
+        // Wait for some time to deflake the test.
+        int attempt = 0;
+        while (!(Objects.equals(expectedHost, System.getProperty("http.proxyHost"))
+                && Objects.equals(expectedPort, System.getProperty("http.proxyPort")))
+                && attempt < 300) {
+            attempt++;
+            Log.d(TAG, "Wait for proxy being updated, attempt=" + attempt);
+            Thread.sleep(100);
+        }
         assertEquals("Incorrect proxy host system property.", expectedHost,
             System.getProperty("http.proxyHost"));
         assertEquals("Incorrect proxy port system property.", expectedPort,
diff --git a/tests/cts/hostside/src/com/android/cts/net/HostsideConnOnActivityStartTest.java b/tests/cts/hostside/src/com/android/cts/net/HostsideConnOnActivityStartTest.java
index cfd3130..a7d5590 100644
--- a/tests/cts/hostside/src/com/android/cts/net/HostsideConnOnActivityStartTest.java
+++ b/tests/cts/hostside/src/com/android/cts/net/HostsideConnOnActivityStartTest.java
@@ -18,9 +18,13 @@
 
 import android.platform.test.annotations.FlakyTest;
 
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
 public class HostsideConnOnActivityStartTest extends HostsideNetworkTestCase {
     private static final String TEST_CLASS = TEST_PKG + ".ConnOnActivityStartTest";
-    @Override
+    @Before
     public void setUp() throws Exception {
         super.setUp();
 
@@ -28,26 +32,30 @@
         installPackage(TEST_APP2_APK);
     }
 
-    @Override
-    protected void tearDown() throws Exception {
+    @After
+    public void tearDown() throws Exception {
         super.tearDown();
 
         uninstallPackage(TEST_APP2_PKG, true);
     }
 
+    @Test
     public void testStartActivity_batterySaver() throws Exception {
         runDeviceTests(TEST_PKG, TEST_CLASS, "testStartActivity_batterySaver");
     }
 
+    @Test
     public void testStartActivity_dataSaver() throws Exception {
         runDeviceTests(TEST_PKG, TEST_CLASS, "testStartActivity_dataSaver");
     }
 
     @FlakyTest(bugId = 231440256)
+    @Test
     public void testStartActivity_doze() throws Exception {
         runDeviceTests(TEST_PKG, TEST_CLASS, "testStartActivity_doze");
     }
 
+    @Test
     public void testStartActivity_appStandby() throws Exception {
         runDeviceTests(TEST_PKG, TEST_CLASS, "testStartActivity_appStandby");
     }
diff --git a/tests/cts/hostside/src/com/android/cts/net/HostsideNetworkCallbackTests.java b/tests/cts/hostside/src/com/android/cts/net/HostsideNetworkCallbackTests.java
index 1312085..5d7ad62 100644
--- a/tests/cts/hostside/src/com/android/cts/net/HostsideNetworkCallbackTests.java
+++ b/tests/cts/hostside/src/com/android/cts/net/HostsideNetworkCallbackTests.java
@@ -14,26 +14,33 @@
  * limitations under the License.
  */
 package com.android.cts.net;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
 public class HostsideNetworkCallbackTests extends HostsideNetworkTestCase {
 
-    @Override
-    protected void setUp() throws Exception {
+    @Before
+    public void setUp() throws Exception {
         super.setUp();
         uninstallPackage(TEST_APP2_PKG, false);
         installPackage(TEST_APP2_APK);
     }
 
-    @Override
-    protected void tearDown() throws Exception {
+    @After
+    public void tearDown() throws Exception {
         super.tearDown();
         uninstallPackage(TEST_APP2_PKG, true);
     }
 
+    @Test
     public void testOnBlockedStatusChanged_dataSaver() throws Exception {
         runDeviceTests(TEST_PKG,
                 TEST_PKG + ".NetworkCallbackTest", "testOnBlockedStatusChanged_dataSaver");
     }
 
+    @Test
     public void testOnBlockedStatusChanged_powerSaver() throws Exception {
         runDeviceTests(TEST_PKG,
                 TEST_PKG + ".NetworkCallbackTest", "testOnBlockedStatusChanged_powerSaver");
diff --git a/tests/cts/hostside/src/com/android/cts/net/HostsideNetworkPolicyManagerTests.java b/tests/cts/hostside/src/com/android/cts/net/HostsideNetworkPolicyManagerTests.java
index fdb8876..40f5f59 100644
--- a/tests/cts/hostside/src/com/android/cts/net/HostsideNetworkPolicyManagerTests.java
+++ b/tests/cts/hostside/src/com/android/cts/net/HostsideNetworkPolicyManagerTests.java
@@ -16,49 +16,59 @@
 
 package com.android.cts.net;
 
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
 public class HostsideNetworkPolicyManagerTests extends HostsideNetworkTestCase {
-    @Override
-    protected void setUp() throws Exception {
+    @Before
+    public void setUp() throws Exception {
         super.setUp();
         uninstallPackage(TEST_APP2_PKG, false);
         installPackage(TEST_APP2_APK);
     }
 
-    @Override
-    protected void tearDown() throws Exception {
+    @After
+    public void tearDown() throws Exception {
         super.tearDown();
         uninstallPackage(TEST_APP2_PKG, true);
     }
 
+    @Test
     public void testIsUidNetworkingBlocked_withUidNotBlocked() throws Exception {
         runDeviceTests(TEST_PKG,
                 TEST_PKG + ".NetworkPolicyManagerTest",
                 "testIsUidNetworkingBlocked_withUidNotBlocked");
     }
 
+    @Test
     public void testIsUidNetworkingBlocked_withSystemUid() throws Exception {
         runDeviceTests(TEST_PKG,
                 TEST_PKG + ".NetworkPolicyManagerTest", "testIsUidNetworkingBlocked_withSystemUid");
     }
 
+    @Test
     public void testIsUidNetworkingBlocked_withDataSaverMode() throws Exception {
         runDeviceTests(TEST_PKG,
                 TEST_PKG + ".NetworkPolicyManagerTest",
                 "testIsUidNetworkingBlocked_withDataSaverMode");
     }
 
+    @Test
     public void testIsUidNetworkingBlocked_withRestrictedNetworkingMode() throws Exception {
         runDeviceTests(TEST_PKG,
                 TEST_PKG + ".NetworkPolicyManagerTest",
                 "testIsUidNetworkingBlocked_withRestrictedNetworkingMode");
     }
 
+    @Test
     public void testIsUidNetworkingBlocked_withPowerSaverMode() throws Exception {
         runDeviceTests(TEST_PKG,
                 TEST_PKG + ".NetworkPolicyManagerTest",
                 "testIsUidNetworkingBlocked_withPowerSaverMode");
     }
 
+    @Test
     public void testIsUidRestrictedOnMeteredNetworks() throws Exception {
         runDeviceTests(TEST_PKG,
                 TEST_PKG + ".NetworkPolicyManagerTest", "testIsUidRestrictedOnMeteredNetworks");
diff --git a/tests/cts/hostside/src/com/android/cts/net/HostsideNetworkTestCase.java b/tests/cts/hostside/src/com/android/cts/net/HostsideNetworkTestCase.java
index 2aa1032..c896168 100644
--- a/tests/cts/hostside/src/com/android/cts/net/HostsideNetworkTestCase.java
+++ b/tests/cts/hostside/src/com/android/cts/net/HostsideNetworkTestCase.java
@@ -16,28 +16,22 @@
 
 package com.android.cts.net;
 
-import com.android.compatibility.common.tradefed.build.CompatibilityBuildHelper;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.fail;
+
 import com.android.ddmlib.Log;
-import com.android.ddmlib.testrunner.RemoteAndroidTestRunner;
-import com.android.ddmlib.testrunner.TestResult.TestStatus;
 import com.android.modules.utils.build.testing.DeviceSdkLevel;
-import com.android.tradefed.build.IBuildInfo;
 import com.android.tradefed.device.DeviceNotAvailableException;
-import com.android.tradefed.result.CollectingTestListener;
-import com.android.tradefed.result.TestDescription;
-import com.android.tradefed.result.TestResult;
-import com.android.tradefed.result.TestRunResult;
-import com.android.tradefed.testtype.DeviceTestCase;
-import com.android.tradefed.testtype.IAbi;
-import com.android.tradefed.testtype.IAbiReceiver;
-import com.android.tradefed.testtype.IBuildReceiver;
+import com.android.tradefed.targetprep.TargetSetupError;
+import com.android.tradefed.testtype.DeviceJUnit4ClassRunner;
+import com.android.tradefed.testtype.junit4.BaseHostJUnit4Test;
+import com.android.tradefed.testtype.junit4.DeviceTestRunOptions;
 import com.android.tradefed.util.RunUtil;
 
-import java.io.FileNotFoundException;
-import java.util.Map;
+import org.junit.runner.RunWith;
 
-abstract class HostsideNetworkTestCase extends DeviceTestCase implements IAbiReceiver,
-        IBuildReceiver {
+@RunWith(DeviceJUnit4ClassRunner.class)
+abstract class HostsideNetworkTestCase extends BaseHostJUnit4Test {
     protected static final boolean DEBUG = false;
     protected static final String TAG = "HostsideNetworkTests";
     protected static final String TEST_PKG = "com.android.cts.net.hostside";
@@ -46,26 +40,7 @@
     protected static final String TEST_APP2_PKG = "com.android.cts.net.hostside.app2";
     protected static final String TEST_APP2_APK = "CtsHostsideNetworkTestsApp2.apk";
 
-    private IAbi mAbi;
-    private IBuildInfo mCtsBuild;
-
-    @Override
-    public void setAbi(IAbi abi) {
-        mAbi = abi;
-    }
-
-    @Override
-    public void setBuild(IBuildInfo buildInfo) {
-        mCtsBuild = buildInfo;
-    }
-
-    @Override
     protected void setUp() throws Exception {
-        super.setUp();
-
-        assertNotNull(mAbi);
-        assertNotNull(mCtsBuild);
-
         DeviceSdkLevel deviceSdkLevel = new DeviceSdkLevel(getDevice());
         String testApk = deviceSdkLevel.isDeviceAtLeastT() ? TEST_APK_NEXT
                 : TEST_APK;
@@ -74,23 +49,20 @@
         installPackage(testApk);
     }
 
-    @Override
     protected void tearDown() throws Exception {
-        super.tearDown();
-
         uninstallPackage(TEST_PKG, true);
     }
 
-    protected void installPackage(String apk) throws FileNotFoundException,
-            DeviceNotAvailableException {
-        CompatibilityBuildHelper buildHelper = new CompatibilityBuildHelper(mCtsBuild);
-        assertNull(getDevice().installPackage(buildHelper.getTestFile(apk),
-                false /* reinstall */, true /* grantPermissions */, "-t"));
+    protected void installPackage(String apk) throws DeviceNotAvailableException, TargetSetupError {
+        final DeviceTestRunOptions installOptions = new DeviceTestRunOptions(
+                null /* packageName */);
+        final int userId = getDevice().getCurrentUser();
+        installPackageAsUser(apk, true /* grantPermission */, userId, "-t");
     }
 
     protected void uninstallPackage(String packageName, boolean shouldSucceed)
             throws DeviceNotAvailableException {
-        final String result = getDevice().uninstallPackage(packageName);
+        final String result = uninstallPackage(packageName);
         if (shouldSucceed) {
             assertNull("uninstallPackage(" + packageName + ") failed: " + result, result);
         }
@@ -126,50 +98,6 @@
         fail("Package '" + packageName + "' not uinstalled after " + max_tries + " seconds");
     }
 
-    protected void runDeviceTests(String packageName, String testClassName)
-            throws DeviceNotAvailableException {
-        runDeviceTests(packageName, testClassName, null);
-    }
-
-    protected void runDeviceTests(String packageName, String testClassName, String methodName)
-            throws DeviceNotAvailableException {
-        RemoteAndroidTestRunner testRunner = new RemoteAndroidTestRunner(packageName,
-                "androidx.test.runner.AndroidJUnitRunner", getDevice().getIDevice());
-
-        if (testClassName != null) {
-            if (methodName != null) {
-                testRunner.setMethodName(testClassName, methodName);
-            } else {
-                testRunner.setClassName(testClassName);
-            }
-        }
-
-        final CollectingTestListener listener = new CollectingTestListener();
-        getDevice().runInstrumentationTests(testRunner, listener);
-
-        final TestRunResult result = listener.getCurrentRunResults();
-        if (result.isRunFailure()) {
-            throw new AssertionError("Failed to successfully run device tests for "
-                    + result.getName() + ": " + result.getRunFailureMessage());
-        }
-
-        if (result.hasFailedTests()) {
-            // build a meaningful error message
-            StringBuilder errorBuilder = new StringBuilder("on-device tests failed:\n");
-            for (Map.Entry<TestDescription, TestResult> resultEntry :
-                    result.getTestResults().entrySet()) {
-                final TestStatus testStatus = resultEntry.getValue().getStatus();
-                if (!TestStatus.PASSED.equals(testStatus)
-                        && !TestStatus.ASSUMPTION_FAILURE.equals(testStatus)) {
-                    errorBuilder.append(resultEntry.getKey().toString());
-                    errorBuilder.append(":\n");
-                    errorBuilder.append(resultEntry.getValue().getStackTrace());
-                }
-            }
-            throw new AssertionError(errorBuilder.toString());
-        }
-    }
-
     protected int getUid(String packageName) throws DeviceNotAvailableException {
         final int currentUser = getDevice().getCurrentUser();
         final String uidLines = runCommand(
diff --git a/tests/cts/hostside/src/com/android/cts/net/HostsideRestrictBackgroundNetworkTests.java b/tests/cts/hostside/src/com/android/cts/net/HostsideRestrictBackgroundNetworkTests.java
index 21c78b7..0977deb 100644
--- a/tests/cts/hostside/src/com/android/cts/net/HostsideRestrictBackgroundNetworkTests.java
+++ b/tests/cts/hostside/src/com/android/cts/net/HostsideRestrictBackgroundNetworkTests.java
@@ -16,30 +16,37 @@
 
 package com.android.cts.net;
 
+import static org.junit.Assert.fail;
+
 import android.platform.test.annotations.SecurityTest;
 
 import com.android.ddmlib.Log;
 import com.android.tradefed.device.DeviceNotAvailableException;
 import com.android.tradefed.util.RunUtil;
 
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
 public class HostsideRestrictBackgroundNetworkTests extends HostsideNetworkTestCase {
 
-    @Override
-    protected void setUp() throws Exception {
+    @Before
+    public void setUp() throws Exception {
         super.setUp();
 
         uninstallPackage(TEST_APP2_PKG, false);
         installPackage(TEST_APP2_APK);
     }
 
-    @Override
-    protected void tearDown() throws Exception {
+    @After
+    public void tearDown() throws Exception {
         super.tearDown();
 
         uninstallPackage(TEST_APP2_PKG, true);
     }
 
     @SecurityTest
+    @Test
     public void testDataWarningReceiver() throws Exception {
         runDeviceTests(TEST_PKG, TEST_PKG + ".DataWarningReceiverTest",
                 "testSnoozeWarningNotReceived");
@@ -49,26 +56,31 @@
      * Data Saver Mode tests. *
      **************************/
 
+    @Test
     public void testDataSaverMode_disabled() throws Exception {
         runDeviceTests(TEST_PKG, TEST_PKG + ".DataSaverModeTest",
                 "testGetRestrictBackgroundStatus_disabled");
     }
 
+    @Test
     public void testDataSaverMode_whitelisted() throws Exception {
         runDeviceTests(TEST_PKG, TEST_PKG + ".DataSaverModeTest",
                 "testGetRestrictBackgroundStatus_whitelisted");
     }
 
+    @Test
     public void testDataSaverMode_enabled() throws Exception {
         runDeviceTests(TEST_PKG, TEST_PKG + ".DataSaverModeTest",
                 "testGetRestrictBackgroundStatus_enabled");
     }
 
+    @Test
     public void testDataSaverMode_blacklisted() throws Exception {
         runDeviceTests(TEST_PKG, TEST_PKG + ".DataSaverModeTest",
                 "testGetRestrictBackgroundStatus_blacklisted");
     }
 
+    @Test
     public void testDataSaverMode_reinstall() throws Exception {
         final int oldUid = getUid(TEST_APP2_PKG);
 
@@ -85,11 +97,13 @@
         assertRestrictBackgroundWhitelist(newUid, false);
     }
 
+    @Test
     public void testDataSaverMode_requiredWhitelistedPackages() throws Exception {
         runDeviceTests(TEST_PKG, TEST_PKG + ".DataSaverModeTest",
                 "testGetRestrictBackgroundStatus_requiredWhitelistedPackages");
     }
 
+    @Test
     public void testDataSaverMode_broadcastNotSentOnUnsupportedDevices() throws Exception {
         runDeviceTests(TEST_PKG, TEST_PKG + ".DataSaverModeTest",
                 "testBroadcastNotSentOnUnsupportedDevices");
@@ -99,21 +113,25 @@
      * Battery Saver Mode tests. *
      *****************************/
 
+    @Test
     public void testBatterySaverModeMetered_disabled() throws Exception {
         runDeviceTests(TEST_PKG, TEST_PKG + ".BatterySaverModeMeteredTest",
                 "testBackgroundNetworkAccess_disabled");
     }
 
+    @Test
     public void testBatterySaverModeMetered_whitelisted() throws Exception {
         runDeviceTests(TEST_PKG, TEST_PKG + ".BatterySaverModeMeteredTest",
                 "testBackgroundNetworkAccess_whitelisted");
     }
 
+    @Test
     public void testBatterySaverModeMetered_enabled() throws Exception {
         runDeviceTests(TEST_PKG, TEST_PKG + ".BatterySaverModeMeteredTest",
                 "testBackgroundNetworkAccess_enabled");
     }
 
+    @Test
     public void testBatterySaverMode_reinstall() throws Exception {
         if (!isDozeModeEnabled()) {
             Log.w(TAG, "testBatterySaverMode_reinstall() skipped because device does not support "
@@ -131,16 +149,19 @@
         assertPowerSaveModeWhitelist(TEST_APP2_PKG, false);
     }
 
+    @Test
     public void testBatterySaverModeNonMetered_disabled() throws Exception {
         runDeviceTests(TEST_PKG, TEST_PKG + ".BatterySaverModeNonMeteredTest",
                 "testBackgroundNetworkAccess_disabled");
     }
 
+    @Test
     public void testBatterySaverModeNonMetered_whitelisted() throws Exception {
         runDeviceTests(TEST_PKG, TEST_PKG + ".BatterySaverModeNonMeteredTest",
                 "testBackgroundNetworkAccess_whitelisted");
     }
 
+    @Test
     public void testBatterySaverModeNonMetered_enabled() throws Exception {
         runDeviceTests(TEST_PKG, TEST_PKG + ".BatterySaverModeNonMeteredTest",
                 "testBackgroundNetworkAccess_enabled");
@@ -150,26 +171,31 @@
      * App idle tests. *
      *******************/
 
+    @Test
     public void testAppIdleMetered_disabled() throws Exception {
         runDeviceTests(TEST_PKG, TEST_PKG + ".AppIdleMeteredTest",
                 "testBackgroundNetworkAccess_disabled");
     }
 
+    @Test
     public void testAppIdleMetered_whitelisted() throws Exception {
         runDeviceTests(TEST_PKG, TEST_PKG + ".AppIdleMeteredTest",
                 "testBackgroundNetworkAccess_whitelisted");
     }
 
+    @Test
     public void testAppIdleMetered_tempWhitelisted() throws Exception {
         runDeviceTests(TEST_PKG, TEST_PKG + ".AppIdleMeteredTest",
                 "testBackgroundNetworkAccess_tempWhitelisted");
     }
 
+    @Test
     public void testAppIdleMetered_enabled() throws Exception {
         runDeviceTests(TEST_PKG, TEST_PKG + ".AppIdleMeteredTest",
                 "testBackgroundNetworkAccess_enabled");
     }
 
+    @Test
     public void testAppIdleMetered_idleWhitelisted() throws Exception {
         runDeviceTests(TEST_PKG, TEST_PKG + ".AppIdleMeteredTest",
                 "testAppIdleNetworkAccess_idleWhitelisted");
@@ -180,41 +206,50 @@
     //    public void testAppIdle_reinstall() throws Exception {
     //    }
 
+    @Test
     public void testAppIdleNonMetered_disabled() throws Exception {
         runDeviceTests(TEST_PKG, TEST_PKG + ".AppIdleNonMeteredTest",
                 "testBackgroundNetworkAccess_disabled");
     }
 
+
+    @Test
     public void testAppIdleNonMetered_whitelisted() throws Exception {
         runDeviceTests(TEST_PKG, TEST_PKG + ".AppIdleNonMeteredTest",
                 "testBackgroundNetworkAccess_whitelisted");
     }
 
+    @Test
     public void testAppIdleNonMetered_tempWhitelisted() throws Exception {
         runDeviceTests(TEST_PKG, TEST_PKG + ".AppIdleNonMeteredTest",
                 "testBackgroundNetworkAccess_tempWhitelisted");
     }
 
+    @Test
     public void testAppIdleNonMetered_enabled() throws Exception {
         runDeviceTests(TEST_PKG, TEST_PKG + ".AppIdleNonMeteredTest",
                 "testBackgroundNetworkAccess_enabled");
     }
 
+    @Test
     public void testAppIdleNonMetered_idleWhitelisted() throws Exception {
         runDeviceTests(TEST_PKG, TEST_PKG + ".AppIdleNonMeteredTest",
                 "testAppIdleNetworkAccess_idleWhitelisted");
     }
 
+    @Test
     public void testAppIdleNonMetered_whenCharging() throws Exception {
         runDeviceTests(TEST_PKG, TEST_PKG + ".AppIdleNonMeteredTest",
                 "testAppIdleNetworkAccess_whenCharging");
     }
 
+    @Test
     public void testAppIdleMetered_whenCharging() throws Exception {
         runDeviceTests(TEST_PKG, TEST_PKG + ".AppIdleMeteredTest",
                 "testAppIdleNetworkAccess_whenCharging");
     }
 
+    @Test
     public void testAppIdle_toast() throws Exception {
         // Check that showing a toast doesn't bring an app out of standby
         runDeviceTests(TEST_PKG, TEST_PKG + ".AppIdleNonMeteredTest",
@@ -225,21 +260,25 @@
      * Doze Mode tests. *
      ********************/
 
+    @Test
     public void testDozeModeMetered_disabled() throws Exception {
         runDeviceTests(TEST_PKG, TEST_PKG + ".DozeModeMeteredTest",
                 "testBackgroundNetworkAccess_disabled");
     }
 
+    @Test
     public void testDozeModeMetered_whitelisted() throws Exception {
         runDeviceTests(TEST_PKG, TEST_PKG + ".DozeModeMeteredTest",
                 "testBackgroundNetworkAccess_whitelisted");
     }
 
+    @Test
     public void testDozeModeMetered_enabled() throws Exception {
         runDeviceTests(TEST_PKG, TEST_PKG + ".DozeModeMeteredTest",
                 "testBackgroundNetworkAccess_enabled");
     }
 
+    @Test
     public void testDozeModeMetered_enabledButWhitelistedOnNotificationAction() throws Exception {
         runDeviceTests(TEST_PKG, TEST_PKG + ".DozeModeMeteredTest",
                 "testBackgroundNetworkAccess_enabledButWhitelistedOnNotificationAction");
@@ -250,21 +289,25 @@
     //    public void testDozeMode_reinstall() throws Exception {
     //    }
 
+    @Test
     public void testDozeModeNonMetered_disabled() throws Exception {
         runDeviceTests(TEST_PKG, TEST_PKG + ".DozeModeNonMeteredTest",
                 "testBackgroundNetworkAccess_disabled");
     }
 
+    @Test
     public void testDozeModeNonMetered_whitelisted() throws Exception {
         runDeviceTests(TEST_PKG, TEST_PKG + ".DozeModeNonMeteredTest",
                 "testBackgroundNetworkAccess_whitelisted");
     }
 
+    @Test
     public void testDozeModeNonMetered_enabled() throws Exception {
         runDeviceTests(TEST_PKG, TEST_PKG + ".DozeModeNonMeteredTest",
                 "testBackgroundNetworkAccess_enabled");
     }
 
+    @Test
     public void testDozeModeNonMetered_enabledButWhitelistedOnNotificationAction()
             throws Exception {
         runDeviceTests(TEST_PKG, TEST_PKG + ".DozeModeNonMeteredTest",
@@ -275,46 +318,55 @@
      * Mixed modes tests. *
      **********************/
 
+    @Test
     public void testDataAndBatterySaverModes_meteredNetwork() throws Exception {
         runDeviceTests(TEST_PKG, TEST_PKG + ".MixedModesTest",
                 "testDataAndBatterySaverModes_meteredNetwork");
     }
 
+    @Test
     public void testDataAndBatterySaverModes_nonMeteredNetwork() throws Exception {
         runDeviceTests(TEST_PKG, TEST_PKG + ".MixedModesTest",
                 "testDataAndBatterySaverModes_nonMeteredNetwork");
     }
 
+    @Test
     public void testDozeAndBatterySaverMode_powerSaveWhitelists() throws Exception {
         runDeviceTests(TEST_PKG, TEST_PKG + ".MixedModesTest",
                 "testDozeAndBatterySaverMode_powerSaveWhitelists");
     }
 
+    @Test
     public void testDozeAndAppIdle_powerSaveWhitelists() throws Exception {
         runDeviceTests(TEST_PKG, TEST_PKG + ".MixedModesTest",
                 "testDozeAndAppIdle_powerSaveWhitelists");
     }
 
+    @Test
     public void testAppIdleAndDoze_tempPowerSaveWhitelists() throws Exception {
         runDeviceTests(TEST_PKG, TEST_PKG + ".MixedModesTest",
                 "testAppIdleAndDoze_tempPowerSaveWhitelists");
     }
 
+    @Test
     public void testAppIdleAndBatterySaver_tempPowerSaveWhitelists() throws Exception {
         runDeviceTests(TEST_PKG, TEST_PKG + ".MixedModesTest",
                 "testAppIdleAndBatterySaver_tempPowerSaveWhitelists");
     }
 
+    @Test
     public void testDozeAndAppIdle_appIdleWhitelist() throws Exception {
         runDeviceTests(TEST_PKG, TEST_PKG + ".MixedModesTest",
                 "testDozeAndAppIdle_appIdleWhitelist");
     }
 
+    @Test
     public void testAppIdleAndDoze_tempPowerSaveAndAppIdleWhitelists() throws Exception {
         runDeviceTests(TEST_PKG, TEST_PKG + ".MixedModesTest",
                 "testAppIdleAndDoze_tempPowerSaveAndAppIdleWhitelists");
     }
 
+    @Test
     public void testAppIdleAndBatterySaver_tempPowerSaveAndAppIdleWhitelists() throws Exception {
         runDeviceTests(TEST_PKG, TEST_PKG + ".MixedModesTest",
                 "testAppIdleAndBatterySaver_tempPowerSaveAndAppIdleWhitelists");
@@ -323,11 +375,14 @@
     /**************************
      * Restricted mode tests. *
      **************************/
+
+    @Test
     public void testNetworkAccess_restrictedMode() throws Exception {
         runDeviceTests(TEST_PKG, TEST_PKG + ".RestrictedModeTest",
                 "testNetworkAccess");
     }
 
+    @Test
     public void testNetworkAccess_restrictedMode_withBatterySaver() throws Exception {
         runDeviceTests(TEST_PKG, TEST_PKG + ".RestrictedModeTest",
                 "testNetworkAccess_withBatterySaver");
@@ -337,10 +392,12 @@
      * Expedited job tests. *
      ************************/
 
+    @Test
     public void testMeteredNetworkAccess_expeditedJob() throws Exception {
         runDeviceTests(TEST_PKG, TEST_PKG + ".ExpeditedJobMeteredTest");
     }
 
+    @Test
     public void testNonMeteredNetworkAccess_expeditedJob() throws Exception {
         runDeviceTests(TEST_PKG, TEST_PKG + ".ExpeditedJobNonMeteredTest");
     }
diff --git a/tests/cts/hostside/src/com/android/cts/net/HostsideSelfDeclaredNetworkCapabilitiesCheckTest.java b/tests/cts/hostside/src/com/android/cts/net/HostsideSelfDeclaredNetworkCapabilitiesCheckTest.java
index 4c2985d..c3bdb6d 100644
--- a/tests/cts/hostside/src/com/android/cts/net/HostsideSelfDeclaredNetworkCapabilitiesCheckTest.java
+++ b/tests/cts/hostside/src/com/android/cts/net/HostsideSelfDeclaredNetworkCapabilitiesCheckTest.java
@@ -15,6 +15,8 @@
  */
 package com.android.cts.net;
 
+import org.junit.Test;
+
 public class HostsideSelfDeclaredNetworkCapabilitiesCheckTest extends HostsideNetworkTestCase {
 
     private static final String TEST_WITH_PROPERTY_IN_CURRENT_SDK_APK =
@@ -34,6 +36,7 @@
             "requestNetwork_withoutRequestCapabilities";
 
 
+    @Test
     public void testRequestNetworkInCurrentSdkWithProperty() throws Exception {
         uninstallPackage(TEST_APP_PKG, false);
         installPackage(TEST_WITH_PROPERTY_IN_CURRENT_SDK_APK);
@@ -48,6 +51,7 @@
         uninstallPackage(TEST_APP_PKG, true);
     }
 
+    @Test
     public void testRequestNetworkInCurrentSdkWithoutProperty() throws Exception {
         uninstallPackage(TEST_APP_PKG, false);
         installPackage(TEST_WITHOUT_PROPERTY_IN_CURRENT_SDK_APK);
@@ -62,6 +66,7 @@
         uninstallPackage(TEST_APP_PKG, true);
     }
 
+    @Test
     public void testRequestNetworkInSdk33() throws Exception {
         uninstallPackage(TEST_APP_PKG, false);
         installPackage(TEST_IN_SDK_33_APK);
@@ -75,6 +80,7 @@
         uninstallPackage(TEST_APP_PKG, true);
     }
 
+    @Test
     public void testReinstallPackageWillUpdateProperty() throws Exception {
         uninstallPackage(TEST_APP_PKG, false);
         installPackage(TEST_WITHOUT_PROPERTY_IN_CURRENT_SDK_APK);
diff --git a/tests/cts/hostside/src/com/android/cts/net/HostsideVpnTests.java b/tests/cts/hostside/src/com/android/cts/net/HostsideVpnTests.java
index 3ca4775..242fd5d 100644
--- a/tests/cts/hostside/src/com/android/cts/net/HostsideVpnTests.java
+++ b/tests/cts/hostside/src/com/android/cts/net/HostsideVpnTests.java
@@ -18,95 +18,116 @@
 
 import android.platform.test.annotations.RequiresDevice;
 
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
 public class HostsideVpnTests extends HostsideNetworkTestCase {
 
-    @Override
-    protected void setUp() throws Exception {
+    @Before
+    public void setUp() throws Exception {
         super.setUp();
 
         uninstallPackage(TEST_APP2_PKG, false);
         installPackage(TEST_APP2_APK);
     }
 
-    @Override
-    protected void tearDown() throws Exception {
+    @After
+    public void tearDown() throws Exception {
         super.tearDown();
 
         uninstallPackage(TEST_APP2_PKG, true);
     }
 
+    @Test
     public void testChangeUnderlyingNetworks() throws Exception {
         runDeviceTests(TEST_PKG, TEST_PKG + ".VpnTest", "testChangeUnderlyingNetworks");
     }
 
+    @Test
     public void testDefault() throws Exception {
         runDeviceTests(TEST_PKG, TEST_PKG + ".VpnTest", "testDefault");
     }
 
+    @Test
     public void testAppAllowed() throws Exception {
         runDeviceTests(TEST_PKG, TEST_PKG + ".VpnTest", "testAppAllowed");
     }
 
+    @Test
     public void testAppDisallowed() throws Exception {
         runDeviceTests(TEST_PKG, TEST_PKG + ".VpnTest", "testAppDisallowed");
     }
 
+    @Test
     public void testSocketClosed() throws Exception {
         runDeviceTests(TEST_PKG, TEST_PKG + ".VpnTest", "testSocketClosed");
     }
 
+    @Test
     public void testGetConnectionOwnerUidSecurity() throws Exception {
         runDeviceTests(TEST_PKG, TEST_PKG + ".VpnTest", "testGetConnectionOwnerUidSecurity");
     }
 
+    @Test
     public void testSetProxy() throws Exception {
         runDeviceTests(TEST_PKG, TEST_PKG + ".VpnTest", "testSetProxy");
     }
 
+    @Test
     public void testSetProxyDisallowedApps() throws Exception {
         runDeviceTests(TEST_PKG, TEST_PKG + ".VpnTest", "testSetProxyDisallowedApps");
     }
 
+    @Test
     public void testNoProxy() throws Exception {
         runDeviceTests(TEST_PKG, TEST_PKG + ".VpnTest", "testNoProxy");
     }
 
+    @Test
     public void testBindToNetworkWithProxy() throws Exception {
         runDeviceTests(TEST_PKG, TEST_PKG + ".VpnTest", "testBindToNetworkWithProxy");
     }
 
+    @Test
     public void testVpnMeterednessWithNoUnderlyingNetwork() throws Exception {
         runDeviceTests(
                 TEST_PKG, TEST_PKG + ".VpnTest", "testVpnMeterednessWithNoUnderlyingNetwork");
     }
 
+    @Test
     public void testVpnMeterednessWithNullUnderlyingNetwork() throws Exception {
         runDeviceTests(
                 TEST_PKG, TEST_PKG + ".VpnTest", "testVpnMeterednessWithNullUnderlyingNetwork");
     }
 
+    @Test
     public void testVpnMeterednessWithNonNullUnderlyingNetwork() throws Exception {
         runDeviceTests(
                 TEST_PKG, TEST_PKG + ".VpnTest", "testVpnMeterednessWithNonNullUnderlyingNetwork");
     }
 
+    @Test
     public void testAlwaysMeteredVpnWithNullUnderlyingNetwork() throws Exception {
         runDeviceTests(
                 TEST_PKG, TEST_PKG + ".VpnTest", "testAlwaysMeteredVpnWithNullUnderlyingNetwork");
     }
 
     @RequiresDevice // Keepalive is not supported on virtual hardware
+    @Test
     public void testAutomaticOnOffKeepaliveModeClose() throws Exception {
         runDeviceTests(
                 TEST_PKG, TEST_PKG + ".VpnTest", "testAutomaticOnOffKeepaliveModeClose");
     }
 
     @RequiresDevice // Keepalive is not supported on virtual hardware
+    @Test
     public void testAutomaticOnOffKeepaliveModeNoClose() throws Exception {
         runDeviceTests(
                 TEST_PKG, TEST_PKG + ".VpnTest", "testAutomaticOnOffKeepaliveModeNoClose");
     }
 
+    @Test
     public void testAlwaysMeteredVpnWithNonNullUnderlyingNetwork() throws Exception {
         runDeviceTests(
                 TEST_PKG,
@@ -114,31 +135,38 @@
                 "testAlwaysMeteredVpnWithNonNullUnderlyingNetwork");
     }
 
+    @Test
     public void testB141603906() throws Exception {
         runDeviceTests(TEST_PKG, TEST_PKG + ".VpnTest", "testB141603906");
     }
 
+    @Test
     public void testDownloadWithDownloadManagerDisallowed() throws Exception {
         runDeviceTests(TEST_PKG, TEST_PKG + ".VpnTest",
                 "testDownloadWithDownloadManagerDisallowed");
     }
 
+    @Test
     public void testExcludedRoutes() throws Exception {
         runDeviceTests(TEST_PKG, TEST_PKG + ".VpnTest", "testExcludedRoutes");
     }
 
+    @Test
     public void testIncludedRoutes() throws Exception {
         runDeviceTests(TEST_PKG, TEST_PKG + ".VpnTest", "testIncludedRoutes");
     }
 
+    @Test
     public void testInterleavedRoutes() throws Exception {
         runDeviceTests(TEST_PKG, TEST_PKG + ".VpnTest", "testInterleavedRoutes");
     }
 
+    @Test
     public void testBlockIncomingPackets() throws Exception {
         runDeviceTests(TEST_PKG, TEST_PKG + ".VpnTest", "testBlockIncomingPackets");
     }
 
+    @Test
     public void testSetVpnDefaultForUids() throws Exception {
         runDeviceTests(TEST_PKG, TEST_PKG + ".VpnTest", "testSetVpnDefaultForUids");
     }
diff --git a/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt b/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt
index 9f8a05d..98ea224 100644
--- a/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt
+++ b/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt
@@ -164,6 +164,12 @@
     it.obj = obj
 }
 
+// On T and below, the native network is only created when the agent connects.
+// Starting in U, the native network was to be created as soon as the agent is registered,
+// but this has been flagged off for now pending resolution of race conditions.
+// TODO : enable this in a Mainline update or in V.
+private const val SHOULD_CREATE_NETWORKS_IMMEDIATELY = false
+
 @RunWith(DevSdkIgnoreRunner::class)
 // NetworkAgent is not updatable in R-, so this test does not need to be compatible with older
 // versions. NetworkAgent was also based on AsyncChannel before S so cannot be tested the same way.
@@ -1247,15 +1253,15 @@
 
         // Connect a third network. Because network1 is awaiting replacement, network3 is preferred
         // as soon as it validates (until then, it is outscored by network1).
-        // The fact that the first event seen by matchAllCallback is the connection of network3
+        // The fact that the first events seen by matchAllCallback is the connection of network3
         // implicitly ensures that no callbacks are sent since network1 was lost.
         val (agent3, network3) = connectNetwork()
+        matchAllCallback.expectAvailableThenValidatedCallbacks(network3)
+        testCallback.expectAvailableDoubleValidatedCallbacks(network3)
+
         // As soon as the replacement arrives, network1 is disconnected.
         // Check that this happens before the replacement timeout (5 seconds) fires.
-        matchAllCallback.expectAvailableCallbacks(network3, validated = false)
         matchAllCallback.expect<Lost>(network1, 2_000 /* timeoutMs */)
-        matchAllCallback.expectCaps(network3) { it.hasCapability(NET_CAPABILITY_VALIDATED) }
-        testCallback.expectAvailableDoubleValidatedCallbacks(network3)
         agent1.expectCallback<OnNetworkUnwanted>()
 
         // Test lingering:
@@ -1301,8 +1307,8 @@
         val callback = TestableNetworkCallback()
         requestNetwork(makeTestNetworkRequest(specifier = specifier6), callback)
         val agent6 = createNetworkAgent(specifier = specifier6)
-        agent6.register()
-        if (SdkLevel.isAtLeastU()) {
+        val network6 = agent6.register()
+        if (SHOULD_CREATE_NETWORKS_IMMEDIATELY) {
             agent6.expectCallback<OnNetworkCreated>()
         } else {
             // No callbacks are sent, so check LinkProperties to wait for the network to be created.
@@ -1316,9 +1322,10 @@
         val timeoutMs = agent6.DEFAULT_TIMEOUT_MS.toInt() + 1_000
         agent6.unregisterAfterReplacement(timeoutMs)
         agent6.expectCallback<OnNetworkUnwanted>()
-        if (!SdkLevel.isAtLeastT() || SdkLevel.isAtLeastU()) {
+        if (!SdkLevel.isAtLeastT() || SHOULD_CREATE_NETWORKS_IMMEDIATELY) {
             // Before T, onNetworkDestroyed is called even if the network was never created.
-            // On U+, the network was created by register(). Destroying it sends onNetworkDestroyed.
+            // If immediate native network creation is supported, the network was created by
+            // register(). Destroying it sends onNetworkDestroyed.
             agent6.expectCallback<OnNetworkDestroyed>()
         }
         // Poll for LinkProperties becoming null, because when onNetworkUnwanted is called, the
@@ -1368,9 +1375,8 @@
 
         val (newWifiAgent, newWifiNetwork) = connectNetwork(TRANSPORT_WIFI)
         testCallback.expectAvailableCallbacks(newWifiNetwork, validated = true)
-        matchAllCallback.expectAvailableCallbacks(newWifiNetwork, validated = false)
+        matchAllCallback.expectAvailableThenValidatedCallbacks(newWifiNetwork)
         matchAllCallback.expect<Lost>(wifiNetwork)
-        matchAllCallback.expectCaps(newWifiNetwork) { it.hasCapability(NET_CAPABILITY_VALIDATED) }
         wifiAgent.expectCallback<OnNetworkUnwanted>()
     }
 
@@ -1478,10 +1484,9 @@
 
     @Test
     fun testNativeNetworkCreation_PhysicalNetwork() {
-        // On T and below, the native network is only created when the agent connects.
-        // Starting in U, the native network is created as soon as the agent is registered.
-        doTestNativeNetworkCreation(expectCreatedImmediately = SdkLevel.isAtLeastU(),
-            intArrayOf(TRANSPORT_CELLULAR))
+        doTestNativeNetworkCreation(
+                expectCreatedImmediately = SHOULD_CREATE_NETWORKS_IMMEDIATELY,
+                intArrayOf(TRANSPORT_CELLULAR))
     }
 
     @Test
diff --git a/tests/cts/net/src/android/net/cts/RateLimitTest.java b/tests/cts/net/src/android/net/cts/RateLimitTest.java
index 36b98fc..5c93738 100644
--- a/tests/cts/net/src/android/net/cts/RateLimitTest.java
+++ b/tests/cts/net/src/android/net/cts/RateLimitTest.java
@@ -36,6 +36,7 @@
 import android.icu.text.MessageFormat;
 import android.net.ConnectivityManager;
 import android.net.ConnectivitySettingsManager;
+import android.net.ConnectivityThread;
 import android.net.InetAddresses;
 import android.net.IpPrefix;
 import android.net.LinkAddress;
@@ -189,7 +190,19 @@
             // whatever happens, don't leave the device in rate limited state.
             ConnectivitySettingsManager.setIngressRateLimitInBytesPerSecond(mContext, -1);
         }
-        if (mSocket != null) mSocket.close();
+        if (mSocket == null) {
+            // HACK(b/272147742): dump ConnectivityThread if test initialization failed.
+            final StackTraceElement[] elements = ConnectivityThread.get().getStackTrace();
+            final StringBuilder sb = new StringBuilder();
+            // Skip first element as it includes the invocation of getStackTrace()
+            for (int i = 1; i < elements.length; i++) {
+                sb.append(elements[i]);
+                sb.append("\n");
+            }
+            Log.e(TAG, sb.toString());
+        } else {
+            mSocket.close();
+        }
         if (mNetworkAgent != null) mNetworkAgent.unregister();
         if (mTunInterface != null) mTunInterface.getFileDescriptor().close();
         if (mCm != null) mCm.unregisterNetworkCallback(mNetworkCallback);
diff --git a/tests/integration/OWNERS b/tests/integration/OWNERS
new file mode 100644
index 0000000..3101da5
--- /dev/null
+++ b/tests/integration/OWNERS
@@ -0,0 +1,2 @@
+# Bug template url: http://b/new?component=31808
+# TODO: move bug template config to common owners file once b/226427845 is resolved
\ No newline at end of file
diff --git a/tests/mts/OWNERS b/tests/mts/OWNERS
new file mode 100644
index 0000000..3101da5
--- /dev/null
+++ b/tests/mts/OWNERS
@@ -0,0 +1,2 @@
+# Bug template url: http://b/new?component=31808
+# TODO: move bug template config to common owners file once b/226427845 is resolved
\ No newline at end of file
diff --git a/tests/native/connectivity_native_test/OWNERS b/tests/native/connectivity_native_test/OWNERS
index 8dfa455..fbfcf92 100644
--- a/tests/native/connectivity_native_test/OWNERS
+++ b/tests/native/connectivity_native_test/OWNERS
@@ -1,3 +1,4 @@
-# Bug component: 31808
+# Bug template url: http://b/new?component=31808
+# TODO: move bug template config to common owners file once b/226427845 is resolved
 set noparent
 file:platform/packages/modules/Connectivity:master:/OWNERS_core_networking_xts
diff --git a/tests/unit/OWNERS b/tests/unit/OWNERS
new file mode 100644
index 0000000..3101da5
--- /dev/null
+++ b/tests/unit/OWNERS
@@ -0,0 +1,2 @@
+# Bug template url: http://b/new?component=31808
+# TODO: move bug template config to common owners file once b/226427845 is resolved
\ No newline at end of file
diff --git a/tests/unit/java/com/android/server/ConnectivityServiceTest.java b/tests/unit/java/com/android/server/ConnectivityServiceTest.java
index 8de6a31..bbde9b4 100755
--- a/tests/unit/java/com/android/server/ConnectivityServiceTest.java
+++ b/tests/unit/java/com/android/server/ConnectivityServiceTest.java
@@ -2274,12 +2274,12 @@
             mDestroySocketsWrapper.destroyLiveTcpSocketsByOwnerUids(ownerUids);
         }
 
-        final ArrayTrackRecord<Long>.ReadHead mScheduledEvaluationTimeouts =
-                new ArrayTrackRecord<Long>().newReadHead();
+        final ArrayTrackRecord<Pair<Integer, Long>>.ReadHead mScheduledEvaluationTimeouts =
+                new ArrayTrackRecord<Pair<Integer, Long>>().newReadHead();
         @Override
         public void scheduleEvaluationTimeout(@NonNull Handler handler,
                 @NonNull final Network network, final long delayMs) {
-            mScheduledEvaluationTimeouts.add(delayMs);
+            mScheduledEvaluationTimeouts.add(new Pair<>(network.netId, delayMs));
             super.scheduleEvaluationTimeout(handler, network, delayMs);
         }
     }
@@ -2973,24 +2973,22 @@
         if (expectLingering) {
             generalCb.expectLosing(net1);
         }
+        generalCb.expectCaps(net2, c -> c.hasCapability(NET_CAPABILITY_VALIDATED));
+        defaultCb.expectAvailableDoubleValidatedCallbacks(net2);
 
         // Make sure cell 1 is unwanted immediately if the radio can't time share, but only
         // after some delay if it can.
         if (expectLingering) {
-            generalCb.expectCaps(net2, c -> c.hasCapability(NET_CAPABILITY_VALIDATED));
-            defaultCb.expectAvailableDoubleValidatedCallbacks(net2);
             net1.assertNotDisconnected(TEST_CALLBACK_TIMEOUT_MS); // always incurs the timeout
             generalCb.assertNoCallback();
             // assertNotDisconnected waited for TEST_CALLBACK_TIMEOUT_MS, so waiting for the
             // linger period gives TEST_CALLBACK_TIMEOUT_MS time for the event to process.
             net1.expectDisconnected(UNREASONABLY_LONG_ALARM_WAIT_MS);
-            generalCb.expect(LOST, net1);
         } else {
             net1.expectDisconnected(TEST_CALLBACK_TIMEOUT_MS);
-            generalCb.expect(LOST, net1);
-            generalCb.expectCaps(net2, c -> c.hasCapability(NET_CAPABILITY_VALIDATED));
-            defaultCb.expectAvailableDoubleValidatedCallbacks(net2);
         }
+        net1.disconnect();
+        generalCb.expect(LOST, net1);
 
         // Remove primary from net 2
         net2.setScore(new NetworkScore.Builder().build());
@@ -6153,7 +6151,7 @@
     }
 
     public void doTestPreferBadWifi(final boolean avoidBadWifi,
-            final boolean preferBadWifi,
+            final boolean preferBadWifi, final boolean explicitlySelected,
             @NonNull Predicate<Long> checkUnvalidationTimeout) throws Exception {
         // Pretend we're on a carrier that restricts switching away from bad wifi, and
         // depending on the parameter one that may indeed prefer bad wifi.
@@ -6177,10 +6175,13 @@
         mDefaultNetworkCallback.expectAvailableThenValidatedCallbacks(mCellAgent);
 
         mWiFiAgent = new TestNetworkAgentWrapper(TRANSPORT_WIFI);
+        mWiFiAgent.explicitlySelected(explicitlySelected, false /* acceptUnvalidated */);
         mWiFiAgent.connect(false);
         wifiCallback.expectAvailableCallbacksUnvalidated(mWiFiAgent);
 
-        mDeps.mScheduledEvaluationTimeouts.poll(TIMEOUT_MS, t -> checkUnvalidationTimeout.test(t));
+        assertNotNull(mDeps.mScheduledEvaluationTimeouts.poll(TIMEOUT_MS,
+                t -> t.first == mWiFiAgent.getNetwork().netId
+                        && checkUnvalidationTimeout.test(t.second)));
 
         if (!avoidBadWifi && preferBadWifi) {
             expectUnvalidationCheckWillNotify(mWiFiAgent, NotificationType.LOST_INTERNET);
@@ -6196,27 +6197,33 @@
         // Starting with U this mode is no longer supported and can't actually be tested
         assumeFalse(mDeps.isAtLeastU());
         doTestPreferBadWifi(false /* avoidBadWifi */, false /* preferBadWifi */,
-                timeout -> timeout < 14_000);
+                false /* explicitlySelected */, timeout -> timeout < 14_000);
     }
 
     @Test
-    public void testPreferBadWifi_doNotAvoid_doPrefer() throws Exception {
+    public void testPreferBadWifi_doNotAvoid_doPrefer_notExplicit() throws Exception {
         doTestPreferBadWifi(false /* avoidBadWifi */, true /* preferBadWifi */,
-                timeout -> timeout > 14_000);
+                false /* explicitlySelected */, timeout -> timeout > 14_000);
+    }
+
+    @Test
+    public void testPreferBadWifi_doNotAvoid_doPrefer_explicitlySelected() throws Exception {
+        doTestPreferBadWifi(false /* avoidBadWifi */, true /* preferBadWifi */,
+                true /* explicitlySelected */, timeout -> timeout < 14_000);
     }
 
     @Test
     public void testPreferBadWifi_doAvoid_doNotPrefer() throws Exception {
         // If avoidBadWifi=true, then preferBadWifi should be irrelevant. Test anyway.
         doTestPreferBadWifi(true /* avoidBadWifi */, false /* preferBadWifi */,
-                timeout -> timeout < 14_000);
+                false /* explicitlySelected */, timeout -> timeout < 14_000);
     }
 
     @Test
     public void testPreferBadWifi_doAvoid_doPrefer() throws Exception {
         // If avoidBadWifi=true, then preferBadWifi should be irrelevant. Test anyway.
         doTestPreferBadWifi(true /* avoidBadWifi */, true /* preferBadWifi */,
-                timeout -> timeout < 14_000);
+                false /* explicitlySelected */, timeout -> timeout < 14_000);
     }
 
     @Test
@@ -6850,17 +6857,19 @@
 
     @Test
     public void testPacketKeepalives() throws Exception {
-        InetAddress myIPv4 = InetAddress.getByName("192.0.2.129");
+        final LinkAddress v4Addr = new LinkAddress("192.0.2.129/24");
+        final InetAddress myIPv4 = v4Addr.getAddress();
         InetAddress notMyIPv4 = InetAddress.getByName("192.0.2.35");
         InetAddress myIPv6 = InetAddress.getByName("2001:db8::1");
         InetAddress dstIPv4 = InetAddress.getByName("8.8.8.8");
         InetAddress dstIPv6 = InetAddress.getByName("2001:4860:4860::8888");
-
+        doReturn(getClatInterfaceConfigParcel(v4Addr)).when(mMockNetd)
+                .interfaceGetCfg(CLAT_MOBILE_IFNAME);
         final int validKaInterval = 15;
         final int invalidKaInterval = 9;
 
         LinkProperties lp = new LinkProperties();
-        lp.setInterfaceName("wlan12");
+        lp.setInterfaceName(MOBILE_IFNAME);
         lp.addLinkAddress(new LinkAddress(myIPv6, 64));
         lp.addLinkAddress(new LinkAddress(myIPv4, 25));
         lp.addRoute(new RouteInfo(InetAddress.getByName("fe80::1234")));
diff --git a/tests/unit/java/com/android/server/connectivity/AutomaticOnOffKeepaliveTrackerTest.java b/tests/unit/java/com/android/server/connectivity/AutomaticOnOffKeepaliveTrackerTest.java
index 9e604e3..eeffbe1 100644
--- a/tests/unit/java/com/android/server/connectivity/AutomaticOnOffKeepaliveTrackerTest.java
+++ b/tests/unit/java/com/android/server/connectivity/AutomaticOnOffKeepaliveTrackerTest.java
@@ -52,6 +52,7 @@
 import android.content.res.Resources;
 import android.net.INetd;
 import android.net.ISocketKeepaliveCallback;
+import android.net.InetAddresses;
 import android.net.KeepalivePacketData;
 import android.net.LinkAddress;
 import android.net.LinkProperties;
@@ -116,7 +117,8 @@
     private static final int MOCK_RESOURCE_ID = 5;
     private static final int TEST_KEEPALIVE_INTERVAL_SEC = 10;
     private static final int TEST_KEEPALIVE_INVALID_INTERVAL_SEC = 9;
-
+    private static final byte[] V4_SRC_ADDR = new byte[] { (byte) 192, 0, 0, (byte) 129 };
+    private static final String TEST_V4_IFACE = "v4-testIface";
     private AutomaticOnOffKeepaliveTracker mAOOKeepaliveTracker;
     private HandlerThread mHandlerThread;
 
@@ -327,6 +329,8 @@
                 NetworkInfo.DetailedState.CONNECTED, "test reason", "test extra info");
         doReturn(new Network(TEST_NETID)).when(mNai).network();
         mNai.linkProperties = new LinkProperties();
+        doReturn(null).when(mNai).translateV4toClatV6(any());
+        doReturn(null).when(mNai).getClatv6SrcAddress();
 
         doReturn(PERMISSION_GRANTED).when(mCtx).checkPermission(any() /* permission */,
                 anyInt() /* pid */, anyInt() /* uid */);
@@ -429,8 +433,7 @@
     }
 
     private TestKeepaliveInfo doStartNattKeepalive(int intervalSeconds) throws Exception {
-        final InetAddress srcAddress = InetAddress.getByAddress(
-                new byte[] { (byte) 192, 0, 0, (byte) 129 });
+        final InetAddress srcAddress = InetAddress.getByAddress(V4_SRC_ADDR);
         final int srcPort = 12345;
         final InetAddress dstAddress = InetAddress.getByAddress(new byte[] {8, 8, 8, 8});
         final int dstPort = 12345;
@@ -610,6 +613,42 @@
     }
 
     @Test
+    public void testStartNattKeepalive_addressTranslationOnClat() throws Exception {
+        final InetAddress v6AddrSrc = InetAddresses.parseNumericAddress("2001:db8::1");
+        final InetAddress v6AddrDst = InetAddresses.parseNumericAddress("2001:db8::2");
+        doReturn(v6AddrDst).when(mNai).translateV4toClatV6(any());
+        doReturn(v6AddrSrc).when(mNai).getClatv6SrcAddress();
+        doReturn(InetAddress.getByAddress(V4_SRC_ADDR)).when(mNai).getClatv4SrcAddress();
+        // Setup nai to add clat address
+        final LinkProperties stacked = new LinkProperties();
+        stacked.setInterfaceName(TEST_V4_IFACE);
+        mNai.linkProperties.addStackedLink(stacked);
+
+        final TestKeepaliveInfo testInfo = doStartNattKeepalive();
+        final ArgumentCaptor<NattKeepalivePacketData> kpdCaptor =
+                ArgumentCaptor.forClass(NattKeepalivePacketData.class);
+        verify(mNai).onStartNattSocketKeepalive(
+                eq(TEST_SLOT), eq(TEST_KEEPALIVE_INTERVAL_SEC), kpdCaptor.capture());
+        final NattKeepalivePacketData kpd = kpdCaptor.getValue();
+        // Verify the addresses are updated to v6 when clat is started.
+        assertEquals(v6AddrSrc, kpd.getSrcAddress());
+        assertEquals(v6AddrDst, kpd.getDstAddress());
+
+        triggerEventKeepalive(TEST_SLOT, SocketKeepalive.SUCCESS);
+        verify(testInfo.socketKeepaliveCallback).onStarted();
+
+        // Remove clat address should stop the keepalive.
+        doReturn(null).when(mNai).getClatv6SrcAddress();
+        visibleOnHandlerThread(
+                mTestHandler, () -> mAOOKeepaliveTracker.handleCheckKeepalivesStillValid(mNai));
+        checkAndProcessKeepaliveStop();
+        assertNull(getAutoKiForBinder(testInfo.binder));
+
+        verify(testInfo.socketKeepaliveCallback).onError(SocketKeepalive.ERROR_INVALID_IP_ADDRESS);
+        verifyNoMoreInteractions(ignoreStubs(testInfo.socketKeepaliveCallback));
+    }
+
+    @Test
     public void testHandleEventSocketKeepalive_startingFailureHardwareError() throws Exception {
         final TestKeepaliveInfo testInfo = doStartNattKeepalive();
 
diff --git a/tests/unit/java/com/android/server/connectivity/KeepaliveStatsTrackerTest.java b/tests/unit/java/com/android/server/connectivity/KeepaliveStatsTrackerTest.java
index 0d2e540..0d1b548 100644
--- a/tests/unit/java/com/android/server/connectivity/KeepaliveStatsTrackerTest.java
+++ b/tests/unit/java/com/android/server/connectivity/KeepaliveStatsTrackerTest.java
@@ -19,6 +19,8 @@
 import static android.net.NetworkCapabilities.TRANSPORT_CELLULAR;
 import static android.net.NetworkCapabilities.TRANSPORT_WIFI;
 
+import static com.android.testutils.DevSdkIgnoreRule.IgnoreAfter;
+import static com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo;
 import static com.android.testutils.HandlerUtils.visibleOnHandlerThread;
 
 import static org.junit.Assert.assertArrayEquals;
@@ -31,6 +33,7 @@
 import static org.mockito.Mockito.doCallRealMethod;
 import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.never;
 import static org.mockito.Mockito.verify;
 
 import android.content.BroadcastReceiver;
@@ -62,6 +65,7 @@
 import com.android.testutils.HandlerUtils;
 
 import org.junit.Before;
+import org.junit.Rule;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.mockito.ArgumentCaptor;
@@ -103,6 +107,8 @@
                 .build();
     }
 
+    @Rule public final DevSdkIgnoreRule ignoreRule = new DevSdkIgnoreRule();
+
     private HandlerThread mHandlerThread;
     private Handler mTestHandler;
 
@@ -1192,4 +1198,43 @@
                     expectKeepaliveCarrierStats1, expectKeepaliveCarrierStats2
                 });
     }
+
+    @Test
+    @IgnoreAfter(Build.VERSION_CODES.S_V2)
+    public void testWriteMetrics_doNothingBeforeT() {
+        // Keepalive stats use repeated atoms, which are only supported on T+. If written to statsd
+        // on S- they will bootloop the system, so they must not be sent on S-. See b/289471411.
+        final int writeTime = 1000;
+        setElapsedRealtime(writeTime);
+        visibleOnHandlerThread(mTestHandler, () -> mKeepaliveStatsTracker.writeAndResetMetrics());
+        verify(mDependencies, never()).writeStats(any());
+    }
+
+    @Test
+    @IgnoreUpTo(Build.VERSION_CODES.S_V2)
+    public void testWriteMetrics() {
+        final int writeTime = 1000;
+
+        final ArgumentCaptor<DailykeepaliveInfoReported> dailyKeepaliveInfoReportedCaptor =
+                ArgumentCaptor.forClass(DailykeepaliveInfoReported.class);
+
+        setElapsedRealtime(writeTime);
+        visibleOnHandlerThread(mTestHandler, () -> mKeepaliveStatsTracker.writeAndResetMetrics());
+        // Ensure writeStats is called with the correct DailykeepaliveInfoReported metrics.
+        verify(mDependencies).writeStats(dailyKeepaliveInfoReportedCaptor.capture());
+        final DailykeepaliveInfoReported dailyKeepaliveInfoReported =
+                dailyKeepaliveInfoReportedCaptor.getValue();
+
+        // Same as the no keepalive case
+        final int[] expectRegisteredDurations = new int[] {writeTime};
+        final int[] expectActiveDurations = new int[] {writeTime};
+        assertDailyKeepaliveInfoReported(
+                dailyKeepaliveInfoReported,
+                /* expectRequestsCount= */ 0,
+                /* expectAutoRequestsCount= */ 0,
+                /* expectAppUids= */ new int[0],
+                expectRegisteredDurations,
+                expectActiveDurations,
+                new KeepaliveCarrierStats[0]);
+    }
 }
diff --git a/tests/unit/java/com/android/server/connectivity/Nat464XlatTest.java b/tests/unit/java/com/android/server/connectivity/Nat464XlatTest.java
index 06e0d6d..58c0114 100644
--- a/tests/unit/java/com/android/server/connectivity/Nat464XlatTest.java
+++ b/tests/unit/java/com/android/server/connectivity/Nat464XlatTest.java
@@ -20,10 +20,12 @@
 
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertTrue;
 import static org.mockito.Mockito.any;
 import static org.mockito.Mockito.anyInt;
 import static org.mockito.Mockito.anyString;
+import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.eq;
 import static org.mockito.Mockito.inOrder;
 import static org.mockito.Mockito.never;
@@ -72,6 +74,7 @@
     static final String STACKED_IFACE = "v4-test0";
     static final LinkAddress V6ADDR = new LinkAddress("2001:db8:1::f00/64");
     static final LinkAddress ADDR = new LinkAddress("192.0.2.5/29");
+    static final String CLAT_V6 = "64:ff9b::1";
     static final String NAT64_PREFIX = "64:ff9b::/96";
     static final String OTHER_NAT64_PREFIX = "2001:db8:0:64::/96";
     static final int NETID = 42;
@@ -132,6 +135,8 @@
         when(mNetd.interfaceGetCfg(eq(STACKED_IFACE))).thenReturn(mConfig);
         mConfig.ipv4Addr = ADDR.getAddress().getHostAddress();
         mConfig.prefixLength =  ADDR.getPrefixLength();
+        doReturn(CLAT_V6).when(mClatCoordinator).clatStart(
+                BASE_IFACE, NETID, new IpPrefix(NAT64_PREFIX));
     }
 
     private void assertRequiresClat(boolean expected, NetworkAgentInfo nai) {
@@ -286,7 +291,8 @@
         assertFalse(c.getValue().getAllInterfaceNames().contains(STACKED_IFACE));
         verify(mDnsResolver).stopPrefix64Discovery(eq(NETID));
         assertIdle(nat);
-
+        // Verify the generated v6 is reset when clat is stopped.
+        assertNull(nat.mIPv6Address);
         // Stacked interface removed notification arrives and is ignored.
         nat.interfaceRemoved(STACKED_IFACE);
         mLooper.dispatchNext();
diff --git a/tests/unit/java/com/android/server/connectivity/NetworkNotificationManagerTest.java b/tests/unit/java/com/android/server/connectivity/NetworkNotificationManagerTest.java
index a27a0bf..967083e 100644
--- a/tests/unit/java/com/android/server/connectivity/NetworkNotificationManagerTest.java
+++ b/tests/unit/java/com/android/server/connectivity/NetworkNotificationManagerTest.java
@@ -62,6 +62,7 @@
 import android.util.DisplayMetrics;
 import android.widget.TextView;
 
+import androidx.annotation.NonNull;
 import androidx.annotation.Nullable;
 import androidx.annotation.StringRes;
 import androidx.test.filters.SmallTest;
@@ -386,14 +387,37 @@
     }
 
     @Test
-    public void testNotifyNoInternetAsDialogWhenHighPriority() throws Exception {
-        doReturn(true).when(mResources).getBoolean(
+    public void testNotifyNoInternet_asNotification() throws Exception {
+        doTestNotifyNotificationAsDialogWhenHighPriority(false, NO_INTERNET);
+    }
+    @Test
+        public void testNotifyNoInternet_asDialog() throws Exception {
+        doTestNotifyNotificationAsDialogWhenHighPriority(true, NO_INTERNET);
+    }
+
+    @Test
+    public void testNotifyLostInternet_asNotification() throws Exception {
+        doTestNotifyNotificationAsDialogWhenHighPriority(false, LOST_INTERNET);
+    }
+
+    @Test
+    public void testNotifyLostInternet_asDialog() throws Exception {
+        doTestNotifyNotificationAsDialogWhenHighPriority(true, LOST_INTERNET);
+    }
+
+    public void doTestNotifyNotificationAsDialogWhenHighPriority(final boolean configActive,
+            @NonNull final NotificationType notifType) throws Exception {
+        doReturn(configActive).when(mResources).getBoolean(
                 R.bool.config_notifyNoInternetAsDialogWhenHighPriority);
 
         final Instrumentation instr = InstrumentationRegistry.getInstrumentation();
         final UiDevice uiDevice =  UiDevice.getInstance(instr);
         final Context ctx = instr.getContext();
         final PowerManager pm = ctx.getSystemService(PowerManager.class);
+        // If the prio of this notif is < that of NETWORK_SWITCH, it's the lowest prio and
+        // therefore it can't be tested whether it cancels other lower-prio notifs.
+        final boolean isLowestPrioNotif = NetworkNotificationManager.priority(notifType)
+                < NetworkNotificationManager.priority(NETWORK_SWITCH);
 
         // Wake up the device (it has no effect if the device is already awake).
         uiDevice.executeShellCommand("input keyevent KEYCODE_WAKEUP");
@@ -409,9 +433,13 @@
                 uiDevice.wait(Until.hasObject(By.pkg(launcherPackageName)),
                         UI_AUTOMATOR_WAIT_TIME_MILLIS));
 
-        mManager.showNotification(TEST_NOTIF_ID, NETWORK_SWITCH, mWifiNai, mCellNai, null, false);
-        // Non-"no internet" notifications are not affected
-        verify(mNotificationManager).notify(eq(TEST_NOTIF_TAG), eq(NETWORK_SWITCH.eventId), any());
+        if (!isLowestPrioNotif) {
+            mManager.showNotification(TEST_NOTIF_ID, NETWORK_SWITCH, mWifiNai, mCellNai,
+                    null, false);
+            // Non-"no internet" notifications are not affected
+            verify(mNotificationManager).notify(eq(TEST_NOTIF_TAG), eq(NETWORK_SWITCH.eventId),
+                    any());
+        }
 
         final String testAction = "com.android.connectivity.coverage.TEST_DIALOG";
         final Intent intent = new Intent(testAction)
@@ -420,22 +448,30 @@
         final PendingIntent pendingIntent = PendingIntent.getActivity(ctx, 0 /* requestCode */,
                 intent, PendingIntent.FLAG_CANCEL_CURRENT | PendingIntent.FLAG_IMMUTABLE);
 
-        mManager.showNotification(TEST_NOTIF_ID, NO_INTERNET, mWifiNai, null /* switchToNai */,
+        mManager.showNotification(TEST_NOTIF_ID, notifType, mWifiNai, null /* switchToNai */,
                 pendingIntent, true /* highPriority */);
 
-        // Previous notifications are still dismissed
-        verify(mNotificationManager).cancel(TEST_NOTIF_TAG, NETWORK_SWITCH.eventId);
+        if (!isLowestPrioNotif) {
+            // Previous notifications are still dismissed
+            verify(mNotificationManager).cancel(TEST_NOTIF_TAG, NETWORK_SWITCH.eventId);
+        }
 
-        // Verify that the activity is shown (the activity shows the action on screen)
-        final UiObject actionText = uiDevice.findObject(new UiSelector().text(testAction));
-        assertTrue("Activity not shown", actionText.waitForExists(TEST_TIMEOUT_MS));
+        if (configActive) {
+            // Verify that the activity is shown (the activity shows the action on screen)
+            final UiObject actionText = uiDevice.findObject(new UiSelector().text(testAction));
+            assertTrue("Activity not shown", actionText.waitForExists(TEST_TIMEOUT_MS));
 
-        // Tapping the text should dismiss the dialog
-        actionText.click();
-        assertTrue("Activity not dismissed", actionText.waitUntilGone(TEST_TIMEOUT_MS));
+            // Tapping the text should dismiss the dialog
+            actionText.click();
+            assertTrue("Activity not dismissed", actionText.waitUntilGone(TEST_TIMEOUT_MS));
 
-        // Verify no NO_INTERNET notification was posted
-        verify(mNotificationManager, never()).notify(any(), eq(NO_INTERNET.eventId), any());
+            // Verify that the notification was not posted
+            verify(mNotificationManager, never()).notify(any(), eq(notifType.eventId), any());
+        } else {
+            // Notification should have been posted, and will have overridden the previous
+            // one because it has the same id (hence no cancel).
+            verify(mNotificationManager).notify(eq(TEST_NOTIF_TAG), eq(notifType.eventId), any());
+        }
     }
 
     private void doNotificationTextTest(NotificationType type, @StringRes int expectedTitleRes,
diff --git a/tests/unit/java/com/android/server/connectivity/VpnTest.java b/tests/unit/java/com/android/server/connectivity/VpnTest.java
index dc50773..7829cb6 100644
--- a/tests/unit/java/com/android/server/connectivity/VpnTest.java
+++ b/tests/unit/java/com/android/server/connectivity/VpnTest.java
@@ -2768,23 +2768,30 @@
                 new PersistableBundle());
     }
 
-    private void verifyMobikeTriggered(List<Network> expected) {
+    private void verifyMobikeTriggered(List<Network> expected, int retryIndex) {
+        // Verify retry is scheduled
+        final long expectedDelayMs = mTestDeps.getValidationFailRecoveryMs(retryIndex);
+        final ArgumentCaptor<Long> delayCaptor = ArgumentCaptor.forClass(Long.class);
+        verify(mExecutor, times(retryIndex + 1)).schedule(
+                any(Runnable.class), delayCaptor.capture(), eq(TimeUnit.MILLISECONDS));
+        final List<Long> delays = delayCaptor.getAllValues();
+        assertEquals(expectedDelayMs, (long) delays.get(delays.size() - 1));
+
         final ArgumentCaptor<Network> networkCaptor = ArgumentCaptor.forClass(Network.class);
-        verify(mIkeSessionWrapper).setNetwork(networkCaptor.capture(),
-                anyInt() /* ipVersion */, anyInt() /* encapType */, anyInt() /* keepaliveDelay */);
+        verify(mIkeSessionWrapper, timeout(TEST_TIMEOUT_MS + expectedDelayMs))
+                .setNetwork(networkCaptor.capture(), anyInt() /* ipVersion */,
+                        anyInt() /* encapType */, anyInt() /* keepaliveDelay */);
         assertEquals(expected, Collections.singletonList(networkCaptor.getValue()));
     }
 
     @Test
     public void testDataStallInIkev2VpnMobikeDisabled() throws Exception {
-        verifySetupPlatformVpn(
+        final PlatformVpnSnapshot vpnSnapShot = verifySetupPlatformVpn(
                 createIkeConfig(createIkeConnectInfo(), false /* isMobikeEnabled */));
 
         doReturn(TEST_NETWORK).when(mMockNetworkAgent).getNetwork();
-        final ConnectivityDiagnosticsCallback connectivityDiagCallback =
-                getConnectivityDiagCallback();
-        final DataStallReport report = createDataStallReport();
-        connectivityDiagCallback.onDataStallSuspected(report);
+        ((Vpn.IkeV2VpnRunner) vpnSnapShot.vpn.mVpnRunner).onValidationStatus(
+                NetworkAgent.VALIDATION_STATUS_NOT_VALID);
 
         // Should not trigger MOBIKE if MOBIKE is not enabled
         verify(mIkeSessionWrapper, never()).setNetwork(any() /* network */,
@@ -2797,19 +2804,11 @@
                 createIkeConfig(createIkeConnectInfo(), true /* isMobikeEnabled */));
 
         doReturn(TEST_NETWORK).when(mMockNetworkAgent).getNetwork();
-        final ConnectivityDiagnosticsCallback connectivityDiagCallback =
-                getConnectivityDiagCallback();
-        final DataStallReport report = createDataStallReport();
-        connectivityDiagCallback.onDataStallSuspected(report);
-
+        ((Vpn.IkeV2VpnRunner) vpnSnapShot.vpn.mVpnRunner).onValidationStatus(
+                NetworkAgent.VALIDATION_STATUS_NOT_VALID);
         // Verify MOBIKE is triggered
-        verifyMobikeTriggered(vpnSnapShot.vpn.mNetworkCapabilities.getUnderlyingNetworks());
-
-        // Expect to skip other data stall event if MOBIKE was started.
-        reset(mIkeSessionWrapper);
-        connectivityDiagCallback.onDataStallSuspected(report);
-        verify(mIkeSessionWrapper, never()).setNetwork(any() /* network */,
-                anyInt() /* ipVersion */, anyInt() /* encapType */, anyInt() /* keepaliveDelay */);
+        verifyMobikeTriggered(vpnSnapShot.vpn.mNetworkCapabilities.getUnderlyingNetworks(),
+                0 /* retryIndex */);
 
         reset(mIkev2SessionCreator);
 
@@ -2819,14 +2818,6 @@
                 NetworkAgent.VALIDATION_STATUS_VALID);
         verify(mIkev2SessionCreator, never()).createIkeSession(
                 any(), any(), any(), any(), any(), any());
-
-        // Send invalid result to verify no ike session reset since the data stall suspected
-        // variables(timer counter and boolean) was reset.
-        ((Vpn.IkeV2VpnRunner) vpnSnapShot.vpn.mVpnRunner).onValidationStatus(
-                NetworkAgent.VALIDATION_STATUS_NOT_VALID);
-        verify(mExecutor, atLeastOnce()).schedule(any(Runnable.class), anyLong(), any());
-        verify(mIkev2SessionCreator, never()).createIkeSession(
-                any(), any(), any(), any(), any(), any());
     }
 
     @Test
@@ -2834,31 +2825,46 @@
         final PlatformVpnSnapshot vpnSnapShot = verifySetupPlatformVpn(
                 createIkeConfig(createIkeConnectInfo(), true /* isMobikeEnabled */));
 
-        final ConnectivityDiagnosticsCallback connectivityDiagCallback =
-                getConnectivityDiagCallback();
-
+        int retry = 0;
         doReturn(TEST_NETWORK).when(mMockNetworkAgent).getNetwork();
-        final DataStallReport report = createDataStallReport();
-        connectivityDiagCallback.onDataStallSuspected(report);
-
-        verifyMobikeTriggered(vpnSnapShot.vpn.mNetworkCapabilities.getUnderlyingNetworks());
+        ((Vpn.IkeV2VpnRunner) vpnSnapShot.vpn.mVpnRunner).onValidationStatus(
+                NetworkAgent.VALIDATION_STATUS_NOT_VALID);
+        verifyMobikeTriggered(vpnSnapShot.vpn.mNetworkCapabilities.getUnderlyingNetworks(),
+                retry++);
 
         reset(mIkev2SessionCreator);
 
+        // Second validation status update.
+        ((Vpn.IkeV2VpnRunner) vpnSnapShot.vpn.mVpnRunner).onValidationStatus(
+                NetworkAgent.VALIDATION_STATUS_NOT_VALID);
+        verifyMobikeTriggered(vpnSnapShot.vpn.mNetworkCapabilities.getUnderlyingNetworks(),
+                retry++);
+
+        // Use real delay to verify reset session will not be performed if there is an existing
+        // recovery for resetting the session.
+        mExecutor.delayMs = TestExecutor.REAL_DELAY;
+        mExecutor.executeDirect = true;
         // Send validation status update should result in ike session reset.
         ((Vpn.IkeV2VpnRunner) vpnSnapShot.vpn.mVpnRunner).onValidationStatus(
                 NetworkAgent.VALIDATION_STATUS_NOT_VALID);
 
-        // Verify reset is scheduled and run.
-        verify(mExecutor, atLeastOnce()).schedule(any(Runnable.class), anyLong(), any());
+        // Verify session reset is scheduled
+        long expectedDelay = mTestDeps.getValidationFailRecoveryMs(retry++);
+        final ArgumentCaptor<Long> delayCaptor = ArgumentCaptor.forClass(Long.class);
+        verify(mExecutor, times(retry)).schedule(any(Runnable.class), delayCaptor.capture(),
+                eq(TimeUnit.MILLISECONDS));
+        final List<Long> delays = delayCaptor.getAllValues();
+        assertEquals(expectedDelay, (long) delays.get(delays.size() - 1));
 
         // Another invalid status reported should not trigger other scheduled recovery.
-        reset(mExecutor);
+        expectedDelay = mTestDeps.getValidationFailRecoveryMs(retry++);
         ((Vpn.IkeV2VpnRunner) vpnSnapShot.vpn.mVpnRunner).onValidationStatus(
                 NetworkAgent.VALIDATION_STATUS_NOT_VALID);
-        verify(mExecutor, never()).schedule(any(Runnable.class), anyLong(), any());
+        verify(mExecutor, never()).schedule(
+                any(Runnable.class), eq(expectedDelay), eq(TimeUnit.MILLISECONDS));
 
-        verify(mIkev2SessionCreator, timeout(TEST_TIMEOUT_MS))
+        // Verify that session being reset
+        verify(mIkev2SessionCreator, timeout(TEST_TIMEOUT_MS + expectedDelay))
                 .createIkeSession(any(), any(), any(), any(), any(), any());
     }
 
@@ -3137,6 +3143,12 @@
         }
 
         @Override
+        public long getValidationFailRecoveryMs(int retryCount) {
+            // Simply return retryCount as the delay seconds for retrying.
+            return retryCount * 100L;
+        }
+
+        @Override
         public ScheduledThreadPoolExecutor newScheduledThreadPoolExecutor() {
             return mExecutor;
         }
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsAdvertiserTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsAdvertiserTest.kt
index c467f45..6a0334f 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsAdvertiserTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsAdvertiserTest.kt
@@ -56,8 +56,8 @@
 private val TEST_ADDR = parseNumericAddress("2001:db8::123")
 private val TEST_LINKADDR = LinkAddress(TEST_ADDR, 64 /* prefixLength */)
 private val TEST_NETWORK_1 = mock(Network::class.java)
-private val TEST_SOCKETKEY_1 = mock(SocketKey::class.java)
-private val TEST_SOCKETKEY_2 = mock(SocketKey::class.java)
+private val TEST_SOCKETKEY_1 = SocketKey(1001 /* interfaceIndex */)
+private val TEST_SOCKETKEY_2 = SocketKey(1002 /* interfaceIndex */)
 private val TEST_HOSTNAME = arrayOf("Android_test", "local")
 private const val TEST_SUBTYPE = "_subtype"
 
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsDiscoveryManagerTests.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsDiscoveryManagerTests.java
index d2298fe..1a4ae5d 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsDiscoveryManagerTests.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsDiscoveryManagerTests.java
@@ -19,7 +19,6 @@
 import static com.android.testutils.DevSdkIgnoreRuleKt.SC_V2;
 
 import static org.mockito.ArgumentMatchers.any;
-import static org.mockito.ArgumentMatchers.anyInt;
 import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.eq;
 import static org.mockito.Mockito.never;
@@ -212,31 +211,31 @@
         runOnHandler(() -> discoveryManager.onResponseReceived(
                 responseForServiceTypeOne, SOCKET_KEY_NULL_NETWORK));
         // Packets for network null are only processed by the ServiceTypeClient for network null
-        verify(mockServiceTypeClientType1NullNetwork).processResponse(responseForServiceTypeOne,
-                SOCKET_KEY_NULL_NETWORK.getInterfaceIndex(), SOCKET_KEY_NULL_NETWORK.getNetwork());
-        verify(mockServiceTypeClientType1Network1, never()).processResponse(any(), anyInt(), any());
-        verify(mockServiceTypeClientType2Network2, never()).processResponse(any(), anyInt(), any());
+        verify(mockServiceTypeClientType1NullNetwork).processResponse(
+                responseForServiceTypeOne, SOCKET_KEY_NULL_NETWORK);
+        verify(mockServiceTypeClientType1Network1, never()).processResponse(any(), any());
+        verify(mockServiceTypeClientType2Network2, never()).processResponse(any(), any());
 
         final MdnsPacket responseForServiceTypeTwo = createMdnsPacket(SERVICE_TYPE_2);
         runOnHandler(() -> discoveryManager.onResponseReceived(
                 responseForServiceTypeTwo, SOCKET_KEY_NETWORK_1));
-        verify(mockServiceTypeClientType1NullNetwork, never()).processResponse(any(), anyInt(),
-                eq(SOCKET_KEY_NETWORK_1.getNetwork()));
-        verify(mockServiceTypeClientType1Network1).processResponse(responseForServiceTypeTwo,
-                SOCKET_KEY_NETWORK_1.getInterfaceIndex(), SOCKET_KEY_NETWORK_1.getNetwork());
-        verify(mockServiceTypeClientType2Network2, never()).processResponse(any(), anyInt(),
-                eq(SOCKET_KEY_NETWORK_1.getNetwork()));
+        verify(mockServiceTypeClientType1NullNetwork, never()).processResponse(any(),
+                eq(SOCKET_KEY_NETWORK_1));
+        verify(mockServiceTypeClientType1Network1).processResponse(
+                responseForServiceTypeTwo, SOCKET_KEY_NETWORK_1);
+        verify(mockServiceTypeClientType2Network2, never()).processResponse(any(),
+                eq(SOCKET_KEY_NETWORK_1));
 
         final MdnsPacket responseForSubtype =
                 createMdnsPacket("subtype._sub._googlecast._tcp.local");
         runOnHandler(() -> discoveryManager.onResponseReceived(
                 responseForSubtype, SOCKET_KEY_NETWORK_2));
-        verify(mockServiceTypeClientType1NullNetwork, never()).processResponse(any(), anyInt(),
-                eq(SOCKET_KEY_NETWORK_2.getNetwork()));
-        verify(mockServiceTypeClientType1Network1, never()).processResponse(any(), anyInt(),
-                eq(SOCKET_KEY_NETWORK_2.getNetwork()));
-        verify(mockServiceTypeClientType2Network2).processResponse(responseForSubtype,
-                SOCKET_KEY_NETWORK_2.getInterfaceIndex(), SOCKET_KEY_NETWORK_2.getNetwork());
+        verify(mockServiceTypeClientType1NullNetwork, never()).processResponse(any(),
+                eq(SOCKET_KEY_NETWORK_2));
+        verify(mockServiceTypeClientType1Network1, never()).processResponse(any(),
+                eq(SOCKET_KEY_NETWORK_2));
+        verify(mockServiceTypeClientType2Network2).processResponse(
+                responseForSubtype, SOCKET_KEY_NETWORK_2);
     }
 
     @Test
@@ -260,15 +259,13 @@
         // Receive a response, it should be processed on both clients.
         final MdnsPacket response = createMdnsPacket(SERVICE_TYPE_1);
         runOnHandler(() -> discoveryManager.onResponseReceived(response, SOCKET_KEY_NETWORK_1));
-        verify(mockServiceTypeClientType1Network1).processResponse(response,
-                SOCKET_KEY_NETWORK_1.getInterfaceIndex(), SOCKET_KEY_NETWORK_1.getNetwork());
-        verify(mockServiceTypeClientType2Network1).processResponse(response,
-                SOCKET_KEY_NETWORK_1.getInterfaceIndex(), SOCKET_KEY_NETWORK_1.getNetwork());
+        verify(mockServiceTypeClientType1Network1).processResponse(response, SOCKET_KEY_NETWORK_1);
+        verify(mockServiceTypeClientType2Network1).processResponse(response, SOCKET_KEY_NETWORK_1);
 
         // The first callback receives a notification that the network has been destroyed,
         // mockServiceTypeClientOne1 should send service removed notifications and remove from the
         // list of clients.
-        runOnHandler(() -> callback.onAllSocketsDestroyed(SOCKET_KEY_NETWORK_1));
+        runOnHandler(() -> callback.onSocketDestroyed(SOCKET_KEY_NETWORK_1));
         verify(mockServiceTypeClientType1Network1).notifySocketDestroyed();
 
         // Receive a response again, it should be processed only on
@@ -276,23 +273,23 @@
         // removed from the list of clients, it is no longer able to process responses.
         runOnHandler(() -> discoveryManager.onResponseReceived(response, SOCKET_KEY_NETWORK_1));
         // Still times(1) as a response was received once previously
-        verify(mockServiceTypeClientType1Network1, times(1)).processResponse(response,
-                SOCKET_KEY_NETWORK_1.getInterfaceIndex(), SOCKET_KEY_NETWORK_1.getNetwork());
-        verify(mockServiceTypeClientType2Network1, times(2)).processResponse(response,
-                SOCKET_KEY_NETWORK_1.getInterfaceIndex(), SOCKET_KEY_NETWORK_1.getNetwork());
+        verify(mockServiceTypeClientType1Network1, times(1)).processResponse(
+                response, SOCKET_KEY_NETWORK_1);
+        verify(mockServiceTypeClientType2Network1, times(2)).processResponse(
+                response, SOCKET_KEY_NETWORK_1);
 
         // The client for NETWORK_1 receives the callback that the NETWORK_2 has been destroyed,
         // mockServiceTypeClientTwo2 shouldn't send any notifications.
-        runOnHandler(() -> callback2.onAllSocketsDestroyed(SOCKET_KEY_NETWORK_2));
+        runOnHandler(() -> callback2.onSocketDestroyed(SOCKET_KEY_NETWORK_2));
         verify(mockServiceTypeClientType2Network1, never()).notifySocketDestroyed();
 
         // Receive a response again, mockServiceTypeClientType2Network1 is still in the list of
         // clients, it's still able to process responses.
         runOnHandler(() -> discoveryManager.onResponseReceived(response, SOCKET_KEY_NETWORK_1));
-        verify(mockServiceTypeClientType1Network1, times(1)).processResponse(response,
-                SOCKET_KEY_NETWORK_1.getInterfaceIndex(), SOCKET_KEY_NETWORK_1.getNetwork());
-        verify(mockServiceTypeClientType2Network1, times(3)).processResponse(response,
-                SOCKET_KEY_NETWORK_1.getInterfaceIndex(), SOCKET_KEY_NETWORK_1.getNetwork());
+        verify(mockServiceTypeClientType1Network1, times(1)).processResponse(
+                response, SOCKET_KEY_NETWORK_1);
+        verify(mockServiceTypeClientType2Network1, times(3)).processResponse(
+                response, SOCKET_KEY_NETWORK_1);
     }
 
     @Test
@@ -310,17 +307,17 @@
         final MdnsPacket response = createMdnsPacket(SERVICE_TYPE_1);
         final int ifIndex = 1;
         runOnHandler(() -> discoveryManager.onResponseReceived(response, SOCKET_KEY_NULL_NETWORK));
-        verify(mockServiceTypeClientType1NullNetwork).processResponse(response,
-                SOCKET_KEY_NULL_NETWORK.getInterfaceIndex(), SOCKET_KEY_NULL_NETWORK.getNetwork());
+        verify(mockServiceTypeClientType1NullNetwork).processResponse(
+                response, SOCKET_KEY_NULL_NETWORK);
 
-        runOnHandler(() -> callback.onAllSocketsDestroyed(SOCKET_KEY_NULL_NETWORK));
+        runOnHandler(() -> callback.onSocketDestroyed(SOCKET_KEY_NULL_NETWORK));
         verify(mockServiceTypeClientType1NullNetwork).notifySocketDestroyed();
 
         // Receive a response again, it should not be processed.
         runOnHandler(() -> discoveryManager.onResponseReceived(response, SOCKET_KEY_NULL_NETWORK));
         // Still times(1) as a response was received once previously
-        verify(mockServiceTypeClientType1NullNetwork, times(1)).processResponse(response,
-                SOCKET_KEY_NULL_NETWORK.getInterfaceIndex(), SOCKET_KEY_NULL_NETWORK.getNetwork());
+        verify(mockServiceTypeClientType1NullNetwork, times(1)).processResponse(
+                response, SOCKET_KEY_NULL_NETWORK);
 
         // Unregister the listener, notifyNetworkUnrequested should be called but other stop methods
         // won't be call because the service type client was unregistered and destroyed. But those
@@ -329,7 +326,7 @@
         verify(socketClient).notifyNetworkUnrequested(mockListenerOne);
         verify(mockServiceTypeClientType1NullNetwork, never()).stopSendAndReceive(any());
         // The stopDiscovery() is only used by MdnsSocketClient, which doesn't send
-        // onAllSocketsDestroyed(). So the socket clients that send onAllSocketsDestroyed() do not
+        // onSocketDestroyed(). So the socket clients that send onSocketDestroyed() do not
         // need to call stopDiscovery().
         verify(socketClient, never()).stopDiscovery();
     }
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClientTest.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClientTest.java
index b812fa6..29de272 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClientTest.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClientTest.java
@@ -28,7 +28,6 @@
 import static org.mockito.Mockito.timeout;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
-import static org.mockito.Mockito.verifyNoMoreInteractions;
 
 import android.net.InetAddresses;
 import android.net.Network;
@@ -67,18 +66,18 @@
     @Mock private MdnsServiceBrowserListener mListener;
     @Mock private MdnsSocketClientBase.Callback mCallback;
     @Mock private SocketCreationCallback mSocketCreationCallback;
-    @Mock private SocketKey mSocketKey;
     private MdnsMultinetworkSocketClient mSocketClient;
     private Handler mHandler;
+    private SocketKey mSocketKey;
 
     @Before
     public void setUp() throws SocketException {
         MockitoAnnotations.initMocks(this);
-        doReturn(mNetwork).when(mSocketKey).getNetwork();
 
         final HandlerThread thread = new HandlerThread("MdnsMultinetworkSocketClientTest");
         thread.start();
         mHandler = new Handler(thread.getLooper());
+        mSocketKey = new SocketKey(1000 /* interfaceIndex */);
         mSocketClient = new MdnsMultinetworkSocketClient(thread.getLooper(), mProvider);
         mHandler.post(() -> mSocketClient.setCallback(mCallback));
     }
@@ -125,10 +124,8 @@
             doReturn(createEmptyNetworkInterface()).when(socket).getInterface();
         }
 
-        final SocketKey tetherSocketKey1 = mock(SocketKey.class);
-        final SocketKey tetherSocketKey2 = mock(SocketKey.class);
-        doReturn(null).when(tetherSocketKey1).getNetwork();
-        doReturn(null).when(tetherSocketKey2).getNetwork();
+        final SocketKey tetherSocketKey1 = new SocketKey(1001 /* interfaceIndex */);
+        final SocketKey tetherSocketKey2 = new SocketKey(1002 /* interfaceIndex */);
         // Notify socket created
         callback.onSocketCreated(mSocketKey, mSocket, List.of());
         verify(mSocketCreationCallback).onSocketCreated(mSocketKey);
@@ -137,8 +134,8 @@
         callback.onSocketCreated(tetherSocketKey2, tetherIfaceSock2, List.of());
         verify(mSocketCreationCallback).onSocketCreated(tetherSocketKey2);
 
-        // Send packet to IPv4 with target network and verify sending has been called.
-        mSocketClient.sendPacketRequestingMulticastResponse(ipv4Packet, mNetwork,
+        // Send packet to IPv4 with mSocketKey and verify sending has been called.
+        mSocketClient.sendPacketRequestingMulticastResponse(ipv4Packet, mSocketKey,
                 false /* onlyUseIpv6OnIpv6OnlyNetworks */);
         HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
         verify(mSocket).send(ipv4Packet);
@@ -146,30 +143,30 @@
         verify(tetherIfaceSock2, never()).send(any());
 
         // Send packet to IPv4 with onlyUseIpv6OnIpv6OnlyNetworks = true, the packet will be sent.
-        mSocketClient.sendPacketRequestingMulticastResponse(ipv4Packet, mNetwork,
+        mSocketClient.sendPacketRequestingMulticastResponse(ipv4Packet, mSocketKey,
                 true /* onlyUseIpv6OnIpv6OnlyNetworks */);
         HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
         verify(mSocket, times(2)).send(ipv4Packet);
         verify(tetherIfaceSock1, never()).send(any());
         verify(tetherIfaceSock2, never()).send(any());
 
-        // Send packet to IPv6 without target network and verify sending has been called.
-        mSocketClient.sendPacketRequestingMulticastResponse(ipv6Packet, null,
+        // Send packet to IPv6 with tetherSocketKey1 and verify sending has been called.
+        mSocketClient.sendPacketRequestingMulticastResponse(ipv6Packet, tetherSocketKey1,
                 false /* onlyUseIpv6OnIpv6OnlyNetworks */);
         HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
         verify(mSocket, never()).send(ipv6Packet);
         verify(tetherIfaceSock1).send(ipv6Packet);
-        verify(tetherIfaceSock2).send(ipv6Packet);
+        verify(tetherIfaceSock2, never()).send(ipv6Packet);
 
         // Send packet to IPv6 with onlyUseIpv6OnIpv6OnlyNetworks = true, the packet will not be
         // sent. Therefore, the tetherIfaceSock1.send() and tetherIfaceSock2.send() are still be
         // called once.
-        mSocketClient.sendPacketRequestingMulticastResponse(ipv6Packet, null,
+        mSocketClient.sendPacketRequestingMulticastResponse(ipv6Packet, tetherSocketKey1,
                 true /* onlyUseIpv6OnIpv6OnlyNetworks */);
         HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
         verify(mSocket, never()).send(ipv6Packet);
         verify(tetherIfaceSock1, times(1)).send(ipv6Packet);
-        verify(tetherIfaceSock2, times(1)).send(ipv6Packet);
+        verify(tetherIfaceSock2, never()).send(ipv6Packet);
     }
 
     @Test
@@ -240,8 +237,8 @@
         doReturn(createEmptyNetworkInterface()).when(socket2).getInterface();
         doReturn(createEmptyNetworkInterface()).when(socket3).getInterface();
 
-        final SocketKey socketKey2 = mock(SocketKey.class);
-        final SocketKey socketKey3 = mock(SocketKey.class);
+        final SocketKey socketKey2 = new SocketKey(1001 /* interfaceIndex */);
+        final SocketKey socketKey3 = new SocketKey(1002 /* interfaceIndex */);
         callback.onSocketCreated(mSocketKey, mSocket, List.of());
         callback.onSocketCreated(socketKey2, socket2, List.of());
         callback.onSocketCreated(socketKey3, socket3, List.of());
@@ -249,8 +246,8 @@
         verify(mSocketCreationCallback).onSocketCreated(socketKey2);
         verify(mSocketCreationCallback).onSocketCreated(socketKey3);
 
-        // Send IPv4 packet on the non-null Network and verify sending has been called.
-        mSocketClient.sendPacketRequestingMulticastResponse(ipv4Packet, mNetwork,
+        // Send IPv4 packet on the mSocketKey and verify sending has been called.
+        mSocketClient.sendPacketRequestingMulticastResponse(ipv4Packet, mSocketKey,
                 false /* onlyUseIpv6OnIpv6OnlyNetworks */);
         HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
         verify(mSocket).send(ipv4Packet);
@@ -278,43 +275,44 @@
         verify(socketCreationCb2).onSocketCreated(socketKey2);
         verify(socketCreationCb2).onSocketCreated(socketKey3);
 
-        // Send IPv4 packet to null network and verify sending to the 2 tethered interface sockets.
-        mSocketClient.sendPacketRequestingMulticastResponse(ipv4Packet, null,
+        // Send IPv4 packet on socket2 and verify sending to the socket2 only.
+        mSocketClient.sendPacketRequestingMulticastResponse(ipv4Packet, socketKey2,
                 false /* onlyUseIpv6OnIpv6OnlyNetworks */);
         HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
         // ipv4Packet still sent only once on mSocket: times(1) matches the packet sent earlier on
         // mNetwork
         verify(mSocket, times(1)).send(ipv4Packet);
         verify(socket2).send(ipv4Packet);
-        verify(socket3).send(ipv4Packet);
+        verify(socket3, never()).send(ipv4Packet);
 
         // Unregister the second request
         mHandler.post(() -> mSocketClient.notifyNetworkUnrequested(listener2));
         verify(mProvider, timeout(DEFAULT_TIMEOUT)).unrequestSocket(callback2);
 
         // Send IPv4 packet again and verify it's still sent a second time
-        mSocketClient.sendPacketRequestingMulticastResponse(ipv4Packet, null,
+        mSocketClient.sendPacketRequestingMulticastResponse(ipv4Packet, socketKey2,
                 false /* onlyUseIpv6OnIpv6OnlyNetworks */);
         HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
         verify(socket2, times(2)).send(ipv4Packet);
-        verify(socket3, times(2)).send(ipv4Packet);
+        verify(socket3, never()).send(ipv4Packet);
 
         // Unrequest remaining sockets
         mHandler.post(() -> mSocketClient.notifyNetworkUnrequested(mListener));
         verify(mProvider, timeout(DEFAULT_TIMEOUT)).unrequestSocket(callback);
 
         // Send IPv4 packet and verify no more sending.
-        mSocketClient.sendPacketRequestingMulticastResponse(ipv4Packet, null,
+        mSocketClient.sendPacketRequestingMulticastResponse(ipv4Packet, mSocketKey,
                 false /* onlyUseIpv6OnIpv6OnlyNetworks */);
         HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
         verify(mSocket, times(1)).send(ipv4Packet);
         verify(socket2, times(2)).send(ipv4Packet);
-        verify(socket3, times(2)).send(ipv4Packet);
+        verify(socket3, never()).send(ipv4Packet);
     }
 
     @Test
     public void testNotifyNetworkUnrequested_SocketsOnNullNetwork() {
         final MdnsInterfaceSocket otherSocket = mock(MdnsInterfaceSocket.class);
+        final SocketKey otherSocketKey = new SocketKey(1001 /* interfaceIndex */);
         final SocketCallback callback = expectSocketCallback(
                 mListener, null /* requestedNetwork */);
         doReturn(createEmptyNetworkInterface()).when(mSocket).getInterface();
@@ -322,33 +320,36 @@
 
         callback.onSocketCreated(mSocketKey, mSocket, List.of());
         verify(mSocketCreationCallback).onSocketCreated(mSocketKey);
-        callback.onSocketCreated(mSocketKey, otherSocket, List.of());
-        verify(mSocketCreationCallback, times(2)).onSocketCreated(mSocketKey);
+        callback.onSocketCreated(otherSocketKey, otherSocket, List.of());
+        verify(mSocketCreationCallback).onSocketCreated(otherSocketKey);
 
-        verify(mSocketCreationCallback, never()).onAllSocketsDestroyed(mSocketKey);
+        verify(mSocketCreationCallback, never()).onSocketDestroyed(mSocketKey);
+        verify(mSocketCreationCallback, never()).onSocketDestroyed(otherSocketKey);
         mHandler.post(() -> mSocketClient.notifyNetworkUnrequested(mListener));
         HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
 
         verify(mProvider).unrequestSocket(callback);
-        verify(mSocketCreationCallback).onAllSocketsDestroyed(mSocketKey);
+        verify(mSocketCreationCallback).onSocketDestroyed(mSocketKey);
+        verify(mSocketCreationCallback).onSocketDestroyed(otherSocketKey);
     }
 
     @Test
     public void testSocketCreatedAndDestroyed_NullNetwork() throws IOException {
         final MdnsInterfaceSocket otherSocket = mock(MdnsInterfaceSocket.class);
+        final SocketKey otherSocketKey = new SocketKey(1001 /* interfaceIndex */);
         final SocketCallback callback = expectSocketCallback(mListener, null /* network */);
         doReturn(createEmptyNetworkInterface()).when(mSocket).getInterface();
         doReturn(createEmptyNetworkInterface()).when(otherSocket).getInterface();
 
         callback.onSocketCreated(mSocketKey, mSocket, List.of());
         verify(mSocketCreationCallback).onSocketCreated(mSocketKey);
-        callback.onSocketCreated(mSocketKey, otherSocket, List.of());
-        verify(mSocketCreationCallback, times(2)).onSocketCreated(mSocketKey);
+        callback.onSocketCreated(otherSocketKey, otherSocket, List.of());
+        verify(mSocketCreationCallback).onSocketCreated(otherSocketKey);
 
         // Notify socket destroyed
         callback.onInterfaceDestroyed(mSocketKey, mSocket);
-        verifyNoMoreInteractions(mSocketCreationCallback);
-        callback.onInterfaceDestroyed(mSocketKey, otherSocket);
-        verify(mSocketCreationCallback).onAllSocketsDestroyed(mSocketKey);
+        verify(mSocketCreationCallback).onSocketDestroyed(mSocketKey);
+        callback.onInterfaceDestroyed(otherSocketKey, otherSocket);
+        verify(mSocketCreationCallback).onSocketDestroyed(otherSocketKey);
     }
 }
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceCacheTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceCacheTest.kt
index f091eea..b43bcf7 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceCacheTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceCacheTest.kt
@@ -16,7 +16,6 @@
 
 package com.android.server.connectivity.mdns
 
-import android.net.Network
 import android.os.Build
 import android.os.Handler
 import android.os.HandlerThread
@@ -32,7 +31,6 @@
 import org.junit.Before
 import org.junit.Test
 import org.junit.runner.RunWith
-import org.mockito.Mockito.mock
 
 private const val SERVICE_NAME_1 = "service-instance-1"
 private const val SERVICE_NAME_2 = "service-instance-2"
@@ -44,7 +42,7 @@
 @RunWith(DevSdkIgnoreRunner::class)
 @DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.S_V2)
 class MdnsServiceCacheTest {
-    private val network = mock(Network::class.java)
+    private val socketKey = SocketKey(null /* network */, INTERFACE_INDEX)
     private val thread = HandlerThread(MdnsServiceCacheTest::class.simpleName)
     private val handler by lazy {
         Handler(thread.looper)
@@ -71,39 +69,47 @@
         return future.get(DEFAULT_TIMEOUT_MS, TimeUnit.MILLISECONDS)
     }
 
-    private fun addOrUpdateService(serviceType: String, network: Network, service: MdnsResponse):
-            Unit = runningOnHandlerAndReturn {
-        serviceCache.addOrUpdateService(serviceType, network, service) }
+    private fun addOrUpdateService(
+            serviceType: String,
+            socketKey: SocketKey,
+            service: MdnsResponse
+    ): Unit = runningOnHandlerAndReturn {
+        serviceCache.addOrUpdateService(serviceType, socketKey, service)
+    }
 
-    private fun removeService(serviceName: String, serviceType: String, network: Network):
+    private fun removeService(serviceName: String, serviceType: String, socketKey: SocketKey):
             Unit = runningOnHandlerAndReturn {
-        serviceCache.removeService(serviceName, serviceType, network) }
+        serviceCache.removeService(serviceName, serviceType, socketKey) }
 
-    private fun getService(serviceName: String, serviceType: String, network: Network):
+    private fun getService(serviceName: String, serviceType: String, socketKey: SocketKey):
             MdnsResponse? = runningOnHandlerAndReturn {
-        serviceCache.getCachedService(serviceName, serviceType, network) }
+        serviceCache.getCachedService(serviceName, serviceType, socketKey) }
 
-    private fun getServices(serviceType: String, network: Network): List<MdnsResponse> =
-        runningOnHandlerAndReturn { serviceCache.getCachedServices(serviceType, network) }
+    private fun getServices(serviceType: String, socketKey: SocketKey): List<MdnsResponse> =
+        runningOnHandlerAndReturn { serviceCache.getCachedServices(serviceType, socketKey) }
 
     @Test
     fun testAddAndRemoveService() {
-        addOrUpdateService(SERVICE_TYPE_1, network, createResponse(SERVICE_NAME_1, SERVICE_TYPE_1))
-        var response = getService(SERVICE_NAME_1, SERVICE_TYPE_1, network)
+        addOrUpdateService(
+                SERVICE_TYPE_1, socketKey, createResponse(SERVICE_NAME_1, SERVICE_TYPE_1))
+        var response = getService(SERVICE_NAME_1, SERVICE_TYPE_1, socketKey)
         assertNotNull(response)
         assertEquals(SERVICE_NAME_1, response.serviceInstanceName)
-        removeService(SERVICE_NAME_1, SERVICE_TYPE_1, network)
-        response = getService(SERVICE_NAME_1, SERVICE_TYPE_1, network)
+        removeService(SERVICE_NAME_1, SERVICE_TYPE_1, socketKey)
+        response = getService(SERVICE_NAME_1, SERVICE_TYPE_1, socketKey)
         assertNull(response)
     }
 
     @Test
     fun testGetCachedServices_multipleServiceTypes() {
-        addOrUpdateService(SERVICE_TYPE_1, network, createResponse(SERVICE_NAME_1, SERVICE_TYPE_1))
-        addOrUpdateService(SERVICE_TYPE_1, network, createResponse(SERVICE_NAME_2, SERVICE_TYPE_1))
-        addOrUpdateService(SERVICE_TYPE_2, network, createResponse(SERVICE_NAME_2, SERVICE_TYPE_2))
+        addOrUpdateService(
+                SERVICE_TYPE_1, socketKey, createResponse(SERVICE_NAME_1, SERVICE_TYPE_1))
+        addOrUpdateService(
+                SERVICE_TYPE_1, socketKey, createResponse(SERVICE_NAME_2, SERVICE_TYPE_1))
+        addOrUpdateService(
+                SERVICE_TYPE_2, socketKey, createResponse(SERVICE_NAME_2, SERVICE_TYPE_2))
 
-        val responses1 = getServices(SERVICE_TYPE_1, network)
+        val responses1 = getServices(SERVICE_TYPE_1, socketKey)
         assertEquals(2, responses1.size)
         assertTrue(responses1.stream().anyMatch { response ->
             response.serviceInstanceName == SERVICE_NAME_1
@@ -111,19 +117,19 @@
         assertTrue(responses1.any { response ->
             response.serviceInstanceName == SERVICE_NAME_2
         })
-        val responses2 = getServices(SERVICE_TYPE_2, network)
+        val responses2 = getServices(SERVICE_TYPE_2, socketKey)
         assertEquals(1, responses2.size)
         assertTrue(responses2.any { response ->
             response.serviceInstanceName == SERVICE_NAME_2
         })
 
-        removeService(SERVICE_NAME_2, SERVICE_TYPE_1, network)
-        val responses3 = getServices(SERVICE_TYPE_1, network)
+        removeService(SERVICE_NAME_2, SERVICE_TYPE_1, socketKey)
+        val responses3 = getServices(SERVICE_TYPE_1, socketKey)
         assertEquals(1, responses3.size)
         assertTrue(responses3.any { response ->
             response.serviceInstanceName == SERVICE_NAME_1
         })
-        val responses4 = getServices(SERVICE_TYPE_2, network)
+        val responses4 = getServices(SERVICE_TYPE_2, socketKey)
         assertEquals(1, responses4.size)
         assertTrue(responses4.any { response ->
             response.serviceInstanceName == SERVICE_NAME_2
@@ -132,5 +138,5 @@
 
     private fun createResponse(serviceInstanceName: String, serviceType: String) = MdnsResponse(
         0 /* now */, "$serviceInstanceName.$serviceType".split(".").toTypedArray(),
-            INTERFACE_INDEX, network)
+            socketKey.interfaceIndex, socketKey.network)
 }
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java
index 9892e9f..cf6275f 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java
@@ -41,6 +41,8 @@
 import android.annotation.Nullable;
 import android.net.InetAddresses;
 import android.net.Network;
+import android.os.Handler;
+import android.os.HandlerThread;
 import android.text.TextUtils;
 
 import com.android.net.module.util.CollectionUtils;
@@ -49,7 +51,9 @@
 import com.android.server.connectivity.mdns.MdnsServiceTypeClient.QueryTaskConfig;
 import com.android.testutils.DevSdkIgnoreRule;
 import com.android.testutils.DevSdkIgnoreRunner;
+import com.android.testutils.HandlerUtils;
 
+import org.junit.After;
 import org.junit.Before;
 import org.junit.Ignore;
 import org.junit.Test;
@@ -82,6 +86,7 @@
 @DevSdkIgnoreRule.IgnoreUpTo(SC_V2)
 public class MdnsServiceTypeClientTests {
     private static final int INTERFACE_INDEX = 999;
+    private static final long DEFAULT_TIMEOUT = 2000L;
     private static final String SERVICE_TYPE = "_googlecast._tcp.local";
     private static final String[] SERVICE_TYPE_LABELS = TextUtils.split(SERVICE_TYPE, "\\.");
     private static final InetSocketAddress IPV4_ADDRESS = new InetSocketAddress(
@@ -119,6 +124,8 @@
 
     private MdnsServiceTypeClient client;
     private SocketKey socketKey;
+    private HandlerThread thread;
+    private Handler handler;
 
     @Before
     @SuppressWarnings("DoNotMock")
@@ -174,9 +181,12 @@
                 .thenReturn(expectedIPv6Packets[14])
                 .thenReturn(expectedIPv6Packets[15]);
 
+        thread = new HandlerThread("MdnsServiceTypeClientTests");
+        thread.start();
+        handler = new Handler(thread.getLooper());
         client =
                 new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor,
-                        mockDecoderClock, socketKey, mockSharedLog) {
+                        mockDecoderClock, socketKey, mockSharedLog, thread.getLooper()) {
                     @Override
                     MdnsPacketWriter createMdnsPacketWriter() {
                         return mockPacketWriter;
@@ -184,11 +194,40 @@
                 };
     }
 
+    @After
+    public void tearDown() {
+        if (thread != null) {
+            thread.quitSafely();
+        }
+    }
+
+    private void runOnHandler(Runnable r) {
+        handler.post(r);
+        HandlerUtils.waitForIdle(handler, DEFAULT_TIMEOUT);
+    }
+
+    private void startSendAndReceive(MdnsServiceBrowserListener listener,
+            MdnsSearchOptions searchOptions) {
+        runOnHandler(() -> client.startSendAndReceive(listener, searchOptions));
+    }
+
+    private void processResponse(MdnsPacket packet, SocketKey socketKey) {
+        runOnHandler(() -> client.processResponse(packet, socketKey));
+    }
+
+    private void stopSendAndReceive(MdnsServiceBrowserListener listener) {
+        runOnHandler(() -> client.stopSendAndReceive(listener));
+    }
+
+    private void notifySocketDestroyed() {
+        runOnHandler(() -> client.notifySocketDestroyed());
+    }
+
     @Test
     public void sendQueries_activeScanMode() {
         MdnsSearchOptions searchOptions =
                 MdnsSearchOptions.newBuilder().addSubtype("12345").setIsPassiveMode(false).build();
-        client.startSendAndReceive(mockListenerOne, searchOptions);
+        startSendAndReceive(mockListenerOne, searchOptions);
 
         // First burst, 3 queries.
         verifyAndSendQuery(0, 0, /* expectsUnicastResponse= */ true);
@@ -228,7 +267,7 @@
                 14, MdnsConfigs.timeBetweenQueriesInBurstMs(), /* expectsUnicastResponse= */ false);
 
         // Stop sending packets.
-        client.stopSendAndReceive(mockListenerOne);
+        stopSendAndReceive(mockListenerOne);
         verify(expectedSendFutures[15]).cancel(true);
     }
 
@@ -236,7 +275,7 @@
     public void sendQueries_reentry_activeScanMode() {
         MdnsSearchOptions searchOptions =
                 MdnsSearchOptions.newBuilder().addSubtype("12345").setIsPassiveMode(false).build();
-        client.startSendAndReceive(mockListenerOne, searchOptions);
+        startSendAndReceive(mockListenerOne, searchOptions);
 
         // First burst, first query is sent.
         verifyAndSendQuery(0, 0, /* expectsUnicastResponse= */ true);
@@ -248,7 +287,7 @@
                         .addSubtype("abcde")
                         .setIsPassiveMode(false)
                         .build();
-        client.startSendAndReceive(mockListenerOne, searchOptions);
+        startSendAndReceive(mockListenerOne, searchOptions);
         // The previous scheduled task should be canceled.
         verify(expectedSendFutures[1]).cancel(true);
 
@@ -260,7 +299,7 @@
                 3, MdnsConfigs.timeBetweenQueriesInBurstMs(), /* expectsUnicastResponse= */ false);
 
         // Stop sending packets.
-        client.stopSendAndReceive(mockListenerOne);
+        stopSendAndReceive(mockListenerOne);
         verify(expectedSendFutures[5]).cancel(true);
     }
 
@@ -268,7 +307,7 @@
     public void sendQueries_passiveScanMode() {
         MdnsSearchOptions searchOptions =
                 MdnsSearchOptions.newBuilder().addSubtype("12345").setIsPassiveMode(true).build();
-        client.startSendAndReceive(mockListenerOne, searchOptions);
+        startSendAndReceive(mockListenerOne, searchOptions);
 
         // First burst, 3 query.
         verifyAndSendQuery(0, 0, /* expectsUnicastResponse= */ true);
@@ -284,15 +323,119 @@
                 false);
 
         // Stop sending packets.
-        client.stopSendAndReceive(mockListenerOne);
+        stopSendAndReceive(mockListenerOne);
         verify(expectedSendFutures[5]).cancel(true);
     }
 
     @Test
+    public void sendQueries_activeScanWithQueryBackoff() {
+        MdnsSearchOptions searchOptions =
+                MdnsSearchOptions.newBuilder().addSubtype("12345").setIsPassiveMode(
+                        false).setNumOfQueriesBeforeBackoff(11).build();
+        startSendAndReceive(mockListenerOne, searchOptions);
+
+        // First burst, 3 queries.
+        verifyAndSendQuery(0, 0, /* expectsUnicastResponse= */ true);
+        verifyAndSendQuery(
+                1, MdnsConfigs.timeBetweenQueriesInBurstMs(), /* expectsUnicastResponse= */ false);
+        verifyAndSendQuery(
+                2, MdnsConfigs.timeBetweenQueriesInBurstMs(), /* expectsUnicastResponse= */ false);
+        // Second burst will be sent after initialTimeBetweenBurstsMs, 3 queries.
+        verifyAndSendQuery(
+                3, MdnsConfigs.initialTimeBetweenBurstsMs(), /* expectsUnicastResponse= */ false);
+        verifyAndSendQuery(
+                4, MdnsConfigs.timeBetweenQueriesInBurstMs(), /* expectsUnicastResponse= */ false);
+        verifyAndSendQuery(
+                5, MdnsConfigs.timeBetweenQueriesInBurstMs(), /* expectsUnicastResponse= */ false);
+        // Third burst will be sent after initialTimeBetweenBurstsMs * 2, 3 queries.
+        verifyAndSendQuery(
+                6, MdnsConfigs.initialTimeBetweenBurstsMs() * 2, /* expectsUnicastResponse= */
+                false);
+        verifyAndSendQuery(
+                7, MdnsConfigs.timeBetweenQueriesInBurstMs(), /* expectsUnicastResponse= */ false);
+        verifyAndSendQuery(
+                8, MdnsConfigs.timeBetweenQueriesInBurstMs(), /* expectsUnicastResponse= */ false);
+        // Forth burst will be sent after initialTimeBetweenBurstsMs * 4, 3 queries.
+        verifyAndSendQuery(
+                9, MdnsConfigs.initialTimeBetweenBurstsMs() * 4, /* expectsUnicastResponse= */
+                false);
+        verifyAndSendQuery(
+                10, MdnsConfigs.timeBetweenQueriesInBurstMs(), /* expectsUnicastResponse= */ false);
+        verifyAndSendQuery(
+                11, MdnsConfigs.timeBetweenQueriesInBurstMs(), /* expectsUnicastResponse= */ false);
+        // In backoff mode, the current scheduled task will be canceled and reschedule if the
+        // 0.8 * smallestRemainingTtl is larger than time to next run.
+        long currentTime = TEST_TTL / 2 + TEST_ELAPSED_REALTIME;
+        doReturn(currentTime).when(mockDecoderClock).elapsedRealtime();
+        processResponse(createResponse(
+                "service-instance-1", "192.0.2.123", 5353,
+                SERVICE_TYPE_LABELS,
+                Collections.emptyMap(), TEST_TTL), socketKey);
+        verifyAndSendQuery(12, (long) (TEST_TTL / 2 * 0.8), /* expectsUnicastResponse= */
+                false);
+        currentTime += (long) (TEST_TTL / 2 * 0.8);
+        doReturn(currentTime).when(mockDecoderClock).elapsedRealtime();
+        verifyAndSendQuery(
+                13, MdnsConfigs.timeBetweenQueriesInBurstMs(), /* expectsUnicastResponse= */ false);
+    }
+
+    @Test
+    public void sendQueries_passiveScanWithQueryBackoff() {
+        MdnsSearchOptions searchOptions =
+                MdnsSearchOptions.newBuilder().addSubtype("12345").setIsPassiveMode(
+                        true).setNumOfQueriesBeforeBackoff(3).build();
+        startSendAndReceive(mockListenerOne, searchOptions);
+        verifyAndSendQuery(0, 0, /* expectsUnicastResponse= */ true);
+        verifyAndSendQuery(
+                1, MdnsConfigs.timeBetweenQueriesInBurstMs(), /* expectsUnicastResponse= */ false);
+        verifyAndSendQuery(
+                2, MdnsConfigs.timeBetweenQueriesInBurstMs(), /* expectsUnicastResponse= */ false);
+        verifyAndSendQuery(3, MdnsConfigs.timeBetweenBurstsMs(), /* expectsUnicastResponse= */
+                false);
+        assertEquals(4, currentThreadExecutor.getNumOfScheduledFuture());
+
+        // In backoff mode, the current scheduled task will be canceled and reschedule if the
+        // 0.8 * smallestRemainingTtl is larger than time to next run.
+        doReturn(TEST_ELAPSED_REALTIME + 20000).when(mockDecoderClock).elapsedRealtime();
+        processResponse(createResponse(
+                "service-instance-1", "192.0.2.123", 5353,
+                SERVICE_TYPE_LABELS,
+                Collections.emptyMap(), TEST_TTL), socketKey);
+        verify(expectedSendFutures[4]).cancel(true);
+        assertEquals(5, currentThreadExecutor.getNumOfScheduledFuture());
+        verifyAndSendQuery(4, 80000 /* timeInMs */, false /* expectsUnicastResponse */);
+        assertEquals(6, currentThreadExecutor.getNumOfScheduledFuture());
+        // Next run should also be scheduled in 0.8 * smallestRemainingTtl
+        verifyAndSendQuery(5, 80000 /* timeInMs */, false /* expectsUnicastResponse */);
+        assertEquals(7, currentThreadExecutor.getNumOfScheduledFuture());
+
+        // If the records is not refreshed, the current scheduled task will not be canceled.
+        doReturn(TEST_ELAPSED_REALTIME + 20001).when(mockDecoderClock).elapsedRealtime();
+        processResponse(createResponse(
+                "service-instance-1", "192.0.2.123", 5353,
+                SERVICE_TYPE_LABELS,
+                Collections.emptyMap(), TEST_TTL,
+                TEST_ELAPSED_REALTIME - 1), socketKey);
+        verify(expectedSendFutures[7], never()).cancel(true);
+
+        // In backoff mode, the current scheduled task will not be canceled if the
+        // 0.8 * smallestRemainingTtl is smaller than time to next run.
+        doReturn(TEST_ELAPSED_REALTIME).when(mockDecoderClock).elapsedRealtime();
+        processResponse(createResponse(
+                "service-instance-1", "192.0.2.123", 5353,
+                SERVICE_TYPE_LABELS,
+                Collections.emptyMap(), TEST_TTL), socketKey);
+        verify(expectedSendFutures[7], never()).cancel(true);
+
+        stopSendAndReceive(mockListenerOne);
+        verify(expectedSendFutures[7]).cancel(true);
+    }
+
+    @Test
     public void sendQueries_reentry_passiveScanMode() {
         MdnsSearchOptions searchOptions =
                 MdnsSearchOptions.newBuilder().addSubtype("12345").setIsPassiveMode(true).build();
-        client.startSendAndReceive(mockListenerOne, searchOptions);
+        startSendAndReceive(mockListenerOne, searchOptions);
 
         // First burst, first query is sent.
         verifyAndSendQuery(0, 0, /* expectsUnicastResponse= */ true);
@@ -304,7 +447,7 @@
                         .addSubtype("abcde")
                         .setIsPassiveMode(true)
                         .build();
-        client.startSendAndReceive(mockListenerOne, searchOptions);
+        startSendAndReceive(mockListenerOne, searchOptions);
         // The previous scheduled task should be canceled.
         verify(expectedSendFutures[1]).cancel(true);
 
@@ -316,7 +459,7 @@
                 3, MdnsConfigs.timeBetweenQueriesInBurstMs(), /* expectsUnicastResponse= */ false);
 
         // Stop sending packets.
-        client.stopSendAndReceive(mockListenerOne);
+        stopSendAndReceive(mockListenerOne);
         verify(expectedSendFutures[5]).cancel(true);
     }
 
@@ -328,7 +471,8 @@
                 MdnsSearchOptions.newBuilder().addSubtype("12345").setIsPassiveMode(false).build();
         QueryTaskConfig config = new QueryTaskConfig(
                 searchOptions.getSubtypes(), searchOptions.isPassiveMode(),
-                false /* onlyUseIpv6OnIpv6OnlyNetworks */, 1, socketKey);
+                false /* onlyUseIpv6OnIpv6OnlyNetworks */, 3 /* numOfQueriesBeforeBackoff */,
+                socketKey);
 
         // This is the first query. We will ask for unicast response.
         assertTrue(config.expectUnicastResponse);
@@ -358,7 +502,8 @@
                 MdnsSearchOptions.newBuilder().addSubtype("12345").setIsPassiveMode(false).build();
         QueryTaskConfig config = new QueryTaskConfig(
                 searchOptions.getSubtypes(), searchOptions.isPassiveMode(),
-                false /* onlyUseIpv6OnIpv6OnlyNetworks */, 1, socketKey);
+                false /* onlyUseIpv6OnIpv6OnlyNetworks */, 3 /* numOfQueriesBeforeBackoff */,
+                socketKey);
 
         // This is the first query. We will ask for unicast response.
         assertTrue(config.expectUnicastResponse);
@@ -386,7 +531,7 @@
     public void testIfPreviousTaskIsCanceledWhenNewSessionStarts() {
         MdnsSearchOptions searchOptions =
                 MdnsSearchOptions.newBuilder().addSubtype("12345").setIsPassiveMode(true).build();
-        client.startSendAndReceive(mockListenerOne, searchOptions);
+        startSendAndReceive(mockListenerOne, searchOptions);
         Runnable firstMdnsTask = currentThreadExecutor.getAndClearSubmittedRunnable();
 
         // Change the sutypes and start a new session.
@@ -396,7 +541,7 @@
                         .addSubtype("abcde")
                         .setIsPassiveMode(true)
                         .build();
-        client.startSendAndReceive(mockListenerOne, searchOptions);
+        startSendAndReceive(mockListenerOne, searchOptions);
 
         // Clear the scheduled runnable.
         currentThreadExecutor.getAndClearLastScheduledRunnable();
@@ -415,9 +560,9 @@
         //MdnsConfigsFlagsImpl.shouldCancelScanTaskWhenFutureIsNull.override(true);
         MdnsSearchOptions searchOptions =
                 MdnsSearchOptions.newBuilder().addSubtype("12345").setIsPassiveMode(true).build();
-        client.startSendAndReceive(mockListenerOne, searchOptions);
+        startSendAndReceive(mockListenerOne, searchOptions);
         // Change the sutypes and start a new session.
-        client.stopSendAndReceive(mockListenerOne);
+        stopSendAndReceive(mockListenerOne);
         // Clear the scheduled runnable.
         currentThreadExecutor.getAndClearLastScheduledRunnable();
 
@@ -432,19 +577,19 @@
     @Test
     public void testQueryScheduledWhenAnsweredFromCache() {
         final MdnsSearchOptions searchOptions = MdnsSearchOptions.getDefaultOptions();
-        client.startSendAndReceive(mockListenerOne, searchOptions);
+        startSendAndReceive(mockListenerOne, searchOptions);
         assertNotNull(currentThreadExecutor.getAndClearSubmittedRunnable());
 
-        client.processResponse(createResponse(
+        processResponse(createResponse(
                 "service-instance-1", "192.0.2.123", 5353,
                 SERVICE_TYPE_LABELS,
-                Collections.emptyMap(), TEST_TTL), /* interfaceIndex= */ 20, mockNetwork);
+                Collections.emptyMap(), TEST_TTL), socketKey);
 
         verify(mockListenerOne).onServiceNameDiscovered(any());
         verify(mockListenerOne).onServiceFound(any());
 
         // File another identical query
-        client.startSendAndReceive(mockListenerTwo, searchOptions);
+        startSendAndReceive(mockListenerTwo, searchOptions);
 
         verify(mockListenerTwo).onServiceNameDiscovered(any());
         verify(mockListenerTwo).onServiceFound(any());
@@ -459,8 +604,7 @@
 
     private static void verifyServiceInfo(MdnsServiceInfo serviceInfo, String serviceName,
             String[] serviceType, List<String> ipv4Addresses, List<String> ipv6Addresses, int port,
-            List<String> subTypes, Map<String, String> attributes, int interfaceIndex,
-            Network network) {
+            List<String> subTypes, Map<String, String> attributes, SocketKey socketKey) {
         assertEquals(serviceName, serviceInfo.getServiceInstanceName());
         assertArrayEquals(serviceType, serviceInfo.getServiceType());
         assertEquals(ipv4Addresses, serviceInfo.getIpv4Addresses());
@@ -471,18 +615,18 @@
             assertTrue(attributes.containsKey(key));
             assertEquals(attributes.get(key), serviceInfo.getAttributeByKey(key));
         }
-        assertEquals(interfaceIndex, serviceInfo.getInterfaceIndex());
-        assertEquals(network, serviceInfo.getNetwork());
+        assertEquals(socketKey.getInterfaceIndex(), serviceInfo.getInterfaceIndex());
+        assertEquals(socketKey.getNetwork(), serviceInfo.getNetwork());
     }
 
     @Test
     public void processResponse_incompleteResponse() {
-        client.startSendAndReceive(mockListenerOne, MdnsSearchOptions.getDefaultOptions());
+        startSendAndReceive(mockListenerOne, MdnsSearchOptions.getDefaultOptions());
 
-        client.processResponse(createResponse(
+        processResponse(createResponse(
                 "service-instance-1", null /* host */, 0 /* port */,
                 SERVICE_TYPE_LABELS,
-                Collections.emptyMap(), TEST_TTL), INTERFACE_INDEX, mockNetwork);
+                Collections.emptyMap(), TEST_TTL), socketKey);
         verify(mockListenerOne).onServiceNameDiscovered(serviceInfoCaptor.capture());
         verifyServiceInfo(serviceInfoCaptor.getAllValues().get(0),
                 "service-instance-1",
@@ -492,8 +636,7 @@
                 /* port= */ 0,
                 /* subTypes= */ List.of(),
                 Collections.emptyMap(),
-                INTERFACE_INDEX,
-                mockNetwork);
+                socketKey);
 
         verify(mockListenerOne, never()).onServiceFound(any(MdnsServiceInfo.class));
         verify(mockListenerOne, never()).onServiceUpdated(any(MdnsServiceInfo.class));
@@ -502,20 +645,20 @@
     @Test
     public void processIPv4Response_completeResponseForNewServiceInstance() throws Exception {
         final String ipV4Address = "192.168.1.1";
-        client.startSendAndReceive(mockListenerOne, MdnsSearchOptions.getDefaultOptions());
+        startSendAndReceive(mockListenerOne, MdnsSearchOptions.getDefaultOptions());
 
         // Process the initial response.
-        client.processResponse(createResponse(
+        processResponse(createResponse(
                 "service-instance-1", ipV4Address, 5353,
                 /* subtype= */ "ABCDE",
-                Collections.emptyMap(), TEST_TTL), /* interfaceIndex= */ 20, mockNetwork);
+                Collections.emptyMap(), TEST_TTL), socketKey);
 
         // Process a second response with a different port and updated text attributes.
-        client.processResponse(createResponse(
+        processResponse(createResponse(
                         "service-instance-1", ipV4Address, 5354,
                         /* subtype= */ "ABCDE",
                         Collections.singletonMap("key", "value"), TEST_TTL),
-                /* interfaceIndex= */ 20, mockNetwork);
+                socketKey);
 
         // Verify onServiceNameDiscovered was called once for the initial response.
         verify(mockListenerOne).onServiceNameDiscovered(serviceInfoCaptor.capture());
@@ -527,8 +670,7 @@
                 5353 /* port */,
                 Collections.singletonList("ABCDE") /* subTypes */,
                 Collections.singletonMap("key", null) /* attributes */,
-                20 /* interfaceIndex */,
-                mockNetwork);
+                socketKey);
 
         // Verify onServiceFound was called once for the initial response.
         verify(mockListenerOne).onServiceFound(serviceInfoCaptor.capture());
@@ -538,8 +680,8 @@
         assertEquals(initialServiceInfo.getPort(), 5353);
         assertEquals(initialServiceInfo.getSubtypes(), Collections.singletonList("ABCDE"));
         assertNull(initialServiceInfo.getAttributeByKey("key"));
-        assertEquals(initialServiceInfo.getInterfaceIndex(), 20);
-        assertEquals(mockNetwork, initialServiceInfo.getNetwork());
+        assertEquals(socketKey.getInterfaceIndex(), initialServiceInfo.getInterfaceIndex());
+        assertEquals(socketKey.getNetwork(), initialServiceInfo.getNetwork());
 
         // Verify onServiceUpdated was called once for the second response.
         verify(mockListenerOne).onServiceUpdated(serviceInfoCaptor.capture());
@@ -550,27 +692,27 @@
         assertTrue(updatedServiceInfo.hasSubtypes());
         assertEquals(updatedServiceInfo.getSubtypes(), Collections.singletonList("ABCDE"));
         assertEquals(updatedServiceInfo.getAttributeByKey("key"), "value");
-        assertEquals(updatedServiceInfo.getInterfaceIndex(), 20);
-        assertEquals(mockNetwork, updatedServiceInfo.getNetwork());
+        assertEquals(socketKey.getInterfaceIndex(), updatedServiceInfo.getInterfaceIndex());
+        assertEquals(socketKey.getNetwork(), updatedServiceInfo.getNetwork());
     }
 
     @Test
     public void processIPv6Response_getCorrectServiceInfo() throws Exception {
         final String ipV6Address = "2000:3333::da6c:63ff:fe7c:7483";
-        client.startSendAndReceive(mockListenerOne, MdnsSearchOptions.getDefaultOptions());
+        startSendAndReceive(mockListenerOne, MdnsSearchOptions.getDefaultOptions());
 
         // Process the initial response.
-        client.processResponse(createResponse(
+        processResponse(createResponse(
                 "service-instance-1", ipV6Address, 5353,
                 /* subtype= */ "ABCDE",
-                Collections.emptyMap(), TEST_TTL), /* interfaceIndex= */ 20, mockNetwork);
+                Collections.emptyMap(), TEST_TTL), socketKey);
 
         // Process a second response with a different port and updated text attributes.
-        client.processResponse(createResponse(
+        processResponse(createResponse(
                         "service-instance-1", ipV6Address, 5354,
                         /* subtype= */ "ABCDE",
                         Collections.singletonMap("key", "value"), TEST_TTL),
-                /* interfaceIndex= */ 20, mockNetwork);
+                socketKey);
 
         // Verify onServiceNameDiscovered was called once for the initial response.
         verify(mockListenerOne).onServiceNameDiscovered(serviceInfoCaptor.capture());
@@ -582,8 +724,7 @@
                 5353 /* port */,
                 Collections.singletonList("ABCDE") /* subTypes */,
                 Collections.singletonMap("key", null) /* attributes */,
-                20 /* interfaceIndex */,
-                mockNetwork);
+                socketKey);
 
         // Verify onServiceFound was called once for the initial response.
         verify(mockListenerOne).onServiceFound(serviceInfoCaptor.capture());
@@ -593,8 +734,8 @@
         assertEquals(initialServiceInfo.getPort(), 5353);
         assertEquals(initialServiceInfo.getSubtypes(), Collections.singletonList("ABCDE"));
         assertNull(initialServiceInfo.getAttributeByKey("key"));
-        assertEquals(initialServiceInfo.getInterfaceIndex(), 20);
-        assertEquals(mockNetwork, initialServiceInfo.getNetwork());
+        assertEquals(socketKey.getInterfaceIndex(), initialServiceInfo.getInterfaceIndex());
+        assertEquals(socketKey.getNetwork(), initialServiceInfo.getNetwork());
 
         // Verify onServiceUpdated was called once for the second response.
         verify(mockListenerOne).onServiceUpdated(serviceInfoCaptor.capture());
@@ -605,8 +746,8 @@
         assertTrue(updatedServiceInfo.hasSubtypes());
         assertEquals(updatedServiceInfo.getSubtypes(), Collections.singletonList("ABCDE"));
         assertEquals(updatedServiceInfo.getAttributeByKey("key"), "value");
-        assertEquals(updatedServiceInfo.getInterfaceIndex(), 20);
-        assertEquals(mockNetwork, updatedServiceInfo.getNetwork());
+        assertEquals(socketKey.getInterfaceIndex(), updatedServiceInfo.getInterfaceIndex());
+        assertEquals(socketKey.getNetwork(), updatedServiceInfo.getNetwork());
     }
 
     private void verifyServiceRemovedNoCallback(MdnsServiceBrowserListener listener) {
@@ -615,61 +756,61 @@
     }
 
     private void verifyServiceRemovedCallback(MdnsServiceBrowserListener listener,
-            String serviceName, String[] serviceType, int interfaceIndex, Network network) {
+            String serviceName, String[] serviceType, SocketKey socketKey) {
         verify(listener).onServiceRemoved(argThat(
                 info -> serviceName.equals(info.getServiceInstanceName())
                         && Arrays.equals(serviceType, info.getServiceType())
-                        && info.getInterfaceIndex() == interfaceIndex
-                        && network.equals(info.getNetwork())));
+                        && info.getInterfaceIndex() == socketKey.getInterfaceIndex()
+                        && socketKey.getNetwork().equals(info.getNetwork())));
         verify(listener).onServiceNameRemoved(argThat(
                 info -> serviceName.equals(info.getServiceInstanceName())
                         && Arrays.equals(serviceType, info.getServiceType())
-                        && info.getInterfaceIndex() == interfaceIndex
-                        && network.equals(info.getNetwork())));
+                        && info.getInterfaceIndex() == socketKey.getInterfaceIndex()
+                        && socketKey.getNetwork().equals(info.getNetwork())));
     }
 
     @Test
     public void processResponse_goodBye() throws Exception {
-        client.startSendAndReceive(mockListenerOne, MdnsSearchOptions.getDefaultOptions());
-        client.startSendAndReceive(mockListenerTwo, MdnsSearchOptions.getDefaultOptions());
+        startSendAndReceive(mockListenerOne, MdnsSearchOptions.getDefaultOptions());
+        startSendAndReceive(mockListenerTwo, MdnsSearchOptions.getDefaultOptions());
 
         final String serviceName = "service-instance-1";
         final String ipV6Address = "2000:3333::da6c:63ff:fe7c:7483";
         // Process the initial response.
-        client.processResponse(createResponse(
+        processResponse(createResponse(
                 serviceName, ipV6Address, 5353,
                 SERVICE_TYPE_LABELS,
-                Collections.emptyMap(), TEST_TTL), INTERFACE_INDEX, mockNetwork);
+                Collections.emptyMap(), TEST_TTL), socketKey);
 
-        client.processResponse(createResponse(
+        processResponse(createResponse(
                 "goodbye-service", ipV6Address, 5353,
                 SERVICE_TYPE_LABELS,
-                Collections.emptyMap(), /* ptrTtlMillis= */ 0L), INTERFACE_INDEX, mockNetwork);
+                Collections.emptyMap(), /* ptrTtlMillis= */ 0L), socketKey);
 
         // Verify removed callback won't be called if the service is not existed.
         verifyServiceRemovedNoCallback(mockListenerOne);
         verifyServiceRemovedNoCallback(mockListenerTwo);
 
         // Verify removed callback would be called.
-        client.processResponse(createResponse(
+        processResponse(createResponse(
                 serviceName, ipV6Address, 5353,
                 SERVICE_TYPE_LABELS,
-                Collections.emptyMap(), 0L), INTERFACE_INDEX, mockNetwork);
+                Collections.emptyMap(), 0L), socketKey);
         verifyServiceRemovedCallback(
-                mockListenerOne, serviceName, SERVICE_TYPE_LABELS, INTERFACE_INDEX, mockNetwork);
+                mockListenerOne, serviceName, SERVICE_TYPE_LABELS, socketKey);
         verifyServiceRemovedCallback(
-                mockListenerTwo, serviceName, SERVICE_TYPE_LABELS, INTERFACE_INDEX, mockNetwork);
+                mockListenerTwo, serviceName, SERVICE_TYPE_LABELS, socketKey);
     }
 
     @Test
     public void reportExistingServiceToNewlyRegisteredListeners() throws Exception {
         // Process the initial response.
-        client.processResponse(createResponse(
+        processResponse(createResponse(
                 "service-instance-1", "192.168.1.1", 5353,
                 /* subtype= */ "ABCDE",
-                Collections.emptyMap(), TEST_TTL), INTERFACE_INDEX, mockNetwork);
+                Collections.emptyMap(), TEST_TTL), socketKey);
 
-        client.startSendAndReceive(mockListenerOne, MdnsSearchOptions.getDefaultOptions());
+        startSendAndReceive(mockListenerOne, MdnsSearchOptions.getDefaultOptions());
 
         // Verify onServiceNameDiscovered was called once for the existing response.
         verify(mockListenerOne).onServiceNameDiscovered(serviceInfoCaptor.capture());
@@ -681,8 +822,7 @@
                 5353 /* port */,
                 Collections.singletonList("ABCDE") /* subTypes */,
                 Collections.singletonMap("key", null) /* attributes */,
-                INTERFACE_INDEX,
-                mockNetwork);
+                socketKey);
 
         // Verify onServiceFound was called once for the existing response.
         verify(mockListenerOne).onServiceFound(serviceInfoCaptor.capture());
@@ -694,12 +834,12 @@
         assertNull(existingServiceInfo.getAttributeByKey("key"));
 
         // Process a goodbye message for the existing response.
-        client.processResponse(createResponse(
+        processResponse(createResponse(
                 "service-instance-1", "192.168.1.1", 5353,
                 SERVICE_TYPE_LABELS,
-                Collections.emptyMap(), /* ptrTtlMillis= */ 0L), INTERFACE_INDEX, mockNetwork);
+                Collections.emptyMap(), /* ptrTtlMillis= */ 0L), socketKey);
 
-        client.startSendAndReceive(mockListenerTwo, MdnsSearchOptions.getDefaultOptions());
+        startSendAndReceive(mockListenerTwo, MdnsSearchOptions.getDefaultOptions());
 
         // Verify onServiceFound was not called on the newly registered listener after the existing
         // response is gone.
@@ -713,21 +853,23 @@
         final String serviceInstanceName = "service-instance-1";
         client =
                 new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor,
-                        mockDecoderClock, socketKey, mockSharedLog) {
+                        mockDecoderClock, socketKey, mockSharedLog, thread.getLooper()) {
                     @Override
                     MdnsPacketWriter createMdnsPacketWriter() {
                         return mockPacketWriter;
                     }
                 };
-        MdnsSearchOptions searchOptions = MdnsSearchOptions.newBuilder().setRemoveExpiredService(
-                true).build();
-        client.startSendAndReceive(mockListenerOne, searchOptions);
+        MdnsSearchOptions searchOptions = MdnsSearchOptions.newBuilder()
+                .setRemoveExpiredService(true)
+                .setNumOfQueriesBeforeBackoff(Integer.MAX_VALUE)
+                .build();
+        startSendAndReceive(mockListenerOne, searchOptions);
         Runnable firstMdnsTask = currentThreadExecutor.getAndClearSubmittedRunnable();
 
         // Process the initial response.
-        client.processResponse(createResponse(
+        processResponse(createResponse(
                 serviceInstanceName, "192.168.1.1", 5353, /* subtype= */ "ABCDE",
-                Collections.emptyMap(), TEST_TTL), INTERFACE_INDEX, mockNetwork);
+                Collections.emptyMap(), TEST_TTL), socketKey);
 
         // Clear the scheduled runnable.
         currentThreadExecutor.getAndClearLastScheduledRunnable();
@@ -744,8 +886,8 @@
         firstMdnsTask.run();
 
         // Verify removed callback was called.
-        verifyServiceRemovedCallback(mockListenerOne, serviceInstanceName, SERVICE_TYPE_LABELS,
-                INTERFACE_INDEX, mockNetwork);
+        verifyServiceRemovedCallback(
+                mockListenerOne, serviceInstanceName, SERVICE_TYPE_LABELS, socketKey);
     }
 
     @Test
@@ -754,19 +896,19 @@
         final String serviceInstanceName = "service-instance-1";
         client =
                 new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor,
-                        mockDecoderClock, socketKey, mockSharedLog) {
+                        mockDecoderClock, socketKey, mockSharedLog, thread.getLooper()) {
                     @Override
                     MdnsPacketWriter createMdnsPacketWriter() {
                         return mockPacketWriter;
                     }
                 };
-        client.startSendAndReceive(mockListenerOne, MdnsSearchOptions.getDefaultOptions());
+        startSendAndReceive(mockListenerOne, MdnsSearchOptions.getDefaultOptions());
         Runnable firstMdnsTask = currentThreadExecutor.getAndClearSubmittedRunnable();
 
         // Process the initial response.
-        client.processResponse(createResponse(
+        processResponse(createResponse(
                 serviceInstanceName, "192.168.1.1", 5353, /* subtype= */ "ABCDE",
-                Collections.emptyMap(), TEST_TTL), INTERFACE_INDEX, mockNetwork);
+                Collections.emptyMap(), TEST_TTL), socketKey);
 
         // Clear the scheduled runnable.
         currentThreadExecutor.getAndClearLastScheduledRunnable();
@@ -787,19 +929,19 @@
         final String serviceInstanceName = "service-instance-1";
         client =
                 new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor,
-                        mockDecoderClock, socketKey, mockSharedLog) {
+                        mockDecoderClock, socketKey, mockSharedLog, thread.getLooper()) {
                     @Override
                     MdnsPacketWriter createMdnsPacketWriter() {
                         return mockPacketWriter;
                     }
                 };
-        client.startSendAndReceive(mockListenerOne, MdnsSearchOptions.getDefaultOptions());
+        startSendAndReceive(mockListenerOne, MdnsSearchOptions.getDefaultOptions());
         Runnable firstMdnsTask = currentThreadExecutor.getAndClearSubmittedRunnable();
 
         // Process the initial response.
-        client.processResponse(createResponse(
+        processResponse(createResponse(
                 serviceInstanceName, "192.168.1.1", 5353, /* subtype= */ "ABCDE",
-                Collections.emptyMap(), TEST_TTL), INTERFACE_INDEX, mockNetwork);
+                Collections.emptyMap(), TEST_TTL), socketKey);
 
         // Clear the scheduled runnable.
         currentThreadExecutor.getAndClearLastScheduledRunnable();
@@ -809,8 +951,8 @@
         firstMdnsTask.run();
 
         // Verify removed callback was called.
-        verifyServiceRemovedCallback(mockListenerOne, serviceInstanceName, SERVICE_TYPE_LABELS,
-                INTERFACE_INDEX, mockNetwork);
+        verifyServiceRemovedCallback(
+                mockListenerOne, serviceInstanceName, SERVICE_TYPE_LABELS, socketKey);
     }
 
     @Test
@@ -818,30 +960,30 @@
         final String serviceName = "service-instance";
         final String ipV4Address = "192.0.2.0";
         final String ipV6Address = "2001:db8::";
-        client.startSendAndReceive(mockListenerOne, MdnsSearchOptions.getDefaultOptions());
+        startSendAndReceive(mockListenerOne, MdnsSearchOptions.getDefaultOptions());
         InOrder inOrder = inOrder(mockListenerOne);
 
         // Process the initial response which is incomplete.
         final String subtype = "ABCDE";
-        client.processResponse(createResponse(
+        processResponse(createResponse(
                 serviceName, null, 5353, subtype,
-                Collections.emptyMap(), TEST_TTL), INTERFACE_INDEX, mockNetwork);
+                Collections.emptyMap(), TEST_TTL), socketKey);
 
         // Process a second response which has ip address to make response become complete.
-        client.processResponse(createResponse(
+        processResponse(createResponse(
                 serviceName, ipV4Address, 5353, subtype,
-                Collections.emptyMap(), TEST_TTL), INTERFACE_INDEX, mockNetwork);
+                Collections.emptyMap(), TEST_TTL), socketKey);
 
         // Process a third response with a different ip address, port and updated text attributes.
-        client.processResponse(createResponse(
+        processResponse(createResponse(
                 serviceName, ipV6Address, 5354, subtype,
-                Collections.singletonMap("key", "value"), TEST_TTL), INTERFACE_INDEX, mockNetwork);
+                Collections.singletonMap("key", "value"), TEST_TTL), socketKey);
 
         // Process the last response which is goodbye message (with the main type, not subtype).
-        client.processResponse(createResponse(
+        processResponse(createResponse(
                         serviceName, ipV6Address, 5354, SERVICE_TYPE_LABELS,
                         Collections.singletonMap("key", "value"), /* ptrTtlMillis= */ 0L),
-                INTERFACE_INDEX, mockNetwork);
+                socketKey);
 
         // Verify onServiceNameDiscovered was first called for the initial response.
         inOrder.verify(mockListenerOne).onServiceNameDiscovered(serviceInfoCaptor.capture());
@@ -853,8 +995,7 @@
                 5353 /* port */,
                 Collections.singletonList(subtype) /* subTypes */,
                 Collections.singletonMap("key", null) /* attributes */,
-                INTERFACE_INDEX,
-                mockNetwork);
+                socketKey);
 
         // Verify onServiceFound was second called for the second response.
         inOrder.verify(mockListenerOne).onServiceFound(serviceInfoCaptor.capture());
@@ -866,8 +1007,7 @@
                 5353 /* port */,
                 Collections.singletonList(subtype) /* subTypes */,
                 Collections.singletonMap("key", null) /* attributes */,
-                INTERFACE_INDEX,
-                mockNetwork);
+                socketKey);
 
         // Verify onServiceUpdated was third called for the third response.
         inOrder.verify(mockListenerOne).onServiceUpdated(serviceInfoCaptor.capture());
@@ -879,8 +1019,7 @@
                 5354 /* port */,
                 Collections.singletonList(subtype) /* subTypes */,
                 Collections.singletonMap("key", "value") /* attributes */,
-                INTERFACE_INDEX,
-                mockNetwork);
+                socketKey);
 
         // Verify onServiceRemoved was called for the last response.
         inOrder.verify(mockListenerOne).onServiceRemoved(serviceInfoCaptor.capture());
@@ -892,8 +1031,7 @@
                 5354 /* port */,
                 Collections.singletonList("ABCDE") /* subTypes */,
                 Collections.singletonMap("key", "value") /* attributes */,
-                INTERFACE_INDEX,
-                mockNetwork);
+                socketKey);
 
         // Verify onServiceNameRemoved was called for the last response.
         inOrder.verify(mockListenerOne).onServiceNameRemoved(serviceInfoCaptor.capture());
@@ -905,14 +1043,13 @@
                 5354 /* port */,
                 Collections.singletonList("ABCDE") /* subTypes */,
                 Collections.singletonMap("key", "value") /* attributes */,
-                INTERFACE_INDEX,
-                mockNetwork);
+                socketKey);
     }
 
     @Test
     public void testProcessResponse_Resolve() throws Exception {
-        client = new MdnsServiceTypeClient(
-                SERVICE_TYPE, mockSocketClient, currentThreadExecutor, socketKey, mockSharedLog);
+        client = new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor,
+                socketKey, mockSharedLog, thread.getLooper());
 
         final String instanceName = "service-instance";
         final String[] hostname = new String[] { "testhost "};
@@ -922,7 +1059,7 @@
         final MdnsSearchOptions resolveOptions = MdnsSearchOptions.newBuilder()
                 .setResolveInstanceName(instanceName).build();
 
-        client.startSendAndReceive(mockListenerOne, resolveOptions);
+        startSendAndReceive(mockListenerOne, resolveOptions);
         InOrder inOrder = inOrder(mockListenerOne, mockSocketClient);
 
         // Verify a query for SRV/TXT was sent, but no PTR query
@@ -932,7 +1069,7 @@
         // Send twice for IPv4 and IPv6
         inOrder.verify(mockSocketClient, times(2)).sendPacketRequestingUnicastResponse(
                 srvTxtQueryCaptor.capture(),
-                eq(mockNetwork), eq(false));
+                eq(socketKey), eq(false));
 
         final MdnsPacket srvTxtQueryPacket = MdnsPacket.parse(
                 new MdnsPacketReader(srvTxtQueryCaptor.getValue()));
@@ -955,7 +1092,7 @@
                 Collections.emptyList() /* authorityRecords */,
                 Collections.emptyList() /* additionalRecords */);
 
-        client.processResponse(srvTxtResponse, INTERFACE_INDEX, mockNetwork);
+        processResponse(srvTxtResponse, socketKey);
 
         // Expect a query for A/AAAA
         final ArgumentCaptor<DatagramPacket> addressQueryCaptor =
@@ -963,7 +1100,7 @@
         currentThreadExecutor.getAndClearLastScheduledRunnable().run();
         inOrder.verify(mockSocketClient, times(2)).sendPacketRequestingMulticastResponse(
                 addressQueryCaptor.capture(),
-                eq(mockNetwork), eq(false));
+                eq(socketKey), eq(false));
 
         final MdnsPacket addressQueryPacket = MdnsPacket.parse(
                 new MdnsPacketReader(addressQueryCaptor.getValue()));
@@ -985,7 +1122,7 @@
                 Collections.emptyList() /* additionalRecords */);
 
         inOrder.verify(mockListenerOne, never()).onServiceNameDiscovered(any());
-        client.processResponse(addressResponse, INTERFACE_INDEX, mockNetwork);
+        processResponse(addressResponse, socketKey);
 
         inOrder.verify(mockListenerOne).onServiceFound(serviceInfoCaptor.capture());
         verifyServiceInfo(serviceInfoCaptor.getValue(),
@@ -996,14 +1133,13 @@
                 1234 /* port */,
                 Collections.emptyList() /* subTypes */,
                 Collections.emptyMap() /* attributes */,
-                INTERFACE_INDEX,
-                mockNetwork);
+                socketKey);
     }
 
     @Test
     public void testRenewTxtSrvInResolve() throws Exception {
         client = new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor,
-                mockDecoderClock, socketKey, mockSharedLog);
+                mockDecoderClock, socketKey, mockSharedLog, thread.getLooper());
 
         final String instanceName = "service-instance";
         final String[] hostname = new String[] { "testhost "};
@@ -1013,7 +1149,7 @@
         final MdnsSearchOptions resolveOptions = MdnsSearchOptions.newBuilder()
                 .setResolveInstanceName(instanceName).build();
 
-        client.startSendAndReceive(mockListenerOne, resolveOptions);
+        startSendAndReceive(mockListenerOne, resolveOptions);
         InOrder inOrder = inOrder(mockListenerOne, mockSocketClient);
 
         // Get the query for SRV/TXT
@@ -1023,7 +1159,7 @@
         // Send twice for IPv4 and IPv6
         inOrder.verify(mockSocketClient, times(2)).sendPacketRequestingUnicastResponse(
                 srvTxtQueryCaptor.capture(),
-                eq(mockNetwork), eq(false));
+                eq(socketKey), eq(false));
 
         final MdnsPacket srvTxtQueryPacket = MdnsPacket.parse(
                 new MdnsPacketReader(srvTxtQueryCaptor.getValue()));
@@ -1050,7 +1186,7 @@
                                 InetAddresses.parseNumericAddress(ipV6Address))),
                 Collections.emptyList() /* authorityRecords */,
                 Collections.emptyList() /* additionalRecords */);
-        client.processResponse(srvTxtResponse, INTERFACE_INDEX, mockNetwork);
+        processResponse(srvTxtResponse, socketKey);
         inOrder.verify(mockListenerOne).onServiceNameDiscovered(any());
         inOrder.verify(mockListenerOne).onServiceFound(any());
 
@@ -1069,7 +1205,7 @@
         // Second and later sends are sent as "expect multicast response" queries
         inOrder.verify(mockSocketClient, times(2)).sendPacketRequestingMulticastResponse(
                 renewalQueryCaptor.capture(),
-                eq(mockNetwork), eq(false));
+                eq(socketKey), eq(false));
         inOrder.verify(mockListenerOne).onDiscoveryQuerySent(any(), anyInt());
         final MdnsPacket renewalPacket = MdnsPacket.parse(
                 new MdnsPacketReader(renewalQueryCaptor.getValue()));
@@ -1095,7 +1231,7 @@
                                 InetAddresses.parseNumericAddress(ipV6Address))),
                 Collections.emptyList() /* authorityRecords */,
                 Collections.emptyList() /* additionalRecords */);
-        client.processResponse(refreshedSrvTxtResponse, INTERFACE_INDEX, mockNetwork);
+        processResponse(refreshedSrvTxtResponse, socketKey);
 
         // Advance time to updatedReceiptTime + 1, expected no refresh query because the cache
         // should contain the record that have update last receipt time.
@@ -1106,8 +1242,8 @@
 
     @Test
     public void testProcessResponse_ResolveExcludesOtherServices() {
-        client = new MdnsServiceTypeClient(
-                SERVICE_TYPE, mockSocketClient, currentThreadExecutor, socketKey, mockSharedLog);
+        client = new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor,
+                socketKey, mockSharedLog, thread.getLooper());
 
         final String requestedInstance = "instance1";
         final String otherInstance = "instance2";
@@ -1119,30 +1255,30 @@
                 // Use different case in the options
                 .setResolveInstanceName(capitalizedRequestInstance).build();
 
-        client.startSendAndReceive(mockListenerOne, resolveOptions);
-        client.startSendAndReceive(mockListenerTwo, MdnsSearchOptions.getDefaultOptions());
+        startSendAndReceive(mockListenerOne, resolveOptions);
+        startSendAndReceive(mockListenerTwo, MdnsSearchOptions.getDefaultOptions());
 
         // Complete response from instanceName
-        client.processResponse(createResponse(
+        processResponse(createResponse(
                         requestedInstance, ipV4Address, 5353, SERVICE_TYPE_LABELS,
                         Collections.emptyMap() /* textAttributes */, TEST_TTL),
-                INTERFACE_INDEX, mockNetwork);
+                socketKey);
 
         // Complete response from otherInstanceName
-        client.processResponse(createResponse(
+        processResponse(createResponse(
                         otherInstance, ipV4Address, 5353, SERVICE_TYPE_LABELS,
                         Collections.emptyMap() /* textAttributes */, TEST_TTL),
-                INTERFACE_INDEX, mockNetwork);
+                socketKey);
 
         // Address update from otherInstanceName
-        client.processResponse(createResponse(
+        processResponse(createResponse(
                 otherInstance, ipV6Address, 5353, SERVICE_TYPE_LABELS,
-                Collections.emptyMap(), TEST_TTL), INTERFACE_INDEX, mockNetwork);
+                Collections.emptyMap(), TEST_TTL), socketKey);
 
         // Goodbye from otherInstanceName
-        client.processResponse(createResponse(
+        processResponse(createResponse(
                 otherInstance, ipV6Address, 5353, SERVICE_TYPE_LABELS,
-                Collections.emptyMap(), 0L /* ttl */), INTERFACE_INDEX, mockNetwork);
+                Collections.emptyMap(), 0L /* ttl */), socketKey);
 
         // mockListenerOne gets notified for the requested instance
         verify(mockListenerOne).onServiceNameDiscovered(
@@ -1170,8 +1306,8 @@
 
     @Test
     public void testProcessResponse_SubtypeDiscoveryLimitedToSubtype() {
-        client = new MdnsServiceTypeClient(
-                SERVICE_TYPE, mockSocketClient, currentThreadExecutor, socketKey, mockSharedLog);
+        client = new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor,
+                socketKey, mockSharedLog, thread.getLooper());
 
         final String matchingInstance = "instance1";
         final String subtype = "_subtype";
@@ -1183,8 +1319,8 @@
                 // Search with different case. Note MdnsSearchOptions subtype doesn't start with "_"
                 .addSubtype("Subtype").build();
 
-        client.startSendAndReceive(mockListenerOne, options);
-        client.startSendAndReceive(mockListenerTwo, MdnsSearchOptions.getDefaultOptions());
+        startSendAndReceive(mockListenerOne, options);
+        startSendAndReceive(mockListenerTwo, MdnsSearchOptions.getDefaultOptions());
 
         // Complete response from instanceName
         final MdnsPacket packetWithoutSubtype = createResponse(
@@ -1207,23 +1343,23 @@
                 newAnswers,
                 packetWithoutSubtype.authorityRecords,
                 packetWithoutSubtype.additionalRecords);
-        client.processResponse(packetWithSubtype, INTERFACE_INDEX, mockNetwork);
+        processResponse(packetWithSubtype, socketKey);
 
         // Complete response from otherInstanceName, without subtype
-        client.processResponse(createResponse(
+        processResponse(createResponse(
                         otherInstance, ipV4Address, 5353, SERVICE_TYPE_LABELS,
                         Collections.emptyMap() /* textAttributes */, TEST_TTL),
-                INTERFACE_INDEX, mockNetwork);
+                socketKey);
 
         // Address update from otherInstanceName
-        client.processResponse(createResponse(
+        processResponse(createResponse(
                 otherInstance, ipV6Address, 5353, SERVICE_TYPE_LABELS,
-                Collections.emptyMap(), TEST_TTL), INTERFACE_INDEX, mockNetwork);
+                Collections.emptyMap(), TEST_TTL), socketKey);
 
         // Goodbye from otherInstanceName
-        client.processResponse(createResponse(
+        processResponse(createResponse(
                 otherInstance, ipV6Address, 5353, SERVICE_TYPE_LABELS,
-                Collections.emptyMap(), 0L /* ttl */), INTERFACE_INDEX, mockNetwork);
+                Collections.emptyMap(), 0L /* ttl */), socketKey);
 
         // mockListenerOne gets notified for the requested instance
         final ArgumentMatcher<MdnsServiceInfo> subtypeInstanceMatcher = info ->
@@ -1251,21 +1387,24 @@
 
     @Test
     public void testNotifySocketDestroyed() throws Exception {
-        client = new MdnsServiceTypeClient(
-                SERVICE_TYPE, mockSocketClient, currentThreadExecutor, socketKey, mockSharedLog);
+        client = new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor,
+                socketKey, mockSharedLog, thread.getLooper());
 
         final String requestedInstance = "instance1";
         final String otherInstance = "instance2";
         final String ipV4Address = "192.0.2.0";
 
         final MdnsSearchOptions resolveOptions = MdnsSearchOptions.newBuilder()
+                .setNumOfQueriesBeforeBackoff(Integer.MAX_VALUE)
                 .setResolveInstanceName("instance1").build();
 
-        client.startSendAndReceive(mockListenerOne, resolveOptions);
+        startSendAndReceive(mockListenerOne, resolveOptions);
         // Ensure the first task is executed so it schedules a future task
         currentThreadExecutor.getAndClearSubmittedFuture().get(
                 TEST_TIMEOUT_MS, TimeUnit.MILLISECONDS);
-        client.startSendAndReceive(mockListenerTwo, MdnsSearchOptions.getDefaultOptions());
+        startSendAndReceive(mockListenerTwo,
+                MdnsSearchOptions.newBuilder().setNumOfQueriesBeforeBackoff(
+                        Integer.MAX_VALUE).build());
 
         // Filing the second request cancels the first future
         verify(expectedSendFutures[0]).cancel(true);
@@ -1275,19 +1414,19 @@
                 TEST_TIMEOUT_MS, TimeUnit.MILLISECONDS);
 
         // Complete response from instanceName
-        client.processResponse(createResponse(
+        processResponse(createResponse(
                         requestedInstance, ipV4Address, 5353, SERVICE_TYPE_LABELS,
                         Collections.emptyMap() /* textAttributes */, TEST_TTL),
-                INTERFACE_INDEX, mockNetwork);
+                socketKey);
 
         // Complete response from otherInstanceName
-        client.processResponse(createResponse(
+        processResponse(createResponse(
                         otherInstance, ipV4Address, 5353, SERVICE_TYPE_LABELS,
                         Collections.emptyMap() /* textAttributes */, TEST_TTL),
-                INTERFACE_INDEX, mockNetwork);
+                socketKey);
 
         verify(expectedSendFutures[1], never()).cancel(true);
-        client.notifySocketDestroyed();
+        notifySocketDestroyed();
         verify(expectedSendFutures[1]).cancel(true);
 
         // mockListenerOne gets notified for the requested instance
@@ -1328,21 +1467,21 @@
 
     private void verifyAndSendQuery(int index, long timeInMs, boolean expectsUnicastResponse,
             boolean multipleSocketDiscovery) {
-        assertEquals(currentThreadExecutor.getAndClearLastScheduledDelayInMs(), timeInMs);
+        assertEquals(timeInMs, currentThreadExecutor.getAndClearLastScheduledDelayInMs());
         currentThreadExecutor.getAndClearLastScheduledRunnable().run();
         if (expectsUnicastResponse) {
             verify(mockSocketClient).sendPacketRequestingUnicastResponse(
-                    expectedIPv4Packets[index], mockNetwork, false);
+                    expectedIPv4Packets[index], socketKey, false);
             if (multipleSocketDiscovery) {
                 verify(mockSocketClient).sendPacketRequestingUnicastResponse(
-                        expectedIPv6Packets[index], mockNetwork, false);
+                        expectedIPv6Packets[index], socketKey, false);
             }
         } else {
             verify(mockSocketClient).sendPacketRequestingMulticastResponse(
-                    expectedIPv4Packets[index], mockNetwork, false);
+                    expectedIPv4Packets[index], socketKey, false);
             if (multipleSocketDiscovery) {
                 verify(mockSocketClient).sendPacketRequestingMulticastResponse(
-                        expectedIPv6Packets[index], mockNetwork, false);
+                        expectedIPv6Packets[index], socketKey, false);
             }
         }
     }
@@ -1417,6 +1556,10 @@
             lastSubmittedFuture = null;
             return val;
         }
+
+        public int getNumOfScheduledFuture() {
+            return futureIndex - 1;
+        }
     }
 
     private MdnsPacket createResponse(
@@ -1435,7 +1578,7 @@
                 textAttributes, ptrTtlMillis);
     }
 
-    // Creates a mDNS response.
+
     private MdnsPacket createResponse(
             @NonNull String serviceInstanceName,
             @Nullable String host,
@@ -1443,6 +1586,19 @@
             @NonNull String[] type,
             @NonNull Map<String, String> textAttributes,
             long ptrTtlMillis) {
+        return createResponse(serviceInstanceName, host, port, type, textAttributes, ptrTtlMillis,
+                TEST_ELAPSED_REALTIME);
+    }
+
+    // Creates a mDNS response.
+    private MdnsPacket createResponse(
+            @NonNull String serviceInstanceName,
+            @Nullable String host,
+            int port,
+            @NonNull String[] type,
+            @NonNull Map<String, String> textAttributes,
+            long ptrTtlMillis,
+            long receiptTimeMillis) {
 
         final ArrayList<MdnsRecord> answerRecords = new ArrayList<>();
 
@@ -1453,7 +1609,7 @@
         final String[] serviceName = serviceNameList.toArray(new String[0]);
         final MdnsPointerRecord pointerRecord = new MdnsPointerRecord(
                 type,
-                TEST_ELAPSED_REALTIME /* receiptTimeMillis */,
+                receiptTimeMillis,
                 false /* cacheFlush */,
                 ptrTtlMillis,
                 serviceName);
@@ -1462,7 +1618,7 @@
         // Set SRV record.
         final MdnsServiceRecord serviceRecord = new MdnsServiceRecord(
                 serviceName,
-                TEST_ELAPSED_REALTIME /* receiptTimeMillis */,
+                receiptTimeMillis,
                 false /* cacheFlush */,
                 TEST_TTL,
                 0 /* servicePriority */,
@@ -1476,7 +1632,7 @@
             final InetAddress addr = InetAddresses.parseNumericAddress(host);
             final MdnsInetAddressRecord inetAddressRecord = new MdnsInetAddressRecord(
                     new String[] {"hostname"} /* name */,
-                    TEST_ELAPSED_REALTIME /* receiptTimeMillis */,
+                    receiptTimeMillis,
                     false /* cacheFlush */,
                     TEST_TTL,
                     addr);
@@ -1490,7 +1646,7 @@
         }
         final MdnsTextRecord textRecord = new MdnsTextRecord(
                 serviceName,
-                TEST_ELAPSED_REALTIME /* receiptTimeMillis */,
+                receiptTimeMillis,
                 false /* cacheFlush */,
                 TEST_TTL,
                 textEntries);
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketProviderTest.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketProviderTest.java
index 0eac5ec..e971de7 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketProviderTest.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketProviderTest.java
@@ -157,7 +157,6 @@
                 TETHERED_IFACE_NAME);
         doReturn(789).when(mDeps).getNetworkInterfaceIndexByName(
                 WIFI_P2P_IFACE_NAME);
-        doReturn(TETHERED_IFACE_IDX).when(mDeps).getInterfaceIndex(any());
         final HandlerThread thread = new HandlerThread("MdnsSocketProviderTest");
         thread.start();
         mHandler = new Handler(thread.getLooper());