Merge "Never create native network immediately."
diff --git a/Cronet/tests/cts/src/android/net/http/cts/UrlRequestTest.java b/Cronet/tests/cts/src/android/net/http/cts/UrlRequestTest.java
index 07e7d45..3c4d134 100644
--- a/Cronet/tests/cts/src/android/net/http/cts/UrlRequestTest.java
+++ b/Cronet/tests/cts/src/android/net/http/cts/UrlRequestTest.java
@@ -363,6 +363,116 @@
                 .containsAtLeastElementsIn(expectedHeaders);
     }
 
+    @Test
+    public void testUrlRequest_getHttpMethod() throws Exception {
+        UrlRequest.Builder builder = createUrlRequestBuilder(mTestServer.getSuccessUrl());
+        final String method = "POST";
+
+        builder.setHttpMethod(method);
+        UrlRequest request = builder.build();
+        assertThat(request.getHttpMethod()).isEqualTo(method);
+    }
+
+    @Test
+    public void testUrlRequest_getHeaders_asList() throws Exception {
+        UrlRequest.Builder builder = createUrlRequestBuilder(mTestServer.getSuccessUrl());
+        final List<Map.Entry<String, String>> expectedHeaders = Arrays.asList(
+                Map.entry("Authorization", "Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ=="),
+                Map.entry("Max-Forwards", "10"),
+                Map.entry("X-Client-Data", "random custom header content"));
+
+        for (Map.Entry<String, String> header : expectedHeaders) {
+            builder.addHeader(header.getKey(), header.getValue());
+        }
+
+        UrlRequest request = builder.build();
+        assertThat(request.getHeaders().getAsList()).containsAtLeastElementsIn(expectedHeaders);
+    }
+
+    @Test
+    public void testUrlRequest_getHeaders_asMap() throws Exception {
+        UrlRequest.Builder builder = createUrlRequestBuilder(mTestServer.getSuccessUrl());
+        final Map<String, List<String>> expectedHeaders = Map.of(
+                "Authorization", Arrays.asList("Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ=="),
+                "Max-Forwards", Arrays.asList("10"),
+                "X-Client-Data", Arrays.asList("random custom header content"));
+
+        for (Map.Entry<String, List<String>> header : expectedHeaders.entrySet()) {
+            builder.addHeader(header.getKey(), header.getValue().get(0));
+        }
+
+        UrlRequest request = builder.build();
+        assertThat(request.getHeaders().getAsMap()).containsAtLeastEntriesIn(expectedHeaders);
+    }
+
+    @Test
+    public void testUrlRequest_isCacheDisabled() throws Exception {
+        UrlRequest.Builder builder = createUrlRequestBuilder(mTestServer.getSuccessUrl());
+        final boolean isCacheDisabled = true;
+
+        builder.setCacheDisabled(isCacheDisabled);
+        UrlRequest request = builder.build();
+        assertThat(request.isCacheDisabled()).isEqualTo(isCacheDisabled);
+    }
+
+    @Test
+    public void testUrlRequest_isDirectExecutorAllowed() throws Exception {
+        UrlRequest.Builder builder = createUrlRequestBuilder(mTestServer.getSuccessUrl());
+        final boolean isDirectExecutorAllowed = true;
+
+        builder.setDirectExecutorAllowed(isDirectExecutorAllowed);
+        UrlRequest request = builder.build();
+        assertThat(request.isDirectExecutorAllowed()).isEqualTo(isDirectExecutorAllowed);
+    }
+
+    @Test
+    public void testUrlRequest_getPriority() throws Exception {
+        UrlRequest.Builder builder = createUrlRequestBuilder(mTestServer.getSuccessUrl());
+        final int priority = UrlRequest.REQUEST_PRIORITY_LOW;
+
+        builder.setPriority(priority);
+        UrlRequest request = builder.build();
+        assertThat(request.getPriority()).isEqualTo(priority);
+    }
+
+    @Test
+    public void testUrlRequest_hasTrafficStatsTag() throws Exception {
+        UrlRequest.Builder builder = createUrlRequestBuilder(mTestServer.getSuccessUrl());
+
+        builder.setTrafficStatsTag(10);
+        UrlRequest request = builder.build();
+        assertThat(request.hasTrafficStatsTag()).isEqualTo(true);
+    }
+
+    @Test
+    public void testUrlRequest_getTrafficStatsTag() throws Exception {
+        UrlRequest.Builder builder = createUrlRequestBuilder(mTestServer.getSuccessUrl());
+        final int trafficStatsTag = 10;
+
+        builder.setTrafficStatsTag(trafficStatsTag);
+        UrlRequest request = builder.build();
+        assertThat(request.getTrafficStatsTag()).isEqualTo(trafficStatsTag);
+    }
+
+    @Test
+    public void testUrlRequest_hasTrafficStatsUid() throws Exception {
+        UrlRequest.Builder builder = createUrlRequestBuilder(mTestServer.getSuccessUrl());
+
+        builder.setTrafficStatsUid(10);
+        UrlRequest request = builder.build();
+        assertThat(request.hasTrafficStatsUid()).isEqualTo(true);
+    }
+
+    @Test
+    public void testUrlRequest_getTrafficStatsUid() throws Exception {
+        UrlRequest.Builder builder = createUrlRequestBuilder(mTestServer.getSuccessUrl());
+        final int trafficStatsUid = 10;
+
+        builder.setTrafficStatsUid(trafficStatsUid);
+        UrlRequest request = builder.build();
+        assertThat(request.getTrafficStatsUid()).isEqualTo(trafficStatsUid);
+    }
+
     private static List<Map.Entry<String, String>> extractEchoedHeaders(HeaderBlock headers) {
         return headers.getAsList()
                 .stream()
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java b/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java
index 7288828..05b1dcf 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java
@@ -194,7 +194,7 @@
                     }
 
                     @Override
