Merge "Improve testOpenConnection() reliability with fallback IPv6 check" into main
diff --git a/staticlibs/framework/com/android/net/module/util/DnsPacket.java b/staticlibs/framework/com/android/net/module/util/DnsPacket.java
index 63106a1..b0c5e2e 100644
--- a/staticlibs/framework/com/android/net/module/util/DnsPacket.java
+++ b/staticlibs/framework/com/android/net/module/util/DnsPacket.java
@@ -50,7 +50,7 @@
  *
  * @hide
  */
-public abstract class DnsPacket {
+public class DnsPacket {
     /**
      * Type of the canonical name for an alias. Refer to RFC 1035 section 3.2.2.
      */
@@ -515,7 +515,14 @@
     protected final DnsHeader mHeader;
     protected final List<DnsRecord>[] mRecords;
 
-    protected DnsPacket(@NonNull byte[] data) throws ParseException {
+    /**
+     * Returns the list of DNS records for a given section.
+     */
+    public List<DnsRecord> getRecords(@RecordType int section) {
+        return mRecords[section];
+    }
+
+    public DnsPacket(@NonNull byte[] data) throws ParseException {
         if (null == data) {
             throw new ParseException("Parse header failed, null input data");
         }
@@ -548,7 +555,7 @@
      *
      * Note that authority records section and additional records section is not supported.
      */
-    protected DnsPacket(@NonNull DnsHeader header, @NonNull List<DnsRecord> qd,
+    public DnsPacket(@NonNull DnsHeader header, @NonNull List<DnsRecord> qd,
             @NonNull List<DnsRecord> an) {
         mHeader = Objects.requireNonNull(header);
         mRecords = new List[NUM_SECTIONS];
diff --git a/tests/cts/net/src/android/net/cts/ConnectivityManagerTest.java b/tests/cts/net/src/android/net/cts/ConnectivityManagerTest.java
index 3a8252a..88309ed 100644
--- a/tests/cts/net/src/android/net/cts/ConnectivityManagerTest.java
+++ b/tests/cts/net/src/android/net/cts/ConnectivityManagerTest.java
@@ -145,6 +145,7 @@
 import android.net.ConnectivityManager;
 import android.net.ConnectivityManager.NetworkCallback;
 import android.net.ConnectivitySettingsManager;
+import android.net.DnsResolver;
 import android.net.InetAddresses;
 import android.net.IpSecManager;
 import android.net.IpSecManager.UdpEncapsulationSocket;
@@ -201,6 +202,7 @@
 import com.android.internal.util.ArrayUtils;
 import com.android.modules.utils.build.SdkLevel;
 import com.android.net.module.util.CollectionUtils;
+import com.android.net.module.util.DnsPacket;
 import com.android.networkstack.apishim.ConnectivityManagerShimImpl;
 import com.android.networkstack.apishim.ConstantsShim;
 import com.android.networkstack.apishim.NetworkInformationShimImpl;
@@ -249,7 +251,6 @@
 import java.net.Socket;
 import java.net.SocketException;
 import java.net.URL;
-import java.net.UnknownHostException;
 import java.nio.charset.StandardCharsets;
 import java.util.ArrayList;
 import java.util.Arrays;
@@ -306,6 +307,7 @@
     private static final int LISTEN_ACTIVITY_TIMEOUT_MS = 30_000;
     private static final int NO_CALLBACK_TIMEOUT_MS = 100;
     private static final int NETWORK_REQUEST_TIMEOUT_MS = 3000;
+    private static final int DNS_REQUEST_TIMEOUT_MS = 1000;
     private static final int SOCKET_TIMEOUT_MS = 100;
     private static final int NUM_TRIES_MULTIPATH_PREF_CHECK = 20;
     private static final long INTERVAL_MULTIPATH_PREF_CHECK_MS = 500;
@@ -861,6 +863,55 @@
         });
     }
 
+    @NonNull
+    private static String getDeviceIpv6AddressThroughDnsQuery(Network network) throws Exception {
+        final InetAddress dnsAddr = getAddrByName("ns1.google.com", AF_INET6);
+        assertNotNull("IPv6 address for ns1.google.com should not be null", dnsAddr);
+
+        try (DatagramSocket udpSocket = new DatagramSocket()) {
+            network.bindSocket(udpSocket);
+
+            final DnsPacket queryDnsPkt = new DnsPacket(
+                    new DnsPacket.DnsHeader(new Random().nextInt(), DnsResolver.FLAG_EMPTY,
+                            1 /* qdcount */,
+                            0 /* ancount */),
+                    List.of(DnsPacket.DnsRecord.makeQuestion("o-o.myaddr.l.google.com",
+                            DnsResolver.TYPE_TXT, DnsResolver.CLASS_IN)),
+                    List.of() /* an */
+            );
+            final byte[] queryDnsRawBytes = queryDnsPkt.getBytes();
+            final byte[] receiveBuffer = new byte[1500];
+            final int maxRetry = 3;
+            for (int attempt = 1; attempt <= maxRetry; ++attempt) {
+                try {
+                    final DatagramPacket queryUdpPkt = new DatagramPacket(queryDnsRawBytes,
+                            queryDnsRawBytes.length, dnsAddr, 53 /* port */);
+                    udpSocket.send(queryUdpPkt);
+
+                    final DatagramPacket replyUdpPkt = new DatagramPacket(receiveBuffer,
+                            receiveBuffer.length);
+                    udpSocket.setSoTimeout(DNS_REQUEST_TIMEOUT_MS);
+                    udpSocket.receive(replyUdpPkt);
+                    break;
+                } catch (IOException e) {
+                    if (attempt == maxRetry) {
+                        throw e; // If the last attempt fails, rethrow the exception.
+                    } else {
+                        Log.e(TAG, "DNS request failed (attempt " + attempt + ")" + e);
+                    }
+                }
+            }
+
+            final DnsPacket replyDnsPkt = new DnsPacket(receiveBuffer);
+            final DnsPacket.DnsRecord answerRecord = replyDnsPkt.getRecords(
+                    DnsPacket.ANSECTION).get(0);
+            final byte[] txtReplyRecord = answerRecord.getRR();
+            final byte dataLength = txtReplyRecord[0];
+            assertEquals(dataLength, txtReplyRecord.length - 1);
+            return new String(Arrays.copyOfRange(txtReplyRecord, 1, txtReplyRecord.length));
+        }
+    }
+
     /**
      * Tests that connections can be opened on WiFi and cellphone networks,
      * and that they are made from different IP addresses.
@@ -886,8 +937,31 @@
 
         // Verify that the IP addresses that the requests appeared to come from are actually on the
         // respective networks.
-        assertOnNetwork(wifiAddressString, wifiNetwork);
-        assertOnNetwork(cellAddressString, cellNetwork);
+        final InetAddress wifiAddress = InetAddresses.parseNumericAddress(wifiAddressString);
+        final LinkProperties wifiLinkProperties = mCm.getLinkProperties(wifiNetwork);
+        // To make sure that the request went out on the right network, check that
+        // the IP address seen by the server is assigned to the expected network.
+        // We can only do this for IPv6 addresses, because in IPv4 we will likely
+        // have a private IPv4 address, and that won't match what the server sees.
+        if (wifiAddress instanceof Inet6Address) {
+            assertContains(wifiLinkProperties.getAddresses(), wifiAddress);
+        }
+
+        final LinkProperties cellLinkProperties = mCm.getLinkProperties(cellNetwork);
+        final InetAddress cellAddress = InetAddresses.parseNumericAddress(cellAddressString);
+        final List<InetAddress> cellNetworkAddresses = cellLinkProperties.getAddresses();
+        // In userdebug build, on cellular network, if the onNetwork check failed, we also try to
+        // re-verify it by obtaining the IP address through DNS query.
+        boolean isUserDebug = Build.isDebuggable();
+        if (cellAddress instanceof Inet6Address) {
+            if (isUserDebug && !cellNetworkAddresses.contains(cellAddress)) {
+                final InetAddress ipv6AddressThroughDns = InetAddresses.parseNumericAddress(
+                        getDeviceIpv6AddressThroughDnsQuery(cellNetwork));
+                assertContains(cellNetworkAddresses, ipv6AddressThroughDns);
+            } else {
+                assertContains(cellNetworkAddresses, cellAddress);
+            }
+        }
 
         assertFalse("Unexpectedly equal: " + wifiNetwork, wifiNetwork.equals(cellNetwork));
     }
@@ -919,17 +993,6 @@
         }
     }
 
-    private void assertOnNetwork(String adressString, Network network) throws UnknownHostException {
-        InetAddress address = InetAddress.getByName(adressString);
-        LinkProperties linkProperties = mCm.getLinkProperties(network);
-        // To make sure that the request went out on the right network, check that
-        // the IP address seen by the server is assigned to the expected network.
-        // We can only do this for IPv6 addresses, because in IPv4 we will likely
-        // have a private IPv4 address, and that won't match what the server sees.
-        if (address instanceof Inet6Address) {
-            assertContains(linkProperties.getAddresses(), address);
-        }
-    }
 
     private static<T> void assertContains(Collection<T> collection, T element) {
         assertTrue(element + " not found in " + collection, collection.contains(element));
@@ -1713,7 +1776,8 @@
         }
     }
 
-    private InetAddress getAddrByName(final String hostname, final int family) throws Exception {
+    private static InetAddress getAddrByName(final String hostname, final int family)
+            throws Exception {
         final InetAddress[] allAddrs = InetAddress.getAllByName(hostname);
         for (InetAddress addr : allAddrs) {
             if (family == AF_INET && addr instanceof Inet4Address) return addr;