Merge "Update THREAD_ENABLED key string in ThreadPersistentSettings" into main
diff --git a/Tethering/tests/integration/base/android/net/EthernetTetheringTestBase.java b/Tethering/tests/integration/base/android/net/EthernetTetheringTestBase.java
index 1022b06..f696885 100644
--- a/Tethering/tests/integration/base/android/net/EthernetTetheringTestBase.java
+++ b/Tethering/tests/integration/base/android/net/EthernetTetheringTestBase.java
@@ -176,6 +176,7 @@
         // tests, turn tethering on and off before running them.
         MyTetheringEventCallback callback = null;
         TestNetworkInterface testIface = null;
+        assumeTrue(sEm != null);
         try {
             // If the physical ethernet interface is available, do nothing.
             if (isInterfaceForTetheringAvailable()) return;
diff --git a/framework-t/src/android/net/EthernetManager.java b/framework-t/src/android/net/EthernetManager.java
index b8070f0..719f60d 100644
--- a/framework-t/src/android/net/EthernetManager.java
+++ b/framework-t/src/android/net/EthernetManager.java
@@ -642,7 +642,14 @@
     }
 
     /**
-     * Listen to changes in the state of ethernet.
+     * Register a IntConsumer to be called back on ethernet state changes.
+     *
+     * <p>{@link IntConsumer#accept} with the current ethernet state will be triggered immediately
+     * upon adding a listener. The same callback is invoked on Ethernet state change, i.e. when
+     * calling {@link #setEthernetEnabled}.
+     * <p>The reported state is represented by:
+     * {@link #ETHERNET_STATE_DISABLED}: ethernet is now disabled.
+     * {@link #ETHERNET_STATE_ENABLED}: ethernet is now enabled.
      *
      * @param executor to run callbacks on.
      * @param listener to listen ethernet state changed.
diff --git a/framework/src/android/net/BpfNetMapsUtils.java b/framework/src/android/net/BpfNetMapsUtils.java
index 0be30bb..3c91db2 100644
--- a/framework/src/android/net/BpfNetMapsUtils.java
+++ b/framework/src/android/net/BpfNetMapsUtils.java
@@ -18,17 +18,22 @@
 
 import static android.net.BpfNetMapsConstants.ALLOW_CHAINS;
 import static android.net.BpfNetMapsConstants.BACKGROUND_MATCH;
+import static android.net.BpfNetMapsConstants.DATA_SAVER_ENABLED;
+import static android.net.BpfNetMapsConstants.DATA_SAVER_ENABLED_KEY;
 import static android.net.BpfNetMapsConstants.DENY_CHAINS;
 import static android.net.BpfNetMapsConstants.DOZABLE_MATCH;
+import static android.net.BpfNetMapsConstants.HAPPY_BOX_MATCH;
 import static android.net.BpfNetMapsConstants.LOW_POWER_STANDBY_MATCH;
 import static android.net.BpfNetMapsConstants.MATCH_LIST;
 import static android.net.BpfNetMapsConstants.NO_MATCH;
 import static android.net.BpfNetMapsConstants.OEM_DENY_1_MATCH;
 import static android.net.BpfNetMapsConstants.OEM_DENY_2_MATCH;
 import static android.net.BpfNetMapsConstants.OEM_DENY_3_MATCH;
+import static android.net.BpfNetMapsConstants.PENALTY_BOX_MATCH;
 import static android.net.BpfNetMapsConstants.POWERSAVE_MATCH;
 import static android.net.BpfNetMapsConstants.RESTRICTED_MATCH;
 import static android.net.BpfNetMapsConstants.STANDBY_MATCH;
+import static android.net.BpfNetMapsConstants.UID_RULES_CONFIGURATION_KEY;
 import static android.net.ConnectivityManager.FIREWALL_CHAIN_BACKGROUND;
 import static android.net.ConnectivityManager.FIREWALL_CHAIN_DOZABLE;
 import static android.net.ConnectivityManager.FIREWALL_CHAIN_LOW_POWER_STANDBY;
@@ -38,12 +43,21 @@
 import static android.net.ConnectivityManager.FIREWALL_CHAIN_POWERSAVE;
 import static android.net.ConnectivityManager.FIREWALL_CHAIN_RESTRICTED;
 import static android.net.ConnectivityManager.FIREWALL_CHAIN_STANDBY;
+import static android.net.ConnectivityManager.FIREWALL_RULE_ALLOW;
+import static android.net.ConnectivityManager.FIREWALL_RULE_DENY;
 import static android.system.OsConstants.EINVAL;
 
 import android.os.ServiceSpecificException;
+import android.system.ErrnoException;
+import android.system.Os;
 import android.util.Pair;
 
 import com.android.modules.utils.build.SdkLevel;
+import com.android.net.module.util.IBpfMap;
+import com.android.net.module.util.Struct;
+import com.android.net.module.util.Struct.S32;
+import com.android.net.module.util.Struct.U32;
+import com.android.net.module.util.Struct.U8;
 
 import java.util.StringJoiner;
 
@@ -56,6 +70,26 @@
 // Because modules could have different copies of this class if this is statically linked,
 // which would be problematic if the definitions in these modules are not synchronized.
 public class BpfNetMapsUtils {
+    // Bitmaps for calculating whether a given uid is blocked by firewall chains.
+    private static final long sMaskDropIfSet;
+    private static final long sMaskDropIfUnset;
+
+    static {
+        long maskDropIfSet = 0L;
+        long maskDropIfUnset = 0L;
+
+        for (int chain : BpfNetMapsConstants.ALLOW_CHAINS) {
+            final long match = getMatchByFirewallChain(chain);
+            maskDropIfUnset |= match;
+        }
+        for (int chain : BpfNetMapsConstants.DENY_CHAINS) {
+            final long match = getMatchByFirewallChain(chain);
+            maskDropIfSet |= match;
+        }
+        sMaskDropIfSet = maskDropIfSet;
+        sMaskDropIfUnset = maskDropIfUnset;
+    }
+
     // Prevent this class from being accidental instantiated.
     private BpfNetMapsUtils() {}
 
@@ -133,4 +167,122 @@
             throw new UnsupportedOperationException(msg);
         }
     }
+
+    /**
+     * Get the specified firewall chain's status.
+     *
+     * @param configurationMap target configurationMap
+     * @param chain target chain
+     * @return {@code true} if chain is enabled, {@code false} if chain is not enabled.
+     * @throws UnsupportedOperationException if called on pre-T devices.
+     * @throws ServiceSpecificException in case of failure, with an error code indicating the
+     *                                  cause of the failure.
+     */
+    public static boolean isChainEnabled(
+            final IBpfMap<S32, U32> configurationMap, final int chain) {
+        throwIfPreT("isChainEnabled is not available on pre-T devices");
+
+        final long match = getMatchByFirewallChain(chain);
+        try {
+            final U32 config = configurationMap.getValue(UID_RULES_CONFIGURATION_KEY);
+            return (config.val & match) != 0;
+        } catch (ErrnoException e) {
+            throw new ServiceSpecificException(e.errno,
+                    "Unable to get firewall chain status: " + Os.strerror(e.errno));
+        }
+    }
+
+    /**
+     * Get firewall rule of specified firewall chain on specified uid.
+     *
+     * @param uidOwnerMap target uidOwnerMap.
+     * @param chain target chain.
+     * @param uid target uid.
+     * @return either FIREWALL_RULE_ALLOW or FIREWALL_RULE_DENY
+     * @throws UnsupportedOperationException if called on pre-T devices.
+     * @throws ServiceSpecificException      in case of failure, with an error code indicating the
+     *                                       cause of the failure.
+     */
+    public static int getUidRule(final IBpfMap<S32, UidOwnerValue> uidOwnerMap,
+            final int chain, final int uid) {
+        throwIfPreT("getUidRule is not available on pre-T devices");
+
+        final long match = getMatchByFirewallChain(chain);
+        final boolean isAllowList = isFirewallAllowList(chain);
+        try {
+            final UidOwnerValue uidMatch = uidOwnerMap.getValue(new S32(uid));
+            final boolean isMatchEnabled = uidMatch != null && (uidMatch.rule & match) != 0;
+            return isMatchEnabled == isAllowList ? FIREWALL_RULE_ALLOW : FIREWALL_RULE_DENY;
+        } catch (ErrnoException e) {
+            throw new ServiceSpecificException(e.errno,
+                    "Unable to get uid rule status: " + Os.strerror(e.errno));
+        }
+    }
+
+    /**
+     * Return whether the network is blocked by firewall chains for the given uid.
+     *
+     * Note that {@link #getDataSaverEnabled(IBpfMap)} has a latency before V.
+     *
+     * @param uid The target uid.
+     * @param isNetworkMetered Whether the target network is metered.
+     *
+     * @return True if the network is blocked. Otherwise, false.
+     * @throws ServiceSpecificException if the read fails.
+     *
+     * @hide
+     */
+    public static boolean isUidNetworkingBlocked(final int uid, boolean isNetworkMetered,
+            IBpfMap<S32, U32> configurationMap,
+            IBpfMap<S32, UidOwnerValue> uidOwnerMap,
+            IBpfMap<S32, U8> dataSaverEnabledMap
+    ) {
+        throwIfPreT("isUidBlockedByFirewallChains is not available on pre-T devices");
+
+        final long uidRuleConfig;
+        final long uidMatch;
+        try {
+            uidRuleConfig = configurationMap.getValue(UID_RULES_CONFIGURATION_KEY).val;
+            final UidOwnerValue value = uidOwnerMap.getValue(new Struct.S32(uid));
+            uidMatch = (value != null) ? value.rule : 0L;
+        } catch (ErrnoException e) {
+            throw new ServiceSpecificException(e.errno,
+                    "Unable to get firewall chain status: " + Os.strerror(e.errno));
+        }
+
+        final boolean blockedByAllowChains = 0 != (uidRuleConfig & ~uidMatch & sMaskDropIfUnset);
+        final boolean blockedByDenyChains = 0 != (uidRuleConfig & uidMatch & sMaskDropIfSet);
+        if (blockedByAllowChains || blockedByDenyChains) {
+            return true;
+        }
+
+        if (!isNetworkMetered) return false;
+        if ((uidMatch & PENALTY_BOX_MATCH) != 0) return true;
+        if ((uidMatch & HAPPY_BOX_MATCH) != 0) return false;
+        return getDataSaverEnabled(dataSaverEnabledMap);
+    }
+
+    /**
+     * Get Data Saver enabled or disabled
+     *
+     * Note that before V, the data saver status in bpf is written by ConnectivityService
+     * when receiving {@link ConnectivityManager#ACTION_RESTRICT_BACKGROUND_CHANGED}. Thus,
+     * the status is not synchronized.
+     * On V+, the data saver status is set by platform code when enabling/disabling
+     * data saver, which is synchronized.
+     *
+     * @return whether Data Saver is enabled or disabled.
+     * @throws ServiceSpecificException in case of failure, with an error code indicating the
+     *                                  cause of the failure.
+     */
+    public static boolean getDataSaverEnabled(IBpfMap<S32, U8> dataSaverEnabledMap) {
+        throwIfPreT("getDataSaverEnabled is not available on pre-T devices");
+
+        try {
+            return dataSaverEnabledMap.getValue(DATA_SAVER_ENABLED_KEY).val == DATA_SAVER_ENABLED;
+        } catch (ErrnoException e) {
+            throw new ServiceSpecificException(e.errno, "Unable to get data saver: "
+                    + Os.strerror(e.errno));
+        }
+    }
 }
diff --git a/framework/src/android/net/NetworkStackBpfNetMaps.java b/framework/src/android/net/NetworkStackBpfNetMaps.java
index 346c997..b7c4e34 100644
--- a/framework/src/android/net/NetworkStackBpfNetMaps.java
+++ b/framework/src/android/net/NetworkStackBpfNetMaps.java
@@ -17,25 +17,14 @@
 package android.net;
 
 import static android.net.BpfNetMapsConstants.CONFIGURATION_MAP_PATH;
-import static android.net.BpfNetMapsConstants.DATA_SAVER_ENABLED;
-import static android.net.BpfNetMapsConstants.DATA_SAVER_ENABLED_KEY;
 import static android.net.BpfNetMapsConstants.DATA_SAVER_ENABLED_MAP_PATH;
-import static android.net.BpfNetMapsConstants.HAPPY_BOX_MATCH;
-import static android.net.BpfNetMapsConstants.PENALTY_BOX_MATCH;
 import static android.net.BpfNetMapsConstants.UID_OWNER_MAP_PATH;
-import static android.net.BpfNetMapsConstants.UID_RULES_CONFIGURATION_KEY;
-import static android.net.BpfNetMapsUtils.getMatchByFirewallChain;
-import static android.net.BpfNetMapsUtils.isFirewallAllowList;
-import static android.net.BpfNetMapsUtils.throwIfPreT;
-import static android.net.ConnectivityManager.FIREWALL_RULE_ALLOW;
-import static android.net.ConnectivityManager.FIREWALL_RULE_DENY;
 
 import android.annotation.NonNull;
 import android.annotation.RequiresApi;
 import android.os.Build;
 import android.os.ServiceSpecificException;
 import android.system.ErrnoException;
-import android.system.Os;
 
 import com.android.internal.annotations.VisibleForTesting;
 import com.android.modules.utils.build.SdkLevel;
@@ -51,7 +40,8 @@
  * {@link com.android.server.BpfNetMaps}
  * @hide
  */
-@RequiresApi(Build.VERSION_CODES.TIRAMISU)  // BPF maps were only mainlined in T
+// NetworkStack can not use this before U due to b/326143935
+@RequiresApi(Build.VERSION_CODES.UPSIDE_DOWN_CAKE)
 public class NetworkStackBpfNetMaps {
     private static final String TAG = NetworkStackBpfNetMaps.class.getSimpleName();
 
@@ -67,26 +57,6 @@
     private final IBpfMap<S32, U8> mDataSaverEnabledMap;
     private final Dependencies mDeps;
 
-    // Bitmaps for calculating whether a given uid is blocked by firewall chains.
-    private static final long sMaskDropIfSet;
-    private static final long sMaskDropIfUnset;
-
-    static {
-        long maskDropIfSet = 0L;
-        long maskDropIfUnset = 0L;
-
-        for (int chain : BpfNetMapsConstants.ALLOW_CHAINS) {
-            final long match = getMatchByFirewallChain(chain);
-            maskDropIfUnset |= match;
-        }
-        for (int chain : BpfNetMapsConstants.DENY_CHAINS) {
-            final long match = getMatchByFirewallChain(chain);
-            maskDropIfSet |= match;
-        }
-        sMaskDropIfSet = maskDropIfSet;
-        sMaskDropIfUnset = maskDropIfUnset;
-    }
-
     private static class SingletonHolder {
         static final NetworkStackBpfNetMaps sInstance = new NetworkStackBpfNetMaps();
     }
@@ -162,7 +132,7 @@
      *                                  cause of the failure.
      */
     public boolean isChainEnabled(final int chain) {
-        return isChainEnabled(mConfigurationMap, chain);
+        return BpfNetMapsUtils.isChainEnabled(mConfigurationMap, chain);
     }
 
     /**
@@ -177,58 +147,7 @@
      *                                  cause of the failure.
      */
     public int getUidRule(final int chain, final int uid) {
-        return getUidRule(mUidOwnerMap, chain, uid);
-    }
-
-    /**
-     * Get the specified firewall chain's status.
-     *
-     * @param configurationMap target configurationMap
-     * @param chain target chain
-     * @return {@code true} if chain is enabled, {@code false} if chain is not enabled.
-     * @throws UnsupportedOperationException if called on pre-T devices.
-     * @throws ServiceSpecificException in case of failure, with an error code indicating the
-     *                                  cause of the failure.
-     */
-    public static boolean isChainEnabled(
-            final IBpfMap<S32, U32> configurationMap, final int chain) {
-        throwIfPreT("isChainEnabled is not available on pre-T devices");
-
-        final long match = getMatchByFirewallChain(chain);
-        try {
-            final U32 config = configurationMap.getValue(UID_RULES_CONFIGURATION_KEY);
-            return (config.val & match) != 0;
-        } catch (ErrnoException e) {
-            throw new ServiceSpecificException(e.errno,
-                    "Unable to get firewall chain status: " + Os.strerror(e.errno));
-        }
-    }
-
-    /**
-     * Get firewall rule of specified firewall chain on specified uid.
-     *
-     * @param uidOwnerMap target uidOwnerMap.
-     * @param chain target chain.
-     * @param uid target uid.
-     * @return either FIREWALL_RULE_ALLOW or FIREWALL_RULE_DENY
-     * @throws UnsupportedOperationException if called on pre-T devices.
-     * @throws ServiceSpecificException      in case of failure, with an error code indicating the
-     *                                       cause of the failure.
-     */
-    public static int getUidRule(final IBpfMap<S32, UidOwnerValue> uidOwnerMap,
-            final int chain, final int uid) {
-        throwIfPreT("getUidRule is not available on pre-T devices");
-
-        final long match = getMatchByFirewallChain(chain);
-        final boolean isAllowList = isFirewallAllowList(chain);
-        try {
-            final UidOwnerValue uidMatch = uidOwnerMap.getValue(new S32(uid));
-            final boolean isMatchEnabled = uidMatch != null && (uidMatch.rule & match) != 0;
-            return isMatchEnabled == isAllowList ? FIREWALL_RULE_ALLOW : FIREWALL_RULE_DENY;
-        } catch (ErrnoException e) {
-            throw new ServiceSpecificException(e.errno,
-                    "Unable to get uid rule status: " + Os.strerror(e.errno));
-        }
+        return BpfNetMapsUtils.getUidRule(mUidOwnerMap, chain, uid);
     }
 
     /**
@@ -245,29 +164,8 @@
      * @hide
      */
     public boolean isUidNetworkingBlocked(final int uid, boolean isNetworkMetered) {
-        throwIfPreT("isUidBlockedByFirewallChains is not available on pre-T devices");
-
-        final long uidRuleConfig;
-        final long uidMatch;
-        try {
-            uidRuleConfig = mConfigurationMap.getValue(UID_RULES_CONFIGURATION_KEY).val;
-            final UidOwnerValue value = mUidOwnerMap.getValue(new S32(uid));
-            uidMatch = (value != null) ? value.rule : 0L;
-        } catch (ErrnoException e) {
-            throw new ServiceSpecificException(e.errno,
-                    "Unable to get firewall chain status: " + Os.strerror(e.errno));
-        }
-
-        final boolean blockedByAllowChains = 0 != (uidRuleConfig & ~uidMatch & sMaskDropIfUnset);
-        final boolean blockedByDenyChains = 0 != (uidRuleConfig & uidMatch & sMaskDropIfSet);
-        if (blockedByAllowChains || blockedByDenyChains) {
-            return true;
-        }
-
-        if (!isNetworkMetered) return false;
-        if ((uidMatch & PENALTY_BOX_MATCH) != 0) return true;
-        if ((uidMatch & HAPPY_BOX_MATCH) != 0) return false;
-        return getDataSaverEnabled();
+        return BpfNetMapsUtils.isUidNetworkingBlocked(uid, isNetworkMetered,
+                mConfigurationMap, mUidOwnerMap, mDataSaverEnabledMap);
     }
 
     /**
@@ -284,13 +182,6 @@
      *                                  cause of the failure.
      */
     public boolean getDataSaverEnabled() {
-        throwIfPreT("getDataSaverEnabled is not available on pre-T devices");
-
-        try {
-            return mDataSaverEnabledMap.getValue(DATA_SAVER_ENABLED_KEY).val == DATA_SAVER_ENABLED;
-        } catch (ErrnoException e) {
-            throw new ServiceSpecificException(e.errno, "Unable to get data saver: "
-                    + Os.strerror(e.errno));
-        }
+        return BpfNetMapsUtils.getDataSaverEnabled(mDataSaverEnabledMap);
     }
 }
diff --git a/netbpfload/netbpfload.mainline.rc b/netbpfload/netbpfload.mainline.rc
index d7202f7..d38a503 100644
--- a/netbpfload/netbpfload.mainline.rc
+++ b/netbpfload/netbpfload.mainline.rc
@@ -10,6 +10,7 @@
     capabilities CHOWN SYS_ADMIN NET_ADMIN
     group system root graphics network_stack net_admin net_bw_acct net_bw_stats net_raw
     user system
+    file /dev/kmsg w
     rlimit memlock 1073741824 1073741824
     oneshot
     reboot_on_failure reboot,bpfloader-failed
diff --git a/service-t/src/com/android/server/NsdService.java b/service-t/src/com/android/server/NsdService.java
index cfb1a33..aca386f 100644
--- a/service-t/src/com/android/server/NsdService.java
+++ b/service-t/src/com/android/server/NsdService.java
@@ -1851,6 +1851,8 @@
                         mContext, MdnsFeatureFlags.NSD_UNICAST_REPLY_ENABLED))
                 .setIsAggressiveQueryModeEnabled(mDeps.isFeatureEnabled(
                         mContext, MdnsFeatureFlags.NSD_AGGRESSIVE_QUERY_MODE))
+                .setIsQueryWithKnownAnswerEnabled(mDeps.isFeatureEnabled(
+                        mContext, MdnsFeatureFlags.NSD_QUERY_WITH_KNOWN_ANSWER))
                 .setOverrideProvider(flag -> mDeps.isFeatureEnabled(
                         mContext, FORCE_ENABLE_FLAG_FOR_TEST_PREFIX + flag))
                 .build();
diff --git a/service-t/src/com/android/server/connectivity/mdns/EnqueueMdnsQueryCallable.java b/service-t/src/com/android/server/connectivity/mdns/EnqueueMdnsQueryCallable.java
index c4d3338..e61555a 100644
--- a/service-t/src/com/android/server/connectivity/mdns/EnqueueMdnsQueryCallable.java
+++ b/service-t/src/com/android/server/connectivity/mdns/EnqueueMdnsQueryCallable.java
@@ -23,6 +23,7 @@
 import android.text.TextUtils;
 import android.util.Pair;
 
+import com.android.net.module.util.CollectionUtils;
 import com.android.net.module.util.SharedLog;
 import com.android.server.connectivity.mdns.util.MdnsUtils;
 
@@ -63,8 +64,6 @@
     @NonNull
     private final WeakReference<MdnsSocketClientBase> weakRequestSender;
     @NonNull
-    private final MdnsPacketWriter packetWriter;
-    @NonNull
     private final String[] serviceTypeLabels;
     @NonNull
     private final List<String> subtypes;
@@ -79,11 +78,15 @@
     private final MdnsUtils.Clock clock;
     @NonNull
     private final SharedLog sharedLog;
+    @NonNull
+    private final MdnsServiceTypeClient.Dependencies dependencies;
     private final boolean onlyUseIpv6OnIpv6OnlyNetworks;
+    private final byte[] packetCreationBuffer = new byte[1500]; // TODO: use interface MTU
+    @NonNull
+    private final List<MdnsResponse> existingServices;
 
     EnqueueMdnsQueryCallable(
             @NonNull MdnsSocketClientBase requestSender,
-            @NonNull MdnsPacketWriter packetWriter,
             @NonNull String serviceType,
             @NonNull Collection<String> subtypes,
             boolean expectUnicastResponse,
@@ -93,9 +96,10 @@
             boolean sendDiscoveryQueries,
             @NonNull Collection<MdnsResponse> servicesToResolve,
             @NonNull MdnsUtils.Clock clock,
-            @NonNull SharedLog sharedLog) {
+            @NonNull SharedLog sharedLog,
+            @NonNull MdnsServiceTypeClient.Dependencies dependencies,
+            @NonNull Collection<MdnsResponse> existingServices) {
         weakRequestSender = new WeakReference<>(requestSender);
-        this.packetWriter = packetWriter;
         serviceTypeLabels = TextUtils.split(serviceType, "\\.");
         this.subtypes = new ArrayList<>(subtypes);
         this.expectUnicastResponse = expectUnicastResponse;
@@ -106,6 +110,8 @@
         this.servicesToResolve = new ArrayList<>(servicesToResolve);
         this.clock = clock;
         this.sharedLog = sharedLog;
+        this.dependencies = dependencies;
+        this.existingServices = new ArrayList<>(existingServices);
     }
 
     /**
@@ -176,29 +182,52 @@
                 return Pair.create(INVALID_TRANSACTION_ID, new ArrayList<>());
             }
 
+            // Put the existing ptr records into known-answer section.
+            final List<MdnsRecord> knownAnswers = new ArrayList<>();
+            if (sendDiscoveryQueries) {
+                for (MdnsResponse existingService : existingServices) {
+                    for (MdnsPointerRecord ptrRecord : existingService.getPointerRecords()) {
+                        // Ignore any PTR records that don't match the current query.
+                        if (!CollectionUtils.any(questions,
+                                q -> q instanceof MdnsPointerRecord
+                                        && MdnsUtils.equalsDnsLabelIgnoreDnsCase(
+                                                q.getName(), ptrRecord.getName()))) {
+                            continue;
+                        }
+
+                        knownAnswers.add(new MdnsPointerRecord(
+                                ptrRecord.getName(),
+                                ptrRecord.getReceiptTime(),
+                                ptrRecord.getCacheFlush(),
+                                ptrRecord.getRemainingTTL(now), // Put the remaining ttl.
+                                ptrRecord.getPointer()));
+                    }
+                }
+            }
+
             final MdnsPacket queryPacket = new MdnsPacket(
                     transactionId,
                     MdnsConstants.FLAGS_QUERY,
                     questions,
-                    Collections.emptyList(), /* answers */
+                    knownAnswers,
                     Collections.emptyList(), /* authorityRecords */
                     Collections.emptyList() /* additionalRecords */);
-            MdnsUtils.writeMdnsPacket(packetWriter, queryPacket);
-            sendPacketToIpv4AndIpv6(requestSender, MdnsConstants.MDNS_PORT);
+            sendPacketToIpv4AndIpv6(requestSender, MdnsConstants.MDNS_PORT, queryPacket);
             for (Integer emulatorPort : castShellEmulatorMdnsPorts) {
-                sendPacketToIpv4AndIpv6(requestSender, emulatorPort);
+                sendPacketToIpv4AndIpv6(requestSender, emulatorPort, queryPacket);
             }
             return Pair.create(transactionId, subtypes);
-        } catch (IOException e) {
+        } catch (Exception e) {
             sharedLog.e(String.format("Failed to create mDNS packet for subtype: %s.",
                     TextUtils.join(",", subtypes)), e);
             return Pair.create(INVALID_TRANSACTION_ID, new ArrayList<>());
         }
     }
 
-    private void sendPacket(MdnsSocketClientBase requestSender, InetSocketAddress address)
-            throws IOException {
-        DatagramPacket packet = packetWriter.getPacket(address);
+    private void sendPacket(MdnsSocketClientBase requestSender, InetSocketAddress address,
+            MdnsPacket mdnsPacket) throws IOException {
+        final DatagramPacket packet = dependencies.getDatagramPacketFromMdnsPacket(
+                packetCreationBuffer, mdnsPacket, address);
         if (expectUnicastResponse) {
             // MdnsMultinetworkSocketClient is only available on T+
             if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.TIRAMISU
@@ -222,16 +251,17 @@
         }
     }
 
-    private void sendPacketToIpv4AndIpv6(MdnsSocketClientBase requestSender, int port) {
+    private void sendPacketToIpv4AndIpv6(MdnsSocketClientBase requestSender, int port,
+            MdnsPacket mdnsPacket) {
         try {
             sendPacket(requestSender,
-                    new InetSocketAddress(MdnsConstants.getMdnsIPv4Address(), port));
+                    new InetSocketAddress(MdnsConstants.getMdnsIPv4Address(), port), mdnsPacket);
         } catch (IOException e) {
             sharedLog.e("Can't send packet to IPv4", e);
         }
         try {
             sendPacket(requestSender,
-                    new InetSocketAddress(MdnsConstants.getMdnsIPv6Address(), port));
+                    new InetSocketAddress(MdnsConstants.getMdnsIPv6Address(), port), mdnsPacket);
         } catch (IOException e) {
             sharedLog.e("Can't send packet to IPv6", e);
         }
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java b/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java
index 21b7069..7b0c738 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java
@@ -362,7 +362,7 @@
         return new MdnsServiceTypeClient(
                 serviceType, socketClient,
                 executorProvider.newServiceTypeClientSchedulerExecutor(), socketKey,
-                sharedLog.forSubComponent(tag), looper, serviceCache);
+                sharedLog.forSubComponent(tag), looper, serviceCache, mdnsFeatureFlags);
     }
 
     /**
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsFeatureFlags.java b/service-t/src/com/android/server/connectivity/mdns/MdnsFeatureFlags.java
index 56202fd..f4a08ba 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsFeatureFlags.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsFeatureFlags.java
@@ -62,6 +62,11 @@
      */
     public static final String NSD_AGGRESSIVE_QUERY_MODE = "nsd_aggressive_query_mode";
 
