Factor out response decoding into MdnsPacket

Factor out generic packet decoding into MdnsPacket, out from
MdnsResponseDecoder.

This allows reusing the same decoding code for replies, and other kinds
of MdnsPackets.

Bug: 241738458
Test: atest MdnsResponseDecoderTests MdnsPacketTest
Change-Id: I4380f80240ed5a367accfc1b0c595967ee475578
diff --git a/service-t/src/com/android/server/mdns/MdnsPacket.java b/service-t/src/com/android/server/mdns/MdnsPacket.java
index eae084a..27002b9 100644
--- a/service-t/src/com/android/server/mdns/MdnsPacket.java
+++ b/service-t/src/com/android/server/mdns/MdnsPacket.java
@@ -16,6 +16,13 @@
 
 package com.android.server.connectivity.mdns;
 
+import android.annotation.NonNull;
+import android.annotation.Nullable;
+import android.util.Log;
+
+import java.io.EOFException;
+import java.io.IOException;
+import java.util.ArrayList;
 import java.util.Collections;
 import java.util.List;
 
@@ -23,21 +30,202 @@
  * A class holding data that can be included in a mDNS packet.
  */
 public class MdnsPacket {
+    private static final String TAG = MdnsPacket.class.getSimpleName();
+
     public final int flags;
+    @NonNull
     public final List<MdnsRecord> questions;
+    @NonNull
     public final List<MdnsRecord> answers;
+    @NonNull
     public final List<MdnsRecord> authorityRecords;
+    @NonNull
     public final List<MdnsRecord> additionalRecords;
 
     MdnsPacket(int flags,
-            List<MdnsRecord> questions,
-            List<MdnsRecord> answers,
-            List<MdnsRecord> authorityRecords,
-            List<MdnsRecord> additionalRecords) {
+            @NonNull List<MdnsRecord> questions,
+            @NonNull List<MdnsRecord> answers,
+            @NonNull List<MdnsRecord> authorityRecords,
+            @NonNull List<MdnsRecord> additionalRecords) {
         this.flags = flags;
         this.questions = Collections.unmodifiableList(questions);
         this.answers = Collections.unmodifiableList(answers);
         this.authorityRecords = Collections.unmodifiableList(authorityRecords);
         this.additionalRecords = Collections.unmodifiableList(additionalRecords);
     }
+
+    /**
+     * Exception thrown on parse errors.
+     */
+    public static class ParseException extends IOException {
+        public final int code;
+
+        public ParseException(int code, @NonNull String message, @Nullable Throwable cause) {
+            super(message, cause);
+            this.code = code;
+        }
+    }
+
+    /**
+     * Parse the packet in the provided {@link MdnsPacketReader}.
+     */
+    @NonNull
+    public static MdnsPacket parse(@NonNull MdnsPacketReader reader) throws ParseException {
+        final int flags;
+        try {
+            reader.readUInt16(); // transaction ID (not used)
+            flags = reader.readUInt16();
+        } catch (EOFException e) {
+            throw new ParseException(MdnsResponseErrorCode.ERROR_END_OF_FILE,
+                    "Reached the end of the mDNS response unexpectedly.", e);
+        }
+        return parseRecordsSection(reader, flags);
+    }
+
+    /**
+     * Parse the records section of a mDNS packet in the provided {@link MdnsPacketReader}.
+     *
+     * The records section starts with the questions count, just after the packet flags.
+     */
+    public static MdnsPacket parseRecordsSection(@NonNull MdnsPacketReader reader, int flags)
+            throws ParseException {
+        try {
+            final int numQuestions = reader.readUInt16();
+            final int numAnswers = reader.readUInt16();
+            final int numAuthority = reader.readUInt16();
+            final int numAdditional = reader.readUInt16();
+
+            final ArrayList<MdnsRecord> questions = parseRecords(reader, numQuestions, true);
+            final ArrayList<MdnsRecord> answers = parseRecords(reader, numAnswers, false);
+            final ArrayList<MdnsRecord> authority = parseRecords(reader, numAuthority, false);
+            final ArrayList<MdnsRecord> additional = parseRecords(reader, numAdditional, false);
+
+            return new MdnsPacket(flags, questions, answers, authority, additional);
+        } catch (EOFException e) {
+            throw new ParseException(MdnsResponseErrorCode.ERROR_END_OF_FILE,
+                    "Reached the end of the mDNS response unexpectedly.", e);
+        }
+    }
+
+    private static ArrayList<MdnsRecord> parseRecords(@NonNull MdnsPacketReader reader, int count,
+            boolean isQuestion)
+            throws ParseException {
+        final ArrayList<MdnsRecord> records = new ArrayList<>(count);
+        for (int i = 0; i < count; ++i) {
+            final MdnsRecord record = parseRecord(reader, isQuestion);
+            if (record != null) {
+                records.add(record);
+            }
+        }
+        return records;
+    }
+
+    @Nullable
+    private static MdnsRecord parseRecord(@NonNull MdnsPacketReader reader, boolean isQuestion)
+            throws ParseException {
+        String[] name;
+        try {
+            name = reader.readLabels();
+        } catch (IOException e) {
+            throw new ParseException(MdnsResponseErrorCode.ERROR_READING_RECORD_NAME,
+                    "Failed to read labels from mDNS response.", e);
+        }
+
+        final int type;
+        try {
+            type = reader.readUInt16();
+        } catch (EOFException e) {
+            throw new ParseException(MdnsResponseErrorCode.ERROR_END_OF_FILE,
+                    "Reached the end of the mDNS response unexpectedly.", e);
+        }
+
+        switch (type) {
+            case MdnsRecord.TYPE_A: {
+                try {
+                    return new MdnsInetAddressRecord(name, MdnsRecord.TYPE_A, reader, isQuestion);
+                } catch (IOException e) {
+                    throw new ParseException(MdnsResponseErrorCode.ERROR_READING_A_RDATA,
+                            "Failed to read A record from mDNS response.", e);
+                }
+            }
+
+            case MdnsRecord.TYPE_AAAA: {
+                try {
+                    return new MdnsInetAddressRecord(name,
+                            MdnsRecord.TYPE_AAAA, reader, isQuestion);
+                } catch (IOException e) {
+                    throw new ParseException(MdnsResponseErrorCode.ERROR_READING_AAAA_RDATA,
+                            "Failed to read AAAA record from mDNS response.", e);
+                }
+            }
+
+            case MdnsRecord.TYPE_PTR: {
+                try {
+                    return new MdnsPointerRecord(name, reader, isQuestion);
+                } catch (IOException e) {
+                    throw new ParseException(MdnsResponseErrorCode.ERROR_READING_PTR_RDATA,
+                            "Failed to read PTR record from mDNS response.", e);
+                }
+            }
+
+            case MdnsRecord.TYPE_SRV: {
+                try {
+                    return new MdnsServiceRecord(name, reader, isQuestion);
+                } catch (IOException e) {
+                    throw new ParseException(MdnsResponseErrorCode.ERROR_READING_SRV_RDATA,
+                            "Failed to read SRV record from mDNS response.", e);
+                }
+            }
+
+            case MdnsRecord.TYPE_TXT: {
+                try {
+                    return new MdnsTextRecord(name, reader, isQuestion);
+                } catch (IOException e) {
+                    throw new ParseException(MdnsResponseErrorCode.ERROR_READING_TXT_RDATA,
+                            "Failed to read TXT record from mDNS response.", e);
+                }
+            }
+
+            case MdnsRecord.TYPE_NSEC: {
+                try {
+                    return new MdnsNsecRecord(name, reader, isQuestion);
+                } catch (IOException e) {
+                    throw new ParseException(MdnsResponseErrorCode.ERROR_READING_NSEC_RDATA,
+                            "Failed to read NSEC record from mDNS response.", e);
+                }
+            }
+
+            case MdnsRecord.TYPE_ANY: {
+                try {
+                    return new MdnsAnyRecord(name, reader);
+                } catch (IOException e) {
+                    throw new ParseException(MdnsResponseErrorCode.ERROR_READING_ANY_RDATA,
+                            "Failed to read TYPE_ANY record from mDNS response.", e);
+                }
+            }
+
+            default: {
+                try {
+                    if (MdnsAdvertiser.DBG) {
+                        Log.i(TAG, "Skipping parsing of record of unhandled type " + type);
+                    }
+                    skipMdnsRecord(reader, isQuestion);
+                    return null;
+                } catch (IOException e) {
+                    throw new ParseException(MdnsResponseErrorCode.ERROR_SKIPPING_UNKNOWN_RECORD,
+                            "Failed to skip mDNS record.", e);
+                }
+            }
+        }
+    }
+
+    private static void skipMdnsRecord(@NonNull MdnsPacketReader reader, boolean isQuestion)
+            throws IOException {
+        reader.skip(2); // Skip the class
+        if (isQuestion) return;
+        // Skip TTL and data
+        reader.skip(4);
+        int dataLength = reader.readUInt16();
+        reader.skip(dataLength);
+    }
 }
