Ensure all DatagramPackets with the same destination address

The MdnsSocketclient#sendMdnsPacket() uses the address of the
first packet only. Since there's no guarantee that all packets
have the same address, skip sending packets if any have a
different address.

Fix: 335125073
Test: atest FrameworksNetTests NsdManagerTest
Change-Id: Ief0fd81eacf46b33bb2e8c0fcc4da7981ac2972d
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java b/service-t/src/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java
index fcfb15f..c575d40 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java
@@ -29,7 +29,9 @@
 import android.util.ArrayMap;
 import android.util.Log;
 
+import com.android.net.module.util.CollectionUtils;
 import com.android.net.module.util.SharedLog;
+import com.android.server.connectivity.mdns.util.MdnsUtils;
 
 import java.io.IOException;
 import java.net.DatagramPacket;
@@ -225,6 +227,12 @@
             Log.wtf(TAG, "No mDns packets to send");
             return;
         }
+        // Check all packets with the same address
+        if (!MdnsUtils.checkAllPacketsWithSameAddress(packets)) {
+            Log.wtf(TAG, "Some mDNS packets have a different target address. addresses="
+                    + CollectionUtils.map(packets, DatagramPacket::getSocketAddress));
+            return;
+        }
 
         final boolean isIpv6 = ((InetSocketAddress) packets.get(0).getSocketAddress())
                 .getAddress() instanceof Inet6Address;
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClient.java b/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClient.java
index 9cfcba1..17e5b31 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClient.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClient.java
@@ -28,7 +28,9 @@
 import android.util.Log;
 
 import com.android.internal.annotations.VisibleForTesting;
+import com.android.net.module.util.CollectionUtils;
 import com.android.net.module.util.SharedLog;
+import com.android.server.connectivity.mdns.util.MdnsUtils;
 
 import java.io.IOException;
 import java.net.DatagramPacket;
@@ -249,6 +251,12 @@
             Log.wtf(TAG, "No mDns packets to send");
             return;
         }
+        // Check all packets with the same address
+        if (!MdnsUtils.checkAllPacketsWithSameAddress(packets)) {
+            Log.wtf(TAG, "Some mDNS packets have a different target address. addresses="
+                    + CollectionUtils.map(packets, DatagramPacket::getSocketAddress));
+            return;
+        }
 
         final boolean isIpv4 = ((InetSocketAddress) packets.get(0).getSocketAddress())
                 .getAddress() instanceof Inet4Address;
diff --git a/service-t/src/com/android/server/connectivity/mdns/util/MdnsUtils.java b/service-t/src/com/android/server/connectivity/mdns/util/MdnsUtils.java
index 3c11a24..226867f 100644
--- a/service-t/src/com/android/server/connectivity/mdns/util/MdnsUtils.java
+++ b/service-t/src/com/android/server/connectivity/mdns/util/MdnsUtils.java
@@ -34,6 +34,7 @@
 
 import java.io.IOException;
 import java.net.DatagramPacket;
+import java.net.InetAddress;
 import java.net.InetSocketAddress;
 import java.nio.ByteBuffer;
 import java.nio.CharBuffer;
@@ -361,4 +362,23 @@
             return SystemClock.elapsedRealtime();
         }
     }
+
+    /**
+     * Check all DatagramPackets with the same destination address.
+     */
+    public static boolean checkAllPacketsWithSameAddress(List<DatagramPacket> packets) {
+        // No packet for address check
+        if (packets.isEmpty()) {
+            return true;
+        }
+
+        final InetAddress address =
+                ((InetSocketAddress) packets.get(0).getSocketAddress()).getAddress();
+        for (DatagramPacket packet : packets) {
+            if (!address.equals(((InetSocketAddress) packet.getSocketAddress()).getAddress())) {
+                return false;
+            }
+        }
+        return true;
+    }
 }
\ No newline at end of file
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClientTest.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClientTest.java
index fb3d183..4c71991 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClientTest.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClientTest.java
@@ -18,8 +18,10 @@
 
 import static com.android.server.connectivity.mdns.MdnsSocketProvider.SocketCallback;
 import static com.android.server.connectivity.mdns.MulticastPacketReader.PacketHandler;