+    /**
+     * A feature flag to control whether the query with known-answer should be enabled.
+     */
+    public static final String NSD_QUERY_WITH_KNOWN_ANSWER = "nsd_query_with_known_answer";
+
     // Flag for offload feature
     public final boolean mIsMdnsOffloadFeatureEnabled;
 
@@ -83,6 +88,9 @@
     // Flag for aggressive query mode
     public final boolean mIsAggressiveQueryModeEnabled;
 
+    // Flag for query with known-answer
+    public final boolean mIsQueryWithKnownAnswerEnabled;
+
     @Nullable
     private final FlagOverrideProvider mOverrideProvider;
 
@@ -126,6 +134,14 @@
     }
 
     /**
+     * Indicates whether {@link #NSD_QUERY_WITH_KNOWN_ANSWER} is enabled, including for testing.
+     */
+    public boolean isQueryWithKnownAnswerEnabled() {
+        return mIsQueryWithKnownAnswerEnabled
+                || isForceEnabledForTest(NSD_QUERY_WITH_KNOWN_ANSWER);
+    }
+
+    /**
      * The constructor for {@link MdnsFeatureFlags}.
      */
     public MdnsFeatureFlags(boolean isOffloadFeatureEnabled,
@@ -135,6 +151,7 @@
             boolean isKnownAnswerSuppressionEnabled,
             boolean isUnicastReplyEnabled,
             boolean isAggressiveQueryModeEnabled,
+            boolean isQueryWithKnownAnswerEnabled,
             @Nullable FlagOverrideProvider overrideProvider) {
         mIsMdnsOffloadFeatureEnabled = isOffloadFeatureEnabled;
         mIncludeInetAddressRecordsInProbing = includeInetAddressRecordsInProbing;
@@ -143,6 +160,7 @@
         mIsKnownAnswerSuppressionEnabled = isKnownAnswerSuppressionEnabled;
         mIsUnicastReplyEnabled = isUnicastReplyEnabled;
         mIsAggressiveQueryModeEnabled = isAggressiveQueryModeEnabled;
+        mIsQueryWithKnownAnswerEnabled = isQueryWithKnownAnswerEnabled;
         mOverrideProvider = overrideProvider;
     }
 
@@ -162,6 +180,7 @@
         private boolean mIsKnownAnswerSuppressionEnabled;
         private boolean mIsUnicastReplyEnabled;
         private boolean mIsAggressiveQueryModeEnabled;
+        private boolean mIsQueryWithKnownAnswerEnabled;
         private FlagOverrideProvider mOverrideProvider;
 
         /**
@@ -175,6 +194,7 @@
             mIsKnownAnswerSuppressionEnabled = false;
             mIsUnicastReplyEnabled = true;
             mIsAggressiveQueryModeEnabled = false;
+            mIsQueryWithKnownAnswerEnabled = false;
             mOverrideProvider = null;
         }
 
@@ -261,6 +281,16 @@
         }
 
         /**
+         * Set whether the query with known-answer is enabled.
+         *
+         * @see #NSD_QUERY_WITH_KNOWN_ANSWER
+         */
+        public Builder setIsQueryWithKnownAnswerEnabled(boolean isQueryWithKnownAnswerEnabled) {
+            mIsQueryWithKnownAnswerEnabled = isQueryWithKnownAnswerEnabled;
+            return this;
+        }
+
+        /**
          * Builds a {@link MdnsFeatureFlags} with the arguments supplied to this builder.
          */
         public MdnsFeatureFlags build() {
@@ -271,6 +301,7 @@
                     mIsKnownAnswerSuppressionEnabled,
                     mIsUnicastReplyEnabled,
                     mIsAggressiveQueryModeEnabled,
+                    mIsQueryWithKnownAnswerEnabled,
                     mOverrideProvider);
         }
     }
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java b/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java
index c1c7d5f..61eb766 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java
@@ -22,12 +22,10 @@
 import android.annotation.Nullable;
 import android.annotation.RequiresApi;
 import android.net.LinkAddress;
-import android.net.nsd.NsdManager;
 import android.net.nsd.NsdServiceInfo;
 import android.os.Build;
 import android.os.Handler;
 import android.os.Looper;
-import android.util.ArraySet;
 
 import com.android.internal.annotations.VisibleForTesting;
 import com.android.net.module.util.HexDump;
@@ -284,6 +282,7 @@
         if (!mRecordRepository.hasActiveService(id)) return;
         mProber.stop(id);
         mAnnouncer.stop(id);
+        final String hostname = mRecordRepository.getHostnameForServiceId(id);
         final MdnsAnnouncer.ExitAnnouncementInfo exitInfo = mRecordRepository.exitService(id);
         if (exitInfo != null) {
             // This effectively schedules onAllServicesRemoved(), as it is to be called when the
@@ -303,6 +302,24 @@
                 }
             });
         }
+        // Re-probe/re-announce the services which have the same custom hostname. These services
+        // were probed/announced using host addresses which were just removed so they should be
+        // re-probed/re-announced without those addresses.
+        if (hostname != null) {
+            final List<MdnsProber.ProbingInfo> probingInfos =
+                    mRecordRepository.restartProbingForHostname(hostname);
+            for (MdnsProber.ProbingInfo probingInfo : probingInfos) {
+                mProber.stop(probingInfo.getServiceId());
+                mProber.startProbing(probingInfo);
+            }
+            final List<MdnsAnnouncer.AnnouncementInfo> announcementInfos =
+                    mRecordRepository.restartAnnouncingForHostname(hostname);
+            for (MdnsAnnouncer.AnnouncementInfo announcementInfo : announcementInfos) {
+                mAnnouncer.stop(announcementInfo.getServiceId());
+                mAnnouncer.startSending(
+                        announcementInfo.getServiceId(), announcementInfo, 0 /* initialDelayMs */);
+            }
+        }
     }
 
     /**
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java b/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java
index ac64c3a..073e465 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java
@@ -925,22 +925,79 @@
         }
     }
 
+    @Nullable
+    public String getHostnameForServiceId(int id) {
+        ServiceRegistration registration = mServices.get(id);
+        if (registration == null) {
+            return null;
+        }
+        return registration.serviceInfo.getHostname();
+    }
+
+    /**
+     * Restart probing the services which are being probed and using the given custom hostname.
+     *
+     * @return The list of {@link MdnsProber.ProbingInfo} to be used by advertiser.
+     */
+    public List<MdnsProber.ProbingInfo> restartProbingForHostname(@NonNull String hostname) {
+        final ArrayList<MdnsProber.ProbingInfo> probingInfos = new ArrayList<>();
+        forEachActiveServiceRegistrationWithHostname(
+                hostname,
+                (id, registration) -> {
+                    if (!registration.isProbing) {
+                        return;
+                    }
+                    probingInfos.add(makeProbingInfo(id, registration));
+                });
+        return probingInfos;
+    }
+
+    /**
+     * Restart announcing the services which are using the given custom hostname.
+     *
+     * @return The list of {@link MdnsAnnouncer.AnnouncementInfo} to be used by advertiser.
+     */
+    public List<MdnsAnnouncer.AnnouncementInfo> restartAnnouncingForHostname(
+            @NonNull String hostname) {
+        final ArrayList<MdnsAnnouncer.AnnouncementInfo> announcementInfos = new ArrayList<>();
+        forEachActiveServiceRegistrationWithHostname(
+                hostname,
+                (id, registration) -> {
+                    if (registration.isProbing) {
+                        return;
+                    }
+                    announcementInfos.add(makeAnnouncementInfo(id, registration));
+                });
+        return announcementInfos;
+    }
+
     /**
      * Called to indicate that probing succeeded for a service.
+     *
      * @param probeSuccessInfo The successful probing info.
      * @return The {@link MdnsAnnouncer.AnnouncementInfo} to send, now that probing has succeeded.
      */
     public MdnsAnnouncer.AnnouncementInfo onProbingSucceeded(
-            MdnsProber.ProbingInfo probeSuccessInfo)
-            throws IOException {
-
-        int serviceId = probeSuccessInfo.getServiceId();
+            MdnsProber.ProbingInfo probeSuccessInfo) throws IOException {
+        final int serviceId = probeSuccessInfo.getServiceId();
         final ServiceRegistration registration = mServices.get(serviceId);
         if (registration == null) {
             throw new IOException("Service is not registered: " + serviceId);
         }
         registration.setProbing(false);
 
+        return makeAnnouncementInfo(serviceId, registration);
+    }
+
+    /**
+     * Make the announcement info of the given service ID.
+     *
+     * @param serviceId The service ID.
+     * @param registration The service registration.
+     * @return The {@link MdnsAnnouncer.AnnouncementInfo} of the given service ID.
+     */
+    private MdnsAnnouncer.AnnouncementInfo makeAnnouncementInfo(
+            int serviceId, ServiceRegistration registration) {
         final Set<MdnsRecord> answersSet = new LinkedHashSet<>();
         final ArrayList<MdnsRecord> additionalAnswers = new ArrayList<>();
 
@@ -972,8 +1029,8 @@
         addNsecRecordsForUniqueNames(additionalAnswers,
                 mGeneralRecords.iterator(), registration.allRecords.iterator());
 
-        return new MdnsAnnouncer.AnnouncementInfo(
-                probeSuccessInfo.getServiceId(), new ArrayList<>(answersSet), additionalAnswers);
+        return new MdnsAnnouncer.AnnouncementInfo(serviceId,
+                new ArrayList<>(answersSet), additionalAnswers);
     }
 
     /**
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
index 8f41b94..bfcd0b4 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
@@ -35,9 +35,12 @@
 import com.android.net.module.util.SharedLog;
 import com.android.server.connectivity.mdns.util.MdnsUtils;
 
+import java.io.IOException;
 import java.io.PrintWriter;
+import java.net.DatagramPacket;
 import java.net.Inet4Address;
 import java.net.Inet6Address;
+import java.net.InetSocketAddress;
 import java.time.Instant;
 import java.util.ArrayList;
 import java.util.Arrays;
@@ -55,7 +58,6 @@
 public class MdnsServiceTypeClient {
 
     private static final String TAG = MdnsServiceTypeClient.class.getSimpleName();
-    private static final int DEFAULT_MTU = 1500;
     @VisibleForTesting
     static final int EVENT_START_QUERYTASK = 1;
     static final int EVENT_QUERY_RESULT = 2;
@@ -84,6 +86,7 @@
                     notifyRemovedServiceToListeners(previousResponse, "Service record expired");
                 }
             };
+    @NonNull private final MdnsFeatureFlags featureFlags;
     private final ArrayMap<MdnsServiceBrowserListener, ListenerInfo> listeners =
             new ArrayMap<>();
     private final boolean removeServiceAfterTtlExpires =
@@ -142,7 +145,8 @@
                     // before sending the query, it needs to be called just before sending it.
                     final List<MdnsResponse> servicesToResolve = makeResponsesForResolve(socketKey);
                     final QueryTask queryTask = new QueryTask(taskArgs, servicesToResolve,
-                            getAllDiscoverySubtypes(), needSendDiscoveryQueries(listeners));
+                            getAllDiscoverySubtypes(), needSendDiscoveryQueries(listeners),
+                            getExistingServices());
                     executor.submit(queryTask);
                     break;
                 }
@@ -221,6 +225,16 @@
         public void sendMessage(@NonNull Handler handler, @NonNull Message message) {
             handler.sendMessage(message);
         }
+
+        /**
+         * Generate a DatagramPacket from given MdnsPacket and InetSocketAddress.
+         */
+        public DatagramPacket getDatagramPacketFromMdnsPacket(@NonNull byte[] packetCreationBuffer,
+                @NonNull MdnsPacket packet, @NonNull InetSocketAddress address) throws IOException {
+            final byte[] queryBuffer =
+                    MdnsUtils.createRawDnsPacket(packetCreationBuffer, packet);
+            return new DatagramPacket(queryBuffer, 0, queryBuffer.length, address);
+        }
     }
 
     /**
@@ -236,9 +250,10 @@
             @NonNull SocketKey socketKey,
             @NonNull SharedLog sharedLog,
             @NonNull Looper looper,
-            @NonNull MdnsServiceCache serviceCache) {
+            @NonNull MdnsServiceCache serviceCache,
+            @NonNull MdnsFeatureFlags featureFlags) {
         this(serviceType, socketClient, executor, new Clock(), socketKey, sharedLog, looper,
-                new Dependencies(), serviceCache);
+                new Dependencies(), serviceCache, featureFlags);
     }
 
     @VisibleForTesting
@@ -251,7 +266,8 @@
             @NonNull SharedLog sharedLog,
             @NonNull Looper looper,
             @NonNull Dependencies dependencies,
-            @NonNull MdnsServiceCache serviceCache) {
+            @NonNull MdnsServiceCache serviceCache,
+            @NonNull MdnsFeatureFlags featureFlags) {
         this.serviceType = serviceType;
         this.socketClient = socketClient;
         this.executor = executor;
@@ -265,6 +281,7 @@
         this.serviceCache = serviceCache;
         this.mdnsQueryScheduler = new MdnsQueryScheduler();
         this.cacheKey = new MdnsServiceCache.CacheKey(serviceType, socketKey);
+        this.featureFlags = featureFlags;
     }
 
     /**
@@ -327,6 +344,11 @@
                 now.plusMillis(response.getMinRemainingTtl(now.toEpochMilli())));
     }
 
+    private List<MdnsResponse> getExistingServices() {
+        return featureFlags.isQueryWithKnownAnswerEnabled()
+                ? serviceCache.getCachedServices(cacheKey) : Collections.emptyList();
+    }
+
     /**
      * Registers {@code listener} for receiving discovery event of mDNS service instances, and
      * starts
@@ -391,7 +413,8 @@
             final QueryTask queryTask = new QueryTask(
                     mdnsQueryScheduler.scheduleFirstRun(taskConfig, now,
                             minRemainingTtl, currentSessionId), servicesToResolve,
-                    getAllDiscoverySubtypes(), needSendDiscoveryQueries(listeners));
+                    getAllDiscoverySubtypes(), needSendDiscoveryQueries(listeners),
+                    getExistingServices());
             executor.submit(queryTask);
         }
 
@@ -617,11 +640,6 @@
         return searchOptions != null && searchOptions.removeExpiredService();
     }
 
-    @VisibleForTesting
-    MdnsPacketWriter createMdnsPacketWriter() {
-        return new MdnsPacketWriter(DEFAULT_MTU);
-    }
-
     private List<MdnsResponse> makeResponsesForResolve(@NonNull SocketKey socketKey) {
         final List<MdnsResponse> resolveResponses = new ArrayList<>();
         for (int i = 0; i < listeners.size(); i++) {
@@ -694,14 +712,16 @@
         private final List<MdnsResponse> servicesToResolve = new ArrayList<>();
         private final List<String> subtypes = new ArrayList<>();
         private final boolean sendDiscoveryQueries;
+        private final List<MdnsResponse> existingServices = new ArrayList<>();
         QueryTask(@NonNull MdnsQueryScheduler.ScheduledQueryTaskArgs taskArgs,
                 @NonNull Collection<MdnsResponse> servicesToResolve,
-                @NonNull Collection<String> subtypes,
-                boolean sendDiscoveryQueries) {
+                @NonNull Collection<String> subtypes, boolean sendDiscoveryQueries,
+                @NonNull Collection<MdnsResponse> existingServices) {
             this.taskArgs = taskArgs;
             this.servicesToResolve.addAll(servicesToResolve);
             this.subtypes.addAll(subtypes);
             this.sendDiscoveryQueries = sendDiscoveryQueries;
+            this.existingServices.addAll(existingServices);
         }
 
         @Override
@@ -711,7 +731,6 @@
                 result =
                         new EnqueueMdnsQueryCallable(
                                 socketClient,
-                                createMdnsPacketWriter(),
                                 serviceType,
                                 subtypes,
                                 taskArgs.config.expectUnicastResponse,
@@ -721,7 +740,9 @@
                                 sendDiscoveryQueries,
                                 servicesToResolve,
                                 clock,
-                                sharedLog)
+                                sharedLog,
+                                dependencies,
+                                existingServices)
                                 .call();
             } catch (RuntimeException e) {
                 sharedLog.e(String.format("Failed to run EnqueueMdnsQueryCallable for subtype: %s",
diff --git a/service/src/com/android/server/BpfNetMaps.java b/service/src/com/android/server/BpfNetMaps.java
index 487f25c..fc6d8c4 100644
--- a/service/src/com/android/server/BpfNetMaps.java
+++ b/service/src/com/android/server/BpfNetMaps.java
@@ -49,8 +49,8 @@
 
 import android.app.StatsManager;
 import android.content.Context;
+import android.net.BpfNetMapsUtils;
 import android.net.INetd;
-import android.net.NetworkStackBpfNetMaps;
 import android.net.UidOwnerValue;
 import android.os.Build;
 import android.os.RemoteException;
@@ -539,7 +539,7 @@
     @Deprecated
     @RequiresApi(Build.VERSION_CODES.TIRAMISU)
     public boolean isChainEnabled(final int childChain) {
-        return NetworkStackBpfNetMaps.isChainEnabled(sConfigurationMap, childChain);
+        return BpfNetMapsUtils.isChainEnabled(sConfigurationMap, childChain);
     }
 
     private Set<Integer> asSet(final int[] uids) {
@@ -634,7 +634,7 @@
      *                                  cause of the failure.
      */
     public int getUidRule(final int childChain, final int uid) {
-        return NetworkStackBpfNetMaps.getUidRule(sUidOwnerMap, childChain, uid);
+        return BpfNetMapsUtils.getUidRule(sUidOwnerMap, childChain, uid);
     }
 
     private Set<Integer> getUidsMatchEnabled(final int childChain) throws ErrnoException {
diff --git a/service/src/com/android/server/ConnectivityService.java b/service/src/com/android/server/ConnectivityService.java
index f27e645..123ad8f 100755
--- a/service/src/com/android/server/ConnectivityService.java
+++ b/service/src/com/android/server/ConnectivityService.java
@@ -848,11 +848,6 @@
     private static final int EVENT_UID_FROZEN_STATE_CHANGED = 61;
 
     /**
-     * Event to inform the ConnectivityService handler when a uid has lost carrier privileges.
-     */
-    private static final int EVENT_UID_CARRIER_PRIVILEGES_LOST = 62;
-
-    /**
      * Argument for {@link #EVENT_PROVISIONING_NOTIFICATION} to indicate that the notification
      * should be shown.
      */
@@ -1295,14 +1290,6 @@
     }
     private final LegacyTypeTracker mLegacyTypeTracker = new LegacyTypeTracker(this);
 
