Create an MdnsServiceTypeClient using a SocketKey

Now, MdnsServiceTypeClient is created using a network. However,
all tether interfaces use the same network (null), which
means they use the same client for the same service type. This is
not the intended behavior, as each interface should have its own
client. Therefore, MdnsServiceTypeClient creation should be
changed to use SocketKey, which includes both the network and
interface index. This will allow each interface to have its own
client.

Bug: 278018903
Test: atest FrameworksNetTests android.net.cts.NsdManagerTest
Change-Id: I34b7d983f00b67198befb5bf71fc511cf0dabae6
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java b/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java
index 39fceb9..c5d1bf6 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java
@@ -22,7 +22,6 @@
 import android.annotation.NonNull;
 import android.annotation.Nullable;
 import android.annotation.RequiresPermission;
-import android.net.Network;
 import android.os.Handler;
 import android.os.HandlerThread;
 import android.util.ArrayMap;
@@ -36,7 +35,6 @@
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.List;
-import java.util.Objects;
 
 /**
  * This class keeps tracking the set of registered {@link MdnsServiceBrowserListener} instances, and
@@ -50,54 +48,58 @@
     private final MdnsSocketClientBase socketClient;
     @NonNull private final SharedLog sharedLog;
 
-    @NonNull private final PerNetworkServiceTypeClients perNetworkServiceTypeClients;
+    @NonNull private final PerSocketServiceTypeClients perSocketServiceTypeClients;
     @NonNull private final Handler handler;
     @Nullable private final HandlerThread handlerThread;
 
-    private static class PerNetworkServiceTypeClients {
-        private final ArrayMap<Pair<String, Network>, MdnsServiceTypeClient> clients =
+    private static class PerSocketServiceTypeClients {
+        private final ArrayMap<Pair<String, SocketKey>, MdnsServiceTypeClient> clients =
                 new ArrayMap<>();
 
-        public void put(@NonNull String serviceType, @Nullable Network network,
+        public void put(@NonNull String serviceType, @NonNull SocketKey socketKey,
                 @NonNull MdnsServiceTypeClient client) {
             final String dnsLowerServiceType = MdnsUtils.toDnsLowerCase(serviceType);
-            final Pair<String, Network> perNetworkServiceType = new Pair<>(dnsLowerServiceType,
-                    network);
-            clients.put(perNetworkServiceType, client);
+            final Pair<String, SocketKey> perSocketServiceType = new Pair<>(dnsLowerServiceType,
+                    socketKey);
+            clients.put(perSocketServiceType, client);
         }
 
         @Nullable
-        public MdnsServiceTypeClient get(@NonNull String serviceType, @Nullable Network network) {
+        public MdnsServiceTypeClient get(
+                @NonNull String serviceType, @NonNull SocketKey socketKey) {
             final String dnsLowerServiceType = MdnsUtils.toDnsLowerCase(serviceType);
-            final Pair<String, Network> perNetworkServiceType = new Pair<>(dnsLowerServiceType,
-                    network);
-            return clients.getOrDefault(perNetworkServiceType, null);
+            final Pair<String, SocketKey> perSocketServiceType = new Pair<>(dnsLowerServiceType,
+                    socketKey);
+            return clients.getOrDefault(perSocketServiceType, null);
         }
 
         public List<MdnsServiceTypeClient> getByServiceType(@NonNull String serviceType) {
             final String dnsLowerServiceType = MdnsUtils.toDnsLowerCase(serviceType);
             final List<MdnsServiceTypeClient> list = new ArrayList<>();
             for (int i = 0; i < clients.size(); i++) {
-                final Pair<String, Network> perNetworkServiceType = clients.keyAt(i);
-                if (dnsLowerServiceType.equals(perNetworkServiceType.first)) {
+                final Pair<String, SocketKey> perSocketServiceType = clients.keyAt(i);
+                if (dnsLowerServiceType.equals(perSocketServiceType.first)) {
                     list.add(clients.valueAt(i));
                 }
             }
             return list;
         }
 
-        public List<MdnsServiceTypeClient> getByNetwork(@Nullable Network network) {
+        public List<MdnsServiceTypeClient> getBySocketKey(@NonNull SocketKey socketKey) {
             final List<MdnsServiceTypeClient> list = new ArrayList<>();
             for (int i = 0; i < clients.size(); i++) {
-                final Pair<String, Network> perNetworkServiceType = clients.keyAt(i);
-                final Network serviceTypeNetwork = perNetworkServiceType.second;
-                if (Objects.equals(network, serviceTypeNetwork)) {
+                final Pair<String, SocketKey> perSocketServiceType = clients.keyAt(i);
+                if (socketKey.equals(perSocketServiceType.second)) {
                     list.add(clients.valueAt(i));
                 }
             }
             return list;
         }
 
+        public List<MdnsServiceTypeClient> getAllMdnsServiceTypeClient() {
+            return new ArrayList<>(clients.values());
+        }
+
         public void remove(@NonNull MdnsServiceTypeClient client) {
             final int index = clients.indexOfValue(client);
             clients.removeAt(index);
@@ -113,7 +115,7 @@
         this.executorProvider = executorProvider;
         this.socketClient = socketClient;
         this.sharedLog = sharedLog;
-        this.perNetworkServiceTypeClients = new PerNetworkServiceTypeClients();
+        this.perSocketServiceTypeClients = new PerSocketServiceTypeClients();
         if (socketClient.getLooper() != null) {
             this.handlerThread = null;
             this.handler = new Handler(socketClient.getLooper());
@@ -164,7 +166,7 @@
             @NonNull String serviceType,
             @NonNull MdnsServiceBrowserListener listener,
             @NonNull MdnsSearchOptions searchOptions) {
-        if (perNetworkServiceTypeClients.isEmpty()) {
+        if (perSocketServiceTypeClients.isEmpty()) {
             // First listener. Starts the socket client.
             try {
                 socketClient.startDiscovery();
@@ -177,29 +179,29 @@
         socketClient.notifyNetworkRequested(listener, searchOptions.getNetwork(),
                 new MdnsSocketClientBase.SocketCreationCallback() {
                     @Override
-                    public void onSocketCreated(@Nullable Network network) {
+                    public void onSocketCreated(@NonNull SocketKey socketKey) {
                         ensureRunningOnHandlerThread(handler);
                         // All listeners of the same service types shares the same
                         // MdnsServiceTypeClient.
                         MdnsServiceTypeClient serviceTypeClient =
-                                perNetworkServiceTypeClients.get(serviceType, network);
+                                perSocketServiceTypeClients.get(serviceType, socketKey);
                         if (serviceTypeClient == null) {
-                            serviceTypeClient = createServiceTypeClient(serviceType, network);
-                            perNetworkServiceTypeClients.put(serviceType, network,
+                            serviceTypeClient = createServiceTypeClient(serviceType, socketKey);
+                            perSocketServiceTypeClients.put(serviceType, socketKey,
                                     serviceTypeClient);
                         }
                         serviceTypeClient.startSendAndReceive(listener, searchOptions);
                     }
 
                     @Override
-                    public void onAllSocketsDestroyed(@Nullable Network network) {
+                    public void onAllSocketsDestroyed(@NonNull SocketKey socketKey) {
                         ensureRunningOnHandlerThread(handler);
                         final MdnsServiceTypeClient serviceTypeClient =
-                                perNetworkServiceTypeClients.get(serviceType, network);
+                                perSocketServiceTypeClients.get(serviceType, socketKey);
                         if (serviceTypeClient == null) return;
                         // Notify all listeners that all services are removed from this socket.
                         serviceTypeClient.notifySocketDestroyed();
-                        perNetworkServiceTypeClients.remove(serviceTypeClient);
+                        perSocketServiceTypeClients.remove(serviceTypeClient);
                     }
                 });
     }
@@ -224,7 +226,7 @@
         socketClient.notifyNetworkUnrequested(listener);
 
         final List<MdnsServiceTypeClient> serviceTypeClients =
-                perNetworkServiceTypeClients.getByServiceType(serviceType);
+                perSocketServiceTypeClients.getByServiceType(serviceType);
         if (serviceTypeClients.isEmpty()) {
             return;
         }
@@ -233,60 +235,58 @@
             if (serviceTypeClient.stopSendAndReceive(listener)) {
                 // No listener is registered for the service type anymore, remove it from the list
                 // of the service type clients.
-                perNetworkServiceTypeClients.remove(serviceTypeClient);
+                perSocketServiceTypeClients.remove(serviceTypeClient);
             }
         }
-        if (perNetworkServiceTypeClients.isEmpty()) {
+        if (perSocketServiceTypeClients.isEmpty()) {
             // No discovery request. Stops the socket client.
             socketClient.stopDiscovery();
         }
     }
 
     @Override
-    public void onResponseReceived(@NonNull MdnsPacket packet,
-            int interfaceIndex, @Nullable Network network) {
+    public void onResponseReceived(@NonNull MdnsPacket packet, @NonNull SocketKey socketKey) {
         checkAndRunOnHandlerThread(() ->
-                handleOnResponseReceived(packet, interfaceIndex, network));
+                handleOnResponseReceived(packet, socketKey));
     }
 
-    private void handleOnResponseReceived(@NonNull MdnsPacket packet, int interfaceIndex,
-            @Nullable Network network) {
-        for (MdnsServiceTypeClient serviceTypeClient
-                : getMdnsServiceTypeClient(network)) {
-            serviceTypeClient.processResponse(packet, interfaceIndex, network);
+    private void handleOnResponseReceived(@NonNull MdnsPacket packet,
+            @NonNull SocketKey socketKey) {
+        for (MdnsServiceTypeClient serviceTypeClient : getMdnsServiceTypeClient(socketKey)) {
+            serviceTypeClient.processResponse(
+                    packet, socketKey.getInterfaceIndex(), socketKey.getNetwork());
         }
     }
 
-    private List<MdnsServiceTypeClient> getMdnsServiceTypeClient(@Nullable Network network) {
+    private List<MdnsServiceTypeClient> getMdnsServiceTypeClient(@NonNull SocketKey socketKey) {
         if (socketClient.supportsRequestingSpecificNetworks()) {
-            return perNetworkServiceTypeClients.getByNetwork(network);
+            return perSocketServiceTypeClients.getBySocketKey(socketKey);
         } else {
-            return perNetworkServiceTypeClients.getByNetwork(null);
+            return perSocketServiceTypeClients.getAllMdnsServiceTypeClient();
         }
     }
 
     @Override
     public void onFailedToParseMdnsResponse(int receivedPacketNumber, int errorCode,
-            @Nullable Network network) {
+            @NonNull SocketKey socketKey) {
         checkAndRunOnHandlerThread(() ->
-                handleOnFailedToParseMdnsResponse(receivedPacketNumber, errorCode, network));
+                handleOnFailedToParseMdnsResponse(receivedPacketNumber, errorCode, socketKey));
     }
 
     private void handleOnFailedToParseMdnsResponse(int receivedPacketNumber, int errorCode,
-            @Nullable Network network) {
-        for (MdnsServiceTypeClient serviceTypeClient
-                : getMdnsServiceTypeClient(network)) {
+            @NonNull SocketKey socketKey) {
+        for (MdnsServiceTypeClient serviceTypeClient : getMdnsServiceTypeClient(socketKey)) {
             serviceTypeClient.onFailedToParseMdnsResponse(receivedPacketNumber, errorCode);
         }
     }
 
     @VisibleForTesting
     MdnsServiceTypeClient createServiceTypeClient(@NonNull String serviceType,
-            @Nullable Network network) {
-        sharedLog.log("createServiceTypeClient for type:" + serviceType + ", net:" + network);
+            @NonNull SocketKey socketKey) {
+        sharedLog.log("createServiceTypeClient for type:" + serviceType + " " + socketKey);
         return new MdnsServiceTypeClient(
                 serviceType, socketClient,
-                executorProvider.newServiceTypeClientSchedulerExecutor(), network,
-                sharedLog.forSubComponent(serviceType + "-" + network));
+                executorProvider.newServiceTypeClientSchedulerExecutor(), socketKey,
+                sharedLog.forSubComponent(serviceType + "-" + socketKey));
     }
 }
\ No newline at end of file
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java b/service-t/src/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java
index 03be681..d0ca20e 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java
@@ -84,7 +84,7 @@
             }
             socket.addPacketHandler(handler);
             mActiveNetworkSockets.put(socket, socketKey);
-            mSocketCreationCallback.onSocketCreated(socketKey.getNetwork());
+            mSocketCreationCallback.onSocketCreated(socketKey);
         }
 
         @Override
@@ -97,7 +97,7 @@
         private void notifySocketDestroyed(@NonNull MdnsInterfaceSocket socket) {
             final SocketKey socketKey = mActiveNetworkSockets.remove(socket);
             if (!isAnySocketActive(socketKey)) {
-                mSocketCreationCallback.onAllSocketsDestroyed(socketKey.getNetwork());
+                mSocketCreationCallback.onAllSocketsDestroyed(socketKey);
             }
         }
 
@@ -247,16 +247,14 @@
             if (e.code != MdnsResponseErrorCode.ERROR_NOT_RESPONSE_MESSAGE) {
                 Log.e(TAG, e.getMessage(), e);
                 if (mCallback != null) {
-                    mCallback.onFailedToParseMdnsResponse(
-                            packetNumber, e.code, socketKey.getNetwork());
+                    mCallback.onFailedToParseMdnsResponse(packetNumber, e.code, socketKey);
                 }
             }
             return;
         }
 
         if (mCallback != null) {
-            mCallback.onResponseReceived(
-                    response, socketKey.getInterfaceIndex(), socketKey.getNetwork());
+            mCallback.onResponseReceived(response, socketKey);
         }
     }
 
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
index 49a376c..6cca0f5 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
@@ -58,7 +58,7 @@
     private final MdnsSocketClientBase socketClient;
     private final MdnsResponseDecoder responseDecoder;
     private final ScheduledExecutorService executor;
-    @Nullable private final Network network;
+    @NonNull private final SocketKey socketKey;
     @NonNull private final SharedLog sharedLog;
     private final Object lock = new Object();
     private final ArrayMap<MdnsServiceBrowserListener, MdnsSearchOptions> listeners =
@@ -90,9 +90,9 @@
             @NonNull String serviceType,
             @NonNull MdnsSocketClientBase socketClient,
             @NonNull ScheduledExecutorService executor,
-            @Nullable Network network,
+            @NonNull SocketKey socketKey,
             @NonNull SharedLog sharedLog) {
-        this(serviceType, socketClient, executor, new MdnsResponseDecoder.Clock(), network,
+        this(serviceType, socketClient, executor, new MdnsResponseDecoder.Clock(), socketKey,
                 sharedLog);
     }
 
@@ -102,7 +102,7 @@
             @NonNull MdnsSocketClientBase socketClient,
             @NonNull ScheduledExecutorService executor,
             @NonNull MdnsResponseDecoder.Clock clock,
-            @Nullable Network network,
+            @NonNull SocketKey socketKey,
             @NonNull SharedLog sharedLog) {
         this.serviceType = serviceType;
         this.socketClient = socketClient;
@@ -110,7 +110,7 @@
         this.serviceTypeLabels = TextUtils.split(serviceType, "\\.");
         this.responseDecoder = new MdnsResponseDecoder(clock, serviceTypeLabels);
         this.clock = clock;
-        this.network = network;
+        this.socketKey = socketKey;
         this.sharedLog = sharedLog;
     }
 
@@ -199,7 +199,7 @@
                     searchOptions.getSubtypes(),
                     searchOptions.isPassiveMode(),
                     currentSessionId,
-                    network);
+                    socketKey);
             if (hadReply) {
                 requestTaskFuture = scheduleNextRunLocked(taskConfig);
             } else {
@@ -432,10 +432,10 @@
         private int burstCounter;
         private int timeToRunNextTaskInMs;
         private boolean isFirstBurst;
-        @Nullable private final Network network;
+        @NonNull private final SocketKey socketKey;
 
         QueryTaskConfig(@NonNull Collection<String> subtypes, boolean usePassiveMode,
-                long sessionId, @Nullable Network network) {
+                long sessionId, @NonNull SocketKey socketKey) {
             this.usePassiveMode = usePassiveMode;
             this.subtypes = new ArrayList<>(subtypes);
             this.queriesPerBurst = QUERIES_PER_BURST;
@@ -457,7 +457,7 @@
                 // doubles until it maxes out at TIME_BETWEEN_BURSTS_MS.
                 this.timeBetweenBurstsInMs = INITIAL_TIME_BETWEEN_BURSTS_MS;
             }
-            this.network = network;
+            this.socketKey = socketKey;
         }
 
         QueryTaskConfig getConfigForNextRun() {
@@ -540,7 +540,7 @@
                 // Only the names are used to know which queries to send, other parameters like
                 // interfaceIndex do not matter.
                 servicesToResolve = makeResponsesForResolve(
-                        0 /* interfaceIndex */, config.network);
+                        0 /* interfaceIndex */, config.socketKey.getNetwork());
                 sendDiscoveryQueries = servicesToResolve.size() < listeners.size();
             }
             Pair<Integer, List<String>> result;
