Send packets on the specific socket

In the previous design, the network could be null for some
MdnsServiceTypeClient (tethering interfaces). However, the
MdnsMultinetworkSocketClient is designed to send packets on
specific networks. This means that if there are multiple
tethering interfaces, packets will be sent to unnecessary
sockets. Now, the MdnsServiceTypeClient is created by the socket,
which can identify which socket is the target for sending
packets. Therefore, the design should be updated to send packets
only on the specific socket.

Bug: 278018903
Test: atest FrameworksNetTests android.net.cts.NsdManagerTest
Change-Id: Id87e6d202599c57620281a6761d3c56acd2c929c
diff --git a/service-t/src/com/android/server/connectivity/mdns/EnqueueMdnsQueryCallable.java b/service-t/src/com/android/server/connectivity/mdns/EnqueueMdnsQueryCallable.java
index 2d5bb00..bd4ec20 100644
--- a/service-t/src/com/android/server/connectivity/mdns/EnqueueMdnsQueryCallable.java
+++ b/service-t/src/com/android/server/connectivity/mdns/EnqueueMdnsQueryCallable.java
@@ -18,7 +18,6 @@
 
 import android.annotation.NonNull;
 import android.annotation.Nullable;
-import android.net.Network;
 import android.text.TextUtils;
 import android.util.Log;
 import android.util.Pair;
@@ -70,8 +69,8 @@
     private final List<String> subtypes;
     private final boolean expectUnicastResponse;
     private final int transactionId;
-    @Nullable
-    private final Network network;
+    @NonNull
+    private final SocketKey socketKey;
     private final boolean sendDiscoveryQueries;
     @NonNull
     private final List<MdnsResponse> servicesToResolve;
@@ -86,7 +85,7 @@
             @NonNull Collection<String> subtypes,
             boolean expectUnicastResponse,
             int transactionId,
-            @Nullable Network network,
+            @NonNull SocketKey socketKey,
             boolean onlyUseIpv6OnIpv6OnlyNetworks,
             boolean sendDiscoveryQueries,
             @NonNull Collection<MdnsResponse> servicesToResolve,
@@ -97,7 +96,7 @@
         this.subtypes = new ArrayList<>(subtypes);
         this.expectUnicastResponse = expectUnicastResponse;
         this.transactionId = transactionId;
-        this.network = network;
+        this.socketKey = socketKey;
         this.onlyUseIpv6OnIpv6OnlyNetworks = onlyUseIpv6OnIpv6OnlyNetworks;
         this.sendDiscoveryQueries = sendDiscoveryQueries;
         this.servicesToResolve = new ArrayList<>(servicesToResolve);
