Merge "Do not immediately send new queries on cache hit"
diff --git a/service-t/src/com/android/server/NsdService.java b/service-t/src/com/android/server/NsdService.java
index 1e0d544..4af4c6a 100644
--- a/service-t/src/com/android/server/NsdService.java
+++ b/service-t/src/com/android/server/NsdService.java
@@ -19,6 +19,7 @@
 import static android.net.ConnectivityManager.NETID_UNSET;
 import static android.net.nsd.NsdManager.MDNS_DISCOVERY_MANAGER_EVENT;
 import static android.net.nsd.NsdManager.MDNS_SERVICE_EVENT;
+import static android.net.nsd.NsdManager.RESOLVE_SERVICE_SUCCEEDED;
 import static android.provider.DeviceConfig.NAMESPACE_TETHERING;
 
 import static com.android.modules.utils.build.SdkLevel.isAtLeastU;
@@ -54,6 +55,7 @@
 import android.os.UserHandle;
 import android.text.TextUtils;
 import android.util.Log;
+import android.util.Pair;
 import android.util.SparseArray;
 
 import com.android.internal.annotations.VisibleForTesting;
@@ -83,6 +85,7 @@
 import java.net.SocketException;
 import java.net.UnknownHostException;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
@@ -242,14 +245,14 @@
         public void onServiceNameDiscovered(@NonNull MdnsServiceInfo serviceInfo) {
             mNsdStateMachine.sendMessage(MDNS_DISCOVERY_MANAGER_EVENT, mTransactionId,
                     NsdManager.SERVICE_FOUND,
-                    new MdnsEvent(mClientId, mReqServiceInfo.getServiceType(), serviceInfo));
+                    new MdnsEvent(mClientId, serviceInfo));
         }
 
         @Override
         public void onServiceNameRemoved(@NonNull MdnsServiceInfo serviceInfo) {
             mNsdStateMachine.sendMessage(MDNS_DISCOVERY_MANAGER_EVENT, mTransactionId,
                     NsdManager.SERVICE_LOST,
-                    new MdnsEvent(mClientId, mReqServiceInfo.getServiceType(), serviceInfo));
+                    new MdnsEvent(mClientId, serviceInfo));
         }
     }
 
@@ -264,7 +267,7 @@
         public void onServiceFound(MdnsServiceInfo serviceInfo) {
             mNsdStateMachine.sendMessage(MDNS_DISCOVERY_MANAGER_EVENT, mTransactionId,
                     NsdManager.RESOLVE_SERVICE_SUCCEEDED,
-                    new MdnsEvent(mClientId, mReqServiceInfo.getServiceType(), serviceInfo));
+                    new MdnsEvent(mClientId, serviceInfo));
         }
     }
 
@@ -279,21 +282,21 @@
         public void onServiceFound(@NonNull MdnsServiceInfo serviceInfo) {
             mNsdStateMachine.sendMessage(MDNS_DISCOVERY_MANAGER_EVENT, mTransactionId,
                     NsdManager.SERVICE_UPDATED,
-                    new MdnsEvent(mClientId, mReqServiceInfo.getServiceType(), serviceInfo));
+                    new MdnsEvent(mClientId, serviceInfo));
         }
 
         @Override
         public void onServiceUpdated(@NonNull MdnsServiceInfo serviceInfo) {
             mNsdStateMachine.sendMessage(MDNS_DISCOVERY_MANAGER_EVENT, mTransactionId,
                     NsdManager.SERVICE_UPDATED,
-                    new MdnsEvent(mClientId, mReqServiceInfo.getServiceType(), serviceInfo));
+                    new MdnsEvent(mClientId, serviceInfo));
         }
 
         @Override
         public void onServiceRemoved(@NonNull MdnsServiceInfo serviceInfo) {
             mNsdStateMachine.sendMessage(MDNS_DISCOVERY_MANAGER_EVENT, mTransactionId,
                     NsdManager.SERVICE_UPDATED_LOST,
-                    new MdnsEvent(mClientId, mReqServiceInfo.getServiceType(), serviceInfo));
+                    new MdnsEvent(mClientId, serviceInfo));
         }
     }
 
