Merge "Stop throwing when the invalid capability is passed" into main
diff --git a/thread/service/java/com/android/server/thread/NsdPublisher.java b/thread/service/java/com/android/server/thread/NsdPublisher.java
index 2c14f1d..d0cb9b8 100644
--- a/thread/service/java/com/android/server/thread/NsdPublisher.java
+++ b/thread/service/java/com/android/server/thread/NsdPublisher.java
@@ -19,11 +19,15 @@
 import static android.net.nsd.NsdManager.PROTOCOL_DNS_SD;
 
 import android.annotation.NonNull;
+import android.annotation.Nullable;
 import android.content.Context;
+import android.net.DnsResolver;
 import android.net.InetAddresses;
+import android.net.Network;
 import android.net.nsd.DiscoveryRequest;
 import android.net.nsd.NsdManager;
 import android.net.nsd.NsdServiceInfo;
+import android.os.CancellationSignal;
 import android.os.Handler;
 import android.os.RemoteException;
 import android.text.TextUtils;
@@ -34,6 +38,7 @@
 import com.android.server.thread.openthread.DnsTxtAttribute;
 import com.android.server.thread.openthread.INsdDiscoverServiceCallback;
 import com.android.server.thread.openthread.INsdPublisher;
+import com.android.server.thread.openthread.INsdResolveHostCallback;
 import com.android.server.thread.openthread.INsdResolveServiceCallback;
 import com.android.server.thread.openthread.INsdStatusReceiver;
 
@@ -41,6 +46,7 @@
 import java.net.InetAddress;
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Collections;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
@@ -56,24 +62,36 @@
  * {@code mHandler} itself.
  */
 public final class NsdPublisher extends INsdPublisher.Stub {
-    // TODO: b/321883491 - specify network for mDNS operations
     private static final String TAG = NsdPublisher.class.getSimpleName();
+
+    // TODO: b/321883491 - specify network for mDNS operations
+    @Nullable private Network mNetwork;
     private final NsdManager mNsdManager;
+    private final DnsResolver mDnsResolver;
     private final Handler mHandler;
     private final Executor mExecutor;
     private final SparseArray<RegistrationListener> mRegistrationListeners = new SparseArray<>(0);
     private final SparseArray<DiscoveryListener> mDiscoveryListeners = new SparseArray<>(0);
     private final SparseArray<ServiceInfoListener> mServiceInfoListeners = new SparseArray<>(0);
+    private final SparseArray<HostInfoListener> mHostInfoListeners = new SparseArray<>(0);
 
     @VisibleForTesting
-    public NsdPublisher(NsdManager nsdManager, Handler handler) {
+    public NsdPublisher(NsdManager nsdManager, DnsResolver dnsResolver, Handler handler) {
+        mNetwork = null;
         mNsdManager = nsdManager;
+        mDnsResolver = dnsResolver;
         mHandler = handler;
         mExecutor = runnable -> mHandler.post(runnable);
     }
 
     public static NsdPublisher newInstance(Context context, Handler handler) {
-        return new NsdPublisher(context.getSystemService(NsdManager.class), handler);
+        return new NsdPublisher(
+                context.getSystemService(NsdManager.class), DnsResolver.getInstance(), handler);
+    }
+
+    // TODO: b/321883491 - NsdPublisher should be disabled when mNetwork is null
+    public void setNetworkForHostResolution(@Nullable Network network) {
+        mNetwork = network;
     }
 
     @Override
@@ -291,6 +309,53 @@
         }
     }
 
