Merge "Add ConfigInfrastructure stubs lib to Android.bp"
diff --git a/Tethering/jni/com_android_networkstack_tethering_BpfCoordinator.cpp b/Tethering/jni/com_android_networkstack_tethering_BpfCoordinator.cpp
index 27357f8..c8c86bc 100644
--- a/Tethering/jni/com_android_networkstack_tethering_BpfCoordinator.cpp
+++ b/Tethering/jni/com_android_networkstack_tethering_BpfCoordinator.cpp
@@ -17,7 +17,7 @@
 #include <jni.h>
 #include <nativehelper/JNIHelp.h>
 
-#include "bpf_tethering.h"
+#include "offload.h"
 
 namespace android {
 
diff --git a/Tethering/src/com/android/networkstack/tethering/BpfCoordinator.java b/Tethering/src/com/android/networkstack/tethering/BpfCoordinator.java
index 6a5089d..51c7c9c 100644
--- a/Tethering/src/com/android/networkstack/tethering/BpfCoordinator.java
+++ b/Tethering/src/com/android/networkstack/tethering/BpfCoordinator.java
@@ -126,7 +126,7 @@
     private static final String DUMPSYS_RAWMAP_ARG_STATS = "--stats";
     private static final String DUMPSYS_RAWMAP_ARG_UPSTREAM4 = "--upstream4";
 
-    /** The names of all the BPF counters defined in bpf_tethering.h. */
+    /** The names of all the BPF counters defined in offload.h. */
     public static final String[] sBpfCounterNames = getBpfCounterNames();
 
     private static String makeMapPath(String which) {
diff --git a/Tethering/tests/unit/src/com/android/networkstack/tethering/BpfCoordinatorTest.java b/Tethering/tests/unit/src/com/android/networkstack/tethering/BpfCoordinatorTest.java
index 225fed7..53984a8 100644
--- a/Tethering/tests/unit/src/com/android/networkstack/tethering/BpfCoordinatorTest.java
+++ b/Tethering/tests/unit/src/com/android/networkstack/tethering/BpfCoordinatorTest.java
@@ -2180,7 +2180,7 @@
                 new TetherDevValue(UPSTREAM_IFINDEX));
 
         // dumpCounters
-        // The error code is defined in packages/modules/Connectivity/bpf_progs/bpf_tethering.h.
+        // The error code is defined in packages/modules/Connectivity/bpf_progs/offload.h.
         mBpfErrorMap.insertEntry(
                 new S32(0 /* INVALID_IPV4_VERSION */),
                 new S32(1000 /* count */));
diff --git a/bpf_progs/bpf_shared.h b/bpf_progs/bpf_shared.h
index 7b1106a..cc88680 100644
--- a/bpf_progs/bpf_shared.h
+++ b/bpf_progs/bpf_shared.h
@@ -196,32 +196,4 @@
 // Entry in the configuration map that stores which stats map is currently in use.
 #define CURRENT_STATS_MAP_CONFIGURATION_KEY 1
 
-typedef struct {
-    uint32_t iif;            // The input interface index
-    struct in6_addr pfx96;   // The source /96 nat64 prefix, bottom 32 bits must be 0
-    struct in6_addr local6;  // The full 128-bits of the destination IPv6 address
-} ClatIngress6Key;
-STRUCT_SIZE(ClatIngress6Key, 4 + 2 * 16);  // 36
-
-typedef struct {
-    uint32_t oif;           // The output interface to redirect to (0 means don't redirect)
-    struct in_addr local4;  // The destination IPv4 address
-} ClatIngress6Value;
-STRUCT_SIZE(ClatIngress6Value, 4 + 4);  // 8
-
-typedef struct {
-    uint32_t iif;           // The input interface index
-    struct in_addr local4;  // The source IPv4 address
-} ClatEgress4Key;
-STRUCT_SIZE(ClatEgress4Key, 4 + 4);  // 8
-
-typedef struct {
-    uint32_t oif;            // The output interface to redirect to
-    struct in6_addr local6;  // The full 128-bits of the source IPv6 address
-    struct in6_addr pfx96;   // The destination /96 nat64 prefix, bottom 32 bits must be 0
-    bool oifIsEthernet;      // Whether the output interface requires ethernet header
-    uint8_t pad[3];
-} ClatEgress4Value;
-STRUCT_SIZE(ClatEgress4Value, 4 + 2 * 16 + 1 + 3);  // 40
-
 #undef STRUCT_SIZE
diff --git a/bpf_progs/clatd.c b/bpf_progs/clatd.c
index fc10d09..14cddf6 100644
--- a/bpf_progs/clatd.c
+++ b/bpf_progs/clatd.c
@@ -35,7 +35,7 @@
 
 #include "bpf_helpers.h"
 #include "bpf_net_helpers.h"
-#include "bpf_shared.h"
+#include "clatd.h"
 #include "clat_mark.h"
 
 // IP flags. (from kernel's include/net/ip.h)
diff --git a/bpf_progs/clatd.h b/bpf_progs/clatd.h
new file mode 100644
index 0000000..b5f1cdc
--- /dev/null
+++ b/bpf_progs/clatd.h
@@ -0,0 +1,60 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include <linux/in.h>
+#include <linux/in6.h>
+
+#include <stdbool.h>
+#include <stdint.h>
+
+// This header file is shared by eBPF kernel programs (C) and netd (C++) and
+// some of the maps are also accessed directly from Java mainline module code.
+//
+// Hence: explicitly pad all relevant structures and assert that their size
+// is the sum of the sizes of their fields.
+#define STRUCT_SIZE(name, size) _Static_assert(sizeof(name) == (size), "Incorrect struct size.")
+
+typedef struct {
+    uint32_t iif;            // The input interface index
+    struct in6_addr pfx96;   // The source /96 nat64 prefix, bottom 32 bits must be 0
+    struct in6_addr local6;  // The full 128-bits of the destination IPv6 address
+} ClatIngress6Key;
+STRUCT_SIZE(ClatIngress6Key, 4 + 2 * 16);  // 36
+
+typedef struct {
+    uint32_t oif;           // The output interface to redirect to (0 means don't redirect)
+    struct in_addr local4;  // The destination IPv4 address
+} ClatIngress6Value;
+STRUCT_SIZE(ClatIngress6Value, 4 + 4);  // 8
+
+typedef struct {
+    uint32_t iif;           // The input interface index
+    struct in_addr local4;  // The source IPv4 address
+} ClatEgress4Key;
+STRUCT_SIZE(ClatEgress4Key, 4 + 4);  // 8
+
+typedef struct {
+    uint32_t oif;            // The output interface to redirect to
+    struct in6_addr local6;  // The full 128-bits of the source IPv6 address
+    struct in6_addr pfx96;   // The destination /96 nat64 prefix, bottom 32 bits must be 0
+    bool oifIsEthernet;      // Whether the output interface requires ethernet header
+    uint8_t pad[3];
+} ClatEgress4Value;
+STRUCT_SIZE(ClatEgress4Value, 4 + 2 * 16 + 1 + 3);  // 40
+
+#undef STRUCT_SIZE
diff --git a/bpf_progs/offload.c b/bpf_progs/offload.c
index e211d68..a8612df 100644
--- a/bpf_progs/offload.c
+++ b/bpf_progs/offload.c
@@ -48,7 +48,7 @@
 
 #include "bpf_helpers.h"
 #include "bpf_net_helpers.h"
-#include "bpf_tethering.h"
+#include "offload.h"
 
 // From kernel:include/net/ip.h
 #define IP_DF 0x4000  // Flag: "Don't Fragment"
diff --git a/bpf_progs/bpf_tethering.h b/bpf_progs/offload.h
similarity index 100%
rename from bpf_progs/bpf_tethering.h
rename to bpf_progs/offload.h
diff --git a/bpf_progs/test.c b/bpf_progs/test.c
index c11c358..d1f780f 100644
--- a/bpf_progs/test.c
+++ b/bpf_progs/test.c
@@ -46,7 +46,7 @@
 
 #include "bpf_helpers.h"
 #include "bpf_net_helpers.h"
-#include "bpf_tethering.h"
+#include "offload.h"
 
 // Used only by TetheringPrivilegedTests, not by production code.
 DEFINE_BPF_MAP_GRW(tether_downstream6_map, HASH, TetherDownstream6Key, Tether6Value, 16,
diff --git a/service/Android.bp b/service/Android.bp
index 9371b02..98bbbac 100644
--- a/service/Android.bp
+++ b/service/Android.bp
@@ -203,6 +203,7 @@
     libs: [
         "framework-annotations-lib",
         "framework-connectivity-pre-jarjar",
+        "framework-tethering",
         "framework-wifi",
         "service-connectivity-pre-jarjar",
     ],
diff --git a/service/mdns/com/android/server/connectivity/mdns/ConnectivityMonitor.java b/service/mdns/com/android/server/connectivity/mdns/ConnectivityMonitor.java
index 2b99d0a..1623669 100644
--- a/service/mdns/com/android/server/connectivity/mdns/ConnectivityMonitor.java
+++ b/service/mdns/com/android/server/connectivity/mdns/ConnectivityMonitor.java
@@ -16,6 +16,8 @@
 
 package com.android.server.connectivity.mdns;
 
+import android.net.Network;
+
 /** Interface for monitoring connectivity changes. */
 public interface ConnectivityMonitor {
     /**
@@ -29,6 +31,9 @@
 
     void notifyConnectivityChange();
 
+    /** Get available network which is received from connectivity change. */
+    Network getAvailableNetwork();
+
     /** Listener interface for receiving connectivity changes. */
     interface Listener {
         void onConnectivityChanged();
diff --git a/service/mdns/com/android/server/connectivity/mdns/ConnectivityMonitorWithConnectivityManager.java b/service/mdns/com/android/server/connectivity/mdns/ConnectivityMonitorWithConnectivityManager.java
index 3563d61..551e3db 100644
--- a/service/mdns/com/android/server/connectivity/mdns/ConnectivityMonitorWithConnectivityManager.java
+++ b/service/mdns/com/android/server/connectivity/mdns/ConnectivityMonitorWithConnectivityManager.java
@@ -16,6 +16,7 @@
 
 package com.android.server.connectivity.mdns;
 
+import android.annotation.Nullable;
 import android.annotation.TargetApi;
 import android.content.Context;
 import android.net.ConnectivityManager;
@@ -37,6 +38,7 @@
     // TODO(b/71901993): Ideally we shouldn't need this flag. However we still don't have clues why
     // the receiver is unregistered twice yet.
     private boolean isCallbackRegistered = false;
+    private Network lastAvailableNetwork = null;
 
     @SuppressWarnings({"nullness:assignment", "nullness:method.invocation"})
     @TargetApi(Build.VERSION_CODES.LOLLIPOP)
@@ -50,6 +52,7 @@
                     @Override
                     public void onAvailable(Network network) {
                         LOGGER.log("network available.");
+                        lastAvailableNetwork = network;
                         notifyConnectivityChange();
                     }
 
@@ -103,4 +106,10 @@
         connectivityManager.unregisterNetworkCallback(networkCallback);
         isCallbackRegistered = false;
     }
+
+    @Override
+    @Nullable
+    public Network getAvailableNetwork() {
+        return lastAvailableNetwork;
+    }
 }
\ No newline at end of file
diff --git a/service/mdns/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java b/service/mdns/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java
index 1faa6ce..0f3c23a 100644
--- a/service/mdns/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java
+++ b/service/mdns/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java
@@ -21,6 +21,7 @@
 import android.annotation.RequiresPermission;
 import android.text.TextUtils;
 import android.util.ArrayMap;
+import android.util.Log;
 
 import com.android.internal.annotations.VisibleForTesting;
 import com.android.server.connectivity.mdns.util.MdnsLogger;
@@ -34,7 +35,8 @@
  * notify them when a mDNS service instance is found, updated, or removed?
  */
 public class MdnsDiscoveryManager implements MdnsSocketClient.Callback {
-
+    private static final String TAG = MdnsDiscoveryManager.class.getSimpleName();
+    public static final boolean DBG = Log.isLoggable(TAG, Log.DEBUG);
     private static final MdnsLogger LOGGER = new MdnsLogger("MdnsDiscoveryManager");
 
     private final ExecutorProvider executorProvider;
diff --git a/service/mdns/com/android/server/connectivity/mdns/MdnsInterfaceSocket.java b/service/mdns/com/android/server/connectivity/mdns/MdnsInterfaceSocket.java
new file mode 100644
index 0000000..6090415
--- /dev/null
+++ b/service/mdns/com/android/server/connectivity/mdns/MdnsInterfaceSocket.java
@@ -0,0 +1,175 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server.connectivity.mdns;
+
+import static com.android.server.connectivity.mdns.MdnsSocket.MULTICAST_IPV4_ADDRESS;
+import static com.android.server.connectivity.mdns.MdnsSocket.MULTICAST_IPV6_ADDRESS;
+
+import android.annotation.NonNull;
+import android.net.LinkAddress;
+import android.net.util.SocketUtils;
+import android.os.ParcelFileDescriptor;
+import android.system.ErrnoException;
+import android.util.Log;
+
+import java.io.IOException;
+import java.net.DatagramPacket;
+import java.net.InetSocketAddress;
+import java.net.MulticastSocket;
+import java.net.NetworkInterface;
+import java.util.List;
+
+/**
+ * {@link MdnsInterfaceSocket} provides a similar interface to {@link MulticastSocket} and binds to
+ * an available multicast network interfaces.
+ *
+ * <p>This isn't thread safe and should be always called on the same thread unless specified
+ * otherwise.
+ *
+ * @see MulticastSocket for javadoc of each public method.
+ */
+public class MdnsInterfaceSocket {
+    private static final String TAG = MdnsInterfaceSocket.class.getSimpleName();
+    @NonNull private final MulticastSocket mMulticastSocket;
+    @NonNull private final NetworkInterface mNetworkInterface;
+    private boolean mJoinedIpv4 = false;
+    private boolean mJoinedIpv6 = false;
+
+    public MdnsInterfaceSocket(@NonNull NetworkInterface networkInterface, int port)
+            throws IOException {
+        mNetworkInterface = networkInterface;
+        mMulticastSocket = new MulticastSocket(port);
+        // RFC Spec: https://tools.ietf.org/html/rfc6762. Time to live is set 255
+        mMulticastSocket.setTimeToLive(255);
+        mMulticastSocket.setNetworkInterface(networkInterface);
+
+        // Bind socket to the interface for receiving from that interface only.
+        try (ParcelFileDescriptor pfd = ParcelFileDescriptor.fromDatagramSocket(mMulticastSocket)) {
+            SocketUtils.bindSocketToInterface(pfd.getFileDescriptor(), mNetworkInterface.getName());
+        } catch (ErrnoException e) {
+            throw new IOException("Error setting socket options", e);
+        }
+    }
+
+    /**
+     * Sends a datagram packet from this socket.
+     *
+     * <p>This method could be used on any thread.
+     */
+    public void send(@NonNull DatagramPacket packet) throws IOException {
+        mMulticastSocket.send(packet);
+    }
+
+    /**
+     * Receives a datagram packet from this socket.
+     *
+     * <p>This method could be used on any thread.
+     */
+    public void receive(@NonNull DatagramPacket packet) throws IOException {
+        mMulticastSocket.receive(packet);
+    }
+
+    private boolean hasIpv4Address(List<LinkAddress> addresses) {
+        for (LinkAddress address : addresses) {
+            if (address.isIpv4()) return true;
+        }
+        return false;
+    }
+
+    private boolean hasIpv6Address(List<LinkAddress> addresses) {
+        for (LinkAddress address : addresses) {
+            if (address.isIpv6()) return true;
+        }
+        return false;
+    }
+
+    /*** Joins both IPv4 and IPv6 multicast groups. */
+    public void joinGroup(@NonNull List<LinkAddress> addresses) {
+        maybeJoinIpv4(addresses);
+        maybeJoinIpv6(addresses);
+    }
+
+    private boolean joinGroup(InetSocketAddress multicastAddress) {
+        try {
+            mMulticastSocket.joinGroup(multicastAddress, mNetworkInterface);
+            return true;
+        } catch (IOException e) {
+            // The address may have just been removed
+            Log.e(TAG, "Error joining multicast group for " + mNetworkInterface, e);
+            return false;
+        }
+    }
+
+    private void maybeJoinIpv4(List<LinkAddress> addresses) {
+        final boolean hasAddr = hasIpv4Address(addresses);
+        if (!mJoinedIpv4 && hasAddr) {
+            mJoinedIpv4 = joinGroup(MULTICAST_IPV4_ADDRESS);
+        } else if (!hasAddr) {
+            // Lost IPv4 address
+            mJoinedIpv4 = false;
+        }
+    }
+
+    private void maybeJoinIpv6(List<LinkAddress> addresses) {
+        final boolean hasAddr = hasIpv6Address(addresses);
+        if (!mJoinedIpv6 && hasAddr) {
+            mJoinedIpv6 = joinGroup(MULTICAST_IPV6_ADDRESS);
+        } else if (!hasAddr) {
+            // Lost IPv6 address
+            mJoinedIpv6 = false;
+        }
+    }
+
+    /*** Destroy this socket by leaving all joined multicast groups and closing this socket. */
+    public void destroy() {
+        if (mJoinedIpv4) {
+            try {
+                mMulticastSocket.leaveGroup(MULTICAST_IPV4_ADDRESS, mNetworkInterface);
+            } catch (IOException e) {
+                Log.e(TAG, "Error leaving IPv4 group for " + mNetworkInterface, e);
+            }
+        }
+        if (mJoinedIpv6) {
+            try {
+                mMulticastSocket.leaveGroup(MULTICAST_IPV6_ADDRESS, mNetworkInterface);
+            } catch (IOException e) {
+                Log.e(TAG, "Error leaving IPv4 group for " + mNetworkInterface, e);
+            }
+        }
+        mMulticastSocket.close();
+    }
+
+    /**
+     * Returns the index of the network interface that this socket is bound to. If the interface
+     * cannot be determined, returns -1.
+     *
+     * <p>This method could be used on any thread.
+     */
+    public int getInterfaceIndex() {
+        return mNetworkInterface.getIndex();
+    }
+
+    /*** Returns whether this socket has joined IPv4 group */
+    public boolean hasJoinedIpv4() {
+        return mJoinedIpv4;
+    }
+
+    /*** Returns whether this socket has joined IPv6 group */
+    public boolean hasJoinedIpv6() {
+        return mJoinedIpv6;
+    }
+}
diff --git a/service/mdns/com/android/server/connectivity/mdns/MdnsResponse.java b/service/mdns/com/android/server/connectivity/mdns/MdnsResponse.java
index 623168c..3a41978 100644
--- a/service/mdns/com/android/server/connectivity/mdns/MdnsResponse.java
+++ b/service/mdns/com/android/server/connectivity/mdns/MdnsResponse.java
@@ -17,6 +17,7 @@
 package com.android.server.connectivity.mdns;
 
 import android.annotation.Nullable;
+import android.net.Network;
 
 import com.android.internal.annotations.VisibleForTesting;
 
@@ -35,13 +36,16 @@
     private MdnsInetAddressRecord inet4AddressRecord;
     private MdnsInetAddressRecord inet6AddressRecord;
     private long lastUpdateTime;
-    private int interfaceIndex = MdnsSocket.INTERFACE_INDEX_UNSPECIFIED;
+    private final int interfaceIndex;
+    @Nullable private final Network network;
 
     /** Constructs a new, empty response. */
-    public MdnsResponse(long now) {
+    public MdnsResponse(long now, int interfaceIndex, @Nullable Network network) {
         lastUpdateTime = now;
         records = new LinkedList<>();
         pointerRecords = new LinkedList<>();
+        this.interfaceIndex = interfaceIndex;
+        this.network = network;
     }
 
     // This generic typed helper compares records for equality.
@@ -208,21 +212,21 @@
     }
 
     /**
-     * Updates the index of the network interface at which this response was received. Can be set to
-     * {@link MdnsSocket#INTERFACE_INDEX_UNSPECIFIED} if unset.
-     */
-    public synchronized void setInterfaceIndex(int interfaceIndex) {
-        this.interfaceIndex = interfaceIndex;
-    }
-
-    /**
      * Returns the index of the network interface at which this response was received. Can be set to
      * {@link MdnsSocket#INTERFACE_INDEX_UNSPECIFIED} if unset.
      */
-    public synchronized int getInterfaceIndex() {
+    public int getInterfaceIndex() {
         return interfaceIndex;
     }
 
+    /**
+     * Returns the network at which this response was received, or null if the network is unknown.
+     */
+    @Nullable
+    public Network getNetwork() {
+        return network;
+    }
+
     /** Gets the IPv6 address record. */
     public synchronized MdnsInetAddressRecord getInet6AddressRecord() {
         return inet6AddressRecord;
diff --git a/service/mdns/com/android/server/connectivity/mdns/MdnsResponseDecoder.java b/service/mdns/com/android/server/connectivity/mdns/MdnsResponseDecoder.java
index 6c2bc19..7cf84f6 100644
--- a/service/mdns/com/android/server/connectivity/mdns/MdnsResponseDecoder.java
+++ b/service/mdns/com/android/server/connectivity/mdns/MdnsResponseDecoder.java
@@ -18,6 +18,7 @@
 
 import android.annotation.NonNull;
 import android.annotation.Nullable;
+import android.net.Network;
 import android.os.SystemClock;
 
 import com.android.server.connectivity.mdns.util.MdnsLogger;
@@ -95,10 +96,11 @@
      * @param packet The packet to read from.
      * @param interfaceIndex the network interface index (or {@link
      *     MdnsSocket#INTERFACE_INDEX_UNSPECIFIED} if not known) at which the packet was received
+     * @param network the network at which the packet was received, or null if it is unknown.
      * @return A list of mDNS responses, or null if the packet contained no appropriate responses.
      */
     public int decode(@NonNull DatagramPacket packet, @NonNull List<MdnsResponse> responses,
-            int interfaceIndex) {
+            int interfaceIndex, @Nullable Network network) {
         MdnsPacketReader reader = new MdnsPacketReader(packet);
 
         List<MdnsRecord> records;
@@ -253,12 +255,11 @@
                     MdnsResponse response = findResponseWithPointer(responses,
                             pointerRecord.getPointer());
                     if (response == null) {
-                        response = new MdnsResponse(now);
+                        response = new MdnsResponse(now, interfaceIndex, network);
                         responses.add(response);
                     }
                     // Set interface index earlier because some responses have PTR record only.
                     // Need to know every response is getting from which interface.
-                    response.setInterfaceIndex(interfaceIndex);
                     response.addPointerRecord((MdnsPointerRecord) record);
                 }
             }
diff --git a/service/mdns/com/android/server/connectivity/mdns/MdnsSearchOptions.java b/service/mdns/com/android/server/connectivity/mdns/MdnsSearchOptions.java
index 195bc8e..583c4a9 100644
--- a/service/mdns/com/android/server/connectivity/mdns/MdnsSearchOptions.java
+++ b/service/mdns/com/android/server/connectivity/mdns/MdnsSearchOptions.java
@@ -17,6 +17,8 @@
 package com.android.server.connectivity.mdns;
 
 import android.annotation.NonNull;
+import android.annotation.Nullable;
+import android.net.Network;
 import android.os.Parcel;
 import android.os.Parcelable;
 import android.text.TextUtils;
@@ -43,7 +45,8 @@
                 @Override
                 public MdnsSearchOptions createFromParcel(Parcel source) {
                     return new MdnsSearchOptions(source.createStringArrayList(),
-                            source.readBoolean(), source.readBoolean());
+                            source.readBoolean(), source.readBoolean(),
+                            source.readParcelable(null));
                 }
 
                 @Override
@@ -56,15 +59,19 @@
 
     private final boolean isPassiveMode;
     private final boolean removeExpiredService;
+    // The target network for searching. Null network means search on all possible interfaces.
+    @Nullable private final Network mNetwork;
 
-    /** Parcelable constructs for a {@link MdnsServiceInfo}. */
-    MdnsSearchOptions(List<String> subtypes, boolean isPassiveMode, boolean removeExpiredService) {
+    /** Parcelable constructs for a {@link MdnsSearchOptions}. */
+    MdnsSearchOptions(List<String> subtypes, boolean isPassiveMode, boolean removeExpiredService,
+            @Nullable Network network) {
         this.subtypes = new ArrayList<>();
         if (subtypes != null) {
             this.subtypes.addAll(subtypes);
         }
         this.isPassiveMode = isPassiveMode;
         this.removeExpiredService = removeExpiredService;
+        mNetwork = network;
     }
 
     /** Returns a {@link Builder} for {@link MdnsSearchOptions}. */
@@ -98,6 +105,16 @@
         return removeExpiredService;
     }
 
+    /**
+     * Returns the network which the mdns query should target on.
+     *
+     * @return the target network or null if search on all possible interfaces.
+     */
+    @Nullable
+    public Network getNetwork() {
+        return mNetwork;
+    }
+
     @Override
     public int describeContents() {
         return 0;
@@ -108,6 +125,7 @@
         out.writeStringList(subtypes);
         out.writeBoolean(isPassiveMode);
         out.writeBoolean(removeExpiredService);
+        out.writeParcelable(mNetwork, 0);
     }
 
     /** A builder to create {@link MdnsSearchOptions}. */
@@ -115,6 +133,7 @@
         private final Set<String> subtypes;
         private boolean isPassiveMode = true;
         private boolean removeExpiredService;
+        private Network mNetwork;
 
         private Builder() {
             subtypes = new ArraySet<>();
@@ -165,10 +184,20 @@
             return this;
         }
 
+        /**
+         * Sets if the mdns query should target on specific network.
+         *
+         * @param network the mdns query will target on given network.
+         */
+        public Builder setNetwork(Network network) {
+            mNetwork = network;
+            return this;
+        }
+
         /** Builds a {@link MdnsSearchOptions} with the arguments supplied to this builder. */
         public MdnsSearchOptions build() {
-            return new MdnsSearchOptions(
-                    new ArrayList<>(subtypes), isPassiveMode, removeExpiredService);
+            return new MdnsSearchOptions(new ArrayList<>(subtypes), isPassiveMode,
+                    removeExpiredService, mNetwork);
         }
     }
 }
\ No newline at end of file
diff --git a/service/mdns/com/android/server/connectivity/mdns/MdnsServiceInfo.java b/service/mdns/com/android/server/connectivity/mdns/MdnsServiceInfo.java
index 9683bc9..938fc3f 100644
--- a/service/mdns/com/android/server/connectivity/mdns/MdnsServiceInfo.java
+++ b/service/mdns/com/android/server/connectivity/mdns/MdnsServiceInfo.java
@@ -16,8 +16,11 @@
 
 package com.android.server.connectivity.mdns;
 
+import static com.android.server.connectivity.mdns.MdnsSocket.INTERFACE_INDEX_UNSPECIFIED;
+
 import android.annotation.NonNull;
 import android.annotation.Nullable;
+import android.net.Network;
 import android.os.Parcel;
 import android.os.Parcelable;
 import android.text.TextUtils;
@@ -58,7 +61,8 @@
                             source.readString(),
                             source.createStringArrayList(),
                             source.createTypedArrayList(TextEntry.CREATOR),
-                            source.readInt());
+                            source.readInt(),
+                            source.readParcelable(null));
                 }
 
                 @Override
