Merge "Create an MdnsServiceTypeClient using a SocketKey"
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java b/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java
index 92a26f1..afad3b7 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java
@@ -22,7 +22,6 @@
 import android.annotation.NonNull;
 import android.annotation.Nullable;
 import android.annotation.RequiresPermission;
-import android.net.Network;
 import android.os.Handler;
 import android.os.HandlerThread;
 import android.util.ArrayMap;
@@ -36,7 +35,6 @@
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.List;
-import java.util.Objects;
 
 /**
  * This class keeps tracking the set of registered {@link MdnsServiceBrowserListener} instances, and
@@ -50,54 +48,58 @@
     private final MdnsSocketClientBase socketClient;
     @NonNull private final SharedLog sharedLog;
 
-    @NonNull private final PerNetworkServiceTypeClients perNetworkServiceTypeClients;
+    @NonNull private final PerSocketServiceTypeClients perSocketServiceTypeClients;
     @NonNull private final Handler handler;
     @Nullable private final HandlerThread handlerThread;
 
-    private static class PerNetworkServiceTypeClients {
-        private final ArrayMap<Pair<String, Network>, MdnsServiceTypeClient> clients =
+    private static class PerSocketServiceTypeClients {
+        private final ArrayMap<Pair<String, SocketKey>, MdnsServiceTypeClient> clients =
                 new ArrayMap<>();
 
-        public void put(@NonNull String serviceType, @Nullable Network network,
+        public void put(@NonNull String serviceType, @NonNull SocketKey socketKey,
                 @NonNull MdnsServiceTypeClient client) {
             final String dnsLowerServiceType = MdnsUtils.toDnsLowerCase(serviceType);
-            final Pair<String, Network> perNetworkServiceType = new Pair<>(dnsLowerServiceType,
-                    network);
-            clients.put(perNetworkServiceType, client);
+            final Pair<String, SocketKey> perSocketServiceType = new Pair<>(dnsLowerServiceType,
+                    socketKey);
+            clients.put(perSocketServiceType, client);
         }
 
         @Nullable
-        public MdnsServiceTypeClient get(@NonNull String serviceType, @Nullable Network network) {
+        public MdnsServiceTypeClient get(
+                @NonNull String serviceType, @NonNull SocketKey socketKey) {
             final String dnsLowerServiceType = MdnsUtils.toDnsLowerCase(serviceType);
-            final Pair<String, Network> perNetworkServiceType = new Pair<>(dnsLowerServiceType,
-                    network);
-            return clients.getOrDefault(perNetworkServiceType, null);
+            final Pair<String, SocketKey> perSocketServiceType = new Pair<>(dnsLowerServiceType,
+                    socketKey);
+            return clients.getOrDefault(perSocketServiceType, null);
         }
 
         public List<MdnsServiceTypeClient> getByServiceType(@NonNull String serviceType) {
             final String dnsLowerServiceType = MdnsUtils.toDnsLowerCase(serviceType);
             final List<MdnsServiceTypeClient> list = new ArrayList<>();
             for (int i = 0; i < clients.size(); i++) {
-                final Pair<String, Network> perNetworkServiceType = clients.keyAt(i);
-                if (dnsLowerServiceType.equals(perNetworkServiceType.first)) {
+                final Pair<String, SocketKey> perSocketServiceType = clients.keyAt(i);
+                if (dnsLowerServiceType.equals(perSocketServiceType.first)) {
                     list.add(clients.valueAt(i));
                 }
             }
             return list;
         }
 
-        public List<MdnsServiceTypeClient> getByNetwork(@Nullable Network network) {
+        public List<MdnsServiceTypeClient> getBySocketKey(@NonNull SocketKey socketKey) {
             final List<MdnsServiceTypeClient> list = new ArrayList<>();
             for (int i = 0; i < clients.size(); i++) {
-                final Pair<String, Network> perNetworkServiceType = clients.keyAt(i);
-                final Network serviceTypeNetwork = perNetworkServiceType.second;
-                if (Objects.equals(network, serviceTypeNetwork)) {
+                final Pair<String, SocketKey> perSocketServiceType = clients.keyAt(i);
+                if (socketKey.equals(perSocketServiceType.second)) {
                     list.add(clients.valueAt(i));
                 }
             }
             return list;
         }
 
+        public List<MdnsServiceTypeClient> getAllMdnsServiceTypeClient() {
+            return new ArrayList<>(clients.values());
+        }
+
         public void remove(@NonNull MdnsServiceTypeClient client) {
             final int index = clients.indexOfValue(client);
             clients.removeAt(index);
@@ -113,7 +115,7 @@
         this.executorProvider = executorProvider;
         this.socketClient = socketClient;
         this.sharedLog = sharedLog;
-        this.perNetworkServiceTypeClients = new PerNetworkServiceTypeClients();
+        this.perSocketServiceTypeClients = new PerSocketServiceTypeClients();
         if (socketClient.getLooper() != null) {
             this.handlerThread = null;
             this.handler = new Handler(socketClient.getLooper());
@@ -164,7 +166,7 @@
             @NonNull String serviceType,
             @NonNull MdnsServiceBrowserListener listener,
             @NonNull MdnsSearchOptions searchOptions) {
-        if (perNetworkServiceTypeClients.isEmpty()) {
+        if (perSocketServiceTypeClients.isEmpty()) {
             // First listener. Starts the socket client.
             try {
                 socketClient.startDiscovery();
@@ -177,29 +179,29 @@
         socketClient.notifyNetworkRequested(listener, searchOptions.getNetwork(),
                 new MdnsSocketClientBase.SocketCreationCallback() {
                     @Override
-                    public void onSocketCreated(@Nullable Network network) {
+                    public void onSocketCreated(@NonNull SocketKey socketKey) {
                         ensureRunningOnHandlerThread(handler);
                         // All listeners of the same service types shares the same
                         // MdnsServiceTypeClient.
                         MdnsServiceTypeClient serviceTypeClient =
-                                perNetworkServiceTypeClients.get(serviceType, network);
+                                perSocketServiceTypeClients.get(serviceType, socketKey);
                         if (serviceTypeClient == null) {
-                            serviceTypeClient = createServiceTypeClient(serviceType, network);
-                            perNetworkServiceTypeClients.put(serviceType, network,
+                            serviceTypeClient = createServiceTypeClient(serviceType, socketKey);
+                            perSocketServiceTypeClients.put(serviceType, socketKey,
                                     serviceTypeClient);
                         }
                         serviceTypeClient.startSendAndReceive(listener, searchOptions);
                     }
 
                     @Override
-                    public void onAllSocketsDestroyed(@Nullable Network network) {
+                    public void onAllSocketsDestroyed(@NonNull SocketKey socketKey) {
                         ensureRunningOnHandlerThread(handler);
                         final MdnsServiceTypeClient serviceTypeClient =
-                                perNetworkServiceTypeClients.get(serviceType, network);
+                                perSocketServiceTypeClients.get(serviceType, socketKey);
                         if (serviceTypeClient == null) return;
                         // Notify all listeners that all services are removed from this socket.
                         serviceTypeClient.notifySocketDestroyed();
-                        perNetworkServiceTypeClients.remove(serviceTypeClient);
+                        perSocketServiceTypeClients.remove(serviceTypeClient);
                     }
                 });
     }
@@ -224,7 +226,7 @@
         socketClient.notifyNetworkUnrequested(listener);
 
         final List<MdnsServiceTypeClient> serviceTypeClients =
-                perNetworkServiceTypeClients.getByServiceType(serviceType);
+                perSocketServiceTypeClients.getByServiceType(serviceType);
         if (serviceTypeClients.isEmpty()) {
             return;
         }
@@ -233,10 +235,10 @@
             if (serviceTypeClient.stopSendAndReceive(listener)) {
                 // No listener is registered for the service type anymore, remove it from the list
                 // of the service type clients.
-                perNetworkServiceTypeClients.remove(serviceTypeClient);
+                perSocketServiceTypeClients.remove(serviceTypeClient);
             }
         }
-        if (perNetworkServiceTypeClients.isEmpty()) {
+        if (perSocketServiceTypeClients.isEmpty()) {
             // No discovery request. Stops the socket client.
             sharedLog.i("All service type listeners unregistered; stopping discovery");
             socketClient.stopDiscovery();
@@ -244,50 +246,48 @@
     }
 
     @Override
-    public void onResponseReceived(@NonNull MdnsPacket packet,
-            int interfaceIndex, @Nullable Network network) {
+    public void onResponseReceived(@NonNull MdnsPacket packet, @NonNull SocketKey socketKey) {
         checkAndRunOnHandlerThread(() ->
-                handleOnResponseReceived(packet, interfaceIndex, network));
+                handleOnResponseReceived(packet, socketKey));
     }
 
-    private void handleOnResponseReceived(@NonNull MdnsPacket packet, int interfaceIndex,
-            @Nullable Network network) {
-        for (MdnsServiceTypeClient serviceTypeClient
-                : getMdnsServiceTypeClient(network)) {
-            serviceTypeClient.processResponse(packet, interfaceIndex, network);
+    private void handleOnResponseReceived(@NonNull MdnsPacket packet,
+            @NonNull SocketKey socketKey) {
+        for (MdnsServiceTypeClient serviceTypeClient : getMdnsServiceTypeClient(socketKey)) {
+            serviceTypeClient.processResponse(
+                    packet, socketKey.getInterfaceIndex(), socketKey.getNetwork());
         }
     }
 
-    private List<MdnsServiceTypeClient> getMdnsServiceTypeClient(@Nullable Network network) {
+    private List<MdnsServiceTypeClient> getMdnsServiceTypeClient(@NonNull SocketKey socketKey) {
         if (socketClient.supportsRequestingSpecificNetworks()) {
-            return perNetworkServiceTypeClients.getByNetwork(network);
+            return perSocketServiceTypeClients.getBySocketKey(socketKey);
         } else {
-            return perNetworkServiceTypeClients.getByNetwork(null);
+            return perSocketServiceTypeClients.getAllMdnsServiceTypeClient();
         }
     }
 
     @Override
     public void onFailedToParseMdnsResponse(int receivedPacketNumber, int errorCode,
-            @Nullable Network network) {
+            @NonNull SocketKey socketKey) {
         checkAndRunOnHandlerThread(() ->
-                handleOnFailedToParseMdnsResponse(receivedPacketNumber, errorCode, network));
+                handleOnFailedToParseMdnsResponse(receivedPacketNumber, errorCode, socketKey));
     }
 
     private void handleOnFailedToParseMdnsResponse(int receivedPacketNumber, int errorCode,
-            @Nullable Network network) {
-        for (MdnsServiceTypeClient serviceTypeClient
-                : getMdnsServiceTypeClient(network)) {
+            @NonNull SocketKey socketKey) {
+        for (MdnsServiceTypeClient serviceTypeClient : getMdnsServiceTypeClient(socketKey)) {
             serviceTypeClient.onFailedToParseMdnsResponse(receivedPacketNumber, errorCode);
         }
     }
 
     @VisibleForTesting
     MdnsServiceTypeClient createServiceTypeClient(@NonNull String serviceType,
-            @Nullable Network network) {
-        sharedLog.log("createServiceTypeClient for type:" + serviceType + ", net:" + network);
+            @NonNull SocketKey socketKey) {
+        sharedLog.log("createServiceTypeClient for type:" + serviceType + " " + socketKey);
         return new MdnsServiceTypeClient(
                 serviceType, socketClient,
-                executorProvider.newServiceTypeClientSchedulerExecutor(), network,
-                sharedLog.forSubComponent(serviceType + "-" + network));
+                executorProvider.newServiceTypeClientSchedulerExecutor(), socketKey,
+                sharedLog.forSubComponent(serviceType + "-" + socketKey));
     }
 }
\ No newline at end of file
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java b/service-t/src/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java
index 03be681..d0ca20e 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java
@@ -84,7 +84,7 @@
             }
             socket.addPacketHandler(handler);
             mActiveNetworkSockets.put(socket, socketKey);
-            mSocketCreationCallback.onSocketCreated(socketKey.getNetwork());
+            mSocketCreationCallback.onSocketCreated(socketKey);
         }
 
         @Override
@@ -97,7 +97,7 @@
         private void notifySocketDestroyed(@NonNull MdnsInterfaceSocket socket) {
             final SocketKey socketKey = mActiveNetworkSockets.remove(socket);
             if (!isAnySocketActive(socketKey)) {
-                mSocketCreationCallback.onAllSocketsDestroyed(socketKey.getNetwork());
+                mSocketCreationCallback.onAllSocketsDestroyed(socketKey);
             }
         }
 
@@ -247,16 +247,14 @@
             if (e.code != MdnsResponseErrorCode.ERROR_NOT_RESPONSE_MESSAGE) {
                 Log.e(TAG, e.getMessage(), e);
                 if (mCallback != null) {
-                    mCallback.onFailedToParseMdnsResponse(
-                            packetNumber, e.code, socketKey.getNetwork());
+                    mCallback.onFailedToParseMdnsResponse(packetNumber, e.code, socketKey);
                 }
             }
             return;
         }
 
         if (mCallback != null) {
-            mCallback.onResponseReceived(
-                    response, socketKey.getInterfaceIndex(), socketKey.getNetwork());
+            mCallback.onResponseReceived(response, socketKey);
         }
     }
 
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
index 0e3522c..bdc673e 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
@@ -58,7 +58,7 @@
     private final MdnsSocketClientBase socketClient;
     private final MdnsResponseDecoder responseDecoder;
     private final ScheduledExecutorService executor;
-    @Nullable private final Network network;
+    @NonNull private final SocketKey socketKey;
     @NonNull private final SharedLog sharedLog;
     private final Object lock = new Object();
     private final ArrayMap<MdnsServiceBrowserListener, MdnsSearchOptions> listeners =
@@ -90,9 +90,9 @@
             @NonNull String serviceType,
             @NonNull MdnsSocketClientBase socketClient,
             @NonNull ScheduledExecutorService executor,
-            @Nullable Network network,
+            @NonNull SocketKey socketKey,
             @NonNull SharedLog sharedLog) {
-        this(serviceType, socketClient, executor, new MdnsResponseDecoder.Clock(), network,
+        this(serviceType, socketClient, executor, new MdnsResponseDecoder.Clock(), socketKey,
                 sharedLog);
     }
 
@@ -102,7 +102,7 @@
             @NonNull MdnsSocketClientBase socketClient,
             @NonNull ScheduledExecutorService executor,
             @NonNull MdnsResponseDecoder.Clock clock,
-            @Nullable Network network,
+            @NonNull SocketKey socketKey,
             @NonNull SharedLog sharedLog) {
         this.serviceType = serviceType;
         this.socketClient = socketClient;
@@ -110,7 +110,7 @@
         this.serviceTypeLabels = TextUtils.split(serviceType, "\\.");
         this.responseDecoder = new MdnsResponseDecoder(clock, serviceTypeLabels);
         this.clock = clock;
-        this.network = network;
+        this.socketKey = socketKey;
         this.sharedLog = sharedLog;
     }
 
@@ -199,7 +199,7 @@
                     searchOptions.getSubtypes(),
                     searchOptions.isPassiveMode(),
                     currentSessionId,
-                    network);
+                    socketKey);
             if (hadReply) {
                 requestTaskFuture = scheduleNextRunLocked(taskConfig);
             } else {
@@ -437,10 +437,10 @@
         private int burstCounter;
         private int timeToRunNextTaskInMs;
         private boolean isFirstBurst;
-        @Nullable private final Network network;
+        @NonNull private final SocketKey socketKey;
 
         QueryTaskConfig(@NonNull Collection<String> subtypes, boolean usePassiveMode,
-                long sessionId, @Nullable Network network) {
+                long sessionId, @NonNull SocketKey socketKey) {
             this.usePassiveMode = usePassiveMode;
             this.subtypes = new ArrayList<>(subtypes);
             this.queriesPerBurst = QUERIES_PER_BURST;
@@ -462,7 +462,7 @@
                 // doubles until it maxes out at TIME_BETWEEN_BURSTS_MS.
                 this.timeBetweenBurstsInMs = INITIAL_TIME_BETWEEN_BURSTS_MS;
             }
-            this.network = network;
+            this.socketKey = socketKey;
         }
 
         QueryTaskConfig getConfigForNextRun() {
@@ -545,7 +545,7 @@
                 // Only the names are used to know which queries to send, other parameters like
                 // interfaceIndex do not matter.
                 servicesToResolve = makeResponsesForResolve(
-                        0 /* interfaceIndex */, config.network);
+                        0 /* interfaceIndex */, config.socketKey.getNetwork());
                 sendDiscoveryQueries = servicesToResolve.size() < listeners.size();
             }
             Pair<Integer, List<String>> result;
