Merge "Add expired services removal flag" into main
diff --git a/common/flags.aconfig b/common/flags.aconfig
index cadc44f..2e552a1 100644
--- a/common/flags.aconfig
+++ b/common/flags.aconfig
@@ -13,3 +13,10 @@
   description: "This flag controls the forbidden capability API"
   bug: "302997505"
 }
+
+flag {
+  name: "nsd_expired_services_removal"
+  namespace: "android_core_networking"
+  description: "Remove expired services from MdnsServiceCache"
+  bug: "304649384"
+}
diff --git a/service-t/src/com/android/server/NsdService.java b/service-t/src/com/android/server/NsdService.java
index b041260..cc3f019 100644
--- a/service-t/src/com/android/server/NsdService.java
+++ b/service-t/src/com/android/server/NsdService.java
@@ -1703,20 +1703,20 @@
         am.addOnUidImportanceListener(new UidImportanceListener(handler),
                 mRunningAppActiveImportanceCutoff);
 
+        final MdnsFeatureFlags flags = new MdnsFeatureFlags.Builder()
+                .setIsMdnsOffloadFeatureEnabled(mDeps.isTetheringFeatureNotChickenedOut(
+                        mContext, MdnsFeatureFlags.NSD_FORCE_DISABLE_MDNS_OFFLOAD))
+                .setIncludeInetAddressRecordsInProbing(mDeps.isFeatureEnabled(
+                        mContext, MdnsFeatureFlags.INCLUDE_INET_ADDRESS_RECORDS_IN_PROBING))
+                .setIsExpiredServicesRemovalEnabled(mDeps.isTrunkStableFeatureEnabled(
+                        MdnsFeatureFlags.NSD_EXPIRED_SERVICES_REMOVAL))
+                .build();
         mMdnsSocketClient =
                 new MdnsMultinetworkSocketClient(handler.getLooper(), mMdnsSocketProvider,
                         LOGGER.forSubComponent("MdnsMultinetworkSocketClient"));
         mMdnsDiscoveryManager = deps.makeMdnsDiscoveryManager(new ExecutorProvider(),
-                mMdnsSocketClient, LOGGER.forSubComponent("MdnsDiscoveryManager"));
+                mMdnsSocketClient, LOGGER.forSubComponent("MdnsDiscoveryManager"), flags);
         handler.post(() -> mMdnsSocketClient.setCallback(mMdnsDiscoveryManager));
-        MdnsFeatureFlags flags = new MdnsFeatureFlags.Builder()
-                .setIsMdnsOffloadFeatureEnabled(
-                        mDeps.isTetheringFeatureNotChickenedOut(
-                                mContext, MdnsFeatureFlags.NSD_FORCE_DISABLE_MDNS_OFFLOAD))
-                .setIncludeInetAddressRecordsInProbing(
-                        mDeps.isFeatureEnabled(
-                                mContext, MdnsFeatureFlags.INCLUDE_INET_ADDRESS_RECORDS_IN_PROBING))
-                .build();
         mAdvertiser = deps.makeMdnsAdvertiser(handler.getLooper(), mMdnsSocketProvider,
                 new AdvertiserCallback(), LOGGER.forSubComponent("MdnsAdvertiser"), flags);
         mClock = deps.makeClock();