diff --git a/service-t/src/com/android/server/mdns/MdnsResponseDecoder.java b/service-t/src/com/android/server/mdns/MdnsResponseDecoder.java
index 50f2069..82da2e4 100644
--- a/service-t/src/com/android/server/mdns/MdnsResponseDecoder.java
+++ b/service-t/src/com/android/server/mdns/MdnsResponseDecoder.java
@@ -24,11 +24,9 @@
 import com.android.server.connectivity.mdns.util.MdnsLogger;
 
 import java.io.EOFException;
-import java.io.IOException;
 import java.net.DatagramPacket;
 import java.util.ArrayList;
 import java.util.Arrays;
-import java.util.LinkedList;
 import java.util.List;
 
 /** A class that decodes mDNS responses from UDP packets. */
@@ -48,12 +46,6 @@
         this.serviceType = serviceType;
     }
 
-    private static void skipMdnsRecord(MdnsPacketReader reader) throws IOException {
-        reader.skip(2 + 4); // skip the class and TTL
-        int dataLength = reader.readUInt16();
-        reader.skip(dataLength);
-    }
-
     private static MdnsResponse findResponseWithPointer(
             List<MdnsResponse> responses, String[] pointer) {
         if (responses != null) {
@@ -120,7 +112,7 @@
             int interfaceIndex, @Nullable Network network) {
         MdnsPacketReader reader = new MdnsPacketReader(recvbuf, length);
 
-        List<MdnsRecord> records;
+        final MdnsPacket mdnsPacket;
         try {
             reader.readUInt16(); // transaction ID (not used)
             int flags = reader.readUInt16();
@@ -128,111 +120,25 @@
                 return MdnsResponseErrorCode.ERROR_NOT_RESPONSE_MESSAGE;
             }
 
-            int numQuestions = reader.readUInt16();
-            int numAnswers = reader.readUInt16();
-            int numAuthority = reader.readUInt16();
-            int numRecords = reader.readUInt16();
-
-            LOGGER.log(String.format(
-                    "num questions: %d, num answers: %d, num authority: %d, num records: %d",
-                    numQuestions, numAnswers, numAuthority, numRecords));
-
-            if (numAnswers < 1) {
+            mdnsPacket = MdnsPacket.parseRecordsSection(reader, flags);
+            if (mdnsPacket.answers.size() < 1) {
                 return MdnsResponseErrorCode.ERROR_NO_ANSWERS;
             }
-
-            records = new LinkedList<>();
-
-            for (int i = 0; i < (numAnswers + numAuthority + numRecords); ++i) {
-                String[] name;
-                try {
-                    name = reader.readLabels();
-                } catch (IOException e) {
-                    LOGGER.e("Failed to read labels from mDNS response.", e);
-                    return MdnsResponseErrorCode.ERROR_READING_RECORD_NAME;
-                }
-                int type = reader.readUInt16();
-
-                switch (type) {
-                    case MdnsRecord.TYPE_A: {
-                        try {
-                            records.add(new MdnsInetAddressRecord(name, MdnsRecord.TYPE_A, reader));
-                        } catch (IOException e) {
-                            LOGGER.e("Failed to read A record from mDNS response.", e);
-                            return MdnsResponseErrorCode.ERROR_READING_A_RDATA;
-                        }
-                        break;
-                    }
-
-                    case MdnsRecord.TYPE_AAAA: {
-                        try {
-                            // AAAA should only contain the IPv6 address.
-                            MdnsInetAddressRecord record =
-                                    new MdnsInetAddressRecord(name, MdnsRecord.TYPE_AAAA, reader);
-                            if (record.getInet6Address() != null) {
-                                records.add(record);
-                            }
-                        } catch (IOException e) {
-                            LOGGER.e("Failed to read AAAA record from mDNS response.", e);
-                            return MdnsResponseErrorCode.ERROR_READING_AAAA_RDATA;
-                        }
-                        break;
-                    }
-
-                    case MdnsRecord.TYPE_PTR: {
-                        try {
-                            records.add(new MdnsPointerRecord(name, reader));
-                        } catch (IOException e) {
-                            LOGGER.e("Failed to read PTR record from mDNS response.", e);
-                            return MdnsResponseErrorCode.ERROR_READING_PTR_RDATA;
-                        }
-                        break;
-                    }
-
-                    case MdnsRecord.TYPE_SRV: {
-                        if (name.length == 4) {
-                            try {
-                                records.add(new MdnsServiceRecord(name, reader));
-                            } catch (IOException e) {
-                                LOGGER.e("Failed to read SRV record from mDNS response.", e);
-                                return MdnsResponseErrorCode.ERROR_READING_SRV_RDATA;
-                            }
-                        } else {
-                            try {
-                                skipMdnsRecord(reader);
-                            } catch (IOException e) {
-                                LOGGER.e("Failed to skip SVR record from mDNS response.", e);
-                                return MdnsResponseErrorCode.ERROR_SKIPPING_SRV_RDATA;
-                            }
-                        }
-                        break;
-                    }
-
-                    case MdnsRecord.TYPE_TXT: {
-                        try {
-                            records.add(new MdnsTextRecord(name, reader));
-                        } catch (IOException e) {
-                            LOGGER.e("Failed to read TXT record from mDNS response.", e);
-                            return MdnsResponseErrorCode.ERROR_READING_TXT_RDATA;
-                        }
-                        break;
-                    }
-
-                    default: {
-                        try {
-                            skipMdnsRecord(reader);
-                        } catch (IOException e) {
-                            LOGGER.e("Failed to skip mDNS record.", e);
-                            return MdnsResponseErrorCode.ERROR_SKIPPING_UNKNOWN_RECORD;
-                        }
-                    }
-                }
-            }
         } catch (EOFException e) {
             LOGGER.e("Reached the end of the mDNS response unexpectedly.", e);
             return MdnsResponseErrorCode.ERROR_END_OF_FILE;
+        } catch (MdnsPacket.ParseException e) {
+            LOGGER.e(e.getMessage(), e);
+            return e.code;
         }
 
+        final ArrayList<MdnsRecord> records = new ArrayList<>(
+                mdnsPacket.questions.size() + mdnsPacket.answers.size()
+                        + mdnsPacket.authorityRecords.size() + mdnsPacket.additionalRecords.size());
+        records.addAll(mdnsPacket.answers);
+        records.addAll(mdnsPacket.authorityRecords);
+        records.addAll(mdnsPacket.additionalRecords);
+
         // The response records are structured in a hierarchy, where some records reference
         // others, as follows:
         //
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsAnnouncerTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsAnnouncerTest.kt
index 650607d..334f99d 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsAnnouncerTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsAnnouncerTest.kt
@@ -91,7 +91,7 @@
         scapy.raw(scapy.dns_compress(scapy.DNS(rd=0, qr=1, aa=1,
         qd = None,
         an =
-        scapy.DNSRR(type='PTR', rrname='123.0.2.192.in-addr.arpa.', rdata='Android.local',
+        scapy.DNSRR(type='PTR', rrname='123.2.0.192.in-addr.arpa.', rdata='Android.local',
             rclass=0x8001, ttl=120) /
         scapy.DNSRR(type='PTR',
             rrname='3.2.1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.B.D.0.1.0.0.2.ip6.arpa',
@@ -111,8 +111,8 @@
         scapy.DNSRR(type='AAAA', rrname='Android.local', rclass=0x8001, rdata='2001:db8::456',
             ttl=120),
         ar =
-        scapy.DNSRRNSEC(rrname='123.0.2.192.in-addr.arpa.', rclass=0x8001, ttl=120,
-            nextname='123.0.2.192.in-addr.arpa.', typebitmaps=[12]) /
+        scapy.DNSRRNSEC(rrname='123.2.0.192.in-addr.arpa.', rclass=0x8001, ttl=120,
+            nextname='123.2.0.192.in-addr.arpa.', typebitmaps=[12]) /
         scapy.DNSRRNSEC(
             rrname='3.2.1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.B.D.0.1.0.0.2.ip6.arpa',
             rclass=0x8001, ttl=120,
@@ -131,7 +131,7 @@
             typebitmaps=[1, 28]))
         )).hex().upper()
         */
-        val expected = "00008400000000090000000503313233013001320331393207696E2D61646472046172706" +
+        val expected = "00008400000000090000000503313233013201300331393207696E2D61646472046172706" +
                 "100000C800100000078000F07416E64726F6964056C6F63616C00013301320131013001300130013" +
                 "00130013001300130013001300130013001300130013001300130013001300130013001380142014" +
                 "40130013101300130013203697036C020000C8001000000780002C030013601350134C045000C800" +
@@ -149,7 +149,7 @@
         val v4Addr = parseNumericAddress("192.0.2.123")
         val v6Addr1 = parseNumericAddress("2001:DB8::123")
         val v6Addr2 = parseNumericAddress("2001:DB8::456")
-        val v4AddrRev = arrayOf("123", "0", "2", "192", "in-addr", "arpa")
+        val v4AddrRev = getReverseDnsAddress(v4Addr)
         val v6Addr1Rev = getReverseDnsAddress(v6Addr1)
         val v6Addr2Rev = getReverseDnsAddress(v6Addr2)
 
@@ -254,7 +254,10 @@
             verify(socket, atLeast(i + 1)).send(any())
             val now = SystemClock.elapsedRealtime()
             assertTrue(now > timeStart + startDelay + i * FIRST_ANNOUNCES_DELAY)
-            assertTrue(now < timeStart + startDelay + (i + 1) * FIRST_ANNOUNCES_DELAY)
+            // Loops can be much slower than the expected timing (>100ms delay), use
+            // TEST_TIMEOUT_MS as tolerance.
+            assertTrue(now < timeStart + startDelay + (i + 1) * FIRST_ANNOUNCES_DELAY +
+                TEST_TIMEOUT_MS)
         }
 
         // Subsequent announces should happen quickly (NEXT_ANNOUNCES_DELAY)
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsPacketTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsPacketTest.kt
new file mode 100644
index 0000000..f88da1f
--- /dev/null
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsPacketTest.kt
@@ -0,0 +1,70 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server.connectivity.mdns
+
+import android.net.InetAddresses
+import com.android.net.module.util.HexDump
+import com.android.testutils.DevSdkIgnoreRunner
+import kotlin.test.assertContentEquals
+import kotlin.test.assertEquals
+import kotlin.test.assertTrue
+import org.junit.Test
+import org.junit.runner.RunWith
+
+@RunWith(DevSdkIgnoreRunner::class)
+class MdnsPacketTest {
+    @Test
+    fun testParseQuery() {
+        // Probe packet with 1 question for Android.local, and 4 additionalRecords with 4 addresses
+        // for Android.local (similar to legacy mdnsresponder probes, although it used to put 4
+        // identical questions(!!) for Android.local when there were 4 addresses).
+        val packetHex = "00000000000100000004000007416e64726f6964056c6f63616c0000ff0001c00c000100" +
+                "01000000780004c000027bc00c001c000100000078001020010db8000000000000000000000123c0" +
+                "0c001c000100000078001020010db8000000000000000000000456c00c001c000100000078001020" +
+                "010db8000000000000000000000789"
+
+        val bytes = HexDump.hexStringToByteArray(packetHex)
+        val reader = MdnsPacketReader(bytes, bytes.size)
+        val packet = MdnsPacket.parse(reader)
+
+        assertEquals(1, packet.questions.size)
+        assertEquals(0, packet.answers.size)
+        assertEquals(4, packet.authorityRecords.size)
+        assertEquals(0, packet.additionalRecords.size)
+
+        val hostname = arrayOf("Android", "local")
+        packet.questions[0].let {
+            assertTrue(it is MdnsAnyRecord)
+            assertContentEquals(hostname, it.name)
+        }
+
+        packet.authorityRecords.forEach {
+            assertTrue(it is MdnsInetAddressRecord)
+            assertContentEquals(hostname, it.name)
+            assertEquals(120000, it.ttl)
+        }
+
+        assertEquals(InetAddresses.parseNumericAddress("192.0.2.123"),
+                (packet.authorityRecords[0] as MdnsInetAddressRecord).inet4Address)
+        assertEquals(InetAddresses.parseNumericAddress("2001:db8::123"),
+                (packet.authorityRecords[1] as MdnsInetAddressRecord).inet6Address)
+        assertEquals(InetAddresses.parseNumericAddress("2001:db8::456"),
+                (packet.authorityRecords[2] as MdnsInetAddressRecord).inet6Address)
+        assertEquals(InetAddresses.parseNumericAddress("2001:db8::789"),
+                (packet.authorityRecords[3] as MdnsInetAddressRecord).inet6Address)
+    }
+}