@@ -82,6 +86,8 @@
     private final int interfaceIndex;
 
     private final Map<String, byte[]> attributes;
+    @Nullable
+    private final Network network;
 
     /** Constructs a {@link MdnsServiceInfo} object with default values. */
     public MdnsServiceInfo(
@@ -103,7 +109,8 @@
                 ipv6Address,
                 textStrings,
                 /* textEntries= */ null,
-                /* interfaceIndex= */ -1);
+                /* interfaceIndex= */ INTERFACE_INDEX_UNSPECIFIED,
+                /* network= */ null);
     }
 
     /** Constructs a {@link MdnsServiceInfo} object with default values. */
@@ -127,7 +134,8 @@
                 ipv6Address,
                 textStrings,
                 textEntries,
-                /* interfaceIndex= */ -1);
+                /* interfaceIndex= */ INTERFACE_INDEX_UNSPECIFIED,
+                /* network= */ null);
     }
 
     /**
@@ -146,6 +154,37 @@
             @Nullable List<String> textStrings,
             @Nullable List<TextEntry> textEntries,
             int interfaceIndex) {
+        this(
+                serviceInstanceName,
+                serviceType,
+                subtypes,
+                hostName,
+                port,
+                ipv4Address,
+                ipv6Address,
+                textStrings,
+                textEntries,
+                interfaceIndex,
+                /* network= */ null);
+    }
+
+    /**
+     * Constructs a {@link MdnsServiceInfo} object with default values.
+     *
+     * @hide
+     */
+    public MdnsServiceInfo(
+            String serviceInstanceName,
+            String[] serviceType,
+            @Nullable List<String> subtypes,
+            String[] hostName,
+            int port,
+            @Nullable String ipv4Address,
+            @Nullable String ipv6Address,
+            @Nullable List<String> textStrings,
+            @Nullable List<TextEntry> textEntries,
+            int interfaceIndex,
+            @Nullable Network network) {
         this.serviceInstanceName = serviceInstanceName;
         this.serviceType = serviceType;
         this.subtypes = new ArrayList<>();
@@ -180,6 +219,7 @@
         }
         this.attributes = Collections.unmodifiableMap(attributes);
         this.interfaceIndex = interfaceIndex;