-                    public void onAllSocketsDestroyed(@NonNull SocketKey socketKey) {
+                    public void onSocketDestroyed(@NonNull SocketKey socketKey) {
                         ensureRunningOnHandlerThread(handler);
                         final MdnsServiceTypeClient serviceTypeClient =
                                 perSocketServiceTypeClients.get(serviceType, socketKey);
@@ -284,9 +284,11 @@
     MdnsServiceTypeClient createServiceTypeClient(@NonNull String serviceType,
             @NonNull SocketKey socketKey) {
         sharedLog.log("createServiceTypeClient for type:" + serviceType + " " + socketKey);
+        final String tag = serviceType + "-" + socketKey.getNetwork()
+                + "/" + socketKey.getInterfaceIndex();
         return new MdnsServiceTypeClient(
                 serviceType, socketClient,
                 executorProvider.newServiceTypeClientSchedulerExecutor(), socketKey,
-                sharedLog.forSubComponent(serviceType + "-" + socketKey));
+                sharedLog.forSubComponent(tag));
     }
 }
\ No newline at end of file
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java b/service-t/src/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java
index 18e460e..d1fa57c 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java
@@ -64,7 +64,7 @@
         @NonNull
         private final SocketCreationCallback mSocketCreationCallback;
         @NonNull
-        private final ArrayMap<MdnsInterfaceSocket, SocketKey> mActiveNetworkSockets =
+        private final ArrayMap<MdnsInterfaceSocket, SocketKey> mActiveSockets =
                 new ArrayMap<>();
 
         InterfaceSocketCallback(SocketCreationCallback socketCreationCallback) {
@@ -83,7 +83,7 @@
                 mSocketPacketHandlers.put(socket, handler);
             }
             socket.addPacketHandler(handler);
-            mActiveNetworkSockets.put(socket, socketKey);
+            mActiveSockets.put(socket, socketKey);
             mSocketCreationCallback.onSocketCreated(socketKey);
         }
 
