Put the Network info in MdnsServiceInfo

In Nsd, every request has Network info to assign the specific
network to do the mdns query. But the response MdnsServiceInfo
only has interface index which is not very useful and need to
transfer to a Network every time when using it because most
APIs for apps to use the network are based on Network object.
Thus, put the Network info into MdnsServiceInfo.

Bug: 254166302
Test: atest FrameworksNetTests
Change-Id: I2206a84636981fc7d9aa9deda0f18f60642bc7d7
diff --git a/service/mdns/com/android/server/connectivity/mdns/ConnectivityMonitor.java b/service/mdns/com/android/server/connectivity/mdns/ConnectivityMonitor.java
index 2b99d0a..1623669 100644
--- a/service/mdns/com/android/server/connectivity/mdns/ConnectivityMonitor.java
+++ b/service/mdns/com/android/server/connectivity/mdns/ConnectivityMonitor.java
@@ -16,6 +16,8 @@
 
 package com.android.server.connectivity.mdns;
 
+import android.net.Network;
+
 /** Interface for monitoring connectivity changes. */
 public interface ConnectivityMonitor {
     /**
@@ -29,6 +31,9 @@
 
     void notifyConnectivityChange();
 
+    /** Get available network which is received from connectivity change. */
+    Network getAvailableNetwork();
+
     /** Listener interface for receiving connectivity changes. */
     interface Listener {
         void onConnectivityChanged();
diff --git a/service/mdns/com/android/server/connectivity/mdns/ConnectivityMonitorWithConnectivityManager.java b/service/mdns/com/android/server/connectivity/mdns/ConnectivityMonitorWithConnectivityManager.java
index 3563d61..551e3db 100644
--- a/service/mdns/com/android/server/connectivity/mdns/ConnectivityMonitorWithConnectivityManager.java
+++ b/service/mdns/com/android/server/connectivity/mdns/ConnectivityMonitorWithConnectivityManager.java
@@ -16,6 +16,7 @@
 
 package com.android.server.connectivity.mdns;
 
+import android.annotation.Nullable;
 import android.annotation.TargetApi;
 import android.content.Context;
 import android.net.ConnectivityManager;
@@ -37,6 +38,7 @@
     // TODO(b/71901993): Ideally we shouldn't need this flag. However we still don't have clues why
     // the receiver is unregistered twice yet.
     private boolean isCallbackRegistered = false;
+    private Network lastAvailableNetwork = null;
 
     @SuppressWarnings({"nullness:assignment", "nullness:method.invocation"})
     @TargetApi(Build.VERSION_CODES.LOLLIPOP)
@@ -50,6 +52,7 @@
                     @Override
                     public void onAvailable(Network network) {
                         LOGGER.log("network available.");
+                        lastAvailableNetwork = network;
                         notifyConnectivityChange();
                     }
 
@@ -103,4 +106,10 @@
         connectivityManager.unregisterNetworkCallback(networkCallback);
         isCallbackRegistered = false;
     }
+
+    @Override
+    @Nullable
+    public Network getAvailableNetwork() {
+        return lastAvailableNetwork;
+    }
 }
\ No newline at end of file
diff --git a/service/mdns/com/android/server/connectivity/mdns/MdnsResponse.java b/service/mdns/com/android/server/connectivity/mdns/MdnsResponse.java
index 623168c..3a41978 100644
--- a/service/mdns/com/android/server/connectivity/mdns/MdnsResponse.java
+++ b/service/mdns/com/android/server/connectivity/mdns/MdnsResponse.java
@@ -17,6 +17,7 @@
 package com.android.server.connectivity.mdns;
 
 import android.annotation.Nullable;
+import android.net.Network;
 
 import com.android.internal.annotations.VisibleForTesting;
 
@@ -35,13 +36,16 @@
     private MdnsInetAddressRecord inet4AddressRecord;
     private MdnsInetAddressRecord inet6AddressRecord;
     private long lastUpdateTime;
-    private int interfaceIndex = MdnsSocket.INTERFACE_INDEX_UNSPECIFIED;
+    private final int interfaceIndex;
+    @Nullable private final Network network;
 
     /** Constructs a new, empty response. */
-    public MdnsResponse(long now) {
+    public MdnsResponse(long now, int interfaceIndex, @Nullable Network network) {
         lastUpdateTime = now;
         records = new LinkedList<>();
         pointerRecords = new LinkedList<>();
+        this.interfaceIndex = interfaceIndex;
+        this.network = network;
     }
 
     // This generic typed helper compares records for equality.
@@ -208,21 +212,21 @@
     }
 
     /**
-     * Updates the index of the network interface at which this response was received. Can be set to
-     * {@link MdnsSocket#INTERFACE_INDEX_UNSPECIFIED} if unset.
-     */
-    public synchronized void setInterfaceIndex(int interfaceIndex) {
-        this.interfaceIndex = interfaceIndex;
-    }
-
-    /**
      * Returns the index of the network interface at which this response was received. Can be set to
      * {@link MdnsSocket#INTERFACE_INDEX_UNSPECIFIED} if unset.
      */
