Merge "Adds document for self certified network capabilities"
diff --git a/Cronet/tests/cts/src/android/net/http/cts/UrlRequestTest.java b/Cronet/tests/cts/src/android/net/http/cts/UrlRequestTest.java
index d7d3679..54c1ee3 100644
--- a/Cronet/tests/cts/src/android/net/http/cts/UrlRequestTest.java
+++ b/Cronet/tests/cts/src/android/net/http/cts/UrlRequestTest.java
@@ -21,6 +21,7 @@
 
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.hamcrest.Matchers.greaterThan;
+import static org.junit.Assert.assertSame;
 
 import android.content.Context;
 import android.net.http.HttpEngine;
@@ -91,4 +92,17 @@
         request.getStatus(statusListener);
         statusListener.expectStatus(Status.INVALID);
     }
+
+    @Test
+    public void testUrlRequestCancel_CancelCalled() throws Exception {
+        UrlRequest request = buildUrlRequest(mTestServer.getSuccessUrl());
+        mCallback.setAutoAdvance(false);
+
+        request.start();
+        mCallback.waitForNextStep();
+        assertSame(mCallback.mResponseStep, ResponseStep.ON_RESPONSE_STARTED);
+
+        request.cancel();
+        mCallback.expectCallback(ResponseStep.ON_CANCELED);
+    }
 }
diff --git a/framework/src/android/net/NetworkAgent.java b/framework/src/android/net/NetworkAgent.java
index 62e4fe1..3ec00d9 100644
--- a/framework/src/android/net/NetworkAgent.java
+++ b/framework/src/android/net/NetworkAgent.java
@@ -491,7 +491,7 @@
      * TCP sockets are open over a VPN. The system will check periodically for presence of
      * such open sockets, and this message is what triggers the re-evaluation.
      *
-     * obj = AutomaticOnOffKeepaliveObject.
+     * obj = A Binder object associated with the keepalive.
      * @hide
      */
     public static final int CMD_MONITOR_AUTOMATIC_KEEPALIVE = BASE + 30;
diff --git a/service-t/native/libs/libnetworkstats/NetworkTraceHandler.cpp b/service-t/native/libs/libnetworkstats/NetworkTraceHandler.cpp
index 4c37b8d..aeadb4a 100644
--- a/service-t/native/libs/libnetworkstats/NetworkTraceHandler.cpp
+++ b/service-t/native/libs/libnetworkstats/NetworkTraceHandler.cpp
@@ -50,6 +50,7 @@
 void NetworkTraceHandler::InitPerfettoTracing() {
   perfetto::TracingInitArgs args = {};
   args.backends |= perfetto::kSystemBackend;
+  args.enable_system_consumer = false;
   perfetto::Tracing::Initialize(args);
   NetworkTraceHandler::RegisterDataSource();
 }
diff --git a/service-t/src/com/android/server/NsdService.java b/service-t/src/com/android/server/NsdService.java
index 49c6ef0..4ad39e1 100644
--- a/service-t/src/com/android/server/NsdService.java
+++ b/service-t/src/com/android/server/NsdService.java
@@ -118,13 +118,15 @@
     private final NsdStateMachine mNsdStateMachine;
     private final MDnsManager mMDnsManager;
     private final MDnsEventCallback mMDnsEventCallback;
-    @Nullable
+    @NonNull
+    private final Dependencies mDeps;
+    @NonNull
     private final MdnsMultinetworkSocketClient mMdnsSocketClient;
-    @Nullable
+    @NonNull
     private final MdnsDiscoveryManager mMdnsDiscoveryManager;
-    @Nullable
+    @NonNull
     private final MdnsSocketProvider mMdnsSocketProvider;
-    @Nullable
+    @NonNull
     private final MdnsAdvertiser mAdvertiser;
     // WARNING : Accessing these values in any thread is not safe, it must only be changed in the
     // state machine thread. If change this outside state machine, it will need to introduce
@@ -311,21 +313,14 @@
             mIsMonitoringSocketsStarted = true;
         }
 