@@ -553,7 +553,7 @@
                                 config.subtypes,
                                 config.expectUnicastResponse,
                                 config.transactionId,
-                                config.network,
+                                config.socketKey.getNetwork(),
                                 sendDiscoveryQueries,
                                 servicesToResolve,
                                 clock)
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClient.java b/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClient.java
index b982644..2b6e5d0 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClient.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClient.java
@@ -235,7 +235,7 @@
             throw new IllegalArgumentException("This socket client does not support requesting "
                     + "specific networks");
         }
-        socketCreationCallback.onSocketCreated(null);
+        socketCreationCallback.onSocketCreated(new SocketKey(multicastSocket.getInterfaceIndex()));
     }
 
     @Override
@@ -456,7 +456,8 @@
             LOGGER.w(String.format("Error while decoding %s packet (%d): %d",
                     responseType, packetNumber, e.code));
             if (callback != null) {
-                callback.onFailedToParseMdnsResponse(packetNumber, e.code, network);
+                callback.onFailedToParseMdnsResponse(packetNumber, e.code,
+                        new SocketKey(network, interfaceIndex));
             }
             return e.code;
         }
@@ -466,7 +467,8 @@
         }
 
         if (callback != null) {
-            callback.onResponseReceived(response, interfaceIndex, network);
+            callback.onResponseReceived(
+                    response, new SocketKey(network, interfaceIndex));
         }
 
         return MdnsResponseErrorCode.SUCCESS;
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClientBase.java b/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClientBase.java
index e0762f9..a35925a 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClientBase.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClientBase.java
@@ -73,20 +73,19 @@
     /*** Callback for mdns response  */
     interface Callback {
         /*** Receive a mdns response */
-        void onResponseReceived(@NonNull MdnsPacket packet, int interfaceIndex,
-                @Nullable Network network);
+        void onResponseReceived(@NonNull MdnsPacket packet, @NonNull SocketKey socketKey);
 
         /*** Parse a mdns response failed */
         void onFailedToParseMdnsResponse(int receivedPacketNumber, int errorCode,
-                @Nullable Network network);
+                @NonNull SocketKey socketKey);
     }
 
     /*** Callback for requested socket creation  */
     interface SocketCreationCallback {
         /*** Notify requested socket is created */
-        void onSocketCreated(@Nullable Network network);
+        void onSocketCreated(@NonNull SocketKey socketKey);
 
         /*** Notify requested socket is destroyed */
-        void onAllSocketsDestroyed(@Nullable Network network);
+        void onAllSocketsDestroyed(@NonNull SocketKey socketKey);
     }
 }
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsDiscoveryManagerTests.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsDiscoveryManagerTests.java
index a24664e..d2298fe 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsDiscoveryManagerTests.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsDiscoveryManagerTests.java
@@ -28,7 +28,6 @@
 import static org.mockito.Mockito.when;
 
 import android.annotation.NonNull;
