Add MdnsUtils#createQueryDatagramPackets method
This method can generate query packets from an oversized
MdnsPacket.
Bug: 312657709
Test: atest FrameworksNetTests
Change-Id: I21c1e082e01e0a6421920692115b0afdca0c9f9c
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsPacket.java b/service-t/src/com/android/server/connectivity/mdns/MdnsPacket.java
index 1fabd49..83ecabc 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsPacket.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsPacket.java
@@ -42,7 +42,7 @@
@NonNull
public final List<MdnsRecord> additionalRecords;
- MdnsPacket(int flags,
+ public MdnsPacket(int flags,
@NonNull List<MdnsRecord> questions,
@NonNull List<MdnsRecord> answers,
@NonNull List<MdnsRecord> authorityRecords,
diff --git a/service-t/src/com/android/server/connectivity/mdns/util/MdnsUtils.java b/service-t/src/com/android/server/connectivity/mdns/util/MdnsUtils.java
index d553210..3c11a24 100644
--- a/service-t/src/com/android/server/connectivity/mdns/util/MdnsUtils.java
+++ b/service-t/src/com/android/server/connectivity/mdns/util/MdnsUtils.java
@@ -16,6 +16,8 @@
package com.android.server.connectivity.mdns.util;
+import static com.android.server.connectivity.mdns.MdnsConstants.FLAG_TRUNCATED;
+
import android.annotation.NonNull;
import android.annotation.Nullable;
import android.net.Network;
@@ -23,6 +25,7 @@
import android.os.Handler;
import android.os.SystemClock;
import android.util.ArraySet;
+import android.util.Pair;
import com.android.server.connectivity.mdns.MdnsConstants;
import com.android.server.connectivity.mdns.MdnsPacket;
@@ -30,13 +33,18 @@
import com.android.server.connectivity.mdns.MdnsRecord;
import java.io.IOException;
+import java.net.DatagramPacket;
+import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.nio.CharBuffer;
import java.nio.charset.Charset;
import java.nio.charset.CharsetEncoder;
import java.nio.charset.StandardCharsets;
+import java.util.ArrayList;
import java.util.Arrays;
+import java.util.Collections;
import java.util.HashSet;
+import java.util.List;
import java.util.Set;
/**
@@ -226,6 +234,100 @@
}
/**
+ * Writes the possible query content of an MdnsPacket into the data buffer.
+ *
+ * <p>This method is specifically for query packets. It writes the question and answer sections
+ * into the data buffer only.
+ *
+ * @param packetCreationBuffer The data buffer for the query content.
+ * @param packet The MdnsPacket to be written into the data buffer.
+ * @return A Pair containing:
+ * 1. The remaining MdnsPacket data that could not fit in the buffer.
+ * 2. The length of the data written to the buffer.
+ */
+ @Nullable
+ private static Pair<MdnsPacket, Integer> writePossibleMdnsPacket(
+ @NonNull byte[] packetCreationBuffer, @NonNull MdnsPacket packet) throws IOException {
+ MdnsPacket remainingPacket;
+ final MdnsPacketWriter writer = new MdnsPacketWriter(packetCreationBuffer);
+ writer.writeUInt16(packet.transactionId); // Transaction ID
+
+ final int flagsPos = writer.getWritePosition();
+ writer.writeUInt16(0); // Flags, written later
+ writer.writeUInt16(0); // questions count, written later
+ writer.writeUInt16(0); // answers count, written later
+ writer.writeUInt16(0); // authority entries count, empty session for query
+ writer.writeUInt16(0); // additional records count, empty session for query
+
+ int writtenQuestions = 0;
+ int writtenAnswers = 0;
+ int lastValidPos = writer.getWritePosition();
+ try {
+ for (MdnsRecord record : packet.questions) {
+ // Questions do not have TTL or data
+ record.writeHeaderFields(writer);
+ writtenQuestions++;
+ lastValidPos = writer.getWritePosition();
+ }
+ for (MdnsRecord record : packet.answers) {
+ record.write(writer, 0L);
+ writtenAnswers++;
+ lastValidPos = writer.getWritePosition();
+ }
+ remainingPacket = null;
+ } catch (IOException e) {
+ // Went over the packet limit; truncate
+ if (writtenQuestions == 0 && writtenAnswers == 0) {
+ // No space to write even one record: just throw (as subclass of IOException)
+ throw e;
+ }
+
+ // Set the last valid position as the final position (not as a rewind)
+ writer.rewind(lastValidPos);
+ writer.clearRewind();
+
+ remainingPacket = new MdnsPacket(packet.flags,
+ packet.questions.subList(
+ writtenQuestions, packet.questions.size()),
+ packet.answers.subList(
+ writtenAnswers, packet.answers.size()),
+ Collections.emptyList(), /* authorityRecords */
+ Collections.emptyList() /* additionalRecords */);
+ }
+
+ final int len = writer.getWritePosition();
+ writer.rewind(flagsPos);
+ writer.writeUInt16(packet.flags | (remainingPacket == null ? 0 : FLAG_TRUNCATED));
+ writer.writeUInt16(writtenQuestions);
+ writer.writeUInt16(writtenAnswers);
+ writer.unrewind();
+
+ return Pair.create(remainingPacket, len);
+ }
+
+ /**
+ * Create Datagram packets from given MdnsPacket and InetSocketAddress.
+ *
+ * <p> If the MdnsPacket is too large for a single DatagramPacket, it will be split into
+ * multiple DatagramPackets.
+ */
+ public static List<DatagramPacket> createQueryDatagramPackets(
+ @NonNull byte[] packetCreationBuffer, @NonNull MdnsPacket packet,
+ @NonNull InetSocketAddress destination) throws IOException {
+ final List<DatagramPacket> datagramPackets = new ArrayList<>();
+ MdnsPacket remainingPacket = packet;
+ while (remainingPacket != null) {
+ final Pair<MdnsPacket, Integer> result =
+ writePossibleMdnsPacket(packetCreationBuffer, remainingPacket);
+ remainingPacket = result.first;
+ final int len = result.second;
+ final byte[] outBuffer = Arrays.copyOfRange(packetCreationBuffer, 0, len);
+ datagramPackets.add(new DatagramPacket(outBuffer, 0, outBuffer.length, destination));
+ }
+ return datagramPackets;
+ }
+
+ /**
* Checks if the MdnsRecord needs to be renewed or not.
*
* <p>As per RFC6762 7.1 no need to query if remaining TTL is more than half the original one,
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/util/MdnsUtilsTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/util/MdnsUtilsTest.kt
index f705bcb..b1a7233 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/util/MdnsUtilsTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/util/MdnsUtilsTest.kt
@@ -17,6 +17,13 @@
package com.android.server.connectivity.mdns.util
import android.os.Build
+import com.android.server.connectivity.mdns.MdnsConstants
+import com.android.server.connectivity.mdns.MdnsConstants.FLAG_TRUNCATED
+import com.android.server.connectivity.mdns.MdnsPacket
+import com.android.server.connectivity.mdns.MdnsPacketReader
+import com.android.server.connectivity.mdns.MdnsPointerRecord
+import com.android.server.connectivity.mdns.MdnsRecord
+import com.android.server.connectivity.mdns.util.MdnsUtils.createQueryDatagramPackets
import com.android.server.connectivity.mdns.util.MdnsUtils.equalsDnsLabelIgnoreDnsCase
import com.android.server.connectivity.mdns.util.MdnsUtils.equalsIgnoreDnsCase
import com.android.server.connectivity.mdns.util.MdnsUtils.toDnsLabelsLowerCase
@@ -24,6 +31,8 @@
import com.android.server.connectivity.mdns.util.MdnsUtils.truncateServiceName
import com.android.testutils.DevSdkIgnoreRule
import com.android.testutils.DevSdkIgnoreRunner
+import java.net.DatagramPacket
+import kotlin.test.assertContentEquals
import org.junit.Assert.assertArrayEquals
import org.junit.Assert.assertEquals
import org.junit.Assert.assertFalse
@@ -102,4 +111,67 @@
arrayOf("a", "_other", "_type", "_tcp", "local"),
arrayOf("a", "_SUB", "_type", "_TCP", "local")))
}
+
+ @Test
+ fun testCreateQueryDatagramPackets() {
+ // Question data bytes:
+ // Name label(17)(duplicated labels) + PTR type(2) + cacheFlush(2) = 21
+ //
+ // Known answers data bytes:
+ // Name label(17)(duplicated labels) + PTR type(2) + cacheFlush(2) + receiptTimeMillis(4)
+ // + Data length(2) + Pointer data(18)(duplicated labels) = 45
+ val questions = mutableListOf<MdnsRecord>()
+ val knownAnswers = mutableListOf<MdnsRecord>()
+ for (i in 1..100) {
+ questions.add(MdnsPointerRecord(arrayOf("_testservice$i", "_tcp", "local"), false))
+ knownAnswers.add(MdnsPointerRecord(
+ arrayOf("_testservice$i", "_tcp", "local"),
+ 0L,
+ false,
+ 4_500_000L,
+ arrayOf("MyTestService$i", "_testservice$i", "_tcp", "local")
+ ))
+ }
+ // MdnsPacket data bytes:
+ // Questions(21 * 100) + Answers(45 * 100) = 6600 -> at least 5 packets
+ val query = MdnsPacket(
+ MdnsConstants.FLAGS_QUERY,
+ questions as List<MdnsRecord>,
+ knownAnswers as List<MdnsRecord>,
+ emptyList(),
+ emptyList()
+ )
+ // Expect the oversize MdnsPacket to be separated into 5 DatagramPackets.
+ val bufferSize = 1500
+ val packets = createQueryDatagramPackets(
+ ByteArray(bufferSize),
+ query,
+ MdnsConstants.IPV4_SOCKET_ADDR
+ )
+ assertEquals(5, packets.size)
+ assertTrue(packets.all { packet -> packet.length < bufferSize })
+
+ val mdnsPacket = createMdnsPacketFromMultipleDatagramPackets(packets)
+ assertEquals(query.flags, mdnsPacket.flags)
+ assertContentEquals(query.questions, mdnsPacket.questions)
+ assertContentEquals(query.answers, mdnsPacket.answers)
+ }
+
+ private fun createMdnsPacketFromMultipleDatagramPackets(
+ packets: List<DatagramPacket>
+ ): MdnsPacket {
+ var flags = 0
+ val questions = mutableListOf<MdnsRecord>()
+ val answers = mutableListOf<MdnsRecord>()
+ for ((index, packet) in packets.withIndex()) {
+ val mdnsPacket = MdnsPacket.parse(MdnsPacketReader(packet))
+ if (index != packets.size - 1) {
+ assertTrue((mdnsPacket.flags and FLAG_TRUNCATED) == FLAG_TRUNCATED)
+ }
+ flags = mdnsPacket.flags
+ questions.addAll(mdnsPacket.questions)
+ answers.addAll(mdnsPacket.answers)
+ }
+ return MdnsPacket(flags, questions, answers, emptyList(), emptyList())
+ }
}