Merge "[Feature sync] Propagate network interface index to MdnsServiceInfo"
diff --git a/service/mdns/com/android/server/connectivity/mdns/MdnsConfigs.java b/service/mdns/com/android/server/connectivity/mdns/MdnsConfigs.java
index 35a685d..41abba7 100644
--- a/service/mdns/com/android/server/connectivity/mdns/MdnsConfigs.java
+++ b/service/mdns/com/android/server/connectivity/mdns/MdnsConfigs.java
@@ -97,4 +97,8 @@
     public static boolean allowSearchOptionsToRemoveExpiredService() {
         return false;
     }
+
+    public static boolean allowNetworkInterfaceIndexPropagation() {
+        return true;
+    }
 }
\ No newline at end of file
diff --git a/service/mdns/com/android/server/connectivity/mdns/MdnsResponse.java b/service/mdns/com/android/server/connectivity/mdns/MdnsResponse.java
index 9f3894f..c94e3c6 100644
--- a/service/mdns/com/android/server/connectivity/mdns/MdnsResponse.java
+++ b/service/mdns/com/android/server/connectivity/mdns/MdnsResponse.java
@@ -35,6 +35,7 @@
     private MdnsInetAddressRecord inet4AddressRecord;
     private MdnsInetAddressRecord inet6AddressRecord;
     private long lastUpdateTime;
+    private int interfaceIndex = MdnsSocket.INTERFACE_INDEX_UNSPECIFIED;
 
     /** Constructs a new, empty response. */
     public MdnsResponse(long now) {
@@ -203,6 +204,21 @@
         return true;
     }
 
+    /**
+     * Updates the index of the network interface at which this response was received. Can be set to
+     * {@link MdnsSocket#INTERFACE_INDEX_UNSPECIFIED} if unset.
+     */
+    public synchronized void setInterfaceIndex(int interfaceIndex) {
+        this.interfaceIndex = interfaceIndex;
+    }
+
+    /**
+     * Returns the index of the network interface at which this response was received. Can be set to
+     * {@link MdnsSocket#INTERFACE_INDEX_UNSPECIFIED} if unset.
+     */
+    public synchronized int getInterfaceIndex() {
+        return interfaceIndex;
+    }
 
     /** Gets the IPv6 address record. */
     public synchronized MdnsInetAddressRecord getInet6AddressRecord() {
diff --git a/service/mdns/com/android/server/connectivity/mdns/MdnsResponseDecoder.java b/service/mdns/com/android/server/connectivity/mdns/MdnsResponseDecoder.java
index 3e5fc42..57b241e 100644
--- a/service/mdns/com/android/server/connectivity/mdns/MdnsResponseDecoder.java
+++ b/service/mdns/com/android/server/connectivity/mdns/MdnsResponseDecoder.java
@@ -92,9 +92,12 @@
      * the responses for completeness; the caller should do that.
      *
      * @param packet The packet to read from.
+     * @param interfaceIndex the network interface index (or {@link
+     *     MdnsSocket#INTERFACE_INDEX_UNSPECIFIED} if not known) at which the packet was received
      * @return A list of mDNS responses, or null if the packet contained no appropriate responses.
      */
-    public int decode(@NonNull DatagramPacket packet, @NonNull List<MdnsResponse> responses) {
+    public int decode(@NonNull DatagramPacket packet, @NonNull List<MdnsResponse> responses,
+            int interfaceIndex) {
         MdnsPacketReader reader = new MdnsPacketReader(packet);
 
         List<MdnsRecord> records;
@@ -281,8 +284,10 @@
                 MdnsResponse response = findResponseWithHostName(responses, inetRecord.getName());
                 if (inetRecord.getInet4Address() != null && response != null) {
                     response.setInet4AddressRecord(inetRecord);
+                    response.setInterfaceIndex(interfaceIndex);
                 } else if (inetRecord.getInet6Address() != null && response != null) {
                     response.setInet6AddressRecord(inetRecord);
+                    response.setInterfaceIndex(interfaceIndex);
                 }
             }
         }
diff --git a/service/mdns/com/android/server/connectivity/mdns/MdnsServiceInfo.java b/service/mdns/com/android/server/connectivity/mdns/MdnsServiceInfo.java
index d142280..7d645e3 100644
--- a/service/mdns/com/android/server/connectivity/mdns/MdnsServiceInfo.java
+++ b/service/mdns/com/android/server/connectivity/mdns/MdnsServiceInfo.java
@@ -57,7 +57,8 @@
                             source.readString(),
                             source.readString(),
                             source.createStringArrayList(),
-                            source.createTypedArrayList(TextEntry.CREATOR));
+                            source.createTypedArrayList(TextEntry.CREATOR),
+                            source.readInt());
                 }
 
                 @Override
