Merge "Add isFeatureEnabled API for CaptivePortalLogin experiment namespace." into main
diff --git a/TEST_MAPPING b/TEST_MAPPING
index ab3ed66..d8d4c21 100644
--- a/TEST_MAPPING
+++ b/TEST_MAPPING
@@ -246,6 +246,9 @@
         },
         {
           "exclude-annotation": "com.android.testutils.DnsResolverModuleTest"
+        },
+        {
+          "exclude-annotation": "com.android.testutils.NetworkStackModuleTest"
         }
       ]
     },
diff --git a/framework/src/android/net/NetworkCapabilities.java b/framework/src/android/net/NetworkCapabilities.java
index 84a0d29..85b1dac 100644
--- a/framework/src/android/net/NetworkCapabilities.java
+++ b/framework/src/android/net/NetworkCapabilities.java
@@ -1775,8 +1775,7 @@
                 // use the same specifier, TelephonyNetworkSpecifier.
                 && mTransportTypes != (1L << TRANSPORT_TEST)
                 && Long.bitCount(mTransportTypes & ~(1L << TRANSPORT_TEST)) != 1
-                && (mTransportTypes & ~(1L << TRANSPORT_TEST))
-                != (1 << TRANSPORT_CELLULAR | 1 << TRANSPORT_SATELLITE)) {
+                && !specifierAcceptableForMultipleTransports(mTransportTypes)) {
             throw new IllegalStateException("Must have a single non-test transport specified to "
                     + "use setNetworkSpecifier");
         }
@@ -1786,6 +1785,12 @@
         return this;
     }
 
+    private boolean specifierAcceptableForMultipleTransports(long transportTypes) {
+        return (transportTypes & ~(1L << TRANSPORT_TEST))
+                // Cellular and satellite use the same NetworkSpecifier.
+                == (1 << TRANSPORT_CELLULAR | 1 << TRANSPORT_SATELLITE);
+    }
+
     /**
      * Sets the optional transport specific information.
      *
diff --git a/netd/BpfHandler.cpp b/netd/BpfHandler.cpp
index e6fc825..0d75c05 100644
--- a/netd/BpfHandler.cpp
+++ b/netd/BpfHandler.cpp
@@ -179,22 +179,24 @@
 }
 
 Status BpfHandler::init(const char* cg2_path) {
-    // Make sure BPF programs are loaded before doing anything
-    ALOGI("Waiting for BPF programs");
+    if (base::GetProperty("bpf.progs_loaded", "") != "1") {
+        // Make sure BPF programs are loaded before doing anything
+        ALOGI("Waiting for BPF programs");
 
-    if (true || !modules::sdklevel::IsAtLeastV()) {
-        waitForNetProgsLoaded();
-        ALOGI("Networking BPF programs are loaded");
+        if (true || !modules::sdklevel::IsAtLeastV()) {
+            waitForNetProgsLoaded();
+            ALOGI("Networking BPF programs are loaded");
 
-        if (!base::SetProperty("ctl.start", "mdnsd_loadbpf")) {
-            ALOGE("Failed to set property ctl.start=mdnsd_loadbpf, see dmesg for reason.");
-            abort();
+            if (!base::SetProperty("ctl.start", "mdnsd_loadbpf")) {
+                ALOGE("Failed to set property ctl.start=mdnsd_loadbpf, see dmesg for reason.");
+                abort();
+            }
+
+            ALOGI("Waiting for remaining BPF programs");
         }
 
-        ALOGI("Waiting for remaining BPF programs");
+        android::bpf::waitForProgsLoaded();
     }
-
-    android::bpf::waitForProgsLoaded();
     ALOGI("BPF programs are loaded");
 
     RETURN_IF_NOT_OK(initPrograms(cg2_path));
diff --git a/service-t/src/com/android/server/connectivity/mdns/EnqueueMdnsQueryCallable.java b/service-t/src/com/android/server/connectivity/mdns/EnqueueMdnsQueryCallable.java
index e61555a..54943c7 100644
--- a/service-t/src/com/android/server/connectivity/mdns/EnqueueMdnsQueryCallable.java
+++ b/service-t/src/com/android/server/connectivity/mdns/EnqueueMdnsQueryCallable.java
@@ -84,6 +84,7 @@
     private final byte[] packetCreationBuffer = new byte[1500]; // TODO: use interface MTU
     @NonNull
     private final List<MdnsResponse> existingServices;
+    private final boolean isQueryWithKnownAnswer;
 
     EnqueueMdnsQueryCallable(
             @NonNull MdnsSocketClientBase requestSender,
@@ -98,7 +99,8 @@
             @NonNull MdnsUtils.Clock clock,
             @NonNull SharedLog sharedLog,
             @NonNull MdnsServiceTypeClient.Dependencies dependencies,
-            @NonNull Collection<MdnsResponse> existingServices) {
+            @NonNull Collection<MdnsResponse> existingServices,
+            boolean isQueryWithKnownAnswer) {
         weakRequestSender = new WeakReference<>(requestSender);
         serviceTypeLabels = TextUtils.split(serviceType, "\\.");
         this.subtypes = new ArrayList<>(subtypes);
@@ -112,6 +114,7 @@
         this.sharedLog = sharedLog;
         this.dependencies = dependencies;
         this.existingServices = new ArrayList<>(existingServices);
+        this.isQueryWithKnownAnswer = isQueryWithKnownAnswer;
     }
 
     /**
@@ -226,27 +229,27 @@
 
     private void sendPacket(MdnsSocketClientBase requestSender, InetSocketAddress address,
             MdnsPacket mdnsPacket) throws IOException {
-        final DatagramPacket packet = dependencies.getDatagramPacketFromMdnsPacket(
-                packetCreationBuffer, mdnsPacket, address);
+        final List<DatagramPacket> packets = dependencies.getDatagramPacketsFromMdnsPacket(
+                packetCreationBuffer, mdnsPacket, address, isQueryWithKnownAnswer);
         if (expectUnicastResponse) {
             // MdnsMultinetworkSocketClient is only available on T+
             if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.TIRAMISU
                     && requestSender instanceof MdnsMultinetworkSocketClient) {
                 ((MdnsMultinetworkSocketClient) requestSender).sendPacketRequestingUnicastResponse(
-                        packet, socketKey, onlyUseIpv6OnIpv6OnlyNetworks);
+                        packets, socketKey, onlyUseIpv6OnIpv6OnlyNetworks);
             } else {
                 requestSender.sendPacketRequestingUnicastResponse(
-                        packet, onlyUseIpv6OnIpv6OnlyNetworks);
+                        packets, onlyUseIpv6OnIpv6OnlyNetworks);
             }
         } else {
             if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.TIRAMISU
                     && requestSender instanceof MdnsMultinetworkSocketClient) {
                 ((MdnsMultinetworkSocketClient) requestSender)
                         .sendPacketRequestingMulticastResponse(
-                                packet, socketKey, onlyUseIpv6OnIpv6OnlyNetworks);
+                                packets, socketKey, onlyUseIpv6OnIpv6OnlyNetworks);
             } else {
                 requestSender.sendPacketRequestingMulticastResponse(
-                        packet, onlyUseIpv6OnIpv6OnlyNetworks);
+                        packets, onlyUseIpv6OnIpv6OnlyNetworks);
             }
         }
     }
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java b/service-t/src/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java
index 869ac9b..fcfb15f 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java
@@ -27,6 +27,7 @@
 import android.os.Handler;
 import android.os.Looper;
 import android.util.ArrayMap;
+import android.util.Log;
 
 import com.android.net.module.util.SharedLog;
 
@@ -213,24 +214,30 @@
         return true;
     }
 
-    private void sendMdnsPacket(@NonNull DatagramPacket packet, @NonNull SocketKey targetSocketKey,
-            boolean onlyUseIpv6OnIpv6OnlyNetworks) {
+    private void sendMdnsPackets(@NonNull List<DatagramPacket> packets,
+            @NonNull SocketKey targetSocketKey, boolean onlyUseIpv6OnIpv6OnlyNetworks) {
         final MdnsInterfaceSocket socket = getTargetSocket(targetSocketKey);
         if (socket == null) {
             mSharedLog.e("No socket matches targetSocketKey=" + targetSocketKey);
             return;
         }
+        if (packets.isEmpty()) {
+            Log.wtf(TAG, "No mDns packets to send");
+            return;
+        }
 
-        final boolean isIpv6 = ((InetSocketAddress) packet.getSocketAddress()).getAddress()
-                instanceof Inet6Address;
-        final boolean isIpv4 = ((InetSocketAddress) packet.getSocketAddress()).getAddress()
-                instanceof Inet4Address;
+        final boolean isIpv6 = ((InetSocketAddress) packets.get(0).getSocketAddress())
+                .getAddress() instanceof Inet6Address;
+        final boolean isIpv4 = ((InetSocketAddress) packets.get(0).getSocketAddress())
+                .getAddress() instanceof Inet4Address;
         final boolean shouldQueryIpv6 = !onlyUseIpv6OnIpv6OnlyNetworks || !socket.hasJoinedIpv4();
         // Check ip capability and network before sending packet
         if ((isIpv6 && socket.hasJoinedIpv6() && shouldQueryIpv6)
                 || (isIpv4 && socket.hasJoinedIpv4())) {
             try {
-                socket.send(packet);
+                for (DatagramPacket packet : packets) {
+                    socket.send(packet);
+                }
             } catch (IOException e) {
                 mSharedLog.e("Failed to send a mDNS packet.", e);
             }
@@ -259,34 +266,34 @@
     }
 
     /**
-     * Send a mDNS request packet via given socket key that asks for multicast response.
+     * Send mDNS request packets via given socket key that asks for multicast response.
      */
