Merge "Factor out response decoding into MdnsPacket"
diff --git a/service-t/src/com/android/server/mdns/MdnsPacket.java b/service-t/src/com/android/server/mdns/MdnsPacket.java
index eae084a..27002b9 100644
--- a/service-t/src/com/android/server/mdns/MdnsPacket.java
+++ b/service-t/src/com/android/server/mdns/MdnsPacket.java
@@ -16,6 +16,13 @@
package com.android.server.connectivity.mdns;
+import android.annotation.NonNull;
+import android.annotation.Nullable;
+import android.util.Log;
+
+import java.io.EOFException;
+import java.io.IOException;
+import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
@@ -23,21 +30,202 @@
* A class holding data that can be included in a mDNS packet.
*/
public class MdnsPacket {
+ private static final String TAG = MdnsPacket.class.getSimpleName();
+
public final int flags;
+ @NonNull
public final List<MdnsRecord> questions;
+ @NonNull
public final List<MdnsRecord> answers;
+ @NonNull
public final List<MdnsRecord> authorityRecords;
+ @NonNull
public final List<MdnsRecord> additionalRecords;
MdnsPacket(int flags,
- List<MdnsRecord> questions,
- List<MdnsRecord> answers,
- List<MdnsRecord> authorityRecords,
- List<MdnsRecord> additionalRecords) {
+ @NonNull List<MdnsRecord> questions,
+ @NonNull List<MdnsRecord> answers,
+ @NonNull List<MdnsRecord> authorityRecords,
+ @NonNull List<MdnsRecord> additionalRecords) {
this.flags = flags;
this.questions = Collections.unmodifiableList(questions);
this.answers = Collections.unmodifiableList(answers);
this.authorityRecords = Collections.unmodifiableList(authorityRecords);
this.additionalRecords = Collections.unmodifiableList(additionalRecords);
}
+
+ /**
+ * Exception thrown on parse errors.
+ */
+ public static class ParseException extends IOException {
+ public final int code;
+
+ public ParseException(int code, @NonNull String message, @Nullable Throwable cause) {
+ super(message, cause);
+ this.code = code;
+ }
+ }
+
+ /**
+ * Parse the packet in the provided {@link MdnsPacketReader}.
+ */
+ @NonNull
+ public static MdnsPacket parse(@NonNull MdnsPacketReader reader) throws ParseException {
+ final int flags;
+ try {
+ reader.readUInt16(); // transaction ID (not used)
+ flags = reader.readUInt16();
+ } catch (EOFException e) {
+ throw new ParseException(MdnsResponseErrorCode.ERROR_END_OF_FILE,
+ "Reached the end of the mDNS response unexpectedly.", e);
+ }
+ return parseRecordsSection(reader, flags);
+ }
+
+ /**
+ * Parse the records section of a mDNS packet in the provided {@link MdnsPacketReader}.
+ *
+ * The records section starts with the questions count, just after the packet flags.
+ */
+ public static MdnsPacket parseRecordsSection(@NonNull MdnsPacketReader reader, int flags)
+ throws ParseException {
+ try {
+ final int numQuestions = reader.readUInt16();
+ final int numAnswers = reader.readUInt16();
+ final int numAuthority = reader.readUInt16();
+ final int numAdditional = reader.readUInt16();
+
+ final ArrayList<MdnsRecord> questions = parseRecords(reader, numQuestions, true);
+ final ArrayList<MdnsRecord> answers = parseRecords(reader, numAnswers, false);
+ final ArrayList<MdnsRecord> authority = parseRecords(reader, numAuthority, false);
+ final ArrayList<MdnsRecord> additional = parseRecords(reader, numAdditional, false);
+
+ return new MdnsPacket(flags, questions, answers, authority, additional);
+ } catch (EOFException e) {
+ throw new ParseException(MdnsResponseErrorCode.ERROR_END_OF_FILE,
+ "Reached the end of the mDNS response unexpectedly.", e);
+ }
+ }
+
+ private static ArrayList<MdnsRecord> parseRecords(@NonNull MdnsPacketReader reader, int count,
+ boolean isQuestion)
+ throws ParseException {
+ final ArrayList<MdnsRecord> records = new ArrayList<>(count);
+ for (int i = 0; i < count; ++i) {
+ final MdnsRecord record = parseRecord(reader, isQuestion);
+ if (record != null) {
+ records.add(record);
+ }
+ }
+ return records;
+ }
+
+ @Nullable
+ private static MdnsRecord parseRecord(@NonNull MdnsPacketReader reader, boolean isQuestion)
+ throws ParseException {
+ String[] name;
+ try {
+ name = reader.readLabels();
+ } catch (IOException e) {
+ throw new ParseException(MdnsResponseErrorCode.ERROR_READING_RECORD_NAME,
+ "Failed to read labels from mDNS response.", e);
+ }
+
+ final int type;
+ try {
+ type = reader.readUInt16();
+ } catch (EOFException e) {
+ throw new ParseException(MdnsResponseErrorCode.ERROR_END_OF_FILE,
+ "Reached the end of the mDNS response unexpectedly.", e);
+ }
+
+ switch (type) {
+ case MdnsRecord.TYPE_A: {
+ try {
+ return new MdnsInetAddressRecord(name, MdnsRecord.TYPE_A, reader, isQuestion);
+ } catch (IOException e) {
+ throw new ParseException(MdnsResponseErrorCode.ERROR_READING_A_RDATA,
+ "Failed to read A record from mDNS response.", e);
+ }
+ }
+
+ case MdnsRecord.TYPE_AAAA: {
+ try {
+ return new MdnsInetAddressRecord(name,
+ MdnsRecord.TYPE_AAAA, reader, isQuestion);
+ } catch (IOException e) {
+ throw new ParseException(MdnsResponseErrorCode.ERROR_READING_AAAA_RDATA,
+ "Failed to read AAAA record from mDNS response.", e);
+ }
+ }
+
+ case MdnsRecord.TYPE_PTR: {
+ try {
+ return new MdnsPointerRecord(name, reader, isQuestion);
+ } catch (IOException e) {
+ throw new ParseException(MdnsResponseErrorCode.ERROR_READING_PTR_RDATA,
+ "Failed to read PTR record from mDNS response.", e);
+ }
+ }
+
+ case MdnsRecord.TYPE_SRV: {
+ try {
+ return new MdnsServiceRecord(name, reader, isQuestion);
+ } catch (IOException e) {
+ throw new ParseException(MdnsResponseErrorCode.ERROR_READING_SRV_RDATA,
+ "Failed to read SRV record from mDNS response.", e);
+ }
+ }
+
+ case MdnsRecord.TYPE_TXT: {
+ try {
+ return new MdnsTextRecord(name, reader, isQuestion);
+ } catch (IOException e) {
+ throw new ParseException(MdnsResponseErrorCode.ERROR_READING_TXT_RDATA,
+ "Failed to read TXT record from mDNS response.", e);
+ }
+ }
+
+ case MdnsRecord.TYPE_NSEC: {
+ try {
+ return new MdnsNsecRecord(name, reader, isQuestion);
+ } catch (IOException e) {
+ throw new ParseException(MdnsResponseErrorCode.ERROR_READING_NSEC_RDATA,
+ "Failed to read NSEC record from mDNS response.", e);
+ }
+ }
+
+ case MdnsRecord.TYPE_ANY: {
+ try {
+ return new MdnsAnyRecord(name, reader);
+ } catch (IOException e) {
+ throw new ParseException(MdnsResponseErrorCode.ERROR_READING_ANY_RDATA,
+ "Failed to read TYPE_ANY record from mDNS response.", e);
+ }
+ }
+
+ default: {
+ try {
+ if (MdnsAdvertiser.DBG) {
+ Log.i(TAG, "Skipping parsing of record of unhandled type " + type);
+ }
+ skipMdnsRecord(reader, isQuestion);
+ return null;
+ } catch (IOException e) {
+ throw new ParseException(MdnsResponseErrorCode.ERROR_SKIPPING_UNKNOWN_RECORD,
+ "Failed to skip mDNS record.", e);
+ }
+ }
+ }
+ }
+
+ private static void skipMdnsRecord(@NonNull MdnsPacketReader reader, boolean isQuestion)
+ throws IOException {
+ reader.skip(2); // Skip the class
+ if (isQuestion) return;
+ // Skip TTL and data
+ reader.skip(4);
+ int dataLength = reader.readUInt16();
+ reader.skip(dataLength);
+ }
}
diff --git a/service-t/src/com/android/server/mdns/MdnsResponseDecoder.java b/service-t/src/com/android/server/mdns/MdnsResponseDecoder.java
index 50f2069..82da2e4 100644
--- a/service-t/src/com/android/server/mdns/MdnsResponseDecoder.java
+++ b/service-t/src/com/android/server/mdns/MdnsResponseDecoder.java
@@ -24,11 +24,9 @@
import com.android.server.connectivity.mdns.util.MdnsLogger;
import java.io.EOFException;
-import java.io.IOException;
import java.net.DatagramPacket;
import java.util.ArrayList;
import java.util.Arrays;
-import java.util.LinkedList;
import java.util.List;
/** A class that decodes mDNS responses from UDP packets. */
@@ -48,12 +46,6 @@
this.serviceType = serviceType;
}
- private static void skipMdnsRecord(MdnsPacketReader reader) throws IOException {
- reader.skip(2 + 4); // skip the class and TTL
- int dataLength = reader.readUInt16();
- reader.skip(dataLength);
- }
-
private static MdnsResponse findResponseWithPointer(
List<MdnsResponse> responses, String[] pointer) {
if (responses != null) {
@@ -120,7 +112,7 @@
int interfaceIndex, @Nullable Network network) {
MdnsPacketReader reader = new MdnsPacketReader(recvbuf, length);
- List<MdnsRecord> records;
+ final MdnsPacket mdnsPacket;
try {
reader.readUInt16(); // transaction ID (not used)
int flags = reader.readUInt16();
@@ -128,111 +120,25 @@
return MdnsResponseErrorCode.ERROR_NOT_RESPONSE_MESSAGE;
}
- int numQuestions = reader.readUInt16();
- int numAnswers = reader.readUInt16();
- int numAuthority = reader.readUInt16();
- int numRecords = reader.readUInt16();
-
- LOGGER.log(String.format(
- "num questions: %d, num answers: %d, num authority: %d, num records: %d",
- numQuestions, numAnswers, numAuthority, numRecords));
-
- if (numAnswers < 1) {
+ mdnsPacket = MdnsPacket.parseRecordsSection(reader, flags);
+ if (mdnsPacket.answers.size() < 1) {
return MdnsResponseErrorCode.ERROR_NO_ANSWERS;
}
-
- records = new LinkedList<>();
-
- for (int i = 0; i < (numAnswers + numAuthority + numRecords); ++i) {
- String[] name;
- try {
- name = reader.readLabels();
- } catch (IOException e) {
- LOGGER.e("Failed to read labels from mDNS response.", e);
- return MdnsResponseErrorCode.ERROR_READING_RECORD_NAME;
- }
- int type = reader.readUInt16();
-
- switch (type) {
- case MdnsRecord.TYPE_A: {
- try {
- records.add(new MdnsInetAddressRecord(name, MdnsRecord.TYPE_A, reader));
- } catch (IOException e) {
- LOGGER.e("Failed to read A record from mDNS response.", e);
- return MdnsResponseErrorCode.ERROR_READING_A_RDATA;
- }
- break;
- }
-
- case MdnsRecord.TYPE_AAAA: {
- try {
- // AAAA should only contain the IPv6 address.
- MdnsInetAddressRecord record =
- new MdnsInetAddressRecord(name, MdnsRecord.TYPE_AAAA, reader);
- if (record.getInet6Address() != null) {
- records.add(record);
- }
- } catch (IOException e) {
- LOGGER.e("Failed to read AAAA record from mDNS response.", e);
- return MdnsResponseErrorCode.ERROR_READING_AAAA_RDATA;
- }
- break;
- }
-
- case MdnsRecord.TYPE_PTR: {
- try {
- records.add(new MdnsPointerRecord(name, reader));
- } catch (IOException e) {
- LOGGER.e("Failed to read PTR record from mDNS response.", e);
- return MdnsResponseErrorCode.ERROR_READING_PTR_RDATA;
- }
- break;
- }
-
- case MdnsRecord.TYPE_SRV: {
- if (name.length == 4) {
- try {
- records.add(new MdnsServiceRecord(name, reader));
- } catch (IOException e) {
- LOGGER.e("Failed to read SRV record from mDNS response.", e);
- return MdnsResponseErrorCode.ERROR_READING_SRV_RDATA;
- }
- } else {
- try {
- skipMdnsRecord(reader);
- } catch (IOException e) {
- LOGGER.e("Failed to skip SVR record from mDNS response.", e);
- return MdnsResponseErrorCode.ERROR_SKIPPING_SRV_RDATA;
- }
- }
- break;
- }
-
- case MdnsRecord.TYPE_TXT: {
- try {
- records.add(new MdnsTextRecord(name, reader));
- } catch (IOException e) {
- LOGGER.e("Failed to read TXT record from mDNS response.", e);
- return MdnsResponseErrorCode.ERROR_READING_TXT_RDATA;
- }
- break;
- }
-
- default: {
- try {
- skipMdnsRecord(reader);
- } catch (IOException e) {
- LOGGER.e("Failed to skip mDNS record.", e);
- return MdnsResponseErrorCode.ERROR_SKIPPING_UNKNOWN_RECORD;
- }
- }
- }
- }
} catch (EOFException e) {
LOGGER.e("Reached the end of the mDNS response unexpectedly.", e);
return MdnsResponseErrorCode.ERROR_END_OF_FILE;
+ } catch (MdnsPacket.ParseException e) {
+ LOGGER.e(e.getMessage(), e);
+ return e.code;
}
+ final ArrayList<MdnsRecord> records = new ArrayList<>(
+ mdnsPacket.questions.size() + mdnsPacket.answers.size()
+ + mdnsPacket.authorityRecords.size() + mdnsPacket.additionalRecords.size());
+ records.addAll(mdnsPacket.answers);
+ records.addAll(mdnsPacket.authorityRecords);
+ records.addAll(mdnsPacket.additionalRecords);
+
// The response records are structured in a hierarchy, where some records reference
// others, as follows:
//
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsAnnouncerTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsAnnouncerTest.kt
index 650607d..334f99d 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsAnnouncerTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsAnnouncerTest.kt
@@ -91,7 +91,7 @@
scapy.raw(scapy.dns_compress(scapy.DNS(rd=0, qr=1, aa=1,
qd = None,
an =
- scapy.DNSRR(type='PTR', rrname='123.0.2.192.in-addr.arpa.', rdata='Android.local',
+ scapy.DNSRR(type='PTR', rrname='123.2.0.192.in-addr.arpa.', rdata='Android.local',
rclass=0x8001, ttl=120) /
scapy.DNSRR(type='PTR',
rrname='3.2.1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.B.D.0.1.0.0.2.ip6.arpa',
@@ -111,8 +111,8 @@
scapy.DNSRR(type='AAAA', rrname='Android.local', rclass=0x8001, rdata='2001:db8::456',
ttl=120),
ar =
- scapy.DNSRRNSEC(rrname='123.0.2.192.in-addr.arpa.', rclass=0x8001, ttl=120,
- nextname='123.0.2.192.in-addr.arpa.', typebitmaps=[12]) /
+ scapy.DNSRRNSEC(rrname='123.2.0.192.in-addr.arpa.', rclass=0x8001, ttl=120,
+ nextname='123.2.0.192.in-addr.arpa.', typebitmaps=[12]) /
scapy.DNSRRNSEC(
rrname='3.2.1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.B.D.0.1.0.0.2.ip6.arpa',
rclass=0x8001, ttl=120,
@@ -131,7 +131,7 @@
typebitmaps=[1, 28]))
)).hex().upper()
*/
- val expected = "00008400000000090000000503313233013001320331393207696E2D61646472046172706" +
+ val expected = "00008400000000090000000503313233013201300331393207696E2D61646472046172706" +
"100000C800100000078000F07416E64726F6964056C6F63616C00013301320131013001300130013" +
"00130013001300130013001300130013001300130013001300130013001300130013001380142014" +
"40130013101300130013203697036C020000C8001000000780002C030013601350134C045000C800" +
@@ -149,7 +149,7 @@
val v4Addr = parseNumericAddress("192.0.2.123")
val v6Addr1 = parseNumericAddress("2001:DB8::123")
val v6Addr2 = parseNumericAddress("2001:DB8::456")
- val v4AddrRev = arrayOf("123", "0", "2", "192", "in-addr", "arpa")
+ val v4AddrRev = getReverseDnsAddress(v4Addr)
val v6Addr1Rev = getReverseDnsAddress(v6Addr1)
val v6Addr2Rev = getReverseDnsAddress(v6Addr2)
@@ -254,7 +254,10 @@
verify(socket, atLeast(i + 1)).send(any())
val now = SystemClock.elapsedRealtime()
assertTrue(now > timeStart + startDelay + i * FIRST_ANNOUNCES_DELAY)
- assertTrue(now < timeStart + startDelay + (i + 1) * FIRST_ANNOUNCES_DELAY)
+ // Loops can be much slower than the expected timing (>100ms delay), use
+ // TEST_TIMEOUT_MS as tolerance.
+ assertTrue(now < timeStart + startDelay + (i + 1) * FIRST_ANNOUNCES_DELAY +
+ TEST_TIMEOUT_MS)
}
// Subsequent announces should happen quickly (NEXT_ANNOUNCES_DELAY)
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsPacketTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsPacketTest.kt
new file mode 100644
index 0000000..f88da1f
--- /dev/null
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsPacketTest.kt
@@ -0,0 +1,70 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server.connectivity.mdns
+
+import android.net.InetAddresses
+import com.android.net.module.util.HexDump
+import com.android.testutils.DevSdkIgnoreRunner
+import kotlin.test.assertContentEquals
+import kotlin.test.assertEquals
+import kotlin.test.assertTrue
+import org.junit.Test
+import org.junit.runner.RunWith
+
+@RunWith(DevSdkIgnoreRunner::class)
+class MdnsPacketTest {
+ @Test
+ fun testParseQuery() {
+ // Probe packet with 1 question for Android.local, and 4 additionalRecords with 4 addresses
+ // for Android.local (similar to legacy mdnsresponder probes, although it used to put 4
+ // identical questions(!!) for Android.local when there were 4 addresses).
+ val packetHex = "00000000000100000004000007416e64726f6964056c6f63616c0000ff0001c00c000100" +
+ "01000000780004c000027bc00c001c000100000078001020010db8000000000000000000000123c0" +
+ "0c001c000100000078001020010db8000000000000000000000456c00c001c000100000078001020" +
+ "010db8000000000000000000000789"
+
+ val bytes = HexDump.hexStringToByteArray(packetHex)
+ val reader = MdnsPacketReader(bytes, bytes.size)
+ val packet = MdnsPacket.parse(reader)
+
+ assertEquals(1, packet.questions.size)
+ assertEquals(0, packet.answers.size)
+ assertEquals(4, packet.authorityRecords.size)
+ assertEquals(0, packet.additionalRecords.size)
+
+ val hostname = arrayOf("Android", "local")
+ packet.questions[0].let {
+ assertTrue(it is MdnsAnyRecord)
+ assertContentEquals(hostname, it.name)
+ }
+
+ packet.authorityRecords.forEach {
+ assertTrue(it is MdnsInetAddressRecord)
+ assertContentEquals(hostname, it.name)
+ assertEquals(120000, it.ttl)
+ }
+
+ assertEquals(InetAddresses.parseNumericAddress("192.0.2.123"),
+ (packet.authorityRecords[0] as MdnsInetAddressRecord).inet4Address)
+ assertEquals(InetAddresses.parseNumericAddress("2001:db8::123"),
+ (packet.authorityRecords[1] as MdnsInetAddressRecord).inet6Address)
+ assertEquals(InetAddresses.parseNumericAddress("2001:db8::456"),
+ (packet.authorityRecords[2] as MdnsInetAddressRecord).inet6Address)
+ assertEquals(InetAddresses.parseNumericAddress("2001:db8::789"),
+ (packet.authorityRecords[3] as MdnsInetAddressRecord).inet6Address)
+ }
+}