Merge "Update DnsManagerTest for AIDL interface change" into rvc-dev
diff --git a/services/core/java/com/android/server/ConnectivityService.java b/services/core/java/com/android/server/ConnectivityService.java
index d9f01da..01b2e69 100644
--- a/services/core/java/com/android/server/ConnectivityService.java
+++ b/services/core/java/com/android/server/ConnectivityService.java
@@ -2488,10 +2488,12 @@
         final List<NetworkDiagnostics> netDiags = new ArrayList<NetworkDiagnostics>();
         final long DIAG_TIME_MS = 5000;
         for (NetworkAgentInfo nai : networksSortedById()) {
+            PrivateDnsConfig privateDnsCfg = mDnsManager.getPrivateDnsConfig(nai.network);
             // Start gathering diagnostic information.
             netDiags.add(new NetworkDiagnostics(
                     nai.network,
                     new LinkProperties(nai.linkProperties),  // Must be a copy.
+                    privateDnsCfg,
                     DIAG_TIME_MS));
         }
 
diff --git a/services/core/java/com/android/server/connectivity/DnsManager.java b/services/core/java/com/android/server/connectivity/DnsManager.java
index 506c8e3..cf6a7f6 100644
--- a/services/core/java/com/android/server/connectivity/DnsManager.java
+++ b/services/core/java/com/android/server/connectivity/DnsManager.java
@@ -57,6 +57,7 @@
 import java.util.Iterator;
 import java.util.Map;
 import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
 import java.util.stream.Collectors;
 
 
@@ -64,7 +65,9 @@
  * Encapsulate the management of DNS settings for networks.
  *
  * This class it NOT designed for concurrent access. Furthermore, all non-static
- * methods MUST be called from ConnectivityService's thread.
+ * methods MUST be called from ConnectivityService's thread. However, an exceptional
+ * case is getPrivateDnsConfig(Network) which is exclusively for
+ * ConnectivityService#dumpNetworkDiagnostics() on a random binder thread.
  *
  * [ Private DNS ]
  * The code handling Private DNS is spread across several components, but this
@@ -236,8 +239,8 @@
     private final ContentResolver mContentResolver;
     private final IDnsResolver mDnsResolver;
     private final MockableSystemProperties mSystemProperties;
-    // TODO: Replace these Maps with SparseArrays.
-    private final Map<Integer, PrivateDnsConfig> mPrivateDnsMap;
+    private final ConcurrentHashMap<Integer, PrivateDnsConfig> mPrivateDnsMap;
+    // TODO: Replace the Map with SparseArrays.
     private final Map<Integer, PrivateDnsValidationStatuses> mPrivateDnsValidationMap;
     private final Map<Integer, LinkProperties> mLinkPropertiesMap;
     private final Map<Integer, int[]> mTransportsMap;
@@ -247,15 +250,13 @@
     private int mSuccessThreshold;
     private int mMinSamples;
     private int mMaxSamples;
-    private String mPrivateDnsMode;
-    private String mPrivateDnsSpecifier;
 
     public DnsManager(Context ctx, IDnsResolver dnsResolver, MockableSystemProperties sp) {
         mContext = ctx;
         mContentResolver = mContext.getContentResolver();
         mDnsResolver = dnsResolver;
         mSystemProperties = sp;
-        mPrivateDnsMap = new HashMap<>();
+        mPrivateDnsMap = new ConcurrentHashMap<>();
         mPrivateDnsValidationMap = new HashMap<>();
         mLinkPropertiesMap = new HashMap<>();
         mTransportsMap = new HashMap<>();
@@ -275,6 +276,12 @@
         mLinkPropertiesMap.remove(network.netId);
     }
 