@@ -558,7 +558,7 @@
                                 config.subtypes,
                                 config.expectUnicastResponse,
                                 config.transactionId,
-                                config.network,
+                                config.socketKey.getNetwork(),
                                 sendDiscoveryQueries,
                                 servicesToResolve,
                                 clock)
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClient.java b/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClient.java
index b982644..2b6e5d0 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClient.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClient.java
@@ -235,7 +235,7 @@
             throw new IllegalArgumentException("This socket client does not support requesting "
                     + "specific networks");
         }
-        socketCreationCallback.onSocketCreated(null);
+        socketCreationCallback.onSocketCreated(new SocketKey(multicastSocket.getInterfaceIndex()));
     }
 
     @Override
@@ -456,7 +456,8 @@
             LOGGER.w(String.format("Error while decoding %s packet (%d): %d",
                     responseType, packetNumber, e.code));
             if (callback != null) {
-                callback.onFailedToParseMdnsResponse(packetNumber, e.code, network);
+                callback.onFailedToParseMdnsResponse(packetNumber, e.code,
+                        new SocketKey(network, interfaceIndex));
             }
             return e.code;
         }
@@ -466,7 +467,8 @@
         }
 
         if (callback != null) {
-            callback.onResponseReceived(response, interfaceIndex, network);
+            callback.onResponseReceived(
+                    response, new SocketKey(network, interfaceIndex));
         }
 
         return MdnsResponseErrorCode.SUCCESS;
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClientBase.java b/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClientBase.java
index e0762f9..a35925a 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClientBase.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClientBase.java
@@ -73,20 +73,19 @@
     /*** Callback for mdns response  */
     interface Callback {
         /*** Receive a mdns response */
-        void onResponseReceived(@NonNull MdnsPacket packet, int interfaceIndex,
-                @Nullable Network network);
+        void onResponseReceived(@NonNull MdnsPacket packet, @NonNull SocketKey socketKey);
 
         /*** Parse a mdns response failed */
         void onFailedToParseMdnsResponse(int receivedPacketNumber, int errorCode,
-                @Nullable Network network);
+                @NonNull SocketKey socketKey);
     }
 
     /*** Callback for requested socket creation  */
     interface SocketCreationCallback {
         /*** Notify requested socket is created */
-        void onSocketCreated(@Nullable Network network);
+        void onSocketCreated(@NonNull SocketKey socketKey);
 
         /*** Notify requested socket is destroyed */
-        void onAllSocketsDestroyed(@Nullable Network network);
+        void onAllSocketsDestroyed(@NonNull SocketKey socketKey);
     }
 }
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsDiscoveryManagerTests.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsDiscoveryManagerTests.java
index a24664e..d2298fe 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsDiscoveryManagerTests.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsDiscoveryManagerTests.java
@@ -28,7 +28,6 @@
 import static org.mockito.Mockito.when;
 
 import android.annotation.NonNull;
