Deal with responses on MdnsServiceCache

To cache services and respond quickly, the MdnsServiceTypeClient
should add all services to the MdnsServiceCache and remove any
services from there as well.

Bug: 265787401
Test: atest FrameworksNetTests
Change-Id: If0a9e6b563a0992ac25b8cde7f3beb00700f1c11
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java b/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java
index dfaec75..f386dd4 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java
@@ -51,6 +51,7 @@
     @NonNull private final PerSocketServiceTypeClients perSocketServiceTypeClients;
     @NonNull private final Handler handler;
     @Nullable private final HandlerThread handlerThread;
+    @NonNull private final MdnsServiceCache serviceCache;
 
     private static class PerSocketServiceTypeClients {
         private final ArrayMap<Pair<String, SocketKey>, MdnsServiceTypeClient> clients =
@@ -119,10 +120,12 @@
         if (socketClient.getLooper() != null) {
             this.handlerThread = null;
             this.handler = new Handler(socketClient.getLooper());
+            this.serviceCache = new MdnsServiceCache(socketClient.getLooper());
         } else {
             this.handlerThread = new HandlerThread(MdnsDiscoveryManager.class.getSimpleName());
             this.handlerThread.start();
             this.handler = new Handler(handlerThread.getLooper());
+            this.serviceCache = new MdnsServiceCache(handlerThread.getLooper());
         }
     }
 
@@ -289,6 +292,6 @@
         return new MdnsServiceTypeClient(
                 serviceType, socketClient,
                 executorProvider.newServiceTypeClientSchedulerExecutor(), socketKey,
-                sharedLog.forSubComponent(tag), handler.getLooper());
+                sharedLog.forSubComponent(tag), handler.getLooper(), serviceCache);
     }
 }
\ No newline at end of file
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceCache.java b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceCache.java
index dc99e49..ec6af9b 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceCache.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceCache.java
@@ -96,7 +96,14 @@
                 : Collections.emptyList();
     }
 
