Merge changes from topic "cherrypicker-L79500000960657219:N07500001368857425" into udc-dev am: 60437e59de am: 2eaf11d4b5

Original change: https://googleplex-android-review.googlesource.com/c/platform/packages/modules/Connectivity/+/23232553

Change-Id: I79095eb2d7120a9983aa442cb995ef29497e6da7
Signed-off-by: Automerger Merge Worker <android-build-automerger-merge-worker@system.gserviceaccount.com>
diff --git a/service-t/src/com/android/server/NsdService.java b/service-t/src/com/android/server/NsdService.java
index 9a2cc5f..4af4c6a 100644
--- a/service-t/src/com/android/server/NsdService.java
+++ b/service-t/src/com/android/server/NsdService.java
@@ -727,7 +727,7 @@
                             // service type would generate service instance names like
                             // Name._subtype._sub._type._tcp, which is incorrect
                             // (it should be Name._type._tcp).
-                            mAdvertiser.addService(id, serviceInfo);
+                            mAdvertiser.addService(id, serviceInfo, typeSubtype.second);
                             storeAdvertiserRequestMap(clientId, id, clientInfo);
                         } else {
                             maybeStartDaemon();
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsAdvertiser.java b/service-t/src/com/android/server/connectivity/mdns/MdnsAdvertiser.java
index 655c364..cc08ea1 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsAdvertiser.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsAdvertiser.java
@@ -270,7 +270,8 @@
             mPendingRegistrations.put(id, registration);
             for (int i = 0; i < mAdvertisers.size(); i++) {
                 try {
-                    mAdvertisers.valueAt(i).addService(id, registration.getServiceInfo());
+                    mAdvertisers.valueAt(i).addService(
+                            id, registration.getServiceInfo(), registration.getSubtype());
                 } catch (NameConflictException e) {
                     Log.wtf(TAG, "Name conflict adding services that should have unique names", e);
                 }
@@ -298,9 +299,10 @@
             }
             mAdvertisers.put(socket, advertiser);
             for (int i = 0; i < mPendingRegistrations.size(); i++) {
+                final Registration registration = mPendingRegistrations.valueAt(i);
                 try {
                     advertiser.addService(mPendingRegistrations.keyAt(i),
-                            mPendingRegistrations.valueAt(i).getServiceInfo());
+                            registration.getServiceInfo(), registration.getSubtype());
                 } catch (NameConflictException e) {
                     Log.wtf(TAG, "Name conflict adding services that should have unique names", e);
                 }
@@ -329,10 +331,13 @@
         private int mConflictCount;
         @NonNull
         private NsdServiceInfo mServiceInfo;
+        @Nullable
+        private final String mSubtype;
 
-        private Registration(@NonNull NsdServiceInfo serviceInfo) {
+        private Registration(@NonNull NsdServiceInfo serviceInfo, @Nullable String subtype) {
             this.mOriginalName = serviceInfo.getServiceName();
             this.mServiceInfo = serviceInfo;
+            this.mSubtype = subtype;
         }
 
         /**
@@ -387,6 +392,11 @@
         public NsdServiceInfo getServiceInfo() {
             return mServiceInfo;
         }
+
+        @Nullable
+        public String getSubtype() {
+            return mSubtype;
+        }
     }
 
     /**
@@ -443,8 +453,9 @@
      * Add a service to advertise.
      * @param id A unique ID for the service.
      * @param service The service info to advertise.
+     * @param subtype An optional subtype to advertise the service with.
      */
-    public void addService(int id, NsdServiceInfo service) {
+    public void addService(int id, NsdServiceInfo service, @Nullable String subtype) {
         checkThread();
         if (mRegistrations.get(id) != null) {
             Log.e(TAG, "Adding duplicate registration for " + service);
@@ -453,10 +464,10 @@
             return;
         }
 
-        mSharedLog.i("Adding service " + service + " with ID " + id);
+        mSharedLog.i("Adding service " + service + " with ID " + id + " and subtype " + subtype);
 
         final Network network = service.getNetwork();
-        final Registration registration = new Registration(service);
+        final Registration registration = new Registration(service, subtype);
         final BiPredicate<Network, InterfaceAdvertiserRequest> checkConflictFilter;
         if (network == null) {
             // If registering on all networks, no advertiser must have conflicts
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java b/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java
index 4e09515..724a704 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java
@@ -212,8 +212,9 @@
      *
      * @throws NameConflictException There is already a service being advertised with that name.
      */
-    public void addService(int id, NsdServiceInfo service) throws NameConflictException {
-        final int replacedExitingService = mRecordRepository.addService(id, service);
+    public void addService(int id, NsdServiceInfo service, @Nullable String subtype)
+            throws NameConflictException {
+        final int replacedExitingService = mRecordRepository.addService(id, service, subtype);
         // Cancel announcements for the existing service. This only happens for exiting services
         // (so cancelling exiting announcements), as per RecordRepository.addService.
         if (replacedExitingService >= 0) {
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java b/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java
index 1329172..f756459 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java
@@ -69,6 +69,8 @@
 
     // Top-level domain for link-local queries, as per RFC6762 3.
     private static final String LOCAL_TLD = "local";
+    // Subtype separator as per RFC6763 7.1 (_printer._sub._http._tcp.local)
+    private static final String SUBTYPE_SEPARATOR = "_sub";
 
     // Service type for service enumeration (RFC6763 9.)
     private static final String[] DNS_SD_SERVICE_TYPE =
@@ -156,13 +158,15 @@
         @NonNull
         public final List<RecordInfo<?>> allRecords;
         @NonNull
-        public final RecordInfo<MdnsPointerRecord> ptrRecord;
+        public final List<RecordInfo<MdnsPointerRecord>> ptrRecords;
         @NonNull
         public final RecordInfo<MdnsServiceRecord> srvRecord;
         @NonNull
         public final RecordInfo<MdnsTextRecord> txtRecord;
         @NonNull
         public final NsdServiceInfo serviceInfo;
+        @Nullable
+        public final String subtype;
 
         /**
          * Whether the service is sending exit announcements and will be destroyed soon.
@@ -175,14 +179,16 @@
          * @param deviceHostname Hostname of the device (for the interface used)
          * @param serviceInfo Service to advertise
          */
-        ServiceRegistration(@NonNull String[] deviceHostname, @NonNull NsdServiceInfo serviceInfo) {
+        ServiceRegistration(@NonNull String[] deviceHostname, @NonNull NsdServiceInfo serviceInfo,
+                @Nullable String subtype) {
             this.serviceInfo = serviceInfo;
+            this.subtype = subtype;
 
             final String[] serviceType = splitServiceType(serviceInfo);
             final String[] serviceName = splitFullyQualifiedName(serviceInfo, serviceType);
 
             // Service PTR record
-            ptrRecord = new RecordInfo<>(
+            final RecordInfo<MdnsPointerRecord> ptrRecord = new RecordInfo<>(
                     serviceInfo,
                     new MdnsPointerRecord(
                             serviceType,
@@ -192,6 +198,26 @@
                             serviceName),
                     true /* sharedName */, true /* probing */);
 
+            if (subtype == null) {
+                this.ptrRecords = Collections.singletonList(ptrRecord);
+            } else {
+                final String[] subtypeName = new String[serviceType.length + 2];
+                System.arraycopy(serviceType, 0, subtypeName, 2, serviceType.length);
+                subtypeName[0] = subtype;
+                subtypeName[1] = SUBTYPE_SEPARATOR;
+                final RecordInfo<MdnsPointerRecord> subtypeRecord = new RecordInfo<>(
+                        serviceInfo,
+                        new MdnsPointerRecord(
+                                subtypeName,
+                                0L /* receiptTimeMillis */,
+                                false /* cacheFlush */,
+                                NON_NAME_RECORDS_TTL_MILLIS,
+                                serviceName),
+                        true /* sharedName */, true /* probing */);
+
+                this.ptrRecords = List.of(ptrRecord, subtypeRecord);
+            }
+
             srvRecord = new RecordInfo<>(
                     serviceInfo,
                     new MdnsServiceRecord(serviceName,
@@ -211,8 +237,8 @@
                             attrsToTextEntries(serviceInfo.getAttributes())),
                     false /* sharedName */, true /* probing */);
 
-            final ArrayList<RecordInfo<?>> allRecords = new ArrayList<>(4);
-            allRecords.add(ptrRecord);
+            final ArrayList<RecordInfo<?>> allRecords = new ArrayList<>(5);
+            allRecords.addAll(ptrRecords);
             allRecords.add(srvRecord);
             allRecords.add(txtRecord);
             // Service type enumeration record (RFC6763 9.)
@@ -275,7 +301,8 @@
      *         ID of the replaced service.
      * @throws NameConflictException There is already a (non-exiting) service using the name.
      */
-    public int addService(int serviceId, NsdServiceInfo serviceInfo) throws NameConflictException {
+    public int addService(int serviceId, NsdServiceInfo serviceInfo, @Nullable String subtype)
+            throws NameConflictException {
         if (mServices.contains(serviceId)) {
             throw new IllegalArgumentException(
                     "Service ID must not be reused across registrations: " + serviceId);
@@ -288,7 +315,7 @@
         }
 
         final ServiceRegistration registration = new ServiceRegistration(
-                mDeviceHostname, serviceInfo);
+                mDeviceHostname, serviceInfo, subtype);
         mServices.put(serviceId, registration);
 
         // Remove existing exiting service
@@ -344,24 +371,25 @@
         if (registration == null) return null;
         if (registration.exiting) return null;
 
-        // Send exit (TTL 0) for the PTR record, if the record was sent (in particular don't send
+        // Send exit (TTL 0) for the PTR records, if at least one was sent (in particular don't send
         // if still probing)
-        if (registration.ptrRecord.lastSentTimeMs == 0L) {
+        if (CollectionUtils.all(registration.ptrRecords, r -> r.lastSentTimeMs == 0L)) {
             return null;
         }
 
         registration.exiting = true;
-        final MdnsPointerRecord expiredRecord = new MdnsPointerRecord(
-                registration.ptrRecord.record.getName(),
-                0L /* receiptTimeMillis */,
-                true /* cacheFlush */,
-                0L /* ttlMillis */,
-                registration.ptrRecord.record.getPointer());
+        final List<MdnsRecord> expiredRecords = CollectionUtils.map(registration.ptrRecords,
+                r -> new MdnsPointerRecord(
+                        r.record.getName(),
+                        0L /* receiptTimeMillis */,
+                        true /* cacheFlush */,
+                        0L /* ttlMillis */,
+                        r.record.getPointer()));
 
         // Exit should be skipped if the record is still advertised by another service, but that
         // would be a conflict (2 service registrations with the same service name), so it would
         // not have been allowed by the repository.
-        return new MdnsAnnouncer.ExitAnnouncementInfo(id, Collections.singletonList(expiredRecord));
+        return new MdnsAnnouncer.ExitAnnouncementInfo(id, expiredRecords);
     }
 
     public void removeService(int id) {
@@ -442,7 +470,7 @@
             for (int i = 0; i < mServices.size(); i++) {
                 final ServiceRegistration registration = mServices.valueAt(i);
                 if (registration.exiting) continue;
-                addReplyFromService(question, registration.allRecords, registration.ptrRecord,
+                addReplyFromService(question, registration.allRecords, registration.ptrRecords,
                         registration.srvRecord, registration.txtRecord, replyUnicast, now,
                         answerInfo, additionalAnswerRecords);
             }
@@ -499,7 +527,7 @@
      */
     private void addReplyFromService(@NonNull MdnsRecord question,
             @NonNull List<RecordInfo<?>> serviceRecords,
-            @Nullable RecordInfo<MdnsPointerRecord> servicePtrRecord,
+            @Nullable List<RecordInfo<MdnsPointerRecord>> servicePtrRecords,
             @Nullable RecordInfo<MdnsServiceRecord> serviceSrvRecord,
             @Nullable RecordInfo<MdnsTextRecord> serviceTxtRecord,
             boolean replyUnicast, long now, @NonNull List<RecordInfo<?>> answerInfo,
@@ -531,7 +559,8 @@
             }
 
             hasKnownAnswer = true;
-            hasDnsSdPtrRecordAnswer |= (info == servicePtrRecord);
+            hasDnsSdPtrRecordAnswer |= (servicePtrRecords != null
+                    && CollectionUtils.any(servicePtrRecords, r -> info == r));
             hasDnsSdSrvRecordAnswer |= (info == serviceSrvRecord);
 
             // TODO: responses to probe queries should bypass this check and only ensure the
@@ -791,10 +820,11 @@
      */
     @Nullable
     public MdnsProber.ProbingInfo renameServiceForConflict(int serviceId, NsdServiceInfo newInfo) {
-        if (!mServices.contains(serviceId)) return null;
+        final ServiceRegistration existing = mServices.get(serviceId);
+        if (existing == null) return null;
 
         final ServiceRegistration newService = new ServiceRegistration(
-                mDeviceHostname, newInfo);
+                mDeviceHostname, newInfo, existing.subtype);
         mServices.put(serviceId, newService);
         return makeProbingInfo(serviceId, newService.srvRecord.record);
     }
diff --git a/tests/unit/java/com/android/server/NsdServiceTest.java b/tests/unit/java/com/android/server/NsdServiceTest.java
index 322b4d2..b3e8cc8 100644
--- a/tests/unit/java/com/android/server/NsdServiceTest.java
+++ b/tests/unit/java/com/android/server/NsdServiceTest.java
@@ -985,10 +985,9 @@
         final RegistrationListener regListener = mock(RegistrationListener.class);
         client.registerService(regInfo, NsdManager.PROTOCOL_DNS_SD, Runnable::run, regListener);
         waitForIdle();
-        // TODO: also pass the subtype to MdnsAdvertiser
         verify(mAdvertiser).addService(anyInt(), argThat(s ->
                 "Instance".equals(s.getServiceName())
-                        && SERVICE_TYPE.equals(s.getServiceType())));
+                        && SERVICE_TYPE.equals(s.getServiceType())), eq("_subtype"));
 
         final DiscoveryListener discListener = mock(DiscoveryListener.class);
         client.discoverServices(typeWithSubtype, PROTOCOL, network, Runnable::run, discListener);
@@ -1090,7 +1089,7 @@
 
         final ArgumentCaptor<Integer> serviceIdCaptor = ArgumentCaptor.forClass(Integer.class);
         verify(mAdvertiser).addService(serviceIdCaptor.capture(),
-                argThat(info -> matches(info, regInfo)));
+                argThat(info -> matches(info, regInfo)), eq(null) /* subtype */);
 
         client.unregisterService(regListenerWithoutFeature);
         waitForIdle();
@@ -1147,8 +1146,10 @@
         waitForIdle();
 
         // The advertiser is enabled for _type2 but not _type1
-        verify(mAdvertiser, never()).addService(anyInt(), argThat(info -> matches(info, service1)));
-        verify(mAdvertiser).addService(anyInt(), argThat(info -> matches(info, service2)));
+        verify(mAdvertiser, never()).addService(
+                anyInt(), argThat(info -> matches(info, service1)), eq(null) /* subtype */);
+        verify(mAdvertiser).addService(
+                anyInt(), argThat(info -> matches(info, service2)), eq(null) /* subtype */);
     }
 
     @Test
@@ -1173,7 +1174,7 @@
         verify(mSocketProvider).startMonitoringSockets();
         final ArgumentCaptor<Integer> idCaptor = ArgumentCaptor.forClass(Integer.class);
         verify(mAdvertiser).addService(idCaptor.capture(), argThat(info ->
-                matches(info, regInfo)));
+                matches(info, regInfo)), eq(null) /* subtype */);
 
         // Verify onServiceRegistered callback
         final MdnsAdvertiser.AdvertiserCallback cb = cbCaptor.getValue();
@@ -1209,7 +1210,7 @@
 
         client.registerService(regInfo, NsdManager.PROTOCOL_DNS_SD, Runnable::run, regListener);
         waitForIdle();
-        verify(mAdvertiser, never()).addService(anyInt(), any());
+        verify(mAdvertiser, never()).addService(anyInt(), any(), any());
 
         verify(regListener, timeout(TIMEOUT_MS)).onRegistrationFailed(
                 argThat(info -> matches(info, regInfo)), eq(FAILURE_INTERNAL_ERROR));
@@ -1237,7 +1238,8 @@
         final ArgumentCaptor<Integer> idCaptor = ArgumentCaptor.forClass(Integer.class);
         // Service name is truncated to 63 characters
         verify(mAdvertiser).addService(idCaptor.capture(),
-                argThat(info -> info.getServiceName().equals("a".repeat(63))));
+                argThat(info -> info.getServiceName().equals("a".repeat(63))),
+                eq(null) /* subtype */);
 
         // Verify onServiceRegistered callback
         final MdnsAdvertiser.AdvertiserCallback cb = cbCaptor.getValue();
@@ -1319,7 +1321,7 @@
         client.registerService(regInfo, NsdManager.PROTOCOL_DNS_SD, Runnable::run, regListener);
         waitForIdle();
         verify(mSocketProvider).startMonitoringSockets();
-        verify(mAdvertiser).addService(anyInt(), any());
+        verify(mAdvertiser).addService(anyInt(), any(), any());
 
         // Verify the discovery uses MdnsDiscoveryManager
         final DiscoveryListener discListener = mock(DiscoveryListener.class);
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsAdvertiserTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsAdvertiserTest.kt
index 3bb08a6..b539fe0 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsAdvertiserTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsAdvertiserTest.kt
@@ -57,6 +57,7 @@
 private val TEST_NETWORK_1 = mock(Network::class.java)
 private val TEST_NETWORK_2 = mock(Network::class.java)
 private val TEST_HOSTNAME = arrayOf("Android_test", "local")
+private const val TEST_SUBTYPE = "_subtype"
 
 private val SERVICE_1 = NsdServiceInfo("TestServiceName", "_advertisertest._tcp").apply {
     port = 12345
@@ -130,7 +131,7 @@
     @Test
     fun testAddService_OneNetwork() {
         val advertiser = MdnsAdvertiser(thread.looper, socketProvider, cb, mockDeps, sharedlog)
-        postSync { advertiser.addService(SERVICE_ID_1, SERVICE_1) }
+        postSync { advertiser.addService(SERVICE_ID_1, SERVICE_1, null /* subtype */) }
 
         val socketCbCaptor = ArgumentCaptor.forClass(SocketCallback::class.java)
         verify(socketProvider).requestSocket(eq(TEST_NETWORK_1), socketCbCaptor.capture())
@@ -161,7 +162,7 @@
     @Test
     fun testAddService_AllNetworks() {
         val advertiser = MdnsAdvertiser(thread.looper, socketProvider, cb, mockDeps, sharedlog)
-        postSync { advertiser.addService(SERVICE_ID_1, ALL_NETWORKS_SERVICE) }
+        postSync { advertiser.addService(SERVICE_ID_1, ALL_NETWORKS_SERVICE, TEST_SUBTYPE) }
 
         val socketCbCaptor = ArgumentCaptor.forClass(SocketCallback::class.java)
         verify(socketProvider).requestSocket(eq(ALL_NETWORKS_SERVICE.network),
@@ -179,6 +180,10 @@
         verify(mockDeps).makeAdvertiser(eq(mockSocket2), eq(listOf(TEST_LINKADDR)),
                 eq(thread.looper), any(), intAdvCbCaptor2.capture(), eq(TEST_HOSTNAME), any()
         )
+        verify(mockInterfaceAdvertiser1).addService(
+                anyInt(), eq(ALL_NETWORKS_SERVICE), eq(TEST_SUBTYPE))
+        verify(mockInterfaceAdvertiser2).addService(
+                anyInt(), eq(ALL_NETWORKS_SERVICE), eq(TEST_SUBTYPE))
 
         doReturn(false).`when`(mockInterfaceAdvertiser1).isProbing(SERVICE_ID_1)
         postSync { intAdvCbCaptor1.value.onRegisterServiceSucceeded(
@@ -207,20 +212,21 @@
     @Test
     fun testAddService_Conflicts() {
         val advertiser = MdnsAdvertiser(thread.looper, socketProvider, cb, mockDeps, sharedlog)
-        postSync { advertiser.addService(SERVICE_ID_1, SERVICE_1) }
+        postSync { advertiser.addService(SERVICE_ID_1, SERVICE_1, null /* subtype */) }
 
         val oneNetSocketCbCaptor = ArgumentCaptor.forClass(SocketCallback::class.java)
         verify(socketProvider).requestSocket(eq(TEST_NETWORK_1), oneNetSocketCbCaptor.capture())
         val oneNetSocketCb = oneNetSocketCbCaptor.value
 
         // Register a service with the same name on all networks (name conflict)
-        postSync { advertiser.addService(SERVICE_ID_2, ALL_NETWORKS_SERVICE) }
+        postSync { advertiser.addService(SERVICE_ID_2, ALL_NETWORKS_SERVICE, null /* subtype */) }
         val allNetSocketCbCaptor = ArgumentCaptor.forClass(SocketCallback::class.java)
         verify(socketProvider).requestSocket(eq(null), allNetSocketCbCaptor.capture())
         val allNetSocketCb = allNetSocketCbCaptor.value
 
-        postSync { advertiser.addService(LONG_SERVICE_ID_1, LONG_SERVICE_1) }
-        postSync { advertiser.addService(LONG_SERVICE_ID_2, LONG_ALL_NETWORKS_SERVICE) }
+        postSync { advertiser.addService(LONG_SERVICE_ID_1, LONG_SERVICE_1, null /* subtype */) }
+        postSync { advertiser.addService(LONG_SERVICE_ID_2, LONG_ALL_NETWORKS_SERVICE,
+                null /* subtype */) }
 
         // Callbacks for matching network and all networks both get the socket
         postSync {
@@ -248,13 +254,13 @@
                 eq(thread.looper), any(), intAdvCbCaptor.capture(), eq(TEST_HOSTNAME), any()
         )
         verify(mockInterfaceAdvertiser1).addService(eq(SERVICE_ID_1),
-                argThat { it.matches(SERVICE_1) })
+                argThat { it.matches(SERVICE_1) }, eq(null))
         verify(mockInterfaceAdvertiser1).addService(eq(SERVICE_ID_2),
-                argThat { it.matches(expectedRenamed) })
+                argThat { it.matches(expectedRenamed) }, eq(null))
         verify(mockInterfaceAdvertiser1).addService(eq(LONG_SERVICE_ID_1),
-                argThat { it.matches(LONG_SERVICE_1) })
+                argThat { it.matches(LONG_SERVICE_1) }, eq(null))
         verify(mockInterfaceAdvertiser1).addService(eq(LONG_SERVICE_ID_2),
-            argThat { it.matches(expectedLongRenamed) })
+            argThat { it.matches(expectedLongRenamed) }, eq(null))
 
         doReturn(false).`when`(mockInterfaceAdvertiser1).isProbing(SERVICE_ID_1)
         postSync { intAdvCbCaptor.value.onRegisterServiceSucceeded(
@@ -278,7 +284,7 @@
     fun testRemoveService_whenAllServiceRemoved_thenUpdateHostName() {
         val advertiser = MdnsAdvertiser(thread.looper, socketProvider, cb, mockDeps, sharedlog)
         verify(mockDeps, times(1)).generateHostname()
-        postSync { advertiser.addService(SERVICE_ID_1, SERVICE_1) }
+        postSync { advertiser.addService(SERVICE_ID_1, SERVICE_1, null /* subtype */) }
         postSync { advertiser.removeService(SERVICE_ID_1) }
         verify(mockDeps, times(2)).generateHostname()
     }
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt
index ee190af..dd458b8 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt
@@ -117,7 +117,7 @@
             knownServices.add(inv.getArgument(0))
 
             -1
-        }.`when`(repository).addService(anyInt(), any())
+        }.`when`(repository).addService(anyInt(), any(), any())
         doAnswer { inv ->
             knownServices.remove(inv.getArgument(0))
             null
@@ -278,8 +278,8 @@
         doReturn(serviceId).`when`(testProbingInfo).serviceId
         doReturn(testProbingInfo).`when`(repository).setServiceProbing(serviceId)
 
-        advertiser.addService(serviceId, serviceInfo)
-        verify(repository).addService(serviceId, serviceInfo)
+        advertiser.addService(serviceId, serviceInfo, null /* subtype */)
+        verify(repository).addService(serviceId, serviceInfo, null /* subtype */)
         verify(prober).startProbing(testProbingInfo)
 
         // Simulate probing success: continues to announcing
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt
index 44e0d08..4a39b93 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt
@@ -44,6 +44,7 @@
 private const val TEST_SERVICE_ID_1 = 42
 private const val TEST_SERVICE_ID_2 = 43
 private const val TEST_PORT = 12345
+private const val TEST_SUBTYPE = "_subtype"
 private val TEST_HOSTNAME = arrayOf("Android_000102030405060708090A0B0C0D0E0F", "local")
 private val TEST_ADDRESSES = listOf(
         LinkAddress(parseNumericAddress("192.0.2.111"), 24),
@@ -86,7 +87,8 @@
     fun testAddServiceAndProbe() {
         val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME)
         assertEquals(0, repository.servicesCount)
-        assertEquals(-1, repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1))
+        assertEquals(-1, repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1,
+                null /* subtype */))
         assertEquals(1, repository.servicesCount)
 
         val probingInfo = repository.setServiceProbing(TEST_SERVICE_ID_1)
@@ -118,18 +120,18 @@
     @Test
     fun testAddAndConflicts() {
         val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME)
-        repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
+        repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1, null /* subtype */)
         assertFailsWith(NameConflictException::class) {
-            repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_1)
+            repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_1, null /* subtype */)
         }
     }
 
     @Test
     fun testInvalidReuseOfServiceId() {
         val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME)
-        repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
+        repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1, null /* subtype */)
         assertFailsWith(IllegalArgumentException::class) {
-            repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_2)
+            repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_2, null /* subtype */)
         }
     }
 
@@ -138,7 +140,7 @@
         val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME)
         assertFalse(repository.hasActiveService(TEST_SERVICE_ID_1))
 