+    @Override
+    public void resolveHost(String name, INsdResolveHostCallback callback, int listenerId) {
+        mHandler.post(() -> resolveHostInternal(name, callback, listenerId));
+    }
+
+    private void resolveHostInternal(
+            String name, INsdResolveHostCallback callback, int listenerId) {
+        checkOnHandlerThread();
+
+        String fullHostname = name + ".local";
+        CancellationSignal cancellationSignal = new CancellationSignal();
+        HostInfoListener listener =
+                new HostInfoListener(name, callback, cancellationSignal, listenerId);
+        mDnsResolver.query(
+                mNetwork,
+                fullHostname,
+                DnsResolver.FLAG_NO_CACHE_LOOKUP,
+                mExecutor,
+                cancellationSignal,
+                listener);
+        mHostInfoListeners.append(listenerId, listener);
+
+        Log.i(TAG, "Resolving host." + " Listener ID: " + listenerId + ", hostname: " + name);
+    }
+
+    @Override
+    public void stopHostResolution(int listenerId) {
+        mHandler.post(() -> stopHostResolutionInternal(listenerId));
+    }
+
+    private void stopHostResolutionInternal(int listenerId) {
+        checkOnHandlerThread();
+
+        HostInfoListener listener = mHostInfoListeners.get(listenerId);
+        if (listener == null) {
+            Log.w(
+                    TAG,
+                    "Failed to stop host resolution. Listener ID: "
+                            + listenerId
+                            + ". The listener is null.");
+            return;
+        }
+        Log.i(TAG, "Stopping host resolution. Listener: " + listener);
+        listener.cancel();
+        mHostInfoListeners.remove(listenerId);
+    }
+
     private void checkOnHandlerThread() {
         if (mHandler.getLooper().getThread() != Thread.currentThread()) {
             throw new IllegalStateException(
@@ -586,4 +651,78 @@
             return "ID: " + mListenerId + ", service name: " + mName + ", service type: " + mType;
         }
     }
+
+    class HostInfoListener implements DnsResolver.Callback<List<InetAddress>> {
+        private final String mHostname;
+        private final INsdResolveHostCallback mResolveHostCallback;
+        private final CancellationSignal mCancellationSignal;
+        private final int mListenerId;
+
+        HostInfoListener(
+                @NonNull String hostname,
+                INsdResolveHostCallback resolveHostCallback,
+                CancellationSignal cancellationSignal,
+                int listenerId) {
+            this.mHostname = hostname;
+            this.mResolveHostCallback = resolveHostCallback;
+            this.mCancellationSignal = cancellationSignal;
+            this.mListenerId = listenerId;
+        }
+
+        @Override
+        public void onAnswer(@NonNull List<InetAddress> answerList, int rcode) {
+            checkOnHandlerThread();
+
+            Log.i(
+                    TAG,
+                    "Host is resolved."
+                            + " Listener ID: "
+                            + mListenerId
+                            + ", hostname: "
+                            + mHostname
+                            + ", addresses: "
+                            + answerList
+                            + ", return code: "
+                            + rcode);
+            List<String> addressStrings = new ArrayList<>();
+            for (InetAddress address : answerList) {
+                addressStrings.add(address.getHostAddress());
+            }
+            try {
+                mResolveHostCallback.onHostResolved(mHostname, addressStrings);
+            } catch (RemoteException e) {
+                // do nothing if the client is dead
+            }
+            mHostInfoListeners.remove(mListenerId);
+        }
+
+        @Override
+        public void onError(@NonNull DnsResolver.DnsException error) {
+            checkOnHandlerThread();
+
+            Log.i(
+                    TAG,
+                    "Failed to resolve host."
+                            + " Listener ID: "
+                            + mListenerId
+                            + ", hostname: "
+                            + mHostname,
+                    error);
+            try {
+                mResolveHostCallback.onHostResolved(mHostname, Collections.emptyList());
+            } catch (RemoteException e) {
+                // do nothing if the client is dead
+            }
+            mHostInfoListeners.remove(mListenerId);
+        }
+
+        public String toString() {
+            return "ID: " + mListenerId + ", hostname: " + mHostname;
+        }
+
+        void cancel() {
+            mCancellationSignal.cancel();
+            mHostInfoListeners.remove(mListenerId);
+        }
+    }
 }
diff --git a/thread/service/java/com/android/server/thread/ThreadNetworkControllerService.java b/thread/service/java/com/android/server/thread/ThreadNetworkControllerService.java
index af9abdf..737ec41 100644
--- a/thread/service/java/com/android/server/thread/ThreadNetworkControllerService.java
+++ b/thread/service/java/com/android/server/thread/ThreadNetworkControllerService.java
@@ -717,6 +717,7 @@
                 if (mNetworkToInterface.containsKey(mUpstreamNetwork)) {
                     enableBorderRouting(mNetworkToInterface.get(mUpstreamNetwork));
                 }
+                mNsdPublisher.setNetworkForHostResolution(mUpstreamNetwork);
             }
         }
     }
diff --git a/thread/tests/unit/src/com/android/server/thread/NsdPublisherTest.java b/thread/tests/unit/src/com/android/server/thread/NsdPublisherTest.java
index 8886c73..3cae84f 100644
--- a/thread/tests/unit/src/com/android/server/thread/NsdPublisherTest.java
+++ b/thread/tests/unit/src/com/android/server/thread/NsdPublisherTest.java
@@ -16,6 +16,7 @@
 
 package com.android.server.thread;
 
+import static android.net.DnsResolver.ERROR_SYSTEM;
 import static android.net.nsd.NsdManager.FAILURE_INTERNAL_ERROR;
 import static android.net.nsd.NsdManager.PROTOCOL_DNS_SD;
 
@@ -30,15 +31,19 @@
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 
+import android.net.DnsResolver;
 import android.net.InetAddresses;
+import android.net.Network;
 import android.net.nsd.DiscoveryRequest;
 import android.net.nsd.NsdManager;
 import android.net.nsd.NsdServiceInfo;
+import android.os.CancellationSignal;
 import android.os.Handler;
 import android.os.test.TestLooper;
 
 import com.android.server.thread.openthread.DnsTxtAttribute;
 import com.android.server.thread.openthread.INsdDiscoverServiceCallback;
