Implement onServiceConflict

Implement the onServiceConflict callback in MdnsAdvertiser, refactoring
the conflict detection to reuse it both in onServiceConflict (when a
conflict is detected on the network after add) and at service add time.

Bug: 241738458
Test: atest MdnsAdvertiserTest
Change-Id: I69128db936296bd2c5e90e9f00df19fd881e1748
diff --git a/service-t/src/com/android/server/mdns/MdnsAdvertiser.java b/service-t/src/com/android/server/mdns/MdnsAdvertiser.java
index 4e40efe..977478a 100644
--- a/service-t/src/com/android/server/mdns/MdnsAdvertiser.java
+++ b/service-t/src/com/android/server/mdns/MdnsAdvertiser.java
@@ -29,10 +29,10 @@
 
 import com.android.internal.annotations.VisibleForTesting;
 
-import java.io.IOException;
 import java.util.List;
 import java.util.Map;
-import java.util.function.Predicate;
+import java.util.function.BiPredicate;
+import java.util.function.Consumer;
 
 /**
  * MdnsAdvertiser manages advertising services per {@link com.android.server.NsdService} requests.
@@ -85,7 +85,7 @@
         public void onRegisterServiceSucceeded(
                 @NonNull MdnsInterfaceAdvertiser advertiser, int serviceId) {
             // Wait for all current interfaces to be done probing before notifying of success.
-            if (anyAdvertiser(a -> a.isProbing(serviceId))) return;
+            if (any(mAllAdvertisers, (k, a) -> a.isProbing(serviceId))) return;
             // The service may still be unregistered/renamed if a conflict is found on a later added
             // interface, or if a conflicting announcement/reply is detected (RFC6762 9.)
 
@@ -102,7 +102,37 @@
 
         @Override
         public void onServiceConflict(@NonNull MdnsInterfaceAdvertiser advertiser, int serviceId) {
-            // TODO: handle conflicts found after registration (during or after probing)
+            if (DBG) {
+                Log.v(TAG, "Found conflict, restarted probing for service " + serviceId);
+            }
+
+            final Registration registration = mRegistrations.get(serviceId);
+            if (registration == null) return;
+            if (registration.mNotifiedRegistrationSuccess) {
+                // TODO: consider notifying clients that the service is no longer registered with
+                // the old name (back to probing). The legacy implementation did not send any
+                // callback though; it only sent onServiceRegistered after re-probing finishes
+                // (with the old, conflicting, actually not used name as argument... The new
+                // implementation will send callbacks with the new name).
+                registration.mNotifiedRegistrationSuccess = false;
+
+                // The service was done probing, just reset it to probing state (RFC6762 9.)
+                forAllAdvertisers(a -> a.restartProbingForConflict(serviceId));
+                return;
+            }
+
+            // Conflict was found during probing; rename once to find a name that has no conflict
+            registration.updateForConflict(
+                    registration.makeNewServiceInfoForConflict(1 /* renameCount */),
+                    1 /* renameCount */);
+
+            // Keep renaming if the new name conflicts in local registrations
+            updateRegistrationUntilNoConflict((net, adv) -> adv.hasRegistration(registration),
+                    registration);
+
+            // Update advertisers to use the new name
+            forAllAdvertisers(a -> a.renameServiceForConflict(
+                    serviceId, registration.getServiceInfo()));
         }
 
         @Override
@@ -116,6 +146,25 @@
         }
     };
 