-    public synchronized int getInterfaceIndex() {
+    public int getInterfaceIndex() {
         return interfaceIndex;
     }
 
+    /**
+     * Returns the network at which this response was received, or null if the network is unknown.
+     */
+    @Nullable
+    public Network getNetwork() {
+        return network;
+    }
+
     /** Gets the IPv6 address record. */
     public synchronized MdnsInetAddressRecord getInet6AddressRecord() {
         return inet6AddressRecord;
diff --git a/service/mdns/com/android/server/connectivity/mdns/MdnsResponseDecoder.java b/service/mdns/com/android/server/connectivity/mdns/MdnsResponseDecoder.java
index 6c2bc19..7cf84f6 100644
--- a/service/mdns/com/android/server/connectivity/mdns/MdnsResponseDecoder.java
+++ b/service/mdns/com/android/server/connectivity/mdns/MdnsResponseDecoder.java
@@ -18,6 +18,7 @@
 
 import android.annotation.NonNull;
 import android.annotation.Nullable;
+import android.net.Network;
 import android.os.SystemClock;
 
 import com.android.server.connectivity.mdns.util.MdnsLogger;
@@ -95,10 +96,11 @@
      * @param packet The packet to read from.
      * @param interfaceIndex the network interface index (or {@link
      *     MdnsSocket#INTERFACE_INDEX_UNSPECIFIED} if not known) at which the packet was received
+     * @param network the network at which the packet was received, or null if it is unknown.
      * @return A list of mDNS responses, or null if the packet contained no appropriate responses.
      */
     public int decode(@NonNull DatagramPacket packet, @NonNull List<MdnsResponse> responses,
-            int interfaceIndex) {
+            int interfaceIndex, @Nullable Network network) {
         MdnsPacketReader reader = new MdnsPacketReader(packet);
 
         List<MdnsRecord> records;
@@ -253,12 +255,11 @@
                     MdnsResponse response = findResponseWithPointer(responses,
                             pointerRecord.getPointer());
                     if (response == null) {
-                        response = new MdnsResponse(now);
+                        response = new MdnsResponse(now, interfaceIndex, network);
                         responses.add(response);
                     }
                     // Set interface index earlier because some responses have PTR record only.
                     // Need to know every response is getting from which interface.
-                    response.setInterfaceIndex(interfaceIndex);
                     response.addPointerRecord((MdnsPointerRecord) record);
                 }
             }
diff --git a/service/mdns/com/android/server/connectivity/mdns/MdnsServiceInfo.java b/service/mdns/com/android/server/connectivity/mdns/MdnsServiceInfo.java
index 9683bc9..938fc3f 100644
--- a/service/mdns/com/android/server/connectivity/mdns/MdnsServiceInfo.java
+++ b/service/mdns/com/android/server/connectivity/mdns/MdnsServiceInfo.java
@@ -16,8 +16,11 @@
 
 package com.android.server.connectivity.mdns;
 
+import static com.android.server.connectivity.mdns.MdnsSocket.INTERFACE_INDEX_UNSPECIFIED;
+
 import android.annotation.NonNull;
 import android.annotation.Nullable;
+import android.net.Network;
 import android.os.Parcel;
 import android.os.Parcelable;
 import android.text.TextUtils;
@@ -58,7 +61,8 @@
                             source.readString(),
                             source.createStringArrayList(),
                             source.createTypedArrayList(TextEntry.CREATOR),
-                            source.readInt());
+                            source.readInt(),
+                            source.readParcelable(null));
                 }
 
                 @Override
@@ -82,6 +86,8 @@
     private final int interfaceIndex;
 
     private final Map<String, byte[]> attributes;
+    @Nullable
+    private final Network network;
 
     /** Constructs a {@link MdnsServiceInfo} object with default values. */
     public MdnsServiceInfo(
@@ -103,7 +109,8 @@
                 ipv6Address,
                 textStrings,
                 /* textEntries= */ null,
-                /* interfaceIndex= */ -1);
+                /* interfaceIndex= */ INTERFACE_INDEX_UNSPECIFIED,
+                /* network= */ null);
     }
 
     /** Constructs a {@link MdnsServiceInfo} object with default values. */
@@ -127,7 +134,8 @@
                 ipv6Address,
                 textStrings,
                 textEntries,
-                /* interfaceIndex= */ -1);
+                /* interfaceIndex= */ INTERFACE_INDEX_UNSPECIFIED,
+                /* network= */ null);
     }
 
     /**
@@ -146,6 +154,37 @@
             @Nullable List<String> textStrings,
             @Nullable List<TextEntry> textEntries,
             int interfaceIndex) {
+        this(
+                serviceInstanceName,
+                serviceType,
+                subtypes,
+                hostName,
+                port,
+                ipv4Address,
+                ipv6Address,
+                textStrings,
+                textEntries,
+                interfaceIndex,
+                /* network= */ null);
+    }
+
+    /**
+     * Constructs a {@link MdnsServiceInfo} object with default values.
+     *
+     * @hide
+     */
+    public MdnsServiceInfo(
+            String serviceInstanceName,
+            String[] serviceType,
+            @Nullable List<String> subtypes,
+            String[] hostName,
+            int port,
+            @Nullable String ipv4Address,
+            @Nullable String ipv6Address,
+            @Nullable List<String> textStrings,
+            @Nullable List<TextEntry> textEntries,
+            int interfaceIndex,
+            @Nullable Network network) {
         this.serviceInstanceName = serviceInstanceName;
         this.serviceType = serviceType;
         this.subtypes = new ArrayList<>();
@@ -180,6 +219,7 @@
         }
         this.attributes = Collections.unmodifiableMap(attributes);
         this.interfaceIndex = interfaceIndex;
+        this.network = network;
     }
 
     private static List<TextEntry> parseTextStrings(List<String> textStrings) {
@@ -244,6 +284,14 @@
     }
 
     /**
+     * Returns the network at which this response was received, or null if the network is unknown.
+     */
+    @Nullable
+    public Network getNetwork() {
+        return network;
+    }
+
+    /**
      * Returns attribute value for {@code key} as a UTF-8 string. It's the caller who must make sure
      * that the value of {@code key} is indeed a UTF-8 string. {@code null} will be returned if no
      * attribute value exists for {@code key}.
@@ -293,6 +341,7 @@
         out.writeStringList(textStrings);
         out.writeTypedList(textEntries);
         out.writeInt(interfaceIndex);
+        out.writeParcelable(network, 0);
     }
 
     @Override
diff --git a/service/mdns/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java b/service/mdns/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
index dd4ff9b..538f376 100644
--- a/service/mdns/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
+++ b/service/mdns/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
@@ -130,7 +130,8 @@
                 ipv6Address,
                 textStrings,
                 textEntries,
-                response.getInterfaceIndex());
+                response.getInterfaceIndex(),
+                response.getNetwork());
     }
 
     /**
diff --git a/service/mdns/com/android/server/connectivity/mdns/MdnsSocket.java b/service/mdns/com/android/server/connectivity/mdns/MdnsSocket.java
index 0a9b2fc..f8452c9 100644
--- a/service/mdns/com/android/server/connectivity/mdns/MdnsSocket.java
+++ b/service/mdns/com/android/server/connectivity/mdns/MdnsSocket.java
@@ -17,6 +17,8 @@
 package com.android.server.connectivity.mdns;
 
 import android.annotation.NonNull;
+import android.annotation.Nullable;
+import android.net.Network;
 
 import com.android.internal.annotations.VisibleForTesting;
 import com.android.server.connectivity.mdns.util.MdnsLogger;
@@ -125,6 +127,14 @@
         }
     }
 
+    /**
+     * Returns the available network that this socket is used to, or null if the network is unknown.
+     */
+    @Nullable
+    public Network getNetwork() {
+        return multicastNetworkInterfaceProvider.getAvailableNetwork();
+    }
+
     public boolean isOnIPv6OnlyNetwork() {
         return isOnIPv6OnlyNetwork;
     }
