Merge "Add MdnsMultinetworkSocketClient" am: 6ea0c1d05b

Original change: https://android-review.googlesource.com/c/platform/packages/modules/Connectivity/+/2278027

Change-Id: Icbeb23ac95c25a6eaa655473f3d706deccc8ca89
Signed-off-by: Automerger Merge Worker <android-build-automerger-merge-worker@system.gserviceaccount.com>
diff --git a/service/mdns/com/android/server/connectivity/mdns/EnqueueMdnsQueryCallable.java b/service/mdns/com/android/server/connectivity/mdns/EnqueueMdnsQueryCallable.java
index f7871f3..fdd1478 100644
--- a/service/mdns/com/android/server/connectivity/mdns/EnqueueMdnsQueryCallable.java
+++ b/service/mdns/com/android/server/connectivity/mdns/EnqueueMdnsQueryCallable.java
@@ -18,7 +18,9 @@
 
 import android.annotation.NonNull;
 import android.annotation.Nullable;
+import android.net.Network;
 import android.text.TextUtils;
+import android.util.Log;
 import android.util.Pair;
 
 import com.android.server.connectivity.mdns.util.MdnsLogger;
@@ -58,26 +60,29 @@
         }
     }
 
-    private final WeakReference<MdnsSocketClient> weakRequestSender;
+    private final WeakReference<MdnsSocketClientBase> weakRequestSender;
     private final MdnsPacketWriter packetWriter;
     private final String[] serviceTypeLabels;
     private final List<String> subtypes;
     private final boolean expectUnicastResponse;
     private final int transactionId;
+    private final Network network;
 
     EnqueueMdnsQueryCallable(
-            @NonNull MdnsSocketClient requestSender,
+            @NonNull MdnsSocketClientBase requestSender,
             @NonNull MdnsPacketWriter packetWriter,
             @NonNull String serviceType,
             @NonNull Collection<String> subtypes,
             boolean expectUnicastResponse,
-            int transactionId) {
+            int transactionId,
+            @Nullable Network network) {
         weakRequestSender = new WeakReference<>(requestSender);
         this.packetWriter = packetWriter;
         serviceTypeLabels = TextUtils.split(serviceType, "\\.");
         this.subtypes = new ArrayList<>(subtypes);
         this.expectUnicastResponse = expectUnicastResponse;
         this.transactionId = transactionId;
+        this.network = network;
     }
 
     // Incompatible return type for override of Callable#call().
@@ -86,7 +91,7 @@
     @Nullable
     public Pair<Integer, List<String>> call() {
         try {
-            MdnsSocketClient requestSender = weakRequestSender.get();
+            MdnsSocketClientBase requestSender = weakRequestSender.get();
             if (requestSender == null) {
                 return null;
             }
@@ -127,15 +132,24 @@
                     MdnsConstants.QCLASS_INTERNET
                             | (expectUnicastResponse ? MdnsConstants.QCLASS_UNICAST : 0));
 
-            InetAddress mdnsAddress = MdnsConstants.getMdnsIPv4Address();
-            if (requestSender.isOnIPv6OnlyNetwork()) {
-                mdnsAddress = MdnsConstants.getMdnsIPv6Address();
-            }
+            if (requestSender instanceof MdnsMultinetworkSocketClient) {
+                sendPacketToIpv4AndIpv6(requestSender, MdnsConstants.MDNS_PORT, network);
+                for (Integer emulatorPort : castShellEmulatorMdnsPorts) {
+                    sendPacketToIpv4AndIpv6(requestSender, emulatorPort, network);
+                }
+            } else if (requestSender instanceof MdnsSocketClient) {
+                final MdnsSocketClient client = (MdnsSocketClient) requestSender;
+                InetAddress mdnsAddress = MdnsConstants.getMdnsIPv4Address();
+                if (client.isOnIPv6OnlyNetwork()) {
+                    mdnsAddress = MdnsConstants.getMdnsIPv6Address();
+                }
 
-            sendPacketTo(requestSender,
-                    new InetSocketAddress(mdnsAddress, MdnsConstants.MDNS_PORT));
-            for (Integer emulatorPort : castShellEmulatorMdnsPorts) {
-                sendPacketTo(requestSender, new InetSocketAddress(mdnsAddress, emulatorPort));
+                sendPacketTo(client, new InetSocketAddress(mdnsAddress, MdnsConstants.MDNS_PORT));
+                for (Integer emulatorPort : castShellEmulatorMdnsPorts) {
+                    sendPacketTo(client, new InetSocketAddress(mdnsAddress, emulatorPort));
+                }
+            } else {
+                throw new IOException("Unknown socket client type: " + requestSender.getClass());
             }
             return Pair.create(transactionId, subtypes);
         } catch (IOException e) {
@@ -145,7 +159,7 @@
         }
     }
 
-    private void sendPacketTo(MdnsSocketClient requestSender, InetSocketAddress address)
+    private void sendPacketTo(MdnsSocketClientBase requestSender, InetSocketAddress address)
             throws IOException {
         DatagramPacket packet = packetWriter.getPacket(address);
         if (expectUnicastResponse) {
@@ -154,4 +168,31 @@
             requestSender.sendMulticastPacket(packet);
         }
     }
+
+    private void sendPacketFromNetwork(MdnsSocketClientBase requestSender,
+            InetSocketAddress address, Network network)
+            throws IOException {
+        DatagramPacket packet = packetWriter.getPacket(address);
+        if (expectUnicastResponse) {
+            requestSender.sendUnicastPacket(packet, network);
+        } else {
+            requestSender.sendMulticastPacket(packet, network);
+        }
+    }
+
+    private void sendPacketToIpv4AndIpv6(MdnsSocketClientBase requestSender, int port,
+            Network network) {
+        try {
+            sendPacketFromNetwork(requestSender,
+                    new InetSocketAddress(MdnsConstants.getMdnsIPv4Address(), port), network);
+        } catch (IOException e) {
+            Log.i(TAG, "Can't send packet to IPv4", e);
+        }
+        try {
+            sendPacketFromNetwork(requestSender,
+                    new InetSocketAddress(MdnsConstants.getMdnsIPv6Address(), port), network);
+        } catch (IOException e) {
+            Log.i(TAG, "Can't send packet to IPv6", e);
+        }
+    }
 }
\ No newline at end of file
diff --git a/service/mdns/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java b/service/mdns/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java
index 0f3c23a..cc6b98b 100644
--- a/service/mdns/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java
+++ b/service/mdns/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java
@@ -34,18 +34,18 @@
  * This class keeps tracking the set of registered {@link MdnsServiceBrowserListener} instances, and
  * notify them when a mDNS service instance is found, updated, or removed?
  */
-public class MdnsDiscoveryManager implements MdnsSocketClient.Callback {
+public class MdnsDiscoveryManager implements MdnsSocketClientBase.Callback {
     private static final String TAG = MdnsDiscoveryManager.class.getSimpleName();
     public static final boolean DBG = Log.isLoggable(TAG, Log.DEBUG);
     private static final MdnsLogger LOGGER = new MdnsLogger("MdnsDiscoveryManager");
 
     private final ExecutorProvider executorProvider;
-    private final MdnsSocketClient socketClient;
+    private final MdnsSocketClientBase socketClient;
 
     private final Map<String, MdnsServiceTypeClient> serviceTypeClients = new ArrayMap<>();
 
-    public MdnsDiscoveryManager(
-            @NonNull ExecutorProvider executorProvider, @NonNull MdnsSocketClient socketClient) {
+    public MdnsDiscoveryManager(@NonNull ExecutorProvider executorProvider,
+            @NonNull MdnsSocketClientBase socketClient) {
         this.executorProvider = executorProvider;
         this.socketClient = socketClient;
     }
@@ -76,12 +76,16 @@
                 return;
             }
         }
+        // Request the network for discovery.
+        socketClient.notifyNetworkRequested(listener, searchOptions.getNetwork());
+
         // All listeners of the same service types shares the same MdnsServiceTypeClient.
         MdnsServiceTypeClient serviceTypeClient = serviceTypeClients.get(serviceType);
         if (serviceTypeClient == null) {
             serviceTypeClient = createServiceTypeClient(serviceType);
             serviceTypeClients.put(serviceType, serviceTypeClient);
         }
+        // TODO(b/264634275): Wait for a socket to be created before sending packets.
         serviceTypeClient.startSendAndReceive(listener, searchOptions);
     }
 
@@ -96,20 +100,22 @@
     public synchronized void unregisterListener(
             @NonNull String serviceType, @NonNull MdnsServiceBrowserListener listener) {
         LOGGER.log("Unregistering listener for service type: %s", serviceType);
+        if (DBG) Log.d(TAG, "Unregistering listener for serviceType:" + serviceType);
         MdnsServiceTypeClient serviceTypeClient = serviceTypeClients.get(serviceType);
         if (serviceTypeClient == null) {
             return;
         }
         if (serviceTypeClient.stopSendAndReceive(listener)) {
             // No listener is registered for the service type anymore, remove it from the list of
-          // the
-            // service type clients.
+            // the service type clients.
             serviceTypeClients.remove(serviceType);
             if (serviceTypeClients.isEmpty()) {
                 // No discovery request. Stops the socket client.
                 socketClient.stopDiscovery();
             }
         }
+        // Unrequested the network.
+        socketClient.notifyNetworkUnrequested(listener);
     }
 
     @Override
diff --git a/service/mdns/com/android/server/connectivity/mdns/MdnsInterfaceSocket.java b/service/mdns/com/android/server/connectivity/mdns/MdnsInterfaceSocket.java
index 67c893d..d1290b6 100644
--- a/service/mdns/com/android/server/connectivity/mdns/MdnsInterfaceSocket.java
+++ b/service/mdns/com/android/server/connectivity/mdns/MdnsInterfaceSocket.java
@@ -22,10 +22,15 @@
 import android.annotation.NonNull;
 import android.net.LinkAddress;
 import android.net.util.SocketUtils;
+import android.os.Handler;
+import android.os.Looper;
 import android.os.ParcelFileDescriptor;
 import android.system.ErrnoException;
+import android.system.Os;
+import android.system.OsConstants;
 import android.util.Log;
 
+import java.io.FileDescriptor;
 import java.io.IOException;
 import java.net.DatagramPacket;
 import java.net.InetSocketAddress;