@@ -303,14 +306,10 @@
     private static class MdnsEvent {
         final int mClientId;
         @NonNull
-        final String mRequestedServiceType;
-        @NonNull
         final MdnsServiceInfo mMdnsServiceInfo;
 
-        MdnsEvent(int clientId, @NonNull String requestedServiceType,
-                @NonNull MdnsServiceInfo mdnsServiceInfo) {
+        MdnsEvent(int clientId, @NonNull MdnsServiceInfo mdnsServiceInfo) {
             mClientId = clientId;
-            mRequestedServiceType = requestedServiceType;
             mMdnsServiceInfo = mdnsServiceInfo;
         }
     }
@@ -601,7 +600,10 @@
 
                         final NsdServiceInfo info = args.serviceInfo;
                         id = getUniqueId();
-                        final String serviceType = constructServiceType(info.getServiceType());
+                        final Pair<String, String> typeAndSubtype =
+                                parseTypeAndSubtype(info.getServiceType());
+                        final String serviceType = typeAndSubtype == null
+                                ? null : typeAndSubtype.first;
                         if (clientInfo.mUseJavaBackend
                                 || mDeps.isMdnsDiscoveryManagerEnabled(mContext)
                                 || useDiscoveryManagerForType(serviceType)) {
@@ -615,12 +617,17 @@
                             maybeStartMonitoringSockets();
                             final MdnsListener listener =
                                     new DiscoveryListener(clientId, id, info, listenServiceType);
-                            final MdnsSearchOptions options = MdnsSearchOptions.newBuilder()
-                                    .setNetwork(info.getNetwork())
-                                    .setIsPassiveMode(true)
-                                    .build();
+                            final MdnsSearchOptions.Builder optionsBuilder =
+                                    MdnsSearchOptions.newBuilder()
+                                            .setNetwork(info.getNetwork())
+                                            .setIsPassiveMode(true);
+                            if (typeAndSubtype.second != null) {
+                                // The parsing ensures subtype starts with an underscore.
+                                // MdnsSearchOptions expects the underscore to not be present.
+                                optionsBuilder.addSubtype(typeAndSubtype.second.substring(1));
+                            }
                             mMdnsDiscoveryManager.registerListener(
-                                    listenServiceType, listener, options);
+                                    listenServiceType, listener, optionsBuilder.build());
                             storeDiscoveryManagerRequestMap(clientId, id, listener, clientInfo);
                             clientInfo.onDiscoverServicesStarted(clientId, info);
                             clientInfo.log("Register a DiscoveryListener " + id
@@ -699,7 +706,9 @@
                         id = getUniqueId();
                         final NsdServiceInfo serviceInfo = args.serviceInfo;
                         final String serviceType = serviceInfo.getServiceType();
-                        final String registerServiceType = constructServiceType(serviceType);
+                        final Pair<String, String> typeSubtype = parseTypeAndSubtype(serviceType);
+                        final String registerServiceType = typeSubtype == null
+                                ? null : typeSubtype.first;
                         if (clientInfo.mUseJavaBackend
                                 || mDeps.isMdnsAdvertiserEnabled(mContext)
                                 || useAdvertiserForType(registerServiceType)) {
@@ -714,7 +723,11 @@
                                     serviceInfo.getServiceName()));
 
                             maybeStartMonitoringSockets();
-                            mAdvertiser.addService(id, serviceInfo);
+                            // TODO: pass in the subtype as well. Including the subtype in the
+                            // service type would generate service instance names like
+                            // Name._subtype._sub._type._tcp, which is incorrect
+                            // (it should be Name._type._tcp).
+                            mAdvertiser.addService(id, serviceInfo, typeSubtype.second);
                             storeAdvertiserRequestMap(clientId, id, clientInfo);
                         } else {
                             maybeStartDaemon();
@@ -780,7 +793,10 @@
 
                         final NsdServiceInfo info = args.serviceInfo;
                         id = getUniqueId();
-                        final String serviceType = constructServiceType(info.getServiceType());
+                        final Pair<String, String> typeSubtype =
+                                parseTypeAndSubtype(info.getServiceType());
+                        final String serviceType = typeSubtype == null
+                                ? null : typeSubtype.first;
                         if (clientInfo.mUseJavaBackend
                                 ||  mDeps.isMdnsDiscoveryManagerEnabled(mContext)
                                 || useDiscoveryManagerForType(serviceType)) {
@@ -873,7 +889,10 @@
 
                         final NsdServiceInfo info = args.serviceInfo;
                         id = getUniqueId();
-                        final String serviceType = constructServiceType(info.getServiceType());
+                        final Pair<String, String> typeAndSubtype =
+                                parseTypeAndSubtype(info.getServiceType());
+                        final String serviceType = typeAndSubtype == null
+                                ? null : typeAndSubtype.first;
                         if (serviceType == null) {
                             clientInfo.onServiceInfoCallbackRegistrationFailed(clientId,
                                     NsdManager.FAILURE_BAD_PARAMETERS);
@@ -1104,9 +1123,38 @@
                 return true;
             }
 
-            private NsdServiceInfo buildNsdServiceInfoFromMdnsEvent(final MdnsEvent event) {
+            @Nullable
+            private NsdServiceInfo buildNsdServiceInfoFromMdnsEvent(
+                    final MdnsEvent event, int code) {
                 final MdnsServiceInfo serviceInfo = event.mMdnsServiceInfo;
-                final String serviceType = event.mRequestedServiceType;
+                final String[] typeArray = serviceInfo.getServiceType();
+                final String joinedType;
+                if (typeArray.length == 0
+                        || !typeArray[typeArray.length - 1].equals(LOCAL_DOMAIN_NAME)) {
+                    Log.wtf(TAG, "MdnsServiceInfo type does not end in .local: "
+                            + Arrays.toString(typeArray));
+                    return null;
+                } else {
+                    joinedType = TextUtils.join(".",
+                            Arrays.copyOfRange(typeArray, 0, typeArray.length - 1));
+                }
+                final String serviceType;
+                switch (code) {
+                    case NsdManager.SERVICE_FOUND:
+                    case NsdManager.SERVICE_LOST:
+                        // For consistency with historical behavior, discovered service types have
+                        // a dot at the end.
+                        serviceType = joinedType + ".";
+                        break;
+                    case RESOLVE_SERVICE_SUCCEEDED:
+                        // For consistency with historical behavior, resolved service types have
+                        // a dot at the beginning.
+                        serviceType = "." + joinedType;
+                        break;
+                    default:
+                        serviceType = joinedType;
+                        break;
+                }
                 final String serviceName = serviceInfo.getServiceInstanceName();
                 final NsdServiceInfo servInfo = new NsdServiceInfo(serviceName, serviceType);
                 final Network network = serviceInfo.getNetwork();
@@ -1131,7 +1179,9 @@
 
                 final MdnsEvent event = (MdnsEvent) obj;
                 final int clientId = event.mClientId;
-                final NsdServiceInfo info = buildNsdServiceInfoFromMdnsEvent(event);
+                final NsdServiceInfo info = buildNsdServiceInfoFromMdnsEvent(event, code);
+                // Errors are already logged if null
+                if (info == null) return false;
                 if (DBG) {
                     Log.d(TAG, String.format("MdnsDiscoveryManager event code=%s transactionId=%d",
                             NsdManager.nameOf(code), transactionId));
@@ -1150,8 +1200,6 @@
                             break;
                         }
                         final MdnsServiceInfo serviceInfo = event.mMdnsServiceInfo;
-                        // Add '.' in front of the service type that aligns with historical behavior
-                        info.setServiceType("." + event.mRequestedServiceType);
                         info.setPort(serviceInfo.getPort());
 
                         Map<String, String> attrs = serviceInfo.getAttributes();
@@ -1288,28 +1336,39 @@
      * Check the given service type is valid and construct it to a service type
      * which can use for discovery / resolution service.
      *
-     * <p> The valid service type should be 2 labels, or 3 labels if the query is for a
+     * <p>The valid service type should be 2 labels, or 3 labels if the query is for a
      * subtype (see RFC6763 7.1). Each label is up to 63 characters and must start with an
      * underscore; they are alphanumerical characters or dashes or underscore, except the
      * last one that is just alphanumerical. The last label must be _tcp or _udp.
      *
+     * <p>The subtype may also be specified with a comma after the service type, for example
+     * _type._tcp,_subtype.
+     *
      * @param serviceType the request service type for discovery / resolution service
      * @return constructed service type or null if the given service type is invalid.
      */
     @Nullable
-    public static String constructServiceType(String serviceType) {
+    public static Pair<String, String> parseTypeAndSubtype(String serviceType) {
         if (TextUtils.isEmpty(serviceType)) return null;
 
+        final String typeOrSubtypePattern = "_[a-zA-Z0-9-_]{1,61}[a-zA-Z0-9]";
         final Pattern serviceTypePattern = Pattern.compile(
-                "^(_[a-zA-Z0-9-_]{1,61}[a-zA-Z0-9]\\.)?"
-                        + "(_[a-zA-Z0-9-_]{1,61}[a-zA-Z0-9]\\._(?:tcp|udp))"
+                // Optional leading subtype (_subtype._type._tcp)
+                // (?: xxx) is a non-capturing parenthesis, don't capture the dot
+                "^(?:(" + typeOrSubtypePattern + ")\\.)?"
+                        // Actual type (_type._tcp.local)
+                        + "(" + typeOrSubtypePattern + "\\._(?:tcp|udp))"
                         // Drop '.' at the end of service type that is compatible with old backend.
-                        + "\\.?$");
+                        // e.g. allow "_type._tcp.local."
+                        + "\\.?"
+                        // Optional subtype after comma, for "_type._tcp,_subtype" format
+                        + "(?:,(" + typeOrSubtypePattern + "))?"
+                        + "$");
         final Matcher matcher = serviceTypePattern.matcher(serviceType);
         if (!matcher.matches()) return null;
-        return matcher.group(1) == null
-                ? matcher.group(2)
-                : matcher.group(1) + "_sub." + matcher.group(2);
+        // Use the subtype either at the beginning or after the comma
+        final String subtype = matcher.group(1) != null ? matcher.group(1) : matcher.group(3);
+        return new Pair<>(matcher.group(2), subtype);
     }
 
     @VisibleForTesting
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsAdvertiser.java b/service-t/src/com/android/server/connectivity/mdns/MdnsAdvertiser.java
index 655c364..cc08ea1 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsAdvertiser.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsAdvertiser.java
@@ -270,7 +270,8 @@
             mPendingRegistrations.put(id, registration);
             for (int i = 0; i < mAdvertisers.size(); i++) {
                 try {
-                    mAdvertisers.valueAt(i).addService(id, registration.getServiceInfo());
+                    mAdvertisers.valueAt(i).addService(
+                            id, registration.getServiceInfo(), registration.getSubtype());
                 } catch (NameConflictException e) {
                     Log.wtf(TAG, "Name conflict adding services that should have unique names", e);
                 }
@@ -298,9 +299,10 @@
             }
             mAdvertisers.put(socket, advertiser);
             for (int i = 0; i < mPendingRegistrations.size(); i++) {
+                final Registration registration = mPendingRegistrations.valueAt(i);
                 try {
                     advertiser.addService(mPendingRegistrations.keyAt(i),
-                            mPendingRegistrations.valueAt(i).getServiceInfo());
+                            registration.getServiceInfo(), registration.getSubtype());
                 } catch (NameConflictException e) {
                     Log.wtf(TAG, "Name conflict adding services that should have unique names", e);
                 }
@@ -329,10 +331,13 @@
         private int mConflictCount;
         @NonNull
         private NsdServiceInfo mServiceInfo;
+        @Nullable
+        private final String mSubtype;
 
-        private Registration(@NonNull NsdServiceInfo serviceInfo) {
+        private Registration(@NonNull NsdServiceInfo serviceInfo, @Nullable String subtype) {
             this.mOriginalName = serviceInfo.getServiceName();
             this.mServiceInfo = serviceInfo;
+            this.mSubtype = subtype;
         }
 
         /**
@@ -387,6 +392,11 @@
         public NsdServiceInfo getServiceInfo() {
             return mServiceInfo;
         }
+
+        @Nullable
+        public String getSubtype() {
+            return mSubtype;
+        }
     }
 
     /**
@@ -443,8 +453,9 @@
      * Add a service to advertise.
      * @param id A unique ID for the service.
      * @param service The service info to advertise.
+     * @param subtype An optional subtype to advertise the service with.
      */
-    public void addService(int id, NsdServiceInfo service) {
+    public void addService(int id, NsdServiceInfo service, @Nullable String subtype) {
         checkThread();
         if (mRegistrations.get(id) != null) {
             Log.e(TAG, "Adding duplicate registration for " + service);
@@ -453,10 +464,10 @@
             return;
         }
 
-        mSharedLog.i("Adding service " + service + " with ID " + id);
+        mSharedLog.i("Adding service " + service + " with ID " + id + " and subtype " + subtype);
 
         final Network network = service.getNetwork();
-        final Registration registration = new Registration(service);
+        final Registration registration = new Registration(service, subtype);
         final BiPredicate<Network, InterfaceAdvertiserRequest> checkConflictFilter;
         if (network == null) {
             // If registering on all networks, no advertiser must have conflicts
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java b/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java
index 4e09515..724a704 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java
@@ -212,8 +212,9 @@
      *
      * @throws NameConflictException There is already a service being advertised with that name.
      */
-    public void addService(int id, NsdServiceInfo service) throws NameConflictException {
-        final int replacedExitingService = mRecordRepository.addService(id, service);
+    public void addService(int id, NsdServiceInfo service, @Nullable String subtype)
+            throws NameConflictException {
+        final int replacedExitingService = mRecordRepository.addService(id, service, subtype);
         // Cancel announcements for the existing service. This only happens for exiting services
         // (so cancelling exiting announcements), as per RecordRepository.addService.
         if (replacedExitingService >= 0) {
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java b/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java
index 1329172..f756459 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java
@@ -69,6 +69,8 @@
 
     // Top-level domain for link-local queries, as per RFC6762 3.
     private static final String LOCAL_TLD = "local";
+    // Subtype separator as per RFC6763 7.1 (_printer._sub._http._tcp.local)
+    private static final String SUBTYPE_SEPARATOR = "_sub";
 
     // Service type for service enumeration (RFC6763 9.)
     private static final String[] DNS_SD_SERVICE_TYPE =
@@ -156,13 +158,15 @@
         @NonNull
         public final List<RecordInfo<?>> allRecords;
         @NonNull
-        public final RecordInfo<MdnsPointerRecord> ptrRecord;
+        public final List<RecordInfo<MdnsPointerRecord>> ptrRecords;
         @NonNull
         public final RecordInfo<MdnsServiceRecord> srvRecord;
         @NonNull
         public final RecordInfo<MdnsTextRecord> txtRecord;
         @NonNull
         public final NsdServiceInfo serviceInfo;
+        @Nullable
+        public final String subtype;
 
         /**
          * Whether the service is sending exit announcements and will be destroyed soon.
@@ -175,14 +179,16 @@
          * @param deviceHostname Hostname of the device (for the interface used)
          * @param serviceInfo Service to advertise
          */
-        ServiceRegistration(@NonNull String[] deviceHostname, @NonNull NsdServiceInfo serviceInfo) {
+        ServiceRegistration(@NonNull String[] deviceHostname, @NonNull NsdServiceInfo serviceInfo,
+                @Nullable String subtype) {
             this.serviceInfo = serviceInfo;
+            this.subtype = subtype;
 
             final String[] serviceType = splitServiceType(serviceInfo);
             final String[] serviceName = splitFullyQualifiedName(serviceInfo, serviceType);
 
             // Service PTR record
-            ptrRecord = new RecordInfo<>(
+            final RecordInfo<MdnsPointerRecord> ptrRecord = new RecordInfo<>(
                     serviceInfo,
                     new MdnsPointerRecord(
                             serviceType,
@@ -192,6 +198,26 @@
                             serviceName),
                     true /* sharedName */, true /* probing */);
 
+            if (subtype == null) {
+                this.ptrRecords = Collections.singletonList(ptrRecord);
+            } else {
+                final String[] subtypeName = new String[serviceType.length + 2];
+                System.arraycopy(serviceType, 0, subtypeName, 2, serviceType.length);
+                subtypeName[0] = subtype;
+                subtypeName[1] = SUBTYPE_SEPARATOR;
+                final RecordInfo<MdnsPointerRecord> subtypeRecord = new RecordInfo<>(
+                        serviceInfo,
+                        new MdnsPointerRecord(
+                                subtypeName,
+                                0L /* receiptTimeMillis */,
+                                false /* cacheFlush */,
+                                NON_NAME_RECORDS_TTL_MILLIS,
+                                serviceName),
+                        true /* sharedName */, true /* probing */);
+
+                this.ptrRecords = List.of(ptrRecord, subtypeRecord);
+            }
+
             srvRecord = new RecordInfo<>(
                     serviceInfo,
                     new MdnsServiceRecord(serviceName,
@@ -211,8 +237,8 @@
                             attrsToTextEntries(serviceInfo.getAttributes())),
                     false /* sharedName */, true /* probing */);
 
-            final ArrayList<RecordInfo<?>> allRecords = new ArrayList<>(4);
-            allRecords.add(ptrRecord);
+            final ArrayList<RecordInfo<?>> allRecords = new ArrayList<>(5);
+            allRecords.addAll(ptrRecords);
             allRecords.add(srvRecord);
             allRecords.add(txtRecord);
             // Service type enumeration record (RFC6763 9.)
@@ -275,7 +301,8 @@
      *         ID of the replaced service.
      * @throws NameConflictException There is already a (non-exiting) service using the name.
      */
-    public int addService(int serviceId, NsdServiceInfo serviceInfo) throws NameConflictException {
+    public int addService(int serviceId, NsdServiceInfo serviceInfo, @Nullable String subtype)
+            throws NameConflictException {
         if (mServices.contains(serviceId)) {
             throw new IllegalArgumentException(
                     "Service ID must not be reused across registrations: " + serviceId);
@@ -288,7 +315,7 @@
         }
 
         final ServiceRegistration registration = new ServiceRegistration(
-                mDeviceHostname, serviceInfo);
+                mDeviceHostname, serviceInfo, subtype);
         mServices.put(serviceId, registration);
 
         // Remove existing exiting service
@@ -344,24 +371,25 @@
         if (registration == null) return null;
         if (registration.exiting) return null;
 
-        // Send exit (TTL 0) for the PTR record, if the record was sent (in particular don't send
+        // Send exit (TTL 0) for the PTR records, if at least one was sent (in particular don't send
         // if still probing)
-        if (registration.ptrRecord.lastSentTimeMs == 0L) {
+        if (CollectionUtils.all(registration.ptrRecords, r -> r.lastSentTimeMs == 0L)) {
             return null;
         }
 
         registration.exiting = true;
-        final MdnsPointerRecord expiredRecord = new MdnsPointerRecord(
-                registration.ptrRecord.record.getName(),
-                0L /* receiptTimeMillis */,
-                true /* cacheFlush */,
-                0L /* ttlMillis */,
-                registration.ptrRecord.record.getPointer());
+        final List<MdnsRecord> expiredRecords = CollectionUtils.map(registration.ptrRecords,
+                r -> new MdnsPointerRecord(
+                        r.record.getName(),
+                        0L /* receiptTimeMillis */,
+                        true /* cacheFlush */,
+                        0L /* ttlMillis */,
+                        r.record.getPointer()));
 
         // Exit should be skipped if the record is still advertised by another service, but that
         // would be a conflict (2 service registrations with the same service name), so it would
         // not have been allowed by the repository.
-        return new MdnsAnnouncer.ExitAnnouncementInfo(id, Collections.singletonList(expiredRecord));
+        return new MdnsAnnouncer.ExitAnnouncementInfo(id, expiredRecords);
     }
 
     public void removeService(int id) {
@@ -442,7 +470,7 @@
             for (int i = 0; i < mServices.size(); i++) {
                 final ServiceRegistration registration = mServices.valueAt(i);
                 if (registration.exiting) continue;
-                addReplyFromService(question, registration.allRecords, registration.ptrRecord,
+                addReplyFromService(question, registration.allRecords, registration.ptrRecords,
                         registration.srvRecord, registration.txtRecord, replyUnicast, now,
                         answerInfo, additionalAnswerRecords);
             }
@@ -499,7 +527,7 @@
      */
     private void addReplyFromService(@NonNull MdnsRecord question,
             @NonNull List<RecordInfo<?>> serviceRecords,
-            @Nullable RecordInfo<MdnsPointerRecord> servicePtrRecord,
+            @Nullable List<RecordInfo<MdnsPointerRecord>> servicePtrRecords,
             @Nullable RecordInfo<MdnsServiceRecord> serviceSrvRecord,
             @Nullable RecordInfo<MdnsTextRecord> serviceTxtRecord,
             boolean replyUnicast, long now, @NonNull List<RecordInfo<?>> answerInfo,
@@ -531,7 +559,8 @@
             }
 
             hasKnownAnswer = true;
-            hasDnsSdPtrRecordAnswer |= (info == servicePtrRecord);
+            hasDnsSdPtrRecordAnswer |= (servicePtrRecords != null
+                    && CollectionUtils.any(servicePtrRecords, r -> info == r));
             hasDnsSdSrvRecordAnswer |= (info == serviceSrvRecord);
 
             // TODO: responses to probe queries should bypass this check and only ensure the
@@ -791,10 +820,11 @@
      */
     @Nullable
     public MdnsProber.ProbingInfo renameServiceForConflict(int serviceId, NsdServiceInfo newInfo) {
-        if (!mServices.contains(serviceId)) return null;
+        final ServiceRegistration existing = mServices.get(serviceId);
+        if (existing == null) return null;
 
         final ServiceRegistration newService = new ServiceRegistration(
-                mDeviceHostname, newInfo);
+                mDeviceHostname, newInfo, existing.subtype);
         mServices.put(serviceId, newService);
         return makeProbingInfo(serviceId, newService.srvRecord.record);
     }
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
index 298e632..4e6571f 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
@@ -28,13 +28,16 @@
 
 import com.android.internal.annotations.GuardedBy;
 import com.android.internal.annotations.VisibleForTesting;
+import com.android.net.module.util.CollectionUtils;
 import com.android.net.module.util.SharedLog;
+import com.android.server.connectivity.mdns.util.MdnsUtils;
 
 import java.net.Inet4Address;
 import java.net.Inet6Address;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collection;
+import java.util.Collections;
 import java.util.HashMap;
 import java.util.Iterator;
 import java.util.List;
@@ -209,9 +212,22 @@
 
     private boolean responseMatchesOptions(@NonNull MdnsResponse response,
             @NonNull MdnsSearchOptions options) {
-        if (options.getResolveInstanceName() == null) return true;
-        // DNS is case-insensitive, so ignore case in the comparison
-        return options.getResolveInstanceName().equalsIgnoreCase(response.getServiceInstanceName());
+        final boolean matchesInstanceName = options.getResolveInstanceName() == null
+                // DNS is case-insensitive, so ignore case in the comparison
+                || MdnsUtils.equalsIgnoreDnsCase(options.getResolveInstanceName(),
+                        response.getServiceInstanceName());
+
+        // If discovery is requiring some subtypes, the response must have one that matches a
+        // requested one.
+        final List<String> responseSubtypes = response.getSubtypes() == null
+                ? Collections.emptyList() : response.getSubtypes();
+        final boolean matchesSubtype = options.getSubtypes().size() == 0
+                || CollectionUtils.any(options.getSubtypes(), requiredSub ->
+                CollectionUtils.any(responseSubtypes, actualSub ->
+                        MdnsUtils.equalsIgnoreDnsCase(
+                                MdnsConstants.SUBTYPE_PREFIX + requiredSub, actualSub)));
+
+        return matchesInstanceName && matchesSubtype;
     }
 
     /**
diff --git a/service/src/com/android/server/BpfNetMaps.java b/service/src/com/android/server/BpfNetMaps.java
index 84e581e..ec168dd 100644
--- a/service/src/com/android/server/BpfNetMaps.java
+++ b/service/src/com/android/server/BpfNetMaps.java
@@ -384,7 +384,6 @@
      * ALLOWLIST means the firewall denies all by default, uids must be explicitly allowed
      * DENYLIST means the firewall allows all by default, uids must be explicitly denyed
      */
-    @VisibleForTesting
     public boolean isFirewallAllowList(final int chain) {
         switch (chain) {
             case FIREWALL_CHAIN_DOZABLE:
@@ -745,6 +744,65 @@
         }
     }
 
+    private Set<Integer> getUidsMatchEnabled(final int childChain) throws ErrnoException {
+        final long match = getMatchByFirewallChain(childChain);
+        Set<Integer> uids = new ArraySet<>();
+        synchronized (sUidOwnerMap) {
+            sUidOwnerMap.forEach((uid, val) -> {
+                if (val == null) {
+                    Log.wtf(TAG, "sUidOwnerMap entry was deleted while holding a lock");
+                } else {
+                    if ((val.rule & match) != 0) {
+                        uids.add(uid.val);
+                    }
+                }
+            });
+        }
+        return uids;
+    }
+
+    /**
+     * Get uids that has FIREWALL_RULE_ALLOW on allowlist chain.
+     * Allowlist means the firewall denies all by default, uids must be explicitly allowed.
+     *
+     * Note that uids that has FIREWALL_RULE_DENY on allowlist chain can not be computed from the
+     * bpf map, since all the uids that does not have explicit FIREWALL_RULE_ALLOW rule in bpf map
+     * are determined to have FIREWALL_RULE_DENY.
+     *
+     * @param childChain target chain
+     * @return Set of uids
+     */
+    public Set<Integer> getUidsWithAllowRuleOnAllowListChain(final int childChain)
+            throws ErrnoException {
+        if (!isFirewallAllowList(childChain)) {
+            throw new IllegalArgumentException("getUidsWithAllowRuleOnAllowListChain is called with"
+                    + " denylist chain:" + childChain);
+        }
+        // Corresponding match is enabled for uids that has FIREWALL_RULE_ALLOW on allowlist chain.
+        return getUidsMatchEnabled(childChain);
+    }
+
+    /**
+     * Get uids that has FIREWALL_RULE_DENY on denylist chain.
+     * Denylist means the firewall allows all by default, uids must be explicitly denyed
+     *
+     * Note that uids that has FIREWALL_RULE_ALLOW on denylist chain can not be computed from the
+     * bpf map, since all the uids that does not have explicit FIREWALL_RULE_DENY rule in bpf map
+     * are determined to have the FIREWALL_RULE_ALLOW.
+     *
+     * @param childChain target chain
+     * @return Set of uids
+     */
+    public Set<Integer> getUidsWithDenyRuleOnDenyListChain(final int childChain)
+            throws ErrnoException {
+        if (isFirewallAllowList(childChain)) {
+            throw new IllegalArgumentException("getUidsWithDenyRuleOnDenyListChain is called with"
+                    + " allowlist chain:" + childChain);
+        }
+        // Corresponding match is enabled for uids that has FIREWALL_RULE_DENY on denylist chain.
+        return getUidsMatchEnabled(childChain);
+    }
+
     /**
      * Add ingress interface filtering rules to a list of UIDs
      *
diff --git a/service/src/com/android/server/ConnectivityService.java b/service/src/com/android/server/ConnectivityService.java
index b5c9b0a..b17af99 100755
--- a/service/src/com/android/server/ConnectivityService.java
+++ b/service/src/com/android/server/ConnectivityService.java
@@ -1509,6 +1509,16 @@
                 throws SocketException, InterruptedIOException, ErrnoException {
             InetDiagMessage.destroyLiveTcpSockets(ranges, exemptUids);
         }
+
+        /**
+         * Call {@link InetDiagMessage#destroyLiveTcpSocketsByOwnerUids(Set)}
+         *
+         * @param ownerUids target uids to close sockets
+         */
+        public void destroyLiveTcpSocketsByOwnerUids(final Set<Integer> ownerUids)
+                throws SocketException, InterruptedIOException, ErrnoException {
+            InetDiagMessage.destroyLiveTcpSocketsByOwnerUids(ownerUids);
+        }
     }
 
     public ConnectivityService(Context context) {
@@ -12048,6 +12058,23 @@
         return rule;
     }
 
+    private void closeSocketsForFirewallChainLocked(final int chain)
+            throws ErrnoException, SocketException, InterruptedIOException {
+        if (mBpfNetMaps.isFirewallAllowList(chain)) {
+            // Allowlist means the firewall denies all by default, uids must be explicitly allowed
+            // So, close all non-system socket owned by uids that are not explicitly allowed
+            Set<Range<Integer>> ranges = new ArraySet<>();
+            ranges.add(new Range<>(Process.FIRST_APPLICATION_UID, Integer.MAX_VALUE));
+            final Set<Integer> exemptUids = mBpfNetMaps.getUidsWithAllowRuleOnAllowListChain(chain);
+            mDeps.destroyLiveTcpSockets(ranges, exemptUids);
+        } else {
+            // Denylist means the firewall allows all by default, uids must be explicitly denied
+            // So, close socket owned by uids that are explicitly denied
+            final Set<Integer> ownerUids = mBpfNetMaps.getUidsWithDenyRuleOnDenyListChain(chain);
+            mDeps.destroyLiveTcpSocketsByOwnerUids(ownerUids);
+        }
+    }
+
     @Override
     public void setFirewallChainEnabled(final int chain, final boolean enable) {
         enforceNetworkStackOrSettingsPermission();
@@ -12057,6 +12084,14 @@
         } catch (ServiceSpecificException e) {
             throw new IllegalStateException(e);
         }
+
+        if (SdkLevel.isAtLeastU() && enable) {
+            try {
+                closeSocketsForFirewallChainLocked(chain);
+            } catch (ErrnoException | SocketException | InterruptedIOException e) {
+                Log.e(TAG, "Failed to close sockets after enabling chain (" + chain + "): " + e);
+            }
+        }
     }
 
     @Override
diff --git a/tests/cts/net/src/android/net/cts/ConnectivityManagerTest.java b/tests/cts/net/src/android/net/cts/ConnectivityManagerTest.java
index 8b059e3..ee2f6bb 100644
--- a/tests/cts/net/src/android/net/cts/ConnectivityManagerTest.java
+++ b/tests/cts/net/src/android/net/cts/ConnectivityManagerTest.java
@@ -225,6 +225,7 @@
 import java.net.InetSocketAddress;
 import java.net.MalformedURLException;
 import java.net.Socket;
+import java.net.SocketException;
 import java.net.URL;
 import java.net.UnknownHostException;
 import java.nio.charset.StandardCharsets;
@@ -278,6 +279,7 @@
     // TODO(b/252972908): reset the original timer when aosp/2188755 is ramped up.
     private static final int LISTEN_ACTIVITY_TIMEOUT_MS = 30_000;
     private static final int NO_CALLBACK_TIMEOUT_MS = 100;
+    private static final int NETWORK_REQUEST_TIMEOUT_MS = 3000;
     private static final int SOCKET_TIMEOUT_MS = 100;
     private static final int NUM_TRIES_MULTIPATH_PREF_CHECK = 20;
     private static final long INTERVAL_MULTIPATH_PREF_CHECK_MS = 500;
@@ -3553,6 +3555,103 @@
         doTestFirewallBlocking(FIREWALL_CHAIN_OEM_DENY_3, DENYLIST);
     }
 
+    private void assertSocketOpen(final Socket socket) throws Exception {
+        mCtsNetUtils.testHttpRequest(socket);
+    }
+
+    private void assertSocketClosed(final Socket socket) throws Exception {
+        try {
+            mCtsNetUtils.testHttpRequest(socket);
+            fail("Socket is expected to be closed");
+        } catch (SocketException expected) {
+        }
+    }
+
+    private static final boolean EXPECT_OPEN = false;
+    private static final boolean EXPECT_CLOSE = true;
+
+    private void doTestFirewallCloseSocket(final int chain, final int rule, final int targetUid,
+            final boolean expectClose) {
+        runWithShellPermissionIdentity(() -> {
+            // Firewall chain status will be restored after the test.
+            final boolean wasChainEnabled = mCm.getFirewallChainEnabled(chain);
+            final int previousUidFirewallRule = mCm.getUidFirewallRule(chain, targetUid);
+            final Socket socket = new Socket(TEST_HOST, HTTP_PORT);
+            socket.setSoTimeout(NETWORK_REQUEST_TIMEOUT_MS);
+            testAndCleanup(() -> {
+                mCm.setFirewallChainEnabled(chain, false /* enable */);
+                assertSocketOpen(socket);
+
+                try {
+                    mCm.setUidFirewallRule(chain, targetUid, rule);
+                } catch (IllegalStateException ignored) {
+                    // Removing match causes an exception when the rule entry for the uid does
+                    // not exist. But this is fine and can be ignored.
+                }
+                mCm.setFirewallChainEnabled(chain, true /* enable */);
+
+                if (expectClose) {
+                    assertSocketClosed(socket);
+                } else {
+                    assertSocketOpen(socket);
+                }
+            }, /* cleanup */ () -> {
+                    // Restore the global chain status
+                    mCm.setFirewallChainEnabled(chain, wasChainEnabled);
+                }, /* cleanup */ () -> {
+                    // Restore the uid firewall rule status
+                    try {
+                        mCm.setUidFirewallRule(chain, targetUid, previousUidFirewallRule);
+                    } catch (IllegalStateException ignored) {
+                        // Removing match causes an exception when the rule entry for the uid does
+                        // not exist. But this is fine and can be ignored.
+                    }
+                }, /* cleanup */ () -> {
+                    socket.close();
+                });
+        }, NETWORK_SETTINGS);
+    }
+
+    @Test @IgnoreUpTo(Build.VERSION_CODES.TIRAMISU) @ConnectivityModuleTest
+    public void testFirewallCloseSocketAllowlistChainAllow() {
+        doTestFirewallCloseSocket(FIREWALL_CHAIN_DOZABLE, FIREWALL_RULE_ALLOW,
+                Process.myUid(), EXPECT_OPEN);
+    }
+
+    @Test @IgnoreUpTo(Build.VERSION_CODES.TIRAMISU) @ConnectivityModuleTest
+    public void testFirewallCloseSocketAllowlistChainDeny() {
+        doTestFirewallCloseSocket(FIREWALL_CHAIN_DOZABLE, FIREWALL_RULE_DENY,
+                Process.myUid(), EXPECT_CLOSE);
+    }
+
+    @Test @IgnoreUpTo(Build.VERSION_CODES.TIRAMISU) @ConnectivityModuleTest
+    public void testFirewallCloseSocketAllowlistChainOtherUid() {
+        doTestFirewallCloseSocket(FIREWALL_CHAIN_DOZABLE, FIREWALL_RULE_ALLOW,
+                Process.myUid() + 1, EXPECT_CLOSE);
+        doTestFirewallCloseSocket(FIREWALL_CHAIN_DOZABLE, FIREWALL_RULE_DENY,
+                Process.myUid() + 1, EXPECT_CLOSE);
+    }
+
+    @Test @IgnoreUpTo(Build.VERSION_CODES.TIRAMISU) @ConnectivityModuleTest
+    public void testFirewallCloseSocketDenylistChainAllow() {
+        doTestFirewallCloseSocket(FIREWALL_CHAIN_STANDBY, FIREWALL_RULE_ALLOW,
+                Process.myUid(), EXPECT_OPEN);
+    }
+
+    @Test @IgnoreUpTo(Build.VERSION_CODES.TIRAMISU) @ConnectivityModuleTest
+    public void testFirewallCloseSocketDenylistChainDeny() {
+        doTestFirewallCloseSocket(FIREWALL_CHAIN_STANDBY, FIREWALL_RULE_DENY,
+                Process.myUid(), EXPECT_CLOSE);
+    }
+
+    @Test @IgnoreUpTo(Build.VERSION_CODES.TIRAMISU) @ConnectivityModuleTest
+    public void testFirewallCloseSocketDenylistChainOtherUid() {
+        doTestFirewallCloseSocket(FIREWALL_CHAIN_STANDBY, FIREWALL_RULE_ALLOW,
+                Process.myUid() + 1, EXPECT_OPEN);
+        doTestFirewallCloseSocket(FIREWALL_CHAIN_STANDBY, FIREWALL_RULE_DENY,
+                Process.myUid() + 1, EXPECT_OPEN);
+    }
+
     private void assumeTestSApis() {
         // Cannot use @IgnoreUpTo(Build.VERSION_CODES.R) because this test also requires API 31
         // shims, and @IgnoreUpTo does not check that.
diff --git a/tests/cts/net/src/android/net/cts/NsdManagerTest.kt b/tests/cts/net/src/android/net/cts/NsdManagerTest.kt
index db7f38c..88b9baf 100644
--- a/tests/cts/net/src/android/net/cts/NsdManagerTest.kt
+++ b/tests/cts/net/src/android/net/cts/NsdManagerTest.kt
@@ -282,13 +282,17 @@
 
         fun waitForServiceDiscovered(
             serviceName: String,
+            serviceType: String,
             expectedNetwork: Network? = null
         ): NsdServiceInfo {
-            return expectCallbackEventually<ServiceFound> {
+            val serviceFound = expectCallbackEventually<ServiceFound> {
                 it.serviceInfo.serviceName == serviceName &&
                         (expectedNetwork == null ||
                                 expectedNetwork == nsdShim.getNetwork(it.serviceInfo))
             }.serviceInfo
+            // Discovered service types have a dot at the end
+            assertEquals("$serviceType.", serviceFound.serviceType)
+            return serviceFound
         }
     }
 
@@ -497,6 +501,10 @@
         val registeredInfo = registrationRecord.expectCallback<ServiceRegistered>(
                 REGISTRATION_TIMEOUT_MS).serviceInfo
 
+        // Only service name is included in ServiceRegistered callbacks
+        assertNull(registeredInfo.serviceType)
+        assertEquals(si.serviceName, registeredInfo.serviceName)
+
         val discoveryRecord = NsdDiscoveryRecord()
         // Test discovering without an Executor
         nsdManager.discoverServices(serviceType, NsdManager.PROTOCOL_DNS_SD, discoveryRecord)
@@ -505,12 +513,15 @@
         discoveryRecord.expectCallback<DiscoveryStarted>()
 
         // Expect a service record to be discovered
-        val foundInfo = discoveryRecord.waitForServiceDiscovered(registeredInfo.serviceName)
+        val foundInfo = discoveryRecord.waitForServiceDiscovered(
+                registeredInfo.serviceName, serviceType)
 
         // Test resolving without an Executor
         val resolveRecord = NsdResolveRecord()
         nsdManager.resolveService(foundInfo, resolveRecord)
         val resolvedService = resolveRecord.expectCallback<ServiceResolved>().serviceInfo
+        assertEquals(".$serviceType", resolvedService.serviceType)
+        assertEquals(registeredInfo.serviceName, resolvedService.serviceName)
 
         // Check Txt attributes
         assertEquals(8, resolvedService.attributes.size)
@@ -538,9 +549,11 @@
         registrationRecord.expectCallback<ServiceUnregistered>()
 
         // Expect a callback for service lost
-        discoveryRecord.expectCallbackEventually<ServiceLost> {
+        val lostCb = discoveryRecord.expectCallbackEventually<ServiceLost> {
             it.serviceInfo.serviceName == serviceName
         }
+        // Lost service types have a dot at the end
+        assertEquals("$serviceType.", lostCb.serviceInfo.serviceType)
 
         // Register service again to see if NsdManager can discover it
         val si2 = NsdServiceInfo()
@@ -554,7 +567,8 @@
 
         // Expect a service record to be discovered (and filter the ones
         // that are unrelated to this test)
-        val foundInfo2 = discoveryRecord.waitForServiceDiscovered(registeredInfo2.serviceName)
+        val foundInfo2 = discoveryRecord.waitForServiceDiscovered(
+                registeredInfo2.serviceName, serviceType)
 
         // Resolve the service
         val resolveRecord2 = NsdResolveRecord()
@@ -591,7 +605,7 @@
                     testNetwork1.network, Executor { it.run() }, discoveryRecord)
 
             val foundInfo = discoveryRecord.waitForServiceDiscovered(
-                    serviceName, testNetwork1.network)
+                    serviceName, serviceType, testNetwork1.network)
             assertEquals(testNetwork1.network, nsdShim.getNetwork(foundInfo))
 
             // Rewind to ensure the service is not found on the other interface
@@ -638,6 +652,8 @@
 
             val serviceDiscovered = discoveryRecord.expectCallback<ServiceFound>()
             assertEquals(registeredInfo1.serviceName, serviceDiscovered.serviceInfo.serviceName)
+            // Discovered service types have a dot at the end
+            assertEquals("$serviceType.", serviceDiscovered.serviceInfo.serviceType)
             assertEquals(testNetwork1.network, nsdShim.getNetwork(serviceDiscovered.serviceInfo))
 
             // Unregister, then register the service back: it should be lost and found again
@@ -650,6 +666,7 @@
             val registeredInfo2 = registerService(registrationRecord, si, executor)
             val serviceDiscovered2 = discoveryRecord.expectCallback<ServiceFound>()
             assertEquals(registeredInfo2.serviceName, serviceDiscovered2.serviceInfo.serviceName)
+            assertEquals("$serviceType.", serviceDiscovered2.serviceInfo.serviceType)
             assertEquals(testNetwork1.network, nsdShim.getNetwork(serviceDiscovered2.serviceInfo))
 
             // Teardown, then bring back up a network on the test interface: the service should
@@ -665,6 +682,7 @@
             val newNetwork = newAgent.network ?: fail("Registered agent should have a network")
             val serviceDiscovered3 = discoveryRecord.expectCallback<ServiceFound>()
             assertEquals(registeredInfo2.serviceName, serviceDiscovered3.serviceInfo.serviceName)
+            assertEquals("$serviceType.", serviceDiscovered3.serviceInfo.serviceType)
             assertEquals(newNetwork, nsdShim.getNetwork(serviceDiscovered3.serviceInfo))
         } cleanupStep {
             nsdManager.stopServiceDiscovery(discoveryRecord)
@@ -740,12 +758,12 @@
             nsdManager.discoverServices(serviceType, NsdManager.PROTOCOL_DNS_SD, discoveryRecord)
 
             val foundInfo1 = discoveryRecord.waitForServiceDiscovered(
-                    serviceName, testNetwork1.network)
+                    serviceName, serviceType, testNetwork1.network)
             assertEquals(testNetwork1.network, nsdShim.getNetwork(foundInfo1))
             // Rewind as the service could be found on each interface in any order
             discoveryRecord.nextEvents.rewind(0)
             val foundInfo2 = discoveryRecord.waitForServiceDiscovered(
-                    serviceName, testNetwork2.network)
+                    serviceName, serviceType, testNetwork2.network)
             assertEquals(testNetwork2.network, nsdShim.getNetwork(foundInfo2))
 
             nsdShim.resolveService(nsdManager, foundInfo1, Executor { it.run() }, resolveRecord)
@@ -790,7 +808,7 @@
                 testNetwork1.network, Executor { it.run() }, discoveryRecord)
             // Expect that service is found on testNetwork1
             val foundInfo = discoveryRecord.waitForServiceDiscovered(
-                serviceName, testNetwork1.network)
+                serviceName, serviceType, testNetwork1.network)
             assertEquals(testNetwork1.network, nsdShim.getNetwork(foundInfo))
 
             // Discover service on testNetwork2.
@@ -805,7 +823,7 @@
                 null as Network? /* network */, Executor { it.run() }, discoveryRecord3)
             // Expect that service is found on testNetwork1
             val foundInfo3 = discoveryRecord3.waitForServiceDiscovered(
-                    serviceName, testNetwork1.network)
+                    serviceName, serviceType, testNetwork1.network)
             assertEquals(testNetwork1.network, nsdShim.getNetwork(foundInfo3))
         } cleanupStep {
             nsdManager.stopServiceDiscovery(discoveryRecord2)
@@ -835,7 +853,7 @@
             nsdManager.discoverServices(
                 serviceType, NsdManager.PROTOCOL_DNS_SD, discoveryRecord
             )
-            val foundInfo = discoveryRecord.waitForServiceDiscovered(serviceNames)
+            val foundInfo = discoveryRecord.waitForServiceDiscovered(serviceNames, serviceType)
 
             // Expect that resolving the service name works properly even service name contains
             // non-standard characters.
@@ -985,7 +1003,7 @@
             nsdShim.discoverServices(nsdManager, serviceType, NsdManager.PROTOCOL_DNS_SD,
                     testNetwork1.network, Executor { it.run() }, discoveryRecord)
             val foundInfo = discoveryRecord.waitForServiceDiscovered(
-                    serviceName, testNetwork1.network)
+                    serviceName, serviceType, testNetwork1.network)
 
             // Register service callback and check the addresses are the same as network addresses
             nsdShim.registerServiceInfoCallback(nsdManager, foundInfo, { it.run() }, cbRecord)
diff --git a/tests/cts/net/util/java/android/net/cts/util/CtsNetUtils.java b/tests/cts/net/util/java/android/net/cts/util/CtsNetUtils.java
index 0c4f794..ce789fc 100644
--- a/tests/cts/net/util/java/android/net/cts/util/CtsNetUtils.java
+++ b/tests/cts/net/util/java/android/net/cts/util/CtsNetUtils.java
@@ -426,7 +426,7 @@
                 .build();
     }
 
-    private void testHttpRequest(Socket s) throws IOException {
+    public void testHttpRequest(Socket s) throws IOException {
         OutputStream out = s.getOutputStream();
         InputStream in = s.getInputStream();
 
@@ -434,7 +434,9 @@
         byte[] responseBytes = new byte[4096];
         out.write(requestBytes);
         in.read(responseBytes);
-        assertTrue(new String(responseBytes, "UTF-8").startsWith("HTTP/1.0 204 No Content\r\n"));
+        final String response = new String(responseBytes, "UTF-8");
+        assertTrue("Received unexpected response: " + response,
+                response.startsWith("HTTP/1.0 204 No Content\r\n"));
     }
 
     private Socket getBoundSocket(Network network, String host, int port) throws IOException {
diff --git a/tests/unit/java/com/android/server/BpfNetMapsTest.java b/tests/unit/java/com/android/server/BpfNetMapsTest.java
index d189848..19fa41d 100644
--- a/tests/unit/java/com/android/server/BpfNetMapsTest.java
+++ b/tests/unit/java/com/android/server/BpfNetMapsTest.java
@@ -66,6 +66,7 @@
 import android.os.Build;
 import android.os.ServiceSpecificException;
 import android.system.ErrnoException;
+import android.util.ArraySet;
 import android.util.IndentingPrintWriter;
 
 import androidx.test.filters.SmallTest;
@@ -1151,4 +1152,33 @@
         mCookieTagMap.updateEntry(new CookieTagMapKey(123), new CookieTagMapValue(456, 0x789));
         assertDumpContains(getDump(), "cookie=123 tag=0x789 uid=456");
     }
+
+    @Test
+    public void testGetUids() throws ErrnoException {
+        final int uid0 = TEST_UIDS[0];
+        final int uid1 = TEST_UIDS[1];
+        final long match0 = DOZABLE_MATCH | POWERSAVE_MATCH;
+        final long match1 = DOZABLE_MATCH | STANDBY_MATCH;
+        mUidOwnerMap.updateEntry(new S32(uid0), new UidOwnerValue(NULL_IIF, match0));
+        mUidOwnerMap.updateEntry(new S32(uid1), new UidOwnerValue(NULL_IIF, match1));
+
+        assertEquals(new ArraySet<>(List.of(uid0, uid1)),
+                mBpfNetMaps.getUidsWithAllowRuleOnAllowListChain(FIREWALL_CHAIN_DOZABLE));
+        assertEquals(new ArraySet<>(List.of(uid0)),
+                mBpfNetMaps.getUidsWithAllowRuleOnAllowListChain(FIREWALL_CHAIN_POWERSAVE));
+
+        assertEquals(new ArraySet<>(List.of(uid1)),
+                mBpfNetMaps.getUidsWithDenyRuleOnDenyListChain(FIREWALL_CHAIN_STANDBY));
+        assertEquals(new ArraySet<>(),
+                mBpfNetMaps.getUidsWithDenyRuleOnDenyListChain(FIREWALL_CHAIN_OEM_DENY_1));
+    }
+
+    @Test
+    public void testGetUidsIllegalArgument() {
+        final Class<IllegalArgumentException> expected = IllegalArgumentException.class;
+        assertThrows(expected,
+                () -> mBpfNetMaps.getUidsWithDenyRuleOnDenyListChain(FIREWALL_CHAIN_DOZABLE));
+        assertThrows(expected,
+                () -> mBpfNetMaps.getUidsWithAllowRuleOnAllowListChain(FIREWALL_CHAIN_OEM_DENY_1));
+    }
 }
diff --git a/tests/unit/java/com/android/server/ConnectivityServiceTest.java b/tests/unit/java/com/android/server/ConnectivityServiceTest.java
index 9d7b21f..31f3124 100755
--- a/tests/unit/java/com/android/server/ConnectivityServiceTest.java
+++ b/tests/unit/java/com/android/server/ConnectivityServiceTest.java
@@ -2173,6 +2173,11 @@
                 final Set<Integer> exemptUids) {
             // This function is empty since the invocation of this method is verified by mocks
         }
+
+        @Override
+        public void destroyLiveTcpSocketsByOwnerUids(final Set<Integer> ownerUids) {
+            // This function is empty since the invocation of this method is verified by mocks
+        }
     }
 
     private class AutomaticOnOffKeepaliveTrackerDependencies
@@ -10269,6 +10274,50 @@
         }
     }
 
+    private void doTestSetFirewallChainEnabledCloseSocket(final int chain,
+            final boolean isAllowList) throws Exception {
+        reset(mDeps);
+
+        mCm.setFirewallChainEnabled(chain, true /* enabled */);
+        final Set<Integer> uids =
+                new ArraySet<>(List.of(TEST_PACKAGE_UID, TEST_PACKAGE_UID2));
+        if (isAllowList) {
+            final Set<Range<Integer>> range = new ArraySet<>(
+                    List.of(new Range<>(Process.FIRST_APPLICATION_UID, Integer.MAX_VALUE)));
+            verify(mDeps).destroyLiveTcpSockets(range, uids);
+        } else {
+            verify(mDeps).destroyLiveTcpSocketsByOwnerUids(uids);
+        }
+
+        mCm.setFirewallChainEnabled(chain, false /* enabled */);
+        verifyNoMoreInteractions(mDeps);
+    }
+
+    @Test @IgnoreUpTo(Build.VERSION_CODES.TIRAMISU)
+    public void testSetFirewallChainEnabledCloseSocket() throws Exception {
+        doReturn(new ArraySet<>(Arrays.asList(TEST_PACKAGE_UID, TEST_PACKAGE_UID2)))
+                .when(mBpfNetMaps)
+                .getUidsWithDenyRuleOnDenyListChain(anyInt());
+        doReturn(new ArraySet<>(Arrays.asList(TEST_PACKAGE_UID, TEST_PACKAGE_UID2)))
+                .when(mBpfNetMaps)
+                .getUidsWithAllowRuleOnAllowListChain(anyInt());
+
+        final boolean allowlist = true;
+        final boolean denylist = false;
+
+        doReturn(true).when(mBpfNetMaps).isFirewallAllowList(anyInt());
+        doTestSetFirewallChainEnabledCloseSocket(FIREWALL_CHAIN_DOZABLE, allowlist);
+        doTestSetFirewallChainEnabledCloseSocket(FIREWALL_CHAIN_POWERSAVE, allowlist);
+        doTestSetFirewallChainEnabledCloseSocket(FIREWALL_CHAIN_RESTRICTED, allowlist);
+        doTestSetFirewallChainEnabledCloseSocket(FIREWALL_CHAIN_LOW_POWER_STANDBY, allowlist);
+
+        doReturn(false).when(mBpfNetMaps).isFirewallAllowList(anyInt());
+        doTestSetFirewallChainEnabledCloseSocket(FIREWALL_CHAIN_STANDBY, denylist);
+        doTestSetFirewallChainEnabledCloseSocket(FIREWALL_CHAIN_OEM_DENY_1, denylist);
+        doTestSetFirewallChainEnabledCloseSocket(FIREWALL_CHAIN_OEM_DENY_2, denylist);
+        doTestSetFirewallChainEnabledCloseSocket(FIREWALL_CHAIN_OEM_DENY_3, denylist);
+    }
+
     private void doTestReplaceFirewallChain(final int chain) {
         final int[] uids = new int[] {1001, 1002};
         mCm.replaceFirewallChain(chain, uids);
diff --git a/tests/unit/java/com/android/server/NsdServiceTest.java b/tests/unit/java/com/android/server/NsdServiceTest.java
index 14b4280..b3e8cc8 100644
--- a/tests/unit/java/com/android/server/NsdServiceTest.java
+++ b/tests/unit/java/com/android/server/NsdServiceTest.java
@@ -23,7 +23,7 @@
 import static android.net.nsd.NsdManager.FAILURE_INTERNAL_ERROR;
 import static android.net.nsd.NsdManager.FAILURE_OPERATION_NOT_RUNNING;
 
-import static com.android.server.NsdService.constructServiceType;
+import static com.android.server.NsdService.parseTypeAndSubtype;
 import static com.android.testutils.ContextUtils.mockService;
 
 import static libcore.junit.util.compat.CoreCompatChangeRule.DisableCompatChanges;
@@ -77,6 +77,7 @@
 import android.os.Looper;
 import android.os.Message;
 import android.os.RemoteException;
+import android.util.Pair;
 
 import androidx.annotation.NonNull;
 import androidx.test.filters.SmallTest;
@@ -84,6 +85,7 @@
 import com.android.server.NsdService.Dependencies;
 import com.android.server.connectivity.mdns.MdnsAdvertiser;
 import com.android.server.connectivity.mdns.MdnsDiscoveryManager;
+import com.android.server.connectivity.mdns.MdnsSearchOptions;
 import com.android.server.connectivity.mdns.MdnsServiceBrowserListener;
 import com.android.server.connectivity.mdns.MdnsServiceInfo;
 import com.android.server.connectivity.mdns.MdnsSocketProvider;
@@ -104,6 +106,7 @@
 
 import java.net.InetAddress;
 import java.net.UnknownHostException;
+import java.util.Collections;
 import java.util.LinkedList;
 import java.util.List;
 import java.util.Objects;
@@ -908,7 +911,8 @@
         listener.onServiceNameDiscovered(foundInfo);
         verify(discListener, timeout(TIMEOUT_MS)).onServiceFound(argThat(info ->
                 info.getServiceName().equals(SERVICE_NAME)
-                        && info.getServiceType().equals(SERVICE_TYPE)
+                        // Service type in discovery callbacks has a dot at the end
+                        && info.getServiceType().equals(SERVICE_TYPE + ".")
                         && info.getNetwork().equals(network)));
 
         final MdnsServiceInfo removedInfo = new MdnsServiceInfo(
@@ -927,7 +931,8 @@
         listener.onServiceNameRemoved(removedInfo);
         verify(discListener, timeout(TIMEOUT_MS)).onServiceLost(argThat(info ->
                 info.getServiceName().equals(SERVICE_NAME)
-                        && info.getServiceType().equals(SERVICE_TYPE)
+                        // Service type in discovery callbacks has a dot at the end
+                        && info.getServiceType().equals(SERVICE_TYPE + ".")
                         && info.getNetwork().equals(network)));
 
         client.stopServiceDiscovery(discListener);
@@ -967,6 +972,34 @@
     }
 
     @Test
+    @EnableCompatChanges(ENABLE_PLATFORM_MDNS_BACKEND)
+    public void testDiscoveryWithMdnsDiscoveryManager_UsesSubtypes() {
+        final String typeWithSubtype = SERVICE_TYPE + ",_subtype";
+        final NsdManager client = connectClient(mService);
+        final NsdServiceInfo regInfo = new NsdServiceInfo("Instance", typeWithSubtype);
+        final Network network = new Network(999);
+        regInfo.setHostAddresses(List.of(parseNumericAddress("192.0.2.123")));
+        regInfo.setPort(12345);
+        regInfo.setNetwork(network);
+
+        final RegistrationListener regListener = mock(RegistrationListener.class);
+        client.registerService(regInfo, NsdManager.PROTOCOL_DNS_SD, Runnable::run, regListener);
+        waitForIdle();
+        verify(mAdvertiser).addService(anyInt(), argThat(s ->
+                "Instance".equals(s.getServiceName())
+                        && SERVICE_TYPE.equals(s.getServiceType())), eq("_subtype"));
+
+        final DiscoveryListener discListener = mock(DiscoveryListener.class);
+        client.discoverServices(typeWithSubtype, PROTOCOL, network, Runnable::run, discListener);
+        waitForIdle();
+        final ArgumentCaptor<MdnsSearchOptions> optionsCaptor =
+                ArgumentCaptor.forClass(MdnsSearchOptions.class);
+        verify(mDiscoveryManager).registerListener(eq(SERVICE_TYPE + ".local"), any(),
+                optionsCaptor.capture());
+        assertEquals(Collections.singletonList("subtype"), optionsCaptor.getValue().getSubtypes());
+    }
+
+    @Test
     public void testResolutionWithMdnsDiscoveryManager() throws UnknownHostException {
         setMdnsDiscoveryManagerEnabled();
 
@@ -974,7 +1007,7 @@
         final ResolveListener resolveListener = mock(ResolveListener.class);
         final Network network = new Network(999);
         final String serviceType = "_nsd._service._tcp";
-        final String constructedServiceType = "_nsd._sub._service._tcp.local";
+        final String constructedServiceType = "_service._tcp.local";
         final ArgumentCaptor<MdnsServiceBrowserListener> listenerCaptor =
                 ArgumentCaptor.forClass(MdnsServiceBrowserListener.class);
         final NsdServiceInfo request = new NsdServiceInfo(SERVICE_NAME, serviceType);
@@ -982,10 +1015,14 @@
         client.resolveService(request, resolveListener);
         waitForIdle();
         verify(mSocketProvider).startMonitoringSockets();
+        final ArgumentCaptor<MdnsSearchOptions> optionsCaptor =
+                ArgumentCaptor.forClass(MdnsSearchOptions.class);
         verify(mDiscoveryManager).registerListener(eq(constructedServiceType),
-                listenerCaptor.capture(), argThat(options ->
-                        network.equals(options.getNetwork())
-                                && SERVICE_NAME.equals(options.getResolveInstanceName())));
+                listenerCaptor.capture(),
+                optionsCaptor.capture());
+        assertEquals(network, optionsCaptor.getValue().getNetwork());
+        // Subtypes are not used for resolution, only for discovery
+        assertEquals(Collections.emptyList(), optionsCaptor.getValue().getSubtypes());
 
         final MdnsServiceBrowserListener listener = listenerCaptor.getValue();
         final MdnsServiceInfo mdnsServiceInfo = new MdnsServiceInfo(
@@ -1009,7 +1046,7 @@
         verify(resolveListener, timeout(TIMEOUT_MS)).onServiceResolved(infoCaptor.capture());
         final NsdServiceInfo info = infoCaptor.getValue();
         assertEquals(SERVICE_NAME, info.getServiceName());
-        assertEquals("." + serviceType, info.getServiceType());
+        assertEquals("._service._tcp", info.getServiceType());
         assertEquals(PORT, info.getPort());
         assertTrue(info.getAttributes().containsKey("key"));
         assertEquals(1, info.getAttributes().size());
@@ -1052,7 +1089,7 @@
 
         final ArgumentCaptor<Integer> serviceIdCaptor = ArgumentCaptor.forClass(Integer.class);
         verify(mAdvertiser).addService(serviceIdCaptor.capture(),
-                argThat(info -> matches(info, regInfo)));
+                argThat(info -> matches(info, regInfo)), eq(null) /* subtype */);
 
         client.unregisterService(regListenerWithoutFeature);
         waitForIdle();
@@ -1109,8 +1146,10 @@
         waitForIdle();
 
         // The advertiser is enabled for _type2 but not _type1
-        verify(mAdvertiser, never()).addService(anyInt(), argThat(info -> matches(info, service1)));
-        verify(mAdvertiser).addService(anyInt(), argThat(info -> matches(info, service2)));
+        verify(mAdvertiser, never()).addService(
+                anyInt(), argThat(info -> matches(info, service1)), eq(null) /* subtype */);
+        verify(mAdvertiser).addService(
+                anyInt(), argThat(info -> matches(info, service2)), eq(null) /* subtype */);
     }
 
     @Test
@@ -1135,7 +1174,7 @@
         verify(mSocketProvider).startMonitoringSockets();
         final ArgumentCaptor<Integer> idCaptor = ArgumentCaptor.forClass(Integer.class);
         verify(mAdvertiser).addService(idCaptor.capture(), argThat(info ->
-                matches(info, regInfo)));
+                matches(info, regInfo)), eq(null) /* subtype */);
 
         // Verify onServiceRegistered callback
         final MdnsAdvertiser.AdvertiserCallback cb = cbCaptor.getValue();
