Merge "Add CachedService Class" into main
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceCache.java b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceCache.java
index 4ae8701..7c72fb1 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceCache.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceCache.java
@@ -78,6 +78,17 @@
return "CacheKey{ ServiceType=" + mUpperCaseServiceType + ", " + mSocketKey + " }";
}
}
+
+ public static class CachedService {
+ @NonNull final MdnsResponse mService;
+ boolean mServiceExpired;
+
+ CachedService(MdnsResponse service) {
+ mService = service;
+ mServiceExpired = false;
+ }
+ }
+
/**
* A map of cached services. Key is composed of service type and socket. Value is the list of
* services which are discovered from the given CacheKey.
@@ -86,7 +97,7 @@
* removal process to progress through the expiration check efficiently.
*/
@NonNull
- private final ArrayMap<CacheKey, List<MdnsResponse>> mCachedServices = new ArrayMap<>();
+ private final ArrayMap<CacheKey, List<CachedService>> mCachedServices = new ArrayMap<>();
/**
* A map of service expire callbacks. Key is composed of service type and socket and value is
* the callback listener.
@@ -113,6 +124,14 @@
mClock = clock;
}
+ private List<MdnsResponse> cachedServicesToResponses(List<CachedService> cachedServices) {
+ final List<MdnsResponse> responses = new ArrayList<>();
+ for (CachedService cachedService : cachedServices) {
+ responses.add(cachedService.mService);
+ }
+ return responses;
+ }
+
/**
* Get the cache services which are queried from given service type and socket.
*
@@ -126,7 +145,8 @@
maybeRemoveExpiredServices(cacheKey, mClock.elapsedRealtime());
}
return mCachedServices.containsKey(cacheKey)
- ? Collections.unmodifiableList(new ArrayList<>(mCachedServices.get(cacheKey)))
+ ? Collections.unmodifiableList(
+ cachedServicesToResponses(mCachedServices.get(cacheKey)))
: Collections.emptyList();
}
@@ -147,6 +167,16 @@
return null;
}
+ private static CachedService findMatchedCachedService(
+ @NonNull List<CachedService> cachedServices, @NonNull String serviceName) {
+ for (CachedService cachedService : cachedServices) {
+ if (equalsIgnoreDnsCase(serviceName, cachedService.mService.getServiceInstanceName())) {
+ return cachedService;
+ }
+ }
+ return null;
+ }
+
/**
* Get the cache service.
*
@@ -160,22 +190,23 @@
if (mMdnsFeatureFlags.mIsExpiredServicesRemovalEnabled) {
maybeRemoveExpiredServices(cacheKey, mClock.elapsedRealtime());
}
- final List<MdnsResponse> responses = mCachedServices.get(cacheKey);
- if (responses == null) {
+ final List<CachedService> cachedServices = mCachedServices.get(cacheKey);
+ if (cachedServices == null) {
return null;
}
- final MdnsResponse response = findMatchedResponse(responses, serviceName);
- return response != null ? new MdnsResponse(response) : null;
+ final CachedService cachedService = findMatchedCachedService(cachedServices, serviceName);
+ return cachedService != null ? new MdnsResponse(cachedService.mService) : null;
}
- static void insertResponseAndSortList(
- List<MdnsResponse> responses, MdnsResponse response, long now) {
+ static void insertServiceAndSortList(
+ List<CachedService> cachedServices, CachedService cachedService, long now) {
// binarySearch returns "the index of the search key, if it is contained in the list;
// otherwise, (-(insertion point) - 1)"
- final int searchRes = Collections.binarySearch(responses, response,
+ final int searchRes = Collections.binarySearch(cachedServices, cachedService,
// Sort the list by ttl.
- (o1, o2) -> Long.compare(o1.getMinRemainingTtl(now), o2.getMinRemainingTtl(now)));
- responses.add(searchRes >= 0 ? searchRes : (-searchRes - 1), response);
+ (o1, o2) -> Long.compare(o1.mService.getMinRemainingTtl(now),
+ o2.mService.getMinRemainingTtl(now)));
+ cachedServices.add(searchRes >= 0 ? searchRes : (-searchRes - 1), cachedService);
}
/**
@@ -186,20 +217,22 @@
*/
public void addOrUpdateService(@NonNull CacheKey cacheKey, @NonNull MdnsResponse response) {
ensureRunningOnHandlerThread(mHandler);
- final List<MdnsResponse> responses = mCachedServices.computeIfAbsent(
+ final List<CachedService> cachedServices = mCachedServices.computeIfAbsent(
cacheKey, key -> new ArrayList<>());
// Remove existing service if present.
- final MdnsResponse existing =
- findMatchedResponse(responses, response.getServiceInstanceName());
- responses.remove(existing);
+ final CachedService existing = findMatchedCachedService(cachedServices,
+ response.getServiceInstanceName());
+ cachedServices.remove(existing);
+
+ final CachedService cachedService = new CachedService(response);
if (mMdnsFeatureFlags.mIsExpiredServicesRemovalEnabled) {
final long now = mClock.elapsedRealtime();
// Insert and sort service
- insertResponseAndSortList(responses, response, now);
+ insertServiceAndSortList(cachedServices, cachedService, now);
// Update the next expiration check time when a new service is added.
mNextExpirationTime = getNextExpirationTime(now);
} else {
- responses.add(response);
+ cachedServices.add(cachedService);
}
}
@@ -212,30 +245,30 @@
@Nullable
public MdnsResponse removeService(@NonNull String serviceName, @NonNull CacheKey cacheKey) {
ensureRunningOnHandlerThread(mHandler);
- final List<MdnsResponse> responses = mCachedServices.get(cacheKey);
- if (responses == null) {
+ final List<CachedService> cachedServices = mCachedServices.get(cacheKey);
+ if (cachedServices == null) {
return null;
}
- final Iterator<MdnsResponse> iterator = responses.iterator();
- MdnsResponse removedResponse = null;
+ final Iterator<CachedService> iterator = cachedServices.iterator();
+ CachedService removedService = null;
while (iterator.hasNext()) {
- final MdnsResponse response = iterator.next();
- if (equalsIgnoreDnsCase(serviceName, response.getServiceInstanceName())) {
+ final CachedService cachedService = iterator.next();
+ if (equalsIgnoreDnsCase(serviceName, cachedService.mService.getServiceInstanceName())) {
iterator.remove();
- removedResponse = response;
+ removedService = cachedService;
break;
}
}
if (mMdnsFeatureFlags.mIsExpiredServicesRemovalEnabled) {
// Remove the serviceType if no response.
- if (responses.isEmpty()) {
+ if (cachedServices.isEmpty()) {
mCachedServices.remove(cacheKey);
}
// Update the next expiration check time when a service is removed.
mNextExpirationTime = getNextExpirationTime(mClock.elapsedRealtime());
}
- return removedResponse;
+ return removedService == null ? null : removedService.mService;
}
/**
@@ -288,24 +321,25 @@
mHandler.post(()-> callback.onServiceRecordExpired(previousResponse, newResponse));
}
- static List<MdnsResponse> removeExpiredServices(@NonNull List<MdnsResponse> responses,
+ static List<CachedService> removeExpiredServices(@NonNull List<CachedService> cachedServices,
long now) {
- final List<MdnsResponse> removedResponses = new ArrayList<>();
- final Iterator<MdnsResponse> iterator = responses.iterator();
+ final List<CachedService> removedServices = new ArrayList<>();
+ final Iterator<CachedService> iterator = cachedServices.iterator();
while (iterator.hasNext()) {
- final MdnsResponse response = iterator.next();
+ final CachedService cachedService = iterator.next();
// TODO: Check other records (A, AAAA, TXT) ttl time and remove the record if it's
// expired. Then send service update notification.
- if (!response.hasServiceRecord() || response.getMinRemainingTtl(now) > 0) {
+ if (!cachedService.mService.hasServiceRecord()
+ || cachedService.mService.getMinRemainingTtl(now) > 0) {
// The responses are sorted by the service record ttl time. Break out of loop
// early if service is not expired or no service record.
break;
}
// Remove the ttl expired service.
iterator.remove();
- removedResponses.add(response);
+ removedServices.add(cachedService);
}
- return removedResponses;
+ return removedServices;
}
private long getNextExpirationTime(long now) {
@@ -319,7 +353,7 @@
// The empty lists are not kept in the map, so there's always at least one
// element in the list. Therefore, it's fine to get the first element without a
// null check.
- mCachedServices.valueAt(i).get(0).getMinRemainingTtl(now));
+ mCachedServices.valueAt(i).get(0).mService.getMinRemainingTtl(now));
}
return minRemainingTtl == EXPIRATION_NEVER ? EXPIRATION_NEVER : now + minRemainingTtl;
}
@@ -334,24 +368,24 @@
return;
}
- final List<MdnsResponse> responses = mCachedServices.get(cacheKey);
- if (responses == null) {
+ final List<CachedService> cachedServices = mCachedServices.get(cacheKey);
+ if (cachedServices == null) {
// No such services.
return;
}
- final List<MdnsResponse> removedResponses = removeExpiredServices(responses, now);
- if (removedResponses.isEmpty()) {
+ final List<CachedService> removedServices = removeExpiredServices(cachedServices, now);
+ if (removedServices.isEmpty()) {
// No expired services.
return;
}
- for (MdnsResponse previousResponse : removedResponses) {
- notifyServiceExpired(cacheKey, previousResponse, null /* newResponse */);
+ for (CachedService previousService : removedServices) {
+ notifyServiceExpired(cacheKey, previousService.mService, null /* newResponse */);
}
// Remove the serviceType if no response.
- if (responses.isEmpty()) {
+ if (cachedServices.isEmpty()) {
mCachedServices.remove(cacheKey);
}
@@ -368,8 +402,9 @@
for (int i = 0; i < mCachedServices.size(); i++) {
final CacheKey key = mCachedServices.keyAt(i);
pw.println(indent + key);
- for (MdnsResponse response : mCachedServices.valueAt(i)) {
- pw.println(indent + " Response{ " + response + " }");
+ for (CachedService cachedService : mCachedServices.valueAt(i)) {
+ pw.println(indent + " Response{ " + cachedService.mService
+ + " } Expired=" + cachedService.mServiceExpired);
}
pw.println();
}
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceCacheTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceCacheTest.kt
index 976dfa9..2ebe87a 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceCacheTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceCacheTest.kt
@@ -21,6 +21,7 @@
import android.os.HandlerThread
import com.android.net.module.util.ArrayTrackRecord
import com.android.server.connectivity.mdns.MdnsServiceCache.CacheKey
+import com.android.server.connectivity.mdns.MdnsServiceCache.CachedService
import com.android.server.connectivity.mdns.MdnsServiceCacheTest.ExpiredRecord.ExpiredEvent.ServiceRecordExpired
import com.android.server.connectivity.mdns.util.MdnsUtils
import com.android.testutils.DevSdkIgnoreRule
@@ -289,32 +290,40 @@
@Test
fun testInsertResponseAndSortList() {
- val responses = ArrayList<MdnsResponse>()
- val response1 = createResponse(SERVICE_NAME_1, SERVICE_TYPE_1, 100L /* ttlTime */)
- MdnsServiceCache.insertResponseAndSortList(responses, response1, TEST_ELAPSED_REALTIME_MS)
- assertEquals(1, responses.size)
- assertEquals(response1, responses[0])
+ val services = ArrayList<CachedService>()
+ val service1 = CachedService(
+ createResponse(SERVICE_NAME_1, SERVICE_TYPE_1, 100L /* ttlTime */)
+ )
+ MdnsServiceCache.insertServiceAndSortList(services, service1, TEST_ELAPSED_REALTIME_MS)
+ assertEquals(1, services.size)
+ assertEquals(service1, services[0])
- val response2 = createResponse(SERVICE_NAME_2, SERVICE_TYPE_1, 50L /* ttlTime */)
- MdnsServiceCache.insertResponseAndSortList(responses, response2, TEST_ELAPSED_REALTIME_MS)
- assertEquals(2, responses.size)
- assertEquals(response2, responses[0])
- assertEquals(response1, responses[1])
+ val service2 = CachedService(
+ createResponse(SERVICE_NAME_2, SERVICE_TYPE_1, 50L /* ttlTime */)
+ )
+ MdnsServiceCache.insertServiceAndSortList(services, service2, TEST_ELAPSED_REALTIME_MS)
+ assertEquals(2, services.size)
+ assertEquals(service2, services[0])
+ assertEquals(service1, services[1])
- val response3 = createResponse(SERVICE_NAME_3, SERVICE_TYPE_1, 75L /* ttlTime */)
- MdnsServiceCache.insertResponseAndSortList(responses, response3, TEST_ELAPSED_REALTIME_MS)
- assertEquals(3, responses.size)
- assertEquals(response2, responses[0])
- assertEquals(response3, responses[1])
- assertEquals(response1, responses[2])
+ val service3 = CachedService(
+ createResponse(SERVICE_NAME_3, SERVICE_TYPE_1, 75L /* ttlTime */)
+ )
+ MdnsServiceCache.insertServiceAndSortList(services, service3, TEST_ELAPSED_REALTIME_MS)
+ assertEquals(3, services.size)
+ assertEquals(service2, services[0])
+ assertEquals(service3, services[1])
+ assertEquals(service1, services[2])
- val response4 = createResponse("service-instance-4", SERVICE_TYPE_1, 125L /* ttlTime */)
- MdnsServiceCache.insertResponseAndSortList(responses, response4, TEST_ELAPSED_REALTIME_MS)
- assertEquals(4, responses.size)
- assertEquals(response2, responses[0])
- assertEquals(response3, responses[1])
- assertEquals(response1, responses[2])
- assertEquals(response4, responses[3])
+ val service4 = CachedService(
+ createResponse("service-instance-4", SERVICE_TYPE_1, 125L /* ttlTime */)
+ )
+ MdnsServiceCache.insertServiceAndSortList(services, service4, TEST_ELAPSED_REALTIME_MS)
+ assertEquals(4, services.size)
+ assertEquals(service2, services[0])
+ assertEquals(service3, services[1])
+ assertEquals(service1, services[2])
+ assertEquals(service4, services[3])
}
@Test