Merge changes from topic "PerNetworkServiceTypeClient"

* changes:
  Create MdnsServiceTypeClient per network
  Wait for a socket to be created before sending packets
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java b/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java
index 67059e7..fb8af8d 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java
@@ -16,19 +16,25 @@
 
 package com.android.server.connectivity.mdns;
 
+import static com.android.server.connectivity.mdns.MdnsSocketProvider.isNetworkMatched;
+
 import android.Manifest.permission;
 import android.annotation.NonNull;
+import android.annotation.Nullable;
 import android.annotation.RequiresPermission;
 import android.net.Network;
 import android.text.TextUtils;
 import android.util.ArrayMap;
 import android.util.Log;
+import android.util.Pair;
 
+import com.android.internal.annotations.GuardedBy;
 import com.android.internal.annotations.VisibleForTesting;
 import com.android.server.connectivity.mdns.util.MdnsLogger;
 
 import java.io.IOException;
-import java.util.Map;
+import java.util.ArrayList;
+import java.util.List;
 
 /**
  * This class keeps tracking the set of registered {@link MdnsServiceBrowserListener} instances, and
@@ -42,12 +48,62 @@
     private final ExecutorProvider executorProvider;
     private final MdnsSocketClientBase socketClient;
 
-    private final Map<String, MdnsServiceTypeClient> serviceTypeClients = new ArrayMap<>();
+    @GuardedBy("this")
+    @NonNull private final PerNetworkServiceTypeClients perNetworkServiceTypeClients;
+
+    private static class PerNetworkServiceTypeClients {
+        private final ArrayMap<Pair<String, Network>, MdnsServiceTypeClient> clients =
+                new ArrayMap<>();
+
+        public void put(@NonNull String serviceType, @Nullable Network network,
+                @NonNull MdnsServiceTypeClient client) {
+            final Pair<String, Network> perNetworkServiceType = new Pair<>(serviceType, network);
+            clients.put(perNetworkServiceType, client);
+        }
+
+        @Nullable
+        public MdnsServiceTypeClient get(@NonNull String serviceType, @Nullable Network network) {
+            final Pair<String, Network> perNetworkServiceType = new Pair<>(serviceType, network);
+            return clients.getOrDefault(perNetworkServiceType, null);
+        }
+
+        public List<MdnsServiceTypeClient> getByServiceType(@NonNull String serviceType) {
+            final List<MdnsServiceTypeClient> list = new ArrayList<>();
+            for (int i = 0; i < clients.size(); i++) {
+                final Pair<String, Network> perNetworkServiceType = clients.keyAt(i);
+                if (serviceType.equals(perNetworkServiceType.first)) {
+                    list.add(clients.valueAt(i));
+                }
+            }
+            return list;
+        }
+
+        public List<MdnsServiceTypeClient> getByMatchingNetwork(@Nullable Network network) {
+            final List<MdnsServiceTypeClient> list = new ArrayList<>();
+            for (int i = 0; i < clients.size(); i++) {
+                final Pair<String, Network> perNetworkServiceType = clients.keyAt(i);
+                if (isNetworkMatched(network, perNetworkServiceType.second)) {
+                    list.add(clients.valueAt(i));
+                }
+            }
+            return list;
+        }
+
+        public void remove(@NonNull MdnsServiceTypeClient client) {
+            final int index = clients.indexOfValue(client);
+            clients.removeAt(index);
+        }
+
+        public boolean isEmpty() {
+            return clients.isEmpty();
+        }
+    }
 
     public MdnsDiscoveryManager(@NonNull ExecutorProvider executorProvider,
             @NonNull MdnsSocketClientBase socketClient) {
         this.executorProvider = executorProvider;
         this.socketClient = socketClient;
+        perNetworkServiceTypeClients = new PerNetworkServiceTypeClients();
     }
 
     /**
@@ -67,7 +123,7 @@
         LOGGER.log(
                 "Registering listener for subtypes: %s",
                 TextUtils.join(",", searchOptions.getSubtypes()));
-        if (serviceTypeClients.isEmpty()) {
+        if (perNetworkServiceTypeClients.isEmpty()) {
             // First listener. Starts the socket client.
             try {
                 socketClient.startDiscovery();
@@ -77,16 +133,18 @@
             }
         }
         // Request the network for discovery.
-        socketClient.notifyNetworkRequested(listener, searchOptions.getNetwork());
-
-        // All listeners of the same service types shares the same MdnsServiceTypeClient.
-        MdnsServiceTypeClient serviceTypeClient = serviceTypeClients.get(serviceType);
-        if (serviceTypeClient == null) {
-            serviceTypeClient = createServiceTypeClient(serviceType);
-            serviceTypeClients.put(serviceType, serviceTypeClient);
-        }
-        // TODO(b/264634275): Wait for a socket to be created before sending packets.
-        serviceTypeClient.startSendAndReceive(listener, searchOptions);
+        socketClient.notifyNetworkRequested(listener, searchOptions.getNetwork(), network -> {
+            synchronized (this) {
+                // All listeners of the same service types shares the same MdnsServiceTypeClient.
+                MdnsServiceTypeClient serviceTypeClient =
+                        perNetworkServiceTypeClients.get(serviceType, network);
+                if (serviceTypeClient == null) {
+                    serviceTypeClient = createServiceTypeClient(serviceType, network);
+                    perNetworkServiceTypeClients.put(serviceType, network, serviceTypeClient);
+                }
+                serviceTypeClient.startSendAndReceive(listener, searchOptions);
+            }
+        });
     }
 
     /**
@@ -101,17 +159,21 @@
             @NonNull String serviceType, @NonNull MdnsServiceBrowserListener listener) {
         LOGGER.log("Unregistering listener for service type: %s", serviceType);
         if (DBG) Log.d(TAG, "Unregistering listener for serviceType:" + serviceType);
-        MdnsServiceTypeClient serviceTypeClient = serviceTypeClients.get(serviceType);
-        if (serviceTypeClient == null) {
+        final List<MdnsServiceTypeClient> serviceTypeClients =
+                perNetworkServiceTypeClients.getByServiceType(serviceType);
+        if (serviceTypeClients.isEmpty()) {
             return;
         }
-        if (serviceTypeClient.stopSendAndReceive(listener)) {
-            // No listener is registered for the service type anymore, remove it from the list of
-            // the service type clients.
-            serviceTypeClients.remove(serviceType);
-            if (serviceTypeClients.isEmpty()) {
-                // No discovery request. Stops the socket client.
-                socketClient.stopDiscovery();
+        for (int i = 0; i < serviceTypeClients.size(); i++) {
+            final MdnsServiceTypeClient serviceTypeClient = serviceTypeClients.get(i);
+            if (serviceTypeClient.stopSendAndReceive(listener)) {
+                // No listener is registered for the service type anymore, remove it from the list
+                // of the service type clients.
+                perNetworkServiceTypeClients.remove(serviceTypeClient);
+                if (perNetworkServiceTypeClients.isEmpty()) {
+                    // No discovery request. Stops the socket client.
+                    socketClient.stopDiscovery();
+                }
             }
         }
         // Unrequested the network.
@@ -121,22 +183,26 @@
     @Override
     public synchronized void onResponseReceived(@NonNull MdnsPacket packet,
             int interfaceIndex, Network network) {
-        for (MdnsServiceTypeClient serviceTypeClient : serviceTypeClients.values()) {
+        for (MdnsServiceTypeClient serviceTypeClient
+                : perNetworkServiceTypeClients.getByMatchingNetwork(network)) {
             serviceTypeClient.processResponse(packet, interfaceIndex, network);
         }
     }
 
     @Override
-    public synchronized void onFailedToParseMdnsResponse(int receivedPacketNumber, int errorCode) {
-        for (MdnsServiceTypeClient serviceTypeClient : serviceTypeClients.values()) {
+    public synchronized void onFailedToParseMdnsResponse(int receivedPacketNumber, int errorCode,
+            Network network) {
+        for (MdnsServiceTypeClient serviceTypeClient
+                : perNetworkServiceTypeClients.getByMatchingNetwork(network)) {
             serviceTypeClient.onFailedToParseMdnsResponse(receivedPacketNumber, errorCode);
         }
     }
 
     @VisibleForTesting
-    MdnsServiceTypeClient createServiceTypeClient(@NonNull String serviceType) {
+    MdnsServiceTypeClient createServiceTypeClient(@NonNull String serviceType,
+            @Nullable Network network) {
         return new MdnsServiceTypeClient(
                 serviceType, socketClient,
-                executorProvider.newServiceTypeClientSchedulerExecutor());
+                executorProvider.newServiceTypeClientSchedulerExecutor(), network);
     }
 }
\ No newline at end of file
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java b/service-t/src/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java
index 93972d9..5254342 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java
@@ -63,6 +63,12 @@
     }
 
     private class InterfaceSocketCallback implements MdnsSocketProvider.SocketCallback {
+        private final SocketCreationCallback mSocketCreationCallback;
+
+        InterfaceSocketCallback(SocketCreationCallback socketCreationCallback) {
+            mSocketCreationCallback = socketCreationCallback;
+        }
+
         @Override
         public void onSocketCreated(@NonNull Network network,
                 @NonNull MdnsInterfaceSocket socket, @NonNull List<LinkAddress> addresses) {
@@ -76,6 +82,7 @@
             }
             socket.addPacketHandler(handler);
             mActiveNetworkSockets.put(socket, network);
+            mSocketCreationCallback.onSocketCreated(network);
         }
 
         @Override
@@ -114,10 +121,11 @@
      * @param listener the listener for discovery.
      * @param network the target network for discovery. Null means discovery on all possible
      *                interfaces.
+     * @param socketCreationCallback the callback to notify socket creation.
      */
     @Override
     public void notifyNetworkRequested(@NonNull MdnsServiceBrowserListener listener,
-            @Nullable Network network) {
+            @Nullable Network network, @NonNull SocketCreationCallback socketCreationCallback) {
         ensureRunningOnHandlerThread(mHandler);
         InterfaceSocketCallback callback = mRequestedNetworks.get(listener);
         if (callback != null) {
@@ -125,7 +133,7 @@
         }
 
         if (DBG) Log.d(TAG, "notifyNetworkRequested: network=" + network);
-        callback = new InterfaceSocketCallback();
+        callback = new InterfaceSocketCallback(socketCreationCallback);
         mRequestedNetworks.put(listener, callback);
         mSocketProvider.requestSocket(network, callback);
     }