@@ -1774,12 +1774,21 @@
         }
 
         /**
+         * @see DeviceConfigUtils#isTrunkStableFeatureEnabled
+         */
+        public boolean isTrunkStableFeatureEnabled(String feature) {
+            return DeviceConfigUtils.isTrunkStableFeatureEnabled(feature);
+        }
+
+        /**
          * @see MdnsDiscoveryManager
          */
         public MdnsDiscoveryManager makeMdnsDiscoveryManager(
                 @NonNull ExecutorProvider executorProvider,
-                @NonNull MdnsMultinetworkSocketClient socketClient, @NonNull SharedLog sharedLog) {
-            return new MdnsDiscoveryManager(executorProvider, socketClient, sharedLog);
+                @NonNull MdnsMultinetworkSocketClient socketClient, @NonNull SharedLog sharedLog,
+                @NonNull MdnsFeatureFlags featureFlags) {
+            return new MdnsDiscoveryManager(
+                    executorProvider, socketClient, sharedLog, featureFlags);
         }
 
         /**
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java b/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java
index 24e9fa8..766f999 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java
@@ -53,6 +53,7 @@
     @NonNull private final Handler handler;
     @Nullable private final HandlerThread handlerThread;
     @NonNull private final MdnsServiceCache serviceCache;
+    @NonNull private final MdnsFeatureFlags mdnsFeatureFlags;
 
     private static class PerSocketServiceTypeClients {
         private final ArrayMap<Pair<String, SocketKey>, MdnsServiceTypeClient> clients =
@@ -117,20 +118,22 @@
     }
 
     public MdnsDiscoveryManager(@NonNull ExecutorProvider executorProvider,
-            @NonNull MdnsSocketClientBase socketClient, @NonNull SharedLog sharedLog) {
+            @NonNull MdnsSocketClientBase socketClient, @NonNull SharedLog sharedLog,
+            @NonNull MdnsFeatureFlags mdnsFeatureFlags) {
         this.executorProvider = executorProvider;
         this.socketClient = socketClient;
         this.sharedLog = sharedLog;
         this.perSocketServiceTypeClients = new PerSocketServiceTypeClients();
+        this.mdnsFeatureFlags = mdnsFeatureFlags;
         if (socketClient.getLooper() != null) {
             this.handlerThread = null;
             this.handler = new Handler(socketClient.getLooper());
-            this.serviceCache = new MdnsServiceCache(socketClient.getLooper());
+            this.serviceCache = new MdnsServiceCache(socketClient.getLooper(), mdnsFeatureFlags);
         } else {
             this.handlerThread = new HandlerThread(MdnsDiscoveryManager.class.getSimpleName());
             this.handlerThread.start();
             this.handler = new Handler(handlerThread.getLooper());
-            this.serviceCache = new MdnsServiceCache(handlerThread.getLooper());
+            this.serviceCache = new MdnsServiceCache(handlerThread.getLooper(), mdnsFeatureFlags);
         }
     }
 
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsFeatureFlags.java b/service-t/src/com/android/server/connectivity/mdns/MdnsFeatureFlags.java
index a6f7871..6f7645e 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsFeatureFlags.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsFeatureFlags.java
@@ -20,16 +20,21 @@
  */
 public class MdnsFeatureFlags {
     /**
-     * The feature flag for control whether the  mDNS offload is enabled or not.
+     * A feature flag to control whether the mDNS offload is enabled or not.
      */
     public static final String NSD_FORCE_DISABLE_MDNS_OFFLOAD = "nsd_force_disable_mdns_offload";
 
     /**
-     * The feature flag for controlling whether the probing question should include
+     * A feature flag to control whether the probing question should include
      * InetAddressRecords or not.
      */
     public static final String INCLUDE_INET_ADDRESS_RECORDS_IN_PROBING =
             "include_inet_address_records_in_probing";
+    /**
+     * A feature flag to control whether expired services removal should be enabled.
+     */
+    public static final String NSD_EXPIRED_SERVICES_REMOVAL =
+            "nsd_expired_services_removal";
 
     // Flag for offload feature
     public final boolean mIsMdnsOffloadFeatureEnabled;
@@ -37,13 +42,17 @@
     // Flag for including InetAddressRecords in probing questions.
     public final boolean mIncludeInetAddressRecordsInProbing;
 
+    // Flag for expired services removal
+    public final boolean mIsExpiredServicesRemovalEnabled;
+
     /**
      * The constructor for {@link MdnsFeatureFlags}.
      */
     public MdnsFeatureFlags(boolean isOffloadFeatureEnabled,
-            boolean includeInetAddressRecordsInProbing) {
+            boolean includeInetAddressRecordsInProbing, boolean isExpiredServicesRemovalEnabled) {
         mIsMdnsOffloadFeatureEnabled = isOffloadFeatureEnabled;
         mIncludeInetAddressRecordsInProbing = includeInetAddressRecordsInProbing;
+        mIsExpiredServicesRemovalEnabled = isExpiredServicesRemovalEnabled;
     }
 
 
@@ -57,6 +66,7 @@
 
         private boolean mIsMdnsOffloadFeatureEnabled;
         private boolean mIncludeInetAddressRecordsInProbing;
+        private boolean mIsExpiredServicesRemovalEnabled;
 
         /**
          * The constructor for {@link Builder}.
@@ -64,10 +74,13 @@
         public Builder() {
             mIsMdnsOffloadFeatureEnabled = false;
             mIncludeInetAddressRecordsInProbing = false;
+            mIsExpiredServicesRemovalEnabled = true; // Default enabled.
         }
 
         /**
-         * Set if the mDNS offload  feature is enabled.
+         * Set whether the mDNS offload feature is enabled.
+         *
+         * @see #NSD_FORCE_DISABLE_MDNS_OFFLOAD
          */
         public Builder setIsMdnsOffloadFeatureEnabled(boolean isMdnsOffloadFeatureEnabled) {
             mIsMdnsOffloadFeatureEnabled = isMdnsOffloadFeatureEnabled;
@@ -75,7 +88,9 @@
         }
 
         /**
-         * Set if the probing question should include InetAddressRecords.
+         * Set whether the probing question should include InetAddressRecords.
+         *
+         * @see #INCLUDE_INET_ADDRESS_RECORDS_IN_PROBING
          */
         public Builder setIncludeInetAddressRecordsInProbing(
                 boolean includeInetAddressRecordsInProbing) {
@@ -84,12 +99,21 @@
         }
 
         /**
+         * Set whether the expired services removal is enabled.
+         *
+         * @see #NSD_EXPIRED_SERVICES_REMOVAL
+         */
+        public Builder setIsExpiredServicesRemovalEnabled(boolean isExpiredServicesRemovalEnabled) {
+            mIsExpiredServicesRemovalEnabled = isExpiredServicesRemovalEnabled;
+            return this;
+        }
+
+        /**
          * Builds a {@link MdnsFeatureFlags} with the arguments supplied to this builder.
          */
         public MdnsFeatureFlags build() {
-            return new MdnsFeatureFlags(
-                    mIsMdnsOffloadFeatureEnabled, mIncludeInetAddressRecordsInProbing);
+            return new MdnsFeatureFlags(mIsMdnsOffloadFeatureEnabled,
+                    mIncludeInetAddressRecordsInProbing, mIsExpiredServicesRemovalEnabled);
         }
-
     }
 }
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceCache.java b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceCache.java
index f9ee0df..d3493c7 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceCache.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceCache.java
@@ -80,9 +80,12 @@
     private final ArrayMap<CacheKey, ServiceExpiredCallback> mCallbacks = new ArrayMap<>();
     @NonNull
     private final Handler mHandler;