diff --git a/service/mdns/com/android/server/connectivity/mdns/MdnsSocketClient.java b/service/mdns/com/android/server/connectivity/mdns/MdnsSocketClient.java
index 758221a..6a321d1 100644
--- a/service/mdns/com/android/server/connectivity/mdns/MdnsSocketClient.java
+++ b/service/mdns/com/android/server/connectivity/mdns/MdnsSocketClient.java
@@ -21,6 +21,7 @@
 import android.annotation.Nullable;
 import android.annotation.RequiresPermission;
 import android.content.Context;
+import android.net.Network;
 import android.net.wifi.WifiManager.MulticastLock;
 import android.os.SystemClock;
 import android.text.format.DateUtils;
@@ -397,7 +398,8 @@
                             responseType,
                             /* interfaceIndex= */ (socket == null || !propagateInterfaceIndex)
                                     ? MdnsSocket.INTERFACE_INDEX_UNSPECIFIED
-                                    : socket.getInterfaceIndex());
+                                    : socket.getInterfaceIndex(),
+                            /* network= */ socket.getNetwork());
                 }
             } catch (IOException e) {
                 if (!shouldStopSocketLoop) {
@@ -408,12 +410,12 @@
         LOGGER.log("Receive thread stopped.");
     }
 
-    private int processResponsePacket(
-            @NonNull DatagramPacket packet, String responseType, int interfaceIndex) {
+    private int processResponsePacket(@NonNull DatagramPacket packet, String responseType,
+            int interfaceIndex, @Nullable Network network) {
         int packetNumber = ++receivedPacketNumber;
 
         List<MdnsResponse> responses = new LinkedList<>();
-        int errorCode = responseDecoder.decode(packet, responses, interfaceIndex);
+        int errorCode = responseDecoder.decode(packet, responses, interfaceIndex, network);
         if (errorCode == MdnsResponseDecoder.SUCCESS) {
             if (responseType.equals(MULTICAST_TYPE)) {
                 receivedMulticastResponse = true;
diff --git a/service/mdns/com/android/server/connectivity/mdns/MulticastNetworkInterfaceProvider.java b/service/mdns/com/android/server/connectivity/mdns/MulticastNetworkInterfaceProvider.java
index e0d8fa6..644460d 100644
--- a/service/mdns/com/android/server/connectivity/mdns/MulticastNetworkInterfaceProvider.java
+++ b/service/mdns/com/android/server/connectivity/mdns/MulticastNetworkInterfaceProvider.java
@@ -19,6 +19,7 @@
 import android.annotation.NonNull;
 import android.annotation.Nullable;
 import android.content.Context;
+import android.net.Network;
 
 import com.android.internal.annotations.VisibleForTesting;
 import com.android.server.connectivity.mdns.util.MdnsLogger;
@@ -56,7 +57,7 @@
                 context, this::onConnectivityChanged);
     }
 
-    private void onConnectivityChanged() {
+    private synchronized void onConnectivityChanged() {
         connectivityChanged = true;
     }
 
@@ -141,6 +142,11 @@
         return networkInterfaceWrappers;
     }
 
+    @Nullable
+    public Network getAvailableNetwork() {
+        return connectivityMonitor.getAvailableNetwork();
+    }
+
     private boolean canScanOnInterface(@Nullable NetworkInterfaceWrapper networkInterface) {
         try {
             if ((networkInterface == null)
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/ConnectivityMonitorWithConnectivityManagerTests.java b/tests/unit/java/com/android/server/connectivity/mdns/ConnectivityMonitorWithConnectivityManagerTests.java
index f84e2d8..8fb7be1 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/ConnectivityMonitorWithConnectivityManagerTests.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/ConnectivityMonitorWithConnectivityManagerTests.java
@@ -21,6 +21,7 @@
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.inOrder;
+import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 
@@ -111,7 +112,7 @@
                 any(NetworkRequest.class), callbackCaptor.capture());
 
         final NetworkCallback callback = callbackCaptor.getValue();
-        final Network testNetwork = new Network(1 /* netId */);
+        final Network testNetwork = mock(Network.class);
 
         // Simulate network available.
         callback.onAvailable(testNetwork);
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsResponseDecoderTests.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsResponseDecoderTests.java
index 02e00c2..4cae447 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsResponseDecoderTests.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsResponseDecoderTests.java
@@ -27,6 +27,7 @@
 import static org.mockito.Mockito.mock;
 
 import android.net.InetAddresses;
+import android.net.Network;
 
 import com.android.net.module.util.HexDump;
 import com.android.testutils.DevSdkIgnoreRule;
@@ -165,7 +166,8 @@
         packet.setSocketAddress(
                 new InetSocketAddress(MdnsConstants.getMdnsIPv4Address(), MdnsConstants.MDNS_PORT));
         responses.clear();
-        int errorCode = decoder.decode(packet, responses, MdnsSocket.INTERFACE_INDEX_UNSPECIFIED);
+        int errorCode = decoder.decode(
+                packet, responses, MdnsSocket.INTERFACE_INDEX_UNSPECIFIED, mock(Network.class));
         assertEquals(MdnsResponseDecoder.SUCCESS, errorCode);
         assertEquals(1, responses.size());
     }
@@ -178,7 +180,8 @@
         packet.setSocketAddress(
                 new InetSocketAddress(MdnsConstants.getMdnsIPv4Address(), MdnsConstants.MDNS_PORT));
         responses.clear();
-        int errorCode = decoder.decode(packet, responses, MdnsSocket.INTERFACE_INDEX_UNSPECIFIED);
+        int errorCode = decoder.decode(
+                packet, responses, MdnsSocket.INTERFACE_INDEX_UNSPECIFIED, mock(Network.class));
         assertEquals(MdnsResponseDecoder.SUCCESS, errorCode);
         assertEquals(2, responses.size());
     }
@@ -237,7 +240,8 @@
                 new InetSocketAddress(MdnsConstants.getMdnsIPv6Address(), MdnsConstants.MDNS_PORT));
 
         responses.clear();
-        int errorCode = decoder.decode(packet, responses, MdnsSocket.INTERFACE_INDEX_UNSPECIFIED);
+        int errorCode = decoder.decode(
+                packet, responses, MdnsSocket.INTERFACE_INDEX_UNSPECIFIED, mock(Network.class));
         assertEquals(MdnsResponseDecoder.SUCCESS, errorCode);
 
         MdnsResponse response = responses.get(0);
@@ -287,10 +291,13 @@
                 new InetSocketAddress(MdnsConstants.getMdnsIPv6Address(), MdnsConstants.MDNS_PORT));
 
         responses.clear();
-        int errorCode = decoder.decode(packet, responses, /* interfaceIndex= */ 10);
+        final Network network = mock(Network.class);
+        int errorCode = decoder.decode(
+                packet, responses, /* interfaceIndex= */ 10, network);
         assertEquals(errorCode, MdnsResponseDecoder.SUCCESS);
         assertEquals(responses.size(), 1);
         assertEquals(responses.get(0).getInterfaceIndex(), 10);
+        assertEquals(network, responses.get(0).getNetwork());
     }
 
     @Test
