[mdns] fix race conditions in MdnsAdvertiser

There is a race condition in MdnsAdvertiser when removing the last
service:
1. When the last service is removed, it will call the destroyNow()
   to tear down the interface advertiser and schedule a onDestroyed()
   callback for MdnsAdvetiser to do further cleanup
2. If another serivce is registered after removing the last service but
   before onDestroyed() is invoked, then onAdvertiserDestroyed() will
   return false and the InterfaceAdvertiserRequest is kept in the map in
   MdnsAdvertiser. This is because the newly registered service is in
   the mPendingRestionstrations.
3. The problem is that InterfaceAdvertiserRequest now have no
   InterfaceAdvertiser (because the only advertiser has already
   destroyed) and cause the newly registered service can't be
   advertised.

It needs to schedule onDestroyed() on the handler thread because we
can't do modify-during-iteration, but this introduces the race condition
here.

The fix in this commit changes to not destroy the InterfaceAdvertiser
immediately when the last service is removed but destroy when there are
no pending registrations in the InterfaceAdvertiserRequest.

Test: atest --iterations 100 CtsNetTestCases:android.net.cts.NsdManagerTest#testRegisterService_registerImmediatelyAfterUnregister_serviceFound
Change-Id: I6b6f01e6ddf103d16a8b0296af000dc8262ea65b
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsAdvertiser.java b/service-t/src/com/android/server/connectivity/mdns/MdnsAdvertiser.java
index c162bcc..98c2d86 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsAdvertiser.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsAdvertiser.java
@@ -241,13 +241,10 @@
         }
 
         @Override
-        public void onDestroyed(@NonNull MdnsInterfaceSocket socket) {
-            for (int i = mAdvertiserRequests.size() - 1; i >= 0; i--) {
-                if (mAdvertiserRequests.valueAt(i).onAdvertiserDestroyed(socket)) {
-                    mAdvertiserRequests.removeAt(i);
-                }
-            }
-            mAllAdvertisers.remove(socket);
+        public void onAllServicesRemoved(@NonNull MdnsInterfaceSocket socket) {
+            if (DBG) { mSharedLog.i("onAllServicesRemoved: " + socket); }
+            // Try destroying the advertiser if all services has been removed
+            destroyAdvertiser(socket, false /* interfaceDestroyed */);
         }
     };
 
@@ -318,6 +315,30 @@
     }
 
     /**
+     * Destroys the advertiser for the interface indicated by {@code socket}.
+     *
+     * {@code interfaceDestroyed} should be set to {@code true} if this method is called because
+     * the associated interface has been destroyed.
+     */
+    private void destroyAdvertiser(MdnsInterfaceSocket socket, boolean interfaceDestroyed) {
+        InterfaceAdvertiserRequest advertiserRequest;
+
+        MdnsInterfaceAdvertiser advertiser = mAllAdvertisers.remove(socket);
+        if (advertiser != null) {
+            advertiser.destroyNow();
+            if (DBG) { mSharedLog.i("MdnsInterfaceAdvertiser is destroyed: " + advertiser); }
+        }
+
+        for (int i = mAdvertiserRequests.size() - 1; i >= 0; i--) {
+            advertiserRequest = mAdvertiserRequests.valueAt(i);
+            if (advertiserRequest.onAdvertiserDestroyed(socket, interfaceDestroyed)) {
+                if (DBG) { mSharedLog.i("AdvertiserRequest is removed: " + advertiserRequest); }
+                mAdvertiserRequests.removeAt(i);
+            }
+        }
+    }
+
+    /**
      * A request for a {@link MdnsInterfaceAdvertiser}.
      *
      * This class tracks services to be advertised on all sockets provided via a registered
@@ -336,13 +357,22 @@
         }
 
         /**
-         * Called when an advertiser was destroyed, after all services were unregistered and it sent
-         * exit announcements, or the interface is gone.
+         * Called when the interface advertiser associated with {@code socket} has been destroyed.
          *
-         * @return true if this {@link InterfaceAdvertiserRequest} should now be deleted.
+         * {@code interfaceDestroyed} should be set to {@code true} if this method is called because
+         * the associated interface has been destroyed.
+         *
+         * @return true if the {@link InterfaceAdvertiserRequest} should now be deleted
          */