-    public void sendPacketRequestingMulticastResponse(@NonNull DatagramPacket packet,
+    public void sendPacketRequestingMulticastResponse(@NonNull List<DatagramPacket> packets,
             @NonNull SocketKey socketKey, boolean onlyUseIpv6OnIpv6OnlyNetworks) {
-        mHandler.post(() -> sendMdnsPacket(packet, socketKey, onlyUseIpv6OnIpv6OnlyNetworks));
+        mHandler.post(() -> sendMdnsPackets(packets, socketKey, onlyUseIpv6OnIpv6OnlyNetworks));
     }
 
     @Override
     public void sendPacketRequestingMulticastResponse(
-            @NonNull DatagramPacket packet, boolean onlyUseIpv6OnIpv6OnlyNetworks) {
+            @NonNull List<DatagramPacket> packets, boolean onlyUseIpv6OnIpv6OnlyNetworks) {
         throw new UnsupportedOperationException("This socket client need to specify the socket to"
                 + "send packet");
     }
 
     /**
-     * Send a mDNS request packet via given socket key that asks for unicast response.
+     * Send mDNS request packets via given socket key that asks for unicast response.
      *
      * <p>The socket client may use a null network to identify some or all interfaces, in which case
      * passing null sends the packet to these.
      */
-    public void sendPacketRequestingUnicastResponse(@NonNull DatagramPacket packet,
+    public void sendPacketRequestingUnicastResponse(@NonNull List<DatagramPacket> packets,
             @NonNull SocketKey socketKey, boolean onlyUseIpv6OnIpv6OnlyNetworks) {
-        mHandler.post(() -> sendMdnsPacket(packet, socketKey, onlyUseIpv6OnIpv6OnlyNetworks));
+        mHandler.post(() -> sendMdnsPackets(packets, socketKey, onlyUseIpv6OnIpv6OnlyNetworks));
     }
 
     @Override
     public void sendPacketRequestingUnicastResponse(
-            @NonNull DatagramPacket packet, boolean onlyUseIpv6OnIpv6OnlyNetworks) {
+            @NonNull List<DatagramPacket> packets, boolean onlyUseIpv6OnIpv6OnlyNetworks) {
         throw new UnsupportedOperationException("This socket client need to specify the socket to"
                 + "send packet");
     }
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsRecord.java b/service-t/src/com/android/server/connectivity/mdns/MdnsRecord.java
index 4b43989..1f9f42b 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsRecord.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsRecord.java
@@ -23,7 +23,8 @@
 import android.os.SystemClock;
 import android.text.TextUtils;
 
-import com.android.internal.annotations.VisibleForTesting;
+import androidx.annotation.VisibleForTesting;
+
 import com.android.server.connectivity.mdns.util.MdnsUtils;
 
 import java.io.IOException;
@@ -231,7 +232,7 @@
      * @param writer The writer to use.
      * @param now    The current system time. This is used when writing the updated TTL.
      */
-    @VisibleForTesting
+    @VisibleForTesting(otherwise = VisibleForTesting.PACKAGE_PRIVATE)
     public final void write(MdnsPacketWriter writer, long now) throws IOException {
         writeHeaderFields(writer);
 
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
index bfcd0b4..b3bdbe0 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
@@ -30,7 +30,8 @@
 import android.util.ArrayMap;
 import android.util.Pair;
 
-import com.android.internal.annotations.VisibleForTesting;
+import androidx.annotation.VisibleForTesting;
+
 import com.android.net.module.util.CollectionUtils;
 import com.android.net.module.util.SharedLog;
 import com.android.server.connectivity.mdns.util.MdnsUtils;
@@ -195,7 +196,7 @@
     /**
      * Dependencies of MdnsServiceTypeClient, for injection in tests.
      */
-    @VisibleForTesting
+    @VisibleForTesting(otherwise = VisibleForTesting.PACKAGE_PRIVATE)
     public static class Dependencies {
         /**
          * @see Handler#sendMessageDelayed(Message, long)
@@ -227,13 +228,22 @@
         }
 
         /**
-         * Generate a DatagramPacket from given MdnsPacket and InetSocketAddress.
+         * Generate the DatagramPackets from given MdnsPacket and InetSocketAddress.
+         *
+         * <p> If the query with known answer feature is enabled and the MdnsPacket is too large for
+         *     a single DatagramPacket, it will be split into multiple DatagramPackets.
          */
-        public DatagramPacket getDatagramPacketFromMdnsPacket(@NonNull byte[] packetCreationBuffer,
-                @NonNull MdnsPacket packet, @NonNull InetSocketAddress address) throws IOException {
-            final byte[] queryBuffer =
-                    MdnsUtils.createRawDnsPacket(packetCreationBuffer, packet);
-            return new DatagramPacket(queryBuffer, 0, queryBuffer.length, address);
+        public List<DatagramPacket> getDatagramPacketsFromMdnsPacket(
+                @NonNull byte[] packetCreationBuffer, @NonNull MdnsPacket packet,
+                @NonNull InetSocketAddress address, boolean isQueryWithKnownAnswer)
+                throws IOException {
+            if (isQueryWithKnownAnswer) {
+                return MdnsUtils.createQueryDatagramPackets(packetCreationBuffer, packet, address);
+            } else {
+                final byte[] queryBuffer =
+                        MdnsUtils.createRawDnsPacket(packetCreationBuffer, packet);
+                return List.of(new DatagramPacket(queryBuffer, 0, queryBuffer.length, address));
+            }
         }
     }
 
@@ -742,7 +752,8 @@
                                 clock,
                                 sharedLog,
                                 dependencies,
-                                existingServices)
+                                existingServices,
+                                featureFlags.isQueryWithKnownAnswerEnabled())
                                 .call();
             } catch (RuntimeException e) {
                 sharedLog.e(String.format("Failed to run EnqueueMdnsQueryCallable for subtype: %s",
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClient.java b/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClient.java
index 7b71e43..9cfcba1 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClient.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClient.java
@@ -25,6 +25,7 @@
 import android.net.wifi.WifiManager.MulticastLock;
 import android.os.SystemClock;
 import android.text.format.DateUtils;
+import android.util.Log;
 
 import com.android.internal.annotations.VisibleForTesting;
 import com.android.net.module.util.SharedLog;
@@ -206,18 +207,18 @@
     }
 
     @Override
-    public void sendPacketRequestingMulticastResponse(@NonNull DatagramPacket packet,
+    public void sendPacketRequestingMulticastResponse(@NonNull List<DatagramPacket> packets,
             boolean onlyUseIpv6OnIpv6OnlyNetworks) {
-        sendMdnsPacket(packet, multicastPacketQueue, onlyUseIpv6OnIpv6OnlyNetworks);
+        sendMdnsPackets(packets, multicastPacketQueue, onlyUseIpv6OnIpv6OnlyNetworks);
     }
 
     @Override
-    public void sendPacketRequestingUnicastResponse(@NonNull DatagramPacket packet,
+    public void sendPacketRequestingUnicastResponse(@NonNull List<DatagramPacket> packets,
             boolean onlyUseIpv6OnIpv6OnlyNetworks) {
         if (useSeparateSocketForUnicast) {
-            sendMdnsPacket(packet, unicastPacketQueue, onlyUseIpv6OnIpv6OnlyNetworks);
+            sendMdnsPackets(packets, unicastPacketQueue, onlyUseIpv6OnIpv6OnlyNetworks);
         } else {
-            sendMdnsPacket(packet, multicastPacketQueue, onlyUseIpv6OnIpv6OnlyNetworks);
+            sendMdnsPackets(packets, multicastPacketQueue, onlyUseIpv6OnIpv6OnlyNetworks);
         }
     }
 
@@ -238,17 +239,21 @@
         return false;
     }
 
-    private void sendMdnsPacket(DatagramPacket packet, Queue<DatagramPacket> packetQueueToUse,
-            boolean onlyUseIpv6OnIpv6OnlyNetworks) {
+    private void sendMdnsPackets(List<DatagramPacket> packets,
+            Queue<DatagramPacket> packetQueueToUse, boolean onlyUseIpv6OnIpv6OnlyNetworks) {
         if (shouldStopSocketLoop && !MdnsConfigs.allowAddMdnsPacketAfterDiscoveryStops()) {
             sharedLog.w("sendMdnsPacket() is called after discovery already stopped");
             return;
         }
+        if (packets.isEmpty()) {
+            Log.wtf(TAG, "No mDns packets to send");
+            return;
+        }
 
-        final boolean isIpv4 = ((InetSocketAddress) packet.getSocketAddress()).getAddress()
-                instanceof Inet4Address;
-        final boolean isIpv6 = ((InetSocketAddress) packet.getSocketAddress()).getAddress()
-                instanceof Inet6Address;
+        final boolean isIpv4 = ((InetSocketAddress) packets.get(0).getSocketAddress())
+                .getAddress() instanceof Inet4Address;
+        final boolean isIpv6 = ((InetSocketAddress) packets.get(0).getSocketAddress())
+                .getAddress() instanceof Inet6Address;
         final boolean ipv6Only = multicastSocket != null && multicastSocket.isOnIPv6OnlyNetwork();
         if (isIpv4 && ipv6Only) {
             return;
@@ -258,10 +263,11 @@
         }
 
         synchronized (packetQueueToUse) {
-            while (packetQueueToUse.size() >= MdnsConfigs.mdnsPacketQueueMaxSize()) {
+            while ((packetQueueToUse.size() + packets.size())
+                    > MdnsConfigs.mdnsPacketQueueMaxSize()) {
                 packetQueueToUse.remove();
             }
-            packetQueueToUse.add(packet);
+            packetQueueToUse.addAll(packets);
         }
         triggerSendThread();
     }
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClientBase.java b/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClientBase.java
index b6000f0..b1a543a 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClientBase.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClientBase.java
@@ -23,6 +23,7 @@
 
 import java.io.IOException;
 import java.net.DatagramPacket;
+import java.util.List;
 
 /**
  * Base class for multicast socket client.
@@ -40,15 +41,15 @@
     void setCallback(@Nullable Callback callback);
 
     /**
-     * Send a mDNS request packet via given network that asks for multicast response.
+     * Send mDNS request packets via given network that asks for multicast response.
      */
-    void sendPacketRequestingMulticastResponse(@NonNull DatagramPacket packet,
+    void sendPacketRequestingMulticastResponse(@NonNull List<DatagramPacket> packets,
             boolean onlyUseIpv6OnIpv6OnlyNetworks);
 
     /**
-     * Send a mDNS request packet via given network that asks for unicast response.
+     * Send mDNS request packets via given network that asks for unicast response.
      */
-    void sendPacketRequestingUnicastResponse(@NonNull DatagramPacket packet,
+    void sendPacketRequestingUnicastResponse(@NonNull List<DatagramPacket> packets,
             boolean onlyUseIpv6OnIpv6OnlyNetworks);
 
     /*** Notify that the given network is requested for mdns discovery / resolution */
diff --git a/service/src/com/android/server/BpfNetMaps.java b/service/src/com/android/server/BpfNetMaps.java
index fc6d8c4..42c1628 100644
--- a/service/src/com/android/server/BpfNetMaps.java
+++ b/service/src/com/android/server/BpfNetMaps.java
@@ -918,6 +918,25 @@
         }
     }
 
+    /**
+     * Return whether the network is blocked by firewall chains for the given uid.
+     *
+     * Note that {@link #getDataSaverEnabled()} has a latency before V.
+     *
+     * @param uid The target uid.
+     * @param isNetworkMetered Whether the target network is metered.
+     *
+     * @return True if the network is blocked. Otherwise, false.
+     * @throws ServiceSpecificException if the read fails.
+     *
+     * @hide
+     */
+    @RequiresApi(Build.VERSION_CODES.TIRAMISU)
+    public boolean isUidNetworkingBlocked(final int uid, boolean isNetworkMetered) {
+        return BpfNetMapsUtils.isUidNetworkingBlocked(uid, isNetworkMetered,
+                sConfigurationMap, sUidOwnerMap, sDataSaverEnabledMap);
+    }
+
     /** Register callback for statsd to pull atom. */
     @RequiresApi(Build.VERSION_CODES.TIRAMISU)
     public void setPullAtomCallback(final Context context) {
diff --git a/service/src/com/android/server/ConnectivityService.java b/service/src/com/android/server/ConnectivityService.java
index 123ad8f..005d617 100755
--- a/service/src/com/android/server/ConnectivityService.java
+++ b/service/src/com/android/server/ConnectivityService.java
@@ -2235,7 +2235,11 @@
         final long ident = Binder.clearCallingIdentity();
         try {
             final boolean metered = nc == null ? true : nc.isMetered();
-            return mPolicyManager.isUidNetworkingBlocked(uid, metered);
+            if (mDeps.isAtLeastV()) {
+                return mBpfNetMaps.isUidNetworkingBlocked(uid, metered);
+            } else {
+                return mPolicyManager.isUidNetworkingBlocked(uid, metered);
+            }
         } finally {
             Binder.restoreCallingIdentity(ident);
         }
diff --git a/service/src/com/android/server/connectivity/SatelliteAccessController.java b/service/src/com/android/server/connectivity/SatelliteAccessController.java
index b53abce..2cdc932 100644
--- a/service/src/com/android/server/connectivity/SatelliteAccessController.java
+++ b/service/src/com/android/server/connectivity/SatelliteAccessController.java
@@ -20,7 +20,10 @@
 import android.annotation.NonNull;
 import android.app.role.OnRoleHoldersChangedListener;
 import android.app.role.RoleManager;
+import android.content.BroadcastReceiver;
 import android.content.Context;
+import android.content.Intent;
+import android.content.IntentFilter;
 import android.content.pm.ApplicationInfo;
 import android.content.pm.PackageManager;
 import android.os.Handler;
@@ -49,7 +52,6 @@
     private final Context mContext;
     private final Dependencies mDeps;
     private final DefaultMessageRoleListener mDefaultMessageRoleListener;
-    private final UserManager mUserManager;
     private final Consumer<Set<Integer>> mCallback;
     private final Handler mConnectivityServiceHandler;
 
@@ -114,7 +116,6 @@
             @NonNull final Handler connectivityServiceInternalHandler) {
         mContext = c;
         mDeps = deps;
-        mUserManager = mContext.getSystemService(UserManager.class);
         mDefaultMessageRoleListener = new DefaultMessageRoleListener();
         mCallback = callback;
         mConnectivityServiceHandler = connectivityServiceInternalHandler;
@@ -165,9 +166,6 @@
     }
 
     // on Role sms change triggered by OnRoleHoldersChangedListener()
-    // TODO(b/326373613): using UserLifecycleListener, callback to be received when user removed for
-    // user delete scenario. This to be used to update uid list and ML Layer request can also be
-    // updated.
     private void onRoleSmsChanged(@NonNull UserHandle userHandle) {
         int userId = userHandle.getIdentifier();
         if (userId == Process.INVALID_UID) {
@@ -184,9 +182,8 @@
                 mAllUsersSatelliteNetworkFallbackUidCache.get(userId, new ArraySet<>());
 
         Log.i(TAG, "currentUser : role_sms_packages: " + userId + " : " + packageNames);
-        final Set<Integer> newUidsForUser = !packageNames.isEmpty()
-                ? updateSatelliteNetworkFallbackUidListCache(packageNames, userHandle)
-                : new ArraySet<>();
+        final Set<Integer> newUidsForUser =
+                updateSatelliteNetworkFallbackUidListCache(packageNames, userHandle);
         Log.i(TAG, "satellite_fallback_uid: " + newUidsForUser);
 
         // on Role change, update the multilayer request at ConnectivityService with updated
@@ -197,6 +194,11 @@
 
         mAllUsersSatelliteNetworkFallbackUidCache.put(userId, newUidsForUser);
 
+        // Update all users fallback cache for user, send cs fallback to update ML request
+        reportSatelliteNetworkFallbackUids();
+    }
+
+    private void reportSatelliteNetworkFallbackUids() {
         // Merge all uids of multiple users available
         Set<Integer> mergedSatelliteNetworkFallbackUidCache = new ArraySet<>();
         for (int i = 0; i < mAllUsersSatelliteNetworkFallbackUidCache.size(); i++) {
@@ -210,27 +212,48 @@
         mCallback.accept(mergedSatelliteNetworkFallbackUidCache);
     }
 
-    private List<String> getRoleSmsChangedPackageName(UserHandle userHandle) {
-        try {
-            return mDeps.getRoleHoldersAsUser(RoleManager.ROLE_SMS, userHandle);
-        } catch (RuntimeException e) {
-            Log.wtf(TAG, "Could not get package name at role sms change update due to: " + e);
-            return null;
-        }
-    }
-
-    /** Register OnRoleHoldersChangedListener */
     public void start() {
         mConnectivityServiceHandler.post(this::updateAllUserRoleSmsUids);
+
+        // register sms OnRoleHoldersChangedListener
         mDefaultMessageRoleListener.register();
+
+        // Monitor for User removal intent, to update satellite fallback uids.
+        IntentFilter userRemovedFilter = new IntentFilter(Intent.ACTION_USER_REMOVED);
+        mContext.registerReceiver(new BroadcastReceiver() {
+            @Override
+            public void onReceive(Context context, Intent intent) {
+                final String action = intent.getAction();
+                if (Intent.ACTION_USER_REMOVED.equals(action)) {
+                    final UserHandle userHandle = intent.getParcelableExtra(Intent.EXTRA_USER);
+                    if (userHandle == null) return;
+                    updateSatelliteFallbackUidListOnUserRemoval(userHandle.getIdentifier());
+                } else {
+                    Log.wtf(TAG, "received unexpected intent: " + action);
+                }
+            }
+        }, userRemovedFilter, null, mConnectivityServiceHandler);
+
     }
 
     private void updateAllUserRoleSmsUids() {
-        List<UserHandle> existingUsers = mUserManager.getUserHandles(true /* excludeDying */);
+        UserManager userManager = mContext.getSystemService(UserManager.class);
+        // get existing user handles of available users
+        List<UserHandle> existingUsers = userManager.getUserHandles(true /*excludeDying*/);
+
         // Iterate through the user handles and obtain their uids with role sms and satellite
         // communication permission
+        Log.i(TAG, "existing users: " + existingUsers);
         for (UserHandle userHandle : existingUsers) {
             onRoleSmsChanged(userHandle);
         }
     }
+
+    private void updateSatelliteFallbackUidListOnUserRemoval(int userIdRemoved) {
+        Log.i(TAG, "user id removed:" + userIdRemoved);
+        if (mAllUsersSatelliteNetworkFallbackUidCache.contains(userIdRemoved)) {
+            mAllUsersSatelliteNetworkFallbackUidCache.remove(userIdRemoved);
+            reportSatelliteNetworkFallbackUids();
+        }
+    }
 }
diff --git a/tests/cts/net/src/android/net/cts/ApfIntegrationTest.kt b/tests/cts/net/src/android/net/cts/ApfIntegrationTest.kt
new file mode 100644
index 0000000..3ecbdc6
--- /dev/null
+++ b/tests/cts/net/src/android/net/cts/ApfIntegrationTest.kt
@@ -0,0 +1,91 @@
+/*
+ * Copyright (C) 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package android.net.cts
+
+import android.content.pm.PackageManager.FEATURE_WIFI
+import android.net.ConnectivityManager
+import android.net.NetworkCapabilities
+import android.net.NetworkRequest
+import android.os.Build
+import android.platform.test.annotations.AppModeFull
+import android.system.OsConstants
+import androidx.test.platform.app.InstrumentationRegistry
+import com.android.compatibility.common.util.PropertyUtil.isVendorApiLevelNewerThan
+import com.android.compatibility.common.util.SystemUtil
+import com.android.testutils.DevSdkIgnoreRule
+import com.android.testutils.DevSdkIgnoreRunner
+import com.android.testutils.NetworkStackModuleTest
+import com.android.testutils.RecorderCallback.CallbackEntry.LinkPropertiesChanged
+import com.android.testutils.TestableNetworkCallback
+import com.google.common.truth.Truth.assertThat
+import kotlin.test.assertEquals
+import kotlin.test.assertNotNull
+import org.junit.After
+import org.junit.Assume.assumeTrue
+import org.junit.Before
+import org.junit.Test
+import org.junit.runner.RunWith
+
+private const val TIMEOUT_MS = 2000L
+
+@AppModeFull(reason = "CHANGE_NETWORK_STATE permission can't be granted to instant apps")
+@RunWith(DevSdkIgnoreRunner::class)
+@NetworkStackModuleTest
+@DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.TIRAMISU)
+class ApfIntegrationTest {
+    private val context by lazy { InstrumentationRegistry.getInstrumentation().context }
+    private val cm by lazy { context.getSystemService(ConnectivityManager::class.java)!! }
+    private val pm by lazy { context.packageManager }
+    private lateinit var ifname: String
+    private lateinit var networkCallback: TestableNetworkCallback
+
+    @Before
+    fun setUp() {
+        assumeTrue(pm.hasSystemFeature(FEATURE_WIFI))
+        assumeTrue(isVendorApiLevelNewerThan(Build.VERSION_CODES.TIRAMISU))
+        networkCallback = TestableNetworkCallback()
+        cm.requestNetwork(
+                NetworkRequest.Builder()
+                        .addTransportType(NetworkCapabilities.TRANSPORT_WIFI)
+                        .addCapability(NetworkCapabilities.NET_CAPABILITY_INTERNET)
+                        .build(),
+                networkCallback
+        )
+        networkCallback.eventuallyExpect<LinkPropertiesChanged>(TIMEOUT_MS) {
+            ifname = assertNotNull(it.lp.interfaceName)
+            true
+        }
+    }
+
+    @After
+    fun tearDown() {
+        if (::networkCallback.isInitialized) {
+            cm.unregisterNetworkCallback(networkCallback)
+        }
+    }
+
+    @Test
+    fun testGetApfCapabilities() {
+        val caps = SystemUtil.runShellCommand("cmd network_stack apf $ifname capabilities").trim()
+        val (version, maxLen, packetFormat) = caps.split(",").map { it.toInt() }
+        assertEquals(4, version)
+        assertThat(maxLen).isAtLeast(1024)
+        if (isVendorApiLevelNewerThan(Build.VERSION_CODES.UPSIDE_DOWN_CAKE)) {
+            assertThat(maxLen).isAtLeast(2000)
+        }
+        assertEquals(OsConstants.ARPHRD_ETHER, packetFormat)
+    }
+}
diff --git a/tests/integration/src/com/android/server/net/integrationtests/NetworkStatsIntegrationTest.kt b/tests/integration/src/com/android/server/net/integrationtests/NetworkStatsIntegrationTest.kt
index 765e56e..52e502d 100644
--- a/tests/integration/src/com/android/server/net/integrationtests/NetworkStatsIntegrationTest.kt
+++ b/tests/integration/src/com/android/server/net/integrationtests/NetworkStatsIntegrationTest.kt
@@ -299,7 +299,8 @@
         val buf = ByteArray(DEFAULT_BUFFER_SIZE)
 
         httpServer.addResponse(
-            TestHttpServer.Request(path, NanoHTTPD.Method.POST), NanoHTTPD.Response.Status.OK,
+            TestHttpServer.Request(path, NanoHTTPD.Method.POST),
+            NanoHTTPD.Response.Status.OK,
             content = getRandomString(downloadSize)
         )
         var httpConnection: HttpURLConnection? = null
@@ -349,15 +350,19 @@
     ) {
         operator fun plus(other: BareStats): BareStats {
             return BareStats(
-                this.rxBytes + other.rxBytes, this.rxPackets + other.rxPackets,
-                this.txBytes + other.txBytes, this.txPackets + other.txPackets
+                this.rxBytes + other.rxBytes,
+                this.rxPackets + other.rxPackets,
+                this.txBytes + other.txBytes,
+                this.txPackets + other.txPackets
             )
         }
 
         operator fun minus(other: BareStats): BareStats {
             return BareStats(
-                this.rxBytes - other.rxBytes, this.rxPackets - other.rxPackets,
-                this.txBytes - other.txBytes, this.txPackets - other.txPackets
+                this.rxBytes - other.rxBytes,
+                this.rxPackets - other.rxPackets,
+                this.txBytes - other.txBytes,
+                this.txPackets - other.txPackets
             )
         }
 
@@ -405,8 +410,12 @@
         private fun getUidDetail(iface: String, tag: Int): BareStats {
             return getNetworkStatsThat(iface, tag) { nsm, template ->
                 nsm.queryDetailsForUidTagState(
-                    template, Long.MIN_VALUE, Long.MAX_VALUE,
-                    Process.myUid(), tag, Bucket.STATE_ALL
+                    template,
+                    Long.MIN_VALUE,
+                    Long.MAX_VALUE,
+                    Process.myUid(),
+                    tag,
+                    Bucket.STATE_ALL
                 )
             }
         }
@@ -498,28 +507,36 @@
         assertInRange(
             "Unexpected iface traffic stats",
             after.iface,
-            before.trafficStatsIface, after.trafficStatsIface,
-            lower, upper
+            before.trafficStatsIface,
+            after.trafficStatsIface,
+            lower,
+            upper
         )
         // Uid traffic stats are counted in both direction because the external network
         // traffic is also attributed to the test uid.
         assertInRange(
             "Unexpected uid traffic stats",
             after.iface,
-            before.trafficStatsUid, after.trafficStatsUid,
-            lower + lower.reverse(), upper + upper.reverse()
+            before.trafficStatsUid,
+            after.trafficStatsUid,
+            lower + lower.reverse(),
+            upper + upper.reverse()
         )
         assertInRange(
             "Unexpected non-tagged summary stats",
             after.iface,
-            before.statsSummary, after.statsSummary,
-            lower, upper
+            before.statsSummary,
+            after.statsSummary,
+            lower,
+            upper
         )
         assertInRange(
             "Unexpected non-tagged uid stats",
             after.iface,
-            before.statsUid, after.statsUid,
-            lower, upper
+            before.statsUid,
+            after.statsUid,
+            lower,
+            upper
         )
     }
 
@@ -546,14 +563,16 @@
         assertInRange(
             "Unexpected tagged summary stats",
             after.iface,
-            before.taggedSummary, after.taggedSummary,
+            before.taggedSummary,
+            after.taggedSummary,
             lower,
             upper
         )
         assertInRange(
             "Unexpected tagged uid stats: ${Process.myUid()}",
             after.iface,
-            before.taggedUid, after.taggedUid,
+            before.taggedUid,
+            after.taggedUid,
             lower,
             upper
         )
@@ -570,7 +589,8 @@
     ) {
         // Passing the value after operation and the value before operation to dump the actual
         // numbers if it fails.
-        assertTrue(checkInRange(before, after, lower, upper),
+        assertTrue(
+            checkInRange(before, after, lower, upper),
             "$tag on $iface: $after - $before is not within range [$lower, $upper]"
         )
     }
diff --git a/tests/unit/java/com/android/server/ConnectivityServiceTest.java b/tests/unit/java/com/android/server/ConnectivityServiceTest.java
index 1f8a743..17c5901 100755
--- a/tests/unit/java/com/android/server/ConnectivityServiceTest.java
+++ b/tests/unit/java/com/android/server/ConnectivityServiceTest.java
@@ -1719,6 +1719,8 @@
     private void mockUidNetworkingBlocked() {
         doAnswer(i -> isUidBlocked(mBlockedReasons, i.getArgument(1))
         ).when(mNetworkPolicyManager).isUidNetworkingBlocked(anyInt(), anyBoolean());
+        doAnswer(i -> isUidBlocked(mBlockedReasons, i.getArgument(1))
+        ).when(mBpfNetMaps).isUidNetworkingBlocked(anyInt(), anyBoolean());
     }
 
     private boolean isUidBlocked(int blockedReasons, boolean meteredNetwork) {
diff --git a/tests/unit/java/com/android/server/connectivity/SatelliteAccessControllerTest.kt b/tests/unit/java/com/android/server/connectivity/SatelliteAccessControllerTest.kt
index 193078b..7885325 100644
--- a/tests/unit/java/com/android/server/connectivity/SatelliteAccessControllerTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/SatelliteAccessControllerTest.kt
@@ -18,17 +18,22 @@
 import android.Manifest
 import android.app.role.OnRoleHoldersChangedListener
 import android.app.role.RoleManager
+import android.content.BroadcastReceiver
 import android.content.Context
+import android.content.Intent
+import android.content.IntentFilter
 import android.content.pm.ApplicationInfo
 import android.content.pm.PackageManager
-import android.content.pm.UserInfo
 import android.os.Build
 import android.os.Handler
+import android.os.Looper
 import android.os.UserHandle
+import android.os.UserManager
 import android.util.ArraySet
-import com.android.server.makeMockUserManager
 import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo
 import com.android.testutils.DevSdkIgnoreRunner
+import java.util.concurrent.Executor
+import java.util.function.Consumer
 import org.junit.Before
 import org.junit.Test
 import org.junit.runner.RunWith
@@ -36,30 +41,32 @@
 import org.mockito.ArgumentMatchers.any
 import org.mockito.ArgumentMatchers.anyInt
 import org.mockito.ArgumentMatchers.eq
+import org.mockito.ArgumentMatchers.isNull
 import org.mockito.Mockito.doReturn
 import org.mockito.Mockito.mock
 import org.mockito.Mockito.never
+import org.mockito.Mockito.timeout
+import org.mockito.Mockito.times
 import org.mockito.Mockito.verify
-import java.util.concurrent.Executor
-import java.util.function.Consumer
 
-private const val USER = 0
-val USER_INFO = UserInfo(USER, "" /* name */, UserInfo.FLAG_PRIMARY)
-val USER_HANDLE = UserHandle(USER)
 private const val PRIMARY_USER = 0
 private const val SECONDARY_USER = 10
 private val PRIMARY_USER_HANDLE = UserHandle.of(PRIMARY_USER)
 private val SECONDARY_USER_HANDLE = UserHandle.of(SECONDARY_USER)
+
 // sms app names
 private const val SMS_APP1 = "sms_app_1"
 private const val SMS_APP2 = "sms_app_2"
+
 // sms app ids
 private const val SMS_APP_ID1 = 100
 private const val SMS_APP_ID2 = 101
+
 // UID for app1 and app2 on primary user
 // These app could become default sms app for user1
 private val PRIMARY_USER_SMS_APP_UID1 = UserHandle.getUid(PRIMARY_USER, SMS_APP_ID1)
 private val PRIMARY_USER_SMS_APP_UID2 = UserHandle.getUid(PRIMARY_USER, SMS_APP_ID2)
+
 // UID for app1 and app2 on secondary user
 // These app could become default sms app for user2
 private val SECONDARY_USER_SMS_APP_UID1 = UserHandle.getUid(SECONDARY_USER, SMS_APP_ID1)
@@ -69,154 +76,259 @@
 @IgnoreUpTo(Build.VERSION_CODES.TIRAMISU)
 class SatelliteAccessControllerTest {
     private val context = mock(Context::class.java)
-    private val mPackageManager = mock(PackageManager::class.java)
-    private val mHandler = mock(Handler::class.java)
-    private val mRoleManager =
-        mock(SatelliteAccessController.Dependencies::class.java)
+    private val primaryUserContext = mock(Context::class.java)
+    private val secondaryUserContext = mock(Context::class.java)
+    private val mPackageManagerPrimaryUser = mock(PackageManager::class.java)
+    private val mPackageManagerSecondaryUser = mock(PackageManager::class.java)
+    private val mDeps = mock(SatelliteAccessController.Dependencies::class.java)
     private val mCallback = mock(Consumer::class.java) as Consumer<Set<Int>>
-    private val mSatelliteAccessController =
-        SatelliteAccessController(context, mRoleManager, mCallback, mHandler)
+    private val userManager = mock(UserManager::class.java)
+    private val mHandler = Handler(Looper.getMainLooper())
+    private var mSatelliteAccessController =
+        SatelliteAccessController(context, mDeps, mCallback, mHandler)
     private lateinit var mRoleHolderChangedListener: OnRoleHoldersChangedListener
+    private lateinit var mUserRemovedReceiver: BroadcastReceiver
+
+    private fun <T> mockService(name: String, clazz: Class<T>, service: T) {
+        doReturn(name).`when`(context).getSystemServiceName(clazz)
+        doReturn(service).`when`(context).getSystemService(name)
+        if (context.getSystemService(clazz) == null) {
+            // Test is using mockito-extended
+            doReturn(service).`when`(context).getSystemService(clazz)
+        }
+    }
+
     @Before
     @Throws(PackageManager.NameNotFoundException::class)
     fun setup() {
-        makeMockUserManager(USER_INFO, USER_HANDLE)
-        doReturn(context).`when`(context).createContextAsUser(any(), anyInt())
-        doReturn(mPackageManager).`when`(context).packageManager
+        doReturn(emptyList<UserHandle>()).`when`(userManager).getUserHandles(true)
+        mockService(Context.USER_SERVICE, UserManager::class.java, userManager)
 
-        doReturn(PackageManager.PERMISSION_GRANTED)
-            .`when`(mPackageManager)
-            .checkPermission(Manifest.permission.SATELLITE_COMMUNICATION, SMS_APP1)
-        doReturn(PackageManager.PERMISSION_GRANTED)
-            .`when`(mPackageManager)
-            .checkPermission(Manifest.permission.SATELLITE_COMMUNICATION, SMS_APP2)
+        doReturn(primaryUserContext).`when`(context).createContextAsUser(PRIMARY_USER_HANDLE, 0)
+        doReturn(mPackageManagerPrimaryUser).`when`(primaryUserContext).packageManager
 
-        // Initialise default message application primary user package1
+        doReturn(secondaryUserContext).`when`(context).createContextAsUser(SECONDARY_USER_HANDLE, 0)
+        doReturn(mPackageManagerSecondaryUser).`when`(secondaryUserContext).packageManager
+
+        for (app in listOf(SMS_APP1, SMS_APP2)) {
+            doReturn(PackageManager.PERMISSION_GRANTED)
+                .`when`(mPackageManagerPrimaryUser)
+                .checkPermission(Manifest.permission.SATELLITE_COMMUNICATION, app)
+            doReturn(PackageManager.PERMISSION_GRANTED)
+                .`when`(mPackageManagerSecondaryUser)
+                .checkPermission(Manifest.permission.SATELLITE_COMMUNICATION, app)
+        }
+
+        // Initialise message application primary user package1
         val applicationInfo1 = ApplicationInfo()
         applicationInfo1.uid = PRIMARY_USER_SMS_APP_UID1
         doReturn(applicationInfo1)
-            .`when`(mPackageManager)
+            .`when`(mPackageManagerPrimaryUser)
             .getApplicationInfo(eq(SMS_APP1), anyInt())
 
-        // Initialise default message application primary user package2
+        // Initialise message application primary user package2
         val applicationInfo2 = ApplicationInfo()
         applicationInfo2.uid = PRIMARY_USER_SMS_APP_UID2
         doReturn(applicationInfo2)
-            .`when`(mPackageManager)
+            .`when`(mPackageManagerPrimaryUser)
             .getApplicationInfo(eq(SMS_APP2), anyInt())
 
-        // Get registered listener using captor
-        val listenerCaptor = ArgumentCaptor.forClass(
-            OnRoleHoldersChangedListener::class.java
-        )
-        mSatelliteAccessController.start()
-        verify(mRoleManager).addOnRoleHoldersChangedListenerAsUser(
-            any(Executor::class.java), listenerCaptor.capture(), any(UserHandle::class.java))
-        mRoleHolderChangedListener = listenerCaptor.value
+        // Initialise message application secondary user package1
+        val applicationInfo3 = ApplicationInfo()
+        applicationInfo3.uid = SECONDARY_USER_SMS_APP_UID1
+        doReturn(applicationInfo3)
+            .`when`(mPackageManagerSecondaryUser)
+            .getApplicationInfo(eq(SMS_APP1), anyInt())
+
+        // Initialise message application secondary user package2
+        val applicationInfo4 = ApplicationInfo()
+        applicationInfo4.uid = SECONDARY_USER_SMS_APP_UID2
+        doReturn(applicationInfo4)
+            .`when`(mPackageManagerSecondaryUser)
+            .getApplicationInfo(eq(SMS_APP2), anyInt())
     }
 
     @Test
     fun test_onRoleHoldersChanged_SatelliteFallbackUid_Changed_SingleUser() {
-        doReturn(listOf<String>()).`when`(mRoleManager).getRoleHoldersAsUser(RoleManager.ROLE_SMS,
-            PRIMARY_USER_HANDLE)
+        startSatelliteAccessController()
+        doReturn(listOf<String>()).`when`(mDeps).getRoleHoldersAsUser(
+            RoleManager.ROLE_SMS,
+            PRIMARY_USER_HANDLE
+        )
         mRoleHolderChangedListener.onRoleHoldersChanged(RoleManager.ROLE_SMS, PRIMARY_USER_HANDLE)
         verify(mCallback, never()).accept(any())
 
         // check DEFAULT_MESSAGING_APP1 is available as satellite network fallback uid
         doReturn(listOf(SMS_APP1))
-            .`when`(mRoleManager).getRoleHoldersAsUser(RoleManager.ROLE_SMS, PRIMARY_USER_HANDLE)
+            .`when`(mDeps).getRoleHoldersAsUser(RoleManager.ROLE_SMS, PRIMARY_USER_HANDLE)
         mRoleHolderChangedListener.onRoleHoldersChanged(RoleManager.ROLE_SMS, PRIMARY_USER_HANDLE)
         verify(mCallback).accept(setOf(PRIMARY_USER_SMS_APP_UID1))
 
         // check SMS_APP2 is available as satellite network Fallback uid
-        doReturn(listOf(SMS_APP2)).`when`(mRoleManager).getRoleHoldersAsUser(RoleManager.ROLE_SMS,
-            PRIMARY_USER_HANDLE)
+        doReturn(listOf(SMS_APP2)).`when`(mDeps).getRoleHoldersAsUser(
+            RoleManager.ROLE_SMS,
+            PRIMARY_USER_HANDLE
+        )
         mRoleHolderChangedListener.onRoleHoldersChanged(RoleManager.ROLE_SMS, PRIMARY_USER_HANDLE)
         verify(mCallback).accept(setOf(PRIMARY_USER_SMS_APP_UID2))
 
         // check no uid is available as satellite network fallback uid
-        doReturn(listOf<String>()).`when`(mRoleManager).getRoleHoldersAsUser(RoleManager.ROLE_SMS,
-            PRIMARY_USER_HANDLE)
+        doReturn(listOf<String>()).`when`(mDeps).getRoleHoldersAsUser(
+            RoleManager.ROLE_SMS,
+            PRIMARY_USER_HANDLE
+        )
         mRoleHolderChangedListener.onRoleHoldersChanged(RoleManager.ROLE_SMS, PRIMARY_USER_HANDLE)
         verify(mCallback).accept(ArraySet())
     }
 
     @Test
     fun test_onRoleHoldersChanged_NoSatelliteCommunicationPermission() {
-        doReturn(listOf<Any>()).`when`(mRoleManager).getRoleHoldersAsUser(RoleManager.ROLE_SMS,
-            PRIMARY_USER_HANDLE)
+        startSatelliteAccessController()
+        doReturn(listOf<Any>()).`when`(mDeps).getRoleHoldersAsUser(
+            RoleManager.ROLE_SMS,
+            PRIMARY_USER_HANDLE
+        )
         mRoleHolderChangedListener.onRoleHoldersChanged(RoleManager.ROLE_SMS, PRIMARY_USER_HANDLE)
         verify(mCallback, never()).accept(any())
 
         // check DEFAULT_MESSAGING_APP1 is not available as satellite network fallback uid
         // since satellite communication permission not available.
         doReturn(PackageManager.PERMISSION_DENIED)
-            .`when`(mPackageManager)
+            .`when`(mPackageManagerPrimaryUser)
             .checkPermission(Manifest.permission.SATELLITE_COMMUNICATION, SMS_APP1)
         doReturn(listOf(SMS_APP1))
-            .`when`(mRoleManager).getRoleHoldersAsUser(RoleManager.ROLE_SMS, PRIMARY_USER_HANDLE)
+            .`when`(mDeps).getRoleHoldersAsUser(RoleManager.ROLE_SMS, PRIMARY_USER_HANDLE)
         mRoleHolderChangedListener.onRoleHoldersChanged(RoleManager.ROLE_SMS, PRIMARY_USER_HANDLE)
         verify(mCallback, never()).accept(any())
     }
 
     @Test
     fun test_onRoleHoldersChanged_RoleSms_NotAvailable() {
+        startSatelliteAccessController()
         doReturn(listOf(SMS_APP1))
-            .`when`(mRoleManager).getRoleHoldersAsUser(RoleManager.ROLE_SMS, PRIMARY_USER_HANDLE)
-        mRoleHolderChangedListener.onRoleHoldersChanged(RoleManager.ROLE_BROWSER,
-            PRIMARY_USER_HANDLE)
+            .`when`(mDeps).getRoleHoldersAsUser(RoleManager.ROLE_SMS, PRIMARY_USER_HANDLE)
+        mRoleHolderChangedListener.onRoleHoldersChanged(
+            RoleManager.ROLE_BROWSER,
+            PRIMARY_USER_HANDLE
+        )
         verify(mCallback, never()).accept(any())
     }
 
     @Test
     fun test_onRoleHoldersChanged_SatelliteNetworkFallbackUid_Changed_multiUser() {
-        doReturn(listOf<String>()).`when`(mRoleManager).getRoleHoldersAsUser(RoleManager.ROLE_SMS,
-            PRIMARY_USER_HANDLE)
+        startSatelliteAccessController()
+        doReturn(listOf<String>()).`when`(mDeps).getRoleHoldersAsUser(
+            RoleManager.ROLE_SMS,
+            PRIMARY_USER_HANDLE
+        )
         mRoleHolderChangedListener.onRoleHoldersChanged(RoleManager.ROLE_SMS, PRIMARY_USER_HANDLE)
         verify(mCallback, never()).accept(any())
 
         // check SMS_APP1 is available as satellite network fallback uid at primary user
         doReturn(listOf(SMS_APP1))
-            .`when`(mRoleManager).getRoleHoldersAsUser(RoleManager.ROLE_SMS, PRIMARY_USER_HANDLE)
+            .`when`(mDeps).getRoleHoldersAsUser(RoleManager.ROLE_SMS, PRIMARY_USER_HANDLE)
         mRoleHolderChangedListener.onRoleHoldersChanged(RoleManager.ROLE_SMS, PRIMARY_USER_HANDLE)
         verify(mCallback).accept(setOf(PRIMARY_USER_SMS_APP_UID1))
 
         // check SMS_APP2 is available as satellite network fallback uid at primary user
-        doReturn(listOf(SMS_APP2)).`when`(mRoleManager).getRoleHoldersAsUser(RoleManager.ROLE_SMS,
-            PRIMARY_USER_HANDLE)
+        doReturn(listOf(SMS_APP2)).`when`(mDeps).getRoleHoldersAsUser(
+            RoleManager.ROLE_SMS,
+            PRIMARY_USER_HANDLE
+        )
         mRoleHolderChangedListener.onRoleHoldersChanged(RoleManager.ROLE_SMS, PRIMARY_USER_HANDLE)
         verify(mCallback).accept(setOf(PRIMARY_USER_SMS_APP_UID2))
 
         // check SMS_APP1 is available as satellite network fallback uid at secondary user
-        val applicationInfo1 = ApplicationInfo()
-        applicationInfo1.uid = SECONDARY_USER_SMS_APP_UID1
-        doReturn(applicationInfo1).`when`(mPackageManager)
-            .getApplicationInfo(eq(SMS_APP1), anyInt())
-        doReturn(listOf(SMS_APP1)).`when`(mRoleManager).getRoleHoldersAsUser(RoleManager.ROLE_SMS,
-            SECONDARY_USER_HANDLE)
+        doReturn(listOf(SMS_APP1)).`when`(mDeps).getRoleHoldersAsUser(
+            RoleManager.ROLE_SMS,
+            SECONDARY_USER_HANDLE
+        )
         mRoleHolderChangedListener.onRoleHoldersChanged(RoleManager.ROLE_SMS, SECONDARY_USER_HANDLE)
         verify(mCallback).accept(setOf(PRIMARY_USER_SMS_APP_UID2, SECONDARY_USER_SMS_APP_UID1))
 
         // check no uid is available as satellite network fallback uid at primary user
-        doReturn(listOf<String>()).`when`(mRoleManager).getRoleHoldersAsUser(RoleManager.ROLE_SMS,
-            PRIMARY_USER_HANDLE)
-        mRoleHolderChangedListener.onRoleHoldersChanged(RoleManager.ROLE_SMS,
-            PRIMARY_USER_HANDLE)
+        doReturn(listOf<String>()).`when`(mDeps).getRoleHoldersAsUser(
+            RoleManager.ROLE_SMS,
+            PRIMARY_USER_HANDLE
+        )
+        mRoleHolderChangedListener.onRoleHoldersChanged(
+            RoleManager.ROLE_SMS,
+            PRIMARY_USER_HANDLE
+        )
         verify(mCallback).accept(setOf(SECONDARY_USER_SMS_APP_UID1))
 
         // check SMS_APP2 is available as satellite network fallback uid at secondary user
-        applicationInfo1.uid = SECONDARY_USER_SMS_APP_UID2
-        doReturn(applicationInfo1).`when`(mPackageManager)
-            .getApplicationInfo(eq(SMS_APP2), anyInt())
         doReturn(listOf(SMS_APP2))
-            .`when`(mRoleManager).getRoleHoldersAsUser(RoleManager.ROLE_SMS, SECONDARY_USER_HANDLE)
+            .`when`(mDeps).getRoleHoldersAsUser(RoleManager.ROLE_SMS, SECONDARY_USER_HANDLE)
         mRoleHolderChangedListener.onRoleHoldersChanged(RoleManager.ROLE_SMS, SECONDARY_USER_HANDLE)
         verify(mCallback).accept(setOf(SECONDARY_USER_SMS_APP_UID2))
 
         // check no uid is available as satellite network fallback uid at secondary user
-        doReturn(listOf<String>()).`when`(mRoleManager).getRoleHoldersAsUser(RoleManager.ROLE_SMS,
-            SECONDARY_USER_HANDLE)
+        doReturn(listOf<String>()).`when`(mDeps).getRoleHoldersAsUser(
+            RoleManager.ROLE_SMS,
+            SECONDARY_USER_HANDLE
+        )
         mRoleHolderChangedListener.onRoleHoldersChanged(RoleManager.ROLE_SMS, SECONDARY_USER_HANDLE)
         verify(mCallback).accept(ArraySet())
     }
+
+    @Test
+    fun test_SatelliteFallbackUidCallback_OnUserRemoval() {
+        startSatelliteAccessController()
+        // check SMS_APP2 is available as satellite network fallback uid at primary user
+        doReturn(listOf(SMS_APP2)).`when`(mDeps).getRoleHoldersAsUser(
+            RoleManager.ROLE_SMS,
+            PRIMARY_USER_HANDLE
+        )
+        mRoleHolderChangedListener.onRoleHoldersChanged(RoleManager.ROLE_SMS, PRIMARY_USER_HANDLE)
+        verify(mCallback).accept(setOf(PRIMARY_USER_SMS_APP_UID2))
+
+        // check SMS_APP1 is available as satellite network fallback uid at secondary user
+        doReturn(listOf(SMS_APP1)).`when`(mDeps).getRoleHoldersAsUser(
+            RoleManager.ROLE_SMS,
+            SECONDARY_USER_HANDLE
+        )
+        mRoleHolderChangedListener.onRoleHoldersChanged(RoleManager.ROLE_SMS, SECONDARY_USER_HANDLE)
+        verify(mCallback).accept(setOf(PRIMARY_USER_SMS_APP_UID2, SECONDARY_USER_SMS_APP_UID1))
+
+        val userRemovalIntent = Intent(Intent.ACTION_USER_REMOVED)
+        userRemovalIntent.putExtra(Intent.EXTRA_USER, SECONDARY_USER_HANDLE)
+        mUserRemovedReceiver.onReceive(context, userRemovalIntent)
+        verify(mCallback, times(2)).accept(setOf(PRIMARY_USER_SMS_APP_UID2))
+    }
+
+    @Test
+    fun testOnStartUpCallbackSatelliteFallbackUidWithExistingUsers() {
+        doReturn(
+            listOf(PRIMARY_USER_HANDLE)
+        ).`when`(userManager).getUserHandles(true)
+        doReturn(listOf(SMS_APP1))
+            .`when`(mDeps).getRoleHoldersAsUser(RoleManager.ROLE_SMS, PRIMARY_USER_HANDLE)
+        // At start up, SatelliteAccessController must call CS callback with existing users'
+        // default messaging apps uids.
+        startSatelliteAccessController()
+        verify(mCallback, timeout(500)).accept(setOf(PRIMARY_USER_SMS_APP_UID1))
+    }
+
+    private fun startSatelliteAccessController() {
+        mSatelliteAccessController.start()
+        // Get registered listener using captor
+        val listenerCaptor = ArgumentCaptor.forClass(OnRoleHoldersChangedListener::class.java)
+        verify(mDeps).addOnRoleHoldersChangedListenerAsUser(
+            any(Executor::class.java),
+            listenerCaptor.capture(),
+            any(UserHandle::class.java)
+        )
+        mRoleHolderChangedListener = listenerCaptor.value
+
+        // Get registered receiver using captor
+        val userRemovedReceiverCaptor = ArgumentCaptor.forClass(BroadcastReceiver::class.java)
+        verify(context).registerReceiver(
+            userRemovedReceiverCaptor.capture(),
+            any(IntentFilter::class.java),
+            isNull(),
+            any(Handler::class.java)
+        )
+         mUserRemovedReceiver = userRemovedReceiverCaptor.value
+    }
 }
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClientTest.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClientTest.java
index 9474464..fb3d183 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClientTest.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClientTest.java
@@ -23,6 +23,7 @@
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.eq;
+import static org.mockito.Mockito.inOrder;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.never;
 import static org.mockito.Mockito.timeout;
@@ -47,6 +48,7 @@
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.mockito.ArgumentCaptor;
+import org.mockito.InOrder;
 import org.mockito.Mock;
 import org.mockito.MockitoAnnotations;
 
@@ -55,6 +57,7 @@
 import java.net.DatagramPacket;
 import java.net.NetworkInterface;
 import java.net.SocketException;
+import java.util.ArrayList;
 import java.util.List;
 
 @RunWith(DevSdkIgnoreRunner.class)
@@ -154,7 +157,7 @@
         verify(mSocketCreationCallback).onSocketCreated(tetherSocketKey2);
 
         // Send packet to IPv4 with mSocketKey and verify sending has been called.
-        mSocketClient.sendPacketRequestingMulticastResponse(ipv4Packet, mSocketKey,
+        mSocketClient.sendPacketRequestingMulticastResponse(List.of(ipv4Packet), mSocketKey,
                 false /* onlyUseIpv6OnIpv6OnlyNetworks */);
         HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
         verify(mSocket).send(ipv4Packet);
@@ -162,7 +165,7 @@
         verify(tetherIfaceSock2, never()).send(any());
 
         // Send packet to IPv4 with onlyUseIpv6OnIpv6OnlyNetworks = true, the packet will be sent.
-        mSocketClient.sendPacketRequestingMulticastResponse(ipv4Packet, mSocketKey,
+        mSocketClient.sendPacketRequestingMulticastResponse(List.of(ipv4Packet), mSocketKey,
                 true /* onlyUseIpv6OnIpv6OnlyNetworks */);
         HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
         verify(mSocket, times(2)).send(ipv4Packet);
@@ -170,7 +173,7 @@
         verify(tetherIfaceSock2, never()).send(any());
 
         // Send packet to IPv6 with tetherSocketKey1 and verify sending has been called.
-        mSocketClient.sendPacketRequestingMulticastResponse(ipv6Packet, tetherSocketKey1,
+        mSocketClient.sendPacketRequestingMulticastResponse(List.of(ipv6Packet), tetherSocketKey1,
                 false /* onlyUseIpv6OnIpv6OnlyNetworks */);
         HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
         verify(mSocket, never()).send(ipv6Packet);
@@ -180,7 +183,7 @@
         // Send packet to IPv6 with onlyUseIpv6OnIpv6OnlyNetworks = true, the packet will not be
         // sent. Therefore, the tetherIfaceSock1.send() and tetherIfaceSock2.send() are still be
         // called once.
-        mSocketClient.sendPacketRequestingMulticastResponse(ipv6Packet, tetherSocketKey1,
+        mSocketClient.sendPacketRequestingMulticastResponse(List.of(ipv6Packet), tetherSocketKey1,
                 true /* onlyUseIpv6OnIpv6OnlyNetworks */);
         HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
         verify(mSocket, never()).send(ipv6Packet);
@@ -266,7 +269,7 @@
         verify(mSocketCreationCallback).onSocketCreated(socketKey3);
 
         // Send IPv4 packet on the mSocketKey and verify sending has been called.
-        mSocketClient.sendPacketRequestingMulticastResponse(ipv4Packet, mSocketKey,
+        mSocketClient.sendPacketRequestingMulticastResponse(List.of(ipv4Packet), mSocketKey,
                 false /* onlyUseIpv6OnIpv6OnlyNetworks */);
         HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
         verify(mSocket).send(ipv4Packet);
@@ -295,7 +298,7 @@
         verify(socketCreationCb2).onSocketCreated(socketKey3);
 
         // Send IPv4 packet on socket2 and verify sending to the socket2 only.
-        mSocketClient.sendPacketRequestingMulticastResponse(ipv4Packet, socketKey2,
+        mSocketClient.sendPacketRequestingMulticastResponse(List.of(ipv4Packet), socketKey2,
                 false /* onlyUseIpv6OnIpv6OnlyNetworks */);
         HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
         // ipv4Packet still sent only once on mSocket: times(1) matches the packet sent earlier on
@@ -309,7 +312,7 @@
         verify(mProvider, timeout(DEFAULT_TIMEOUT)).unrequestSocket(callback2);
 
         // Send IPv4 packet again and verify it's still sent a second time
-        mSocketClient.sendPacketRequestingMulticastResponse(ipv4Packet, socketKey2,
+        mSocketClient.sendPacketRequestingMulticastResponse(List.of(ipv4Packet), socketKey2,
                 false /* onlyUseIpv6OnIpv6OnlyNetworks */);
         HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
         verify(socket2, times(2)).send(ipv4Packet);
@@ -320,7 +323,7 @@
         verify(mProvider, timeout(DEFAULT_TIMEOUT)).unrequestSocket(callback);
 
         // Send IPv4 packet and verify no more sending.
-        mSocketClient.sendPacketRequestingMulticastResponse(ipv4Packet, mSocketKey,
+        mSocketClient.sendPacketRequestingMulticastResponse(List.of(ipv4Packet), mSocketKey,
                 false /* onlyUseIpv6OnIpv6OnlyNetworks */);
         HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
         verify(mSocket, times(1)).send(ipv4Packet);
@@ -407,4 +410,31 @@
         verify(creationCallback3).onSocketDestroyed(mSocketKey);
         verify(creationCallback3, never()).onSocketDestroyed(socketKey2);
     }
+
+    @Test
+    public void testSendPacketWithMultipleDatagramPacket() throws IOException {
+        final SocketCallback callback = expectSocketCallback();
+        final List<DatagramPacket> packets = new ArrayList<>();
+        for (int i = 0; i < 10; i++) {
+            packets.add(new DatagramPacket(new byte[10 + i] /* buff */, 0 /* offset */,
+                    10 + i /* length */, MdnsConstants.IPV4_SOCKET_ADDR));
+        }
+        doReturn(true).when(mSocket).hasJoinedIpv4();
+        doReturn(true).when(mSocket).hasJoinedIpv6();
+        doReturn(createEmptyNetworkInterface()).when(mSocket).getInterface();
+
+        // Notify socket created
+        callback.onSocketCreated(mSocketKey, mSocket, List.of());
+        verify(mSocketCreationCallback).onSocketCreated(mSocketKey);
+
+        // Send packets to IPv4 with mSocketKey then verify sending has been called and the
+        // sequence is correct.
+        mSocketClient.sendPacketRequestingMulticastResponse(packets, mSocketKey,
+                false /* onlyUseIpv6OnIpv6OnlyNetworks */);
+        HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
+        InOrder inOrder = inOrder(mSocket);
+        for (int i = 0; i < 10; i++) {
+            inOrder.verify(mSocket).send(packets.get(i));
+        }
+    }
 }
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java
index 2eb9440..44fa55c 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java
@@ -162,59 +162,59 @@
             expectedIPv6Packets[i] = new DatagramPacket(buf, 0 /* offset */, 5 /* length */,
                     MdnsConstants.getMdnsIPv6Address(), MdnsConstants.MDNS_PORT);
         }
-        when(mockDeps.getDatagramPacketFromMdnsPacket(
-                any(), any(MdnsPacket.class), eq(IPV4_ADDRESS)))
-                .thenReturn(expectedIPv4Packets[0])
-                .thenReturn(expectedIPv4Packets[1])
-                .thenReturn(expectedIPv4Packets[2])
-                .thenReturn(expectedIPv4Packets[3])
-                .thenReturn(expectedIPv4Packets[4])
-                .thenReturn(expectedIPv4Packets[5])
-                .thenReturn(expectedIPv4Packets[6])
-                .thenReturn(expectedIPv4Packets[7])
-                .thenReturn(expectedIPv4Packets[8])
-                .thenReturn(expectedIPv4Packets[9])
-                .thenReturn(expectedIPv4Packets[10])
-                .thenReturn(expectedIPv4Packets[11])
-                .thenReturn(expectedIPv4Packets[12])
-                .thenReturn(expectedIPv4Packets[13])
-                .thenReturn(expectedIPv4Packets[14])
-                .thenReturn(expectedIPv4Packets[15])
-                .thenReturn(expectedIPv4Packets[16])
-                .thenReturn(expectedIPv4Packets[17])
-                .thenReturn(expectedIPv4Packets[18])
-                .thenReturn(expectedIPv4Packets[19])
-                .thenReturn(expectedIPv4Packets[20])
-                .thenReturn(expectedIPv4Packets[21])
-                .thenReturn(expectedIPv4Packets[22])
-                .thenReturn(expectedIPv4Packets[23]);
+        when(mockDeps.getDatagramPacketsFromMdnsPacket(
+                any(), any(MdnsPacket.class), eq(IPV4_ADDRESS), anyBoolean()))
+                .thenReturn(List.of(expectedIPv4Packets[0]))
+                .thenReturn(List.of(expectedIPv4Packets[1]))
+                .thenReturn(List.of(expectedIPv4Packets[2]))
+                .thenReturn(List.of(expectedIPv4Packets[3]))
+                .thenReturn(List.of(expectedIPv4Packets[4]))
+                .thenReturn(List.of(expectedIPv4Packets[5]))
+                .thenReturn(List.of(expectedIPv4Packets[6]))
+                .thenReturn(List.of(expectedIPv4Packets[7]))
+                .thenReturn(List.of(expectedIPv4Packets[8]))
+                .thenReturn(List.of(expectedIPv4Packets[9]))
+                .thenReturn(List.of(expectedIPv4Packets[10]))
+                .thenReturn(List.of(expectedIPv4Packets[11]))
+                .thenReturn(List.of(expectedIPv4Packets[12]))
+                .thenReturn(List.of(expectedIPv4Packets[13]))
+                .thenReturn(List.of(expectedIPv4Packets[14]))
+                .thenReturn(List.of(expectedIPv4Packets[15]))
+                .thenReturn(List.of(expectedIPv4Packets[16]))
+                .thenReturn(List.of(expectedIPv4Packets[17]))
+                .thenReturn(List.of(expectedIPv4Packets[18]))
+                .thenReturn(List.of(expectedIPv4Packets[19]))
+                .thenReturn(List.of(expectedIPv4Packets[20]))
+                .thenReturn(List.of(expectedIPv4Packets[21]))
+                .thenReturn(List.of(expectedIPv4Packets[22]))
+                .thenReturn(List.of(expectedIPv4Packets[23]));
 
-        when(mockDeps.getDatagramPacketFromMdnsPacket(
-                any(), any(MdnsPacket.class), eq(IPV6_ADDRESS)))
-                .thenReturn(expectedIPv6Packets[0])
-                .thenReturn(expectedIPv6Packets[1])
-                .thenReturn(expectedIPv6Packets[2])
-                .thenReturn(expectedIPv6Packets[3])
-                .thenReturn(expectedIPv6Packets[4])
-                .thenReturn(expectedIPv6Packets[5])
-                .thenReturn(expectedIPv6Packets[6])
-                .thenReturn(expectedIPv6Packets[7])
-                .thenReturn(expectedIPv6Packets[8])
-                .thenReturn(expectedIPv6Packets[9])
-                .thenReturn(expectedIPv6Packets[10])
-                .thenReturn(expectedIPv6Packets[11])
-                .thenReturn(expectedIPv6Packets[12])
-                .thenReturn(expectedIPv6Packets[13])
-                .thenReturn(expectedIPv6Packets[14])
-                .thenReturn(expectedIPv6Packets[15])
-                .thenReturn(expectedIPv6Packets[16])
-                .thenReturn(expectedIPv6Packets[17])
-                .thenReturn(expectedIPv6Packets[18])
-                .thenReturn(expectedIPv6Packets[19])
-                .thenReturn(expectedIPv6Packets[20])
-                .thenReturn(expectedIPv6Packets[21])
-                .thenReturn(expectedIPv6Packets[22])
-                .thenReturn(expectedIPv6Packets[23]);
+        when(mockDeps.getDatagramPacketsFromMdnsPacket(
+                any(), any(MdnsPacket.class), eq(IPV6_ADDRESS), anyBoolean()))
+                .thenReturn(List.of(expectedIPv6Packets[0]))
+                .thenReturn(List.of(expectedIPv6Packets[1]))
+                .thenReturn(List.of(expectedIPv6Packets[2]))
+                .thenReturn(List.of(expectedIPv6Packets[3]))
+                .thenReturn(List.of(expectedIPv6Packets[4]))
+                .thenReturn(List.of(expectedIPv6Packets[5]))
+                .thenReturn(List.of(expectedIPv6Packets[6]))
+                .thenReturn(List.of(expectedIPv6Packets[7]))
+                .thenReturn(List.of(expectedIPv6Packets[8]))
+                .thenReturn(List.of(expectedIPv6Packets[9]))
+                .thenReturn(List.of(expectedIPv6Packets[10]))
+                .thenReturn(List.of(expectedIPv6Packets[11]))
+                .thenReturn(List.of(expectedIPv6Packets[12]))
+                .thenReturn(List.of(expectedIPv6Packets[13]))
+                .thenReturn(List.of(expectedIPv6Packets[14]))
+                .thenReturn(List.of(expectedIPv6Packets[15]))
+                .thenReturn(List.of(expectedIPv6Packets[16]))
+                .thenReturn(List.of(expectedIPv6Packets[17]))
+                .thenReturn(List.of(expectedIPv6Packets[18]))
+                .thenReturn(List.of(expectedIPv6Packets[19]))
+                .thenReturn(List.of(expectedIPv6Packets[20]))
+                .thenReturn(List.of(expectedIPv6Packets[21]))
+                .thenReturn(List.of(expectedIPv6Packets[22]))
+                .thenReturn(List.of(expectedIPv6Packets[23]));
 
         thread = new HandlerThread("MdnsServiceTypeClientTests");
         thread.start();
@@ -694,23 +694,23 @@
                 .addSubtype("subtype1").build();
         final MdnsSearchOptions searchOptions2 = MdnsSearchOptions.newBuilder()
                 .addSubtype("subtype2").build();
-        doCallRealMethod().when(mockDeps).getDatagramPacketFromMdnsPacket(
-                any(), any(MdnsPacket.class), any(InetSocketAddress.class));
+        doCallRealMethod().when(mockDeps).getDatagramPacketsFromMdnsPacket(
+                any(), any(MdnsPacket.class), any(InetSocketAddress.class), anyBoolean());
         startSendAndReceive(mockListenerOne, searchOptions1);
         currentThreadExecutor.getAndClearLastScheduledRunnable().run();
 
         InOrder inOrder = inOrder(mockListenerOne, mockSocketClient, mockDeps);
 
         // Verify the query asks for subtype1
-        final ArgumentCaptor<DatagramPacket> subtype1QueryCaptor =
-                ArgumentCaptor.forClass(DatagramPacket.class);
+        final ArgumentCaptor<List<DatagramPacket>> subtype1QueryCaptor =
+                ArgumentCaptor.forClass(List.class);
         // Send twice for IPv4 and IPv6
         inOrder.verify(mockSocketClient, times(2)).sendPacketRequestingUnicastResponse(
                 subtype1QueryCaptor.capture(),
                 eq(socketKey), eq(false));
 
         final MdnsPacket subtype1Query = MdnsPacket.parse(
-                new MdnsPacketReader(subtype1QueryCaptor.getValue()));
+                new MdnsPacketReader(subtype1QueryCaptor.getValue().get(0)));
 
         assertEquals(2, subtype1Query.questions.size());
         assertTrue(hasQuestion(subtype1Query, MdnsRecord.TYPE_PTR, SERVICE_TYPE_LABELS));
@@ -722,8 +722,8 @@
         inOrder.verify(mockDeps).removeMessages(any(), eq(EVENT_START_QUERYTASK));
         currentThreadExecutor.getAndClearLastScheduledRunnable().run();
 
-        final ArgumentCaptor<DatagramPacket> combinedSubtypesQueryCaptor =
-                ArgumentCaptor.forClass(DatagramPacket.class);
+        final ArgumentCaptor<List<DatagramPacket>> combinedSubtypesQueryCaptor =
+                ArgumentCaptor.forClass(List.class);
         inOrder.verify(mockSocketClient, times(2)).sendPacketRequestingUnicastResponse(
                 combinedSubtypesQueryCaptor.capture(),
                 eq(socketKey), eq(false));
@@ -731,7 +731,7 @@
         inOrder.verify(mockDeps).sendMessageDelayed(any(), any(), anyLong());
 
         final MdnsPacket combinedSubtypesQuery = MdnsPacket.parse(
-                new MdnsPacketReader(combinedSubtypesQueryCaptor.getValue()));
+                new MdnsPacketReader(combinedSubtypesQueryCaptor.getValue().get(0)));
 
         assertEquals(3, combinedSubtypesQuery.questions.size());
         assertTrue(hasQuestion(combinedSubtypesQuery, MdnsRecord.TYPE_PTR, SERVICE_TYPE_LABELS));
@@ -747,15 +747,15 @@
         dispatchMessage();
         currentThreadExecutor.getAndClearLastScheduledRunnable().run();
 
-        final ArgumentCaptor<DatagramPacket> subtype2QueryCaptor =
-                ArgumentCaptor.forClass(DatagramPacket.class);
+        final ArgumentCaptor<List<DatagramPacket>> subtype2QueryCaptor =
+                ArgumentCaptor.forClass(List.class);
         // Send twice for IPv4 and IPv6
         inOrder.verify(mockSocketClient, times(2)).sendPacketRequestingMulticastResponse(
                 subtype2QueryCaptor.capture(),
                 eq(socketKey), eq(false));
 
         final MdnsPacket subtype2Query = MdnsPacket.parse(
-                new MdnsPacketReader(subtype2QueryCaptor.getValue()));
+                new MdnsPacketReader(subtype2QueryCaptor.getValue().get(0)));
 
         assertEquals(2, subtype2Query.questions.size());
         assertTrue(hasQuestion(subtype2Query, MdnsRecord.TYPE_PTR, SERVICE_TYPE_LABELS));
@@ -1201,8 +1201,8 @@
         final MdnsSearchOptions resolveOptions2 = MdnsSearchOptions.newBuilder()
                 .setResolveInstanceName(instanceName).build();
 
-        doCallRealMethod().when(mockDeps).getDatagramPacketFromMdnsPacket(
-                any(), any(MdnsPacket.class), any(InetSocketAddress.class));
+        doCallRealMethod().when(mockDeps).getDatagramPacketsFromMdnsPacket(
+                any(), any(MdnsPacket.class), any(InetSocketAddress.class), anyBoolean());
 
         startSendAndReceive(mockListenerOne, resolveOptions1);
         startSendAndReceive(mockListenerTwo, resolveOptions2);
@@ -1210,8 +1210,8 @@
         InOrder inOrder = inOrder(mockListenerOne, mockSocketClient);
 
         // Verify a query for SRV/TXT was sent, but no PTR query
-        final ArgumentCaptor<DatagramPacket> srvTxtQueryCaptor =
-                ArgumentCaptor.forClass(DatagramPacket.class);
+        final ArgumentCaptor<List<DatagramPacket>> srvTxtQueryCaptor =
+                ArgumentCaptor.forClass(List.class);
         currentThreadExecutor.getAndClearLastScheduledRunnable().run();
         // Send twice for IPv4 and IPv6
         inOrder.verify(mockSocketClient, times(2)).sendPacketRequestingUnicastResponse(
@@ -1223,7 +1223,7 @@
         verify(mockListenerTwo).onDiscoveryQuerySent(any(), anyInt());
 
         final MdnsPacket srvTxtQueryPacket = MdnsPacket.parse(
-                new MdnsPacketReader(srvTxtQueryCaptor.getValue()));
+                new MdnsPacketReader(srvTxtQueryCaptor.getValue().get(0)));
 
         final String[] serviceName = getTestServiceName(instanceName);
         assertEquals(1, srvTxtQueryPacket.questions.size());
@@ -1255,8 +1255,8 @@
 
         // Expect a query for A/AAAA
         dispatchMessage();
-        final ArgumentCaptor<DatagramPacket> addressQueryCaptor =
-                ArgumentCaptor.forClass(DatagramPacket.class);
+        final ArgumentCaptor<List<DatagramPacket>> addressQueryCaptor =
+                ArgumentCaptor.forClass(List.class);
         currentThreadExecutor.getAndClearLastScheduledRunnable().run();
         inOrder.verify(mockSocketClient, times(2)).sendPacketRequestingMulticastResponse(
                 addressQueryCaptor.capture(),
@@ -1266,7 +1266,7 @@
         verify(mockListenerTwo, times(2)).onDiscoveryQuerySent(any(), anyInt());
 
         final MdnsPacket addressQueryPacket = MdnsPacket.parse(
-                new MdnsPacketReader(addressQueryCaptor.getValue()));
+                new MdnsPacketReader(addressQueryCaptor.getValue().get(0)));
         assertEquals(2, addressQueryPacket.questions.size());
         assertTrue(hasQuestion(addressQueryPacket, MdnsRecord.TYPE_A, hostname));
         assertTrue(hasQuestion(addressQueryPacket, MdnsRecord.TYPE_AAAA, hostname));
@@ -1316,15 +1316,15 @@
         final MdnsSearchOptions resolveOptions = MdnsSearchOptions.newBuilder()
                 .setResolveInstanceName(instanceName).build();
 
-        doCallRealMethod().when(mockDeps).getDatagramPacketFromMdnsPacket(
-                any(), any(MdnsPacket.class), any(InetSocketAddress.class));
+        doCallRealMethod().when(mockDeps).getDatagramPacketsFromMdnsPacket(
+                any(), any(MdnsPacket.class), any(InetSocketAddress.class), anyBoolean());
 
         startSendAndReceive(mockListenerOne, resolveOptions);
         InOrder inOrder = inOrder(mockListenerOne, mockSocketClient);
 
         // Get the query for SRV/TXT
-        final ArgumentCaptor<DatagramPacket> srvTxtQueryCaptor =
-                ArgumentCaptor.forClass(DatagramPacket.class);
+        final ArgumentCaptor<List<DatagramPacket>> srvTxtQueryCaptor =
+                ArgumentCaptor.forClass(List.class);
         currentThreadExecutor.getAndClearLastScheduledRunnable().run();
         // Send twice for IPv4 and IPv6
         inOrder.verify(mockSocketClient, times(2)).sendPacketRequestingUnicastResponse(
@@ -1334,7 +1334,7 @@
         assertNotNull(delayMessage);
 
         final MdnsPacket srvTxtQueryPacket = MdnsPacket.parse(
-                new MdnsPacketReader(srvTxtQueryCaptor.getValue()));
+                new MdnsPacketReader(srvTxtQueryCaptor.getValue().get(0)));
 
         final String[] serviceName = getTestServiceName(instanceName);
         assertTrue(hasQuestion(srvTxtQueryPacket, MdnsRecord.TYPE_ANY, serviceName));
@@ -1378,8 +1378,8 @@
         currentThreadExecutor.getAndClearLastScheduledRunnable().run();
 
         // Expect a renewal query
-        final ArgumentCaptor<DatagramPacket> renewalQueryCaptor =
-                ArgumentCaptor.forClass(DatagramPacket.class);
+        final ArgumentCaptor<List<DatagramPacket>> renewalQueryCaptor =
+                ArgumentCaptor.forClass(List.class);
         // Second and later sends are sent as "expect multicast response" queries
         inOrder.verify(mockSocketClient, times(2)).sendPacketRequestingMulticastResponse(
                 renewalQueryCaptor.capture(),
@@ -1388,7 +1388,7 @@
         assertNotNull(delayMessage);
         inOrder.verify(mockListenerOne).onDiscoveryQuerySent(any(), anyInt());
         final MdnsPacket renewalPacket = MdnsPacket.parse(
-                new MdnsPacketReader(renewalQueryCaptor.getValue()));
+                new MdnsPacketReader(renewalQueryCaptor.getValue().get(0)));
         assertTrue(hasQuestion(renewalPacket, MdnsRecord.TYPE_ANY, serviceName));
         inOrder.verifyNoMoreInteractions();
 
@@ -1937,14 +1937,14 @@
                 serviceCache,
                 MdnsFeatureFlags.newBuilder().setIsQueryWithKnownAnswerEnabled(true).build());
 
-        doCallRealMethod().when(mockDeps).getDatagramPacketFromMdnsPacket(
-                any(), any(MdnsPacket.class), any(InetSocketAddress.class));
+        doCallRealMethod().when(mockDeps).getDatagramPacketsFromMdnsPacket(
+                any(), any(MdnsPacket.class), any(InetSocketAddress.class), anyBoolean());
 
         startSendAndReceive(mockListenerOne, MdnsSearchOptions.getDefaultOptions());
         InOrder inOrder = inOrder(mockListenerOne, mockSocketClient);
 
-        final ArgumentCaptor<DatagramPacket> queryCaptor =
-                ArgumentCaptor.forClass(DatagramPacket.class);
+        final ArgumentCaptor<List<DatagramPacket>> queryCaptor =
+                ArgumentCaptor.forClass(List.class);
         currentThreadExecutor.getAndClearLastScheduledRunnable().run();
         // Send twice for IPv4 and IPv6
         inOrder.verify(mockSocketClient, times(2)).sendPacketRequestingUnicastResponse(
@@ -1953,7 +1953,7 @@
         assertNotNull(delayMessage);
 
         final MdnsPacket queryPacket = MdnsPacket.parse(
-                new MdnsPacketReader(queryCaptor.getValue()));
+                new MdnsPacketReader(queryCaptor.getValue().get(0)));
         assertTrue(hasQuestion(queryPacket, MdnsRecord.TYPE_PTR));
 
         // Process a response
@@ -1981,14 +1981,14 @@
 
         // Expect a query with known answers
         dispatchMessage();
-        final ArgumentCaptor<DatagramPacket> knownAnswersQueryCaptor =
-                ArgumentCaptor.forClass(DatagramPacket.class);
+        final ArgumentCaptor<List<DatagramPacket>> knownAnswersQueryCaptor =
+                ArgumentCaptor.forClass(List.class);
         currentThreadExecutor.getAndClearLastScheduledRunnable().run();
         inOrder.verify(mockSocketClient, times(2)).sendPacketRequestingMulticastResponse(
                 knownAnswersQueryCaptor.capture(), eq(socketKey), eq(false));
 
         final MdnsPacket knownAnswersQueryPacket = MdnsPacket.parse(
-                new MdnsPacketReader(knownAnswersQueryCaptor.getValue()));
+                new MdnsPacketReader(knownAnswersQueryCaptor.getValue().get(0)));
         assertTrue(hasQuestion(knownAnswersQueryPacket, MdnsRecord.TYPE_PTR, SERVICE_TYPE_LABELS));
         assertTrue(hasAnswer(knownAnswersQueryPacket, MdnsRecord.TYPE_PTR, SERVICE_TYPE_LABELS));
         assertFalse(hasAnswer(knownAnswersQueryPacket, MdnsRecord.TYPE_PTR, subtypeLabels));
@@ -2001,16 +2001,16 @@
                 serviceCache,
                 MdnsFeatureFlags.newBuilder().setIsQueryWithKnownAnswerEnabled(true).build());
 
-        doCallRealMethod().when(mockDeps).getDatagramPacketFromMdnsPacket(
-                any(), any(MdnsPacket.class), any(InetSocketAddress.class));
+        doCallRealMethod().when(mockDeps).getDatagramPacketsFromMdnsPacket(
+                any(), any(MdnsPacket.class), any(InetSocketAddress.class), anyBoolean());
 
         final MdnsSearchOptions options = MdnsSearchOptions.newBuilder()
                 .addSubtype("subtype").build();
         startSendAndReceive(mockListenerOne, options);
         InOrder inOrder = inOrder(mockListenerOne, mockSocketClient);
 
-        final ArgumentCaptor<DatagramPacket> queryCaptor =
-                ArgumentCaptor.forClass(DatagramPacket.class);
+        final ArgumentCaptor<List<DatagramPacket>> queryCaptor =
+                ArgumentCaptor.forClass(List.class);
         currentThreadExecutor.getAndClearLastScheduledRunnable().run();
         // Send twice for IPv4 and IPv6
         inOrder.verify(mockSocketClient, times(2)).sendPacketRequestingUnicastResponse(
@@ -2019,7 +2019,7 @@
         assertNotNull(delayMessage);
 
         final MdnsPacket queryPacket = MdnsPacket.parse(
-                new MdnsPacketReader(queryCaptor.getValue()));
+                new MdnsPacketReader(queryCaptor.getValue().get(0)));
         final String[] subtypeLabels = Stream.concat(Stream.of("_subtype", "_sub"),
                 Arrays.stream(SERVICE_TYPE_LABELS)).toArray(String[]::new);
         assertTrue(hasQuestion(queryPacket, MdnsRecord.TYPE_PTR, SERVICE_TYPE_LABELS));
@@ -2048,14 +2048,14 @@
 
         // Expect a query with known answers
         dispatchMessage();
-        final ArgumentCaptor<DatagramPacket> knownAnswersQueryCaptor =
-                ArgumentCaptor.forClass(DatagramPacket.class);
+        final ArgumentCaptor<List<DatagramPacket>> knownAnswersQueryCaptor =
+                ArgumentCaptor.forClass(List.class);
         currentThreadExecutor.getAndClearLastScheduledRunnable().run();
         inOrder.verify(mockSocketClient, times(2)).sendPacketRequestingMulticastResponse(
                 knownAnswersQueryCaptor.capture(), eq(socketKey), eq(false));
 
         final MdnsPacket knownAnswersQueryPacket = MdnsPacket.parse(
-                new MdnsPacketReader(knownAnswersQueryCaptor.getValue()));
+                new MdnsPacketReader(knownAnswersQueryCaptor.getValue().get(0)));
         assertTrue(hasQuestion(knownAnswersQueryPacket, MdnsRecord.TYPE_PTR, SERVICE_TYPE_LABELS));
         assertTrue(hasQuestion(knownAnswersQueryPacket, MdnsRecord.TYPE_PTR, subtypeLabels));
         assertTrue(hasAnswer(knownAnswersQueryPacket, MdnsRecord.TYPE_PTR, SERVICE_TYPE_LABELS));
@@ -2083,17 +2083,21 @@
         currentThreadExecutor.getAndClearLastScheduledRunnable().run();
         if (expectsUnicastResponse) {
             verify(mockSocketClient).sendPacketRequestingUnicastResponse(
-                    expectedIPv4Packets[index], socketKey, false);
+                    argThat(pkts -> pkts.get(0).equals(expectedIPv4Packets[index])),
+                    eq(socketKey), eq(false));
             if (multipleSocketDiscovery) {
                 verify(mockSocketClient).sendPacketRequestingUnicastResponse(
-                        expectedIPv6Packets[index], socketKey, false);
+                        argThat(pkts -> pkts.get(0).equals(expectedIPv6Packets[index])),
+                        eq(socketKey), eq(false));
             }
         } else {
             verify(mockSocketClient).sendPacketRequestingMulticastResponse(
-                    expectedIPv4Packets[index], socketKey, false);
+                    argThat(pkts -> pkts.get(0).equals(expectedIPv4Packets[index])),
+                    eq(socketKey), eq(false));
             if (multipleSocketDiscovery) {
                 verify(mockSocketClient).sendPacketRequestingMulticastResponse(
-                        expectedIPv6Packets[index], socketKey, false);
+                        argThat(pkts -> pkts.get(0).equals(expectedIPv6Packets[index])),
+                        eq(socketKey), eq(false));
             }
         }
         verify(mockDeps, times(index + 1))
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketClientTests.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketClientTests.java
index 7ced1cb..1989ed3 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketClientTests.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketClientTests.java
@@ -27,6 +27,7 @@
 import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.doReturn;