+    // This is exclusively called by ConnectivityService#dumpNetworkDiagnostics() which
+    // is not on the ConnectivityService handler thread.
+    public PrivateDnsConfig getPrivateDnsConfig(@NonNull Network network) {
+        return mPrivateDnsMap.getOrDefault(network.netId, PRIVATE_DNS_OFF);
+    }
+
     public PrivateDnsConfig updatePrivateDns(Network network, PrivateDnsConfig cfg) {
         Slog.w(TAG, "updatePrivateDns(" + network + ", " + cfg + ")");
         return (cfg != null)
diff --git a/services/core/java/com/android/server/connectivity/NetworkDiagnostics.java b/services/core/java/com/android/server/connectivity/NetworkDiagnostics.java
index a1a8e35..49c16ad 100644
--- a/services/core/java/com/android/server/connectivity/NetworkDiagnostics.java
+++ b/services/core/java/com/android/server/connectivity/NetworkDiagnostics.java
@@ -18,12 +18,15 @@
 
 import static android.system.OsConstants.*;
 
+import android.annotation.NonNull;
+import android.annotation.Nullable;
 import android.net.LinkAddress;
 import android.net.LinkProperties;
 import android.net.Network;
 import android.net.NetworkUtils;
 import android.net.RouteInfo;
 import android.net.TrafficStats;
+import android.net.shared.PrivateDnsConfig;
 import android.net.util.NetworkConstants;
 import android.os.SystemClock;
 import android.system.ErrnoException;
@@ -38,6 +41,8 @@
 import libcore.io.IoUtils;
 
 import java.io.Closeable;
+import java.io.DataInputStream;
+import java.io.DataOutputStream;
 import java.io.FileDescriptor;
 import java.io.IOException;
 import java.io.InterruptedIOException;
@@ -52,6 +57,7 @@
 import java.nio.ByteBuffer;
 import java.nio.charset.StandardCharsets;
 import java.util.ArrayList;
+import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
@@ -59,6 +65,12 @@
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.TimeUnit;
 
+import javax.net.ssl.SNIHostName;
+import javax.net.ssl.SNIServerName;
+import javax.net.ssl.SSLParameters;
+import javax.net.ssl.SSLSocket;
+import javax.net.ssl.SSLSocketFactory;
+
 /**
  * NetworkDiagnostics
  *
@@ -100,6 +112,7 @@
 
     private final Network mNetwork;
     private final LinkProperties mLinkProperties;
+    private final PrivateDnsConfig mPrivateDnsCfg;
     private final Integer mInterfaceIndex;
 
     private final long mTimeoutMs;
@@ -163,12 +176,15 @@
     private final Map<Pair<InetAddress, InetAddress>, Measurement> mExplicitSourceIcmpChecks =
             new HashMap<>();
     private final Map<InetAddress, Measurement> mDnsUdpChecks = new HashMap<>();
+    private final Map<InetAddress, Measurement> mDnsTlsChecks = new HashMap<>();
     private final String mDescription;
 
 
-    public NetworkDiagnostics(Network network, LinkProperties lp, long timeoutMs) {
+    public NetworkDiagnostics(Network network, LinkProperties lp,
+            @NonNull PrivateDnsConfig privateDnsCfg, long timeoutMs) {
         mNetwork = network;
         mLinkProperties = lp;
+        mPrivateDnsCfg = privateDnsCfg;
         mInterfaceIndex = getInterfaceIndex(mLinkProperties.getInterfaceName());
         mTimeoutMs = timeoutMs;
         mStartTime = now();
@@ -199,8 +215,22 @@
             }
         }
         for (InetAddress nameserver : mLinkProperties.getDnsServers()) {
-                prepareIcmpMeasurement(nameserver);
-                prepareDnsMeasurement(nameserver);
+            prepareIcmpMeasurement(nameserver);
+            prepareDnsMeasurement(nameserver);
+
+            // Unlike the DnsResolver which doesn't do certificate validation in opportunistic mode,
+            // DoT probes to the DNS servers will fail if certificate validation fails.
+            prepareDnsTlsMeasurement(null /* hostname */, nameserver);
+        }
+
+        for (InetAddress tlsNameserver : mPrivateDnsCfg.ips) {
+            // Reachability check is necessary since when resolving the strict mode hostname,
+            // NetworkMonitor always queries for both A and AAAA records, even if the network
+            // is IPv4-only or IPv6-only.
+            if (mLinkProperties.isReachable(tlsNameserver)) {
+                // If there are IPs, there must have been a name that resolved to them.
+                prepareDnsTlsMeasurement(mPrivateDnsCfg.hostname, tlsNameserver);
+            }
         }
 
         mCountDownLatch = new CountDownLatch(totalMeasurementCount());