+import static com.android.testutils.Cleanup.testAndCleanup;
 
 import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.eq;
@@ -35,6 +37,7 @@
 import android.os.Build;
 import android.os.Handler;
 import android.os.HandlerThread;
+import android.util.Log;
 
 import com.android.net.module.util.HexDump;
 import com.android.net.module.util.SharedLog;
@@ -59,6 +62,7 @@
 import java.net.SocketException;
 import java.util.ArrayList;
 import java.util.List;
+import java.util.concurrent.atomic.AtomicBoolean;
 
 @RunWith(DevSdkIgnoreRunner.class)
 @DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.S_V2)
@@ -437,4 +441,34 @@
             inOrder.verify(mSocket).send(packets.get(i));
         }
     }
+
+    @Test
+    public void testSendPacketWithMultiplePacketsWithDifferentAddresses() throws IOException {
+        final SocketCallback callback = expectSocketCallback();
+        final DatagramPacket ipv4Packet = new DatagramPacket(BUFFER, 0 /* offset */, BUFFER.length,
+                InetAddresses.parseNumericAddress("192.0.2.1"), 0 /* port */);
+        final DatagramPacket ipv6Packet = new DatagramPacket(BUFFER, 0 /* offset */, BUFFER.length,
+                InetAddresses.parseNumericAddress("2001:db8::"), 0 /* port */);
+        doReturn(true).when(mSocket).hasJoinedIpv4();
+        doReturn(true).when(mSocket).hasJoinedIpv6();
+        doReturn(createEmptyNetworkInterface()).when(mSocket).getInterface();
+
+        // Notify socket created
+        callback.onSocketCreated(mSocketKey, mSocket, List.of());
+        verify(mSocketCreationCallback).onSocketCreated(mSocketKey);
+
+        // Send packets with IPv4 and IPv6 then verify wtf logs and sending has never been called.
+        // Override the default TerribleFailureHandler, as that handler might terminate the process
+        // (if we're on an eng build).
+        final AtomicBoolean hasFailed = new AtomicBoolean(false);
+        final Log.TerribleFailureHandler originalHandler =
+                Log.setWtfHandler((tag, what, system) -> hasFailed.set(true));
+        testAndCleanup(() -> {
+            mSocketClient.sendPacketRequestingMulticastResponse(List.of(ipv4Packet, ipv6Packet),
+                    mSocketKey, false /* onlyUseIpv6OnIpv6OnlyNetworks */);
+            HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
+            assertTrue(hasFailed.get());
+            verify(mSocket, never()).send(any());
+        }, () -> Log.setWtfHandler(originalHandler));
+    }
 }
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketClientTests.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketClientTests.java
index 1989ed3..ab70e38 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketClientTests.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketClientTests.java
@@ -16,6 +16,7 @@
 
 package com.android.server.connectivity.mdns;
 
+import static com.android.testutils.Cleanup.testAndCleanup;
 import static com.android.testutils.DevSdkIgnoreRuleKt.SC_V2;
 
 import static org.junit.Assert.assertFalse;
@@ -38,9 +39,11 @@
 import android.annotation.RequiresPermission;
 import android.content.Context;
 import android.net.ConnectivityManager;
+import android.net.InetAddresses;
 import android.net.wifi.WifiManager;
 import android.net.wifi.WifiManager.MulticastLock;
 import android.text.format.DateUtils;
+import android.util.Log;
 
 import com.android.net.module.util.HexDump;
 import com.android.net.module.util.SharedLog;
@@ -594,6 +597,29 @@
         }
     }
 
