Merge changes from topics "remove-napt", "support-multipleaddr" into main

* changes:
  Support multiple addresses in PacketBridge
  Remove NAPT from PacketForwarder
diff --git a/framework/src/android/net/ConnectivityManager.java b/framework/src/android/net/ConnectivityManager.java
index fa27d0e..1ea1815 100644
--- a/framework/src/android/net/ConnectivityManager.java
+++ b/framework/src/android/net/ConnectivityManager.java
@@ -6022,6 +6022,13 @@
     /**
      * Sets data saver switch.
      *
+     * <p>This API configures the bandwidth control, and filling data saver status in BpfMap,
+     * which is intended for internal use by the network stack to optimize performance
+     * when frequently checking data saver status for multiple uids without doing IPC.
+     * It does not directly control the global data saver mode that users manage in settings.
+     * To query the comprehensive data saver status for a specific UID, including allowlist
+     * considerations, use {@link #getRestrictBackgroundStatus}.
+     *
      * @param enable True if enable.
      * @throws IllegalStateException if failed.
      * @hide
diff --git a/service-t/src/com/android/server/NsdService.java b/service-t/src/com/android/server/NsdService.java
index 6c25d76..9b2f80b 100644
--- a/service-t/src/com/android/server/NsdService.java
+++ b/service-t/src/com/android/server/NsdService.java
@@ -176,7 +176,7 @@
             "mdns_advertiser_allowlist_";
     private static final String MDNS_ALLOWLIST_FLAG_SUFFIX = "_version";
 
-
+    private static final String FORCE_ENABLE_FLAG_FOR_TEST_PREFIX = "test_";
 
     @VisibleForTesting
     static final String MDNS_CONFIG_RUNNING_APP_ACTIVE_IMPORTANCE_CUTOFF =
@@ -1739,6 +1739,10 @@
                         mContext, MdnsFeatureFlags.NSD_LIMIT_LABEL_COUNT))
                 .setIsKnownAnswerSuppressionEnabled(mDeps.isFeatureEnabled(
                         mContext, MdnsFeatureFlags.NSD_KNOWN_ANSWER_SUPPRESSION))
+                .setIsUnicastReplyEnabled(mDeps.isFeatureEnabled(
+                        mContext, MdnsFeatureFlags.NSD_UNICAST_REPLY_ENABLED))
+                .setOverrideProvider(flag -> mDeps.isFeatureEnabled(
+                        mContext, FORCE_ENABLE_FLAG_FOR_TEST_PREFIX + flag))
                 .build();
         mMdnsSocketClient =
                 new MdnsMultinetworkSocketClient(handler.getLooper(), mMdnsSocketProvider,
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsFeatureFlags.java b/service-t/src/com/android/server/connectivity/mdns/MdnsFeatureFlags.java
index 1ad47a3..9466162 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsFeatureFlags.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsFeatureFlags.java
@@ -15,6 +15,9 @@
  */
 package com.android.server.connectivity.mdns;
 
+import android.annotation.NonNull;
+import android.annotation.Nullable;
+
 /**
  * The class that contains mDNS feature flags;
  */
@@ -46,6 +49,14 @@
      */
     public static final String NSD_KNOWN_ANSWER_SUPPRESSION = "nsd_known_answer_suppression";
 
+    /**
+     * A feature flag to control whether unicast replies should be enabled.
+     *
+     * <p>Enabling this feature causes replies to queries with the Query Unicast (QU) flag set to be
+     * sent unicast instead of multicast, as per RFC6762 5.4.
+     */
+    public static final String NSD_UNICAST_REPLY_ENABLED = "nsd_unicast_reply_enabled";
+
     // Flag for offload feature
     public final boolean mIsMdnsOffloadFeatureEnabled;
 
@@ -61,6 +72,36 @@
     // Flag for known-answer suppression
     public final boolean mIsKnownAnswerSuppressionEnabled;
 
+    // Flag to enable replying unicast to queries requesting unicast replies
+    public final boolean mIsUnicastReplyEnabled;
+
+    @Nullable
+    private final FlagOverrideProvider mOverrideProvider;
+
+    /**
+     * A provider that can indicate whether a flag should be force-enabled for testing purposes.
+     */
+    public interface FlagOverrideProvider {
+        /**
+         * Indicates whether the flag should be force-enabled for testing purposes.
+         */
+        boolean isForceEnabledForTest(@NonNull String flag);
+    }
+
+    /**
+     * Indicates whether the flag should be force-enabled for testing purposes.
+     */
+    private boolean isForceEnabledForTest(@NonNull String flag) {
+        return mOverrideProvider != null && mOverrideProvider.isForceEnabledForTest(flag);
+    }
+
+    /**
+     * Indicates whether {@link #NSD_UNICAST_REPLY_ENABLED} is enabled, including for testing.
+     */
+    public boolean isUnicastReplyEnabled() {
+        return mIsUnicastReplyEnabled || isForceEnabledForTest(NSD_UNICAST_REPLY_ENABLED);
+    }
+
     /**
      * The constructor for {@link MdnsFeatureFlags}.
      */
@@ -68,12 +109,16 @@
             boolean includeInetAddressRecordsInProbing,
             boolean isExpiredServicesRemovalEnabled,
             boolean isLabelCountLimitEnabled,
-            boolean isKnownAnswerSuppressionEnabled) {
+            boolean isKnownAnswerSuppressionEnabled,
+            boolean isUnicastReplyEnabled,
+            @Nullable FlagOverrideProvider overrideProvider) {
         mIsMdnsOffloadFeatureEnabled = isOffloadFeatureEnabled;
         mIncludeInetAddressRecordsInProbing = includeInetAddressRecordsInProbing;
         mIsExpiredServicesRemovalEnabled = isExpiredServicesRemovalEnabled;
         mIsLabelCountLimitEnabled = isLabelCountLimitEnabled;
         mIsKnownAnswerSuppressionEnabled = isKnownAnswerSuppressionEnabled;
+        mIsUnicastReplyEnabled = isUnicastReplyEnabled;
+        mOverrideProvider = overrideProvider;
     }
 
 
@@ -90,6 +135,8 @@
         private boolean mIsExpiredServicesRemovalEnabled;
         private boolean mIsLabelCountLimitEnabled;
         private boolean mIsKnownAnswerSuppressionEnabled;