@@ -95,16 +95,16 @@
         }
 
         private void notifySocketDestroyed(@NonNull MdnsInterfaceSocket socket) {
-            final SocketKey socketKey = mActiveNetworkSockets.remove(socket);
-            if (!isAnySocketActive(socketKey)) {
-                mSocketCreationCallback.onAllSocketsDestroyed(socketKey);
+            final SocketKey socketKey = mActiveSockets.remove(socket);
+            if (!isSocketActive(socket)) {
+                mSocketCreationCallback.onSocketDestroyed(socketKey);
             }
         }
 
         void onNetworkUnrequested() {
-            for (int i = mActiveNetworkSockets.size() - 1; i >= 0; i--) {
+            for (int i = mActiveSockets.size() - 1; i >= 0; i--) {
                 // Iterate from the end so the socket can be removed
-                final MdnsInterfaceSocket socket = mActiveNetworkSockets.keyAt(i);
+                final MdnsInterfaceSocket socket = mActiveSockets.keyAt(i);
                 notifySocketDestroyed(socket);
                 maybeCleanupPacketHandler(socket);
             }
@@ -114,17 +114,7 @@
     private boolean isSocketActive(@NonNull MdnsInterfaceSocket socket) {
         for (int i = 0; i < mRequestedNetworks.size(); i++) {
             final InterfaceSocketCallback isc = mRequestedNetworks.valueAt(i);
-            if (isc.mActiveNetworkSockets.containsKey(socket)) {
-                return true;
-            }
-        }
-        return false;
-    }
-
-    private boolean isAnySocketActive(@NonNull SocketKey socketKey) {
-        for (int i = 0; i < mRequestedNetworks.size(); i++) {
-            final InterfaceSocketCallback isc = mRequestedNetworks.valueAt(i);
-            if (isc.mActiveNetworkSockets.containsValue(socketKey)) {
+            if (isc.mActiveSockets.containsKey(socket)) {
                 return true;
             }
         }
@@ -135,7 +125,7 @@
         final ArrayMap<MdnsInterfaceSocket, SocketKey> sockets = new ArrayMap<>();
         for (int i = 0; i < mRequestedNetworks.size(); i++) {
             final InterfaceSocketCallback isc = mRequestedNetworks.valueAt(i);
-            sockets.putAll(isc.mActiveNetworkSockets);
+            sockets.putAll(isc.mActiveSockets);
         }
         return sockets;
     }
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClientBase.java b/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClientBase.java
index 5e4a8b5..b6000f0 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClientBase.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClientBase.java
@@ -82,6 +82,6 @@
         void onSocketCreated(@NonNull SocketKey socketKey);
 
         /*** Notify requested socket is destroyed */
-        void onAllSocketsDestroyed(@NonNull SocketKey socketKey);
+        void onSocketDestroyed(@NonNull SocketKey socketKey);
     }
 }
\ No newline at end of file
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsSocketProvider.java b/service-t/src/com/android/server/connectivity/mdns/MdnsSocketProvider.java
index 3df6313..e963ab7 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsSocketProvider.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsSocketProvider.java
@@ -258,11 +258,6 @@
                 @NonNull final NetLinkMonitorCallBack cb) {
             return SocketNetLinkMonitorFactory.createNetLinkMonitor(handler, log, cb);
         }
-
-        /*** Get interface index by given socket */
-        public int getInterfaceIndex(@NonNull MdnsInterfaceSocket socket) {
-            return socket.getInterface().getIndex();
-        }
     }
     /**
      * The callback interface for the netlink monitor messages.
@@ -323,11 +318,14 @@
         final MdnsInterfaceSocket mSocket;
         final List<LinkAddress> mAddresses;
         final int[] mTransports;
+        @NonNull final SocketKey mSocketKey;
 
-        SocketInfo(MdnsInterfaceSocket socket, List<LinkAddress> addresses, int[] transports) {
+        SocketInfo(MdnsInterfaceSocket socket, List<LinkAddress> addresses, int[] transports,
+                @NonNull SocketKey socketKey) {
             mSocket = socket;
             mAddresses = new ArrayList<>(addresses);
             mTransports = transports;
+            mSocketKey = socketKey;
         }
     }
 
@@ -447,7 +445,7 @@
         // Try to join the group again.
         socketInfo.mSocket.joinGroup(addresses);
 
-        notifyAddressesChanged(network, socketInfo.mSocket, addresses);
+        notifyAddressesChanged(network, socketInfo, addresses);
     }
     private LinkProperties createLPForTetheredInterface(@NonNull final String interfaceName,
             int ifaceIndex) {
@@ -528,21 +526,22 @@
                     networkInterface.getNetworkInterface(), MdnsConstants.MDNS_PORT, mLooper,
                     mPacketReadBuffer);
             final List<LinkAddress> addresses = lp.getLinkAddresses();
+            final Network network =
+                    networkKey == LOCAL_NET ? null : ((NetworkAsKey) networkKey).mNetwork;
+            final SocketKey socketKey = new SocketKey(network, networkInterface.getIndex());
             // TODO: technically transport types are mutable, although generally not in ways that
             // would meaningfully impact the logic using it here. Consider updating logic to
             // support transports being added/removed.
-            final SocketInfo socketInfo = new SocketInfo(socket, addresses, transports);
+            final SocketInfo socketInfo = new SocketInfo(socket, addresses, transports, socketKey);
             if (networkKey == LOCAL_NET) {
                 mTetherInterfaceSockets.put(interfaceName, socketInfo);
             } else {
-                mNetworkSockets.put(((NetworkAsKey) networkKey).mNetwork, socketInfo);
+                mNetworkSockets.put(network, socketInfo);
             }
             // Try to join IPv4/IPv6 group.
             socket.joinGroup(addresses);
 
             // Notify the listeners which need this socket.
-            final Network network =
-                    networkKey == LOCAL_NET ? null : ((NetworkAsKey) networkKey).mNetwork;
             notifySocketCreated(network, socketInfo);
         } catch (IOException e) {
             mSharedLog.e("Create socket failed ifName:" + interfaceName, e);
@@ -584,7 +583,7 @@
         if (socketInfo == null) return;
 
         socketInfo.mSocket.destroy();
-        notifyInterfaceDestroyed(network, socketInfo.mSocket);
+        notifyInterfaceDestroyed(network, socketInfo);
         mSocketRequestMonitor.onSocketDestroyed(network, socketInfo.mSocket);
         mSharedLog.log("Remove socket on net:" + network);
     }
@@ -593,7 +592,7 @@
         final SocketInfo socketInfo = mTetherInterfaceSockets.remove(interfaceName);
         if (socketInfo == null) return;
         socketInfo.mSocket.destroy();
-        notifyInterfaceDestroyed(null /* network */, socketInfo.mSocket);
+        notifyInterfaceDestroyed(null /* network */, socketInfo);
         mSocketRequestMonitor.onSocketDestroyed(null /* network */, socketInfo.mSocket);
         mSharedLog.log("Remove socket on ifName:" + interfaceName);
     }