@@ -306,7 +313,8 @@
                 new InetSocketAddress(MdnsConstants.getMdnsIPv6Address(), MdnsConstants.MDNS_PORT));
 
         responses.clear();
-        int errorCode = decoder.decode(packet, responses, /* interfaceIndex= */ 0);
+        int errorCode = decoder.decode(
+                packet, responses, /* interfaceIndex= */ 0, mock(Network.class));
         assertEquals(MdnsResponseDecoder.SUCCESS, errorCode);
 
         // This should emit two records:
@@ -340,7 +348,8 @@
                 new InetSocketAddress(MdnsConstants.getMdnsIPv6Address(), MdnsConstants.MDNS_PORT));
 
         responses.clear();
-        int errorCode = decoder.decode(packet, responses, /* interfaceIndex= */ 0);
+        int errorCode = decoder.decode(
+                packet, responses, /* interfaceIndex= */ 0, mock(Network.class));
         assertEquals(MdnsResponseDecoder.SUCCESS, errorCode);
 
         // This should emit only two records:
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsResponseTests.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsResponseTests.java
index 771e42c..ec57dc8 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsResponseTests.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsResponseTests.java
@@ -21,8 +21,12 @@
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
+import static org.mockito.Mockito.mock;
+
+import android.net.Network;
 
 import com.android.net.module.util.HexDump;
 import com.android.testutils.DevSdkIgnoreRule;
@@ -92,6 +96,9 @@
             + "3839300878797A3D"
             + "21402324");
 
