Merge "Stop monitoring sockets until all sockets are unrequested"
diff --git a/service-t/src/com/android/server/NsdService.java b/service-t/src/com/android/server/NsdService.java
index 4ad39e1..8e06fde 100644
--- a/service-t/src/com/android/server/NsdService.java
+++ b/service-t/src/com/android/server/NsdService.java
@@ -317,7 +317,7 @@
             if (!mIsMonitoringSocketsStarted) return;
             if (isAnyRequestActive()) return;
 
-            mMdnsSocketProvider.stopMonitoringSockets();
+            mMdnsSocketProvider.requestStopWhenInactive();
             mIsMonitoringSocketsStarted = false;
         }
 
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsSocketProvider.java b/service-t/src/com/android/server/connectivity/mdns/MdnsSocketProvider.java
index 9298852..743f946 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsSocketProvider.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsSocketProvider.java
@@ -82,6 +82,7 @@
     private final List<String> mTetheredInterfaces = new ArrayList<>();
     private final byte[] mPacketReadBuffer = new byte[READ_BUFFER_SIZE];
     private boolean mMonitoringSockets = false;
+    private boolean mRequestStop = false;
 
     public MdnsSocketProvider(@NonNull Context context, @NonNull Looper looper) {
         this(context, looper, new Dependencies());
@@ -179,6 +180,7 @@
     /*** Start monitoring sockets by listening callbacks for sockets creation or removal */
     public void startMonitoringSockets() {
         ensureRunningOnHandlerThread(mHandler);
+        mRequestStop = false; // Reset stop request flag.
         if (mMonitoringSockets) {
             Log.d(TAG, "Already monitoring sockets.");
             return;
@@ -195,22 +197,34 @@
         mMonitoringSockets = true;
     }
 
-    /*** Stop monitoring sockets and unregister callbacks */
-    public void stopMonitoringSockets() {
+    private void maybeStopMonitoringSockets() {
+        if (!mMonitoringSockets) return; // Already unregistered.
+        if (!mRequestStop) return; // No stop request.
+
+        // Only unregister the network callback if there is no socket request.
+        if (mCallbacksToRequestedNetworks.isEmpty()) {
+            mContext.getSystemService(ConnectivityManager.class)
+                    .unregisterNetworkCallback(mNetworkCallback);
+
+            final TetheringManager tetheringManager = mContext.getSystemService(
+                    TetheringManager.class);
+            tetheringManager.unregisterTetheringEventCallback(mTetheringEventCallback);
+
+            mHandler.post(mNetlinkMonitor::stop);
+            mMonitoringSockets = false;
+        }
+    }
+
+    /*** Request to stop monitoring sockets and unregister callbacks */
+    public void requestStopWhenInactive() {
         ensureRunningOnHandlerThread(mHandler);
         if (!mMonitoringSockets) {
             Log.d(TAG, "Monitoring sockets hasn't been started.");
             return;
         }
-        if (DBG) Log.d(TAG, "Stop monitoring sockets.");
-        mContext.getSystemService(ConnectivityManager.class)
-                .unregisterNetworkCallback(mNetworkCallback);
-
-        final TetheringManager tetheringManager = mContext.getSystemService(TetheringManager.class);
-        tetheringManager.unregisterTetheringEventCallback(mTetheringEventCallback);
-
-        mHandler.post(mNetlinkMonitor::stop);
-        mMonitoringSockets = false;
+        if (DBG) Log.d(TAG, "Try to stop monitoring sockets.");
+        mRequestStop = true;
+        maybeStopMonitoringSockets();
     }
 
     /*** Check whether the target network is matched current network */
@@ -450,6 +464,9 @@
             cb.onInterfaceDestroyed(new Network(INetd.LOCAL_NET_ID), info.mSocket);
         }
         mTetherInterfaceSockets.clear();
+
+        // Try to unregister network callback.
+        maybeStopMonitoringSockets();
     }
 
     /*** Callbacks for listening socket changes */
diff --git a/tests/unit/java/com/android/server/NsdServiceTest.java b/tests/unit/java/com/android/server/NsdServiceTest.java
index 0680772..8fc9252 100644
--- a/tests/unit/java/com/android/server/NsdServiceTest.java
+++ b/tests/unit/java/com/android/server/NsdServiceTest.java
@@ -932,7 +932,7 @@
         waitForIdle();
         verify(mDiscoveryManager).unregisterListener(eq(serviceTypeWithLocalDomain), any());
         verify(discListener, timeout(TIMEOUT_MS)).onDiscoveryStopped(SERVICE_TYPE);
-        verify(mSocketProvider, timeout(CLEANUP_DELAY_MS + TIMEOUT_MS)).stopMonitoringSockets();
+        verify(mSocketProvider, timeout(CLEANUP_DELAY_MS + TIMEOUT_MS)).requestStopWhenInactive();
     }
 
     @Test
@@ -1016,7 +1016,7 @@
         // Verify the listener has been unregistered.
         verify(mDiscoveryManager, timeout(TIMEOUT_MS))
                 .unregisterListener(eq(constructedServiceType), any());
-        verify(mSocketProvider, timeout(CLEANUP_DELAY_MS + TIMEOUT_MS)).stopMonitoringSockets();
+        verify(mSocketProvider, timeout(CLEANUP_DELAY_MS + TIMEOUT_MS)).requestStopWhenInactive();
     }
 
     @Test