+import static org.mockito.Mockito.inOrder;
 import static org.mockito.Mockito.never;
 import static org.mockito.Mockito.timeout;
 import static org.mockito.Mockito.times;
@@ -53,6 +54,7 @@
 import org.junit.runner.RunWith;
 import org.mockito.ArgumentCaptor;
 import org.mockito.ArgumentMatchers;
+import org.mockito.InOrder;
 import org.mockito.Mock;
 import org.mockito.MockitoAnnotations;
 import org.mockito.invocation.InvocationOnMock;
@@ -60,6 +62,8 @@
 import java.io.IOException;
 import java.net.DatagramPacket;
 import java.net.InetSocketAddress;
+import java.util.ArrayList;
+import java.util.List;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicBoolean;
@@ -234,7 +238,7 @@
 
         // Sends a packet.
         DatagramPacket packet = getTestDatagramPacket();
-        mdnsClient.sendPacketRequestingMulticastResponse(packet,
+        mdnsClient.sendPacketRequestingMulticastResponse(List.of(packet),
                 false /* onlyUseIpv6OnIpv6OnlyNetworks */);
         // mockMulticastSocket.send() will be called on another thread. If we verify it immediately,
         // it may not be called yet. So timeout is added.
@@ -242,7 +246,7 @@
         verify(mockUnicastSocket, timeout(TIMEOUT).times(0)).send(packet);
 
         // Verify the packet is sent by the unicast socket.
-        mdnsClient.sendPacketRequestingUnicastResponse(packet,
+        mdnsClient.sendPacketRequestingUnicastResponse(List.of(packet),
                 false /* onlyUseIpv6OnIpv6OnlyNetworks */);
         verify(mockMulticastSocket, timeout(TIMEOUT).times(1)).send(packet);
         verify(mockUnicastSocket, timeout(TIMEOUT).times(1)).send(packet);
@@ -287,7 +291,7 @@
 
         // Sends a packet.
         DatagramPacket packet = getTestDatagramPacket();
-        mdnsClient.sendPacketRequestingMulticastResponse(packet,
+        mdnsClient.sendPacketRequestingMulticastResponse(List.of(packet),
                 false /* onlyUseIpv6OnIpv6OnlyNetworks */);
         // mockMulticastSocket.send() will be called on another thread. If we verify it immediately,
         // it may not be called yet. So timeout is added.
@@ -295,7 +299,7 @@
         verify(mockUnicastSocket, timeout(TIMEOUT).times(0)).send(packet);
 
         // Verify the packet is sent by the multicast socket as well.
-        mdnsClient.sendPacketRequestingUnicastResponse(packet,
+        mdnsClient.sendPacketRequestingUnicastResponse(List.of(packet),
                 false /* onlyUseIpv6OnIpv6OnlyNetworks */);
         verify(mockMulticastSocket, timeout(TIMEOUT).times(2)).send(packet);
         verify(mockUnicastSocket, timeout(TIMEOUT).times(0)).send(packet);
@@ -354,7 +358,7 @@
     public void testStopDiscovery_queueIsCleared() throws IOException {
         mdnsClient.startDiscovery();
         mdnsClient.stopDiscovery();
-        mdnsClient.sendPacketRequestingMulticastResponse(getTestDatagramPacket(),
+        mdnsClient.sendPacketRequestingMulticastResponse(List.of(getTestDatagramPacket()),
                 false /* onlyUseIpv6OnIpv6OnlyNetworks */);
 
         synchronized (mdnsClient.multicastPacketQueue) {
@@ -366,7 +370,7 @@
     public void testSendPacket_afterDiscoveryStops() throws IOException {
         mdnsClient.startDiscovery();
         mdnsClient.stopDiscovery();
-        mdnsClient.sendPacketRequestingMulticastResponse(getTestDatagramPacket(),
+        mdnsClient.sendPacketRequestingMulticastResponse(List.of(getTestDatagramPacket()),
                 false /* onlyUseIpv6OnIpv6OnlyNetworks */);
 
         synchronized (mdnsClient.multicastPacketQueue) {
@@ -380,7 +384,7 @@
         //MdnsConfigsFlagsImpl.mdnsPacketQueueMaxSize.override(2L);
         mdnsClient.startDiscovery();
         for (int i = 0; i < 100; i++) {
-            mdnsClient.sendPacketRequestingMulticastResponse(getTestDatagramPacket(),
+            mdnsClient.sendPacketRequestingMulticastResponse(List.of(getTestDatagramPacket()),
                     false /* onlyUseIpv6OnIpv6OnlyNetworks */);
         }
 
@@ -478,9 +482,9 @@
 
         mdnsClient.startDiscovery();
         DatagramPacket packet = getTestDatagramPacket();
-        mdnsClient.sendPacketRequestingUnicastResponse(packet,
+        mdnsClient.sendPacketRequestingUnicastResponse(List.of(packet),
                 false /* onlyUseIpv6OnIpv6OnlyNetworks */);
-        mdnsClient.sendPacketRequestingMulticastResponse(packet,
+        mdnsClient.sendPacketRequestingMulticastResponse(List.of(packet),
                 false /* onlyUseIpv6OnIpv6OnlyNetworks */);
 
         // Wait for the timer to be triggered.
@@ -511,9 +515,9 @@
         assertFalse(mdnsClient.receivedUnicastResponse);
         assertFalse(mdnsClient.cannotReceiveMulticastResponse.get());
 
-        mdnsClient.sendPacketRequestingUnicastResponse(packet,
+        mdnsClient.sendPacketRequestingUnicastResponse(List.of(packet),
                 false /* onlyUseIpv6OnIpv6OnlyNetworks */);
-        mdnsClient.sendPacketRequestingMulticastResponse(packet,
+        mdnsClient.sendPacketRequestingMulticastResponse(List.of(packet),
                 false /* onlyUseIpv6OnIpv6OnlyNetworks */);
         Thread.sleep(MdnsConfigs.checkMulticastResponseIntervalMs() * 2);
 
@@ -570,6 +574,26 @@
                 .onResponseReceived(any(), argThat(key -> key.getInterfaceIndex() == -1));
     }
 
+    @Test
+    public void testSendPacketWithMultipleDatagramPacket() throws IOException {
+        mdnsClient.startDiscovery();
+        final List<DatagramPacket> packets = new ArrayList<>();
+        for (int i = 0; i < 10; i++) {
+            packets.add(new DatagramPacket(new byte[10 + i] /* buff */, 0 /* offset */,
+                    10 + i /* length */, MdnsConstants.IPV4_SOCKET_ADDR));
+        }
+
+        // Sends packets.
+        mdnsClient.sendPacketRequestingMulticastResponse(packets,
+                false /* onlyUseIpv6OnIpv6OnlyNetworks */);
+        InOrder inOrder = inOrder(mockMulticastSocket);
+        for (int i = 0; i < 10; i++) {
+            // mockMulticastSocket.send() will be called on another thread. If we verify it
+            // immediately, it may not be called yet. So timeout is added.
+            inOrder.verify(mockMulticastSocket, timeout(TIMEOUT)).send(packets.get(i));
+        }
+    }
+
     private DatagramPacket getTestDatagramPacket() {
         return new DatagramPacket(buf, 0, 5,
                 new InetSocketAddress(MdnsConstants.getMdnsIPv4Address(), 5353 /* port */));
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/util/MdnsUtilsTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/util/MdnsUtilsTest.kt
index b1a7233..009205e 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/util/MdnsUtilsTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/util/MdnsUtilsTest.kt
@@ -52,19 +52,27 @@
         assertEquals("ţést", toDnsLowerCase("ţést"))
         // Unicode characters 0x10000 (𐀀), 0x10001 (𐀁), 0x10041 (𐁁)
         // Note the last 2 bytes of 0x10041 are identical to 'A', but it should remain unchanged.
-        assertEquals("test: -->\ud800\udc00 \ud800\udc01 \ud800\udc41<-- ",
-                toDnsLowerCase("Test: -->\ud800\udc00 \ud800\udc01 \ud800\udc41<-- "))
+        assertEquals(
+            "test: -->\ud800\udc00 \ud800\udc01 \ud800\udc41<-- ",
+                toDnsLowerCase("Test: -->\ud800\udc00 \ud800\udc01 \ud800\udc41<-- ")
+        )
         // Also test some characters where the first surrogate is not \ud800
-        assertEquals("test: >\ud83c\udff4\udb40\udc67\udb40\udc62\udb40" +
+        assertEquals(
+            "test: >\ud83c\udff4\udb40\udc67\udb40\udc62\udb40" +
                 "\udc77\udb40\udc6c\udb40\udc73\udb40\udc7f<",
-                toDnsLowerCase("Test: >\ud83c\udff4\udb40\udc67\udb40\udc62\udb40" +
-                        "\udc77\udb40\udc6c\udb40\udc73\udb40\udc7f<"))
+                toDnsLowerCase(
+                    "Test: >\ud83c\udff4\udb40\udc67\udb40\udc62\udb40" +
+                        "\udc77\udb40\udc6c\udb40\udc73\udb40\udc7f<"
+                )
+        )
     }
 
     @Test
     fun testToDnsLabelsLowerCase() {
-        assertArrayEquals(arrayOf("test", "tÉst", "ţést"),
-            toDnsLabelsLowerCase(arrayOf("TeSt", "TÉST", "ţést")))
+        assertArrayEquals(
+            arrayOf("test", "tÉst", "ţést"),
+            toDnsLabelsLowerCase(arrayOf("TeSt", "TÉST", "ţést"))
+        )
     }
 
     @Test
@@ -76,13 +84,17 @@
         assertFalse(equalsIgnoreDnsCase("ŢÉST", "ţést"))
         // Unicode characters 0x10000 (𐀀), 0x10001 (𐀁), 0x10041 (𐁁)
         // Note the last 2 bytes of 0x10041 are identical to 'A', but it should remain unchanged.
-        assertTrue(equalsIgnoreDnsCase("test: -->\ud800\udc00 \ud800\udc01 \ud800\udc41<-- ",
-                "Test: -->\ud800\udc00 \ud800\udc01 \ud800\udc41<-- "))
+        assertTrue(equalsIgnoreDnsCase(
+            "test: -->\ud800\udc00 \ud800\udc01 \ud800\udc41<-- ",
+                "Test: -->\ud800\udc00 \ud800\udc01 \ud800\udc41<-- "
+        ))
         // Also test some characters where the first surrogate is not \ud800
-        assertTrue(equalsIgnoreDnsCase("test: >\ud83c\udff4\udb40\udc67\udb40\udc62\udb40" +
+        assertTrue(equalsIgnoreDnsCase(
+            "test: >\ud83c\udff4\udb40\udc67\udb40\udc62\udb40" +
                 "\udc77\udb40\udc6c\udb40\udc73\udb40\udc7f<",
                 "Test: >\ud83c\udff4\udb40\udc67\udb40\udc62\udb40" +
-                        "\udc77\udb40\udc6c\udb40\udc73\udb40\udc7f<"))
+                        "\udc77\udb40\udc6c\udb40\udc73\udb40\udc7f<"
+        ))
     }
 
     @Test
@@ -101,15 +113,22 @@
 
     @Test
     fun testTypeEqualsOrIsSubtype() {
-        assertTrue(MdnsUtils.typeEqualsOrIsSubtype(arrayOf("_type", "_tcp", "local"),
-            arrayOf("_type", "_TCP", "local")))
-        assertTrue(MdnsUtils.typeEqualsOrIsSubtype(arrayOf("_type", "_tcp", "local"),
-            arrayOf("a", "_SUB", "_type", "_TCP", "local")))
-        assertFalse(MdnsUtils.typeEqualsOrIsSubtype(arrayOf("_sub", "_type", "_tcp", "local"),
-                arrayOf("_type", "_TCP", "local")))
+        assertTrue(MdnsUtils.typeEqualsOrIsSubtype(
+            arrayOf("_type", "_tcp", "local"),
+            arrayOf("_type", "_TCP", "local")
+        ))
+        assertTrue(MdnsUtils.typeEqualsOrIsSubtype(
+            arrayOf("_type", "_tcp", "local"),
+            arrayOf("a", "_SUB", "_type", "_TCP", "local")
+        ))
+        assertFalse(MdnsUtils.typeEqualsOrIsSubtype(
+            arrayOf("_sub", "_type", "_tcp", "local"),
+                arrayOf("_type", "_TCP", "local")
+        ))
         assertFalse(MdnsUtils.typeEqualsOrIsSubtype(
                 arrayOf("a", "_other", "_type", "_tcp", "local"),
-                arrayOf("a", "_SUB", "_type", "_TCP", "local")))
+                arrayOf("a", "_SUB", "_type", "_TCP", "local")
+        ))
     }
 
     @Test
diff --git a/tests/unit/java/com/android/server/connectivityservice/CSIngressDiscardRuleTests.kt b/tests/unit/java/com/android/server/connectivityservice/CSIngressDiscardRuleTests.kt
index e8664c1..bb7fb51 100644
--- a/tests/unit/java/com/android/server/connectivityservice/CSIngressDiscardRuleTests.kt
+++ b/tests/unit/java/com/android/server/connectivityservice/CSIngressDiscardRuleTests.kt
@@ -30,6 +30,7 @@
 import android.net.VpnTransportInfo
 import android.os.Build
 import androidx.test.filters.SmallTest
+import com.android.server.connectivity.ConnectivityFlags
 import com.android.testutils.DevSdkIgnoreRule
 import com.android.testutils.DevSdkIgnoreRunner
 import com.android.testutils.RecorderCallback.CallbackEntry.LinkPropertiesChanged
@@ -56,7 +57,9 @@
                         TYPE_VPN_SERVICE,
                         "MySession12345",
                         false /* bypassable */,
-                        false /* longLivedTcpConnectionsExpensive */))
+                        false /* longLivedTcpConnectionsExpensive */
+                )
+        )
         .build()
 
 private fun wifiNc() = NetworkCapabilities.Builder()
@@ -286,4 +289,19 @@
         waitForIdle()
         verify(bpfNetMaps).removeIngressDiscardRule(IPV6_ADDRESS)
     }
+
+    @Test @FeatureFlags([Flag(ConnectivityFlags.INGRESS_TO_VPN_ADDRESS_FILTERING, false)])
+    fun testVpnIngressDiscardRule_FeatureDisabled() {
+        val nr = nr(TRANSPORT_VPN)
+        val cb = TestableNetworkCallback()
+        cm.registerNetworkCallback(nr, cb)
+        val nc = vpnNc()
+        val lp = lp(VPN_IFNAME, IPV6_LINK_ADDRESS, LOCAL_IPV6_LINK_ADDRRESS)
+        val agent = Agent(nc = nc, lp = lp)
+        agent.connect()
+        cb.expectAvailableCallbacks(agent.network, validated = false)
+
+        // IngressDiscardRule should not be added since feature is disabled
+        verify(bpfNetMaps, never()).setIngressDiscardRule(any(), any())
+    }
 }
diff --git a/tests/unit/java/com/android/server/connectivityservice/base/CSTest.kt b/tests/unit/java/com/android/server/connectivityservice/base/CSTest.kt
index 5c4617b..bd26c63 100644
--- a/tests/unit/java/com/android/server/connectivityservice/base/CSTest.kt
+++ b/tests/unit/java/com/android/server/connectivityservice/base/CSTest.kt
@@ -79,11 +79,15 @@
 import java.util.concurrent.TimeUnit
 import java.util.function.BiConsumer
 import java.util.function.Consumer
+import kotlin.annotation.AnnotationRetention.RUNTIME
+import kotlin.annotation.AnnotationTarget.FUNCTION
 import kotlin.test.assertNotNull
 import kotlin.test.assertNull
 import kotlin.test.fail
 import org.junit.After
 import org.junit.Before
+import org.junit.Rule
+import org.junit.rules.TestName
 import org.mockito.AdditionalAnswers.delegatesTo
 import org.mockito.Mockito.doAnswer
 import org.mockito.Mockito.doReturn
@@ -126,6 +130,9 @@
 // TODO (b/272685721) : make ConnectivityServiceTest smaller and faster by moving the setup
 // parts into this class and moving the individual tests to multiple separate classes.
 open class CSTest {
+    @get:Rule
+    val testNameRule = TestName()
+
     companion object {
         val CSTestExecutor = Executors.newSingleThreadExecutor()
     }
@@ -155,8 +162,7 @@
         it[ConnectivityService.ALLOW_SATALLITE_NETWORK_FALLBACK] = true
         it[ConnectivityFlags.INGRESS_TO_VPN_ADDRESS_FILTERING] = true
     }
-    fun enableFeature(f: String) = enabledFeatures.set(f, true)
-    fun disableFeature(f: String) = enabledFeatures.set(f, false)
+    fun setFeatureEnabled(flag: String, enabled: Boolean) = enabledFeatures.set(flag, enabled)
 
     // When adding new members, consider if it's not better to build the object in CSTestHelpers
     // to keep this file clean of implementation details. Generally, CSTestHelpers should only
@@ -201,8 +207,32 @@
     lateinit var cm: ConnectivityManager
     lateinit var csHandler: Handler
 
+    // Tests can use this annotation to set flag values before constructing ConnectivityService
+    // e.g. @FeatureFlags([Flag(flagName1, true/false), Flag(flagName2, true/false)])
+    @Retention(RUNTIME)
+    @Target(FUNCTION)
+    annotation class FeatureFlags(val flags: Array<Flag>)
+
+    @Retention(RUNTIME)
+    @Target(FUNCTION)
+    annotation class Flag(val name: String, val enabled: Boolean)
+
     @Before
     fun setUp() {
+        // Set feature flags before constructing ConnectivityService
+        val testMethodName = testNameRule.methodName
+        try {
+            val testMethod = this::class.java.getMethod(testMethodName)
+            val featureFlags = testMethod.getAnnotation(FeatureFlags::class.java)
+            if (featureFlags != null) {
+                for (flag in featureFlags.flags) {
+                    setFeatureEnabled(flag.name, flag.enabled)
+                }
+            }
+        } catch (ignored: NoSuchMethodException) {
+            // This is expected for parameterized tests
+        }
+
         alarmHandlerThread = HandlerThread("TestAlarmManager").also { it.start() }
         alarmManager = makeMockAlarmManager(alarmHandlerThread)
         service = makeConnectivityService(context, netd, deps).also { it.systemReadyInternal() }
diff --git a/thread/service/java/com/android/server/thread/NsdPublisher.java b/thread/service/java/com/android/server/thread/NsdPublisher.java
index 72e3980..2c14f1d 100644
--- a/thread/service/java/com/android/server/thread/NsdPublisher.java
+++ b/thread/service/java/com/android/server/thread/NsdPublisher.java
@@ -39,10 +39,8 @@
 
 import java.net.Inet6Address;
 import java.net.InetAddress;
-import java.util.ArrayDeque;
 import java.util.ArrayList;
 import java.util.Arrays;
-import java.util.Deque;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
@@ -56,14 +54,6 @@
  *
  * <p>All the data members of this class MUST be accessed in the {@code mHandler}'s Thread except
  * {@code mHandler} itself.
- *
- * <p>TODO: b/323300118 - Remove the following mechanism when the race condition in NsdManager is
- * fixed.
- *
- * <p>There's always only one running registration job at any timepoint. All other pending jobs are
- * queued in {@code mRegistrationJobs}. When a registration job is complete (i.e. the according
- * method in {@link NsdManager.RegistrationListener} is called), it will start the next registration
- * job in the queue.
  */
 public final class NsdPublisher extends INsdPublisher.Stub {
     // TODO: b/321883491 - specify network for mDNS operations
@@ -74,7 +64,6 @@
     private final SparseArray<RegistrationListener> mRegistrationListeners = new SparseArray<>(0);
     private final SparseArray<DiscoveryListener> mDiscoveryListeners = new SparseArray<>(0);
     private final SparseArray<ServiceInfoListener> mServiceInfoListeners = new SparseArray<>(0);
-    private final Deque<Runnable> mRegistrationJobs = new ArrayDeque<>();
 
     @VisibleForTesting
     public NsdPublisher(NsdManager nsdManager, Handler handler) {
@@ -97,13 +86,9 @@
             List<DnsTxtAttribute> txt,
             INsdStatusReceiver receiver,
             int listenerId) {
-        postRegistrationJob(
-                () -> {
-                    NsdServiceInfo serviceInfo =
-                            buildServiceInfoForService(
-                                    hostname, name, type, subTypeList, port, txt);
-                    registerInternal(serviceInfo, receiver, listenerId, "service");
-                });
+        NsdServiceInfo serviceInfo =
+                buildServiceInfoForService(hostname, name, type, subTypeList, port, txt);
+        mHandler.post(() -> registerInternal(serviceInfo, receiver, listenerId, "service"));
     }
 
     private static NsdServiceInfo buildServiceInfoForService(
@@ -132,11 +117,8 @@
     @Override
     public void registerHost(
             String name, List<String> addresses, INsdStatusReceiver receiver, int listenerId) {
-        postRegistrationJob(
-                () -> {
-                    NsdServiceInfo serviceInfo = buildServiceInfoForHost(name, addresses);
-                    registerInternal(serviceInfo, receiver, listenerId, "host");
-                });
+        NsdServiceInfo serviceInfo = buildServiceInfoForHost(name, addresses);
+        mHandler.post(() -> registerInternal(serviceInfo, receiver, listenerId, "host"));
     }
 
     private static NsdServiceInfo buildServiceInfoForHost(
@@ -178,7 +160,7 @@
     }
 
     public void unregister(INsdStatusReceiver receiver, int listenerId) {
-        postRegistrationJob(() -> unregisterInternal(receiver, listenerId));
+        mHandler.post(() -> unregisterInternal(receiver, listenerId));
     }
 
     public void unregisterInternal(INsdStatusReceiver receiver, int listenerId) {
@@ -338,7 +320,6 @@
             }
         }
         mRegistrationListeners.clear();
-        mRegistrationJobs.clear();
     }
 
     /** On ot-daemon died, reset. */
@@ -346,39 +327,6 @@
         reset();
     }
 
-    // TODO: b/323300118 - Remove this mechanism when the race condition in NsdManager is fixed.
-    /** Fetch the first job from the queue and run it. See the class doc for more details. */
-    private void peekAndRun() {
-        if (mRegistrationJobs.isEmpty()) {
-            return;
-        }
-        Runnable job = mRegistrationJobs.getFirst();
-        job.run();
-    }
-
-    // TODO: b/323300118 - Remove this mechanism when the race condition in NsdManager is fixed.
-    /**
-     * Pop the first job from the queue and run the next job. See the class doc for more details.
-     */
-    private void popAndRunNext() {
-        if (mRegistrationJobs.isEmpty()) {
-            Log.i(TAG, "No registration jobs when trying to pop and run next.");
-            return;
-        }
-        mRegistrationJobs.removeFirst();
-        peekAndRun();
-    }
-
-    private void postRegistrationJob(Runnable registrationJob) {
-        mHandler.post(
-                () -> {
-                    mRegistrationJobs.addLast(registrationJob);
-                    if (mRegistrationJobs.size() == 1) {
-                        peekAndRun();
-                    }
-                });
-    }
-
     private final class RegistrationListener implements NsdManager.RegistrationListener {
         private final NsdServiceInfo mServiceInfo;
         private final int mListenerId;
@@ -416,7 +364,6 @@
             } catch (RemoteException ignored) {
                 // do nothing if the client is dead
             }
-            popAndRunNext();
         }
 
         @Override
@@ -438,7 +385,6 @@
                     // do nothing if the client is dead
                 }
             }
-            popAndRunNext();
         }
 
         @Override
@@ -456,7 +402,6 @@
             } catch (RemoteException ignored) {
                 // do nothing if the client is dead
             }
