Implement onServiceConflict
Implement the onServiceConflict callback in MdnsAdvertiser, refactoring
the conflict detection to reuse it both in onServiceConflict (when a
conflict is detected on the network after add) and at service add time.
Bug: 241738458
Test: atest MdnsAdvertiserTest
Change-Id: I69128db936296bd2c5e90e9f00df19fd881e1748
diff --git a/service-t/src/com/android/server/mdns/MdnsAdvertiser.java b/service-t/src/com/android/server/mdns/MdnsAdvertiser.java
index 4e40efe..977478a 100644
--- a/service-t/src/com/android/server/mdns/MdnsAdvertiser.java
+++ b/service-t/src/com/android/server/mdns/MdnsAdvertiser.java
@@ -29,10 +29,10 @@
import com.android.internal.annotations.VisibleForTesting;
-import java.io.IOException;
import java.util.List;
import java.util.Map;
-import java.util.function.Predicate;
+import java.util.function.BiPredicate;
+import java.util.function.Consumer;
/**
* MdnsAdvertiser manages advertising services per {@link com.android.server.NsdService} requests.
@@ -85,7 +85,7 @@
public void onRegisterServiceSucceeded(
@NonNull MdnsInterfaceAdvertiser advertiser, int serviceId) {
// Wait for all current interfaces to be done probing before notifying of success.
- if (anyAdvertiser(a -> a.isProbing(serviceId))) return;
+ if (any(mAllAdvertisers, (k, a) -> a.isProbing(serviceId))) return;
// The service may still be unregistered/renamed if a conflict is found on a later added
// interface, or if a conflicting announcement/reply is detected (RFC6762 9.)
@@ -102,7 +102,37 @@
@Override
public void onServiceConflict(@NonNull MdnsInterfaceAdvertiser advertiser, int serviceId) {
- // TODO: handle conflicts found after registration (during or after probing)
+ if (DBG) {
+ Log.v(TAG, "Found conflict, restarted probing for service " + serviceId);
+ }
+
+ final Registration registration = mRegistrations.get(serviceId);
+ if (registration == null) return;
+ if (registration.mNotifiedRegistrationSuccess) {
+ // TODO: consider notifying clients that the service is no longer registered with
+ // the old name (back to probing). The legacy implementation did not send any
+ // callback though; it only sent onServiceRegistered after re-probing finishes
+ // (with the old, conflicting, actually not used name as argument... The new
+ // implementation will send callbacks with the new name).
+ registration.mNotifiedRegistrationSuccess = false;
+
+ // The service was done probing, just reset it to probing state (RFC6762 9.)
+ forAllAdvertisers(a -> a.restartProbingForConflict(serviceId));
+ return;
+ }
+
+ // Conflict was found during probing; rename once to find a name that has no conflict
+ registration.updateForConflict(
+ registration.makeNewServiceInfoForConflict(1 /* renameCount */),
+ 1 /* renameCount */);
+
+ // Keep renaming if the new name conflicts in local registrations
+ updateRegistrationUntilNoConflict((net, adv) -> adv.hasRegistration(registration),
+ registration);
+
+ // Update advertisers to use the new name
+ forAllAdvertisers(a -> a.renameServiceForConflict(
+ serviceId, registration.getServiceInfo()));
}
@Override
@@ -116,6 +146,25 @@
}
};
+ private boolean hasAnyConflict(
+ @NonNull BiPredicate<Network, InterfaceAdvertiserRequest> applicableAdvertiserFilter,
+ @NonNull NsdServiceInfo newInfo) {
+ return any(mAdvertiserRequests, (network, adv) ->
+ applicableAdvertiserFilter.test(network, adv) && adv.hasConflict(newInfo));
+ }
+
+ private void updateRegistrationUntilNoConflict(
+ @NonNull BiPredicate<Network, InterfaceAdvertiserRequest> applicableAdvertiserFilter,
+ @NonNull Registration registration) {
+ int renameCount = 0;
+ NsdServiceInfo newInfo = registration.getServiceInfo();
+ while (hasAnyConflict(applicableAdvertiserFilter, newInfo)) {
+ renameCount++;
+ newInfo = registration.makeNewServiceInfoForConflict(renameCount);
+ }
+ registration.updateForConflict(newInfo, renameCount);
+ }
+
/**
* A request for a {@link MdnsInterfaceAdvertiser}.
*
@@ -153,6 +202,21 @@
}
/**
+ * Return whether this {@link InterfaceAdvertiserRequest} has the given registration.
+ */
+ boolean hasRegistration(@NonNull Registration registration) {
+ return mPendingRegistrations.indexOfValue(registration) >= 0;
+ }
+
+ /**
+ * Return whether using the proposed new {@link NsdServiceInfo} to add a registration would
+ * cause a conflict in this {@link InterfaceAdvertiserRequest}.
+ */
+ boolean hasConflict(@NonNull NsdServiceInfo newInfo) {
+ return getConflictingService(newInfo) >= 0;
+ }
+
+ /**
* Get the ID of a conflicting service, or -1 if none.
*/
int getConflictingService(@NonNull NsdServiceInfo info) {
@@ -166,16 +230,19 @@
return -1;
}
- void addService(int id, Registration registration)
- throws NameConflictException {
- final int conflicting = getConflictingService(registration.getServiceInfo());
- if (conflicting >= 0) {
- throw new NameConflictException(conflicting);
- }
-
+ /**
+ * Add a service.
+ *
+ * Conflicts must be checked via {@link #getConflictingService} before attempting to add.
+ */
+ void addService(int id, Registration registration) {
mPendingRegistrations.put(id, registration);
for (int i = 0; i < mAdvertisers.size(); i++) {
- mAdvertisers.valueAt(i).addService(id, registration.getServiceInfo());
+ try {
+ mAdvertisers.valueAt(i).addService(id, registration.getServiceInfo());
+ } catch (NameConflictException e) {
+ Log.wtf(TAG, "Name conflict adding services that should have unique names", e);
+ }
}
}
@@ -239,32 +306,42 @@
/**
* Update the registration to use a different service name, after a conflict was found.
*
+ * @param newInfo New service info to use.
+ * @param renameCount How many renames were done before reaching the current name.
+ */
+ private void updateForConflict(@NonNull NsdServiceInfo newInfo, int renameCount) {
+ mConflictCount += renameCount;
+ mServiceInfo = newInfo;
+ }
+
+ /**
+ * Make a new service name for the registration, after a conflict was found.
+ *
* If a name conflict was found during probing or because different advertising requests
* used the same name, the registration is attempted again with a new name (here using
* a number suffix, (1), (2) etc). Registration success is notified once probing succeeds
* with a new name. This matches legacy behavior based on mdnsresponder, and appendix D of
* RFC6763.
- * @return The new service info with the updated name.
+ *
+ * @param renameCount How much to increase the number suffix for this conflict.
*/
@NonNull
- private NsdServiceInfo updateForConflict() {
- mConflictCount++;
+ public NsdServiceInfo makeNewServiceInfoForConflict(int renameCount) {
// In case of conflict choose a different service name. After the first conflict use
// "Name (2)", then "Name (3)" etc.
// TODO: use a hidden method in NsdServiceInfo once MdnsAdvertiser is moved to service-t
final NsdServiceInfo newInfo = new NsdServiceInfo();
- newInfo.setServiceName(mOriginalName + " (" + (mConflictCount + 1) + ")");
+ newInfo.setServiceName(mOriginalName + " (" + (mConflictCount + renameCount + 1) + ")");
newInfo.setServiceType(mServiceInfo.getServiceType());
for (Map.Entry<String, byte[]> attr : mServiceInfo.getAttributes().entrySet()) {
- newInfo.setAttribute(attr.getKey(), attr.getValue());
+ newInfo.setAttribute(attr.getKey(),
+ attr.getValue() == null ? null : new String(attr.getValue()));
}
newInfo.setHost(mServiceInfo.getHost());
newInfo.setPort(mServiceInfo.getPort());
newInfo.setNetwork(mServiceInfo.getNetwork());
// interfaceIndex is not set when registering
-
- mServiceInfo = newInfo;
- return mServiceInfo;
+ return newInfo;
}
@NonNull
@@ -338,55 +415,27 @@
Log.i(TAG, "Adding service " + service + " with ID " + id);
}
- try {
- final Registration registration = new Registration(service);
- while (!tryAddRegistration(id, registration)) {
- registration.updateForConflict();
- }
-
- mRegistrations.put(id, registration);
- } catch (IOException e) {
- Log.e(TAG, "Error adding service " + service, e);
- removeService(id);
- // TODO (b/264986328): add a more specific error code
- mCb.onRegisterServiceFailed(id, NsdManager.FAILURE_INTERNAL_ERROR);
- }
- }
-
- private boolean tryAddRegistration(int id, @NonNull Registration registration)
- throws IOException {
- final NsdServiceInfo serviceInfo = registration.getServiceInfo();
- final Network network = serviceInfo.getNetwork();
- try {
- InterfaceAdvertiserRequest advertiser = mAdvertiserRequests.get(network);
- if (advertiser == null) {
- advertiser = new InterfaceAdvertiserRequest(network);
- mAdvertiserRequests.put(network, advertiser);
- }
- advertiser.addService(id, registration);
- } catch (NameConflictException e) {
- if (DBG) {
- Log.i(TAG, "Service name conflicts: " + serviceInfo.getServiceName());
- }
- removeService(id);
- return false;
+ final Network network = service.getNetwork();
+ final Registration registration = new Registration(service);
+ final BiPredicate<Network, InterfaceAdvertiserRequest> checkConflictFilter;
+ if (network == null) {
+ // If registering on all networks, no advertiser must have conflicts
+ checkConflictFilter = (net, adv) -> true;
+ } else {
+ // If registering on one network, the matching network advertiser and the one for all
+ // networks must not have conflicts
+ checkConflictFilter = (net, adv) -> net == null || network.equals(net);
}
- // When adding a service to a specific network, check that it does not conflict with other
- // registrations advertising on all networks
- final InterfaceAdvertiserRequest allNetworksAdvertiser = mAdvertiserRequests.get(null);
- if (network != null && allNetworksAdvertiser != null
- && allNetworksAdvertiser.getConflictingService(serviceInfo) >= 0) {
- if (DBG) {
- Log.i(TAG, "Service conflicts with advertisement on all networks: "
- + serviceInfo.getServiceName());
- }
- removeService(id);
- return false;
- }
+ updateRegistrationUntilNoConflict(checkConflictFilter, registration);
+ InterfaceAdvertiserRequest advertiser = mAdvertiserRequests.get(network);
+ if (advertiser == null) {
+ advertiser = new InterfaceAdvertiserRequest(network);
+ mAdvertiserRequests.put(network, advertiser);
+ }
+ advertiser.addService(id, registration);
mRegistrations.put(id, registration);
- return true;
}
/**
@@ -406,12 +455,20 @@
mRegistrations.remove(id);
}
- private boolean anyAdvertiser(@NonNull Predicate<MdnsInterfaceAdvertiser> predicate) {
- for (int i = 0; i < mAllAdvertisers.size(); i++) {
- if (predicate.test(mAllAdvertisers.valueAt(i))) {
+ private static <K, V> boolean any(@NonNull ArrayMap<K, V> map,
+ @NonNull BiPredicate<K, V> predicate) {
+ for (int i = 0; i < map.size(); i++) {
+ if (predicate.test(map.keyAt(i), map.valueAt(i))) {
return true;
}
}
return false;
}
+
+ private void forAllAdvertisers(@NonNull Consumer<MdnsInterfaceAdvertiser> consumer) {
+ any(mAllAdvertisers, (socket, advertiser) -> {
+ consumer.accept(advertiser);
+ return false;
+ });
+ }
}
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsAdvertiserTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsAdvertiserTest.kt
index e2babb1..1febe6d 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsAdvertiserTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsAdvertiserTest.kt
@@ -38,6 +38,7 @@
import org.mockito.Mockito.any
import org.mockito.Mockito.anyInt
import org.mockito.Mockito.argThat
+import org.mockito.Mockito.atLeastOnce
import org.mockito.Mockito.doReturn
import org.mockito.Mockito.mock
import org.mockito.Mockito.never
@@ -161,6 +162,60 @@
verify(socketProvider).unrequestSocket(socketCb)
}
+ @Test
+ fun testAddService_Conflicts() {
+ val advertiser = MdnsAdvertiser(thread.looper, socketProvider, cb, mockDeps)
+ postSync { advertiser.addService(SERVICE_ID_1, SERVICE_1) }
+
+ val oneNetSocketCbCaptor = ArgumentCaptor.forClass(SocketCallback::class.java)
+ verify(socketProvider).requestSocket(eq(TEST_NETWORK_1), oneNetSocketCbCaptor.capture())
+ val oneNetSocketCb = oneNetSocketCbCaptor.value
+
+ // Register a service with the same name on all networks (name conflict)
+ postSync { advertiser.addService(SERVICE_ID_2, ALL_NETWORKS_SERVICE) }
+ val allNetSocketCbCaptor = ArgumentCaptor.forClass(SocketCallback::class.java)
+ verify(socketProvider).requestSocket(eq(null), allNetSocketCbCaptor.capture())
+ val allNetSocketCb = allNetSocketCbCaptor.value
+
+ // Callbacks for matching network and all networks both get the socket
+ postSync {
+ oneNetSocketCb.onSocketCreated(TEST_NETWORK_1, mockSocket1, listOf(TEST_LINKADDR))
+ allNetSocketCb.onSocketCreated(TEST_NETWORK_1, mockSocket1, listOf(TEST_LINKADDR))
+ }
+
+ val expectedRenamed = NsdServiceInfo(
+ "${ALL_NETWORKS_SERVICE.serviceName} (2)", ALL_NETWORKS_SERVICE.serviceType).apply {
+ port = ALL_NETWORKS_SERVICE.port
+ host = ALL_NETWORKS_SERVICE.host
+ network = ALL_NETWORKS_SERVICE.network
+ }
+
+ val intAdvCbCaptor = ArgumentCaptor.forClass(MdnsInterfaceAdvertiser.Callback::class.java)
+ verify(mockDeps).makeAdvertiser(eq(mockSocket1), eq(listOf(TEST_LINKADDR)),
+ eq(thread.looper), any(), intAdvCbCaptor.capture())
+ verify(mockInterfaceAdvertiser1).addService(eq(SERVICE_ID_1),
+ argThat { it.matches(SERVICE_1) })
+ verify(mockInterfaceAdvertiser1).addService(eq(SERVICE_ID_2),
+ argThat { it.matches(expectedRenamed) })
+
+ doReturn(false).`when`(mockInterfaceAdvertiser1).isProbing(SERVICE_ID_1)
+ postSync { intAdvCbCaptor.value.onRegisterServiceSucceeded(
+ mockInterfaceAdvertiser1, SERVICE_ID_1) }
+ verify(cb).onRegisterServiceSucceeded(eq(SERVICE_ID_1), argThat { it.matches(SERVICE_1) })
+
+ doReturn(false).`when`(mockInterfaceAdvertiser1).isProbing(SERVICE_ID_2)
+ postSync { intAdvCbCaptor.value.onRegisterServiceSucceeded(
+ mockInterfaceAdvertiser1, SERVICE_ID_2) }
+ verify(cb).onRegisterServiceSucceeded(eq(SERVICE_ID_2),
+ argThat { it.matches(expectedRenamed) })
+
+ postSync { oneNetSocketCb.onInterfaceDestroyed(TEST_NETWORK_1, mockSocket1) }
+ postSync { allNetSocketCb.onInterfaceDestroyed(TEST_NETWORK_1, mockSocket1) }
+
+ // destroyNow can be called multiple times
+ verify(mockInterfaceAdvertiser1, atLeastOnce()).destroyNow()
+ }
+
private fun postSync(r: () -> Unit) {
handler.post(r)
handler.waitForIdle(TIMEOUT_MS)