Use NsdServiceInfo ifIndex in MdnsDiscoveryManager

NsdServiceInfo may contain an interface index if it was obtained from a
previous discovery callback which found a service on an interface that
does not belong to any Network.

In that case, subsequent resolve or registerServiceInfoCallback queries
using that NsdServiceInfo should only be done on the matching interface.
Apps may otherwise receive replies from other networks, with different
interface addresses.

As a simple fix, still request sockets from all interfaces in that case,
but only create the MdnsServiceTypeClients on matching sockets.

The sockets are created / cleaned up based on the lifecycle of the
requests, which does not change.

Fixes: 332262554
Test: atest
Change-Id: I2bbb498d4a0d37725f019a4cf5e367fef76fa0e8
diff --git a/service-t/src/com/android/server/NsdService.java b/service-t/src/com/android/server/NsdService.java
index aca386f..46c435f 100644
--- a/service-t/src/com/android/server/NsdService.java
+++ b/service-t/src/com/android/server/NsdService.java
@@ -1105,8 +1105,11 @@
                             maybeStartMonitoringSockets();
                             final MdnsListener listener = new ResolutionListener(clientRequestId,
                                     transactionId, resolveServiceType);
+                            final int ifaceIdx = info.getNetwork() != null
+                                    ? 0 : info.getInterfaceIndex();
                             final MdnsSearchOptions options = MdnsSearchOptions.newBuilder()
                                     .setNetwork(info.getNetwork())
+                                    .setInterfaceIndex(ifaceIdx)
                                     .setQueryMode(mMdnsFeatureFlags.isAggressiveQueryModeEnabled()
                                             ? AGGRESSIVE_QUERY_MODE
                                             : PASSIVE_QUERY_MODE)
@@ -1205,8 +1208,11 @@
                         maybeStartMonitoringSockets();
                         final MdnsListener listener = new ServiceInfoListener(clientRequestId,
                                 transactionId, resolveServiceType);
+                        final int ifIndex = info.getNetwork() != null
+                                ? 0 : info.getInterfaceIndex();
                         final MdnsSearchOptions options = MdnsSearchOptions.newBuilder()
                                 .setNetwork(info.getNetwork())
+                                .setInterfaceIndex(ifIndex)
                                 .setQueryMode(mMdnsFeatureFlags.isAggressiveQueryModeEnabled()
                                         ? AGGRESSIVE_QUERY_MODE
                                         : PASSIVE_QUERY_MODE)
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java b/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java
index 7b0c738..0ab7a76 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java
@@ -241,11 +241,30 @@
             }
         }
         // Request the network for discovery.
