Add removeServices method in MdnsServiceCache

This method can remove all services that match the given
CacheKey.

Bug: 355421878
Test: atest FrameworksNetTests NsdManagerTest
Change-Id: I83e1a0cbdd4355d92510b5f9a550f96ef0154682
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceCache.java b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceCache.java
index 7eea93a..a8a4ef1 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceCache.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceCache.java
@@ -233,6 +233,21 @@
     }
 
     /**
+     * Remove services which matches the given type and socket.
+     *
+     * @param cacheKey the target CacheKey.
+     */
+    public void removeServices(@NonNull CacheKey cacheKey) {
+        ensureRunningOnHandlerThread(mHandler);
+        // Remove all services
+        if (mCachedServices.remove(cacheKey) == null) {
+            return;
+        }
+        // Update the next expiration check time if services are removed.
+        mNextExpirationTime = getNextExpirationTime(mClock.elapsedRealtime());
+    }
+
+    /**
      * Register a callback to listen to service expiration.
      *
      * <p> Registering the same callback instance twice is a no-op, since MdnsServiceTypeClient
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceCacheTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceCacheTest.kt
index b040ab6..0a8f108 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceCacheTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceCacheTest.kt
@@ -154,6 +154,11 @@
         serviceCache.registerServiceExpiredCallback(cacheKey, callback)
     }
 
+    private fun removeServices(
+            serviceCache: MdnsServiceCache,
+            cacheKey: CacheKey
+    ): Unit = runningOnHandlerAndReturn { serviceCache.removeServices(cacheKey) }
+
     @Test
     fun testAddAndRemoveService() {
         val serviceCache = MdnsServiceCache(thread.looper, makeFlags(), clock)
@@ -291,6 +296,37 @@
         assertEquals(response4, responses[3])
     }
 
+    @Test
+    fun testRemoveServices() {
+        val serviceCache = MdnsServiceCache(thread.looper, makeFlags(), clock)
+        addOrUpdateService(serviceCache, cacheKey1, createResponse(SERVICE_NAME_1, SERVICE_TYPE_1))
+        addOrUpdateService(serviceCache, cacheKey1, createResponse(SERVICE_NAME_2, SERVICE_TYPE_1))
+        addOrUpdateService(serviceCache, cacheKey2, createResponse(SERVICE_NAME_1, SERVICE_TYPE_2))
+        val responses1 = getServices(serviceCache, cacheKey1)
+        assertEquals(2, responses1.size)
+        assertTrue(responses1.stream().anyMatch { response ->
+            response.serviceInstanceName == SERVICE_NAME_1
+        })
+        assertTrue(responses1.any { response ->
+            response.serviceInstanceName == SERVICE_NAME_2
+        })
+        val responses2 = getServices(serviceCache, cacheKey2)
+        assertEquals(1, responses2.size)
+        assertTrue(responses2.stream().anyMatch { response ->
+            response.serviceInstanceName == SERVICE_NAME_1
+        })
+
+        removeServices(serviceCache, cacheKey1)
+        val responses3 = getServices(serviceCache, cacheKey1)
+        assertEquals(0, responses3.size)
+        val responses4 = getServices(serviceCache, cacheKey2)
+        assertEquals(1, responses4.size)
+
+        removeServices(serviceCache, cacheKey2)
+        val responses5 = getServices(serviceCache, cacheKey2)
+        assertEquals(0, responses5.size)
+    }
+
     private fun createResponse(
             serviceInstanceName: String,
             serviceType: String,