-    private MdnsResponse findMatchedResponse(@NonNull List<MdnsResponse> responses,
+    /**
+     * Find a matched response for given service name
+     *
+     * @param responses the responses to be searched.
+     * @param serviceName the target service name
+     * @return the response which matches the given service name or null if not found.
+     */
+    public static MdnsResponse findMatchedResponse(@NonNull List<MdnsResponse> responses,
             @NonNull String serviceName) {
         for (MdnsResponse response : responses) {
             if (equalsIgnoreDnsCase(serviceName, response.getServiceInstanceName())) {
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
index 7035c90..3b5193a 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
@@ -16,6 +16,7 @@
 
 package com.android.server.connectivity.mdns;
 
+import static com.android.server.connectivity.mdns.MdnsServiceCache.findMatchedResponse;
 import static com.android.server.connectivity.mdns.util.MdnsUtils.ensureRunningOnHandlerThread;
 
 import android.annotation.NonNull;
@@ -40,10 +41,8 @@
 import java.util.Arrays;
 import java.util.Collection;
 import java.util.Collections;
-import java.util.HashMap;
 import java.util.Iterator;
 import java.util.List;
-import java.util.Map;
 import java.util.concurrent.ScheduledExecutorService;
 
 /**
@@ -67,10 +66,12 @@
     @NonNull private final SharedLog sharedLog;
     @NonNull private final Handler handler;
     @NonNull private final Dependencies dependencies;
+    /**
+     * The service caches for each socket. It should be accessed from looper thread only.
+     */
+    @NonNull private final MdnsServiceCache serviceCache;
     private final ArrayMap<MdnsServiceBrowserListener, MdnsSearchOptions> listeners =
             new ArrayMap<>();
-    // TODO: change instanceNameToResponse to TreeMap with case insensitive comparator.
-    private final Map<String, MdnsResponse> instanceNameToResponse = new HashMap<>();
     private final boolean removeServiceAfterTtlExpires =
             MdnsConfigs.removeServiceAfterTtlExpires();
     private final MdnsResponseDecoder.Clock clock;
@@ -190,9 +191,10 @@
             @NonNull ScheduledExecutorService executor,
             @NonNull SocketKey socketKey,
             @NonNull SharedLog sharedLog,
-            @NonNull Looper looper) {
+            @NonNull Looper looper,
+            @NonNull MdnsServiceCache serviceCache) {
         this(serviceType, socketClient, executor, new MdnsResponseDecoder.Clock(), socketKey,
-                sharedLog, looper, new Dependencies());
+                sharedLog, looper, new Dependencies(), serviceCache);
     }
 
     @VisibleForTesting
@@ -204,7 +206,8 @@
             @NonNull SocketKey socketKey,
             @NonNull SharedLog sharedLog,
             @NonNull Looper looper,
-            @NonNull Dependencies dependencies) {
+            @NonNull Dependencies dependencies,
+            @NonNull MdnsServiceCache serviceCache) {
         this.serviceType = serviceType;
         this.socketClient = socketClient;
         this.executor = executor;
@@ -215,6 +218,7 @@
         this.sharedLog = sharedLog;
         this.handler = new QueryTaskHandler(looper);
         this.dependencies = dependencies;
+        this.serviceCache = serviceCache;
     }
 
     private static MdnsServiceInfo buildMdnsServiceInfoFromResponse(
@@ -281,7 +285,8 @@
         this.searchOptions = searchOptions;
         boolean hadReply = false;
         if (listeners.put(listener, searchOptions) == null) {
-            for (MdnsResponse existingResponse : instanceNameToResponse.values()) {
+            for (MdnsResponse existingResponse :
+                    serviceCache.getCachedServices(serviceType, socketKey)) {
                 if (!responseMatchesOptions(existingResponse, searchOptions)) continue;
                 final MdnsServiceInfo info =
                         buildMdnsServiceInfoFromResponse(existingResponse, serviceTypeLabels);
@@ -377,11 +382,13 @@
         ensureRunningOnHandlerThread(handler);
         // Augment the list of current known responses, and generated responses for resolve
         // requests if there is no known response
-        final List<MdnsResponse> currentList = new ArrayList<>(instanceNameToResponse.values());
+        final List<MdnsResponse> cachedList =
+                serviceCache.getCachedServices(serviceType, socketKey);
+        final List<MdnsResponse> currentList = new ArrayList<>(cachedList);
         List<MdnsResponse> additionalResponses = makeResponsesForResolve(socketKey);
         for (MdnsResponse additionalResponse : additionalResponses) {
-            if (!instanceNameToResponse.containsKey(
-                    additionalResponse.getServiceInstanceName())) {
+            if (findMatchedResponse(
+                    cachedList, additionalResponse.getServiceInstanceName()) == null) {
                 currentList.add(additionalResponse);
             }
         }
@@ -393,16 +400,17 @@
         final ArrayList<MdnsResponse> allResponses = augmentedResult.second;
 
         for (MdnsResponse response : allResponses) {
+            final String serviceInstanceName = response.getServiceInstanceName();
             if (modifiedResponse.contains(response)) {
                 if (response.isGoodbye()) {
-                    onGoodbyeReceived(response.getServiceInstanceName());
+                    onGoodbyeReceived(serviceInstanceName);
                 } else {
                     onResponseModified(response);
                 }
-            } else if (instanceNameToResponse.containsKey(response.getServiceInstanceName())) {
+            } else if (findMatchedResponse(cachedList, serviceInstanceName) != null) {
                 // If the response is not modified and already in the cache. The cache will
                 // need to be updated to refresh the last receipt time.
-                instanceNameToResponse.put(response.getServiceInstanceName(), response);
+                serviceCache.addOrUpdateService(serviceType, socketKey, response);
             }
         }
         if (dependencies.hasMessages(handler, EVENT_START_QUERYTASK)
@@ -431,7 +439,7 @@
     /** Notify all services are removed because the socket is destroyed. */
     public void notifySocketDestroyed() {
         ensureRunningOnHandlerThread(handler);
-        for (MdnsResponse response : instanceNameToResponse.values()) {
+        for (MdnsResponse response : serviceCache.getCachedServices(serviceType, socketKey)) {
             final String name = response.getServiceInstanceName();
             if (name == null) continue;
             for (int i = 0; i < listeners.size(); i++) {
@@ -453,18 +461,18 @@
     private void onResponseModified(@NonNull MdnsResponse response) {
         final String serviceInstanceName = response.getServiceInstanceName();
         final MdnsResponse currentResponse =
-                instanceNameToResponse.get(serviceInstanceName);
+                serviceCache.getCachedService(serviceInstanceName, serviceType, socketKey);
 
         boolean newServiceFound = false;
         boolean serviceBecomesComplete = false;
         if (currentResponse == null) {
             newServiceFound = true;
             if (serviceInstanceName != null) {
-                instanceNameToResponse.put(serviceInstanceName, response);
+                serviceCache.addOrUpdateService(serviceType, socketKey, response);
             }
         } else {
             boolean before = currentResponse.isComplete();
-            instanceNameToResponse.put(serviceInstanceName, response);
+            serviceCache.addOrUpdateService(serviceType, socketKey, response);
             boolean after = response.isComplete();
             serviceBecomesComplete = !before && after;
         }
@@ -497,7 +505,8 @@
     }
 
     private void onGoodbyeReceived(@Nullable String serviceInstanceName) {
-        final MdnsResponse response = instanceNameToResponse.remove(serviceInstanceName);
+        final MdnsResponse response =
+                serviceCache.removeService(serviceInstanceName, serviceType, socketKey);
         if (response == null) {
             return;
         }
@@ -673,7 +682,8 @@
             if (resolveName == null) {
                 continue;
             }
-            MdnsResponse knownResponse = instanceNameToResponse.get(resolveName);
+            MdnsResponse knownResponse =
+                    serviceCache.getCachedService(resolveName, serviceType, socketKey);
             if (knownResponse == null) {
                 final ArrayList<String> instanceFullName = new ArrayList<>(
                         serviceTypeLabels.length + 1);
@@ -691,19 +701,21 @@
     private void tryRemoveServiceAfterTtlExpires() {
         if (!shouldRemoveServiceAfterTtlExpires()) return;
 
-        Iterator<MdnsResponse> iter = instanceNameToResponse.values().iterator();
+        Iterator<MdnsResponse> iter =
+                serviceCache.getCachedServices(serviceType, socketKey).iterator();
         while (iter.hasNext()) {
             MdnsResponse existingResponse = iter.next();
+            final String serviceInstanceName = existingResponse.getServiceInstanceName();
             if (existingResponse.hasServiceRecord()
                     && existingResponse.getServiceRecord()
                     .getRemainingTTL(clock.elapsedRealtime()) == 0) {
-                iter.remove();
+                serviceCache.removeService(serviceInstanceName, serviceType, socketKey);
                 for (int i = 0; i < listeners.size(); i++) {
                     if (!responseMatchesOptions(existingResponse, listeners.valueAt(i))) {
                         continue;
                     }
                     final MdnsServiceBrowserListener listener = listeners.keyAt(i);
-                    if (existingResponse.getServiceInstanceName() != null) {
+                    if (serviceInstanceName != null) {
                         final MdnsServiceInfo serviceInfo = buildMdnsServiceInfoFromResponse(
                                 existingResponse, serviceTypeLabels);
                         if (existingResponse.isComplete()) {
@@ -812,7 +824,7 @@
 
     private long getMinRemainingTtl(long now) {
         long minRemainingTtl = Long.MAX_VALUE;
-        for (MdnsResponse response : instanceNameToResponse.values()) {
+        for (MdnsResponse response : serviceCache.getCachedServices(serviceType, socketKey)) {
             if (!response.isComplete()) {
                 continue;
             }
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java
index 4328053..11c9653 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java
@@ -131,6 +131,7 @@
     private SocketKey socketKey;
     private HandlerThread thread;
     private Handler handler;
+    private MdnsServiceCache serviceCache;
     private long latestDelayMs = 0;
     private Message delayMessage = null;
     private Handler realHandler = null;
@@ -190,6 +191,7 @@
         thread = new HandlerThread("MdnsServiceTypeClientTests");
         thread.start();
         handler = new Handler(thread.getLooper());
+        serviceCache = new MdnsServiceCache(thread.getLooper());
 
         doAnswer(inv -> {
             latestDelayMs = 0;
@@ -213,7 +215,8 @@
 
         client =
                 new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor,
-                        mockDecoderClock, socketKey, mockSharedLog, thread.getLooper(), mockDeps) {
+                        mockDecoderClock, socketKey, mockSharedLog, thread.getLooper(), mockDeps,
+                        serviceCache) {
                     @Override
                     MdnsPacketWriter createMdnsPacketWriter() {
                         return mockPacketWriter;
@@ -908,7 +911,8 @@
         final String serviceInstanceName = "service-instance-1";
         client =
                 new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor,
-                        mockDecoderClock, socketKey, mockSharedLog, thread.getLooper(), mockDeps) {
+                        mockDecoderClock, socketKey, mockSharedLog, thread.getLooper(), mockDeps,
+                        serviceCache) {
                     @Override
                     MdnsPacketWriter createMdnsPacketWriter() {
                         return mockPacketWriter;
@@ -953,7 +957,8 @@
         final String serviceInstanceName = "service-instance-1";
         client =
                 new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor,
-                        mockDecoderClock, socketKey, mockSharedLog, thread.getLooper(), mockDeps) {
+                        mockDecoderClock, socketKey, mockSharedLog, thread.getLooper(), mockDeps,
+                        serviceCache) {
                     @Override
                     MdnsPacketWriter createMdnsPacketWriter() {
                         return mockPacketWriter;
@@ -986,7 +991,8 @@
         final String serviceInstanceName = "service-instance-1";
         client =
                 new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor,
-                        mockDecoderClock, socketKey, mockSharedLog, thread.getLooper(), mockDeps) {
+                        mockDecoderClock, socketKey, mockSharedLog, thread.getLooper(), mockDeps,
+                        serviceCache) {
                     @Override
                     MdnsPacketWriter createMdnsPacketWriter() {
                         return mockPacketWriter;
@@ -1106,7 +1112,8 @@
     @Test
     public void testProcessResponse_Resolve() throws Exception {
         client = new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor,
-                mockDecoderClock, socketKey, mockSharedLog, thread.getLooper(), mockDeps);
+                mockDecoderClock, socketKey, mockSharedLog, thread.getLooper(), mockDeps,
+                serviceCache);
 
         final String instanceName = "service-instance";
         final String[] hostname = new String[] { "testhost "};
@@ -1199,7 +1206,8 @@
     @Test
     public void testRenewTxtSrvInResolve() throws Exception {
         client = new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor,
-                mockDecoderClock, socketKey, mockSharedLog, thread.getLooper(), mockDeps);
+                mockDecoderClock, socketKey, mockSharedLog, thread.getLooper(), mockDeps,
+                serviceCache);
 
         final String instanceName = "service-instance";
         final String[] hostname = new String[] { "testhost "};
@@ -1312,7 +1320,8 @@
     @Test
     public void testProcessResponse_ResolveExcludesOtherServices() {
         client = new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor,
-                mockDecoderClock, socketKey, mockSharedLog, thread.getLooper(), mockDeps);
+                mockDecoderClock, socketKey, mockSharedLog, thread.getLooper(), mockDeps,
+                serviceCache);
 
         final String requestedInstance = "instance1";
         final String otherInstance = "instance2";
@@ -1376,7 +1385,8 @@
     @Test
     public void testProcessResponse_SubtypeDiscoveryLimitedToSubtype() {
         client = new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor,
-                mockDecoderClock, socketKey, mockSharedLog, thread.getLooper(), mockDeps);
+                mockDecoderClock, socketKey, mockSharedLog, thread.getLooper(), mockDeps,
+                serviceCache);
 
         final String matchingInstance = "instance1";
         final String subtype = "_subtype";
@@ -1457,7 +1467,8 @@
     @Test
     public void testNotifySocketDestroyed() throws Exception {
         client = new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor,
-                mockDecoderClock, socketKey, mockSharedLog, thread.getLooper(), mockDeps);
+                mockDecoderClock, socketKey, mockSharedLog, thread.getLooper(), mockDeps,
+                serviceCache);
 
         final String requestedInstance = "instance1";
         final String otherInstance = "instance2";
@@ -1512,16 +1523,109 @@
         verify(mockListenerOne, never()).onServiceNameRemoved(matchServiceName(otherInstance));
 
         // mockListenerTwo gets notified for both though
-        final InOrder inOrder2 = inOrder(mockListenerTwo);
-        inOrder2.verify(mockListenerTwo).onServiceNameDiscovered(
+        verify(mockListenerTwo).onServiceNameDiscovered(
                 matchServiceName(requestedInstance));
-        inOrder2.verify(mockListenerTwo).onServiceFound(matchServiceName(requestedInstance));
-        inOrder2.verify(mockListenerTwo).onServiceNameDiscovered(matchServiceName(otherInstance));
-        inOrder2.verify(mockListenerTwo).onServiceFound(matchServiceName(otherInstance));
-        inOrder2.verify(mockListenerTwo).onServiceRemoved(matchServiceName(otherInstance));
-        inOrder2.verify(mockListenerTwo).onServiceNameRemoved(matchServiceName(otherInstance));
-        inOrder2.verify(mockListenerTwo).onServiceRemoved(matchServiceName(requestedInstance));
-        inOrder2.verify(mockListenerTwo).onServiceNameRemoved(matchServiceName(requestedInstance));
+        verify(mockListenerTwo).onServiceFound(matchServiceName(requestedInstance));
+        verify(mockListenerTwo).onServiceNameDiscovered(matchServiceName(otherInstance));
+        verify(mockListenerTwo).onServiceFound(matchServiceName(otherInstance));
+        verify(mockListenerTwo).onServiceRemoved(matchServiceName(otherInstance));
+        verify(mockListenerTwo).onServiceNameRemoved(matchServiceName(otherInstance));
+        verify(mockListenerTwo).onServiceRemoved(matchServiceName(requestedInstance));
+        verify(mockListenerTwo).onServiceNameRemoved(matchServiceName(requestedInstance));
+    }
+
+    @Test
+    public void testServicesAreCached() throws Exception {
+        final String serviceName = "service-instance";
+        final String ipV4Address = "192.0.2.0";
+        // Register a listener
+        startSendAndReceive(mockListenerOne, MdnsSearchOptions.getDefaultOptions());
+        verify(mockDeps, times(1)).removeMessages(any(), eq(EVENT_START_QUERYTASK));
+        InOrder inOrder = inOrder(mockListenerOne);
+
+        // Process a response which has ip address to make response become complete.
+        final String subtype = "ABCDE";
+        processResponse(createResponse(
+                        serviceName, ipV4Address, 5353, subtype,
+                        Collections.emptyMap(), TEST_TTL),
+                socketKey);
+
+        // Verify that onServiceNameDiscovered is called.
+        inOrder.verify(mockListenerOne).onServiceNameDiscovered(serviceInfoCaptor.capture());
+        verifyServiceInfo(serviceInfoCaptor.getAllValues().get(0),
+                serviceName,
+                SERVICE_TYPE_LABELS,
+                List.of(ipV4Address) /* ipv4Address */,
+                List.of() /* ipv6Address */,
+                5353 /* port */,
+                Collections.singletonList(subtype) /* subTypes */,
+                Collections.singletonMap("key", null) /* attributes */,
+                socketKey);
+
+        // Verify that onServiceFound is called.
+        inOrder.verify(mockListenerOne).onServiceFound(serviceInfoCaptor.capture());
+        verifyServiceInfo(serviceInfoCaptor.getAllValues().get(1),
+                serviceName,
+                SERVICE_TYPE_LABELS,
+                List.of(ipV4Address) /* ipv4Address */,
+                List.of() /* ipv6Address */,
+                5353 /* port */,
+                Collections.singletonList(subtype) /* subTypes */,
+                Collections.singletonMap("key", null) /* attributes */,
+                socketKey);
+
+        // Unregister the listener
+        stopSendAndReceive(mockListenerOne);
+        verify(mockDeps, times(2)).removeMessages(any(), eq(EVENT_START_QUERYTASK));
+
+        // Register another listener.
+        startSendAndReceive(mockListenerTwo, MdnsSearchOptions.getDefaultOptions());
+        verify(mockDeps, times(3)).removeMessages(any(), eq(EVENT_START_QUERYTASK));
+        InOrder inOrder2 = inOrder(mockListenerTwo);
+
+        // The services are cached in MdnsServiceCache, verify that onServiceNameDiscovered is
+        // called immediately.
+        inOrder2.verify(mockListenerTwo).onServiceNameDiscovered(serviceInfoCaptor.capture());
+        verifyServiceInfo(serviceInfoCaptor.getAllValues().get(2),
+                serviceName,
+                SERVICE_TYPE_LABELS,
+                List.of(ipV4Address) /* ipv4Address */,
+                List.of() /* ipv6Address */,
+                5353 /* port */,
+                Collections.singletonList(subtype) /* subTypes */,
+                Collections.singletonMap("key", null) /* attributes */,
+                socketKey);
+
+        // The services are cached in MdnsServiceCache, verify that onServiceFound is
+        // called immediately.
+        inOrder2.verify(mockListenerTwo).onServiceFound(serviceInfoCaptor.capture());
+        verifyServiceInfo(serviceInfoCaptor.getAllValues().get(3),
+                serviceName,
+                SERVICE_TYPE_LABELS,
+                List.of(ipV4Address) /* ipv4Address */,
+                List.of() /* ipv6Address */,
+                5353 /* port */,
+                Collections.singletonList(subtype) /* subTypes */,
+                Collections.singletonMap("key", null) /* attributes */,
+                socketKey);
+
+        // Process a response with a different ip address, port and updated text attributes.
+        final String ipV6Address = "2001:db8::";
+        processResponse(createResponse(
+                serviceName, ipV6Address, 5354, subtype,
+                Collections.singletonMap("key", "value"), TEST_TTL), socketKey);
+
+        // Verify the onServiceUpdated is called.
+        inOrder2.verify(mockListenerTwo).onServiceUpdated(serviceInfoCaptor.capture());
+        verifyServiceInfo(serviceInfoCaptor.getAllValues().get(4),
+                serviceName,
+                SERVICE_TYPE_LABELS,
+                List.of(ipV4Address) /* ipv4Address */,
+                List.of(ipV6Address) /* ipv6Address */,
+                5354 /* port */,
+                Collections.singletonList(subtype) /* subTypes */,
+                Collections.singletonMap("key", "value") /* attributes */,
+                socketKey);
     }
 
     private static MdnsServiceInfo matchServiceName(String name) {