-import android.annotation.Nullable;
 import android.net.Network;
 import android.os.Handler;
 import android.os.HandlerThread;
@@ -65,17 +64,22 @@
     private static final String SERVICE_TYPE_2 = "_test._tcp.local";
     private static final Network NETWORK_1 = Mockito.mock(Network.class);
     private static final Network NETWORK_2 = Mockito.mock(Network.class);
-    private static final Pair<String, Network> PER_NETWORK_SERVICE_TYPE_1_NULL_NETWORK =
-            Pair.create(SERVICE_TYPE_1, null);
-    private static final Pair<String, Network> PER_NETWORK_SERVICE_TYPE_1_NETWORK_1 =
-            Pair.create(SERVICE_TYPE_1, NETWORK_1);
-    private static final Pair<String, Network> PER_NETWORK_SERVICE_TYPE_2_NULL_NETWORK =
-            Pair.create(SERVICE_TYPE_2, null);
-    private static final Pair<String, Network> PER_NETWORK_SERVICE_TYPE_2_NETWORK_1 =
-            Pair.create(SERVICE_TYPE_2, NETWORK_1);
-    private static final Pair<String, Network> PER_NETWORK_SERVICE_TYPE_2_NETWORK_2 =
-            Pair.create(SERVICE_TYPE_2, NETWORK_2);
-
+    private static final SocketKey SOCKET_KEY_NULL_NETWORK =
+            new SocketKey(null /* network */, 999 /* interfaceIndex */);
+    private static final SocketKey SOCKET_KEY_NETWORK_1 =
+            new SocketKey(NETWORK_1, 998 /* interfaceIndex */);
+    private static final SocketKey SOCKET_KEY_NETWORK_2 =
+            new SocketKey(NETWORK_2, 997 /* interfaceIndex */);
+    private static final Pair<String, SocketKey> PER_SOCKET_SERVICE_TYPE_1_NULL_NETWORK =
+            Pair.create(SERVICE_TYPE_1, SOCKET_KEY_NULL_NETWORK);
+    private static final Pair<String, SocketKey> PER_SOCKET_SERVICE_TYPE_2_NULL_NETWORK =
+            Pair.create(SERVICE_TYPE_2, SOCKET_KEY_NULL_NETWORK);
+    private static final Pair<String, SocketKey> PER_SOCKET_SERVICE_TYPE_1_NETWORK_1 =
+            Pair.create(SERVICE_TYPE_1, SOCKET_KEY_NETWORK_1);
+    private static final Pair<String, SocketKey> PER_SOCKET_SERVICE_TYPE_2_NETWORK_1 =
+            Pair.create(SERVICE_TYPE_2, SOCKET_KEY_NETWORK_1);
+    private static final Pair<String, SocketKey> PER_SOCKET_SERVICE_TYPE_2_NETWORK_2 =
+            Pair.create(SERVICE_TYPE_2, SOCKET_KEY_NETWORK_2);
     @Mock private ExecutorProvider executorProvider;
     @Mock private MdnsSocketClientBase socketClient;
     @Mock private MdnsServiceTypeClient mockServiceTypeClientType1NullNetwork;