+    private boolean hasAnyConflict(
+            @NonNull BiPredicate<Network, InterfaceAdvertiserRequest> applicableAdvertiserFilter,
+            @NonNull NsdServiceInfo newInfo) {
+        return any(mAdvertiserRequests, (network, adv) ->
+                applicableAdvertiserFilter.test(network, adv) && adv.hasConflict(newInfo));
+    }
+
+    private void updateRegistrationUntilNoConflict(
+            @NonNull BiPredicate<Network, InterfaceAdvertiserRequest> applicableAdvertiserFilter,
+            @NonNull Registration registration) {
+        int renameCount = 0;
+        NsdServiceInfo newInfo = registration.getServiceInfo();
+        while (hasAnyConflict(applicableAdvertiserFilter, newInfo)) {
+            renameCount++;
+            newInfo = registration.makeNewServiceInfoForConflict(renameCount);
+        }
+        registration.updateForConflict(newInfo, renameCount);
+    }
+
     /**
      * A request for a {@link MdnsInterfaceAdvertiser}.
      *
@@ -153,6 +202,21 @@
         }
 
         /**
+         * Return whether this {@link InterfaceAdvertiserRequest} has the given registration.
+         */
+        boolean hasRegistration(@NonNull Registration registration) {
+            return mPendingRegistrations.indexOfValue(registration) >= 0;
+        }
+
+        /**
+         * Return whether using the proposed new {@link NsdServiceInfo} to add a registration would
+         * cause a conflict in this {@link InterfaceAdvertiserRequest}.
+         */
+        boolean hasConflict(@NonNull NsdServiceInfo newInfo) {
+            return getConflictingService(newInfo) >= 0;
+        }
+
+        /**
          * Get the ID of a conflicting service, or -1 if none.
          */
         int getConflictingService(@NonNull NsdServiceInfo info) {
@@ -166,16 +230,19 @@
             return -1;
         }
 
