Take the multicast lock on mDNS usage

When an mDNS request (discovery, advertising, resolving...) is
registered and gets assigned a socket on a wifi network, take the
multicast lock to ensure that it can reliably receive mDNS responses.

This is limited to when the application has importance
FOREGROUND_SERVICE or higher.
NsdManager is not documented to require usage of the multicast lock,
which has caused various reports about its reliability. Taking the lock
while the app is in the foreground should address the large majority of
cases, while limiting battery impact.

Going forward this should allow developers on U+ to not take the
lock themselves, allowing optimizations on devices supporting APF,
where instead of taking the lock APF would let through only select
mDNS packets.

Bug: 284389438
Test: atest
Change-Id: I1ce85220eac4a1529b6716d50727c1c462356846
diff --git a/service-t/src/com/android/server/NsdService.java b/service-t/src/com/android/server/NsdService.java
index 47a1022..25aa693 100644
--- a/service-t/src/com/android/server/NsdService.java
+++ b/service-t/src/com/android/server/NsdService.java
@@ -17,6 +17,8 @@
 package com.android.server;
 
 import static android.net.ConnectivityManager.NETID_UNSET;
+import static android.net.NetworkCapabilities.TRANSPORT_VPN;
+import static android.net.NetworkCapabilities.TRANSPORT_WIFI;
 import static android.net.nsd.NsdManager.MDNS_DISCOVERY_MANAGER_EVENT;
 import static android.net.nsd.NsdManager.MDNS_SERVICE_EVENT;
 import static android.net.nsd.NsdManager.RESOLVE_SERVICE_SUCCEEDED;
@@ -27,6 +29,7 @@
 
 import android.annotation.NonNull;
 import android.annotation.Nullable;
+import android.app.ActivityManager;
 import android.content.Context;
 import android.content.Intent;
 import android.net.ConnectivityManager;
@@ -45,6 +48,7 @@
 import android.net.nsd.MDnsManager;
 import android.net.nsd.NsdManager;
 import android.net.nsd.NsdServiceInfo;
+import android.net.wifi.WifiManager;
 import android.os.Binder;
 import android.os.Handler;
 import android.os.HandlerThread;
@@ -53,7 +57,9 @@
 import android.os.Message;
 import android.os.RemoteException;
 import android.os.UserHandle;
+import android.provider.DeviceConfig;
 import android.text.TextUtils;
+import android.util.ArraySet;
 import android.util.Log;
 import android.util.Pair;
 import android.util.SparseArray;
@@ -62,6 +68,7 @@
 import com.android.internal.util.IndentingPrintWriter;
 import com.android.internal.util.State;
 import com.android.internal.util.StateMachine;
+import com.android.net.module.util.CollectionUtils;
 import com.android.net.module.util.DeviceConfigUtils;
 import com.android.net.module.util.InetAddressUtils;
 import com.android.net.module.util.PermissionUtils;
@@ -69,6 +76,7 @@
 import com.android.server.connectivity.mdns.ExecutorProvider;
 import com.android.server.connectivity.mdns.MdnsAdvertiser;
 import com.android.server.connectivity.mdns.MdnsDiscoveryManager;
+import com.android.server.connectivity.mdns.MdnsInterfaceSocket;
 import com.android.server.connectivity.mdns.MdnsMultinetworkSocketClient;
 import com.android.server.connectivity.mdns.MdnsSearchOptions;
 import com.android.server.connectivity.mdns.MdnsServiceBrowserListener;
@@ -141,6 +149,14 @@
             "mdns_advertiser_allowlist_";
     private static final String MDNS_ALLOWLIST_FLAG_SUFFIX = "_version";
 
+    @VisibleForTesting
+    static final String MDNS_CONFIG_RUNNING_APP_ACTIVE_IMPORTANCE_CUTOFF =
+            "mdns_config_running_app_active_importance_cutoff";
+    @VisibleForTesting
+    static final int DEFAULT_RUNNING_APP_ACTIVE_IMPORTANCE_CUTOFF =
+            ActivityManager.RunningAppProcessInfo.IMPORTANCE_FOREGROUND;
+    private final int mRunningAppActiveImportanceCutoff;
+
     public static final boolean DBG = Log.isLoggable(TAG, Log.DEBUG);
     private static final long CLEANUP_DELAY_MS = 10000;
     private static final int IFACE_IDX_ANY = 0;
@@ -175,6 +191,16 @@
     /* A map from unique id to client info */
     private final SparseArray<ClientInfo> mIdToClientInfoMap= new SparseArray<>();
 
+    // Note this is not final to avoid depending on the Wi-Fi service starting before NsdService
+    @Nullable
+    private WifiManager.MulticastLock mHeldMulticastLock;
+    // Fulfilled network requests that require the Wi-Fi lock: key is the obtained Network
+    // (non-null), value is the requested Network (nullable)
+    @NonNull
+    private final ArraySet<Network> mWifiLockRequiredNetworks = new ArraySet<>();
+    @NonNull
+    private final ArraySet<Integer> mRunningAppActiveUids = new ArraySet<>();
+
     private final long mCleanupDelayMs;
 
     private static final int INVALID_ID = 0;
@@ -299,6 +325,104 @@
         }
     }
 
+    private class SocketRequestMonitor implements MdnsSocketProvider.SocketRequestMonitor {
+        @Override
+        public void onSocketRequestFulfilled(@Nullable Network socketNetwork,
+                @NonNull MdnsInterfaceSocket socket, @NonNull int[] transports) {
+            // The network may be null for Wi-Fi SoftAp interfaces (tethering), but there is no APF
+            // filtering on such interfaces, so taking the multicast lock is not necessary to
+            // disable APF filtering of multicast.
+            if (socketNetwork == null
+                    || !CollectionUtils.contains(transports, TRANSPORT_WIFI)
+                    || CollectionUtils.contains(transports, TRANSPORT_VPN)) {
+                return;
+            }
+
+            if (mWifiLockRequiredNetworks.add(socketNetwork)) {
+                updateMulticastLock();
+            }
+        }
+
+        @Override
+        public void onSocketDestroyed(@Nullable Network socketNetwork,
+                @NonNull MdnsInterfaceSocket socket) {
+            if (mWifiLockRequiredNetworks.remove(socketNetwork)) {
+                updateMulticastLock();
+            }
+        }
+    }
+
+    private class UidImportanceListener implements ActivityManager.OnUidImportanceListener {
+        private final Handler mHandler;
+
+        private UidImportanceListener(Handler handler) {
+            mHandler = handler;
+        }
+
+        @Override
+        public void onUidImportance(int uid, int importance) {
+            mHandler.post(() -> handleUidImportanceChanged(uid, importance));
+        }
+    }
+
+    private void handleUidImportanceChanged(int uid, int importance) {
+        // Lower importance values are more "important"
+        final boolean modified = importance <= mRunningAppActiveImportanceCutoff
+                ? mRunningAppActiveUids.add(uid)
+                : mRunningAppActiveUids.remove(uid);
+        if (modified) {
+            updateMulticastLock();
+        }
+    }
+
+    /**
+     * Take or release the lock based on updated internal state.
+     *
+     * This determines whether the lock needs to be held based on
+     * {@link #mWifiLockRequiredNetworks}, {@link #mRunningAppActiveUids} and
+     * {@link ClientInfo#mClientRequests}, so it must be called after any of the these have been
+     * updated.
+     */
+    private void updateMulticastLock() {
+        final int needsLockUid = getMulticastLockNeededUid();
+        if (needsLockUid >= 0 && mHeldMulticastLock == null) {
+            final WifiManager wm = mContext.getSystemService(WifiManager.class);
+            if (wm == null) {
+                Log.wtf(TAG, "Got a TRANSPORT_WIFI network without WifiManager");
+                return;
+            }
+            mHeldMulticastLock = wm.createMulticastLock(TAG);
+            mHeldMulticastLock.acquire();
+            mServiceLogs.log("Taking multicast lock for uid " + needsLockUid);
+        } else if (needsLockUid < 0 && mHeldMulticastLock != null) {
+            mHeldMulticastLock.release();
+            mHeldMulticastLock = null;
+            mServiceLogs.log("Released multicast lock");
+        }
+    }
+
+    /**
+     * @return The UID of an app requiring the multicast lock, or -1 if none.
+     */
+    private int getMulticastLockNeededUid() {
+        if (mWifiLockRequiredNetworks.size() == 0) {
+            // Return early if NSD is not active, or not on any relevant network
+            return -1;
+        }
+        for (int i = 0; i < mIdToClientInfoMap.size(); i++) {
+            final ClientInfo clientInfo = mIdToClientInfoMap.valueAt(i);
+            if (!mRunningAppActiveUids.contains(clientInfo.mUid)) {
+                // Ignore non-active UIDs
+                continue;
+            }
+
+            if (clientInfo.hasAnyJavaBackendRequestForNetworks(mWifiLockRequiredNetworks)) {
+                return clientInfo.mUid;
+            }
+        }
+        return -1;
+    }
+
     /**
      * Data class of mdns service callback information.
      */
