Wait for a socket to be created before sending packets

The required sockets may not have been created yet when
MdnsServiceTypeClient#startSendAndReceive is called on
MdnsDiscoveryManager#registerListener, so the first send would go
nowhere, and only later retries would be sent. Ideally the code
would wait for some sockets to be created before calling
startSendAndReceive.

Bug: 265787401
Bug: 264634275
Test: atest FrameworksNetTests
Change-Id: Id789d564d125c0192e742d7dd246367afdb93413
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java b/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java
index 67059e7..e63b2e1 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java
@@ -77,16 +77,15 @@
             }
         }
         // Request the network for discovery.
-        socketClient.notifyNetworkRequested(listener, searchOptions.getNetwork());
-
-        // All listeners of the same service types shares the same MdnsServiceTypeClient.
-        MdnsServiceTypeClient serviceTypeClient = serviceTypeClients.get(serviceType);
-        if (serviceTypeClient == null) {
-            serviceTypeClient = createServiceTypeClient(serviceType);
-            serviceTypeClients.put(serviceType, serviceTypeClient);
-        }
-        // TODO(b/264634275): Wait for a socket to be created before sending packets.
-        serviceTypeClient.startSendAndReceive(listener, searchOptions);
+        socketClient.notifyNetworkRequested(listener, searchOptions.getNetwork(), network -> {
+            // All listeners of the same service types shares the same MdnsServiceTypeClient.
+            MdnsServiceTypeClient serviceTypeClient = serviceTypeClients.get(serviceType);
+            if (serviceTypeClient == null) {
+                serviceTypeClient = createServiceTypeClient(serviceType);
+                serviceTypeClients.put(serviceType, serviceTypeClient);
+            }
+            serviceTypeClient.startSendAndReceive(listener, searchOptions);
+        });
     }
 
     /**
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java b/service-t/src/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java
index 93972d9..3d00ace 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java
@@ -63,6 +63,12 @@
     }
 
     private class InterfaceSocketCallback implements MdnsSocketProvider.SocketCallback {
+        private final SocketCreationCallback mSocketCreationCallback;
+
+        InterfaceSocketCallback(SocketCreationCallback socketCreationCallback) {
+            mSocketCreationCallback = socketCreationCallback;
+        }
+
         @Override
         public void onSocketCreated(@NonNull Network network,
                 @NonNull MdnsInterfaceSocket socket, @NonNull List<LinkAddress> addresses) {
@@ -76,6 +82,7 @@
             }
             socket.addPacketHandler(handler);
             mActiveNetworkSockets.put(socket, network);
+            mSocketCreationCallback.onSocketCreated(network);
         }
 
         @Override
@@ -114,10 +121,11 @@
      * @param listener the listener for discovery.
      * @param network the target network for discovery. Null means discovery on all possible
      *                interfaces.
+     * @param socketCreationCallback the callback to notify socket creation.
      */
     @Override
     public void notifyNetworkRequested(@NonNull MdnsServiceBrowserListener listener,
-            @Nullable Network network) {
+            @Nullable Network network, @NonNull SocketCreationCallback socketCreationCallback) {
         ensureRunningOnHandlerThread(mHandler);
         InterfaceSocketCallback callback = mRequestedNetworks.get(listener);
         if (callback != null) {
@@ -125,7 +133,7 @@
         }
 
         if (DBG) Log.d(TAG, "notifyNetworkRequested: network=" + network);
-        callback = new InterfaceSocketCallback();
+        callback = new InterfaceSocketCallback(socketCreationCallback);
         mRequestedNetworks.put(listener, callback);
         mSocketProvider.requestSocket(network, callback);
     }
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClientBase.java b/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClientBase.java
index 796dc83..25d2074 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClientBase.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClientBase.java
@@ -64,7 +64,9 @@
 
     /*** Notify that the given network is requested for mdns discovery / resolution */
     default void notifyNetworkRequested(@NonNull MdnsServiceBrowserListener listener,
-            @Nullable Network network) { }
+            @Nullable Network network, @NonNull SocketCreationCallback socketCreationCallback) {
+        socketCreationCallback.onSocketCreated(network);
+    }
 
     /*** Notify that the network is unrequested */
     default void notifyNetworkUnrequested(@NonNull MdnsServiceBrowserListener listener) { }
@@ -78,4 +80,10 @@
         /*** Parse a mdns response failed */
         void onFailedToParseMdnsResponse(int receivedPacketNumber, int errorCode);
     }
+
+    /*** Callback for requested socket creation  */
+    interface SocketCreationCallback {
+        /*** Notify requested socket is created */
+        void onSocketCreated(@Nullable Network network);
+    }
 }
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsDiscoveryManagerTests.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsDiscoveryManagerTests.java
index e6b8326..57e26d1 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsDiscoveryManagerTests.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsDiscoveryManagerTests.java
@@ -18,6 +18,8 @@
 
 import static com.android.testutils.DevSdkIgnoreRuleKt.SC_V2;
 
+import static org.mockito.Mockito.any;
+import static org.mockito.Mockito.eq;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
@@ -26,12 +28,14 @@
 import android.net.Network;
 import android.text.TextUtils;
 
+import com.android.server.connectivity.mdns.MdnsSocketClientBase.SocketCreationCallback;
 import com.android.testutils.DevSdkIgnoreRule;
 import com.android.testutils.DevSdkIgnoreRunner;
 
 import org.junit.Before;
 import org.junit.Test;
 import org.junit.runner.RunWith;
