Add MdnsMultinetworkSocketClient

Add MdnsMultinetworkSocketClient which is using for managing
multinetwork for discovery and resolution. If the requests are
specified the network to do the discovery or resolution, it
should send the queries and receive the responses on the active
networks only. This can save the resource by reducing unnecessary
queries and align the behavior with mdnsresponder.

Bug: 254166302
Test: atest FramworksNetTests
Change-Id: I9f49ac11e70cb945f9a90efc5eb684be87801286
diff --git a/service/mdns/com/android/server/connectivity/mdns/EnqueueMdnsQueryCallable.java b/service/mdns/com/android/server/connectivity/mdns/EnqueueMdnsQueryCallable.java
index f7871f3..fdd1478 100644
--- a/service/mdns/com/android/server/connectivity/mdns/EnqueueMdnsQueryCallable.java
+++ b/service/mdns/com/android/server/connectivity/mdns/EnqueueMdnsQueryCallable.java
@@ -18,7 +18,9 @@
 
 import android.annotation.NonNull;
 import android.annotation.Nullable;
+import android.net.Network;
 import android.text.TextUtils;
+import android.util.Log;
 import android.util.Pair;
 
 import com.android.server.connectivity.mdns.util.MdnsLogger;
@@ -58,26 +60,29 @@
         }
     }
 
-    private final WeakReference<MdnsSocketClient> weakRequestSender;
+    private final WeakReference<MdnsSocketClientBase> weakRequestSender;
     private final MdnsPacketWriter packetWriter;
     private final String[] serviceTypeLabels;
     private final List<String> subtypes;
     private final boolean expectUnicastResponse;
     private final int transactionId;
+    private final Network network;
 
     EnqueueMdnsQueryCallable(
-            @NonNull MdnsSocketClient requestSender,
+            @NonNull MdnsSocketClientBase requestSender,
             @NonNull MdnsPacketWriter packetWriter,
             @NonNull String serviceType,
             @NonNull Collection<String> subtypes,
             boolean expectUnicastResponse,
-            int transactionId) {
+            int transactionId,
+            @Nullable Network network) {
         weakRequestSender = new WeakReference<>(requestSender);
         this.packetWriter = packetWriter;
         serviceTypeLabels = TextUtils.split(serviceType, "\\.");
         this.subtypes = new ArrayList<>(subtypes);
         this.expectUnicastResponse = expectUnicastResponse;
         this.transactionId = transactionId;
+        this.network = network;
     }
 
     // Incompatible return type for override of Callable#call().
@@ -86,7 +91,7 @@
     @Nullable
     public Pair<Integer, List<String>> call() {
         try {
-            MdnsSocketClient requestSender = weakRequestSender.get();
+            MdnsSocketClientBase requestSender = weakRequestSender.get();
             if (requestSender == null) {
                 return null;
             }
@@ -127,15 +132,24 @@
                     MdnsConstants.QCLASS_INTERNET
                             | (expectUnicastResponse ? MdnsConstants.QCLASS_UNICAST : 0));
 
-            InetAddress mdnsAddress = MdnsConstants.getMdnsIPv4Address();
-            if (requestSender.isOnIPv6OnlyNetwork()) {
-                mdnsAddress = MdnsConstants.getMdnsIPv6Address();
-            }
+            if (requestSender instanceof MdnsMultinetworkSocketClient) {
+                sendPacketToIpv4AndIpv6(requestSender, MdnsConstants.MDNS_PORT, network);
+                for (Integer emulatorPort : castShellEmulatorMdnsPorts) {
+                    sendPacketToIpv4AndIpv6(requestSender, emulatorPort, network);
+                }
+            } else if (requestSender instanceof MdnsSocketClient) {
+                final MdnsSocketClient client = (MdnsSocketClient) requestSender;
+                InetAddress mdnsAddress = MdnsConstants.getMdnsIPv4Address();
+                if (client.isOnIPv6OnlyNetwork()) {
+                    mdnsAddress = MdnsConstants.getMdnsIPv6Address();
+                }
 
-            sendPacketTo(requestSender,
-                    new InetSocketAddress(mdnsAddress, MdnsConstants.MDNS_PORT));
-            for (Integer emulatorPort : castShellEmulatorMdnsPorts) {
-                sendPacketTo(requestSender, new InetSocketAddress(mdnsAddress, emulatorPort));
+                sendPacketTo(client, new InetSocketAddress(mdnsAddress, MdnsConstants.MDNS_PORT));
+                for (Integer emulatorPort : castShellEmulatorMdnsPorts) {
+                    sendPacketTo(client, new InetSocketAddress(mdnsAddress, emulatorPort));
+                }
+            } else {
+                throw new IOException("Unknown socket client type: " + requestSender.getClass());
             }
             return Pair.create(transactionId, subtypes);
         } catch (IOException e) {
@@ -145,7 +159,7 @@
         }
     }
 
-    private void sendPacketTo(MdnsSocketClient requestSender, InetSocketAddress address)
+    private void sendPacketTo(MdnsSocketClientBase requestSender, InetSocketAddress address)
             throws IOException {
         DatagramPacket packet = packetWriter.getPacket(address);
         if (expectUnicastResponse) {
@@ -154,4 +168,31 @@
             requestSender.sendMulticastPacket(packet);
         }
     }
+
+    private void sendPacketFromNetwork(MdnsSocketClientBase requestSender,
+            InetSocketAddress address, Network network)
+            throws IOException {
+        DatagramPacket packet = packetWriter.getPacket(address);
+        if (expectUnicastResponse) {
+            requestSender.sendUnicastPacket(packet, network);
+        } else {
+            requestSender.sendMulticastPacket(packet, network);
+        }
+    }
+
+    private void sendPacketToIpv4AndIpv6(MdnsSocketClientBase requestSender, int port,
+            Network network) {
+        try {
+            sendPacketFromNetwork(requestSender,
+                    new InetSocketAddress(MdnsConstants.getMdnsIPv4Address(), port), network);
+        } catch (IOException e) {
+            Log.i(TAG, "Can't send packet to IPv4", e);
+        }
+        try {
+            sendPacketFromNetwork(requestSender,
+                    new InetSocketAddress(MdnsConstants.getMdnsIPv6Address(), port), network);
+        } catch (IOException e) {
+            Log.i(TAG, "Can't send packet to IPv6", e);
+        }
+    }
 }
\ No newline at end of file
diff --git a/service/mdns/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java b/service/mdns/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java
index 0f3c23a..cc6b98b 100644
--- a/service/mdns/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java
+++ b/service/mdns/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java
@@ -34,18 +34,18 @@
  * This class keeps tracking the set of registered {@link MdnsServiceBrowserListener} instances, and
  * notify them when a mDNS service instance is found, updated, or removed?
  */
-public class MdnsDiscoveryManager implements MdnsSocketClient.Callback {
+public class MdnsDiscoveryManager implements MdnsSocketClientBase.Callback {
     private static final String TAG = MdnsDiscoveryManager.class.getSimpleName();
     public static final boolean DBG = Log.isLoggable(TAG, Log.DEBUG);
     private static final MdnsLogger LOGGER = new MdnsLogger("MdnsDiscoveryManager");
 
     private final ExecutorProvider executorProvider;
-    private final MdnsSocketClient socketClient;
+    private final MdnsSocketClientBase socketClient;
 
     private final Map<String, MdnsServiceTypeClient> serviceTypeClients = new ArrayMap<>();
 
-    public MdnsDiscoveryManager(
-            @NonNull ExecutorProvider executorProvider, @NonNull MdnsSocketClient socketClient) {
+    public MdnsDiscoveryManager(@NonNull ExecutorProvider executorProvider,
+            @NonNull MdnsSocketClientBase socketClient) {
         this.executorProvider = executorProvider;
         this.socketClient = socketClient;
     }
@@ -76,12 +76,16 @@
                 return;
             }
         }
+        // Request the network for discovery.
+        socketClient.notifyNetworkRequested(listener, searchOptions.getNetwork());
+
         // All listeners of the same service types shares the same MdnsServiceTypeClient.
         MdnsServiceTypeClient serviceTypeClient = serviceTypeClients.get(serviceType);
         if (serviceTypeClient == null) {
             serviceTypeClient = createServiceTypeClient(serviceType);
             serviceTypeClients.put(serviceType, serviceTypeClient);
         }
+        // TODO(b/264634275): Wait for a socket to be created before sending packets.
         serviceTypeClient.startSendAndReceive(listener, searchOptions);
     }
 
@@ -96,20 +100,22 @@
     public synchronized void unregisterListener(
             @NonNull String serviceType, @NonNull MdnsServiceBrowserListener listener) {
         LOGGER.log("Unregistering listener for service type: %s", serviceType);
+        if (DBG) Log.d(TAG, "Unregistering listener for serviceType:" + serviceType);
         MdnsServiceTypeClient serviceTypeClient = serviceTypeClients.get(serviceType);
         if (serviceTypeClient == null) {
             return;
         }
         if (serviceTypeClient.stopSendAndReceive(listener)) {
             // No listener is registered for the service type anymore, remove it from the list of
-          // the
-            // service type clients.
+            // the service type clients.
             serviceTypeClients.remove(serviceType);
             if (serviceTypeClients.isEmpty()) {
                 // No discovery request. Stops the socket client.
                 socketClient.stopDiscovery();
             }
         }
+        // Unrequested the network.
+        socketClient.notifyNetworkUnrequested(listener);
     }
 
     @Override
diff --git a/service/mdns/com/android/server/connectivity/mdns/MdnsInterfaceSocket.java b/service/mdns/com/android/server/connectivity/mdns/MdnsInterfaceSocket.java
index 67c893d..d1290b6 100644
--- a/service/mdns/com/android/server/connectivity/mdns/MdnsInterfaceSocket.java
+++ b/service/mdns/com/android/server/connectivity/mdns/MdnsInterfaceSocket.java
@@ -22,10 +22,15 @@
 import android.annotation.NonNull;
 import android.net.LinkAddress;
 import android.net.util.SocketUtils;
+import android.os.Handler;
+import android.os.Looper;
 import android.os.ParcelFileDescriptor;
 import android.system.ErrnoException;
+import android.system.Os;
+import android.system.OsConstants;
 import android.util.Log;
 
+import java.io.FileDescriptor;
 import java.io.IOException;
 import java.net.DatagramPacket;
 import java.net.InetSocketAddress;