@@ -404,7 +528,7 @@
                         try {
                             cb.asBinder().linkToDeath(arg.connector, 0);
                             final String tag = "Client" + arg.uid + "-" + mClientNumberId++;
-                            cInfo = new ClientInfo(cb, arg.useJavaBackend,
+                            cInfo = new ClientInfo(cb, arg.uid, arg.useJavaBackend,
                                     mServiceLogs.forSubComponent(tag));
                             mClients.put(arg.connector, cInfo);
                         } catch (RemoteException e) {
@@ -529,9 +653,11 @@
             }
 
             private void storeAdvertiserRequestMap(int clientId, int globalId,
-                    ClientInfo clientInfo) {
-                clientInfo.mClientRequests.put(clientId, new AdvertiserClientRequest(globalId));
+                    ClientInfo clientInfo, @Nullable Network requestedNetwork) {
+                clientInfo.mClientRequests.put(clientId,
+                        new AdvertiserClientRequest(globalId, requestedNetwork));
                 mIdToClientInfoMap.put(globalId, clientInfo);
+                updateMulticastLock();
             }
 
             private void removeRequestMap(int clientId, int globalId, ClientInfo clientInfo) {
@@ -544,14 +670,17 @@
                     maybeScheduleStop();
                 } else {
                     maybeStopMonitoringSocketsIfNoActiveRequest();
+                    updateMulticastLock();
                 }
             }
 
             private void storeDiscoveryManagerRequestMap(int clientId, int globalId,
-                    MdnsListener listener, ClientInfo clientInfo) {
+                    MdnsListener listener, ClientInfo clientInfo,
+                    @Nullable Network requestedNetwork) {
                 clientInfo.mClientRequests.put(clientId,
-                        new DiscoveryManagerRequest(globalId, listener));
+                        new DiscoveryManagerRequest(globalId, listener, requestedNetwork));
                 mIdToClientInfoMap.put(globalId, clientInfo);
+                updateMulticastLock();
             }
 
             /**
@@ -628,7 +757,8 @@
                             }
                             mMdnsDiscoveryManager.registerListener(
                                     listenServiceType, listener, optionsBuilder.build());
-                            storeDiscoveryManagerRequestMap(clientId, id, listener, clientInfo);
+                            storeDiscoveryManagerRequestMap(clientId, id, listener, clientInfo,
+                                    info.getNetwork());
                             clientInfo.onDiscoverServicesStarted(clientId, info);
                             clientInfo.log("Register a DiscoveryListener " + id
                                     + " for service type:" + listenServiceType);
@@ -728,7 +858,8 @@
                             // Name._subtype._sub._type._tcp, which is incorrect
                             // (it should be Name._type._tcp).
                             mAdvertiser.addService(id, serviceInfo, typeSubtype.second);
-                            storeAdvertiserRequestMap(clientId, id, clientInfo);
+                            storeAdvertiserRequestMap(clientId, id, clientInfo,
+                                    serviceInfo.getNetwork());
                         } else {
                             maybeStartDaemon();
                             if (registerService(id, serviceInfo)) {
@@ -818,7 +949,8 @@
                                     .build();
                             mMdnsDiscoveryManager.registerListener(
                                     resolveServiceType, listener, options);
-                            storeDiscoveryManagerRequestMap(clientId, id, listener, clientInfo);
+                            storeDiscoveryManagerRequestMap(clientId, id, listener, clientInfo,
+                                    info.getNetwork());
                             clientInfo.log("Register a ResolutionListener " + id
                                     + " for service type:" + resolveServiceType);
                         } else {
@@ -912,7 +1044,8 @@
                                 .build();
                         mMdnsDiscoveryManager.registerListener(
                                 resolveServiceType, listener, options);
-                        storeDiscoveryManagerRequestMap(clientId, id, listener, clientInfo);
+                        storeDiscoveryManagerRequestMap(clientId, id, listener, clientInfo,
+                                info.getNetwork());
                         clientInfo.log("Register a ServiceInfoListener " + id
                                 + " for service type:" + resolveServiceType);
                         break;
@@ -1389,10 +1522,20 @@
         mDeps = deps;
 
         mMdnsSocketProvider = deps.makeMdnsSocketProvider(ctx, handler.getLooper(),
-                LOGGER.forSubComponent("MdnsSocketProvider"));
+                LOGGER.forSubComponent("MdnsSocketProvider"), new SocketRequestMonitor());
         // Netlink monitor starts on boot, and intentionally never stopped, to ensure that all
         // address events are received.
         handler.post(mMdnsSocketProvider::startNetLinkMonitor);
+
+        // NsdService is started after ActivityManager (startOtherServices in SystemServer, vs.
+        // startBootstrapServices).
+        mRunningAppActiveImportanceCutoff = mDeps.getDeviceConfigInt(
+                MDNS_CONFIG_RUNNING_APP_ACTIVE_IMPORTANCE_CUTOFF,
+                DEFAULT_RUNNING_APP_ACTIVE_IMPORTANCE_CUTOFF);
+        final ActivityManager am = ctx.getSystemService(ActivityManager.class);
+        am.addOnUidImportanceListener(new UidImportanceListener(handler),
+                mRunningAppActiveImportanceCutoff);
+
         mMdnsSocketClient =
                 new MdnsMultinetworkSocketClient(handler.getLooper(), mMdnsSocketProvider);
         mMdnsDiscoveryManager = deps.makeMdnsDiscoveryManager(new ExecutorProvider(),
@@ -1471,8 +1614,23 @@
          * @see MdnsSocketProvider
          */
         public MdnsSocketProvider makeMdnsSocketProvider(@NonNull Context context,
-                @NonNull Looper looper, @NonNull SharedLog sharedLog) {
-            return new MdnsSocketProvider(context, looper, sharedLog);
+                @NonNull Looper looper, @NonNull SharedLog sharedLog,
+                @NonNull MdnsSocketProvider.SocketRequestMonitor socketCreationCallback) {
+            return new MdnsSocketProvider(context, looper, sharedLog, socketCreationCallback);
+        }
+
+        /**
+         * @see DeviceConfig#getInt(String, String, int)
+         */
+        public int getDeviceConfigInt(@NonNull String config, int defaultValue) {
+            return DeviceConfig.getInt(NAMESPACE_TETHERING, config, defaultValue);
+        }
+
+        /**
+         * @see Binder#getCallingUid()
+         */
+        public int getCallingUid() {
+            return Binder.getCallingUid();
         }
     }
 