-        repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
+        repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1, null /* subtype */)
         assertTrue(repository.hasActiveService(TEST_SERVICE_ID_1))
 
         val probingInfo = repository.setServiceProbing(TEST_SERVICE_ID_1)
@@ -180,13 +182,49 @@
     }
 
     @Test
+    fun testExitAnnouncements_WithSubtype() {
+        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME)
+        repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1, TEST_SUBTYPE)
+        repository.onAdvertisementSent(TEST_SERVICE_ID_1)
+
+        val exitAnnouncement = repository.exitService(TEST_SERVICE_ID_1)
+        assertNotNull(exitAnnouncement)
+        assertEquals(1, repository.servicesCount)
+        val packet = exitAnnouncement.getPacket(0)
+
+        assertEquals(0x8400 /* response, authoritative */, packet.flags)
+        assertEquals(0, packet.questions.size)
+        assertEquals(0, packet.authorityRecords.size)
+        assertEquals(0, packet.additionalRecords.size)
+
+        assertContentEquals(listOf(
+                MdnsPointerRecord(
+                        arrayOf("_testservice", "_tcp", "local"),
+                        0L /* receiptTimeMillis */,
+                        true /* cacheFlush */,
+                        0L /* ttlMillis */,
+                        arrayOf("MyTestService", "_testservice", "_tcp", "local")),
+                MdnsPointerRecord(
+                        arrayOf("_subtype", "_sub", "_testservice", "_tcp", "local"),
+                        0L /* receiptTimeMillis */,
+                        true /* cacheFlush */,
+                        0L /* ttlMillis */,
+                        arrayOf("MyTestService", "_testservice", "_tcp", "local")),
+        ), packet.answers)
+
+        repository.removeService(TEST_SERVICE_ID_1)
+        assertEquals(0, repository.servicesCount)
+    }
+
+    @Test
     fun testExitingServiceReAdded() {
         val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME)
         repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
         repository.onAdvertisementSent(TEST_SERVICE_ID_1)
         repository.exitService(TEST_SERVICE_ID_1)
 