@@ -41,15 +46,19 @@
  * otherwise.
  *
  * @see MulticastSocket for javadoc of each public method.
+ * @see MulticastSocket for javadoc of each public method.
  */
 public class MdnsInterfaceSocket {
     private static final String TAG = MdnsInterfaceSocket.class.getSimpleName();
     @NonNull private final MulticastSocket mMulticastSocket;
     @NonNull private final NetworkInterface mNetworkInterface;
+    @NonNull private final MulticastPacketReader mPacketReader;
+    @NonNull private final ParcelFileDescriptor mFileDescriptor;
     private boolean mJoinedIpv4 = false;
     private boolean mJoinedIpv6 = false;
 
-    public MdnsInterfaceSocket(@NonNull NetworkInterface networkInterface, int port)
+    public MdnsInterfaceSocket(@NonNull NetworkInterface networkInterface, int port,
+            @NonNull Looper looper, @NonNull byte[] packetReadBuffer)
             throws IOException {
         mNetworkInterface = networkInterface;
         mMulticastSocket = new MulticastSocket(port);
@@ -58,11 +67,19 @@
         mMulticastSocket.setNetworkInterface(networkInterface);
 
         // Bind socket to the interface for receiving from that interface only.
-        try (ParcelFileDescriptor pfd = ParcelFileDescriptor.fromDatagramSocket(mMulticastSocket)) {
-            SocketUtils.bindSocketToInterface(pfd.getFileDescriptor(), mNetworkInterface.getName());
+        mFileDescriptor = ParcelFileDescriptor.fromDatagramSocket(mMulticastSocket);
+        try {
+            final FileDescriptor fd = mFileDescriptor.getFileDescriptor();
+            final int flags = Os.fcntlInt(fd, OsConstants.F_GETFL, 0);
+            Os.fcntlInt(fd, OsConstants.F_SETFL, flags | OsConstants.SOCK_NONBLOCK);
+            SocketUtils.bindSocketToInterface(fd, mNetworkInterface.getName());
         } catch (ErrnoException e) {
             throw new IOException("Error setting socket options", e);
         }
+
+        mPacketReader = new MulticastPacketReader(networkInterface.getName(), mFileDescriptor,
+                new Handler(looper), packetReadBuffer);
+        mPacketReader.start();
     }
 
     /**
@@ -74,23 +91,14 @@
         mMulticastSocket.send(packet);
     }
 
-    /**
-     * Receives a datagram packet from this socket.
-     *
-     * <p>This method could be used on any thread.
-     */
-    public void receive(@NonNull DatagramPacket packet) throws IOException {
-        mMulticastSocket.receive(packet);
-    }
-
-    private boolean hasIpv4Address(List<LinkAddress> addresses) {
+    private static boolean hasIpv4Address(@NonNull List<LinkAddress> addresses) {
         for (LinkAddress address : addresses) {
             if (address.isIpv4()) return true;
         }
         return false;
     }
 
-    private boolean hasIpv6Address(List<LinkAddress> addresses) {
+    private static boolean hasIpv6Address(@NonNull List<LinkAddress> addresses) {
         for (LinkAddress address : addresses) {
             if (address.isIpv6()) return true;
         }
@@ -103,7 +111,7 @@
         maybeJoinIpv6(addresses);
     }
 
-    private boolean joinGroup(InetSocketAddress multicastAddress) {
+    private boolean joinGroup(@NonNull InetSocketAddress multicastAddress) {
         try {
             mMulticastSocket.joinGroup(multicastAddress, mNetworkInterface);
             return true;
@@ -114,7 +122,7 @@
         }
     }
 
-    private void maybeJoinIpv4(List<LinkAddress> addresses) {
+    private void maybeJoinIpv4(@NonNull List<LinkAddress> addresses) {
         final boolean hasAddr = hasIpv4Address(addresses);
         if (!mJoinedIpv4 && hasAddr) {
             mJoinedIpv4 = joinGroup(MULTICAST_IPV4_ADDRESS);
@@ -124,7 +132,7 @@
         }
     }
 
-    private void maybeJoinIpv6(List<LinkAddress> addresses) {
+    private void maybeJoinIpv6(@NonNull List<LinkAddress> addresses) {
         final boolean hasAddr = hasIpv6Address(addresses);
         if (!mJoinedIpv6 && hasAddr) {
             mJoinedIpv6 = joinGroup(MULTICAST_IPV6_ADDRESS);
@@ -134,26 +142,26 @@
         }
     }
 
-    /*** Destroy this socket by leaving all joined multicast groups and closing this socket. */
+    /*** Destroy the socket */
     public void destroy() {
-        if (mJoinedIpv4) {
-            try {
-                mMulticastSocket.leaveGroup(MULTICAST_IPV4_ADDRESS, mNetworkInterface);
-            } catch (IOException e) {
-                Log.e(TAG, "Error leaving IPv4 group for " + mNetworkInterface, e);
-            }
-        }
-        if (mJoinedIpv6) {
-            try {
-                mMulticastSocket.leaveGroup(MULTICAST_IPV6_ADDRESS, mNetworkInterface);
-            } catch (IOException e) {
-                Log.e(TAG, "Error leaving IPv4 group for " + mNetworkInterface, e);
-            }
+        mPacketReader.stop();
+        try {
+            mFileDescriptor.close();
+        } catch (IOException e) {
+            Log.e(TAG, "Close file descriptor failed.");
         }
         mMulticastSocket.close();
     }
 
     /**
+     * Add a handler to receive callbacks when reads the packet from socket. If the handler is
+     * already set, this is a no-op.
+     */
+    public void addPacketHandler(@NonNull MulticastPacketReader.PacketHandler handler) {
+        mPacketReader.addPacketHandler(handler);
+    }
+
+    /**
      * Returns the network interface that this socket is bound to.
      *
      * <p>This method could be used on any thread.
diff --git a/service/mdns/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java b/service/mdns/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java
new file mode 100644
index 0000000..d959065
--- /dev/null
+++ b/service/mdns/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java
@@ -0,0 +1,219 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server.connectivity.mdns;
+
+import static com.android.server.connectivity.mdns.MdnsSocketProvider.ensureRunningOnHandlerThread;
+import static com.android.server.connectivity.mdns.MdnsSocketProvider.isNetworkMatched;
+
+import android.annotation.NonNull;
+import android.annotation.Nullable;
+import android.net.LinkAddress;
+import android.net.Network;
+import android.os.Handler;
+import android.os.Looper;
+import android.util.ArrayMap;
+import android.util.Log;
+
+import java.io.IOException;
+import java.net.DatagramPacket;
+import java.net.Inet4Address;
+import java.net.Inet6Address;
+import java.net.InetSocketAddress;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * The {@link MdnsMultinetworkSocketClient} manages the multinetwork socket for mDns
+ *
+ *  * <p>This class is not thread safe.
+ */
+public class MdnsMultinetworkSocketClient implements MdnsSocketClientBase {
+    private static final String TAG = MdnsMultinetworkSocketClient.class.getSimpleName();
+    private static final boolean DBG = MdnsDiscoveryManager.DBG;
+
+    @NonNull private final Handler mHandler;
+    @NonNull private final MdnsSocketProvider mSocketProvider;
+    @NonNull private final MdnsResponseDecoder mResponseDecoder;
+
+    private final Map<MdnsServiceBrowserListener, InterfaceSocketCallback> mRequestedNetworks =
+            new ArrayMap<>();
+    private final ArrayMap<MdnsInterfaceSocket, Network> mActiveNetworkSockets = new ArrayMap<>();
+    private final ArrayMap<MdnsInterfaceSocket, ReadPacketHandler> mSocketPacketHandlers =
+            new ArrayMap<>();
+    private MdnsSocketClientBase.Callback mCallback = null;
+    private int mReceivedPacketNumber = 0;
+
+    public MdnsMultinetworkSocketClient(@NonNull Looper looper,
+            @NonNull MdnsSocketProvider provider) {
+        mHandler = new Handler(looper);
+        mSocketProvider = provider;
+        mResponseDecoder = new MdnsResponseDecoder(
+                new MdnsResponseDecoder.Clock(), null /* serviceType */);
+    }
+
+    private class InterfaceSocketCallback implements MdnsSocketProvider.SocketCallback {
+        @Override
+        public void onSocketCreated(@NonNull Network network,
+                @NonNull MdnsInterfaceSocket socket, @NonNull List<LinkAddress> addresses) {
+            // The socket may be already created by other request before, try to get the stored
+            // ReadPacketHandler.
+            ReadPacketHandler handler = mSocketPacketHandlers.get(socket);
+            if (handler == null) {
+                // First request to create this socket. Initial a ReadPacketHandler for this socket.
+                handler = new ReadPacketHandler(network, socket.getInterface().getIndex());
+                mSocketPacketHandlers.put(socket, handler);
+            }
+            socket.addPacketHandler(handler);
+            mActiveNetworkSockets.put(socket, network);
+        }
+
+        @Override
+        public void onInterfaceDestroyed(@NonNull Network network,
+                @NonNull MdnsInterfaceSocket socket) {
+            mSocketPacketHandlers.remove(socket);
+            mActiveNetworkSockets.remove(socket);
+        }
+    }
+
+    private class ReadPacketHandler implements MulticastPacketReader.PacketHandler {
+        private final Network mNetwork;
+        private final int mInterfaceIndex;
+
+        ReadPacketHandler(@NonNull Network network, int interfaceIndex) {
+            mNetwork = network;
+            mInterfaceIndex = interfaceIndex;
+        }
+
+        @Override
+        public void handlePacket(byte[] recvbuf, int length, InetSocketAddress src) {
+            processResponsePacket(recvbuf, length, mInterfaceIndex, mNetwork);
+        }
+    }
+
+    /*** Set callback for receiving mDns response */
+    @Override
+    public void setCallback(@Nullable MdnsSocketClientBase.Callback callback) {
+        ensureRunningOnHandlerThread(mHandler);
+        mCallback = callback;
+    }
+
+    /***
+     * Notify that the given network is requested for mdns discovery / resolution
+     *
+     * @param listener the listener for discovery.
+     * @param network the target network for discovery. Null means discovery on all possible
+     *                interfaces.
+     */
+    @Override
+    public void notifyNetworkRequested(@NonNull MdnsServiceBrowserListener listener,
+            @Nullable Network network) {
+        ensureRunningOnHandlerThread(mHandler);
+        InterfaceSocketCallback callback = mRequestedNetworks.get(listener);
+        if (callback != null) {
+            throw new IllegalArgumentException("Can not register duplicated listener");
+        }
+
+        if (DBG) Log.d(TAG, "notifyNetworkRequested: network=" + network);
+        callback = new InterfaceSocketCallback();
+        mRequestedNetworks.put(listener, callback);
+        mSocketProvider.requestSocket(network, callback);
+    }
+
+    /*** Notify that the network is unrequested */
+    @Override
+    public void notifyNetworkUnrequested(@NonNull MdnsServiceBrowserListener listener) {
+        ensureRunningOnHandlerThread(mHandler);
+        final InterfaceSocketCallback callback = mRequestedNetworks.remove(listener);
+        if (callback == null) {
+            Log.e(TAG, "Can not be unrequested with unknown listener=" + listener);
+            return;
+        }
+        mSocketProvider.unrequestSocket(callback);
+    }
+
+    private void sendMdnsPacket(@NonNull DatagramPacket packet, @Nullable Network targetNetwork) {
+        final boolean isIpv6 = ((InetSocketAddress) packet.getSocketAddress()).getAddress()
+                instanceof Inet6Address;
+        final boolean isIpv4 = ((InetSocketAddress) packet.getSocketAddress()).getAddress()
+                instanceof Inet4Address;
+        for (int i = 0; i < mActiveNetworkSockets.size(); i++) {
+            final MdnsInterfaceSocket socket = mActiveNetworkSockets.keyAt(i);
+            final Network network = mActiveNetworkSockets.valueAt(i);
+            // Check ip capability and network before sending packet
+            if (((isIpv6 && socket.hasJoinedIpv6()) || (isIpv4 && socket.hasJoinedIpv4()))
+                    && isNetworkMatched(targetNetwork, network)) {
+                try {
+                    socket.send(packet);
+                } catch (IOException e) {
+                    Log.e(TAG, "Failed to send a mDNS packet.", e);
+                }
+            }
+        }
+    }
+
+    private void processResponsePacket(byte[] recvbuf, int length, int interfaceIndex,
+            @NonNull Network network) {
+        int packetNumber = ++mReceivedPacketNumber;
+
+        final List<MdnsResponse> responses = new ArrayList<>();
+        final int errorCode = mResponseDecoder.decode(
+                recvbuf, length, responses, interfaceIndex, network);
+        if (errorCode == MdnsResponseDecoder.SUCCESS) {
+            for (MdnsResponse response : responses) {
+                if (mCallback != null) {
+                    mCallback.onResponseReceived(response);
+                }
+            }
+        } else if (errorCode != MdnsResponseErrorCode.ERROR_NOT_RESPONSE_MESSAGE) {
+            if (mCallback != null) {
+                mCallback.onFailedToParseMdnsResponse(packetNumber, errorCode);
+            }
+        }
+    }
+
+    /** Sends a mDNS request packet that asks for multicast response. */
+    @Override
+    public void sendMulticastPacket(@NonNull DatagramPacket packet) {
+        sendMulticastPacket(packet, null /* network */);
+    }
+
+    /**
+     * Sends a mDNS request packet via given network that asks for multicast response. Null network
+     * means sending packet via all networks.
+     */
+    @Override
+    public void sendMulticastPacket(@NonNull DatagramPacket packet, @Nullable Network network) {
+        mHandler.post(() -> sendMdnsPacket(packet, network));
+    }
+
+    /** Sends a mDNS request packet that asks for unicast response. */
+    @Override
+    public void sendUnicastPacket(@NonNull DatagramPacket packet) {
+        sendUnicastPacket(packet, null /* network */);
+    }
+
+    /**
+     * Sends a mDNS request packet via given network that asks for unicast response. Null network
+     * means sending packet via all networks.
+     */
+    @Override
+    public void sendUnicastPacket(@NonNull DatagramPacket packet, @Nullable Network network) {
+        // TODO: Separate unicast packet.
+        mHandler.post(() -> sendMdnsPacket(packet, network));
+    }
+}
diff --git a/service/mdns/com/android/server/connectivity/mdns/MdnsPacketReader.java b/service/mdns/com/android/server/connectivity/mdns/MdnsPacketReader.java
index 856a2cd..aa38844 100644
--- a/service/mdns/com/android/server/connectivity/mdns/MdnsPacketReader.java
+++ b/service/mdns/com/android/server/connectivity/mdns/MdnsPacketReader.java
@@ -38,8 +38,13 @@
 
     /** Constructs a reader for the given packet. */
     public MdnsPacketReader(DatagramPacket packet) {
-        buf = packet.getData();
-        count = packet.getLength();
+        this(packet.getData(), packet.getLength());
+    }
+
+    /** Constructs a reader for the given packet. */
+    public MdnsPacketReader(byte[] buffer, int length) {
+        buf = buffer;
+        count = length;
         pos = 0;
         limit = -1;
         labelDictionary = new SparseArray<>(16);
diff --git a/service/mdns/com/android/server/connectivity/mdns/MdnsResponseDecoder.java b/service/mdns/com/android/server/connectivity/mdns/MdnsResponseDecoder.java
index 7cf84f6..50f2069 100644
--- a/service/mdns/com/android/server/connectivity/mdns/MdnsResponseDecoder.java
+++ b/service/mdns/com/android/server/connectivity/mdns/MdnsResponseDecoder.java
@@ -101,7 +101,24 @@
      */
     public int decode(@NonNull DatagramPacket packet, @NonNull List<MdnsResponse> responses,
             int interfaceIndex, @Nullable Network network) {
-        MdnsPacketReader reader = new MdnsPacketReader(packet);
+        return decode(packet.getData(), packet.getLength(), responses, interfaceIndex, network);
+    }
+
+    /**
+     * Decodes all mDNS responses for the desired service type from a packet. The class does not
+     * check
+     * the responses for completeness; the caller should do that.
+     *
+     * @param recvbuf The received data buffer to read from.
+     * @param length The length of received data buffer.
+     * @param interfaceIndex the network interface index (or {@link
+     *     MdnsSocket#INTERFACE_INDEX_UNSPECIFIED} if not known) at which the packet was received
+     * @param network the network at which the packet was received, or null if it is unknown.
+     * @return A list of mDNS responses, or null if the packet contained no appropriate responses.
+     */
+    public int decode(@NonNull byte[] recvbuf, int length, @NonNull List<MdnsResponse> responses,
+            int interfaceIndex, @Nullable Network network) {
+        MdnsPacketReader reader = new MdnsPacketReader(recvbuf, length);
 
         List<MdnsRecord> records;
         try {
diff --git a/service/mdns/com/android/server/connectivity/mdns/MdnsResponseErrorCode.java b/service/mdns/com/android/server/connectivity/mdns/MdnsResponseErrorCode.java
index fcf9058..73a7e3a 100644
--- a/service/mdns/com/android/server/connectivity/mdns/MdnsResponseErrorCode.java
+++ b/service/mdns/com/android/server/connectivity/mdns/MdnsResponseErrorCode.java
@@ -35,4 +35,6 @@
     public static final int ERROR_READING_TXT_RDATA = 10;
     public static final int ERROR_SKIPPING_UNKNOWN_RECORD = 11;
     public static final int ERROR_END_OF_FILE = 12;
+    public static final int ERROR_READING_NSEC_RDATA = 13;
+    public static final int ERROR_READING_ANY_RDATA = 14;
 }
\ No newline at end of file
diff --git a/service/mdns/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java b/service/mdns/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
index 538f376..d26fbdb 100644
--- a/service/mdns/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
+++ b/service/mdns/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
@@ -20,6 +20,7 @@
 
 import android.annotation.NonNull;
 import android.annotation.Nullable;
+import android.net.Network;
 import android.os.SystemClock;
 import android.text.TextUtils;
 import android.util.ArraySet;
@@ -52,7 +53,7 @@
 
     private final String serviceType;
     private final String[] serviceTypeLabels;
-    private final MdnsSocketClient socketClient;
+    private final MdnsSocketClientBase socketClient;
     private final ScheduledExecutorService executor;
     private final Object lock = new Object();
     private final Set<MdnsServiceBrowserListener> listeners = new ArraySet<>();
@@ -77,11 +78,11 @@
      * Constructor of {@link MdnsServiceTypeClient}.
      *
      * @param socketClient Sends and receives mDNS packet.
-     * @param executor     A {@link ScheduledExecutorService} used to schedule query tasks.
+     * @param executor         A {@link ScheduledExecutorService} used to schedule query tasks.
      */
     public MdnsServiceTypeClient(
             @NonNull String serviceType,
-            @NonNull MdnsSocketClient socketClient,
+            @NonNull MdnsSocketClientBase socketClient,
             @NonNull ScheduledExecutorService executor) {
         this.serviceType = serviceType;
         this.socketClient = socketClient;
@@ -169,7 +170,8 @@
                                     new QueryTaskConfig(
                                             searchOptions.getSubtypes(),
                                             searchOptions.isPassiveMode(),
-                                            ++currentSessionId)));
+                                            ++currentSessionId,
+                                            searchOptions.getNetwork())));
         }
     }
 
@@ -322,9 +324,10 @@
         private int burstCounter;
         private int timeToRunNextTaskInMs;
         private boolean isFirstBurst;
+        @Nullable private final Network network;
 
         QueryTaskConfig(@NonNull Collection<String> subtypes, boolean usePassiveMode,
-                long sessionId) {
+                long sessionId, @Nullable Network network) {
             this.usePassiveMode = usePassiveMode;
             this.subtypes = new ArrayList<>(subtypes);
             this.queriesPerBurst = QUERIES_PER_BURST;
@@ -346,6 +349,7 @@
                 // doubles until it maxes out at TIME_BETWEEN_BURSTS_MS.
                 this.timeBetweenBurstsInMs = INITIAL_TIME_BETWEEN_BURSTS_MS;
             }
+            this.network = network;
         }
 
         QueryTaskConfig getConfigForNextRun() {
@@ -405,7 +409,8 @@
                                 serviceType,
                                 config.subtypes,
                                 config.expectUnicastResponse,
-                                config.transactionId)
+                                config.transactionId,
+                                config.network)
                                 .call();
             } catch (RuntimeException e) {
                 LOGGER.e(String.format("Failed to run EnqueueMdnsQueryCallable for subtype: %s",
diff --git a/service/mdns/com/android/server/connectivity/mdns/MdnsSocket.java b/service/mdns/com/android/server/connectivity/mdns/MdnsSocket.java
index 64c4495..5fd1354 100644
--- a/service/mdns/com/android/server/connectivity/mdns/MdnsSocket.java
+++ b/service/mdns/com/android/server/connectivity/mdns/MdnsSocket.java
@@ -40,9 +40,9 @@
     private static final MdnsLogger LOGGER = new MdnsLogger("MdnsSocket");
 
     static final int INTERFACE_INDEX_UNSPECIFIED = -1;
-    protected static final InetSocketAddress MULTICAST_IPV4_ADDRESS =
+    public static final InetSocketAddress MULTICAST_IPV4_ADDRESS =
             new InetSocketAddress(MdnsConstants.getMdnsIPv4Address(), MdnsConstants.MDNS_PORT);
-    protected static final InetSocketAddress MULTICAST_IPV6_ADDRESS =
+    public static final InetSocketAddress MULTICAST_IPV6_ADDRESS =
             new InetSocketAddress(MdnsConstants.getMdnsIPv6Address(), MdnsConstants.MDNS_PORT);
     private final MulticastNetworkInterfaceProvider multicastNetworkInterfaceProvider;
     private final MulticastSocket multicastSocket;
diff --git a/service/mdns/com/android/server/connectivity/mdns/MdnsSocketClient.java b/service/mdns/com/android/server/connectivity/mdns/MdnsSocketClient.java
index 6a321d1..907687e 100644
--- a/service/mdns/com/android/server/connectivity/mdns/MdnsSocketClient.java
+++ b/service/mdns/com/android/server/connectivity/mdns/MdnsSocketClient.java
@@ -16,6 +16,8 @@
 
 package com.android.server.connectivity.mdns;
 
+import static com.android.server.connectivity.mdns.MdnsSocketClientBase.Callback;
+
 import android.Manifest.permission;
 import android.annotation.NonNull;
 import android.annotation.Nullable;
@@ -47,7 +49,7 @@
  *
  * <p>See https://tools.ietf.org/html/rfc6763 (namely sections 4 and 5).
  */
-public class MdnsSocketClient {
+public class MdnsSocketClient implements MdnsSocketClientBase {
 
     private static final String TAG = "MdnsClient";
     // TODO: The following values are copied from cast module. We need to think about the
@@ -116,11 +118,13 @@
         }
     }
 
+    @Override
     public synchronized void setCallback(@Nullable Callback callback) {
         this.callback = callback;
     }
 
     @RequiresPermission(permission.CHANGE_WIFI_MULTICAST_STATE)
+    @Override
     public synchronized void startDiscovery() throws IOException {
         if (multicastSocket != null) {
             LOGGER.w("Discovery is already in progress.");
@@ -160,6 +164,7 @@
     }
 
     @RequiresPermission(permission.CHANGE_WIFI_MULTICAST_STATE)
+    @Override
     public void stopDiscovery() {
         LOGGER.log("Stop discovery.");
         if (multicastSocket == null && unicastSocket == null) {
@@ -195,11 +200,13 @@
     }
 
     /** Sends a mDNS request packet that asks for multicast response. */
+    @Override
     public void sendMulticastPacket(@NonNull DatagramPacket packet) {
         sendMdnsPacket(packet, multicastPacketQueue);
     }
 
     /** Sends a mDNS request packet that asks for unicast response. */
+    @Override
     public void sendUnicastPacket(DatagramPacket packet) {
         if (useSeparateSocketForUnicast) {
             sendMdnsPacket(packet, unicastPacketQueue);
@@ -512,11 +519,4 @@
     public boolean isOnIPv6OnlyNetwork() {
         return multicastSocket != null && multicastSocket.isOnIPv6OnlyNetwork();
     }
-
-    /** Callback for {@link MdnsSocketClient}. */
-    public interface Callback {
-        void onResponseReceived(@NonNull MdnsResponse response);
-
-        void onFailedToParseMdnsResponse(int receivedPacketNumber, int errorCode);
-    }
 }
\ No newline at end of file
diff --git a/service/mdns/com/android/server/connectivity/mdns/MdnsSocketClientBase.java b/service/mdns/com/android/server/connectivity/mdns/MdnsSocketClientBase.java
new file mode 100644
index 0000000..23504a0
--- /dev/null
+++ b/service/mdns/com/android/server/connectivity/mdns/MdnsSocketClientBase.java
@@ -0,0 +1,80 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server.connectivity.mdns;
+
+import android.annotation.NonNull;
+import android.annotation.Nullable;
+import android.net.Network;
+
+import java.io.IOException;
+import java.net.DatagramPacket;
+
+/**
+ * Base class for multicast socket client.
+ *
+ * @hide
+ */
+public interface MdnsSocketClientBase {
+    /*** Start mDns discovery on given network. */
+    default void startDiscovery() throws IOException { }
+
+    /*** Stop mDns discovery. */
+    default void stopDiscovery() { }
+
+    /*** Set callback for receiving mDns response */
+    void setCallback(@Nullable Callback callback);
+
+    /*** Sends a mDNS request packet that asks for multicast response. */
+    void sendMulticastPacket(@NonNull DatagramPacket packet);
+
+    /**
+     * Sends a mDNS request packet via given network that asks for multicast response. Null network
+     * means sending packet via all networks.
+     */
+    default void sendMulticastPacket(@NonNull DatagramPacket packet, @Nullable Network network) {
+        throw new UnsupportedOperationException(
+                "This socket client doesn't support per-network sending");
+    }
+
+    /*** Sends a mDNS request packet that asks for unicast response. */
+    void sendUnicastPacket(@NonNull DatagramPacket packet);
+
+    /**
+     * Sends a mDNS request packet via given network that asks for unicast response. Null network
+     * means sending packet via all networks.
+     */
+    default void sendUnicastPacket(@NonNull DatagramPacket packet, @Nullable Network network) {
+        throw new UnsupportedOperationException(
+                "This socket client doesn't support per-network sending");
+    }
+
+    /*** Notify that the given network is requested for mdns discovery / resolution */
+    default void notifyNetworkRequested(@NonNull MdnsServiceBrowserListener listener,
+            @Nullable Network network) { }
+
+    /*** Notify that the network is unrequested */
+    default void notifyNetworkUnrequested(@NonNull MdnsServiceBrowserListener listener) { }
+
+    /*** Callback for mdns response  */
+    interface Callback {
+        /*** Receive a mdns response */
+        void onResponseReceived(@NonNull MdnsResponse response);
+
+        /*** Parse a mdns response failed */
+        void onFailedToParseMdnsResponse(int receivedPacketNumber, int errorCode);
+    }
+}
diff --git a/service/mdns/com/android/server/connectivity/mdns/MdnsSocketProvider.java b/service/mdns/com/android/server/connectivity/mdns/MdnsSocketProvider.java
index d3bf060..9298852 100644
--- a/service/mdns/com/android/server/connectivity/mdns/MdnsSocketProvider.java
+++ b/service/mdns/com/android/server/connectivity/mdns/MdnsSocketProvider.java
@@ -35,6 +35,7 @@
 import android.util.Log;
 
 import com.android.internal.annotations.VisibleForTesting;
+import com.android.net.module.util.CollectionUtils;
 import com.android.net.module.util.LinkPropertiesUtils.CompareResult;
 import com.android.net.module.util.ip.NetlinkMonitor;
 import com.android.net.module.util.netlink.NetlinkConstants;
@@ -42,7 +43,6 @@
 import com.android.server.connectivity.mdns.util.MdnsLogger;
 
 import java.io.IOException;
-import java.net.InterfaceAddress;
 import java.net.NetworkInterface;
 import java.net.SocketException;
 import java.util.ArrayList;
@@ -60,8 +60,13 @@
 public class MdnsSocketProvider {
     private static final String TAG = MdnsSocketProvider.class.getSimpleName();
     private static final boolean DBG = MdnsDiscoveryManager.DBG;
+    // This buffer size matches what MdnsSocketClient uses currently.
+    // But 1440 should generally be enough because of standard Ethernet.
+    // Note: mdnsresponder mDNSEmbeddedAPI.h uses 8940 for Ethernet jumbo frames.
+    private static final int READ_BUFFER_SIZE = 2048;
     private static final MdnsLogger LOGGER = new MdnsLogger(TAG);
     @NonNull private final Context mContext;
+    @NonNull private final Looper mLooper;
     @NonNull private final Handler mHandler;
     @NonNull private final Dependencies mDependencies;
     @NonNull private final NetworkCallback mNetworkCallback;
@@ -75,6 +80,7 @@
             new ArrayMap<>();
     private final List<String> mLocalOnlyInterfaces = new ArrayList<>();
     private final List<String> mTetheredInterfaces = new ArrayList<>();
+    private final byte[] mPacketReadBuffer = new byte[READ_BUFFER_SIZE];
     private boolean mMonitoringSockets = false;
 
     public MdnsSocketProvider(@NonNull Context context, @NonNull Looper looper) {
@@ -84,6 +90,7 @@
     MdnsSocketProvider(@NonNull Context context, @NonNull Looper looper,
             @NonNull Dependencies deps) {
         mContext = context;
+        mLooper = looper;
         mHandler = new Handler(looper);
         mDependencies = deps;
         mNetworkCallback = new NetworkCallback() {
@@ -119,32 +126,33 @@
     @VisibleForTesting
     public static class Dependencies {
         /*** Get network interface by given interface name */
-        public NetworkInterfaceWrapper getNetworkInterfaceByName(String interfaceName)
+        public NetworkInterfaceWrapper getNetworkInterfaceByName(@NonNull String interfaceName)
                 throws SocketException {
             final NetworkInterface ni = NetworkInterface.getByName(interfaceName);
             return ni == null ? null : new NetworkInterfaceWrapper(ni);
         }
 
         /*** Check whether given network interface can support mdns */
-        public boolean canScanOnInterface(NetworkInterfaceWrapper networkInterface) {
+        public boolean canScanOnInterface(@NonNull NetworkInterfaceWrapper networkInterface) {
             return MulticastNetworkInterfaceProvider.canScanOnInterface(networkInterface);
         }
 
         /*** Create a MdnsInterfaceSocket */
-        public MdnsInterfaceSocket createMdnsInterfaceSocket(NetworkInterface networkInterface,
-                int port) throws IOException {
-            return new MdnsInterfaceSocket(networkInterface, port);
+        public MdnsInterfaceSocket createMdnsInterfaceSocket(
+                @NonNull NetworkInterface networkInterface, int port, @NonNull Looper looper,
+                @NonNull byte[] packetReadBuffer) throws IOException {
+            return new MdnsInterfaceSocket(networkInterface, port, looper, packetReadBuffer);
         }
     }
 
     /*** Data class for storing socket related info  */
     private static class SocketInfo {
         final MdnsInterfaceSocket mSocket;
-        final List<LinkAddress> mAddresses = new ArrayList<>();
+        final List<LinkAddress> mAddresses;
 
         SocketInfo(MdnsInterfaceSocket socket, List<LinkAddress> addresses) {
             mSocket = socket;
-            mAddresses.addAll(addresses);
+            mAddresses = new ArrayList<>(addresses);
         }
     }
 
@@ -160,8 +168,9 @@
         }
     }
 
-    private void ensureRunningOnHandlerThread() {
-        if (mHandler.getLooper().getThread() != Thread.currentThread()) {
+    /*** Ensure that current running thread is same as given handler thread */
+    public static void ensureRunningOnHandlerThread(Handler handler) {
+        if (handler.getLooper().getThread() != Thread.currentThread()) {
             throw new IllegalStateException(
                     "Not running on Handler thread: " + Thread.currentThread().getName());
         }
@@ -169,7 +178,7 @@
 
     /*** Start monitoring sockets by listening callbacks for sockets creation or removal */
     public void startMonitoringSockets() {
-        ensureRunningOnHandlerThread();
+        ensureRunningOnHandlerThread(mHandler);
         if (mMonitoringSockets) {
             Log.d(TAG, "Already monitoring sockets.");
             return;
@@ -188,7 +197,7 @@
 
     /*** Stop monitoring sockets and unregister callbacks */
     public void stopMonitoringSockets() {
-        ensureRunningOnHandlerThread();
+        ensureRunningOnHandlerThread(mHandler);
         if (!mMonitoringSockets) {
             Log.d(TAG, "Monitoring sockets hasn't been started.");
             return;
@@ -204,19 +213,15 @@
         mMonitoringSockets = false;
     }
 
-    private static boolean isNetworkMatched(@Nullable Network targetNetwork,
+    /*** Check whether the target network is matched current network */
+    public static boolean isNetworkMatched(@Nullable Network targetNetwork,
             @NonNull Network currentNetwork) {
         return targetNetwork == null || targetNetwork.equals(currentNetwork);
     }
 
     private boolean matchRequestedNetwork(Network network) {
-        for (int i = 0; i < mCallbacksToRequestedNetworks.size(); i++) {
-            final Network requestedNetwork =  mCallbacksToRequestedNetworks.valueAt(i);
-            if (isNetworkMatched(requestedNetwork, network)) {
-                return true;
-            }
-        }
-        return false;
+        return hasAllNetworksRequest()
+                || mCallbacksToRequestedNetworks.containsValue(network);
     }
 
     private boolean hasAllNetworksRequest() {
@@ -279,15 +284,6 @@
         current.addAll(updated);
     }
 
-    private static List<LinkAddress> getLinkAddressFromNetworkInterface(
-            NetworkInterfaceWrapper networkInterface) {
-        List<LinkAddress> addresses = new ArrayList<>();
-        for (InterfaceAddress address : networkInterface.getInterfaceAddresses()) {
-            addresses.add(new LinkAddress(address));
-        }
-        return addresses;
-    }
-
     private void createSocket(Network network, LinkProperties lp) {
         final String interfaceName = lp.getInterfaceName();
         if (interfaceName == null) {
@@ -307,10 +303,12 @@
                         + " with interfaceName:" + interfaceName);
             }
             final MdnsInterfaceSocket socket = mDependencies.createMdnsInterfaceSocket(
-                    networkInterface.getNetworkInterface(), MdnsConstants.MDNS_PORT);
+                    networkInterface.getNetworkInterface(), MdnsConstants.MDNS_PORT, mLooper,
+                    mPacketReadBuffer);
             final List<LinkAddress> addresses;
             if (network.netId == INetd.LOCAL_NET_ID) {
-                addresses = getLinkAddressFromNetworkInterface(networkInterface);
+                addresses = CollectionUtils.map(
+                        networkInterface.getInterfaceAddresses(), LinkAddress::new);
                 mTetherInterfaceSockets.put(interfaceName, new SocketInfo(socket, addresses));
             } else {
                 addresses = lp.getLinkAddresses();
@@ -402,7 +400,7 @@
      * @param cb the callback to listen the socket creation.
      */
     public void requestSocket(@Nullable Network network, @NonNull SocketCallback cb) {
-        ensureRunningOnHandlerThread();
+        ensureRunningOnHandlerThread(mHandler);
         mCallbacksToRequestedNetworks.put(cb, network);
         if (network == null) {
             // Does not specify a required network, create sockets for all possible
@@ -425,7 +423,7 @@
 
     /*** Unrequest the socket */
     public void unrequestSocket(@NonNull SocketCallback cb) {
-        ensureRunningOnHandlerThread();
+        ensureRunningOnHandlerThread(mHandler);
         mCallbacksToRequestedNetworks.remove(cb);
         if (hasAllNetworksRequest()) {
             // Still has a request for all networks (interfaces).
@@ -434,16 +432,24 @@
 
         // Check if remaining requests are matched any of sockets.
         for (int i = mNetworkSockets.size() - 1; i >= 0; i--) {
-            if (matchRequestedNetwork(mNetworkSockets.keyAt(i))) continue;
-            mNetworkSockets.removeAt(i).mSocket.destroy();
+            final Network network = mNetworkSockets.keyAt(i);
+            if (matchRequestedNetwork(network)) continue;
+            final SocketInfo info = mNetworkSockets.removeAt(i);
+            info.mSocket.destroy();
+            // Still notify to unrequester for socket destroy.
+            cb.onInterfaceDestroyed(network, info.mSocket);
         }
 
         // Remove all sockets for tethering interface because these sockets do not have associated
         // networks, and they should invoke by a request for all networks (interfaces). If there is
         // no such request, the sockets for tethering interface should be removed.
         for (int i = mTetherInterfaceSockets.size() - 1; i >= 0; i--) {
-            mTetherInterfaceSockets.removeAt(i).mSocket.destroy();
+            final SocketInfo info = mTetherInterfaceSockets.valueAt(i);
+            info.mSocket.destroy();
+            // Still notify to unrequester for socket destroy.
+            cb.onInterfaceDestroyed(new Network(INetd.LOCAL_NET_ID), info.mSocket);
         }
+        mTetherInterfaceSockets.clear();
     }
 
     /*** Callbacks for listening socket changes */
diff --git a/service/mdns/com/android/server/connectivity/mdns/MulticastPacketReader.java b/service/mdns/com/android/server/connectivity/mdns/MulticastPacketReader.java
new file mode 100644
index 0000000..20cc47f
--- /dev/null
+++ b/service/mdns/com/android/server/connectivity/mdns/MulticastPacketReader.java
@@ -0,0 +1,111 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server.connectivity.mdns;
+
+import static com.android.server.connectivity.mdns.MdnsSocketProvider.ensureRunningOnHandlerThread;
+
+import android.annotation.NonNull;
+import android.os.Handler;
+import android.os.ParcelFileDescriptor;
+import android.system.Os;
+import android.util.ArraySet;
+
+import com.android.net.module.util.FdEventsReader;
+
+import java.io.FileDescriptor;
+import java.net.InetSocketAddress;
+import java.util.Set;
+
+/** Simple reader for mDNS packets. */
+public class MulticastPacketReader extends FdEventsReader<MulticastPacketReader.RecvBuffer> {
+    @NonNull
+    private final String mLogTag;
+    @NonNull
+    private final ParcelFileDescriptor mSocket;
+    @NonNull
+    private final Handler mHandler;
+    @NonNull
+    private final Set<PacketHandler> mPacketHandlers = new ArraySet<>();
+
+    interface PacketHandler {
+        void handlePacket(byte[] recvbuf, int length, InetSocketAddress src);
+    }
+
+    public static final class RecvBuffer {
+        final byte[] data;
+        final InetSocketAddress src;
+
+        private RecvBuffer(byte[] data, InetSocketAddress src) {
+            this.data = data;
+            this.src = src;
+        }
+    }
+
+    /**
+     * Create a new {@link MulticastPacketReader}.
+     * @param socket Socket to read from. This will *not* be closed when the reader terminates.
+     * @param buffer Buffer to read packets into. Will only be used from the handler thread.
+     */
+    protected MulticastPacketReader(@NonNull String interfaceTag,
+            @NonNull ParcelFileDescriptor socket, @NonNull Handler handler,
+            @NonNull byte[] buffer) {
+        super(handler, new RecvBuffer(buffer, new InetSocketAddress()));
+        mLogTag = MulticastPacketReader.class.getSimpleName() + "/" + interfaceTag;
+        mSocket = socket;
+        mHandler = handler;
+    }
+
+    @Override
+    protected int recvBufSize(@NonNull RecvBuffer buffer) {
+        return buffer.data.length;
+    }
+
+    @Override
+    protected FileDescriptor createFd() {
+        // Keep a reference to the PFD as it would close the fd in its finalizer otherwise
+        return mSocket.getFileDescriptor();
+    }
+
+    @Override
+    protected void onStop() {
+        // Do nothing (do not close the FD)
+    }
+
+    @Override
+    protected int readPacket(@NonNull FileDescriptor fd, @NonNull RecvBuffer buffer)
+            throws Exception {
+        return Os.recvfrom(
+                fd, buffer.data, 0, buffer.data.length, 0 /* flags */, buffer.src);
+    }
+
+    @Override
+    protected void handlePacket(@NonNull RecvBuffer recvbuf, int length) {
+        for (PacketHandler handler : mPacketHandlers) {
+            handler.handlePacket(recvbuf.data, length, recvbuf.src);
+        }
+    }
+
+    /**
+     * Add a packet handler to deal with received packets. If the handler is already set,
+     * this is a no-op.
+     */
+    public void addPacketHandler(@NonNull PacketHandler handler) {
+        ensureRunningOnHandlerThread(mHandler);
+        mPacketHandlers.add(handler);
+    }
+}
+
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsDiscoveryManagerTests.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsDiscoveryManagerTests.java
index 3e3c3bf..83e7696 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsDiscoveryManagerTests.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsDiscoveryManagerTests.java
@@ -46,7 +46,7 @@
     private static final String SERVICE_TYPE_2 = "_test._tcp.local";
 
     @Mock private ExecutorProvider executorProvider;
-    @Mock private MdnsSocketClient socketClient;
+    @Mock private MdnsSocketClientBase socketClient;
     @Mock private MdnsServiceTypeClient mockServiceTypeClientOne;
     @Mock private MdnsServiceTypeClient mockServiceTypeClientTwo;
 
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClientTest.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClientTest.java
new file mode 100644
index 0000000..9d42a65
--- /dev/null
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClientTest.java
@@ -0,0 +1,161 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server.connectivity.mdns;
+
+import static com.android.server.connectivity.mdns.MdnsSocketProvider.SocketCallback;
+import static com.android.server.connectivity.mdns.MulticastPacketReader.PacketHandler;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+import static org.mockito.Mockito.doReturn;
+import static org.mockito.Mockito.eq;
+import static org.mockito.Mockito.timeout;
+import static org.mockito.Mockito.verify;
+
+import android.net.InetAddresses;
+import android.net.Network;
+import android.os.Build;
+import android.os.Handler;
+import android.os.HandlerThread;
+
+import com.android.net.module.util.HexDump;
+import com.android.testutils.DevSdkIgnoreRule;
+import com.android.testutils.DevSdkIgnoreRunner;
+import com.android.testutils.HandlerUtils;
+
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.ArgumentCaptor;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+
+import java.io.IOException;
+import java.lang.reflect.Constructor;
+import java.net.DatagramPacket;
+import java.net.NetworkInterface;
+import java.net.SocketException;
+import java.util.List;
+
+@RunWith(DevSdkIgnoreRunner.class)
+@DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.S_V2)
+public class MdnsMultinetworkSocketClientTest {
+    private static final byte[] BUFFER = new byte[10];
+    private static final long DEFAULT_TIMEOUT = 2000L;
+    @Mock private Network mNetwork;
+    @Mock private MdnsSocketProvider mProvider;
+    @Mock private MdnsInterfaceSocket mSocket;
+    @Mock private MdnsServiceBrowserListener mListener;
+    @Mock private MdnsSocketClientBase.Callback mCallback;
+    private MdnsMultinetworkSocketClient mSocketClient;
+    private Handler mHandler;
+
+    @Before
+    public void setUp() throws SocketException {
+        MockitoAnnotations.initMocks(this);
+        final HandlerThread thread = new HandlerThread("MdnsMultinetworkSocketClientTest");
+        thread.start();
+        mHandler = new Handler(thread.getLooper());
+        mSocketClient = new MdnsMultinetworkSocketClient(thread.getLooper(), mProvider);
+        mHandler.post(() -> mSocketClient.setCallback(mCallback));
+    }
+
+    private SocketCallback expectSocketCallback() {
+        final ArgumentCaptor<SocketCallback> callbackCaptor =
+                ArgumentCaptor.forClass(SocketCallback.class);
+        mHandler.post(() -> mSocketClient.notifyNetworkRequested(mListener, mNetwork));
+        verify(mProvider, timeout(DEFAULT_TIMEOUT))
+                .requestSocket(eq(mNetwork), callbackCaptor.capture());
+        return callbackCaptor.getValue();
+    }
+
+    private NetworkInterface createEmptyNetworkInterface() {
+        try {
+            Constructor<NetworkInterface> constructor =
+                    NetworkInterface.class.getDeclaredConstructor();
+            constructor.setAccessible(true);
+            return constructor.newInstance();
+        } catch (Exception e) {
+            throw new RuntimeException(e);
+        }
+    }
+
+    @Test
+    public void testSendPacket() throws IOException {
+        final SocketCallback callback = expectSocketCallback();
+        final DatagramPacket ipv4Packet = new DatagramPacket(BUFFER, 0 /* offset */, BUFFER.length,
+                InetAddresses.parseNumericAddress("192.0.2.1"), 0 /* port */);
+        final DatagramPacket ipv6Packet = new DatagramPacket(BUFFER, 0 /* offset */, BUFFER.length,
+                InetAddresses.parseNumericAddress("2001:db8::"), 0 /* port */);
+        doReturn(true).when(mSocket).hasJoinedIpv4();
+        doReturn(true).when(mSocket).hasJoinedIpv6();
+        doReturn(createEmptyNetworkInterface()).when(mSocket).getInterface();
+        // Notify socket created
+        callback.onSocketCreated(mNetwork, mSocket, List.of());
+
+        // Send packet to IPv4 with target network and verify sending has been called.
+        mSocketClient.sendMulticastPacket(ipv4Packet, mNetwork);
+        HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
+        verify(mSocket).send(ipv4Packet);
+
+        // Send packet to IPv6 without target network and verify sending has been called.
+        mSocketClient.sendMulticastPacket(ipv6Packet);
+        HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
+        verify(mSocket).send(ipv6Packet);
+    }
+
+    @Test
+    public void testReceivePacket() {
+        final SocketCallback callback = expectSocketCallback();
+        final byte[] data = HexDump.hexStringToByteArray(
+                // scapy.raw(scapy.dns_compress(
+                //     scapy.DNS(rd=0, qr=1, aa=1, qd = None,
+                //     an =
+                //     scapy.DNSRR(type='PTR', rrname='_testtype._tcp.local',
+                //         rdata='testservice._testtype._tcp.local', rclass='IN', ttl=4500) /
+                //     scapy.DNSRRSRV(rrname='testservice._testtype._tcp.local', rclass=0x8001,
+                //         port=31234, target='Android.local', ttl=120))
+                // )).hex().upper()
+                "000084000000000200000000095F7465737474797065045F746370056C6F63616C00000C0001000011"
+                        + "94000E0B7465737473657276696365C00CC02C00218001000000780010000000007A0207"
+                        + "416E64726F6964C01B");
+
+        doReturn(createEmptyNetworkInterface()).when(mSocket).getInterface();
+        // Notify socket created
+        callback.onSocketCreated(mNetwork, mSocket, List.of());
+
+        final ArgumentCaptor<PacketHandler> handlerCaptor =
+                ArgumentCaptor.forClass(PacketHandler.class);
+        verify(mSocket).addPacketHandler(handlerCaptor.capture());
+
+        // Send the data and verify the received records.
+        final PacketHandler handler = handlerCaptor.getValue();
+        handler.handlePacket(data, data.length, null /* src */);
+        final ArgumentCaptor<MdnsResponse> responseCaptor =
+                ArgumentCaptor.forClass(MdnsResponse.class);
+        verify(mCallback).onResponseReceived(responseCaptor.capture());
+        final MdnsResponse response = responseCaptor.getValue();
+        assertTrue(response.hasPointerRecords());
+        assertArrayEquals("_testtype._tcp.local".split("\\."),
+                response.getPointerRecords().get(0).getName());
+        assertTrue(response.hasServiceRecord());
+        assertEquals("testservice", response.getServiceRecord().getServiceInstanceName());
+        assertEquals("Android.local".split("\\."),
+                response.getServiceRecord().getServiceHost());
+    }
+}
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java
index 697116c..a45ca68 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java
@@ -62,7 +62,7 @@
 import java.net.DatagramPacket;
 import java.net.Inet4Address;
 import java.net.Inet6Address;
-import java.net.SocketAddress;
+import java.net.InetSocketAddress;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
@@ -80,7 +80,10 @@
     private static final int INTERFACE_INDEX = 999;
     private static final String SERVICE_TYPE = "_googlecast._tcp.local";
     private static final String[] SERVICE_TYPE_LABELS = TextUtils.split(SERVICE_TYPE, "\\.");
-    private static final Network NETWORK = mock(Network.class);
+    private static final InetSocketAddress IPV4_ADDRESS = new InetSocketAddress(
+            MdnsConstants.getMdnsIPv4Address(), MdnsConstants.MDNS_PORT);
+    private static final InetSocketAddress IPV6_ADDRESS = new InetSocketAddress(
+            MdnsConstants.getMdnsIPv6Address(), MdnsConstants.MDNS_PORT);
 
     @Mock
     private MdnsServiceBrowserListener mockListenerOne;
@@ -89,13 +92,16 @@
     @Mock
     private MdnsPacketWriter mockPacketWriter;
     @Mock
-    private MdnsSocketClient mockSocketClient;
+    private MdnsMultinetworkSocketClient mockSocketClient;
+    @Mock
+    private Network mockNetwork;
     @Captor
     private ArgumentCaptor<MdnsServiceInfo> serviceInfoCaptor;
 
     private final byte[] buf = new byte[10];
 
-    private DatagramPacket[] expectedPackets;
+    private DatagramPacket[] expectedIPv4Packets;
+    private DatagramPacket[] expectedIPv6Packets;
     private ScheduledFuture<?>[] expectedSendFutures;
     private FakeExecutor currentThreadExecutor = new FakeExecutor();
 
@@ -106,30 +112,52 @@
     public void setUp() throws IOException {
         MockitoAnnotations.initMocks(this);
 
-        expectedPackets = new DatagramPacket[16];
+        expectedIPv4Packets = new DatagramPacket[16];
+        expectedIPv6Packets = new DatagramPacket[16];
         expectedSendFutures = new ScheduledFuture<?>[16];
 
         for (int i = 0; i < expectedSendFutures.length; ++i) {
-            expectedPackets[i] = new DatagramPacket(buf, 0, 5);
+            expectedIPv4Packets[i] = new DatagramPacket(buf, 0 /* offset */, 5 /* length */,
+                    MdnsConstants.getMdnsIPv4Address(), MdnsConstants.MDNS_PORT);
+            expectedIPv6Packets[i] = new DatagramPacket(buf, 0 /* offset */, 5 /* length */,
+                    MdnsConstants.getMdnsIPv6Address(), MdnsConstants.MDNS_PORT);
             expectedSendFutures[i] = Mockito.mock(ScheduledFuture.class);
         }
-        when(mockPacketWriter.getPacket(any(SocketAddress.class)))
-                .thenReturn(expectedPackets[0])
-                .thenReturn(expectedPackets[1])
-                .thenReturn(expectedPackets[2])
-                .thenReturn(expectedPackets[3])
-                .thenReturn(expectedPackets[4])
-                .thenReturn(expectedPackets[5])
-                .thenReturn(expectedPackets[6])
-                .thenReturn(expectedPackets[7])
-                .thenReturn(expectedPackets[8])
-                .thenReturn(expectedPackets[9])
-                .thenReturn(expectedPackets[10])
-                .thenReturn(expectedPackets[11])
-                .thenReturn(expectedPackets[12])
-                .thenReturn(expectedPackets[13])
-                .thenReturn(expectedPackets[14])
-                .thenReturn(expectedPackets[15]);
+        when(mockPacketWriter.getPacket(IPV4_ADDRESS))
+                .thenReturn(expectedIPv4Packets[0])
+                .thenReturn(expectedIPv4Packets[1])
+                .thenReturn(expectedIPv4Packets[2])
+                .thenReturn(expectedIPv4Packets[3])
+                .thenReturn(expectedIPv4Packets[4])
+                .thenReturn(expectedIPv4Packets[5])
+                .thenReturn(expectedIPv4Packets[6])
+                .thenReturn(expectedIPv4Packets[7])
+                .thenReturn(expectedIPv4Packets[8])
+                .thenReturn(expectedIPv4Packets[9])
+                .thenReturn(expectedIPv4Packets[10])
+                .thenReturn(expectedIPv4Packets[11])
+                .thenReturn(expectedIPv4Packets[12])
+                .thenReturn(expectedIPv4Packets[13])
+                .thenReturn(expectedIPv4Packets[14])
+                .thenReturn(expectedIPv4Packets[15]);
+
+        when(mockPacketWriter.getPacket(IPV6_ADDRESS))
+                .thenReturn(expectedIPv6Packets[0])
+                .thenReturn(expectedIPv6Packets[1])
+                .thenReturn(expectedIPv6Packets[2])
+                .thenReturn(expectedIPv6Packets[3])
+                .thenReturn(expectedIPv6Packets[4])
+                .thenReturn(expectedIPv6Packets[5])
+                .thenReturn(expectedIPv6Packets[6])
+                .thenReturn(expectedIPv6Packets[7])
+                .thenReturn(expectedIPv6Packets[8])
+                .thenReturn(expectedIPv6Packets[9])
+                .thenReturn(expectedIPv6Packets[10])
+                .thenReturn(expectedIPv6Packets[11])
+                .thenReturn(expectedIPv6Packets[12])
+                .thenReturn(expectedIPv6Packets[13])
+                .thenReturn(expectedIPv6Packets[14])
+                .thenReturn(expectedIPv6Packets[15]);
 
         client =
                 new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor) {
@@ -282,8 +310,8 @@
         //MdnsConfigsFlagsImpl.alwaysAskForUnicastResponseInEachBurst.override(true);
         MdnsSearchOptions searchOptions =
                 MdnsSearchOptions.newBuilder().addSubtype("12345").setIsPassiveMode(false).build();
-        QueryTaskConfig config =
-                new QueryTaskConfig(searchOptions.getSubtypes(), searchOptions.isPassiveMode(), 1);
+        QueryTaskConfig config = new QueryTaskConfig(
+                searchOptions.getSubtypes(), searchOptions.isPassiveMode(), 1, mockNetwork);
 
         // This is the first query. We will ask for unicast response.
         assertTrue(config.expectUnicastResponse);
@@ -311,8 +339,8 @@
     public void testQueryTaskConfig_askForUnicastInFirstQuery() {
         MdnsSearchOptions searchOptions =
                 MdnsSearchOptions.newBuilder().addSubtype("12345").setIsPassiveMode(false).build();
-        QueryTaskConfig config =
-                new QueryTaskConfig(searchOptions.getSubtypes(), searchOptions.isPassiveMode(), 1);
+        QueryTaskConfig config = new QueryTaskConfig(
+                searchOptions.getSubtypes(), searchOptions.isPassiveMode(), 1, mockNetwork);
 
         // This is the first query. We will ask for unicast response.
         assertTrue(config.expectUnicastResponse);
@@ -409,7 +437,7 @@
         MdnsResponse response = mock(MdnsResponse.class);
         when(response.getServiceInstanceName()).thenReturn("service-instance-1");
         doReturn(INTERFACE_INDEX).when(response).getInterfaceIndex();
-        doReturn(NETWORK).when(response).getNetwork();
+        doReturn(mockNetwork).when(response).getNetwork();
         when(response.isComplete()).thenReturn(false);
 
         client.processResponse(response);
@@ -423,7 +451,7 @@
                 List.of() /* subTypes */,
                 Collections.singletonMap("key", null) /* attributes */,
                 INTERFACE_INDEX,
-                NETWORK);
+                mockNetwork);
 
         verify(mockListenerOne, never()).onServiceFound(any(MdnsServiceInfo.class));
         verify(mockListenerOne, never()).onServiceUpdated(any(MdnsServiceInfo.class));
@@ -443,7 +471,7 @@
                         /* subtype= */ "ABCDE",
                         Collections.emptyMap(),
                         /* interfaceIndex= */ 20,
-                        NETWORK);
+                        mockNetwork);
         client.processResponse(initialResponse);
 
         // Process a second response with a different port and updated text attributes.
@@ -455,7 +483,7 @@
                         /* subtype= */ "ABCDE",
                         Collections.singletonMap("key", "value"),
                         /* interfaceIndex= */ 20,
-                        NETWORK);
+                        mockNetwork);
         client.processResponse(secondResponse);
 
         // Verify onServiceNameDiscovered was called once for the initial response.
@@ -469,7 +497,7 @@
                 Collections.singletonList("ABCDE") /* subTypes */,
                 Collections.singletonMap("key", null) /* attributes */,
                 20 /* interfaceIndex */,
-                NETWORK);
+                mockNetwork);
 
         // Verify onServiceFound was called once for the initial response.
         verify(mockListenerOne).onServiceFound(serviceInfoCaptor.capture());
@@ -480,7 +508,7 @@
         assertEquals(initialServiceInfo.getSubtypes(), Collections.singletonList("ABCDE"));
         assertNull(initialServiceInfo.getAttributeByKey("key"));
         assertEquals(initialServiceInfo.getInterfaceIndex(), 20);
-        assertEquals(NETWORK, initialServiceInfo.getNetwork());
+        assertEquals(mockNetwork, initialServiceInfo.getNetwork());
 
         // Verify onServiceUpdated was called once for the second response.
         verify(mockListenerOne).onServiceUpdated(serviceInfoCaptor.capture());
@@ -492,7 +520,7 @@
         assertEquals(updatedServiceInfo.getSubtypes(), Collections.singletonList("ABCDE"));
         assertEquals(updatedServiceInfo.getAttributeByKey("key"), "value");
         assertEquals(updatedServiceInfo.getInterfaceIndex(), 20);
-        assertEquals(NETWORK, updatedServiceInfo.getNetwork());
+        assertEquals(mockNetwork, updatedServiceInfo.getNetwork());
     }
 
     @Test
@@ -509,7 +537,7 @@
                         /* subtype= */ "ABCDE",
                         Collections.emptyMap(),
                         /* interfaceIndex= */ 20,
-                        NETWORK);
+                        mockNetwork);
         client.processResponse(initialResponse);
 
         // Process a second response with a different port and updated text attributes.
@@ -521,7 +549,7 @@
                         /* subtype= */ "ABCDE",
                         Collections.singletonMap("key", "value"),
                         /* interfaceIndex= */ 20,
-                        NETWORK);
+                        mockNetwork);
         client.processResponse(secondResponse);
 
         System.out.println("secondResponses ip"
@@ -538,7 +566,7 @@
                 Collections.singletonList("ABCDE") /* subTypes */,
                 Collections.singletonMap("key", null) /* attributes */,
                 20 /* interfaceIndex */,
-                NETWORK);
+                mockNetwork);
 
         // Verify onServiceFound was called once for the initial response.
         verify(mockListenerOne).onServiceFound(serviceInfoCaptor.capture());
@@ -549,7 +577,7 @@
         assertEquals(initialServiceInfo.getSubtypes(), Collections.singletonList("ABCDE"));
         assertNull(initialServiceInfo.getAttributeByKey("key"));
         assertEquals(initialServiceInfo.getInterfaceIndex(), 20);
-        assertEquals(NETWORK, initialServiceInfo.getNetwork());
+        assertEquals(mockNetwork, initialServiceInfo.getNetwork());
 
         // Verify onServiceUpdated was called once for the second response.
         verify(mockListenerOne).onServiceUpdated(serviceInfoCaptor.capture());
@@ -561,7 +589,7 @@
         assertEquals(updatedServiceInfo.getSubtypes(), Collections.singletonList("ABCDE"));
         assertEquals(updatedServiceInfo.getAttributeByKey("key"), "value");
         assertEquals(updatedServiceInfo.getInterfaceIndex(), 20);
-        assertEquals(NETWORK, updatedServiceInfo.getNetwork());
+        assertEquals(mockNetwork, updatedServiceInfo.getNetwork());
     }
 
     private void verifyServiceRemovedNoCallback(MdnsServiceBrowserListener listener) {
@@ -599,12 +627,12 @@
                         /* subtype= */ "ABCDE",
                         Collections.emptyMap(),
                         INTERFACE_INDEX,
-                        NETWORK);
+                        mockNetwork);
         client.processResponse(initialResponse);
         MdnsResponse response = mock(MdnsResponse.class);
         doReturn("goodbye-service").when(response).getServiceInstanceName();
         doReturn(INTERFACE_INDEX).when(response).getInterfaceIndex();
-        doReturn(NETWORK).when(response).getNetwork();
+        doReturn(mockNetwork).when(response).getNetwork();
         doReturn(true).when(response).isGoodbye();
         client.processResponse(response);
         // Verify removed callback won't be called if the service is not existed.
@@ -615,9 +643,9 @@
         doReturn(serviceName).when(response).getServiceInstanceName();
         client.processResponse(response);
         verifyServiceRemovedCallback(
-                mockListenerOne, serviceName, SERVICE_TYPE_LABELS, INTERFACE_INDEX, NETWORK);
+                mockListenerOne, serviceName, SERVICE_TYPE_LABELS, INTERFACE_INDEX, mockNetwork);
         verifyServiceRemovedCallback(
-                mockListenerTwo, serviceName, SERVICE_TYPE_LABELS, INTERFACE_INDEX, NETWORK);
+                mockListenerTwo, serviceName, SERVICE_TYPE_LABELS, INTERFACE_INDEX, mockNetwork);
     }
 
     @Test
@@ -631,7 +659,7 @@
                         /* subtype= */ "ABCDE",
                         Collections.emptyMap(),
                         INTERFACE_INDEX,
-                        NETWORK);
+                        mockNetwork);
         client.processResponse(initialResponse);
 
         client.startSendAndReceive(mockListenerOne, MdnsSearchOptions.getDefaultOptions());
@@ -647,7 +675,7 @@
                 Collections.singletonList("ABCDE") /* subTypes */,
                 Collections.singletonMap("key", null) /* attributes */,
                 INTERFACE_INDEX,
-                NETWORK);
+                mockNetwork);
 
         // Verify onServiceFound was called once for the existing response.
         verify(mockListenerOne).onServiceFound(serviceInfoCaptor.capture());