@@ -216,7 +215,7 @@
         if (expectUnicastResponse) {
             if (requestSender instanceof MdnsMultinetworkSocketClient) {
                 ((MdnsMultinetworkSocketClient) requestSender).sendPacketRequestingUnicastResponse(
-                        packet, network, onlyUseIpv6OnIpv6OnlyNetworks);
+                        packet, socketKey, onlyUseIpv6OnIpv6OnlyNetworks);
             } else {
                 requestSender.sendPacketRequestingUnicastResponse(
                         packet, onlyUseIpv6OnIpv6OnlyNetworks);
@@ -225,7 +224,7 @@
             if (requestSender instanceof MdnsMultinetworkSocketClient) {
                 ((MdnsMultinetworkSocketClient) requestSender)
                         .sendPacketRequestingMulticastResponse(
-                                packet, network, onlyUseIpv6OnIpv6OnlyNetworks);
+                                packet, socketKey, onlyUseIpv6OnIpv6OnlyNetworks);
             } else {
                 requestSender.sendPacketRequestingMulticastResponse(
                         packet, onlyUseIpv6OnIpv6OnlyNetworks);
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsConstants.java b/service-t/src/com/android/server/connectivity/mdns/MdnsConstants.java
index f0e1717..ce5f540 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsConstants.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsConstants.java
@@ -16,6 +16,8 @@
 
 package com.android.server.connectivity.mdns;
 
+import static com.android.internal.annotations.VisibleForTesting.Visibility.PACKAGE;
+
 import static java.nio.charset.StandardCharsets.UTF_8;
 
 import com.android.internal.annotations.VisibleForTesting;
@@ -25,7 +27,7 @@
 import java.nio.charset.Charset;
 
 /** mDNS-related constants. */
-@VisibleForTesting
+@VisibleForTesting(visibility = PACKAGE)
 public final class MdnsConstants {
     public static final int MDNS_PORT = 5353;
     // Flags word format is:
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java b/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java
index afad3b7..7288828 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java
@@ -254,8 +254,7 @@
     private void handleOnResponseReceived(@NonNull MdnsPacket packet,
             @NonNull SocketKey socketKey) {
         for (MdnsServiceTypeClient serviceTypeClient : getMdnsServiceTypeClient(socketKey)) {
-            serviceTypeClient.processResponse(
-                    packet, socketKey.getInterfaceIndex(), socketKey.getNetwork());
+            serviceTypeClient.processResponse(packet, socketKey);
         }
     }
 
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java b/service-t/src/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java
index 1253444..18e460e 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java
@@ -213,25 +213,22 @@
         return true;
     }
 
-    private void sendMdnsPacket(@NonNull DatagramPacket packet, @Nullable Network targetNetwork,
+    private void sendMdnsPacket(@NonNull DatagramPacket packet, @NonNull SocketKey targetSocketKey,
             boolean onlyUseIpv6OnIpv6OnlyNetworks) {
         final boolean isIpv6 = ((InetSocketAddress) packet.getSocketAddress()).getAddress()
                 instanceof Inet6Address;
         final boolean isIpv4 = ((InetSocketAddress) packet.getSocketAddress()).getAddress()
                 instanceof Inet4Address;
         final ArrayMap<MdnsInterfaceSocket, SocketKey> activeSockets = getActiveSockets();
-        boolean shouldQueryIpv6 = !onlyUseIpv6OnIpv6OnlyNetworks || isIpv6OnlyNetworks(
-                activeSockets, targetNetwork);
+        boolean shouldQueryIpv6 = !onlyUseIpv6OnIpv6OnlyNetworks || isIpv6OnlySockets(
+                activeSockets, targetSocketKey);
         for (int i = 0; i < activeSockets.size(); i++) {
             final MdnsInterfaceSocket socket = activeSockets.keyAt(i);
-            final Network network = activeSockets.valueAt(i).getNetwork();
+            final SocketKey socketKey = activeSockets.valueAt(i);
             // Check ip capability and network before sending packet
             if (((isIpv6 && socket.hasJoinedIpv6() && shouldQueryIpv6)
                     || (isIpv4 && socket.hasJoinedIpv4()))
-                    // Contrary to MdnsUtils.isNetworkMatched, only send packets targeting
-                    // the null network to interfaces that have the null network (tethering
-                    // downstream interfaces).
-                    && Objects.equals(network, targetNetwork)) {
+                    && Objects.equals(socketKey, targetSocketKey)) {
                 try {
                     socket.send(packet);
                 } catch (IOException e) {
@@ -241,13 +238,13 @@
         }
     }
 
-    private boolean isIpv6OnlyNetworks(
+    private boolean isIpv6OnlySockets(
             @NonNull ArrayMap<MdnsInterfaceSocket, SocketKey> activeSockets,
-            @Nullable Network targetNetwork) {
+            @NonNull SocketKey targetSocketKey) {
         for (int i = 0; i < activeSockets.size(); i++) {
             final MdnsInterfaceSocket socket = activeSockets.keyAt(i);
-            final Network network = activeSockets.valueAt(i).getNetwork();
-            if (Objects.equals(network, targetNetwork) && socket.hasJoinedIpv4()) {
+            final SocketKey socketKey = activeSockets.valueAt(i);
+            if (Objects.equals(socketKey, targetSocketKey) && socket.hasJoinedIpv4()) {
                 return false;
             }
         }
@@ -276,38 +273,35 @@
     }
 
     /**
-     * Send a mDNS request packet via given network that asks for multicast response.
-     *
-     * <p>The socket client may use a null network to identify some or all interfaces, in which case
-     * passing null sends the packet to these.
+     * Send a mDNS request packet via given socket key that asks for multicast response.
      */
     public void sendPacketRequestingMulticastResponse(@NonNull DatagramPacket packet,
-            @Nullable Network network, boolean onlyUseIpv6OnIpv6OnlyNetworks) {
-        mHandler.post(() -> sendMdnsPacket(packet, network, onlyUseIpv6OnIpv6OnlyNetworks));
+            @NonNull SocketKey socketKey, boolean onlyUseIpv6OnIpv6OnlyNetworks) {
+        mHandler.post(() -> sendMdnsPacket(packet, socketKey, onlyUseIpv6OnIpv6OnlyNetworks));
     }
 
     @Override
     public void sendPacketRequestingMulticastResponse(
             @NonNull DatagramPacket packet, boolean onlyUseIpv6OnIpv6OnlyNetworks) {
-        sendPacketRequestingMulticastResponse(
-                packet, null /* network */, onlyUseIpv6OnIpv6OnlyNetworks);
+        throw new UnsupportedOperationException("This socket client need to specify the socket to"
+                + "send packet");
     }
 
     /**
-     * Send a mDNS request packet via given network that asks for unicast response.
+     * Send a mDNS request packet via given socket key that asks for unicast response.
      *
      * <p>The socket client may use a null network to identify some or all interfaces, in which case
      * passing null sends the packet to these.
      */
     public void sendPacketRequestingUnicastResponse(@NonNull DatagramPacket packet,
-            @Nullable Network network, boolean onlyUseIpv6OnIpv6OnlyNetworks) {
-        mHandler.post(() -> sendMdnsPacket(packet, network, onlyUseIpv6OnIpv6OnlyNetworks));
+            @NonNull SocketKey socketKey, boolean onlyUseIpv6OnIpv6OnlyNetworks) {
+        mHandler.post(() -> sendMdnsPacket(packet, socketKey, onlyUseIpv6OnIpv6OnlyNetworks));
     }
 
     @Override
     public void sendPacketRequestingUnicastResponse(
             @NonNull DatagramPacket packet, boolean onlyUseIpv6OnIpv6OnlyNetworks) {
-        sendPacketRequestingUnicastResponse(
-                packet, null /* network */, onlyUseIpv6OnIpv6OnlyNetworks);
+        throw new UnsupportedOperationException("This socket client need to specify the socket to"
+                + "send packet");
     }
 }
\ No newline at end of file
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
index a36eb1b..48e4724 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
@@ -20,7 +20,6 @@
 
 import android.annotation.NonNull;
 import android.annotation.Nullable;
-import android.net.Network;
 import android.text.TextUtils;
 import android.util.ArrayMap;
 import android.util.ArraySet;
@@ -262,15 +261,13 @@
     /**
      * Process an incoming response packet.
      */
-    public synchronized void processResponse(@NonNull MdnsPacket packet, int interfaceIndex,
-            Network network) {
+    public synchronized void processResponse(@NonNull MdnsPacket packet,
+            @NonNull SocketKey socketKey) {
         synchronized (lock) {
             // Augment the list of current known responses, and generated responses for resolve
             // requests if there is no known response
             final List<MdnsResponse> currentList = new ArrayList<>(instanceNameToResponse.values());
-
-            List<MdnsResponse> additionalResponses = makeResponsesForResolve(interfaceIndex,
-                    network);
+            List<MdnsResponse> additionalResponses = makeResponsesForResolve(socketKey);
             for (MdnsResponse additionalResponse : additionalResponses) {
                 if (!instanceNameToResponse.containsKey(
                         additionalResponse.getServiceInstanceName())) {
@@ -278,7 +275,8 @@
                 }
             }
             final Pair<ArraySet<MdnsResponse>, ArrayList<MdnsResponse>> augmentedResult =
-                    responseDecoder.augmentResponses(packet, currentList, interfaceIndex, network);
+                    responseDecoder.augmentResponses(packet, currentList,
+                            socketKey.getInterfaceIndex(), socketKey.getNetwork());
 
             final ArraySet<MdnsResponse> modifiedResponse = augmentedResult.first;
             final ArrayList<MdnsResponse> allResponses = augmentedResult.second;
@@ -508,8 +506,7 @@
         }
     }
 
-    private List<MdnsResponse> makeResponsesForResolve(int interfaceIndex,
-            @NonNull Network network) {
+    private List<MdnsResponse> makeResponsesForResolve(@NonNull SocketKey socketKey) {
         final List<MdnsResponse> resolveResponses = new ArrayList<>();
         for (int i = 0; i < listeners.size(); i++) {
             final String resolveName = listeners.valueAt(i).getResolveInstanceName();
@@ -524,7 +521,7 @@
                 instanceFullName.addAll(Arrays.asList(serviceTypeLabels));
                 knownResponse = new MdnsResponse(
                         0L /* lastUpdateTime */, instanceFullName.toArray(new String[0]),
-                        interfaceIndex, network);
+                        socketKey.getInterfaceIndex(), socketKey.getNetwork());
             }
             resolveResponses.add(knownResponse);
         }
@@ -548,10 +545,7 @@
                 // The listener is requesting to resolve a service that has no info in
                 // cache. Use the provided name to generate a minimal response, so other records are
                 // queried to complete it.
-                // Only the names are used to know which queries to send, other parameters like
-                // interfaceIndex do not matter.
-                servicesToResolve = makeResponsesForResolve(
-                        0 /* interfaceIndex */, config.socketKey.getNetwork());
+                servicesToResolve = makeResponsesForResolve(config.socketKey);
                 sendDiscoveryQueries = servicesToResolve.size() < listeners.size();
             }
             Pair<Integer, List<String>> result;
@@ -564,7 +558,7 @@
                                 config.subtypes,
                                 config.expectUnicastResponse,
                                 config.transactionId,
-                                config.socketKey.getNetwork(),
+                                config.socketKey,
                                 config.onlyUseIpv6OnIpv6OnlyNetworks,
                                 sendDiscoveryQueries,
                                 servicesToResolve,
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsDiscoveryManagerTests.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsDiscoveryManagerTests.java
index d2298fe..81e1cd4 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsDiscoveryManagerTests.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsDiscoveryManagerTests.java
@@ -19,7 +19,6 @@
 import static com.android.testutils.DevSdkIgnoreRuleKt.SC_V2;
 
 import static org.mockito.ArgumentMatchers.any;
-import static org.mockito.ArgumentMatchers.anyInt;
 import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.eq;
 import static org.mockito.Mockito.never;
@@ -212,31 +211,31 @@
         runOnHandler(() -> discoveryManager.onResponseReceived(
                 responseForServiceTypeOne, SOCKET_KEY_NULL_NETWORK));
         // Packets for network null are only processed by the ServiceTypeClient for network null
-        verify(mockServiceTypeClientType1NullNetwork).processResponse(responseForServiceTypeOne,
-                SOCKET_KEY_NULL_NETWORK.getInterfaceIndex(), SOCKET_KEY_NULL_NETWORK.getNetwork());
-        verify(mockServiceTypeClientType1Network1, never()).processResponse(any(), anyInt(), any());
-        verify(mockServiceTypeClientType2Network2, never()).processResponse(any(), anyInt(), any());
+        verify(mockServiceTypeClientType1NullNetwork).processResponse(
+                responseForServiceTypeOne, SOCKET_KEY_NULL_NETWORK);
+        verify(mockServiceTypeClientType1Network1, never()).processResponse(any(), any());
+        verify(mockServiceTypeClientType2Network2, never()).processResponse(any(), any());
 
         final MdnsPacket responseForServiceTypeTwo = createMdnsPacket(SERVICE_TYPE_2);
         runOnHandler(() -> discoveryManager.onResponseReceived(
                 responseForServiceTypeTwo, SOCKET_KEY_NETWORK_1));
-        verify(mockServiceTypeClientType1NullNetwork, never()).processResponse(any(), anyInt(),
-                eq(SOCKET_KEY_NETWORK_1.getNetwork()));
-        verify(mockServiceTypeClientType1Network1).processResponse(responseForServiceTypeTwo,
-                SOCKET_KEY_NETWORK_1.getInterfaceIndex(), SOCKET_KEY_NETWORK_1.getNetwork());
-        verify(mockServiceTypeClientType2Network2, never()).processResponse(any(), anyInt(),
-                eq(SOCKET_KEY_NETWORK_1.getNetwork()));
+        verify(mockServiceTypeClientType1NullNetwork, never()).processResponse(any(),
+                eq(SOCKET_KEY_NETWORK_1));
+        verify(mockServiceTypeClientType1Network1).processResponse(
+                responseForServiceTypeTwo, SOCKET_KEY_NETWORK_1);
+        verify(mockServiceTypeClientType2Network2, never()).processResponse(any(),
+                eq(SOCKET_KEY_NETWORK_1));
 
         final MdnsPacket responseForSubtype =
                 createMdnsPacket("subtype._sub._googlecast._tcp.local");
         runOnHandler(() -> discoveryManager.onResponseReceived(
                 responseForSubtype, SOCKET_KEY_NETWORK_2));
-        verify(mockServiceTypeClientType1NullNetwork, never()).processResponse(any(), anyInt(),
-                eq(SOCKET_KEY_NETWORK_2.getNetwork()));
-        verify(mockServiceTypeClientType1Network1, never()).processResponse(any(), anyInt(),
-                eq(SOCKET_KEY_NETWORK_2.getNetwork()));
-        verify(mockServiceTypeClientType2Network2).processResponse(responseForSubtype,
-                SOCKET_KEY_NETWORK_2.getInterfaceIndex(), SOCKET_KEY_NETWORK_2.getNetwork());
+        verify(mockServiceTypeClientType1NullNetwork, never()).processResponse(any(),
+                eq(SOCKET_KEY_NETWORK_2));
+        verify(mockServiceTypeClientType1Network1, never()).processResponse(any(),
+                eq(SOCKET_KEY_NETWORK_2));
+        verify(mockServiceTypeClientType2Network2).processResponse(
+                responseForSubtype, SOCKET_KEY_NETWORK_2);
     }
 
     @Test
@@ -260,10 +259,8 @@
         // Receive a response, it should be processed on both clients.
         final MdnsPacket response = createMdnsPacket(SERVICE_TYPE_1);
         runOnHandler(() -> discoveryManager.onResponseReceived(response, SOCKET_KEY_NETWORK_1));
-        verify(mockServiceTypeClientType1Network1).processResponse(response,
-                SOCKET_KEY_NETWORK_1.getInterfaceIndex(), SOCKET_KEY_NETWORK_1.getNetwork());
-        verify(mockServiceTypeClientType2Network1).processResponse(response,
-                SOCKET_KEY_NETWORK_1.getInterfaceIndex(), SOCKET_KEY_NETWORK_1.getNetwork());
+        verify(mockServiceTypeClientType1Network1).processResponse(response, SOCKET_KEY_NETWORK_1);
+        verify(mockServiceTypeClientType2Network1).processResponse(response, SOCKET_KEY_NETWORK_1);
 
         // The first callback receives a notification that the network has been destroyed,
         // mockServiceTypeClientOne1 should send service removed notifications and remove from the
@@ -276,10 +273,10 @@
         // removed from the list of clients, it is no longer able to process responses.
         runOnHandler(() -> discoveryManager.onResponseReceived(response, SOCKET_KEY_NETWORK_1));
         // Still times(1) as a response was received once previously
-        verify(mockServiceTypeClientType1Network1, times(1)).processResponse(response,
-                SOCKET_KEY_NETWORK_1.getInterfaceIndex(), SOCKET_KEY_NETWORK_1.getNetwork());
-        verify(mockServiceTypeClientType2Network1, times(2)).processResponse(response,
-                SOCKET_KEY_NETWORK_1.getInterfaceIndex(), SOCKET_KEY_NETWORK_1.getNetwork());
+        verify(mockServiceTypeClientType1Network1, times(1)).processResponse(
+                response, SOCKET_KEY_NETWORK_1);
+        verify(mockServiceTypeClientType2Network1, times(2)).processResponse(
+                response, SOCKET_KEY_NETWORK_1);
 
         // The client for NETWORK_1 receives the callback that the NETWORK_2 has been destroyed,
         // mockServiceTypeClientTwo2 shouldn't send any notifications.
@@ -289,10 +286,10 @@
         // Receive a response again, mockServiceTypeClientType2Network1 is still in the list of
         // clients, it's still able to process responses.
         runOnHandler(() -> discoveryManager.onResponseReceived(response, SOCKET_KEY_NETWORK_1));
-        verify(mockServiceTypeClientType1Network1, times(1)).processResponse(response,
-                SOCKET_KEY_NETWORK_1.getInterfaceIndex(), SOCKET_KEY_NETWORK_1.getNetwork());
-        verify(mockServiceTypeClientType2Network1, times(3)).processResponse(response,
-                SOCKET_KEY_NETWORK_1.getInterfaceIndex(), SOCKET_KEY_NETWORK_1.getNetwork());
+        verify(mockServiceTypeClientType1Network1, times(1)).processResponse(
+                response, SOCKET_KEY_NETWORK_1);
+        verify(mockServiceTypeClientType2Network1, times(3)).processResponse(
+                response, SOCKET_KEY_NETWORK_1);
     }
 
     @Test
@@ -310,8 +307,8 @@
         final MdnsPacket response = createMdnsPacket(SERVICE_TYPE_1);
         final int ifIndex = 1;
         runOnHandler(() -> discoveryManager.onResponseReceived(response, SOCKET_KEY_NULL_NETWORK));
-        verify(mockServiceTypeClientType1NullNetwork).processResponse(response,
-                SOCKET_KEY_NULL_NETWORK.getInterfaceIndex(), SOCKET_KEY_NULL_NETWORK.getNetwork());
+        verify(mockServiceTypeClientType1NullNetwork).processResponse(
+                response, SOCKET_KEY_NULL_NETWORK);
 
         runOnHandler(() -> callback.onAllSocketsDestroyed(SOCKET_KEY_NULL_NETWORK));
         verify(mockServiceTypeClientType1NullNetwork).notifySocketDestroyed();
@@ -319,8 +316,8 @@
         // Receive a response again, it should not be processed.
         runOnHandler(() -> discoveryManager.onResponseReceived(response, SOCKET_KEY_NULL_NETWORK));
         // Still times(1) as a response was received once previously
-        verify(mockServiceTypeClientType1NullNetwork, times(1)).processResponse(response,
-                SOCKET_KEY_NULL_NETWORK.getInterfaceIndex(), SOCKET_KEY_NULL_NETWORK.getNetwork());
+        verify(mockServiceTypeClientType1NullNetwork, times(1)).processResponse(
+                response, SOCKET_KEY_NULL_NETWORK);
 
         // Unregister the listener, notifyNetworkUnrequested should be called but other stop methods
         // won't be call because the service type client was unregistered and destroyed. But those
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClientTest.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClientTest.java
index b812fa6..6f75d7e 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClientTest.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClientTest.java
@@ -74,7 +74,6 @@
     @Before
     public void setUp() throws SocketException {
         MockitoAnnotations.initMocks(this);
-        doReturn(mNetwork).when(mSocketKey).getNetwork();
 
         final HandlerThread thread = new HandlerThread("MdnsMultinetworkSocketClientTest");
         thread.start();
@@ -127,8 +126,6 @@
 
         final SocketKey tetherSocketKey1 = mock(SocketKey.class);
         final SocketKey tetherSocketKey2 = mock(SocketKey.class);
-        doReturn(null).when(tetherSocketKey1).getNetwork();
-        doReturn(null).when(tetherSocketKey2).getNetwork();
         // Notify socket created
         callback.onSocketCreated(mSocketKey, mSocket, List.of());
         verify(mSocketCreationCallback).onSocketCreated(mSocketKey);
@@ -137,8 +134,8 @@
         callback.onSocketCreated(tetherSocketKey2, tetherIfaceSock2, List.of());
         verify(mSocketCreationCallback).onSocketCreated(tetherSocketKey2);
 
-        // Send packet to IPv4 with target network and verify sending has been called.
-        mSocketClient.sendPacketRequestingMulticastResponse(ipv4Packet, mNetwork,
+        // Send packet to IPv4 with mSocketKey and verify sending has been called.
+        mSocketClient.sendPacketRequestingMulticastResponse(ipv4Packet, mSocketKey,
                 false /* onlyUseIpv6OnIpv6OnlyNetworks */);
         HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
         verify(mSocket).send(ipv4Packet);
@@ -146,30 +143,30 @@
         verify(tetherIfaceSock2, never()).send(any());
 
         // Send packet to IPv4 with onlyUseIpv6OnIpv6OnlyNetworks = true, the packet will be sent.
-        mSocketClient.sendPacketRequestingMulticastResponse(ipv4Packet, mNetwork,
+        mSocketClient.sendPacketRequestingMulticastResponse(ipv4Packet, mSocketKey,
                 true /* onlyUseIpv6OnIpv6OnlyNetworks */);
         HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
         verify(mSocket, times(2)).send(ipv4Packet);
         verify(tetherIfaceSock1, never()).send(any());
         verify(tetherIfaceSock2, never()).send(any());
 
-        // Send packet to IPv6 without target network and verify sending has been called.
-        mSocketClient.sendPacketRequestingMulticastResponse(ipv6Packet, null,
+        // Send packet to IPv6 with tetherSocketKey1 and verify sending has been called.
+        mSocketClient.sendPacketRequestingMulticastResponse(ipv6Packet, tetherSocketKey1,
                 false /* onlyUseIpv6OnIpv6OnlyNetworks */);
         HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
         verify(mSocket, never()).send(ipv6Packet);
         verify(tetherIfaceSock1).send(ipv6Packet);
-        verify(tetherIfaceSock2).send(ipv6Packet);
+        verify(tetherIfaceSock2, never()).send(ipv6Packet);
 
         // Send packet to IPv6 with onlyUseIpv6OnIpv6OnlyNetworks = true, the packet will not be
         // sent. Therefore, the tetherIfaceSock1.send() and tetherIfaceSock2.send() are still be
         // called once.
-        mSocketClient.sendPacketRequestingMulticastResponse(ipv6Packet, null,
+        mSocketClient.sendPacketRequestingMulticastResponse(ipv6Packet, tetherSocketKey1,
                 true /* onlyUseIpv6OnIpv6OnlyNetworks */);
         HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
         verify(mSocket, never()).send(ipv6Packet);
         verify(tetherIfaceSock1, times(1)).send(ipv6Packet);
-        verify(tetherIfaceSock2, times(1)).send(ipv6Packet);
+        verify(tetherIfaceSock2, never()).send(ipv6Packet);
     }
 
     @Test
@@ -249,8 +246,8 @@
         verify(mSocketCreationCallback).onSocketCreated(socketKey2);
         verify(mSocketCreationCallback).onSocketCreated(socketKey3);
 
-        // Send IPv4 packet on the non-null Network and verify sending has been called.
-        mSocketClient.sendPacketRequestingMulticastResponse(ipv4Packet, mNetwork,
+        // Send IPv4 packet on the mSocketKey and verify sending has been called.
+        mSocketClient.sendPacketRequestingMulticastResponse(ipv4Packet, mSocketKey,
                 false /* onlyUseIpv6OnIpv6OnlyNetworks */);
         HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
         verify(mSocket).send(ipv4Packet);
@@ -278,38 +275,38 @@
         verify(socketCreationCb2).onSocketCreated(socketKey2);
         verify(socketCreationCb2).onSocketCreated(socketKey3);
 
-        // Send IPv4 packet to null network and verify sending to the 2 tethered interface sockets.
-        mSocketClient.sendPacketRequestingMulticastResponse(ipv4Packet, null,
+        // Send IPv4 packet on socket2 and verify sending to the socket2 only.
+        mSocketClient.sendPacketRequestingMulticastResponse(ipv4Packet, socketKey2,
                 false /* onlyUseIpv6OnIpv6OnlyNetworks */);
         HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
         // ipv4Packet still sent only once on mSocket: times(1) matches the packet sent earlier on
         // mNetwork
         verify(mSocket, times(1)).send(ipv4Packet);
         verify(socket2).send(ipv4Packet);
-        verify(socket3).send(ipv4Packet);
+        verify(socket3, never()).send(ipv4Packet);
 
         // Unregister the second request
         mHandler.post(() -> mSocketClient.notifyNetworkUnrequested(listener2));
         verify(mProvider, timeout(DEFAULT_TIMEOUT)).unrequestSocket(callback2);
 
         // Send IPv4 packet again and verify it's still sent a second time
-        mSocketClient.sendPacketRequestingMulticastResponse(ipv4Packet, null,
+        mSocketClient.sendPacketRequestingMulticastResponse(ipv4Packet, socketKey2,
                 false /* onlyUseIpv6OnIpv6OnlyNetworks */);
         HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
         verify(socket2, times(2)).send(ipv4Packet);
-        verify(socket3, times(2)).send(ipv4Packet);
+        verify(socket3, never()).send(ipv4Packet);
 
         // Unrequest remaining sockets
         mHandler.post(() -> mSocketClient.notifyNetworkUnrequested(mListener));
         verify(mProvider, timeout(DEFAULT_TIMEOUT)).unrequestSocket(callback);
 
         // Send IPv4 packet and verify no more sending.
-        mSocketClient.sendPacketRequestingMulticastResponse(ipv4Packet, null,
+        mSocketClient.sendPacketRequestingMulticastResponse(ipv4Packet, mSocketKey,
                 false /* onlyUseIpv6OnIpv6OnlyNetworks */);
         HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
         verify(mSocket, times(1)).send(ipv4Packet);
         verify(socket2, times(2)).send(ipv4Packet);
-        verify(socket3, times(2)).send(ipv4Packet);
+        verify(socket3, never()).send(ipv4Packet);
     }
 
     @Test
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java
index 9892e9f..03e893f 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java
@@ -438,7 +438,7 @@
         client.processResponse(createResponse(
                 "service-instance-1", "192.0.2.123", 5353,
                 SERVICE_TYPE_LABELS,
-                Collections.emptyMap(), TEST_TTL), /* interfaceIndex= */ 20, mockNetwork);
+                Collections.emptyMap(), TEST_TTL), socketKey);
 
         verify(mockListenerOne).onServiceNameDiscovered(any());
         verify(mockListenerOne).onServiceFound(any());
@@ -459,8 +459,7 @@
 
     private static void verifyServiceInfo(MdnsServiceInfo serviceInfo, String serviceName,
             String[] serviceType, List<String> ipv4Addresses, List<String> ipv6Addresses, int port,
-            List<String> subTypes, Map<String, String> attributes, int interfaceIndex,
-            Network network) {
+            List<String> subTypes, Map<String, String> attributes, SocketKey socketKey) {
         assertEquals(serviceName, serviceInfo.getServiceInstanceName());
         assertArrayEquals(serviceType, serviceInfo.getServiceType());
         assertEquals(ipv4Addresses, serviceInfo.getIpv4Addresses());
@@ -471,8 +470,8 @@
             assertTrue(attributes.containsKey(key));
             assertEquals(attributes.get(key), serviceInfo.getAttributeByKey(key));
         }
-        assertEquals(interfaceIndex, serviceInfo.getInterfaceIndex());
-        assertEquals(network, serviceInfo.getNetwork());
+        assertEquals(socketKey.getInterfaceIndex(), serviceInfo.getInterfaceIndex());
+        assertEquals(socketKey.getNetwork(), serviceInfo.getNetwork());
     }
 
     @Test
@@ -482,7 +481,7 @@
         client.processResponse(createResponse(
                 "service-instance-1", null /* host */, 0 /* port */,
                 SERVICE_TYPE_LABELS,
-                Collections.emptyMap(), TEST_TTL), INTERFACE_INDEX, mockNetwork);
+                Collections.emptyMap(), TEST_TTL), socketKey);
         verify(mockListenerOne).onServiceNameDiscovered(serviceInfoCaptor.capture());
         verifyServiceInfo(serviceInfoCaptor.getAllValues().get(0),
                 "service-instance-1",
@@ -492,8 +491,7 @@
                 /* port= */ 0,
                 /* subTypes= */ List.of(),
                 Collections.emptyMap(),
-                INTERFACE_INDEX,
-                mockNetwork);
+                socketKey);
 
         verify(mockListenerOne, never()).onServiceFound(any(MdnsServiceInfo.class));
         verify(mockListenerOne, never()).onServiceUpdated(any(MdnsServiceInfo.class));
@@ -508,14 +506,14 @@
         client.processResponse(createResponse(
                 "service-instance-1", ipV4Address, 5353,
                 /* subtype= */ "ABCDE",
-                Collections.emptyMap(), TEST_TTL), /* interfaceIndex= */ 20, mockNetwork);
+                Collections.emptyMap(), TEST_TTL), socketKey);
 
         // Process a second response with a different port and updated text attributes.
         client.processResponse(createResponse(
                         "service-instance-1", ipV4Address, 5354,
                         /* subtype= */ "ABCDE",
                         Collections.singletonMap("key", "value"), TEST_TTL),
-                /* interfaceIndex= */ 20, mockNetwork);
+                socketKey);
 
         // Verify onServiceNameDiscovered was called once for the initial response.
         verify(mockListenerOne).onServiceNameDiscovered(serviceInfoCaptor.capture());
@@ -527,8 +525,7 @@
                 5353 /* port */,
                 Collections.singletonList("ABCDE") /* subTypes */,
                 Collections.singletonMap("key", null) /* attributes */,
-                20 /* interfaceIndex */,
-                mockNetwork);
+                socketKey);
 
         // Verify onServiceFound was called once for the initial response.
         verify(mockListenerOne).onServiceFound(serviceInfoCaptor.capture());