+    @NonNull
+    private final MdnsFeatureFlags mMdnsFeatureFlags;
 
-    public MdnsServiceCache(@NonNull Looper looper) {
+    public MdnsServiceCache(@NonNull Looper looper, @NonNull MdnsFeatureFlags mdnsFeatureFlags) {
         mHandler = new Handler(looper);
+        mMdnsFeatureFlags = mdnsFeatureFlags;
     }
 
     /**
diff --git a/tests/unit/java/com/android/server/NsdServiceTest.java b/tests/unit/java/com/android/server/NsdServiceTest.java
index 71bd330..771edb2 100644
--- a/tests/unit/java/com/android/server/NsdServiceTest.java
+++ b/tests/unit/java/com/android/server/NsdServiceTest.java
@@ -35,14 +35,17 @@
 import static android.net.nsd.NsdManager.FAILURE_BAD_PARAMETERS;
 import static android.net.nsd.NsdManager.FAILURE_INTERNAL_ERROR;
 import static android.net.nsd.NsdManager.FAILURE_OPERATION_NOT_RUNNING;
+
 import static com.android.networkstack.apishim.api33.ConstantsShim.REGISTER_NSD_OFFLOAD_ENGINE;
 import static com.android.server.NsdService.DEFAULT_RUNNING_APP_ACTIVE_IMPORTANCE_CUTOFF;
 import static com.android.server.NsdService.MdnsListener;
 import static com.android.server.NsdService.NO_TRANSACTION;
 import static com.android.server.NsdService.parseTypeAndSubtype;
 import static com.android.testutils.ContextUtils.mockService;
+
 import static libcore.junit.util.compat.CoreCompatChangeRule.DisableCompatChanges;
 import static libcore.junit.util.compat.CoreCompatChangeRule.EnableCompatChanges;
+
 import static org.junit.Assert.assertArrayEquals;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNotNull;
@@ -220,7 +223,7 @@
                 anyInt(), anyString(), anyString(), anyString(), anyInt());
         doReturn(false).when(mDeps).isMdnsDiscoveryManagerEnabled(any(Context.class));
         doReturn(mDiscoveryManager).when(mDeps)
-                .makeMdnsDiscoveryManager(any(), any(), any());
+                .makeMdnsDiscoveryManager(any(), any(), any(), any());
         doReturn(mMulticastLock).when(mWifiManager).createMulticastLock(any());
         doReturn(mSocketProvider).when(mDeps).makeMdnsSocketProvider(any(), any(), any(), any());
         doReturn(DEFAULT_RUNNING_APP_ACTIVE_IMPORTANCE_CUTOFF).when(mDeps).getDeviceConfigInt(
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsDiscoveryManagerTests.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsDiscoveryManagerTests.java
index e869b91..331a5b6 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsDiscoveryManagerTests.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsDiscoveryManagerTests.java
@@ -106,7 +106,7 @@
         doReturn(thread.getLooper()).when(socketClient).getLooper();
         doReturn(true).when(socketClient).supportsRequestingSpecificNetworks();
         discoveryManager = new MdnsDiscoveryManager(executorProvider, socketClient,
-                sharedLog) {
+                sharedLog, MdnsFeatureFlags.newBuilder().build()) {
                     @Override
                     MdnsServiceTypeClient createServiceTypeClient(@NonNull String serviceType,
                             @NonNull SocketKey socketKey) {
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceCacheTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceCacheTest.kt
index 1b6f120..2b3b834 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceCacheTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceCacheTest.kt
@@ -50,9 +50,6 @@
     private val handler by lazy {
         Handler(thread.looper)
     }
-    private val serviceCache by lazy {
-        MdnsServiceCache(thread.looper)
-    }
 
     @Before
     fun setUp() {
@@ -64,6 +61,11 @@
         thread.quitSafely()
     }
 
+    private fun makeFlags(isExpiredServicesRemovalEnabled: Boolean = false) =
+            MdnsFeatureFlags.Builder()
+                    .setIsExpiredServicesRemovalEnabled(isExpiredServicesRemovalEnabled)
+                    .build()
+
     private fun <T> runningOnHandlerAndReturn(functor: (() -> T)): T {
         val future = CompletableFuture<T>()
         handler.post {
@@ -72,36 +74,51 @@
         return future.get(DEFAULT_TIMEOUT_MS, TimeUnit.MILLISECONDS)
     }
 
-    private fun addOrUpdateService(cacheKey: CacheKey, service: MdnsResponse): Unit =
-        runningOnHandlerAndReturn { serviceCache.addOrUpdateService(cacheKey, service) }
+    private fun addOrUpdateService(
+            serviceCache: MdnsServiceCache,
+            cacheKey: CacheKey,
+            service: MdnsResponse
+    ): Unit = runningOnHandlerAndReturn { serviceCache.addOrUpdateService(cacheKey, service) }
 
-    private fun removeService(serviceName: String, cacheKey: CacheKey): Unit =
-        runningOnHandlerAndReturn { serviceCache.removeService(serviceName, cacheKey) }
+    private fun removeService(
+            serviceCache: MdnsServiceCache,
+            serviceName: String,
+            cacheKey: CacheKey
+    ): Unit = runningOnHandlerAndReturn { serviceCache.removeService(serviceName, cacheKey) }
 
-    private fun getService(serviceName: String, cacheKey: CacheKey): MdnsResponse? =
-        runningOnHandlerAndReturn { serviceCache.getCachedService(serviceName, cacheKey) }
+    private fun getService(
+            serviceCache: MdnsServiceCache,
+            serviceName: String,
+            cacheKey: CacheKey,
+    ): MdnsResponse? = runningOnHandlerAndReturn {
+        serviceCache.getCachedService(serviceName, cacheKey)
+    }
 
-    private fun getServices(cacheKey: CacheKey): List<MdnsResponse> =
-        runningOnHandlerAndReturn { serviceCache.getCachedServices(cacheKey) }
+    private fun getServices(
+            serviceCache: MdnsServiceCache,
+            cacheKey: CacheKey,
+    ): List<MdnsResponse> = runningOnHandlerAndReturn { serviceCache.getCachedServices(cacheKey) }
 
     @Test
     fun testAddAndRemoveService() {
-        addOrUpdateService(cacheKey1, createResponse(SERVICE_NAME_1, SERVICE_TYPE_1))
-        var response = getService(SERVICE_NAME_1, cacheKey1)
+        val serviceCache = MdnsServiceCache(thread.looper, makeFlags())
+        addOrUpdateService(serviceCache, cacheKey1, createResponse(SERVICE_NAME_1, SERVICE_TYPE_1))
+        var response = getService(serviceCache, SERVICE_NAME_1, cacheKey1)
         assertNotNull(response)
         assertEquals(SERVICE_NAME_1, response.serviceInstanceName)
-        removeService(SERVICE_NAME_1, cacheKey1)
-        response = getService(SERVICE_NAME_1, cacheKey1)
+        removeService(serviceCache, SERVICE_NAME_1, cacheKey1)
+        response = getService(serviceCache, SERVICE_NAME_1, cacheKey1)
         assertNull(response)
     }
 
     @Test
     fun testGetCachedServices_multipleServiceTypes() {
-        addOrUpdateService(cacheKey1, createResponse(SERVICE_NAME_1, SERVICE_TYPE_1))
-        addOrUpdateService(cacheKey1, createResponse(SERVICE_NAME_2, SERVICE_TYPE_1))
-        addOrUpdateService(cacheKey2, createResponse(SERVICE_NAME_2, SERVICE_TYPE_2))
+        val serviceCache = MdnsServiceCache(thread.looper, makeFlags())
+        addOrUpdateService(serviceCache, cacheKey1, createResponse(SERVICE_NAME_1, SERVICE_TYPE_1))
+        addOrUpdateService(serviceCache, cacheKey1, createResponse(SERVICE_NAME_2, SERVICE_TYPE_1))
+        addOrUpdateService(serviceCache, cacheKey2, createResponse(SERVICE_NAME_2, SERVICE_TYPE_2))
 
-        val responses1 = getServices(cacheKey1)
+        val responses1 = getServices(serviceCache, cacheKey1)
         assertEquals(2, responses1.size)
         assertTrue(responses1.stream().anyMatch { response ->
             response.serviceInstanceName == SERVICE_NAME_1
@@ -109,19 +126,19 @@
         assertTrue(responses1.any { response ->
             response.serviceInstanceName == SERVICE_NAME_2
         })
-        val responses2 = getServices(cacheKey2)
+        val responses2 = getServices(serviceCache, cacheKey2)
         assertEquals(1, responses2.size)
         assertTrue(responses2.any { response ->
             response.serviceInstanceName == SERVICE_NAME_2
         })
 
-        removeService(SERVICE_NAME_2, cacheKey1)
-        val responses3 = getServices(cacheKey1)
+        removeService(serviceCache, SERVICE_NAME_2, cacheKey1)
+        val responses3 = getServices(serviceCache, cacheKey1)
         assertEquals(1, responses3.size)
         assertTrue(responses3.any { response ->
             response.serviceInstanceName == SERVICE_NAME_1
         })
-        val responses4 = getServices(cacheKey2)
+        val responses4 = getServices(serviceCache, cacheKey2)
         assertEquals(1, responses4.size)
         assertTrue(responses4.any { response ->
             response.serviceInstanceName == SERVICE_NAME_2
@@ -129,6 +146,6 @@
     }
 
     private fun createResponse(serviceInstanceName: String, serviceType: String) = MdnsResponse(
-        0 /* now */, "$serviceInstanceName.$serviceType".split(".").toTypedArray(),
+            0 /* now */, "$serviceInstanceName.$serviceType".split(".").toTypedArray(),
             socketKey.interfaceIndex, socketKey.network)
 }
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java
index fde5abd..ce154dd 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java
@@ -193,7 +193,8 @@
         thread = new HandlerThread("MdnsServiceTypeClientTests");
         thread.start();
         handler = new Handler(thread.getLooper());
-        serviceCache = new MdnsServiceCache(thread.getLooper());
+        serviceCache = new MdnsServiceCache(
+                thread.getLooper(), MdnsFeatureFlags.newBuilder().build());
 
         doAnswer(inv -> {
             latestDelayMs = 0;