-import android.annotation.Nullable;
 import android.net.Network;
 import android.os.Handler;
 import android.os.HandlerThread;
@@ -65,17 +64,22 @@
     private static final String SERVICE_TYPE_2 = "_test._tcp.local";
     private static final Network NETWORK_1 = Mockito.mock(Network.class);
     private static final Network NETWORK_2 = Mockito.mock(Network.class);
-    private static final Pair<String, Network> PER_NETWORK_SERVICE_TYPE_1_NULL_NETWORK =
-            Pair.create(SERVICE_TYPE_1, null);
-    private static final Pair<String, Network> PER_NETWORK_SERVICE_TYPE_1_NETWORK_1 =
-            Pair.create(SERVICE_TYPE_1, NETWORK_1);
-    private static final Pair<String, Network> PER_NETWORK_SERVICE_TYPE_2_NULL_NETWORK =
-            Pair.create(SERVICE_TYPE_2, null);
-    private static final Pair<String, Network> PER_NETWORK_SERVICE_TYPE_2_NETWORK_1 =
-            Pair.create(SERVICE_TYPE_2, NETWORK_1);
-    private static final Pair<String, Network> PER_NETWORK_SERVICE_TYPE_2_NETWORK_2 =
-            Pair.create(SERVICE_TYPE_2, NETWORK_2);
-
+    private static final SocketKey SOCKET_KEY_NULL_NETWORK =
+            new SocketKey(null /* network */, 999 /* interfaceIndex */);
+    private static final SocketKey SOCKET_KEY_NETWORK_1 =
+            new SocketKey(NETWORK_1, 998 /* interfaceIndex */);
+    private static final SocketKey SOCKET_KEY_NETWORK_2 =
+            new SocketKey(NETWORK_2, 997 /* interfaceIndex */);
+    private static final Pair<String, SocketKey> PER_SOCKET_SERVICE_TYPE_1_NULL_NETWORK =
+            Pair.create(SERVICE_TYPE_1, SOCKET_KEY_NULL_NETWORK);
+    private static final Pair<String, SocketKey> PER_SOCKET_SERVICE_TYPE_2_NULL_NETWORK =
+            Pair.create(SERVICE_TYPE_2, SOCKET_KEY_NULL_NETWORK);
+    private static final Pair<String, SocketKey> PER_SOCKET_SERVICE_TYPE_1_NETWORK_1 =
+            Pair.create(SERVICE_TYPE_1, SOCKET_KEY_NETWORK_1);
+    private static final Pair<String, SocketKey> PER_SOCKET_SERVICE_TYPE_2_NETWORK_1 =
+            Pair.create(SERVICE_TYPE_2, SOCKET_KEY_NETWORK_1);
+    private static final Pair<String, SocketKey> PER_SOCKET_SERVICE_TYPE_2_NETWORK_2 =
+            Pair.create(SERVICE_TYPE_2, SOCKET_KEY_NETWORK_2);
     @Mock private ExecutorProvider executorProvider;
     @Mock private MdnsSocketClientBase socketClient;
     @Mock private MdnsServiceTypeClient mockServiceTypeClientType1NullNetwork;