@@ -41,15 +46,19 @@
  * otherwise.
  *
  * @see MulticastSocket for javadoc of each public method.
+ * @see MulticastSocket for javadoc of each public method.
  */
 public class MdnsInterfaceSocket {
     private static final String TAG = MdnsInterfaceSocket.class.getSimpleName();
     @NonNull private final MulticastSocket mMulticastSocket;
     @NonNull private final NetworkInterface mNetworkInterface;
+    @NonNull private final MulticastPacketReader mPacketReader;
+    @NonNull private final ParcelFileDescriptor mFileDescriptor;
     private boolean mJoinedIpv4 = false;
     private boolean mJoinedIpv6 = false;
 
-    public MdnsInterfaceSocket(@NonNull NetworkInterface networkInterface, int port)
+    public MdnsInterfaceSocket(@NonNull NetworkInterface networkInterface, int port,
+            @NonNull Looper looper, @NonNull byte[] packetReadBuffer)
             throws IOException {
         mNetworkInterface = networkInterface;
         mMulticastSocket = new MulticastSocket(port);
@@ -58,11 +67,19 @@
         mMulticastSocket.setNetworkInterface(networkInterface);
 
         // Bind socket to the interface for receiving from that interface only.
-        try (ParcelFileDescriptor pfd = ParcelFileDescriptor.fromDatagramSocket(mMulticastSocket)) {
-            SocketUtils.bindSocketToInterface(pfd.getFileDescriptor(), mNetworkInterface.getName());
+        mFileDescriptor = ParcelFileDescriptor.fromDatagramSocket(mMulticastSocket);
+        try {
+            final FileDescriptor fd = mFileDescriptor.getFileDescriptor();
+            final int flags = Os.fcntlInt(fd, OsConstants.F_GETFL, 0);
+            Os.fcntlInt(fd, OsConstants.F_SETFL, flags | OsConstants.SOCK_NONBLOCK);
+            SocketUtils.bindSocketToInterface(fd, mNetworkInterface.getName());
         } catch (ErrnoException e) {
             throw new IOException("Error setting socket options", e);
         }
+
+        mPacketReader = new MulticastPacketReader(networkInterface.getName(), mFileDescriptor,
+                new Handler(looper), packetReadBuffer);
+        mPacketReader.start();
     }
 
     /**
@@ -74,23 +91,14 @@
         mMulticastSocket.send(packet);
     }
 
-    /**
-     * Receives a datagram packet from this socket.
-     *
-     * <p>This method could be used on any thread.
-     */
-    public void receive(@NonNull DatagramPacket packet) throws IOException {
-        mMulticastSocket.receive(packet);
-    }
-
-    private boolean hasIpv4Address(List<LinkAddress> addresses) {
+    private static boolean hasIpv4Address(@NonNull List<LinkAddress> addresses) {
         for (LinkAddress address : addresses) {
             if (address.isIpv4()) return true;
         }
         return false;
     }
 
-    private boolean hasIpv6Address(List<LinkAddress> addresses) {
+    private static boolean hasIpv6Address(@NonNull List<LinkAddress> addresses) {
         for (LinkAddress address : addresses) {
             if (address.isIpv6()) return true;
         }
@@ -103,7 +111,7 @@
         maybeJoinIpv6(addresses);
     }
 
-    private boolean joinGroup(InetSocketAddress multicastAddress) {
+    private boolean joinGroup(@NonNull InetSocketAddress multicastAddress) {
         try {
             mMulticastSocket.joinGroup(multicastAddress, mNetworkInterface);
             return true;
@@ -114,7 +122,7 @@
         }
     }
 
-    private void maybeJoinIpv4(List<LinkAddress> addresses) {
+    private void maybeJoinIpv4(@NonNull List<LinkAddress> addresses) {
         final boolean hasAddr = hasIpv4Address(addresses);
         if (!mJoinedIpv4 && hasAddr) {
             mJoinedIpv4 = joinGroup(MULTICAST_IPV4_ADDRESS);
@@ -124,7 +132,7 @@
         }
     }
 
-    private void maybeJoinIpv6(List<LinkAddress> addresses) {
+    private void maybeJoinIpv6(@NonNull List<LinkAddress> addresses) {
         final boolean hasAddr = hasIpv6Address(addresses);
         if (!mJoinedIpv6 && hasAddr) {
             mJoinedIpv6 = joinGroup(MULTICAST_IPV6_ADDRESS);
@@ -134,26 +142,26 @@
         }
     }
 
-    /*** Destroy this socket by leaving all joined multicast groups and closing this socket. */
+    /*** Destroy the socket */
     public void destroy() {
-        if (mJoinedIpv4) {
-            try {
-                mMulticastSocket.leaveGroup(MULTICAST_IPV4_ADDRESS, mNetworkInterface);
-            } catch (IOException e) {
-                Log.e(TAG, "Error leaving IPv4 group for " + mNetworkInterface, e);
-            }
-        }
-        if (mJoinedIpv6) {
-            try {
-                mMulticastSocket.leaveGroup(MULTICAST_IPV6_ADDRESS, mNetworkInterface);
-            } catch (IOException e) {
-                Log.e(TAG, "Error leaving IPv4 group for " + mNetworkInterface, e);
-            }
+        mPacketReader.stop();
+        try {
+            mFileDescriptor.close();
+        } catch (IOException e) {
+            Log.e(TAG, "Close file descriptor failed.");
         }
         mMulticastSocket.close();
     }
 
     /**
+     * Add a handler to receive callbacks when reads the packet from socket. If the handler is
+     * already set, this is a no-op.
+     */
+    public void addPacketHandler(@NonNull MulticastPacketReader.PacketHandler handler) {
+        mPacketReader.addPacketHandler(handler);
+    }
+
+    /**
      * Returns the network interface that this socket is bound to.
      *
      * <p>This method could be used on any thread.
diff --git a/service/mdns/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java b/service/mdns/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java
new file mode 100644
index 0000000..d959065
--- /dev/null
+++ b/service/mdns/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java
@@ -0,0 +1,219 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server.connectivity.mdns;
+
+import static com.android.server.connectivity.mdns.MdnsSocketProvider.ensureRunningOnHandlerThread;
+import static com.android.server.connectivity.mdns.MdnsSocketProvider.isNetworkMatched;
+
+import android.annotation.NonNull;
+import android.annotation.Nullable;
+import android.net.LinkAddress;
+import android.net.Network;
+import android.os.Handler;
+import android.os.Looper;
+import android.util.ArrayMap;
+import android.util.Log;
+
+import java.io.IOException;
+import java.net.DatagramPacket;
+import java.net.Inet4Address;
+import java.net.Inet6Address;
+import java.net.InetSocketAddress;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * The {@link MdnsMultinetworkSocketClient} manages the multinetwork socket for mDns
+ *
+ *  * <p>This class is not thread safe.
+ */
+public class MdnsMultinetworkSocketClient implements MdnsSocketClientBase {
+    private static final String TAG = MdnsMultinetworkSocketClient.class.getSimpleName();
+    private static final boolean DBG = MdnsDiscoveryManager.DBG;
+
+    @NonNull private final Handler mHandler;
+    @NonNull private final MdnsSocketProvider mSocketProvider;
+    @NonNull private final MdnsResponseDecoder mResponseDecoder;
+
+    private final Map<MdnsServiceBrowserListener, InterfaceSocketCallback> mRequestedNetworks =
+            new ArrayMap<>();
+    private final ArrayMap<MdnsInterfaceSocket, Network> mActiveNetworkSockets = new ArrayMap<>();
+    private final ArrayMap<MdnsInterfaceSocket, ReadPacketHandler> mSocketPacketHandlers =
+            new ArrayMap<>();
+    private MdnsSocketClientBase.Callback mCallback = null;
+    private int mReceivedPacketNumber = 0;
+
+    public MdnsMultinetworkSocketClient(@NonNull Looper looper,
+            @NonNull MdnsSocketProvider provider) {
+        mHandler = new Handler(looper);
+        mSocketProvider = provider;
+        mResponseDecoder = new MdnsResponseDecoder(
+                new MdnsResponseDecoder.Clock(), null /* serviceType */);
+    }
+
+    private class InterfaceSocketCallback implements MdnsSocketProvider.SocketCallback {
+        @Override
+        public void onSocketCreated(@NonNull Network network,
+                @NonNull MdnsInterfaceSocket socket, @NonNull List<LinkAddress> addresses) {
+            // The socket may be already created by other request before, try to get the stored
+            // ReadPacketHandler.
+            ReadPacketHandler handler = mSocketPacketHandlers.get(socket);
+            if (handler == null) {
+                // First request to create this socket. Initial a ReadPacketHandler for this socket.
+                handler = new ReadPacketHandler(network, socket.getInterface().getIndex());
+                mSocketPacketHandlers.put(socket, handler);
+            }
+            socket.addPacketHandler(handler);
+            mActiveNetworkSockets.put(socket, network);
+        }
+
+        @Override
+        public void onInterfaceDestroyed(@NonNull Network network,
+                @NonNull MdnsInterfaceSocket socket) {
+            mSocketPacketHandlers.remove(socket);
+            mActiveNetworkSockets.remove(socket);
+        }
+    }
+
+    private class ReadPacketHandler implements MulticastPacketReader.PacketHandler {
+        private final Network mNetwork;
+        private final int mInterfaceIndex;
+
+        ReadPacketHandler(@NonNull Network network, int interfaceIndex) {
+            mNetwork = network;
+            mInterfaceIndex = interfaceIndex;
+        }
+
+        @Override
+        public void handlePacket(byte[] recvbuf, int length, InetSocketAddress src) {
+            processResponsePacket(recvbuf, length, mInterfaceIndex, mNetwork);
+        }
+    }
+
+    /*** Set callback for receiving mDns response */
+    @Override
+    public void setCallback(@Nullable MdnsSocketClientBase.Callback callback) {
+        ensureRunningOnHandlerThread(mHandler);
+        mCallback = callback;
+    }
+
+    /***
+     * Notify that the given network is requested for mdns discovery / resolution
+     *
+     * @param listener the listener for discovery.
+     * @param network the target network for discovery. Null means discovery on all possible
+     *                interfaces.
+     */
+    @Override
+    public void notifyNetworkRequested(@NonNull MdnsServiceBrowserListener listener,
+            @Nullable Network network) {
+        ensureRunningOnHandlerThread(mHandler);
+        InterfaceSocketCallback callback = mRequestedNetworks.get(listener);
+        if (callback != null) {
+            throw new IllegalArgumentException("Can not register duplicated listener");
+        }
+
+        if (DBG) Log.d(TAG, "notifyNetworkRequested: network=" + network);
+        callback = new InterfaceSocketCallback();
+        mRequestedNetworks.put(listener, callback);
+        mSocketProvider.requestSocket(network, callback);
+    }
+
+    /*** Notify that the network is unrequested */
+    @Override
+    public void notifyNetworkUnrequested(@NonNull MdnsServiceBrowserListener listener) {
+        ensureRunningOnHandlerThread(mHandler);
+        final InterfaceSocketCallback callback = mRequestedNetworks.remove(listener);
+        if (callback == null) {
+            Log.e(TAG, "Can not be unrequested with unknown listener=" + listener);
+            return;
+        }
+        mSocketProvider.unrequestSocket(callback);
+    }
+
+    private void sendMdnsPacket(@NonNull DatagramPacket packet, @Nullable Network targetNetwork) {
+        final boolean isIpv6 = ((InetSocketAddress) packet.getSocketAddress()).getAddress()
+                instanceof Inet6Address;
+        final boolean isIpv4 = ((InetSocketAddress) packet.getSocketAddress()).getAddress()
+                instanceof Inet4Address;
+        for (int i = 0; i < mActiveNetworkSockets.size(); i++) {
+            final MdnsInterfaceSocket socket = mActiveNetworkSockets.keyAt(i);
+            final Network network = mActiveNetworkSockets.valueAt(i);
+            // Check ip capability and network before sending packet
+            if (((isIpv6 && socket.hasJoinedIpv6()) || (isIpv4 && socket.hasJoinedIpv4()))
+                    && isNetworkMatched(targetNetwork, network)) {
+                try {
+                    socket.send(packet);
+                } catch (IOException e) {
+                    Log.e(TAG, "Failed to send a mDNS packet.", e);
+                }
+            }
+        }
+    }
+
+    private void processResponsePacket(byte[] recvbuf, int length, int interfaceIndex,
+            @NonNull Network network) {
+        int packetNumber = ++mReceivedPacketNumber;
+
+        final List<MdnsResponse> responses = new ArrayList<>();
+        final int errorCode = mResponseDecoder.decode(
+                recvbuf, length, responses, interfaceIndex, network);
+        if (errorCode == MdnsResponseDecoder.SUCCESS) {
+            for (MdnsResponse response : responses) {
+                if (mCallback != null) {
+                    mCallback.onResponseReceived(response);
+                }
+            }
+        } else if (errorCode != MdnsResponseErrorCode.ERROR_NOT_RESPONSE_MESSAGE) {
+            if (mCallback != null) {
+                mCallback.onFailedToParseMdnsResponse(packetNumber, errorCode);
+            }
+        }
+    }
+
+    /** Sends a mDNS request packet that asks for multicast response. */
+    @Override
+    public void sendMulticastPacket(@NonNull DatagramPacket packet) {
+        sendMulticastPacket(packet, null /* network */);
+    }
+
+    /**
+     * Sends a mDNS request packet via given network that asks for multicast response. Null network
+     * means sending packet via all networks.
+     */
+    @Override
+    public void sendMulticastPacket(@NonNull DatagramPacket packet, @Nullable Network network) {
+        mHandler.post(() -> sendMdnsPacket(packet, network));
+    }
+
+    /** Sends a mDNS request packet that asks for unicast response. */
+    @Override
+    public void sendUnicastPacket(@NonNull DatagramPacket packet) {
+        sendUnicastPacket(packet, null /* network */);
+    }
+
+    /**
+     * Sends a mDNS request packet via given network that asks for unicast response. Null network
+     * means sending packet via all networks.
+     */
+    @Override
+    public void sendUnicastPacket(@NonNull DatagramPacket packet, @Nullable Network network) {
+        // TODO: Separate unicast packet.
+        mHandler.post(() -> sendMdnsPacket(packet, network));
+    }
+}
diff --git a/service/mdns/com/android/server/connectivity/mdns/MdnsPacketReader.java b/service/mdns/com/android/server/connectivity/mdns/MdnsPacketReader.java
index 856a2cd..aa38844 100644
--- a/service/mdns/com/android/server/connectivity/mdns/MdnsPacketReader.java
+++ b/service/mdns/com/android/server/connectivity/mdns/MdnsPacketReader.java
@@ -38,8 +38,13 @@
 
     /** Constructs a reader for the given packet. */
     public MdnsPacketReader(DatagramPacket packet) {
-        buf = packet.getData();
-        count = packet.getLength();
+        this(packet.getData(), packet.getLength());
+    }
+
+    /** Constructs a reader for the given packet. */
+    public MdnsPacketReader(byte[] buffer, int length) {
+        buf = buffer;
+        count = length;
         pos = 0;
         limit = -1;
         labelDictionary = new SparseArray<>(16);
diff --git a/service/mdns/com/android/server/connectivity/mdns/MdnsResponseDecoder.java b/service/mdns/com/android/server/connectivity/mdns/MdnsResponseDecoder.java
index 7cf84f6..50f2069 100644
--- a/service/mdns/com/android/server/connectivity/mdns/MdnsResponseDecoder.java
+++ b/service/mdns/com/android/server/connectivity/mdns/MdnsResponseDecoder.java
@@ -101,7 +101,24 @@
      */
     public int decode(@NonNull DatagramPacket packet, @NonNull List<MdnsResponse> responses,
             int interfaceIndex, @Nullable Network network) {
-        MdnsPacketReader reader = new MdnsPacketReader(packet);
+        return decode(packet.getData(), packet.getLength(), responses, interfaceIndex, network);
+    }
+
+    /**
+     * Decodes all mDNS responses for the desired service type from a packet. The class does not
+     * check
+     * the responses for completeness; the caller should do that.
+     *
+     * @param recvbuf The received data buffer to read from.
+     * @param length The length of received data buffer.
+     * @param interfaceIndex the network interface index (or {@link
+     *     MdnsSocket#INTERFACE_INDEX_UNSPECIFIED} if not known) at which the packet was received
+     * @param network the network at which the packet was received, or null if it is unknown.
+     * @return A list of mDNS responses, or null if the packet contained no appropriate responses.
+     */
+    public int decode(@NonNull byte[] recvbuf, int length, @NonNull List<MdnsResponse> responses,
+            int interfaceIndex, @Nullable Network network) {
+        MdnsPacketReader reader = new MdnsPacketReader(recvbuf, length);
 
         List<MdnsRecord> records;
         try {
diff --git a/service/mdns/com/android/server/connectivity/mdns/MdnsResponseErrorCode.java b/service/mdns/com/android/server/connectivity/mdns/MdnsResponseErrorCode.java
index fcf9058..73a7e3a 100644
--- a/service/mdns/com/android/server/connectivity/mdns/MdnsResponseErrorCode.java
+++ b/service/mdns/com/android/server/connectivity/mdns/MdnsResponseErrorCode.java
@@ -35,4 +35,6 @@
     public static final int ERROR_READING_TXT_RDATA = 10;
     public static final int ERROR_SKIPPING_UNKNOWN_RECORD = 11;
     public static final int ERROR_END_OF_FILE = 12;
+    public static final int ERROR_READING_NSEC_RDATA = 13;
+    public static final int ERROR_READING_ANY_RDATA = 14;
 }
\ No newline at end of file
diff --git a/service/mdns/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java b/service/mdns/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
index 538f376..d26fbdb 100644
--- a/service/mdns/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
+++ b/service/mdns/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
@@ -20,6 +20,7 @@
 
 import android.annotation.NonNull;
 import android.annotation.Nullable;
+import android.net.Network;
 import android.os.SystemClock;
 import android.text.TextUtils;
 import android.util.ArraySet;
@@ -52,7 +53,7 @@
 
     private final String serviceType;
     private final String[] serviceTypeLabels;
-    private final MdnsSocketClient socketClient;
+    private final MdnsSocketClientBase socketClient;
     private final ScheduledExecutorService executor;
     private final Object lock = new Object();
     private final Set<MdnsServiceBrowserListener> listeners = new ArraySet<>();
@@ -77,11 +78,11 @@
      * Constructor of {@link MdnsServiceTypeClient}.
      *
      * @param socketClient Sends and receives mDNS packet.
-     * @param executor     A {@link ScheduledExecutorService} used to schedule query tasks.
+     * @param executor         A {@link ScheduledExecutorService} used to schedule query tasks.
      */
     public MdnsServiceTypeClient(
             @NonNull String serviceType,
-            @NonNull MdnsSocketClient socketClient,
+            @NonNull MdnsSocketClientBase socketClient,
             @NonNull ScheduledExecutorService executor) {
         this.serviceType = serviceType;
         this.socketClient = socketClient;
@@ -169,7 +170,8 @@
                                     new QueryTaskConfig(
                                             searchOptions.getSubtypes(),
                                             searchOptions.isPassiveMode(),
-                                            ++currentSessionId)));
+                                            ++currentSessionId,
+                                            searchOptions.getNetwork())));
         }
     }
 
@@ -322,9 +324,10 @@
         private int burstCounter;
         private int timeToRunNextTaskInMs;
         private boolean isFirstBurst;
+        @Nullable private final Network network;
 
         QueryTaskConfig(@NonNull Collection<String> subtypes, boolean usePassiveMode,
-                long sessionId) {
+                long sessionId, @Nullable Network network) {
             this.usePassiveMode = usePassiveMode;
             this.subtypes = new ArrayList<>(subtypes);
             this.queriesPerBurst = QUERIES_PER_BURST;
@@ -346,6 +349,7 @@
                 // doubles until it maxes out at TIME_BETWEEN_BURSTS_MS.
                 this.timeBetweenBurstsInMs = INITIAL_TIME_BETWEEN_BURSTS_MS;
             }
+            this.network = network;
         }
 
         QueryTaskConfig getConfigForNextRun() {
@@ -405,7 +409,8 @@
                                 serviceType,
                                 config.subtypes,
                                 config.expectUnicastResponse,
-                                config.transactionId)
+                                config.transactionId,
+                                config.network)
                                 .call();
             } catch (RuntimeException e) {
                 LOGGER.e(String.format("Failed to run EnqueueMdnsQueryCallable for subtype: %s",
diff --git a/service/mdns/com/android/server/connectivity/mdns/MdnsSocket.java b/service/mdns/com/android/server/connectivity/mdns/MdnsSocket.java
index 64c4495..5fd1354 100644
--- a/service/mdns/com/android/server/connectivity/mdns/MdnsSocket.java
+++ b/service/mdns/com/android/server/connectivity/mdns/MdnsSocket.java
@@ -40,9 +40,9 @@
     private static final MdnsLogger LOGGER = new MdnsLogger("MdnsSocket");
 
     static final int INTERFACE_INDEX_UNSPECIFIED = -1;
-    protected static final InetSocketAddress MULTICAST_IPV4_ADDRESS =
+    public static final InetSocketAddress MULTICAST_IPV4_ADDRESS =
             new InetSocketAddress(MdnsConstants.getMdnsIPv4Address(), MdnsConstants.MDNS_PORT);
-    protected static final InetSocketAddress MULTICAST_IPV6_ADDRESS =
+    public static final InetSocketAddress MULTICAST_IPV6_ADDRESS =
             new InetSocketAddress(MdnsConstants.getMdnsIPv6Address(), MdnsConstants.MDNS_PORT);
     private final MulticastNetworkInterfaceProvider multicastNetworkInterfaceProvider;
     private final MulticastSocket multicastSocket;
diff --git a/service/mdns/com/android/server/connectivity/mdns/MdnsSocketClient.java b/service/mdns/com/android/server/connectivity/mdns/MdnsSocketClient.java
index 6a321d1..907687e 100644
--- a/service/mdns/com/android/server/connectivity/mdns/MdnsSocketClient.java
+++ b/service/mdns/com/android/server/connectivity/mdns/MdnsSocketClient.java
@@ -16,6 +16,8 @@
 
 package com.android.server.connectivity.mdns;
 
+import static com.android.server.connectivity.mdns.MdnsSocketClientBase.Callback;
+
 import android.Manifest.permission;
 import android.annotation.NonNull;
 import android.annotation.Nullable;
@@ -47,7 +49,7 @@
  *
  * <p>See https://tools.ietf.org/html/rfc6763 (namely sections 4 and 5).
  */
-public class MdnsSocketClient {
+public class MdnsSocketClient implements MdnsSocketClientBase {
 
     private static final String TAG = "MdnsClient";
     // TODO: The following values are copied from cast module. We need to think about the
@@ -116,11 +118,13 @@
         }
     }
 
+    @Override
     public synchronized void setCallback(@Nullable Callback callback) {
         this.callback = callback;
     }
 
     @RequiresPermission(permission.CHANGE_WIFI_MULTICAST_STATE)
+    @Override
     public synchronized void startDiscovery() throws IOException {
         if (multicastSocket != null) {
             LOGGER.w("Discovery is already in progress.");
@@ -160,6 +164,7 @@
     }
 
     @RequiresPermission(permission.CHANGE_WIFI_MULTICAST_STATE)
+    @Override
     public void stopDiscovery() {
         LOGGER.log("Stop discovery.");
         if (multicastSocket == null && unicastSocket == null) {
@@ -195,11 +200,13 @@
     }
 
     /** Sends a mDNS request packet that asks for multicast response. */
+    @Override
     public void sendMulticastPacket(@NonNull DatagramPacket packet) {
         sendMdnsPacket(packet, multicastPacketQueue);
     }
 
     /** Sends a mDNS request packet that asks for unicast response. */
+    @Override
     public void sendUnicastPacket(DatagramPacket packet) {
         if (useSeparateSocketForUnicast) {
             sendMdnsPacket(packet, unicastPacketQueue);
@@ -512,11 +519,4 @@
     public boolean isOnIPv6OnlyNetwork() {
         return multicastSocket != null && multicastSocket.isOnIPv6OnlyNetwork();
     }
-
-    /** Callback for {@link MdnsSocketClient}. */
-    public interface Callback {
-        void onResponseReceived(@NonNull MdnsResponse response);
-
-        void onFailedToParseMdnsResponse(int receivedPacketNumber, int errorCode);
-    }
 }
\ No newline at end of file
diff --git a/service/mdns/com/android/server/connectivity/mdns/MdnsSocketClientBase.java b/service/mdns/com/android/server/connectivity/mdns/MdnsSocketClientBase.java
new file mode 100644
index 0000000..23504a0
--- /dev/null
+++ b/service/mdns/com/android/server/connectivity/mdns/MdnsSocketClientBase.java
@@ -0,0 +1,80 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server.connectivity.mdns;
+
+import android.annotation.NonNull;
+import android.annotation.Nullable;
+import android.net.Network;
+
+import java.io.IOException;
+import java.net.DatagramPacket;
+
+/**
+ * Base class for multicast socket client.
+ *
+ * @hide
+ */
+public interface MdnsSocketClientBase {
+    /*** Start mDns discovery on given network. */
+    default void startDiscovery() throws IOException { }
+
+    /*** Stop mDns discovery. */
+    default void stopDiscovery() { }
+
+    /*** Set callback for receiving mDns response */
+    void setCallback(@Nullable Callback callback);
+
+    /*** Sends a mDNS request packet that asks for multicast response. */
+    void sendMulticastPacket(@NonNull DatagramPacket packet);
+
+    /**
+     * Sends a mDNS request packet via given network that asks for multicast response. Null network
+     * means sending packet via all networks.
+     */
+    default void sendMulticastPacket(@NonNull DatagramPacket packet, @Nullable Network network) {
+        throw new UnsupportedOperationException(
+                "This socket client doesn't support per-network sending");
+    }
+
+    /*** Sends a mDNS request packet that asks for unicast response. */
+    void sendUnicastPacket(@NonNull DatagramPacket packet);
+
+    /**
+     * Sends a mDNS request packet via given network that asks for unicast response. Null network
+     * means sending packet via all networks.
+     */
+    default void sendUnicastPacket(@NonNull DatagramPacket packet, @Nullable Network network) {
+        throw new UnsupportedOperationException(
+                "This socket client doesn't support per-network sending");
+    }
+
+    /*** Notify that the given network is requested for mdns discovery / resolution */
+    default void notifyNetworkRequested(@NonNull MdnsServiceBrowserListener listener,
+            @Nullable Network network) { }
+
+    /*** Notify that the network is unrequested */
+    default void notifyNetworkUnrequested(@NonNull MdnsServiceBrowserListener listener) { }
+
+    /*** Callback for mdns response  */
+    interface Callback {
+        /*** Receive a mdns response */
+        void onResponseReceived(@NonNull MdnsResponse response);
+
+        /*** Parse a mdns response failed */
+        void onFailedToParseMdnsResponse(int receivedPacketNumber, int errorCode);
+    }
+}
diff --git a/service/mdns/com/android/server/connectivity/mdns/MdnsSocketProvider.java b/service/mdns/com/android/server/connectivity/mdns/MdnsSocketProvider.java
index d3bf060..9298852 100644
--- a/service/mdns/com/android/server/connectivity/mdns/MdnsSocketProvider.java
+++ b/service/mdns/com/android/server/connectivity/mdns/MdnsSocketProvider.java
@@ -35,6 +35,7 @@
 import android.util.Log;
 
 import com.android.internal.annotations.VisibleForTesting;
+import com.android.net.module.util.CollectionUtils;
 import com.android.net.module.util.LinkPropertiesUtils.CompareResult;
 import com.android.net.module.util.ip.NetlinkMonitor;
 import com.android.net.module.util.netlink.NetlinkConstants;
@@ -42,7 +43,6 @@
 import com.android.server.connectivity.mdns.util.MdnsLogger;
 
 import java.io.IOException;
-import java.net.InterfaceAddress;
 import java.net.NetworkInterface;
 import java.net.SocketException;
 import java.util.ArrayList;
@@ -60,8 +60,13 @@
 public class MdnsSocketProvider {
     private static final String TAG = MdnsSocketProvider.class.getSimpleName();
     private static final boolean DBG = MdnsDiscoveryManager.DBG;
+    // This buffer size matches what MdnsSocketClient uses currently.
+    // But 1440 should generally be enough because of standard Ethernet.
+    // Note: mdnsresponder mDNSEmbeddedAPI.h uses 8940 for Ethernet jumbo frames.
+    private static final int READ_BUFFER_SIZE = 2048;
     private static final MdnsLogger LOGGER = new MdnsLogger(TAG);
     @NonNull private final Context mContext;
+    @NonNull private final Looper mLooper;
     @NonNull private final Handler mHandler;
     @NonNull private final Dependencies mDependencies;
     @NonNull private final NetworkCallback mNetworkCallback;
@@ -75,6 +80,7 @@
             new ArrayMap<>();
     private final List<String> mLocalOnlyInterfaces = new ArrayList<>();
     private final List<String> mTetheredInterfaces = new ArrayList<>();
+    private final byte[] mPacketReadBuffer = new byte[READ_BUFFER_SIZE];
     private boolean mMonitoringSockets = false;
 
     public MdnsSocketProvider(@NonNull Context context, @NonNull Looper looper) {
@@ -84,6 +90,7 @@
     MdnsSocketProvider(@NonNull Context context, @NonNull Looper looper,
             @NonNull Dependencies deps) {
         mContext = context;
+        mLooper = looper;
         mHandler = new Handler(looper);
         mDependencies = deps;
         mNetworkCallback = new NetworkCallback() {
@@ -119,32 +126,33 @@
     @VisibleForTesting
     public static class Dependencies {
         /*** Get network interface by given interface name */
-        public NetworkInterfaceWrapper getNetworkInterfaceByName(String interfaceName)
+        public NetworkInterfaceWrapper getNetworkInterfaceByName(@NonNull String interfaceName)
                 throws SocketException {
             final NetworkInterface ni = NetworkInterface.getByName(interfaceName);
             return ni == null ? null : new NetworkInterfaceWrapper(ni);
         }
 
         /*** Check whether given network interface can support mdns */
-        public boolean canScanOnInterface(NetworkInterfaceWrapper networkInterface) {
+        public boolean canScanOnInterface(@NonNull NetworkInterfaceWrapper networkInterface) {
             return MulticastNetworkInterfaceProvider.canScanOnInterface(networkInterface);
         }
 
         /*** Create a MdnsInterfaceSocket */
-        public MdnsInterfaceSocket createMdnsInterfaceSocket(NetworkInterface networkInterface,
-                int port) throws IOException {
-            return new MdnsInterfaceSocket(networkInterface, port);
+        public MdnsInterfaceSocket createMdnsInterfaceSocket(
+                @NonNull NetworkInterface networkInterface, int port, @NonNull Looper looper,
+                @NonNull byte[] packetReadBuffer) throws IOException {
+            return new MdnsInterfaceSocket(networkInterface, port, looper, packetReadBuffer);
         }
     }
 
     /*** Data class for storing socket related info  */
     private static class SocketInfo {
         final MdnsInterfaceSocket mSocket;
-        final List<LinkAddress> mAddresses = new ArrayList<>();
+        final List<LinkAddress> mAddresses;
 
         SocketInfo(MdnsInterfaceSocket socket, List<LinkAddress> addresses) {
             mSocket = socket;
-            mAddresses.addAll(addresses);
+            mAddresses = new ArrayList<>(addresses);
         }
     }
 
@@ -160,8 +168,9 @@
         }
     }
 
-    private void ensureRunningOnHandlerThread() {
-        if (mHandler.getLooper().getThread() != Thread.currentThread()) {
+    /*** Ensure that current running thread is same as given handler thread */
+    public static void ensureRunningOnHandlerThread(Handler handler) {
+        if (handler.getLooper().getThread() != Thread.currentThread()) {
             throw new IllegalStateException(
                     "Not running on Handler thread: " + Thread.currentThread().getName());
         }
@@ -169,7 +178,7 @@
 
     /*** Start monitoring sockets by listening callbacks for sockets creation or removal */
     public void startMonitoringSockets() {
-        ensureRunningOnHandlerThread();
+        ensureRunningOnHandlerThread(mHandler);
         if (mMonitoringSockets) {
             Log.d(TAG, "Already monitoring sockets.");
             return;
@@ -188,7 +197,7 @@
 
     /*** Stop monitoring sockets and unregister callbacks */
     public void stopMonitoringSockets() {
-        ensureRunningOnHandlerThread();
+        ensureRunningOnHandlerThread(mHandler);
         if (!mMonitoringSockets) {
             Log.d(TAG, "Monitoring sockets hasn't been started.");
             return;
@@ -204,19 +213,15 @@
         mMonitoringSockets = false;
     }
 
-    private static boolean isNetworkMatched(@Nullable Network targetNetwork,
+    /*** Check whether the target network is matched current network */
+    public static boolean isNetworkMatched(@Nullable Network targetNetwork,
             @NonNull Network currentNetwork) {
         return targetNetwork == null || targetNetwork.equals(currentNetwork);
     }
 
     private boolean matchRequestedNetwork(Network network) {
-        for (int i = 0; i < mCallbacksToRequestedNetworks.size(); i++) {
-            final Network requestedNetwork =  mCallbacksToRequestedNetworks.valueAt(i);
-            if (isNetworkMatched(requestedNetwork, network)) {
-                return true;
-            }
-        }
-        return false;
+        return hasAllNetworksRequest()
+                || mCallbacksToRequestedNetworks.containsValue(network);
     }
 
     private boolean hasAllNetworksRequest() {
@@ -279,15 +284,6 @@
         current.addAll(updated);
     }
 
-    private static List<LinkAddress> getLinkAddressFromNetworkInterface(
-            NetworkInterfaceWrapper networkInterface) {
-        List<LinkAddress> addresses = new ArrayList<>();
-        for (InterfaceAddress address : networkInterface.getInterfaceAddresses()) {
-            addresses.add(new LinkAddress(address));
-        }
-        return addresses;
-    }
-
     private void createSocket(Network network, LinkProperties lp) {
         final String interfaceName = lp.getInterfaceName();
         if (interfaceName == null) {
@@ -307,10 +303,12 @@
                         + " with interfaceName:" + interfaceName);
             }
             final MdnsInterfaceSocket socket = mDependencies.createMdnsInterfaceSocket(
-                    networkInterface.getNetworkInterface(), MdnsConstants.MDNS_PORT);
+                    networkInterface.getNetworkInterface(), MdnsConstants.MDNS_PORT, mLooper,
+                    mPacketReadBuffer);
             final List<LinkAddress> addresses;
             if (network.netId == INetd.LOCAL_NET_ID) {
-                addresses = getLinkAddressFromNetworkInterface(networkInterface);
+                addresses = CollectionUtils.map(
+                        networkInterface.getInterfaceAddresses(), LinkAddress::new);
                 mTetherInterfaceSockets.put(interfaceName, new SocketInfo(socket, addresses));
             } else {
                 addresses = lp.getLinkAddresses();
@@ -402,7 +400,7 @@
      * @param cb the callback to listen the socket creation.
      */
     public void requestSocket(@Nullable Network network, @NonNull SocketCallback cb) {
-        ensureRunningOnHandlerThread();
+        ensureRunningOnHandlerThread(mHandler);
         mCallbacksToRequestedNetworks.put(cb, network);
         if (network == null) {
             // Does not specify a required network, create sockets for all possible
@@ -425,7 +423,7 @@
 
     /*** Unrequest the socket */
     public void unrequestSocket(@NonNull SocketCallback cb) {
-        ensureRunningOnHandlerThread();
+        ensureRunningOnHandlerThread(mHandler);
         mCallbacksToRequestedNetworks.remove(cb);
         if (hasAllNetworksRequest()) {
             // Still has a request for all networks (interfaces).
@@ -434,16 +432,24 @@
 
         // Check if remaining requests are matched any of sockets.
         for (int i = mNetworkSockets.size() - 1; i >= 0; i--) {
-            if (matchRequestedNetwork(mNetworkSockets.keyAt(i))) continue;
-            mNetworkSockets.removeAt(i).mSocket.destroy();
+            final Network network = mNetworkSockets.keyAt(i);
+            if (matchRequestedNetwork(network)) continue;
+            final SocketInfo info = mNetworkSockets.removeAt(i);
+            info.mSocket.destroy();
+            // Still notify to unrequester for socket destroy.
+            cb.onInterfaceDestroyed(network, info.mSocket);
         }
 
         // Remove all sockets for tethering interface because these sockets do not have associated
         // networks, and they should invoke by a request for all networks (interfaces). If there is
         // no such request, the sockets for tethering interface should be removed.
         for (int i = mTetherInterfaceSockets.size() - 1; i >= 0; i--) {
-            mTetherInterfaceSockets.removeAt(i).mSocket.destroy();
+            final SocketInfo info = mTetherInterfaceSockets.valueAt(i);
+            info.mSocket.destroy();
+            // Still notify to unrequester for socket destroy.
+            cb.onInterfaceDestroyed(new Network(INetd.LOCAL_NET_ID), info.mSocket);
         }
+        mTetherInterfaceSockets.clear();
     }
 
     /*** Callbacks for listening socket changes */
diff --git a/service/mdns/com/android/server/connectivity/mdns/MulticastPacketReader.java b/service/mdns/com/android/server/connectivity/mdns/MulticastPacketReader.java
new file mode 100644
index 0000000..20cc47f
--- /dev/null
+++ b/service/mdns/com/android/server/connectivity/mdns/MulticastPacketReader.java
@@ -0,0 +1,111 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server.connectivity.mdns;
+
+import static com.android.server.connectivity.mdns.MdnsSocketProvider.ensureRunningOnHandlerThread;
+
+import android.annotation.NonNull;
+import android.os.Handler;
+import android.os.ParcelFileDescriptor;
+import android.system.Os;
+import android.util.ArraySet;
+
+import com.android.net.module.util.FdEventsReader;
+
+import java.io.FileDescriptor;
+import java.net.InetSocketAddress;
+import java.util.Set;
+
+/** Simple reader for mDNS packets. */
+public class MulticastPacketReader extends FdEventsReader<MulticastPacketReader.RecvBuffer> {
+    @NonNull
+    private final String mLogTag;
+    @NonNull
+    private final ParcelFileDescriptor mSocket;
+    @NonNull
+    private final Handler mHandler;
+    @NonNull
+    private final Set<PacketHandler> mPacketHandlers = new ArraySet<>();
+
+    interface PacketHandler {
+        void handlePacket(byte[] recvbuf, int length, InetSocketAddress src);
+    }
+
+    public static final class RecvBuffer {
+        final byte[] data;
+        final InetSocketAddress src;
+
+        private RecvBuffer(byte[] data, InetSocketAddress src) {
+            this.data = data;
+            this.src = src;
+        }
+    }
+
+    /**
+     * Create a new {@link MulticastPacketReader}.
+     * @param socket Socket to read from. This will *not* be closed when the reader terminates.
+     * @param buffer Buffer to read packets into. Will only be used from the handler thread.
+     */
+    protected MulticastPacketReader(@NonNull String interfaceTag,
+            @NonNull ParcelFileDescriptor socket, @NonNull Handler handler,
+            @NonNull byte[] buffer) {
+        super(handler, new RecvBuffer(buffer, new InetSocketAddress()));
+        mLogTag = MulticastPacketReader.class.getSimpleName() + "/" + interfaceTag;
+        mSocket = socket;
+        mHandler = handler;
+    }
+
+    @Override
+    protected int recvBufSize(@NonNull RecvBuffer buffer) {
+        return buffer.data.length;
+    }
+
+    @Override
+    protected FileDescriptor createFd() {
+        // Keep a reference to the PFD as it would close the fd in its finalizer otherwise
+        return mSocket.getFileDescriptor();
+    }
+
+    @Override
+    protected void onStop() {
+        // Do nothing (do not close the FD)
+    }
+
+    @Override
+    protected int readPacket(@NonNull FileDescriptor fd, @NonNull RecvBuffer buffer)
+            throws Exception {
+        return Os.recvfrom(
+                fd, buffer.data, 0, buffer.data.length, 0 /* flags */, buffer.src);
+    }
+
+    @Override
+    protected void handlePacket(@NonNull RecvBuffer recvbuf, int length) {
+        for (PacketHandler handler : mPacketHandlers) {
+            handler.handlePacket(recvbuf.data, length, recvbuf.src);
+        }
+    }
+
+    /**
+     * Add a packet handler to deal with received packets. If the handler is already set,
+     * this is a no-op.
+     */
+    public void addPacketHandler(@NonNull PacketHandler handler) {
+        ensureRunningOnHandlerThread(mHandler);
+        mPacketHandlers.add(handler);
+    }
+}
+
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsDiscoveryManagerTests.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsDiscoveryManagerTests.java
index 3e3c3bf..83e7696 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsDiscoveryManagerTests.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsDiscoveryManagerTests.java
@@ -46,7 +46,7 @@
     private static final String SERVICE_TYPE_2 = "_test._tcp.local";
 
     @Mock private ExecutorProvider executorProvider;
-    @Mock private MdnsSocketClient socketClient;
+    @Mock private MdnsSocketClientBase socketClient;
     @Mock private MdnsServiceTypeClient mockServiceTypeClientOne;
     @Mock private MdnsServiceTypeClient mockServiceTypeClientTwo;
 
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClientTest.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClientTest.java
new file mode 100644
index 0000000..9d42a65
--- /dev/null
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClientTest.java
@@ -0,0 +1,161 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server.connectivity.mdns;
+
+import static com.android.server.connectivity.mdns.MdnsSocketProvider.SocketCallback;
+import static com.android.server.connectivity.mdns.MulticastPacketReader.PacketHandler;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+import static org.mockito.Mockito.doReturn;
+import static org.mockito.Mockito.eq;
+import static org.mockito.Mockito.timeout;
+import static org.mockito.Mockito.verify;
+
+import android.net.InetAddresses;
+import android.net.Network;
+import android.os.Build;
+import android.os.Handler;
+import android.os.HandlerThread;
+
+import com.android.net.module.util.HexDump;
+import com.android.testutils.DevSdkIgnoreRule;
+import com.android.testutils.DevSdkIgnoreRunner;
+import com.android.testutils.HandlerUtils;
+
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.ArgumentCaptor;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+
+import java.io.IOException;
+import java.lang.reflect.Constructor;
+import java.net.DatagramPacket;
+import java.net.NetworkInterface;
+import java.net.SocketException;
+import java.util.List;
+
+@RunWith(DevSdkIgnoreRunner.class)
+@DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.S_V2)
+public class MdnsMultinetworkSocketClientTest {
+    private static final byte[] BUFFER = new byte[10];
+    private static final long DEFAULT_TIMEOUT = 2000L;
+    @Mock private Network mNetwork;
+    @Mock private MdnsSocketProvider mProvider;
+    @Mock private MdnsInterfaceSocket mSocket;
+    @Mock private MdnsServiceBrowserListener mListener;
+    @Mock private MdnsSocketClientBase.Callback mCallback;
+    private MdnsMultinetworkSocketClient mSocketClient;
+    private Handler mHandler;
+
+    @Before
+    public void setUp() throws SocketException {
+        MockitoAnnotations.initMocks(this);
+        final HandlerThread thread = new HandlerThread("MdnsMultinetworkSocketClientTest");
+        thread.start();
+        mHandler = new Handler(thread.getLooper());
+        mSocketClient = new MdnsMultinetworkSocketClient(thread.getLooper(), mProvider);
+        mHandler.post(() -> mSocketClient.setCallback(mCallback));
+    }
+
+    private SocketCallback expectSocketCallback() {
+        final ArgumentCaptor<SocketCallback> callbackCaptor =
+                ArgumentCaptor.forClass(SocketCallback.class);
+        mHandler.post(() -> mSocketClient.notifyNetworkRequested(mListener, mNetwork));
+        verify(mProvider, timeout(DEFAULT_TIMEOUT))
+                .requestSocket(eq(mNetwork), callbackCaptor.capture());
+        return callbackCaptor.getValue();
+    }
+
+    private NetworkInterface createEmptyNetworkInterface() {
+        try {
+            Constructor<NetworkInterface> constructor =
+                    NetworkInterface.class.getDeclaredConstructor();
+            constructor.setAccessible(true);
+            return constructor.newInstance();
+        } catch (Exception e) {
+            throw new RuntimeException(e);
+        }
+    }
+
+    @Test
+    public void testSendPacket() throws IOException {
+        final SocketCallback callback = expectSocketCallback();
+        final DatagramPacket ipv4Packet = new DatagramPacket(BUFFER, 0 /* offset */, BUFFER.length,
+                InetAddresses.parseNumericAddress("192.0.2.1"), 0 /* port */);
+        final DatagramPacket ipv6Packet = new DatagramPacket(BUFFER, 0 /* offset */, BUFFER.length,
+                InetAddresses.parseNumericAddress("2001:db8::"), 0 /* port */);
+        doReturn(true).when(mSocket).hasJoinedIpv4();
+        doReturn(true).when(mSocket).hasJoinedIpv6();
+        doReturn(createEmptyNetworkInterface()).when(mSocket).getInterface();
+        // Notify socket created
+        callback.onSocketCreated(mNetwork, mSocket, List.of());
+
+        // Send packet to IPv4 with target network and verify sending has been called.
+        mSocketClient.sendMulticastPacket(ipv4Packet, mNetwork);
+        HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
+        verify(mSocket).send(ipv4Packet);
+
+        // Send packet to IPv6 without target network and verify sending has been called.
+        mSocketClient.sendMulticastPacket(ipv6Packet);
+        HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
+        verify(mSocket).send(ipv6Packet);
+    }
+
+    @Test
+    public void testReceivePacket() {
+        final SocketCallback callback = expectSocketCallback();
+        final byte[] data = HexDump.hexStringToByteArray(
+                // scapy.raw(scapy.dns_compress(
+                //     scapy.DNS(rd=0, qr=1, aa=1, qd = None,
+                //     an =
+                //     scapy.DNSRR(type='PTR', rrname='_testtype._tcp.local',
+                //         rdata='testservice._testtype._tcp.local', rclass='IN', ttl=4500) /
+                //     scapy.DNSRRSRV(rrname='testservice._testtype._tcp.local', rclass=0x8001,
+                //         port=31234, target='Android.local', ttl=120))
+                // )).hex().upper()
+                "000084000000000200000000095F7465737474797065045F746370056C6F63616C00000C0001000011"
+                        + "94000E0B7465737473657276696365C00CC02C00218001000000780010000000007A0207"
+                        + "416E64726F6964C01B");
+
+        doReturn(createEmptyNetworkInterface()).when(mSocket).getInterface();
+        // Notify socket created
+        callback.onSocketCreated(mNetwork, mSocket, List.of());
+
+        final ArgumentCaptor<PacketHandler> handlerCaptor =
+                ArgumentCaptor.forClass(PacketHandler.class);
+        verify(mSocket).addPacketHandler(handlerCaptor.capture());
+
+        // Send the data and verify the received records.
+        final PacketHandler handler = handlerCaptor.getValue();
+        handler.handlePacket(data, data.length, null /* src */);
+        final ArgumentCaptor<MdnsResponse> responseCaptor =
+                ArgumentCaptor.forClass(MdnsResponse.class);
+        verify(mCallback).onResponseReceived(responseCaptor.capture());
+        final MdnsResponse response = responseCaptor.getValue();
+        assertTrue(response.hasPointerRecords());
+        assertArrayEquals("_testtype._tcp.local".split("\\."),
+                response.getPointerRecords().get(0).getName());
+        assertTrue(response.hasServiceRecord());
+        assertEquals("testservice", response.getServiceRecord().getServiceInstanceName());
+        assertEquals("Android.local".split("\\."),
+                response.getServiceRecord().getServiceHost());
+    }
+}
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java
index 697116c..a45ca68 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java
@@ -62,7 +62,7 @@
 import java.net.DatagramPacket;
 import java.net.Inet4Address;
 import java.net.Inet6Address;
-import java.net.SocketAddress;
+import java.net.InetSocketAddress;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
@@ -80,7 +80,10 @@
     private static final int INTERFACE_INDEX = 999;
     private static final String SERVICE_TYPE = "_googlecast._tcp.local";
     private static final String[] SERVICE_TYPE_LABELS = TextUtils.split(SERVICE_TYPE, "\\.");
-    private static final Network NETWORK = mock(Network.class);
+    private static final InetSocketAddress IPV4_ADDRESS = new InetSocketAddress(
+            MdnsConstants.getMdnsIPv4Address(), MdnsConstants.MDNS_PORT);
+    private static final InetSocketAddress IPV6_ADDRESS = new InetSocketAddress(
+            MdnsConstants.getMdnsIPv6Address(), MdnsConstants.MDNS_PORT);
 
     @Mock
     private MdnsServiceBrowserListener mockListenerOne;
@@ -89,13 +92,16 @@
     @Mock
     private MdnsPacketWriter mockPacketWriter;
     @Mock
-    private MdnsSocketClient mockSocketClient;
+    private MdnsMultinetworkSocketClient mockSocketClient;
+    @Mock
+    private Network mockNetwork;
     @Captor
     private ArgumentCaptor<MdnsServiceInfo> serviceInfoCaptor;
 
     private final byte[] buf = new byte[10];
 
-    private DatagramPacket[] expectedPackets;
+    private DatagramPacket[] expectedIPv4Packets;
+    private DatagramPacket[] expectedIPv6Packets;
     private ScheduledFuture<?>[] expectedSendFutures;
     private FakeExecutor currentThreadExecutor = new FakeExecutor();
 
@@ -106,30 +112,52 @@
     public void setUp() throws IOException {
         MockitoAnnotations.initMocks(this);
 
-        expectedPackets = new DatagramPacket[16];
+        expectedIPv4Packets = new DatagramPacket[16];
+        expectedIPv6Packets = new DatagramPacket[16];
         expectedSendFutures = new ScheduledFuture<?>[16];
 
         for (int i = 0; i < expectedSendFutures.length; ++i) {
-            expectedPackets[i] = new DatagramPacket(buf, 0, 5);
+            expectedIPv4Packets[i] = new DatagramPacket(buf, 0 /* offset */, 5 /* length */,
+                    MdnsConstants.getMdnsIPv4Address(), MdnsConstants.MDNS_PORT);
+            expectedIPv6Packets[i] = new DatagramPacket(buf, 0 /* offset */, 5 /* length */,
+                    MdnsConstants.getMdnsIPv6Address(), MdnsConstants.MDNS_PORT);
             expectedSendFutures[i] = Mockito.mock(ScheduledFuture.class);
         }
-        when(mockPacketWriter.getPacket(any(SocketAddress.class)))
-                .thenReturn(expectedPackets[0])
-                .thenReturn(expectedPackets[1])
-                .thenReturn(expectedPackets[2])
-                .thenReturn(expectedPackets[3])
-                .thenReturn(expectedPackets[4])
-                .thenReturn(expectedPackets[5])
-                .thenReturn(expectedPackets[6])
-                .thenReturn(expectedPackets[7])
-                .thenReturn(expectedPackets[8])
-                .thenReturn(expectedPackets[9])
-                .thenReturn(expectedPackets[10])
-                .thenReturn(expectedPackets[11])
-                .thenReturn(expectedPackets[12])
-                .thenReturn(expectedPackets[13])
-                .thenReturn(expectedPackets[14])
-                .thenReturn(expectedPackets[15]);
+        when(mockPacketWriter.getPacket(IPV4_ADDRESS))
+                .thenReturn(expectedIPv4Packets[0])
+                .thenReturn(expectedIPv4Packets[1])
+                .thenReturn(expectedIPv4Packets[2])
+                .thenReturn(expectedIPv4Packets[3])
+                .thenReturn(expectedIPv4Packets[4])
+                .thenReturn(expectedIPv4Packets[5])
+                .thenReturn(expectedIPv4Packets[6])
+                .thenReturn(expectedIPv4Packets[7])
+                .thenReturn(expectedIPv4Packets[8])
+                .thenReturn(expectedIPv4Packets[9])
+                .thenReturn(expectedIPv4Packets[10])
+                .thenReturn(expectedIPv4Packets[11])
+                .thenReturn(expectedIPv4Packets[12])
+                .thenReturn(expectedIPv4Packets[13])
+                .thenReturn(expectedIPv4Packets[14])
+                .thenReturn(expectedIPv4Packets[15]);
+
+        when(mockPacketWriter.getPacket(IPV6_ADDRESS))
+                .thenReturn(expectedIPv6Packets[0])
+                .thenReturn(expectedIPv6Packets[1])
+                .thenReturn(expectedIPv6Packets[2])
+                .thenReturn(expectedIPv6Packets[3])
+                .thenReturn(expectedIPv6Packets[4])
+                .thenReturn(expectedIPv6Packets[5])
+                .thenReturn(expectedIPv6Packets[6])
+                .thenReturn(expectedIPv6Packets[7])
+                .thenReturn(expectedIPv6Packets[8])
+                .thenReturn(expectedIPv6Packets[9])
+                .thenReturn(expectedIPv6Packets[10])
+                .thenReturn(expectedIPv6Packets[11])
+                .thenReturn(expectedIPv6Packets[12])
+                .thenReturn(expectedIPv6Packets[13])
+                .thenReturn(expectedIPv6Packets[14])
+                .thenReturn(expectedIPv6Packets[15]);
 
         client =
                 new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor) {
@@ -282,8 +310,8 @@
         //MdnsConfigsFlagsImpl.alwaysAskForUnicastResponseInEachBurst.override(true);
         MdnsSearchOptions searchOptions =
                 MdnsSearchOptions.newBuilder().addSubtype("12345").setIsPassiveMode(false).build();
-        QueryTaskConfig config =
-                new QueryTaskConfig(searchOptions.getSubtypes(), searchOptions.isPassiveMode(), 1);
+        QueryTaskConfig config = new QueryTaskConfig(
+                searchOptions.getSubtypes(), searchOptions.isPassiveMode(), 1, mockNetwork);
 
         // This is the first query. We will ask for unicast response.
         assertTrue(config.expectUnicastResponse);
@@ -311,8 +339,8 @@
     public void testQueryTaskConfig_askForUnicastInFirstQuery() {
         MdnsSearchOptions searchOptions =
                 MdnsSearchOptions.newBuilder().addSubtype("12345").setIsPassiveMode(false).build();
-        QueryTaskConfig config =
-                new QueryTaskConfig(searchOptions.getSubtypes(), searchOptions.isPassiveMode(), 1);
+        QueryTaskConfig config = new QueryTaskConfig(
+                searchOptions.getSubtypes(), searchOptions.isPassiveMode(), 1, mockNetwork);
 
         // This is the first query. We will ask for unicast response.
         assertTrue(config.expectUnicastResponse);
@@ -409,7 +437,7 @@
         MdnsResponse response = mock(MdnsResponse.class);
         when(response.getServiceInstanceName()).thenReturn("service-instance-1");
         doReturn(INTERFACE_INDEX).when(response).getInterfaceIndex();
-        doReturn(NETWORK).when(response).getNetwork();
+        doReturn(mockNetwork).when(response).getNetwork();
         when(response.isComplete()).thenReturn(false);
 
         client.processResponse(response);
@@ -423,7 +451,7 @@
                 List.of() /* subTypes */,
                 Collections.singletonMap("key", null) /* attributes */,
                 INTERFACE_INDEX,
-                NETWORK);
+                mockNetwork);
 
         verify(mockListenerOne, never()).onServiceFound(any(MdnsServiceInfo.class));
         verify(mockListenerOne, never()).onServiceUpdated(any(MdnsServiceInfo.class));
@@ -443,7 +471,7 @@
                         /* subtype= */ "ABCDE",
                         Collections.emptyMap(),
                         /* interfaceIndex= */ 20,
-                        NETWORK);
+                        mockNetwork);
         client.processResponse(initialResponse);
 
         // Process a second response with a different port and updated text attributes.
@@ -455,7 +483,7 @@
                         /* subtype= */ "ABCDE",
                         Collections.singletonMap("key", "value"),
                         /* interfaceIndex= */ 20,
-                        NETWORK);
+                        mockNetwork);
         client.processResponse(secondResponse);
 
         // Verify onServiceNameDiscovered was called once for the initial response.
@@ -469,7 +497,7 @@
                 Collections.singletonList("ABCDE") /* subTypes */,
                 Collections.singletonMap("key", null) /* attributes */,
                 20 /* interfaceIndex */,
-                NETWORK);
+                mockNetwork);
 
         // Verify onServiceFound was called once for the initial response.
         verify(mockListenerOne).onServiceFound(serviceInfoCaptor.capture());
@@ -480,7 +508,7 @@
         assertEquals(initialServiceInfo.getSubtypes(), Collections.singletonList("ABCDE"));
         assertNull(initialServiceInfo.getAttributeByKey("key"));
         assertEquals(initialServiceInfo.getInterfaceIndex(), 20);
-        assertEquals(NETWORK, initialServiceInfo.getNetwork());
+        assertEquals(mockNetwork, initialServiceInfo.getNetwork());
 
         // Verify onServiceUpdated was called once for the second response.
         verify(mockListenerOne).onServiceUpdated(serviceInfoCaptor.capture());
@@ -492,7 +520,7 @@
         assertEquals(updatedServiceInfo.getSubtypes(), Collections.singletonList("ABCDE"));
         assertEquals(updatedServiceInfo.getAttributeByKey("key"), "value");
         assertEquals(updatedServiceInfo.getInterfaceIndex(), 20);
-        assertEquals(NETWORK, updatedServiceInfo.getNetwork());
+        assertEquals(mockNetwork, updatedServiceInfo.getNetwork());
     }
 
     @Test
@@ -509,7 +537,7 @@
                         /* subtype= */ "ABCDE",
                         Collections.emptyMap(),
                         /* interfaceIndex= */ 20,
-                        NETWORK);
+                        mockNetwork);
         client.processResponse(initialResponse);
 
         // Process a second response with a different port and updated text attributes.
@@ -521,7 +549,7 @@
                         /* subtype= */ "ABCDE",
                         Collections.singletonMap("key", "value"),
                         /* interfaceIndex= */ 20,
-                        NETWORK);
+                        mockNetwork);
         client.processResponse(secondResponse);
 
         System.out.println("secondResponses ip"
@@ -538,7 +566,7 @@
                 Collections.singletonList("ABCDE") /* subTypes */,
                 Collections.singletonMap("key", null) /* attributes */,
                 20 /* interfaceIndex */,
-                NETWORK);
+                mockNetwork);
 
         // Verify onServiceFound was called once for the initial response.
         verify(mockListenerOne).onServiceFound(serviceInfoCaptor.capture());
@@ -549,7 +577,7 @@
         assertEquals(initialServiceInfo.getSubtypes(), Collections.singletonList("ABCDE"));
         assertNull(initialServiceInfo.getAttributeByKey("key"));
         assertEquals(initialServiceInfo.getInterfaceIndex(), 20);
-        assertEquals(NETWORK, initialServiceInfo.getNetwork());
+        assertEquals(mockNetwork, initialServiceInfo.getNetwork());
 
         // Verify onServiceUpdated was called once for the second response.
         verify(mockListenerOne).onServiceUpdated(serviceInfoCaptor.capture());
@@ -561,7 +589,7 @@
         assertEquals(updatedServiceInfo.getSubtypes(), Collections.singletonList("ABCDE"));
         assertEquals(updatedServiceInfo.getAttributeByKey("key"), "value");
         assertEquals(updatedServiceInfo.getInterfaceIndex(), 20);
-        assertEquals(NETWORK, updatedServiceInfo.getNetwork());
+        assertEquals(mockNetwork, updatedServiceInfo.getNetwork());
     }
 
     private void verifyServiceRemovedNoCallback(MdnsServiceBrowserListener listener) {
@@ -599,12 +627,12 @@
                         /* subtype= */ "ABCDE",
                         Collections.emptyMap(),
                         INTERFACE_INDEX,
-                        NETWORK);
+                        mockNetwork);
         client.processResponse(initialResponse);
         MdnsResponse response = mock(MdnsResponse.class);
         doReturn("goodbye-service").when(response).getServiceInstanceName();
         doReturn(INTERFACE_INDEX).when(response).getInterfaceIndex();
-        doReturn(NETWORK).when(response).getNetwork();
+        doReturn(mockNetwork).when(response).getNetwork();
         doReturn(true).when(response).isGoodbye();
         client.processResponse(response);
         // Verify removed callback won't be called if the service is not existed.
@@ -615,9 +643,9 @@
         doReturn(serviceName).when(response).getServiceInstanceName();
         client.processResponse(response);
         verifyServiceRemovedCallback(
-                mockListenerOne, serviceName, SERVICE_TYPE_LABELS, INTERFACE_INDEX, NETWORK);
+                mockListenerOne, serviceName, SERVICE_TYPE_LABELS, INTERFACE_INDEX, mockNetwork);
         verifyServiceRemovedCallback(
-                mockListenerTwo, serviceName, SERVICE_TYPE_LABELS, INTERFACE_INDEX, NETWORK);
+                mockListenerTwo, serviceName, SERVICE_TYPE_LABELS, INTERFACE_INDEX, mockNetwork);
     }
 
     @Test
@@ -631,7 +659,7 @@
                         /* subtype= */ "ABCDE",
                         Collections.emptyMap(),
                         INTERFACE_INDEX,
-                        NETWORK);
+                        mockNetwork);
         client.processResponse(initialResponse);
 
         client.startSendAndReceive(mockListenerOne, MdnsSearchOptions.getDefaultOptions());
@@ -647,7 +675,7 @@
                 Collections.singletonList("ABCDE") /* subTypes */,
                 Collections.singletonMap("key", null) /* attributes */,
                 INTERFACE_INDEX,
-                NETWORK);
+                mockNetwork);
 
         // Verify onServiceFound was called once for the existing response.
         verify(mockListenerOne).onServiceFound(serviceInfoCaptor.capture());
