Merge "SatelliteAccessController enhancements to handle delete user scenario" into main
diff --git a/Cronet/tests/common/Android.bp b/Cronet/tests/common/Android.bp
index edeb0b3..703f544 100644
--- a/Cronet/tests/common/Android.bp
+++ b/Cronet/tests/common/Android.bp
@@ -43,6 +43,7 @@
     jni_libs: [
         "cronet_aml_components_cronet_android_cronet_tests__testing",
         "cronet_aml_third_party_netty_tcnative_netty_tcnative_so__testing",
+        "libnativecoverage",
     ],
     data: [":cronet_javatests_resources"],
 }
diff --git a/TEST_MAPPING b/TEST_MAPPING
index ab3ed66..d8d4c21 100644
--- a/TEST_MAPPING
+++ b/TEST_MAPPING
@@ -246,6 +246,9 @@
         },
         {
           "exclude-annotation": "com.android.testutils.DnsResolverModuleTest"
+        },
+        {
+          "exclude-annotation": "com.android.testutils.NetworkStackModuleTest"
         }
       ]
     },
diff --git a/Tethering/Android.bp b/Tethering/Android.bp
index e4e6c70..19bcff9 100644
--- a/Tethering/Android.bp
+++ b/Tethering/Android.bp
@@ -74,6 +74,8 @@
         "net-utils-device-common-bpf",
         "net-utils-device-common-ip",
         "net-utils-device-common-netlink",
+        "net-utils-device-common-struct",
+        "net-utils-device-common-struct-base",
         "netd-client",
         "tetheringstatsprotos",
     ],
@@ -98,7 +100,6 @@
     ],
     static_libs: [
         "NetworkStackApiCurrentShims",
-        "net-utils-device-common-struct",
     ],
     apex_available: ["com.android.tethering"],
     lint: {
@@ -115,7 +116,6 @@
     ],
     static_libs: [
         "NetworkStackApiStableShims",
-        "net-utils-device-common-struct",
     ],
     apex_available: ["com.android.tethering"],
     lint: {
diff --git a/Tethering/tests/integration/Android.bp b/Tethering/tests/integration/Android.bp
index 07fa733..337d408 100644
--- a/Tethering/tests/integration/Android.bp
+++ b/Tethering/tests/integration/Android.bp
@@ -33,6 +33,7 @@
         "net-tests-utils",
         "net-utils-device-common",
         "net-utils-device-common-bpf",
+        "net-utils-device-common-struct-base",
         "testables",
         "connectivity-net-module-utils-bpf",
     ],
diff --git a/Tethering/tests/integration/base/android/net/EthernetTetheringTestBase.java b/Tethering/tests/integration/base/android/net/EthernetTetheringTestBase.java
index 1022b06..f696885 100644
--- a/Tethering/tests/integration/base/android/net/EthernetTetheringTestBase.java
+++ b/Tethering/tests/integration/base/android/net/EthernetTetheringTestBase.java
@@ -176,6 +176,7 @@
         // tests, turn tethering on and off before running them.
         MyTetheringEventCallback callback = null;
         TestNetworkInterface testIface = null;
+        assumeTrue(sEm != null);
         try {
             // If the physical ethernet interface is available, do nothing.
             if (isInterfaceForTetheringAvailable()) return;
diff --git a/Tethering/tests/mts/Android.bp b/Tethering/tests/mts/Android.bp
index a80e49e..c4d5636 100644
--- a/Tethering/tests/mts/Android.bp
+++ b/Tethering/tests/mts/Android.bp
@@ -45,6 +45,7 @@
         "junit-params",
         "connectivity-net-module-utils-bpf",
         "net-utils-device-common-bpf",
+        "net-utils-device-common-struct-base",
     ],
 
     jni_libs: [
diff --git a/Tethering/tests/unit/src/com/android/networkstack/tethering/TetheringTest.java b/Tethering/tests/unit/src/com/android/networkstack/tethering/TetheringTest.java
index 750bfce..f01e1bb 100644
--- a/Tethering/tests/unit/src/com/android/networkstack/tethering/TetheringTest.java
+++ b/Tethering/tests/unit/src/com/android/networkstack/tethering/TetheringTest.java
@@ -3622,6 +3622,43 @@
                 InetAddresses.parseNumericAddress(ifaceConfig.ipv4Addr), ifaceConfig.prefixLength);
         assertFalse(sapPrefix.equals(lohsPrefix));
     }
+
+    @Test
+    public void testWifiTetheringWhenP2pActive() throws Exception {
+        initTetheringOnTestThread();
+        // Enable wifi P2P.
+        sendWifiP2pConnectionChanged(true, true, TEST_P2P_IFNAME);
+        verifyInterfaceServingModeStarted(TEST_P2P_IFNAME);
+        verifyTetheringBroadcast(TEST_P2P_IFNAME, EXTRA_AVAILABLE_TETHER);
+        verifyTetheringBroadcast(TEST_P2P_IFNAME, EXTRA_ACTIVE_LOCAL_ONLY);
+        verify(mUpstreamNetworkMonitor).startObserveAllNetworks();
+        // Verify never enable upstream if only P2P active.
+        verify(mUpstreamNetworkMonitor, never()).setTryCell(true);
+        assertEquals(TETHER_ERROR_NO_ERROR, mTethering.getLastErrorForTest(TEST_P2P_IFNAME));
+
+        when(mWifiManager.startTetheredHotspot(any())).thenReturn(true);
+        // Emulate pressing the WiFi tethering button.
+        mTethering.startTethering(createTetheringRequestParcel(TETHERING_WIFI), TEST_CALLER_PKG,
+                null);
+        mLooper.dispatchAll();
+        verify(mWifiManager).startTetheredHotspot(null);
+        verifyNoMoreInteractions(mWifiManager);
+
+        mTethering.interfaceStatusChanged(TEST_WLAN_IFNAME, true);
+        sendWifiApStateChanged(WIFI_AP_STATE_ENABLED, TEST_WLAN_IFNAME, IFACE_IP_MODE_TETHERED);
+
+        verifyTetheringBroadcast(TEST_WLAN_IFNAME, EXTRA_AVAILABLE_TETHER);
+        verify(mWifiManager).updateInterfaceIpState(
+                TEST_WLAN_IFNAME, WifiManager.IFACE_IP_MODE_UNSPECIFIED);
+
+        verify(mWifiManager).updateInterfaceIpState(TEST_WLAN_IFNAME, IFACE_IP_MODE_TETHERED);
+        verifyNoMoreInteractions(mWifiManager);
+
+        verifyTetheringBroadcast(TEST_WLAN_IFNAME, EXTRA_ACTIVE_TETHER);
+        // FIXME: wifi tethering doesn't have upstream when P2P is enabled.
+        verify(mUpstreamNetworkMonitor, never()).setTryCell(true);
+    }
+
     // TODO: Test that a request for hotspot mode doesn't interfere with an
     // already operating tethering mode interface.
 }
diff --git a/bpf_progs/block.c b/bpf_progs/block.c
index 0a2b0b8..152dda6 100644
--- a/bpf_progs/block.c
+++ b/bpf_progs/block.c
@@ -19,8 +19,8 @@
 #include <netinet/in.h>
 #include <stdint.h>
 
-// The resulting .o needs to load on the Android T bpfloader
-#define BPFLOADER_MIN_VER BPFLOADER_T_VERSION
+// The resulting .o needs to load on Android T+
+#define BPFLOADER_MIN_VER BPFLOADER_MAINLINE_T_VERSION
 
 #include "bpf_helpers.h"
 
diff --git a/bpf_progs/clatd.c b/bpf_progs/clatd.c
index addb02f..f83e5ae 100644
--- a/bpf_progs/clatd.c
+++ b/bpf_progs/clatd.c
@@ -30,8 +30,8 @@
 #define __kernel_udphdr udphdr
 #include <linux/udp.h>
 
-// The resulting .o needs to load on the Android T bpfloader
-#define BPFLOADER_MIN_VER BPFLOADER_T_VERSION
+// The resulting .o needs to load on Android T+
+#define BPFLOADER_MIN_VER BPFLOADER_MAINLINE_T_VERSION
 
 #include "bpf_helpers.h"
 #include "bpf_net_helpers.h"
@@ -265,6 +265,10 @@
         *(struct iphdr*)data = ip;
     }
 
+    // Count successfully translated packet
+    __sync_fetch_and_add(&v->packets, 1);
+    __sync_fetch_and_add(&v->bytes, skb->len - l2_header_size);
+
     // Redirect, possibly back to same interface, so tcpdump sees packet twice.
     if (v->oif) return bpf_redirect(v->oif, BPF_F_INGRESS);
 
@@ -416,6 +420,10 @@
     // Copy over the new ipv6 header without an ethernet header.
     *(struct ipv6hdr*)data = ip6;
 
+    // Count successfully translated packet
+    __sync_fetch_and_add(&v->packets, 1);
+    __sync_fetch_and_add(&v->bytes, skb->len);
+
     // Redirect to non v4-* interface.  Tcpdump only sees packet after this redirect.
     return bpf_redirect(v->oif, 0 /* this is effectively BPF_F_EGRESS */);
 }
diff --git a/bpf_progs/clatd.h b/bpf_progs/clatd.h
index b5f1cdc..a75798f 100644
--- a/bpf_progs/clatd.h
+++ b/bpf_progs/clatd.h
@@ -39,8 +39,10 @@
 typedef struct {
     uint32_t oif;           // The output interface to redirect to (0 means don't redirect)
     struct in_addr local4;  // The destination IPv4 address
+    uint64_t packets;       // Count of translated gso (large) packets
+    uint64_t bytes;         // Sum of post-translation skb->len
 } ClatIngress6Value;
-STRUCT_SIZE(ClatIngress6Value, 4 + 4);  // 8
+STRUCT_SIZE(ClatIngress6Value, 4 + 4 + 8 + 8);  // 24
 
 typedef struct {
     uint32_t iif;           // The input interface index
@@ -54,7 +56,9 @@
     struct in6_addr pfx96;   // The destination /96 nat64 prefix, bottom 32 bits must be 0
     bool oifIsEthernet;      // Whether the output interface requires ethernet header
     uint8_t pad[3];
+    uint64_t packets;       // Count of translated gso (large) packets
+    uint64_t bytes;         // Sum of post-translation skb->len
 } ClatEgress4Value;
-STRUCT_SIZE(ClatEgress4Value, 4 + 2 * 16 + 1 + 3);  // 40
+STRUCT_SIZE(ClatEgress4Value, 4 + 2 * 16 + 1 + 3 + 8 + 8);  // 56
 
 #undef STRUCT_SIZE
diff --git a/bpf_progs/dscpPolicy.c b/bpf_progs/dscpPolicy.c
index e845a69..ed114e4 100644
--- a/bpf_progs/dscpPolicy.c
+++ b/bpf_progs/dscpPolicy.c
@@ -27,8 +27,8 @@
 #include <stdint.h>
 #include <string.h>
 
-// The resulting .o needs to load on the Android T bpfloader
-#define BPFLOADER_MIN_VER BPFLOADER_T_VERSION
+// The resulting .o needs to load on Android T+
+#define BPFLOADER_MIN_VER BPFLOADER_MAINLINE_T_VERSION
 
 #include "bpf_helpers.h"
 #include "dscpPolicy.h"
diff --git a/bpf_progs/netd.c b/bpf_progs/netd.c
index 5e401aa..dfc7699 100644
--- a/bpf_progs/netd.c
+++ b/bpf_progs/netd.c
@@ -14,8 +14,8 @@
  * limitations under the License.
  */
 
-// The resulting .o needs to load on the Android T bpfloader
-#define BPFLOADER_MIN_VER BPFLOADER_T_VERSION
+// The resulting .o needs to load on Android T+
+#define BPFLOADER_MIN_VER BPFLOADER_MAINLINE_T_VERSION
 
 #include <bpf_helpers.h>
 #include <linux/bpf.h>
@@ -103,13 +103,13 @@
 // A single-element configuration array, packet tracing is enabled when 'true'.
 DEFINE_BPF_MAP_EXT(packet_trace_enabled_map, ARRAY, uint32_t, bool, 1,
                    AID_ROOT, AID_SYSTEM, 0060, "fs_bpf_net_shared", "", PRIVATE,
-                   BPFLOADER_IGNORED_ON_VERSION, BPFLOADER_MAX_VER, LOAD_ON_ENG,
+                   BPFLOADER_MAINLINE_U_VERSION, BPFLOADER_MAX_VER, LOAD_ON_ENG,
                    LOAD_ON_USER, LOAD_ON_USERDEBUG)
 
 // A ring buffer on which packet information is pushed.
 DEFINE_BPF_RINGBUF_EXT(packet_trace_ringbuf, PacketTrace, PACKET_TRACE_BUF_SIZE,
                        AID_ROOT, AID_SYSTEM, 0060, "fs_bpf_net_shared", "", PRIVATE,
-                       BPFLOADER_IGNORED_ON_VERSION, BPFLOADER_MAX_VER, LOAD_ON_ENG,
+                       BPFLOADER_MAINLINE_U_VERSION, BPFLOADER_MAX_VER, LOAD_ON_ENG,
                        LOAD_ON_USER, LOAD_ON_USERDEBUG);
 
 DEFINE_BPF_MAP_RO_NETD(data_saver_enabled_map, ARRAY, uint32_t, bool,
@@ -516,7 +516,7 @@
 // This program is optional, and enables tracing on Android U+, 5.8+ on user builds.
 DEFINE_BPF_PROG_EXT("cgroupskb/ingress/stats$trace_user", AID_ROOT, AID_SYSTEM,
                     bpf_cgroup_ingress_trace_user, KVER_5_8, KVER_INF,
-                    BPFLOADER_IGNORED_ON_VERSION, BPFLOADER_MAX_VER, OPTIONAL,
+                    BPFLOADER_MAINLINE_U_VERSION, BPFLOADER_MAX_VER, OPTIONAL,
                     "fs_bpf_netd_readonly", "",
                     IGNORE_ON_ENG, LOAD_ON_USER, IGNORE_ON_USERDEBUG)
 (struct __sk_buff* skb) {
@@ -526,7 +526,7 @@
 // This program is required, and enables tracing on Android U+, 5.8+, userdebug/eng.
 DEFINE_BPF_PROG_EXT("cgroupskb/ingress/stats$trace", AID_ROOT, AID_SYSTEM,
                     bpf_cgroup_ingress_trace, KVER_5_8, KVER_INF,
-                    BPFLOADER_IGNORED_ON_VERSION, BPFLOADER_MAX_VER, MANDATORY,
+                    BPFLOADER_MAINLINE_U_VERSION, BPFLOADER_MAX_VER, MANDATORY,
                     "fs_bpf_netd_readonly", "",
                     LOAD_ON_ENG, IGNORE_ON_USER, LOAD_ON_USERDEBUG)
 (struct __sk_buff* skb) {
@@ -548,7 +548,7 @@
 // This program is optional, and enables tracing on Android U+, 5.8+ on user builds.
 DEFINE_BPF_PROG_EXT("cgroupskb/egress/stats$trace_user", AID_ROOT, AID_SYSTEM,
                     bpf_cgroup_egress_trace_user, KVER_5_8, KVER_INF,
-                    BPFLOADER_IGNORED_ON_VERSION, BPFLOADER_MAX_VER, OPTIONAL,
+                    BPFLOADER_MAINLINE_U_VERSION, BPFLOADER_MAX_VER, OPTIONAL,
                     "fs_bpf_netd_readonly", "",
                     IGNORE_ON_ENG, LOAD_ON_USER, IGNORE_ON_USERDEBUG)
 (struct __sk_buff* skb) {
@@ -558,7 +558,7 @@
 // This program is required, and enables tracing on Android U+, 5.8+, userdebug/eng.
 DEFINE_BPF_PROG_EXT("cgroupskb/egress/stats$trace", AID_ROOT, AID_SYSTEM,
                     bpf_cgroup_egress_trace, KVER_5_8, KVER_INF,
-                    BPFLOADER_IGNORED_ON_VERSION, BPFLOADER_MAX_VER, MANDATORY,
+                    BPFLOADER_MAINLINE_U_VERSION, BPFLOADER_MAX_VER, MANDATORY,
                     "fs_bpf_netd_readonly", "",
                     LOAD_ON_ENG, IGNORE_ON_USER, LOAD_ON_USERDEBUG)
 (struct __sk_buff* skb) {
diff --git a/bpf_progs/offload.c b/bpf_progs/offload.c
index dd59dca..4f152bf 100644
--- a/bpf_progs/offload.c
+++ b/bpf_progs/offload.c
@@ -28,11 +28,11 @@
 // BTF is incompatible with bpfloaders < v0.10, hence for S (v0.2) we must
 // ship a different file than for later versions, but we need bpfloader v0.25+
 // for obj@ver.o support
-#define BPFLOADER_MIN_VER BPFLOADER_OBJ_AT_VER_VERSION
+#define BPFLOADER_MIN_VER BPFLOADER_MAINLINE_T_VERSION
 #else /* MAINLINE */
 // The resulting .o needs to load on the Android S bpfloader
 #define BPFLOADER_MIN_VER BPFLOADER_S_VERSION
-#define BPFLOADER_MAX_VER BPFLOADER_OBJ_AT_VER_VERSION
+#define BPFLOADER_MAX_VER BPFLOADER_T_VERSION
 #endif /* MAINLINE */
 
 // Warning: values other than AID_ROOT don't work for map uid on BpfLoader < v0.21
diff --git a/bpf_progs/test.c b/bpf_progs/test.c
index e2b8ea5..fff3512 100644
--- a/bpf_progs/test.c
+++ b/bpf_progs/test.c
@@ -22,11 +22,11 @@
 // BTF is incompatible with bpfloaders < v0.10, hence for S (v0.2) we must
 // ship a different file than for later versions, but we need bpfloader v0.25+
 // for obj@ver.o support
-#define BPFLOADER_MIN_VER BPFLOADER_OBJ_AT_VER_VERSION
+#define BPFLOADER_MIN_VER BPFLOADER_MAINLINE_T_VERSION
 #else /* MAINLINE */
 // The resulting .o needs to load on the Android S bpfloader
 #define BPFLOADER_MIN_VER BPFLOADER_S_VERSION
-#define BPFLOADER_MAX_VER BPFLOADER_OBJ_AT_VER_VERSION
+#define BPFLOADER_MAX_VER BPFLOADER_T_VERSION
 #endif /* MAINLINE */
 
 // Warning: values other than AID_ROOT don't work for map uid on BpfLoader < v0.21
diff --git a/common/Android.bp b/common/Android.bp
index 0048a0a..5fabf41 100644
--- a/common/Android.bp
+++ b/common/Android.bp
@@ -26,7 +26,7 @@
 // as the above target may not exist
 // depending on the branch
 
-// The library requires the final artifact to contain net-utils-device-common-struct.
+// The library requires the final artifact to contain net-utils-device-common-struct-base.
 java_library {
     name: "connectivity-net-module-utils-bpf",
     srcs: [
@@ -45,7 +45,7 @@
         // For libraries which are statically linked in framework-connectivity, do not
         // statically link here because callers of this library might already have a static
         // version linked.
-        "net-utils-device-common-struct",
+        "net-utils-device-common-struct-base",
     ],
     apex_available: [
         "com.android.tethering",
diff --git a/common/src/com/android/net/module/util/bpf/ClatEgress4Value.java b/common/src/com/android/net/module/util/bpf/ClatEgress4Value.java
index 69fab09..71f7516 100644
--- a/common/src/com/android/net/module/util/bpf/ClatEgress4Value.java
+++ b/common/src/com/android/net/module/util/bpf/ClatEgress4Value.java
@@ -36,11 +36,24 @@
     @Field(order = 3, type = Type.U8, padding = 3)
     public final short oifIsEthernet; // Whether the output interface requires ethernet header
 
+    @Field(order = 4, type = Type.U63)
+    public final long packets; // Count of translated gso (large) packets
+
+    @Field(order = 5, type = Type.U63)
+    public final long bytes; // Sum of post-translation skb->len
+
     public ClatEgress4Value(final int oif, final Inet6Address local6, final Inet6Address pfx96,
-            final short oifIsEthernet) {
+            final short oifIsEthernet, final long packets, final long bytes) {
         this.oif = oif;
         this.local6 = local6;
         this.pfx96 = pfx96;
         this.oifIsEthernet = oifIsEthernet;
+        this.packets = packets;
+        this.bytes = bytes;
+    }
+
+    public ClatEgress4Value(final int oif, final Inet6Address local6, final Inet6Address pfx96,
+            final short oifIsEthernet) {
+        this(oif, local6, pfx96, oifIsEthernet, 0, 0);
     }
 }
diff --git a/common/src/com/android/net/module/util/bpf/ClatIngress6Value.java b/common/src/com/android/net/module/util/bpf/ClatIngress6Value.java
index fb81caa..25f737b 100644
--- a/common/src/com/android/net/module/util/bpf/ClatIngress6Value.java
+++ b/common/src/com/android/net/module/util/bpf/ClatIngress6Value.java
@@ -30,8 +30,21 @@
     @Field(order = 1, type = Type.Ipv4Address)
     public final Inet4Address local4; // The destination IPv4 address
 
-    public ClatIngress6Value(final int oif, final Inet4Address local4) {
+    @Field(order = 2, type = Type.U63)
+    public final long packets; // Count of translated gso (large) packets
+
+    @Field(order = 3, type = Type.U63)
+    public final long bytes; // Sum of post-translation skb->len
+
+    public ClatIngress6Value(final int oif, final Inet4Address local4, final long packets,
+            final long bytes) {
         this.oif = oif;
         this.local4 = local4;
+        this.packets = packets;
+        this.bytes = bytes;
+    }
+
+    public ClatIngress6Value(final int oif, final Inet4Address local4) {
+        this(oif, local4, 0, 0);
     }
 }
diff --git a/framework-t/src/android/net/EthernetManager.java b/framework-t/src/android/net/EthernetManager.java
index b8070f0..719f60d 100644
--- a/framework-t/src/android/net/EthernetManager.java
+++ b/framework-t/src/android/net/EthernetManager.java
@@ -642,7 +642,14 @@
     }
 
     /**
-     * Listen to changes in the state of ethernet.
+     * Register a IntConsumer to be called back on ethernet state changes.
+     *
+     * <p>{@link IntConsumer#accept} with the current ethernet state will be triggered immediately
+     * upon adding a listener. The same callback is invoked on Ethernet state change, i.e. when
+     * calling {@link #setEthernetEnabled}.
+     * <p>The reported state is represented by:
+     * {@link #ETHERNET_STATE_DISABLED}: ethernet is now disabled.
+     * {@link #ETHERNET_STATE_ENABLED}: ethernet is now enabled.
      *
      * @param executor to run callbacks on.
      * @param listener to listen ethernet state changed.
diff --git a/framework/Android.bp b/framework/Android.bp
index 52f2c7c..8787167 100644
--- a/framework/Android.bp
+++ b/framework/Android.bp
@@ -96,6 +96,7 @@
     ],
     impl_only_static_libs: [
         "net-utils-device-common-bpf",
+        "net-utils-device-common-struct-base",
     ],
     libs: [
         "androidx.annotation_annotation",
@@ -124,6 +125,7 @@
         // Even if the library is included in "impl_only_static_libs" of defaults. This is still
         // needed because java_library which doesn't understand "impl_only_static_libs".
         "net-utils-device-common-bpf",
+        "net-utils-device-common-struct-base",
     ],
     libs: [
         // This cannot be in the defaults clause above because if it were, it would be used
diff --git a/framework/src/android/net/BpfNetMapsReader.java b/framework/src/android/net/BpfNetMapsReader.java
deleted file mode 100644
index ee422ab..0000000
--- a/framework/src/android/net/BpfNetMapsReader.java
+++ /dev/null
@@ -1,287 +0,0 @@
-/*
- * Copyright (C) 2023 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package android.net;
-
-import static android.net.BpfNetMapsConstants.CONFIGURATION_MAP_PATH;
-import static android.net.BpfNetMapsConstants.DATA_SAVER_ENABLED;
-import static android.net.BpfNetMapsConstants.DATA_SAVER_ENABLED_KEY;
-import static android.net.BpfNetMapsConstants.DATA_SAVER_ENABLED_MAP_PATH;
-import static android.net.BpfNetMapsConstants.HAPPY_BOX_MATCH;
-import static android.net.BpfNetMapsConstants.PENALTY_BOX_MATCH;
-import static android.net.BpfNetMapsConstants.UID_OWNER_MAP_PATH;
-import static android.net.BpfNetMapsConstants.UID_RULES_CONFIGURATION_KEY;
-import static android.net.BpfNetMapsUtils.getMatchByFirewallChain;
-import static android.net.BpfNetMapsUtils.isFirewallAllowList;
-import static android.net.BpfNetMapsUtils.throwIfPreT;
-import static android.net.ConnectivityManager.FIREWALL_RULE_ALLOW;
-import static android.net.ConnectivityManager.FIREWALL_RULE_DENY;
-
-import android.annotation.NonNull;
-import android.annotation.RequiresApi;
-import android.os.Build;
-import android.os.ServiceSpecificException;
-import android.system.ErrnoException;
-import android.system.Os;
-
-import com.android.internal.annotations.VisibleForTesting;
-import com.android.modules.utils.build.SdkLevel;
-import com.android.net.module.util.BpfMap;
-import com.android.net.module.util.IBpfMap;
-import com.android.net.module.util.Struct.S32;
-import com.android.net.module.util.Struct.U32;
-import com.android.net.module.util.Struct.U8;
-
-/**
- * A helper class to *read* java BpfMaps.
- * @hide
- */
-@RequiresApi(Build.VERSION_CODES.TIRAMISU)  // BPF maps were only mainlined in T
-public class BpfNetMapsReader {
-    private static final String TAG = BpfNetMapsReader.class.getSimpleName();
-
-    // Locally store the handle of bpf maps. The FileDescriptors are statically cached inside the
-    // BpfMap implementation.
-
-    // Bpf map to store various networking configurations, the format of the value is different
-    // for different keys. See BpfNetMapsConstants#*_CONFIGURATION_KEY for keys.
-    private final IBpfMap<S32, U32> mConfigurationMap;
-    // Bpf map to store per uid traffic control configurations.
-    // See {@link UidOwnerValue} for more detail.
-    private final IBpfMap<S32, UidOwnerValue> mUidOwnerMap;
-    private final IBpfMap<S32, U8> mDataSaverEnabledMap;
-    private final Dependencies mDeps;
-
-    // Bitmaps for calculating whether a given uid is blocked by firewall chains.
-    private static final long sMaskDropIfSet;
-    private static final long sMaskDropIfUnset;
-
-    static {
-        long maskDropIfSet = 0L;
-        long maskDropIfUnset = 0L;
-
-        for (int chain : BpfNetMapsConstants.ALLOW_CHAINS) {
-            final long match = getMatchByFirewallChain(chain);
-            maskDropIfUnset |= match;
-        }
-        for (int chain : BpfNetMapsConstants.DENY_CHAINS) {
-            final long match = getMatchByFirewallChain(chain);
-            maskDropIfSet |= match;
-        }
-        sMaskDropIfSet = maskDropIfSet;
-        sMaskDropIfUnset = maskDropIfUnset;
-    }
-
-    private static class SingletonHolder {
-        static final BpfNetMapsReader sInstance = new BpfNetMapsReader();
-    }
-
-    @NonNull
-    public static BpfNetMapsReader getInstance() {
-        return SingletonHolder.sInstance;
-    }
-
-    private BpfNetMapsReader() {
-        this(new Dependencies());
-    }
-
-    // While the production code uses the singleton to optimize for performance and deal with
-    // concurrent access, the test needs to use a non-static approach for dependency injection and
-    // mocking virtual bpf maps.
-    @VisibleForTesting
-    public BpfNetMapsReader(@NonNull Dependencies deps) {
-        if (!SdkLevel.isAtLeastT()) {
-            throw new UnsupportedOperationException(
-                    BpfNetMapsReader.class.getSimpleName() + " is not supported below Android T");
-        }
-        mDeps = deps;
-        mConfigurationMap = mDeps.getConfigurationMap();
-        mUidOwnerMap = mDeps.getUidOwnerMap();
-        mDataSaverEnabledMap = mDeps.getDataSaverEnabledMap();
-    }
-
-    /**
-     * Dependencies of BpfNetMapReader, for injection in tests.
-     */
-    @VisibleForTesting
-    public static class Dependencies {
-        /** Get the configuration map. */
-        public IBpfMap<S32, U32> getConfigurationMap() {
-            try {
-                return new BpfMap<>(CONFIGURATION_MAP_PATH, BpfMap.BPF_F_RDONLY,
-                        S32.class, U32.class);
-            } catch (ErrnoException e) {
-                throw new IllegalStateException("Cannot open configuration map", e);
-            }
-        }
-
-        /** Get the uid owner map. */
-        public IBpfMap<S32, UidOwnerValue> getUidOwnerMap() {
-            try {
-                return new BpfMap<>(UID_OWNER_MAP_PATH, BpfMap.BPF_F_RDONLY,
-                        S32.class, UidOwnerValue.class);
-            } catch (ErrnoException e) {
-                throw new IllegalStateException("Cannot open uid owner map", e);
-            }
-        }
-
-        /** Get the data saver enabled map. */
-        public  IBpfMap<S32, U8> getDataSaverEnabledMap() {
-            try {
-                return new BpfMap<>(DATA_SAVER_ENABLED_MAP_PATH, BpfMap.BPF_F_RDONLY, S32.class,
-                        U8.class);
-            } catch (ErrnoException e) {
-                throw new IllegalStateException("Cannot open data saver enabled map", e);
-            }
-        }
-    }
-
-    /**
-     * Get the specified firewall chain's status.
-     *
-     * @param chain target chain
-     * @return {@code true} if chain is enabled, {@code false} if chain is not enabled.
-     * @throws UnsupportedOperationException if called on pre-T devices.
-     * @throws ServiceSpecificException in case of failure, with an error code indicating the
-     *                                  cause of the failure.
-     */
-    public boolean isChainEnabled(final int chain) {
-        return isChainEnabled(mConfigurationMap, chain);
-    }
-
-    /**
-     * Get firewall rule of specified firewall chain on specified uid.
-     *
-     * @param chain target chain
-     * @param uid        target uid
-     * @return either {@link ConnectivityManager#FIREWALL_RULE_ALLOW} or
-     *         {@link ConnectivityManager#FIREWALL_RULE_DENY}.
-     * @throws UnsupportedOperationException if called on pre-T devices.
-     * @throws ServiceSpecificException in case of failure, with an error code indicating the
-     *                                  cause of the failure.
-     */
-    public int getUidRule(final int chain, final int uid) {
-        return getUidRule(mUidOwnerMap, chain, uid);
-    }
-
-    /**
-     * Get the specified firewall chain's status.
-     *
-     * @param configurationMap target configurationMap
-     * @param chain target chain
-     * @return {@code true} if chain is enabled, {@code false} if chain is not enabled.
-     * @throws UnsupportedOperationException if called on pre-T devices.
-     * @throws ServiceSpecificException in case of failure, with an error code indicating the
-     *                                  cause of the failure.
-     */
-    public static boolean isChainEnabled(
-            final IBpfMap<S32, U32> configurationMap, final int chain) {
-        throwIfPreT("isChainEnabled is not available on pre-T devices");
-
-        final long match = getMatchByFirewallChain(chain);
-        try {
-            final U32 config = configurationMap.getValue(UID_RULES_CONFIGURATION_KEY);
-            return (config.val & match) != 0;
-        } catch (ErrnoException e) {
-            throw new ServiceSpecificException(e.errno,
-                    "Unable to get firewall chain status: " + Os.strerror(e.errno));
-        }
-    }
-
-    /**
-     * Get firewall rule of specified firewall chain on specified uid.
-     *
-     * @param uidOwnerMap target uidOwnerMap.
-     * @param chain target chain.
-     * @param uid target uid.
-     * @return either FIREWALL_RULE_ALLOW or FIREWALL_RULE_DENY
-     * @throws UnsupportedOperationException if called on pre-T devices.
-     * @throws ServiceSpecificException      in case of failure, with an error code indicating the
-     *                                       cause of the failure.
-     */
-    public static int getUidRule(final IBpfMap<S32, UidOwnerValue> uidOwnerMap,
-            final int chain, final int uid) {
-        throwIfPreT("getUidRule is not available on pre-T devices");
-
-        final long match = getMatchByFirewallChain(chain);
-        final boolean isAllowList = isFirewallAllowList(chain);
-        try {
-            final UidOwnerValue uidMatch = uidOwnerMap.getValue(new S32(uid));
-            final boolean isMatchEnabled = uidMatch != null && (uidMatch.rule & match) != 0;
-            return isMatchEnabled == isAllowList ? FIREWALL_RULE_ALLOW : FIREWALL_RULE_DENY;
-        } catch (ErrnoException e) {
-            throw new ServiceSpecificException(e.errno,
-                    "Unable to get uid rule status: " + Os.strerror(e.errno));
-        }
-    }
-
-    /**
-     * Return whether the network is blocked by firewall chains for the given uid.
-     *
-     * @param uid The target uid.
-     * @param isNetworkMetered Whether the target network is metered.
-     * @param isDataSaverEnabled Whether the data saver is enabled.
-     *
-     * @return True if the network is blocked. Otherwise, false.
-     * @throws ServiceSpecificException if the read fails.
-     *
-     * @hide
-     */
-    public boolean isUidNetworkingBlocked(final int uid, boolean isNetworkMetered,
-            boolean isDataSaverEnabled) {
-        throwIfPreT("isUidBlockedByFirewallChains is not available on pre-T devices");
-
-        final long uidRuleConfig;
-        final long uidMatch;
-        try {
-            uidRuleConfig = mConfigurationMap.getValue(UID_RULES_CONFIGURATION_KEY).val;
-            final UidOwnerValue value = mUidOwnerMap.getValue(new S32(uid));
-            uidMatch = (value != null) ? value.rule : 0L;
-        } catch (ErrnoException e) {
-            throw new ServiceSpecificException(e.errno,
-                    "Unable to get firewall chain status: " + Os.strerror(e.errno));
-        }
-
-        final boolean blockedByAllowChains = 0 != (uidRuleConfig & ~uidMatch & sMaskDropIfUnset);
-        final boolean blockedByDenyChains = 0 != (uidRuleConfig & uidMatch & sMaskDropIfSet);
-        if (blockedByAllowChains || blockedByDenyChains) {
-            return true;
-        }
-
-        if (!isNetworkMetered) return false;
-        if ((uidMatch & PENALTY_BOX_MATCH) != 0) return true;
-        if ((uidMatch & HAPPY_BOX_MATCH) != 0) return false;
-        return isDataSaverEnabled;
-    }
-
-    /**
-     * Get Data Saver enabled or disabled
-     *
-     * @return whether Data Saver is enabled or disabled.
-     * @throws ServiceSpecificException in case of failure, with an error code indicating the
-     *                                  cause of the failure.
-     */
-    public boolean getDataSaverEnabled() {
-        throwIfPreT("getDataSaverEnabled is not available on pre-T devices");
-
-        try {
-            return mDataSaverEnabledMap.getValue(DATA_SAVER_ENABLED_KEY).val == DATA_SAVER_ENABLED;
-        } catch (ErrnoException e) {
-            throw new ServiceSpecificException(e.errno, "Unable to get data saver: "
-                    + Os.strerror(e.errno));
-        }
-    }
-}
diff --git a/framework/src/android/net/BpfNetMapsUtils.java b/framework/src/android/net/BpfNetMapsUtils.java
index 0be30bb..19ecafb 100644
--- a/framework/src/android/net/BpfNetMapsUtils.java
+++ b/framework/src/android/net/BpfNetMapsUtils.java
@@ -18,17 +18,22 @@
 
 import static android.net.BpfNetMapsConstants.ALLOW_CHAINS;
 import static android.net.BpfNetMapsConstants.BACKGROUND_MATCH;
+import static android.net.BpfNetMapsConstants.DATA_SAVER_ENABLED;
+import static android.net.BpfNetMapsConstants.DATA_SAVER_ENABLED_KEY;
 import static android.net.BpfNetMapsConstants.DENY_CHAINS;
 import static android.net.BpfNetMapsConstants.DOZABLE_MATCH;
+import static android.net.BpfNetMapsConstants.HAPPY_BOX_MATCH;
 import static android.net.BpfNetMapsConstants.LOW_POWER_STANDBY_MATCH;
 import static android.net.BpfNetMapsConstants.MATCH_LIST;
 import static android.net.BpfNetMapsConstants.NO_MATCH;
 import static android.net.BpfNetMapsConstants.OEM_DENY_1_MATCH;
 import static android.net.BpfNetMapsConstants.OEM_DENY_2_MATCH;
 import static android.net.BpfNetMapsConstants.OEM_DENY_3_MATCH;
+import static android.net.BpfNetMapsConstants.PENALTY_BOX_MATCH;
 import static android.net.BpfNetMapsConstants.POWERSAVE_MATCH;
 import static android.net.BpfNetMapsConstants.RESTRICTED_MATCH;
 import static android.net.BpfNetMapsConstants.STANDBY_MATCH;
+import static android.net.BpfNetMapsConstants.UID_RULES_CONFIGURATION_KEY;
 import static android.net.ConnectivityManager.FIREWALL_CHAIN_BACKGROUND;
 import static android.net.ConnectivityManager.FIREWALL_CHAIN_DOZABLE;
 import static android.net.ConnectivityManager.FIREWALL_CHAIN_LOW_POWER_STANDBY;
@@ -38,12 +43,22 @@
 import static android.net.ConnectivityManager.FIREWALL_CHAIN_POWERSAVE;
 import static android.net.ConnectivityManager.FIREWALL_CHAIN_RESTRICTED;
 import static android.net.ConnectivityManager.FIREWALL_CHAIN_STANDBY;
+import static android.net.ConnectivityManager.FIREWALL_RULE_ALLOW;
+import static android.net.ConnectivityManager.FIREWALL_RULE_DENY;
 import static android.system.OsConstants.EINVAL;
 
+import android.os.Process;
 import android.os.ServiceSpecificException;
+import android.system.ErrnoException;
+import android.system.Os;
 import android.util.Pair;
 
 import com.android.modules.utils.build.SdkLevel;
+import com.android.net.module.util.IBpfMap;
+import com.android.net.module.util.Struct;
+import com.android.net.module.util.Struct.S32;
+import com.android.net.module.util.Struct.U32;
+import com.android.net.module.util.Struct.U8;
 
 import java.util.StringJoiner;
 
@@ -56,6 +71,26 @@
 // Because modules could have different copies of this class if this is statically linked,
 // which would be problematic if the definitions in these modules are not synchronized.
 public class BpfNetMapsUtils {
+    // Bitmaps for calculating whether a given uid is blocked by firewall chains.
+    private static final long sMaskDropIfSet;
+    private static final long sMaskDropIfUnset;
+
+    static {
+        long maskDropIfSet = 0L;
+        long maskDropIfUnset = 0L;
+
+        for (int chain : BpfNetMapsConstants.ALLOW_CHAINS) {
+            final long match = getMatchByFirewallChain(chain);
+            maskDropIfUnset |= match;
+        }
+        for (int chain : BpfNetMapsConstants.DENY_CHAINS) {
+            final long match = getMatchByFirewallChain(chain);
+            maskDropIfSet |= match;
+        }
+        sMaskDropIfSet = maskDropIfSet;
+        sMaskDropIfUnset = maskDropIfUnset;
+    }
+
     // Prevent this class from being accidental instantiated.
     private BpfNetMapsUtils() {}
 
@@ -133,4 +168,128 @@
             throw new UnsupportedOperationException(msg);
         }
     }
+
+    /**
+     * Get the specified firewall chain's status.
+     *
+     * @param configurationMap target configurationMap
+     * @param chain target chain
+     * @return {@code true} if chain is enabled, {@code false} if chain is not enabled.
+     * @throws UnsupportedOperationException if called on pre-T devices.
+     * @throws ServiceSpecificException in case of failure, with an error code indicating the
+     *                                  cause of the failure.
+     */
+    public static boolean isChainEnabled(
+            final IBpfMap<S32, U32> configurationMap, final int chain) {
+        throwIfPreT("isChainEnabled is not available on pre-T devices");
+
+        final long match = getMatchByFirewallChain(chain);
+        try {
+            final U32 config = configurationMap.getValue(UID_RULES_CONFIGURATION_KEY);
+            return (config.val & match) != 0;
+        } catch (ErrnoException e) {
+            throw new ServiceSpecificException(e.errno,
+                    "Unable to get firewall chain status: " + Os.strerror(e.errno));
+        }
+    }
+
+    /**
+     * Get firewall rule of specified firewall chain on specified uid.
+     *
+     * @param uidOwnerMap target uidOwnerMap.
+     * @param chain target chain.
+     * @param uid target uid.
+     * @return either FIREWALL_RULE_ALLOW or FIREWALL_RULE_DENY
+     * @throws UnsupportedOperationException if called on pre-T devices.
+     * @throws ServiceSpecificException      in case of failure, with an error code indicating the
+     *                                       cause of the failure.
+     */
+    public static int getUidRule(final IBpfMap<S32, UidOwnerValue> uidOwnerMap,
+            final int chain, final int uid) {
+        throwIfPreT("getUidRule is not available on pre-T devices");
+
+        final long match = getMatchByFirewallChain(chain);
+        final boolean isAllowList = isFirewallAllowList(chain);
+        try {
+            final UidOwnerValue uidMatch = uidOwnerMap.getValue(new S32(uid));
+            final boolean isMatchEnabled = uidMatch != null && (uidMatch.rule & match) != 0;
+            return isMatchEnabled == isAllowList ? FIREWALL_RULE_ALLOW : FIREWALL_RULE_DENY;
+        } catch (ErrnoException e) {
+            throw new ServiceSpecificException(e.errno,
+                    "Unable to get uid rule status: " + Os.strerror(e.errno));
+        }
+    }
+
+    /**
+     * Return whether the network is blocked by firewall chains for the given uid.
+     *
+     * Note that {@link #getDataSaverEnabled(IBpfMap)} has a latency before V.
+     *
+     * @param uid The target uid.
+     * @param isNetworkMetered Whether the target network is metered.
+     *
+     * @return True if the network is blocked. Otherwise, false.
+     * @throws ServiceSpecificException if the read fails.
+     *
+     * @hide
+     */
+    public static boolean isUidNetworkingBlocked(final int uid, boolean isNetworkMetered,
+            IBpfMap<S32, U32> configurationMap,
+            IBpfMap<S32, UidOwnerValue> uidOwnerMap,
+            IBpfMap<S32, U8> dataSaverEnabledMap
+    ) {
+        throwIfPreT("isUidBlockedByFirewallChains is not available on pre-T devices");
+
+        // System uid is not blocked by firewall chains, see bpf_progs/netd.c
+        // TODO: use UserHandle.isCore() once it is accessible
+        if (uid < Process.FIRST_APPLICATION_UID) {
+            return false;
+        }
+
+        final long uidRuleConfig;
+        final long uidMatch;
+        try {
+            uidRuleConfig = configurationMap.getValue(UID_RULES_CONFIGURATION_KEY).val;
+            final UidOwnerValue value = uidOwnerMap.getValue(new Struct.S32(uid));
+            uidMatch = (value != null) ? value.rule : 0L;
+        } catch (ErrnoException e) {
+            throw new ServiceSpecificException(e.errno,
+                    "Unable to get firewall chain status: " + Os.strerror(e.errno));
+        }
+
+        final boolean blockedByAllowChains = 0 != (uidRuleConfig & ~uidMatch & sMaskDropIfUnset);
+        final boolean blockedByDenyChains = 0 != (uidRuleConfig & uidMatch & sMaskDropIfSet);
+        if (blockedByAllowChains || blockedByDenyChains) {
+            return true;
+        }
+
+        if (!isNetworkMetered) return false;
+        if ((uidMatch & PENALTY_BOX_MATCH) != 0) return true;
+        if ((uidMatch & HAPPY_BOX_MATCH) != 0) return false;
+        return getDataSaverEnabled(dataSaverEnabledMap);
+    }
+
+    /**
+     * Get Data Saver enabled or disabled
+     *
+     * Note that before V, the data saver status in bpf is written by ConnectivityService
+     * when receiving {@link ConnectivityManager#ACTION_RESTRICT_BACKGROUND_CHANGED}. Thus,
+     * the status is not synchronized.
+     * On V+, the data saver status is set by platform code when enabling/disabling
+     * data saver, which is synchronized.
+     *
+     * @return whether Data Saver is enabled or disabled.
+     * @throws ServiceSpecificException in case of failure, with an error code indicating the
+     *                                  cause of the failure.
+     */
+    public static boolean getDataSaverEnabled(IBpfMap<S32, U8> dataSaverEnabledMap) {
+        throwIfPreT("getDataSaverEnabled is not available on pre-T devices");
+
+        try {
+            return dataSaverEnabledMap.getValue(DATA_SAVER_ENABLED_KEY).val == DATA_SAVER_ENABLED;
+        } catch (ErrnoException e) {
+            throw new ServiceSpecificException(e.errno, "Unable to get data saver: "
+                    + Os.strerror(e.errno));
+        }
+    }
 }
diff --git a/framework/src/android/net/ConnectivityManager.java b/framework/src/android/net/ConnectivityManager.java
index 915ec52..b1e636d 100644
--- a/framework/src/android/net/ConnectivityManager.java
+++ b/framework/src/android/net/ConnectivityManager.java
@@ -6284,15 +6284,16 @@
     @RequiresPermission(NetworkStack.PERMISSION_MAINLINE_NETWORK_STACK)
     public boolean isUidNetworkingBlocked(int uid, boolean isNetworkMetered) {
         if (!SdkLevel.isAtLeastU()) {
-            Log.wtf(TAG, "isUidNetworkingBlocked is not supported on pre-U devices");
+            throw new IllegalStateException(
+                    "isUidNetworkingBlocked is not supported on pre-U devices");
         }
-        final BpfNetMapsReader reader = BpfNetMapsReader.getInstance();
+        final NetworkStackBpfNetMaps reader = NetworkStackBpfNetMaps.getInstance();
         // Note that before V, the data saver status in bpf is written by ConnectivityService
         // when receiving {@link #ACTION_RESTRICT_BACKGROUND_CHANGED}. Thus,
         // the status is not synchronized.
         // On V+, the data saver status is set by platform code when enabling/disabling
         // data saver, which is synchronized.
-        return reader.isUidNetworkingBlocked(uid, isNetworkMetered, reader.getDataSaverEnabled());
+        return reader.isUidNetworkingBlocked(uid, isNetworkMetered);
     }
 
     /** @hide */
diff --git a/framework/src/android/net/NetworkStackBpfNetMaps.java b/framework/src/android/net/NetworkStackBpfNetMaps.java
new file mode 100644
index 0000000..b7c4e34
--- /dev/null
+++ b/framework/src/android/net/NetworkStackBpfNetMaps.java
@@ -0,0 +1,187 @@
+/*
+ * Copyright (C) 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package android.net;
+
+import static android.net.BpfNetMapsConstants.CONFIGURATION_MAP_PATH;
+import static android.net.BpfNetMapsConstants.DATA_SAVER_ENABLED_MAP_PATH;
+import static android.net.BpfNetMapsConstants.UID_OWNER_MAP_PATH;
+
+import android.annotation.NonNull;
+import android.annotation.RequiresApi;
+import android.os.Build;
+import android.os.ServiceSpecificException;
+import android.system.ErrnoException;
+
+import com.android.internal.annotations.VisibleForTesting;
+import com.android.modules.utils.build.SdkLevel;
+import com.android.net.module.util.BpfMap;
+import com.android.net.module.util.IBpfMap;
+import com.android.net.module.util.Struct.S32;
+import com.android.net.module.util.Struct.U32;
+import com.android.net.module.util.Struct.U8;
+
+/**
+ * A helper class to *read* java BpfMaps for network stack.
+ * BpfMap operations that are not used from network stack should be in
+ * {@link com.android.server.BpfNetMaps}
+ * @hide
+ */
+// NetworkStack can not use this before U due to b/326143935
+@RequiresApi(Build.VERSION_CODES.UPSIDE_DOWN_CAKE)
+public class NetworkStackBpfNetMaps {
+    private static final String TAG = NetworkStackBpfNetMaps.class.getSimpleName();
+
+    // Locally store the handle of bpf maps. The FileDescriptors are statically cached inside the
+    // BpfMap implementation.
+
+    // Bpf map to store various networking configurations, the format of the value is different
+    // for different keys. See BpfNetMapsConstants#*_CONFIGURATION_KEY for keys.
+    private final IBpfMap<S32, U32> mConfigurationMap;
+    // Bpf map to store per uid traffic control configurations.
+    // See {@link UidOwnerValue} for more detail.
+    private final IBpfMap<S32, UidOwnerValue> mUidOwnerMap;
+    private final IBpfMap<S32, U8> mDataSaverEnabledMap;
+    private final Dependencies mDeps;
+
+    private static class SingletonHolder {
+        static final NetworkStackBpfNetMaps sInstance = new NetworkStackBpfNetMaps();
+    }
+
+    @NonNull
+    public static NetworkStackBpfNetMaps getInstance() {
+        return SingletonHolder.sInstance;
+    }
+
+    private NetworkStackBpfNetMaps() {
+        this(new Dependencies());
+    }
+
+    // While the production code uses the singleton to optimize for performance and deal with
+    // concurrent access, the test needs to use a non-static approach for dependency injection and
+    // mocking virtual bpf maps.
+    @VisibleForTesting
+    public NetworkStackBpfNetMaps(@NonNull Dependencies deps) {
+        if (!SdkLevel.isAtLeastT()) {
+            throw new UnsupportedOperationException(
+                    NetworkStackBpfNetMaps.class.getSimpleName()
+                            + " is not supported below Android T");
+        }
+        mDeps = deps;
+        mConfigurationMap = mDeps.getConfigurationMap();
+        mUidOwnerMap = mDeps.getUidOwnerMap();
+        mDataSaverEnabledMap = mDeps.getDataSaverEnabledMap();
+    }
+
+    /**
+     * Dependencies of BpfNetMapReader, for injection in tests.
+     */
+    @VisibleForTesting
+    public static class Dependencies {
+        /** Get the configuration map. */
+        public IBpfMap<S32, U32> getConfigurationMap() {
+            try {
+                return new BpfMap<>(CONFIGURATION_MAP_PATH, BpfMap.BPF_F_RDONLY,
+                        S32.class, U32.class);
+            } catch (ErrnoException e) {
+                throw new IllegalStateException("Cannot open configuration map", e);
+            }
+        }
+
+        /** Get the uid owner map. */
+        public IBpfMap<S32, UidOwnerValue> getUidOwnerMap() {
+            try {
+                return new BpfMap<>(UID_OWNER_MAP_PATH, BpfMap.BPF_F_RDONLY,
+                        S32.class, UidOwnerValue.class);
+            } catch (ErrnoException e) {
+                throw new IllegalStateException("Cannot open uid owner map", e);
+            }
+        }
+
+        /** Get the data saver enabled map. */
+        public  IBpfMap<S32, U8> getDataSaverEnabledMap() {
+            try {
+                return new BpfMap<>(DATA_SAVER_ENABLED_MAP_PATH, BpfMap.BPF_F_RDONLY, S32.class,
+                        U8.class);
+            } catch (ErrnoException e) {
+                throw new IllegalStateException("Cannot open data saver enabled map", e);
+            }
+        }
+    }
+
+    /**
+     * Get the specified firewall chain's status.
+     *
+     * @param chain target chain
+     * @return {@code true} if chain is enabled, {@code false} if chain is not enabled.
+     * @throws UnsupportedOperationException if called on pre-T devices.
+     * @throws ServiceSpecificException in case of failure, with an error code indicating the
+     *                                  cause of the failure.
+     */
+    public boolean isChainEnabled(final int chain) {
+        return BpfNetMapsUtils.isChainEnabled(mConfigurationMap, chain);
+    }
+
+    /**
+     * Get firewall rule of specified firewall chain on specified uid.
+     *
+     * @param chain target chain
+     * @param uid        target uid
+     * @return either {@link ConnectivityManager#FIREWALL_RULE_ALLOW} or
+     *         {@link ConnectivityManager#FIREWALL_RULE_DENY}.
+     * @throws UnsupportedOperationException if called on pre-T devices.
+     * @throws ServiceSpecificException in case of failure, with an error code indicating the
+     *                                  cause of the failure.
+     */
+    public int getUidRule(final int chain, final int uid) {
+        return BpfNetMapsUtils.getUidRule(mUidOwnerMap, chain, uid);
+    }
+
+    /**
+     * Return whether the network is blocked by firewall chains for the given uid.
+     *
+     * Note that {@link #getDataSaverEnabled()} has a latency before V.
+     *
+     * @param uid The target uid.
+     * @param isNetworkMetered Whether the target network is metered.
+     *
+     * @return True if the network is blocked. Otherwise, false.
+     * @throws ServiceSpecificException if the read fails.
+     *
+     * @hide
+     */
+    public boolean isUidNetworkingBlocked(final int uid, boolean isNetworkMetered) {
+        return BpfNetMapsUtils.isUidNetworkingBlocked(uid, isNetworkMetered,
+                mConfigurationMap, mUidOwnerMap, mDataSaverEnabledMap);
+    }
+
+    /**
+     * Get Data Saver enabled or disabled
+     *
+     * Note that before V, the data saver status in bpf is written by ConnectivityService
+     * when receiving {@link ConnectivityManager#ACTION_RESTRICT_BACKGROUND_CHANGED}. Thus,
+     * the status is not synchronized.
+     * On V+, the data saver status is set by platform code when enabling/disabling
+     * data saver, which is synchronized.
+     *
+     * @return whether Data Saver is enabled or disabled.
+     * @throws ServiceSpecificException in case of failure, with an error code indicating the
+     *                                  cause of the failure.
+     */
+    public boolean getDataSaverEnabled() {
+        return BpfNetMapsUtils.getDataSaverEnabled(mDataSaverEnabledMap);
+    }
+}
diff --git a/netbpfload/NetBpfLoad.cpp b/netbpfload/NetBpfLoad.cpp
index 0688e88..83bb98c 100644
--- a/netbpfload/NetBpfLoad.cpp
+++ b/netbpfload/NetBpfLoad.cpp
@@ -97,7 +97,7 @@
         },
 };
 
-int loadAllElfObjects(const android::bpf::Location& location) {
+int loadAllElfObjects(const unsigned int bpfloader_ver, const android::bpf::Location& location) {
     int retVal = 0;
     DIR* dir;
     struct dirent* ent;
@@ -111,7 +111,7 @@
             progPath += s;
 
             bool critical;
-            int ret = android::bpf::loadProg(progPath.c_str(), &critical, location);
+            int ret = android::bpf::loadProg(progPath.c_str(), &critical, bpfloader_ver, location);
             if (ret) {
                 if (critical) retVal = ret;
                 ALOGE("Failed to load object: %s, ret: %s", progPath.c_str(), std::strerror(-ret));
@@ -257,13 +257,8 @@
 
     logTetheringApexVersion();
 
-    if (has_platform_bpfloader_rc && !has_platform_netbpfload_rc) {
-        // Tethering apex shipped initrc file causes us to reach here
-        // but we're not ready to correctly handle anything before U QPR2
-        // in which the 'bpfloader' vs 'netbpfload' split happened
-        const char * args[] = { platformBpfLoader, NULL, };
-        execve(args[0], (char**)args, envp);
-        ALOGE("exec '%s' fail: %d[%s]", platformBpfLoader, errno, strerror(errno));
+    if (!isAtLeastT) {
+        ALOGE("Impossible - not reachable on Android <T.");
         return 1;
     }
 
@@ -318,14 +313,16 @@
         return 1;
     }
 
-    if (isAtLeastU) {
+    if (false && isAtLeastV) {
         // Linux 5.16-rc1 changed the default to 2 (disabled but changeable),
         // but we need 0 (enabled)
         // (this writeFile is known to fail on at least 4.19, but always defaults to 0 on
         // pre-5.13, on 5.13+ it depends on CONFIG_BPF_UNPRIV_DEFAULT_OFF)
         if (writeProcSysFile("/proc/sys/kernel/unprivileged_bpf_disabled", "0\n") &&
             android::bpf::isAtLeastKernelVersion(5, 13, 0)) return 1;
+    }
 
+    if (isAtLeastU) {
         // Enable the eBPF JIT -- but do note that on 64-bit kernels it is likely
         // already force enabled by the kernel config option BPF_JIT_ALWAYS_ON.
         // (Note: this (open) will fail with ENOENT 'No such file or directory' if
@@ -355,9 +352,15 @@
     // Thus we need to manually create the /sys/fs/bpf/loader subdirectory.
     if (createSysFsBpfSubDir("loader")) return 1;
 
+    // Version of Network BpfLoader depends on the Android OS version
+    unsigned int bpfloader_ver = 42u;  // [42] BPFLOADER_MAINLINE_VERSION
+    if (isAtLeastT) ++bpfloader_ver;   // [43] BPFLOADER_MAINLINE_T_VERSION
+    if (isAtLeastU) ++bpfloader_ver;   // [44] BPFLOADER_MAINLINE_U_VERSION
+    if (isAtLeastV) ++bpfloader_ver;   // [45] BPFLOADER_MAINLINE_V_VERSION
+
     // Load all ELF objects, create programs and maps, and pin them
     for (const auto& location : locations) {
-        if (loadAllElfObjects(location) != 0) {
+        if (loadAllElfObjects(bpfloader_ver, location) != 0) {
             ALOGE("=== CRITICAL FAILURE LOADING BPF PROGRAMS FROM %s ===", location.dir);
             ALOGE("If this triggers reliably, you're probably missing kernel options or patches.");
             ALOGE("If this triggers randomly, you might be hitting some memory allocation "
@@ -377,10 +380,15 @@
         return 1;
     }
 
-    ALOGI("done, transferring control to platform bpfloader.");
+    if (false && isAtLeastV) {
+        ALOGI("done, transferring control to platform bpfloader.");
 
-    const char * args[] = { platformBpfLoader, NULL, };
-    execve(args[0], (char**)args, envp);
-    ALOGE("FATAL: execve('%s'): %d[%s]", platformBpfLoader, errno, strerror(errno));
-    return 1;
+        const char * args[] = { platformBpfLoader, NULL, };
+        execve(args[0], (char**)args, envp);
+        ALOGE("FATAL: execve('%s'): %d[%s]", platformBpfLoader, errno, strerror(errno));
+        return 1;
+    }
+
+    ALOGI("mainline done!");
+    return 0;
 }
diff --git a/netbpfload/loader.cpp b/netbpfload/loader.cpp
index 013309e..9dd0d2a 100644
--- a/netbpfload/loader.cpp
+++ b/netbpfload/loader.cpp
@@ -31,15 +31,6 @@
 #include <sys/wait.h>
 #include <unistd.h>
 
-// This is Network BpfLoader v0.42
-// WARNING: If you ever hit cherrypick conflicts here you're doing it wrong:
-// You are NOT allowed to cherrypick bpfloader related patches out of order.
-// (indeed: cherrypicking is probably a bad idea and you should merge instead)
-// Mainline supports ONLY the published versions of the bpfloader for each Android release.
-#define BPFLOADER_VERSION_MAJOR 0u
-#define BPFLOADER_VERSION_MINOR 42u
-#define BPFLOADER_VERSION ((BPFLOADER_VERSION_MAJOR << 16) | BPFLOADER_VERSION_MINOR)
-
 #include "BpfSyscallWrappers.h"
 #include "bpf/BpfUtils.h"
 #include "bpf/bpf_map_def.h"
@@ -622,7 +613,8 @@
 }
 
 static int createMaps(const char* elfPath, ifstream& elfFile, vector<unique_fd>& mapFds,
-                      const char* prefix, const size_t sizeOfBpfMapDef) {
+                      const char* prefix, const size_t sizeOfBpfMapDef,
+                      const unsigned int bpfloader_ver) {
     int ret;
     vector<char> mdData;
     vector<struct bpf_map_def> md;
@@ -664,14 +656,14 @@
     for (int i = 0; i < (int)mapNames.size(); i++) {
         if (md[i].zero != 0) abort();
 
-        if (BPFLOADER_VERSION < md[i].bpfloader_min_ver) {
+        if (bpfloader_ver < md[i].bpfloader_min_ver) {
             ALOGI("skipping map %s which requires bpfloader min ver 0x%05x", mapNames[i].c_str(),
                   md[i].bpfloader_min_ver);
             mapFds.push_back(unique_fd());
             continue;
         }
 
-        if (BPFLOADER_VERSION >= md[i].bpfloader_max_ver) {
+        if (bpfloader_ver >= md[i].bpfloader_max_ver) {
             ALOGI("skipping map %s which requires bpfloader max ver 0x%05x", mapNames[i].c_str(),
                   md[i].bpfloader_max_ver);
             mapFds.push_back(unique_fd());
@@ -922,7 +914,7 @@
 }
 
 static int loadCodeSections(const char* elfPath, vector<codeSection>& cs, const string& license,
-                            const char* prefix) {
+                            const char* prefix, const unsigned int bpfloader_ver) {
     unsigned kvers = kernelVersion();
 
     if (!kvers) {
@@ -958,8 +950,8 @@
 
         ALOGD("cs[%d].name:%s requires bpfloader version [0x%05x,0x%05x)", i, name.c_str(),
               bpfMinVer, bpfMaxVer);
-        if (BPFLOADER_VERSION < bpfMinVer) continue;
-        if (BPFLOADER_VERSION >= bpfMaxVer) continue;
+        if (bpfloader_ver < bpfMinVer) continue;
+        if (bpfloader_ver >= bpfMaxVer) continue;
 
         if ((cs[i].prog_def->ignore_on_eng && isEng()) ||
             (cs[i].prog_def->ignore_on_user && isUser()) ||
@@ -1095,7 +1087,8 @@
     return 0;
 }
 
-int loadProg(const char* elfPath, bool* isCritical, const Location& location) {
+int loadProg(const char* const elfPath, bool* const isCritical, const unsigned int bpfloader_ver,
+             const Location& location) {
     vector<char> license;
     vector<char> critical;
     vector<codeSection> cs;
@@ -1134,27 +1127,27 @@
             readSectionUint("size_of_bpf_prog_def", elfFile, DEFAULT_SIZEOF_BPF_PROG_DEF);
 
     // inclusive lower bound check
-    if (BPFLOADER_VERSION < bpfLoaderMinVer) {
+    if (bpfloader_ver < bpfLoaderMinVer) {
         ALOGI("BpfLoader version 0x%05x ignoring ELF object %s with min ver 0x%05x",
-              BPFLOADER_VERSION, elfPath, bpfLoaderMinVer);
+              bpfloader_ver, elfPath, bpfLoaderMinVer);
         return 0;
     }
 
     // exclusive upper bound check
-    if (BPFLOADER_VERSION >= bpfLoaderMaxVer) {
+    if (bpfloader_ver >= bpfLoaderMaxVer) {
         ALOGI("BpfLoader version 0x%05x ignoring ELF object %s with max ver 0x%05x",
-              BPFLOADER_VERSION, elfPath, bpfLoaderMaxVer);
+              bpfloader_ver, elfPath, bpfLoaderMaxVer);
         return 0;
     }
 
-    if (BPFLOADER_VERSION < bpfLoaderMinRequiredVer) {
+    if (bpfloader_ver < bpfLoaderMinRequiredVer) {
         ALOGI("BpfLoader version 0x%05x failing due to ELF object %s with required min ver 0x%05x",
-              BPFLOADER_VERSION, elfPath, bpfLoaderMinRequiredVer);
+              bpfloader_ver, elfPath, bpfLoaderMinRequiredVer);
         return -1;
     }
 
     ALOGI("BpfLoader version 0x%05x processing ELF object %s with ver [0x%05x,0x%05x)",
-          BPFLOADER_VERSION, elfPath, bpfLoaderMinVer, bpfLoaderMaxVer);
+          bpfloader_ver, elfPath, bpfLoaderMinVer, bpfLoaderMaxVer);
 
     if (sizeOfBpfMapDef < DEFAULT_SIZEOF_BPF_MAP_DEF) {
         ALOGE("sizeof(bpf_map_def) of %zu is too small (< %d)", sizeOfBpfMapDef,
@@ -1177,7 +1170,7 @@
     /* Just for future debugging */
     if (0) dumpAllCs(cs);
 
-    ret = createMaps(elfPath, elfFile, mapFds, location.prefix, sizeOfBpfMapDef);
+    ret = createMaps(elfPath, elfFile, mapFds, location.prefix, sizeOfBpfMapDef, bpfloader_ver);
     if (ret) {
         ALOGE("Failed to create maps: (ret=%d) in %s", ret, elfPath);
         return ret;
@@ -1188,7 +1181,7 @@
 
     applyMapRelo(elfFile, mapFds, cs);
 
-    ret = loadCodeSections(elfPath, cs, string(license.data()), location.prefix);
+    ret = loadCodeSections(elfPath, cs, string(license.data()), location.prefix, bpfloader_ver);
     if (ret) ALOGE("Failed to load programs, loadCodeSections ret=%d", ret);
 
     return ret;
diff --git a/netbpfload/loader.h b/netbpfload/loader.h
index b884637..4da6830 100644
--- a/netbpfload/loader.h
+++ b/netbpfload/loader.h
@@ -70,7 +70,8 @@
 };
 
 // BPF loader implementation. Loads an eBPF ELF object
-int loadProg(const char* elfPath, bool* isCritical, const Location &location = {});
+int loadProg(const char* elfPath, bool* isCritical, const unsigned int bpfloader_ver,
+             const Location &location = {});
 
 // Exposed for testing
 unsigned int readSectionUint(const char* name, std::ifstream& elfFile, unsigned int defVal);
diff --git a/netbpfload/netbpfload.mainline.rc b/netbpfload/netbpfload.mainline.rc
index 0ac5de8..d38a503 100644
--- a/netbpfload/netbpfload.mainline.rc
+++ b/netbpfload/netbpfload.mainline.rc
@@ -1,8 +1,17 @@
-service bpfloader /apex/com.android.tethering/bin/netbpfload
+service mdnsd_loadbpf /system/bin/bpfloader
     capabilities CHOWN SYS_ADMIN NET_ADMIN
     group root graphics network_stack net_admin net_bw_acct net_bw_stats net_raw system
     user root
     rlimit memlock 1073741824 1073741824
     oneshot
     reboot_on_failure reboot,bpfloader-failed
+
+service bpfloader /apex/com.android.tethering/bin/netbpfload
+    capabilities CHOWN SYS_ADMIN NET_ADMIN
+    group system root graphics network_stack net_admin net_bw_acct net_bw_stats net_raw
+    user system
+    file /dev/kmsg w
+    rlimit memlock 1073741824 1073741824
+    oneshot
+    reboot_on_failure reboot,bpfloader-failed
     override
diff --git a/netd/BpfHandler.cpp b/netd/BpfHandler.cpp
index a00c363..0d75c05 100644
--- a/netd/BpfHandler.cpp
+++ b/netd/BpfHandler.cpp
@@ -165,9 +165,38 @@
 BpfHandler::BpfHandler(uint32_t perUidLimit, uint32_t totalLimit)
     : mPerUidStatsEntriesLimit(perUidLimit), mTotalUidStatsEntriesLimit(totalLimit) {}
 
+// copied with minor changes from waitForProgsLoaded()
+// p/m/C's staticlibs/native/bpf_headers/include/bpf/WaitForProgsLoaded.h
+static inline void waitForNetProgsLoaded() {
+    // infinite loop until success with 5/10/20/40/60/60/60... delay
+    for (int delay = 5;; delay *= 2) {
+        if (delay > 60) delay = 60;
+        if (base::WaitForProperty("init.svc.bpfloader", "stopped", std::chrono::seconds(delay))
+            && !access("/sys/fs/bpf/netd_shared", F_OK))
+            return;
+        ALOGW("Waited %ds for init.svc.bpfloader=stopped, still waiting...", delay);
+    }
+}
+
 Status BpfHandler::init(const char* cg2_path) {
-    // Make sure BPF programs are loaded before doing anything
-    android::bpf::waitForProgsLoaded();
+    if (base::GetProperty("bpf.progs_loaded", "") != "1") {
+        // Make sure BPF programs are loaded before doing anything
+        ALOGI("Waiting for BPF programs");
+
+        if (true || !modules::sdklevel::IsAtLeastV()) {
+            waitForNetProgsLoaded();
+            ALOGI("Networking BPF programs are loaded");
+
+            if (!base::SetProperty("ctl.start", "mdnsd_loadbpf")) {
+                ALOGE("Failed to set property ctl.start=mdnsd_loadbpf, see dmesg for reason.");
+                abort();
+            }
+
+            ALOGI("Waiting for remaining BPF programs");
+        }
+
+        android::bpf::waitForProgsLoaded();
+    }
     ALOGI("BPF programs are loaded");
 
     RETURN_IF_NOT_OK(initPrograms(cg2_path));
diff --git a/service-t/src/com/android/server/NsdService.java b/service-t/src/com/android/server/NsdService.java
index cfb1a33..aca386f 100644
--- a/service-t/src/com/android/server/NsdService.java
+++ b/service-t/src/com/android/server/NsdService.java
@@ -1851,6 +1851,8 @@
                         mContext, MdnsFeatureFlags.NSD_UNICAST_REPLY_ENABLED))
                 .setIsAggressiveQueryModeEnabled(mDeps.isFeatureEnabled(
                         mContext, MdnsFeatureFlags.NSD_AGGRESSIVE_QUERY_MODE))
+                .setIsQueryWithKnownAnswerEnabled(mDeps.isFeatureEnabled(
+                        mContext, MdnsFeatureFlags.NSD_QUERY_WITH_KNOWN_ANSWER))
                 .setOverrideProvider(flag -> mDeps.isFeatureEnabled(
                         mContext, FORCE_ENABLE_FLAG_FOR_TEST_PREFIX + flag))
                 .build();
diff --git a/service-t/src/com/android/server/connectivity/mdns/EnqueueMdnsQueryCallable.java b/service-t/src/com/android/server/connectivity/mdns/EnqueueMdnsQueryCallable.java
index c4d3338..54943c7 100644
--- a/service-t/src/com/android/server/connectivity/mdns/EnqueueMdnsQueryCallable.java
+++ b/service-t/src/com/android/server/connectivity/mdns/EnqueueMdnsQueryCallable.java
@@ -23,6 +23,7 @@
 import android.text.TextUtils;
 import android.util.Pair;
 
+import com.android.net.module.util.CollectionUtils;
 import com.android.net.module.util.SharedLog;
 import com.android.server.connectivity.mdns.util.MdnsUtils;
 
@@ -63,8 +64,6 @@
     @NonNull
     private final WeakReference<MdnsSocketClientBase> weakRequestSender;
     @NonNull
-    private final MdnsPacketWriter packetWriter;
-    @NonNull
     private final String[] serviceTypeLabels;
     @NonNull
     private final List<String> subtypes;
@@ -79,11 +78,16 @@
     private final MdnsUtils.Clock clock;
     @NonNull
     private final SharedLog sharedLog;
+    @NonNull
+    private final MdnsServiceTypeClient.Dependencies dependencies;
     private final boolean onlyUseIpv6OnIpv6OnlyNetworks;
+    private final byte[] packetCreationBuffer = new byte[1500]; // TODO: use interface MTU
+    @NonNull
+    private final List<MdnsResponse> existingServices;
+    private final boolean isQueryWithKnownAnswer;
 
     EnqueueMdnsQueryCallable(
             @NonNull MdnsSocketClientBase requestSender,
-            @NonNull MdnsPacketWriter packetWriter,
             @NonNull String serviceType,
             @NonNull Collection<String> subtypes,
             boolean expectUnicastResponse,
@@ -93,9 +97,11 @@
             boolean sendDiscoveryQueries,
             @NonNull Collection<MdnsResponse> servicesToResolve,
             @NonNull MdnsUtils.Clock clock,
-            @NonNull SharedLog sharedLog) {
+            @NonNull SharedLog sharedLog,
+            @NonNull MdnsServiceTypeClient.Dependencies dependencies,
+            @NonNull Collection<MdnsResponse> existingServices,
+            boolean isQueryWithKnownAnswer) {
         weakRequestSender = new WeakReference<>(requestSender);
-        this.packetWriter = packetWriter;
         serviceTypeLabels = TextUtils.split(serviceType, "\\.");
         this.subtypes = new ArrayList<>(subtypes);
         this.expectUnicastResponse = expectUnicastResponse;
@@ -106,6 +112,9 @@
         this.servicesToResolve = new ArrayList<>(servicesToResolve);
         this.clock = clock;
         this.sharedLog = sharedLog;
+        this.dependencies = dependencies;
+        this.existingServices = new ArrayList<>(existingServices);
+        this.isQueryWithKnownAnswer = isQueryWithKnownAnswer;
     }
 
     /**
@@ -176,62 +185,86 @@
                 return Pair.create(INVALID_TRANSACTION_ID, new ArrayList<>());
             }
 
+            // Put the existing ptr records into known-answer section.
+            final List<MdnsRecord> knownAnswers = new ArrayList<>();
+            if (sendDiscoveryQueries) {
+                for (MdnsResponse existingService : existingServices) {
+                    for (MdnsPointerRecord ptrRecord : existingService.getPointerRecords()) {
+                        // Ignore any PTR records that don't match the current query.
+                        if (!CollectionUtils.any(questions,
+                                q -> q instanceof MdnsPointerRecord
+                                        && MdnsUtils.equalsDnsLabelIgnoreDnsCase(
+                                                q.getName(), ptrRecord.getName()))) {
+                            continue;
+                        }
+
+                        knownAnswers.add(new MdnsPointerRecord(
+                                ptrRecord.getName(),
+                                ptrRecord.getReceiptTime(),
+                                ptrRecord.getCacheFlush(),
+                                ptrRecord.getRemainingTTL(now), // Put the remaining ttl.
+                                ptrRecord.getPointer()));
+                    }
+                }
+            }
+
             final MdnsPacket queryPacket = new MdnsPacket(
                     transactionId,
                     MdnsConstants.FLAGS_QUERY,
                     questions,
-                    Collections.emptyList(), /* answers */
+                    knownAnswers,
                     Collections.emptyList(), /* authorityRecords */
                     Collections.emptyList() /* additionalRecords */);
-            MdnsUtils.writeMdnsPacket(packetWriter, queryPacket);
-            sendPacketToIpv4AndIpv6(requestSender, MdnsConstants.MDNS_PORT);
+            sendPacketToIpv4AndIpv6(requestSender, MdnsConstants.MDNS_PORT, queryPacket);
             for (Integer emulatorPort : castShellEmulatorMdnsPorts) {
-                sendPacketToIpv4AndIpv6(requestSender, emulatorPort);
+                sendPacketToIpv4AndIpv6(requestSender, emulatorPort, queryPacket);
             }
             return Pair.create(transactionId, subtypes);
-        } catch (IOException e) {
+        } catch (Exception e) {
             sharedLog.e(String.format("Failed to create mDNS packet for subtype: %s.",
                     TextUtils.join(",", subtypes)), e);
             return Pair.create(INVALID_TRANSACTION_ID, new ArrayList<>());
         }
     }
 
-    private void sendPacket(MdnsSocketClientBase requestSender, InetSocketAddress address)
-            throws IOException {
-        DatagramPacket packet = packetWriter.getPacket(address);
+    private void sendPacket(MdnsSocketClientBase requestSender, InetSocketAddress address,
+            MdnsPacket mdnsPacket) throws IOException {
+        final List<DatagramPacket> packets = dependencies.getDatagramPacketsFromMdnsPacket(
+                packetCreationBuffer, mdnsPacket, address, isQueryWithKnownAnswer);
         if (expectUnicastResponse) {
             // MdnsMultinetworkSocketClient is only available on T+
             if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.TIRAMISU
                     && requestSender instanceof MdnsMultinetworkSocketClient) {
                 ((MdnsMultinetworkSocketClient) requestSender).sendPacketRequestingUnicastResponse(
-                        packet, socketKey, onlyUseIpv6OnIpv6OnlyNetworks);
+                        packets, socketKey, onlyUseIpv6OnIpv6OnlyNetworks);
             } else {
                 requestSender.sendPacketRequestingUnicastResponse(
-                        packet, onlyUseIpv6OnIpv6OnlyNetworks);
+                        packets, onlyUseIpv6OnIpv6OnlyNetworks);
             }
         } else {
             if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.TIRAMISU
                     && requestSender instanceof MdnsMultinetworkSocketClient) {
                 ((MdnsMultinetworkSocketClient) requestSender)
                         .sendPacketRequestingMulticastResponse(
-                                packet, socketKey, onlyUseIpv6OnIpv6OnlyNetworks);
+                                packets, socketKey, onlyUseIpv6OnIpv6OnlyNetworks);
             } else {
                 requestSender.sendPacketRequestingMulticastResponse(
-                        packet, onlyUseIpv6OnIpv6OnlyNetworks);
+                        packets, onlyUseIpv6OnIpv6OnlyNetworks);
             }
         }
     }
 
-    private void sendPacketToIpv4AndIpv6(MdnsSocketClientBase requestSender, int port) {
+    private void sendPacketToIpv4AndIpv6(MdnsSocketClientBase requestSender, int port,
+            MdnsPacket mdnsPacket) {
         try {
             sendPacket(requestSender,
-                    new InetSocketAddress(MdnsConstants.getMdnsIPv4Address(), port));
+                    new InetSocketAddress(MdnsConstants.getMdnsIPv4Address(), port), mdnsPacket);
         } catch (IOException e) {
             sharedLog.e("Can't send packet to IPv4", e);
         }
         try {
             sendPacket(requestSender,
-                    new InetSocketAddress(MdnsConstants.getMdnsIPv6Address(), port));
+                    new InetSocketAddress(MdnsConstants.getMdnsIPv6Address(), port), mdnsPacket);
         } catch (IOException e) {
             sharedLog.e("Can't send packet to IPv6", e);
         }
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java b/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java
index 21b7069..7b0c738 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java
@@ -362,7 +362,7 @@
         return new MdnsServiceTypeClient(
                 serviceType, socketClient,
                 executorProvider.newServiceTypeClientSchedulerExecutor(), socketKey,
-                sharedLog.forSubComponent(tag), looper, serviceCache);
+                sharedLog.forSubComponent(tag), looper, serviceCache, mdnsFeatureFlags);
     }
 
     /**
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsFeatureFlags.java b/service-t/src/com/android/server/connectivity/mdns/MdnsFeatureFlags.java
index 56202fd..f4a08ba 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsFeatureFlags.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsFeatureFlags.java
@@ -62,6 +62,11 @@
      */
     public static final String NSD_AGGRESSIVE_QUERY_MODE = "nsd_aggressive_query_mode";
 
+    /**
+     * A feature flag to control whether the query with known-answer should be enabled.
+     */
+    public static final String NSD_QUERY_WITH_KNOWN_ANSWER = "nsd_query_with_known_answer";
+
     // Flag for offload feature
     public final boolean mIsMdnsOffloadFeatureEnabled;
 
@@ -83,6 +88,9 @@
     // Flag for aggressive query mode
     public final boolean mIsAggressiveQueryModeEnabled;
 
+    // Flag for query with known-answer
+    public final boolean mIsQueryWithKnownAnswerEnabled;
+
     @Nullable
     private final FlagOverrideProvider mOverrideProvider;
 
@@ -126,6 +134,14 @@
     }
 
     /**
+     * Indicates whether {@link #NSD_QUERY_WITH_KNOWN_ANSWER} is enabled, including for testing.
+     */
+    public boolean isQueryWithKnownAnswerEnabled() {
+        return mIsQueryWithKnownAnswerEnabled
+                || isForceEnabledForTest(NSD_QUERY_WITH_KNOWN_ANSWER);
+    }
+
+    /**
      * The constructor for {@link MdnsFeatureFlags}.
      */
     public MdnsFeatureFlags(boolean isOffloadFeatureEnabled,
@@ -135,6 +151,7 @@
             boolean isKnownAnswerSuppressionEnabled,
             boolean isUnicastReplyEnabled,
             boolean isAggressiveQueryModeEnabled,
+            boolean isQueryWithKnownAnswerEnabled,
             @Nullable FlagOverrideProvider overrideProvider) {
         mIsMdnsOffloadFeatureEnabled = isOffloadFeatureEnabled;
         mIncludeInetAddressRecordsInProbing = includeInetAddressRecordsInProbing;
@@ -143,6 +160,7 @@
         mIsKnownAnswerSuppressionEnabled = isKnownAnswerSuppressionEnabled;
         mIsUnicastReplyEnabled = isUnicastReplyEnabled;
         mIsAggressiveQueryModeEnabled = isAggressiveQueryModeEnabled;
+        mIsQueryWithKnownAnswerEnabled = isQueryWithKnownAnswerEnabled;
         mOverrideProvider = overrideProvider;
     }
 
@@ -162,6 +180,7 @@
         private boolean mIsKnownAnswerSuppressionEnabled;
         private boolean mIsUnicastReplyEnabled;
         private boolean mIsAggressiveQueryModeEnabled;
+        private boolean mIsQueryWithKnownAnswerEnabled;
         private FlagOverrideProvider mOverrideProvider;
 
         /**
@@ -175,6 +194,7 @@
             mIsKnownAnswerSuppressionEnabled = false;
             mIsUnicastReplyEnabled = true;
             mIsAggressiveQueryModeEnabled = false;
+            mIsQueryWithKnownAnswerEnabled = false;
             mOverrideProvider = null;
         }
 
@@ -261,6 +281,16 @@
         }
 
         /**
+         * Set whether the query with known-answer is enabled.
+         *
+         * @see #NSD_QUERY_WITH_KNOWN_ANSWER
+         */
+        public Builder setIsQueryWithKnownAnswerEnabled(boolean isQueryWithKnownAnswerEnabled) {
+            mIsQueryWithKnownAnswerEnabled = isQueryWithKnownAnswerEnabled;
+            return this;
+        }
+
+        /**
          * Builds a {@link MdnsFeatureFlags} with the arguments supplied to this builder.
          */
         public MdnsFeatureFlags build() {
@@ -271,6 +301,7 @@
                     mIsKnownAnswerSuppressionEnabled,
                     mIsUnicastReplyEnabled,
                     mIsAggressiveQueryModeEnabled,
+                    mIsQueryWithKnownAnswerEnabled,
                     mOverrideProvider);
         }
     }
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java b/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java
index c1c7d5f..0b2003f 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java
@@ -22,12 +22,10 @@
 import android.annotation.Nullable;
 import android.annotation.RequiresApi;
 import android.net.LinkAddress;
-import android.net.nsd.NsdManager;
 import android.net.nsd.NsdServiceInfo;
 import android.os.Build;
 import android.os.Handler;
 import android.os.Looper;
-import android.util.ArraySet;
 
 import com.android.internal.annotations.VisibleForTesting;
 import com.android.net.module.util.HexDump;
@@ -38,6 +36,7 @@
 
 import java.io.IOException;
 import java.net.InetSocketAddress;
+import java.util.ArrayList;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
@@ -136,6 +135,15 @@
 
             mAnnouncer.startSending(info.getServiceId(), announcementInfo,
                     0L /* initialDelayMs */);
+
+            // Re-announce the services which have the same custom hostname.
+            final String hostname = mRecordRepository.getHostnameForServiceId(info.getServiceId());
+            if (hostname != null) {
+                final List<MdnsAnnouncer.AnnouncementInfo> announcementInfos =
+                        new ArrayList<>(mRecordRepository.restartAnnouncingForHostname(hostname));
+                announcementInfos.removeIf((i) -> i.getServiceId() == info.getServiceId());
+                reannounceServices(announcementInfos);
+            }
         }
     }
 
@@ -284,6 +292,7 @@
         if (!mRecordRepository.hasActiveService(id)) return;
         mProber.stop(id);
         mAnnouncer.stop(id);
+        final String hostname = mRecordRepository.getHostnameForServiceId(id);
         final MdnsAnnouncer.ExitAnnouncementInfo exitInfo = mRecordRepository.exitService(id);
         if (exitInfo != null) {
             // This effectively schedules onAllServicesRemoved(), as it is to be called when the
@@ -303,6 +312,17 @@
                 }
             });
         }
+        // Re-probe/re-announce the services which have the same custom hostname. These services
+        // were probed/announced using host addresses which were just removed so they should be
+        // re-probed/re-announced without those addresses.
+        if (hostname != null) {
+            final List<MdnsProber.ProbingInfo> probingInfos =
+                    mRecordRepository.restartProbingForHostname(hostname);
+            reprobeServices(probingInfos);
+            final List<MdnsAnnouncer.AnnouncementInfo> announcementInfos =
+                    mRecordRepository.restartAnnouncingForHostname(hostname);
+            reannounceServices(announcementInfos);
+        }
     }
 
     /**
@@ -447,4 +467,19 @@
             return new byte[0];
         }
     }
+
+    private void reprobeServices(List<MdnsProber.ProbingInfo> probingInfos) {
+        for (MdnsProber.ProbingInfo probingInfo : probingInfos) {
+            mProber.stop(probingInfo.getServiceId());
+            mProber.startProbing(probingInfo);
+        }
+    }
+
+    private void reannounceServices(List<MdnsAnnouncer.AnnouncementInfo> announcementInfos) {
+        for (MdnsAnnouncer.AnnouncementInfo announcementInfo : announcementInfos) {
+            mAnnouncer.stop(announcementInfo.getServiceId());
+            mAnnouncer.startSending(
+                    announcementInfo.getServiceId(), announcementInfo, 0 /* initialDelayMs */);
+        }
+    }
 }
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java b/service-t/src/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java
index 869ac9b..fcfb15f 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java
@@ -27,6 +27,7 @@
 import android.os.Handler;
 import android.os.Looper;
 import android.util.ArrayMap;
+import android.util.Log;
 
 import com.android.net.module.util.SharedLog;
 
@@ -213,24 +214,30 @@
         return true;
     }
 
-    private void sendMdnsPacket(@NonNull DatagramPacket packet, @NonNull SocketKey targetSocketKey,
-            boolean onlyUseIpv6OnIpv6OnlyNetworks) {
+    private void sendMdnsPackets(@NonNull List<DatagramPacket> packets,
+            @NonNull SocketKey targetSocketKey, boolean onlyUseIpv6OnIpv6OnlyNetworks) {
         final MdnsInterfaceSocket socket = getTargetSocket(targetSocketKey);
         if (socket == null) {
             mSharedLog.e("No socket matches targetSocketKey=" + targetSocketKey);
             return;
         }
+        if (packets.isEmpty()) {
+            Log.wtf(TAG, "No mDns packets to send");
+            return;
+        }
 
-        final boolean isIpv6 = ((InetSocketAddress) packet.getSocketAddress()).getAddress()
-                instanceof Inet6Address;
-        final boolean isIpv4 = ((InetSocketAddress) packet.getSocketAddress()).getAddress()
-                instanceof Inet4Address;
+        final boolean isIpv6 = ((InetSocketAddress) packets.get(0).getSocketAddress())
+                .getAddress() instanceof Inet6Address;
+        final boolean isIpv4 = ((InetSocketAddress) packets.get(0).getSocketAddress())
+                .getAddress() instanceof Inet4Address;
         final boolean shouldQueryIpv6 = !onlyUseIpv6OnIpv6OnlyNetworks || !socket.hasJoinedIpv4();
         // Check ip capability and network before sending packet
         if ((isIpv6 && socket.hasJoinedIpv6() && shouldQueryIpv6)
                 || (isIpv4 && socket.hasJoinedIpv4())) {
             try {
-                socket.send(packet);
+                for (DatagramPacket packet : packets) {
+                    socket.send(packet);
+                }
             } catch (IOException e) {
                 mSharedLog.e("Failed to send a mDNS packet.", e);
             }
@@ -259,34 +266,34 @@
     }
 
     /**
-     * Send a mDNS request packet via given socket key that asks for multicast response.
+     * Send mDNS request packets via given socket key that asks for multicast response.
      */
-    public void sendPacketRequestingMulticastResponse(@NonNull DatagramPacket packet,
+    public void sendPacketRequestingMulticastResponse(@NonNull List<DatagramPacket> packets,
             @NonNull SocketKey socketKey, boolean onlyUseIpv6OnIpv6OnlyNetworks) {
-        mHandler.post(() -> sendMdnsPacket(packet, socketKey, onlyUseIpv6OnIpv6OnlyNetworks));
+        mHandler.post(() -> sendMdnsPackets(packets, socketKey, onlyUseIpv6OnIpv6OnlyNetworks));
     }
 
     @Override
     public void sendPacketRequestingMulticastResponse(
-            @NonNull DatagramPacket packet, boolean onlyUseIpv6OnIpv6OnlyNetworks) {
+            @NonNull List<DatagramPacket> packets, boolean onlyUseIpv6OnIpv6OnlyNetworks) {
         throw new UnsupportedOperationException("This socket client need to specify the socket to"
                 + "send packet");
     }
 
     /**
-     * Send a mDNS request packet via given socket key that asks for unicast response.
+     * Send mDNS request packets via given socket key that asks for unicast response.
      *
      * <p>The socket client may use a null network to identify some or all interfaces, in which case
      * passing null sends the packet to these.
      */
-    public void sendPacketRequestingUnicastResponse(@NonNull DatagramPacket packet,
+    public void sendPacketRequestingUnicastResponse(@NonNull List<DatagramPacket> packets,
             @NonNull SocketKey socketKey, boolean onlyUseIpv6OnIpv6OnlyNetworks) {
-        mHandler.post(() -> sendMdnsPacket(packet, socketKey, onlyUseIpv6OnIpv6OnlyNetworks));
+        mHandler.post(() -> sendMdnsPackets(packets, socketKey, onlyUseIpv6OnIpv6OnlyNetworks));
     }
 
     @Override
     public void sendPacketRequestingUnicastResponse(
-            @NonNull DatagramPacket packet, boolean onlyUseIpv6OnIpv6OnlyNetworks) {
+            @NonNull List<DatagramPacket> packets, boolean onlyUseIpv6OnIpv6OnlyNetworks) {
         throw new UnsupportedOperationException("This socket client need to specify the socket to"
                 + "send packet");
     }
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsPacket.java b/service-t/src/com/android/server/connectivity/mdns/MdnsPacket.java
index 1fabd49..83ecabc 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsPacket.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsPacket.java
@@ -42,7 +42,7 @@
     @NonNull
     public final List<MdnsRecord> additionalRecords;
 
-    MdnsPacket(int flags,
+    public MdnsPacket(int flags,
             @NonNull List<MdnsRecord> questions,
             @NonNull List<MdnsRecord> answers,
             @NonNull List<MdnsRecord> authorityRecords,
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsRecord.java b/service-t/src/com/android/server/connectivity/mdns/MdnsRecord.java
index 4b43989..1f9f42b 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsRecord.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsRecord.java
@@ -23,7 +23,8 @@
 import android.os.SystemClock;
 import android.text.TextUtils;
 
-import com.android.internal.annotations.VisibleForTesting;
+import androidx.annotation.VisibleForTesting;
+
 import com.android.server.connectivity.mdns.util.MdnsUtils;
 
 import java.io.IOException;
@@ -231,7 +232,7 @@
      * @param writer The writer to use.
      * @param now    The current system time. This is used when writing the updated TTL.
      */
-    @VisibleForTesting
+    @VisibleForTesting(otherwise = VisibleForTesting.PACKAGE_PRIVATE)
     public final void write(MdnsPacketWriter writer, long now) throws IOException {
         writeHeaderFields(writer);
 
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java b/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java
index ac64c3a..073e465 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java
@@ -925,22 +925,79 @@
         }
     }
 
+    @Nullable
+    public String getHostnameForServiceId(int id) {
+        ServiceRegistration registration = mServices.get(id);
+        if (registration == null) {
+            return null;
+        }
+        return registration.serviceInfo.getHostname();
+    }
+
+    /**
+     * Restart probing the services which are being probed and using the given custom hostname.
+     *
+     * @return The list of {@link MdnsProber.ProbingInfo} to be used by advertiser.
+     */
+    public List<MdnsProber.ProbingInfo> restartProbingForHostname(@NonNull String hostname) {
+        final ArrayList<MdnsProber.ProbingInfo> probingInfos = new ArrayList<>();
+        forEachActiveServiceRegistrationWithHostname(
+                hostname,
+                (id, registration) -> {
+                    if (!registration.isProbing) {
+                        return;
+                    }
+                    probingInfos.add(makeProbingInfo(id, registration));
+                });
+        return probingInfos;
+    }
+
+    /**
+     * Restart announcing the services which are using the given custom hostname.
+     *
+     * @return The list of {@link MdnsAnnouncer.AnnouncementInfo} to be used by advertiser.
+     */
+    public List<MdnsAnnouncer.AnnouncementInfo> restartAnnouncingForHostname(
+            @NonNull String hostname) {
+        final ArrayList<MdnsAnnouncer.AnnouncementInfo> announcementInfos = new ArrayList<>();
+        forEachActiveServiceRegistrationWithHostname(
+                hostname,
+                (id, registration) -> {
+                    if (registration.isProbing) {
+                        return;
+                    }
+                    announcementInfos.add(makeAnnouncementInfo(id, registration));
+                });
+        return announcementInfos;
+    }
+
     /**
      * Called to indicate that probing succeeded for a service.
+     *
      * @param probeSuccessInfo The successful probing info.
      * @return The {@link MdnsAnnouncer.AnnouncementInfo} to send, now that probing has succeeded.
      */
     public MdnsAnnouncer.AnnouncementInfo onProbingSucceeded(
-            MdnsProber.ProbingInfo probeSuccessInfo)
-            throws IOException {
-
-        int serviceId = probeSuccessInfo.getServiceId();
+            MdnsProber.ProbingInfo probeSuccessInfo) throws IOException {
+        final int serviceId = probeSuccessInfo.getServiceId();
         final ServiceRegistration registration = mServices.get(serviceId);
         if (registration == null) {
             throw new IOException("Service is not registered: " + serviceId);
         }
         registration.setProbing(false);
 
+        return makeAnnouncementInfo(serviceId, registration);
+    }
+
+    /**
+     * Make the announcement info of the given service ID.
+     *
+     * @param serviceId The service ID.
+     * @param registration The service registration.
+     * @return The {@link MdnsAnnouncer.AnnouncementInfo} of the given service ID.
+     */
+    private MdnsAnnouncer.AnnouncementInfo makeAnnouncementInfo(
+            int serviceId, ServiceRegistration registration) {
         final Set<MdnsRecord> answersSet = new LinkedHashSet<>();
         final ArrayList<MdnsRecord> additionalAnswers = new ArrayList<>();
 
@@ -972,8 +1029,8 @@
         addNsecRecordsForUniqueNames(additionalAnswers,
                 mGeneralRecords.iterator(), registration.allRecords.iterator());
 
-        return new MdnsAnnouncer.AnnouncementInfo(
-                probeSuccessInfo.getServiceId(), new ArrayList<>(answersSet), additionalAnswers);
+        return new MdnsAnnouncer.AnnouncementInfo(serviceId,
+                new ArrayList<>(answersSet), additionalAnswers);
     }
 
     /**
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
index 8f41b94..b3bdbe0 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
@@ -30,14 +30,18 @@
 import android.util.ArrayMap;
 import android.util.Pair;
 
-import com.android.internal.annotations.VisibleForTesting;
+import androidx.annotation.VisibleForTesting;
+
 import com.android.net.module.util.CollectionUtils;
 import com.android.net.module.util.SharedLog;
 import com.android.server.connectivity.mdns.util.MdnsUtils;
 
+import java.io.IOException;
 import java.io.PrintWriter;
+import java.net.DatagramPacket;
 import java.net.Inet4Address;
 import java.net.Inet6Address;
+import java.net.InetSocketAddress;
 import java.time.Instant;
 import java.util.ArrayList;
 import java.util.Arrays;
@@ -55,7 +59,6 @@
 public class MdnsServiceTypeClient {
 
     private static final String TAG = MdnsServiceTypeClient.class.getSimpleName();
-    private static final int DEFAULT_MTU = 1500;
     @VisibleForTesting
     static final int EVENT_START_QUERYTASK = 1;
     static final int EVENT_QUERY_RESULT = 2;
@@ -84,6 +87,7 @@
                     notifyRemovedServiceToListeners(previousResponse, "Service record expired");
                 }
             };
+    @NonNull private final MdnsFeatureFlags featureFlags;
     private final ArrayMap<MdnsServiceBrowserListener, ListenerInfo> listeners =
             new ArrayMap<>();
     private final boolean removeServiceAfterTtlExpires =
@@ -142,7 +146,8 @@
                     // before sending the query, it needs to be called just before sending it.
                     final List<MdnsResponse> servicesToResolve = makeResponsesForResolve(socketKey);
                     final QueryTask queryTask = new QueryTask(taskArgs, servicesToResolve,
-                            getAllDiscoverySubtypes(), needSendDiscoveryQueries(listeners));
+                            getAllDiscoverySubtypes(), needSendDiscoveryQueries(listeners),
+                            getExistingServices());
                     executor.submit(queryTask);
                     break;
                 }
@@ -191,7 +196,7 @@
     /**
      * Dependencies of MdnsServiceTypeClient, for injection in tests.
      */
-    @VisibleForTesting
+    @VisibleForTesting(otherwise = VisibleForTesting.PACKAGE_PRIVATE)
     public static class Dependencies {
         /**
          * @see Handler#sendMessageDelayed(Message, long)
@@ -221,6 +226,25 @@
         public void sendMessage(@NonNull Handler handler, @NonNull Message message) {
             handler.sendMessage(message);
         }
+
+        /**
+         * Generate the DatagramPackets from given MdnsPacket and InetSocketAddress.
+         *
+         * <p> If the query with known answer feature is enabled and the MdnsPacket is too large for
+         *     a single DatagramPacket, it will be split into multiple DatagramPackets.
+         */
+        public List<DatagramPacket> getDatagramPacketsFromMdnsPacket(
+                @NonNull byte[] packetCreationBuffer, @NonNull MdnsPacket packet,
+                @NonNull InetSocketAddress address, boolean isQueryWithKnownAnswer)
+                throws IOException {
+            if (isQueryWithKnownAnswer) {
+                return MdnsUtils.createQueryDatagramPackets(packetCreationBuffer, packet, address);
+            } else {
+                final byte[] queryBuffer =
+                        MdnsUtils.createRawDnsPacket(packetCreationBuffer, packet);
+                return List.of(new DatagramPacket(queryBuffer, 0, queryBuffer.length, address));
+            }
+        }
     }
 
     /**
@@ -236,9 +260,10 @@
             @NonNull SocketKey socketKey,
             @NonNull SharedLog sharedLog,
             @NonNull Looper looper,
-            @NonNull MdnsServiceCache serviceCache) {
+            @NonNull MdnsServiceCache serviceCache,
+            @NonNull MdnsFeatureFlags featureFlags) {
         this(serviceType, socketClient, executor, new Clock(), socketKey, sharedLog, looper,
-                new Dependencies(), serviceCache);
+                new Dependencies(), serviceCache, featureFlags);
     }
 
     @VisibleForTesting
@@ -251,7 +276,8 @@
             @NonNull SharedLog sharedLog,
             @NonNull Looper looper,
             @NonNull Dependencies dependencies,
-            @NonNull MdnsServiceCache serviceCache) {
+            @NonNull MdnsServiceCache serviceCache,
+            @NonNull MdnsFeatureFlags featureFlags) {
         this.serviceType = serviceType;
         this.socketClient = socketClient;
         this.executor = executor;
@@ -265,6 +291,7 @@
         this.serviceCache = serviceCache;
         this.mdnsQueryScheduler = new MdnsQueryScheduler();
         this.cacheKey = new MdnsServiceCache.CacheKey(serviceType, socketKey);
+        this.featureFlags = featureFlags;
     }
 
     /**
@@ -327,6 +354,11 @@
                 now.plusMillis(response.getMinRemainingTtl(now.toEpochMilli())));
     }
 
+    private List<MdnsResponse> getExistingServices() {
+        return featureFlags.isQueryWithKnownAnswerEnabled()
+                ? serviceCache.getCachedServices(cacheKey) : Collections.emptyList();
+    }
+
     /**
      * Registers {@code listener} for receiving discovery event of mDNS service instances, and
      * starts
@@ -391,7 +423,8 @@
             final QueryTask queryTask = new QueryTask(
                     mdnsQueryScheduler.scheduleFirstRun(taskConfig, now,
                             minRemainingTtl, currentSessionId), servicesToResolve,
-                    getAllDiscoverySubtypes(), needSendDiscoveryQueries(listeners));
+                    getAllDiscoverySubtypes(), needSendDiscoveryQueries(listeners),
+                    getExistingServices());
             executor.submit(queryTask);
         }
 
@@ -617,11 +650,6 @@
         return searchOptions != null && searchOptions.removeExpiredService();
     }
 
-    @VisibleForTesting
-    MdnsPacketWriter createMdnsPacketWriter() {
-        return new MdnsPacketWriter(DEFAULT_MTU);
-    }
-
     private List<MdnsResponse> makeResponsesForResolve(@NonNull SocketKey socketKey) {
         final List<MdnsResponse> resolveResponses = new ArrayList<>();
         for (int i = 0; i < listeners.size(); i++) {
@@ -694,14 +722,16 @@
         private final List<MdnsResponse> servicesToResolve = new ArrayList<>();
         private final List<String> subtypes = new ArrayList<>();
         private final boolean sendDiscoveryQueries;
+        private final List<MdnsResponse> existingServices = new ArrayList<>();
         QueryTask(@NonNull MdnsQueryScheduler.ScheduledQueryTaskArgs taskArgs,
                 @NonNull Collection<MdnsResponse> servicesToResolve,
-                @NonNull Collection<String> subtypes,
-                boolean sendDiscoveryQueries) {
+                @NonNull Collection<String> subtypes, boolean sendDiscoveryQueries,
+                @NonNull Collection<MdnsResponse> existingServices) {
             this.taskArgs = taskArgs;
             this.servicesToResolve.addAll(servicesToResolve);
             this.subtypes.addAll(subtypes);
             this.sendDiscoveryQueries = sendDiscoveryQueries;
+            this.existingServices.addAll(existingServices);
         }
 
         @Override
@@ -711,7 +741,6 @@
                 result =
                         new EnqueueMdnsQueryCallable(
                                 socketClient,
-                                createMdnsPacketWriter(),
                                 serviceType,
                                 subtypes,
                                 taskArgs.config.expectUnicastResponse,
@@ -721,7 +750,10 @@
                                 sendDiscoveryQueries,
                                 servicesToResolve,
                                 clock,
-                                sharedLog)
+                                sharedLog,
+                                dependencies,
+                                existingServices,
+                                featureFlags.isQueryWithKnownAnswerEnabled())
                                 .call();
             } catch (RuntimeException e) {
                 sharedLog.e(String.format("Failed to run EnqueueMdnsQueryCallable for subtype: %s",
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClient.java b/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClient.java
index 7b71e43..9cfcba1 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClient.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClient.java
@@ -25,6 +25,7 @@
 import android.net.wifi.WifiManager.MulticastLock;
 import android.os.SystemClock;
 import android.text.format.DateUtils;
+import android.util.Log;
 
 import com.android.internal.annotations.VisibleForTesting;
 import com.android.net.module.util.SharedLog;
@@ -206,18 +207,18 @@
     }
 
     @Override
-    public void sendPacketRequestingMulticastResponse(@NonNull DatagramPacket packet,
+    public void sendPacketRequestingMulticastResponse(@NonNull List<DatagramPacket> packets,
             boolean onlyUseIpv6OnIpv6OnlyNetworks) {
-        sendMdnsPacket(packet, multicastPacketQueue, onlyUseIpv6OnIpv6OnlyNetworks);
+        sendMdnsPackets(packets, multicastPacketQueue, onlyUseIpv6OnIpv6OnlyNetworks);
     }
 
     @Override
-    public void sendPacketRequestingUnicastResponse(@NonNull DatagramPacket packet,
+    public void sendPacketRequestingUnicastResponse(@NonNull List<DatagramPacket> packets,
             boolean onlyUseIpv6OnIpv6OnlyNetworks) {
         if (useSeparateSocketForUnicast) {
-            sendMdnsPacket(packet, unicastPacketQueue, onlyUseIpv6OnIpv6OnlyNetworks);
+            sendMdnsPackets(packets, unicastPacketQueue, onlyUseIpv6OnIpv6OnlyNetworks);
         } else {
-            sendMdnsPacket(packet, multicastPacketQueue, onlyUseIpv6OnIpv6OnlyNetworks);
+            sendMdnsPackets(packets, multicastPacketQueue, onlyUseIpv6OnIpv6OnlyNetworks);
         }
     }
 
@@ -238,17 +239,21 @@
         return false;
     }
 
-    private void sendMdnsPacket(DatagramPacket packet, Queue<DatagramPacket> packetQueueToUse,
-            boolean onlyUseIpv6OnIpv6OnlyNetworks) {
+    private void sendMdnsPackets(List<DatagramPacket> packets,
+            Queue<DatagramPacket> packetQueueToUse, boolean onlyUseIpv6OnIpv6OnlyNetworks) {
         if (shouldStopSocketLoop && !MdnsConfigs.allowAddMdnsPacketAfterDiscoveryStops()) {
             sharedLog.w("sendMdnsPacket() is called after discovery already stopped");
             return;
         }
+        if (packets.isEmpty()) {
+            Log.wtf(TAG, "No mDns packets to send");
+            return;
+        }
 
-        final boolean isIpv4 = ((InetSocketAddress) packet.getSocketAddress()).getAddress()
-                instanceof Inet4Address;
-        final boolean isIpv6 = ((InetSocketAddress) packet.getSocketAddress()).getAddress()
-                instanceof Inet6Address;
+        final boolean isIpv4 = ((InetSocketAddress) packets.get(0).getSocketAddress())
+                .getAddress() instanceof Inet4Address;
+        final boolean isIpv6 = ((InetSocketAddress) packets.get(0).getSocketAddress())
+                .getAddress() instanceof Inet6Address;
         final boolean ipv6Only = multicastSocket != null && multicastSocket.isOnIPv6OnlyNetwork();
         if (isIpv4 && ipv6Only) {
             return;
@@ -258,10 +263,11 @@
         }
 
         synchronized (packetQueueToUse) {
-            while (packetQueueToUse.size() >= MdnsConfigs.mdnsPacketQueueMaxSize()) {
+            while ((packetQueueToUse.size() + packets.size())
+                    > MdnsConfigs.mdnsPacketQueueMaxSize()) {
                 packetQueueToUse.remove();
             }
-            packetQueueToUse.add(packet);
+            packetQueueToUse.addAll(packets);
         }
         triggerSendThread();
     }
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClientBase.java b/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClientBase.java
index b6000f0..b1a543a 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClientBase.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClientBase.java
@@ -23,6 +23,7 @@
 
 import java.io.IOException;
 import java.net.DatagramPacket;
+import java.util.List;
 
 /**
  * Base class for multicast socket client.
@@ -40,15 +41,15 @@
     void setCallback(@Nullable Callback callback);
 
     /**
-     * Send a mDNS request packet via given network that asks for multicast response.
+     * Send mDNS request packets via given network that asks for multicast response.
      */
-    void sendPacketRequestingMulticastResponse(@NonNull DatagramPacket packet,
+    void sendPacketRequestingMulticastResponse(@NonNull List<DatagramPacket> packets,
             boolean onlyUseIpv6OnIpv6OnlyNetworks);
 
     /**
-     * Send a mDNS request packet via given network that asks for unicast response.
+     * Send mDNS request packets via given network that asks for unicast response.
      */
-    void sendPacketRequestingUnicastResponse(@NonNull DatagramPacket packet,
+    void sendPacketRequestingUnicastResponse(@NonNull List<DatagramPacket> packets,
             boolean onlyUseIpv6OnIpv6OnlyNetworks);
 
     /*** Notify that the given network is requested for mdns discovery / resolution */
diff --git a/service-t/src/com/android/server/connectivity/mdns/util/MdnsUtils.java b/service-t/src/com/android/server/connectivity/mdns/util/MdnsUtils.java
index d553210..3c11a24 100644
--- a/service-t/src/com/android/server/connectivity/mdns/util/MdnsUtils.java
+++ b/service-t/src/com/android/server/connectivity/mdns/util/MdnsUtils.java
@@ -16,6 +16,8 @@
 
 package com.android.server.connectivity.mdns.util;
 
+import static com.android.server.connectivity.mdns.MdnsConstants.FLAG_TRUNCATED;
+
 import android.annotation.NonNull;
 import android.annotation.Nullable;
 import android.net.Network;
@@ -23,6 +25,7 @@
 import android.os.Handler;
 import android.os.SystemClock;
 import android.util.ArraySet;
+import android.util.Pair;
 
 import com.android.server.connectivity.mdns.MdnsConstants;
 import com.android.server.connectivity.mdns.MdnsPacket;
@@ -30,13 +33,18 @@
 import com.android.server.connectivity.mdns.MdnsRecord;
 
 import java.io.IOException;
+import java.net.DatagramPacket;
+import java.net.InetSocketAddress;
 import java.nio.ByteBuffer;
 import java.nio.CharBuffer;
 import java.nio.charset.Charset;
 import java.nio.charset.CharsetEncoder;
 import java.nio.charset.StandardCharsets;
+import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Collections;
 import java.util.HashSet;
+import java.util.List;
 import java.util.Set;
 
 /**
@@ -226,6 +234,100 @@
     }
 
     /**
+     * Writes the possible query content of an MdnsPacket into the data buffer.
+     *
+     * <p>This method is specifically for query packets. It writes the question and answer sections
+     *    into the data buffer only.
+     *
+     * @param packetCreationBuffer The data buffer for the query content.
+     * @param packet The MdnsPacket to be written into the data buffer.
+     * @return A Pair containing:
+     *         1. The remaining MdnsPacket data that could not fit in the buffer.
+     *         2. The length of the data written to the buffer.
+     */
+    @Nullable
+    private static Pair<MdnsPacket, Integer> writePossibleMdnsPacket(
+            @NonNull byte[] packetCreationBuffer, @NonNull MdnsPacket packet) throws IOException {
+        MdnsPacket remainingPacket;
+        final MdnsPacketWriter writer = new MdnsPacketWriter(packetCreationBuffer);
+        writer.writeUInt16(packet.transactionId); // Transaction ID
+
+        final int flagsPos = writer.getWritePosition();
+        writer.writeUInt16(0); // Flags, written later
+        writer.writeUInt16(0); // questions count, written later
+        writer.writeUInt16(0); // answers count, written later
+        writer.writeUInt16(0); // authority entries count, empty session for query
+        writer.writeUInt16(0); // additional records count, empty session for query
+
+        int writtenQuestions = 0;
+        int writtenAnswers = 0;
+        int lastValidPos = writer.getWritePosition();
+        try {
+            for (MdnsRecord record : packet.questions) {
+                // Questions do not have TTL or data
+                record.writeHeaderFields(writer);
+                writtenQuestions++;
+                lastValidPos = writer.getWritePosition();
+            }
+            for (MdnsRecord record : packet.answers) {
+                record.write(writer, 0L);
+                writtenAnswers++;
+                lastValidPos = writer.getWritePosition();
+            }
+            remainingPacket = null;
+        } catch (IOException e) {
+            // Went over the packet limit; truncate
+            if (writtenQuestions == 0 && writtenAnswers == 0) {
+                // No space to write even one record: just throw (as subclass of IOException)
+                throw e;
+            }
+
+            // Set the last valid position as the final position (not as a rewind)
+            writer.rewind(lastValidPos);
+            writer.clearRewind();
+
+            remainingPacket = new MdnsPacket(packet.flags,
+                    packet.questions.subList(
+                            writtenQuestions, packet.questions.size()),
+                    packet.answers.subList(
+                            writtenAnswers, packet.answers.size()),
+                    Collections.emptyList(), /* authorityRecords */
+                    Collections.emptyList() /* additionalRecords */);
+        }
+
+        final int len = writer.getWritePosition();
+        writer.rewind(flagsPos);
+        writer.writeUInt16(packet.flags | (remainingPacket == null ? 0 : FLAG_TRUNCATED));
+        writer.writeUInt16(writtenQuestions);
+        writer.writeUInt16(writtenAnswers);
+        writer.unrewind();
+
+        return Pair.create(remainingPacket, len);
+    }
+
+    /**
+     * Create Datagram packets from given MdnsPacket and InetSocketAddress.
+     *
+     * <p> If the MdnsPacket is too large for a single DatagramPacket, it will be split into
+     *     multiple DatagramPackets.
+     */
+    public static List<DatagramPacket> createQueryDatagramPackets(
+            @NonNull byte[] packetCreationBuffer, @NonNull MdnsPacket packet,
+            @NonNull InetSocketAddress destination) throws IOException {
+        final List<DatagramPacket> datagramPackets = new ArrayList<>();
+        MdnsPacket remainingPacket = packet;
+        while (remainingPacket != null) {
+            final Pair<MdnsPacket, Integer> result =
+                    writePossibleMdnsPacket(packetCreationBuffer, remainingPacket);
+            remainingPacket = result.first;
+            final int len = result.second;
+            final byte[] outBuffer = Arrays.copyOfRange(packetCreationBuffer, 0, len);
+            datagramPackets.add(new DatagramPacket(outBuffer, 0, outBuffer.length, destination));
+        }
+        return datagramPackets;
+    }
+
+    /**
      * Checks if the MdnsRecord needs to be renewed or not.
      *
      * <p>As per RFC6762 7.1 no need to query if remaining TTL is more than half the original one,
diff --git a/service-t/src/com/android/server/ethernet/EthernetTracker.java b/service-t/src/com/android/server/ethernet/EthernetTracker.java
index bfcc171..9c8fd99 100644
--- a/service-t/src/com/android/server/ethernet/EthernetTracker.java
+++ b/service-t/src/com/android/server/ethernet/EthernetTracker.java
@@ -460,7 +460,7 @@
             if (!include) {
                 removeTestData();
             }
-            mHandler.post(() -> trackAvailableInterfaces());
+            trackAvailableInterfaces();
         });
     }
 
diff --git a/service/Android.bp b/service/Android.bp
index c35c4f8..1d74efc 100644
--- a/service/Android.bp
+++ b/service/Android.bp
@@ -199,7 +199,9 @@
         "PlatformProperties",
         "service-connectivity-protos",
         "service-connectivity-stats-protos",
-        "net-utils-multicast-forwarding-structs",
+        // The required dependency net-utils-device-common-struct-base is in the classpath via
+        // framework-connectivity
+        "net-utils-device-common-struct",
     ],
     apex_available: [
         "com.android.tethering",
diff --git a/service/ServiceConnectivityResources/res/values/config_thread.xml b/service/ServiceConnectivityResources/res/values/config_thread.xml
index f7e47f5..4783f2b 100644
--- a/service/ServiceConnectivityResources/res/values/config_thread.xml
+++ b/service/ServiceConnectivityResources/res/values/config_thread.xml
@@ -31,4 +31,26 @@
     Thread Network regulatory purposes.
     -->
     <bool name="config_thread_location_use_for_country_code_enabled">true</bool>
+
+    <!-- Specifies the UTF-8 vendor name of this device. If this value is not an empty string, it
+    will be included in TXT value (key is 'vn') of the "_meshcop._udp" mDNS service which is
+    published by the Thread service. A non-empty string value must not exceed length of 24 UTF-8
+    bytes.
+    -->
+    <string translatable="false" name="config_thread_vendor_name">Android</string>
+
+    <!-- Specifies the 24 bits vendor OUI of this device. If this value is not an empty string, it
+    will be included in TXT (key is 'vo') value of the "_meshcop._udp" mDNS service which is
+    published by the Thread service. The OUI can be represented as a base-16 number of six
+    hexadecimal digits, or octets separated by hyphens or dots. For example, "ACDE48", "AC-DE-48"
+    and "AC:DE:48" are all valid representations of the same OUI value.
+    -->
+    <string translatable="false" name="config_thread_vendor_oui"></string>
+
+    <!-- Specifies the UTF-8 product model name of this device. If this value is not an empty
+    string, it will be included in TXT (key is 'mn') value of the "_meshcop._udp" mDNS service
+    which is published by the Thread service. A non-empty string value must not exceed length of 24
+    UTF-8 bytes.
+    -->
+    <string translatable="false" name="config_thread_model_name">Thread Border Router</string>
 </resources>
diff --git a/service/ServiceConnectivityResources/res/values/overlayable.xml b/service/ServiceConnectivityResources/res/values/overlayable.xml
index d9af5a3..158b0c8 100644
--- a/service/ServiceConnectivityResources/res/values/overlayable.xml
+++ b/service/ServiceConnectivityResources/res/values/overlayable.xml
@@ -48,6 +48,9 @@
             <!-- Configuration values for ThreadNetworkService -->
             <item type="bool" name="config_thread_default_enabled" />
             <item type="bool" name="config_thread_location_use_for_country_code_enabled" />
+            <item type="string" name="config_thread_vendor_name" />
+            <item type="string" name="config_thread_vendor_oui" />
+            <item type="string" name="config_thread_model_name" />
         </policy>
     </overlayable>
 </resources>
diff --git a/service/jni/com_android_server_connectivity_ClatCoordinator.cpp b/service/jni/com_android_server_connectivity_ClatCoordinator.cpp
index c125bd6..4214bc9 100644
--- a/service/jni/com_android_server_connectivity_ClatCoordinator.cpp
+++ b/service/jni/com_android_server_connectivity_ClatCoordinator.cpp
@@ -113,7 +113,12 @@
     if (!modules::sdklevel::IsAtLeastT()) return;
 
     V("/sys/fs/bpf", S_IFDIR|S_ISVTX|0777, ROOT, ROOT, "fs_bpf", DIR);
-    V("/sys/fs/bpf/net_shared", S_IFDIR|S_ISVTX|0777, ROOT, ROOT, "fs_bpf_net_shared", DIR);
+
+    if (false && modules::sdklevel::IsAtLeastV()) {
+        V("/sys/fs/bpf/net_shared", S_IFDIR|01777, ROOT, ROOT, "fs_bpf_net_shared", DIR);
+    } else {
+        V("/sys/fs/bpf/net_shared", S_IFDIR|01777, SYSTEM, SYSTEM, "fs_bpf_net_shared", DIR);
+    }
 
     // pre-U we do not have selinux privs to getattr on bpf maps/progs
     // so while the below *should* be as listed, we have no way to actually verify
diff --git a/service/src/com/android/server/BpfNetMaps.java b/service/src/com/android/server/BpfNetMaps.java
index a7fddd0..42c1628 100644
--- a/service/src/com/android/server/BpfNetMaps.java
+++ b/service/src/com/android/server/BpfNetMaps.java
@@ -49,7 +49,7 @@
 
 import android.app.StatsManager;
 import android.content.Context;
-import android.net.BpfNetMapsReader;
+import android.net.BpfNetMapsUtils;
 import android.net.INetd;
 import android.net.UidOwnerValue;
 import android.os.Build;
@@ -535,14 +535,11 @@
      * @throws UnsupportedOperationException if called on pre-T devices.
      * @throws ServiceSpecificException in case of failure, with an error code indicating the
      *                                  cause of the failure.
-     *
-     * @deprecated Use {@link BpfNetMapsReader#isChainEnabled} instead.
      */
-    // TODO: Migrate the callers to use {@link BpfNetMapsReader#isChainEnabled} instead.
     @Deprecated
     @RequiresApi(Build.VERSION_CODES.TIRAMISU)
     public boolean isChainEnabled(final int childChain) {
-        return BpfNetMapsReader.isChainEnabled(sConfigurationMap, childChain);
+        return BpfNetMapsUtils.isChainEnabled(sConfigurationMap, childChain);
     }
 
     private Set<Integer> asSet(final int[] uids) {
@@ -635,12 +632,9 @@
      * @throws UnsupportedOperationException if called on pre-T devices.
      * @throws ServiceSpecificException in case of failure, with an error code indicating the
      *                                  cause of the failure.
-     *
-     * @deprecated use {@link BpfNetMapsReader#getUidRule} instead.
      */
-    // TODO: Migrate the callers to use {@link BpfNetMapsReader#getUidRule} instead.
     public int getUidRule(final int childChain, final int uid) {
-        return BpfNetMapsReader.getUidRule(sUidOwnerMap, childChain, uid);
+        return BpfNetMapsUtils.getUidRule(sUidOwnerMap, childChain, uid);
     }
 
     private Set<Integer> getUidsMatchEnabled(final int childChain) throws ErrnoException {
@@ -924,6 +918,25 @@
         }
     }
 
+    /**
+     * Return whether the network is blocked by firewall chains for the given uid.
+     *
+     * Note that {@link #getDataSaverEnabled()} has a latency before V.
+     *
+     * @param uid The target uid.
+     * @param isNetworkMetered Whether the target network is metered.
+     *
+     * @return True if the network is blocked. Otherwise, false.
+     * @throws ServiceSpecificException if the read fails.
+     *
+     * @hide
+     */
+    @RequiresApi(Build.VERSION_CODES.TIRAMISU)
+    public boolean isUidNetworkingBlocked(final int uid, boolean isNetworkMetered) {
+        return BpfNetMapsUtils.isUidNetworkingBlocked(uid, isNetworkMetered,
+                sConfigurationMap, sUidOwnerMap, sDataSaverEnabledMap);
+    }
+
     /** Register callback for statsd to pull atom. */
     @RequiresApi(Build.VERSION_CODES.TIRAMISU)
     public void setPullAtomCallback(final Context context) {
diff --git a/service/src/com/android/server/ConnectivityService.java b/service/src/com/android/server/ConnectivityService.java
index f27e645..005d617 100755
--- a/service/src/com/android/server/ConnectivityService.java
+++ b/service/src/com/android/server/ConnectivityService.java
@@ -848,11 +848,6 @@
     private static final int EVENT_UID_FROZEN_STATE_CHANGED = 61;
 
     /**
-     * Event to inform the ConnectivityService handler when a uid has lost carrier privileges.
-     */
-    private static final int EVENT_UID_CARRIER_PRIVILEGES_LOST = 62;
-
-    /**
      * Argument for {@link #EVENT_PROVISIONING_NOTIFICATION} to indicate that the notification
      * should be shown.
      */
@@ -1295,14 +1290,6 @@
     }
     private final LegacyTypeTracker mLegacyTypeTracker = new LegacyTypeTracker(this);
 
-    @VisibleForTesting
-    void onCarrierPrivilegesLost(Integer uid, Integer subId) {
-        if (mRequestRestrictedWifiEnabled) {
-            mHandler.sendMessage(mHandler.obtainMessage(
-                    EVENT_UID_CARRIER_PRIVILEGES_LOST, uid, subId));
-        }
-    }
-
     final LocalPriorityDump mPriorityDumper = new LocalPriorityDump();
     /**
      * Helper class which parses out priority arguments and dumps sections according to their
@@ -1524,10 +1511,11 @@
                 @NonNull final Context context,
                 @NonNull final TelephonyManager tm,
                 boolean requestRestrictedWifiEnabled,
-                @NonNull BiConsumer<Integer, Integer> listener) {
+                @NonNull BiConsumer<Integer, Integer> listener,
+                @NonNull final Handler connectivityServiceHandler) {
             if (isAtLeastT()) {
-                return new CarrierPrivilegeAuthenticator(
-                        context, tm, requestRestrictedWifiEnabled, listener);
+                return new CarrierPrivilegeAuthenticator(context, tm, requestRestrictedWifiEnabled,
+                        listener, connectivityServiceHandler);
             } else {
                 return null;
             }
@@ -1812,7 +1800,7 @@
                 && mDeps.isFeatureEnabled(context, REQUEST_RESTRICTED_WIFI);
         mCarrierPrivilegeAuthenticator = mDeps.makeCarrierPrivilegeAuthenticator(
                 mContext, mTelephonyManager, mRequestRestrictedWifiEnabled,
-                this::onCarrierPrivilegesLost);
+                this::handleUidCarrierPrivilegesLost, mHandler);
 
         if (mDeps.isAtLeastU()
                 && mDeps
@@ -2247,7 +2235,11 @@
         final long ident = Binder.clearCallingIdentity();
         try {
             final boolean metered = nc == null ? true : nc.isMetered();
-            return mPolicyManager.isUidNetworkingBlocked(uid, metered);
+            if (mDeps.isAtLeastV()) {
+                return mBpfNetMaps.isUidNetworkingBlocked(uid, metered);
+            } else {
+                return mPolicyManager.isUidNetworkingBlocked(uid, metered);
+            }
         } finally {
             Binder.restoreCallingIdentity(ident);
         }
@@ -3831,6 +3823,10 @@
             mSatelliteAccessController.start();
         }
 
+        if (mCarrierPrivilegeAuthenticator != null) {
+            mCarrierPrivilegeAuthenticator.start();
+        }
+
         // On T+ devices, register callback for statsd to pull NETWORK_BPF_MAP_INFO atom
         if (mDeps.isAtLeastT()) {
             mBpfNetMaps.setPullAtomCallback(mContext);
@@ -6519,9 +6515,6 @@
                     UidFrozenStateChangedArgs args = (UidFrozenStateChangedArgs) msg.obj;
                     handleFrozenUids(args.mUids, args.mFrozenStates);
                     break;
-                case EVENT_UID_CARRIER_PRIVILEGES_LOST:
-                    handleUidCarrierPrivilegesLost(msg.arg1, msg.arg2);
-                    break;
             }
         }
     }
@@ -9260,6 +9253,9 @@
     }
 
     private void handleUidCarrierPrivilegesLost(int uid, int subId) {
+        if (!mRequestRestrictedWifiEnabled) {
+            return;
+        }
         ensureRunningOnConnectivityServiceThread();
         // A NetworkRequest needs to be revoked when all the conditions are met
         //   1. It requests restricted network
diff --git a/service/src/com/android/server/connectivity/CarrierPrivilegeAuthenticator.java b/service/src/com/android/server/connectivity/CarrierPrivilegeAuthenticator.java
index 04d0fc1..f5fa4fb 100644
--- a/service/src/com/android/server/connectivity/CarrierPrivilegeAuthenticator.java
+++ b/service/src/com/android/server/connectivity/CarrierPrivilegeAuthenticator.java
@@ -91,42 +91,62 @@
             @NonNull final TelephonyManager t,
             @NonNull final TelephonyManagerShim telephonyManagerShim,
             final boolean requestRestrictedWifiEnabled,
-            @NonNull BiConsumer<Integer, Integer> listener) {
+            @NonNull BiConsumer<Integer, Integer> listener,
+            @NonNull final Handler connectivityServiceHandler) {
         mContext = c;
         mTelephonyManager = t;
         mTelephonyManagerShim = telephonyManagerShim;
-        final HandlerThread thread = deps.makeHandlerThread();
-        thread.start();
-        mHandler = new Handler(thread.getLooper());
         mUseCallbacksForServiceChanged = deps.isFeatureEnabled(
                 c, CARRIER_SERVICE_CHANGED_USE_CALLBACK);
         mRequestRestrictedWifiEnabled = requestRestrictedWifiEnabled;
         mListener = listener;
+        if (mRequestRestrictedWifiEnabled) {
+            mHandler = connectivityServiceHandler;
+        } else {
+            final HandlerThread thread = deps.makeHandlerThread();
+            thread.start();
+            mHandler = new Handler(thread.getLooper());
+            synchronized (mLock) {
+                registerSimConfigChangedReceiver();
+                simConfigChanged();
+            }
+        }
+    }
+
+    private void registerSimConfigChangedReceiver() {
         final IntentFilter filter = new IntentFilter();
         filter.addAction(TelephonyManager.ACTION_MULTI_SIM_CONFIG_CHANGED);
-        synchronized (mLock) {
-            // Never unregistered because the system server never stops
-            c.registerReceiver(new BroadcastReceiver() {
-                @Override
-                public void onReceive(final Context context, final Intent intent) {
-                    switch (intent.getAction()) {
-                        case TelephonyManager.ACTION_MULTI_SIM_CONFIG_CHANGED:
-                            simConfigChanged();
-                            break;
-                        default:
-                            Log.d(TAG, "Unknown intent received, action: " + intent.getAction());
-                    }
+        // Never unregistered because the system server never stops
+        mContext.registerReceiver(new BroadcastReceiver() {
+            @Override
+            public void onReceive(final Context context, final Intent intent) {
+                switch (intent.getAction()) {
+                    case TelephonyManager.ACTION_MULTI_SIM_CONFIG_CHANGED:
+                        simConfigChanged();
+                        break;
+                    default:
+                        Log.d(TAG, "Unknown intent received, action: " + intent.getAction());
                 }
-            }, filter, null, mHandler);
-            simConfigChanged();
+            }
+        }, filter, null, mHandler);
+    }
+
+    /**
+     * Start CarrierPrivilegeAuthenticator
+     */
+    public void start() {
+        if (mRequestRestrictedWifiEnabled) {
+            registerSimConfigChangedReceiver();
+            mHandler.post(this::simConfigChanged);
         }
     }
 
     public CarrierPrivilegeAuthenticator(@NonNull final Context c,
             @NonNull final TelephonyManager t, final boolean requestRestrictedWifiEnabled,
-            @NonNull BiConsumer<Integer, Integer> listener) {
+            @NonNull BiConsumer<Integer, Integer> listener,
+            @NonNull final Handler connectivityServiceHandler) {
         this(c, new Dependencies(), t, TelephonyManagerShimImpl.newInstance(t),
-                requestRestrictedWifiEnabled, listener);
+                requestRestrictedWifiEnabled, listener, connectivityServiceHandler);
     }
 
     public static class Dependencies {
@@ -146,6 +166,10 @@
     }
 
     private void simConfigChanged() {
+        //  If mRequestRestrictedWifiEnabled is false, constructor calls simConfigChanged
+        if (mRequestRestrictedWifiEnabled) {
+            ensureRunningOnHandlerThread();
+        }
         synchronized (mLock) {
             unregisterCarrierPrivilegesListeners();
             mModemCount = mTelephonyManager.getActiveModemCount();
@@ -188,6 +212,7 @@
         public void onCarrierPrivilegesChanged(
                 @NonNull List<String> privilegedPackageNames,
                 @NonNull int[] privilegedUids) {
+            ensureRunningOnHandlerThread();
             if (mUseCallbacksForServiceChanged) return;
             // Re-trigger the synchronous check (which is also very cheap due
             // to caching in CarrierPrivilegesTracker). This allows consistency
@@ -198,6 +223,7 @@
         @Override
         public void onCarrierServiceChanged(@Nullable final String carrierServicePackageName,
                 final int carrierServiceUid) {
+            ensureRunningOnHandlerThread();
             if (!mUseCallbacksForServiceChanged) {
                 // Re-trigger the synchronous check (which is also very cheap due
                 // to caching in CarrierPrivilegesTracker). This allows consistency
@@ -439,6 +465,13 @@
         }
     }
 
+    private void ensureRunningOnHandlerThread() {
+        if (mHandler.getLooper().getThread() != Thread.currentThread()) {
+            throw new IllegalStateException(
+                    "Not running on handler thread: " + Thread.currentThread().getName());
+        }
+    }
+
     public void dump(IndentingPrintWriter pw) {
         pw.println("CarrierPrivilegeAuthenticator:");
         pw.println("mRequestRestrictedWifiEnabled = " + mRequestRestrictedWifiEnabled);
diff --git a/service/src/com/android/server/connectivity/ClatCoordinator.java b/service/src/com/android/server/connectivity/ClatCoordinator.java
index daaf91d..eea16bf 100644
--- a/service/src/com/android/server/connectivity/ClatCoordinator.java
+++ b/service/src/com/android/server/connectivity/ClatCoordinator.java
@@ -847,12 +847,12 @@
             if (mIngressMap.isEmpty()) {
                 pw.println("<empty>");
             }
-            pw.println("BPF ingress map: iif nat64Prefix v6Addr -> v4Addr oif");
+            pw.println("BPF ingress map: iif nat64Prefix v6Addr -> v4Addr oif (packets bytes)");
             pw.increaseIndent();
             mIngressMap.forEach((k, v) -> {
                 // TODO: print interface name
-                pw.println(String.format("%d %s/96 %s -> %s %d", k.iif, k.pfx96, k.local6,
-                        v.local4, v.oif));
+                pw.println(String.format("%d %s/96 %s -> %s %d (%d %d)", k.iif, k.pfx96, k.local6,
+                        v.local4, v.oif, v.packets, v.bytes));
             });
             pw.decreaseIndent();
         } catch (ErrnoException e) {
@@ -870,12 +870,13 @@
             if (mEgressMap.isEmpty()) {
                 pw.println("<empty>");
             }
-            pw.println("BPF egress map: iif v4Addr -> v6Addr nat64Prefix oif");
+            pw.println("BPF egress map: iif v4Addr -> v6Addr nat64Prefix oif (packets bytes)");
             pw.increaseIndent();
             mEgressMap.forEach((k, v) -> {
                 // TODO: print interface name
-                pw.println(String.format("%d %s -> %s %s/96 %d %s", k.iif, k.local4, v.local6,
-                        v.pfx96, v.oif, v.oifIsEthernet != 0 ? "ether" : "rawip"));
+                pw.println(String.format("%d %s -> %s %s/96 %d %s (%d %d)", k.iif, k.local4,
+                        v.local6, v.pfx96, v.oif, v.oifIsEthernet != 0 ? "ether" : "rawip",
+                        v.packets, v.bytes));
             });
             pw.decreaseIndent();
         } catch (ErrnoException e) {
diff --git a/service/src/com/android/server/connectivity/KeepaliveStatsTracker.java b/service/src/com/android/server/connectivity/KeepaliveStatsTracker.java
index 48af9fa..21dbb45 100644
--- a/service/src/com/android/server/connectivity/KeepaliveStatsTracker.java
+++ b/service/src/com/android/server/connectivity/KeepaliveStatsTracker.java
@@ -29,6 +29,7 @@
 import android.net.TelephonyNetworkSpecifier;
 import android.net.TransportInfo;
 import android.net.wifi.WifiInfo;
+import android.os.Build;
 import android.os.Handler;
 import android.os.SystemClock;
 import android.telephony.SubscriptionInfo;
@@ -39,6 +40,8 @@
 import android.util.SparseArray;
 import android.util.SparseIntArray;
 
+import androidx.annotation.RequiresApi;
+
 import com.android.internal.annotations.VisibleForTesting;
 import com.android.metrics.DailykeepaliveInfoReported;
 import com.android.metrics.DurationForNumOfKeepalive;
@@ -279,6 +282,7 @@
          *
          * @param dailyKeepaliveInfoReported the proto to write to statsD.
          */
+        @RequiresApi(Build.VERSION_CODES.TIRAMISU)
         public void writeStats(DailykeepaliveInfoReported dailyKeepaliveInfoReported) {
             ConnectivityStatsLog.write(
                     ConnectivityStatsLog.DAILY_KEEPALIVE_INFO_REPORTED,
diff --git a/staticlibs/Android.bp b/staticlibs/Android.bp
index f7b42a6..ede6d3f 100644
--- a/staticlibs/Android.bp
+++ b/staticlibs/Android.bp
@@ -124,6 +124,8 @@
     ],
 }
 
+// The net-utils-device-common-bpf library requires the callers to contain
+// net-utils-device-common-struct-base.
 java_library {
     name: "net-utils-device-common-bpf",
     srcs: [
@@ -133,9 +135,7 @@
         "device/com/android/net/module/util/BpfUtils.java",
         "device/com/android/net/module/util/IBpfMap.java",
         "device/com/android/net/module/util/JniUtil.java",
-        "device/com/android/net/module/util/Struct.java",
         "device/com/android/net/module/util/TcUtils.java",
-        "framework/com/android/net/module/util/HexDump.java",
     ],
     sdk_version: "module_current",
     min_sdk_version: "30",
@@ -146,6 +146,7 @@
     libs: [
         "androidx.annotation_annotation",
         "framework-connectivity.stubs.module_lib",
+        "net-utils-device-common-struct-base",
     ],
     apex_available: [
         "com.android.tethering",
@@ -158,12 +159,9 @@
 }
 
 java_library {
-    name: "net-utils-device-common-struct",
+    name: "net-utils-device-common-struct-base",
     srcs: [
-        "device/com/android/net/module/util/Ipv6Utils.java",
-        "device/com/android/net/module/util/PacketBuilder.java",
         "device/com/android/net/module/util/Struct.java",
-        "device/com/android/net/module/util/structs/*.java",
     ],
     sdk_version: "module_current",
     min_sdk_version: "30",
@@ -176,6 +174,7 @@
     ],
     libs: [
         "androidx.annotation_annotation",
+        "framework-annotations-lib", // Required by InetAddressUtils.java
         "framework-connectivity.stubs.module_lib",
     ],
     apex_available: [
@@ -188,26 +187,30 @@
     },
 }
 
-// The net-utils-multicast-forwarding-structs library requires the callers to
-// contain net-utils-device-common-bpf.
+// The net-utils-device-common-struct library requires the callers to contain
+// net-utils-device-common-struct-base.
 java_library {
-    name: "net-utils-multicast-forwarding-structs",
+    name: "net-utils-device-common-struct",
     srcs: [
-        "device/com/android/net/module/util/structs/StructMf6cctl.java",
-        "device/com/android/net/module/util/structs/StructMif6ctl.java",
-        "device/com/android/net/module/util/structs/StructMrt6Msg.java",
+        "device/com/android/net/module/util/Ipv6Utils.java",
+        "device/com/android/net/module/util/PacketBuilder.java",
+        "device/com/android/net/module/util/structs/*.java",
     ],
     sdk_version: "module_current",
     min_sdk_version: "30",
     visibility: [
         "//packages/modules/Connectivity:__subpackages__",
+        "//packages/modules/NetworkStack:__subpackages__",
     ],
     libs: [
-        // Only Struct.java is needed from "net-utils-device-common-bpf"
-        "net-utils-device-common-bpf",
+        "androidx.annotation_annotation",
+        "framework-annotations-lib", // Required by IpUtils.java
+        "framework-connectivity.stubs.module_lib",
+        "net-utils-device-common-struct-base",
     ],
     apex_available: [
         "com.android.tethering",
+        "//apex_available:platform",
     ],
     lint: {
         strict_updatability_linting: true,
@@ -216,7 +219,7 @@
 }
 
 // The net-utils-device-common-netlink library requires the callers to contain
-// net-utils-device-common-struct.
+// net-utils-device-common-struct and net-utils-device-common-struct-base.
 java_library {
     name: "net-utils-device-common-netlink",
     srcs: [
@@ -235,6 +238,7 @@
         // statically link here because callers of this library might already have a static
         // version linked.
         "net-utils-device-common-struct",
+        "net-utils-device-common-struct-base",
     ],
     apex_available: [
         "com.android.tethering",
@@ -247,7 +251,7 @@
 }
 
 // The net-utils-device-common-ip library requires the callers to contain
-// net-utils-device-common-struct.
+// net-utils-device-common-struct and net-utils-device-common-struct-base.
 java_library {
     // TODO : this target should probably be folded into net-utils-device-common
     name: "net-utils-device-common-ip",
diff --git a/staticlibs/device/com/android/net/module/util/structs/PrefixInformationOption.java b/staticlibs/device/com/android/net/module/util/structs/PrefixInformationOption.java
index 49d7654..0fc85e4 100644
--- a/staticlibs/device/com/android/net/module/util/structs/PrefixInformationOption.java
+++ b/staticlibs/device/com/android/net/module/util/structs/PrefixInformationOption.java
@@ -21,6 +21,7 @@
 import android.net.IpPrefix;
 
 import androidx.annotation.NonNull;
+import androidx.annotation.VisibleForTesting;
 
 import com.android.net.module.util.Struct;
 import com.android.net.module.util.Struct.Field;
@@ -71,7 +72,8 @@
     @Field(order = 7, type = Type.ByteArray, arraysize = 16)
     public final byte[] prefix;
 
-    PrefixInformationOption(final byte type, final byte length, final byte prefixLen,
+    @VisibleForTesting
+    public PrefixInformationOption(final byte type, final byte length, final byte prefixLen,
             final byte flags, final long validLifetime, final long preferredLifetime,
             final int reserved, @NonNull final byte[] prefix) {
         this.type = type;
diff --git a/staticlibs/native/bpf_headers/include/bpf/bpf_helpers.h b/staticlibs/native/bpf_headers/include/bpf/bpf_helpers.h
index 92ddf44..dc7925e 100644
--- a/staticlibs/native/bpf_headers/include/bpf/bpf_helpers.h
+++ b/staticlibs/native/bpf_headers/include/bpf/bpf_helpers.h
@@ -37,17 +37,27 @@
 #define BPFLOADER_IGNORED_ON_VERSION 33u
 
 // Android U / 14 (api level 34) - various new program types added
-#define BPFLOADER_U_VERSION 37u
+#define BPFLOADER_U_VERSION 38u
 
 // Android V / 15 (api level 35) - platform only
 // (note: the platform bpfloader in V isn't really versioned at all,
 //  as there is no need as it can only load objects compiled at the
 //  same time as itself and the rest of the platform)
-#define BPFLOADER_V_VERSION 41u
+#define BPFLOADER_PLATFORM_VERSION 41u
 
-// Android Mainline - this bpfloader should eventually go back to T
+// Android Mainline - this bpfloader should eventually go back to T (or even S)
+// Note: this value (and the following +1u's) are hardcoded in NetBpfLoad.cpp
 #define BPFLOADER_MAINLINE_VERSION 42u
 
+// Android Mainline BpfLoader when running on Android T
+#define BPFLOADER_MAINLINE_T_VERSION (BPFLOADER_MAINLINE_VERSION + 1u)
+
+// Android Mainline BpfLoader when running on Android U
+#define BPFLOADER_MAINLINE_U_VERSION (BPFLOADER_MAINLINE_T_VERSION + 1u)
+
+// Android Mainline BpfLoader when running on Android V
+#define BPFLOADER_MAINLINE_V_VERSION (BPFLOADER_MAINLINE_U_VERSION + 1u)
+
 /* For mainline module use, you can #define BPFLOADER_{MIN/MAX}_VER
  * before #include "bpf_helpers.h" to change which bpfloaders will
  * process the resulting .o file.
@@ -57,7 +67,7 @@
  * In which case it's just best to use the default.
  */
 #ifndef BPFLOADER_MIN_VER
-#define BPFLOADER_MIN_VER BPFLOADER_V_VERSION
+#define BPFLOADER_MIN_VER BPFLOADER_PLATFORM_VERSION
 #endif
 
 #ifndef BPFLOADER_MAX_VER
diff --git a/staticlibs/tests/unit/Android.bp b/staticlibs/tests/unit/Android.bp
index 4c226cc..fa466f8 100644
--- a/staticlibs/tests/unit/Android.bp
+++ b/staticlibs/tests/unit/Android.bp
@@ -25,6 +25,7 @@
         "net-utils-device-common-async",
         "net-utils-device-common-bpf",
         "net-utils-device-common-ip",
+        "net-utils-device-common-struct-base",
         "net-utils-device-common-wear",
     ],
     libs: [
diff --git a/staticlibs/testutils/Android.bp b/staticlibs/testutils/Android.bp
index a8e5a69..9124ac0 100644
--- a/staticlibs/testutils/Android.bp
+++ b/staticlibs/testutils/Android.bp
@@ -40,6 +40,7 @@
         "net-utils-device-common-async",
         "net-utils-device-common-netlink",
         "net-utils-device-common-struct",
+        "net-utils-device-common-struct-base",
         "net-utils-device-common-wear",
         "modules-utils-build_system",
     ],
diff --git a/tests/cts/net/src/android/net/cts/ApfIntegrationTest.kt b/tests/cts/net/src/android/net/cts/ApfIntegrationTest.kt
new file mode 100644
index 0000000..3ecbdc6
--- /dev/null
+++ b/tests/cts/net/src/android/net/cts/ApfIntegrationTest.kt
@@ -0,0 +1,91 @@
+/*
+ * Copyright (C) 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package android.net.cts
+
+import android.content.pm.PackageManager.FEATURE_WIFI
+import android.net.ConnectivityManager
+import android.net.NetworkCapabilities
+import android.net.NetworkRequest
+import android.os.Build
+import android.platform.test.annotations.AppModeFull
+import android.system.OsConstants
+import androidx.test.platform.app.InstrumentationRegistry
+import com.android.compatibility.common.util.PropertyUtil.isVendorApiLevelNewerThan
+import com.android.compatibility.common.util.SystemUtil
+import com.android.testutils.DevSdkIgnoreRule
+import com.android.testutils.DevSdkIgnoreRunner
+import com.android.testutils.NetworkStackModuleTest
+import com.android.testutils.RecorderCallback.CallbackEntry.LinkPropertiesChanged
+import com.android.testutils.TestableNetworkCallback
+import com.google.common.truth.Truth.assertThat
+import kotlin.test.assertEquals
+import kotlin.test.assertNotNull
+import org.junit.After
+import org.junit.Assume.assumeTrue
+import org.junit.Before
+import org.junit.Test
+import org.junit.runner.RunWith
+
+private const val TIMEOUT_MS = 2000L
+
+@AppModeFull(reason = "CHANGE_NETWORK_STATE permission can't be granted to instant apps")
+@RunWith(DevSdkIgnoreRunner::class)
+@NetworkStackModuleTest
+@DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.TIRAMISU)
+class ApfIntegrationTest {
+    private val context by lazy { InstrumentationRegistry.getInstrumentation().context }
+    private val cm by lazy { context.getSystemService(ConnectivityManager::class.java)!! }
+    private val pm by lazy { context.packageManager }
+    private lateinit var ifname: String
+    private lateinit var networkCallback: TestableNetworkCallback
+
+    @Before
+    fun setUp() {
+        assumeTrue(pm.hasSystemFeature(FEATURE_WIFI))
+        assumeTrue(isVendorApiLevelNewerThan(Build.VERSION_CODES.TIRAMISU))
+        networkCallback = TestableNetworkCallback()
+        cm.requestNetwork(
+                NetworkRequest.Builder()
+                        .addTransportType(NetworkCapabilities.TRANSPORT_WIFI)
+                        .addCapability(NetworkCapabilities.NET_CAPABILITY_INTERNET)
+                        .build(),
+                networkCallback
+        )
+        networkCallback.eventuallyExpect<LinkPropertiesChanged>(TIMEOUT_MS) {
+            ifname = assertNotNull(it.lp.interfaceName)
+            true
+        }
+    }
+
+    @After
+    fun tearDown() {
+        if (::networkCallback.isInitialized) {
+            cm.unregisterNetworkCallback(networkCallback)
+        }
+    }
+
+    @Test
+    fun testGetApfCapabilities() {
+        val caps = SystemUtil.runShellCommand("cmd network_stack apf $ifname capabilities").trim()
+        val (version, maxLen, packetFormat) = caps.split(",").map { it.toInt() }
+        assertEquals(4, version)
+        assertThat(maxLen).isAtLeast(1024)
+        if (isVendorApiLevelNewerThan(Build.VERSION_CODES.UPSIDE_DOWN_CAKE)) {
+            assertThat(maxLen).isAtLeast(2000)
+        }
+        assertEquals(OsConstants.ARPHRD_ETHER, packetFormat)
+    }
+}
diff --git a/tests/cts/net/src/android/net/cts/EthernetManagerTest.kt b/tests/cts/net/src/android/net/cts/EthernetManagerTest.kt
index 7af3c83..d052551 100644
--- a/tests/cts/net/src/android/net/cts/EthernetManagerTest.kt
+++ b/tests/cts/net/src/android/net/cts/EthernetManagerTest.kt
@@ -348,7 +348,9 @@
         }
     }
 
-    private fun isEthernetSupported() = em != null
+    private fun isEthernetSupported() : Boolean {
+        return context.getSystemService(EthernetManager::class.java) != null
+    }
 
     @Before
     fun setUp() {
diff --git a/tests/cts/net/src/android/net/cts/NsdManagerTest.kt b/tests/cts/net/src/android/net/cts/NsdManagerTest.kt
index 0daf7fe..6dd4857 100644
--- a/tests/cts/net/src/android/net/cts/NsdManagerTest.kt
+++ b/tests/cts/net/src/android/net/cts/NsdManagerTest.kt
@@ -1888,6 +1888,64 @@
         }
     }
 
+    @Test
+    fun testQueryWhenKnownAnswerSuppressionFlagSet() {
+        // The flag may be removed in the future but known-answer suppression should be enabled by
+        // default in that case. The rule will reset flags automatically on teardown.
+        deviceConfigRule.setConfig(NAMESPACE_TETHERING, "test_nsd_query_with_known_answer", "1")
+
+        // Register service on testNetwork1
+        val discoveryRecord = NsdDiscoveryRecord()
+        val packetReader = TapPacketReader(Handler(handlerThread.looper),
+                testNetwork1.iface.fileDescriptor.fileDescriptor, 1500 /* maxPacketSize */)
+        packetReader.startAsyncForTest()
+        handlerThread.waitForIdle(TIMEOUT_MS)
+
+        nsdManager.discoverServices(serviceType, NsdManager.PROTOCOL_DNS_SD,
+                testNetwork1.network, { it.run() }, discoveryRecord)
+
+        tryTest {
+            discoveryRecord.expectCallback<DiscoveryStarted>()
+            assertNotNull(packetReader.pollForQuery("$serviceType.local", DnsResolver.TYPE_PTR))
+            /*
+            Generated with:
+            scapy.raw(scapy.DNS(rd=0, qr=1, aa=1, qd = None, an =
+                scapy.DNSRR(rrname='_nmt123456789._tcp.local', type='PTR', ttl=120,
+                rdata='NsdTest123456789._nmt123456789._tcp.local'))).hex()
+             */
+            val ptrResponsePayload = HexDump.hexStringToByteArray("0000840000000001000000000d5f6e" +
+                    "6d74313233343536373839045f746370056c6f63616c00000c000100000078002b104e736454" +
+                    "6573743132333435363738390d5f6e6d74313233343536373839045f746370056c6f63616c00")
+
+            replaceServiceNameAndTypeWithTestSuffix(ptrResponsePayload)
+            packetReader.sendResponse(buildMdnsPacket(ptrResponsePayload))
+
+            val serviceFound = discoveryRecord.expectCallback<ServiceFound>()
+            serviceFound.serviceInfo.let {
+                assertEquals(serviceName, it.serviceName)
+                // Discovered service types have a dot at the end
+                assertEquals("$serviceType.", it.serviceType)
+                assertEquals(testNetwork1.network, it.network)
+                // ServiceFound does not provide port, address or attributes (only information
+                // available in the PTR record is included in that callback, regardless of whether
+                // other records exist).
+                assertEquals(0, it.port)
+                assertEmpty(it.hostAddresses)
+                assertEquals(0, it.attributes.size)
+            }
+
+            // Expect the second query with a known answer
+            val query = packetReader.pollForMdnsPacket { pkt ->
+                pkt.isQueryFor("$serviceType.local", DnsResolver.TYPE_PTR) &&
+                        pkt.isReplyFor("$serviceType.local", DnsResolver.TYPE_PTR)
+            }
+            assertNotNull(query)
+        } cleanup {
+            nsdManager.stopServiceDiscovery(discoveryRecord)
+            discoveryRecord.expectCallback<DiscoveryStopped>()
+        }
+    }
+
     private fun makeLinkLocalAddressOfOtherDeviceOnPrefix(network: Network): Inet6Address {
         val lp = cm.getLinkProperties(network) ?: fail("No LinkProperties for net $network")
         // Expect to have a /64 link-local address
@@ -2148,6 +2206,66 @@
     }
 
     @Test
+    fun testAdvertisingAndDiscovery_reregisterCustomHostWithDifferentAddresses_newAddressesFound() {
+        val si1 = NsdServiceInfo().also {
+            it.network = testNetwork1.network
+            it.hostname = customHostname
+            it.hostAddresses = listOf(
+                    parseNumericAddress("192.0.2.23"),
+                    parseNumericAddress("2001:db8::1"))
+        }
+        val si2 = NsdServiceInfo().also {
+            it.network = testNetwork1.network
+            it.serviceName = serviceName
+            it.serviceType = serviceType
+            it.hostname = customHostname
+            it.port = TEST_PORT
+        }
+        val si3 = NsdServiceInfo().also {
+            it.network = testNetwork1.network
+            it.hostname = customHostname
+            it.hostAddresses = listOf(
+                    parseNumericAddress("192.0.2.24"),
+                    parseNumericAddress("2001:db8::2"))
+        }
+
+        val registrationRecord1 = NsdRegistrationRecord()
+        val registrationRecord2 = NsdRegistrationRecord()
+        val registrationRecord3 = NsdRegistrationRecord()
+
+        val discoveryRecord = NsdDiscoveryRecord()
+
+        tryTest {
+            registerService(registrationRecord1, si1)
+            registerService(registrationRecord2, si2)
+
+            nsdManager.unregisterService(registrationRecord1)
+            registrationRecord1.expectCallback<ServiceUnregistered>()
+
+            registerService(registrationRecord3, si3)
+
+            nsdManager.discoverServices(serviceType, NsdManager.PROTOCOL_DNS_SD,
+                    testNetwork1.network, Executor { it.run() }, discoveryRecord)
+            val discoveredInfo = discoveryRecord.waitForServiceDiscovered(
+                    serviceName, serviceType, testNetwork1.network)
+            val resolvedInfo = resolveService(discoveredInfo)
+
+            assertEquals(serviceName, discoveredInfo.serviceName)
+            assertEquals(TEST_PORT, resolvedInfo.port)
+            assertEquals(customHostname, resolvedInfo.hostname)
+            assertAddressEquals(
+                    listOf(parseNumericAddress("192.0.2.24"), parseNumericAddress("2001:db8::2")),
+                    resolvedInfo.hostAddresses)
+        } cleanupStep {
+            nsdManager.stopServiceDiscovery(discoveryRecord)
+            discoveryRecord.expectCallbackEventually<DiscoveryStopped>()
+        } cleanup {
+            nsdManager.unregisterService(registrationRecord2)
+            nsdManager.unregisterService(registrationRecord3)
+        }
+    }
+
+    @Test
     fun testServiceTypeClientRemovedAfterSocketDestroyed() {
         val si = makeTestServiceInfo(testNetwork1.network)
         // Register service on testNetwork1
diff --git a/tests/integration/src/com/android/server/net/integrationtests/ConnectivityServiceIntegrationTest.kt b/tests/integration/src/com/android/server/net/integrationtests/ConnectivityServiceIntegrationTest.kt
index 035f607..d2e46af 100644
--- a/tests/integration/src/com/android/server/net/integrationtests/ConnectivityServiceIntegrationTest.kt
+++ b/tests/integration/src/com/android/server/net/integrationtests/ConnectivityServiceIntegrationTest.kt
@@ -243,18 +243,19 @@
             super.makeHandlerThread(tag).also { handlerThreads.add(it) }
 
         override fun makeCarrierPrivilegeAuthenticator(
-            context: Context,
-            tm: TelephonyManager,
-            requestRestrictedWifiEnabled: Boolean,
-            listener: BiConsumer<Int, Int>
+                context: Context,
+                tm: TelephonyManager,
+                requestRestrictedWifiEnabled: Boolean,
+                listener: BiConsumer<Int, Int>,
+                handler: Handler
         ): CarrierPrivilegeAuthenticator {
             return CarrierPrivilegeAuthenticator(context,
-                object : CarrierPrivilegeAuthenticator.Dependencies() {
-                    override fun makeHandlerThread(): HandlerThread =
-                        super.makeHandlerThread().also { handlerThreads.add(it) }
-                },
-                tm, TelephonyManagerShimImpl.newInstance(tm),
-                requestRestrictedWifiEnabled, listener)
+                    object : CarrierPrivilegeAuthenticator.Dependencies() {
+                        override fun makeHandlerThread(): HandlerThread =
+                                super.makeHandlerThread().also { handlerThreads.add(it) }
+                    },
+                    tm, TelephonyManagerShimImpl.newInstance(tm),
+                    requestRestrictedWifiEnabled, listener, handler)
         }
 
         override fun makeSatelliteAccessController(
diff --git a/tests/integration/src/com/android/server/net/integrationtests/NetworkStatsIntegrationTest.kt b/tests/integration/src/com/android/server/net/integrationtests/NetworkStatsIntegrationTest.kt
index 765e56e..52e502d 100644
--- a/tests/integration/src/com/android/server/net/integrationtests/NetworkStatsIntegrationTest.kt
+++ b/tests/integration/src/com/android/server/net/integrationtests/NetworkStatsIntegrationTest.kt
@@ -299,7 +299,8 @@
         val buf = ByteArray(DEFAULT_BUFFER_SIZE)
 
         httpServer.addResponse(
-            TestHttpServer.Request(path, NanoHTTPD.Method.POST), NanoHTTPD.Response.Status.OK,
+            TestHttpServer.Request(path, NanoHTTPD.Method.POST),
+            NanoHTTPD.Response.Status.OK,
             content = getRandomString(downloadSize)
         )
         var httpConnection: HttpURLConnection? = null
@@ -349,15 +350,19 @@
     ) {
         operator fun plus(other: BareStats): BareStats {
             return BareStats(
-                this.rxBytes + other.rxBytes, this.rxPackets + other.rxPackets,
-                this.txBytes + other.txBytes, this.txPackets + other.txPackets
+                this.rxBytes + other.rxBytes,
+                this.rxPackets + other.rxPackets,
+                this.txBytes + other.txBytes,
+                this.txPackets + other.txPackets
             )
         }
 
         operator fun minus(other: BareStats): BareStats {
             return BareStats(
-                this.rxBytes - other.rxBytes, this.rxPackets - other.rxPackets,
-                this.txBytes - other.txBytes, this.txPackets - other.txPackets
+                this.rxBytes - other.rxBytes,
+                this.rxPackets - other.rxPackets,
+                this.txBytes - other.txBytes,
+                this.txPackets - other.txPackets
             )
         }
 
@@ -405,8 +410,12 @@
         private fun getUidDetail(iface: String, tag: Int): BareStats {
             return getNetworkStatsThat(iface, tag) { nsm, template ->
                 nsm.queryDetailsForUidTagState(
-                    template, Long.MIN_VALUE, Long.MAX_VALUE,
-                    Process.myUid(), tag, Bucket.STATE_ALL
+                    template,
+                    Long.MIN_VALUE,
+                    Long.MAX_VALUE,
+                    Process.myUid(),
+                    tag,
+                    Bucket.STATE_ALL
                 )
             }
         }
@@ -498,28 +507,36 @@
         assertInRange(
             "Unexpected iface traffic stats",
             after.iface,
-            before.trafficStatsIface, after.trafficStatsIface,
-            lower, upper
+            before.trafficStatsIface,
+            after.trafficStatsIface,
+            lower,
+            upper
         )
         // Uid traffic stats are counted in both direction because the external network
         // traffic is also attributed to the test uid.
         assertInRange(
             "Unexpected uid traffic stats",
             after.iface,
-            before.trafficStatsUid, after.trafficStatsUid,
-            lower + lower.reverse(), upper + upper.reverse()
+            before.trafficStatsUid,
+            after.trafficStatsUid,
+            lower + lower.reverse(),
+            upper + upper.reverse()
         )
         assertInRange(
             "Unexpected non-tagged summary stats",
             after.iface,
-            before.statsSummary, after.statsSummary,
-            lower, upper
+            before.statsSummary,
+            after.statsSummary,
+            lower,
+            upper
         )
         assertInRange(
             "Unexpected non-tagged uid stats",
             after.iface,
-            before.statsUid, after.statsUid,
-            lower, upper
+            before.statsUid,
+            after.statsUid,
+            lower,
+            upper
         )
     }
 
@@ -546,14 +563,16 @@
         assertInRange(
             "Unexpected tagged summary stats",
             after.iface,
-            before.taggedSummary, after.taggedSummary,
+            before.taggedSummary,
+            after.taggedSummary,
             lower,
             upper
         )
         assertInRange(
             "Unexpected tagged uid stats: ${Process.myUid()}",
             after.iface,
-            before.taggedUid, after.taggedUid,
+            before.taggedUid,
+            after.taggedUid,
             lower,
             upper
         )
@@ -570,7 +589,8 @@
     ) {
         // Passing the value after operation and the value before operation to dump the actual
         // numbers if it fails.
-        assertTrue(checkInRange(before, after, lower, upper),
+        assertTrue(
+            checkInRange(before, after, lower, upper),
             "$tag on $iface: $after - $before is not within range [$lower, $upper]"
         )
     }
diff --git a/tests/unit/Android.bp b/tests/unit/Android.bp
index a5d2f4a..2f88c41 100644
--- a/tests/unit/Android.bp
+++ b/tests/unit/Android.bp
@@ -72,17 +72,6 @@
     ],
 }
 
-// Subset of services-core used to by ConnectivityService tests to test VPN realistically.
-// This is stripped by jarjar (see rules below) from other unrelated classes, so tests do not
-// include most classes from services-core, which are unrelated and cause wrong code coverage
-// calculations.
-java_library {
-    name: "services.core-vpn",
-    static_libs: ["services.core"],
-    jarjar_rules: "vpn-jarjar-rules.txt",
-    visibility: ["//visibility:private"],
-}
-
 java_defaults {
     name: "FrameworksNetTestsDefaults",
     min_sdk_version: "30",
@@ -109,7 +98,6 @@
         "platform-test-annotations",
         "service-connectivity-pre-jarjar",
         "service-connectivity-tiramisu-pre-jarjar",
-        "services.core-vpn",
         "testables",
         "cts-net-utils",
     ],
diff --git a/tests/unit/java/android/net/BpfNetMapsReaderTest.kt b/tests/unit/java/android/net/NetworkStackBpfNetMapsTest.kt
similarity index 83%
rename from tests/unit/java/android/net/BpfNetMapsReaderTest.kt
rename to tests/unit/java/android/net/NetworkStackBpfNetMapsTest.kt
index 8919666..a9ccbdd 100644
--- a/tests/unit/java/android/net/BpfNetMapsReaderTest.kt
+++ b/tests/unit/java/android/net/NetworkStackBpfNetMapsTest.kt
@@ -26,6 +26,7 @@
 import android.net.BpfNetMapsConstants.UID_RULES_CONFIGURATION_KEY
 import android.net.BpfNetMapsUtils.getMatchByFirewallChain
 import android.os.Build.VERSION_CODES
+import android.os.Process.FIRST_APPLICATION_UID
 import com.android.net.module.util.IBpfMap
 import com.android.net.module.util.Struct.S32
 import com.android.net.module.util.Struct.U32
@@ -42,7 +43,7 @@
 import org.junit.Test
 import org.junit.runner.RunWith
 
-private const val TEST_UID1 = 1234
+private const val TEST_UID1 = 11234
 private const val TEST_UID2 = TEST_UID1 + 1
 private const val TEST_UID3 = TEST_UID2 + 1
 private const val NO_IIF = 0
@@ -50,7 +51,7 @@
 // pre-T devices does not support Bpf.
 @RunWith(DevSdkIgnoreRunner::class)
 @IgnoreUpTo(VERSION_CODES.S_V2)
-class BpfNetMapsReaderTest {
+class NetworkStackBpfNetMapsTest {
     @Rule
     @JvmField
     val ignoreRule = DevSdkIgnoreRule()
@@ -58,14 +59,15 @@
     private val testConfigurationMap: IBpfMap<S32, U32> = TestBpfMap()
     private val testUidOwnerMap: IBpfMap<S32, UidOwnerValue> = TestBpfMap()
     private val testDataSaverEnabledMap: IBpfMap<S32, U8> = TestBpfMap()
-    private val bpfNetMapsReader = BpfNetMapsReader(
-        TestDependencies(testConfigurationMap, testUidOwnerMap, testDataSaverEnabledMap))
+    private val bpfNetMapsReader = NetworkStackBpfNetMaps(
+        TestDependencies(testConfigurationMap, testUidOwnerMap, testDataSaverEnabledMap)
+    )
 
     class TestDependencies(
         private val configMap: IBpfMap<S32, U32>,
         private val uidOwnerMap: IBpfMap<S32, UidOwnerValue>,
         private val dataSaverEnabledMap: IBpfMap<S32, U8>
-    ) : BpfNetMapsReader.Dependencies() {
+    ) : NetworkStackBpfNetMaps.Dependencies() {
         override fun getConfigurationMap() = configMap
         override fun getUidOwnerMap() = uidOwnerMap
         override fun getDataSaverEnabledMap() = dataSaverEnabledMap
@@ -99,11 +101,16 @@
             Modifier.isStatic(it.modifiers) && it.name.startsWith("FIREWALL_CHAIN_")
         }
         // Verify the size matches, this also verifies no common item in allow and deny chains.
-        assertEquals(BpfNetMapsConstants.ALLOW_CHAINS.size +
-                BpfNetMapsConstants.DENY_CHAINS.size, declaredChains.size)
+        assertEquals(
+            BpfNetMapsConstants.ALLOW_CHAINS.size +
+                BpfNetMapsConstants.DENY_CHAINS.size,
+            declaredChains.size
+        )
         declaredChains.forEach {
-            assertTrue(BpfNetMapsConstants.ALLOW_CHAINS.contains(it.get(null)) ||
-                    BpfNetMapsConstants.DENY_CHAINS.contains(it.get(null)))
+            assertTrue(
+                BpfNetMapsConstants.ALLOW_CHAINS.contains(it.get(null)) ||
+                    BpfNetMapsConstants.DENY_CHAINS.contains(it.get(null))
+            )
         }
     }
 
@@ -117,11 +124,17 @@
         testConfigurationMap.updateEntry(UID_RULES_CONFIGURATION_KEY, U32(newConfig))
     }
 
-    fun isUidNetworkingBlocked(uid: Int, metered: Boolean = false, dataSaver: Boolean = false) =
-            bpfNetMapsReader.isUidNetworkingBlocked(uid, metered, dataSaver)
+    private fun mockDataSaverEnabled(enabled: Boolean) {
+        val dataSaverValue = if (enabled) {DATA_SAVER_ENABLED} else {DATA_SAVER_DISABLED}
+        testDataSaverEnabledMap.updateEntry(DATA_SAVER_ENABLED_KEY, U8(dataSaverValue))
+    }
+
+    fun isUidNetworkingBlocked(uid: Int, metered: Boolean = false) =
+            bpfNetMapsReader.isUidNetworkingBlocked(uid, metered)
 
     @Test
     fun testIsUidNetworkingBlockedByFirewallChains_allowChain() {
+        mockDataSaverEnabled(enabled = false)
         // With everything disabled by default, verify the return value is false.
         testConfigurationMap.updateEntry(UID_RULES_CONFIGURATION_KEY, U32(0))
         assertFalse(isUidNetworkingBlocked(TEST_UID1))
@@ -141,6 +154,7 @@
 
     @Test
     fun testIsUidNetworkingBlockedByFirewallChains_denyChain() {
+        mockDataSaverEnabled(enabled = false)
         // Enable standby chain but does not provide denied list. Verify the network is allowed
         // for all uids.
         testConfigurationMap.updateEntry(UID_RULES_CONFIGURATION_KEY, U32(0))
@@ -162,12 +176,14 @@
         testConfigurationMap.updateEntry(UID_RULES_CONFIGURATION_KEY, U32(0))
         mockChainEnabled(ConnectivityManager.FIREWALL_CHAIN_POWERSAVE, true)
         mockChainEnabled(ConnectivityManager.FIREWALL_CHAIN_STANDBY, true)
+        mockDataSaverEnabled(enabled = false)
         assertTrue(isUidNetworkingBlocked(TEST_UID1))
     }
 
     @IgnoreUpTo(VERSION_CODES.S_V2)
     @Test
     fun testIsUidNetworkingBlockedByDataSaver() {
+        mockDataSaverEnabled(enabled = false)
         // With everything disabled by default, verify the return value is false.
         testConfigurationMap.updateEntry(UID_RULES_CONFIGURATION_KEY, U32(0))
         assertFalse(isUidNetworkingBlocked(TEST_UID1, metered = true))
@@ -180,10 +196,11 @@
 
         // Enable data saver, verify the network is blocked for uid1, uid2, but uid3 in happy box
         // is not affected.
+        mockDataSaverEnabled(enabled = true)
         testUidOwnerMap.updateEntry(S32(TEST_UID3), UidOwnerValue(NO_IIF, HAPPY_BOX_MATCH))
-        assertTrue(isUidNetworkingBlocked(TEST_UID1, metered = true, dataSaver = true))
-        assertTrue(isUidNetworkingBlocked(TEST_UID2, metered = true, dataSaver = true))
-        assertFalse(isUidNetworkingBlocked(TEST_UID3, metered = true, dataSaver = true))
+        assertTrue(isUidNetworkingBlocked(TEST_UID1, metered = true))
+        assertTrue(isUidNetworkingBlocked(TEST_UID2, metered = true))
+        assertFalse(isUidNetworkingBlocked(TEST_UID3, metered = true))
 
         // Add uid1 to happy box as well, verify nothing is changed because penalty box has higher
         // priority.
@@ -191,18 +208,19 @@
             S32(TEST_UID1),
             UidOwnerValue(NO_IIF, PENALTY_BOX_MATCH or HAPPY_BOX_MATCH)
         )
-        assertTrue(isUidNetworkingBlocked(TEST_UID1, metered = true, dataSaver = true))
-        assertTrue(isUidNetworkingBlocked(TEST_UID2, metered = true, dataSaver = true))
-        assertFalse(isUidNetworkingBlocked(TEST_UID3, metered = true, dataSaver = true))
+        assertTrue(isUidNetworkingBlocked(TEST_UID1, metered = true))
+        assertTrue(isUidNetworkingBlocked(TEST_UID2, metered = true))
+        assertFalse(isUidNetworkingBlocked(TEST_UID3, metered = true))
 
         // Enable doze mode, verify uid3 is blocked even if it is in happy box.
         mockChainEnabled(ConnectivityManager.FIREWALL_CHAIN_DOZABLE, true)
-        assertTrue(isUidNetworkingBlocked(TEST_UID1, metered = true, dataSaver = true))
-        assertTrue(isUidNetworkingBlocked(TEST_UID2, metered = true, dataSaver = true))
-        assertTrue(isUidNetworkingBlocked(TEST_UID3, metered = true, dataSaver = true))
+        assertTrue(isUidNetworkingBlocked(TEST_UID1, metered = true))
+        assertTrue(isUidNetworkingBlocked(TEST_UID2, metered = true))
+        assertTrue(isUidNetworkingBlocked(TEST_UID3, metered = true))
 
         // Disable doze mode and data saver, only uid1 which is in penalty box is blocked.
         mockChainEnabled(ConnectivityManager.FIREWALL_CHAIN_DOZABLE, false)
+        mockDataSaverEnabled(enabled = false)
         assertTrue(isUidNetworkingBlocked(TEST_UID1, metered = true))
         assertFalse(isUidNetworkingBlocked(TEST_UID2, metered = true))
         assertFalse(isUidNetworkingBlocked(TEST_UID3, metered = true))
@@ -214,6 +232,24 @@
     }
 
     @Test
+    fun testIsUidNetworkingBlocked_SystemUid() {
+        mockDataSaverEnabled(enabled = false)
+        testConfigurationMap.updateEntry(UID_RULES_CONFIGURATION_KEY, U32(0))
+        mockChainEnabled(ConnectivityManager.FIREWALL_CHAIN_DOZABLE, true)
+
+        for (uid in FIRST_APPLICATION_UID - 5..FIRST_APPLICATION_UID + 5) {
+            // system uid is not blocked regardless of firewall chains
+            val expectBlocked = uid >= FIRST_APPLICATION_UID
+            testUidOwnerMap.updateEntry(S32(uid), UidOwnerValue(NO_IIF, PENALTY_BOX_MATCH))
+            assertEquals(
+                expectBlocked,
+                    isUidNetworkingBlocked(uid, metered = true),
+                    "isUidNetworkingBlocked returns unexpected value for uid = " + uid
+            )
+        }
+    }
+
+    @Test
     fun testGetDataSaverEnabled() {
         testDataSaverEnabledMap.updateEntry(DATA_SAVER_ENABLED_KEY, U8(DATA_SAVER_DISABLED))
         assertFalse(bpfNetMapsReader.dataSaverEnabled)
diff --git a/tests/unit/java/com/android/server/ConnectivityServiceTest.java b/tests/unit/java/com/android/server/ConnectivityServiceTest.java
index f5eee42..17c5901 100755
--- a/tests/unit/java/com/android/server/ConnectivityServiceTest.java
+++ b/tests/unit/java/com/android/server/ConnectivityServiceTest.java
@@ -368,7 +368,6 @@
 import android.os.UserHandle;
 import android.os.UserManager;
 import android.provider.Settings;
-import android.security.Credentials;
 import android.system.Os;
 import android.telephony.SubscriptionManager;
 import android.telephony.TelephonyManager;
@@ -389,7 +388,6 @@
 import com.android.internal.annotations.GuardedBy;
 import com.android.internal.app.IBatteryStats;
 import com.android.internal.net.VpnConfig;
-import com.android.internal.net.VpnProfile;
 import com.android.internal.util.WakeupMessage;
 import com.android.internal.util.test.BroadcastInterceptingContext;
 import com.android.internal.util.test.FakeSettingsProvider;
@@ -424,7 +422,6 @@
 import com.android.server.connectivity.SatelliteAccessController;
 import com.android.server.connectivity.TcpKeepaliveController;
 import com.android.server.connectivity.UidRangeUtils;
-import com.android.server.connectivity.VpnProfileStore;
 import com.android.server.net.NetworkPinner;
 import com.android.testutils.DevSdkIgnoreRule;
 import com.android.testutils.DevSdkIgnoreRunner;
@@ -464,7 +461,6 @@
 import java.net.InetAddress;
 import java.net.InetSocketAddress;
 import java.net.Socket;
-import java.nio.charset.StandardCharsets;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collection;
@@ -631,7 +627,6 @@
     @Mock TelephonyManager mTelephonyManager;
     @Mock EthernetManager mEthernetManager;
     @Mock NetworkPolicyManager mNetworkPolicyManager;
-    @Mock VpnProfileStore mVpnProfileStore;
     @Mock SystemConfigManager mSystemConfigManager;
     @Mock DevicePolicyManager mDevicePolicyManager;
     @Mock Resources mResources;
@@ -1667,23 +1662,11 @@
             waitForIdle();
         }
 
-        public void startLegacyVpnPrivileged(VpnProfile profile) {
-            switch (profile.type) {
-                case VpnProfile.TYPE_IKEV2_IPSEC_RSA:
-                case VpnProfile.TYPE_IKEV2_IPSEC_USER_PASS:
-                case VpnProfile.TYPE_IKEV2_IPSEC_PSK:
-                case VpnProfile.TYPE_IKEV2_FROM_IKE_TUN_CONN_PARAMS:
-                    startPlatformVpn();
-                    break;
-                case VpnProfile.TYPE_L2TP_IPSEC_PSK:
-                case VpnProfile.TYPE_L2TP_IPSEC_RSA:
-                case VpnProfile.TYPE_IPSEC_XAUTH_PSK:
-                case VpnProfile.TYPE_IPSEC_XAUTH_RSA:
-                case VpnProfile.TYPE_IPSEC_HYBRID_RSA:
-                    startLegacyVpn();
-                    break;
-                default:
-                    fail("Unknown VPN profile type");
+        public void startLegacyVpnPrivileged(boolean isIkev2Vpn) {
+            if (isIkev2Vpn) {
+                startPlatformVpn();
+            } else {
+                startLegacyVpn();
             }
         }
 
@@ -1736,6 +1719,8 @@
     private void mockUidNetworkingBlocked() {
         doAnswer(i -> isUidBlocked(mBlockedReasons, i.getArgument(1))
         ).when(mNetworkPolicyManager).isUidNetworkingBlocked(anyInt(), anyBoolean());
+        doAnswer(i -> isUidBlocked(mBlockedReasons, i.getArgument(1))
+        ).when(mBpfNetMaps).isUidNetworkingBlocked(anyInt(), anyBoolean());
     }
 
     private boolean isUidBlocked(int blockedReasons, boolean meteredNetwork) {
@@ -2055,12 +2040,16 @@
             };
         }
 
+        private BiConsumer<Integer, Integer> mCarrierPrivilegesLostListener;
+
         @Override
         public CarrierPrivilegeAuthenticator makeCarrierPrivilegeAuthenticator(
                 @NonNull final Context context,
                 @NonNull final TelephonyManager tm,
                 final boolean requestRestrictedWifiEnabled,
-                BiConsumer<Integer, Integer> listener) {
+                BiConsumer<Integer, Integer> listener,
+                @NonNull final Handler handler) {
+            mCarrierPrivilegesLostListener = listener;
             return mDeps.isAtLeastT() ? mCarrierPrivilegeAuthenticator : null;
         }
 
@@ -10213,24 +10202,6 @@
         doAsUid(Process.SYSTEM_UID, () -> mCm.unregisterNetworkCallback(perUidCb));
     }
 
-    private VpnProfile setupLockdownVpn(int profileType) {
-        final String profileName = "testVpnProfile";
-        final byte[] profileTag = profileName.getBytes(StandardCharsets.UTF_8);
-        doReturn(profileTag).when(mVpnProfileStore).get(Credentials.LOCKDOWN_VPN);
-
-        final VpnProfile profile = new VpnProfile(profileName);
-        profile.name = "My VPN";
-        profile.server = "192.0.2.1";
-        profile.dnsServers = "8.8.8.8";
-        profile.ipsecIdentifier = "My ipsecIdentifier";
-        profile.ipsecSecret = "My PSK";
-        profile.type = profileType;
-        final byte[] encodedProfile = profile.encode();
-        doReturn(encodedProfile).when(mVpnProfileStore).get(Credentials.VPN + profileName);
-
-        return profile;
-    }
-
     private void establishLegacyLockdownVpn(Network underlying) throws Exception {
         // The legacy lockdown VPN only supports userId 0, and must have an underlying network.
         assertNotNull(underlying);
@@ -10242,7 +10213,7 @@
         mMockVpn.connect(true);
     }
 
-    private void doTestLockdownVpn(VpnProfile profile, boolean expectSetVpnDefaultForUids)
+    private void doTestLockdownVpn(boolean isIkev2Vpn)
             throws Exception {
         mServiceContext.setPermission(
                 Manifest.permission.CONTROL_VPN, PERMISSION_GRANTED);
@@ -10280,8 +10251,8 @@
         b.expectBroadcast();
         // Simulate LockdownVpnTracker attempting to start the VPN since it received the
         // systemDefault callback.
-        mMockVpn.startLegacyVpnPrivileged(profile);
-        if (expectSetVpnDefaultForUids) {
+        mMockVpn.startLegacyVpnPrivileged(isIkev2Vpn);
+        if (isIkev2Vpn) {
             // setVpnDefaultForUids() releases the original network request and creates a VPN
             // request so LOST callback is received.
             defaultCallback.expect(LOST, mCellAgent);
@@ -10305,7 +10276,7 @@
         final NetworkCapabilities vpnNc = mCm.getNetworkCapabilities(mMockVpn.getNetwork());
         b2.expectBroadcast();
         b3.expectBroadcast();
-        if (expectSetVpnDefaultForUids) {
+        if (isIkev2Vpn) {
             // Due to the VPN default request, getActiveNetworkInfo() gets the VPN network as the
             // network satisfier which has TYPE_VPN.
             assertActiveNetworkInfo(TYPE_VPN, DetailedState.CONNECTED);
@@ -10351,14 +10322,15 @@
         // callback with different network.
         final ExpectedBroadcast b6 = expectConnectivityAction(TYPE_VPN, DetailedState.DISCONNECTED);
         mMockVpn.stopVpnRunnerPrivileged();
-        mMockVpn.startLegacyVpnPrivileged(profile);
+
+        mMockVpn.startLegacyVpnPrivileged(isIkev2Vpn);
         // VPN network is disconnected (to restart)
         callback.expect(LOST, mMockVpn);
         defaultCallback.expect(LOST, mMockVpn);
         // The network preference is cleared when VPN is disconnected so it receives callbacks for
         // the system-wide default.
         defaultCallback.expectAvailableCallbacksUnvalidatedAndBlocked(mWiFiAgent);
-        if (expectSetVpnDefaultForUids) {
+        if (isIkev2Vpn) {
             // setVpnDefaultForUids() releases the original network request and creates a VPN
             // request so LOST callback is received.
             defaultCallback.expect(LOST, mWiFiAgent);
@@ -10367,7 +10339,7 @@
         b6.expectBroadcast();
 
         // While the VPN is reconnecting on the new network, everything is blocked.
-        if (expectSetVpnDefaultForUids) {
+        if (isIkev2Vpn) {
             // Due to the VPN default request, getActiveNetworkInfo() gets the mNoServiceNetwork
             // as the network satisfier.
             assertNull(mCm.getActiveNetworkInfo());
@@ -10388,7 +10360,7 @@
         systemDefaultCallback.assertNoCallback();
         b7.expectBroadcast();
         b8.expectBroadcast();
-        if (expectSetVpnDefaultForUids) {
+        if (isIkev2Vpn) {
             // Due to the VPN default request, getActiveNetworkInfo() gets the VPN network as the
             // network satisfier which has TYPE_VPN.
             assertActiveNetworkInfo(TYPE_VPN, DetailedState.CONNECTED);
@@ -10414,7 +10386,7 @@
         defaultCallback.assertNoCallback();
         systemDefaultCallback.assertNoCallback();
 
-        if (expectSetVpnDefaultForUids) {
+        if (isIkev2Vpn) {
             // Due to the VPN default request, getActiveNetworkInfo() gets the VPN network as the
             // network satisfier which has TYPE_VPN.
             assertActiveNetworkInfo(TYPE_VPN, DetailedState.CONNECTED);
@@ -10455,14 +10427,12 @@
 
     @Test
     public void testLockdownVpn_LegacyVpnRunner() throws Exception {
-        final VpnProfile profile = setupLockdownVpn(VpnProfile.TYPE_IPSEC_XAUTH_PSK);
-        doTestLockdownVpn(profile, false /* expectSetVpnDefaultForUids */);
+        doTestLockdownVpn(false /* isIkev2Vpn */);
     }
 
     @Test
     public void testLockdownVpn_Ikev2VpnRunner() throws Exception {
-        final VpnProfile profile = setupLockdownVpn(VpnProfile.TYPE_IKEV2_IPSEC_PSK);
-        doTestLockdownVpn(profile, true /* expectSetVpnDefaultForUids */);
+        doTestLockdownVpn(true /* isIkev2Vpn */);
     }
 
     @Test @IgnoreUpTo(Build.VERSION_CODES.S_V2)
@@ -17449,7 +17419,10 @@
                 .isCarrierServiceUidForNetworkCapabilities(eq(Process.myUid()), any());
         doReturn(TEST_SUBSCRIPTION_ID).when(mCarrierPrivilegeAuthenticator)
                 .getSubIdFromNetworkCapabilities(any());
-        mService.onCarrierPrivilegesLost(lostPrivilegeUid, lostPrivilegeSubId);
+
+        visibleOnHandlerThread(mCsHandlerThread.getThreadHandler(), () -> {
+            mDeps.mCarrierPrivilegesLostListener.accept(lostPrivilegeUid, lostPrivilegeSubId);
+        });
         waitForIdle();
 
         if (expectCapChanged) {
@@ -17463,11 +17436,12 @@
         }
 
         mWiFiAgent.disconnect();
-        waitForIdle();
 
         if (expectUnavailable) {
+            testFactory.expectRequestRemove();
             testFactory.assertRequestCountEquals(0);
         } else {
+            testFactory.expectRequestAdd();
             testFactory.assertRequestCountEquals(1);
         }
 
diff --git a/tests/unit/java/com/android/server/connectivity/CarrierPrivilegeAuthenticatorTest.java b/tests/unit/java/com/android/server/connectivity/CarrierPrivilegeAuthenticatorTest.java
index 7bd2b56..ab81abc 100644
--- a/tests/unit/java/com/android/server/connectivity/CarrierPrivilegeAuthenticatorTest.java
+++ b/tests/unit/java/com/android/server/connectivity/CarrierPrivilegeAuthenticatorTest.java
@@ -21,6 +21,7 @@
 import static android.telephony.TelephonyManager.ACTION_MULTI_SIM_CONFIG_CHANGED;
 
 import static com.android.server.connectivity.ConnectivityFlags.CARRIER_SERVICE_CHANGED_USE_CALLBACK;
+import static com.android.testutils.HandlerUtils.visibleOnHandlerThread;
 
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
@@ -45,6 +46,7 @@
 import android.net.NetworkCapabilities;
 import android.net.TelephonyNetworkSpecifier;
 import android.os.Build;
+import android.os.Handler;
 import android.os.HandlerThread;
 import android.telephony.TelephonyManager;
 
@@ -56,6 +58,7 @@
 import com.android.testutils.DevSdkIgnoreRule;
 import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo;
 import com.android.testutils.DevSdkIgnoreRunner;
+import com.android.testutils.HandlerUtils;
 
 import org.junit.After;
 import org.junit.Rule;
@@ -85,6 +88,7 @@
 
     private static final int SUBSCRIPTION_COUNT = 2;
     private static final int TEST_SUBSCRIPTION_ID = 1;
+    private static final int TIMEOUT_MS = 1_000;
 
     @NonNull private final Context mContext;
     @NonNull private final TelephonyManager mTelephonyManager;
@@ -97,13 +101,16 @@
     private final String mTestPkg = "com.android.server.connectivity.test";
     private final BroadcastReceiver mMultiSimBroadcastReceiver;
     @NonNull private final HandlerThread mHandlerThread;
+    @NonNull private final Handler mCsHandler;
+    @NonNull private final HandlerThread mCsHandlerThread;
 
     public class TestCarrierPrivilegeAuthenticator extends CarrierPrivilegeAuthenticator {
         TestCarrierPrivilegeAuthenticator(@NonNull final Context c,
                 @NonNull final Dependencies deps,
-                @NonNull final TelephonyManager t) {
+                @NonNull final TelephonyManager t,
+                @NonNull final Handler handler) {
             super(c, deps, t, mTelephonyManagerShim, true /* requestRestrictedWifiEnabled */,
-                    mListener);
+                    mListener, handler);
         }
         @Override
         protected int getSubId(int slotIndex) {
@@ -112,8 +119,11 @@
     }
 
     @After
-    public void tearDown() {
+    public void tearDown() throws Exception {
         mHandlerThread.quit();
+        mHandlerThread.join();
+        mCsHandlerThread.quit();
+        mCsHandlerThread.join();
     }
 
     /** Parameters to test both using callbacks or the old broadcast */
@@ -141,8 +151,14 @@
         final ApplicationInfo applicationInfo = new ApplicationInfo();
         applicationInfo.uid = mCarrierConfigPkgUid;
         doReturn(applicationInfo).when(mPackageManager).getApplicationInfo(eq(mTestPkg), anyInt());
-        mCarrierPrivilegeAuthenticator =
-                new TestCarrierPrivilegeAuthenticator(mContext, deps, mTelephonyManager);
+        mCsHandlerThread = new HandlerThread(
+                CarrierPrivilegeAuthenticatorTest.class.getSimpleName() + "-CsHandlerThread");
+        mCsHandlerThread.start();
+        mCsHandler = new Handler(mCsHandlerThread.getLooper());
+        mCarrierPrivilegeAuthenticator = new TestCarrierPrivilegeAuthenticator(mContext, deps,
+                mTelephonyManager, mCsHandler);
+        mCarrierPrivilegeAuthenticator.start();
+        HandlerUtils.waitForIdle(mCsHandlerThread, TIMEOUT_MS);
         final ArgumentCaptor<BroadcastReceiver> receiverCaptor =
                 ArgumentCaptor.forClass(BroadcastReceiver.class);
         verify(mContext).registerReceiver(receiverCaptor.capture(), argThat(filter ->
@@ -178,7 +194,9 @@
         assertNotNull(initialListeners.get(1));
         assertEquals(2, initialListeners.size());
 
-        initialListeners.get(0).onCarrierServiceChanged(null, mCarrierConfigPkgUid);
+        visibleOnHandlerThread(mCsHandler, () -> {
+            initialListeners.get(0).onCarrierServiceChanged(null, mCarrierConfigPkgUid);
+        });
 
         final NetworkCapabilities.Builder ncBuilder = new NetworkCapabilities.Builder()
                 .addTransportType(TRANSPORT_CELLULAR)
@@ -201,10 +219,10 @@
 
         doReturn(1).when(mTelephonyManager).getActiveModemCount();
 
-        // This is a little bit cavalier in that the call to onReceive is not on the handler
-        // thread that was specified in registerReceiver.
-        // TODO : capture the handler and call this on it if this causes flakiness.
-        mMultiSimBroadcastReceiver.onReceive(mContext, buildTestMultiSimConfigBroadcastIntent());
+        visibleOnHandlerThread(mCsHandler, () -> {
+            mMultiSimBroadcastReceiver.onReceive(mContext,
+                    buildTestMultiSimConfigBroadcastIntent());
+        });
         // Check all listeners have been removed
         for (CarrierPrivilegesListenerShim listener : initialListeners.values()) {
             verify(mTelephonyManagerShim).removeCarrierPrivilegesListener(eq(listener));
@@ -216,7 +234,9 @@
         assertNotNull(newListeners.get(0));
         assertEquals(1, newListeners.size());
 
-        newListeners.get(0).onCarrierServiceChanged(null, mCarrierConfigPkgUid);
+        visibleOnHandlerThread(mCsHandler, () -> {
+            newListeners.get(0).onCarrierServiceChanged(null, mCarrierConfigPkgUid);
+        });
 
         final TelephonyNetworkSpecifier specifier =
                 new TelephonyNetworkSpecifier(TEST_SUBSCRIPTION_ID);
@@ -235,12 +255,17 @@
     public void testCarrierPrivilegesLostDueToCarrierServiceUpdate() throws Exception {
         final CarrierPrivilegesListenerShim l = getCarrierPrivilegesListeners().get(0);
 
-        l.onCarrierServiceChanged(null, mCarrierConfigPkgUid);
-        l.onCarrierServiceChanged(null, mCarrierConfigPkgUid + 1);
+        visibleOnHandlerThread(mCsHandler, () -> {
+            l.onCarrierServiceChanged(null, mCarrierConfigPkgUid);
+            l.onCarrierServiceChanged(null, mCarrierConfigPkgUid + 1);
+        });
         if (mUseCallbacks) {
             verify(mListener).accept(eq(mCarrierConfigPkgUid), eq(TEST_SUBSCRIPTION_ID));
         }
-        l.onCarrierServiceChanged(null, mCarrierConfigPkgUid + 2);
+
+        visibleOnHandlerThread(mCsHandler, () -> {
+            l.onCarrierServiceChanged(null, mCarrierConfigPkgUid + 2);
+        });
         if (mUseCallbacks) {
             verify(mListener).accept(eq(mCarrierConfigPkgUid + 1), eq(TEST_SUBSCRIPTION_ID));
         }
@@ -260,8 +285,10 @@
         final ApplicationInfo applicationInfo = new ApplicationInfo();
         applicationInfo.uid = mCarrierConfigPkgUid + 1;
         doReturn(applicationInfo).when(mPackageManager).getApplicationInfo(eq(mTestPkg), anyInt());
-        listener.onCarrierPrivilegesChanged(Collections.emptyList(), new int[] {});
-        listener.onCarrierServiceChanged(null, applicationInfo.uid);
+        visibleOnHandlerThread(mCsHandler, () -> {
+            listener.onCarrierPrivilegesChanged(Collections.emptyList(), new int[]{});
+            listener.onCarrierServiceChanged(null, applicationInfo.uid);
+        });
 
         assertFalse(mCarrierPrivilegeAuthenticator.isCarrierServiceUidForNetworkCapabilities(
                 mCarrierConfigPkgUid, nc));
@@ -272,7 +299,9 @@
     @Test
     public void testDefaultSubscription() throws Exception {
         final CarrierPrivilegesListenerShim listener = getCarrierPrivilegesListeners().get(0);
-        listener.onCarrierServiceChanged(null, mCarrierConfigPkgUid);
+        visibleOnHandlerThread(mCsHandler, () -> {
+            listener.onCarrierServiceChanged(null, mCarrierConfigPkgUid);
+        });
 
         final NetworkCapabilities.Builder ncBuilder = new NetworkCapabilities.Builder();
         ncBuilder.addTransportType(TRANSPORT_CELLULAR);
@@ -297,7 +326,9 @@
     @IgnoreUpTo(Build.VERSION_CODES.TIRAMISU)
     public void testNetworkCapabilitiesContainOneSubId() throws Exception {
         final CarrierPrivilegesListenerShim listener = getCarrierPrivilegesListeners().get(0);
-        listener.onCarrierServiceChanged(null, mCarrierConfigPkgUid);
+        visibleOnHandlerThread(mCsHandler, () -> {
+            listener.onCarrierServiceChanged(null, mCarrierConfigPkgUid);
+        });
 
         final NetworkCapabilities.Builder ncBuilder = new NetworkCapabilities.Builder();
         ncBuilder.addTransportType(TRANSPORT_WIFI);
@@ -311,7 +342,9 @@
     @IgnoreUpTo(Build.VERSION_CODES.TIRAMISU)
     public void testNetworkCapabilitiesContainTwoSubIds() throws Exception {
         final CarrierPrivilegesListenerShim listener = getCarrierPrivilegesListeners().get(0);
-        listener.onCarrierServiceChanged(null, mCarrierConfigPkgUid);
+        visibleOnHandlerThread(mCsHandler, () -> {
+            listener.onCarrierServiceChanged(null, mCarrierConfigPkgUid);
+        });
 
         final NetworkCapabilities.Builder ncBuilder = new NetworkCapabilities.Builder();
         ncBuilder.addTransportType(TRANSPORT_WIFI);
diff --git a/tests/unit/java/com/android/server/connectivity/ClatCoordinatorTest.java b/tests/unit/java/com/android/server/connectivity/ClatCoordinatorTest.java
index 88044be..da7fda3 100644
--- a/tests/unit/java/com/android/server/connectivity/ClatCoordinatorTest.java
+++ b/tests/unit/java/com/android/server/connectivity/ClatCoordinatorTest.java
@@ -526,13 +526,13 @@
                     + "v4: /192.0.0.46, v6: /2001:db8:0:b11::464, pfx96: /64:ff9b::, "
                     + "pid: 10483, cookie: 27149", dumpStrings[0].trim());
             assertEquals("Forwarding rules:", dumpStrings[1].trim());
-            assertEquals("BPF ingress map: iif nat64Prefix v6Addr -> v4Addr oif",
+            assertEquals("BPF ingress map: iif nat64Prefix v6Addr -> v4Addr oif (packets bytes)",
                     dumpStrings[2].trim());
-            assertEquals("1000 /64:ff9b::/96 /2001:db8:0:b11::464 -> /192.0.0.46 1001",
+            assertEquals("1000 /64:ff9b::/96 /2001:db8:0:b11::464 -> /192.0.0.46 1001 (0 0)",
                     dumpStrings[3].trim());
-            assertEquals("BPF egress map: iif v4Addr -> v6Addr nat64Prefix oif",
+            assertEquals("BPF egress map: iif v4Addr -> v6Addr nat64Prefix oif (packets bytes)",
                     dumpStrings[4].trim());
-            assertEquals("1001 /192.0.0.46 -> /2001:db8:0:b11::464 /64:ff9b::/96 1000 ether",
+            assertEquals("1001 /192.0.0.46 -> /2001:db8:0:b11::464 /64:ff9b::/96 1000 ether (0 0)",
                     dumpStrings[5].trim());
         } else {
             assertEquals(1, dumpStrings.length);
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt
index 69fec85..629ac67 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt
@@ -18,7 +18,6 @@
 
 import android.net.InetAddresses.parseNumericAddress
 import android.net.LinkAddress
-import android.net.nsd.NsdManager
 import android.net.nsd.NsdServiceInfo
 import android.os.Build
 import android.os.HandlerThread
@@ -55,6 +54,7 @@
 import org.mockito.Mockito.never
 import org.mockito.Mockito.times
 import org.mockito.Mockito.verify
+import org.mockito.Mockito.inOrder
 
 private const val LOG_TAG = "testlogtag"
 private const val TIMEOUT_MS = 10_000L
@@ -65,6 +65,7 @@
 
 private const val TEST_SERVICE_ID_1 = 42
 private const val TEST_SERVICE_ID_DUPLICATE = 43
+private const val TEST_SERVICE_ID_2 = 44
 private val TEST_SERVICE_1 = NsdServiceInfo().apply {
     serviceType = "_testservice._tcp"
     serviceName = "MyTestService"
@@ -78,6 +79,13 @@
     port = 12345
 }
 
+private val TEST_SERVICE_1_CUSTOM_HOST = NsdServiceInfo().apply {
+    serviceType = "_testservice._tcp"
+    serviceName = "MyTestService"
+    hostname = "MyTestHost"
+    port = 12345
+}
+
 @RunWith(DevSdkIgnoreRunner::class)
 @IgnoreUpTo(Build.VERSION_CODES.S_V2)
 class MdnsInterfaceAdvertiserTest {
@@ -183,6 +191,93 @@
     }
 
     @Test
+    fun testAddRemoveServiceWithCustomHost_restartProbingForProbingServices() {
+        val customHost1 = NsdServiceInfo().apply {
+            hostname = "MyTestHost"
+            hostAddresses = listOf(
+                    parseNumericAddress("192.0.2.23"),
+                    parseNumericAddress("2001:db8::1"))
+        }
+        addServiceAndFinishProbing(TEST_SERVICE_ID_1, customHost1)
+        addServiceAndFinishProbing(TEST_SERVICE_ID_2, TEST_SERVICE_1_CUSTOM_HOST)
+        repository.setServiceProbing(TEST_SERVICE_ID_2)
+        val probingInfo = mock(ProbingInfo::class.java)
+        doReturn("MyTestHost")
+                .`when`(repository).getHostnameForServiceId(TEST_SERVICE_ID_1)
+        doReturn(TEST_SERVICE_ID_2).`when`(probingInfo).serviceId
+        doReturn(listOf(probingInfo))
+                .`when`(repository).restartProbingForHostname("MyTestHost")
+        val inOrder = inOrder(prober, announcer)
+
+        // Remove the custom host: the custom host's announcement is stopped and the probing
+        // services which use that hostname are re-announced.
+        advertiser.removeService(TEST_SERVICE_ID_1)
+
+        inOrder.verify(prober).stop(TEST_SERVICE_ID_1)
+        inOrder.verify(announcer).stop(TEST_SERVICE_ID_1)
+        inOrder.verify(prober).stop(TEST_SERVICE_ID_2)
+        inOrder.verify(prober).startProbing(probingInfo)
+    }
+
+    @Test
+    fun testAddRemoveServiceWithCustomHost_restartAnnouncingForProbedServices() {
+        val customHost1 = NsdServiceInfo().apply {
+            hostname = "MyTestHost"
+            hostAddresses = listOf(
+                    parseNumericAddress("192.0.2.23"),
+                    parseNumericAddress("2001:db8::1"))
+        }
+        addServiceAndFinishProbing(TEST_SERVICE_ID_1, customHost1)
+        val announcementInfo =
+                addServiceAndFinishProbing(TEST_SERVICE_ID_2, TEST_SERVICE_1_CUSTOM_HOST)
+        doReturn("MyTestHost")
+                .`when`(repository).getHostnameForServiceId(TEST_SERVICE_ID_1)
+        doReturn(listOf(announcementInfo))
+                .`when`(repository).restartAnnouncingForHostname("MyTestHost")
+        val inOrder = inOrder(prober, announcer)
+
+        // Remove the custom host: the custom host's announcement is stopped and the probed services
+        // which use that hostname are re-announced.
+        advertiser.removeService(TEST_SERVICE_ID_1)
+
+        inOrder.verify(prober).stop(TEST_SERVICE_ID_1)
+        inOrder.verify(announcer).stop(TEST_SERVICE_ID_1)
+        inOrder.verify(announcer).stop(TEST_SERVICE_ID_2)
+        inOrder.verify(announcer).startSending(TEST_SERVICE_ID_2, announcementInfo, 0L /* initialDelayMs */)
+    }
+
+    @Test
+    fun testAddMoreAddressesForCustomHost_restartAnnouncingForProbedServices() {
+        val customHost = NsdServiceInfo().apply {
+            hostname = "MyTestHost"
+            hostAddresses = listOf(
+                parseNumericAddress("192.0.2.23"),
+                parseNumericAddress("2001:db8::1"))
+        }
+        doReturn("MyTestHost")
+            .`when`(repository).getHostnameForServiceId(TEST_SERVICE_ID_1)
+        doReturn("MyTestHost")
+            .`when`(repository).getHostnameForServiceId(TEST_SERVICE_ID_2)
+        val announcementInfo1 =
+            addServiceAndFinishProbing(TEST_SERVICE_ID_1, TEST_SERVICE_1_CUSTOM_HOST)
+
+        val probingInfo2 = addServiceAndStartProbing(TEST_SERVICE_ID_2, customHost)
+        val announcementInfo2 = AnnouncementInfo(TEST_SERVICE_ID_2, emptyList(), emptyList())
+        doReturn(announcementInfo2).`when`(repository).onProbingSucceeded(probingInfo2)
+        doReturn(listOf(announcementInfo1, announcementInfo2))
+            .`when`(repository).restartAnnouncingForHostname("MyTestHost")
+        probeCb.onFinished(probingInfo2)
+
+        val inOrder = inOrder(prober, announcer)
+
+        inOrder.verify(announcer)
+            .startSending(TEST_SERVICE_ID_2, announcementInfo2, 0L /* initialDelayMs */)
+        inOrder.verify(announcer).stop(TEST_SERVICE_ID_1)
+        inOrder.verify(announcer)
+            .startSending(TEST_SERVICE_ID_1, announcementInfo1, 0L /* initialDelayMs */)
+    }
+
+    @Test
     fun testDoubleRemove() {
         addServiceAndFinishProbing(TEST_SERVICE_ID_1, TEST_SERVICE_1)
 
@@ -422,8 +517,8 @@
         verify(prober, never()).startProbing(any())
     }
 
-    private fun addServiceAndFinishProbing(serviceId: Int, serviceInfo: NsdServiceInfo):
-            AnnouncementInfo {
+    private fun addServiceAndStartProbing(serviceId: Int, serviceInfo: NsdServiceInfo):
+            ProbingInfo {
         val testProbingInfo = mock(ProbingInfo::class.java)
         doReturn(serviceId).`when`(testProbingInfo).serviceId
         doReturn(testProbingInfo).`when`(repository).setServiceProbing(serviceId)
@@ -432,8 +527,15 @@
         verify(repository).addService(serviceId, serviceInfo, null /* ttl */)
         verify(prober).startProbing(testProbingInfo)
 
+        return testProbingInfo
+    }
+
+    private fun addServiceAndFinishProbing(serviceId: Int, serviceInfo: NsdServiceInfo):
+            AnnouncementInfo {
+        val testProbingInfo = addServiceAndStartProbing(serviceId, serviceInfo)
+
         // Simulate probing success: continues to announcing
-        val testAnnouncementInfo = mock(AnnouncementInfo::class.java)
+        val testAnnouncementInfo = AnnouncementInfo(serviceId, emptyList(), emptyList())
         doReturn(testAnnouncementInfo).`when`(repository).onProbingSucceeded(testProbingInfo)
         probeCb.onFinished(testProbingInfo)
         return testAnnouncementInfo
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClientTest.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClientTest.java
index 9474464..fb3d183 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClientTest.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClientTest.java
@@ -23,6 +23,7 @@
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.eq;
+import static org.mockito.Mockito.inOrder;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.never;
 import static org.mockito.Mockito.timeout;
@@ -47,6 +48,7 @@
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.mockito.ArgumentCaptor;
+import org.mockito.InOrder;
 import org.mockito.Mock;
 import org.mockito.MockitoAnnotations;
 
@@ -55,6 +57,7 @@
 import java.net.DatagramPacket;
 import java.net.NetworkInterface;
 import java.net.SocketException;
+import java.util.ArrayList;
 import java.util.List;
 
 @RunWith(DevSdkIgnoreRunner.class)
@@ -154,7 +157,7 @@
         verify(mSocketCreationCallback).onSocketCreated(tetherSocketKey2);
 
         // Send packet to IPv4 with mSocketKey and verify sending has been called.
-        mSocketClient.sendPacketRequestingMulticastResponse(ipv4Packet, mSocketKey,
+        mSocketClient.sendPacketRequestingMulticastResponse(List.of(ipv4Packet), mSocketKey,
                 false /* onlyUseIpv6OnIpv6OnlyNetworks */);
         HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
         verify(mSocket).send(ipv4Packet);
@@ -162,7 +165,7 @@
         verify(tetherIfaceSock2, never()).send(any());
 
         // Send packet to IPv4 with onlyUseIpv6OnIpv6OnlyNetworks = true, the packet will be sent.
-        mSocketClient.sendPacketRequestingMulticastResponse(ipv4Packet, mSocketKey,
+        mSocketClient.sendPacketRequestingMulticastResponse(List.of(ipv4Packet), mSocketKey,
                 true /* onlyUseIpv6OnIpv6OnlyNetworks */);
         HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
         verify(mSocket, times(2)).send(ipv4Packet);
@@ -170,7 +173,7 @@
         verify(tetherIfaceSock2, never()).send(any());
 
         // Send packet to IPv6 with tetherSocketKey1 and verify sending has been called.
-        mSocketClient.sendPacketRequestingMulticastResponse(ipv6Packet, tetherSocketKey1,
+        mSocketClient.sendPacketRequestingMulticastResponse(List.of(ipv6Packet), tetherSocketKey1,
                 false /* onlyUseIpv6OnIpv6OnlyNetworks */);
         HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
         verify(mSocket, never()).send(ipv6Packet);
@@ -180,7 +183,7 @@
         // Send packet to IPv6 with onlyUseIpv6OnIpv6OnlyNetworks = true, the packet will not be
         // sent. Therefore, the tetherIfaceSock1.send() and tetherIfaceSock2.send() are still be
         // called once.
-        mSocketClient.sendPacketRequestingMulticastResponse(ipv6Packet, tetherSocketKey1,
+        mSocketClient.sendPacketRequestingMulticastResponse(List.of(ipv6Packet), tetherSocketKey1,
                 true /* onlyUseIpv6OnIpv6OnlyNetworks */);
         HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
         verify(mSocket, never()).send(ipv6Packet);
@@ -266,7 +269,7 @@
         verify(mSocketCreationCallback).onSocketCreated(socketKey3);
 
         // Send IPv4 packet on the mSocketKey and verify sending has been called.
-        mSocketClient.sendPacketRequestingMulticastResponse(ipv4Packet, mSocketKey,
+        mSocketClient.sendPacketRequestingMulticastResponse(List.of(ipv4Packet), mSocketKey,
                 false /* onlyUseIpv6OnIpv6OnlyNetworks */);
         HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
         verify(mSocket).send(ipv4Packet);
@@ -295,7 +298,7 @@
         verify(socketCreationCb2).onSocketCreated(socketKey3);
 
         // Send IPv4 packet on socket2 and verify sending to the socket2 only.
-        mSocketClient.sendPacketRequestingMulticastResponse(ipv4Packet, socketKey2,
+        mSocketClient.sendPacketRequestingMulticastResponse(List.of(ipv4Packet), socketKey2,
                 false /* onlyUseIpv6OnIpv6OnlyNetworks */);
         HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
         // ipv4Packet still sent only once on mSocket: times(1) matches the packet sent earlier on
@@ -309,7 +312,7 @@
         verify(mProvider, timeout(DEFAULT_TIMEOUT)).unrequestSocket(callback2);
 
         // Send IPv4 packet again and verify it's still sent a second time
-        mSocketClient.sendPacketRequestingMulticastResponse(ipv4Packet, socketKey2,
+        mSocketClient.sendPacketRequestingMulticastResponse(List.of(ipv4Packet), socketKey2,
                 false /* onlyUseIpv6OnIpv6OnlyNetworks */);
         HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
         verify(socket2, times(2)).send(ipv4Packet);
@@ -320,7 +323,7 @@
         verify(mProvider, timeout(DEFAULT_TIMEOUT)).unrequestSocket(callback);
 
         // Send IPv4 packet and verify no more sending.
-        mSocketClient.sendPacketRequestingMulticastResponse(ipv4Packet, mSocketKey,
+        mSocketClient.sendPacketRequestingMulticastResponse(List.of(ipv4Packet), mSocketKey,
                 false /* onlyUseIpv6OnIpv6OnlyNetworks */);
         HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
         verify(mSocket, times(1)).send(ipv4Packet);
@@ -407,4 +410,31 @@
         verify(creationCallback3).onSocketDestroyed(mSocketKey);
         verify(creationCallback3, never()).onSocketDestroyed(socketKey2);
     }
+
+    @Test
+    public void testSendPacketWithMultipleDatagramPacket() throws IOException {
+        final SocketCallback callback = expectSocketCallback();
+        final List<DatagramPacket> packets = new ArrayList<>();
+        for (int i = 0; i < 10; i++) {
+            packets.add(new DatagramPacket(new byte[10 + i] /* buff */, 0 /* offset */,
+                    10 + i /* length */, MdnsConstants.IPV4_SOCKET_ADDR));
+        }
+        doReturn(true).when(mSocket).hasJoinedIpv4();
+        doReturn(true).when(mSocket).hasJoinedIpv6();
+        doReturn(createEmptyNetworkInterface()).when(mSocket).getInterface();
+
+        // Notify socket created
+        callback.onSocketCreated(mSocketKey, mSocket, List.of());
+        verify(mSocketCreationCallback).onSocketCreated(mSocketKey);
+
+        // Send packets to IPv4 with mSocketKey then verify sending has been called and the
+        // sequence is correct.
+        mSocketClient.sendPacketRequestingMulticastResponse(packets, mSocketKey,
+                false /* onlyUseIpv6OnIpv6OnlyNetworks */);
+        HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
+        InOrder inOrder = inOrder(mSocket);
+        for (int i = 0; i < 10; i++) {
+            inOrder.verify(mSocket).send(packets.get(i));
+        }
+    }
 }
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt
index c69b1e1..271cc65 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt
@@ -24,6 +24,7 @@
 import com.android.server.connectivity.mdns.MdnsAnnouncer.AnnouncementInfo
 import com.android.server.connectivity.mdns.MdnsInterfaceAdvertiser.CONFLICT_HOST
 import com.android.server.connectivity.mdns.MdnsInterfaceAdvertiser.CONFLICT_SERVICE
+import com.android.server.connectivity.mdns.MdnsProber.ProbingInfo
 import com.android.server.connectivity.mdns.MdnsRecord.TYPE_A
 import com.android.server.connectivity.mdns.MdnsRecord.TYPE_AAAA
 import com.android.server.connectivity.mdns.MdnsRecord.TYPE_PTR
@@ -51,6 +52,10 @@
 import org.junit.Before
 import org.junit.Test
 import org.junit.runner.RunWith
+import org.mockito.ArgumentCaptor
+import org.mockito.ArgumentMatchers.eq
+import org.mockito.Mockito.mock
+import org.mockito.Mockito.verify
 
 private const val TEST_SERVICE_ID_1 = 42
 private const val TEST_SERVICE_ID_2 = 43
@@ -112,6 +117,14 @@
     port = TEST_PORT
 }
 
+private val TEST_SERVICE_CUSTOM_HOST_NO_ADDRESSES = NsdServiceInfo().apply {
+    hostname = "TestHost"
+    hostAddresses = listOf()
+    serviceType = "_testservice._tcp"
+    serviceName = "TestService"
+    port = TEST_PORT
+}
+
 @RunWith(DevSdkIgnoreRunner::class)
 @DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.S_V2)
 class MdnsRecordRepositoryTest {
@@ -1676,6 +1689,127 @@
         assertEquals(0, reply.additionalAnswers.size)
         assertEquals(knownAnswers, reply.knownAnswers)
     }
+
+    @Test
+    fun testRestartProbingForHostname() {
+        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
+        repository.initWithService(TEST_CUSTOM_HOST_ID_1, TEST_CUSTOM_HOST_1,
+                setOf(TEST_SUBTYPE, TEST_SUBTYPE2))
+        repository.addService(TEST_SERVICE_CUSTOM_HOST_ID_1,
+                TEST_SERVICE_CUSTOM_HOST_NO_ADDRESSES, null)
+        repository.setServiceProbing(TEST_SERVICE_CUSTOM_HOST_ID_1)
+        repository.removeService(TEST_CUSTOM_HOST_ID_1)
+
+        val probingInfos = repository.restartProbingForHostname("TestHost")
+
+        assertEquals(1, probingInfos.size)
+        val probingInfo = probingInfos.get(0)
+        assertEquals(TEST_SERVICE_CUSTOM_HOST_ID_1, probingInfo.serviceId)
+        val packet = probingInfo.getPacket(0)
+        assertEquals(0, packet.transactionId)
+        assertEquals(MdnsConstants.FLAGS_QUERY, packet.flags)
+        assertEquals(0, packet.answers.size)
+        assertEquals(0, packet.additionalRecords.size)
+        assertEquals(1, packet.questions.size)
+        val serviceName = arrayOf("TestService", "_testservice", "_tcp", "local")
+        assertEquals(MdnsAnyRecord(serviceName, false /* unicast */), packet.questions[0])
+        assertThat(packet.authorityRecords).containsExactly(
+                MdnsServiceRecord(
+                        serviceName,
+                        0L /* receiptTimeMillis */,
+                        false /* cacheFlush */,
+                        SHORT_TTL /* ttlMillis */,
+                        0 /* servicePriority */,
+                        0 /* serviceWeight */,
+                        TEST_PORT,
+                        TEST_CUSTOM_HOST_1_NAME))
+    }
+
+    @Test
+    fun testRestartAnnouncingForHostname() {
+        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
+        repository.initWithService(TEST_CUSTOM_HOST_ID_1, TEST_CUSTOM_HOST_1,
+                setOf(TEST_SUBTYPE, TEST_SUBTYPE2))
+        repository.addServiceAndFinishProbing(TEST_SERVICE_CUSTOM_HOST_ID_1,
+                TEST_SERVICE_CUSTOM_HOST_NO_ADDRESSES)
+        repository.removeService(TEST_CUSTOM_HOST_ID_1)
+
+        val announcementInfos = repository.restartAnnouncingForHostname("TestHost")
+
+        assertEquals(1, announcementInfos.size)
+        val announcementInfo = announcementInfos.get(0)
+        assertEquals(TEST_SERVICE_CUSTOM_HOST_ID_1, announcementInfo.serviceId)
+        val packet = announcementInfo.getPacket(0)
+        assertEquals(0, packet.transactionId)
+        assertEquals(0x8400 /* response, authoritative */, packet.flags)
+        assertEquals(0, packet.questions.size)
+        assertEquals(0, packet.authorityRecords.size)
+        val serviceName = arrayOf("TestService", "_testservice", "_tcp", "local")
+        val serviceType = arrayOf("_testservice", "_tcp", "local")
+        val v4AddrRev = getReverseDnsAddress(TEST_ADDRESSES[0].address)
+        val v6Addr1Rev = getReverseDnsAddress(TEST_ADDRESSES[1].address)
+        val v6Addr2Rev = getReverseDnsAddress(TEST_ADDRESSES[2].address)
+        assertThat(packet.answers).containsExactly(
+                MdnsPointerRecord(
+                        serviceType,
+                        0L /* receiptTimeMillis */,
+                        // Not a unique name owned by the announcer, so cacheFlush=false
+                        false /* cacheFlush */,
+                        4500000L /* ttlMillis */,
+                        serviceName),
+                MdnsServiceRecord(
+                        serviceName,
+                        0L /* receiptTimeMillis */,
+                        true /* cacheFlush */,
+                        120000L /* ttlMillis */,
+                        0 /* servicePriority */,
+                        0 /* serviceWeight */,
+                        TEST_PORT /* servicePort */,
+                        TEST_CUSTOM_HOST_1_NAME),
+                MdnsTextRecord(
+                        serviceName,
+                        0L /* receiptTimeMillis */,
+                        true /* cacheFlush */,
+                        4500000L /* ttlMillis */,
+                        emptyList() /* entries */),
+                MdnsPointerRecord(
+                        arrayOf("_services", "_dns-sd", "_udp", "local"),
+                        0L /* receiptTimeMillis */,
+                        false /* cacheFlush */,
+                        4500000L /* ttlMillis */,
+                        serviceType))
+        assertThat(packet.additionalRecords).containsExactly(
+                MdnsNsecRecord(v4AddrRev,
+                        0L /* receiptTimeMillis */,
+                        true /* cacheFlush */,
+                        120000L /* ttlMillis */,
+                        v4AddrRev,
+                        intArrayOf(TYPE_PTR)),
+                MdnsNsecRecord(TEST_HOSTNAME,
+                        0L /* receiptTimeMillis */,
+                        true /* cacheFlush */,
+                        120000L /* ttlMillis */,
+                        TEST_HOSTNAME,
+                        intArrayOf(TYPE_A, TYPE_AAAA)),
+                MdnsNsecRecord(v6Addr1Rev,
+                        0L /* receiptTimeMillis */,
+                        true /* cacheFlush */,
+                        120000L /* ttlMillis */,
+                        v6Addr1Rev,
+                        intArrayOf(TYPE_PTR)),
+                MdnsNsecRecord(v6Addr2Rev,
+                        0L /* receiptTimeMillis */,
+                        true /* cacheFlush */,
+                        120000L /* ttlMillis */,
+                        v6Addr2Rev,
+                        intArrayOf(TYPE_PTR)),
+                MdnsNsecRecord(serviceName,
+                        0L /* receiptTimeMillis */,
+                        true /* cacheFlush */,
+                        4500000L /* ttlMillis */,
+                        serviceName,
+                        intArrayOf(TYPE_TXT, TYPE_SRV)))
+    }
 }
 
 private fun MdnsRecordRepository.initWithService(
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java
index 09236b1..44fa55c 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java
@@ -38,6 +38,7 @@
 import static org.mockito.ArgumentMatchers.argThat;
 import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.doCallRealMethod;
 import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.inOrder;
 import static org.mockito.Mockito.never;
@@ -117,8 +118,6 @@
     @Mock
     private MdnsServiceBrowserListener mockListenerTwo;
     @Mock
-    private MdnsPacketWriter mockPacketWriter;
-    @Mock
     private MdnsMultinetworkSocketClient mockSocketClient;
     @Mock
     private Network mockNetwork;
@@ -145,6 +144,7 @@
     private long latestDelayMs = 0;
     private Message delayMessage = null;
     private Handler realHandler = null;
+    private MdnsFeatureFlags featureFlags = MdnsFeatureFlags.newBuilder().build();
 
     @Before
     @SuppressWarnings("DoNotMock")
@@ -162,57 +162,59 @@
             expectedIPv6Packets[i] = new DatagramPacket(buf, 0 /* offset */, 5 /* length */,
                     MdnsConstants.getMdnsIPv6Address(), MdnsConstants.MDNS_PORT);
         }
-        when(mockPacketWriter.getPacket(IPV4_ADDRESS))
-                .thenReturn(expectedIPv4Packets[0])
-                .thenReturn(expectedIPv4Packets[1])
-                .thenReturn(expectedIPv4Packets[2])
-                .thenReturn(expectedIPv4Packets[3])
-                .thenReturn(expectedIPv4Packets[4])
-                .thenReturn(expectedIPv4Packets[5])
-                .thenReturn(expectedIPv4Packets[6])
-                .thenReturn(expectedIPv4Packets[7])
-                .thenReturn(expectedIPv4Packets[8])
-                .thenReturn(expectedIPv4Packets[9])
-                .thenReturn(expectedIPv4Packets[10])
-                .thenReturn(expectedIPv4Packets[11])
-                .thenReturn(expectedIPv4Packets[12])
-                .thenReturn(expectedIPv4Packets[13])
-                .thenReturn(expectedIPv4Packets[14])
-                .thenReturn(expectedIPv4Packets[15])
-                .thenReturn(expectedIPv4Packets[16])
-                .thenReturn(expectedIPv4Packets[17])
-                .thenReturn(expectedIPv4Packets[18])
-                .thenReturn(expectedIPv4Packets[19])
-                .thenReturn(expectedIPv4Packets[20])
-                .thenReturn(expectedIPv4Packets[21])
-                .thenReturn(expectedIPv4Packets[22])
-                .thenReturn(expectedIPv4Packets[23]);
+        when(mockDeps.getDatagramPacketsFromMdnsPacket(
+                any(), any(MdnsPacket.class), eq(IPV4_ADDRESS), anyBoolean()))
+                .thenReturn(List.of(expectedIPv4Packets[0]))
+                .thenReturn(List.of(expectedIPv4Packets[1]))
+                .thenReturn(List.of(expectedIPv4Packets[2]))
+                .thenReturn(List.of(expectedIPv4Packets[3]))
+                .thenReturn(List.of(expectedIPv4Packets[4]))
+                .thenReturn(List.of(expectedIPv4Packets[5]))
+                .thenReturn(List.of(expectedIPv4Packets[6]))
+                .thenReturn(List.of(expectedIPv4Packets[7]))
+                .thenReturn(List.of(expectedIPv4Packets[8]))
+                .thenReturn(List.of(expectedIPv4Packets[9]))
+                .thenReturn(List.of(expectedIPv4Packets[10]))
+                .thenReturn(List.of(expectedIPv4Packets[11]))
+                .thenReturn(List.of(expectedIPv4Packets[12]))
+                .thenReturn(List.of(expectedIPv4Packets[13]))
+                .thenReturn(List.of(expectedIPv4Packets[14]))
+                .thenReturn(List.of(expectedIPv4Packets[15]))
+                .thenReturn(List.of(expectedIPv4Packets[16]))
+                .thenReturn(List.of(expectedIPv4Packets[17]))
+                .thenReturn(List.of(expectedIPv4Packets[18]))
+                .thenReturn(List.of(expectedIPv4Packets[19]))
+                .thenReturn(List.of(expectedIPv4Packets[20]))
+                .thenReturn(List.of(expectedIPv4Packets[21]))
+                .thenReturn(List.of(expectedIPv4Packets[22]))
+                .thenReturn(List.of(expectedIPv4Packets[23]));
 
-        when(mockPacketWriter.getPacket(IPV6_ADDRESS))
-                .thenReturn(expectedIPv6Packets[0])
-                .thenReturn(expectedIPv6Packets[1])
-                .thenReturn(expectedIPv6Packets[2])
-                .thenReturn(expectedIPv6Packets[3])
-                .thenReturn(expectedIPv6Packets[4])
-                .thenReturn(expectedIPv6Packets[5])
-                .thenReturn(expectedIPv6Packets[6])
-                .thenReturn(expectedIPv6Packets[7])
-                .thenReturn(expectedIPv6Packets[8])
-                .thenReturn(expectedIPv6Packets[9])
-                .thenReturn(expectedIPv6Packets[10])
-                .thenReturn(expectedIPv6Packets[11])
-                .thenReturn(expectedIPv6Packets[12])
-                .thenReturn(expectedIPv6Packets[13])
-                .thenReturn(expectedIPv6Packets[14])
-                .thenReturn(expectedIPv6Packets[15])
-                .thenReturn(expectedIPv6Packets[16])
-                .thenReturn(expectedIPv6Packets[17])
-                .thenReturn(expectedIPv6Packets[18])
-                .thenReturn(expectedIPv6Packets[19])
-                .thenReturn(expectedIPv6Packets[20])
-                .thenReturn(expectedIPv6Packets[21])
-                .thenReturn(expectedIPv6Packets[22])
-                .thenReturn(expectedIPv6Packets[23]);
+        when(mockDeps.getDatagramPacketsFromMdnsPacket(
+                any(), any(MdnsPacket.class), eq(IPV6_ADDRESS), anyBoolean()))
+                .thenReturn(List.of(expectedIPv6Packets[0]))
+                .thenReturn(List.of(expectedIPv6Packets[1]))
+                .thenReturn(List.of(expectedIPv6Packets[2]))
+                .thenReturn(List.of(expectedIPv6Packets[3]))
+                .thenReturn(List.of(expectedIPv6Packets[4]))
+                .thenReturn(List.of(expectedIPv6Packets[5]))
+                .thenReturn(List.of(expectedIPv6Packets[6]))
+                .thenReturn(List.of(expectedIPv6Packets[7]))
+                .thenReturn(List.of(expectedIPv6Packets[8]))
+                .thenReturn(List.of(expectedIPv6Packets[9]))
+                .thenReturn(List.of(expectedIPv6Packets[10]))
+                .thenReturn(List.of(expectedIPv6Packets[11]))
+                .thenReturn(List.of(expectedIPv6Packets[12]))
+                .thenReturn(List.of(expectedIPv6Packets[13]))
+                .thenReturn(List.of(expectedIPv6Packets[14]))
+                .thenReturn(List.of(expectedIPv6Packets[15]))
+                .thenReturn(List.of(expectedIPv6Packets[16]))
+                .thenReturn(List.of(expectedIPv6Packets[17]))
+                .thenReturn(List.of(expectedIPv6Packets[18]))
+                .thenReturn(List.of(expectedIPv6Packets[19]))
+                .thenReturn(List.of(expectedIPv6Packets[20]))
+                .thenReturn(List.of(expectedIPv6Packets[21]))
+                .thenReturn(List.of(expectedIPv6Packets[22]))
+                .thenReturn(List.of(expectedIPv6Packets[23]));
 
         thread = new HandlerThread("MdnsServiceTypeClientTests");
         thread.start();
@@ -242,22 +244,13 @@
             return true;
         }).when(mockDeps).sendMessage(any(Handler.class), any(Message.class));
 
-        client = makeMdnsServiceTypeClient(mockPacketWriter);
+        client = makeMdnsServiceTypeClient();
     }
 
-    private MdnsServiceTypeClient makeMdnsServiceTypeClient(
-            @Nullable MdnsPacketWriter packetWriter) {
+    private MdnsServiceTypeClient makeMdnsServiceTypeClient() {
         return new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor,
                 mockDecoderClock, socketKey, mockSharedLog, thread.getLooper(), mockDeps,
-                serviceCache) {
-            @Override
-            MdnsPacketWriter createMdnsPacketWriter() {
-                if (packetWriter == null) {
-                    return super.createMdnsPacketWriter();
-                }
-                return packetWriter;
-            }
-        };
+                serviceCache, featureFlags);
     }
 
     @After
@@ -697,27 +690,27 @@
 
     @Test
     public void testCombinedSubtypesQueriedWithMultipleListeners() throws Exception {
-        client = makeMdnsServiceTypeClient(/* packetWriter= */ null);
         final MdnsSearchOptions searchOptions1 = MdnsSearchOptions.newBuilder()
                 .addSubtype("subtype1").build();
         final MdnsSearchOptions searchOptions2 = MdnsSearchOptions.newBuilder()
                 .addSubtype("subtype2").build();
+        doCallRealMethod().when(mockDeps).getDatagramPacketsFromMdnsPacket(
+                any(), any(MdnsPacket.class), any(InetSocketAddress.class), anyBoolean());
         startSendAndReceive(mockListenerOne, searchOptions1);
-        currentThreadExecutor.getAndClearSubmittedRunnable().run();
+        currentThreadExecutor.getAndClearLastScheduledRunnable().run();
 
         InOrder inOrder = inOrder(mockListenerOne, mockSocketClient, mockDeps);
 
         // Verify the query asks for subtype1
-        final ArgumentCaptor<DatagramPacket> subtype1QueryCaptor =
-                ArgumentCaptor.forClass(DatagramPacket.class);
-        currentThreadExecutor.getAndClearLastScheduledRunnable().run();
+        final ArgumentCaptor<List<DatagramPacket>> subtype1QueryCaptor =
+                ArgumentCaptor.forClass(List.class);
         // Send twice for IPv4 and IPv6
         inOrder.verify(mockSocketClient, times(2)).sendPacketRequestingUnicastResponse(
                 subtype1QueryCaptor.capture(),
                 eq(socketKey), eq(false));
 
         final MdnsPacket subtype1Query = MdnsPacket.parse(
-                new MdnsPacketReader(subtype1QueryCaptor.getValue()));
+                new MdnsPacketReader(subtype1QueryCaptor.getValue().get(0)));
 
         assertEquals(2, subtype1Query.questions.size());
         assertTrue(hasQuestion(subtype1Query, MdnsRecord.TYPE_PTR, SERVICE_TYPE_LABELS));
@@ -729,8 +722,8 @@
         inOrder.verify(mockDeps).removeMessages(any(), eq(EVENT_START_QUERYTASK));
         currentThreadExecutor.getAndClearLastScheduledRunnable().run();
 
-        final ArgumentCaptor<DatagramPacket> combinedSubtypesQueryCaptor =
-                ArgumentCaptor.forClass(DatagramPacket.class);
+        final ArgumentCaptor<List<DatagramPacket>> combinedSubtypesQueryCaptor =
+                ArgumentCaptor.forClass(List.class);
         inOrder.verify(mockSocketClient, times(2)).sendPacketRequestingUnicastResponse(
                 combinedSubtypesQueryCaptor.capture(),
                 eq(socketKey), eq(false));
@@ -738,7 +731,7 @@
         inOrder.verify(mockDeps).sendMessageDelayed(any(), any(), anyLong());
 
         final MdnsPacket combinedSubtypesQuery = MdnsPacket.parse(
-                new MdnsPacketReader(combinedSubtypesQueryCaptor.getValue()));
+                new MdnsPacketReader(combinedSubtypesQueryCaptor.getValue().get(0)));
 
         assertEquals(3, combinedSubtypesQuery.questions.size());
         assertTrue(hasQuestion(combinedSubtypesQuery, MdnsRecord.TYPE_PTR, SERVICE_TYPE_LABELS));
@@ -754,15 +747,15 @@
         dispatchMessage();
         currentThreadExecutor.getAndClearLastScheduledRunnable().run();
 
-        final ArgumentCaptor<DatagramPacket> subtype2QueryCaptor =
-                ArgumentCaptor.forClass(DatagramPacket.class);
+        final ArgumentCaptor<List<DatagramPacket>> subtype2QueryCaptor =
+                ArgumentCaptor.forClass(List.class);
         // Send twice for IPv4 and IPv6
         inOrder.verify(mockSocketClient, times(2)).sendPacketRequestingMulticastResponse(
                 subtype2QueryCaptor.capture(),
                 eq(socketKey), eq(false));
 
         final MdnsPacket subtype2Query = MdnsPacket.parse(
-                new MdnsPacketReader(subtype2QueryCaptor.getValue()));
+                new MdnsPacketReader(subtype2QueryCaptor.getValue().get(0)));
 
         assertEquals(2, subtype2Query.questions.size());
         assertTrue(hasQuestion(subtype2Query, MdnsRecord.TYPE_PTR, SERVICE_TYPE_LABELS));
@@ -1022,7 +1015,6 @@
     public void processResponse_searchOptionsEnableServiceRemoval_shouldRemove()
             throws Exception {
         final String serviceInstanceName = "service-instance-1";
-        client = makeMdnsServiceTypeClient(mockPacketWriter);
         MdnsSearchOptions searchOptions = MdnsSearchOptions.newBuilder()
                 .setRemoveExpiredService(true)
                 .setNumOfQueriesBeforeBackoff(Integer.MAX_VALUE)
@@ -1060,7 +1052,6 @@
     public void processResponse_searchOptionsNotEnableServiceRemoval_shouldNotRemove()
             throws Exception {
         final String serviceInstanceName = "service-instance-1";
-        client = makeMdnsServiceTypeClient(mockPacketWriter);
         startSendAndReceive(mockListenerOne, MdnsSearchOptions.getDefaultOptions());
         Runnable firstMdnsTask = currentThreadExecutor.getAndClearSubmittedRunnable();
 
@@ -1086,7 +1077,6 @@
             throws Exception {
         //MdnsConfigsFlagsImpl.removeServiceAfterTtlExpires.override(true);
         final String serviceInstanceName = "service-instance-1";
-        client = makeMdnsServiceTypeClient(mockPacketWriter);
         startSendAndReceive(mockListenerOne, MdnsSearchOptions.getDefaultOptions());
         Runnable firstMdnsTask = currentThreadExecutor.getAndClearSubmittedRunnable();
 
@@ -1201,8 +1191,6 @@
 
     @Test
     public void testProcessResponse_Resolve() throws Exception {
-        client = makeMdnsServiceTypeClient(/* packetWriter= */ null);
-
         final String instanceName = "service-instance";
         final String[] hostname = new String[] { "testhost "};
         final String ipV4Address = "192.0.2.0";
@@ -1213,14 +1201,17 @@
         final MdnsSearchOptions resolveOptions2 = MdnsSearchOptions.newBuilder()
                 .setResolveInstanceName(instanceName).build();
 
+        doCallRealMethod().when(mockDeps).getDatagramPacketsFromMdnsPacket(
+                any(), any(MdnsPacket.class), any(InetSocketAddress.class), anyBoolean());
+
         startSendAndReceive(mockListenerOne, resolveOptions1);
         startSendAndReceive(mockListenerTwo, resolveOptions2);
         // No need to verify order for both listeners; and order is not guaranteed between them
         InOrder inOrder = inOrder(mockListenerOne, mockSocketClient);
 
         // Verify a query for SRV/TXT was sent, but no PTR query
-        final ArgumentCaptor<DatagramPacket> srvTxtQueryCaptor =
-                ArgumentCaptor.forClass(DatagramPacket.class);
+        final ArgumentCaptor<List<DatagramPacket>> srvTxtQueryCaptor =
+                ArgumentCaptor.forClass(List.class);
         currentThreadExecutor.getAndClearLastScheduledRunnable().run();
         // Send twice for IPv4 and IPv6
         inOrder.verify(mockSocketClient, times(2)).sendPacketRequestingUnicastResponse(
@@ -1232,7 +1223,7 @@
         verify(mockListenerTwo).onDiscoveryQuerySent(any(), anyInt());
 
         final MdnsPacket srvTxtQueryPacket = MdnsPacket.parse(
-                new MdnsPacketReader(srvTxtQueryCaptor.getValue()));
+                new MdnsPacketReader(srvTxtQueryCaptor.getValue().get(0)));
 
         final String[] serviceName = getTestServiceName(instanceName);
         assertEquals(1, srvTxtQueryPacket.questions.size());
@@ -1264,8 +1255,8 @@
 
         // Expect a query for A/AAAA
         dispatchMessage();
-        final ArgumentCaptor<DatagramPacket> addressQueryCaptor =
-                ArgumentCaptor.forClass(DatagramPacket.class);
+        final ArgumentCaptor<List<DatagramPacket>> addressQueryCaptor =
+                ArgumentCaptor.forClass(List.class);
         currentThreadExecutor.getAndClearLastScheduledRunnable().run();
         inOrder.verify(mockSocketClient, times(2)).sendPacketRequestingMulticastResponse(
                 addressQueryCaptor.capture(),
@@ -1275,7 +1266,7 @@
         verify(mockListenerTwo, times(2)).onDiscoveryQuerySent(any(), anyInt());
 
         final MdnsPacket addressQueryPacket = MdnsPacket.parse(
-                new MdnsPacketReader(addressQueryCaptor.getValue()));
+                new MdnsPacketReader(addressQueryCaptor.getValue().get(0)));
         assertEquals(2, addressQueryPacket.questions.size());
         assertTrue(hasQuestion(addressQueryPacket, MdnsRecord.TYPE_A, hostname));
         assertTrue(hasQuestion(addressQueryPacket, MdnsRecord.TYPE_AAAA, hostname));
@@ -1317,8 +1308,6 @@
 
     @Test
     public void testRenewTxtSrvInResolve() throws Exception {
-        client = makeMdnsServiceTypeClient(/* packetWriter= */ null);
-
         final String instanceName = "service-instance";
         final String[] hostname = new String[] { "testhost "};
         final String ipV4Address = "192.0.2.0";
@@ -1327,12 +1316,15 @@
         final MdnsSearchOptions resolveOptions = MdnsSearchOptions.newBuilder()
                 .setResolveInstanceName(instanceName).build();
 
+        doCallRealMethod().when(mockDeps).getDatagramPacketsFromMdnsPacket(
+                any(), any(MdnsPacket.class), any(InetSocketAddress.class), anyBoolean());
+
         startSendAndReceive(mockListenerOne, resolveOptions);
         InOrder inOrder = inOrder(mockListenerOne, mockSocketClient);
 
         // Get the query for SRV/TXT
-        final ArgumentCaptor<DatagramPacket> srvTxtQueryCaptor =
-                ArgumentCaptor.forClass(DatagramPacket.class);
+        final ArgumentCaptor<List<DatagramPacket>> srvTxtQueryCaptor =
+                ArgumentCaptor.forClass(List.class);
         currentThreadExecutor.getAndClearLastScheduledRunnable().run();
         // Send twice for IPv4 and IPv6
         inOrder.verify(mockSocketClient, times(2)).sendPacketRequestingUnicastResponse(
@@ -1342,7 +1334,7 @@
         assertNotNull(delayMessage);
 
         final MdnsPacket srvTxtQueryPacket = MdnsPacket.parse(
-                new MdnsPacketReader(srvTxtQueryCaptor.getValue()));
+                new MdnsPacketReader(srvTxtQueryCaptor.getValue().get(0)));
 
         final String[] serviceName = getTestServiceName(instanceName);
         assertTrue(hasQuestion(srvTxtQueryPacket, MdnsRecord.TYPE_ANY, serviceName));
@@ -1386,8 +1378,8 @@
         currentThreadExecutor.getAndClearLastScheduledRunnable().run();
 
         // Expect a renewal query
-        final ArgumentCaptor<DatagramPacket> renewalQueryCaptor =
-                ArgumentCaptor.forClass(DatagramPacket.class);
+        final ArgumentCaptor<List<DatagramPacket>> renewalQueryCaptor =
+                ArgumentCaptor.forClass(List.class);
         // Second and later sends are sent as "expect multicast response" queries
         inOrder.verify(mockSocketClient, times(2)).sendPacketRequestingMulticastResponse(
                 renewalQueryCaptor.capture(),
@@ -1396,7 +1388,7 @@
         assertNotNull(delayMessage);
         inOrder.verify(mockListenerOne).onDiscoveryQuerySent(any(), anyInt());
         final MdnsPacket renewalPacket = MdnsPacket.parse(
-                new MdnsPacketReader(renewalQueryCaptor.getValue()));
+                new MdnsPacketReader(renewalQueryCaptor.getValue().get(0)));
         assertTrue(hasQuestion(renewalPacket, MdnsRecord.TYPE_ANY, serviceName));
         inOrder.verifyNoMoreInteractions();
 
@@ -1431,8 +1423,6 @@
 
     @Test
     public void testProcessResponse_ResolveExcludesOtherServices() {
-        client = makeMdnsServiceTypeClient(/* packetWriter= */ null);
-
         final String requestedInstance = "instance1";
         final String otherInstance = "instance2";
         final String ipV4Address = "192.0.2.0";
@@ -1499,8 +1489,6 @@
 
     @Test
     public void testProcessResponse_SubtypeDiscoveryLimitedToSubtype() {
-        client = makeMdnsServiceTypeClient(/* packetWriter= */ null);
-
         final String matchingInstance = "instance1";
         final String subtype = "_subtype";
         final String otherInstance = "instance2";
@@ -1587,8 +1575,6 @@
 
     @Test
     public void testProcessResponse_SubtypeChange() {
-        client = makeMdnsServiceTypeClient(/* packetWriter= */ null);
-
         final String matchingInstance = "instance1";
         final String subtype = "_subtype";
         final String ipV4Address = "192.0.2.0";
@@ -1670,8 +1656,6 @@
 
     @Test
     public void testNotifySocketDestroyed() throws Exception {
-        client = makeMdnsServiceTypeClient(/* packetWriter= */ null);
-
         final String requestedInstance = "instance1";
         final String otherInstance = "instance2";
         final String ipV4Address = "192.0.2.0";
@@ -1946,6 +1930,138 @@
                 16 /* scheduledCount */);
     }
 
+    @Test
+    public void testSendQueryWithKnownAnswers() throws Exception {
+        client = new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor,
+                mockDecoderClock, socketKey, mockSharedLog, thread.getLooper(), mockDeps,
+                serviceCache,
+                MdnsFeatureFlags.newBuilder().setIsQueryWithKnownAnswerEnabled(true).build());
+
+        doCallRealMethod().when(mockDeps).getDatagramPacketsFromMdnsPacket(
+                any(), any(MdnsPacket.class), any(InetSocketAddress.class), anyBoolean());
+
+        startSendAndReceive(mockListenerOne, MdnsSearchOptions.getDefaultOptions());
+        InOrder inOrder = inOrder(mockListenerOne, mockSocketClient);
+
+        final ArgumentCaptor<List<DatagramPacket>> queryCaptor =
+                ArgumentCaptor.forClass(List.class);
+        currentThreadExecutor.getAndClearLastScheduledRunnable().run();
+        // Send twice for IPv4 and IPv6
+        inOrder.verify(mockSocketClient, times(2)).sendPacketRequestingUnicastResponse(
+                queryCaptor.capture(), eq(socketKey), eq(false));
+        verify(mockDeps, times(1)).sendMessage(any(), any(Message.class));
+        assertNotNull(delayMessage);
+
+        final MdnsPacket queryPacket = MdnsPacket.parse(
+                new MdnsPacketReader(queryCaptor.getValue().get(0)));
+        assertTrue(hasQuestion(queryPacket, MdnsRecord.TYPE_PTR));
+
+        // Process a response
+        final String serviceName = "service-instance";
+        final String ipV4Address = "192.0.2.0";
+        final String[] subtypeLabels = Stream.concat(Stream.of("_subtype", "_sub"),
+                        Arrays.stream(SERVICE_TYPE_LABELS)).toArray(String[]::new);
+        final MdnsPacket packetWithoutSubtype = createResponse(
+                serviceName, ipV4Address, 5353, SERVICE_TYPE_LABELS,
+                Collections.emptyMap() /* textAttributes */, TEST_TTL);
+        final MdnsPointerRecord originalPtr = (MdnsPointerRecord) CollectionUtils.findFirst(
+                packetWithoutSubtype.answers, r -> r instanceof MdnsPointerRecord);
+
+        // Add a subtype PTR record
+        final ArrayList<MdnsRecord> newAnswers = new ArrayList<>(packetWithoutSubtype.answers);
+        newAnswers.add(new MdnsPointerRecord(subtypeLabels, originalPtr.getReceiptTime(),
+                originalPtr.getCacheFlush(), originalPtr.getTtl(), originalPtr.getPointer()));
+        final MdnsPacket packetWithSubtype = new MdnsPacket(
+                packetWithoutSubtype.flags,
+                packetWithoutSubtype.questions,
+                newAnswers,
+                packetWithoutSubtype.authorityRecords,
+                packetWithoutSubtype.additionalRecords);
+        processResponse(packetWithSubtype, socketKey);
+
+        // Expect a query with known answers
+        dispatchMessage();
+        final ArgumentCaptor<List<DatagramPacket>> knownAnswersQueryCaptor =
+                ArgumentCaptor.forClass(List.class);
+        currentThreadExecutor.getAndClearLastScheduledRunnable().run();
+        inOrder.verify(mockSocketClient, times(2)).sendPacketRequestingMulticastResponse(
+                knownAnswersQueryCaptor.capture(), eq(socketKey), eq(false));
+
+        final MdnsPacket knownAnswersQueryPacket = MdnsPacket.parse(
+                new MdnsPacketReader(knownAnswersQueryCaptor.getValue().get(0)));
+        assertTrue(hasQuestion(knownAnswersQueryPacket, MdnsRecord.TYPE_PTR, SERVICE_TYPE_LABELS));
+        assertTrue(hasAnswer(knownAnswersQueryPacket, MdnsRecord.TYPE_PTR, SERVICE_TYPE_LABELS));
+        assertFalse(hasAnswer(knownAnswersQueryPacket, MdnsRecord.TYPE_PTR, subtypeLabels));
+    }
+
+    @Test
+    public void testSendQueryWithSubTypeWithKnownAnswers() throws Exception {
+        client = new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor,
+                mockDecoderClock, socketKey, mockSharedLog, thread.getLooper(), mockDeps,
+                serviceCache,
+                MdnsFeatureFlags.newBuilder().setIsQueryWithKnownAnswerEnabled(true).build());
+
+        doCallRealMethod().when(mockDeps).getDatagramPacketsFromMdnsPacket(
+                any(), any(MdnsPacket.class), any(InetSocketAddress.class), anyBoolean());
+
+        final MdnsSearchOptions options = MdnsSearchOptions.newBuilder()
+                .addSubtype("subtype").build();
+        startSendAndReceive(mockListenerOne, options);
+        InOrder inOrder = inOrder(mockListenerOne, mockSocketClient);
+
+        final ArgumentCaptor<List<DatagramPacket>> queryCaptor =
+                ArgumentCaptor.forClass(List.class);
+        currentThreadExecutor.getAndClearLastScheduledRunnable().run();
+        // Send twice for IPv4 and IPv6
+        inOrder.verify(mockSocketClient, times(2)).sendPacketRequestingUnicastResponse(
+                queryCaptor.capture(), eq(socketKey), eq(false));
+        verify(mockDeps, times(1)).sendMessage(any(), any(Message.class));
+        assertNotNull(delayMessage);
+
+        final MdnsPacket queryPacket = MdnsPacket.parse(
+                new MdnsPacketReader(queryCaptor.getValue().get(0)));
+        final String[] subtypeLabels = Stream.concat(Stream.of("_subtype", "_sub"),
+                Arrays.stream(SERVICE_TYPE_LABELS)).toArray(String[]::new);
+        assertTrue(hasQuestion(queryPacket, MdnsRecord.TYPE_PTR, SERVICE_TYPE_LABELS));
+        assertTrue(hasQuestion(queryPacket, MdnsRecord.TYPE_PTR, subtypeLabels));
+
+        // Process a response
+        final String serviceName = "service-instance";
+        final String ipV4Address = "192.0.2.0";
+        final MdnsPacket packetWithoutSubtype = createResponse(
+                serviceName, ipV4Address, 5353, SERVICE_TYPE_LABELS,
+                Collections.emptyMap() /* textAttributes */, TEST_TTL);
+        final MdnsPointerRecord originalPtr = (MdnsPointerRecord) CollectionUtils.findFirst(
+                packetWithoutSubtype.answers, r -> r instanceof MdnsPointerRecord);
+
+        // Add a subtype PTR record
+        final ArrayList<MdnsRecord> newAnswers = new ArrayList<>(packetWithoutSubtype.answers);
+        newAnswers.add(new MdnsPointerRecord(subtypeLabels, originalPtr.getReceiptTime(),
+                originalPtr.getCacheFlush(), originalPtr.getTtl(), originalPtr.getPointer()));
+        final MdnsPacket packetWithSubtype = new MdnsPacket(
+                packetWithoutSubtype.flags,
+                packetWithoutSubtype.questions,
+                newAnswers,
+                packetWithoutSubtype.authorityRecords,
+                packetWithoutSubtype.additionalRecords);
+        processResponse(packetWithSubtype, socketKey);
+
+        // Expect a query with known answers
+        dispatchMessage();
+        final ArgumentCaptor<List<DatagramPacket>> knownAnswersQueryCaptor =
+                ArgumentCaptor.forClass(List.class);
+        currentThreadExecutor.getAndClearLastScheduledRunnable().run();
+        inOrder.verify(mockSocketClient, times(2)).sendPacketRequestingMulticastResponse(
+                knownAnswersQueryCaptor.capture(), eq(socketKey), eq(false));
+
+        final MdnsPacket knownAnswersQueryPacket = MdnsPacket.parse(
+                new MdnsPacketReader(knownAnswersQueryCaptor.getValue().get(0)));
+        assertTrue(hasQuestion(knownAnswersQueryPacket, MdnsRecord.TYPE_PTR, SERVICE_TYPE_LABELS));
+        assertTrue(hasQuestion(knownAnswersQueryPacket, MdnsRecord.TYPE_PTR, subtypeLabels));
+        assertTrue(hasAnswer(knownAnswersQueryPacket, MdnsRecord.TYPE_PTR, SERVICE_TYPE_LABELS));
+        assertTrue(hasAnswer(knownAnswersQueryPacket, MdnsRecord.TYPE_PTR, subtypeLabels));
+    }
+
     private static MdnsServiceInfo matchServiceName(String name) {
         return argThat(info -> info.getServiceInstanceName().equals(name));
     }
@@ -1967,17 +2083,21 @@
         currentThreadExecutor.getAndClearLastScheduledRunnable().run();
         if (expectsUnicastResponse) {
             verify(mockSocketClient).sendPacketRequestingUnicastResponse(
-                    expectedIPv4Packets[index], socketKey, false);
+                    argThat(pkts -> pkts.get(0).equals(expectedIPv4Packets[index])),
+                    eq(socketKey), eq(false));
             if (multipleSocketDiscovery) {
                 verify(mockSocketClient).sendPacketRequestingUnicastResponse(
-                        expectedIPv6Packets[index], socketKey, false);
+                        argThat(pkts -> pkts.get(0).equals(expectedIPv6Packets[index])),
+                        eq(socketKey), eq(false));
             }
         } else {
             verify(mockSocketClient).sendPacketRequestingMulticastResponse(
-                    expectedIPv4Packets[index], socketKey, false);
+                    argThat(pkts -> pkts.get(0).equals(expectedIPv4Packets[index])),
+                    eq(socketKey), eq(false));
             if (multipleSocketDiscovery) {
                 verify(mockSocketClient).sendPacketRequestingMulticastResponse(
-                        expectedIPv6Packets[index], socketKey, false);
+                        argThat(pkts -> pkts.get(0).equals(expectedIPv6Packets[index])),
+                        eq(socketKey), eq(false));
             }
         }
         verify(mockDeps, times(index + 1))
@@ -2006,6 +2126,12 @@
                 && (name == null || Arrays.equals(q.name, name)));
     }
 
+    private static boolean hasAnswer(MdnsPacket packet, int type, @NonNull String[] name) {
+        return packet.answers.stream().anyMatch(q -> {
+            return q.getType() == type && (Arrays.equals(q.name, name));
+        });
+    }
+
     // A fake ScheduledExecutorService that keeps tracking the last scheduled Runnable and its delay
     // time.
     private class FakeExecutor extends ScheduledThreadPoolExecutor {
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketClientTests.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketClientTests.java
index 7ced1cb..1989ed3 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketClientTests.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketClientTests.java
@@ -27,6 +27,7 @@
 import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.doReturn;
+import static org.mockito.Mockito.inOrder;
 import static org.mockito.Mockito.never;
 import static org.mockito.Mockito.timeout;
 import static org.mockito.Mockito.times;
@@ -53,6 +54,7 @@
 import org.junit.runner.RunWith;
 import org.mockito.ArgumentCaptor;
 import org.mockito.ArgumentMatchers;
+import org.mockito.InOrder;
 import org.mockito.Mock;
 import org.mockito.MockitoAnnotations;
 import org.mockito.invocation.InvocationOnMock;
@@ -60,6 +62,8 @@
 import java.io.IOException;
 import java.net.DatagramPacket;
 import java.net.InetSocketAddress;
+import java.util.ArrayList;
+import java.util.List;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicBoolean;
@@ -234,7 +238,7 @@
 
         // Sends a packet.
         DatagramPacket packet = getTestDatagramPacket();
-        mdnsClient.sendPacketRequestingMulticastResponse(packet,
+        mdnsClient.sendPacketRequestingMulticastResponse(List.of(packet),
                 false /* onlyUseIpv6OnIpv6OnlyNetworks */);
         // mockMulticastSocket.send() will be called on another thread. If we verify it immediately,
         // it may not be called yet. So timeout is added.
@@ -242,7 +246,7 @@
         verify(mockUnicastSocket, timeout(TIMEOUT).times(0)).send(packet);
 
         // Verify the packet is sent by the unicast socket.
-        mdnsClient.sendPacketRequestingUnicastResponse(packet,
+        mdnsClient.sendPacketRequestingUnicastResponse(List.of(packet),
                 false /* onlyUseIpv6OnIpv6OnlyNetworks */);
         verify(mockMulticastSocket, timeout(TIMEOUT).times(1)).send(packet);
         verify(mockUnicastSocket, timeout(TIMEOUT).times(1)).send(packet);
@@ -287,7 +291,7 @@
 
         // Sends a packet.
         DatagramPacket packet = getTestDatagramPacket();
-        mdnsClient.sendPacketRequestingMulticastResponse(packet,
+        mdnsClient.sendPacketRequestingMulticastResponse(List.of(packet),
                 false /* onlyUseIpv6OnIpv6OnlyNetworks */);
         // mockMulticastSocket.send() will be called on another thread. If we verify it immediately,
         // it may not be called yet. So timeout is added.
@@ -295,7 +299,7 @@
         verify(mockUnicastSocket, timeout(TIMEOUT).times(0)).send(packet);
 
         // Verify the packet is sent by the multicast socket as well.
-        mdnsClient.sendPacketRequestingUnicastResponse(packet,
+        mdnsClient.sendPacketRequestingUnicastResponse(List.of(packet),
                 false /* onlyUseIpv6OnIpv6OnlyNetworks */);
         verify(mockMulticastSocket, timeout(TIMEOUT).times(2)).send(packet);
         verify(mockUnicastSocket, timeout(TIMEOUT).times(0)).send(packet);
@@ -354,7 +358,7 @@
     public void testStopDiscovery_queueIsCleared() throws IOException {
         mdnsClient.startDiscovery();
         mdnsClient.stopDiscovery();
-        mdnsClient.sendPacketRequestingMulticastResponse(getTestDatagramPacket(),
+        mdnsClient.sendPacketRequestingMulticastResponse(List.of(getTestDatagramPacket()),
                 false /* onlyUseIpv6OnIpv6OnlyNetworks */);
 
         synchronized (mdnsClient.multicastPacketQueue) {
@@ -366,7 +370,7 @@
     public void testSendPacket_afterDiscoveryStops() throws IOException {
         mdnsClient.startDiscovery();
         mdnsClient.stopDiscovery();
-        mdnsClient.sendPacketRequestingMulticastResponse(getTestDatagramPacket(),
+        mdnsClient.sendPacketRequestingMulticastResponse(List.of(getTestDatagramPacket()),
                 false /* onlyUseIpv6OnIpv6OnlyNetworks */);
 
         synchronized (mdnsClient.multicastPacketQueue) {
@@ -380,7 +384,7 @@
         //MdnsConfigsFlagsImpl.mdnsPacketQueueMaxSize.override(2L);
         mdnsClient.startDiscovery();
         for (int i = 0; i < 100; i++) {
-            mdnsClient.sendPacketRequestingMulticastResponse(getTestDatagramPacket(),
+            mdnsClient.sendPacketRequestingMulticastResponse(List.of(getTestDatagramPacket()),
                     false /* onlyUseIpv6OnIpv6OnlyNetworks */);
         }
 
@@ -478,9 +482,9 @@
 
         mdnsClient.startDiscovery();
         DatagramPacket packet = getTestDatagramPacket();
-        mdnsClient.sendPacketRequestingUnicastResponse(packet,
+        mdnsClient.sendPacketRequestingUnicastResponse(List.of(packet),
                 false /* onlyUseIpv6OnIpv6OnlyNetworks */);
-        mdnsClient.sendPacketRequestingMulticastResponse(packet,
+        mdnsClient.sendPacketRequestingMulticastResponse(List.of(packet),
                 false /* onlyUseIpv6OnIpv6OnlyNetworks */);
 
         // Wait for the timer to be triggered.
@@ -511,9 +515,9 @@
         assertFalse(mdnsClient.receivedUnicastResponse);
         assertFalse(mdnsClient.cannotReceiveMulticastResponse.get());
 
-        mdnsClient.sendPacketRequestingUnicastResponse(packet,
+        mdnsClient.sendPacketRequestingUnicastResponse(List.of(packet),
                 false /* onlyUseIpv6OnIpv6OnlyNetworks */);
-        mdnsClient.sendPacketRequestingMulticastResponse(packet,
+        mdnsClient.sendPacketRequestingMulticastResponse(List.of(packet),
                 false /* onlyUseIpv6OnIpv6OnlyNetworks */);
         Thread.sleep(MdnsConfigs.checkMulticastResponseIntervalMs() * 2);
 
@@ -570,6 +574,26 @@
                 .onResponseReceived(any(), argThat(key -> key.getInterfaceIndex() == -1));
     }
 
+    @Test
+    public void testSendPacketWithMultipleDatagramPacket() throws IOException {
+        mdnsClient.startDiscovery();
+        final List<DatagramPacket> packets = new ArrayList<>();
+        for (int i = 0; i < 10; i++) {
+            packets.add(new DatagramPacket(new byte[10 + i] /* buff */, 0 /* offset */,
+                    10 + i /* length */, MdnsConstants.IPV4_SOCKET_ADDR));
+        }
+
+        // Sends packets.
+        mdnsClient.sendPacketRequestingMulticastResponse(packets,
+                false /* onlyUseIpv6OnIpv6OnlyNetworks */);
+        InOrder inOrder = inOrder(mockMulticastSocket);
+        for (int i = 0; i < 10; i++) {
+            // mockMulticastSocket.send() will be called on another thread. If we verify it
+            // immediately, it may not be called yet. So timeout is added.
+            inOrder.verify(mockMulticastSocket, timeout(TIMEOUT)).send(packets.get(i));
+        }
+    }
+
     private DatagramPacket getTestDatagramPacket() {
         return new DatagramPacket(buf, 0, 5,
                 new InetSocketAddress(MdnsConstants.getMdnsIPv4Address(), 5353 /* port */));
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/util/MdnsUtilsTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/util/MdnsUtilsTest.kt
index f705bcb..009205e 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/util/MdnsUtilsTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/util/MdnsUtilsTest.kt
@@ -17,6 +17,13 @@
 package com.android.server.connectivity.mdns.util
 
 import android.os.Build
+import com.android.server.connectivity.mdns.MdnsConstants
+import com.android.server.connectivity.mdns.MdnsConstants.FLAG_TRUNCATED
+import com.android.server.connectivity.mdns.MdnsPacket
+import com.android.server.connectivity.mdns.MdnsPacketReader
+import com.android.server.connectivity.mdns.MdnsPointerRecord
+import com.android.server.connectivity.mdns.MdnsRecord
+import com.android.server.connectivity.mdns.util.MdnsUtils.createQueryDatagramPackets
 import com.android.server.connectivity.mdns.util.MdnsUtils.equalsDnsLabelIgnoreDnsCase
 import com.android.server.connectivity.mdns.util.MdnsUtils.equalsIgnoreDnsCase
 import com.android.server.connectivity.mdns.util.MdnsUtils.toDnsLabelsLowerCase
@@ -24,6 +31,8 @@
 import com.android.server.connectivity.mdns.util.MdnsUtils.truncateServiceName
 import com.android.testutils.DevSdkIgnoreRule
 import com.android.testutils.DevSdkIgnoreRunner
+import java.net.DatagramPacket
+import kotlin.test.assertContentEquals
 import org.junit.Assert.assertArrayEquals
 import org.junit.Assert.assertEquals
 import org.junit.Assert.assertFalse
@@ -43,19 +52,27 @@
         assertEquals("ţést", toDnsLowerCase("ţést"))
         // Unicode characters 0x10000 (𐀀), 0x10001 (𐀁), 0x10041 (𐁁)
         // Note the last 2 bytes of 0x10041 are identical to 'A', but it should remain unchanged.
-        assertEquals("test: -->\ud800\udc00 \ud800\udc01 \ud800\udc41<-- ",
-                toDnsLowerCase("Test: -->\ud800\udc00 \ud800\udc01 \ud800\udc41<-- "))
+        assertEquals(
+            "test: -->\ud800\udc00 \ud800\udc01 \ud800\udc41<-- ",
+                toDnsLowerCase("Test: -->\ud800\udc00 \ud800\udc01 \ud800\udc41<-- ")
+        )
         // Also test some characters where the first surrogate is not \ud800
-        assertEquals("test: >\ud83c\udff4\udb40\udc67\udb40\udc62\udb40" +
+        assertEquals(
+            "test: >\ud83c\udff4\udb40\udc67\udb40\udc62\udb40" +
                 "\udc77\udb40\udc6c\udb40\udc73\udb40\udc7f<",
-                toDnsLowerCase("Test: >\ud83c\udff4\udb40\udc67\udb40\udc62\udb40" +
-                        "\udc77\udb40\udc6c\udb40\udc73\udb40\udc7f<"))
+                toDnsLowerCase(
+                    "Test: >\ud83c\udff4\udb40\udc67\udb40\udc62\udb40" +
+                        "\udc77\udb40\udc6c\udb40\udc73\udb40\udc7f<"
+                )
+        )
     }
 
     @Test
     fun testToDnsLabelsLowerCase() {
-        assertArrayEquals(arrayOf("test", "tÉst", "ţést"),
-            toDnsLabelsLowerCase(arrayOf("TeSt", "TÉST", "ţést")))
+        assertArrayEquals(
+            arrayOf("test", "tÉst", "ţést"),
+            toDnsLabelsLowerCase(arrayOf("TeSt", "TÉST", "ţést"))
+        )
     }
 
     @Test
@@ -67,13 +84,17 @@
         assertFalse(equalsIgnoreDnsCase("ŢÉST", "ţést"))
         // Unicode characters 0x10000 (𐀀), 0x10001 (𐀁), 0x10041 (𐁁)
         // Note the last 2 bytes of 0x10041 are identical to 'A', but it should remain unchanged.
-        assertTrue(equalsIgnoreDnsCase("test: -->\ud800\udc00 \ud800\udc01 \ud800\udc41<-- ",
-                "Test: -->\ud800\udc00 \ud800\udc01 \ud800\udc41<-- "))
+        assertTrue(equalsIgnoreDnsCase(
+            "test: -->\ud800\udc00 \ud800\udc01 \ud800\udc41<-- ",
+                "Test: -->\ud800\udc00 \ud800\udc01 \ud800\udc41<-- "
+        ))
         // Also test some characters where the first surrogate is not \ud800
-        assertTrue(equalsIgnoreDnsCase("test: >\ud83c\udff4\udb40\udc67\udb40\udc62\udb40" +
+        assertTrue(equalsIgnoreDnsCase(
+            "test: >\ud83c\udff4\udb40\udc67\udb40\udc62\udb40" +
                 "\udc77\udb40\udc6c\udb40\udc73\udb40\udc7f<",
                 "Test: >\ud83c\udff4\udb40\udc67\udb40\udc62\udb40" +
-                        "\udc77\udb40\udc6c\udb40\udc73\udb40\udc7f<"))
+                        "\udc77\udb40\udc6c\udb40\udc73\udb40\udc7f<"
+        ))
     }
 
     @Test
@@ -92,14 +113,84 @@
 
     @Test
     fun testTypeEqualsOrIsSubtype() {
-        assertTrue(MdnsUtils.typeEqualsOrIsSubtype(arrayOf("_type", "_tcp", "local"),
-            arrayOf("_type", "_TCP", "local")))
-        assertTrue(MdnsUtils.typeEqualsOrIsSubtype(arrayOf("_type", "_tcp", "local"),
-            arrayOf("a", "_SUB", "_type", "_TCP", "local")))
-        assertFalse(MdnsUtils.typeEqualsOrIsSubtype(arrayOf("_sub", "_type", "_tcp", "local"),
-                arrayOf("_type", "_TCP", "local")))
+        assertTrue(MdnsUtils.typeEqualsOrIsSubtype(
+            arrayOf("_type", "_tcp", "local"),
+            arrayOf("_type", "_TCP", "local")
+        ))
+        assertTrue(MdnsUtils.typeEqualsOrIsSubtype(
+            arrayOf("_type", "_tcp", "local"),
+            arrayOf("a", "_SUB", "_type", "_TCP", "local")
+        ))
+        assertFalse(MdnsUtils.typeEqualsOrIsSubtype(
+            arrayOf("_sub", "_type", "_tcp", "local"),
+                arrayOf("_type", "_TCP", "local")
+        ))
         assertFalse(MdnsUtils.typeEqualsOrIsSubtype(
                 arrayOf("a", "_other", "_type", "_tcp", "local"),
-                arrayOf("a", "_SUB", "_type", "_TCP", "local")))
+                arrayOf("a", "_SUB", "_type", "_TCP", "local")
+        ))
+    }
+
+    @Test
+    fun testCreateQueryDatagramPackets() {
+        // Question data bytes:
+        // Name label(17)(duplicated labels) + PTR type(2) + cacheFlush(2) = 21
+        //
+        // Known answers data bytes:
+        // Name label(17)(duplicated labels) + PTR type(2) + cacheFlush(2) + receiptTimeMillis(4)
+        // + Data length(2) + Pointer data(18)(duplicated labels) = 45
+        val questions = mutableListOf<MdnsRecord>()
+        val knownAnswers = mutableListOf<MdnsRecord>()
+        for (i in 1..100) {
+            questions.add(MdnsPointerRecord(arrayOf("_testservice$i", "_tcp", "local"), false))
+            knownAnswers.add(MdnsPointerRecord(
+                    arrayOf("_testservice$i", "_tcp", "local"),
+                    0L,
+                    false,
+                    4_500_000L,
+                    arrayOf("MyTestService$i", "_testservice$i", "_tcp", "local")
+            ))
+        }
+        // MdnsPacket data bytes:
+        // Questions(21 * 100) + Answers(45 * 100) = 6600 -> at least 5 packets
+        val query = MdnsPacket(
+                MdnsConstants.FLAGS_QUERY,
+                questions as List<MdnsRecord>,
+                knownAnswers as List<MdnsRecord>,
+                emptyList(),
+                emptyList()
+        )
+        // Expect the oversize MdnsPacket to be separated into 5 DatagramPackets.
+        val bufferSize = 1500
+        val packets = createQueryDatagramPackets(
+                ByteArray(bufferSize),
+                query,
+                MdnsConstants.IPV4_SOCKET_ADDR
+        )
+        assertEquals(5, packets.size)
+        assertTrue(packets.all { packet -> packet.length < bufferSize })
+
+        val mdnsPacket = createMdnsPacketFromMultipleDatagramPackets(packets)
+        assertEquals(query.flags, mdnsPacket.flags)
+        assertContentEquals(query.questions, mdnsPacket.questions)
+        assertContentEquals(query.answers, mdnsPacket.answers)
+    }
+
+    private fun createMdnsPacketFromMultipleDatagramPackets(
+            packets: List<DatagramPacket>
+    ): MdnsPacket {
+        var flags = 0
+        val questions = mutableListOf<MdnsRecord>()
+        val answers = mutableListOf<MdnsRecord>()
+        for ((index, packet) in packets.withIndex()) {
+            val mdnsPacket = MdnsPacket.parse(MdnsPacketReader(packet))
+            if (index != packets.size - 1) {
+                assertTrue((mdnsPacket.flags and FLAG_TRUNCATED) == FLAG_TRUNCATED)
+            }
+            flags = mdnsPacket.flags
+            questions.addAll(mdnsPacket.questions)
+            answers.addAll(mdnsPacket.answers)
+        }
+        return MdnsPacket(flags, questions, answers, emptyList(), emptyList())
     }
 }
diff --git a/tests/unit/java/com/android/server/connectivityservice/CSIngressDiscardRuleTests.kt b/tests/unit/java/com/android/server/connectivityservice/CSIngressDiscardRuleTests.kt
index e8664c1..bb7fb51 100644
--- a/tests/unit/java/com/android/server/connectivityservice/CSIngressDiscardRuleTests.kt
+++ b/tests/unit/java/com/android/server/connectivityservice/CSIngressDiscardRuleTests.kt
@@ -30,6 +30,7 @@
 import android.net.VpnTransportInfo
 import android.os.Build
 import androidx.test.filters.SmallTest
+import com.android.server.connectivity.ConnectivityFlags
 import com.android.testutils.DevSdkIgnoreRule
 import com.android.testutils.DevSdkIgnoreRunner
 import com.android.testutils.RecorderCallback.CallbackEntry.LinkPropertiesChanged
@@ -56,7 +57,9 @@
                         TYPE_VPN_SERVICE,
                         "MySession12345",
                         false /* bypassable */,
-                        false /* longLivedTcpConnectionsExpensive */))
+                        false /* longLivedTcpConnectionsExpensive */
+                )
+        )
         .build()
 
 private fun wifiNc() = NetworkCapabilities.Builder()
@@ -286,4 +289,19 @@
         waitForIdle()
         verify(bpfNetMaps).removeIngressDiscardRule(IPV6_ADDRESS)
     }
+
+    @Test @FeatureFlags([Flag(ConnectivityFlags.INGRESS_TO_VPN_ADDRESS_FILTERING, false)])
+    fun testVpnIngressDiscardRule_FeatureDisabled() {
+        val nr = nr(TRANSPORT_VPN)
+        val cb = TestableNetworkCallback()
+        cm.registerNetworkCallback(nr, cb)
+        val nc = vpnNc()
+        val lp = lp(VPN_IFNAME, IPV6_LINK_ADDRESS, LOCAL_IPV6_LINK_ADDRRESS)
+        val agent = Agent(nc = nc, lp = lp)
+        agent.connect()
+        cb.expectAvailableCallbacks(agent.network, validated = false)
+
+        // IngressDiscardRule should not be added since feature is disabled
+        verify(bpfNetMaps, never()).setIngressDiscardRule(any(), any())
+    }
 }
diff --git a/tests/unit/java/com/android/server/connectivityservice/base/CSTest.kt b/tests/unit/java/com/android/server/connectivityservice/base/CSTest.kt
index 6c9871c..bd26c63 100644
--- a/tests/unit/java/com/android/server/connectivityservice/base/CSTest.kt
+++ b/tests/unit/java/com/android/server/connectivityservice/base/CSTest.kt
@@ -79,11 +79,15 @@
 import java.util.concurrent.TimeUnit
 import java.util.function.BiConsumer
 import java.util.function.Consumer
+import kotlin.annotation.AnnotationRetention.RUNTIME
+import kotlin.annotation.AnnotationTarget.FUNCTION
 import kotlin.test.assertNotNull
 import kotlin.test.assertNull
 import kotlin.test.fail
 import org.junit.After
 import org.junit.Before
+import org.junit.Rule
+import org.junit.rules.TestName
 import org.mockito.AdditionalAnswers.delegatesTo
 import org.mockito.Mockito.doAnswer
 import org.mockito.Mockito.doReturn
@@ -126,6 +130,9 @@
 // TODO (b/272685721) : make ConnectivityServiceTest smaller and faster by moving the setup
 // parts into this class and moving the individual tests to multiple separate classes.
 open class CSTest {
+    @get:Rule
+    val testNameRule = TestName()
+
     companion object {
         val CSTestExecutor = Executors.newSingleThreadExecutor()
     }
@@ -155,8 +162,7 @@
         it[ConnectivityService.ALLOW_SATALLITE_NETWORK_FALLBACK] = true
         it[ConnectivityFlags.INGRESS_TO_VPN_ADDRESS_FILTERING] = true
     }
-    fun enableFeature(f: String) = enabledFeatures.set(f, true)
-    fun disableFeature(f: String) = enabledFeatures.set(f, false)
+    fun setFeatureEnabled(flag: String, enabled: Boolean) = enabledFeatures.set(flag, enabled)
 
     // When adding new members, consider if it's not better to build the object in CSTestHelpers
     // to keep this file clean of implementation details. Generally, CSTestHelpers should only
@@ -201,8 +207,32 @@
     lateinit var cm: ConnectivityManager
     lateinit var csHandler: Handler
 
+    // Tests can use this annotation to set flag values before constructing ConnectivityService
+    // e.g. @FeatureFlags([Flag(flagName1, true/false), Flag(flagName2, true/false)])
+    @Retention(RUNTIME)
+    @Target(FUNCTION)
+    annotation class FeatureFlags(val flags: Array<Flag>)
+
+    @Retention(RUNTIME)
+    @Target(FUNCTION)
+    annotation class Flag(val name: String, val enabled: Boolean)
+
     @Before
     fun setUp() {
+        // Set feature flags before constructing ConnectivityService
+        val testMethodName = testNameRule.methodName
+        try {
+            val testMethod = this::class.java.getMethod(testMethodName)
+            val featureFlags = testMethod.getAnnotation(FeatureFlags::class.java)
+            if (featureFlags != null) {
+                for (flag in featureFlags.flags) {
+                    setFeatureEnabled(flag.name, flag.enabled)
+                }
+            }
+        } catch (ignored: NoSuchMethodException) {
+            // This is expected for parameterized tests
+        }
+
         alarmHandlerThread = HandlerThread("TestAlarmManager").also { it.start() }
         alarmManager = makeMockAlarmManager(alarmHandlerThread)
         service = makeConnectivityService(context, netd, deps).also { it.systemReadyInternal() }
@@ -235,7 +265,8 @@
                 context: Context,
                 tm: TelephonyManager,
                 requestRestrictedWifiEnabled: Boolean,
-                listener: BiConsumer<Int, Int>
+                listener: BiConsumer<Int, Int>,
+                handler: Handler
         ) = if (SdkLevel.isAtLeastT()) mock<CarrierPrivilegeAuthenticator>() else null
 
         var satelliteNetworkFallbackUidUpdate: Consumer<Set<Int>>? = null
diff --git a/tests/unit/vpn-jarjar-rules.txt b/tests/unit/vpn-jarjar-rules.txt
deleted file mode 100644
index f74eab8..0000000
--- a/tests/unit/vpn-jarjar-rules.txt
+++ /dev/null
@@ -1,2 +0,0 @@
-# Only keep classes imported by ConnectivityServiceTest
-keep com.android.server.connectivity.VpnProfileStore
diff --git a/thread/demoapp/Android.bp b/thread/demoapp/Android.bp
index fcfd469..117b4f9 100644
--- a/thread/demoapp/Android.bp
+++ b/thread/demoapp/Android.bp
@@ -34,7 +34,19 @@
     libs: [
         "framework-connectivity-t",
     ],
+    required: [
+        "privapp-permissions-com.android.threadnetwork.demoapp",
+    ],
+    system_ext_specific: true,
     certificate: "platform",
     privileged: true,
     platform_apis: true,
 }
+
+prebuilt_etc {
+    name: "privapp-permissions-com.android.threadnetwork.demoapp",
+    src: "privapp-permissions-com.android.threadnetwork.demoapp.xml",
+    sub_dir: "permissions",
+    filename_from_src: true,
+    system_ext_specific: true,
+}
diff --git a/thread/demoapp/privapp-permissions-com.android.threadnetwork.demoapp.xml b/thread/demoapp/privapp-permissions-com.android.threadnetwork.demoapp.xml
new file mode 100644
index 0000000..1995e60
--- /dev/null
+++ b/thread/demoapp/privapp-permissions-com.android.threadnetwork.demoapp.xml
@@ -0,0 +1,22 @@
+<?xml version="1.0" encoding="utf-8"?>
+<!-- Copyright (C) 2024 The Android Open Source Project
+
+     Licensed under the Apache License, Version 2.0 (the "License");
+     you may not use this file except in compliance with the License.
+     You may obtain a copy of the License at
+
+          http://www.apache.org/licenses/LICENSE-2.0
+
+     Unless required by applicable law or agreed to in writing, software
+     distributed under the License is distributed on an "AS IS" BASIS,
+     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+     See the License for the specific language governing permissions and
+     limitations under the License.
+-->
+
+<!-- The privileged permissions needed by the com.android.threadnetwork.demoapp app. -->
+<permissions>
+    <privapp-permissions package="com.android.threadnetwork.demoapp">
+        <permission name="android.permission.THREAD_NETWORK_PRIVILEGED" />
+    </privapp-permissions>
+</permissions>
diff --git a/thread/framework/Android.bp b/thread/framework/Android.bp
index 846253c..f8fe422 100644
--- a/thread/framework/Android.bp
+++ b/thread/framework/Android.bp
@@ -30,3 +30,14 @@
         "//packages/modules/Connectivity:__subpackages__",
     ],
 }
+
+filegroup {
+    name: "framework-thread-ot-daemon-shared-aidl-sources",
+    srcs: [
+        "java/android/net/thread/ChannelMaxPower.aidl",
+    ],
+    path: "java",
+    visibility: [
+        "//external/ot-br-posix:__subpackages__",
+    ],
+}
diff --git a/thread/framework/java/android/net/thread/ActiveOperationalDataset.java b/thread/framework/java/android/net/thread/ActiveOperationalDataset.java
index 4a7d3a7..22457f5 100644
--- a/thread/framework/java/android/net/thread/ActiveOperationalDataset.java
+++ b/thread/framework/java/android/net/thread/ActiveOperationalDataset.java
@@ -18,7 +18,7 @@
 
 import static com.android.internal.util.Preconditions.checkArgument;
 import static com.android.internal.util.Preconditions.checkState;
-import static com.android.net.module.util.HexDump.dumpHexString;
+import static com.android.net.module.util.HexDump.toHexString;
 
 import static java.nio.charset.StandardCharsets.UTF_8;
 import static java.util.Objects.requireNonNull;
@@ -610,7 +610,7 @@
         sb.append("{networkName=")
                 .append(getNetworkName())
                 .append(", extendedPanId=")
-                .append(dumpHexString(getExtendedPanId()))
+                .append(toHexString(getExtendedPanId()))
                 .append(", panId=")
                 .append(getPanId())
                 .append(", channel=")
@@ -1109,7 +1109,7 @@
             sb.append("{rotation=")
                     .append(mRotationTimeHours)
                     .append(", flags=")
-                    .append(dumpHexString(mFlags))
+                    .append(toHexString(mFlags))
                     .append("}");
             return sb.toString();
         }
diff --git a/thread/framework/java/android/net/thread/ChannelMaxPower.aidl b/thread/framework/java/android/net/thread/ChannelMaxPower.aidl
new file mode 100644
index 0000000..bcda8a8
--- /dev/null
+++ b/thread/framework/java/android/net/thread/ChannelMaxPower.aidl
@@ -0,0 +1,28 @@
+/*
+ * Copyright (C) 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package android.net.thread;
+
+ /**
+  * Mapping from a channel to its max power.
+  *
+  * {@hide}
+  */
+parcelable ChannelMaxPower {
+    int channel; // The Thread radio channel.
+    int maxPower; // The max power in the unit of 0.01dBm. Passing INT16_MAX(32767) will
+                  // disable the channel.
+}
diff --git a/thread/framework/java/android/net/thread/IThreadNetworkController.aidl b/thread/framework/java/android/net/thread/IThreadNetworkController.aidl
index 485e25d..c5ca557 100644
--- a/thread/framework/java/android/net/thread/IThreadNetworkController.aidl
+++ b/thread/framework/java/android/net/thread/IThreadNetworkController.aidl
@@ -17,6 +17,7 @@
 package android.net.thread;
 
 import android.net.thread.ActiveOperationalDataset;
+import android.net.thread.ChannelMaxPower;
 import android.net.thread.IActiveOperationalDatasetReceiver;
 import android.net.thread.IOperationalDatasetCallback;
 import android.net.thread.IOperationReceiver;
@@ -39,6 +40,7 @@
     void leave(in IOperationReceiver receiver);
 
     void setTestNetworkAsUpstream(in String testNetworkInterfaceName, in IOperationReceiver receiver);
+    void setChannelMaxPowers(in ChannelMaxPower[] channelMaxPowers, in IOperationReceiver receiver);
 
     int getThreadVersion();
     void createRandomizedDataset(String networkName, IActiveOperationalDatasetReceiver receiver);
diff --git a/thread/framework/java/android/net/thread/ThreadNetworkController.java b/thread/framework/java/android/net/thread/ThreadNetworkController.java
index db761a3..8d6b40a 100644
--- a/thread/framework/java/android/net/thread/ThreadNetworkController.java
+++ b/thread/framework/java/android/net/thread/ThreadNetworkController.java
@@ -25,10 +25,12 @@
 import android.annotation.NonNull;
 import android.annotation.Nullable;
 import android.annotation.RequiresPermission;
+import android.annotation.Size;
 import android.annotation.SystemApi;
 import android.os.Binder;
 import android.os.OutcomeReceiver;
 import android.os.RemoteException;
+import android.util.SparseIntArray;
 
 import com.android.internal.annotations.GuardedBy;
 import com.android.internal.annotations.VisibleForTesting;
@@ -98,6 +100,12 @@
     /** Thread standard version 1.3. */
     public static final int THREAD_VERSION_1_3 = 4;
 
+    /** Minimum value of max power in unit of 0.01dBm. @hide */
+    private static final int POWER_LIMITATION_MIN = -32768;
+
+    /** Maximum value of max power in unit of 0.01dBm. @hide */
+    private static final int POWER_LIMITATION_MAX = 32767;
+
     /** @hide */
     @Retention(RetentionPolicy.SOURCE)
     @IntDef({THREAD_VERSION_1_3})
@@ -596,6 +604,98 @@
         }
     }
 
+    /**
+     * Sets max power of each channel.
+     *
+     * <p>If not set, the default max power is set by the Thread HAL service or the Thread radio
+     * chip firmware.
+     *
+     * <p>On success, the Pending Dataset is successfully registered and persisted on the Leader and
+     * {@link OutcomeReceiver#onResult} of {@code receiver} will be called; When failed, {@link
+     * OutcomeReceiver#onError} will be called with a specific error:
+     *
+     * <ul>
+     *   <li>{@link ThreadNetworkException#ERROR_UNSUPPORTED_OPERATION} the operation is no
+     *       supported by the platform.
+     * </ul>
+     *
+     * @param channelMaxPowers SparseIntArray (key: channel, value: max power) consists of channel
+     *     and corresponding max power. Valid channel values should be between {@link
+     *     ActiveOperationalDataset#CHANNEL_MIN_24_GHZ} and {@link
+     *     ActiveOperationalDataset#CHANNEL_MAX_24_GHZ}. The unit of the max power is 0.01dBm. Max
+     *     power values should be between INT16_MIN (-32768) and INT16_MAX (32767). If the max power
+     *     is set to INT16_MAX, the corresponding channel is not supported.
+     * @param executor the executor to execute {@code receiver}.
+     * @param receiver the receiver to receive the result of this operation.
+     * @throws IllegalArgumentException if the size of {@code channelMaxPowers} is smaller than 1,
+     *     or invalid channel or max power is configured.
+     * @hide
+     */
+    @RequiresPermission("android.permission.THREAD_NETWORK_PRIVILEGED")
+    public final void setChannelMaxPowers(
+            @NonNull @Size(min = 1) SparseIntArray channelMaxPowers,
+            @NonNull @CallbackExecutor Executor executor,
+            @NonNull OutcomeReceiver<Void, ThreadNetworkException> receiver) {
+        requireNonNull(channelMaxPowers, "channelMaxPowers cannot be null");
+        requireNonNull(executor, "executor cannot be null");
+        requireNonNull(receiver, "receiver cannot be null");
+
+        if (channelMaxPowers.size() < 1) {
+            throw new IllegalArgumentException("channelMaxPowers cannot be empty");
+        }
+
+        for (int i = 0; i < channelMaxPowers.size(); i++) {
+            int channel = channelMaxPowers.keyAt(i);
+            int maxPower = channelMaxPowers.get(channel);
+
+            if ((channel < ActiveOperationalDataset.CHANNEL_MIN_24_GHZ)
+                    || (channel > ActiveOperationalDataset.CHANNEL_MAX_24_GHZ)) {
+                throw new IllegalArgumentException(
+                        "Channel "
+                                + channel
+                                + " exceeds allowed range ["
+                                + ActiveOperationalDataset.CHANNEL_MIN_24_GHZ
+                                + ", "
+                                + ActiveOperationalDataset.CHANNEL_MAX_24_GHZ
+                                + "]");
+            }
+
+            if ((maxPower < POWER_LIMITATION_MIN) || (maxPower > POWER_LIMITATION_MAX)) {
+                throw new IllegalArgumentException(
+                        "Channel power ({channel: "
+                                + channel
+                                + ", maxPower: "
+                                + maxPower
+                                + "}) exceeds allowed range ["
+                                + POWER_LIMITATION_MIN
+                                + ", "
+                                + POWER_LIMITATION_MAX
+                                + "]");
+            }
+        }
+
+        try {
+            mControllerService.setChannelMaxPowers(
+                    toChannelMaxPowerArray(channelMaxPowers),
+                    new OperationReceiverProxy(executor, receiver));
+        } catch (RemoteException e) {
+            throw e.rethrowFromSystemServer();
+        }
+    }
+
+    private static ChannelMaxPower[] toChannelMaxPowerArray(
+            @NonNull SparseIntArray channelMaxPowers) {
+        final ChannelMaxPower[] powerArray = new ChannelMaxPower[channelMaxPowers.size()];
+
+        for (int i = 0; i < channelMaxPowers.size(); i++) {
+            powerArray[i] = new ChannelMaxPower();
+            powerArray[i].channel = channelMaxPowers.keyAt(i);
+            powerArray[i].maxPower = channelMaxPowers.get(powerArray[i].channel);
+        }
+
+        return powerArray;
+    }
+
     private static <T> void propagateError(
             Executor executor,
             OutcomeReceiver<T, ThreadNetworkException> receiver,
diff --git a/thread/framework/java/android/net/thread/ThreadNetworkException.java b/thread/framework/java/android/net/thread/ThreadNetworkException.java
index 4def0fb..f699c30 100644
--- a/thread/framework/java/android/net/thread/ThreadNetworkException.java
+++ b/thread/framework/java/android/net/thread/ThreadNetworkException.java
@@ -138,8 +138,17 @@
      */
     public static final int ERROR_THREAD_DISABLED = 12;
 
+    /**
+     * The operation failed because it is not supported by the platform. For example, some platforms
+     * may not support setting the target power of each channel. The caller should not retry and may
+     * return an error to the user.
+     *
+     * @hide
+     */
+    public static final int ERROR_UNSUPPORTED_OPERATION = 13;
+
     private static final int ERROR_MIN = ERROR_INTERNAL_ERROR;
-    private static final int ERROR_MAX = ERROR_THREAD_DISABLED;
+    private static final int ERROR_MAX = ERROR_UNSUPPORTED_OPERATION;
 
     private final int mErrorCode;
 
diff --git a/thread/service/Android.bp b/thread/service/Android.bp
index 6e2fac1..a82a499 100644
--- a/thread/service/Android.bp
+++ b/thread/service/Android.bp
@@ -45,6 +45,9 @@
         "modules-utils-shell-command-handler",
         "net-utils-device-common",
         "net-utils-device-common-netlink",
+        // The required dependency net-utils-device-common-struct-base is in the classpath via
+        // framework-connectivity
+        "net-utils-device-common-struct",
         "ot-daemon-aidl-java",
     ],
     apex_available: ["com.android.tethering"],
diff --git a/thread/service/java/com/android/server/thread/ActiveOperationalDatasetReceiverWrapper.java b/thread/service/java/com/android/server/thread/ActiveOperationalDatasetReceiverWrapper.java
index e3b4e1a..43ff336 100644
--- a/thread/service/java/com/android/server/thread/ActiveOperationalDatasetReceiverWrapper.java
+++ b/thread/service/java/com/android/server/thread/ActiveOperationalDatasetReceiverWrapper.java
@@ -16,10 +16,12 @@
 
 package com.android.server.thread;
 
+import static android.net.thread.ThreadNetworkException.ERROR_INTERNAL_ERROR;
 import static android.net.thread.ThreadNetworkException.ERROR_UNAVAILABLE;
 
 import android.net.thread.ActiveOperationalDataset;
 import android.net.thread.IActiveOperationalDatasetReceiver;
+import android.net.thread.ThreadNetworkException;
 import android.os.RemoteException;
 
 import com.android.internal.annotations.GuardedBy;
@@ -73,6 +75,17 @@
         }
     }
 
+    public void onError(Throwable e) {
+        if (e instanceof ThreadNetworkException) {
+            ThreadNetworkException threadException = (ThreadNetworkException) e;
+            onError(threadException.getErrorCode(), threadException.getMessage());
+        } else if (e instanceof RemoteException) {
+            onError(ERROR_INTERNAL_ERROR, "Thread stack error");
+        } else {
+            throw new AssertionError(e);
+        }
+    }
+
     public void onError(int errorCode, String errorMessage) {
         synchronized (sPendingReceiversLock) {
             sPendingReceivers.remove(this);
diff --git a/thread/service/java/com/android/server/thread/NsdPublisher.java b/thread/service/java/com/android/server/thread/NsdPublisher.java
index 3c7a72b..2c14f1d 100644
--- a/thread/service/java/com/android/server/thread/NsdPublisher.java
+++ b/thread/service/java/com/android/server/thread/NsdPublisher.java
@@ -21,6 +21,7 @@
 import android.annotation.NonNull;
 import android.content.Context;
 import android.net.InetAddresses;
+import android.net.nsd.DiscoveryRequest;
 import android.net.nsd.NsdManager;
 import android.net.nsd.NsdServiceInfo;
 import android.os.Handler;
@@ -31,15 +32,18 @@
 
 import com.android.internal.annotations.VisibleForTesting;
 import com.android.server.thread.openthread.DnsTxtAttribute;
+import com.android.server.thread.openthread.INsdDiscoverServiceCallback;
 import com.android.server.thread.openthread.INsdPublisher;
+import com.android.server.thread.openthread.INsdResolveServiceCallback;
 import com.android.server.thread.openthread.INsdStatusReceiver;
 
+import java.net.Inet6Address;
 import java.net.InetAddress;
-import java.util.ArrayDeque;
 import java.util.ArrayList;
-import java.util.Deque;
+import java.util.Arrays;
 import java.util.HashSet;
 import java.util.List;
+import java.util.Map;
 import java.util.concurrent.Executor;
 
 /**
@@ -50,14 +54,6 @@
  *
  * <p>All the data members of this class MUST be accessed in the {@code mHandler}'s Thread except
  * {@code mHandler} itself.
- *
- * <p>TODO: b/323300118 - Remove the following mechanism when the race condition in NsdManager is
- * fixed.
- *
- * <p>There's always only one running registration job at any timepoint. All other pending jobs are
- * queued in {@code mRegistrationJobs}. When a registration job is complete (i.e. the according
- * method in {@link NsdManager.RegistrationListener} is called), it will start the next registration
- * job in the queue.
  */
 public final class NsdPublisher extends INsdPublisher.Stub {
     // TODO: b/321883491 - specify network for mDNS operations
@@ -66,7 +62,8 @@
     private final Handler mHandler;
     private final Executor mExecutor;
     private final SparseArray<RegistrationListener> mRegistrationListeners = new SparseArray<>(0);
-    private final Deque<Runnable> mRegistrationJobs = new ArrayDeque<>();
+    private final SparseArray<DiscoveryListener> mDiscoveryListeners = new SparseArray<>(0);
+    private final SparseArray<ServiceInfoListener> mServiceInfoListeners = new SparseArray<>(0);
 
     @VisibleForTesting
     public NsdPublisher(NsdManager nsdManager, Handler handler) {
@@ -89,13 +86,9 @@
             List<DnsTxtAttribute> txt,
             INsdStatusReceiver receiver,
             int listenerId) {
-        postRegistrationJob(
-                () -> {
-                    NsdServiceInfo serviceInfo =
-                            buildServiceInfoForService(
-                                    hostname, name, type, subTypeList, port, txt);
-                    registerInternal(serviceInfo, receiver, listenerId, "service");
-                });
+        NsdServiceInfo serviceInfo =
+                buildServiceInfoForService(hostname, name, type, subTypeList, port, txt);
+        mHandler.post(() -> registerInternal(serviceInfo, receiver, listenerId, "service"));
     }
 
     private static NsdServiceInfo buildServiceInfoForService(
@@ -124,11 +117,8 @@
     @Override
     public void registerHost(
             String name, List<String> addresses, INsdStatusReceiver receiver, int listenerId) {
-        postRegistrationJob(
-                () -> {
-                    NsdServiceInfo serviceInfo = buildServiceInfoForHost(name, addresses);
-                    registerInternal(serviceInfo, receiver, listenerId, "host");
-                });
+        NsdServiceInfo serviceInfo = buildServiceInfoForHost(name, addresses);
+        mHandler.post(() -> registerInternal(serviceInfo, receiver, listenerId, "host"));
     }
 
     private static NsdServiceInfo buildServiceInfoForHost(
@@ -170,7 +160,7 @@
     }
 
     public void unregister(INsdStatusReceiver receiver, int listenerId) {
-        postRegistrationJob(() -> unregisterInternal(receiver, listenerId));
+        mHandler.post(() -> unregisterInternal(receiver, listenerId));
     }
 
     public void unregisterInternal(INsdStatusReceiver receiver, int listenerId) {
@@ -197,6 +187,110 @@
         mNsdManager.unregisterService(registrationListener);
     }
 
+    @Override
+    public void discoverService(String type, INsdDiscoverServiceCallback callback, int listenerId) {
+        mHandler.post(() -> discoverServiceInternal(type, callback, listenerId));
+    }
+
+    private void discoverServiceInternal(
+            String type, INsdDiscoverServiceCallback callback, int listenerId) {
+        checkOnHandlerThread();
+        Log.i(
+                TAG,
+                "Discovering services."
+                        + " Listener ID: "
+                        + listenerId
+                        + ", service type: "
+                        + type);
+
+        DiscoveryListener listener = new DiscoveryListener(listenerId, type, callback);
+        mDiscoveryListeners.append(listenerId, listener);
+        DiscoveryRequest discoveryRequest =
+                new DiscoveryRequest.Builder(type).setNetwork(null).build();
+        mNsdManager.discoverServices(discoveryRequest, mExecutor, listener);
+    }
+
+    @Override
+    public void stopServiceDiscovery(int listenerId) {
+        mHandler.post(() -> stopServiceDiscoveryInternal(listenerId));
+    }
+
+    private void stopServiceDiscoveryInternal(int listenerId) {
+        checkOnHandlerThread();
+
+        DiscoveryListener listener = mDiscoveryListeners.get(listenerId);
+        if (listener == null) {
+            Log.w(
+                    TAG,
+                    "Failed to stop service discovery. Listener ID "
+                            + listenerId
+                            + ". The listener is null.");
+            return;
+        }
+
+        Log.i(TAG, "Stopping service discovery. Listener: " + listener);
+        mNsdManager.stopServiceDiscovery(listener);
+    }
+
+    @Override
+    public void resolveService(
+            String name, String type, INsdResolveServiceCallback callback, int listenerId) {
+        mHandler.post(() -> resolveServiceInternal(name, type, callback, listenerId));
+    }
+
+    private void resolveServiceInternal(
+            String name, String type, INsdResolveServiceCallback callback, int listenerId) {
+        checkOnHandlerThread();
+
+        NsdServiceInfo serviceInfo = new NsdServiceInfo();
+        serviceInfo.setServiceName(name);
+        serviceInfo.setServiceType(type);
+        serviceInfo.setNetwork(null);
+        Log.i(
+                TAG,
+                "Resolving service."
+                        + " Listener ID: "
+                        + listenerId
+                        + ", service name: "
+                        + name
+                        + ", service type: "
+                        + type);
+
+        ServiceInfoListener listener = new ServiceInfoListener(serviceInfo, listenerId, callback);
+        mServiceInfoListeners.append(listenerId, listener);
+        mNsdManager.registerServiceInfoCallback(serviceInfo, mExecutor, listener);
+    }
+
+    @Override
+    public void stopServiceResolution(int listenerId) {
+        mHandler.post(() -> stopServiceResolutionInternal(listenerId));
+    }
+
+    private void stopServiceResolutionInternal(int listenerId) {
+        checkOnHandlerThread();
+
+        ServiceInfoListener listener = mServiceInfoListeners.get(listenerId);
+        if (listener == null) {
+            Log.w(
+                    TAG,
+                    "Failed to stop service resolution. Listener ID: "
+                            + listenerId
+                            + ". The listener is null.");
+            return;
+        }
+
+        Log.i(TAG, "Stopping service resolution. Listener: " + listener);
+
+        try {
+            mNsdManager.unregisterServiceInfoCallback(listener);
+        } catch (IllegalArgumentException e) {
+            Log.w(
+                    TAG,
+                    "Failed to stop the service resolution because it's already stopped. Listener: "
+                            + listener);
+        }
+    }
+
     private void checkOnHandlerThread() {
         if (mHandler.getLooper().getThread() != Thread.currentThread()) {
             throw new IllegalStateException(
@@ -226,7 +320,6 @@
             }
         }
         mRegistrationListeners.clear();
-        mRegistrationJobs.clear();
     }
 
     /** On ot-daemon died, reset. */
@@ -234,39 +327,6 @@
         reset();
     }
 
-    // TODO: b/323300118 - Remove this mechanism when the race condition in NsdManager is fixed.
-    /** Fetch the first job from the queue and run it. See the class doc for more details. */
-    private void peekAndRun() {
-        if (mRegistrationJobs.isEmpty()) {
-            return;
-        }
-        Runnable job = mRegistrationJobs.getFirst();
-        job.run();
-    }
-
-    // TODO: b/323300118 - Remove this mechanism when the race condition in NsdManager is fixed.
-    /**
-     * Pop the first job from the queue and run the next job. See the class doc for more details.
-     */
-    private void popAndRunNext() {
-        if (mRegistrationJobs.isEmpty()) {
-            Log.i(TAG, "No registration jobs when trying to pop and run next.");
-            return;
-        }
-        mRegistrationJobs.removeFirst();
-        peekAndRun();
-    }
-
-    private void postRegistrationJob(Runnable registrationJob) {
-        mHandler.post(
-                () -> {
-                    mRegistrationJobs.addLast(registrationJob);
-                    if (mRegistrationJobs.size() == 1) {
-                        peekAndRun();
-                    }
-                });
-    }
-
     private final class RegistrationListener implements NsdManager.RegistrationListener {
         private final NsdServiceInfo mServiceInfo;
         private final int mListenerId;
@@ -304,7 +364,6 @@
             } catch (RemoteException ignored) {
                 // do nothing if the client is dead
             }
-            popAndRunNext();
         }
 
         @Override
@@ -326,7 +385,6 @@
                     // do nothing if the client is dead
                 }
             }
-            popAndRunNext();
         }
 
         @Override
@@ -344,7 +402,6 @@
             } catch (RemoteException ignored) {
                 // do nothing if the client is dead
             }
-            popAndRunNext();
         }
 
         @Override
@@ -365,7 +422,168 @@
                 }
             }
             mRegistrationListeners.remove(mListenerId);
-            popAndRunNext();
+        }
+    }
+
+    private final class DiscoveryListener implements NsdManager.DiscoveryListener {
+        private final int mListenerId;
+        private final String mType;
+        private final INsdDiscoverServiceCallback mDiscoverServiceCallback;
+
+        DiscoveryListener(
+                int listenerId,
+                @NonNull String type,
+                @NonNull INsdDiscoverServiceCallback discoverServiceCallback) {
+            mListenerId = listenerId;
+            mType = type;
+            mDiscoverServiceCallback = discoverServiceCallback;
+        }
+
+        @Override
+        public void onStartDiscoveryFailed(String serviceType, int errorCode) {
+            Log.e(
+                    TAG,
+                    "Failed to start service discovery."
+                            + " Error code: "
+                            + errorCode
+                            + ", listener: "
+                            + this);
+            mDiscoveryListeners.remove(mListenerId);
+        }
+
+        @Override
+        public void onStopDiscoveryFailed(String serviceType, int errorCode) {
+            Log.e(
+                    TAG,
+                    "Failed to stop service discovery."
+                            + " Error code: "
+                            + errorCode
+                            + ", listener: "
+                            + this);
+            mDiscoveryListeners.remove(mListenerId);
+        }
+
+        @Override
+        public void onDiscoveryStarted(String serviceType) {
+            Log.i(TAG, "Started service discovery. Listener: " + this);
+        }
+
+        @Override
+        public void onDiscoveryStopped(String serviceType) {
+            Log.i(TAG, "Stopped service discovery. Listener: " + this);
+            mDiscoveryListeners.remove(mListenerId);
+        }
+
+        @Override
+        public void onServiceFound(NsdServiceInfo serviceInfo) {
+            Log.i(TAG, "Found service: " + serviceInfo);
+            try {
+                mDiscoverServiceCallback.onServiceDiscovered(
+                        serviceInfo.getServiceName(), mType, true);
+            } catch (RemoteException e) {
+                // do nothing if the client is dead
+            }
+        }
+
+        @Override
+        public void onServiceLost(NsdServiceInfo serviceInfo) {
+            Log.i(TAG, "Lost service: " + serviceInfo);
+            try {
+                mDiscoverServiceCallback.onServiceDiscovered(
+                        serviceInfo.getServiceName(), mType, false);
+            } catch (RemoteException e) {
+                // do nothing if the client is dead
+            }
+        }
+
+        @Override
+        public String toString() {
+            return "ID: " + mListenerId + ", type: " + mType;
+        }
+    }
+
+    private final class ServiceInfoListener implements NsdManager.ServiceInfoCallback {
+        private final String mName;
+        private final String mType;
+        private final INsdResolveServiceCallback mResolveServiceCallback;
+        private final int mListenerId;
+
+        ServiceInfoListener(
+                @NonNull NsdServiceInfo serviceInfo,
+                int listenerId,
+                @NonNull INsdResolveServiceCallback resolveServiceCallback) {
+            mName = serviceInfo.getServiceName();
+            mType = serviceInfo.getServiceType();
+            mListenerId = listenerId;
+            mResolveServiceCallback = resolveServiceCallback;
+        }
+
+        @Override
+        public void onServiceInfoCallbackRegistrationFailed(int errorCode) {
+            Log.e(
+                    TAG,
+                    "Failed to register service info callback."
+                            + " Listener ID: "
+                            + mListenerId
+                            + ", error: "
+                            + errorCode
+                            + ", service name: "
+                            + mName
+                            + ", service type: "
+                            + mType);
+        }
+
+        @Override
+        public void onServiceUpdated(@NonNull NsdServiceInfo serviceInfo) {
+            Log.i(
+                    TAG,
+                    "Service is resolved. "
+                            + " Listener ID: "
+                            + mListenerId
+                            + ", serviceInfo: "
+                            + serviceInfo);
+            List<String> addresses = new ArrayList<>();
+            for (InetAddress address : serviceInfo.getHostAddresses()) {
+                if (address instanceof Inet6Address) {
+                    addresses.add(address.getHostAddress());
+                }
+            }
+            List<DnsTxtAttribute> txtList = new ArrayList<>();
+            for (Map.Entry<String, byte[]> entry : serviceInfo.getAttributes().entrySet()) {
+                DnsTxtAttribute attribute = new DnsTxtAttribute();
+                attribute.name = entry.getKey();
+                attribute.value = Arrays.copyOf(entry.getValue(), entry.getValue().length);
+                txtList.add(attribute);
+            }
+            // TODO: b/329018320 - Use the serviceInfo.getExpirationTime to derive TTL.
+            int ttlSeconds = 10;
+            try {
+                mResolveServiceCallback.onServiceResolved(
+                        serviceInfo.getHostname(),
+                        serviceInfo.getServiceName(),
+                        serviceInfo.getServiceType(),
+                        serviceInfo.getPort(),
+                        addresses,
+                        txtList,
+                        ttlSeconds);
+
+            } catch (RemoteException e) {
+                // do nothing if the client is dead
+            }
+        }
+
+        @Override
+        public void onServiceLost() {}
+
+        @Override
+        public void onServiceInfoCallbackUnregistered() {
+            Log.i(TAG, "The service info callback is unregistered. Listener: " + this);
+            mServiceInfoListeners.remove(mListenerId);
+        }
+
+        @Override
+        public String toString() {
+            return "ID: " + mListenerId + ", service name: " + mName + ", service type: " + mType;
         }
     }
 }
diff --git a/thread/service/java/com/android/server/thread/OperationReceiverWrapper.java b/thread/service/java/com/android/server/thread/OperationReceiverWrapper.java
index a8909bc..bad63f3 100644
--- a/thread/service/java/com/android/server/thread/OperationReceiverWrapper.java
+++ b/thread/service/java/com/android/server/thread/OperationReceiverWrapper.java
@@ -16,9 +16,11 @@
 
 package com.android.server.thread;
 
+import static android.net.thread.ThreadNetworkException.ERROR_INTERNAL_ERROR;
 import static android.net.thread.ThreadNetworkException.ERROR_UNAVAILABLE;
 
 import android.net.thread.IOperationReceiver;
+import android.net.thread.ThreadNetworkException;
 import android.os.RemoteException;
 
 import com.android.internal.annotations.GuardedBy;
@@ -29,6 +31,7 @@
 /** A {@link IOperationReceiver} wrapper which makes it easier to invoke the callbacks. */
 final class OperationReceiverWrapper {
     private final IOperationReceiver mReceiver;
+    private final boolean mExpectOtDaemonDied;
 
     private static final Object sPendingReceiversLock = new Object();
 
@@ -36,7 +39,19 @@
     private static final Set<OperationReceiverWrapper> sPendingReceivers = new HashSet<>();
 
     public OperationReceiverWrapper(IOperationReceiver receiver) {
-        this.mReceiver = receiver;
+        this(receiver, false /* expectOtDaemonDied */);
+    }
+
+    /**
+     * Creates a new {@link OperationReceiverWrapper}.
+     *
+     * <p>If {@code expectOtDaemonDied} is {@code true}, it's expected that ot-daemon becomes dead
+     * before {@code receiver} is completed with {@code onSuccess} and {@code onError} and {@code
+     * receiver#onSuccess} will be invoked in this case.
+     */
+    public OperationReceiverWrapper(IOperationReceiver receiver, boolean expectOtDaemonDied) {
+        mReceiver = receiver;
+        mExpectOtDaemonDied = expectOtDaemonDied;
 
         synchronized (sPendingReceiversLock) {
             sPendingReceivers.add(this);
@@ -47,7 +62,11 @@
         synchronized (sPendingReceiversLock) {
             for (OperationReceiverWrapper receiver : sPendingReceivers) {
                 try {
-                    receiver.mReceiver.onError(ERROR_UNAVAILABLE, "Thread daemon died");
+                    if (receiver.mExpectOtDaemonDied) {
+                        receiver.mReceiver.onSuccess();
+                    } else {
+                        receiver.mReceiver.onError(ERROR_UNAVAILABLE, "Thread daemon died");
+                    }
                 } catch (RemoteException e) {
                     // The client is dead, do nothing
                 }
@@ -68,6 +87,17 @@
         }
     }
 
+    public void onError(Throwable e) {
+        if (e instanceof ThreadNetworkException) {
+            ThreadNetworkException threadException = (ThreadNetworkException) e;
+            onError(threadException.getErrorCode(), threadException.getMessage());
+        } else if (e instanceof RemoteException) {
+            onError(ERROR_INTERNAL_ERROR, "Thread stack error");
+        } else {
+            throw new AssertionError(e);
+        }
+    }
+
     public void onError(int errorCode, String errorMessage, Object... messageArgs) {
         synchronized (sPendingReceiversLock) {
             sPendingReceivers.remove(this);
diff --git a/thread/service/java/com/android/server/thread/ThreadNetworkControllerService.java b/thread/service/java/com/android/server/thread/ThreadNetworkControllerService.java
index 1b36d2b..d80dcfb 100644
--- a/thread/service/java/com/android/server/thread/ThreadNetworkControllerService.java
+++ b/thread/service/java/com/android/server/thread/ThreadNetworkControllerService.java
@@ -40,6 +40,7 @@
 import static android.net.thread.ThreadNetworkException.ERROR_THREAD_DISABLED;
 import static android.net.thread.ThreadNetworkException.ERROR_TIMEOUT;
 import static android.net.thread.ThreadNetworkException.ERROR_UNSUPPORTED_CHANNEL;
+import static android.net.thread.ThreadNetworkException.ERROR_UNSUPPORTED_OPERATION;
 import static android.net.thread.ThreadNetworkManager.DISALLOW_THREAD_NETWORK;
 import static android.net.thread.ThreadNetworkManager.PERMISSION_THREAD_NETWORK_PRIVILEGED;
 
@@ -47,6 +48,7 @@
 import static com.android.server.thread.openthread.IOtDaemon.ErrorCode.OT_ERROR_BUSY;
 import static com.android.server.thread.openthread.IOtDaemon.ErrorCode.OT_ERROR_FAILED_PRECONDITION;
 import static com.android.server.thread.openthread.IOtDaemon.ErrorCode.OT_ERROR_INVALID_STATE;
+import static com.android.server.thread.openthread.IOtDaemon.ErrorCode.OT_ERROR_NOT_IMPLEMENTED;
 import static com.android.server.thread.openthread.IOtDaemon.ErrorCode.OT_ERROR_NO_BUFS;
 import static com.android.server.thread.openthread.IOtDaemon.ErrorCode.OT_ERROR_PARSE;
 import static com.android.server.thread.openthread.IOtDaemon.ErrorCode.OT_ERROR_REASSEMBLY_TIMEOUT;
@@ -68,6 +70,7 @@
 import android.content.Context;
 import android.content.Intent;
 import android.content.IntentFilter;
+import android.content.res.Resources;
 import android.net.ConnectivityManager;
 import android.net.InetAddresses;
 import android.net.LinkAddress;
@@ -85,6 +88,7 @@
 import android.net.TestNetworkSpecifier;
 import android.net.thread.ActiveOperationalDataset;
 import android.net.thread.ActiveOperationalDataset.SecurityPolicy;
+import android.net.thread.ChannelMaxPower;
 import android.net.thread.IActiveOperationalDatasetReceiver;
 import android.net.thread.IOperationReceiver;
 import android.net.thread.IOperationalDatasetCallback;
@@ -94,6 +98,7 @@
 import android.net.thread.PendingOperationalDataset;
 import android.net.thread.ThreadNetworkController;
 import android.net.thread.ThreadNetworkController.DeviceRole;
+import android.net.thread.ThreadNetworkException;
 import android.net.thread.ThreadNetworkException.ErrorCode;
 import android.os.Build;
 import android.os.Handler;
@@ -106,8 +111,10 @@
 import android.util.Log;
 import android.util.SparseArray;
 
+import com.android.connectivity.resources.R;
 import com.android.internal.annotations.VisibleForTesting;
 import com.android.server.ServiceManagerWrapper;
+import com.android.server.connectivity.ConnectivityResources;
 import com.android.server.thread.openthread.BackboneRouterState;
 import com.android.server.thread.openthread.BorderRouterConfigurationParcel;
 import com.android.server.thread.openthread.IChannelMasksReceiver;
@@ -115,12 +122,16 @@
 import com.android.server.thread.openthread.IOtDaemonCallback;
 import com.android.server.thread.openthread.IOtStatusReceiver;
 import com.android.server.thread.openthread.Ipv6AddressInfo;
+import com.android.server.thread.openthread.MeshcopTxtAttributes;
 import com.android.server.thread.openthread.OtDaemonState;
 
+import libcore.util.HexEncoding;
+
 import java.io.IOException;
 import java.net.Inet6Address;
 import java.net.InetAddress;
 import java.net.UnknownHostException;
+import java.nio.charset.StandardCharsets;
 import java.security.SecureRandom;
 import java.time.Instant;
 import java.util.HashMap;
@@ -129,6 +140,7 @@
 import java.util.Objects;
 import java.util.Random;
 import java.util.function.Supplier;
+import java.util.regex.Pattern;
 
 /**
  * Implementation of the {@link ThreadNetworkController} API.
@@ -143,6 +155,16 @@
 final class ThreadNetworkControllerService extends IThreadNetworkController.Stub {
     private static final String TAG = "ThreadNetworkService";
 
+    // The model name length in utf-8 bytes
+    private static final int MAX_MODEL_NAME_UTF8_BYTES = 24;
+
+    // The max vendor name length in utf-8 bytes
+    private static final int MAX_VENDOR_NAME_UTF8_BYTES = 24;
+
+    // This regex pattern allows "XXXXXX", "XX:XX:XX" and "XX-XX-XX" OUI formats.
+    // Note that this regex allows "XX:XX-XX" as well but we don't need to be a strict checker
+    private static final String OUI_REGEX = "^([0-9A-Fa-f]{2}[:-]?){2}([0-9A-Fa-f]{2})$";
+
     // Below member fields can be accessed from both the binder and handler threads
 
     private final Context mContext;
@@ -159,7 +181,11 @@
     private final InfraInterfaceController mInfraIfController;
     private final NsdPublisher mNsdPublisher;
     private final OtDaemonCallbackProxy mOtDaemonCallbackProxy = new OtDaemonCallbackProxy();
+    private final ConnectivityResources mResources;
+    private final Supplier<String> mCountryCodeSupplier;
 
+    // This should not be directly used for calling IOtDaemon APIs because ot-daemon may die and
+    // {@code mOtDaemon} will be set to {@code null}. Instead, use {@code getOtDaemon()}
     @Nullable private IOtDaemon mOtDaemon;
     @Nullable private NetworkAgent mNetworkAgent;
     @Nullable private NetworkAgent mTestNetworkAgent;
@@ -174,6 +200,7 @@
     private final ThreadPersistentSettings mPersistentSettings;
     private final UserManager mUserManager;
     private boolean mUserRestricted;
+    private boolean mForceStopOtDaemonEnabled;
 
     private BorderRouterConfigurationParcel mBorderRouterConfig;
 
@@ -188,7 +215,9 @@
             InfraInterfaceController infraIfController,
             ThreadPersistentSettings persistentSettings,
             NsdPublisher nsdPublisher,
-            UserManager userManager) {
+            UserManager userManager,
+            ConnectivityResources resources,
+            Supplier<String> countryCodeSupplier) {
         mContext = context;
         mHandler = handler;
         mNetworkProvider = networkProvider;
@@ -202,10 +231,14 @@
         mPersistentSettings = persistentSettings;
         mNsdPublisher = nsdPublisher;
         mUserManager = userManager;
+        mResources = resources;
+        mCountryCodeSupplier = countryCodeSupplier;
     }
 
     public static ThreadNetworkControllerService newInstance(
-            Context context, ThreadPersistentSettings persistentSettings) {
+            Context context,
+            ThreadPersistentSettings persistentSettings,
+            Supplier<String> countryCodeSupplier) {
         HandlerThread handlerThread = new HandlerThread("ThreadHandlerThread");
         handlerThread.start();
         Handler handler = new Handler(handlerThread.getLooper());
@@ -222,7 +255,9 @@
                 new InfraInterfaceController(),
                 persistentSettings,
                 NsdPublisher.newInstance(context, handler),
-                context.getSystemService(UserManager.class));
+                context.getSystemService(UserManager.class),
+                new ConnectivityResources(context),
+                countryCodeSupplier);
     }
 
     private static Inet6Address bytesToInet6Address(byte[] addressBytes) {
@@ -279,17 +314,31 @@
                 .build();
     }
 
-    private void initializeOtDaemon() {
+    private void maybeInitializeOtDaemon() {
+        if (!isEnabled()) {
+            return;
+        }
+
+        Log.i(TAG, "Starting OT daemon...");
+
         try {
             getOtDaemon();
         } catch (RemoteException e) {
-            Log.e(TAG, "Failed to initialize ot-daemon");
+            Log.e(TAG, "Failed to initialize ot-daemon", e);
+        } catch (ThreadNetworkException e) {
+            // no ThreadNetworkException.ERROR_THREAD_DISABLED error should be thrown
+            throw new AssertionError(e);
         }
     }
 
-    private IOtDaemon getOtDaemon() throws RemoteException {
+    private IOtDaemon getOtDaemon() throws RemoteException, ThreadNetworkException {
         checkOnHandlerThread();
 
+        if (mForceStopOtDaemonEnabled) {
+            throw new ThreadNetworkException(
+                    ERROR_THREAD_DISABLED, "ot-daemon is forcibly stopped");
+        }
+
         if (mOtDaemon != null) {
             return mOtDaemon;
         }
@@ -298,29 +347,75 @@
         if (otDaemon == null) {
             throw new RemoteException("Internal error: failed to start OT daemon");
         }
-        otDaemon.initialize(mTunIfController.getTunFd(), isEnabled(), mNsdPublisher);
-        otDaemon.registerStateCallback(mOtDaemonCallbackProxy, -1);
+
+        otDaemon.initialize(
+                mTunIfController.getTunFd(),
+                isEnabled(),
+                mNsdPublisher,
+                getMeshcopTxtAttributes(mResources.get()),
+                mOtDaemonCallbackProxy,
+                mCountryCodeSupplier.get());
         otDaemon.asBinder().linkToDeath(() -> mHandler.post(this::onOtDaemonDied), 0);
         mOtDaemon = otDaemon;
         return mOtDaemon;
     }
 
+    @VisibleForTesting
+    static MeshcopTxtAttributes getMeshcopTxtAttributes(Resources resources) {
+        final String modelName = resources.getString(R.string.config_thread_model_name);
+        final String vendorName = resources.getString(R.string.config_thread_vendor_name);
+        final String vendorOui = resources.getString(R.string.config_thread_vendor_oui);
+
+        if (!modelName.isEmpty()) {
+            if (modelName.getBytes(StandardCharsets.UTF_8).length > MAX_MODEL_NAME_UTF8_BYTES) {
+                throw new IllegalStateException(
+                        "Model name is longer than "
+                                + MAX_MODEL_NAME_UTF8_BYTES
+                                + "utf-8 bytes: "
+                                + modelName);
+            }
+        }
+
+        if (!vendorName.isEmpty()) {
+            if (vendorName.getBytes(StandardCharsets.UTF_8).length > MAX_VENDOR_NAME_UTF8_BYTES) {
+                throw new IllegalStateException(
+                        "Vendor name is longer than "
+                                + MAX_VENDOR_NAME_UTF8_BYTES
+                                + " utf-8 bytes: "
+                                + vendorName);
+            }
+        }
+
+        if (!vendorOui.isEmpty() && !Pattern.compile(OUI_REGEX).matcher(vendorOui).matches()) {
+            throw new IllegalStateException("Vendor OUI is invalid: " + vendorOui);
+        }
+
+        MeshcopTxtAttributes meshcopTxts = new MeshcopTxtAttributes();
+        meshcopTxts.modelName = modelName;
+        meshcopTxts.vendorName = vendorName;
+        meshcopTxts.vendorOui = HexEncoding.decode(vendorOui.replace("-", "").replace(":", ""));
+        return meshcopTxts;
+    }
+
     private void onOtDaemonDied() {
         checkOnHandlerThread();
-        Log.w(TAG, "OT daemon is dead, clean up and restart it...");
+        Log.w(TAG, "OT daemon is dead, clean up...");
 
         OperationReceiverWrapper.onOtDaemonDied();
         mOtDaemonCallbackProxy.onOtDaemonDied();
         mTunIfController.onOtDaemonDied();
         mNsdPublisher.onOtDaemonDied();
         mOtDaemon = null;
-        initializeOtDaemon();
+        maybeInitializeOtDaemon();
     }
 
     public void initialize() {
         mHandler.post(
                 () -> {
-                    Log.d(TAG, "Initializing Thread system service...");
+                    Log.d(
+                            TAG,
+                            "Initializing Thread system service: Thread is "
+                                    + (isEnabled() ? "enabled" : "disabled"));
                     try {
                         mTunIfController.createTunInterface();
                     } catch (IOException e) {
@@ -332,10 +427,59 @@
                     requestThreadNetwork();
                     mUserRestricted = isThreadUserRestricted();
                     registerUserRestrictionsReceiver();
-                    initializeOtDaemon();
+                    maybeInitializeOtDaemon();
                 });
     }
 
+    /**
+     * Force stops ot-daemon immediately and prevents ot-daemon from being restarted by
+     * system_server again.
+     *
+     * <p>This is for VTS testing only.
+     */
+    @RequiresPermission(PERMISSION_THREAD_NETWORK_PRIVILEGED)
+    void forceStopOtDaemonForTest(boolean enabled, @NonNull IOperationReceiver receiver) {
+        enforceAllPermissionsGranted(PERMISSION_THREAD_NETWORK_PRIVILEGED);
+
+        mHandler.post(
+                () ->
+                        forceStopOtDaemonForTestInternal(
+                                enabled,
+                                new OperationReceiverWrapper(
+                                        receiver, true /* expectOtDaemonDied */)));
+    }
+
+    private void forceStopOtDaemonForTestInternal(
+            boolean enabled, @NonNull OperationReceiverWrapper receiver) {
+        checkOnHandlerThread();
+        if (enabled == mForceStopOtDaemonEnabled) {
+            receiver.onSuccess();
+            return;
+        }
+
+        if (!enabled) {
+            mForceStopOtDaemonEnabled = false;
+            maybeInitializeOtDaemon();
+            receiver.onSuccess();
+            return;
+        }
+
+        try {
+            getOtDaemon().terminate();
+            // Do not invoke the {@code receiver} callback here but wait for ot-daemon to
+            // become dead, so that it's guaranteed that ot-daemon is stopped when {@code
+            // receiver} is completed
+        } catch (RemoteException e) {
+            Log.e(TAG, "otDaemon.terminate failed", e);
+            receiver.onError(ERROR_INTERNAL_ERROR, "Thread stack error");
+        } catch (ThreadNetworkException e) {
+            // No ThreadNetworkException.ERROR_THREAD_DISABLED error will be thrown
+            throw new AssertionError(e);
+        } finally {
+            mForceStopOtDaemonEnabled = true;
+        }
+    }
+
     public void setEnabled(boolean isEnabled, @NonNull IOperationReceiver receiver) {
         enforceAllPermissionsGranted(PERMISSION_THREAD_NETWORK_PRIVILEGED);
 
@@ -356,6 +500,8 @@
             return;
         }
 
+        Log.i(TAG, "Set Thread enabled: " + isEnabled + ", persist: " + persist);
+
         if (persist) {
             // The persistent setting keeps the desired enabled state, thus it's set regardless
             // the otDaemon set enabled state operation succeeded or not, so that it can recover
@@ -365,9 +511,9 @@
 
         try {
             getOtDaemon().setThreadEnabled(isEnabled, newOtStatusReceiver(receiver));
-        } catch (RemoteException e) {
+        } catch (RemoteException | ThreadNetworkException e) {
             Log.e(TAG, "otDaemon.setThreadEnabled failed", e);
-            receiver.onError(ERROR_INTERNAL_ERROR, "Thread stack error");
+            receiver.onError(e);
         }
     }
 
@@ -424,7 +570,9 @@
 
     /** Returns {@code true} if Thread is set enabled. */
     private boolean isEnabled() {
-        return !mUserRestricted && mPersistentSettings.get(ThreadPersistentSettings.THREAD_ENABLED);
+        return !mForceStopOtDaemonEnabled
+                && !mUserRestricted
+                && mPersistentSettings.get(ThreadPersistentSettings.THREAD_ENABLED);
     }
 
     /** Returns {@code true} if Thread has been restricted for the user. */
@@ -612,9 +760,9 @@
 
         try {
             getOtDaemon().getChannelMasks(newChannelMasksReceiver(networkName, receiver));
-        } catch (RemoteException e) {
+        } catch (RemoteException | ThreadNetworkException e) {
             Log.e(TAG, "otDaemon.getChannelMasks failed", e);
-            receiver.onError(ERROR_INTERNAL_ERROR, "Thread stack error");
+            receiver.onError(e);
         }
     }
 
@@ -786,6 +934,8 @@
                 return ERROR_ABORTED;
             case OT_ERROR_BUSY:
                 return ERROR_BUSY;
+            case OT_ERROR_NOT_IMPLEMENTED:
+                return ERROR_UNSUPPORTED_OPERATION;
             case OT_ERROR_NO_BUFS:
                 return ERROR_RESOURCE_EXHAUSTED;
             case OT_ERROR_PARSE:
@@ -824,9 +974,9 @@
         try {
             // The otDaemon.join() will leave first if this device is currently attached
             getOtDaemon().join(activeDataset.toThreadTlvs(), newOtStatusReceiver(receiver));
-        } catch (RemoteException e) {
+        } catch (RemoteException | ThreadNetworkException e) {
             Log.e(TAG, "otDaemon.join failed", e);
-            receiver.onError(ERROR_INTERNAL_ERROR, "Thread stack error");
+            receiver.onError(e);
         }
     }
 
@@ -849,9 +999,9 @@
             getOtDaemon()
                     .scheduleMigration(
                             pendingDataset.toThreadTlvs(), newOtStatusReceiver(receiver));
-        } catch (RemoteException e) {
+        } catch (RemoteException | ThreadNetworkException e) {
             Log.e(TAG, "otDaemon.scheduleMigration failed", e);
-            receiver.onError(ERROR_INTERNAL_ERROR, "Thread stack error");
+            receiver.onError(e);
         }
     }
 
@@ -867,9 +1017,9 @@
 
         try {
             getOtDaemon().leave(newOtStatusReceiver(receiver));
-        } catch (RemoteException e) {
-            // Oneway AIDL API should never throw?
-            receiver.onError(ERROR_INTERNAL_ERROR, "Thread stack error");
+        } catch (RemoteException | ThreadNetworkException e) {
+            Log.e(TAG, "otDaemon.leave failed", e);
+            receiver.onError(e);
         }
     }
 
@@ -891,11 +1041,18 @@
             String countryCode, @NonNull OperationReceiverWrapper receiver) {
         checkOnHandlerThread();
 
+        // Fails early to avoid waking up ot-daemon by the ThreadNetworkCountryCode class
+        if (!isEnabled()) {
+            receiver.onError(
+                    ERROR_THREAD_DISABLED, "Can't set country code when Thread is disabled");
+            return;
+        }
+
         try {
             getOtDaemon().setCountryCode(countryCode, newOtStatusReceiver(receiver));
-        } catch (RemoteException e) {
+        } catch (RemoteException | ThreadNetworkException e) {
             Log.e(TAG, "otDaemon.setCountryCode failed", e);
-            receiver.onError(ERROR_INTERNAL_ERROR, "Thread stack error");
+            receiver.onError(e);
         }
     }
 
@@ -931,6 +1088,30 @@
         }
     }
 
+    @RequiresPermission(PERMISSION_THREAD_NETWORK_PRIVILEGED)
+    public void setChannelMaxPowers(
+            @NonNull ChannelMaxPower[] channelMaxPowers, @NonNull IOperationReceiver receiver) {
+        enforceAllPermissionsGranted(PERMISSION_THREAD_NETWORK_PRIVILEGED);
+
+        mHandler.post(
+                () ->
+                        setChannelMaxPowersInternal(
+                                channelMaxPowers, new OperationReceiverWrapper(receiver)));
+    }
+
+    private void setChannelMaxPowersInternal(
+            @NonNull ChannelMaxPower[] channelMaxPowers,
+            @NonNull OperationReceiverWrapper receiver) {
+        checkOnHandlerThread();
+
+        try {
+            getOtDaemon().setChannelMaxPowers(channelMaxPowers, newOtStatusReceiver(receiver));
+        } catch (RemoteException | ThreadNetworkException e) {
+            Log.e(TAG, "otDaemon.setChannelMaxPowers failed", e);
+            receiver.onError(ERROR_INTERNAL_ERROR, "Thread stack error");
+        }
+    }
+
     private void enableBorderRouting(String infraIfName) {
         if (mBorderRouterConfig.isBorderRoutingEnabled
                 && infraIfName.equals(mBorderRouterConfig.infraInterfaceName)) {
@@ -943,9 +1124,10 @@
                     mInfraIfController.createIcmp6Socket(infraIfName);
             mBorderRouterConfig.isBorderRoutingEnabled = true;
 
-            mOtDaemon.configureBorderRouter(
-                    mBorderRouterConfig, new ConfigureBorderRouterStatusReceiver());
-        } catch (RemoteException | IOException e) {
+            getOtDaemon()
+                    .configureBorderRouter(
+                            mBorderRouterConfig, new ConfigureBorderRouterStatusReceiver());
+        } catch (RemoteException | IOException | ThreadNetworkException e) {
             Log.w(TAG, "Failed to enable border routing", e);
         }
     }
@@ -956,9 +1138,10 @@
         mBorderRouterConfig.infraInterfaceIcmp6Socket = null;
         mBorderRouterConfig.isBorderRoutingEnabled = false;
         try {
-            mOtDaemon.configureBorderRouter(
-                    mBorderRouterConfig, new ConfigureBorderRouterStatusReceiver());
-        } catch (RemoteException e) {
+            getOtDaemon()
+                    .configureBorderRouter(
+                            mBorderRouterConfig, new ConfigureBorderRouterStatusReceiver());
+        } catch (RemoteException | ThreadNetworkException e) {
             Log.w(TAG, "Failed to disable border routing", e);
         }
     }
@@ -1126,8 +1309,8 @@
 
             try {
                 getOtDaemon().registerStateCallback(this, callbackMetadata.id);
-            } catch (RemoteException e) {
-                // oneway operation should never fail
+            } catch (RemoteException | ThreadNetworkException e) {
+                Log.e(TAG, "otDaemon.registerStateCallback failed", e);
             }
         }
 
@@ -1167,8 +1350,8 @@
 
             try {
                 getOtDaemon().registerStateCallback(this, callbackMetadata.id);
-            } catch (RemoteException e) {
-                // oneway operation should never fail
+            } catch (RemoteException | ThreadNetworkException e) {
+                Log.e(TAG, "otDaemon.registerStateCallback failed", e);
             }
         }
 
@@ -1187,8 +1370,11 @@
                 return;
             }
 
+            final int deviceRole = mState.deviceRole;
+            mState = null;
+
             // If this device is already STOPPED or DETACHED, do nothing
-            if (!ThreadNetworkController.isAttached(mState.deviceRole)) {
+            if (!ThreadNetworkController.isAttached(deviceRole)) {
                 return;
             }
 
diff --git a/thread/service/java/com/android/server/thread/ThreadNetworkCountryCode.java b/thread/service/java/com/android/server/thread/ThreadNetworkCountryCode.java
index ffa7b44..a194114 100644
--- a/thread/service/java/com/android/server/thread/ThreadNetworkCountryCode.java
+++ b/thread/service/java/com/android/server/thread/ThreadNetworkCountryCode.java
@@ -16,6 +16,8 @@
 
 package com.android.server.thread;
 
+import static com.android.server.thread.ThreadPersistentSettings.THREAD_COUNTRY_CODE;
+
 import android.annotation.Nullable;
 import android.annotation.StringDef;
 import android.annotation.TargetApi;
@@ -83,6 +85,7 @@
                 COUNTRY_CODE_SOURCE_TELEPHONY,
                 COUNTRY_CODE_SOURCE_TELEPHONY_LAST,
                 COUNTRY_CODE_SOURCE_WIFI,
+                COUNTRY_CODE_SOURCE_SETTINGS,
             })
     private @interface CountryCodeSource {}
 
@@ -93,6 +96,7 @@
     private static final String COUNTRY_CODE_SOURCE_TELEPHONY = "Telephony";
     private static final String COUNTRY_CODE_SOURCE_TELEPHONY_LAST = "TelephonyLast";
     private static final String COUNTRY_CODE_SOURCE_WIFI = "Wifi";
+    private static final String COUNTRY_CODE_SOURCE_SETTINGS = "Settings";
 
     private static final CountryCodeInfo DEFAULT_COUNTRY_CODE_INFO =
             new CountryCodeInfo(DEFAULT_COUNTRY_CODE, COUNTRY_CODE_SOURCE_DEFAULT);
@@ -107,6 +111,7 @@
     private final SubscriptionManager mSubscriptionManager;
     private final Map<Integer, TelephonyCountryCodeSlotInfo> mTelephonyCountryCodeSlotInfoMap =
             new ArrayMap();
+    private final ThreadPersistentSettings mPersistentSettings;
 
     @Nullable private CountryCodeInfo mCurrentCountryCodeInfo;
     @Nullable private CountryCodeInfo mLocationCountryCodeInfo;
@@ -215,7 +220,8 @@
             Context context,
             TelephonyManager telephonyManager,
             SubscriptionManager subscriptionManager,
-            @Nullable String oemCountryCode) {
+            @Nullable String oemCountryCode,
+            ThreadPersistentSettings persistentSettings) {
         mLocationManager = locationManager;
         mThreadNetworkControllerService = threadNetworkControllerService;
         mGeocoder = geocoder;
@@ -224,14 +230,19 @@
         mContext = context;
         mTelephonyManager = telephonyManager;
         mSubscriptionManager = subscriptionManager;
+        mPersistentSettings = persistentSettings;
 
         if (oemCountryCode != null) {
             mOemCountryCodeInfo = new CountryCodeInfo(oemCountryCode, COUNTRY_CODE_SOURCE_OEM);
         }
+
+        mCurrentCountryCodeInfo = pickCountryCode();
     }
 
     public static ThreadNetworkCountryCode newInstance(
-            Context context, ThreadNetworkControllerService controllerService) {
+            Context context,
+            ThreadNetworkControllerService controllerService,
+            ThreadPersistentSettings persistentSettings) {
         return new ThreadNetworkCountryCode(
                 context.getSystemService(LocationManager.class),
                 controllerService,
@@ -241,7 +252,8 @@
                 context,
                 context.getSystemService(TelephonyManager.class),
                 context.getSystemService(SubscriptionManager.class),
-                ThreadNetworkProperties.country_code().orElse(null));
+                ThreadNetworkProperties.country_code().orElse(null),
+                persistentSettings);
     }
 
     /** Sets up this country code module to listen to location country code changes. */
@@ -485,6 +497,11 @@
             return mLocationCountryCodeInfo;
         }
 
+        String settingsCountryCode = mPersistentSettings.get(THREAD_COUNTRY_CODE);
+        if (settingsCountryCode != null) {
+            return new CountryCodeInfo(settingsCountryCode, COUNTRY_CODE_SOURCE_SETTINGS);
+        }
+
         if (mOemCountryCodeInfo != null) {
             return mOemCountryCodeInfo;
         }
@@ -498,6 +515,8 @@
             public void onSuccess() {
                 synchronized ("ThreadNetworkCountryCode.this") {
                     mCurrentCountryCodeInfo = countryCodeInfo;
+                    mPersistentSettings.put(
+                            THREAD_COUNTRY_CODE.key, countryCodeInfo.getCountryCode());
                 }
             }
 
@@ -536,10 +555,9 @@
                 newOperationReceiver(countryCodeInfo));
     }
 
-    /** Returns the current country code or {@code null} if no country code is set. */
-    @Nullable
+    /** Returns the current country code. */
     public synchronized String getCountryCode() {
-        return (mCurrentCountryCodeInfo != null) ? mCurrentCountryCodeInfo.getCountryCode() : null;
+        return mCurrentCountryCodeInfo.getCountryCode();
     }
 
     /**
diff --git a/thread/service/java/com/android/server/thread/ThreadNetworkService.java b/thread/service/java/com/android/server/thread/ThreadNetworkService.java
index 5664922..30c67ca 100644
--- a/thread/service/java/com/android/server/thread/ThreadNetworkService.java
+++ b/thread/service/java/com/android/server/thread/ThreadNetworkService.java
@@ -18,6 +18,8 @@
 
 import static android.content.pm.PackageManager.PERMISSION_GRANTED;
 
+import static java.util.Objects.requireNonNull;
+
 import android.annotation.NonNull;
 import android.annotation.Nullable;
 import android.content.Context;
@@ -58,15 +60,19 @@
         if (phase == SystemService.PHASE_SYSTEM_SERVICES_READY) {
             mPersistentSettings.initialize();
             mControllerService =
-                    ThreadNetworkControllerService.newInstance(mContext, mPersistentSettings);
+                    ThreadNetworkControllerService.newInstance(
+                            mContext, mPersistentSettings, () -> mCountryCode.getCountryCode());
+            mCountryCode =
+                    ThreadNetworkCountryCode.newInstance(
+                            mContext, mControllerService, mPersistentSettings);
             mControllerService.initialize();
         } else if (phase == SystemService.PHASE_BOOT_COMPLETED) {
             // Country code initialization is delayed to the BOOT_COMPLETED phase because it will
             // call into Wi-Fi and Telephony service whose country code module is ready after
             // PHASE_ACTIVITY_MANAGER_READY and PHASE_THIRD_PARTY_APPS_CAN_START
-            mCountryCode = ThreadNetworkCountryCode.newInstance(mContext, mControllerService);
             mCountryCode.initialize();
-            mShellCommand = new ThreadNetworkShellCommand(mCountryCode);
+            mShellCommand =
+                    new ThreadNetworkShellCommand(requireNonNull(mControllerService), mCountryCode);
         }
     }
 
diff --git a/thread/service/java/com/android/server/thread/ThreadNetworkShellCommand.java b/thread/service/java/com/android/server/thread/ThreadNetworkShellCommand.java
index c17c5a7..c6a1618 100644
--- a/thread/service/java/com/android/server/thread/ThreadNetworkShellCommand.java
+++ b/thread/service/java/com/android/server/thread/ThreadNetworkShellCommand.java
@@ -16,7 +16,10 @@
 
 package com.android.server.thread;
 
+import android.annotation.NonNull;
 import android.annotation.Nullable;
+import android.net.thread.IOperationReceiver;
+import android.net.thread.ThreadNetworkException;
 import android.os.Binder;
 import android.os.Process;
 import android.text.TextUtils;
@@ -25,7 +28,12 @@
 import com.android.modules.utils.BasicShellCommandHandler;
 
 import java.io.PrintWriter;
+import java.time.Duration;
 import java.util.List;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
 
 /**
  * Interprets and executes 'adb shell cmd thread_network [args]'.
@@ -37,16 +45,22 @@
  * corresponding API permissions.
  */
 public class ThreadNetworkShellCommand extends BasicShellCommandHandler {
-    private static final String TAG = "ThreadNetworkShellCommand";
+    private static final Duration SET_ENABLED_TIMEOUT = Duration.ofSeconds(2);
+    private static final Duration FORCE_STOP_TIMEOUT = Duration.ofSeconds(1);
 
     // These don't require root access.
-    private static final List<String> NON_PRIVILEGED_COMMANDS = List.of("help", "get-country-code");
+    private static final List<String> NON_PRIVILEGED_COMMANDS =
+            List.of("help", "get-country-code", "enable", "disable");
 
-    @Nullable private final ThreadNetworkCountryCode mCountryCode;
+    @NonNull private final ThreadNetworkControllerService mControllerService;
+    @NonNull private final ThreadNetworkCountryCode mCountryCode;
     @Nullable private PrintWriter mOutputWriter;
     @Nullable private PrintWriter mErrorWriter;
 
-    ThreadNetworkShellCommand(@Nullable ThreadNetworkCountryCode countryCode) {
+    ThreadNetworkShellCommand(
+            @NonNull ThreadNetworkControllerService controllerService,
+            @NonNull ThreadNetworkCountryCode countryCode) {
+        mControllerService = controllerService;
         mCountryCode = countryCode;
     }
 
@@ -91,14 +105,14 @@
         }
 
         switch (cmd) {
+            case "enable":
+                return setThreadEnabled(true);
+            case "disable":
+                return setThreadEnabled(false);
+            case "force-stop-ot-daemon":
+                return forceStopOtDaemon();
             case "force-country-code":
                 boolean enabled;
-
-                if (mCountryCode == null) {
-                    perr.println("Thread country code operations are not supported");
-                    return -1;
-                }
-
                 try {
                     enabled = getNextArgRequiredTrueOrFalse("enabled", "disabled");
                 } catch (IllegalArgumentException e) {
@@ -124,11 +138,6 @@
                 }
                 return 0;
             case "get-country-code":
-                if (mCountryCode == null) {
-                    perr.println("Thread country code operations are not supported");
-                    return -1;
-                }
-
                 pw.println("Thread country code = " + mCountryCode.getCountryCode());
                 return 0;
             default:
@@ -136,6 +145,64 @@
         }
     }
 
+    private int setThreadEnabled(boolean enabled) {
+        CompletableFuture<Void> setEnabledFuture = new CompletableFuture<>();
+        mControllerService.setEnabled(enabled, newOperationReceiver(setEnabledFuture));
+        return waitForFuture(setEnabledFuture, FORCE_STOP_TIMEOUT, getErrorWriter());
+    }
+
+    private int forceStopOtDaemon() {
+        final PrintWriter errorWriter = getErrorWriter();
+        boolean enabled;
+        try {
+            enabled = getNextArgRequiredTrueOrFalse("enabled", "disabled");
+        } catch (IllegalArgumentException e) {
+            errorWriter.println("Invalid argument: " + e.getMessage());
+            return -1;
+        }
+
+        CompletableFuture<Void> forceStopFuture = new CompletableFuture<>();
+        mControllerService.forceStopOtDaemonForTest(enabled, newOperationReceiver(forceStopFuture));
+        return waitForFuture(forceStopFuture, FORCE_STOP_TIMEOUT, getErrorWriter());
+    }
+
+    private static IOperationReceiver newOperationReceiver(CompletableFuture<Void> future) {
+        return new IOperationReceiver.Stub() {
+            @Override
+            public void onSuccess() {
+                future.complete(null);
+            }
+
+            @Override
+            public void onError(int errorCode, String errorMessage) {
+                future.completeExceptionally(new ThreadNetworkException(errorCode, errorMessage));
+            }
+        };
+    }
+
+    /**
+     * Waits for the future to complete within given timeout.
+     *
+     * <p>Returns 0 if {@code future} completed successfully, or -1 if {@code future} failed to
+     * complete. When failed, error messages are printed to {@code errorWriter}.
+     */
+    private int waitForFuture(
+            CompletableFuture<Void> future, Duration timeout, PrintWriter errorWriter) {
+        try {
+            future.get(timeout.toSeconds(), TimeUnit.SECONDS);
+            return 0;
+        } catch (InterruptedException e) {
+            Thread.currentThread().interrupt();
+            errorWriter.println("Failed: " + e.getMessage());
+        } catch (ExecutionException e) {
+            errorWriter.println("Failed: " + e.getCause().getMessage());
+        } catch (TimeoutException e) {
+            errorWriter.println("Failed: command timeout for " + timeout);
+        }
+
+        return -1;
+    }
+
     private static boolean argTrueOrFalse(String arg, String trueString, String falseString) {
         if (trueString.equals(arg)) {
             return true;
@@ -159,6 +226,10 @@
     }
 
     private void onHelpNonPrivileged(PrintWriter pw) {
+        pw.println("  enable");
+        pw.println("    Enables Thread radio");
+        pw.println("  disable");
+        pw.println("    Disables Thread radio");
         pw.println("  get-country-code");
         pw.println("    Gets country code as a two-letter string");
     }
@@ -166,6 +237,8 @@
     private void onHelpPrivileged(PrintWriter pw) {
         pw.println("  force-country-code enabled <two-letter code> | disabled ");
         pw.println("    Sets country code to <two-letter code> or left for normal value");
+        pw.println("  force-stop-ot-daemon enabled | disabled ");
+        pw.println("    force stop ot-daemon service");
     }
 
     @Override
diff --git a/thread/service/java/com/android/server/thread/ThreadPersistentSettings.java b/thread/service/java/com/android/server/thread/ThreadPersistentSettings.java
index 5cb53fe..8aaff60 100644
--- a/thread/service/java/com/android/server/thread/ThreadPersistentSettings.java
+++ b/thread/service/java/com/android/server/thread/ThreadPersistentSettings.java
@@ -61,7 +61,10 @@
 
     /******** Thread persistent setting keys ***************/
     /** Stores the Thread feature toggle state, true for enabled and false for disabled. */
-    public static final Key<Boolean> THREAD_ENABLED = new Key<>("Thread_enabled", true);
+    public static final Key<Boolean> THREAD_ENABLED = new Key<>("thread_enabled", true);
+
+    /** Stores the Thread country code, null if no country code is stored. */
+    public static final Key<String> THREAD_COUNTRY_CODE = new Key<>("thread_country_code", null);
 
     /******** Thread persistent setting keys ***************/
 
@@ -123,7 +126,9 @@
     private <T> T getObject(String key, T defaultValue) {
         Object value;
         synchronized (mLock) {
-            if (defaultValue instanceof Boolean) {
+            if (defaultValue == null) {
+                value = mSettings.getString(key, null);
+            } else if (defaultValue instanceof Boolean) {
                 value = mSettings.getBoolean(key, (Boolean) defaultValue);
             } else if (defaultValue instanceof Integer) {
                 value = mSettings.getInt(key, (Integer) defaultValue);
diff --git a/thread/tests/cts/src/android/net/thread/cts/ThreadNetworkControllerTest.java b/thread/tests/cts/src/android/net/thread/cts/ThreadNetworkControllerTest.java
index 0591c87..9a81388 100644
--- a/thread/tests/cts/src/android/net/thread/cts/ThreadNetworkControllerTest.java
+++ b/thread/tests/cts/src/android/net/thread/cts/ThreadNetworkControllerTest.java
@@ -864,11 +864,12 @@
     @Test
     public void meshcopService_threadDisabled_notDiscovered() throws Exception {
         setUpTestNetwork();
-
         CompletableFuture<NsdServiceInfo> serviceLostFuture = new CompletableFuture<>();
         NsdManager.DiscoveryListener listener =
                 discoverForServiceLost(MESHCOP_SERVICE_TYPE, serviceLostFuture);
+
         setEnabledAndWait(mController, false);
+
         try {
             serviceLostFuture.get(SERVICE_LOST_TIMEOUT_MILLIS, MILLISECONDS);
         } catch (InterruptedException | ExecutionException | TimeoutException ignored) {
@@ -877,7 +878,6 @@
         } finally {
             mNsdManager.stopServiceDiscovery(listener);
         }
-
         assertThrows(
                 TimeoutException.class,
                 () -> discoverService(MESHCOP_SERVICE_TYPE, SERVICE_LOST_TIMEOUT_MILLIS));
@@ -1112,7 +1112,12 @@
                         serviceInfoFuture.complete(serviceInfo);
                     }
                 };
-        mNsdManager.discoverServices(serviceType, NsdManager.PROTOCOL_DNS_SD, listener);
+        mNsdManager.discoverServices(
+                serviceType,
+                NsdManager.PROTOCOL_DNS_SD,
+                mTestNetworkTracker.getNetwork(),
+                mExecutor,
+                listener);
         try {
             serviceInfoFuture.get(timeoutMilliseconds, MILLISECONDS);
         } finally {
@@ -1131,7 +1136,12 @@
                         serviceInfoFuture.complete(serviceInfo);
                     }
                 };
-        mNsdManager.discoverServices(serviceType, NsdManager.PROTOCOL_DNS_SD, listener);
+        mNsdManager.discoverServices(
+                serviceType,
+                NsdManager.PROTOCOL_DNS_SD,
+                mTestNetworkTracker.getNetwork(),
+                mExecutor,
+                listener);
         return listener;
     }
 
diff --git a/thread/tests/integration/Android.bp b/thread/tests/integration/Android.bp
index 9677ec5..94985b1 100644
--- a/thread/tests/integration/Android.bp
+++ b/thread/tests/integration/Android.bp
@@ -30,6 +30,7 @@
         "net-tests-utils",
         "net-utils-device-common",
         "net-utils-device-common-bpf",
+        "net-utils-device-common-struct-base",
         "testables",
         "ThreadNetworkTestUtils",
         "truth",
diff --git a/thread/tests/integration/src/android/net/thread/BorderRoutingTest.java b/thread/tests/integration/src/android/net/thread/BorderRoutingTest.java
index b40685c..9b1c338 100644
--- a/thread/tests/integration/src/android/net/thread/BorderRoutingTest.java
+++ b/thread/tests/integration/src/android/net/thread/BorderRoutingTest.java
@@ -17,10 +17,7 @@
 package android.net.thread;
 
 import static android.Manifest.permission.MANAGE_TEST_NETWORKS;
-import static android.Manifest.permission.NETWORK_SETTINGS;
-import static android.net.thread.ThreadNetworkManager.PERMISSION_THREAD_NETWORK_PRIVILEGED;
 import static android.net.thread.utils.IntegrationTestUtils.JOIN_TIMEOUT;
-import static android.net.thread.utils.IntegrationTestUtils.RESTART_JOIN_TIMEOUT;
 import static android.net.thread.utils.IntegrationTestUtils.isExpectedIcmpv6Packet;
 import static android.net.thread.utils.IntegrationTestUtils.isFromIpv6Source;
 import static android.net.thread.utils.IntegrationTestUtils.isInMulticastGroup;
@@ -36,14 +33,13 @@
 import static com.android.testutils.TestPermissionUtil.runAsShell;
 
 import static com.google.common.io.BaseEncoding.base16;
-import static com.google.common.util.concurrent.MoreExecutors.directExecutor;
 
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertNull;
 
-import static java.util.concurrent.TimeUnit.MILLISECONDS;
+import static java.util.Objects.requireNonNull;
 
 import android.content.Context;
 import android.net.InetAddresses;
@@ -56,6 +52,7 @@
 import android.net.thread.utils.ThreadFeatureCheckerRule.RequiresIpv6MulticastRouting;
 import android.net.thread.utils.ThreadFeatureCheckerRule.RequiresSimulationThreadDevice;
 import android.net.thread.utils.ThreadFeatureCheckerRule.RequiresThreadFeature;
+import android.net.thread.utils.ThreadNetworkControllerWrapper;
 import android.os.Handler;
 import android.os.HandlerThread;
 
@@ -76,9 +73,6 @@
 import java.time.Duration;
 import java.util.ArrayList;
 import java.util.List;
-import java.util.concurrent.CompletableFuture;
-import java.util.concurrent.CountDownLatch;
-import java.util.concurrent.TimeUnit;
 import java.util.function.Predicate;
 
 /** Integration test cases for Thread Border Routing feature. */
@@ -110,7 +104,8 @@
     @Rule public final ThreadFeatureCheckerRule mThreadRule = new ThreadFeatureCheckerRule();
 
     private final Context mContext = ApplicationProvider.getApplicationContext();
-    private ThreadNetworkController mController;
+    private final ThreadNetworkControllerWrapper mController =
+            ThreadNetworkControllerWrapper.newInstance(mContext);
     private OtDaemonController mOtCtl;
     private HandlerThread mHandlerThread;
     private Handler mHandler;
@@ -121,11 +116,6 @@
 
     @Before
     public void setUp() throws Exception {
-        final ThreadNetworkManager manager = mContext.getSystemService(ThreadNetworkManager.class);
-        if (manager != null) {
-            mController = manager.getAllThreadNetworkControllers().get(0);
-        }
-
         // TODO: b/323301831 - This is a workaround to avoid unnecessary delay to re-form a network
         mOtCtl = new OtDaemonController();
         mOtCtl.factoryReset();
@@ -136,13 +126,12 @@
         mFtds = new ArrayList<>();
 
         setUpInfraNetwork();
-
-        // BR forms a network.
-        startBrLeader();
+        mController.setEnabledAndWait(true);
+        mController.joinAndWait(DEFAULT_DATASET);
 
         // Creates a infra network device.
         mInfraNetworkReader = newPacketReader(mInfraNetworkTracker.getTestIface(), mHandler);
-        startInfraDevice();
+        startInfraDeviceAndWaitForOnLinkAddr();
 
         // Create Ftds
         for (int i = 0; i < NUM_FTD; ++i) {
@@ -152,16 +141,8 @@
 
     @After
     public void tearDown() throws Exception {
-        runAsShell(
-                PERMISSION_THREAD_NETWORK_PRIVILEGED,
-                NETWORK_SETTINGS,
-                () -> {
-                    CountDownLatch latch = new CountDownLatch(2);
-                    mController.setTestNetworkAsUpstream(
-                            null, directExecutor(), v -> latch.countDown());
-                    mController.leave(directExecutor(), v -> latch.countDown());
-                    latch.await(10, TimeUnit.SECONDS);
-                });
+        mController.setTestNetworkAsUpstreamAndWait(null);
+        mController.leaveAndWait();
         tearDownInfraNetwork();
 
         mHandlerThread.quitSafely();
@@ -184,15 +165,13 @@
          * </pre>
          */
 
-        // Let ftd join the network.
         FullThreadDevice ftd = mFtds.get(0);
         startFtdChild(ftd);
 
-        // Infra device sends an echo request to FTD's OMR.
         mInfraDevice.sendEchoRequest(ftd.getOmrAddress());
 
         // Infra device receives an echo reply sent by FTD.
-        assertNotNull(pollForPacketOnInfraNetwork(ICMPV6_ECHO_REPLY_TYPE, null /* srcAddress */));
+        assertNotNull(pollForPacketOnInfraNetwork(ICMPV6_ECHO_REPLY_TYPE, ftd.getOmrAddress()));
     }
 
     @Test
@@ -207,10 +186,7 @@
          * </pre>
          */
 
-        // Form the network.
-        mOtCtl.factoryReset();
-        startBrLeader();
-        startInfraDevice();
+        startInfraDeviceAndWaitForOnLinkAddr();
         FullThreadDevice ftd = mFtds.get(0);
         startFtdChild(ftd);
 
@@ -220,6 +196,36 @@
     }
 
     @Test
+    public void unicastRouting_afterInfraNetworkSwitchInfraDevicePingThreadDeviceOmr_replyReceived()
+            throws Exception {
+        /*
+         * <pre>
+         * Topology:
+         *                 infra network                       Thread
+         * infra device -------------------- Border Router -------------- Full Thread device
+         *                                   (Cuttlefish)
+         * </pre>
+         */
+
+        FullThreadDevice ftd = mFtds.get(0);
+        startFtdChild(ftd);
+        Inet6Address ftdOmr = ftd.getOmrAddress();
+        // Create a new infra network and let Thread prefer it
+        TestNetworkTracker oldInfraNetworkTracker = mInfraNetworkTracker;
+        try {
+            setUpInfraNetwork();
+            mInfraNetworkReader = newPacketReader(mInfraNetworkTracker.getTestIface(), mHandler);
+            startInfraDeviceAndWaitForOnLinkAddr();
+
+            mInfraDevice.sendEchoRequest(ftd.getOmrAddress());
+
+            assertNotNull(pollForPacketOnInfraNetwork(ICMPV6_ECHO_REPLY_TYPE, ftdOmr));
+        } finally {
+            runAsShell(MANAGE_TEST_NETWORKS, () -> oldInfraNetworkTracker.teardown());
+        }
+    }
+
+    @Test
     public void unicastRouting_borderRouterSendsUdpToThreadDevice_datagramReceived()
             throws Exception {
         /*
@@ -231,19 +237,10 @@
          * </pre>
          */
 
-        // BR forms a network.
-        CompletableFuture<Void> joinFuture = new CompletableFuture<>();
-        runAsShell(
-                PERMISSION_THREAD_NETWORK_PRIVILEGED,
-                () -> mController.join(DEFAULT_DATASET, directExecutor(), joinFuture::complete));
-        joinFuture.get(RESTART_JOIN_TIMEOUT.toMillis(), MILLISECONDS);
-
-        // Creates a Full Thread Device (FTD) and lets it join the network.
         FullThreadDevice ftd = mFtds.get(0);
         startFtdChild(ftd);
-        Inet6Address ftdOmr = ftd.getOmrAddress();
-        Inet6Address ftdMlEid = ftd.getMlEid();
-        assertNotNull(ftdMlEid);
+        Inet6Address ftdOmr = requireNonNull(ftd.getOmrAddress());
+        Inet6Address ftdMlEid = requireNonNull(ftd.getMlEid());
 
         ftd.udpBind(ftdOmr, 12345);
         sendUdpMessage(ftdOmr, 12345, "aaaaaaaa");
@@ -508,7 +505,7 @@
         Inet6Address ftdLla = ftd.getLinkLocalAddress();
         assertNotNull(ftdLla);
 
-        ftd.ping(GROUP_ADDR_SCOPE_4, ftdLla, 100 /* size */, 1 /* count */);
+        ftd.ping(GROUP_ADDR_SCOPE_4, ftdLla);
 
         assertNull(
                 pollForPacketOnInfraNetwork(ICMPV6_ECHO_REQUEST_TYPE, ftdLla, GROUP_ADDR_SCOPE_4));
@@ -532,7 +529,7 @@
         assertFalse(ftdMlas.isEmpty());
 
         for (Inet6Address ftdMla : ftdMlas) {
-            ftd.ping(GROUP_ADDR_SCOPE_4, ftdMla, 100 /* size */, 1 /* count */);
+            ftd.ping(GROUP_ADDR_SCOPE_4, ftdMla);
 
             assertNull(
                     pollForPacketOnInfraNetwork(
@@ -562,7 +559,7 @@
         tearDownInfraNetwork();
         setUpInfraNetwork();
         mInfraNetworkReader = newPacketReader(mInfraNetworkTracker.getTestIface(), mHandler);
-        startInfraDevice();
+        startInfraDeviceAndWaitForOnLinkAddr();
 
         mInfraDevice.sendEchoRequest(GROUP_ADDR_SCOPE_5);
 
@@ -589,7 +586,7 @@
         tearDownInfraNetwork();
         setUpInfraNetwork();
         mInfraNetworkReader = newPacketReader(mInfraNetworkTracker.getTestIface(), mHandler);
-        startInfraDevice();
+        startInfraDeviceAndWaitForOnLinkAddr();
 
         ftd.ping(GROUP_ADDR_SCOPE_4);
 
@@ -597,38 +594,21 @@
                 pollForPacketOnInfraNetwork(ICMPV6_ECHO_REQUEST_TYPE, ftdOmr, GROUP_ADDR_SCOPE_4));
     }
 
-    private void setUpInfraNetwork() {
+    private void setUpInfraNetwork() throws Exception {
         mInfraNetworkTracker =
                 runAsShell(
                         MANAGE_TEST_NETWORKS,
                         () ->
                                 initTestNetwork(
                                         mContext, new LinkProperties(), 5000 /* timeoutMs */));
-        runAsShell(
-                PERMISSION_THREAD_NETWORK_PRIVILEGED,
-                NETWORK_SETTINGS,
-                () -> {
-                    CompletableFuture<Void> future = new CompletableFuture<>();
-                    mController.setTestNetworkAsUpstream(
-                            mInfraNetworkTracker.getTestIface().getInterfaceName(),
-                            directExecutor(),
-                            future::complete);
-                    future.get(5, TimeUnit.SECONDS);
-                });
+        mController.setTestNetworkAsUpstreamAndWait(
+                mInfraNetworkTracker.getTestIface().getInterfaceName());
     }
 
     private void tearDownInfraNetwork() {
         runAsShell(MANAGE_TEST_NETWORKS, () -> mInfraNetworkTracker.teardown());
     }
 
-    private void startBrLeader() throws Exception {
-        CompletableFuture<Void> joinFuture = new CompletableFuture<>();
-        runAsShell(
-                PERMISSION_THREAD_NETWORK_PRIVILEGED,
-                () -> mController.join(DEFAULT_DATASET, directExecutor(), joinFuture::complete));
-        joinFuture.get(RESTART_JOIN_TIMEOUT.toSeconds(), TimeUnit.SECONDS);
-    }
-
     private void startFtdChild(FullThreadDevice ftd) throws Exception {
         ftd.factoryReset();
         ftd.joinNetwork(DEFAULT_DATASET);
@@ -638,7 +618,7 @@
         assertNotNull(ftdOmr);
     }
 
-    private void startInfraDevice() throws Exception {
+    private void startInfraDeviceAndWaitForOnLinkAddr() throws Exception {
         mInfraDevice =
                 new InfraNetworkDevice(MacAddress.fromString("1:2:3:4:5:6"), mInfraNetworkReader);
         mInfraDevice.runSlaac(Duration.ofSeconds(60));
diff --git a/thread/tests/integration/src/android/net/thread/ServiceDiscoveryTest.java b/thread/tests/integration/src/android/net/thread/ServiceDiscoveryTest.java
index 4e6ee5f..491331c 100644
--- a/thread/tests/integration/src/android/net/thread/ServiceDiscoveryTest.java
+++ b/thread/tests/integration/src/android/net/thread/ServiceDiscoveryTest.java
@@ -16,11 +16,9 @@
 
 package android.net.thread;
 
-import static android.Manifest.permission.NETWORK_SETTINGS;
 import static android.net.InetAddresses.parseNumericAddress;
-import static android.net.thread.ThreadNetworkManager.PERMISSION_THREAD_NETWORK_PRIVILEGED;
+import static android.net.nsd.NsdManager.PROTOCOL_DNS_SD;
 import static android.net.thread.utils.IntegrationTestUtils.JOIN_TIMEOUT;
-import static android.net.thread.utils.IntegrationTestUtils.RESTART_JOIN_TIMEOUT;
 import static android.net.thread.utils.IntegrationTestUtils.SERVICE_DISCOVERY_TIMEOUT;
 import static android.net.thread.utils.IntegrationTestUtils.discoverForServiceLost;
 import static android.net.thread.utils.IntegrationTestUtils.discoverService;
@@ -28,25 +26,24 @@
 import static android.net.thread.utils.IntegrationTestUtils.resolveServiceUntil;
 import static android.net.thread.utils.IntegrationTestUtils.waitFor;
 
-import static com.android.testutils.TestPermissionUtil.runAsShell;
-
 import static com.google.common.io.BaseEncoding.base16;
 import static com.google.common.truth.Truth.assertThat;
-import static com.google.common.util.concurrent.MoreExecutors.directExecutor;
 
 import static org.junit.Assert.assertThrows;
 
+import static java.nio.charset.StandardCharsets.UTF_8;
 import static java.util.concurrent.TimeUnit.MILLISECONDS;
-import static java.util.concurrent.TimeUnit.SECONDS;
 
 import android.content.Context;
 import android.net.nsd.NsdManager;
 import android.net.nsd.NsdServiceInfo;
 import android.net.thread.utils.FullThreadDevice;
+import android.net.thread.utils.OtDaemonController;
 import android.net.thread.utils.TapTestNetworkTracker;
 import android.net.thread.utils.ThreadFeatureCheckerRule;
 import android.net.thread.utils.ThreadFeatureCheckerRule.RequiresSimulationThreadDevice;
 import android.net.thread.utils.ThreadFeatureCheckerRule.RequiresThreadFeature;
+import android.net.thread.utils.ThreadNetworkControllerWrapper;
 import android.os.HandlerThread;
 
 import androidx.test.core.app.ApplicationProvider;
@@ -70,6 +67,7 @@
 import java.util.Map;
 import java.util.Random;
 import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutionException;
 import java.util.concurrent.TimeoutException;
 
 /** Integration test cases for Service Discovery feature. */
@@ -99,27 +97,18 @@
     @Rule public final ThreadFeatureCheckerRule mThreadRule = new ThreadFeatureCheckerRule();
 
     private final Context mContext = ApplicationProvider.getApplicationContext();
-
+    private final ThreadNetworkControllerWrapper mController =
+            ThreadNetworkControllerWrapper.newInstance(mContext);
+    private final OtDaemonController mOtCtl = new OtDaemonController();
     private HandlerThread mHandlerThread;
-    private ThreadNetworkController mController;
     private NsdManager mNsdManager;
     private TapTestNetworkTracker mTestNetworkTracker;
     private List<FullThreadDevice> mFtds;
+    private List<RegistrationListener> mRegistrationListeners = new ArrayList<>();
 
     @Before
     public void setUp() throws Exception {
-        final ThreadNetworkManager manager = mContext.getSystemService(ThreadNetworkManager.class);
-        if (manager != null) {
-            mController = manager.getAllThreadNetworkControllers().get(0);
-        }
-
-        // BR forms a network.
-        CompletableFuture<Void> joinFuture = new CompletableFuture<>();
-        runAsShell(
-                PERMISSION_THREAD_NETWORK_PRIVILEGED,
-                () -> mController.join(DEFAULT_DATASET, directExecutor(), joinFuture::complete));
-        joinFuture.get(RESTART_JOIN_TIMEOUT.toMillis(), MILLISECONDS);
-
+        mController.joinAndWait(DEFAULT_DATASET);
         mNsdManager = mContext.getSystemService(NsdManager.class);
 
         mHandlerThread = new HandlerThread(TAG);
@@ -127,17 +116,8 @@
 
         mTestNetworkTracker = new TapTestNetworkTracker(mContext, mHandlerThread.getLooper());
         assertThat(mTestNetworkTracker).isNotNull();
-        runAsShell(
-                PERMISSION_THREAD_NETWORK_PRIVILEGED,
-                NETWORK_SETTINGS,
-                () -> {
-                    CompletableFuture<Void> future = new CompletableFuture<>();
-                    mController.setTestNetworkAsUpstream(
-                            mTestNetworkTracker.getInterfaceName(),
-                            directExecutor(),
-                            v -> future.complete(null));
-                    future.get(5, SECONDS);
-                });
+        mController.setTestNetworkAsUpstreamAndWait(mTestNetworkTracker.getInterfaceName());
+
         // Create the FTDs in setUp() so that the FTDs can be safely released in tearDown().
         // Don't create new FTDs in test cases.
         mFtds = new ArrayList<>();
@@ -150,6 +130,9 @@
 
     @After
     public void tearDown() throws Exception {
+        for (RegistrationListener listener : mRegistrationListeners) {
+            unregisterService(listener);
+        }
         for (FullThreadDevice ftd : mFtds) {
             // Clear registered SRP hosts and services
             if (ftd.isSrpHostRegistered()) {
@@ -164,18 +147,8 @@
             mHandlerThread.quitSafely();
             mHandlerThread.join();
         }
-        runAsShell(
-                PERMISSION_THREAD_NETWORK_PRIVILEGED,
-                NETWORK_SETTINGS,
-                () -> {
-                    CompletableFuture<Void> setUpstreamFuture = new CompletableFuture<>();
-                    CompletableFuture<Void> leaveFuture = new CompletableFuture<>();
-                    mController.setTestNetworkAsUpstream(
-                            null, directExecutor(), v -> setUpstreamFuture.complete(null));
-                    mController.leave(directExecutor(), v -> leaveFuture.complete(null));
-                    setUpstreamFuture.get(5, SECONDS);
-                    leaveFuture.get(5, SECONDS);
-                });
+        mController.setTestNetworkAsUpstreamAndWait(null);
+        mController.leaveAndWait();
     }
 
     @Test
@@ -336,6 +309,187 @@
         assertThrows(TimeoutException.class, () -> discoverService(mNsdManager, "_test._udp"));
     }
 
+    @Test
+    public void meshcopOverlay_vendorAndModelNameAreSetToOverlayValue() throws Exception {
+        NsdServiceInfo discoveredService = discoverService(mNsdManager, "_meshcop._udp");
+        assertThat(discoveredService).isNotNull();
+        NsdServiceInfo meshcopService = resolveService(mNsdManager, discoveredService);
+
+        Map<String, byte[]> txtMap = meshcopService.getAttributes();
+        assertThat(txtMap.get("vn")).isEqualTo("Android".getBytes(UTF_8));
+        assertThat(txtMap.get("mn")).isEqualTo("Thread Border Router".getBytes(UTF_8));
+    }
+
+    @Test
+    public void discoveryProxy_multipleClientsBrowseAndResolveServiceOverMdns() throws Exception {
+        /*
+         * <pre>
+         * Topology:
+         *                    Thread
+         *  Border Router -------------- Full Thread device
+         *  (Cuttlefish)
+         * </pre>
+         */
+
+        RegistrationListener listener = new RegistrationListener();
+        NsdServiceInfo info = new NsdServiceInfo();
+        info.setServiceType("_testservice._tcp");
+        info.setServiceName("test-service");
+        info.setPort(12345);
+        info.setHostname("testhost");
+        info.setHostAddresses(List.of(parseNumericAddress("2001::1")));
+        info.setAttribute("key1", bytes(0x01, 0x02));
+        info.setAttribute("key2", bytes(0x03));
+        registerService(info, listener);
+        mRegistrationListeners.add(listener);
+        for (int i = 0; i < NUM_FTD; ++i) {
+            FullThreadDevice ftd = mFtds.get(i);
+            ftd.joinNetwork(DEFAULT_DATASET);
+            ftd.waitForStateAnyOf(List.of("router", "child"), JOIN_TIMEOUT);
+            ftd.setDnsServerAddress(mOtCtl.getMlEid().getHostAddress());
+        }
+        final ArrayList<NsdServiceInfo> browsedServices = new ArrayList<>();
+        final ArrayList<NsdServiceInfo> resolvedServices = new ArrayList<>();
+        final ArrayList<Thread> threads = new ArrayList<>();
+        for (int i = 0; i < NUM_FTD; ++i) {
+            browsedServices.add(null);
+            resolvedServices.add(null);
+        }
+        for (int i = 0; i < NUM_FTD; ++i) {
+            final FullThreadDevice ftd = mFtds.get(i);
+            final int index = i;
+            Runnable task =
+                    () -> {
+                        browsedServices.set(
+                                index,
+                                ftd.browseService("_testservice._tcp.default.service.arpa."));
+                        resolvedServices.set(
+                                index,
+                                ftd.resolveService(
+                                        "test-service", "_testservice._tcp.default.service.arpa."));
+                    };
+            threads.add(new Thread(task));
+        }
+        for (Thread thread : threads) {
+            thread.start();
+        }
+        for (Thread thread : threads) {
+            thread.join();
+        }
+
+        for (int i = 0; i < NUM_FTD; ++i) {
+            NsdServiceInfo browsedService = browsedServices.get(i);
+            assertThat(browsedService.getServiceName()).isEqualTo("test-service");
+            assertThat(browsedService.getPort()).isEqualTo(12345);
+
+            NsdServiceInfo resolvedService = resolvedServices.get(i);
+            assertThat(resolvedService.getServiceName()).isEqualTo("test-service");
+            assertThat(resolvedService.getPort()).isEqualTo(12345);
+            assertThat(resolvedService.getHostname()).isEqualTo("testhost.default.service.arpa.");
+            assertThat(resolvedService.getHostAddresses())
+                    .containsExactly(parseNumericAddress("2001::1"));
+            assertThat(resolvedService.getAttributes())
+                    .comparingValuesUsing(BYTE_ARRAY_EQUALITY)
+                    .containsExactly("key1", bytes(0x01, 0x02), "key2", bytes(3));
+        }
+    }
+
+    @Test
+    public void discoveryProxy_browseAndResolveServiceAtSrpServer() throws Exception {
+        /*
+         * <pre>
+         * Topology:
+         *                    Thread
+         *  Border Router -------+------ SRP client
+         *  (Cuttlefish)         |
+         *                       +------ DNS client
+         *
+         * </pre>
+         */
+        FullThreadDevice srpClient = mFtds.get(0);
+        srpClient.joinNetwork(DEFAULT_DATASET);
+        srpClient.waitForStateAnyOf(List.of("router", "child"), JOIN_TIMEOUT);
+        srpClient.setSrpHostname("my-host");
+        srpClient.setSrpHostAddresses(List.of((Inet6Address) parseNumericAddress("2001::1")));
+        srpClient.addSrpService(
+                "my-service",
+                "_test._udp",
+                List.of("_sub1"),
+                12345 /* port */,
+                Map.of("key1", bytes(0x01, 0x02), "key2", bytes(0x03)));
+
+        FullThreadDevice dnsClient = mFtds.get(1);
+        dnsClient.joinNetwork(DEFAULT_DATASET);
+        dnsClient.waitForStateAnyOf(List.of("router", "child"), JOIN_TIMEOUT);
+        dnsClient.setDnsServerAddress(mOtCtl.getMlEid().getHostAddress());
+
+        NsdServiceInfo browsedService = dnsClient.browseService("_test._udp.default.service.arpa.");
+        assertThat(browsedService.getServiceName()).isEqualTo("my-service");
+        assertThat(browsedService.getPort()).isEqualTo(12345);
+        assertThat(browsedService.getHostname()).isEqualTo("my-host.default.service.arpa.");
+        assertThat(browsedService.getHostAddresses())
+                .containsExactly(parseNumericAddress("2001::1"));
+        assertThat(browsedService.getAttributes())
+                .comparingValuesUsing(BYTE_ARRAY_EQUALITY)
+                .containsExactly("key1", bytes(0x01, 0x02), "key2", bytes(3));
+
+        NsdServiceInfo resolvedService =
+                dnsClient.resolveService("my-service", "_test._udp.default.service.arpa.");
+        assertThat(resolvedService.getServiceName()).isEqualTo("my-service");
+        assertThat(resolvedService.getPort()).isEqualTo(12345);
+        assertThat(resolvedService.getHostname()).isEqualTo("my-host.default.service.arpa.");
+        assertThat(resolvedService.getHostAddresses())
+                .containsExactly(parseNumericAddress("2001::1"));
+        assertThat(resolvedService.getAttributes())
+                .comparingValuesUsing(BYTE_ARRAY_EQUALITY)
+                .containsExactly("key1", bytes(0x01, 0x02), "key2", bytes(3));
+    }
+
+    private void registerService(NsdServiceInfo serviceInfo, RegistrationListener listener)
+            throws InterruptedException, ExecutionException, TimeoutException {
+        mNsdManager.registerService(serviceInfo, PROTOCOL_DNS_SD, listener);
+        listener.waitForRegistered();
+    }
+
+    private void unregisterService(RegistrationListener listener)
+            throws InterruptedException, ExecutionException, TimeoutException {
+        mNsdManager.unregisterService(listener);
+        listener.waitForUnregistered();
+    }
+
+    private static class RegistrationListener implements NsdManager.RegistrationListener {
+        private final CompletableFuture<Void> mRegisteredFuture = new CompletableFuture<>();
+        private final CompletableFuture<Void> mUnRegisteredFuture = new CompletableFuture<>();
+
+        RegistrationListener() {}
+
+        @Override
+        public void onRegistrationFailed(NsdServiceInfo serviceInfo, int errorCode) {}
+
+        @Override
+        public void onUnregistrationFailed(NsdServiceInfo serviceInfo, int errorCode) {}
+
+        @Override
+        public void onServiceRegistered(NsdServiceInfo serviceInfo) {
+            mRegisteredFuture.complete(null);
+        }
+
+        @Override
+        public void onServiceUnregistered(NsdServiceInfo serviceInfo) {
+            mUnRegisteredFuture.complete(null);
+        }
+
+        public void waitForRegistered()
+                throws InterruptedException, ExecutionException, TimeoutException {
+            mRegisteredFuture.get(SERVICE_DISCOVERY_TIMEOUT.toMillis(), MILLISECONDS);
+        }
+
+        public void waitForUnregistered()
+                throws InterruptedException, ExecutionException, TimeoutException {
+            mUnRegisteredFuture.get(SERVICE_DISCOVERY_TIMEOUT.toMillis(), MILLISECONDS);
+        }
+    }
+
     private static byte[] bytes(int... byteInts) {
         byte[] bytes = new byte[byteInts.length];
         for (int i = 0; i < byteInts.length; ++i) {
diff --git a/thread/tests/integration/src/android/net/thread/ThreadIntegrationTest.java b/thread/tests/integration/src/android/net/thread/ThreadIntegrationTest.java
index 9585d7d..c70f3af 100644
--- a/thread/tests/integration/src/android/net/thread/ThreadIntegrationTest.java
+++ b/thread/tests/integration/src/android/net/thread/ThreadIntegrationTest.java
@@ -16,31 +16,23 @@
 
 package android.net.thread;
 
-import static android.Manifest.permission.NETWORK_SETTINGS;
 import static android.net.thread.ThreadNetworkController.DEVICE_ROLE_DETACHED;
 import static android.net.thread.ThreadNetworkController.DEVICE_ROLE_LEADER;
 import static android.net.thread.ThreadNetworkController.DEVICE_ROLE_STOPPED;
-import static android.net.thread.ThreadNetworkManager.PERMISSION_THREAD_NETWORK_PRIVILEGED;
 import static android.net.thread.utils.IntegrationTestUtils.CALLBACK_TIMEOUT;
-import static android.net.thread.utils.IntegrationTestUtils.LEAVE_TIMEOUT;
 import static android.net.thread.utils.IntegrationTestUtils.RESTART_JOIN_TIMEOUT;
-import static android.net.thread.utils.IntegrationTestUtils.waitForStateAnyOf;
 
 import static com.android.compatibility.common.util.SystemUtil.runShellCommand;
-import static com.android.testutils.TestPermissionUtil.runAsShell;
+import static com.android.compatibility.common.util.SystemUtil.runShellCommandOrThrow;
 
 import static com.google.common.io.BaseEncoding.base16;
 import static com.google.common.truth.Truth.assertThat;
-import static com.google.common.util.concurrent.MoreExecutors.directExecutor;
 
-import static java.util.concurrent.TimeUnit.MILLISECONDS;
-
-import android.annotation.Nullable;
 import android.content.Context;
-import android.net.thread.ThreadNetworkController.StateCallback;
 import android.net.thread.utils.OtDaemonController;
 import android.net.thread.utils.ThreadFeatureCheckerRule;
 import android.net.thread.utils.ThreadFeatureCheckerRule.RequiresThreadFeature;
+import android.net.thread.utils.ThreadNetworkControllerWrapper;
 import android.os.SystemClock;
 
 import androidx.test.core.app.ApplicationProvider;
@@ -55,7 +47,6 @@
 
 import java.net.Inet6Address;
 import java.util.List;
-import java.util.concurrent.CompletableFuture;
 
 /** Tests for E2E Android Thread integration with ot-daemon, ConnectivityService, etc.. */
 @LargeTest
@@ -76,17 +67,14 @@
     @Rule public final ThreadFeatureCheckerRule mThreadRule = new ThreadFeatureCheckerRule();
 
     private final Context mContext = ApplicationProvider.getApplicationContext();
-    private ThreadNetworkController mController;
+    private final ThreadNetworkControllerWrapper mController =
+            ThreadNetworkControllerWrapper.newInstance(mContext);
     private OtDaemonController mOtCtl;
 
     @Before
     public void setUp() throws Exception {
-        mController =
-                mContext.getSystemService(ThreadNetworkManager.class)
-                        .getAllThreadNetworkControllers()
-                        .get(0);
         mOtCtl = new OtDaemonController();
-        leaveAndWait(mController);
+        mController.leaveAndWait();
 
         // TODO: b/323301831 - This is a workaround to avoid unnecessary delay to re-form a network
         mOtCtl.factoryReset();
@@ -94,53 +82,56 @@
 
     @After
     public void tearDown() throws Exception {
-        setTestUpStreamNetworkAndWait(mController, null);
-        leaveAndWait(mController);
+        mController.setTestNetworkAsUpstreamAndWait(null);
+        mController.leaveAndWait();
     }
 
     @Test
     public void otDaemonRestart_notJoinedAndStopped_deviceRoleIsStopped() throws Exception {
-        leaveAndWait(mController);
+        mController.leaveAndWait();
 
         runShellCommand("stop ot-daemon");
         // TODO(b/323331973): the sleep is needed to workaround the race conditions
         SystemClock.sleep(200);
 
-        waitForStateAnyOf(mController, List.of(DEVICE_ROLE_STOPPED), CALLBACK_TIMEOUT);
+        mController.waitForRole(DEVICE_ROLE_STOPPED, CALLBACK_TIMEOUT);
     }
 
     @Test
-    public void otDaemonRestart_JoinedNetworkAndStopped_autoRejoined() throws Exception {
-        joinAndWait(mController, DEFAULT_DATASET);
+    public void otDaemonRestart_JoinedNetworkAndStopped_autoRejoinedAndTunIfStateConsistent()
+            throws Exception {
+        mController.joinAndWait(DEFAULT_DATASET);
 
         runShellCommand("stop ot-daemon");
 
-        waitForStateAnyOf(mController, List.of(DEVICE_ROLE_DETACHED), CALLBACK_TIMEOUT);
-        waitForStateAnyOf(mController, List.of(DEVICE_ROLE_LEADER), RESTART_JOIN_TIMEOUT);
+        mController.waitForRole(DEVICE_ROLE_DETACHED, CALLBACK_TIMEOUT);
+        mController.waitForRole(DEVICE_ROLE_LEADER, RESTART_JOIN_TIMEOUT);
+        assertThat(mOtCtl.isInterfaceUp()).isTrue();
+        assertThat(runShellCommand("ifconfig thread-wpan")).contains("UP POINTOPOINT RUNNING");
     }
 
     @Test
     public void otDaemonFactoryReset_deviceRoleIsStopped() throws Exception {
-        joinAndWait(mController, DEFAULT_DATASET);
+        mController.joinAndWait(DEFAULT_DATASET);
 
         mOtCtl.factoryReset();
 
-        assertThat(getDeviceRole(mController)).isEqualTo(DEVICE_ROLE_STOPPED);
+        assertThat(mController.getDeviceRole()).isEqualTo(DEVICE_ROLE_STOPPED);
     }
 
     @Test
     public void otDaemonFactoryReset_addressesRemoved() throws Exception {
-        joinAndWait(mController, DEFAULT_DATASET);
+        mController.joinAndWait(DEFAULT_DATASET);
 
         mOtCtl.factoryReset();
-        String ifconfig = runShellCommand("ifconfig thread-wpan");
 
+        String ifconfig = runShellCommand("ifconfig thread-wpan");
         assertThat(ifconfig).doesNotContain("inet6 addr");
     }
 
     @Test
     public void tunInterface_joinedNetwork_otAddressesAddedToTunInterface() throws Exception {
-        joinAndWait(mController, DEFAULT_DATASET);
+        mController.joinAndWait(DEFAULT_DATASET);
 
         String ifconfig = runShellCommand("ifconfig thread-wpan");
         List<Inet6Address> otAddresses = mOtCtl.getAddresses();
@@ -150,48 +141,22 @@
         }
     }
 
+    @Test
+    public void otDaemonRestart_latestCountryCodeIsSetToOtDaemon() throws Exception {
+        runThreadCommand("force-country-code enabled CN");
+
+        runShellCommand("stop ot-daemon");
+        // TODO(b/323331973): the sleep is needed to workaround the race conditions
+        SystemClock.sleep(200);
+        mController.waitForRole(DEVICE_ROLE_STOPPED, CALLBACK_TIMEOUT);
+
+        assertThat(mOtCtl.getCountryCode()).isEqualTo("CN");
+    }
+
+    private static String runThreadCommand(String cmd) {
+        return runShellCommandOrThrow("cmd thread_network " + cmd);
+    }
+
     // TODO (b/323300829): add more tests for integration with linux platform and
     // ConnectivityService
-
-    private static int getDeviceRole(ThreadNetworkController controller) throws Exception {
-        CompletableFuture<Integer> future = new CompletableFuture<>();
-        StateCallback callback = future::complete;
-        controller.registerStateCallback(directExecutor(), callback);
-        try {
-            return future.get(CALLBACK_TIMEOUT.toMillis(), MILLISECONDS);
-        } finally {
-            controller.unregisterStateCallback(callback);
-        }
-    }
-
-    private static void joinAndWait(
-            ThreadNetworkController controller, ActiveOperationalDataset activeDataset)
-            throws Exception {
-        runAsShell(
-                PERMISSION_THREAD_NETWORK_PRIVILEGED,
-                () -> controller.join(activeDataset, directExecutor(), result -> {}));
-        waitForStateAnyOf(controller, List.of(DEVICE_ROLE_LEADER), RESTART_JOIN_TIMEOUT);
-    }
-
-    private static void leaveAndWait(ThreadNetworkController controller) throws Exception {
-        CompletableFuture<Void> future = new CompletableFuture<>();
-        runAsShell(
-                PERMISSION_THREAD_NETWORK_PRIVILEGED,
-                () -> controller.leave(directExecutor(), future::complete));
-        future.get(LEAVE_TIMEOUT.toMillis(), MILLISECONDS);
-    }
-
-    private static void setTestUpStreamNetworkAndWait(
-            ThreadNetworkController controller, @Nullable String networkInterfaceName)
-            throws Exception {
-        CompletableFuture<Void> future = new CompletableFuture<>();
-        runAsShell(
-                PERMISSION_THREAD_NETWORK_PRIVILEGED,
-                NETWORK_SETTINGS,
-                () -> {
-                    controller.setTestNetworkAsUpstream(
-                            networkInterfaceName, directExecutor(), future::complete);
-                });
-        future.get(CALLBACK_TIMEOUT.toMillis(), MILLISECONDS);
-    }
 }
diff --git a/thread/tests/integration/src/android/net/thread/ThreadNetworkControllerTest.java b/thread/tests/integration/src/android/net/thread/ThreadNetworkControllerTest.java
new file mode 100644
index 0000000..ba04348
--- /dev/null
+++ b/thread/tests/integration/src/android/net/thread/ThreadNetworkControllerTest.java
@@ -0,0 +1,171 @@
+/*
+ * Copyright (C) 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package android.net.thread;
+
+import static android.net.thread.ThreadNetworkException.ERROR_UNSUPPORTED_OPERATION;
+
+import static androidx.test.platform.app.InstrumentationRegistry.getInstrumentation;
+
+import static com.android.testutils.TestPermissionUtil.runAsShell;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import static org.junit.Assert.assertThrows;
+
+import android.content.Context;
+import android.net.thread.utils.ThreadFeatureCheckerRule;
+import android.net.thread.utils.ThreadFeatureCheckerRule.RequiresThreadFeature;
+import android.os.OutcomeReceiver;
+import android.util.SparseIntArray;
+
+import androidx.test.core.app.ApplicationProvider;
+import androidx.test.filters.LargeTest;
+import androidx.test.runner.AndroidJUnit4;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+
+/** Tests for hide methods of {@link ThreadNetworkController}. */
+@LargeTest
+@RequiresThreadFeature
+@RunWith(AndroidJUnit4.class)
+public class ThreadNetworkControllerTest {
+    private static final int VALID_POWER = 32_767;
+    private static final int INVALID_POWER = 32_768;
+    private static final int VALID_CHANNEL = 20;
+    private static final int INVALID_CHANNEL = 10;
+    private static final String THREAD_NETWORK_PRIVILEGED =
+            "android.permission.THREAD_NETWORK_PRIVILEGED";
+
+    private static final SparseIntArray CHANNEL_MAX_POWERS =
+            new SparseIntArray() {
+                {
+                    put(20, 32767);
+                }
+            };
+
+    @Rule public final ThreadFeatureCheckerRule mThreadRule = new ThreadFeatureCheckerRule();
+
+    private final Context mContext = ApplicationProvider.getApplicationContext();
+    private ExecutorService mExecutor;
+    private ThreadNetworkController mController;
+
+    @Before
+    public void setUp() throws Exception {
+        mController =
+                mContext.getSystemService(ThreadNetworkManager.class)
+                        .getAllThreadNetworkControllers()
+                        .get(0);
+
+        mExecutor = Executors.newSingleThreadExecutor();
+    }
+
+    @After
+    public void tearDown() throws Exception {
+        dropAllPermissions();
+    }
+
+    @Test
+    public void setChannelMaxPowers_withPrivilegedPermission_success() throws Exception {
+        CompletableFuture<Void> powerFuture = new CompletableFuture<>();
+
+        runAsShell(
+                THREAD_NETWORK_PRIVILEGED,
+                () ->
+                        mController.setChannelMaxPowers(
+                                CHANNEL_MAX_POWERS, mExecutor, newOutcomeReceiver(powerFuture)));
+
+        try {
+            assertThat(powerFuture.get()).isNull();
+        } catch (ExecutionException exception) {
+            ThreadNetworkException thrown = (ThreadNetworkException) exception.getCause();
+            assertThat(thrown.getErrorCode()).isEqualTo(ERROR_UNSUPPORTED_OPERATION);
+        }
+    }
+
+    @Test
+    public void setChannelMaxPowers_withoutPrivilegedPermission_throwsSecurityException()
+            throws Exception {
+        dropAllPermissions();
+
+        assertThrows(
+                SecurityException.class,
+                () -> mController.setChannelMaxPowers(CHANNEL_MAX_POWERS, mExecutor, v -> {}));
+    }
+
+    @Test
+    public void setChannelMaxPowers_emptyChannelMaxPower_throwsIllegalArgumentException() {
+        assertThrows(
+                IllegalArgumentException.class,
+                () -> mController.setChannelMaxPowers(new SparseIntArray(), mExecutor, v -> {}));
+    }
+
+    @Test
+    public void setChannelMaxPowers_invalidChannel_throwsIllegalArgumentException() {
+        final SparseIntArray INVALID_CHANNEL_ARRAY =
+                new SparseIntArray() {
+                    {
+                        put(INVALID_CHANNEL, VALID_POWER);
+                    }
+                };
+
+        assertThrows(
+                IllegalArgumentException.class,
+                () -> mController.setChannelMaxPowers(INVALID_CHANNEL_ARRAY, mExecutor, v -> {}));
+    }
+
+    @Test
+    public void setChannelMaxPowers_invalidPower_throwsIllegalArgumentException() {
+        final SparseIntArray INVALID_POWER_ARRAY =
+                new SparseIntArray() {
+                    {
+                        put(VALID_CHANNEL, INVALID_POWER);
+                    }
+                };
+
+        assertThrows(
+                IllegalArgumentException.class,
+                () -> mController.setChannelMaxPowers(INVALID_POWER_ARRAY, mExecutor, v -> {}));
+    }
+
+    private static void dropAllPermissions() {
+        getInstrumentation().getUiAutomation().dropShellPermissionIdentity();
+    }
+
+    private static <V> OutcomeReceiver<V, ThreadNetworkException> newOutcomeReceiver(
+            CompletableFuture<V> future) {
+        return new OutcomeReceiver<V, ThreadNetworkException>() {
+            @Override
+            public void onResult(V result) {
+                future.complete(result);
+            }
+
+            @Override
+            public void onError(ThreadNetworkException e) {
+                future.completeExceptionally(e);
+            }
+        };
+    }
+}
diff --git a/thread/tests/integration/src/android/net/thread/ThreadNetworkShellCommandTest.java b/thread/tests/integration/src/android/net/thread/ThreadNetworkShellCommandTest.java
new file mode 100644
index 0000000..8835f40
--- /dev/null
+++ b/thread/tests/integration/src/android/net/thread/ThreadNetworkShellCommandTest.java
@@ -0,0 +1,129 @@
+/*
+ * Copyright (C) 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package android.net.thread;
+
+import static android.net.thread.ThreadNetworkController.STATE_DISABLED;
+import static android.net.thread.ThreadNetworkController.STATE_ENABLED;
+import static android.net.thread.ThreadNetworkException.ERROR_THREAD_DISABLED;
+
+import static com.android.compatibility.common.util.SystemUtil.runShellCommandOrThrow;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import static org.junit.Assert.assertThrows;
+
+import android.content.Context;
+import android.net.thread.utils.ThreadFeatureCheckerRule;
+import android.net.thread.utils.ThreadFeatureCheckerRule.RequiresThreadFeature;
+import android.net.thread.utils.ThreadNetworkControllerWrapper;
+
+import androidx.test.core.app.ApplicationProvider;
+import androidx.test.filters.LargeTest;
+import androidx.test.runner.AndroidJUnit4;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+
+import java.util.concurrent.ExecutionException;
+
+/** Integration tests for {@link ThreadNetworkShellCommand}. */
+@LargeTest
+@RequiresThreadFeature
+@RunWith(AndroidJUnit4.class)
+public class ThreadNetworkShellCommandTest {
+    @Rule public final ThreadFeatureCheckerRule mThreadRule = new ThreadFeatureCheckerRule();
+
+    private final Context mContext = ApplicationProvider.getApplicationContext();
+    private final ThreadNetworkControllerWrapper mController =
+            ThreadNetworkControllerWrapper.newInstance(mContext);
+
+    @Before
+    public void setUp() {
+        ensureThreadEnabled();
+    }
+
+    @After
+    public void tearDown() {
+        ensureThreadEnabled();
+    }
+
+    private static void ensureThreadEnabled() {
+        runThreadCommand("force-stop-ot-daemon disabled");
+        runThreadCommand("enable");
+    }
+
+    @Test
+    public void enable_threadStateIsEnabled() throws Exception {
+        runThreadCommand("enable");
+
+        assertThat(mController.getEnabledState()).isEqualTo(STATE_ENABLED);
+    }
+
+    @Test
+    public void disable_threadStateIsDisabled() throws Exception {
+        runThreadCommand("disable");
+
+        assertThat(mController.getEnabledState()).isEqualTo(STATE_DISABLED);
+    }
+
+    @Test
+    public void forceStopOtDaemon_forceStopEnabled_otDaemonServiceDisappear() {
+        runThreadCommand("force-stop-ot-daemon enabled");
+
+        assertThat(runShellCommandOrThrow("service list")).doesNotContain("ot_daemon");
+    }
+
+    @Test
+    public void forceStopOtDaemon_forceStopEnabled_canNotEnableThread() throws Exception {
+        runThreadCommand("force-stop-ot-daemon enabled");
+
+        ExecutionException thrown =
+                assertThrows(ExecutionException.class, () -> mController.setEnabledAndWait(true));
+        ThreadNetworkException cause = (ThreadNetworkException) thrown.getCause();
+        assertThat(cause.getErrorCode()).isEqualTo(ERROR_THREAD_DISABLED);
+    }
+
+    @Test
+    public void forceStopOtDaemon_forceStopDisabled_otDaemonServiceAppears() throws Exception {
+        runThreadCommand("force-stop-ot-daemon disabled");
+
+        assertThat(runShellCommandOrThrow("service list")).contains("ot_daemon");
+    }
+
+    @Test
+    public void forceStopOtDaemon_forceStopDisabled_canEnableThread() throws Exception {
+        runThreadCommand("force-stop-ot-daemon disabled");
+
+        mController.setEnabledAndWait(true);
+        assertThat(mController.getEnabledState()).isEqualTo(STATE_ENABLED);
+    }
+
+    @Test
+    public void forceCountryCode_setCN_getCountryCodeReturnsCN() {
+        runThreadCommand("force-country-code enabled CN");
+
+        final String result = runThreadCommand("get-country-code");
+        assertThat(result).contains("Thread country code = CN");
+    }
+
+    private static String runThreadCommand(String cmd) {
+        return runShellCommandOrThrow("cmd thread_network " + cmd);
+    }
+}
diff --git a/thread/tests/integration/src/android/net/thread/utils/FullThreadDevice.java b/thread/tests/integration/src/android/net/thread/utils/FullThreadDevice.java
index 6306a65..f7bb9ff 100644
--- a/thread/tests/integration/src/android/net/thread/utils/FullThreadDevice.java
+++ b/thread/tests/integration/src/android/net/thread/utils/FullThreadDevice.java
@@ -24,6 +24,7 @@
 
 import android.net.InetAddresses;
 import android.net.IpPrefix;
+import android.net.nsd.NsdServiceInfo;
 import android.net.thread.ActiveOperationalDataset;
 
 import com.google.errorprone.annotations.FormatMethod;
@@ -34,6 +35,7 @@
 import java.io.InputStreamReader;
 import java.io.OutputStreamWriter;
 import java.net.Inet6Address;
+import java.net.InetAddress;
 import java.nio.charset.StandardCharsets;
 import java.time.Duration;
 import java.util.ArrayList;
@@ -52,6 +54,13 @@
  * available commands.
  */
 public final class FullThreadDevice {
+    private static final int HOP_LIMIT = 64;
+    private static final int PING_INTERVAL = 1;
+    private static final int PING_SIZE = 100;
+    // There may not be a response for the ping command, using a short timeout to keep the tests
+    // short.
+    private static final float PING_TIMEOUT_SECONDS = 0.1f;
+
     private final Process mProcess;
     private final BufferedReader mReader;
     private final BufferedWriter mWriter;
@@ -320,6 +329,55 @@
         return false;
     }
 
+    /** Sets the DNS server address. */
+    public void setDnsServerAddress(String address) {
+        executeCommand("dns config " + address);
+    }
+
+    /** Returns the first browsed service instance of {@code serviceType}. */
+    public NsdServiceInfo browseService(String serviceType) {
+        // CLI output:
+        // DNS browse response for _testservice._tcp.default.service.arpa.
+        // test-service
+        //    Port:12345, Priority:0, Weight:0, TTL:10
+        //    Host:testhost.default.service.arpa.
+        //    HostAddress:2001:0:0:0:0:0:0:1 TTL:10
+        //    TXT:[key1=0102, key2=03] TTL:10
+
+        List<String> lines = executeCommand("dns browse " + serviceType);
+        NsdServiceInfo info = new NsdServiceInfo();
+        info.setServiceName(lines.get(1));
+        info.setServiceType(serviceType);
+        info.setPort(DnsServiceCliOutputParser.parsePort(lines.get(2)));
+        info.setHostname(DnsServiceCliOutputParser.parseHostname(lines.get(3)));
+        info.setHostAddresses(List.of(DnsServiceCliOutputParser.parseHostAddress(lines.get(4))));
+        DnsServiceCliOutputParser.parseTxtIntoServiceInfo(lines.get(5), info);
+
+        return info;
+    }
+
+    /** Returns the resolved service instance. */
+    public NsdServiceInfo resolveService(String serviceName, String serviceType) {
+        // CLI output:
+        // DNS service resolution response for test-service for service
+        // _test._tcp.default.service.arpa.
+        // Port:12345, Priority:0, Weight:0, TTL:10
+        // Host:Android.default.service.arpa.
+        // HostAddress:2001:0:0:0:0:0:0:1 TTL:10
+        // TXT:[key1=0102, key2=03] TTL:10
+
+        List<String> lines = executeCommand("dns service %s %s", serviceName, serviceType);
+        NsdServiceInfo info = new NsdServiceInfo();
+        info.setServiceName(serviceName);
+        info.setServiceType(serviceType);
+        info.setPort(DnsServiceCliOutputParser.parsePort(lines.get(1)));
+        info.setHostname(DnsServiceCliOutputParser.parseHostname(lines.get(2)));
+        info.setHostAddresses(List.of(DnsServiceCliOutputParser.parseHostAddress(lines.get(3))));
+        DnsServiceCliOutputParser.parseTxtIntoServiceInfo(lines.get(4), info);
+
+        return info;
+    }
+
     /** Runs the "factoryreset" command on the device. */
     public void factoryReset() {
         try {
@@ -339,7 +397,36 @@
         executeCommand("ipmaddr add " + address.getHostAddress());
     }
 
-    public void ping(Inet6Address address, Inet6Address source, int size, int count) {
+    public void ping(Inet6Address address, Inet6Address source) {
+        ping(
+                address,
+                source,
+                PING_SIZE,
+                1 /* count */,
+                PING_INTERVAL,
+                HOP_LIMIT,
+                PING_TIMEOUT_SECONDS);
+    }
+
+    public void ping(Inet6Address address) {
+        ping(
+                address,
+                null,
+                PING_SIZE,
+                1 /* count */,
+                PING_INTERVAL,
+                HOP_LIMIT,
+                PING_TIMEOUT_SECONDS);
+    }
+
+    private void ping(
+            Inet6Address address,
+            Inet6Address source,
+            int size,
+            int count,
+            int interval,
+            int hopLimit,
+            float timeout) {
         String cmd =
                 "ping"
                         + ((source == null) ? "" : (" -I " + source.getHostAddress()))
@@ -348,14 +435,16 @@
                         + " "
                         + size
                         + " "
-                        + count;
+                        + count
+                        + " "
+                        + interval
+                        + " "
+                        + hopLimit
+                        + " "
+                        + timeout;
         executeCommand(cmd);
     }
 
-    public void ping(Inet6Address address) {
-        ping(address, null, 100 /* size */, 1 /* count */);
-    }
-
     @FormatMethod
     private List<String> executeCommand(String commandFormat, Object... args) {
         return executeCommand(String.format(commandFormat, args));
@@ -416,4 +505,45 @@
     private static String toHexString(byte[] bytes) {
         return base16().encode(bytes);
     }
+
+    private static final class DnsServiceCliOutputParser {
+        /** Returns the first match in the input of a given regex pattern. */
+        private static Matcher firstMatchOf(String input, String regex) {
+            Matcher matcher = Pattern.compile(regex).matcher(input);
+            matcher.find();
+            return matcher;
+        }
+
+        // Example: "Port:12345"
+        private static int parsePort(String line) {
+            return Integer.parseInt(firstMatchOf(line, "Port:(\\d+)").group(1));
+        }
+
+        // Example: "Host:Android.default.service.arpa."
+        private static String parseHostname(String line) {
+            return firstMatchOf(line, "Host:(.+)").group(1);
+        }
+
+        // Example: "HostAddress:2001:0:0:0:0:0:0:1"
+        private static InetAddress parseHostAddress(String line) {
+            return InetAddresses.parseNumericAddress(
+                    firstMatchOf(line, "HostAddress:([^ ]+)").group(1));
+        }
+
+        // Example: "TXT:[key1=0102, key2=03]"
+        private static void parseTxtIntoServiceInfo(String line, NsdServiceInfo serviceInfo) {
+            String txtString = firstMatchOf(line, "TXT:\\[([^\\]]+)\\]").group(1);
+            for (String txtEntry : txtString.split(",")) {
+                String[] nameAndValue = txtEntry.trim().split("=");
+                String name = nameAndValue[0];
+                String value = nameAndValue[1];
+                byte[] bytes = new byte[value.length() / 2];
+                for (int i = 0; i < value.length(); i += 2) {
+                    byte b = (byte) ((value.charAt(i) - '0') << 4 | (value.charAt(i + 1) - '0'));
+                    bytes[i / 2] = b;
+                }
+                serviceInfo.setAttribute(name, bytes);
+            }
+        }
+    }
 }
diff --git a/thread/tests/integration/src/android/net/thread/utils/OtDaemonController.java b/thread/tests/integration/src/android/net/thread/utils/OtDaemonController.java
index 4a06fe8..f39a064 100644
--- a/thread/tests/integration/src/android/net/thread/utils/OtDaemonController.java
+++ b/thread/tests/integration/src/android/net/thread/utils/OtDaemonController.java
@@ -62,6 +62,24 @@
                 .toList();
     }
 
+    /** Returns {@code true} if the Thread interface is up. */
+    public boolean isInterfaceUp() {
+        String output = executeCommand("ifconfig");
+        return output.contains("up");
+    }
+
+    /** Returns the ML-EID of the device. */
+    public Inet6Address getMlEid() {
+        String addressStr = executeCommand("ipaddr mleid").split("\n")[0].trim();
+        return (Inet6Address) InetAddresses.parseNumericAddress(addressStr);
+    }
+
+    /** Returns the country code on ot-daemon. */
+    public String getCountryCode() {
+        String countryCodeStr = executeCommand("region").split("\n")[0].trim();
+        return countryCodeStr;
+    }
+
     public String executeCommand(String cmd) {
         return SystemUtil.runShellCommand(OT_CTL + " " + cmd);
     }
diff --git a/thread/tests/integration/src/android/net/thread/utils/ThreadNetworkControllerWrapper.java b/thread/tests/integration/src/android/net/thread/utils/ThreadNetworkControllerWrapper.java
new file mode 100644
index 0000000..7e84233
--- /dev/null
+++ b/thread/tests/integration/src/android/net/thread/utils/ThreadNetworkControllerWrapper.java
@@ -0,0 +1,208 @@
+/*
+ * Copyright (C) 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package android.net.thread.utils;
+
+import static android.Manifest.permission.ACCESS_NETWORK_STATE;
+import static android.Manifest.permission.NETWORK_SETTINGS;
+import static android.net.thread.ThreadNetworkManager.PERMISSION_THREAD_NETWORK_PRIVILEGED;
+import static android.net.thread.utils.IntegrationTestUtils.CALLBACK_TIMEOUT;
+
+import static com.android.testutils.TestPermissionUtil.runAsShell;
+
+import static com.google.common.util.concurrent.MoreExecutors.directExecutor;
+
+import static java.util.concurrent.TimeUnit.SECONDS;
+
+import android.annotation.Nullable;
+import android.content.Context;
+import android.net.thread.ActiveOperationalDataset;
+import android.net.thread.ThreadNetworkController;
+import android.net.thread.ThreadNetworkController.StateCallback;
+import android.net.thread.ThreadNetworkException;
+import android.net.thread.ThreadNetworkManager;
+import android.os.OutcomeReceiver;
+
+import java.time.Duration;
+import java.util.List;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeoutException;
+
+/** A helper class which provides synchronous API wrappers for {@link ThreadNetworkController}. */
+public final class ThreadNetworkControllerWrapper {
+    public static final Duration JOIN_TIMEOUT = Duration.ofSeconds(10);
+    public static final Duration LEAVE_TIMEOUT = Duration.ofSeconds(2);
+    private static final Duration CALLBACK_TIMEOUT = Duration.ofSeconds(1);
+    private static final Duration SET_ENABLED_TIMEOUT = Duration.ofSeconds(2);
+
+    private final ThreadNetworkController mController;
+
+    /**
+     * Returns a new {@link ThreadNetworkControllerWrapper} instance or {@code null} if Thread
+     * feature is not supported on this device.
+     */
+    @Nullable
+    public static ThreadNetworkControllerWrapper newInstance(Context context) {
+        final ThreadNetworkManager manager = context.getSystemService(ThreadNetworkManager.class);
+        if (manager == null) {
+            return null;
+        }
+        return new ThreadNetworkControllerWrapper(manager.getAllThreadNetworkControllers().get(0));
+    }
+
+    private ThreadNetworkControllerWrapper(ThreadNetworkController controller) {
+        mController = controller;
+    }
+
+    /**
+     * Returns the Thread enabled state.
+     *
+     * <p>The value can be one of {@code ThreadNetworkController#STATE_*}.
+     */
+    public final int getEnabledState()
+            throws InterruptedException, ExecutionException, TimeoutException {
+        CompletableFuture<Integer> future = new CompletableFuture<>();
+        StateCallback callback =
+                new StateCallback() {
+                    @Override
+                    public void onThreadEnableStateChanged(int enabledState) {
+                        future.complete(enabledState);
+                    }
+
+                    @Override
+                    public void onDeviceRoleChanged(int deviceRole) {}
+                };
+
+        runAsShell(
+                ACCESS_NETWORK_STATE,
+                () -> mController.registerStateCallback(directExecutor(), callback));
+        try {
+            return future.get(CALLBACK_TIMEOUT.toSeconds(), SECONDS);
+        } finally {
+            runAsShell(ACCESS_NETWORK_STATE, () -> mController.unregisterStateCallback(callback));
+        }
+    }
+
+    /**
+     * Returns the Thread device role.
+     *
+     * <p>The value can be one of {@code ThreadNetworkController#DEVICE_ROLE_*}.
+     */
+    public final int getDeviceRole()
+            throws InterruptedException, ExecutionException, TimeoutException {
+        CompletableFuture<Integer> future = new CompletableFuture<>();
+        StateCallback callback = future::complete;
+
+        runAsShell(
+                ACCESS_NETWORK_STATE,
+                () -> mController.registerStateCallback(directExecutor(), callback));
+        try {
+            return future.get(CALLBACK_TIMEOUT.toSeconds(), SECONDS);
+        } finally {
+            runAsShell(ACCESS_NETWORK_STATE, () -> mController.unregisterStateCallback(callback));
+        }
+    }
+
+    /** An synchronous variant of {@link ThreadNetworkController#setEnabled}. */
+    public void setEnabledAndWait(boolean enabled)
+            throws InterruptedException, ExecutionException, TimeoutException {
+        CompletableFuture<Void> future = new CompletableFuture<>();
+        runAsShell(
+                PERMISSION_THREAD_NETWORK_PRIVILEGED,
+                () ->
+                        mController.setEnabled(
+                                enabled, directExecutor(), newOutcomeReceiver(future)));
+        future.get(SET_ENABLED_TIMEOUT.toSeconds(), SECONDS);
+    }
+
+    /** Joins the given network and wait for this device to become attached. */
+    public void joinAndWait(ActiveOperationalDataset activeDataset)
+            throws InterruptedException, ExecutionException, TimeoutException {
+        CompletableFuture<Void> future = new CompletableFuture<>();
+        runAsShell(
+                PERMISSION_THREAD_NETWORK_PRIVILEGED,
+                () ->
+                        mController.join(
+                                activeDataset, directExecutor(), newOutcomeReceiver(future)));
+        future.get(JOIN_TIMEOUT.toSeconds(), SECONDS);
+    }
+
+    /** An synchronous variant of {@link ThreadNetworkController#leave}. */
+    public void leaveAndWait() throws InterruptedException, ExecutionException, TimeoutException {
+        CompletableFuture<Void> future = new CompletableFuture<>();
+        runAsShell(
+                PERMISSION_THREAD_NETWORK_PRIVILEGED,
+                () -> mController.leave(directExecutor(), future::complete));
+        future.get(LEAVE_TIMEOUT.toSeconds(), SECONDS);
+    }
+
+    /** Waits for the device role to become {@code deviceRole}. */
+    public int waitForRole(int deviceRole, Duration timeout)
+            throws InterruptedException, ExecutionException, TimeoutException {
+        return waitForRoleAnyOf(List.of(deviceRole), timeout);
+    }
+
+    /** Waits for the device role to become one of the values specified in {@code deviceRoles}. */
+    public int waitForRoleAnyOf(List<Integer> deviceRoles, Duration timeout)
+            throws InterruptedException, ExecutionException, TimeoutException {
+        CompletableFuture<Integer> future = new CompletableFuture<>();
+        ThreadNetworkController.StateCallback callback =
+                newRole -> {
+                    if (deviceRoles.contains(newRole)) {
+                        future.complete(newRole);
+                    }
+                };
+
+        runAsShell(
+                ACCESS_NETWORK_STATE,
+                () -> mController.registerStateCallback(directExecutor(), callback));
+
+        try {
+            return future.get(timeout.toSeconds(), SECONDS);
+        } finally {
+            mController.unregisterStateCallback(callback);
+        }
+    }
+
+    /** An synchronous variant of {@link ThreadNetworkController#setTestNetworkAsUpstream}. */
+    public void setTestNetworkAsUpstreamAndWait(@Nullable String networkInterfaceName)
+            throws InterruptedException, ExecutionException, TimeoutException {
+        CompletableFuture<Void> future = new CompletableFuture<>();
+        runAsShell(
+                PERMISSION_THREAD_NETWORK_PRIVILEGED,
+                NETWORK_SETTINGS,
+                () -> {
+                    mController.setTestNetworkAsUpstream(
+                            networkInterfaceName, directExecutor(), future::complete);
+                });
+        future.get(CALLBACK_TIMEOUT.toSeconds(), SECONDS);
+    }
+
+    private static <V> OutcomeReceiver<V, ThreadNetworkException> newOutcomeReceiver(
+            CompletableFuture<V> future) {
+        return new OutcomeReceiver<V, ThreadNetworkException>() {
+            @Override
+            public void onResult(V result) {
+                future.complete(result);
+            }
+
+            @Override
+            public void onError(ThreadNetworkException e) {
+                future.completeExceptionally(e);
+            }
+        };
+    }
+}
diff --git a/thread/tests/unit/src/android/net/thread/ThreadNetworkControllerTest.java b/thread/tests/unit/src/android/net/thread/ThreadNetworkControllerTest.java
index 75eb043..ac74372 100644
--- a/thread/tests/unit/src/android/net/thread/ThreadNetworkControllerTest.java
+++ b/thread/tests/unit/src/android/net/thread/ThreadNetworkControllerTest.java
@@ -19,6 +19,7 @@
 import static android.net.thread.ThreadNetworkController.DEVICE_ROLE_CHILD;
 import static android.net.thread.ThreadNetworkException.ERROR_UNAVAILABLE;
 import static android.net.thread.ThreadNetworkException.ERROR_UNSUPPORTED_CHANNEL;
+import static android.net.thread.ThreadNetworkException.ERROR_UNSUPPORTED_OPERATION;
 import static android.os.Process.SYSTEM_UID;
 
 import static com.google.common.io.BaseEncoding.base16;
@@ -33,6 +34,7 @@
 import android.os.Binder;
 import android.os.OutcomeReceiver;
 import android.os.Process;
+import android.util.SparseIntArray;
 
 import androidx.test.ext.junit.runners.AndroidJUnit4;
 import androidx.test.filters.SmallTest;
@@ -77,6 +79,13 @@
     private static final ActiveOperationalDataset DEFAULT_DATASET =
             ActiveOperationalDataset.fromThreadTlvs(DEFAULT_DATASET_TLVS);
 
+    private static final SparseIntArray DEFAULT_CHANNEL_POWERS =
+            new SparseIntArray() {
+                {
+                    put(20, 32767);
+                }
+            };
+
     @Before
     public void setUp() {
         MockitoAnnotations.initMocks(this);
@@ -111,6 +120,10 @@
         return (IOperationReceiver) invocation.getArguments()[1];
     }
 
+    private static IOperationReceiver getSetChannelMaxPowersReceiver(InvocationOnMock invocation) {
+        return (IOperationReceiver) invocation.getArguments()[1];
+    }
+
     private static IActiveOperationalDatasetReceiver getCreateDatasetReceiver(
             InvocationOnMock invocation) {
         return (IActiveOperationalDatasetReceiver) invocation.getArguments()[1];
@@ -361,6 +374,51 @@
     }
 
     @Test
+    public void setChannelMaxPowers_callbackIsInvokedWithCallingAppIdentity() throws Exception {
+        setBinderUid(SYSTEM_UID);
+
+        AtomicInteger successCallbackUid = new AtomicInteger(0);
+        AtomicInteger errorCallbackUid = new AtomicInteger(0);
+
+        doAnswer(
+                        invoke -> {
+                            getSetChannelMaxPowersReceiver(invoke).onSuccess();
+                            return null;
+                        })
+                .when(mMockService)
+                .setChannelMaxPowers(any(ChannelMaxPower[].class), any(IOperationReceiver.class));
+        mController.setChannelMaxPowers(
+                DEFAULT_CHANNEL_POWERS,
+                Runnable::run,
+                v -> successCallbackUid.set(Binder.getCallingUid()));
+        doAnswer(
+                        invoke -> {
+                            getSetChannelMaxPowersReceiver(invoke)
+                                    .onError(ERROR_UNSUPPORTED_OPERATION, "");
+                            return null;
+                        })
+                .when(mMockService)
+                .setChannelMaxPowers(any(ChannelMaxPower[].class), any(IOperationReceiver.class));
+        mController.setChannelMaxPowers(
+                DEFAULT_CHANNEL_POWERS,
+                Runnable::run,
+                new OutcomeReceiver<>() {
+                    @Override
+                    public void onResult(Void unused) {}
+
+                    @Override
+                    public void onError(ThreadNetworkException e) {
+                        errorCallbackUid.set(Binder.getCallingUid());
+                    }
+                });
+
+        assertThat(successCallbackUid.get()).isNotEqualTo(SYSTEM_UID);
+        assertThat(successCallbackUid.get()).isEqualTo(Process.myUid());
+        assertThat(errorCallbackUid.get()).isNotEqualTo(SYSTEM_UID);
+        assertThat(errorCallbackUid.get()).isEqualTo(Process.myUid());
+    }
+
+    @Test
     public void setTestNetworkAsUpstream_callbackIsInvokedWithCallingAppIdentity()
             throws Exception {
         setBinderUid(SYSTEM_UID);
diff --git a/thread/tests/unit/src/android/net/thread/ThreadNetworkExceptionTest.java b/thread/tests/unit/src/android/net/thread/ThreadNetworkExceptionTest.java
index f62b437..5908c20 100644
--- a/thread/tests/unit/src/android/net/thread/ThreadNetworkExceptionTest.java
+++ b/thread/tests/unit/src/android/net/thread/ThreadNetworkExceptionTest.java
@@ -32,6 +32,6 @@
     public void constructor_tooLargeErrorCode_throwsIllegalArgumentException() throws Exception {
         // TODO (b/323791003): move this test case to cts/ThreadNetworkExceptionTest when mainline
         // CTS is ready.
-        assertThrows(IllegalArgumentException.class, () -> new ThreadNetworkException(13, "13"));
+        assertThrows(IllegalArgumentException.class, () -> new ThreadNetworkException(14, "14"));
     }
 }
diff --git a/thread/tests/unit/src/com/android/server/thread/NsdPublisherTest.java b/thread/tests/unit/src/com/android/server/thread/NsdPublisherTest.java
index d860166..8886c73 100644
--- a/thread/tests/unit/src/com/android/server/thread/NsdPublisherTest.java
+++ b/thread/tests/unit/src/com/android/server/thread/NsdPublisherTest.java
@@ -23,6 +23,7 @@
 
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.anyInt;
+import static org.mockito.ArgumentMatchers.argThat;
 import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.Mockito.doThrow;
 import static org.mockito.Mockito.spy;
@@ -30,24 +31,30 @@
 import static org.mockito.Mockito.verify;
 
 import android.net.InetAddresses;
+import android.net.nsd.DiscoveryRequest;
 import android.net.nsd.NsdManager;
 import android.net.nsd.NsdServiceInfo;
 import android.os.Handler;
 import android.os.test.TestLooper;
 
 import com.android.server.thread.openthread.DnsTxtAttribute;
+import com.android.server.thread.openthread.INsdDiscoverServiceCallback;
+import com.android.server.thread.openthread.INsdResolveServiceCallback;
 import com.android.server.thread.openthread.INsdStatusReceiver;
 
 import org.junit.Before;
 import org.junit.Test;
 import org.mockito.ArgumentCaptor;
+import org.mockito.ArgumentMatcher;
 import org.mockito.Mock;
 import org.mockito.MockitoAnnotations;
 
 import java.net.InetAddress;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Collections;
 import java.util.List;
+import java.util.Objects;
 import java.util.Set;
 import java.util.concurrent.Executor;
 
@@ -57,6 +64,8 @@
 
     @Mock private INsdStatusReceiver mRegistrationReceiver;
     @Mock private INsdStatusReceiver mUnregistrationReceiver;
+    @Mock private INsdDiscoverServiceCallback mDiscoverServiceCallback;
+    @Mock private INsdResolveServiceCallback mResolveServiceCallback;
 
     private TestLooper mTestLooper;
     private NsdPublisher mNsdPublisher;
@@ -469,6 +478,165 @@
     }
 
     @Test
+    public void discoverService_serviceDiscovered() throws Exception {
+        prepareTest();
+
+        mNsdPublisher.discoverService("_test._tcp", mDiscoverServiceCallback, 10 /* listenerId */);
+        mTestLooper.dispatchAll();
+        ArgumentCaptor<NsdManager.DiscoveryListener> discoveryListenerArgumentCaptor =
+                ArgumentCaptor.forClass(NsdManager.DiscoveryListener.class);
+        verify(mMockNsdManager, times(1))
+                .discoverServices(
+                        eq(new DiscoveryRequest.Builder(PROTOCOL_DNS_SD, "_test._tcp").build()),
+                        any(Executor.class),
+                        discoveryListenerArgumentCaptor.capture());
+        NsdManager.DiscoveryListener actualDiscoveryListener =
+                discoveryListenerArgumentCaptor.getValue();
+        NsdServiceInfo serviceInfo = new NsdServiceInfo();
+        serviceInfo.setServiceName("test");
+        serviceInfo.setServiceType(null);
+        actualDiscoveryListener.onServiceFound(serviceInfo);
+        mTestLooper.dispatchAll();
+
+        verify(mDiscoverServiceCallback, times(1))
+                .onServiceDiscovered("test", "_test._tcp", true /* isFound */);
+    }
+
+    @Test
+    public void discoverService_serviceLost() throws Exception {
+        prepareTest();
+
+        mNsdPublisher.discoverService("_test._tcp", mDiscoverServiceCallback, 10 /* listenerId */);
+        mTestLooper.dispatchAll();
+        ArgumentCaptor<NsdManager.DiscoveryListener> discoveryListenerArgumentCaptor =
+                ArgumentCaptor.forClass(NsdManager.DiscoveryListener.class);
+        verify(mMockNsdManager, times(1))
+                .discoverServices(
+                        eq(new DiscoveryRequest.Builder(PROTOCOL_DNS_SD, "_test._tcp").build()),
+                        any(Executor.class),
+                        discoveryListenerArgumentCaptor.capture());
+        NsdManager.DiscoveryListener actualDiscoveryListener =
+                discoveryListenerArgumentCaptor.getValue();
+        NsdServiceInfo serviceInfo = new NsdServiceInfo();
+        serviceInfo.setServiceName("test");
+        serviceInfo.setServiceType(null);
+        actualDiscoveryListener.onServiceLost(serviceInfo);
+        mTestLooper.dispatchAll();
+
+        verify(mDiscoverServiceCallback, times(1))
+                .onServiceDiscovered("test", "_test._tcp", false /* isFound */);
+    }
+
+    @Test
+    public void stopServiceDiscovery() {
+        prepareTest();
+
+        mNsdPublisher.discoverService("_test._tcp", mDiscoverServiceCallback, 10 /* listenerId */);
+        mTestLooper.dispatchAll();
+        ArgumentCaptor<NsdManager.DiscoveryListener> discoveryListenerArgumentCaptor =
+                ArgumentCaptor.forClass(NsdManager.DiscoveryListener.class);
+        verify(mMockNsdManager, times(1))
+                .discoverServices(
+                        eq(new DiscoveryRequest.Builder(PROTOCOL_DNS_SD, "_test._tcp").build()),
+                        any(Executor.class),
+                        discoveryListenerArgumentCaptor.capture());
+        NsdManager.DiscoveryListener actualDiscoveryListener =
+                discoveryListenerArgumentCaptor.getValue();
+        NsdServiceInfo serviceInfo = new NsdServiceInfo();
+        serviceInfo.setServiceName("test");
+        serviceInfo.setServiceType(null);
+        actualDiscoveryListener.onServiceFound(serviceInfo);
+        mNsdPublisher.stopServiceDiscovery(10 /* listenerId */);
+        mTestLooper.dispatchAll();
+
+        verify(mMockNsdManager, times(1)).stopServiceDiscovery(actualDiscoveryListener);
+    }
+
+    @Test
+    public void resolveService_serviceResolved() throws Exception {
+        prepareTest();
+
+        mNsdPublisher.resolveService(
+                "test", "_test._tcp", mResolveServiceCallback, 10 /* listenerId */);
+        mTestLooper.dispatchAll();
+        ArgumentCaptor<NsdServiceInfo> serviceInfoArgumentCaptor =
+                ArgumentCaptor.forClass(NsdServiceInfo.class);
+        ArgumentCaptor<NsdManager.ServiceInfoCallback> serviceInfoCallbackArgumentCaptor =
+                ArgumentCaptor.forClass(NsdManager.ServiceInfoCallback.class);
+        verify(mMockNsdManager, times(1))
+                .registerServiceInfoCallback(
+                        serviceInfoArgumentCaptor.capture(),
+                        any(Executor.class),
+                        serviceInfoCallbackArgumentCaptor.capture());
+        assertThat(serviceInfoArgumentCaptor.getValue().getServiceName()).isEqualTo("test");
+        assertThat(serviceInfoArgumentCaptor.getValue().getServiceType()).isEqualTo("_test._tcp");
+        NsdServiceInfo serviceInfo = new NsdServiceInfo();
+        serviceInfo.setServiceName("test");
+        serviceInfo.setServiceType("_test._tcp");
+        serviceInfo.setPort(12345);
+        serviceInfo.setHostname("test-host");
+        serviceInfo.setHostAddresses(
+                List.of(
+                        InetAddress.parseNumericAddress("2001::1"),
+                        InetAddress.parseNumericAddress("2001::2")));
+        serviceInfo.setAttribute("key1", new byte[] {(byte) 0x01, (byte) 0x02});
+        serviceInfo.setAttribute("key2", new byte[] {(byte) 0x03});
+        serviceInfoCallbackArgumentCaptor.getValue().onServiceUpdated(serviceInfo);
+        mTestLooper.dispatchAll();
+
+        verify(mResolveServiceCallback, times(1))
+                .onServiceResolved(
+                        eq("test-host"),
+                        eq("test"),
+                        eq("_test._tcp"),
+                        eq(12345),
+                        eq(List.of("2001::1", "2001::2")),
+                        argThat(
+                                new TxtMatcher(
+                                        List.of(
+                                                makeTxtAttribute("key1", List.of(0x01, 0x02)),
+                                                makeTxtAttribute("key2", List.of(0x03))))),
+                        anyInt());
+    }
+
+    @Test
+    public void stopServiceResolution() throws Exception {
+        prepareTest();
+
+        mNsdPublisher.resolveService(
+                "test", "_test._tcp", mResolveServiceCallback, 10 /* listenerId */);
+        mTestLooper.dispatchAll();
+        ArgumentCaptor<NsdServiceInfo> serviceInfoArgumentCaptor =
+                ArgumentCaptor.forClass(NsdServiceInfo.class);
+        ArgumentCaptor<NsdManager.ServiceInfoCallback> serviceInfoCallbackArgumentCaptor =
+                ArgumentCaptor.forClass(NsdManager.ServiceInfoCallback.class);
+        verify(mMockNsdManager, times(1))
+                .registerServiceInfoCallback(
+                        serviceInfoArgumentCaptor.capture(),
+                        any(Executor.class),
+                        serviceInfoCallbackArgumentCaptor.capture());
+        assertThat(serviceInfoArgumentCaptor.getValue().getServiceName()).isEqualTo("test");
+        assertThat(serviceInfoArgumentCaptor.getValue().getServiceType()).isEqualTo("_test._tcp");
+        NsdServiceInfo serviceInfo = new NsdServiceInfo();
+        serviceInfo.setServiceName("test");
+        serviceInfo.setServiceType("_test._tcp");
+        serviceInfo.setPort(12345);
+        serviceInfo.setHostname("test-host");
+        serviceInfo.setHostAddresses(
+                List.of(
+                        InetAddress.parseNumericAddress("2001::1"),
+                        InetAddress.parseNumericAddress("2001::2")));
+        serviceInfo.setAttribute("key1", new byte[] {(byte) 0x01, (byte) 0x02});
+        serviceInfo.setAttribute("key2", new byte[] {(byte) 0x03});
+        serviceInfoCallbackArgumentCaptor.getValue().onServiceUpdated(serviceInfo);
+        mNsdPublisher.stopServiceResolution(10 /* listenerId */);
+        mTestLooper.dispatchAll();
+
+        verify(mMockNsdManager, times(1))
+                .unregisterServiceInfoCallback(serviceInfoCallbackArgumentCaptor.getValue());
+    }
+
+    @Test
     public void reset_unregisterAll() {
         prepareTest();
 
@@ -582,6 +750,30 @@
         return addresses;
     }
 
+    private static class TxtMatcher implements ArgumentMatcher<List<DnsTxtAttribute>> {
+        private final List<DnsTxtAttribute> mAttributes;
+
+        TxtMatcher(List<DnsTxtAttribute> attributes) {
+            mAttributes = attributes;
+        }
+
+        @Override
+        public boolean matches(List<DnsTxtAttribute> argument) {
+            if (argument.size() != mAttributes.size()) {
+                return false;
+            }
+            for (int i = 0; i < argument.size(); ++i) {
+                if (!Objects.equals(argument.get(i).name, mAttributes.get(i).name)) {
+                    return false;
+                }
+                if (!Arrays.equals(argument.get(i).value, mAttributes.get(i).value)) {
+                    return false;
+                }
+            }
+            return true;
+        }
+    }
+
     // @Before and @Test run in different threads. NsdPublisher requires the jobs are run on the
     // thread looper, so TestLooper needs to be created inside each test case to install the
     // correct looper.
diff --git a/thread/tests/unit/src/com/android/server/thread/ThreadNetworkControllerServiceTest.java b/thread/tests/unit/src/com/android/server/thread/ThreadNetworkControllerServiceTest.java
index 60a5f2b..85b6873 100644
--- a/thread/tests/unit/src/com/android/server/thread/ThreadNetworkControllerServiceTest.java
+++ b/thread/tests/unit/src/com/android/server/thread/ThreadNetworkControllerServiceTest.java
@@ -21,9 +21,11 @@
 import static android.net.thread.ThreadNetworkController.STATE_ENABLED;
 import static android.net.thread.ThreadNetworkException.ERROR_FAILED_PRECONDITION;
 import static android.net.thread.ThreadNetworkException.ERROR_INTERNAL_ERROR;
+import static android.net.thread.ThreadNetworkException.ERROR_THREAD_DISABLED;
 import static android.net.thread.ThreadNetworkManager.DISALLOW_THREAD_NETWORK;
 import static android.net.thread.ThreadNetworkManager.PERMISSION_THREAD_NETWORK_PRIVILEGED;
 
+import static com.android.server.thread.ThreadNetworkCountryCode.DEFAULT_COUNTRY_CODE;
 import static com.android.server.thread.openthread.IOtDaemon.ErrorCode.OT_ERROR_INVALID_STATE;
 
 import static com.google.common.io.BaseEncoding.base16;
@@ -37,6 +39,7 @@
 import static org.mockito.Mockito.any;
 import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.doNothing;
+import static org.mockito.Mockito.doThrow;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.never;
 import static org.mockito.Mockito.spy;
@@ -47,6 +50,7 @@
 import android.content.BroadcastReceiver;
 import android.content.Context;
 import android.content.Intent;
+import android.content.res.Resources;
 import android.net.ConnectivityManager;
 import android.net.NetworkAgent;
 import android.net.NetworkProvider;
@@ -65,6 +69,9 @@
 import androidx.test.ext.junit.runners.AndroidJUnit4;
 import androidx.test.filters.SmallTest;
 
+import com.android.connectivity.resources.R;
+import com.android.server.connectivity.ConnectivityResources;
+import com.android.server.thread.openthread.MeshcopTxtAttributes;
 import com.android.server.thread.openthread.testing.FakeOtDaemon;
 
 import org.junit.Before;
@@ -72,7 +79,9 @@
 import org.junit.runner.RunWith;
 import org.mockito.ArgumentCaptor;
 import org.mockito.Captor;
+import org.mockito.InOrder;
 import org.mockito.Mock;
+import org.mockito.Mockito;
 import org.mockito.MockitoAnnotations;
 
 import java.util.concurrent.CompletableFuture;
@@ -110,6 +119,11 @@
     private static final int DEFAULT_SELECTED_CHANNEL = 11;
     private static final byte[] DEFAULT_SUPPORTED_CHANNEL_MASK_ARRAY = base16().decode("001FFFE0");
 
+    private static final String TEST_VENDOR_OUI = "AC-DE-48";
+    private static final byte[] TEST_VENDOR_OUI_BYTES = new byte[] {(byte) 0xAC, (byte) 0xDE, 0x48};
+    private static final String TEST_VENDOR_NAME = "test vendor";
+    private static final String TEST_MODEL_NAME = "test model";
+
     @Mock private ConnectivityManager mMockConnectivityManager;
     @Mock private NetworkAgent mMockNetworkAgent;
     @Mock private TunInterfaceController mMockTunIfController;
@@ -119,6 +133,9 @@
     @Mock private NsdPublisher mMockNsdPublisher;
     @Mock private UserManager mMockUserManager;
     @Mock private IBinder mIBinder;
+    @Mock Resources mResources;
+    @Mock ConnectivityResources mConnectivityResources;
+
     private Context mContext;
     private TestLooper mTestLooper;
     private FakeOtDaemon mFakeOtDaemon;
@@ -146,6 +163,14 @@
         when(mMockPersistentSettings.get(any())).thenReturn(true);
         when(mMockUserManager.hasUserRestriction(eq(DISALLOW_THREAD_NETWORK))).thenReturn(false);
 
+        when(mConnectivityResources.get()).thenReturn(mResources);
+        when(mResources.getString(eq(R.string.config_thread_vendor_name)))
+                .thenReturn(TEST_VENDOR_NAME);
+        when(mResources.getString(eq(R.string.config_thread_vendor_oui)))
+                .thenReturn(TEST_VENDOR_OUI);
+        when(mResources.getString(eq(R.string.config_thread_model_name)))
+                .thenReturn(TEST_MODEL_NAME);
+
         mService =
                 new ThreadNetworkControllerService(
                         mContext,
@@ -157,7 +182,9 @@
                         mMockInfraIfController,
                         mMockPersistentSettings,
                         mMockNsdPublisher,
-                        mMockUserManager);
+                        mMockUserManager,
+                        mConnectivityResources,
+                        () -> DEFAULT_COUNTRY_CODE);
         mService.setTestNetworkAgent(mMockNetworkAgent);
     }
 
@@ -174,6 +201,93 @@
     }
 
     @Test
+    public void initialize_vendorAndModelNameInResourcesAreSetToOtDaemon() throws Exception {
+        when(mResources.getString(eq(R.string.config_thread_vendor_name)))
+                .thenReturn(TEST_VENDOR_NAME);
+        when(mResources.getString(eq(R.string.config_thread_vendor_oui)))
+                .thenReturn(TEST_VENDOR_OUI);
+        when(mResources.getString(eq(R.string.config_thread_model_name)))
+                .thenReturn(TEST_MODEL_NAME);
+
+        mService.initialize();
+        mTestLooper.dispatchAll();
+
+        MeshcopTxtAttributes meshcopTxts = mFakeOtDaemon.getOverriddenMeshcopTxtAttributes();
+        assertThat(meshcopTxts.vendorName).isEqualTo(TEST_VENDOR_NAME);
+        assertThat(meshcopTxts.vendorOui).isEqualTo(TEST_VENDOR_OUI_BYTES);
+        assertThat(meshcopTxts.modelName).isEqualTo(TEST_MODEL_NAME);
+    }
+
+    @Test
+    public void getMeshcopTxtAttributes_emptyVendorName_accepted() {
+        when(mResources.getString(eq(R.string.config_thread_vendor_name))).thenReturn("");
+
+        MeshcopTxtAttributes meshcopTxts =
+                ThreadNetworkControllerService.getMeshcopTxtAttributes(mResources);
+
+        assertThat(meshcopTxts.vendorName).isEqualTo("");
+    }
+
+    @Test
+    public void getMeshcopTxtAttributes_tooLongVendorName_throwsIllegalStateException() {
+        when(mResources.getString(eq(R.string.config_thread_vendor_name)))
+                .thenReturn("vendor name is 25 bytes!!");
+
+        assertThrows(
+                IllegalStateException.class,
+                () -> ThreadNetworkControllerService.getMeshcopTxtAttributes(mResources));
+    }
+
+    @Test
+    public void getMeshcopTxtAttributes_tooLongModelName_throwsIllegalStateException() {
+        when(mResources.getString(eq(R.string.config_thread_model_name)))
+                .thenReturn("model name is 25 bytes!!!");
+
+        assertThrows(
+                IllegalStateException.class,
+                () -> ThreadNetworkControllerService.getMeshcopTxtAttributes(mResources));
+    }
+
+    @Test
+    public void getMeshcopTxtAttributes_emptyModelName_accepted() {
+        when(mResources.getString(eq(R.string.config_thread_model_name))).thenReturn("");
+
+        var meshcopTxts = ThreadNetworkControllerService.getMeshcopTxtAttributes(mResources);
+        assertThat(meshcopTxts.modelName).isEqualTo("");
+    }
+
+    @Test
+    public void getMeshcopTxtAttributes_invalidVendorOui_throwsIllegalStateException() {
+        assertThrows(
+                IllegalStateException.class, () -> getMeshcopTxtAttributesWithVendorOui("ABCDEFA"));
+        assertThrows(
+                IllegalStateException.class, () -> getMeshcopTxtAttributesWithVendorOui("ABCDEG"));
+        assertThrows(
+                IllegalStateException.class, () -> getMeshcopTxtAttributesWithVendorOui("ABCD"));
+        assertThrows(
+                IllegalStateException.class,
+                () -> getMeshcopTxtAttributesWithVendorOui("AB.CD.EF"));
+    }
+
+    @Test
+    public void getMeshcopTxtAttributes_validVendorOui_accepted() {
+        assertThat(getMeshcopTxtAttributesWithVendorOui("010203")).isEqualTo(new byte[] {1, 2, 3});
+        assertThat(getMeshcopTxtAttributesWithVendorOui("01-02-03"))
+                .isEqualTo(new byte[] {1, 2, 3});
+        assertThat(getMeshcopTxtAttributesWithVendorOui("01:02:03"))
+                .isEqualTo(new byte[] {1, 2, 3});
+        assertThat(getMeshcopTxtAttributesWithVendorOui("ABCDEF"))
+                .isEqualTo(new byte[] {(byte) 0xAB, (byte) 0xCD, (byte) 0xEF});
+        assertThat(getMeshcopTxtAttributesWithVendorOui("abcdef"))
+                .isEqualTo(new byte[] {(byte) 0xAB, (byte) 0xCD, (byte) 0xEF});
+    }
+
+    private byte[] getMeshcopTxtAttributesWithVendorOui(String vendorOui) {
+        when(mResources.getString(eq(R.string.config_thread_vendor_oui))).thenReturn(vendorOui);
+        return ThreadNetworkControllerService.getMeshcopTxtAttributes(mResources).vendorOui;
+    }
+
+    @Test
     public void join_otDaemonRemoteFailure_returnsInternalError() throws Exception {
         mService.initialize();
         final IOperationReceiver mockReceiver = mock(IOperationReceiver.class);
@@ -204,13 +318,13 @@
     }
 
     @Test
-    public void userRestriction_initWithUserRestricted_threadIsDisabled() {
+    public void userRestriction_initWithUserRestricted_otDaemonNotStarted() {
         when(mMockUserManager.hasUserRestriction(eq(DISALLOW_THREAD_NETWORK))).thenReturn(true);
 
         mService.initialize();
         mTestLooper.dispatchAll();
 
-        assertThat(mFakeOtDaemon.getEnabledState()).isEqualTo(STATE_DISABLED);
+        assertThat(mFakeOtDaemon.isInitialized()).isFalse();
     }
 
     @Test
@@ -332,4 +446,70 @@
         verify(mockReceiver, never()).onSuccess(any(ActiveOperationalDataset.class));
         verify(mockReceiver, times(1)).onError(eq(ERROR_INTERNAL_ERROR), anyString());
     }
+
+    @Test
+    public void forceStopOtDaemonForTest_noPermission_throwsSecurityException() {
+        doThrow(new SecurityException(""))
+                .when(mContext)
+                .enforceCallingOrSelfPermission(eq(PERMISSION_THREAD_NETWORK_PRIVILEGED), any());
+
+        assertThrows(
+                SecurityException.class,
+                () -> mService.forceStopOtDaemonForTest(true, new IOperationReceiver.Default()));
+    }
+
+    @Test
+    public void forceStopOtDaemonForTest_enabled_otDaemonDiesAndJoinFails() throws Exception {
+        mService.initialize();
+        IOperationReceiver mockReceiver = mock(IOperationReceiver.class);
+        IOperationReceiver mockJoinReceiver = mock(IOperationReceiver.class);
+
+        mService.forceStopOtDaemonForTest(true, mockReceiver);
+        mTestLooper.dispatchAll();
+        mService.join(DEFAULT_ACTIVE_DATASET, mockJoinReceiver);
+        mTestLooper.dispatchAll();
+
+        verify(mockReceiver, times(1)).onSuccess();
+        assertThat(mFakeOtDaemon.isInitialized()).isFalse();
+        verify(mockJoinReceiver, times(1)).onError(eq(ERROR_THREAD_DISABLED), anyString());
+    }
+
+    @Test
+    public void forceStopOtDaemonForTest_disable_otDaemonRestartsAndJoinSccess() throws Exception {
+        mService.initialize();
+        IOperationReceiver mockReceiver = mock(IOperationReceiver.class);
+        IOperationReceiver mockJoinReceiver = mock(IOperationReceiver.class);
+
+        mService.forceStopOtDaemonForTest(true, mock(IOperationReceiver.class));
+        mTestLooper.dispatchAll();
+        mService.forceStopOtDaemonForTest(false, mockReceiver);
+        mTestLooper.dispatchAll();
+        mService.join(DEFAULT_ACTIVE_DATASET, mockJoinReceiver);
+        mTestLooper.dispatchAll();
+        mTestLooper.moveTimeForward(FakeOtDaemon.JOIN_DELAY.toMillis() + 100);
+        mTestLooper.dispatchAll();
+
+        verify(mockReceiver, times(1)).onSuccess();
+        assertThat(mFakeOtDaemon.isInitialized()).isTrue();
+        verify(mockJoinReceiver, times(1)).onSuccess();
+    }
+
+    @Test
+    public void onOtDaemonDied_joinedNetwork_interfaceStateBackToUp() throws Exception {
+        mService.initialize();
+        final IOperationReceiver mockReceiver = mock(IOperationReceiver.class);
+        mService.join(DEFAULT_ACTIVE_DATASET, mockReceiver);
+        mTestLooper.dispatchAll();
+        mTestLooper.moveTimeForward(FakeOtDaemon.JOIN_DELAY.toMillis() + 100);
+        mTestLooper.dispatchAll();
+
+        Mockito.reset(mMockInfraIfController);
+        mFakeOtDaemon.terminate();
+        mTestLooper.dispatchAll();
+
+        verify(mMockTunIfController, times(1)).onOtDaemonDied();
+        InOrder inOrder = Mockito.inOrder(mMockTunIfController);
+        inOrder.verify(mMockTunIfController, times(1)).setInterfaceUp(false);
+        inOrder.verify(mMockTunIfController, times(1)).setInterfaceUp(true);
+    }
 }
diff --git a/thread/tests/unit/src/com/android/server/thread/ThreadNetworkCountryCodeTest.java b/thread/tests/unit/src/com/android/server/thread/ThreadNetworkCountryCodeTest.java
index 5ca6511..ca9741d 100644
--- a/thread/tests/unit/src/com/android/server/thread/ThreadNetworkCountryCodeTest.java
+++ b/thread/tests/unit/src/com/android/server/thread/ThreadNetworkCountryCodeTest.java
@@ -19,6 +19,7 @@
 import static android.net.thread.ThreadNetworkException.ERROR_INTERNAL_ERROR;
 
 import static com.android.server.thread.ThreadNetworkCountryCode.DEFAULT_COUNTRY_CODE;
+import static com.android.server.thread.ThreadPersistentSettings.THREAD_COUNTRY_CODE;
 
 import static com.google.common.truth.Truth.assertThat;
 
@@ -104,6 +105,7 @@
     @Mock List<SubscriptionInfo> mSubscriptionInfoList;
     @Mock SubscriptionInfo mSubscriptionInfo0;
     @Mock SubscriptionInfo mSubscriptionInfo1;
+    @Mock ThreadPersistentSettings mPersistentSettings;
 
     private ThreadNetworkCountryCode mThreadNetworkCountryCode;
     private boolean mErrorSetCountryCode;
@@ -164,7 +166,8 @@
                 mContext,
                 mTelephonyManager,
                 mSubscriptionManager,
-                oemCountryCode);
+                oemCountryCode,
+                mPersistentSettings);
     }
 
     private static Address newAddress(String countryCode) {
@@ -450,6 +453,14 @@
     }
 
     @Test
+    public void settingsCountryCode_settingsCountryCodeIsActive_settingsCountryCodeIsUsed() {
+        when(mPersistentSettings.get(THREAD_COUNTRY_CODE)).thenReturn(TEST_COUNTRY_CODE_CN);
+        mThreadNetworkCountryCode.initialize();
+
+        assertThat(mThreadNetworkCountryCode.getCountryCode()).isEqualTo(TEST_COUNTRY_CODE_CN);
+    }
+
+    @Test
     public void dump_allCountryCodeInfoAreDumped() {
         StringWriter stringWriter = new StringWriter();
         PrintWriter printWriter = new PrintWriter(stringWriter);
diff --git a/thread/tests/unit/src/com/android/server/thread/ThreadNetworkShellCommandTest.java b/thread/tests/unit/src/com/android/server/thread/ThreadNetworkShellCommandTest.java
index c7e0eca..9f2d0cb 100644
--- a/thread/tests/unit/src/com/android/server/thread/ThreadNetworkShellCommandTest.java
+++ b/thread/tests/unit/src/com/android/server/thread/ThreadNetworkShellCommandTest.java
@@ -16,10 +16,15 @@
 
 package com.android.server.thread;
 
+import static org.mockito.ArgumentMatchers.anyBoolean;
 import static org.mockito.Mockito.any;
+import static org.mockito.Mockito.atLeastOnce;
 import static org.mockito.Mockito.contains;
+import static org.mockito.Mockito.doNothing;
+import static org.mockito.Mockito.doThrow;
 import static org.mockito.Mockito.eq;
 import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.validateMockitoUsage;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
@@ -45,19 +50,19 @@
 @SmallTest
 public class ThreadNetworkShellCommandTest {
     private static final String TAG = "ThreadNetworkShellCommandTTest";
-    @Mock ThreadNetworkService mThreadNetworkService;
-    @Mock ThreadNetworkCountryCode mThreadNetworkCountryCode;
+    @Mock ThreadNetworkControllerService mControllerService;
+    @Mock ThreadNetworkCountryCode mCountryCode;
     @Mock PrintWriter mErrorWriter;
     @Mock PrintWriter mOutputWriter;
 
-    ThreadNetworkShellCommand mThreadNetworkShellCommand;
+    ThreadNetworkShellCommand mShellCommand;
 
     @Before
     public void setUp() throws Exception {
         MockitoAnnotations.initMocks(this);
 
-        mThreadNetworkShellCommand = new ThreadNetworkShellCommand(mThreadNetworkCountryCode);
-        mThreadNetworkShellCommand.setPrintWriters(mOutputWriter, mErrorWriter);
+        mShellCommand = new ThreadNetworkShellCommand(mControllerService, mCountryCode);
+        mShellCommand.setPrintWriters(mOutputWriter, mErrorWriter);
     }
 
     @After
@@ -68,9 +73,9 @@
     @Test
     public void getCountryCode_executeInUnrootedShell_allowed() {
         BinderUtil.setUid(Process.SHELL_UID);
-        when(mThreadNetworkCountryCode.getCountryCode()).thenReturn("US");
+        when(mCountryCode.getCountryCode()).thenReturn("US");
 
-        mThreadNetworkShellCommand.exec(
+        mShellCommand.exec(
                 new Binder(),
                 new FileDescriptor(),
                 new FileDescriptor(),
@@ -84,14 +89,14 @@
     public void forceSetCountryCodeEnabled_executeInUnrootedShell_notAllowed() {
         BinderUtil.setUid(Process.SHELL_UID);
 
-        mThreadNetworkShellCommand.exec(
+        mShellCommand.exec(
                 new Binder(),
                 new FileDescriptor(),
                 new FileDescriptor(),
                 new FileDescriptor(),
                 new String[] {"force-country-code", "enabled", "US"});
 
-        verify(mThreadNetworkCountryCode, never()).setOverrideCountryCode(eq("US"));
+        verify(mCountryCode, never()).setOverrideCountryCode(eq("US"));
         verify(mErrorWriter).println(contains("force-country-code"));
     }
 
@@ -99,28 +104,28 @@
     public void forceSetCountryCodeEnabled_executeInRootedShell_allowed() {
         BinderUtil.setUid(Process.ROOT_UID);
 
-        mThreadNetworkShellCommand.exec(
+        mShellCommand.exec(
                 new Binder(),
                 new FileDescriptor(),
                 new FileDescriptor(),
                 new FileDescriptor(),
                 new String[] {"force-country-code", "enabled", "US"});
 
-        verify(mThreadNetworkCountryCode).setOverrideCountryCode(eq("US"));
+        verify(mCountryCode).setOverrideCountryCode(eq("US"));
     }
 
     @Test
     public void forceSetCountryCodeDisabled_executeInUnrootedShell_notAllowed() {
         BinderUtil.setUid(Process.SHELL_UID);
 
-        mThreadNetworkShellCommand.exec(
+        mShellCommand.exec(
                 new Binder(),
                 new FileDescriptor(),
                 new FileDescriptor(),
                 new FileDescriptor(),
                 new String[] {"force-country-code", "disabled"});
 
-        verify(mThreadNetworkCountryCode, never()).setOverrideCountryCode(any());
+        verify(mCountryCode, never()).setOverrideCountryCode(any());
         verify(mErrorWriter).println(contains("force-country-code"));
     }
 
@@ -128,13 +133,64 @@
     public void forceSetCountryCodeDisabled_executeInRootedShell_allowed() {
         BinderUtil.setUid(Process.ROOT_UID);
 
-        mThreadNetworkShellCommand.exec(
+        mShellCommand.exec(
                 new Binder(),
                 new FileDescriptor(),
                 new FileDescriptor(),
                 new FileDescriptor(),
                 new String[] {"force-country-code", "disabled"});
 
-        verify(mThreadNetworkCountryCode).clearOverrideCountryCode();
+        verify(mCountryCode).clearOverrideCountryCode();
+    }
+
+    @Test
+    public void forceStopOtDaemon_executeInUnrootedShell_failedAndServiceApiNotCalled() {
+        BinderUtil.setUid(Process.SHELL_UID);
+
+        mShellCommand.exec(
+                new Binder(),
+                new FileDescriptor(),
+                new FileDescriptor(),
+                new FileDescriptor(),
+                new String[] {"force-stop-ot-daemon", "enabled"});
+
+        verify(mControllerService, never()).forceStopOtDaemonForTest(anyBoolean(), any());
+        verify(mErrorWriter, atLeastOnce()).println(contains("force-stop-ot-daemon"));
+        verify(mOutputWriter, never()).println();
+    }
+
+    @Test
+    public void forceStopOtDaemon_serviceThrows_failed() {
+        BinderUtil.setUid(Process.ROOT_UID);
+        doThrow(new SecurityException(""))
+                .when(mControllerService)
+                .forceStopOtDaemonForTest(eq(true), any());
+
+        mShellCommand.exec(
+                new Binder(),
+                new FileDescriptor(),
+                new FileDescriptor(),
+                new FileDescriptor(),
+                new String[] {"force-stop-ot-daemon", "enabled"});
+
+        verify(mControllerService, times(1)).forceStopOtDaemonForTest(eq(true), any());
+        verify(mOutputWriter, never()).println();
+    }
+
+    @Test
+    public void forceStopOtDaemon_serviceApiTimeout_failedWithTimeoutError() {
+        BinderUtil.setUid(Process.ROOT_UID);
+        doNothing().when(mControllerService).forceStopOtDaemonForTest(eq(true), any());
+
+        mShellCommand.exec(
+                new Binder(),
+                new FileDescriptor(),
+                new FileDescriptor(),
+                new FileDescriptor(),
+                new String[] {"force-stop-ot-daemon", "enabled"});
+
+        verify(mControllerService, times(1)).forceStopOtDaemonForTest(eq(true), any());
+        verify(mErrorWriter, atLeastOnce()).println(contains("timeout"));
+        verify(mOutputWriter, never()).println();
     }
 }
diff --git a/thread/tests/unit/src/com/android/server/thread/ThreadPersistentSettingsTest.java b/thread/tests/unit/src/com/android/server/thread/ThreadPersistentSettingsTest.java
index 9406a2f..7d2fe91 100644
--- a/thread/tests/unit/src/com/android/server/thread/ThreadPersistentSettingsTest.java
+++ b/thread/tests/unit/src/com/android/server/thread/ThreadPersistentSettingsTest.java
@@ -16,6 +16,7 @@
 
 package com.android.server.thread;
 
+import static com.android.server.thread.ThreadPersistentSettings.THREAD_COUNTRY_CODE;
 import static com.android.server.thread.ThreadPersistentSettings.THREAD_ENABLED;
 
 import static com.google.common.truth.Truth.assertThat;
@@ -54,6 +55,8 @@
 @RunWith(AndroidJUnit4.class)
 @SmallTest
 public class ThreadPersistentSettingsTest {
+    private static final String TEST_COUNTRY_CODE = "CN";
+
     @Mock private AtomicFile mAtomicFile;
     @Mock Resources mResources;
     @Mock ConnectivityResources mConnectivityResources;
@@ -131,6 +134,28 @@
         verify(mAtomicFile).finishWrite(any());
     }
 
+    @Test
+    public void put_ThreadCountryCodeString_returnsString() throws Exception {
+        mThreadPersistentSettings.put(THREAD_COUNTRY_CODE.key, TEST_COUNTRY_CODE);
+
+        assertThat(mThreadPersistentSettings.get(THREAD_COUNTRY_CODE)).isEqualTo(TEST_COUNTRY_CODE);
+
+        // Confirm that file writes have been triggered.
+        verify(mAtomicFile).startWrite();
+        verify(mAtomicFile).finishWrite(any());
+    }
+
+    @Test
+    public void put_ThreadCountryCodeNull_returnsNull() throws Exception {
+        mThreadPersistentSettings.put(THREAD_COUNTRY_CODE.key, null);
+
+        assertThat(mThreadPersistentSettings.get(THREAD_COUNTRY_CODE)).isNull();
+
+        // Confirm that file writes have been triggered.
+        verify(mAtomicFile).startWrite();
+        verify(mAtomicFile).finishWrite(any());
+    }
+
     private byte[] createXmlForParsing(String key, Boolean value) throws Exception {
         PersistableBundle bundle = new PersistableBundle();
         ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
diff --git a/thread/tests/utils/Android.bp b/thread/tests/utils/Android.bp
index 24e9bb9..726ec9d 100644
--- a/thread/tests/utils/Android.bp
+++ b/thread/tests/utils/Android.bp
@@ -27,6 +27,7 @@
         "net-tests-utils",
         "net-utils-device-common",
         "net-utils-device-common-bpf",
+        "net-utils-device-common-struct-base",
     ],
     srcs: [
         "src/**/*.java",
diff --git a/thread/tests/utils/src/android/net/thread/utils/TapTestNetworkTracker.java b/thread/tests/utils/src/android/net/thread/utils/TapTestNetworkTracker.java
index 43f177d..b586a19 100644
--- a/thread/tests/utils/src/android/net/thread/utils/TapTestNetworkTracker.java
+++ b/thread/tests/utils/src/android/net/thread/utils/TapTestNetworkTracker.java
@@ -62,6 +62,7 @@
     private final Looper mLooper;
     private TestNetworkInterface mInterface;
     private TestableNetworkAgent mAgent;
+    private Network mNetwork;
     private final TestableNetworkCallback mNetworkCallback;
     private final ConnectivityManager mConnectivityManager;
 
@@ -91,6 +92,11 @@
         return mInterface.getInterfaceName();
     }
 
+    /** Returns the {@link android.net.Network} of the test network. */
+    public Network getNetwork() {
+        return mNetwork;
+    }
+
     private void setUpTestNetwork() throws Exception {
         mInterface = mContext.getSystemService(TestNetworkManager.class).createTapInterface();
 
@@ -105,13 +111,13 @@
                         newNetworkCapabilities(),
                         lp,
                         new NetworkAgentConfig.Builder().build());
-        final Network network = mAgent.register();
+        mNetwork = mAgent.register();
         mAgent.markConnected();
 
         PollingCheck.check(
                 "No usable address on interface",
                 TIMEOUT.toMillis(),
-                () -> hasUsableAddress(network, getInterfaceName()));
+                () -> hasUsableAddress(mNetwork, getInterfaceName()));
 
         lp.setLinkAddresses(makeLinkAddresses());
         mAgent.sendLinkProperties(lp);