@@ -104,22 +108,22 @@
                 sharedLog) {
                     @Override
                     MdnsServiceTypeClient createServiceTypeClient(@NonNull String serviceType,
-                            @Nullable Network network) {
-                        final Pair<String, Network> perNetworkServiceType =
-                                Pair.create(serviceType, network);
-                        if (perNetworkServiceType.equals(PER_NETWORK_SERVICE_TYPE_1_NULL_NETWORK)) {
+                            @NonNull SocketKey socketKey) {
+                        final Pair<String, SocketKey> perSocketServiceType =
+                                Pair.create(serviceType, socketKey);
+                        if (perSocketServiceType.equals(PER_SOCKET_SERVICE_TYPE_1_NULL_NETWORK)) {
                             return mockServiceTypeClientType1NullNetwork;
-                        } else if (perNetworkServiceType.equals(
-                                PER_NETWORK_SERVICE_TYPE_1_NETWORK_1)) {
+                        } else if (perSocketServiceType.equals(
+                                PER_SOCKET_SERVICE_TYPE_1_NETWORK_1)) {
                             return mockServiceTypeClientType1Network1;
-                        } else if (perNetworkServiceType.equals(
-                                PER_NETWORK_SERVICE_TYPE_2_NULL_NETWORK)) {
+                        } else if (perSocketServiceType.equals(
+                                PER_SOCKET_SERVICE_TYPE_2_NULL_NETWORK)) {
                             return mockServiceTypeClientType2NullNetwork;
-                        } else if (perNetworkServiceType.equals(
-                                PER_NETWORK_SERVICE_TYPE_2_NETWORK_1)) {
+                        } else if (perSocketServiceType.equals(
+                                PER_SOCKET_SERVICE_TYPE_2_NETWORK_1)) {
                             return mockServiceTypeClientType2Network1;
-                        } else if (perNetworkServiceType.equals(
-                                PER_NETWORK_SERVICE_TYPE_2_NETWORK_2)) {
+                        } else if (perSocketServiceType.equals(
+                                PER_SOCKET_SERVICE_TYPE_2_NETWORK_2)) {
                             return mockServiceTypeClientType2Network2;
                         }
                         return null;
@@ -156,7 +160,7 @@
                 MdnsSearchOptions.newBuilder().setNetwork(null /* network */).build();
         final SocketCreationCallback callback = expectSocketCreationCallback(
                 SERVICE_TYPE_1, mockListenerOne, options);
-        runOnHandler(() -> callback.onSocketCreated(null /* network */));
+        runOnHandler(() -> callback.onSocketCreated(SOCKET_KEY_NULL_NETWORK));
         verify(mockServiceTypeClientType1NullNetwork).startSendAndReceive(mockListenerOne, options);
 
         when(mockServiceTypeClientType1NullNetwork.stopSendAndReceive(mockListenerOne))
@@ -172,16 +176,16 @@
                 MdnsSearchOptions.newBuilder().setNetwork(null /* network */).build();
         final SocketCreationCallback callback = expectSocketCreationCallback(
                 SERVICE_TYPE_1, mockListenerOne, options);
-        runOnHandler(() -> callback.onSocketCreated(null /* network */));
+        runOnHandler(() -> callback.onSocketCreated(SOCKET_KEY_NULL_NETWORK));
         verify(mockServiceTypeClientType1NullNetwork).startSendAndReceive(mockListenerOne, options);
-        runOnHandler(() -> callback.onSocketCreated(NETWORK_1));
+        runOnHandler(() -> callback.onSocketCreated(SOCKET_KEY_NETWORK_1));
         verify(mockServiceTypeClientType1Network1).startSendAndReceive(mockListenerOne, options);
 
         final SocketCreationCallback callback2 = expectSocketCreationCallback(
                 SERVICE_TYPE_2, mockListenerTwo, options);
-        runOnHandler(() -> callback2.onSocketCreated(null /* network */));
+        runOnHandler(() -> callback2.onSocketCreated(SOCKET_KEY_NULL_NETWORK));
         verify(mockServiceTypeClientType2NullNetwork).startSendAndReceive(mockListenerTwo, options);
-        runOnHandler(() -> callback2.onSocketCreated(NETWORK_2));
+        runOnHandler(() -> callback2.onSocketCreated(SOCKET_KEY_NETWORK_2));
         verify(mockServiceTypeClientType2Network2).startSendAndReceive(mockListenerTwo, options);
     }
 
@@ -191,49 +195,48 @@
                 MdnsSearchOptions.newBuilder().setNetwork(null /* network */).build();
         final SocketCreationCallback callback = expectSocketCreationCallback(
                 SERVICE_TYPE_1, mockListenerOne, options1);
-        runOnHandler(() -> callback.onSocketCreated(null /* network */));
+        runOnHandler(() -> callback.onSocketCreated(SOCKET_KEY_NULL_NETWORK));
         verify(mockServiceTypeClientType1NullNetwork).startSendAndReceive(
                 mockListenerOne, options1);
-        runOnHandler(() -> callback.onSocketCreated(NETWORK_1));
+        runOnHandler(() -> callback.onSocketCreated(SOCKET_KEY_NETWORK_1));
         verify(mockServiceTypeClientType1Network1).startSendAndReceive(mockListenerOne, options1);
 
         final MdnsSearchOptions options2 =
                 MdnsSearchOptions.newBuilder().setNetwork(NETWORK_2).build();
         final SocketCreationCallback callback2 = expectSocketCreationCallback(
                 SERVICE_TYPE_2, mockListenerTwo, options2);
-        runOnHandler(() -> callback2.onSocketCreated(NETWORK_2));
+        runOnHandler(() -> callback2.onSocketCreated(SOCKET_KEY_NETWORK_2));
         verify(mockServiceTypeClientType2Network2).startSendAndReceive(mockListenerTwo, options2);
 
         final MdnsPacket responseForServiceTypeOne = createMdnsPacket(SERVICE_TYPE_1);
-        final int ifIndex = 1;
         runOnHandler(() -> discoveryManager.onResponseReceived(
-                responseForServiceTypeOne, ifIndex, null /* network */));
+                responseForServiceTypeOne, SOCKET_KEY_NULL_NETWORK));
         // Packets for network null are only processed by the ServiceTypeClient for network null
         verify(mockServiceTypeClientType1NullNetwork).processResponse(responseForServiceTypeOne,
-                ifIndex, null /* network */);
+                SOCKET_KEY_NULL_NETWORK.getInterfaceIndex(), SOCKET_KEY_NULL_NETWORK.getNetwork());
         verify(mockServiceTypeClientType1Network1, never()).processResponse(any(), anyInt(), any());
         verify(mockServiceTypeClientType2Network2, never()).processResponse(any(), anyInt(), any());
 
         final MdnsPacket responseForServiceTypeTwo = createMdnsPacket(SERVICE_TYPE_2);
         runOnHandler(() -> discoveryManager.onResponseReceived(
-                responseForServiceTypeTwo, ifIndex, NETWORK_1));
+                responseForServiceTypeTwo, SOCKET_KEY_NETWORK_1));
         verify(mockServiceTypeClientType1NullNetwork, never()).processResponse(any(), anyInt(),
-                eq(NETWORK_1));
+                eq(SOCKET_KEY_NETWORK_1.getNetwork()));
         verify(mockServiceTypeClientType1Network1).processResponse(responseForServiceTypeTwo,
-                ifIndex, NETWORK_1);
+                SOCKET_KEY_NETWORK_1.getInterfaceIndex(), SOCKET_KEY_NETWORK_1.getNetwork());
         verify(mockServiceTypeClientType2Network2, never()).processResponse(any(), anyInt(),
-                eq(NETWORK_1));
+                eq(SOCKET_KEY_NETWORK_1.getNetwork()));
 
         final MdnsPacket responseForSubtype =
                 createMdnsPacket("subtype._sub._googlecast._tcp.local");
         runOnHandler(() -> discoveryManager.onResponseReceived(
-                responseForSubtype, ifIndex, NETWORK_2));
-        verify(mockServiceTypeClientType1NullNetwork, never()).processResponse(
-                any(), anyInt(), eq(NETWORK_2));
-        verify(mockServiceTypeClientType1Network1, never()).processResponse(
-                any(), anyInt(), eq(NETWORK_2));
-        verify(mockServiceTypeClientType2Network2).processResponse(
-                responseForSubtype, ifIndex, NETWORK_2);
+                responseForSubtype, SOCKET_KEY_NETWORK_2));
+        verify(mockServiceTypeClientType1NullNetwork, never()).processResponse(any(), anyInt(),
+                eq(SOCKET_KEY_NETWORK_2.getNetwork()));
+        verify(mockServiceTypeClientType1Network1, never()).processResponse(any(), anyInt(),
+                eq(SOCKET_KEY_NETWORK_2.getNetwork()));
+        verify(mockServiceTypeClientType2Network2).processResponse(responseForSubtype,
+                SOCKET_KEY_NETWORK_2.getInterfaceIndex(), SOCKET_KEY_NETWORK_2.getNetwork());
     }
 
     @Test