@@ -1626,7 +1784,7 @@
         final INsdServiceConnector connector = new NsdServiceConnector();
         mNsdStateMachine.sendMessage(mNsdStateMachine.obtainMessage(NsdManager.REGISTER_CLIENT,
                 new ConnectorArgs((NsdServiceConnector) connector, cb, useJavaBackend,
-                        Binder.getCallingUid())));
+                        mDeps.getCallingUid())));
         return connector;
     }
 
@@ -1857,18 +2015,34 @@
         }
     }
 
-    private static class AdvertiserClientRequest extends ClientRequest {
-        private AdvertiserClientRequest(int globalId) {
+    private abstract static class JavaBackendClientRequest extends ClientRequest {
+        @Nullable
+        private final Network mRequestedNetwork;
+
+        private JavaBackendClientRequest(int globalId, @Nullable Network requestedNetwork) {
             super(globalId);
+            mRequestedNetwork = requestedNetwork;
+        }
+
+        @Nullable
+        public Network getRequestedNetwork() {
+            return mRequestedNetwork;
         }
     }
 
-    private static class DiscoveryManagerRequest extends ClientRequest {
+    private static class AdvertiserClientRequest extends JavaBackendClientRequest {
+        private AdvertiserClientRequest(int globalId, @Nullable Network requestedNetwork) {
+            super(globalId, requestedNetwork);
+        }
+    }
+
+    private static class DiscoveryManagerRequest extends JavaBackendClientRequest {
         @NonNull
         private final MdnsListener mListener;
 
-        private DiscoveryManagerRequest(int globalId, @NonNull MdnsListener listener) {
-            super(globalId);
+        private DiscoveryManagerRequest(int globalId, @NonNull MdnsListener listener,
+                @Nullable Network requestedNetwork) {
+            super(globalId, requestedNetwork);
             mListener = listener;
         }
     }
@@ -1886,13 +2060,16 @@
 
         // The target SDK of this client < Build.VERSION_CODES.S
         private boolean mIsPreSClient = false;
+        private final int mUid;
         // The flag of using java backend if the client's target SDK >= U
         private final boolean mUseJavaBackend;
         // Store client logs
         private final SharedLog mClientLogs;
 
-        private ClientInfo(INsdManagerCallback cb, boolean useJavaBackend, SharedLog sharedLog) {
+        private ClientInfo(INsdManagerCallback cb, int uid, boolean useJavaBackend,
+                SharedLog sharedLog) {
             mCb = cb;
+            mUid = uid;
             mUseJavaBackend = useJavaBackend;
             mClientLogs = sharedLog;
             mClientLogs.log("New client. useJavaBackend=" + useJavaBackend);
@@ -1903,6 +2080,8 @@
             StringBuilder sb = new StringBuilder();
             sb.append("mResolvedService ").append(mResolvedService).append("\n");
             sb.append("mIsLegacy ").append(mIsPreSClient).append("\n");
+            sb.append("mUseJavaBackend ").append(mUseJavaBackend).append("\n");
+            sb.append("mUid ").append(mUid).append("\n");
             for (int i = 0; i < mClientRequests.size(); i++) {
                 int clientID = mClientRequests.keyAt(i);
                 sb.append("clientId ")
@@ -1974,6 +2153,26 @@
                 }
             }
             mClientRequests.clear();
+            updateMulticastLock();
+        }
+
+        /**
+         * Returns true if this client has any Java backend request that requests one of the given
+         * networks.
+         */
+        boolean hasAnyJavaBackendRequestForNetworks(@NonNull ArraySet<Network> networks) {
+            for (int i = 0; i < mClientRequests.size(); i++) {
+                final ClientRequest req = mClientRequests.valueAt(i);
+                if (!(req instanceof JavaBackendClientRequest)) {
+                    continue;
+                }
+                final Network reqNetwork = ((JavaBackendClientRequest) mClientRequests.valueAt(i))
+                        .getRequestedNetwork();
+                if (MdnsUtils.isAnyNetworkMatched(reqNetwork, networks)) {
+                    return true;
+                }
+            }
+            return false;
         }
 
         // mClientRequests is a sparse array of listener id -> ClientRequest.  For a given
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsSocketProvider.java b/service-t/src/com/android/server/connectivity/mdns/MdnsSocketProvider.java
index dc09bef..d90f67f 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsSocketProvider.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsSocketProvider.java
@@ -97,6 +97,8 @@
     // the netlink monitor is never stop and the old states must be kept.
     private final SparseArray<LinkProperties> mIfaceIdxToLinkProperties = new SparseArray<>();
     private final byte[] mPacketReadBuffer = new byte[READ_BUFFER_SIZE];
+    @NonNull
+    private final SocketRequestMonitor mSocketRequestMonitor;
     private boolean mMonitoringSockets = false;
     private boolean mRequestStop = false;
     private String mWifiP2pTetherInterface = null;
@@ -155,17 +157,20 @@
     }
 
     public MdnsSocketProvider(@NonNull Context context, @NonNull Looper looper,
-            @NonNull SharedLog sharedLog) {
-        this(context, looper, new Dependencies(), sharedLog);
+            @NonNull SharedLog sharedLog,
+            @NonNull SocketRequestMonitor socketRequestMonitor) {
+        this(context, looper, new Dependencies(), sharedLog, socketRequestMonitor);
     }
 
     MdnsSocketProvider(@NonNull Context context, @NonNull Looper looper,
-            @NonNull Dependencies deps, @NonNull SharedLog sharedLog) {
+            @NonNull Dependencies deps, @NonNull SharedLog sharedLog,
+            @NonNull SocketRequestMonitor socketRequestMonitor) {
         mContext = context;
         mLooper = looper;
         mHandler = new Handler(looper);
         mDependencies = deps;
         mSharedLog = sharedLog;
+        mSocketRequestMonitor = socketRequestMonitor;
         mNetworkCallback = new NetworkCallback() {
             @Override
             public void onLost(Network network) {
@@ -312,10 +317,12 @@
     private static class SocketInfo {
         final MdnsInterfaceSocket mSocket;
         final List<LinkAddress> mAddresses;
+        final int[] mTransports;
 
-        SocketInfo(MdnsInterfaceSocket socket, List<LinkAddress> addresses) {
+        SocketInfo(MdnsInterfaceSocket socket, List<LinkAddress> addresses, int[] transports) {
             mSocket = socket;
             mAddresses = new ArrayList<>(addresses);
+            mTransports = transports;
         }
     }
 
@@ -498,9 +505,13 @@
             if (networkKey == LOCAL_NET) {
                 transports = new int[0];
             } else {
-                transports = mActiveNetworksTransports.get(((NetworkAsKey) networkKey).mNetwork);
-                if (transports == null) {
+                final int[] knownTransports =
+                        mActiveNetworksTransports.get(((NetworkAsKey) networkKey).mNetwork);
+                if (knownTransports != null) {
+                    transports = knownTransports;
+                } else {
                     Log.wtf(TAG, "transports is missing for key: " + networkKey);
+                    transports = new int[0];
                 }
             }
             if (networkInterface == null || !isMdnsCapableInterface(networkInterface, transports)) {
@@ -512,21 +523,22 @@
                     networkInterface.getNetworkInterface(), MdnsConstants.MDNS_PORT, mLooper,
                     mPacketReadBuffer);
             final List<LinkAddress> addresses = lp.getLinkAddresses();
+            // TODO: technically transport types are mutable, although generally not in ways that
+            // would meaningfully impact the logic using it here. Consider updating logic to
+            // support transports being added/removed.
+            final SocketInfo socketInfo = new SocketInfo(socket, addresses, transports);
             if (networkKey == LOCAL_NET) {
-                mTetherInterfaceSockets.put(interfaceName, new SocketInfo(socket, addresses));
+                mTetherInterfaceSockets.put(interfaceName, socketInfo);
             } else {
-                mNetworkSockets.put(((NetworkAsKey) networkKey).mNetwork,
-                        new SocketInfo(socket, addresses));
+                mNetworkSockets.put(((NetworkAsKey) networkKey).mNetwork, socketInfo);
             }
             // Try to join IPv4/IPv6 group.
             socket.joinGroup(addresses);
 
             // Notify the listeners which need this socket.
-            if (networkKey == LOCAL_NET) {
-                notifySocketCreated(null /* network */, socket, addresses);
-            } else {
-                notifySocketCreated(((NetworkAsKey) networkKey).mNetwork, socket, addresses);
-            }
+            final Network network =
+                    networkKey == LOCAL_NET ? null : ((NetworkAsKey) networkKey).mNetwork;
+            notifySocketCreated(network, socketInfo);
         } catch (IOException e) {
             mSharedLog.e("Create socket failed ifName:" + interfaceName, e);
         }
@@ -568,6 +580,7 @@
 
         socketInfo.mSocket.destroy();
         notifyInterfaceDestroyed(network, socketInfo.mSocket);
+        mSocketRequestMonitor.onSocketDestroyed(network, socketInfo.mSocket);
         mSharedLog.log("Remove socket on net:" + network);
     }
 
@@ -576,15 +589,18 @@
         if (socketInfo == null) return;
         socketInfo.mSocket.destroy();
         notifyInterfaceDestroyed(null /* network */, socketInfo.mSocket);
+        mSocketRequestMonitor.onSocketDestroyed(null /* network */, socketInfo.mSocket);
         mSharedLog.log("Remove socket on ifName:" + interfaceName);
     }
 
-    private void notifySocketCreated(Network network, MdnsInterfaceSocket socket,
-            List<LinkAddress> addresses) {
+    private void notifySocketCreated(Network network, SocketInfo socketInfo) {
         for (int i = 0; i < mCallbacksToRequestedNetworks.size(); i++) {
             final Network requestedNetwork = mCallbacksToRequestedNetworks.valueAt(i);
             if (isNetworkMatched(requestedNetwork, network)) {
-                mCallbacksToRequestedNetworks.keyAt(i).onSocketCreated(network, socket, addresses);
+                mCallbacksToRequestedNetworks.keyAt(i).onSocketCreated(network, socketInfo.mSocket,
+                        socketInfo.mAddresses);
+                mSocketRequestMonitor.onSocketRequestFulfilled(network, socketInfo.mSocket,
+                        socketInfo.mTransports);
             }
         }
     }
@@ -622,6 +638,8 @@
         } else {
             // Notify the socket for requested network.
             cb.onSocketCreated(network, socketInfo.mSocket, socketInfo.mAddresses);
+            mSocketRequestMonitor.onSocketRequestFulfilled(network, socketInfo.mSocket,
+                    socketInfo.mTransports);
         }
     }
 
@@ -636,6 +654,8 @@
             // Notify the socket for requested network.
             cb.onSocketCreated(
                     null /* network */, socketInfo.mSocket, socketInfo.mAddresses);
+            mSocketRequestMonitor.onSocketRequestFulfilled(null /* socketNetwork */,
+                    socketInfo.mSocket, socketInfo.mTransports);
         }
     }
 
@@ -690,6 +710,7 @@
             if (matchRequestedNetwork(network)) continue;
             final SocketInfo info = mNetworkSockets.removeAt(i);
             info.mSocket.destroy();
+            mSocketRequestMonitor.onSocketDestroyed(network, info.mSocket);
             mSharedLog.log("Remove socket on net:" + network + " after unrequestSocket");
         }
 
@@ -699,6 +720,7 @@
         for (int i = mTetherInterfaceSockets.size() - 1; i >= 0; i--) {
             final SocketInfo info = mTetherInterfaceSockets.valueAt(i);
             info.mSocket.destroy();
+            mSocketRequestMonitor.onSocketDestroyed(null /* network */, info.mSocket);
             mSharedLog.log("Remove socket on ifName:" + mTetherInterfaceSockets.keyAt(i)
                     + " after unrequestSocket");
         }
@@ -709,19 +731,61 @@
     }
 
 