@@ -538,8 +535,8 @@
         assertEquals(initialServiceInfo.getPort(), 5353);
         assertEquals(initialServiceInfo.getSubtypes(), Collections.singletonList("ABCDE"));
         assertNull(initialServiceInfo.getAttributeByKey("key"));
-        assertEquals(initialServiceInfo.getInterfaceIndex(), 20);
-        assertEquals(mockNetwork, initialServiceInfo.getNetwork());
+        assertEquals(socketKey.getInterfaceIndex(), initialServiceInfo.getInterfaceIndex());
+        assertEquals(socketKey.getNetwork(), initialServiceInfo.getNetwork());
 
         // Verify onServiceUpdated was called once for the second response.
         verify(mockListenerOne).onServiceUpdated(serviceInfoCaptor.capture());
@@ -550,8 +547,8 @@
         assertTrue(updatedServiceInfo.hasSubtypes());
         assertEquals(updatedServiceInfo.getSubtypes(), Collections.singletonList("ABCDE"));
         assertEquals(updatedServiceInfo.getAttributeByKey("key"), "value");
-        assertEquals(updatedServiceInfo.getInterfaceIndex(), 20);
-        assertEquals(mockNetwork, updatedServiceInfo.getNetwork());
+        assertEquals(socketKey.getInterfaceIndex(), updatedServiceInfo.getInterfaceIndex());
+        assertEquals(socketKey.getNetwork(), updatedServiceInfo.getNetwork());
     }
 
     @Test