+        this.network = network;
     }
 
     private static List<TextEntry> parseTextStrings(List<String> textStrings) {
@@ -244,6 +284,14 @@
     }
 
     /**
+     * Returns the network at which this response was received, or null if the network is unknown.
+     */
+    @Nullable
+    public Network getNetwork() {
+        return network;
+    }
+
+    /**
      * Returns attribute value for {@code key} as a UTF-8 string. It's the caller who must make sure
      * that the value of {@code key} is indeed a UTF-8 string. {@code null} will be returned if no
      * attribute value exists for {@code key}.
@@ -293,6 +341,7 @@
         out.writeStringList(textStrings);
         out.writeTypedList(textEntries);
         out.writeInt(interfaceIndex);
+        out.writeParcelable(network, 0);
     }
 
     @Override
diff --git a/service/mdns/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java b/service/mdns/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
index dd4ff9b..538f376 100644
--- a/service/mdns/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
+++ b/service/mdns/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
@@ -130,7 +130,8 @@
                 ipv6Address,
                 textStrings,
                 textEntries,
-                response.getInterfaceIndex());
+                response.getInterfaceIndex(),
+                response.getNetwork());
     }
 
     /**
diff --git a/service/mdns/com/android/server/connectivity/mdns/MdnsSocket.java b/service/mdns/com/android/server/connectivity/mdns/MdnsSocket.java
index 0a9b2fc..64c4495 100644
--- a/service/mdns/com/android/server/connectivity/mdns/MdnsSocket.java
+++ b/service/mdns/com/android/server/connectivity/mdns/MdnsSocket.java
@@ -17,6 +17,8 @@
 package com.android.server.connectivity.mdns;
 
 import android.annotation.NonNull;
+import android.annotation.Nullable;
+import android.net.Network;
 
 import com.android.internal.annotations.VisibleForTesting;
 import com.android.server.connectivity.mdns.util.MdnsLogger;
@@ -38,9 +40,9 @@
     private static final MdnsLogger LOGGER = new MdnsLogger("MdnsSocket");
 
     static final int INTERFACE_INDEX_UNSPECIFIED = -1;
-    private static final InetSocketAddress MULTICAST_IPV4_ADDRESS =
+    protected static final InetSocketAddress MULTICAST_IPV4_ADDRESS =
             new InetSocketAddress(MdnsConstants.getMdnsIPv4Address(), MdnsConstants.MDNS_PORT);
-    private static final InetSocketAddress MULTICAST_IPV6_ADDRESS =
+    protected static final InetSocketAddress MULTICAST_IPV6_ADDRESS =
             new InetSocketAddress(MdnsConstants.getMdnsIPv6Address(), MdnsConstants.MDNS_PORT);
     private final MulticastNetworkInterfaceProvider multicastNetworkInterfaceProvider;
     private final MulticastSocket multicastSocket;
@@ -125,6 +127,14 @@
         }
     }
 
+    /**
+     * Returns the available network that this socket is used to, or null if the network is unknown.
+     */
+    @Nullable
+    public Network getNetwork() {
+        return multicastNetworkInterfaceProvider.getAvailableNetwork();
+    }
+
     public boolean isOnIPv6OnlyNetwork() {
         return isOnIPv6OnlyNetwork;
     }
diff --git a/service/mdns/com/android/server/connectivity/mdns/MdnsSocketClient.java b/service/mdns/com/android/server/connectivity/mdns/MdnsSocketClient.java
index 758221a..6a321d1 100644
--- a/service/mdns/com/android/server/connectivity/mdns/MdnsSocketClient.java
+++ b/service/mdns/com/android/server/connectivity/mdns/MdnsSocketClient.java
@@ -21,6 +21,7 @@
 import android.annotation.Nullable;
 import android.annotation.RequiresPermission;
 import android.content.Context;
+import android.net.Network;
 import android.net.wifi.WifiManager.MulticastLock;
 import android.os.SystemClock;
 import android.text.format.DateUtils;
@@ -397,7 +398,8 @@
                             responseType,
                             /* interfaceIndex= */ (socket == null || !propagateInterfaceIndex)
                                     ? MdnsSocket.INTERFACE_INDEX_UNSPECIFIED