@@ -76,6 +77,7 @@
     final List<String> textStrings;
     @Nullable
     final List<TextEntry> textEntries;
+    private final int interfaceIndex;
 
     private final Map<String, byte[]> attributes;
 
@@ -98,7 +100,32 @@
                 ipv4Address,
                 ipv6Address,
                 textStrings,
-                /* textEntries= */ null);
+                /* textEntries= */ null,
+                /* interfaceIndex= */ -1);
+    }
+
+    /** Constructs a {@link MdnsServiceInfo} object with default values. */
+    public MdnsServiceInfo(
+            String serviceInstanceName,
+            String[] serviceType,
+            List<String> subtypes,
+            String[] hostName,
+            int port,
+            String ipv4Address,
+            String ipv6Address,
+            List<String> textStrings,
+            @Nullable List<TextEntry> textEntries) {
+        this(
+                serviceInstanceName,
+                serviceType,
+                subtypes,
+                hostName,
+                port,
+                ipv4Address,
+                ipv6Address,
+                textStrings,
+                textEntries,
+                /* interfaceIndex= */ -1);
     }
 
     /**
@@ -115,7 +142,8 @@
             String ipv4Address,
             String ipv6Address,
             List<String> textStrings,
-            @Nullable List<TextEntry> textEntries) {
+            @Nullable List<TextEntry> textEntries,
+            int interfaceIndex) {
         this.serviceInstanceName = serviceInstanceName;
         this.serviceType = serviceType;
         this.subtypes = new ArrayList<>();
@@ -149,6 +177,7 @@
             }
         }
         this.attributes = Collections.unmodifiableMap(attributes);
+        this.interfaceIndex = interfaceIndex;
     }
 
     private static List<TextEntry> parseTextStrings(List<String> textStrings) {
@@ -206,6 +235,14 @@
     }
 
     /**
+     * Returns the index of the network interface at which this response was received, or -1 if the
+     * index is not known.
+     */
+    public int getInterfaceIndex() {
+        return interfaceIndex;
+    }
+
+    /**
      * Returns attribute value for {@code key} as a UTF-8 string. It's the caller who must make sure
      * that the value of {@code key} is indeed a UTF-8 string. {@code null} will be returned if no
      * attribute value exists for {@code key}.
@@ -253,6 +290,7 @@
         out.writeString(ipv6Address);
         out.writeStringList(textStrings);
         out.writeTypedList(textEntries);
+        out.writeInt(interfaceIndex);
     }
 
     @Override
diff --git a/service/mdns/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java b/service/mdns/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
index 3747323..be993e2 100644
--- a/service/mdns/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
+++ b/service/mdns/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
@@ -111,7 +111,8 @@
                 ipv4Address,
                 ipv6Address,
                 response.getTextRecord().getStrings(),
-                response.getTextRecord().getEntries());
+                response.getTextRecord().getEntries(),
+                response.getInterfaceIndex());
     }
 
     /**
diff --git a/service/mdns/com/android/server/connectivity/mdns/MdnsSocket.java b/service/mdns/com/android/server/connectivity/mdns/MdnsSocket.java
index 34db7f0..3442430 100644
--- a/service/mdns/com/android/server/connectivity/mdns/MdnsSocket.java
+++ b/service/mdns/com/android/server/connectivity/mdns/MdnsSocket.java
@@ -19,11 +19,13 @@
 import android.annotation.NonNull;
 
 import com.android.internal.annotations.VisibleForTesting;
+import com.android.server.connectivity.mdns.util.MdnsLogger;
 
 import java.io.IOException;
 import java.net.DatagramPacket;
 import java.net.InetSocketAddress;
 import java.net.MulticastSocket;
+import java.net.SocketException;
 import java.util.List;
 
 /**
@@ -35,6 +37,9 @@
 // TODO(b/242631897): Resolve nullness suppression.
 @SuppressWarnings("nullness")
 public class MdnsSocket {
+    private static final MdnsLogger LOGGER = new MdnsLogger("MdnsSocket");
+
+    static final int INTERFACE_INDEX_UNSPECIFIED = -1;
     private static final InetSocketAddress MULTICAST_IPV4_ADDRESS =
             new InetSocketAddress(MdnsConstants.getMdnsIPv4Address(), MdnsConstants.MDNS_PORT);
     private static final InetSocketAddress MULTICAST_IPV6_ADDRESS =
@@ -103,6 +108,19 @@
         multicastNetworkInterfaceProvider.stopWatchingConnectivityChanges();
     }
 
+    /**
+     * Returns the index of the network interface that this socket is bound to. If the interface
+     * cannot be determined, returns -1.
+     */
+    public int getInterfaceIndex() {
+        try {
+            return multicastSocket.getNetworkInterface().getIndex();
+        } catch (SocketException e) {
+            LOGGER.e("Failed to retrieve interface index for socket.", e);
+            return -1;
+        }
+    }
+
     @VisibleForTesting
     MulticastSocket createMulticastSocket(int port) throws IOException {
         return new MulticastSocket(port);
diff --git a/service/mdns/com/android/server/connectivity/mdns/MdnsSocketClient.java b/service/mdns/com/android/server/connectivity/mdns/MdnsSocketClient.java
index 010f761..6cbe3c7 100644
--- a/service/mdns/com/android/server/connectivity/mdns/MdnsSocketClient.java
+++ b/service/mdns/com/android/server/connectivity/mdns/MdnsSocketClient.java
@@ -79,6 +79,8 @@
     private final boolean checkMulticastResponse = MdnsConfigs.checkMulticastResponse();
     private final long checkMulticastResponseIntervalMs =
             MdnsConfigs.checkMulticastResponseIntervalMs();
+    private final boolean propagateInterfaceIndex =
+            MdnsConfigs.allowNetworkInterfaceIndexPropagation();
     private final Object socketLock = new Object();
     private final Object timerObject = new Object();
     // If multicast response was received in the current session. The value is reset in the
@@ -382,7 +384,12 @@
 
                 if (!shouldStopSocketLoop) {
                     String responseType = socket == multicastSocket ? MULTICAST_TYPE : UNICAST_TYPE;
-                    processResponsePacket(packet, responseType);
+                    processResponsePacket(
+                            packet,
+                            responseType,
+                            /* interfaceIndex= */ (socket == null || !propagateInterfaceIndex)
+                                    ? MdnsSocket.INTERFACE_INDEX_UNSPECIFIED
+                                    : socket.getInterfaceIndex());
                 }
             } catch (IOException e) {
                 if (!shouldStopSocketLoop) {
@@ -393,12 +400,12 @@
         LOGGER.log("Receive thread stopped.");
     }
 
-    private int processResponsePacket(@NonNull DatagramPacket packet, String responseType)
-            throws IOException {
+    private int processResponsePacket(
+            @NonNull DatagramPacket packet, String responseType, int interfaceIndex) {
         int packetNumber = ++receivedPacketNumber;
 
         List<MdnsResponse> responses = new LinkedList<>();
-        int errorCode = responseDecoder.decode(packet, responses);
+        int errorCode = responseDecoder.decode(packet, responses, interfaceIndex);
         if (errorCode == MdnsResponseDecoder.SUCCESS) {
             if (responseType.equals(MULTICAST_TYPE)) {
                 receivedMulticastResponse = true;
@@ -414,7 +421,8 @@
             }
             for (MdnsResponse response : responses) {
                 String serviceInstanceName = response.getServiceInstanceName();
-                LOGGER.log("mDNS %s response received: %s", responseType, serviceInstanceName);
+                LOGGER.log("mDNS %s response received: %s at ifIndex %d", responseType,
+                        serviceInstanceName, interfaceIndex);
                 if (callback != null) {
                     callback.onResponseReceived(response);
                 }
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsResponseDecoderTests.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsResponseDecoderTests.java
index ea9156c..8d0ace5 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsResponseDecoderTests.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsResponseDecoderTests.java
@@ -106,9 +106,9 @@
             + "63616C0000018001000000780004C0A8010A000001800100000078"
             + "0004C0A8010A00000000000000");
 
-    private static final String DUMMY_CAST_SERVICE_NAME = "_googlecast";
-    private static final String[] DUMMY_CAST_SERVICE_TYPE =
-            new String[] {DUMMY_CAST_SERVICE_NAME, "_tcp", "local"};
+    private static final String CAST_SERVICE_NAME = "_googlecast";
+    private static final String[] CAST_SERVICE_TYPE =
+            new String[] {CAST_SERVICE_NAME, "_tcp", "local"};
 
     private final List<MdnsResponse> responses = new LinkedList<>();
 
@@ -116,13 +116,13 @@
 
     @Before
     public void setUp() {
-        MdnsResponseDecoder decoder = new MdnsResponseDecoder(mClock, DUMMY_CAST_SERVICE_TYPE);
+        MdnsResponseDecoder decoder = new MdnsResponseDecoder(mClock, CAST_SERVICE_TYPE);
         assertNotNull(data);
         DatagramPacket packet = new DatagramPacket(data, data.length);
         packet.setSocketAddress(
                 new InetSocketAddress(MdnsConstants.getMdnsIPv4Address(), MdnsConstants.MDNS_PORT));
         responses.clear();
-        int errorCode = decoder.decode(packet, responses);
+        int errorCode = decoder.decode(packet, responses, MdnsSocket.INTERFACE_INDEX_UNSPECIFIED);
         assertEquals(MdnsResponseDecoder.SUCCESS, errorCode);
         assertEquals(1, responses.size());
     }
@@ -135,7 +135,7 @@
         packet.setSocketAddress(
                 new InetSocketAddress(MdnsConstants.getMdnsIPv4Address(), MdnsConstants.MDNS_PORT));
         responses.clear();
-        int errorCode = decoder.decode(packet, responses);
+        int errorCode = decoder.decode(packet, responses, MdnsSocket.INTERFACE_INDEX_UNSPECIFIED);
         assertEquals(MdnsResponseDecoder.SUCCESS, errorCode);
         assertEquals(2, responses.size());
     }
@@ -153,7 +153,7 @@
 
         MdnsServiceRecord serviceRecord = response.getServiceRecord();
         String serviceName = serviceRecord.getServiceName();
-        assertEquals(DUMMY_CAST_SERVICE_NAME, serviceName);
+        assertEquals(CAST_SERVICE_NAME, serviceName);
 
         String serviceInstanceName = serviceRecord.getServiceInstanceName();
         assertEquals("Johnny's Chromecast", serviceInstanceName);
@@ -187,14 +187,14 @@
 
     @Test
     public void testDecodeIPv6AnswerPacket() throws IOException {
-        MdnsResponseDecoder decoder = new MdnsResponseDecoder(mClock, DUMMY_CAST_SERVICE_TYPE);
+        MdnsResponseDecoder decoder = new MdnsResponseDecoder(mClock, CAST_SERVICE_TYPE);
         assertNotNull(data6);
         DatagramPacket packet = new DatagramPacket(data6, data6.length);
         packet.setSocketAddress(
                 new InetSocketAddress(MdnsConstants.getMdnsIPv6Address(), MdnsConstants.MDNS_PORT));
 
         responses.clear();
-        int errorCode = decoder.decode(packet, responses);
+        int errorCode = decoder.decode(packet, responses, MdnsSocket.INTERFACE_INDEX_UNSPECIFIED);
         assertEquals(MdnsResponseDecoder.SUCCESS, errorCode);
 
         MdnsResponse response = responses.get(0);
@@ -234,4 +234,19 @@
         response.setTextRecord(null);
         assertFalse(response.isComplete());
     }
+
+    @Test
+    public void decode_withInterfaceIndex_populatesInterfaceIndex() {
+        MdnsResponseDecoder decoder = new MdnsResponseDecoder(mClock, CAST_SERVICE_TYPE);
+        assertNotNull(data6);
+        DatagramPacket packet = new DatagramPacket(data6, data6.length);
+        packet.setSocketAddress(
+                new InetSocketAddress(MdnsConstants.getMdnsIPv6Address(), MdnsConstants.MDNS_PORT));
+
+        responses.clear();
+        int errorCode = decoder.decode(packet, responses, /* interfaceIndex= */ 10);
+        assertEquals(errorCode, MdnsResponseDecoder.SUCCESS);
+        assertEquals(responses.size(), 1);
+        assertEquals(responses.get(0).getInterfaceIndex(), 10);
+    }
 }
\ No newline at end of file
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsResponseTests.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsResponseTests.java
index ae16f2b..771e42c 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsResponseTests.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsResponseTests.java
@@ -222,6 +222,19 @@
     }
 
     @Test
+    public void getInterfaceIndex_returnsDefaultValue() {
+        MdnsResponse response = new MdnsResponse(/* now= */ 0);
+        assertEquals(response.getInterfaceIndex(), -1);
+    }
+
+    @Test
+    public void getInterfaceIndex_afterSet_returnsValue() {
+        MdnsResponse response = new MdnsResponse(/* now= */ 0);
+        response.setInterfaceIndex(5);
+        assertEquals(response.getInterfaceIndex(), 5);
+    }
+
+    @Test
     public void mergeRecordsFrom_indicates_change_on_ipv4_address() throws IOException {
         MdnsResponse response = makeMdnsResponse(
                 0,
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceInfoTest.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceInfoTest.java
index 79d6046..d3934c2 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceInfoTest.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceInfoTest.java
@@ -116,6 +116,40 @@
     }
 
     @Test
+    public void getInterfaceIndex_constructorWithDefaultValues_returnsMinusOne() {
+        MdnsServiceInfo info =
+                new MdnsServiceInfo(
+                        "my-mdns-service",
+                        new String[] {"_googlecast", "_tcp"},
+                        List.of(),
+                        new String[] {"my-host", "local"},
+                        12345,
+                        "192.168.1.1",
+                        "2001::1",
+                        List.of());
+
+        assertEquals(info.getInterfaceIndex(), -1);
+    }
+
+    @Test
+    public void getInterfaceIndex_constructorWithInterfaceIndex_returnsProvidedIndex() {
+        MdnsServiceInfo info =
+                new MdnsServiceInfo(
+                        "my-mdns-service",
+                        new String[] {"_googlecast", "_tcp"},
+                        List.of(),
+                        new String[] {"my-host", "local"},
+                        12345,
+                        "192.168.1.1",
+                        "2001::1",
+                        List.of(),
+                        /* textEntries= */ null,
+                        /* interfaceIndex= */ 20);
+
+        assertEquals(info.getInterfaceIndex(), 20);
+    }
+
+    @Test
     public void parcelable_canBeParceledAndUnparceled() {
         Parcel parcel = Parcel.obtain();
         MdnsServiceInfo beforeParcel =
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java
index c84c386..6b10c71 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java
@@ -401,7 +401,8 @@
                         ipV4Address,
                         5353,
                         Collections.singletonList("ABCDE"),
-                        Collections.emptyMap());
+                        Collections.emptyMap(),
+                        /* interfaceIndex= */ 20);
         client.processResponse(initialResponse);
 
         // Process a second response with a different port and updated text attributes.
@@ -411,7 +412,8 @@
                         ipV4Address,
                         5354,
                         Collections.singletonList("ABCDE"),
-                        Collections.singletonMap("key", "value"));
+                        Collections.singletonMap("key", "value"),
+                        /* interfaceIndex= */ 20);
         client.processResponse(secondResponse);
 
         // Verify onServiceFound was called once for the initial response.