+import com.android.server.thread.openthread.INsdResolveHostCallback;
 import com.android.server.thread.openthread.INsdResolveServiceCallback;
 import com.android.server.thread.openthread.INsdStatusReceiver;
 
@@ -61,11 +66,14 @@
 /** Unit tests for {@link NsdPublisher}. */
 public final class NsdPublisherTest {
     @Mock private NsdManager mMockNsdManager;
+    @Mock private DnsResolver mMockDnsResolver;
 
     @Mock private INsdStatusReceiver mRegistrationReceiver;
     @Mock private INsdStatusReceiver mUnregistrationReceiver;
     @Mock private INsdDiscoverServiceCallback mDiscoverServiceCallback;
     @Mock private INsdResolveServiceCallback mResolveServiceCallback;
+    @Mock private INsdResolveHostCallback mResolveHostCallback;
+    @Mock private Network mNetwork;
 
     private TestLooper mTestLooper;
     private NsdPublisher mNsdPublisher;
@@ -637,6 +645,84 @@
     }
 
     @Test
+    public void resolveHost_hostResolved() throws Exception {
+        prepareTest();
+
+        mNsdPublisher.resolveHost("test", mResolveHostCallback, 10 /* listenerId */);
+        mTestLooper.dispatchAll();
+
+        ArgumentCaptor<DnsResolver.Callback<List<InetAddress>>> resolveHostCallbackArgumentCaptor =
+                ArgumentCaptor.forClass(DnsResolver.Callback.class);
+        verify(mMockDnsResolver, times(1))
+                .query(
+                        eq(mNetwork),
+                        eq("test.local"),
+                        eq(DnsResolver.FLAG_NO_CACHE_LOOKUP),
+                        any(Executor.class),
+                        any(CancellationSignal.class),
+                        resolveHostCallbackArgumentCaptor.capture());
+        resolveHostCallbackArgumentCaptor
+                .getValue()
+                .onAnswer(
+                        List.of(
+                                InetAddresses.parseNumericAddress("2001::1"),
+                                InetAddresses.parseNumericAddress("2001::2")),
+                        0);
+        mTestLooper.dispatchAll();
+
+        verify(mResolveHostCallback, times(1))
+                .onHostResolved("test", List.of("2001::1", "2001::2"));
+    }
+
+    @Test
+    public void resolveHost_errorReported() throws Exception {
+        prepareTest();
+
+        mNsdPublisher.resolveHost("test", mResolveHostCallback, 10 /* listenerId */);
+        mTestLooper.dispatchAll();
+
+        ArgumentCaptor<DnsResolver.Callback<List<InetAddress>>> resolveHostCallbackArgumentCaptor =
+                ArgumentCaptor.forClass(DnsResolver.Callback.class);
+        verify(mMockDnsResolver, times(1))
+                .query(
+                        eq(mNetwork),
+                        eq("test.local"),
+                        eq(DnsResolver.FLAG_NO_CACHE_LOOKUP),
+                        any(Executor.class),
+                        any(CancellationSignal.class),
+                        resolveHostCallbackArgumentCaptor.capture());
+        resolveHostCallbackArgumentCaptor
+                .getValue()
+                .onError(new DnsResolver.DnsException(ERROR_SYSTEM, null /* cause */));
+        mTestLooper.dispatchAll();
+
+        verify(mResolveHostCallback, times(1)).onHostResolved("test", Collections.emptyList());
+    }
+
+    @Test
+    public void stopHostResolution() throws Exception {
+        prepareTest();
+
+        mNsdPublisher.resolveHost("test", mResolveHostCallback, 10 /* listenerId */);
+        mTestLooper.dispatchAll();
+        ArgumentCaptor<CancellationSignal> cancellationSignalArgumentCaptor =
+                ArgumentCaptor.forClass(CancellationSignal.class);
+        verify(mMockDnsResolver, times(1))
+                .query(
+                        eq(mNetwork),
+                        eq("test.local"),
+                        eq(DnsResolver.FLAG_NO_CACHE_LOOKUP),
+                        any(Executor.class),
+                        cancellationSignalArgumentCaptor.capture(),
+                        any(DnsResolver.Callback.class));
+
+        mNsdPublisher.stopHostResolution(10 /* listenerId */);
+        mTestLooper.dispatchAll();
+
+        assertThat(cancellationSignalArgumentCaptor.getValue().isCanceled()).isTrue();
+    }
+
+    @Test
     public void reset_unregisterAll() {
         prepareTest();
 
@@ -780,6 +866,7 @@
     private void prepareTest() {
         mTestLooper = new TestLooper();
         Handler handler = new Handler(mTestLooper.getLooper());
-        mNsdPublisher = new NsdPublisher(mMockNsdManager, handler);
+        mNsdPublisher = new NsdPublisher(mMockNsdManager, mMockDnsResolver, handler);
+        mNsdPublisher.setNetworkForHostResolution(mNetwork);
     }
 }