@@ -173,7 +181,7 @@
             if (e.code != MdnsResponseErrorCode.ERROR_NOT_RESPONSE_MESSAGE) {
                 Log.e(TAG, e.getMessage(), e);
                 if (mCallback != null) {
-                    mCallback.onFailedToParseMdnsResponse(packetNumber, e.code);
+                    mCallback.onFailedToParseMdnsResponse(packetNumber, e.code, network);
                 }
             }
             return;
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
index df270bb..f87804b 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
@@ -57,6 +57,7 @@
     private final MdnsSocketClientBase socketClient;
     private final MdnsResponseDecoder responseDecoder;
     private final ScheduledExecutorService executor;
+    @Nullable private final Network network;
     private final Object lock = new Object();
     private final ArrayMap<MdnsServiceBrowserListener, MdnsSearchOptions> listeners =
             new ArrayMap<>();
@@ -88,8 +89,9 @@
     public MdnsServiceTypeClient(
             @NonNull String serviceType,
             @NonNull MdnsSocketClientBase socketClient,
-            @NonNull ScheduledExecutorService executor) {
-        this(serviceType, socketClient, executor, new MdnsResponseDecoder.Clock());
+            @NonNull ScheduledExecutorService executor,
+            @Nullable Network network) {
+        this(serviceType, socketClient, executor, new MdnsResponseDecoder.Clock(), network);
     }
 
     @VisibleForTesting