@@ -684,7 +712,7 @@
         MdnsResponse initialResponse =
                 createMockResponse(
                         serviceInstanceName, "192.168.1.1", 5353, List.of("ABCDE"),
-                        Map.of(), INTERFACE_INDEX, NETWORK);
+                        Map.of(), INTERFACE_INDEX, mockNetwork);
         client.processResponse(initialResponse);
 
         // Clear the scheduled runnable.
@@ -718,7 +746,7 @@
         MdnsResponse initialResponse =
                 createMockResponse(
                         serviceInstanceName, "192.168.1.1", 5353, List.of("ABCDE"),
-                        Map.of(), INTERFACE_INDEX, NETWORK);
+                        Map.of(), INTERFACE_INDEX, mockNetwork);
         client.processResponse(initialResponse);
 
         // Clear the scheduled runnable.
@@ -737,7 +765,7 @@
 
         // Verify removed callback was called.
         verifyServiceRemovedCallback(mockListenerOne, serviceInstanceName, SERVICE_TYPE_LABELS,
-                INTERFACE_INDEX, NETWORK);
+                INTERFACE_INDEX, mockNetwork);
     }
 
     @Test
@@ -758,7 +786,7 @@
         MdnsResponse initialResponse =
                 createMockResponse(
                         serviceInstanceName, "192.168.1.1", 5353, List.of("ABCDE"),
-                        Map.of(), INTERFACE_INDEX, NETWORK);
+                        Map.of(), INTERFACE_INDEX, mockNetwork);
         client.processResponse(initialResponse);
 
         // Clear the scheduled runnable.
