Optionally use MdnsAdvertiser for advertising

Based on a flag, use MdnsAdvertiser for advertising instead of the
legacy mdnsresponder implementation.

Bug: 241738458
Test: atest NsdServiceTest
Change-Id: I2d5069097c11f2959e3792cac326d179a3116479
diff --git a/service-t/src/com/android/server/NsdService.java b/service-t/src/com/android/server/NsdService.java
index 84b9f12..9510c12 100644
--- a/service-t/src/com/android/server/NsdService.java
+++ b/service-t/src/com/android/server/NsdService.java
@@ -60,6 +60,7 @@
 import com.android.net.module.util.DeviceConfigUtils;
 import com.android.net.module.util.PermissionUtils;
 import com.android.server.connectivity.mdns.ExecutorProvider;
+import com.android.server.connectivity.mdns.MdnsAdvertiser;
 import com.android.server.connectivity.mdns.MdnsDiscoveryManager;
 import com.android.server.connectivity.mdns.MdnsMultinetworkSocketClient;
 import com.android.server.connectivity.mdns.MdnsSearchOptions;
@@ -74,6 +75,11 @@
 import java.net.NetworkInterface;
 import java.net.SocketException;
 import java.net.UnknownHostException;
+import java.nio.ByteBuffer;
+import java.nio.CharBuffer;
+import java.nio.charset.Charset;
+import java.nio.charset.CharsetEncoder;
+import java.nio.charset.StandardCharsets;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
@@ -89,10 +95,22 @@
 public class NsdService extends INsdManager.Stub {
     private static final String TAG = "NsdService";
     private static final String MDNS_TAG = "mDnsConnector";
+    /**
+     * Enable discovery using the Java DiscoveryManager, instead of the legacy mdnsresponder
+     * implementation.
+     */
     private static final String MDNS_DISCOVERY_MANAGER_VERSION = "mdns_discovery_manager_version";
     private static final String LOCAL_DOMAIN_NAME = "local";
+    // Max label length as per RFC 1034/1035
+    private static final int MAX_LABEL_LENGTH = 63;
 
-    private static final boolean DBG = Log.isLoggable(TAG, Log.DEBUG);
+    /**
+     * Enable advertising using the Java MdnsAdvertiser, instead of the legacy mdnsresponder
+     * implementation.
+     */
+    private static final String MDNS_ADVERTISER_VERSION = "mdns_advertiser_version";
+
+    public static final boolean DBG = Log.isLoggable(TAG, Log.DEBUG);
     private static final long CLEANUP_DELAY_MS = 10000;
     private static final int IFACE_IDX_ANY = 0;
 
@@ -106,6 +124,8 @@
     private final MdnsDiscoveryManager mMdnsDiscoveryManager;
     @Nullable
     private final MdnsSocketProvider mMdnsSocketProvider;
+    @Nullable
+    private final MdnsAdvertiser mAdvertiser;
     // WARNING : Accessing these values in any thread is not safe, it must only be changed in the
     // state machine thread. If change this outside state machine, it will need to introduce
     // synchronization.
@@ -345,7 +365,7 @@
                                 mLegacyClientCount -= 1;
                             }
                         }
-                        if (mMdnsDiscoveryManager != null) {
+                        if (mMdnsDiscoveryManager != null || mAdvertiser != null) {
                             maybeStopMonitoringSocketsIfNoActiveRequest();
                         }
                         maybeScheduleStop();
@@ -446,6 +466,7 @@
                 clientInfo.mClientRequests.delete(clientId);
                 mIdToClientInfoMap.remove(globalId);
                 maybeScheduleStop();
+                maybeStopMonitoringSocketsIfNoActiveRequest();
             }
 
             private void storeListenerMap(int clientId, int transactionId, MdnsListener listener,
@@ -484,8 +505,31 @@
                 final Matcher matcher = serviceTypePattern.matcher(serviceType);
                 if (!matcher.matches()) return null;
                 return matcher.group(1) == null
-                        ? serviceType + ".local"
-                        : matcher.group(1) + "_sub." + matcher.group(2) + ".local";
+                        ? serviceType
+                        : matcher.group(1) + "_sub." + matcher.group(2);
+            }
+
+            /**
+             * Truncate a service name to up to 63 UTF-8 bytes.
+             *
+             * See RFC6763 4.1.1: service instance names are UTF-8 and up to 63 bytes. Truncating
+             * names used in registerService follows historical behavior (see mdnsresponder
+             * handle_regservice_request).
+             */
+            @NonNull
+            private String truncateServiceName(@NonNull String originalName) {
+                // UTF-8 is at most 4 bytes per character; return early in the common case where
+                // the name can't possibly be over the limit given its string length.
+                if (originalName.length() <= MAX_LABEL_LENGTH / 4) return originalName;
+
+                final Charset utf8 = StandardCharsets.UTF_8;
+                final CharsetEncoder encoder = utf8.newEncoder();
+                final ByteBuffer out = ByteBuffer.allocate(MAX_LABEL_LENGTH);
+                // encode will write as many characters as possible to the out buffer, and just
+                // return an overflow code if there were too many characters (no need to check the
+                // return code here, this method truncates the name on purpose).
+                encoder.encode(CharBuffer.wrap(originalName), out, true /* endOfInput */);
+                return new String(out.array(), 0, out.position(), utf8);
             }
 
             @Override