@@ -243,55 +246,53 @@
                 MdnsSearchOptions.newBuilder().setNetwork(NETWORK_1).build();
         final SocketCreationCallback callback = expectSocketCreationCallback(
                 SERVICE_TYPE_1, mockListenerOne, network1Options);
-        runOnHandler(() -> callback.onSocketCreated(NETWORK_1));
+        runOnHandler(() -> callback.onSocketCreated(SOCKET_KEY_NETWORK_1));
         verify(mockServiceTypeClientType1Network1).startSendAndReceive(
                 mockListenerOne, network1Options);
 
         // Create a ServiceTypeClient for SERVICE_TYPE_2 and NETWORK_1
         final SocketCreationCallback callback2 = expectSocketCreationCallback(
                 SERVICE_TYPE_2, mockListenerTwo, network1Options);
-        runOnHandler(() -> callback2.onSocketCreated(NETWORK_1));
+        runOnHandler(() -> callback2.onSocketCreated(SOCKET_KEY_NETWORK_1));
         verify(mockServiceTypeClientType2Network1).startSendAndReceive(
                 mockListenerTwo, network1Options);
 
         // Receive a response, it should be processed on both clients.
         final MdnsPacket response = createMdnsPacket(SERVICE_TYPE_1);
-        final int ifIndex = 1;
-        runOnHandler(() -> discoveryManager.onResponseReceived(
-                response, ifIndex, NETWORK_1));
-        verify(mockServiceTypeClientType1Network1).processResponse(response, ifIndex, NETWORK_1);
-        verify(mockServiceTypeClientType2Network1).processResponse(response, ifIndex, NETWORK_1);
+        runOnHandler(() -> discoveryManager.onResponseReceived(response, SOCKET_KEY_NETWORK_1));
+        verify(mockServiceTypeClientType1Network1).processResponse(response,
+                SOCKET_KEY_NETWORK_1.getInterfaceIndex(), SOCKET_KEY_NETWORK_1.getNetwork());
+        verify(mockServiceTypeClientType2Network1).processResponse(response,
+                SOCKET_KEY_NETWORK_1.getInterfaceIndex(), SOCKET_KEY_NETWORK_1.getNetwork());
 
         // The first callback receives a notification that the network has been destroyed,
         // mockServiceTypeClientOne1 should send service removed notifications and remove from the
         // list of clients.
-        runOnHandler(() -> callback.onAllSocketsDestroyed(NETWORK_1));
+        runOnHandler(() -> callback.onAllSocketsDestroyed(SOCKET_KEY_NETWORK_1));
         verify(mockServiceTypeClientType1Network1).notifySocketDestroyed();
 
         // Receive a response again, it should be processed only on
         // mockServiceTypeClientType2Network1. Because the mockServiceTypeClientType1Network1 is
         // removed from the list of clients, it is no longer able to process responses.
-        runOnHandler(() -> discoveryManager.onResponseReceived(
-                response, ifIndex, NETWORK_1));
+        runOnHandler(() -> discoveryManager.onResponseReceived(response, SOCKET_KEY_NETWORK_1));
         // Still times(1) as a response was received once previously