-            popAndRunNext();
         }
 
         @Override
@@ -477,7 +422,6 @@
                 }
             }
             mRegistrationListeners.remove(mListenerId);
-            popAndRunNext();
         }
     }
 
diff --git a/thread/service/java/com/android/server/thread/ThreadNetworkControllerService.java b/thread/service/java/com/android/server/thread/ThreadNetworkControllerService.java
index 0b13d1b..d80dcfb 100644
--- a/thread/service/java/com/android/server/thread/ThreadNetworkControllerService.java
+++ b/thread/service/java/com/android/server/thread/ThreadNetworkControllerService.java
@@ -182,6 +182,7 @@
     private final NsdPublisher mNsdPublisher;
     private final OtDaemonCallbackProxy mOtDaemonCallbackProxy = new OtDaemonCallbackProxy();
     private final ConnectivityResources mResources;
+    private final Supplier<String> mCountryCodeSupplier;
 
     // This should not be directly used for calling IOtDaemon APIs because ot-daemon may die and
     // {@code mOtDaemon} will be set to {@code null}. Instead, use {@code getOtDaemon()}
@@ -215,7 +216,8 @@
             ThreadPersistentSettings persistentSettings,
             NsdPublisher nsdPublisher,
             UserManager userManager,
-            ConnectivityResources resources) {
+            ConnectivityResources resources,
+            Supplier<String> countryCodeSupplier) {
         mContext = context;
         mHandler = handler;
         mNetworkProvider = networkProvider;
@@ -230,10 +232,13 @@
         mNsdPublisher = nsdPublisher;
         mUserManager = userManager;
         mResources = resources;
+        mCountryCodeSupplier = countryCodeSupplier;
     }
 
     public static ThreadNetworkControllerService newInstance(
-            Context context, ThreadPersistentSettings persistentSettings) {
+            Context context,
+            ThreadPersistentSettings persistentSettings,
+            Supplier<String> countryCodeSupplier) {
         HandlerThread handlerThread = new HandlerThread("ThreadHandlerThread");
         handlerThread.start();
         Handler handler = new Handler(handlerThread.getLooper());
@@ -251,7 +256,8 @@
                 persistentSettings,
                 NsdPublisher.newInstance(context, handler),
                 context.getSystemService(UserManager.class),
-                new ConnectivityResources(context));
+                new ConnectivityResources(context),
+                countryCodeSupplier);
     }
 
     private static Inet6Address bytesToInet6Address(byte[] addressBytes) {
@@ -347,7 +353,8 @@
                 isEnabled(),
                 mNsdPublisher,
                 getMeshcopTxtAttributes(mResources.get()),
-                mOtDaemonCallbackProxy);
+                mOtDaemonCallbackProxy,
+                mCountryCodeSupplier.get());
         otDaemon.asBinder().linkToDeath(() -> mHandler.post(this::onOtDaemonDied), 0);
         mOtDaemon = otDaemon;
         return mOtDaemon;