-                                    : socket.getInterfaceIndex());
+                                    : socket.getInterfaceIndex(),
+                            /* network= */ socket.getNetwork());
                 }
             } catch (IOException e) {
                 if (!shouldStopSocketLoop) {
@@ -408,12 +410,12 @@
         LOGGER.log("Receive thread stopped.");
     }
 
-    private int processResponsePacket(
-            @NonNull DatagramPacket packet, String responseType, int interfaceIndex) {
+    private int processResponsePacket(@NonNull DatagramPacket packet, String responseType,
+            int interfaceIndex, @Nullable Network network) {
         int packetNumber = ++receivedPacketNumber;
 
         List<MdnsResponse> responses = new LinkedList<>();
-        int errorCode = responseDecoder.decode(packet, responses, interfaceIndex);
+        int errorCode = responseDecoder.decode(packet, responses, interfaceIndex, network);
         if (errorCode == MdnsResponseDecoder.SUCCESS) {
             if (responseType.equals(MULTICAST_TYPE)) {
                 receivedMulticastResponse = true;
diff --git a/service/mdns/com/android/server/connectivity/mdns/MdnsSocketProvider.java b/service/mdns/com/android/server/connectivity/mdns/MdnsSocketProvider.java
new file mode 100644
index 0000000..b8c324e
--- /dev/null
+++ b/service/mdns/com/android/server/connectivity/mdns/MdnsSocketProvider.java
@@ -0,0 +1,460 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server.connectivity.mdns;
+
+import android.annotation.NonNull;
+import android.annotation.Nullable;
+import android.content.Context;
+import android.net.ConnectivityManager;
+import android.net.ConnectivityManager.NetworkCallback;
+import android.net.INetd;
+import android.net.LinkAddress;
+import android.net.LinkProperties;
+import android.net.Network;
+import android.net.NetworkRequest;
+import android.net.TetheringManager;
+import android.net.TetheringManager.TetheringEventCallback;
+import android.os.Handler;
+import android.os.Looper;
+import android.system.OsConstants;
+import android.util.ArrayMap;
+import android.util.Log;
+
+import com.android.internal.annotations.VisibleForTesting;
+import com.android.net.module.util.LinkPropertiesUtils.CompareResult;
+import com.android.net.module.util.ip.NetlinkMonitor;
+import com.android.net.module.util.netlink.NetlinkConstants;
+import com.android.net.module.util.netlink.NetlinkMessage;
+import com.android.server.connectivity.mdns.util.MdnsLogger;
+
+import java.io.IOException;
+import java.net.InterfaceAddress;
+import java.net.NetworkInterface;
+import java.net.SocketException;
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * The {@link MdnsSocketProvider} manages the multiple sockets for mDns.
+ *
+ * <p>This class is not thread safe, it is intended to be used only from the looper thread.
+ * However, the constructor is an exception, as it is called on another thread;
+ * therefore for thread safety all members of this class MUST either be final or initialized
+ * to their default value (0, false or null).
+ *
+ */
+public class MdnsSocketProvider {
+    private static final String TAG = MdnsSocketProvider.class.getSimpleName();
+    private static final boolean DBG = MdnsDiscoveryManager.DBG;
+    private static final MdnsLogger LOGGER = new MdnsLogger(TAG);
+    @NonNull private final Context mContext;
+    @NonNull private final Handler mHandler;
+    @NonNull private final Dependencies mDependencies;
+    @NonNull private final NetworkCallback mNetworkCallback;
+    @NonNull private final TetheringEventCallback mTetheringEventCallback;
+    @NonNull private final NetlinkMonitor mNetlinkMonitor;
+    private final ArrayMap<Network, SocketInfo> mNetworkSockets = new ArrayMap<>();
+    private final ArrayMap<String, SocketInfo> mTetherInterfaceSockets = new ArrayMap<>();
+    private final ArrayMap<Network, LinkProperties> mActiveNetworksLinkProperties =
+            new ArrayMap<>();
+    private final ArrayMap<SocketCallback, Network> mCallbacksToRequestedNetworks =
+            new ArrayMap<>();
+    private final List<String> mLocalOnlyInterfaces = new ArrayList<>();
+    private final List<String> mTetheredInterfaces = new ArrayList<>();
+    private boolean mMonitoringSockets = false;
+
+    public MdnsSocketProvider(@NonNull Context context, @NonNull Looper looper) {
+        this(context, looper, new Dependencies());
+    }
+
+    MdnsSocketProvider(@NonNull Context context, @NonNull Looper looper,
+            @NonNull Dependencies deps) {
+        mContext = context;
+        mHandler = new Handler(looper);
+        mDependencies = deps;
+        mNetworkCallback = new NetworkCallback() {
+            @Override
+            public void onLost(Network network) {
+                mActiveNetworksLinkProperties.remove(network);
+                removeSocket(network, null /* interfaceName */);
+            }
+
+            @Override
+            public void onLinkPropertiesChanged(Network network, LinkProperties lp) {
+                handleLinkPropertiesChanged(network, lp);
+            }
+        };
+        mTetheringEventCallback = new TetheringEventCallback() {
+            @Override
+            public void onLocalOnlyInterfacesChanged(@NonNull List<String> interfaces) {
+                handleTetherInterfacesChanged(mLocalOnlyInterfaces, interfaces);
+            }
+
+            @Override
+            public void onTetheredInterfacesChanged(@NonNull List<String> interfaces) {
+                handleTetherInterfacesChanged(mTetheredInterfaces, interfaces);
+            }
+        };
+
+        mNetlinkMonitor = new SocketNetlinkMonitor(mHandler);
+    }
+
+    /**
+     * Dependencies of MdnsSocketProvider, for injection in tests.
+     */
+    @VisibleForTesting
+    public static class Dependencies {
+        /*** Get network interface by given interface name */
+        public NetworkInterfaceWrapper getNetworkInterfaceByName(String interfaceName)
+                throws SocketException {
+            final NetworkInterface ni = NetworkInterface.getByName(interfaceName);
+            return ni == null ? null : new NetworkInterfaceWrapper(ni);
+        }
+
+        /*** Check whether given network interface can support mdns */
+        public boolean canScanOnInterface(NetworkInterfaceWrapper networkInterface) {
+            return MulticastNetworkInterfaceProvider.canScanOnInterface(networkInterface);
+        }
+
+        /*** Create a MdnsInterfaceSocket */
+        public MdnsInterfaceSocket createMdnsInterfaceSocket(NetworkInterface networkInterface,
+                int port) throws IOException {
+            return new MdnsInterfaceSocket(networkInterface, port);
+        }
+    }
+
+    /*** Data class for storing socket related info  */
+    private static class SocketInfo {
+        final MdnsInterfaceSocket mSocket;
+        final List<LinkAddress> mAddresses = new ArrayList<>();
+
+        SocketInfo(MdnsInterfaceSocket socket, List<LinkAddress> addresses) {
+            mSocket = socket;
+            mAddresses.addAll(addresses);
+        }
+    }
+
+    private static class SocketNetlinkMonitor extends NetlinkMonitor {
+        SocketNetlinkMonitor(Handler handler) {
+            super(handler, LOGGER.mLog, TAG, OsConstants.NETLINK_ROUTE,
+                    NetlinkConstants.RTMGRP_IPV4_IFADDR | NetlinkConstants.RTMGRP_IPV6_IFADDR);
+        }
+
+        @Override
+        public void processNetlinkMessage(NetlinkMessage nlMsg, long whenMs) {
+            // TODO: Handle netlink message.
+        }
+    }
+
+    private void ensureRunningOnHandlerThread() {
+        if (mHandler.getLooper().getThread() != Thread.currentThread()) {
+            throw new IllegalStateException(
+                    "Not running on Handler thread: " + Thread.currentThread().getName());
+        }
+    }
+
+    /*** Start monitoring sockets by listening callbacks for sockets creation or removal */
+    public void startMonitoringSockets() {
+        ensureRunningOnHandlerThread();
+        if (mMonitoringSockets) {
+            Log.d(TAG, "Already monitoring sockets.");
+            return;
+        }
+        if (DBG) Log.d(TAG, "Start monitoring sockets.");
+        mContext.getSystemService(ConnectivityManager.class).registerNetworkCallback(
+                new NetworkRequest.Builder().clearCapabilities().build(),
+                mNetworkCallback, mHandler);
+
+        final TetheringManager tetheringManager = mContext.getSystemService(TetheringManager.class);
+        tetheringManager.registerTetheringEventCallback(mHandler::post, mTetheringEventCallback);
+
+        mHandler.post(mNetlinkMonitor::start);
+        mMonitoringSockets = true;
+    }
+
+    /*** Stop monitoring sockets and unregister callbacks */
+    public void stopMonitoringSockets() {
+        ensureRunningOnHandlerThread();
+        if (!mMonitoringSockets) {
+            Log.d(TAG, "Monitoring sockets hasn't been started.");
+            return;
+        }
+        if (DBG) Log.d(TAG, "Stop monitoring sockets.");
+        mContext.getSystemService(ConnectivityManager.class)
+                .unregisterNetworkCallback(mNetworkCallback);
+
+        final TetheringManager tetheringManager = mContext.getSystemService(TetheringManager.class);
+        tetheringManager.unregisterTetheringEventCallback(mTetheringEventCallback);
+
+        mHandler.post(mNetlinkMonitor::stop);
+        mMonitoringSockets = false;
+    }
+
+    private static boolean isNetworkMatched(@Nullable Network targetNetwork,
+            @NonNull Network currentNetwork) {
+        return targetNetwork == null || targetNetwork.equals(currentNetwork);
+    }
+
+    private boolean matchRequestedNetwork(Network network) {
+        for (int i = 0; i < mCallbacksToRequestedNetworks.size(); i++) {
+            final Network requestedNetwork =  mCallbacksToRequestedNetworks.valueAt(i);
+            if (isNetworkMatched(requestedNetwork, network)) {
+                return true;
+            }
+        }
+        return false;
+    }
+
+    private boolean hasAllNetworksRequest() {
+        return mCallbacksToRequestedNetworks.containsValue(null);
+    }
+
+    private void handleLinkPropertiesChanged(Network network, LinkProperties lp) {
+        mActiveNetworksLinkProperties.put(network, lp);
+        if (!matchRequestedNetwork(network)) {
+            if (DBG) {
+                Log.d(TAG, "Ignore LinkProperties change. There is no request for the"
+                        + " Network:" + network);
+            }
+            return;
+        }
+
+        final SocketInfo socketInfo = mNetworkSockets.get(network);
+        if (socketInfo == null) {
+            createSocket(network, lp);
+        } else {
+            // Update the addresses of this socket.
+            final List<LinkAddress> addresses = lp.getLinkAddresses();
+            socketInfo.mAddresses.clear();
+            socketInfo.mAddresses.addAll(addresses);
+            // Try to join the group again.
+            socketInfo.mSocket.joinGroup(addresses);
+
+            notifyAddressesChanged(network, lp);
+        }
+    }
+
+    private static LinkProperties createLPForTetheredInterface(String interfaceName) {
+        final LinkProperties linkProperties = new LinkProperties();
+        linkProperties.setInterfaceName(interfaceName);
+        // TODO: Use NetlinkMonitor to update addresses for tethering interfaces.
+        return linkProperties;
+    }
+
+    private void handleTetherInterfacesChanged(List<String> current, List<String> updated) {
+        if (!hasAllNetworksRequest()) {
+            // Currently, the network for tethering can not be requested, so the sockets for
+            // tethering are only created if there is a request for all networks (interfaces).
+            // Therefore, this change can skip if there is no such request.
+            if (DBG) {
+                Log.d(TAG, "Ignore tether interfaces change. There is no request for all"
+                        + " networks.");
+            }
+            return;
+        }
+
+        final CompareResult<String> interfaceDiff = new CompareResult<>(
+                current, updated);
+        for (String name : interfaceDiff.added) {
+            createSocket(new Network(INetd.LOCAL_NET_ID), createLPForTetheredInterface(name));
+        }
+        for (String name : interfaceDiff.removed) {
+            removeSocket(new Network(INetd.LOCAL_NET_ID), name);
+        }
+        current.clear();
+        current.addAll(updated);
+    }
+
+    private static List<LinkAddress> getLinkAddressFromNetworkInterface(
+            NetworkInterfaceWrapper networkInterface) {
+        List<LinkAddress> addresses = new ArrayList<>();
+        for (InterfaceAddress address : networkInterface.getInterfaceAddresses()) {
+            addresses.add(new LinkAddress(address));
+        }
+        return addresses;
+    }
+
+    private void createSocket(Network network, LinkProperties lp) {
+        final String interfaceName = lp.getInterfaceName();
+        if (interfaceName == null) {
+            Log.e(TAG, "Can not create socket with null interface name.");
+            return;
+        }
+
+        try {
+            final NetworkInterfaceWrapper networkInterface =
+                    mDependencies.getNetworkInterfaceByName(interfaceName);
+            if (networkInterface == null || !mDependencies.canScanOnInterface(networkInterface)) {
+                return;
+            }
+
+            if (DBG) {
+                Log.d(TAG, "Create a socket on network:" + network
+                        + " with interfaceName:" + interfaceName);
+            }
+            final MdnsInterfaceSocket socket = mDependencies.createMdnsInterfaceSocket(
+                    networkInterface.getNetworkInterface(), MdnsConstants.MDNS_PORT);
+            final List<LinkAddress> addresses;
+            if (network.netId == INetd.LOCAL_NET_ID) {
+                addresses = getLinkAddressFromNetworkInterface(networkInterface);
+                mTetherInterfaceSockets.put(interfaceName, new SocketInfo(socket, addresses));
+            } else {
+                addresses = lp.getLinkAddresses();
+                mNetworkSockets.put(network, new SocketInfo(socket, addresses));
+            }
+            // Try to join IPv4/IPv6 group.
+            socket.joinGroup(addresses);
+
+            // Notify the listeners which need this socket.
+            notifySocketCreated(network, socket, addresses);
+        } catch (IOException e) {
+            Log.e(TAG, "Create a socket failed with interface=" + interfaceName, e);
+        }
+    }
+
+    private void removeSocket(Network network, String interfaceName) {
+        final SocketInfo socketInfo = network.netId == INetd.LOCAL_NET_ID
+                ? mTetherInterfaceSockets.remove(interfaceName)
+                : mNetworkSockets.remove(network);
+        if (socketInfo == null) return;
+
+        socketInfo.mSocket.destroy();
+        notifyInterfaceDestroyed(network, socketInfo.mSocket);
+    }
+
+    private void notifySocketCreated(Network network, MdnsInterfaceSocket socket,
+            List<LinkAddress> addresses) {
+        for (int i = 0; i < mCallbacksToRequestedNetworks.size(); i++) {
+            final Network requestedNetwork = mCallbacksToRequestedNetworks.valueAt(i);
+            if (isNetworkMatched(requestedNetwork, network)) {
+                mCallbacksToRequestedNetworks.keyAt(i).onSocketCreated(network, socket, addresses);
+            }
+        }
+    }
+
+    private void notifyInterfaceDestroyed(Network network, MdnsInterfaceSocket socket) {
+        for (int i = 0; i < mCallbacksToRequestedNetworks.size(); i++) {
+            final Network requestedNetwork = mCallbacksToRequestedNetworks.valueAt(i);
+            if (isNetworkMatched(requestedNetwork, network)) {
+                mCallbacksToRequestedNetworks.keyAt(i).onInterfaceDestroyed(network, socket);
+            }
+        }
+    }
+
+    private void notifyAddressesChanged(Network network, LinkProperties lp) {
+        for (int i = 0; i < mCallbacksToRequestedNetworks.size(); i++) {
+            final Network requestedNetwork = mCallbacksToRequestedNetworks.valueAt(i);
+            if (isNetworkMatched(requestedNetwork, network)) {
+                mCallbacksToRequestedNetworks.keyAt(i)
+                        .onAddressesChanged(network, lp.getLinkAddresses());
+            }
+        }
+    }
+
+    private void retrieveAndNotifySocketFromNetwork(Network network, SocketCallback cb) {
+        final SocketInfo socketInfo = mNetworkSockets.get(network);
+        if (socketInfo == null) {
+            final LinkProperties lp = mActiveNetworksLinkProperties.get(network);
+            if (lp == null) {
+                // The requested network is not existed. Maybe wait for LinkProperties change later.
+                if (DBG) Log.d(TAG, "There is no LinkProperties for this network:" + network);
+                return;
+            }
+            createSocket(network, lp);
+        } else {
+            // Notify the socket for requested network.
+            cb.onSocketCreated(network, socketInfo.mSocket, socketInfo.mAddresses);
+        }
+    }
+
+    private void retrieveAndNotifySocketFromInterface(String interfaceName, SocketCallback cb) {
+        final SocketInfo socketInfo = mTetherInterfaceSockets.get(interfaceName);
+        if (socketInfo == null) {
+            createSocket(
+                    new Network(INetd.LOCAL_NET_ID), createLPForTetheredInterface(interfaceName));
+        } else {
+            // Notify the socket for requested network.
+            cb.onSocketCreated(
+                    new Network(INetd.LOCAL_NET_ID), socketInfo.mSocket, socketInfo.mAddresses);
+        }
+    }
+
+    /**
+     * Request a socket for given network.
+     *
+     * @param network the required network for a socket. Null means create sockets on all possible
+     *                networks (interfaces).
+     * @param cb the callback to listen the socket creation.
+     */
+    public void requestSocket(@Nullable Network network, @NonNull SocketCallback cb) {
+        ensureRunningOnHandlerThread();
+        mCallbacksToRequestedNetworks.put(cb, network);
+        if (network == null) {
+            // Does not specify a required network, create sockets for all possible
+            // networks (interfaces).
+            for (int i = 0; i < mActiveNetworksLinkProperties.size(); i++) {
+                retrieveAndNotifySocketFromNetwork(mActiveNetworksLinkProperties.keyAt(i), cb);
+            }
+
+            for (String localInterface : mLocalOnlyInterfaces) {
+                retrieveAndNotifySocketFromInterface(localInterface, cb);
+            }
+
+            for (String tetheredInterface : mTetheredInterfaces) {
+                retrieveAndNotifySocketFromInterface(tetheredInterface, cb);
+            }
+        } else {
+            retrieveAndNotifySocketFromNetwork(network, cb);
+        }
+    }
+
+    /*** Unrequest the socket */
+    public void unrequestSocket(@NonNull SocketCallback cb) {
+        ensureRunningOnHandlerThread();
+        mCallbacksToRequestedNetworks.remove(cb);
+        if (hasAllNetworksRequest()) {
+            // Still has a request for all networks (interfaces).
+            return;
+        }
+
+        // Check if remaining requests are matched any of sockets.
+        for (int i = mNetworkSockets.size() - 1; i >= 0; i--) {
+            if (matchRequestedNetwork(mNetworkSockets.keyAt(i))) continue;
+            mNetworkSockets.removeAt(i).mSocket.destroy();
+        }
+
+        // Remove all sockets for tethering interface because these sockets do not have associated
+        // networks, and they should invoke by a request for all networks (interfaces). If there is
+        // no such request, the sockets for tethering interface should be removed.
+        for (int i = mTetherInterfaceSockets.size() - 1; i >= 0; i--) {
+            mTetherInterfaceSockets.removeAt(i).mSocket.destroy();
+        }
+    }
+
+    /*** Callbacks for listening socket changes */
+    public interface SocketCallback {
+        /*** Notify the socket is created */
+        default void onSocketCreated(@NonNull Network network, @NonNull MdnsInterfaceSocket socket,
+                @NonNull List<LinkAddress> addresses) {}
+        /*** Notify the interface is destroyed */
+        default void onInterfaceDestroyed(@NonNull Network network,
+                @NonNull MdnsInterfaceSocket socket) {}
+        /*** Notify the addresses is changed on the network */
+        default void onAddressesChanged(@NonNull Network network,
+                @NonNull List<LinkAddress> addresses) {}
+    }
+}
diff --git a/service/mdns/com/android/server/connectivity/mdns/MulticastNetworkInterfaceProvider.java b/service/mdns/com/android/server/connectivity/mdns/MulticastNetworkInterfaceProvider.java
index e0d8fa6..ade7b95 100644
--- a/service/mdns/com/android/server/connectivity/mdns/MulticastNetworkInterfaceProvider.java
+++ b/service/mdns/com/android/server/connectivity/mdns/MulticastNetworkInterfaceProvider.java
@@ -19,6 +19,7 @@
 import android.annotation.NonNull;
 import android.annotation.Nullable;
 import android.content.Context;
+import android.net.Network;
 
 import com.android.internal.annotations.VisibleForTesting;
 import com.android.server.connectivity.mdns.util.MdnsLogger;
@@ -56,7 +57,7 @@
                 context, this::onConnectivityChanged);
     }
 
-    private void onConnectivityChanged() {
+    private synchronized void onConnectivityChanged() {
         connectivityChanged = true;
     }
 
@@ -141,7 +142,13 @@
         return networkInterfaceWrappers;
     }
 
-    private boolean canScanOnInterface(@Nullable NetworkInterfaceWrapper networkInterface) {
+    @Nullable
+    public Network getAvailableNetwork() {
+        return connectivityMonitor.getAvailableNetwork();
+    }
+
+    /*** Check whether given network interface can support mdns */
+    public static boolean canScanOnInterface(@Nullable NetworkInterfaceWrapper networkInterface) {
         try {
             if ((networkInterface == null)
                     || networkInterface.isLoopback()
@@ -160,7 +167,7 @@
         return false;
     }
 
-    private boolean hasInet4Address(@NonNull NetworkInterfaceWrapper networkInterface) {
+    private static boolean hasInet4Address(@NonNull NetworkInterfaceWrapper networkInterface) {
         for (InterfaceAddress ifAddr : networkInterface.getInterfaceAddresses()) {
             if (ifAddr.getAddress() instanceof Inet4Address) {
                 return true;
@@ -169,7 +176,7 @@
         return false;
     }
 
-    private boolean hasInet6Address(@NonNull NetworkInterfaceWrapper networkInterface) {
+    private static boolean hasInet6Address(@NonNull NetworkInterfaceWrapper networkInterface) {
         for (InterfaceAddress ifAddr : networkInterface.getInterfaceAddresses()) {
             if (ifAddr.getAddress() instanceof Inet6Address) {
                 return true;
diff --git a/service/src/com/android/server/ConnectivityService.java b/service/src/com/android/server/ConnectivityService.java
index 4c9e3a3..a44494c 100755
--- a/service/src/com/android/server/ConnectivityService.java
+++ b/service/src/com/android/server/ConnectivityService.java
@@ -738,6 +738,12 @@
     private static final int EVENT_INITIAL_EVALUATION_TIMEOUT = 57;
 
     /**
+     * Used internally when the user does not want the network from captive portal app.
+     * obj = Network
+     */
+    private static final int EVENT_USER_DOES_NOT_WANT = 58;
+
+    /**
      * Argument for {@link #EVENT_PROVISIONING_NOTIFICATION} to indicate that the notification
      * should be shown.
      */
@@ -5065,6 +5071,10 @@
         public void appResponse(final int response) {
             if (response == CaptivePortal.APP_RETURN_WANTED_AS_IS) {
                 enforceSettingsPermission();
+            } else if (response == CaptivePortal.APP_RETURN_UNWANTED) {
+                mHandler.sendMessage(mHandler.obtainMessage(EVENT_USER_DOES_NOT_WANT, mNetwork));
+                // Since the network will be disconnected, skip notifying NetworkMonitor
+                return;
             }
 
             final NetworkMonitorManager nm = getNetworkMonitorManager(mNetwork);
@@ -5508,6 +5518,12 @@
                 case EVENT_INGRESS_RATE_LIMIT_CHANGED:
                     handleIngressRateLimitChanged();
                     break;
+                case EVENT_USER_DOES_NOT_WANT:
+                    final NetworkAgentInfo nai = getNetworkAgentInfoForNetwork((Network) msg.obj);
+                    if (nai == null) break;
+                    nai.onPreventAutomaticReconnect();
+                    nai.disconnect();
+                    break;
             }
         }
     }
diff --git a/tests/unit/java/com/android/server/ConnectivityServiceTest.java b/tests/unit/java/com/android/server/ConnectivityServiceTest.java
index 67cc7bd..07d3d95 100755
--- a/tests/unit/java/com/android/server/ConnectivityServiceTest.java
+++ b/tests/unit/java/com/android/server/ConnectivityServiceTest.java
@@ -228,6 +228,7 @@
 import android.content.pm.UserInfo;
 import android.content.res.Resources;
 import android.location.LocationManager;
+import android.net.CaptivePortal;
 import android.net.CaptivePortalData;
 import android.net.ConnectionInfo;
 import android.net.ConnectivityDiagnosticsManager.DataStallReport;
@@ -4440,6 +4441,27 @@
         validatedCallback.expect(CallbackEntry.LOST, mWiFiNetworkAgent);
     }
 
+    private Intent startCaptivePortalApp(TestNetworkAgentWrapper networkAgent) throws Exception {
+        Network network = networkAgent.getNetwork();
+        // Check that startCaptivePortalApp sends the expected command to NetworkMonitor.
+        mCm.startCaptivePortalApp(network);
+        waitForIdle();
+        verify(networkAgent.mNetworkMonitor).launchCaptivePortalApp();
+
+        // NetworkMonitor uses startCaptivePortal(Network, Bundle) (startCaptivePortalAppInternal)
+        final Bundle testBundle = new Bundle();
+        final String testKey = "testkey";
+        final String testValue = "testvalue";
+        testBundle.putString(testKey, testValue);
+        mServiceContext.setPermission(NetworkStack.PERMISSION_MAINLINE_NETWORK_STACK,
+                PERMISSION_GRANTED);
+        mCm.startCaptivePortalApp(network, testBundle);
+        final Intent signInIntent = mServiceContext.expectStartActivityIntent(TIMEOUT_MS);
+        assertEquals(ACTION_CAPTIVE_PORTAL_SIGN_IN, signInIntent.getAction());
+        assertEquals(testValue, signInIntent.getStringExtra(testKey));
+        return signInIntent;
+    }
+
     @Test
     public void testCaptivePortalApp() throws Exception {
         final TestNetworkCallback captivePortalCallback = new TestNetworkCallback();
@@ -4476,22 +4498,7 @@
         captivePortalCallback.expect(CallbackEntry.NETWORK_CAPS_UPDATED,
                 mWiFiNetworkAgent);
 
-        // Check that startCaptivePortalApp sends the expected command to NetworkMonitor.
-        mCm.startCaptivePortalApp(wifiNetwork);
-        waitForIdle();
-        verify(mWiFiNetworkAgent.mNetworkMonitor).launchCaptivePortalApp();
-
-        // NetworkMonitor uses startCaptivePortal(Network, Bundle) (startCaptivePortalAppInternal)
-        final Bundle testBundle = new Bundle();
-        final String testKey = "testkey";
-        final String testValue = "testvalue";
-        testBundle.putString(testKey, testValue);
-        mServiceContext.setPermission(NetworkStack.PERMISSION_MAINLINE_NETWORK_STACK,
-                PERMISSION_GRANTED);
-        mCm.startCaptivePortalApp(wifiNetwork, testBundle);
-        final Intent signInIntent = mServiceContext.expectStartActivityIntent(TIMEOUT_MS);
-        assertEquals(ACTION_CAPTIVE_PORTAL_SIGN_IN, signInIntent.getAction());
-        assertEquals(testValue, signInIntent.getStringExtra(testKey));
+        startCaptivePortalApp(mWiFiNetworkAgent);
 
         // Report that the captive portal is dismissed, and check that callbacks are fired
         mWiFiNetworkAgent.setNetworkValid(false /* isStrictMode */);
@@ -4504,6 +4511,37 @@
     }
 
     @Test
+    public void testCaptivePortalApp_IgnoreNetwork() throws Exception {
+        final TestNetworkCallback captivePortalCallback = new TestNetworkCallback();
+        final NetworkRequest captivePortalRequest = new NetworkRequest.Builder()
+                .addCapability(NET_CAPABILITY_CAPTIVE_PORTAL).build();
+        mCm.registerNetworkCallback(captivePortalRequest, captivePortalCallback);
+
+        mWiFiNetworkAgent = new TestNetworkAgentWrapper(TRANSPORT_WIFI);
+        mWiFiNetworkAgent.connectWithCaptivePortal(TEST_REDIRECT_URL, false);
+        captivePortalCallback.expectAvailableCallbacksUnvalidated(mWiFiNetworkAgent);
+
+        final Intent signInIntent = startCaptivePortalApp(mWiFiNetworkAgent);
+        final CaptivePortal captivePortal = signInIntent
+                .getParcelableExtra(ConnectivityManager.EXTRA_CAPTIVE_PORTAL);
+
+        captivePortal.ignoreNetwork();
+        waitForIdle();
+
+        // Since network will disconnect, ensure no notification of response to NetworkMonitor
+        verify(mWiFiNetworkAgent.mNetworkMonitor, never())
+                .notifyCaptivePortalAppFinished(CaptivePortal.APP_RETURN_UNWANTED);
+
+        // Report that the network is disconnected
+        mWiFiNetworkAgent.expectDisconnected();
+        mWiFiNetworkAgent.expectPreventReconnectReceived();
+        verify(mWiFiNetworkAgent.mNetworkMonitor).notifyNetworkDisconnected();
+        captivePortalCallback.expect(CallbackEntry.LOST, mWiFiNetworkAgent);
+
+        mCm.unregisterNetworkCallback(captivePortalCallback);
+    }
+
+    @Test
     public void testAvoidOrIgnoreCaptivePortals() throws Exception {
         final TestNetworkCallback captivePortalCallback = new TestNetworkCallback();
         final NetworkRequest captivePortalRequest = new NetworkRequest.Builder()
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/ConnectivityMonitorWithConnectivityManagerTests.java b/tests/unit/java/com/android/server/connectivity/mdns/ConnectivityMonitorWithConnectivityManagerTests.java
index f84e2d8..8fb7be1 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/ConnectivityMonitorWithConnectivityManagerTests.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/ConnectivityMonitorWithConnectivityManagerTests.java
@@ -21,6 +21,7 @@
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.inOrder;
+import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 
@@ -111,7 +112,7 @@
                 any(NetworkRequest.class), callbackCaptor.capture());
 
         final NetworkCallback callback = callbackCaptor.getValue();
-        final Network testNetwork = new Network(1 /* netId */);
+        final Network testNetwork = mock(Network.class);
 
         // Simulate network available.
         callback.onAvailable(testNetwork);
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsResponseDecoderTests.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsResponseDecoderTests.java
index 02e00c2..4cae447 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsResponseDecoderTests.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsResponseDecoderTests.java
@@ -27,6 +27,7 @@
 import static org.mockito.Mockito.mock;
 
 import android.net.InetAddresses;
+import android.net.Network;
 
 import com.android.net.module.util.HexDump;
 import com.android.testutils.DevSdkIgnoreRule;
@@ -165,7 +166,8 @@
         packet.setSocketAddress(
                 new InetSocketAddress(MdnsConstants.getMdnsIPv4Address(), MdnsConstants.MDNS_PORT));
         responses.clear();
-        int errorCode = decoder.decode(packet, responses, MdnsSocket.INTERFACE_INDEX_UNSPECIFIED);
+        int errorCode = decoder.decode(
+                packet, responses, MdnsSocket.INTERFACE_INDEX_UNSPECIFIED, mock(Network.class));
         assertEquals(MdnsResponseDecoder.SUCCESS, errorCode);
         assertEquals(1, responses.size());
     }
@@ -178,7 +180,8 @@
         packet.setSocketAddress(
                 new InetSocketAddress(MdnsConstants.getMdnsIPv4Address(), MdnsConstants.MDNS_PORT));
         responses.clear();
-        int errorCode = decoder.decode(packet, responses, MdnsSocket.INTERFACE_INDEX_UNSPECIFIED);
+        int errorCode = decoder.decode(
+                packet, responses, MdnsSocket.INTERFACE_INDEX_UNSPECIFIED, mock(Network.class));
         assertEquals(MdnsResponseDecoder.SUCCESS, errorCode);
         assertEquals(2, responses.size());
     }
@@ -237,7 +240,8 @@
                 new InetSocketAddress(MdnsConstants.getMdnsIPv6Address(), MdnsConstants.MDNS_PORT));
 
         responses.clear();
-        int errorCode = decoder.decode(packet, responses, MdnsSocket.INTERFACE_INDEX_UNSPECIFIED);
+        int errorCode = decoder.decode(
+                packet, responses, MdnsSocket.INTERFACE_INDEX_UNSPECIFIED, mock(Network.class));
         assertEquals(MdnsResponseDecoder.SUCCESS, errorCode);
 
         MdnsResponse response = responses.get(0);
@@ -287,10 +291,13 @@
                 new InetSocketAddress(MdnsConstants.getMdnsIPv6Address(), MdnsConstants.MDNS_PORT));
 
         responses.clear();
-        int errorCode = decoder.decode(packet, responses, /* interfaceIndex= */ 10);
+        final Network network = mock(Network.class);
+        int errorCode = decoder.decode(
+                packet, responses, /* interfaceIndex= */ 10, network);
         assertEquals(errorCode, MdnsResponseDecoder.SUCCESS);
         assertEquals(responses.size(), 1);
         assertEquals(responses.get(0).getInterfaceIndex(), 10);
+        assertEquals(network, responses.get(0).getNetwork());
     }
 
     @Test
@@ -306,7 +313,8 @@
                 new InetSocketAddress(MdnsConstants.getMdnsIPv6Address(), MdnsConstants.MDNS_PORT));
 
         responses.clear();
-        int errorCode = decoder.decode(packet, responses, /* interfaceIndex= */ 0);
+        int errorCode = decoder.decode(
+                packet, responses, /* interfaceIndex= */ 0, mock(Network.class));
         assertEquals(MdnsResponseDecoder.SUCCESS, errorCode);
 
         // This should emit two records:
@@ -340,7 +348,8 @@
                 new InetSocketAddress(MdnsConstants.getMdnsIPv6Address(), MdnsConstants.MDNS_PORT));
 
         responses.clear();
-        int errorCode = decoder.decode(packet, responses, /* interfaceIndex= */ 0);
+        int errorCode = decoder.decode(
+                packet, responses, /* interfaceIndex= */ 0, mock(Network.class));
         assertEquals(MdnsResponseDecoder.SUCCESS, errorCode);
 
         // This should emit only two records:
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsResponseTests.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsResponseTests.java
index 771e42c..ec57dc8 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsResponseTests.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsResponseTests.java
@@ -21,8 +21,12 @@
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
+import static org.mockito.Mockito.mock;
+
+import android.net.Network;
 
 import com.android.net.module.util.HexDump;
 import com.android.testutils.DevSdkIgnoreRule;
@@ -92,6 +96,9 @@
             + "3839300878797A3D"
             + "21402324");
 
+    private static final int INTERFACE_INDEX = 999;
+    private final Network mNetwork = mock(Network.class);
+
     // The following helper classes act as wrappers so that IPv4 and IPv6 address records can
     // be explicitly created by type using same constructor signature as all other records.
     static class MdnsInet4AddressRecord extends MdnsInetAddressRecord {
@@ -127,7 +134,7 @@
     // Construct an MdnsResponse with the specified data packets applied.
     private MdnsResponse makeMdnsResponse(long time, List<PacketAndRecordClass> responseList)
             throws IOException {
-        MdnsResponse response = new MdnsResponse(time);
+        MdnsResponse response = new MdnsResponse(time, INTERFACE_INDEX, mNetwork);
         for (PacketAndRecordClass responseData : responseList) {
             DatagramPacket packet =
                     new DatagramPacket(responseData.packetData, responseData.packetData.length);
@@ -159,7 +166,7 @@
         String[] name = reader.readLabels();
         reader.skip(2); // skip record type indication.
         MdnsInetAddressRecord record = new MdnsInetAddressRecord(name, MdnsRecord.TYPE_A, reader);
-        MdnsResponse response = new MdnsResponse(0);
+        MdnsResponse response = new MdnsResponse(0, INTERFACE_INDEX, mNetwork);
         assertFalse(response.hasInet4AddressRecord());
         assertTrue(response.setInet4AddressRecord(record));
         assertEquals(response.getInet4AddressRecord(), record);
@@ -173,7 +180,7 @@
         reader.skip(2); // skip record type indication.
         MdnsInetAddressRecord record =
                 new MdnsInetAddressRecord(name, MdnsRecord.TYPE_AAAA, reader);
-        MdnsResponse response = new MdnsResponse(0);
+        MdnsResponse response = new MdnsResponse(0, INTERFACE_INDEX, mNetwork);
         assertFalse(response.hasInet6AddressRecord());
         assertTrue(response.setInet6AddressRecord(record));
         assertEquals(response.getInet6AddressRecord(), record);
@@ -186,7 +193,7 @@
         String[] name = reader.readLabels();
         reader.skip(2); // skip record type indication.
         MdnsPointerRecord record = new MdnsPointerRecord(name, reader);
-        MdnsResponse response = new MdnsResponse(0);
+        MdnsResponse response = new MdnsResponse(0, INTERFACE_INDEX, mNetwork);
         assertFalse(response.hasPointerRecords());
         assertTrue(response.addPointerRecord(record));
         List<MdnsPointerRecord> recordList = response.getPointerRecords();
@@ -202,7 +209,7 @@
         String[] name = reader.readLabels();
         reader.skip(2); // skip record type indication.
         MdnsServiceRecord record = new MdnsServiceRecord(name, reader);
-        MdnsResponse response = new MdnsResponse(0);
+        MdnsResponse response = new MdnsResponse(0, INTERFACE_INDEX, mNetwork);
         assertFalse(response.hasServiceRecord());
         assertTrue(response.setServiceRecord(record));
         assertEquals(response.getServiceRecord(), record);
@@ -215,23 +222,31 @@
         String[] name = reader.readLabels();
         reader.skip(2); // skip record type indication.
         MdnsTextRecord record = new MdnsTextRecord(name, reader);
-        MdnsResponse response = new MdnsResponse(0);
+        MdnsResponse response = new MdnsResponse(0, INTERFACE_INDEX, mNetwork);
         assertFalse(response.hasTextRecord());
         assertTrue(response.setTextRecord(record));
         assertEquals(response.getTextRecord(), record);
     }
 
     @Test
-    public void getInterfaceIndex_returnsDefaultValue() {
-        MdnsResponse response = new MdnsResponse(/* now= */ 0);
-        assertEquals(response.getInterfaceIndex(), -1);
+    public void getInterfaceIndex() {
+        final MdnsResponse response1 = new MdnsResponse(/* now= */ 0, INTERFACE_INDEX, mNetwork);
+        assertEquals(INTERFACE_INDEX, response1.getInterfaceIndex());
+
+        final MdnsResponse response2 =
+                new MdnsResponse(/* now= */ 0, 1234 /* interfaceIndex */, mNetwork);
+        assertEquals(1234, response2.getInterfaceIndex());
     }
 
     @Test
-    public void getInterfaceIndex_afterSet_returnsValue() {
-        MdnsResponse response = new MdnsResponse(/* now= */ 0);
-        response.setInterfaceIndex(5);
-        assertEquals(response.getInterfaceIndex(), 5);
+    public void testGetNetwork() {
+        final MdnsResponse response1 =
+                new MdnsResponse(/* now= */ 0, INTERFACE_INDEX, null /* network */);
+        assertNull(response1.getNetwork());
+
+        final MdnsResponse response2 =
+                new MdnsResponse(/* now= */ 0, 1234 /* interfaceIndex */, mNetwork);
+        assertEquals(mNetwork, response2.getNetwork());
     }
 
     @Test
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceInfoTest.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceInfoTest.java
index ebdb73f..76728cf 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceInfoTest.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceInfoTest.java
@@ -16,13 +16,16 @@
 
 package com.android.server.connectivity.mdns;
 
+import static com.android.server.connectivity.mdns.MdnsSocket.INTERFACE_INDEX_UNSPECIFIED;
 import static com.android.testutils.DevSdkIgnoreRuleKt.SC_V2;
 
 import static org.junit.Assert.assertArrayEquals;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertTrue;
+import static org.mockito.Mockito.mock;
 
+import android.net.Network;
 import android.os.Parcel;
 
 import com.android.server.connectivity.mdns.MdnsServiceInfo.TextEntry;
@@ -128,7 +131,7 @@
                         "2001::1",
                         List.of());
 
-        assertEquals(info.getInterfaceIndex(), -1);
+        assertEquals(info.getInterfaceIndex(), INTERFACE_INDEX_UNSPECIFIED);
     }
 
     @Test
@@ -150,6 +153,41 @@
     }
 
     @Test
+    public void testGetNetwork() {
+        final MdnsServiceInfo info1 =
+                new MdnsServiceInfo(
+                        "my-mdns-service",
+                        new String[] {"_googlecast", "_tcp"},
+                        List.of(),
+                        new String[] {"my-host", "local"},
+                        12345,
+                        "192.168.1.1",
+                        "2001::1",
+                        List.of(),
+                        /* textEntries= */ null,
+                        /* interfaceIndex= */ 20);
+
+        assertNull(info1.getNetwork());
+
+        final Network network = mock(Network.class);
+        final MdnsServiceInfo info2 =
+                new MdnsServiceInfo(
+                        "my-mdns-service",
+                        new String[] {"_googlecast", "_tcp"},
+                        List.of(),
+                        new String[] {"my-host", "local"},
+                        12345,
+                        "192.168.1.1",
+                        "2001::1",
+                        List.of(),
+                        /* textEntries= */ null,
+                        /* interfaceIndex= */ 20,
+                        network);
+
+        assertEquals(network, info2.getNetwork());
+    }
+
+    @Test
     public void parcelable_canBeParceledAndUnparceled() {
         Parcel parcel = Parcel.obtain();
         MdnsServiceInfo beforeParcel =
@@ -165,7 +203,9 @@
                         List.of(
                                 MdnsServiceInfo.TextEntry.fromString("vn=Google Inc."),
                                 MdnsServiceInfo.TextEntry.fromString("mn=Google Nest Hub Max"),
-                                MdnsServiceInfo.TextEntry.fromString("test=")));
+                                MdnsServiceInfo.TextEntry.fromString("test=")),
+                        20 /* interfaceIndex */,
+                        new Network(123));
 
         beforeParcel.writeToParcel(parcel, 0);
         parcel.setDataPosition(0);
@@ -179,6 +219,8 @@
         assertEquals(beforeParcel.getIpv4Address(), afterParcel.getIpv4Address());
         assertEquals(beforeParcel.getIpv6Address(), afterParcel.getIpv6Address());
         assertEquals(beforeParcel.getAttributes(), afterParcel.getAttributes());
+        assertEquals(beforeParcel.getInterfaceIndex(), afterParcel.getInterfaceIndex());
+        assertEquals(beforeParcel.getNetwork(), afterParcel.getNetwork());
     }
 
     @Test
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java
index 462685a..697116c 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java
@@ -39,6 +39,7 @@
 import android.annotation.NonNull;
 import android.annotation.Nullable;
 import android.net.InetAddresses;
