Merge "Improve test for the changing of DnsResolver"
am: fd3e15e0c5

Change-Id: I95d74116c9de67b8915343f45b26a011d0248faa
diff --git a/tests/cts/net/src/android/net/cts/DnsResolverTest.java b/tests/cts/net/src/android/net/cts/DnsResolverTest.java
index 308f1ed..643d542 100644
--- a/tests/cts/net/src/android/net/cts/DnsResolverTest.java
+++ b/tests/cts/net/src/android/net/cts/DnsResolverTest.java
@@ -18,15 +18,18 @@
 
 import static android.net.DnsResolver.CLASS_IN;
 import static android.net.DnsResolver.FLAG_NO_CACHE_LOOKUP;
+import static android.net.DnsResolver.TYPE_A;
 import static android.net.DnsResolver.TYPE_AAAA;
 
 import android.annotation.NonNull;
+import android.annotation.Nullable;
 import android.content.Context;
 import android.net.ConnectivityManager;
 import android.net.DnsPacket;
 import android.net.DnsResolver;
 import android.net.Network;
 import android.net.NetworkCapabilities;
+import android.net.ParseException;
 import android.os.Handler;
 import android.os.Looper;
 import android.system.ErrnoException;
@@ -86,41 +89,55 @@
         return testableNetworks.toArray(new Network[0]);
     }
 