+    @Test
+    public void testSendPacketWithMultiplePacketsWithDifferentAddresses() throws IOException {
+        mdnsClient.startDiscovery();
+        final byte[] buffer = new byte[10];
+        final DatagramPacket ipv4Packet = new DatagramPacket(buffer, 0 /* offset */, buffer.length,
+                InetAddresses.parseNumericAddress("192.0.2.1"), 0 /* port */);
+        final DatagramPacket ipv6Packet = new DatagramPacket(buffer, 0 /* offset */, buffer.length,
+                InetAddresses.parseNumericAddress("2001:db8::"), 0 /* port */);
+
+        // Send packets with IPv4 and IPv6 then verify wtf logs and sending has never been called.
+        // Override the default TerribleFailureHandler, as that handler might terminate the process
+        // (if we're on an eng build).
+        final AtomicBoolean hasFailed = new AtomicBoolean(false);
+        final Log.TerribleFailureHandler originalHandler =
+                Log.setWtfHandler((tag, what, system) -> hasFailed.set(true));
+        testAndCleanup(() -> {
+            mdnsClient.sendPacketRequestingMulticastResponse(List.of(ipv4Packet, ipv6Packet),
+                    false /* onlyUseIpv6OnIpv6OnlyNetworks */);
+            assertTrue(hasFailed.get());
+            verify(mockMulticastSocket, never()).send(any());
+        }, () -> Log.setWtfHandler(originalHandler));
+    }
+
     private DatagramPacket getTestDatagramPacket() {
         return new DatagramPacket(buf, 0, 5,
                 new InetSocketAddress(MdnsConstants.getMdnsIPv4Address(), 5353 /* port */));
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/util/MdnsUtilsTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/util/MdnsUtilsTest.kt
index 009205e..cf88d05 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/util/MdnsUtilsTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/util/MdnsUtilsTest.kt
@@ -16,9 +16,12 @@
 
 package com.android.server.connectivity.mdns.util
 
+import android.net.InetAddresses
 import android.os.Build
 import com.android.server.connectivity.mdns.MdnsConstants
 import com.android.server.connectivity.mdns.MdnsConstants.FLAG_TRUNCATED
+import com.android.server.connectivity.mdns.MdnsConstants.IPV4_SOCKET_ADDR
+import com.android.server.connectivity.mdns.MdnsConstants.IPV6_SOCKET_ADDR
 import com.android.server.connectivity.mdns.MdnsPacket
 import com.android.server.connectivity.mdns.MdnsPacketReader
 import com.android.server.connectivity.mdns.MdnsPointerRecord
@@ -193,4 +196,31 @@
         }
         return MdnsPacket(flags, questions, answers, emptyList(), emptyList())
     }
+
+    @Test
+    fun testCheckAllPacketsWithSameAddress() {
+        val buffer = ByteArray(10)
+        val v4Packet = DatagramPacket(buffer, buffer.size, IPV4_SOCKET_ADDR)
+        val otherV4Packet = DatagramPacket(
+            buffer,
+            buffer.size,
+            InetAddresses.parseNumericAddress("192.0.2.1"),
+            1234
+        )
+        val v6Packet = DatagramPacket(ByteArray(10), 10, IPV6_SOCKET_ADDR)
+        val otherV6Packet = DatagramPacket(
+            buffer,
+            buffer.size,
+            InetAddresses.parseNumericAddress("2001:db8::"),
+            1234
+        )
+        assertTrue(MdnsUtils.checkAllPacketsWithSameAddress(listOf()))
+        assertTrue(MdnsUtils.checkAllPacketsWithSameAddress(listOf(v4Packet)))
+        assertTrue(MdnsUtils.checkAllPacketsWithSameAddress(listOf(v4Packet, v4Packet)))
+        assertFalse(MdnsUtils.checkAllPacketsWithSameAddress(listOf(v4Packet, otherV4Packet)))
+        assertTrue(MdnsUtils.checkAllPacketsWithSameAddress(listOf(v6Packet)))
+        assertTrue(MdnsUtils.checkAllPacketsWithSameAddress(listOf(v6Packet, v6Packet)))
+        assertFalse(MdnsUtils.checkAllPacketsWithSameAddress(listOf(v6Packet, otherV6Packet)))
+        assertFalse(MdnsUtils.checkAllPacketsWithSameAddress(listOf(v4Packet, v6Packet)))
+    }
 }