Merge "Factor out response decoding into MdnsPacket"
diff --git a/service-t/src/com/android/server/mdns/MdnsPacket.java b/service-t/src/com/android/server/mdns/MdnsPacket.java
index eae084a..27002b9 100644
--- a/service-t/src/com/android/server/mdns/MdnsPacket.java
+++ b/service-t/src/com/android/server/mdns/MdnsPacket.java
@@ -16,6 +16,13 @@
 
 package com.android.server.connectivity.mdns;
 
+import android.annotation.NonNull;
+import android.annotation.Nullable;
+import android.util.Log;
+
+import java.io.EOFException;
+import java.io.IOException;
+import java.util.ArrayList;
 import java.util.Collections;
 import java.util.List;
 
@@ -23,21 +30,202 @@
  * A class holding data that can be included in a mDNS packet.
  */
 public class MdnsPacket {
+    private static final String TAG = MdnsPacket.class.getSimpleName();
+
     public final int flags;
+    @NonNull
     public final List<MdnsRecord> questions;
+    @NonNull
     public final List<MdnsRecord> answers;
+    @NonNull
     public final List<MdnsRecord> authorityRecords;
+    @NonNull
     public final List<MdnsRecord> additionalRecords;
 
     MdnsPacket(int flags,
-            List<MdnsRecord> questions,
-            List<MdnsRecord> answers,
-            List<MdnsRecord> authorityRecords,
-            List<MdnsRecord> additionalRecords) {
+            @NonNull List<MdnsRecord> questions,
+            @NonNull List<MdnsRecord> answers,
+            @NonNull List<MdnsRecord> authorityRecords,
+            @NonNull List<MdnsRecord> additionalRecords) {
         this.flags = flags;
         this.questions = Collections.unmodifiableList(questions);
         this.answers = Collections.unmodifiableList(answers);
         this.authorityRecords = Collections.unmodifiableList(authorityRecords);
         this.additionalRecords = Collections.unmodifiableList(additionalRecords);
     }
+
+    /**
+     * Exception thrown on parse errors.
+     */
+    public static class ParseException extends IOException {
+        public final int code;
+
+        public ParseException(int code, @NonNull String message, @Nullable Throwable cause) {
+            super(message, cause);
+            this.code = code;
+        }
+    }
+
+    /**
+     * Parse the packet in the provided {@link MdnsPacketReader}.
+     */
+    @NonNull
+    public static MdnsPacket parse(@NonNull MdnsPacketReader reader) throws ParseException {
+        final int flags;
+        try {
+            reader.readUInt16(); // transaction ID (not used)
+            flags = reader.readUInt16();
+        } catch (EOFException e) {
+            throw new ParseException(MdnsResponseErrorCode.ERROR_END_OF_FILE,
+                    "Reached the end of the mDNS response unexpectedly.", e);
+        }
+        return parseRecordsSection(reader, flags);
+    }
+
+    /**
+     * Parse the records section of a mDNS packet in the provided {@link MdnsPacketReader}.
+     *
+     * The records section starts with the questions count, just after the packet flags.
+     */
+    public static MdnsPacket parseRecordsSection(@NonNull MdnsPacketReader reader, int flags)
+            throws ParseException {
+        try {
+            final int numQuestions = reader.readUInt16();
+            final int numAnswers = reader.readUInt16();
+            final int numAuthority = reader.readUInt16();
+            final int numAdditional = reader.readUInt16();
+
+            final ArrayList<MdnsRecord> questions = parseRecords(reader, numQuestions, true);
+            final ArrayList<MdnsRecord> answers = parseRecords(reader, numAnswers, false);
+            final ArrayList<MdnsRecord> authority = parseRecords(reader, numAuthority, false);
+            final ArrayList<MdnsRecord> additional = parseRecords(reader, numAdditional, false);
+
+            return new MdnsPacket(flags, questions, answers, authority, additional);
+        } catch (EOFException e) {
+            throw new ParseException(MdnsResponseErrorCode.ERROR_END_OF_FILE,
+                    "Reached the end of the mDNS response unexpectedly.", e);
+        }
+    }
+
+    private static ArrayList<MdnsRecord> parseRecords(@NonNull MdnsPacketReader reader, int count,
+            boolean isQuestion)
+            throws ParseException {
+        final ArrayList<MdnsRecord> records = new ArrayList<>(count);
+        for (int i = 0; i < count; ++i) {
+            final MdnsRecord record = parseRecord(reader, isQuestion);
+            if (record != null) {
+                records.add(record);
+            }
+        }
+        return records;
+    }
+
+    @Nullable
+    private static MdnsRecord parseRecord(@NonNull MdnsPacketReader reader, boolean isQuestion)
+            throws ParseException {
+        String[] name;
+        try {
+            name = reader.readLabels();
+        } catch (IOException e) {
+            throw new ParseException(MdnsResponseErrorCode.ERROR_READING_RECORD_NAME,
+                    "Failed to read labels from mDNS response.", e);
+        }
+
+        final int type;
+        try {
+            type = reader.readUInt16();
+        } catch (EOFException e) {
+            throw new ParseException(MdnsResponseErrorCode.ERROR_END_OF_FILE,
+                    "Reached the end of the mDNS response unexpectedly.", e);
+        }
+
+        switch (type) {
+            case MdnsRecord.TYPE_A: {
+                try {
+                    return new MdnsInetAddressRecord(name, MdnsRecord.TYPE_A, reader, isQuestion);
+                } catch (IOException e) {
+                    throw new ParseException(MdnsResponseErrorCode.ERROR_READING_A_RDATA,
+                            "Failed to read A record from mDNS response.", e);
+                }
+            }
+
+            case MdnsRecord.TYPE_AAAA: {
+                try {
+                    return new MdnsInetAddressRecord(name,
+                            MdnsRecord.TYPE_AAAA, reader, isQuestion);
+                } catch (IOException e) {
+                    throw new ParseException(MdnsResponseErrorCode.ERROR_READING_AAAA_RDATA,
+                            "Failed to read AAAA record from mDNS response.", e);
+                }
+            }
+
+            case MdnsRecord.TYPE_PTR: {
+                try {
+                    return new MdnsPointerRecord(name, reader, isQuestion);
+                } catch (IOException e) {
+                    throw new ParseException(MdnsResponseErrorCode.ERROR_READING_PTR_RDATA,
+                            "Failed to read PTR record from mDNS response.", e);
+                }
+            }
+
+            case MdnsRecord.TYPE_SRV: {
+                try {
+                    return new MdnsServiceRecord(name, reader, isQuestion);
+                } catch (IOException e) {
+                    throw new ParseException(MdnsResponseErrorCode.ERROR_READING_SRV_RDATA,
+                            "Failed to read SRV record from mDNS response.", e);
+                }
+            }
+
+            case MdnsRecord.TYPE_TXT: {
+                try {
+                    return new MdnsTextRecord(name, reader, isQuestion);
+                } catch (IOException e) {
+                    throw new ParseException(MdnsResponseErrorCode.ERROR_READING_TXT_RDATA,
+                            "Failed to read TXT record from mDNS response.", e);
+                }
+            }
+
+            case MdnsRecord.TYPE_NSEC: {
+                try {
+                    return new MdnsNsecRecord(name, reader, isQuestion);
+                } catch (IOException e) {
+                    throw new ParseException(MdnsResponseErrorCode.ERROR_READING_NSEC_RDATA,
+                            "Failed to read NSEC record from mDNS response.", e);
+                }
+            }
+
+            case MdnsRecord.TYPE_ANY: {
+                try {
+                    return new MdnsAnyRecord(name, reader);
+                } catch (IOException e) {
+                    throw new ParseException(MdnsResponseErrorCode.ERROR_READING_ANY_RDATA,
+                            "Failed to read TYPE_ANY record from mDNS response.", e);
+                }
+            }
+
+            default: {
+                try {
+                    if (MdnsAdvertiser.DBG) {
+                        Log.i(TAG, "Skipping parsing of record of unhandled type " + type);
+                    }
+                    skipMdnsRecord(reader, isQuestion);
+                    return null;
+                } catch (IOException e) {
+                    throw new ParseException(MdnsResponseErrorCode.ERROR_SKIPPING_UNKNOWN_RECORD,
+                            "Failed to skip mDNS record.", e);
+                }
+            }
+        }
+    }
+
+    private static void skipMdnsRecord(@NonNull MdnsPacketReader reader, boolean isQuestion)
+            throws IOException {
+        reader.skip(2); // Skip the class
+        if (isQuestion) return;
+        // Skip TTL and data
+        reader.skip(4);
+        int dataLength = reader.readUInt16();
+        reader.skip(dataLength);
+    }
 }