@@ -563,14 +560,14 @@
         client.processResponse(createResponse(
                 "service-instance-1", ipV6Address, 5353,
                 /* subtype= */ "ABCDE",
-                Collections.emptyMap(), TEST_TTL), /* interfaceIndex= */ 20, mockNetwork);
+                Collections.emptyMap(), TEST_TTL), socketKey);
 
         // Process a second response with a different port and updated text attributes.
         client.processResponse(createResponse(
                         "service-instance-1", ipV6Address, 5354,
                         /* subtype= */ "ABCDE",
                         Collections.singletonMap("key", "value"), TEST_TTL),
-                /* interfaceIndex= */ 20, mockNetwork);
+                socketKey);
 
         // Verify onServiceNameDiscovered was called once for the initial response.
         verify(mockListenerOne).onServiceNameDiscovered(serviceInfoCaptor.capture());
@@ -582,8 +579,7 @@
                 5353 /* port */,
                 Collections.singletonList("ABCDE") /* subTypes */,
                 Collections.singletonMap("key", null) /* attributes */,
-                20 /* interfaceIndex */,
-                mockNetwork);
+                socketKey);
 
         // Verify onServiceFound was called once for the initial response.
         verify(mockListenerOne).onServiceFound(serviceInfoCaptor.capture());