+        private boolean mIsUnicastReplyEnabled;
+        private FlagOverrideProvider mOverrideProvider;
 
         /**
          * The constructor for {@link Builder}.
@@ -100,6 +147,8 @@
             mIsExpiredServicesRemovalEnabled = false;
             mIsLabelCountLimitEnabled = true; // Default enabled.
             mIsKnownAnswerSuppressionEnabled = false;
+            mIsUnicastReplyEnabled = true;
+            mOverrideProvider = null;
         }
 
         /**
@@ -154,6 +203,27 @@
         }
 
         /**
+         * Set whether the unicast reply feature is enabled.
+         *
+         * @see #NSD_UNICAST_REPLY_ENABLED
+         */
+        public Builder setIsUnicastReplyEnabled(boolean isUnicastReplyEnabled) {
+            mIsUnicastReplyEnabled = isUnicastReplyEnabled;
+            return this;
+        }
+
+        /**
+         * Set a {@link FlagOverrideProvider} to be used by {@link #isForceEnabledForTest(String)}.
+         *
+         * If non-null, features that use {@link #isForceEnabledForTest(String)} will use that
+         * provider to query whether the flag should be force-enabled.
+         */
+        public Builder setOverrideProvider(@Nullable FlagOverrideProvider overrideProvider) {
+            mOverrideProvider = overrideProvider;
+            return this;
+        }
+
+        /**
          * Builds a {@link MdnsFeatureFlags} with the arguments supplied to this builder.
          */
         public MdnsFeatureFlags build() {
@@ -161,7 +231,9 @@
                     mIncludeInetAddressRecordsInProbing,
                     mIsExpiredServicesRemovalEnabled,
                     mIsLabelCountLimitEnabled,
-                    mIsKnownAnswerSuppressionEnabled);
+                    mIsKnownAnswerSuppressionEnabled,
+                    mIsUnicastReplyEnabled,
+                    mOverrideProvider);
         }
     }
 }
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java b/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java
index aa40c92..3a04dcd 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java
@@ -373,12 +373,14 @@
             }
             return;
         }