diff --git a/service-t/src/com/android/server/mdns/MdnsResponseDecoder.java b/service-t/src/com/android/server/mdns/MdnsResponseDecoder.java
index 50f2069..82da2e4 100644
--- a/service-t/src/com/android/server/mdns/MdnsResponseDecoder.java
+++ b/service-t/src/com/android/server/mdns/MdnsResponseDecoder.java
@@ -24,11 +24,9 @@
 import com.android.server.connectivity.mdns.util.MdnsLogger;
 
 import java.io.EOFException;
-import java.io.IOException;
 import java.net.DatagramPacket;
 import java.util.ArrayList;
 import java.util.Arrays;
-import java.util.LinkedList;
 import java.util.List;
 
 /** A class that decodes mDNS responses from UDP packets. */
@@ -48,12 +46,6 @@
         this.serviceType = serviceType;
     }
 
-    private static void skipMdnsRecord(MdnsPacketReader reader) throws IOException {
-        reader.skip(2 + 4); // skip the class and TTL
-        int dataLength = reader.readUInt16();
-        reader.skip(dataLength);
-    }
-
     private static MdnsResponse findResponseWithPointer(
             List<MdnsResponse> responses, String[] pointer) {
         if (responses != null) {
@@ -120,7 +112,7 @@
             int interfaceIndex, @Nullable Network network) {
         MdnsPacketReader reader = new MdnsPacketReader(recvbuf, length);
 
-        List<MdnsRecord> records;
+        final MdnsPacket mdnsPacket;
         try {
             reader.readUInt16(); // transaction ID (not used)
             int flags = reader.readUInt16();
@@ -128,111 +120,25 @@
                 return MdnsResponseErrorCode.ERROR_NOT_RESPONSE_MESSAGE;
             }
 
-            int numQuestions = reader.readUInt16();
-            int numAnswers = reader.readUInt16();
-            int numAuthority = reader.readUInt16();
-            int numRecords = reader.readUInt16();
-
-            LOGGER.log(String.format(
-                    "num questions: %d, num answers: %d, num authority: %d, num records: %d",
-                    numQuestions, numAnswers, numAuthority, numRecords));
-
-            if (numAnswers < 1) {
+            mdnsPacket = MdnsPacket.parseRecordsSection(reader, flags);
+            if (mdnsPacket.answers.size() < 1) {
                 return MdnsResponseErrorCode.ERROR_NO_ANSWERS;
             }
-
-            records = new LinkedList<>();
-
-            for (int i = 0; i < (numAnswers + numAuthority + numRecords); ++i) {
-                String[] name;
-                try {
-                    name = reader.readLabels();
-                } catch (IOException e) {
-                    LOGGER.e("Failed to read labels from mDNS response.", e);
-                    return MdnsResponseErrorCode.ERROR_READING_RECORD_NAME;
-                }
-                int type = reader.readUInt16();
-
-                switch (type) {
-                    case MdnsRecord.TYPE_A: {
-                        try {
-                            records.add(new MdnsInetAddressRecord(name, MdnsRecord.TYPE_A, reader));
-                        } catch (IOException e) {
-                            LOGGER.e("Failed to read A record from mDNS response.", e);
-                            return MdnsResponseErrorCode.ERROR_READING_A_RDATA;
-                        }
-                        break;
-                    }
-
-                    case MdnsRecord.TYPE_AAAA: {
-                        try {
-                            // AAAA should only contain the IPv6 address.
-                            MdnsInetAddressRecord record =
-                                    new MdnsInetAddressRecord(name, MdnsRecord.TYPE_AAAA, reader);
-                            if (record.getInet6Address() != null) {
-                                records.add(record);
-                            }
-                        } catch (IOException e) {
-                            LOGGER.e("Failed to read AAAA record from mDNS response.", e);
-                            return MdnsResponseErrorCode.ERROR_READING_AAAA_RDATA;
-                        }
-                        break;
-                    }
-
-                    case MdnsRecord.TYPE_PTR: {
-                        try {
-                            records.add(new MdnsPointerRecord(name, reader));
-                        } catch (IOException e) {
-                            LOGGER.e("Failed to read PTR record from mDNS response.", e);
-                            return MdnsResponseErrorCode.ERROR_READING_PTR_RDATA;
-                        }
-                        break;
-                    }
-
-                    case MdnsRecord.TYPE_SRV: {
-                        if (name.length == 4) {
-                            try {
-                                records.add(new MdnsServiceRecord(name, reader));
-                            } catch (IOException e) {
-                                LOGGER.e("Failed to read SRV record from mDNS response.", e);
-                                return MdnsResponseErrorCode.ERROR_READING_SRV_RDATA;
-                            }
-                        } else {
-                            try {
-                                skipMdnsRecord(reader);
-                            } catch (IOException e) {
-                                LOGGER.e("Failed to skip SVR record from mDNS response.", e);
-                                return MdnsResponseErrorCode.ERROR_SKIPPING_SRV_RDATA;
-                            }
-                        }
-                        break;
-                    }
-
-                    case MdnsRecord.TYPE_TXT: {
-                        try {
-                            records.add(new MdnsTextRecord(name, reader));
-                        } catch (IOException e) {
-                            LOGGER.e("Failed to read TXT record from mDNS response.", e);
-                            return MdnsResponseErrorCode.ERROR_READING_TXT_RDATA;
-                        }
-                        break;
-                    }
-
-                    default: {
-                        try {
-                            skipMdnsRecord(reader);
-                        } catch (IOException e) {
-                            LOGGER.e("Failed to skip mDNS record.", e);
-                            return MdnsResponseErrorCode.ERROR_SKIPPING_UNKNOWN_RECORD;
-                        }
-                    }
-                }
-            }
         } catch (EOFException e) {
             LOGGER.e("Reached the end of the mDNS response unexpectedly.", e);
             return MdnsResponseErrorCode.ERROR_END_OF_FILE;
+        } catch (MdnsPacket.ParseException e) {
+            LOGGER.e(e.getMessage(), e);
+            return e.code;
         }
 
+        final ArrayList<MdnsRecord> records = new ArrayList<>(
+                mdnsPacket.questions.size() + mdnsPacket.answers.size()
+                        + mdnsPacket.authorityRecords.size() + mdnsPacket.additionalRecords.size());
+        records.addAll(mdnsPacket.answers);
+        records.addAll(mdnsPacket.authorityRecords);
+        records.addAll(mdnsPacket.additionalRecords);
+
         // The response records are structured in a hierarchy, where some records reference
         // others, as follows:
         //
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsAnnouncerTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsAnnouncerTest.kt
index 650607d..334f99d 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsAnnouncerTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsAnnouncerTest.kt
@@ -91,7 +91,7 @@
         scapy.raw(scapy.dns_compress(scapy.DNS(rd=0, qr=1, aa=1,
         qd = None,
         an =
-        scapy.DNSRR(type='PTR', rrname='123.0.2.192.in-addr.arpa.', rdata='Android.local',
+        scapy.DNSRR(type='PTR', rrname='123.2.0.192.in-addr.arpa.', rdata='Android.local',
             rclass=0x8001, ttl=120) /
         scapy.DNSRR(type='PTR',
             rrname='3.2.1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.B.D.0.1.0.0.2.ip6.arpa',
@@ -111,8 +111,8 @@
         scapy.DNSRR(type='AAAA', rrname='Android.local', rclass=0x8001, rdata='2001:db8::456',
             ttl=120),
         ar =
-        scapy.DNSRRNSEC(rrname='123.0.2.192.in-addr.arpa.', rclass=0x8001, ttl=120,
-            nextname='123.0.2.192.in-addr.arpa.', typebitmaps=[12]) /
+        scapy.DNSRRNSEC(rrname='123.2.0.192.in-addr.arpa.', rclass=0x8001, ttl=120,
+            nextname='123.2.0.192.in-addr.arpa.', typebitmaps=[12]) /
         scapy.DNSRRNSEC(
             rrname='3.2.1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.B.D.0.1.0.0.2.ip6.arpa',
             rclass=0x8001, ttl=120,
@@ -131,7 +131,7 @@
             typebitmaps=[1, 28]))
         )).hex().upper()
         */
-        val expected = "00008400000000090000000503313233013001320331393207696E2D61646472046172706" +
+        val expected = "00008400000000090000000503313233013201300331393207696E2D61646472046172706" +
                 "100000C800100000078000F07416E64726F6964056C6F63616C00013301320131013001300130013" +
                 "00130013001300130013001300130013001300130013001300130013001300130013001380142014" +
                 "40130013101300130013203697036C020000C8001000000780002C030013601350134C045000C800" +
@@ -149,7 +149,7 @@
         val v4Addr = parseNumericAddress("192.0.2.123")
         val v6Addr1 = parseNumericAddress("2001:DB8::123")
         val v6Addr2 = parseNumericAddress("2001:DB8::456")
-        val v4AddrRev = arrayOf("123", "0", "2", "192", "in-addr", "arpa")
+        val v4AddrRev = getReverseDnsAddress(v4Addr)
         val v6Addr1Rev = getReverseDnsAddress(v6Addr1)
         val v6Addr2Rev = getReverseDnsAddress(v6Addr2)
 
@@ -254,7 +254,10 @@
             verify(socket, atLeast(i + 1)).send(any())
             val now = SystemClock.elapsedRealtime()
             assertTrue(now > timeStart + startDelay + i * FIRST_ANNOUNCES_DELAY)
-            assertTrue(now < timeStart + startDelay + (i + 1) * FIRST_ANNOUNCES_DELAY)
+            // Loops can be much slower than the expected timing (>100ms delay), use
+            // TEST_TIMEOUT_MS as tolerance.
+            assertTrue(now < timeStart + startDelay + (i + 1) * FIRST_ANNOUNCES_DELAY +
+                TEST_TIMEOUT_MS)
         }
 
         // Subsequent announces should happen quickly (NEXT_ANNOUNCES_DELAY)
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsPacketTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsPacketTest.kt
new file mode 100644
index 0000000..f88da1f
--- /dev/null
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsPacketTest.kt
@@ -0,0 +1,70 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server.connectivity.mdns
+
+import android.net.InetAddresses
+import com.android.net.module.util.HexDump
+import com.android.testutils.DevSdkIgnoreRunner
+import kotlin.test.assertContentEquals
+import kotlin.test.assertEquals
+import kotlin.test.assertTrue
+import org.junit.Test
+import org.junit.runner.RunWith
+
+@RunWith(DevSdkIgnoreRunner::class)
+class MdnsPacketTest {
+    @Test
+    fun testParseQuery() {
+        // Probe packet with 1 question for Android.local, and 4 additionalRecords with 4 addresses
+        // for Android.local (similar to legacy mdnsresponder probes, although it used to put 4
+        // identical questions(!!) for Android.local when there were 4 addresses).
+        val packetHex = "00000000000100000004000007416e64726f6964056c6f63616c0000ff0001c00c000100" +
+                "01000000780004c000027bc00c001c000100000078001020010db8000000000000000000000123c0" +
+                "0c001c000100000078001020010db8000000000000000000000456c00c001c000100000078001020" +
+                "010db8000000000000000000000789"
+
+        val bytes = HexDump.hexStringToByteArray(packetHex)
+        val reader = MdnsPacketReader(bytes, bytes.size)
+        val packet = MdnsPacket.parse(reader)
+
+        assertEquals(1, packet.questions.size)
+        assertEquals(0, packet.answers.size)
+        assertEquals(4, packet.authorityRecords.size)
+        assertEquals(0, packet.additionalRecords.size)
+
+        val hostname = arrayOf("Android", "local")
+        packet.questions[0].let {
+            assertTrue(it is MdnsAnyRecord)
+            assertContentEquals(hostname, it.name)
+        }
+
+        packet.authorityRecords.forEach {
+            assertTrue(it is MdnsInetAddressRecord)
+            assertContentEquals(hostname, it.name)
+            assertEquals(120000, it.ttl)
+        }
+
+        assertEquals(InetAddresses.parseNumericAddress("192.0.2.123"),
+                (packet.authorityRecords[0] as MdnsInetAddressRecord).inet4Address)
+        assertEquals(InetAddresses.parseNumericAddress("2001:db8::123"),
+                (packet.authorityRecords[1] as MdnsInetAddressRecord).inet6Address)
+        assertEquals(InetAddresses.parseNumericAddress("2001:db8::456"),
+                (packet.authorityRecords[2] as MdnsInetAddressRecord).inet6Address)
+        assertEquals(InetAddresses.parseNumericAddress("2001:db8::789"),
+                (packet.authorityRecords[3] as MdnsInetAddressRecord).inet6Address)
+    }
+}