@@ -97,13 +99,15 @@
             @NonNull String serviceType,
             @NonNull MdnsSocketClientBase socketClient,
             @NonNull ScheduledExecutorService executor,
-            @NonNull MdnsResponseDecoder.Clock clock) {
+            @NonNull MdnsResponseDecoder.Clock clock,
+            @Nullable Network network) {
         this.serviceType = serviceType;
         this.socketClient = socketClient;
         this.executor = executor;
         this.serviceTypeLabels = TextUtils.split(serviceType, "\\.");
         this.responseDecoder = new MdnsResponseDecoder(clock, serviceTypeLabels);
         this.clock = clock;
+        this.network = network;
     }
 
     private static MdnsServiceInfo buildMdnsServiceInfoFromResponse(
@@ -204,7 +208,9 @@
      */
     public boolean stopSendAndReceive(@NonNull MdnsServiceBrowserListener listener) {
         synchronized (lock) {
-            listeners.remove(listener);
+            if (listeners.remove(listener) == null) {
+                return listeners.isEmpty();
+            }
             if (listeners.isEmpty() && requestTaskFuture != null) {
                 requestTaskFuture.cancel(true);
                 requestTaskFuture = null;
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClient.java b/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClient.java
index c03e6aa..783b18a 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClient.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClient.java
@@ -423,7 +423,7 @@
             LOGGER.w(String.format("Error while decoding %s packet (%d): %d",
                     responseType, packetNumber, e.code));
             if (callback != null) {
-                callback.onFailedToParseMdnsResponse(packetNumber, e.code);
+                callback.onFailedToParseMdnsResponse(packetNumber, e.code, network);
             }
             return e.code;
         }
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClientBase.java b/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClientBase.java
index 796dc83..ebafc49 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClientBase.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClientBase.java
@@ -64,7 +64,9 @@
 
     /*** Notify that the given network is requested for mdns discovery / resolution */
     default void notifyNetworkRequested(@NonNull MdnsServiceBrowserListener listener,
-            @Nullable Network network) { }
+            @Nullable Network network, @NonNull SocketCreationCallback socketCreationCallback) {
+        socketCreationCallback.onSocketCreated(network);
+    }
 
     /*** Notify that the network is unrequested */
     default void notifyNetworkUnrequested(@NonNull MdnsServiceBrowserListener listener) { }
@@ -76,6 +78,13 @@
                 @Nullable Network network);
 
         /*** Parse a mdns response failed */
-        void onFailedToParseMdnsResponse(int receivedPacketNumber, int errorCode);
+        void onFailedToParseMdnsResponse(int receivedPacketNumber, int errorCode,
+                @Nullable Network network);
+    }
+
+    /*** Callback for requested socket creation  */
+    interface SocketCreationCallback {
+        /*** Notify requested socket is created */
+        void onSocketCreated(@Nullable Network network);
     }
 }
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsDiscoveryManagerTests.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsDiscoveryManagerTests.java
index e6b8326..7e7e6a4 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsDiscoveryManagerTests.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsDiscoveryManagerTests.java
@@ -18,20 +18,25 @@
 
 import static com.android.testutils.DevSdkIgnoreRuleKt.SC_V2;
 