-        verify(mockServiceTypeClientType1Network1, times(1))
-                .processResponse(response, ifIndex, NETWORK_1);
-        verify(mockServiceTypeClientType2Network1, times(2))
-                .processResponse(response, ifIndex, NETWORK_1);
+        verify(mockServiceTypeClientType1Network1, times(1)).processResponse(response,
+                SOCKET_KEY_NETWORK_1.getInterfaceIndex(), SOCKET_KEY_NETWORK_1.getNetwork());
+        verify(mockServiceTypeClientType2Network1, times(2)).processResponse(response,
+                SOCKET_KEY_NETWORK_1.getInterfaceIndex(), SOCKET_KEY_NETWORK_1.getNetwork());
 
         // The client for NETWORK_1 receives the callback that the NETWORK_2 has been destroyed,
         // mockServiceTypeClientTwo2 shouldn't send any notifications.
-        runOnHandler(() -> callback2.onAllSocketsDestroyed(NETWORK_2));
+        runOnHandler(() -> callback2.onAllSocketsDestroyed(SOCKET_KEY_NETWORK_2));
         verify(mockServiceTypeClientType2Network1, never()).notifySocketDestroyed();
 
         // Receive a response again, mockServiceTypeClientType2Network1 is still in the list of
         // clients, it's still able to process responses.
-        runOnHandler(() -> discoveryManager.onResponseReceived(
-                response, ifIndex, NETWORK_1));
-        verify(mockServiceTypeClientType1Network1, times(1))
-                .processResponse(response, ifIndex, NETWORK_1);
-        verify(mockServiceTypeClientType2Network1, times(3))
-                .processResponse(response, ifIndex, NETWORK_1);
+        runOnHandler(() -> discoveryManager.onResponseReceived(response, SOCKET_KEY_NETWORK_1));
+        verify(mockServiceTypeClientType1Network1, times(1)).processResponse(response,
+                SOCKET_KEY_NETWORK_1.getInterfaceIndex(), SOCKET_KEY_NETWORK_1.getNetwork());
+        verify(mockServiceTypeClientType2Network1, times(3)).processResponse(response,
+                SOCKET_KEY_NETWORK_1.getInterfaceIndex(), SOCKET_KEY_NETWORK_1.getNetwork());
     }
 
     @Test
@@ -301,27 +302,25 @@
                 MdnsSearchOptions.newBuilder().setNetwork(null /* network */).build();
         final SocketCreationCallback callback = expectSocketCreationCallback(
                 SERVICE_TYPE_1, mockListenerOne, network1Options);
-        runOnHandler(() -> callback.onSocketCreated(null /* network */));
+        runOnHandler(() -> callback.onSocketCreated(SOCKET_KEY_NULL_NETWORK));
         verify(mockServiceTypeClientType1NullNetwork).startSendAndReceive(
                 mockListenerOne, network1Options);
 
         // Receive a response, it should be processed on the client.
         final MdnsPacket response = createMdnsPacket(SERVICE_TYPE_1);
         final int ifIndex = 1;
-        runOnHandler(() -> discoveryManager.onResponseReceived(
-                response, ifIndex, null /* network */));
-        verify(mockServiceTypeClientType1NullNetwork).processResponse(
-                response, ifIndex, null /* network */);
+        runOnHandler(() -> discoveryManager.onResponseReceived(response, SOCKET_KEY_NULL_NETWORK));
+        verify(mockServiceTypeClientType1NullNetwork).processResponse(response,
+                SOCKET_KEY_NULL_NETWORK.getInterfaceIndex(), SOCKET_KEY_NULL_NETWORK.getNetwork());
 
-        runOnHandler(() -> callback.onAllSocketsDestroyed(null /* network */));
+        runOnHandler(() -> callback.onAllSocketsDestroyed(SOCKET_KEY_NULL_NETWORK));
         verify(mockServiceTypeClientType1NullNetwork).notifySocketDestroyed();
 
         // Receive a response again, it should not be processed.
-        runOnHandler(() -> discoveryManager.onResponseReceived(
-                response, ifIndex, null /* network */));
+        runOnHandler(() -> discoveryManager.onResponseReceived(response, SOCKET_KEY_NULL_NETWORK));
         // Still times(1) as a response was received once previously
-        verify(mockServiceTypeClientType1NullNetwork, times(1))
-                .processResponse(response, ifIndex, null /* network */);
+        verify(mockServiceTypeClientType1NullNetwork, times(1)).processResponse(response,
+                SOCKET_KEY_NULL_NETWORK.getInterfaceIndex(), SOCKET_KEY_NULL_NETWORK.getNetwork());
 
         // Unregister the listener, notifyNetworkUnrequested should be called but other stop methods
         // won't be call because the service type client was unregistered and destroyed. But those
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClientTest.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClientTest.java
index a0a302f..f7ef077 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClientTest.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClientTest.java
@@ -21,7 +21,6 @@
 
 import static org.junit.Assert.assertEquals;
 import static org.mockito.ArgumentMatchers.any;
-import static org.mockito.ArgumentMatchers.anyInt;
 import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.eq;
 import static org.mockito.Mockito.mock;
@@ -132,11 +131,11 @@
         doReturn(null).when(tetherSocketKey2).getNetwork();
         // Notify socket created
         callback.onSocketCreated(mSocketKey, mSocket, List.of());
-        verify(mSocketCreationCallback).onSocketCreated(mNetwork);
+        verify(mSocketCreationCallback).onSocketCreated(mSocketKey);
         callback.onSocketCreated(tetherSocketKey1, tetherIfaceSock1, List.of());
-        verify(mSocketCreationCallback).onSocketCreated(null);
+        verify(mSocketCreationCallback).onSocketCreated(tetherSocketKey1);
         callback.onSocketCreated(tetherSocketKey2, tetherIfaceSock2, List.of());
-        verify(mSocketCreationCallback, times(2)).onSocketCreated(null);
+        verify(mSocketCreationCallback).onSocketCreated(tetherSocketKey2);
 
         // Send packet to IPv4 with target network and verify sending has been called.
         mSocketClient.sendMulticastPacket(ipv4Packet, mNetwork);
@@ -172,7 +171,7 @@
         doReturn(createEmptyNetworkInterface()).when(mSocket).getInterface();
         // Notify socket created
         callback.onSocketCreated(mSocketKey, mSocket, List.of());
-        verify(mSocketCreationCallback).onSocketCreated(mNetwork);
+        verify(mSocketCreationCallback).onSocketCreated(mSocketKey);
 
         final ArgumentCaptor<PacketHandler> handlerCaptor =
                 ArgumentCaptor.forClass(PacketHandler.class);
@@ -183,7 +182,7 @@
         handler.handlePacket(data, data.length, null /* src */);
         final ArgumentCaptor<MdnsPacket> responseCaptor =
                 ArgumentCaptor.forClass(MdnsPacket.class);
-        verify(mCallback).onResponseReceived(responseCaptor.capture(), anyInt(), any());
+        verify(mCallback).onResponseReceived(responseCaptor.capture(), any());
         final MdnsPacket response = responseCaptor.getValue();
         assertEquals(0, response.questions.size());
         assertEquals(0, response.additionalRecords.size());
@@ -222,12 +221,13 @@
         doReturn(createEmptyNetworkInterface()).when(socket3).getInterface();
 
         final SocketKey socketKey2 = mock(SocketKey.class);