@@ -104,22 +108,22 @@
                 sharedLog) {
                     @Override
                     MdnsServiceTypeClient createServiceTypeClient(@NonNull String serviceType,
-                            @Nullable Network network) {
-                        final Pair<String, Network> perNetworkServiceType =
-                                Pair.create(serviceType, network);
-                        if (perNetworkServiceType.equals(PER_NETWORK_SERVICE_TYPE_1_NULL_NETWORK)) {
+                            @NonNull SocketKey socketKey) {
+                        final Pair<String, SocketKey> perSocketServiceType =
+                                Pair.create(serviceType, socketKey);
+                        if (perSocketServiceType.equals(PER_SOCKET_SERVICE_TYPE_1_NULL_NETWORK)) {
                             return mockServiceTypeClientType1NullNetwork;
-                        } else if (perNetworkServiceType.equals(
-                                PER_NETWORK_SERVICE_TYPE_1_NETWORK_1)) {
+                        } else if (perSocketServiceType.equals(
+                                PER_SOCKET_SERVICE_TYPE_1_NETWORK_1)) {
                             return mockServiceTypeClientType1Network1;
-                        } else if (perNetworkServiceType.equals(
-                                PER_NETWORK_SERVICE_TYPE_2_NULL_NETWORK)) {
+                        } else if (perSocketServiceType.equals(
+                                PER_SOCKET_SERVICE_TYPE_2_NULL_NETWORK)) {
                             return mockServiceTypeClientType2NullNetwork;
-                        } else if (perNetworkServiceType.equals(
-                                PER_NETWORK_SERVICE_TYPE_2_NETWORK_1)) {
+                        } else if (perSocketServiceType.equals(
+                                PER_SOCKET_SERVICE_TYPE_2_NETWORK_1)) {
                             return mockServiceTypeClientType2Network1;
-                        } else if (perNetworkServiceType.equals(
-                                PER_NETWORK_SERVICE_TYPE_2_NETWORK_2)) {
+                        } else if (perSocketServiceType.equals(
+                                PER_SOCKET_SERVICE_TYPE_2_NETWORK_2)) {
                             return mockServiceTypeClientType2Network2;
                         }
                         return null;
@@ -156,7 +160,7 @@
                 MdnsSearchOptions.newBuilder().setNetwork(null /* network */).build();
         final SocketCreationCallback callback = expectSocketCreationCallback(
                 SERVICE_TYPE_1, mockListenerOne, options);
-        runOnHandler(() -> callback.onSocketCreated(null /* network */));
+        runOnHandler(() -> callback.onSocketCreated(SOCKET_KEY_NULL_NETWORK));
         verify(mockServiceTypeClientType1NullNetwork).startSendAndReceive(mockListenerOne, options);
 
         when(mockServiceTypeClientType1NullNetwork.stopSendAndReceive(mockListenerOne))
@@ -172,16 +176,16 @@
                 MdnsSearchOptions.newBuilder().setNetwork(null /* network */).build();
         final SocketCreationCallback callback = expectSocketCreationCallback(
                 SERVICE_TYPE_1, mockListenerOne, options);
-        runOnHandler(() -> callback.onSocketCreated(null /* network */));
+        runOnHandler(() -> callback.onSocketCreated(SOCKET_KEY_NULL_NETWORK));
         verify(mockServiceTypeClientType1NullNetwork).startSendAndReceive(mockListenerOne, options);
-        runOnHandler(() -> callback.onSocketCreated(NETWORK_1));
+        runOnHandler(() -> callback.onSocketCreated(SOCKET_KEY_NETWORK_1));
         verify(mockServiceTypeClientType1Network1).startSendAndReceive(mockListenerOne, options);
 
         final SocketCreationCallback callback2 = expectSocketCreationCallback(
                 SERVICE_TYPE_2, mockListenerTwo, options);
-        runOnHandler(() -> callback2.onSocketCreated(null /* network */));
+        runOnHandler(() -> callback2.onSocketCreated(SOCKET_KEY_NULL_NETWORK));
         verify(mockServiceTypeClientType2NullNetwork).startSendAndReceive(mockListenerTwo, options);
-        runOnHandler(() -> callback2.onSocketCreated(NETWORK_2));
+        runOnHandler(() -> callback2.onSocketCreated(SOCKET_KEY_NETWORK_2));
         verify(mockServiceTypeClientType2Network2).startSendAndReceive(mockListenerTwo, options);
     }
 
@@ -191,49 +195,48 @@
                 MdnsSearchOptions.newBuilder().setNetwork(null /* network */).build();
         final SocketCreationCallback callback = expectSocketCreationCallback(
                 SERVICE_TYPE_1, mockListenerOne, options1);
-        runOnHandler(() -> callback.onSocketCreated(null /* network */));
+        runOnHandler(() -> callback.onSocketCreated(SOCKET_KEY_NULL_NETWORK));
         verify(mockServiceTypeClientType1NullNetwork).startSendAndReceive(
                 mockListenerOne, options1);
-        runOnHandler(() -> callback.onSocketCreated(NETWORK_1));
+        runOnHandler(() -> callback.onSocketCreated(SOCKET_KEY_NETWORK_1));
         verify(mockServiceTypeClientType1Network1).startSendAndReceive(mockListenerOne, options1);
 
         final MdnsSearchOptions options2 =
                 MdnsSearchOptions.newBuilder().setNetwork(NETWORK_2).build();
         final SocketCreationCallback callback2 = expectSocketCreationCallback(
                 SERVICE_TYPE_2, mockListenerTwo, options2);
-        runOnHandler(() -> callback2.onSocketCreated(NETWORK_2));
+        runOnHandler(() -> callback2.onSocketCreated(SOCKET_KEY_NETWORK_2));
         verify(mockServiceTypeClientType2Network2).startSendAndReceive(mockListenerTwo, options2);
 
         final MdnsPacket responseForServiceTypeOne = createMdnsPacket(SERVICE_TYPE_1);
-        final int ifIndex = 1;
         runOnHandler(() -> discoveryManager.onResponseReceived(
-                responseForServiceTypeOne, ifIndex, null /* network */));
+                responseForServiceTypeOne, SOCKET_KEY_NULL_NETWORK));
         // Packets for network null are only processed by the ServiceTypeClient for network null
         verify(mockServiceTypeClientType1NullNetwork).processResponse(responseForServiceTypeOne,
-                ifIndex, null /* network */);
+                SOCKET_KEY_NULL_NETWORK.getInterfaceIndex(), SOCKET_KEY_NULL_NETWORK.getNetwork());
         verify(mockServiceTypeClientType1Network1, never()).processResponse(any(), anyInt(), any());
         verify(mockServiceTypeClientType2Network2, never()).processResponse(any(), anyInt(), any());
 
         final MdnsPacket responseForServiceTypeTwo = createMdnsPacket(SERVICE_TYPE_2);
         runOnHandler(() -> discoveryManager.onResponseReceived(
-                responseForServiceTypeTwo, ifIndex, NETWORK_1));
+                responseForServiceTypeTwo, SOCKET_KEY_NETWORK_1));
         verify(mockServiceTypeClientType1NullNetwork, never()).processResponse(any(), anyInt(),
-                eq(NETWORK_1));
+                eq(SOCKET_KEY_NETWORK_1.getNetwork()));
         verify(mockServiceTypeClientType1Network1).processResponse(responseForServiceTypeTwo,
-                ifIndex, NETWORK_1);
+                SOCKET_KEY_NETWORK_1.getInterfaceIndex(), SOCKET_KEY_NETWORK_1.getNetwork());
         verify(mockServiceTypeClientType2Network2, never()).processResponse(any(), anyInt(),
-                eq(NETWORK_1));
+                eq(SOCKET_KEY_NETWORK_1.getNetwork()));
 
         final MdnsPacket responseForSubtype =
                 createMdnsPacket("subtype._sub._googlecast._tcp.local");
         runOnHandler(() -> discoveryManager.onResponseReceived(
-                responseForSubtype, ifIndex, NETWORK_2));
-        verify(mockServiceTypeClientType1NullNetwork, never()).processResponse(
-                any(), anyInt(), eq(NETWORK_2));
-        verify(mockServiceTypeClientType1Network1, never()).processResponse(
-                any(), anyInt(), eq(NETWORK_2));
-        verify(mockServiceTypeClientType2Network2).processResponse(
-                responseForSubtype, ifIndex, NETWORK_2);
+                responseForSubtype, SOCKET_KEY_NETWORK_2));
+        verify(mockServiceTypeClientType1NullNetwork, never()).processResponse(any(), anyInt(),
+                eq(SOCKET_KEY_NETWORK_2.getNetwork()));
+        verify(mockServiceTypeClientType1Network1, never()).processResponse(any(), anyInt(),
+                eq(SOCKET_KEY_NETWORK_2.getNetwork()));
+        verify(mockServiceTypeClientType2Network2).processResponse(responseForSubtype,
+                SOCKET_KEY_NETWORK_2.getInterfaceIndex(), SOCKET_KEY_NETWORK_2.getNetwork());
     }
 
     @Test