+    private static final int INTERFACE_INDEX = 999;
+    private final Network mNetwork = mock(Network.class);
+
     // The following helper classes act as wrappers so that IPv4 and IPv6 address records can
     // be explicitly created by type using same constructor signature as all other records.
     static class MdnsInet4AddressRecord extends MdnsInetAddressRecord {
@@ -127,7 +134,7 @@
     // Construct an MdnsResponse with the specified data packets applied.
     private MdnsResponse makeMdnsResponse(long time, List<PacketAndRecordClass> responseList)
             throws IOException {
-        MdnsResponse response = new MdnsResponse(time);
+        MdnsResponse response = new MdnsResponse(time, INTERFACE_INDEX, mNetwork);
         for (PacketAndRecordClass responseData : responseList) {
             DatagramPacket packet =
                     new DatagramPacket(responseData.packetData, responseData.packetData.length);
@@ -159,7 +166,7 @@
         String[] name = reader.readLabels();
         reader.skip(2); // skip record type indication.
         MdnsInetAddressRecord record = new MdnsInetAddressRecord(name, MdnsRecord.TYPE_A, reader);
-        MdnsResponse response = new MdnsResponse(0);
+        MdnsResponse response = new MdnsResponse(0, INTERFACE_INDEX, mNetwork);
         assertFalse(response.hasInet4AddressRecord());
         assertTrue(response.setInet4AddressRecord(record));
         assertEquals(response.getInet4AddressRecord(), record);
@@ -173,7 +180,7 @@
         reader.skip(2); // skip record type indication.
         MdnsInetAddressRecord record =
                 new MdnsInetAddressRecord(name, MdnsRecord.TYPE_AAAA, reader);
-        MdnsResponse response = new MdnsResponse(0);
+        MdnsResponse response = new MdnsResponse(0, INTERFACE_INDEX, mNetwork);
         assertFalse(response.hasInet6AddressRecord());
         assertTrue(response.setInet6AddressRecord(record));
         assertEquals(response.getInet6AddressRecord(), record);
@@ -186,7 +193,7 @@
         String[] name = reader.readLabels();
         reader.skip(2); // skip record type indication.
         MdnsPointerRecord record = new MdnsPointerRecord(name, reader);
-        MdnsResponse response = new MdnsResponse(0);
+        MdnsResponse response = new MdnsResponse(0, INTERFACE_INDEX, mNetwork);
         assertFalse(response.hasPointerRecords());
         assertTrue(response.addPointerRecord(record));
         List<MdnsPointerRecord> recordList = response.getPointerRecords();
@@ -202,7 +209,7 @@
         String[] name = reader.readLabels();
         reader.skip(2); // skip record type indication.
         MdnsServiceRecord record = new MdnsServiceRecord(name, reader);
-        MdnsResponse response = new MdnsResponse(0);
+        MdnsResponse response = new MdnsResponse(0, INTERFACE_INDEX, mNetwork);
         assertFalse(response.hasServiceRecord());
         assertTrue(response.setServiceRecord(record));
         assertEquals(response.getServiceRecord(), record);
@@ -215,23 +222,31 @@
         String[] name = reader.readLabels();
         reader.skip(2); // skip record type indication.
         MdnsTextRecord record = new MdnsTextRecord(name, reader);
-        MdnsResponse response = new MdnsResponse(0);
+        MdnsResponse response = new MdnsResponse(0, INTERFACE_INDEX, mNetwork);
         assertFalse(response.hasTextRecord());
         assertTrue(response.setTextRecord(record));
         assertEquals(response.getTextRecord(), record);
     }
 
     @Test
-    public void getInterfaceIndex_returnsDefaultValue() {
-        MdnsResponse response = new MdnsResponse(/* now= */ 0);
-        assertEquals(response.getInterfaceIndex(), -1);
+    public void getInterfaceIndex() {
+        final MdnsResponse response1 = new MdnsResponse(/* now= */ 0, INTERFACE_INDEX, mNetwork);
+        assertEquals(INTERFACE_INDEX, response1.getInterfaceIndex());
+
+        final MdnsResponse response2 =
+                new MdnsResponse(/* now= */ 0, 1234 /* interfaceIndex */, mNetwork);
+        assertEquals(1234, response2.getInterfaceIndex());
     }
 
     @Test
-    public void getInterfaceIndex_afterSet_returnsValue() {
-        MdnsResponse response = new MdnsResponse(/* now= */ 0);
-        response.setInterfaceIndex(5);
-        assertEquals(response.getInterfaceIndex(), 5);
+    public void testGetNetwork() {
+        final MdnsResponse response1 =
+                new MdnsResponse(/* now= */ 0, INTERFACE_INDEX, null /* network */);
+        assertNull(response1.getNetwork());
+
+        final MdnsResponse response2 =
+                new MdnsResponse(/* now= */ 0, 1234 /* interfaceIndex */, mNetwork);
+        assertEquals(mNetwork, response2.getNetwork());
     }
 
     @Test
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceInfoTest.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceInfoTest.java
index ebdb73f..76728cf 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceInfoTest.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceInfoTest.java
@@ -16,13 +16,16 @@
 
 package com.android.server.connectivity.mdns;
 
+import static com.android.server.connectivity.mdns.MdnsSocket.INTERFACE_INDEX_UNSPECIFIED;
 import static com.android.testutils.DevSdkIgnoreRuleKt.SC_V2;
 
 import static org.junit.Assert.assertArrayEquals;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertTrue;
+import static org.mockito.Mockito.mock;
 
+import android.net.Network;
 import android.os.Parcel;
 
 import com.android.server.connectivity.mdns.MdnsServiceInfo.TextEntry;
@@ -128,7 +131,7 @@
                         "2001::1",
                         List.of());
 
-        assertEquals(info.getInterfaceIndex(), -1);
+        assertEquals(info.getInterfaceIndex(), INTERFACE_INDEX_UNSPECIFIED);
     }
 
     @Test
@@ -150,6 +153,41 @@
     }
 
     @Test
+    public void testGetNetwork() {
+        final MdnsServiceInfo info1 =
+                new MdnsServiceInfo(
+                        "my-mdns-service",
+                        new String[] {"_googlecast", "_tcp"},
+                        List.of(),
+                        new String[] {"my-host", "local"},
+                        12345,
+                        "192.168.1.1",
+                        "2001::1",
+                        List.of(),
+                        /* textEntries= */ null,
+                        /* interfaceIndex= */ 20);
+
+        assertNull(info1.getNetwork());
+
+        final Network network = mock(Network.class);
+        final MdnsServiceInfo info2 =
+                new MdnsServiceInfo(
+                        "my-mdns-service",
+                        new String[] {"_googlecast", "_tcp"},
+                        List.of(),
+                        new String[] {"my-host", "local"},
+                        12345,
+                        "192.168.1.1",
+                        "2001::1",
+                        List.of(),
+                        /* textEntries= */ null,
+                        /* interfaceIndex= */ 20,
+                        network);
+
+        assertEquals(network, info2.getNetwork());
+    }
+
+    @Test
     public void parcelable_canBeParceledAndUnparceled() {
         Parcel parcel = Parcel.obtain();
         MdnsServiceInfo beforeParcel =
@@ -165,7 +203,9 @@
                         List.of(
                                 MdnsServiceInfo.TextEntry.fromString("vn=Google Inc."),
                                 MdnsServiceInfo.TextEntry.fromString("mn=Google Nest Hub Max"),
-                                MdnsServiceInfo.TextEntry.fromString("test=")));
+                                MdnsServiceInfo.TextEntry.fromString("test=")),
+                        20 /* interfaceIndex */,
+                        new Network(123));
 
         beforeParcel.writeToParcel(parcel, 0);
         parcel.setDataPosition(0);
@@ -179,6 +219,8 @@
         assertEquals(beforeParcel.getIpv4Address(), afterParcel.getIpv4Address());
         assertEquals(beforeParcel.getIpv6Address(), afterParcel.getIpv6Address());
         assertEquals(beforeParcel.getAttributes(), afterParcel.getAttributes());
