Merge "Fix isUidNetworkingBlocked for system uids" into main
diff --git a/netbpfload/netbpfload.mainline.rc b/netbpfload/netbpfload.mainline.rc
index d7202f7..d38a503 100644
--- a/netbpfload/netbpfload.mainline.rc
+++ b/netbpfload/netbpfload.mainline.rc
@@ -10,6 +10,7 @@
     capabilities CHOWN SYS_ADMIN NET_ADMIN
     group system root graphics network_stack net_admin net_bw_acct net_bw_stats net_raw
     user system
+    file /dev/kmsg w
     rlimit memlock 1073741824 1073741824
     oneshot
     reboot_on_failure reboot,bpfloader-failed
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java b/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java
index c1c7d5f..61eb766 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java
@@ -22,12 +22,10 @@
 import android.annotation.Nullable;
 import android.annotation.RequiresApi;
 import android.net.LinkAddress;
-import android.net.nsd.NsdManager;
 import android.net.nsd.NsdServiceInfo;
 import android.os.Build;
 import android.os.Handler;
 import android.os.Looper;
-import android.util.ArraySet;
 
 import com.android.internal.annotations.VisibleForTesting;
 import com.android.net.module.util.HexDump;
@@ -284,6 +282,7 @@
         if (!mRecordRepository.hasActiveService(id)) return;
         mProber.stop(id);
         mAnnouncer.stop(id);
+        final String hostname = mRecordRepository.getHostnameForServiceId(id);
         final MdnsAnnouncer.ExitAnnouncementInfo exitInfo = mRecordRepository.exitService(id);
         if (exitInfo != null) {
             // This effectively schedules onAllServicesRemoved(), as it is to be called when the
@@ -303,6 +302,24 @@
                 }
             });
         }