-        doReturn(null).when(socketKey2).getNetwork();
+        final SocketKey socketKey3 = mock(SocketKey.class);
         callback.onSocketCreated(mSocketKey, mSocket, List.of());
         callback.onSocketCreated(socketKey2, socket2, List.of());
-        callback.onSocketCreated(socketKey2, socket3, List.of());
-        verify(mSocketCreationCallback).onSocketCreated(mNetwork);
-        verify(mSocketCreationCallback, times(2)).onSocketCreated(null);
+        callback.onSocketCreated(socketKey3, socket3, List.of());
+        verify(mSocketCreationCallback).onSocketCreated(mSocketKey);
+        verify(mSocketCreationCallback).onSocketCreated(socketKey2);
+        verify(mSocketCreationCallback).onSocketCreated(socketKey3);
 
         // Send IPv4 packet on the non-null Network and verify sending has been called.
         mSocketClient.sendMulticastPacket(ipv4Packet, mNetwork);
@@ -252,9 +252,10 @@
         // Notify socket created for all networks.
         callback2.onSocketCreated(mSocketKey, mSocket, List.of());
         callback2.onSocketCreated(socketKey2, socket2, List.of());
-        callback2.onSocketCreated(socketKey2, socket3, List.of());
-        verify(socketCreationCb2).onSocketCreated(mNetwork);
-        verify(socketCreationCb2, times(2)).onSocketCreated(null);
+        callback2.onSocketCreated(socketKey3, socket3, List.of());
+        verify(socketCreationCb2).onSocketCreated(mSocketKey);
+        verify(socketCreationCb2).onSocketCreated(socketKey2);
+        verify(socketCreationCb2).onSocketCreated(socketKey3);
 
         // Send IPv4 packet to null network and verify sending to the 2 tethered interface sockets.
         mSocketClient.sendMulticastPacket(ipv4Packet, null);
@@ -296,16 +297,16 @@
         doReturn(createEmptyNetworkInterface()).when(otherSocket).getInterface();
 
         callback.onSocketCreated(mSocketKey, mSocket, List.of());
-        verify(mSocketCreationCallback).onSocketCreated(mNetwork);
+        verify(mSocketCreationCallback).onSocketCreated(mSocketKey);
         callback.onSocketCreated(mSocketKey, otherSocket, List.of());
-        verify(mSocketCreationCallback, times(2)).onSocketCreated(mNetwork);
+        verify(mSocketCreationCallback, times(2)).onSocketCreated(mSocketKey);
 
-        verify(mSocketCreationCallback, never()).onAllSocketsDestroyed(mNetwork);
+        verify(mSocketCreationCallback, never()).onAllSocketsDestroyed(mSocketKey);
         mHandler.post(() -> mSocketClient.notifyNetworkUnrequested(mListener));
         HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
 
         verify(mProvider).unrequestSocket(callback);
-        verify(mSocketCreationCallback).onAllSocketsDestroyed(mNetwork);
+        verify(mSocketCreationCallback).onAllSocketsDestroyed(mSocketKey);
     }
 
     @Test
@@ -316,14 +317,14 @@
         doReturn(createEmptyNetworkInterface()).when(otherSocket).getInterface();
 
         callback.onSocketCreated(mSocketKey, mSocket, List.of());
-        verify(mSocketCreationCallback).onSocketCreated(mNetwork);
+        verify(mSocketCreationCallback).onSocketCreated(mSocketKey);
         callback.onSocketCreated(mSocketKey, otherSocket, List.of());
-        verify(mSocketCreationCallback, times(2)).onSocketCreated(mNetwork);
+        verify(mSocketCreationCallback, times(2)).onSocketCreated(mSocketKey);
 
         // Notify socket destroyed
         callback.onInterfaceDestroyed(mSocketKey, mSocket);
         verifyNoMoreInteractions(mSocketCreationCallback);
         callback.onInterfaceDestroyed(mSocketKey, otherSocket);
-        verify(mSocketCreationCallback).onAllSocketsDestroyed(mNetwork);
+        verify(mSocketCreationCallback).onAllSocketsDestroyed(mSocketKey);
     }
 }
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java
index d1adecf..635a1d4 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java
@@ -118,6 +118,7 @@
     private FakeExecutor currentThreadExecutor = new FakeExecutor();
 
     private MdnsServiceTypeClient client;
+    private SocketKey socketKey;
 
     @Before
     @SuppressWarnings("DoNotMock")
@@ -128,6 +129,7 @@
         expectedIPv4Packets = new DatagramPacket[16];
         expectedIPv6Packets = new DatagramPacket[16];
         expectedSendFutures = new ScheduledFuture<?>[16];
+        socketKey = new SocketKey(mockNetwork, INTERFACE_INDEX);
 
         for (int i = 0; i < expectedSendFutures.length; ++i) {
             expectedIPv4Packets[i] = new DatagramPacket(buf, 0 /* offset */, 5 /* length */,
@@ -174,7 +176,7 @@
 
         client =
                 new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor,
-                        mockDecoderClock, mockNetwork, mockSharedLog) {
+                        mockDecoderClock, socketKey, mockSharedLog) {
                     @Override
                     MdnsPacketWriter createMdnsPacketWriter() {
                         return mockPacketWriter;
@@ -325,7 +327,7 @@
         MdnsSearchOptions searchOptions =
                 MdnsSearchOptions.newBuilder().addSubtype("12345").setIsPassiveMode(false).build();
         QueryTaskConfig config = new QueryTaskConfig(
-                searchOptions.getSubtypes(), searchOptions.isPassiveMode(), 1, mockNetwork);
+                searchOptions.getSubtypes(), searchOptions.isPassiveMode(), 1, socketKey);
 
         // This is the first query. We will ask for unicast response.
         assertTrue(config.expectUnicastResponse);
@@ -354,7 +356,7 @@
         MdnsSearchOptions searchOptions =
                 MdnsSearchOptions.newBuilder().addSubtype("12345").setIsPassiveMode(false).build();
         QueryTaskConfig config = new QueryTaskConfig(
-                searchOptions.getSubtypes(), searchOptions.isPassiveMode(), 1, mockNetwork);
+                searchOptions.getSubtypes(), searchOptions.isPassiveMode(), 1, socketKey);
 
         // This is the first query. We will ask for unicast response.
         assertTrue(config.expectUnicastResponse);
@@ -508,9 +510,9 @@
 
         // Process a second response with a different port and updated text attributes.
         client.processResponse(createResponse(
-                "service-instance-1", ipV4Address, 5354,
-                /* subtype= */ "ABCDE",
-                Collections.singletonMap("key", "value"), TEST_TTL),
+                        "service-instance-1", ipV4Address, 5354,
+                        /* subtype= */ "ABCDE",
+                        Collections.singletonMap("key", "value"), TEST_TTL),
                 /* interfaceIndex= */ 20, mockNetwork);
 
         // Verify onServiceNameDiscovered was called once for the initial response.