@@ -523,14 +567,16 @@
                                 break;
                             }
 
+                            final String listenServiceType = serviceType + ".local";
                             maybeStartMonitoringSockets();
                             final MdnsListener listener =
-                                    new DiscoveryListener(clientId, id, info, serviceType);
+                                    new DiscoveryListener(clientId, id, info, listenServiceType);
                             final MdnsSearchOptions options = MdnsSearchOptions.newBuilder()
                                     .setNetwork(info.getNetwork())
                                     .setIsPassiveMode(true)
                                     .build();
-                            mMdnsDiscoveryManager.registerListener(serviceType, listener, options);
+                            mMdnsDiscoveryManager.registerListener(
+                                    listenServiceType, listener, options);
                             storeListenerMap(clientId, id, listener, clientInfo);
                             clientInfo.onDiscoverServicesStarted(clientId, info);
                         } else {
@@ -608,16 +654,36 @@
                             break;
                         }
 
-                        maybeStartDaemon();
                         id = getUniqueId();
-                        if (registerService(id, args.serviceInfo)) {
-                            if (DBG) Log.d(TAG, "Register " + clientId + " " + id);
+                        if (mAdvertiser != null) {
+                            final NsdServiceInfo serviceInfo = args.serviceInfo;
+                            final String serviceType = serviceInfo.getServiceType();
+                            final String registerServiceType = constructServiceType(serviceType);
+                            if (registerServiceType == null) {
+                                Log.e(TAG, "Invalid service type: " + serviceType);
+                                clientInfo.onRegisterServiceFailed(clientId,
+                                        NsdManager.FAILURE_INTERNAL_ERROR);
+                                break;
+                            }
+                            serviceInfo.setServiceType(registerServiceType);
+                            serviceInfo.setServiceName(truncateServiceName(
+                                    serviceInfo.getServiceName()));
+
+                            maybeStartMonitoringSockets();
+                            mAdvertiser.addService(id, serviceInfo);
                             storeRequestMap(clientId, id, clientInfo, msg.what);
-                            // Return success after mDns reports success
                         } else {
-                            unregisterService(id);
-                            clientInfo.onRegisterServiceFailed(
-                                    clientId, NsdManager.FAILURE_INTERNAL_ERROR);
+                            maybeStartDaemon();
+                            if (registerService(id, args.serviceInfo)) {
+                                if (DBG) Log.d(TAG, "Register " + clientId + " " + id);
+                                storeRequestMap(clientId, id, clientInfo, msg.what);
+                                // Return success after mDns reports success
+                            } else {
+                                unregisterService(id);
+                                clientInfo.onRegisterServiceFailed(
+                                        clientId, NsdManager.FAILURE_INTERNAL_ERROR);
+                            }
+
                         }
                         break;
                     case NsdManager.UNREGISTER_SERVICE:
@@ -633,11 +699,17 @@
                         }
                         id = clientInfo.mClientIds.get(clientId);
                         removeRequestMap(clientId, id, clientInfo);
-                        if (unregisterService(id)) {
+
+                        if (mAdvertiser != null) {
+                            mAdvertiser.removeService(id);
                             clientInfo.onUnregisterServiceSucceeded(clientId);
                         } else {
-                            clientInfo.onUnregisterServiceFailed(
-                                    clientId, NsdManager.FAILURE_INTERNAL_ERROR);
+                            if (unregisterService(id)) {
+                                clientInfo.onUnregisterServiceSucceeded(clientId);
+                            } else {
+                                clientInfo.onUnregisterServiceFailed(
+                                        clientId, NsdManager.FAILURE_INTERNAL_ERROR);
+                            }
                         }
                         break;
                     case NsdManager.RESOLVE_SERVICE: {
@@ -661,15 +733,17 @@
                                         NsdManager.FAILURE_INTERNAL_ERROR);
                                 break;
                             }