@@ -243,55 +246,53 @@
                 MdnsSearchOptions.newBuilder().setNetwork(NETWORK_1).build();
         final SocketCreationCallback callback = expectSocketCreationCallback(
                 SERVICE_TYPE_1, mockListenerOne, network1Options);
-        runOnHandler(() -> callback.onSocketCreated(NETWORK_1));
+        runOnHandler(() -> callback.onSocketCreated(SOCKET_KEY_NETWORK_1));
         verify(mockServiceTypeClientType1Network1).startSendAndReceive(
                 mockListenerOne, network1Options);
 
         // Create a ServiceTypeClient for SERVICE_TYPE_2 and NETWORK_1
         final SocketCreationCallback callback2 = expectSocketCreationCallback(
                 SERVICE_TYPE_2, mockListenerTwo, network1Options);
-        runOnHandler(() -> callback2.onSocketCreated(NETWORK_1));
+        runOnHandler(() -> callback2.onSocketCreated(SOCKET_KEY_NETWORK_1));
         verify(mockServiceTypeClientType2Network1).startSendAndReceive(
                 mockListenerTwo, network1Options);
 
         // Receive a response, it should be processed on both clients.
         final MdnsPacket response = createMdnsPacket(SERVICE_TYPE_1);
-        final int ifIndex = 1;
-        runOnHandler(() -> discoveryManager.onResponseReceived(
-                response, ifIndex, NETWORK_1));
-        verify(mockServiceTypeClientType1Network1).processResponse(response, ifIndex, NETWORK_1);
-        verify(mockServiceTypeClientType2Network1).processResponse(response, ifIndex, NETWORK_1);
+        runOnHandler(() -> discoveryManager.onResponseReceived(response, SOCKET_KEY_NETWORK_1));
+        verify(mockServiceTypeClientType1Network1).processResponse(response,
+                SOCKET_KEY_NETWORK_1.getInterfaceIndex(), SOCKET_KEY_NETWORK_1.getNetwork());
+        verify(mockServiceTypeClientType2Network1).processResponse(response,
+                SOCKET_KEY_NETWORK_1.getInterfaceIndex(), SOCKET_KEY_NETWORK_1.getNetwork());
 
         // The first callback receives a notification that the network has been destroyed,
         // mockServiceTypeClientOne1 should send service removed notifications and remove from the
         // list of clients.
-        runOnHandler(() -> callback.onAllSocketsDestroyed(NETWORK_1));
+        runOnHandler(() -> callback.onAllSocketsDestroyed(SOCKET_KEY_NETWORK_1));
         verify(mockServiceTypeClientType1Network1).notifySocketDestroyed();
 
         // Receive a response again, it should be processed only on
         // mockServiceTypeClientType2Network1. Because the mockServiceTypeClientType1Network1 is
         // removed from the list of clients, it is no longer able to process responses.
-        runOnHandler(() -> discoveryManager.onResponseReceived(
-                response, ifIndex, NETWORK_1));
+        runOnHandler(() -> discoveryManager.onResponseReceived(response, SOCKET_KEY_NETWORK_1));
         // Still times(1) as a response was received once previously
