Merge "[mdns] fix race conditions in MdnsAdvertiser" into main
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsAdvertiser.java b/service-t/src/com/android/server/connectivity/mdns/MdnsAdvertiser.java
index c162bcc..98c2d86 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsAdvertiser.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsAdvertiser.java
@@ -241,13 +241,10 @@
         }
 
         @Override
-        public void onDestroyed(@NonNull MdnsInterfaceSocket socket) {
-            for (int i = mAdvertiserRequests.size() - 1; i >= 0; i--) {
-                if (mAdvertiserRequests.valueAt(i).onAdvertiserDestroyed(socket)) {
-                    mAdvertiserRequests.removeAt(i);
-                }
-            }
-            mAllAdvertisers.remove(socket);
+        public void onAllServicesRemoved(@NonNull MdnsInterfaceSocket socket) {
+            if (DBG) { mSharedLog.i("onAllServicesRemoved: " + socket); }
+            // Try destroying the advertiser if all services has been removed
+            destroyAdvertiser(socket, false /* interfaceDestroyed */);
         }
     };
 
@@ -318,6 +315,30 @@
     }
 
     /**
+     * Destroys the advertiser for the interface indicated by {@code socket}.
+     *
+     * {@code interfaceDestroyed} should be set to {@code true} if this method is called because
+     * the associated interface has been destroyed.
+     */
+    private void destroyAdvertiser(MdnsInterfaceSocket socket, boolean interfaceDestroyed) {
+        InterfaceAdvertiserRequest advertiserRequest;
+
+        MdnsInterfaceAdvertiser advertiser = mAllAdvertisers.remove(socket);
+        if (advertiser != null) {
+            advertiser.destroyNow();
+            if (DBG) { mSharedLog.i("MdnsInterfaceAdvertiser is destroyed: " + advertiser); }
+        }
+
+        for (int i = mAdvertiserRequests.size() - 1; i >= 0; i--) {
+            advertiserRequest = mAdvertiserRequests.valueAt(i);
+            if (advertiserRequest.onAdvertiserDestroyed(socket, interfaceDestroyed)) {
+                if (DBG) { mSharedLog.i("AdvertiserRequest is removed: " + advertiserRequest); }
+                mAdvertiserRequests.removeAt(i);
+            }
+        }
+    }
+
+    /**
      * A request for a {@link MdnsInterfaceAdvertiser}.
      *
      * This class tracks services to be advertised on all sockets provided via a registered
@@ -336,13 +357,22 @@
         }
 
         /**
-         * Called when an advertiser was destroyed, after all services were unregistered and it sent
-         * exit announcements, or the interface is gone.
+         * Called when the interface advertiser associated with {@code socket} has been destroyed.
          *
-         * @return true if this {@link InterfaceAdvertiserRequest} should now be deleted.
+         * {@code interfaceDestroyed} should be set to {@code true} if this method is called because
+         * the associated interface has been destroyed.
+         *
+         * @return true if the {@link InterfaceAdvertiserRequest} should now be deleted
          */