-        private void maybeStopMonitoringSockets() {
-            if (!mIsMonitoringSocketsStarted) {
-                if (DBG) Log.d(TAG, "Socket monitoring has not been started.");
-                return;
-            }
+        private void maybeStopMonitoringSocketsIfNoActiveRequest() {
+            if (!mIsMonitoringSocketsStarted) return;
+            if (isAnyRequestActive()) return;
+
             mMdnsSocketProvider.stopMonitoringSockets();
             mIsMonitoringSocketsStarted = false;
         }
 
-        private void maybeStopMonitoringSocketsIfNoActiveRequest() {
-            if (!isAnyRequestActive()) {
-                maybeStopMonitoringSockets();
-            }
-        }
-
         NsdStateMachine(String name, Handler handler) {
             super(name, handler);
             addState(mDefaultState);
@@ -362,9 +357,7 @@
                                 mLegacyClientCount -= 1;
                             }
                         }
-                        if (mMdnsDiscoveryManager != null || mAdvertiser != null) {
-                            maybeStopMonitoringSocketsIfNoActiveRequest();
-                        }
+                        maybeStopMonitoringSocketsIfNoActiveRequest();
                         maybeScheduleStop();
                         break;
                     case NsdManager.DISCOVER_SERVICES:
@@ -579,7 +572,7 @@
 
                         final NsdServiceInfo info = args.serviceInfo;
                         id = getUniqueId();
