Merge "Use NsdServiceInfo ifIndex in MdnsDiscoveryManager" into main
diff --git a/service-t/src/com/android/server/NsdService.java b/service-t/src/com/android/server/NsdService.java
index aca386f..46c435f 100644
--- a/service-t/src/com/android/server/NsdService.java
+++ b/service-t/src/com/android/server/NsdService.java
@@ -1105,8 +1105,11 @@
                             maybeStartMonitoringSockets();
                             final MdnsListener listener = new ResolutionListener(clientRequestId,
                                     transactionId, resolveServiceType);
+                            final int ifaceIdx = info.getNetwork() != null
+                                    ? 0 : info.getInterfaceIndex();
                             final MdnsSearchOptions options = MdnsSearchOptions.newBuilder()
                                     .setNetwork(info.getNetwork())
+                                    .setInterfaceIndex(ifaceIdx)
                                     .setQueryMode(mMdnsFeatureFlags.isAggressiveQueryModeEnabled()
                                             ? AGGRESSIVE_QUERY_MODE
                                             : PASSIVE_QUERY_MODE)
@@ -1205,8 +1208,11 @@
                         maybeStartMonitoringSockets();
                         final MdnsListener listener = new ServiceInfoListener(clientRequestId,
                                 transactionId, resolveServiceType);
+                        final int ifIndex = info.getNetwork() != null
+                                ? 0 : info.getInterfaceIndex();
                         final MdnsSearchOptions options = MdnsSearchOptions.newBuilder()
                                 .setNetwork(info.getNetwork())
+                                .setInterfaceIndex(ifIndex)
                                 .setQueryMode(mMdnsFeatureFlags.isAggressiveQueryModeEnabled()
                                         ? AGGRESSIVE_QUERY_MODE
                                         : PASSIVE_QUERY_MODE)
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java b/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java
index 7b0c738..0ab7a76 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java
@@ -241,11 +241,30 @@
             }
         }
         // Request the network for discovery.