-        assertEquals(TEST_SERVICE_ID_1, repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_1))
+        assertEquals(TEST_SERVICE_ID_1,
+                repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_1, null /* subtype */))
         assertEquals(1, repository.servicesCount)
 
         repository.removeService(TEST_SERVICE_ID_2)
@@ -196,7 +234,8 @@
     @Test
     fun testOnProbingSucceeded() {
         val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME)
-        val announcementInfo = repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
+        val announcementInfo = repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1,
+                TEST_SUBTYPE)
         repository.onAdvertisementSent(TEST_SERVICE_ID_1)
         val packet = announcementInfo.getPacket(0)
 
@@ -205,6 +244,7 @@
         assertEquals(0, packet.authorityRecords.size)
 
         val serviceType = arrayOf("_testservice", "_tcp", "local")
+        val serviceSubtype = arrayOf(TEST_SUBTYPE, "_sub", "_testservice", "_tcp", "local")
         val serviceName = arrayOf("MyTestService", "_testservice", "_tcp", "local")
         val v4AddrRev = getReverseDnsAddress(TEST_ADDRESSES[0].address)
         val v6Addr1Rev = getReverseDnsAddress(TEST_ADDRESSES[1].address)
@@ -250,6 +290,13 @@
                         false /* cacheFlush */,
                         4500000L /* ttlMillis */,
                         serviceName),