-    /*** Callbacks for listening socket changes */
+    /**
+     * Callback used to register socket requests.
+     */
     public interface SocketCallback {
-        /*** Notify the socket is created */
+        /**
+         * Notify the socket was created for the registered request.
+         *
+         * This may be called immediately when the request is registered with an existing socket,
+         * if it had been created previously for other requests.
+         */
         default void onSocketCreated(@Nullable Network network, @NonNull MdnsInterfaceSocket socket,
                 @NonNull List<LinkAddress> addresses) {}
-        /*** Notify the interface is destroyed */
+
+        /**
+         * Notify that the interface was destroyed, so the provided socket cannot be used anymore.
+         *
+         * This indicates that although the socket was still requested, it had to be destroyed.
+         */
         default void onInterfaceDestroyed(@Nullable Network network,
                 @NonNull MdnsInterfaceSocket socket) {}
-        /*** Notify the addresses is changed on the network */
+
+        /**
+         * Notify the interface addresses have changed for the network.
+         */
         default void onAddressesChanged(@Nullable Network network,
                 @NonNull MdnsInterfaceSocket socket, @NonNull List<LinkAddress> addresses) {}
     }
 
+    /**
+     * Global callback indicating when sockets are created or destroyed for requests.
+     */
+    public interface SocketRequestMonitor {
+        /**
+         * Indicates that the socket was used to fulfill the request of one requester.
+         *
+         * There is always at most one socket created for each interface. The interface is available
+         * in {@link MdnsInterfaceSocket#getInterface()}.
+         * @param socketNetwork The network of the socket interface, if any.
+         * @param socket The socket that was provided to a requester.
+         * @param transports Array of TRANSPORT_* from {@link NetworkCapabilities}. Empty if the
+         *                   interface is not part of a network with known transports.
+         */
+        default void onSocketRequestFulfilled(@Nullable Network socketNetwork,
+                @NonNull MdnsInterfaceSocket socket, @NonNull int[] transports) {}
+
+        /**
+         * Indicates that a previously created socket was destroyed.
+         *
+         * @param socketNetwork The network of the socket interface, if any.
+         * @param socket The destroyed socket.
+         */
+        default void onSocketDestroyed(@Nullable Network socketNetwork,
+                @NonNull MdnsInterfaceSocket socket) {}
+    }
+
     private interface NetworkKey {
     }
 