+        assertEquals(beforeParcel.getInterfaceIndex(), afterParcel.getInterfaceIndex());
+        assertEquals(beforeParcel.getNetwork(), afterParcel.getNetwork());
     }
 
     @Test
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java
index 462685a..697116c 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java
@@ -39,6 +39,7 @@
 import android.annotation.NonNull;
 import android.annotation.Nullable;
 import android.net.InetAddresses;
+import android.net.Network;
 import android.text.TextUtils;
 
 import com.android.server.connectivity.mdns.MdnsServiceInfo.TextEntry;
@@ -79,6 +80,7 @@
     private static final int INTERFACE_INDEX = 999;
     private static final String SERVICE_TYPE = "_googlecast._tcp.local";
     private static final String[] SERVICE_TYPE_LABELS = TextUtils.split(SERVICE_TYPE, "\\.");
+    private static final Network NETWORK = mock(Network.class);
 
     @Mock
     private MdnsServiceBrowserListener mockListenerOne;
@@ -385,7 +387,8 @@
 
     private static void verifyServiceInfo(MdnsServiceInfo serviceInfo, String serviceName,
             String[] serviceType, String ipv4Address, String ipv6Address, int port,
-            List<String> subTypes, Map<String, String> attributes, int interfaceIndex) {
+            List<String> subTypes, Map<String, String> attributes, int interfaceIndex,
+            Network network) {
         assertEquals(serviceName, serviceInfo.getServiceInstanceName());
         assertArrayEquals(serviceType, serviceInfo.getServiceType());
         assertEquals(ipv4Address, serviceInfo.getIpv4Address());
@@ -396,6 +399,7 @@
             assertEquals(attributes.get(key), serviceInfo.getAttributeByKey(key));
         }
         assertEquals(interfaceIndex, serviceInfo.getInterfaceIndex());
+        assertEquals(network, serviceInfo.getNetwork());
     }
 
     @Test
@@ -405,6 +409,7 @@
         MdnsResponse response = mock(MdnsResponse.class);
         when(response.getServiceInstanceName()).thenReturn("service-instance-1");
         doReturn(INTERFACE_INDEX).when(response).getInterfaceIndex();
+        doReturn(NETWORK).when(response).getNetwork();
         when(response.isComplete()).thenReturn(false);
 
         client.processResponse(response);
@@ -417,7 +422,8 @@
                 0 /* port */,
                 List.of() /* subTypes */,
                 Collections.singletonMap("key", null) /* attributes */,
-                INTERFACE_INDEX);
+                INTERFACE_INDEX,
+                NETWORK);
 
         verify(mockListenerOne, never()).onServiceFound(any(MdnsServiceInfo.class));
         verify(mockListenerOne, never()).onServiceUpdated(any(MdnsServiceInfo.class));
@@ -436,7 +442,8 @@
                         5353,
                         /* subtype= */ "ABCDE",
                         Collections.emptyMap(),
-                        /* interfaceIndex= */ 20);
+                        /* interfaceIndex= */ 20,
+                        NETWORK);
         client.processResponse(initialResponse);
 
         // Process a second response with a different port and updated text attributes.
@@ -447,7 +454,8 @@
                         5354,
                         /* subtype= */ "ABCDE",
                         Collections.singletonMap("key", "value"),
-                        /* interfaceIndex= */ 20);
+                        /* interfaceIndex= */ 20,
+                        NETWORK);
         client.processResponse(secondResponse);
 
         // Verify onServiceNameDiscovered was called once for the initial response.
@@ -460,7 +468,8 @@
                 5353 /* port */,
                 Collections.singletonList("ABCDE") /* subTypes */,
                 Collections.singletonMap("key", null) /* attributes */,
-                20 /* interfaceIndex */);
+                20 /* interfaceIndex */,
+                NETWORK);
 
         // Verify onServiceFound was called once for the initial response.
         verify(mockListenerOne).onServiceFound(serviceInfoCaptor.capture());
@@ -471,6 +480,7 @@
         assertEquals(initialServiceInfo.getSubtypes(), Collections.singletonList("ABCDE"));
         assertNull(initialServiceInfo.getAttributeByKey("key"));
         assertEquals(initialServiceInfo.getInterfaceIndex(), 20);
+        assertEquals(NETWORK, initialServiceInfo.getNetwork());
 
         // Verify onServiceUpdated was called once for the second response.
         verify(mockListenerOne).onServiceUpdated(serviceInfoCaptor.capture());
@@ -482,6 +492,7 @@
         assertEquals(updatedServiceInfo.getSubtypes(), Collections.singletonList("ABCDE"));
         assertEquals(updatedServiceInfo.getAttributeByKey("key"), "value");
         assertEquals(updatedServiceInfo.getInterfaceIndex(), 20);
+        assertEquals(NETWORK, updatedServiceInfo.getNetwork());
     }
 
     @Test
@@ -497,7 +508,8 @@
                         5353,
                         /* subtype= */ "ABCDE",
                         Collections.emptyMap(),
-                        /* interfaceIndex= */ 20);
+                        /* interfaceIndex= */ 20,
+                        NETWORK);
         client.processResponse(initialResponse);
 
         // Process a second response with a different port and updated text attributes.
@@ -508,7 +520,8 @@
                         5354,
                         /* subtype= */ "ABCDE",
                         Collections.singletonMap("key", "value"),
-                        /* interfaceIndex= */ 20);
+                        /* interfaceIndex= */ 20,
+                        NETWORK);
         client.processResponse(secondResponse);
 
         System.out.println("secondResponses ip"
@@ -524,7 +537,8 @@
                 5353 /* port */,
                 Collections.singletonList("ABCDE") /* subTypes */,
                 Collections.singletonMap("key", null) /* attributes */,
-                20 /* interfaceIndex */);
+                20 /* interfaceIndex */,
+                NETWORK);
 
         // Verify onServiceFound was called once for the initial response.
         verify(mockListenerOne).onServiceFound(serviceInfoCaptor.capture());