-        verify(mockServiceTypeClientType1Network1, times(1))
-                .processResponse(response, ifIndex, NETWORK_1);
-        verify(mockServiceTypeClientType2Network1, times(2))
-                .processResponse(response, ifIndex, NETWORK_1);
+        verify(mockServiceTypeClientType1Network1, times(1)).processResponse(response,
+                SOCKET_KEY_NETWORK_1.getInterfaceIndex(), SOCKET_KEY_NETWORK_1.getNetwork());
+        verify(mockServiceTypeClientType2Network1, times(2)).processResponse(response,
+                SOCKET_KEY_NETWORK_1.getInterfaceIndex(), SOCKET_KEY_NETWORK_1.getNetwork());
 
         // The client for NETWORK_1 receives the callback that the NETWORK_2 has been destroyed,
         // mockServiceTypeClientTwo2 shouldn't send any notifications.
-        runOnHandler(() -> callback2.onAllSocketsDestroyed(NETWORK_2));
+        runOnHandler(() -> callback2.onAllSocketsDestroyed(SOCKET_KEY_NETWORK_2));
         verify(mockServiceTypeClientType2Network1, never()).notifySocketDestroyed();
 
         // Receive a response again, mockServiceTypeClientType2Network1 is still in the list of
         // clients, it's still able to process responses.
-        runOnHandler(() -> discoveryManager.onResponseReceived(
-                response, ifIndex, NETWORK_1));
-        verify(mockServiceTypeClientType1Network1, times(1))
-                .processResponse(response, ifIndex, NETWORK_1);
-        verify(mockServiceTypeClientType2Network1, times(3))
-                .processResponse(response, ifIndex, NETWORK_1);
+        runOnHandler(() -> discoveryManager.onResponseReceived(response, SOCKET_KEY_NETWORK_1));
+        verify(mockServiceTypeClientType1Network1, times(1)).processResponse(response,
+                SOCKET_KEY_NETWORK_1.getInterfaceIndex(), SOCKET_KEY_NETWORK_1.getNetwork());
+        verify(mockServiceTypeClientType2Network1, times(3)).processResponse(response,
+                SOCKET_KEY_NETWORK_1.getInterfaceIndex(), SOCKET_KEY_NETWORK_1.getNetwork());
     }
 
     @Test
@@ -301,27 +302,25 @@
                 MdnsSearchOptions.newBuilder().setNetwork(null /* network */).build();
         final SocketCreationCallback callback = expectSocketCreationCallback(
                 SERVICE_TYPE_1, mockListenerOne, network1Options);
-        runOnHandler(() -> callback.onSocketCreated(null /* network */));
+        runOnHandler(() -> callback.onSocketCreated(SOCKET_KEY_NULL_NETWORK));
         verify(mockServiceTypeClientType1NullNetwork).startSendAndReceive(
                 mockListenerOne, network1Options);
 
         // Receive a response, it should be processed on the client.
         final MdnsPacket response = createMdnsPacket(SERVICE_TYPE_1);
         final int ifIndex = 1;
-        runOnHandler(() -> discoveryManager.onResponseReceived(
-                response, ifIndex, null /* network */));
-        verify(mockServiceTypeClientType1NullNetwork).processResponse(
-                response, ifIndex, null /* network */);
+        runOnHandler(() -> discoveryManager.onResponseReceived(response, SOCKET_KEY_NULL_NETWORK));
+        verify(mockServiceTypeClientType1NullNetwork).processResponse(response,
+                SOCKET_KEY_NULL_NETWORK.getInterfaceIndex(), SOCKET_KEY_NULL_NETWORK.getNetwork());
 
-        runOnHandler(() -> callback.onAllSocketsDestroyed(null /* network */));
+        runOnHandler(() -> callback.onAllSocketsDestroyed(SOCKET_KEY_NULL_NETWORK));
         verify(mockServiceTypeClientType1NullNetwork).notifySocketDestroyed();
 
         // Receive a response again, it should not be processed.
-        runOnHandler(() -> discoveryManager.onResponseReceived(
-                response, ifIndex, null /* network */));
+        runOnHandler(() -> discoveryManager.onResponseReceived(response, SOCKET_KEY_NULL_NETWORK));
         // Still times(1) as a response was received once previously
-        verify(mockServiceTypeClientType1NullNetwork, times(1))
-                .processResponse(response, ifIndex, null /* network */);
+        verify(mockServiceTypeClientType1NullNetwork, times(1)).processResponse(response,
+                SOCKET_KEY_NULL_NETWORK.getInterfaceIndex(), SOCKET_KEY_NULL_NETWORK.getNetwork());
 
         // Unregister the listener, notifyNetworkUnrequested should be called but other stop methods
         // won't be call because the service type client was unregistered and destroyed. But those
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClientTest.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClientTest.java
index a0a302f..f7ef077 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClientTest.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClientTest.java
@@ -21,7 +21,6 @@
 
 import static org.junit.Assert.assertEquals;
 import static org.mockito.ArgumentMatchers.any;
-import static org.mockito.ArgumentMatchers.anyInt;
 import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.eq;
 import static org.mockito.Mockito.mock;
@@ -132,11 +131,11 @@
         doReturn(null).when(tetherSocketKey2).getNetwork();
         // Notify socket created
         callback.onSocketCreated(mSocketKey, mSocket, List.of());
-        verify(mSocketCreationCallback).onSocketCreated(mNetwork);
+        verify(mSocketCreationCallback).onSocketCreated(mSocketKey);
         callback.onSocketCreated(tetherSocketKey1, tetherIfaceSock1, List.of());
-        verify(mSocketCreationCallback).onSocketCreated(null);
+        verify(mSocketCreationCallback).onSocketCreated(tetherSocketKey1);
         callback.onSocketCreated(tetherSocketKey2, tetherIfaceSock2, List.of());
-        verify(mSocketCreationCallback, times(2)).onSocketCreated(null);
+        verify(mSocketCreationCallback).onSocketCreated(tetherSocketKey2);
 
         // Send packet to IPv4 with target network and verify sending has been called.
         mSocketClient.sendMulticastPacket(ipv4Packet, mNetwork);
@@ -172,7 +171,7 @@
         doReturn(createEmptyNetworkInterface()).when(mSocket).getInterface();
         // Notify socket created
         callback.onSocketCreated(mSocketKey, mSocket, List.of());
-        verify(mSocketCreationCallback).onSocketCreated(mNetwork);
+        verify(mSocketCreationCallback).onSocketCreated(mSocketKey);
 
         final ArgumentCaptor<PacketHandler> handlerCaptor =
                 ArgumentCaptor.forClass(PacketHandler.class);
@@ -183,7 +182,7 @@
         handler.handlePacket(data, data.length, null /* src */);
         final ArgumentCaptor<MdnsPacket> responseCaptor =
                 ArgumentCaptor.forClass(MdnsPacket.class);
-        verify(mCallback).onResponseReceived(responseCaptor.capture(), anyInt(), any());
+        verify(mCallback).onResponseReceived(responseCaptor.capture(), any());
         final MdnsPacket response = responseCaptor.getValue();
         assertEquals(0, response.questions.size());
         assertEquals(0, response.additionalRecords.size());
@@ -222,12 +221,13 @@
         doReturn(createEmptyNetworkInterface()).when(socket3).getInterface();
 
         final SocketKey socketKey2 = mock(SocketKey.class);