-    public void testInetAddressQuery() throws ErrnoException {
+    public void testQueryWithInetAddressCallback() {
         final String dname = "www.google.com";
-        final String msg = "InetAddress query " + dname;
+        final String msg = "Query with InetAddressAnswerCallback " + dname;
         for (Network network : getTestableNetworks()) {
             final CountDownLatch latch = new CountDownLatch(1);
             final AtomicReference<List<InetAddress>> answers = new AtomicReference<>();
-
-            mDns.query(network, dname, FLAG_NO_CACHE_LOOKUP, mHandler, answerList -> {
-                        answers.set(answerList);
-                        for (InetAddress addr : answerList) {
-                            Log.d(TAG, "Reported addr:" + addr.toString());
-                        }
-                        latch.countDown();
+            final DnsResolver.InetAddressAnswerCallback callback =
+                    new DnsResolver.InetAddressAnswerCallback() {
+                @Override
+                public void onAnswer(@NonNull List<InetAddress> answerList) {
+                    answers.set(answerList);
+                    for (InetAddress addr : answerList) {
+                        Log.d(TAG, "Reported addr: " + addr.toString());
                     }
-            );
+                    latch.countDown();
+                }
+
+                @Override
+                public void onParseException(@NonNull ParseException e) {
+                    fail(msg + e.getMessage());
+                }
+
+                @Override
+                public void onQueryException(@NonNull ErrnoException e) {
+                    fail(msg + e.getMessage());
+                }
+            };
+            mDns.query(network, dname, CLASS_IN, TYPE_A, FLAG_NO_CACHE_LOOKUP, mHandler, callback);
             try {
                 assertTrue(msg + " but no valid answer after " + TIMEOUT_MS + "ms.",
                         latch.await(TIMEOUT_MS, TimeUnit.MILLISECONDS));
                 assertGreaterThan(msg + " returned 0 result", answers.get().size(), 0);
             } catch (InterruptedException e) {
+                fail(msg + " Waiting for DNS lookup was interrupted");
             }
         }
     }
 
-    static private void assertGreaterThan(String msg, int a, int b) {
-        assertTrue(msg + ": " + a + " > " + b, a > b);
+    static private void assertGreaterThan(String msg, int first, int second) {
+        assertTrue(msg + " Excepted " + first + " to be greater than " + second, first > second);
     }
 
     static private void assertValidAnswer(String msg, @NonNull DnsAnswer ans) {
         // Check rcode field.(0, No error condition).
         assertTrue(msg + " Response error, rcode: " + ans.getRcode(), ans.getRcode() == 0);
         // Check answer counts.
-        assertTrue(msg + " No answer found", ans.getANCount() > 0);
+        assertGreaterThan(msg + " No answer found", ans.getANCount(), 0);
         // Check question counts.
-        assertTrue(msg + " No question found", ans.getQDCount() > 0);
+        assertGreaterThan(msg + " No question found", ans.getQDCount(), 0);
     }
 
     static private void assertValidEmptyAnswer(String msg, @NonNull DnsAnswer ans) {
@@ -129,10 +146,10 @@
         // Check answer counts. Expect 0 answer.
         assertTrue(msg + " Not an empty answer", ans.getANCount() == 0);
         // Check question counts.
-        assertTrue(msg + " No question found", ans.getQDCount() > 0);
+        assertGreaterThan(msg + " No question found", ans.getQDCount(), 0);
     }
 
-    private class DnsAnswer extends DnsPacket {
+    private static class DnsAnswer extends DnsPacket {
         DnsAnswer(@NonNull byte[] data) throws ParseException {
             super(data);
             // Check QR field.(query (0), or a response (1)).
@@ -152,34 +169,56 @@
         }
     }
 
-    public void testRawQuery() throws ErrnoException {
+    class RawAnswerCallbackImpl extends DnsResolver.RawAnswerCallback {
+        private final CountDownLatch mLatch = new CountDownLatch(1);
+        private final String mMsg;
+        RawAnswerCallbackImpl(String msg) {
+            this.mMsg = msg;
+        }
+
+        public boolean waitForAnswer() throws InterruptedException {
+            return mLatch.await(TIMEOUT_MS, TimeUnit.MILLISECONDS);
+        }
+
+        @Override
+        public void onAnswer(@NonNull byte[] answer) {
+            try {
+                assertValidAnswer(mMsg, new DnsAnswer(answer));
+                Log.d(TAG, "Reported blob: " + byteArrayToHexString(answer));
+                mLatch.countDown();
+            } catch (ParseException e) {
+                fail(mMsg + e.getMessage());
+            }
+        }
+
+        @Override
+        public void onParseException(@NonNull ParseException e) {
+            fail(mMsg + e.getMessage());
+        }
+
+        @Override
+        public void onQueryException(@NonNull ErrnoException e) {
+            fail(mMsg + e.getMessage());
+        }
+    }
+
+    public void testQueryWithRawAnswerCallback() {
         final String dname = "www.google.com";
-        final String msg = "Raw query " + dname;
+        final String msg = "Query with RawAnswerCallback " + dname;
         for (Network network : getTestableNetworks()) {
-            final CountDownLatch latch = new CountDownLatch(1);
-            mDns.query(network, dname, CLASS_IN, TYPE_AAAA, FLAG_NO_CACHE_LOOKUP, mHandler,
-                    answer -> {
-                        if (answer == null) {
-                            fail(msg + " no answer returned");
-                        }
-                        try {
-                            assertValidAnswer(msg, new DnsAnswer(answer));
-                            Log.d(TAG, "Reported blob:" + byteArrayToHexString(answer));
-                            latch.countDown();
-                        } catch (DnsPacket.ParseException e) {
-                            fail(msg + e.getMessage());
-                        }
-                    }
-            );
+            final RawAnswerCallbackImpl callback = new RawAnswerCallbackImpl(msg);
+            mDns.query(network, dname, CLASS_IN, TYPE_AAAA, FLAG_NO_CACHE_LOOKUP,
+                    mHandler, callback);
             try {
                 assertTrue(msg + " but no answer after " + TIMEOUT_MS + "ms.",
-                        latch.await(TIMEOUT_MS, TimeUnit.MILLISECONDS));
+                        callback.waitForAnswer());
             } catch (InterruptedException e) {
+                fail(msg + " Waiting for DNS lookup was interrupted");
             }
         }
     }
 
-    public void testRawQueryWithBlob() throws ErrnoException {
+    public void testQueryBlobWithRawAnswerCallback() {
         final byte[] blob = new byte[]{
             /* Header */
             0x55, 0x66, /* Transaction ID */
@@ -194,74 +233,90 @@
             0x00, 0x01, /* Type */
             0x00, 0x01  /* Class */
         };
-        final String msg = "Raw query with blob " + byteArrayToHexString(blob);
+        final String msg = "Query with RawAnswerCallback " + byteArrayToHexString(blob);
         for (Network network : getTestableNetworks()) {
-            final CountDownLatch latch = new CountDownLatch(1);
-            mDns.query(network, blob, FLAG_NO_CACHE_LOOKUP, mHandler, answer -> {
-                        if (answer == null) {
-                            fail(msg + " no answer returned");
-                        }
-                        try {
-                            assertValidAnswer(msg, new DnsAnswer(answer));
-                            Log.d(TAG, "Reported blob:" + byteArrayToHexString(answer));
-                            latch.countDown();
-                        } catch (DnsPacket.ParseException e) {
-                            fail(msg + e.getMessage());
-                        }
-                    }
-            );
+            final RawAnswerCallbackImpl callback = new RawAnswerCallbackImpl(msg);
+            mDns.query(network, blob, FLAG_NO_CACHE_LOOKUP, mHandler, callback);
             try {
                 assertTrue(msg + " but no answer after " + TIMEOUT_MS + "ms.",
-                        latch.await(TIMEOUT_MS, TimeUnit.MILLISECONDS));
+                        callback.waitForAnswer());
             } catch (InterruptedException e) {
+                fail(msg + " Waiting for DNS lookup was interrupted");
             }
         }
     }
 
-    public void testEmptyQuery() throws ErrnoException {
+    public void testQueryRoot() {
         final String dname = "";
-        final String msg = "Raw query empty dname(ROOT)";
+        final String msg = "Query with RawAnswerCallback empty dname(ROOT) ";
         for (Network network : getTestableNetworks()) {
             final CountDownLatch latch = new CountDownLatch(1);
-            mDns.query(network, dname, CLASS_IN, TYPE_AAAA, FLAG_NO_CACHE_LOOKUP, mHandler,
-                    answer -> {
-                        if (answer == null) {
-                            fail(msg + " no answer returned");
-                        }
-                        try {
-                            // Except no answer record because of querying with empty dname(ROOT)
-                            assertValidEmptyAnswer(msg, new DnsAnswer(answer));
-                            latch.countDown();
-                        } catch (DnsPacket.ParseException e) {
-                            fail(msg + e.getMessage());
-                        }
+            final DnsResolver.RawAnswerCallback callback = new DnsResolver.RawAnswerCallback() {
+                @Override
+                public void onAnswer(@NonNull byte[] answer) {
+                    try {
+                        // Except no answer record because of querying with empty dname(ROOT)
+                        assertValidEmptyAnswer(msg, new DnsAnswer(answer));
+                        latch.countDown();
+                    } catch (ParseException e) {
+                        fail(msg + e.getMessage());
                     }
-            );
+                }
+
+                @Override
+                public void onParseException(@NonNull ParseException e) {
+                    fail(msg + e.getMessage());
+                }
+
+                @Override
+                public void onQueryException(@NonNull ErrnoException e) {
+                    fail(msg + e.getMessage());
+                }
+            };
+            mDns.query(network, dname, CLASS_IN, TYPE_AAAA, FLAG_NO_CACHE_LOOKUP,
+                    mHandler, callback);
             try {
-                assertTrue(msg + " but no answer after " + TIMEOUT_MS + "ms.",
+                assertTrue(msg + "but no answer after " + TIMEOUT_MS + "ms.",
                         latch.await(TIMEOUT_MS, TimeUnit.MILLISECONDS));
             } catch (InterruptedException e) {
+                fail(msg + "Waiting for DNS lookup was interrupted");
             }
         }
     }
 
-    public void testNXQuery() throws ErrnoException {
+    public void testQueryNXDomain() {
         final String dname = "test1-nx.metric.gstatic.com";
-        final String msg = "InetAddress query " + dname;
+        final String msg = "Query with InetAddressAnswerCallback " + dname;
         for (Network network : getTestableNetworks()) {
             final CountDownLatch latch = new CountDownLatch(1);
-            mDns.query(network, dname, FLAG_NO_CACHE_LOOKUP, mHandler, answerList -> {
-                        if (answerList.size() == 0) {
-                            latch.countDown();
-                            return;
-                        }
-                        fail(msg + " but get valid answers");
+            final DnsResolver.InetAddressAnswerCallback callback =
+                    new DnsResolver.InetAddressAnswerCallback() {
+                @Override
+                public void onAnswer(@NonNull List<InetAddress> answerList) {
+                    if (answerList.size() == 0) {
+                        latch.countDown();
+                        return;
                     }
-            );
+                    fail(msg + " but get valid answers");
+                }
+
+                @Override
+                public void onParseException(@NonNull ParseException e) {
+                    fail(msg + e.getMessage());
+                }
+
+                @Override
+                public void onQueryException(@NonNull ErrnoException e) {
+                    fail(msg + e.getMessage());
+                }
+            };
+            mDns.query(network, dname, CLASS_IN, TYPE_AAAA, FLAG_NO_CACHE_LOOKUP,
+                    mHandler, callback);
             try {
                 assertTrue(msg + " but no answer after " + TIMEOUT_MS + "ms.",
                         latch.await(TIMEOUT_MS, TimeUnit.MILLISECONDS));
             } catch (InterruptedException e) {
+                fail(msg + " Waiting for DNS lookup was interrupted");
             }
         }
     }