-    @VisibleForTesting
-    void onCarrierPrivilegesLost(Integer uid, Integer subId) {
-        if (mRequestRestrictedWifiEnabled) {
-            mHandler.sendMessage(mHandler.obtainMessage(
-                    EVENT_UID_CARRIER_PRIVILEGES_LOST, uid, subId));
-        }
-    }
-
     final LocalPriorityDump mPriorityDumper = new LocalPriorityDump();
     /**
      * Helper class which parses out priority arguments and dumps sections according to their
@@ -1524,10 +1511,11 @@
                 @NonNull final Context context,
                 @NonNull final TelephonyManager tm,
                 boolean requestRestrictedWifiEnabled,
-                @NonNull BiConsumer<Integer, Integer> listener) {
+                @NonNull BiConsumer<Integer, Integer> listener,
+                @NonNull final Handler connectivityServiceHandler) {
             if (isAtLeastT()) {
-                return new CarrierPrivilegeAuthenticator(
-                        context, tm, requestRestrictedWifiEnabled, listener);
+                return new CarrierPrivilegeAuthenticator(context, tm, requestRestrictedWifiEnabled,
+                        listener, connectivityServiceHandler);
             } else {
                 return null;
             }
@@ -1812,7 +1800,7 @@
                 && mDeps.isFeatureEnabled(context, REQUEST_RESTRICTED_WIFI);
         mCarrierPrivilegeAuthenticator = mDeps.makeCarrierPrivilegeAuthenticator(
                 mContext, mTelephonyManager, mRequestRestrictedWifiEnabled,
-                this::onCarrierPrivilegesLost);
+                this::handleUidCarrierPrivilegesLost, mHandler);
 
         if (mDeps.isAtLeastU()
                 && mDeps
@@ -3831,6 +3819,10 @@
             mSatelliteAccessController.start();
         }
 
+        if (mCarrierPrivilegeAuthenticator != null) {
+            mCarrierPrivilegeAuthenticator.start();
+        }
+
         // On T+ devices, register callback for statsd to pull NETWORK_BPF_MAP_INFO atom
         if (mDeps.isAtLeastT()) {
             mBpfNetMaps.setPullAtomCallback(mContext);
@@ -6519,9 +6511,6 @@
                     UidFrozenStateChangedArgs args = (UidFrozenStateChangedArgs) msg.obj;
                     handleFrozenUids(args.mUids, args.mFrozenStates);
                     break;
-                case EVENT_UID_CARRIER_PRIVILEGES_LOST:
-                    handleUidCarrierPrivilegesLost(msg.arg1, msg.arg2);
-                    break;
             }
         }
     }
@@ -9260,6 +9249,9 @@
     }
 
     private void handleUidCarrierPrivilegesLost(int uid, int subId) {
+        if (!mRequestRestrictedWifiEnabled) {
+            return;
+        }
         ensureRunningOnConnectivityServiceThread();
         // A NetworkRequest needs to be revoked when all the conditions are met
         //   1. It requests restricted network
diff --git a/service/src/com/android/server/connectivity/CarrierPrivilegeAuthenticator.java b/service/src/com/android/server/connectivity/CarrierPrivilegeAuthenticator.java
index 04d0fc1..f5fa4fb 100644
--- a/service/src/com/android/server/connectivity/CarrierPrivilegeAuthenticator.java
+++ b/service/src/com/android/server/connectivity/CarrierPrivilegeAuthenticator.java
@@ -91,42 +91,62 @@
             @NonNull final TelephonyManager t,
             @NonNull final TelephonyManagerShim telephonyManagerShim,
             final boolean requestRestrictedWifiEnabled,
-            @NonNull BiConsumer<Integer, Integer> listener) {
+            @NonNull BiConsumer<Integer, Integer> listener,
+            @NonNull final Handler connectivityServiceHandler) {
         mContext = c;
         mTelephonyManager = t;
         mTelephonyManagerShim = telephonyManagerShim;
-        final HandlerThread thread = deps.makeHandlerThread();
-        thread.start();
-        mHandler = new Handler(thread.getLooper());
         mUseCallbacksForServiceChanged = deps.isFeatureEnabled(
                 c, CARRIER_SERVICE_CHANGED_USE_CALLBACK);
         mRequestRestrictedWifiEnabled = requestRestrictedWifiEnabled;
         mListener = listener;
+        if (mRequestRestrictedWifiEnabled) {
+            mHandler = connectivityServiceHandler;
+        } else {
+            final HandlerThread thread = deps.makeHandlerThread();
+            thread.start();
+            mHandler = new Handler(thread.getLooper());
+            synchronized (mLock) {
+                registerSimConfigChangedReceiver();
+                simConfigChanged();
+            }
+        }
+    }
+
+    private void registerSimConfigChangedReceiver() {
         final IntentFilter filter = new IntentFilter();
         filter.addAction(TelephonyManager.ACTION_MULTI_SIM_CONFIG_CHANGED);
-        synchronized (mLock) {
-            // Never unregistered because the system server never stops
-            c.registerReceiver(new BroadcastReceiver() {
-                @Override
-                public void onReceive(final Context context, final Intent intent) {
-                    switch (intent.getAction()) {
-                        case TelephonyManager.ACTION_MULTI_SIM_CONFIG_CHANGED:
-                            simConfigChanged();
-                            break;
-                        default:
-                            Log.d(TAG, "Unknown intent received, action: " + intent.getAction());
-                    }
+        // Never unregistered because the system server never stops
+        mContext.registerReceiver(new BroadcastReceiver() {
+            @Override
+            public void onReceive(final Context context, final Intent intent) {
+                switch (intent.getAction()) {
+                    case TelephonyManager.ACTION_MULTI_SIM_CONFIG_CHANGED:
+                        simConfigChanged();
+                        break;
+                    default:
+                        Log.d(TAG, "Unknown intent received, action: " + intent.getAction());
                 }
-            }, filter, null, mHandler);
-            simConfigChanged();
+            }
+        }, filter, null, mHandler);
+    }
+
+    /**
+     * Start CarrierPrivilegeAuthenticator
+     */
+    public void start() {
+        if (mRequestRestrictedWifiEnabled) {
+            registerSimConfigChangedReceiver();
+            mHandler.post(this::simConfigChanged);
         }
     }
 
     public CarrierPrivilegeAuthenticator(@NonNull final Context c,
             @NonNull final TelephonyManager t, final boolean requestRestrictedWifiEnabled,
-            @NonNull BiConsumer<Integer, Integer> listener) {
+            @NonNull BiConsumer<Integer, Integer> listener,
+            @NonNull final Handler connectivityServiceHandler) {
         this(c, new Dependencies(), t, TelephonyManagerShimImpl.newInstance(t),
-                requestRestrictedWifiEnabled, listener);
+                requestRestrictedWifiEnabled, listener, connectivityServiceHandler);
     }
 
     public static class Dependencies {
@@ -146,6 +166,10 @@
     }
 
     private void simConfigChanged() {
+        //  If mRequestRestrictedWifiEnabled is false, constructor calls simConfigChanged
+        if (mRequestRestrictedWifiEnabled) {
+            ensureRunningOnHandlerThread();
+        }
         synchronized (mLock) {
             unregisterCarrierPrivilegesListeners();
             mModemCount = mTelephonyManager.getActiveModemCount();
@@ -188,6 +212,7 @@
         public void onCarrierPrivilegesChanged(
                 @NonNull List<String> privilegedPackageNames,
                 @NonNull int[] privilegedUids) {
+            ensureRunningOnHandlerThread();
             if (mUseCallbacksForServiceChanged) return;
             // Re-trigger the synchronous check (which is also very cheap due
             // to caching in CarrierPrivilegesTracker). This allows consistency
@@ -198,6 +223,7 @@
         @Override
         public void onCarrierServiceChanged(@Nullable final String carrierServicePackageName,
                 final int carrierServiceUid) {
+            ensureRunningOnHandlerThread();
             if (!mUseCallbacksForServiceChanged) {
                 // Re-trigger the synchronous check (which is also very cheap due
                 // to caching in CarrierPrivilegesTracker). This allows consistency
@@ -439,6 +465,13 @@
         }
     }
 
+    private void ensureRunningOnHandlerThread() {
+        if (mHandler.getLooper().getThread() != Thread.currentThread()) {
+            throw new IllegalStateException(
+                    "Not running on handler thread: " + Thread.currentThread().getName());
+        }
+    }
+
     public void dump(IndentingPrintWriter pw) {
         pw.println("CarrierPrivilegeAuthenticator:");
         pw.println("mRequestRestrictedWifiEnabled = " + mRequestRestrictedWifiEnabled);
diff --git a/staticlibs/device/com/android/net/module/util/structs/PrefixInformationOption.java b/staticlibs/device/com/android/net/module/util/structs/PrefixInformationOption.java
index 49d7654..0fc85e4 100644
--- a/staticlibs/device/com/android/net/module/util/structs/PrefixInformationOption.java
+++ b/staticlibs/device/com/android/net/module/util/structs/PrefixInformationOption.java
@@ -21,6 +21,7 @@
 import android.net.IpPrefix;
 
 import androidx.annotation.NonNull;
+import androidx.annotation.VisibleForTesting;
 
 import com.android.net.module.util.Struct;
 import com.android.net.module.util.Struct.Field;
@@ -71,7 +72,8 @@
     @Field(order = 7, type = Type.ByteArray, arraysize = 16)
     public final byte[] prefix;
 
-    PrefixInformationOption(final byte type, final byte length, final byte prefixLen,
+    @VisibleForTesting
+    public PrefixInformationOption(final byte type, final byte length, final byte prefixLen,
             final byte flags, final long validLifetime, final long preferredLifetime,
             final int reserved, @NonNull final byte[] prefix) {
         this.type = type;
diff --git a/tests/cts/net/src/android/net/cts/ApfIntegrationTest.kt b/tests/cts/net/src/android/net/cts/ApfIntegrationTest.kt
new file mode 100644
index 0000000..e92c906
--- /dev/null
+++ b/tests/cts/net/src/android/net/cts/ApfIntegrationTest.kt
@@ -0,0 +1,81 @@
+/*
+ * Copyright (C) 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package android.net.cts
+
+import android.content.pm.PackageManager.FEATURE_WIFI
+import android.net.ConnectivityManager
+import android.net.NetworkCapabilities
+import android.net.NetworkRequest
+import android.os.Build
+import android.system.OsConstants
+import androidx.test.platform.app.InstrumentationRegistry
+import com.android.compatibility.common.util.PropertyUtil.isVendorApiLevelNewerThan
+import com.android.compatibility.common.util.SystemUtil
+import com.android.testutils.DevSdkIgnoreRule
+import com.android.testutils.DevSdkIgnoreRunner
+import com.android.testutils.NetworkStackModuleTest
+import com.android.testutils.RecorderCallback.CallbackEntry.LinkPropertiesChanged
+import com.android.testutils.TestableNetworkCallback
+import com.google.common.truth.Truth.assertThat
+import kotlin.test.assertEquals
+import kotlin.test.assertNotNull
+import org.junit.Assume.assumeTrue
+import org.junit.Before
+import org.junit.Test
+import org.junit.runner.RunWith
+
+private const val TIMEOUT_MS = 2000L
+
+@RunWith(DevSdkIgnoreRunner::class)
+@NetworkStackModuleTest
+@DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.TIRAMISU)
+class ApfIntegrationTest {
+    private val context by lazy { InstrumentationRegistry.getInstrumentation().context }
+    private val cm by lazy { context.getSystemService(ConnectivityManager::class.java)!! }
+    private val pm by lazy { context.packageManager }
+    private lateinit var wifiIfaceName: String
+    @Before
+    fun setUp() {
+        assumeTrue(pm.hasSystemFeature(FEATURE_WIFI))
+        assumeTrue(isVendorApiLevelNewerThan(Build.VERSION_CODES.TIRAMISU))
+        val cb = TestableNetworkCallback()
+        cm.requestNetwork(
+                NetworkRequest.Builder()
+                        .addTransportType(NetworkCapabilities.TRANSPORT_WIFI)
+                        .addCapability(NetworkCapabilities.NET_CAPABILITY_INTERNET)
+                        .build(),
+                cb
+        )
+        cb.eventuallyExpect<LinkPropertiesChanged>(TIMEOUT_MS) {
+            wifiIfaceName = assertNotNull(it.lp.interfaceName)
+            true
+        }
+        assertNotNull(wifiIfaceName)
+    }
+
+    @Test
+    fun testGetApfCapabilities() {
+        val capabilities = SystemUtil
+                .runShellCommand("cmd network_stack apf $wifiIfaceName capabilities").trim()
+        val (version, maxLen, packetFormat) = capabilities.split(",").map { it.toInt() }
+        assertEquals(4, version)
+        assertThat(maxLen).isAtLeast(1024)
+        if (isVendorApiLevelNewerThan(Build.VERSION_CODES.UPSIDE_DOWN_CAKE)) {
+            assertThat(maxLen).isAtLeast(2000)
+        }
+        assertEquals(OsConstants.ARPHRD_ETHER, packetFormat)
+    }
+}
diff --git a/tests/cts/net/src/android/net/cts/NsdManagerTest.kt b/tests/cts/net/src/android/net/cts/NsdManagerTest.kt
index 0daf7fe..6dd4857 100644
--- a/tests/cts/net/src/android/net/cts/NsdManagerTest.kt
+++ b/tests/cts/net/src/android/net/cts/NsdManagerTest.kt
@@ -1888,6 +1888,64 @@
         }
     }
 
+    @Test
+    fun testQueryWhenKnownAnswerSuppressionFlagSet() {
+        // The flag may be removed in the future but known-answer suppression should be enabled by
+        // default in that case. The rule will reset flags automatically on teardown.
+        deviceConfigRule.setConfig(NAMESPACE_TETHERING, "test_nsd_query_with_known_answer", "1")
+
+        // Register service on testNetwork1
+        val discoveryRecord = NsdDiscoveryRecord()
+        val packetReader = TapPacketReader(Handler(handlerThread.looper),
+                testNetwork1.iface.fileDescriptor.fileDescriptor, 1500 /* maxPacketSize */)
+        packetReader.startAsyncForTest()
+        handlerThread.waitForIdle(TIMEOUT_MS)
+
+        nsdManager.discoverServices(serviceType, NsdManager.PROTOCOL_DNS_SD,
+                testNetwork1.network, { it.run() }, discoveryRecord)
+
+        tryTest {
+            discoveryRecord.expectCallback<DiscoveryStarted>()
+            assertNotNull(packetReader.pollForQuery("$serviceType.local", DnsResolver.TYPE_PTR))
+            /*
+            Generated with:
+            scapy.raw(scapy.DNS(rd=0, qr=1, aa=1, qd = None, an =
+                scapy.DNSRR(rrname='_nmt123456789._tcp.local', type='PTR', ttl=120,
+                rdata='NsdTest123456789._nmt123456789._tcp.local'))).hex()
+             */
+            val ptrResponsePayload = HexDump.hexStringToByteArray("0000840000000001000000000d5f6e" +
+                    "6d74313233343536373839045f746370056c6f63616c00000c000100000078002b104e736454" +
+                    "6573743132333435363738390d5f6e6d74313233343536373839045f746370056c6f63616c00")
+
+            replaceServiceNameAndTypeWithTestSuffix(ptrResponsePayload)
+            packetReader.sendResponse(buildMdnsPacket(ptrResponsePayload))
+
+            val serviceFound = discoveryRecord.expectCallback<ServiceFound>()
+            serviceFound.serviceInfo.let {
+                assertEquals(serviceName, it.serviceName)
+                // Discovered service types have a dot at the end
+                assertEquals("$serviceType.", it.serviceType)
+                assertEquals(testNetwork1.network, it.network)
+                // ServiceFound does not provide port, address or attributes (only information
+                // available in the PTR record is included in that callback, regardless of whether
+                // other records exist).
+                assertEquals(0, it.port)
+                assertEmpty(it.hostAddresses)
+                assertEquals(0, it.attributes.size)
+            }
+
+            // Expect the second query with a known answer
+            val query = packetReader.pollForMdnsPacket { pkt ->
+                pkt.isQueryFor("$serviceType.local", DnsResolver.TYPE_PTR) &&
+                        pkt.isReplyFor("$serviceType.local", DnsResolver.TYPE_PTR)
+            }
+            assertNotNull(query)
+        } cleanup {
+            nsdManager.stopServiceDiscovery(discoveryRecord)
+            discoveryRecord.expectCallback<DiscoveryStopped>()
+        }
+    }
+
     private fun makeLinkLocalAddressOfOtherDeviceOnPrefix(network: Network): Inet6Address {
         val lp = cm.getLinkProperties(network) ?: fail("No LinkProperties for net $network")
         // Expect to have a /64 link-local address
@@ -2148,6 +2206,66 @@
     }
 
     @Test
+    fun testAdvertisingAndDiscovery_reregisterCustomHostWithDifferentAddresses_newAddressesFound() {
+        val si1 = NsdServiceInfo().also {
+            it.network = testNetwork1.network
+            it.hostname = customHostname
+            it.hostAddresses = listOf(
+                    parseNumericAddress("192.0.2.23"),
+                    parseNumericAddress("2001:db8::1"))
+        }
+        val si2 = NsdServiceInfo().also {
+            it.network = testNetwork1.network
+            it.serviceName = serviceName
+            it.serviceType = serviceType
+            it.hostname = customHostname
+            it.port = TEST_PORT
+        }
+        val si3 = NsdServiceInfo().also {
+            it.network = testNetwork1.network
+            it.hostname = customHostname
+            it.hostAddresses = listOf(
+                    parseNumericAddress("192.0.2.24"),
+                    parseNumericAddress("2001:db8::2"))
+        }
+
+        val registrationRecord1 = NsdRegistrationRecord()
+        val registrationRecord2 = NsdRegistrationRecord()
+        val registrationRecord3 = NsdRegistrationRecord()
+
+        val discoveryRecord = NsdDiscoveryRecord()
+
+        tryTest {
+            registerService(registrationRecord1, si1)
+            registerService(registrationRecord2, si2)
+
+            nsdManager.unregisterService(registrationRecord1)
+            registrationRecord1.expectCallback<ServiceUnregistered>()
+
+            registerService(registrationRecord3, si3)
+
+            nsdManager.discoverServices(serviceType, NsdManager.PROTOCOL_DNS_SD,
+                    testNetwork1.network, Executor { it.run() }, discoveryRecord)
+            val discoveredInfo = discoveryRecord.waitForServiceDiscovered(
+                    serviceName, serviceType, testNetwork1.network)
+            val resolvedInfo = resolveService(discoveredInfo)
+
+            assertEquals(serviceName, discoveredInfo.serviceName)
+            assertEquals(TEST_PORT, resolvedInfo.port)
+            assertEquals(customHostname, resolvedInfo.hostname)
+            assertAddressEquals(
+                    listOf(parseNumericAddress("192.0.2.24"), parseNumericAddress("2001:db8::2")),
+                    resolvedInfo.hostAddresses)
+        } cleanupStep {
+            nsdManager.stopServiceDiscovery(discoveryRecord)
+            discoveryRecord.expectCallbackEventually<DiscoveryStopped>()
+        } cleanup {
+            nsdManager.unregisterService(registrationRecord2)
+            nsdManager.unregisterService(registrationRecord3)
+        }
+    }
+
+    @Test
     fun testServiceTypeClientRemovedAfterSocketDestroyed() {
         val si = makeTestServiceInfo(testNetwork1.network)
         // Register service on testNetwork1
diff --git a/tests/integration/src/com/android/server/net/integrationtests/ConnectivityServiceIntegrationTest.kt b/tests/integration/src/com/android/server/net/integrationtests/ConnectivityServiceIntegrationTest.kt
index 035f607..d2e46af 100644
--- a/tests/integration/src/com/android/server/net/integrationtests/ConnectivityServiceIntegrationTest.kt
+++ b/tests/integration/src/com/android/server/net/integrationtests/ConnectivityServiceIntegrationTest.kt
@@ -243,18 +243,19 @@
             super.makeHandlerThread(tag).also { handlerThreads.add(it) }
 
         override fun makeCarrierPrivilegeAuthenticator(
-            context: Context,
-            tm: TelephonyManager,
-            requestRestrictedWifiEnabled: Boolean,
-            listener: BiConsumer<Int, Int>
+                context: Context,
+                tm: TelephonyManager,
+                requestRestrictedWifiEnabled: Boolean,
+                listener: BiConsumer<Int, Int>,
+                handler: Handler
         ): CarrierPrivilegeAuthenticator {
             return CarrierPrivilegeAuthenticator(context,
-                object : CarrierPrivilegeAuthenticator.Dependencies() {
-                    override fun makeHandlerThread(): HandlerThread =
-                        super.makeHandlerThread().also { handlerThreads.add(it) }
-                },
-                tm, TelephonyManagerShimImpl.newInstance(tm),
-                requestRestrictedWifiEnabled, listener)
+                    object : CarrierPrivilegeAuthenticator.Dependencies() {
+                        override fun makeHandlerThread(): HandlerThread =
+                                super.makeHandlerThread().also { handlerThreads.add(it) }
+                    },
+                    tm, TelephonyManagerShimImpl.newInstance(tm),
+                    requestRestrictedWifiEnabled, listener, handler)
         }
 
         override fun makeSatelliteAccessController(
diff --git a/tests/unit/java/com/android/server/ConnectivityServiceTest.java b/tests/unit/java/com/android/server/ConnectivityServiceTest.java
index 8c30776..f41d7b2 100755
--- a/tests/unit/java/com/android/server/ConnectivityServiceTest.java
+++ b/tests/unit/java/com/android/server/ConnectivityServiceTest.java
@@ -2038,12 +2038,16 @@
             };
         }
 
+        private BiConsumer<Integer, Integer> mCarrierPrivilegesLostListener;
+
         @Override
         public CarrierPrivilegeAuthenticator makeCarrierPrivilegeAuthenticator(
                 @NonNull final Context context,
                 @NonNull final TelephonyManager tm,
                 final boolean requestRestrictedWifiEnabled,
-                BiConsumer<Integer, Integer> listener) {
+                BiConsumer<Integer, Integer> listener,
+                @NonNull final Handler handler) {
+            mCarrierPrivilegesLostListener = listener;
             return mDeps.isAtLeastT() ? mCarrierPrivilegeAuthenticator : null;
         }
 
@@ -17413,7 +17417,10 @@
                 .isCarrierServiceUidForNetworkCapabilities(eq(Process.myUid()), any());
         doReturn(TEST_SUBSCRIPTION_ID).when(mCarrierPrivilegeAuthenticator)
                 .getSubIdFromNetworkCapabilities(any());
-        mService.onCarrierPrivilegesLost(lostPrivilegeUid, lostPrivilegeSubId);
+
+        visibleOnHandlerThread(mCsHandlerThread.getThreadHandler(), () -> {
+            mDeps.mCarrierPrivilegesLostListener.accept(lostPrivilegeUid, lostPrivilegeSubId);
+        });
         waitForIdle();
 
         if (expectCapChanged) {
diff --git a/tests/unit/java/com/android/server/connectivity/CarrierPrivilegeAuthenticatorTest.java b/tests/unit/java/com/android/server/connectivity/CarrierPrivilegeAuthenticatorTest.java
index 7bd2b56..ab81abc 100644
--- a/tests/unit/java/com/android/server/connectivity/CarrierPrivilegeAuthenticatorTest.java
+++ b/tests/unit/java/com/android/server/connectivity/CarrierPrivilegeAuthenticatorTest.java
@@ -21,6 +21,7 @@
 import static android.telephony.TelephonyManager.ACTION_MULTI_SIM_CONFIG_CHANGED;
 
 import static com.android.server.connectivity.ConnectivityFlags.CARRIER_SERVICE_CHANGED_USE_CALLBACK;
+import static com.android.testutils.HandlerUtils.visibleOnHandlerThread;
 
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
@@ -45,6 +46,7 @@
 import android.net.NetworkCapabilities;
 import android.net.TelephonyNetworkSpecifier;
 import android.os.Build;
+import android.os.Handler;
 import android.os.HandlerThread;
 import android.telephony.TelephonyManager;
 
@@ -56,6 +58,7 @@
 import com.android.testutils.DevSdkIgnoreRule;
 import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo;
 import com.android.testutils.DevSdkIgnoreRunner;
+import com.android.testutils.HandlerUtils;
 
 import org.junit.After;
 import org.junit.Rule;
@@ -85,6 +88,7 @@
 
     private static final int SUBSCRIPTION_COUNT = 2;
     private static final int TEST_SUBSCRIPTION_ID = 1;
+    private static final int TIMEOUT_MS = 1_000;
 
     @NonNull private final Context mContext;
     @NonNull private final TelephonyManager mTelephonyManager;
@@ -97,13 +101,16 @@
     private final String mTestPkg = "com.android.server.connectivity.test";
     private final BroadcastReceiver mMultiSimBroadcastReceiver;
     @NonNull private final HandlerThread mHandlerThread;
+    @NonNull private final Handler mCsHandler;
+    @NonNull private final HandlerThread mCsHandlerThread;
 
     public class TestCarrierPrivilegeAuthenticator extends CarrierPrivilegeAuthenticator {
         TestCarrierPrivilegeAuthenticator(@NonNull final Context c,
                 @NonNull final Dependencies deps,
-                @NonNull final TelephonyManager t) {
+                @NonNull final TelephonyManager t,
+                @NonNull final Handler handler) {
             super(c, deps, t, mTelephonyManagerShim, true /* requestRestrictedWifiEnabled */,
-                    mListener);
+                    mListener, handler);
         }
         @Override
         protected int getSubId(int slotIndex) {
@@ -112,8 +119,11 @@
     }
 
     @After
-    public void tearDown() {
+    public void tearDown() throws Exception {
         mHandlerThread.quit();
+        mHandlerThread.join();
+        mCsHandlerThread.quit();
+        mCsHandlerThread.join();
     }
 
     /** Parameters to test both using callbacks or the old broadcast */
@@ -141,8 +151,14 @@
         final ApplicationInfo applicationInfo = new ApplicationInfo();
         applicationInfo.uid = mCarrierConfigPkgUid;
         doReturn(applicationInfo).when(mPackageManager).getApplicationInfo(eq(mTestPkg), anyInt());
-        mCarrierPrivilegeAuthenticator =
-                new TestCarrierPrivilegeAuthenticator(mContext, deps, mTelephonyManager);
+        mCsHandlerThread = new HandlerThread(
+                CarrierPrivilegeAuthenticatorTest.class.getSimpleName() + "-CsHandlerThread");
+        mCsHandlerThread.start();
+        mCsHandler = new Handler(mCsHandlerThread.getLooper());
+        mCarrierPrivilegeAuthenticator = new TestCarrierPrivilegeAuthenticator(mContext, deps,
+                mTelephonyManager, mCsHandler);
+        mCarrierPrivilegeAuthenticator.start();
+        HandlerUtils.waitForIdle(mCsHandlerThread, TIMEOUT_MS);
         final ArgumentCaptor<BroadcastReceiver> receiverCaptor =
                 ArgumentCaptor.forClass(BroadcastReceiver.class);
         verify(mContext).registerReceiver(receiverCaptor.capture(), argThat(filter ->
@@ -178,7 +194,9 @@
         assertNotNull(initialListeners.get(1));
         assertEquals(2, initialListeners.size());
 
-        initialListeners.get(0).onCarrierServiceChanged(null, mCarrierConfigPkgUid);
+        visibleOnHandlerThread(mCsHandler, () -> {
+            initialListeners.get(0).onCarrierServiceChanged(null, mCarrierConfigPkgUid);
+        });
 
         final NetworkCapabilities.Builder ncBuilder = new NetworkCapabilities.Builder()
                 .addTransportType(TRANSPORT_CELLULAR)
@@ -201,10 +219,10 @@
 
         doReturn(1).when(mTelephonyManager).getActiveModemCount();
 
-        // This is a little bit cavalier in that the call to onReceive is not on the handler
-        // thread that was specified in registerReceiver.
-        // TODO : capture the handler and call this on it if this causes flakiness.
-        mMultiSimBroadcastReceiver.onReceive(mContext, buildTestMultiSimConfigBroadcastIntent());
+        visibleOnHandlerThread(mCsHandler, () -> {
+            mMultiSimBroadcastReceiver.onReceive(mContext,
+                    buildTestMultiSimConfigBroadcastIntent());
+        });
         // Check all listeners have been removed
         for (CarrierPrivilegesListenerShim listener : initialListeners.values()) {
             verify(mTelephonyManagerShim).removeCarrierPrivilegesListener(eq(listener));
@@ -216,7 +234,9 @@
         assertNotNull(newListeners.get(0));
         assertEquals(1, newListeners.size());
 
-        newListeners.get(0).onCarrierServiceChanged(null, mCarrierConfigPkgUid);
+        visibleOnHandlerThread(mCsHandler, () -> {
+            newListeners.get(0).onCarrierServiceChanged(null, mCarrierConfigPkgUid);
+        });
 
         final TelephonyNetworkSpecifier specifier =
                 new TelephonyNetworkSpecifier(TEST_SUBSCRIPTION_ID);
@@ -235,12 +255,17 @@
     public void testCarrierPrivilegesLostDueToCarrierServiceUpdate() throws Exception {
         final CarrierPrivilegesListenerShim l = getCarrierPrivilegesListeners().get(0);
 
-        l.onCarrierServiceChanged(null, mCarrierConfigPkgUid);
-        l.onCarrierServiceChanged(null, mCarrierConfigPkgUid + 1);
+        visibleOnHandlerThread(mCsHandler, () -> {
+            l.onCarrierServiceChanged(null, mCarrierConfigPkgUid);
+            l.onCarrierServiceChanged(null, mCarrierConfigPkgUid + 1);
+        });
         if (mUseCallbacks) {
             verify(mListener).accept(eq(mCarrierConfigPkgUid), eq(TEST_SUBSCRIPTION_ID));
         }
-        l.onCarrierServiceChanged(null, mCarrierConfigPkgUid + 2);
+
+        visibleOnHandlerThread(mCsHandler, () -> {
+            l.onCarrierServiceChanged(null, mCarrierConfigPkgUid + 2);
+        });
         if (mUseCallbacks) {
             verify(mListener).accept(eq(mCarrierConfigPkgUid + 1), eq(TEST_SUBSCRIPTION_ID));
         }
@@ -260,8 +285,10 @@
         final ApplicationInfo applicationInfo = new ApplicationInfo();
         applicationInfo.uid = mCarrierConfigPkgUid + 1;
         doReturn(applicationInfo).when(mPackageManager).getApplicationInfo(eq(mTestPkg), anyInt());
-        listener.onCarrierPrivilegesChanged(Collections.emptyList(), new int[] {});
-        listener.onCarrierServiceChanged(null, applicationInfo.uid);
+        visibleOnHandlerThread(mCsHandler, () -> {
+            listener.onCarrierPrivilegesChanged(Collections.emptyList(), new int[]{});
+            listener.onCarrierServiceChanged(null, applicationInfo.uid);
+        });
 
         assertFalse(mCarrierPrivilegeAuthenticator.isCarrierServiceUidForNetworkCapabilities(
                 mCarrierConfigPkgUid, nc));
@@ -272,7 +299,9 @@
     @Test
     public void testDefaultSubscription() throws Exception {
         final CarrierPrivilegesListenerShim listener = getCarrierPrivilegesListeners().get(0);
-        listener.onCarrierServiceChanged(null, mCarrierConfigPkgUid);
+        visibleOnHandlerThread(mCsHandler, () -> {
+            listener.onCarrierServiceChanged(null, mCarrierConfigPkgUid);
+        });
 
         final NetworkCapabilities.Builder ncBuilder = new NetworkCapabilities.Builder();
         ncBuilder.addTransportType(TRANSPORT_CELLULAR);
@@ -297,7 +326,9 @@
     @IgnoreUpTo(Build.VERSION_CODES.TIRAMISU)
     public void testNetworkCapabilitiesContainOneSubId() throws Exception {
         final CarrierPrivilegesListenerShim listener = getCarrierPrivilegesListeners().get(0);
-        listener.onCarrierServiceChanged(null, mCarrierConfigPkgUid);
+        visibleOnHandlerThread(mCsHandler, () -> {
+            listener.onCarrierServiceChanged(null, mCarrierConfigPkgUid);
+        });
 
         final NetworkCapabilities.Builder ncBuilder = new NetworkCapabilities.Builder();
         ncBuilder.addTransportType(TRANSPORT_WIFI);
@@ -311,7 +342,9 @@
     @IgnoreUpTo(Build.VERSION_CODES.TIRAMISU)
     public void testNetworkCapabilitiesContainTwoSubIds() throws Exception {
         final CarrierPrivilegesListenerShim listener = getCarrierPrivilegesListeners().get(0);
-        listener.onCarrierServiceChanged(null, mCarrierConfigPkgUid);
+        visibleOnHandlerThread(mCsHandler, () -> {
+            listener.onCarrierServiceChanged(null, mCarrierConfigPkgUid);
+        });
 
         final NetworkCapabilities.Builder ncBuilder = new NetworkCapabilities.Builder();
         ncBuilder.addTransportType(TRANSPORT_WIFI);
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt
index 69fec85..7ac7bee 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt
@@ -18,7 +18,6 @@
 
 import android.net.InetAddresses.parseNumericAddress
 import android.net.LinkAddress
-import android.net.nsd.NsdManager
 import android.net.nsd.NsdServiceInfo
 import android.os.Build
 import android.os.HandlerThread
@@ -48,6 +47,7 @@
 import org.mockito.Mockito.anyInt
 import org.mockito.Mockito.anyString
 import org.mockito.Mockito.argThat
+import org.mockito.Mockito.atLeastOnce
 import org.mockito.Mockito.doAnswer
 import org.mockito.Mockito.doReturn
 import org.mockito.Mockito.eq
@@ -55,6 +55,8 @@
 import org.mockito.Mockito.never
 import org.mockito.Mockito.times
 import org.mockito.Mockito.verify
+import org.mockito.Mockito.clearInvocations
+import org.mockito.Mockito.inOrder
 
 private const val LOG_TAG = "testlogtag"
 private const val TIMEOUT_MS = 10_000L
@@ -65,6 +67,7 @@
 
 private const val TEST_SERVICE_ID_1 = 42
 private const val TEST_SERVICE_ID_DUPLICATE = 43
+private const val TEST_SERVICE_ID_2 = 44
 private val TEST_SERVICE_1 = NsdServiceInfo().apply {
     serviceType = "_testservice._tcp"
     serviceName = "MyTestService"
@@ -78,6 +81,13 @@
     port = 12345
 }
 
+private val TEST_SERVICE_1_CUSTOM_HOST = NsdServiceInfo().apply {
+    serviceType = "_testservice._tcp"
+    serviceName = "MyTestService"
+    hostname = "MyTestHost"
+    port = 12345
+}
+
 @RunWith(DevSdkIgnoreRunner::class)
 @IgnoreUpTo(Build.VERSION_CODES.S_V2)
 class MdnsInterfaceAdvertiserTest {
@@ -183,6 +193,63 @@
     }
 
     @Test
+    fun testAddRemoveServiceWithCustomHost_restartProbingForProbingServices() {
+        val customHost1 = NsdServiceInfo().apply {
+            hostname = "MyTestHost"
+            hostAddresses = listOf(
+                    parseNumericAddress("192.0.2.23"),
+                    parseNumericAddress("2001:db8::1"))
+        }
+        addServiceAndFinishProbing(TEST_SERVICE_ID_1, customHost1)
+        addServiceAndFinishProbing(TEST_SERVICE_ID_2, TEST_SERVICE_1_CUSTOM_HOST)
+        repository.setServiceProbing(TEST_SERVICE_ID_2)
+        val probingInfo = mock(ProbingInfo::class.java)
+        doReturn("MyTestHost")
+                .`when`(repository).getHostnameForServiceId(TEST_SERVICE_ID_1)
+        doReturn(TEST_SERVICE_ID_2).`when`(probingInfo).serviceId
+        doReturn(listOf(probingInfo))
+                .`when`(repository).restartProbingForHostname("MyTestHost")
+        val inOrder = inOrder(prober, announcer)
+
+        // Remove the custom host: the custom host's announcement is stopped and the probing
+        // services which use that hostname are re-announced.
+        advertiser.removeService(TEST_SERVICE_ID_1)
+
+        inOrder.verify(prober).stop(TEST_SERVICE_ID_1)
+        inOrder.verify(announcer).stop(TEST_SERVICE_ID_1)
+        inOrder.verify(prober).stop(TEST_SERVICE_ID_2)
+        inOrder.verify(prober).startProbing(probingInfo)
+    }
+
+    @Test
+    fun testAddRemoveServiceWithCustomHost_restartAnnouncingForProbedServices() {
+        val customHost1 = NsdServiceInfo().apply {
+            hostname = "MyTestHost"
+            hostAddresses = listOf(
+                    parseNumericAddress("192.0.2.23"),
+                    parseNumericAddress("2001:db8::1"))
+        }
+        addServiceAndFinishProbing(TEST_SERVICE_ID_1, customHost1)
+        val announcementInfo =
+                addServiceAndFinishProbing(TEST_SERVICE_ID_2, TEST_SERVICE_1_CUSTOM_HOST)
+        doReturn("MyTestHost")
+                .`when`(repository).getHostnameForServiceId(TEST_SERVICE_ID_1)
+        doReturn(TEST_SERVICE_ID_2).`when`(announcementInfo).serviceId
+        doReturn(listOf(announcementInfo))
+                .`when`(repository).restartAnnouncingForHostname("MyTestHost")
+        clearInvocations(announcer)
+
+        // Remove the custom host: the custom host's announcement is stopped and the probed services
+        // which use that hostname are re-announced.
+        advertiser.removeService(TEST_SERVICE_ID_1)
+
+        verify(prober).stop(TEST_SERVICE_ID_1)
+        verify(announcer, atLeastOnce()).stop(TEST_SERVICE_ID_1)
+        verify(announcer).stop(TEST_SERVICE_ID_2)
+        verify(announcer).startSending(TEST_SERVICE_ID_2, announcementInfo, 0L /* initialDelayMs */)
+    }
+
+    @Test
     fun testDoubleRemove() {
         addServiceAndFinishProbing(TEST_SERVICE_ID_1, TEST_SERVICE_1)
 
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt
index c69b1e1..271cc65 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt
@@ -24,6 +24,7 @@
 import com.android.server.connectivity.mdns.MdnsAnnouncer.AnnouncementInfo
 import com.android.server.connectivity.mdns.MdnsInterfaceAdvertiser.CONFLICT_HOST
 import com.android.server.connectivity.mdns.MdnsInterfaceAdvertiser.CONFLICT_SERVICE
+import com.android.server.connectivity.mdns.MdnsProber.ProbingInfo
 import com.android.server.connectivity.mdns.MdnsRecord.TYPE_A
 import com.android.server.connectivity.mdns.MdnsRecord.TYPE_AAAA
 import com.android.server.connectivity.mdns.MdnsRecord.TYPE_PTR
@@ -51,6 +52,10 @@
 import org.junit.Before
 import org.junit.Test
 import org.junit.runner.RunWith
+import org.mockito.ArgumentCaptor
+import org.mockito.ArgumentMatchers.eq
+import org.mockito.Mockito.mock
+import org.mockito.Mockito.verify
 
 private const val TEST_SERVICE_ID_1 = 42
 private const val TEST_SERVICE_ID_2 = 43
@@ -112,6 +117,14 @@
     port = TEST_PORT
 }
 
+private val TEST_SERVICE_CUSTOM_HOST_NO_ADDRESSES = NsdServiceInfo().apply {
+    hostname = "TestHost"
+    hostAddresses = listOf()
+    serviceType = "_testservice._tcp"
+    serviceName = "TestService"
+    port = TEST_PORT
+}
+
 @RunWith(DevSdkIgnoreRunner::class)
 @DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.S_V2)
 class MdnsRecordRepositoryTest {
@@ -1676,6 +1689,127 @@
         assertEquals(0, reply.additionalAnswers.size)
         assertEquals(knownAnswers, reply.knownAnswers)
     }
+
+    @Test
+    fun testRestartProbingForHostname() {
+        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
+        repository.initWithService(TEST_CUSTOM_HOST_ID_1, TEST_CUSTOM_HOST_1,
+                setOf(TEST_SUBTYPE, TEST_SUBTYPE2))
+        repository.addService(TEST_SERVICE_CUSTOM_HOST_ID_1,
+                TEST_SERVICE_CUSTOM_HOST_NO_ADDRESSES, null)
+        repository.setServiceProbing(TEST_SERVICE_CUSTOM_HOST_ID_1)
+        repository.removeService(TEST_CUSTOM_HOST_ID_1)
+
+        val probingInfos = repository.restartProbingForHostname("TestHost")
+
+        assertEquals(1, probingInfos.size)
+        val probingInfo = probingInfos.get(0)
+        assertEquals(TEST_SERVICE_CUSTOM_HOST_ID_1, probingInfo.serviceId)
+        val packet = probingInfo.getPacket(0)
+        assertEquals(0, packet.transactionId)
+        assertEquals(MdnsConstants.FLAGS_QUERY, packet.flags)
+        assertEquals(0, packet.answers.size)
+        assertEquals(0, packet.additionalRecords.size)
+        assertEquals(1, packet.questions.size)
+        val serviceName = arrayOf("TestService", "_testservice", "_tcp", "local")
+        assertEquals(MdnsAnyRecord(serviceName, false /* unicast */), packet.questions[0])
+        assertThat(packet.authorityRecords).containsExactly(
+                MdnsServiceRecord(
+                        serviceName,
+                        0L /* receiptTimeMillis */,
+                        false /* cacheFlush */,
+                        SHORT_TTL /* ttlMillis */,
+                        0 /* servicePriority */,
+                        0 /* serviceWeight */,
+                        TEST_PORT,
+                        TEST_CUSTOM_HOST_1_NAME))
+    }
+
+    @Test
+    fun testRestartAnnouncingForHostname() {
+        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
+        repository.initWithService(TEST_CUSTOM_HOST_ID_1, TEST_CUSTOM_HOST_1,
+                setOf(TEST_SUBTYPE, TEST_SUBTYPE2))
+        repository.addServiceAndFinishProbing(TEST_SERVICE_CUSTOM_HOST_ID_1,
+                TEST_SERVICE_CUSTOM_HOST_NO_ADDRESSES)
+        repository.removeService(TEST_CUSTOM_HOST_ID_1)
+
+        val announcementInfos = repository.restartAnnouncingForHostname("TestHost")
+
+        assertEquals(1, announcementInfos.size)
+        val announcementInfo = announcementInfos.get(0)
+        assertEquals(TEST_SERVICE_CUSTOM_HOST_ID_1, announcementInfo.serviceId)
+        val packet = announcementInfo.getPacket(0)
+        assertEquals(0, packet.transactionId)
+        assertEquals(0x8400 /* response, authoritative */, packet.flags)
+        assertEquals(0, packet.questions.size)
+        assertEquals(0, packet.authorityRecords.size)
+        val serviceName = arrayOf("TestService", "_testservice", "_tcp", "local")
+        val serviceType = arrayOf("_testservice", "_tcp", "local")
+        val v4AddrRev = getReverseDnsAddress(TEST_ADDRESSES[0].address)
+        val v6Addr1Rev = getReverseDnsAddress(TEST_ADDRESSES[1].address)
+        val v6Addr2Rev = getReverseDnsAddress(TEST_ADDRESSES[2].address)
+        assertThat(packet.answers).containsExactly(
+                MdnsPointerRecord(
+                        serviceType,
+                        0L /* receiptTimeMillis */,
+                        // Not a unique name owned by the announcer, so cacheFlush=false
+                        false /* cacheFlush */,
+                        4500000L /* ttlMillis */,
+                        serviceName),
+                MdnsServiceRecord(
+                        serviceName,
+                        0L /* receiptTimeMillis */,
+                        true /* cacheFlush */,
+                        120000L /* ttlMillis */,
+                        0 /* servicePriority */,
+                        0 /* serviceWeight */,
+                        TEST_PORT /* servicePort */,
+                        TEST_CUSTOM_HOST_1_NAME),
+                MdnsTextRecord(
+                        serviceName,
+                        0L /* receiptTimeMillis */,
+                        true /* cacheFlush */,
+                        4500000L /* ttlMillis */,
+                        emptyList() /* entries */),
+                MdnsPointerRecord(
+                        arrayOf("_services", "_dns-sd", "_udp", "local"),
+                        0L /* receiptTimeMillis */,
+                        false /* cacheFlush */,
+                        4500000L /* ttlMillis */,
+                        serviceType))
+        assertThat(packet.additionalRecords).containsExactly(
+                MdnsNsecRecord(v4AddrRev,
+                        0L /* receiptTimeMillis */,
+                        true /* cacheFlush */,
+                        120000L /* ttlMillis */,
+                        v4AddrRev,
+                        intArrayOf(TYPE_PTR)),
+                MdnsNsecRecord(TEST_HOSTNAME,
+                        0L /* receiptTimeMillis */,
+                        true /* cacheFlush */,
+                        120000L /* ttlMillis */,
+                        TEST_HOSTNAME,
+                        intArrayOf(TYPE_A, TYPE_AAAA)),
+                MdnsNsecRecord(v6Addr1Rev,
+                        0L /* receiptTimeMillis */,
+                        true /* cacheFlush */,
+                        120000L /* ttlMillis */,
+                        v6Addr1Rev,
+                        intArrayOf(TYPE_PTR)),
+                MdnsNsecRecord(v6Addr2Rev,
+                        0L /* receiptTimeMillis */,
+                        true /* cacheFlush */,
+                        120000L /* ttlMillis */,
+                        v6Addr2Rev,
+                        intArrayOf(TYPE_PTR)),
+                MdnsNsecRecord(serviceName,
+                        0L /* receiptTimeMillis */,
+                        true /* cacheFlush */,
+                        4500000L /* ttlMillis */,
+                        serviceName,
+                        intArrayOf(TYPE_TXT, TYPE_SRV)))
+    }
 }
 
 private fun MdnsRecordRepository.initWithService(
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java
index f279c5a..2eb9440 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java
@@ -38,6 +38,7 @@
 import static org.mockito.ArgumentMatchers.argThat;
 import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.doCallRealMethod;
 import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.inOrder;
 import static org.mockito.Mockito.never;
@@ -117,8 +118,6 @@
     @Mock
     private MdnsServiceBrowserListener mockListenerTwo;
     @Mock
-    private MdnsPacketWriter mockPacketWriter;
-    @Mock
     private MdnsMultinetworkSocketClient mockSocketClient;
     @Mock
     private Network mockNetwork;
@@ -145,6 +144,7 @@
     private long latestDelayMs = 0;
     private Message delayMessage = null;
     private Handler realHandler = null;
+    private MdnsFeatureFlags featureFlags = MdnsFeatureFlags.newBuilder().build();
 
     @Before
     @SuppressWarnings("DoNotMock")
@@ -162,7 +162,8 @@
             expectedIPv6Packets[i] = new DatagramPacket(buf, 0 /* offset */, 5 /* length */,
                     MdnsConstants.getMdnsIPv6Address(), MdnsConstants.MDNS_PORT);
         }
-        when(mockPacketWriter.getPacket(IPV4_ADDRESS))
+        when(mockDeps.getDatagramPacketFromMdnsPacket(
+                any(), any(MdnsPacket.class), eq(IPV4_ADDRESS)))
                 .thenReturn(expectedIPv4Packets[0])
                 .thenReturn(expectedIPv4Packets[1])
                 .thenReturn(expectedIPv4Packets[2])
@@ -188,7 +189,8 @@
                 .thenReturn(expectedIPv4Packets[22])
                 .thenReturn(expectedIPv4Packets[23]);
 
-        when(mockPacketWriter.getPacket(IPV6_ADDRESS))
+        when(mockDeps.getDatagramPacketFromMdnsPacket(
+                any(), any(MdnsPacket.class), eq(IPV6_ADDRESS)))
                 .thenReturn(expectedIPv6Packets[0])
                 .thenReturn(expectedIPv6Packets[1])
                 .thenReturn(expectedIPv6Packets[2])
@@ -242,22 +244,13 @@
             return true;
         }).when(mockDeps).sendMessage(any(Handler.class), any(Message.class));
 
-        client = makeMdnsServiceTypeClient(mockPacketWriter);
+        client = makeMdnsServiceTypeClient();
     }
 
-    private MdnsServiceTypeClient makeMdnsServiceTypeClient(
-            @Nullable MdnsPacketWriter packetWriter) {
+    private MdnsServiceTypeClient makeMdnsServiceTypeClient() {
         return new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor,
                 mockDecoderClock, socketKey, mockSharedLog, thread.getLooper(), mockDeps,
-                serviceCache) {
-            @Override
-            MdnsPacketWriter createMdnsPacketWriter() {
-                if (packetWriter == null) {
-                    return super.createMdnsPacketWriter();
-                }
-                return packetWriter;
-            }
-        };
+                serviceCache, featureFlags);
     }
 
     @After
@@ -697,11 +690,12 @@
 
     @Test
     public void testCombinedSubtypesQueriedWithMultipleListeners() throws Exception {
-        client = makeMdnsServiceTypeClient(/* packetWriter= */ null);
         final MdnsSearchOptions searchOptions1 = MdnsSearchOptions.newBuilder()
                 .addSubtype("subtype1").build();
         final MdnsSearchOptions searchOptions2 = MdnsSearchOptions.newBuilder()
                 .addSubtype("subtype2").build();
+        doCallRealMethod().when(mockDeps).getDatagramPacketFromMdnsPacket(
+                any(), any(MdnsPacket.class), any(InetSocketAddress.class));
         startSendAndReceive(mockListenerOne, searchOptions1);
         currentThreadExecutor.getAndClearLastScheduledRunnable().run();
 
@@ -1021,7 +1015,6 @@
     public void processResponse_searchOptionsEnableServiceRemoval_shouldRemove()
             throws Exception {
         final String serviceInstanceName = "service-instance-1";
-        client = makeMdnsServiceTypeClient(mockPacketWriter);
         MdnsSearchOptions searchOptions = MdnsSearchOptions.newBuilder()
                 .setRemoveExpiredService(true)
                 .setNumOfQueriesBeforeBackoff(Integer.MAX_VALUE)
@@ -1059,7 +1052,6 @@
     public void processResponse_searchOptionsNotEnableServiceRemoval_shouldNotRemove()
             throws Exception {
         final String serviceInstanceName = "service-instance-1";
-        client = makeMdnsServiceTypeClient(mockPacketWriter);
         startSendAndReceive(mockListenerOne, MdnsSearchOptions.getDefaultOptions());
         Runnable firstMdnsTask = currentThreadExecutor.getAndClearSubmittedRunnable();
 
@@ -1085,7 +1077,6 @@
             throws Exception {
         //MdnsConfigsFlagsImpl.removeServiceAfterTtlExpires.override(true);
         final String serviceInstanceName = "service-instance-1";
-        client = makeMdnsServiceTypeClient(mockPacketWriter);
         startSendAndReceive(mockListenerOne, MdnsSearchOptions.getDefaultOptions());
         Runnable firstMdnsTask = currentThreadExecutor.getAndClearSubmittedRunnable();
 
@@ -1200,8 +1191,6 @@
 
     @Test
     public void testProcessResponse_Resolve() throws Exception {
-        client = makeMdnsServiceTypeClient(/* packetWriter= */ null);
-
         final String instanceName = "service-instance";
         final String[] hostname = new String[] { "testhost "};
         final String ipV4Address = "192.0.2.0";
@@ -1212,6 +1201,9 @@
         final MdnsSearchOptions resolveOptions2 = MdnsSearchOptions.newBuilder()
                 .setResolveInstanceName(instanceName).build();
 
+        doCallRealMethod().when(mockDeps).getDatagramPacketFromMdnsPacket(
+                any(), any(MdnsPacket.class), any(InetSocketAddress.class));
+
         startSendAndReceive(mockListenerOne, resolveOptions1);
         startSendAndReceive(mockListenerTwo, resolveOptions2);
         // No need to verify order for both listeners; and order is not guaranteed between them
@@ -1316,8 +1308,6 @@
 
     @Test
     public void testRenewTxtSrvInResolve() throws Exception {
-        client = makeMdnsServiceTypeClient(/* packetWriter= */ null);
-
         final String instanceName = "service-instance";
         final String[] hostname = new String[] { "testhost "};
         final String ipV4Address = "192.0.2.0";
@@ -1326,6 +1316,9 @@
         final MdnsSearchOptions resolveOptions = MdnsSearchOptions.newBuilder()
                 .setResolveInstanceName(instanceName).build();
 
+        doCallRealMethod().when(mockDeps).getDatagramPacketFromMdnsPacket(
+                any(), any(MdnsPacket.class), any(InetSocketAddress.class));
+
         startSendAndReceive(mockListenerOne, resolveOptions);
         InOrder inOrder = inOrder(mockListenerOne, mockSocketClient);
 
@@ -1430,8 +1423,6 @@
 
     @Test
     public void testProcessResponse_ResolveExcludesOtherServices() {
-        client = makeMdnsServiceTypeClient(/* packetWriter= */ null);
-
         final String requestedInstance = "instance1";
         final String otherInstance = "instance2";
         final String ipV4Address = "192.0.2.0";
@@ -1498,8 +1489,6 @@
 
     @Test
     public void testProcessResponse_SubtypeDiscoveryLimitedToSubtype() {
-        client = makeMdnsServiceTypeClient(/* packetWriter= */ null);
-
         final String matchingInstance = "instance1";
         final String subtype = "_subtype";
         final String otherInstance = "instance2";
@@ -1586,8 +1575,6 @@
 
     @Test
     public void testProcessResponse_SubtypeChange() {
-        client = makeMdnsServiceTypeClient(/* packetWriter= */ null);
-
         final String matchingInstance = "instance1";
         final String subtype = "_subtype";
         final String ipV4Address = "192.0.2.0";
@@ -1669,8 +1656,6 @@
 
     @Test
     public void testNotifySocketDestroyed() throws Exception {
-        client = makeMdnsServiceTypeClient(/* packetWriter= */ null);
-
         final String requestedInstance = "instance1";
         final String otherInstance = "instance2";
         final String ipV4Address = "192.0.2.0";
@@ -1945,6 +1930,138 @@
                 16 /* scheduledCount */);
     }
 
+    @Test
+    public void testSendQueryWithKnownAnswers() throws Exception {
+        client = new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor,
+                mockDecoderClock, socketKey, mockSharedLog, thread.getLooper(), mockDeps,
+                serviceCache,
+                MdnsFeatureFlags.newBuilder().setIsQueryWithKnownAnswerEnabled(true).build());
+
+        doCallRealMethod().when(mockDeps).getDatagramPacketFromMdnsPacket(
+                any(), any(MdnsPacket.class), any(InetSocketAddress.class));
+
+        startSendAndReceive(mockListenerOne, MdnsSearchOptions.getDefaultOptions());
+        InOrder inOrder = inOrder(mockListenerOne, mockSocketClient);
+
+        final ArgumentCaptor<DatagramPacket> queryCaptor =
+                ArgumentCaptor.forClass(DatagramPacket.class);
+        currentThreadExecutor.getAndClearLastScheduledRunnable().run();
+        // Send twice for IPv4 and IPv6
+        inOrder.verify(mockSocketClient, times(2)).sendPacketRequestingUnicastResponse(
+                queryCaptor.capture(), eq(socketKey), eq(false));
+        verify(mockDeps, times(1)).sendMessage(any(), any(Message.class));
+        assertNotNull(delayMessage);
+
+        final MdnsPacket queryPacket = MdnsPacket.parse(
+                new MdnsPacketReader(queryCaptor.getValue()));
+        assertTrue(hasQuestion(queryPacket, MdnsRecord.TYPE_PTR));
+
+        // Process a response
+        final String serviceName = "service-instance";
+        final String ipV4Address = "192.0.2.0";
+        final String[] subtypeLabels = Stream.concat(Stream.of("_subtype", "_sub"),
+                        Arrays.stream(SERVICE_TYPE_LABELS)).toArray(String[]::new);
+        final MdnsPacket packetWithoutSubtype = createResponse(
+                serviceName, ipV4Address, 5353, SERVICE_TYPE_LABELS,
+                Collections.emptyMap() /* textAttributes */, TEST_TTL);
+        final MdnsPointerRecord originalPtr = (MdnsPointerRecord) CollectionUtils.findFirst(
+                packetWithoutSubtype.answers, r -> r instanceof MdnsPointerRecord);
+
+        // Add a subtype PTR record
+        final ArrayList<MdnsRecord> newAnswers = new ArrayList<>(packetWithoutSubtype.answers);
+        newAnswers.add(new MdnsPointerRecord(subtypeLabels, originalPtr.getReceiptTime(),
+                originalPtr.getCacheFlush(), originalPtr.getTtl(), originalPtr.getPointer()));
+        final MdnsPacket packetWithSubtype = new MdnsPacket(
+                packetWithoutSubtype.flags,
+                packetWithoutSubtype.questions,
+                newAnswers,
+                packetWithoutSubtype.authorityRecords,
+                packetWithoutSubtype.additionalRecords);
+        processResponse(packetWithSubtype, socketKey);
+
+        // Expect a query with known answers
+        dispatchMessage();
+        final ArgumentCaptor<DatagramPacket> knownAnswersQueryCaptor =
+                ArgumentCaptor.forClass(DatagramPacket.class);
+        currentThreadExecutor.getAndClearLastScheduledRunnable().run();
+        inOrder.verify(mockSocketClient, times(2)).sendPacketRequestingMulticastResponse(
+                knownAnswersQueryCaptor.capture(), eq(socketKey), eq(false));
+
+        final MdnsPacket knownAnswersQueryPacket = MdnsPacket.parse(
+                new MdnsPacketReader(knownAnswersQueryCaptor.getValue()));
+        assertTrue(hasQuestion(knownAnswersQueryPacket, MdnsRecord.TYPE_PTR, SERVICE_TYPE_LABELS));
+        assertTrue(hasAnswer(knownAnswersQueryPacket, MdnsRecord.TYPE_PTR, SERVICE_TYPE_LABELS));
+        assertFalse(hasAnswer(knownAnswersQueryPacket, MdnsRecord.TYPE_PTR, subtypeLabels));
+    }
+
+    @Test
+    public void testSendQueryWithSubTypeWithKnownAnswers() throws Exception {
+        client = new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor,
+                mockDecoderClock, socketKey, mockSharedLog, thread.getLooper(), mockDeps,
+                serviceCache,
+                MdnsFeatureFlags.newBuilder().setIsQueryWithKnownAnswerEnabled(true).build());
+
+        doCallRealMethod().when(mockDeps).getDatagramPacketFromMdnsPacket(
+                any(), any(MdnsPacket.class), any(InetSocketAddress.class));
+
+        final MdnsSearchOptions options = MdnsSearchOptions.newBuilder()
+                .addSubtype("subtype").build();
+        startSendAndReceive(mockListenerOne, options);
+        InOrder inOrder = inOrder(mockListenerOne, mockSocketClient);
+
+        final ArgumentCaptor<DatagramPacket> queryCaptor =
+                ArgumentCaptor.forClass(DatagramPacket.class);
+        currentThreadExecutor.getAndClearLastScheduledRunnable().run();
+        // Send twice for IPv4 and IPv6
+        inOrder.verify(mockSocketClient, times(2)).sendPacketRequestingUnicastResponse(
+                queryCaptor.capture(), eq(socketKey), eq(false));
+        verify(mockDeps, times(1)).sendMessage(any(), any(Message.class));
+        assertNotNull(delayMessage);
+
+        final MdnsPacket queryPacket = MdnsPacket.parse(
+                new MdnsPacketReader(queryCaptor.getValue()));
+        final String[] subtypeLabels = Stream.concat(Stream.of("_subtype", "_sub"),
+                Arrays.stream(SERVICE_TYPE_LABELS)).toArray(String[]::new);
+        assertTrue(hasQuestion(queryPacket, MdnsRecord.TYPE_PTR, SERVICE_TYPE_LABELS));
+        assertTrue(hasQuestion(queryPacket, MdnsRecord.TYPE_PTR, subtypeLabels));
+
+        // Process a response
+        final String serviceName = "service-instance";
+        final String ipV4Address = "192.0.2.0";
+        final MdnsPacket packetWithoutSubtype = createResponse(
+                serviceName, ipV4Address, 5353, SERVICE_TYPE_LABELS,
+                Collections.emptyMap() /* textAttributes */, TEST_TTL);
+        final MdnsPointerRecord originalPtr = (MdnsPointerRecord) CollectionUtils.findFirst(
+                packetWithoutSubtype.answers, r -> r instanceof MdnsPointerRecord);
+
+        // Add a subtype PTR record
+        final ArrayList<MdnsRecord> newAnswers = new ArrayList<>(packetWithoutSubtype.answers);
+        newAnswers.add(new MdnsPointerRecord(subtypeLabels, originalPtr.getReceiptTime(),
+                originalPtr.getCacheFlush(), originalPtr.getTtl(), originalPtr.getPointer()));
+        final MdnsPacket packetWithSubtype = new MdnsPacket(
+                packetWithoutSubtype.flags,
+                packetWithoutSubtype.questions,
+                newAnswers,
+                packetWithoutSubtype.authorityRecords,
+                packetWithoutSubtype.additionalRecords);
+        processResponse(packetWithSubtype, socketKey);
+
+        // Expect a query with known answers
+        dispatchMessage();
+        final ArgumentCaptor<DatagramPacket> knownAnswersQueryCaptor =
+                ArgumentCaptor.forClass(DatagramPacket.class);
+        currentThreadExecutor.getAndClearLastScheduledRunnable().run();
+        inOrder.verify(mockSocketClient, times(2)).sendPacketRequestingMulticastResponse(
+                knownAnswersQueryCaptor.capture(), eq(socketKey), eq(false));
+
+        final MdnsPacket knownAnswersQueryPacket = MdnsPacket.parse(
+                new MdnsPacketReader(knownAnswersQueryCaptor.getValue()));
+        assertTrue(hasQuestion(knownAnswersQueryPacket, MdnsRecord.TYPE_PTR, SERVICE_TYPE_LABELS));
+        assertTrue(hasQuestion(knownAnswersQueryPacket, MdnsRecord.TYPE_PTR, subtypeLabels));
+        assertTrue(hasAnswer(knownAnswersQueryPacket, MdnsRecord.TYPE_PTR, SERVICE_TYPE_LABELS));
+        assertTrue(hasAnswer(knownAnswersQueryPacket, MdnsRecord.TYPE_PTR, subtypeLabels));
+    }
+
     private static MdnsServiceInfo matchServiceName(String name) {
         return argThat(info -> info.getServiceInstanceName().equals(name));
     }
@@ -2005,6 +2122,12 @@
                 && (name == null || Arrays.equals(q.name, name)));
     }
 
+    private static boolean hasAnswer(MdnsPacket packet, int type, @NonNull String[] name) {
+        return packet.answers.stream().anyMatch(q -> {
+            return q.getType() == type && (Arrays.equals(q.name, name));
+        });
+    }
+
     // A fake ScheduledExecutorService that keeps tracking the last scheduled Runnable and its delay
     // time.
     private class FakeExecutor extends ScheduledThreadPoolExecutor {
diff --git a/tests/unit/java/com/android/server/connectivityservice/base/CSTest.kt b/tests/unit/java/com/android/server/connectivityservice/base/CSTest.kt
index 6c9871c..5c4617b 100644
--- a/tests/unit/java/com/android/server/connectivityservice/base/CSTest.kt
+++ b/tests/unit/java/com/android/server/connectivityservice/base/CSTest.kt
@@ -235,7 +235,8 @@
                 context: Context,
                 tm: TelephonyManager,
                 requestRestrictedWifiEnabled: Boolean,
-                listener: BiConsumer<Int, Int>
+                listener: BiConsumer<Int, Int>,
+                handler: Handler
         ) = if (SdkLevel.isAtLeastT()) mock<CarrierPrivilegeAuthenticator>() else null
 
         var satelliteNetworkFallbackUidUpdate: Consumer<Set<Int>>? = null
diff --git a/thread/demoapp/Android.bp b/thread/demoapp/Android.bp
index 00f8090..117b4f9 100644
--- a/thread/demoapp/Android.bp
+++ b/thread/demoapp/Android.bp
@@ -37,6 +37,7 @@
     required: [
         "privapp-permissions-com.android.threadnetwork.demoapp",
     ],
+    system_ext_specific: true,
     certificate: "platform",
     privileged: true,
     platform_apis: true,
@@ -47,4 +48,5 @@
     src: "privapp-permissions-com.android.threadnetwork.demoapp.xml",
     sub_dir: "permissions",
     filename_from_src: true,
+    system_ext_specific: true,
 }
diff --git a/thread/framework/Android.bp b/thread/framework/Android.bp
index 846253c..f8fe422 100644
--- a/thread/framework/Android.bp
+++ b/thread/framework/Android.bp
@@ -30,3 +30,14 @@
         "//packages/modules/Connectivity:__subpackages__",
     ],
 }
+
+filegroup {
+    name: "framework-thread-ot-daemon-shared-aidl-sources",
+    srcs: [
+        "java/android/net/thread/ChannelMaxPower.aidl",
+    ],
+    path: "java",
+    visibility: [
+        "//external/ot-br-posix:__subpackages__",
+    ],
+}
diff --git a/thread/framework/java/android/net/thread/ChannelMaxPower.aidl b/thread/framework/java/android/net/thread/ChannelMaxPower.aidl
new file mode 100644
index 0000000..bcda8a8
--- /dev/null
+++ b/thread/framework/java/android/net/thread/ChannelMaxPower.aidl
@@ -0,0 +1,28 @@
+/*
+ * Copyright (C) 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package android.net.thread;
+
+ /**
+  * Mapping from a channel to its max power.
+  *
+  * {@hide}
+  */
+parcelable ChannelMaxPower {
+    int channel; // The Thread radio channel.
+    int maxPower; // The max power in the unit of 0.01dBm. Passing INT16_MAX(32767) will
+                  // disable the channel.
+}
diff --git a/thread/framework/java/android/net/thread/IThreadNetworkController.aidl b/thread/framework/java/android/net/thread/IThreadNetworkController.aidl
index 485e25d..c5ca557 100644
--- a/thread/framework/java/android/net/thread/IThreadNetworkController.aidl
+++ b/thread/framework/java/android/net/thread/IThreadNetworkController.aidl
@@ -17,6 +17,7 @@
 package android.net.thread;
 
 import android.net.thread.ActiveOperationalDataset;
+import android.net.thread.ChannelMaxPower;
 import android.net.thread.IActiveOperationalDatasetReceiver;
 import android.net.thread.IOperationalDatasetCallback;
 import android.net.thread.IOperationReceiver;
@@ -39,6 +40,7 @@
     void leave(in IOperationReceiver receiver);
 
     void setTestNetworkAsUpstream(in String testNetworkInterfaceName, in IOperationReceiver receiver);
+    void setChannelMaxPowers(in ChannelMaxPower[] channelMaxPowers, in IOperationReceiver receiver);
 
     int getThreadVersion();
     void createRandomizedDataset(String networkName, IActiveOperationalDatasetReceiver receiver);
diff --git a/thread/framework/java/android/net/thread/ThreadNetworkController.java b/thread/framework/java/android/net/thread/ThreadNetworkController.java
index db761a3..8d6b40a 100644
--- a/thread/framework/java/android/net/thread/ThreadNetworkController.java
+++ b/thread/framework/java/android/net/thread/ThreadNetworkController.java
@@ -25,10 +25,12 @@
 import android.annotation.NonNull;
 import android.annotation.Nullable;
 import android.annotation.RequiresPermission;
+import android.annotation.Size;
 import android.annotation.SystemApi;
 import android.os.Binder;
 import android.os.OutcomeReceiver;
 import android.os.RemoteException;
+import android.util.SparseIntArray;
 
 import com.android.internal.annotations.GuardedBy;
 import com.android.internal.annotations.VisibleForTesting;
@@ -98,6 +100,12 @@
     /** Thread standard version 1.3. */
     public static final int THREAD_VERSION_1_3 = 4;
 
+    /** Minimum value of max power in unit of 0.01dBm. @hide */
+    private static final int POWER_LIMITATION_MIN = -32768;
+
+    /** Maximum value of max power in unit of 0.01dBm. @hide */
+    private static final int POWER_LIMITATION_MAX = 32767;
+
     /** @hide */
     @Retention(RetentionPolicy.SOURCE)
     @IntDef({THREAD_VERSION_1_3})
@@ -596,6 +604,98 @@
         }
     }
 
+    /**
+     * Sets max power of each channel.
+     *
+     * <p>If not set, the default max power is set by the Thread HAL service or the Thread radio
+     * chip firmware.
+     *
+     * <p>On success, the Pending Dataset is successfully registered and persisted on the Leader and
+     * {@link OutcomeReceiver#onResult} of {@code receiver} will be called; When failed, {@link
+     * OutcomeReceiver#onError} will be called with a specific error:
+     *
+     * <ul>
+     *   <li>{@link ThreadNetworkException#ERROR_UNSUPPORTED_OPERATION} the operation is no
+     *       supported by the platform.
+     * </ul>
+     *
+     * @param channelMaxPowers SparseIntArray (key: channel, value: max power) consists of channel
+     *     and corresponding max power. Valid channel values should be between {@link
+     *     ActiveOperationalDataset#CHANNEL_MIN_24_GHZ} and {@link
+     *     ActiveOperationalDataset#CHANNEL_MAX_24_GHZ}. The unit of the max power is 0.01dBm. Max
+     *     power values should be between INT16_MIN (-32768) and INT16_MAX (32767). If the max power
+     *     is set to INT16_MAX, the corresponding channel is not supported.
+     * @param executor the executor to execute {@code receiver}.
+     * @param receiver the receiver to receive the result of this operation.
+     * @throws IllegalArgumentException if the size of {@code channelMaxPowers} is smaller than 1,
+     *     or invalid channel or max power is configured.
+     * @hide
+     */
+    @RequiresPermission("android.permission.THREAD_NETWORK_PRIVILEGED")
+    public final void setChannelMaxPowers(
+            @NonNull @Size(min = 1) SparseIntArray channelMaxPowers,
+            @NonNull @CallbackExecutor Executor executor,
+            @NonNull OutcomeReceiver<Void, ThreadNetworkException> receiver) {
+        requireNonNull(channelMaxPowers, "channelMaxPowers cannot be null");
+        requireNonNull(executor, "executor cannot be null");
+        requireNonNull(receiver, "receiver cannot be null");
+
+        if (channelMaxPowers.size() < 1) {
+            throw new IllegalArgumentException("channelMaxPowers cannot be empty");
+        }
+
+        for (int i = 0; i < channelMaxPowers.size(); i++) {
+            int channel = channelMaxPowers.keyAt(i);
+            int maxPower = channelMaxPowers.get(channel);
+
+            if ((channel < ActiveOperationalDataset.CHANNEL_MIN_24_GHZ)
+                    || (channel > ActiveOperationalDataset.CHANNEL_MAX_24_GHZ)) {
+                throw new IllegalArgumentException(
+                        "Channel "
+                                + channel
+                                + " exceeds allowed range ["
+                                + ActiveOperationalDataset.CHANNEL_MIN_24_GHZ
+                                + ", "
+                                + ActiveOperationalDataset.CHANNEL_MAX_24_GHZ
+                                + "]");
+            }
+
+            if ((maxPower < POWER_LIMITATION_MIN) || (maxPower > POWER_LIMITATION_MAX)) {
+                throw new IllegalArgumentException(
+                        "Channel power ({channel: "
+                                + channel
+                                + ", maxPower: "
+                                + maxPower
+                                + "}) exceeds allowed range ["
+                                + POWER_LIMITATION_MIN
+                                + ", "
+                                + POWER_LIMITATION_MAX
+                                + "]");
+            }
+        }
+
+        try {
+            mControllerService.setChannelMaxPowers(
+                    toChannelMaxPowerArray(channelMaxPowers),
+                    new OperationReceiverProxy(executor, receiver));
+        } catch (RemoteException e) {
+            throw e.rethrowFromSystemServer();
+        }
+    }
+
+    private static ChannelMaxPower[] toChannelMaxPowerArray(
+            @NonNull SparseIntArray channelMaxPowers) {
+        final ChannelMaxPower[] powerArray = new ChannelMaxPower[channelMaxPowers.size()];
+
+        for (int i = 0; i < channelMaxPowers.size(); i++) {
+            powerArray[i] = new ChannelMaxPower();
+            powerArray[i].channel = channelMaxPowers.keyAt(i);
+            powerArray[i].maxPower = channelMaxPowers.get(powerArray[i].channel);
+        }
+
+        return powerArray;
+    }
+
     private static <T> void propagateError(
             Executor executor,
             OutcomeReceiver<T, ThreadNetworkException> receiver,
diff --git a/thread/framework/java/android/net/thread/ThreadNetworkException.java b/thread/framework/java/android/net/thread/ThreadNetworkException.java
index 4def0fb..f699c30 100644
--- a/thread/framework/java/android/net/thread/ThreadNetworkException.java
+++ b/thread/framework/java/android/net/thread/ThreadNetworkException.java
@@ -138,8 +138,17 @@
      */
     public static final int ERROR_THREAD_DISABLED = 12;
 
+    /**
+     * The operation failed because it is not supported by the platform. For example, some platforms
+     * may not support setting the target power of each channel. The caller should not retry and may
+     * return an error to the user.
+     *
+     * @hide
+     */
+    public static final int ERROR_UNSUPPORTED_OPERATION = 13;
+
     private static final int ERROR_MIN = ERROR_INTERNAL_ERROR;
-    private static final int ERROR_MAX = ERROR_THREAD_DISABLED;
+    private static final int ERROR_MAX = ERROR_UNSUPPORTED_OPERATION;
 
     private final int mErrorCode;
 
diff --git a/thread/service/java/com/android/server/thread/ActiveOperationalDatasetReceiverWrapper.java b/thread/service/java/com/android/server/thread/ActiveOperationalDatasetReceiverWrapper.java
index e3b4e1a..43ff336 100644
--- a/thread/service/java/com/android/server/thread/ActiveOperationalDatasetReceiverWrapper.java
+++ b/thread/service/java/com/android/server/thread/ActiveOperationalDatasetReceiverWrapper.java
@@ -16,10 +16,12 @@
 
 package com.android.server.thread;
 
+import static android.net.thread.ThreadNetworkException.ERROR_INTERNAL_ERROR;
 import static android.net.thread.ThreadNetworkException.ERROR_UNAVAILABLE;
 
 import android.net.thread.ActiveOperationalDataset;
 import android.net.thread.IActiveOperationalDatasetReceiver;
+import android.net.thread.ThreadNetworkException;
 import android.os.RemoteException;
 
 import com.android.internal.annotations.GuardedBy;
@@ -73,6 +75,17 @@
         }
     }
 
+    public void onError(Throwable e) {
+        if (e instanceof ThreadNetworkException) {
+            ThreadNetworkException threadException = (ThreadNetworkException) e;
+            onError(threadException.getErrorCode(), threadException.getMessage());
+        } else if (e instanceof RemoteException) {
+            onError(ERROR_INTERNAL_ERROR, "Thread stack error");
+        } else {
+            throw new AssertionError(e);
+        }
+    }
+
     public void onError(int errorCode, String errorMessage) {
         synchronized (sPendingReceiversLock) {
             sPendingReceivers.remove(this);
diff --git a/thread/service/java/com/android/server/thread/NsdPublisher.java b/thread/service/java/com/android/server/thread/NsdPublisher.java
index 3c7a72b..72e3980 100644
--- a/thread/service/java/com/android/server/thread/NsdPublisher.java
+++ b/thread/service/java/com/android/server/thread/NsdPublisher.java
@@ -21,6 +21,7 @@
 import android.annotation.NonNull;
 import android.content.Context;
 import android.net.InetAddresses;
+import android.net.nsd.DiscoveryRequest;
 import android.net.nsd.NsdManager;
 import android.net.nsd.NsdServiceInfo;
 import android.os.Handler;
@@ -31,15 +32,20 @@
 
 import com.android.internal.annotations.VisibleForTesting;
 import com.android.server.thread.openthread.DnsTxtAttribute;
+import com.android.server.thread.openthread.INsdDiscoverServiceCallback;
 import com.android.server.thread.openthread.INsdPublisher;
+import com.android.server.thread.openthread.INsdResolveServiceCallback;
 import com.android.server.thread.openthread.INsdStatusReceiver;
 
+import java.net.Inet6Address;
 import java.net.InetAddress;
 import java.util.ArrayDeque;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Deque;
 import java.util.HashSet;
 import java.util.List;
+import java.util.Map;
 import java.util.concurrent.Executor;
 
 /**
@@ -66,6 +72,8 @@
     private final Handler mHandler;
     private final Executor mExecutor;
     private final SparseArray<RegistrationListener> mRegistrationListeners = new SparseArray<>(0);
+    private final SparseArray<DiscoveryListener> mDiscoveryListeners = new SparseArray<>(0);
+    private final SparseArray<ServiceInfoListener> mServiceInfoListeners = new SparseArray<>(0);
     private final Deque<Runnable> mRegistrationJobs = new ArrayDeque<>();
 
     @VisibleForTesting
@@ -197,6 +205,110 @@
         mNsdManager.unregisterService(registrationListener);
     }
 
+    @Override
+    public void discoverService(String type, INsdDiscoverServiceCallback callback, int listenerId) {
+        mHandler.post(() -> discoverServiceInternal(type, callback, listenerId));
+    }
+
+    private void discoverServiceInternal(
+            String type, INsdDiscoverServiceCallback callback, int listenerId) {
+        checkOnHandlerThread();
+        Log.i(
+                TAG,
+                "Discovering services."
+                        + " Listener ID: "
+                        + listenerId
+                        + ", service type: "
+                        + type);
+
+        DiscoveryListener listener = new DiscoveryListener(listenerId, type, callback);
+        mDiscoveryListeners.append(listenerId, listener);
+        DiscoveryRequest discoveryRequest =
+                new DiscoveryRequest.Builder(type).setNetwork(null).build();
+        mNsdManager.discoverServices(discoveryRequest, mExecutor, listener);
+    }
+
+    @Override
+    public void stopServiceDiscovery(int listenerId) {
+        mHandler.post(() -> stopServiceDiscoveryInternal(listenerId));
+    }
+
+    private void stopServiceDiscoveryInternal(int listenerId) {
+        checkOnHandlerThread();
+
+        DiscoveryListener listener = mDiscoveryListeners.get(listenerId);
+        if (listener == null) {
+            Log.w(
+                    TAG,
+                    "Failed to stop service discovery. Listener ID "
+                            + listenerId
+                            + ". The listener is null.");
+            return;
+        }
+
+        Log.i(TAG, "Stopping service discovery. Listener: " + listener);
+        mNsdManager.stopServiceDiscovery(listener);
+    }
+
+    @Override
+    public void resolveService(
+            String name, String type, INsdResolveServiceCallback callback, int listenerId) {
+        mHandler.post(() -> resolveServiceInternal(name, type, callback, listenerId));
+    }
+
+    private void resolveServiceInternal(
+            String name, String type, INsdResolveServiceCallback callback, int listenerId) {
+        checkOnHandlerThread();
+
+        NsdServiceInfo serviceInfo = new NsdServiceInfo();
+        serviceInfo.setServiceName(name);
+        serviceInfo.setServiceType(type);
+        serviceInfo.setNetwork(null);
+        Log.i(
+                TAG,
+                "Resolving service."
+                        + " Listener ID: "
+                        + listenerId
+                        + ", service name: "
+                        + name
+                        + ", service type: "
+                        + type);
+
+        ServiceInfoListener listener = new ServiceInfoListener(serviceInfo, listenerId, callback);
+        mServiceInfoListeners.append(listenerId, listener);
+        mNsdManager.registerServiceInfoCallback(serviceInfo, mExecutor, listener);
+    }
+
+    @Override
+    public void stopServiceResolution(int listenerId) {
+        mHandler.post(() -> stopServiceResolutionInternal(listenerId));
+    }
+
+    private void stopServiceResolutionInternal(int listenerId) {
+        checkOnHandlerThread();
+
+        ServiceInfoListener listener = mServiceInfoListeners.get(listenerId);
+        if (listener == null) {
+            Log.w(
+                    TAG,
+                    "Failed to stop service resolution. Listener ID: "
+                            + listenerId
+                            + ". The listener is null.");
+            return;
+        }
+
+        Log.i(TAG, "Stopping service resolution. Listener: " + listener);
+
+        try {
+            mNsdManager.unregisterServiceInfoCallback(listener);
+        } catch (IllegalArgumentException e) {
+            Log.w(
+                    TAG,
+                    "Failed to stop the service resolution because it's already stopped. Listener: "
+                            + listener);
+        }
+    }
+
     private void checkOnHandlerThread() {
         if (mHandler.getLooper().getThread() != Thread.currentThread()) {
             throw new IllegalStateException(
@@ -368,4 +480,166 @@
             popAndRunNext();
         }
     }
+
+    private final class DiscoveryListener implements NsdManager.DiscoveryListener {
+        private final int mListenerId;
+        private final String mType;
+        private final INsdDiscoverServiceCallback mDiscoverServiceCallback;
+
+        DiscoveryListener(
+                int listenerId,
+                @NonNull String type,
+                @NonNull INsdDiscoverServiceCallback discoverServiceCallback) {
+            mListenerId = listenerId;
+            mType = type;
+            mDiscoverServiceCallback = discoverServiceCallback;
+        }
+
+        @Override
+        public void onStartDiscoveryFailed(String serviceType, int errorCode) {
+            Log.e(
+                    TAG,
+                    "Failed to start service discovery."
+                            + " Error code: "
+                            + errorCode
+                            + ", listener: "
+                            + this);
+            mDiscoveryListeners.remove(mListenerId);
+        }
+
+        @Override
+        public void onStopDiscoveryFailed(String serviceType, int errorCode) {
+            Log.e(
+                    TAG,
+                    "Failed to stop service discovery."
+                            + " Error code: "
+                            + errorCode
+                            + ", listener: "
+                            + this);
+            mDiscoveryListeners.remove(mListenerId);
+        }
+
+        @Override
+        public void onDiscoveryStarted(String serviceType) {
+            Log.i(TAG, "Started service discovery. Listener: " + this);
+        }
+
+        @Override
+        public void onDiscoveryStopped(String serviceType) {
+            Log.i(TAG, "Stopped service discovery. Listener: " + this);
+            mDiscoveryListeners.remove(mListenerId);
+        }
+
+        @Override
+        public void onServiceFound(NsdServiceInfo serviceInfo) {
+            Log.i(TAG, "Found service: " + serviceInfo);
+            try {
+                mDiscoverServiceCallback.onServiceDiscovered(
+                        serviceInfo.getServiceName(), mType, true);
+            } catch (RemoteException e) {
+                // do nothing if the client is dead
+            }
+        }
+
+        @Override
+        public void onServiceLost(NsdServiceInfo serviceInfo) {
+            Log.i(TAG, "Lost service: " + serviceInfo);
+            try {
+                mDiscoverServiceCallback.onServiceDiscovered(
+                        serviceInfo.getServiceName(), mType, false);
+            } catch (RemoteException e) {
+                // do nothing if the client is dead
+            }
+        }
+
+        @Override
+        public String toString() {
+            return "ID: " + mListenerId + ", type: " + mType;
+        }
+    }
+
+    private final class ServiceInfoListener implements NsdManager.ServiceInfoCallback {
+        private final String mName;
+        private final String mType;
+        private final INsdResolveServiceCallback mResolveServiceCallback;
+        private final int mListenerId;
+
+        ServiceInfoListener(
+                @NonNull NsdServiceInfo serviceInfo,
+                int listenerId,
+                @NonNull INsdResolveServiceCallback resolveServiceCallback) {
+            mName = serviceInfo.getServiceName();
+            mType = serviceInfo.getServiceType();
+            mListenerId = listenerId;
+            mResolveServiceCallback = resolveServiceCallback;
+        }
+
+        @Override
+        public void onServiceInfoCallbackRegistrationFailed(int errorCode) {
+            Log.e(
+                    TAG,
+                    "Failed to register service info callback."
+                            + " Listener ID: "
+                            + mListenerId
+                            + ", error: "
+                            + errorCode
+                            + ", service name: "
+                            + mName
+                            + ", service type: "
+                            + mType);
+        }
+
+        @Override
+        public void onServiceUpdated(@NonNull NsdServiceInfo serviceInfo) {
+            Log.i(
+                    TAG,
+                    "Service is resolved. "
+                            + " Listener ID: "
+                            + mListenerId
+                            + ", serviceInfo: "
+                            + serviceInfo);
+            List<String> addresses = new ArrayList<>();
+            for (InetAddress address : serviceInfo.getHostAddresses()) {
+                if (address instanceof Inet6Address) {
+                    addresses.add(address.getHostAddress());
+                }
+            }
+            List<DnsTxtAttribute> txtList = new ArrayList<>();
+            for (Map.Entry<String, byte[]> entry : serviceInfo.getAttributes().entrySet()) {
+                DnsTxtAttribute attribute = new DnsTxtAttribute();
+                attribute.name = entry.getKey();
+                attribute.value = Arrays.copyOf(entry.getValue(), entry.getValue().length);
+                txtList.add(attribute);
+            }
+            // TODO: b/329018320 - Use the serviceInfo.getExpirationTime to derive TTL.
+            int ttlSeconds = 10;
+            try {
+                mResolveServiceCallback.onServiceResolved(
+                        serviceInfo.getHostname(),
+                        serviceInfo.getServiceName(),
+                        serviceInfo.getServiceType(),
+                        serviceInfo.getPort(),
+                        addresses,
+                        txtList,
+                        ttlSeconds);
+
+            } catch (RemoteException e) {
+                // do nothing if the client is dead
+            }
+        }
+
+        @Override
+        public void onServiceLost() {}
+
+        @Override
+        public void onServiceInfoCallbackUnregistered() {
+            Log.i(TAG, "The service info callback is unregistered. Listener: " + this);
+            mServiceInfoListeners.remove(mListenerId);
+        }
+
+        @Override
+        public String toString() {
+            return "ID: " + mListenerId + ", service name: " + mName + ", service type: " + mType;
+        }
+    }
 }
diff --git a/thread/service/java/com/android/server/thread/OperationReceiverWrapper.java b/thread/service/java/com/android/server/thread/OperationReceiverWrapper.java
index a8909bc..bad63f3 100644
--- a/thread/service/java/com/android/server/thread/OperationReceiverWrapper.java
+++ b/thread/service/java/com/android/server/thread/OperationReceiverWrapper.java
@@ -16,9 +16,11 @@
 
 package com.android.server.thread;
 
+import static android.net.thread.ThreadNetworkException.ERROR_INTERNAL_ERROR;
 import static android.net.thread.ThreadNetworkException.ERROR_UNAVAILABLE;
 
 import android.net.thread.IOperationReceiver;
+import android.net.thread.ThreadNetworkException;
 import android.os.RemoteException;
 
 import com.android.internal.annotations.GuardedBy;
@@ -29,6 +31,7 @@
 /** A {@link IOperationReceiver} wrapper which makes it easier to invoke the callbacks. */
 final class OperationReceiverWrapper {
     private final IOperationReceiver mReceiver;
+    private final boolean mExpectOtDaemonDied;
 
     private static final Object sPendingReceiversLock = new Object();
 
@@ -36,7 +39,19 @@
     private static final Set<OperationReceiverWrapper> sPendingReceivers = new HashSet<>();
 
     public OperationReceiverWrapper(IOperationReceiver receiver) {
-        this.mReceiver = receiver;
+        this(receiver, false /* expectOtDaemonDied */);
+    }
+
+    /**
+     * Creates a new {@link OperationReceiverWrapper}.
+     *
+     * <p>If {@code expectOtDaemonDied} is {@code true}, it's expected that ot-daemon becomes dead
+     * before {@code receiver} is completed with {@code onSuccess} and {@code onError} and {@code
+     * receiver#onSuccess} will be invoked in this case.
+     */
+    public OperationReceiverWrapper(IOperationReceiver receiver, boolean expectOtDaemonDied) {
+        mReceiver = receiver;
+        mExpectOtDaemonDied = expectOtDaemonDied;
 
         synchronized (sPendingReceiversLock) {
             sPendingReceivers.add(this);
@@ -47,7 +62,11 @@
         synchronized (sPendingReceiversLock) {
             for (OperationReceiverWrapper receiver : sPendingReceivers) {
                 try {
-                    receiver.mReceiver.onError(ERROR_UNAVAILABLE, "Thread daemon died");
+                    if (receiver.mExpectOtDaemonDied) {
+                        receiver.mReceiver.onSuccess();
+                    } else {
+                        receiver.mReceiver.onError(ERROR_UNAVAILABLE, "Thread daemon died");
+                    }
                 } catch (RemoteException e) {
                     // The client is dead, do nothing
                 }
@@ -68,6 +87,17 @@
         }
     }
 
+    public void onError(Throwable e) {
+        if (e instanceof ThreadNetworkException) {
+            ThreadNetworkException threadException = (ThreadNetworkException) e;
+            onError(threadException.getErrorCode(), threadException.getMessage());
+        } else if (e instanceof RemoteException) {
+            onError(ERROR_INTERNAL_ERROR, "Thread stack error");
+        } else {
+            throw new AssertionError(e);
+        }
+    }
+
     public void onError(int errorCode, String errorMessage, Object... messageArgs) {
         synchronized (sPendingReceiversLock) {
             sPendingReceivers.remove(this);
diff --git a/thread/service/java/com/android/server/thread/ThreadNetworkControllerService.java b/thread/service/java/com/android/server/thread/ThreadNetworkControllerService.java
index 1235c30..155296d 100644
--- a/thread/service/java/com/android/server/thread/ThreadNetworkControllerService.java
+++ b/thread/service/java/com/android/server/thread/ThreadNetworkControllerService.java
@@ -40,6 +40,7 @@
 import static android.net.thread.ThreadNetworkException.ERROR_THREAD_DISABLED;
 import static android.net.thread.ThreadNetworkException.ERROR_TIMEOUT;
 import static android.net.thread.ThreadNetworkException.ERROR_UNSUPPORTED_CHANNEL;
+import static android.net.thread.ThreadNetworkException.ERROR_UNSUPPORTED_OPERATION;
 import static android.net.thread.ThreadNetworkManager.DISALLOW_THREAD_NETWORK;
 import static android.net.thread.ThreadNetworkManager.PERMISSION_THREAD_NETWORK_PRIVILEGED;
 
@@ -47,6 +48,7 @@
 import static com.android.server.thread.openthread.IOtDaemon.ErrorCode.OT_ERROR_BUSY;
 import static com.android.server.thread.openthread.IOtDaemon.ErrorCode.OT_ERROR_FAILED_PRECONDITION;
 import static com.android.server.thread.openthread.IOtDaemon.ErrorCode.OT_ERROR_INVALID_STATE;
+import static com.android.server.thread.openthread.IOtDaemon.ErrorCode.OT_ERROR_NOT_IMPLEMENTED;
 import static com.android.server.thread.openthread.IOtDaemon.ErrorCode.OT_ERROR_NO_BUFS;
 import static com.android.server.thread.openthread.IOtDaemon.ErrorCode.OT_ERROR_PARSE;
 import static com.android.server.thread.openthread.IOtDaemon.ErrorCode.OT_ERROR_REASSEMBLY_TIMEOUT;
@@ -86,6 +88,7 @@
 import android.net.TestNetworkSpecifier;
 import android.net.thread.ActiveOperationalDataset;
 import android.net.thread.ActiveOperationalDataset.SecurityPolicy;
+import android.net.thread.ChannelMaxPower;
 import android.net.thread.IActiveOperationalDatasetReceiver;
 import android.net.thread.IOperationReceiver;
 import android.net.thread.IOperationalDatasetCallback;
@@ -95,6 +98,7 @@
 import android.net.thread.PendingOperationalDataset;
 import android.net.thread.ThreadNetworkController;
 import android.net.thread.ThreadNetworkController.DeviceRole;
+import android.net.thread.ThreadNetworkException;
 import android.net.thread.ThreadNetworkException.ErrorCode;
 import android.os.Build;
 import android.os.Handler;
@@ -179,6 +183,8 @@
     private final OtDaemonCallbackProxy mOtDaemonCallbackProxy = new OtDaemonCallbackProxy();
     private final ConnectivityResources mResources;
 
+    // This should not be directly used for calling IOtDaemon APIs because ot-daemon may die and
+    // {@code mOtDaemon} will be set to {@code null}. Instead, use {@code getOtDaemon()}
     @Nullable private IOtDaemon mOtDaemon;
     @Nullable private NetworkAgent mNetworkAgent;
     @Nullable private NetworkAgent mTestNetworkAgent;
@@ -193,6 +199,7 @@
     private final ThreadPersistentSettings mPersistentSettings;
     private final UserManager mUserManager;
     private boolean mUserRestricted;
+    private boolean mForceStopOtDaemonEnabled;
 
     private BorderRouterConfigurationParcel mBorderRouterConfig;
 
@@ -301,17 +308,31 @@
                 .build();
     }
 
-    private void initializeOtDaemon() {
+    private void maybeInitializeOtDaemon() {
+        if (!isEnabled()) {
+            return;
+        }
+
+        Log.i(TAG, "Starting OT daemon...");
+
         try {
             getOtDaemon();
         } catch (RemoteException e) {
-            Log.e(TAG, "Failed to initialize ot-daemon");
+            Log.e(TAG, "Failed to initialize ot-daemon", e);
+        } catch (ThreadNetworkException e) {
+            // no ThreadNetworkException.ERROR_THREAD_DISABLED error should be thrown
+            throw new AssertionError(e);
         }
     }
 
-    private IOtDaemon getOtDaemon() throws RemoteException {
+    private IOtDaemon getOtDaemon() throws RemoteException, ThreadNetworkException {
         checkOnHandlerThread();
 
+        if (mForceStopOtDaemonEnabled) {
+            throw new ThreadNetworkException(
+                    ERROR_THREAD_DISABLED, "ot-daemon is forcibly stopped");
+        }
+
         if (mOtDaemon != null) {
             return mOtDaemon;
         }
@@ -325,8 +346,8 @@
                 mTunIfController.getTunFd(),
                 isEnabled(),
                 mNsdPublisher,
-                getMeshcopTxtAttributes(mResources.get()));
-        otDaemon.registerStateCallback(mOtDaemonCallbackProxy, -1);
+                getMeshcopTxtAttributes(mResources.get()),
+                mOtDaemonCallbackProxy);
         otDaemon.asBinder().linkToDeath(() -> mHandler.post(this::onOtDaemonDied), 0);
         mOtDaemon = otDaemon;
         return mOtDaemon;
@@ -371,14 +392,14 @@
 
     private void onOtDaemonDied() {
         checkOnHandlerThread();
-        Log.w(TAG, "OT daemon is dead, clean up and restart it...");
+        Log.w(TAG, "OT daemon is dead, clean up...");
 
         OperationReceiverWrapper.onOtDaemonDied();
         mOtDaemonCallbackProxy.onOtDaemonDied();
         mTunIfController.onOtDaemonDied();
         mNsdPublisher.onOtDaemonDied();
         mOtDaemon = null;
-        initializeOtDaemon();
+        maybeInitializeOtDaemon();
     }
 
     public void initialize() {
@@ -396,10 +417,59 @@
                     requestThreadNetwork();
                     mUserRestricted = isThreadUserRestricted();
                     registerUserRestrictionsReceiver();
-                    initializeOtDaemon();
+                    maybeInitializeOtDaemon();
                 });
     }
 
+    /**
+     * Force stops ot-daemon immediately and prevents ot-daemon from being restarted by
+     * system_server again.
+     *
+     * <p>This is for VTS testing only.
+     */
+    @RequiresPermission(PERMISSION_THREAD_NETWORK_PRIVILEGED)
+    void forceStopOtDaemonForTest(boolean enabled, @NonNull IOperationReceiver receiver) {
+        enforceAllPermissionsGranted(PERMISSION_THREAD_NETWORK_PRIVILEGED);
+
+        mHandler.post(
+                () ->
+                        forceStopOtDaemonForTestInternal(
+                                enabled,
+                                new OperationReceiverWrapper(
+                                        receiver, true /* expectOtDaemonDied */)));
+    }
+
+    private void forceStopOtDaemonForTestInternal(
+            boolean enabled, @NonNull OperationReceiverWrapper receiver) {
+        checkOnHandlerThread();
+        if (enabled == mForceStopOtDaemonEnabled) {
+            receiver.onSuccess();
+            return;
+        }
+
+        if (!enabled) {
+            mForceStopOtDaemonEnabled = false;
+            maybeInitializeOtDaemon();
+            receiver.onSuccess();
+            return;
+        }
+
+        try {
+            getOtDaemon().terminate();
+            // Do not invoke the {@code receiver} callback here but wait for ot-daemon to
+            // become dead, so that it's guaranteed that ot-daemon is stopped when {@code
+            // receiver} is completed
+        } catch (RemoteException e) {
+            Log.e(TAG, "otDaemon.terminate failed", e);
+            receiver.onError(ERROR_INTERNAL_ERROR, "Thread stack error");
+        } catch (ThreadNetworkException e) {
+            // No ThreadNetworkException.ERROR_THREAD_DISABLED error will be thrown
+            throw new AssertionError(e);
+        } finally {
+            mForceStopOtDaemonEnabled = true;
+        }
+    }
+
     public void setEnabled(boolean isEnabled, @NonNull IOperationReceiver receiver) {
         enforceAllPermissionsGranted(PERMISSION_THREAD_NETWORK_PRIVILEGED);
 
@@ -429,9 +499,9 @@
 
         try {
             getOtDaemon().setThreadEnabled(isEnabled, newOtStatusReceiver(receiver));
-        } catch (RemoteException e) {
+        } catch (RemoteException | ThreadNetworkException e) {
             Log.e(TAG, "otDaemon.setThreadEnabled failed", e);
-            receiver.onError(ERROR_INTERNAL_ERROR, "Thread stack error");
+            receiver.onError(e);
         }
     }
 
@@ -488,7 +558,9 @@
 
     /** Returns {@code true} if Thread is set enabled. */
     private boolean isEnabled() {
-        return !mUserRestricted && mPersistentSettings.get(ThreadPersistentSettings.THREAD_ENABLED);
+        return !mForceStopOtDaemonEnabled
+                && !mUserRestricted
+                && mPersistentSettings.get(ThreadPersistentSettings.THREAD_ENABLED);
     }
 
     /** Returns {@code true} if Thread has been restricted for the user. */
@@ -676,9 +748,9 @@
 
         try {
             getOtDaemon().getChannelMasks(newChannelMasksReceiver(networkName, receiver));
-        } catch (RemoteException e) {
+        } catch (RemoteException | ThreadNetworkException e) {
             Log.e(TAG, "otDaemon.getChannelMasks failed", e);
-            receiver.onError(ERROR_INTERNAL_ERROR, "Thread stack error");
+            receiver.onError(e);
         }
     }
 
@@ -850,6 +922,8 @@
                 return ERROR_ABORTED;
             case OT_ERROR_BUSY:
                 return ERROR_BUSY;
+            case OT_ERROR_NOT_IMPLEMENTED:
+                return ERROR_UNSUPPORTED_OPERATION;
             case OT_ERROR_NO_BUFS:
                 return ERROR_RESOURCE_EXHAUSTED;
             case OT_ERROR_PARSE:
@@ -888,9 +962,9 @@
         try {
             // The otDaemon.join() will leave first if this device is currently attached
             getOtDaemon().join(activeDataset.toThreadTlvs(), newOtStatusReceiver(receiver));
-        } catch (RemoteException e) {
+        } catch (RemoteException | ThreadNetworkException e) {
             Log.e(TAG, "otDaemon.join failed", e);
-            receiver.onError(ERROR_INTERNAL_ERROR, "Thread stack error");
+            receiver.onError(e);
         }
     }
 
@@ -913,9 +987,9 @@
             getOtDaemon()
                     .scheduleMigration(
                             pendingDataset.toThreadTlvs(), newOtStatusReceiver(receiver));
-        } catch (RemoteException e) {
+        } catch (RemoteException | ThreadNetworkException e) {
             Log.e(TAG, "otDaemon.scheduleMigration failed", e);
-            receiver.onError(ERROR_INTERNAL_ERROR, "Thread stack error");
+            receiver.onError(e);
         }
     }
 
@@ -931,9 +1005,9 @@
 
         try {
             getOtDaemon().leave(newOtStatusReceiver(receiver));
-        } catch (RemoteException e) {
-            // Oneway AIDL API should never throw?
-            receiver.onError(ERROR_INTERNAL_ERROR, "Thread stack error");
+        } catch (RemoteException | ThreadNetworkException e) {
+            Log.e(TAG, "otDaemon.leave failed", e);
+            receiver.onError(e);
         }
     }
 
@@ -955,11 +1029,18 @@
             String countryCode, @NonNull OperationReceiverWrapper receiver) {
         checkOnHandlerThread();
 
+        // Fails early to avoid waking up ot-daemon by the ThreadNetworkCountryCode class
+        if (!isEnabled()) {
+            receiver.onError(
+                    ERROR_THREAD_DISABLED, "Can't set country code when Thread is disabled");
+            return;
+        }
+
         try {
             getOtDaemon().setCountryCode(countryCode, newOtStatusReceiver(receiver));
-        } catch (RemoteException e) {
+        } catch (RemoteException | ThreadNetworkException e) {
             Log.e(TAG, "otDaemon.setCountryCode failed", e);
-            receiver.onError(ERROR_INTERNAL_ERROR, "Thread stack error");
+            receiver.onError(e);
         }
     }
 
@@ -995,6 +1076,30 @@
         }
     }
 
+    @RequiresPermission(PERMISSION_THREAD_NETWORK_PRIVILEGED)
+    public void setChannelMaxPowers(
+            @NonNull ChannelMaxPower[] channelMaxPowers, @NonNull IOperationReceiver receiver) {
+        enforceAllPermissionsGranted(PERMISSION_THREAD_NETWORK_PRIVILEGED);
+
+        mHandler.post(
+                () ->
+                        setChannelMaxPowersInternal(
+                                channelMaxPowers, new OperationReceiverWrapper(receiver)));
+    }
+
+    private void setChannelMaxPowersInternal(
+            @NonNull ChannelMaxPower[] channelMaxPowers,
+            @NonNull OperationReceiverWrapper receiver) {
+        checkOnHandlerThread();
+
+        try {
+            getOtDaemon().setChannelMaxPowers(channelMaxPowers, newOtStatusReceiver(receiver));
+        } catch (RemoteException | ThreadNetworkException e) {
+            Log.e(TAG, "otDaemon.setChannelMaxPowers failed", e);
+            receiver.onError(ERROR_INTERNAL_ERROR, "Thread stack error");
+        }
+    }
+
     private void enableBorderRouting(String infraIfName) {
         if (mBorderRouterConfig.isBorderRoutingEnabled
                 && infraIfName.equals(mBorderRouterConfig.infraInterfaceName)) {
@@ -1007,9 +1112,10 @@
                     mInfraIfController.createIcmp6Socket(infraIfName);
             mBorderRouterConfig.isBorderRoutingEnabled = true;
 
-            mOtDaemon.configureBorderRouter(
-                    mBorderRouterConfig, new ConfigureBorderRouterStatusReceiver());
-        } catch (RemoteException | IOException e) {
+            getOtDaemon()
+                    .configureBorderRouter(
+                            mBorderRouterConfig, new ConfigureBorderRouterStatusReceiver());
+        } catch (RemoteException | IOException | ThreadNetworkException e) {
             Log.w(TAG, "Failed to enable border routing", e);
         }
     }
@@ -1020,9 +1126,10 @@
         mBorderRouterConfig.infraInterfaceIcmp6Socket = null;
         mBorderRouterConfig.isBorderRoutingEnabled = false;
         try {
-            mOtDaemon.configureBorderRouter(
-                    mBorderRouterConfig, new ConfigureBorderRouterStatusReceiver());
-        } catch (RemoteException e) {
+            getOtDaemon()
+                    .configureBorderRouter(
+                            mBorderRouterConfig, new ConfigureBorderRouterStatusReceiver());
+        } catch (RemoteException | ThreadNetworkException e) {
             Log.w(TAG, "Failed to disable border routing", e);
         }
     }