-        boolean onAdvertiserDestroyed(@NonNull MdnsInterfaceSocket socket) {
+        boolean onAdvertiserDestroyed(
+                @NonNull MdnsInterfaceSocket socket, boolean interfaceDestroyed) {
             final MdnsInterfaceAdvertiser removedAdvertiser = mAdvertisers.remove(socket);
+            if (removedAdvertiser != null
+                    && !interfaceDestroyed && mPendingRegistrations.size() > 0) {
+                mSharedLog.wtf(
+                        "unexpected onAdvertiserDestroyed() when there are pending registrations");
+            }
+
             if (mMdnsFeatureFlags.mIsMdnsOffloadFeatureEnabled && removedAdvertiser != null) {
                 final String interfaceName = removedAdvertiser.getSocketInterfaceName();
                 // If the interface is destroyed, stop all hardware offloading on that
@@ -528,7 +558,7 @@
         public void onInterfaceDestroyed(@NonNull SocketKey socketKey,
                 @NonNull MdnsInterfaceSocket socket) {
             final MdnsInterfaceAdvertiser advertiser = mAdvertisers.get(socket);
-            if (advertiser != null) advertiser.destroyNow();
+            if (advertiser != null) destroyAdvertiser(socket, true /* interfaceDestroyed */);
         }
 
         @Override
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java b/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java
index c2363c0..c1c7d5f 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java
@@ -102,12 +102,15 @@
                 @NonNull MdnsInterfaceAdvertiser advertiser, int serviceId, int conflictType);
 
         /**
-         * Called by the advertiser when it destroyed itself.
+         * Called when all services on this interface advertiser has already been removed and exit
+         * announcements have been sent.
          *
-         * This can happen after a call to {@link #destroyNow()}, or after all services were
-         * unregistered and the advertiser finished sending exit announcements.
+         * <p>It's guaranteed that there are no service registrations in the
+         * MdnsInterfaceAdvertiser when this callback is invoked.
+         *
+         * <p>This is typically listened by the {@link MdnsAdvertiser} to release the resources
          */
-        void onDestroyed(@NonNull MdnsInterfaceSocket socket);
+        void onAllServicesRemoved(@NonNull MdnsInterfaceSocket socket);
     }
 
     /**
@@ -149,10 +152,11 @@
         public void onFinished(@NonNull BaseAnnouncementInfo info) {
             if (info instanceof MdnsAnnouncer.ExitAnnouncementInfo) {
                 mRecordRepository.removeService(info.getServiceId());
-
-                if (mRecordRepository.getServicesCount() == 0) {
-                    destroyNow();
-                }
+                mCbHandler.post(() -> {
+                    if (mRecordRepository.getServicesCount() == 0) {
+                        mCb.onAllServicesRemoved(mSocket);
+                    }
+                });
             }
         }
     }
@@ -234,8 +238,7 @@
      * Start the advertiser.
      *
      * The advertiser will stop itself when all services are removed and exit announcements sent,
-     * notifying via {@link Callback#onDestroyed}. This can also be triggered manually via
-     * {@link #destroyNow()}.
+     * notifying via {@link Callback#onAllServicesRemoved}.
      */
     public void start() {
         mSocket.addPacketHandler(this);
@@ -283,8 +286,8 @@
         mAnnouncer.stop(id);
         final MdnsAnnouncer.ExitAnnouncementInfo exitInfo = mRecordRepository.exitService(id);
         if (exitInfo != null) {
-            // This effectively schedules destroyNow(), as it is to be called when the exit
-            // announcement finishes if there is no service left.
+            // This effectively schedules onAllServicesRemoved(), as it is to be called when the
+            // exit announcement finishes if there is no service left.
             // A non-zero exit announcement delay follows legacy mdnsresponder behavior, and is
             // also useful to ensure that when a host receives the exit announcement, the service
             // has been unregistered on all interfaces; so an announcement sent from interface A
@@ -294,9 +297,11 @@
         } else {
             // No exit announcement necessary: remove the service immediately.
             mRecordRepository.removeService(id);
-            if (mRecordRepository.getServicesCount() == 0) {
-                destroyNow();
-            }
+            mCbHandler.post(() -> {
+                if (mRecordRepository.getServicesCount() == 0) {
+                    mCb.onAllServicesRemoved(mSocket);
+                }
+            });
         }
     }
 
@@ -330,7 +335,8 @@
     /**
      * Destroy the advertiser immediately, not sending any exit announcement.
      *
-     * <p>Useful when the underlying network went away. This will trigger an onDestroyed callback.
+     * <p>This is typically called when all services on the interface are removed or when the
+     * underlying network went away.
      */
     public void destroyNow() {
         for (int serviceId : mRecordRepository.clearServices()) {
@@ -339,7 +345,6 @@
         }
         mReplySender.cancelAll();
         mSocket.removePacketHandler(this);
-        mCbHandler.post(() -> mCb.onDestroyed(mSocket));
     }
 
     /**
diff --git a/tests/cts/net/src/android/net/cts/NsdManagerTest.kt b/tests/cts/net/src/android/net/cts/NsdManagerTest.kt
index dbececf..8dbcf2f 100644
--- a/tests/cts/net/src/android/net/cts/NsdManagerTest.kt
+++ b/tests/cts/net/src/android/net/cts/NsdManagerTest.kt
@@ -2108,6 +2108,46 @@
     }
 
     @Test
+    fun testRegisterService_registerImmediatelyAfterUnregister_serviceFound() {
+        val info1 = makeTestServiceInfo(network = testNetwork1.network).apply {
+            serviceName = "service11111"
+            port = 11111
+        }
+        val info2 = makeTestServiceInfo(network = testNetwork1.network).apply {
+            serviceName = "service22222"
+            port = 22222
+        }
+        val registrationRecord1 = NsdRegistrationRecord()
+        val discoveryRecord1 = NsdDiscoveryRecord()
+        val registrationRecord2 = NsdRegistrationRecord()
+        val discoveryRecord2 = NsdDiscoveryRecord()
+        tryTest {
+            registerService(registrationRecord1, info1)
+            nsdManager.discoverServices(serviceType,
+                    NsdManager.PROTOCOL_DNS_SD, testNetwork1.network, { it.run() },
+                    discoveryRecord1)
+            discoveryRecord1.waitForServiceDiscovered(info1.serviceName,
+                    serviceType, testNetwork1.network)
+            nsdManager.stopServiceDiscovery(discoveryRecord1)
+
+            nsdManager.unregisterService(registrationRecord1)
+            registerService(registrationRecord2, info2)
+            nsdManager.discoverServices(serviceType,
+                    NsdManager.PROTOCOL_DNS_SD, testNetwork1.network, { it.run() },
+                    discoveryRecord2)
+            val infoDiscovered = discoveryRecord2.waitForServiceDiscovered(info2.serviceName,
+                    serviceType, testNetwork1.network)
+            val infoResolved = resolveService(infoDiscovered)
+            assertEquals(22222, infoResolved.port)
+        } cleanupStep {
+            nsdManager.stopServiceDiscovery(discoveryRecord2)
+            discoveryRecord2.expectCallback<DiscoveryStopped>()
+        } cleanup {
+            nsdManager.unregisterService(registrationRecord2)
+        }
+    }
+
+    @Test
     fun testServiceTypeClientRemovedAfterSocketDestroyed() {
         val si = makeTestServiceInfo(testNetwork1.network)
         // Register service on testNetwork1
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsAdvertiserTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsAdvertiserTest.kt
index b8ebf0f..df48f6c 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsAdvertiserTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsAdvertiserTest.kt
@@ -286,7 +286,6 @@
 
         postSync { socketCb.onInterfaceDestroyed(TEST_SOCKETKEY_1, mockSocket1) }
         verify(mockInterfaceAdvertiser1).destroyNow()
-        postSync { intAdvCbCaptor.value.onDestroyed(mockSocket1) }
         verify(cb).onOffloadStop(eq(TEST_INTERFACE1), eq(OFFLOAD_SERVICEINFO_NO_SUBTYPE2))
     }
 
@@ -364,10 +363,10 @@
         verify(cb).onOffloadStop(eq(TEST_INTERFACE1), eq(OFFLOAD_SERVICEINFO))
         verify(cb).onOffloadStop(eq(TEST_INTERFACE2), eq(OFFLOAD_SERVICEINFO))
 
-        // Interface advertisers call onDestroyed after sending exit announcements
-        postSync { intAdvCbCaptor1.value.onDestroyed(mockSocket1) }
+        // Interface advertisers call onAllServicesRemoved after sending exit announcements
+        postSync { intAdvCbCaptor1.value.onAllServicesRemoved(mockSocket1) }
         verify(socketProvider, never()).unrequestSocket(any())
-        postSync { intAdvCbCaptor2.value.onDestroyed(mockSocket2) }
+        postSync { intAdvCbCaptor2.value.onAllServicesRemoved(mockSocket2) }
         verify(socketProvider).unrequestSocket(socketCb)
     }
 
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt
index 28608bb..69fec85 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt
@@ -179,7 +179,7 @@
         // Exit announcements finish: the advertiser has no left service and destroys itself
         announceCb.onFinished(testExitInfo)
         thread.waitForIdle(TIMEOUT_MS)
-        verify(cb).onDestroyed(socket)
+        verify(cb).onAllServicesRemoved(socket)
     }
 
     @Test