Include/exclude mDNS interfaces based on transport

Regardless of IFF_MULTICAST (and IFF_BROADCAST), always include wifi
transport interfaces, and always exclude cellular transport interfaces.

Some interfaces do not have the multicast or broadcast flag set
properly. Use the transport to determine whether to use them, rather
than the interface flags.

Bug: 268138840
Test: atest MdnsSocketProviderTest
Change-Id: Idbddfa9d2cc05ce1850786aa634da4c38afd3fc0
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsSocketProvider.java b/service-t/src/com/android/server/connectivity/mdns/MdnsSocketProvider.java
index 2823f92..0952e88 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsSocketProvider.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsSocketProvider.java
@@ -16,6 +16,10 @@
 
 package com.android.server.connectivity.mdns;
 
+import static android.net.NetworkCapabilities.TRANSPORT_CELLULAR;
+import static android.net.NetworkCapabilities.TRANSPORT_VPN;
+import static android.net.NetworkCapabilities.TRANSPORT_WIFI;
+
 import android.annotation.NonNull;
 import android.annotation.Nullable;
 import android.content.Context;
@@ -24,6 +28,7 @@
 import android.net.LinkAddress;
 import android.net.LinkProperties;
 import android.net.Network;
+import android.net.NetworkCapabilities;
 import android.net.NetworkRequest;
 import android.net.TetheringManager;
 import android.net.TetheringManager.TetheringEventCallback;
@@ -71,6 +76,7 @@
     private final ArrayMap<String, SocketInfo> mTetherInterfaceSockets = new ArrayMap<>();
     private final ArrayMap<Network, LinkProperties> mActiveNetworksLinkProperties =
             new ArrayMap<>();
+    private final ArrayMap<Network, int[]> mActiveNetworksTransports = new ArrayMap<>();
     private final ArrayMap<SocketCallback, Network> mCallbacksToRequestedNetworks =
             new ArrayMap<>();
     private final List<String> mLocalOnlyInterfaces = new ArrayList<>();
@@ -93,10 +99,17 @@
             @Override
             public void onLost(Network network) {
                 mActiveNetworksLinkProperties.remove(network);
+                mActiveNetworksTransports.remove(network);
                 removeNetworkSocket(network);
             }
 
             @Override
+            public void onCapabilitiesChanged(@NonNull Network network,
+                    @NonNull NetworkCapabilities networkCapabilities) {
+                mActiveNetworksTransports.put(network, networkCapabilities.getTransportTypes());
+            }
+
+            @Override
             public void onLinkPropertiesChanged(Network network, LinkProperties lp) {
                 handleLinkPropertiesChanged(network, lp);
             }
@@ -129,11 +142,6 @@
             return ni == null ? null : new NetworkInterfaceWrapper(ni);
         }
 