+import android.net.Network;
 import android.text.TextUtils;
 
 import com.android.server.connectivity.mdns.MdnsServiceInfo.TextEntry;
@@ -79,6 +80,7 @@
     private static final int INTERFACE_INDEX = 999;
     private static final String SERVICE_TYPE = "_googlecast._tcp.local";
     private static final String[] SERVICE_TYPE_LABELS = TextUtils.split(SERVICE_TYPE, "\\.");
+    private static final Network NETWORK = mock(Network.class);
 
     @Mock
     private MdnsServiceBrowserListener mockListenerOne;
@@ -385,7 +387,8 @@
 
     private static void verifyServiceInfo(MdnsServiceInfo serviceInfo, String serviceName,
             String[] serviceType, String ipv4Address, String ipv6Address, int port,
-            List<String> subTypes, Map<String, String> attributes, int interfaceIndex) {
+            List<String> subTypes, Map<String, String> attributes, int interfaceIndex,
+            Network network) {
         assertEquals(serviceName, serviceInfo.getServiceInstanceName());
         assertArrayEquals(serviceType, serviceInfo.getServiceType());
         assertEquals(ipv4Address, serviceInfo.getIpv4Address());
@@ -396,6 +399,7 @@
             assertEquals(attributes.get(key), serviceInfo.getAttributeByKey(key));
         }
         assertEquals(interfaceIndex, serviceInfo.getInterfaceIndex());