diff --git a/thread/service/java/com/android/server/thread/ThreadNetworkCountryCode.java b/thread/service/java/com/android/server/thread/ThreadNetworkCountryCode.java
index ffa7b44..a194114 100644
--- a/thread/service/java/com/android/server/thread/ThreadNetworkCountryCode.java
+++ b/thread/service/java/com/android/server/thread/ThreadNetworkCountryCode.java
@@ -16,6 +16,8 @@
 
 package com.android.server.thread;
 
+import static com.android.server.thread.ThreadPersistentSettings.THREAD_COUNTRY_CODE;
+
 import android.annotation.Nullable;
 import android.annotation.StringDef;
 import android.annotation.TargetApi;
@@ -83,6 +85,7 @@
                 COUNTRY_CODE_SOURCE_TELEPHONY,
                 COUNTRY_CODE_SOURCE_TELEPHONY_LAST,
                 COUNTRY_CODE_SOURCE_WIFI,
+                COUNTRY_CODE_SOURCE_SETTINGS,
             })
     private @interface CountryCodeSource {}
 
@@ -93,6 +96,7 @@
     private static final String COUNTRY_CODE_SOURCE_TELEPHONY = "Telephony";
     private static final String COUNTRY_CODE_SOURCE_TELEPHONY_LAST = "TelephonyLast";
     private static final String COUNTRY_CODE_SOURCE_WIFI = "Wifi";