@@ -684,7 +712,7 @@
         MdnsResponse initialResponse =
                 createMockResponse(
                         serviceInstanceName, "192.168.1.1", 5353, List.of("ABCDE"),
-                        Map.of(), INTERFACE_INDEX, NETWORK);
+                        Map.of(), INTERFACE_INDEX, mockNetwork);
         client.processResponse(initialResponse);
 
         // Clear the scheduled runnable.
@@ -718,7 +746,7 @@
         MdnsResponse initialResponse =
                 createMockResponse(
                         serviceInstanceName, "192.168.1.1", 5353, List.of("ABCDE"),
-                        Map.of(), INTERFACE_INDEX, NETWORK);
+                        Map.of(), INTERFACE_INDEX, mockNetwork);
         client.processResponse(initialResponse);
 
         // Clear the scheduled runnable.
@@ -737,7 +765,7 @@
 
         // Verify removed callback was called.
         verifyServiceRemovedCallback(mockListenerOne, serviceInstanceName, SERVICE_TYPE_LABELS,
-                INTERFACE_INDEX, NETWORK);
+                INTERFACE_INDEX, mockNetwork);
     }
 
     @Test
@@ -758,7 +786,7 @@
         MdnsResponse initialResponse =
                 createMockResponse(
                         serviceInstanceName, "192.168.1.1", 5353, List.of("ABCDE"),
-                        Map.of(), INTERFACE_INDEX, NETWORK);
+                        Map.of(), INTERFACE_INDEX, mockNetwork);
         client.processResponse(initialResponse);
 
         // Clear the scheduled runnable.