+        assertEquals(network, serviceInfo.getNetwork());
     }
 
     @Test
@@ -405,6 +409,7 @@
         MdnsResponse response = mock(MdnsResponse.class);
         when(response.getServiceInstanceName()).thenReturn("service-instance-1");
         doReturn(INTERFACE_INDEX).when(response).getInterfaceIndex();
+        doReturn(NETWORK).when(response).getNetwork();
         when(response.isComplete()).thenReturn(false);
 
         client.processResponse(response);
@@ -417,7 +422,8 @@
                 0 /* port */,
                 List.of() /* subTypes */,
                 Collections.singletonMap("key", null) /* attributes */,
-                INTERFACE_INDEX);
+                INTERFACE_INDEX,
+                NETWORK);
 
         verify(mockListenerOne, never()).onServiceFound(any(MdnsServiceInfo.class));
         verify(mockListenerOne, never()).onServiceUpdated(any(MdnsServiceInfo.class));
@@ -436,7 +442,8 @@
                         5353,
                         /* subtype= */ "ABCDE",
                         Collections.emptyMap(),
-                        /* interfaceIndex= */ 20);
+                        /* interfaceIndex= */ 20,
+                        NETWORK);
         client.processResponse(initialResponse);
 
         // Process a second response with a different port and updated text attributes.
@@ -447,7 +454,8 @@
                         5354,
                         /* subtype= */ "ABCDE",
                         Collections.singletonMap("key", "value"),
-                        /* interfaceIndex= */ 20);
+                        /* interfaceIndex= */ 20,
+                        NETWORK);
         client.processResponse(secondResponse);
 
         // Verify onServiceNameDiscovered was called once for the initial response.
@@ -460,7 +468,8 @@
                 5353 /* port */,
                 Collections.singletonList("ABCDE") /* subTypes */,
                 Collections.singletonMap("key", null) /* attributes */,
-                20 /* interfaceIndex */);
+                20 /* interfaceIndex */,
+                NETWORK);
 
         // Verify onServiceFound was called once for the initial response.
         verify(mockListenerOne).onServiceFound(serviceInfoCaptor.capture());
@@ -471,6 +480,7 @@
         assertEquals(initialServiceInfo.getSubtypes(), Collections.singletonList("ABCDE"));
         assertNull(initialServiceInfo.getAttributeByKey("key"));
         assertEquals(initialServiceInfo.getInterfaceIndex(), 20);
+        assertEquals(NETWORK, initialServiceInfo.getNetwork());
 
         // Verify onServiceUpdated was called once for the second response.
         verify(mockListenerOne).onServiceUpdated(serviceInfoCaptor.capture());
@@ -482,6 +492,7 @@
         assertEquals(updatedServiceInfo.getSubtypes(), Collections.singletonList("ABCDE"));
         assertEquals(updatedServiceInfo.getAttributeByKey("key"), "value");
         assertEquals(updatedServiceInfo.getInterfaceIndex(), 20);
+        assertEquals(NETWORK, updatedServiceInfo.getNetwork());
     }
 
     @Test
@@ -497,7 +508,8 @@
                         5353,
                         /* subtype= */ "ABCDE",
                         Collections.emptyMap(),
-                        /* interfaceIndex= */ 20);
+                        /* interfaceIndex= */ 20,
+                        NETWORK);
         client.processResponse(initialResponse);
 
         // Process a second response with a different port and updated text attributes.
@@ -508,7 +520,8 @@
                         5354,
                         /* subtype= */ "ABCDE",
                         Collections.singletonMap("key", "value"),
-                        /* interfaceIndex= */ 20);
+                        /* interfaceIndex= */ 20,
+                        NETWORK);
         client.processResponse(secondResponse);
 
         System.out.println("secondResponses ip"
@@ -524,7 +537,8 @@
                 5353 /* port */,
                 Collections.singletonList("ABCDE") /* subTypes */,
                 Collections.singletonMap("key", null) /* attributes */,
-                20 /* interfaceIndex */);
+                20 /* interfaceIndex */,
+                NETWORK);
 
         // Verify onServiceFound was called once for the initial response.
         verify(mockListenerOne).onServiceFound(serviceInfoCaptor.capture());
@@ -535,6 +549,7 @@
         assertEquals(initialServiceInfo.getSubtypes(), Collections.singletonList("ABCDE"));
         assertNull(initialServiceInfo.getAttributeByKey("key"));
         assertEquals(initialServiceInfo.getInterfaceIndex(), 20);
+        assertEquals(NETWORK, initialServiceInfo.getNetwork());
 
         // Verify onServiceUpdated was called once for the second response.
         verify(mockListenerOne).onServiceUpdated(serviceInfoCaptor.capture());
@@ -546,6 +561,7 @@
         assertEquals(updatedServiceInfo.getSubtypes(), Collections.singletonList("ABCDE"));
         assertEquals(updatedServiceInfo.getAttributeByKey("key"), "value");
         assertEquals(updatedServiceInfo.getInterfaceIndex(), 20);
+        assertEquals(NETWORK, updatedServiceInfo.getNetwork());
     }
 
     private void verifyServiceRemovedNoCallback(MdnsServiceBrowserListener listener) {
@@ -554,15 +570,17 @@
     }
 
     private void verifyServiceRemovedCallback(MdnsServiceBrowserListener listener,
-            String serviceName, String[] serviceType, int interfaceIndex) {
+            String serviceName, String[] serviceType, int interfaceIndex, Network network) {
         verify(listener).onServiceRemoved(argThat(
                 info -> serviceName.equals(info.getServiceInstanceName())
                         && Arrays.equals(serviceType, info.getServiceType())
-                        && info.getInterfaceIndex() == interfaceIndex));
+                        && info.getInterfaceIndex() == interfaceIndex
+                        && network.equals(info.getNetwork())));
         verify(listener).onServiceNameRemoved(argThat(
                 info -> serviceName.equals(info.getServiceInstanceName())
                         && Arrays.equals(serviceType, info.getServiceType())
-                        && info.getInterfaceIndex() == interfaceIndex));
+                        && info.getInterfaceIndex() == interfaceIndex
+                        && network.equals(info.getNetwork())));
     }
 
     @Test
@@ -580,11 +598,13 @@
                         5353 /* port */,
                         /* subtype= */ "ABCDE",
                         Collections.emptyMap(),
-                        INTERFACE_INDEX);
+                        INTERFACE_INDEX,
+                        NETWORK);
         client.processResponse(initialResponse);
         MdnsResponse response = mock(MdnsResponse.class);
         doReturn("goodbye-service").when(response).getServiceInstanceName();
         doReturn(INTERFACE_INDEX).when(response).getInterfaceIndex();
+        doReturn(NETWORK).when(response).getNetwork();
         doReturn(true).when(response).isGoodbye();
         client.processResponse(response);
         // Verify removed callback won't be called if the service is not existed.
@@ -595,9 +615,9 @@
         doReturn(serviceName).when(response).getServiceInstanceName();
         client.processResponse(response);
         verifyServiceRemovedCallback(
-                mockListenerOne, serviceName, SERVICE_TYPE_LABELS, INTERFACE_INDEX);
+                mockListenerOne, serviceName, SERVICE_TYPE_LABELS, INTERFACE_INDEX, NETWORK);
         verifyServiceRemovedCallback(
-                mockListenerTwo, serviceName, SERVICE_TYPE_LABELS, INTERFACE_INDEX);
+                mockListenerTwo, serviceName, SERVICE_TYPE_LABELS, INTERFACE_INDEX, NETWORK);
     }
 
     @Test
@@ -610,7 +630,8 @@
                         5353,
                         /* subtype= */ "ABCDE",
                         Collections.emptyMap(),