@@ -422,6 +424,7 @@
         assertEquals(initialServiceInfo.getPort(), 5353);
         assertEquals(initialServiceInfo.getSubtypes(), Collections.singletonList("ABCDE"));
         assertNull(initialServiceInfo.getAttributeByKey("key"));
+        assertEquals(initialServiceInfo.getInterfaceIndex(), 20);
 
         // Verify onServiceUpdated was called once for the second response.
         verify(mockListenerOne).onServiceUpdated(serviceInfoCaptor.capture());
@@ -432,6 +435,7 @@
         assertTrue(updatedServiceInfo.hasSubtypes());
         assertEquals(updatedServiceInfo.getSubtypes(), Collections.singletonList("ABCDE"));
         assertEquals(updatedServiceInfo.getAttributeByKey("key"), "value");
+        assertEquals(updatedServiceInfo.getInterfaceIndex(), 20);
     }
 
     @Test
@@ -446,7 +450,8 @@
                         ipV6Address,
                         5353,
                         Collections.singletonList("ABCDE"),
-                        Collections.emptyMap());
+                        Collections.emptyMap(),
+                        /* interfaceIndex= */ 20);
         client.processResponse(initialResponse);
 
         // Process a second response with a different port and updated text attributes.
@@ -456,7 +461,8 @@
                         ipV6Address,
                         5354,
                         Collections.singletonList("ABCDE"),