@@ -602,9 +601,7 @@
         for (int i = 0; i < mCallbacksToRequestedNetworks.size(); i++) {
             final Network requestedNetwork = mCallbacksToRequestedNetworks.valueAt(i);
             if (isNetworkMatched(requestedNetwork, network)) {
-                final int ifaceIndex = mDependencies.getInterfaceIndex(socketInfo.mSocket);
-                final SocketKey socketKey = new SocketKey(network, ifaceIndex);
-                mCallbacksToRequestedNetworks.keyAt(i).onSocketCreated(socketKey,
+                mCallbacksToRequestedNetworks.keyAt(i).onSocketCreated(socketInfo.mSocketKey,
                         socketInfo.mSocket, socketInfo.mAddresses);
                 mSocketRequestMonitor.onSocketRequestFulfilled(network, socketInfo.mSocket,
                         socketInfo.mTransports);
@@ -612,25 +609,23 @@
         }
     }
 
-    private void notifyInterfaceDestroyed(Network network, MdnsInterfaceSocket socket) {
+    private void notifyInterfaceDestroyed(Network network, SocketInfo socketInfo) {
         for (int i = 0; i < mCallbacksToRequestedNetworks.size(); i++) {
             final Network requestedNetwork = mCallbacksToRequestedNetworks.valueAt(i);
             if (isNetworkMatched(requestedNetwork, network)) {
-                final int ifaceIndex = mDependencies.getInterfaceIndex(socket);
                 mCallbacksToRequestedNetworks.keyAt(i)
-                        .onInterfaceDestroyed(new SocketKey(network, ifaceIndex), socket);
+                        .onInterfaceDestroyed(socketInfo.mSocketKey, socketInfo.mSocket);
             }
         }
     }
 
-    private void notifyAddressesChanged(Network network, MdnsInterfaceSocket socket,
+    private void notifyAddressesChanged(Network network, SocketInfo socketInfo,
             List<LinkAddress> addresses) {
         for (int i = 0; i < mCallbacksToRequestedNetworks.size(); i++) {
             final Network requestedNetwork = mCallbacksToRequestedNetworks.valueAt(i);
             if (isNetworkMatched(requestedNetwork, network)) {
-                final int ifaceIndex = mDependencies.getInterfaceIndex(socket);
                 mCallbacksToRequestedNetworks.keyAt(i)
-                        .onAddressesChanged(new SocketKey(network, ifaceIndex), socket, addresses);
+                        .onAddressesChanged(socketInfo.mSocketKey, socketInfo.mSocket, addresses);
             }
         }
     }
@@ -647,9 +642,7 @@
             createSocket(new NetworkAsKey(network), lp);
         } else {
             // Notify the socket for requested network.
-            final int ifaceIndex = mDependencies.getInterfaceIndex(socketInfo.mSocket);
-            final SocketKey socketKey = new SocketKey(network, ifaceIndex);
-            cb.onSocketCreated(socketKey, socketInfo.mSocket, socketInfo.mAddresses);
+            cb.onSocketCreated(socketInfo.mSocketKey, socketInfo.mSocket, socketInfo.mAddresses);
             mSocketRequestMonitor.onSocketRequestFulfilled(network, socketInfo.mSocket,
                     socketInfo.mTransports);
         }
@@ -664,9 +657,7 @@
                     createLPForTetheredInterface(interfaceName, ifaceIndex));
         } else {
             // Notify the socket for requested network.
-            final int ifaceIndex = mDependencies.getInterfaceIndex(socketInfo.mSocket);
-            final SocketKey socketKey = new SocketKey(ifaceIndex);
-            cb.onSocketCreated(socketKey, socketInfo.mSocket, socketInfo.mAddresses);
+            cb.onSocketCreated(socketInfo.mSocketKey, socketInfo.mSocket, socketInfo.mAddresses);
             mSocketRequestMonitor.onSocketRequestFulfilled(null /* socketNetwork */,
                     socketInfo.mSocket, socketInfo.mTransports);
         }