@@ -1090,7 +1090,7 @@
         verify(mAdvertiser).removeService(idCaptor.getValue());
         verify(regListener, timeout(TIMEOUT_MS)).onServiceUnregistered(
                 argThat(info -> matches(info, regInfo)));
-        verify(mSocketProvider, timeout(TIMEOUT_MS)).stopMonitoringSockets();
+        verify(mSocketProvider, timeout(TIMEOUT_MS)).requestStopWhenInactive();
     }
 
     @Test
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketProviderTest.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketProviderTest.java
index 635b296..b9cb255 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketProviderTest.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketProviderTest.java
@@ -27,6 +27,7 @@
 import static org.mockito.Mockito.doCallRealMethod;
 import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.never;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 
@@ -109,12 +110,15 @@
         final HandlerThread thread = new HandlerThread("MdnsSocketProviderTest");
         thread.start();
         mHandler = new Handler(thread.getLooper());
+        mSocketProvider = new MdnsSocketProvider(mContext, thread.getLooper(), mDeps);
+    }
 
+    private void startMonitoringSockets() {
         final ArgumentCaptor<NetworkCallback> nwCallbackCaptor =
                 ArgumentCaptor.forClass(NetworkCallback.class);
         final ArgumentCaptor<TetheringEventCallback> teCallbackCaptor =
                 ArgumentCaptor.forClass(TetheringEventCallback.class);
-        mSocketProvider = new MdnsSocketProvider(mContext, thread.getLooper(), mDeps);
+
         mHandler.post(mSocketProvider::startMonitoringSockets);
         HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
         verify(mCm).registerNetworkCallback(any(), nwCallbackCaptor.capture(), any());
@@ -205,6 +209,8 @@
 
     @Test
     public void testSocketRequestAndUnrequestSocket() {
+        startMonitoringSockets();
+
         final TestSocketCallback testCallback1 = new TestSocketCallback();
         mHandler.post(() -> mSocketProvider.requestSocket(TEST_NETWORK, testCallback1));
         HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
@@ -275,6 +281,8 @@
 
     @Test
     public void testAddressesChanged() throws Exception {
+        startMonitoringSockets();
+
         final TestSocketCallback testCallback = new TestSocketCallback();
         mHandler.post(() -> mSocketProvider.requestSocket(TEST_NETWORK, testCallback));
         HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
@@ -297,4 +305,53 @@
         testCallback.expectedAddressesChangedForNetwork(
                 TEST_NETWORK, List.of(LINKADDRV4, LINKADDRV6));
     }
+
+    @Test
+    public void testStartAndStopMonitoringSockets() {
+        // Stop monitoring sockets before start. Should not unregister any network callback.
+        mHandler.post(mSocketProvider::requestStopWhenInactive);
+        HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
+        verify(mCm, never()).unregisterNetworkCallback(any(NetworkCallback.class));
+        verify(mTm, never()).unregisterTetheringEventCallback(any(TetheringEventCallback.class));
+
+        // Start sockets monitoring.
+        startMonitoringSockets();
+        // Request a socket then unrequest it. Expect no network callback unregistration.
+        final TestSocketCallback testCallback = new TestSocketCallback();
+        mHandler.post(() -> mSocketProvider.requestSocket(TEST_NETWORK, testCallback));
+        HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
+        testCallback.expectedNoCallback();
+        mHandler.post(()-> mSocketProvider.unrequestSocket(testCallback));
+        HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
+        verify(mCm, never()).unregisterNetworkCallback(any(NetworkCallback.class));
+        verify(mTm, never()).unregisterTetheringEventCallback(any(TetheringEventCallback.class));
+        // Request stop and it should unregister network callback immediately because there is no
+        // socket request.
+        mHandler.post(mSocketProvider::requestStopWhenInactive);
+        HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
+        verify(mCm, times(1)).unregisterNetworkCallback(any(NetworkCallback.class));
+        verify(mTm, times(1)).unregisterTetheringEventCallback(any(TetheringEventCallback.class));
+
+        // Start sockets monitoring and request a socket again.
+        mHandler.post(mSocketProvider::startMonitoringSockets);
+        HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
+        verify(mCm, times(2)).registerNetworkCallback(any(), any(NetworkCallback.class), any());
+        verify(mTm, times(2)).registerTetheringEventCallback(
+                any(), any(TetheringEventCallback.class));
+        final TestSocketCallback testCallback2 = new TestSocketCallback();
+        mHandler.post(() -> mSocketProvider.requestSocket(TEST_NETWORK, testCallback2));
+        HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
+        testCallback2.expectedNoCallback();
+        // Try to stop monitoring sockets but should be ignored and wait until all socket are
+        // unrequested.
+        mHandler.post(mSocketProvider::requestStopWhenInactive);
+        HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
+        verify(mCm, times(1)).unregisterNetworkCallback(any(NetworkCallback.class));
+        verify(mTm, times(1)).unregisterTetheringEventCallback(any());
+        // Unrequest the socket then network callbacks should be unregistered.
+        mHandler.post(()-> mSocketProvider.unrequestSocket(testCallback2));
+        HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
+        verify(mCm, times(2)).unregisterNetworkCallback(any(NetworkCallback.class));
+        verify(mTm, times(2)).unregisterTetheringEventCallback(any(TetheringEventCallback.class));
+    }
 }