-        doReturn(null).when(socketKey2).getNetwork();
+        final SocketKey socketKey3 = mock(SocketKey.class);
         callback.onSocketCreated(mSocketKey, mSocket, List.of());
         callback.onSocketCreated(socketKey2, socket2, List.of());
-        callback.onSocketCreated(socketKey2, socket3, List.of());
-        verify(mSocketCreationCallback).onSocketCreated(mNetwork);
-        verify(mSocketCreationCallback, times(2)).onSocketCreated(null);
+        callback.onSocketCreated(socketKey3, socket3, List.of());
+        verify(mSocketCreationCallback).onSocketCreated(mSocketKey);
+        verify(mSocketCreationCallback).onSocketCreated(socketKey2);
+        verify(mSocketCreationCallback).onSocketCreated(socketKey3);
 
         // Send IPv4 packet on the non-null Network and verify sending has been called.
         mSocketClient.sendMulticastPacket(ipv4Packet, mNetwork);
@@ -252,9 +252,10 @@
         // Notify socket created for all networks.
         callback2.onSocketCreated(mSocketKey, mSocket, List.of());
         callback2.onSocketCreated(socketKey2, socket2, List.of());
-        callback2.onSocketCreated(socketKey2, socket3, List.of());
-        verify(socketCreationCb2).onSocketCreated(mNetwork);
-        verify(socketCreationCb2, times(2)).onSocketCreated(null);
+        callback2.onSocketCreated(socketKey3, socket3, List.of());
+        verify(socketCreationCb2).onSocketCreated(mSocketKey);
+        verify(socketCreationCb2).onSocketCreated(socketKey2);
+        verify(socketCreationCb2).onSocketCreated(socketKey3);
 
         // Send IPv4 packet to null network and verify sending to the 2 tethered interface sockets.
         mSocketClient.sendMulticastPacket(ipv4Packet, null);
@@ -296,16 +297,16 @@
         doReturn(createEmptyNetworkInterface()).when(otherSocket).getInterface();
 
         callback.onSocketCreated(mSocketKey, mSocket, List.of());
-        verify(mSocketCreationCallback).onSocketCreated(mNetwork);
+        verify(mSocketCreationCallback).onSocketCreated(mSocketKey);
         callback.onSocketCreated(mSocketKey, otherSocket, List.of());
-        verify(mSocketCreationCallback, times(2)).onSocketCreated(mNetwork);
+        verify(mSocketCreationCallback, times(2)).onSocketCreated(mSocketKey);
 
-        verify(mSocketCreationCallback, never()).onAllSocketsDestroyed(mNetwork);
+        verify(mSocketCreationCallback, never()).onAllSocketsDestroyed(mSocketKey);
         mHandler.post(() -> mSocketClient.notifyNetworkUnrequested(mListener));
         HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
 
         verify(mProvider).unrequestSocket(callback);
-        verify(mSocketCreationCallback).onAllSocketsDestroyed(mNetwork);
+        verify(mSocketCreationCallback).onAllSocketsDestroyed(mSocketKey);
     }
 
     @Test
@@ -316,14 +317,14 @@
         doReturn(createEmptyNetworkInterface()).when(otherSocket).getInterface();
 
         callback.onSocketCreated(mSocketKey, mSocket, List.of());
-        verify(mSocketCreationCallback).onSocketCreated(mNetwork);
+        verify(mSocketCreationCallback).onSocketCreated(mSocketKey);
         callback.onSocketCreated(mSocketKey, otherSocket, List.of());
-        verify(mSocketCreationCallback, times(2)).onSocketCreated(mNetwork);
+        verify(mSocketCreationCallback, times(2)).onSocketCreated(mSocketKey);
 
         // Notify socket destroyed
         callback.onInterfaceDestroyed(mSocketKey, mSocket);
         verifyNoMoreInteractions(mSocketCreationCallback);
         callback.onInterfaceDestroyed(mSocketKey, otherSocket);
-        verify(mSocketCreationCallback).onAllSocketsDestroyed(mNetwork);
+        verify(mSocketCreationCallback).onAllSocketsDestroyed(mSocketKey);
     }
 }
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java
index d1adecf..635a1d4 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java
@@ -118,6 +118,7 @@
     private FakeExecutor currentThreadExecutor = new FakeExecutor();
 
     private MdnsServiceTypeClient client;
+    private SocketKey socketKey;
 
     @Before
     @SuppressWarnings("DoNotMock")
@@ -128,6 +129,7 @@
         expectedIPv4Packets = new DatagramPacket[16];
         expectedIPv6Packets = new DatagramPacket[16];
         expectedSendFutures = new ScheduledFuture<?>[16];
+        socketKey = new SocketKey(mockNetwork, INTERFACE_INDEX);
 
         for (int i = 0; i < expectedSendFutures.length; ++i) {
             expectedIPv4Packets[i] = new DatagramPacket(buf, 0 /* offset */, 5 /* length */,
@@ -174,7 +176,7 @@
 
         client =
                 new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor,
-                        mockDecoderClock, mockNetwork, mockSharedLog) {
+                        mockDecoderClock, socketKey, mockSharedLog) {
                     @Override
                     MdnsPacketWriter createMdnsPacketWriter() {
                         return mockPacketWriter;
@@ -325,7 +327,7 @@
         MdnsSearchOptions searchOptions =
                 MdnsSearchOptions.newBuilder().addSubtype("12345").setIsPassiveMode(false).build();
         QueryTaskConfig config = new QueryTaskConfig(
-                searchOptions.getSubtypes(), searchOptions.isPassiveMode(), 1, mockNetwork);
+                searchOptions.getSubtypes(), searchOptions.isPassiveMode(), 1, socketKey);
 
         // This is the first query. We will ask for unicast response.
         assertTrue(config.expectUnicastResponse);
@@ -354,7 +356,7 @@
         MdnsSearchOptions searchOptions =
                 MdnsSearchOptions.newBuilder().addSubtype("12345").setIsPassiveMode(false).build();
         QueryTaskConfig config = new QueryTaskConfig(
-                searchOptions.getSubtypes(), searchOptions.isPassiveMode(), 1, mockNetwork);
+                searchOptions.getSubtypes(), searchOptions.isPassiveMode(), 1, socketKey);
 
         // This is the first query. We will ask for unicast response.
         assertTrue(config.expectUnicastResponse);
@@ -508,9 +510,9 @@
 
         // Process a second response with a different port and updated text attributes.
         client.processResponse(createResponse(
-                "service-instance-1", ipV4Address, 5354,
-                /* subtype= */ "ABCDE",
-                Collections.singletonMap("key", "value"), TEST_TTL),
+                        "service-instance-1", ipV4Address, 5354,
+                        /* subtype= */ "ABCDE",
+                        Collections.singletonMap("key", "value"), TEST_TTL),
                 /* interfaceIndex= */ 20, mockNetwork);
 
         // Verify onServiceNameDiscovered was called once for the initial response.