-                        Collections.singletonMap("key", "value"));
+                        Collections.singletonMap("key", "value"),
+                        /* interfaceIndex= */ 20);
         client.processResponse(secondResponse);
 
         System.out.println("secondResponses ip"
@@ -470,6 +476,7 @@
         assertEquals(initialServiceInfo.getPort(), 5353);
         assertEquals(initialServiceInfo.getSubtypes(), Collections.singletonList("ABCDE"));
         assertNull(initialServiceInfo.getAttributeByKey("key"));
+        assertEquals(initialServiceInfo.getInterfaceIndex(), 20);
 
         // Verify onServiceUpdated was called once for the second response.
         verify(mockListenerOne).onServiceUpdated(serviceInfoCaptor.capture());
@@ -480,6 +487,7 @@
         assertTrue(updatedServiceInfo.hasSubtypes());
         assertEquals(updatedServiceInfo.getSubtypes(), Collections.singletonList("ABCDE"));
         assertEquals(updatedServiceInfo.getAttributeByKey("key"), "value");
+        assertEquals(updatedServiceInfo.getInterfaceIndex(), 20);
     }
 
     @Test
@@ -727,7 +735,6 @@
         }
     }
 
-    // Creates a complete mDNS response.
     private MdnsResponse createResponse(
             @NonNull String serviceInstanceName,
             @NonNull String host,
@@ -735,6 +742,19 @@
             @NonNull List<String> subtypes,
             @NonNull Map<String, String> textAttributes)
             throws Exception {
+        return createResponse(serviceInstanceName, host, port, subtypes, textAttributes,
+                /* interfaceIndex= */ -1);
+    }
+
+    // Creates a complete mDNS response.
+    private MdnsResponse createResponse(
+            @NonNull String serviceInstanceName,
+            @NonNull String host,
+            int port,
+            @NonNull List<String> subtypes,
+            @NonNull Map<String, String> textAttributes,
+            int interfaceIndex)
+            throws Exception {
         String[] hostName = new String[]{"hostname"};
         MdnsServiceRecord serviceRecord = mock(MdnsServiceRecord.class);
         when(serviceRecord.getServiceHost()).thenReturn(hostName);
@@ -747,10 +767,12 @@
             when(inetAddressRecord.getInet6Address())
                     .thenReturn((Inet6Address) Inet6Address.getByName(host));
             response.setInet6AddressRecord(inetAddressRecord);
+            response.setInterfaceIndex(interfaceIndex);
         } else {
             when(inetAddressRecord.getInet4Address())
                     .thenReturn((Inet4Address) Inet4Address.getByName(host));
             response.setInet4AddressRecord(inetAddressRecord);
+            response.setInterfaceIndex(interfaceIndex);
         }
 
         MdnsTextRecord textRecord = mock(MdnsTextRecord.class);
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketClientTests.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketClientTests.java
index 21ed7eb..f84ebfb 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketClientTests.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketClientTests.java
@@ -18,6 +18,7 @@
 
 import static com.android.testutils.DevSdkIgnoreRuleKt.SC_V2;
 