diff --git a/service-t/src/com/android/server/connectivity/mdns/util/MdnsUtils.java b/service-t/src/com/android/server/connectivity/mdns/util/MdnsUtils.java
index eb12b9a..3180a6f 100644
--- a/service-t/src/com/android/server/connectivity/mdns/util/MdnsUtils.java
+++ b/service-t/src/com/android/server/connectivity/mdns/util/MdnsUtils.java
@@ -20,6 +20,7 @@
 import android.annotation.Nullable;
 import android.net.Network;
 import android.os.Handler;
+import android.util.ArraySet;
 
 import com.android.server.connectivity.mdns.MdnsConstants;
 import com.android.server.connectivity.mdns.MdnsRecord;
@@ -129,12 +130,21 @@
         return false;
     }
 
-    /*** Check whether the target network is matched current network */
+    /*** Check whether the target network matches the current network */
     public static boolean isNetworkMatched(@Nullable Network targetNetwork,
             @Nullable Network currentNetwork) {
         return targetNetwork == null || targetNetwork.equals(currentNetwork);
     }
 
+    /*** Check whether the target network matches any of the current networks */
+    public static boolean isAnyNetworkMatched(@Nullable Network targetNetwork,
+            ArraySet<Network> currentNetworks) {
+        if (targetNetwork == null) {
+            return !currentNetworks.isEmpty();
+        }
+        return currentNetworks.contains(targetNetwork);
+    }
+
     /**
      * Truncate a service name to up to maxLength UTF-8 bytes.
      */
diff --git a/tests/unit/java/com/android/server/NsdServiceTest.java b/tests/unit/java/com/android/server/NsdServiceTest.java
index 1997215..f51b28d 100644
--- a/tests/unit/java/com/android/server/NsdServiceTest.java
+++ b/tests/unit/java/com/android/server/NsdServiceTest.java
@@ -16,13 +16,21 @@
 
 package com.android.server;
 
+import static android.app.ActivityManager.RunningAppProcessInfo.IMPORTANCE_CACHED;
+import static android.app.ActivityManager.RunningAppProcessInfo.IMPORTANCE_FOREGROUND;
+import static android.app.ActivityManager.RunningAppProcessInfo.IMPORTANCE_GONE;
+import static android.app.ActivityManager.RunningAppProcessInfo.IMPORTANCE_VISIBLE;
 import static android.net.InetAddresses.parseNumericAddress;
+import static android.net.NetworkCapabilities.TRANSPORT_ETHERNET;
+import static android.net.NetworkCapabilities.TRANSPORT_VPN;
+import static android.net.NetworkCapabilities.TRANSPORT_WIFI;
 import static android.net.connectivity.ConnectivityCompatChanges.ENABLE_PLATFORM_MDNS_BACKEND;
 import static android.net.connectivity.ConnectivityCompatChanges.RUN_NATIVE_NSD_ONLY_IF_LEGACY_APPS_T_AND_LATER;
 import static android.net.nsd.NsdManager.FAILURE_BAD_PARAMETERS;
 import static android.net.nsd.NsdManager.FAILURE_INTERNAL_ERROR;
 import static android.net.nsd.NsdManager.FAILURE_OPERATION_NOT_RUNNING;
 
+import static com.android.server.NsdService.DEFAULT_RUNNING_APP_ACTIVE_IMPORTANCE_CUTOFF;
 import static com.android.server.NsdService.parseTypeAndSubtype;
 import static com.android.testutils.ContextUtils.mockService;
 
@@ -42,6 +50,7 @@
 import static org.mockito.Mockito.doCallRealMethod;
 import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.eq;
+import static org.mockito.Mockito.inOrder;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.never;
 import static org.mockito.Mockito.reset;
@@ -51,6 +60,8 @@
 import static org.mockito.Mockito.verifyNoMoreInteractions;
 import static org.mockito.Mockito.when;
 
+import android.app.ActivityManager;
+import android.app.ActivityManager.OnUidImportanceListener;
 import android.compat.testing.PlatformCompatChangeRule;
 import android.content.ContentResolver;
 import android.content.Context;
@@ -68,7 +79,9 @@
 import android.net.nsd.NsdManager.DiscoveryListener;
 import android.net.nsd.NsdManager.RegistrationListener;
 import android.net.nsd.NsdManager.ResolveListener;
+import android.net.nsd.NsdManager.ServiceInfoCallback;
 import android.net.nsd.NsdServiceInfo;
+import android.net.wifi.WifiManager;
 import android.os.Binder;
 import android.os.Build;
 import android.os.Handler;
@@ -85,10 +98,12 @@
 import com.android.server.NsdService.Dependencies;
 import com.android.server.connectivity.mdns.MdnsAdvertiser;
 import com.android.server.connectivity.mdns.MdnsDiscoveryManager;
+import com.android.server.connectivity.mdns.MdnsInterfaceSocket;
 import com.android.server.connectivity.mdns.MdnsSearchOptions;
 import com.android.server.connectivity.mdns.MdnsServiceBrowserListener;
 import com.android.server.connectivity.mdns.MdnsServiceInfo;
 import com.android.server.connectivity.mdns.MdnsSocketProvider;
+import com.android.server.connectivity.mdns.MdnsSocketProvider.SocketRequestMonitor;
 import com.android.testutils.DevSdkIgnoreRule;
 import com.android.testutils.DevSdkIgnoreRunner;
 import com.android.testutils.HandlerUtils;
@@ -101,6 +116,7 @@
 import org.junit.runner.RunWith;
 import org.mockito.AdditionalAnswers;
 import org.mockito.ArgumentCaptor;
+import org.mockito.InOrder;
 import org.mockito.Mock;
 import org.mockito.MockitoAnnotations;
 
@@ -145,6 +161,11 @@
     @Mock MdnsDiscoveryManager mDiscoveryManager;
     @Mock MdnsAdvertiser mAdvertiser;
     @Mock MdnsSocketProvider mSocketProvider;