@@ -1190,8 +1297,8 @@
 
             try {
                 getOtDaemon().registerStateCallback(this, callbackMetadata.id);
-            } catch (RemoteException e) {
-                // oneway operation should never fail
+            } catch (RemoteException | ThreadNetworkException e) {
+                Log.e(TAG, "otDaemon.registerStateCallback failed", e);
             }
         }
 
@@ -1231,8 +1338,8 @@
 
             try {
                 getOtDaemon().registerStateCallback(this, callbackMetadata.id);
-            } catch (RemoteException e) {
-                // oneway operation should never fail
+            } catch (RemoteException | ThreadNetworkException e) {
+                Log.e(TAG, "otDaemon.registerStateCallback failed", e);
             }
         }
 
@@ -1251,8 +1358,11 @@
                 return;
             }
 
+            final int deviceRole = mState.deviceRole;
+            mState = null;
+
             // If this device is already STOPPED or DETACHED, do nothing
-            if (!ThreadNetworkController.isAttached(mState.deviceRole)) {
+            if (!ThreadNetworkController.isAttached(deviceRole)) {
                 return;
             }
 
diff --git a/thread/service/java/com/android/server/thread/ThreadNetworkService.java b/thread/service/java/com/android/server/thread/ThreadNetworkService.java
index 5664922..37c1cf1 100644
--- a/thread/service/java/com/android/server/thread/ThreadNetworkService.java
+++ b/thread/service/java/com/android/server/thread/ThreadNetworkService.java
@@ -18,6 +18,8 @@
 
 import static android.content.pm.PackageManager.PERMISSION_GRANTED;
 
+import static java.util.Objects.requireNonNull;
+
 import android.annotation.NonNull;
 import android.annotation.Nullable;
 import android.content.Context;
@@ -66,7 +68,8 @@
             // PHASE_ACTIVITY_MANAGER_READY and PHASE_THIRD_PARTY_APPS_CAN_START
             mCountryCode = ThreadNetworkCountryCode.newInstance(mContext, mControllerService);
             mCountryCode.initialize();
-            mShellCommand = new ThreadNetworkShellCommand(mCountryCode);
+            mShellCommand =
+                    new ThreadNetworkShellCommand(requireNonNull(mControllerService), mCountryCode);
         }
     }
 