+        // Re-probe/re-announce the services which have the same custom hostname. These services
+        // were probed/announced using host addresses which were just removed so they should be
+        // re-probed/re-announced without those addresses.
+        if (hostname != null) {
+            final List<MdnsProber.ProbingInfo> probingInfos =
+                    mRecordRepository.restartProbingForHostname(hostname);
+            for (MdnsProber.ProbingInfo probingInfo : probingInfos) {
+                mProber.stop(probingInfo.getServiceId());
+                mProber.startProbing(probingInfo);
+            }
+            final List<MdnsAnnouncer.AnnouncementInfo> announcementInfos =
+                    mRecordRepository.restartAnnouncingForHostname(hostname);
+            for (MdnsAnnouncer.AnnouncementInfo announcementInfo : announcementInfos) {
+                mAnnouncer.stop(announcementInfo.getServiceId());
+                mAnnouncer.startSending(
+                        announcementInfo.getServiceId(), announcementInfo, 0 /* initialDelayMs */);
+            }
+        }
     }
 
     /**
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java b/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java
index ac64c3a..073e465 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java
@@ -925,22 +925,79 @@
         }
     }
 
+    @Nullable
+    public String getHostnameForServiceId(int id) {
+        ServiceRegistration registration = mServices.get(id);
+        if (registration == null) {
+            return null;
+        }
+        return registration.serviceInfo.getHostname();
+    }
+
+    /**
+     * Restart probing the services which are being probed and using the given custom hostname.
+     *
+     * @return The list of {@link MdnsProber.ProbingInfo} to be used by advertiser.
+     */
+    public List<MdnsProber.ProbingInfo> restartProbingForHostname(@NonNull String hostname) {
+        final ArrayList<MdnsProber.ProbingInfo> probingInfos = new ArrayList<>();
+        forEachActiveServiceRegistrationWithHostname(
+                hostname,
+                (id, registration) -> {
+                    if (!registration.isProbing) {
+                        return;
+                    }
+                    probingInfos.add(makeProbingInfo(id, registration));
+                });
+        return probingInfos;
+    }
+
+    /**
+     * Restart announcing the services which are using the given custom hostname.
+     *
+     * @return The list of {@link MdnsAnnouncer.AnnouncementInfo} to be used by advertiser.
+     */
+    public List<MdnsAnnouncer.AnnouncementInfo> restartAnnouncingForHostname(
+            @NonNull String hostname) {
+        final ArrayList<MdnsAnnouncer.AnnouncementInfo> announcementInfos = new ArrayList<>();
+        forEachActiveServiceRegistrationWithHostname(
+                hostname,
+                (id, registration) -> {
+                    if (registration.isProbing) {
+                        return;
+                    }
+                    announcementInfos.add(makeAnnouncementInfo(id, registration));
+                });
+        return announcementInfos;
+    }
+
     /**
      * Called to indicate that probing succeeded for a service.
+     *
      * @param probeSuccessInfo The successful probing info.
      * @return The {@link MdnsAnnouncer.AnnouncementInfo} to send, now that probing has succeeded.
      */
     public MdnsAnnouncer.AnnouncementInfo onProbingSucceeded(
-            MdnsProber.ProbingInfo probeSuccessInfo)
-            throws IOException {
-
-        int serviceId = probeSuccessInfo.getServiceId();
+            MdnsProber.ProbingInfo probeSuccessInfo) throws IOException {
+        final int serviceId = probeSuccessInfo.getServiceId();
         final ServiceRegistration registration = mServices.get(serviceId);
         if (registration == null) {
             throw new IOException("Service is not registered: " + serviceId);
         }
         registration.setProbing(false);
 
+        return makeAnnouncementInfo(serviceId, registration);
+    }
+
+    /**
+     * Make the announcement info of the given service ID.
+     *
+     * @param serviceId The service ID.
+     * @param registration The service registration.
+     * @return The {@link MdnsAnnouncer.AnnouncementInfo} of the given service ID.
+     */
+    private MdnsAnnouncer.AnnouncementInfo makeAnnouncementInfo(
+            int serviceId, ServiceRegistration registration) {
         final Set<MdnsRecord> answersSet = new LinkedHashSet<>();
         final ArrayList<MdnsRecord> additionalAnswers = new ArrayList<>();
 
@@ -972,8 +1029,8 @@
         addNsecRecordsForUniqueNames(additionalAnswers,
                 mGeneralRecords.iterator(), registration.allRecords.iterator());
 
-        return new MdnsAnnouncer.AnnouncementInfo(
-                probeSuccessInfo.getServiceId(), new ArrayList<>(answersSet), additionalAnswers);
+        return new MdnsAnnouncer.AnnouncementInfo(serviceId,
+                new ArrayList<>(answersSet), additionalAnswers);
     }
 
     /**
diff --git a/staticlibs/device/com/android/net/module/util/structs/PrefixInformationOption.java b/staticlibs/device/com/android/net/module/util/structs/PrefixInformationOption.java
index 49d7654..0fc85e4 100644
--- a/staticlibs/device/com/android/net/module/util/structs/PrefixInformationOption.java
+++ b/staticlibs/device/com/android/net/module/util/structs/PrefixInformationOption.java
@@ -21,6 +21,7 @@
 import android.net.IpPrefix;
 
 import androidx.annotation.NonNull;
+import androidx.annotation.VisibleForTesting;
 
 import com.android.net.module.util.Struct;
 import com.android.net.module.util.Struct.Field;
@@ -71,7 +72,8 @@
     @Field(order = 7, type = Type.ByteArray, arraysize = 16)
     public final byte[] prefix;
 
-    PrefixInformationOption(final byte type, final byte length, final byte prefixLen,
+    @VisibleForTesting
+    public PrefixInformationOption(final byte type, final byte length, final byte prefixLen,
             final byte flags, final long validLifetime, final long preferredLifetime,
             final int reserved, @NonNull final byte[] prefix) {
         this.type = type;
diff --git a/tests/cts/net/src/android/net/cts/ApfIntegrationTest.kt b/tests/cts/net/src/android/net/cts/ApfIntegrationTest.kt
new file mode 100644
index 0000000..e92c906
--- /dev/null
+++ b/tests/cts/net/src/android/net/cts/ApfIntegrationTest.kt
@@ -0,0 +1,81 @@
+/*
+ * Copyright (C) 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package android.net.cts
+
+import android.content.pm.PackageManager.FEATURE_WIFI
+import android.net.ConnectivityManager
+import android.net.NetworkCapabilities
+import android.net.NetworkRequest
+import android.os.Build
+import android.system.OsConstants
+import androidx.test.platform.app.InstrumentationRegistry
+import com.android.compatibility.common.util.PropertyUtil.isVendorApiLevelNewerThan
+import com.android.compatibility.common.util.SystemUtil
+import com.android.testutils.DevSdkIgnoreRule
+import com.android.testutils.DevSdkIgnoreRunner
+import com.android.testutils.NetworkStackModuleTest
+import com.android.testutils.RecorderCallback.CallbackEntry.LinkPropertiesChanged
+import com.android.testutils.TestableNetworkCallback
+import com.google.common.truth.Truth.assertThat
+import kotlin.test.assertEquals
+import kotlin.test.assertNotNull
+import org.junit.Assume.assumeTrue
+import org.junit.Before
+import org.junit.Test
+import org.junit.runner.RunWith
+
+private const val TIMEOUT_MS = 2000L
+
+@RunWith(DevSdkIgnoreRunner::class)
+@NetworkStackModuleTest
+@DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.TIRAMISU)
+class ApfIntegrationTest {
+    private val context by lazy { InstrumentationRegistry.getInstrumentation().context }
+    private val cm by lazy { context.getSystemService(ConnectivityManager::class.java)!! }
+    private val pm by lazy { context.packageManager }
+    private lateinit var wifiIfaceName: String
+    @Before
+    fun setUp() {
+        assumeTrue(pm.hasSystemFeature(FEATURE_WIFI))
+        assumeTrue(isVendorApiLevelNewerThan(Build.VERSION_CODES.TIRAMISU))
+        val cb = TestableNetworkCallback()
+        cm.requestNetwork(
+                NetworkRequest.Builder()
+                        .addTransportType(NetworkCapabilities.TRANSPORT_WIFI)
+                        .addCapability(NetworkCapabilities.NET_CAPABILITY_INTERNET)
+                        .build(),
+                cb
+        )
+        cb.eventuallyExpect<LinkPropertiesChanged>(TIMEOUT_MS) {
+            wifiIfaceName = assertNotNull(it.lp.interfaceName)
+            true
+        }
+        assertNotNull(wifiIfaceName)
+    }
+
+    @Test
+    fun testGetApfCapabilities() {
+        val capabilities = SystemUtil
+                .runShellCommand("cmd network_stack apf $wifiIfaceName capabilities").trim()
+        val (version, maxLen, packetFormat) = capabilities.split(",").map { it.toInt() }
+        assertEquals(4, version)
+        assertThat(maxLen).isAtLeast(1024)
+        if (isVendorApiLevelNewerThan(Build.VERSION_CODES.UPSIDE_DOWN_CAKE)) {
+            assertThat(maxLen).isAtLeast(2000)
+        }
+        assertEquals(OsConstants.ARPHRD_ETHER, packetFormat)
+    }
+}
diff --git a/tests/cts/net/src/android/net/cts/NsdManagerTest.kt b/tests/cts/net/src/android/net/cts/NsdManagerTest.kt
index 61117df..6dd4857 100644
--- a/tests/cts/net/src/android/net/cts/NsdManagerTest.kt
+++ b/tests/cts/net/src/android/net/cts/NsdManagerTest.kt
@@ -2206,6 +2206,66 @@
     }
 
     @Test
+    fun testAdvertisingAndDiscovery_reregisterCustomHostWithDifferentAddresses_newAddressesFound() {
+        val si1 = NsdServiceInfo().also {
+            it.network = testNetwork1.network
+            it.hostname = customHostname
+            it.hostAddresses = listOf(
+                    parseNumericAddress("192.0.2.23"),
+                    parseNumericAddress("2001:db8::1"))
+        }
+        val si2 = NsdServiceInfo().also {
+            it.network = testNetwork1.network
+            it.serviceName = serviceName
+            it.serviceType = serviceType
+            it.hostname = customHostname
+            it.port = TEST_PORT
+        }
+        val si3 = NsdServiceInfo().also {
+            it.network = testNetwork1.network
+            it.hostname = customHostname
+            it.hostAddresses = listOf(
+                    parseNumericAddress("192.0.2.24"),
+                    parseNumericAddress("2001:db8::2"))
+        }
+
+        val registrationRecord1 = NsdRegistrationRecord()
+        val registrationRecord2 = NsdRegistrationRecord()
+        val registrationRecord3 = NsdRegistrationRecord()
+
+        val discoveryRecord = NsdDiscoveryRecord()
+
+        tryTest {
+            registerService(registrationRecord1, si1)
+            registerService(registrationRecord2, si2)
+
+            nsdManager.unregisterService(registrationRecord1)
+            registrationRecord1.expectCallback<ServiceUnregistered>()
+
+            registerService(registrationRecord3, si3)
+
+            nsdManager.discoverServices(serviceType, NsdManager.PROTOCOL_DNS_SD,
+                    testNetwork1.network, Executor { it.run() }, discoveryRecord)
+            val discoveredInfo = discoveryRecord.waitForServiceDiscovered(
+                    serviceName, serviceType, testNetwork1.network)
+            val resolvedInfo = resolveService(discoveredInfo)
+
+            assertEquals(serviceName, discoveredInfo.serviceName)
+            assertEquals(TEST_PORT, resolvedInfo.port)
+            assertEquals(customHostname, resolvedInfo.hostname)
+            assertAddressEquals(
+                    listOf(parseNumericAddress("192.0.2.24"), parseNumericAddress("2001:db8::2")),
+                    resolvedInfo.hostAddresses)
+        } cleanupStep {
+            nsdManager.stopServiceDiscovery(discoveryRecord)
+            discoveryRecord.expectCallbackEventually<DiscoveryStopped>()
+        } cleanup {
+            nsdManager.unregisterService(registrationRecord2)
+            nsdManager.unregisterService(registrationRecord3)
+        }
+    }
+
+    @Test
     fun testServiceTypeClientRemovedAfterSocketDestroyed() {
         val si = makeTestServiceInfo(testNetwork1.network)
         // Register service on testNetwork1
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt
index 69fec85..7ac7bee 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt
@@ -18,7 +18,6 @@
 
 import android.net.InetAddresses.parseNumericAddress
 import android.net.LinkAddress
-import android.net.nsd.NsdManager
 import android.net.nsd.NsdServiceInfo
 import android.os.Build
 import android.os.HandlerThread
@@ -48,6 +47,7 @@
 import org.mockito.Mockito.anyInt
 import org.mockito.Mockito.anyString
 import org.mockito.Mockito.argThat
+import org.mockito.Mockito.atLeastOnce
 import org.mockito.Mockito.doAnswer
 import org.mockito.Mockito.doReturn
 import org.mockito.Mockito.eq
@@ -55,6 +55,8 @@
 import org.mockito.Mockito.never
 import org.mockito.Mockito.times
 import org.mockito.Mockito.verify
+import org.mockito.Mockito.clearInvocations
+import org.mockito.Mockito.inOrder
 
 private const val LOG_TAG = "testlogtag"
 private const val TIMEOUT_MS = 10_000L
@@ -65,6 +67,7 @@
 
 private const val TEST_SERVICE_ID_1 = 42
 private const val TEST_SERVICE_ID_DUPLICATE = 43
+private const val TEST_SERVICE_ID_2 = 44
 private val TEST_SERVICE_1 = NsdServiceInfo().apply {
     serviceType = "_testservice._tcp"
     serviceName = "MyTestService"
@@ -78,6 +81,13 @@
     port = 12345
 }
 
+private val TEST_SERVICE_1_CUSTOM_HOST = NsdServiceInfo().apply {
+    serviceType = "_testservice._tcp"
+    serviceName = "MyTestService"
+    hostname = "MyTestHost"
+    port = 12345
+}
+
 @RunWith(DevSdkIgnoreRunner::class)
 @IgnoreUpTo(Build.VERSION_CODES.S_V2)
 class MdnsInterfaceAdvertiserTest {
@@ -183,6 +193,63 @@
     }
 
     @Test
+    fun testAddRemoveServiceWithCustomHost_restartProbingForProbingServices() {
+        val customHost1 = NsdServiceInfo().apply {
+            hostname = "MyTestHost"
+            hostAddresses = listOf(
+                    parseNumericAddress("192.0.2.23"),
+                    parseNumericAddress("2001:db8::1"))
+        }
+        addServiceAndFinishProbing(TEST_SERVICE_ID_1, customHost1)
+        addServiceAndFinishProbing(TEST_SERVICE_ID_2, TEST_SERVICE_1_CUSTOM_HOST)
+        repository.setServiceProbing(TEST_SERVICE_ID_2)
+        val probingInfo = mock(ProbingInfo::class.java)
+        doReturn("MyTestHost")
+                .`when`(repository).getHostnameForServiceId(TEST_SERVICE_ID_1)
+        doReturn(TEST_SERVICE_ID_2).`when`(probingInfo).serviceId
+        doReturn(listOf(probingInfo))
+                .`when`(repository).restartProbingForHostname("MyTestHost")
+        val inOrder = inOrder(prober, announcer)
+
+        // Remove the custom host: the custom host's announcement is stopped and the probing
+        // services which use that hostname are re-announced.
+        advertiser.removeService(TEST_SERVICE_ID_1)
+
+        inOrder.verify(prober).stop(TEST_SERVICE_ID_1)
+        inOrder.verify(announcer).stop(TEST_SERVICE_ID_1)
+        inOrder.verify(prober).stop(TEST_SERVICE_ID_2)
+        inOrder.verify(prober).startProbing(probingInfo)
+    }
+
+    @Test
+    fun testAddRemoveServiceWithCustomHost_restartAnnouncingForProbedServices() {
+        val customHost1 = NsdServiceInfo().apply {
+            hostname = "MyTestHost"
+            hostAddresses = listOf(
+                    parseNumericAddress("192.0.2.23"),
+                    parseNumericAddress("2001:db8::1"))
+        }
+        addServiceAndFinishProbing(TEST_SERVICE_ID_1, customHost1)
+        val announcementInfo =
+                addServiceAndFinishProbing(TEST_SERVICE_ID_2, TEST_SERVICE_1_CUSTOM_HOST)
+        doReturn("MyTestHost")
+                .`when`(repository).getHostnameForServiceId(TEST_SERVICE_ID_1)
+        doReturn(TEST_SERVICE_ID_2).`when`(announcementInfo).serviceId
+        doReturn(listOf(announcementInfo))
+                .`when`(repository).restartAnnouncingForHostname("MyTestHost")
+        clearInvocations(announcer)
+
+        // Remove the custom host: the custom host's announcement is stopped and the probed services
+        // which use that hostname are re-announced.
+        advertiser.removeService(TEST_SERVICE_ID_1)
+
+        verify(prober).stop(TEST_SERVICE_ID_1)
+        verify(announcer, atLeastOnce()).stop(TEST_SERVICE_ID_1)
+        verify(announcer).stop(TEST_SERVICE_ID_2)
+        verify(announcer).startSending(TEST_SERVICE_ID_2, announcementInfo, 0L /* initialDelayMs */)
+    }
+
+    @Test
     fun testDoubleRemove() {
         addServiceAndFinishProbing(TEST_SERVICE_ID_1, TEST_SERVICE_1)
 
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt
index c69b1e1..271cc65 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt
@@ -24,6 +24,7 @@
 import com.android.server.connectivity.mdns.MdnsAnnouncer.AnnouncementInfo
 import com.android.server.connectivity.mdns.MdnsInterfaceAdvertiser.CONFLICT_HOST
 import com.android.server.connectivity.mdns.MdnsInterfaceAdvertiser.CONFLICT_SERVICE
+import com.android.server.connectivity.mdns.MdnsProber.ProbingInfo
 import com.android.server.connectivity.mdns.MdnsRecord.TYPE_A
 import com.android.server.connectivity.mdns.MdnsRecord.TYPE_AAAA
 import com.android.server.connectivity.mdns.MdnsRecord.TYPE_PTR
@@ -51,6 +52,10 @@
 import org.junit.Before
 import org.junit.Test
 import org.junit.runner.RunWith
+import org.mockito.ArgumentCaptor
+import org.mockito.ArgumentMatchers.eq
+import org.mockito.Mockito.mock
+import org.mockito.Mockito.verify
 
 private const val TEST_SERVICE_ID_1 = 42
 private const val TEST_SERVICE_ID_2 = 43
@@ -112,6 +117,14 @@
     port = TEST_PORT
 }
 