diff --git a/service-t/src/com/android/server/connectivity/mdns/NetworkInterfaceWrapper.java b/service-t/src/com/android/server/connectivity/mdns/NetworkInterfaceWrapper.java
index 0ecae48..48c396e 100644
--- a/service-t/src/com/android/server/connectivity/mdns/NetworkInterfaceWrapper.java
+++ b/service-t/src/com/android/server/connectivity/mdns/NetworkInterfaceWrapper.java
@@ -57,6 +57,10 @@
         return networkInterface.getInterfaceAddresses();
     }
 
+    public int getIndex() {
+        return networkInterface.getIndex();
+    }
+
     @Override
     public String toString() {
         return networkInterface.toString();
diff --git a/service-t/src/com/android/server/connectivity/mdns/SocketKey.java b/service-t/src/com/android/server/connectivity/mdns/SocketKey.java
index a893acb..f13d0e0 100644
--- a/service-t/src/com/android/server/connectivity/mdns/SocketKey.java
+++ b/service-t/src/com/android/server/connectivity/mdns/SocketKey.java
@@ -43,6 +43,7 @@
         mInterfaceIndex = interfaceIndex;
     }
 
+    @Nullable
     public Network getNetwork() {
         return mNetwork;
     }
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsAdvertiserTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsAdvertiserTest.kt
index c467f45..6a0334f 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsAdvertiserTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsAdvertiserTest.kt
@@ -56,8 +56,8 @@
 private val TEST_ADDR = parseNumericAddress("2001:db8::123")
 private val TEST_LINKADDR = LinkAddress(TEST_ADDR, 64 /* prefixLength */)
 private val TEST_NETWORK_1 = mock(Network::class.java)
-private val TEST_SOCKETKEY_1 = mock(SocketKey::class.java)
-private val TEST_SOCKETKEY_2 = mock(SocketKey::class.java)
+private val TEST_SOCKETKEY_1 = SocketKey(1001 /* interfaceIndex */)
+private val TEST_SOCKETKEY_2 = SocketKey(1002 /* interfaceIndex */)
 private val TEST_HOSTNAME = arrayOf("Android_test", "local")
 private const val TEST_SUBTYPE = "_subtype"
 
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsDiscoveryManagerTests.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsDiscoveryManagerTests.java
index 81e1cd4..1a4ae5d 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsDiscoveryManagerTests.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsDiscoveryManagerTests.java
@@ -265,7 +265,7 @@
         // The first callback receives a notification that the network has been destroyed,
         // mockServiceTypeClientOne1 should send service removed notifications and remove from the
         // list of clients.