+        // This requests sockets on all networks even if the searchOptions have a given interface
+        // index (with getNetwork==null, for local interfaces), and only uses matching interfaces
+        // in that case. While this is a simple solution to only use matching sockets, a better
+        // practice would be to only request the correct socket for discovery.
+        // TODO: avoid requesting extra sockets after migrating P2P and tethering networks to local
+        // NetworkAgents.
         socketClient.notifyNetworkRequested(listener, searchOptions.getNetwork(),
                 new MdnsSocketClientBase.SocketCreationCallback() {
                     @Override
                     public void onSocketCreated(@NonNull SocketKey socketKey) {
                         discoveryExecutor.ensureRunningOnHandlerThread();
+                        final int searchInterfaceIndex = searchOptions.getInterfaceIndex();
+                        if (searchOptions.getNetwork() == null
+                                && searchInterfaceIndex > 0
+                                // The interface index in options should only match interfaces that
+                                // do not have any Network; a matching Network should be provided
+                                // otherwise.
+                                && (socketKey.getNetwork() != null
+                                    || socketKey.getInterfaceIndex() != searchInterfaceIndex)) {
+                            sharedLog.i("Skipping " + socketKey + " as ifIndex "
+                                    + searchInterfaceIndex + " was requested.");
+                            return;
+                        }
+
                         // All listeners of the same service types shares the same
                         // MdnsServiceTypeClient.
                         MdnsServiceTypeClient serviceTypeClient =
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsSearchOptions.java b/service-t/src/com/android/server/connectivity/mdns/MdnsSearchOptions.java
index 086094b..73405ab 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsSearchOptions.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsSearchOptions.java
@@ -59,6 +59,7 @@
                             source.readInt(),
                             source.readInt() == 1,
                             source.readParcelable(null),
+                            source.readInt(),
                             source.readString(),
                             source.readInt() == 1,
                             source.readInt());
@@ -79,6 +80,8 @@
     private final boolean removeExpiredService;
     // The target network for searching. Null network means search on all possible interfaces.
     @Nullable private final Network mNetwork;
+    // If the target interface does not have a Network, set to the interface index, otherwise unset.
+    private final int mInterfaceIndex;
 
     /** Parcelable constructs for a {@link MdnsSearchOptions}. */
     MdnsSearchOptions(
@@ -86,6 +89,7 @@
             int queryMode,
             boolean removeExpiredService,
             @Nullable Network network,
+            int interfaceIndex,
             @Nullable String resolveInstanceName,
             boolean onlyUseIpv6OnIpv6OnlyNetworks,
             int numOfQueriesBeforeBackoff) {
@@ -98,6 +102,7 @@
         this.numOfQueriesBeforeBackoff = numOfQueriesBeforeBackoff;
         this.removeExpiredService = removeExpiredService;
         mNetwork = network;
+        mInterfaceIndex = interfaceIndex;
         this.resolveInstanceName = resolveInstanceName;
     }
 
@@ -148,15 +153,27 @@
     }
 
     /**
-     * Returns the network which the mdns query should target on.
+     * Returns the network which the mdns query should target.
      *
-     * @return the target network or null if search on all possible interfaces.
+     * @return the target network or null to search on all possible interfaces.
      */
     @Nullable
     public Network getNetwork() {
         return mNetwork;
     }
 
+
+    /**
+     * Returns the interface index which the mdns query should target.
+     *
+     * This is only set when the service is to be searched on an interface that does not have a
+     * Network, in which case {@link #getNetwork()} returns null.
+     * The interface index as per {@link java.net.NetworkInterface#getIndex}, or 0 if unset.
+     */
+    public int getInterfaceIndex() {
+        return mInterfaceIndex;
+    }
+
     /**
      * If non-null, queries should try to resolve all records of this specific service, rather than
      * discovering all services.
@@ -177,6 +194,7 @@
         out.writeInt(queryMode);
         out.writeInt(removeExpiredService ? 1 : 0);
         out.writeParcelable(mNetwork, 0);
+        out.writeInt(mInterfaceIndex);
         out.writeString(resolveInstanceName);
         out.writeInt(onlyUseIpv6OnIpv6OnlyNetworks ? 1 : 0);
         out.writeInt(numOfQueriesBeforeBackoff);
@@ -190,6 +208,7 @@
         private int numOfQueriesBeforeBackoff = 3;
         private boolean removeExpiredService;
         private Network mNetwork;
+        private int mInterfaceIndex;
         private String resolveInstanceName;
 
         private Builder() {
@@ -278,6 +297,16 @@
             return this;
         }
 
+        /**
+         * Set the interface index to use for the query, if not querying on a {@link Network}.
+         *
+         * @see MdnsSearchOptions#getInterfaceIndex()
+         */
+        public Builder setInterfaceIndex(int index) {
+            mInterfaceIndex = index;
+            return this;
+        }
+
         /** Builds a {@link MdnsSearchOptions} with the arguments supplied to this builder. */
         public MdnsSearchOptions build() {
             return new MdnsSearchOptions(
@@ -285,6 +314,7 @@
                     queryMode,
                     removeExpiredService,
                     mNetwork,
+                    mInterfaceIndex,
                     resolveInstanceName,
                     onlyUseIpv6OnIpv6OnlyNetworks,
                     numOfQueriesBeforeBackoff);
diff --git a/tests/unit/java/com/android/server/NsdServiceTest.java b/tests/unit/java/com/android/server/NsdServiceTest.java
index 881de56..d91e29c 100644
--- a/tests/unit/java/com/android/server/NsdServiceTest.java
+++ b/tests/unit/java/com/android/server/NsdServiceTest.java
@@ -521,6 +521,56 @@
     }
 
     @Test
+    @EnableCompatChanges(ENABLE_PLATFORM_MDNS_BACKEND)
+    public void testDiscoverOnTetheringDownstream_DiscoveryManager() throws Exception {
+        final NsdManager client = connectClient(mService);
+        final DiscoveryListener discListener = mock(DiscoveryListener.class);
+        client.discoverServices(SERVICE_TYPE, PROTOCOL, discListener);
+        waitForIdle();
+
+        final ArgumentCaptor<MdnsServiceBrowserListener> discoverListenerCaptor =
+                ArgumentCaptor.forClass(MdnsServiceBrowserListener.class);
+        final InOrder discManagerOrder = inOrder(mDiscoveryManager);
+        final String serviceTypeWithLocalDomain = SERVICE_TYPE + ".local";
+        discManagerOrder.verify(mDiscoveryManager).registerListener(eq(serviceTypeWithLocalDomain),
+                discoverListenerCaptor.capture(), any());
+
+        final int interfaceIdx = 123;
+        final MdnsServiceInfo mockServiceInfo = new MdnsServiceInfo(
+                SERVICE_NAME, /* serviceInstanceName */
+                serviceTypeWithLocalDomain.split("\\."), /* serviceType */
+                List.of(), /* subtypes */
+                new String[] {"android", "local"}, /* hostName */
+                12345, /* port */
+                List.of(IPV4_ADDRESS),
+                List.of(IPV6_ADDRESS),
+                List.of(), /* textStrings */
+                List.of(), /* textEntries */
+                interfaceIdx, /* interfaceIndex */
+                null /* network */,
+                Instant.MAX /* expirationTime */);
+
+        // Verify service is found with the interface index
+        discoverListenerCaptor.getValue().onServiceNameDiscovered(
+                mockServiceInfo, false /* isServiceFromCache */);
+        final ArgumentCaptor<NsdServiceInfo> foundInfoCaptor =
+                ArgumentCaptor.forClass(NsdServiceInfo.class);
+        verify(discListener, timeout(TIMEOUT_MS)).onServiceFound(foundInfoCaptor.capture());
+        final NsdServiceInfo foundInfo = foundInfoCaptor.getValue();
+        assertNull(foundInfo.getNetwork());
+        assertEquals(interfaceIdx, foundInfo.getInterfaceIndex());
+
+        // Using the returned service info to resolve or register callback uses the interface index
+        client.resolveService(foundInfo, mock(ResolveListener.class));
+        client.registerServiceInfoCallback(foundInfo, Runnable::run,
+                mock(ServiceInfoCallback.class));
+        waitForIdle();
+
+        discManagerOrder.verify(mDiscoveryManager, times(2)).registerListener(any(), any(), argThat(
+                o -> o.getNetwork() == null && o.getInterfaceIndex() == interfaceIdx));
+    }
+
+    @Test
     @DisableCompatChanges(ENABLE_PLATFORM_MDNS_BACKEND)
     @DevSdkIgnoreRule.IgnoreAfter(Build.VERSION_CODES.UPSIDE_DOWN_CAKE)
     public void testDiscoverOnBlackholeNetwork() throws Exception {
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsDiscoveryManagerTests.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsDiscoveryManagerTests.java
index 5251e2a..b5c0132 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsDiscoveryManagerTests.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsDiscoveryManagerTests.java
@@ -18,6 +18,8 @@
 
 import static com.android.testutils.DevSdkIgnoreRuleKt.SC_V2;
 
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.fail;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.eq;
@@ -65,8 +67,9 @@
     private static final String SERVICE_TYPE_2 = "_test._tcp.local";
     private static final Network NETWORK_1 = Mockito.mock(Network.class);
     private static final Network NETWORK_2 = Mockito.mock(Network.class);
+    private static final int INTERFACE_INDEX_NULL_NETWORK = 123;
     private static final SocketKey SOCKET_KEY_NULL_NETWORK =
-            new SocketKey(null /* network */, 999 /* interfaceIndex */);
+            new SocketKey(null /* network */, INTERFACE_INDEX_NULL_NETWORK);
     private static final SocketKey SOCKET_KEY_NETWORK_1 =
             new SocketKey(NETWORK_1, 998 /* interfaceIndex */);
     private static final SocketKey SOCKET_KEY_NETWORK_2 =
@@ -97,6 +100,8 @@
     private HandlerThread thread;
     private Handler handler;
 
+    private int createdServiceTypeClientCount;
+
     @Before
     public void setUp() {
         MockitoAnnotations.initMocks(this);
@@ -106,11 +111,13 @@
         handler = new Handler(thread.getLooper());
         doReturn(thread.getLooper()).when(socketClient).getLooper();
         doReturn(true).when(socketClient).supportsRequestingSpecificNetworks();
+        createdServiceTypeClientCount = 0;
         discoveryManager = new MdnsDiscoveryManager(executorProvider, socketClient,
                 sharedLog, MdnsFeatureFlags.newBuilder().build()) {
                     @Override
                     MdnsServiceTypeClient createServiceTypeClient(@NonNull String serviceType,
                             @NonNull SocketKey socketKey) {
+                        createdServiceTypeClientCount++;
                         final Pair<String, SocketKey> perSocketServiceType =
                                 Pair.create(serviceType, socketKey);
                         if (perSocketServiceType.equals(PER_SOCKET_SERVICE_TYPE_1_NULL_NETWORK)) {
@@ -128,6 +135,7 @@
                                 PER_SOCKET_SERVICE_TYPE_2_NETWORK_2)) {
                             return mockServiceTypeClientType2Network2;
                         }
+                        fail("Unexpected perSocketServiceType: " + perSocketServiceType);
                         return null;
                     }
                 };
@@ -324,7 +332,6 @@
 
         // Receive a response, it should be processed on the client.
         final MdnsPacket response = createMdnsPacket(SERVICE_TYPE_1);
-        final int ifIndex = 1;
         runOnHandler(() -> discoveryManager.onResponseReceived(response, SOCKET_KEY_NULL_NETWORK));
         verify(mockServiceTypeClientType1NullNetwork).processResponse(
                 response, SOCKET_KEY_NULL_NETWORK);
@@ -350,6 +357,39 @@
         verify(socketClient, never()).stopDiscovery();
     }
 
+    @Test
+    public void testInterfaceIndexRequested_OnlyUsesSelectedInterface() throws IOException {
+        final MdnsSearchOptions searchOptions =
+                MdnsSearchOptions.newBuilder()
+                        .setNetwork(null /* network */)
+                        .setInterfaceIndex(INTERFACE_INDEX_NULL_NETWORK)
+                        .build();
+
+        final SocketCreationCallback callback = expectSocketCreationCallback(
+                SERVICE_TYPE_1, mockListenerOne, searchOptions);
+        final SocketKey unusedIfaceKey = new SocketKey(null, INTERFACE_INDEX_NULL_NETWORK + 1);
+        final SocketKey matchingIfaceWithNetworkKey =
+                new SocketKey(Mockito.mock(Network.class), INTERFACE_INDEX_NULL_NETWORK);
+        runOnHandler(() -> {
+            callback.onSocketCreated(unusedIfaceKey);
+            callback.onSocketCreated(matchingIfaceWithNetworkKey);
+            callback.onSocketCreated(SOCKET_KEY_NULL_NETWORK);
+            callback.onSocketCreated(SOCKET_KEY_NETWORK_1);
+        });
+        // Only the client for INTERFACE_INDEX_NULL_NETWORK is created
+        verify(mockServiceTypeClientType1NullNetwork).startSendAndReceive(
+                mockListenerOne, searchOptions);
+        assertEquals(1, createdServiceTypeClientCount);
+
+        runOnHandler(() -> {
+            callback.onSocketDestroyed(SOCKET_KEY_NETWORK_1);
+            callback.onSocketDestroyed(SOCKET_KEY_NULL_NETWORK);
+            callback.onSocketDestroyed(matchingIfaceWithNetworkKey);
+            callback.onSocketDestroyed(unusedIfaceKey);
+        });
+        verify(mockServiceTypeClientType1NullNetwork).notifySocketDestroyed();
+    }
+
     private MdnsPacket createMdnsPacket(String serviceType) {
         final String[] type = TextUtils.split(serviceType, "\\.");
         final ArrayList<String> name = new ArrayList<>(type.length + 1);