@@ -1171,7 +1210,7 @@
 
         client.registerService(regInfo, NsdManager.PROTOCOL_DNS_SD, Runnable::run, regListener);
         waitForIdle();
-        verify(mAdvertiser, never()).addService(anyInt(), any());
+        verify(mAdvertiser, never()).addService(anyInt(), any(), any());
 
         verify(regListener, timeout(TIMEOUT_MS)).onRegistrationFailed(
                 argThat(info -> matches(info, regInfo)), eq(FAILURE_INTERNAL_ERROR));
@@ -1199,7 +1238,8 @@
         final ArgumentCaptor<Integer> idCaptor = ArgumentCaptor.forClass(Integer.class);
         // Service name is truncated to 63 characters
         verify(mAdvertiser).addService(idCaptor.capture(),
-                argThat(info -> info.getServiceName().equals("a".repeat(63))));
+                argThat(info -> info.getServiceName().equals("a".repeat(63))),
+                eq(null) /* subtype */);
 
         // Verify onServiceRegistered callback
         final MdnsAdvertiser.AdvertiserCallback cb = cbCaptor.getValue();
@@ -1217,7 +1257,7 @@
         final ResolveListener resolveListener = mock(ResolveListener.class);
         final Network network = new Network(999);
         final String serviceType = "_nsd._service._tcp";