-        runOnHandler(() -> callback.onAllSocketsDestroyed(SOCKET_KEY_NETWORK_1));
+        runOnHandler(() -> callback.onSocketDestroyed(SOCKET_KEY_NETWORK_1));
         verify(mockServiceTypeClientType1Network1).notifySocketDestroyed();
 
         // Receive a response again, it should be processed only on
@@ -280,7 +280,7 @@
 
         // The client for NETWORK_1 receives the callback that the NETWORK_2 has been destroyed,
         // mockServiceTypeClientTwo2 shouldn't send any notifications.
-        runOnHandler(() -> callback2.onAllSocketsDestroyed(SOCKET_KEY_NETWORK_2));
+        runOnHandler(() -> callback2.onSocketDestroyed(SOCKET_KEY_NETWORK_2));
         verify(mockServiceTypeClientType2Network1, never()).notifySocketDestroyed();
 
         // Receive a response again, mockServiceTypeClientType2Network1 is still in the list of
@@ -310,7 +310,7 @@
         verify(mockServiceTypeClientType1NullNetwork).processResponse(
                 response, SOCKET_KEY_NULL_NETWORK);
 
-        runOnHandler(() -> callback.onAllSocketsDestroyed(SOCKET_KEY_NULL_NETWORK));
+        runOnHandler(() -> callback.onSocketDestroyed(SOCKET_KEY_NULL_NETWORK));
         verify(mockServiceTypeClientType1NullNetwork).notifySocketDestroyed();
 
         // Receive a response again, it should not be processed.
@@ -326,7 +326,7 @@
         verify(socketClient).notifyNetworkUnrequested(mockListenerOne);
         verify(mockServiceTypeClientType1NullNetwork, never()).stopSendAndReceive(any());
         // The stopDiscovery() is only used by MdnsSocketClient, which doesn't send
-        // onAllSocketsDestroyed(). So the socket clients that send onAllSocketsDestroyed() do not
+        // onSocketDestroyed(). So the socket clients that send onSocketDestroyed() do not
         // need to call stopDiscovery().
         verify(socketClient, never()).stopDiscovery();
     }
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClientTest.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClientTest.java
index 6f75d7e..29de272 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClientTest.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClientTest.java
@@ -28,7 +28,6 @@
 import static org.mockito.Mockito.timeout;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
-import static org.mockito.Mockito.verifyNoMoreInteractions;
 
 import android.net.InetAddresses;
 import android.net.Network;
@@ -67,9 +66,9 @@
     @Mock private MdnsServiceBrowserListener mListener;
     @Mock private MdnsSocketClientBase.Callback mCallback;
     @Mock private SocketCreationCallback mSocketCreationCallback;
-    @Mock private SocketKey mSocketKey;
     private MdnsMultinetworkSocketClient mSocketClient;
     private Handler mHandler;
+    private SocketKey mSocketKey;
 
     @Before
     public void setUp() throws SocketException {
@@ -78,6 +77,7 @@
         final HandlerThread thread = new HandlerThread("MdnsMultinetworkSocketClientTest");
         thread.start();
         mHandler = new Handler(thread.getLooper());
+        mSocketKey = new SocketKey(1000 /* interfaceIndex */);
         mSocketClient = new MdnsMultinetworkSocketClient(thread.getLooper(), mProvider);
         mHandler.post(() -> mSocketClient.setCallback(mCallback));
     }
@@ -124,8 +124,8 @@
             doReturn(createEmptyNetworkInterface()).when(socket).getInterface();
         }
 
-        final SocketKey tetherSocketKey1 = mock(SocketKey.class);
-        final SocketKey tetherSocketKey2 = mock(SocketKey.class);
+        final SocketKey tetherSocketKey1 = new SocketKey(1001 /* interfaceIndex */);
+        final SocketKey tetherSocketKey2 = new SocketKey(1002 /* interfaceIndex */);
         // Notify socket created
         callback.onSocketCreated(mSocketKey, mSocket, List.of());
         verify(mSocketCreationCallback).onSocketCreated(mSocketKey);
@@ -237,8 +237,8 @@
         doReturn(createEmptyNetworkInterface()).when(socket2).getInterface();
         doReturn(createEmptyNetworkInterface()).when(socket3).getInterface();
 