+                            final String resolveServiceType = serviceType + ".local";
 
                             maybeStartMonitoringSockets();
                             final MdnsListener listener =
-                                    new ResolutionListener(clientId, id, info, serviceType);
+                                    new ResolutionListener(clientId, id, info, resolveServiceType);
                             final MdnsSearchOptions options = MdnsSearchOptions.newBuilder()
                                     .setNetwork(info.getNetwork())
                                     .setIsPassiveMode(true)
                                     .build();
-                            mMdnsDiscoveryManager.registerListener(serviceType, listener, options);
+                            mMdnsDiscoveryManager.registerListener(
+                                    resolveServiceType, listener, options);
                             storeListenerMap(clientId, id, listener, clientInfo);
                         } else {
                             if (clientInfo.mResolvedService != null) {
@@ -1005,18 +1079,32 @@
         mNsdStateMachine.start();
         mMDnsManager = ctx.getSystemService(MDnsManager.class);
         mMDnsEventCallback = new MDnsEventCallback(mNsdStateMachine);
-        if (deps.isMdnsDiscoveryManagerEnabled(ctx)) {
+
+        final boolean discoveryManagerEnabled = deps.isMdnsDiscoveryManagerEnabled(ctx);
+        final boolean advertiserEnabled = deps.isMdnsAdvertiserEnabled(ctx);
+        if (discoveryManagerEnabled || advertiserEnabled) {
             mMdnsSocketProvider = deps.makeMdnsSocketProvider(ctx, handler.getLooper());
+        } else {
+            mMdnsSocketProvider = null;
+        }
+
+        if (discoveryManagerEnabled) {
             mMdnsSocketClient =
                     new MdnsMultinetworkSocketClient(handler.getLooper(), mMdnsSocketProvider);
             mMdnsDiscoveryManager =
                     deps.makeMdnsDiscoveryManager(new ExecutorProvider(), mMdnsSocketClient);
             handler.post(() -> mMdnsSocketClient.setCallback(mMdnsDiscoveryManager));
         } else {
-            mMdnsSocketProvider = null;
             mMdnsSocketClient = null;
             mMdnsDiscoveryManager = null;
         }
+
+        if (advertiserEnabled) {
+            mAdvertiser = deps.makeMdnsAdvertiser(handler.getLooper(), mMdnsSocketProvider,
+                    new AdvertiserCallback());
+        } else {
+            mAdvertiser = null;
+        }
     }
 
     /**
@@ -1025,10 +1113,10 @@
     @VisibleForTesting
     public static class Dependencies {
         /**
-         * Check whether or not MdnsDiscoveryManager feature is enabled.
+         * Check whether the MdnsDiscoveryManager feature is enabled.
          *
          * @param context The global context information about an app environment.
-         * @return true if MdnsDiscoveryManager feature is enabled.
+         * @return true if the MdnsDiscoveryManager feature is enabled.
          */
         public boolean isMdnsDiscoveryManagerEnabled(Context context) {
             return DeviceConfigUtils.isFeatureEnabled(context, NAMESPACE_CONNECTIVITY,
@@ -1036,6 +1124,17 @@
         }
 
         /**
+         * Check whether the MdnsAdvertiser feature is enabled.
+         *
+         * @param context The global context information about an app environment.
+         * @return true if the MdnsAdvertiser feature is enabled.
+         */
+        public boolean isMdnsAdvertiserEnabled(Context context) {
+            return DeviceConfigUtils.isFeatureEnabled(context, NAMESPACE_CONNECTIVITY,
+                    MDNS_ADVERTISER_VERSION, false /* defaultEnabled */);
+        }
+
+        /**
          * @see MdnsDiscoveryManager
          */
         public MdnsDiscoveryManager makeMdnsDiscoveryManager(
@@ -1044,6 +1143,15 @@
         }
 
         /**
+         * @see MdnsAdvertiser
+         */
+        public MdnsAdvertiser makeMdnsAdvertiser(
+                @NonNull Looper looper, @NonNull MdnsSocketProvider socketProvider,
+                @NonNull MdnsAdvertiser.AdvertiserCallback cb) {
+            return new MdnsAdvertiser(looper, socketProvider, cb);
+        }
+
+        /**
          * @see MdnsSocketProvider
          */
         public MdnsSocketProvider makeMdnsSocketProvider(Context context, Looper looper) {
@@ -1101,6 +1209,49 @@
         }
     }
 
+    private class AdvertiserCallback implements MdnsAdvertiser.AdvertiserCallback {
+        @Override
+        public void onRegisterServiceSucceeded(int serviceId, NsdServiceInfo registeredInfo) {
+            final ClientInfo clientInfo = getClientInfoOrLog(serviceId);
+            if (clientInfo == null) return;
+
+            final int clientId = getClientIdOrLog(clientInfo, serviceId);
+            if (clientId < 0) return;
+
+            // onRegisterServiceSucceeded only has the service name in its info. This aligns with
+            // historical behavior.
+            final NsdServiceInfo cbInfo = new NsdServiceInfo(registeredInfo.getServiceName(), null);
+            clientInfo.onRegisterServiceSucceeded(clientId, cbInfo);
+        }
+
+        @Override
+        public void onRegisterServiceFailed(int serviceId, int errorCode) {
+            final ClientInfo clientInfo = getClientInfoOrLog(serviceId);
+            if (clientInfo == null) return;
+
+            final int clientId = getClientIdOrLog(clientInfo, serviceId);
+            if (clientId < 0) return;
+
+            clientInfo.onRegisterServiceFailed(clientId, errorCode);
+        }
+
+        private ClientInfo getClientInfoOrLog(int serviceId) {
+            final ClientInfo clientInfo = mIdToClientInfoMap.get(serviceId);
+            if (clientInfo == null) {
+                Log.e(TAG, String.format("Callback for service %d has no client", serviceId));
+            }
+            return clientInfo;
+        }
+
+        private int getClientIdOrLog(@NonNull ClientInfo info, int serviceId) {
+            final int clientId = info.getClientId(serviceId);
+            if (clientId < 0) {
+                Log.e(TAG, String.format("Client ID not found for service %d", serviceId));
+            }
+            return clientId;
+        }
+    }
+
     @Override
     public INsdServiceConnector connect(INsdManagerCallback cb) {
         mContext.enforceCallingOrSelfPermission(android.Manifest.permission.INTERNET, "NsdService");
@@ -1364,7 +1515,11 @@
                         stopResolveService(globalId);
                         break;
                     case NsdManager.REGISTER_SERVICE:
-                        unregisterService(globalId);
+                        if (mAdvertiser != null) {
+                            mAdvertiser.removeService(globalId);
+                        } else {
+                            unregisterService(globalId);
+                        }
                         break;
                     default:
                         break;
diff --git a/tests/unit/java/com/android/server/NsdServiceTest.java b/tests/unit/java/com/android/server/NsdServiceTest.java
index 1bd49a5..9edfbe6 100644
--- a/tests/unit/java/com/android/server/NsdServiceTest.java
+++ b/tests/unit/java/com/android/server/NsdServiceTest.java
@@ -16,6 +16,7 @@
 
 package com.android.server;
 
+import static android.net.InetAddresses.parseNumericAddress;
 import static android.net.nsd.NsdManager.FAILURE_INTERNAL_ERROR;
 
 import static com.android.testutils.ContextUtils.mockService;
@@ -47,7 +48,6 @@
 import android.content.ContentResolver;
 import android.content.Context;
 import android.net.INetd;
-import android.net.InetAddresses;
 import android.net.Network;
 import android.net.mdns.aidl.DiscoveryInfo;
 import android.net.mdns.aidl.GetAddressInfo;
@@ -74,6 +74,7 @@
 import androidx.test.filters.SmallTest;
 
 import com.android.server.NsdService.Dependencies;
+import com.android.server.connectivity.mdns.MdnsAdvertiser;
 import com.android.server.connectivity.mdns.MdnsDiscoveryManager;
 import com.android.server.connectivity.mdns.MdnsServiceBrowserListener;
 import com.android.server.connectivity.mdns.MdnsServiceInfo;
@@ -96,6 +97,7 @@
 import java.net.UnknownHostException;
 import java.util.LinkedList;
 import java.util.List;
+import java.util.Objects;
 import java.util.Queue;
 
 // TODOs:
@@ -129,6 +131,7 @@
     @Mock MDnsManager mMockMDnsM;
     @Mock Dependencies mDeps;
     @Mock MdnsDiscoveryManager mDiscoveryManager;
+    @Mock MdnsAdvertiser mAdvertiser;
     @Mock MdnsSocketProvider mSocketProvider;
     HandlerThread mThread;
     TestHandler mHandler;
@@ -399,7 +402,7 @@
         final NsdServiceInfo resolvedService = resInfoCaptor.getValue();
         assertEquals(SERVICE_NAME, resolvedService.getServiceName());
         assertEquals("." + SERVICE_TYPE, resolvedService.getServiceType());
-        assertEquals(InetAddresses.parseNumericAddress(serviceAddress), resolvedService.getHost());
+        assertEquals(parseNumericAddress(serviceAddress), resolvedService.getHost());
         assertEquals(servicePort, resolvedService.getPort());
         assertNull(resolvedService.getNetwork());
         assertEquals(interfaceIdx, resolvedService.getInterfaceIndex());
@@ -582,6 +585,16 @@
         verify(mDeps).makeMdnsSocketProvider(any(), any());
     }
 
+    private void makeServiceWithMdnsAdvertiserEnabled() {
+        doReturn(true).when(mDeps).isMdnsAdvertiserEnabled(any(Context.class));
+        doReturn(mAdvertiser).when(mDeps).makeMdnsAdvertiser(any(), any(), any());
+        doReturn(mSocketProvider).when(mDeps).makeMdnsSocketProvider(any(), any());
+
+        mService = makeService();
+        verify(mDeps).makeMdnsAdvertiser(any(), any(), any());
+        verify(mDeps).makeMdnsSocketProvider(any(), any());
+    }
+
     @Test
     public void testMdnsDiscoveryManagerFeature() {
         // Create NsdService w/o feature enabled.
@@ -733,7 +746,7 @@
         assertTrue(info.getAttributes().containsKey("key"));
         assertEquals(1, info.getAttributes().size());
         assertArrayEquals(new byte[]{(byte) 0xFF, (byte) 0xFE}, info.getAttributes().get("key"));
-        assertEquals(InetAddresses.parseNumericAddress(IPV4_ADDRESS), info.getHost());
+        assertEquals(parseNumericAddress(IPV4_ADDRESS), info.getHost());
         assertEquals(network, info.getNetwork());
 
         // Verify the listener has been unregistered.
@@ -742,6 +755,102 @@
         verify(mSocketProvider, timeout(CLEANUP_DELAY_MS + TIMEOUT_MS)).stopMonitoringSockets();
     }
 
+    @Test
+    public void testAdvertiseWithMdnsAdvertiser() {
+        makeServiceWithMdnsAdvertiserEnabled();
+
+        final NsdManager client = connectClient(mService);
+        final RegistrationListener regListener = mock(RegistrationListener.class);
+        // final String serviceTypeWithLocalDomain = SERVICE_TYPE + ".local";
+        final ArgumentCaptor<MdnsAdvertiser.AdvertiserCallback> cbCaptor =
+                ArgumentCaptor.forClass(MdnsAdvertiser.AdvertiserCallback.class);
+        verify(mDeps).makeMdnsAdvertiser(any(), any(), cbCaptor.capture());
+
+        final NsdServiceInfo regInfo = new NsdServiceInfo(SERVICE_NAME, SERVICE_TYPE);
+        regInfo.setHost(parseNumericAddress("192.0.2.123"));
+        regInfo.setPort(12345);
+        regInfo.setAttribute("testattr", "testvalue");
+        regInfo.setNetwork(new Network(999));
+
+        client.registerService(regInfo, NsdManager.PROTOCOL_DNS_SD, Runnable::run, regListener);
+        waitForIdle();
+        verify(mSocketProvider).startMonitoringSockets();
+        final ArgumentCaptor<Integer> idCaptor = ArgumentCaptor.forClass(Integer.class);
+        verify(mAdvertiser).addService(idCaptor.capture(), argThat(info ->
+                matches(info, regInfo)));
+
+        // Verify onServiceRegistered callback
+        final MdnsAdvertiser.AdvertiserCallback cb = cbCaptor.getValue();
+        cb.onRegisterServiceSucceeded(idCaptor.getValue(), regInfo);
+
+        verify(regListener, timeout(TIMEOUT_MS)).onServiceRegistered(argThat(info -> matches(info,
+                new NsdServiceInfo(regInfo.getServiceName(), null))));
+
+        client.unregisterService(regListener);
+        waitForIdle();
+        verify(mAdvertiser).removeService(idCaptor.getValue());
+        verify(regListener, timeout(TIMEOUT_MS)).onServiceUnregistered(
+                argThat(info -> matches(info, regInfo)));
+        verify(mSocketProvider, timeout(TIMEOUT_MS)).stopMonitoringSockets();
+    }
+
+    @Test
+    public void testAdvertiseWithMdnsAdvertiser_FailedWithInvalidServiceType() {
+        makeServiceWithMdnsAdvertiserEnabled();
+
+        final NsdManager client = connectClient(mService);
+        final RegistrationListener regListener = mock(RegistrationListener.class);
+        // final String serviceTypeWithLocalDomain = SERVICE_TYPE + ".local";
+        final ArgumentCaptor<MdnsAdvertiser.AdvertiserCallback> cbCaptor =
+                ArgumentCaptor.forClass(MdnsAdvertiser.AdvertiserCallback.class);
+        verify(mDeps).makeMdnsAdvertiser(any(), any(), cbCaptor.capture());
+
+        final NsdServiceInfo regInfo = new NsdServiceInfo(SERVICE_NAME, "invalid_type");
+        regInfo.setHost(parseNumericAddress("192.0.2.123"));
+        regInfo.setPort(12345);
+        regInfo.setAttribute("testattr", "testvalue");
+        regInfo.setNetwork(new Network(999));
+
+        client.registerService(regInfo, NsdManager.PROTOCOL_DNS_SD, Runnable::run, regListener);
+        waitForIdle();
+        verify(mAdvertiser, never()).addService(anyInt(), any());
+
+        verify(regListener, timeout(TIMEOUT_MS)).onRegistrationFailed(
+                argThat(info -> matches(info, regInfo)), eq(FAILURE_INTERNAL_ERROR));
+    }
+
+    @Test
+    public void testAdvertiseWithMdnsAdvertiser_LongServiceName() {
+        makeServiceWithMdnsAdvertiserEnabled();
+
+        final NsdManager client = connectClient(mService);
+        final RegistrationListener regListener = mock(RegistrationListener.class);
+        // final String serviceTypeWithLocalDomain = SERVICE_TYPE + ".local";
+        final ArgumentCaptor<MdnsAdvertiser.AdvertiserCallback> cbCaptor =
+                ArgumentCaptor.forClass(MdnsAdvertiser.AdvertiserCallback.class);
+        verify(mDeps).makeMdnsAdvertiser(any(), any(), cbCaptor.capture());
+
+        final NsdServiceInfo regInfo = new NsdServiceInfo("a".repeat(70), SERVICE_TYPE);
+        regInfo.setHost(parseNumericAddress("192.0.2.123"));
+        regInfo.setPort(12345);
+        regInfo.setAttribute("testattr", "testvalue");
+        regInfo.setNetwork(new Network(999));
+
+        client.registerService(regInfo, NsdManager.PROTOCOL_DNS_SD, Runnable::run, regListener);
+        waitForIdle();
+        final ArgumentCaptor<Integer> idCaptor = ArgumentCaptor.forClass(Integer.class);
+        // Service name is truncated to 63 characters
+        verify(mAdvertiser).addService(idCaptor.capture(),
+                argThat(info -> info.getServiceName().equals("a".repeat(63))));
+
+        // Verify onServiceRegistered callback
+        final MdnsAdvertiser.AdvertiserCallback cb = cbCaptor.getValue();
+        cb.onRegisterServiceSucceeded(idCaptor.getValue(), regInfo);
+
+        verify(regListener, timeout(TIMEOUT_MS)).onServiceRegistered(
+                argThat(info -> matches(info, new NsdServiceInfo(regInfo.getServiceName(), null))));
+    }
+
     private void waitForIdle() {
         HandlerUtils.waitForIdle(mHandler, TIMEOUT_MS);
     }
@@ -786,6 +895,19 @@
         verify(mMockMDnsM, timeout(cleanupDelayMs + TIMEOUT_MS)).stopDaemon();
     }
 
+    /**
+     * Return true if two service info are the same.
+     *
+     * Useful for argument matchers as {@link NsdServiceInfo} does not implement equals.
+     */
+    private boolean matches(NsdServiceInfo a, NsdServiceInfo b) {
+        return Objects.equals(a.getServiceName(), b.getServiceName())
+                && Objects.equals(a.getServiceType(), b.getServiceType())
+                && Objects.equals(a.getHost(), b.getHost())
+                && Objects.equals(a.getNetwork(), b.getNetwork())
+                && Objects.equals(a.getAttributes(), b.getAttributes());
+    }
+
     public static class TestHandler extends Handler {
         public Message lastMessage;