@@ -222,6 +252,15 @@
         }
     }
 
+    private static String socketAddressToString(@NonNull SocketAddress sockAddr) {
+        // The default toString() implementation is not the prettiest.
+        InetSocketAddress inetSockAddr = (InetSocketAddress) sockAddr;
+        InetAddress localAddr = inetSockAddr.getAddress();
+        return String.format(
+                (localAddr instanceof Inet6Address ? "[%s]:%d" : "%s:%d"),
+                localAddr.getHostAddress(), inetSockAddr.getPort());
+    }
+
     private void prepareIcmpMeasurement(InetAddress target) {
         if (!mIcmpChecks.containsKey(target)) {
             Measurement measurement = new Measurement();
@@ -252,8 +291,19 @@
         }
     }
 
+    private void prepareDnsTlsMeasurement(@Nullable String hostname, @NonNull InetAddress target) {
+        // This might overwrite an existing entry in mDnsTlsChecks, because |target| can be an IP
+        // address configured by the network as well as an IP address learned by resolving the
+        // strict mode DNS hostname. If the entry is overwritten, the overwritten measurement
+        // thread will not execute.
+        Measurement measurement = new Measurement();
+        measurement.thread = new Thread(new DnsTlsCheck(hostname, target, measurement));
+        mDnsTlsChecks.put(target, measurement);
+    }
+
     private int totalMeasurementCount() {
-        return mIcmpChecks.size() + mExplicitSourceIcmpChecks.size() + mDnsUdpChecks.size();
+        return mIcmpChecks.size() + mExplicitSourceIcmpChecks.size() + mDnsUdpChecks.size()
+                + mDnsTlsChecks.size();
     }
 
     private void startMeasurements() {
@@ -266,6 +316,9 @@
         for (Measurement measurement : mDnsUdpChecks.values()) {
             measurement.thread.start();
         }
+        for (Measurement measurement : mDnsTlsChecks.values()) {
+            measurement.thread.start();
+        }
     }
 
     public void waitForMeasurements() {
@@ -297,6 +350,11 @@
                 measurements.add(entry.getValue());
             }
         }
+        for (Map.Entry<InetAddress, Measurement> entry : mDnsTlsChecks.entrySet()) {
+            if (entry.getKey() instanceof Inet4Address) {
+                measurements.add(entry.getValue());
+            }
+        }
 
         // IPv6 measurements second.
         for (Map.Entry<InetAddress, Measurement> entry : mIcmpChecks.entrySet()) {
@@ -315,6 +373,11 @@
                 measurements.add(entry.getValue());
             }
         }
+        for (Map.Entry<InetAddress, Measurement> entry : mDnsTlsChecks.entrySet()) {
+            if (entry.getKey() instanceof Inet6Address) {
+                measurements.add(entry.getValue());
+            }
+        }
 
         return measurements;
     }
@@ -387,6 +450,8 @@
             try {
                 mFileDescriptor = Os.socket(mAddressFamily, sockType, protocol);
             } finally {
+                // TODO: The tag should remain set until all traffic is sent and received.
+                // Consider tagging the socket after the measurement thread is started.
                 TrafficStats.setThreadStatsTag(oldTag);
             }
             // Setting SNDTIMEO is purely for defensive purposes.
@@ -403,13 +468,12 @@
             mSocketAddress = Os.getsockname(mFileDescriptor);
         }
 