-        final String constructedServiceType = "_nsd._sub._service._tcp.local";
+        final String constructedServiceType = "_service._tcp.local";
         final ArgumentCaptor<MdnsServiceBrowserListener> listenerCaptor =
                 ArgumentCaptor.forClass(MdnsServiceBrowserListener.class);
         final NsdServiceInfo request = new NsdServiceInfo(SERVICE_NAME, serviceType);
@@ -1225,8 +1265,14 @@
         client.resolveService(request, resolveListener);
         waitForIdle();
         verify(mSocketProvider).startMonitoringSockets();
+        final ArgumentCaptor<MdnsSearchOptions> optionsCaptor =
+                ArgumentCaptor.forClass(MdnsSearchOptions.class);
         verify(mDiscoveryManager).registerListener(eq(constructedServiceType),
-                listenerCaptor.capture(), argThat(options -> network.equals(options.getNetwork())));
+                listenerCaptor.capture(),
+                optionsCaptor.capture());
+        assertEquals(network, optionsCaptor.getValue().getNetwork());
+        // Subtypes are not used for resolution, only for discovery
+        assertEquals(Collections.emptyList(), optionsCaptor.getValue().getSubtypes());
 
         client.stopServiceResolution(resolveListener);
         waitForIdle();
@@ -1241,16 +1287,22 @@
     }
 
     @Test