+        // recvbuf and src are reused after this returns; ensure references to src are not kept.
+        final InetSocketAddress srcCopy = new InetSocketAddress(src.getAddress(), src.getPort());
 
         if (DBG) {
             mSharedLog.v("Parsed packet with " + packet.questions.size() + " questions, "
                     + packet.answers.size() + " answers, "
                     + packet.authorityRecords.size() + " authority, "
-                    + packet.additionalRecords.size() + " additional from " + src);
+                    + packet.additionalRecords.size() + " additional from " + srcCopy);
         }
 
         for (int conflictServiceId : mRecordRepository.getConflictingServices(packet)) {
@@ -389,7 +391,7 @@
         // happen when the incoming packet has answer records (not a question), so there will be no
         // answer. One exception is simultaneous probe tiebreaking (rfc6762 8.2), in which case the
         // conflicting service is still probing and won't reply either.
-        final MdnsReplyInfo answers = mRecordRepository.getReply(packet, src);
+        final MdnsReplyInfo answers = mRecordRepository.getReply(packet, srcCopy);
 
         if (answers == null) return;
         mReplySender.queueReply(answers);
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsRecord.java b/service-t/src/com/android/server/connectivity/mdns/MdnsRecord.java
index 28bd1b4..4b43989 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsRecord.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsRecord.java
@@ -176,6 +176,16 @@
     }
 
     /**
+     * For questions, returns whether a unicast reply was requested.
+     *
+     * In practice this is identical to {@link #getCacheFlush()}, as the "cache flush" flag in
+     * replies is the same as "unicast reply requested" in questions.
+     */
+    public final boolean isUnicastReplyRequested() {
+        return (cls & MdnsConstants.QCLASS_UNICAST) != 0;
+    }
+
+    /**
      * Returns the record's remaining TTL.
      *
      * If the record was not sent yet (receipt time {@link #RECEIPT_TIME_NOT_SENT}), this is the
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java b/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java
index 6b6632c..585b097 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java
@@ -499,26 +499,30 @@
     @Nullable
     public MdnsReplyInfo getReply(MdnsPacket packet, InetSocketAddress src) {
         final long now = SystemClock.elapsedRealtime();
-        final boolean replyUnicast = (packet.flags & MdnsConstants.QCLASS_UNICAST) != 0;
 
         // Use LinkedHashSet for preserving the insert order of the RRs, so that RRs of the same
         // service or host are grouped together (which is more developer-friendly).
         final Set<RecordInfo<?>> answerInfo = new LinkedHashSet<>();
         final Set<RecordInfo<?>> additionalAnswerInfo = new LinkedHashSet<>();
-
+        // Reply unicast if the feature is enabled AND all replied questions request unicast
+        final boolean replyUnicastEnabled = mMdnsFeatureFlags.isUnicastReplyEnabled();
+        boolean replyUnicast = replyUnicastEnabled;
         for (MdnsRecord question : packet.questions) {
             // Add answers from general records
-            addReplyFromService(question, mGeneralRecords, null /* servicePtrRecord */,
-                    null /* serviceSrvRecord */, null /* serviceTxtRecord */, replyUnicast, now,
-                    answerInfo, additionalAnswerInfo, Collections.emptyList());
+            if (addReplyFromService(question, mGeneralRecords, null /* servicePtrRecord */,
+                    null /* serviceSrvRecord */, null /* serviceTxtRecord */, replyUnicastEnabled,
+                    now, answerInfo, additionalAnswerInfo, Collections.emptyList())) {
+                replyUnicast &= question.isUnicastReplyRequested();
+            }
 
             // Add answers from each service
             for (int i = 0; i < mServices.size(); i++) {
                 final ServiceRegistration registration = mServices.valueAt(i);
                 if (registration.exiting || registration.isProbing) continue;
                 if (addReplyFromService(question, registration.allRecords, registration.ptrRecords,
-                        registration.srvRecord, registration.txtRecord, replyUnicast, now,
+                        registration.srvRecord, registration.txtRecord, replyUnicastEnabled, now,
                         answerInfo, additionalAnswerInfo, packet.answers)) {
+                    replyUnicast &= question.isUnicastReplyRequested();
                     registration.repliedServiceCount++;
                     registration.sentPacketCount++;
                 }
@@ -570,6 +574,12 @@
         // Determine the send destination
         final InetSocketAddress dest;
         if (replyUnicast) {
+            // As per RFC6762 5.4, "if the responder has not multicast that record recently (within
+            // one quarter of its TTL), then the responder SHOULD instead multicast the response so
+            // as to keep all the peer caches up to date": this SHOULD is not implemented to
+            // minimize latency for queriers who have just started, so they did not receive previous
+            // multicast responses. Unicast replies are faster as they do not need to wait for the
+            // beacon interval on Wi-Fi.
             dest = src;
         } else if (src.getAddress() instanceof Inet4Address) {
             dest = IPV4_SOCKET_ADDR;
@@ -608,7 +618,7 @@
             @Nullable List<RecordInfo<MdnsPointerRecord>> servicePtrRecords,
             @Nullable RecordInfo<MdnsServiceRecord> serviceSrvRecord,
             @Nullable RecordInfo<MdnsTextRecord> serviceTxtRecord,
-            boolean replyUnicast, long now, @NonNull Set<RecordInfo<?>> answerInfo,
+            boolean replyUnicastEnabled, long now, @NonNull Set<RecordInfo<?>> answerInfo,
             @NonNull Set<RecordInfo<?>> additionalAnswerInfo,
             @NonNull List<MdnsRecord> knownAnswerRecords) {
         boolean hasDnsSdPtrRecordAnswer = false;
@@ -659,7 +669,8 @@
 
             // TODO: responses to probe queries should bypass this check and only ensure the
             // reply is sent 250ms after the last sent time (RFC 6762 p.15)
-            if (!replyUnicast && info.lastAdvertisedTimeMs > 0L
+            if (!(replyUnicastEnabled && question.isUnicastReplyRequested())
+                    && info.lastAdvertisedTimeMs > 0L
                     && now - info.lastAdvertisedTimeMs < MIN_MULTICAST_REPLY_INTERVAL_MS) {
                 continue;
             }
diff --git a/service-t/src/com/android/server/connectivity/mdns/MulticastPacketReader.java b/service-t/src/com/android/server/connectivity/mdns/MulticastPacketReader.java
index 3cd77a4..70451f3 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MulticastPacketReader.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MulticastPacketReader.java
@@ -42,6 +42,12 @@
     private final Set<PacketHandler> mPacketHandlers = MdnsUtils.newSet();
 
     interface PacketHandler {
+        /**
+         * Handle an incoming packet.
+         *
+         * The recvbuf and src <b>will be reused and modified</b> after this method returns, so
+         * implementers must ensure that they are not accessed after handlePacket returns.
+         */
         void handlePacket(byte[] recvbuf, int length, InetSocketAddress src);
     }
 
diff --git a/staticlibs/testutils/devicetests/NSResponder.kt b/staticlibs/testutils/devicetests/NSResponder.kt
new file mode 100644
index 0000000..f7619cd
--- /dev/null
+++ b/staticlibs/testutils/devicetests/NSResponder.kt
@@ -0,0 +1,86 @@
+/*
+ * Copyright (C) 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.testutils
+
+import android.net.MacAddress
+import android.util.Log
+import com.android.net.module.util.Ipv6Utils
+import com.android.net.module.util.NetworkStackConstants.ETHER_HEADER_LEN
+import com.android.net.module.util.NetworkStackConstants.ICMPV6_ND_OPTION_TLLA
+import com.android.net.module.util.NetworkStackConstants.NEIGHBOR_ADVERTISEMENT_FLAG_SOLICITED
+import com.android.net.module.util.Struct
+import com.android.net.module.util.structs.Icmpv6Header
+import com.android.net.module.util.structs.Ipv6Header
+import com.android.net.module.util.structs.LlaOption
+import com.android.net.module.util.structs.NsHeader
+import com.android.testutils.PacketReflector.IPV6_HEADER_LENGTH
+import java.lang.IllegalArgumentException
+import java.net.Inet6Address
+import java.nio.ByteBuffer
+
+private const val NS_TYPE = 135.toShort()
+
+/**
+ * A class that can be used to reply to Neighbor Solicitation packets on a [TapPacketReader].
+ */
+class NSResponder(
+    reader: TapPacketReader,
+    table: Map<Inet6Address, MacAddress>,
+    name: String = NSResponder::class.java.simpleName
+) : PacketResponder(reader, Icmpv6Filter(), name) {
+    companion object {
+        private val TAG = NSResponder::class.simpleName
+    }
+
+    // Copy the map if not already immutable (toMap) to make sure it is not modified
+    private val table = table.toMap()
+
+    override fun replyToPacket(packet: ByteArray, reader: TapPacketReader) {
+        if (packet.size < IPV6_HEADER_LENGTH) {
+            return
+        }
+        val buf = ByteBuffer.wrap(packet, ETHER_HEADER_LEN, packet.size - ETHER_HEADER_LEN)
+        val ipv6Header = parseOrLog(Ipv6Header::class.java, buf) ?: return
+        val icmpHeader = parseOrLog(Icmpv6Header::class.java, buf) ?: return
+        if (icmpHeader.type != NS_TYPE) {
+            return
+        }
+        val ns = parseOrLog(NsHeader::class.java, buf) ?: return
+        val replyMacAddr = table[ns.target] ?: return
+        val slla = parseOrLog(LlaOption::class.java, buf) ?: return
+        val requesterMac = slla.linkLayerAddress
+
+        val tlla = LlaOption.build(ICMPV6_ND_OPTION_TLLA.toByte(), replyMacAddr)
+        reader.sendResponse(Ipv6Utils.buildNaPacket(
+            replyMacAddr /* srcMac */,
+            requesterMac /* dstMac */,
+            ns.target /* srcIp */,
+            ipv6Header.srcIp /* dstIp */,
+            NEIGHBOR_ADVERTISEMENT_FLAG_SOLICITED,
+            ns.target,
+            tlla))
+    }
+
+    private fun <T> parseOrLog(clazz: Class<T>, buf: ByteBuffer): T? where T : Struct {
+        return try {
+            Struct.parse(clazz, buf)
+        } catch (e: IllegalArgumentException) {
+            Log.e(TAG, "Invalid ${clazz.simpleName} in ICMPv6 packet", e)
+            null
+        }
+    }
+}
diff --git a/staticlibs/testutils/devicetests/com/android/testutils/DeviceConfigRule.kt b/staticlibs/testutils/devicetests/com/android/testutils/DeviceConfigRule.kt
index 3d98cc3..68248ca 100644
--- a/staticlibs/testutils/devicetests/com/android/testutils/DeviceConfigRule.kt
+++ b/staticlibs/testutils/devicetests/com/android/testutils/DeviceConfigRule.kt
@@ -22,12 +22,12 @@
 import android.util.Log
 import com.android.modules.utils.build.SdkLevel
 import com.android.testutils.FunctionalUtils.ThrowingRunnable
-import org.junit.rules.TestRule
-import org.junit.runner.Description
-import org.junit.runners.model.Statement
 import java.util.concurrent.CompletableFuture
 import java.util.concurrent.Executor
 import java.util.concurrent.TimeUnit
+import org.junit.rules.TestRule
+import org.junit.runner.Description
+import org.junit.runners.model.Statement
 
 private val TAG = DeviceConfigRule::class.simpleName
 
@@ -147,11 +147,11 @@
         return tryTest {
             runAsShell(*readWritePermissions) {
                 DeviceConfig.addOnPropertiesChangedListener(
-                        DeviceConfig.NAMESPACE_CONNECTIVITY,
+                        namespace,
                         inlineExecutor,
                         listener)
                 DeviceConfig.setProperty(
-                        DeviceConfig.NAMESPACE_CONNECTIVITY,
+                        namespace,
                         key,
                         value,
                         false /* makeDefault */)
diff --git a/staticlibs/testutils/hostdevice/com/android/testutils/PacketFilter.kt b/staticlibs/testutils/hostdevice/com/android/testutils/PacketFilter.kt
index 1bb6d68..a73a58a 100644
--- a/staticlibs/testutils/hostdevice/com/android/testutils/PacketFilter.kt
+++ b/staticlibs/testutils/hostdevice/com/android/testutils/PacketFilter.kt
@@ -110,6 +110,12 @@
     override fun test(t: ByteArray) = impl.test(t)
 }
 
+class Icmpv6Filter : Predicate<ByteArray> {
+    private val impl = OffsetFilter(ETHER_TYPE_OFFSET, 0x86.toByte(), 0xdd.toByte() /* IPv6 */).and(
+        OffsetFilter(IPV6_PROTOCOL_OFFSET, 58 /* ICMPv6 */))
+    override fun test(t: ByteArray) = impl.test(t)
+}
+
 /**
  * A [Predicate] that matches ethernet-encapped DHCP packets sent from a DHCP client.
  */
diff --git a/tests/cts/net/src/android/net/cts/MdnsTestUtils.kt b/tests/cts/net/src/android/net/cts/MdnsTestUtils.kt
index eef3f87..5ba6c4c 100644
--- a/tests/cts/net/src/android/net/cts/MdnsTestUtils.kt
+++ b/tests/cts/net/src/android/net/cts/MdnsTestUtils.kt
@@ -23,11 +23,15 @@
 import com.android.net.module.util.ArrayTrackRecord
 import com.android.net.module.util.DnsPacket
 import com.android.net.module.util.NetworkStackConstants.ETHER_HEADER_LEN
+import com.android.net.module.util.NetworkStackConstants.IPV6_ADDR_LEN
+import com.android.net.module.util.NetworkStackConstants.IPV6_DST_ADDR_OFFSET
 import com.android.net.module.util.NetworkStackConstants.IPV6_HEADER_LEN
 import com.android.net.module.util.NetworkStackConstants.UDP_HEADER_LEN
 import com.android.net.module.util.TrackRecord
 import com.android.testutils.IPv6UdpFilter
 import com.android.testutils.TapPacketReader
+import java.net.Inet6Address
+import java.net.InetAddress
 import kotlin.test.assertEquals
 import kotlin.test.assertNotNull
 import kotlin.test.assertNull
@@ -236,19 +240,28 @@
 private fun getMdnsPayload(packet: ByteArray) = packet.copyOfRange(
     ETHER_HEADER_LEN + IPV6_HEADER_LEN + UDP_HEADER_LEN, packet.size)
 
+private fun getDstAddr(packet: ByteArray): Inet6Address {
+    val v6AddrPos = ETHER_HEADER_LEN + IPV6_DST_ADDR_OFFSET
+    return Inet6Address.getByAddress(packet.copyOfRange(v6AddrPos, v6AddrPos + IPV6_ADDR_LEN))
+            as Inet6Address
+}
+
 fun TapPacketReader.pollForMdnsPacket(
     timeoutMs: Long = MDNS_REGISTRATION_TIMEOUT_MS,
     predicate: (TestDnsPacket) -> Boolean
 ): TestDnsPacket? {
     val mdnsProbeFilter = IPv6UdpFilter(srcPort = MDNS_PORT, dstPort = MDNS_PORT).and {
+        val dst = getDstAddr(it)
         val mdnsPayload = getMdnsPayload(it)
         try {
-            predicate(TestDnsPacket(mdnsPayload))
+            predicate(TestDnsPacket(mdnsPayload, dst))
         } catch (e: DnsPacket.ParseException) {
             false
         }
     }
-    return poll(timeoutMs, mdnsProbeFilter)?.let { TestDnsPacket(getMdnsPayload(it)) }
+    return poll(timeoutMs, mdnsProbeFilter)?.let {
+        TestDnsPacket(getMdnsPayload(it), getDstAddr(it))
+    }
 }
 
 fun TapPacketReader.pollForProbe(
@@ -281,7 +294,7 @@
     it.isReplyFor("$serviceName.$serviceType.local")
 }
 
-class TestDnsPacket(data: ByteArray) : DnsPacket(data) {
+class TestDnsPacket(data: ByteArray, val dstAddr: InetAddress) : DnsPacket(data) {
     val header: DnsHeader
         get() = mHeader
     val records: Array<List<DnsRecord>>
@@ -290,9 +303,10 @@
         it.dName == name && it.nsType == DnsResolver.TYPE_ANY
     }
 
-    fun isReplyFor(name: String): Boolean = mRecords[ANSECTION].any {
-        it.dName == name && it.nsType == DnsResolver.TYPE_SRV
-    }
+    fun isReplyFor(name: String, type: Int = DnsResolver.TYPE_SRV): Boolean =
+        mRecords[ANSECTION].any {
+            it.dName == name && it.nsType == type
+        }
 
     fun isQueryFor(name: String, vararg requiredTypes: Int): Boolean = requiredTypes.all { type ->
         mRecords[QDSECTION].any {
diff --git a/tests/cts/net/src/android/net/cts/NsdManagerTest.kt b/tests/cts/net/src/android/net/cts/NsdManagerTest.kt
index a040201..1309e79 100644
--- a/tests/cts/net/src/android/net/cts/NsdManagerTest.kt
+++ b/tests/cts/net/src/android/net/cts/NsdManagerTest.kt
@@ -61,6 +61,7 @@
 import android.os.Handler
 import android.os.HandlerThread
 import android.platform.test.annotations.AppModeFull
+import android.provider.DeviceConfig.NAMESPACE_TETHERING
 import android.system.ErrnoException
 import android.system.Os
 import android.system.OsConstants.AF_INET6
@@ -69,6 +70,7 @@
 import android.system.OsConstants.ETH_P_IPV6
 import android.system.OsConstants.IPPROTO_IPV6
 import android.system.OsConstants.IPPROTO_UDP
+import android.system.OsConstants.RT_SCOPE_LINK
 import android.system.OsConstants.SOCK_DGRAM
 import android.util.Log
 import androidx.test.filters.SmallTest
@@ -78,12 +80,14 @@
 import com.android.modules.utils.build.SdkLevel.isAtLeastU
 import com.android.net.module.util.DnsPacket
 import com.android.net.module.util.HexDump
+import com.android.net.module.util.NetworkStackConstants.IPV6_ADDR_LEN
 import com.android.net.module.util.PacketBuilder
 import com.android.testutils.ConnectivityModuleTest
 import com.android.testutils.DevSdkIgnoreRule
-import com.android.testutils.DevSdkIgnoreRule.IgnoreAfter
 import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo
 import com.android.testutils.DevSdkIgnoreRunner
+import com.android.testutils.DeviceConfigRule
+import com.android.testutils.NSResponder
 import com.android.testutils.RecorderCallback.CallbackEntry.CapabilitiesChanged
 import com.android.testutils.RecorderCallback.CallbackEntry.LinkPropertiesChanged
 import com.android.testutils.TapPacketReader
@@ -133,6 +137,7 @@
 private const val TEST_PORT = 12345
 private const val MDNS_PORT = 5353.toShort()
 private val multicastIpv6Addr = parseNumericAddress("ff02::fb") as Inet6Address
+private val testSrcAddr = parseNumericAddress("2001:db8::123") as Inet6Address
 
 @AppModeFull(reason = "Socket cannot bind in instant app mode")
 @RunWith(DevSdkIgnoreRunner::class)
@@ -144,6 +149,9 @@
     @get:Rule
     val ignoreRule = DevSdkIgnoreRule()
 
+    @get:Rule
+    val deviceConfigRule = DeviceConfigRule()
+
     private val context by lazy { InstrumentationRegistry.getInstrumentation().context }
     private val nsdManager by lazy {
         context.getSystemService(NsdManager::class.java) ?: fail("Could not get NsdManager service")
@@ -682,7 +690,7 @@
         assertEquals(OffloadEngine.OFFLOAD_TYPE_REPLY.toLong(), serviceInfo.offloadType)
         val offloadPayload = serviceInfo.offloadPayload
         assertNotNull(offloadPayload)
-        val dnsPacket = TestDnsPacket(offloadPayload)
+        val dnsPacket = TestDnsPacket(offloadPayload, dstAddr = multicastIpv6Addr)
         assertEquals(0x8400, dnsPacket.header.flags)
         assertEquals(0, dnsPacket.records[DnsPacket.QDSECTION].size)
         assertTrue(dnsPacket.records[DnsPacket.ANSECTION].size >= 5)
@@ -1286,7 +1294,8 @@
         // Resolve service on testNetwork1
         val resolveRecord = NsdResolveRecord()
         val packetReader = TapPacketReader(Handler(handlerThread.looper),
-                testNetwork1.iface.fileDescriptor.fileDescriptor, 1500 /* maxPacketSize */)
+                testNetwork1.iface.fileDescriptor.fileDescriptor, 1500 /* maxPacketSize */
+        )
         packetReader.startAsyncForTest()
         handlerThread.waitForIdle(TIMEOUT_MS)
 
@@ -1349,6 +1358,68 @@
                 serviceResolved.serviceInfo.hostAddresses.toSet())
     }
 
+    @Test
+    fun testUnicastReplyUsedWhenQueryUnicastFlagSet() {
+        // The flag may be removed in the future but unicast replies should be enabled by default
+        // in that case. The rule will reset flags automatically on teardown.
+        deviceConfigRule.setConfig(NAMESPACE_TETHERING, "test_nsd_unicast_reply_enabled", "1")
+
+        val si = makeTestServiceInfo(testNetwork1.network)
+
+        // Register service on testNetwork1
+        val registrationRecord = NsdRegistrationRecord()
+        var nsResponder: NSResponder? = null
+        tryTest {
+            registerService(registrationRecord, si)
+            val packetReader = TapPacketReader(Handler(handlerThread.looper),
+                testNetwork1.iface.fileDescriptor.fileDescriptor, 1500 /* maxPacketSize */)
+            packetReader.startAsyncForTest()
+
+            handlerThread.waitForIdle(TIMEOUT_MS)
+            /*
+            Send a "query unicast" query.
+            Generated with:
+            scapy.raw(scapy.DNS(rd=0, qr=0, aa=0, qd =
+                    scapy.DNSQR(qname='_nmt123456789._tcp.local', qtype='PTR', qclass=0x8001)
+            )).hex()
+            */
+            val mdnsPayload = HexDump.hexStringToByteArray("0000000000010000000000000d5f6e6d74313" +
+                    "233343536373839045f746370056c6f63616c00000c8001")
+            replaceServiceNameAndTypeWithTestSuffix(mdnsPayload)
+
+            val testSrcAddr = makeLinkLocalAddressOfOtherDeviceOnPrefix(testNetwork1.network)
+            nsResponder = NSResponder(packetReader, mapOf(
+                testSrcAddr to MacAddress.fromString("01:02:03:04:05:06")
+            )).apply { start() }
+
+            packetReader.sendResponse(buildMdnsPacket(mdnsPayload, testSrcAddr))
+            // The reply is sent unicast to the source address. There may be announcements sent
+            // multicast around this time, so filter by destination address.
+            val reply = packetReader.pollForMdnsPacket { pkt ->
+                pkt.isReplyFor("$serviceType.local", DnsResolver.TYPE_PTR) &&
+                        pkt.dstAddr == testSrcAddr
+            }
+            assertNotNull(reply)
+        } cleanup {
+            nsResponder?.stop()
+            nsdManager.unregisterService(registrationRecord)
+            registrationRecord.expectCallback<ServiceUnregistered>()
+        }
+    }
+
+    private fun makeLinkLocalAddressOfOtherDeviceOnPrefix(network: Network): Inet6Address {
+        val lp = cm.getLinkProperties(network) ?: fail("No LinkProperties for net $network")
+        // Expect to have a /64 link-local address
+        val linkAddr = lp.linkAddresses.firstOrNull {
+            it.isIPv6 && it.scope == RT_SCOPE_LINK && it.prefixLength == 64
+        } ?: fail("No /64 link-local address found in ${lp.linkAddresses} for net $network")
+
+        // Add one to the device address to simulate the address of another device on the prefix
+        val addrBytes = linkAddr.address.address
+        addrBytes[IPV6_ADDR_LEN - 1]++
+        return Inet6Address.getByAddress(addrBytes) as Inet6Address
+    }
+
     private fun buildConflictingAnnouncement(): ByteBuffer {
         /*
         Generated with:
@@ -1393,7 +1464,10 @@
         replaceAll(buffer, source, replacement)
     }
 
-    private fun buildMdnsPacket(mdnsPayload: ByteArray): ByteBuffer {
+    private fun buildMdnsPacket(
+        mdnsPayload: ByteArray,
+        srcAddr: Inet6Address = testSrcAddr
+    ): ByteBuffer {
         val packetBuffer = PacketBuilder.allocate(true /* hasEther */, IPPROTO_IPV6,
                 IPPROTO_UDP, mdnsPayload.size)
         val packetBuilder = PacketBuilder(packetBuffer)
@@ -1408,7 +1482,7 @@
                 0x60000000, // version=6, traffic class=0x0, flowlabel=0x0
                 IPPROTO_UDP.toByte(),
                 64 /* hop limit */,
-                parseNumericAddress("2001:db8::123") as Inet6Address /* srcIp */,
+                srcAddr,
                 multicastIpv6Addr /* dstIp */)
         packetBuilder.writeUdpHeader(MDNS_PORT /* srcPort */, MDNS_PORT /* dstPort */)
         packetBuffer.put(mdnsPayload)
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt
index 0c04bff..ee0bd1a 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt
@@ -35,6 +35,7 @@
 import java.net.InetSocketAddress
 import kotlin.test.assertContentEquals
 import kotlin.test.assertEquals
+import kotlin.test.assertNotSame
 import kotlin.test.assertTrue
 import org.junit.After
 import org.junit.Before
@@ -213,7 +214,12 @@
         packetHandler.handlePacket(query, query.size, src)
 
         val packetCaptor = ArgumentCaptor.forClass(MdnsPacket::class.java)
-        verify(repository).getReply(packetCaptor.capture(), eq(src))
+        val srcCaptor = ArgumentCaptor.forClass(InetSocketAddress::class.java)
+        verify(repository).getReply(packetCaptor.capture(), srcCaptor.capture())
+
+        assertEquals(src, srcCaptor.value)
+        assertNotSame(src, srcCaptor.value, "src will be reused by the packetHandler, references " +
+                "to it should not be used outside of handlePacket.")
 
         packetCaptor.value.let {
             assertEquals(1, it.questions.size)
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt
index 4b1f166..1edc806 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt
@@ -24,7 +24,6 @@
 import com.android.server.connectivity.mdns.MdnsAnnouncer.AnnouncementInfo
 import com.android.server.connectivity.mdns.MdnsRecord.TYPE_A
 import com.android.server.connectivity.mdns.MdnsRecord.TYPE_AAAA
-import com.android.server.connectivity.mdns.MdnsRecord.TYPE_ANY
 import com.android.server.connectivity.mdns.MdnsRecord.TYPE_PTR
 import com.android.server.connectivity.mdns.MdnsRecord.TYPE_SRV
 import com.android.server.connectivity.mdns.MdnsRecord.TYPE_TXT
@@ -95,7 +94,6 @@
         override fun getInterfaceInetAddresses(iface: NetworkInterface) =
                 Collections.enumeration(TEST_ADDRESSES.map { it.address })
     }
-    private val flags = MdnsFeatureFlags.newBuilder().build()
 
     @Before
     fun setUp() {
@@ -108,9 +106,19 @@
         thread.join()
     }
 
+    private fun makeFlags(
+        includeInetAddressesInProbing: Boolean = false,
+        isKnownAnswerSuppressionEnabled: Boolean = false,
+        unicastReplyEnabled: Boolean = true
+    ) = MdnsFeatureFlags.Builder()
+        .setIncludeInetAddressRecordsInProbing(includeInetAddressesInProbing)
+        .setIsKnownAnswerSuppressionEnabled(isKnownAnswerSuppressionEnabled)
+        .setIsUnicastReplyEnabled(unicastReplyEnabled)
+        .build()
+
     @Test
     fun testAddServiceAndProbe() {
-        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
+        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
         assertEquals(0, repository.servicesCount)
         assertEquals(-1, repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1))
         assertEquals(1, repository.servicesCount)
@@ -144,7 +152,7 @@
 
     @Test
     fun testAddAndConflicts() {
-        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
+        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
         repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
         assertFailsWith(NameConflictException::class) {
             repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_1)
@@ -156,7 +164,7 @@
 
     @Test
     fun testAddAndUpdates() {
-        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
+        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
         repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
 
         assertFailsWith(IllegalArgumentException::class) {
@@ -190,7 +198,7 @@
 
     @Test
     fun testInvalidReuseOfServiceId() {
-        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
+        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
         repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
         assertFailsWith(IllegalArgumentException::class) {
             repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_2)
@@ -199,7 +207,7 @@
 
     @Test
     fun testHasActiveService() {
-        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
+        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
         assertFalse(repository.hasActiveService(TEST_SERVICE_ID_1))
 
         repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
@@ -216,7 +224,7 @@
 
     @Test
     fun testExitAnnouncements() {
-        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
+        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
         repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
         repository.onAdvertisementSent(TEST_SERVICE_ID_1, 2 /* sentPacketCount */)
 
@@ -246,7 +254,7 @@
 
     @Test
     fun testExitAnnouncements_WithSubtypes() {
-        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
+        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
         repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1,
                 setOf(TEST_SUBTYPE, TEST_SUBTYPE2))
         repository.onAdvertisementSent(TEST_SERVICE_ID_1, 2 /* sentPacketCount */)
@@ -288,7 +296,7 @@
 
     @Test
     fun testExitingServiceReAdded() {
-        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
+        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
         repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
         repository.onAdvertisementSent(TEST_SERVICE_ID_1, 2 /* sentPacketCount */)
         repository.exitService(TEST_SERVICE_ID_1)
@@ -303,7 +311,7 @@
 
     @Test
     fun testOnProbingSucceeded() {
-        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
+        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
         val announcementInfo = repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1,
                 setOf(TEST_SUBTYPE, TEST_SUBTYPE2))
         repository.onAdvertisementSent(TEST_SERVICE_ID_1, 2 /* sentPacketCount */)
@@ -435,7 +443,7 @@
 
     @Test
     fun testGetOffloadPacket() {
-        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
+        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
         repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
         val serviceName = arrayOf("MyTestService", "_testservice", "_tcp", "local")
         val serviceType = arrayOf("_testservice", "_tcp", "local")
@@ -497,7 +505,7 @@
 
     @Test
     fun testGetReplyCaseInsensitive() {
-        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
+        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
         repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
         val questionsCaseInSensitive = listOf(
                 MdnsPointerRecord(arrayOf("_TESTSERVICE", "_TCP", "local"), false /* isUnicast */))
@@ -517,7 +525,7 @@
     private fun makeQuery(vararg queries: Pair<Int, Array<String>>): MdnsPacket {
         val questions = queries.map { (type, name) -> makeQuestionRecord(name, type) }
         return MdnsPacket(0 /* flags */, questions, listOf() /* answers */,
-                listOf() /* authorityRecords */, listOf() /* additionalRecords */)
+            listOf() /* authorityRecords */, listOf() /* additionalRecords */)
     }
 
     private fun makeQuestionRecord(name: Array<String>, type: Int): MdnsRecord {
@@ -532,7 +540,7 @@
 
     @Test
     fun testGetReply_singlePtrQuestion_returnsSrvTxtAddressNsecRecords() {
-        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
+        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
         repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1, setOf(TEST_SUBTYPE))
         val src = InetSocketAddress(parseNumericAddress("192.0.2.123"), 5353)
         val serviceName = arrayOf("MyTestService", "_testservice", "_tcp", "local")
@@ -563,7 +571,7 @@
 
     @Test
     fun testGetReply_singleSubtypePtrQuestion_returnsSrvTxtAddressNsecRecords() {
-        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
+        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
         repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1, setOf(TEST_SUBTYPE))
         val src = InetSocketAddress(parseNumericAddress("192.0.2.123"), 5353)
         val serviceName = arrayOf("MyTestService", "_testservice", "_tcp", "local")
@@ -596,7 +604,7 @@
 
     @Test
     fun testGetReply_duplicatePtrQuestions_doesNotReturnDuplicateRecords() {
-        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
+        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
         repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1, setOf(TEST_SUBTYPE))
         val src = InetSocketAddress(parseNumericAddress("192.0.2.123"), 5353)
         val serviceName = arrayOf("MyTestService", "_testservice", "_tcp", "local")
@@ -629,7 +637,7 @@
 
     @Test
     fun testGetReply_multiplePtrQuestionsWithSubtype_doesNotReturnDuplicateRecords() {
-        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
+        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
         repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1, setOf(TEST_SUBTYPE))
         val src = InetSocketAddress(parseNumericAddress("192.0.2.123"), 5353)
         val serviceName = arrayOf("MyTestService", "_testservice", "_tcp", "local")
@@ -665,7 +673,7 @@
 
     @Test
     fun testGetReply_txtQuestion_returnsNoNsecRecord() {
-        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
+        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
         repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1, setOf(TEST_SUBTYPE))
         val src = InetSocketAddress(parseNumericAddress("192.0.2.123"), 5353)
         val serviceName = arrayOf("MyTestService", "_testservice", "_tcp", "local")
@@ -682,7 +690,7 @@
 
     @Test
     fun testGetReply_AAAAQuestionButNoIpv6Address_returnsNsecRecord() {
-        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
+        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
         repository.initWithService(
                 TEST_SERVICE_ID_1, TEST_SERVICE_1, setOf(TEST_SUBTYPE),
                 listOf(LinkAddress(parseNumericAddress("192.0.2.111"), 24)))
@@ -701,7 +709,7 @@
 
     @Test
     fun testGetReply_ptrAndSrvQuestions_doesNotReturnSrvRecordInAdditionalAnswerSection() {
-        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
+        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
         repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1, setOf(TEST_SUBTYPE))
         val src = InetSocketAddress(parseNumericAddress("192.0.2.123"), 5353)
         val serviceName = arrayOf("MyTestService", "_testservice", "_tcp", "local")
@@ -723,7 +731,7 @@
 
     @Test
     fun testGetReply_srvTxtAddressQuestions_returnsAllRecordsInAnswerSectionExceptNsec() {
-        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
+        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
         repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1, setOf(TEST_SUBTYPE))
         val src = InetSocketAddress(parseNumericAddress("192.0.2.123"), 5353)
         val serviceName = arrayOf("MyTestService", "_testservice", "_tcp", "local")
@@ -757,7 +765,7 @@
 
     @Test
     fun testGetReply_queryWithIpv4Address_replyWithIpv4Address() {
-        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
+        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
         repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1, setOf(TEST_SUBTYPE))
         val query = makeQuery(TYPE_PTR to arrayOf("_testservice", "_tcp", "local"))
 
@@ -771,7 +779,7 @@
 
     @Test
     fun testGetReply_queryWithIpv6Address_replyWithIpv6Address() {
-        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
+        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
         repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1, setOf(TEST_SUBTYPE))
         val query = makeQuery(TYPE_PTR to arrayOf("_testservice", "_tcp", "local"))
 
@@ -785,7 +793,7 @@
 
     @Test
     fun testGetConflictingServices() {
-        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
+        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
         repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
         repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_2)
 
@@ -813,7 +821,7 @@
 
     @Test
     fun testGetConflictingServicesCaseInsensitive() {
-        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
+        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
         repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
         repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_2)
 
@@ -841,7 +849,7 @@
 
     @Test
     fun testGetConflictingServices_IdenticalService() {
-        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
+        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
         repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
         repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_2)
 
@@ -870,7 +878,7 @@
 
     @Test
     fun testGetConflictingServicesCaseInsensitive_IdenticalService() {
-        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
+        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
         repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
         repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_2)
 
@@ -899,7 +907,7 @@
 
     @Test
     fun testGetServiceRepliedRequestsCount() {
-        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
+        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
         repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
         // Verify that there is no packet replied.
         assertEquals(MdnsConstants.NO_PACKET,
@@ -924,7 +932,7 @@
     @Test
     fun testIncludeInetAddressRecordsInProbing() {
         val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME,
-                MdnsFeatureFlags.newBuilder().setIncludeInetAddressRecordsInProbing(true).build())
+            makeFlags(includeInetAddressesInProbing = true))
         repository.updateAddresses(TEST_ADDRESSES)
         assertEquals(0, repository.servicesCount)
         assertEquals(-1, repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1))
@@ -990,7 +998,7 @@
             expectReply: Boolean
     ) {
         val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME,
-                MdnsFeatureFlags.newBuilder().setIsKnownAnswerSuppressionEnabled(true).build())
+            makeFlags(isKnownAnswerSuppressionEnabled = true))
         repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
         val query = MdnsPacket(0 /* flags */, questions, knownAnswers,
                 listOf() /* authorityRecords */, listOf() /* additionalRecords */)
@@ -1222,23 +1230,109 @@
                 MdnsPointerRecord(queriedName, false /* isUnicast */),
                 MdnsServiceRecord(serviceName, false /* isUnicast */))
         val knownAnswers = listOf(
-                MdnsPointerRecord(
-                        queriedName,
-                        0L /* receiptTimeMillis */,
-                        false /* cacheFlush */,
-                        LONG_TTL - 1000L,
-                        serviceName),
-                MdnsServiceRecord(
-                        serviceName,
-                        0L /* receiptTimeMillis */,
-                        false /* cacheFlush */,
-                        SHORT_TTL - 15_000L,
-                        0 /* servicePriority */,
-                        0 /* serviceWeight */,
-                        TEST_PORT,
-                        TEST_HOSTNAME))
-        doGetReplyWithAnswersTest(questions, knownAnswers, listOf() /* replyAnswers */,
-                listOf() /* additionalAnswers */, false /* expectReply */)
+            MdnsPointerRecord(
+                queriedName,
+                0L /* receiptTimeMillis */,
+                false /* cacheFlush */,
+                LONG_TTL - 1000L,
+                serviceName
+            ),
+            MdnsServiceRecord(
+                serviceName,
+                0L /* receiptTimeMillis */,
+                false /* cacheFlush */,
+                SHORT_TTL - 15_000L,
+                0 /* servicePriority */,
+                0 /* serviceWeight */,
+                TEST_PORT,
+                TEST_HOSTNAME
+            )
+        )
+        doGetReplyWithAnswersTest(
+            questions, knownAnswers, listOf() /* replyAnswers */,
+            listOf() /* additionalAnswers */, false /* expectReply */
+        )
+    }
+
+    @Test
+    fun testReplyUnicastToQueryUnicastQuestions() {
+        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
+        repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
+
+        // Ask for 2 services, only the first one is known and requests unicast reply
+        val questions = listOf(
+            MdnsPointerRecord(arrayOf("_testservice", "_tcp", "local"), true /* isUnicast */),
+            MdnsPointerRecord(arrayOf("_otherservice", "_tcp", "local"), true /* isUnicast */))
+        val query = MdnsPacket(0 /* flags */, questions, listOf() /* answers */,
+            listOf() /* authorityRecords */, listOf() /* additionalRecords */)
+        val src = InetSocketAddress(parseNumericAddress("2001:db8::123"), 5353)
+
+        // Reply to the question and verify it is sent to the source.
+        val reply = repository.getReply(query, src)
+        assertNotNull(reply)
+        assertEquals(src, reply.destination)
+    }
+
+    @Test
+    fun testReplyMulticastToQueryUnicastAndMulticastMixedQuestions() {
+        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
+        repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
+        repository.addServiceAndFinishProbing(TEST_SERVICE_ID_2, NsdServiceInfo().apply {
+            serviceType = "_otherservice._tcp"
+            serviceName = "OtherTestService"
+            port = TEST_PORT
+        })
+
+        // Ask for 2 services, both are known and only the first one requests unicast reply
+        val questions = listOf(
+            MdnsPointerRecord(arrayOf("_testservice", "_tcp", "local"), true /* isUnicast */),
+            MdnsPointerRecord(arrayOf("_otherservice", "_tcp", "local"), false /* isUnicast */))
+        val query = MdnsPacket(0 /* flags */, questions, listOf() /* answers */,
+            listOf() /* authorityRecords */, listOf() /* additionalRecords */)
+        val src = InetSocketAddress(parseNumericAddress("2001:db8::123"), 5353)
+
+        // Reply to the question and verify it is sent multicast.
+        val reply = repository.getReply(query, src)
+        assertNotNull(reply)
+        assertEquals(MdnsConstants.getMdnsIPv6Address(), reply.destination.address)
+    }
+
+    @Test
+    fun testReplyMulticastWhenNoUnicastQueryMatches() {
+        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
+        repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
+
+        // Ask for 2 services, the first one requests a unicast reply but is unknown
+        val questions = listOf(
+            MdnsPointerRecord(arrayOf("_otherservice", "_tcp", "local"), true /* isUnicast */),
+            MdnsPointerRecord(arrayOf("_testservice", "_tcp", "local"), false /* isUnicast */))
+        val query = MdnsPacket(0 /* flags */, questions, listOf() /* answers */,
+            listOf() /* authorityRecords */, listOf() /* additionalRecords */)
+        val src = InetSocketAddress(parseNumericAddress("2001:db8::123"), 5353)
+
+        // Reply to the question and verify it is sent multicast.
+        val reply = repository.getReply(query, src)
+        assertNotNull(reply)
+        assertEquals(MdnsConstants.getMdnsIPv6Address(), reply.destination.address)
+    }
+
+    @Test
+    fun testReplyMulticastWhenUnicastFeatureDisabled() {
+        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME,
+            makeFlags(unicastReplyEnabled = false))
+        repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
+
+        // The service is known and requests unicast reply, but the feature is disabled
+        val questions = listOf(
+            MdnsPointerRecord(arrayOf("_testservice", "_tcp", "local"), true /* isUnicast */))
+        val query = MdnsPacket(0 /* flags */, questions, listOf() /* answers */,
+            listOf() /* authorityRecords */, listOf() /* additionalRecords */)
+        val src = InetSocketAddress(parseNumericAddress("2001:db8::123"), 5353)
+
+        // Reply to the question and verify it is sent multicast.
+        val reply = repository.getReply(query, src)
+        assertNotNull(reply)
+        assertEquals(MdnsConstants.getMdnsIPv6Address(), reply.destination.address)
     }
 }
 
@@ -1250,6 +1344,13 @@
 ): AnnouncementInfo {
     updateAddresses(addresses)
     serviceInfo.setSubtypes(subtypes)
+    return addServiceAndFinishProbing(serviceId, serviceInfo)
+}
+
+private fun MdnsRecordRepository.addServiceAndFinishProbing(
+    serviceId: Int,
+    serviceInfo: NsdServiceInfo
+): AnnouncementInfo {
     addService(serviceId, serviceInfo)
     val probingInfo = setServiceProbing(serviceId)
     assertNotNull(probingInfo)
diff --git a/thread/tests/cts/src/android/net/thread/cts/ThreadNetworkControllerTest.java b/thread/tests/cts/src/android/net/thread/cts/ThreadNetworkControllerTest.java
index e02e74d..7a6c9aa 100644
--- a/thread/tests/cts/src/android/net/thread/cts/ThreadNetworkControllerTest.java
+++ b/thread/tests/cts/src/android/net/thread/cts/ThreadNetworkControllerTest.java
@@ -24,6 +24,8 @@
 
 import static androidx.test.platform.app.InstrumentationRegistry.getInstrumentation;
 
+import static com.android.testutils.TestPermissionUtil.runAsShell;
+
 import static com.google.common.truth.Truth.assertThat;
 import static com.google.common.util.concurrent.MoreExecutors.directExecutor;
 
@@ -34,6 +36,10 @@
 
 import android.Manifest.permission;
 import android.content.Context;
+import android.net.ConnectivityManager;
+import android.net.Network;
+import android.net.NetworkCapabilities;
+import android.net.NetworkRequest;
 import android.net.thread.ActiveOperationalDataset;
 import android.net.thread.OperationalDatasetTimestamp;
 import android.net.thread.PendingOperationalDataset;
@@ -74,6 +80,8 @@
 @RunWith(DevSdkIgnoreRunner.class)
 @IgnoreUpTo(Build.VERSION_CODES.TIRAMISU) // Thread is available on only U+
 public class ThreadNetworkControllerTest {
+    private static final int JOIN_TIMEOUT_MILLIS = 30 * 1000;
+    private static final int NETWORK_CALLBACK_TIMEOUT_MILLIS = 10 * 1000;
     private static final int CALLBACK_TIMEOUT_MILLIS = 1000;
     private static final String PERMISSION_THREAD_NETWORK_PRIVILEGED =
             "android.permission.THREAD_NETWORK_PRIVILEGED";
@@ -750,4 +758,36 @@
             assertThat(dataset.getMeshLocalPrefix().getRawAddress()[0]).isEqualTo((byte) 0xfd);
         }
     }
+
+    @Test
+    public void threadNetworkCallback_deviceAttached_threadNetworkIsAvailable() throws Exception {
+        ThreadNetworkController controller = mManager.getAllThreadNetworkControllers().get(0);
+        ActiveOperationalDataset activeDataset = newRandomizedDataset("TestNet", controller);
+        SettableFuture<Void> joinFuture = SettableFuture.create();
+        SettableFuture<Network> networkFuture = SettableFuture.create();
+        ConnectivityManager cm = mContext.getSystemService(ConnectivityManager.class);
+        NetworkRequest networkRequest =
+                new NetworkRequest.Builder()
+                        .addTransportType(NetworkCapabilities.TRANSPORT_THREAD)
+                        .build();
+        ConnectivityManager.NetworkCallback networkCallback =
+                new ConnectivityManager.NetworkCallback() {
+                    @Override
+                    public void onAvailable(Network network) {
+                        networkFuture.set(network);
+                    }
+                };
+
+        runAsShell(
+                PERMISSION_THREAD_NETWORK_PRIVILEGED,
+                () -> controller.join(activeDataset, mExecutor, newOutcomeReceiver(joinFuture)));
+        runAsShell(
+                permission.ACCESS_NETWORK_STATE,
+                () -> cm.registerNetworkCallback(networkRequest, networkCallback));
+
+        joinFuture.get(JOIN_TIMEOUT_MILLIS, MILLISECONDS);
+        runAsShell(
+                permission.ACCESS_NETWORK_STATE, () -> assertThat(isAttached(controller)).isTrue());
+        assertThat(networkFuture.get(NETWORK_CALLBACK_TIMEOUT_MILLIS, MILLISECONDS)).isNotNull();
+    }
 }