-        /*** Check whether given network interface can support mdns */
-        public boolean canScanOnInterface(@NonNull NetworkInterfaceWrapper networkInterface) {
-            return MulticastNetworkInterfaceProvider.canScanOnInterface(networkInterface);
-        }
-
         /*** Create a MdnsInterfaceSocket */
         public MdnsInterfaceSocket createMdnsInterfaceSocket(
                 @NonNull NetworkInterface networkInterface, int port, @NonNull Looper looper,
@@ -303,7 +311,17 @@
         try {
             final NetworkInterfaceWrapper networkInterface =
                     mDependencies.getNetworkInterfaceByName(interfaceName);
-            if (networkInterface == null || !mDependencies.canScanOnInterface(networkInterface)) {
+            // There are no transports for tethered interfaces. Other interfaces should always
+            // have transports since LinkProperties updates are always sent after
+            // NetworkCapabilities updates.
+            final int[] transports;
+            if (networkKey == LOCAL_NET) {
+                transports = new int[0];
+            } else {
+                transports = mActiveNetworksTransports.getOrDefault(
+                        ((NetworkAsKey) networkKey).mNetwork, new int[0]);
+            }
+            if (networkInterface == null || !isMdnsCapableInterface(networkInterface, transports)) {
                 return;
             }
 
@@ -339,6 +357,36 @@
         }
     }
 
+    private boolean isMdnsCapableInterface(
+            @NonNull NetworkInterfaceWrapper iface, @NonNull int[] transports) {
+        try {
+            // Never try mDNS on cellular, or on interfaces with incompatible flags
+            if (CollectionUtils.contains(transports, TRANSPORT_CELLULAR)
+                    || iface.isLoopback()
+                    || iface.isPointToPoint()
+                    || iface.isVirtual()
+                    || !iface.isUp()) {
+                return false;
+            }
+
+            // Otherwise, always try mDNS on non-VPN Wifi.
+            if (!CollectionUtils.contains(transports, TRANSPORT_VPN)
+                    && CollectionUtils.contains(transports, TRANSPORT_WIFI)) {
+                return true;
+            }
+
+            // For other transports, or no transports (tethering downstreams), do mDNS based on the
+            // interface flags. This is not always reliable (for example some Wifi interfaces may
+            // not have the MULTICAST flag even though they can do mDNS, and some cellular
+            // interfaces may have the BROADCAST or MULTICAST flags), so checks are done based on
+            // transports above in priority.
+            return iface.supportsMulticast();
+        } catch (SocketException e) {
+            Log.e(TAG, "Error checking interface flags", e);
+            return false;
+        }
+    }
+
     private void removeNetworkSocket(Network network) {
         final SocketInfo socketInfo = mNetworkSockets.remove(network);
         if (socketInfo == null) return;
diff --git a/service-t/src/com/android/server/connectivity/mdns/MulticastNetworkInterfaceProvider.java b/service-t/src/com/android/server/connectivity/mdns/MulticastNetworkInterfaceProvider.java
index ade7b95..f248c98 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MulticastNetworkInterfaceProvider.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MulticastNetworkInterfaceProvider.java
@@ -148,7 +148,7 @@
     }
 
     /*** Check whether given network interface can support mdns */
-    public static boolean canScanOnInterface(@Nullable NetworkInterfaceWrapper networkInterface) {
+    private static boolean canScanOnInterface(@Nullable NetworkInterfaceWrapper networkInterface) {
         try {
             if ((networkInterface == null)
                     || networkInterface.isLoopback()
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketProviderTest.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketProviderTest.java
index 004ea7c..d9420b8 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketProviderTest.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketProviderTest.java
@@ -16,6 +16,11 @@
 
 package com.android.server.connectivity.mdns;
 
+import static android.net.NetworkCapabilities.TRANSPORT_BLUETOOTH;
+import static android.net.NetworkCapabilities.TRANSPORT_CELLULAR;
+import static android.net.NetworkCapabilities.TRANSPORT_VPN;
+import static android.net.NetworkCapabilities.TRANSPORT_WIFI;
+
 import static com.android.testutils.ContextUtils.mockService;
 
 import static org.junit.Assert.assertEquals;
@@ -38,6 +43,7 @@
 import android.net.LinkAddress;
 import android.net.LinkProperties;
 import android.net.Network;
+import android.net.NetworkCapabilities;
 import android.net.TetheringManager;
 import android.net.TetheringManager.TetheringEventCallback;
 import android.os.Build;
@@ -98,8 +104,13 @@
             // Test is using mockito-extended
             doCallRealMethod().when(mContext).getSystemService(TetheringManager.class);
         }
-        doReturn(true).when(mDeps).canScanOnInterface(any());
         doReturn(mTestNetworkIfaceWrapper).when(mDeps).getNetworkInterfaceByName(anyString());
+        doReturn(true).when(mTestNetworkIfaceWrapper).isUp();
+        doReturn(true).when(mLocalOnlyIfaceWrapper).isUp();
+        doReturn(true).when(mTetheredIfaceWrapper).isUp();
+        doReturn(true).when(mTestNetworkIfaceWrapper).supportsMulticast();
+        doReturn(true).when(mLocalOnlyIfaceWrapper).supportsMulticast();
+        doReturn(true).when(mTetheredIfaceWrapper).supportsMulticast();
         doReturn(mLocalOnlyIfaceWrapper).when(mDeps)
                 .getNetworkInterfaceByName(LOCAL_ONLY_IFACE_NAME);
         doReturn(mTetheredIfaceWrapper).when(mDeps).getNetworkInterfaceByName(TETHERED_IFACE_NAME);
@@ -205,6 +216,24 @@
         }
     }
 
+    private static NetworkCapabilities makeCapabilities(int... transports) {
+        final NetworkCapabilities nc = new NetworkCapabilities();
+        for (int transport : transports) {
+            nc.addTransportType(transport);
+        }
+        return nc;
+    }
+
+    private void postNetworkAvailable(int... transports) {
+        final LinkProperties testLp = new LinkProperties();
+        testLp.setInterfaceName(TEST_IFACE_NAME);
+        testLp.setLinkAddresses(List.of(LINKADDRV4));
+        final NetworkCapabilities testNc = makeCapabilities(transports);
+        mHandler.post(() -> mNetworkCallback.onCapabilitiesChanged(TEST_NETWORK, testNc));
+        mHandler.post(() -> mNetworkCallback.onLinkPropertiesChanged(TEST_NETWORK, testLp));
+        HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
+    }
+
     @Test
     public void testSocketRequestAndUnrequestSocket() {
         startMonitoringSockets();
@@ -214,12 +243,7 @@
         HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
         testCallback1.expectedNoCallback();
 
-        final LinkProperties testLp = new LinkProperties();
-        testLp.setInterfaceName(TEST_IFACE_NAME);
-        testLp.setLinkAddresses(List.of(LINKADDRV4));
-        mHandler.post(() -> mNetworkCallback.onLinkPropertiesChanged(TEST_NETWORK, testLp));
-        HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
-        verify(mTestNetworkIfaceWrapper).getNetworkInterface();
+        postNetworkAvailable(TRANSPORT_WIFI);
         testCallback1.expectedSocketCreatedForNetwork(TEST_NETWORK, List.of(LINKADDRV4));
 
         final TestSocketCallback testCallback2 = new TestSocketCallback();
@@ -286,12 +310,7 @@
         HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
         testCallback.expectedNoCallback();
 
-        final LinkProperties testLp = new LinkProperties();
-        testLp.setInterfaceName(TEST_IFACE_NAME);
-        testLp.setLinkAddresses(List.of(LINKADDRV4));
-        mHandler.post(() -> mNetworkCallback.onLinkPropertiesChanged(TEST_NETWORK, testLp));
-        HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
-        verify(mTestNetworkIfaceWrapper, times(1)).getNetworkInterface();
+        postNetworkAvailable(TRANSPORT_WIFI);
         testCallback.expectedSocketCreatedForNetwork(TEST_NETWORK, List.of(LINKADDRV4));
 
         final LinkProperties newTestLp = new LinkProperties();
@@ -299,7 +318,6 @@
         newTestLp.setLinkAddresses(List.of(LINKADDRV4, LINKADDRV6));
         mHandler.post(() -> mNetworkCallback.onLinkPropertiesChanged(TEST_NETWORK, newTestLp));
         HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
-        verify(mTestNetworkIfaceWrapper, times(1)).getNetworkInterface();
         testCallback.expectedAddressesChangedForNetwork(
                 TEST_NETWORK, List.of(LINKADDRV4, LINKADDRV6));
     }
@@ -403,4 +421,77 @@
         verify(mTestNetworkIfaceWrapper, times(2)).getNetworkInterface();
         testCallback.expectedSocketCreatedForNetwork(otherNetwork, List.of(otherAddress));
     }
+
+    @Test
+    public void testNoSocketCreatedForCellular() {
+        startMonitoringSockets();
+
+        final TestSocketCallback testCallback = new TestSocketCallback();
+        mHandler.post(() -> mSocketProvider.requestSocket(TEST_NETWORK, testCallback));
+
+        postNetworkAvailable(TRANSPORT_CELLULAR);
+        testCallback.expectedNoCallback();
+    }
+
+    @Test
+    public void testNoSocketCreatedForNonMulticastInterface() throws Exception {
+        doReturn(false).when(mTestNetworkIfaceWrapper).supportsMulticast();
+        startMonitoringSockets();
+
+        final TestSocketCallback testCallback = new TestSocketCallback();
+        mHandler.post(() -> mSocketProvider.requestSocket(TEST_NETWORK, testCallback));
+
+        postNetworkAvailable(TRANSPORT_BLUETOOTH);
+        testCallback.expectedNoCallback();
+    }
+
+    @Test
+    public void testSocketCreatedForMulticastInterface() throws Exception {
+        doReturn(true).when(mTestNetworkIfaceWrapper).supportsMulticast();
+        startMonitoringSockets();
+
+        final TestSocketCallback testCallback = new TestSocketCallback();
+        mHandler.post(() -> mSocketProvider.requestSocket(TEST_NETWORK, testCallback));
+
+        postNetworkAvailable(TRANSPORT_BLUETOOTH);
+        testCallback.expectedSocketCreatedForNetwork(TEST_NETWORK, List.of(LINKADDRV4));
+    }
+
+    @Test
+    public void testNoSocketCreatedForPTPInterface() throws Exception {
+        doReturn(true).when(mTestNetworkIfaceWrapper).isPointToPoint();
+        startMonitoringSockets();
+
+        final TestSocketCallback testCallback = new TestSocketCallback();
+        mHandler.post(() -> mSocketProvider.requestSocket(TEST_NETWORK, testCallback));
+
+        postNetworkAvailable(TRANSPORT_BLUETOOTH);
+        testCallback.expectedNoCallback();
+    }
+
+    @Test
+    public void testNoSocketCreatedForVPNInterface() throws Exception {
+        // VPN interfaces generally also have IFF_POINTOPOINT, but even if they don't, they should
+        // not be included even with TRANSPORT_WIFI.
+        doReturn(false).when(mTestNetworkIfaceWrapper).supportsMulticast();
+        startMonitoringSockets();
+
+        final TestSocketCallback testCallback = new TestSocketCallback();
+        mHandler.post(() -> mSocketProvider.requestSocket(TEST_NETWORK, testCallback));
+
+        postNetworkAvailable(TRANSPORT_VPN, TRANSPORT_WIFI);
+        testCallback.expectedNoCallback();
+    }
+
+    @Test
+    public void testSocketCreatedForWifiWithoutMulticastFlag() throws Exception {
+        doReturn(false).when(mTestNetworkIfaceWrapper).supportsMulticast();
+        startMonitoringSockets();
+
+        final TestSocketCallback testCallback = new TestSocketCallback();
+        mHandler.post(() -> mSocketProvider.requestSocket(TEST_NETWORK, testCallback));
+
+        postNetworkAvailable(TRANSPORT_WIFI);
+        testCallback.expectedSocketCreatedForNetwork(TEST_NETWORK, List.of(LINKADDRV4));
+    }
 }