-        void addService(int id, Registration registration)
-                throws NameConflictException {
-            final int conflicting = getConflictingService(registration.getServiceInfo());
-            if (conflicting >= 0) {
-                throw new NameConflictException(conflicting);
-            }
-
+        /**
+         * Add a service.
+         *
+         * Conflicts must be checked via {@link #getConflictingService} before attempting to add.
+         */
+        void addService(int id, Registration registration) {
             mPendingRegistrations.put(id, registration);
             for (int i = 0; i < mAdvertisers.size(); i++) {
-                mAdvertisers.valueAt(i).addService(id, registration.getServiceInfo());
+                try {
+                    mAdvertisers.valueAt(i).addService(id, registration.getServiceInfo());
+                } catch (NameConflictException e) {
+                    Log.wtf(TAG, "Name conflict adding services that should have unique names", e);
+                }
             }
         }
 
@@ -239,32 +306,42 @@
         /**
          * Update the registration to use a different service name, after a conflict was found.
          *
+         * @param newInfo New service info to use.
+         * @param renameCount How many renames were done before reaching the current name.
+         */
+        private void updateForConflict(@NonNull NsdServiceInfo newInfo, int renameCount) {
+            mConflictCount += renameCount;
+            mServiceInfo = newInfo;
+        }
+
+        /**
+         * Make a new service name for the registration, after a conflict was found.
+         *
          * If a name conflict was found during probing or because different advertising requests
          * used the same name, the registration is attempted again with a new name (here using
          * a number suffix, (1), (2) etc). Registration success is notified once probing succeeds
          * with a new name. This matches legacy behavior based on mdnsresponder, and appendix D of
          * RFC6763.
-         * @return The new service info with the updated name.
+         *
+         * @param renameCount How much to increase the number suffix for this conflict.
          */
         @NonNull
-        private NsdServiceInfo updateForConflict() {
-            mConflictCount++;
+        public NsdServiceInfo makeNewServiceInfoForConflict(int renameCount) {
             // In case of conflict choose a different service name. After the first conflict use
             // "Name (2)", then "Name (3)" etc.
             // TODO: use a hidden method in NsdServiceInfo once MdnsAdvertiser is moved to service-t
             final NsdServiceInfo newInfo = new NsdServiceInfo();
-            newInfo.setServiceName(mOriginalName + " (" + (mConflictCount + 1) + ")");
+            newInfo.setServiceName(mOriginalName + " (" + (mConflictCount + renameCount + 1) + ")");
             newInfo.setServiceType(mServiceInfo.getServiceType());
             for (Map.Entry<String, byte[]> attr : mServiceInfo.getAttributes().entrySet()) {
-                newInfo.setAttribute(attr.getKey(), attr.getValue());
+                newInfo.setAttribute(attr.getKey(),
+                        attr.getValue() == null ? null : new String(attr.getValue()));
             }
             newInfo.setHost(mServiceInfo.getHost());
             newInfo.setPort(mServiceInfo.getPort());
             newInfo.setNetwork(mServiceInfo.getNetwork());
             // interfaceIndex is not set when registering
-
-            mServiceInfo = newInfo;
-            return mServiceInfo;
+            return newInfo;
         }
 
         @NonNull
@@ -338,55 +415,27 @@
             Log.i(TAG, "Adding service " + service + " with ID " + id);
         }
 
-        try {
-            final Registration registration = new Registration(service);
-            while (!tryAddRegistration(id, registration)) {
-                registration.updateForConflict();
-            }
-
-            mRegistrations.put(id, registration);
-        } catch (IOException e) {
-            Log.e(TAG, "Error adding service " + service, e);
-            removeService(id);
-            // TODO (b/264986328): add a more specific error code
-            mCb.onRegisterServiceFailed(id, NsdManager.FAILURE_INTERNAL_ERROR);
-        }
-    }
-
-    private boolean tryAddRegistration(int id, @NonNull Registration registration)
-            throws IOException {
-        final NsdServiceInfo serviceInfo = registration.getServiceInfo();
-        final Network network = serviceInfo.getNetwork();
-        try {
-            InterfaceAdvertiserRequest advertiser = mAdvertiserRequests.get(network);
-            if (advertiser == null) {
-                advertiser = new InterfaceAdvertiserRequest(network);
-                mAdvertiserRequests.put(network, advertiser);
-            }
-            advertiser.addService(id, registration);
-        } catch (NameConflictException e) {
-            if (DBG) {
-                Log.i(TAG, "Service name conflicts: " + serviceInfo.getServiceName());
-            }
-            removeService(id);
-            return false;
+        final Network network = service.getNetwork();
+        final Registration registration = new Registration(service);
+        final BiPredicate<Network, InterfaceAdvertiserRequest> checkConflictFilter;
+        if (network == null) {
+            // If registering on all networks, no advertiser must have conflicts
+            checkConflictFilter = (net, adv) -> true;
+        } else {
+            // If registering on one network, the matching network advertiser and the one for all
+            // networks must not have conflicts
+            checkConflictFilter = (net, adv) -> net == null || network.equals(net);
         }
 
-        // When adding a service to a specific network, check that it does not conflict with other
-        // registrations advertising on all networks
-        final InterfaceAdvertiserRequest allNetworksAdvertiser = mAdvertiserRequests.get(null);
-        if (network != null && allNetworksAdvertiser != null
-                && allNetworksAdvertiser.getConflictingService(serviceInfo) >= 0) {
-            if (DBG) {
-                Log.i(TAG, "Service conflicts with advertisement on all networks: "
-                        + serviceInfo.getServiceName());
-            }
-            removeService(id);
-            return false;
-        }
+        updateRegistrationUntilNoConflict(checkConflictFilter, registration);
 
+        InterfaceAdvertiserRequest advertiser = mAdvertiserRequests.get(network);
+        if (advertiser == null) {
+            advertiser = new InterfaceAdvertiserRequest(network);
+            mAdvertiserRequests.put(network, advertiser);
+        }
+        advertiser.addService(id, registration);
         mRegistrations.put(id, registration);
-        return true;
     }
 
     /**
@@ -406,12 +455,20 @@
         mRegistrations.remove(id);
     }
 
-    private boolean anyAdvertiser(@NonNull Predicate<MdnsInterfaceAdvertiser> predicate) {
-        for (int i = 0; i < mAllAdvertisers.size(); i++) {
-            if (predicate.test(mAllAdvertisers.valueAt(i))) {
+    private static <K, V> boolean any(@NonNull ArrayMap<K, V> map,
+            @NonNull BiPredicate<K, V> predicate) {
+        for (int i = 0; i < map.size(); i++) {
+            if (predicate.test(map.keyAt(i), map.valueAt(i))) {
                 return true;
             }
         }
         return false;
     }
+
+    private void forAllAdvertisers(@NonNull Consumer<MdnsInterfaceAdvertiser> consumer) {
+        any(mAllAdvertisers, (socket, advertiser) -> {
+            consumer.accept(advertiser);
+            return false;
+        });
+    }
 }
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsAdvertiserTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsAdvertiserTest.kt
index e2babb1..1febe6d 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsAdvertiserTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsAdvertiserTest.kt
@@ -38,6 +38,7 @@
 import org.mockito.Mockito.any
 import org.mockito.Mockito.anyInt
 import org.mockito.Mockito.argThat
+import org.mockito.Mockito.atLeastOnce
 import org.mockito.Mockito.doReturn
 import org.mockito.Mockito.mock
 import org.mockito.Mockito.never
@@ -161,6 +162,60 @@
         verify(socketProvider).unrequestSocket(socketCb)
     }
 
+    @Test
+    fun testAddService_Conflicts() {
+        val advertiser = MdnsAdvertiser(thread.looper, socketProvider, cb, mockDeps)
+        postSync { advertiser.addService(SERVICE_ID_1, SERVICE_1) }
+
+        val oneNetSocketCbCaptor = ArgumentCaptor.forClass(SocketCallback::class.java)
+        verify(socketProvider).requestSocket(eq(TEST_NETWORK_1), oneNetSocketCbCaptor.capture())
+        val oneNetSocketCb = oneNetSocketCbCaptor.value
+
+        // Register a service with the same name on all networks (name conflict)
+        postSync { advertiser.addService(SERVICE_ID_2, ALL_NETWORKS_SERVICE) }
+        val allNetSocketCbCaptor = ArgumentCaptor.forClass(SocketCallback::class.java)
+        verify(socketProvider).requestSocket(eq(null), allNetSocketCbCaptor.capture())
+        val allNetSocketCb = allNetSocketCbCaptor.value
+
+        // Callbacks for matching network and all networks both get the socket
+        postSync {
+            oneNetSocketCb.onSocketCreated(TEST_NETWORK_1, mockSocket1, listOf(TEST_LINKADDR))
+            allNetSocketCb.onSocketCreated(TEST_NETWORK_1, mockSocket1, listOf(TEST_LINKADDR))
+        }
+
+        val expectedRenamed = NsdServiceInfo(
+                "${ALL_NETWORKS_SERVICE.serviceName} (2)", ALL_NETWORKS_SERVICE.serviceType).apply {
+            port = ALL_NETWORKS_SERVICE.port
+            host = ALL_NETWORKS_SERVICE.host
+            network = ALL_NETWORKS_SERVICE.network
+        }
+
+        val intAdvCbCaptor = ArgumentCaptor.forClass(MdnsInterfaceAdvertiser.Callback::class.java)
+        verify(mockDeps).makeAdvertiser(eq(mockSocket1), eq(listOf(TEST_LINKADDR)),
+                eq(thread.looper), any(), intAdvCbCaptor.capture())
+        verify(mockInterfaceAdvertiser1).addService(eq(SERVICE_ID_1),
+                argThat { it.matches(SERVICE_1) })
+        verify(mockInterfaceAdvertiser1).addService(eq(SERVICE_ID_2),
+                argThat { it.matches(expectedRenamed) })
+
+        doReturn(false).`when`(mockInterfaceAdvertiser1).isProbing(SERVICE_ID_1)
+        postSync { intAdvCbCaptor.value.onRegisterServiceSucceeded(
+                mockInterfaceAdvertiser1, SERVICE_ID_1) }
+        verify(cb).onRegisterServiceSucceeded(eq(SERVICE_ID_1), argThat { it.matches(SERVICE_1) })
+
+        doReturn(false).`when`(mockInterfaceAdvertiser1).isProbing(SERVICE_ID_2)
+        postSync { intAdvCbCaptor.value.onRegisterServiceSucceeded(
+                mockInterfaceAdvertiser1, SERVICE_ID_2) }
+        verify(cb).onRegisterServiceSucceeded(eq(SERVICE_ID_2),
+                argThat { it.matches(expectedRenamed) })
+
+        postSync { oneNetSocketCb.onInterfaceDestroyed(TEST_NETWORK_1, mockSocket1) }
+        postSync { allNetSocketCb.onInterfaceDestroyed(TEST_NETWORK_1, mockSocket1) }
+
+        // destroyNow can be called multiple times
+        verify(mockInterfaceAdvertiser1, atLeastOnce()).destroyNow()
+    }
+
     private fun postSync(r: () -> Unit) {
         handler.post(r)
         handler.waitForIdle(TIMEOUT_MS)