@@ -792,7 +820,7 @@
         MdnsResponse initialResponse =
                 createMockResponse(
                         serviceInstanceName, "192.168.1.1", 5353, List.of("ABCDE"),
-                        Map.of(), INTERFACE_INDEX, NETWORK);
+                        Map.of(), INTERFACE_INDEX, mockNetwork);
         client.processResponse(initialResponse);
 
         // Clear the scheduled runnable.
@@ -804,7 +832,7 @@
 
         // Verify removed callback was called.
         verifyServiceRemovedCallback(mockListenerOne, serviceInstanceName, SERVICE_TYPE_LABELS,
-                INTERFACE_INDEX, NETWORK);
+                INTERFACE_INDEX, mockNetwork);
     }
 
     @Test
@@ -824,7 +852,7 @@
                         "ABCDE" /* subtype */,
                         Collections.emptyMap(),
                         INTERFACE_INDEX,
-                        NETWORK);
+                        mockNetwork);
         client.processResponse(initialResponse);
 
         // Process a second response which has ip address to make response become complete.
@@ -836,7 +864,7 @@
                         "ABCDE" /* subtype */,
                         Collections.emptyMap(),
                         INTERFACE_INDEX,
-                        NETWORK);
+                        mockNetwork);
         client.processResponse(secondResponse);
 
         // Process a third response with a different ip address, port and updated text attributes.