-        protected String getSocketAddressString() {
-            // The default toString() implementation is not the prettiest.
-            InetSocketAddress inetSockAddr = (InetSocketAddress) mSocketAddress;
-            InetAddress localAddr = inetSockAddr.getAddress();
-            return String.format(
-                    (localAddr instanceof Inet6Address ? "[%s]:%d" : "%s:%d"),
-                    localAddr.getHostAddress(), inetSockAddr.getPort());
+        protected boolean ensureMeasurementNecessary() {
+            if (mMeasurement.finishTime == 0) return false;
+
+            // Countdown latch was not decremented when the measurement failed during setup.
+            mCountDownLatch.countDown();
+            return true;
         }
 
         @Override
@@ -448,13 +512,7 @@
 
         @Override
         public void run() {
-            // Check if this measurement has already failed during setup.
-            if (mMeasurement.finishTime > 0) {
-                // If the measurement failed during construction it didn't
-                // decrement the countdown latch; do so here.
-                mCountDownLatch.countDown();
-                return;
-            }
+            if (ensureMeasurementNecessary()) return;
 
             try {
                 setupSocket(SOCK_DGRAM, mProtocol, TIMEOUT_SEND, TIMEOUT_RECV, 0);
@@ -462,7 +520,7 @@
                 mMeasurement.recordFailure(e.toString());
                 return;
             }
-            mMeasurement.description += " src{" + getSocketAddressString() + "}";
+            mMeasurement.description += " src{" + socketAddressToString(mSocketAddress) + "}";
 
             // Build a trivial ICMP packet.
             final byte[] icmpPacket = {
@@ -507,10 +565,10 @@
         private static final int RR_TYPE_AAAA = 28;
         private static final int PACKET_BUFSIZE = 512;
 
-        private final Random mRandom = new Random();
+        protected final Random mRandom = new Random();
 
         // Should be static, but the compiler mocks our puny, human attempts at reason.
-        private String responseCodeStr(int rcode) {
+        protected String responseCodeStr(int rcode) {
             try {
                 return DnsResponseCode.values()[rcode].toString();
             } catch (IndexOutOfBoundsException e) {
@@ -518,7 +576,7 @@
             }
         }
 
-        private final int mQueryType;
+        protected final int mQueryType;
 
         public DnsUdpCheck(InetAddress target, Measurement measurement) {
             super(target, measurement);
@@ -535,13 +593,7 @@
 
         @Override
         public void run() {
-            // Check if this measurement has already failed during setup.
-            if (mMeasurement.finishTime > 0) {
-                // If the measurement failed during construction it didn't
-                // decrement the countdown latch; do so here.
-                mCountDownLatch.countDown();
-                return;
-            }
+            if (ensureMeasurementNecessary()) return;
 
             try {
                 setupSocket(SOCK_DGRAM, IPPROTO_UDP, TIMEOUT_SEND, TIMEOUT_RECV,
@@ -550,12 +602,10 @@
                 mMeasurement.recordFailure(e.toString());
                 return;
             }
-            mMeasurement.description += " src{" + getSocketAddressString() + "}";
 
             // This needs to be fixed length so it can be dropped into the pre-canned packet.
             final String sixRandomDigits = String.valueOf(mRandom.nextInt(900000) + 100000);
-            mMeasurement.description += " qtype{" + mQueryType + "}"
-                    + " qname{" + sixRandomDigits + "-android-ds.metric.gstatic.com}";
+            appendDnsToMeasurementDescription(sixRandomDigits, mSocketAddress);
 
             // Build a trivial DNS packet.
             final byte[] dnsPacket = getDnsQueryPacket(sixRandomDigits);
@@ -592,7 +642,7 @@
             close();
         }
 
-        private byte[] getDnsQueryPacket(String sixRandomDigits) {
+        protected byte[] getDnsQueryPacket(String sixRandomDigits) {
             byte[] rnd = sixRandomDigits.getBytes(StandardCharsets.US_ASCII);
             return new byte[] {
                 (byte) mRandom.nextInt(), (byte) mRandom.nextInt(),  // [0-1]   query ID
@@ -611,5 +661,97 @@
                 0, 1  // QCLASS, set to 1 = IN (Internet)
             };
         }
+
+        protected void appendDnsToMeasurementDescription(
+                String sixRandomDigits, SocketAddress sockAddr) {
+            mMeasurement.description += " src{" + socketAddressToString(sockAddr) + "}"
+                    + " qtype{" + mQueryType + "}"
+                    + " qname{" + sixRandomDigits + "-android-ds.metric.gstatic.com}";
+        }
+    }
+
+    // TODO: Have it inherited from SimpleSocketCheck, and separate common DNS helpers out of
+    // DnsUdpCheck.
+    private class DnsTlsCheck extends DnsUdpCheck {
+        private static final int TCP_CONNECT_TIMEOUT_MS = 2500;
+        private static final int TCP_TIMEOUT_MS = 2000;
+        private static final int DNS_TLS_PORT = 853;
+        private static final int DNS_HEADER_SIZE = 12;
+
+        private final String mHostname;
+
+        public DnsTlsCheck(@Nullable String hostname, @NonNull InetAddress target,
+                @NonNull Measurement measurement) {
+            super(target, measurement);
+
+            mHostname = hostname;
+            mMeasurement.description = "DNS TLS dst{" + mTarget.getHostAddress() + "} hostname{"
+                    + TextUtils.emptyIfNull(mHostname) + "}";
+        }
+
+        private SSLSocket setupSSLSocket() throws IOException {
+            // A TrustManager will be created and initialized with a KeyStore containing system
+            // CaCerts. During SSL handshake, it will be used to validate the certificates from
+            // the server.
+            SSLSocket sslSocket = (SSLSocket) SSLSocketFactory.getDefault().createSocket();
+            sslSocket.setSoTimeout(TCP_TIMEOUT_MS);
+
+            if (!TextUtils.isEmpty(mHostname)) {
+                // Set SNI.
+                final List<SNIServerName> names =
+                        Collections.singletonList(new SNIHostName(mHostname));
+                SSLParameters params = sslSocket.getSSLParameters();
+                params.setServerNames(names);
+                sslSocket.setSSLParameters(params);
+            }
+
+            mNetwork.bindSocket(sslSocket);
+            return sslSocket;
+        }
+
+        private void sendDoTProbe(@Nullable SSLSocket sslSocket) throws IOException {
+            final String sixRandomDigits = String.valueOf(mRandom.nextInt(900000) + 100000);
+            final byte[] dnsPacket = getDnsQueryPacket(sixRandomDigits);
+
+            mMeasurement.startTime = now();
+            sslSocket.connect(new InetSocketAddress(mTarget, DNS_TLS_PORT), TCP_CONNECT_TIMEOUT_MS);
+
+            // Synchronous call waiting for the TLS handshake complete.
+            sslSocket.startHandshake();
+            appendDnsToMeasurementDescription(sixRandomDigits, sslSocket.getLocalSocketAddress());
+
+            final DataOutputStream output = new DataOutputStream(sslSocket.getOutputStream());
+            output.writeShort(dnsPacket.length);
+            output.write(dnsPacket, 0, dnsPacket.length);
+
+            final DataInputStream input = new DataInputStream(sslSocket.getInputStream());
+            final int replyLength = Short.toUnsignedInt(input.readShort());
+            final byte[] reply = new byte[replyLength];
+            int bytesRead = 0;
+            while (bytesRead < replyLength) {
+                bytesRead += input.read(reply, bytesRead, replyLength - bytesRead);
+            }
+
+            if (bytesRead > DNS_HEADER_SIZE && bytesRead == replyLength) {
+                mMeasurement.recordSuccess("1/1 " + responseCodeStr((int) (reply[3]) & 0x0f));
+            } else {
+                mMeasurement.recordFailure("1/1 Read " + bytesRead + " bytes while expected to be "
+                        + replyLength + " bytes");
+            }
+        }
+
+        @Override
+        public void run() {
+            if (ensureMeasurementNecessary()) return;
+
+            // No need to restore the tag, since this thread is only used for this measurement.
+            TrafficStats.getAndSetThreadStatsTag(TrafficStatsConstants.TAG_SYSTEM_PROBE);
+
+            try (SSLSocket sslSocket = setupSSLSocket()) {
+                sendDoTProbe(sslSocket);
+            } catch (IOException e) {
+                mMeasurement.recordFailure(e.toString());
+            }
+        }
     }
 }
diff --git a/tests/net/java/com/android/server/connectivity/DnsManagerTest.java b/tests/net/java/com/android/server/connectivity/DnsManagerTest.java
index 7cdffc6..508b5cd 100644
--- a/tests/net/java/com/android/server/connectivity/DnsManagerTest.java
+++ b/tests/net/java/com/android/server/connectivity/DnsManagerTest.java
@@ -62,6 +62,8 @@
 import com.android.internal.util.MessageUtils;
 import com.android.internal.util.test.FakeSettingsProvider;
 
+import libcore.net.InetAddressUtils;
+
 import org.junit.Before;
 import org.junit.Test;
 import org.junit.runner.RunWith;
@@ -380,4 +382,49 @@
             assertEquals(name, dnsTransTypes.get(i));
         }
     }
+
+    @Test
+    public void testGetPrivateDnsConfigForNetwork() throws Exception {
+        final Network network = new Network(TEST_NETID);
+        final InetAddress dnsAddr = InetAddressUtils.parseNumericAddress("3.3.3.3");
+        final InetAddress[] tlsAddrs = new InetAddress[]{
+            InetAddressUtils.parseNumericAddress("6.6.6.6"),
+            InetAddressUtils.parseNumericAddress("2001:db8:66:66::1")
+        };
+        final String tlsName = "strictmode.com";
+        LinkProperties lp = new LinkProperties();
+        lp.addDnsServer(dnsAddr);
+
+        // The PrivateDnsConfig map is empty, so the default PRIVATE_DNS_OFF is returned.
+        PrivateDnsConfig privateDnsCfg = mDnsManager.getPrivateDnsConfig(network);
+        assertFalse(privateDnsCfg.useTls);
+        assertEquals("", privateDnsCfg.hostname);
+        assertEquals(new InetAddress[0], privateDnsCfg.ips);
+
+        // An entry with default PrivateDnsConfig is added to the PrivateDnsConfig map.
+        mDnsManager.updatePrivateDns(network, mDnsManager.getPrivateDnsConfig());
+        mDnsManager.noteDnsServersForNetwork(TEST_NETID, lp);
+        mDnsManager.updatePrivateDnsValidation(
+                new DnsManager.PrivateDnsValidationUpdate(TEST_NETID, dnsAddr, "", true));
+        mDnsManager.updatePrivateDnsStatus(TEST_NETID, lp);
+        privateDnsCfg = mDnsManager.getPrivateDnsConfig(network);
+        assertTrue(privateDnsCfg.useTls);
+        assertEquals("", privateDnsCfg.hostname);
+        assertEquals(new InetAddress[0], privateDnsCfg.ips);
+
+        // The original entry is overwritten by a new PrivateDnsConfig.
+        mDnsManager.updatePrivateDns(network, new PrivateDnsConfig(tlsName, tlsAddrs));
+        mDnsManager.updatePrivateDnsStatus(TEST_NETID, lp);
+        privateDnsCfg = mDnsManager.getPrivateDnsConfig(network);
+        assertTrue(privateDnsCfg.useTls);
+        assertEquals(tlsName, privateDnsCfg.hostname);
+        assertEquals(tlsAddrs, privateDnsCfg.ips);
+
+        // The network is removed, so the PrivateDnsConfig map becomes empty again.
+        mDnsManager.removeNetwork(network);
+        privateDnsCfg = mDnsManager.getPrivateDnsConfig(network);
+        assertFalse(privateDnsCfg.useTls);
+        assertEquals("", privateDnsCfg.hostname);
+        assertEquals(new InetAddress[0], privateDnsCfg.ips);
+    }
 }