-import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.any;
+import static org.mockito.Mockito.eq;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
 
 import android.annotation.NonNull;
+import android.annotation.Nullable;
 import android.net.Network;
 import android.text.TextUtils;
+import android.util.Pair;
 
+import com.android.server.connectivity.mdns.MdnsSocketClientBase.SocketCreationCallback;
 import com.android.testutils.DevSdkIgnoreRule;
 import com.android.testutils.DevSdkIgnoreRunner;
 
 import org.junit.Before;
 import org.junit.Test;
 import org.junit.runner.RunWith;
+import org.mockito.ArgumentCaptor;
 import org.mockito.Mock;
 import org.mockito.MockitoAnnotations;
 
@@ -48,6 +53,10 @@
 
     private static final String SERVICE_TYPE_1 = "_googlecast._tcp.local";
     private static final String SERVICE_TYPE_2 = "_test._tcp.local";
+    private static final Pair<String, Network> PER_NETWORK_SERVICE_TYPE_1 =
+            Pair.create(SERVICE_TYPE_1, null);
+    private static final Pair<String, Network> PER_NETWORK_SERVICE_TYPE_2 =
+            Pair.create(SERVICE_TYPE_2, null);
 
     @Mock private ExecutorProvider executorProvider;
     @Mock private MdnsSocketClientBase socketClient;