@@ -593,8 +589,8 @@
         assertEquals(initialServiceInfo.getPort(), 5353);
         assertEquals(initialServiceInfo.getSubtypes(), Collections.singletonList("ABCDE"));
         assertNull(initialServiceInfo.getAttributeByKey("key"));
-        assertEquals(initialServiceInfo.getInterfaceIndex(), 20);
-        assertEquals(mockNetwork, initialServiceInfo.getNetwork());
+        assertEquals(socketKey.getInterfaceIndex(), initialServiceInfo.getInterfaceIndex());
+        assertEquals(socketKey.getNetwork(), initialServiceInfo.getNetwork());
 
         // Verify onServiceUpdated was called once for the second response.
         verify(mockListenerOne).onServiceUpdated(serviceInfoCaptor.capture());
@@ -605,8 +601,8 @@
         assertTrue(updatedServiceInfo.hasSubtypes());
         assertEquals(updatedServiceInfo.getSubtypes(), Collections.singletonList("ABCDE"));
         assertEquals(updatedServiceInfo.getAttributeByKey("key"), "value");
-        assertEquals(updatedServiceInfo.getInterfaceIndex(), 20);
-        assertEquals(mockNetwork, updatedServiceInfo.getNetwork());
+        assertEquals(socketKey.getInterfaceIndex(), updatedServiceInfo.getInterfaceIndex());
+        assertEquals(socketKey.getNetwork(), updatedServiceInfo.getNetwork());
     }
 
     private void verifyServiceRemovedNoCallback(MdnsServiceBrowserListener listener) {
@@ -615,17 +611,17 @@
     }
 
     private void verifyServiceRemovedCallback(MdnsServiceBrowserListener listener,
-            String serviceName, String[] serviceType, int interfaceIndex, Network network) {
+            String serviceName, String[] serviceType, SocketKey socketKey) {
         verify(listener).onServiceRemoved(argThat(
                 info -> serviceName.equals(info.getServiceInstanceName())
                         && Arrays.equals(serviceType, info.getServiceType())
-                        && info.getInterfaceIndex() == interfaceIndex
-                        && network.equals(info.getNetwork())));
+                        && info.getInterfaceIndex() == socketKey.getInterfaceIndex()
+                        && socketKey.getNetwork().equals(info.getNetwork())));
         verify(listener).onServiceNameRemoved(argThat(
                 info -> serviceName.equals(info.getServiceInstanceName())
                         && Arrays.equals(serviceType, info.getServiceType())
-                        && info.getInterfaceIndex() == interfaceIndex
-                        && network.equals(info.getNetwork())));
+                        && info.getInterfaceIndex() == socketKey.getInterfaceIndex()
+                        && socketKey.getNetwork().equals(info.getNetwork())));
     }
 
     @Test