@@ -848,7 +876,7 @@
                         "ABCDE" /* subtype */,
                         Collections.singletonMap("key", "value"),
                         INTERFACE_INDEX,
-                        NETWORK);
+                        mockNetwork);
         client.processResponse(thirdResponse);
 
         // Process the last response which is goodbye message.
@@ -868,7 +896,7 @@
                 Collections.singletonList("ABCDE") /* subTypes */,
                 Collections.singletonMap("key", null) /* attributes */,
                 INTERFACE_INDEX,
-                NETWORK);
+                mockNetwork);
 
         // Verify onServiceFound was second called for the second response.
         inOrder.verify(mockListenerOne).onServiceFound(serviceInfoCaptor.capture());
@@ -881,7 +909,7 @@
                 Collections.singletonList("ABCDE") /* subTypes */,
                 Collections.singletonMap("key", null) /* attributes */,
                 INTERFACE_INDEX,
-                NETWORK);
+                mockNetwork);
 
         // Verify onServiceUpdated was third called for the third response.
         inOrder.verify(mockListenerOne).onServiceUpdated(serviceInfoCaptor.capture());
@@ -894,7 +922,7 @@
                 Collections.singletonList("ABCDE") /* subTypes */,
                 Collections.singletonMap("key", "value") /* attributes */,
                 INTERFACE_INDEX,