@@ -792,7 +820,7 @@
         MdnsResponse initialResponse =
                 createMockResponse(
                         serviceInstanceName, "192.168.1.1", 5353, List.of("ABCDE"),
-                        Map.of(), INTERFACE_INDEX, NETWORK);
+                        Map.of(), INTERFACE_INDEX, mockNetwork);
         client.processResponse(initialResponse);
 
         // Clear the scheduled runnable.
@@ -804,7 +832,7 @@
 
         // Verify removed callback was called.
         verifyServiceRemovedCallback(mockListenerOne, serviceInstanceName, SERVICE_TYPE_LABELS,
-                INTERFACE_INDEX, NETWORK);
+                INTERFACE_INDEX, mockNetwork);
     }
 
     @Test
@@ -824,7 +852,7 @@
                         "ABCDE" /* subtype */,
                         Collections.emptyMap(),
                         INTERFACE_INDEX,
-                        NETWORK);
+                        mockNetwork);
         client.processResponse(initialResponse);
 
         // Process a second response which has ip address to make response become complete.
@@ -836,7 +864,7 @@
                         "ABCDE" /* subtype */,
                         Collections.emptyMap(),
                         INTERFACE_INDEX,
-                        NETWORK);
+                        mockNetwork);
         client.processResponse(secondResponse);
 
         // Process a third response with a different ip address, port and updated text attributes.