+import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertTrue;
@@ -25,6 +26,7 @@
 import static org.mockito.ArgumentMatchers.anyInt;
 import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.never;
 import static org.mockito.Mockito.timeout;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
@@ -45,6 +47,7 @@
 import org.junit.Ignore;
 import org.junit.Test;
 import org.junit.runner.RunWith;
+import org.mockito.ArgumentCaptor;
 import org.mockito.ArgumentMatchers;
 import org.mockito.Mock;
 import org.mockito.MockitoAnnotations;
@@ -490,4 +493,58 @@
         assertFalse(mdnsClient.receivedUnicastResponse);
         assertFalse(mdnsClient.cannotReceiveMulticastResponse.get());
     }
+
+    @Test
+    public void startDiscovery_andPropagateInterfaceIndex_includesInterfaceIndex()
+            throws Exception {
+        //MdnsConfigsFlagsImpl.allowNetworkInterfaceIndexPropagation.override(true);
+
+        when(mockMulticastSocket.getInterfaceIndex()).thenReturn(21);
+        mdnsClient =
+                new MdnsSocketClient(mContext, mockMulticastLock) {
+                    @Override
+                    MdnsSocket createMdnsSocket(int port) {
+                        if (port == MdnsConstants.MDNS_PORT) {
+                            return mockMulticastSocket;
+                        }
+                        return mockUnicastSocket;
+                    }
+                };
+        mdnsClient.setCallback(mockCallback);
+        mdnsClient.startDiscovery();
+
+        ArgumentCaptor<MdnsResponse> mdnsResponseCaptor =
+                ArgumentCaptor.forClass(MdnsResponse.class);
+        verify(mockCallback, timeout(TIMEOUT).atLeast(1))
+                .onResponseReceived(mdnsResponseCaptor.capture());
+        assertEquals(21, mdnsResponseCaptor.getValue().getInterfaceIndex());
+    }
+
+    @Test
+    @Ignore("MdnsConfigs is not configurable currently.")
+    public void startDiscovery_andDoNotPropagateInterfaceIndex_doesNotIncludeInterfaceIndex()
+            throws Exception {
+        //MdnsConfigsFlagsImpl.allowNetworkInterfaceIndexPropagation.override(false);
+
+        when(mockMulticastSocket.getInterfaceIndex()).thenReturn(21);
+        mdnsClient =
+                new MdnsSocketClient(mContext, mockMulticastLock) {
+                    @Override
+                    MdnsSocket createMdnsSocket(int port) {
+                        if (port == MdnsConstants.MDNS_PORT) {
+                            return mockMulticastSocket;
+                        }
+                        return mockUnicastSocket;
+                    }
+                };
+        mdnsClient.setCallback(mockCallback);
+        mdnsClient.startDiscovery();
+
+        ArgumentCaptor<MdnsResponse> mdnsResponseCaptor =
+                ArgumentCaptor.forClass(MdnsResponse.class);
+        verify(mockMulticastSocket, never()).getInterfaceIndex();
+        verify(mockCallback, timeout(TIMEOUT).atLeast(1))
+                .onResponseReceived(mdnsResponseCaptor.capture());
+        assertEquals(-1, mdnsResponseCaptor.getValue().getInterfaceIndex());
+    }
 }
\ No newline at end of file