@@ -639,12 +635,12 @@
         client.processResponse(createResponse(
                 serviceName, ipV6Address, 5353,
                 SERVICE_TYPE_LABELS,
-                Collections.emptyMap(), TEST_TTL), INTERFACE_INDEX, mockNetwork);
+                Collections.emptyMap(), TEST_TTL), socketKey);
 
         client.processResponse(createResponse(
                 "goodbye-service", ipV6Address, 5353,
                 SERVICE_TYPE_LABELS,
-                Collections.emptyMap(), /* ptrTtlMillis= */ 0L), INTERFACE_INDEX, mockNetwork);
+                Collections.emptyMap(), /* ptrTtlMillis= */ 0L), socketKey);
 
         // Verify removed callback won't be called if the service is not existed.
         verifyServiceRemovedNoCallback(mockListenerOne);
@@ -654,11 +650,11 @@
         client.processResponse(createResponse(
                 serviceName, ipV6Address, 5353,
                 SERVICE_TYPE_LABELS,
-                Collections.emptyMap(), 0L), INTERFACE_INDEX, mockNetwork);
+                Collections.emptyMap(), 0L), socketKey);
         verifyServiceRemovedCallback(
-                mockListenerOne, serviceName, SERVICE_TYPE_LABELS, INTERFACE_INDEX, mockNetwork);
+                mockListenerOne, serviceName, SERVICE_TYPE_LABELS, socketKey);
         verifyServiceRemovedCallback(
-                mockListenerTwo, serviceName, SERVICE_TYPE_LABELS, INTERFACE_INDEX, mockNetwork);
+                mockListenerTwo, serviceName, SERVICE_TYPE_LABELS, socketKey);
     }
 
     @Test
@@ -667,7 +663,7 @@
         client.processResponse(createResponse(
                 "service-instance-1", "192.168.1.1", 5353,
                 /* subtype= */ "ABCDE",
-                Collections.emptyMap(), TEST_TTL), INTERFACE_INDEX, mockNetwork);
+                Collections.emptyMap(), TEST_TTL), socketKey);
 
         client.startSendAndReceive(mockListenerOne, MdnsSearchOptions.getDefaultOptions());
 
@@ -681,8 +677,7 @@
                 5353 /* port */,
                 Collections.singletonList("ABCDE") /* subTypes */,
                 Collections.singletonMap("key", null) /* attributes */,
-                INTERFACE_INDEX,
-                mockNetwork);
+                socketKey);
 
         // Verify onServiceFound was called once for the existing response.
         verify(mockListenerOne).onServiceFound(serviceInfoCaptor.capture());
@@ -697,7 +692,7 @@
         client.processResponse(createResponse(
                 "service-instance-1", "192.168.1.1", 5353,
                 SERVICE_TYPE_LABELS,
-                Collections.emptyMap(), /* ptrTtlMillis= */ 0L), INTERFACE_INDEX, mockNetwork);
+                Collections.emptyMap(), /* ptrTtlMillis= */ 0L), socketKey);
 
         client.startSendAndReceive(mockListenerTwo, MdnsSearchOptions.getDefaultOptions());
 
@@ -727,7 +722,7 @@
         // Process the initial response.
         client.processResponse(createResponse(
                 serviceInstanceName, "192.168.1.1", 5353, /* subtype= */ "ABCDE",
-                Collections.emptyMap(), TEST_TTL), INTERFACE_INDEX, mockNetwork);
+                Collections.emptyMap(), TEST_TTL), socketKey);
 
         // Clear the scheduled runnable.
         currentThreadExecutor.getAndClearLastScheduledRunnable();