@@ -848,7 +876,7 @@
                         "ABCDE" /* subtype */,
                         Collections.singletonMap("key", "value"),
                         INTERFACE_INDEX,
-                        NETWORK);
+                        mockNetwork);
         client.processResponse(thirdResponse);
 
         // Process the last response which is goodbye message.
@@ -868,7 +896,7 @@
                 Collections.singletonList("ABCDE") /* subTypes */,
                 Collections.singletonMap("key", null) /* attributes */,
                 INTERFACE_INDEX,
-                NETWORK);
+                mockNetwork);
 
         // Verify onServiceFound was second called for the second response.
         inOrder.verify(mockListenerOne).onServiceFound(serviceInfoCaptor.capture());
@@ -881,7 +909,7 @@
                 Collections.singletonList("ABCDE") /* subTypes */,
                 Collections.singletonMap("key", null) /* attributes */,
                 INTERFACE_INDEX,
-                NETWORK);
+                mockNetwork);
 
         // Verify onServiceUpdated was third called for the third response.
         inOrder.verify(mockListenerOne).onServiceUpdated(serviceInfoCaptor.capture());
@@ -894,7 +922,7 @@
                 Collections.singletonList("ABCDE") /* subTypes */,
                 Collections.singletonMap("key", "value") /* attributes */,
                 INTERFACE_INDEX,