-        final SocketKey socketKey2 = mock(SocketKey.class);
-        final SocketKey socketKey3 = mock(SocketKey.class);
+        final SocketKey socketKey2 = new SocketKey(1001 /* interfaceIndex */);
+        final SocketKey socketKey3 = new SocketKey(1002 /* interfaceIndex */);
         callback.onSocketCreated(mSocketKey, mSocket, List.of());
         callback.onSocketCreated(socketKey2, socket2, List.of());
         callback.onSocketCreated(socketKey3, socket3, List.of());
@@ -312,6 +312,7 @@
     @Test
     public void testNotifyNetworkUnrequested_SocketsOnNullNetwork() {
         final MdnsInterfaceSocket otherSocket = mock(MdnsInterfaceSocket.class);
+        final SocketKey otherSocketKey = new SocketKey(1001 /* interfaceIndex */);
         final SocketCallback callback = expectSocketCallback(
                 mListener, null /* requestedNetwork */);
         doReturn(createEmptyNetworkInterface()).when(mSocket).getInterface();
@@ -319,33 +320,36 @@
 
         callback.onSocketCreated(mSocketKey, mSocket, List.of());
         verify(mSocketCreationCallback).onSocketCreated(mSocketKey);
-        callback.onSocketCreated(mSocketKey, otherSocket, List.of());
-        verify(mSocketCreationCallback, times(2)).onSocketCreated(mSocketKey);
+        callback.onSocketCreated(otherSocketKey, otherSocket, List.of());
+        verify(mSocketCreationCallback).onSocketCreated(otherSocketKey);
 
-        verify(mSocketCreationCallback, never()).onAllSocketsDestroyed(mSocketKey);
+        verify(mSocketCreationCallback, never()).onSocketDestroyed(mSocketKey);
+        verify(mSocketCreationCallback, never()).onSocketDestroyed(otherSocketKey);
         mHandler.post(() -> mSocketClient.notifyNetworkUnrequested(mListener));
         HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
 
         verify(mProvider).unrequestSocket(callback);
-        verify(mSocketCreationCallback).onAllSocketsDestroyed(mSocketKey);
+        verify(mSocketCreationCallback).onSocketDestroyed(mSocketKey);
+        verify(mSocketCreationCallback).onSocketDestroyed(otherSocketKey);
     }
 
     @Test
     public void testSocketCreatedAndDestroyed_NullNetwork() throws IOException {
         final MdnsInterfaceSocket otherSocket = mock(MdnsInterfaceSocket.class);
+        final SocketKey otherSocketKey = new SocketKey(1001 /* interfaceIndex */);
         final SocketCallback callback = expectSocketCallback(mListener, null /* network */);
         doReturn(createEmptyNetworkInterface()).when(mSocket).getInterface();
         doReturn(createEmptyNetworkInterface()).when(otherSocket).getInterface();
 
         callback.onSocketCreated(mSocketKey, mSocket, List.of());
         verify(mSocketCreationCallback).onSocketCreated(mSocketKey);
-        callback.onSocketCreated(mSocketKey, otherSocket, List.of());
-        verify(mSocketCreationCallback, times(2)).onSocketCreated(mSocketKey);
+        callback.onSocketCreated(otherSocketKey, otherSocket, List.of());
+        verify(mSocketCreationCallback).onSocketCreated(otherSocketKey);
 
         // Notify socket destroyed
         callback.onInterfaceDestroyed(mSocketKey, mSocket);
-        verifyNoMoreInteractions(mSocketCreationCallback);
-        callback.onInterfaceDestroyed(mSocketKey, otherSocket);
-        verify(mSocketCreationCallback).onAllSocketsDestroyed(mSocketKey);
+        verify(mSocketCreationCallback).onSocketDestroyed(mSocketKey);
+        callback.onInterfaceDestroyed(otherSocketKey, otherSocket);
+        verify(mSocketCreationCallback).onSocketDestroyed(otherSocketKey);
     }
 }
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketProviderTest.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketProviderTest.java
index 0eac5ec..e971de7 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketProviderTest.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketProviderTest.java
@@ -157,7 +157,6 @@
                 TETHERED_IFACE_NAME);
         doReturn(789).when(mDeps).getNetworkInterfaceIndexByName(
                 WIFI_P2P_IFACE_NAME);
-        doReturn(TETHERED_IFACE_IDX).when(mDeps).getInterfaceIndex(any());
         final HandlerThread thread = new HandlerThread("MdnsSocketProviderTest");
         thread.start();
         mHandler = new Handler(thread.getLooper());