-    public void testConstructServiceType() {
+    public void testParseTypeAndSubtype() {
         final String serviceType1 = "test._tcp";
         final String serviceType2 = "_test._quic";
-        final String serviceType3 = "_123._udp.";
-        final String serviceType4 = "_TEST._999._tcp.";
+        final String serviceType3 = "_test._quic,_test1,_test2";
+        final String serviceType4 = "_123._udp.";
+        final String serviceType5 = "_TEST._999._tcp.";
+        final String serviceType6 = "_998._tcp.,_TEST";
+        final String serviceType7 = "_997._tcp,_TEST";
 
-        assertEquals(null, constructServiceType(serviceType1));
-        assertEquals(null, constructServiceType(serviceType2));
-        assertEquals("_123._udp", constructServiceType(serviceType3));
-        assertEquals("_TEST._sub._999._tcp", constructServiceType(serviceType4));
+        assertNull(parseTypeAndSubtype(serviceType1));
+        assertNull(parseTypeAndSubtype(serviceType2));
+        assertNull(parseTypeAndSubtype(serviceType3));
+        assertEquals(new Pair<>("_123._udp", null), parseTypeAndSubtype(serviceType4));
+        assertEquals(new Pair<>("_999._tcp", "_TEST"), parseTypeAndSubtype(serviceType5));
+        assertEquals(new Pair<>("_998._tcp", "_TEST"), parseTypeAndSubtype(serviceType6));
+        assertEquals(new Pair<>("_997._tcp", "_TEST"), parseTypeAndSubtype(serviceType7));
     }
 
     @Test
@@ -1269,7 +1321,7 @@
         client.registerService(regInfo, NsdManager.PROTOCOL_DNS_SD, Runnable::run, regListener);
         waitForIdle();
         verify(mSocketProvider).startMonitoringSockets();
-        verify(mAdvertiser).addService(anyInt(), any());
+        verify(mAdvertiser).addService(anyInt(), any(), any());
 
         // Verify the discovery uses MdnsDiscoveryManager
         final DiscoveryListener discListener = mock(DiscoveryListener.class);
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsAdvertiserTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsAdvertiserTest.kt
index 3bb08a6..b539fe0 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsAdvertiserTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsAdvertiserTest.kt
@@ -57,6 +57,7 @@
 private val TEST_NETWORK_1 = mock(Network::class.java)
 private val TEST_NETWORK_2 = mock(Network::class.java)
 private val TEST_HOSTNAME = arrayOf("Android_test", "local")
+private const val TEST_SUBTYPE = "_subtype"
 
 private val SERVICE_1 = NsdServiceInfo("TestServiceName", "_advertisertest._tcp").apply {
     port = 12345
@@ -130,7 +131,7 @@
     @Test
     fun testAddService_OneNetwork() {
         val advertiser = MdnsAdvertiser(thread.looper, socketProvider, cb, mockDeps, sharedlog)
-        postSync { advertiser.addService(SERVICE_ID_1, SERVICE_1) }
+        postSync { advertiser.addService(SERVICE_ID_1, SERVICE_1, null /* subtype */) }
 
         val socketCbCaptor = ArgumentCaptor.forClass(SocketCallback::class.java)
         verify(socketProvider).requestSocket(eq(TEST_NETWORK_1), socketCbCaptor.capture())
@@ -161,7 +162,7 @@
     @Test
     fun testAddService_AllNetworks() {
         val advertiser = MdnsAdvertiser(thread.looper, socketProvider, cb, mockDeps, sharedlog)
-        postSync { advertiser.addService(SERVICE_ID_1, ALL_NETWORKS_SERVICE) }
+        postSync { advertiser.addService(SERVICE_ID_1, ALL_NETWORKS_SERVICE, TEST_SUBTYPE) }
 
         val socketCbCaptor = ArgumentCaptor.forClass(SocketCallback::class.java)
         verify(socketProvider).requestSocket(eq(ALL_NETWORKS_SERVICE.network),
@@ -179,6 +180,10 @@
         verify(mockDeps).makeAdvertiser(eq(mockSocket2), eq(listOf(TEST_LINKADDR)),
                 eq(thread.looper), any(), intAdvCbCaptor2.capture(), eq(TEST_HOSTNAME), any()
         )
+        verify(mockInterfaceAdvertiser1).addService(
+                anyInt(), eq(ALL_NETWORKS_SERVICE), eq(TEST_SUBTYPE))
+        verify(mockInterfaceAdvertiser2).addService(
+                anyInt(), eq(ALL_NETWORKS_SERVICE), eq(TEST_SUBTYPE))
 
         doReturn(false).`when`(mockInterfaceAdvertiser1).isProbing(SERVICE_ID_1)
         postSync { intAdvCbCaptor1.value.onRegisterServiceSucceeded(
@@ -207,20 +212,21 @@
     @Test
     fun testAddService_Conflicts() {
         val advertiser = MdnsAdvertiser(thread.looper, socketProvider, cb, mockDeps, sharedlog)
-        postSync { advertiser.addService(SERVICE_ID_1, SERVICE_1) }
+        postSync { advertiser.addService(SERVICE_ID_1, SERVICE_1, null /* subtype */) }
 
         val oneNetSocketCbCaptor = ArgumentCaptor.forClass(SocketCallback::class.java)
         verify(socketProvider).requestSocket(eq(TEST_NETWORK_1), oneNetSocketCbCaptor.capture())
         val oneNetSocketCb = oneNetSocketCbCaptor.value
 
         // Register a service with the same name on all networks (name conflict)
-        postSync { advertiser.addService(SERVICE_ID_2, ALL_NETWORKS_SERVICE) }
+        postSync { advertiser.addService(SERVICE_ID_2, ALL_NETWORKS_SERVICE, null /* subtype */) }
         val allNetSocketCbCaptor = ArgumentCaptor.forClass(SocketCallback::class.java)
         verify(socketProvider).requestSocket(eq(null), allNetSocketCbCaptor.capture())
         val allNetSocketCb = allNetSocketCbCaptor.value
 
-        postSync { advertiser.addService(LONG_SERVICE_ID_1, LONG_SERVICE_1) }
-        postSync { advertiser.addService(LONG_SERVICE_ID_2, LONG_ALL_NETWORKS_SERVICE) }
+        postSync { advertiser.addService(LONG_SERVICE_ID_1, LONG_SERVICE_1, null /* subtype */) }
+        postSync { advertiser.addService(LONG_SERVICE_ID_2, LONG_ALL_NETWORKS_SERVICE,
+                null /* subtype */) }
 
         // Callbacks for matching network and all networks both get the socket
         postSync {
@@ -248,13 +254,13 @@
                 eq(thread.looper), any(), intAdvCbCaptor.capture(), eq(TEST_HOSTNAME), any()
         )
         verify(mockInterfaceAdvertiser1).addService(eq(SERVICE_ID_1),
-                argThat { it.matches(SERVICE_1) })
+                argThat { it.matches(SERVICE_1) }, eq(null))
         verify(mockInterfaceAdvertiser1).addService(eq(SERVICE_ID_2),
-                argThat { it.matches(expectedRenamed) })
+                argThat { it.matches(expectedRenamed) }, eq(null))
         verify(mockInterfaceAdvertiser1).addService(eq(LONG_SERVICE_ID_1),
-                argThat { it.matches(LONG_SERVICE_1) })
+                argThat { it.matches(LONG_SERVICE_1) }, eq(null))
         verify(mockInterfaceAdvertiser1).addService(eq(LONG_SERVICE_ID_2),
-            argThat { it.matches(expectedLongRenamed) })
+            argThat { it.matches(expectedLongRenamed) }, eq(null))
 
         doReturn(false).`when`(mockInterfaceAdvertiser1).isProbing(SERVICE_ID_1)
         postSync { intAdvCbCaptor.value.onRegisterServiceSucceeded(
@@ -278,7 +284,7 @@
     fun testRemoveService_whenAllServiceRemoved_thenUpdateHostName() {
         val advertiser = MdnsAdvertiser(thread.looper, socketProvider, cb, mockDeps, sharedlog)
         verify(mockDeps, times(1)).generateHostname()
-        postSync { advertiser.addService(SERVICE_ID_1, SERVICE_1) }
+        postSync { advertiser.addService(SERVICE_ID_1, SERVICE_1, null /* subtype */) }
         postSync { advertiser.removeService(SERVICE_ID_1) }
         verify(mockDeps, times(2)).generateHostname()
     }
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt
index ee190af..dd458b8 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt
@@ -117,7 +117,7 @@
             knownServices.add(inv.getArgument(0))
 
             -1
-        }.`when`(repository).addService(anyInt(), any())
+        }.`when`(repository).addService(anyInt(), any(), any())
         doAnswer { inv ->
             knownServices.remove(inv.getArgument(0))
             null
@@ -278,8 +278,8 @@
         doReturn(serviceId).`when`(testProbingInfo).serviceId
         doReturn(testProbingInfo).`when`(repository).setServiceProbing(serviceId)
 
-        advertiser.addService(serviceId, serviceInfo)
-        verify(repository).addService(serviceId, serviceInfo)
+        advertiser.addService(serviceId, serviceInfo, null /* subtype */)
+        verify(repository).addService(serviceId, serviceInfo, null /* subtype */)
         verify(prober).startProbing(testProbingInfo)
 
         // Simulate probing success: continues to announcing
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt
index 44e0d08..4a39b93 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt
@@ -44,6 +44,7 @@
 private const val TEST_SERVICE_ID_1 = 42
 private const val TEST_SERVICE_ID_2 = 43
 private const val TEST_PORT = 12345
+private const val TEST_SUBTYPE = "_subtype"
 private val TEST_HOSTNAME = arrayOf("Android_000102030405060708090A0B0C0D0E0F", "local")
 private val TEST_ADDRESSES = listOf(
         LinkAddress(parseNumericAddress("192.0.2.111"), 24),
@@ -86,7 +87,8 @@
     fun testAddServiceAndProbe() {
         val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME)
         assertEquals(0, repository.servicesCount)
-        assertEquals(-1, repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1))
+        assertEquals(-1, repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1,
+                null /* subtype */))
         assertEquals(1, repository.servicesCount)
 
         val probingInfo = repository.setServiceProbing(TEST_SERVICE_ID_1)
@@ -118,18 +120,18 @@
     @Test
     fun testAddAndConflicts() {
         val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME)
-        repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
+        repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1, null /* subtype */)
         assertFailsWith(NameConflictException::class) {
-            repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_1)
+            repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_1, null /* subtype */)
         }
     }
 
     @Test
     fun testInvalidReuseOfServiceId() {
         val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME)
