Merge "Add CachedService Class" into main
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceCache.java b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceCache.java
index 4ae8701..7c72fb1 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceCache.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceCache.java
@@ -78,6 +78,17 @@
             return "CacheKey{ ServiceType=" + mUpperCaseServiceType + ", " + mSocketKey + " }";
         }
     }
+
+    public static class CachedService {
+        @NonNull final MdnsResponse mService;
+        boolean mServiceExpired;
+
+        CachedService(MdnsResponse service) {
+            mService = service;
+            mServiceExpired = false;
+        }
+    }
+
     /**
      * A map of cached services. Key is composed of service type and socket. Value is the list of
      * services which are discovered from the given CacheKey.
@@ -86,7 +97,7 @@
      * removal process to progress through the expiration check efficiently.
      */
     @NonNull
-    private final ArrayMap<CacheKey, List<MdnsResponse>> mCachedServices = new ArrayMap<>();
+    private final ArrayMap<CacheKey, List<CachedService>> mCachedServices = new ArrayMap<>();
     /**
      * A map of service expire callbacks. Key is composed of service type and socket and value is
      * the callback listener.
@@ -113,6 +124,14 @@
         mClock = clock;
     }
 
+    private List<MdnsResponse> cachedServicesToResponses(List<CachedService> cachedServices) {
+        final List<MdnsResponse> responses = new ArrayList<>();
+        for (CachedService cachedService : cachedServices) {
+            responses.add(cachedService.mService);
+        }
+        return responses;
+    }
+
     /**
      * Get the cache services which are queried from given service type and socket.
      *
@@ -126,7 +145,8 @@
             maybeRemoveExpiredServices(cacheKey, mClock.elapsedRealtime());
         }
         return mCachedServices.containsKey(cacheKey)
-                ? Collections.unmodifiableList(new ArrayList<>(mCachedServices.get(cacheKey)))
+                ? Collections.unmodifiableList(
+                        cachedServicesToResponses(mCachedServices.get(cacheKey)))
                 : Collections.emptyList();
     }
 
@@ -147,6 +167,16 @@
         return null;
     }
 
+    private static CachedService findMatchedCachedService(
+            @NonNull List<CachedService> cachedServices, @NonNull String serviceName) {
+        for (CachedService cachedService : cachedServices) {
+            if (equalsIgnoreDnsCase(serviceName, cachedService.mService.getServiceInstanceName())) {
+                return cachedService;
+            }
+        }
+        return null;
+    }
+
     /**
      * Get the cache service.
      *
@@ -160,22 +190,23 @@
         if (mMdnsFeatureFlags.mIsExpiredServicesRemovalEnabled) {
             maybeRemoveExpiredServices(cacheKey, mClock.elapsedRealtime());
         }
-        final List<MdnsResponse> responses = mCachedServices.get(cacheKey);
-        if (responses == null) {
+        final List<CachedService> cachedServices = mCachedServices.get(cacheKey);
+        if (cachedServices == null) {
             return null;
         }
-        final MdnsResponse response = findMatchedResponse(responses, serviceName);
-        return response != null ? new MdnsResponse(response) : null;
+        final CachedService cachedService = findMatchedCachedService(cachedServices, serviceName);
+        return cachedService != null ? new MdnsResponse(cachedService.mService) : null;
     }
 
-    static void insertResponseAndSortList(
-            List<MdnsResponse> responses, MdnsResponse response, long now) {
+    static void insertServiceAndSortList(
+            List<CachedService> cachedServices, CachedService cachedService, long now) {
         // binarySearch returns "the index of the search key, if it is contained in the list;
         // otherwise, (-(insertion point) - 1)"
-        final int searchRes = Collections.binarySearch(responses, response,
+        final int searchRes = Collections.binarySearch(cachedServices, cachedService,
                 // Sort the list by ttl.
-                (o1, o2) -> Long.compare(o1.getMinRemainingTtl(now), o2.getMinRemainingTtl(now)));
-        responses.add(searchRes >= 0 ? searchRes : (-searchRes - 1), response);
+                (o1, o2) -> Long.compare(o1.mService.getMinRemainingTtl(now),
+                        o2.mService.getMinRemainingTtl(now)));
+        cachedServices.add(searchRes >= 0 ? searchRes : (-searchRes - 1), cachedService);
     }
 
     /**
@@ -186,20 +217,22 @@
      */
     public void addOrUpdateService(@NonNull CacheKey cacheKey, @NonNull MdnsResponse response) {
         ensureRunningOnHandlerThread(mHandler);
-        final List<MdnsResponse> responses = mCachedServices.computeIfAbsent(
+        final List<CachedService> cachedServices = mCachedServices.computeIfAbsent(
                 cacheKey, key -> new ArrayList<>());
         // Remove existing service if present.
-        final MdnsResponse existing =
-                findMatchedResponse(responses, response.getServiceInstanceName());
-        responses.remove(existing);
+        final CachedService existing = findMatchedCachedService(cachedServices,
+                response.getServiceInstanceName());
+        cachedServices.remove(existing);
+
+        final CachedService cachedService = new CachedService(response);
         if (mMdnsFeatureFlags.mIsExpiredServicesRemovalEnabled) {
             final long now = mClock.elapsedRealtime();
             // Insert and sort service
-            insertResponseAndSortList(responses, response, now);
+            insertServiceAndSortList(cachedServices, cachedService, now);
             // Update the next expiration check time when a new service is added.
             mNextExpirationTime = getNextExpirationTime(now);
         } else {
-            responses.add(response);
+            cachedServices.add(cachedService);
         }
     }
 
@@ -212,30 +245,30 @@
     @Nullable
     public MdnsResponse removeService(@NonNull String serviceName, @NonNull CacheKey cacheKey) {
         ensureRunningOnHandlerThread(mHandler);
-        final List<MdnsResponse> responses = mCachedServices.get(cacheKey);
-        if (responses == null) {
+        final List<CachedService> cachedServices = mCachedServices.get(cacheKey);
+        if (cachedServices == null) {
             return null;
         }
-        final Iterator<MdnsResponse> iterator = responses.iterator();
-        MdnsResponse removedResponse = null;
+        final Iterator<CachedService> iterator = cachedServices.iterator();
+        CachedService removedService = null;
         while (iterator.hasNext()) {
-            final MdnsResponse response = iterator.next();
-            if (equalsIgnoreDnsCase(serviceName, response.getServiceInstanceName())) {
+            final CachedService cachedService = iterator.next();
+            if (equalsIgnoreDnsCase(serviceName, cachedService.mService.getServiceInstanceName())) {
                 iterator.remove();
-                removedResponse = response;
+                removedService = cachedService;
                 break;
             }
         }
 
         if (mMdnsFeatureFlags.mIsExpiredServicesRemovalEnabled) {
             // Remove the serviceType if no response.
-            if (responses.isEmpty()) {
+            if (cachedServices.isEmpty()) {
                 mCachedServices.remove(cacheKey);
             }
             // Update the next expiration check time when a service is removed.
             mNextExpirationTime = getNextExpirationTime(mClock.elapsedRealtime());
         }
-        return removedResponse;
+        return removedService == null ? null : removedService.mService;
     }
 
     /**
@@ -288,24 +321,25 @@
         mHandler.post(()-> callback.onServiceRecordExpired(previousResponse, newResponse));
     }
 
-    static List<MdnsResponse> removeExpiredServices(@NonNull List<MdnsResponse> responses,
+    static List<CachedService> removeExpiredServices(@NonNull List<CachedService> cachedServices,
             long now) {
-        final List<MdnsResponse> removedResponses = new ArrayList<>();
-        final Iterator<MdnsResponse> iterator = responses.iterator();
+        final List<CachedService> removedServices = new ArrayList<>();
+        final Iterator<CachedService> iterator = cachedServices.iterator();
         while (iterator.hasNext()) {
-            final MdnsResponse response = iterator.next();
+            final CachedService cachedService = iterator.next();
             // TODO: Check other records (A, AAAA, TXT) ttl time and remove the record if it's
             //  expired. Then send service update notification.
-            if (!response.hasServiceRecord() || response.getMinRemainingTtl(now) > 0) {
+            if (!cachedService.mService.hasServiceRecord()
+                    || cachedService.mService.getMinRemainingTtl(now) > 0) {
                 // The responses are sorted by the service record ttl time. Break out of loop
                 // early if service is not expired or no service record.
                 break;
             }
             // Remove the ttl expired service.
             iterator.remove();
-            removedResponses.add(response);
+            removedServices.add(cachedService);
         }
-        return removedResponses;
+        return removedServices;
     }
 
     private long getNextExpirationTime(long now) {
@@ -319,7 +353,7 @@
                     // The empty lists are not kept in the map, so there's always at least one
                     // element in the list. Therefore, it's fine to get the first element without a
                     // null check.
-                    mCachedServices.valueAt(i).get(0).getMinRemainingTtl(now));
+                    mCachedServices.valueAt(i).get(0).mService.getMinRemainingTtl(now));
         }
         return minRemainingTtl == EXPIRATION_NEVER ? EXPIRATION_NEVER : now + minRemainingTtl;
     }
@@ -334,24 +368,24 @@
             return;
         }
 
-        final List<MdnsResponse> responses = mCachedServices.get(cacheKey);
-        if (responses == null) {
+        final List<CachedService> cachedServices = mCachedServices.get(cacheKey);
+        if (cachedServices == null) {
             // No such services.
             return;
         }
 
-        final List<MdnsResponse> removedResponses = removeExpiredServices(responses, now);
-        if (removedResponses.isEmpty()) {
+        final List<CachedService> removedServices = removeExpiredServices(cachedServices, now);
+        if (removedServices.isEmpty()) {
             // No expired services.
             return;
         }
 
-        for (MdnsResponse previousResponse : removedResponses) {
-            notifyServiceExpired(cacheKey, previousResponse, null /* newResponse */);
+        for (CachedService previousService : removedServices) {
+            notifyServiceExpired(cacheKey, previousService.mService, null /* newResponse */);
         }
 
         // Remove the serviceType if no response.
-        if (responses.isEmpty()) {
+        if (cachedServices.isEmpty()) {
             mCachedServices.remove(cacheKey);
         }
 
@@ -368,8 +402,9 @@
         for (int i = 0; i < mCachedServices.size(); i++) {
             final CacheKey key = mCachedServices.keyAt(i);
             pw.println(indent + key);
-            for (MdnsResponse response : mCachedServices.valueAt(i)) {
-                pw.println(indent + "  Response{ " + response + " }");
+            for (CachedService cachedService : mCachedServices.valueAt(i)) {
+                pw.println(indent + "  Response{ " + cachedService.mService
+                        + " } Expired=" + cachedService.mServiceExpired);
             }
             pw.println();
         }
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceCacheTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceCacheTest.kt
index 976dfa9..2ebe87a 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceCacheTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceCacheTest.kt
@@ -21,6 +21,7 @@
 import android.os.HandlerThread
 import com.android.net.module.util.ArrayTrackRecord
 import com.android.server.connectivity.mdns.MdnsServiceCache.CacheKey
+import com.android.server.connectivity.mdns.MdnsServiceCache.CachedService
 import com.android.server.connectivity.mdns.MdnsServiceCacheTest.ExpiredRecord.ExpiredEvent.ServiceRecordExpired
 import com.android.server.connectivity.mdns.util.MdnsUtils
 import com.android.testutils.DevSdkIgnoreRule
@@ -289,32 +290,40 @@
 
     @Test
     fun testInsertResponseAndSortList() {
-        val responses = ArrayList<MdnsResponse>()
-        val response1 = createResponse(SERVICE_NAME_1, SERVICE_TYPE_1, 100L /* ttlTime */)
-        MdnsServiceCache.insertResponseAndSortList(responses, response1, TEST_ELAPSED_REALTIME_MS)
-        assertEquals(1, responses.size)
-        assertEquals(response1, responses[0])
+        val services = ArrayList<CachedService>()
+        val service1 = CachedService(
+                createResponse(SERVICE_NAME_1, SERVICE_TYPE_1, 100L /* ttlTime */)
+        )
+        MdnsServiceCache.insertServiceAndSortList(services, service1, TEST_ELAPSED_REALTIME_MS)
+        assertEquals(1, services.size)
+        assertEquals(service1, services[0])
 
-        val response2 = createResponse(SERVICE_NAME_2, SERVICE_TYPE_1, 50L /* ttlTime */)
-        MdnsServiceCache.insertResponseAndSortList(responses, response2, TEST_ELAPSED_REALTIME_MS)
-        assertEquals(2, responses.size)
-        assertEquals(response2, responses[0])
-        assertEquals(response1, responses[1])
+        val service2 = CachedService(
+                createResponse(SERVICE_NAME_2, SERVICE_TYPE_1, 50L /* ttlTime */)
+        )
+        MdnsServiceCache.insertServiceAndSortList(services, service2, TEST_ELAPSED_REALTIME_MS)
+        assertEquals(2, services.size)
+        assertEquals(service2, services[0])
+        assertEquals(service1, services[1])
 
-        val response3 = createResponse(SERVICE_NAME_3, SERVICE_TYPE_1, 75L /* ttlTime */)
-        MdnsServiceCache.insertResponseAndSortList(responses, response3, TEST_ELAPSED_REALTIME_MS)
-        assertEquals(3, responses.size)
-        assertEquals(response2, responses[0])
-        assertEquals(response3, responses[1])
-        assertEquals(response1, responses[2])
+        val service3 = CachedService(
+                createResponse(SERVICE_NAME_3, SERVICE_TYPE_1, 75L /* ttlTime */)
+        )
+        MdnsServiceCache.insertServiceAndSortList(services, service3, TEST_ELAPSED_REALTIME_MS)
+        assertEquals(3, services.size)
+        assertEquals(service2, services[0])
+        assertEquals(service3, services[1])
+        assertEquals(service1, services[2])
 
-        val response4 = createResponse("service-instance-4", SERVICE_TYPE_1, 125L /* ttlTime */)
-        MdnsServiceCache.insertResponseAndSortList(responses, response4, TEST_ELAPSED_REALTIME_MS)
-        assertEquals(4, responses.size)
-        assertEquals(response2, responses[0])
-        assertEquals(response3, responses[1])
-        assertEquals(response1, responses[2])
-        assertEquals(response4, responses[3])
+        val service4 = CachedService(
+                createResponse("service-instance-4", SERVICE_TYPE_1, 125L /* ttlTime */)
+        )
+        MdnsServiceCache.insertServiceAndSortList(services, service4, TEST_ELAPSED_REALTIME_MS)
+        assertEquals(4, services.size)
+        assertEquals(service2, services[0])
+        assertEquals(service3, services[1])
+        assertEquals(service1, services[2])
+        assertEquals(service4, services[3])
     }
 
     @Test