Use SocketKey in MdnsServiceCache

The MdnsServiceTypeClient is now created using a SocketKey, so
the MdnsServiceCache should also use the SocketKey to deal with
the caching services.

Bug: 265787401
Test: atest FrameworksNetTests
Change-Id: I6165ffd420a39e750c06778b4851142a3ba3cf44
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceCache.java b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceCache.java
index cd0be67..dc99e49 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceCache.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceCache.java
@@ -22,7 +22,6 @@
 
 import android.annotation.NonNull;
 import android.annotation.Nullable;
-import android.net.Network;
 import android.os.Handler;
 import android.os.Looper;
 import android.util.ArrayMap;
@@ -45,15 +44,15 @@
 public class MdnsServiceCache {
     private static class CacheKey {
         @NonNull final String mLowercaseServiceType;
-        @Nullable final Network mNetwork;
+        @NonNull final SocketKey mSocketKey;
 
-        CacheKey(@NonNull String serviceType, @Nullable Network network) {
+        CacheKey(@NonNull String serviceType, @NonNull SocketKey socketKey) {
             mLowercaseServiceType = toDnsLowerCase(serviceType);
-            mNetwork = network;
+            mSocketKey = socketKey;
         }
 
         @Override public int hashCode() {
-            return Objects.hash(mLowercaseServiceType, mNetwork);
+            return Objects.hash(mLowercaseServiceType, mSocketKey);
         }
 
         @Override public boolean equals(Object other) {
@@ -64,11 +63,11 @@
                 return false;
             }
             return Objects.equals(mLowercaseServiceType, ((CacheKey) other).mLowercaseServiceType)
-                    && Objects.equals(mNetwork, ((CacheKey) other).mNetwork);
+                    && Objects.equals(mSocketKey, ((CacheKey) other).mSocketKey);
         }
     }
     /**
-     * A map of cached services. Key is composed of service name, type and network. Value is the
+     * A map of cached services. Key is composed of service name, type and socket. Value is the
      * service which use the service type to discover from each socket.
      */
     @NonNull
@@ -81,17 +80,17 @@
     }
 
     /**
-     * Get the cache services which are queried from given service type and network.
+     * Get the cache services which are queried from given service type and socket.
      *
      * @param serviceType the target service type.
-     * @param network the target network
+     * @param socketKey the target socket
      * @return the set of services which matches the given service type.
      */
     @NonNull
     public List<MdnsResponse> getCachedServices(@NonNull String serviceType,
-            @Nullable Network network) {
+            @NonNull SocketKey socketKey) {
         ensureRunningOnHandlerThread(mHandler);
-        final CacheKey key = new CacheKey(serviceType, network);
+        final CacheKey key = new CacheKey(serviceType, socketKey);
         return mCachedServices.containsKey(key)
                 ? Collections.unmodifiableList(new ArrayList<>(mCachedServices.get(key)))
                 : Collections.emptyList();
@@ -112,15 +111,15 @@
      *
      * @param serviceName the target service name.
      * @param serviceType the target service type.
-     * @param network the target network
+     * @param socketKey the target socket
      * @return the service which matches given conditions.
      */
     @Nullable
     public MdnsResponse getCachedService(@NonNull String serviceName,
-            @NonNull String serviceType, @Nullable Network network) {
+            @NonNull String serviceType, @NonNull SocketKey socketKey) {
         ensureRunningOnHandlerThread(mHandler);
         final List<MdnsResponse> responses =
-                mCachedServices.get(new CacheKey(serviceType, network));
+                mCachedServices.get(new CacheKey(serviceType, socketKey));
         if (responses == null) {
             return null;
         }
@@ -132,14 +131,14 @@
      * Add or update a service.
      *
      * @param serviceType the service type.
-     * @param network the target network
+     * @param socketKey the target socket
      * @param response the response of the discovered service.
      */
-    public void addOrUpdateService(@NonNull String serviceType, @Nullable Network network,
+    public void addOrUpdateService(@NonNull String serviceType, @NonNull SocketKey socketKey,
             @NonNull MdnsResponse response) {
         ensureRunningOnHandlerThread(mHandler);
         final List<MdnsResponse> responses = mCachedServices.computeIfAbsent(
-                new CacheKey(serviceType, network), key -> new ArrayList<>());
+                new CacheKey(serviceType, socketKey), key -> new ArrayList<>());
         // Remove existing service if present.
         final MdnsResponse existing =
                 findMatchedResponse(responses, response.getServiceInstanceName());
@@ -148,18 +147,18 @@
     }
 
     /**
-     * Remove a service which matches the given service name, type and network.
+     * Remove a service which matches the given service name, type and socket.
      *
      * @param serviceName the target service name.
      * @param serviceType the target service type.
-     * @param network the target network.
+     * @param socketKey the target socket.
      */
     @Nullable
     public MdnsResponse removeService(@NonNull String serviceName, @NonNull String serviceType,
-            @Nullable Network network) {
+            @NonNull SocketKey socketKey) {
         ensureRunningOnHandlerThread(mHandler);
         final List<MdnsResponse> responses =
-                mCachedServices.get(new CacheKey(serviceType, network));
+                mCachedServices.get(new CacheKey(serviceType, socketKey));
         if (responses == null) {
             return null;
         }
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceCacheTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceCacheTest.kt
index f091eea..b43bcf7 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceCacheTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceCacheTest.kt
@@ -16,7 +16,6 @@
 
 package com.android.server.connectivity.mdns
 
-import android.net.Network
 import android.os.Build
 import android.os.Handler
 import android.os.HandlerThread
@@ -32,7 +31,6 @@
 import org.junit.Before
 import org.junit.Test
 import org.junit.runner.RunWith
-import org.mockito.Mockito.mock
 
 private const val SERVICE_NAME_1 = "service-instance-1"
 private const val SERVICE_NAME_2 = "service-instance-2"
@@ -44,7 +42,7 @@
 @RunWith(DevSdkIgnoreRunner::class)
 @DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.S_V2)
 class MdnsServiceCacheTest {
-    private val network = mock(Network::class.java)
+    private val socketKey = SocketKey(null /* network */, INTERFACE_INDEX)
     private val thread = HandlerThread(MdnsServiceCacheTest::class.simpleName)
     private val handler by lazy {
         Handler(thread.looper)
@@ -71,39 +69,47 @@
         return future.get(DEFAULT_TIMEOUT_MS, TimeUnit.MILLISECONDS)
     }
 
-    private fun addOrUpdateService(serviceType: String, network: Network, service: MdnsResponse):
-            Unit = runningOnHandlerAndReturn {
-        serviceCache.addOrUpdateService(serviceType, network, service) }
+    private fun addOrUpdateService(
+            serviceType: String,
+            socketKey: SocketKey,
+            service: MdnsResponse
+    ): Unit = runningOnHandlerAndReturn {
+        serviceCache.addOrUpdateService(serviceType, socketKey, service)
+    }
 
-    private fun removeService(serviceName: String, serviceType: String, network: Network):
+    private fun removeService(serviceName: String, serviceType: String, socketKey: SocketKey):
             Unit = runningOnHandlerAndReturn {
-        serviceCache.removeService(serviceName, serviceType, network) }
+        serviceCache.removeService(serviceName, serviceType, socketKey) }
 
-    private fun getService(serviceName: String, serviceType: String, network: Network):
+    private fun getService(serviceName: String, serviceType: String, socketKey: SocketKey):
             MdnsResponse? = runningOnHandlerAndReturn {
-        serviceCache.getCachedService(serviceName, serviceType, network) }
+        serviceCache.getCachedService(serviceName, serviceType, socketKey) }
 
-    private fun getServices(serviceType: String, network: Network): List<MdnsResponse> =
-        runningOnHandlerAndReturn { serviceCache.getCachedServices(serviceType, network) }
+    private fun getServices(serviceType: String, socketKey: SocketKey): List<MdnsResponse> =
+        runningOnHandlerAndReturn { serviceCache.getCachedServices(serviceType, socketKey) }
 
     @Test
     fun testAddAndRemoveService() {
-        addOrUpdateService(SERVICE_TYPE_1, network, createResponse(SERVICE_NAME_1, SERVICE_TYPE_1))
-        var response = getService(SERVICE_NAME_1, SERVICE_TYPE_1, network)
+        addOrUpdateService(
+                SERVICE_TYPE_1, socketKey, createResponse(SERVICE_NAME_1, SERVICE_TYPE_1))
+        var response = getService(SERVICE_NAME_1, SERVICE_TYPE_1, socketKey)
         assertNotNull(response)
         assertEquals(SERVICE_NAME_1, response.serviceInstanceName)
-        removeService(SERVICE_NAME_1, SERVICE_TYPE_1, network)
-        response = getService(SERVICE_NAME_1, SERVICE_TYPE_1, network)
+        removeService(SERVICE_NAME_1, SERVICE_TYPE_1, socketKey)
+        response = getService(SERVICE_NAME_1, SERVICE_TYPE_1, socketKey)
         assertNull(response)
     }
 
     @Test
     fun testGetCachedServices_multipleServiceTypes() {
-        addOrUpdateService(SERVICE_TYPE_1, network, createResponse(SERVICE_NAME_1, SERVICE_TYPE_1))
-        addOrUpdateService(SERVICE_TYPE_1, network, createResponse(SERVICE_NAME_2, SERVICE_TYPE_1))
-        addOrUpdateService(SERVICE_TYPE_2, network, createResponse(SERVICE_NAME_2, SERVICE_TYPE_2))
+        addOrUpdateService(
+                SERVICE_TYPE_1, socketKey, createResponse(SERVICE_NAME_1, SERVICE_TYPE_1))
+        addOrUpdateService(
+                SERVICE_TYPE_1, socketKey, createResponse(SERVICE_NAME_2, SERVICE_TYPE_1))
+        addOrUpdateService(
+                SERVICE_TYPE_2, socketKey, createResponse(SERVICE_NAME_2, SERVICE_TYPE_2))
 
-        val responses1 = getServices(SERVICE_TYPE_1, network)
+        val responses1 = getServices(SERVICE_TYPE_1, socketKey)
         assertEquals(2, responses1.size)
         assertTrue(responses1.stream().anyMatch { response ->
             response.serviceInstanceName == SERVICE_NAME_1
@@ -111,19 +117,19 @@
         assertTrue(responses1.any { response ->
             response.serviceInstanceName == SERVICE_NAME_2
         })
-        val responses2 = getServices(SERVICE_TYPE_2, network)
+        val responses2 = getServices(SERVICE_TYPE_2, socketKey)
         assertEquals(1, responses2.size)
         assertTrue(responses2.any { response ->
             response.serviceInstanceName == SERVICE_NAME_2
         })
 
-        removeService(SERVICE_NAME_2, SERVICE_TYPE_1, network)
-        val responses3 = getServices(SERVICE_TYPE_1, network)
+        removeService(SERVICE_NAME_2, SERVICE_TYPE_1, socketKey)
+        val responses3 = getServices(SERVICE_TYPE_1, socketKey)
         assertEquals(1, responses3.size)
         assertTrue(responses3.any { response ->
             response.serviceInstanceName == SERVICE_NAME_1
         })
-        val responses4 = getServices(SERVICE_TYPE_2, network)
+        val responses4 = getServices(SERVICE_TYPE_2, socketKey)
         assertEquals(1, responses4.size)
         assertTrue(responses4.any { response ->
             response.serviceInstanceName == SERVICE_NAME_2
@@ -132,5 +138,5 @@
 
     private fun createResponse(serviceInstanceName: String, serviceType: String) = MdnsResponse(
         0 /* now */, "$serviceInstanceName.$serviceType".split(".").toTypedArray(),
-            INTERFACE_INDEX, network)
+            socketKey.interfaceIndex, socketKey.network)
 }