diff --git a/thread/service/java/com/android/server/thread/ThreadNetworkShellCommand.java b/thread/service/java/com/android/server/thread/ThreadNetworkShellCommand.java
index c17c5a7..c6a1618 100644
--- a/thread/service/java/com/android/server/thread/ThreadNetworkShellCommand.java
+++ b/thread/service/java/com/android/server/thread/ThreadNetworkShellCommand.java
@@ -16,7 +16,10 @@
 
 package com.android.server.thread;
 
+import android.annotation.NonNull;
 import android.annotation.Nullable;
+import android.net.thread.IOperationReceiver;
+import android.net.thread.ThreadNetworkException;
 import android.os.Binder;
 import android.os.Process;
 import android.text.TextUtils;
@@ -25,7 +28,12 @@
 import com.android.modules.utils.BasicShellCommandHandler;
 
 import java.io.PrintWriter;
+import java.time.Duration;
 import java.util.List;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
 
 /**
  * Interprets and executes 'adb shell cmd thread_network [args]'.
@@ -37,16 +45,22 @@
  * corresponding API permissions.
  */
 public class ThreadNetworkShellCommand extends BasicShellCommandHandler {
-    private static final String TAG = "ThreadNetworkShellCommand";
+    private static final Duration SET_ENABLED_TIMEOUT = Duration.ofSeconds(2);
+    private static final Duration FORCE_STOP_TIMEOUT = Duration.ofSeconds(1);
 
     // These don't require root access.
-    private static final List<String> NON_PRIVILEGED_COMMANDS = List.of("help", "get-country-code");
+    private static final List<String> NON_PRIVILEGED_COMMANDS =
+            List.of("help", "get-country-code", "enable", "disable");
 
-    @Nullable private final ThreadNetworkCountryCode mCountryCode;
+    @NonNull private final ThreadNetworkControllerService mControllerService;
+    @NonNull private final ThreadNetworkCountryCode mCountryCode;
     @Nullable private PrintWriter mOutputWriter;
     @Nullable private PrintWriter mErrorWriter;
 
-    ThreadNetworkShellCommand(@Nullable ThreadNetworkCountryCode countryCode) {
+    ThreadNetworkShellCommand(
+            @NonNull ThreadNetworkControllerService controllerService,
+            @NonNull ThreadNetworkCountryCode countryCode) {
+        mControllerService = controllerService;
         mCountryCode = countryCode;
     }
 
@@ -91,14 +105,14 @@
         }
 
         switch (cmd) {
+            case "enable":
+                return setThreadEnabled(true);
+            case "disable":
+                return setThreadEnabled(false);
+            case "force-stop-ot-daemon":
+                return forceStopOtDaemon();
             case "force-country-code":
                 boolean enabled;
-
-                if (mCountryCode == null) {
-                    perr.println("Thread country code operations are not supported");
-                    return -1;
-                }
-
                 try {
                     enabled = getNextArgRequiredTrueOrFalse("enabled", "disabled");
                 } catch (IllegalArgumentException e) {
@@ -124,11 +138,6 @@
                 }
                 return 0;
             case "get-country-code":
-                if (mCountryCode == null) {
-                    perr.println("Thread country code operations are not supported");
-                    return -1;
-                }
-
                 pw.println("Thread country code = " + mCountryCode.getCountryCode());
                 return 0;
             default:
@@ -136,6 +145,64 @@
         }
     }
 