-        boolean onAdvertiserDestroyed(@NonNull MdnsInterfaceSocket socket) {
+        boolean onAdvertiserDestroyed(
+                @NonNull MdnsInterfaceSocket socket, boolean interfaceDestroyed) {
             final MdnsInterfaceAdvertiser removedAdvertiser = mAdvertisers.remove(socket);
+            if (removedAdvertiser != null
+                    && !interfaceDestroyed && mPendingRegistrations.size() > 0) {
+                mSharedLog.wtf(
+                        "unexpected onAdvertiserDestroyed() when there are pending registrations");
+            }
+
             if (mMdnsFeatureFlags.mIsMdnsOffloadFeatureEnabled && removedAdvertiser != null) {
                 final String interfaceName = removedAdvertiser.getSocketInterfaceName();
                 // If the interface is destroyed, stop all hardware offloading on that
@@ -528,7 +558,7 @@
         public void onInterfaceDestroyed(@NonNull SocketKey socketKey,
                 @NonNull MdnsInterfaceSocket socket) {
             final MdnsInterfaceAdvertiser advertiser = mAdvertisers.get(socket);
-            if (advertiser != null) advertiser.destroyNow();
+            if (advertiser != null) destroyAdvertiser(socket, true /* interfaceDestroyed */);
         }
 
         @Override
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java b/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java
index c2363c0..c1c7d5f 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java
@@ -102,12 +102,15 @@
                 @NonNull MdnsInterfaceAdvertiser advertiser, int serviceId, int conflictType);
 
         /**
-         * Called by the advertiser when it destroyed itself.
+         * Called when all services on this interface advertiser has already been removed and exit
+         * announcements have been sent.
          *
-         * This can happen after a call to {@link #destroyNow()}, or after all services were
-         * unregistered and the advertiser finished sending exit announcements.
+         * <p>It's guaranteed that there are no service registrations in the
+         * MdnsInterfaceAdvertiser when this callback is invoked.
+         *
+         * <p>This is typically listened by the {@link MdnsAdvertiser} to release the resources
          */
-        void onDestroyed(@NonNull MdnsInterfaceSocket socket);
+        void onAllServicesRemoved(@NonNull MdnsInterfaceSocket socket);
     }
 
     /**
@@ -149,10 +152,11 @@
         public void onFinished(@NonNull BaseAnnouncementInfo info) {
             if (info instanceof MdnsAnnouncer.ExitAnnouncementInfo) {
                 mRecordRepository.removeService(info.getServiceId());
-
-                if (mRecordRepository.getServicesCount() == 0) {
-                    destroyNow();
-                }
+                mCbHandler.post(() -> {
+                    if (mRecordRepository.getServicesCount() == 0) {
+                        mCb.onAllServicesRemoved(mSocket);
+                    }
+                });
             }
         }
     }
@@ -234,8 +238,7 @@
      * Start the advertiser.
      *
      * The advertiser will stop itself when all services are removed and exit announcements sent,
-     * notifying via {@link Callback#onDestroyed}. This can also be triggered manually via
-     * {@link #destroyNow()}.
+     * notifying via {@link Callback#onAllServicesRemoved}.
      */
     public void start() {
         mSocket.addPacketHandler(this);
@@ -283,8 +286,8 @@
         mAnnouncer.stop(id);
         final MdnsAnnouncer.ExitAnnouncementInfo exitInfo = mRecordRepository.exitService(id);
         if (exitInfo != null) {
-            // This effectively schedules destroyNow(), as it is to be called when the exit
-            // announcement finishes if there is no service left.
+            // This effectively schedules onAllServicesRemoved(), as it is to be called when the
+            // exit announcement finishes if there is no service left.
             // A non-zero exit announcement delay follows legacy mdnsresponder behavior, and is
             // also useful to ensure that when a host receives the exit announcement, the service
             // has been unregistered on all interfaces; so an announcement sent from interface A
@@ -294,9 +297,11 @@
         } else {
             // No exit announcement necessary: remove the service immediately.
             mRecordRepository.removeService(id);
-            if (mRecordRepository.getServicesCount() == 0) {
-                destroyNow();
-            }
+            mCbHandler.post(() -> {
+                if (mRecordRepository.getServicesCount() == 0) {
+                    mCb.onAllServicesRemoved(mSocket);
+                }
+            });
         }
     }
 
@@ -330,7 +335,8 @@
     /**
      * Destroy the advertiser immediately, not sending any exit announcement.
      *
-     * <p>Useful when the underlying network went away. This will trigger an onDestroyed callback.
+     * <p>This is typically called when all services on the interface are removed or when the
+     * underlying network went away.
      */
     public void destroyNow() {
         for (int serviceId : mRecordRepository.clearServices()) {
@@ -339,7 +345,6 @@
         }
         mReplySender.cancelAll();
         mSocket.removePacketHandler(this);
-        mCbHandler.post(() -> mCb.onDestroyed(mSocket));
     }
 
     /**
diff --git a/tests/cts/net/src/android/net/cts/NsdManagerTest.kt b/tests/cts/net/src/android/net/cts/NsdManagerTest.kt
index dbececf..8dbcf2f 100644
--- a/tests/cts/net/src/android/net/cts/NsdManagerTest.kt
+++ b/tests/cts/net/src/android/net/cts/NsdManagerTest.kt
@@ -2108,6 +2108,46 @@
     }
 
     @Test
+    fun testRegisterService_registerImmediatelyAfterUnregister_serviceFound() {
+        val info1 = makeTestServiceInfo(network = testNetwork1.network).apply {
+            serviceName = "service11111"
+            port = 11111
+        }
+        val info2 = makeTestServiceInfo(network = testNetwork1.network).apply {
+            serviceName = "service22222"
+            port = 22222
+        }
+        val registrationRecord1 = NsdRegistrationRecord()
+        val discoveryRecord1 = NsdDiscoveryRecord()
+        val registrationRecord2 = NsdRegistrationRecord()
+        val discoveryRecord2 = NsdDiscoveryRecord()
+        tryTest {
+            registerService(registrationRecord1, info1)
+            nsdManager.discoverServices(serviceType,
+                    NsdManager.PROTOCOL_DNS_SD, testNetwork1.network, { it.run() },
+                    discoveryRecord1)
+            discoveryRecord1.waitForServiceDiscovered(info1.serviceName,
+                    serviceType, testNetwork1.network)
+            nsdManager.stopServiceDiscovery(discoveryRecord1)
+
+            nsdManager.unregisterService(registrationRecord1)
+            registerService(registrationRecord2, info2)
+            nsdManager.discoverServices(serviceType,
+                    NsdManager.PROTOCOL_DNS_SD, testNetwork1.network, { it.run() },
+                    discoveryRecord2)
+            val infoDiscovered = discoveryRecord2.waitForServiceDiscovered(info2.serviceName,
+                    serviceType, testNetwork1.network)
+            val infoResolved = resolveService(infoDiscovered)
+            assertEquals(22222, infoResolved.port)
+        } cleanupStep {
+            nsdManager.stopServiceDiscovery(discoveryRecord2)
+            discoveryRecord2.expectCallback<DiscoveryStopped>()
+        } cleanup {
+            nsdManager.unregisterService(registrationRecord2)
+        }
+    }
+
+    @Test
     fun testServiceTypeClientRemovedAfterSocketDestroyed() {
         val si = makeTestServiceInfo(testNetwork1.network)
         // Register service on testNetwork1
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsAdvertiserTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsAdvertiserTest.kt
index b8ebf0f..df48f6c 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsAdvertiserTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsAdvertiserTest.kt
@@ -286,7 +286,6 @@
 
         postSync { socketCb.onInterfaceDestroyed(TEST_SOCKETKEY_1, mockSocket1) }
         verify(mockInterfaceAdvertiser1).destroyNow()
-        postSync { intAdvCbCaptor.value.onDestroyed(mockSocket1) }
         verify(cb).onOffloadStop(eq(TEST_INTERFACE1), eq(OFFLOAD_SERVICEINFO_NO_SUBTYPE2))
     }
 
@@ -364,10 +363,10 @@
         verify(cb).onOffloadStop(eq(TEST_INTERFACE1), eq(OFFLOAD_SERVICEINFO))
         verify(cb).onOffloadStop(eq(TEST_INTERFACE2), eq(OFFLOAD_SERVICEINFO))
 
-        // Interface advertisers call onDestroyed after sending exit announcements
-        postSync { intAdvCbCaptor1.value.onDestroyed(mockSocket1) }
+        // Interface advertisers call onAllServicesRemoved after sending exit announcements
+        postSync { intAdvCbCaptor1.value.onAllServicesRemoved(mockSocket1) }
         verify(socketProvider, never()).unrequestSocket(any())
-        postSync { intAdvCbCaptor2.value.onDestroyed(mockSocket2) }
+        postSync { intAdvCbCaptor2.value.onAllServicesRemoved(mockSocket2) }
         verify(socketProvider).unrequestSocket(socketCb)
     }
 
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt
index 28608bb..69fec85 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt
@@ -179,7 +179,7 @@
         // Exit announcements finish: the advertiser has no left service and destroys itself
         announceCb.onFinished(testExitInfo)
         thread.waitForIdle(TIMEOUT_MS)
-        verify(cb).onDestroyed(socket)
+        verify(cb).onAllServicesRemoved(socket)
     }
 
     @Test