-                        INTERFACE_INDEX);
+                        INTERFACE_INDEX,
+                        NETWORK);
         client.processResponse(initialResponse);
 
         client.startSendAndReceive(mockListenerOne, MdnsSearchOptions.getDefaultOptions());
@@ -625,7 +646,8 @@
                 5353 /* port */,
                 Collections.singletonList("ABCDE") /* subTypes */,
                 Collections.singletonMap("key", null) /* attributes */,
-                INTERFACE_INDEX);
+                INTERFACE_INDEX,
+                NETWORK);
 
         // Verify onServiceFound was called once for the existing response.
         verify(mockListenerOne).onServiceFound(serviceInfoCaptor.capture());
@@ -662,7 +684,7 @@
         MdnsResponse initialResponse =
                 createMockResponse(
                         serviceInstanceName, "192.168.1.1", 5353, List.of("ABCDE"),
-                        Map.of(), INTERFACE_INDEX);
+                        Map.of(), INTERFACE_INDEX, NETWORK);
         client.processResponse(initialResponse);
 
         // Clear the scheduled runnable.
@@ -696,7 +718,7 @@
         MdnsResponse initialResponse =
                 createMockResponse(
                         serviceInstanceName, "192.168.1.1", 5353, List.of("ABCDE"),
-                        Map.of(), INTERFACE_INDEX);
+                        Map.of(), INTERFACE_INDEX, NETWORK);
         client.processResponse(initialResponse);
 
         // Clear the scheduled runnable.
@@ -714,8 +736,8 @@
         firstMdnsTask.run();
 
         // Verify removed callback was called.
-        verifyServiceRemovedCallback(
-                mockListenerOne, serviceInstanceName, SERVICE_TYPE_LABELS, INTERFACE_INDEX);
+        verifyServiceRemovedCallback(mockListenerOne, serviceInstanceName, SERVICE_TYPE_LABELS,
+                INTERFACE_INDEX, NETWORK);
     }
 
     @Test
@@ -736,7 +758,7 @@
         MdnsResponse initialResponse =
                 createMockResponse(
                         serviceInstanceName, "192.168.1.1", 5353, List.of("ABCDE"),
-                        Map.of(), INTERFACE_INDEX);
+                        Map.of(), INTERFACE_INDEX, NETWORK);
         client.processResponse(initialResponse);
 
         // Clear the scheduled runnable.
@@ -770,7 +792,7 @@
         MdnsResponse initialResponse =
                 createMockResponse(
                         serviceInstanceName, "192.168.1.1", 5353, List.of("ABCDE"),
-                        Map.of(), INTERFACE_INDEX);
+                        Map.of(), INTERFACE_INDEX, NETWORK);
         client.processResponse(initialResponse);
 
         // Clear the scheduled runnable.
@@ -781,8 +803,8 @@
         firstMdnsTask.run();
 
         // Verify removed callback was called.
-        verifyServiceRemovedCallback(
-                mockListenerOne, serviceInstanceName, SERVICE_TYPE_LABELS, INTERFACE_INDEX);
+        verifyServiceRemovedCallback(mockListenerOne, serviceInstanceName, SERVICE_TYPE_LABELS,
+                INTERFACE_INDEX, NETWORK);
     }
 
     @Test
@@ -801,7 +823,8 @@
                         5353,
                         "ABCDE" /* subtype */,
                         Collections.emptyMap(),
-                        INTERFACE_INDEX);
+                        INTERFACE_INDEX,
+                        NETWORK);
         client.processResponse(initialResponse);
 
         // Process a second response which has ip address to make response become complete.
@@ -812,7 +835,8 @@
                         5353,
                         "ABCDE" /* subtype */,
                         Collections.emptyMap(),
-                        INTERFACE_INDEX);
+                        INTERFACE_INDEX,
+                        NETWORK);
         client.processResponse(secondResponse);
 
         // Process a third response with a different ip address, port and updated text attributes.
@@ -823,7 +847,8 @@
                         5354,
                         "ABCDE" /* subtype */,
                         Collections.singletonMap("key", "value"),
-                        INTERFACE_INDEX);
+                        INTERFACE_INDEX,
+                        NETWORK);
         client.processResponse(thirdResponse);
 
         // Process the last response which is goodbye message.
@@ -842,7 +867,8 @@
                 5353 /* port */,
                 Collections.singletonList("ABCDE") /* subTypes */,
                 Collections.singletonMap("key", null) /* attributes */,
-                INTERFACE_INDEX);
+                INTERFACE_INDEX,
+                NETWORK);
 
         // Verify onServiceFound was second called for the second response.
         inOrder.verify(mockListenerOne).onServiceFound(serviceInfoCaptor.capture());
@@ -854,7 +880,8 @@
                 5353 /* port */,
                 Collections.singletonList("ABCDE") /* subTypes */,
                 Collections.singletonMap("key", null) /* attributes */,
-                INTERFACE_INDEX);
+                INTERFACE_INDEX,
+                NETWORK);
 
         // Verify onServiceUpdated was third called for the third response.
         inOrder.verify(mockListenerOne).onServiceUpdated(serviceInfoCaptor.capture());
@@ -866,7 +893,8 @@
                 5354 /* port */,
                 Collections.singletonList("ABCDE") /* subTypes */,
                 Collections.singletonMap("key", "value") /* attributes */,
-                INTERFACE_INDEX);
+                INTERFACE_INDEX,
+                NETWORK);
 
         // Verify onServiceRemoved was called for the last response.
         inOrder.verify(mockListenerOne).onServiceRemoved(serviceInfoCaptor.capture());
@@ -878,7 +906,8 @@
                 5354 /* port */,
                 Collections.singletonList("ABCDE") /* subTypes */,
                 Collections.singletonMap("key", "value") /* attributes */,
-                INTERFACE_INDEX);
+                INTERFACE_INDEX,
+                NETWORK);
 
         // Verify onServiceNameRemoved was called for the last response.
         inOrder.verify(mockListenerOne).onServiceNameRemoved(serviceInfoCaptor.capture());
@@ -890,7 +919,8 @@
                 5354 /* port */,
                 Collections.singletonList("ABCDE") /* subTypes */,
                 Collections.singletonMap("key", "value") /* attributes */,
-                INTERFACE_INDEX);
+                INTERFACE_INDEX,
+                NETWORK);
     }
 
     // verifies that the right query was enqueued with the right delay, and send query by executing
@@ -962,26 +992,25 @@
             int port,
             @NonNull List<String> subtypes,
             @NonNull Map<String, String> textAttributes,