-        repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
+        repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1, null /* subtype */)
         assertFailsWith(IllegalArgumentException::class) {
-            repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_2)
+            repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_2, null /* subtype */)
         }
     }
 
@@ -138,7 +140,7 @@
         val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME)
         assertFalse(repository.hasActiveService(TEST_SERVICE_ID_1))
 
-        repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
+        repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1, null /* subtype */)
         assertTrue(repository.hasActiveService(TEST_SERVICE_ID_1))
 
         val probingInfo = repository.setServiceProbing(TEST_SERVICE_ID_1)
@@ -180,13 +182,49 @@
     }
 
     @Test
+    fun testExitAnnouncements_WithSubtype() {
+        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME)
+        repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1, TEST_SUBTYPE)
+        repository.onAdvertisementSent(TEST_SERVICE_ID_1)
+
+        val exitAnnouncement = repository.exitService(TEST_SERVICE_ID_1)
+        assertNotNull(exitAnnouncement)
+        assertEquals(1, repository.servicesCount)
+        val packet = exitAnnouncement.getPacket(0)
+
+        assertEquals(0x8400 /* response, authoritative */, packet.flags)
+        assertEquals(0, packet.questions.size)
+        assertEquals(0, packet.authorityRecords.size)
+        assertEquals(0, packet.additionalRecords.size)
+
+        assertContentEquals(listOf(
+                MdnsPointerRecord(
+                        arrayOf("_testservice", "_tcp", "local"),
+                        0L /* receiptTimeMillis */,
+                        true /* cacheFlush */,
+                        0L /* ttlMillis */,
+                        arrayOf("MyTestService", "_testservice", "_tcp", "local")),
+                MdnsPointerRecord(
+                        arrayOf("_subtype", "_sub", "_testservice", "_tcp", "local"),
+                        0L /* receiptTimeMillis */,
+                        true /* cacheFlush */,
+                        0L /* ttlMillis */,
+                        arrayOf("MyTestService", "_testservice", "_tcp", "local")),
+        ), packet.answers)
+
+        repository.removeService(TEST_SERVICE_ID_1)
+        assertEquals(0, repository.servicesCount)
+    }
+
+    @Test
     fun testExitingServiceReAdded() {
         val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME)
         repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
         repository.onAdvertisementSent(TEST_SERVICE_ID_1)
         repository.exitService(TEST_SERVICE_ID_1)
 
