Merge "Add isFeatureEnabled API for CaptivePortalLogin experiment namespace." into main
diff --git a/TEST_MAPPING b/TEST_MAPPING
index ab3ed66..d8d4c21 100644
--- a/TEST_MAPPING
+++ b/TEST_MAPPING
@@ -246,6 +246,9 @@
},
{
"exclude-annotation": "com.android.testutils.DnsResolverModuleTest"
+ },
+ {
+ "exclude-annotation": "com.android.testutils.NetworkStackModuleTest"
}
]
},
diff --git a/framework/src/android/net/NetworkCapabilities.java b/framework/src/android/net/NetworkCapabilities.java
index 84a0d29..85b1dac 100644
--- a/framework/src/android/net/NetworkCapabilities.java
+++ b/framework/src/android/net/NetworkCapabilities.java
@@ -1775,8 +1775,7 @@
// use the same specifier, TelephonyNetworkSpecifier.
&& mTransportTypes != (1L << TRANSPORT_TEST)
&& Long.bitCount(mTransportTypes & ~(1L << TRANSPORT_TEST)) != 1
- && (mTransportTypes & ~(1L << TRANSPORT_TEST))
- != (1 << TRANSPORT_CELLULAR | 1 << TRANSPORT_SATELLITE)) {
+ && !specifierAcceptableForMultipleTransports(mTransportTypes)) {
throw new IllegalStateException("Must have a single non-test transport specified to "
+ "use setNetworkSpecifier");
}
@@ -1786,6 +1785,12 @@
return this;
}
+ private boolean specifierAcceptableForMultipleTransports(long transportTypes) {
+ return (transportTypes & ~(1L << TRANSPORT_TEST))
+ // Cellular and satellite use the same NetworkSpecifier.
+ == (1 << TRANSPORT_CELLULAR | 1 << TRANSPORT_SATELLITE);
+ }
+
/**
* Sets the optional transport specific information.
*
diff --git a/netd/BpfHandler.cpp b/netd/BpfHandler.cpp
index e6fc825..0d75c05 100644
--- a/netd/BpfHandler.cpp
+++ b/netd/BpfHandler.cpp
@@ -179,22 +179,24 @@
}
Status BpfHandler::init(const char* cg2_path) {
- // Make sure BPF programs are loaded before doing anything
- ALOGI("Waiting for BPF programs");
+ if (base::GetProperty("bpf.progs_loaded", "") != "1") {
+ // Make sure BPF programs are loaded before doing anything
+ ALOGI("Waiting for BPF programs");
- if (true || !modules::sdklevel::IsAtLeastV()) {
- waitForNetProgsLoaded();
- ALOGI("Networking BPF programs are loaded");
+ if (true || !modules::sdklevel::IsAtLeastV()) {
+ waitForNetProgsLoaded();
+ ALOGI("Networking BPF programs are loaded");
- if (!base::SetProperty("ctl.start", "mdnsd_loadbpf")) {
- ALOGE("Failed to set property ctl.start=mdnsd_loadbpf, see dmesg for reason.");
- abort();
+ if (!base::SetProperty("ctl.start", "mdnsd_loadbpf")) {
+ ALOGE("Failed to set property ctl.start=mdnsd_loadbpf, see dmesg for reason.");
+ abort();
+ }
+
+ ALOGI("Waiting for remaining BPF programs");
}
- ALOGI("Waiting for remaining BPF programs");
+ android::bpf::waitForProgsLoaded();
}
-
- android::bpf::waitForProgsLoaded();
ALOGI("BPF programs are loaded");
RETURN_IF_NOT_OK(initPrograms(cg2_path));
diff --git a/service-t/src/com/android/server/connectivity/mdns/EnqueueMdnsQueryCallable.java b/service-t/src/com/android/server/connectivity/mdns/EnqueueMdnsQueryCallable.java
index e61555a..54943c7 100644
--- a/service-t/src/com/android/server/connectivity/mdns/EnqueueMdnsQueryCallable.java
+++ b/service-t/src/com/android/server/connectivity/mdns/EnqueueMdnsQueryCallable.java
@@ -84,6 +84,7 @@
private final byte[] packetCreationBuffer = new byte[1500]; // TODO: use interface MTU
@NonNull
private final List<MdnsResponse> existingServices;
+ private final boolean isQueryWithKnownAnswer;
EnqueueMdnsQueryCallable(
@NonNull MdnsSocketClientBase requestSender,
@@ -98,7 +99,8 @@
@NonNull MdnsUtils.Clock clock,
@NonNull SharedLog sharedLog,
@NonNull MdnsServiceTypeClient.Dependencies dependencies,
- @NonNull Collection<MdnsResponse> existingServices) {
+ @NonNull Collection<MdnsResponse> existingServices,
+ boolean isQueryWithKnownAnswer) {
weakRequestSender = new WeakReference<>(requestSender);
serviceTypeLabels = TextUtils.split(serviceType, "\\.");
this.subtypes = new ArrayList<>(subtypes);
@@ -112,6 +114,7 @@
this.sharedLog = sharedLog;
this.dependencies = dependencies;
this.existingServices = new ArrayList<>(existingServices);
+ this.isQueryWithKnownAnswer = isQueryWithKnownAnswer;
}
/**
@@ -226,27 +229,27 @@
private void sendPacket(MdnsSocketClientBase requestSender, InetSocketAddress address,
MdnsPacket mdnsPacket) throws IOException {
- final DatagramPacket packet = dependencies.getDatagramPacketFromMdnsPacket(
- packetCreationBuffer, mdnsPacket, address);
+ final List<DatagramPacket> packets = dependencies.getDatagramPacketsFromMdnsPacket(
+ packetCreationBuffer, mdnsPacket, address, isQueryWithKnownAnswer);
if (expectUnicastResponse) {
// MdnsMultinetworkSocketClient is only available on T+
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.TIRAMISU
&& requestSender instanceof MdnsMultinetworkSocketClient) {
((MdnsMultinetworkSocketClient) requestSender).sendPacketRequestingUnicastResponse(
- packet, socketKey, onlyUseIpv6OnIpv6OnlyNetworks);
+ packets, socketKey, onlyUseIpv6OnIpv6OnlyNetworks);
} else {
requestSender.sendPacketRequestingUnicastResponse(
- packet, onlyUseIpv6OnIpv6OnlyNetworks);
+ packets, onlyUseIpv6OnIpv6OnlyNetworks);
}
} else {
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.TIRAMISU
&& requestSender instanceof MdnsMultinetworkSocketClient) {
((MdnsMultinetworkSocketClient) requestSender)
.sendPacketRequestingMulticastResponse(
- packet, socketKey, onlyUseIpv6OnIpv6OnlyNetworks);
+ packets, socketKey, onlyUseIpv6OnIpv6OnlyNetworks);
} else {
requestSender.sendPacketRequestingMulticastResponse(
- packet, onlyUseIpv6OnIpv6OnlyNetworks);
+ packets, onlyUseIpv6OnIpv6OnlyNetworks);
}
}
}
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java b/service-t/src/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java
index 869ac9b..fcfb15f 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java
@@ -27,6 +27,7 @@
import android.os.Handler;
import android.os.Looper;
import android.util.ArrayMap;
+import android.util.Log;
import com.android.net.module.util.SharedLog;
@@ -213,24 +214,30 @@
return true;
}
- private void sendMdnsPacket(@NonNull DatagramPacket packet, @NonNull SocketKey targetSocketKey,
- boolean onlyUseIpv6OnIpv6OnlyNetworks) {
+ private void sendMdnsPackets(@NonNull List<DatagramPacket> packets,
+ @NonNull SocketKey targetSocketKey, boolean onlyUseIpv6OnIpv6OnlyNetworks) {
final MdnsInterfaceSocket socket = getTargetSocket(targetSocketKey);
if (socket == null) {
mSharedLog.e("No socket matches targetSocketKey=" + targetSocketKey);
return;
}
+ if (packets.isEmpty()) {
+ Log.wtf(TAG, "No mDns packets to send");
+ return;
+ }
- final boolean isIpv6 = ((InetSocketAddress) packet.getSocketAddress()).getAddress()
- instanceof Inet6Address;
- final boolean isIpv4 = ((InetSocketAddress) packet.getSocketAddress()).getAddress()
- instanceof Inet4Address;
+ final boolean isIpv6 = ((InetSocketAddress) packets.get(0).getSocketAddress())
+ .getAddress() instanceof Inet6Address;
+ final boolean isIpv4 = ((InetSocketAddress) packets.get(0).getSocketAddress())
+ .getAddress() instanceof Inet4Address;
final boolean shouldQueryIpv6 = !onlyUseIpv6OnIpv6OnlyNetworks || !socket.hasJoinedIpv4();
// Check ip capability and network before sending packet
if ((isIpv6 && socket.hasJoinedIpv6() && shouldQueryIpv6)
|| (isIpv4 && socket.hasJoinedIpv4())) {
try {
- socket.send(packet);
+ for (DatagramPacket packet : packets) {
+ socket.send(packet);
+ }
} catch (IOException e) {
mSharedLog.e("Failed to send a mDNS packet.", e);
}
@@ -259,34 +266,34 @@
}
/**
- * Send a mDNS request packet via given socket key that asks for multicast response.
+ * Send mDNS request packets via given socket key that asks for multicast response.
*/
- public void sendPacketRequestingMulticastResponse(@NonNull DatagramPacket packet,
+ public void sendPacketRequestingMulticastResponse(@NonNull List<DatagramPacket> packets,
@NonNull SocketKey socketKey, boolean onlyUseIpv6OnIpv6OnlyNetworks) {
- mHandler.post(() -> sendMdnsPacket(packet, socketKey, onlyUseIpv6OnIpv6OnlyNetworks));
+ mHandler.post(() -> sendMdnsPackets(packets, socketKey, onlyUseIpv6OnIpv6OnlyNetworks));
}
@Override
public void sendPacketRequestingMulticastResponse(
- @NonNull DatagramPacket packet, boolean onlyUseIpv6OnIpv6OnlyNetworks) {
+ @NonNull List<DatagramPacket> packets, boolean onlyUseIpv6OnIpv6OnlyNetworks) {
throw new UnsupportedOperationException("This socket client need to specify the socket to"
+ "send packet");
}
/**
- * Send a mDNS request packet via given socket key that asks for unicast response.
+ * Send mDNS request packets via given socket key that asks for unicast response.
*
* <p>The socket client may use a null network to identify some or all interfaces, in which case
* passing null sends the packet to these.
*/
- public void sendPacketRequestingUnicastResponse(@NonNull DatagramPacket packet,
+ public void sendPacketRequestingUnicastResponse(@NonNull List<DatagramPacket> packets,
@NonNull SocketKey socketKey, boolean onlyUseIpv6OnIpv6OnlyNetworks) {
- mHandler.post(() -> sendMdnsPacket(packet, socketKey, onlyUseIpv6OnIpv6OnlyNetworks));
+ mHandler.post(() -> sendMdnsPackets(packets, socketKey, onlyUseIpv6OnIpv6OnlyNetworks));
}
@Override
public void sendPacketRequestingUnicastResponse(
- @NonNull DatagramPacket packet, boolean onlyUseIpv6OnIpv6OnlyNetworks) {
+ @NonNull List<DatagramPacket> packets, boolean onlyUseIpv6OnIpv6OnlyNetworks) {
throw new UnsupportedOperationException("This socket client need to specify the socket to"
+ "send packet");
}
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsRecord.java b/service-t/src/com/android/server/connectivity/mdns/MdnsRecord.java
index 4b43989..1f9f42b 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsRecord.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsRecord.java
@@ -23,7 +23,8 @@
import android.os.SystemClock;
import android.text.TextUtils;
-import com.android.internal.annotations.VisibleForTesting;
+import androidx.annotation.VisibleForTesting;
+
import com.android.server.connectivity.mdns.util.MdnsUtils;
import java.io.IOException;
@@ -231,7 +232,7 @@
* @param writer The writer to use.
* @param now The current system time. This is used when writing the updated TTL.
*/
- @VisibleForTesting
+ @VisibleForTesting(otherwise = VisibleForTesting.PACKAGE_PRIVATE)
public final void write(MdnsPacketWriter writer, long now) throws IOException {
writeHeaderFields(writer);
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
index bfcd0b4..b3bdbe0 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
@@ -30,7 +30,8 @@
import android.util.ArrayMap;
import android.util.Pair;
-import com.android.internal.annotations.VisibleForTesting;
+import androidx.annotation.VisibleForTesting;
+
import com.android.net.module.util.CollectionUtils;
import com.android.net.module.util.SharedLog;
import com.android.server.connectivity.mdns.util.MdnsUtils;
@@ -195,7 +196,7 @@
/**
* Dependencies of MdnsServiceTypeClient, for injection in tests.
*/
- @VisibleForTesting
+ @VisibleForTesting(otherwise = VisibleForTesting.PACKAGE_PRIVATE)
public static class Dependencies {
/**
* @see Handler#sendMessageDelayed(Message, long)
@@ -227,13 +228,22 @@
}
/**
- * Generate a DatagramPacket from given MdnsPacket and InetSocketAddress.
+ * Generate the DatagramPackets from given MdnsPacket and InetSocketAddress.
+ *
+ * <p> If the query with known answer feature is enabled and the MdnsPacket is too large for
+ * a single DatagramPacket, it will be split into multiple DatagramPackets.
*/
- public DatagramPacket getDatagramPacketFromMdnsPacket(@NonNull byte[] packetCreationBuffer,
- @NonNull MdnsPacket packet, @NonNull InetSocketAddress address) throws IOException {
- final byte[] queryBuffer =
- MdnsUtils.createRawDnsPacket(packetCreationBuffer, packet);
- return new DatagramPacket(queryBuffer, 0, queryBuffer.length, address);
+ public List<DatagramPacket> getDatagramPacketsFromMdnsPacket(
+ @NonNull byte[] packetCreationBuffer, @NonNull MdnsPacket packet,
+ @NonNull InetSocketAddress address, boolean isQueryWithKnownAnswer)
+ throws IOException {
+ if (isQueryWithKnownAnswer) {
+ return MdnsUtils.createQueryDatagramPackets(packetCreationBuffer, packet, address);
+ } else {
+ final byte[] queryBuffer =
+ MdnsUtils.createRawDnsPacket(packetCreationBuffer, packet);
+ return List.of(new DatagramPacket(queryBuffer, 0, queryBuffer.length, address));
+ }
}
}
@@ -742,7 +752,8 @@
clock,
sharedLog,
dependencies,
- existingServices)
+ existingServices,
+ featureFlags.isQueryWithKnownAnswerEnabled())
.call();
} catch (RuntimeException e) {
sharedLog.e(String.format("Failed to run EnqueueMdnsQueryCallable for subtype: %s",
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClient.java b/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClient.java
index 7b71e43..9cfcba1 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClient.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClient.java
@@ -25,6 +25,7 @@
import android.net.wifi.WifiManager.MulticastLock;
import android.os.SystemClock;
import android.text.format.DateUtils;
+import android.util.Log;
import com.android.internal.annotations.VisibleForTesting;
import com.android.net.module.util.SharedLog;
@@ -206,18 +207,18 @@
}
@Override
- public void sendPacketRequestingMulticastResponse(@NonNull DatagramPacket packet,
+ public void sendPacketRequestingMulticastResponse(@NonNull List<DatagramPacket> packets,
boolean onlyUseIpv6OnIpv6OnlyNetworks) {
- sendMdnsPacket(packet, multicastPacketQueue, onlyUseIpv6OnIpv6OnlyNetworks);
+ sendMdnsPackets(packets, multicastPacketQueue, onlyUseIpv6OnIpv6OnlyNetworks);
}
@Override
- public void sendPacketRequestingUnicastResponse(@NonNull DatagramPacket packet,
+ public void sendPacketRequestingUnicastResponse(@NonNull List<DatagramPacket> packets,
boolean onlyUseIpv6OnIpv6OnlyNetworks) {
if (useSeparateSocketForUnicast) {
- sendMdnsPacket(packet, unicastPacketQueue, onlyUseIpv6OnIpv6OnlyNetworks);
+ sendMdnsPackets(packets, unicastPacketQueue, onlyUseIpv6OnIpv6OnlyNetworks);
} else {
- sendMdnsPacket(packet, multicastPacketQueue, onlyUseIpv6OnIpv6OnlyNetworks);
+ sendMdnsPackets(packets, multicastPacketQueue, onlyUseIpv6OnIpv6OnlyNetworks);
}
}
@@ -238,17 +239,21 @@
return false;
}
- private void sendMdnsPacket(DatagramPacket packet, Queue<DatagramPacket> packetQueueToUse,
- boolean onlyUseIpv6OnIpv6OnlyNetworks) {
+ private void sendMdnsPackets(List<DatagramPacket> packets,
+ Queue<DatagramPacket> packetQueueToUse, boolean onlyUseIpv6OnIpv6OnlyNetworks) {
if (shouldStopSocketLoop && !MdnsConfigs.allowAddMdnsPacketAfterDiscoveryStops()) {
sharedLog.w("sendMdnsPacket() is called after discovery already stopped");
return;
}
+ if (packets.isEmpty()) {
+ Log.wtf(TAG, "No mDns packets to send");
+ return;
+ }
- final boolean isIpv4 = ((InetSocketAddress) packet.getSocketAddress()).getAddress()
- instanceof Inet4Address;
- final boolean isIpv6 = ((InetSocketAddress) packet.getSocketAddress()).getAddress()
- instanceof Inet6Address;
+ final boolean isIpv4 = ((InetSocketAddress) packets.get(0).getSocketAddress())
+ .getAddress() instanceof Inet4Address;
+ final boolean isIpv6 = ((InetSocketAddress) packets.get(0).getSocketAddress())
+ .getAddress() instanceof Inet6Address;
final boolean ipv6Only = multicastSocket != null && multicastSocket.isOnIPv6OnlyNetwork();
if (isIpv4 && ipv6Only) {
return;
@@ -258,10 +263,11 @@
}
synchronized (packetQueueToUse) {
- while (packetQueueToUse.size() >= MdnsConfigs.mdnsPacketQueueMaxSize()) {
+ while ((packetQueueToUse.size() + packets.size())
+ > MdnsConfigs.mdnsPacketQueueMaxSize()) {
packetQueueToUse.remove();
}
- packetQueueToUse.add(packet);
+ packetQueueToUse.addAll(packets);
}
triggerSendThread();
}
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClientBase.java b/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClientBase.java
index b6000f0..b1a543a 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClientBase.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClientBase.java
@@ -23,6 +23,7 @@
import java.io.IOException;
import java.net.DatagramPacket;
+import java.util.List;
/**
* Base class for multicast socket client.
@@ -40,15 +41,15 @@
void setCallback(@Nullable Callback callback);
/**
- * Send a mDNS request packet via given network that asks for multicast response.
+ * Send mDNS request packets via given network that asks for multicast response.
*/
- void sendPacketRequestingMulticastResponse(@NonNull DatagramPacket packet,
+ void sendPacketRequestingMulticastResponse(@NonNull List<DatagramPacket> packets,
boolean onlyUseIpv6OnIpv6OnlyNetworks);
/**
- * Send a mDNS request packet via given network that asks for unicast response.
+ * Send mDNS request packets via given network that asks for unicast response.
*/
- void sendPacketRequestingUnicastResponse(@NonNull DatagramPacket packet,
+ void sendPacketRequestingUnicastResponse(@NonNull List<DatagramPacket> packets,
boolean onlyUseIpv6OnIpv6OnlyNetworks);
/*** Notify that the given network is requested for mdns discovery / resolution */
diff --git a/service/src/com/android/server/BpfNetMaps.java b/service/src/com/android/server/BpfNetMaps.java
index fc6d8c4..42c1628 100644
--- a/service/src/com/android/server/BpfNetMaps.java
+++ b/service/src/com/android/server/BpfNetMaps.java
@@ -918,6 +918,25 @@
}
}
+ /**
+ * Return whether the network is blocked by firewall chains for the given uid.
+ *
+ * Note that {@link #getDataSaverEnabled()} has a latency before V.
+ *
+ * @param uid The target uid.
+ * @param isNetworkMetered Whether the target network is metered.
+ *
+ * @return True if the network is blocked. Otherwise, false.
+ * @throws ServiceSpecificException if the read fails.
+ *
+ * @hide
+ */
+ @RequiresApi(Build.VERSION_CODES.TIRAMISU)
+ public boolean isUidNetworkingBlocked(final int uid, boolean isNetworkMetered) {
+ return BpfNetMapsUtils.isUidNetworkingBlocked(uid, isNetworkMetered,
+ sConfigurationMap, sUidOwnerMap, sDataSaverEnabledMap);
+ }
+
/** Register callback for statsd to pull atom. */
@RequiresApi(Build.VERSION_CODES.TIRAMISU)
public void setPullAtomCallback(final Context context) {
diff --git a/service/src/com/android/server/ConnectivityService.java b/service/src/com/android/server/ConnectivityService.java
index 123ad8f..005d617 100755
--- a/service/src/com/android/server/ConnectivityService.java
+++ b/service/src/com/android/server/ConnectivityService.java
@@ -2235,7 +2235,11 @@
final long ident = Binder.clearCallingIdentity();
try {
final boolean metered = nc == null ? true : nc.isMetered();
- return mPolicyManager.isUidNetworkingBlocked(uid, metered);
+ if (mDeps.isAtLeastV()) {
+ return mBpfNetMaps.isUidNetworkingBlocked(uid, metered);
+ } else {
+ return mPolicyManager.isUidNetworkingBlocked(uid, metered);
+ }
} finally {
Binder.restoreCallingIdentity(ident);
}
diff --git a/service/src/com/android/server/connectivity/SatelliteAccessController.java b/service/src/com/android/server/connectivity/SatelliteAccessController.java
index b53abce..2cdc932 100644
--- a/service/src/com/android/server/connectivity/SatelliteAccessController.java
+++ b/service/src/com/android/server/connectivity/SatelliteAccessController.java
@@ -20,7 +20,10 @@
import android.annotation.NonNull;
import android.app.role.OnRoleHoldersChangedListener;
import android.app.role.RoleManager;
+import android.content.BroadcastReceiver;
import android.content.Context;
+import android.content.Intent;
+import android.content.IntentFilter;
import android.content.pm.ApplicationInfo;
import android.content.pm.PackageManager;
import android.os.Handler;
@@ -49,7 +52,6 @@
private final Context mContext;
private final Dependencies mDeps;
private final DefaultMessageRoleListener mDefaultMessageRoleListener;
- private final UserManager mUserManager;
private final Consumer<Set<Integer>> mCallback;
private final Handler mConnectivityServiceHandler;
@@ -114,7 +116,6 @@
@NonNull final Handler connectivityServiceInternalHandler) {
mContext = c;
mDeps = deps;
- mUserManager = mContext.getSystemService(UserManager.class);
mDefaultMessageRoleListener = new DefaultMessageRoleListener();
mCallback = callback;
mConnectivityServiceHandler = connectivityServiceInternalHandler;
@@ -165,9 +166,6 @@
}
// on Role sms change triggered by OnRoleHoldersChangedListener()
- // TODO(b/326373613): using UserLifecycleListener, callback to be received when user removed for
- // user delete scenario. This to be used to update uid list and ML Layer request can also be
- // updated.
private void onRoleSmsChanged(@NonNull UserHandle userHandle) {
int userId = userHandle.getIdentifier();
if (userId == Process.INVALID_UID) {
@@ -184,9 +182,8 @@
mAllUsersSatelliteNetworkFallbackUidCache.get(userId, new ArraySet<>());
Log.i(TAG, "currentUser : role_sms_packages: " + userId + " : " + packageNames);
- final Set<Integer> newUidsForUser = !packageNames.isEmpty()
- ? updateSatelliteNetworkFallbackUidListCache(packageNames, userHandle)
- : new ArraySet<>();
+ final Set<Integer> newUidsForUser =
+ updateSatelliteNetworkFallbackUidListCache(packageNames, userHandle);
Log.i(TAG, "satellite_fallback_uid: " + newUidsForUser);
// on Role change, update the multilayer request at ConnectivityService with updated
@@ -197,6 +194,11 @@
mAllUsersSatelliteNetworkFallbackUidCache.put(userId, newUidsForUser);
+ // Update all users fallback cache for user, send cs fallback to update ML request
+ reportSatelliteNetworkFallbackUids();
+ }
+
+ private void reportSatelliteNetworkFallbackUids() {
// Merge all uids of multiple users available
Set<Integer> mergedSatelliteNetworkFallbackUidCache = new ArraySet<>();
for (int i = 0; i < mAllUsersSatelliteNetworkFallbackUidCache.size(); i++) {
@@ -210,27 +212,48 @@
mCallback.accept(mergedSatelliteNetworkFallbackUidCache);
}
- private List<String> getRoleSmsChangedPackageName(UserHandle userHandle) {
- try {
- return mDeps.getRoleHoldersAsUser(RoleManager.ROLE_SMS, userHandle);
- } catch (RuntimeException e) {
- Log.wtf(TAG, "Could not get package name at role sms change update due to: " + e);
- return null;
- }
- }
-
- /** Register OnRoleHoldersChangedListener */
public void start() {
mConnectivityServiceHandler.post(this::updateAllUserRoleSmsUids);
+
+ // register sms OnRoleHoldersChangedListener
mDefaultMessageRoleListener.register();
+
+ // Monitor for User removal intent, to update satellite fallback uids.
+ IntentFilter userRemovedFilter = new IntentFilter(Intent.ACTION_USER_REMOVED);
+ mContext.registerReceiver(new BroadcastReceiver() {
+ @Override
+ public void onReceive(Context context, Intent intent) {
+ final String action = intent.getAction();
+ if (Intent.ACTION_USER_REMOVED.equals(action)) {
+ final UserHandle userHandle = intent.getParcelableExtra(Intent.EXTRA_USER);
+ if (userHandle == null) return;
+ updateSatelliteFallbackUidListOnUserRemoval(userHandle.getIdentifier());
+ } else {
+ Log.wtf(TAG, "received unexpected intent: " + action);
+ }
+ }
+ }, userRemovedFilter, null, mConnectivityServiceHandler);
+
}
private void updateAllUserRoleSmsUids() {
- List<UserHandle> existingUsers = mUserManager.getUserHandles(true /* excludeDying */);
+ UserManager userManager = mContext.getSystemService(UserManager.class);
+ // get existing user handles of available users
+ List<UserHandle> existingUsers = userManager.getUserHandles(true /*excludeDying*/);
+
// Iterate through the user handles and obtain their uids with role sms and satellite
// communication permission
+ Log.i(TAG, "existing users: " + existingUsers);
for (UserHandle userHandle : existingUsers) {
onRoleSmsChanged(userHandle);
}
}
+
+ private void updateSatelliteFallbackUidListOnUserRemoval(int userIdRemoved) {
+ Log.i(TAG, "user id removed:" + userIdRemoved);
+ if (mAllUsersSatelliteNetworkFallbackUidCache.contains(userIdRemoved)) {
+ mAllUsersSatelliteNetworkFallbackUidCache.remove(userIdRemoved);
+ reportSatelliteNetworkFallbackUids();
+ }
+ }
}
diff --git a/tests/cts/net/src/android/net/cts/ApfIntegrationTest.kt b/tests/cts/net/src/android/net/cts/ApfIntegrationTest.kt
new file mode 100644
index 0000000..3ecbdc6
--- /dev/null
+++ b/tests/cts/net/src/android/net/cts/ApfIntegrationTest.kt
@@ -0,0 +1,91 @@
+/*
+ * Copyright (C) 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package android.net.cts
+
+import android.content.pm.PackageManager.FEATURE_WIFI
+import android.net.ConnectivityManager
+import android.net.NetworkCapabilities
+import android.net.NetworkRequest
+import android.os.Build
+import android.platform.test.annotations.AppModeFull
+import android.system.OsConstants
+import androidx.test.platform.app.InstrumentationRegistry
+import com.android.compatibility.common.util.PropertyUtil.isVendorApiLevelNewerThan
+import com.android.compatibility.common.util.SystemUtil
+import com.android.testutils.DevSdkIgnoreRule
+import com.android.testutils.DevSdkIgnoreRunner
+import com.android.testutils.NetworkStackModuleTest
+import com.android.testutils.RecorderCallback.CallbackEntry.LinkPropertiesChanged
+import com.android.testutils.TestableNetworkCallback
+import com.google.common.truth.Truth.assertThat
+import kotlin.test.assertEquals
+import kotlin.test.assertNotNull
+import org.junit.After
+import org.junit.Assume.assumeTrue
+import org.junit.Before
+import org.junit.Test
+import org.junit.runner.RunWith
+
+private const val TIMEOUT_MS = 2000L
+
+@AppModeFull(reason = "CHANGE_NETWORK_STATE permission can't be granted to instant apps")
+@RunWith(DevSdkIgnoreRunner::class)
+@NetworkStackModuleTest
+@DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.TIRAMISU)
+class ApfIntegrationTest {
+ private val context by lazy { InstrumentationRegistry.getInstrumentation().context }
+ private val cm by lazy { context.getSystemService(ConnectivityManager::class.java)!! }
+ private val pm by lazy { context.packageManager }
+ private lateinit var ifname: String
+ private lateinit var networkCallback: TestableNetworkCallback
+
+ @Before
+ fun setUp() {
+ assumeTrue(pm.hasSystemFeature(FEATURE_WIFI))
+ assumeTrue(isVendorApiLevelNewerThan(Build.VERSION_CODES.TIRAMISU))
+ networkCallback = TestableNetworkCallback()
+ cm.requestNetwork(
+ NetworkRequest.Builder()
+ .addTransportType(NetworkCapabilities.TRANSPORT_WIFI)
+ .addCapability(NetworkCapabilities.NET_CAPABILITY_INTERNET)
+ .build(),
+ networkCallback
+ )
+ networkCallback.eventuallyExpect<LinkPropertiesChanged>(TIMEOUT_MS) {
+ ifname = assertNotNull(it.lp.interfaceName)
+ true
+ }
+ }
+
+ @After
+ fun tearDown() {
+ if (::networkCallback.isInitialized) {
+ cm.unregisterNetworkCallback(networkCallback)
+ }
+ }
+
+ @Test
+ fun testGetApfCapabilities() {
+ val caps = SystemUtil.runShellCommand("cmd network_stack apf $ifname capabilities").trim()
+ val (version, maxLen, packetFormat) = caps.split(",").map { it.toInt() }
+ assertEquals(4, version)
+ assertThat(maxLen).isAtLeast(1024)
+ if (isVendorApiLevelNewerThan(Build.VERSION_CODES.UPSIDE_DOWN_CAKE)) {
+ assertThat(maxLen).isAtLeast(2000)
+ }
+ assertEquals(OsConstants.ARPHRD_ETHER, packetFormat)
+ }
+}
diff --git a/tests/integration/src/com/android/server/net/integrationtests/NetworkStatsIntegrationTest.kt b/tests/integration/src/com/android/server/net/integrationtests/NetworkStatsIntegrationTest.kt
index 765e56e..52e502d 100644
--- a/tests/integration/src/com/android/server/net/integrationtests/NetworkStatsIntegrationTest.kt
+++ b/tests/integration/src/com/android/server/net/integrationtests/NetworkStatsIntegrationTest.kt
@@ -299,7 +299,8 @@
val buf = ByteArray(DEFAULT_BUFFER_SIZE)
httpServer.addResponse(
- TestHttpServer.Request(path, NanoHTTPD.Method.POST), NanoHTTPD.Response.Status.OK,
+ TestHttpServer.Request(path, NanoHTTPD.Method.POST),
+ NanoHTTPD.Response.Status.OK,
content = getRandomString(downloadSize)
)
var httpConnection: HttpURLConnection? = null
@@ -349,15 +350,19 @@
) {
operator fun plus(other: BareStats): BareStats {
return BareStats(
- this.rxBytes + other.rxBytes, this.rxPackets + other.rxPackets,
- this.txBytes + other.txBytes, this.txPackets + other.txPackets
+ this.rxBytes + other.rxBytes,
+ this.rxPackets + other.rxPackets,
+ this.txBytes + other.txBytes,
+ this.txPackets + other.txPackets
)
}
operator fun minus(other: BareStats): BareStats {
return BareStats(
- this.rxBytes - other.rxBytes, this.rxPackets - other.rxPackets,
- this.txBytes - other.txBytes, this.txPackets - other.txPackets
+ this.rxBytes - other.rxBytes,
+ this.rxPackets - other.rxPackets,
+ this.txBytes - other.txBytes,
+ this.txPackets - other.txPackets
)
}
@@ -405,8 +410,12 @@
private fun getUidDetail(iface: String, tag: Int): BareStats {
return getNetworkStatsThat(iface, tag) { nsm, template ->
nsm.queryDetailsForUidTagState(
- template, Long.MIN_VALUE, Long.MAX_VALUE,
- Process.myUid(), tag, Bucket.STATE_ALL
+ template,
+ Long.MIN_VALUE,
+ Long.MAX_VALUE,
+ Process.myUid(),
+ tag,
+ Bucket.STATE_ALL
)
}
}
@@ -498,28 +507,36 @@
assertInRange(
"Unexpected iface traffic stats",
after.iface,
- before.trafficStatsIface, after.trafficStatsIface,
- lower, upper
+ before.trafficStatsIface,
+ after.trafficStatsIface,
+ lower,
+ upper
)
// Uid traffic stats are counted in both direction because the external network
// traffic is also attributed to the test uid.
assertInRange(
"Unexpected uid traffic stats",
after.iface,
- before.trafficStatsUid, after.trafficStatsUid,
- lower + lower.reverse(), upper + upper.reverse()
+ before.trafficStatsUid,
+ after.trafficStatsUid,
+ lower + lower.reverse(),
+ upper + upper.reverse()
)
assertInRange(
"Unexpected non-tagged summary stats",
after.iface,
- before.statsSummary, after.statsSummary,
- lower, upper
+ before.statsSummary,
+ after.statsSummary,
+ lower,
+ upper
)
assertInRange(
"Unexpected non-tagged uid stats",
after.iface,
- before.statsUid, after.statsUid,
- lower, upper
+ before.statsUid,
+ after.statsUid,
+ lower,
+ upper
)
}
@@ -546,14 +563,16 @@
assertInRange(
"Unexpected tagged summary stats",
after.iface,
- before.taggedSummary, after.taggedSummary,
+ before.taggedSummary,
+ after.taggedSummary,
lower,
upper
)
assertInRange(
"Unexpected tagged uid stats: ${Process.myUid()}",
after.iface,
- before.taggedUid, after.taggedUid,
+ before.taggedUid,
+ after.taggedUid,
lower,
upper
)
@@ -570,7 +589,8 @@
) {
// Passing the value after operation and the value before operation to dump the actual
// numbers if it fails.
- assertTrue(checkInRange(before, after, lower, upper),
+ assertTrue(
+ checkInRange(before, after, lower, upper),
"$tag on $iface: $after - $before is not within range [$lower, $upper]"
)
}
diff --git a/tests/unit/java/com/android/server/ConnectivityServiceTest.java b/tests/unit/java/com/android/server/ConnectivityServiceTest.java
index 1f8a743..17c5901 100755
--- a/tests/unit/java/com/android/server/ConnectivityServiceTest.java
+++ b/tests/unit/java/com/android/server/ConnectivityServiceTest.java
@@ -1719,6 +1719,8 @@
private void mockUidNetworkingBlocked() {
doAnswer(i -> isUidBlocked(mBlockedReasons, i.getArgument(1))
).when(mNetworkPolicyManager).isUidNetworkingBlocked(anyInt(), anyBoolean());
+ doAnswer(i -> isUidBlocked(mBlockedReasons, i.getArgument(1))
+ ).when(mBpfNetMaps).isUidNetworkingBlocked(anyInt(), anyBoolean());
}
private boolean isUidBlocked(int blockedReasons, boolean meteredNetwork) {
diff --git a/tests/unit/java/com/android/server/connectivity/SatelliteAccessControllerTest.kt b/tests/unit/java/com/android/server/connectivity/SatelliteAccessControllerTest.kt
index 193078b..7885325 100644
--- a/tests/unit/java/com/android/server/connectivity/SatelliteAccessControllerTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/SatelliteAccessControllerTest.kt
@@ -18,17 +18,22 @@
import android.Manifest
import android.app.role.OnRoleHoldersChangedListener
import android.app.role.RoleManager
+import android.content.BroadcastReceiver
import android.content.Context
+import android.content.Intent
+import android.content.IntentFilter
import android.content.pm.ApplicationInfo
import android.content.pm.PackageManager
-import android.content.pm.UserInfo
import android.os.Build
import android.os.Handler
+import android.os.Looper
import android.os.UserHandle
+import android.os.UserManager
import android.util.ArraySet
-import com.android.server.makeMockUserManager
import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo
import com.android.testutils.DevSdkIgnoreRunner
+import java.util.concurrent.Executor
+import java.util.function.Consumer
import org.junit.Before
import org.junit.Test
import org.junit.runner.RunWith
@@ -36,30 +41,32 @@
import org.mockito.ArgumentMatchers.any
import org.mockito.ArgumentMatchers.anyInt
import org.mockito.ArgumentMatchers.eq
+import org.mockito.ArgumentMatchers.isNull
import org.mockito.Mockito.doReturn
import org.mockito.Mockito.mock
import org.mockito.Mockito.never
+import org.mockito.Mockito.timeout
+import org.mockito.Mockito.times
import org.mockito.Mockito.verify
-import java.util.concurrent.Executor
-import java.util.function.Consumer
-private const val USER = 0
-val USER_INFO = UserInfo(USER, "" /* name */, UserInfo.FLAG_PRIMARY)
-val USER_HANDLE = UserHandle(USER)
private const val PRIMARY_USER = 0
private const val SECONDARY_USER = 10
private val PRIMARY_USER_HANDLE = UserHandle.of(PRIMARY_USER)
private val SECONDARY_USER_HANDLE = UserHandle.of(SECONDARY_USER)
+
// sms app names
private const val SMS_APP1 = "sms_app_1"
private const val SMS_APP2 = "sms_app_2"
+
// sms app ids
private const val SMS_APP_ID1 = 100
private const val SMS_APP_ID2 = 101
+
// UID for app1 and app2 on primary user
// These app could become default sms app for user1
private val PRIMARY_USER_SMS_APP_UID1 = UserHandle.getUid(PRIMARY_USER, SMS_APP_ID1)
private val PRIMARY_USER_SMS_APP_UID2 = UserHandle.getUid(PRIMARY_USER, SMS_APP_ID2)
+
// UID for app1 and app2 on secondary user
// These app could become default sms app for user2
private val SECONDARY_USER_SMS_APP_UID1 = UserHandle.getUid(SECONDARY_USER, SMS_APP_ID1)
@@ -69,154 +76,259 @@
@IgnoreUpTo(Build.VERSION_CODES.TIRAMISU)
class SatelliteAccessControllerTest {
private val context = mock(Context::class.java)
- private val mPackageManager = mock(PackageManager::class.java)
- private val mHandler = mock(Handler::class.java)
- private val mRoleManager =
- mock(SatelliteAccessController.Dependencies::class.java)
+ private val primaryUserContext = mock(Context::class.java)
+ private val secondaryUserContext = mock(Context::class.java)
+ private val mPackageManagerPrimaryUser = mock(PackageManager::class.java)
+ private val mPackageManagerSecondaryUser = mock(PackageManager::class.java)
+ private val mDeps = mock(SatelliteAccessController.Dependencies::class.java)
private val mCallback = mock(Consumer::class.java) as Consumer<Set<Int>>
- private val mSatelliteAccessController =
- SatelliteAccessController(context, mRoleManager, mCallback, mHandler)
+ private val userManager = mock(UserManager::class.java)
+ private val mHandler = Handler(Looper.getMainLooper())
+ private var mSatelliteAccessController =
+ SatelliteAccessController(context, mDeps, mCallback, mHandler)
private lateinit var mRoleHolderChangedListener: OnRoleHoldersChangedListener
+ private lateinit var mUserRemovedReceiver: BroadcastReceiver
+
+ private fun <T> mockService(name: String, clazz: Class<T>, service: T) {
+ doReturn(name).`when`(context).getSystemServiceName(clazz)
+ doReturn(service).`when`(context).getSystemService(name)
+ if (context.getSystemService(clazz) == null) {
+ // Test is using mockito-extended
+ doReturn(service).`when`(context).getSystemService(clazz)
+ }
+ }
+
@Before
@Throws(PackageManager.NameNotFoundException::class)
fun setup() {
- makeMockUserManager(USER_INFO, USER_HANDLE)
- doReturn(context).`when`(context).createContextAsUser(any(), anyInt())
- doReturn(mPackageManager).`when`(context).packageManager
+ doReturn(emptyList<UserHandle>()).`when`(userManager).getUserHandles(true)
+ mockService(Context.USER_SERVICE, UserManager::class.java, userManager)
- doReturn(PackageManager.PERMISSION_GRANTED)
- .`when`(mPackageManager)
- .checkPermission(Manifest.permission.SATELLITE_COMMUNICATION, SMS_APP1)
- doReturn(PackageManager.PERMISSION_GRANTED)
- .`when`(mPackageManager)
- .checkPermission(Manifest.permission.SATELLITE_COMMUNICATION, SMS_APP2)
+ doReturn(primaryUserContext).`when`(context).createContextAsUser(PRIMARY_USER_HANDLE, 0)
+ doReturn(mPackageManagerPrimaryUser).`when`(primaryUserContext).packageManager
- // Initialise default message application primary user package1
+ doReturn(secondaryUserContext).`when`(context).createContextAsUser(SECONDARY_USER_HANDLE, 0)
+ doReturn(mPackageManagerSecondaryUser).`when`(secondaryUserContext).packageManager
+
+ for (app in listOf(SMS_APP1, SMS_APP2)) {
+ doReturn(PackageManager.PERMISSION_GRANTED)
+ .`when`(mPackageManagerPrimaryUser)
+ .checkPermission(Manifest.permission.SATELLITE_COMMUNICATION, app)
+ doReturn(PackageManager.PERMISSION_GRANTED)
+ .`when`(mPackageManagerSecondaryUser)
+ .checkPermission(Manifest.permission.SATELLITE_COMMUNICATION, app)
+ }
+
+ // Initialise message application primary user package1
val applicationInfo1 = ApplicationInfo()
applicationInfo1.uid = PRIMARY_USER_SMS_APP_UID1
doReturn(applicationInfo1)
- .`when`(mPackageManager)
+ .`when`(mPackageManagerPrimaryUser)
.getApplicationInfo(eq(SMS_APP1), anyInt())
- // Initialise default message application primary user package2
+ // Initialise message application primary user package2
val applicationInfo2 = ApplicationInfo()
applicationInfo2.uid = PRIMARY_USER_SMS_APP_UID2
doReturn(applicationInfo2)
- .`when`(mPackageManager)
+ .`when`(mPackageManagerPrimaryUser)
.getApplicationInfo(eq(SMS_APP2), anyInt())
- // Get registered listener using captor
- val listenerCaptor = ArgumentCaptor.forClass(
- OnRoleHoldersChangedListener::class.java
- )
- mSatelliteAccessController.start()
- verify(mRoleManager).addOnRoleHoldersChangedListenerAsUser(
- any(Executor::class.java), listenerCaptor.capture(), any(UserHandle::class.java))
- mRoleHolderChangedListener = listenerCaptor.value
+ // Initialise message application secondary user package1
+ val applicationInfo3 = ApplicationInfo()
+ applicationInfo3.uid = SECONDARY_USER_SMS_APP_UID1
+ doReturn(applicationInfo3)
+ .`when`(mPackageManagerSecondaryUser)
+ .getApplicationInfo(eq(SMS_APP1), anyInt())
+
+ // Initialise message application secondary user package2
+ val applicationInfo4 = ApplicationInfo()
+ applicationInfo4.uid = SECONDARY_USER_SMS_APP_UID2
+ doReturn(applicationInfo4)
+ .`when`(mPackageManagerSecondaryUser)
+ .getApplicationInfo(eq(SMS_APP2), anyInt())
}
@Test
fun test_onRoleHoldersChanged_SatelliteFallbackUid_Changed_SingleUser() {
- doReturn(listOf<String>()).`when`(mRoleManager).getRoleHoldersAsUser(RoleManager.ROLE_SMS,
- PRIMARY_USER_HANDLE)
+ startSatelliteAccessController()
+ doReturn(listOf<String>()).`when`(mDeps).getRoleHoldersAsUser(
+ RoleManager.ROLE_SMS,
+ PRIMARY_USER_HANDLE
+ )
mRoleHolderChangedListener.onRoleHoldersChanged(RoleManager.ROLE_SMS, PRIMARY_USER_HANDLE)
verify(mCallback, never()).accept(any())
// check DEFAULT_MESSAGING_APP1 is available as satellite network fallback uid
doReturn(listOf(SMS_APP1))
- .`when`(mRoleManager).getRoleHoldersAsUser(RoleManager.ROLE_SMS, PRIMARY_USER_HANDLE)
+ .`when`(mDeps).getRoleHoldersAsUser(RoleManager.ROLE_SMS, PRIMARY_USER_HANDLE)
mRoleHolderChangedListener.onRoleHoldersChanged(RoleManager.ROLE_SMS, PRIMARY_USER_HANDLE)
verify(mCallback).accept(setOf(PRIMARY_USER_SMS_APP_UID1))
// check SMS_APP2 is available as satellite network Fallback uid
- doReturn(listOf(SMS_APP2)).`when`(mRoleManager).getRoleHoldersAsUser(RoleManager.ROLE_SMS,
- PRIMARY_USER_HANDLE)
+ doReturn(listOf(SMS_APP2)).`when`(mDeps).getRoleHoldersAsUser(
+ RoleManager.ROLE_SMS,
+ PRIMARY_USER_HANDLE
+ )
mRoleHolderChangedListener.onRoleHoldersChanged(RoleManager.ROLE_SMS, PRIMARY_USER_HANDLE)
verify(mCallback).accept(setOf(PRIMARY_USER_SMS_APP_UID2))
// check no uid is available as satellite network fallback uid
- doReturn(listOf<String>()).`when`(mRoleManager).getRoleHoldersAsUser(RoleManager.ROLE_SMS,
- PRIMARY_USER_HANDLE)
+ doReturn(listOf<String>()).`when`(mDeps).getRoleHoldersAsUser(
+ RoleManager.ROLE_SMS,
+ PRIMARY_USER_HANDLE
+ )
mRoleHolderChangedListener.onRoleHoldersChanged(RoleManager.ROLE_SMS, PRIMARY_USER_HANDLE)
verify(mCallback).accept(ArraySet())
}
@Test
fun test_onRoleHoldersChanged_NoSatelliteCommunicationPermission() {
- doReturn(listOf<Any>()).`when`(mRoleManager).getRoleHoldersAsUser(RoleManager.ROLE_SMS,
- PRIMARY_USER_HANDLE)
+ startSatelliteAccessController()
+ doReturn(listOf<Any>()).`when`(mDeps).getRoleHoldersAsUser(
+ RoleManager.ROLE_SMS,
+ PRIMARY_USER_HANDLE
+ )
mRoleHolderChangedListener.onRoleHoldersChanged(RoleManager.ROLE_SMS, PRIMARY_USER_HANDLE)
verify(mCallback, never()).accept(any())
// check DEFAULT_MESSAGING_APP1 is not available as satellite network fallback uid
// since satellite communication permission not available.
doReturn(PackageManager.PERMISSION_DENIED)
- .`when`(mPackageManager)
+ .`when`(mPackageManagerPrimaryUser)
.checkPermission(Manifest.permission.SATELLITE_COMMUNICATION, SMS_APP1)
doReturn(listOf(SMS_APP1))
- .`when`(mRoleManager).getRoleHoldersAsUser(RoleManager.ROLE_SMS, PRIMARY_USER_HANDLE)
+ .`when`(mDeps).getRoleHoldersAsUser(RoleManager.ROLE_SMS, PRIMARY_USER_HANDLE)
mRoleHolderChangedListener.onRoleHoldersChanged(RoleManager.ROLE_SMS, PRIMARY_USER_HANDLE)
verify(mCallback, never()).accept(any())
}
@Test
fun test_onRoleHoldersChanged_RoleSms_NotAvailable() {
+ startSatelliteAccessController()
doReturn(listOf(SMS_APP1))
- .`when`(mRoleManager).getRoleHoldersAsUser(RoleManager.ROLE_SMS, PRIMARY_USER_HANDLE)
- mRoleHolderChangedListener.onRoleHoldersChanged(RoleManager.ROLE_BROWSER,
- PRIMARY_USER_HANDLE)
+ .`when`(mDeps).getRoleHoldersAsUser(RoleManager.ROLE_SMS, PRIMARY_USER_HANDLE)
+ mRoleHolderChangedListener.onRoleHoldersChanged(
+ RoleManager.ROLE_BROWSER,
+ PRIMARY_USER_HANDLE
+ )
verify(mCallback, never()).accept(any())
}
@Test
fun test_onRoleHoldersChanged_SatelliteNetworkFallbackUid_Changed_multiUser() {
- doReturn(listOf<String>()).`when`(mRoleManager).getRoleHoldersAsUser(RoleManager.ROLE_SMS,
- PRIMARY_USER_HANDLE)
+ startSatelliteAccessController()
+ doReturn(listOf<String>()).`when`(mDeps).getRoleHoldersAsUser(
+ RoleManager.ROLE_SMS,
+ PRIMARY_USER_HANDLE
+ )
mRoleHolderChangedListener.onRoleHoldersChanged(RoleManager.ROLE_SMS, PRIMARY_USER_HANDLE)
verify(mCallback, never()).accept(any())
// check SMS_APP1 is available as satellite network fallback uid at primary user
doReturn(listOf(SMS_APP1))
- .`when`(mRoleManager).getRoleHoldersAsUser(RoleManager.ROLE_SMS, PRIMARY_USER_HANDLE)
+ .`when`(mDeps).getRoleHoldersAsUser(RoleManager.ROLE_SMS, PRIMARY_USER_HANDLE)
mRoleHolderChangedListener.onRoleHoldersChanged(RoleManager.ROLE_SMS, PRIMARY_USER_HANDLE)
verify(mCallback).accept(setOf(PRIMARY_USER_SMS_APP_UID1))
// check SMS_APP2 is available as satellite network fallback uid at primary user
- doReturn(listOf(SMS_APP2)).`when`(mRoleManager).getRoleHoldersAsUser(RoleManager.ROLE_SMS,
- PRIMARY_USER_HANDLE)
+ doReturn(listOf(SMS_APP2)).`when`(mDeps).getRoleHoldersAsUser(
+ RoleManager.ROLE_SMS,
+ PRIMARY_USER_HANDLE
+ )
mRoleHolderChangedListener.onRoleHoldersChanged(RoleManager.ROLE_SMS, PRIMARY_USER_HANDLE)
verify(mCallback).accept(setOf(PRIMARY_USER_SMS_APP_UID2))
// check SMS_APP1 is available as satellite network fallback uid at secondary user
- val applicationInfo1 = ApplicationInfo()
- applicationInfo1.uid = SECONDARY_USER_SMS_APP_UID1
- doReturn(applicationInfo1).`when`(mPackageManager)
- .getApplicationInfo(eq(SMS_APP1), anyInt())
- doReturn(listOf(SMS_APP1)).`when`(mRoleManager).getRoleHoldersAsUser(RoleManager.ROLE_SMS,
- SECONDARY_USER_HANDLE)
+ doReturn(listOf(SMS_APP1)).`when`(mDeps).getRoleHoldersAsUser(
+ RoleManager.ROLE_SMS,
+ SECONDARY_USER_HANDLE
+ )
mRoleHolderChangedListener.onRoleHoldersChanged(RoleManager.ROLE_SMS, SECONDARY_USER_HANDLE)
verify(mCallback).accept(setOf(PRIMARY_USER_SMS_APP_UID2, SECONDARY_USER_SMS_APP_UID1))
// check no uid is available as satellite network fallback uid at primary user
- doReturn(listOf<String>()).`when`(mRoleManager).getRoleHoldersAsUser(RoleManager.ROLE_SMS,
- PRIMARY_USER_HANDLE)
- mRoleHolderChangedListener.onRoleHoldersChanged(RoleManager.ROLE_SMS,
- PRIMARY_USER_HANDLE)
+ doReturn(listOf<String>()).`when`(mDeps).getRoleHoldersAsUser(
+ RoleManager.ROLE_SMS,
+ PRIMARY_USER_HANDLE
+ )
+ mRoleHolderChangedListener.onRoleHoldersChanged(
+ RoleManager.ROLE_SMS,
+ PRIMARY_USER_HANDLE
+ )
verify(mCallback).accept(setOf(SECONDARY_USER_SMS_APP_UID1))
// check SMS_APP2 is available as satellite network fallback uid at secondary user
- applicationInfo1.uid = SECONDARY_USER_SMS_APP_UID2
- doReturn(applicationInfo1).`when`(mPackageManager)
- .getApplicationInfo(eq(SMS_APP2), anyInt())
doReturn(listOf(SMS_APP2))
- .`when`(mRoleManager).getRoleHoldersAsUser(RoleManager.ROLE_SMS, SECONDARY_USER_HANDLE)
+ .`when`(mDeps).getRoleHoldersAsUser(RoleManager.ROLE_SMS, SECONDARY_USER_HANDLE)
mRoleHolderChangedListener.onRoleHoldersChanged(RoleManager.ROLE_SMS, SECONDARY_USER_HANDLE)
verify(mCallback).accept(setOf(SECONDARY_USER_SMS_APP_UID2))
// check no uid is available as satellite network fallback uid at secondary user
- doReturn(listOf<String>()).`when`(mRoleManager).getRoleHoldersAsUser(RoleManager.ROLE_SMS,
- SECONDARY_USER_HANDLE)
+ doReturn(listOf<String>()).`when`(mDeps).getRoleHoldersAsUser(
+ RoleManager.ROLE_SMS,
+ SECONDARY_USER_HANDLE
+ )
mRoleHolderChangedListener.onRoleHoldersChanged(RoleManager.ROLE_SMS, SECONDARY_USER_HANDLE)
verify(mCallback).accept(ArraySet())
}
+
+ @Test
+ fun test_SatelliteFallbackUidCallback_OnUserRemoval() {
+ startSatelliteAccessController()
+ // check SMS_APP2 is available as satellite network fallback uid at primary user
+ doReturn(listOf(SMS_APP2)).`when`(mDeps).getRoleHoldersAsUser(
+ RoleManager.ROLE_SMS,
+ PRIMARY_USER_HANDLE
+ )
+ mRoleHolderChangedListener.onRoleHoldersChanged(RoleManager.ROLE_SMS, PRIMARY_USER_HANDLE)
+ verify(mCallback).accept(setOf(PRIMARY_USER_SMS_APP_UID2))
+
+ // check SMS_APP1 is available as satellite network fallback uid at secondary user
+ doReturn(listOf(SMS_APP1)).`when`(mDeps).getRoleHoldersAsUser(
+ RoleManager.ROLE_SMS,
+ SECONDARY_USER_HANDLE
+ )
+ mRoleHolderChangedListener.onRoleHoldersChanged(RoleManager.ROLE_SMS, SECONDARY_USER_HANDLE)
+ verify(mCallback).accept(setOf(PRIMARY_USER_SMS_APP_UID2, SECONDARY_USER_SMS_APP_UID1))
+
+ val userRemovalIntent = Intent(Intent.ACTION_USER_REMOVED)
+ userRemovalIntent.putExtra(Intent.EXTRA_USER, SECONDARY_USER_HANDLE)
+ mUserRemovedReceiver.onReceive(context, userRemovalIntent)
+ verify(mCallback, times(2)).accept(setOf(PRIMARY_USER_SMS_APP_UID2))
+ }
+
+ @Test
+ fun testOnStartUpCallbackSatelliteFallbackUidWithExistingUsers() {
+ doReturn(
+ listOf(PRIMARY_USER_HANDLE)
+ ).`when`(userManager).getUserHandles(true)
+ doReturn(listOf(SMS_APP1))
+ .`when`(mDeps).getRoleHoldersAsUser(RoleManager.ROLE_SMS, PRIMARY_USER_HANDLE)
+ // At start up, SatelliteAccessController must call CS callback with existing users'
+ // default messaging apps uids.
+ startSatelliteAccessController()
+ verify(mCallback, timeout(500)).accept(setOf(PRIMARY_USER_SMS_APP_UID1))
+ }
+
+ private fun startSatelliteAccessController() {
+ mSatelliteAccessController.start()
+ // Get registered listener using captor
+ val listenerCaptor = ArgumentCaptor.forClass(OnRoleHoldersChangedListener::class.java)
+ verify(mDeps).addOnRoleHoldersChangedListenerAsUser(
+ any(Executor::class.java),
+ listenerCaptor.capture(),
+ any(UserHandle::class.java)
+ )
+ mRoleHolderChangedListener = listenerCaptor.value
+
+ // Get registered receiver using captor
+ val userRemovedReceiverCaptor = ArgumentCaptor.forClass(BroadcastReceiver::class.java)
+ verify(context).registerReceiver(
+ userRemovedReceiverCaptor.capture(),
+ any(IntentFilter::class.java),
+ isNull(),
+ any(Handler::class.java)
+ )
+ mUserRemovedReceiver = userRemovedReceiverCaptor.value
+ }
}
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClientTest.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClientTest.java
index 9474464..fb3d183 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClientTest.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClientTest.java
@@ -23,6 +23,7 @@
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.eq;
+import static org.mockito.Mockito.inOrder;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.timeout;
@@ -47,6 +48,7 @@
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.ArgumentCaptor;
+import org.mockito.InOrder;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
@@ -55,6 +57,7 @@
import java.net.DatagramPacket;
import java.net.NetworkInterface;
import java.net.SocketException;
+import java.util.ArrayList;
import java.util.List;
@RunWith(DevSdkIgnoreRunner.class)
@@ -154,7 +157,7 @@
verify(mSocketCreationCallback).onSocketCreated(tetherSocketKey2);
// Send packet to IPv4 with mSocketKey and verify sending has been called.
- mSocketClient.sendPacketRequestingMulticastResponse(ipv4Packet, mSocketKey,
+ mSocketClient.sendPacketRequestingMulticastResponse(List.of(ipv4Packet), mSocketKey,
false /* onlyUseIpv6OnIpv6OnlyNetworks */);
HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
verify(mSocket).send(ipv4Packet);
@@ -162,7 +165,7 @@
verify(tetherIfaceSock2, never()).send(any());
// Send packet to IPv4 with onlyUseIpv6OnIpv6OnlyNetworks = true, the packet will be sent.
- mSocketClient.sendPacketRequestingMulticastResponse(ipv4Packet, mSocketKey,
+ mSocketClient.sendPacketRequestingMulticastResponse(List.of(ipv4Packet), mSocketKey,
true /* onlyUseIpv6OnIpv6OnlyNetworks */);
HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
verify(mSocket, times(2)).send(ipv4Packet);
@@ -170,7 +173,7 @@
verify(tetherIfaceSock2, never()).send(any());
// Send packet to IPv6 with tetherSocketKey1 and verify sending has been called.
- mSocketClient.sendPacketRequestingMulticastResponse(ipv6Packet, tetherSocketKey1,
+ mSocketClient.sendPacketRequestingMulticastResponse(List.of(ipv6Packet), tetherSocketKey1,
false /* onlyUseIpv6OnIpv6OnlyNetworks */);
HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
verify(mSocket, never()).send(ipv6Packet);
@@ -180,7 +183,7 @@
// Send packet to IPv6 with onlyUseIpv6OnIpv6OnlyNetworks = true, the packet will not be
// sent. Therefore, the tetherIfaceSock1.send() and tetherIfaceSock2.send() are still be
// called once.
- mSocketClient.sendPacketRequestingMulticastResponse(ipv6Packet, tetherSocketKey1,
+ mSocketClient.sendPacketRequestingMulticastResponse(List.of(ipv6Packet), tetherSocketKey1,
true /* onlyUseIpv6OnIpv6OnlyNetworks */);
HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
verify(mSocket, never()).send(ipv6Packet);
@@ -266,7 +269,7 @@
verify(mSocketCreationCallback).onSocketCreated(socketKey3);
// Send IPv4 packet on the mSocketKey and verify sending has been called.
- mSocketClient.sendPacketRequestingMulticastResponse(ipv4Packet, mSocketKey,
+ mSocketClient.sendPacketRequestingMulticastResponse(List.of(ipv4Packet), mSocketKey,
false /* onlyUseIpv6OnIpv6OnlyNetworks */);
HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
verify(mSocket).send(ipv4Packet);
@@ -295,7 +298,7 @@
verify(socketCreationCb2).onSocketCreated(socketKey3);
// Send IPv4 packet on socket2 and verify sending to the socket2 only.
- mSocketClient.sendPacketRequestingMulticastResponse(ipv4Packet, socketKey2,
+ mSocketClient.sendPacketRequestingMulticastResponse(List.of(ipv4Packet), socketKey2,
false /* onlyUseIpv6OnIpv6OnlyNetworks */);
HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
// ipv4Packet still sent only once on mSocket: times(1) matches the packet sent earlier on
@@ -309,7 +312,7 @@
verify(mProvider, timeout(DEFAULT_TIMEOUT)).unrequestSocket(callback2);
// Send IPv4 packet again and verify it's still sent a second time
- mSocketClient.sendPacketRequestingMulticastResponse(ipv4Packet, socketKey2,
+ mSocketClient.sendPacketRequestingMulticastResponse(List.of(ipv4Packet), socketKey2,
false /* onlyUseIpv6OnIpv6OnlyNetworks */);
HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
verify(socket2, times(2)).send(ipv4Packet);
@@ -320,7 +323,7 @@
verify(mProvider, timeout(DEFAULT_TIMEOUT)).unrequestSocket(callback);
// Send IPv4 packet and verify no more sending.
- mSocketClient.sendPacketRequestingMulticastResponse(ipv4Packet, mSocketKey,
+ mSocketClient.sendPacketRequestingMulticastResponse(List.of(ipv4Packet), mSocketKey,
false /* onlyUseIpv6OnIpv6OnlyNetworks */);
HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
verify(mSocket, times(1)).send(ipv4Packet);
@@ -407,4 +410,31 @@
verify(creationCallback3).onSocketDestroyed(mSocketKey);
verify(creationCallback3, never()).onSocketDestroyed(socketKey2);
}
+
+ @Test
+ public void testSendPacketWithMultipleDatagramPacket() throws IOException {
+ final SocketCallback callback = expectSocketCallback();
+ final List<DatagramPacket> packets = new ArrayList<>();
+ for (int i = 0; i < 10; i++) {
+ packets.add(new DatagramPacket(new byte[10 + i] /* buff */, 0 /* offset */,
+ 10 + i /* length */, MdnsConstants.IPV4_SOCKET_ADDR));
+ }
+ doReturn(true).when(mSocket).hasJoinedIpv4();
+ doReturn(true).when(mSocket).hasJoinedIpv6();
+ doReturn(createEmptyNetworkInterface()).when(mSocket).getInterface();
+
+ // Notify socket created
+ callback.onSocketCreated(mSocketKey, mSocket, List.of());
+ verify(mSocketCreationCallback).onSocketCreated(mSocketKey);
+
+ // Send packets to IPv4 with mSocketKey then verify sending has been called and the
+ // sequence is correct.
+ mSocketClient.sendPacketRequestingMulticastResponse(packets, mSocketKey,
+ false /* onlyUseIpv6OnIpv6OnlyNetworks */);
+ HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
+ InOrder inOrder = inOrder(mSocket);
+ for (int i = 0; i < 10; i++) {
+ inOrder.verify(mSocket).send(packets.get(i));
+ }
+ }
}
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java
index 2eb9440..44fa55c 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java
@@ -162,59 +162,59 @@
expectedIPv6Packets[i] = new DatagramPacket(buf, 0 /* offset */, 5 /* length */,
MdnsConstants.getMdnsIPv6Address(), MdnsConstants.MDNS_PORT);
}
- when(mockDeps.getDatagramPacketFromMdnsPacket(
- any(), any(MdnsPacket.class), eq(IPV4_ADDRESS)))
- .thenReturn(expectedIPv4Packets[0])
- .thenReturn(expectedIPv4Packets[1])
- .thenReturn(expectedIPv4Packets[2])
- .thenReturn(expectedIPv4Packets[3])
- .thenReturn(expectedIPv4Packets[4])
- .thenReturn(expectedIPv4Packets[5])
- .thenReturn(expectedIPv4Packets[6])
- .thenReturn(expectedIPv4Packets[7])
- .thenReturn(expectedIPv4Packets[8])
- .thenReturn(expectedIPv4Packets[9])
- .thenReturn(expectedIPv4Packets[10])
- .thenReturn(expectedIPv4Packets[11])
- .thenReturn(expectedIPv4Packets[12])
- .thenReturn(expectedIPv4Packets[13])
- .thenReturn(expectedIPv4Packets[14])
- .thenReturn(expectedIPv4Packets[15])
- .thenReturn(expectedIPv4Packets[16])
- .thenReturn(expectedIPv4Packets[17])
- .thenReturn(expectedIPv4Packets[18])
- .thenReturn(expectedIPv4Packets[19])
- .thenReturn(expectedIPv4Packets[20])
- .thenReturn(expectedIPv4Packets[21])
- .thenReturn(expectedIPv4Packets[22])
- .thenReturn(expectedIPv4Packets[23]);
+ when(mockDeps.getDatagramPacketsFromMdnsPacket(
+ any(), any(MdnsPacket.class), eq(IPV4_ADDRESS), anyBoolean()))
+ .thenReturn(List.of(expectedIPv4Packets[0]))
+ .thenReturn(List.of(expectedIPv4Packets[1]))
+ .thenReturn(List.of(expectedIPv4Packets[2]))
+ .thenReturn(List.of(expectedIPv4Packets[3]))
+ .thenReturn(List.of(expectedIPv4Packets[4]))
+ .thenReturn(List.of(expectedIPv4Packets[5]))
+ .thenReturn(List.of(expectedIPv4Packets[6]))
+ .thenReturn(List.of(expectedIPv4Packets[7]))
+ .thenReturn(List.of(expectedIPv4Packets[8]))
+ .thenReturn(List.of(expectedIPv4Packets[9]))
+ .thenReturn(List.of(expectedIPv4Packets[10]))
+ .thenReturn(List.of(expectedIPv4Packets[11]))
+ .thenReturn(List.of(expectedIPv4Packets[12]))
+ .thenReturn(List.of(expectedIPv4Packets[13]))
+ .thenReturn(List.of(expectedIPv4Packets[14]))
+ .thenReturn(List.of(expectedIPv4Packets[15]))
+ .thenReturn(List.of(expectedIPv4Packets[16]))
+ .thenReturn(List.of(expectedIPv4Packets[17]))
+ .thenReturn(List.of(expectedIPv4Packets[18]))
+ .thenReturn(List.of(expectedIPv4Packets[19]))
+ .thenReturn(List.of(expectedIPv4Packets[20]))
+ .thenReturn(List.of(expectedIPv4Packets[21]))
+ .thenReturn(List.of(expectedIPv4Packets[22]))
+ .thenReturn(List.of(expectedIPv4Packets[23]));
- when(mockDeps.getDatagramPacketFromMdnsPacket(
- any(), any(MdnsPacket.class), eq(IPV6_ADDRESS)))
- .thenReturn(expectedIPv6Packets[0])
- .thenReturn(expectedIPv6Packets[1])
- .thenReturn(expectedIPv6Packets[2])
- .thenReturn(expectedIPv6Packets[3])
- .thenReturn(expectedIPv6Packets[4])
- .thenReturn(expectedIPv6Packets[5])
- .thenReturn(expectedIPv6Packets[6])
- .thenReturn(expectedIPv6Packets[7])
- .thenReturn(expectedIPv6Packets[8])
- .thenReturn(expectedIPv6Packets[9])
- .thenReturn(expectedIPv6Packets[10])
- .thenReturn(expectedIPv6Packets[11])
- .thenReturn(expectedIPv6Packets[12])
- .thenReturn(expectedIPv6Packets[13])
- .thenReturn(expectedIPv6Packets[14])
- .thenReturn(expectedIPv6Packets[15])
- .thenReturn(expectedIPv6Packets[16])
- .thenReturn(expectedIPv6Packets[17])
- .thenReturn(expectedIPv6Packets[18])
- .thenReturn(expectedIPv6Packets[19])
- .thenReturn(expectedIPv6Packets[20])
- .thenReturn(expectedIPv6Packets[21])
- .thenReturn(expectedIPv6Packets[22])
- .thenReturn(expectedIPv6Packets[23]);
+ when(mockDeps.getDatagramPacketsFromMdnsPacket(
+ any(), any(MdnsPacket.class), eq(IPV6_ADDRESS), anyBoolean()))
+ .thenReturn(List.of(expectedIPv6Packets[0]))
+ .thenReturn(List.of(expectedIPv6Packets[1]))
+ .thenReturn(List.of(expectedIPv6Packets[2]))
+ .thenReturn(List.of(expectedIPv6Packets[3]))
+ .thenReturn(List.of(expectedIPv6Packets[4]))
+ .thenReturn(List.of(expectedIPv6Packets[5]))
+ .thenReturn(List.of(expectedIPv6Packets[6]))
+ .thenReturn(List.of(expectedIPv6Packets[7]))
+ .thenReturn(List.of(expectedIPv6Packets[8]))
+ .thenReturn(List.of(expectedIPv6Packets[9]))
+ .thenReturn(List.of(expectedIPv6Packets[10]))
+ .thenReturn(List.of(expectedIPv6Packets[11]))
+ .thenReturn(List.of(expectedIPv6Packets[12]))
+ .thenReturn(List.of(expectedIPv6Packets[13]))
+ .thenReturn(List.of(expectedIPv6Packets[14]))
+ .thenReturn(List.of(expectedIPv6Packets[15]))
+ .thenReturn(List.of(expectedIPv6Packets[16]))
+ .thenReturn(List.of(expectedIPv6Packets[17]))
+ .thenReturn(List.of(expectedIPv6Packets[18]))
+ .thenReturn(List.of(expectedIPv6Packets[19]))
+ .thenReturn(List.of(expectedIPv6Packets[20]))
+ .thenReturn(List.of(expectedIPv6Packets[21]))
+ .thenReturn(List.of(expectedIPv6Packets[22]))
+ .thenReturn(List.of(expectedIPv6Packets[23]));
thread = new HandlerThread("MdnsServiceTypeClientTests");
thread.start();
@@ -694,23 +694,23 @@
.addSubtype("subtype1").build();
final MdnsSearchOptions searchOptions2 = MdnsSearchOptions.newBuilder()
.addSubtype("subtype2").build();
- doCallRealMethod().when(mockDeps).getDatagramPacketFromMdnsPacket(
- any(), any(MdnsPacket.class), any(InetSocketAddress.class));
+ doCallRealMethod().when(mockDeps).getDatagramPacketsFromMdnsPacket(
+ any(), any(MdnsPacket.class), any(InetSocketAddress.class), anyBoolean());
startSendAndReceive(mockListenerOne, searchOptions1);
currentThreadExecutor.getAndClearLastScheduledRunnable().run();
InOrder inOrder = inOrder(mockListenerOne, mockSocketClient, mockDeps);
// Verify the query asks for subtype1
- final ArgumentCaptor<DatagramPacket> subtype1QueryCaptor =
- ArgumentCaptor.forClass(DatagramPacket.class);
+ final ArgumentCaptor<List<DatagramPacket>> subtype1QueryCaptor =
+ ArgumentCaptor.forClass(List.class);
// Send twice for IPv4 and IPv6
inOrder.verify(mockSocketClient, times(2)).sendPacketRequestingUnicastResponse(
subtype1QueryCaptor.capture(),
eq(socketKey), eq(false));
final MdnsPacket subtype1Query = MdnsPacket.parse(
- new MdnsPacketReader(subtype1QueryCaptor.getValue()));
+ new MdnsPacketReader(subtype1QueryCaptor.getValue().get(0)));
assertEquals(2, subtype1Query.questions.size());
assertTrue(hasQuestion(subtype1Query, MdnsRecord.TYPE_PTR, SERVICE_TYPE_LABELS));
@@ -722,8 +722,8 @@
inOrder.verify(mockDeps).removeMessages(any(), eq(EVENT_START_QUERYTASK));
currentThreadExecutor.getAndClearLastScheduledRunnable().run();
- final ArgumentCaptor<DatagramPacket> combinedSubtypesQueryCaptor =
- ArgumentCaptor.forClass(DatagramPacket.class);
+ final ArgumentCaptor<List<DatagramPacket>> combinedSubtypesQueryCaptor =
+ ArgumentCaptor.forClass(List.class);
inOrder.verify(mockSocketClient, times(2)).sendPacketRequestingUnicastResponse(
combinedSubtypesQueryCaptor.capture(),
eq(socketKey), eq(false));
@@ -731,7 +731,7 @@
inOrder.verify(mockDeps).sendMessageDelayed(any(), any(), anyLong());
final MdnsPacket combinedSubtypesQuery = MdnsPacket.parse(
- new MdnsPacketReader(combinedSubtypesQueryCaptor.getValue()));
+ new MdnsPacketReader(combinedSubtypesQueryCaptor.getValue().get(0)));
assertEquals(3, combinedSubtypesQuery.questions.size());
assertTrue(hasQuestion(combinedSubtypesQuery, MdnsRecord.TYPE_PTR, SERVICE_TYPE_LABELS));
@@ -747,15 +747,15 @@
dispatchMessage();
currentThreadExecutor.getAndClearLastScheduledRunnable().run();
- final ArgumentCaptor<DatagramPacket> subtype2QueryCaptor =
- ArgumentCaptor.forClass(DatagramPacket.class);
+ final ArgumentCaptor<List<DatagramPacket>> subtype2QueryCaptor =
+ ArgumentCaptor.forClass(List.class);
// Send twice for IPv4 and IPv6
inOrder.verify(mockSocketClient, times(2)).sendPacketRequestingMulticastResponse(
subtype2QueryCaptor.capture(),
eq(socketKey), eq(false));
final MdnsPacket subtype2Query = MdnsPacket.parse(
- new MdnsPacketReader(subtype2QueryCaptor.getValue()));
+ new MdnsPacketReader(subtype2QueryCaptor.getValue().get(0)));
assertEquals(2, subtype2Query.questions.size());
assertTrue(hasQuestion(subtype2Query, MdnsRecord.TYPE_PTR, SERVICE_TYPE_LABELS));
@@ -1201,8 +1201,8 @@
final MdnsSearchOptions resolveOptions2 = MdnsSearchOptions.newBuilder()
.setResolveInstanceName(instanceName).build();
- doCallRealMethod().when(mockDeps).getDatagramPacketFromMdnsPacket(
- any(), any(MdnsPacket.class), any(InetSocketAddress.class));
+ doCallRealMethod().when(mockDeps).getDatagramPacketsFromMdnsPacket(
+ any(), any(MdnsPacket.class), any(InetSocketAddress.class), anyBoolean());
startSendAndReceive(mockListenerOne, resolveOptions1);
startSendAndReceive(mockListenerTwo, resolveOptions2);
@@ -1210,8 +1210,8 @@
InOrder inOrder = inOrder(mockListenerOne, mockSocketClient);
// Verify a query for SRV/TXT was sent, but no PTR query
- final ArgumentCaptor<DatagramPacket> srvTxtQueryCaptor =
- ArgumentCaptor.forClass(DatagramPacket.class);
+ final ArgumentCaptor<List<DatagramPacket>> srvTxtQueryCaptor =
+ ArgumentCaptor.forClass(List.class);
currentThreadExecutor.getAndClearLastScheduledRunnable().run();
// Send twice for IPv4 and IPv6
inOrder.verify(mockSocketClient, times(2)).sendPacketRequestingUnicastResponse(
@@ -1223,7 +1223,7 @@
verify(mockListenerTwo).onDiscoveryQuerySent(any(), anyInt());
final MdnsPacket srvTxtQueryPacket = MdnsPacket.parse(
- new MdnsPacketReader(srvTxtQueryCaptor.getValue()));
+ new MdnsPacketReader(srvTxtQueryCaptor.getValue().get(0)));
final String[] serviceName = getTestServiceName(instanceName);
assertEquals(1, srvTxtQueryPacket.questions.size());
@@ -1255,8 +1255,8 @@
// Expect a query for A/AAAA
dispatchMessage();
- final ArgumentCaptor<DatagramPacket> addressQueryCaptor =
- ArgumentCaptor.forClass(DatagramPacket.class);
+ final ArgumentCaptor<List<DatagramPacket>> addressQueryCaptor =
+ ArgumentCaptor.forClass(List.class);
currentThreadExecutor.getAndClearLastScheduledRunnable().run();
inOrder.verify(mockSocketClient, times(2)).sendPacketRequestingMulticastResponse(
addressQueryCaptor.capture(),
@@ -1266,7 +1266,7 @@
verify(mockListenerTwo, times(2)).onDiscoveryQuerySent(any(), anyInt());
final MdnsPacket addressQueryPacket = MdnsPacket.parse(
- new MdnsPacketReader(addressQueryCaptor.getValue()));
+ new MdnsPacketReader(addressQueryCaptor.getValue().get(0)));
assertEquals(2, addressQueryPacket.questions.size());
assertTrue(hasQuestion(addressQueryPacket, MdnsRecord.TYPE_A, hostname));
assertTrue(hasQuestion(addressQueryPacket, MdnsRecord.TYPE_AAAA, hostname));
@@ -1316,15 +1316,15 @@
final MdnsSearchOptions resolveOptions = MdnsSearchOptions.newBuilder()
.setResolveInstanceName(instanceName).build();
- doCallRealMethod().when(mockDeps).getDatagramPacketFromMdnsPacket(
- any(), any(MdnsPacket.class), any(InetSocketAddress.class));
+ doCallRealMethod().when(mockDeps).getDatagramPacketsFromMdnsPacket(
+ any(), any(MdnsPacket.class), any(InetSocketAddress.class), anyBoolean());
startSendAndReceive(mockListenerOne, resolveOptions);
InOrder inOrder = inOrder(mockListenerOne, mockSocketClient);
// Get the query for SRV/TXT
- final ArgumentCaptor<DatagramPacket> srvTxtQueryCaptor =
- ArgumentCaptor.forClass(DatagramPacket.class);
+ final ArgumentCaptor<List<DatagramPacket>> srvTxtQueryCaptor =
+ ArgumentCaptor.forClass(List.class);
currentThreadExecutor.getAndClearLastScheduledRunnable().run();
// Send twice for IPv4 and IPv6
inOrder.verify(mockSocketClient, times(2)).sendPacketRequestingUnicastResponse(
@@ -1334,7 +1334,7 @@
assertNotNull(delayMessage);
final MdnsPacket srvTxtQueryPacket = MdnsPacket.parse(
- new MdnsPacketReader(srvTxtQueryCaptor.getValue()));
+ new MdnsPacketReader(srvTxtQueryCaptor.getValue().get(0)));
final String[] serviceName = getTestServiceName(instanceName);
assertTrue(hasQuestion(srvTxtQueryPacket, MdnsRecord.TYPE_ANY, serviceName));
@@ -1378,8 +1378,8 @@
currentThreadExecutor.getAndClearLastScheduledRunnable().run();
// Expect a renewal query
- final ArgumentCaptor<DatagramPacket> renewalQueryCaptor =
- ArgumentCaptor.forClass(DatagramPacket.class);
+ final ArgumentCaptor<List<DatagramPacket>> renewalQueryCaptor =
+ ArgumentCaptor.forClass(List.class);
// Second and later sends are sent as "expect multicast response" queries
inOrder.verify(mockSocketClient, times(2)).sendPacketRequestingMulticastResponse(
renewalQueryCaptor.capture(),
@@ -1388,7 +1388,7 @@
assertNotNull(delayMessage);
inOrder.verify(mockListenerOne).onDiscoveryQuerySent(any(), anyInt());
final MdnsPacket renewalPacket = MdnsPacket.parse(
- new MdnsPacketReader(renewalQueryCaptor.getValue()));
+ new MdnsPacketReader(renewalQueryCaptor.getValue().get(0)));
assertTrue(hasQuestion(renewalPacket, MdnsRecord.TYPE_ANY, serviceName));
inOrder.verifyNoMoreInteractions();
@@ -1937,14 +1937,14 @@
serviceCache,
MdnsFeatureFlags.newBuilder().setIsQueryWithKnownAnswerEnabled(true).build());
- doCallRealMethod().when(mockDeps).getDatagramPacketFromMdnsPacket(
- any(), any(MdnsPacket.class), any(InetSocketAddress.class));
+ doCallRealMethod().when(mockDeps).getDatagramPacketsFromMdnsPacket(
+ any(), any(MdnsPacket.class), any(InetSocketAddress.class), anyBoolean());
startSendAndReceive(mockListenerOne, MdnsSearchOptions.getDefaultOptions());
InOrder inOrder = inOrder(mockListenerOne, mockSocketClient);
- final ArgumentCaptor<DatagramPacket> queryCaptor =
- ArgumentCaptor.forClass(DatagramPacket.class);
+ final ArgumentCaptor<List<DatagramPacket>> queryCaptor =
+ ArgumentCaptor.forClass(List.class);
currentThreadExecutor.getAndClearLastScheduledRunnable().run();
// Send twice for IPv4 and IPv6
inOrder.verify(mockSocketClient, times(2)).sendPacketRequestingUnicastResponse(
@@ -1953,7 +1953,7 @@
assertNotNull(delayMessage);
final MdnsPacket queryPacket = MdnsPacket.parse(
- new MdnsPacketReader(queryCaptor.getValue()));
+ new MdnsPacketReader(queryCaptor.getValue().get(0)));
assertTrue(hasQuestion(queryPacket, MdnsRecord.TYPE_PTR));
// Process a response
@@ -1981,14 +1981,14 @@
// Expect a query with known answers
dispatchMessage();
- final ArgumentCaptor<DatagramPacket> knownAnswersQueryCaptor =
- ArgumentCaptor.forClass(DatagramPacket.class);
+ final ArgumentCaptor<List<DatagramPacket>> knownAnswersQueryCaptor =
+ ArgumentCaptor.forClass(List.class);
currentThreadExecutor.getAndClearLastScheduledRunnable().run();
inOrder.verify(mockSocketClient, times(2)).sendPacketRequestingMulticastResponse(
knownAnswersQueryCaptor.capture(), eq(socketKey), eq(false));
final MdnsPacket knownAnswersQueryPacket = MdnsPacket.parse(
- new MdnsPacketReader(knownAnswersQueryCaptor.getValue()));
+ new MdnsPacketReader(knownAnswersQueryCaptor.getValue().get(0)));
assertTrue(hasQuestion(knownAnswersQueryPacket, MdnsRecord.TYPE_PTR, SERVICE_TYPE_LABELS));
assertTrue(hasAnswer(knownAnswersQueryPacket, MdnsRecord.TYPE_PTR, SERVICE_TYPE_LABELS));
assertFalse(hasAnswer(knownAnswersQueryPacket, MdnsRecord.TYPE_PTR, subtypeLabels));
@@ -2001,16 +2001,16 @@
serviceCache,
MdnsFeatureFlags.newBuilder().setIsQueryWithKnownAnswerEnabled(true).build());
- doCallRealMethod().when(mockDeps).getDatagramPacketFromMdnsPacket(
- any(), any(MdnsPacket.class), any(InetSocketAddress.class));
+ doCallRealMethod().when(mockDeps).getDatagramPacketsFromMdnsPacket(
+ any(), any(MdnsPacket.class), any(InetSocketAddress.class), anyBoolean());
final MdnsSearchOptions options = MdnsSearchOptions.newBuilder()
.addSubtype("subtype").build();
startSendAndReceive(mockListenerOne, options);
InOrder inOrder = inOrder(mockListenerOne, mockSocketClient);
- final ArgumentCaptor<DatagramPacket> queryCaptor =
- ArgumentCaptor.forClass(DatagramPacket.class);
+ final ArgumentCaptor<List<DatagramPacket>> queryCaptor =
+ ArgumentCaptor.forClass(List.class);
currentThreadExecutor.getAndClearLastScheduledRunnable().run();
// Send twice for IPv4 and IPv6
inOrder.verify(mockSocketClient, times(2)).sendPacketRequestingUnicastResponse(
@@ -2019,7 +2019,7 @@
assertNotNull(delayMessage);
final MdnsPacket queryPacket = MdnsPacket.parse(
- new MdnsPacketReader(queryCaptor.getValue()));
+ new MdnsPacketReader(queryCaptor.getValue().get(0)));
final String[] subtypeLabels = Stream.concat(Stream.of("_subtype", "_sub"),
Arrays.stream(SERVICE_TYPE_LABELS)).toArray(String[]::new);
assertTrue(hasQuestion(queryPacket, MdnsRecord.TYPE_PTR, SERVICE_TYPE_LABELS));
@@ -2048,14 +2048,14 @@
// Expect a query with known answers
dispatchMessage();
- final ArgumentCaptor<DatagramPacket> knownAnswersQueryCaptor =
- ArgumentCaptor.forClass(DatagramPacket.class);
+ final ArgumentCaptor<List<DatagramPacket>> knownAnswersQueryCaptor =
+ ArgumentCaptor.forClass(List.class);
currentThreadExecutor.getAndClearLastScheduledRunnable().run();
inOrder.verify(mockSocketClient, times(2)).sendPacketRequestingMulticastResponse(
knownAnswersQueryCaptor.capture(), eq(socketKey), eq(false));
final MdnsPacket knownAnswersQueryPacket = MdnsPacket.parse(
- new MdnsPacketReader(knownAnswersQueryCaptor.getValue()));
+ new MdnsPacketReader(knownAnswersQueryCaptor.getValue().get(0)));
assertTrue(hasQuestion(knownAnswersQueryPacket, MdnsRecord.TYPE_PTR, SERVICE_TYPE_LABELS));
assertTrue(hasQuestion(knownAnswersQueryPacket, MdnsRecord.TYPE_PTR, subtypeLabels));
assertTrue(hasAnswer(knownAnswersQueryPacket, MdnsRecord.TYPE_PTR, SERVICE_TYPE_LABELS));
@@ -2083,17 +2083,21 @@
currentThreadExecutor.getAndClearLastScheduledRunnable().run();
if (expectsUnicastResponse) {
verify(mockSocketClient).sendPacketRequestingUnicastResponse(
- expectedIPv4Packets[index], socketKey, false);
+ argThat(pkts -> pkts.get(0).equals(expectedIPv4Packets[index])),
+ eq(socketKey), eq(false));
if (multipleSocketDiscovery) {
verify(mockSocketClient).sendPacketRequestingUnicastResponse(
- expectedIPv6Packets[index], socketKey, false);
+ argThat(pkts -> pkts.get(0).equals(expectedIPv6Packets[index])),
+ eq(socketKey), eq(false));
}
} else {
verify(mockSocketClient).sendPacketRequestingMulticastResponse(
- expectedIPv4Packets[index], socketKey, false);
+ argThat(pkts -> pkts.get(0).equals(expectedIPv4Packets[index])),
+ eq(socketKey), eq(false));
if (multipleSocketDiscovery) {
verify(mockSocketClient).sendPacketRequestingMulticastResponse(
- expectedIPv6Packets[index], socketKey, false);
+ argThat(pkts -> pkts.get(0).equals(expectedIPv6Packets[index])),
+ eq(socketKey), eq(false));
}
}
verify(mockDeps, times(index + 1))
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketClientTests.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketClientTests.java
index 7ced1cb..1989ed3 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketClientTests.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketClientTests.java
@@ -27,6 +27,7 @@
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doReturn;
+import static org.mockito.Mockito.inOrder;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.timeout;
import static org.mockito.Mockito.times;
@@ -53,6 +54,7 @@
import org.junit.runner.RunWith;
import org.mockito.ArgumentCaptor;
import org.mockito.ArgumentMatchers;
+import org.mockito.InOrder;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.mockito.invocation.InvocationOnMock;
@@ -60,6 +62,8 @@
import java.io.IOException;
import java.net.DatagramPacket;
import java.net.InetSocketAddress;
+import java.util.ArrayList;
+import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
@@ -234,7 +238,7 @@
// Sends a packet.
DatagramPacket packet = getTestDatagramPacket();
- mdnsClient.sendPacketRequestingMulticastResponse(packet,
+ mdnsClient.sendPacketRequestingMulticastResponse(List.of(packet),
false /* onlyUseIpv6OnIpv6OnlyNetworks */);
// mockMulticastSocket.send() will be called on another thread. If we verify it immediately,
// it may not be called yet. So timeout is added.
@@ -242,7 +246,7 @@
verify(mockUnicastSocket, timeout(TIMEOUT).times(0)).send(packet);
// Verify the packet is sent by the unicast socket.
- mdnsClient.sendPacketRequestingUnicastResponse(packet,
+ mdnsClient.sendPacketRequestingUnicastResponse(List.of(packet),
false /* onlyUseIpv6OnIpv6OnlyNetworks */);
verify(mockMulticastSocket, timeout(TIMEOUT).times(1)).send(packet);
verify(mockUnicastSocket, timeout(TIMEOUT).times(1)).send(packet);
@@ -287,7 +291,7 @@
// Sends a packet.
DatagramPacket packet = getTestDatagramPacket();
- mdnsClient.sendPacketRequestingMulticastResponse(packet,
+ mdnsClient.sendPacketRequestingMulticastResponse(List.of(packet),
false /* onlyUseIpv6OnIpv6OnlyNetworks */);
// mockMulticastSocket.send() will be called on another thread. If we verify it immediately,
// it may not be called yet. So timeout is added.
@@ -295,7 +299,7 @@
verify(mockUnicastSocket, timeout(TIMEOUT).times(0)).send(packet);
// Verify the packet is sent by the multicast socket as well.
- mdnsClient.sendPacketRequestingUnicastResponse(packet,
+ mdnsClient.sendPacketRequestingUnicastResponse(List.of(packet),
false /* onlyUseIpv6OnIpv6OnlyNetworks */);
verify(mockMulticastSocket, timeout(TIMEOUT).times(2)).send(packet);
verify(mockUnicastSocket, timeout(TIMEOUT).times(0)).send(packet);
@@ -354,7 +358,7 @@
public void testStopDiscovery_queueIsCleared() throws IOException {
mdnsClient.startDiscovery();
mdnsClient.stopDiscovery();
- mdnsClient.sendPacketRequestingMulticastResponse(getTestDatagramPacket(),
+ mdnsClient.sendPacketRequestingMulticastResponse(List.of(getTestDatagramPacket()),
false /* onlyUseIpv6OnIpv6OnlyNetworks */);
synchronized (mdnsClient.multicastPacketQueue) {
@@ -366,7 +370,7 @@
public void testSendPacket_afterDiscoveryStops() throws IOException {
mdnsClient.startDiscovery();
mdnsClient.stopDiscovery();
- mdnsClient.sendPacketRequestingMulticastResponse(getTestDatagramPacket(),
+ mdnsClient.sendPacketRequestingMulticastResponse(List.of(getTestDatagramPacket()),
false /* onlyUseIpv6OnIpv6OnlyNetworks */);
synchronized (mdnsClient.multicastPacketQueue) {
@@ -380,7 +384,7 @@
//MdnsConfigsFlagsImpl.mdnsPacketQueueMaxSize.override(2L);
mdnsClient.startDiscovery();
for (int i = 0; i < 100; i++) {
- mdnsClient.sendPacketRequestingMulticastResponse(getTestDatagramPacket(),
+ mdnsClient.sendPacketRequestingMulticastResponse(List.of(getTestDatagramPacket()),
false /* onlyUseIpv6OnIpv6OnlyNetworks */);
}
@@ -478,9 +482,9 @@
mdnsClient.startDiscovery();
DatagramPacket packet = getTestDatagramPacket();
- mdnsClient.sendPacketRequestingUnicastResponse(packet,
+ mdnsClient.sendPacketRequestingUnicastResponse(List.of(packet),
false /* onlyUseIpv6OnIpv6OnlyNetworks */);
- mdnsClient.sendPacketRequestingMulticastResponse(packet,
+ mdnsClient.sendPacketRequestingMulticastResponse(List.of(packet),
false /* onlyUseIpv6OnIpv6OnlyNetworks */);
// Wait for the timer to be triggered.
@@ -511,9 +515,9 @@
assertFalse(mdnsClient.receivedUnicastResponse);
assertFalse(mdnsClient.cannotReceiveMulticastResponse.get());
- mdnsClient.sendPacketRequestingUnicastResponse(packet,
+ mdnsClient.sendPacketRequestingUnicastResponse(List.of(packet),
false /* onlyUseIpv6OnIpv6OnlyNetworks */);
- mdnsClient.sendPacketRequestingMulticastResponse(packet,
+ mdnsClient.sendPacketRequestingMulticastResponse(List.of(packet),
false /* onlyUseIpv6OnIpv6OnlyNetworks */);
Thread.sleep(MdnsConfigs.checkMulticastResponseIntervalMs() * 2);
@@ -570,6 +574,26 @@
.onResponseReceived(any(), argThat(key -> key.getInterfaceIndex() == -1));
}
+ @Test
+ public void testSendPacketWithMultipleDatagramPacket() throws IOException {
+ mdnsClient.startDiscovery();
+ final List<DatagramPacket> packets = new ArrayList<>();
+ for (int i = 0; i < 10; i++) {
+ packets.add(new DatagramPacket(new byte[10 + i] /* buff */, 0 /* offset */,
+ 10 + i /* length */, MdnsConstants.IPV4_SOCKET_ADDR));
+ }
+
+ // Sends packets.
+ mdnsClient.sendPacketRequestingMulticastResponse(packets,
+ false /* onlyUseIpv6OnIpv6OnlyNetworks */);
+ InOrder inOrder = inOrder(mockMulticastSocket);
+ for (int i = 0; i < 10; i++) {
+ // mockMulticastSocket.send() will be called on another thread. If we verify it
+ // immediately, it may not be called yet. So timeout is added.
+ inOrder.verify(mockMulticastSocket, timeout(TIMEOUT)).send(packets.get(i));
+ }
+ }
+
private DatagramPacket getTestDatagramPacket() {
return new DatagramPacket(buf, 0, 5,
new InetSocketAddress(MdnsConstants.getMdnsIPv4Address(), 5353 /* port */));
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/util/MdnsUtilsTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/util/MdnsUtilsTest.kt
index b1a7233..009205e 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/util/MdnsUtilsTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/util/MdnsUtilsTest.kt
@@ -52,19 +52,27 @@
assertEquals("ţést", toDnsLowerCase("ţést"))
// Unicode characters 0x10000 (𐀀), 0x10001 (𐀁), 0x10041 (𐁁)
// Note the last 2 bytes of 0x10041 are identical to 'A', but it should remain unchanged.
- assertEquals("test: -->\ud800\udc00 \ud800\udc01 \ud800\udc41<-- ",
- toDnsLowerCase("Test: -->\ud800\udc00 \ud800\udc01 \ud800\udc41<-- "))
+ assertEquals(
+ "test: -->\ud800\udc00 \ud800\udc01 \ud800\udc41<-- ",
+ toDnsLowerCase("Test: -->\ud800\udc00 \ud800\udc01 \ud800\udc41<-- ")
+ )
// Also test some characters where the first surrogate is not \ud800
- assertEquals("test: >\ud83c\udff4\udb40\udc67\udb40\udc62\udb40" +
+ assertEquals(
+ "test: >\ud83c\udff4\udb40\udc67\udb40\udc62\udb40" +
"\udc77\udb40\udc6c\udb40\udc73\udb40\udc7f<",
- toDnsLowerCase("Test: >\ud83c\udff4\udb40\udc67\udb40\udc62\udb40" +
- "\udc77\udb40\udc6c\udb40\udc73\udb40\udc7f<"))
+ toDnsLowerCase(
+ "Test: >\ud83c\udff4\udb40\udc67\udb40\udc62\udb40" +
+ "\udc77\udb40\udc6c\udb40\udc73\udb40\udc7f<"
+ )
+ )
}
@Test
fun testToDnsLabelsLowerCase() {
- assertArrayEquals(arrayOf("test", "tÉst", "ţést"),
- toDnsLabelsLowerCase(arrayOf("TeSt", "TÉST", "ţést")))
+ assertArrayEquals(
+ arrayOf("test", "tÉst", "ţést"),
+ toDnsLabelsLowerCase(arrayOf("TeSt", "TÉST", "ţést"))
+ )
}
@Test
@@ -76,13 +84,17 @@
assertFalse(equalsIgnoreDnsCase("ŢÉST", "ţést"))
// Unicode characters 0x10000 (𐀀), 0x10001 (𐀁), 0x10041 (𐁁)
// Note the last 2 bytes of 0x10041 are identical to 'A', but it should remain unchanged.
- assertTrue(equalsIgnoreDnsCase("test: -->\ud800\udc00 \ud800\udc01 \ud800\udc41<-- ",
- "Test: -->\ud800\udc00 \ud800\udc01 \ud800\udc41<-- "))
+ assertTrue(equalsIgnoreDnsCase(
+ "test: -->\ud800\udc00 \ud800\udc01 \ud800\udc41<-- ",
+ "Test: -->\ud800\udc00 \ud800\udc01 \ud800\udc41<-- "
+ ))
// Also test some characters where the first surrogate is not \ud800
- assertTrue(equalsIgnoreDnsCase("test: >\ud83c\udff4\udb40\udc67\udb40\udc62\udb40" +
+ assertTrue(equalsIgnoreDnsCase(
+ "test: >\ud83c\udff4\udb40\udc67\udb40\udc62\udb40" +
"\udc77\udb40\udc6c\udb40\udc73\udb40\udc7f<",
"Test: >\ud83c\udff4\udb40\udc67\udb40\udc62\udb40" +
- "\udc77\udb40\udc6c\udb40\udc73\udb40\udc7f<"))
+ "\udc77\udb40\udc6c\udb40\udc73\udb40\udc7f<"
+ ))
}
@Test
@@ -101,15 +113,22 @@
@Test
fun testTypeEqualsOrIsSubtype() {
- assertTrue(MdnsUtils.typeEqualsOrIsSubtype(arrayOf("_type", "_tcp", "local"),
- arrayOf("_type", "_TCP", "local")))
- assertTrue(MdnsUtils.typeEqualsOrIsSubtype(arrayOf("_type", "_tcp", "local"),
- arrayOf("a", "_SUB", "_type", "_TCP", "local")))
- assertFalse(MdnsUtils.typeEqualsOrIsSubtype(arrayOf("_sub", "_type", "_tcp", "local"),
- arrayOf("_type", "_TCP", "local")))
+ assertTrue(MdnsUtils.typeEqualsOrIsSubtype(
+ arrayOf("_type", "_tcp", "local"),
+ arrayOf("_type", "_TCP", "local")
+ ))
+ assertTrue(MdnsUtils.typeEqualsOrIsSubtype(
+ arrayOf("_type", "_tcp", "local"),
+ arrayOf("a", "_SUB", "_type", "_TCP", "local")
+ ))
+ assertFalse(MdnsUtils.typeEqualsOrIsSubtype(
+ arrayOf("_sub", "_type", "_tcp", "local"),
+ arrayOf("_type", "_TCP", "local")
+ ))
assertFalse(MdnsUtils.typeEqualsOrIsSubtype(
arrayOf("a", "_other", "_type", "_tcp", "local"),
- arrayOf("a", "_SUB", "_type", "_TCP", "local")))
+ arrayOf("a", "_SUB", "_type", "_TCP", "local")
+ ))
}
@Test
diff --git a/tests/unit/java/com/android/server/connectivityservice/CSIngressDiscardRuleTests.kt b/tests/unit/java/com/android/server/connectivityservice/CSIngressDiscardRuleTests.kt
index e8664c1..bb7fb51 100644
--- a/tests/unit/java/com/android/server/connectivityservice/CSIngressDiscardRuleTests.kt
+++ b/tests/unit/java/com/android/server/connectivityservice/CSIngressDiscardRuleTests.kt
@@ -30,6 +30,7 @@
import android.net.VpnTransportInfo
import android.os.Build
import androidx.test.filters.SmallTest
+import com.android.server.connectivity.ConnectivityFlags
import com.android.testutils.DevSdkIgnoreRule
import com.android.testutils.DevSdkIgnoreRunner
import com.android.testutils.RecorderCallback.CallbackEntry.LinkPropertiesChanged
@@ -56,7 +57,9 @@
TYPE_VPN_SERVICE,
"MySession12345",
false /* bypassable */,
- false /* longLivedTcpConnectionsExpensive */))
+ false /* longLivedTcpConnectionsExpensive */
+ )
+ )
.build()
private fun wifiNc() = NetworkCapabilities.Builder()
@@ -286,4 +289,19 @@
waitForIdle()
verify(bpfNetMaps).removeIngressDiscardRule(IPV6_ADDRESS)
}
+
+ @Test @FeatureFlags([Flag(ConnectivityFlags.INGRESS_TO_VPN_ADDRESS_FILTERING, false)])
+ fun testVpnIngressDiscardRule_FeatureDisabled() {
+ val nr = nr(TRANSPORT_VPN)
+ val cb = TestableNetworkCallback()
+ cm.registerNetworkCallback(nr, cb)
+ val nc = vpnNc()
+ val lp = lp(VPN_IFNAME, IPV6_LINK_ADDRESS, LOCAL_IPV6_LINK_ADDRRESS)
+ val agent = Agent(nc = nc, lp = lp)
+ agent.connect()
+ cb.expectAvailableCallbacks(agent.network, validated = false)
+
+ // IngressDiscardRule should not be added since feature is disabled
+ verify(bpfNetMaps, never()).setIngressDiscardRule(any(), any())
+ }
}
diff --git a/tests/unit/java/com/android/server/connectivityservice/base/CSTest.kt b/tests/unit/java/com/android/server/connectivityservice/base/CSTest.kt
index 5c4617b..bd26c63 100644
--- a/tests/unit/java/com/android/server/connectivityservice/base/CSTest.kt
+++ b/tests/unit/java/com/android/server/connectivityservice/base/CSTest.kt
@@ -79,11 +79,15 @@
import java.util.concurrent.TimeUnit
import java.util.function.BiConsumer
import java.util.function.Consumer
+import kotlin.annotation.AnnotationRetention.RUNTIME
+import kotlin.annotation.AnnotationTarget.FUNCTION
import kotlin.test.assertNotNull
import kotlin.test.assertNull
import kotlin.test.fail
import org.junit.After
import org.junit.Before
+import org.junit.Rule
+import org.junit.rules.TestName
import org.mockito.AdditionalAnswers.delegatesTo
import org.mockito.Mockito.doAnswer
import org.mockito.Mockito.doReturn
@@ -126,6 +130,9 @@
// TODO (b/272685721) : make ConnectivityServiceTest smaller and faster by moving the setup
// parts into this class and moving the individual tests to multiple separate classes.
open class CSTest {
+ @get:Rule
+ val testNameRule = TestName()
+
companion object {
val CSTestExecutor = Executors.newSingleThreadExecutor()
}
@@ -155,8 +162,7 @@
it[ConnectivityService.ALLOW_SATALLITE_NETWORK_FALLBACK] = true
it[ConnectivityFlags.INGRESS_TO_VPN_ADDRESS_FILTERING] = true
}
- fun enableFeature(f: String) = enabledFeatures.set(f, true)
- fun disableFeature(f: String) = enabledFeatures.set(f, false)
+ fun setFeatureEnabled(flag: String, enabled: Boolean) = enabledFeatures.set(flag, enabled)
// When adding new members, consider if it's not better to build the object in CSTestHelpers
// to keep this file clean of implementation details. Generally, CSTestHelpers should only
@@ -201,8 +207,32 @@
lateinit var cm: ConnectivityManager
lateinit var csHandler: Handler
+ // Tests can use this annotation to set flag values before constructing ConnectivityService
+ // e.g. @FeatureFlags([Flag(flagName1, true/false), Flag(flagName2, true/false)])
+ @Retention(RUNTIME)
+ @Target(FUNCTION)
+ annotation class FeatureFlags(val flags: Array<Flag>)
+
+ @Retention(RUNTIME)
+ @Target(FUNCTION)
+ annotation class Flag(val name: String, val enabled: Boolean)
+
@Before
fun setUp() {
+ // Set feature flags before constructing ConnectivityService
+ val testMethodName = testNameRule.methodName
+ try {
+ val testMethod = this::class.java.getMethod(testMethodName)
+ val featureFlags = testMethod.getAnnotation(FeatureFlags::class.java)
+ if (featureFlags != null) {
+ for (flag in featureFlags.flags) {
+ setFeatureEnabled(flag.name, flag.enabled)
+ }
+ }
+ } catch (ignored: NoSuchMethodException) {
+ // This is expected for parameterized tests
+ }
+
alarmHandlerThread = HandlerThread("TestAlarmManager").also { it.start() }
alarmManager = makeMockAlarmManager(alarmHandlerThread)
service = makeConnectivityService(context, netd, deps).also { it.systemReadyInternal() }
diff --git a/thread/service/java/com/android/server/thread/NsdPublisher.java b/thread/service/java/com/android/server/thread/NsdPublisher.java
index 72e3980..2c14f1d 100644
--- a/thread/service/java/com/android/server/thread/NsdPublisher.java
+++ b/thread/service/java/com/android/server/thread/NsdPublisher.java
@@ -39,10 +39,8 @@
import java.net.Inet6Address;
import java.net.InetAddress;
-import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
-import java.util.Deque;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
@@ -56,14 +54,6 @@
*
* <p>All the data members of this class MUST be accessed in the {@code mHandler}'s Thread except
* {@code mHandler} itself.
- *
- * <p>TODO: b/323300118 - Remove the following mechanism when the race condition in NsdManager is
- * fixed.
- *
- * <p>There's always only one running registration job at any timepoint. All other pending jobs are
- * queued in {@code mRegistrationJobs}. When a registration job is complete (i.e. the according
- * method in {@link NsdManager.RegistrationListener} is called), it will start the next registration
- * job in the queue.
*/
public final class NsdPublisher extends INsdPublisher.Stub {
// TODO: b/321883491 - specify network for mDNS operations
@@ -74,7 +64,6 @@
private final SparseArray<RegistrationListener> mRegistrationListeners = new SparseArray<>(0);
private final SparseArray<DiscoveryListener> mDiscoveryListeners = new SparseArray<>(0);
private final SparseArray<ServiceInfoListener> mServiceInfoListeners = new SparseArray<>(0);
- private final Deque<Runnable> mRegistrationJobs = new ArrayDeque<>();
@VisibleForTesting
public NsdPublisher(NsdManager nsdManager, Handler handler) {
@@ -97,13 +86,9 @@
List<DnsTxtAttribute> txt,
INsdStatusReceiver receiver,
int listenerId) {
- postRegistrationJob(
- () -> {
- NsdServiceInfo serviceInfo =
- buildServiceInfoForService(
- hostname, name, type, subTypeList, port, txt);
- registerInternal(serviceInfo, receiver, listenerId, "service");
- });
+ NsdServiceInfo serviceInfo =
+ buildServiceInfoForService(hostname, name, type, subTypeList, port, txt);
+ mHandler.post(() -> registerInternal(serviceInfo, receiver, listenerId, "service"));
}
private static NsdServiceInfo buildServiceInfoForService(
@@ -132,11 +117,8 @@
@Override
public void registerHost(
String name, List<String> addresses, INsdStatusReceiver receiver, int listenerId) {
- postRegistrationJob(
- () -> {
- NsdServiceInfo serviceInfo = buildServiceInfoForHost(name, addresses);
- registerInternal(serviceInfo, receiver, listenerId, "host");
- });
+ NsdServiceInfo serviceInfo = buildServiceInfoForHost(name, addresses);
+ mHandler.post(() -> registerInternal(serviceInfo, receiver, listenerId, "host"));
}
private static NsdServiceInfo buildServiceInfoForHost(
@@ -178,7 +160,7 @@
}
public void unregister(INsdStatusReceiver receiver, int listenerId) {
- postRegistrationJob(() -> unregisterInternal(receiver, listenerId));
+ mHandler.post(() -> unregisterInternal(receiver, listenerId));
}
public void unregisterInternal(INsdStatusReceiver receiver, int listenerId) {
@@ -338,7 +320,6 @@
}
}
mRegistrationListeners.clear();
- mRegistrationJobs.clear();
}
/** On ot-daemon died, reset. */
@@ -346,39 +327,6 @@
reset();
}
- // TODO: b/323300118 - Remove this mechanism when the race condition in NsdManager is fixed.
- /** Fetch the first job from the queue and run it. See the class doc for more details. */
- private void peekAndRun() {
- if (mRegistrationJobs.isEmpty()) {
- return;
- }
- Runnable job = mRegistrationJobs.getFirst();
- job.run();
- }
-
- // TODO: b/323300118 - Remove this mechanism when the race condition in NsdManager is fixed.
- /**
- * Pop the first job from the queue and run the next job. See the class doc for more details.
- */
- private void popAndRunNext() {
- if (mRegistrationJobs.isEmpty()) {
- Log.i(TAG, "No registration jobs when trying to pop and run next.");
- return;
- }
- mRegistrationJobs.removeFirst();
- peekAndRun();
- }
-
- private void postRegistrationJob(Runnable registrationJob) {
- mHandler.post(
- () -> {
- mRegistrationJobs.addLast(registrationJob);
- if (mRegistrationJobs.size() == 1) {
- peekAndRun();
- }
- });
- }
-
private final class RegistrationListener implements NsdManager.RegistrationListener {
private final NsdServiceInfo mServiceInfo;
private final int mListenerId;
@@ -416,7 +364,6 @@
} catch (RemoteException ignored) {
// do nothing if the client is dead
}
- popAndRunNext();
}
@Override
@@ -438,7 +385,6 @@
// do nothing if the client is dead
}
}
- popAndRunNext();
}
@Override
@@ -456,7 +402,6 @@
} catch (RemoteException ignored) {
// do nothing if the client is dead
}
- popAndRunNext();
}
@Override
@@ -477,7 +422,6 @@
}
}
mRegistrationListeners.remove(mListenerId);
- popAndRunNext();
}
}
diff --git a/thread/service/java/com/android/server/thread/ThreadNetworkControllerService.java b/thread/service/java/com/android/server/thread/ThreadNetworkControllerService.java
index 0b13d1b..d80dcfb 100644
--- a/thread/service/java/com/android/server/thread/ThreadNetworkControllerService.java
+++ b/thread/service/java/com/android/server/thread/ThreadNetworkControllerService.java
@@ -182,6 +182,7 @@
private final NsdPublisher mNsdPublisher;
private final OtDaemonCallbackProxy mOtDaemonCallbackProxy = new OtDaemonCallbackProxy();
private final ConnectivityResources mResources;
+ private final Supplier<String> mCountryCodeSupplier;
// This should not be directly used for calling IOtDaemon APIs because ot-daemon may die and
// {@code mOtDaemon} will be set to {@code null}. Instead, use {@code getOtDaemon()}
@@ -215,7 +216,8 @@
ThreadPersistentSettings persistentSettings,
NsdPublisher nsdPublisher,
UserManager userManager,
- ConnectivityResources resources) {
+ ConnectivityResources resources,
+ Supplier<String> countryCodeSupplier) {
mContext = context;
mHandler = handler;
mNetworkProvider = networkProvider;
@@ -230,10 +232,13 @@
mNsdPublisher = nsdPublisher;
mUserManager = userManager;
mResources = resources;
+ mCountryCodeSupplier = countryCodeSupplier;
}
public static ThreadNetworkControllerService newInstance(
- Context context, ThreadPersistentSettings persistentSettings) {
+ Context context,
+ ThreadPersistentSettings persistentSettings,
+ Supplier<String> countryCodeSupplier) {
HandlerThread handlerThread = new HandlerThread("ThreadHandlerThread");
handlerThread.start();
Handler handler = new Handler(handlerThread.getLooper());
@@ -251,7 +256,8 @@
persistentSettings,
NsdPublisher.newInstance(context, handler),
context.getSystemService(UserManager.class),
- new ConnectivityResources(context));
+ new ConnectivityResources(context),
+ countryCodeSupplier);
}
private static Inet6Address bytesToInet6Address(byte[] addressBytes) {
@@ -347,7 +353,8 @@
isEnabled(),
mNsdPublisher,
getMeshcopTxtAttributes(mResources.get()),
- mOtDaemonCallbackProxy);
+ mOtDaemonCallbackProxy,
+ mCountryCodeSupplier.get());
otDaemon.asBinder().linkToDeath(() -> mHandler.post(this::onOtDaemonDied), 0);
mOtDaemon = otDaemon;
return mOtDaemon;
diff --git a/thread/service/java/com/android/server/thread/ThreadNetworkCountryCode.java b/thread/service/java/com/android/server/thread/ThreadNetworkCountryCode.java
index ffa7b44..a194114 100644
--- a/thread/service/java/com/android/server/thread/ThreadNetworkCountryCode.java
+++ b/thread/service/java/com/android/server/thread/ThreadNetworkCountryCode.java
@@ -16,6 +16,8 @@
package com.android.server.thread;
+import static com.android.server.thread.ThreadPersistentSettings.THREAD_COUNTRY_CODE;
+
import android.annotation.Nullable;
import android.annotation.StringDef;
import android.annotation.TargetApi;
@@ -83,6 +85,7 @@
COUNTRY_CODE_SOURCE_TELEPHONY,
COUNTRY_CODE_SOURCE_TELEPHONY_LAST,
COUNTRY_CODE_SOURCE_WIFI,
+ COUNTRY_CODE_SOURCE_SETTINGS,
})
private @interface CountryCodeSource {}
@@ -93,6 +96,7 @@
private static final String COUNTRY_CODE_SOURCE_TELEPHONY = "Telephony";
private static final String COUNTRY_CODE_SOURCE_TELEPHONY_LAST = "TelephonyLast";
private static final String COUNTRY_CODE_SOURCE_WIFI = "Wifi";
+ private static final String COUNTRY_CODE_SOURCE_SETTINGS = "Settings";
private static final CountryCodeInfo DEFAULT_COUNTRY_CODE_INFO =
new CountryCodeInfo(DEFAULT_COUNTRY_CODE, COUNTRY_CODE_SOURCE_DEFAULT);
@@ -107,6 +111,7 @@
private final SubscriptionManager mSubscriptionManager;
private final Map<Integer, TelephonyCountryCodeSlotInfo> mTelephonyCountryCodeSlotInfoMap =
new ArrayMap();
+ private final ThreadPersistentSettings mPersistentSettings;
@Nullable private CountryCodeInfo mCurrentCountryCodeInfo;
@Nullable private CountryCodeInfo mLocationCountryCodeInfo;
@@ -215,7 +220,8 @@
Context context,
TelephonyManager telephonyManager,
SubscriptionManager subscriptionManager,
- @Nullable String oemCountryCode) {
+ @Nullable String oemCountryCode,
+ ThreadPersistentSettings persistentSettings) {
mLocationManager = locationManager;
mThreadNetworkControllerService = threadNetworkControllerService;
mGeocoder = geocoder;
@@ -224,14 +230,19 @@
mContext = context;
mTelephonyManager = telephonyManager;
mSubscriptionManager = subscriptionManager;
+ mPersistentSettings = persistentSettings;
if (oemCountryCode != null) {
mOemCountryCodeInfo = new CountryCodeInfo(oemCountryCode, COUNTRY_CODE_SOURCE_OEM);
}
+
+ mCurrentCountryCodeInfo = pickCountryCode();
}
public static ThreadNetworkCountryCode newInstance(
- Context context, ThreadNetworkControllerService controllerService) {
+ Context context,
+ ThreadNetworkControllerService controllerService,
+ ThreadPersistentSettings persistentSettings) {
return new ThreadNetworkCountryCode(
context.getSystemService(LocationManager.class),
controllerService,
@@ -241,7 +252,8 @@
context,
context.getSystemService(TelephonyManager.class),
context.getSystemService(SubscriptionManager.class),
- ThreadNetworkProperties.country_code().orElse(null));
+ ThreadNetworkProperties.country_code().orElse(null),
+ persistentSettings);
}
/** Sets up this country code module to listen to location country code changes. */
@@ -485,6 +497,11 @@
return mLocationCountryCodeInfo;
}
+ String settingsCountryCode = mPersistentSettings.get(THREAD_COUNTRY_CODE);
+ if (settingsCountryCode != null) {
+ return new CountryCodeInfo(settingsCountryCode, COUNTRY_CODE_SOURCE_SETTINGS);
+ }
+
if (mOemCountryCodeInfo != null) {
return mOemCountryCodeInfo;
}
@@ -498,6 +515,8 @@
public void onSuccess() {
synchronized ("ThreadNetworkCountryCode.this") {
mCurrentCountryCodeInfo = countryCodeInfo;
+ mPersistentSettings.put(
+ THREAD_COUNTRY_CODE.key, countryCodeInfo.getCountryCode());
}
}
@@ -536,10 +555,9 @@
newOperationReceiver(countryCodeInfo));
}
- /** Returns the current country code or {@code null} if no country code is set. */
- @Nullable
+ /** Returns the current country code. */
public synchronized String getCountryCode() {
- return (mCurrentCountryCodeInfo != null) ? mCurrentCountryCodeInfo.getCountryCode() : null;
+ return mCurrentCountryCodeInfo.getCountryCode();
}
/**
diff --git a/thread/service/java/com/android/server/thread/ThreadNetworkService.java b/thread/service/java/com/android/server/thread/ThreadNetworkService.java
index 37c1cf1..30c67ca 100644
--- a/thread/service/java/com/android/server/thread/ThreadNetworkService.java
+++ b/thread/service/java/com/android/server/thread/ThreadNetworkService.java
@@ -60,13 +60,16 @@
if (phase == SystemService.PHASE_SYSTEM_SERVICES_READY) {
mPersistentSettings.initialize();
mControllerService =
- ThreadNetworkControllerService.newInstance(mContext, mPersistentSettings);
+ ThreadNetworkControllerService.newInstance(
+ mContext, mPersistentSettings, () -> mCountryCode.getCountryCode());
+ mCountryCode =
+ ThreadNetworkCountryCode.newInstance(
+ mContext, mControllerService, mPersistentSettings);
mControllerService.initialize();
} else if (phase == SystemService.PHASE_BOOT_COMPLETED) {
// Country code initialization is delayed to the BOOT_COMPLETED phase because it will
// call into Wi-Fi and Telephony service whose country code module is ready after
// PHASE_ACTIVITY_MANAGER_READY and PHASE_THIRD_PARTY_APPS_CAN_START
- mCountryCode = ThreadNetworkCountryCode.newInstance(mContext, mControllerService);
mCountryCode.initialize();
mShellCommand =
new ThreadNetworkShellCommand(requireNonNull(mControllerService), mCountryCode);
diff --git a/thread/service/java/com/android/server/thread/ThreadPersistentSettings.java b/thread/service/java/com/android/server/thread/ThreadPersistentSettings.java
index 923f002..8aaff60 100644
--- a/thread/service/java/com/android/server/thread/ThreadPersistentSettings.java
+++ b/thread/service/java/com/android/server/thread/ThreadPersistentSettings.java
@@ -63,6 +63,9 @@
/** Stores the Thread feature toggle state, true for enabled and false for disabled. */
public static final Key<Boolean> THREAD_ENABLED = new Key<>("thread_enabled", true);
+ /** Stores the Thread country code, null if no country code is stored. */
+ public static final Key<String> THREAD_COUNTRY_CODE = new Key<>("thread_country_code", null);
+
/******** Thread persistent setting keys ***************/
@GuardedBy("mLock")
@@ -123,7 +126,9 @@
private <T> T getObject(String key, T defaultValue) {
Object value;
synchronized (mLock) {
- if (defaultValue instanceof Boolean) {
+ if (defaultValue == null) {
+ value = mSettings.getString(key, null);
+ } else if (defaultValue instanceof Boolean) {
value = mSettings.getBoolean(key, (Boolean) defaultValue);
} else if (defaultValue instanceof Integer) {
value = mSettings.getInt(key, (Integer) defaultValue);
diff --git a/thread/tests/cts/src/android/net/thread/cts/ThreadNetworkControllerTest.java b/thread/tests/cts/src/android/net/thread/cts/ThreadNetworkControllerTest.java
index 0591c87..9a81388 100644
--- a/thread/tests/cts/src/android/net/thread/cts/ThreadNetworkControllerTest.java
+++ b/thread/tests/cts/src/android/net/thread/cts/ThreadNetworkControllerTest.java
@@ -864,11 +864,12 @@
@Test
public void meshcopService_threadDisabled_notDiscovered() throws Exception {
setUpTestNetwork();
-
CompletableFuture<NsdServiceInfo> serviceLostFuture = new CompletableFuture<>();
NsdManager.DiscoveryListener listener =
discoverForServiceLost(MESHCOP_SERVICE_TYPE, serviceLostFuture);
+
setEnabledAndWait(mController, false);
+
try {
serviceLostFuture.get(SERVICE_LOST_TIMEOUT_MILLIS, MILLISECONDS);
} catch (InterruptedException | ExecutionException | TimeoutException ignored) {
@@ -877,7 +878,6 @@
} finally {
mNsdManager.stopServiceDiscovery(listener);
}
-
assertThrows(
TimeoutException.class,
() -> discoverService(MESHCOP_SERVICE_TYPE, SERVICE_LOST_TIMEOUT_MILLIS));
@@ -1112,7 +1112,12 @@
serviceInfoFuture.complete(serviceInfo);
}
};
- mNsdManager.discoverServices(serviceType, NsdManager.PROTOCOL_DNS_SD, listener);
+ mNsdManager.discoverServices(
+ serviceType,
+ NsdManager.PROTOCOL_DNS_SD,
+ mTestNetworkTracker.getNetwork(),
+ mExecutor,
+ listener);
try {
serviceInfoFuture.get(timeoutMilliseconds, MILLISECONDS);
} finally {
@@ -1131,7 +1136,12 @@
serviceInfoFuture.complete(serviceInfo);
}
};
- mNsdManager.discoverServices(serviceType, NsdManager.PROTOCOL_DNS_SD, listener);
+ mNsdManager.discoverServices(
+ serviceType,
+ NsdManager.PROTOCOL_DNS_SD,
+ mTestNetworkTracker.getNetwork(),
+ mExecutor,
+ listener);
return listener;
}
diff --git a/thread/tests/integration/src/android/net/thread/ThreadIntegrationTest.java b/thread/tests/integration/src/android/net/thread/ThreadIntegrationTest.java
index bfded1d..c70f3af 100644
--- a/thread/tests/integration/src/android/net/thread/ThreadIntegrationTest.java
+++ b/thread/tests/integration/src/android/net/thread/ThreadIntegrationTest.java
@@ -23,6 +23,7 @@
import static android.net.thread.utils.IntegrationTestUtils.RESTART_JOIN_TIMEOUT;
import static com.android.compatibility.common.util.SystemUtil.runShellCommand;
+import static com.android.compatibility.common.util.SystemUtil.runShellCommandOrThrow;
import static com.google.common.io.BaseEncoding.base16;
import static com.google.common.truth.Truth.assertThat;
@@ -140,6 +141,22 @@
}
}
+ @Test
+ public void otDaemonRestart_latestCountryCodeIsSetToOtDaemon() throws Exception {
+ runThreadCommand("force-country-code enabled CN");
+
+ runShellCommand("stop ot-daemon");
+ // TODO(b/323331973): the sleep is needed to workaround the race conditions
+ SystemClock.sleep(200);
+ mController.waitForRole(DEVICE_ROLE_STOPPED, CALLBACK_TIMEOUT);
+
+ assertThat(mOtCtl.getCountryCode()).isEqualTo("CN");
+ }
+
+ private static String runThreadCommand(String cmd) {
+ return runShellCommandOrThrow("cmd thread_network " + cmd);
+ }
+
// TODO (b/323300829): add more tests for integration with linux platform and
// ConnectivityService
}
diff --git a/thread/tests/integration/src/android/net/thread/utils/OtDaemonController.java b/thread/tests/integration/src/android/net/thread/utils/OtDaemonController.java
index ade0669..f39a064 100644
--- a/thread/tests/integration/src/android/net/thread/utils/OtDaemonController.java
+++ b/thread/tests/integration/src/android/net/thread/utils/OtDaemonController.java
@@ -74,6 +74,12 @@
return (Inet6Address) InetAddresses.parseNumericAddress(addressStr);
}
+ /** Returns the country code on ot-daemon. */
+ public String getCountryCode() {
+ String countryCodeStr = executeCommand("region").split("\n")[0].trim();
+ return countryCodeStr;
+ }
+
public String executeCommand(String cmd) {
return SystemUtil.runShellCommand(OT_CTL + " " + cmd);
}
diff --git a/thread/tests/unit/src/com/android/server/thread/ThreadNetworkControllerServiceTest.java b/thread/tests/unit/src/com/android/server/thread/ThreadNetworkControllerServiceTest.java
index 0c7d086..85b6873 100644
--- a/thread/tests/unit/src/com/android/server/thread/ThreadNetworkControllerServiceTest.java
+++ b/thread/tests/unit/src/com/android/server/thread/ThreadNetworkControllerServiceTest.java
@@ -25,6 +25,7 @@
import static android.net.thread.ThreadNetworkManager.DISALLOW_THREAD_NETWORK;
import static android.net.thread.ThreadNetworkManager.PERMISSION_THREAD_NETWORK_PRIVILEGED;
+import static com.android.server.thread.ThreadNetworkCountryCode.DEFAULT_COUNTRY_CODE;
import static com.android.server.thread.openthread.IOtDaemon.ErrorCode.OT_ERROR_INVALID_STATE;
import static com.google.common.io.BaseEncoding.base16;
@@ -182,7 +183,8 @@
mMockPersistentSettings,
mMockNsdPublisher,
mMockUserManager,
- mConnectivityResources);
+ mConnectivityResources,
+ () -> DEFAULT_COUNTRY_CODE);
mService.setTestNetworkAgent(mMockNetworkAgent);
}
diff --git a/thread/tests/unit/src/com/android/server/thread/ThreadNetworkCountryCodeTest.java b/thread/tests/unit/src/com/android/server/thread/ThreadNetworkCountryCodeTest.java
index 5ca6511..ca9741d 100644
--- a/thread/tests/unit/src/com/android/server/thread/ThreadNetworkCountryCodeTest.java
+++ b/thread/tests/unit/src/com/android/server/thread/ThreadNetworkCountryCodeTest.java
@@ -19,6 +19,7 @@
import static android.net.thread.ThreadNetworkException.ERROR_INTERNAL_ERROR;
import static com.android.server.thread.ThreadNetworkCountryCode.DEFAULT_COUNTRY_CODE;
+import static com.android.server.thread.ThreadPersistentSettings.THREAD_COUNTRY_CODE;
import static com.google.common.truth.Truth.assertThat;
@@ -104,6 +105,7 @@
@Mock List<SubscriptionInfo> mSubscriptionInfoList;
@Mock SubscriptionInfo mSubscriptionInfo0;
@Mock SubscriptionInfo mSubscriptionInfo1;
+ @Mock ThreadPersistentSettings mPersistentSettings;
private ThreadNetworkCountryCode mThreadNetworkCountryCode;
private boolean mErrorSetCountryCode;
@@ -164,7 +166,8 @@
mContext,
mTelephonyManager,
mSubscriptionManager,
- oemCountryCode);
+ oemCountryCode,
+ mPersistentSettings);
}
private static Address newAddress(String countryCode) {
@@ -450,6 +453,14 @@
}
@Test
+ public void settingsCountryCode_settingsCountryCodeIsActive_settingsCountryCodeIsUsed() {
+ when(mPersistentSettings.get(THREAD_COUNTRY_CODE)).thenReturn(TEST_COUNTRY_CODE_CN);
+ mThreadNetworkCountryCode.initialize();
+
+ assertThat(mThreadNetworkCountryCode.getCountryCode()).isEqualTo(TEST_COUNTRY_CODE_CN);
+ }
+
+ @Test
public void dump_allCountryCodeInfoAreDumped() {
StringWriter stringWriter = new StringWriter();
PrintWriter printWriter = new PrintWriter(stringWriter);
diff --git a/thread/tests/unit/src/com/android/server/thread/ThreadPersistentSettingsTest.java b/thread/tests/unit/src/com/android/server/thread/ThreadPersistentSettingsTest.java
index 9406a2f..7d2fe91 100644
--- a/thread/tests/unit/src/com/android/server/thread/ThreadPersistentSettingsTest.java
+++ b/thread/tests/unit/src/com/android/server/thread/ThreadPersistentSettingsTest.java
@@ -16,6 +16,7 @@
package com.android.server.thread;
+import static com.android.server.thread.ThreadPersistentSettings.THREAD_COUNTRY_CODE;
import static com.android.server.thread.ThreadPersistentSettings.THREAD_ENABLED;
import static com.google.common.truth.Truth.assertThat;
@@ -54,6 +55,8 @@
@RunWith(AndroidJUnit4.class)
@SmallTest
public class ThreadPersistentSettingsTest {
+ private static final String TEST_COUNTRY_CODE = "CN";
+
@Mock private AtomicFile mAtomicFile;
@Mock Resources mResources;
@Mock ConnectivityResources mConnectivityResources;
@@ -131,6 +134,28 @@
verify(mAtomicFile).finishWrite(any());
}
+ @Test
+ public void put_ThreadCountryCodeString_returnsString() throws Exception {
+ mThreadPersistentSettings.put(THREAD_COUNTRY_CODE.key, TEST_COUNTRY_CODE);
+
+ assertThat(mThreadPersistentSettings.get(THREAD_COUNTRY_CODE)).isEqualTo(TEST_COUNTRY_CODE);
+
+ // Confirm that file writes have been triggered.
+ verify(mAtomicFile).startWrite();
+ verify(mAtomicFile).finishWrite(any());
+ }
+
+ @Test
+ public void put_ThreadCountryCodeNull_returnsNull() throws Exception {
+ mThreadPersistentSettings.put(THREAD_COUNTRY_CODE.key, null);
+
+ assertThat(mThreadPersistentSettings.get(THREAD_COUNTRY_CODE)).isNull();
+
+ // Confirm that file writes have been triggered.
+ verify(mAtomicFile).startWrite();
+ verify(mAtomicFile).finishWrite(any());
+ }
+
private byte[] createXmlForParsing(String key, Boolean value) throws Exception {
PersistableBundle bundle = new PersistableBundle();
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
diff --git a/thread/tests/utils/src/android/net/thread/utils/TapTestNetworkTracker.java b/thread/tests/utils/src/android/net/thread/utils/TapTestNetworkTracker.java
index 43f177d..b586a19 100644
--- a/thread/tests/utils/src/android/net/thread/utils/TapTestNetworkTracker.java
+++ b/thread/tests/utils/src/android/net/thread/utils/TapTestNetworkTracker.java
@@ -62,6 +62,7 @@
private final Looper mLooper;
private TestNetworkInterface mInterface;
private TestableNetworkAgent mAgent;
+ private Network mNetwork;
private final TestableNetworkCallback mNetworkCallback;
private final ConnectivityManager mConnectivityManager;
@@ -91,6 +92,11 @@
return mInterface.getInterfaceName();
}
+ /** Returns the {@link android.net.Network} of the test network. */
+ public Network getNetwork() {
+ return mNetwork;
+ }
+
private void setUpTestNetwork() throws Exception {
mInterface = mContext.getSystemService(TestNetworkManager.class).createTapInterface();
@@ -105,13 +111,13 @@
newNetworkCapabilities(),
lp,
new NetworkAgentConfig.Builder().build());
- final Network network = mAgent.register();
+ mNetwork = mAgent.register();
mAgent.markConnected();
PollingCheck.check(
"No usable address on interface",
TIMEOUT.toMillis(),
- () -> hasUsableAddress(network, getInterfaceName()));
+ () -> hasUsableAddress(mNetwork, getInterfaceName()));
lp.setLinkAddresses(makeLinkAddresses());
mAgent.sendLinkProperties(lp);