Do not send socket destroyed on unregistration
When a SocketCallback is unregistered from MdnsSocketProvider, do not
send socket destroyed callbacks. Callers may not expect getting
callbacks after unregistration, and the current callbacks are also
broken when an unrequested socket is still in use by another requester.
MdnsAdvertiser already does not depend on getting this callback, as it
only unregisters the SocketCallback after it is done using the socket.
This change fixes MdnsMultinetworkSocketClient to destroy the socket by
itself when unrequesting.
Bug: 276177548
Test: atest
Change-Id: If95f833e293f3aab91128aab1c9852ebfd41995d
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java b/service-t/src/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java
index 6414453..ad8cb64 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java
@@ -34,7 +34,6 @@
import java.net.Inet6Address;
import java.net.InetSocketAddress;
import java.util.List;
-import java.util.Map;
/**
* The {@link MdnsMultinetworkSocketClient} manages the multinetwork socket for mDns
@@ -48,9 +47,8 @@
@NonNull private final Handler mHandler;
@NonNull private final MdnsSocketProvider mSocketProvider;
- private final Map<MdnsServiceBrowserListener, InterfaceSocketCallback> mRequestedNetworks =
+ private final ArrayMap<MdnsServiceBrowserListener, InterfaceSocketCallback> mRequestedNetworks =
new ArrayMap<>();
- private final ArrayMap<MdnsInterfaceSocket, Network> mActiveNetworkSockets = new ArrayMap<>();
private final ArrayMap<MdnsInterfaceSocket, ReadPacketHandler> mSocketPacketHandlers =
new ArrayMap<>();
private MdnsSocketClientBase.Callback mCallback = null;
@@ -63,7 +61,11 @@
}
private class InterfaceSocketCallback implements MdnsSocketProvider.SocketCallback {
+ @NonNull
private final SocketCreationCallback mSocketCreationCallback;
+ @NonNull
+ private final ArrayMap<MdnsInterfaceSocket, Network> mActiveNetworkSockets =
+ new ArrayMap<>();
InterfaceSocketCallback(SocketCreationCallback socketCreationCallback) {
mSocketCreationCallback = socketCreationCallback;
@@ -88,10 +90,47 @@
@Override
public void onInterfaceDestroyed(@Nullable Network network,
@NonNull MdnsInterfaceSocket socket) {
- mSocketPacketHandlers.remove(socket);
- mActiveNetworkSockets.remove(socket);
+ notifySocketDestroyed(socket);
+ maybeCleanupPacketHandler(socket);
+ }
+
+ private void notifySocketDestroyed(@NonNull MdnsInterfaceSocket socket) {
+ final Network network = mActiveNetworkSockets.remove(socket);
mSocketCreationCallback.onSocketDestroyed(network);
}
+
+ void onNetworkUnrequested() {
+ for (int i = mActiveNetworkSockets.size() - 1; i >= 0; i--) {
+ // Iterate from the end so the socket can be removed
+ final MdnsInterfaceSocket socket = mActiveNetworkSockets.keyAt(i);
+ notifySocketDestroyed(socket);
+ maybeCleanupPacketHandler(socket);
+ }
+ }
+ }
+
+ private boolean isSocketActive(@NonNull MdnsInterfaceSocket socket) {
+ for (int i = 0; i < mRequestedNetworks.size(); i++) {
+ final InterfaceSocketCallback isc = mRequestedNetworks.valueAt(i);
+ if (isc.mActiveNetworkSockets.containsKey(socket)) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ private ArrayMap<MdnsInterfaceSocket, Network> getActiveSockets() {
+ final ArrayMap<MdnsInterfaceSocket, Network> sockets = new ArrayMap<>();
+ for (int i = 0; i < mRequestedNetworks.size(); i++) {
+ final InterfaceSocketCallback isc = mRequestedNetworks.valueAt(i);
+ sockets.putAll(isc.mActiveNetworkSockets);
+ }
+ return sockets;
+ }
+
+ private void maybeCleanupPacketHandler(@NonNull MdnsInterfaceSocket socket) {
+ if (isSocketActive(socket)) return;
+ mSocketPacketHandlers.remove(socket);
}
private class ReadPacketHandler implements MulticastPacketReader.PacketHandler {
@@ -149,6 +188,7 @@
return;
}
mSocketProvider.unrequestSocket(callback);
+ callback.onNetworkUnrequested();
}
private void sendMdnsPacket(@NonNull DatagramPacket packet, @Nullable Network targetNetwork) {
@@ -156,9 +196,10 @@
instanceof Inet6Address;
final boolean isIpv4 = ((InetSocketAddress) packet.getSocketAddress()).getAddress()
instanceof Inet4Address;
- for (int i = 0; i < mActiveNetworkSockets.size(); i++) {
- final MdnsInterfaceSocket socket = mActiveNetworkSockets.keyAt(i);
- final Network network = mActiveNetworkSockets.valueAt(i);
+ final ArrayMap<MdnsInterfaceSocket, Network> activeSockets = getActiveSockets();
+ for (int i = 0; i < activeSockets.size(); i++) {
+ final MdnsInterfaceSocket socket = activeSockets.keyAt(i);
+ final Network network = activeSockets.valueAt(i);
// Check ip capability and network before sending packet
if (((isIpv6 && socket.hasJoinedIpv6()) || (isIpv4 && socket.hasJoinedIpv4()))
&& isNetworkMatched(targetNetwork, network)) {
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsSocketProvider.java b/service-t/src/com/android/server/connectivity/mdns/MdnsSocketProvider.java
index ca61d34..e245ff1 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsSocketProvider.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsSocketProvider.java
@@ -599,8 +599,6 @@
if (matchRequestedNetwork(network)) continue;
final SocketInfo info = mNetworkSockets.removeAt(i);
info.mSocket.destroy();
- // Still notify to unrequester for socket destroy.
- cb.onInterfaceDestroyed(network, info.mSocket);
mSharedLog.log("Remove socket on net:" + network + " after unrequestSocket");
}
@@ -610,8 +608,6 @@
for (int i = mTetherInterfaceSockets.size() - 1; i >= 0; i--) {
final SocketInfo info = mTetherInterfaceSockets.valueAt(i);
info.mSocket.destroy();
- // Still notify to unrequester for socket destroy.
- cb.onInterfaceDestroyed(null /* network */, info.mSocket);
mSharedLog.log("Remove socket on ifName:" + mTetherInterfaceSockets.keyAt(i)
+ " after unrequestSocket");
}
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClientTest.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClientTest.java
index 90c43e5..d5d2902 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClientTest.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClientTest.java
@@ -24,7 +24,9 @@
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.eq;
+import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.timeout;
+import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import android.net.InetAddresses;
@@ -77,12 +79,17 @@
}
private SocketCallback expectSocketCallback() {
+ return expectSocketCallback(mListener, mNetwork);
+ }
+
+ private SocketCallback expectSocketCallback(MdnsServiceBrowserListener listener,
+ Network requestedNetwork) {
final ArgumentCaptor<SocketCallback> callbackCaptor =
ArgumentCaptor.forClass(SocketCallback.class);
mHandler.post(() -> mSocketClient.notifyNetworkRequested(
- mListener, mNetwork, mSocketCreationCallback));
+ listener, requestedNetwork, mSocketCreationCallback));
verify(mProvider, timeout(DEFAULT_TIMEOUT))
- .requestSocket(eq(mNetwork), callbackCaptor.capture());
+ .requestSocket(eq(requestedNetwork), callbackCaptor.capture());
return callbackCaptor.getValue();
}
@@ -169,4 +176,83 @@
new String[] { "Android", "local" } /* serviceHost */)
), response.answers);
}
+
+ @Test
+ public void testSocketRemovedAfterNetworkUnrequested() throws IOException {
+ // Request a socket
+ final SocketCallback callback = expectSocketCallback(mListener, mNetwork);
+ final DatagramPacket ipv4Packet = new DatagramPacket(BUFFER, 0 /* offset */, BUFFER.length,
+ InetAddresses.parseNumericAddress("192.0.2.1"), 0 /* port */);
+ doReturn(true).when(mSocket).hasJoinedIpv4();
+ doReturn(true).when(mSocket).hasJoinedIpv6();
+ doReturn(createEmptyNetworkInterface()).when(mSocket).getInterface();
+ // Notify socket created
+ callback.onSocketCreated(mNetwork, mSocket, List.of());
+ verify(mSocketCreationCallback).onSocketCreated(mNetwork);
+
+ // Send IPv4 packet and verify sending has been called.
+ mSocketClient.sendMulticastPacket(ipv4Packet);
+ HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
+ verify(mSocket).send(ipv4Packet);
+
+ // Request another socket with null network
+ final MdnsServiceBrowserListener listener2 = mock(MdnsServiceBrowserListener.class);
+ final Network network2 = mock(Network.class);
+ final MdnsInterfaceSocket socket2 = mock(MdnsInterfaceSocket.class);
+ final SocketCallback callback2 = expectSocketCallback(listener2, null);
+ doReturn(true).when(socket2).hasJoinedIpv4();
+ doReturn(true).when(socket2).hasJoinedIpv6();
+ doReturn(createEmptyNetworkInterface()).when(socket2).getInterface();
+ // Notify socket created for two networks.
+ callback2.onSocketCreated(mNetwork, mSocket, List.of());
+ callback2.onSocketCreated(network2, socket2, List.of());
+ verify(mSocketCreationCallback, times(2)).onSocketCreated(mNetwork);
+ verify(mSocketCreationCallback).onSocketCreated(network2);
+
+ // Send IPv4 packet and verify sending to two sockets.
+ mSocketClient.sendMulticastPacket(ipv4Packet);
+ HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
+ verify(mSocket, times(2)).send(ipv4Packet);
+ verify(socket2).send(ipv4Packet);
+
+ // Unrequest another socket
+ mHandler.post(() -> mSocketClient.notifyNetworkUnrequested(listener2));
+ verify(mProvider, timeout(DEFAULT_TIMEOUT)).unrequestSocket(callback2);
+
+ // Send IPv4 packet again and verify only sending via mSocket
+ mSocketClient.sendMulticastPacket(ipv4Packet);
+ HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
+ verify(mSocket, times(3)).send(ipv4Packet);
+ verify(socket2).send(ipv4Packet);
+
+ // Unrequest remaining socket
+ mHandler.post(() -> mSocketClient.notifyNetworkUnrequested(mListener));
+ verify(mProvider, timeout(DEFAULT_TIMEOUT)).unrequestSocket(callback);
+
+ // Send IPv4 packet and verify no more sending.
+ mSocketClient.sendMulticastPacket(ipv4Packet);
+ HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
+ verify(mSocket, times(3)).send(ipv4Packet);
+ verify(socket2).send(ipv4Packet);
+ }
+
+ @Test
+ public void testNotifyNetworkUnrequested_SocketsOnNullNetwork() {
+ final MdnsInterfaceSocket otherSocket = mock(MdnsInterfaceSocket.class);
+ final SocketCallback callback = expectSocketCallback(
+ mListener, null /* requestedNetwork */);
+ doReturn(createEmptyNetworkInterface()).when(mSocket).getInterface();
+ doReturn(createEmptyNetworkInterface()).when(otherSocket).getInterface();
+
+ callback.onSocketCreated(null /* network */, mSocket, List.of());
+ verify(mSocketCreationCallback).onSocketCreated(null);
+ callback.onSocketCreated(null /* network */, otherSocket, List.of());
+ verify(mSocketCreationCallback, times(2)).onSocketCreated(null);
+
+ mHandler.post(() -> mSocketClient.notifyNetworkUnrequested(mListener));
+ HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
+
+ verify(mProvider).unrequestSocket(callback);
+ verify(mSocketCreationCallback, times(2)).onSocketDestroyed(null /* network */);
+ }
}
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketProviderTest.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketProviderTest.java
index 744ec84..4b87556 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketProviderTest.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketProviderTest.java
@@ -349,8 +349,8 @@
HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
testCallback1.expectedNoCallback();
testCallback2.expectedNoCallback();
- // Expect the socket destroy for tethered interface.
- testCallback3.expectedInterfaceDestroyedForNetwork(null /* network */);
+ // There was still a tethered interface, but no callback should be sent once unregistered
+ testCallback3.expectedNoCallback();
}
private RtNetlinkAddressMessage createNetworkAddressUpdateNetLink(
@@ -528,7 +528,8 @@
HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
mHandler.post(()-> mSocketProvider.unrequestSocket(testCallback));
HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
- testCallback.expectedInterfaceDestroyedForNetwork(TEST_NETWORK);
+ // No callback sent when unregistered
+ testCallback.expectedNoCallback();
verify(mCm, times(1)).unregisterNetworkCallback(any(NetworkCallback.class));
verify(mTm, times(1)).unregisterTetheringEventCallback(any());