Support DNS-over-TLS probes in NetworkDiagnostics
Probe DNS servers to see they support DNS-over-TLS. Use system
CAs to verify whether the certificates sent by DNS servers are
trusted or not. An error is thrown to cause the probe failed if
DNS servers send untrusted certificates.
Unlike the DnsResolver which doesn't verify the certificates
in opportunistic mode, all of the DoT probes from NetworkDiagnostics
check certificates.
DoT probes apply to the DNS servers gotten from LinkProperties
and the DoT servers gotten from PrivateDnsConfig whatever private
DNS mode is.
A common example in DNS strict mode:
. DNS TLS dst{8.8.8.8} hostname{dns.google} src{192.168.43.2:48436} qtype{1} qname{815149-android-ds.metric.gstatic.com}: SUCCEEDED: 1/1 NOERROR (432ms)
F DNS TLS dst{192.168.43.144} hostname{}: FAILED: java.net.ConnectException: failed to connect to /192.168.43.144 (port 853) from /192.168.43.2 (port 41770) after 2500ms: isConnected failed: ECONNREFUSED (Connection refused) (172ms)
. DNS TLS dst{8.8.4.4} hostname{dns.google} src{192.168.43.2:37598} qtype{1} qname{759312-android-ds.metric.gstatic.com}: SUCCEEDED: 1/1 NOERROR (427ms)
An example when the CA is not trusted:
F DNS TLS dst{8.8.8.8} hostname{dns.google}: FAILED: javax.net.ssl.SSLHandshakeException: java.security.cert.CertPathValidatorException: Trust anchor for certification path not found. (16ms)
An example when TCP/TLS handshake timeout:
F DNS TLS dst{8.8.8.8} hostname{dns.google}: FAILED: java.net.SocketTimeoutException: failed to connect to /8.8.8.8 (port 853) from /192.168.2.108 (port 45680) after 2500ms (2514ms)
Bug: 132925257
Bug: 118369977
Test: atest FrameworksNetTests
Change-Id: I1b54abed0e931ca4b8a97149459cde54da1c3d6f
diff --git a/services/core/java/com/android/server/ConnectivityService.java b/services/core/java/com/android/server/ConnectivityService.java
index 320f3fb..80efd52 100644
--- a/services/core/java/com/android/server/ConnectivityService.java
+++ b/services/core/java/com/android/server/ConnectivityService.java
@@ -2479,10 +2479,12 @@
final List<NetworkDiagnostics> netDiags = new ArrayList<NetworkDiagnostics>();
final long DIAG_TIME_MS = 5000;
for (NetworkAgentInfo nai : networksSortedById()) {
+ PrivateDnsConfig privateDnsCfg = mDnsManager.getPrivateDnsConfig(nai.network);
// Start gathering diagnostic information.
netDiags.add(new NetworkDiagnostics(
nai.network,
new LinkProperties(nai.linkProperties), // Must be a copy.
+ privateDnsCfg,
DIAG_TIME_MS));
}
diff --git a/services/core/java/com/android/server/connectivity/DnsManager.java b/services/core/java/com/android/server/connectivity/DnsManager.java
index 506c8e3..cf6a7f6 100644
--- a/services/core/java/com/android/server/connectivity/DnsManager.java
+++ b/services/core/java/com/android/server/connectivity/DnsManager.java
@@ -57,6 +57,7 @@
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
@@ -64,7 +65,9 @@
* Encapsulate the management of DNS settings for networks.
*
* This class it NOT designed for concurrent access. Furthermore, all non-static
- * methods MUST be called from ConnectivityService's thread.
+ * methods MUST be called from ConnectivityService's thread. However, an exceptional
+ * case is getPrivateDnsConfig(Network) which is exclusively for
+ * ConnectivityService#dumpNetworkDiagnostics() on a random binder thread.
*
* [ Private DNS ]
* The code handling Private DNS is spread across several components, but this
@@ -236,8 +239,8 @@
private final ContentResolver mContentResolver;
private final IDnsResolver mDnsResolver;
private final MockableSystemProperties mSystemProperties;
- // TODO: Replace these Maps with SparseArrays.
- private final Map<Integer, PrivateDnsConfig> mPrivateDnsMap;
+ private final ConcurrentHashMap<Integer, PrivateDnsConfig> mPrivateDnsMap;
+ // TODO: Replace the Map with SparseArrays.
private final Map<Integer, PrivateDnsValidationStatuses> mPrivateDnsValidationMap;
private final Map<Integer, LinkProperties> mLinkPropertiesMap;
private final Map<Integer, int[]> mTransportsMap;
@@ -247,15 +250,13 @@
private int mSuccessThreshold;
private int mMinSamples;
private int mMaxSamples;
- private String mPrivateDnsMode;
- private String mPrivateDnsSpecifier;
public DnsManager(Context ctx, IDnsResolver dnsResolver, MockableSystemProperties sp) {
mContext = ctx;
mContentResolver = mContext.getContentResolver();
mDnsResolver = dnsResolver;
mSystemProperties = sp;
- mPrivateDnsMap = new HashMap<>();
+ mPrivateDnsMap = new ConcurrentHashMap<>();
mPrivateDnsValidationMap = new HashMap<>();
mLinkPropertiesMap = new HashMap<>();
mTransportsMap = new HashMap<>();
@@ -275,6 +276,12 @@
mLinkPropertiesMap.remove(network.netId);
}
+ // This is exclusively called by ConnectivityService#dumpNetworkDiagnostics() which
+ // is not on the ConnectivityService handler thread.
+ public PrivateDnsConfig getPrivateDnsConfig(@NonNull Network network) {
+ return mPrivateDnsMap.getOrDefault(network.netId, PRIVATE_DNS_OFF);
+ }
+
public PrivateDnsConfig updatePrivateDns(Network network, PrivateDnsConfig cfg) {
Slog.w(TAG, "updatePrivateDns(" + network + ", " + cfg + ")");
return (cfg != null)
diff --git a/services/core/java/com/android/server/connectivity/NetworkDiagnostics.java b/services/core/java/com/android/server/connectivity/NetworkDiagnostics.java
index a1a8e35..49c16ad 100644
--- a/services/core/java/com/android/server/connectivity/NetworkDiagnostics.java
+++ b/services/core/java/com/android/server/connectivity/NetworkDiagnostics.java
@@ -18,12 +18,15 @@
import static android.system.OsConstants.*;
+import android.annotation.NonNull;
+import android.annotation.Nullable;
import android.net.LinkAddress;
import android.net.LinkProperties;
import android.net.Network;
import android.net.NetworkUtils;
import android.net.RouteInfo;
import android.net.TrafficStats;
+import android.net.shared.PrivateDnsConfig;
import android.net.util.NetworkConstants;
import android.os.SystemClock;
import android.system.ErrnoException;
@@ -38,6 +41,8 @@
import libcore.io.IoUtils;
import java.io.Closeable;
+import java.io.DataInputStream;
+import java.io.DataOutputStream;
import java.io.FileDescriptor;
import java.io.IOException;
import java.io.InterruptedIOException;
@@ -52,6 +57,7 @@
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
+import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@@ -59,6 +65,12 @@
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
+import javax.net.ssl.SNIHostName;
+import javax.net.ssl.SNIServerName;
+import javax.net.ssl.SSLParameters;
+import javax.net.ssl.SSLSocket;
+import javax.net.ssl.SSLSocketFactory;
+
/**
* NetworkDiagnostics
*
@@ -100,6 +112,7 @@
private final Network mNetwork;
private final LinkProperties mLinkProperties;
+ private final PrivateDnsConfig mPrivateDnsCfg;
private final Integer mInterfaceIndex;
private final long mTimeoutMs;
@@ -163,12 +176,15 @@
private final Map<Pair<InetAddress, InetAddress>, Measurement> mExplicitSourceIcmpChecks =
new HashMap<>();
private final Map<InetAddress, Measurement> mDnsUdpChecks = new HashMap<>();
+ private final Map<InetAddress, Measurement> mDnsTlsChecks = new HashMap<>();
private final String mDescription;
- public NetworkDiagnostics(Network network, LinkProperties lp, long timeoutMs) {
+ public NetworkDiagnostics(Network network, LinkProperties lp,
+ @NonNull PrivateDnsConfig privateDnsCfg, long timeoutMs) {
mNetwork = network;
mLinkProperties = lp;
+ mPrivateDnsCfg = privateDnsCfg;
mInterfaceIndex = getInterfaceIndex(mLinkProperties.getInterfaceName());
mTimeoutMs = timeoutMs;
mStartTime = now();
@@ -199,8 +215,22 @@
}
}
for (InetAddress nameserver : mLinkProperties.getDnsServers()) {
- prepareIcmpMeasurement(nameserver);
- prepareDnsMeasurement(nameserver);
+ prepareIcmpMeasurement(nameserver);
+ prepareDnsMeasurement(nameserver);
+
+ // Unlike the DnsResolver which doesn't do certificate validation in opportunistic mode,
+ // DoT probes to the DNS servers will fail if certificate validation fails.
+ prepareDnsTlsMeasurement(null /* hostname */, nameserver);
+ }
+
+ for (InetAddress tlsNameserver : mPrivateDnsCfg.ips) {
+ // Reachability check is necessary since when resolving the strict mode hostname,
+ // NetworkMonitor always queries for both A and AAAA records, even if the network
+ // is IPv4-only or IPv6-only.
+ if (mLinkProperties.isReachable(tlsNameserver)) {
+ // If there are IPs, there must have been a name that resolved to them.
+ prepareDnsTlsMeasurement(mPrivateDnsCfg.hostname, tlsNameserver);
+ }
}
mCountDownLatch = new CountDownLatch(totalMeasurementCount());
@@ -222,6 +252,15 @@
}
}
+ private static String socketAddressToString(@NonNull SocketAddress sockAddr) {
+ // The default toString() implementation is not the prettiest.
+ InetSocketAddress inetSockAddr = (InetSocketAddress) sockAddr;
+ InetAddress localAddr = inetSockAddr.getAddress();
+ return String.format(
+ (localAddr instanceof Inet6Address ? "[%s]:%d" : "%s:%d"),
+ localAddr.getHostAddress(), inetSockAddr.getPort());
+ }
+
private void prepareIcmpMeasurement(InetAddress target) {
if (!mIcmpChecks.containsKey(target)) {
Measurement measurement = new Measurement();
@@ -252,8 +291,19 @@
}
}
+ private void prepareDnsTlsMeasurement(@Nullable String hostname, @NonNull InetAddress target) {
+ // This might overwrite an existing entry in mDnsTlsChecks, because |target| can be an IP
+ // address configured by the network as well as an IP address learned by resolving the
+ // strict mode DNS hostname. If the entry is overwritten, the overwritten measurement
+ // thread will not execute.
+ Measurement measurement = new Measurement();
+ measurement.thread = new Thread(new DnsTlsCheck(hostname, target, measurement));
+ mDnsTlsChecks.put(target, measurement);
+ }
+
private int totalMeasurementCount() {
- return mIcmpChecks.size() + mExplicitSourceIcmpChecks.size() + mDnsUdpChecks.size();
+ return mIcmpChecks.size() + mExplicitSourceIcmpChecks.size() + mDnsUdpChecks.size()
+ + mDnsTlsChecks.size();
}
private void startMeasurements() {
@@ -266,6 +316,9 @@
for (Measurement measurement : mDnsUdpChecks.values()) {
measurement.thread.start();
}
+ for (Measurement measurement : mDnsTlsChecks.values()) {
+ measurement.thread.start();
+ }
}
public void waitForMeasurements() {
@@ -297,6 +350,11 @@
measurements.add(entry.getValue());
}
}
+ for (Map.Entry<InetAddress, Measurement> entry : mDnsTlsChecks.entrySet()) {
+ if (entry.getKey() instanceof Inet4Address) {
+ measurements.add(entry.getValue());
+ }
+ }
// IPv6 measurements second.
for (Map.Entry<InetAddress, Measurement> entry : mIcmpChecks.entrySet()) {
@@ -315,6 +373,11 @@
measurements.add(entry.getValue());
}
}
+ for (Map.Entry<InetAddress, Measurement> entry : mDnsTlsChecks.entrySet()) {
+ if (entry.getKey() instanceof Inet6Address) {
+ measurements.add(entry.getValue());
+ }
+ }
return measurements;
}
@@ -387,6 +450,8 @@
try {
mFileDescriptor = Os.socket(mAddressFamily, sockType, protocol);
} finally {
+ // TODO: The tag should remain set until all traffic is sent and received.
+ // Consider tagging the socket after the measurement thread is started.
TrafficStats.setThreadStatsTag(oldTag);
}
// Setting SNDTIMEO is purely for defensive purposes.
@@ -403,13 +468,12 @@
mSocketAddress = Os.getsockname(mFileDescriptor);
}
- protected String getSocketAddressString() {
- // The default toString() implementation is not the prettiest.
- InetSocketAddress inetSockAddr = (InetSocketAddress) mSocketAddress;
- InetAddress localAddr = inetSockAddr.getAddress();
- return String.format(
- (localAddr instanceof Inet6Address ? "[%s]:%d" : "%s:%d"),
- localAddr.getHostAddress(), inetSockAddr.getPort());
+ protected boolean ensureMeasurementNecessary() {
+ if (mMeasurement.finishTime == 0) return false;
+
+ // Countdown latch was not decremented when the measurement failed during setup.
+ mCountDownLatch.countDown();
+ return true;
}
@Override
@@ -448,13 +512,7 @@
@Override
public void run() {
- // Check if this measurement has already failed during setup.
- if (mMeasurement.finishTime > 0) {
- // If the measurement failed during construction it didn't
- // decrement the countdown latch; do so here.
- mCountDownLatch.countDown();
- return;
- }
+ if (ensureMeasurementNecessary()) return;
try {
setupSocket(SOCK_DGRAM, mProtocol, TIMEOUT_SEND, TIMEOUT_RECV, 0);
@@ -462,7 +520,7 @@
mMeasurement.recordFailure(e.toString());
return;
}
- mMeasurement.description += " src{" + getSocketAddressString() + "}";
+ mMeasurement.description += " src{" + socketAddressToString(mSocketAddress) + "}";
// Build a trivial ICMP packet.
final byte[] icmpPacket = {
@@ -507,10 +565,10 @@
private static final int RR_TYPE_AAAA = 28;
private static final int PACKET_BUFSIZE = 512;
- private final Random mRandom = new Random();
+ protected final Random mRandom = new Random();
// Should be static, but the compiler mocks our puny, human attempts at reason.
- private String responseCodeStr(int rcode) {
+ protected String responseCodeStr(int rcode) {
try {
return DnsResponseCode.values()[rcode].toString();
} catch (IndexOutOfBoundsException e) {
@@ -518,7 +576,7 @@
}
}
- private final int mQueryType;
+ protected final int mQueryType;
public DnsUdpCheck(InetAddress target, Measurement measurement) {
super(target, measurement);
@@ -535,13 +593,7 @@
@Override
public void run() {
- // Check if this measurement has already failed during setup.
- if (mMeasurement.finishTime > 0) {
- // If the measurement failed during construction it didn't
- // decrement the countdown latch; do so here.
- mCountDownLatch.countDown();
- return;
- }
+ if (ensureMeasurementNecessary()) return;
try {
setupSocket(SOCK_DGRAM, IPPROTO_UDP, TIMEOUT_SEND, TIMEOUT_RECV,
@@ -550,12 +602,10 @@
mMeasurement.recordFailure(e.toString());
return;
}
- mMeasurement.description += " src{" + getSocketAddressString() + "}";
// This needs to be fixed length so it can be dropped into the pre-canned packet.
final String sixRandomDigits = String.valueOf(mRandom.nextInt(900000) + 100000);
- mMeasurement.description += " qtype{" + mQueryType + "}"
- + " qname{" + sixRandomDigits + "-android-ds.metric.gstatic.com}";
+ appendDnsToMeasurementDescription(sixRandomDigits, mSocketAddress);
// Build a trivial DNS packet.
final byte[] dnsPacket = getDnsQueryPacket(sixRandomDigits);
@@ -592,7 +642,7 @@
close();
}
- private byte[] getDnsQueryPacket(String sixRandomDigits) {
+ protected byte[] getDnsQueryPacket(String sixRandomDigits) {
byte[] rnd = sixRandomDigits.getBytes(StandardCharsets.US_ASCII);
return new byte[] {
(byte) mRandom.nextInt(), (byte) mRandom.nextInt(), // [0-1] query ID
@@ -611,5 +661,97 @@
0, 1 // QCLASS, set to 1 = IN (Internet)
};
}
+
+ protected void appendDnsToMeasurementDescription(
+ String sixRandomDigits, SocketAddress sockAddr) {
+ mMeasurement.description += " src{" + socketAddressToString(sockAddr) + "}"
+ + " qtype{" + mQueryType + "}"
+ + " qname{" + sixRandomDigits + "-android-ds.metric.gstatic.com}";
+ }
+ }
+
+ // TODO: Have it inherited from SimpleSocketCheck, and separate common DNS helpers out of
+ // DnsUdpCheck.
+ private class DnsTlsCheck extends DnsUdpCheck {
+ private static final int TCP_CONNECT_TIMEOUT_MS = 2500;
+ private static final int TCP_TIMEOUT_MS = 2000;
+ private static final int DNS_TLS_PORT = 853;
+ private static final int DNS_HEADER_SIZE = 12;
+
+ private final String mHostname;
+
+ public DnsTlsCheck(@Nullable String hostname, @NonNull InetAddress target,
+ @NonNull Measurement measurement) {
+ super(target, measurement);
+
+ mHostname = hostname;
+ mMeasurement.description = "DNS TLS dst{" + mTarget.getHostAddress() + "} hostname{"
+ + TextUtils.emptyIfNull(mHostname) + "}";
+ }
+
+ private SSLSocket setupSSLSocket() throws IOException {
+ // A TrustManager will be created and initialized with a KeyStore containing system
+ // CaCerts. During SSL handshake, it will be used to validate the certificates from
+ // the server.
+ SSLSocket sslSocket = (SSLSocket) SSLSocketFactory.getDefault().createSocket();
+ sslSocket.setSoTimeout(TCP_TIMEOUT_MS);
+
+ if (!TextUtils.isEmpty(mHostname)) {
+ // Set SNI.
+ final List<SNIServerName> names =
+ Collections.singletonList(new SNIHostName(mHostname));
+ SSLParameters params = sslSocket.getSSLParameters();
+ params.setServerNames(names);
+ sslSocket.setSSLParameters(params);
+ }
+
+ mNetwork.bindSocket(sslSocket);
+ return sslSocket;
+ }
+
+ private void sendDoTProbe(@Nullable SSLSocket sslSocket) throws IOException {
+ final String sixRandomDigits = String.valueOf(mRandom.nextInt(900000) + 100000);
+ final byte[] dnsPacket = getDnsQueryPacket(sixRandomDigits);
+
+ mMeasurement.startTime = now();
+ sslSocket.connect(new InetSocketAddress(mTarget, DNS_TLS_PORT), TCP_CONNECT_TIMEOUT_MS);
+
+ // Synchronous call waiting for the TLS handshake complete.
+ sslSocket.startHandshake();
+ appendDnsToMeasurementDescription(sixRandomDigits, sslSocket.getLocalSocketAddress());
+
+ final DataOutputStream output = new DataOutputStream(sslSocket.getOutputStream());
+ output.writeShort(dnsPacket.length);
+ output.write(dnsPacket, 0, dnsPacket.length);
+
+ final DataInputStream input = new DataInputStream(sslSocket.getInputStream());
+ final int replyLength = Short.toUnsignedInt(input.readShort());
+ final byte[] reply = new byte[replyLength];
+ int bytesRead = 0;
+ while (bytesRead < replyLength) {
+ bytesRead += input.read(reply, bytesRead, replyLength - bytesRead);
+ }
+
+ if (bytesRead > DNS_HEADER_SIZE && bytesRead == replyLength) {
+ mMeasurement.recordSuccess("1/1 " + responseCodeStr((int) (reply[3]) & 0x0f));
+ } else {
+ mMeasurement.recordFailure("1/1 Read " + bytesRead + " bytes while expected to be "
+ + replyLength + " bytes");
+ }
+ }
+
+ @Override
+ public void run() {
+ if (ensureMeasurementNecessary()) return;
+
+ // No need to restore the tag, since this thread is only used for this measurement.
+ TrafficStats.getAndSetThreadStatsTag(TrafficStatsConstants.TAG_SYSTEM_PROBE);
+
+ try (SSLSocket sslSocket = setupSSLSocket()) {
+ sendDoTProbe(sslSocket);
+ } catch (IOException e) {
+ mMeasurement.recordFailure(e.toString());
+ }
+ }
}
}
diff --git a/tests/net/java/com/android/server/connectivity/DnsManagerTest.java b/tests/net/java/com/android/server/connectivity/DnsManagerTest.java
index 0a603b8..26a28da 100644
--- a/tests/net/java/com/android/server/connectivity/DnsManagerTest.java
+++ b/tests/net/java/com/android/server/connectivity/DnsManagerTest.java
@@ -62,6 +62,8 @@
import com.android.internal.util.MessageUtils;
import com.android.internal.util.test.FakeSettingsProvider;
+import libcore.net.InetAddressUtils;
+
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
@@ -379,4 +381,49 @@
assertEquals(name, dnsTransTypes.get(i));
}
}
+
+ @Test
+ public void testGetPrivateDnsConfigForNetwork() throws Exception {
+ final Network network = new Network(TEST_NETID);
+ final InetAddress dnsAddr = InetAddressUtils.parseNumericAddress("3.3.3.3");
+ final InetAddress[] tlsAddrs = new InetAddress[]{
+ InetAddressUtils.parseNumericAddress("6.6.6.6"),
+ InetAddressUtils.parseNumericAddress("2001:db8:66:66::1")
+ };
+ final String tlsName = "strictmode.com";
+ LinkProperties lp = new LinkProperties();
+ lp.addDnsServer(dnsAddr);
+
+ // The PrivateDnsConfig map is empty, so the default PRIVATE_DNS_OFF is returned.
+ PrivateDnsConfig privateDnsCfg = mDnsManager.getPrivateDnsConfig(network);
+ assertFalse(privateDnsCfg.useTls);
+ assertEquals("", privateDnsCfg.hostname);
+ assertEquals(new InetAddress[0], privateDnsCfg.ips);
+
+ // An entry with default PrivateDnsConfig is added to the PrivateDnsConfig map.
+ mDnsManager.updatePrivateDns(network, mDnsManager.getPrivateDnsConfig());
+ mDnsManager.noteDnsServersForNetwork(TEST_NETID, lp);
+ mDnsManager.updatePrivateDnsValidation(
+ new DnsManager.PrivateDnsValidationUpdate(TEST_NETID, dnsAddr, "", true));
+ mDnsManager.updatePrivateDnsStatus(TEST_NETID, lp);
+ privateDnsCfg = mDnsManager.getPrivateDnsConfig(network);
+ assertTrue(privateDnsCfg.useTls);
+ assertEquals("", privateDnsCfg.hostname);
+ assertEquals(new InetAddress[0], privateDnsCfg.ips);
+
+ // The original entry is overwritten by a new PrivateDnsConfig.
+ mDnsManager.updatePrivateDns(network, new PrivateDnsConfig(tlsName, tlsAddrs));
+ mDnsManager.updatePrivateDnsStatus(TEST_NETID, lp);
+ privateDnsCfg = mDnsManager.getPrivateDnsConfig(network);
+ assertTrue(privateDnsCfg.useTls);
+ assertEquals(tlsName, privateDnsCfg.hostname);
+ assertEquals(tlsAddrs, privateDnsCfg.ips);
+
+ // The network is removed, so the PrivateDnsConfig map becomes empty again.
+ mDnsManager.removeNetwork(network);
+ privateDnsCfg = mDnsManager.getPrivateDnsConfig(network);
+ assertFalse(privateDnsCfg.useTls);
+ assertEquals("", privateDnsCfg.hostname);
+ assertEquals(new InetAddress[0], privateDnsCfg.ips);
+ }
}