+    private int setThreadEnabled(boolean enabled) {
+        CompletableFuture<Void> setEnabledFuture = new CompletableFuture<>();
+        mControllerService.setEnabled(enabled, newOperationReceiver(setEnabledFuture));
+        return waitForFuture(setEnabledFuture, FORCE_STOP_TIMEOUT, getErrorWriter());
+    }
+
+    private int forceStopOtDaemon() {
+        final PrintWriter errorWriter = getErrorWriter();
+        boolean enabled;
+        try {
+            enabled = getNextArgRequiredTrueOrFalse("enabled", "disabled");
+        } catch (IllegalArgumentException e) {
+            errorWriter.println("Invalid argument: " + e.getMessage());
+            return -1;
+        }
+
+        CompletableFuture<Void> forceStopFuture = new CompletableFuture<>();
+        mControllerService.forceStopOtDaemonForTest(enabled, newOperationReceiver(forceStopFuture));
+        return waitForFuture(forceStopFuture, FORCE_STOP_TIMEOUT, getErrorWriter());
+    }
+
+    private static IOperationReceiver newOperationReceiver(CompletableFuture<Void> future) {
+        return new IOperationReceiver.Stub() {
+            @Override
+            public void onSuccess() {
+                future.complete(null);
+            }
+
+            @Override
+            public void onError(int errorCode, String errorMessage) {
+                future.completeExceptionally(new ThreadNetworkException(errorCode, errorMessage));
+            }
+        };
+    }
+
+    /**
+     * Waits for the future to complete within given timeout.
+     *
+     * <p>Returns 0 if {@code future} completed successfully, or -1 if {@code future} failed to
+     * complete. When failed, error messages are printed to {@code errorWriter}.
+     */
+    private int waitForFuture(
+            CompletableFuture<Void> future, Duration timeout, PrintWriter errorWriter) {
+        try {
+            future.get(timeout.toSeconds(), TimeUnit.SECONDS);
+            return 0;
+        } catch (InterruptedException e) {
+            Thread.currentThread().interrupt();
+            errorWriter.println("Failed: " + e.getMessage());
+        } catch (ExecutionException e) {
+            errorWriter.println("Failed: " + e.getCause().getMessage());
+        } catch (TimeoutException e) {
+            errorWriter.println("Failed: command timeout for " + timeout);
+        }
+
+        return -1;
+    }
+
     private static boolean argTrueOrFalse(String arg, String trueString, String falseString) {
         if (trueString.equals(arg)) {
             return true;
@@ -159,6 +226,10 @@
     }
 
     private void onHelpNonPrivileged(PrintWriter pw) {
+        pw.println("  enable");
+        pw.println("    Enables Thread radio");
+        pw.println("  disable");
+        pw.println("    Disables Thread radio");
         pw.println("  get-country-code");
         pw.println("    Gets country code as a two-letter string");
     }
@@ -166,6 +237,8 @@
     private void onHelpPrivileged(PrintWriter pw) {
         pw.println("  force-country-code enabled <two-letter code> | disabled ");
         pw.println("    Sets country code to <two-letter code> or left for normal value");
+        pw.println("  force-stop-ot-daemon enabled | disabled ");
+        pw.println("    force stop ot-daemon service");
     }
 
     @Override
diff --git a/thread/tests/integration/src/android/net/thread/BorderRoutingTest.java b/thread/tests/integration/src/android/net/thread/BorderRoutingTest.java
index 353db10..5fe4325 100644
--- a/thread/tests/integration/src/android/net/thread/BorderRoutingTest.java
+++ b/thread/tests/integration/src/android/net/thread/BorderRoutingTest.java
@@ -164,15 +164,13 @@
          * </pre>
          */
 
-        // Let ftd join the network.
         FullThreadDevice ftd = mFtds.get(0);
         startFtdChild(ftd);
 
-        // Infra device sends an echo request to FTD's OMR.
         mInfraDevice.sendEchoRequest(ftd.getOmrAddress());
 
         // Infra device receives an echo reply sent by FTD.
-        assertNotNull(pollForPacketOnInfraNetwork(ICMPV6_ECHO_REPLY_TYPE, null /* srcAddress */));
+        assertNotNull(pollForPacketOnInfraNetwork(ICMPV6_ECHO_REPLY_TYPE, ftd.getOmrAddress()));
     }
 
     @Test
diff --git a/thread/tests/integration/src/android/net/thread/ServiceDiscoveryTest.java b/thread/tests/integration/src/android/net/thread/ServiceDiscoveryTest.java
index 39a1671..491331c 100644
--- a/thread/tests/integration/src/android/net/thread/ServiceDiscoveryTest.java
+++ b/thread/tests/integration/src/android/net/thread/ServiceDiscoveryTest.java
@@ -17,6 +17,7 @@
 package android.net.thread;
 
 import static android.net.InetAddresses.parseNumericAddress;
+import static android.net.nsd.NsdManager.PROTOCOL_DNS_SD;
 import static android.net.thread.utils.IntegrationTestUtils.JOIN_TIMEOUT;
 import static android.net.thread.utils.IntegrationTestUtils.SERVICE_DISCOVERY_TIMEOUT;
 import static android.net.thread.utils.IntegrationTestUtils.discoverForServiceLost;
@@ -37,6 +38,7 @@
 import android.net.nsd.NsdManager;
 import android.net.nsd.NsdServiceInfo;
 import android.net.thread.utils.FullThreadDevice;
+import android.net.thread.utils.OtDaemonController;
 import android.net.thread.utils.TapTestNetworkTracker;
 import android.net.thread.utils.ThreadFeatureCheckerRule;
 import android.net.thread.utils.ThreadFeatureCheckerRule.RequiresSimulationThreadDevice;
@@ -65,6 +67,7 @@
 import java.util.Map;
 import java.util.Random;
 import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutionException;
 import java.util.concurrent.TimeoutException;
 
 /** Integration test cases for Service Discovery feature. */
@@ -96,15 +99,15 @@
     private final Context mContext = ApplicationProvider.getApplicationContext();
     private final ThreadNetworkControllerWrapper mController =
             ThreadNetworkControllerWrapper.newInstance(mContext);
-
+    private final OtDaemonController mOtCtl = new OtDaemonController();
     private HandlerThread mHandlerThread;
     private NsdManager mNsdManager;
     private TapTestNetworkTracker mTestNetworkTracker;
     private List<FullThreadDevice> mFtds;