+                MdnsPointerRecord(
+                        serviceSubtype,
+                        0L /* receiptTimeMillis */,
+                        // Not a unique name owned by the announcer, so cacheFlush=false
+                        false /* cacheFlush */,
+                        4500000L /* ttlMillis */,
+                        serviceName),
                 MdnsServiceRecord(
                         serviceName,
                         0L /* receiptTimeMillis */,
@@ -319,9 +366,21 @@
 
     @Test
     fun testGetReply() {
+        doGetReplyTest(subtype = null)
+    }
+
+    @Test
+    fun testGetReply_WithSubtype() {
+        doGetReplyTest(TEST_SUBTYPE)
+    }
+
+    private fun doGetReplyTest(subtype: String?) {
         val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME)
-        repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
-        val questions = listOf(MdnsPointerRecord(arrayOf("_testservice", "_tcp", "local"),
+        repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1, subtype)
+        val queriedName = if (subtype == null) arrayOf("_testservice", "_tcp", "local")
+        else arrayOf(subtype, "_sub", "_testservice", "_tcp", "local")
+
+        val questions = listOf(MdnsPointerRecord(queriedName,
                 0L /* receiptTimeMillis */,
                 false /* cacheFlush */,
                 // TTL and data is empty for a question
@@ -344,7 +403,7 @@
 
         assertEquals(listOf(
                 MdnsPointerRecord(
-                        arrayOf("_testservice", "_tcp", "local"),
+                        queriedName,
                         0L /* receiptTimeMillis */,
                         false /* cacheFlush */,
                         longTtl,
@@ -405,8 +464,8 @@
     @Test
     fun testGetConflictingServices() {
         val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME)
-        repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
-        repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_2)
+        repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1, null /* subtype */)
+        repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_2, null /* subtype */)
 
         val packet = MdnsPacket(
                 0 /* flags */,
@@ -433,8 +492,8 @@
     @Test
     fun testGetConflictingServices_IdenticalService() {
         val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME)
-        repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
-        repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_2)
+        repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1, null /* subtype */)
+        repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_2, null /* subtype */)
 
         val otherTtlMillis = 1234L
         val packet = MdnsPacket(
@@ -460,10 +519,13 @@
     }
 }
 
-private fun MdnsRecordRepository.initWithService(serviceId: Int, serviceInfo: NsdServiceInfo):
-        AnnouncementInfo {
+private fun MdnsRecordRepository.initWithService(
+    serviceId: Int,
+    serviceInfo: NsdServiceInfo,
+    subtype: String? = null
+): AnnouncementInfo {
     updateAddresses(TEST_ADDRESSES)
-    addService(serviceId, serviceInfo)
+    addService(serviceId, serviceInfo, subtype)
     val probingInfo = setServiceProbing(serviceId)
     assertNotNull(probingInfo)
     return onProbingSucceeded(probingInfo)