Merge "Limit the label count when parsing a mDns packet" into udc-mainline-prod
diff --git a/service-t/src/com/android/server/NsdService.java b/service-t/src/com/android/server/NsdService.java
index c74f229..43357e4 100644
--- a/service-t/src/com/android/server/NsdService.java
+++ b/service-t/src/com/android/server/NsdService.java
@@ -1647,10 +1647,12 @@
mContext, MdnsFeatureFlags.INCLUDE_INET_ADDRESS_RECORDS_IN_PROBING))
.setIsExpiredServicesRemovalEnabled(mDeps.isTrunkStableFeatureEnabled(
MdnsFeatureFlags.NSD_EXPIRED_SERVICES_REMOVAL))
+ .setIsLabelCountLimitEnabled(mDeps.isTetheringFeatureNotChickenedOut(
+ mContext, MdnsFeatureFlags.NSD_LIMIT_LABEL_COUNT))
.build();
mMdnsSocketClient =
new MdnsMultinetworkSocketClient(handler.getLooper(), mMdnsSocketProvider,
- LOGGER.forSubComponent("MdnsMultinetworkSocketClient"));
+ LOGGER.forSubComponent("MdnsMultinetworkSocketClient"), flags);
mMdnsDiscoveryManager = deps.makeMdnsDiscoveryManager(new ExecutorProvider(),
mMdnsSocketClient, LOGGER.forSubComponent("MdnsDiscoveryManager"), flags);
handler.post(() -> mMdnsSocketClient.setCallback(mMdnsDiscoveryManager));
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsFeatureFlags.java b/service-t/src/com/android/server/connectivity/mdns/MdnsFeatureFlags.java
index 6f7645e..738c151 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsFeatureFlags.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsFeatureFlags.java
@@ -36,6 +36,11 @@
public static final String NSD_EXPIRED_SERVICES_REMOVAL =
"nsd_expired_services_removal";
+ /**
+ * A feature flag to control whether the label count limit should be enabled.
+ */
+ public static final String NSD_LIMIT_LABEL_COUNT = "nsd_limit_label_count";
+
// Flag for offload feature
public final boolean mIsMdnsOffloadFeatureEnabled;
@@ -45,14 +50,20 @@
// Flag for expired services removal
public final boolean mIsExpiredServicesRemovalEnabled;
+ // Flag for label count limit
+ public final boolean mIsLabelCountLimitEnabled;
+
/**
* The constructor for {@link MdnsFeatureFlags}.
*/
public MdnsFeatureFlags(boolean isOffloadFeatureEnabled,
- boolean includeInetAddressRecordsInProbing, boolean isExpiredServicesRemovalEnabled) {
+ boolean includeInetAddressRecordsInProbing,
+ boolean isExpiredServicesRemovalEnabled,
+ boolean isLabelCountLimitEnabled) {
mIsMdnsOffloadFeatureEnabled = isOffloadFeatureEnabled;
mIncludeInetAddressRecordsInProbing = includeInetAddressRecordsInProbing;
mIsExpiredServicesRemovalEnabled = isExpiredServicesRemovalEnabled;
+ mIsLabelCountLimitEnabled = isLabelCountLimitEnabled;
}
@@ -67,6 +78,7 @@
private boolean mIsMdnsOffloadFeatureEnabled;
private boolean mIncludeInetAddressRecordsInProbing;
private boolean mIsExpiredServicesRemovalEnabled;
+ private boolean mIsLabelCountLimitEnabled;
/**
* The constructor for {@link Builder}.
@@ -75,6 +87,7 @@
mIsMdnsOffloadFeatureEnabled = false;
mIncludeInetAddressRecordsInProbing = false;
mIsExpiredServicesRemovalEnabled = true; // Default enabled.
+ mIsLabelCountLimitEnabled = true; // Default enabled.
}
/**
@@ -109,11 +122,23 @@
}
/**
+ * Set whether the label count limit is enabled.
+ *
+ * @see #NSD_LIMIT_LABEL_COUNT
+ */
+ public Builder setIsLabelCountLimitEnabled(boolean isLabelCountLimitEnabled) {
+ mIsLabelCountLimitEnabled = isLabelCountLimitEnabled;
+ return this;
+ }
+
+ /**
* Builds a {@link MdnsFeatureFlags} with the arguments supplied to this builder.
*/
public MdnsFeatureFlags build() {
return new MdnsFeatureFlags(mIsMdnsOffloadFeatureEnabled,
- mIncludeInetAddressRecordsInProbing, mIsExpiredServicesRemovalEnabled);
+ mIncludeInetAddressRecordsInProbing,
+ mIsExpiredServicesRemovalEnabled,
+ mIsLabelCountLimitEnabled);
}
}
}
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java b/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java
index 42a6b0d..62c37ad 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java
@@ -65,11 +65,12 @@
private final MdnsProber mProber;
@NonNull
private final MdnsReplySender mReplySender;
-
@NonNull
private final SharedLog mSharedLog;
@NonNull
private final byte[] mPacketCreationBuffer;
+ @NonNull
+ private final MdnsFeatureFlags mMdnsFeatureFlags;
/**
* Callbacks called by {@link MdnsInterfaceAdvertiser} to report status updates.
@@ -213,6 +214,7 @@
mProber = deps.makeMdnsProber(sharedLog.getTag(), looper, mReplySender, mProbingCallback,
sharedLog);
mSharedLog = sharedLog;
+ mMdnsFeatureFlags = mdnsFeatureFlags;
}
/**
@@ -351,7 +353,7 @@
public void handlePacket(byte[] recvbuf, int length, InetSocketAddress src) {
final MdnsPacket packet;
try {
- packet = MdnsPacket.parse(new MdnsPacketReader(recvbuf, length));
+ packet = MdnsPacket.parse(new MdnsPacketReader(recvbuf, length, mMdnsFeatureFlags));
} catch (MdnsPacket.ParseException e) {
mSharedLog.e("Error parsing mDNS packet", e);
if (DBG) {
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java b/service-t/src/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java
index 4ba6912..e7b0eaa 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java
@@ -50,6 +50,7 @@
@NonNull private final Handler mHandler;
@NonNull private final MdnsSocketProvider mSocketProvider;
@NonNull private final SharedLog mSharedLog;
+ @NonNull private final MdnsFeatureFlags mMdnsFeatureFlags;
private final ArrayMap<MdnsServiceBrowserListener, InterfaceSocketCallback> mSocketRequests =
new ArrayMap<>();
@@ -58,11 +59,12 @@
private int mReceivedPacketNumber = 0;
public MdnsMultinetworkSocketClient(@NonNull Looper looper,
- @NonNull MdnsSocketProvider provider,
- @NonNull SharedLog sharedLog) {
+ @NonNull MdnsSocketProvider provider, @NonNull SharedLog sharedLog,
+ @NonNull MdnsFeatureFlags mdnsFeatureFlags) {
mHandler = new Handler(looper);
mSocketProvider = provider;
mSharedLog = sharedLog;
+ mMdnsFeatureFlags = mdnsFeatureFlags;
}
private class InterfaceSocketCallback implements MdnsSocketProvider.SocketCallback {
@@ -239,7 +241,7 @@
final MdnsPacket response;
try {
- response = MdnsResponseDecoder.parseResponse(recvbuf, length);
+ response = MdnsResponseDecoder.parseResponse(recvbuf, length, mMdnsFeatureFlags);
} catch (MdnsPacket.ParseException e) {
if (e.code != MdnsResponseErrorCode.ERROR_NOT_RESPONSE_MESSAGE) {
mSharedLog.e(e.getMessage(), e);
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsPacketReader.java b/service-t/src/com/android/server/connectivity/mdns/MdnsPacketReader.java
index aa38844..32060a2 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsPacketReader.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsPacketReader.java
@@ -16,6 +16,7 @@
package com.android.server.connectivity.mdns;
+import android.annotation.NonNull;
import android.annotation.Nullable;
import android.util.SparseArray;
@@ -30,24 +31,31 @@
/** Simple decoder for mDNS packets. */
public class MdnsPacketReader {
+ // The total length in bytes should be less than 255 bytes anyway (including labels and label
+ // length) per RFC9267, so limit the number of labels to 128 (each label is 2 bytes with the
+ // length).
+ // https://www.rfc-editor.org/rfc/rfc9267.html#name-label-and-name-length-valid
+ private static final int LABEL_COUNT_LIMIT = 128;
private final byte[] buf;
private final int count;
private final SparseArray<LabelEntry> labelDictionary;
+ private final MdnsFeatureFlags mMdnsFeatureFlags;
private int pos;
private int limit;
/** Constructs a reader for the given packet. */
public MdnsPacketReader(DatagramPacket packet) {
- this(packet.getData(), packet.getLength());
+ this(packet.getData(), packet.getLength(), MdnsFeatureFlags.newBuilder().build());
}
/** Constructs a reader for the given packet. */
- public MdnsPacketReader(byte[] buffer, int length) {
+ public MdnsPacketReader(byte[] buffer, int length, @NonNull MdnsFeatureFlags mdnsFeatureFlags) {
buf = buffer;
count = length;
pos = 0;
limit = -1;
labelDictionary = new SparseArray<>(16);
+ mMdnsFeatureFlags = mdnsFeatureFlags;
}
/**
@@ -137,6 +145,7 @@
public String[] readLabels() throws IOException {
List<String> result = new ArrayList<>(5);
LabelEntry previousEntry = null;
+ int tracingHops = 0;
while (getRemaining() > 0) {
byte nextByte = peekByte();
@@ -161,6 +170,10 @@
// Follow the chain of labels starting at this pointer, adding all of them onto the
// result.
while (labelOffset != 0) {
+ if (mMdnsFeatureFlags.mIsLabelCountLimitEnabled
+ && tracingHops > LABEL_COUNT_LIMIT) {
+ throw new IOException("Invalid MDNS response packet: Too many labels.");
+ }
LabelEntry entry = labelDictionary.get(labelOffset);
if (entry == null) {
throw new IOException(
@@ -169,6 +182,7 @@
}
result.add(entry.label);
labelOffset = entry.nextOffset;
+ tracingHops++;
}
break;
} else {
@@ -269,4 +283,4 @@
this.label = label;
}
}
-}
\ No newline at end of file
+}
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsResponseDecoder.java b/service-t/src/com/android/server/connectivity/mdns/MdnsResponseDecoder.java
index 050913f..b812bb4 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsResponseDecoder.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsResponseDecoder.java
@@ -84,9 +84,9 @@
* @throws MdnsPacket.ParseException if a response packet could not be parsed.
*/
@NonNull
- public static MdnsPacket parseResponse(@NonNull byte[] recvbuf, int length)
- throws MdnsPacket.ParseException {
- MdnsPacketReader reader = new MdnsPacketReader(recvbuf, length);
+ public static MdnsPacket parseResponse(@NonNull byte[] recvbuf, int length,
+ @NonNull MdnsFeatureFlags mdnsFeatureFlags) throws MdnsPacket.ParseException {
+ final MdnsPacketReader reader = new MdnsPacketReader(recvbuf, length, mdnsFeatureFlags);
final MdnsPacket mdnsPacket;
try {
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClient.java b/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClient.java
index d18a19b..82c8c5b 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClient.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClient.java
@@ -105,9 +105,10 @@
private AtomicInteger packetsCount;
@Nullable private Timer checkMulticastResponseTimer;
private final SharedLog sharedLog;
+ @NonNull private final MdnsFeatureFlags mdnsFeatureFlags;
public MdnsSocketClient(@NonNull Context context, @NonNull MulticastLock multicastLock,
- SharedLog sharedLog) {
+ SharedLog sharedLog, @NonNull MdnsFeatureFlags mdnsFeatureFlags) {
this.sharedLog = sharedLog;
this.context = context;
this.multicastLock = multicastLock;
@@ -116,6 +117,7 @@
} else {
unicastReceiverBuffer = null;
}
+ this.mdnsFeatureFlags = mdnsFeatureFlags;
}
@Override
@@ -454,7 +456,8 @@
final MdnsPacket response;
try {
- response = MdnsResponseDecoder.parseResponse(packet.getData(), packet.getLength());
+ response = MdnsResponseDecoder.parseResponse(
+ packet.getData(), packet.getLength(), mdnsFeatureFlags);
} catch (MdnsPacket.ParseException e) {
sharedLog.w(String.format("Error while decoding %s packet (%d): %d",
responseType, packetNumber, e.code));
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClientTest.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClientTest.java
index 8917ed3..ad30ce0 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClientTest.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClientTest.java
@@ -82,8 +82,8 @@
mHandlerThread.start();
mHandler = new Handler(mHandlerThread.getLooper());
mSocketKey = new SocketKey(1000 /* interfaceIndex */);
- mSocketClient = new MdnsMultinetworkSocketClient(
- mHandlerThread.getLooper(), mProvider, mSharedLog);
+ mSocketClient = new MdnsMultinetworkSocketClient(mHandlerThread.getLooper(), mProvider,
+ mSharedLog, MdnsFeatureFlags.newBuilder().build());
mHandler.post(() -> mSocketClient.setCallback(mCallback));
}
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsPacketReaderTests.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsPacketReaderTests.java
index 19d8a00..0168b61 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsPacketReaderTests.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsPacketReaderTests.java
@@ -19,8 +19,10 @@
import static com.android.testutils.DevSdkIgnoreRuleKt.SC_V2;
import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertThrows;
import static org.junit.Assert.fail;
+import com.android.net.module.util.HexDump;
import com.android.testutils.DevSdkIgnoreRule;
import com.android.testutils.DevSdkIgnoreRunner;
@@ -75,7 +77,7 @@
+ "the packet length");
} catch (IOException e) {
// Expected
- } catch (Exception e) {
+ } catch (RuntimeException e) {
fail(String.format(
Locale.ROOT,
"Should not have thrown any other exception except " + "for IOException: %s",
@@ -83,4 +85,17 @@
}
assertEquals(data.length, packetReader.getRemaining());
}
-}
\ No newline at end of file
+
+ @Test
+ public void testInfinitePtrLoop() {
+ // Fake mdns response packet label portion which has infinite ptr loop.
+ final byte[] infinitePtrLoopData = HexDump.hexStringToByteArray(
+ "054C4142454C" // label "LABEL"
+ + "0454455354" // label "TEST"
+ + "C006"); // PTR to second label.
+ MdnsPacketReader packetReader = new MdnsPacketReader(
+ infinitePtrLoopData, infinitePtrLoopData.length,
+ MdnsFeatureFlags.newBuilder().setIsLabelCountLimitEnabled(true).build());
+ assertThrows(IOException.class, packetReader::readLabels);
+ }
+}
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsPacketTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsPacketTest.kt
index 28ea4b6..fc4796b 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsPacketTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsPacketTest.kt
@@ -21,12 +21,16 @@
import com.android.testutils.DevSdkIgnoreRunner
import kotlin.test.assertContentEquals
import kotlin.test.assertEquals
+import kotlin.test.assertFailsWith
import kotlin.test.assertTrue
import org.junit.Test
import org.junit.runner.RunWith
@RunWith(DevSdkIgnoreRunner::class)
class MdnsPacketTest {
+ private fun makeFlags(isLabelCountLimitEnabled: Boolean = false): MdnsFeatureFlags =
+ MdnsFeatureFlags.newBuilder()
+ .setIsLabelCountLimitEnabled(isLabelCountLimitEnabled).build()
@Test
fun testParseQuery() {
// Probe packet with 1 question for Android.local, and 4 additionalRecords with 4 addresses
@@ -38,7 +42,7 @@
"010db8000000000000000000000789"
val bytes = HexDump.hexStringToByteArray(packetHex)
- val reader = MdnsPacketReader(bytes, bytes.size)
+ val reader = MdnsPacketReader(bytes, bytes.size, makeFlags())
val packet = MdnsPacket.parse(reader)
assertEquals(123, packet.transactionId)
@@ -68,4 +72,17 @@
assertEquals(InetAddresses.parseNumericAddress("2001:db8::789"),
(packet.authorityRecords[3] as MdnsInetAddressRecord).inet6Address!!)
}
+
+ @Test
+ fun testParseQueryWithLabelLoop_ThrowsParseException() {
+ val packetWithErrorHex = "000084000000000100000000054C4142454C0454455354C006000C800100000" +
+ "07800140454455354056C6F63616C00"
+
+ val bytes = HexDump.hexStringToByteArray(packetWithErrorHex)
+ val reader = MdnsPacketReader(
+ bytes, bytes.size, makeFlags(isLabelCountLimitEnabled = true))
+ assertFailsWith<MdnsPacket.ParseException> {
+ MdnsPacket.parse(reader)
+ }
+ }
}
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsResponseDecoderTests.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsResponseDecoderTests.java
index 3fc656a..a22e8c6 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsResponseDecoderTests.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsResponseDecoderTests.java
@@ -17,8 +17,10 @@
package com.android.server.connectivity.mdns;
import static android.net.InetAddresses.parseNumericAddress;
+
import static com.android.server.connectivity.mdns.util.MdnsUtils.Clock;
import static com.android.testutils.DevSdkIgnoreRuleKt.SC_V2;
+
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
@@ -337,7 +339,8 @@
packet.setSocketAddress(
new InetSocketAddress(MdnsConstants.getMdnsIPv6Address(), MdnsConstants.MDNS_PORT));
- final MdnsPacket parsedPacket = MdnsResponseDecoder.parseResponse(data6, data6.length);
+ final MdnsPacket parsedPacket = MdnsResponseDecoder.parseResponse(
+ data6, data6.length, MdnsFeatureFlags.newBuilder().build());
assertNotNull(parsedPacket);
final Network network = mock(Network.class);
@@ -636,7 +639,8 @@
private ArraySet<MdnsResponse> decode(MdnsResponseDecoder decoder, byte[] data,
Collection<MdnsResponse> existingResponses) throws MdnsPacket.ParseException {
- final MdnsPacket parsedPacket = MdnsResponseDecoder.parseResponse(data, data.length);
+ final MdnsPacket parsedPacket = MdnsResponseDecoder.parseResponse(
+ data, data.length, MdnsFeatureFlags.newBuilder().build());
assertNotNull(parsedPacket);
return new ArraySet<>(decoder.augmentResponses(parsedPacket,
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketClientTests.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketClientTests.java
index 74f1c37..8b7ab71 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketClientTests.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketClientTests.java
@@ -78,6 +78,7 @@
@Mock private SharedLog sharedLog;
private MdnsSocketClient mdnsClient;
+ private MdnsFeatureFlags flags = MdnsFeatureFlags.newBuilder().build();
@Before
public void setup() throws RuntimeException, IOException {
@@ -86,7 +87,7 @@
when(mockWifiManager.createMulticastLock(ArgumentMatchers.anyString()))
.thenReturn(mockMulticastLock);
- mdnsClient = new MdnsSocketClient(mContext, mockMulticastLock, sharedLog) {
+ mdnsClient = new MdnsSocketClient(mContext, mockMulticastLock, sharedLog, flags) {
@Override
MdnsSocket createMdnsSocket(int port, SharedLog sharedLog) throws IOException {
if (port == MdnsConstants.MDNS_PORT) {
@@ -515,7 +516,7 @@
//MdnsConfigsFlagsImpl.allowNetworkInterfaceIndexPropagation.override(true);
when(mockMulticastSocket.getInterfaceIndex()).thenReturn(21);
- mdnsClient = new MdnsSocketClient(mContext, mockMulticastLock, sharedLog) {
+ mdnsClient = new MdnsSocketClient(mContext, mockMulticastLock, sharedLog, flags) {
@Override
MdnsSocket createMdnsSocket(int port, SharedLog sharedLog) {
if (port == MdnsConstants.MDNS_PORT) {
@@ -538,7 +539,7 @@
//MdnsConfigsFlagsImpl.allowNetworkInterfaceIndexPropagation.override(false);
when(mockMulticastSocket.getInterfaceIndex()).thenReturn(21);
- mdnsClient = new MdnsSocketClient(mContext, mockMulticastLock, sharedLog) {
+ mdnsClient = new MdnsSocketClient(mContext, mockMulticastLock, sharedLog, flags) {
@Override
MdnsSocket createMdnsSocket(int port, SharedLog sharedLog) {
if (port == MdnsConstants.MDNS_PORT) {