@@ -744,8 +739,8 @@
         firstMdnsTask.run();
 
         // Verify removed callback was called.
-        verifyServiceRemovedCallback(mockListenerOne, serviceInstanceName, SERVICE_TYPE_LABELS,
-                INTERFACE_INDEX, mockNetwork);
+        verifyServiceRemovedCallback(
+                mockListenerOne, serviceInstanceName, SERVICE_TYPE_LABELS, socketKey);
     }
 
     @Test
@@ -766,7 +761,7 @@
         // Process the initial response.
         client.processResponse(createResponse(
                 serviceInstanceName, "192.168.1.1", 5353, /* subtype= */ "ABCDE",
-                Collections.emptyMap(), TEST_TTL), INTERFACE_INDEX, mockNetwork);
+                Collections.emptyMap(), TEST_TTL), socketKey);
 
         // Clear the scheduled runnable.
         currentThreadExecutor.getAndClearLastScheduledRunnable();
@@ -799,7 +794,7 @@
         // Process the initial response.
         client.processResponse(createResponse(
                 serviceInstanceName, "192.168.1.1", 5353, /* subtype= */ "ABCDE",
-                Collections.emptyMap(), TEST_TTL), INTERFACE_INDEX, mockNetwork);
+                Collections.emptyMap(), TEST_TTL), socketKey);
 
         // Clear the scheduled runnable.
         currentThreadExecutor.getAndClearLastScheduledRunnable();
@@ -809,8 +804,8 @@
         firstMdnsTask.run();
 
         // Verify removed callback was called.
-        verifyServiceRemovedCallback(mockListenerOne, serviceInstanceName, SERVICE_TYPE_LABELS,
-                INTERFACE_INDEX, mockNetwork);
+        verifyServiceRemovedCallback(
+                mockListenerOne, serviceInstanceName, SERVICE_TYPE_LABELS, socketKey);
     }
 
     @Test
@@ -825,23 +820,23 @@
         final String subtype = "ABCDE";
         client.processResponse(createResponse(
                 serviceName, null, 5353, subtype,
-                Collections.emptyMap(), TEST_TTL), INTERFACE_INDEX, mockNetwork);
+                Collections.emptyMap(), TEST_TTL), socketKey);
 
         // Process a second response which has ip address to make response become complete.
         client.processResponse(createResponse(
                 serviceName, ipV4Address, 5353, subtype,
-                Collections.emptyMap(), TEST_TTL), INTERFACE_INDEX, mockNetwork);
+                Collections.emptyMap(), TEST_TTL), socketKey);
 
         // Process a third response with a different ip address, port and updated text attributes.
         client.processResponse(createResponse(
                 serviceName, ipV6Address, 5354, subtype,
-                Collections.singletonMap("key", "value"), TEST_TTL), INTERFACE_INDEX, mockNetwork);
+                Collections.singletonMap("key", "value"), TEST_TTL), socketKey);
 
         // Process the last response which is goodbye message (with the main type, not subtype).
         client.processResponse(createResponse(
                         serviceName, ipV6Address, 5354, SERVICE_TYPE_LABELS,
                         Collections.singletonMap("key", "value"), /* ptrTtlMillis= */ 0L),
-                INTERFACE_INDEX, mockNetwork);
+                socketKey);
 
         // Verify onServiceNameDiscovered was first called for the initial response.
         inOrder.verify(mockListenerOne).onServiceNameDiscovered(serviceInfoCaptor.capture());
@@ -853,8 +848,7 @@
                 5353 /* port */,
                 Collections.singletonList(subtype) /* subTypes */,
                 Collections.singletonMap("key", null) /* attributes */,
-                INTERFACE_INDEX,
-                mockNetwork);
+                socketKey);
 
         // Verify onServiceFound was second called for the second response.
         inOrder.verify(mockListenerOne).onServiceFound(serviceInfoCaptor.capture());
@@ -866,8 +860,7 @@
                 5353 /* port */,
                 Collections.singletonList(subtype) /* subTypes */,
                 Collections.singletonMap("key", null) /* attributes */,
-                INTERFACE_INDEX,
-                mockNetwork);
+                socketKey);
 
         // Verify onServiceUpdated was third called for the third response.
         inOrder.verify(mockListenerOne).onServiceUpdated(serviceInfoCaptor.capture());
@@ -879,8 +872,7 @@
                 5354 /* port */,
                 Collections.singletonList(subtype) /* subTypes */,
                 Collections.singletonMap("key", "value") /* attributes */,
-                INTERFACE_INDEX,
-                mockNetwork);
+                socketKey);
 
         // Verify onServiceRemoved was called for the last response.
         inOrder.verify(mockListenerOne).onServiceRemoved(serviceInfoCaptor.capture());
@@ -892,8 +884,7 @@
                 5354 /* port */,
                 Collections.singletonList("ABCDE") /* subTypes */,
                 Collections.singletonMap("key", "value") /* attributes */,
-                INTERFACE_INDEX,
-                mockNetwork);
+                socketKey);
 
         // Verify onServiceNameRemoved was called for the last response.
         inOrder.verify(mockListenerOne).onServiceNameRemoved(serviceInfoCaptor.capture());
@@ -905,8 +896,7 @@
                 5354 /* port */,
                 Collections.singletonList("ABCDE") /* subTypes */,
                 Collections.singletonMap("key", "value") /* attributes */,
-                INTERFACE_INDEX,
-                mockNetwork);
+                socketKey);
     }
 
     @Test
@@ -932,7 +922,7 @@
         // Send twice for IPv4 and IPv6
         inOrder.verify(mockSocketClient, times(2)).sendPacketRequestingUnicastResponse(
                 srvTxtQueryCaptor.capture(),
-                eq(mockNetwork), eq(false));
+                eq(socketKey), eq(false));
 
         final MdnsPacket srvTxtQueryPacket = MdnsPacket.parse(
                 new MdnsPacketReader(srvTxtQueryCaptor.getValue()));
@@ -955,7 +945,7 @@
                 Collections.emptyList() /* authorityRecords */,
                 Collections.emptyList() /* additionalRecords */);
 
-        client.processResponse(srvTxtResponse, INTERFACE_INDEX, mockNetwork);
+        client.processResponse(srvTxtResponse, socketKey);
 
         // Expect a query for A/AAAA
         final ArgumentCaptor<DatagramPacket> addressQueryCaptor =
@@ -963,7 +953,7 @@
         currentThreadExecutor.getAndClearLastScheduledRunnable().run();
         inOrder.verify(mockSocketClient, times(2)).sendPacketRequestingMulticastResponse(
                 addressQueryCaptor.capture(),
-                eq(mockNetwork), eq(false));
+                eq(socketKey), eq(false));
 
         final MdnsPacket addressQueryPacket = MdnsPacket.parse(
                 new MdnsPacketReader(addressQueryCaptor.getValue()));
@@ -985,7 +975,7 @@
                 Collections.emptyList() /* additionalRecords */);
 
         inOrder.verify(mockListenerOne, never()).onServiceNameDiscovered(any());