-        assertEquals(TEST_SERVICE_ID_1, repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_1))
+        assertEquals(TEST_SERVICE_ID_1,
+                repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_1, null /* subtype */))
         assertEquals(1, repository.servicesCount)
 
         repository.removeService(TEST_SERVICE_ID_2)
@@ -196,7 +234,8 @@
     @Test
     fun testOnProbingSucceeded() {
         val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME)
-        val announcementInfo = repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
+        val announcementInfo = repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1,
+                TEST_SUBTYPE)
         repository.onAdvertisementSent(TEST_SERVICE_ID_1)
         val packet = announcementInfo.getPacket(0)
 
@@ -205,6 +244,7 @@
         assertEquals(0, packet.authorityRecords.size)
 
         val serviceType = arrayOf("_testservice", "_tcp", "local")
+        val serviceSubtype = arrayOf(TEST_SUBTYPE, "_sub", "_testservice", "_tcp", "local")
         val serviceName = arrayOf("MyTestService", "_testservice", "_tcp", "local")
         val v4AddrRev = getReverseDnsAddress(TEST_ADDRESSES[0].address)
         val v6Addr1Rev = getReverseDnsAddress(TEST_ADDRESSES[1].address)
@@ -250,6 +290,13 @@
                         false /* cacheFlush */,
                         4500000L /* ttlMillis */,
                         serviceName),