@@ -563,9 +565,9 @@
 
         // Process a second response with a different port and updated text attributes.
         client.processResponse(createResponse(
-                "service-instance-1", ipV6Address, 5354,
-                /* subtype= */ "ABCDE",
-                Collections.singletonMap("key", "value"), TEST_TTL),
+                        "service-instance-1", ipV6Address, 5354,
+                        /* subtype= */ "ABCDE",
+                        Collections.singletonMap("key", "value"), TEST_TTL),
                 /* interfaceIndex= */ 20, mockNetwork);
 
         // Verify onServiceNameDiscovered was called once for the initial response.
@@ -709,7 +711,7 @@
         final String serviceInstanceName = "service-instance-1";
         client =
                 new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor,
-                        mockDecoderClock, mockNetwork, mockSharedLog) {
+                        mockDecoderClock, socketKey, mockSharedLog) {
                     @Override
                     MdnsPacketWriter createMdnsPacketWriter() {
                         return mockPacketWriter;
@@ -750,7 +752,7 @@
         final String serviceInstanceName = "service-instance-1";
         client =
                 new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor,
-                        mockDecoderClock, mockNetwork, mockSharedLog) {
+                        mockDecoderClock, socketKey, mockSharedLog) {
                     @Override
                     MdnsPacketWriter createMdnsPacketWriter() {
                         return mockPacketWriter;
@@ -783,7 +785,7 @@
         final String serviceInstanceName = "service-instance-1";
         client =
                 new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor,
-                        mockDecoderClock, mockNetwork, mockSharedLog) {
+                        mockDecoderClock, socketKey, mockSharedLog) {
                     @Override
                     MdnsPacketWriter createMdnsPacketWriter() {
                         return mockPacketWriter;
@@ -835,8 +837,8 @@
 
         // Process the last response which is goodbye message (with the main type, not subtype).
         client.processResponse(createResponse(
-                serviceName, ipV6Address, 5354, SERVICE_TYPE_LABELS,
-                Collections.singletonMap("key", "value"), /* ptrTtlMillis= */ 0L),
+                        serviceName, ipV6Address, 5354, SERVICE_TYPE_LABELS,
+                        Collections.singletonMap("key", "value"), /* ptrTtlMillis= */ 0L),
                 INTERFACE_INDEX, mockNetwork);
 
         // Verify onServiceNameDiscovered was first called for the initial response.
@@ -908,7 +910,7 @@
     @Test
     public void testProcessResponse_Resolve() throws Exception {
         client = new MdnsServiceTypeClient(
-                SERVICE_TYPE, mockSocketClient, currentThreadExecutor, mockNetwork, mockSharedLog);
+                SERVICE_TYPE, mockSocketClient, currentThreadExecutor, socketKey, mockSharedLog);
 
         final String instanceName = "service-instance";
         final String[] hostname = new String[] { "testhost "};
@@ -998,7 +1000,7 @@
     @Test
     public void testRenewTxtSrvInResolve() throws Exception {
         client = new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor,
-                mockDecoderClock, mockNetwork, mockSharedLog);
+                mockDecoderClock, socketKey, mockSharedLog);
 
         final String instanceName = "service-instance";
         final String[] hostname = new String[] { "testhost "};
@@ -1102,7 +1104,7 @@
     @Test
     public void testProcessResponse_ResolveExcludesOtherServices() {
         client = new MdnsServiceTypeClient(
-                SERVICE_TYPE, mockSocketClient, currentThreadExecutor, mockNetwork, mockSharedLog);
+                SERVICE_TYPE, mockSocketClient, currentThreadExecutor, socketKey, mockSharedLog);
 
         final String requestedInstance = "instance1";
         final String otherInstance = "instance2";
@@ -1119,13 +1121,13 @@
 
         // Complete response from instanceName
         client.processResponse(createResponse(
-                requestedInstance, ipV4Address, 5353, SERVICE_TYPE_LABELS,
+                        requestedInstance, ipV4Address, 5353, SERVICE_TYPE_LABELS,
                         Collections.emptyMap() /* textAttributes */, TEST_TTL),
                 INTERFACE_INDEX, mockNetwork);
 
         // Complete response from otherInstanceName
         client.processResponse(createResponse(
-                otherInstance, ipV4Address, 5353, SERVICE_TYPE_LABELS,
+                        otherInstance, ipV4Address, 5353, SERVICE_TYPE_LABELS,
                         Collections.emptyMap() /* textAttributes */, TEST_TTL),
                 INTERFACE_INDEX, mockNetwork);
 
@@ -1166,7 +1168,7 @@
     @Test
     public void testProcessResponse_SubtypeDiscoveryLimitedToSubtype() {
         client = new MdnsServiceTypeClient(
-                SERVICE_TYPE, mockSocketClient, currentThreadExecutor, mockNetwork, mockSharedLog);
+                SERVICE_TYPE, mockSocketClient, currentThreadExecutor, socketKey, mockSharedLog);
 
         final String matchingInstance = "instance1";
         final String subtype = "_subtype";
@@ -1247,7 +1249,7 @@
     @Test
     public void testNotifySocketDestroyed() throws Exception {
         client = new MdnsServiceTypeClient(
-                SERVICE_TYPE, mockSocketClient, currentThreadExecutor, mockNetwork, mockSharedLog);
+                SERVICE_TYPE, mockSocketClient, currentThreadExecutor, socketKey, mockSharedLog);
 
         final String requestedInstance = "instance1";
         final String otherInstance = "instance2";
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketClientTests.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketClientTests.java
index abb1747..e30c249 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketClientTests.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketClientTests.java
@@ -23,6 +23,7 @@
 import static org.junit.Assert.assertTrue;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.anyInt;
+import static org.mockito.ArgumentMatchers.argThat;
 import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.never;
@@ -370,7 +371,7 @@
         mdnsClient.startDiscovery();
 
         verify(mockCallback, timeout(TIMEOUT).atLeast(1))
-                .onResponseReceived(any(MdnsPacket.class), anyInt(), any());
+                .onResponseReceived(any(MdnsPacket.class), any(SocketKey.class));
     }
 
     @Test
@@ -379,7 +380,7 @@
         mdnsClient.startDiscovery();
 
         verify(mockCallback, timeout(TIMEOUT).atLeastOnce())
-                .onResponseReceived(any(MdnsPacket.class), anyInt(), any());
+                .onResponseReceived(any(MdnsPacket.class), any(SocketKey.class));
 
         mdnsClient.stopDiscovery();
     }
@@ -513,7 +514,7 @@
         mdnsClient.startDiscovery();
 
         verify(mockCallback, timeout(TIMEOUT).atLeastOnce())
-                .onResponseReceived(any(), eq(21), any());
+                .onResponseReceived(any(), argThat(key -> key.getInterfaceIndex() == 21));
     }
 
     @Test
@@ -536,6 +537,7 @@
         mdnsClient.startDiscovery();
 
         verify(mockMulticastSocket, never()).getInterfaceIndex();
-        verify(mockCallback, timeout(TIMEOUT).atLeast(1)).onResponseReceived(any(), eq(-1), any());
+        verify(mockCallback, timeout(TIMEOUT).atLeast(1))
+                .onResponseReceived(any(), argThat(key -> key.getInterfaceIndex() == -1));
     }
 }
\ No newline at end of file