+import org.mockito.ArgumentCaptor;
 import org.mockito.Mock;
 import org.mockito.MockitoAnnotations;
 
@@ -80,13 +84,23 @@
                 };
     }
 
+    private void verifyListenerRegistration(String serviceType, MdnsServiceBrowserListener listener,
+            MdnsServiceTypeClient client) throws IOException {
+        final ArgumentCaptor<SocketCreationCallback> callbackCaptor =
+                ArgumentCaptor.forClass(SocketCreationCallback.class);
+        discoveryManager.registerListener(serviceType, listener,
+                MdnsSearchOptions.getDefaultOptions());
+        verify(socketClient).startDiscovery();
+        verify(socketClient).notifyNetworkRequested(
+                eq(listener), any(), callbackCaptor.capture());
+        final SocketCreationCallback callback = callbackCaptor.getValue();
+        callback.onSocketCreated(null /* network */);
+        verify(client).startSendAndReceive(listener, MdnsSearchOptions.getDefaultOptions());
+    }
+
     @Test
     public void registerListener_unregisterListener() throws IOException {
-        discoveryManager.registerListener(
-                SERVICE_TYPE_1, mockListenerOne, MdnsSearchOptions.getDefaultOptions());
-        verify(socketClient).startDiscovery();
-        verify(mockServiceTypeClientOne)
-                .startSendAndReceive(mockListenerOne, MdnsSearchOptions.getDefaultOptions());
+        verifyListenerRegistration(SERVICE_TYPE_1, mockListenerOne, mockServiceTypeClientOne);
 
         when(mockServiceTypeClientOne.stopSendAndReceive(mockListenerOne)).thenReturn(true);
         discoveryManager.unregisterListener(SERVICE_TYPE_1, mockListenerOne);
@@ -96,24 +110,14 @@
 
     @Test
     public void registerMultipleListeners() throws IOException {
-        discoveryManager.registerListener(
-                SERVICE_TYPE_1, mockListenerOne, MdnsSearchOptions.getDefaultOptions());
-        verify(socketClient).startDiscovery();
-        verify(mockServiceTypeClientOne)
-                .startSendAndReceive(mockListenerOne, MdnsSearchOptions.getDefaultOptions());
-
-        discoveryManager.registerListener(
-                SERVICE_TYPE_2, mockListenerTwo, MdnsSearchOptions.getDefaultOptions());
-        verify(mockServiceTypeClientTwo)
-                .startSendAndReceive(mockListenerTwo, MdnsSearchOptions.getDefaultOptions());
+        verifyListenerRegistration(SERVICE_TYPE_1, mockListenerOne, mockServiceTypeClientOne);
+        verifyListenerRegistration(SERVICE_TYPE_2, mockListenerTwo, mockServiceTypeClientTwo);
     }
 
     @Test
-    public void onResponseReceived() {
-        discoveryManager.registerListener(
-                SERVICE_TYPE_1, mockListenerOne, MdnsSearchOptions.getDefaultOptions());
-        discoveryManager.registerListener(
-                SERVICE_TYPE_2, mockListenerTwo, MdnsSearchOptions.getDefaultOptions());
+    public void onResponseReceived() throws IOException {
+        verifyListenerRegistration(SERVICE_TYPE_1, mockListenerOne, mockServiceTypeClientOne);
+        verifyListenerRegistration(SERVICE_TYPE_2, mockListenerTwo, mockServiceTypeClientTwo);
 
         MdnsPacket responseForServiceTypeOne = createMdnsPacket(SERVICE_TYPE_1);
         final int ifIndex = 1;
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClientTest.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClientTest.java
index 1e322e4..90c43e5 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClientTest.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClientTest.java
@@ -62,6 +62,7 @@
     @Mock private MdnsInterfaceSocket mSocket;
     @Mock private MdnsServiceBrowserListener mListener;
     @Mock private MdnsSocketClientBase.Callback mCallback;
+    @Mock private MdnsSocketClientBase.SocketCreationCallback mSocketCreationCallback;
     private MdnsMultinetworkSocketClient mSocketClient;
     private Handler mHandler;
 
@@ -78,7 +79,8 @@
     private SocketCallback expectSocketCallback() {
         final ArgumentCaptor<SocketCallback> callbackCaptor =
                 ArgumentCaptor.forClass(SocketCallback.class);
-        mHandler.post(() -> mSocketClient.notifyNetworkRequested(mListener, mNetwork));
+        mHandler.post(() -> mSocketClient.notifyNetworkRequested(
+                mListener, mNetwork, mSocketCreationCallback));
         verify(mProvider, timeout(DEFAULT_TIMEOUT))
                 .requestSocket(eq(mNetwork), callbackCaptor.capture());
         return callbackCaptor.getValue();
@@ -107,6 +109,7 @@
         doReturn(createEmptyNetworkInterface()).when(mSocket).getInterface();
         // Notify socket created
         callback.onSocketCreated(mNetwork, mSocket, List.of());
+        verify(mSocketCreationCallback).onSocketCreated(mNetwork);
 
         // Send packet to IPv4 with target network and verify sending has been called.
         mSocketClient.sendMulticastPacket(ipv4Packet, mNetwork);
@@ -138,6 +141,7 @@
         doReturn(createEmptyNetworkInterface()).when(mSocket).getInterface();
         // Notify socket created
         callback.onSocketCreated(mNetwork, mSocket, List.of());
+        verify(mSocketCreationCallback).onSocketCreated(mNetwork);
 
         final ArgumentCaptor<PacketHandler> handlerCaptor =
                 ArgumentCaptor.forClass(PacketHandler.class);