Merge "Add support to update the registered service in place" into main am: 283ff70585

Original change: https://android-review.googlesource.com/c/platform/packages/modules/Connectivity/+/2817455

Change-Id: I634ef0dfa46e78fbb048c13baf7ad3f4fe819cb3
Signed-off-by: Automerger Merge Worker <android-build-automerger-merge-worker@system.gserviceaccount.com>
diff --git a/service-t/src/com/android/server/NsdService.java b/service-t/src/com/android/server/NsdService.java
index 8cf6db7..2640332 100644
--- a/service-t/src/com/android/server/NsdService.java
+++ b/service-t/src/com/android/server/NsdService.java
@@ -91,6 +91,7 @@
 import com.android.net.module.util.SharedLog;
 import com.android.server.connectivity.mdns.ExecutorProvider;
 import com.android.server.connectivity.mdns.MdnsAdvertiser;
+import com.android.server.connectivity.mdns.MdnsAdvertisingOptions;
 import com.android.server.connectivity.mdns.MdnsDiscoveryManager;
 import com.android.server.connectivity.mdns.MdnsFeatureFlags;
 import com.android.server.connectivity.mdns.MdnsInterfaceSocket;
@@ -850,7 +851,9 @@
                             // service type would generate service instance names like
                             // Name._subtype._sub._type._tcp, which is incorrect
                             // (it should be Name._type._tcp).
-                            mAdvertiser.addService(transactionId, serviceInfo, typeSubtype.second);
+                            mAdvertiser.addOrUpdateService(transactionId, serviceInfo,
+                                    typeSubtype.second,
+                                    MdnsAdvertisingOptions.newBuilder().build());
                             storeAdvertiserRequestMap(clientRequestId, transactionId, clientInfo,
                                     serviceInfo.getNetwork());
                         } else {
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsAdvertiser.java b/service-t/src/com/android/server/connectivity/mdns/MdnsAdvertiser.java
index 28e3924..fc0e11b 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsAdvertiser.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsAdvertiser.java
@@ -43,6 +43,7 @@
 import java.util.Collections;
 import java.util.List;
 import java.util.Map;
+import java.util.Objects;
 import java.util.UUID;
 import java.util.function.BiPredicate;
 import java.util.function.Consumer;
@@ -342,16 +343,16 @@
         }
 
         /**
-         * Add a service.
+         * Add a service to advertise.
          *
          * Conflicts must be checked via {@link #getConflictingService} before attempting to add.
          */