+    private List<RegistrationListener> mRegistrationListeners = new ArrayList<>();
 
     @Before
     public void setUp() throws Exception {
-
         mController.joinAndWait(DEFAULT_DATASET);
         mNsdManager = mContext.getSystemService(NsdManager.class);
 
@@ -127,6 +130,9 @@
 
     @After
     public void tearDown() throws Exception {
+        for (RegistrationListener listener : mRegistrationListeners) {
+            unregisterService(listener);
+        }
         for (FullThreadDevice ftd : mFtds) {
             // Clear registered SRP hosts and services
             if (ftd.isSrpHostRegistered()) {
@@ -314,6 +320,176 @@
         assertThat(txtMap.get("mn")).isEqualTo("Thread Border Router".getBytes(UTF_8));
     }
 
+    @Test
+    public void discoveryProxy_multipleClientsBrowseAndResolveServiceOverMdns() throws Exception {
+        /*
+         * <pre>
+         * Topology:
+         *                    Thread
+         *  Border Router -------------- Full Thread device
+         *  (Cuttlefish)
+         * </pre>
+         */
+
+        RegistrationListener listener = new RegistrationListener();
+        NsdServiceInfo info = new NsdServiceInfo();
+        info.setServiceType("_testservice._tcp");
+        info.setServiceName("test-service");
+        info.setPort(12345);
+        info.setHostname("testhost");
+        info.setHostAddresses(List.of(parseNumericAddress("2001::1")));
+        info.setAttribute("key1", bytes(0x01, 0x02));
+        info.setAttribute("key2", bytes(0x03));
+        registerService(info, listener);
+        mRegistrationListeners.add(listener);
+        for (int i = 0; i < NUM_FTD; ++i) {
+            FullThreadDevice ftd = mFtds.get(i);
+            ftd.joinNetwork(DEFAULT_DATASET);
+            ftd.waitForStateAnyOf(List.of("router", "child"), JOIN_TIMEOUT);
+            ftd.setDnsServerAddress(mOtCtl.getMlEid().getHostAddress());
+        }
+        final ArrayList<NsdServiceInfo> browsedServices = new ArrayList<>();
+        final ArrayList<NsdServiceInfo> resolvedServices = new ArrayList<>();
+        final ArrayList<Thread> threads = new ArrayList<>();
+        for (int i = 0; i < NUM_FTD; ++i) {
+            browsedServices.add(null);
+            resolvedServices.add(null);
+        }
+        for (int i = 0; i < NUM_FTD; ++i) {
+            final FullThreadDevice ftd = mFtds.get(i);
+            final int index = i;
+            Runnable task =
+                    () -> {
+                        browsedServices.set(
+                                index,
+                                ftd.browseService("_testservice._tcp.default.service.arpa."));
+                        resolvedServices.set(
+                                index,
+                                ftd.resolveService(
+                                        "test-service", "_testservice._tcp.default.service.arpa."));
+                    };
+            threads.add(new Thread(task));
+        }
+        for (Thread thread : threads) {
+            thread.start();
+        }
+        for (Thread thread : threads) {
+            thread.join();
+        }
+
+        for (int i = 0; i < NUM_FTD; ++i) {
+            NsdServiceInfo browsedService = browsedServices.get(i);
+            assertThat(browsedService.getServiceName()).isEqualTo("test-service");
+            assertThat(browsedService.getPort()).isEqualTo(12345);
+
+            NsdServiceInfo resolvedService = resolvedServices.get(i);
+            assertThat(resolvedService.getServiceName()).isEqualTo("test-service");
+            assertThat(resolvedService.getPort()).isEqualTo(12345);
+            assertThat(resolvedService.getHostname()).isEqualTo("testhost.default.service.arpa.");
+            assertThat(resolvedService.getHostAddresses())
+                    .containsExactly(parseNumericAddress("2001::1"));
+            assertThat(resolvedService.getAttributes())
+                    .comparingValuesUsing(BYTE_ARRAY_EQUALITY)
+                    .containsExactly("key1", bytes(0x01, 0x02), "key2", bytes(3));
+        }
+    }
+
+    @Test
+    public void discoveryProxy_browseAndResolveServiceAtSrpServer() throws Exception {
+        /*
+         * <pre>
+         * Topology:
+         *                    Thread
+         *  Border Router -------+------ SRP client
+         *  (Cuttlefish)         |
+         *                       +------ DNS client
+         *
+         * </pre>
+         */
+        FullThreadDevice srpClient = mFtds.get(0);
+        srpClient.joinNetwork(DEFAULT_DATASET);
+        srpClient.waitForStateAnyOf(List.of("router", "child"), JOIN_TIMEOUT);
+        srpClient.setSrpHostname("my-host");
+        srpClient.setSrpHostAddresses(List.of((Inet6Address) parseNumericAddress("2001::1")));
+        srpClient.addSrpService(
+                "my-service",
+                "_test._udp",
+                List.of("_sub1"),
+                12345 /* port */,
+                Map.of("key1", bytes(0x01, 0x02), "key2", bytes(0x03)));
+
+        FullThreadDevice dnsClient = mFtds.get(1);
+        dnsClient.joinNetwork(DEFAULT_DATASET);
+        dnsClient.waitForStateAnyOf(List.of("router", "child"), JOIN_TIMEOUT);
+        dnsClient.setDnsServerAddress(mOtCtl.getMlEid().getHostAddress());
+
+        NsdServiceInfo browsedService = dnsClient.browseService("_test._udp.default.service.arpa.");
+        assertThat(browsedService.getServiceName()).isEqualTo("my-service");
+        assertThat(browsedService.getPort()).isEqualTo(12345);
+        assertThat(browsedService.getHostname()).isEqualTo("my-host.default.service.arpa.");
+        assertThat(browsedService.getHostAddresses())
+                .containsExactly(parseNumericAddress("2001::1"));
+        assertThat(browsedService.getAttributes())
+                .comparingValuesUsing(BYTE_ARRAY_EQUALITY)
+                .containsExactly("key1", bytes(0x01, 0x02), "key2", bytes(3));
+
+        NsdServiceInfo resolvedService =
+                dnsClient.resolveService("my-service", "_test._udp.default.service.arpa.");
+        assertThat(resolvedService.getServiceName()).isEqualTo("my-service");
+        assertThat(resolvedService.getPort()).isEqualTo(12345);
+        assertThat(resolvedService.getHostname()).isEqualTo("my-host.default.service.arpa.");
+        assertThat(resolvedService.getHostAddresses())
+                .containsExactly(parseNumericAddress("2001::1"));
+        assertThat(resolvedService.getAttributes())
+                .comparingValuesUsing(BYTE_ARRAY_EQUALITY)
+                .containsExactly("key1", bytes(0x01, 0x02), "key2", bytes(3));
+    }
+
+    private void registerService(NsdServiceInfo serviceInfo, RegistrationListener listener)
+            throws InterruptedException, ExecutionException, TimeoutException {
+        mNsdManager.registerService(serviceInfo, PROTOCOL_DNS_SD, listener);
+        listener.waitForRegistered();
+    }
+
+    private void unregisterService(RegistrationListener listener)
+            throws InterruptedException, ExecutionException, TimeoutException {
+        mNsdManager.unregisterService(listener);
+        listener.waitForUnregistered();
+    }
+
+    private static class RegistrationListener implements NsdManager.RegistrationListener {
+        private final CompletableFuture<Void> mRegisteredFuture = new CompletableFuture<>();
+        private final CompletableFuture<Void> mUnRegisteredFuture = new CompletableFuture<>();
+
+        RegistrationListener() {}
+
+        @Override
+        public void onRegistrationFailed(NsdServiceInfo serviceInfo, int errorCode) {}
+
+        @Override
+        public void onUnregistrationFailed(NsdServiceInfo serviceInfo, int errorCode) {}
+
+        @Override
+        public void onServiceRegistered(NsdServiceInfo serviceInfo) {
+            mRegisteredFuture.complete(null);
+        }
+
+        @Override
+        public void onServiceUnregistered(NsdServiceInfo serviceInfo) {
+            mUnRegisteredFuture.complete(null);
+        }
+
+        public void waitForRegistered()
+                throws InterruptedException, ExecutionException, TimeoutException {
+            mRegisteredFuture.get(SERVICE_DISCOVERY_TIMEOUT.toMillis(), MILLISECONDS);
+        }
+
+        public void waitForUnregistered()
+                throws InterruptedException, ExecutionException, TimeoutException {
+            mUnRegisteredFuture.get(SERVICE_DISCOVERY_TIMEOUT.toMillis(), MILLISECONDS);
+        }
+    }
+
     private static byte[] bytes(int... byteInts) {
         byte[] bytes = new byte[byteInts.length];
         for (int i = 0; i < byteInts.length; ++i) {
diff --git a/thread/tests/integration/src/android/net/thread/ThreadIntegrationTest.java b/thread/tests/integration/src/android/net/thread/ThreadIntegrationTest.java
index 4a006cf..bfded1d 100644
--- a/thread/tests/integration/src/android/net/thread/ThreadIntegrationTest.java
+++ b/thread/tests/integration/src/android/net/thread/ThreadIntegrationTest.java
@@ -97,13 +97,16 @@
     }
 
     @Test
-    public void otDaemonRestart_JoinedNetworkAndStopped_autoRejoined() throws Exception {
+    public void otDaemonRestart_JoinedNetworkAndStopped_autoRejoinedAndTunIfStateConsistent()
+            throws Exception {
         mController.joinAndWait(DEFAULT_DATASET);
 
         runShellCommand("stop ot-daemon");
 
         mController.waitForRole(DEVICE_ROLE_DETACHED, CALLBACK_TIMEOUT);
         mController.waitForRole(DEVICE_ROLE_LEADER, RESTART_JOIN_TIMEOUT);
+        assertThat(mOtCtl.isInterfaceUp()).isTrue();
+        assertThat(runShellCommand("ifconfig thread-wpan")).contains("UP POINTOPOINT RUNNING");
     }
 
     @Test
@@ -120,8 +123,8 @@
         mController.joinAndWait(DEFAULT_DATASET);
 
         mOtCtl.factoryReset();
-        String ifconfig = runShellCommand("ifconfig thread-wpan");
 
+        String ifconfig = runShellCommand("ifconfig thread-wpan");
         assertThat(ifconfig).doesNotContain("inet6 addr");
     }
 
diff --git a/thread/tests/integration/src/android/net/thread/ThreadNetworkControllerTest.java b/thread/tests/integration/src/android/net/thread/ThreadNetworkControllerTest.java
new file mode 100644
index 0000000..ba04348
--- /dev/null
+++ b/thread/tests/integration/src/android/net/thread/ThreadNetworkControllerTest.java
@@ -0,0 +1,171 @@
+/*
+ * Copyright (C) 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package android.net.thread;
+
+import static android.net.thread.ThreadNetworkException.ERROR_UNSUPPORTED_OPERATION;
+
+import static androidx.test.platform.app.InstrumentationRegistry.getInstrumentation;
+
+import static com.android.testutils.TestPermissionUtil.runAsShell;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import static org.junit.Assert.assertThrows;
+
+import android.content.Context;
+import android.net.thread.utils.ThreadFeatureCheckerRule;
+import android.net.thread.utils.ThreadFeatureCheckerRule.RequiresThreadFeature;
+import android.os.OutcomeReceiver;
+import android.util.SparseIntArray;
+
+import androidx.test.core.app.ApplicationProvider;
+import androidx.test.filters.LargeTest;
+import androidx.test.runner.AndroidJUnit4;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+
+/** Tests for hide methods of {@link ThreadNetworkController}. */
+@LargeTest
+@RequiresThreadFeature
+@RunWith(AndroidJUnit4.class)
+public class ThreadNetworkControllerTest {
+    private static final int VALID_POWER = 32_767;
+    private static final int INVALID_POWER = 32_768;
+    private static final int VALID_CHANNEL = 20;
+    private static final int INVALID_CHANNEL = 10;
+    private static final String THREAD_NETWORK_PRIVILEGED =
+            "android.permission.THREAD_NETWORK_PRIVILEGED";
+
+    private static final SparseIntArray CHANNEL_MAX_POWERS =
+            new SparseIntArray() {
+                {
+                    put(20, 32767);
+                }
+            };
+
+    @Rule public final ThreadFeatureCheckerRule mThreadRule = new ThreadFeatureCheckerRule();
+
+    private final Context mContext = ApplicationProvider.getApplicationContext();
+    private ExecutorService mExecutor;
+    private ThreadNetworkController mController;
+
+    @Before
+    public void setUp() throws Exception {
+        mController =
+                mContext.getSystemService(ThreadNetworkManager.class)
+                        .getAllThreadNetworkControllers()
+                        .get(0);
+
+        mExecutor = Executors.newSingleThreadExecutor();
+    }
+
+    @After
+    public void tearDown() throws Exception {
+        dropAllPermissions();
+    }
+
+    @Test
+    public void setChannelMaxPowers_withPrivilegedPermission_success() throws Exception {
+        CompletableFuture<Void> powerFuture = new CompletableFuture<>();
+
+        runAsShell(
+                THREAD_NETWORK_PRIVILEGED,
+                () ->
+                        mController.setChannelMaxPowers(
+                                CHANNEL_MAX_POWERS, mExecutor, newOutcomeReceiver(powerFuture)));
+
+        try {
+            assertThat(powerFuture.get()).isNull();
+        } catch (ExecutionException exception) {
+            ThreadNetworkException thrown = (ThreadNetworkException) exception.getCause();
+            assertThat(thrown.getErrorCode()).isEqualTo(ERROR_UNSUPPORTED_OPERATION);
+        }
+    }
+
+    @Test
+    public void setChannelMaxPowers_withoutPrivilegedPermission_throwsSecurityException()
+            throws Exception {
+        dropAllPermissions();
+
+        assertThrows(
+                SecurityException.class,
+                () -> mController.setChannelMaxPowers(CHANNEL_MAX_POWERS, mExecutor, v -> {}));
+    }
+
+    @Test
+    public void setChannelMaxPowers_emptyChannelMaxPower_throwsIllegalArgumentException() {
+        assertThrows(
+                IllegalArgumentException.class,
+                () -> mController.setChannelMaxPowers(new SparseIntArray(), mExecutor, v -> {}));
+    }
+
+    @Test
+    public void setChannelMaxPowers_invalidChannel_throwsIllegalArgumentException() {
+        final SparseIntArray INVALID_CHANNEL_ARRAY =
+                new SparseIntArray() {
+                    {
+                        put(INVALID_CHANNEL, VALID_POWER);
+                    }
+                };
+
+        assertThrows(
+                IllegalArgumentException.class,
+                () -> mController.setChannelMaxPowers(INVALID_CHANNEL_ARRAY, mExecutor, v -> {}));
+    }
+
+    @Test
+    public void setChannelMaxPowers_invalidPower_throwsIllegalArgumentException() {
+        final SparseIntArray INVALID_POWER_ARRAY =
+                new SparseIntArray() {
+                    {
+                        put(VALID_CHANNEL, INVALID_POWER);
+                    }
+                };
+
+        assertThrows(
+                IllegalArgumentException.class,
+                () -> mController.setChannelMaxPowers(INVALID_POWER_ARRAY, mExecutor, v -> {}));
+    }
+
+    private static void dropAllPermissions() {
+        getInstrumentation().getUiAutomation().dropShellPermissionIdentity();
+    }
+
+    private static <V> OutcomeReceiver<V, ThreadNetworkException> newOutcomeReceiver(
+            CompletableFuture<V> future) {
+        return new OutcomeReceiver<V, ThreadNetworkException>() {
+            @Override
+            public void onResult(V result) {
+                future.complete(result);
+            }
+
+            @Override
+            public void onError(ThreadNetworkException e) {
+                future.completeExceptionally(e);
+            }
+        };
+    }
+}
diff --git a/thread/tests/integration/src/android/net/thread/ThreadNetworkShellCommandTest.java b/thread/tests/integration/src/android/net/thread/ThreadNetworkShellCommandTest.java
new file mode 100644
index 0000000..8835f40
--- /dev/null
+++ b/thread/tests/integration/src/android/net/thread/ThreadNetworkShellCommandTest.java
@@ -0,0 +1,129 @@
+/*
+ * Copyright (C) 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package android.net.thread;
+
+import static android.net.thread.ThreadNetworkController.STATE_DISABLED;
+import static android.net.thread.ThreadNetworkController.STATE_ENABLED;
+import static android.net.thread.ThreadNetworkException.ERROR_THREAD_DISABLED;
+
+import static com.android.compatibility.common.util.SystemUtil.runShellCommandOrThrow;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import static org.junit.Assert.assertThrows;
+
+import android.content.Context;
+import android.net.thread.utils.ThreadFeatureCheckerRule;
+import android.net.thread.utils.ThreadFeatureCheckerRule.RequiresThreadFeature;
+import android.net.thread.utils.ThreadNetworkControllerWrapper;
+
+import androidx.test.core.app.ApplicationProvider;
+import androidx.test.filters.LargeTest;
+import androidx.test.runner.AndroidJUnit4;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+
+import java.util.concurrent.ExecutionException;
+
+/** Integration tests for {@link ThreadNetworkShellCommand}. */
+@LargeTest
+@RequiresThreadFeature
+@RunWith(AndroidJUnit4.class)
+public class ThreadNetworkShellCommandTest {
+    @Rule public final ThreadFeatureCheckerRule mThreadRule = new ThreadFeatureCheckerRule();
+
+    private final Context mContext = ApplicationProvider.getApplicationContext();
+    private final ThreadNetworkControllerWrapper mController =
+            ThreadNetworkControllerWrapper.newInstance(mContext);
+
+    @Before
+    public void setUp() {
+        ensureThreadEnabled();
+    }
+
+    @After
+    public void tearDown() {
+        ensureThreadEnabled();
+    }
+
+    private static void ensureThreadEnabled() {
+        runThreadCommand("force-stop-ot-daemon disabled");
+        runThreadCommand("enable");
+    }
+
+    @Test
+    public void enable_threadStateIsEnabled() throws Exception {
+        runThreadCommand("enable");
+
+        assertThat(mController.getEnabledState()).isEqualTo(STATE_ENABLED);
+    }
+
+    @Test
+    public void disable_threadStateIsDisabled() throws Exception {
+        runThreadCommand("disable");
+
+        assertThat(mController.getEnabledState()).isEqualTo(STATE_DISABLED);
+    }
+
+    @Test
+    public void forceStopOtDaemon_forceStopEnabled_otDaemonServiceDisappear() {
+        runThreadCommand("force-stop-ot-daemon enabled");
+
+        assertThat(runShellCommandOrThrow("service list")).doesNotContain("ot_daemon");
+    }
+
+    @Test
+    public void forceStopOtDaemon_forceStopEnabled_canNotEnableThread() throws Exception {
+        runThreadCommand("force-stop-ot-daemon enabled");
+
+        ExecutionException thrown =
+                assertThrows(ExecutionException.class, () -> mController.setEnabledAndWait(true));
+        ThreadNetworkException cause = (ThreadNetworkException) thrown.getCause();
+        assertThat(cause.getErrorCode()).isEqualTo(ERROR_THREAD_DISABLED);
+    }
+
+    @Test
+    public void forceStopOtDaemon_forceStopDisabled_otDaemonServiceAppears() throws Exception {
+        runThreadCommand("force-stop-ot-daemon disabled");
+
+        assertThat(runShellCommandOrThrow("service list")).contains("ot_daemon");
+    }
+
+    @Test
+    public void forceStopOtDaemon_forceStopDisabled_canEnableThread() throws Exception {
+        runThreadCommand("force-stop-ot-daemon disabled");
+
+        mController.setEnabledAndWait(true);
+        assertThat(mController.getEnabledState()).isEqualTo(STATE_ENABLED);
+    }
+
+    @Test
+    public void forceCountryCode_setCN_getCountryCodeReturnsCN() {
+        runThreadCommand("force-country-code enabled CN");
+
+        final String result = runThreadCommand("get-country-code");
+        assertThat(result).contains("Thread country code = CN");
+    }
+
+    private static String runThreadCommand(String cmd) {
+        return runShellCommandOrThrow("cmd thread_network " + cmd);
+    }
+}
diff --git a/thread/tests/integration/src/android/net/thread/utils/FullThreadDevice.java b/thread/tests/integration/src/android/net/thread/utils/FullThreadDevice.java
index 600b662..f7bb9ff 100644
--- a/thread/tests/integration/src/android/net/thread/utils/FullThreadDevice.java
+++ b/thread/tests/integration/src/android/net/thread/utils/FullThreadDevice.java
@@ -24,6 +24,7 @@
 
 import android.net.InetAddresses;
 import android.net.IpPrefix;
+import android.net.nsd.NsdServiceInfo;
 import android.net.thread.ActiveOperationalDataset;
 
 import com.google.errorprone.annotations.FormatMethod;
@@ -34,6 +35,7 @@
 import java.io.InputStreamReader;
 import java.io.OutputStreamWriter;
 import java.net.Inet6Address;
+import java.net.InetAddress;
 import java.nio.charset.StandardCharsets;
 import java.time.Duration;
 import java.util.ArrayList;
@@ -327,6 +329,55 @@
         return false;
     }
 
+    /** Sets the DNS server address. */
+    public void setDnsServerAddress(String address) {
+        executeCommand("dns config " + address);
+    }
+
+    /** Returns the first browsed service instance of {@code serviceType}. */
+    public NsdServiceInfo browseService(String serviceType) {
+        // CLI output:
+        // DNS browse response for _testservice._tcp.default.service.arpa.
+        // test-service
+        //    Port:12345, Priority:0, Weight:0, TTL:10
+        //    Host:testhost.default.service.arpa.
+        //    HostAddress:2001:0:0:0:0:0:0:1 TTL:10
+        //    TXT:[key1=0102, key2=03] TTL:10
+
+        List<String> lines = executeCommand("dns browse " + serviceType);
+        NsdServiceInfo info = new NsdServiceInfo();
+        info.setServiceName(lines.get(1));
+        info.setServiceType(serviceType);
+        info.setPort(DnsServiceCliOutputParser.parsePort(lines.get(2)));
+        info.setHostname(DnsServiceCliOutputParser.parseHostname(lines.get(3)));
+        info.setHostAddresses(List.of(DnsServiceCliOutputParser.parseHostAddress(lines.get(4))));
+        DnsServiceCliOutputParser.parseTxtIntoServiceInfo(lines.get(5), info);
+
+        return info;
+    }
+
+    /** Returns the resolved service instance. */
+    public NsdServiceInfo resolveService(String serviceName, String serviceType) {
+        // CLI output:
+        // DNS service resolution response for test-service for service
+        // _test._tcp.default.service.arpa.
+        // Port:12345, Priority:0, Weight:0, TTL:10
+        // Host:Android.default.service.arpa.
+        // HostAddress:2001:0:0:0:0:0:0:1 TTL:10
+        // TXT:[key1=0102, key2=03] TTL:10
+
+        List<String> lines = executeCommand("dns service %s %s", serviceName, serviceType);
+        NsdServiceInfo info = new NsdServiceInfo();
+        info.setServiceName(serviceName);
+        info.setServiceType(serviceType);
+        info.setPort(DnsServiceCliOutputParser.parsePort(lines.get(1)));
+        info.setHostname(DnsServiceCliOutputParser.parseHostname(lines.get(2)));
+        info.setHostAddresses(List.of(DnsServiceCliOutputParser.parseHostAddress(lines.get(3))));
+        DnsServiceCliOutputParser.parseTxtIntoServiceInfo(lines.get(4), info);
+
+        return info;
+    }
+
     /** Runs the "factoryreset" command on the device. */
     public void factoryReset() {
         try {
@@ -454,4 +505,45 @@
     private static String toHexString(byte[] bytes) {
         return base16().encode(bytes);
     }
+
+    private static final class DnsServiceCliOutputParser {
+        /** Returns the first match in the input of a given regex pattern. */
+        private static Matcher firstMatchOf(String input, String regex) {
+            Matcher matcher = Pattern.compile(regex).matcher(input);
+            matcher.find();
+            return matcher;
+        }
+
+        // Example: "Port:12345"
+        private static int parsePort(String line) {
+            return Integer.parseInt(firstMatchOf(line, "Port:(\\d+)").group(1));
+        }
+
+        // Example: "Host:Android.default.service.arpa."
+        private static String parseHostname(String line) {
+            return firstMatchOf(line, "Host:(.+)").group(1);
+        }
+
+        // Example: "HostAddress:2001:0:0:0:0:0:0:1"
+        private static InetAddress parseHostAddress(String line) {
+            return InetAddresses.parseNumericAddress(
+                    firstMatchOf(line, "HostAddress:([^ ]+)").group(1));
+        }
+
+        // Example: "TXT:[key1=0102, key2=03]"
+        private static void parseTxtIntoServiceInfo(String line, NsdServiceInfo serviceInfo) {
+            String txtString = firstMatchOf(line, "TXT:\\[([^\\]]+)\\]").group(1);
+            for (String txtEntry : txtString.split(",")) {
+                String[] nameAndValue = txtEntry.trim().split("=");
+                String name = nameAndValue[0];
+                String value = nameAndValue[1];
+                byte[] bytes = new byte[value.length() / 2];
+                for (int i = 0; i < value.length(); i += 2) {
+                    byte b = (byte) ((value.charAt(i) - '0') << 4 | (value.charAt(i + 1) - '0'));
+                    bytes[i / 2] = b;
+                }
+                serviceInfo.setAttribute(name, bytes);
+            }
+        }
+    }
 }
diff --git a/thread/tests/integration/src/android/net/thread/utils/OtDaemonController.java b/thread/tests/integration/src/android/net/thread/utils/OtDaemonController.java
index 4a06fe8..ade0669 100644
--- a/thread/tests/integration/src/android/net/thread/utils/OtDaemonController.java
+++ b/thread/tests/integration/src/android/net/thread/utils/OtDaemonController.java
@@ -62,6 +62,18 @@
                 .toList();
     }
 
+    /** Returns {@code true} if the Thread interface is up. */
+    public boolean isInterfaceUp() {
+        String output = executeCommand("ifconfig");
+        return output.contains("up");
+    }
+
+    /** Returns the ML-EID of the device. */
+    public Inet6Address getMlEid() {
+        String addressStr = executeCommand("ipaddr mleid").split("\n")[0].trim();
+        return (Inet6Address) InetAddresses.parseNumericAddress(addressStr);
+    }
+
     public String executeCommand(String cmd) {
         return SystemUtil.runShellCommand(OT_CTL + " " + cmd);
     }
diff --git a/thread/tests/integration/src/android/net/thread/utils/ThreadNetworkControllerWrapper.java b/thread/tests/integration/src/android/net/thread/utils/ThreadNetworkControllerWrapper.java
index e7b4cd9..7e84233 100644
--- a/thread/tests/integration/src/android/net/thread/utils/ThreadNetworkControllerWrapper.java
+++ b/thread/tests/integration/src/android/net/thread/utils/ThreadNetworkControllerWrapper.java
@@ -46,6 +46,7 @@
     public static final Duration JOIN_TIMEOUT = Duration.ofSeconds(10);
     public static final Duration LEAVE_TIMEOUT = Duration.ofSeconds(2);
     private static final Duration CALLBACK_TIMEOUT = Duration.ofSeconds(1);
+    private static final Duration SET_ENABLED_TIMEOUT = Duration.ofSeconds(2);
 
     private final ThreadNetworkController mController;
 
@@ -115,6 +116,18 @@
         }
     }
 
+    /** An synchronous variant of {@link ThreadNetworkController#setEnabled}. */
+    public void setEnabledAndWait(boolean enabled)
+            throws InterruptedException, ExecutionException, TimeoutException {
+        CompletableFuture<Void> future = new CompletableFuture<>();
+        runAsShell(
+                PERMISSION_THREAD_NETWORK_PRIVILEGED,
+                () ->
+                        mController.setEnabled(
+                                enabled, directExecutor(), newOutcomeReceiver(future)));
+        future.get(SET_ENABLED_TIMEOUT.toSeconds(), SECONDS);
+    }
+
     /** Joins the given network and wait for this device to become attached. */
     public void joinAndWait(ActiveOperationalDataset activeDataset)
             throws InterruptedException, ExecutionException, TimeoutException {
diff --git a/thread/tests/unit/src/android/net/thread/ThreadNetworkControllerTest.java b/thread/tests/unit/src/android/net/thread/ThreadNetworkControllerTest.java
index 75eb043..ac74372 100644
--- a/thread/tests/unit/src/android/net/thread/ThreadNetworkControllerTest.java
+++ b/thread/tests/unit/src/android/net/thread/ThreadNetworkControllerTest.java
@@ -19,6 +19,7 @@
 import static android.net.thread.ThreadNetworkController.DEVICE_ROLE_CHILD;
 import static android.net.thread.ThreadNetworkException.ERROR_UNAVAILABLE;
 import static android.net.thread.ThreadNetworkException.ERROR_UNSUPPORTED_CHANNEL;
+import static android.net.thread.ThreadNetworkException.ERROR_UNSUPPORTED_OPERATION;
 import static android.os.Process.SYSTEM_UID;
 
 import static com.google.common.io.BaseEncoding.base16;
@@ -33,6 +34,7 @@
 import android.os.Binder;
 import android.os.OutcomeReceiver;
 import android.os.Process;
+import android.util.SparseIntArray;
 
 import androidx.test.ext.junit.runners.AndroidJUnit4;
 import androidx.test.filters.SmallTest;
@@ -77,6 +79,13 @@
     private static final ActiveOperationalDataset DEFAULT_DATASET =
             ActiveOperationalDataset.fromThreadTlvs(DEFAULT_DATASET_TLVS);
 
+    private static final SparseIntArray DEFAULT_CHANNEL_POWERS =
+            new SparseIntArray() {
+                {
+                    put(20, 32767);
+                }
+            };
+
     @Before
     public void setUp() {
         MockitoAnnotations.initMocks(this);
@@ -111,6 +120,10 @@
         return (IOperationReceiver) invocation.getArguments()[1];
     }
 
+    private static IOperationReceiver getSetChannelMaxPowersReceiver(InvocationOnMock invocation) {
+        return (IOperationReceiver) invocation.getArguments()[1];
+    }
+
     private static IActiveOperationalDatasetReceiver getCreateDatasetReceiver(
             InvocationOnMock invocation) {
         return (IActiveOperationalDatasetReceiver) invocation.getArguments()[1];
@@ -361,6 +374,51 @@
     }
 
     @Test
+    public void setChannelMaxPowers_callbackIsInvokedWithCallingAppIdentity() throws Exception {
+        setBinderUid(SYSTEM_UID);
+
+        AtomicInteger successCallbackUid = new AtomicInteger(0);
+        AtomicInteger errorCallbackUid = new AtomicInteger(0);
+
+        doAnswer(
+                        invoke -> {
+                            getSetChannelMaxPowersReceiver(invoke).onSuccess();
+                            return null;
+                        })
+                .when(mMockService)
+                .setChannelMaxPowers(any(ChannelMaxPower[].class), any(IOperationReceiver.class));
+        mController.setChannelMaxPowers(
+                DEFAULT_CHANNEL_POWERS,
+                Runnable::run,
+                v -> successCallbackUid.set(Binder.getCallingUid()));
+        doAnswer(
+                        invoke -> {
+                            getSetChannelMaxPowersReceiver(invoke)
+                                    .onError(ERROR_UNSUPPORTED_OPERATION, "");
+                            return null;
+                        })
+                .when(mMockService)
+                .setChannelMaxPowers(any(ChannelMaxPower[].class), any(IOperationReceiver.class));
+        mController.setChannelMaxPowers(
+                DEFAULT_CHANNEL_POWERS,
+                Runnable::run,
+                new OutcomeReceiver<>() {
+                    @Override
+                    public void onResult(Void unused) {}
+
+                    @Override
+                    public void onError(ThreadNetworkException e) {
+                        errorCallbackUid.set(Binder.getCallingUid());
+                    }
+                });
+
+        assertThat(successCallbackUid.get()).isNotEqualTo(SYSTEM_UID);
+        assertThat(successCallbackUid.get()).isEqualTo(Process.myUid());
+        assertThat(errorCallbackUid.get()).isNotEqualTo(SYSTEM_UID);
+        assertThat(errorCallbackUid.get()).isEqualTo(Process.myUid());
+    }
+
+    @Test
     public void setTestNetworkAsUpstream_callbackIsInvokedWithCallingAppIdentity()
             throws Exception {
         setBinderUid(SYSTEM_UID);
diff --git a/thread/tests/unit/src/android/net/thread/ThreadNetworkExceptionTest.java b/thread/tests/unit/src/android/net/thread/ThreadNetworkExceptionTest.java
index f62b437..5908c20 100644
--- a/thread/tests/unit/src/android/net/thread/ThreadNetworkExceptionTest.java
+++ b/thread/tests/unit/src/android/net/thread/ThreadNetworkExceptionTest.java
@@ -32,6 +32,6 @@
     public void constructor_tooLargeErrorCode_throwsIllegalArgumentException() throws Exception {
         // TODO (b/323791003): move this test case to cts/ThreadNetworkExceptionTest when mainline
         // CTS is ready.
-        assertThrows(IllegalArgumentException.class, () -> new ThreadNetworkException(13, "13"));
+        assertThrows(IllegalArgumentException.class, () -> new ThreadNetworkException(14, "14"));
     }
 }
diff --git a/thread/tests/unit/src/com/android/server/thread/NsdPublisherTest.java b/thread/tests/unit/src/com/android/server/thread/NsdPublisherTest.java
index d860166..8886c73 100644
--- a/thread/tests/unit/src/com/android/server/thread/NsdPublisherTest.java
+++ b/thread/tests/unit/src/com/android/server/thread/NsdPublisherTest.java
@@ -23,6 +23,7 @@
 
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.anyInt;
+import static org.mockito.ArgumentMatchers.argThat;
 import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.Mockito.doThrow;
 import static org.mockito.Mockito.spy;
@@ -30,24 +31,30 @@
 import static org.mockito.Mockito.verify;
 
 import android.net.InetAddresses;
+import android.net.nsd.DiscoveryRequest;
 import android.net.nsd.NsdManager;
 import android.net.nsd.NsdServiceInfo;
 import android.os.Handler;
 import android.os.test.TestLooper;
 
 import com.android.server.thread.openthread.DnsTxtAttribute;
+import com.android.server.thread.openthread.INsdDiscoverServiceCallback;
+import com.android.server.thread.openthread.INsdResolveServiceCallback;
 import com.android.server.thread.openthread.INsdStatusReceiver;
 
 import org.junit.Before;
 import org.junit.Test;
 import org.mockito.ArgumentCaptor;
+import org.mockito.ArgumentMatcher;
 import org.mockito.Mock;
 import org.mockito.MockitoAnnotations;
 
 import java.net.InetAddress;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Collections;
 import java.util.List;
+import java.util.Objects;
 import java.util.Set;
 import java.util.concurrent.Executor;
 
@@ -57,6 +64,8 @@
 
     @Mock private INsdStatusReceiver mRegistrationReceiver;
     @Mock private INsdStatusReceiver mUnregistrationReceiver;
+    @Mock private INsdDiscoverServiceCallback mDiscoverServiceCallback;
+    @Mock private INsdResolveServiceCallback mResolveServiceCallback;
 
     private TestLooper mTestLooper;
     private NsdPublisher mNsdPublisher;
@@ -469,6 +478,165 @@
     }
 
     @Test
+    public void discoverService_serviceDiscovered() throws Exception {
+        prepareTest();
+
+        mNsdPublisher.discoverService("_test._tcp", mDiscoverServiceCallback, 10 /* listenerId */);
+        mTestLooper.dispatchAll();
+        ArgumentCaptor<NsdManager.DiscoveryListener> discoveryListenerArgumentCaptor =
+                ArgumentCaptor.forClass(NsdManager.DiscoveryListener.class);
+        verify(mMockNsdManager, times(1))
+                .discoverServices(
+                        eq(new DiscoveryRequest.Builder(PROTOCOL_DNS_SD, "_test._tcp").build()),
+                        any(Executor.class),
+                        discoveryListenerArgumentCaptor.capture());
+        NsdManager.DiscoveryListener actualDiscoveryListener =
+                discoveryListenerArgumentCaptor.getValue();
+        NsdServiceInfo serviceInfo = new NsdServiceInfo();
+        serviceInfo.setServiceName("test");
+        serviceInfo.setServiceType(null);
+        actualDiscoveryListener.onServiceFound(serviceInfo);
+        mTestLooper.dispatchAll();
+
+        verify(mDiscoverServiceCallback, times(1))
+                .onServiceDiscovered("test", "_test._tcp", true /* isFound */);
+    }
+
+    @Test
+    public void discoverService_serviceLost() throws Exception {
+        prepareTest();
+
+        mNsdPublisher.discoverService("_test._tcp", mDiscoverServiceCallback, 10 /* listenerId */);
+        mTestLooper.dispatchAll();
+        ArgumentCaptor<NsdManager.DiscoveryListener> discoveryListenerArgumentCaptor =
+                ArgumentCaptor.forClass(NsdManager.DiscoveryListener.class);
+        verify(mMockNsdManager, times(1))
+                .discoverServices(
+                        eq(new DiscoveryRequest.Builder(PROTOCOL_DNS_SD, "_test._tcp").build()),
+                        any(Executor.class),
+                        discoveryListenerArgumentCaptor.capture());
+        NsdManager.DiscoveryListener actualDiscoveryListener =
+                discoveryListenerArgumentCaptor.getValue();
+        NsdServiceInfo serviceInfo = new NsdServiceInfo();
+        serviceInfo.setServiceName("test");
+        serviceInfo.setServiceType(null);
+        actualDiscoveryListener.onServiceLost(serviceInfo);
+        mTestLooper.dispatchAll();
+
+        verify(mDiscoverServiceCallback, times(1))
+                .onServiceDiscovered("test", "_test._tcp", false /* isFound */);
+    }
+
+    @Test
+    public void stopServiceDiscovery() {
+        prepareTest();
+
+        mNsdPublisher.discoverService("_test._tcp", mDiscoverServiceCallback, 10 /* listenerId */);
+        mTestLooper.dispatchAll();
+        ArgumentCaptor<NsdManager.DiscoveryListener> discoveryListenerArgumentCaptor =
+                ArgumentCaptor.forClass(NsdManager.DiscoveryListener.class);
+        verify(mMockNsdManager, times(1))
+                .discoverServices(
+                        eq(new DiscoveryRequest.Builder(PROTOCOL_DNS_SD, "_test._tcp").build()),
+                        any(Executor.class),
+                        discoveryListenerArgumentCaptor.capture());
+        NsdManager.DiscoveryListener actualDiscoveryListener =
+                discoveryListenerArgumentCaptor.getValue();
+        NsdServiceInfo serviceInfo = new NsdServiceInfo();
+        serviceInfo.setServiceName("test");
+        serviceInfo.setServiceType(null);
+        actualDiscoveryListener.onServiceFound(serviceInfo);
+        mNsdPublisher.stopServiceDiscovery(10 /* listenerId */);
+        mTestLooper.dispatchAll();
+
+        verify(mMockNsdManager, times(1)).stopServiceDiscovery(actualDiscoveryListener);
+    }
+
+    @Test
+    public void resolveService_serviceResolved() throws Exception {
+        prepareTest();
+
+        mNsdPublisher.resolveService(
+                "test", "_test._tcp", mResolveServiceCallback, 10 /* listenerId */);
+        mTestLooper.dispatchAll();
+        ArgumentCaptor<NsdServiceInfo> serviceInfoArgumentCaptor =
+                ArgumentCaptor.forClass(NsdServiceInfo.class);
+        ArgumentCaptor<NsdManager.ServiceInfoCallback> serviceInfoCallbackArgumentCaptor =
+                ArgumentCaptor.forClass(NsdManager.ServiceInfoCallback.class);
+        verify(mMockNsdManager, times(1))
+                .registerServiceInfoCallback(
+                        serviceInfoArgumentCaptor.capture(),
+                        any(Executor.class),
+                        serviceInfoCallbackArgumentCaptor.capture());
+        assertThat(serviceInfoArgumentCaptor.getValue().getServiceName()).isEqualTo("test");
+        assertThat(serviceInfoArgumentCaptor.getValue().getServiceType()).isEqualTo("_test._tcp");
+        NsdServiceInfo serviceInfo = new NsdServiceInfo();
+        serviceInfo.setServiceName("test");
+        serviceInfo.setServiceType("_test._tcp");
+        serviceInfo.setPort(12345);
+        serviceInfo.setHostname("test-host");
+        serviceInfo.setHostAddresses(
+                List.of(
+                        InetAddress.parseNumericAddress("2001::1"),
+                        InetAddress.parseNumericAddress("2001::2")));
+        serviceInfo.setAttribute("key1", new byte[] {(byte) 0x01, (byte) 0x02});
+        serviceInfo.setAttribute("key2", new byte[] {(byte) 0x03});
+        serviceInfoCallbackArgumentCaptor.getValue().onServiceUpdated(serviceInfo);
+        mTestLooper.dispatchAll();
+
+        verify(mResolveServiceCallback, times(1))
+                .onServiceResolved(
+                        eq("test-host"),
+                        eq("test"),
+                        eq("_test._tcp"),
+                        eq(12345),
+                        eq(List.of("2001::1", "2001::2")),
+                        argThat(
+                                new TxtMatcher(
+                                        List.of(
+                                                makeTxtAttribute("key1", List.of(0x01, 0x02)),
+                                                makeTxtAttribute("key2", List.of(0x03))))),
+                        anyInt());
+    }
+
+    @Test
+    public void stopServiceResolution() throws Exception {
+        prepareTest();
+
+        mNsdPublisher.resolveService(
+                "test", "_test._tcp", mResolveServiceCallback, 10 /* listenerId */);
+        mTestLooper.dispatchAll();
+        ArgumentCaptor<NsdServiceInfo> serviceInfoArgumentCaptor =
+                ArgumentCaptor.forClass(NsdServiceInfo.class);
+        ArgumentCaptor<NsdManager.ServiceInfoCallback> serviceInfoCallbackArgumentCaptor =
+                ArgumentCaptor.forClass(NsdManager.ServiceInfoCallback.class);
+        verify(mMockNsdManager, times(1))
+                .registerServiceInfoCallback(
+                        serviceInfoArgumentCaptor.capture(),
+                        any(Executor.class),
+                        serviceInfoCallbackArgumentCaptor.capture());
+        assertThat(serviceInfoArgumentCaptor.getValue().getServiceName()).isEqualTo("test");
+        assertThat(serviceInfoArgumentCaptor.getValue().getServiceType()).isEqualTo("_test._tcp");
+        NsdServiceInfo serviceInfo = new NsdServiceInfo();
+        serviceInfo.setServiceName("test");
+        serviceInfo.setServiceType("_test._tcp");
+        serviceInfo.setPort(12345);
+        serviceInfo.setHostname("test-host");
+        serviceInfo.setHostAddresses(
+                List.of(
+                        InetAddress.parseNumericAddress("2001::1"),
+                        InetAddress.parseNumericAddress("2001::2")));
+        serviceInfo.setAttribute("key1", new byte[] {(byte) 0x01, (byte) 0x02});
+        serviceInfo.setAttribute("key2", new byte[] {(byte) 0x03});
+        serviceInfoCallbackArgumentCaptor.getValue().onServiceUpdated(serviceInfo);
+        mNsdPublisher.stopServiceResolution(10 /* listenerId */);
+        mTestLooper.dispatchAll();
+
+        verify(mMockNsdManager, times(1))
+                .unregisterServiceInfoCallback(serviceInfoCallbackArgumentCaptor.getValue());
+    }
+
+    @Test
     public void reset_unregisterAll() {
         prepareTest();
 
@@ -582,6 +750,30 @@
         return addresses;
     }
 
+    private static class TxtMatcher implements ArgumentMatcher<List<DnsTxtAttribute>> {
+        private final List<DnsTxtAttribute> mAttributes;
+
+        TxtMatcher(List<DnsTxtAttribute> attributes) {
+            mAttributes = attributes;
+        }
+
+        @Override
+        public boolean matches(List<DnsTxtAttribute> argument) {
+            if (argument.size() != mAttributes.size()) {
+                return false;
+            }
+            for (int i = 0; i < argument.size(); ++i) {
+                if (!Objects.equals(argument.get(i).name, mAttributes.get(i).name)) {
+                    return false;
+                }
+                if (!Arrays.equals(argument.get(i).value, mAttributes.get(i).value)) {
+                    return false;
+                }
+            }
+            return true;
+        }
+    }
+
     // @Before and @Test run in different threads. NsdPublisher requires the jobs are run on the
     // thread looper, so TestLooper needs to be created inside each test case to install the
     // correct looper.
diff --git a/thread/tests/unit/src/com/android/server/thread/ThreadNetworkControllerServiceTest.java b/thread/tests/unit/src/com/android/server/thread/ThreadNetworkControllerServiceTest.java
index f54edfe..0c7d086 100644
--- a/thread/tests/unit/src/com/android/server/thread/ThreadNetworkControllerServiceTest.java
+++ b/thread/tests/unit/src/com/android/server/thread/ThreadNetworkControllerServiceTest.java
@@ -21,6 +21,7 @@
 import static android.net.thread.ThreadNetworkController.STATE_ENABLED;
 import static android.net.thread.ThreadNetworkException.ERROR_FAILED_PRECONDITION;
 import static android.net.thread.ThreadNetworkException.ERROR_INTERNAL_ERROR;
+import static android.net.thread.ThreadNetworkException.ERROR_THREAD_DISABLED;
 import static android.net.thread.ThreadNetworkManager.DISALLOW_THREAD_NETWORK;
 import static android.net.thread.ThreadNetworkManager.PERMISSION_THREAD_NETWORK_PRIVILEGED;
 
@@ -37,6 +38,7 @@
 import static org.mockito.Mockito.any;
 import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.doNothing;
+import static org.mockito.Mockito.doThrow;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.never;
 import static org.mockito.Mockito.spy;
@@ -76,7 +78,9 @@
 import org.junit.runner.RunWith;
 import org.mockito.ArgumentCaptor;
 import org.mockito.Captor;
+import org.mockito.InOrder;
 import org.mockito.Mock;
+import org.mockito.Mockito;
 import org.mockito.MockitoAnnotations;
 
 import java.util.concurrent.CompletableFuture;
@@ -312,13 +316,13 @@
     }
 
     @Test