-                NETWORK);
+                mockNetwork);
 
         // Verify onServiceRemoved was called for the last response.
         inOrder.verify(mockListenerOne).onServiceRemoved(serviceInfoCaptor.capture());
@@ -907,7 +935,7 @@
                 Collections.singletonList("ABCDE") /* subTypes */,
                 Collections.singletonMap("key", "value") /* attributes */,
                 INTERFACE_INDEX,
-                NETWORK);
+                mockNetwork);
 
         // Verify onServiceNameRemoved was called for the last response.
         inOrder.verify(mockListenerOne).onServiceNameRemoved(serviceInfoCaptor.capture());
@@ -920,18 +948,34 @@
                 Collections.singletonList("ABCDE") /* subTypes */,
                 Collections.singletonMap("key", "value") /* attributes */,
                 INTERFACE_INDEX,
-                NETWORK);
+                mockNetwork);
     }
 
     // verifies that the right query was enqueued with the right delay, and send query by executing
     // the runnable.
     private void verifyAndSendQuery(int index, long timeInMs, boolean expectsUnicastResponse) {
+        verifyAndSendQuery(
+                index, timeInMs, expectsUnicastResponse, true /* multipleSocketDiscovery */);
+    }
+
+    private void verifyAndSendQuery(int index, long timeInMs, boolean expectsUnicastResponse,
+            boolean multipleSocketDiscovery) {
         assertEquals(currentThreadExecutor.getAndClearLastScheduledDelayInMs(), timeInMs);
         currentThreadExecutor.getAndClearLastScheduledRunnable().run();
         if (expectsUnicastResponse) {
-            verify(mockSocketClient).sendUnicastPacket(expectedPackets[index]);
+            verify(mockSocketClient).sendUnicastPacket(
+                    expectedIPv4Packets[index], null /* network */);
+            if (multipleSocketDiscovery) {
+                verify(mockSocketClient).sendUnicastPacket(
+                        expectedIPv6Packets[index], null /* network */);
+            }
         } else {
-            verify(mockSocketClient).sendMulticastPacket(expectedPackets[index]);
+            verify(mockSocketClient).sendMulticastPacket(
+                    expectedIPv4Packets[index], null /* network */);
+            if (multipleSocketDiscovery) {
+                verify(mockSocketClient).sendMulticastPacket(
+                        expectedIPv6Packets[index], null /* network */);
+            }
         }
     }
 
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketProviderTest.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketProviderTest.java
index ef73030..07bbbb5 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketProviderTest.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketProviderTest.java
@@ -96,7 +96,7 @@
                 .getNetworkInterfaceByName(LOCAL_ONLY_IFACE_NAME);
         doReturn(mTetheredIfaceWrapper).when(mDeps).getNetworkInterfaceByName(TETHERED_IFACE_NAME);
         doReturn(mock(MdnsInterfaceSocket.class))