+private val TEST_SERVICE_CUSTOM_HOST_NO_ADDRESSES = NsdServiceInfo().apply {
+    hostname = "TestHost"
+    hostAddresses = listOf()
+    serviceType = "_testservice._tcp"
+    serviceName = "TestService"
+    port = TEST_PORT
+}
+
 @RunWith(DevSdkIgnoreRunner::class)
 @DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.S_V2)
 class MdnsRecordRepositoryTest {
@@ -1676,6 +1689,127 @@
         assertEquals(0, reply.additionalAnswers.size)
         assertEquals(knownAnswers, reply.knownAnswers)
     }
+
+    @Test
+    fun testRestartProbingForHostname() {
+        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
+        repository.initWithService(TEST_CUSTOM_HOST_ID_1, TEST_CUSTOM_HOST_1,
+                setOf(TEST_SUBTYPE, TEST_SUBTYPE2))
+        repository.addService(TEST_SERVICE_CUSTOM_HOST_ID_1,
+                TEST_SERVICE_CUSTOM_HOST_NO_ADDRESSES, null)
+        repository.setServiceProbing(TEST_SERVICE_CUSTOM_HOST_ID_1)
+        repository.removeService(TEST_CUSTOM_HOST_ID_1)
+
+        val probingInfos = repository.restartProbingForHostname("TestHost")
+
+        assertEquals(1, probingInfos.size)
+        val probingInfo = probingInfos.get(0)
+        assertEquals(TEST_SERVICE_CUSTOM_HOST_ID_1, probingInfo.serviceId)
+        val packet = probingInfo.getPacket(0)
+        assertEquals(0, packet.transactionId)
+        assertEquals(MdnsConstants.FLAGS_QUERY, packet.flags)
+        assertEquals(0, packet.answers.size)
+        assertEquals(0, packet.additionalRecords.size)
+        assertEquals(1, packet.questions.size)
+        val serviceName = arrayOf("TestService", "_testservice", "_tcp", "local")
+        assertEquals(MdnsAnyRecord(serviceName, false /* unicast */), packet.questions[0])
+        assertThat(packet.authorityRecords).containsExactly(
+                MdnsServiceRecord(
+                        serviceName,
+                        0L /* receiptTimeMillis */,
+                        false /* cacheFlush */,
+                        SHORT_TTL /* ttlMillis */,
+                        0 /* servicePriority */,
+                        0 /* serviceWeight */,
+                        TEST_PORT,
+                        TEST_CUSTOM_HOST_1_NAME))
+    }
+
+    @Test
+    fun testRestartAnnouncingForHostname() {
+        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
+        repository.initWithService(TEST_CUSTOM_HOST_ID_1, TEST_CUSTOM_HOST_1,
+                setOf(TEST_SUBTYPE, TEST_SUBTYPE2))
+        repository.addServiceAndFinishProbing(TEST_SERVICE_CUSTOM_HOST_ID_1,
+                TEST_SERVICE_CUSTOM_HOST_NO_ADDRESSES)
+        repository.removeService(TEST_CUSTOM_HOST_ID_1)
+
+        val announcementInfos = repository.restartAnnouncingForHostname("TestHost")
+
+        assertEquals(1, announcementInfos.size)
+        val announcementInfo = announcementInfos.get(0)
+        assertEquals(TEST_SERVICE_CUSTOM_HOST_ID_1, announcementInfo.serviceId)
+        val packet = announcementInfo.getPacket(0)
+        assertEquals(0, packet.transactionId)
+        assertEquals(0x8400 /* response, authoritative */, packet.flags)
+        assertEquals(0, packet.questions.size)
+        assertEquals(0, packet.authorityRecords.size)
+        val serviceName = arrayOf("TestService", "_testservice", "_tcp", "local")
+        val serviceType = arrayOf("_testservice", "_tcp", "local")
+        val v4AddrRev = getReverseDnsAddress(TEST_ADDRESSES[0].address)
+        val v6Addr1Rev = getReverseDnsAddress(TEST_ADDRESSES[1].address)
+        val v6Addr2Rev = getReverseDnsAddress(TEST_ADDRESSES[2].address)
+        assertThat(packet.answers).containsExactly(
+                MdnsPointerRecord(
+                        serviceType,
+                        0L /* receiptTimeMillis */,
+                        // Not a unique name owned by the announcer, so cacheFlush=false
+                        false /* cacheFlush */,
+                        4500000L /* ttlMillis */,
+                        serviceName),
+                MdnsServiceRecord(
+                        serviceName,
+                        0L /* receiptTimeMillis */,
+                        true /* cacheFlush */,
+                        120000L /* ttlMillis */,
+                        0 /* servicePriority */,
+                        0 /* serviceWeight */,
+                        TEST_PORT /* servicePort */,
+                        TEST_CUSTOM_HOST_1_NAME),
+                MdnsTextRecord(
+                        serviceName,
+                        0L /* receiptTimeMillis */,
+                        true /* cacheFlush */,
+                        4500000L /* ttlMillis */,
+                        emptyList() /* entries */),
+                MdnsPointerRecord(
+                        arrayOf("_services", "_dns-sd", "_udp", "local"),
+                        0L /* receiptTimeMillis */,
+                        false /* cacheFlush */,
+                        4500000L /* ttlMillis */,
+                        serviceType))
+        assertThat(packet.additionalRecords).containsExactly(
+                MdnsNsecRecord(v4AddrRev,
+                        0L /* receiptTimeMillis */,
+                        true /* cacheFlush */,
+                        120000L /* ttlMillis */,
+                        v4AddrRev,
+                        intArrayOf(TYPE_PTR)),
+                MdnsNsecRecord(TEST_HOSTNAME,
+                        0L /* receiptTimeMillis */,
+                        true /* cacheFlush */,
+                        120000L /* ttlMillis */,
+                        TEST_HOSTNAME,
+                        intArrayOf(TYPE_A, TYPE_AAAA)),
+                MdnsNsecRecord(v6Addr1Rev,
+                        0L /* receiptTimeMillis */,
+                        true /* cacheFlush */,
+                        120000L /* ttlMillis */,
+                        v6Addr1Rev,
+                        intArrayOf(TYPE_PTR)),
+                MdnsNsecRecord(v6Addr2Rev,
+                        0L /* receiptTimeMillis */,
+                        true /* cacheFlush */,
+                        120000L /* ttlMillis */,
+                        v6Addr2Rev,
+                        intArrayOf(TYPE_PTR)),
+                MdnsNsecRecord(serviceName,
+                        0L /* receiptTimeMillis */,
+                        true /* cacheFlush */,
+                        4500000L /* ttlMillis */,
+                        serviceName,
+                        intArrayOf(TYPE_TXT, TYPE_SRV)))
+    }
 }
 
 private fun MdnsRecordRepository.initWithService(
diff --git a/thread/service/java/com/android/server/thread/NsdPublisher.java b/thread/service/java/com/android/server/thread/NsdPublisher.java
index 3c7a72b..72e3980 100644
--- a/thread/service/java/com/android/server/thread/NsdPublisher.java
+++ b/thread/service/java/com/android/server/thread/NsdPublisher.java
@@ -21,6 +21,7 @@
 import android.annotation.NonNull;
 import android.content.Context;
 import android.net.InetAddresses;
+import android.net.nsd.DiscoveryRequest;
 import android.net.nsd.NsdManager;
 import android.net.nsd.NsdServiceInfo;
 import android.os.Handler;
@@ -31,15 +32,20 @@
 
 import com.android.internal.annotations.VisibleForTesting;
 import com.android.server.thread.openthread.DnsTxtAttribute;
+import com.android.server.thread.openthread.INsdDiscoverServiceCallback;
 import com.android.server.thread.openthread.INsdPublisher;
+import com.android.server.thread.openthread.INsdResolveServiceCallback;
 import com.android.server.thread.openthread.INsdStatusReceiver;
 
+import java.net.Inet6Address;
 import java.net.InetAddress;
 import java.util.ArrayDeque;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Deque;
 import java.util.HashSet;
 import java.util.List;
+import java.util.Map;
 import java.util.concurrent.Executor;
 
 /**
@@ -66,6 +72,8 @@
     private final Handler mHandler;
     private final Executor mExecutor;
     private final SparseArray<RegistrationListener> mRegistrationListeners = new SparseArray<>(0);
+    private final SparseArray<DiscoveryListener> mDiscoveryListeners = new SparseArray<>(0);
+    private final SparseArray<ServiceInfoListener> mServiceInfoListeners = new SparseArray<>(0);
     private final Deque<Runnable> mRegistrationJobs = new ArrayDeque<>();
 
     @VisibleForTesting
@@ -197,6 +205,110 @@
         mNsdManager.unregisterService(registrationListener);
     }
 
+    @Override
+    public void discoverService(String type, INsdDiscoverServiceCallback callback, int listenerId) {
+        mHandler.post(() -> discoverServiceInternal(type, callback, listenerId));
+    }
+
+    private void discoverServiceInternal(
+            String type, INsdDiscoverServiceCallback callback, int listenerId) {
+        checkOnHandlerThread();
+        Log.i(
+                TAG,
+                "Discovering services."
+                        + " Listener ID: "
+                        + listenerId
+                        + ", service type: "
+                        + type);
+
+        DiscoveryListener listener = new DiscoveryListener(listenerId, type, callback);
+        mDiscoveryListeners.append(listenerId, listener);
+        DiscoveryRequest discoveryRequest =
+                new DiscoveryRequest.Builder(type).setNetwork(null).build();
+        mNsdManager.discoverServices(discoveryRequest, mExecutor, listener);
+    }
+
+    @Override
+    public void stopServiceDiscovery(int listenerId) {
+        mHandler.post(() -> stopServiceDiscoveryInternal(listenerId));
+    }
+
+    private void stopServiceDiscoveryInternal(int listenerId) {
+        checkOnHandlerThread();
+
+        DiscoveryListener listener = mDiscoveryListeners.get(listenerId);
+        if (listener == null) {
+            Log.w(
+                    TAG,
+                    "Failed to stop service discovery. Listener ID "
+                            + listenerId
+                            + ". The listener is null.");
+            return;
+        }
+
+        Log.i(TAG, "Stopping service discovery. Listener: " + listener);
+        mNsdManager.stopServiceDiscovery(listener);
+    }
+
+    @Override
+    public void resolveService(
+            String name, String type, INsdResolveServiceCallback callback, int listenerId) {
+        mHandler.post(() -> resolveServiceInternal(name, type, callback, listenerId));
+    }
+
+    private void resolveServiceInternal(
+            String name, String type, INsdResolveServiceCallback callback, int listenerId) {
+        checkOnHandlerThread();
+
+        NsdServiceInfo serviceInfo = new NsdServiceInfo();
+        serviceInfo.setServiceName(name);
+        serviceInfo.setServiceType(type);
+        serviceInfo.setNetwork(null);
+        Log.i(
+                TAG,
+                "Resolving service."
+                        + " Listener ID: "
+                        + listenerId
+                        + ", service name: "
+                        + name
+                        + ", service type: "
+                        + type);
+
+        ServiceInfoListener listener = new ServiceInfoListener(serviceInfo, listenerId, callback);
+        mServiceInfoListeners.append(listenerId, listener);
+        mNsdManager.registerServiceInfoCallback(serviceInfo, mExecutor, listener);
+    }
+
+    @Override
+    public void stopServiceResolution(int listenerId) {
+        mHandler.post(() -> stopServiceResolutionInternal(listenerId));
+    }
+
+    private void stopServiceResolutionInternal(int listenerId) {
+        checkOnHandlerThread();
+
+        ServiceInfoListener listener = mServiceInfoListeners.get(listenerId);
+        if (listener == null) {
+            Log.w(
+                    TAG,
+                    "Failed to stop service resolution. Listener ID: "
+                            + listenerId
+                            + ". The listener is null.");
+            return;
+        }
+
+        Log.i(TAG, "Stopping service resolution. Listener: " + listener);
+
+        try {
+            mNsdManager.unregisterServiceInfoCallback(listener);
+        } catch (IllegalArgumentException e) {
+            Log.w(
+                    TAG,
+                    "Failed to stop the service resolution because it's already stopped. Listener: "
+                            + listener);
+        }
+    }
+
     private void checkOnHandlerThread() {
         if (mHandler.getLooper().getThread() != Thread.currentThread()) {
             throw new IllegalStateException(
@@ -368,4 +480,166 @@
             popAndRunNext();
         }
     }
+
+    private final class DiscoveryListener implements NsdManager.DiscoveryListener {
+        private final int mListenerId;
+        private final String mType;
+        private final INsdDiscoverServiceCallback mDiscoverServiceCallback;
+
+        DiscoveryListener(
+                int listenerId,
+                @NonNull String type,
+                @NonNull INsdDiscoverServiceCallback discoverServiceCallback) {
+            mListenerId = listenerId;
+            mType = type;
+            mDiscoverServiceCallback = discoverServiceCallback;
+        }
+
+        @Override
+        public void onStartDiscoveryFailed(String serviceType, int errorCode) {
+            Log.e(
+                    TAG,
+                    "Failed to start service discovery."
+                            + " Error code: "
+                            + errorCode
+                            + ", listener: "
+                            + this);
+            mDiscoveryListeners.remove(mListenerId);
+        }
+
+        @Override
+        public void onStopDiscoveryFailed(String serviceType, int errorCode) {
+            Log.e(
+                    TAG,
+                    "Failed to stop service discovery."
+                            + " Error code: "
+                            + errorCode
+                            + ", listener: "
+                            + this);
+            mDiscoveryListeners.remove(mListenerId);
+        }
+
+        @Override
+        public void onDiscoveryStarted(String serviceType) {
+            Log.i(TAG, "Started service discovery. Listener: " + this);
+        }
+
+        @Override
+        public void onDiscoveryStopped(String serviceType) {
+            Log.i(TAG, "Stopped service discovery. Listener: " + this);
+            mDiscoveryListeners.remove(mListenerId);
+        }
+
+        @Override
+        public void onServiceFound(NsdServiceInfo serviceInfo) {
+            Log.i(TAG, "Found service: " + serviceInfo);
+            try {
+                mDiscoverServiceCallback.onServiceDiscovered(
+                        serviceInfo.getServiceName(), mType, true);
+            } catch (RemoteException e) {
+                // do nothing if the client is dead
+            }
+        }
+
+        @Override
+        public void onServiceLost(NsdServiceInfo serviceInfo) {
+            Log.i(TAG, "Lost service: " + serviceInfo);
+            try {
+                mDiscoverServiceCallback.onServiceDiscovered(
+                        serviceInfo.getServiceName(), mType, false);
+            } catch (RemoteException e) {
+                // do nothing if the client is dead
+            }
+        }
+
+        @Override
+        public String toString() {
+            return "ID: " + mListenerId + ", type: " + mType;
+        }
+    }
+
+    private final class ServiceInfoListener implements NsdManager.ServiceInfoCallback {
+        private final String mName;
+        private final String mType;
+        private final INsdResolveServiceCallback mResolveServiceCallback;
+        private final int mListenerId;
+
+        ServiceInfoListener(
+                @NonNull NsdServiceInfo serviceInfo,
+                int listenerId,
+                @NonNull INsdResolveServiceCallback resolveServiceCallback) {
+            mName = serviceInfo.getServiceName();
+            mType = serviceInfo.getServiceType();
+            mListenerId = listenerId;
+            mResolveServiceCallback = resolveServiceCallback;
+        }
+
+        @Override
+        public void onServiceInfoCallbackRegistrationFailed(int errorCode) {
+            Log.e(
+                    TAG,
+                    "Failed to register service info callback."
+                            + " Listener ID: "
+                            + mListenerId
+                            + ", error: "
+                            + errorCode
+                            + ", service name: "
+                            + mName
+                            + ", service type: "
+                            + mType);
+        }
+
+        @Override
+        public void onServiceUpdated(@NonNull NsdServiceInfo serviceInfo) {
+            Log.i(
+                    TAG,
+                    "Service is resolved. "
+                            + " Listener ID: "
+                            + mListenerId
+                            + ", serviceInfo: "
+                            + serviceInfo);
+            List<String> addresses = new ArrayList<>();
+            for (InetAddress address : serviceInfo.getHostAddresses()) {
+                if (address instanceof Inet6Address) {
+                    addresses.add(address.getHostAddress());
+                }
+            }
+            List<DnsTxtAttribute> txtList = new ArrayList<>();
+            for (Map.Entry<String, byte[]> entry : serviceInfo.getAttributes().entrySet()) {
+                DnsTxtAttribute attribute = new DnsTxtAttribute();
+                attribute.name = entry.getKey();
+                attribute.value = Arrays.copyOf(entry.getValue(), entry.getValue().length);
+                txtList.add(attribute);
+            }
+            // TODO: b/329018320 - Use the serviceInfo.getExpirationTime to derive TTL.
+            int ttlSeconds = 10;
+            try {
+                mResolveServiceCallback.onServiceResolved(
+                        serviceInfo.getHostname(),
+                        serviceInfo.getServiceName(),
+                        serviceInfo.getServiceType(),
+                        serviceInfo.getPort(),
+                        addresses,
+                        txtList,
+                        ttlSeconds);
+
+            } catch (RemoteException e) {
+                // do nothing if the client is dead
+            }
+        }
+
+        @Override
+        public void onServiceLost() {}
+
+        @Override
+        public void onServiceInfoCallbackUnregistered() {
+            Log.i(TAG, "The service info callback is unregistered. Listener: " + this);
+            mServiceInfoListeners.remove(mListenerId);
+        }
+
+        @Override
+        public String toString() {
+            return "ID: " + mListenerId + ", service name: " + mName + ", service type: " + mType;
+        }
+    }
 }
diff --git a/thread/service/java/com/android/server/thread/ThreadNetworkControllerService.java b/thread/service/java/com/android/server/thread/ThreadNetworkControllerService.java
index 5f4627f..155296d 100644
--- a/thread/service/java/com/android/server/thread/ThreadNetworkControllerService.java
+++ b/thread/service/java/com/android/server/thread/ThreadNetworkControllerService.java
@@ -346,8 +346,8 @@
                 mTunIfController.getTunFd(),
                 isEnabled(),
                 mNsdPublisher,
-                getMeshcopTxtAttributes(mResources.get()));
-        otDaemon.registerStateCallback(mOtDaemonCallbackProxy, -1);
+                getMeshcopTxtAttributes(mResources.get()),
+                mOtDaemonCallbackProxy);
         otDaemon.asBinder().linkToDeath(() -> mHandler.post(this::onOtDaemonDied), 0);
         mOtDaemon = otDaemon;
         return mOtDaemon;
@@ -1358,8 +1358,11 @@
                 return;
             }
 
+            final int deviceRole = mState.deviceRole;
+            mState = null;
+
             // If this device is already STOPPED or DETACHED, do nothing
-            if (!ThreadNetworkController.isAttached(mState.deviceRole)) {
+            if (!ThreadNetworkController.isAttached(deviceRole)) {
                 return;
             }
 
diff --git a/thread/service/java/com/android/server/thread/ThreadPersistentSettings.java b/thread/service/java/com/android/server/thread/ThreadPersistentSettings.java
index 5cb53fe..923f002 100644
--- a/thread/service/java/com/android/server/thread/ThreadPersistentSettings.java
+++ b/thread/service/java/com/android/server/thread/ThreadPersistentSettings.java
@@ -61,7 +61,7 @@
 
     /******** Thread persistent setting keys ***************/
     /** Stores the Thread feature toggle state, true for enabled and false for disabled. */
-    public static final Key<Boolean> THREAD_ENABLED = new Key<>("Thread_enabled", true);
+    public static final Key<Boolean> THREAD_ENABLED = new Key<>("thread_enabled", true);
 
     /******** Thread persistent setting keys ***************/
 
diff --git a/thread/tests/integration/src/android/net/thread/ServiceDiscoveryTest.java b/thread/tests/integration/src/android/net/thread/ServiceDiscoveryTest.java
index 39a1671..491331c 100644
--- a/thread/tests/integration/src/android/net/thread/ServiceDiscoveryTest.java
+++ b/thread/tests/integration/src/android/net/thread/ServiceDiscoveryTest.java
@@ -17,6 +17,7 @@
 package android.net.thread;
 
 import static android.net.InetAddresses.parseNumericAddress;
+import static android.net.nsd.NsdManager.PROTOCOL_DNS_SD;
 import static android.net.thread.utils.IntegrationTestUtils.JOIN_TIMEOUT;
 import static android.net.thread.utils.IntegrationTestUtils.SERVICE_DISCOVERY_TIMEOUT;
 import static android.net.thread.utils.IntegrationTestUtils.discoverForServiceLost;
@@ -37,6 +38,7 @@
 import android.net.nsd.NsdManager;
 import android.net.nsd.NsdServiceInfo;
 import android.net.thread.utils.FullThreadDevice;
+import android.net.thread.utils.OtDaemonController;
 import android.net.thread.utils.TapTestNetworkTracker;
 import android.net.thread.utils.ThreadFeatureCheckerRule;
 import android.net.thread.utils.ThreadFeatureCheckerRule.RequiresSimulationThreadDevice;
@@ -65,6 +67,7 @@
 import java.util.Map;
 import java.util.Random;
 import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutionException;
 import java.util.concurrent.TimeoutException;
 
 /** Integration test cases for Service Discovery feature. */
@@ -96,15 +99,15 @@
     private final Context mContext = ApplicationProvider.getApplicationContext();
     private final ThreadNetworkControllerWrapper mController =
             ThreadNetworkControllerWrapper.newInstance(mContext);
-
+    private final OtDaemonController mOtCtl = new OtDaemonController();
     private HandlerThread mHandlerThread;
     private NsdManager mNsdManager;
     private TapTestNetworkTracker mTestNetworkTracker;
     private List<FullThreadDevice> mFtds;
+    private List<RegistrationListener> mRegistrationListeners = new ArrayList<>();
 
     @Before
     public void setUp() throws Exception {
-
         mController.joinAndWait(DEFAULT_DATASET);
         mNsdManager = mContext.getSystemService(NsdManager.class);
 
@@ -127,6 +130,9 @@
 
     @After
     public void tearDown() throws Exception {
+        for (RegistrationListener listener : mRegistrationListeners) {
+            unregisterService(listener);
+        }
         for (FullThreadDevice ftd : mFtds) {
             // Clear registered SRP hosts and services
             if (ftd.isSrpHostRegistered()) {
@@ -314,6 +320,176 @@
         assertThat(txtMap.get("mn")).isEqualTo("Thread Border Router".getBytes(UTF_8));
     }
 
+    @Test
+    public void discoveryProxy_multipleClientsBrowseAndResolveServiceOverMdns() throws Exception {
+        /*
+         * <pre>
+         * Topology:
+         *                    Thread
+         *  Border Router -------------- Full Thread device
+         *  (Cuttlefish)
+         * </pre>
+         */
+
+        RegistrationListener listener = new RegistrationListener();
+        NsdServiceInfo info = new NsdServiceInfo();
+        info.setServiceType("_testservice._tcp");
+        info.setServiceName("test-service");
+        info.setPort(12345);
+        info.setHostname("testhost");
+        info.setHostAddresses(List.of(parseNumericAddress("2001::1")));
+        info.setAttribute("key1", bytes(0x01, 0x02));
+        info.setAttribute("key2", bytes(0x03));
+        registerService(info, listener);
+        mRegistrationListeners.add(listener);
+        for (int i = 0; i < NUM_FTD; ++i) {
+            FullThreadDevice ftd = mFtds.get(i);
+            ftd.joinNetwork(DEFAULT_DATASET);
+            ftd.waitForStateAnyOf(List.of("router", "child"), JOIN_TIMEOUT);
+            ftd.setDnsServerAddress(mOtCtl.getMlEid().getHostAddress());
+        }
+        final ArrayList<NsdServiceInfo> browsedServices = new ArrayList<>();
+        final ArrayList<NsdServiceInfo> resolvedServices = new ArrayList<>();
+        final ArrayList<Thread> threads = new ArrayList<>();
+        for (int i = 0; i < NUM_FTD; ++i) {
+            browsedServices.add(null);
+            resolvedServices.add(null);
+        }
+        for (int i = 0; i < NUM_FTD; ++i) {
+            final FullThreadDevice ftd = mFtds.get(i);
+            final int index = i;
+            Runnable task =
+                    () -> {
+                        browsedServices.set(
+                                index,
+                                ftd.browseService("_testservice._tcp.default.service.arpa."));
+                        resolvedServices.set(
+                                index,
+                                ftd.resolveService(
+                                        "test-service", "_testservice._tcp.default.service.arpa."));
+                    };
+            threads.add(new Thread(task));
+        }
+        for (Thread thread : threads) {
+            thread.start();
+        }
+        for (Thread thread : threads) {
+            thread.join();
+        }
+
+        for (int i = 0; i < NUM_FTD; ++i) {
+            NsdServiceInfo browsedService = browsedServices.get(i);
+            assertThat(browsedService.getServiceName()).isEqualTo("test-service");
+            assertThat(browsedService.getPort()).isEqualTo(12345);
+
+            NsdServiceInfo resolvedService = resolvedServices.get(i);
+            assertThat(resolvedService.getServiceName()).isEqualTo("test-service");
+            assertThat(resolvedService.getPort()).isEqualTo(12345);
+            assertThat(resolvedService.getHostname()).isEqualTo("testhost.default.service.arpa.");
+            assertThat(resolvedService.getHostAddresses())
+                    .containsExactly(parseNumericAddress("2001::1"));
+            assertThat(resolvedService.getAttributes())
+                    .comparingValuesUsing(BYTE_ARRAY_EQUALITY)
+                    .containsExactly("key1", bytes(0x01, 0x02), "key2", bytes(3));
+        }
+    }
+
+    @Test
+    public void discoveryProxy_browseAndResolveServiceAtSrpServer() throws Exception {
+        /*
+         * <pre>
+         * Topology:
+         *                    Thread
+         *  Border Router -------+------ SRP client
+         *  (Cuttlefish)         |
+         *                       +------ DNS client
+         *
+         * </pre>
+         */
+        FullThreadDevice srpClient = mFtds.get(0);
+        srpClient.joinNetwork(DEFAULT_DATASET);
+        srpClient.waitForStateAnyOf(List.of("router", "child"), JOIN_TIMEOUT);
+        srpClient.setSrpHostname("my-host");
+        srpClient.setSrpHostAddresses(List.of((Inet6Address) parseNumericAddress("2001::1")));
+        srpClient.addSrpService(
+                "my-service",
+                "_test._udp",
+                List.of("_sub1"),
+                12345 /* port */,
+                Map.of("key1", bytes(0x01, 0x02), "key2", bytes(0x03)));
+
+        FullThreadDevice dnsClient = mFtds.get(1);
+        dnsClient.joinNetwork(DEFAULT_DATASET);
+        dnsClient.waitForStateAnyOf(List.of("router", "child"), JOIN_TIMEOUT);
+        dnsClient.setDnsServerAddress(mOtCtl.getMlEid().getHostAddress());
+
+        NsdServiceInfo browsedService = dnsClient.browseService("_test._udp.default.service.arpa.");
+        assertThat(browsedService.getServiceName()).isEqualTo("my-service");
+        assertThat(browsedService.getPort()).isEqualTo(12345);
+        assertThat(browsedService.getHostname()).isEqualTo("my-host.default.service.arpa.");
+        assertThat(browsedService.getHostAddresses())
+                .containsExactly(parseNumericAddress("2001::1"));
+        assertThat(browsedService.getAttributes())
+                .comparingValuesUsing(BYTE_ARRAY_EQUALITY)
+                .containsExactly("key1", bytes(0x01, 0x02), "key2", bytes(3));
+
+        NsdServiceInfo resolvedService =
+                dnsClient.resolveService("my-service", "_test._udp.default.service.arpa.");
+        assertThat(resolvedService.getServiceName()).isEqualTo("my-service");
+        assertThat(resolvedService.getPort()).isEqualTo(12345);
+        assertThat(resolvedService.getHostname()).isEqualTo("my-host.default.service.arpa.");
+        assertThat(resolvedService.getHostAddresses())
+                .containsExactly(parseNumericAddress("2001::1"));
+        assertThat(resolvedService.getAttributes())
+                .comparingValuesUsing(BYTE_ARRAY_EQUALITY)
+                .containsExactly("key1", bytes(0x01, 0x02), "key2", bytes(3));
+    }
+
+    private void registerService(NsdServiceInfo serviceInfo, RegistrationListener listener)
+            throws InterruptedException, ExecutionException, TimeoutException {
+        mNsdManager.registerService(serviceInfo, PROTOCOL_DNS_SD, listener);
+        listener.waitForRegistered();
+    }
+
+    private void unregisterService(RegistrationListener listener)
+            throws InterruptedException, ExecutionException, TimeoutException {
+        mNsdManager.unregisterService(listener);
+        listener.waitForUnregistered();
+    }
+
+    private static class RegistrationListener implements NsdManager.RegistrationListener {
+        private final CompletableFuture<Void> mRegisteredFuture = new CompletableFuture<>();
+        private final CompletableFuture<Void> mUnRegisteredFuture = new CompletableFuture<>();
+
+        RegistrationListener() {}
+
+        @Override
+        public void onRegistrationFailed(NsdServiceInfo serviceInfo, int errorCode) {}
+
+        @Override
+        public void onUnregistrationFailed(NsdServiceInfo serviceInfo, int errorCode) {}
+
+        @Override
+        public void onServiceRegistered(NsdServiceInfo serviceInfo) {
+            mRegisteredFuture.complete(null);
+        }
+
+        @Override
+        public void onServiceUnregistered(NsdServiceInfo serviceInfo) {
+            mUnRegisteredFuture.complete(null);
+        }
+
+        public void waitForRegistered()
+                throws InterruptedException, ExecutionException, TimeoutException {
+            mRegisteredFuture.get(SERVICE_DISCOVERY_TIMEOUT.toMillis(), MILLISECONDS);
+        }
+
+        public void waitForUnregistered()
+                throws InterruptedException, ExecutionException, TimeoutException {
+            mUnRegisteredFuture.get(SERVICE_DISCOVERY_TIMEOUT.toMillis(), MILLISECONDS);
+        }
+    }
+
     private static byte[] bytes(int... byteInts) {
         byte[] bytes = new byte[byteInts.length];
         for (int i = 0; i < byteInts.length; ++i) {
diff --git a/thread/tests/integration/src/android/net/thread/ThreadIntegrationTest.java b/thread/tests/integration/src/android/net/thread/ThreadIntegrationTest.java
index 4a006cf..bfded1d 100644
--- a/thread/tests/integration/src/android/net/thread/ThreadIntegrationTest.java
+++ b/thread/tests/integration/src/android/net/thread/ThreadIntegrationTest.java
@@ -97,13 +97,16 @@
     }
 
     @Test
-    public void otDaemonRestart_JoinedNetworkAndStopped_autoRejoined() throws Exception {
+    public void otDaemonRestart_JoinedNetworkAndStopped_autoRejoinedAndTunIfStateConsistent()
+            throws Exception {
         mController.joinAndWait(DEFAULT_DATASET);
 
         runShellCommand("stop ot-daemon");
 
         mController.waitForRole(DEVICE_ROLE_DETACHED, CALLBACK_TIMEOUT);
         mController.waitForRole(DEVICE_ROLE_LEADER, RESTART_JOIN_TIMEOUT);
+        assertThat(mOtCtl.isInterfaceUp()).isTrue();
+        assertThat(runShellCommand("ifconfig thread-wpan")).contains("UP POINTOPOINT RUNNING");
     }
 
     @Test
@@ -120,8 +123,8 @@
         mController.joinAndWait(DEFAULT_DATASET);
 
         mOtCtl.factoryReset();
-        String ifconfig = runShellCommand("ifconfig thread-wpan");
 
+        String ifconfig = runShellCommand("ifconfig thread-wpan");
         assertThat(ifconfig).doesNotContain("inet6 addr");
     }
 
diff --git a/thread/tests/integration/src/android/net/thread/ThreadNetworkControllerTest.java b/thread/tests/integration/src/android/net/thread/ThreadNetworkControllerTest.java
index 496ec9f..ba04348 100644
--- a/thread/tests/integration/src/android/net/thread/ThreadNetworkControllerTest.java
+++ b/thread/tests/integration/src/android/net/thread/ThreadNetworkControllerTest.java
@@ -27,6 +27,7 @@
 import static org.junit.Assert.assertThrows;
 
 import android.content.Context;
+import android.net.thread.utils.ThreadFeatureCheckerRule;
 import android.net.thread.utils.ThreadFeatureCheckerRule.RequiresThreadFeature;
 import android.os.OutcomeReceiver;
 import android.util.SparseIntArray;
@@ -37,6 +38,7 @@
 
 import org.junit.After;
 import org.junit.Before;
+import org.junit.Rule;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 
@@ -64,6 +66,8 @@
                 }
             };
 
+    @Rule public final ThreadFeatureCheckerRule mThreadRule = new ThreadFeatureCheckerRule();
+
     private final Context mContext = ApplicationProvider.getApplicationContext();
     private ExecutorService mExecutor;
     private ThreadNetworkController mController;
diff --git a/thread/tests/integration/src/android/net/thread/utils/FullThreadDevice.java b/thread/tests/integration/src/android/net/thread/utils/FullThreadDevice.java
index 600b662..f7bb9ff 100644
--- a/thread/tests/integration/src/android/net/thread/utils/FullThreadDevice.java
+++ b/thread/tests/integration/src/android/net/thread/utils/FullThreadDevice.java
@@ -24,6 +24,7 @@
 
 import android.net.InetAddresses;
 import android.net.IpPrefix;
+import android.net.nsd.NsdServiceInfo;
 import android.net.thread.ActiveOperationalDataset;
 
 import com.google.errorprone.annotations.FormatMethod;
@@ -34,6 +35,7 @@
 import java.io.InputStreamReader;
 import java.io.OutputStreamWriter;
 import java.net.Inet6Address;
+import java.net.InetAddress;
 import java.nio.charset.StandardCharsets;
 import java.time.Duration;
 import java.util.ArrayList;
@@ -327,6 +329,55 @@
         return false;
     }
 
+    /** Sets the DNS server address. */
+    public void setDnsServerAddress(String address) {
+        executeCommand("dns config " + address);
+    }
+
+    /** Returns the first browsed service instance of {@code serviceType}. */
+    public NsdServiceInfo browseService(String serviceType) {
+        // CLI output:
+        // DNS browse response for _testservice._tcp.default.service.arpa.
+        // test-service
+        //    Port:12345, Priority:0, Weight:0, TTL:10
+        //    Host:testhost.default.service.arpa.
+        //    HostAddress:2001:0:0:0:0:0:0:1 TTL:10
+        //    TXT:[key1=0102, key2=03] TTL:10
+
+        List<String> lines = executeCommand("dns browse " + serviceType);
+        NsdServiceInfo info = new NsdServiceInfo();
+        info.setServiceName(lines.get(1));
+        info.setServiceType(serviceType);
+        info.setPort(DnsServiceCliOutputParser.parsePort(lines.get(2)));
+        info.setHostname(DnsServiceCliOutputParser.parseHostname(lines.get(3)));
+        info.setHostAddresses(List.of(DnsServiceCliOutputParser.parseHostAddress(lines.get(4))));
+        DnsServiceCliOutputParser.parseTxtIntoServiceInfo(lines.get(5), info);
+
+        return info;
+    }
+
+    /** Returns the resolved service instance. */
+    public NsdServiceInfo resolveService(String serviceName, String serviceType) {
+        // CLI output:
+        // DNS service resolution response for test-service for service
+        // _test._tcp.default.service.arpa.
+        // Port:12345, Priority:0, Weight:0, TTL:10
+        // Host:Android.default.service.arpa.
+        // HostAddress:2001:0:0:0:0:0:0:1 TTL:10
+        // TXT:[key1=0102, key2=03] TTL:10
+
+        List<String> lines = executeCommand("dns service %s %s", serviceName, serviceType);
+        NsdServiceInfo info = new NsdServiceInfo();
+        info.setServiceName(serviceName);
+        info.setServiceType(serviceType);
+        info.setPort(DnsServiceCliOutputParser.parsePort(lines.get(1)));
+        info.setHostname(DnsServiceCliOutputParser.parseHostname(lines.get(2)));
+        info.setHostAddresses(List.of(DnsServiceCliOutputParser.parseHostAddress(lines.get(3))));
+        DnsServiceCliOutputParser.parseTxtIntoServiceInfo(lines.get(4), info);
+
+        return info;
+    }
+
     /** Runs the "factoryreset" command on the device. */
     public void factoryReset() {
         try {
@@ -454,4 +505,45 @@
     private static String toHexString(byte[] bytes) {
         return base16().encode(bytes);
     }
+
+    private static final class DnsServiceCliOutputParser {
+        /** Returns the first match in the input of a given regex pattern. */
+        private static Matcher firstMatchOf(String input, String regex) {
+            Matcher matcher = Pattern.compile(regex).matcher(input);
+            matcher.find();
+            return matcher;
+        }
+
+        // Example: "Port:12345"
+        private static int parsePort(String line) {
+            return Integer.parseInt(firstMatchOf(line, "Port:(\\d+)").group(1));
+        }
+
+        // Example: "Host:Android.default.service.arpa."
+        private static String parseHostname(String line) {
+            return firstMatchOf(line, "Host:(.+)").group(1);
+        }
+
+        // Example: "HostAddress:2001:0:0:0:0:0:0:1"
+        private static InetAddress parseHostAddress(String line) {
+            return InetAddresses.parseNumericAddress(
+                    firstMatchOf(line, "HostAddress:([^ ]+)").group(1));
+        }
+
+        // Example: "TXT:[key1=0102, key2=03]"
+        private static void parseTxtIntoServiceInfo(String line, NsdServiceInfo serviceInfo) {
+            String txtString = firstMatchOf(line, "TXT:\\[([^\\]]+)\\]").group(1);
+            for (String txtEntry : txtString.split(",")) {
+                String[] nameAndValue = txtEntry.trim().split("=");
+                String name = nameAndValue[0];
+                String value = nameAndValue[1];
+                byte[] bytes = new byte[value.length() / 2];
+                for (int i = 0; i < value.length(); i += 2) {
+                    byte b = (byte) ((value.charAt(i) - '0') << 4 | (value.charAt(i + 1) - '0'));
+                    bytes[i / 2] = b;
+                }
+                serviceInfo.setAttribute(name, bytes);
+            }
+        }
+    }
 }
diff --git a/thread/tests/integration/src/android/net/thread/utils/OtDaemonController.java b/thread/tests/integration/src/android/net/thread/utils/OtDaemonController.java
index 4a06fe8..ade0669 100644
--- a/thread/tests/integration/src/android/net/thread/utils/OtDaemonController.java
+++ b/thread/tests/integration/src/android/net/thread/utils/OtDaemonController.java
@@ -62,6 +62,18 @@
                 .toList();
     }
 
+    /** Returns {@code true} if the Thread interface is up. */
+    public boolean isInterfaceUp() {
+        String output = executeCommand("ifconfig");
+        return output.contains("up");
+    }
+
+    /** Returns the ML-EID of the device. */
+    public Inet6Address getMlEid() {
+        String addressStr = executeCommand("ipaddr mleid").split("\n")[0].trim();
+        return (Inet6Address) InetAddresses.parseNumericAddress(addressStr);
+    }
+
     public String executeCommand(String cmd) {
         return SystemUtil.runShellCommand(OT_CTL + " " + cmd);
     }
diff --git a/thread/tests/unit/src/com/android/server/thread/NsdPublisherTest.java b/thread/tests/unit/src/com/android/server/thread/NsdPublisherTest.java
index d860166..8886c73 100644
--- a/thread/tests/unit/src/com/android/server/thread/NsdPublisherTest.java
+++ b/thread/tests/unit/src/com/android/server/thread/NsdPublisherTest.java
@@ -23,6 +23,7 @@
 
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.anyInt;
+import static org.mockito.ArgumentMatchers.argThat;
 import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.Mockito.doThrow;
 import static org.mockito.Mockito.spy;
@@ -30,24 +31,30 @@
 import static org.mockito.Mockito.verify;
 
 import android.net.InetAddresses;
+import android.net.nsd.DiscoveryRequest;
 import android.net.nsd.NsdManager;
 import android.net.nsd.NsdServiceInfo;
 import android.os.Handler;
 import android.os.test.TestLooper;
 
 import com.android.server.thread.openthread.DnsTxtAttribute;
+import com.android.server.thread.openthread.INsdDiscoverServiceCallback;
+import com.android.server.thread.openthread.INsdResolveServiceCallback;
 import com.android.server.thread.openthread.INsdStatusReceiver;
 
 import org.junit.Before;
 import org.junit.Test;
 import org.mockito.ArgumentCaptor;
+import org.mockito.ArgumentMatcher;
 import org.mockito.Mock;
 import org.mockito.MockitoAnnotations;
 
 import java.net.InetAddress;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Collections;
 import java.util.List;
+import java.util.Objects;
 import java.util.Set;
 import java.util.concurrent.Executor;
 
@@ -57,6 +64,8 @@
 
     @Mock private INsdStatusReceiver mRegistrationReceiver;
     @Mock private INsdStatusReceiver mUnregistrationReceiver;
+    @Mock private INsdDiscoverServiceCallback mDiscoverServiceCallback;
+    @Mock private INsdResolveServiceCallback mResolveServiceCallback;
 
     private TestLooper mTestLooper;
     private NsdPublisher mNsdPublisher;
@@ -469,6 +478,165 @@
     }
 
     @Test
+    public void discoverService_serviceDiscovered() throws Exception {
+        prepareTest();
+
+        mNsdPublisher.discoverService("_test._tcp", mDiscoverServiceCallback, 10 /* listenerId */);
+        mTestLooper.dispatchAll();
+        ArgumentCaptor<NsdManager.DiscoveryListener> discoveryListenerArgumentCaptor =
+                ArgumentCaptor.forClass(NsdManager.DiscoveryListener.class);
+        verify(mMockNsdManager, times(1))
+                .discoverServices(
+                        eq(new DiscoveryRequest.Builder(PROTOCOL_DNS_SD, "_test._tcp").build()),
+                        any(Executor.class),
+                        discoveryListenerArgumentCaptor.capture());
+        NsdManager.DiscoveryListener actualDiscoveryListener =
+                discoveryListenerArgumentCaptor.getValue();
+        NsdServiceInfo serviceInfo = new NsdServiceInfo();
+        serviceInfo.setServiceName("test");
+        serviceInfo.setServiceType(null);
+        actualDiscoveryListener.onServiceFound(serviceInfo);
+        mTestLooper.dispatchAll();
+
+        verify(mDiscoverServiceCallback, times(1))
+                .onServiceDiscovered("test", "_test._tcp", true /* isFound */);
+    }
+
+    @Test
+    public void discoverService_serviceLost() throws Exception {
+        prepareTest();
+
+        mNsdPublisher.discoverService("_test._tcp", mDiscoverServiceCallback, 10 /* listenerId */);
+        mTestLooper.dispatchAll();
+        ArgumentCaptor<NsdManager.DiscoveryListener> discoveryListenerArgumentCaptor =
+                ArgumentCaptor.forClass(NsdManager.DiscoveryListener.class);
+        verify(mMockNsdManager, times(1))
+                .discoverServices(
+                        eq(new DiscoveryRequest.Builder(PROTOCOL_DNS_SD, "_test._tcp").build()),
+                        any(Executor.class),
+                        discoveryListenerArgumentCaptor.capture());
+        NsdManager.DiscoveryListener actualDiscoveryListener =
+                discoveryListenerArgumentCaptor.getValue();
+        NsdServiceInfo serviceInfo = new NsdServiceInfo();
+        serviceInfo.setServiceName("test");
+        serviceInfo.setServiceType(null);
+        actualDiscoveryListener.onServiceLost(serviceInfo);
+        mTestLooper.dispatchAll();
+
+        verify(mDiscoverServiceCallback, times(1))
+                .onServiceDiscovered("test", "_test._tcp", false /* isFound */);
+    }
+
+    @Test
+    public void stopServiceDiscovery() {
+        prepareTest();
+
+        mNsdPublisher.discoverService("_test._tcp", mDiscoverServiceCallback, 10 /* listenerId */);
+        mTestLooper.dispatchAll();
+        ArgumentCaptor<NsdManager.DiscoveryListener> discoveryListenerArgumentCaptor =
+                ArgumentCaptor.forClass(NsdManager.DiscoveryListener.class);
+        verify(mMockNsdManager, times(1))
+                .discoverServices(
+                        eq(new DiscoveryRequest.Builder(PROTOCOL_DNS_SD, "_test._tcp").build()),
+                        any(Executor.class),
+                        discoveryListenerArgumentCaptor.capture());
+        NsdManager.DiscoveryListener actualDiscoveryListener =
+                discoveryListenerArgumentCaptor.getValue();
+        NsdServiceInfo serviceInfo = new NsdServiceInfo();
+        serviceInfo.setServiceName("test");
+        serviceInfo.setServiceType(null);
+        actualDiscoveryListener.onServiceFound(serviceInfo);
+        mNsdPublisher.stopServiceDiscovery(10 /* listenerId */);
+        mTestLooper.dispatchAll();
+
+        verify(mMockNsdManager, times(1)).stopServiceDiscovery(actualDiscoveryListener);
+    }
+
+    @Test
+    public void resolveService_serviceResolved() throws Exception {
+        prepareTest();
+
+        mNsdPublisher.resolveService(
+                "test", "_test._tcp", mResolveServiceCallback, 10 /* listenerId */);
+        mTestLooper.dispatchAll();
+        ArgumentCaptor<NsdServiceInfo> serviceInfoArgumentCaptor =
+                ArgumentCaptor.forClass(NsdServiceInfo.class);
+        ArgumentCaptor<NsdManager.ServiceInfoCallback> serviceInfoCallbackArgumentCaptor =
+                ArgumentCaptor.forClass(NsdManager.ServiceInfoCallback.class);
+        verify(mMockNsdManager, times(1))
+                .registerServiceInfoCallback(
+                        serviceInfoArgumentCaptor.capture(),
+                        any(Executor.class),
+                        serviceInfoCallbackArgumentCaptor.capture());
+        assertThat(serviceInfoArgumentCaptor.getValue().getServiceName()).isEqualTo("test");
+        assertThat(serviceInfoArgumentCaptor.getValue().getServiceType()).isEqualTo("_test._tcp");
+        NsdServiceInfo serviceInfo = new NsdServiceInfo();
+        serviceInfo.setServiceName("test");
+        serviceInfo.setServiceType("_test._tcp");
+        serviceInfo.setPort(12345);
+        serviceInfo.setHostname("test-host");
+        serviceInfo.setHostAddresses(
+                List.of(
+                        InetAddress.parseNumericAddress("2001::1"),
+                        InetAddress.parseNumericAddress("2001::2")));
+        serviceInfo.setAttribute("key1", new byte[] {(byte) 0x01, (byte) 0x02});
+        serviceInfo.setAttribute("key2", new byte[] {(byte) 0x03});
+        serviceInfoCallbackArgumentCaptor.getValue().onServiceUpdated(serviceInfo);
+        mTestLooper.dispatchAll();
+
+        verify(mResolveServiceCallback, times(1))
+                .onServiceResolved(
+                        eq("test-host"),
+                        eq("test"),
+                        eq("_test._tcp"),
+                        eq(12345),
+                        eq(List.of("2001::1", "2001::2")),
+                        argThat(
+                                new TxtMatcher(
+                                        List.of(
+                                                makeTxtAttribute("key1", List.of(0x01, 0x02)),
+                                                makeTxtAttribute("key2", List.of(0x03))))),
+                        anyInt());
+    }
+
+    @Test
+    public void stopServiceResolution() throws Exception {
+        prepareTest();
+
+        mNsdPublisher.resolveService(
+                "test", "_test._tcp", mResolveServiceCallback, 10 /* listenerId */);
+        mTestLooper.dispatchAll();
+        ArgumentCaptor<NsdServiceInfo> serviceInfoArgumentCaptor =
+                ArgumentCaptor.forClass(NsdServiceInfo.class);
+        ArgumentCaptor<NsdManager.ServiceInfoCallback> serviceInfoCallbackArgumentCaptor =
+                ArgumentCaptor.forClass(NsdManager.ServiceInfoCallback.class);
+        verify(mMockNsdManager, times(1))
+                .registerServiceInfoCallback(
+                        serviceInfoArgumentCaptor.capture(),
+                        any(Executor.class),
+                        serviceInfoCallbackArgumentCaptor.capture());
+        assertThat(serviceInfoArgumentCaptor.getValue().getServiceName()).isEqualTo("test");
+        assertThat(serviceInfoArgumentCaptor.getValue().getServiceType()).isEqualTo("_test._tcp");
+        NsdServiceInfo serviceInfo = new NsdServiceInfo();
+        serviceInfo.setServiceName("test");
+        serviceInfo.setServiceType("_test._tcp");
+        serviceInfo.setPort(12345);
+        serviceInfo.setHostname("test-host");
+        serviceInfo.setHostAddresses(
+                List.of(
+                        InetAddress.parseNumericAddress("2001::1"),
+                        InetAddress.parseNumericAddress("2001::2")));
+        serviceInfo.setAttribute("key1", new byte[] {(byte) 0x01, (byte) 0x02});
+        serviceInfo.setAttribute("key2", new byte[] {(byte) 0x03});
+        serviceInfoCallbackArgumentCaptor.getValue().onServiceUpdated(serviceInfo);
+        mNsdPublisher.stopServiceResolution(10 /* listenerId */);
+        mTestLooper.dispatchAll();
+
+        verify(mMockNsdManager, times(1))
+                .unregisterServiceInfoCallback(serviceInfoCallbackArgumentCaptor.getValue());
+    }
+
+    @Test
     public void reset_unregisterAll() {
         prepareTest();
 
@@ -582,6 +750,30 @@
         return addresses;
     }
 
+    private static class TxtMatcher implements ArgumentMatcher<List<DnsTxtAttribute>> {
+        private final List<DnsTxtAttribute> mAttributes;
+
+        TxtMatcher(List<DnsTxtAttribute> attributes) {
+            mAttributes = attributes;
+        }
+
+        @Override
+        public boolean matches(List<DnsTxtAttribute> argument) {
+            if (argument.size() != mAttributes.size()) {
+                return false;
+            }
+            for (int i = 0; i < argument.size(); ++i) {
+                if (!Objects.equals(argument.get(i).name, mAttributes.get(i).name)) {
+                    return false;
+                }
+                if (!Arrays.equals(argument.get(i).value, mAttributes.get(i).value)) {
+                    return false;
+                }
+            }
+            return true;
+        }
+    }
+
     // @Before and @Test run in different threads. NsdPublisher requires the jobs are run on the
     // thread looper, so TestLooper needs to be created inside each test case to install the
     // correct looper.
diff --git a/thread/tests/unit/src/com/android/server/thread/ThreadNetworkControllerServiceTest.java b/thread/tests/unit/src/com/android/server/thread/ThreadNetworkControllerServiceTest.java
index 151ed5b..0c7d086 100644
--- a/thread/tests/unit/src/com/android/server/thread/ThreadNetworkControllerServiceTest.java
+++ b/thread/tests/unit/src/com/android/server/thread/ThreadNetworkControllerServiceTest.java
@@ -78,7 +78,9 @@
 import org.junit.runner.RunWith;
 import org.mockito.ArgumentCaptor;
 import org.mockito.Captor;
+import org.mockito.InOrder;
 import org.mockito.Mock;
+import org.mockito.Mockito;
 import org.mockito.MockitoAnnotations;
 
 import java.util.concurrent.CompletableFuture;
@@ -489,4 +491,23 @@
         assertThat(mFakeOtDaemon.isInitialized()).isTrue();
         verify(mockJoinReceiver, times(1)).onSuccess();
     }
+
+    @Test
+    public void onOtDaemonDied_joinedNetwork_interfaceStateBackToUp() throws Exception {
+        mService.initialize();
+        final IOperationReceiver mockReceiver = mock(IOperationReceiver.class);
+        mService.join(DEFAULT_ACTIVE_DATASET, mockReceiver);
+        mTestLooper.dispatchAll();
+        mTestLooper.moveTimeForward(FakeOtDaemon.JOIN_DELAY.toMillis() + 100);
+        mTestLooper.dispatchAll();
+
+        Mockito.reset(mMockInfraIfController);
+        mFakeOtDaemon.terminate();
+        mTestLooper.dispatchAll();
+
+        verify(mMockTunIfController, times(1)).onOtDaemonDied();
+        InOrder inOrder = Mockito.inOrder(mMockTunIfController);
+        inOrder.verify(mMockTunIfController, times(1)).setInterfaceUp(false);
+        inOrder.verify(mMockTunIfController, times(1)).setInterfaceUp(true);
+    }
 }