-    public void userRestriction_initWithUserRestricted_threadIsDisabled() {
+    public void userRestriction_initWithUserRestricted_otDaemonNotStarted() {
         when(mMockUserManager.hasUserRestriction(eq(DISALLOW_THREAD_NETWORK))).thenReturn(true);
 
         mService.initialize();
         mTestLooper.dispatchAll();
 
-        assertThat(mFakeOtDaemon.getEnabledState()).isEqualTo(STATE_DISABLED);
+        assertThat(mFakeOtDaemon.isInitialized()).isFalse();
     }
 
     @Test
@@ -440,4 +444,70 @@
         verify(mockReceiver, never()).onSuccess(any(ActiveOperationalDataset.class));
         verify(mockReceiver, times(1)).onError(eq(ERROR_INTERNAL_ERROR), anyString());
     }
+
+    @Test
+    public void forceStopOtDaemonForTest_noPermission_throwsSecurityException() {
+        doThrow(new SecurityException(""))
+                .when(mContext)
+                .enforceCallingOrSelfPermission(eq(PERMISSION_THREAD_NETWORK_PRIVILEGED), any());
+
+        assertThrows(
+                SecurityException.class,
+                () -> mService.forceStopOtDaemonForTest(true, new IOperationReceiver.Default()));
+    }
+
+    @Test
+    public void forceStopOtDaemonForTest_enabled_otDaemonDiesAndJoinFails() throws Exception {
+        mService.initialize();
+        IOperationReceiver mockReceiver = mock(IOperationReceiver.class);
+        IOperationReceiver mockJoinReceiver = mock(IOperationReceiver.class);
+
+        mService.forceStopOtDaemonForTest(true, mockReceiver);
+        mTestLooper.dispatchAll();
+        mService.join(DEFAULT_ACTIVE_DATASET, mockJoinReceiver);
+        mTestLooper.dispatchAll();
+
+        verify(mockReceiver, times(1)).onSuccess();
+        assertThat(mFakeOtDaemon.isInitialized()).isFalse();
+        verify(mockJoinReceiver, times(1)).onError(eq(ERROR_THREAD_DISABLED), anyString());
+    }
+
+    @Test
+    public void forceStopOtDaemonForTest_disable_otDaemonRestartsAndJoinSccess() throws Exception {
+        mService.initialize();
+        IOperationReceiver mockReceiver = mock(IOperationReceiver.class);
+        IOperationReceiver mockJoinReceiver = mock(IOperationReceiver.class);
+
+        mService.forceStopOtDaemonForTest(true, mock(IOperationReceiver.class));
+        mTestLooper.dispatchAll();
+        mService.forceStopOtDaemonForTest(false, mockReceiver);
+        mTestLooper.dispatchAll();
+        mService.join(DEFAULT_ACTIVE_DATASET, mockJoinReceiver);
+        mTestLooper.dispatchAll();
+        mTestLooper.moveTimeForward(FakeOtDaemon.JOIN_DELAY.toMillis() + 100);
+        mTestLooper.dispatchAll();
+
+        verify(mockReceiver, times(1)).onSuccess();
+        assertThat(mFakeOtDaemon.isInitialized()).isTrue();
+        verify(mockJoinReceiver, times(1)).onSuccess();
+    }
+
+    @Test
+    public void onOtDaemonDied_joinedNetwork_interfaceStateBackToUp() throws Exception {
+        mService.initialize();
+        final IOperationReceiver mockReceiver = mock(IOperationReceiver.class);
+        mService.join(DEFAULT_ACTIVE_DATASET, mockReceiver);
+        mTestLooper.dispatchAll();
+        mTestLooper.moveTimeForward(FakeOtDaemon.JOIN_DELAY.toMillis() + 100);
+        mTestLooper.dispatchAll();
+
+        Mockito.reset(mMockInfraIfController);
+        mFakeOtDaemon.terminate();
+        mTestLooper.dispatchAll();
+
+        verify(mMockTunIfController, times(1)).onOtDaemonDied();
+        InOrder inOrder = Mockito.inOrder(mMockTunIfController);
+        inOrder.verify(mMockTunIfController, times(1)).setInterfaceUp(false);
+        inOrder.verify(mMockTunIfController, times(1)).setInterfaceUp(true);
+    }
 }
diff --git a/thread/tests/unit/src/com/android/server/thread/ThreadNetworkShellCommandTest.java b/thread/tests/unit/src/com/android/server/thread/ThreadNetworkShellCommandTest.java
index c7e0eca..9f2d0cb 100644
--- a/thread/tests/unit/src/com/android/server/thread/ThreadNetworkShellCommandTest.java
+++ b/thread/tests/unit/src/com/android/server/thread/ThreadNetworkShellCommandTest.java
@@ -16,10 +16,15 @@
 
 package com.android.server.thread;
 
+import static org.mockito.ArgumentMatchers.anyBoolean;
 import static org.mockito.Mockito.any;
+import static org.mockito.Mockito.atLeastOnce;
 import static org.mockito.Mockito.contains;
+import static org.mockito.Mockito.doNothing;
+import static org.mockito.Mockito.doThrow;
 import static org.mockito.Mockito.eq;
 import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.validateMockitoUsage;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
@@ -45,19 +50,19 @@
 @SmallTest
 public class ThreadNetworkShellCommandTest {
     private static final String TAG = "ThreadNetworkShellCommandTTest";
-    @Mock ThreadNetworkService mThreadNetworkService;
-    @Mock ThreadNetworkCountryCode mThreadNetworkCountryCode;
+    @Mock ThreadNetworkControllerService mControllerService;
+    @Mock ThreadNetworkCountryCode mCountryCode;
     @Mock PrintWriter mErrorWriter;
     @Mock PrintWriter mOutputWriter;
 
-    ThreadNetworkShellCommand mThreadNetworkShellCommand;
+    ThreadNetworkShellCommand mShellCommand;
 
     @Before
     public void setUp() throws Exception {
         MockitoAnnotations.initMocks(this);
 
-        mThreadNetworkShellCommand = new ThreadNetworkShellCommand(mThreadNetworkCountryCode);
-        mThreadNetworkShellCommand.setPrintWriters(mOutputWriter, mErrorWriter);
+        mShellCommand = new ThreadNetworkShellCommand(mControllerService, mCountryCode);
+        mShellCommand.setPrintWriters(mOutputWriter, mErrorWriter);
     }
 
     @After
@@ -68,9 +73,9 @@
     @Test
     public void getCountryCode_executeInUnrootedShell_allowed() {
         BinderUtil.setUid(Process.SHELL_UID);
-        when(mThreadNetworkCountryCode.getCountryCode()).thenReturn("US");
+        when(mCountryCode.getCountryCode()).thenReturn("US");
 
-        mThreadNetworkShellCommand.exec(
+        mShellCommand.exec(
                 new Binder(),
                 new FileDescriptor(),
                 new FileDescriptor(),
@@ -84,14 +89,14 @@
     public void forceSetCountryCodeEnabled_executeInUnrootedShell_notAllowed() {
         BinderUtil.setUid(Process.SHELL_UID);
 
-        mThreadNetworkShellCommand.exec(
+        mShellCommand.exec(
                 new Binder(),
                 new FileDescriptor(),
                 new FileDescriptor(),
                 new FileDescriptor(),
                 new String[] {"force-country-code", "enabled", "US"});
 
-        verify(mThreadNetworkCountryCode, never()).setOverrideCountryCode(eq("US"));
+        verify(mCountryCode, never()).setOverrideCountryCode(eq("US"));
         verify(mErrorWriter).println(contains("force-country-code"));
     }
 
@@ -99,28 +104,28 @@
     public void forceSetCountryCodeEnabled_executeInRootedShell_allowed() {
         BinderUtil.setUid(Process.ROOT_UID);
 
-        mThreadNetworkShellCommand.exec(
+        mShellCommand.exec(
                 new Binder(),
                 new FileDescriptor(),
                 new FileDescriptor(),
                 new FileDescriptor(),
                 new String[] {"force-country-code", "enabled", "US"});
 
-        verify(mThreadNetworkCountryCode).setOverrideCountryCode(eq("US"));
+        verify(mCountryCode).setOverrideCountryCode(eq("US"));
     }
 
     @Test
     public void forceSetCountryCodeDisabled_executeInUnrootedShell_notAllowed() {
         BinderUtil.setUid(Process.SHELL_UID);
 
-        mThreadNetworkShellCommand.exec(
+        mShellCommand.exec(
                 new Binder(),
                 new FileDescriptor(),
                 new FileDescriptor(),
                 new FileDescriptor(),
                 new String[] {"force-country-code", "disabled"});
 
-        verify(mThreadNetworkCountryCode, never()).setOverrideCountryCode(any());
+        verify(mCountryCode, never()).setOverrideCountryCode(any());
         verify(mErrorWriter).println(contains("force-country-code"));
     }
 
@@ -128,13 +133,64 @@
     public void forceSetCountryCodeDisabled_executeInRootedShell_allowed() {
         BinderUtil.setUid(Process.ROOT_UID);
 
-        mThreadNetworkShellCommand.exec(
+        mShellCommand.exec(
                 new Binder(),
                 new FileDescriptor(),
                 new FileDescriptor(),
                 new FileDescriptor(),
                 new String[] {"force-country-code", "disabled"});
 
-        verify(mThreadNetworkCountryCode).clearOverrideCountryCode();
+        verify(mCountryCode).clearOverrideCountryCode();
+    }
+
+    @Test
+    public void forceStopOtDaemon_executeInUnrootedShell_failedAndServiceApiNotCalled() {
+        BinderUtil.setUid(Process.SHELL_UID);
+
+        mShellCommand.exec(
+                new Binder(),
+                new FileDescriptor(),
+                new FileDescriptor(),
+                new FileDescriptor(),
+                new String[] {"force-stop-ot-daemon", "enabled"});
+
+        verify(mControllerService, never()).forceStopOtDaemonForTest(anyBoolean(), any());
+        verify(mErrorWriter, atLeastOnce()).println(contains("force-stop-ot-daemon"));
+        verify(mOutputWriter, never()).println();
+    }
+
+    @Test
+    public void forceStopOtDaemon_serviceThrows_failed() {
+        BinderUtil.setUid(Process.ROOT_UID);
+        doThrow(new SecurityException(""))
+                .when(mControllerService)
+                .forceStopOtDaemonForTest(eq(true), any());
+
+        mShellCommand.exec(
+                new Binder(),
+                new FileDescriptor(),
+                new FileDescriptor(),
+                new FileDescriptor(),
+                new String[] {"force-stop-ot-daemon", "enabled"});
+
+        verify(mControllerService, times(1)).forceStopOtDaemonForTest(eq(true), any());
+        verify(mOutputWriter, never()).println();
+    }
+
+    @Test
+    public void forceStopOtDaemon_serviceApiTimeout_failedWithTimeoutError() {
+        BinderUtil.setUid(Process.ROOT_UID);
+        doNothing().when(mControllerService).forceStopOtDaemonForTest(eq(true), any());
+
+        mShellCommand.exec(
+                new Binder(),
+                new FileDescriptor(),
+                new FileDescriptor(),
+                new FileDescriptor(),
+                new String[] {"force-stop-ot-daemon", "enabled"});
+
+        verify(mControllerService, times(1)).forceStopOtDaemonForTest(eq(true), any());
+        verify(mErrorWriter, atLeastOnce()).println(contains("timeout"));
+        verify(mOutputWriter, never()).println();
     }
 }