-                .when(mDeps).createMdnsInterfaceSocket(any(), anyInt());
+                .when(mDeps).createMdnsInterfaceSocket(any(), anyInt(), any(), any());
         final HandlerThread thread = new HandlerThread("MdnsSocketProviderTest");
         thread.start();
         mHandler = new Handler(thread.getLooper());
@@ -165,7 +165,7 @@
         }
 
         public void expectedSocketCreatedForNetwork(Network network, List<LinkAddress> addresses) {
-            final SocketEvent event = mHistory.poll(DEFAULT_TIMEOUT, c -> true);
+            final SocketEvent event = mHistory.poll(0L /* timeoutMs */, c -> true);
             assertNotNull(event);
             assertTrue(event instanceof SocketCreatedEvent);
             assertEquals(network, event.mNetwork);
@@ -173,7 +173,7 @@
         }
 
         public void expectedInterfaceDestroyedForNetwork(Network network) {
-            final SocketEvent event = mHistory.poll(DEFAULT_TIMEOUT, c -> true);
+            final SocketEvent event = mHistory.poll(0L /* timeoutMs */, c -> true);
             assertNotNull(event);
             assertTrue(event instanceof InterfaceDestroyedEvent);
             assertEquals(network, event.mNetwork);
@@ -181,7 +181,7 @@
 
         public void expectedAddressesChangedForNetwork(Network network,
                 List<LinkAddress> addresses) {
-            final SocketEvent event = mHistory.poll(DEFAULT_TIMEOUT, c -> true);
+            final SocketEvent event = mHistory.poll(0L /* timeoutMs */, c -> true);
             assertNotNull(event);
             assertTrue(event instanceof AddressesChangedEvent);
             assertEquals(network, event.mNetwork);
@@ -260,7 +260,8 @@
         HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
         testCallback1.expectedNoCallback();
         testCallback2.expectedNoCallback();
-        testCallback3.expectedNoCallback();
+        // Expect the socket destroy for tethered interface.
+        testCallback3.expectedInterfaceDestroyedForNetwork(LOCAL_NETWORK);
     }
 
     @Test