Merge "Implement NetworkRequestsState atom" into main
diff --git a/framework/src/android/net/ConnectivityManager.java b/framework/src/android/net/ConnectivityManager.java
index fa27d0e..1ea1815 100644
--- a/framework/src/android/net/ConnectivityManager.java
+++ b/framework/src/android/net/ConnectivityManager.java
@@ -6022,6 +6022,13 @@
/**
* Sets data saver switch.
*
+ * <p>This API configures the bandwidth control, and filling data saver status in BpfMap,
+ * which is intended for internal use by the network stack to optimize performance
+ * when frequently checking data saver status for multiple uids without doing IPC.
+ * It does not directly control the global data saver mode that users manage in settings.
+ * To query the comprehensive data saver status for a specific UID, including allowlist
+ * considerations, use {@link #getRestrictBackgroundStatus}.
+ *
* @param enable True if enable.
* @throws IllegalStateException if failed.
* @hide
diff --git a/service-t/src/com/android/server/NsdService.java b/service-t/src/com/android/server/NsdService.java
index 6c25d76..9b2f80b 100644
--- a/service-t/src/com/android/server/NsdService.java
+++ b/service-t/src/com/android/server/NsdService.java
@@ -176,7 +176,7 @@
"mdns_advertiser_allowlist_";
private static final String MDNS_ALLOWLIST_FLAG_SUFFIX = "_version";
-
+ private static final String FORCE_ENABLE_FLAG_FOR_TEST_PREFIX = "test_";
@VisibleForTesting
static final String MDNS_CONFIG_RUNNING_APP_ACTIVE_IMPORTANCE_CUTOFF =
@@ -1739,6 +1739,10 @@
mContext, MdnsFeatureFlags.NSD_LIMIT_LABEL_COUNT))
.setIsKnownAnswerSuppressionEnabled(mDeps.isFeatureEnabled(
mContext, MdnsFeatureFlags.NSD_KNOWN_ANSWER_SUPPRESSION))
+ .setIsUnicastReplyEnabled(mDeps.isFeatureEnabled(
+ mContext, MdnsFeatureFlags.NSD_UNICAST_REPLY_ENABLED))
+ .setOverrideProvider(flag -> mDeps.isFeatureEnabled(
+ mContext, FORCE_ENABLE_FLAG_FOR_TEST_PREFIX + flag))
.build();
mMdnsSocketClient =
new MdnsMultinetworkSocketClient(handler.getLooper(), mMdnsSocketProvider,
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsFeatureFlags.java b/service-t/src/com/android/server/connectivity/mdns/MdnsFeatureFlags.java
index 1ad47a3..9466162 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsFeatureFlags.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsFeatureFlags.java
@@ -15,6 +15,9 @@
*/
package com.android.server.connectivity.mdns;
+import android.annotation.NonNull;
+import android.annotation.Nullable;
+
/**
* The class that contains mDNS feature flags;
*/
@@ -46,6 +49,14 @@
*/
public static final String NSD_KNOWN_ANSWER_SUPPRESSION = "nsd_known_answer_suppression";
+ /**
+ * A feature flag to control whether unicast replies should be enabled.
+ *
+ * <p>Enabling this feature causes replies to queries with the Query Unicast (QU) flag set to be
+ * sent unicast instead of multicast, as per RFC6762 5.4.
+ */
+ public static final String NSD_UNICAST_REPLY_ENABLED = "nsd_unicast_reply_enabled";
+
// Flag for offload feature
public final boolean mIsMdnsOffloadFeatureEnabled;
@@ -61,6 +72,36 @@
// Flag for known-answer suppression
public final boolean mIsKnownAnswerSuppressionEnabled;
+ // Flag to enable replying unicast to queries requesting unicast replies
+ public final boolean mIsUnicastReplyEnabled;
+
+ @Nullable
+ private final FlagOverrideProvider mOverrideProvider;
+
+ /**
+ * A provider that can indicate whether a flag should be force-enabled for testing purposes.
+ */
+ public interface FlagOverrideProvider {
+ /**
+ * Indicates whether the flag should be force-enabled for testing purposes.
+ */
+ boolean isForceEnabledForTest(@NonNull String flag);
+ }
+
+ /**
+ * Indicates whether the flag should be force-enabled for testing purposes.
+ */
+ private boolean isForceEnabledForTest(@NonNull String flag) {
+ return mOverrideProvider != null && mOverrideProvider.isForceEnabledForTest(flag);
+ }
+
+ /**
+ * Indicates whether {@link #NSD_UNICAST_REPLY_ENABLED} is enabled, including for testing.
+ */
+ public boolean isUnicastReplyEnabled() {
+ return mIsUnicastReplyEnabled || isForceEnabledForTest(NSD_UNICAST_REPLY_ENABLED);
+ }
+
/**
* The constructor for {@link MdnsFeatureFlags}.
*/
@@ -68,12 +109,16 @@
boolean includeInetAddressRecordsInProbing,
boolean isExpiredServicesRemovalEnabled,
boolean isLabelCountLimitEnabled,
- boolean isKnownAnswerSuppressionEnabled) {
+ boolean isKnownAnswerSuppressionEnabled,
+ boolean isUnicastReplyEnabled,
+ @Nullable FlagOverrideProvider overrideProvider) {
mIsMdnsOffloadFeatureEnabled = isOffloadFeatureEnabled;
mIncludeInetAddressRecordsInProbing = includeInetAddressRecordsInProbing;
mIsExpiredServicesRemovalEnabled = isExpiredServicesRemovalEnabled;
mIsLabelCountLimitEnabled = isLabelCountLimitEnabled;
mIsKnownAnswerSuppressionEnabled = isKnownAnswerSuppressionEnabled;
+ mIsUnicastReplyEnabled = isUnicastReplyEnabled;
+ mOverrideProvider = overrideProvider;
}
@@ -90,6 +135,8 @@
private boolean mIsExpiredServicesRemovalEnabled;
private boolean mIsLabelCountLimitEnabled;
private boolean mIsKnownAnswerSuppressionEnabled;
+ private boolean mIsUnicastReplyEnabled;
+ private FlagOverrideProvider mOverrideProvider;
/**
* The constructor for {@link Builder}.
@@ -100,6 +147,8 @@
mIsExpiredServicesRemovalEnabled = false;
mIsLabelCountLimitEnabled = true; // Default enabled.
mIsKnownAnswerSuppressionEnabled = false;
+ mIsUnicastReplyEnabled = true;
+ mOverrideProvider = null;
}
/**
@@ -154,6 +203,27 @@
}
/**
+ * Set whether the unicast reply feature is enabled.
+ *
+ * @see #NSD_UNICAST_REPLY_ENABLED
+ */
+ public Builder setIsUnicastReplyEnabled(boolean isUnicastReplyEnabled) {
+ mIsUnicastReplyEnabled = isUnicastReplyEnabled;
+ return this;
+ }
+
+ /**
+ * Set a {@link FlagOverrideProvider} to be used by {@link #isForceEnabledForTest(String)}.
+ *
+ * If non-null, features that use {@link #isForceEnabledForTest(String)} will use that
+ * provider to query whether the flag should be force-enabled.
+ */
+ public Builder setOverrideProvider(@Nullable FlagOverrideProvider overrideProvider) {
+ mOverrideProvider = overrideProvider;
+ return this;
+ }
+
+ /**
* Builds a {@link MdnsFeatureFlags} with the arguments supplied to this builder.
*/
public MdnsFeatureFlags build() {
@@ -161,7 +231,9 @@
mIncludeInetAddressRecordsInProbing,
mIsExpiredServicesRemovalEnabled,
mIsLabelCountLimitEnabled,
- mIsKnownAnswerSuppressionEnabled);
+ mIsKnownAnswerSuppressionEnabled,
+ mIsUnicastReplyEnabled,
+ mOverrideProvider);
}
}
}
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java b/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java
index aa40c92..3a04dcd 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiser.java
@@ -373,12 +373,14 @@
}
return;
}
+ // recvbuf and src are reused after this returns; ensure references to src are not kept.
+ final InetSocketAddress srcCopy = new InetSocketAddress(src.getAddress(), src.getPort());
if (DBG) {
mSharedLog.v("Parsed packet with " + packet.questions.size() + " questions, "
+ packet.answers.size() + " answers, "
+ packet.authorityRecords.size() + " authority, "
- + packet.additionalRecords.size() + " additional from " + src);
+ + packet.additionalRecords.size() + " additional from " + srcCopy);
}
for (int conflictServiceId : mRecordRepository.getConflictingServices(packet)) {
@@ -389,7 +391,7 @@
// happen when the incoming packet has answer records (not a question), so there will be no
// answer. One exception is simultaneous probe tiebreaking (rfc6762 8.2), in which case the
// conflicting service is still probing and won't reply either.
- final MdnsReplyInfo answers = mRecordRepository.getReply(packet, src);
+ final MdnsReplyInfo answers = mRecordRepository.getReply(packet, srcCopy);
if (answers == null) return;
mReplySender.queueReply(answers);
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsRecord.java b/service-t/src/com/android/server/connectivity/mdns/MdnsRecord.java
index 28bd1b4..4b43989 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsRecord.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsRecord.java
@@ -176,6 +176,16 @@
}
/**
+ * For questions, returns whether a unicast reply was requested.
+ *
+ * In practice this is identical to {@link #getCacheFlush()}, as the "cache flush" flag in
+ * replies is the same as "unicast reply requested" in questions.
+ */
+ public final boolean isUnicastReplyRequested() {
+ return (cls & MdnsConstants.QCLASS_UNICAST) != 0;
+ }
+
+ /**
* Returns the record's remaining TTL.
*
* If the record was not sent yet (receipt time {@link #RECEIPT_TIME_NOT_SENT}), this is the
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java b/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java
index 6b6632c..585b097 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java
@@ -499,26 +499,30 @@
@Nullable
public MdnsReplyInfo getReply(MdnsPacket packet, InetSocketAddress src) {
final long now = SystemClock.elapsedRealtime();
- final boolean replyUnicast = (packet.flags & MdnsConstants.QCLASS_UNICAST) != 0;
// Use LinkedHashSet for preserving the insert order of the RRs, so that RRs of the same
// service or host are grouped together (which is more developer-friendly).
final Set<RecordInfo<?>> answerInfo = new LinkedHashSet<>();
final Set<RecordInfo<?>> additionalAnswerInfo = new LinkedHashSet<>();
-
+ // Reply unicast if the feature is enabled AND all replied questions request unicast
+ final boolean replyUnicastEnabled = mMdnsFeatureFlags.isUnicastReplyEnabled();
+ boolean replyUnicast = replyUnicastEnabled;
for (MdnsRecord question : packet.questions) {
// Add answers from general records
- addReplyFromService(question, mGeneralRecords, null /* servicePtrRecord */,
- null /* serviceSrvRecord */, null /* serviceTxtRecord */, replyUnicast, now,
- answerInfo, additionalAnswerInfo, Collections.emptyList());
+ if (addReplyFromService(question, mGeneralRecords, null /* servicePtrRecord */,
+ null /* serviceSrvRecord */, null /* serviceTxtRecord */, replyUnicastEnabled,
+ now, answerInfo, additionalAnswerInfo, Collections.emptyList())) {
+ replyUnicast &= question.isUnicastReplyRequested();
+ }
// Add answers from each service
for (int i = 0; i < mServices.size(); i++) {
final ServiceRegistration registration = mServices.valueAt(i);
if (registration.exiting || registration.isProbing) continue;
if (addReplyFromService(question, registration.allRecords, registration.ptrRecords,
- registration.srvRecord, registration.txtRecord, replyUnicast, now,
+ registration.srvRecord, registration.txtRecord, replyUnicastEnabled, now,
answerInfo, additionalAnswerInfo, packet.answers)) {
+ replyUnicast &= question.isUnicastReplyRequested();
registration.repliedServiceCount++;
registration.sentPacketCount++;
}
@@ -570,6 +574,12 @@
// Determine the send destination
final InetSocketAddress dest;
if (replyUnicast) {
+ // As per RFC6762 5.4, "if the responder has not multicast that record recently (within
+ // one quarter of its TTL), then the responder SHOULD instead multicast the response so
+ // as to keep all the peer caches up to date": this SHOULD is not implemented to
+ // minimize latency for queriers who have just started, so they did not receive previous
+ // multicast responses. Unicast replies are faster as they do not need to wait for the
+ // beacon interval on Wi-Fi.
dest = src;
} else if (src.getAddress() instanceof Inet4Address) {
dest = IPV4_SOCKET_ADDR;
@@ -608,7 +618,7 @@
@Nullable List<RecordInfo<MdnsPointerRecord>> servicePtrRecords,
@Nullable RecordInfo<MdnsServiceRecord> serviceSrvRecord,
@Nullable RecordInfo<MdnsTextRecord> serviceTxtRecord,
- boolean replyUnicast, long now, @NonNull Set<RecordInfo<?>> answerInfo,
+ boolean replyUnicastEnabled, long now, @NonNull Set<RecordInfo<?>> answerInfo,
@NonNull Set<RecordInfo<?>> additionalAnswerInfo,
@NonNull List<MdnsRecord> knownAnswerRecords) {
boolean hasDnsSdPtrRecordAnswer = false;
@@ -659,7 +669,8 @@
// TODO: responses to probe queries should bypass this check and only ensure the
// reply is sent 250ms after the last sent time (RFC 6762 p.15)
- if (!replyUnicast && info.lastAdvertisedTimeMs > 0L
+ if (!(replyUnicastEnabled && question.isUnicastReplyRequested())
+ && info.lastAdvertisedTimeMs > 0L
&& now - info.lastAdvertisedTimeMs < MIN_MULTICAST_REPLY_INTERVAL_MS) {
continue;
}
diff --git a/service-t/src/com/android/server/connectivity/mdns/MulticastPacketReader.java b/service-t/src/com/android/server/connectivity/mdns/MulticastPacketReader.java
index 3cd77a4..70451f3 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MulticastPacketReader.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MulticastPacketReader.java
@@ -42,6 +42,12 @@
private final Set<PacketHandler> mPacketHandlers = MdnsUtils.newSet();
interface PacketHandler {
+ /**
+ * Handle an incoming packet.
+ *
+ * The recvbuf and src <b>will be reused and modified</b> after this method returns, so
+ * implementers must ensure that they are not accessed after handlePacket returns.
+ */
void handlePacket(byte[] recvbuf, int length, InetSocketAddress src);
}
diff --git a/staticlibs/testutils/devicetests/NSResponder.kt b/staticlibs/testutils/devicetests/NSResponder.kt
new file mode 100644
index 0000000..f7619cd
--- /dev/null
+++ b/staticlibs/testutils/devicetests/NSResponder.kt
@@ -0,0 +1,86 @@
+/*
+ * Copyright (C) 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.testutils
+
+import android.net.MacAddress
+import android.util.Log
+import com.android.net.module.util.Ipv6Utils
+import com.android.net.module.util.NetworkStackConstants.ETHER_HEADER_LEN
+import com.android.net.module.util.NetworkStackConstants.ICMPV6_ND_OPTION_TLLA
+import com.android.net.module.util.NetworkStackConstants.NEIGHBOR_ADVERTISEMENT_FLAG_SOLICITED
+import com.android.net.module.util.Struct
+import com.android.net.module.util.structs.Icmpv6Header
+import com.android.net.module.util.structs.Ipv6Header
+import com.android.net.module.util.structs.LlaOption
+import com.android.net.module.util.structs.NsHeader
+import com.android.testutils.PacketReflector.IPV6_HEADER_LENGTH
+import java.lang.IllegalArgumentException
+import java.net.Inet6Address
+import java.nio.ByteBuffer
+
+private const val NS_TYPE = 135.toShort()
+
+/**
+ * A class that can be used to reply to Neighbor Solicitation packets on a [TapPacketReader].
+ */
+class NSResponder(
+ reader: TapPacketReader,
+ table: Map<Inet6Address, MacAddress>,
+ name: String = NSResponder::class.java.simpleName
+) : PacketResponder(reader, Icmpv6Filter(), name) {
+ companion object {
+ private val TAG = NSResponder::class.simpleName
+ }
+
+ // Copy the map if not already immutable (toMap) to make sure it is not modified
+ private val table = table.toMap()
+
+ override fun replyToPacket(packet: ByteArray, reader: TapPacketReader) {
+ if (packet.size < IPV6_HEADER_LENGTH) {
+ return
+ }
+ val buf = ByteBuffer.wrap(packet, ETHER_HEADER_LEN, packet.size - ETHER_HEADER_LEN)
+ val ipv6Header = parseOrLog(Ipv6Header::class.java, buf) ?: return
+ val icmpHeader = parseOrLog(Icmpv6Header::class.java, buf) ?: return
+ if (icmpHeader.type != NS_TYPE) {
+ return
+ }
+ val ns = parseOrLog(NsHeader::class.java, buf) ?: return
+ val replyMacAddr = table[ns.target] ?: return
+ val slla = parseOrLog(LlaOption::class.java, buf) ?: return
+ val requesterMac = slla.linkLayerAddress
+
+ val tlla = LlaOption.build(ICMPV6_ND_OPTION_TLLA.toByte(), replyMacAddr)
+ reader.sendResponse(Ipv6Utils.buildNaPacket(
+ replyMacAddr /* srcMac */,
+ requesterMac /* dstMac */,
+ ns.target /* srcIp */,
+ ipv6Header.srcIp /* dstIp */,
+ NEIGHBOR_ADVERTISEMENT_FLAG_SOLICITED,
+ ns.target,
+ tlla))
+ }
+
+ private fun <T> parseOrLog(clazz: Class<T>, buf: ByteBuffer): T? where T : Struct {
+ return try {
+ Struct.parse(clazz, buf)
+ } catch (e: IllegalArgumentException) {
+ Log.e(TAG, "Invalid ${clazz.simpleName} in ICMPv6 packet", e)
+ null
+ }
+ }
+}
diff --git a/staticlibs/testutils/devicetests/com/android/testutils/DeviceConfigRule.kt b/staticlibs/testutils/devicetests/com/android/testutils/DeviceConfigRule.kt
index 3d98cc3..68248ca 100644
--- a/staticlibs/testutils/devicetests/com/android/testutils/DeviceConfigRule.kt
+++ b/staticlibs/testutils/devicetests/com/android/testutils/DeviceConfigRule.kt
@@ -22,12 +22,12 @@
import android.util.Log
import com.android.modules.utils.build.SdkLevel
import com.android.testutils.FunctionalUtils.ThrowingRunnable
-import org.junit.rules.TestRule
-import org.junit.runner.Description
-import org.junit.runners.model.Statement
import java.util.concurrent.CompletableFuture
import java.util.concurrent.Executor
import java.util.concurrent.TimeUnit
+import org.junit.rules.TestRule
+import org.junit.runner.Description
+import org.junit.runners.model.Statement
private val TAG = DeviceConfigRule::class.simpleName
@@ -147,11 +147,11 @@
return tryTest {
runAsShell(*readWritePermissions) {
DeviceConfig.addOnPropertiesChangedListener(
- DeviceConfig.NAMESPACE_CONNECTIVITY,
+ namespace,
inlineExecutor,
listener)
DeviceConfig.setProperty(
- DeviceConfig.NAMESPACE_CONNECTIVITY,
+ namespace,
key,
value,
false /* makeDefault */)
diff --git a/staticlibs/testutils/devicetests/com/android/testutils/NatExternalPacketForwarder.kt b/staticlibs/testutils/devicetests/com/android/testutils/NatExternalPacketForwarder.kt
deleted file mode 100644
index d7961a0..0000000
--- a/staticlibs/testutils/devicetests/com/android/testutils/NatExternalPacketForwarder.kt
+++ /dev/null
@@ -1,81 +0,0 @@
-/*
- * Copyright (C) 2023 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.android.testutils
-
-import java.io.FileDescriptor
-import java.net.InetAddress
-
-/**
- * A class that forwards packets from the external {@link TestNetworkInterface} to the internal
- * {@link TestNetworkInterface} with NAT. See {@link NatPacketForwarderBase} for detail.
- */
-class NatExternalPacketForwarder(
- srcFd: FileDescriptor,
- mtu: Int,
- dstFd: FileDescriptor,
- extAddr: InetAddress,
- natMap: PacketBridge.NatMap
-) : NatPacketForwarderBase(srcFd, mtu, dstFd, extAddr, natMap) {
-
- /**
- * Rewrite addresses, ports and fix up checksums for packets received on the external
- * interface.
- *
- * Incoming response from external interface which is being forwarded to the internal
- * interface with translated address, e.g. 1.2.3.4:80 -> 8.8.8.8:1234
- * will be translated into 8.8.8.8:80 -> 192.168.1.1:5678.
- *
- * For packets that are not an incoming response, do not forward them to the
- * internal interface.
- */
- override fun preparePacketForForwarding(buf: ByteArray, len: Int, version: Int, proto: Int) {
- val (addrPos, addrLen) = getAddressPositionAndLength(version)
-
- // TODO: support one external address per ip version.
- val extAddrBuf = mExtAddr.address
- if (addrLen != extAddrBuf.size) throw IllegalStateException("Packet IP version mismatch")
-
- // Get internal address by port.
- val transportOffset =
- if (version == 4) PacketReflector.IPV4_HEADER_LENGTH
- else PacketReflector.IPV6_HEADER_LENGTH
- val dstPort = getPortAt(buf, transportOffset + DESTINATION_PORT_OFFSET)
- val intAddrInfo = synchronized(mNatMap) { mNatMap.fromExternalPort(dstPort) }
- // No mapping, skip. This usually happens if the connection is initiated directly on
- // the external interface, e.g. DNS64 resolution, network validation, etc.
- if (intAddrInfo == null) return
-
- val intAddrBuf = intAddrInfo.address.address
- val intPort = intAddrInfo.port
-
- // Copy the original destination to into the source address.
- for (i in 0 until addrLen) {
- buf[addrPos + i] = buf[addrPos + addrLen + i]
- }
-
- // Copy the internal address into the destination address.
- for (i in 0 until addrLen) {
- buf[addrPos + addrLen + i] = intAddrBuf[i]
- }
-
- // Copy the internal port into the destination port.
- setPortAt(intPort, buf, transportOffset + DESTINATION_PORT_OFFSET)
-
- // Fix IP and Transport layer checksum.
- fixPacketChecksum(buf, len, version, proto.toByte())
- }
-}
diff --git a/staticlibs/testutils/devicetests/com/android/testutils/NatInternalPacketForwarder.kt b/staticlibs/testutils/devicetests/com/android/testutils/NatInternalPacketForwarder.kt
deleted file mode 100644
index fa39d19..0000000
--- a/staticlibs/testutils/devicetests/com/android/testutils/NatInternalPacketForwarder.kt
+++ /dev/null
@@ -1,78 +0,0 @@
-/*
- * Copyright (C) 2023 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.android.testutils
-
-import java.io.FileDescriptor
-import java.net.InetAddress
-
-/**
- * A class that forwards packets from the internal {@link TestNetworkInterface} to the external
- * {@link TestNetworkInterface} with NAT. See {@link NatPacketForwarderBase} for detail.
- */
-class NatInternalPacketForwarder(
- srcFd: FileDescriptor,
- mtu: Int,
- dstFd: FileDescriptor,
- extAddr: InetAddress,
- natMap: PacketBridge.NatMap
-) : NatPacketForwarderBase(srcFd, mtu, dstFd, extAddr, natMap) {
-
- /**
- * Rewrite addresses, ports and fix up checksums for packets received on the internal
- * interface.
- *
- * Outgoing packet from the internal interface which is being forwarded to the
- * external interface with translated address, e.g. 192.168.1.1:5678 -> 8.8.8.8:80
- * will be translated into 8.8.8.8:1234 -> 1.2.3.4:80.
- *
- * The external port, e.g. 1234 in the above example, is the port number assigned by
- * the forwarder when creating the mapping to identify the source address and port when
- * the response is coming from the external interface. See {@link PacketBridge.NatMap}
- * for detail.
- */
- override fun preparePacketForForwarding(buf: ByteArray, len: Int, version: Int, proto: Int) {
- val (addrPos, addrLen) = getAddressPositionAndLength(version)
-
- // TODO: support one external address per ip version.
- val extAddrBuf = mExtAddr.address
- if (addrLen != extAddrBuf.size) throw IllegalStateException("Packet IP version mismatch")
-
- val srcAddr = getInetAddressAt(buf, addrPos, addrLen)
-
- // Copy the original destination to into the source address.
- for (i in 0 until addrLen) {
- buf[addrPos + i] = buf[addrPos + addrLen + i]
- }
-
- // Copy the external address into the destination address.
- for (i in 0 until addrLen) {
- buf[addrPos + addrLen + i] = extAddrBuf[i]
- }
-
- // Add an entry to NAT mapping table.
- val transportOffset =
- if (version == 4) PacketReflector.IPV4_HEADER_LENGTH
- else PacketReflector.IPV6_HEADER_LENGTH
- val srcPort = getPortAt(buf, transportOffset)
- val extPort = synchronized(mNatMap) { mNatMap.toExternalPort(srcAddr, srcPort, proto) }
- // Copy the external port to into the source port.
- setPortAt(extPort, buf, transportOffset)
-
- // Fix IP and Transport layer checksum.
- fixPacketChecksum(buf, len, version, proto.toByte())
- }
-}
diff --git a/staticlibs/testutils/devicetests/com/android/testutils/PacketBridge.kt b/staticlibs/testutils/devicetests/com/android/testutils/PacketBridge.kt
index d50f78a..1a2cc88 100644
--- a/staticlibs/testutils/devicetests/com/android/testutils/PacketBridge.kt
+++ b/staticlibs/testutils/devicetests/com/android/testutils/PacketBridge.kt
@@ -16,6 +16,7 @@
package com.android.testutils
+import android.annotation.SuppressLint
import android.content.Context
import android.net.ConnectivityManager
import android.net.LinkAddress
@@ -31,29 +32,26 @@
import java.net.InetAddress
import libcore.io.IoUtils
-private const val MIN_PORT_NUMBER = 1025
-private const val MAX_PORT_NUMBER = 65535
-
/**
- * A class that set up two {@link TestNetworkInterface} with NAT, and forward packets between them.
+ * A class that set up two {@link TestNetworkInterface}, and forward packets between them.
*
- * See {@link NatPacketForwarder} for more detailed information.
+ * See {@link PacketForwarder} for more detailed information.
*/
class PacketBridge(
context: Context,
- internalAddr: LinkAddress,
- externalAddr: LinkAddress,
+ addresses: List<LinkAddress>,
dnsAddr: InetAddress
) {
- private val natMap = NatMap()
private val binder = Binder()
private val cm = context.getSystemService(ConnectivityManager::class.java)!!
private val tnm = context.getSystemService(TestNetworkManager::class.java)!!
- // Create test networks.
- private val internalIface = tnm.createTunInterface(listOf(internalAddr))
- private val externalIface = tnm.createTunInterface(listOf(externalAddr))
+ // Create test networks. The needed permissions should be supplied by the callers.
+ @SuppressLint("MissingPermission")
+ private val internalIface = tnm.createTunInterface(addresses)
+ @SuppressLint("MissingPermission")
+ private val externalIface = tnm.createTunInterface(addresses)
// Register test networks to ConnectivityService.
private val internalNetworkCallback: TestableNetworkCallback
@@ -61,32 +59,20 @@
val internalNetwork: Network
val externalNetwork: Network
init {
- val (inCb, inNet) = createTestNetwork(internalIface, internalAddr, dnsAddr)
- val (exCb, exNet) = createTestNetwork(externalIface, externalAddr, dnsAddr)
+ val (inCb, inNet) = createTestNetwork(internalIface, addresses, dnsAddr)
+ val (exCb, exNet) = createTestNetwork(externalIface, addresses, dnsAddr)
internalNetworkCallback = inCb
externalNetworkCallback = exCb
internalNetwork = inNet
externalNetwork = exNet
}
- // Setup the packet bridge.
+ // Set up the packet bridge.
private val internalFd = internalIface.fileDescriptor.fileDescriptor
private val externalFd = externalIface.fileDescriptor.fileDescriptor
- private val pr1 = NatInternalPacketForwarder(
- internalFd,
- 1500,
- externalFd,
- externalAddr.address,
- natMap
- )
- private val pr2 = NatExternalPacketForwarder(
- externalFd,
- 1500,
- internalFd,
- externalAddr.address,
- natMap
- )
+ private val pr1 = PacketForwarder(internalFd, 1500, externalFd)
+ private val pr2 = PacketForwarder(externalFd, 1500, internalFd)
fun start() {
IoUtils.setBlocking(internalFd, true /* blocking */)
@@ -107,7 +93,7 @@
*/
private fun createTestNetwork(
testIface: TestNetworkInterface,
- addr: LinkAddress,
+ addresses: List<LinkAddress>,
dnsAddr: InetAddress
): Pair<TestableNetworkCallback, Network> {
// Make a network request to hold the test network
@@ -120,7 +106,7 @@
cm.requestNetwork(nr, testCb)
val lp = LinkProperties().apply {
- addLinkAddress(addr)
+ setLinkAddresses(addresses)
interfaceName = testIface.interfaceName
addDnsServer(dnsAddr)
}
@@ -130,44 +116,4 @@
val network = testCb.expect<Available>().network
return testCb to network
}
-
- /**
- * A helper class to maintain the mappings between internal addresses/ports and external
- * ports.
- *
- * This class assigns an unused external port number if the mapping between
- * srcaddress:srcport:protocol and the external port does not exist yet.
- *
- * Note that this class is not thread-safe. The instance of the class needs to be
- * synchronized in the callers when being used in multiple threads.
- */
- class NatMap {
- data class AddressInfo(val address: InetAddress, val port: Int, val protocol: Int)
-
- private val mToExternalPort = HashMap<AddressInfo, Int>()
- private val mFromExternalPort = HashMap<Int, AddressInfo>()
-
- // Skip well-known port 0~1024.
- private var nextExternalPort = MIN_PORT_NUMBER
-
- fun toExternalPort(addr: InetAddress, port: Int, protocol: Int): Int {
- val info = AddressInfo(addr, port, protocol)
- val extPort: Int
- if (!mToExternalPort.containsKey(info)) {
- extPort = nextExternalPort++
- if (nextExternalPort > MAX_PORT_NUMBER) {
- throw IllegalStateException("Available ports are exhausted")
- }
- mToExternalPort[info] = extPort
- mFromExternalPort[extPort] = info
- } else {
- extPort = mToExternalPort[info]!!
- }
- return extPort
- }
-
- fun fromExternalPort(port: Int): AddressInfo? {
- return mFromExternalPort[port]
- }
- }
}
diff --git a/staticlibs/testutils/devicetests/com/android/testutils/NatPacketForwarderBase.java b/staticlibs/testutils/devicetests/com/android/testutils/PacketForwarder.java
similarity index 62%
rename from staticlibs/testutils/devicetests/com/android/testutils/NatPacketForwarderBase.java
rename to staticlibs/testutils/devicetests/com/android/testutils/PacketForwarder.java
index 0a2b5d4..d8efb7d 100644
--- a/staticlibs/testutils/devicetests/com/android/testutils/NatPacketForwarderBase.java
+++ b/staticlibs/testutils/devicetests/com/android/testutils/PacketForwarder.java
@@ -30,16 +30,13 @@
import android.system.Os;
import android.util.Log;
-import androidx.annotation.GuardedBy;
-
import java.io.FileDescriptor;
import java.io.IOException;
-import java.net.InetAddress;
import java.util.Objects;
/**
* A class that forwards packets from a {@link TestNetworkInterface} to another
- * {@link TestNetworkInterface} with NAT.
+ * {@link TestNetworkInterface}.
*
* For testing purposes, a {@link TestNetworkInterface} provides a {@link FileDescriptor}
* which allows content injection on the test network. However, this could be hard to use
@@ -54,30 +51,14 @@
*
* To make it work, an internal interface and an external interface are defined, where
* the client might send packets from the internal interface which are originated from
- * multiple addresses to a server that listens on the external address.
- *
- * When forwarding the outgoing packet on the internal interface, a simple NAT mechanism
- * is implemented during forwarding, which will swap the source and destination,
- * but replacing the source address with the external address,
- * e.g. 192.168.1.1:1234 -> 8.8.8.8:80 will be translated into 8.8.8.8:1234 -> 1.2.3.4:80.
- *
- * For the above example, a client who sends http request will have a hallucination that
- * it is talking to a remote server at 8.8.8.8. Also, the server listens on 1.2.3.4 will
- * have a different hallucination that the request is sent from a remote client at 8.8.8.8,
- * to a local address 1.2.3.4.
- *
- * And a NAT mapping is created at the time when the outgoing packet is forwarded.
- * With a different internal source port, the instance learned that when a response with the
- * destination port 1234, it should forward the packet to the internal address 192.168.1.1.
+ * multiple addresses to a server that listens on the different port.
*
* For the incoming packet received from external interface, for example a http response sent
* from the http server, the same mechanism is applied but in a different direction,
- * where the source and destination will be swapped, and the source address will be replaced
- * with the internal address, which is obtained from the NAT mapping described above.
+ * where the source and destination will be swapped.
*/
-public abstract class NatPacketForwarderBase extends Thread {
- private static final String TAG = "NatPacketForwarder";
- static final int DESTINATION_PORT_OFFSET = 2;
+public class PacketForwarder extends Thread {
+ private static final String TAG = "PacketForwarder";
// The source fd to read packets from.
@NonNull
@@ -88,27 +69,12 @@
// The destination fd to write packets to.
@NonNull
final FileDescriptor mDstFd;
- // The NAT mapping table shared between two NatPacketForwarder instances to map from
- // the source port to the associated internal address. The map can be read/write from two
- // different threads on any given time whenever receiving packets on the
- // {@link TestNetworkInterface}. Thus, synchronize on the object when reading/writing is needed.
- @GuardedBy("mNatMap")
- @NonNull
- final PacketBridge.NatMap mNatMap;
- // The address of the external interface. See {@link NatPacketForwarder}.
- @NonNull
- final InetAddress mExtAddr;
/**
- * Construct a {@link NatPacketForwarderBase}.
+ * Construct a {@link PacketForwarder}.
*
* This class reads packets from {@code srcFd} of a {@link TestNetworkInterface}, and
- * forwards them to the {@code dstFd} of another {@link TestNetworkInterface} with
- * NAT applied. See {@link NatPacketForwarderBase}.
- *
- * To apply NAT, the address of the external interface needs to be supplied through
- * {@code extAddr} to identify the external interface. And a shared NAT mapping table,
- * {@code natMap} is needed to be shared between these two instances.
+ * forwards them to the {@code dstFd} of another {@link TestNetworkInterface}.
*
* Note that this class is not useful if the instance is not managed by a
* {@link PacketBridge} to set up a two-way communication.
@@ -116,29 +82,15 @@
* @param srcFd {@link FileDescriptor} to read packets from.
* @param mtu MTU of the test network.
* @param dstFd {@link FileDescriptor} to write packets to.
- * @param extAddr the external address, which is the address of the external interface.
- * See {@link NatPacketForwarderBase}.
- * @param natMap the NAT mapping table shared between two {@link NatPacketForwarderBase}
- * instance.
*/
- public NatPacketForwarderBase(@NonNull FileDescriptor srcFd, int mtu,
- @NonNull FileDescriptor dstFd, @NonNull InetAddress extAddr,
- @NonNull PacketBridge.NatMap natMap) {
+ public PacketForwarder(@NonNull FileDescriptor srcFd, int mtu,
+ @NonNull FileDescriptor dstFd) {
super(TAG);
mSrcFd = Objects.requireNonNull(srcFd);
mBuf = new byte[mtu];
mDstFd = Objects.requireNonNull(dstFd);
- mExtAddr = Objects.requireNonNull(extAddr);
- mNatMap = Objects.requireNonNull(natMap);
}
- /**
- * A method to prepare forwarding packets between two instances of {@link TestNetworkInterface},
- * which includes re-write addresses, ports and fix up checksums.
- * Subclasses should override this method to implement a simple NAT.
- */
- abstract void preparePacketForForwarding(@NonNull byte[] buf, int len, int version, int proto);
-
private void forwardPacket(@NonNull byte[] buf, int len) {
try {
Os.write(mDstFd, buf, 0, len);
@@ -190,8 +142,9 @@
if (len < ipHdrLen + transportHdrLen) {
throw new IllegalStateException("Unexpected buffer length: " + len);
}
- // Re-write addresses, ports and fix up checksums.
- preparePacketForForwarding(mBuf, len, version, proto);
+ // Swap addresses.
+ PacketReflectorUtil.swapAddresses(mBuf, version);
+
// Send the packet to the destination fd.
forwardPacket(mBuf, len);
}
diff --git a/staticlibs/testutils/devicetests/com/android/testutils/PacketReflector.java b/staticlibs/testutils/devicetests/com/android/testutils/PacketReflector.java
index 69392d4..ce20d67 100644
--- a/staticlibs/testutils/devicetests/com/android/testutils/PacketReflector.java
+++ b/staticlibs/testutils/devicetests/com/android/testutils/PacketReflector.java
@@ -87,31 +87,6 @@
mBuf = new byte[mtu];
}
- private static void swapBytes(@NonNull byte[] buf, int pos1, int pos2, int len) {
- for (int i = 0; i < len; i++) {
- byte b = buf[pos1 + i];
- buf[pos1 + i] = buf[pos2 + i];
- buf[pos2 + i] = b;
- }
- }
-
- private static void swapAddresses(@NonNull byte[] buf, int version) {
- int addrPos, addrLen;
- switch (version) {
- case 4:
- addrPos = IPV4_ADDR_OFFSET;
- addrLen = IPV4_ADDR_LENGTH;
- break;
- case 6:
- addrPos = IPV6_ADDR_OFFSET;
- addrLen = IPV6_ADDR_LENGTH;
- break;
- default:
- throw new IllegalArgumentException();
- }
- swapBytes(buf, addrPos, addrPos + addrLen, addrLen);
- }
-
// Reflect TCP packets: swap the source and destination addresses, but don't change the ports.
// This is used by the test to "connect to itself" through the VPN.
private void processTcpPacket(@NonNull byte[] buf, int version, int len, int hdrLen) {
@@ -120,7 +95,7 @@
}
// Swap src and dst IP addresses.
- swapAddresses(buf, version);
+ PacketReflectorUtil.swapAddresses(buf, version);
// Send the packet back.
writePacket(buf, len);
@@ -134,11 +109,11 @@
}
// Swap src and dst IP addresses.
- swapAddresses(buf, version);
+ PacketReflectorUtil.swapAddresses(buf, version);
// Swap dst and src ports.
int portOffset = hdrLen;
- swapBytes(buf, portOffset, portOffset + 2, 2);
+ PacketReflectorUtil.swapBytes(buf, portOffset, portOffset + 2, 2);
// Send the packet back.
writePacket(buf, len);
@@ -160,7 +135,7 @@
// Swap src and dst IP addresses, and send the packet back.
// This effectively pings the device to see if it replies.
- swapAddresses(buf, version);
+ PacketReflectorUtil.swapAddresses(buf, version);
writePacket(buf, len);
// The device should have replied, and buf should now contain a ping response.
@@ -202,7 +177,7 @@
}
// Now swap the addresses again and reflect the packet. This sends a ping reply.
- swapAddresses(buf, version);
+ PacketReflectorUtil.swapAddresses(buf, version);
writePacket(buf, len);
}
diff --git a/staticlibs/testutils/devicetests/com/android/testutils/PacketReflectorUtil.kt b/staticlibs/testutils/devicetests/com/android/testutils/PacketReflectorUtil.kt
index 498b1a3..ad259c5 100644
--- a/staticlibs/testutils/devicetests/com/android/testutils/PacketReflectorUtil.kt
+++ b/staticlibs/testutils/devicetests/com/android/testutils/PacketReflectorUtil.kt
@@ -112,3 +112,28 @@
else -> throw IllegalArgumentException("Unsupported protocol: $protocol")
}
}
+
+fun swapBytes(buf: ByteArray, pos1: Int, pos2: Int, len: Int) {
+ for (i in 0 until len) {
+ val b = buf[pos1 + i]
+ buf[pos1 + i] = buf[pos2 + i]
+ buf[pos2 + i] = b
+ }
+}
+
+fun swapAddresses(buf: ByteArray, version: Int) {
+ val addrPos: Int
+ val addrLen: Int
+ when (version) {
+ 4 -> {
+ addrPos = PacketReflector.IPV4_ADDR_OFFSET
+ addrLen = PacketReflector.IPV4_ADDR_LENGTH
+ }
+ 6 -> {
+ addrPos = PacketReflector.IPV6_ADDR_OFFSET
+ addrLen = PacketReflector.IPV6_ADDR_LENGTH
+ }
+ else -> throw java.lang.IllegalArgumentException()
+ }
+ swapBytes(buf, addrPos, addrPos + addrLen, addrLen)
+}
diff --git a/staticlibs/testutils/hostdevice/com/android/testutils/PacketFilter.kt b/staticlibs/testutils/hostdevice/com/android/testutils/PacketFilter.kt
index 1bb6d68..a73a58a 100644
--- a/staticlibs/testutils/hostdevice/com/android/testutils/PacketFilter.kt
+++ b/staticlibs/testutils/hostdevice/com/android/testutils/PacketFilter.kt
@@ -110,6 +110,12 @@
override fun test(t: ByteArray) = impl.test(t)
}
+class Icmpv6Filter : Predicate<ByteArray> {
+ private val impl = OffsetFilter(ETHER_TYPE_OFFSET, 0x86.toByte(), 0xdd.toByte() /* IPv6 */).and(
+ OffsetFilter(IPV6_PROTOCOL_OFFSET, 58 /* ICMPv6 */))
+ override fun test(t: ByteArray) = impl.test(t)
+}
+
/**
* A [Predicate] that matches ethernet-encapped DHCP packets sent from a DHCP client.
*/
diff --git a/tests/cts/net/src/android/net/cts/MdnsTestUtils.kt b/tests/cts/net/src/android/net/cts/MdnsTestUtils.kt
index eef3f87..5ba6c4c 100644
--- a/tests/cts/net/src/android/net/cts/MdnsTestUtils.kt
+++ b/tests/cts/net/src/android/net/cts/MdnsTestUtils.kt
@@ -23,11 +23,15 @@
import com.android.net.module.util.ArrayTrackRecord
import com.android.net.module.util.DnsPacket
import com.android.net.module.util.NetworkStackConstants.ETHER_HEADER_LEN
+import com.android.net.module.util.NetworkStackConstants.IPV6_ADDR_LEN
+import com.android.net.module.util.NetworkStackConstants.IPV6_DST_ADDR_OFFSET
import com.android.net.module.util.NetworkStackConstants.IPV6_HEADER_LEN
import com.android.net.module.util.NetworkStackConstants.UDP_HEADER_LEN
import com.android.net.module.util.TrackRecord
import com.android.testutils.IPv6UdpFilter
import com.android.testutils.TapPacketReader
+import java.net.Inet6Address
+import java.net.InetAddress
import kotlin.test.assertEquals
import kotlin.test.assertNotNull
import kotlin.test.assertNull
@@ -236,19 +240,28 @@
private fun getMdnsPayload(packet: ByteArray) = packet.copyOfRange(
ETHER_HEADER_LEN + IPV6_HEADER_LEN + UDP_HEADER_LEN, packet.size)
+private fun getDstAddr(packet: ByteArray): Inet6Address {
+ val v6AddrPos = ETHER_HEADER_LEN + IPV6_DST_ADDR_OFFSET
+ return Inet6Address.getByAddress(packet.copyOfRange(v6AddrPos, v6AddrPos + IPV6_ADDR_LEN))
+ as Inet6Address
+}
+
fun TapPacketReader.pollForMdnsPacket(
timeoutMs: Long = MDNS_REGISTRATION_TIMEOUT_MS,
predicate: (TestDnsPacket) -> Boolean
): TestDnsPacket? {
val mdnsProbeFilter = IPv6UdpFilter(srcPort = MDNS_PORT, dstPort = MDNS_PORT).and {
+ val dst = getDstAddr(it)
val mdnsPayload = getMdnsPayload(it)
try {
- predicate(TestDnsPacket(mdnsPayload))
+ predicate(TestDnsPacket(mdnsPayload, dst))
} catch (e: DnsPacket.ParseException) {
false
}
}
- return poll(timeoutMs, mdnsProbeFilter)?.let { TestDnsPacket(getMdnsPayload(it)) }
+ return poll(timeoutMs, mdnsProbeFilter)?.let {
+ TestDnsPacket(getMdnsPayload(it), getDstAddr(it))
+ }
}
fun TapPacketReader.pollForProbe(
@@ -281,7 +294,7 @@
it.isReplyFor("$serviceName.$serviceType.local")
}
-class TestDnsPacket(data: ByteArray) : DnsPacket(data) {
+class TestDnsPacket(data: ByteArray, val dstAddr: InetAddress) : DnsPacket(data) {
val header: DnsHeader
get() = mHeader
val records: Array<List<DnsRecord>>
@@ -290,9 +303,10 @@
it.dName == name && it.nsType == DnsResolver.TYPE_ANY
}
- fun isReplyFor(name: String): Boolean = mRecords[ANSECTION].any {
- it.dName == name && it.nsType == DnsResolver.TYPE_SRV
- }
+ fun isReplyFor(name: String, type: Int = DnsResolver.TYPE_SRV): Boolean =
+ mRecords[ANSECTION].any {
+ it.dName == name && it.nsType == type
+ }
fun isQueryFor(name: String, vararg requiredTypes: Int): Boolean = requiredTypes.all { type ->
mRecords[QDSECTION].any {
diff --git a/tests/cts/net/src/android/net/cts/NsdManagerTest.kt b/tests/cts/net/src/android/net/cts/NsdManagerTest.kt
index a040201..1309e79 100644
--- a/tests/cts/net/src/android/net/cts/NsdManagerTest.kt
+++ b/tests/cts/net/src/android/net/cts/NsdManagerTest.kt
@@ -61,6 +61,7 @@
import android.os.Handler
import android.os.HandlerThread
import android.platform.test.annotations.AppModeFull
+import android.provider.DeviceConfig.NAMESPACE_TETHERING
import android.system.ErrnoException
import android.system.Os
import android.system.OsConstants.AF_INET6
@@ -69,6 +70,7 @@
import android.system.OsConstants.ETH_P_IPV6
import android.system.OsConstants.IPPROTO_IPV6
import android.system.OsConstants.IPPROTO_UDP
+import android.system.OsConstants.RT_SCOPE_LINK
import android.system.OsConstants.SOCK_DGRAM
import android.util.Log
import androidx.test.filters.SmallTest
@@ -78,12 +80,14 @@
import com.android.modules.utils.build.SdkLevel.isAtLeastU
import com.android.net.module.util.DnsPacket
import com.android.net.module.util.HexDump
+import com.android.net.module.util.NetworkStackConstants.IPV6_ADDR_LEN
import com.android.net.module.util.PacketBuilder
import com.android.testutils.ConnectivityModuleTest
import com.android.testutils.DevSdkIgnoreRule
-import com.android.testutils.DevSdkIgnoreRule.IgnoreAfter
import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo
import com.android.testutils.DevSdkIgnoreRunner
+import com.android.testutils.DeviceConfigRule
+import com.android.testutils.NSResponder
import com.android.testutils.RecorderCallback.CallbackEntry.CapabilitiesChanged
import com.android.testutils.RecorderCallback.CallbackEntry.LinkPropertiesChanged
import com.android.testutils.TapPacketReader
@@ -133,6 +137,7 @@
private const val TEST_PORT = 12345
private const val MDNS_PORT = 5353.toShort()
private val multicastIpv6Addr = parseNumericAddress("ff02::fb") as Inet6Address
+private val testSrcAddr = parseNumericAddress("2001:db8::123") as Inet6Address
@AppModeFull(reason = "Socket cannot bind in instant app mode")
@RunWith(DevSdkIgnoreRunner::class)
@@ -144,6 +149,9 @@
@get:Rule
val ignoreRule = DevSdkIgnoreRule()
+ @get:Rule
+ val deviceConfigRule = DeviceConfigRule()
+
private val context by lazy { InstrumentationRegistry.getInstrumentation().context }
private val nsdManager by lazy {
context.getSystemService(NsdManager::class.java) ?: fail("Could not get NsdManager service")
@@ -682,7 +690,7 @@
assertEquals(OffloadEngine.OFFLOAD_TYPE_REPLY.toLong(), serviceInfo.offloadType)
val offloadPayload = serviceInfo.offloadPayload
assertNotNull(offloadPayload)
- val dnsPacket = TestDnsPacket(offloadPayload)
+ val dnsPacket = TestDnsPacket(offloadPayload, dstAddr = multicastIpv6Addr)
assertEquals(0x8400, dnsPacket.header.flags)
assertEquals(0, dnsPacket.records[DnsPacket.QDSECTION].size)
assertTrue(dnsPacket.records[DnsPacket.ANSECTION].size >= 5)
@@ -1286,7 +1294,8 @@
// Resolve service on testNetwork1
val resolveRecord = NsdResolveRecord()
val packetReader = TapPacketReader(Handler(handlerThread.looper),
- testNetwork1.iface.fileDescriptor.fileDescriptor, 1500 /* maxPacketSize */)
+ testNetwork1.iface.fileDescriptor.fileDescriptor, 1500 /* maxPacketSize */
+ )
packetReader.startAsyncForTest()
handlerThread.waitForIdle(TIMEOUT_MS)
@@ -1349,6 +1358,68 @@
serviceResolved.serviceInfo.hostAddresses.toSet())
}
+ @Test
+ fun testUnicastReplyUsedWhenQueryUnicastFlagSet() {
+ // The flag may be removed in the future but unicast replies should be enabled by default
+ // in that case. The rule will reset flags automatically on teardown.
+ deviceConfigRule.setConfig(NAMESPACE_TETHERING, "test_nsd_unicast_reply_enabled", "1")
+
+ val si = makeTestServiceInfo(testNetwork1.network)
+
+ // Register service on testNetwork1
+ val registrationRecord = NsdRegistrationRecord()
+ var nsResponder: NSResponder? = null
+ tryTest {
+ registerService(registrationRecord, si)
+ val packetReader = TapPacketReader(Handler(handlerThread.looper),
+ testNetwork1.iface.fileDescriptor.fileDescriptor, 1500 /* maxPacketSize */)
+ packetReader.startAsyncForTest()
+
+ handlerThread.waitForIdle(TIMEOUT_MS)
+ /*
+ Send a "query unicast" query.
+ Generated with:
+ scapy.raw(scapy.DNS(rd=0, qr=0, aa=0, qd =
+ scapy.DNSQR(qname='_nmt123456789._tcp.local', qtype='PTR', qclass=0x8001)
+ )).hex()
+ */
+ val mdnsPayload = HexDump.hexStringToByteArray("0000000000010000000000000d5f6e6d74313" +
+ "233343536373839045f746370056c6f63616c00000c8001")
+ replaceServiceNameAndTypeWithTestSuffix(mdnsPayload)
+
+ val testSrcAddr = makeLinkLocalAddressOfOtherDeviceOnPrefix(testNetwork1.network)
+ nsResponder = NSResponder(packetReader, mapOf(
+ testSrcAddr to MacAddress.fromString("01:02:03:04:05:06")
+ )).apply { start() }
+
+ packetReader.sendResponse(buildMdnsPacket(mdnsPayload, testSrcAddr))
+ // The reply is sent unicast to the source address. There may be announcements sent
+ // multicast around this time, so filter by destination address.
+ val reply = packetReader.pollForMdnsPacket { pkt ->
+ pkt.isReplyFor("$serviceType.local", DnsResolver.TYPE_PTR) &&
+ pkt.dstAddr == testSrcAddr
+ }
+ assertNotNull(reply)
+ } cleanup {
+ nsResponder?.stop()
+ nsdManager.unregisterService(registrationRecord)
+ registrationRecord.expectCallback<ServiceUnregistered>()
+ }
+ }
+
+ private fun makeLinkLocalAddressOfOtherDeviceOnPrefix(network: Network): Inet6Address {
+ val lp = cm.getLinkProperties(network) ?: fail("No LinkProperties for net $network")
+ // Expect to have a /64 link-local address
+ val linkAddr = lp.linkAddresses.firstOrNull {
+ it.isIPv6 && it.scope == RT_SCOPE_LINK && it.prefixLength == 64
+ } ?: fail("No /64 link-local address found in ${lp.linkAddresses} for net $network")
+
+ // Add one to the device address to simulate the address of another device on the prefix
+ val addrBytes = linkAddr.address.address
+ addrBytes[IPV6_ADDR_LEN - 1]++
+ return Inet6Address.getByAddress(addrBytes) as Inet6Address
+ }
+
private fun buildConflictingAnnouncement(): ByteBuffer {
/*
Generated with:
@@ -1393,7 +1464,10 @@
replaceAll(buffer, source, replacement)
}
- private fun buildMdnsPacket(mdnsPayload: ByteArray): ByteBuffer {
+ private fun buildMdnsPacket(
+ mdnsPayload: ByteArray,
+ srcAddr: Inet6Address = testSrcAddr
+ ): ByteBuffer {
val packetBuffer = PacketBuilder.allocate(true /* hasEther */, IPPROTO_IPV6,
IPPROTO_UDP, mdnsPayload.size)
val packetBuilder = PacketBuilder(packetBuffer)
@@ -1408,7 +1482,7 @@
0x60000000, // version=6, traffic class=0x0, flowlabel=0x0
IPPROTO_UDP.toByte(),
64 /* hop limit */,
- parseNumericAddress("2001:db8::123") as Inet6Address /* srcIp */,
+ srcAddr,
multicastIpv6Addr /* dstIp */)
packetBuilder.writeUdpHeader(MDNS_PORT /* srcPort */, MDNS_PORT /* dstPort */)
packetBuffer.put(mdnsPayload)
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt
index 0c04bff..ee0bd1a 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsInterfaceAdvertiserTest.kt
@@ -35,6 +35,7 @@
import java.net.InetSocketAddress
import kotlin.test.assertContentEquals
import kotlin.test.assertEquals
+import kotlin.test.assertNotSame
import kotlin.test.assertTrue
import org.junit.After
import org.junit.Before
@@ -213,7 +214,12 @@
packetHandler.handlePacket(query, query.size, src)
val packetCaptor = ArgumentCaptor.forClass(MdnsPacket::class.java)
- verify(repository).getReply(packetCaptor.capture(), eq(src))
+ val srcCaptor = ArgumentCaptor.forClass(InetSocketAddress::class.java)
+ verify(repository).getReply(packetCaptor.capture(), srcCaptor.capture())
+
+ assertEquals(src, srcCaptor.value)
+ assertNotSame(src, srcCaptor.value, "src will be reused by the packetHandler, references " +
+ "to it should not be used outside of handlePacket.")
packetCaptor.value.let {
assertEquals(1, it.questions.size)
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt
index 4b1f166..1edc806 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsRecordRepositoryTest.kt
@@ -24,7 +24,6 @@
import com.android.server.connectivity.mdns.MdnsAnnouncer.AnnouncementInfo
import com.android.server.connectivity.mdns.MdnsRecord.TYPE_A
import com.android.server.connectivity.mdns.MdnsRecord.TYPE_AAAA
-import com.android.server.connectivity.mdns.MdnsRecord.TYPE_ANY
import com.android.server.connectivity.mdns.MdnsRecord.TYPE_PTR
import com.android.server.connectivity.mdns.MdnsRecord.TYPE_SRV
import com.android.server.connectivity.mdns.MdnsRecord.TYPE_TXT
@@ -95,7 +94,6 @@
override fun getInterfaceInetAddresses(iface: NetworkInterface) =
Collections.enumeration(TEST_ADDRESSES.map { it.address })
}
- private val flags = MdnsFeatureFlags.newBuilder().build()
@Before
fun setUp() {
@@ -108,9 +106,19 @@
thread.join()
}
+ private fun makeFlags(
+ includeInetAddressesInProbing: Boolean = false,
+ isKnownAnswerSuppressionEnabled: Boolean = false,
+ unicastReplyEnabled: Boolean = true
+ ) = MdnsFeatureFlags.Builder()
+ .setIncludeInetAddressRecordsInProbing(includeInetAddressesInProbing)
+ .setIsKnownAnswerSuppressionEnabled(isKnownAnswerSuppressionEnabled)
+ .setIsUnicastReplyEnabled(unicastReplyEnabled)
+ .build()
+
@Test
fun testAddServiceAndProbe() {
- val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
+ val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
assertEquals(0, repository.servicesCount)
assertEquals(-1, repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1))
assertEquals(1, repository.servicesCount)
@@ -144,7 +152,7 @@
@Test
fun testAddAndConflicts() {
- val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
+ val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
assertFailsWith(NameConflictException::class) {
repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_1)
@@ -156,7 +164,7 @@
@Test
fun testAddAndUpdates() {
- val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
+ val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
assertFailsWith(IllegalArgumentException::class) {
@@ -190,7 +198,7 @@
@Test
fun testInvalidReuseOfServiceId() {
- val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
+ val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
assertFailsWith(IllegalArgumentException::class) {
repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_2)
@@ -199,7 +207,7 @@
@Test
fun testHasActiveService() {
- val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
+ val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
assertFalse(repository.hasActiveService(TEST_SERVICE_ID_1))
repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
@@ -216,7 +224,7 @@
@Test
fun testExitAnnouncements() {
- val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
+ val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
repository.onAdvertisementSent(TEST_SERVICE_ID_1, 2 /* sentPacketCount */)
@@ -246,7 +254,7 @@
@Test
fun testExitAnnouncements_WithSubtypes() {
- val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
+ val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1,
setOf(TEST_SUBTYPE, TEST_SUBTYPE2))
repository.onAdvertisementSent(TEST_SERVICE_ID_1, 2 /* sentPacketCount */)
@@ -288,7 +296,7 @@
@Test
fun testExitingServiceReAdded() {
- val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
+ val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
repository.onAdvertisementSent(TEST_SERVICE_ID_1, 2 /* sentPacketCount */)
repository.exitService(TEST_SERVICE_ID_1)
@@ -303,7 +311,7 @@
@Test
fun testOnProbingSucceeded() {
- val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
+ val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
val announcementInfo = repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1,
setOf(TEST_SUBTYPE, TEST_SUBTYPE2))
repository.onAdvertisementSent(TEST_SERVICE_ID_1, 2 /* sentPacketCount */)
@@ -435,7 +443,7 @@
@Test
fun testGetOffloadPacket() {
- val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
+ val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
val serviceName = arrayOf("MyTestService", "_testservice", "_tcp", "local")
val serviceType = arrayOf("_testservice", "_tcp", "local")
@@ -497,7 +505,7 @@
@Test
fun testGetReplyCaseInsensitive() {
- val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
+ val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
val questionsCaseInSensitive = listOf(
MdnsPointerRecord(arrayOf("_TESTSERVICE", "_TCP", "local"), false /* isUnicast */))
@@ -517,7 +525,7 @@
private fun makeQuery(vararg queries: Pair<Int, Array<String>>): MdnsPacket {
val questions = queries.map { (type, name) -> makeQuestionRecord(name, type) }
return MdnsPacket(0 /* flags */, questions, listOf() /* answers */,
- listOf() /* authorityRecords */, listOf() /* additionalRecords */)
+ listOf() /* authorityRecords */, listOf() /* additionalRecords */)
}
private fun makeQuestionRecord(name: Array<String>, type: Int): MdnsRecord {
@@ -532,7 +540,7 @@
@Test
fun testGetReply_singlePtrQuestion_returnsSrvTxtAddressNsecRecords() {
- val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
+ val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1, setOf(TEST_SUBTYPE))
val src = InetSocketAddress(parseNumericAddress("192.0.2.123"), 5353)
val serviceName = arrayOf("MyTestService", "_testservice", "_tcp", "local")
@@ -563,7 +571,7 @@
@Test
fun testGetReply_singleSubtypePtrQuestion_returnsSrvTxtAddressNsecRecords() {
- val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
+ val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1, setOf(TEST_SUBTYPE))
val src = InetSocketAddress(parseNumericAddress("192.0.2.123"), 5353)
val serviceName = arrayOf("MyTestService", "_testservice", "_tcp", "local")
@@ -596,7 +604,7 @@
@Test
fun testGetReply_duplicatePtrQuestions_doesNotReturnDuplicateRecords() {
- val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
+ val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1, setOf(TEST_SUBTYPE))
val src = InetSocketAddress(parseNumericAddress("192.0.2.123"), 5353)
val serviceName = arrayOf("MyTestService", "_testservice", "_tcp", "local")
@@ -629,7 +637,7 @@
@Test
fun testGetReply_multiplePtrQuestionsWithSubtype_doesNotReturnDuplicateRecords() {
- val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
+ val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1, setOf(TEST_SUBTYPE))
val src = InetSocketAddress(parseNumericAddress("192.0.2.123"), 5353)
val serviceName = arrayOf("MyTestService", "_testservice", "_tcp", "local")
@@ -665,7 +673,7 @@
@Test
fun testGetReply_txtQuestion_returnsNoNsecRecord() {
- val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
+ val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1, setOf(TEST_SUBTYPE))
val src = InetSocketAddress(parseNumericAddress("192.0.2.123"), 5353)
val serviceName = arrayOf("MyTestService", "_testservice", "_tcp", "local")
@@ -682,7 +690,7 @@
@Test
fun testGetReply_AAAAQuestionButNoIpv6Address_returnsNsecRecord() {
- val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
+ val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
repository.initWithService(
TEST_SERVICE_ID_1, TEST_SERVICE_1, setOf(TEST_SUBTYPE),
listOf(LinkAddress(parseNumericAddress("192.0.2.111"), 24)))
@@ -701,7 +709,7 @@
@Test
fun testGetReply_ptrAndSrvQuestions_doesNotReturnSrvRecordInAdditionalAnswerSection() {
- val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
+ val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1, setOf(TEST_SUBTYPE))
val src = InetSocketAddress(parseNumericAddress("192.0.2.123"), 5353)
val serviceName = arrayOf("MyTestService", "_testservice", "_tcp", "local")
@@ -723,7 +731,7 @@
@Test
fun testGetReply_srvTxtAddressQuestions_returnsAllRecordsInAnswerSectionExceptNsec() {
- val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
+ val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1, setOf(TEST_SUBTYPE))
val src = InetSocketAddress(parseNumericAddress("192.0.2.123"), 5353)
val serviceName = arrayOf("MyTestService", "_testservice", "_tcp", "local")
@@ -757,7 +765,7 @@
@Test
fun testGetReply_queryWithIpv4Address_replyWithIpv4Address() {
- val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
+ val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1, setOf(TEST_SUBTYPE))
val query = makeQuery(TYPE_PTR to arrayOf("_testservice", "_tcp", "local"))
@@ -771,7 +779,7 @@
@Test
fun testGetReply_queryWithIpv6Address_replyWithIpv6Address() {
- val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
+ val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1, setOf(TEST_SUBTYPE))
val query = makeQuery(TYPE_PTR to arrayOf("_testservice", "_tcp", "local"))
@@ -785,7 +793,7 @@
@Test
fun testGetConflictingServices() {
- val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
+ val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_2)
@@ -813,7 +821,7 @@
@Test
fun testGetConflictingServicesCaseInsensitive() {
- val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
+ val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_2)
@@ -841,7 +849,7 @@
@Test
fun testGetConflictingServices_IdenticalService() {
- val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
+ val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_2)
@@ -870,7 +878,7 @@
@Test
fun testGetConflictingServicesCaseInsensitive_IdenticalService() {
- val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
+ val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_2)
@@ -899,7 +907,7 @@
@Test
fun testGetServiceRepliedRequestsCount() {
- val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
+ val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
// Verify that there is no packet replied.
assertEquals(MdnsConstants.NO_PACKET,
@@ -924,7 +932,7 @@
@Test
fun testIncludeInetAddressRecordsInProbing() {
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME,
- MdnsFeatureFlags.newBuilder().setIncludeInetAddressRecordsInProbing(true).build())
+ makeFlags(includeInetAddressesInProbing = true))
repository.updateAddresses(TEST_ADDRESSES)
assertEquals(0, repository.servicesCount)
assertEquals(-1, repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1))
@@ -990,7 +998,7 @@
expectReply: Boolean
) {
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME,
- MdnsFeatureFlags.newBuilder().setIsKnownAnswerSuppressionEnabled(true).build())
+ makeFlags(isKnownAnswerSuppressionEnabled = true))
repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
val query = MdnsPacket(0 /* flags */, questions, knownAnswers,
listOf() /* authorityRecords */, listOf() /* additionalRecords */)
@@ -1222,23 +1230,109 @@
MdnsPointerRecord(queriedName, false /* isUnicast */),
MdnsServiceRecord(serviceName, false /* isUnicast */))
val knownAnswers = listOf(
- MdnsPointerRecord(
- queriedName,
- 0L /* receiptTimeMillis */,
- false /* cacheFlush */,
- LONG_TTL - 1000L,
- serviceName),
- MdnsServiceRecord(
- serviceName,
- 0L /* receiptTimeMillis */,
- false /* cacheFlush */,
- SHORT_TTL - 15_000L,
- 0 /* servicePriority */,
- 0 /* serviceWeight */,
- TEST_PORT,
- TEST_HOSTNAME))
- doGetReplyWithAnswersTest(questions, knownAnswers, listOf() /* replyAnswers */,
- listOf() /* additionalAnswers */, false /* expectReply */)
+ MdnsPointerRecord(
+ queriedName,
+ 0L /* receiptTimeMillis */,
+ false /* cacheFlush */,
+ LONG_TTL - 1000L,
+ serviceName
+ ),
+ MdnsServiceRecord(
+ serviceName,
+ 0L /* receiptTimeMillis */,
+ false /* cacheFlush */,
+ SHORT_TTL - 15_000L,
+ 0 /* servicePriority */,
+ 0 /* serviceWeight */,
+ TEST_PORT,
+ TEST_HOSTNAME
+ )
+ )
+ doGetReplyWithAnswersTest(
+ questions, knownAnswers, listOf() /* replyAnswers */,
+ listOf() /* additionalAnswers */, false /* expectReply */
+ )
+ }
+
+ @Test
+ fun testReplyUnicastToQueryUnicastQuestions() {
+ val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
+ repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
+
+ // Ask for 2 services, only the first one is known and requests unicast reply
+ val questions = listOf(
+ MdnsPointerRecord(arrayOf("_testservice", "_tcp", "local"), true /* isUnicast */),
+ MdnsPointerRecord(arrayOf("_otherservice", "_tcp", "local"), true /* isUnicast */))
+ val query = MdnsPacket(0 /* flags */, questions, listOf() /* answers */,
+ listOf() /* authorityRecords */, listOf() /* additionalRecords */)
+ val src = InetSocketAddress(parseNumericAddress("2001:db8::123"), 5353)
+
+ // Reply to the question and verify it is sent to the source.
+ val reply = repository.getReply(query, src)
+ assertNotNull(reply)
+ assertEquals(src, reply.destination)
+ }
+
+ @Test
+ fun testReplyMulticastToQueryUnicastAndMulticastMixedQuestions() {
+ val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
+ repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
+ repository.addServiceAndFinishProbing(TEST_SERVICE_ID_2, NsdServiceInfo().apply {
+ serviceType = "_otherservice._tcp"
+ serviceName = "OtherTestService"
+ port = TEST_PORT
+ })
+
+ // Ask for 2 services, both are known and only the first one requests unicast reply
+ val questions = listOf(
+ MdnsPointerRecord(arrayOf("_testservice", "_tcp", "local"), true /* isUnicast */),
+ MdnsPointerRecord(arrayOf("_otherservice", "_tcp", "local"), false /* isUnicast */))
+ val query = MdnsPacket(0 /* flags */, questions, listOf() /* answers */,
+ listOf() /* authorityRecords */, listOf() /* additionalRecords */)
+ val src = InetSocketAddress(parseNumericAddress("2001:db8::123"), 5353)
+
+ // Reply to the question and verify it is sent multicast.
+ val reply = repository.getReply(query, src)
+ assertNotNull(reply)
+ assertEquals(MdnsConstants.getMdnsIPv6Address(), reply.destination.address)
+ }
+
+ @Test
+ fun testReplyMulticastWhenNoUnicastQueryMatches() {
+ val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, makeFlags())
+ repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
+
+ // Ask for 2 services, the first one requests a unicast reply but is unknown
+ val questions = listOf(
+ MdnsPointerRecord(arrayOf("_otherservice", "_tcp", "local"), true /* isUnicast */),
+ MdnsPointerRecord(arrayOf("_testservice", "_tcp", "local"), false /* isUnicast */))
+ val query = MdnsPacket(0 /* flags */, questions, listOf() /* answers */,
+ listOf() /* authorityRecords */, listOf() /* additionalRecords */)
+ val src = InetSocketAddress(parseNumericAddress("2001:db8::123"), 5353)
+
+ // Reply to the question and verify it is sent multicast.
+ val reply = repository.getReply(query, src)
+ assertNotNull(reply)
+ assertEquals(MdnsConstants.getMdnsIPv6Address(), reply.destination.address)
+ }
+
+ @Test
+ fun testReplyMulticastWhenUnicastFeatureDisabled() {
+ val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME,
+ makeFlags(unicastReplyEnabled = false))
+ repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
+
+ // The service is known and requests unicast reply, but the feature is disabled
+ val questions = listOf(
+ MdnsPointerRecord(arrayOf("_testservice", "_tcp", "local"), true /* isUnicast */))
+ val query = MdnsPacket(0 /* flags */, questions, listOf() /* answers */,
+ listOf() /* authorityRecords */, listOf() /* additionalRecords */)
+ val src = InetSocketAddress(parseNumericAddress("2001:db8::123"), 5353)
+
+ // Reply to the question and verify it is sent multicast.
+ val reply = repository.getReply(query, src)
+ assertNotNull(reply)
+ assertEquals(MdnsConstants.getMdnsIPv6Address(), reply.destination.address)
}
}
@@ -1250,6 +1344,13 @@
): AnnouncementInfo {
updateAddresses(addresses)
serviceInfo.setSubtypes(subtypes)
+ return addServiceAndFinishProbing(serviceId, serviceInfo)
+}
+
+private fun MdnsRecordRepository.addServiceAndFinishProbing(
+ serviceId: Int,
+ serviceInfo: NsdServiceInfo
+): AnnouncementInfo {
addService(serviceId, serviceInfo)
val probingInfo = setServiceProbing(serviceId)
assertNotNull(probingInfo)
diff --git a/tests/unit/java/com/android/server/connectivityservice/CSBasicMethodsTest.kt b/tests/unit/java/com/android/server/connectivityservice/CSBasicMethodsTest.kt
index 58f20a9..a5d5297 100644
--- a/tests/unit/java/com/android/server/connectivityservice/CSBasicMethodsTest.kt
+++ b/tests/unit/java/com/android/server/connectivityservice/CSBasicMethodsTest.kt
@@ -23,11 +23,12 @@
import androidx.test.filters.SmallTest
import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo
import com.android.testutils.DevSdkIgnoreRunner
-import org.junit.Test
-import org.junit.runner.RunWith
import kotlin.test.assertFalse
import kotlin.test.assertTrue
+import org.junit.Test
+import org.junit.runner.RunWith
+@DevSdkIgnoreRunner.MonitorThreadLeak
@RunWith(DevSdkIgnoreRunner::class)
@SmallTest
@IgnoreUpTo(Build.VERSION_CODES.R)
diff --git a/tests/unit/java/com/android/server/connectivityservice/CSBpfNetMapsTest.kt b/tests/unit/java/com/android/server/connectivityservice/CSBpfNetMapsTest.kt
index c26ec53..8155fd0 100644
--- a/tests/unit/java/com/android/server/connectivityservice/CSBpfNetMapsTest.kt
+++ b/tests/unit/java/com/android/server/connectivityservice/CSBpfNetMapsTest.kt
@@ -38,6 +38,7 @@
import org.mockito.Mockito.never
import org.mockito.Mockito.verify
+@DevSdkIgnoreRunner.MonitorThreadLeak
@RunWith(DevSdkIgnoreRunner::class)
@SmallTest
@IgnoreUpTo(Build.VERSION_CODES.S_V2) // Bpf only supports in T+.
diff --git a/tests/unit/java/com/android/server/connectivityservice/CSDestroyedNetworkTests.kt b/tests/unit/java/com/android/server/connectivityservice/CSDestroyedNetworkTests.kt
index 572c7bb..5c29e3a 100644
--- a/tests/unit/java/com/android/server/connectivityservice/CSDestroyedNetworkTests.kt
+++ b/tests/unit/java/com/android/server/connectivityservice/CSDestroyedNetworkTests.kt
@@ -30,6 +30,7 @@
private const val LONG_TIMEOUT_MS = 5_000
+@DevSdkIgnoreRunner.MonitorThreadLeak
@RunWith(DevSdkIgnoreRunner::class)
@SmallTest
@DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.R)
diff --git a/tests/unit/java/com/android/server/connectivityservice/CSKeepConnectedTest.kt b/tests/unit/java/com/android/server/connectivityservice/CSKeepConnectedTest.kt
index a753922..94c68c0 100644
--- a/tests/unit/java/com/android/server/connectivityservice/CSKeepConnectedTest.kt
+++ b/tests/unit/java/com/android/server/connectivityservice/CSKeepConnectedTest.kt
@@ -22,8 +22,8 @@
import android.net.NetworkCapabilities.TRANSPORT_WIFI
import android.net.NetworkRequest
import android.net.NetworkScore
-import android.net.NetworkScore.KEEP_CONNECTED_LOCAL_NETWORK
import android.net.NetworkScore.KEEP_CONNECTED_FOR_TEST
+import android.net.NetworkScore.KEEP_CONNECTED_LOCAL_NETWORK
import android.os.Build
import androidx.test.filters.SmallTest
import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo
@@ -33,6 +33,7 @@
import org.junit.Test
import org.junit.runner.RunWith
+@DevSdkIgnoreRunner.MonitorThreadLeak
@RunWith(DevSdkIgnoreRunner::class)
@SmallTest
@IgnoreUpTo(Build.VERSION_CODES.TIRAMISU)
diff --git a/tests/unit/java/com/android/server/connectivityservice/CSLocalAgentCreationTests.kt b/tests/unit/java/com/android/server/connectivityservice/CSLocalAgentCreationTests.kt
index 6add6b9..cb98454 100644
--- a/tests/unit/java/com/android/server/connectivityservice/CSLocalAgentCreationTests.kt
+++ b/tests/unit/java/com/android/server/connectivityservice/CSLocalAgentCreationTests.kt
@@ -33,6 +33,7 @@
import com.android.testutils.DevSdkIgnoreRunner
import com.android.testutils.RecorderCallback.CallbackEntry.Available
import com.android.testutils.TestableNetworkCallback
+import kotlin.test.assertFailsWith
import org.junit.Assert.assertEquals
import org.junit.Test
import org.junit.runner.RunWith
@@ -41,7 +42,6 @@
import org.mockito.Mockito.inOrder
import org.mockito.Mockito.never
import org.mockito.Mockito.timeout
-import kotlin.test.assertFailsWith
private const val TIMEOUT_MS = 2_000L
private const val NO_CALLBACK_TIMEOUT_MS = 200L
@@ -51,6 +51,7 @@
private fun defaultLnc() = FromS(LocalNetworkConfig.Builder().build())
+@DevSdkIgnoreRunner.MonitorThreadLeak
@RunWith(DevSdkIgnoreRunner::class)
@SmallTest
@IgnoreUpTo(Build.VERSION_CODES.R)
diff --git a/tests/unit/java/com/android/server/connectivityservice/CSLocalAgentTests.kt b/tests/unit/java/com/android/server/connectivityservice/CSLocalAgentTests.kt
index dd0706b..ba14775 100644
--- a/tests/unit/java/com/android/server/connectivityservice/CSLocalAgentTests.kt
+++ b/tests/unit/java/com/android/server/connectivityservice/CSLocalAgentTests.kt
@@ -42,6 +42,7 @@
import com.android.testutils.RecorderCallback.CallbackEntry.LocalInfoChanged
import com.android.testutils.RecorderCallback.CallbackEntry.Lost
import com.android.testutils.TestableNetworkCallback
+import kotlin.test.assertFailsWith
import org.junit.Test
import org.junit.runner.RunWith
import org.mockito.Mockito.clearInvocations
@@ -49,7 +50,6 @@
import org.mockito.Mockito.never
import org.mockito.Mockito.timeout
import org.mockito.Mockito.verify
-import kotlin.test.assertFailsWith
private const val TIMEOUT_MS = 200L
private const val MEDIUM_TIMEOUT_MS = 1_000L
@@ -79,6 +79,7 @@
NetworkScore.Builder().setKeepConnectedReason(KEEP_CONNECTED_FOR_TEST).build()
)
+@DevSdkIgnoreRunner.MonitorThreadLeak
@RunWith(DevSdkIgnoreRunner::class)
@DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.TIRAMISU)
class CSLocalAgentTests : CSTest() {
diff --git a/tests/unit/java/com/android/server/connectivityservice/CSNetworkActivityTest.kt b/tests/unit/java/com/android/server/connectivityservice/CSNetworkActivityTest.kt
index 526ec9d..df0a2cc 100644
--- a/tests/unit/java/com/android/server/connectivityservice/CSNetworkActivityTest.kt
+++ b/tests/unit/java/com/android/server/connectivityservice/CSNetworkActivityTest.kt
@@ -63,6 +63,7 @@
private const val PACKAGE_UID = 123
private const val TIMEOUT_MS = 250L
+@DevSdkIgnoreRunner.MonitorThreadLeak
@RunWith(DevSdkIgnoreRunner::class)
@SmallTest
@IgnoreUpTo(Build.VERSION_CODES.UPSIDE_DOWN_CAKE)
diff --git a/tests/unit/java/com/android/server/connectivityservice/base/CSTest.kt b/tests/unit/java/com/android/server/connectivityservice/base/CSTest.kt
index 21396f2..5322799 100644
--- a/tests/unit/java/com/android/server/connectivityservice/base/CSTest.kt
+++ b/tests/unit/java/com/android/server/connectivityservice/base/CSTest.kt
@@ -70,6 +70,7 @@
import java.util.concurrent.Executors
import kotlin.test.assertNull
import kotlin.test.fail
+import org.junit.After
import org.mockito.AdditionalAnswers.delegatesTo
import org.mockito.Mockito.doAnswer
import org.mockito.Mockito.doReturn
@@ -160,7 +161,8 @@
val clatCoordinator = mock<ClatCoordinator>()
val networkRequestStateStatsMetrics = mock<NetworkRequestStateStatsMetrics>()
val proxyTracker = ProxyTracker(context, mock<Handler>(), 16 /* EVENT_PROXY_HAS_CHANGED */)
- val alarmManager = makeMockAlarmManager()
+ val alrmHandlerThread = HandlerThread("TestAlarmManager").also { it.start() }
+ val alarmManager = makeMockAlarmManager(alrmHandlerThread)
val systemConfigManager = makeMockSystemConfigManager()
val batteryStats = mock<IBatteryStats>()
val batteryManager = BatteryStatsManager(batteryStats)
@@ -173,6 +175,14 @@
val cm = ConnectivityManager(context, service)
val csHandler = Handler(csHandlerThread.looper)
+ @After
+ fun tearDown() {
+ csHandlerThread.quitSafely()
+ csHandlerThread.join()
+ alrmHandlerThread.quitSafely()
+ alrmHandlerThread.join()
+ }
+
inner class CSDeps : ConnectivityService.Dependencies() {
override fun getResources(ctx: Context) = connResources
override fun getBpfNetMaps(context: Context, netd: INetd) = this@CSTest.bpfNetMaps
diff --git a/tests/unit/java/com/android/server/connectivityservice/base/CSTestHelpers.kt b/tests/unit/java/com/android/server/connectivityservice/base/CSTestHelpers.kt
index c1828b2..8ff790c 100644
--- a/tests/unit/java/com/android/server/connectivityservice/base/CSTestHelpers.kt
+++ b/tests/unit/java/com/android/server/connectivityservice/base/CSTestHelpers.kt
@@ -53,6 +53,7 @@
import com.android.modules.utils.build.SdkLevel
import com.android.server.ConnectivityService.Dependencies
import com.android.server.connectivity.ConnectivityResources
+import kotlin.test.fail
import org.mockito.ArgumentMatchers
import org.mockito.ArgumentMatchers.any
import org.mockito.ArgumentMatchers.anyInt
@@ -64,7 +65,6 @@
import org.mockito.Mockito.doAnswer
import org.mockito.Mockito.doNothing
import org.mockito.Mockito.doReturn
-import kotlin.test.fail
internal inline fun <reified T> mock() = Mockito.mock(T::class.java)
internal inline fun <reified T> any() = any(T::class.java)
@@ -128,8 +128,8 @@
}
private val UNREASONABLY_LONG_ALARM_WAIT_MS = 1000
-internal fun makeMockAlarmManager() = mock<AlarmManager>().also { am ->
- val alrmHdlr = HandlerThread("TestAlarmManager").also { it.start() }.threadHandler
+internal fun makeMockAlarmManager(handlerThread: HandlerThread) = mock<AlarmManager>().also { am ->
+ val alrmHdlr = handlerThread.threadHandler
doAnswer {
val (_, date, _, wakeupMsg, handler) = it.arguments
wakeupMsg as WakeupMessage
diff --git a/thread/tests/cts/src/android/net/thread/cts/ThreadNetworkControllerTest.java b/thread/tests/cts/src/android/net/thread/cts/ThreadNetworkControllerTest.java
index e02e74d..7a6c9aa 100644
--- a/thread/tests/cts/src/android/net/thread/cts/ThreadNetworkControllerTest.java
+++ b/thread/tests/cts/src/android/net/thread/cts/ThreadNetworkControllerTest.java
@@ -24,6 +24,8 @@
import static androidx.test.platform.app.InstrumentationRegistry.getInstrumentation;
+import static com.android.testutils.TestPermissionUtil.runAsShell;
+
import static com.google.common.truth.Truth.assertThat;
import static com.google.common.util.concurrent.MoreExecutors.directExecutor;
@@ -34,6 +36,10 @@
import android.Manifest.permission;
import android.content.Context;
+import android.net.ConnectivityManager;
+import android.net.Network;
+import android.net.NetworkCapabilities;
+import android.net.NetworkRequest;
import android.net.thread.ActiveOperationalDataset;
import android.net.thread.OperationalDatasetTimestamp;
import android.net.thread.PendingOperationalDataset;
@@ -74,6 +80,8 @@
@RunWith(DevSdkIgnoreRunner.class)
@IgnoreUpTo(Build.VERSION_CODES.TIRAMISU) // Thread is available on only U+
public class ThreadNetworkControllerTest {
+ private static final int JOIN_TIMEOUT_MILLIS = 30 * 1000;
+ private static final int NETWORK_CALLBACK_TIMEOUT_MILLIS = 10 * 1000;
private static final int CALLBACK_TIMEOUT_MILLIS = 1000;
private static final String PERMISSION_THREAD_NETWORK_PRIVILEGED =
"android.permission.THREAD_NETWORK_PRIVILEGED";
@@ -750,4 +758,36 @@
assertThat(dataset.getMeshLocalPrefix().getRawAddress()[0]).isEqualTo((byte) 0xfd);
}
}
+
+ @Test
+ public void threadNetworkCallback_deviceAttached_threadNetworkIsAvailable() throws Exception {
+ ThreadNetworkController controller = mManager.getAllThreadNetworkControllers().get(0);
+ ActiveOperationalDataset activeDataset = newRandomizedDataset("TestNet", controller);
+ SettableFuture<Void> joinFuture = SettableFuture.create();
+ SettableFuture<Network> networkFuture = SettableFuture.create();
+ ConnectivityManager cm = mContext.getSystemService(ConnectivityManager.class);
+ NetworkRequest networkRequest =
+ new NetworkRequest.Builder()
+ .addTransportType(NetworkCapabilities.TRANSPORT_THREAD)
+ .build();
+ ConnectivityManager.NetworkCallback networkCallback =
+ new ConnectivityManager.NetworkCallback() {
+ @Override
+ public void onAvailable(Network network) {
+ networkFuture.set(network);
+ }
+ };
+
+ runAsShell(
+ PERMISSION_THREAD_NETWORK_PRIVILEGED,
+ () -> controller.join(activeDataset, mExecutor, newOutcomeReceiver(joinFuture)));
+ runAsShell(
+ permission.ACCESS_NETWORK_STATE,
+ () -> cm.registerNetworkCallback(networkRequest, networkCallback));
+
+ joinFuture.get(JOIN_TIMEOUT_MILLIS, MILLISECONDS);
+ runAsShell(
+ permission.ACCESS_NETWORK_STATE, () -> assertThat(isAttached(controller)).isTrue());
+ assertThat(networkFuture.get(NETWORK_CALLBACK_TIMEOUT_MILLIS, MILLISECONDS)).isNotNull();
+ }
}