-                NETWORK);
+                mockNetwork);
 
         // Verify onServiceRemoved was called for the last response.
         inOrder.verify(mockListenerOne).onServiceRemoved(serviceInfoCaptor.capture());
@@ -907,7 +935,7 @@
                 Collections.singletonList("ABCDE") /* subTypes */,
                 Collections.singletonMap("key", "value") /* attributes */,
                 INTERFACE_INDEX,
-                NETWORK);
+                mockNetwork);
 
         // Verify onServiceNameRemoved was called for the last response.
         inOrder.verify(mockListenerOne).onServiceNameRemoved(serviceInfoCaptor.capture());
@@ -920,18 +948,34 @@
                 Collections.singletonList("ABCDE") /* subTypes */,
                 Collections.singletonMap("key", "value") /* attributes */,
                 INTERFACE_INDEX,
-                NETWORK);
+                mockNetwork);
     }
 
     // verifies that the right query was enqueued with the right delay, and send query by executing
     // the runnable.
     private void verifyAndSendQuery(int index, long timeInMs, boolean expectsUnicastResponse) {
+        verifyAndSendQuery(
+                index, timeInMs, expectsUnicastResponse, true /* multipleSocketDiscovery */);
+    }
+
+    private void verifyAndSendQuery(int index, long timeInMs, boolean expectsUnicastResponse,
+            boolean multipleSocketDiscovery) {
         assertEquals(currentThreadExecutor.getAndClearLastScheduledDelayInMs(), timeInMs);
         currentThreadExecutor.getAndClearLastScheduledRunnable().run();
         if (expectsUnicastResponse) {
-            verify(mockSocketClient).sendUnicastPacket(expectedPackets[index]);
+            verify(mockSocketClient).sendUnicastPacket(
+                    expectedIPv4Packets[index], null /* network */);
+            if (multipleSocketDiscovery) {
+                verify(mockSocketClient).sendUnicastPacket(
+                        expectedIPv6Packets[index], null /* network */);
+            }
         } else {
-            verify(mockSocketClient).sendMulticastPacket(expectedPackets[index]);
+            verify(mockSocketClient).sendMulticastPacket(
+                    expectedIPv4Packets[index], null /* network */);
+            if (multipleSocketDiscovery) {
+                verify(mockSocketClient).sendMulticastPacket(
+                        expectedIPv6Packets[index], null /* network */);
+            }
         }
     }
 
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketProviderTest.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketProviderTest.java
index ef73030..07bbbb5 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketProviderTest.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketProviderTest.java
@@ -96,7 +96,7 @@
                 .getNetworkInterfaceByName(LOCAL_ONLY_IFACE_NAME);
         doReturn(mTetheredIfaceWrapper).when(mDeps).getNetworkInterfaceByName(TETHERED_IFACE_NAME);
         doReturn(mock(MdnsInterfaceSocket.class))
