Add MdnsUtils#createQueryDatagramPackets method

This method can generate query packets from an oversized
MdnsPacket.

Bug: 312657709
Test: atest FrameworksNetTests
Change-Id: I21c1e082e01e0a6421920692115b0afdca0c9f9c
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsPacket.java b/service-t/src/com/android/server/connectivity/mdns/MdnsPacket.java
index 1fabd49..83ecabc 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsPacket.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsPacket.java
@@ -42,7 +42,7 @@
     @NonNull
     public final List<MdnsRecord> additionalRecords;
 
-    MdnsPacket(int flags,
+    public MdnsPacket(int flags,
             @NonNull List<MdnsRecord> questions,
             @NonNull List<MdnsRecord> answers,
             @NonNull List<MdnsRecord> authorityRecords,
diff --git a/service-t/src/com/android/server/connectivity/mdns/util/MdnsUtils.java b/service-t/src/com/android/server/connectivity/mdns/util/MdnsUtils.java
index d553210..3c11a24 100644
--- a/service-t/src/com/android/server/connectivity/mdns/util/MdnsUtils.java
+++ b/service-t/src/com/android/server/connectivity/mdns/util/MdnsUtils.java
@@ -16,6 +16,8 @@
 
 package com.android.server.connectivity.mdns.util;
 
+import static com.android.server.connectivity.mdns.MdnsConstants.FLAG_TRUNCATED;
+
 import android.annotation.NonNull;
 import android.annotation.Nullable;
 import android.net.Network;
@@ -23,6 +25,7 @@
 import android.os.Handler;
 import android.os.SystemClock;
 import android.util.ArraySet;
+import android.util.Pair;
 
 import com.android.server.connectivity.mdns.MdnsConstants;
 import com.android.server.connectivity.mdns.MdnsPacket;
@@ -30,13 +33,18 @@
 import com.android.server.connectivity.mdns.MdnsRecord;
 
 import java.io.IOException;
+import java.net.DatagramPacket;
+import java.net.InetSocketAddress;
 import java.nio.ByteBuffer;
 import java.nio.CharBuffer;
 import java.nio.charset.Charset;
 import java.nio.charset.CharsetEncoder;
 import java.nio.charset.StandardCharsets;
+import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Collections;
 import java.util.HashSet;
+import java.util.List;
 import java.util.Set;
 
 /**
@@ -226,6 +234,100 @@
     }
 
     /**
+     * Writes the possible query content of an MdnsPacket into the data buffer.
+     *
+     * <p>This method is specifically for query packets. It writes the question and answer sections
+     *    into the data buffer only.
+     *
+     * @param packetCreationBuffer The data buffer for the query content.
+     * @param packet The MdnsPacket to be written into the data buffer.
+     * @return A Pair containing:
+     *         1. The remaining MdnsPacket data that could not fit in the buffer.
+     *         2. The length of the data written to the buffer.
+     */
+    @Nullable
+    private static Pair<MdnsPacket, Integer> writePossibleMdnsPacket(
+            @NonNull byte[] packetCreationBuffer, @NonNull MdnsPacket packet) throws IOException {
+        MdnsPacket remainingPacket;
+        final MdnsPacketWriter writer = new MdnsPacketWriter(packetCreationBuffer);
+        writer.writeUInt16(packet.transactionId); // Transaction ID
+
+        final int flagsPos = writer.getWritePosition();
+        writer.writeUInt16(0); // Flags, written later
+        writer.writeUInt16(0); // questions count, written later
+        writer.writeUInt16(0); // answers count, written later
+        writer.writeUInt16(0); // authority entries count, empty session for query
+        writer.writeUInt16(0); // additional records count, empty session for query
+
+        int writtenQuestions = 0;
+        int writtenAnswers = 0;
+        int lastValidPos = writer.getWritePosition();
+        try {
+            for (MdnsRecord record : packet.questions) {
+                // Questions do not have TTL or data
+                record.writeHeaderFields(writer);
+                writtenQuestions++;
+                lastValidPos = writer.getWritePosition();
+            }
+            for (MdnsRecord record : packet.answers) {
+                record.write(writer, 0L);
+                writtenAnswers++;
+                lastValidPos = writer.getWritePosition();
+            }
+            remainingPacket = null;
+        } catch (IOException e) {
+            // Went over the packet limit; truncate
+            if (writtenQuestions == 0 && writtenAnswers == 0) {
+                // No space to write even one record: just throw (as subclass of IOException)
+                throw e;
+            }
+
+            // Set the last valid position as the final position (not as a rewind)
+            writer.rewind(lastValidPos);
+            writer.clearRewind();
+
+            remainingPacket = new MdnsPacket(packet.flags,
+                    packet.questions.subList(
+                            writtenQuestions, packet.questions.size()),
+                    packet.answers.subList(
+                            writtenAnswers, packet.answers.size()),
+                    Collections.emptyList(), /* authorityRecords */
+                    Collections.emptyList() /* additionalRecords */);
+        }
+
+        final int len = writer.getWritePosition();
+        writer.rewind(flagsPos);
+        writer.writeUInt16(packet.flags | (remainingPacket == null ? 0 : FLAG_TRUNCATED));
+        writer.writeUInt16(writtenQuestions);
+        writer.writeUInt16(writtenAnswers);
+        writer.unrewind();
+
+        return Pair.create(remainingPacket, len);
+    }
+
+    /**
+     * Create Datagram packets from given MdnsPacket and InetSocketAddress.
+     *
+     * <p> If the MdnsPacket is too large for a single DatagramPacket, it will be split into
+     *     multiple DatagramPackets.
+     */
+    public static List<DatagramPacket> createQueryDatagramPackets(
+            @NonNull byte[] packetCreationBuffer, @NonNull MdnsPacket packet,
+            @NonNull InetSocketAddress destination) throws IOException {
+        final List<DatagramPacket> datagramPackets = new ArrayList<>();
+        MdnsPacket remainingPacket = packet;
+        while (remainingPacket != null) {
+            final Pair<MdnsPacket, Integer> result =
+                    writePossibleMdnsPacket(packetCreationBuffer, remainingPacket);
+            remainingPacket = result.first;
+            final int len = result.second;
+            final byte[] outBuffer = Arrays.copyOfRange(packetCreationBuffer, 0, len);
+            datagramPackets.add(new DatagramPacket(outBuffer, 0, outBuffer.length, destination));
+        }
+        return datagramPackets;
+    }
+
+    /**
      * Checks if the MdnsRecord needs to be renewed or not.
      *
      * <p>As per RFC6762 7.1 no need to query if remaining TTL is more than half the original one,
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/util/MdnsUtilsTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/util/MdnsUtilsTest.kt
index f705bcb..b1a7233 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/util/MdnsUtilsTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/util/MdnsUtilsTest.kt
@@ -17,6 +17,13 @@
 package com.android.server.connectivity.mdns.util
 
 import android.os.Build
+import com.android.server.connectivity.mdns.MdnsConstants
+import com.android.server.connectivity.mdns.MdnsConstants.FLAG_TRUNCATED
+import com.android.server.connectivity.mdns.MdnsPacket
+import com.android.server.connectivity.mdns.MdnsPacketReader
+import com.android.server.connectivity.mdns.MdnsPointerRecord
+import com.android.server.connectivity.mdns.MdnsRecord
+import com.android.server.connectivity.mdns.util.MdnsUtils.createQueryDatagramPackets
 import com.android.server.connectivity.mdns.util.MdnsUtils.equalsDnsLabelIgnoreDnsCase
 import com.android.server.connectivity.mdns.util.MdnsUtils.equalsIgnoreDnsCase
 import com.android.server.connectivity.mdns.util.MdnsUtils.toDnsLabelsLowerCase
@@ -24,6 +31,8 @@
 import com.android.server.connectivity.mdns.util.MdnsUtils.truncateServiceName
 import com.android.testutils.DevSdkIgnoreRule
 import com.android.testutils.DevSdkIgnoreRunner
+import java.net.DatagramPacket
+import kotlin.test.assertContentEquals
 import org.junit.Assert.assertArrayEquals
 import org.junit.Assert.assertEquals
 import org.junit.Assert.assertFalse
@@ -102,4 +111,67 @@
                 arrayOf("a", "_other", "_type", "_tcp", "local"),
                 arrayOf("a", "_SUB", "_type", "_TCP", "local")))
     }
+
+    @Test
+    fun testCreateQueryDatagramPackets() {
+        // Question data bytes:
+        // Name label(17)(duplicated labels) + PTR type(2) + cacheFlush(2) = 21
+        //
+        // Known answers data bytes:
+        // Name label(17)(duplicated labels) + PTR type(2) + cacheFlush(2) + receiptTimeMillis(4)
+        // + Data length(2) + Pointer data(18)(duplicated labels) = 45
+        val questions = mutableListOf<MdnsRecord>()
+        val knownAnswers = mutableListOf<MdnsRecord>()
+        for (i in 1..100) {
+            questions.add(MdnsPointerRecord(arrayOf("_testservice$i", "_tcp", "local"), false))
+            knownAnswers.add(MdnsPointerRecord(
+                    arrayOf("_testservice$i", "_tcp", "local"),
+                    0L,
+                    false,
+                    4_500_000L,
+                    arrayOf("MyTestService$i", "_testservice$i", "_tcp", "local")
+            ))
+        }
+        // MdnsPacket data bytes:
+        // Questions(21 * 100) + Answers(45 * 100) = 6600 -> at least 5 packets
+        val query = MdnsPacket(
+                MdnsConstants.FLAGS_QUERY,
+                questions as List<MdnsRecord>,
+                knownAnswers as List<MdnsRecord>,
+                emptyList(),
+                emptyList()
+        )
+        // Expect the oversize MdnsPacket to be separated into 5 DatagramPackets.
+        val bufferSize = 1500
+        val packets = createQueryDatagramPackets(
+                ByteArray(bufferSize),
+                query,
+                MdnsConstants.IPV4_SOCKET_ADDR
+        )
+        assertEquals(5, packets.size)
+        assertTrue(packets.all { packet -> packet.length < bufferSize })
+
+        val mdnsPacket = createMdnsPacketFromMultipleDatagramPackets(packets)
+        assertEquals(query.flags, mdnsPacket.flags)
+        assertContentEquals(query.questions, mdnsPacket.questions)
+        assertContentEquals(query.answers, mdnsPacket.answers)
+    }
+
+    private fun createMdnsPacketFromMultipleDatagramPackets(
+            packets: List<DatagramPacket>
+    ): MdnsPacket {
+        var flags = 0
+        val questions = mutableListOf<MdnsRecord>()
+        val answers = mutableListOf<MdnsRecord>()
+        for ((index, packet) in packets.withIndex()) {
+            val mdnsPacket = MdnsPacket.parse(MdnsPacketReader(packet))
+            if (index != packets.size - 1) {
+                assertTrue((mdnsPacket.flags and FLAG_TRUNCATED) == FLAG_TRUNCATED)
+            }
+            flags = mdnsPacket.flags
+            questions.addAll(mdnsPacket.questions)
+            answers.addAll(mdnsPacket.answers)
+        }
+        return MdnsPacket(flags, questions, answers, emptyList(), emptyList())
+    }
 }