+                MdnsPointerRecord(
+                        serviceSubtype,
+                        0L /* receiptTimeMillis */,
+                        // Not a unique name owned by the announcer, so cacheFlush=false
+                        false /* cacheFlush */,
+                        4500000L /* ttlMillis */,
+                        serviceName),
                 MdnsServiceRecord(
                         serviceName,
                         0L /* receiptTimeMillis */,
@@ -319,9 +366,21 @@
 
     @Test
     fun testGetReply() {
+        doGetReplyTest(subtype = null)
+    }
+
+    @Test
+    fun testGetReply_WithSubtype() {
+        doGetReplyTest(TEST_SUBTYPE)
+    }
+
+    private fun doGetReplyTest(subtype: String?) {
         val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME)
-        repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
-        val questions = listOf(MdnsPointerRecord(arrayOf("_testservice", "_tcp", "local"),
+        repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1, subtype)
+        val queriedName = if (subtype == null) arrayOf("_testservice", "_tcp", "local")
+        else arrayOf(subtype, "_sub", "_testservice", "_tcp", "local")
+
+        val questions = listOf(MdnsPointerRecord(queriedName,
                 0L /* receiptTimeMillis */,
                 false /* cacheFlush */,
                 // TTL and data is empty for a question
@@ -344,7 +403,7 @@
 
         assertEquals(listOf(
                 MdnsPointerRecord(
-                        arrayOf("_testservice", "_tcp", "local"),
+                        queriedName,
                         0L /* receiptTimeMillis */,
                         false /* cacheFlush */,
                         longTtl,
@@ -405,8 +464,8 @@
     @Test
     fun testGetConflictingServices() {
         val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME)
-        repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
-        repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_2)
+        repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1, null /* subtype */)
+        repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_2, null /* subtype */)
 
         val packet = MdnsPacket(
                 0 /* flags */,
@@ -433,8 +492,8 @@
     @Test
     fun testGetConflictingServices_IdenticalService() {
         val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME)
-        repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
-        repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_2)
+        repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1, null /* subtype */)
+        repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_2, null /* subtype */)
 
         val otherTtlMillis = 1234L
         val packet = MdnsPacket(
@@ -460,10 +519,13 @@
     }
 }
 
-private fun MdnsRecordRepository.initWithService(serviceId: Int, serviceInfo: NsdServiceInfo):
-        AnnouncementInfo {
+private fun MdnsRecordRepository.initWithService(
+    serviceId: Int,
+    serviceInfo: NsdServiceInfo,
+    subtype: String? = null
+): AnnouncementInfo {
     updateAddresses(TEST_ADDRESSES)
-    addService(serviceId, serviceInfo)
+    addService(serviceId, serviceInfo, subtype)
     val probingInfo = setServiceProbing(serviceId)
     assertNotNull(probingInfo)
     return onProbingSucceeded(probingInfo)
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java
index 98cea62..bd59156 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java
@@ -42,6 +42,7 @@
 import android.net.Network;
 import android.text.TextUtils;
 
+import com.android.net.module.util.CollectionUtils;
 import com.android.net.module.util.SharedLog;
 import com.android.server.connectivity.mdns.MdnsServiceInfo.TextEntry;
 import com.android.server.connectivity.mdns.MdnsServiceTypeClient.QueryTaskConfig;
@@ -53,6 +54,7 @@
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.mockito.ArgumentCaptor;
+import org.mockito.ArgumentMatcher;
 import org.mockito.Captor;
 import org.mockito.InOrder;
 import org.mockito.Mock;
@@ -1086,6 +1088,87 @@
     }
 
     @Test
+    public void testProcessResponse_SubtypeDiscoveryLimitedToSubtype() {
+        client = new MdnsServiceTypeClient(
+                SERVICE_TYPE, mockSocketClient, currentThreadExecutor, mockNetwork, mockSharedLog);
+
+        final String matchingInstance = "instance1";
+        final String subtype = "_subtype";
+        final String otherInstance = "instance2";
+        final String ipV4Address = "192.0.2.0";
+        final String ipV6Address = "2001:db8::";
+
+        final MdnsSearchOptions options = MdnsSearchOptions.newBuilder()
+                // Search with different case. Note MdnsSearchOptions subtype doesn't start with "_"
+                .addSubtype("Subtype").build();
+
+        client.startSendAndReceive(mockListenerOne, options);
+        client.startSendAndReceive(mockListenerTwo, MdnsSearchOptions.getDefaultOptions());
+
+        // Complete response from instanceName
+        final MdnsPacket packetWithoutSubtype = createResponse(
+                matchingInstance, ipV4Address, 5353, SERVICE_TYPE_LABELS,
+                Collections.emptyMap() /* textAttributes */, TEST_TTL);
+        final MdnsPointerRecord originalPtr = (MdnsPointerRecord) CollectionUtils.findFirst(
+                packetWithoutSubtype.answers, r -> r instanceof MdnsPointerRecord);
+
+        // Add a subtype PTR record
+        final ArrayList<MdnsRecord> newAnswers = new ArrayList<>(packetWithoutSubtype.answers);
+        newAnswers.add(new MdnsPointerRecord(
+                // PTR should be _subtype._sub._type._tcp.local -> instance1._type._tcp.local
+                Stream.concat(Stream.of(subtype, "_sub"), Arrays.stream(SERVICE_TYPE_LABELS))
+                        .toArray(String[]::new),
+                originalPtr.getReceiptTime(), originalPtr.getCacheFlush(), originalPtr.getTtl(),
+                originalPtr.getPointer()));
+        final MdnsPacket packetWithSubtype = new MdnsPacket(
+                packetWithoutSubtype.flags,
+                packetWithoutSubtype.questions,
+                newAnswers,
+                packetWithoutSubtype.authorityRecords,
+                packetWithoutSubtype.additionalRecords);
+        client.processResponse(packetWithSubtype, INTERFACE_INDEX, mockNetwork);
+
+        // Complete response from otherInstanceName, without subtype
+        client.processResponse(createResponse(
+                        otherInstance, ipV4Address, 5353, SERVICE_TYPE_LABELS,
+                        Collections.emptyMap() /* textAttributes */, TEST_TTL),
+                INTERFACE_INDEX, mockNetwork);
+
+        // Address update from otherInstanceName
+        client.processResponse(createResponse(
+                otherInstance, ipV6Address, 5353, SERVICE_TYPE_LABELS,
+                Collections.emptyMap(), TEST_TTL), INTERFACE_INDEX, mockNetwork);
+
+        // Goodbye from otherInstanceName
+        client.processResponse(createResponse(
+                otherInstance, ipV6Address, 5353, SERVICE_TYPE_LABELS,
+                Collections.emptyMap(), 0L /* ttl */), INTERFACE_INDEX, mockNetwork);
+
+        // mockListenerOne gets notified for the requested instance
+        final ArgumentMatcher<MdnsServiceInfo> subtypeInstanceMatcher = info ->
+                info.getServiceInstanceName().equals(matchingInstance)
+                        && info.getSubtypes().equals(Collections.singletonList(subtype));
+        verify(mockListenerOne).onServiceNameDiscovered(argThat(subtypeInstanceMatcher));
+        verify(mockListenerOne).onServiceFound(argThat(subtypeInstanceMatcher));
+
+        // ...but does not get any callback for the other instance
+        verify(mockListenerOne, never()).onServiceFound(matchServiceName(otherInstance));
+        verify(mockListenerOne, never()).onServiceNameDiscovered(matchServiceName(otherInstance));
+        verify(mockListenerOne, never()).onServiceUpdated(matchServiceName(otherInstance));
+        verify(mockListenerOne, never()).onServiceRemoved(matchServiceName(otherInstance));
+
+        // mockListenerTwo gets notified for both though
+        final InOrder inOrder = inOrder(mockListenerTwo);
+        inOrder.verify(mockListenerTwo).onServiceNameDiscovered(argThat(subtypeInstanceMatcher));
+        inOrder.verify(mockListenerTwo).onServiceFound(argThat(subtypeInstanceMatcher));
+
+        inOrder.verify(mockListenerTwo).onServiceNameDiscovered(matchServiceName(otherInstance));
+        inOrder.verify(mockListenerTwo).onServiceFound(matchServiceName(otherInstance));
+        inOrder.verify(mockListenerTwo).onServiceUpdated(matchServiceName(otherInstance));
+        inOrder.verify(mockListenerTwo).onServiceRemoved(matchServiceName(otherInstance));
+    }
+
+    @Test
     public void testNotifyAllServicesRemoved() {
         client = new MdnsServiceTypeClient(
                 SERVICE_TYPE, mockSocketClient, currentThreadExecutor, mockNetwork, mockSharedLog);