-            int interfaceIndex)
+            int interfaceIndex,
+            Network network)
             throws Exception {
         String[] hostName = new String[]{"hostname"};
         MdnsServiceRecord serviceRecord = mock(MdnsServiceRecord.class);
         when(serviceRecord.getServiceHost()).thenReturn(hostName);
         when(serviceRecord.getServicePort()).thenReturn(port);
 
-        MdnsResponse response = spy(new MdnsResponse(0));
+        MdnsResponse response = spy(new MdnsResponse(0, interfaceIndex, network));
 
         MdnsInetAddressRecord inetAddressRecord = mock(MdnsInetAddressRecord.class);
         if (host.contains(":")) {
             when(inetAddressRecord.getInet6Address())
                     .thenReturn((Inet6Address) Inet6Address.getByName(host));
             response.setInet6AddressRecord(inetAddressRecord);
-            response.setInterfaceIndex(interfaceIndex);
         } else {
             when(inetAddressRecord.getInet4Address())
                     .thenReturn((Inet4Address) Inet4Address.getByName(host));
             response.setInet4AddressRecord(inetAddressRecord);
-            response.setInterfaceIndex(interfaceIndex);
         }
 
         MdnsTextRecord textRecord = mock(MdnsTextRecord.class);
@@ -1011,10 +1040,10 @@
             int port,
             @NonNull String subtype,
             @NonNull Map<String, String> textAttributes,
-            int interfaceIndex)
+            int interfaceIndex,
+            Network network)
             throws Exception {
-        MdnsResponse response = new MdnsResponse(0);
-        response.setInterfaceIndex(interfaceIndex);
+        MdnsResponse response = new MdnsResponse(0, interfaceIndex, network);
 
         // Set PTR record
         final MdnsPointerRecord pointerRecord = new MdnsPointerRecord(
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketClientTests.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketClientTests.java
index b4442a5..1d61cd3 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketClientTests.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketClientTests.java
@@ -501,8 +501,7 @@
         //MdnsConfigsFlagsImpl.allowNetworkInterfaceIndexPropagation.override(true);
 
         when(mockMulticastSocket.getInterfaceIndex()).thenReturn(21);
-        mdnsClient =
-                new MdnsSocketClient(mContext, mockMulticastLock) {
+        mdnsClient = new MdnsSocketClient(mContext, mockMulticastLock) {
                     @Override
                     MdnsSocket createMdnsSocket(int port) {
                         if (port == MdnsConstants.MDNS_PORT) {
@@ -525,8 +524,7 @@
         //MdnsConfigsFlagsImpl.allowNetworkInterfaceIndexPropagation.override(false);
 
         when(mockMulticastSocket.getInterfaceIndex()).thenReturn(21);
-        mdnsClient =
-                new MdnsSocketClient(mContext, mockMulticastLock) {
+        mdnsClient = new MdnsSocketClient(mContext, mockMulticastLock) {
                     @Override
                     MdnsSocket createMdnsSocket(int port) {
                         if (port == MdnsConstants.MDNS_PORT) {
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketProviderTest.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketProviderTest.java
new file mode 100644
index 0000000..2bb61a6a
--- /dev/null
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketProviderTest.java
@@ -0,0 +1,289 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server.connectivity.mdns;
+
+import static com.android.testutils.ContextUtils.mockService;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertTrue;
+import static org.mockito.ArgumentMatchers.anyInt;
+import static org.mockito.Mockito.any;
+import static org.mockito.Mockito.doReturn;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+
+import android.content.Context;
+import android.net.ConnectivityManager;
+import android.net.ConnectivityManager.NetworkCallback;
+import android.net.INetd;
+import android.net.LinkAddress;
+import android.net.LinkProperties;
+import android.net.Network;
+import android.net.TetheringManager;
+import android.net.TetheringManager.TetheringEventCallback;
+import android.os.Build;
+import android.os.Handler;
+import android.os.HandlerThread;
+
+import com.android.net.module.util.ArrayTrackRecord;
+import com.android.server.connectivity.mdns.MdnsSocketProvider.Dependencies;
+import com.android.testutils.DevSdkIgnoreRule;
+import com.android.testutils.DevSdkIgnoreRunner;
+import com.android.testutils.HandlerUtils;
+
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.ArgumentCaptor;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.List;
+
+@RunWith(DevSdkIgnoreRunner.class)
+@DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.S_V2)
+public class MdnsSocketProviderTest {
+    private static final String TEST_IFACE_NAME = "test";
+    private static final String LOCAL_ONLY_IFACE_NAME = "local_only";
+    private static final String TETHERED_IFACE_NAME = "tethered";
+    private static final long DEFAULT_TIMEOUT = 2000L;
+    private static final long NO_CALLBACK_TIMEOUT = 200L;
+    private static final LinkAddress LINKADDRV4 = new LinkAddress("192.0.2.0/24");
+    private static final LinkAddress LINKADDRV6 =
+            new LinkAddress("2001:0db8:85a3:0000:0000:8a2e:0370:7334/64");
+    private static final Network TEST_NETWORK = new Network(123);
+    private static final Network LOCAL_NETWORK = new Network(INetd.LOCAL_NET_ID);
+
+    @Mock private Context mContext;
+    @Mock private Dependencies mDeps;
+    @Mock private ConnectivityManager mCm;
+    @Mock private TetheringManager mTm;
+    @Mock private NetworkInterfaceWrapper mTestNetworkIfaceWrapper;
+    @Mock private NetworkInterfaceWrapper mLocalOnlyIfaceWrapper;
+    @Mock private NetworkInterfaceWrapper mTetheredIfaceWrapper;
+    private Handler mHandler;
+    private MdnsSocketProvider mSocketProvider;
+    private NetworkCallback mNetworkCallback;
+    private TetheringEventCallback mTetheringEventCallback;
+
+    @Before
+    public void setUp() throws IOException {
+        MockitoAnnotations.initMocks(this);
+        mockService(mContext, ConnectivityManager.class, Context.CONNECTIVITY_SERVICE, mCm);
+        mockService(mContext, TetheringManager.class, Context.TETHERING_SERVICE, mTm);
+        doReturn(true).when(mDeps).canScanOnInterface(any());
+        doReturn(mTestNetworkIfaceWrapper).when(mDeps).getNetworkInterfaceByName(TEST_IFACE_NAME);
+        doReturn(mLocalOnlyIfaceWrapper).when(mDeps)
+                .getNetworkInterfaceByName(LOCAL_ONLY_IFACE_NAME);
+        doReturn(mTetheredIfaceWrapper).when(mDeps).getNetworkInterfaceByName(TETHERED_IFACE_NAME);
+        doReturn(mock(MdnsInterfaceSocket.class))
+                .when(mDeps).createMdnsInterfaceSocket(any(), anyInt());
+        final HandlerThread thread = new HandlerThread("MdnsSocketProviderTest");
+        thread.start();
+        mHandler = new Handler(thread.getLooper());
+
+        final ArgumentCaptor<NetworkCallback> nwCallbackCaptor =
+                ArgumentCaptor.forClass(NetworkCallback.class);
+        final ArgumentCaptor<TetheringEventCallback> teCallbackCaptor =
+                ArgumentCaptor.forClass(TetheringEventCallback.class);
+        mSocketProvider = new MdnsSocketProvider(mContext, thread.getLooper(), mDeps);
+        mHandler.post(mSocketProvider::startMonitoringSockets);
+        HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
+        verify(mCm).registerNetworkCallback(any(), nwCallbackCaptor.capture(), any());
+        verify(mTm).registerTetheringEventCallback(any(), teCallbackCaptor.capture());
+
+        mNetworkCallback = nwCallbackCaptor.getValue();
+        mTetheringEventCallback = teCallbackCaptor.getValue();
+    }
+
+    private class TestSocketCallback implements MdnsSocketProvider.SocketCallback {
+        private class SocketEvent {
+            public final Network mNetwork;
+            public final List<LinkAddress> mAddresses;
+
+            SocketEvent(Network network, List<LinkAddress> addresses) {
+                mNetwork = network;
+                mAddresses = Collections.unmodifiableList(addresses);
+            }
+        }
+
+        private class SocketCreatedEvent extends SocketEvent {
+            SocketCreatedEvent(Network nw, List<LinkAddress> addresses) {
+                super(nw, addresses);
+            }
+        }
+
+        private class InterfaceDestroyedEvent extends SocketEvent {
+            InterfaceDestroyedEvent(Network nw, List<LinkAddress> addresses) {
+                super(nw, addresses);
+            }
+        }
+
+        private class AddressesChangedEvent extends SocketEvent {
+            AddressesChangedEvent(Network nw, List<LinkAddress> addresses) {
+                super(nw, addresses);
+            }
+        }
+
+        private final ArrayTrackRecord<SocketEvent>.ReadHead mHistory =
+                new ArrayTrackRecord<SocketEvent>().newReadHead();
+
+        @Override
+        public void onSocketCreated(Network network, MdnsInterfaceSocket socket,
+                List<LinkAddress> addresses) {
+            mHistory.add(new SocketCreatedEvent(network, addresses));
+        }
+
+        @Override
+        public void onInterfaceDestroyed(Network network, MdnsInterfaceSocket socket) {
+            mHistory.add(new InterfaceDestroyedEvent(network, List.of()));
+        }
+
+        @Override
+        public void onAddressesChanged(Network network, List<LinkAddress> addresses) {
+            mHistory.add(new AddressesChangedEvent(network, addresses));
+        }
+
+        public void expectedSocketCreatedForNetwork(Network network, List<LinkAddress> addresses) {
+            final SocketEvent event = mHistory.poll(DEFAULT_TIMEOUT, c -> true);
+            assertNotNull(event);
+            assertTrue(event instanceof SocketCreatedEvent);
+            assertEquals(network, event.mNetwork);
+            assertEquals(addresses, event.mAddresses);
+        }
+
+        public void expectedInterfaceDestroyedForNetwork(Network network) {
+            final SocketEvent event = mHistory.poll(DEFAULT_TIMEOUT, c -> true);
+            assertNotNull(event);
+            assertTrue(event instanceof InterfaceDestroyedEvent);
+            assertEquals(network, event.mNetwork);
+        }
+
+        public void expectedAddressesChangedForNetwork(Network network,
+                List<LinkAddress> addresses) {
+            final SocketEvent event = mHistory.poll(DEFAULT_TIMEOUT, c -> true);
+            assertNotNull(event);
+            assertTrue(event instanceof AddressesChangedEvent);
+            assertEquals(network, event.mNetwork);
+            assertEquals(event.mAddresses, addresses);
+        }
+
+        public void expectedNoCallback() {
+            final SocketEvent event = mHistory.poll(NO_CALLBACK_TIMEOUT, c -> true);
+            assertNull(event);
+        }
+    }
+
+    @Test
+    public void testSocketRequestAndUnrequestSocket() {
+        final TestSocketCallback testCallback1 = new TestSocketCallback();
+        mHandler.post(() -> mSocketProvider.requestSocket(TEST_NETWORK, testCallback1));
+        HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
+        testCallback1.expectedNoCallback();
+
+        final LinkProperties testLp = new LinkProperties();
+        testLp.setInterfaceName(TEST_IFACE_NAME);
+        testLp.setLinkAddresses(List.of(LINKADDRV4));
+        mHandler.post(() -> mNetworkCallback.onLinkPropertiesChanged(TEST_NETWORK, testLp));
+        HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
+        verify(mTestNetworkIfaceWrapper).getNetworkInterface();
+        testCallback1.expectedSocketCreatedForNetwork(TEST_NETWORK, List.of(LINKADDRV4));
+
+        final TestSocketCallback testCallback2 = new TestSocketCallback();
+        mHandler.post(() -> mSocketProvider.requestSocket(TEST_NETWORK, testCallback2));
+        HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
+        testCallback1.expectedNoCallback();
+        testCallback2.expectedSocketCreatedForNetwork(TEST_NETWORK, List.of(LINKADDRV4));
+
+        final TestSocketCallback testCallback3 = new TestSocketCallback();
+        mHandler.post(() -> mSocketProvider.requestSocket(null /* network */, testCallback3));
+        HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
+        testCallback1.expectedNoCallback();
+        testCallback2.expectedNoCallback();
+        testCallback3.expectedSocketCreatedForNetwork(TEST_NETWORK, List.of(LINKADDRV4));
+
+        mHandler.post(() -> mTetheringEventCallback.onLocalOnlyInterfacesChanged(
+                List.of(LOCAL_ONLY_IFACE_NAME)));
+        HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
+        verify(mLocalOnlyIfaceWrapper).getNetworkInterface();
+        testCallback1.expectedNoCallback();
+        testCallback2.expectedNoCallback();
+        testCallback3.expectedSocketCreatedForNetwork(LOCAL_NETWORK, List.of());
+
+        mHandler.post(() -> mTetheringEventCallback.onTetheredInterfacesChanged(
+                List.of(TETHERED_IFACE_NAME)));
+        HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
+        verify(mTetheredIfaceWrapper).getNetworkInterface();
+        testCallback1.expectedNoCallback();
+        testCallback2.expectedNoCallback();
+        testCallback3.expectedSocketCreatedForNetwork(LOCAL_NETWORK, List.of());
+
+        mHandler.post(() -> mSocketProvider.unrequestSocket(testCallback1));
+        HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
+        testCallback1.expectedNoCallback();
+        testCallback2.expectedNoCallback();
+        testCallback3.expectedNoCallback();
+
+        mHandler.post(() -> mNetworkCallback.onLost(TEST_NETWORK));
+        HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
+        testCallback1.expectedNoCallback();
+        testCallback2.expectedInterfaceDestroyedForNetwork(TEST_NETWORK);
+        testCallback3.expectedInterfaceDestroyedForNetwork(TEST_NETWORK);
+
+        mHandler.post(() -> mTetheringEventCallback.onLocalOnlyInterfacesChanged(List.of()));
+        HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
+        testCallback1.expectedNoCallback();
+        testCallback2.expectedNoCallback();
+        testCallback3.expectedInterfaceDestroyedForNetwork(LOCAL_NETWORK);
+
+        mHandler.post(() -> mSocketProvider.unrequestSocket(testCallback3));
+        HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
+        testCallback1.expectedNoCallback();
+        testCallback2.expectedNoCallback();
+        testCallback3.expectedNoCallback();
+    }
+
+    @Test
+    public void testAddressesChanged() throws Exception {
+        final TestSocketCallback testCallback = new TestSocketCallback();
+        mHandler.post(() -> mSocketProvider.requestSocket(TEST_NETWORK, testCallback));
+        HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
+        testCallback.expectedNoCallback();
+
+        final LinkProperties testLp = new LinkProperties();
+        testLp.setInterfaceName(TEST_IFACE_NAME);
+        testLp.setLinkAddresses(List.of(LINKADDRV4));
+        mHandler.post(() -> mNetworkCallback.onLinkPropertiesChanged(TEST_NETWORK, testLp));
+        HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
+        verify(mTestNetworkIfaceWrapper, times(1)).getNetworkInterface();
+        testCallback.expectedSocketCreatedForNetwork(TEST_NETWORK, List.of(LINKADDRV4));
+
+        final LinkProperties newTestLp = new LinkProperties();
+        newTestLp.setInterfaceName(TEST_IFACE_NAME);
+        newTestLp.setLinkAddresses(List.of(LINKADDRV4, LINKADDRV6));
+        mHandler.post(() -> mNetworkCallback.onLinkPropertiesChanged(TEST_NETWORK, newTestLp));
+        HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
+        verify(mTestNetworkIfaceWrapper, times(1)).getNetworkInterface();
+        testCallback.expectedAddressesChangedForNetwork(
+                TEST_NETWORK, List.of(LINKADDRV4, LINKADDRV6));
+    }
+}
diff --git a/tools/gn2bp/Android.bp.swp b/tools/gn2bp/Android.bp.swp
index 4d0ed32..0b0ad9e 100644
--- a/tools/gn2bp/Android.bp.swp
+++ b/tools/gn2bp/Android.bp.swp
@@ -9848,6 +9848,56 @@
         "net/android/java/src/org/chromium/net/X509Util.java",
         "url/android/java/src/org/chromium/url/IDNStringUtil.java",
     ],
+    apex_available: [
+        "//apex_available:platform",
+        "com.android.tethering",
+    ],
+    libs: [
+        "android-support-multidex",
+        "androidx.annotation_annotation",
+        "androidx.annotation_annotation-experimental-nodeps",
+        "androidx.collection_collection",
+        "androidx.core_core-nodeps",
+        "framework-connectivity-t.stubs.module_lib",
+        "framework-connectivity.stubs.module_lib",
+        "framework-mediaprovider.stubs.module_lib",
+        "framework-tethering.stubs.module_lib",
+        "framework-wifi.stubs.module_lib",
+        "jsr305",
+    ],
+    aidl: {
+        include_dirs: [
+            "frameworks/base/core/java/",
+        ],
+        local_include_dirs: [
+            "base/android/java/src/",
+        ],
+    },
+    plugins: [
+        "cronet_aml_java_jni_annotation_preprocessor",
+    ],
+    sdk_version: "module_current",
+}
+
+// GN: //base/android/jni_generator:jni_processor
+java_plugin {
+    name: "cronet_aml_java_jni_annotation_preprocessor",
+    srcs: [
+        ":cronet_aml_build_android_build_config_gen",
+        "base/android/java/src/org/chromium/base/JniException.java",
+        "base/android/java/src/org/chromium/base/JniStaticTestMocker.java",
+        "base/android/java/src/org/chromium/base/NativeLibraryLoadedStatus.java",
+        "base/android/java/src/org/chromium/base/annotations/NativeMethods.java",
+        "base/android/jni_generator/java/src/org/chromium/jni_generator/JniProcessor.java",
+        "build/android/java/src/org/chromium/build/annotations/CheckDiscard.java",
+        "build/android/java/src/org/chromium/build/annotations/MainDex.java",
+    ],
+    static_libs: [
+        "auto_service_annotations",
+        "guava",
+        "javapoet",
+    ],
+    processor_class: "org.chromium.jni_generator.JniProcessor",
 }
 
 // GN: //net/android:net_android_java_enums_srcjar
diff --git a/tools/gn2bp/gen_android_bp b/tools/gn2bp/gen_android_bp
index 42cb494..46918c6 100755
--- a/tools/gn2bp/gen_android_bp
+++ b/tools/gn2bp/gen_android_bp
@@ -314,8 +314,13 @@
     self.cppflags = set()
     self.rtti = False
     # Name of the output. Used for setting .so file name for libcronet
+    self.libs = set()
     self.stem = None
     self.compile_multilib = None
+    self.aidl = dict()
+    self.plugins = set()
+    self.processor_class = None
+    self.sdk_version = None
 
   def to_string(self, output):
     if self.comment:
@@ -362,8 +367,13 @@
     self._output_field(output, 'proto')
     self._output_field(output, 'linker_scripts')
     self._output_field(output, 'cppflags')
+    self._output_field(output, 'libs')
     self._output_field(output, 'stem')
     self._output_field(output, 'compile_multilib')
+    self._output_field(output, 'aidl')
+    self._output_field(output, 'plugins')
+    self._output_field(output, 'processor_class')
+    self._output_field(output, 'sdk_version')
     if self.rtti:
       self._output_field(output, 'rtti')
 
@@ -1298,16 +1308,62 @@
                     (dep_module.name, target.name, dep_module.type))
   return module
 
+def create_java_jni_preprocessor(blueprint):
+  bp_module_name = module_prefix + 'java_jni_annotation_preprocessor'
+  module = Module('java_plugin', bp_module_name, '//base/android/jni_generator:jni_processor')
+  module.srcs.update(
+  [
+    "base/android/jni_generator/java/src/org/chromium/jni_generator/JniProcessor.java",
+    # Avoids a circular dependency with base:base_java. This is okay because
+    # no target should ever expect to package an annotation processor.
+    "build/android/java/src/org/chromium/build/annotations/CheckDiscard.java",
+    "build/android/java/src/org/chromium/build/annotations/MainDex.java",
+    "base/android/java/src/org/chromium/base/JniStaticTestMocker.java",
+    "base/android/java/src/org/chromium/base/NativeLibraryLoadedStatus.java",
+    "base/android/java/src/org/chromium/base/annotations/NativeMethods.java",
+    "base/android/java/src/org/chromium/base/JniException.java",
+    ":cronet_aml_build_android_build_config_gen",
+  ])
+  module.static_libs.update({
+      "javapoet",
+      "guava",
+      "auto_service_annotations",
+  })
+  module.processor_class = "org.chromium.jni_generator.JniProcessor"
+  blueprint.add_module(module)
+  return module
+
 def create_java_module(blueprint, gn):
   bp_module_name = module_prefix + 'java'
   module = Module('java_library', bp_module_name, '//gn:java')
   module.srcs.update([gn_utils.label_to_path(source) for source in gn.java_sources])
+  module.libs = {
+    "androidx.annotation_annotation",
+    "jsr305",
+    "androidx.core_core-nodeps",
+    "androidx.collection_collection",
+    "androidx.annotation_annotation-experimental-nodeps",
+    "android-support-multidex",
+    "framework-connectivity.stubs.module_lib",
+    "framework-connectivity-t.stubs.module_lib",
+    "framework-tethering.stubs.module_lib",
+    "framework-wifi.stubs.module_lib",
+    "framework-mediaprovider.stubs.module_lib",
+  }
+  module.aidl["include_dirs"] = {"frameworks/base/core/java/"}
+  module.aidl["local_include_dirs"] = {"base/android/java/src/"}
+  module.sdk_version = "module_current"
+  module.apex_available.add(tethering_apex)
+  # TODO: remove following workaround required to make this module visible to make (b/203203405)
+  module.apex_available.add("//apex_available:platform")
   for dep in gn.java_actions:
     target = gn.get_target(dep)
     if target.script == '//build/android/gyp/gcc_preprocess.py':
       module.srcs.add(':' + create_gcc_preprocess_modules(blueprint, target).name)
     else:
       module.srcs.add(':' + create_action_module(blueprint, target, 'java_genrule').name)
+  preprocessor_module = create_java_jni_preprocessor(blueprint)
+  module.plugins.add(preprocessor_module.name)
   blueprint.add_module(module)
 
 def update_jni_registration_module(module, gn):