-                        if (mMdnsDiscoveryManager != null) {
+                        if (mDeps.isMdnsDiscoveryManagerEnabled(mContext)) {
                             final String serviceType = constructServiceType(info.getServiceType());
                             if (serviceType == null) {
                                 clientInfo.onDiscoverServicesFailed(clientId,
@@ -634,6 +627,9 @@
                             break;
                         }
                         id = request.mGlobalId;
+                        // Note isMdnsDiscoveryManagerEnabled may have changed to false at this
+                        // point, so this needs to check the type of the original request to
+                        // unregister instead of looking at the flag value.
                         if (request instanceof DiscoveryManagerRequest) {
                             final MdnsListener listener =
                                     ((DiscoveryManagerRequest) request).mListener;
@@ -671,7 +667,7 @@
                         }
 
                         id = getUniqueId();
-                        if (mAdvertiser != null) {
+                        if (mDeps.isMdnsAdvertiserEnabled(mContext)) {
                             final NsdServiceInfo serviceInfo = args.serviceInfo;
                             final String serviceType = serviceInfo.getServiceType();
                             final String registerServiceType = constructServiceType(serviceType);
@@ -722,7 +718,10 @@
                         id = request.mGlobalId;
                         removeRequestMap(clientId, id, clientInfo);
 
-                        if (mAdvertiser != null) {
+                        // Note isMdnsAdvertiserEnabled may have changed to false at this point,
+                        // so this needs to check the type of the original request to unregister
+                        // instead of looking at the flag value.
+                        if (request instanceof AdvertiserClientRequest) {
                             mAdvertiser.removeService(id);
                             clientInfo.onUnregisterServiceSucceeded(clientId);
                         } else {
@@ -749,7 +748,7 @@
 
                         final NsdServiceInfo info = args.serviceInfo;
                         id = getUniqueId();
-                        if (mMdnsDiscoveryManager != null) {
+                        if (mDeps.isMdnsDiscoveryManagerEnabled(mContext)) {
                             final String serviceType = constructServiceType(info.getServiceType());
                             if (serviceType == null) {
                                 clientInfo.onResolveServiceFailed(clientId,
@@ -1241,32 +1240,16 @@
         mNsdStateMachine.start();
         mMDnsManager = ctx.getSystemService(MDnsManager.class);
         mMDnsEventCallback = new MDnsEventCallback(mNsdStateMachine);
+        mDeps = deps;
 
-        final boolean discoveryManagerEnabled = deps.isMdnsDiscoveryManagerEnabled(ctx);
-        final boolean advertiserEnabled = deps.isMdnsAdvertiserEnabled(ctx);
-        if (discoveryManagerEnabled || advertiserEnabled) {
-            mMdnsSocketProvider = deps.makeMdnsSocketProvider(ctx, handler.getLooper());
-        } else {
-            mMdnsSocketProvider = null;
-        }
-
-        if (discoveryManagerEnabled) {
-            mMdnsSocketClient =
-                    new MdnsMultinetworkSocketClient(handler.getLooper(), mMdnsSocketProvider);
-            mMdnsDiscoveryManager =
-                    deps.makeMdnsDiscoveryManager(new ExecutorProvider(), mMdnsSocketClient);
-            handler.post(() -> mMdnsSocketClient.setCallback(mMdnsDiscoveryManager));
-        } else {
-            mMdnsSocketClient = null;
-            mMdnsDiscoveryManager = null;
-        }
-
-        if (advertiserEnabled) {
-            mAdvertiser = deps.makeMdnsAdvertiser(handler.getLooper(), mMdnsSocketProvider,
-                    new AdvertiserCallback());
-        } else {
-            mAdvertiser = null;
-        }
+        mMdnsSocketProvider = deps.makeMdnsSocketProvider(ctx, handler.getLooper());
+        mMdnsSocketClient =
+                new MdnsMultinetworkSocketClient(handler.getLooper(), mMdnsSocketProvider);
+        mMdnsDiscoveryManager =
+                deps.makeMdnsDiscoveryManager(new ExecutorProvider(), mMdnsSocketClient);
+        handler.post(() -> mMdnsSocketClient.setCallback(mMdnsDiscoveryManager));
+        mAdvertiser = deps.makeMdnsAdvertiser(handler.getLooper(), mMdnsSocketProvider,
+                new AdvertiserCallback());
     }
 
     /**
diff --git a/service/src/com/android/server/ConnectivityService.java b/service/src/com/android/server/ConnectivityService.java
index 330a1da..f1c68cb 100755
--- a/service/src/com/android/server/ConnectivityService.java
+++ b/service/src/com/android/server/ConnectivityService.java
@@ -5547,7 +5547,9 @@
                     break;
                 }
                 case NetworkAgent.CMD_MONITOR_AUTOMATIC_KEEPALIVE: {
-                    final AutomaticOnOffKeepalive ki = (AutomaticOnOffKeepalive) msg.obj;
+                    final AutomaticOnOffKeepalive ki =
+                            mKeepaliveTracker.getKeepaliveForBinder((IBinder) msg.obj);
+                    if (null == ki) return; // The callback was unregistered before the alarm fired
 
                     final Network network = ki.getNetwork();
                     boolean networkFound = false;
diff --git a/service/src/com/android/server/connectivity/AutomaticOnOffKeepaliveTracker.java b/service/src/com/android/server/connectivity/AutomaticOnOffKeepaliveTracker.java
index 46fff6c..18e2dd8 100644
--- a/service/src/com/android/server/connectivity/AutomaticOnOffKeepaliveTracker.java
+++ b/service/src/com/android/server/connectivity/AutomaticOnOffKeepaliveTracker.java
@@ -45,6 +45,7 @@
 import android.net.Network;
 import android.net.NetworkAgent;
 import android.net.SocketKeepalive.InvalidSocketException;
+import android.os.Bundle;
 import android.os.FileUtils;
 import android.os.Handler;
 import android.os.IBinder;
@@ -60,6 +61,7 @@
 import com.android.internal.annotations.VisibleForTesting;
 import com.android.internal.util.IndentingPrintWriter;
 import com.android.modules.utils.build.SdkLevel;
+import com.android.net.module.util.BinderUtils;
 import com.android.net.module.util.CollectionUtils;
 import com.android.net.module.util.DeviceConfigUtils;
 import com.android.net.module.util.HexDump;
@@ -92,8 +94,7 @@
     private static final int[] ADDRESS_FAMILIES = new int[] {AF_INET6, AF_INET};
     private static final String ACTION_TCP_POLLING_ALARM =
             "com.android.server.connectivity.KeepaliveTracker.TCP_POLLING_ALARM";
-    private static final String EXTRA_NETWORK = "network_id";
-    private static final String EXTRA_SLOT = "slot";
+    private static final String EXTRA_BINDER_TOKEN = "token";
     private static final long DEFAULT_TCP_POLLING_INTERVAL_MS = 120_000L;
     private static final String AUTOMATIC_ON_OFF_KEEPALIVE_VERSION =
             "automatic_on_off_keepalive_version";
@@ -159,11 +160,10 @@
         public void onReceive(Context context, Intent intent) {
             if (ACTION_TCP_POLLING_ALARM.equals(intent.getAction())) {
                 Log.d(TAG, "Received TCP polling intent");
-                final Network network = intent.getParcelableExtra(EXTRA_NETWORK);
-                final int slot = intent.getIntExtra(EXTRA_SLOT, -1);
+                final IBinder token = intent.getBundleExtra(EXTRA_BINDER_TOKEN).getBinder(
+                        EXTRA_BINDER_TOKEN);
                 mConnectivityServiceHandler.obtainMessage(
-                        NetworkAgent.CMD_MONITOR_AUTOMATIC_KEEPALIVE,
-                        slot, 0 , network).sendToTarget();
+                        NetworkAgent.CMD_MONITOR_AUTOMATIC_KEEPALIVE, token).sendToTarget();
             }
         }
     };
@@ -183,6 +183,8 @@
     public class AutomaticOnOffKeepalive {
         @NonNull
         private final KeepaliveTracker.KeepaliveInfo mKi;
+        @NonNull
+        private final ISocketKeepaliveCallback mCallback;
         @Nullable
         private final FileDescriptor mFd;
         @Nullable
@@ -193,6 +195,7 @@
         AutomaticOnOffKeepalive(@NonNull final KeepaliveTracker.KeepaliveInfo ki,
                 final boolean autoOnOff, @NonNull Context context) throws InvalidSocketException {
             this.mKi = Objects.requireNonNull(ki);
+            mCallback = ki.mCallback;
             if (autoOnOff && mDependencies.isFeatureEnabled(AUTOMATIC_ON_OFF_KEEPALIVE_VERSION)) {
                 mAutomaticOnOffState = STATE_ENABLED;
                 if (null == ki.mFd) {
@@ -205,8 +208,7 @@
                     Log.e(TAG, "Cannot dup fd: ", e);
                     throw new InvalidSocketException(ERROR_INVALID_SOCKET, e);
                 }
-                mTcpPollingAlarm = createTcpPollingAlarmIntent(
-                        context, ki.getNai().network(), ki.getSlot());
+                mTcpPollingAlarm = createTcpPollingAlarmIntent(context, mCallback.asBinder());
             } else {
                 mAutomaticOnOffState = STATE_ALWAYS_ON;
                 // A null fd is acceptable in KeepaliveInfo for backward compatibility of
@@ -226,12 +228,14 @@
         }
 
         private PendingIntent createTcpPollingAlarmIntent(@NonNull Context context,
-                @NonNull Network network, int slot) {
+                @NonNull IBinder token) {
             final Intent intent = new Intent(ACTION_TCP_POLLING_ALARM);
-            intent.putExtra(EXTRA_NETWORK, network);
-            intent.putExtra(EXTRA_SLOT, slot);
-            return PendingIntent.getBroadcast(
-                    context, 0 /* requestCode */, intent, PendingIntent.FLAG_IMMUTABLE);
+            // Intent doesn't expose methods to put extra Binders, but Bundle does.
+            final Bundle b = new Bundle();
+            b.putBinder(EXTRA_BINDER_TOKEN, token);
+            intent.putExtra(EXTRA_BINDER_TOKEN, b);
+            return BinderUtils.withCleanCallingIdentity(() -> PendingIntent.getBroadcast(
+                    context, 0 /* requestCode */, intent, PendingIntent.FLAG_IMMUTABLE));
         }
     }
 
@@ -318,13 +322,14 @@
             newKi = autoKi.mKi.withFd(autoKi.mFd);
         } catch (InvalidSocketException | IllegalArgumentException | SecurityException e) {
             Log.e(TAG, "Fail to construct keepalive", e);
-            mKeepaliveTracker.notifyErrorCallback(autoKi.mKi.mCallback, ERROR_INVALID_SOCKET);
+            mKeepaliveTracker.notifyErrorCallback(autoKi.mCallback, ERROR_INVALID_SOCKET);
             return;
         }
         autoKi.mAutomaticOnOffState = STATE_ENABLED;
         handleResumeKeepalive(newKi);
     }
 
+    // TODO : this method should be removed ; the keepalives should always be indexed by callback
     private int findAutomaticOnOffKeepaliveIndex(@NonNull Network network, int slot) {
         ensureRunningOnHandlerThread();
 
@@ -338,6 +343,7 @@
         return -1;
     }
 
+    // TODO : this method should be removed ; the keepalives should always be indexed by callback
     @Nullable
     private AutomaticOnOffKeepalive findAutomaticOnOffKeepalive(@NonNull Network network,
             int slot) {
@@ -348,6 +354,18 @@
     }
 
     /**
+     * Find the AutomaticOnOffKeepalive associated with a given callback.
+     * @return the keepalive associated with this callback, or null if none
+     */
+    @Nullable
+    public AutomaticOnOffKeepalive getKeepaliveForBinder(@NonNull final IBinder token) {
+        ensureRunningOnHandlerThread();
+
+        return CollectionUtils.findFirst(mAutomaticOnOffKeepalives,
+                it -> it.mCallback.asBinder().equals(token));
+    }
+
+    /**
      * Handle keepalive events from lower layer.
      *
      * Forward to KeepaliveTracker.
diff --git a/service/src/com/android/server/connectivity/KeepaliveTracker.java b/service/src/com/android/server/connectivity/KeepaliveTracker.java
index 63b76c7..7cb613b 100644
--- a/service/src/com/android/server/connectivity/KeepaliveTracker.java
+++ b/service/src/com/android/server/connectivity/KeepaliveTracker.java
@@ -125,8 +125,9 @@
      * which is only returned when the hardware has successfully started the keepalive.
      */
     class KeepaliveInfo implements IBinder.DeathRecipient {
-        // Bookkeeping data.
+        // TODO : remove this member. Only AutoOnOffKeepalive should have a reference to this.
         public final ISocketKeepaliveCallback mCallback;
+        // Bookkeeping data.
         private final int mUid;
         private final int mPid;
         private final boolean mPrivileged;
diff --git a/tests/common/AndroidTest_Coverage.xml b/tests/common/AndroidTest_Coverage.xml
index 48d26b8..c94ec27 100644
--- a/tests/common/AndroidTest_Coverage.xml
+++ b/tests/common/AndroidTest_Coverage.xml
@@ -13,7 +13,7 @@
      limitations under the License.
 -->
 <configuration description="Runs coverage tests for Connectivity">
-    <target_preparer class="com.android.tradefed.targetprep.TestAppInstallSetup">
+    <target_preparer class="com.android.tradefed.targetprep.suite.SuiteApkInstaller">
       <option name="test-file-name" value="ConnectivityCoverageTests.apk" />
       <option name="install-arg" value="-t" />
     </target_preparer>
diff --git a/tests/unit/java/com/android/server/NsdServiceTest.java b/tests/unit/java/com/android/server/NsdServiceTest.java
index 98a8ed2..a2c4b9b 100644
--- a/tests/unit/java/com/android/server/NsdServiceTest.java
+++ b/tests/unit/java/com/android/server/NsdServiceTest.java
@@ -45,6 +45,7 @@
 import static org.mockito.Mockito.timeout;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyNoMoreInteractions;
 import static org.mockito.Mockito.when;
 
 import android.compat.testing.PlatformCompatChangeRule;
@@ -170,6 +171,9 @@
         doReturn(true).when(mMockMDnsM).resolve(
                 anyInt(), anyString(), anyString(), anyString(), anyInt());
         doReturn(false).when(mDeps).isMdnsDiscoveryManagerEnabled(any(Context.class));
+        doReturn(mDiscoveryManager).when(mDeps).makeMdnsDiscoveryManager(any(), any());
+        doReturn(mSocketProvider).when(mDeps).makeMdnsSocketProvider(any(), any());
+        doReturn(mAdvertiser).when(mDeps).makeMdnsAdvertiser(any(), any(), any());
 
         mService = makeService();
     }
@@ -824,40 +828,50 @@
                 client.unregisterServiceInfoCallback(serviceInfoCallback));
     }
 
-    private void makeServiceWithMdnsDiscoveryManagerEnabled() {
+    private void setMdnsDiscoveryManagerEnabled() {
         doReturn(true).when(mDeps).isMdnsDiscoveryManagerEnabled(any(Context.class));
-        doReturn(mDiscoveryManager).when(mDeps).makeMdnsDiscoveryManager(any(), any());
-        doReturn(mSocketProvider).when(mDeps).makeMdnsSocketProvider(any(), any());
-
-        mService = makeService();
-        verify(mDeps).makeMdnsDiscoveryManager(any(), any());
-        verify(mDeps).makeMdnsSocketProvider(any(), any());
     }
 
-    private void makeServiceWithMdnsAdvertiserEnabled() {
+    private void setMdnsAdvertiserEnabled() {
         doReturn(true).when(mDeps).isMdnsAdvertiserEnabled(any(Context.class));
-        doReturn(mAdvertiser).when(mDeps).makeMdnsAdvertiser(any(), any(), any());
-        doReturn(mSocketProvider).when(mDeps).makeMdnsSocketProvider(any(), any());
-
-        mService = makeService();
-        verify(mDeps).makeMdnsAdvertiser(any(), any(), any());
-        verify(mDeps).makeMdnsSocketProvider(any(), any());
     }
 
     @Test
     public void testMdnsDiscoveryManagerFeature() {
         // Create NsdService w/o feature enabled.
-        connectClient(mService);
-        verify(mDeps, never()).makeMdnsDiscoveryManager(any(), any());
-        verify(mDeps, never()).makeMdnsSocketProvider(any(), any());
+        final NsdManager client = connectClient(mService);
+        final DiscoveryListener discListenerWithoutFeature = mock(DiscoveryListener.class);
+        client.discoverServices(SERVICE_TYPE, PROTOCOL, discListenerWithoutFeature);
+        waitForIdle();
 
-        // Create NsdService again w/ feature enabled.
-        makeServiceWithMdnsDiscoveryManagerEnabled();
+        final ArgumentCaptor<Integer> legacyIdCaptor = ArgumentCaptor.forClass(Integer.class);
+        verify(mMockMDnsM).discover(legacyIdCaptor.capture(), any(), anyInt());
+        verifyNoMoreInteractions(mDiscoveryManager);
+
+        setMdnsDiscoveryManagerEnabled();
+        final DiscoveryListener discListenerWithFeature = mock(DiscoveryListener.class);
+        client.discoverServices(SERVICE_TYPE, PROTOCOL, discListenerWithFeature);
+        waitForIdle();
+
+        final String serviceTypeWithLocalDomain = SERVICE_TYPE + ".local";
+        final ArgumentCaptor<MdnsServiceBrowserListener> listenerCaptor =
+                ArgumentCaptor.forClass(MdnsServiceBrowserListener.class);
+        verify(mDiscoveryManager).registerListener(eq(serviceTypeWithLocalDomain),
+                listenerCaptor.capture(), any());
+
+        client.stopServiceDiscovery(discListenerWithoutFeature);
+        waitForIdle();
+        verify(mMockMDnsM).stopOperation(legacyIdCaptor.getValue());
+
+        client.stopServiceDiscovery(discListenerWithFeature);
+        waitForIdle();
+        verify(mDiscoveryManager).unregisterListener(serviceTypeWithLocalDomain,
+                listenerCaptor.getValue());
     }
 
     @Test
     public void testDiscoveryWithMdnsDiscoveryManager() {
-        makeServiceWithMdnsDiscoveryManagerEnabled();
+        setMdnsDiscoveryManagerEnabled();
 
         final NsdManager client = connectClient(mService);
         final DiscoveryListener discListener = mock(DiscoveryListener.class);
@@ -922,7 +936,7 @@
 
     @Test
     public void testDiscoveryWithMdnsDiscoveryManager_FailedWithInvalidServiceType() {
-        makeServiceWithMdnsDiscoveryManagerEnabled();
+        setMdnsDiscoveryManagerEnabled();
 
         final NsdManager client = connectClient(mService);
         final DiscoveryListener discListener = mock(DiscoveryListener.class);
@@ -951,7 +965,7 @@
 
     @Test
     public void testResolutionWithMdnsDiscoveryManager() throws UnknownHostException {
-        makeServiceWithMdnsDiscoveryManagerEnabled();
+        setMdnsDiscoveryManagerEnabled();
 
         final NsdManager client = connectClient(mService);
         final ResolveListener resolveListener = mock(ResolveListener.class);
@@ -1005,8 +1019,43 @@
     }
 
     @Test
+    public void testMdnsAdvertiserFeatureFlagging() {
+        // Create NsdService w/o feature enabled.
+        final NsdManager client = connectClient(mService);
+        final NsdServiceInfo regInfo = new NsdServiceInfo(SERVICE_NAME, SERVICE_TYPE);
+        regInfo.setHost(parseNumericAddress("192.0.2.123"));
+        regInfo.setPort(12345);
+        final RegistrationListener regListenerWithoutFeature = mock(RegistrationListener.class);
+        client.registerService(regInfo, PROTOCOL, regListenerWithoutFeature);
+        waitForIdle();
+
+        final ArgumentCaptor<Integer> legacyIdCaptor = ArgumentCaptor.forClass(Integer.class);
+        verify(mMockMDnsM).registerService(legacyIdCaptor.capture(), any(), any(), anyInt(),
+                any(), anyInt());
+        verifyNoMoreInteractions(mAdvertiser);
+
+        setMdnsAdvertiserEnabled();
+        final RegistrationListener regListenerWithFeature = mock(RegistrationListener.class);
+        client.registerService(regInfo, PROTOCOL, regListenerWithFeature);
+        waitForIdle();
+
+        final ArgumentCaptor<Integer> serviceIdCaptor = ArgumentCaptor.forClass(Integer.class);
+        verify(mAdvertiser).addService(serviceIdCaptor.capture(),
+                argThat(info -> matches(info, regInfo)));
+
+        client.unregisterService(regListenerWithoutFeature);
+        waitForIdle();
+        verify(mMockMDnsM).stopOperation(legacyIdCaptor.getValue());
+        verify(mAdvertiser, never()).removeService(anyInt());
+
+        client.unregisterService(regListenerWithFeature);
+        waitForIdle();
+        verify(mAdvertiser).removeService(serviceIdCaptor.getValue());
+    }
+
+    @Test
     public void testAdvertiseWithMdnsAdvertiser() {
-        makeServiceWithMdnsAdvertiserEnabled();
+        setMdnsAdvertiserEnabled();
 
         final NsdManager client = connectClient(mService);
         final RegistrationListener regListener = mock(RegistrationListener.class);
@@ -1045,7 +1094,7 @@
 
     @Test
     public void testAdvertiseWithMdnsAdvertiser_FailedWithInvalidServiceType() {
-        makeServiceWithMdnsAdvertiserEnabled();
+        setMdnsAdvertiserEnabled();
 
         final NsdManager client = connectClient(mService);
         final RegistrationListener regListener = mock(RegistrationListener.class);
@@ -1070,7 +1119,7 @@
 
     @Test
     public void testAdvertiseWithMdnsAdvertiser_LongServiceName() {
-        makeServiceWithMdnsAdvertiserEnabled();
+        setMdnsAdvertiserEnabled();
 
         final NsdManager client = connectClient(mService);
         final RegistrationListener regListener = mock(RegistrationListener.class);