-                .when(mDeps).createMdnsInterfaceSocket(any(), anyInt());
+                .when(mDeps).createMdnsInterfaceSocket(any(), anyInt(), any(), any());
         final HandlerThread thread = new HandlerThread("MdnsSocketProviderTest");
         thread.start();
         mHandler = new Handler(thread.getLooper());
@@ -165,7 +165,7 @@
         }
 
         public void expectedSocketCreatedForNetwork(Network network, List<LinkAddress> addresses) {
-            final SocketEvent event = mHistory.poll(DEFAULT_TIMEOUT, c -> true);
+            final SocketEvent event = mHistory.poll(0L /* timeoutMs */, c -> true);
             assertNotNull(event);
             assertTrue(event instanceof SocketCreatedEvent);
             assertEquals(network, event.mNetwork);
@@ -173,7 +173,7 @@
         }
 
         public void expectedInterfaceDestroyedForNetwork(Network network) {
-            final SocketEvent event = mHistory.poll(DEFAULT_TIMEOUT, c -> true);
+            final SocketEvent event = mHistory.poll(0L /* timeoutMs */, c -> true);
             assertNotNull(event);
             assertTrue(event instanceof InterfaceDestroyedEvent);
             assertEquals(network, event.mNetwork);
@@ -181,7 +181,7 @@
 
         public void expectedAddressesChangedForNetwork(Network network,
                 List<LinkAddress> addresses) {
-            final SocketEvent event = mHistory.poll(DEFAULT_TIMEOUT, c -> true);
+            final SocketEvent event = mHistory.poll(0L /* timeoutMs */, c -> true);
             assertNotNull(event);
             assertTrue(event instanceof AddressesChangedEvent);
             assertEquals(network, event.mNetwork);
@@ -260,7 +260,8 @@
         HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
         testCallback1.expectedNoCallback();
         testCallback2.expectedNoCallback();
-        testCallback3.expectedNoCallback();
+        // Expect the socket destroy for tethered interface.
+        testCallback3.expectedInterfaceDestroyedForNetwork(LOCAL_NETWORK);
     }
 
     @Test