Merge "Parse DnsRecord by factory method"
diff --git a/staticlibs/framework/com/android/net/module/util/DnsPacket.java b/staticlibs/framework/com/android/net/module/util/DnsPacket.java
index 79ca3a3..0dcdf1e 100644
--- a/staticlibs/framework/com/android/net/module/util/DnsPacket.java
+++ b/staticlibs/framework/com/android/net/module/util/DnsPacket.java
@@ -19,7 +19,6 @@
 import static android.net.DnsResolver.TYPE_A;
 import static android.net.DnsResolver.TYPE_AAAA;
 
-import static com.android.internal.annotations.VisibleForTesting.Visibility.PACKAGE;
 import static com.android.internal.annotations.VisibleForTesting.Visibility.PRIVATE;
 import static com.android.net.module.util.DnsPacketUtils.DnsRecordParser.domainNameToLabels;
 
@@ -245,9 +244,11 @@
      *     /                     RDATA                     /
      *     /                                               /
      *     +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
+     *
+     * Note that this class is meant to be used by composition and not inheritance, and
+     * that classes implementing more specific DNS records should call #parse.
      */
-    // TODO: Make DnsResourceRecord and DnsQuestion subclasses of DnsRecord, and construct
-    //  corresponding object from factory methods.
+    // TODO: Make DnsResourceRecord and DnsQuestion subclasses of DnsRecord.
     public static class DnsRecord {
         // Refer to RFC 1035 section 2.3.4 for MAXNAMESIZE.
         // NAME_NORMAL and NAME_COMPRESSION are used for checking name compression,
@@ -281,8 +282,7 @@
          * @param buf ByteBuffer input of record, must be in network byte order
          *         (which is the default).
          */
-        @VisibleForTesting(visibility = PACKAGE)
-        public DnsRecord(@RecordType int rType, @NonNull ByteBuffer buf)
+        private DnsRecord(@RecordType int rType, @NonNull ByteBuffer buf)
                 throws BufferUnderflowException, ParseException {
             Objects.requireNonNull(buf);
             this.rType = rType;
@@ -307,6 +307,31 @@
         }
 
         /**
+         * Create a new DnsRecord or subclass of DnsRecord instance from a positioned ByteBuffer.
+         *
+         * Peek the nsType, sending the buffer to corresponding DnsRecord subclass constructors
+         * to allow constructing the corresponding object.
+         */
+        @VisibleForTesting(visibility = PRIVATE)
+        public static DnsRecord parse(@RecordType int rType, @NonNull ByteBuffer buf)
+                throws BufferUnderflowException, ParseException {
+            Objects.requireNonNull(buf);
+            final int oldPos = buf.position();
+            // Parsed name not used, just for jumping to nsType position.
+            DnsRecordParser.parseName(buf, 0 /* Parse depth */,
+                    true /* isNameCompressionSupported */);
+            // Peek the nsType.
+            final int nsType = Short.toUnsignedInt(buf.getShort());
+            buf.position(oldPos);
+            // Return a DnsRecord instance by default for backward compatibility, this is useful
+            // when a partner supports new type of DnsRecord but does not inherit DnsRecord.
+            switch (nsType) {
+                default:
+                    return new DnsRecord(rType, buf);
+            }
+        }
+
+        /**
          * Make an A or AAAA record based on the specified parameters.
          *
          * @param rType Type of the record, can be {@link #ANSECTION}, {@link #ARSECTION}
@@ -507,7 +532,7 @@
             mRecords[i] = new ArrayList(count);
             for (int j = 0; j < count; ++j) {
                 try {
-                    mRecords[i].add(new DnsRecord(i, buffer));
+                    mRecords[i].add(DnsRecord.parse(i, buffer));
                 } catch (BufferUnderflowException e) {
                     throw new ParseException("Parse record fail", e);
                 }
diff --git a/staticlibs/tests/unit/src/com/android/net/module/util/DnsPacketTest.java b/staticlibs/tests/unit/src/com/android/net/module/util/DnsPacketTest.java
index 409c1eb..28e183a 100644
--- a/staticlibs/tests/unit/src/com/android/net/module/util/DnsPacketTest.java
+++ b/staticlibs/tests/unit/src/com/android/net/module/util/DnsPacketTest.java
@@ -219,7 +219,7 @@
                 0x00, 0x04, /* Data length */
                 (byte) 0xac, (byte) 0xd9, (byte) 0xa1, (byte) 0x84 /* Address */};
         final DnsPacket.DnsRecord questionsFromBytes =
-                new DnsPacket.DnsRecord(DnsPacket.QDSECTION, ByteBuffer.wrap(qdWithTTLRData));
+                DnsPacket.DnsRecord.parse(DnsPacket.QDSECTION, ByteBuffer.wrap(qdWithTTLRData));
         assertEquals(0, questionsFromBytes.ttl);
         assertNull(questionsFromBytes.getRR());
 
@@ -230,12 +230,12 @@
                 0x00, 0x01, /* Type */
                 0x00, 0x01, /* Class */};
         assertThrows(BufferUnderflowException.class, () ->
-                new DnsPacket.DnsRecord(DnsPacket.ANSECTION, ByteBuffer.wrap(anWithoutTTLRData)));
+                DnsPacket.DnsRecord.parse(DnsPacket.ANSECTION, ByteBuffer.wrap(anWithoutTTLRData)));
     }
 
     private void assertDnsRecordRoundTrip(DnsPacket.DnsRecord before)
             throws IOException {
-        final DnsPacket.DnsRecord after = new DnsPacket.DnsRecord(before.rType,
+        final DnsPacket.DnsRecord after = DnsPacket.DnsRecord.parse(before.rType,
                 ByteBuffer.wrap(before.getBytes()));
         assertEquals(after, before);
     }
@@ -393,7 +393,7 @@
                 "test.com", TYPE_AAAA, CLASS_IN);
         final DnsPacket.DnsRecord testAnswer = DnsPacket.DnsRecord.makeCNameRecord(
                 DnsPacket.ANSECTION, "test.com", CLASS_IN, 9, "www.test.com");
-        final DnsPacket.DnsRecord questionFromBytes = new DnsPacket.DnsRecord(DnsPacket.QDSECTION,
+        final DnsPacket.DnsRecord questionFromBytes = DnsPacket.DnsRecord.parse(DnsPacket.QDSECTION,
                 ByteBuffer.wrap(testQuestion.getBytes()));
         assertEquals(testQuestion, questionFromBytes);
         assertEquals(testQuestion.hashCode(), questionFromBytes.hashCode());