+    private static final String COUNTRY_CODE_SOURCE_SETTINGS = "Settings";
 
     private static final CountryCodeInfo DEFAULT_COUNTRY_CODE_INFO =
             new CountryCodeInfo(DEFAULT_COUNTRY_CODE, COUNTRY_CODE_SOURCE_DEFAULT);
@@ -107,6 +111,7 @@
     private final SubscriptionManager mSubscriptionManager;
     private final Map<Integer, TelephonyCountryCodeSlotInfo> mTelephonyCountryCodeSlotInfoMap =
             new ArrayMap();
+    private final ThreadPersistentSettings mPersistentSettings;
 
     @Nullable private CountryCodeInfo mCurrentCountryCodeInfo;
     @Nullable private CountryCodeInfo mLocationCountryCodeInfo;
@@ -215,7 +220,8 @@
             Context context,
             TelephonyManager telephonyManager,
             SubscriptionManager subscriptionManager,
-            @Nullable String oemCountryCode) {
+            @Nullable String oemCountryCode,
+            ThreadPersistentSettings persistentSettings) {
         mLocationManager = locationManager;
         mThreadNetworkControllerService = threadNetworkControllerService;
         mGeocoder = geocoder;
@@ -224,14 +230,19 @@
         mContext = context;
         mTelephonyManager = telephonyManager;
         mSubscriptionManager = subscriptionManager;
+        mPersistentSettings = persistentSettings;
 
         if (oemCountryCode != null) {
             mOemCountryCodeInfo = new CountryCodeInfo(oemCountryCode, COUNTRY_CODE_SOURCE_OEM);
         }
+
+        mCurrentCountryCodeInfo = pickCountryCode();
     }
 
     public static ThreadNetworkCountryCode newInstance(
-            Context context, ThreadNetworkControllerService controllerService) {
+            Context context,
+            ThreadNetworkControllerService controllerService,
+            ThreadPersistentSettings persistentSettings) {
         return new ThreadNetworkCountryCode(
                 context.getSystemService(LocationManager.class),
                 controllerService,
@@ -241,7 +252,8 @@
                 context,
                 context.getSystemService(TelephonyManager.class),
                 context.getSystemService(SubscriptionManager.class),
-                ThreadNetworkProperties.country_code().orElse(null));
+                ThreadNetworkProperties.country_code().orElse(null),
+                persistentSettings);
     }
 
     /** Sets up this country code module to listen to location country code changes. */
