Merge "Notify socket changes using a SoketKey"
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsAdvertiser.java b/service-t/src/com/android/server/connectivity/mdns/MdnsAdvertiser.java
index 5f27b6a..158d7a3 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsAdvertiser.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsAdvertiser.java
@@ -287,7 +287,7 @@
         }
 
         @Override
-        public void onSocketCreated(@NonNull Network network,
+        public void onSocketCreated(@NonNull SocketKey socketKey,
                 @NonNull MdnsInterfaceSocket socket,
                 @NonNull List<LinkAddress> addresses) {
             MdnsInterfaceAdvertiser advertiser = mAllAdvertisers.get(socket);
@@ -311,14 +311,14 @@
         }
 
         @Override
-        public void onInterfaceDestroyed(@NonNull Network network,
+        public void onInterfaceDestroyed(@NonNull SocketKey socketKey,
                 @NonNull MdnsInterfaceSocket socket) {
             final MdnsInterfaceAdvertiser advertiser = mAdvertisers.get(socket);
             if (advertiser != null) advertiser.destroyNow();
         }
 
         @Override
-        public void onAddressesChanged(@NonNull Network network,
+        public void onAddressesChanged(@NonNull SocketKey socketKey,
                 @NonNull MdnsInterfaceSocket socket, @NonNull List<LinkAddress> addresses) {
             final MdnsInterfaceAdvertiser advertiser = mAdvertisers.get(socket);
             if (advertiser != null) advertiser.updateAddresses(addresses);
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java b/service-t/src/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java
index 73e4497..03be681 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java
@@ -64,7 +64,7 @@
         @NonNull
         private final SocketCreationCallback mSocketCreationCallback;
         @NonNull
-        private final ArrayMap<MdnsInterfaceSocket, Network> mActiveNetworkSockets =
+        private final ArrayMap<MdnsInterfaceSocket, SocketKey> mActiveNetworkSockets =
                 new ArrayMap<>();
 
         InterfaceSocketCallback(SocketCreationCallback socketCreationCallback) {
@@ -72,32 +72,32 @@
         }
 
         @Override
-        public void onSocketCreated(@Nullable Network network,
+        public void onSocketCreated(@NonNull SocketKey socketKey,
                 @NonNull MdnsInterfaceSocket socket, @NonNull List<LinkAddress> addresses) {
             // The socket may be already created by other request before, try to get the stored
             // ReadPacketHandler.
             ReadPacketHandler handler = mSocketPacketHandlers.get(socket);
             if (handler == null) {
                 // First request to create this socket. Initial a ReadPacketHandler for this socket.
-                handler = new ReadPacketHandler(network, socket.getInterface().getIndex());
+                handler = new ReadPacketHandler(socketKey);
                 mSocketPacketHandlers.put(socket, handler);
             }
             socket.addPacketHandler(handler);
-            mActiveNetworkSockets.put(socket, network);
-            mSocketCreationCallback.onSocketCreated(network);
+            mActiveNetworkSockets.put(socket, socketKey);
+            mSocketCreationCallback.onSocketCreated(socketKey.getNetwork());
         }
 
         @Override
-        public void onInterfaceDestroyed(@Nullable Network network,
+        public void onInterfaceDestroyed(@NonNull SocketKey socketKey,
                 @NonNull MdnsInterfaceSocket socket) {
             notifySocketDestroyed(socket);
             maybeCleanupPacketHandler(socket);
         }
 
         private void notifySocketDestroyed(@NonNull MdnsInterfaceSocket socket) {
-            final Network network = mActiveNetworkSockets.remove(socket);
-            if (!isAnySocketActive(network)) {
-                mSocketCreationCallback.onAllSocketsDestroyed(network);
+            final SocketKey socketKey = mActiveNetworkSockets.remove(socket);
+            if (!isAnySocketActive(socketKey)) {
+                mSocketCreationCallback.onAllSocketsDestroyed(socketKey.getNetwork());
             }
         }
 
@@ -121,18 +121,18 @@
         return false;
     }
 
-    private boolean isAnySocketActive(@Nullable Network network) {
+    private boolean isAnySocketActive(@NonNull SocketKey socketKey) {
         for (int i = 0; i < mRequestedNetworks.size(); i++) {
             final InterfaceSocketCallback isc = mRequestedNetworks.valueAt(i);
-            if (isc.mActiveNetworkSockets.containsValue(network)) {
+            if (isc.mActiveNetworkSockets.containsValue(socketKey)) {
                 return true;
             }
         }
         return false;
     }
 
-    private ArrayMap<MdnsInterfaceSocket, Network> getActiveSockets() {
-        final ArrayMap<MdnsInterfaceSocket, Network> sockets = new ArrayMap<>();
+    private ArrayMap<MdnsInterfaceSocket, SocketKey> getActiveSockets() {
+        final ArrayMap<MdnsInterfaceSocket, SocketKey> sockets = new ArrayMap<>();
         for (int i = 0; i < mRequestedNetworks.size(); i++) {
             final InterfaceSocketCallback isc = mRequestedNetworks.valueAt(i);
             sockets.putAll(isc.mActiveNetworkSockets);
@@ -146,17 +146,15 @@
     }
 
     private class ReadPacketHandler implements MulticastPacketReader.PacketHandler {
-        private final Network mNetwork;
-        private final int mInterfaceIndex;
+        @NonNull private final SocketKey mSocketKey;
 
-        ReadPacketHandler(@NonNull Network network, int interfaceIndex) {
-            mNetwork = network;
-            mInterfaceIndex = interfaceIndex;
+        ReadPacketHandler(@NonNull SocketKey socketKey) {
+            mSocketKey = socketKey;
         }
 
         @Override
         public void handlePacket(byte[] recvbuf, int length, InetSocketAddress src) {
-            processResponsePacket(recvbuf, length, mInterfaceIndex, mNetwork);
+            processResponsePacket(recvbuf, length, mSocketKey);
         }
     }
 
@@ -220,10 +218,10 @@
                 instanceof Inet6Address;
         final boolean isIpv4 = ((InetSocketAddress) packet.getSocketAddress()).getAddress()
                 instanceof Inet4Address;
-        final ArrayMap<MdnsInterfaceSocket, Network> activeSockets = getActiveSockets();
+        final ArrayMap<MdnsInterfaceSocket, SocketKey> activeSockets = getActiveSockets();
         for (int i = 0; i < activeSockets.size(); i++) {
             final MdnsInterfaceSocket socket = activeSockets.keyAt(i);
-            final Network network = activeSockets.valueAt(i);
+            final Network network = activeSockets.valueAt(i).getNetwork();
             // Check ip capability and network before sending packet
             if (((isIpv6 && socket.hasJoinedIpv6()) || (isIpv4 && socket.hasJoinedIpv4()))
                     // Contrary to MdnsUtils.isNetworkMatched, only send packets targeting
@@ -239,8 +237,7 @@
         }
     }
 
-    private void processResponsePacket(byte[] recvbuf, int length, int interfaceIndex,
-            @NonNull Network network) {
+    private void processResponsePacket(byte[] recvbuf, int length, @NonNull SocketKey socketKey) {
         int packetNumber = ++mReceivedPacketNumber;
 
         final MdnsPacket response;
@@ -250,14 +247,16 @@
             if (e.code != MdnsResponseErrorCode.ERROR_NOT_RESPONSE_MESSAGE) {
                 Log.e(TAG, e.getMessage(), e);
                 if (mCallback != null) {
-                    mCallback.onFailedToParseMdnsResponse(packetNumber, e.code, network);
+                    mCallback.onFailedToParseMdnsResponse(
+                            packetNumber, e.code, socketKey.getNetwork());
                 }
             }
             return;
         }
 
         if (mCallback != null) {
-            mCallback.onResponseReceived(response, interfaceIndex, network);
+            mCallback.onResponseReceived(
+                    response, socketKey.getInterfaceIndex(), socketKey.getNetwork());
         }
     }
 
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsSocketProvider.java b/service-t/src/com/android/server/connectivity/mdns/MdnsSocketProvider.java
index d90f67f..3df6313 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsSocketProvider.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsSocketProvider.java
@@ -258,6 +258,11 @@
                 @NonNull final NetLinkMonitorCallBack cb) {
             return SocketNetLinkMonitorFactory.createNetLinkMonitor(handler, log, cb);
         }
+
+        /*** Get interface index by given socket */
+        public int getInterfaceIndex(@NonNull MdnsInterfaceSocket socket) {
+            return socket.getInterface().getIndex();
+        }
     }
     /**
      * The callback interface for the netlink monitor messages.
@@ -597,8 +602,10 @@
         for (int i = 0; i < mCallbacksToRequestedNetworks.size(); i++) {
             final Network requestedNetwork = mCallbacksToRequestedNetworks.valueAt(i);
             if (isNetworkMatched(requestedNetwork, network)) {
-                mCallbacksToRequestedNetworks.keyAt(i).onSocketCreated(network, socketInfo.mSocket,
-                        socketInfo.mAddresses);
+                final int ifaceIndex = mDependencies.getInterfaceIndex(socketInfo.mSocket);
+                final SocketKey socketKey = new SocketKey(network, ifaceIndex);
+                mCallbacksToRequestedNetworks.keyAt(i).onSocketCreated(socketKey,
+                        socketInfo.mSocket, socketInfo.mAddresses);
                 mSocketRequestMonitor.onSocketRequestFulfilled(network, socketInfo.mSocket,
                         socketInfo.mTransports);
             }
@@ -609,7 +616,9 @@
         for (int i = 0; i < mCallbacksToRequestedNetworks.size(); i++) {
             final Network requestedNetwork = mCallbacksToRequestedNetworks.valueAt(i);
             if (isNetworkMatched(requestedNetwork, network)) {
-                mCallbacksToRequestedNetworks.keyAt(i).onInterfaceDestroyed(network, socket);
+                final int ifaceIndex = mDependencies.getInterfaceIndex(socket);
+                mCallbacksToRequestedNetworks.keyAt(i)
+                        .onInterfaceDestroyed(new SocketKey(network, ifaceIndex), socket);
             }
         }
     }
@@ -619,8 +628,9 @@
         for (int i = 0; i < mCallbacksToRequestedNetworks.size(); i++) {
             final Network requestedNetwork = mCallbacksToRequestedNetworks.valueAt(i);
             if (isNetworkMatched(requestedNetwork, network)) {
+                final int ifaceIndex = mDependencies.getInterfaceIndex(socket);
                 mCallbacksToRequestedNetworks.keyAt(i)
-                        .onAddressesChanged(network, socket, addresses);
+                        .onAddressesChanged(new SocketKey(network, ifaceIndex), socket, addresses);
             }
         }
     }
@@ -637,7 +647,9 @@
             createSocket(new NetworkAsKey(network), lp);
         } else {
             // Notify the socket for requested network.
-            cb.onSocketCreated(network, socketInfo.mSocket, socketInfo.mAddresses);
+            final int ifaceIndex = mDependencies.getInterfaceIndex(socketInfo.mSocket);
+            final SocketKey socketKey = new SocketKey(network, ifaceIndex);
+            cb.onSocketCreated(socketKey, socketInfo.mSocket, socketInfo.mAddresses);
             mSocketRequestMonitor.onSocketRequestFulfilled(network, socketInfo.mSocket,
                     socketInfo.mTransports);
         }
@@ -652,8 +664,9 @@
                     createLPForTetheredInterface(interfaceName, ifaceIndex));
         } else {
             // Notify the socket for requested network.
-            cb.onSocketCreated(
-                    null /* network */, socketInfo.mSocket, socketInfo.mAddresses);
+            final int ifaceIndex = mDependencies.getInterfaceIndex(socketInfo.mSocket);
+            final SocketKey socketKey = new SocketKey(ifaceIndex);
+            cb.onSocketCreated(socketKey, socketInfo.mSocket, socketInfo.mAddresses);
             mSocketRequestMonitor.onSocketRequestFulfilled(null /* socketNetwork */,
                     socketInfo.mSocket, socketInfo.mTransports);
         }
@@ -741,21 +754,21 @@
          * This may be called immediately when the request is registered with an existing socket,
          * if it had been created previously for other requests.
          */
-        default void onSocketCreated(@Nullable Network network, @NonNull MdnsInterfaceSocket socket,
-                @NonNull List<LinkAddress> addresses) {}
+        default void onSocketCreated(@NonNull SocketKey socketKey,
+                @NonNull MdnsInterfaceSocket socket, @NonNull List<LinkAddress> addresses) {}
 
         /**
          * Notify that the interface was destroyed, so the provided socket cannot be used anymore.
          *
          * This indicates that although the socket was still requested, it had to be destroyed.
          */
-        default void onInterfaceDestroyed(@Nullable Network network,
+        default void onInterfaceDestroyed(@NonNull SocketKey socketKey,
                 @NonNull MdnsInterfaceSocket socket) {}
 
         /**
          * Notify the interface addresses have changed for the network.
          */
-        default void onAddressesChanged(@Nullable Network network,
+        default void onAddressesChanged(@NonNull SocketKey socketKey,
                 @NonNull MdnsInterfaceSocket socket, @NonNull List<LinkAddress> addresses) {}
     }
 
diff --git a/service-t/src/com/android/server/connectivity/mdns/SocketKey.java b/service-t/src/com/android/server/connectivity/mdns/SocketKey.java
new file mode 100644
index 0000000..a893acb
--- /dev/null
+++ b/service-t/src/com/android/server/connectivity/mdns/SocketKey.java
@@ -0,0 +1,72 @@
+/*
+ * Copyright (C) 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server.connectivity.mdns;
+
+import android.annotation.Nullable;
+import android.net.Network;
+
+import java.util.Objects;
+
+/**
+ * A class that identifies a socket.
+ *
+ * <p> A socket is typically created with an associated network. However, tethering interfaces do
+ * not have an associated network, only an interface index. This means that the socket cannot be
+ * identified in some places. Therefore, this class is necessary for identifying a socket. It
+ * includes both the network and interface index.
+ */
+public class SocketKey {
+    @Nullable
+    private final Network mNetwork;
+    private final int mInterfaceIndex;
+
+    SocketKey(int interfaceIndex) {
+        this(null /* network */, interfaceIndex);
+    }
+
+    SocketKey(@Nullable Network network, int interfaceIndex) {
+        mNetwork = network;
+        mInterfaceIndex = interfaceIndex;
+    }
+
+    public Network getNetwork() {
+        return mNetwork;
+    }
+
+    public int getInterfaceIndex() {
+        return mInterfaceIndex;
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(mNetwork, mInterfaceIndex);
+    }
+
+    @Override
+    public boolean equals(@Nullable Object other) {
+        if (!(other instanceof SocketKey)) {
+            return false;
+        }
+        return Objects.equals(mNetwork, ((SocketKey) other).mNetwork)
+                && mInterfaceIndex == ((SocketKey) other).mInterfaceIndex;
+    }
+
+    @Override
+    public String toString() {
+        return "SocketKey{ network=" + mNetwork + " interfaceIndex=" + mInterfaceIndex + " }";
+    }
+}
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsAdvertiserTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsAdvertiserTest.kt
index d9acc61..c467f45 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsAdvertiserTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsAdvertiserTest.kt
@@ -56,7 +56,8 @@
 private val TEST_ADDR = parseNumericAddress("2001:db8::123")
 private val TEST_LINKADDR = LinkAddress(TEST_ADDR, 64 /* prefixLength */)
 private val TEST_NETWORK_1 = mock(Network::class.java)
-private val TEST_NETWORK_2 = mock(Network::class.java)
+private val TEST_SOCKETKEY_1 = mock(SocketKey::class.java)
+private val TEST_SOCKETKEY_2 = mock(SocketKey::class.java)
 private val TEST_HOSTNAME = arrayOf("Android_test", "local")
 private const val TEST_SUBTYPE = "_subtype"
 
@@ -145,7 +146,7 @@
         verify(socketProvider).requestSocket(eq(TEST_NETWORK_1), socketCbCaptor.capture())
 
         val socketCb = socketCbCaptor.value
-        postSync { socketCb.onSocketCreated(TEST_NETWORK_1, mockSocket1, listOf(TEST_LINKADDR)) }
+        postSync { socketCb.onSocketCreated(TEST_SOCKETKEY_1, mockSocket1, listOf(TEST_LINKADDR)) }
 
         val intAdvCbCaptor = ArgumentCaptor.forClass(MdnsInterfaceAdvertiser.Callback::class.java)
         verify(mockDeps).makeAdvertiser(
@@ -163,7 +164,7 @@
                 mockInterfaceAdvertiser1, SERVICE_ID_1) }
         verify(cb).onRegisterServiceSucceeded(eq(SERVICE_ID_1), argThat { it.matches(SERVICE_1) })
 
-        postSync { socketCb.onInterfaceDestroyed(TEST_NETWORK_1, mockSocket1) }
+        postSync { socketCb.onInterfaceDestroyed(TEST_SOCKETKEY_1, mockSocket1) }
         verify(mockInterfaceAdvertiser1).destroyNow()
     }
 
@@ -177,8 +178,8 @@
                 socketCbCaptor.capture())
 
         val socketCb = socketCbCaptor.value
-        postSync { socketCb.onSocketCreated(TEST_NETWORK_1, mockSocket1, listOf(TEST_LINKADDR)) }
-        postSync { socketCb.onSocketCreated(TEST_NETWORK_2, mockSocket2, listOf(TEST_LINKADDR)) }
+        postSync { socketCb.onSocketCreated(TEST_SOCKETKEY_1, mockSocket1, listOf(TEST_LINKADDR)) }
+        postSync { socketCb.onSocketCreated(TEST_SOCKETKEY_2, mockSocket2, listOf(TEST_LINKADDR)) }
 
         val intAdvCbCaptor1 = ArgumentCaptor.forClass(MdnsInterfaceAdvertiser.Callback::class.java)
         val intAdvCbCaptor2 = ArgumentCaptor.forClass(MdnsInterfaceAdvertiser.Callback::class.java)
@@ -241,8 +242,8 @@
 
         // Callbacks for matching network and all networks both get the socket
         postSync {
-            oneNetSocketCb.onSocketCreated(TEST_NETWORK_1, mockSocket1, listOf(TEST_LINKADDR))
-            allNetSocketCb.onSocketCreated(TEST_NETWORK_1, mockSocket1, listOf(TEST_LINKADDR))
+            oneNetSocketCb.onSocketCreated(TEST_SOCKETKEY_1, mockSocket1, listOf(TEST_LINKADDR))
+            allNetSocketCb.onSocketCreated(TEST_SOCKETKEY_1, mockSocket1, listOf(TEST_LINKADDR))
         }
 
         val expectedRenamed = NsdServiceInfo(
@@ -294,8 +295,8 @@
         verify(cb).onRegisterServiceSucceeded(eq(SERVICE_ID_2),
                 argThat { it.matches(expectedRenamed) })
 
-        postSync { oneNetSocketCb.onInterfaceDestroyed(TEST_NETWORK_1, mockSocket1) }
-        postSync { allNetSocketCb.onInterfaceDestroyed(TEST_NETWORK_1, mockSocket1) }
+        postSync { oneNetSocketCb.onInterfaceDestroyed(TEST_SOCKETKEY_1, mockSocket1) }
+        postSync { allNetSocketCb.onInterfaceDestroyed(TEST_SOCKETKEY_1, mockSocket1) }
 
         // destroyNow can be called multiple times
         verify(mockInterfaceAdvertiser1, atLeastOnce()).destroyNow()
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClientTest.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClientTest.java
index 87ba5d7..a0a302f 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClientTest.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClientTest.java
@@ -68,12 +68,15 @@
     @Mock private MdnsServiceBrowserListener mListener;
     @Mock private MdnsSocketClientBase.Callback mCallback;
     @Mock private SocketCreationCallback mSocketCreationCallback;
+    @Mock private SocketKey mSocketKey;
     private MdnsMultinetworkSocketClient mSocketClient;
     private Handler mHandler;
 
     @Before
     public void setUp() throws SocketException {
         MockitoAnnotations.initMocks(this);
+        doReturn(mNetwork).when(mSocketKey).getNetwork();
+
         final HandlerThread thread = new HandlerThread("MdnsMultinetworkSocketClientTest");
         thread.start();
         mHandler = new Handler(thread.getLooper());
@@ -123,12 +126,16 @@
             doReturn(createEmptyNetworkInterface()).when(socket).getInterface();
         }
 
+        final SocketKey tetherSocketKey1 = mock(SocketKey.class);
+        final SocketKey tetherSocketKey2 = mock(SocketKey.class);
+        doReturn(null).when(tetherSocketKey1).getNetwork();
+        doReturn(null).when(tetherSocketKey2).getNetwork();
         // Notify socket created
-        callback.onSocketCreated(mNetwork, mSocket, List.of());
+        callback.onSocketCreated(mSocketKey, mSocket, List.of());
         verify(mSocketCreationCallback).onSocketCreated(mNetwork);
-        callback.onSocketCreated(null, tetherIfaceSock1, List.of());
+        callback.onSocketCreated(tetherSocketKey1, tetherIfaceSock1, List.of());
         verify(mSocketCreationCallback).onSocketCreated(null);
-        callback.onSocketCreated(null, tetherIfaceSock2, List.of());
+        callback.onSocketCreated(tetherSocketKey2, tetherIfaceSock2, List.of());
         verify(mSocketCreationCallback, times(2)).onSocketCreated(null);
 
         // Send packet to IPv4 with target network and verify sending has been called.
@@ -164,7 +171,7 @@
 
         doReturn(createEmptyNetworkInterface()).when(mSocket).getInterface();
         // Notify socket created
-        callback.onSocketCreated(mNetwork, mSocket, List.of());
+        callback.onSocketCreated(mSocketKey, mSocket, List.of());
         verify(mSocketCreationCallback).onSocketCreated(mNetwork);
 
         final ArgumentCaptor<PacketHandler> handlerCaptor =
@@ -214,9 +221,11 @@
         doReturn(createEmptyNetworkInterface()).when(socket2).getInterface();
         doReturn(createEmptyNetworkInterface()).when(socket3).getInterface();
 
-        callback.onSocketCreated(mNetwork, mSocket, List.of());
-        callback.onSocketCreated(null, socket2, List.of());
-        callback.onSocketCreated(null, socket3, List.of());
+        final SocketKey socketKey2 = mock(SocketKey.class);
+        doReturn(null).when(socketKey2).getNetwork();
+        callback.onSocketCreated(mSocketKey, mSocket, List.of());
+        callback.onSocketCreated(socketKey2, socket2, List.of());
+        callback.onSocketCreated(socketKey2, socket3, List.of());
         verify(mSocketCreationCallback).onSocketCreated(mNetwork);
         verify(mSocketCreationCallback, times(2)).onSocketCreated(null);
 
@@ -241,9 +250,9 @@
         final SocketCallback callback2 = callback2Captor.getAllValues().get(1);
 
         // Notify socket created for all networks.
-        callback2.onSocketCreated(mNetwork, mSocket, List.of());
-        callback2.onSocketCreated(null, socket2, List.of());
-        callback2.onSocketCreated(null, socket3, List.of());
+        callback2.onSocketCreated(mSocketKey, mSocket, List.of());
+        callback2.onSocketCreated(socketKey2, socket2, List.of());
+        callback2.onSocketCreated(socketKey2, socket3, List.of());
         verify(socketCreationCb2).onSocketCreated(mNetwork);
         verify(socketCreationCb2, times(2)).onSocketCreated(null);
 
@@ -286,17 +295,17 @@
         doReturn(createEmptyNetworkInterface()).when(mSocket).getInterface();
         doReturn(createEmptyNetworkInterface()).when(otherSocket).getInterface();
 
-        callback.onSocketCreated(null /* network */, mSocket, List.of());
-        verify(mSocketCreationCallback).onSocketCreated(null);
-        callback.onSocketCreated(null /* network */, otherSocket, List.of());
-        verify(mSocketCreationCallback, times(2)).onSocketCreated(null);
+        callback.onSocketCreated(mSocketKey, mSocket, List.of());
+        verify(mSocketCreationCallback).onSocketCreated(mNetwork);
+        callback.onSocketCreated(mSocketKey, otherSocket, List.of());
+        verify(mSocketCreationCallback, times(2)).onSocketCreated(mNetwork);
 
-        verify(mSocketCreationCallback, never()).onAllSocketsDestroyed(null /* network */);
+        verify(mSocketCreationCallback, never()).onAllSocketsDestroyed(mNetwork);
         mHandler.post(() -> mSocketClient.notifyNetworkUnrequested(mListener));
         HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
 
         verify(mProvider).unrequestSocket(callback);
-        verify(mSocketCreationCallback).onAllSocketsDestroyed(null /* network */);
+        verify(mSocketCreationCallback).onAllSocketsDestroyed(mNetwork);
     }
 
     @Test
@@ -306,15 +315,15 @@
         doReturn(createEmptyNetworkInterface()).when(mSocket).getInterface();
         doReturn(createEmptyNetworkInterface()).when(otherSocket).getInterface();
 
-        callback.onSocketCreated(null /* network */, mSocket, List.of());
-        verify(mSocketCreationCallback).onSocketCreated(null);
-        callback.onSocketCreated(null /* network */, otherSocket, List.of());
-        verify(mSocketCreationCallback, times(2)).onSocketCreated(null);
+        callback.onSocketCreated(mSocketKey, mSocket, List.of());
+        verify(mSocketCreationCallback).onSocketCreated(mNetwork);
+        callback.onSocketCreated(mSocketKey, otherSocket, List.of());
+        verify(mSocketCreationCallback, times(2)).onSocketCreated(mNetwork);
 
         // Notify socket destroyed
-        callback.onInterfaceDestroyed(null /* network */, mSocket);
+        callback.onInterfaceDestroyed(mSocketKey, mSocket);
         verifyNoMoreInteractions(mSocketCreationCallback);
-        callback.onInterfaceDestroyed(null /* network */, otherSocket);
-        verify(mSocketCreationCallback).onAllSocketsDestroyed(null /* network */);
+        callback.onInterfaceDestroyed(mSocketKey, otherSocket);
+        verify(mSocketCreationCallback).onAllSocketsDestroyed(mNetwork);
     }
 }
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketProviderTest.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketProviderTest.java
index 4ef64cb..0eac5ec 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketProviderTest.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketProviderTest.java
@@ -157,6 +157,7 @@
                 TETHERED_IFACE_NAME);
         doReturn(789).when(mDeps).getNetworkInterfaceIndexByName(
                 WIFI_P2P_IFACE_NAME);
+        doReturn(TETHERED_IFACE_IDX).when(mDeps).getInterfaceIndex(any());
         final HandlerThread thread = new HandlerThread("MdnsSocketProviderTest");
         thread.start();
         mHandler = new Handler(thread.getLooper());
@@ -227,30 +228,30 @@
 
     private class TestSocketCallback implements MdnsSocketProvider.SocketCallback {
         private class SocketEvent {
-            public final Network mNetwork;
+            public final SocketKey mSocketKey;
             public final List<LinkAddress> mAddresses;
 
-            SocketEvent(Network network, List<LinkAddress> addresses) {
-                mNetwork = network;
+            SocketEvent(SocketKey socketKey, List<LinkAddress> addresses) {
+                mSocketKey = socketKey;
                 mAddresses = Collections.unmodifiableList(addresses);
             }
         }
 
         private class SocketCreatedEvent extends SocketEvent {
-            SocketCreatedEvent(Network nw, List<LinkAddress> addresses) {
-                super(nw, addresses);
+            SocketCreatedEvent(SocketKey socketKey, List<LinkAddress> addresses) {
+                super(socketKey, addresses);
             }
         }
 
         private class InterfaceDestroyedEvent extends SocketEvent {
-            InterfaceDestroyedEvent(Network nw, List<LinkAddress> addresses) {
-                super(nw, addresses);
+            InterfaceDestroyedEvent(SocketKey socketKey, List<LinkAddress> addresses) {
+                super(socketKey, addresses);
             }
         }
 
         private class AddressesChangedEvent extends SocketEvent {
-            AddressesChangedEvent(Network nw, List<LinkAddress> addresses) {
-                super(nw, addresses);
+            AddressesChangedEvent(SocketKey socketKey, List<LinkAddress> addresses) {
+                super(socketKey, addresses);
             }
         }
 
@@ -258,27 +259,27 @@
                 new ArrayTrackRecord<SocketEvent>().newReadHead();
 
         @Override
-        public void onSocketCreated(Network network, MdnsInterfaceSocket socket,
+        public void onSocketCreated(SocketKey socketKey, MdnsInterfaceSocket socket,
                 List<LinkAddress> addresses) {
-            mHistory.add(new SocketCreatedEvent(network, addresses));
+            mHistory.add(new SocketCreatedEvent(socketKey, addresses));
         }
 
         @Override
-        public void onInterfaceDestroyed(Network network, MdnsInterfaceSocket socket) {
-            mHistory.add(new InterfaceDestroyedEvent(network, List.of()));
+        public void onInterfaceDestroyed(SocketKey socketKey, MdnsInterfaceSocket socket) {
+            mHistory.add(new InterfaceDestroyedEvent(socketKey, List.of()));
         }
 
         @Override
-        public void onAddressesChanged(Network network, MdnsInterfaceSocket socket,
+        public void onAddressesChanged(SocketKey socketKey, MdnsInterfaceSocket socket,
                 List<LinkAddress> addresses) {
-            mHistory.add(new AddressesChangedEvent(network, addresses));
+            mHistory.add(new AddressesChangedEvent(socketKey, addresses));
         }
 
         public void expectedSocketCreatedForNetwork(Network network, List<LinkAddress> addresses) {
             final SocketEvent event = mHistory.poll(0L /* timeoutMs */, c -> true);
             assertNotNull(event);
             assertTrue(event instanceof SocketCreatedEvent);
-            assertEquals(network, event.mNetwork);
+            assertEquals(network, event.mSocketKey.getNetwork());
             assertEquals(addresses, event.mAddresses);
         }
 
@@ -286,7 +287,7 @@
             final SocketEvent event = mHistory.poll(0L /* timeoutMs */, c -> true);
             assertNotNull(event);
             assertTrue(event instanceof InterfaceDestroyedEvent);
-            assertEquals(network, event.mNetwork);
+            assertEquals(network, event.mSocketKey.getNetwork());
         }
 
         public void expectedAddressesChangedForNetwork(Network network,
@@ -294,7 +295,7 @@
             final SocketEvent event = mHistory.poll(0L /* timeoutMs */, c -> true);
             assertNotNull(event);
             assertTrue(event instanceof AddressesChangedEvent);
-            assertEquals(network, event.mNetwork);
+            assertEquals(network, event.mSocketKey.getNetwork());
             assertEquals(event.mAddresses, addresses);
         }