+        // This requests sockets on all networks even if the searchOptions have a given interface
+        // index (with getNetwork==null, for local interfaces), and only uses matching interfaces
+        // in that case. While this is a simple solution to only use matching sockets, a better
+        // practice would be to only request the correct socket for discovery.
+        // TODO: avoid requesting extra sockets after migrating P2P and tethering networks to local
+        // NetworkAgents.
         socketClient.notifyNetworkRequested(listener, searchOptions.getNetwork(),
                 new MdnsSocketClientBase.SocketCreationCallback() {
                     @Override
                     public void onSocketCreated(@NonNull SocketKey socketKey) {
                         discoveryExecutor.ensureRunningOnHandlerThread();
+                        final int searchInterfaceIndex = searchOptions.getInterfaceIndex();
+                        if (searchOptions.getNetwork() == null
+                                && searchInterfaceIndex > 0
+                                // The interface index in options should only match interfaces that
+                                // do not have any Network; a matching Network should be provided
+                                // otherwise.
+                                && (socketKey.getNetwork() != null
+                                    || socketKey.getInterfaceIndex() != searchInterfaceIndex)) {
+                            sharedLog.i("Skipping " + socketKey + " as ifIndex "
+                                    + searchInterfaceIndex + " was requested.");
+                            return;
+                        }
+
                         // All listeners of the same service types shares the same
                         // MdnsServiceTypeClient.
                         MdnsServiceTypeClient serviceTypeClient =
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsSearchOptions.java b/service-t/src/com/android/server/connectivity/mdns/MdnsSearchOptions.java
index 086094b..73405ab 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsSearchOptions.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsSearchOptions.java
@@ -59,6 +59,7 @@
                             source.readInt(),
                             source.readInt() == 1,
                             source.readParcelable(null),
+                            source.readInt(),
                             source.readString(),
                             source.readInt() == 1,
                             source.readInt());
@@ -79,6 +80,8 @@
     private final boolean removeExpiredService;
     // The target network for searching. Null network means search on all possible interfaces.
     @Nullable private final Network mNetwork;
+    // If the target interface does not have a Network, set to the interface index, otherwise unset.
+    private final int mInterfaceIndex;
 
     /** Parcelable constructs for a {@link MdnsSearchOptions}. */
     MdnsSearchOptions(
@@ -86,6 +89,7 @@
             int queryMode,
             boolean removeExpiredService,
             @Nullable Network network,
+            int interfaceIndex,
             @Nullable String resolveInstanceName,
             boolean onlyUseIpv6OnIpv6OnlyNetworks,
             int numOfQueriesBeforeBackoff) {
@@ -98,6 +102,7 @@
         this.numOfQueriesBeforeBackoff = numOfQueriesBeforeBackoff;
         this.removeExpiredService = removeExpiredService;
         mNetwork = network;
+        mInterfaceIndex = interfaceIndex;
         this.resolveInstanceName = resolveInstanceName;
     }
 
@@ -148,15 +153,27 @@
     }
 
     /**
-     * Returns the network which the mdns query should target on.
+     * Returns the network which the mdns query should target.
      *
-     * @return the target network or null if search on all possible interfaces.
+     * @return the target network or null to search on all possible interfaces.
      */
     @Nullable
     public Network getNetwork() {
         return mNetwork;
     }
 
+
+    /**
+     * Returns the interface index which the mdns query should target.
+     *
+     * This is only set when the service is to be searched on an interface that does not have a
+     * Network, in which case {@link #getNetwork()} returns null.
+     * The interface index as per {@link java.net.NetworkInterface#getIndex}, or 0 if unset.
+     */
+    public int getInterfaceIndex() {
+        return mInterfaceIndex;
+    }
+
     /**
      * If non-null, queries should try to resolve all records of this specific service, rather than
      * discovering all services.
@@ -177,6 +194,7 @@
         out.writeInt(queryMode);
         out.writeInt(removeExpiredService ? 1 : 0);
         out.writeParcelable(mNetwork, 0);
+        out.writeInt(mInterfaceIndex);
         out.writeString(resolveInstanceName);
         out.writeInt(onlyUseIpv6OnIpv6OnlyNetworks ? 1 : 0);
         out.writeInt(numOfQueriesBeforeBackoff);
@@ -190,6 +208,7 @@
         private int numOfQueriesBeforeBackoff = 3;
         private boolean removeExpiredService;
         private Network mNetwork;
+        private int mInterfaceIndex;
         private String resolveInstanceName;
 
         private Builder() {
@@ -278,6 +297,16 @@
             return this;
         }
 
+        /**
+         * Set the interface index to use for the query, if not querying on a {@link Network}.
+         *
+         * @see MdnsSearchOptions#getInterfaceIndex()
+         */
+        public Builder setInterfaceIndex(int index) {
+            mInterfaceIndex = index;
+            return this;
+        }
+
         /** Builds a {@link MdnsSearchOptions} with the arguments supplied to this builder. */
         public MdnsSearchOptions build() {
             return new MdnsSearchOptions(
@@ -285,6 +314,7 @@
                     queryMode,
                     removeExpiredService,
                     mNetwork,
+                    mInterfaceIndex,
                     resolveInstanceName,
                     onlyUseIpv6OnIpv6OnlyNetworks,
                     numOfQueriesBeforeBackoff);
diff --git a/tests/unit/java/com/android/server/NsdServiceTest.java b/tests/unit/java/com/android/server/NsdServiceTest.java
index 881de56..d91e29c 100644
--- a/tests/unit/java/com/android/server/NsdServiceTest.java
+++ b/tests/unit/java/com/android/server/NsdServiceTest.java
@@ -521,6 +521,56 @@
     }
 
     @Test
+    @EnableCompatChanges(ENABLE_PLATFORM_MDNS_BACKEND)
+    public void testDiscoverOnTetheringDownstream_DiscoveryManager() throws Exception {
+        final NsdManager client = connectClient(mService);
+        final DiscoveryListener discListener = mock(DiscoveryListener.class);
+        client.discoverServices(SERVICE_TYPE, PROTOCOL, discListener);
+        waitForIdle();
+
+        final ArgumentCaptor<MdnsServiceBrowserListener> discoverListenerCaptor =
+                ArgumentCaptor.forClass(MdnsServiceBrowserListener.class);
+        final InOrder discManagerOrder = inOrder(mDiscoveryManager);
+        final String serviceTypeWithLocalDomain = SERVICE_TYPE + ".local";
+        discManagerOrder.verify(mDiscoveryManager).registerListener(eq(serviceTypeWithLocalDomain),
+                discoverListenerCaptor.capture(), any());
+
+        final int interfaceIdx = 123;
+        final MdnsServiceInfo mockServiceInfo = new MdnsServiceInfo(
+                SERVICE_NAME, /* serviceInstanceName */
+                serviceTypeWithLocalDomain.split("\\."), /* serviceType */
+                List.of(), /* subtypes */
+                new String[] {"android", "local"}, /* hostName */
+                12345, /* port */
+                List.of(IPV4_ADDRESS),
+                List.of(IPV6_ADDRESS),
+                List.of(), /* textStrings */
+                List.of(), /* textEntries */
+                interfaceIdx, /* interfaceIndex */
+                null /* network */,
+                Instant.MAX /* expirationTime */);
+
+        // Verify service is found with the interface index
+        discoverListenerCaptor.getValue().onServiceNameDiscovered(
+                mockServiceInfo, false /* isServiceFromCache */);
+        final ArgumentCaptor<NsdServiceInfo> foundInfoCaptor =
+                ArgumentCaptor.forClass(NsdServiceInfo.class);
+        verify(discListener, timeout(TIMEOUT_MS)).onServiceFound(foundInfoCaptor.capture());
+        final NsdServiceInfo foundInfo = foundInfoCaptor.getValue();
+        assertNull(foundInfo.getNetwork());
+        assertEquals(interfaceIdx, foundInfo.getInterfaceIndex());
+
+        // Using the returned service info to resolve or register callback uses the interface index
+        client.resolveService(foundInfo, mock(ResolveListener.class));
+        client.registerServiceInfoCallback(foundInfo, Runnable::run,
+                mock(ServiceInfoCallback.class));
+        waitForIdle();
+
+        discManagerOrder.verify(mDiscoveryManager, times(2)).registerListener(any(), any(), argThat(
+                o -> o.getNetwork() == null && o.getInterfaceIndex() == interfaceIdx));
+    }
+
+    @Test
     @DisableCompatChanges(ENABLE_PLATFORM_MDNS_BACKEND)
     @DevSdkIgnoreRule.IgnoreAfter(Build.VERSION_CODES.UPSIDE_DOWN_CAKE)
     public void testDiscoverOnBlackholeNetwork() throws Exception {
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsDiscoveryManagerTests.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsDiscoveryManagerTests.java
index 5251e2a..b5c0132 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsDiscoveryManagerTests.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsDiscoveryManagerTests.java
@@ -18,6 +18,8 @@
 
 import static com.android.testutils.DevSdkIgnoreRuleKt.SC_V2;
 
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.fail;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.eq;
@@ -65,8 +67,9 @@
     private static final String SERVICE_TYPE_2 = "_test._tcp.local";
     private static final Network NETWORK_1 = Mockito.mock(Network.class);
     private static final Network NETWORK_2 = Mockito.mock(Network.class);
+    private static final int INTERFACE_INDEX_NULL_NETWORK = 123;
     private static final SocketKey SOCKET_KEY_NULL_NETWORK =
-            new SocketKey(null /* network */, 999 /* interfaceIndex */);
+            new SocketKey(null /* network */, INTERFACE_INDEX_NULL_NETWORK);
     private static final SocketKey SOCKET_KEY_NETWORK_1 =
             new SocketKey(NETWORK_1, 998 /* interfaceIndex */);
     private static final SocketKey SOCKET_KEY_NETWORK_2 =
@@ -97,6 +100,8 @@
     private HandlerThread thread;
     private Handler handler;
 
+    private int createdServiceTypeClientCount;
+
     @Before
     public void setUp() {
         MockitoAnnotations.initMocks(this);
@@ -106,11 +111,13 @@
         handler = new Handler(thread.getLooper());
         doReturn(thread.getLooper()).when(socketClient).getLooper();
         doReturn(true).when(socketClient).supportsRequestingSpecificNetworks();
+        createdServiceTypeClientCount = 0;
         discoveryManager = new MdnsDiscoveryManager(executorProvider, socketClient,
                 sharedLog, MdnsFeatureFlags.newBuilder().build()) {
                     @Override
                     MdnsServiceTypeClient createServiceTypeClient(@NonNull String serviceType,
                             @NonNull SocketKey socketKey) {
+                        createdServiceTypeClientCount++;
                         final Pair<String, SocketKey> perSocketServiceType =
                                 Pair.create(serviceType, socketKey);
                         if (perSocketServiceType.equals(PER_SOCKET_SERVICE_TYPE_1_NULL_NETWORK)) {
@@ -128,6 +135,7 @@
                                 PER_SOCKET_SERVICE_TYPE_2_NETWORK_2)) {
                             return mockServiceTypeClientType2Network2;
                         }
+                        fail("Unexpected perSocketServiceType: " + perSocketServiceType);
                         return null;
                     }
                 };
@@ -324,7 +332,6 @@
 
         // Receive a response, it should be processed on the client.
         final MdnsPacket response = createMdnsPacket(SERVICE_TYPE_1);
-        final int ifIndex = 1;
         runOnHandler(() -> discoveryManager.onResponseReceived(response, SOCKET_KEY_NULL_NETWORK));
         verify(mockServiceTypeClientType1NullNetwork).processResponse(
                 response, SOCKET_KEY_NULL_NETWORK);
@@ -350,6 +357,39 @@
         verify(socketClient, never()).stopDiscovery();
     }
 
+    @Test
+    public void testInterfaceIndexRequested_OnlyUsesSelectedInterface() throws IOException {
+        final MdnsSearchOptions searchOptions =
+                MdnsSearchOptions.newBuilder()
+                        .setNetwork(null /* network */)
+                        .setInterfaceIndex(INTERFACE_INDEX_NULL_NETWORK)
+                        .build();
+
+        final SocketCreationCallback callback = expectSocketCreationCallback(
+                SERVICE_TYPE_1, mockListenerOne, searchOptions);
+        final SocketKey unusedIfaceKey = new SocketKey(null, INTERFACE_INDEX_NULL_NETWORK + 1);
+        final SocketKey matchingIfaceWithNetworkKey =
+                new SocketKey(Mockito.mock(Network.class), INTERFACE_INDEX_NULL_NETWORK);
+        runOnHandler(() -> {
+            callback.onSocketCreated(unusedIfaceKey);
+            callback.onSocketCreated(matchingIfaceWithNetworkKey);
+            callback.onSocketCreated(SOCKET_KEY_NULL_NETWORK);
+            callback.onSocketCreated(SOCKET_KEY_NETWORK_1);
+        });
+        // Only the client for INTERFACE_INDEX_NULL_NETWORK is created
+        verify(mockServiceTypeClientType1NullNetwork).startSendAndReceive(
+                mockListenerOne, searchOptions);
+        assertEquals(1, createdServiceTypeClientCount);
+
+        runOnHandler(() -> {
+            callback.onSocketDestroyed(SOCKET_KEY_NETWORK_1);
+            callback.onSocketDestroyed(SOCKET_KEY_NULL_NETWORK);
+            callback.onSocketDestroyed(matchingIfaceWithNetworkKey);
+            callback.onSocketDestroyed(unusedIfaceKey);
+        });
+        verify(mockServiceTypeClientType1NullNetwork).notifySocketDestroyed();
+    }
+
     private MdnsPacket createMdnsPacket(String serviceType) {
         final String[] type = TextUtils.split(serviceType, "\\.");
         final ArrayList<String> name = new ArrayList<>(type.length + 1);