@@ -535,6 +549,7 @@
         assertEquals(initialServiceInfo.getSubtypes(), Collections.singletonList("ABCDE"));
         assertNull(initialServiceInfo.getAttributeByKey("key"));
         assertEquals(initialServiceInfo.getInterfaceIndex(), 20);
+        assertEquals(NETWORK, initialServiceInfo.getNetwork());
 
         // Verify onServiceUpdated was called once for the second response.
         verify(mockListenerOne).onServiceUpdated(serviceInfoCaptor.capture());
@@ -546,6 +561,7 @@
         assertEquals(updatedServiceInfo.getSubtypes(), Collections.singletonList("ABCDE"));
         assertEquals(updatedServiceInfo.getAttributeByKey("key"), "value");
         assertEquals(updatedServiceInfo.getInterfaceIndex(), 20);
+        assertEquals(NETWORK, updatedServiceInfo.getNetwork());
     }
 
     private void verifyServiceRemovedNoCallback(MdnsServiceBrowserListener listener) {
@@ -554,15 +570,17 @@
     }
 
     private void verifyServiceRemovedCallback(MdnsServiceBrowserListener listener,
-            String serviceName, String[] serviceType, int interfaceIndex) {
+            String serviceName, String[] serviceType, int interfaceIndex, Network network) {
         verify(listener).onServiceRemoved(argThat(
                 info -> serviceName.equals(info.getServiceInstanceName())
                         && Arrays.equals(serviceType, info.getServiceType())
-                        && info.getInterfaceIndex() == interfaceIndex));
+                        && info.getInterfaceIndex() == interfaceIndex
+                        && network.equals(info.getNetwork())));
         verify(listener).onServiceNameRemoved(argThat(
                 info -> serviceName.equals(info.getServiceInstanceName())
                         && Arrays.equals(serviceType, info.getServiceType())
-                        && info.getInterfaceIndex() == interfaceIndex));
+                        && info.getInterfaceIndex() == interfaceIndex
+                        && network.equals(info.getNetwork())));
     }
 
     @Test
@@ -580,11 +598,13 @@
                         5353 /* port */,
                         /* subtype= */ "ABCDE",
                         Collections.emptyMap(),
-                        INTERFACE_INDEX);
+                        INTERFACE_INDEX,
+                        NETWORK);
         client.processResponse(initialResponse);
         MdnsResponse response = mock(MdnsResponse.class);
         doReturn("goodbye-service").when(response).getServiceInstanceName();
         doReturn(INTERFACE_INDEX).when(response).getInterfaceIndex();
+        doReturn(NETWORK).when(response).getNetwork();
         doReturn(true).when(response).isGoodbye();
         client.processResponse(response);
         // Verify removed callback won't be called if the service is not existed.
@@ -595,9 +615,9 @@
         doReturn(serviceName).when(response).getServiceInstanceName();
         client.processResponse(response);
         verifyServiceRemovedCallback(
-                mockListenerOne, serviceName, SERVICE_TYPE_LABELS, INTERFACE_INDEX);
+                mockListenerOne, serviceName, SERVICE_TYPE_LABELS, INTERFACE_INDEX, NETWORK);
         verifyServiceRemovedCallback(
-                mockListenerTwo, serviceName, SERVICE_TYPE_LABELS, INTERFACE_INDEX);
+                mockListenerTwo, serviceName, SERVICE_TYPE_LABELS, INTERFACE_INDEX, NETWORK);
     }
 
     @Test
@@ -610,7 +630,8 @@
                         5353,
                         /* subtype= */ "ABCDE",
                         Collections.emptyMap(),
-                        INTERFACE_INDEX);
+                        INTERFACE_INDEX,
+                        NETWORK);
         client.processResponse(initialResponse);
 
         client.startSendAndReceive(mockListenerOne, MdnsSearchOptions.getDefaultOptions());
@@ -625,7 +646,8 @@
                 5353 /* port */,
                 Collections.singletonList("ABCDE") /* subTypes */,
                 Collections.singletonMap("key", null) /* attributes */,
-                INTERFACE_INDEX);
+                INTERFACE_INDEX,
+                NETWORK);
 
         // Verify onServiceFound was called once for the existing response.
         verify(mockListenerOne).onServiceFound(serviceInfoCaptor.capture());
@@ -662,7 +684,7 @@
         MdnsResponse initialResponse =
                 createMockResponse(
                         serviceInstanceName, "192.168.1.1", 5353, List.of("ABCDE"),
-                        Map.of(), INTERFACE_INDEX);
+                        Map.of(), INTERFACE_INDEX, NETWORK);
         client.processResponse(initialResponse);
 
         // Clear the scheduled runnable.
@@ -696,7 +718,7 @@
         MdnsResponse initialResponse =
                 createMockResponse(
                         serviceInstanceName, "192.168.1.1", 5353, List.of("ABCDE"),
-                        Map.of(), INTERFACE_INDEX);
+                        Map.of(), INTERFACE_INDEX, NETWORK);
         client.processResponse(initialResponse);
 
         // Clear the scheduled runnable.
@@ -714,8 +736,8 @@
         firstMdnsTask.run();
 
         // Verify removed callback was called.
-        verifyServiceRemovedCallback(
-                mockListenerOne, serviceInstanceName, SERVICE_TYPE_LABELS, INTERFACE_INDEX);
+        verifyServiceRemovedCallback(mockListenerOne, serviceInstanceName, SERVICE_TYPE_LABELS,
+                INTERFACE_INDEX, NETWORK);
     }
 
     @Test
@@ -736,7 +758,7 @@
         MdnsResponse initialResponse =
                 createMockResponse(
                         serviceInstanceName, "192.168.1.1", 5353, List.of("ABCDE"),
-                        Map.of(), INTERFACE_INDEX);
+                        Map.of(), INTERFACE_INDEX, NETWORK);
         client.processResponse(initialResponse);
 
         // Clear the scheduled runnable.