-        client.processResponse(addressResponse, INTERFACE_INDEX, mockNetwork);
+        client.processResponse(addressResponse, socketKey);
 
         inOrder.verify(mockListenerOne).onServiceFound(serviceInfoCaptor.capture());
         verifyServiceInfo(serviceInfoCaptor.getValue(),
@@ -996,8 +986,7 @@
                 1234 /* port */,
                 Collections.emptyList() /* subTypes */,
                 Collections.emptyMap() /* attributes */,
-                INTERFACE_INDEX,
-                mockNetwork);
+                socketKey);
     }
 
     @Test
@@ -1023,7 +1012,7 @@
         // Send twice for IPv4 and IPv6
         inOrder.verify(mockSocketClient, times(2)).sendPacketRequestingUnicastResponse(
                 srvTxtQueryCaptor.capture(),
-                eq(mockNetwork), eq(false));
+                eq(socketKey), eq(false));
 
         final MdnsPacket srvTxtQueryPacket = MdnsPacket.parse(
                 new MdnsPacketReader(srvTxtQueryCaptor.getValue()));
@@ -1050,7 +1039,7 @@
                                 InetAddresses.parseNumericAddress(ipV6Address))),
                 Collections.emptyList() /* authorityRecords */,
                 Collections.emptyList() /* additionalRecords */);
-        client.processResponse(srvTxtResponse, INTERFACE_INDEX, mockNetwork);
+        client.processResponse(srvTxtResponse, socketKey);
         inOrder.verify(mockListenerOne).onServiceNameDiscovered(any());
         inOrder.verify(mockListenerOne).onServiceFound(any());
 
@@ -1069,7 +1058,7 @@
         // Second and later sends are sent as "expect multicast response" queries
         inOrder.verify(mockSocketClient, times(2)).sendPacketRequestingMulticastResponse(
                 renewalQueryCaptor.capture(),
-                eq(mockNetwork), eq(false));
+                eq(socketKey), eq(false));
         inOrder.verify(mockListenerOne).onDiscoveryQuerySent(any(), anyInt());
         final MdnsPacket renewalPacket = MdnsPacket.parse(
                 new MdnsPacketReader(renewalQueryCaptor.getValue()));
@@ -1095,7 +1084,7 @@
                                 InetAddresses.parseNumericAddress(ipV6Address))),
                 Collections.emptyList() /* authorityRecords */,
                 Collections.emptyList() /* additionalRecords */);
-        client.processResponse(refreshedSrvTxtResponse, INTERFACE_INDEX, mockNetwork);
+        client.processResponse(refreshedSrvTxtResponse, socketKey);
 
         // Advance time to updatedReceiptTime + 1, expected no refresh query because the cache
         // should contain the record that have update last receipt time.
@@ -1126,23 +1115,23 @@
         client.processResponse(createResponse(
                         requestedInstance, ipV4Address, 5353, SERVICE_TYPE_LABELS,
                         Collections.emptyMap() /* textAttributes */, TEST_TTL),
-                INTERFACE_INDEX, mockNetwork);
+                socketKey);
 
         // Complete response from otherInstanceName
         client.processResponse(createResponse(
                         otherInstance, ipV4Address, 5353, SERVICE_TYPE_LABELS,
                         Collections.emptyMap() /* textAttributes */, TEST_TTL),
-                INTERFACE_INDEX, mockNetwork);
+                socketKey);
 
         // Address update from otherInstanceName
         client.processResponse(createResponse(
                 otherInstance, ipV6Address, 5353, SERVICE_TYPE_LABELS,
-                Collections.emptyMap(), TEST_TTL), INTERFACE_INDEX, mockNetwork);
+                Collections.emptyMap(), TEST_TTL), socketKey);
 
         // Goodbye from otherInstanceName
         client.processResponse(createResponse(
                 otherInstance, ipV6Address, 5353, SERVICE_TYPE_LABELS,
-                Collections.emptyMap(), 0L /* ttl */), INTERFACE_INDEX, mockNetwork);
+                Collections.emptyMap(), 0L /* ttl */), socketKey);
 
         // mockListenerOne gets notified for the requested instance
         verify(mockListenerOne).onServiceNameDiscovered(
@@ -1207,23 +1196,23 @@
                 newAnswers,
                 packetWithoutSubtype.authorityRecords,
                 packetWithoutSubtype.additionalRecords);
-        client.processResponse(packetWithSubtype, INTERFACE_INDEX, mockNetwork);
+        client.processResponse(packetWithSubtype, socketKey);
 
         // Complete response from otherInstanceName, without subtype
         client.processResponse(createResponse(
                         otherInstance, ipV4Address, 5353, SERVICE_TYPE_LABELS,
                         Collections.emptyMap() /* textAttributes */, TEST_TTL),
-                INTERFACE_INDEX, mockNetwork);
+                socketKey);
 
         // Address update from otherInstanceName
         client.processResponse(createResponse(
                 otherInstance, ipV6Address, 5353, SERVICE_TYPE_LABELS,
-                Collections.emptyMap(), TEST_TTL), INTERFACE_INDEX, mockNetwork);
+                Collections.emptyMap(), TEST_TTL), socketKey);
 
         // Goodbye from otherInstanceName
         client.processResponse(createResponse(
                 otherInstance, ipV6Address, 5353, SERVICE_TYPE_LABELS,
-                Collections.emptyMap(), 0L /* ttl */), INTERFACE_INDEX, mockNetwork);
+                Collections.emptyMap(), 0L /* ttl */), socketKey);
 
         // mockListenerOne gets notified for the requested instance
         final ArgumentMatcher<MdnsServiceInfo> subtypeInstanceMatcher = info ->
@@ -1278,13 +1267,13 @@
         client.processResponse(createResponse(
                         requestedInstance, ipV4Address, 5353, SERVICE_TYPE_LABELS,
                         Collections.emptyMap() /* textAttributes */, TEST_TTL),
-                INTERFACE_INDEX, mockNetwork);
+                socketKey);
 
         // Complete response from otherInstanceName
         client.processResponse(createResponse(
                         otherInstance, ipV4Address, 5353, SERVICE_TYPE_LABELS,
                         Collections.emptyMap() /* textAttributes */, TEST_TTL),
-                INTERFACE_INDEX, mockNetwork);
+                socketKey);
 
         verify(expectedSendFutures[1], never()).cancel(true);
         client.notifySocketDestroyed();
@@ -1332,17 +1321,17 @@
         currentThreadExecutor.getAndClearLastScheduledRunnable().run();
         if (expectsUnicastResponse) {
             verify(mockSocketClient).sendPacketRequestingUnicastResponse(
-                    expectedIPv4Packets[index], mockNetwork, false);
+                    expectedIPv4Packets[index], socketKey, false);
             if (multipleSocketDiscovery) {
                 verify(mockSocketClient).sendPacketRequestingUnicastResponse(
-                        expectedIPv6Packets[index], mockNetwork, false);
+                        expectedIPv6Packets[index], socketKey, false);
             }
         } else {
             verify(mockSocketClient).sendPacketRequestingMulticastResponse(
-                    expectedIPv4Packets[index], mockNetwork, false);
+                    expectedIPv4Packets[index], socketKey, false);
             if (multipleSocketDiscovery) {
                 verify(mockSocketClient).sendPacketRequestingMulticastResponse(
-                        expectedIPv6Packets[index], mockNetwork, false);
+                        expectedIPv6Packets[index], socketKey, false);
             }
         }
     }