+    @Mock WifiManager mWifiManager;
+    @Mock WifiManager.MulticastLock mMulticastLock;
+    @Mock ActivityManager mActivityManager;
+    SocketRequestMonitor mSocketRequestMonitor;
+    OnUidImportanceListener mUidImportanceListener;
     HandlerThread mThread;
     TestHandler mHandler;
     NsdService mService;
@@ -167,9 +188,13 @@
         mHandler = new TestHandler(mThread.getLooper());
         when(mContext.getContentResolver()).thenReturn(mResolver);
         mockService(mContext, MDnsManager.class, MDnsManager.MDNS_SERVICE, mMockMDnsM);
+        mockService(mContext, WifiManager.class, Context.WIFI_SERVICE, mWifiManager);
+        mockService(mContext, ActivityManager.class, Context.ACTIVITY_SERVICE, mActivityManager);
         if (mContext.getSystemService(MDnsManager.class) == null) {
             // Test is using mockito-extended
             doCallRealMethod().when(mContext).getSystemService(MDnsManager.class);
+            doCallRealMethod().when(mContext).getSystemService(WifiManager.class);
+            doCallRealMethod().when(mContext).getSystemService(ActivityManager.class);
         }
         doReturn(true).when(mMockMDnsM).registerService(
                 anyInt(), anyString(), anyString(), anyInt(), any(), anyInt());
@@ -180,9 +205,21 @@
         doReturn(false).when(mDeps).isMdnsDiscoveryManagerEnabled(any(Context.class));
         doReturn(mDiscoveryManager).when(mDeps)
                 .makeMdnsDiscoveryManager(any(), any(), any());
-        doReturn(mSocketProvider).when(mDeps).makeMdnsSocketProvider(any(), any(), any());
+        doReturn(mMulticastLock).when(mWifiManager).createMulticastLock(any());
+        doReturn(mSocketProvider).when(mDeps).makeMdnsSocketProvider(any(), any(), any(), any());
+        doReturn(DEFAULT_RUNNING_APP_ACTIVE_IMPORTANCE_CUTOFF).when(mDeps).getDeviceConfigInt(
+                eq(NsdService.MDNS_CONFIG_RUNNING_APP_ACTIVE_IMPORTANCE_CUTOFF), anyInt());
         doReturn(mAdvertiser).when(mDeps).makeMdnsAdvertiser(any(), any(), any(), any());
         mService = makeService();
+        final ArgumentCaptor<SocketRequestMonitor> cbMonitorCaptor =
+                ArgumentCaptor.forClass(SocketRequestMonitor.class);
+        verify(mDeps).makeMdnsSocketProvider(any(), any(), any(), cbMonitorCaptor.capture());
+        mSocketRequestMonitor = cbMonitorCaptor.getValue();
+
+        final ArgumentCaptor<OnUidImportanceListener> uidListenerCaptor =
+                ArgumentCaptor.forClass(OnUidImportanceListener.class);
+        verify(mActivityManager).addOnUidImportanceListener(uidListenerCaptor.capture(), anyInt());
+        mUidImportanceListener = uidListenerCaptor.getValue();
     }
 
     @After