@@ -770,7 +792,7 @@
         MdnsResponse initialResponse =
                 createMockResponse(
                         serviceInstanceName, "192.168.1.1", 5353, List.of("ABCDE"),
-                        Map.of(), INTERFACE_INDEX);
+                        Map.of(), INTERFACE_INDEX, NETWORK);
         client.processResponse(initialResponse);
 
         // Clear the scheduled runnable.
@@ -781,8 +803,8 @@
         firstMdnsTask.run();
 
         // Verify removed callback was called.
-        verifyServiceRemovedCallback(
-                mockListenerOne, serviceInstanceName, SERVICE_TYPE_LABELS, INTERFACE_INDEX);
+        verifyServiceRemovedCallback(mockListenerOne, serviceInstanceName, SERVICE_TYPE_LABELS,
+                INTERFACE_INDEX, NETWORK);
     }
 
     @Test
@@ -801,7 +823,8 @@
                         5353,
                         "ABCDE" /* subtype */,
                         Collections.emptyMap(),
-                        INTERFACE_INDEX);
+                        INTERFACE_INDEX,
+                        NETWORK);
         client.processResponse(initialResponse);
 
         // Process a second response which has ip address to make response become complete.
@@ -812,7 +835,8 @@
                         5353,
                         "ABCDE" /* subtype */,
                         Collections.emptyMap(),
-                        INTERFACE_INDEX);
+                        INTERFACE_INDEX,
+                        NETWORK);
         client.processResponse(secondResponse);
 
         // Process a third response with a different ip address, port and updated text attributes.
@@ -823,7 +847,8 @@
                         5354,
                         "ABCDE" /* subtype */,
                         Collections.singletonMap("key", "value"),
-                        INTERFACE_INDEX);
+                        INTERFACE_INDEX,
+                        NETWORK);
         client.processResponse(thirdResponse);
 
         // Process the last response which is goodbye message.
@@ -842,7 +867,8 @@
                 5353 /* port */,
                 Collections.singletonList("ABCDE") /* subTypes */,
                 Collections.singletonMap("key", null) /* attributes */,
-                INTERFACE_INDEX);
+                INTERFACE_INDEX,
+                NETWORK);
 
         // Verify onServiceFound was second called for the second response.
         inOrder.verify(mockListenerOne).onServiceFound(serviceInfoCaptor.capture());
@@ -854,7 +880,8 @@
                 5353 /* port */,
                 Collections.singletonList("ABCDE") /* subTypes */,
                 Collections.singletonMap("key", null) /* attributes */,
-                INTERFACE_INDEX);
+                INTERFACE_INDEX,
+                NETWORK);
 
         // Verify onServiceUpdated was third called for the third response.
         inOrder.verify(mockListenerOne).onServiceUpdated(serviceInfoCaptor.capture());
@@ -866,7 +893,8 @@
                 5354 /* port */,
                 Collections.singletonList("ABCDE") /* subTypes */,
                 Collections.singletonMap("key", "value") /* attributes */,
-                INTERFACE_INDEX);
+                INTERFACE_INDEX,
+                NETWORK);
 
         // Verify onServiceRemoved was called for the last response.
         inOrder.verify(mockListenerOne).onServiceRemoved(serviceInfoCaptor.capture());
@@ -878,7 +906,8 @@
                 5354 /* port */,
                 Collections.singletonList("ABCDE") /* subTypes */,
                 Collections.singletonMap("key", "value") /* attributes */,
-                INTERFACE_INDEX);
+                INTERFACE_INDEX,
+                NETWORK);
 
         // Verify onServiceNameRemoved was called for the last response.
         inOrder.verify(mockListenerOne).onServiceNameRemoved(serviceInfoCaptor.capture());
@@ -890,7 +919,8 @@
                 5354 /* port */,
                 Collections.singletonList("ABCDE") /* subTypes */,
                 Collections.singletonMap("key", "value") /* attributes */,
-                INTERFACE_INDEX);
+                INTERFACE_INDEX,
+                NETWORK);
     }
 
     // verifies that the right query was enqueued with the right delay, and send query by executing
@@ -962,26 +992,25 @@
             int port,
             @NonNull List<String> subtypes,
             @NonNull Map<String, String> textAttributes,
-            int interfaceIndex)
+            int interfaceIndex,
+            Network network)
             throws Exception {
         String[] hostName = new String[]{"hostname"};
         MdnsServiceRecord serviceRecord = mock(MdnsServiceRecord.class);
         when(serviceRecord.getServiceHost()).thenReturn(hostName);
         when(serviceRecord.getServicePort()).thenReturn(port);
 
-        MdnsResponse response = spy(new MdnsResponse(0));
+        MdnsResponse response = spy(new MdnsResponse(0, interfaceIndex, network));
 
         MdnsInetAddressRecord inetAddressRecord = mock(MdnsInetAddressRecord.class);
         if (host.contains(":")) {
             when(inetAddressRecord.getInet6Address())
                     .thenReturn((Inet6Address) Inet6Address.getByName(host));
             response.setInet6AddressRecord(inetAddressRecord);
-            response.setInterfaceIndex(interfaceIndex);
         } else {
             when(inetAddressRecord.getInet4Address())
                     .thenReturn((Inet4Address) Inet4Address.getByName(host));
             response.setInet4AddressRecord(inetAddressRecord);
-            response.setInterfaceIndex(interfaceIndex);
         }
 
         MdnsTextRecord textRecord = mock(MdnsTextRecord.class);
@@ -1011,10 +1040,10 @@
             int port,
             @NonNull String subtype,
             @NonNull Map<String, String> textAttributes,
-            int interfaceIndex)
+            int interfaceIndex,
+            Network network)
             throws Exception {
-        MdnsResponse response = new MdnsResponse(0);
-        response.setInterfaceIndex(interfaceIndex);
+        MdnsResponse response = new MdnsResponse(0, interfaceIndex, network);
 
         // Set PTR record
         final MdnsPointerRecord pointerRecord = new MdnsPointerRecord(