@@ -485,6 +497,11 @@
             return mLocationCountryCodeInfo;
         }
 
+        String settingsCountryCode = mPersistentSettings.get(THREAD_COUNTRY_CODE);
+        if (settingsCountryCode != null) {
+            return new CountryCodeInfo(settingsCountryCode, COUNTRY_CODE_SOURCE_SETTINGS);
+        }
+
         if (mOemCountryCodeInfo != null) {
             return mOemCountryCodeInfo;
         }
@@ -498,6 +515,8 @@
             public void onSuccess() {
                 synchronized ("ThreadNetworkCountryCode.this") {
                     mCurrentCountryCodeInfo = countryCodeInfo;
+                    mPersistentSettings.put(
+                            THREAD_COUNTRY_CODE.key, countryCodeInfo.getCountryCode());
                 }
             }
 
@@ -536,10 +555,9 @@
                 newOperationReceiver(countryCodeInfo));
     }
 
-    /** Returns the current country code or {@code null} if no country code is set. */
-    @Nullable
+    /** Returns the current country code. */
     public synchronized String getCountryCode() {
-        return (mCurrentCountryCodeInfo != null) ? mCurrentCountryCodeInfo.getCountryCode() : null;
+        return mCurrentCountryCodeInfo.getCountryCode();
     }
 
     /**
diff --git a/thread/service/java/com/android/server/thread/ThreadNetworkService.java b/thread/service/java/com/android/server/thread/ThreadNetworkService.java
index 37c1cf1..30c67ca 100644
--- a/thread/service/java/com/android/server/thread/ThreadNetworkService.java
+++ b/thread/service/java/com/android/server/thread/ThreadNetworkService.java
@@ -60,13 +60,16 @@
         if (phase == SystemService.PHASE_SYSTEM_SERVICES_READY) {
             mPersistentSettings.initialize();
             mControllerService =
-                    ThreadNetworkControllerService.newInstance(mContext, mPersistentSettings);
+                    ThreadNetworkControllerService.newInstance(
+                            mContext, mPersistentSettings, () -> mCountryCode.getCountryCode());
+            mCountryCode =
+                    ThreadNetworkCountryCode.newInstance(
+                            mContext, mControllerService, mPersistentSettings);
             mControllerService.initialize();
         } else if (phase == SystemService.PHASE_BOOT_COMPLETED) {
             // Country code initialization is delayed to the BOOT_COMPLETED phase because it will
             // call into Wi-Fi and Telephony service whose country code module is ready after
             // PHASE_ACTIVITY_MANAGER_READY and PHASE_THIRD_PARTY_APPS_CAN_START
-            mCountryCode = ThreadNetworkCountryCode.newInstance(mContext, mControllerService);
             mCountryCode.initialize();
             mShellCommand =
                     new ThreadNetworkShellCommand(requireNonNull(mControllerService), mCountryCode);
diff --git a/thread/service/java/com/android/server/thread/ThreadPersistentSettings.java b/thread/service/java/com/android/server/thread/ThreadPersistentSettings.java
index 923f002..8aaff60 100644
--- a/thread/service/java/com/android/server/thread/ThreadPersistentSettings.java
+++ b/thread/service/java/com/android/server/thread/ThreadPersistentSettings.java
@@ -63,6 +63,9 @@
     /** Stores the Thread feature toggle state, true for enabled and false for disabled. */
     public static final Key<Boolean> THREAD_ENABLED = new Key<>("thread_enabled", true);
 