-        void addService(int id, Registration registration) {
+        void addService(int id, @NonNull Registration registration) {
             mPendingRegistrations.put(id, registration);
             for (int i = 0; i < mAdvertisers.size(); i++) {
                 try {
-                    mAdvertisers.valueAt(i).addService(
-                            id, registration.getServiceInfo(), registration.getSubtype());
+                    mAdvertisers.valueAt(i).addService(id, registration.getServiceInfo(),
+                            registration.getSubtype());
                 } catch (NameConflictException e) {
                     mSharedLog.wtf("Name conflict adding services that should have unique names",
                             e);
@@ -359,6 +360,17 @@
             }
         }
 
+        /**
+         * Update an already registered service.
+         * The caller is expected to check that the service being updated doesn't change its name
+         */
+        void updateService(int id, @NonNull Registration registration) {
+            mPendingRegistrations.put(id, registration);
+            for (int i = 0; i < mAdvertisers.size(); i++) {
+                mAdvertisers.valueAt(i).updateService(id, registration.getSubtype());
+            }
+        }
+
         void removeService(int id) {
             mPendingRegistrations.remove(id);
             for (int i = 0; i < mAdvertisers.size(); i++) {
@@ -474,7 +486,8 @@
         @NonNull
         private NsdServiceInfo mServiceInfo;
         @Nullable
-        private final String mSubtype;
+        private String mSubtype;
+
         int mConflictDuringProbingCount;
         int mConflictAfterProbingCount;
 
@@ -485,6 +498,22 @@
         }
 
         /**
+         * Matches between the NsdServiceInfo in the Registration and the provided argument.
+         */
+        public boolean matches(@Nullable NsdServiceInfo newInfo) {
+            return Objects.equals(newInfo.getServiceName(), mOriginalName) && Objects.equals(
+                    newInfo.getServiceType(), mServiceInfo.getServiceType()) && Objects.equals(
+                    newInfo.getNetwork(), mServiceInfo.getNetwork());
+        }
+
+        /**
+         * Update subType for the registration.
+         */
+        public void updateSubtype(@Nullable String subtype) {
+            this.mSubtype = subtype;
+        }
+
+        /**
          * Update the registration to use a different service name, after a conflict was found.
          *
          * @param newInfo New service info to use.
@@ -632,42 +661,68 @@
     }
 
     /**
-     * Add a service to advertise.
+     * Add or update a service to advertise.
+     *
      * @param id A unique ID for the service.
      * @param service The service info to advertise.
      * @param subtype An optional subtype to advertise the service with.
+     * @param advertisingOptions The advertising options.
      */
-    public void addService(int id, NsdServiceInfo service, @Nullable String subtype) {
+    public void addOrUpdateService(int id, NsdServiceInfo service, @Nullable String subtype,
+            MdnsAdvertisingOptions advertisingOptions) {
         checkThread();
-        if (mRegistrations.get(id) != null) {
-            mSharedLog.e("Adding duplicate registration for " + service);
-            // TODO (b/264986328): add a more specific error code
-            mCb.onRegisterServiceFailed(id, NsdManager.FAILURE_INTERNAL_ERROR);
-            return;
-        }
-
-        mSharedLog.i("Adding service " + service + " with ID " + id + " and subtype " + subtype);
-
+        final Registration existingRegistration = mRegistrations.get(id);
         final Network network = service.getNetwork();
-        final Registration registration = new Registration(service, subtype);
-        final BiPredicate<Network, InterfaceAdvertiserRequest> checkConflictFilter;
-        if (network == null) {
-            // If registering on all networks, no advertiser must have conflicts
-            checkConflictFilter = (net, adv) -> true;
-        } else {
-            // If registering on one network, the matching network advertiser and the one for all
-            // networks must not have conflicts
-            checkConflictFilter = (net, adv) -> net == null || network.equals(net);
-        }
+        Registration registration;
+        if (advertisingOptions.isOnlyUpdate()) {
+            if (existingRegistration == null) {
+                mSharedLog.e("Update non existing registration for " + service);
+                mCb.onRegisterServiceFailed(id, NsdManager.FAILURE_INTERNAL_ERROR);
+                return;
+            }
+            if (!(existingRegistration.matches(service))) {
+                mSharedLog.e("Update request can only update subType, serviceInfo: " + service
+                        + ", existing serviceInfo: " + existingRegistration.getServiceInfo());
+                mCb.onRegisterServiceFailed(id, NsdManager.FAILURE_INTERNAL_ERROR);
+                return;
 
-        updateRegistrationUntilNoConflict(checkConflictFilter, registration);
+            }
+            mSharedLog.i("Update service " + service + " with ID " + id + " and subtype " + subtype
+                    + " advertisingOptions " + advertisingOptions);
+            registration = existingRegistration;
+            registration.updateSubtype(subtype);
+        } else {
+            if (existingRegistration != null) {
+                mSharedLog.e("Adding duplicate registration for " + service);
+                // TODO (b/264986328): add a more specific error code
+                mCb.onRegisterServiceFailed(id, NsdManager.FAILURE_INTERNAL_ERROR);
+                return;
+            }
+            mSharedLog.i("Adding service " + service + " with ID " + id + " and subtype " + subtype
+                    + " advertisingOptions " + advertisingOptions);
+            registration = new Registration(service, subtype);
+            final BiPredicate<Network, InterfaceAdvertiserRequest> checkConflictFilter;
+            if (network == null) {
+                // If registering on all networks, no advertiser must have conflicts
+                checkConflictFilter = (net, adv) -> true;
+            } else {
+                // If registering on one network, the matching network advertiser and the one
+                // for all networks must not have conflicts
+                checkConflictFilter = (net, adv) -> net == null || network.equals(net);
+            }
+            updateRegistrationUntilNoConflict(checkConflictFilter, registration);
+        }
 
         InterfaceAdvertiserRequest advertiser = mAdvertiserRequests.get(network);
         if (advertiser == null) {
             advertiser = new InterfaceAdvertiserRequest(network);
             mAdvertiserRequests.put(network, advertiser);
         }
-        advertiser.addService(id, registration);
+        if (advertisingOptions.isOnlyUpdate()) {
+            advertiser.updateService(id, registration);
+        } else {
+            advertiser.addService(id, registration);
+        }
         mRegistrations.put(id, registration);
     }
 
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsAdvertisingOptions.java b/service-t/src/com/android/server/connectivity/mdns/MdnsAdvertisingOptions.java
new file mode 100644
index 0000000..e7a6ca7
--- /dev/null
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsAdvertisingOptions.java
@@ -0,0 +1,92 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server.connectivity.mdns;
+
+/**
+ * API configuration parameters for advertising the mDNS service.
+ *
+ * <p>Use {@link MdnsAdvertisingOptions.Builder} to create {@link MdnsAdvertisingOptions}.
+ *
+ * @hide
+ */
+public class MdnsAdvertisingOptions {
+
+    private static MdnsAdvertisingOptions sDefaultOptions;
+    private final boolean mIsOnlyUpdate;
+
+    /**
+     * Parcelable constructs for a {@link MdnsAdvertisingOptions}.
+     */
+    MdnsAdvertisingOptions(
+            boolean isOnlyUpdate) {
+        this.mIsOnlyUpdate = isOnlyUpdate;
+    }
+
+    /**
+     * Returns a {@link Builder} for {@link MdnsAdvertisingOptions}.
+     */
+    public static Builder newBuilder() {
+        return new Builder();
+    }
+
+    /**
+     * Returns a default search options.
+     */
+    public static synchronized MdnsAdvertisingOptions getDefaultOptions() {
+        if (sDefaultOptions == null) {
+            sDefaultOptions = newBuilder().build();
+        }
+        return sDefaultOptions;
+    }
+
+    /**
+     * @return {@code true} if the advertising request is an update request.
+     */
+    public boolean isOnlyUpdate() {
+        return mIsOnlyUpdate;
+    }
+
+    @Override
+    public String toString() {
+        return "MdnsAdvertisingOptions{" + "mIsOnlyUpdate=" + mIsOnlyUpdate + '}';
+    }
+
+    /**
+     * A builder to create {@link MdnsAdvertisingOptions}.
+     */
+    public static final class Builder {
+        private boolean mIsOnlyUpdate = false;
+
+        private Builder() {
+        }
+
+        /**
+         * Sets if the advertising request is an update request.
+         */
+        public Builder setIsOnlyUpdate(boolean isOnlyUpdate) {
+            this.mIsOnlyUpdate = isOnlyUpdate;
+            return this;
+        }
+
+        /**
+         * Builds a {@link MdnsAdvertisingOptions} with the arguments supplied to this builder.
+         */
+        public MdnsAdvertisingOptions build() {
+            return new MdnsAdvertisingOptions(mIsOnlyUpdate);
+        }
+    }
+}
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java b/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java
index 62c37ad..463df63 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java
@@ -229,6 +229,18 @@
     }
 
     /**
+     * Update an already registered service without sending exit/re-announcement packet.
+     *
+     * @param id An exiting service id
+     * @param subtype A new subtype
+     */
+    public void updateService(int id, @Nullable String subtype) {
+        // The current implementation is intended to be used in cases where subtypes don't get
+        // announced.
+        mRecordRepository.updateService(id, subtype);
+    }
+
+    /**
      * Start advertising a service.
      *
      * @throws NameConflictException There is already a service being advertised with that name.
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java b/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java
index e34778f..48ece68 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java
@@ -167,7 +167,7 @@
         /**
          * Whether the service is sending exit announcements and will be destroyed soon.
          */
-        public boolean exiting = false;
+        public boolean exiting;
 
         /**
          * The replied query packet count of this service.
@@ -185,13 +185,20 @@
         private boolean isProbing;
 
         /**
+         * Create a ServiceRegistration with only update the subType
+         */
+        ServiceRegistration withSubtype(String newSubType) {
+            return new ServiceRegistration(srvRecord.record.getServiceHost(), serviceInfo,
+                    newSubType, repliedServiceCount, sentPacketCount, exiting, isProbing);
+        }
+
+
+        /**
          * Create a ServiceRegistration for dns-sd service registration (RFC6763).
-         *
-         * @param deviceHostname Hostname of the device (for the interface used)
-         * @param serviceInfo Service to advertise
          */
         ServiceRegistration(@NonNull String[] deviceHostname, @NonNull NsdServiceInfo serviceInfo,
-                @Nullable String subtype, int repliedServiceCount, int sentPacketCount) {
+                @Nullable String subtype, int repliedServiceCount, int sentPacketCount,
+                boolean exiting, boolean isProbing) {
             this.serviceInfo = serviceInfo;
             this.subtype = subtype;
 
@@ -266,7 +273,20 @@
             this.allRecords = Collections.unmodifiableList(allRecords);
             this.repliedServiceCount = repliedServiceCount;
             this.sentPacketCount = sentPacketCount;
-            this.isProbing = true;
+            this.isProbing = isProbing;
+            this.exiting = exiting;
+        }
+
+        /**
+         * Create a ServiceRegistration for dns-sd service registration (RFC6763).
+         *
+         * @param deviceHostname Hostname of the device (for the interface used)
+         * @param serviceInfo Service to advertise
+         */
+        ServiceRegistration(@NonNull String[] deviceHostname, @NonNull NsdServiceInfo serviceInfo,
+                @Nullable String subtype, int repliedServiceCount, int sentPacketCount) {
+            this(deviceHostname, serviceInfo, subtype, repliedServiceCount, sentPacketCount,
+                    false /* exiting */, true /* isProbing */);
         }
 
         void setProbing(boolean probing) {
@@ -305,6 +325,24 @@
     }
 
     /**
+     * Update a service that already registered in the repository.
+     *
+     * @param serviceId An existing service ID.
+     * @param subtype A new subtype
+     * @return
+     */
+    public void updateService(int serviceId, @Nullable String subtype) {
+        final ServiceRegistration existingRegistration = mServices.get(serviceId);
+        if (existingRegistration == null) {
+            throw new IllegalArgumentException(
+                    "Service ID must already exist for an update request: " + serviceId);
+        }
+        final ServiceRegistration updatedRegistration = existingRegistration.withSubtype(
+                subtype);
+        mServices.put(serviceId, updatedRegistration);
+    }
+
+    /**
      * Add a service to the repository.
      *
      * This may remove/replace any existing service that used the name added but is exiting.
diff --git a/tests/unit/java/com/android/server/NsdServiceTest.java b/tests/unit/java/com/android/server/NsdServiceTest.java
index ad87d28..32014c2 100644
--- a/tests/unit/java/com/android/server/NsdServiceTest.java
+++ b/tests/unit/java/com/android/server/NsdServiceTest.java
@@ -1115,9 +1115,9 @@
         final RegistrationListener regListener = mock(RegistrationListener.class);
         client.registerService(regInfo, NsdManager.PROTOCOL_DNS_SD, Runnable::run, regListener);
         waitForIdle();
-        verify(mAdvertiser).addService(anyInt(), argThat(s ->
+        verify(mAdvertiser).addOrUpdateService(anyInt(), argThat(s ->
                 "Instance".equals(s.getServiceName())
-                        && SERVICE_TYPE.equals(s.getServiceType())), eq("_subtype"));
+                        && SERVICE_TYPE.equals(s.getServiceType())), eq("_subtype"), any());
 
         final DiscoveryListener discListener = mock(DiscoveryListener.class);
         client.discoverServices(typeWithSubtype, PROTOCOL, network, Runnable::run, discListener);
@@ -1222,8 +1222,8 @@
         waitForIdle();
 
         final ArgumentCaptor<Integer> serviceIdCaptor = ArgumentCaptor.forClass(Integer.class);
-        verify(mAdvertiser).addService(serviceIdCaptor.capture(),
-                argThat(info -> matches(info, regInfo)), eq(null) /* subtype */);
+        verify(mAdvertiser).addOrUpdateService(serviceIdCaptor.capture(),
+                argThat(info -> matches(info, regInfo)), eq(null) /* subtype */, any());
 
         client.unregisterService(regListenerWithoutFeature);
         waitForIdle();
@@ -1282,10 +1282,10 @@
         waitForIdle();
 
         // The advertiser is enabled for _type2 but not _type1
-        verify(mAdvertiser, never()).addService(
-                anyInt(), argThat(info -> matches(info, service1)), eq(null) /* subtype */);
-        verify(mAdvertiser).addService(
-                anyInt(), argThat(info -> matches(info, service2)), eq(null) /* subtype */);
+        verify(mAdvertiser, never()).addOrUpdateService(anyInt(),
+                argThat(info -> matches(info, service1)), eq(null) /* subtype */, any());
+        verify(mAdvertiser).addOrUpdateService(anyInt(), argThat(info -> matches(info, service2)),
+                eq(null) /* subtype */, any());
     }
 
     @Test
@@ -1309,8 +1309,8 @@
         waitForIdle();
         verify(mSocketProvider).startMonitoringSockets();
         final ArgumentCaptor<Integer> idCaptor = ArgumentCaptor.forClass(Integer.class);
-        verify(mAdvertiser).addService(idCaptor.capture(), argThat(info ->
-                matches(info, regInfo)), eq(null) /* subtype */);
+        verify(mAdvertiser).addOrUpdateService(idCaptor.capture(), argThat(info ->
+                matches(info, regInfo)), eq(null) /* subtype */, any());
 
         // Verify onServiceRegistered callback
         final MdnsAdvertiser.AdvertiserCallback cb = cbCaptor.getValue();
@@ -1358,7 +1358,7 @@
 
         client.registerService(regInfo, NsdManager.PROTOCOL_DNS_SD, Runnable::run, regListener);
         waitForIdle();
-        verify(mAdvertiser, never()).addService(anyInt(), any(), any());
+        verify(mAdvertiser, never()).addOrUpdateService(anyInt(), any(), any(), any());
 
         verify(regListener, timeout(TIMEOUT_MS)).onRegistrationFailed(
                 argThat(info -> matches(info, regInfo)), eq(FAILURE_INTERNAL_ERROR));
@@ -1387,9 +1387,9 @@
         waitForIdle();
         final ArgumentCaptor<Integer> idCaptor = ArgumentCaptor.forClass(Integer.class);
         // Service name is truncated to 63 characters
-        verify(mAdvertiser).addService(idCaptor.capture(),
+        verify(mAdvertiser).addOrUpdateService(idCaptor.capture(),
                 argThat(info -> info.getServiceName().equals("a".repeat(63))),
-                eq(null) /* subtype */);
+                eq(null) /* subtype */, any());
 
         // Verify onServiceRegistered callback
         final MdnsAdvertiser.AdvertiserCallback cb = cbCaptor.getValue();
@@ -1479,7 +1479,7 @@
         client.registerService(regInfo, NsdManager.PROTOCOL_DNS_SD, Runnable::run, regListener);
         waitForIdle();
         verify(mSocketProvider).startMonitoringSockets();
-        verify(mAdvertiser).addService(anyInt(), any(), any());
+        verify(mAdvertiser).addOrUpdateService(anyInt(), any(), any(), any());
 
         // Verify the discovery uses MdnsDiscoveryManager
         final DiscoveryListener discListener = mock(DiscoveryListener.class);
@@ -1512,7 +1512,7 @@
         client.registerService(regInfo, NsdManager.PROTOCOL_DNS_SD, Runnable::run, regListener);
         waitForIdle();
         verify(mSocketProvider).startMonitoringSockets();
-        verify(mAdvertiser).addService(anyInt(), any(), any());
+        verify(mAdvertiser).addOrUpdateService(anyInt(), any(), any(), any());
 
         final Network wifiNetwork1 = new Network(123);
         final Network wifiNetwork2 = new Network(124);
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsAdvertiserTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsAdvertiserTest.kt
index a86f923..f0cb6df 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsAdvertiserTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsAdvertiserTest.kt
@@ -19,6 +19,7 @@
 import android.net.InetAddresses.parseNumericAddress
 import android.net.LinkAddress
 import android.net.Network
+import android.net.nsd.NsdManager
 import android.net.nsd.NsdServiceInfo
 import android.net.nsd.OffloadEngine
 import android.net.nsd.OffloadServiceInfo
@@ -71,6 +72,7 @@
 private val TEST_INTERFACE2 = "test_iface2"
 private val TEST_OFFLOAD_PACKET1 = byteArrayOf(0x01, 0x02, 0x03)
 private val TEST_OFFLOAD_PACKET2 = byteArrayOf(0x02, 0x03, 0x04)
+private val DEFAULT_ADVERTISING_OPTION = MdnsAdvertisingOptions.getDefaultOptions()
 
 private val SERVICE_1 = NsdServiceInfo("TestServiceName", "_advertisertest._tcp").apply {
     port = 12345
@@ -186,7 +188,8 @@
     fun testAddService_OneNetwork() {
         val advertiser =
             MdnsAdvertiser(thread.looper, socketProvider, cb, mockDeps, sharedlog, flags)
-        postSync { advertiser.addService(SERVICE_ID_1, SERVICE_1, null /* subtype */) }
+        postSync { advertiser.addOrUpdateService(SERVICE_ID_1, SERVICE_1,
+                null /* subtype */, DEFAULT_ADVERTISING_OPTION) }
 
         val socketCbCaptor = ArgumentCaptor.forClass(SocketCallback::class.java)
         verify(socketProvider).requestSocket(eq(TEST_NETWORK_1), socketCbCaptor.capture())
@@ -247,7 +250,8 @@
     fun testAddService_AllNetworks() {
         val advertiser =
             MdnsAdvertiser(thread.looper, socketProvider, cb, mockDeps, sharedlog, flags)
-        postSync { advertiser.addService(SERVICE_ID_1, ALL_NETWORKS_SERVICE, TEST_SUBTYPE) }
+        postSync { advertiser.addOrUpdateService(SERVICE_ID_1, ALL_NETWORKS_SERVICE,
+                TEST_SUBTYPE, DEFAULT_ADVERTISING_OPTION) }
 
         val socketCbCaptor = ArgumentCaptor.forClass(SocketCallback::class.java)
         verify(socketProvider).requestSocket(eq(ALL_NETWORKS_SERVICE.network),
@@ -318,24 +322,27 @@
     fun testAddService_Conflicts() {
         val advertiser =
             MdnsAdvertiser(thread.looper, socketProvider, cb, mockDeps, sharedlog, flags)
-        postSync { advertiser.addService(SERVICE_ID_1, SERVICE_1, null /* subtype */) }
+        postSync { advertiser.addOrUpdateService(SERVICE_ID_1, SERVICE_1,
+                null /* subtype */, DEFAULT_ADVERTISING_OPTION) }
 
         val oneNetSocketCbCaptor = ArgumentCaptor.forClass(SocketCallback::class.java)
         verify(socketProvider).requestSocket(eq(TEST_NETWORK_1), oneNetSocketCbCaptor.capture())
         val oneNetSocketCb = oneNetSocketCbCaptor.value
 
         // Register a service with the same name on all networks (name conflict)
-        postSync { advertiser.addService(SERVICE_ID_2, ALL_NETWORKS_SERVICE, null /* subtype */) }
+        postSync { advertiser.addOrUpdateService(SERVICE_ID_2, ALL_NETWORKS_SERVICE,
+                null /* subtype */, DEFAULT_ADVERTISING_OPTION) }
         val allNetSocketCbCaptor = ArgumentCaptor.forClass(SocketCallback::class.java)
         verify(socketProvider).requestSocket(eq(null), allNetSocketCbCaptor.capture())
         val allNetSocketCb = allNetSocketCbCaptor.value
 
-        postSync { advertiser.addService(LONG_SERVICE_ID_1, LONG_SERVICE_1, null /* subtype */) }
-        postSync { advertiser.addService(LONG_SERVICE_ID_2, LONG_ALL_NETWORKS_SERVICE,
-                null /* subtype */) }
+        postSync { advertiser.addOrUpdateService(LONG_SERVICE_ID_1, LONG_SERVICE_1,
+                null /* subtype */, DEFAULT_ADVERTISING_OPTION) }
+        postSync { advertiser.addOrUpdateService(LONG_SERVICE_ID_2, LONG_ALL_NETWORKS_SERVICE,
+                null /* subtype */, DEFAULT_ADVERTISING_OPTION) }
 
-        postSync { advertiser.addService(CASE_INSENSITIVE_TEST_SERVICE_ID, ALL_NETWORKS_SERVICE_2,
-                null /* subtype */) }
+        postSync { advertiser.addOrUpdateService(CASE_INSENSITIVE_TEST_SERVICE_ID,
+                ALL_NETWORKS_SERVICE_2, null /* subtype */, DEFAULT_ADVERTISING_OPTION) }
 
         // Callbacks for matching network and all networks both get the socket
         postSync {
@@ -400,11 +407,51 @@
     }
 
     @Test
+    fun testAddOrUpdateService_Updates() {
+        val advertiser =
+                MdnsAdvertiser(thread.looper, socketProvider, cb, mockDeps, sharedlog, flags)
+        postSync { advertiser.addOrUpdateService(SERVICE_ID_1, ALL_NETWORKS_SERVICE,
+                null /* subtype */, DEFAULT_ADVERTISING_OPTION) }
+
+        val socketCbCaptor = ArgumentCaptor.forClass(SocketCallback::class.java)
+        verify(socketProvider).requestSocket(eq(null), socketCbCaptor.capture())
+
+        val socketCb = socketCbCaptor.value
+        postSync { socketCb.onSocketCreated(TEST_SOCKETKEY_1, mockSocket1, listOf(TEST_LINKADDR)) }
+
+        verify(mockInterfaceAdvertiser1).addService(eq(SERVICE_ID_1),
+                argThat { it.matches(ALL_NETWORKS_SERVICE) }, eq(null))
+
+        val updateOptions = MdnsAdvertisingOptions.newBuilder().setIsOnlyUpdate(true).build()
+
+        // Update with serviceId that is not registered yet should fail
+        postSync { advertiser.addOrUpdateService(SERVICE_ID_2, ALL_NETWORKS_SERVICE, TEST_SUBTYPE,
+                updateOptions) }
+        verify(cb).onRegisterServiceFailed(SERVICE_ID_2, NsdManager.FAILURE_INTERNAL_ERROR)
+
+        // Update service with different NsdServiceInfo should fail
+        postSync { advertiser.addOrUpdateService(SERVICE_ID_1, SERVICE_1, TEST_SUBTYPE,
+                updateOptions) }
+        verify(cb).onRegisterServiceFailed(SERVICE_ID_1, NsdManager.FAILURE_INTERNAL_ERROR)
+
+        // Update service with same NsdServiceInfo but different subType should succeed
+        postSync { advertiser.addOrUpdateService(SERVICE_ID_1, ALL_NETWORKS_SERVICE, TEST_SUBTYPE,
+                updateOptions) }
+        verify(mockInterfaceAdvertiser1).updateService(eq(SERVICE_ID_1), eq(TEST_SUBTYPE))
+
+        // Newly created MdnsInterfaceAdvertiser will get addService() call.
+        postSync { socketCb.onSocketCreated(TEST_SOCKETKEY_2, mockSocket2, listOf(TEST_LINKADDR2)) }
+        verify(mockInterfaceAdvertiser2).addService(eq(SERVICE_ID_1),
+                argThat { it.matches(ALL_NETWORKS_SERVICE) }, eq(TEST_SUBTYPE))
+    }
+
+    @Test
     fun testRemoveService_whenAllServiceRemoved_thenUpdateHostName() {
         val advertiser =
             MdnsAdvertiser(thread.looper, socketProvider, cb, mockDeps, sharedlog, flags)
         verify(mockDeps, times(1)).generateHostname()
-        postSync { advertiser.addService(SERVICE_ID_1, SERVICE_1, null /* subtype */) }
+        postSync { advertiser.addOrUpdateService(SERVICE_ID_1, SERVICE_1,
+                null /* subtype */, DEFAULT_ADVERTISING_OPTION) }
         postSync { advertiser.removeService(SERVICE_ID_1) }
         verify(mockDeps, times(2)).generateHostname()
     }
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt
index db41a6a..f85d71d 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt
@@ -48,6 +48,7 @@
 import org.mockito.Mockito.doReturn
 import org.mockito.Mockito.eq
 import org.mockito.Mockito.mock
+import org.mockito.Mockito.never
 import org.mockito.Mockito.times
 import org.mockito.Mockito.verify
 
@@ -59,6 +60,7 @@
 private val TEST_HOSTNAME = arrayOf("Android_test", "local")
 
 private const val TEST_SERVICE_ID_1 = 42
+private const val TEST_SERVICE_ID_DUPLICATE = 43
 private val TEST_SERVICE_1 = NsdServiceInfo().apply {
     serviceType = "_testservice._tcp"
     serviceName = "MyTestService"
@@ -272,6 +274,28 @@
         verify(prober).restartForConflict(mockProbingInfo)
     }
 
+    @Test
+    fun testReplaceExitingService() {
+        doReturn(TEST_SERVICE_ID_DUPLICATE).`when`(repository)
+                .addService(eq(TEST_SERVICE_ID_DUPLICATE), any(), any())
+        val subType = "_sub"
+        advertiser.addService(TEST_SERVICE_ID_DUPLICATE, TEST_SERVICE_1, subType)
+        verify(repository).addService(eq(TEST_SERVICE_ID_DUPLICATE), any(), any())
+        verify(announcer).stop(TEST_SERVICE_ID_DUPLICATE)
+        verify(prober).startProbing(any())
+    }
+
+    @Test
+    fun testUpdateExistingService() {
+        doReturn(TEST_SERVICE_ID_DUPLICATE).`when`(repository)
+                .addService(eq(TEST_SERVICE_ID_DUPLICATE), any(), any())
+        val subType = "_sub"
+        advertiser.updateService(TEST_SERVICE_ID_DUPLICATE, subType)
+        verify(repository).updateService(eq(TEST_SERVICE_ID_DUPLICATE), any())
+        verify(announcer, never()).stop(TEST_SERVICE_ID_DUPLICATE)
+        verify(prober, never()).startProbing(any())
+    }
+
     private fun addServiceAndFinishProbing(serviceId: Int, serviceInfo: NsdServiceInfo):
             AnnouncementInfo {
         val testProbingInfo = mock(ProbingInfo::class.java)
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt
index f26f7e1..582e7b1 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt
@@ -129,7 +129,7 @@
     @Test
     fun testAddAndConflicts() {
         val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
-        repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1, null /* subtype */)
+        repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
         assertFailsWith(NameConflictException::class) {
             repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_1, null /* subtype */)
         }
@@ -139,6 +139,45 @@
     }
 
     @Test
+    fun testAddAndUpdates() {
+        val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
+        repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
+
+        assertFailsWith(IllegalArgumentException::class) {
+            repository.updateService(TEST_SERVICE_ID_2, null /* subtype */)
+        }
+
+        repository.updateService(TEST_SERVICE_ID_1, TEST_SUBTYPE)
+
+        val queriedName = arrayOf(TEST_SUBTYPE, "_sub", "_testservice", "_tcp", "local")
+        val questions = listOf(MdnsPointerRecord(queriedName,
+                0L /* receiptTimeMillis */,
+                false /* cacheFlush */,
+                // TTL and data is empty for a question
+                0L /* ttlMillis */,
+                null /* pointer */))
+        val query = MdnsPacket(0 /* flags */, questions, listOf() /* answers */,
+                listOf() /* authorityRecords */, listOf() /* additionalRecords */)
+        val src = InetSocketAddress(parseNumericAddress("192.0.2.123"), 5353)
+        val reply = repository.getReply(query, src)
+
+        assertNotNull(reply)
+
+        // TTLs as per RFC6762 10.
+        val longTtl = 4_500_000L
+        val serviceName = arrayOf("MyTestService", "_testservice", "_tcp", "local")
+
+        assertEquals(listOf(
+                MdnsPointerRecord(
+                        queriedName,
+                        0L /* receiptTimeMillis */,
+                        false /* cacheFlush */,
+                        longTtl,
+                        serviceName),
+        ), reply.answers)
+    }
+
+    @Test
     fun testInvalidReuseOfServiceId() {
         val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
         repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1, null /* subtype */)
@@ -758,7 +797,7 @@
 private fun MdnsRecordRepository.initWithService(
     serviceId: Int,
     serviceInfo: NsdServiceInfo,
-    subtype: String? = null
+    subtype: String? = null,
 ): AnnouncementInfo {
     updateAddresses(TEST_ADDRESSES)
     addService(serviceId, serviceInfo, subtype)