Add ServiceExpiredCallback
This is a no-op change and refactors the design for subsequent
TTL expiration check changes.
- Add a ServiceExpiredCallback to notify expired services.
- To simplify the design, pass the CacheKey to MdnsServiceCache
methods instead.
Bug: 265787401
Test: atest FrameworksNetTests CtsNetTestCases
Change-Id: I930a4f7baf9b8d3d0037dc6aefd717dbdd486520
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceCache.java b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceCache.java
index ec6af9b..f9ee0df 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceCache.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceCache.java
@@ -42,7 +42,7 @@
* to their default value (0, false or null).
*/
public class MdnsServiceCache {
- private static class CacheKey {
+ static class CacheKey {
@NonNull final String mLowercaseServiceType;
@NonNull final SocketKey mSocketKey;
@@ -72,6 +72,12 @@
*/
@NonNull
private final ArrayMap<CacheKey, List<MdnsResponse>> mCachedServices = new ArrayMap<>();
+ /**
+ * A map of service expire callbacks. Key is composed of service type and socket and value is
+ * the callback listener.
+ */
+ @NonNull
+ private final ArrayMap<CacheKey, ServiceExpiredCallback> mCallbacks = new ArrayMap<>();
@NonNull
private final Handler mHandler;
@@ -82,17 +88,14 @@
/**
* Get the cache services which are queried from given service type and socket.
*
- * @param serviceType the target service type.
- * @param socketKey the target socket
+ * @param cacheKey the target CacheKey.
* @return the set of services which matches the given service type.
*/
@NonNull
- public List<MdnsResponse> getCachedServices(@NonNull String serviceType,
- @NonNull SocketKey socketKey) {
+ public List<MdnsResponse> getCachedServices(@NonNull CacheKey cacheKey) {
ensureRunningOnHandlerThread(mHandler);
- final CacheKey key = new CacheKey(serviceType, socketKey);
- return mCachedServices.containsKey(key)
- ? Collections.unmodifiableList(new ArrayList<>(mCachedServices.get(key)))
+ return mCachedServices.containsKey(cacheKey)
+ ? Collections.unmodifiableList(new ArrayList<>(mCachedServices.get(cacheKey)))
: Collections.emptyList();
}
@@ -117,16 +120,13 @@
* Get the cache service.
*
* @param serviceName the target service name.
- * @param serviceType the target service type.
- * @param socketKey the target socket
+ * @param cacheKey the target CacheKey.
* @return the service which matches given conditions.
*/
@Nullable
- public MdnsResponse getCachedService(@NonNull String serviceName,
- @NonNull String serviceType, @NonNull SocketKey socketKey) {
+ public MdnsResponse getCachedService(@NonNull String serviceName, @NonNull CacheKey cacheKey) {
ensureRunningOnHandlerThread(mHandler);
- final List<MdnsResponse> responses =
- mCachedServices.get(new CacheKey(serviceType, socketKey));
+ final List<MdnsResponse> responses = mCachedServices.get(cacheKey);
if (responses == null) {
return null;
}
@@ -137,15 +137,13 @@
/**
* Add or update a service.
*
- * @param serviceType the service type.
- * @param socketKey the target socket
+ * @param cacheKey the target CacheKey.
* @param response the response of the discovered service.
*/
- public void addOrUpdateService(@NonNull String serviceType, @NonNull SocketKey socketKey,
- @NonNull MdnsResponse response) {
+ public void addOrUpdateService(@NonNull CacheKey cacheKey, @NonNull MdnsResponse response) {
ensureRunningOnHandlerThread(mHandler);
final List<MdnsResponse> responses = mCachedServices.computeIfAbsent(
- new CacheKey(serviceType, socketKey), key -> new ArrayList<>());
+ cacheKey, key -> new ArrayList<>());
// Remove existing service if present.
final MdnsResponse existing =
findMatchedResponse(responses, response.getServiceInstanceName());
@@ -157,15 +155,12 @@
* Remove a service which matches the given service name, type and socket.
*
* @param serviceName the target service name.
- * @param serviceType the target service type.
- * @param socketKey the target socket.
+ * @param cacheKey the target CacheKey.
*/
@Nullable
- public MdnsResponse removeService(@NonNull String serviceName, @NonNull String serviceType,
- @NonNull SocketKey socketKey) {
+ public MdnsResponse removeService(@NonNull String serviceName, @NonNull CacheKey cacheKey) {
ensureRunningOnHandlerThread(mHandler);
- final List<MdnsResponse> responses =
- mCachedServices.get(new CacheKey(serviceType, socketKey));
+ final List<MdnsResponse> responses = mCachedServices.get(cacheKey);
if (responses == null) {
return null;
}
@@ -180,5 +175,37 @@
return null;
}
+ /**
+ * Register a callback to listen to service expiration.
+ *
+ * <p> Registering the same callback instance twice is a no-op, since MdnsServiceTypeClient
+ * relies on this.
+ *
+ * @param cacheKey the target CacheKey.
+ * @param callback the callback that notify the service is expired.
+ */
+ public void registerServiceExpiredCallback(@NonNull CacheKey cacheKey,
+ @NonNull ServiceExpiredCallback callback) {
+ ensureRunningOnHandlerThread(mHandler);
+ mCallbacks.put(cacheKey, callback);
+ }
+
+ /**
+ * Unregister the service expired callback.
+ *
+ * @param cacheKey the CacheKey that is registered to listen service expiration before.
+ */
+ public void unregisterServiceExpiredCallback(@NonNull CacheKey cacheKey) {
+ ensureRunningOnHandlerThread(mHandler);
+ mCallbacks.remove(cacheKey);
+ }
+
+ /*** Callbacks for listening service expiration */
+ public interface ServiceExpiredCallback {
+ /*** Notify the service is expired */
+ void onServiceRecordExpired(@NonNull MdnsResponse previousResponse,
+ @Nullable MdnsResponse newResponse);
+ }
+
// TODO: check ttl expiration for each service and notify to the clients.
}
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
index bbe8f4c..0a03186 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
@@ -16,6 +16,7 @@
package com.android.server.connectivity.mdns;
+import static com.android.server.connectivity.mdns.MdnsServiceCache.ServiceExpiredCallback;
import static com.android.server.connectivity.mdns.MdnsServiceCache.findMatchedResponse;
import static com.android.server.connectivity.mdns.util.MdnsUtils.Clock;
import static com.android.server.connectivity.mdns.util.MdnsUtils.ensureRunningOnHandlerThread;
@@ -71,6 +72,15 @@
* The service caches for each socket. It should be accessed from looper thread only.
*/
@NonNull private final MdnsServiceCache serviceCache;
+ @NonNull private final MdnsServiceCache.CacheKey cacheKey;
+ @NonNull private final ServiceExpiredCallback serviceExpiredCallback =
+ new ServiceExpiredCallback() {
+ @Override
+ public void onServiceRecordExpired(@NonNull MdnsResponse previousResponse,
+ @Nullable MdnsResponse newResponse) {
+ notifyRemovedServiceToListeners(previousResponse, "Service record expired");
+ }
+ };
private final ArrayMap<MdnsServiceBrowserListener, MdnsSearchOptions> listeners =
new ArrayMap<>();
private final boolean removeServiceAfterTtlExpires =
@@ -225,6 +235,16 @@
this.dependencies = dependencies;
this.serviceCache = serviceCache;
this.mdnsQueryScheduler = new MdnsQueryScheduler();
+ this.cacheKey = new MdnsServiceCache.CacheKey(serviceType, socketKey);
+ }
+
+ /**
+ * Do the cleanup of the MdnsServiceTypeClient
+ */
+ private void shutDown() {
+ removeScheduledTask();
+ mdnsQueryScheduler.cancelScheduledRun();
+ serviceCache.unregisterServiceExpiredCallback(cacheKey);
}
private static MdnsServiceInfo buildMdnsServiceInfoFromResponse(
@@ -293,7 +313,7 @@
boolean hadReply = false;
if (listeners.put(listener, searchOptions) == null) {
for (MdnsResponse existingResponse :
- serviceCache.getCachedServices(serviceType, socketKey)) {
+ serviceCache.getCachedServices(cacheKey)) {
if (!responseMatchesOptions(existingResponse, searchOptions)) continue;
final MdnsServiceInfo info =
buildMdnsServiceInfoFromResponse(existingResponse, serviceTypeLabels);
@@ -341,6 +361,8 @@
servicesToResolve.size() < listeners.size() /* sendDiscoveryQueries */);
executor.submit(queryTask);
}
+
+ serviceCache.registerServiceExpiredCallback(cacheKey, serviceExpiredCallback);
}
/**
@@ -390,8 +412,7 @@
return listeners.isEmpty();
}
if (listeners.isEmpty()) {
- removeScheduledTask();
- mdnsQueryScheduler.cancelScheduledRun();
+ shutDown();
}
return listeners.isEmpty();
}
@@ -404,8 +425,7 @@
ensureRunningOnHandlerThread(handler);
// Augment the list of current known responses, and generated responses for resolve
// requests if there is no known response
- final List<MdnsResponse> cachedList =
- serviceCache.getCachedServices(serviceType, socketKey);
+ final List<MdnsResponse> cachedList = serviceCache.getCachedServices(cacheKey);
final List<MdnsResponse> currentList = new ArrayList<>(cachedList);
List<MdnsResponse> additionalResponses = makeResponsesForResolve(socketKey);
for (MdnsResponse additionalResponse : additionalResponses) {
@@ -432,7 +452,7 @@
} else if (findMatchedResponse(cachedList, serviceInstanceName) != null) {
// If the response is not modified and already in the cache. The cache will
// need to be updated to refresh the last receipt time.
- serviceCache.addOrUpdateService(serviceType, socketKey, response);
+ serviceCache.addOrUpdateService(cacheKey, response);
}
}
if (dependencies.hasMessages(handler, EVENT_START_QUERYTASK)) {
@@ -458,44 +478,50 @@
}
}
- /** Notify all services are removed because the socket is destroyed. */
- public void notifySocketDestroyed() {
- ensureRunningOnHandlerThread(handler);
- for (MdnsResponse response : serviceCache.getCachedServices(serviceType, socketKey)) {
- final String name = response.getServiceInstanceName();
- if (name == null) continue;
- for (int i = 0; i < listeners.size(); i++) {
- if (!responseMatchesOptions(response, listeners.valueAt(i))) continue;
- final MdnsServiceBrowserListener listener = listeners.keyAt(i);
- final MdnsServiceInfo serviceInfo =
- buildMdnsServiceInfoFromResponse(response, serviceTypeLabels);
+ private void notifyRemovedServiceToListeners(@NonNull MdnsResponse response,
+ @NonNull String message) {
+ for (int i = 0; i < listeners.size(); i++) {
+ if (!responseMatchesOptions(response, listeners.valueAt(i))) continue;
+ final MdnsServiceBrowserListener listener = listeners.keyAt(i);
+ if (response.getServiceInstanceName() != null) {
+ final MdnsServiceInfo serviceInfo = buildMdnsServiceInfoFromResponse(
+ response, serviceTypeLabels);
if (response.isComplete()) {
- sharedLog.log("Socket destroyed. onServiceRemoved: " + name);
+ sharedLog.log(message + ". onServiceRemoved: " + serviceInfo);
listener.onServiceRemoved(serviceInfo);
}
- sharedLog.log("Socket destroyed. onServiceNameRemoved: " + name);
+ sharedLog.log(message + ". onServiceNameRemoved: " + serviceInfo);
listener.onServiceNameRemoved(serviceInfo);
}
}
- removeScheduledTask();
- mdnsQueryScheduler.cancelScheduledRun();
+ }
+
+ /** Notify all services are removed because the socket is destroyed. */
+ public void notifySocketDestroyed() {
+ ensureRunningOnHandlerThread(handler);
+ for (MdnsResponse response : serviceCache.getCachedServices(cacheKey)) {
+ final String name = response.getServiceInstanceName();
+ if (name == null) continue;
+ notifyRemovedServiceToListeners(response, "Socket destroyed");
+ }
+ shutDown();
}
private void onResponseModified(@NonNull MdnsResponse response) {
final String serviceInstanceName = response.getServiceInstanceName();
final MdnsResponse currentResponse =
- serviceCache.getCachedService(serviceInstanceName, serviceType, socketKey);
+ serviceCache.getCachedService(serviceInstanceName, cacheKey);
boolean newServiceFound = false;
boolean serviceBecomesComplete = false;
if (currentResponse == null) {
newServiceFound = true;
if (serviceInstanceName != null) {
- serviceCache.addOrUpdateService(serviceType, socketKey, response);
+ serviceCache.addOrUpdateService(cacheKey, response);
}
} else {
boolean before = currentResponse.isComplete();
- serviceCache.addOrUpdateService(serviceType, socketKey, response);
+ serviceCache.addOrUpdateService(cacheKey, response);
boolean after = response.isComplete();
serviceBecomesComplete = !before && after;
}
@@ -529,22 +555,11 @@
private void onGoodbyeReceived(@Nullable String serviceInstanceName) {
final MdnsResponse response =
- serviceCache.removeService(serviceInstanceName, serviceType, socketKey);
+ serviceCache.removeService(serviceInstanceName, cacheKey);
if (response == null) {
return;
}
- for (int i = 0; i < listeners.size(); i++) {
- if (!responseMatchesOptions(response, listeners.valueAt(i))) continue;
- final MdnsServiceBrowserListener listener = listeners.keyAt(i);
- final MdnsServiceInfo serviceInfo =
- buildMdnsServiceInfoFromResponse(response, serviceTypeLabels);
- if (response.isComplete()) {
- sharedLog.log("onServiceRemoved: " + serviceInfo);
- listener.onServiceRemoved(serviceInfo);
- }
- sharedLog.log("onServiceNameRemoved: " + serviceInfo);
- listener.onServiceNameRemoved(serviceInfo);
- }
+ notifyRemovedServiceToListeners(response, "Goodbye received");
}
private boolean shouldRemoveServiceAfterTtlExpires() {
@@ -567,7 +582,7 @@
continue;
}
MdnsResponse knownResponse =
- serviceCache.getCachedService(resolveName, serviceType, socketKey);
+ serviceCache.getCachedService(resolveName, cacheKey);
if (knownResponse == null) {
final ArrayList<String> instanceFullName = new ArrayList<>(
serviceTypeLabels.length + 1);
@@ -585,36 +600,18 @@
private void tryRemoveServiceAfterTtlExpires() {
if (!shouldRemoveServiceAfterTtlExpires()) return;
- Iterator<MdnsResponse> iter =
- serviceCache.getCachedServices(serviceType, socketKey).iterator();
+ final Iterator<MdnsResponse> iter = serviceCache.getCachedServices(cacheKey).iterator();
while (iter.hasNext()) {
MdnsResponse existingResponse = iter.next();
- final String serviceInstanceName = existingResponse.getServiceInstanceName();
if (existingResponse.hasServiceRecord()
&& existingResponse.getServiceRecord()
.getRemainingTTL(clock.elapsedRealtime()) == 0) {
- serviceCache.removeService(serviceInstanceName, serviceType, socketKey);
- for (int i = 0; i < listeners.size(); i++) {
- if (!responseMatchesOptions(existingResponse, listeners.valueAt(i))) {
- continue;
- }
- final MdnsServiceBrowserListener listener = listeners.keyAt(i);
- if (serviceInstanceName != null) {
- final MdnsServiceInfo serviceInfo = buildMdnsServiceInfoFromResponse(
- existingResponse, serviceTypeLabels);
- if (existingResponse.isComplete()) {
- sharedLog.log("TTL expired. onServiceRemoved: " + serviceInfo);
- listener.onServiceRemoved(serviceInfo);
- }
- sharedLog.log("TTL expired. onServiceNameRemoved: " + serviceInfo);
- listener.onServiceNameRemoved(serviceInfo);
- }
- }
+ serviceCache.removeService(existingResponse.getServiceInstanceName(), cacheKey);
+ notifyRemovedServiceToListeners(existingResponse, "TTL expired");
}
}
}
-
private static class QuerySentArguments {
private final int transactionId;
private final List<String> subTypes = new ArrayList<>();
@@ -672,7 +669,7 @@
private long getMinRemainingTtl(long now) {
long minRemainingTtl = Long.MAX_VALUE;
- for (MdnsResponse response : serviceCache.getCachedServices(serviceType, socketKey)) {
+ for (MdnsResponse response : serviceCache.getCachedServices(cacheKey)) {
if (!response.isComplete()) {
continue;
}
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceCacheTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceCacheTest.kt
index b43bcf7..1b6f120 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceCacheTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceCacheTest.kt
@@ -19,6 +19,7 @@
import android.os.Build
import android.os.Handler
import android.os.HandlerThread
+import com.android.server.connectivity.mdns.MdnsServiceCache.CacheKey
import com.android.testutils.DevSdkIgnoreRule
import com.android.testutils.DevSdkIgnoreRunner
import java.util.concurrent.CompletableFuture
@@ -43,6 +44,8 @@
@DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.S_V2)
class MdnsServiceCacheTest {
private val socketKey = SocketKey(null /* network */, INTERFACE_INDEX)
+ private val cacheKey1 = CacheKey(SERVICE_TYPE_1, socketKey)
+ private val cacheKey2 = CacheKey(SERVICE_TYPE_2, socketKey)
private val thread = HandlerThread(MdnsServiceCacheTest::class.simpleName)
private val handler by lazy {
Handler(thread.looper)
@@ -69,47 +72,36 @@
return future.get(DEFAULT_TIMEOUT_MS, TimeUnit.MILLISECONDS)
}
- private fun addOrUpdateService(
- serviceType: String,
- socketKey: SocketKey,
- service: MdnsResponse
- ): Unit = runningOnHandlerAndReturn {
- serviceCache.addOrUpdateService(serviceType, socketKey, service)
- }
+ private fun addOrUpdateService(cacheKey: CacheKey, service: MdnsResponse): Unit =
+ runningOnHandlerAndReturn { serviceCache.addOrUpdateService(cacheKey, service) }
- private fun removeService(serviceName: String, serviceType: String, socketKey: SocketKey):
- Unit = runningOnHandlerAndReturn {
- serviceCache.removeService(serviceName, serviceType, socketKey) }
+ private fun removeService(serviceName: String, cacheKey: CacheKey): Unit =
+ runningOnHandlerAndReturn { serviceCache.removeService(serviceName, cacheKey) }
- private fun getService(serviceName: String, serviceType: String, socketKey: SocketKey):
- MdnsResponse? = runningOnHandlerAndReturn {
- serviceCache.getCachedService(serviceName, serviceType, socketKey) }
+ private fun getService(serviceName: String, cacheKey: CacheKey): MdnsResponse? =
+ runningOnHandlerAndReturn { serviceCache.getCachedService(serviceName, cacheKey) }
- private fun getServices(serviceType: String, socketKey: SocketKey): List<MdnsResponse> =
- runningOnHandlerAndReturn { serviceCache.getCachedServices(serviceType, socketKey) }
+ private fun getServices(cacheKey: CacheKey): List<MdnsResponse> =
+ runningOnHandlerAndReturn { serviceCache.getCachedServices(cacheKey) }
@Test
fun testAddAndRemoveService() {
- addOrUpdateService(
- SERVICE_TYPE_1, socketKey, createResponse(SERVICE_NAME_1, SERVICE_TYPE_1))
- var response = getService(SERVICE_NAME_1, SERVICE_TYPE_1, socketKey)
+ addOrUpdateService(cacheKey1, createResponse(SERVICE_NAME_1, SERVICE_TYPE_1))
+ var response = getService(SERVICE_NAME_1, cacheKey1)
assertNotNull(response)
assertEquals(SERVICE_NAME_1, response.serviceInstanceName)
- removeService(SERVICE_NAME_1, SERVICE_TYPE_1, socketKey)
- response = getService(SERVICE_NAME_1, SERVICE_TYPE_1, socketKey)
+ removeService(SERVICE_NAME_1, cacheKey1)
+ response = getService(SERVICE_NAME_1, cacheKey1)
assertNull(response)
}
@Test
fun testGetCachedServices_multipleServiceTypes() {
- addOrUpdateService(
- SERVICE_TYPE_1, socketKey, createResponse(SERVICE_NAME_1, SERVICE_TYPE_1))
- addOrUpdateService(
- SERVICE_TYPE_1, socketKey, createResponse(SERVICE_NAME_2, SERVICE_TYPE_1))
- addOrUpdateService(
- SERVICE_TYPE_2, socketKey, createResponse(SERVICE_NAME_2, SERVICE_TYPE_2))
+ addOrUpdateService(cacheKey1, createResponse(SERVICE_NAME_1, SERVICE_TYPE_1))
+ addOrUpdateService(cacheKey1, createResponse(SERVICE_NAME_2, SERVICE_TYPE_1))
+ addOrUpdateService(cacheKey2, createResponse(SERVICE_NAME_2, SERVICE_TYPE_2))
- val responses1 = getServices(SERVICE_TYPE_1, socketKey)
+ val responses1 = getServices(cacheKey1)
assertEquals(2, responses1.size)
assertTrue(responses1.stream().anyMatch { response ->
response.serviceInstanceName == SERVICE_NAME_1
@@ -117,19 +109,19 @@
assertTrue(responses1.any { response ->
response.serviceInstanceName == SERVICE_NAME_2
})
- val responses2 = getServices(SERVICE_TYPE_2, socketKey)
+ val responses2 = getServices(cacheKey2)
assertEquals(1, responses2.size)
assertTrue(responses2.any { response ->
response.serviceInstanceName == SERVICE_NAME_2
})
- removeService(SERVICE_NAME_2, SERVICE_TYPE_1, socketKey)
- val responses3 = getServices(SERVICE_TYPE_1, socketKey)
+ removeService(SERVICE_NAME_2, cacheKey1)
+ val responses3 = getServices(cacheKey1)
assertEquals(1, responses3.size)
assertTrue(responses3.any { response ->
response.serviceInstanceName == SERVICE_NAME_1
})
- val responses4 = getServices(SERVICE_TYPE_2, socketKey)
+ val responses4 = getServices(cacheKey2)
assertEquals(1, responses4.size)
assertTrue(responses4.any { response ->
response.serviceInstanceName == SERVICE_NAME_2