Merge "Include/exclude mDNS interfaces based on transport"
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsSocketProvider.java b/service-t/src/com/android/server/connectivity/mdns/MdnsSocketProvider.java
index 2823f92..0952e88 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsSocketProvider.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsSocketProvider.java
@@ -16,6 +16,10 @@
 
 package com.android.server.connectivity.mdns;
 
+import static android.net.NetworkCapabilities.TRANSPORT_CELLULAR;
+import static android.net.NetworkCapabilities.TRANSPORT_VPN;
+import static android.net.NetworkCapabilities.TRANSPORT_WIFI;
+
 import android.annotation.NonNull;
 import android.annotation.Nullable;
 import android.content.Context;
@@ -24,6 +28,7 @@
 import android.net.LinkAddress;
 import android.net.LinkProperties;
 import android.net.Network;
+import android.net.NetworkCapabilities;
 import android.net.NetworkRequest;
 import android.net.TetheringManager;
 import android.net.TetheringManager.TetheringEventCallback;
@@ -71,6 +76,7 @@
     private final ArrayMap<String, SocketInfo> mTetherInterfaceSockets = new ArrayMap<>();
     private final ArrayMap<Network, LinkProperties> mActiveNetworksLinkProperties =
             new ArrayMap<>();
+    private final ArrayMap<Network, int[]> mActiveNetworksTransports = new ArrayMap<>();
     private final ArrayMap<SocketCallback, Network> mCallbacksToRequestedNetworks =
             new ArrayMap<>();
     private final List<String> mLocalOnlyInterfaces = new ArrayList<>();
@@ -93,10 +99,17 @@
             @Override
             public void onLost(Network network) {
                 mActiveNetworksLinkProperties.remove(network);
+                mActiveNetworksTransports.remove(network);
                 removeNetworkSocket(network);
             }
 
             @Override
+            public void onCapabilitiesChanged(@NonNull Network network,
+                    @NonNull NetworkCapabilities networkCapabilities) {
+                mActiveNetworksTransports.put(network, networkCapabilities.getTransportTypes());
+            }
+
+            @Override
             public void onLinkPropertiesChanged(Network network, LinkProperties lp) {
                 handleLinkPropertiesChanged(network, lp);
             }
@@ -129,11 +142,6 @@
             return ni == null ? null : new NetworkInterfaceWrapper(ni);
         }
 