@@ -69,10 +78,13 @@
 
         discoveryManager = new MdnsDiscoveryManager(executorProvider, socketClient) {
                     @Override
-                    MdnsServiceTypeClient createServiceTypeClient(@NonNull String serviceType) {
-                        if (serviceType.equals(SERVICE_TYPE_1)) {
+                    MdnsServiceTypeClient createServiceTypeClient(@NonNull String serviceType,
+                            @Nullable Network network) {
+                        final Pair<String, Network> perNetworkServiceType =
+                                Pair.create(serviceType, network);
+                        if (perNetworkServiceType.equals(PER_NETWORK_SERVICE_TYPE_1)) {
                             return mockServiceTypeClientOne;
-                        } else if (serviceType.equals(SERVICE_TYPE_2)) {
+                        } else if (perNetworkServiceType.equals(PER_NETWORK_SERVICE_TYPE_2)) {
                             return mockServiceTypeClientTwo;
                         }
                         return null;
@@ -80,13 +92,23 @@
                 };
     }
 
+    private void verifyListenerRegistration(String serviceType, MdnsServiceBrowserListener listener,
+            MdnsServiceTypeClient client) throws IOException {
+        final ArgumentCaptor<SocketCreationCallback> callbackCaptor =
+                ArgumentCaptor.forClass(SocketCreationCallback.class);
+        discoveryManager.registerListener(serviceType, listener,
+                MdnsSearchOptions.getDefaultOptions());
+        verify(socketClient).startDiscovery();
+        verify(socketClient).notifyNetworkRequested(
+                eq(listener), any(), callbackCaptor.capture());
+        final SocketCreationCallback callback = callbackCaptor.getValue();
+        callback.onSocketCreated(null /* network */);
+        verify(client).startSendAndReceive(listener, MdnsSearchOptions.getDefaultOptions());
+    }
+
     @Test
     public void registerListener_unregisterListener() throws IOException {
-        discoveryManager.registerListener(
-                SERVICE_TYPE_1, mockListenerOne, MdnsSearchOptions.getDefaultOptions());
-        verify(socketClient).startDiscovery();
-        verify(mockServiceTypeClientOne)
-                .startSendAndReceive(mockListenerOne, MdnsSearchOptions.getDefaultOptions());
+        verifyListenerRegistration(SERVICE_TYPE_1, mockListenerOne, mockServiceTypeClientOne);
 
         when(mockServiceTypeClientOne.stopSendAndReceive(mockListenerOne)).thenReturn(true);
         discoveryManager.unregisterListener(SERVICE_TYPE_1, mockListenerOne);
@@ -96,40 +118,30 @@
 
     @Test
     public void registerMultipleListeners() throws IOException {
-        discoveryManager.registerListener(
-                SERVICE_TYPE_1, mockListenerOne, MdnsSearchOptions.getDefaultOptions());
-        verify(socketClient).startDiscovery();
-        verify(mockServiceTypeClientOne)
-                .startSendAndReceive(mockListenerOne, MdnsSearchOptions.getDefaultOptions());
-
-        discoveryManager.registerListener(
-                SERVICE_TYPE_2, mockListenerTwo, MdnsSearchOptions.getDefaultOptions());
-        verify(mockServiceTypeClientTwo)
-                .startSendAndReceive(mockListenerTwo, MdnsSearchOptions.getDefaultOptions());
+        verifyListenerRegistration(SERVICE_TYPE_1, mockListenerOne, mockServiceTypeClientOne);
+        verifyListenerRegistration(SERVICE_TYPE_2, mockListenerTwo, mockServiceTypeClientTwo);
     }
 
     @Test
-    public void onResponseReceived() {
-        discoveryManager.registerListener(
-                SERVICE_TYPE_1, mockListenerOne, MdnsSearchOptions.getDefaultOptions());
-        discoveryManager.registerListener(
-                SERVICE_TYPE_2, mockListenerTwo, MdnsSearchOptions.getDefaultOptions());
+    public void onResponseReceived() throws IOException {
+        verifyListenerRegistration(SERVICE_TYPE_1, mockListenerOne, mockServiceTypeClientOne);
+        verifyListenerRegistration(SERVICE_TYPE_2, mockListenerTwo, mockServiceTypeClientTwo);
 
         MdnsPacket responseForServiceTypeOne = createMdnsPacket(SERVICE_TYPE_1);
         final int ifIndex = 1;
-        final Network network = mock(Network.class);
-        discoveryManager.onResponseReceived(responseForServiceTypeOne, ifIndex, network);
+        discoveryManager.onResponseReceived(responseForServiceTypeOne, ifIndex, null /* network */);
         verify(mockServiceTypeClientOne).processResponse(responseForServiceTypeOne, ifIndex,
-                network);
+                null /* network */);
 
         MdnsPacket responseForServiceTypeTwo = createMdnsPacket(SERVICE_TYPE_2);
-        discoveryManager.onResponseReceived(responseForServiceTypeTwo, ifIndex, network);
+        discoveryManager.onResponseReceived(responseForServiceTypeTwo, ifIndex, null /* network */);
         verify(mockServiceTypeClientTwo).processResponse(responseForServiceTypeTwo, ifIndex,
-                network);
+                null /* network */);
 
         MdnsPacket responseForSubtype = createMdnsPacket("subtype._sub._googlecast._tcp.local");
-        discoveryManager.onResponseReceived(responseForSubtype, ifIndex, network);
-        verify(mockServiceTypeClientOne).processResponse(responseForSubtype, ifIndex, network);
+        discoveryManager.onResponseReceived(responseForSubtype, ifIndex, null /* network */);
+        verify(mockServiceTypeClientOne).processResponse(responseForSubtype, ifIndex,
+                null /* network */);
     }
 
     private MdnsPacket createMdnsPacket(String serviceType) {
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClientTest.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClientTest.java
index 1e322e4..90c43e5 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClientTest.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClientTest.java
@@ -62,6 +62,7 @@
     @Mock private MdnsInterfaceSocket mSocket;
     @Mock private MdnsServiceBrowserListener mListener;
     @Mock private MdnsSocketClientBase.Callback mCallback;
+    @Mock private MdnsSocketClientBase.SocketCreationCallback mSocketCreationCallback;
     private MdnsMultinetworkSocketClient mSocketClient;
     private Handler mHandler;
 
@@ -78,7 +79,8 @@
     private SocketCallback expectSocketCallback() {
         final ArgumentCaptor<SocketCallback> callbackCaptor =
                 ArgumentCaptor.forClass(SocketCallback.class);
-        mHandler.post(() -> mSocketClient.notifyNetworkRequested(mListener, mNetwork));
+        mHandler.post(() -> mSocketClient.notifyNetworkRequested(
+                mListener, mNetwork, mSocketCreationCallback));
         verify(mProvider, timeout(DEFAULT_TIMEOUT))
                 .requestSocket(eq(mNetwork), callbackCaptor.capture());
         return callbackCaptor.getValue();
@@ -107,6 +109,7 @@
         doReturn(createEmptyNetworkInterface()).when(mSocket).getInterface();
         // Notify socket created
         callback.onSocketCreated(mNetwork, mSocket, List.of());
+        verify(mSocketCreationCallback).onSocketCreated(mNetwork);
 
         // Send packet to IPv4 with target network and verify sending has been called.
         mSocketClient.sendMulticastPacket(ipv4Packet, mNetwork);
@@ -138,6 +141,7 @@
         doReturn(createEmptyNetworkInterface()).when(mSocket).getInterface();
         // Notify socket created
         callback.onSocketCreated(mNetwork, mSocket, List.of());
+        verify(mSocketCreationCallback).onSocketCreated(mNetwork);
 
         final ArgumentCaptor<PacketHandler> handlerCaptor =
                 ArgumentCaptor.forClass(PacketHandler.class);
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java
index d9fa10f..3e2ea35 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java
@@ -166,7 +166,7 @@
 
         client =
                 new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor,
-                        mockDecoderClock) {
+                        mockDecoderClock, mockNetwork) {
                     @Override
                     MdnsPacketWriter createMdnsPacketWriter() {
                         return mockPacketWriter;
@@ -701,7 +701,7 @@
         final String serviceInstanceName = "service-instance-1";
         client =
                 new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor,
-                        mockDecoderClock) {
+                        mockDecoderClock, mockNetwork) {
                     @Override
                     MdnsPacketWriter createMdnsPacketWriter() {
                         return mockPacketWriter;
@@ -740,7 +740,7 @@
         final String serviceInstanceName = "service-instance-1";
         client =
                 new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor,
-                        mockDecoderClock) {
+                        mockDecoderClock, mockNetwork) {
                     @Override
                     MdnsPacketWriter createMdnsPacketWriter() {
                         return mockPacketWriter;
@@ -773,7 +773,7 @@
         final String serviceInstanceName = "service-instance-1";
         client =
                 new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor,
-                        mockDecoderClock) {
+                        mockDecoderClock, mockNetwork) {
                     @Override
                     MdnsPacketWriter createMdnsPacketWriter() {
                         return mockPacketWriter;
@@ -897,7 +897,8 @@
 
     @Test
     public void testProcessResponse_Resolve() throws Exception {
-        client = new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor);
+        client = new MdnsServiceTypeClient(
+                SERVICE_TYPE, mockSocketClient, currentThreadExecutor, mockNetwork);
 
         final String instanceName = "service-instance";
         final String[] hostname = new String[] { "testhost "};
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketClientTests.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketClientTests.java
index 9048686..abb1747 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketClientTests.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketClientTests.java
@@ -412,7 +412,8 @@
         mdnsClient.startDiscovery();
 
         verify(mockCallback, timeout(TIMEOUT).atLeast(1))
-                .onFailedToParseMdnsResponse(anyInt(), eq(MdnsResponseErrorCode.ERROR_END_OF_FILE));
+                .onFailedToParseMdnsResponse(
+                        anyInt(), eq(MdnsResponseErrorCode.ERROR_END_OF_FILE), any());
 
         mdnsClient.stopDiscovery();
     }
@@ -433,7 +434,8 @@
         mdnsClient.startDiscovery();
 
         verify(mockCallback, timeout(TIMEOUT).atLeast(1))
-                .onFailedToParseMdnsResponse(1, MdnsResponseErrorCode.ERROR_END_OF_FILE);
+                .onFailedToParseMdnsResponse(
+                        eq(1), eq(MdnsResponseErrorCode.ERROR_END_OF_FILE), any());
 
         mdnsClient.stopDiscovery();
     }