@@ -738,8 +775,8 @@
     public void testRegisterAndUnregisterServiceInfoCallback() {
         final NsdManager client = connectClient(mService);
         final NsdServiceInfo request = new NsdServiceInfo(SERVICE_NAME, SERVICE_TYPE);
-        final NsdManager.ServiceInfoCallback serviceInfoCallback = mock(
-                NsdManager.ServiceInfoCallback.class);
+        final ServiceInfoCallback serviceInfoCallback = mock(
+                ServiceInfoCallback.class);
         final String serviceTypeWithLocalDomain = SERVICE_TYPE + ".local";
         final Network network = new Network(999);
         request.setNetwork(network);
@@ -813,8 +850,8 @@
         final NsdManager client = connectClient(mService);
         final String invalidServiceType = "a_service";
         final NsdServiceInfo request = new NsdServiceInfo(SERVICE_NAME, invalidServiceType);
-        final NsdManager.ServiceInfoCallback serviceInfoCallback = mock(
-                NsdManager.ServiceInfoCallback.class);
+        final ServiceInfoCallback serviceInfoCallback = mock(
+                ServiceInfoCallback.class);
         client.registerServiceInfoCallback(request, Runnable::run, serviceInfoCallback);
         waitForIdle();
 
@@ -826,8 +863,8 @@
     @Test
     public void testUnregisterNotRegisteredCallback() {
         final NsdManager client = connectClient(mService);
-        final NsdManager.ServiceInfoCallback serviceInfoCallback = mock(
-                NsdManager.ServiceInfoCallback.class);
+        final ServiceInfoCallback serviceInfoCallback = mock(
+                ServiceInfoCallback.class);
 
         assertThrows(IllegalArgumentException.class, () ->
                 client.unregisterServiceInfoCallback(serviceInfoCallback));
@@ -1336,6 +1373,194 @@
         verify(mDiscoveryManager, times(2)).registerListener(anyString(), any(), any());
     }
 
+    @Test
+    @EnableCompatChanges(ENABLE_PLATFORM_MDNS_BACKEND)
+    public void testTakeMulticastLockOnBehalfOfClient_ForWifiNetworksOnly() {
+        // Test on one client in the foreground
+        mUidImportanceListener.onUidImportance(123, IMPORTANCE_FOREGROUND);
+        doReturn(123).when(mDeps).getCallingUid();
+        final NsdManager client = connectClient(mService);
+
+        final NsdServiceInfo regInfo = new NsdServiceInfo(SERVICE_NAME, SERVICE_TYPE);
+        regInfo.setHostAddresses(List.of(parseNumericAddress("192.0.2.123")));
+        regInfo.setPort(12345);
+        // File a request for all networks
+        regInfo.setNetwork(null);
+
+        final RegistrationListener regListener = mock(RegistrationListener.class);
+        client.registerService(regInfo, NsdManager.PROTOCOL_DNS_SD, Runnable::run, regListener);
+        waitForIdle();
+        verify(mSocketProvider).startMonitoringSockets();
+        verify(mAdvertiser).addService(anyInt(), any(), any());
+
+        final Network wifiNetwork1 = new Network(123);
+        final Network wifiNetwork2 = new Network(124);
+        final Network ethernetNetwork = new Network(125);
+
+        final MdnsInterfaceSocket wifiNetworkSocket1 = mock(MdnsInterfaceSocket.class);
+        final MdnsInterfaceSocket wifiNetworkSocket2 = mock(MdnsInterfaceSocket.class);
+        final MdnsInterfaceSocket ethernetNetworkSocket = mock(MdnsInterfaceSocket.class);
+
+        // Nothing happens for networks with no transports, no Wi-Fi transport, or VPN transport
+        mHandler.post(() -> {
+            mSocketRequestMonitor.onSocketRequestFulfilled(
+                    new Network(125), mock(MdnsInterfaceSocket.class), new int[0]);
+            mSocketRequestMonitor.onSocketRequestFulfilled(
+                    ethernetNetwork, ethernetNetworkSocket,
+                    new int[] { TRANSPORT_ETHERNET });
+            mSocketRequestMonitor.onSocketRequestFulfilled(
+                    new Network(127), mock(MdnsInterfaceSocket.class),
+                    new int[] { TRANSPORT_WIFI, TRANSPORT_VPN });
+        });
+        waitForIdle();
+        verify(mWifiManager, never()).createMulticastLock(any());
+
+        // First Wi-Fi network
+        mHandler.post(() -> mSocketRequestMonitor.onSocketRequestFulfilled(
+                wifiNetwork1, wifiNetworkSocket1, new int[] { TRANSPORT_WIFI }));
+        waitForIdle();
+        verify(mWifiManager).createMulticastLock(any());
+        verify(mMulticastLock).acquire();
+
+        // Second Wi-Fi network
+        mHandler.post(() -> mSocketRequestMonitor.onSocketRequestFulfilled(
+                wifiNetwork2, wifiNetworkSocket2, new int[] { TRANSPORT_WIFI }));
+        waitForIdle();
+        verifyNoMoreInteractions(mMulticastLock);
+
+        // One Wi-Fi network becomes unused, nothing happens
+        mHandler.post(() -> mSocketRequestMonitor.onSocketDestroyed(
+                wifiNetwork1, wifiNetworkSocket1));
+        waitForIdle();
+        verifyNoMoreInteractions(mMulticastLock);
+
+        // Ethernet network becomes unused, still nothing
+        mHandler.post(() -> mSocketRequestMonitor.onSocketDestroyed(
+                ethernetNetwork, ethernetNetworkSocket));
+        waitForIdle();
+        verifyNoMoreInteractions(mMulticastLock);
+
+        // The second Wi-Fi network becomes unused, the lock is released
+        mHandler.post(() -> mSocketRequestMonitor.onSocketDestroyed(
+                wifiNetwork2, wifiNetworkSocket2));
+        waitForIdle();
+        verify(mMulticastLock).release();
+    }
+
+    @Test
+    @EnableCompatChanges(ENABLE_PLATFORM_MDNS_BACKEND)
+    public void testTakeMulticastLockOnBehalfOfClient_ForForegroundAppsOnly() {
+        final int uid1 = 12;
+        final int uid2 = 34;
+        final int uid3 = 56;
+        final int uid4 = 78;
+        final InOrder lockOrder = inOrder(mMulticastLock);
+        // Connect one client without any foreground info
+        doReturn(uid1).when(mDeps).getCallingUid();
+        final NsdManager client1 = connectClient(mService);
+
+        // Connect client2 as visible, but not foreground
+        mUidImportanceListener.onUidImportance(uid2, IMPORTANCE_VISIBLE);
+        waitForIdle();
+        doReturn(uid2).when(mDeps).getCallingUid();
+        final NsdManager client2 = connectClient(mService);
+
+        // Connect client3, client4 as foreground
+        mUidImportanceListener.onUidImportance(uid3, IMPORTANCE_FOREGROUND);
+        waitForIdle();
+        doReturn(uid3).when(mDeps).getCallingUid();
+        final NsdManager client3 = connectClient(mService);
+
+        mUidImportanceListener.onUidImportance(uid4, IMPORTANCE_FOREGROUND);
+        waitForIdle();
+        doReturn(uid4).when(mDeps).getCallingUid();
+        final NsdManager client4 = connectClient(mService);
+
+        // First client advertises on any network
+        final NsdServiceInfo regInfo = new NsdServiceInfo(SERVICE_NAME, SERVICE_TYPE);
+        regInfo.setHostAddresses(List.of(parseNumericAddress("192.0.2.123")));
+        regInfo.setPort(12345);
+        regInfo.setNetwork(null);
+        final RegistrationListener regListener = mock(RegistrationListener.class);
+        client1.registerService(regInfo, NsdManager.PROTOCOL_DNS_SD, Runnable::run, regListener);
+        waitForIdle();
+
+        final MdnsInterfaceSocket wifiSocket = mock(MdnsInterfaceSocket.class);
+        final Network wifiNetwork = new Network(123);
+
+        final MdnsInterfaceSocket ethSocket = mock(MdnsInterfaceSocket.class);
+        final Network ethNetwork = new Network(234);
+
+        mHandler.post(() -> {
+            mSocketRequestMonitor.onSocketRequestFulfilled(
+                    wifiNetwork, wifiSocket, new int[] { TRANSPORT_WIFI });
+            mSocketRequestMonitor.onSocketRequestFulfilled(
+                    ethNetwork, ethSocket, new int[] { TRANSPORT_ETHERNET });
+        });
+        waitForIdle();
+
+        // No multicast lock since client1 has no foreground info
+        lockOrder.verifyNoMoreInteractions();
+
+        // Second client discovers specifically on the Wi-Fi network
+        final DiscoveryListener discListener = mock(DiscoveryListener.class);
+        client2.discoverServices(SERVICE_TYPE, NsdManager.PROTOCOL_DNS_SD, wifiNetwork,
+                Runnable::run, discListener);
+        waitForIdle();
+        mHandler.post(() -> mSocketRequestMonitor.onSocketRequestFulfilled(
+                wifiNetwork, wifiSocket, new int[] { TRANSPORT_WIFI }));
+        waitForIdle();
+        // No multicast lock since client2 is not visible enough
+        lockOrder.verifyNoMoreInteractions();
+
+        // Third client registers a callback on all networks
+        final NsdServiceInfo cbInfo = new NsdServiceInfo(SERVICE_NAME, SERVICE_TYPE);
+        cbInfo.setNetwork(null);
+        final ServiceInfoCallback infoCb = mock(ServiceInfoCallback.class);
+        client3.registerServiceInfoCallback(cbInfo, Runnable::run, infoCb);
+        waitForIdle();
+        mHandler.post(() -> {
+            mSocketRequestMonitor.onSocketRequestFulfilled(
+                    wifiNetwork, wifiSocket, new int[] { TRANSPORT_WIFI });
+            mSocketRequestMonitor.onSocketRequestFulfilled(
+                    ethNetwork, ethSocket, new int[] { TRANSPORT_ETHERNET });
+        });
+        waitForIdle();
+
+        // Multicast lock is taken for third client
+        lockOrder.verify(mMulticastLock).acquire();
+
+        // Client3 goes to the background
+        mUidImportanceListener.onUidImportance(uid3, IMPORTANCE_CACHED);
+        waitForIdle();
+        lockOrder.verify(mMulticastLock).release();
+
+        // client4 resolves on a different network
+        final ResolveListener resolveListener = mock(ResolveListener.class);
+        final NsdServiceInfo resolveInfo = new NsdServiceInfo(SERVICE_NAME, SERVICE_TYPE);
+        resolveInfo.setNetwork(ethNetwork);
+        client4.resolveService(resolveInfo, Runnable::run, resolveListener);
+        waitForIdle();
+        mHandler.post(() -> mSocketRequestMonitor.onSocketRequestFulfilled(
+                ethNetwork, ethSocket, new int[] { TRANSPORT_ETHERNET }));
+        waitForIdle();
+
+        // client4 is foreground, but not Wi-Fi
+        lockOrder.verifyNoMoreInteractions();
+
+        // Second client becomes foreground
+        mUidImportanceListener.onUidImportance(uid2, IMPORTANCE_FOREGROUND);
+        waitForIdle();
+
+        lockOrder.verify(mMulticastLock).acquire();
+
+        // Second client is lost
+        mUidImportanceListener.onUidImportance(uid2, IMPORTANCE_GONE);
+        waitForIdle();
+
+        lockOrder.verify(mMulticastLock).release();
+    }
+
     private void waitForIdle() {
         HandlerUtils.waitForIdle(mHandler, TIMEOUT_MS);
     }
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketProviderTest.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketProviderTest.java
index 4f56857..4ef64cb 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketProviderTest.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketProviderTest.java
@@ -32,10 +32,12 @@
 import static org.mockito.ArgumentMatchers.anyInt;
 import static org.mockito.ArgumentMatchers.anyString;
 import static org.mockito.ArgumentMatchers.argThat;
+import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.Mockito.any;
 import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.doCallRealMethod;
 import static org.mockito.Mockito.doReturn;
+import static org.mockito.Mockito.inOrder;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.never;
 import static org.mockito.Mockito.times;
@@ -70,6 +72,7 @@
 import com.android.net.module.util.netlink.StructIfaddrMsg;
 import com.android.net.module.util.netlink.StructNlMsgHdr;
 import com.android.server.connectivity.mdns.MdnsSocketProvider.Dependencies;
+import com.android.server.connectivity.mdns.MdnsSocketProvider.SocketRequestMonitor;
 import com.android.server.connectivity.mdns.internal.SocketNetlinkMonitor;
 import com.android.testutils.DevSdkIgnoreRule;
 import com.android.testutils.DevSdkIgnoreRunner;
@@ -79,6 +82,7 @@
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.mockito.ArgumentCaptor;
+import org.mockito.InOrder;
 import org.mockito.Mock;
 import org.mockito.MockitoAnnotations;
 
@@ -114,6 +118,7 @@
     @Mock private NetworkInterfaceWrapper mTestNetworkIfaceWrapper;
     @Mock private NetworkInterfaceWrapper mLocalOnlyIfaceWrapper;
     @Mock private NetworkInterfaceWrapper mTetheredIfaceWrapper;
+    @Mock private SocketRequestMonitor mSocketRequestMonitor;
     private Handler mHandler;
     private MdnsSocketProvider mSocketProvider;
     private NetworkCallback mNetworkCallback;
@@ -165,7 +170,8 @@
             return mTestSocketNetLinkMonitor;
         }).when(mDeps).createSocketNetlinkMonitor(any(), any(),
                 any());