+    /** Stores the Thread country code, null if no country code is stored. */
+    public static final Key<String> THREAD_COUNTRY_CODE = new Key<>("thread_country_code", null);
+
     /******** Thread persistent setting keys ***************/
 
     @GuardedBy("mLock")
@@ -123,7 +126,9 @@
     private <T> T getObject(String key, T defaultValue) {
         Object value;
         synchronized (mLock) {
-            if (defaultValue instanceof Boolean) {
+            if (defaultValue == null) {
+                value = mSettings.getString(key, null);
+            } else if (defaultValue instanceof Boolean) {
                 value = mSettings.getBoolean(key, (Boolean) defaultValue);
             } else if (defaultValue instanceof Integer) {
                 value = mSettings.getInt(key, (Integer) defaultValue);
diff --git a/thread/tests/cts/src/android/net/thread/cts/ThreadNetworkControllerTest.java b/thread/tests/cts/src/android/net/thread/cts/ThreadNetworkControllerTest.java
index 0591c87..9a81388 100644
--- a/thread/tests/cts/src/android/net/thread/cts/ThreadNetworkControllerTest.java
+++ b/thread/tests/cts/src/android/net/thread/cts/ThreadNetworkControllerTest.java
@@ -864,11 +864,12 @@
     @Test
     public void meshcopService_threadDisabled_notDiscovered() throws Exception {
         setUpTestNetwork();
-
         CompletableFuture<NsdServiceInfo> serviceLostFuture = new CompletableFuture<>();
         NsdManager.DiscoveryListener listener =
                 discoverForServiceLost(MESHCOP_SERVICE_TYPE, serviceLostFuture);
+
         setEnabledAndWait(mController, false);
+
         try {
             serviceLostFuture.get(SERVICE_LOST_TIMEOUT_MILLIS, MILLISECONDS);
         } catch (InterruptedException | ExecutionException | TimeoutException ignored) {
@@ -877,7 +878,6 @@
         } finally {
             mNsdManager.stopServiceDiscovery(listener);
         }
-
         assertThrows(
                 TimeoutException.class,
                 () -> discoverService(MESHCOP_SERVICE_TYPE, SERVICE_LOST_TIMEOUT_MILLIS));
@@ -1112,7 +1112,12 @@
                         serviceInfoFuture.complete(serviceInfo);
                     }
                 };
-        mNsdManager.discoverServices(serviceType, NsdManager.PROTOCOL_DNS_SD, listener);
+        mNsdManager.discoverServices(
+                serviceType,
+                NsdManager.PROTOCOL_DNS_SD,
+                mTestNetworkTracker.getNetwork(),
+                mExecutor,
+                listener);
         try {
             serviceInfoFuture.get(timeoutMilliseconds, MILLISECONDS);
         } finally {
@@ -1131,7 +1136,12 @@
                         serviceInfoFuture.complete(serviceInfo);
                     }
                 };
-        mNsdManager.discoverServices(serviceType, NsdManager.PROTOCOL_DNS_SD, listener);
+        mNsdManager.discoverServices(
+                serviceType,
+                NsdManager.PROTOCOL_DNS_SD,
+                mTestNetworkTracker.getNetwork(),
+                mExecutor,
+                listener);
         return listener;
     }
 
diff --git a/thread/tests/integration/src/android/net/thread/ThreadIntegrationTest.java b/thread/tests/integration/src/android/net/thread/ThreadIntegrationTest.java
index bfded1d..c70f3af 100644
--- a/thread/tests/integration/src/android/net/thread/ThreadIntegrationTest.java
+++ b/thread/tests/integration/src/android/net/thread/ThreadIntegrationTest.java
@@ -23,6 +23,7 @@
 import static android.net.thread.utils.IntegrationTestUtils.RESTART_JOIN_TIMEOUT;
 
 import static com.android.compatibility.common.util.SystemUtil.runShellCommand;
+import static com.android.compatibility.common.util.SystemUtil.runShellCommandOrThrow;
 
 import static com.google.common.io.BaseEncoding.base16;
 import static com.google.common.truth.Truth.assertThat;
@@ -140,6 +141,22 @@
         }
     }
 
+    @Test
+    public void otDaemonRestart_latestCountryCodeIsSetToOtDaemon() throws Exception {
+        runThreadCommand("force-country-code enabled CN");
+
+        runShellCommand("stop ot-daemon");
+        // TODO(b/323331973): the sleep is needed to workaround the race conditions
+        SystemClock.sleep(200);
+        mController.waitForRole(DEVICE_ROLE_STOPPED, CALLBACK_TIMEOUT);
+
+        assertThat(mOtCtl.getCountryCode()).isEqualTo("CN");
+    }
+
+    private static String runThreadCommand(String cmd) {
+        return runShellCommandOrThrow("cmd thread_network " + cmd);
+    }
+
     // TODO (b/323300829): add more tests for integration with linux platform and
     // ConnectivityService
 }
diff --git a/thread/tests/integration/src/android/net/thread/utils/OtDaemonController.java b/thread/tests/integration/src/android/net/thread/utils/OtDaemonController.java
index ade0669..f39a064 100644
--- a/thread/tests/integration/src/android/net/thread/utils/OtDaemonController.java
+++ b/thread/tests/integration/src/android/net/thread/utils/OtDaemonController.java
@@ -74,6 +74,12 @@
         return (Inet6Address) InetAddresses.parseNumericAddress(addressStr);
     }
 
+    /** Returns the country code on ot-daemon. */
+    public String getCountryCode() {
+        String countryCodeStr = executeCommand("region").split("\n")[0].trim();
+        return countryCodeStr;
+    }
+
     public String executeCommand(String cmd) {
         return SystemUtil.runShellCommand(OT_CTL + " " + cmd);
     }
diff --git a/thread/tests/unit/src/com/android/server/thread/ThreadNetworkControllerServiceTest.java b/thread/tests/unit/src/com/android/server/thread/ThreadNetworkControllerServiceTest.java
index 0c7d086..85b6873 100644
--- a/thread/tests/unit/src/com/android/server/thread/ThreadNetworkControllerServiceTest.java
+++ b/thread/tests/unit/src/com/android/server/thread/ThreadNetworkControllerServiceTest.java
@@ -25,6 +25,7 @@
 import static android.net.thread.ThreadNetworkManager.DISALLOW_THREAD_NETWORK;
 import static android.net.thread.ThreadNetworkManager.PERMISSION_THREAD_NETWORK_PRIVILEGED;
 
+import static com.android.server.thread.ThreadNetworkCountryCode.DEFAULT_COUNTRY_CODE;
 import static com.android.server.thread.openthread.IOtDaemon.ErrorCode.OT_ERROR_INVALID_STATE;
 
 import static com.google.common.io.BaseEncoding.base16;
@@ -182,7 +183,8 @@
                         mMockPersistentSettings,
                         mMockNsdPublisher,
                         mMockUserManager,
-                        mConnectivityResources);
+                        mConnectivityResources,
+                        () -> DEFAULT_COUNTRY_CODE);
         mService.setTestNetworkAgent(mMockNetworkAgent);
     }
 
diff --git a/thread/tests/unit/src/com/android/server/thread/ThreadNetworkCountryCodeTest.java b/thread/tests/unit/src/com/android/server/thread/ThreadNetworkCountryCodeTest.java
index 5ca6511..ca9741d 100644
--- a/thread/tests/unit/src/com/android/server/thread/ThreadNetworkCountryCodeTest.java
+++ b/thread/tests/unit/src/com/android/server/thread/ThreadNetworkCountryCodeTest.java
@@ -19,6 +19,7 @@
 import static android.net.thread.ThreadNetworkException.ERROR_INTERNAL_ERROR;
 
 import static com.android.server.thread.ThreadNetworkCountryCode.DEFAULT_COUNTRY_CODE;
+import static com.android.server.thread.ThreadPersistentSettings.THREAD_COUNTRY_CODE;
 
 import static com.google.common.truth.Truth.assertThat;
 
@@ -104,6 +105,7 @@
     @Mock List<SubscriptionInfo> mSubscriptionInfoList;
     @Mock SubscriptionInfo mSubscriptionInfo0;
     @Mock SubscriptionInfo mSubscriptionInfo1;
+    @Mock ThreadPersistentSettings mPersistentSettings;
 
     private ThreadNetworkCountryCode mThreadNetworkCountryCode;
     private boolean mErrorSetCountryCode;
@@ -164,7 +166,8 @@
                 mContext,
                 mTelephonyManager,
                 mSubscriptionManager,
-                oemCountryCode);
+                oemCountryCode,
+                mPersistentSettings);
     }
 
     private static Address newAddress(String countryCode) {
@@ -450,6 +453,14 @@
     }
 
     @Test
+    public void settingsCountryCode_settingsCountryCodeIsActive_settingsCountryCodeIsUsed() {
+        when(mPersistentSettings.get(THREAD_COUNTRY_CODE)).thenReturn(TEST_COUNTRY_CODE_CN);
+        mThreadNetworkCountryCode.initialize();
+
+        assertThat(mThreadNetworkCountryCode.getCountryCode()).isEqualTo(TEST_COUNTRY_CODE_CN);
+    }
+
+    @Test
     public void dump_allCountryCodeInfoAreDumped() {
         StringWriter stringWriter = new StringWriter();
         PrintWriter printWriter = new PrintWriter(stringWriter);
diff --git a/thread/tests/unit/src/com/android/server/thread/ThreadPersistentSettingsTest.java b/thread/tests/unit/src/com/android/server/thread/ThreadPersistentSettingsTest.java
index 9406a2f..7d2fe91 100644
--- a/thread/tests/unit/src/com/android/server/thread/ThreadPersistentSettingsTest.java
+++ b/thread/tests/unit/src/com/android/server/thread/ThreadPersistentSettingsTest.java
@@ -16,6 +16,7 @@
 
 package com.android.server.thread;
 
+import static com.android.server.thread.ThreadPersistentSettings.THREAD_COUNTRY_CODE;
 import static com.android.server.thread.ThreadPersistentSettings.THREAD_ENABLED;
 
 import static com.google.common.truth.Truth.assertThat;
@@ -54,6 +55,8 @@
 @RunWith(AndroidJUnit4.class)
 @SmallTest
 public class ThreadPersistentSettingsTest {
+    private static final String TEST_COUNTRY_CODE = "CN";
+
     @Mock private AtomicFile mAtomicFile;
     @Mock Resources mResources;
     @Mock ConnectivityResources mConnectivityResources;
@@ -131,6 +134,28 @@
         verify(mAtomicFile).finishWrite(any());
     }
 
+    @Test
+    public void put_ThreadCountryCodeString_returnsString() throws Exception {
+        mThreadPersistentSettings.put(THREAD_COUNTRY_CODE.key, TEST_COUNTRY_CODE);
+
+        assertThat(mThreadPersistentSettings.get(THREAD_COUNTRY_CODE)).isEqualTo(TEST_COUNTRY_CODE);
+
+        // Confirm that file writes have been triggered.
+        verify(mAtomicFile).startWrite();
+        verify(mAtomicFile).finishWrite(any());
+    }
+
+    @Test
+    public void put_ThreadCountryCodeNull_returnsNull() throws Exception {
+        mThreadPersistentSettings.put(THREAD_COUNTRY_CODE.key, null);
+
+        assertThat(mThreadPersistentSettings.get(THREAD_COUNTRY_CODE)).isNull();
+
+        // Confirm that file writes have been triggered.
+        verify(mAtomicFile).startWrite();
+        verify(mAtomicFile).finishWrite(any());
+    }
+
     private byte[] createXmlForParsing(String key, Boolean value) throws Exception {
         PersistableBundle bundle = new PersistableBundle();
         ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
diff --git a/thread/tests/utils/src/android/net/thread/utils/TapTestNetworkTracker.java b/thread/tests/utils/src/android/net/thread/utils/TapTestNetworkTracker.java
index 43f177d..b586a19 100644
--- a/thread/tests/utils/src/android/net/thread/utils/TapTestNetworkTracker.java
+++ b/thread/tests/utils/src/android/net/thread/utils/TapTestNetworkTracker.java
@@ -62,6 +62,7 @@
     private final Looper mLooper;
     private TestNetworkInterface mInterface;
     private TestableNetworkAgent mAgent;
+    private Network mNetwork;
     private final TestableNetworkCallback mNetworkCallback;
     private final ConnectivityManager mConnectivityManager;
 
@@ -91,6 +92,11 @@
         return mInterface.getInterfaceName();
     }
 
+    /** Returns the {@link android.net.Network} of the test network. */
+    public Network getNetwork() {
+        return mNetwork;
+    }
+
     private void setUpTestNetwork() throws Exception {
         mInterface = mContext.getSystemService(TestNetworkManager.class).createTapInterface();
 
@@ -105,13 +111,13 @@
                         newNetworkCapabilities(),
                         lp,
                         new NetworkAgentConfig.Builder().build());
-        final Network network = mAgent.register();
+        mNetwork = mAgent.register();
         mAgent.markConnected();
 
         PollingCheck.check(
                 "No usable address on interface",
                 TIMEOUT.toMillis(),
-                () -> hasUsableAddress(network, getInterfaceName()));
+                () -> hasUsableAddress(mNetwork, getInterfaceName()));
 
         lp.setLinkAddresses(makeLinkAddresses());
         mAgent.sendLinkProperties(lp);