Merge "Obtain the target socket directly to send packets" into main
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java b/service-t/src/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java
index 097dbe0..2ef7368 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java
@@ -34,7 +34,6 @@
 import java.net.Inet6Address;
 import java.net.InetSocketAddress;
 import java.util.List;
-import java.util.Objects;
 
 /**
  * The {@link MdnsMultinetworkSocketClient} manages the multinetwork socket for mDns
@@ -49,10 +48,9 @@
     @NonNull private final MdnsSocketProvider mSocketProvider;
     @NonNull private final SharedLog mSharedLog;
 
-    private final ArrayMap<MdnsServiceBrowserListener, InterfaceSocketCallback> mRequestedNetworks =
+    private final ArrayMap<MdnsServiceBrowserListener, InterfaceSocketCallback> mSocketRequests =
             new ArrayMap<>();
-    private final ArrayMap<MdnsInterfaceSocket, ReadPacketHandler> mSocketPacketHandlers =
-            new ArrayMap<>();
+    private final ArrayMap<SocketKey, ReadPacketHandler> mSocketPacketHandlers = new ArrayMap<>();
     private MdnsSocketClientBase.Callback mCallback = null;
     private int mReceivedPacketNumber = 0;
 
@@ -68,8 +66,7 @@
         @NonNull
         private final SocketCreationCallback mSocketCreationCallback;
         @NonNull
-        private final ArrayMap<MdnsInterfaceSocket, SocketKey> mActiveSockets =
-                new ArrayMap<>();
+        private final ArrayMap<SocketKey, MdnsInterfaceSocket> mActiveSockets = new ArrayMap<>();
 
         InterfaceSocketCallback(SocketCreationCallback socketCreationCallback) {
             mSocketCreationCallback = socketCreationCallback;
@@ -80,27 +77,27 @@
                 @NonNull MdnsInterfaceSocket socket, @NonNull List<LinkAddress> addresses) {
             // The socket may be already created by other request before, try to get the stored
             // ReadPacketHandler.
-            ReadPacketHandler handler = mSocketPacketHandlers.get(socket);
+            ReadPacketHandler handler = mSocketPacketHandlers.get(socketKey);
             if (handler == null) {
                 // First request to create this socket. Initial a ReadPacketHandler for this socket.
                 handler = new ReadPacketHandler(socketKey);
-                mSocketPacketHandlers.put(socket, handler);
+                mSocketPacketHandlers.put(socketKey, handler);
             }
             socket.addPacketHandler(handler);
-            mActiveSockets.put(socket, socketKey);
+            mActiveSockets.put(socketKey, socket);
             mSocketCreationCallback.onSocketCreated(socketKey);
         }
 
         @Override
         public void onInterfaceDestroyed(@NonNull SocketKey socketKey,
                 @NonNull MdnsInterfaceSocket socket) {
-            notifySocketDestroyed(socket);
-            maybeCleanupPacketHandler(socket);
+            notifySocketDestroyed(socketKey);
+            maybeCleanupPacketHandler(socketKey);
         }
 
-        private void notifySocketDestroyed(@NonNull MdnsInterfaceSocket socket) {
-            final SocketKey socketKey = mActiveSockets.remove(socket);
-            if (!isSocketActive(socket)) {
+        private void notifySocketDestroyed(@NonNull SocketKey socketKey) {
+            mActiveSockets.remove(socketKey);
+            if (!isSocketActive(socketKey)) {
                 mSocketCreationCallback.onSocketDestroyed(socketKey);
             }
         }
@@ -108,35 +105,38 @@
         void onNetworkUnrequested() {
             for (int i = mActiveSockets.size() - 1; i >= 0; i--) {
                 // Iterate from the end so the socket can be removed
-                final MdnsInterfaceSocket socket = mActiveSockets.keyAt(i);
-                notifySocketDestroyed(socket);
-                maybeCleanupPacketHandler(socket);
+                final SocketKey socketKey = mActiveSockets.keyAt(i);
+                notifySocketDestroyed(socketKey);
+                maybeCleanupPacketHandler(socketKey);
             }
         }
     }
 
-    private boolean isSocketActive(@NonNull MdnsInterfaceSocket socket) {
-        for (int i = 0; i < mRequestedNetworks.size(); i++) {
-            final InterfaceSocketCallback isc = mRequestedNetworks.valueAt(i);
-            if (isc.mActiveSockets.containsKey(socket)) {
+    private boolean isSocketActive(@NonNull SocketKey socketKey) {
+        for (int i = 0; i < mSocketRequests.size(); i++) {
+            final InterfaceSocketCallback ifaceSocketCallback = mSocketRequests.valueAt(i);
+            if (ifaceSocketCallback.mActiveSockets.containsKey(socketKey)) {
                 return true;
             }
         }
         return false;
     }
 
-    private ArrayMap<MdnsInterfaceSocket, SocketKey> getActiveSockets() {
-        final ArrayMap<MdnsInterfaceSocket, SocketKey> sockets = new ArrayMap<>();
-        for (int i = 0; i < mRequestedNetworks.size(); i++) {
-            final InterfaceSocketCallback isc = mRequestedNetworks.valueAt(i);
-            sockets.putAll(isc.mActiveSockets);
+    @Nullable
+    private MdnsInterfaceSocket getTargetSocket(@NonNull SocketKey targetSocketKey) {
+        for (int i = 0; i < mSocketRequests.size(); i++) {
+            final InterfaceSocketCallback ifaceSocketCallback = mSocketRequests.valueAt(i);
+            final int index = ifaceSocketCallback.mActiveSockets.indexOfKey(targetSocketKey);
+            if (index >= 0) {
+                return ifaceSocketCallback.mActiveSockets.valueAt(index);
+            }
         }
-        return sockets;
+        return null;
     }
 
-    private void maybeCleanupPacketHandler(@NonNull MdnsInterfaceSocket socket) {
-        if (isSocketActive(socket)) return;
-        mSocketPacketHandlers.remove(socket);
+    private void maybeCleanupPacketHandler(@NonNull SocketKey socketKey) {
+        if (isSocketActive(socketKey)) return;
+        mSocketPacketHandlers.remove(socketKey);
     }
 
     private class ReadPacketHandler implements MulticastPacketReader.PacketHandler {
@@ -171,14 +171,14 @@
     public void notifyNetworkRequested(@NonNull MdnsServiceBrowserListener listener,
             @Nullable Network network, @NonNull SocketCreationCallback socketCreationCallback) {
         ensureRunningOnHandlerThread(mHandler);
-        InterfaceSocketCallback callback = mRequestedNetworks.get(listener);
+        InterfaceSocketCallback callback = mSocketRequests.get(listener);
         if (callback != null) {
             throw new IllegalArgumentException("Can not register duplicated listener");
         }
 
         if (DBG) mSharedLog.v("notifyNetworkRequested: network=" + network);
         callback = new InterfaceSocketCallback(socketCreationCallback);
-        mRequestedNetworks.put(listener, callback);
+        mSocketRequests.put(listener, callback);
         mSocketProvider.requestSocket(network, callback);
     }
 
@@ -186,14 +186,14 @@
     @Override
     public void notifyNetworkUnrequested(@NonNull MdnsServiceBrowserListener listener) {
         ensureRunningOnHandlerThread(mHandler);
-        final InterfaceSocketCallback callback = mRequestedNetworks.get(listener);
+        final InterfaceSocketCallback callback = mSocketRequests.get(listener);
         if (callback == null) {
             mSharedLog.e("Can not be unrequested with unknown listener=" + listener);
             return;
         }
         callback.onNetworkUnrequested();
-        // onNetworkUnrequested does cleanups based on mRequestedNetworks, only remove afterwards
-        mRequestedNetworks.remove(listener);
+        // onNetworkUnrequested does cleanups based on mSocketRequests, only remove afterwards
+        mSocketRequests.remove(listener);
         mSocketProvider.unrequestSocket(callback);
     }
 
@@ -209,42 +209,28 @@
 
     private void sendMdnsPacket(@NonNull DatagramPacket packet, @NonNull SocketKey targetSocketKey,
             boolean onlyUseIpv6OnIpv6OnlyNetworks) {
+        final MdnsInterfaceSocket socket = getTargetSocket(targetSocketKey);
+        if (socket == null) {
+            mSharedLog.e("No socket matches targetSocketKey=" + targetSocketKey);
+            return;
+        }
+
         final boolean isIpv6 = ((InetSocketAddress) packet.getSocketAddress()).getAddress()
                 instanceof Inet6Address;
         final boolean isIpv4 = ((InetSocketAddress) packet.getSocketAddress()).getAddress()
                 instanceof Inet4Address;
-        final ArrayMap<MdnsInterfaceSocket, SocketKey> activeSockets = getActiveSockets();
-        boolean shouldQueryIpv6 = !onlyUseIpv6OnIpv6OnlyNetworks || isIpv6OnlySockets(
-                activeSockets, targetSocketKey);
-        for (int i = 0; i < activeSockets.size(); i++) {
-            final MdnsInterfaceSocket socket = activeSockets.keyAt(i);
-            final SocketKey socketKey = activeSockets.valueAt(i);
-            // Check ip capability and network before sending packet
-            if (((isIpv6 && socket.hasJoinedIpv6() && shouldQueryIpv6)
-                    || (isIpv4 && socket.hasJoinedIpv4()))
-                    && Objects.equals(socketKey, targetSocketKey)) {
-                try {
-                    socket.send(packet);
-                } catch (IOException e) {
-                    mSharedLog.e("Failed to send a mDNS packet.", e);
-                }
+        final boolean shouldQueryIpv6 = !onlyUseIpv6OnIpv6OnlyNetworks || !socket.hasJoinedIpv4();
+        // Check ip capability and network before sending packet
+        if ((isIpv6 && socket.hasJoinedIpv6() && shouldQueryIpv6)
+                || (isIpv4 && socket.hasJoinedIpv4())) {
+            try {
+                socket.send(packet);
+            } catch (IOException e) {
+                mSharedLog.e("Failed to send a mDNS packet.", e);
             }
         }
     }
 
-    private boolean isIpv6OnlySockets(
-            @NonNull ArrayMap<MdnsInterfaceSocket, SocketKey> activeSockets,
-            @NonNull SocketKey targetSocketKey) {
-        for (int i = 0; i < activeSockets.size(); i++) {
-            final MdnsInterfaceSocket socket = activeSockets.keyAt(i);
-            final SocketKey socketKey = activeSockets.valueAt(i);
-            if (Objects.equals(socketKey, targetSocketKey) && socket.hasJoinedIpv4()) {
-                return false;
-            }
-        }
-        return true;
-    }
-
     private void processResponsePacket(byte[] recvbuf, int length, @NonNull SocketKey socketKey) {
         int packetNumber = ++mReceivedPacketNumber;