-        mSocketProvider = new MdnsSocketProvider(mContext, thread.getLooper(), mDeps, mLog);
+        mSocketProvider = new MdnsSocketProvider(mContext, thread.getLooper(), mDeps, mLog,
+                mSocketRequestMonitor);
     }
 
     private void runOnHandler(Runnable r) {
@@ -319,23 +325,30 @@
     public void testSocketRequestAndUnrequestSocket() {
         startMonitoringSockets();
 
+        final InOrder cbMonitorOrder = inOrder(mSocketRequestMonitor);
         final TestSocketCallback testCallback1 = new TestSocketCallback();
         runOnHandler(() -> mSocketProvider.requestSocket(TEST_NETWORK, testCallback1));
         testCallback1.expectedNoCallback();
 
         postNetworkAvailable(TRANSPORT_WIFI);
         testCallback1.expectedSocketCreatedForNetwork(TEST_NETWORK, List.of(LINKADDRV4));
+        cbMonitorOrder.verify(mSocketRequestMonitor).onSocketRequestFulfilled(eq(TEST_NETWORK),
+                any(), eq(new int[] { TRANSPORT_WIFI }));
 
         final TestSocketCallback testCallback2 = new TestSocketCallback();
         runOnHandler(() -> mSocketProvider.requestSocket(TEST_NETWORK, testCallback2));
         testCallback1.expectedNoCallback();
         testCallback2.expectedSocketCreatedForNetwork(TEST_NETWORK, List.of(LINKADDRV4));
+        cbMonitorOrder.verify(mSocketRequestMonitor).onSocketRequestFulfilled(eq(TEST_NETWORK),
+                any(), eq(new int[] { TRANSPORT_WIFI }));
 
         final TestSocketCallback testCallback3 = new TestSocketCallback();
         runOnHandler(() -> mSocketProvider.requestSocket(null /* network */, testCallback3));
         testCallback1.expectedNoCallback();
         testCallback2.expectedNoCallback();
         testCallback3.expectedSocketCreatedForNetwork(TEST_NETWORK, List.of(LINKADDRV4));
+        cbMonitorOrder.verify(mSocketRequestMonitor).onSocketRequestFulfilled(eq(TEST_NETWORK),
+                any(), eq(new int[] { TRANSPORT_WIFI }));
 
         runOnHandler(() -> mTetheringEventCallback.onLocalOnlyInterfacesChanged(
                 List.of(LOCAL_ONLY_IFACE_NAME)));
@@ -343,6 +356,8 @@
         testCallback1.expectedNoCallback();
         testCallback2.expectedNoCallback();
         testCallback3.expectedSocketCreatedForNetwork(null /* network */, List.of());
+        cbMonitorOrder.verify(mSocketRequestMonitor).onSocketRequestFulfilled(eq(null),
+                any(), eq(new int[0]));
 
         runOnHandler(() -> mTetheringEventCallback.onTetheredInterfacesChanged(
                 List.of(TETHERED_IFACE_NAME)));
@@ -350,6 +365,8 @@
         testCallback1.expectedNoCallback();
         testCallback2.expectedNoCallback();
         testCallback3.expectedSocketCreatedForNetwork(null /* network */, List.of());
+        cbMonitorOrder.verify(mSocketRequestMonitor).onSocketRequestFulfilled(eq(null),
+                any(), eq(new int[0]));
 
         runOnHandler(() -> mSocketProvider.unrequestSocket(testCallback1));
         testCallback1.expectedNoCallback();
@@ -360,17 +377,22 @@
         testCallback1.expectedNoCallback();
         testCallback2.expectedInterfaceDestroyedForNetwork(TEST_NETWORK);
         testCallback3.expectedInterfaceDestroyedForNetwork(TEST_NETWORK);
+        cbMonitorOrder.verify(mSocketRequestMonitor).onSocketDestroyed(eq(TEST_NETWORK), any());
 
         runOnHandler(() -> mTetheringEventCallback.onLocalOnlyInterfacesChanged(List.of()));
         testCallback1.expectedNoCallback();
         testCallback2.expectedNoCallback();
         testCallback3.expectedInterfaceDestroyedForNetwork(null /* network */);
+        cbMonitorOrder.verify(mSocketRequestMonitor).onSocketDestroyed(eq(null), any());
 
         runOnHandler(() -> mSocketProvider.unrequestSocket(testCallback3));
         testCallback1.expectedNoCallback();
         testCallback2.expectedNoCallback();
         // There was still a tethered interface, but no callback should be sent once unregistered
         testCallback3.expectedNoCallback();
+
+        // However the socket is getting destroyed, so the callback monitor is notified
+        cbMonitorOrder.verify(mSocketRequestMonitor).onSocketDestroyed(eq(null), any());
     }
 
     private RtNetlinkAddressMessage createNetworkAddressUpdateNetLink(