@@ -563,9 +565,9 @@
 
         // Process a second response with a different port and updated text attributes.
         client.processResponse(createResponse(
-                "service-instance-1", ipV6Address, 5354,
-                /* subtype= */ "ABCDE",
-                Collections.singletonMap("key", "value"), TEST_TTL),
+                        "service-instance-1", ipV6Address, 5354,
+                        /* subtype= */ "ABCDE",
+                        Collections.singletonMap("key", "value"), TEST_TTL),
                 /* interfaceIndex= */ 20, mockNetwork);
 
         // Verify onServiceNameDiscovered was called once for the initial response.
@@ -709,7 +711,7 @@
         final String serviceInstanceName = "service-instance-1";
         client =
                 new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor,
-                        mockDecoderClock, mockNetwork, mockSharedLog) {
+                        mockDecoderClock, socketKey, mockSharedLog) {
                     @Override
                     MdnsPacketWriter createMdnsPacketWriter() {
                         return mockPacketWriter;
@@ -750,7 +752,7 @@
         final String serviceInstanceName = "service-instance-1";
         client =
                 new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor,
-                        mockDecoderClock, mockNetwork, mockSharedLog) {
+                        mockDecoderClock, socketKey, mockSharedLog) {
                     @Override
                     MdnsPacketWriter createMdnsPacketWriter() {
                         return mockPacketWriter;
@@ -783,7 +785,7 @@
         final String serviceInstanceName = "service-instance-1";
         client =
                 new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor,
-                        mockDecoderClock, mockNetwork, mockSharedLog) {
+                        mockDecoderClock, socketKey, mockSharedLog) {
                     @Override
                     MdnsPacketWriter createMdnsPacketWriter() {
                         return mockPacketWriter;
@@ -835,8 +837,8 @@
 
         // Process the last response which is goodbye message (with the main type, not subtype).
         client.processResponse(createResponse(
-                serviceName, ipV6Address, 5354, SERVICE_TYPE_LABELS,
-                Collections.singletonMap("key", "value"), /* ptrTtlMillis= */ 0L),
+                        serviceName, ipV6Address, 5354, SERVICE_TYPE_LABELS,
+                        Collections.singletonMap("key", "value"), /* ptrTtlMillis= */ 0L),
                 INTERFACE_INDEX, mockNetwork);
 
         // Verify onServiceNameDiscovered was first called for the initial response.
@@ -908,7 +910,7 @@
     @Test
     public void testProcessResponse_Resolve() throws Exception {
         client = new MdnsServiceTypeClient(
-                SERVICE_TYPE, mockSocketClient, currentThreadExecutor, mockNetwork, mockSharedLog);
+                SERVICE_TYPE, mockSocketClient, currentThreadExecutor, socketKey, mockSharedLog);
 
         final String instanceName = "service-instance";
         final String[] hostname = new String[] { "testhost "};
@@ -998,7 +1000,7 @@
     @Test
     public void testRenewTxtSrvInResolve() throws Exception {
         client = new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor,
-                mockDecoderClock, mockNetwork, mockSharedLog);
+                mockDecoderClock, socketKey, mockSharedLog);
 
         final String instanceName = "service-instance";
         final String[] hostname = new String[] { "testhost "};
@@ -1102,7 +1104,7 @@
     @Test
     public void testProcessResponse_ResolveExcludesOtherServices() {
         client = new MdnsServiceTypeClient(
-                SERVICE_TYPE, mockSocketClient, currentThreadExecutor, mockNetwork, mockSharedLog);
+                SERVICE_TYPE, mockSocketClient, currentThreadExecutor, socketKey, mockSharedLog);
 
         final String requestedInstance = "instance1";
         final String otherInstance = "instance2";
@@ -1119,13 +1121,13 @@
 
         // Complete response from instanceName
         client.processResponse(createResponse(
-                requestedInstance, ipV4Address, 5353, SERVICE_TYPE_LABELS,
+                        requestedInstance, ipV4Address, 5353, SERVICE_TYPE_LABELS,
                         Collections.emptyMap() /* textAttributes */, TEST_TTL),
                 INTERFACE_INDEX, mockNetwork);
 
         // Complete response from otherInstanceName
         client.processResponse(createResponse(
-                otherInstance, ipV4Address, 5353, SERVICE_TYPE_LABELS,
+                        otherInstance, ipV4Address, 5353, SERVICE_TYPE_LABELS,
                         Collections.emptyMap() /* textAttributes */, TEST_TTL),
                 INTERFACE_INDEX, mockNetwork);
 
@@ -1166,7 +1168,7 @@
     @Test
     public void testProcessResponse_SubtypeDiscoveryLimitedToSubtype() {
         client = new MdnsServiceTypeClient(
-                SERVICE_TYPE, mockSocketClient, currentThreadExecutor, mockNetwork, mockSharedLog);
+                SERVICE_TYPE, mockSocketClient, currentThreadExecutor, socketKey, mockSharedLog);
 
         final String matchingInstance = "instance1";
         final String subtype = "_subtype";
@@ -1247,7 +1249,7 @@
     @Test
     public void testNotifySocketDestroyed() throws Exception {
         client = new MdnsServiceTypeClient(
-                SERVICE_TYPE, mockSocketClient, currentThreadExecutor, mockNetwork, mockSharedLog);
+                SERVICE_TYPE, mockSocketClient, currentThreadExecutor, socketKey, mockSharedLog);
 
         final String requestedInstance = "instance1";
         final String otherInstance = "instance2";
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketClientTests.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketClientTests.java
index abb1747..e30c249 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketClientTests.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketClientTests.java
@@ -23,6 +23,7 @@
 import static org.junit.Assert.assertTrue;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.anyInt;
+import static org.mockito.ArgumentMatchers.argThat;
 import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.never;
@@ -370,7 +371,7 @@
         mdnsClient.startDiscovery();
 
         verify(mockCallback, timeout(TIMEOUT).atLeast(1))
-                .onResponseReceived(any(MdnsPacket.class), anyInt(), any());
+                .onResponseReceived(any(MdnsPacket.class), any(SocketKey.class));
     }
 
     @Test
@@ -379,7 +380,7 @@
         mdnsClient.startDiscovery();
 
         verify(mockCallback, timeout(TIMEOUT).atLeastOnce())
-                .onResponseReceived(any(MdnsPacket.class), anyInt(), any());
+                .onResponseReceived(any(MdnsPacket.class), any(SocketKey.class));
 
         mdnsClient.stopDiscovery();
     }
@@ -513,7 +514,7 @@
         mdnsClient.startDiscovery();
 
         verify(mockCallback, timeout(TIMEOUT).atLeastOnce())
-                .onResponseReceived(any(), eq(21), any());
+                .onResponseReceived(any(), argThat(key -> key.getInterfaceIndex() == 21));
     }
 
     @Test
@@ -536,6 +537,7 @@
         mdnsClient.startDiscovery();
 
         verify(mockMulticastSocket, never()).getInterfaceIndex();
-        verify(mockCallback, timeout(TIMEOUT).atLeast(1)).onResponseReceived(any(), eq(-1), any());
+        verify(mockCallback, timeout(TIMEOUT).atLeast(1))
+                .onResponseReceived(any(), argThat(key -> key.getInterfaceIndex() == -1));
     }
 }
\ No newline at end of file