-        /*** Check whether given network interface can support mdns */
-        public boolean canScanOnInterface(@NonNull NetworkInterfaceWrapper networkInterface) {
-            return MulticastNetworkInterfaceProvider.canScanOnInterface(networkInterface);
-        }
-
         /*** Create a MdnsInterfaceSocket */
         public MdnsInterfaceSocket createMdnsInterfaceSocket(
                 @NonNull NetworkInterface networkInterface, int port, @NonNull Looper looper,
@@ -303,7 +311,17 @@
         try {
             final NetworkInterfaceWrapper networkInterface =
                     mDependencies.getNetworkInterfaceByName(interfaceName);
-            if (networkInterface == null || !mDependencies.canScanOnInterface(networkInterface)) {
+            // There are no transports for tethered interfaces. Other interfaces should always
+            // have transports since LinkProperties updates are always sent after
+            // NetworkCapabilities updates.
+            final int[] transports;
+            if (networkKey == LOCAL_NET) {
+                transports = new int[0];
+            } else {
+                transports = mActiveNetworksTransports.getOrDefault(
+                        ((NetworkAsKey) networkKey).mNetwork, new int[0]);
+            }
+            if (networkInterface == null || !isMdnsCapableInterface(networkInterface, transports)) {
                 return;
             }
 
@@ -339,6 +357,36 @@
         }
     }
 
+    private boolean isMdnsCapableInterface(
+            @NonNull NetworkInterfaceWrapper iface, @NonNull int[] transports) {
+        try {
+            // Never try mDNS on cellular, or on interfaces with incompatible flags
+            if (CollectionUtils.contains(transports, TRANSPORT_CELLULAR)
+                    || iface.isLoopback()
+                    || iface.isPointToPoint()
+                    || iface.isVirtual()
+                    || !iface.isUp()) {
+                return false;
+            }
+
+            // Otherwise, always try mDNS on non-VPN Wifi.
+            if (!CollectionUtils.contains(transports, TRANSPORT_VPN)
+                    && CollectionUtils.contains(transports, TRANSPORT_WIFI)) {
+                return true;
+            }
+
+            // For other transports, or no transports (tethering downstreams), do mDNS based on the
+            // interface flags. This is not always reliable (for example some Wifi interfaces may
+            // not have the MULTICAST flag even though they can do mDNS, and some cellular
+            // interfaces may have the BROADCAST or MULTICAST flags), so checks are done based on
+            // transports above in priority.
+            return iface.supportsMulticast();
+        } catch (SocketException e) {
+            Log.e(TAG, "Error checking interface flags", e);
+            return false;
+        }
+    }
+
     private void removeNetworkSocket(Network network) {
         final SocketInfo socketInfo = mNetworkSockets.remove(network);
         if (socketInfo == null) return;
diff --git a/service-t/src/com/android/server/connectivity/mdns/MulticastNetworkInterfaceProvider.java b/service-t/src/com/android/server/connectivity/mdns/MulticastNetworkInterfaceProvider.java
index ade7b95..f248c98 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MulticastNetworkInterfaceProvider.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MulticastNetworkInterfaceProvider.java
@@ -148,7 +148,7 @@
     }
 
     /*** Check whether given network interface can support mdns */
-    public static boolean canScanOnInterface(@Nullable NetworkInterfaceWrapper networkInterface) {
+    private static boolean canScanOnInterface(@Nullable NetworkInterfaceWrapper networkInterface) {
         try {
             if ((networkInterface == null)
                     || networkInterface.isLoopback()
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketProviderTest.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketProviderTest.java
index 004ea7c..d9420b8 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketProviderTest.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketProviderTest.java
@@ -16,6 +16,11 @@
 
 package com.android.server.connectivity.mdns;
 
+import static android.net.NetworkCapabilities.TRANSPORT_BLUETOOTH;
+import static android.net.NetworkCapabilities.TRANSPORT_CELLULAR;
+import static android.net.NetworkCapabilities.TRANSPORT_VPN;
+import static android.net.NetworkCapabilities.TRANSPORT_WIFI;
+
 import static com.android.testutils.ContextUtils.mockService;
 
 import static org.junit.Assert.assertEquals;
@@ -38,6 +43,7 @@
 import android.net.LinkAddress;
 import android.net.LinkProperties;
 import android.net.Network;
+import android.net.NetworkCapabilities;
 import android.net.TetheringManager;
 import android.net.TetheringManager.TetheringEventCallback;
 import android.os.Build;
@@ -98,8 +104,13 @@
             // Test is using mockito-extended
             doCallRealMethod().when(mContext).getSystemService(TetheringManager.class);
         }
-        doReturn(true).when(mDeps).canScanOnInterface(any());
         doReturn(mTestNetworkIfaceWrapper).when(mDeps).getNetworkInterfaceByName(anyString());
+        doReturn(true).when(mTestNetworkIfaceWrapper).isUp();
+        doReturn(true).when(mLocalOnlyIfaceWrapper).isUp();
+        doReturn(true).when(mTetheredIfaceWrapper).isUp();
+        doReturn(true).when(mTestNetworkIfaceWrapper).supportsMulticast();
+        doReturn(true).when(mLocalOnlyIfaceWrapper).supportsMulticast();
+        doReturn(true).when(mTetheredIfaceWrapper).supportsMulticast();
         doReturn(mLocalOnlyIfaceWrapper).when(mDeps)
                 .getNetworkInterfaceByName(LOCAL_ONLY_IFACE_NAME);
         doReturn(mTetheredIfaceWrapper).when(mDeps).getNetworkInterfaceByName(TETHERED_IFACE_NAME);
@@ -205,6 +216,24 @@
         }
     }
 
+    private static NetworkCapabilities makeCapabilities(int... transports) {
+        final NetworkCapabilities nc = new NetworkCapabilities();
+        for (int transport : transports) {
+            nc.addTransportType(transport);
+        }
+        return nc;
+    }
+
+    private void postNetworkAvailable(int... transports) {
+        final LinkProperties testLp = new LinkProperties();
+        testLp.setInterfaceName(TEST_IFACE_NAME);
+        testLp.setLinkAddresses(List.of(LINKADDRV4));
+        final NetworkCapabilities testNc = makeCapabilities(transports);
+        mHandler.post(() -> mNetworkCallback.onCapabilitiesChanged(TEST_NETWORK, testNc));
+        mHandler.post(() -> mNetworkCallback.onLinkPropertiesChanged(TEST_NETWORK, testLp));
+        HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
+    }
+
     @Test
     public void testSocketRequestAndUnrequestSocket() {
         startMonitoringSockets();
@@ -214,12 +243,7 @@
         HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
         testCallback1.expectedNoCallback();
 
-        final LinkProperties testLp = new LinkProperties();
-        testLp.setInterfaceName(TEST_IFACE_NAME);
-        testLp.setLinkAddresses(List.of(LINKADDRV4));
-        mHandler.post(() -> mNetworkCallback.onLinkPropertiesChanged(TEST_NETWORK, testLp));
-        HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
-        verify(mTestNetworkIfaceWrapper).getNetworkInterface();
+        postNetworkAvailable(TRANSPORT_WIFI);
         testCallback1.expectedSocketCreatedForNetwork(TEST_NETWORK, List.of(LINKADDRV4));
 
         final TestSocketCallback testCallback2 = new TestSocketCallback();
@@ -286,12 +310,7 @@
         HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
         testCallback.expectedNoCallback();
 
-        final LinkProperties testLp = new LinkProperties();
-        testLp.setInterfaceName(TEST_IFACE_NAME);
-        testLp.setLinkAddresses(List.of(LINKADDRV4));
-        mHandler.post(() -> mNetworkCallback.onLinkPropertiesChanged(TEST_NETWORK, testLp));
-        HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
-        verify(mTestNetworkIfaceWrapper, times(1)).getNetworkInterface();
+        postNetworkAvailable(TRANSPORT_WIFI);
         testCallback.expectedSocketCreatedForNetwork(TEST_NETWORK, List.of(LINKADDRV4));
 
         final LinkProperties newTestLp = new LinkProperties();
@@ -299,7 +318,6 @@
         newTestLp.setLinkAddresses(List.of(LINKADDRV4, LINKADDRV6));
         mHandler.post(() -> mNetworkCallback.onLinkPropertiesChanged(TEST_NETWORK, newTestLp));
         HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
-        verify(mTestNetworkIfaceWrapper, times(1)).getNetworkInterface();
         testCallback.expectedAddressesChangedForNetwork(
                 TEST_NETWORK, List.of(LINKADDRV4, LINKADDRV6));
     }
@@ -403,4 +421,77 @@
         verify(mTestNetworkIfaceWrapper, times(2)).getNetworkInterface();
         testCallback.expectedSocketCreatedForNetwork(otherNetwork, List.of(otherAddress));
     }
+
+    @Test
+    public void testNoSocketCreatedForCellular() {
+        startMonitoringSockets();
+
+        final TestSocketCallback testCallback = new TestSocketCallback();
+        mHandler.post(() -> mSocketProvider.requestSocket(TEST_NETWORK, testCallback));
+
+        postNetworkAvailable(TRANSPORT_CELLULAR);
+        testCallback.expectedNoCallback();
+    }
+
+    @Test
+    public void testNoSocketCreatedForNonMulticastInterface() throws Exception {
+        doReturn(false).when(mTestNetworkIfaceWrapper).supportsMulticast();
+        startMonitoringSockets();
+
+        final TestSocketCallback testCallback = new TestSocketCallback();
+        mHandler.post(() -> mSocketProvider.requestSocket(TEST_NETWORK, testCallback));
+
+        postNetworkAvailable(TRANSPORT_BLUETOOTH);
+        testCallback.expectedNoCallback();
+    }
+
+    @Test
+    public void testSocketCreatedForMulticastInterface() throws Exception {
+        doReturn(true).when(mTestNetworkIfaceWrapper).supportsMulticast();
+        startMonitoringSockets();
+
+        final TestSocketCallback testCallback = new TestSocketCallback();
+        mHandler.post(() -> mSocketProvider.requestSocket(TEST_NETWORK, testCallback));
+
+        postNetworkAvailable(TRANSPORT_BLUETOOTH);
+        testCallback.expectedSocketCreatedForNetwork(TEST_NETWORK, List.of(LINKADDRV4));
+    }
+
+    @Test
+    public void testNoSocketCreatedForPTPInterface() throws Exception {
+        doReturn(true).when(mTestNetworkIfaceWrapper).isPointToPoint();
+        startMonitoringSockets();
+
+        final TestSocketCallback testCallback = new TestSocketCallback();
+        mHandler.post(() -> mSocketProvider.requestSocket(TEST_NETWORK, testCallback));
+
+        postNetworkAvailable(TRANSPORT_BLUETOOTH);
+        testCallback.expectedNoCallback();
+    }
+
+    @Test
+    public void testNoSocketCreatedForVPNInterface() throws Exception {
+        // VPN interfaces generally also have IFF_POINTOPOINT, but even if they don't, they should
+        // not be included even with TRANSPORT_WIFI.
+        doReturn(false).when(mTestNetworkIfaceWrapper).supportsMulticast();
+        startMonitoringSockets();
+
+        final TestSocketCallback testCallback = new TestSocketCallback();
+        mHandler.post(() -> mSocketProvider.requestSocket(TEST_NETWORK, testCallback));
+
+        postNetworkAvailable(TRANSPORT_VPN, TRANSPORT_WIFI);
+        testCallback.expectedNoCallback();
+    }
+
+    @Test
+    public void testSocketCreatedForWifiWithoutMulticastFlag() throws Exception {
+        doReturn(false).when(mTestNetworkIfaceWrapper).supportsMulticast();
+        startMonitoringSockets();
+
+        final TestSocketCallback testCallback = new TestSocketCallback();
+        mHandler.post(() -> mSocketProvider.requestSocket(TEST_NETWORK, testCallback));
+
+        postNetworkAvailable(TRANSPORT_WIFI);
+        testCallback.expectedSocketCreatedForNetwork(TEST_NETWORK, List.of(LINKADDRV4));
+    }
 }