Ensure MdnsDiscoveryManager calls to ServiceTypeClients on looper thread

ServiceTypeClients will store/access services on MdnsServiceCache
in subsequent changes. And MdnsServiceCache can be accessed from
looper thread only. So ensure MdnsDiscoveryManager calls to
ServiceTypeClients on the looper thread.

Bug: 265787401
Test: atest FrameworksNetTests
(cherry picked from https://android-review.googlesource.com/q/commit:bd4140ea913ad031057ff37bea8c2556b4970463)
Merged-In: I05e73140da58c029b49057bb0ccfdb8ed7818dfc
Change-Id: I05e73140da58c029b49057bb0ccfdb8ed7818dfc
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java b/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java
index 2281c81..1b3c464 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java
@@ -16,18 +16,21 @@
 
 package com.android.server.connectivity.mdns;
 
+import static com.android.server.connectivity.mdns.util.MdnsUtils.ensureRunningOnHandlerThread;
 import static com.android.server.connectivity.mdns.util.MdnsUtils.isNetworkMatched;
+import static com.android.server.connectivity.mdns.util.MdnsUtils.isRunningOnHandlerThread;
 
 import android.Manifest.permission;
 import android.annotation.NonNull;
 import android.annotation.Nullable;
 import android.annotation.RequiresPermission;
 import android.net.Network;
+import android.os.Handler;
+import android.os.Looper;
 import android.util.ArrayMap;
 import android.util.Log;
 import android.util.Pair;
 
-import com.android.internal.annotations.GuardedBy;
 import com.android.internal.annotations.VisibleForTesting;
 import com.android.net.module.util.SharedLog;
 
@@ -47,8 +50,8 @@
     private final MdnsSocketClientBase socketClient;
     @NonNull private final SharedLog sharedLog;
 
-    @GuardedBy("this")
     @NonNull private final PerNetworkServiceTypeClients perNetworkServiceTypeClients;
+    @NonNull private final Handler handler;
 
     private static class PerNetworkServiceTypeClients {
         private final ArrayMap<Pair<String, Network>, MdnsServiceTypeClient> clients =
@@ -103,11 +106,21 @@
     }
 
     public MdnsDiscoveryManager(@NonNull ExecutorProvider executorProvider,
-            @NonNull MdnsSocketClientBase socketClient, @NonNull SharedLog sharedLog) {
+            @NonNull MdnsSocketClientBase socketClient, @NonNull SharedLog sharedLog,
+            @NonNull Looper looper) {
         this.executorProvider = executorProvider;
         this.socketClient = socketClient;
         this.sharedLog = sharedLog;
         perNetworkServiceTypeClients = new PerNetworkServiceTypeClients();
+        handler = new Handler(looper);
+    }
+
+    private void checkAndRunOnHandlerThread(@NonNull Runnable function) {
+        if (isRunningOnHandlerThread(handler)) {
+            function.run();
+        } else {
+            handler.post(function);
+        }
     }
 
     /**
@@ -120,11 +133,19 @@
      *                      serviceType}.
      */
     @RequiresPermission(permission.CHANGE_WIFI_MULTICAST_STATE)
-    public synchronized void registerListener(
+    public void registerListener(
             @NonNull String serviceType,
             @NonNull MdnsServiceBrowserListener listener,
             @NonNull MdnsSearchOptions searchOptions) {
         sharedLog.i("Registering listener for serviceType: " + serviceType);
+        checkAndRunOnHandlerThread(() ->
+                handleRegisterListener(serviceType, listener, searchOptions));
+    }
+
+    private void handleRegisterListener(
+            @NonNull String serviceType,
+            @NonNull MdnsServiceBrowserListener listener,
+            @NonNull MdnsSearchOptions searchOptions) {
         if (perNetworkServiceTypeClients.isEmpty()) {
             // First listener. Starts the socket client.
             try {
@@ -139,30 +160,28 @@
                 new MdnsSocketClientBase.SocketCreationCallback() {
                     @Override
                     public void onSocketCreated(@Nullable Network network) {
-                        synchronized (MdnsDiscoveryManager.this) {
-                            // All listeners of the same service types shares the same
-                            // MdnsServiceTypeClient.
-                            MdnsServiceTypeClient serviceTypeClient =
-                                    perNetworkServiceTypeClients.get(serviceType, network);
-                            if (serviceTypeClient == null) {
-                                serviceTypeClient = createServiceTypeClient(serviceType, network);
-                                perNetworkServiceTypeClients.put(serviceType, network,
-                                        serviceTypeClient);
-                            }
-                            serviceTypeClient.startSendAndReceive(listener, searchOptions);
+                        ensureRunningOnHandlerThread(handler);
+                        // All listeners of the same service types shares the same
+                        // MdnsServiceTypeClient.
+                        MdnsServiceTypeClient serviceTypeClient =
+                                perNetworkServiceTypeClients.get(serviceType, network);
+                        if (serviceTypeClient == null) {
+                            serviceTypeClient = createServiceTypeClient(serviceType, network);
+                            perNetworkServiceTypeClients.put(serviceType, network,
+                                    serviceTypeClient);
                         }
+                        serviceTypeClient.startSendAndReceive(listener, searchOptions);
                     }
 
                     @Override
                     public void onAllSocketsDestroyed(@Nullable Network network) {
-                        synchronized (MdnsDiscoveryManager.this) {
-                            final MdnsServiceTypeClient serviceTypeClient =
-                                    perNetworkServiceTypeClients.get(serviceType, network);
-                            if (serviceTypeClient == null) return;
-                            // Notify all listeners that all services are removed from this socket.
-                            serviceTypeClient.notifySocketDestroyed();
-                            perNetworkServiceTypeClients.remove(serviceTypeClient);
-                        }
+                        ensureRunningOnHandlerThread(handler);
+                        final MdnsServiceTypeClient serviceTypeClient =
+                                perNetworkServiceTypeClients.get(serviceType, network);
+                        if (serviceTypeClient == null) return;
+                        // Notify all listeners that all services are removed from this socket.
+                        serviceTypeClient.notifySocketDestroyed();
+                        perNetworkServiceTypeClients.remove(serviceTypeClient);
                     }
                 });
     }
@@ -175,9 +194,14 @@
      * @param listener    The {@link MdnsServiceBrowserListener} listener.
      */
     @RequiresPermission(permission.CHANGE_WIFI_MULTICAST_STATE)
-    public synchronized void unregisterListener(
+    public void unregisterListener(
             @NonNull String serviceType, @NonNull MdnsServiceBrowserListener listener) {
         sharedLog.i("Unregistering listener for serviceType:" + serviceType);
+        checkAndRunOnHandlerThread(() -> handleUnregisterListener(serviceType, listener));
+    }
+
+    private void handleUnregisterListener(
+            @NonNull String serviceType, @NonNull MdnsServiceBrowserListener listener) {
         final List<MdnsServiceTypeClient> serviceTypeClients =
                 perNetworkServiceTypeClients.getByServiceType(serviceType);
         if (serviceTypeClients.isEmpty()) {
@@ -200,8 +224,14 @@
     }
 
     @Override
-    public synchronized void onResponseReceived(@NonNull MdnsPacket packet,
-            int interfaceIndex, Network network) {
+    public void onResponseReceived(@NonNull MdnsPacket packet,
+            int interfaceIndex, @Nullable Network network) {
+        checkAndRunOnHandlerThread(() ->
+                handleOnResponseReceived(packet, interfaceIndex, network));
+    }
+
+    private void handleOnResponseReceived(@NonNull MdnsPacket packet, int interfaceIndex,
+            @Nullable Network network) {
         for (MdnsServiceTypeClient serviceTypeClient
                 : perNetworkServiceTypeClients.getByMatchingNetwork(network)) {
             serviceTypeClient.processResponse(packet, interfaceIndex, network);
@@ -209,8 +239,14 @@
     }
 
     @Override
-    public synchronized void onFailedToParseMdnsResponse(int receivedPacketNumber, int errorCode,
-            Network network) {
+    public void onFailedToParseMdnsResponse(int receivedPacketNumber, int errorCode,
+            @Nullable Network network) {
+        checkAndRunOnHandlerThread(() ->
+                handleOnFailedToParseMdnsResponse(receivedPacketNumber, errorCode, network));
+    }
+
+    private void handleOnFailedToParseMdnsResponse(int receivedPacketNumber, int errorCode,
+            @Nullable Network network) {
         for (MdnsServiceTypeClient serviceTypeClient
                 : perNetworkServiceTypeClients.getByMatchingNetwork(network)) {
             serviceTypeClient.onFailedToParseMdnsResponse(receivedPacketNumber, errorCode);
diff --git a/service-t/src/com/android/server/connectivity/mdns/util/MdnsUtils.java b/service-t/src/com/android/server/connectivity/mdns/util/MdnsUtils.java
index 5cc789f..4ba4e6d 100644
--- a/service-t/src/com/android/server/connectivity/mdns/util/MdnsUtils.java
+++ b/service-t/src/com/android/server/connectivity/mdns/util/MdnsUtils.java
@@ -68,12 +68,20 @@
 
     /*** Ensure that current running thread is same as given handler thread */
     public static void ensureRunningOnHandlerThread(@NonNull Handler handler) {
-        if (handler.getLooper().getThread() != Thread.currentThread()) {
+        if (!isRunningOnHandlerThread(handler)) {
             throw new IllegalStateException(
                     "Not running on Handler thread: " + Thread.currentThread().getName());
         }
     }
 
+    /*** Check that current running thread is same as given handler thread */
+    public static boolean isRunningOnHandlerThread(@NonNull Handler handler) {
+        if (handler.getLooper().getThread() == Thread.currentThread()) {
+            return true;
+        }
+        return false;
+    }
+
     /*** Check whether the target network is matched current network */
     public static boolean isNetworkMatched(@Nullable Network targetNetwork,
             @Nullable Network currentNetwork) {
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsDiscoveryManagerTests.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsDiscoveryManagerTests.java
index 45da874..89776e2 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsDiscoveryManagerTests.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsDiscoveryManagerTests.java
@@ -27,6 +27,8 @@
 import android.annotation.NonNull;
 import android.annotation.Nullable;
 import android.net.Network;
+import android.os.Handler;
+import android.os.HandlerThread;
 import android.text.TextUtils;
 import android.util.Pair;
 
@@ -34,7 +36,9 @@
 import com.android.server.connectivity.mdns.MdnsSocketClientBase.SocketCreationCallback;
 import com.android.testutils.DevSdkIgnoreRule;
 import com.android.testutils.DevSdkIgnoreRunner;
+import com.android.testutils.HandlerUtils;
 
+import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
 import org.junit.runner.RunWith;
@@ -53,7 +57,7 @@
 @RunWith(DevSdkIgnoreRunner.class)
 @DevSdkIgnoreRule.IgnoreUpTo(SC_V2)
 public class MdnsDiscoveryManagerTests {
-
+    private static final long DEFAULT_TIMEOUT = 2000L;
     private static final String SERVICE_TYPE_1 = "_googlecast._tcp.local";
     private static final String SERVICE_TYPE_2 = "_test._tcp.local";
     private static final Network NETWORK_1 = Mockito.mock(Network.class);
@@ -78,12 +82,18 @@
     @Mock MdnsServiceBrowserListener mockListenerTwo;
     @Mock SharedLog sharedLog;
     private MdnsDiscoveryManager discoveryManager;
+    private HandlerThread thread;
+    private Handler handler;
 
     @Before
     public void setUp() {
         MockitoAnnotations.initMocks(this);
 
-        discoveryManager = new MdnsDiscoveryManager(executorProvider, socketClient, sharedLog) {
+        thread = new HandlerThread("MdnsDiscoveryManagerTests");
+        thread.start();
+        handler = new Handler(thread.getLooper());
+        discoveryManager = new MdnsDiscoveryManager(executorProvider, socketClient, sharedLog,
+                    thread.getLooper()) {
                     @Override
                     MdnsServiceTypeClient createServiceTypeClient(@NonNull String serviceType,
                             @Nullable Network network) {
@@ -103,11 +113,23 @@
                 };
     }
 
+    @After
+    public void tearDown() {
+        if (thread != null) {
+            thread.quitSafely();
+        }
+    }
+
+    private void runOnHandler(Runnable r) {
+        handler.post(r);
+        HandlerUtils.waitForIdle(handler, DEFAULT_TIMEOUT);
+    }
+
     private SocketCreationCallback expectSocketCreationCallback(String serviceType,
             MdnsServiceBrowserListener listener, MdnsSearchOptions options) throws IOException {
         final ArgumentCaptor<SocketCreationCallback> callbackCaptor =
                 ArgumentCaptor.forClass(SocketCreationCallback.class);
-        discoveryManager.registerListener(serviceType, listener, options);
+        runOnHandler(() -> discoveryManager.registerListener(serviceType, listener, options));
         verify(socketClient).startDiscovery();
         verify(socketClient).notifyNetworkRequested(
                 eq(listener), eq(options.getNetwork()), callbackCaptor.capture());
@@ -120,11 +142,11 @@
                 MdnsSearchOptions.newBuilder().setNetwork(null /* network */).build();
         final SocketCreationCallback callback = expectSocketCreationCallback(
                 SERVICE_TYPE_1, mockListenerOne, options);
-        callback.onSocketCreated(null /* network */);
+        runOnHandler(() -> callback.onSocketCreated(null /* network */));
         verify(mockServiceTypeClientOne).startSendAndReceive(mockListenerOne, options);
 
         when(mockServiceTypeClientOne.stopSendAndReceive(mockListenerOne)).thenReturn(true);
-        discoveryManager.unregisterListener(SERVICE_TYPE_1, mockListenerOne);
+        runOnHandler(() -> discoveryManager.unregisterListener(SERVICE_TYPE_1, mockListenerOne));
         verify(mockServiceTypeClientOne).stopSendAndReceive(mockListenerOne);
         verify(socketClient).stopDiscovery();
     }
@@ -135,16 +157,16 @@
                 MdnsSearchOptions.newBuilder().setNetwork(null /* network */).build();
         final SocketCreationCallback callback = expectSocketCreationCallback(
                 SERVICE_TYPE_1, mockListenerOne, options);
-        callback.onSocketCreated(null /* network */);
+        runOnHandler(() -> callback.onSocketCreated(null /* network */));
         verify(mockServiceTypeClientOne).startSendAndReceive(mockListenerOne, options);
-        callback.onSocketCreated(NETWORK_1);
+        runOnHandler(() -> callback.onSocketCreated(NETWORK_1));
         verify(mockServiceTypeClientOne1).startSendAndReceive(mockListenerOne, options);
 
         final SocketCreationCallback callback2 = expectSocketCreationCallback(
                 SERVICE_TYPE_2, mockListenerTwo, options);
-        callback2.onSocketCreated(null /* network */);
+        runOnHandler(() -> callback2.onSocketCreated(null /* network */));
         verify(mockServiceTypeClientTwo).startSendAndReceive(mockListenerTwo, options);
-        callback2.onSocketCreated(NETWORK_2);
+        runOnHandler(() -> callback2.onSocketCreated(NETWORK_2));
         verify(mockServiceTypeClientTwo2).startSendAndReceive(mockListenerTwo, options);
     }
 
@@ -154,21 +176,22 @@
                 MdnsSearchOptions.newBuilder().setNetwork(null /* network */).build();
         final SocketCreationCallback callback = expectSocketCreationCallback(
                 SERVICE_TYPE_1, mockListenerOne, options1);
-        callback.onSocketCreated(null /* network */);
+        runOnHandler(() -> callback.onSocketCreated(null /* network */));
         verify(mockServiceTypeClientOne).startSendAndReceive(mockListenerOne, options1);
-        callback.onSocketCreated(NETWORK_1);
+        runOnHandler(() -> callback.onSocketCreated(NETWORK_1));
         verify(mockServiceTypeClientOne1).startSendAndReceive(mockListenerOne, options1);
 
         final MdnsSearchOptions options2 =
                 MdnsSearchOptions.newBuilder().setNetwork(NETWORK_2).build();
         final SocketCreationCallback callback2 = expectSocketCreationCallback(
                 SERVICE_TYPE_2, mockListenerTwo, options2);
-        callback2.onSocketCreated(NETWORK_2);
+        runOnHandler(() -> callback2.onSocketCreated(NETWORK_2));
         verify(mockServiceTypeClientTwo2).startSendAndReceive(mockListenerTwo, options2);
 
         final MdnsPacket responseForServiceTypeOne = createMdnsPacket(SERVICE_TYPE_1);
         final int ifIndex = 1;
-        discoveryManager.onResponseReceived(responseForServiceTypeOne, ifIndex, null /* network */);
+        runOnHandler(() -> discoveryManager.onResponseReceived(
+                responseForServiceTypeOne, ifIndex, null /* network */));
         verify(mockServiceTypeClientOne).processResponse(responseForServiceTypeOne, ifIndex,
                 null /* network */);
         verify(mockServiceTypeClientOne1).processResponse(responseForServiceTypeOne, ifIndex,
@@ -177,7 +200,8 @@
                 null /* network */);
 
         final MdnsPacket responseForServiceTypeTwo = createMdnsPacket(SERVICE_TYPE_2);
-        discoveryManager.onResponseReceived(responseForServiceTypeTwo, ifIndex, NETWORK_1);
+        runOnHandler(() -> discoveryManager.onResponseReceived(
+                responseForServiceTypeTwo, ifIndex, NETWORK_1));
         verify(mockServiceTypeClientOne).processResponse(responseForServiceTypeTwo, ifIndex,
                 NETWORK_1);
         verify(mockServiceTypeClientOne1).processResponse(responseForServiceTypeTwo, ifIndex,
@@ -187,7 +211,8 @@
 
         final MdnsPacket responseForSubtype =
                 createMdnsPacket("subtype._sub._googlecast._tcp.local");
-        discoveryManager.onResponseReceived(responseForSubtype, ifIndex, NETWORK_2);
+        runOnHandler(() -> discoveryManager.onResponseReceived(
+                responseForSubtype, ifIndex, NETWORK_2));
         verify(mockServiceTypeClientOne).processResponse(responseForSubtype, ifIndex, NETWORK_2);
         verify(mockServiceTypeClientOne1, never()).processResponse(
                 responseForSubtype, ifIndex, NETWORK_2);
@@ -201,7 +226,7 @@
                 MdnsSearchOptions.newBuilder().setNetwork(NETWORK_1).build();
         final SocketCreationCallback callback = expectSocketCreationCallback(
                 SERVICE_TYPE_1, mockListenerOne, options1);
-        callback.onSocketCreated(NETWORK_1);
+        runOnHandler(() -> callback.onSocketCreated(NETWORK_1));
         verify(mockServiceTypeClientOne1).startSendAndReceive(mockListenerOne, options1);
 
         // Create a ServiceTypeClient for SERVICE_TYPE_2 and NETWORK_2
@@ -209,26 +234,28 @@
                 MdnsSearchOptions.newBuilder().setNetwork(NETWORK_2).build();
         final SocketCreationCallback callback2 = expectSocketCreationCallback(
                 SERVICE_TYPE_2, mockListenerTwo, options2);
-        callback2.onSocketCreated(NETWORK_2);
+        runOnHandler(() -> callback2.onSocketCreated(NETWORK_2));
         verify(mockServiceTypeClientTwo2).startSendAndReceive(mockListenerTwo, options2);
 
         // Receive a response, it should be processed on both clients.
         final MdnsPacket response = createMdnsPacket(SERVICE_TYPE_1);
         final int ifIndex = 1;
-        discoveryManager.onResponseReceived(response, ifIndex, null /* network */);
+        runOnHandler(() -> discoveryManager.onResponseReceived(
+                response, ifIndex, null /* network */));
         verify(mockServiceTypeClientOne1).processResponse(response, ifIndex, null /* network */);
         verify(mockServiceTypeClientTwo2).processResponse(response, ifIndex, null /* network */);
 
         // The client for NETWORK_1 receives the callback that the NETWORK_1 has been destroyed,
         // mockServiceTypeClientOne1 should send service removed notifications and remove from the
         // list of clients.
-        callback.onAllSocketsDestroyed(NETWORK_1);
+        runOnHandler(() -> callback.onAllSocketsDestroyed(NETWORK_1));
         verify(mockServiceTypeClientOne1).notifySocketDestroyed();
 
         // Receive a response again, it should be processed only on mockServiceTypeClientTwo2.
         // Because the mockServiceTypeClientOne1 is removed from the list of clients, it is no
         // longer able to process responses.
-        discoveryManager.onResponseReceived(response, ifIndex, null /* network */);
+        runOnHandler(() -> discoveryManager.onResponseReceived(
+                response, ifIndex, null /* network */));
         verify(mockServiceTypeClientOne1, times(1))
                 .processResponse(response, ifIndex, null /* network */);
         verify(mockServiceTypeClientTwo2, times(2))
@@ -236,12 +263,13 @@
 
         // The client for NETWORK_2 receives the callback that the NETWORK_1 has been destroyed,
         // mockServiceTypeClientTwo2 shouldn't send any notifications.
-        callback2.onAllSocketsDestroyed(NETWORK_1);
+        runOnHandler(() -> callback2.onAllSocketsDestroyed(NETWORK_1));
         verify(mockServiceTypeClientTwo2, never()).notifySocketDestroyed();
 
         // Receive a response again, mockServiceTypeClientTwo2 is still in the list of clients, it's
         // still able to process responses.
-        discoveryManager.onResponseReceived(response, ifIndex, null /* network */);
+        runOnHandler(() -> discoveryManager.onResponseReceived(
+                response, ifIndex, null /* network */));
         verify(mockServiceTypeClientOne1, times(1))
                 .processResponse(response, ifIndex, null /* network */);
         verify(mockServiceTypeClientTwo2, times(3))